diff --git a/CMakeLists.txt b/CMakeLists.txt index 987e4ae709..c4da105cac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,6 +17,10 @@ else() set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") endif() +if (ENABLE_PYTHON) + add_compile_definitions(ENABLE_PYTHON) +endif() + set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D'_LIBCPP_EXTERN_TEMPLATE(...)=' -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 -Werror -Wall -Wno-deprecated-declarations -fPIC") diff --git a/RELEASE.md b/RELEASE.md index 4b829152a2..def72cbb20 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -70,6 +70,22 @@ Alexey Shevlyakov, avakh, baihuawei, BowenK, buxue, caifubi, caojian05, Cathy Wo Contributions of any kind are welcome! +# Release 0.3.1-alpha + +## Major Features and Improvements + +### Ascend 910 Training and Inference Framework +* Frontend and User Interface + * Independent model init interface. +* Data processing, augmentation, and save format + * Support sample padding for minddataset. + +## Bugfixes +* Python API + * Fix bugs in the lars optimizer([!1894](https://gitee.com/mindspore/mindspore/pulls/1894)) +* Data processing + * Fix accuracy problem of RandomCropDecodeResize ([!2340](https://gitee.com/mindspore/mindspore/pulls/2340)) + # Release 0.3.0-alpha ## Major Features and Improvements diff --git a/build.sh b/build.sh index 059478b9af..cfa657ff3e 100755 --- a/build.sh +++ b/build.sh @@ -24,8 +24,8 @@ usage() { echo "Usage:" echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\" - echo " [-a on|off] [-Q on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" - echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K] [-B on|off] [-E]" + echo " [-a on|off] [-Q on|off] [-S on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" + echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K] [-B on|off] [-E] [-l on|off]" echo "" echo "Options:" echo " -d Debug mode" @@ -48,6 +48,7 @@ usage() echo " -P Enable dump anf graph to file in ProtoBuffer format, default on" echo " -Q Enable dump memory, default off" echo " -D Enable dumping of function graph ir, default on" + echo " -S Enable async data dump, default off" echo " -z Compile dataset & mindrecord, default on" echo " -M Enable MPI and NCCL for GPU training, gpu default on" echo " -V Specify the minimum required cuda version, default CUDA 10.1" @@ -56,6 +57,7 @@ usage() echo " -s Enable serving module, default off" echo " -B Enable debugger, default off" echo " -E Enable IBVERBS for parameter server, default off" + echo " -l Compile with python dependency, default on" } # check value of input is 'on' or 'off' @@ -87,6 +89,7 @@ checkopts() ENABLE_TIMELINE="off" ENABLE_DUMP2PROTO="on" ENABLE_DUMPE2E="off" + ENABLE_DATA_DUMP="off" ENABLE_DUMP_IR="on" COMPILE_MINDDATA="on" ENABLE_MPI="off" @@ -98,9 +101,10 @@ checkopts() ENABLE_SERVING="off" ENABLE_DEBUGGER="off" ENABLE_IBVERBS="off" + ENABLE_PYTHON="on" # Process the options - while getopts 'drvj:c:t:hsb:a:g:p:ie:m:I:LRP:Q:D:zM:V:K:sB:E' opt + while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:Q:S:D:zM:V:K:sB:E' opt do OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') case "${opt}" in @@ -151,6 +155,10 @@ checkopts() check_on_off $OPTARG p ENABLE_PROFILE="$OPTARG" ;; + l) + check_on_off $OPTARG l + ENABLE_PYTHON="$OPTARG" + ;; i) INC_BUILD="on" ;; @@ -212,6 +220,11 @@ checkopts() ENABLE_DUMPE2E="$OPTARG" echo "enable dump end to end" ;; + S) + check_on_off $OPTARG S + ENABLE_DATA_DUMP="$OPTARG" + echo "enable data dump" + ;; D) check_on_off $OPTARG D ENABLE_DUMP_IR="$OPTARG" @@ -315,7 +328,11 @@ build_mindspore() if [[ "X$ENABLE_DUMPE2E" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_E2E=ON" fi + if [[ "X$ENABLE_DATA_DUMP" = "Xon" ]]; then + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DATA_DUMP=ON" + fi CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_DUMP_IR=${ENABLE_DUMP_IR}" + CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_PYTHON=${ENABLE_PYTHON}" if [[ "X$ENABLE_MPI" = "Xon" ]]; then CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_MPI=ON" fi diff --git a/cmake/external_libs/icu4c.cmake b/cmake/external_libs/icu4c.cmake index 7d13e4fd2a..af69328e55 100644 --- a/cmake/external_libs/icu4c.cmake +++ b/cmake/external_libs/icu4c.cmake @@ -9,11 +9,11 @@ else() LIBS ${LIB_ICU_COMMON} ${LIB_ICU_DATA} ${LIB_ICU_I18N} URL https://github.com/unicode-org/icu/archive/release-67-1.tar.gz MD5 0c2662a2b0bc80b0eb56495205247c8f - CONFIGURE_COMMAND ./icu4c/source/runConfigureICU Linux --enable-rpath --disable-tests --disable-samples --disable-icuio --disable-extras ICU_DATA_FILTER_FILE=${CMAKE_SOURCE_DIR}/third_party/icu4c/filter.json + CONFIGURE_COMMAND ${CMAKE_SOURCE_DIR}/scripts/build_icu4c.sh ) include_directories(${icu4c_INC}) add_library(mindspore::icuuc ALIAS icu4c::${LIB_ICU_COMMON}) add_library(mindspore::icudata ALIAS icu4c::${LIB_ICU_DATA}) add_library(mindspore::icui18n ALIAS icu4c::${LIB_ICU_I18N}) add_definitions(-D ENABLE_ICU4C) -endif() \ No newline at end of file +endif() diff --git a/cmake/mind_expression.cmake b/cmake/mind_expression.cmake index 63a65cd533..9002c23976 100644 --- a/cmake/mind_expression.cmake +++ b/cmake/mind_expression.cmake @@ -15,7 +15,7 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake) include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) -if (ENABLE_DEBUGGER) +if (ENABLE_DEBUGGER OR ENABLE_SERVING) # build dependencies of gRPC include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake) @@ -30,7 +30,7 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/flatbuffers.cmake) if(USE_GLOG) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/glog.cmake) endif() -if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows") +if (NOT ${CMAKE_SYSTEM_NAME} MATCHES "Windows" AND NOT ENABLE_GE) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/zeromq.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/pslite.cmake) endif() diff --git a/cmake/options.cmake b/cmake/options.cmake index 18db942d68..2470c25a90 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -19,6 +19,7 @@ option(ENABLE_MPI "enable mpi" OFF) option(ENABLE_AKG "enable akg" OFF) option(ENABLE_DEBUGGER "enable debugger" OFF) option(ENABLE_IBVERBS "enable IBVERBS for parameter server" OFF) +option(ENABLE_PYTHON "Enable python" ON) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if (WIN32) @@ -115,6 +116,10 @@ if(ENABLE_DUMP_E2E) add_compile_definitions(ENABLE_DUMP_E2E) endif() +if(ENABLE_DATA_DUMP) + add_compile_definitions(ENABLE_DATA_DUMP) +endif() + if(ENABLE_DEBUGGER) add_compile_definitions(ENABLE_DEBUGGER) endif() diff --git a/cmake/package.cmake b/cmake/package.cmake index 2fde01af4f..7b3c2f7bb2 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -213,7 +213,6 @@ install( ${CMAKE_SOURCE_DIR}/mindspore/parallel ${CMAKE_SOURCE_DIR}/mindspore/mindrecord ${CMAKE_SOURCE_DIR}/mindspore/train - ${CMAKE_SOURCE_DIR}/mindspore/model_zoo ${CMAKE_SOURCE_DIR}/mindspore/common ${CMAKE_SOURCE_DIR}/mindspore/ops ${CMAKE_SOURCE_DIR}/mindspore/communication @@ -261,3 +260,17 @@ if (EXISTS ${CMAKE_SOURCE_DIR}/mindspore/dataset) COMPONENT mindspore ) endif () + +if (ENABLE_SERVING) + install( + TARGETS ms_serving + DESTINATION ${INSTALL_BASE_DIR} + COMPONENT mindspore + ) + + install( + TARGETS inference + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore + ) +endif () diff --git a/config/data_dump.json b/config/data_dump.json new file mode 100644 index 0000000000..fc08f78590 --- /dev/null +++ b/config/data_dump.json @@ -0,0 +1,15 @@ +{ + "DumpSettings": { + "net_name": "ResNet50", + "mode": 1, + "iteration": 0, + "kernels": ["Default/Conv2D-op2", "Default/TensorAdd-op10"] + }, + + "DumpSettingsSpec": { + "net_name": "net name eg:ResNet50", + "mode": "0: dump all kernels, 1: dump kernels in kernels list", + "iteration": "specified iteration ", + "kernels": "op's full scope name which need to be dump" + } +} \ No newline at end of file diff --git a/config/op_info.config b/config/op_info.config new file mode 100644 index 0000000000..6ab9eba875 --- /dev/null +++ b/config/op_info.config @@ -0,0 +1,383 @@ +{"op_name": "InitData", "inputs": [], "outputs": [], "attr": [{"name": "queue_name", "type": "str"}], "fusion_type": "OPAQUE", "dtype_format": [], "imply_type": "AiCPU"} +{"op_name": "DropoutGenMask", "inputs": [{"index": 0, "name": "x1", "param_type": "required"}, {"index": 1, "name": "x2", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "Seed0", "type": "int"}, {"name": "Seed1", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["int32", "NCHW"], ["float16", "NCHW"], ["uint8", "NCHW"]]], "imply_type": "AiCPU"} +{"op_name": "GetNext", "inputs": [], "outputs": [{"index": 0, "name": "y", "param_type": "dynamic"}], "attr": [{"name": "shared_name", "type": "str"}], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"]], [["int8", "DefaultFormat"]], [["int16", "DefaultFormat"]], [["int32", "DefaultFormat"]], [["int64", "DefaultFormat"]], [["float16", "DefaultFormat"]], [["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"]], [["float32", "DefaultFormat"]]], "imply_type": "AiCPU"} +{"op_name": "Print", "inputs": [{"index": 0, "name": "x", "param_type": "dynamic"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "AiCPU"} +{"op_name": "TopK", "inputs": [{"index": 0, "name": "intput", "param_type": "required"}, {"index": 1, "name": "k", "param_type": "required"}], "outputs": [{"index": 0, "name": "values", "param_type": "required"}, {"index": 1, "name": "indices", "param_type": "required"}], "attr": [{"name": "sorted", "type": "bool"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "AiCPU"} +{"op_name": "IsFinite", "inputs": [{"index": 0, "name": "x", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int8", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int16", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int32", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int64", "DefaultFormat"], ["bool", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["bool", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["bool", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["bool", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["bool", "DefaultFormat"]], [["float16", "DefaultFormat"], ["bool", "DefaultFormat"]], [["float32", "DefaultFormat"], ["bool", "DefaultFormat"]], [["float64", "DefaultFormat"], ["bool", "DefaultFormat"]], [["bool", "NCHW"], ["bool", "NCHW"]], [["int8", "NCHW"], ["bool", "NCHW"]], [["int16", "NCHW"], ["bool", "NCHW"]], [["int32", "NCHW"], ["bool", "NCHW"]], [["int64", "NCHW"], ["bool", "NCHW"]], [["uint8", "NCHW"], ["bool", "NCHW"]], [["uint16", "NCHW"], ["bool", "NCHW"]], [["uint32", "NCHW"], ["bool", "NCHW"]], [["uint64", "NCHW"], ["bool", "NCHW"]], [["float16", "NCHW"], ["bool", "NCHW"]], [["float32", "NCHW"], ["bool", "NCHW"]], [["float64", "NCHW"], ["bool", "NCHW"]]], "imply_type": "AiCPU"} +{"op_name": "Reshape", "inputs": [{"index": 0, "name": "x", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float64", "DefaultFormat"], ["float64", "DefaultFormat"]], [["bool", "NCHW"], ["bool", "NCHW"]], [["int8", "NCHW"], ["int8", "NCHW"]], [["int16", "NCHW"], ["int16", "NCHW"]], [["int32", "NCHW"], ["int32", "NCHW"]], [["int64", "NCHW"], ["int64", "NCHW"]], [["uint8", "NCHW"], ["uint8", "NCHW"]], [["uint16", "NCHW"], ["uint16", "NCHW"]], [["uint32", "NCHW"], ["uint32", "NCHW"]], [["uint64", "NCHW"], ["uint64", "NCHW"]], [["float16", "NCHW"], ["float16", "NCHW"]], [["float32", "NCHW"], ["float32", "NCHW"]], [["float64", "NCHW"], ["float64", "NCHW"]]], "imply_type": "AiCPU"} +{"op_name": "Flatten", "inputs": [{"index": 0, "name": "x", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "NCHW"], ["int8", "NCHW"]], [["int16", "NCHW"], ["int16", "NCHW"]], [["int32", "NCHW"], ["int32", "NCHW"]], [["int64", "NCHW"], ["int64", "NCHW"]], [["uint8", "NCHW"], ["uint8", "NCHW"]], [["uint16", "NCHW"], ["uint16", "NCHW"]], [["uint32", "NCHW"], ["uint32", "NCHW"]], [["uint64", "NCHW"], ["uint64", "NCHW"]], [["float16", "NCHW"], ["float16", "NCHW"]], [["float32", "NCHW"], ["float32", "NCHW"]]], "imply_type": "AiCPU"} +{"op_name": "Squeeze", "inputs": [{"index": 0, "name": "x", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float64", "DefaultFormat"], ["float64", "DefaultFormat"]], [["bool", "NCHW"], ["bool", "NCHW"]], [["int8", "NCHW"], ["int8", "NCHW"]], [["int16", "NCHW"], ["int16", "NCHW"]], [["int32", "NCHW"], ["int32", "NCHW"]], [["int64", "NCHW"], ["int64", "NCHW"]], [["uint8", "NCHW"], ["uint8", "NCHW"]], [["uint16", "NCHW"], ["uint16", "NCHW"]], [["uint32", "NCHW"], ["uint32", "NCHW"]], [["uint64", "NCHW"], ["uint64", "NCHW"]], [["float16", "NCHW"], ["float16", "NCHW"]], [["float32", "NCHW"], ["float32", "NCHW"]], [["float64", "NCHW"], ["float64", "NCHW"]]], "imply_type": "AiCPU"} +{"op_name": "ExpandDims", "inputs": [{"index": 0, "name": "x", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float64", "DefaultFormat"], ["float64", "DefaultFormat"]], [["bool", "NCHW"], ["bool", "NCHW"]], [["int8", "NCHW"], ["int8", "NCHW"]], [["int16", "NCHW"], ["int16", "NCHW"]], [["int32", "NCHW"], ["int32", "NCHW"]], [["int64", "NCHW"], ["int64", "NCHW"]], [["uint8", "NCHW"], ["uint8", "NCHW"]], [["uint16", "NCHW"], ["uint16", "NCHW"]], [["uint32", "NCHW"], ["uint32", "NCHW"]], [["uint64", "NCHW"], ["uint64", "NCHW"]], [["float16", "NCHW"], ["float16", "NCHW"]], [["float32", "NCHW"], ["float32", "NCHW"]], [["float64", "NCHW"], ["float64", "NCHW"]]], "imply_type": "AiCPU"} +{"op_name": "RandomChoiceWithMask", "inputs": [{"index": 0, "name": "x", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}, {"index": 1, "name": "mask", "param_type": "required"}], "attr": [{"name": "count", "type": "int"}, {"name": "seed", "type": "int"}, {"name": "seed2", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "NCHW"], ["int32", "NCHW"], ["bool", "NCHW"]], [["bool", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "AiCPU"} +{"op_name": "Pack", "inputs": [{"index": 0, "name": "x", "param_type": "dynamic"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "axis", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float64", "DefaultFormat"], ["float64", "DefaultFormat"]], [["bool", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "AiCPU"} +{"op_name": "Normal", "inputs": [{"index": 0, "name": "shape", "param_type": "required"}, {"index": 1, "name": "mean", "param_type": "required"}, {"index": 2, "name": "stddev", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "seed", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]]], "imply_type": "AiCPU"} +{"op_name": "CTCLoss", "inputs": [{"index": 0, "name": "inputs", "param_type": "required"}, {"index": 1, "name": "labels_indices", "param_type": "required"}, {"index": 2, "name": "labels_values", "param_type": "required"}, {"index": 3, "name": "sequence_length", "param_type": "required"}], "outputs": [{"index": 0, "name": "loss", "param_type": "required"}, {"index": 1, "name": "gradient", "param_type": "required"}], "attr": [{"name": "preprocess_collapse_repeated", "type": "bool"}, {"name": "ctc_merge_repeated", "type": "bool"}, {"name": "ignore_longer_outputs_than_inputs", "type": "bool"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float64", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float64", "DefaultFormat"], ["float64", "DefaultFormat"]], [["float32", "NCHW"], ["int64", "NCHW"], ["int32", "NCHW"], ["int32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float64", "NCHW"], ["int64", "NCHW"], ["int32", "NCHW"], ["int32", "NCHW"], ["float64", "NCHW"], ["float64", "NCHW"]]], "imply_type": "AiCPU"} +{"op_name": "ReverseSequence", "inputs": [{"index": 0, "name": "x", "param_type": "required"}, {"index": 1, "name": "seq_lengths", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "seq_dim", "type": "int"}, {"name": "batch_dim", "type": "int"}], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int32", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int32", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float64", "DefaultFormat"], ["int32", "DefaultFormat"], ["float64", "DefaultFormat"]], [["bool", "NCHW"], ["int32", "NCHW"], ["bool", "NCHW"]], [["int8", "NCHW"], ["int32", "NCHW"], ["int8", "NCHW"]], [["int16", "NCHW"], ["int32", "NCHW"], ["int16", "NCHW"]], [["int32", "NCHW"], ["int32", "NCHW"], ["int32", "NCHW"]], [["int64", "NCHW"], ["int32", "NCHW"], ["int64", "NCHW"]], [["uint8", "NCHW"], ["int32", "NCHW"], ["uint8", "NCHW"]], [["uint16", "NCHW"], ["int32", "NCHW"], ["uint16", "NCHW"]], [["uint32", "NCHW"], ["int32", "NCHW"], ["uint32", "NCHW"]], [["uint64", "NCHW"], ["int32", "NCHW"], ["uint64", "NCHW"]], [["float16", "NCHW"], ["int32", "NCHW"], ["float16", "NCHW"]], [["float32", "NCHW"], ["int32", "NCHW"], ["float32", "NCHW"]], [["float64", "NCHW"], ["int32", "NCHW"], ["float64", "NCHW"]], [["bool", "DefaultFormat"], ["int64", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int64", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int64", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float64", "DefaultFormat"], ["int64", "DefaultFormat"], ["float64", "DefaultFormat"]], [["bool", "NCHW"], ["int64", "NCHW"], ["bool", "NCHW"]], [["int8", "NCHW"], ["int64", "NCHW"], ["int8", "NCHW"]], [["int16", "NCHW"], ["int64", "NCHW"], ["int16", "NCHW"]], [["int32", "NCHW"], ["int64", "NCHW"], ["int32", "NCHW"]], [["int64", "NCHW"], ["int64", "NCHW"], ["int64", "NCHW"]], [["uint8", "NCHW"], ["int64", "NCHW"], ["uint8", "NCHW"]], [["uint16", "NCHW"], ["int64", "NCHW"], ["uint16", "NCHW"]], [["uint32", "NCHW"], ["int64", "NCHW"], ["uint32", "NCHW"]], [["uint64", "NCHW"], ["int64", "NCHW"], ["uint64", "NCHW"]], [["float16", "NCHW"], ["int64", "NCHW"], ["float16", "NCHW"]], [["float32", "NCHW"], ["int64", "NCHW"], ["float32", "NCHW"]], [["float64", "NCHW"], ["int64", "NCHW"], ["float64", "NCHW"]]], "imply_type": "AiCPU"} +{"op_name": "CropAndResize", "inputs": [{"index": 0, "name": "image", "param_type": "required"}, {"index": 1, "name": "boxes", "param_type": "required"}, {"index": 2, "name": "box_index", "param_type": "required"}, {"index": 3, "name": "crop_size", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [{"name": "method", "type": "str"}, {"name": "extrapolation_value", "type": "float"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int16", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float64", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"]], [["int16", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"]], [["int32", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"]], [["int64", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"]], [["float16", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"]], [["float32", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"]], [["float64", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"]], [["uint8", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"]], [["uint16", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"]]], "imply_type": "AiCPU"} +{"op_name": "EndOfSequence", "inputs": [{"index": 0, "name": "x", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "param_type": "required"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "AiCPU"} +{"op_name": "Abs", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "AddN", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "dynamic", "name": "inputs"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "TensorAdd", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "int32", "float16", "int32", "float32", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0", "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "x"}, {"index": 1, "dtype": ["float16", "int32", "float16", "int32", "float32", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0", "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "y"}], "outputs": [{"index": 0, "dtype": ["float16", "int32", "float16", "int32", "float32", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0", "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "ApplyMomentum", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [{"name": "use_nesterov", "param_type": "optional", "type": "bool"}, {"name": "gradient_scale", "param_type": "optional", "type": "float"}], "inputs": [{"index": 0, "dtype": ["float32", "float32", "float32"], "format": ["DefaultFormat", "NC1HWC0", "FracZ"], "name": "variable"}, {"index": 1, "dtype": ["float32", "float32", "float32"], "format": ["DefaultFormat", "NC1HWC0", "FracZ"], "name": "accumulation"}, {"index": 2, "dtype": ["float32", "float32", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "learning_rate"}, {"index": 3, "dtype": ["float32", "float32", "float32"], "format": ["DefaultFormat", "NC1HWC0", "FracZ"], "name": "gradient"}, {"index": 4, "dtype": ["float32", "float32", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "momentum"}], "outputs": [{"index": 0, "dtype": ["float32", "float32", "float32"], "format": ["DefaultFormat", "NC1HWC0", "FracZ"], "name": "output"}]} +{"op_name": "Assign", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [], "inputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ"], "name": "ref"}, {"index": 1, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ"], "name": "value"}], "outputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ"], "name": "output"}]} +{"op_name": "InplaceAssign", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [{"name": "fake_output", "param_type": "optional", "type": "bool"}], "inputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ"], "name": "x"}, {"index": 1, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ"], "name": "y"}, {"index": 2, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ"], "name": "z"}], "outputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ"], "name": "output"}]} +{"op_name": "AssignAdd", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "ref"}, {"index": 1, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "value"}], "outputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "BiasAddGrad", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [{"name": "data_format", "param_type": "optional", "type": "listStr"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["NHWC", "NHWC", "NC1HWC0", "NC1HWC0", "DefaultFormat", "DefaultFormat"], "name": "dout"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "DefaultFormat"], "name": "output"}]} +{"op_name": "BiasAdd", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [{"name": "data_format", "param_type": "optional", "type": "listStr"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["NHWC", "NHWC", "NC1HWC0", "NC1HWC0", "DefaultFormat", "DefaultFormat"], "name": "x"}, {"index": 1, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["NHWC", "NHWC", "NC1HWC0", "NC1HWC0", "DefaultFormat", "DefaultFormat"], "name": "b"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "DefaultFormat"], "name": "output"}]} +{"op_name": "Cast", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [{"name": "dst_type", "param_type": "required", "type": "str"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "bool", "bool", "float16", "float32", "int32", "int32", "bool", "float16", "float32", "bool", "bool", "float16", "float32", "bool", "bool"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float32", "float16", "int32", "float16", "int32", "int32", "float16", "float32", "float32", "float32", "float16", "int32", "float32", "float32", "float16", "int32", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "ClearZero", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [{"name": "pad_mod", "param_type": "optional", "type": "string"}, {"name": "window", "param_type": "optional", "type": "int"}, {"name": "pad", "param_type": "optional", "type": "int"}, {"name": "stride", "param_type": "optional", "type": "int"}], "inputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": []} +{"op_name": "ConvBN1", "imply_type": "AutoDiff", "fusion_type": "CONVLUTION", "attr": [{"name": "x_shape", "param_type": "required", "type": "listInt"}, {"name": "w_shape", "param_type": "required", "type": "listInt"}, {"name": "pad_list", "param_type": "required", "type": "listInt"}, {"name": "stride", "param_type": "optional", "type": "int"}, {"name": "dilation", "param_type": "optional", "type": "int"}], "inputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["float16"], "format": ["FracZ"], "name": "w"}], "outputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "conv_res_16"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "var_part"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "mean"}]} +{"op_name": "Conv2DBackpropFilter", "imply_type": "AutoDiff", "fusion_type": "CONVLUTION", "attr": [{"name": "input_shape", "param_type": "required", "type": "listInt"}, {"name": "filter_sizes", "param_type": "required", "type": "listInt"}, {"name": "stride", "param_type": "optional", "type": "int"}, {"name": "pad_list", "param_type": "required", "type": "listInt"}, {"name": "dilation", "param_type": "optional", "type": "int"}], "inputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "out_backprop"}, {"index": 1, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "input"}], "outputs": [{"index": 0, "dtype": ["float32"], "format": ["FracZ"], "name": "output"}]} +{"op_name": "Conv2DBackpropInput", "imply_type": "AutoDiff", "fusion_type": "CONVLUTION", "attr": [{"name": "input_sizes", "param_type": "required", "type": "listInt"}, {"name": "filter_shape", "param_type": "required", "type": "listInt"}, {"name": "stride", "param_type": "optional", "type": "int"}, {"name": "pad_list", "param_type": "required", "type": "listInt"}, {"name": "dilation", "param_type": "optional", "type": "int"}], "inputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "out_backprop"}, {"index": 1, "dtype": ["float16"], "format": ["FracZ"], "name": "filter"}], "outputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "output"}]} +{"op_name": "Conv2D", "imply_type": "AutoDiff", "fusion_type": "CONVLUTION", "attr": [{"name": "x_shape", "param_type": "required", "type": "listInt"}, {"name": "w_shape", "param_type": "required", "type": "listInt"}, {"name": "pad_list", "param_type": "required", "type": "listInt"}, {"name": "stride", "param_type": "optional", "type": "int"}, {"name": "dilation", "param_type": "optional", "type": "int"}], "inputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["float16"], "format": ["FracZ"], "name": "w"}], "outputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "output"}]} +{"op_name": "Div", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "y"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "EqualCount", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [], "inputs": [{"index": 0, "dtype": ["int32"], "format": ["DefaultFormat"], "name": "x"}, {"index": 1, "dtype": ["int32"], "format": ["DefaultFormat"], "name": "y"}], "outputs": [{"index": 0, "dtype": ["int32"], "format": ["DefaultFormat"], "name": "output"}]} +{"op_name": "Exp", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "Five2Four", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "shape4d", "param_type": "required", "type": "listInt"}, {"name": "dstType", "param_type": "required", "type": "str"}, {"name": "output_format", "param_type": "required", "type": "str"}], "inputs": [{"index": 0, "dtype": ["float16", "float16", "float16", "float32", "float16", "float32"], "format": ["NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float16", "float32", "float32", "float32", "float32"], "format": ["DefaultFormat", "NHWC", "DefaultFormat", "DefaultFormat", "NHWC", "NHWC"], "name": "output"}]} +{"op_name": "Four2Five", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "data_format", "param_type": "optional", "type": "listStr"}, {"name": "dst_type", "param_type": "required", "type": "str"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float32", "float16", "float32", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NHWC", "NHWC", "NHWC"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float16", "float32", "float16", "float16", "float32"], "format": ["NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "FusedBatchNormGrad", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "data_format", "param_type": "optional", "type": "listStr"}], "inputs": [{"index": 0, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "dy"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "x"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "scale"}, {"index": 3, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "save_mean"}, {"index": 4, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "save_inv_variance"}], "outputs": [{"index": 0, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "dx"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "bn_scale"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "bn_bias"}]} +{"op_name": "FusedBatchNormInfer", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "momentum", "param_type": "optional", "type": "float"}, {"name": "epsilon", "param_type": "optional", "type": "float"}, {"name": "data_format", "param_type": "optional", "type": "listStr"}], "inputs": [{"index": 0, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "scale"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "b"}, {"index": 3, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "mean"}, {"index": 4, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "variance"}], "outputs": [{"index": 0, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "y"}]} +{"op_name": "FusedBatchNorm", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "momentum", "param_type": "optional", "type": "float"}, {"name": "epsilon", "param_type": "optional", "type": "float"}, {"name": "data_format", "param_type": "optional", "type": "listStr"}], "inputs": [{"index": 0, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "scale"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "b"}, {"index": 3, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "mean"}, {"index": 4, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "variance"}], "outputs": [{"index": 0, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "y"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "running_mean"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "running_variance"}, {"index": 3, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "save_mean"}, {"index": 4, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "save_inv_variance"}]} +{"op_name": "BNGrad1", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "dy"}, {"index": 1, "dtype": ["float16", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "data"}, {"index": 2, "dtype": ["float32", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "mean"}], "outputs": [{"index": 0, "dtype": ["float32", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "output"}, {"index": 1, "dtype": ["float32", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "output"}, {"index": 2, "dtype": ["float32", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "FusedBN1", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "data"}], "outputs": [{"index": 0, "dtype": ["float32", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "output"}, {"index": 1, "dtype": ["float32", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "BNGrad2", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [{"name": "eps", "param_type": "optional", "type": "float"}, {"name": "data_shape", "param_type": "optional", "type": "listInt"}], "inputs": [{"index": 0, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "dgamma_red_hw"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "dbeta_red_hw"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "variance"}, {"index": 3, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "gamma"}], "outputs": [{"index": 0, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "output"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "output"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "output"}, {"index": 3, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "output"}, {"index": 4, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "output"}]} +{"op_name": "FusedBN2", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [{"name": "momentum", "param_type": "optional", "type": "float"}], "inputs": [{"index": 0, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "mean"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "var_part"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "running_mean"}, {"index": 3, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "running_var"}], "outputs": [{"index": 0, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "output"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "output"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "output"}]} +{"op_name": "BNGrad3", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "dy"}, {"index": 1, "dtype": ["float32", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "rs"}, {"index": 2, "dtype": ["float32", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "dgamma_dx"}, {"index": 3, "dtype": ["float32", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "dbeta_dx"}, {"index": 4, "dtype": ["float32", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "data_minus_mean"}], "outputs": [{"index": 0, "dtype": ["float16", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "FusedBN3", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [{"name": "eps", "param_type": "optional", "type": "float"}], "inputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "data"}, {"index": 1, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "mean"}, {"index": 2, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "variance"}, {"index": 3, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "gamma"}, {"index": 4, "dtype": ["float32"], "format": ["NC1HWC0"], "name": "beta"}], "outputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "output"}]} +{"op_name": "GatherV2", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [{"name": "axis", "param_type": "optional", "type": "int"}], "inputs": [{"index": 0, "dtype": ["int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "params"}, {"index": 1, "dtype": ["int32", "int32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "indices"}], "outputs": [{"index": 0, "dtype": ["int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "output"}]} +{"op_name": "Less", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float16"], "format": ["DefaultFormat", "NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["float16", "float16"], "format": ["DefaultFormat", "NC1HWC0"], "name": "y"}], "outputs": [{"index": 0, "dtype": ["bool", "bool"], "format": ["DefaultFormat", "NC1HWC0"], "name": "output"}]} +{"op_name": "Log", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "MatMul", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "transpose_a", "param_type": "optional", "type": "bool"}, {"name": "transpose_b", "param_type": "optional", "type": "bool"}], "inputs": [{"index": 0, "dtype": ["float16", "float32"], "format": ["DefaultFormat", "DefaultFormat"], "name": "x1"}, {"index": 1, "dtype": ["float16", "float32"], "format": ["DefaultFormat", "DefaultFormat"], "name": "x2"}], "outputs": [{"index": 0, "dtype": ["float16", "float32"], "format": ["DefaultFormat", "DefaultFormat"], "name": "output"}]} +{"op_name": "BatchMatMul", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "transpose_a", "param_type": "optional", "type": "bool"}, {"name": "transpose_b", "param_type": "optional", "type": "bool"}], "inputs": [{"index": 0, "dtype": ["float16"], "format": ["FRACTAL_NZ"], "name": "x1"}, {"index": 1, "dtype": ["float16"], "format": ["FRACTAL_NZ"], "name": "x2"}], "outputs": [{"index": 0, "dtype": ["float16"], "format": ["FRACTAL_NZ"], "name": "output"}]} +{"op_name": "MaxPoolGradWithArgmax", "imply_type": "AutoDiff", "fusion_type": "CONVLUTION", "attr": [{"name": "pad_mode", "param_type": "optional", "type": "str"}, {"name": "window", "param_type": "optional", "type": "int"}, {"name": "pad", "param_type": "optional", "type": "int"}, {"name": "stride", "param_type": "optional", "type": "int"}], "inputs": [{"index": 0, "dtype": ["float16", "float16"], "format": ["NC1HWC0", "NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["float16", "float32"], "format": ["DefaultFormat", "DefaultFormat"], "name": "argmax"}, {"index": 2, "dtype": ["float16", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "grad"}], "outputs": [{"index": 0, "dtype": ["float16", "float32"], "format": ["NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "MaxPoolWithArgmax", "imply_type": "AutoDiff", "fusion_type": "CONVLUTION", "attr": [{"name": "pad_mode", "param_type": "optional", "type": "str"}, {"name": "window", "param_type": "optional", "type": "int"}, {"name": "pad", "param_type": "optional", "type": "int"}, {"name": "stride", "param_type": "optional", "type": "int"}], "inputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16"], "format": ["NC1HWC0"], "name": "output"}, {"index": 1, "dtype": ["float16"], "format": ["DefaultFormat"], "name": "argmax"}]} +{"op_name": "Max", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [{"name": "axis", "param_type": "required", "type": "listInt"}, {"name": "keep_dims", "param_type": "required", "type": "bool"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "Maximum", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "param_type": "required", "name": "x"}, {"index": 1, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "param_type": "required", "name": "y"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "SimpleMeanGrad", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [{"name": "input_shape", "param_type": "required", "type": "listInt"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "HEAD"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "SimpleMean", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "Minimum", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "x"}, {"index": 1, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "y"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "Mul", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [{"name": "x_shape", "param_type": "required", "type": "listInt"}, {"name": "y_shape", "param_type": "required", "type": "listInt"}, {"name": "data_format", "param_type": "required", "type": "listStr"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32", "float16", "float32"], "format": ["FracZ", "FracZ", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "x"}, {"index": 1, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32", "float16", "float32"], "format": ["FracZ", "FracZ", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "y"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32", "float16", "float32"], "format": ["FracZ", "FracZ", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "Neg", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "OneHot", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "depth", "param_type": "required", "type": "int"}, {"name": "axis", "param_type": "required", "type": "int"}], "inputs": [{"index": 0, "dtype": ["int32", "int32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "indices"}, {"index": 1, "dtype": ["int32", "float32", "float16"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "on_value"}, {"index": 2, "dtype": ["int32", "float32", "float16"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "off_value"}], "outputs": [{"index": 0, "dtype": ["int32", "float32", "float16"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "output"}]} +{"op_name": "Pow", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "int32", "float16", "int32", "float32", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0"], "param_type": "required", "name": "x"}, {"index": 1, "dtype": ["float16", "int32", "float16", "int32", "float32", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0"], "param_type": "required", "name": "power"}], "outputs": [{"index": 0, "dtype": ["float16", "int32", "float16", "int32", "float32", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0"], "name": "output"}]} +{"op_name": "RealDiv", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "x"}, {"index": 1, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "y"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "Reciprocal", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "ReduceMax", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [{"name": "axis", "param_type": "required", "type": "listInt"}, {"name": "keep_dims", "param_type": "required", "type": "bool"}], "inputs": [{"index": 0, "dtype": ["float16", "float16"], "format": ["DefaultFormat", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float16"], "format": ["DefaultFormat", "NC1HWC0"], "name": "output"}]} +{"op_name": "ReduceMean", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [{"name": "axis", "param_type": "required", "type": "listInt"}, {"name": "keep_dims", "param_type": "required", "type": "bool"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "ReduceSum", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [{"name": "axis", "param_type": "required", "type": "listInt"}, {"name": "keep_dims", "param_type": "required", "type": "bool"}, {"name": "atomic_add", "param_type": "optional", "type": "str"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "ReluGrad", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0"], "name": "y_backprop"}, {"index": 1, "dtype": ["float16", "float32", "float16"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0"], "name": "output"}]} +{"op_name": "ReLU", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0"], "name": "output"}]} +{"op_name": "Reshape", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "shape", "param_type": "required", "type": "listInt"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "tensor"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "Round", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "Rsqrt", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "param_type": "required", "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "Select", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["bool", "bool", "bool", "bool", "bool", "bool"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0"], "param_type": "required", "name": "condition"}, {"index": 1, "dtype": ["float16", "int32", "float16", "int32", "float32", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0"], "param_type": "required", "name": "x"}, {"index": 2, "dtype": ["float16", "int32", "float16", "int32", "float32", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0"], "param_type": "required", "name": "y"}], "outputs": [{"index": 0, "dtype": ["float16", "int32", "float16", "int32", "float32", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0"], "name": "output"}]} +{"op_name": "Softmax", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [{"name": "axis", "param_type": "required", "type": "listInt"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "SparseSoftmaxCrossEntropyWithLogits", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "is_grad", "param_type": "optional", "type": "bool"}, {"name": "sens", "param_type": "optional", "type": "float"}], "inputs": [{"index": 0, "dtype": ["float32"], "format": ["DefaultFormat"], "name": "features"}, {"index": 1, "dtype": ["int32"], "format": ["DefaultFormat"], "name": "labels"}], "outputs": [{"index": 0, "dtype": ["float32"], "format": ["DefaultFormat"], "name": "output"}]} +{"op_name": "Sqrt", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "param_type": "required", "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "StridedSlice", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "begin", "param_type": "required", "type": "listInt"}, {"name": "end", "param_type": "required", "type": "listInt"}, {"name": "strides", "param_type": "required", "type": "listInt"}, {"name": "begin_mask", "param_type": "required", "type": "int"}, {"name": "end_mask", "param_type": "required", "type": "int"}, {"name": "ellipsis_mask", "param_type": "required", "type": "int"}, {"name": "new_axis_mask", "param_type": "required", "type": "int"}, {"name": "shrink_axis_mask", "param_type": "required", "type": "int"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "Sub", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "x"}, {"index": 1, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "y"}], "outputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "Sum", "imply_type": "AutoDiff", "fusion_type": "COMMREDUCE", "attr": [{"name": "axis", "param_type": "required", "type": "listInt"}, {"name": "keepdims", "param_type": "required", "type": "bool"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "param_type": "required", "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "output"}]} +{"op_name": "Tile", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "multiples", "param_type": "required", "type": "listInt"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "int32", "float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "ZerosLike", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "Argmax", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "axis", "param_type": "optional", "type": "int"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["int32", "int32", "int32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "FloorDiv", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "y"}], "outputs": [{"index": 0, "dtype": ["int32", "int32", "int32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "Equal", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "y"}], "outputs": [{"index": 0, "dtype": ["bool", "bool", "bool", "bool", "bool", "bool"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "GreaterEqual", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "y"}], "outputs": [{"index": 0, "dtype": ["bool", "bool", "bool", "bool", "bool", "bool"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "LessEqual", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["int32", "float16", "float32", "int32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "y"}], "outputs": [{"index": 0, "dtype": ["bool", "bool", "bool", "bool", "bool", "bool"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0"], "name": "output"}]} +{"op_name": "ExpandDims", "imply_type": "AutoDiff", "fusion_type": "OPAQUE", "attr": [{"name": "axis", "param_type": "required", "type": "int"}], "inputs": [{"index": 0, "dtype": ["float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "int32"], "format": ["DefaultFormat", "DefaultFormat", "DefaultFormat"], "name": "y"}]} +{"op_name": "Greater", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float16", "float32", "float32"], "format": ["DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"], "name": "x"}, {"index": 1, "dtype": ["float16", "float16", "float32", "float32"], "format": ["DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"], "name": "y"}], "outputs": [{"index": 0, "dtype": ["bool", "bool", "bool", "bool"], "format": ["DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0"], "name": "output"}]} +{"op_name": "EquivFormat", "imply_type": "AutoDiff", "fusion_type": "ELEMWISE", "attr": [], "inputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["DefaultFormat", "DefaultFormat", "FRACTAL_NZ", "FRACTAL_NZ"], "name": "x"}], "outputs": [{"index": 0, "dtype": ["float16", "float32", "float16", "float32"], "format": ["FRACTAL_NZ", "FRACTAL_NZ", "DefaultFormat", "DefaultFormat"], "name": "output"}]} +{"op_name": "Cast", "inputs": [{"index": 0, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [{"name": "dst_type", "param_type": "required", "type": "str"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["bool", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "Equal", "inputs": [{"index": 0, "name": "x"}, {"index": 1, "name": "y"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["bool", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "SimpleMean", "inputs": [{"index": 0, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "SimpleMeanGrad", "inputs": [{"index": 0, "name": "HEAD"}], "outputs": [{"index": 0, "name": "output"}], "attr": [{"name": "input_shape", "param_type": "required", "type": "listInt"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "Mul", "inputs": [{"index": 0, "name": "x"}, {"index": 1, "name": "y"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "ReLU6", "inputs": [{"index": 0, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "ReLU6Grad", "inputs": [{"index": 0, "name": "y_grad"}, {"index": 1, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "Squeeze", "inputs": [{"index": 0, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "SqueezeGrad", "inputs": [{"index": 0, "name": "y_grad"}], "outputs": [{"index": 0, "name": "output"}], "attr": [{"name": "x_shape", "param_type": "required", "type": "listInt"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "Tile", "inputs": [{"index": 0, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [{"name": "multiples", "param_type": "required", "type": "listInt"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "HSigmoid", "inputs": [{"index": 0, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "HSigmoidGrad", "inputs": [{"index": 0, "name": "y_grad"}, {"index": 1, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "HSwish", "inputs": [{"index": 0, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "HSwishGrad", "inputs": [{"index": 0, "name": "y_grad"}, {"index": 1, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "Sub", "inputs": [{"index": 0, "name": "x"}, {"index": 1, "name": "y"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "LogicalAnd", "inputs": [{"index": 0, "name": "x"}, {"index": 1, "name": "y"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "LogicalNot", "inputs": [{"index": 0, "name": "x"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "LogicalOr", "inputs": [{"index": 0, "name": "x"}, {"index": 1, "name": "y"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "LessEqual", "inputs": [{"index": 0, "name": "x"}, {"index": 1, "name": "y"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["bool", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "NotEqual", "inputs": [{"index": 0, "name": "x"}, {"index": 1, "name": "y"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["bool", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "GreaterEqual", "inputs": [{"index": 0, "name": "x"}, {"index": 1, "name": "y"}], "outputs": [{"index": 0, "name": "output"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["bool", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "AutoDiff", "processor": "cuda"} +{"op_name": "Abs", "inputs": [{"index": 0, "name": "x", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]], [["int32", ""], ["int32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "abs.so", "compute_cost": 10, "kernel_name": "abs", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "InplaceAdd", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "v", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "indices", "param_type": "required", "type": "listInt", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "inplace_add_d.so", "compute_cost": 10, "kernel_name": "inplace_add_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "InplaceSub", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "v", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "indices", "param_type": "required", "type": "listInt", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "inplace_sub_d.so", "compute_cost": 10, "kernel_name": "inplace_sub_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "AbsGrad", "inputs": [{"index": 0, "name": "y", "param_type": "required"}, {"index": 1, "name": "dy", "param_type": "required"}], "outputs": [{"index": 0, "name": "z", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "abs_grad.so", "compute_cost": 10, "kernel_name": "abs_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ACos", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "acos.so", "compute_cost": 10, "kernel_name": "acos", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "ACosGrad", "inputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "acos_grad.so", "compute_cost": 10, "kernel_name": "acos_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Acosh", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "acosh.so", "compute_cost": 10, "kernel_name": "acosh", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "AcoshGrad", "inputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "acosh_grad.so", "compute_cost": 10, "kernel_name": "acosh_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "AdamApplyOneWithDecay", "inputs": [{"index": 0, "name": "input0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "input2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "input3", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "input4", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "mul0_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "mul1_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 7, "name": "mul2_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 8, "name": "mul3_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 9, "name": "mul4_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 10, "name": "add2_y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "output1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "output2", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "adam_apply_one_with_decay.so", "compute_cost": 10, "kernel_name": "adam_apply_one_with_decay", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Add", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["", ""], ["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "add.so", "compute_cost": 10, "kernel_name": "add", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "ApplyCenteredRMSProp", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "mg", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "ms", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "mom", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "rho", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "momentum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 7, "name": "epsilon", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 8, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_centered_rms_prop.so", "compute_cost": 10, "kernel_name": "apply_centered_rms_prop", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "AddN", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "dynamic", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "n", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]], [["int32", ""], ["int32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "add_n.so", "compute_cost": 10, "kernel_name": "add_n", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "AccumulateNV2", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "dynamic", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "n", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]], [["int32", ""], ["int32", ""]], [["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "accumulate_n_v2.so", "compute_cost": 10, "kernel_name": "accumulate_n_v2", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "ApplyFtrl", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "linear", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "l1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "l2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 7, "name": "lr_power", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "linear", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_ftrl.so", "compute_cost": 10, "kernel_name": "apply_ftrl", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyMomentum", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "momentum", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_nesterov", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_momentum.so", "compute_cost": 10, "kernel_name": "apply_momentum", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Adam", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "m", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "v", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "beta1_power", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "beta2_power", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "beta1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 7, "name": "beta2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 8, "name": "epsilon", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 9, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "m", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "v", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}, {"name": "use_nesterov", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_adam.so", "compute_cost": 10, "kernel_name": "apply_adam", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyAdaMax", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "m", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "v", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "beta1_power", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "beta1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "beta2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 7, "name": "epsilon", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 8, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "m", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "v", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_ada_max_d.so", "compute_cost": 10, "kernel_name": "apply_ada_max_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyAdadelta", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "accum_update", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "rho", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "epsilon", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "accum_update", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_adadelta_d.so", "compute_cost": 10, "kernel_name": "apply_adadelta_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyAdagrad", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "update_slots", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_adagrad_d.so", "compute_cost": 10, "kernel_name": "apply_adagrad_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyAdagradV2", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "epsilon", "param_type": "required", "type": "float", "value": "all"}, {"name": "update_slots", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_adagradv2_d.so", "compute_cost": 10, "kernel_name": "apply_adagradv2_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyAddSign", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "m", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "alpha", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "sign_decay", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "beta", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "m", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_add_sign_d.so", "compute_cost": 10, "kernel_name": "apply_add_sign_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyPowerSign", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "m", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "logbase", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "sign_decay", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "beta", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "m", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_power_sign_d.so", "compute_cost": 10, "kernel_name": "apply_power_sign_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyGradientDescent", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "alpha", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "delta", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_gradient_descent.so", "compute_cost": 10, "kernel_name": "apply_gradient_descent", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyProximalGradientDescent", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "alpha", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "l1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "l2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "delta", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_proximal_gradient_descent.so", "compute_cost": 10, "kernel_name": "apply_proximal_gradient_descent", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SparseApplyFtrlV2", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "linear", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "linear", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "lr", "param_type": "required", "type": "float", "value": "all"}, {"name": "l1", "param_type": "required", "type": "float", "value": "all"}, {"name": "l2", "param_type": "required", "type": "float", "value": "all"}, {"name": "l2_shrinkage", "param_type": "required", "type": "float", "value": "all"}, {"name": "lr_power", "param_type": "required", "type": "float", "value": "all"}, {"name": "use_locking", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["int32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sparse_apply_ftrl_v2_d.so", "compute_cost": 10, "kernel_name": "sparse_apply_ftrl_v2_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SparseApplyAdagradV2", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "lr", "param_type": "required", "type": "float", "value": "all"}, {"name": "epsilon", "param_type": "required", "type": "float", "value": "all"}, {"name": "use_locking", "param_type": "optional", "type": "bool", "value": "all"}, {"name": "update_slots", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["int32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sparse_apply_adagrad_v2_d.so", "compute_cost": 10, "kernel_name": "sparse_apply_adagrad_v2_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApproximateEqual", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "tolerance", "param_type": "optional", "type": "float", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""], ["bool", ""]], [["float32", ""], ["float32", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "approximate_equal.so", "compute_cost": 10, "kernel_name": "approximate_equal", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "AdamApplyOne", "inputs": [{"index": 0, "name": "input0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "input2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "input3", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "input4", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "mul0_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "mul1_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 7, "name": "mul2_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 8, "name": "mul3_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 9, "name": "add2_y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "output1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "output2", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "adam_apply_one.so", "compute_cost": 10, "kernel_name": "adam_apply_one", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Assign", "inputs": [{"index": 0, "name": "ref", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "value", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "ref", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"], ["bool", "DefaultFormat"]], [["bool", "NC1HWC0"], ["bool", "NC1HWC0"], ["bool", "NC1HWC0"]], [["bool", "C1HWNCoC0"], ["bool", "C1HWNCoC0"], ["bool", "C1HWNCoC0"]], [["bool", "FracZ"], ["bool", "FracZ"], ["bool", "FracZ"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"]], [["int8", "FracZ"], ["int8", "FracZ"], ["int8", "FracZ"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"]], [["uint8", "FracZ"], ["uint8", "FracZ"], ["uint8", "FracZ"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int16", "NC1HWC0"], ["int16", "NC1HWC0"], ["int16", "NC1HWC0"]], [["int16", "C1HWNCoC0"], ["int16", "C1HWNCoC0"], ["int16", "C1HWNCoC0"]], [["int16", "FracZ"], ["int16", "FracZ"], ["int16", "FracZ"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint16", "NC1HWC0"], ["uint16", "NC1HWC0"], ["uint16", "NC1HWC0"]], [["uint16", "C1HWNCoC0"], ["uint16", "C1HWNCoC0"], ["uint16", "C1HWNCoC0"]], [["uint16", "FracZ"], ["uint16", "FracZ"], ["uint16", "FracZ"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"]], [["int32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint32", "NC1HWC0"], ["uint32", "NC1HWC0"], ["uint32", "NC1HWC0"]], [["uint32", "C1HWNCoC0"], ["uint32", "C1HWNCoC0"], ["uint32", "C1HWNCoC0"]], [["uint32", "FracZ"], ["uint32", "FracZ"], ["uint32", "FracZ"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["int64", "NC1HWC0"], ["int64", "NC1HWC0"], ["int64", "NC1HWC0"]], [["int64", "C1HWNCoC0"], ["int64", "C1HWNCoC0"], ["int64", "C1HWNCoC0"]], [["int64", "FracZ"], ["int64", "FracZ"], ["int64", "FracZ"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["uint64", "NC1HWC0"], ["uint64", "NC1HWC0"], ["uint64", "NC1HWC0"]], [["uint64", "C1HWNCoC0"], ["uint64", "C1HWNCoC0"], ["uint64", "C1HWNCoC0"]], [["uint64", "FracZ"], ["uint64", "FracZ"], ["uint64", "FracZ"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "assign.so", "compute_cost": 10, "kernel_name": "assign", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "AssignAdd", "inputs": [{"index": 0, "name": "ref", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "value", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "ref", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"]], [["int8", "FracZ"], ["int8", "FracZ"], ["int8", "FracZ"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"]], [["uint8", "FracZ"], ["uint8", "FracZ"], ["uint8", "FracZ"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"]], [["int32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["int64", "NC1HWC0"], ["int64", "NC1HWC0"], ["int64", "NC1HWC0"]], [["int64", "C1HWNCoC0"], ["int64", "C1HWNCoC0"], ["int64", "C1HWNCoC0"]], [["int64", "FracZ"], ["int64", "FracZ"], ["int64", "FracZ"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "assignadd.so", "compute_cost": 10, "kernel_name": "assignadd", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "AssignSub", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "value", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"]], [["int8", "FracZ"], ["int8", "FracZ"], ["int8", "FracZ"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"]], [["uint8", "FracZ"], ["uint8", "FracZ"], ["uint8", "FracZ"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"]], [["int32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "assign_sub.so", "compute_cost": 10, "kernel_name": "assign_sub", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BatchMatMul", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "bias", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "transpose_x1", "param_type": "required", "type": "bool", "value": "all"}, {"name": "transpose_x2", "param_type": "required", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int32", ""], ["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "batch_matmul.so", "compute_cost": 10, "kernel_name": "batch_matmul", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "BatchNorm", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "offset", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "mean", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 4, "name": "variance", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "batch_mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "batch_variance", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "reserve_space_1", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 4, "name": "reserve_space_2", "need_compile": false, "param_type": "optional", "shape": "all"}], "attr": [{"name": "epsilon", "param_type": "optional", "type": "float", "value": "all"}, {"name": "data_format", "param_type": "optional", "type": "str", "value": "all"}, {"name": "is_training", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "batch_norm.so", "compute_cost": 10, "kernel_name": "batch_norm", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BatchNormGrad", "inputs": [{"index": 0, "name": "y_backprop", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "reserve_space_1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "reserve_space_2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "x_backprop", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "scale_backprop", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "offset_backprop", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "reserve_space_4", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 4, "name": "reserve_space_5", "need_compile": false, "param_type": "optional", "shape": "all"}], "attr": [{"name": "epsilon", "param_type": "optional", "type": "float", "value": "all"}, {"name": "data_format", "param_type": "optional", "type": "str", "value": "all"}, {"name": "is_training", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "batchnormgrad.so", "compute_cost": 10, "kernel_name": "batchnormgrad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BiasAdd", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "bias", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "data_format", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "COMMREDUCE", "dtype_format": [[["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bias_add.so", "compute_cost": 10, "kernel_name": "bias_add", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "BiasAddGrad", "inputs": [{"index": 0, "name": "output_backprop", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "data_format", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "COMMREDUCE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "FRACTAL_NZ"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FRACTAL_NZ"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "biasaddgrad.so", "compute_cost": 10, "kernel_name": "biasaddgrad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Cast", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "dst_type", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["bool", ""], ["float16", ""]], [["bool", ""], ["uint8", ""]], [["bool", ""], ["float32", ""]], [["bool", ""], ["int32", ""]], [["int8", ""], ["float16", ""]], [["int8", ""], ["float32", ""]], [["int8", ""], ["int32", ""]], [["uint8", ""], ["float16", ""]], [["uint8", ""], ["float32", ""]], [["uint8", ""], ["int32", ""]], [["int32", ""], ["bool", ""]], [["int32", ""], ["float16", ""]], [["int32", ""], ["float32", ""]], [["int32", ""], ["int8", ""]], [["int32", ""], ["uint8", ""]], [["float16", ""], ["uint8", ""]], [["float16", ""], ["float32", ""]], [["float16", ""], ["int32", ""]], [["float32", ""], ["float16", ""]], [["float32", ""], ["int32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "cast.so", "compute_cost": 10, "kernel_name": "cast", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Conv2D", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "filter", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "bias", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 3, "name": "offset_w", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [{"name": "stride", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_list", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "dilation", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "offset_a", "param_type": "optional", "type": "int", "value": "all"}], "fusion_type": "CONVLUTION", "dtype_format": [[["float16", ""], ["float16", ""], ["float16", ""], ["int8", ""], ["float16", ""]], [["int8", ""], ["int8", ""], ["int32", ""], ["int8", ""], ["int32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "conv2d.so", "compute_cost": 10, "kernel_name": "conv2d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "Conv2DBackpropFilter", "inputs": [{"index": 0, "name": "out_backprop", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "filter_sizes", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "stride", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_list", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "dilation", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "groups", "param_type": "optional", "type": "int", "value": "all"}, {"name": "data_format", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "CONVLUTION", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "conv2d_backprop_filter_d.so", "compute_cost": 10, "kernel_name": "conv2d_backprop_filter_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Conv2DBackpropInput", "inputs": [{"index": 0, "name": "out_backprop", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "filter", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [{"name": "input_sizes", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "stride", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pad_list", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "dilation", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "group", "param_type": "optional", "type": "int", "value": "all"}, {"name": "data_format", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "CONVLUTION", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "FracZ"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "conv2d_backprop_input_d.so", "compute_cost": 10, "kernel_name": "conv2d_backprop_input_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ConfusionMulGrad", "inputs": [{"index": 0, "name": "input0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "input2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "output1", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "required", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "", "compute_cost": 10, "kernel_name": "", "partial_flag": false, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "DropoutDoMask", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "mask", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "keep_prob", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""], ["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "drop_out_do_mask.so", "compute_cost": 10, "kernel_name": "drop_out_do_mask", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "Gelu", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gelu.so", "compute_cost": 10, "kernel_name": "gelu", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "GeluGrad", "inputs": [{"index": 0, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gelu_grad.so", "compute_cost": 10, "kernel_name": "gelu_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "MaxPool", "inputs": [{"index": 0, "name": "input_data", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output_data", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "ksize", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "padding", "param_type": "required", "type": "str", "value": "all"}, {"name": "data_format", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool.so", "compute_cost": 10, "kernel_name": "max_pool", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "MaxPoolGrad", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "ksize", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "padding", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool_grad.so", "compute_cost": 10, "kernel_name": "max_pool_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "MaxPoolGradWithArgmax", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "argmax", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "ksize", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "padding", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["uint16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["int64", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool_grad_with_argmax.so", "compute_cost": 10, "kernel_name": "max_pool_grad_with_argmax", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "MaxPoolWithArgmax", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "argmax", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "ksize", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "padding", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "CONVLUTION", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["uint16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool_with_argmax.so", "compute_cost": 10, "kernel_name": "max_pool_with_argmax", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Mul", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["", ""], ["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "mul.so", "compute_cost": 10, "kernel_name": "mul", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "RealDiv", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "realdiv.so", "compute_cost": 10, "kernel_name": "realdiv", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "ReLU", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int8", ""], ["int8", ""]], [["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "relu.so", "compute_cost": 10, "kernel_name": "relu", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "ReluGrad", "inputs": [{"index": 0, "name": "gradients", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "features", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "backprops", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int8", ""], ["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""], ["uint8", ""]], [["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "relugrad.so", "compute_cost": 10, "kernel_name": "relugrad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "ReLU6", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]], [["int32", ""], ["int32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "relu6.so", "compute_cost": 10, "kernel_name": "relu6", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "ReLU6Grad", "inputs": [{"index": 0, "name": "gradients", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "features", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "backprops", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "relu6_grad.so", "compute_cost": 10, "kernel_name": "relu6_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ReLUV2", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "mask", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["uint8", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["uint8", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["uint8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "relu_v2.so", "compute_cost": 10, "kernel_name": "relu_v2", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ReluGradV2", "inputs": [{"index": 0, "name": "gradients", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "mask", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "backprops", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["uint8", "DefaultFormat"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["uint8", "DefaultFormat"], ["float32", "NC1HWC0"]], [["int32", "NC1HWC0"], ["uint8", "DefaultFormat"], ["int32", "NC1HWC0"]], [["int8", "NC1HWC0"], ["uint8", "DefaultFormat"], ["int8", "NC1HWC0"]], [["uint8", "NC1HWC0"], ["uint8", "DefaultFormat"], ["uint8", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "relu_grad_v2.so", "compute_cost": 10, "kernel_name": "relu_grad_v2", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SoftmaxCrossEntropyWithLogits", "inputs": [{"index": 0, "name": "input_features", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input_labels", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output_loss", "need_compile": true, "param_type": "required", "shape": "all"}, {"index": 1, "name": "output_backprop", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "softmax_cross_entropy_with_logits.so", "compute_cost": 10, "kernel_name": "softmax_cross_entropy_with_logits", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SigmoidCrossEntropyWithLogits", "inputs": [{"index": 0, "name": "predict", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "target", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "loss", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sigmoid_cross_entropy_with_logits.so", "compute_cost": 10, "kernel_name": "sigmoid_cross_entropy_with_logits", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SigmoidCrossEntropyWithLogitsGrad", "inputs": [{"index": 0, "name": "predict", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "target", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "dout", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "gradient", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sigmoid_cross_entropy_with_logits_grad.so", "compute_cost": 10, "kernel_name": "sigmoid_cross_entropy_with_logits_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "TensorAdd", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "add.so", "compute_cost": 10, "kernel_name": "add", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "TransData", "inputs": [{"index": 0, "name": "src", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "dst", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "src_format", "param_type": "required", "type": "str", "value": "DefaultFormat, NC1HWC0, FracZ, FRACTAL_NZ, HWCN, C1HWNCoC0, NDHWC, NHWC"}, {"name": "dst_format", "param_type": "required", "type": "str", "value": "DefaultFormat, NC1HWC0, FracZ, FRACTAL_NZ, HWCN, C1HWNCoC0, NDHWC, NHWC"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "NHWC"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NHWC"]], [["float32", "NC1HWC0"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "FracZ"]], [["float32", "HWCN"], ["float32", "FracZ"]], [["float32", "FracZ"], ["float32", "HWCN"]], [["float32", "C1HWNCoC0"], ["float32", "HWCN"]], [["float32", "HWCN"], ["float32", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "FracZ"]], [["float16", "NHWC"], ["float16", "FracZ"]], [["float16", "HWCN"], ["float16", "FracZ"]], [["float16", "DefaultFormat"], ["float16", "NC1HWC0"]], [["float16", "NHWC"], ["float16", "NC1HWC0"]], [["float16", "HWCN"], ["float16", "NC1HWC0"]], [["float16", "NC1HWC0"], ["float16", "NHWC"]], [["float16", "NC1HWC0"], ["float16", "DefaultFormat"]], [["float16", "FracZ"], ["float16", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "FracZ"]], [["float16", "HWCN"], ["float16", "FracZ"]], [["float16", "FracZ"], ["float16", "HWCN"]], [["float16", "C1HWNCoC0"], ["float16", "HWCN"]], [["float16", "HWCN"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "FRACTAL_NZ"]], [["float32", "DefaultFormat"], ["float32", "FRACTAL_NZ"]], [["float16", "FRACTAL_NZ"], ["float16", "DefaultFormat"]], [["float32", "FRACTAL_NZ"], ["float32", "DefaultFormat"]], [["bool", "NHWC"], ["bool", "NC1HWC0"]], [["bool", "DefaultFormat"], ["bool", "NC1HWC0"]], [["bool", "NC1HWC0"], ["bool", "NHWC"]], [["bool", "NC1HWC0"], ["bool", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "NHWC"]], [["float16", "DefaultFormat"], ["float16", "HWCN"]], [["float16", "NHWC"], ["float16", "DefaultFormat"]], [["float16", "NHWC"], ["float16", "HWCN"]], [["float16", "HWCN"], ["float16", "DefaultFormat"]], [["float16", "HWCN"], ["float16", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "HWCN"]], [["float32", "NHWC"], ["float32", "DefaultFormat"]], [["float32", "NHWC"], ["float32", "HWCN"]], [["float32", "HWCN"], ["float32", "DefaultFormat"]], [["float32", "HWCN"], ["float32", "NHWC"]], [["int8", "DefaultFormat"], ["int8", "FRACTAL_NZ"]], [["int8", "DefaultFormat"], ["int8", "FracZ"]], [["int8", "DefaultFormat"], ["int8", "NHWC"]], [["int8", "DefaultFormat"], ["int8", "HWCN"]], [["int8", "NHWC"], ["int8", "DefaultFormat"]], [["int8", "NHWC"], ["int8", "HWCN"]], [["int8", "HWCN"], ["int8", "DefaultFormat"]], [["int8", "HWCN"], ["int8", "NHWC"]], [["int16", "DefaultFormat"], ["int16", "NHWC"]], [["int16", "DefaultFormat"], ["int16", "HWCN"]], [["int16", "NHWC"], ["int16", "DefaultFormat"]], [["int16", "NHWC"], ["int16", "HWCN"]], [["int16", "HWCN"], ["int16", "DefaultFormat"]], [["int16", "HWCN"], ["int16", "NHWC"]], [["int32", "DefaultFormat"], ["int32", "NHWC"]], [["int32", "DefaultFormat"], ["int32", "HWCN"]], [["int32", "NHWC"], ["int32", "DefaultFormat"]], [["int32", "NHWC"], ["int32", "HWCN"]], [["int32", "HWCN"], ["int32", "DefaultFormat"]], [["int32", "HWCN"], ["int32", "NHWC"]], [["int64", "DefaultFormat"], ["int64", "NHWC"]], [["int64", "DefaultFormat"], ["int64", "HWCN"]], [["int64", "NHWC"], ["int64", "DefaultFormat"]], [["int64", "NHWC"], ["int64", "HWCN"]], [["int64", "HWCN"], ["int64", "DefaultFormat"]], [["int64", "HWCN"], ["int64", "NHWC"]], [["uint8", "DefaultFormat"], ["uint8", "NHWC"]], [["uint8", "DefaultFormat"], ["uint8", "HWCN"]], [["uint8", "NHWC"], ["uint8", "DefaultFormat"]], [["uint8", "NHWC"], ["uint8", "HWCN"]], [["uint8", "HWCN"], ["uint8", "DefaultFormat"]], [["uint8", "HWCN"], ["uint8", "NHWC"]], [["uint16", "DefaultFormat"], ["uint16", "NHWC"]], [["uint16", "DefaultFormat"], ["uint16", "HWCN"]], [["uint16", "NHWC"], ["uint16", "DefaultFormat"]], [["uint16", "NHWC"], ["uint16", "HWCN"]], [["uint16", "HWCN"], ["uint16", "DefaultFormat"]], [["uint16", "HWCN"], ["uint16", "NHWC"]], [["uint32", "DefaultFormat"], ["uint32", "NHWC"]], [["uint32", "DefaultFormat"], ["uint32", "HWCN"]], [["uint32", "NHWC"], ["uint32", "DefaultFormat"]], [["uint32", "NHWC"], ["uint32", "HWCN"]], [["uint32", "HWCN"], ["uint32", "DefaultFormat"]], [["uint32", "HWCN"], ["uint32", "NHWC"]], [["uint64", "DefaultFormat"], ["uint64", "NHWC"]], [["uint64", "DefaultFormat"], ["uint64", "HWCN"]], [["uint64", "NHWC"], ["uint64", "DefaultFormat"]], [["uint64", "NHWC"], ["uint64", "HWCN"]], [["uint64", "HWCN"], ["uint64", "DefaultFormat"]], [["uint64", "HWCN"], ["uint64", "NHWC"]], [["int32", "FRACTAL_NZ"], ["int32", "DefaultFormat"]], [["float16", "NDHWC"], ["float16", "NC1HWC0"]], [["float16", "NC1HWC0"], ["float16", "NDHWC"]], [["int8", "HWCN"], ["int8", "C1HWNCoC0"]], [["float16", "HWCN"], ["float16", "FracZ"]], [["float16", "FracZ"], ["float16", "HWCN"]], [["float16", "HWCN"], ["float16", "FRACTAL_NZ"]], [["float32", "HWCN"], ["float16", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "trans_data.so", "compute_cost": 10, "kernel_name": "trans_data", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "TopK", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "assist_seq", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "values", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "dim", "param_type": "optional", "type": "int", "value": "all"}, {"name": "k", "param_type": "required", "type": "int", "value": "all"}, {"name": "largest", "param_type": "optional", "type": "bool", "value": "all"}, {"name": "sorted", "param_type": "optional", "type": "bool", "value": "true"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "top_k_d.so", "compute_cost": 10, "kernel_name": "top_k_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "MatMul", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "bias", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 3, "name": "offset_w", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "transpose_x1", "param_type": "required", "type": "bool", "value": "all"}, {"name": "transpose_x2", "param_type": "required", "type": "bool", "value": "all"}, {"name": "offset_x", "param_type": "optional", "type": "int", "value": "all"}], "fusion_type": "DYNAMIC", "dtype_format": [[["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "DefaultFormat"], ["int8", "DefaultFormat"], ["float16", "FRACTAL_NZ"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float32", "DefaultFormat"], ["int8", "DefaultFormat"], ["float32", "FRACTAL_NZ"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["int8", "DefaultFormat"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int8", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "NHWC"], ["int32", "NHWC"], ["int32", "NHWC"], ["int8", "DefaultFormat"], ["int32", "NHWC"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "matmul.so", "compute_cost": 10, "kernel_name": "matmul", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Sub", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sub.so", "compute_cost": 10, "kernel_name": "sub", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "ReduceMeanD", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "reduce_mean_d.so", "compute_cost": 10, "kernel_name": "reduce_mean_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "reduce"} +{"op_name": "ScatterNd", "inputs": [{"index": 0, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "shape", "param_type": "optional", "type": "listInt", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "scatter_nd_d.so", "compute_cost": 10, "kernel_name": "scatter_nd_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ScatterNdD", "inputs": [{"index": 0, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "shape", "param_type": "optional", "type": "listInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "scatter_nd_d.so", "compute_cost": 10, "kernel_name": "scatter_nd_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ReduceMean", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "reduce_mean.so", "compute_cost": 10, "kernel_name": "reduce_mean", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "reduce"} +{"op_name": "Tile", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "multiples", "param_type": "optional", "type": "listInt", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "tile_d.so", "compute_cost": 10, "kernel_name": "tile_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "AtomicAddrClean", "inputs": [], "outputs": [], "attr": [{"name": "automic_add_mem_size", "param_type": "required", "type": "listUInt64", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [], "imply_type": "TBE", "async_flag": false, "binfile_name": "atomic_addr_clean.so", "compute_cost": 10, "kernel_name": "atomic_addr_clean", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "GatherV2", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int32", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "NC1HWC0"], ["int64", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "FracZ"], ["int32", "FracZ"], ["int8", "FracZ"]], [["int8", "FracZ"], ["int64", "FracZ"], ["int8", "FracZ"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["int32", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "NC1HWC0"], ["int64", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "FracZ"], ["int32", "FracZ"], ["uint8", "FracZ"]], [["uint8", "FracZ"], ["int64", "FracZ"], ["uint8", "FracZ"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"]], [["int32", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"]], [["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int64", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["int32", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "NC1HWC0"], ["int64", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["int32", "FracZ"], ["float16", "FracZ"]], [["float16", "FracZ"], ["int64", "FracZ"], ["float16", "FracZ"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["int32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NC1HWC0"], ["int64", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["int32", "FracZ"], ["float32", "FracZ"]], [["float32", "FracZ"], ["int64", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gather_v2_d.so", "compute_cost": 10, "kernel_name": "gather_v2_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "GatherNd", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int64", "DefaultFormat"], ["float16", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["bool", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"]], [["bool", "DefaultFormat"], ["int64", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gather_nd.so", "compute_cost": 10, "kernel_name": "gather_nd", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BNTrainingReduce", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}], "outputs": [{"index": 0, "name": "sum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "square_sum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float32", ""], ["float32", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bn_training_reduce.so", "compute_cost": 10, "kernel_name": "bn_training_reduce", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "BNTrainingReduceGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "x_norm", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 2, "name": "diff_scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "diff_offset", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "batch_mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "batch_variance", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}], "attr": [{"name": "epsilon", "param_type": "optional", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bn_training_reduce_grad.so", "compute_cost": 10, "kernel_name": "bn_training_reduce_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "BNTrainingUpdate", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "sum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "square_sum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "offset", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "variance", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "variance", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "batch_mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "batch_variance", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "factor", "param_type": "optional", "type": "float", "value": "all"}, {"name": "epsilon", "param_type": "optional", "type": "float", "value": "all"}, {"name": "isRef", "param_type": "optional", "type": "bool", "value": "all", "default_value": "true"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bn_training_update.so", "compute_cost": 10, "kernel_name": "bn_training_update", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BNTrainingUpdateGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 2, "name": "batch_mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "batch_variance", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "diff_scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "diff_offset", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "epsilon", "param_type": "optional", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bn_training_update_grad.so", "compute_cost": 10, "kernel_name": "bn_training_update_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "BNInfer", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "offset", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "variance", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}], "attr": [{"name": "epsilon", "param_type": "required", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bn_infer.so", "compute_cost": 10, "kernel_name": "bn_infer", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BNInferGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "batch_variance", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "x_backprop", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}], "attr": [{"name": "epsilon", "param_type": "optional", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bn_infer_grad.so", "compute_cost": 10, "kernel_name": "bn_infer_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Reciprocal", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "reciprocal.so", "compute_cost": 10, "kernel_name": "reciprocal", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "StridedSlice", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "begin", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "end", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "begin_mask", "param_type": "required", "type": "int", "value": "all"}, {"name": "end_mask", "param_type": "required", "type": "int", "value": "all"}, {"name": "ellipsis_mask", "param_type": "required", "type": "int", "value": "all"}, {"name": "new_axis_mask", "param_type": "required", "type": "int", "value": "all"}, {"name": "shrink_axis_mask", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "strided_slice_d.so", "compute_cost": 10, "kernel_name": "strided_slice_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "StridedSliceGrad", "inputs": [{"index": 0, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "shapex", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "begin", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "end", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "begin_mask", "param_type": "optional", "type": "int", "value": "all"}, {"name": "end_mask", "param_type": "optional", "type": "int", "value": "all"}, {"name": "ellipsis_mask", "param_type": "optional", "type": "int", "value": "all"}, {"name": "new_axis_mask", "param_type": "optional", "type": "int", "value": "all"}, {"name": "shrink_axis_mask", "param_type": "optional", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "strided_slice_grad_d.so", "compute_cost": 10, "kernel_name": "strided_slice_grad_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Split", "inputs": [{"index": 0, "name": "value", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output", "need_compile": false, "param_type": "dynamic", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "int", "value": "all"}, {"name": "output_num", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "split_d.so", "compute_cost": 10, "kernel_name": "split_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "Exp", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "exp.so", "compute_cost": 10, "kernel_name": "exp", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Expm1", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "expm1.so", "compute_cost": 10, "kernel_name": "expm1", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Elu", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "alpha", "param_type": "optional", "type": "float", "value": "all", "default_value": "1.0"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "elu.so", "compute_cost": 10, "kernel_name": "elu", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "EluGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "activations", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "elu_grad.so", "compute_cost": 10, "kernel_name": "elu_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Div", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int8", ""], ["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""], ["uint8", ""]], [["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "div.so", "compute_cost": 10, "kernel_name": "div", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "Log", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "log.so", "compute_cost": 10, "kernel_name": "log", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "FloorDiv", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int8", ""], ["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""], ["uint8", ""]], [["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "floordiv.so", "compute_cost": 10, "kernel_name": "floordiv", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "ZerosLike", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["bool", ""], ["bool", ""]], [["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""]], [["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "zeros_like.so", "compute_cost": 10, "kernel_name": "zeros_like", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Neg", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "neg.so", "compute_cost": 10, "kernel_name": "neg", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "NPUClearFloatStatus", "inputs": [{"index": 0, "name": "addr", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "data", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "n_p_u_clear_float_status.so", "compute_cost": 10, "kernel_name": "n_p_u_clear_float_status", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "NPUGetFloatStatus", "inputs": [{"index": 0, "name": "addr", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "data", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "n_p_u_get_float_status.so", "compute_cost": 10, "kernel_name": "n_p_u_get_float_status", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "NPUAllocFloatStatus", "inputs": [], "outputs": [{"index": 0, "name": "data", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "n_p_u_alloc_float_status.so", "compute_cost": 10, "kernel_name": "n_p_u_alloc_float_status", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "OneHot", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "on_value", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "off_value", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "depth", "param_type": "required", "type": "int", "value": "all"}, {"name": "axis", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["uint8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "one_hot.so", "compute_cost": 10, "kernel_name": "one_hot", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Equal", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int8", ""], ["int8", ""], ["bool", ""]], [["uint8", ""], ["uint8", ""], ["bool", ""]], [["int32", ""], ["int32", ""], ["bool", ""]], [["float16", ""], ["float16", ""], ["bool", ""]], [["float32", ""], ["float32", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "equal.so", "compute_cost": 10, "kernel_name": "equal", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "Less", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", ""], ["int8", ""], ["bool", ""]], [["uint8", ""], ["uint8", ""], ["bool", ""]], [["int32", ""], ["int32", ""], ["bool", ""]], [["float16", ""], ["float16", ""], ["bool", ""]], [["float32", ""], ["float32", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "less.so", "compute_cost": 10, "kernel_name": "less", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "LessEqual", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "begin_norm_axis", "param_type": "required", "type": "int", "value": "all"}, {"name": "begin_params_axis", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", ""], ["int8", ""], ["bool", ""]], [["uint8", ""], ["uint8", ""], ["bool", ""]], [["int32", ""], ["int32", ""], ["bool", ""]], [["float16", ""], ["float16", ""], ["bool", ""]], [["float32", ""], ["float32", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "less_equal.so", "compute_cost": 10, "kernel_name": "less_equal", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "LogicalAnd", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["bool", ""], ["bool", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "logical_and.so", "compute_cost": 10, "kernel_name": "logical_and", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "LogicalNot", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["bool", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "logical_not.so", "compute_cost": 10, "kernel_name": "logical_not", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "LogicalOr", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["bool", ""], ["bool", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "logical_or.so", "compute_cost": 10, "kernel_name": "logical_or", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "ReduceMax", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["bool", ""], ["bool", ""]], [["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""]], [["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "reduce_max_d.so", "compute_cost": 10, "kernel_name": "reduce_max_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "reduce"} +{"op_name": "ReduceMin", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "required", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "reduce_min_d.so", "compute_cost": 10, "kernel_name": "reduce_min_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "reduce"} +{"op_name": "ReduceSum", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "reduce_sum_d.so", "compute_cost": 10, "kernel_name": "reduce_sum_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "reduce"} +{"op_name": "Round", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]], [["int32", ""], ["int32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "round.so", "compute_cost": 10, "kernel_name": "round", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Tanh", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "tanh.so", "compute_cost": 10, "kernel_name": "tanh", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "TanhGrad", "inputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "tanh_grad.so", "compute_cost": 10, "kernel_name": "tanh_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Softmax", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "softmax.so", "compute_cost": 10, "kernel_name": "softmax", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Softsign", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "softsign.so", "compute_cost": 10, "kernel_name": "softsign", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Softplus", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "softplus.so", "compute_cost": 10, "kernel_name": "softplus", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "SoftplusGrad", "inputs": [{"index": 0, "name": "gradients", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "features", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "backprops", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "softplus_grad.so", "compute_cost": 10, "kernel_name": "softplus_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "SoftmaxGradExt", "inputs": [{"index": 0, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "keepdims", "param_type": "required", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""], ["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "softmax_grad_ext.so", "compute_cost": 10, "kernel_name": "softmax_grad_ext", "partial_flag": true, "reshape_type": "", "dynamic_format": true, "op_pattern": "dynamicFormat"} +{"op_name": "Square", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "square.so", "compute_cost": 10, "kernel_name": "square", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Sqrt", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sqrt.so", "compute_cost": 10, "kernel_name": "sqrt", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "SparseApplyFtrl", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "linear", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "linear", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "lr", "param_type": "required", "type": "float", "value": "all"}, {"name": "l1", "param_type": "required", "type": "float", "value": "all"}, {"name": "l2", "param_type": "required", "type": "float", "value": "all"}, {"name": "lr_power", "param_type": "required", "type": "float", "value": "all"}, {"name": "use_locking", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["int32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sparse_apply_ftrl.so", "compute_cost": 10, "kernel_name": "sparse_apply_ftrl", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SparseApplyProximalAdagrad", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "l1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "l2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["int16", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["int16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["int16", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int16", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["int16", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["int32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["int32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["int32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["int64", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["int64", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["int64", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["int64", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["uint16", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["uint16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["uint16", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["uint16", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["uint16", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["uint32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["uint32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["uint32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["uint32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["uint32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["uint64", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["uint64", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["uint64", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["uint64", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["uint64", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sparse_apply_proximal_adagrad.so", "compute_cost": 10, "kernel_name": "sparse_apply_proximal_adagrad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyProximalAdagrad", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "l1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "l2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_proximal_adagrad_d.so", "compute_cost": 10, "kernel_name": "apply_proximal_adagrad_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Transpose", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "perm", "param_type": "optional", "type": "listInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "transpose_d.so", "compute_cost": 10, "kernel_name": "transpose_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "UnsortedSegmentSum", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "segment_ids", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "num_segments", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "unsorted_segment_sum_d.so", "compute_cost": 10, "kernel_name": "unsorted_segment_sum_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "UnsortedSegmentProd", "inputs": [{"index": 0, "name": "data", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "segment_ids", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "num_segments", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["int32", "DefaultFormat"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["int32", "DefaultFormat"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["int32", "DefaultFormat"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["int32", "DefaultFormat"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["int32", "DefaultFormat"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["int32", "DefaultFormat"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "DefaultFormat"], ["int32", "NC1HWC0"]], [["int32", "FracZ"], ["int32", "DefaultFormat"], ["int32", "FracZ"]], [["int32", "C1HWNCoC0"], ["int32", "DefaultFormat"], ["int32", "C1HWNCoC0"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "unsorted_segment_prod_d.so", "compute_cost": 10, "kernel_name": "unsorted_segment_prod_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LogSoftmaxGrad", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "log_softmax_grad.so", "compute_cost": 10, "kernel_name": "log_softmax_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LogSoftmax", "inputs": [{"index": 0, "name": "logits", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "logsoftmax", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "log_softmax.so", "compute_cost": 10, "kernel_name": "log_softmax", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Select", "inputs": [{"index": 0, "name": "condition", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""], ["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "select.so", "compute_cost": 10, "kernel_name": "select", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "Pow", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int8", ""], ["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""], ["uint8", ""]], [["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "pow.so", "compute_cost": 10, "kernel_name": "pow", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "Maximum", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "maximum.so", "compute_cost": 10, "kernel_name": "maximum", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "Minimum", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "minimum.so", "compute_cost": 10, "kernel_name": "minimum", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "MinimumGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y2", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "grad_x", "param_type": "optional", "type": "bool", "value": "all"}, {"name": "grad_y", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int32", ""], ["int32", ""], ["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "minimum_grad.so", "compute_cost": 10, "kernel_name": "minimum_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "MaximumGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y2", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "grad_x", "param_type": "optional", "type": "bool", "value": "all"}, {"name": "grad_y", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int32", ""], ["int32", ""], ["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "maximum_grad.so", "compute_cost": 10, "kernel_name": "maximum_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "Concat", "inputs": [{"index": 0, "name": "input_values", "need_compile": false, "param_type": "dynamic", "shape": "all"}], "outputs": [{"index": 0, "name": "output_data", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "concat_d.so", "compute_cost": 10, "kernel_name": "concat_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "Slice", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "begin", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "size", "param_type": "required", "type": "listInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "slice_d.so", "compute_cost": 10, "kernel_name": "slice_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Sign", "inputs": [{"index": 0, "name": "x", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]], [["int32", ""], ["int32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sign.so", "compute_cost": 10, "kernel_name": "sign", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Greater", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", ""], ["int8", ""], ["bool", ""]], [["uint8", ""], ["uint8", ""], ["bool", ""]], [["int32", ""], ["int32", ""], ["bool", ""]], [["float16", ""], ["float16", ""], ["bool", ""]], [["float32", ""], ["float32", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "greater.so", "compute_cost": 10, "kernel_name": "greater", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "ClipByNormNoDivSum", "inputs": [{"index": 0, "name": "input_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "input2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "input3", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output_y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "clip_by_norm_no_div_sum.so", "compute_cost": 10, "kernel_name": "clip_by_norm_no_div_sum", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ClipByValue", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "clip_value_min", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "clip_value_max", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "dst_type", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["int32", ""], ["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "clip_by_value.so", "compute_cost": 10, "kernel_name": "clip_by_value", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "LayerNormBetaGammaBackprop", "inputs": [{"index": 0, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "variance", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "mean", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "pd_gamma", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "pd_beta", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "shape_gamma", "param_type": "required", "type": "listInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""], ["float32", ""], ["float32", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "layer_norm_beta_gamma_backprop.so", "compute_cost": 10, "kernel_name": "layer_norm_beta_gamma_backprop", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "LayerNorm", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "gamma", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "beta", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "variance", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "begin_norm_axis", "param_type": "required", "type": "int", "value": "all"}, {"name": "begin_params_axis", "param_type": "required", "type": "int", "value": "all"}, {"name": "epsilon", "param_type": "optional", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "layer_norm.so", "compute_cost": 10, "kernel_name": "layer_norm", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "LayerNormGrad", "inputs": [{"index": 0, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "variance", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "gamma", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "pd_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "pd_gamma", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "pd_beta", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "layer_norm_grad.so", "compute_cost": 10, "kernel_name": "layer_norm_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LayerNormXBackprop", "inputs": [{"index": 0, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "variance", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "gamma", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "pd_x", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "layer_norm_x_backprop.so", "compute_cost": 10, "kernel_name": "layer_norm_x_backprop", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "L2Loss", "inputs": [{"index": 0, "name": "x", "param_type": "required"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "FracZ"], ["float16", "DefaultFormat"]], [["float16", "FRACTAL_NZ"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "DefaultFormat"]], [["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "DefaultFormat"]], [["float32", "FRACTAL_NZ"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "DefaultFormat"]], [["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "l2_loss.so", "compute_cost": 10, "kernel_name": "l2_loss", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "L2Normalize", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "epsilon", "param_type": "required", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "l2_normalize.so", "compute_cost": 10, "kernel_name": "l2_normalize", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "L2NormalizeGrad", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "dx", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "epsilon", "param_type": "required", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "l2_normalize_grad.so", "compute_cost": 10, "kernel_name": "l2_normalize_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SquareSumV1", "inputs": [{"index": 0, "name": "input_x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output1", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "square_sum_v1.so", "compute_cost": 10, "kernel_name": "square_sum_v1", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SquareSumV2", "inputs": [{"index": 0, "name": "input_x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "output2", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "square_sum_v2.so", "compute_cost": 10, "kernel_name": "square_sum_v2", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ConfusionTransposeD", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "perm", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "shape", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "transpose_first", "param_type": "required", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "confusion_transpose_d.so", "compute_cost": 10, "kernel_name": "confusion_transpose_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "ConfusionSoftmaxGrad", "inputs": [{"index": 0, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "confusion_softmax_grad.so", "compute_cost": 10, "kernel_name": "confusion_softmax_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LambUpdateWithLrV2", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "x3", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "x4", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "x5", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "greater_y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "select_e", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "lamb_update_with_lr_v2.so", "compute_cost": 10, "kernel_name": "lamb_update_with_lr_v2", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LambNextMV", "inputs": [{"index": 0, "name": "input1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "input3", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "input4", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "input5", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "input6", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "input7", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 7, "name": "input8", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 8, "name": "input9", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 9, "name": "inputx0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 10, "name": "inputx1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 11, "name": "inputx2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 12, "name": "inputx3", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "output2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "output3", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "output4", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "lamb_next_m_v.so", "compute_cost": 10, "kernel_name": "lamb_next_m_v", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LambNextMVWithDecay", "inputs": [{"index": 0, "name": "input_mul3", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input_mul2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "input_realdiv1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "input_mul1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "input_mul0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "input_realdiv0", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "input_mul4", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 7, "name": "mul0_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 8, "name": "mul1_sub", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 9, "name": "mul2_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 10, "name": "mul3_sub1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 11, "name": "mul4_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 12, "name": "add2_y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y1", "need_compile": true, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y2", "need_compile": true, "param_type": "required", "shape": "all"}, {"index": 2, "name": "y3", "need_compile": true, "param_type": "required", "shape": "all"}, {"index": 3, "name": "y4", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "lamb_next_m_v_with_decay.so", "compute_cost": 10, "kernel_name": "lamb_next_m_v_with_decay", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LambUpdateWithLR", "inputs": [{"index": 0, "name": "input1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "input3", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "input4", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "input5", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "input6", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "input7", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 7, "name": "input8", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 8, "name": "input9", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output_y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "lamb_update_with_lr.so", "compute_cost": 10, "kernel_name": "lamb_update_with_lr", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Rsqrt", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "rsqrt.so", "compute_cost": 10, "kernel_name": "rsqrt", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Sigmoid", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sigmoid.so", "compute_cost": 10, "kernel_name": "sigmoid", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "SigmoidGrad", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sigmoid_grad.so", "compute_cost": 10, "kernel_name": "sigmoid_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ResizeNearestNeighbor", "inputs": [{"index": 0, "name": "images", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [{"name": "size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "align_corners", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "resize_nearest_neighbor_d.so", "compute_cost": 10, "kernel_name": "resize_nearest_neighbor_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ResizeNearestNeighborGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "align_corners", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "resize_nearest_neighbor_grad_d.so", "compute_cost": 10, "kernel_name": "resize_nearest_neighbor_grad_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Pad", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "paddings", "param_type": "optional", "type": "listListInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "pad_d.so", "compute_cost": 10, "kernel_name": "pad_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ArgMaxWithValue", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "indice", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "values", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "arg_max_with_value.so", "compute_cost": 10, "kernel_name": "arg_max_with_value", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ArgMinWithValue", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "indice", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "values", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "arg_min_with_value.so", "compute_cost": 10, "kernel_name": "arg_min_with_value", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SmoothL1Loss", "inputs": [{"index": 0, "name": "predict", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "label", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "loss", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "sigma", "param_type": "required", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "smooth_l1_loss.so", "compute_cost": 10, "kernel_name": "smooth_l1_loss", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SmoothL1LossGrad", "inputs": [{"index": 0, "name": "predict", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "label", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "dout", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "loss", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "sigma", "param_type": "required", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "smooth_l1_loss_grad.so", "compute_cost": 10, "kernel_name": "smooth_l1_loss_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "FusedMulAdd", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "x3", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""], ["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "fused_mul_add.so", "compute_cost": 10, "kernel_name": "fused_mul_add", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "FusedMulAddN", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "x3", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "fused_mul_add_n.so", "compute_cost": 10, "kernel_name": "fused_mul_add_n", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "FusedMulApplyMomentum", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "momentum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_nesterov", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "fused_mul_apply_momentum.so", "compute_cost": 10, "kernel_name": "fused_mul_apply_momentum", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Fill", "inputs": [{"index": 0, "name": "value", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "dims", "param_type": "required", "type": "listInt", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "FracZ"], ["int32", "FracZ"]], [["int32", "C1HWNCoC0"], ["int32", "C1HWNCoC0"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "FracZ"], ["int8", "FracZ"]], [["int8", "C1HWNCoC0"], ["int8", "C1HWNCoC0"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "FracZ"], ["uint8", "FracZ"]], [["uint8", "C1HWNCoC0"], ["uint8", "C1HWNCoC0"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "fill_d.so", "compute_cost": 10, "kernel_name": "fill_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Erf", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "erf.so", "compute_cost": 10, "kernel_name": "erf", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Erfc", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "erfc.so", "compute_cost": 10, "kernel_name": "erfc", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "DepthwiseConv2dNative", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "filter", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "bias", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 3, "name": "offset_w", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [{"name": "stride", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "dilation", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pads", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "data_format", "param_type": "required", "type": "str", "value": "all"}, {"name": "offset_a", "param_type": "optional", "type": "int", "value": "all"}], "fusion_type": "CONVLUTION", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "depthwise_conv2d.so", "compute_cost": 10, "kernel_name": "depthwise_conv2d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "DepthwiseConv2dNativeBackpropFilter", "inputs": [{"index": 0, "name": "input", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "out_backprop", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "filter_grad", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "filter_size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "stride", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "dilation", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pads", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "data_format", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "CONVLUTION", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float32", "C1HWNCoC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "depthwise_conv2d_backprop_filter_d.so", "compute_cost": 10, "kernel_name": "depthwise_conv2d_backprop_filter_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "DepthwiseConv2dNativeBackpropInput", "inputs": [{"index": 0, "name": "filter", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "out_backprop", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "input_grad", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "input_size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "stride", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "dilation", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pads", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "data_format", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "CONVLUTION", "dtype_format": [[["float16", "C1HWNCoC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "depthwise_conv2d_backprop_input_d.so", "compute_cost": 10, "kernel_name": "depthwise_conv2d_backprop_input_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "GreaterEqual", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", ""], ["int8", ""], ["bool", ""]], [["uint8", ""], ["uint8", ""], ["bool", ""]], [["int32", ""], ["int32", ""], ["bool", ""]], [["float16", ""], ["float16", ""], ["bool", ""]], [["float32", ""], ["float32", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "greater_equal.so", "compute_cost": 10, "kernel_name": "greater_equal", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "NotEqual", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int8", ""], ["int8", ""], ["bool", ""]], [["uint8", ""], ["uint8", ""], ["bool", ""]], [["int32", ""], ["int32", ""], ["bool", ""]], [["float16", ""], ["float16", ""], ["bool", ""]], [["float32", ""], ["float32", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "not_equal.so", "compute_cost": 10, "kernel_name": "not_equal", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "FloorMod", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]], [["int32", ""], ["int32", ""], ["int32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "floor_mod.so", "compute_cost": 10, "kernel_name": "floor_mod", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "ScatterNdUpdate", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "updates", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["bool", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "scatter_nd_update.so", "compute_cost": 10, "kernel_name": "scatter_nd_update", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "AvgPool", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "ksize", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "padding", "param_type": "required", "type": "str", "value": "all"}, {"name": "data_format", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "avg_pool.so", "compute_cost": 10, "kernel_name": "avg_pool", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "AvgPoolGrad", "inputs": [{"index": 0, "name": "input_grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "mean_matrix", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 2, "name": "kernel_matrix", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "out_grad", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [{"name": "x_origin", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "ksize", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "padding", "param_type": "required", "type": "str", "value": "all"}, {"name": "data_format", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "C1HWNCoC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "avg_pool_grad_d.so", "compute_cost": 10, "kernel_name": "avg_pool_grad_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "OnesLike", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["uint8", ""], ["uint8", ""]], [["int8", ""], ["int8", ""]], [["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "ones_like.so", "compute_cost": 10, "kernel_name": "ones_like", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "BatchToSpace", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "block_size", "param_type": "required", "type": "int", "value": "all"}, {"name": "crops", "param_type": "required", "type": "listListInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "batch_to_space_d.so", "compute_cost": 10, "kernel_name": "batch_to_space_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SpaceToBatch", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "block_size", "param_type": "required", "type": "int", "value": "all"}, {"name": "paddings", "param_type": "required", "type": "listListInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "space_to_batch_d.so", "compute_cost": 10, "kernel_name": "space_to_batch_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "DepthToSpace", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "block_size", "param_type": "required", "type": "int", "value": "all"}, {"name": "data_format", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NHWC"], ["float16", "NHWC"]], [["float32", "NHWC"], ["float32", "NHWC"]], [["int8", "NHWC"], ["int8", "NHWC"]], [["int16", "NHWC"], ["int16", "NHWC"]], [["int32", "NHWC"], ["int32", "NHWC"]], [["int64", "NHWC"], ["int64", "NHWC"]], [["uint8", "NHWC"], ["uint8", "NHWC"]], [["uint16", "NHWC"], ["uint16", "NHWC"]], [["uint32", "NHWC"], ["uint32", "NHWC"]], [["uint64", "NHWC"], ["uint64", "NHWC"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "depth_to_space.so", "compute_cost": 10, "kernel_name": "depth_to_space", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SpaceToDepth", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "block_size", "param_type": "required", "type": "int", "value": "all"}, {"name": "data_format", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NHWC"], ["float16", "NHWC"]], [["float32", "NHWC"], ["float32", "NHWC"]], [["int8", "NHWC"], ["int8", "NHWC"]], [["int16", "NHWC"], ["int16", "NHWC"]], [["int32", "NHWC"], ["int32", "NHWC"]], [["int64", "NHWC"], ["int64", "NHWC"]], [["uint8", "NHWC"], ["uint8", "NHWC"]], [["uint16", "NHWC"], ["uint16", "NHWC"]], [["uint32", "NHWC"], ["uint32", "NHWC"]], [["uint64", "NHWC"], ["uint64", "NHWC"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "space_to_depth.so", "compute_cost": 10, "kernel_name": "space_to_depth", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Floor", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "floor.so", "compute_cost": 10, "kernel_name": "floor", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Ceil", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "ceil.so", "compute_cost": 10, "kernel_name": "ceil", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Log1p", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "log1p.so", "compute_cost": 10, "kernel_name": "log1p", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "ResizeBilinear", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "size", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "align_corners", "param_type": "optional", "type": "bool", "value": "all"}, {"name": "half_pixel_centers", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "resize_bilinear_v2_d.so", "compute_cost": 10, "kernel_name": "resize_bilinear_v2_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ResizeBilinearGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "original_image", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "align_corners", "param_type": "optional", "type": "bool", "value": "all"}, {"name": "half_pixel_centers", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "resize_bilinear_v2_grad.so", "compute_cost": 10, "kernel_name": "resize_bilinear_v2_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Flatten", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "flatten.so", "compute_cost": 10, "kernel_name": "flatten", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ROIAlign", "inputs": [{"index": 0, "name": "features", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "rois", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "rois_n", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "spatial_scale", "param_type": "required", "type": "float", "value": "all"}, {"name": "pooled_height", "param_type": "required", "type": "int", "value": "all"}, {"name": "pooled_width", "param_type": "required", "type": "int", "value": "all"}, {"name": "sample_num", "param_type": "optional", "type": "int", "value": "all", "default_value": "2"}, {"name": "roi_end_mode", "param_type": "optional", "type": "0,1", "value": "1"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "roi_align.so", "compute_cost": 10, "kernel_name": "roi_align", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ROIAlignGrad", "inputs": [{"index": 0, "name": "ydiff", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "rois", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "rois_n", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "xdiff", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "xdiff_shape", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "pooled_width", "param_type": "required", "type": "int", "value": "all"}, {"name": "pooled_height", "param_type": "required", "type": "int", "value": "all"}, {"name": "spatial_scale", "param_type": "required", "type": "float", "value": "all"}, {"name": "sample_num", "param_type": "optional", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "roi_align_grad.so", "compute_cost": 10, "kernel_name": "roi_align_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BoundingBoxDecode", "inputs": [{"index": 0, "name": "rois", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "deltas", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "bboxes", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "means", "param_type": "optional", "type": "listFloat", "value": "all"}, {"name": "stds", "param_type": "optional", "type": "listFloat", "value": "all"}, {"name": "max_shape", "param_type": "optional", "type": "listInt", "value": "all"}, {"name": "wh_ratio_clip", "param_type": "optional", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bounding_box_decode.so", "compute_cost": 10, "kernel_name": "bounding_box_decode", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BoundingBoxEncode", "inputs": [{"index": 0, "name": "anchor_box", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "ground_truth_box", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "delats", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "means", "param_type": "optional", "type": "listFloat", "value": "all"}, {"name": "stds", "param_type": "optional", "type": "listFloat", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bounding_box_encode.so", "compute_cost": 10, "kernel_name": "bounding_box_encode", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "CheckValid", "inputs": [{"index": 0, "name": "bbox_tensor", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "img_tas", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "valid_tensor", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float16", ""], ["int8", ""]], [["float16", ""], ["float16", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "check_valid.so", "compute_cost": 10, "kernel_name": "check_valid", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "IOU", "inputs": [{"index": 0, "name": "bboxes", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "gtboxes", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "overlap", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "mode", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "iou.so", "compute_cost": 10, "kernel_name": "iou", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Argmax", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "int", "value": "all"}, {"name": "output_dtype", "param_type": "optional", "type": "type", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "arg_max_d.so", "compute_cost": 10, "kernel_name": "arg_max_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "NMSWithMask", "inputs": [{"index": 0, "name": "box_scores", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "selected_boxes", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 0, "name": "selected_idx", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 0, "name": "selected_mask", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "iou_threshold", "param_type": "optional", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "nms_with_mask.so", "compute_cost": 10, "kernel_name": "nms_with_mask", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SGD", "inputs": [{"index": 0, "name": "parameters", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "gradient", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "learning_rate", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "momentum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "stat", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "parameters", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "dampening", "param_type": "optional", "type": "float", "value": "all"}, {"name": "weight_decay", "param_type": "optional", "type": "float", "value": "all"}, {"name": "nesterov", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sgd.so", "compute_cost": 10, "kernel_name": "sgd", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LARSUpdate", "inputs": [{"index": 0, "name": "w", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "g", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "w_square_sum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "g_square_sum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "weight_decay", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "learning_rate", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "g_new", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "hyperpara", "param_type": "optional", "type": "float", "value": "all"}, {"name": "epsilon", "param_type": "optional", "type": "float", "value": "all"}, {"name": "use_clip", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "lars_v2_update.so", "compute_cost": 10, "kernel_name": "lars_v2_update", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Argmin", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "int", "value": "all"}, {"name": "output_dtype", "param_type": "optional", "type": "type", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "arg_min_d.so", "compute_cost": 10, "kernel_name": "arg_min_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BNTrainingUpdateV2", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "sum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "square_sum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "offset", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "batch_mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "batch_variance", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "epsilon", "param_type": "required", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float16", ""], ["float32", ""], ["float32", ""]], [["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bn_training_update_v2.so", "compute_cost": 10, "kernel_name": "bn_training_update_v2", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "BNTrainingUpdateV3", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "sum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "square_sum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "scale", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "offset", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NC"}, {"index": 1, "name": "batch_mean", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "batch_variance", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "reserve_1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "reserve_2", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "epsilon", "param_type": "required", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bn_training_update_v3.so", "compute_cost": 10, "kernel_name": "bn_training_update_v3", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SquareSumAll", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y2", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "square_sum_all.so", "compute_cost": 10, "kernel_name": "square_sum", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Pack", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "dynamic", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["bool", "DefaultFormat"], ["bool", "DefaultFormat"]], [["int8", "NDHWC"], ["int8", "NDHWC"]], [["int16", "NDHWC"], ["int16", "NDHWC"]], [["int32", "NDHWC"], ["int32", "NDHWC"]], [["int64", "NDHWC"], ["int64", "NDHWC"]], [["uint8", "NDHWC"], ["uint8", "NDHWC"]], [["uint16", "NDHWC"], ["uint16", "NDHWC"]], [["uint32", "NDHWC"], ["uint32", "NDHWC"]], [["uint64", "NDHWC"], ["uint64", "NDHWC"]], [["float16", "NDHWC"], ["float16", "NDHWC"]], [["float32", "NDHWC"], ["float32", "NDHWC"]], [["bool", "NDHWC"], ["bool", "NDHWC"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "pack.so", "compute_cost": 10, "kernel_name": "pack", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Unpack", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "dynamic", "shape": "all"}], "attr": [{"name": "num", "param_type": "optional", "type": "int", "value": "all"}, {"name": "axis", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int16", "NC1HWC0"], ["int16", "NC1HWC0"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int64", "NC1HWC0"], ["int64", "NC1HWC0"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint16", "NC1HWC0"], ["uint16", "NC1HWC0"]], [["uint32", "NC1HWC0"], ["uint32", "NC1HWC0"]], [["uint64", "NC1HWC0"], ["uint64", "NC1HWC0"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "unpack.so", "compute_cost": 10, "kernel_name": "unpack", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ScatterUpdate", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "updates", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["bool", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "scatter_update.so", "compute_cost": 10, "kernel_name": "scatter_update", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "PReLU", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "weight", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NCHW"], ["float16", "DefaultFormat"], ["float16", "NCHW"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NCHW"], ["float32", "DefaultFormat"], ["float32", "NCHW"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "prelu.so", "compute_cost": 10, "kernel_name": "prelu", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "PReLUGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "features", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "weights", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "dx", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 0, "name": "da", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float32", "NCHW"], ["float32", "NCHW"], ["float32", "DefaultFormat"], ["float32", "NCHW"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "prelu_grad.so", "compute_cost": 10, "kernel_name": "prelu_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BinaryCrossEntropy", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "weight", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "output", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "reduction", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["", ""], ["", ""], ["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "binary_cross_entropy.so", "compute_cost": 10, "kernel_name": "binary_cross_entropy", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "BinaryCrossEntropyGrad", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "grad_output", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "weight", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "output", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "reduction", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "binary_cross_entropy_grad.so", "compute_cost": 10, "kernel_name": "binary_cross_entropy_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Sin", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sin.so", "compute_cost": 10, "kernel_name": "sin", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Cos", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "cos.so", "compute_cost": 10, "kernel_name": "cos", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "CumSum", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "int", "value": "all", "default_value": "0"}, {"name": "exclusive", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "fales"}, {"name": "reverse", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "cumsum_d.so", "compute_cost": 10, "kernel_name": "cumsum_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ApplyRMSProp", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "ms", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "mom", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "ms", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "mom", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "rho", "param_type": "required", "type": "float", "value": "all"}, {"name": "momentum", "param_type": "required", "type": "float", "value": "all"}, {"name": "epsilon", "param_type": "required", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "apply_rms_prop.so", "compute_cost": 10, "kernel_name": "apply_rms_prop_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "CumProd", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "int", "value": "all"}, {"name": "exclusive", "param_type": "optional", "type": "bool", "value": "all"}, {"name": "reverse", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "cumprod_d.so", "compute_cost": 10, "kernel_name": "cumprod_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ReduceProd", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""]], [["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "reduce_prod_d.so", "compute_cost": 10, "kernel_name": "reduce_prod_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "reduce"} +{"op_name": "FlattenGrad", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "shape", "param_type": "required", "type": "listInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "reshape.so", "compute_cost": 10, "kernel_name": "reshape", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ScatterAdd", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "updates", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "scatter_add.so", "compute_cost": 10, "kernel_name": "scatter_add", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Atan2", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "atan2.so", "compute_cost": 10, "kernel_name": "atan2", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "BesselI0e", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bessel_i0e.so", "compute_cost": 10, "kernel_name": "bessel_i0e", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "BesselI1e", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bessel_i1e.so", "compute_cost": 10, "kernel_name": "bessel_i1e", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "BatchToSpaceND", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NH"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NH"}], "attr": [{"name": "block_shape", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "crops", "param_type": "required", "type": "listListInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "batch_to_space_nd_d.so", "compute_cost": 10, "kernel_name": "batch_to_space_nd_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SpaceToBatchND", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NH"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all", "reshape_type": "NH"}], "attr": [{"name": "block_shape", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "paddings", "param_type": "required", "type": "listListInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "space_to_batch_nd_d.so", "compute_cost": 10, "kernel_name": "space_to_batch_nd_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BitwiseAnd", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int16", ""], ["int16", ""], ["int16", ""]], [["uint16", ""], ["uint16", ""], ["uint16", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bitwise_and.so", "compute_cost": 10, "kernel_name": "bitwise_and", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "BitwiseOr", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int16", ""], ["int16", ""], ["int16", ""]], [["uint16", ""], ["uint16", ""], ["uint16", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bitwise_or.so", "compute_cost": 10, "kernel_name": "bitwise_or", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "BitwiseXor", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int16", ""], ["int16", ""], ["int16", ""]], [["uint16", ""], ["uint16", ""], ["uint16", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "bitwise_xor.so", "compute_cost": 10, "kernel_name": "bitwise_xor", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "ReduceAll", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "keep_dims", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["bool", ""], ["bool", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "reduce_all_d.so", "compute_cost": 10, "kernel_name": "reduce_all_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "reduce"} +{"op_name": "SparseApplyAdagrad", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "lr", "param_type": "required", "type": "float", "value": "all"}, {"name": "update_slots", "param_type": "optional", "type": "bool", "value": "all"}, {"name": "use_locking", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["int32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]], [["float32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"], ["int32", "NHWC"], ["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sparse_apply_adagrad_d.so", "compute_cost": 10, "kernel_name": "sparse_apply_adagrad_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "UnsortedSegmentMin", "inputs": [{"index": 0, "name": "data", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "segment_ids", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "num_segments", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["int32", "DefaultFormat"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["int32", "DefaultFormat"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["int32", "DefaultFormat"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["int32", "DefaultFormat"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["int32", "DefaultFormat"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["int32", "DefaultFormat"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "DefaultFormat"], ["int32", "NC1HWC0"]], [["int32", "FracZ"], ["int32", "DefaultFormat"], ["int32", "FracZ"]], [["int32", "C1HWNCoC0"], ["int32", "DefaultFormat"], ["int32", "C1HWNCoC0"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "unsorted_segment_min_d.so", "compute_cost": 10, "kernel_name": "unsorted_segment_min_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Asin", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "asin.so", "compute_cost": 10, "kernel_name": "asin", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "AsinGrad", "inputs": [{"index": 0, "name": "y", "param_type": "required", "shape": "all"}, {"index": 1, "name": "dy", "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "asin_grad.so", "compute_cost": 10, "kernel_name": "asin_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Asinh", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "asinh.so", "compute_cost": 10, "kernel_name": "asinh", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "AsinhGrad", "inputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "asinh_grad.so", "compute_cost": 10, "kernel_name": "asinh_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "DivNoNan", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int8", ""], ["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""], ["uint8", ""]], [["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "div_no_nan.so", "compute_cost": 10, "kernel_name": "div_no_nan", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "Atan", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "atan.so", "compute_cost": 10, "kernel_name": "atan", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "AtanGrad", "inputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "dy", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["float16", "FRACTAL_NZ"], ["float16", "FracZ"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["float32", "FRACTAL_NZ"], ["float32", "FracZ"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "atan_grad.so", "compute_cost": 10, "kernel_name": "atan_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Atanh", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "atanh.so", "compute_cost": 10, "kernel_name": "atanh", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Cosh", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "cosh.so", "compute_cost": 10, "kernel_name": "cosh", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Sinh", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": true, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "sinh.so", "compute_cost": 10, "kernel_name": "sinh", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "Inv", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int32", ""], ["int32", ""]], [["float32", ""], ["float32", ""]], [["float16", ""], ["float16", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "inv.so", "compute_cost": 10, "kernel_name": "inv", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "InvGrad", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]], [["int32", ""], ["int32", ""], ["int32", ""]], [["int8", ""], ["int8", ""], ["int8", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "inv_grad.so", "compute_cost": 10, "kernel_name": "inv_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "Invert", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int16", ""], ["int16", ""]], [["uint16", ""], ["uint16", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "invert.so", "compute_cost": 10, "kernel_name": "invert", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "formatAgnostic"} +{"op_name": "BasicLSTMCell", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "h", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "c", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "w", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "b", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "mask", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "ct", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "ht", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "it", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 3, "name": "jt", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 4, "name": "ft", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 5, "name": "ot", "need_compile": false, "param_type": "optional", "shape": "all"}, {"index": 6, "name": "tanhct", "need_compile": false, "param_type": "optional", "shape": "all"}], "attr": [{"name": "keep_prob", "param_type": "optional", "type": "float", "value": "all"}, {"name": "forget_bias", "param_type": "optional", "type": "float", "value": "all"}, {"name": "state_is_tuple", "param_type": "optional", "type": "bool", "value": "true"}, {"name": "activation", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float16", "FracZ"], ["float32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["float32", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["uint8", "DefaultFormat"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "basic_lstm_cell.so", "compute_cost": 10, "kernel_name": "basic_lstm_cell", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BasicLSTMCellCStateGrad", "inputs": [{"index": 0, "name": "c", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "dht", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "dct", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "it", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "jt", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "ft", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "ot", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 7, "name": "tanhct", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "dgate", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "dct_1", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "forget_bias", "param_type": "optional", "type": "float", "value": "all"}, {"name": "activation", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "basic_lstm_cell_c_state_grad.so", "compute_cost": 10, "kernel_name": "basic_lstm_cell_c_state_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BasicLSTMCellWeightGrad", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "h", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "dgate", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "dw", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "db", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FracZ"], ["float32", "DefaultFormat"]], [["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "basic_lstm_cell_weight_grad.so", "compute_cost": 10, "kernel_name": "basic_lstm_cell_weight_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BasicLSTMCellInputGrad", "inputs": [{"index": 0, "name": "dgate", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "w", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "dropout_mask", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "dxt", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "dht", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "keep_prob", "param_type": "optional", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "FRACTAL_NZ"], ["float16", "FracZ"], ["uint8", "DefaultFormat"], ["float32", "FRACTAL_NZ"], ["float32", "FRACTAL_NZ"]], [["float16", "FRACTAL_NZ"], ["float16", "FracZ"], ["uint8", "DefaultFormat"], ["float16", "FRACTAL_NZ"], ["float16", "FRACTAL_NZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "basic_lstm_cell_input_grad.so", "compute_cost": 10, "kernel_name": "basic_lstm_cell_input_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ConfusionMatrix", "inputs": [{"index": 0, "name": "labels", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "predictions", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "weights", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "num_classes", "param_type": "required", "type": "int", "value": "all"}, {"name": "dtype", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "confusion_matrix.so", "compute_cost": 10, "kernel_name": "confusion_matrix", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "BroadcastTo", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "shape", "param_type": "required", "type": "listInt", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["uint16", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "broadcast_to_d.so", "compute_cost": 10, "kernel_name": "broadcast_to_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "StridedRead", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "int", "value": "all"}, {"name": "stride", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "strided_read.so", "compute_cost": 10, "kernel_name": "strided_read", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "StridedWrite", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "required", "type": "int", "value": "all"}, {"name": "stride", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "strided_write.so", "compute_cost": 10, "kernel_name": "strided_write", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Range", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "start", "param_type": "required", "type": "float", "value": "all"}, {"name": "limit", "param_type": "required", "type": "float", "value": "all"}, {"name": "delta", "param_type": "required", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "range_d.so", "compute_cost": 10, "kernel_name": "range_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "FusedMulAddNL2loss", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "x3", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y2", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"]], [["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "fused_mul_addn_l2loss.so", "compute_cost": 10, "kernel_name": "fused_mul_addn_l2loss", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "FusedMulApplyMomentumExtern", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "lr", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "momentum", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 6, "name": "var_copy", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "var_copy", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "accum", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_nesterov", "param_type": "optional", "type": "bool", "value": "true,false", "default_value": "false"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float16", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "FracZ"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "FracZ"], ["float32", "FracZ"], ["float16", "FracZ"], ["float16", "FracZ"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "NC1HWC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "NC1HWC0"], ["float32", "NC1HWC0"], ["float16", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "C1HWNCoC0"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "C1HWNCoC0"], ["float32", "C1HWNCoC0"], ["float16", "C1HWNCoC0"], ["float32", "C1HWNCoC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "FracZ"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "FracZ"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float16", "FracZ"], ["float32", "FracZ"], ["float16", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "fused_mul_apply_momentum_extern.so", "compute_cost": 10, "kernel_name": "fused_mul_apply_momentum_extern", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LambNextRight", "inputs": [{"index": 0, "name": "input_square", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "input_mul2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "mul2_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "mul3_x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 4, "name": "truediv1_recip", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 5, "name": "add2_y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "y2", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "lamb_next_right.so", "compute_cost": 10, "kernel_name": "lamb_next_right", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SparseGatherV2", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "axis", "param_type": "optional", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int64", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int32", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "NC1HWC0"], ["int64", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "FracZ"], ["int32", "FracZ"], ["int8", "FracZ"]], [["int8", "FracZ"], ["int64", "FracZ"], ["int8", "FracZ"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int64", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["int32", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "NC1HWC0"], ["int64", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "FracZ"], ["int32", "FracZ"], ["uint8", "FracZ"]], [["uint8", "FracZ"], ["int64", "FracZ"], ["uint8", "FracZ"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int64", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "NC1HWC0"], ["int64", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "FracZ"], ["int32", "FracZ"], ["int32", "FracZ"]], [["int32", "FracZ"], ["int64", "FracZ"], ["int32", "FracZ"]], [["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "DefaultFormat"], ["int64", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["int32", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "NC1HWC0"], ["int64", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "FracZ"], ["int32", "FracZ"], ["float16", "FracZ"]], [["float16", "FracZ"], ["int64", "FracZ"], ["float16", "FracZ"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int64", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["int32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "NC1HWC0"], ["int64", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "FracZ"], ["int32", "FracZ"], ["float32", "FracZ"]], [["float32", "FracZ"], ["int64", "FracZ"], ["float32", "FracZ"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "gather_v2_d.so", "compute_cost": 10, "kernel_name": "gather_v2_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "DataFormatDimMap", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "dst_format", "param_type": "optional", "type": "str", "value": "all"}, {"name": "src_format", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "data_format_dim_map.so", "compute_cost": 10, "kernel_name": "data_format_dim_map", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "HistogramFixedWidth", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "range", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "nbins", "param_type": "required", "type": "int", "value": "all"}, {"name": "dtype", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["int32", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "histogram_fixed_width_d.so", "compute_cost": 10, "kernel_name": "histogram_fixed_width_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "TensorScatterUpdate", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "updates", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "tensor_scatter_update.so", "compute_cost": 10, "kernel_name": "tensor_scatter_update", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "InplaceUpdate", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "v", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "indices", "param_type": "required", "type": "listInt", "value": "all"}], "fusion_type": "INPLACE", "dtype_format": [[["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "inplace_update_d.so", "compute_cost": 10, "kernel_name": "inplace_update_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "SplitV", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "dynamic", "shape": "all"}], "attr": [{"name": "size_splits", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "split_dim", "param_type": "required", "type": "int", "value": "all"}, {"name": "num_split", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["", ""], ["", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "split_v_d.so", "compute_cost": 10, "kernel_name": "split_v_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "dynamicFormat"} +{"op_name": "InTopK", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "k", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["bool", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "in_top_k.so", "compute_cost": 10, "kernel_name": "in_top_k", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LinSpace", "inputs": [{"index": 0, "name": "assist", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "start", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "stop", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 3, "name": "num", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "output", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["float32", ""], ["float32", ""], ["float32", ""], ["int32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "lin_space.so", "compute_cost": 10, "kernel_name": "lin_space", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "MatrixDiag", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "assist", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "matrix_diag_d.so", "compute_cost": 10, "kernel_name": "matrix_diag_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "MatrixDiagPart", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "assist", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "matrix_diag_part_d.so", "compute_cost": 10, "kernel_name": "matrix_diag_part_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "MatrixSetDiag", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "diagonal", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "assist", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "matrix_diag_d.so", "compute_cost": 10, "kernel_name": "matrix_diag_d", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LRN", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "depth_radius", "param_type": "optional", "type": "int", "value": "all", "default_value": "5"}, {"name": "bias", "param_type": "optional", "type": "float", "value": "all", "default_value": "1.0"}, {"name": "alpha", "param_type": "optional", "type": "float", "value": "all", "default_value": "1.0"}, {"name": "beta", "param_type": "optional", "type": "float", "value": "all", "default_value": "0.5"}, {"name": "norm_region", "param_type": "optional", "type": "str", "value": "all", "default_value": "ACROSS_CHANNELS"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NCHW"], ["float16", "NCHW"]], [["float32", "NCHW"], ["float32", "NCHW"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "lrn.so", "compute_cost": 10, "kernel_name": "lrn", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "LRNGrad", "inputs": [{"index": 0, "name": "grads", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "z", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "depth_radius", "param_type": "optional", "type": "int", "value": "all"}, {"name": "bias", "param_type": "optional", "type": "float", "value": "all"}, {"name": "alpha", "param_type": "optional", "type": "float", "value": "all"}, {"name": "beta", "param_type": "optional", "type": "float", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NCHW"], ["float16", "NCHW"], ["float16", "NCHW"], ["float16", "NCHW"]], [["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"], ["float32", "NCHW"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "lrn_grad.so", "compute_cost": 10, "kernel_name": "lrn_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ScatterMax", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "updates", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "scatter_max.so", "compute_cost": 10, "kernel_name": "scatter_max", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ScatterMin", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "updates", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "scatter_min.so", "compute_cost": 10, "kernel_name": "scatter_min", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ScatterSub", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "updates", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "scatter_sub.so", "compute_cost": 10, "kernel_name": "scatter_sub", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ScatterMul", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "updates", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "scatter_mul.so", "compute_cost": 10, "kernel_name": "scatter_mul", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ScatterDiv", "inputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "indices", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "updates", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "var", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "use_locking", "param_type": "optional", "type": "bool", "value": "all"}], "fusion_type": "ELEMWISE", "dtype_format": [[["float16", "DefaultFormat"], ["int32", "DefaultFormat"], ["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float32", "DefaultFormat"], ["int32", "DefaultFormat"], ["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int8", "DefaultFormat"], ["int32", "DefaultFormat"], ["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["uint8", "DefaultFormat"], ["int32", "DefaultFormat"], ["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "scatter_div.so", "compute_cost": 10, "kernel_name": "scatter_div", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "Mod", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "ELEMWISE", "dtype_format": [[["int8", ""], ["int8", ""], ["int8", ""]], [["uint8", ""], ["uint8", ""], ["uint8", ""]], [["int32", ""], ["int32", ""], ["int32", ""]], [["float16", ""], ["float16", ""], ["float16", ""]], [["float32", ""], ["float32", ""], ["float32", ""]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "mod.so", "compute_cost": 10, "kernel_name": "mod", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": "broadcast"} +{"op_name": "MaxPoolGradGrad", "inputs": [{"index": 0, "name": "x1", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "x2", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "ksize", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "padding", "param_type": "required", "type": "str", "value": "all"}, {"name": "data_format", "param_type": "optional", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool_grad_grad.so", "compute_cost": 10, "kernel_name": "max_pool_grad_grad", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "MaxPoolGradGradWithArgmax", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 1, "name": "grad", "need_compile": false, "param_type": "required", "shape": "all"}, {"index": 2, "name": "argmax", "need_compile": false, "param_type": "optional", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "ksize", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "strides", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "padding", "param_type": "required", "type": "str", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["uint16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"], ["int64", "NC1HWC0"], ["float16", "NC1HWC0"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "max_pool_grad_grad_with_argmax.so", "compute_cost": 10, "kernel_name": "max_pool_grad_grad_with_argmax", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "PopulationCount", "inputs": [{"index": 0, "name": "x", "need_compile": false, "param_type": "required", "shape": "all"}], "outputs": [{"index": 0, "name": "y", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [], "fusion_type": "OPAQUE", "dtype_format": [[["int16", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["int16", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint16", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["uint16", "DefaultFormat"], ["uint8", "DefaultFormat"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "population_count.so", "compute_cost": 10, "kernel_name": "population_count", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} +{"op_name": "ParallelConcat", "inputs": [{"index": 0, "name": "values", "need_compile": false, "param_type": "dynamic", "shape": "all"}], "outputs": [{"index": 0, "name": "output_data", "need_compile": false, "param_type": "required", "shape": "all"}], "attr": [{"name": "shape", "param_type": "required", "type": "listInt", "value": "all"}, {"name": "N", "param_type": "required", "type": "int", "value": "all"}], "fusion_type": "OPAQUE", "dtype_format": [[["bool", "DefaultFormat"], ["bool", "DefaultFormat"]], [["bool", "NC1HWC0"], ["bool", "NC1HWC0"]], [["int8", "DefaultFormat"], ["int8", "DefaultFormat"]], [["int8", "NC1HWC0"], ["int8", "NC1HWC0"]], [["uint8", "DefaultFormat"], ["uint8", "DefaultFormat"]], [["uint8", "NC1HWC0"], ["uint8", "NC1HWC0"]], [["int16", "DefaultFormat"], ["int16", "DefaultFormat"]], [["int16", "NC1HWC0"], ["int16", "NC1HWC0"]], [["uint16", "DefaultFormat"], ["uint16", "DefaultFormat"]], [["uint16", "NC1HWC0"], ["uint16", "NC1HWC0"]], [["int32", "DefaultFormat"], ["int32", "DefaultFormat"]], [["int32", "NC1HWC0"], ["int32", "NC1HWC0"]], [["uint32", "DefaultFormat"], ["uint32", "DefaultFormat"]], [["uint32", "NC1HWC0"], ["uint32", "NC1HWC0"]], [["int64", "DefaultFormat"], ["int64", "DefaultFormat"]], [["int64", "NC1HWC0"], ["int64", "NC1HWC0"]], [["uint64", "DefaultFormat"], ["uint64", "DefaultFormat"]], [["uint64", "NC1HWC0"], ["uint64", "NC1HWC0"]], [["float16", "DefaultFormat"], ["float16", "DefaultFormat"]], [["float16", "NC1HWC0"], ["float16", "NC1HWC0"]], [["float32", "DefaultFormat"], ["float32", "DefaultFormat"]], [["float32", "NC1HWC0"], ["float32", "NC1HWC0"]], [["bool", "NHWC"], ["bool", "NHWC"]], [["bool", "NCHW"], ["bool", "NCHW"]], [["int8", "NHWC"], ["int8", "NHWC"]], [["int8", "NCHW"], ["int8", "NCHW"]], [["uint8", "NHWC"], ["uint8", "NHWC"]], [["uint8", "NCHW"], ["uint8", "NCHW"]], [["int16", "NHWC"], ["int16", "NHWC"]], [["int16", "NCHW"], ["int16", "NCHW"]], [["uint16", "NHWC"], ["uint16", "NHWC"]], [["uint16", "NCHW"], ["uint16", "NCHW"]], [["int32", "NHWC"], ["int32", "NHWC"]], [["int32", "NCHW"], ["int32", "NCHW"]], [["uint32", "NHWC"], ["uint32", "NHWC"]], [["uint32", "NCHW"], ["uint32", "NCHW"]], [["int64", "NHWC"], ["int64", "NHWC"]], [["int64", "NCHW"], ["int64", "NCHW"]], [["uint64", "NHWC"], ["uint64", "NHWC"]], [["uint64", "NCHW"], ["uint64", "NCHW"]], [["float16", "NHWC"], ["float16", "NHWC"]], [["float16", "NCHW"], ["float16", "NCHW"]], [["float32", "NHWC"], ["float32", "NHWC"]], [["float32", "NCHW"], ["float32", "NCHW"]]], "imply_type": "TBE", "async_flag": false, "binfile_name": "parallel_concat.so", "compute_cost": 10, "kernel_name": "parallel_concat", "partial_flag": true, "reshape_type": "", "dynamic_format": false, "op_pattern": ""} diff --git a/graphengine b/graphengine index 4084909d62..31aa96ef41 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 4084909d62c159da6ba316f61ad3d02a4857b34b +Subproject commit 31aa96ef41067a0ecdc4113ef245f8ede48f3457 diff --git a/include/ms_tensor.h b/include/ms_tensor.h index 1f9661df5e..fc59e12328 100644 --- a/include/ms_tensor.h +++ b/include/ms_tensor.h @@ -20,7 +20,7 @@ #include #include #include -#include "ir/dtype/type_id.h" +#include "mindspore/core/ir/dtype/type_id.h" namespace mindspore { #define MS_API __attribute__((visibility("default"))) diff --git a/mindspore/_extends/parse/parser.py b/mindspore/_extends/parse/parser.py index a6043eb787..9d715fdf53 100644 --- a/mindspore/_extends/parse/parser.py +++ b/mindspore/_extends/parse/parser.py @@ -334,7 +334,7 @@ class Parser: def __init__(self, fn: (types.FunctionType, types.MethodType), parse_method=None) -> None: self.fn = fn self.parse_method = parse_method - _, self.line_offset = inspect.getsourcelines(self.fn) + self.line_offset = 0 self.filename: str = inspect.getfile(self.fn) # Used to resolve the function's globals Namespace. @@ -350,7 +350,8 @@ class Parser: logger.debug("fn = %r", self.fn) tree = None if isinstance(self.fn, (types.FunctionType, types.MethodType)): - original_src = inspect.getsource(self.fn) + lines, self.line_offset = inspect.getsourcelines(self.fn) + original_src = ''.join(lines) hexstr = hashlib.sha256(original_src.encode()).hexdigest() tree = Parser.ast_cache.get(hexstr) if not tree: diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 936099a4fb..d70c6edcf4 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -108,7 +108,8 @@ def enumerate_(x, start=0): """Enumerate list or tuple.""" x_type = F.typeof(x) ret = () - if check_is_tuple_or_list(x_type, "enumerate"): + op_name = "enumerate" + if check_is_tuple_or_list(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"): ret = zip(range(start, start + len(x)), x) return ret @@ -123,11 +124,22 @@ def while_cond(x): @constexpr -def check_is_tuple_or_list(x, op_name): +def check_is_tuple_or_list(x, op_name, arg_name): """check whether x is list or tuple.""" if isinstance(x, (mstype.list_type, mstype.tuple_type)): return True - raise TypeError(f"For '{op_name}', the input parameter should be tuple or list, but got {x}.") + raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list, but got {x}.") + + +@constexpr +def check_is_const_int(x, op_name, arg_name): + """check whether x is const int.""" + if x is None: + raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.") + if not isinstance(x, int): + raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.") + return True + @constexpr def check_is_tensor_bool_cond(shp): diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 9dc1502aa5..bb02f338f6 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -1,4 +1,5 @@ ## common setting +include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_BINARY_DIR}) link_directories(${CMAKE_SOURCE_DIR}/build/mindspore/graphengine) @@ -35,20 +36,20 @@ if(ENABLE_GPU) include_directories(${CUDNN_PATH} ${CUDA_PATH} ${CUDA_INCLUDE_DIRS}) file(GLOB_RECURSE GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "device/gpu/*.cc" - "device/gpu/*.cu" - "kernel/gpu/*.cu" - "kernel/akg/gpu/*.cc" - "kernel/akg/akg_kernel_build.cc" - "kernel/akg/akg_kernel_attrs_process.cc" + "runtime/device/gpu/*.cc" + "runtime/device/gpu/*.cu" + "backend/kernel_compiler/gpu/*.cu" + "backend/kernel_compiler/akg/gpu/*.cc" + "backend/kernel_compiler/akg/akg_kernel_build.cc" + "backend/kernel_compiler/akg/akg_kernel_attrs_process.cc" ) list(APPEND CUDA_NVCC_FLAGS -arch=sm_53) - list(REMOVE_ITEM GPU_SRC_LIST "device/gpu/blocking_queue.cc" "device/gpu/gpu_buffer_mgr.cc") - list(REMOVE_ITEM GPU_SRC_LIST "device/gpu/mpi/mpi_initializer.cc" - "device/gpu/distribution/collective_wrapper.cc" - "device/gpu/distribution/mpi_wrapper.cc" - "device/gpu/distribution/nccl_wrapper.cc" + list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/blocking_queue.cc" "runtime/device/gpu/gpu_buffer_mgr.cc") + list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/mpi/mpi_initializer.cc" + "runtime/device/gpu/distribution/collective_wrapper.cc" + "runtime/device/gpu/distribution/mpi_wrapper.cc" + "runtime/device/gpu/distribution/nccl_wrapper.cc" ) set(NVCC_TMP_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) @@ -56,6 +57,7 @@ if(ENABLE_GPU) set_property(SOURCE ${GPU_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) cuda_add_library(gpu_cuda_lib STATIC ${GPU_SRC_LIST}) set(CMAKE_CXX_FLAGS ${NVCC_TMP_CMAKE_CXX_FLAGS}) + add_compile_definitions(ENABLE_GPU) endif () ## make flatuffer files @@ -101,16 +103,20 @@ if (ENABLE_DUMP_PROTO) endif () if (ENABLE_D) - include_directories("${CMAKE_BINARY_DIR}/kernel/aicpu") + include_directories("${CMAKE_BINARY_DIR}/backend/kernel_compiler/aicpu") include_directories("${CMAKE_BINARY_DIR}/predict/generator/ir") - file(GLOB_RECURSE PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "kernel/aicpu/proto/*.proto") + file(GLOB_RECURSE PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "backend/kernel_compiler/aicpu/proto/*.proto") ms_protobuf_generate(PROTOSRCS PROTOHDRS ${PROTO_IN}) file(GLOB_RECURSE PROTO_INNER RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "predict/proto/*.proto") ms_protobuf_generate(PREDICT_PROTOSRCS PREDICT_PROTOHDRS ${PROTO_INNER}) + file(GLOB_RECURSE PROTO_DUMP RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "runtime/device/ascend/dump/proto/*.proto") + ms_protobuf_generate(DUMP_PROTOSRCS PROTOHDRS ${PROTO_DUMP}) + list(APPEND MINDSPORE_PROTO_LIST ${PROTOSRCS}) list(APPEND MINDSPORE_PROTO_LIST ${PREDICT_PROTOSRCS}) + list(APPEND MINDSPORE_PROTO_LIST ${DUMP_PROTOSRCS}) add_compile_definitions(ENABLE_D) endif () @@ -121,18 +127,36 @@ if (MINDSPORE_PROTO_LIST) endif() ## make sub objects -set(SUB_COMP - transform pre_activate parallel pipeline device kernel common debug gvar ir onnx operator optimizer predict - pybind_api pynative session utils vm +set(SUB_COMP + transform/graph_ir + transform/onnx + backend/optimizer + backend/kernel_compiler + backend/session + runtime/device + frontend/optimizer + frontend/parallel + frontend/operator + pipeline/jit + pipeline/pynative + common debug gvar predict pybind_api utils vm ) foreach (_comp ${SUB_COMP}) add_subdirectory(${_comp}) - if (TARGET _mindspore_${_comp}_obj) - list(APPEND SUB_OBJECTS_SRC $) - add_dependencies(_mindspore_${_comp}_obj proto_input flat_input) + string(REPLACE "/" "_" sub ${_comp}) + if (TARGET _mindspore_${sub}_obj) + list(APPEND SUB_OBJECTS_SRC $) + add_dependencies(_mindspore_${sub}_obj proto_input flat_input) endif () endforeach () +add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/base base) +list(APPEND SUB_OBJECTS_SRC $) +add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/abstract abstract) +list(APPEND SUB_OBJECTS_SRC $) +add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/ir ir) +list(APPEND SUB_OBJECTS_SRC $) +add_dependencies(_mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj proto_input flat_input) set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME) add_library(mindspore STATIC ${SUB_OBJECTS_SRC}) @@ -204,8 +228,8 @@ endif() # set c_expression building set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) -set_property(SOURCE "pipeline/init.cc" PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) -pybind11_add_module(_c_expression "pipeline/init.cc") +set_property(SOURCE "pipeline/jit/init.cc" PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) +pybind11_add_module(_c_expression "pipeline/jit/init.cc") MESSAGE(STATUS "operation system is ${CMAKE_SYSTEM}") if (CMAKE_SYSTEM_NAME MATCHES "Linux") @@ -231,9 +255,11 @@ else () target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive) target_link_libraries(_c_expression PRIVATE mindspore::pybind11_module) target_link_libraries(_c_expression PRIVATE mindspore_gvar) - target_link_libraries(_c_expression PRIVATE mindspore::pslite mindspore::protobuf ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) - if (${ENABLE_IBVERBS} STREQUAL "ON") - target_link_libraries(_c_expression PRIVATE ibverbs rdmacm) + if (NOT ENABLE_GE) + target_link_libraries(_c_expression PRIVATE mindspore::pslite mindspore::protobuf ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) + if (${ENABLE_IBVERBS} STREQUAL "ON") + target_link_libraries(_c_expression PRIVATE ibverbs rdmacm) + endif() endif() endif () @@ -260,8 +286,8 @@ if (ENABLE_CPU) endif () if (ENABLE_MINDDATA) - add_subdirectory(mindrecord) - add_subdirectory(dataset) + add_subdirectory(minddata/mindrecord) + add_subdirectory(minddata/dataset) endif () # build inference @@ -270,7 +296,7 @@ set(LOAD_ONNX_SRC ${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc ) add_library(inference SHARED - ${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/backend/session/session.cc ${LOAD_ONNX_SRC} ) target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} diff --git a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt new file mode 100644 index 0000000000..b412d83d11 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt @@ -0,0 +1,66 @@ +file(GLOB_RECURSE KERNEL_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "kernel_build_info.cc" + "kash/*.cc" + "common_utils.cc" + "oplib/*.cc" +) + +if (ENABLE_D) + file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "kernel_query.cc" + "kernel_fusion.cc" + "akg/ascend/*.cc" + "akg/akg_kernel_build.cc" + "akg/akg_kernel_attrs_process.cc" + "akg/akg_kernel_metadata.cc" + "tbe/*.cc" + "aicpu/*.cc" + "rts/*.cc" + "hccl/*.cc" + ) + add_compile_definitions(ENABLE_D) +endif () + +if (ENABLE_CPU) + file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "cpu/*.cc" + ) + + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc" + "cpu/ps/pull_kernel.cc" + "cpu/ps/embedding_look_up_ps_kernel.cc" + "cpu/ps/embedding_look_up_proxy_kernel.cc" + "cpu/ps/apply_momentum_ps_kernel.cc" + "cpu/ps/sparse_apply_adam_ps_kernel.cc" + "cpu/ps/sparse_apply_ftrl_ps_kernel.cc") + + if (NOT ENABLE_MPI) + list(REMOVE_ITEM CPU_SRC_LIST "cpu/allgather_cpu_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/reduce_scatter_cpu_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/embedding_look_up_comm_grad_cpu_kernel.cc") + endif () +endif () + +if (ENABLE_GPU) + file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "gpu/*.cu" + "akg/gpu/*.cc" + "akg/akg_kernel_build.cc" + "akg/akg_kernel_attrs_process.cc" + ) + + file(GLOB_RECURSE GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc") + list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_gpu_kernel.cc") + + if (ENABLE_MPI) + include(ExternalProject) + file(GLOB_RECURSE GPU_NCCL_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/nccl/*.cc") + list(APPEND GPU_SRC_LIST ${GPU_NCCL_LIST}) + endif () + + # add_library(_mindspore_kernel_cuda_obj OBJECT ${CUDA_SRC_LIST}) +endif() + +set_property(SOURCE ${KERNEL_SRC_LIST} ${CPU_SRC_LIST} ${GPU_SRC_LIST} ${D_SRC_LIST} + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_KERNEL) +add_library(_mindspore_backend_kernel_compiler_obj OBJECT ${KERNEL_SRC_LIST} ${CPU_SRC_LIST} ${GPU_SRC_LIST} ${D_SRC_LIST}) diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc new file mode 100644 index 0000000000..7e7fd20f39 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc @@ -0,0 +1,312 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "runtime/device/kernel_runtime.h" +#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" +#include "backend/kernel_compiler/akg/akg_kernel_build.h" +#include "proto/tensor.pb.h" +#include "proto/tensor_shape.pb.h" +#include "proto/attr.pb.h" +#include "proto/node_def.pb.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +using FNodeAttrHandle = std::function &anf_node, mindspore::NodeDef *proto)>; + +bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input_num, + std::vector *input_size_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(input_size_list); + for (size_t i = 0; i < input_num; i++) { + std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); + if (AnfAlgo::GetInputDeviceDataType(anf_node, i) == kObjectTypeString) { + if (!anf_node->isa()) { + MS_LOG(EXCEPTION) << "anf_node is not CNode."; + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() < (i + 1)) { + MS_LOG(ERROR) << "cnode inputs size " << cnode->inputs().size() << " is smaller than " << i + 1; + return false; + } + auto input_node = cnode->inputs()[i + 1]; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa()) { + auto value_ptr = GetValueNode(input_node); + auto value = GetValue(value_ptr); + input_size_list->push_back(value.size()); + } + } else { + auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); + MS_EXCEPTION_IF_NULL(type_ptr); + int64_t size_i = 1; + for (size_t j = 0; j < shape_i.size(); j++) { + size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); + } + size_t type_byte = GetTypeByte(type_ptr); + if (type_byte == 0) { + return false; + } + size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); + input_size_list->push_back(LongToSize(size_i)); + } + } + return true; +} + +bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptr &kernel_mod_ptr) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + std::vector input_size_list; + std::vector output_size_list; + size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); + + if (!SetIOIputSize(anf_node, input_num, &input_size_list)) { + return false; + } + kernel_mod_ptr->SetInputSizeList(input_size_list); + + for (size_t i = 0; i < output_num; i++) { + std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); + TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); + MS_EXCEPTION_IF_NULL(type_ptr); + int64_t size_i = 1; + for (size_t j = 0; j < shape_i.size(); j++) { + size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); + } + size_t type_byte = GetTypeByte(type_ptr); + if (type_byte == 0) { + return false; + } + size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); + output_size_list.push_back(LongToSize(size_i)); + } + kernel_mod_ptr->SetOutputSizeList(output_size_list); + return true; +} + +void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value, + ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) { + MS_EXCEPTION_IF_NULL(node_attr); + MS_EXCEPTION_IF_NULL(value); + if (type == "int") { + auto attr_value = GetValue(value); + (*node_attr)[attr_name].set_i(attr_value); + } else if (type == "str") { + auto attr_value = GetValue(value); + (*node_attr)[attr_name].set_s(attr_value); + } else if (type == "bool") { + auto attr_value = GetValue(value); + (*node_attr)[attr_name].set_b(attr_value); + } else if (type == "float") { + auto attr_value = GetValue(value); + (*node_attr)[attr_name].set_f(attr_value); + } else if (type == "listInt") { + std::vector attr_value; + auto value_type = value->type(); + MS_EXCEPTION_IF_NULL(value_type); + auto value_type_str = value_type->ToString(); + if (value_type_str == "Int32") { + int data = GetValue(value); + attr_value.push_back(data); + } else { + attr_value = GetValue>(value); + } + mindspore::AttrValue input_shape_attr; + mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array(); + MS_EXCEPTION_IF_NULL(input_shape_attr_list); + for (const auto shape : attr_value) { + input_shape_attr_list->add_i(shape); + } + (*node_attr)[attr_name] = input_shape_attr; + } else { + MS_LOG(EXCEPTION) << "type: " << type << "not support"; + } +} + +void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(proto); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (op_name == kInitDataSetQueue) { + op_name = kInitData; + } + if (op_name == kPrint) { + return; + } + + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); + MS_EXCEPTION_IF_NULL(op_info_ptr); + auto attrs_ptr = op_info_ptr->attrs_ptr(); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); + for (const auto &attr_ptr : attrs_ptr) { + MS_EXCEPTION_IF_NULL(attr_ptr); + std::string attr_name = attr_ptr->name(); + auto value = primitive->GetAttr(attr_name); + if (value != nullptr) { + if (attr_name == kQueueName || attr_name == kSharedName) { + attr_name = kChannelName; + } else if (attr_name == kSeed0) { + attr_name = kSeed; + } else if (attr_name == kSeed1) { + attr_name = kSeed2; + } + std::string type = attr_ptr->type(); + ParseAttrValue(type, attr_name, value, node_attr); + } + } + MS_LOG(INFO) << "Set node attr end!"; +} + +void SetNodeInputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(proto); + MS_EXCEPTION_IF_NULL(anf_node); + size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); + if (input_num == 0) { + MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input."; + return; + } + + for (size_t input_index = 0; input_index < input_num; input_index++) { + ::mindspore::Tensor *node_inputs = proto->add_inputs(); + MS_EXCEPTION_IF_NULL(node_inputs); + TypeId input_type = AnfAlgo::GetInputDeviceDataType(anf_node, input_index); + std::vector input_shape; + int32_t input_data_type; + if (input_type == kObjectTypeString) { + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input_node = cnode->inputs()[input_index + 1]; + auto value_ptr = GetValueNode(input_node); + auto value = GetValue(value_ptr); + input_shape.push_back(1); + input_shape.push_back(value.size()); + input_data_type = AicpuOpUtil::MsTypeToProtoType(kTypeUnknown); + } else { + input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index); + input_data_type = AicpuOpUtil::MsTypeToProtoType(input_type); + } + + mindspore::TensorShape *tensorShape = node_inputs->mutable_tensor_shape(); + for (auto item : input_shape) { + mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); + dim->set_size((::google::protobuf::int64)item); + } + node_inputs->set_tensor_type((mindspore::DataType)input_data_type); + node_inputs->set_mem_device("HBM"); + } +} + +void SetNodeOutputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(proto); + MS_EXCEPTION_IF_NULL(anf_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); + if (output_num == 0) { + MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. "; + return; + } + + for (size_t output_index = 0; output_index < output_num; output_index++) { + ::mindspore::Tensor *node_outputs = proto->add_outputs(); + MS_EXCEPTION_IF_NULL(node_outputs); + std::vector output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index); + mindspore::TensorShape *tensorShape = node_outputs->mutable_tensor_shape(); + MS_EXCEPTION_IF_NULL(tensorShape); + for (auto item : output_shape) { + mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); + MS_EXCEPTION_IF_NULL(dim); + dim->set_size((::google::protobuf::int64)item); + } + TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, output_index); + int32_t output_data_type = AicpuOpUtil::MsTypeToProtoType(output_type); + node_outputs->set_tensor_type((mindspore::DataType)output_data_type); + node_outputs->set_mem_device("HBM"); + } +} + +void SetNodedefProto(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(proto); + MS_LOG(INFO) << "SetNodedefProto entry"; + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (op_name == kInitDataSetQueue) { + op_name = kInitData; + } + // set op name + proto->set_op(op_name); + // set inputs tensor + SetNodeInputs(anf_node, proto); + // set outputs tensor + SetNodeOutputs(anf_node, proto); + // set node attr + SetNodeAttr(anf_node, proto); + MS_LOG(INFO) << "SetNodedefProto end!"; +} + +bool CreateNodeDefBytes(const std::shared_ptr &anf_node, + const std::shared_ptr &kernel_mod_ptr) { + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "CreateNodeDefBytes entry"; + + mindspore::NodeDef proto; + SetNodedefProto(anf_node, &proto); + std::string nodeDefStr; + if (!proto.SerializeToString(&nodeDefStr)) { + MS_LOG(ERROR) << "Serialize nodeDef to string failed."; + return false; + } + kernel_mod_ptr->SetNodeDef(nodeDefStr); + MS_LOG(INFO) << "CreateNodeDefBytes end!"; + return true; +} + +KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (op_name == kInitDataSetQueue) { + op_name = kInitData; + } + auto kernel_mod_ptr = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + kernel_mod_ptr->SetAnfNode(anf_node); + kernel_mod_ptr->SetNodeName(op_name); + if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) { + MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!"; + } + if (!SetIOSize(anf_node, kernel_mod_ptr)) { + MS_LOG(EXCEPTION) << "Set input output size list failed."; + } + return kernel_mod_ptr; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.h new file mode 100644 index 0000000000..6e2ee3959b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ +#include +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc new file mode 100644 index 0000000000..76c29b9f5c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h" +#include +#include +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + MS_LOG(INFO) << "AicpuMetadataInfo."; + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + if (op_name == kInitDataSetQueue) { + op_name = kInitData; + } + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); + if (op_info_ptr == nullptr) { + MS_LOG(DEBUG) << "Aicpu does not have op [" << op_name << "]"; + return; + } + // For compatibility with the current framework + if (op_name == kPrint || op_name == kGetNext || op_name == kPack) { + std::vector inputs_format{}; + std::vector inputs_type{}; + if (op_name == kPrint || op_name == kPack) { + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + inputs_format.emplace_back(kOpFormat_DEFAULT); + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); + } + } + std::vector outputs_format; + std::vector outputs_type; + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + outputs_format.emplace_back(kOpFormat_DEFAULT); + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); + } + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetProcessor(AICPU); + builder.SetKernelType(AICPU_KERNEL); + builder.SetFusionType(OPAQUE); + kernel_info_list->push_back(builder.Build()); + return; + } + if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) { + MS_LOG(WARNING) << "Aicpu parsed metadata op [" << op_name << "] failed"; + return; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h new file mode 100644 index 0000000000..e21f4eace4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ + +#include +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc new file mode 100644 index 0000000000..e18b3169f3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" + +#include +#include +#include +#include + +#include "runtime/mem.h" +#include "runtime/rt.h" +#include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h" +#include "utils/convert_utils.h" +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +#include "utils/context/ms_context.h" + +using AicpuTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +constexpr auto AICPU_OPS_SO_NAME = "libaicpu_kernels.so"; + +AicpuOpKernelMod::AicpuOpKernelMod() : anf_node_(nullptr) {} + +AicpuOpKernelMod::~AicpuOpKernelMod() { + args_.clear(); + inputList_.clear(); + outputList_.clear(); + anf_node_ = nullptr; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); +} + +void AicpuOpKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } +const std::vector &AicpuOpKernelMod::GetInputSizeList() const { return input_size_list_; } +void AicpuOpKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } +const std::vector &AicpuOpKernelMod::GetOutputSizeList() const { return output_size_list_; } +void AicpuOpKernelMod::SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } +const std::vector &AicpuOpKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } +void AicpuOpKernelMod::SetInputList(const std::vector &inputList) { inputList_ = inputList; } +void AicpuOpKernelMod::SetOutputList(const std::vector &outputList) { outputList_ = outputList; } +void AicpuOpKernelMod::SetNodeDef(const std::string &nodeDef) { (void)node_def_str_.assign(nodeDef); } +void AicpuOpKernelMod::SetNodeName(const std::string &node_name) { node_name_ = node_name; } +void AicpuOpKernelMod::SetAnfNode(const mindspore::AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + anf_node_ = anf_node; +} + +void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector &inputs, + const std::vector &outputs) { + MS_LOG(INFO) << "CreateCpuKernelInfoOffline start"; + + node_so_ = AICPU_OPS_SO_NAME; + + // InputOutputAddr + vector io_addrs; + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(io_addrs), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(io_addrs), + [](const AddressPtr &output) -> void * { return output->addr; }); + + auto io_addrs_num = io_addrs.size(); + // calculate paramLen: AicpuParamHead.len + ioAddrsSize + notifyId.len + customizedAttr.len + auto param_len = sizeof(AicpuParamHead); + + // get input and output addrs size, no need to check overflow + auto io_addrs_size = io_addrs_num * sizeof(uint64_t); + // refresh paramLen, no need to check overflow + param_len += io_addrs_size; + + auto node_def_len = node_def_str_.length(); + param_len += node_def_len; + + // Create taskArgs: AicpuParamHead + ioAddrs + notifyId + customizedAttr + AicpuParamHead paramHead = {static_cast(param_len), static_cast(io_addrs_num)}; + args_.clear(); + (void)args_.append(reinterpret_cast(¶mHead), sizeof(AicpuParamHead)); + // TaskArgs append ioAddrs + if (io_addrs_size != 0) { + (void)args_.append(reinterpret_cast(io_addrs.data()), io_addrs_size); + } + + // When it's aicpu customized ops, taskArgs should append customized attr + if (node_def_len != 0) { + (void)args_.append(reinterpret_cast(node_def_str_.data()), node_def_len); + } + + MS_LOG(INFO) << "CreateCpuKernelInfoOffline end"; +} + +bool AicpuOpKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + if (stream_ptr == nullptr) { + MS_LOG(ERROR) << "stream_ptr should not be nullptr."; + return false; + } + + CreateCpuKernelInfo(inputs, outputs); + if (node_name_ == kTopK) { + node_name_ = kTopKV2; + } + MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_ + << ", args_size:" << args_.length(); + if (rtCpuKernelLaunch(reinterpret_cast(node_so_.c_str()), + reinterpret_cast(node_name_.c_str()), 1, + reinterpret_cast(args_.data()), static_cast(args_.length()), nullptr, + stream_ptr) != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Aicpu op launch failed!"; + + return false; + } + return true; +} + +std::vector AicpuOpKernelMod::GenTask(const std::vector &inputs, + const std::vector &, + const std::vector &outputs, uint32_t stream_id) { + MS_LOG(INFO) << "AicpuOpKernelMod GenTask start"; + + stream_id_ = stream_id; + node_so_ = AICPU_OPS_SO_NAME; + std::vector input_data_addrs; + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), + [](const AddressPtr &input) -> void * { return input->addr; }); + + std::vector output_data_addrs; + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), + [](const AddressPtr &output) -> void * { return output->addr; }); + + if (node_name_ == kTopK) { + node_name_ = kTopKV2; + } + + AicpuTaskInfoPtr task_info_ptr = make_shared( + kernel_name_, stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs, NeedDump()); + + MS_LOG(INFO) << "AicpuOpKernelMod GenTask end"; + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h new file mode 100644 index 0000000000..82260010ea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +namespace mindspore { +namespace kernel { +class AicpuOpKernelMod : public AscendKernelMod { + public: + AicpuOpKernelMod(); + ~AicpuOpKernelMod() override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + void SetInputList(const std::vector &inputList); + void SetOutputList(const std::vector &outputList); + void SetAnfNode(const AnfNodePtr &anf_node); + void SetNodeDef(const std::string &nodeDef); + void SetNodeName(const std::string &node_name); + + /** + * @brief Build AICPU Engine kernel structure, and allocate device memory for offline task generate + * @return SUCCESS + * @return FAIL + * + */ + void CreateCpuKernelInfo(const std::vector &inputs, const std::vector &outputs); + + void SetInputSizeList(const std::vector &size_list); + void SetOutputSizeList(const std::vector &size_list); + void SetWorkspaceSizeList(const std::vector &size_list); + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + + private: + std::string args_; + std::string node_def_str_; + std::string node_name_; + std::string node_so_; + std::vector inputList_; + std::vector outputList_; + AnfNodePtr anf_node_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +using AicpuOpKernelModPtr = std::shared_ptr; +using AicputOpKernelModPtrList = std::vector; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc new file mode 100644 index 0000000000..790319daa6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +#include +#include +#include "proto/types.pb.h" +#include "runtime/mem.h" +#include "runtime/rt.h" +#include "utils/convert_utils.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +static std::map MS_PROTO_DATA_TYPE_MAP = { + {mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN}, + {mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL}, + {mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32}, + {mindspore::TypeId::kNumberTypeInt8, mindspore::DataType::MS_INT8}, + {mindspore::TypeId::kNumberTypeInt16, mindspore::DataType::MS_INT16}, + {mindspore::TypeId::kNumberTypeInt32, mindspore::DataType::MS_INT32}, + {mindspore::TypeId::kNumberTypeInt64, mindspore::DataType::MS_INT64}, + {mindspore::TypeId::kNumberTypeUInt, mindspore::DataType::MS_UINT32}, + {mindspore::TypeId::kNumberTypeUInt8, mindspore::DataType::MS_UINT8}, + {mindspore::TypeId::kNumberTypeUInt16, mindspore::DataType::MS_UINT16}, + {mindspore::TypeId::kNumberTypeUInt32, mindspore::DataType::MS_UINT32}, + {mindspore::TypeId::kNumberTypeUInt64, mindspore::DataType::MS_UINT64}, + {mindspore::TypeId::kNumberTypeFloat16, mindspore::DataType::MS_FLOAT16}, + {mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32}, + {mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32}, + {mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64}, +}; + +int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) { + auto iter = MS_PROTO_DATA_TYPE_MAP.find(ms_type); + if (iter != MS_PROTO_DATA_TYPE_MAP.end()) { + return MS_PROTO_DATA_TYPE_MAP[ms_type]; + } else { + MS_LOG(ERROR) << "UnSupported ms_type value" << static_cast(ms_type); + return -1; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h new file mode 100644 index 0000000000..fd4495afeb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +constexpr auto kInitDataSetQueue = "InitDataSetQueue"; +constexpr auto kInitData = "InitData"; +constexpr auto kGetNext = "GetNext"; +constexpr auto kPrint = "Print"; +constexpr auto kPack = "Pack"; +constexpr auto kOutputTypes = "output_types"; +constexpr auto kOutputShapes = "output_shapes"; +constexpr auto kChannelName = "channel_name"; +constexpr auto kSharedName = "shared_name"; +constexpr auto kShapes = "shapes"; +constexpr auto kTypes = "types"; +constexpr auto kQueueName = "queue_name"; +constexpr auto kSeed = "seed"; +constexpr auto kSeed0 = "Seed0"; +constexpr auto kSeed1 = "Seed1"; +constexpr auto kSeed2 = "seed2"; +constexpr auto kTopK = "TopK"; +constexpr auto kTopKV2 = "TopKV2"; + +struct AicpuParamHead { + uint32_t length; // Total length: include cunstom message + uint32_t ioAddrNum; // Input and output address number + uint32_t extInfoLength; // extInfo struct Length + uint64_t extInfoAddr; // extInfo address +} __attribute__((packed)); + +class AicpuOpUtil { + public: + static int MsTypeToProtoType(TypeId ms_type); + + private: + // kernel id + static uint64_t KernelId_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ diff --git a/mindspore/ccsrc/kernel/aicpu/proto/attr.proto b/mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/attr.proto similarity index 100% rename from mindspore/ccsrc/kernel/aicpu/proto/attr.proto rename to mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/attr.proto diff --git a/mindspore/ccsrc/kernel/aicpu/proto/node_def.proto b/mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/node_def.proto similarity index 100% rename from mindspore/ccsrc/kernel/aicpu/proto/node_def.proto rename to mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/node_def.proto diff --git a/mindspore/ccsrc/kernel/aicpu/proto/tensor.proto b/mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/tensor.proto similarity index 100% rename from mindspore/ccsrc/kernel/aicpu/proto/tensor.proto rename to mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/tensor.proto diff --git a/mindspore/ccsrc/kernel/aicpu/proto/tensor_shape.proto b/mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/tensor_shape.proto similarity index 100% rename from mindspore/ccsrc/kernel/aicpu/proto/tensor_shape.proto rename to mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/tensor_shape.proto diff --git a/mindspore/ccsrc/kernel/aicpu/proto/types.proto b/mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/types.proto similarity index 100% rename from mindspore/ccsrc/kernel/aicpu/proto/types.proto rename to mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/types.proto diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc new file mode 100644 index 0000000000..73fdb5c11b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc @@ -0,0 +1,180 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" + +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace kernel { +void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + // The x and output are akg op input and output param. + std::vector input_names = {"x"}; + std::vector output_names = {"output"}; + AnfAlgo::SetNodeAttr("input_names", MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr("output_names", MakeValue(output_names), anf_node); + + TypeId dst_type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); + std::string dst_type; + if (dst_type_id == kFloat32->type_id()) { + dst_type = "float32"; + } else if (dst_type_id == kFloat16->type_id()) { + dst_type = "float16"; + } + AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); +} + +void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector input_names = {"x"}; + std::vector output_names = {"output"}; + AnfAlgo::SetNodeAttr("input_names", MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr("output_names", MakeValue(output_names), anf_node); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(anf_node, 0); + if (origin_shape.size() != kShape4dDims) { + MS_LOG(EXCEPTION) << "The dim of origin_shape is not equal to 4, but it's dim is " << origin_shape.size() << "."; + } + std::vector shape_transform; + (void)std::transform(origin_shape.begin(), origin_shape.end(), std::back_inserter(shape_transform), + [](const int &origin_shape) { return static_cast(origin_shape); }); + AnfAlgo::SetNodeAttr("shape4d", MakeValue(shape_transform), anf_node); + AnfAlgo::SetNodeAttr("output_format", MakeValue(kOpFormat_NCHW), anf_node); + + TypeId dst_type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); + std::string dst_type; + if (dst_type_id == kFloat32->type_id()) { + dst_type = "float32"; + } else if (dst_type_id == kFloat16->type_id()) { + dst_type = "float16"; + } + AnfAlgo::SetNodeAttr("dstType", MakeValue(dst_type), anf_node); +} + +void SetAkgAttrsForCast(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + // The x and output are akg op input and output param. + std::vector input_names = {"x", "dst_type"}; + std::vector output_names = {"output"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); + + std::string dst_type; + TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); + if (output_type == kFloat32->type_id()) { + dst_type = "float32"; + } else if (output_type == kFloat16->type_id()) { + dst_type = "float16"; + } else if (output_type == kInt32->type_id()) { + dst_type = "int32"; + } else { + MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString(); + } + AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); +} + +void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector input_names{"dy", "data", "mean"}; + std::vector output_names{"dgamma_red_hw", "dbeta_red_hw", "data_minus_mean"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); +} + +void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node) { + const size_t kBNGrad2InputSize = 5; + MS_EXCEPTION_IF_NULL(anf_node); + std::vector input_names{"dgamma_red_hw", "dbeta_red_hw", "variance", "gamma"}; + std::vector output_names{"bn_scale", "bn_bias", "rs", "dgamma_dx", "dbeta_dx"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() < kBNGrad2InputSize) { + MS_LOG(EXCEPTION) << "The inputs size of BNGrad2 is less then " << kBNGrad2InputSize; + } + auto input1 = cnode->input(1); + MS_EXCEPTION_IF_NULL(input1); + auto tuple_getitem = input1->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + if (tuple_getitem->inputs().size() < kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The inputs size of tuple_getitem is less then " << kTupleGetItemInputSize; + } + auto bn_grad1 = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); + std::vector data_shape = AnfAlgo::GetInputDeviceShape(bn_grad1, 0); + AnfAlgo::SetNodeAttr(kAttrDataShape, MakeValue(opt::Convert2Int(data_shape)), anf_node); +} + +void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector input_names{"dy", "rs", "dgamma_dx", "dbeta_dx", "data_minus_mean"}; + std::vector output_names{"dx"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); +} + +void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + // Set attr for fused_bn1 + std::vector fused_bn1_input_names{"data"}; + std::vector fused_bn1_output_names{"mean", "var_part"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn1_input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn1_output_names), anf_node); +} + +void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + // Set attr for fused_bn2 + std::vector fused_bn2_input_names{"mean", "var_part", "running_mean", "running_var"}; + std::vector fused_bn2_output_names{"variance", "running_mean", "running_variance"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn2_input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn2_output_names), anf_node); +} + +void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + // Set attr for fused_bn3 + std::vector fused_bn3_input_names{"data", "mean", "variance", "gamma", "beta"}; + std::vector fused_bn3_output_names{"y"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn3_input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn3_output_names), anf_node); +} + +void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector conv_bn1_output_names{"data", "var_part", "mean"}; + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(conv_bn1_output_names), anf_node); +} + +void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector bn2_add_relu_input_names{"data", "var_part", "mean", "other_branch_data", + "gamma", "beta", "running_mean", "running_var"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_add_relu_input_names), anf_node); + std::vector bn2_add_relu_output_names{"output", "running_mean", "running_variance", "save_inv_variance"}; + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_add_relu_output_names), anf_node); +} + +void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector bn2_input_names{"data", "var_part", "mean", "gamma", "beta", "running_mean", "running_var"}; + std::vector bn2_output_names{"y", "running_mean", "running_variance", "save_inv_variance"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h new file mode 100644 index 0000000000..9ba724db42 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H +#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H + +#include +#include +#include +#include +#include "ir/anf.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace kernel { +void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node); +void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node); +void SetAkgAttrsForCast(const AnfNodePtr &anf_node); +void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node); +void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node); +void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node); +void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node); +void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node); +void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node); +void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node); +void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node); +void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node); + +const std::unordered_map> kAkgKernelAttrsProcessMap = { + {kFour2FiveOpName, SetAkgAttrsForFour2Five}, + {kFive2FourOpName, SetAkgAttrsForFive2Four}, + {"Cast", SetAkgAttrsForCast}, + {kBNGrad1OpName, SetAkgAttrsForBNGrad1}, + {kBNGrad2OpName, SetAkgAttrsForBNGrad2}, + {kBNGrad3OpName, SetAkgAttrsForBNGrad3}, + {kFusedBN1OpName, SetAkgAttrsForFusedBN1}, + {kFusedBN2OpName, SetAkgAttrsForFusedBN2}, + {kFusedBN3OpName, SetAkgAttrsForFusedBN3}, + {kConvBN1OpName, SetAkgAttrsForConvBN1}, + {kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu}, + {kBN2ReLUOpName, SetAkgAttrsForBN2Relu}, +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc new file mode 100644 index 0000000000..9c13629b1b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc @@ -0,0 +1,623 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/akg_kernel_build.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "utils/convert_utils.h" +#include "utils/any.h" +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" + +namespace mindspore { +namespace kernel { +constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; +constexpr int32_t ARGS_SIZE = 1; +constexpr auto kCompileWithJsonFunc = "compilewithjson"; + +// json key +constexpr auto kOpDesc = "op_desc"; +constexpr auto kInputDesc = "input_desc"; +constexpr auto kShape = "shape"; +constexpr auto kDataType = "data_type"; +constexpr auto kOutputDesc = "output_desc"; +constexpr auto kName = "name"; +constexpr auto kTensorName = "tensor_name"; +constexpr auto kValue = "value"; +constexpr auto KDynInputSizes = "dyn_input_sizes"; +constexpr auto KInputNames = "input_names"; +constexpr auto KInput = "input"; +constexpr auto KDtype = "dtype"; +namespace { +template +std::string Vector2Str(const std::vector &inputs) { + if (!inputs.empty()) { + std::ostringstream oss; + (void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator(oss, ", ")); + oss << inputs.back(); + return oss.str(); + } + return ""; +} +} // namespace + +std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) { + char *pChar = nullptr; + std::string str_res; + if (PyObj == nullptr) { + MS_LOG(ERROR) << "Input parameter is nullptr."; + return str_res; + } + PyObject *strArgs = PyObject_Str(PyObj); + if (strArgs != nullptr) { + (void)PyArg_Parse(strArgs, "s", &pChar); + } + if (pChar == nullptr) { + MS_LOG(ERROR) << "pChar is nullptr."; + return str_res; + } + str_res = pChar; + return str_res; +} + +std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, + const std::pair &position) { + if (node_json.count(tag) == 0) { + MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "]."; + return ""; + } + + auto const &tag_desc = node_json[tag]; + nlohmann::json first_index; + if (tag == kOutputDesc) { + first_index = tag_desc; + } else if (!tag_desc.is_array() || tag_desc.size() <= position.first) { + MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "]."; + return ""; + } else { + first_index = tag_desc[position.first]; + } + + if (!first_index.is_array() || first_index.size() <= position.second) { + MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "]."; + return ""; + } + auto const &second_index = first_index[position.second]; + if (second_index.count(kTensorName) == 0) { + MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "]."; + return ""; + } + + return second_index[kTensorName]; +} + +void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair &position, + nlohmann::json *const node_json) { + MS_EXCEPTION_IF_NULL(node_json); + if (node_json->count(tag) == 0) { + MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "]."; + return; + } + + nlohmann::json *tag_desc = &((*node_json)[tag]); + nlohmann::json *first_index; + if (tag == kOutputDesc) { + first_index = tag_desc; + } else if (!tag_desc->is_array() || tag_desc->size() <= position.first) { + MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "]."; + return; + } else { + first_index = &((*tag_desc)[position.first]); + } + + if (!first_index->is_array() || first_index->size() <= position.second) { + MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "]."; + return; + } + nlohmann::json *second_index = &((*first_index)[position.second]); + if (second_index->count(kTensorName) == 0) { + MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "]."; + return; + } + (*second_index)[kTensorName] = new_name; + return; +} + +int AkgKernelBuild::op_cnt_ = 0; +std::mutex AkgKernelBuild::op_cnt_mtx_; + +std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string device; + switch (AnfAlgo::GetProcessor(anf_node)) { + case Processor::AICORE: + device = kProcessorAiCore; + break; + + case Processor::AICPU: + device = kProcessorAiCpu; + break; + + case Processor::CUDA: + device = kProcessorCuda; + break; + + default: + MS_LOG(ERROR) << "Unknown processor type."; + break; + } + + return device; +} + +bool GetIOSize(const nlohmann::json &node_json, std::vector *const input_size, + std::vector *const output_size) { + if (input_size == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "input size or output size is nullptr"; + return false; + } + input_size->clear(); + output_size->clear(); + + for (size_t i = 0; i < node_json[kInputDesc].size(); i++) { + for (size_t m = 0; m < node_json[kInputDesc][i].size(); m++) { + std::string dtype = node_json[kInputDesc][i][m][kDataType]; + size_t nbyte = GetDtypeNbyte(dtype); + size_t size_i = std::accumulate(node_json[kInputDesc][i][m][kShape].begin(), + node_json[kInputDesc][i][m][kShape].end(), nbyte, std::multiplies()); + input_size->push_back(size_i); + } + } + + for (size_t i = 0; i < node_json[kOutputDesc].size(); i++) { + std::string dtype = node_json[kOutputDesc][i][kDataType]; + size_t nbyte = GetDtypeNbyte(dtype); + size_t size_i = std::accumulate(node_json[kOutputDesc][i][kShape].begin(), node_json[kOutputDesc][i][kShape].end(), + nbyte, std::multiplies()); + output_size->push_back(size_i); + } + + return true; +} + +int AkgKernelBuild::GetOpCntInc() { + op_cnt_mtx_.lock(); + int cnt = op_cnt_++; + op_cnt_mtx_.unlock(); + return cnt; +} + +bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(inputs_json); + + // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); + if (op_info == nullptr) { + MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op_info is nullptr"; + return false; + } + + std::vector> inputs_ptr = op_info->inputs_ptr(); + if (inputs_ptr.empty()) { + MS_LOG(INFO) << "Apply kernel [" << op_name << "] regist info has no input info"; + return true; + } + auto op_info_input_num = inputs_ptr.size(); + + // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. + std::vector dyn_input_sizes; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + + if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + + size_t real_input_index = 0; + std::vector input_list; + for (size_t i = 0; i < op_info_input_num; i++) { + size_t input_tensor_num; + std::shared_ptr input_ptr = inputs_ptr[i]; + std::string op_input_name; + if (input_ptr == nullptr) { + MS_LOG(ERROR) << "Apply kernel [" << op_name << "] regist input[" << i << "] is nullptr"; + return false; + } + + op_input_name = input_ptr->name(); + if (dyn_input_sizes.empty()) { + input_tensor_num = 1; + } else { + input_tensor_num = IntToSize(dyn_input_sizes[i]); + } + + input_list.clear(); + for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { + // dtype : float16 + auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index); + std::string dtype = TypeId2String(type_id); + if (dtype.empty()) { + MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. "; + return false; + } + nlohmann::json input_desc_json; + input_desc_json[kDataType] = dtype; + input_desc_json[kName] = op_input_name; + input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); + auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); + if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && + GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { + MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2) + << "] as const tensor, shape: [" << Vector2Str(input_shape) + << "], value: " << input_desc_json[kValue]; + + input_shape.clear(); + } + if (input_shape.empty()) { + input_shape.push_back(1); + } + input_desc_json[kShape] = input_shape; + input_list.emplace_back(input_desc_json); + real_input_index++; + } + inputs_json->emplace_back(input_list); + } + return true; +} + +bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(outputs_json); + size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); + auto outputs = op_info_ptr->outputs_ptr(); + for (size_t i = 0; i < output_tensor_num; i++) { + nlohmann::json output_json; + auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i); + std::string dtype = TypeId2String(type_id); + if (dtype.empty()) { + MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. "; + return false; + } + + std::string output_name = outputs[i]->name(); + output_json[kDataType] = dtype; + output_json[kName] = output_name; + output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc()); + output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i); + outputs_json->push_back(output_json); + } + return true; +} + +void GetJson(const AnfNodePtr &anf_node, const std::vector &dyn_input_sizes, + const std::shared_ptr &op_attr, nlohmann::json *const attr_json, const ValuePtr &attr_value) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(op_attr); + MS_EXCEPTION_IF_NULL(attr_json); + std::string type = op_attr->type(); + if (type == "int") { + (*attr_json)[kValue] = GetValue(attr_value); + } else if (type == "str") { + (*attr_json)[kValue] = GetValue(attr_value); + } else if (type == "bool") { + (*attr_json)[kValue] = GetValue(attr_value); + } else if (type == "float") { + (*attr_json)[kValue] = GetValue(attr_value); + } else if (type == "listInt") { + (*attr_json)[kValue] = GetValue>(attr_value); + } else if (type == "listStr") { + std::vector data_format; + if (op_attr->name() == kArgDataformat) { + size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); + for (size_t format_i = 0; format_i < tensor_args_num; format_i++) { + auto input_format = AnfAlgo::GetInputFormat(anf_node, format_i); + data_format.push_back(input_format); + } + } else { + data_format = GetValue>(attr_value); + } + (*attr_json)[kValue] = data_format; + } else { + MS_LOG(WARNING) << "attr type:" << type; + } +} + +bool AkgKernelBuild::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, + const std::shared_ptr &op_info, nlohmann::json *const attrs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(attrs_json); + MS_EXCEPTION_IF_NULL(op_info); + std::vector> attrs = op_info->attrs_ptr(); + if (attrs.empty()) { + MS_LOG(INFO) << "Apply kernel [" << op_name << "] op info attrs is empty"; + return true; + } + std::vector> inputs = op_info->inputs_ptr(); + + std::vector dyn_input_sizes; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + + if (inputs.empty()) { + MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op info inputs is empty"; + return false; + } + + // create input name list for atch "x_shape" in att with "x" in primitive. + std::map op_info_shape_name; + for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) { + std::string input_name = inputs[op_info_input_i]->name(); + std::string x_shape_name = input_name + "_shape"; + (void)op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name)); + } + + for (const auto &op_attr : attrs) { + nlohmann::json attr_json; + ValuePtr attr_value = primitive->GetAttr(op_attr->name()); + if (attr_value == nullptr && op_attr->name() != kArgDataformat) { + if (op_attr->param_type() == "required") { + // match "x_shape" in att with "x" in primitive. + std::string attr_name = op_attr->name(); + auto find_item = std::find_if( + op_info_shape_name.begin(), op_info_shape_name.end(), + [attr_name](const std::map::value_type item) { return item.second == attr_name; }); + if (find_item != op_info_shape_name.end()) { + if (!dyn_input_sizes.empty()) { + if (find_item->first >= dyn_input_sizes.size() - 1) { + MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first + << " is out of range:" << dyn_input_sizes.size() - 1 << "."; + return false; + } + size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0)); + for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) { + attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx); + attr_json[kName] = op_attr->name(); + attrs_json->push_back(attr_json); + tensor_idx++; + } + } else { + attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first); + attr_json[kName] = op_attr->name(); + attrs_json->push_back(attr_json); + } + } else { + MS_LOG(ERROR) << "op [" << op_name << "] should have attr :" << op_attr->name(); + return false; + } + } + continue; + } + + GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value); + + attr_json[kName] = op_attr->name(); + attrs_json->push_back(attr_json); + } + return true; +} + +bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, + nlohmann::json *const node_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(node_json); + int op_cnt = GetOpCntInc(); + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); + MS_EXCEPTION_IF_NULL(op_info_ptr); + + // get basic params from currentNodeOpDesc + (*node_json)[kName] = op_name; + (*node_json)["impl_path"] = op_info_ptr->impl_path(); + (*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node); + (*node_json)["composite"] = false; + + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + ValuePtr input_names_v = primitive->GetAttr(KInputNames); + if (input_names_v == nullptr) { + MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "]."; + return false; + } + std::vector prim_input_names = GetValue>(input_names_v); + std::string inputs_name; + for (const auto &prim_input_name : prim_input_names) { + (void)inputs_name.append("_input_").append(prim_input_name).append("_"); + } + + // input desc + nlohmann::json inputs_json; + if (!CreateInputDescJson(anf_node, &inputs_json)) { + MS_LOG(ERROR) << "Create input desc json failed, op[" << op_name << "]."; + return false; + } + (*node_json)[kInputDesc] = inputs_json; + MS_LOG(INFO) << "Akg create input desc json success."; + std::string inputs_shape = "inputs_shape_"; + for (auto &i : inputs_json) { + for (auto &m : i) { + std::string data_type = m[kDataType]; + (void)inputs_shape.append("_").append(data_type).append("_"); + for (auto &j : m[kShape]) { + size_t n = j; + (void)inputs_shape.append(std::to_string(n)).append("_"); + } + } + } + + // output desc + nlohmann::json outputs_json; + if (!CreateOutputDescJson(anf_node, &outputs_json)) { + MS_LOG(ERROR) << "Create output desc json failed, op[" << op_name << "]."; + return false; + } + + (*node_json)[kOutputDesc] = outputs_json; + MS_LOG(INFO) << "Akg create output desc json success."; + std::string outputs_shape = "outputs_shape_"; + for (auto &i : outputs_json) { + std::string data_type = i[kDataType]; + (void)outputs_shape.append("_").append(data_type).append("_"); + for (auto &j : i[kShape]) { + size_t m = j; + (void)outputs_shape.append(std::to_string(m)).append("_"); + } + } + + // attribute desc + nlohmann::json attrs_json; + if (!CreateAttrDescJson(anf_node, op_name, op_info_ptr, &attrs_json)) { + MS_LOG(ERROR) << "Create attr desc json failed, op[" << op_name << "]."; + return false; + } + (*node_json)["attr"] = attrs_json; + std::string json_str = node_json->dump(); + size_t hash_id = std::hash()(json_str); + json_name_ = op_name + "_"; + (void)json_name_.append(std::to_string(hash_id)); + MS_LOG(INFO) << "full scope name is : " << anf_node->fullname_with_scope() << ", json info name is : " << json_name_; + json_info_ = json_str; + (*node_json)["id"] = op_cnt; + (*node_json)["op"] = json_name_; + MS_LOG(INFO) << "Akg create node desc json success."; + return true; +} + +KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + auto processor = AkgKernelBuild::GetProcessor(anf_node); + auto cached_kernel_pack = SearchCache(json_name_, processor); + if (cached_kernel_pack != nullptr) { + MS_LOG(INFO) << "Use cached kernel, json_name_[" << json_name_ << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + return cached_kernel_pack; + } + + PyObject *pModule = nullptr; + PyObject *pFunc = nullptr; + PyObject *pArg = nullptr; + PyObject *pRes = nullptr; + + pModule = PyImport_ImportModule(kAkgModule); + if (pModule == nullptr) { + MS_LOG(ERROR) << "Failed to import [" << kAkgModule << "]."; + return nullptr; + } + + pFunc = PyObject_GetAttrString(pModule, kCompileWithJsonFunc); + pArg = PyTuple_New(ARGS_SIZE); + (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", node_json.c_str())); + + (void)alarm(AUTODIFF_COMPILE_OVERTIME); + pRes = PyEval_CallObject(pFunc, pArg); + (void)alarm(0); + if (pRes == nullptr) { + MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" + << AkgKernelBuild::PyObjectToStr(pArg) << ")."; + return nullptr; + } + if (PyObject_IsTrue(pRes) != 1) { + MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" + << AkgKernelBuild::PyObjectToStr(pArg) << ")."; + return nullptr; + } + + auto new_kernel_pack = InsertCache(json_name_, processor); + kernel::SaveJsonInfo(json_name_, json_info_); + if (new_kernel_pack == nullptr) { + MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name_ << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + return nullptr; + } + return new_kernel_pack; +} + +KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vector *const input_size, + std::vector *const output_size) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + auto it = kAkgKernelAttrsProcessMap.find(op_name); + if (it != kAkgKernelAttrsProcessMap.end()) { + it->second(anf_node); + } + MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; + nlohmann::json node_json; + if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { + MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed."; + } + + std::string json_str = node_json.dump(); + auto kernel_pack = OpBuild(json_str, anf_node); + if (kernel_pack == nullptr) { + MS_LOG(ERROR) << "Akg build failed op[" << op_name << "], json:" << json_str; + return nullptr; + } + + if (!GetIOSize(node_json, input_size, output_size)) { + MS_LOG(ERROR) << "Cal mem size failed."; + return nullptr; + } + MS_LOG(INFO) << "Akg compile success, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) + << "]"; + return kernel_pack; +} + +size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) { + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (input_idx + 1 >= cnode->inputs().size()) { + MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" + << cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]"; + } + + auto input_node = cnode->input(input_idx + 1); + if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) { + size_t index = input_tensor_idx_.size(); + input_tensor_idx_[input_node] = index; + } + + return input_tensor_idx_[input_node]; +} + +size_t AkgKernelBuild::GetOutputTensorIdxInc() { + size_t idx = output_tensor_idx_++; + return idx; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.h new file mode 100644 index 0000000000..7b6a2f0b86 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "ir/dtype.h" +#include +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/oplib/oplib.h" + +namespace mindspore { +namespace kernel { +class AkgKernelBuild { + public: + AkgKernelBuild() { + input_tensor_idx_ = {}; + output_tensor_idx_ = 0; + } + ~AkgKernelBuild() = default; + + KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector *const input_size, + std::vector *const output_size); + static std::string GetProcessor(const AnfNodePtr &anf_node); + static std::string PyObjectToStr(PyObject *const PyObj); + + protected: + bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json); + bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json); + bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, + const std::shared_ptr &op_info, nlohmann::json *const attrs_json); + KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node); + int GetOpCntInc(); + size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx); + size_t GetOutputTensorIdxInc(); + bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, + nlohmann::json *const node_json); + + static int op_cnt_; + // lock for variable fusionOpCnt in singleton mode + static std::mutex op_cnt_mtx_; + std::string json_name_; + std::string json_info_; + std::unordered_map input_tensor_idx_; + size_t output_tensor_idx_; +}; + +bool GetIOSize(const nlohmann::json &node_json, std::vector *const input_size, + std::vector *const output_size); +void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair &position, + nlohmann::json *const node_json); +std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, + const std::pair &position); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.cc new file mode 100644 index 0000000000..f3567428d3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/akg_kernel_metadata.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +void AkgMetadataInfo(const CNodePtr &kernel_node, + std::vector> *const kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + for (size_t i = 0; i < support_devices.size(); i++) { + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); + if (op_info_ptr == nullptr) { + continue; + } + + if (!ParseMetadata(kernel_node, op_info_ptr, Processor(i), kernel_info_list)) { + MS_LOG(WARNING) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "] failed."; + } else { + MS_LOG(DEBUG) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "]."; + break; + } + } + + if (kernel_info_list->empty()) { + MS_LOG(WARNING) << "Akg dose not has metadata of op[" << op_name << "]."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.h new file mode 100644 index 0000000000..02785c6cdb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +void AkgMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc new file mode 100644 index 0000000000..d698c89bc9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc @@ -0,0 +1,422 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/dtype.h" +#include "ir/func_graph.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h" +#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +constexpr int32_t PARALLEL_ARGS_SIZE = 3; +constexpr int32_t PROCESS_NUM = 16; +constexpr int32_t TIME_OUT = 300; + +constexpr auto kOpDesc = "op_desc"; +constexpr auto kShape = "shape"; +constexpr auto kDataType = "data_type"; +constexpr auto kInputDesc = "input_desc"; +constexpr auto kOutputDesc = "output_desc"; +constexpr auto kTensorName = "tensor_name"; +constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel"; +constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler"; +namespace { +void UpdateTensorNameInJson(const std::vector &anf_nodes, + std::map *node_json_map) { + for (auto const &anf_node : anf_nodes) { + std::vector dyn_input_sizes; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + + if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + + bool is_dynamic_input = !dyn_input_sizes.empty(); + size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); + size_t real_input_index = 0; + for (size_t i = 0; i < input_num; ++i) { + size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1; + for (size_t j = 0; j < input_tensor_num; ++j) { + auto tmp_input = GetKernelInput(anf_node, real_input_index); + std::string tensor_name = GetTensorName((*node_json_map)[anf_node], kInputDesc, std::make_pair(i, j)); + if (node_json_map->find(tmp_input.first) != node_json_map->end()) { + std::string new_tensor_name = + GetTensorName((*node_json_map)[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second)); + SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &((*node_json_map)[anf_node])); + MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of [" + << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output [" + << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "]."; + } else { + MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of [" + << anf_node->fullname_with_scope() << "] is out input."; + } + real_input_index++; + } + } + } +} + +nlohmann::json GetInputsJson(const std::vector &anf_nodes, const std::vector &input_list, + std::map *node_json_map) { + nlohmann::json inputs_json; + auto input_index = GetInputIndex(anf_nodes, input_list); + for (size_t i = 0; i < input_index.size(); ++i) { + auto tmp_input = input_index[i]; + auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first); + std::string dtype = TypeId2String(type_id); + nlohmann::json input_desc_json; + input_desc_json[kTensorName] = GetTensorName((*node_json_map)[tmp_input.first], kInputDesc, tmp_input.second); + input_desc_json[kDataType] = dtype; + input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first); + inputs_json.emplace_back(std::vector{input_desc_json}); + } + + return inputs_json; +} + +nlohmann::json GetOutputsJson(const std::vector &anf_nodes, const std::vector &input_list, + const std::vector &output_list, const nlohmann::json &inputs_json, + std::map *node_json_map) { + nlohmann::json outputs_json; + auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); + for (size_t i = 0; i < output_index.size(); ++i) { + auto tmp_output = output_index[i]; + bool found = false; + nlohmann::json output_desc_json; + for (size_t input_i = 0; input_i < input_list.size(); ++input_i) { + if (tmp_output.first == input_list[input_i]) { + output_desc_json = inputs_json[input_i][0]; + found = true; + break; + } + } + if (!found) { + auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second); + std::string dtype = TypeId2String(type_id); + output_desc_json[kTensorName] = + GetTensorName((*node_json_map)[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second)); + output_desc_json[kDataType] = dtype; + auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second); + if (output_shape.empty()) { + output_shape.push_back(1); + } + output_desc_json[kShape] = output_shape; + } + outputs_json.emplace_back(output_desc_json); + } + + return outputs_json; +} + +std::pair, std::vector>> PreProcessJsonForBuild( + const std::vector> &build_args) { + // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. + std::vector jsons; + std::vector> repeat_nodes; + std::unordered_set json_name_set; + for (const auto &[builder, anf_node] : build_args) { + MS_EXCEPTION_IF_NULL(anf_node); + auto json_name = builder.json_name(); + MS_LOG(DEBUG) << "Akg start compile op: " << json_name; + auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); + if (cached_kernel_pack != nullptr) { + MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + auto kernel_mod_ptr = std::make_shared(cached_kernel_pack); + kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); + kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); + AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); + continue; + } + + if (json_name_set.count(json_name) != 0) { + repeat_nodes.push_back({builder, anf_node}); + continue; + } + json_name_set.insert(json_name); + auto node_json = builder.kernel_json(); + kernel::SaveJsonInfo(json_name, node_json); + jsons.push_back(node_json); + } + + return std::make_pair(jsons, repeat_nodes); +} + +bool PostProcessAfterCompile(const std::vector> &build_args, + const std::vector> &repeat_nodes) { + for (const auto &[builder, anf_node] : build_args) { + auto json_name = builder.json_name(); + auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); + if (new_kernel_pack == nullptr) { + MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + return false; + } + auto kernel_mod_ptr = std::make_shared(new_kernel_pack); + kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); + kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); + AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); + MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!"; + } + + for (const auto &[builder, anf_node] : repeat_nodes) { + auto node_json = builder.kernel_json(); + auto json_name = builder.json_name(); + auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); + if (cached_kernel_pack == nullptr) { + return false; + } + MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + auto kernel_mod_ptr = std::make_shared(cached_kernel_pack); + kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); + kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); + AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); + } + + return true; +} +} // namespace + +bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; + auto it = kAkgKernelAttrsProcessMap.find(op_name); + if (it != kAkgKernelAttrsProcessMap.end()) { + it->second(anf_node); + } + MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; + nlohmann::json node_json; + if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { + MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed."; + } + + kernel_json_ = node_json.dump(); + + if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) { + MS_LOG(ERROR) << "Cal mem size failed."; + return false; + } + + return true; +} + +bool AkgAscendKernelBuilder::GenJsonAndPreprocess4Fused(const std::vector &anf_nodes, + std::map *node_json_map) { + for (auto const &anf_node : anf_nodes) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (!AnfAlgo::IsRealKernel(anf_node)) { + MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "]."; + return false; + } + auto it = kAkgKernelAttrsProcessMap.find(op_name); + if (it != kAkgKernelAttrsProcessMap.end()) { + it->second(anf_node); + } + + nlohmann::json node_json; + if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { + MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed."; + return false; + } + // No need for composite op. + node_json.erase("id"); + node_json.erase("op"); + node_json.erase("composite"); + + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + + if (primitive->GetAttr("fusion") != nullptr) { + node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); + } + + (*node_json_map)[anf_node] = node_json; + } + return true; +} + +bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector &anf_nodes, + const std::vector &input_list, + const std::vector &output_list) { + if (anf_nodes.empty() || input_list.empty()) { + MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size() + << "]."; + return false; + } + MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list [" + << input_list.size() << "]."; + + std::map node_json_map; + if (!GenJsonAndPreprocess4Fused(anf_nodes, &node_json_map)) { + return false; + } + + UpdateTensorNameInJson(anf_nodes, &node_json_map); + + nlohmann::json fused_node_json; + std::vector node_json_desc; + std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc), + [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; }); + fused_node_json[kOpDesc] = node_json_desc; + fused_node_json[kInputDesc] = GetInputsJson(anf_nodes, input_list, &node_json_map); + fused_node_json[kOutputDesc] = + GetOutputsJson(anf_nodes, input_list, output_list, fused_node_json[kInputDesc], &node_json_map); + + size_t hash_id = std::hash()(fused_node_json.dump()); + json_name_ = "Fused_"; + auto fg = anf_nodes[0]->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + if (attr_val != nullptr) { + auto fg_attr = GetValue(attr_val); + (void)json_name_.append(fg_attr).append("_"); + } + (void)json_name_.append(std::to_string(hash_id)); + fused_node_json["composite_graph"] = fg->ToString(); + fused_node_json["op"] = json_name_; + fused_node_json["platform"] = "AKG"; + fused_node_json["process"] = "aicore"; + fused_node_json["composite"] = true; + + kernel_json_ = fused_node_json.dump(); + + if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) { + MS_LOG(ERROR) << "Cal mem size failed."; + return false; + } + + return true; +} + +void GenParallelCompileFuncArgs(const std::vector &kernel_jsons, PyObject **p_args) { + MS_EXCEPTION_IF_NULL(p_args); + *p_args = PyTuple_New(PARALLEL_ARGS_SIZE); + + PyObject *arg1 = PyList_New(kernel_jsons.size()); + for (int i = 0; i < PyList_Size(arg1); ++i) { + PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str())); + } + PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM); + PyObject *arg3 = Py_BuildValue("i", TIME_OUT); + + (void)PyTuple_SetItem(*p_args, 0, arg1); + (void)PyTuple_SetItem(*p_args, 1, arg2); + (void)PyTuple_SetItem(*p_args, 2, arg3); +} + +bool AkgOpParallelBuild(const std::vector> &build_args) { + auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args); + if (jsons.empty()) { + return true; + } + + // Try to call python method to compile nodes parallely. + PyObject *p_module = nullptr; + PyObject *p_func = nullptr; + PyObject *p_arg = nullptr; + PyObject *p_res = nullptr; + + p_module = PyImport_ImportModule(kMultiProcModule); + if (p_module == nullptr) { + MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "]."; + return false; + } + + p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc); + GenParallelCompileFuncArgs(jsons, &p_arg); + MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size() + << " Akg kernels parallelly."; + p_res = PyEval_CallObject(p_func, p_arg); + if (p_res == nullptr) { + PyErr_Print(); + MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" + << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; + return false; + } + if (PyObject_IsTrue(p_res) != 1) { + PyErr_Print(); + MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" + << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; + return false; + } + + if (!PostProcessAfterCompile(build_args, repeat_nodes)) { + return false; + } + + return true; +} + +bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes) { + std::vector> json_and_node; + for (const auto &anf_node : anf_nodes) { + MS_EXCEPTION_IF_NULL(anf_node); + AkgAscendKernelBuilder akg_cce_kernel_builder; + KernelPackPtr kernel_pack = nullptr; + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::IsGraphKernel(cnode)) { + auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + MS_EXCEPTION_IF_NULL(func_graph); + std::vector node_list; + std::vector input_list; + std::vector output_list; + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]"; + GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); + if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) { + MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "]."; + } + } else { + if (!akg_cce_kernel_builder.CollectJson(anf_node)) { + MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "]."; + } + } + json_and_node.push_back({akg_cce_kernel_builder, anf_node}); + } + + if (json_and_node.empty()) { + MS_LOG(DEBUG) << "There is no kernel needed to be compiled."; + return true; + } + + return AkgOpParallelBuild(json_and_node); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h new file mode 100644 index 0000000000..713b65a451 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/akg/akg_kernel_build.h" + +namespace mindspore { +namespace kernel { +class AkgAscendKernelBuilder : public AkgKernelBuild { + public: + AkgAscendKernelBuilder() = default; + ~AkgAscendKernelBuilder() = default; + + bool CollectJson(const AnfNodePtr &anf_node); + bool CollectFusedJson(const std::vector &anf_nodes, const std::vector &input_list, + const std::vector &output_list); + std::string json_name() const { return json_name_; } + std::string kernel_json() const { return kernel_json_; } + const std::vector &input_size_list() const { return input_size_list_; } + const std::vector &output_size_list() const { return output_size_list_; } + + private: + bool GenJsonAndPreprocess4Fused(const std::vector &anf_nodes, + std::map *node_json_map); + + std::string kernel_json_; + std::vector input_size_list_; + std::vector output_size_list_; +}; + +bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc new file mode 100644 index 0000000000..8bb4940778 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h" +#include +#include +#include +#include +#include +#include +#include +#include "nlohmann/json.hpp" +#include "runtime/rt.h" +#include "utils/log_adapter.h" +#include "utils/convert_utils.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +using std::fstream; +using std::map; +using std::mutex; +using std::string; +using TbeTaskInfoPtr = std::shared_ptr; +using tbe::KernelManager; +constexpr uint32_t DEFAULT_BLOCK_DIM = 1; +/** + * @brief infotable contain func_stub\blockdim\kernel file buffer + */ +AkgKernelMod::AkgKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {} + +void AkgKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } + +void AkgKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } + +void AkgKernelMod::SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } + +const std::vector &AkgKernelMod::GetInputSizeList() const { return input_size_list_; } + +const std::vector &AkgKernelMod::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool AkgKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + if (stream_ptr == nullptr) { + MS_LOG(ERROR) << "stream_ptr should not be nullptr."; + return false; + } + + if (kernel_pack_ == nullptr) { + MS_LOG(ERROR) << "kernel pack should not be nullptr."; + return false; + } + + uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. + auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); + if (func_stub == 0) { + MS_LOG(ERROR) << "GenFuncStub failed."; + return false; + } + + // pack all addresses into a vector. + std::vector runtime_args; + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtime_args), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args), + [](const AddressPtr &output) -> void * { return output->addr; }); + + rtL2Ctrl_t *l2ctrl = nullptr; + auto stream = reinterpret_cast(stream_ptr); + if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast(func_stub), block_dim, runtime_args.data(), + SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) { + MS_LOG(ERROR) << "Call runtime rtKernelLaunch error."; + return false; + } + + return true; +} + +std::vector AkgKernelMod::GenTask(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uint32_t stream_id) { + if (kernel_pack_ == nullptr) { + MS_LOG(EXCEPTION) << "kernel pack should not be nullptr."; + } + + std::vector args; + const uint32_t args_size = 0; + std::vector sm_desc; + void *binary = nullptr; + const uint32_t binary_size = 0; + std::vector meta_data; + std::vector input_data_addrs; + std::vector output_data_addrs; + std::vector workspace_addrs; + + // pack all addresses into a vector. + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), + [](const AddressPtr &output) -> void * { return output->addr; }); + + uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. + auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); + if (func_stub == 0) { + MS_LOG(EXCEPTION) << "GenFuncStub failed."; + } + + std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_); + + MS_LOG(DEBUG) << "The block_dim is:" << block_dim; + + TbeTaskInfoPtr task_info_ptr = make_shared( + kernel_name_, stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, + input_data_addrs, output_data_addrs, workspace_addrs, NeedDump()); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h new file mode 100644 index 0000000000..3ea36f1a23 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" + +namespace mindspore { +namespace kernel { +class AkgKernelMod : public AscendKernelMod { + public: + explicit AkgKernelMod(const KernelPackPtr &kernel_pack); + ~AkgKernelMod() final {} + + void SetInputSizeList(const std::vector &size_list); + void SetOutputSizeList(const std::vector &size_list); + void SetWorkspaceSizeList(const std::vector &size_list); + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + KernelPackPtr kernel_pack_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +using AkgKernelModPtr = std::shared_ptr; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc new file mode 100644 index 0000000000..96fcd1869e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h" +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/akg/akg_kernel_build.h" +#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + AkgKernelBuild akg_kernel_build; + + std::vector input_size_list; + std::vector output_size_list; + KernelPackPtr kernel_pack = akg_kernel_build.BuildByJson(anf_node, &input_size_list, &output_size_list); + MS_EXCEPTION_IF_NULL(kernel_pack); + + auto kernel_mod_ptr = std::make_shared(kernel_pack); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + kernel_mod_ptr->SetInputSizeList(input_size_list); + kernel_mod_ptr->SetOutputSizeList(output_size_list); + return kernel_mod_ptr; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h new file mode 100644 index 0000000000..abb6d1f030 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ +#include "backend/kernel_compiler/kernel.h" +#include "base/base.h" + +namespace mindspore { +namespace kernel { +KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc new file mode 100644 index 0000000000..d527f8ec76 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h" +#include +#include +#include "nlohmann/json.hpp" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +using std::fstream; +using std::string; +using std::vector; + +GpuKernelManagerPtr GpuKernelMod::kernelmanager_ = std::make_shared(); +GpuKernelManager::GpuKernelManager() {} + +CUresult GpuKernelManager::GetFunction(const KernelPackPtr &kernel_pack, bool force_reload, + vector *thread_info, CUfunction *func) { + if (kernel_pack->GetJson() == nullptr || kernel_pack->GetJson()->contents == nullptr || + kernel_pack->GetKernel() == nullptr || kernel_pack->GetKernel()->contents == nullptr) { + MS_LOG(ERROR) << "GPU:Invalid kernel pack, json or kernel is nullptr."; + return CUDA_ERROR_INVALID_IMAGE; + } + auto js = nlohmann::json::parse(kernel_pack->GetJson()->contents, + kernel_pack->GetJson()->contents + kernel_pack->GetJson()->len); + string fn = js["kernelName"]; + if (!force_reload) { + auto iter = infotable_.find(fn); + if (iter != infotable_.end()) { + auto kernelmeta = iter->second; + *thread_info = kernelmeta->thread_info_; + *func = kernelmeta->func_addr_; + return CUDA_SUCCESS; + } + } + thread_info->emplace_back(js["blockIdx.x"]); + thread_info->emplace_back(js["blockIdx.y"]); + thread_info->emplace_back(js["blockIdx.z"]); + thread_info->emplace_back(js["threadIdx.x"]); + thread_info->emplace_back(js["threadIdx.y"]); + thread_info->emplace_back(js["threadIdx.z"]); + CUmodule module; + CUresult result = cuModuleLoadData(&module, kernel_pack->GetKernel()->contents); + if (result != CUDA_SUCCESS) { + MS_LOG(ERROR) << "cuModuleLoadData failed."; + return result; + } + result = cuModuleGetFunction(func, module, fn.c_str()); + if (result != CUDA_SUCCESS) { + MS_LOG(ERROR) << "cuModuleGetFunction failed."; + return result; + } + infotable_[fn] = std::make_shared(*func, module, *thread_info); + return result; +} + +GpuKernelMod::GpuKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {} + +void GpuKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } + +void GpuKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } + +const std::vector &GpuKernelMod::GetInputSizeList() const { return input_size_list_; } + +const std::vector &GpuKernelMod::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &GpuKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool GpuKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + if (stream_ptr == 0) { + MS_LOG(ERROR) << "stream_ptr should not be nullptr."; + return false; + } + if (kernel_pack_ == nullptr) { + MS_LOG(ERROR) << "kernel pack should not be nullptr."; + return false; + } + vector thread_info; + CUfunction kernel_addr; + CUresult result = kernelmanager_->GetFunction(kernel_pack_, false, &thread_info, &kernel_addr); + if (result != CUDA_SUCCESS) { + MS_LOG(ERROR) << "GetFunction failed."; + return false; + } + std::vector runtimeargs; + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs), + [](const AddressPtr &input) -> void * { return reinterpret_cast(&(input->addr)); }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs), + [](const AddressPtr &output) -> void * { return reinterpret_cast(&(output->addr)); }); + result = cuLaunchKernel(kernel_addr, thread_info[0], thread_info[1], thread_info[2], thread_info[3], thread_info[4], + thread_info[5], 0, reinterpret_cast(stream_ptr), + reinterpret_cast(&runtimeargs[0]), 0); + if (result != CUDA_SUCCESS) { + MS_LOG(ERROR) << "Launch Kernel failed."; + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h new file mode 100644 index 0000000000..a6a17d033f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +struct GpuKernelMeta { + CUfunction func_addr_; + CUmodule module_; + std::vector thread_info_; + GpuKernelMeta(CUfunction funcAddr, CUmodule module, const std::vector &thread_info) + : func_addr_(funcAddr), module_(module), thread_info_(thread_info) {} +}; +using GpuKernelMetaPtr = std::shared_ptr; + +class GpuKernelManager { + public: + GpuKernelManager(); + virtual ~GpuKernelManager() { + for (auto iter = infotable_.begin(); iter != infotable_.end(); ++iter) { + CUresult ret = cuModuleUnload(iter->second->module_); + if (ret != CUDA_SUCCESS && ret != CUDA_ERROR_DEINITIALIZED) { + MS_LOG(ERROR) << "Unload GPU Module failed."; + } + } + } + CUresult GetFunction(const KernelPackPtr &kernel_pack, bool force_reload, std::vector *thread_info, + CUfunction *func); + + private: + std::unordered_map infotable_; +}; +using GpuKernelManagerPtr = std::shared_ptr; + +class GpuKernelMod : public KernelMod { + public: + explicit GpuKernelMod(const KernelPackPtr &kernel_pack); + virtual ~GpuKernelMod() {} + + void SetInputSizeList(const std::vector &size_list); + void SetOutputSizeList(const std::vector &size_list); + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + static GpuKernelManagerPtr kernelmanager_; + + private: + KernelPackPtr kernel_pack_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +using GpuKernelModPtr = std::shared_ptr; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h new file mode 100644 index 0000000000..c6398eda9e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ + +#include +#include +#include "framework/ge_runtime/task_info.h" +#include "backend/kernel_compiler/kernel.h" +#ifdef ENABLE_DATA_DUMP +#include "debug/data_dump_parser.h" +#endif + +using TaskInfoPtr = std::shared_ptr; +namespace mindspore { +namespace kernel { +class AscendKernelMod : public KernelMod { + public: + virtual std::vector GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t) = 0; + uint32_t block_dim() { return block_dim_; } + uint32_t stream_id() { return stream_id_; } + virtual bool NeedDump() { +#ifdef ENABLE_DATA_DUMP + return DataDumpParser::GetInstance().NeedDump(kernel_name_); +#else + return false; +#endif + } + + protected: + uint32_t block_dim_{1}; + uint32_t stream_id_{0}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc new file mode 100644 index 0000000000..f4495cdb9d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc @@ -0,0 +1,1029 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/common_utils.h" +#include +#include +#include +#include +#include +#include +#include "nlohmann/json.hpp" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "ir/manager.h" +#include "ir/meta_tensor.h" +#include "ir/func_graph.h" +#include "frontend/operator/ops.h" +#include "utils/graph_utils.h" + +namespace mindspore { +namespace kernel { +constexpr char kAxis[] = "axis"; +constexpr char kTypeInt32[] = "Int32"; +const std::unordered_map type_id_maps = { + {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16}, + {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64}, + {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8}, + {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32}, + {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, + {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, + {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, + {"bool", TypeId::kNumberTypeBool}, +}; + +const std::map type_id_str_map = { + {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"}, + {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"}, + {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"}, + {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"}, + {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, + {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, + {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, + {TypeId::kNumberTypeBool, "bool"}, +}; + +const std::unordered_map dtype_shortdtype_map_ = { + {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"}, + {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"}, +}; + +const std::unordered_map dtype_nbyte_map = { + {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2}, + {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)}, + {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2}, + {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)}, +}; + +const std::unordered_map fusion_type_maps = { + {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, + {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE}, +}; + +void KernelMeta::Initialize() { + kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/"; + // remove old kernel cache + RemoveKernelCache(); + +#if defined(_WIN32) || defined(_WIN64) + auto ret = mkdir(kernel_meta_path_.c_str()); +#else + auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU); +#endif + if (ret != 0) { + MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later"; + } + initialized_ = true; +} + +void KernelMeta::RemoveKernelCache() { + DIR *dir = opendir(kernel_meta_path_.c_str()); + if (dir == nullptr) { + return; + } + struct dirent *entry; + while ((entry = readdir(dir)) != nullptr) { + std::string kernel_file = entry->d_name; + std::string kernel_file_realpath = kernel_meta_path_ + kernel_file; + (void)remove(kernel_file_realpath.c_str()); + } + (void)closedir(dir); + (void)rmdir(kernel_meta_path_.c_str()); +} + +std::string KernelMeta::Search(const std::string &kernel_name) const { + if (!initialized_) { + return ""; + } + + auto iter = kernel_meta_map_.find(kernel_name); + if (iter == kernel_meta_map_.end()) { + return ""; + } else { + return iter->second; + } +} + +bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) { + if (!initialized_) { + return false; + } + kernel_meta_map_[kernel_name] = kernel_json; + return true; +} + +bool CheckCache(const std::string &kernel_name) { + // check cache. + KernelMeta *bin_map = KernelMeta::GetInstance(); + if (bin_map == nullptr) { + MS_LOG(DEBUG) << "kernel cache is invalid."; + return false; + } + std::string kernel_json = bin_map->Search(kernel_name); + bool ret = (!kernel_json.empty()); + if (ret) { + MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed."; + } else { + MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed."; + } + return ret; +} + +KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) { + // search cache. + KernelMeta *bin_map = KernelMeta::GetInstance(); + if (bin_map == nullptr) { + MS_LOG(DEBUG) << "kernel cache is invalid."; + return nullptr; + } + + std::string kernel_json = bin_map->Search(kernel_name); + if (!kernel_json.empty()) { + KernelPackPtr kernel_pack = std::make_shared(); + // just a tmp solution. + if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) { + MS_LOG(DEBUG) << "Read cache json and bin file failed[" << kernel_json << "]."; + return nullptr; + } else { + return kernel_pack; + } + } else { + MS_LOG(INFO) << "cache kernel not found[" << kernel_name << "]."; + return nullptr; + } +} + +KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) { + MS_LOG(INFO) << "kernel name:" << kernel_name << ", processr:" << processor; + KernelMeta *bin_map = KernelMeta::GetInstance(); + std::string kernel_json; + if (processor == kProcessorAiCore || processor == kProcessorAiCpu) { + kernel_json = kCceKernelMeta; + } else { + kernel_json = bin_map->GetKernelMetaPath(); + } + (void)kernel_json.append(kernel_name).append(kJsonSuffix); + KernelPackPtr kernel_pack = std::make_shared(); + if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) { + MS_LOG(DEBUG) << "Read json and bin file failed[" << kernel_json << "]."; + return nullptr; + } + + if (bin_map == nullptr) { + MS_LOG(DEBUG) << "kernel cache is invalid."; + return nullptr; + } + if (bin_map->Insert(kernel_name, kernel_json)) { + MS_LOG(INFO) << "Insert to cache success[" << kernel_json << "], kernelname[" << kernel_name << "]."; + } + return kernel_pack; +} + +TypeId DtypeToTypeId(const std::string &dtypes) { + auto iter = type_id_maps.find(dtypes); + if (iter != type_id_maps.end()) { + return iter->second; + } else { + MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes; + } +} + +std::string TypeId2String(TypeId type_id) { + auto iter = type_id_str_map.find(type_id); + if (iter == type_id_str_map.end()) { + return std::string(TypeIdLabel(type_id)); + } + return iter->second; +} + +std::string Dtype2ShortType(const std::string &dtypes) { + auto iter = dtype_shortdtype_map_.find(dtypes); + if (iter != dtype_shortdtype_map_.end()) { + return iter->second; + } else { + MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes; + } +} + +size_t GetDtypeNbyte(const std::string &dtypes) { + auto iter = dtype_nbyte_map.find(dtypes); + if (iter != dtype_nbyte_map.end()) { + return iter->second; + } else { + MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes; + } +} + +bool SetInputKernelBuilderInfo(const std::vector> &inputs, size_t real_input_num, + size_t builder_idex, const std::vector &dyn_input_sizes, + const std::shared_ptr &builder) { + MS_EXCEPTION_IF_NULL(builder); + + std::vector inputs_device_type; + std::vector inputs_format; + size_t dyn_input_idx = 0; + size_t kernel_info_index = 0; + MS_EXCEPTION_IF_NULL(inputs[0]); + size_t kernel_info_cnt = inputs[0]->dtypes().size(); + + for (const auto &input : inputs) { + MS_EXCEPTION_IF_NULL(input); + std::string param_type = input->param_type(); + std::vector dtypes = input->dtypes(); + std::vector formats = input->formats(); + if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { + MS_LOG(DEBUG) << "Set input kernel builder info, dtyps size != formats size."; + return false; + } + + if (param_type == "dynamic") { + if (dyn_input_sizes.empty()) { + MS_LOG(DEBUG) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; + return false; + } + + for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) { + kernel_info_index++; + auto type_id = DtypeToTypeId(dtypes[builder_idex]); + inputs_device_type.push_back(type_id); + inputs_format.push_back(formats[builder_idex]); + } + dyn_input_idx++; + } else if (param_type == "required") { + kernel_info_index++; + auto type_id = DtypeToTypeId(dtypes[builder_idex]); + inputs_device_type.push_back(type_id); + inputs_format.push_back(formats[builder_idex]); + } else { + if (kernel_info_index < real_input_num) { + MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index; + kernel_info_index++; + auto type_id = DtypeToTypeId(dtypes[builder_idex]); + inputs_device_type.push_back(type_id); + inputs_format.push_back(formats[builder_idex]); + } + } + } + + builder->SetInputsDeviceType(inputs_device_type); + builder->SetInputsFormat(inputs_format); + return true; +} + +bool SetOutputKernelBuilderInfo(const std::vector> &outputs, size_t builder_idex, + const size_t &real_output_num, + const std::shared_ptr &builder) { + // not now but in the next we need to support dynamic output case + MS_EXCEPTION_IF_NULL(builder); + + size_t output_idx = 0; + std::vector outputs_device_type; + std::vector outputs_format; + MS_EXCEPTION_IF_NULL(outputs[0]); + size_t kernel_info_cnt = outputs[0]->dtypes().size(); + + for (const auto &output : outputs) { + MS_EXCEPTION_IF_NULL(output); + if (output_idx >= real_output_num) { + MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!"; + continue; + } + size_t output_num = 0; + if (output->param_type() == "dynamic") { + if (outputs.size() > 1) { + MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!"; + } + output_num = real_output_num; + } else if (output->param_type() == "required") { + output_num = 1; + } else { + if (output_idx < real_output_num) { + MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; + output_num = 1; + } + } + + for (size_t i = 0; i < output_num; i++) { + std::vector dtypes = output->dtypes(); + std::vector formats = output->formats(); + if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { + MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size."; + return false; + } + auto type_id = DtypeToTypeId(dtypes[builder_idex]); + outputs_device_type.push_back(type_id); + outputs_format.push_back(formats[builder_idex]); + output_idx++; + } + } + + builder->SetOutputsFormat(outputs_format); + builder->SetOutputsDeviceType(outputs_device_type); + return true; +} + +void SetKernelBuildInfo(const std::shared_ptr &builder, Processor processor, + const std::shared_ptr &op_info_ptr) { + MS_EXCEPTION_IF_NULL(builder); + MS_EXCEPTION_IF_NULL(op_info_ptr); + + auto imply_type = op_info_ptr->imply_type(); + builder->SetProcessor(processor); + std::string fusion_type = op_info_ptr->fusion_type(); + auto iter = fusion_type_maps.find(fusion_type); + if (iter != fusion_type_maps.end()) { + builder->SetFusionType(iter->second); + } else { + if (imply_type == kAKG) { + MS_EXCEPTION(NotExistsError) << "Illegal fusion type from dsl register:" << fusion_type; + } + } + + if (imply_type == kAKG) { + builder->SetKernelType(AKG_KERNEL); + } else if (imply_type == kAICPU) { + builder->SetKernelType(AICPU_KERNEL); + } else { + builder->SetKernelType(TBE_KERNEL); + } +} + +bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &op_info_ptr, Processor processor, + std::vector> *const kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node); + size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + std::vector> inputs = op_info_ptr->inputs_ptr(); + std::vector> outputs = op_info_ptr->outputs_ptr(); + std::vector dyn_input_sizes; + auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node); + MS_EXCEPTION_IF_NULL(primitive); + if (primitive->GetAttr("dyn_input_sizes") != nullptr) { + dyn_input_sizes = GetValue>(primitive->GetAttr("dyn_input_sizes")); + } + if (inputs.size() > 0) { + MS_EXCEPTION_IF_NULL(inputs[0]); + size_t kernel_info_cnt = inputs[0]->dtypes().size(); + for (size_t j = 0; j < kernel_info_cnt; j++) { + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + SetKernelBuildInfo(builder, processor, op_info_ptr); + + if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) { + MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed."; + return false; + } + + if (outputs.size() > 0) { + if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) { + MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed."; + return false; + } + } + + kernel_info_list->push_back(builder->Build()); + } + } else if (outputs.size() > 0) { + MS_EXCEPTION_IF_NULL(outputs[0]); + size_t kernel_info_cnt = outputs[0]->dtypes().size(); + for (size_t j = 0; j < kernel_info_cnt; j++) { + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + SetKernelBuildInfo(builder, processor, op_info_ptr); + + if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) { + MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed."; + return false; + } + + kernel_info_list->push_back(builder->Build()); + } + } else { + if (processor == AICPU) { + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + SetKernelBuildInfo(builder, processor, op_info_ptr); + kernel_info_list->push_back(builder->Build()); + } + } + return true; +} + +void SaveJsonInfo(const std::string &json_name, const std::string &info) { + char real_path[PATH_MAX] = {0}; + std::string path = kCceKernelMeta + json_name + kInfoSuffix; + if (path.size() > PATH_MAX) { + MS_LOG(DEBUG) << "file path " << path << " is too long."; + return; + } + std::ofstream filewrite; + filewrite.open(path); + if (!filewrite.is_open()) { + return; + } + filewrite << info << std::endl; + filewrite.close(); +#if defined(_WIN32) || defined(_WIN64) + if (nullptr == _fullpath(real_path, path.c_str(), PATH_MAX)) { + MS_LOG(DEBUG) << "dir " << path << " does not exit."; + return; + } +#else + if (nullptr == realpath(path.c_str(), real_path)) { + MS_LOG(DEBUG) << "dir " << path << " does not exit."; + return; + } +#endif + MS_LOG(INFO) << "real path is :" << real_path; + if (chmod(real_path, S_IRUSR) == -1) { + MS_LOG(DEBUG) << "modify file:" << real_path << " to read only fail."; + } +} + +std::string GetProcessor(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string device; + switch (AnfAlgo::GetProcessor(anf_node)) { + case Processor::AICORE: + device = kProcessorAiCore; + break; + + case Processor::AICPU: + device = kProcessorAiCpu; + break; + + case Processor::CUDA: + device = kProcessorCuda; + break; + + default: + MS_LOG(DEBUG) << "Unknown processor type."; + break; + } + return device; +} + +bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b) { + if (shape_a.size() != shape_b.size()) { + return false; + } + for (size_t i = 0; i < shape_a.size(); ++i) { + if (shape_a[i] != shape_b[i]) { + return false; + } + } + return true; +} + +int Sign(float x) { + if (x > 0) { + return 1; + } + if (x < 0) { + return -1; + } + return 0; +} + +void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim) { + MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); + MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); + MS_EXCEPTION_IF_NULL(unique_grad); + MS_EXCEPTION_IF_NULL(unique_grad->value_); + MS_EXCEPTION_IF_NULL(unique_grad->indices_); + std::unordered_map index_map; + size_t unique_indices_size = 0; + for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) { + int index = origin_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= first_dim) { + continue; + } + auto iter = index_map.find(index); + if (iter == index_map.end()) { + index_map[index] = unique_indices_size; + unique_grad->indices_[unique_indices_size] = index; + size_t start_index = unique_indices_size * outer_dim; + size_t end_index = start_index + outer_dim; + for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) { + unique_grad->value_[j] = origin_sparse_grad.value_[k]; + } + unique_indices_size++; + } else { + size_t first_index = iter->second; + size_t start_index = first_index * outer_dim; + size_t end_index = start_index + outer_dim; + for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) { + unique_grad->value_[j] += origin_sparse_grad.value_[k]; + } + } + } + unique_grad->indices_size_ = unique_indices_size; +} + +struct WorkerParamsForReduceSparseGradient { + size_t slice_start_{0}; + size_t slice_end_{0}; + size_t max_length_{0}; + size_t outer_dim_{0}; + std::vector> *sorted_indices_{nullptr}; + std::vector *slice_positions_{nullptr}; + float *src_value_{nullptr}; + SparseGradient *unique_grad_{nullptr}; +}; + +void WorkerForReduceSparseGradient(WorkerParamsForReduceSparseGradient param) { + MS_EXCEPTION_IF_NULL(param.sorted_indices_); + MS_EXCEPTION_IF_NULL(param.slice_positions_); + MS_EXCEPTION_IF_NULL(param.src_value_); + MS_EXCEPTION_IF_NULL(param.unique_grad_); + auto outer_dim = param.outer_dim_; + auto &sorted_indices = *(param.sorted_indices_); + auto &slice_positions = *(param.slice_positions_); + auto unique_grad = param.unique_grad_; + for (size_t slice_id = param.slice_start_; slice_id < param.slice_end_; ++slice_id) { + size_t cur_pos = slice_positions[slice_id]; + int index = sorted_indices[cur_pos].first; + unique_grad->indices_[slice_id] = index; + size_t start_index = slice_id * outer_dim; + auto ret_code = memcpy_s(unique_grad->value_ + start_index, (param.max_length_ - start_index) * sizeof(float), + param.src_value_ + sorted_indices[cur_pos].second, outer_dim * sizeof(float)); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data!"; + } + cur_pos++; + size_t end_pos; + if (slice_id + 1 < slice_positions.size()) { + end_pos = slice_positions[slice_id + 1]; + } else { + end_pos = sorted_indices.size(); + } + while (cur_pos < end_pos) { + for (size_t i = 0; i < outer_dim; ++i) { + unique_grad->value_[start_index + i] += param.src_value_[sorted_indices[cur_pos].second + i]; + } + cur_pos++; + } + } +} + +void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, + size_t outer_dim, std::vector> *sorted_indices, + std::vector *slice_positions) { + MS_LOG(DEBUG) << "Start"; + size_t thread_num = 24; + if (slice_positions->size() < thread_num) { + thread_num = slice_positions->size(); + } + size_t stride = (slice_positions->size() + thread_num - 1) / thread_num; + thread_num = (slice_positions->size() + stride - 1) / stride; + std::vector threads; + size_t max_length = sorted_indices->size() * outer_dim; + for (size_t i = 0; i < thread_num; ++i) { + size_t slice_start = i * stride; + size_t slice_end = 0; + if (i == thread_num - 1) { + slice_end = slice_positions->size(); + } else { + slice_end = slice_start + stride; + } + WorkerParamsForReduceSparseGradient params{ + slice_start, slice_end, max_length, outer_dim, sorted_indices, slice_positions, origin_sparse_grad.value_, + unique_grad}; + threads.emplace_back(std::thread(WorkerForReduceSparseGradient, params)); + } + for (size_t i = 0; i < thread_num; ++i) { + threads[i].join(); + } + MS_LOG(DEBUG) << "End"; +} + +void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim, bool use_multi_threads) { + MS_LOG(DEBUG) << "Start"; + MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); + MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); + MS_EXCEPTION_IF_NULL(unique_grad); + MS_EXCEPTION_IF_NULL(unique_grad->value_); + MS_EXCEPTION_IF_NULL(unique_grad->indices_); + std::vector> sorted_indices; + sorted_indices.reserve(origin_sparse_grad.indices_size_); + for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) { + int index = origin_sparse_grad.indices_[i]; + if (index >= 0 && IntToSize(index) < first_dim) { + sorted_indices.emplace_back(std::pair(index, i * outer_dim)); + } + } + std::sort( + sorted_indices.begin(), sorted_indices.end(), + [](const std::pair &left, const std::pair &right) { return left.first < right.first; }); + int last_index = 0; + std::vector slice_positions; + slice_positions.reserve(sorted_indices.size()); + for (size_t i = 0; i < sorted_indices.size(); ++i) { + if (i == 0 || last_index != sorted_indices[i].first) { + slice_positions.emplace_back(i); + } + last_index = sorted_indices[i].first; + } + if (use_multi_threads) { + RunMultiThreadReduceSparseGradient(origin_sparse_grad, unique_grad, outer_dim, &sorted_indices, &slice_positions); + } else { + size_t max_length = sorted_indices.size() * outer_dim; + WorkerParamsForReduceSparseGradient params{0, + slice_positions.size(), + max_length, + outer_dim, + &sorted_indices, + &slice_positions, + origin_sparse_grad.value_, + unique_grad}; + WorkerForReduceSparseGradient(params); + } + unique_grad->indices_size_ = slice_positions.size(); + MS_LOG(DEBUG) << "End"; +} + +void ReduceMultiSparseGradient(const std::vector> &unique_slice_grads, + SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim) { + MS_LOG(DEBUG) << "Start"; + if (unique_slice_grads.empty()) { + return; + } + size_t index_data_size = outer_dim * sizeof(float); + size_t unique_indices_size = 0; + for (size_t i = 0; i < unique_slice_grads.size(); ++i) { + auto &slice_grad = unique_slice_grads[i]; + auto ret_code = memcpy_s(tmp_grad->value_ + unique_indices_size * outer_dim, + (tmp_grad->indices_size_ - unique_indices_size) * index_data_size, slice_grad->value_, + slice_grad->indices_size_ * index_data_size); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data!"; + } + ret_code = + memcpy_s(tmp_grad->indices_ + unique_indices_size, (tmp_grad->indices_size_ - unique_indices_size) * sizeof(int), + slice_grad->indices_, slice_grad->indices_size_ * sizeof(int)); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data!"; + } + unique_indices_size += slice_grad->indices_size_; + } + tmp_grad->indices_size_ = unique_indices_size; + ReduceSparseGradient(*tmp_grad, unique_grad, first_dim, outer_dim); + MS_LOG(DEBUG) << "End"; +} + +void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad, + SparseGradient *unique_grad, size_t first_dim, size_t outer_dim) { + MS_LOG(DEBUG) << "Start"; + MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); + MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); + MS_EXCEPTION_IF_NULL(unique_grad); + MS_EXCEPTION_IF_NULL(unique_grad->value_); + MS_EXCEPTION_IF_NULL(unique_grad->indices_); + MS_EXCEPTION_IF_NULL(tmp_grad); + MS_EXCEPTION_IF_NULL(tmp_grad->value_); + MS_EXCEPTION_IF_NULL(tmp_grad->indices_); + size_t thread_num = 24; + if (origin_sparse_grad.indices_size_ < thread_num) { + thread_num = origin_sparse_grad.indices_size_; + } + size_t thread_indices_size = origin_sparse_grad.indices_size_ / thread_num; + size_t left_indices_size = origin_sparse_grad.indices_size_ % thread_num; + std::vector threads; + threads.reserve(thread_num); + std::vector> unique_slice_grads; + for (size_t i = 0; i < thread_num; ++i) { + size_t indices_size = thread_indices_size; + if (i == thread_num - 1) { + indices_size = thread_indices_size + left_indices_size; + } + size_t value_offset = i * thread_indices_size * outer_dim; + size_t indices_offset = i * thread_indices_size; + auto slice_grad = SparseGradient( + {origin_sparse_grad.value_ + value_offset, origin_sparse_grad.indices_ + indices_offset, indices_size}); + unique_slice_grads.emplace_back(std::make_shared()); + unique_slice_grads[i]->value_ = unique_grad->value_ + value_offset; + unique_slice_grads[i]->indices_ = unique_grad->indices_ + indices_offset; + unique_slice_grads[i]->indices_size_ = indices_size; + threads.emplace_back( + std::thread(ReduceSparseGradient, slice_grad, unique_slice_grads[i].get(), first_dim, outer_dim, false)); + } + for (size_t i = 0; i < thread_num; ++i) { + threads[i].join(); + } + ReduceMultiSparseGradient(unique_slice_grads, tmp_grad, unique_grad, first_dim, outer_dim); + MS_LOG(DEBUG) << "End"; +} + +std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + + if (index >= AnfAlgo::GetInputTensorNum(anf_node)) { + MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs."; + } + + auto cnode = anf_node->cast(); + if (cnode == nullptr) { + return AnfAlgo::VisitKernel(anf_node, 0); + } else { + return AnfAlgo::VisitKernel(anf_node->cast()->input(index + 1), 0); + } +} + +std::vector>> GetInputIndex(const std::vector &node_list, + const std::vector &input_list) { + std::vector>> input_index; + for (size_t i = 0; i < input_list.size(); ++i) { + auto const &input = input_list[i]; + MS_EXCEPTION_IF_NULL(input); + bool found = false; + // using NodeUsersMap = std::unordered_map>>; + auto mng = input->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(mng); + const NodeUsersMap &users = mng->node_users(); + auto input_users = users.find(input); + if (input_users == users.end() || input_users->second.empty()) { + MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" + << input->func_graph()->ToString() << "] has no users."; + } + + for (auto const &input_user : input_users->second) { + for (auto const &anf_node : node_list) { + if (anf_node != input_user.first) { + continue; + } + + std::vector dyn_input_sizes; + auto prim = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(prim); + if (prim->GetAttr(kAttrDynInputSizes) != nullptr) { + dyn_input_sizes = GetValue>(prim->GetAttr(kAttrDynInputSizes)); + } + + if (dyn_input_sizes.empty()) { + input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0))); + found = true; + break; + } else { + int used_as_idx = input_user.second - 1; + int accum_idx = 0; + size_t dyn_i = 0; + for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) { + accum_idx += dyn_input_sizes[dyn_i]; + if (used_as_idx < accum_idx) { + input_index.push_back(std::make_pair( + anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i]))))); + break; + } + } + if (dyn_i != dyn_input_sizes.size()) { + found = true; + break; + } + } + } + if (found) { + break; + } + } + + if (!found) { + MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" + << input->func_graph()->ToString() << "] found no related kernel info."; + } + } + return input_index; +} + +std::vector> GetOutputIndex(const std::vector &node_list, + const std::vector &input_list, + const std::vector &output_list) { + std::vector> output_index; + for (size_t i = 0; i < output_list.size(); ++i) { + auto const &output = output_list[i]; + MS_EXCEPTION_IF_NULL(output); + bool found = false; + auto pree_node = AnfAlgo::VisitKernel(output, 0); + auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first); + if (pos != std::end(node_list)) { + output_index.push_back(pree_node); + continue; + } + auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first); + if (ret != std::end(input_list)) { + output_index.push_back(std::make_pair(pree_node.first, 0)); + found = true; + } + if (!found) { + MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of [" + << output->func_graph()->ToString() << "] found no related kernel info."; + } + } + return output_index; +} + +void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list) { + MS_EXCEPTION_IF_NULL(node_list); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector node_lists = TopoSort(func_graph->get_return()); + for (auto const &node : node_lists) { + if (!AnfAlgo::IsRealKernel(node) || !node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (IsValueNode(cnode->input(kAnfPrimitiveIndex))) { + node_list->push_back(node); + } + } +} + +void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list, + std::vector *input_list, std::vector *output_list) { + MS_EXCEPTION_IF_NULL(node_list); + MS_EXCEPTION_IF_NULL(input_list); + MS_EXCEPTION_IF_NULL(output_list); + MS_EXCEPTION_IF_NULL(func_graph); + + GetValidKernelNodes(func_graph, node_list); + + auto parameters = func_graph->parameters(); + input_list->insert(input_list->begin(), parameters.begin(), parameters.end()); + + auto func_output = func_graph->output(); + MS_EXCEPTION_IF_NULL(func_output); + if (func_output->isa()) { + // multi output. + auto cnode = func_output->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input0 = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(input0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) { + auto input_node = cnode->input(input_idx); + MS_EXCEPTION_IF_NULL(input_node); + output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first); + } + } else { + // single output. + output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); + } + } else { + // single output. + output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); + } +} + +bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(node_json); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (input_idx + 1 >= cnode->size()) { + MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" + << cnode->inputs().size() << "][" << cnode->DebugString() << "]"; + } + + auto input_node = cnode->input(input_idx + 1); + if (!IsValueNode(input_node)) { + return false; + } + + auto tensor = GetValueNode(input_node); + if (tensor == nullptr) { + return false; + } + + auto type_id = tensor->data_type(); + auto *data = tensor->data_c(); + MS_EXCEPTION_IF_NULL(data); + if (tensor->DataDim() > 1 || tensor->DataSize() != 1) { + // not const tensor. + MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]"; + } + + if (type_id == kFloat32->type_id()) { + float *val = static_cast(data); + MS_EXCEPTION_IF_NULL(val); + (*node_json)["value"] = val[0]; + MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "]."; + return true; + } else if (type_id == kFloat16->type_id()) { + float16 *val = static_cast(data); + MS_EXCEPTION_IF_NULL(val); + (*node_json)["value"] = static_cast(val[0]); + MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "]."; + return true; + } else if (type_id == kInt32->type_id()) { + int *val = static_cast(data); + MS_EXCEPTION_IF_NULL(val); + (*node_json)["value"] = val[0]; + MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "]."; + return true; + } + MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]"; + return false; +} + +void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector> *node_list) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node_list); + auto output = func_graph->output(); + MS_EXCEPTION_IF_NULL(output); + if (AnfAlgo::IsRealKernel(output)) { + // single output. + node_list->push_back(std::make_pair(output, 0)); + return; + } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + // multi output. + auto &inputs = output_cnode->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0); + node_list->push_back(in_with_idx); + } + return; + } + MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2) + << " of graph: " << func_graph->ToString(); +} + +bool IsWeightBoundary(const AnfNodePtr &node) { + if (node->isa()) { + return true; + } + if (node->isa() && AnfAlgo::IsParameterWeight(node->cast())) { + return true; + } + return false; +} + +void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params, + size_t total_compute_size) { + const size_t kThreadNum = 24; + std::vector threads; + threads.reserve(kThreadNum); + size_t start = 0; + size_t once_compute_size = (total_compute_size + kThreadNum - 1) / kThreadNum; + while (start < total_compute_size) { + size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size); + threads.emplace_back(std::thread(func, params, start, end)); + start += once_compute_size; + } + for (size_t i = 0; i < threads.size(); ++i) { + threads[i].join(); + } +} + +std::vector GetReduceAttrAxis(const CNodePtr &cnode) { + if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) && + AnfAlgo::GetInputTensorNum(cnode) != 1) { + MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString() + << "] is not single input or single output "; + } + std::vector axis; + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + auto axis_attr = primitive->GetAttr(kAxis); + if (axis_attr == nullptr) { + MS_LOG(ERROR) << "This node does't have axie attr."; + return std::vector(); + } + auto type = axis_attr->type(); + MS_EXCEPTION_IF_NULL(type); + std::vector axis_list; + if (type->ToString() == kTypeInt32) { + axis_list.emplace_back(GetValue(axis_attr)); + } else { + axis_list = GetValue>(axis_attr); + } + for (const auto &elem : axis_list) { + if (elem < 0) { + axis.emplace_back(input_shape.size() + elem); + } else { + axis.emplace_back(elem); + } + } + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode); + return axis; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h new file mode 100644 index 0000000000..8c9ea84b34 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h @@ -0,0 +1,145 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ +#define MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/oplib/opinfo.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +constexpr auto kCceKernelMeta = "./kernel_meta/"; +constexpr auto kGpuKernelMeta = "./cuda_meta"; +constexpr auto kProcessorAiCore = "aicore"; +constexpr auto kProcessorAiCpu = "aicpu"; +constexpr auto kProcessorCuda = "cuda"; +constexpr auto kJsonSuffix = ".json"; +constexpr auto kInfoSuffix = ".info"; +constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600; +constexpr auto kAkgModule = "_akg"; +constexpr auto kArgDataformat = "data_format"; + +const std::vector support_devices = {"aicore", "aicpu", "cuda"}; + +struct KernelMetaInfo { + uintptr_t func_stub_; + uint32_t block_dim_; +}; +using KernelMetaPtr = std::shared_ptr; + +class KernelMeta { + public: + KernelMeta() = default; + void Initialize(); + void RemoveKernelCache(); + std::string Search(const std::string &kernel_name) const; + bool Insert(const std::string &kernel_name, const std::string &kernel_json); + std::string GetKernelMetaPath() { return kernel_meta_path_; } + + static KernelMeta *GetInstance() { + static KernelMeta kernel_meta; + return &kernel_meta; + } + ~KernelMeta() = default; + + private: + bool initialized_ = false; + std::string kernel_meta_path_; + std::unordered_map kernel_meta_map_; +}; + +struct SparseGradient { + float *value_; + int *indices_; + size_t indices_size_; +}; + +struct MultiThreadComputeParams { + float *var_; + float *accum_; + float *linear_; + float *m_; + float *m_t_; + float *v_; + float lr_; + float l1_; + float l2_; + float lr_power_; + float beta1_; + float beta2_; + float epsilon_; + SparseGradient sparse_grad_; + size_t var_first_dim_size_; + size_t var_outer_dim_size_; + bool use_nesterov_; +}; +using MultiThreadComputeFunc = std::function; + +bool CheckCache(const std::string &kernel_name); +KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); +KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); +TypeId DtypeToTypeId(const std::string &dtypes); +std::string Dtype2ShortType(const std::string &dtypes); +std::string TypeId2String(TypeId type_id); +size_t GetDtypeNbyte(const std::string &dtypes); +bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &op_info_ptr, Processor processor, + std::vector> *const kernel_info_list); +void SaveJsonInfo(const std::string &json_name, const std::string &info); +std::string GetProcessor(const AnfNodePtr &anf_node); +bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b); +int Sign(float x); +void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim); +void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim, bool use_multi_threads = true); +std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index); +std::vector>> GetInputIndex(const std::vector &node_list, + const std::vector &input_list); +std::vector> GetOutputIndex(const std::vector &node_list, + const std::vector &input_list, + const std::vector &output_list); +void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list, + std::vector *input_list, std::vector *output_list); +void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list); +bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json); +void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector> *node_list); +bool IsWeightBoundary(const AnfNodePtr &node); +void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params, + size_t total_compute_size); +void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, + size_t outer_dim, std::vector> *sorted_indices, + std::vector *slice_positions); +void ReduceMultiSparseGradient(const std::vector> &unique_slice_grads, + SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim); +void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad, + SparseGradient *unique_grad, size_t first_dim, size_t outer_dim); +std::vector GetReduceAttrAxis(const CNodePtr &cnode); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.cc new file mode 100644 index 0000000000..1300847d40 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/addn_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +bool AddNCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + + size_t offset = 0; + for (size_t i = 0; i < output_shape_[0]; ++i) { + for (size_t j = 0; j < output_shape_[1]; ++j) { + for (size_t k = 0; k < output_shape_[2]; ++k) { + for (size_t m = 0; m < output_shape_[3]; ++m) { + float sum = 0; + for (size_t index = 0; index < input_num_; ++index) { + auto input_addr = reinterpret_cast(inputs[index]->addr); + sum += input_addr[offset]; + } + output_addr[offset++] = sum; + } + } + } + } + + return true; +} + +void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but AddNCPUKernel olny support 4d or lower."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AddNCPUKernel needs 1 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.h new file mode 100644 index 0000000000..925f0fab50 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class AddNCPUKernel : public CPUKernel { + public: + AddNCPUKernel() : input_num_(0) {} + ~AddNCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + size_t input_num_; + std::vector output_shape_; +}; + +MS_REG_CPU_KERNEL(AddN, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AddNCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.cc new file mode 100644 index 0000000000..55afecb8fa --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/allgather_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "runtime/device/cpu/mpi/mpi_adapter.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr auto kRanksGroup = "group"; +constexpr auto kAllGatherInputNum = 1; +} // namespace + +void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != kAllGatherInputNum) { + MS_LOG(EXCEPTION) << "allgather input num:" << input_num; + } + + auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); + if (ranks_group != nullptr) { + ranks_group_ = GetValue>(ranks_group); + } else { + MS_LOG(EXCEPTION) << "Miss attribute " << kRanksGroup; + } +} + +bool AllGatherCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto input_data_num = inputs[0]->size / sizeof(float); + auto mpi_instance = device::cpu::MPIAdapter::Instance(); + MS_EXCEPTION_IF_NULL(mpi_instance); + return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.h new file mode 100644 index 0000000000..42c83ccf0b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class AllGatherCPUKernel : public CPUKernel { + public: + AllGatherCPUKernel() = default; + ~AllGatherCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::vector ranks_group_; +}; + +MS_REG_CPU_KERNEL(_HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AllGatherCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc new file mode 100644 index 0000000000..c1ff8d54bd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void ApplyMomentumCPUKernel::InitKernel(const CNodePtr & /*kernel_node*/) {} + +bool ApplyMomentumCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector & /*outputs*/) { + if (inputs.size() < 5) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[3]->size) { + MS_LOG(EXCEPTION) << "error input data size!"; + } + auto weight = reinterpret_cast(inputs[0]->addr); + auto accumulate = reinterpret_cast(inputs[1]->addr); + float learning_rate = reinterpret_cast(inputs[2]->addr)[0]; + auto gradient = reinterpret_cast(inputs[3]->addr); + float moment = reinterpret_cast(inputs[4]->addr)[0]; + size_t elem_num = inputs[0]->size / sizeof(float); + for (size_t i = 0; i < elem_num; ++i) { + accumulate[i] = accumulate[i] * moment + gradient[i]; + weight[i] -= accumulate[i] * learning_rate; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h new file mode 100644 index 0000000000..23e8488890 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class ApplyMomentumCPUKernel : public MKLCPUKernel { + public: + ApplyMomentumCPUKernel() = default; + ~ApplyMomentumCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + ApplyMomentumCPUKernel); +MS_REG_CPU_KERNEL(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + ApplyMomentumCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc new file mode 100644 index 0000000000..d67c4d47ff --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/argmax_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void ArgmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (shape.size() != 2) { + MS_LOG(EXCEPTION) << "argmax kernel dims invalid " << shape.size(); + } + batch_size_ = shape[0]; + class_num_ = shape[1]; + + int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis != -1 && axis != 1) { + MS_LOG(EXCEPTION) << "argmax kernel not support axis " << axis; + } +} + +bool ArgmaxCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspaces*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output empty!"; + } + + size_t batch_float_size = batch_size_ * sizeof(float); + size_t batch_class_float_size = class_num_ * batch_float_size; + if (inputs[0]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { + MS_LOG(EXCEPTION) << "invalid input or output data size!"; + } + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + size_t row_start = 0; + for (size_t i = 0; i < batch_size_; ++i) { + size_t max_index = 0; + float max_value = input[row_start]; + for (size_t j = 1; j < class_num_; ++j) { + size_t index = row_start + j; + if (input[index] > max_value) { + max_value = input[index]; + max_index = j; + } + } + output[i] = SizeToInt(max_index); + row_start += class_num_; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h new file mode 100644 index 0000000000..3883344f96 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ArgmaxCPUKernel : public CPUKernel { + public: + ArgmaxCPUKernel() = default; + ~ArgmaxCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t class_num_{0}; + size_t batch_size_{0}; +}; + +MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + ArgmaxCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.cc new file mode 100644 index 0000000000..f42bb6807d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/bias_add_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +void BiasAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + bias_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + if (input_shape_.size() == 4) { + data_shape_ = 4; + } else if (input_shape_.size() == 2) { + data_shape_ = 2; + } else { + MS_LOG(EXCEPTION) << "bias add input data format should be NCHW or NC"; + } + if (input_shape_.size() != 2 && input_shape_.size() != 4) { + MS_LOG(EXCEPTION) << "bias add input shape nchw or nc"; + } + if (bias_shape_.size() != 1) { + MS_LOG(EXCEPTION) << "bias shape invalid"; + } + if (input_shape_[1] != bias_shape_[0]) { + MS_LOG(EXCEPTION) << "bias shape not match"; + } +} + +bool BiasAddCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() != 2 || outputs.size() != 1) { + MS_LOG(EXCEPTION) << "inputs outputs size not supoort"; + } + + auto src_addr = reinterpret_cast(inputs[0]->addr); + auto bias_addr = reinterpret_cast(inputs[1]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + if (data_shape_ == 4) { + size_t h_size = input_shape_[3]; + size_t c_size = input_shape_[2] * h_size; + size_t n_size = input_shape_[1] * c_size; + size_t hw_size = input_shape_[2] * input_shape_[3]; + size_t n_offset = 0; + for (size_t n = 0; n < input_shape_[0]; ++n) { + size_t c_offset = 0; + for (size_t c = 0; c < input_shape_[1]; ++c) { + for (size_t hw = 0; hw < hw_size; ++hw) { + size_t offset = n_offset + c_offset + hw; + output_addr[offset] = src_addr[offset] + bias_addr[c]; + } + c_offset += c_size; + } + n_offset += n_size; + } + } else { + size_t n_offset = 0; + for (size_t n = 0; n < input_shape_[0]; ++n) { + for (size_t c = 0; c < input_shape_[1]; ++c) { + output_addr[n_offset + c] = src_addr[n_offset + c] + bias_addr[c]; + } + n_offset += input_shape_[1]; + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.h new file mode 100644 index 0000000000..c572f68230 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class BiasAddCPUKernel : public CPUKernel { + public: + BiasAddCPUKernel() = default; + ~BiasAddCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + uint8_t data_shape_{0}; + std::vector input_shape_; + std::vector bias_shape_; +}; +MS_REG_CPU_KERNEL( + BiasAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BiasAddCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.cc new file mode 100644 index 0000000000..8b6e2d0188 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +void BiasAddGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (input_shape_.size() != 4 && input_shape_.size() != 2) { + MS_LOG(EXCEPTION) << "input data format not support"; + } +} + +bool BiasAddGradCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() != 1 || outputs.size() != 1) { + MS_LOG(EXCEPTION) << "input output size not support"; + } + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto input_addr = reinterpret_cast(inputs[0]->addr); + + if (input_shape_.size() == 4) { + size_t h_size = input_shape_[3]; + size_t c_size = h_size * input_shape_[2]; + size_t n_size = c_size * input_shape_[1]; + size_t hw_size = input_shape_[2] * input_shape_[3]; + size_t c_offset = 0; + for (size_t c = 0; c < input_shape_[1]; ++c) { + output_addr[c] = 0; + size_t n_offset = 0; + for (size_t n = 0; n < input_shape_[0]; ++n) { + for (size_t hw = 0; hw < hw_size; ++hw) { + size_t offset = c_offset + n_offset + hw; + output_addr[c] += input_addr[offset]; + } + n_offset += n_size; + } + c_offset += c_size; + } + } else if (input_shape_.size() == 2) { + for (size_t c = 0; c < input_shape_[1]; ++c) { + output_addr[c] = 0; + size_t n_offset = 0; + for (size_t n = 0; n < input_shape_[0]; ++n) { + output_addr[c] += input_addr[c + n_offset]; + n_offset += input_shape_[1]; + } + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h new file mode 100644 index 0000000000..a5743879a7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class BiasAddGradCPUKernel : public CPUKernel { + public: + BiasAddGradCPUKernel() = default; + ~BiasAddGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::vector input_shape_; +}; +MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BiasAddGradCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc new file mode 100644 index 0000000000..6776c0f154 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/concat_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + + axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + auto input_1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_1_shape.size()); + } + axis_ += 4 - input_1_shape.size(); + + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; i++) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + CPUKernelUtils::ExpandDimsTo4(&input_shape); + input_shape_list_.push_back(input_shape); + } + + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +bool ConcatCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto buff_size = outputs[0]->size; + size_t dim0 = output_shape_[0]; + size_t dim1 = output_shape_[1]; + size_t dim2 = output_shape_[2]; + + if (axis_ == 3) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + for (size_t k = 0; k < dim2; ++k) { + CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size); + } + } + } + } else if (axis_ == 2) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size); + } + } + } else if (axis_ == 1) { + for (size_t i = 0; i < dim0; ++i) { + CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size); + } + } else if (axis_ == 0) { + CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); + } + return true; +} + +void ConcatCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, + size_t dim2, float **output_addr, size_t *buff_size) { + for (size_t i = 0; i < input_shape_list_.size(); ++i) { + auto input_i_shape = input_shape_list_[i]; + auto input_i_addr = reinterpret_cast(inputs[i]->addr); + + size_t num = CPUKernelUtils::GetElementNumOnAxis(input_i_shape, axis_); + num *= input_i_shape[axis_]; + auto pos = CPUKernelUtils::CalcOffset(input_i_shape, dim0, dim1, dim2, 0); + auto ret = memcpy_s(*output_addr, *buff_size, input_i_addr + pos, num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed."; + } + *output_addr += num; + *buff_size -= num * sizeof(float); + } +} + +void ConcatCPUKernel::CheckParam(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but ConcatCPUKernel olny support 4d or lower."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ConcatCPUKernel needs 1 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h new file mode 100644 index 0000000000..94e4ad40f3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ConcatCPUKernel : public CPUKernel { + public: + ConcatCPUKernel() : axis_(0) {} + ~ConcatCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, + float **output_addr, size_t *buff_size); + int axis_; + std::vector> input_shape_list_; + std::vector output_shape_; +}; + +MS_REG_CPU_KERNEL(Concat, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConcatCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc new file mode 100644 index 0000000000..fb9398e7c4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/cpu_kernel.h" + +namespace mindspore { +namespace kernel { +void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + size_t type_size = sizeof(float); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); + size_t tensor_size = + shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + input_size_list_.emplace_back(tensor_size); + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { + std::vector shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); + size_t tensor_size = + shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + output_size_list_.emplace_back(tensor_size); + } +} + +void CPUKernel::Init(const CNodePtr &kernel_node) { + InitKernel(kernel_node); + InitInputOutputSize(kernel_node); +} + +void CPUKernelUtils::ExpandDimsTo4(std::vector *shape) { + auto len = shape->size(); + if (len < 4) { + for (size_t i = 0; i < 4 - len; ++i) { + shape->insert(shape->begin(), 1); + } + } +} + +size_t CPUKernelUtils::CalcOffset(const std::vector &shape, size_t dim0, size_t dim1, size_t dim2, + size_t dim3) { + size_t offset = dim0 * shape[1] * shape[2] * shape[3] + dim1 * shape[2] * shape[3] + dim2 * shape[3] + dim3; + return offset; +} + +size_t CPUKernelUtils::GetElementNumOnAxis(const std::vector &shape, int axis) { + if (axis < 0) { + axis = axis + SizeToInt(shape.size()); + } + size_t result = 1; + for (int j = 3; j > axis; --j) { + result *= shape[j]; + } + return result; +} + +void CPUKernelUtils::GetElementNumEveryDim(const std::vector &shape, std::vector *element_num) { + size_t accumulation = 1; + element_num->emplace_back(1); + for (size_t i = shape.size() - 1; i > 0; --i) { + accumulation *= shape[i]; + element_num->emplace_back(accumulation); + } + std::reverse(element_num->begin(), element_num->end()); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h new file mode 100644 index 0000000000..f2aa292c6e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -0,0 +1,87 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "ir/anf.h" +#include "backend/session/anf_runtime_algorithm.h" + +using mindspore::kernel::Address; +using mindspore::kernel::AddressPtr; +namespace mindspore { +namespace kernel { +const char KSIZE[] = "ksize"; +const char STRIDE[] = "stride"; +const char STRIDES[] = "strides"; +const char DILATION[] = "dilation"; +const char PAD[] = "pad"; +const char PAD_MODE[] = "pad_mode"; +const char PADDING[] = "padding"; +const char PAD_MODE_LOWER_SAME[] = "same"; +const char PAD_MODE_LOWER_VALID[] = "valid"; +const char PAD_MODE_UPPER_SAME[] = "SAME"; +const char PAD_MODE_UPPER_VALID[] = "VALID"; +const char TRANSPOSE_A[] = "transpose_a"; +const char TRANSPOSE_B[] = "transpose_b"; +const char IS_GRAD[] = "is_grad"; +const char TRANSPOSE_NO = 'N'; +const char TRANSPOSE_YES = 'T'; +const char AXIS[] = "axis"; +const char BEGIN[] = "begin"; +const char END[] = "end"; +const char SIZE[] = "size"; +const char USE_NESTEROV[] = "use_nesterov"; + +class CPUKernel : public kernel::KernelMod { + public: + CPUKernel() = default; + ~CPUKernel() override = default; + virtual void Init(const CNodePtr &kernel_node); + virtual void InitKernel(const CNodePtr &kernel_node) = 0; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void * /*stream_ptr*/) override { + return Launch(inputs, workspace, outputs); + }; + virtual bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) = 0; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + protected: + virtual void InitInputOutputSize(const CNodePtr &kernel_node); + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +class CPUKernelUtils { + public: + static void ExpandDimsTo4(std::vector *shape); + static size_t CalcOffset(const std::vector &shape, size_t dim0, size_t dim1, size_t dim2, size_t dim3); + static size_t GetElementNumOnAxis(const std::vector &shape, int axis); + static void GetElementNumEveryDim(const std::vector &shape, std::vector *element_num); +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc new file mode 100644 index 0000000000..accd742976 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +#include +#include +#include + +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace kernel { +CPUKernelFactory &CPUKernelFactory::GetInstance() { + static CPUKernelFactory instance; + return instance; +} + +void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr, + CPUKernelCreator &&kernel_creator) { + (void)name_to_attr_creator_[kernel_name].emplace_back(kernel_attr, kernel_creator); +#if !defined(_WIN32) && !defined(_WIN64) + MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name; +#endif +} + +std::shared_ptr CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { + auto kernel_info = dynamic_cast(apply_kernel->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(kernel_build_Info); + std::pair ret_pair = CPUKernelAttrCheck(kernel_name, *kernel_build_Info); + if (ret_pair.first) { + return (name_to_attr_creator_.find(kernel_name)->second)[ret_pair.second].second(); + } + return nullptr; +} + +std::pair CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name, + const KernelBuildInfo &kernel_info) { + auto iter = name_to_attr_creator_.find(kernel_name); + if (iter == name_to_attr_creator_.end()) { + MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!"; + return std::make_pair(false, 0); + } + auto creators = iter->second; + for (size_t index = 0; index < creators.size(); ++index) { + auto attr_creator = creators[index]; + if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) { + return std::make_pair(true, index); + } + } + return std::make_pair(false, 0); +} + +bool CPUKernelFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) { + for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) { + auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first; + if (kernel_info.GetInputDeviceType(i) != dtype) { + MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i) + << ", register type:" << dtype; + return false; + } + } + for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) { + auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first; + if (kernel_info.GetOutputDeviceType(i) != dtype) { + MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i) + << ", register type:" << dtype; + return false; + } + } + return true; +} + +std::vector CPUKernelFactory::GetSupportedKernelAttrList(const std::string &kernel_name) { + std::vector result; + auto iter = name_to_attr_creator_.find(kernel_name); + if (iter == name_to_attr_creator_.end()) { + MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!"; + return result; + } + auto creators = iter->second; + for (size_t index = 0; index < creators.size(); ++index) { + auto attr_creator = creators[index]; + result.push_back(attr_creator.first); + } + return result; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.h new file mode 100644 index 0000000000..80f9a342ac --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ + +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "runtime/device/cpu/kernel_select_cpu.h" + +namespace mindspore { +namespace kernel { +using mindspore::device::cpu::KernelAttr; +using CPUKernelCreator = std::function()>; +class CPUKernelFactory { + public: + static CPUKernelFactory &GetInstance(); + void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator); + std::shared_ptr Create(const std::string &kernel_name, const CNodePtr &apply_kernel); + std::vector GetSupportedKernelAttrList(const std::string &kernel_name); + + private: + CPUKernelFactory() = default; + ~CPUKernelFactory() = default; + DISABLE_COPY_AND_ASSIGN(CPUKernelFactory) + std::pair CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info); + bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info); + std::map>> name_to_attr_creator_; +}; + +class CPUKernelRegistrar { + public: + CPUKernelRegistrar(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator) { + CPUKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(kernel_creator)); + } + ~CPUKernelRegistrar() = default; +}; + +#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) MS_REG_CPU_KERNEL_(__COUNTER__, OPNAME, ATTR, OPCLASS) +#define MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) +#define _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) \ + static_assert(std::is_base_of::value, " must be base of CPUKernel"); \ + static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \ + []() { return std::make_shared(); }); + +#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T) +#define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) +#define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \ + static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ + static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \ + #OPNAME, ATTR, []() { return std::make_shared>(); }); + +#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \ + static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ + static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_##S##_reg( \ + #OPNAME, ATTR, []() { return std::make_shared>(); }); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.cc new file mode 100644 index 0000000000..344f03cc53 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/debug_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif + +namespace mindspore { +namespace kernel { +void DebugCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } + +bool DebugCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 1 || outputs.empty()) { + MS_LOG(EXCEPTION) << " input or output empty!"; + } + auto val = reinterpret_cast(inputs[0]->addr); + MS_LOG(DEBUG) << " launch DebugCountCPUKernel val " << *val; + + auto output = reinterpret_cast(outputs[0]->addr); + size_t elem_num = inputs[0]->size / sizeof(int); + for (size_t i = 0; i < elem_num; i++) { + output[i] = val[i]; + } + +#ifdef ENABLE_DEBUGGER + // debugger will suspend execution is neccessary + Debugger::GetInstance()->PostDebugOp(); +#endif + + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.h new file mode 100644 index 0000000000..18302e8992 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class DebugCPUKernel : public CPUKernel { + public: + DebugCPUKernel() = default; + ~DebugCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(Debug, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), DebugCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.cc new file mode 100644 index 0000000000..1bcc36faa4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "runtime/device/cpu/mpi/mpi_adapter.h" + +namespace mindspore { +namespace kernel { +void EmbeddingLookUpCommGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + split_num_ = AnfAlgo::GetNodeAttr(kernel_node, "split_num"); + MS_LOG(INFO) << "split_num: " << split_num_; + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape[0] % split_num_ != 0) { + MS_LOG(EXCEPTION) << "Input shape[0] is " << input_shape[0] << ", but it must be multiple of split_num."; + } +} + +bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); +#endif + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + size_t input_size = inputs[0]->size; + size_t output_size = outputs[0]->size; + MS_LOG(DEBUG) << "input addr: " << input_addr << "input size: " << input_size; + MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << output_size; + memset_s(output_addr, output_size, 0, output_size); + const std::vector &rank_group = {0, 1, 2, 3, 4, 5, 6, 7}; + size_t input_split_lens = input_size / split_num_ / sizeof(float_t); + size_t output_split_lens = output_size / split_num_ / sizeof(float_t); + auto mpi_instance = device::cpu::MPIAdapter::Instance(); + MS_EXCEPTION_IF_NULL(mpi_instance); + for (int i = 0; i < split_num_; i++) { + mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group, + input_split_lens); + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "EmbeddingLookUpCommGradCPUKernel, used time: " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); + time += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "EmbeddingLookUpCommGradCPUKernel, used time: " << time << " us"; +#endif + return true; +} + +void EmbeddingLookUpCommGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but EmbeddingLookUpCommGradCPUKernel needs 1."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h new file mode 100644 index 0000000000..3e3807f58e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class EmbeddingLookUpCommGradCPUKernel : public CPUKernel { + public: + EmbeddingLookUpCommGradCPUKernel() : split_num_(1) {} + ~EmbeddingLookUpCommGradCPUKernel() override{}; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + int split_num_; +}; + +MS_REG_CPU_KERNEL(EmbeddingLookupCommGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + EmbeddingLookUpCommGradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc new file mode 100644 index 0000000000..b2feb9204f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "runtime/device/cpu/mpi/mpi_adapter.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace kernel { +void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_lens_ = 1; + for (auto shape : input_shape_) { + input_lens_ = input_lens_ * shape; + } + indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + indices_lens_ = 1; + for (auto shape : indices_shape_) { + indices_lens_ = indices_lens_ * shape; + } + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + axis_ = 4 - input_shape_.size(); + if (AnfAlgo::HasNodeAttr(kAttrReduceScatterFlag, kernel_node)) { + reduce_scatter_flag_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrReduceScatterFlag); + } +#ifdef ENABLE_MPI + if (reduce_scatter_flag_) { + size_t gatherv2_out_lens = 1; + for (int i = 0; i < SizeToInt(input_shape_.size()); i++) { + if (i == 0) { + for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) { + gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j]; + } + } else { + gatherv2_out_lens = gatherv2_out_lens * input_shape_[i]; + } + } + gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float); + gather_v2_out_ = malloc(gatherv2_out_lens_); + if (gather_v2_out_ == nullptr) { + MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_; + } + auto ret = memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_); + if (ret != 0) { + MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed"; + } + split_num_ = AnfAlgo::GetNodeAttr(kernel_node, "split_num"); + } +#else + if (reduce_scatter_flag_) { + MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true"; + } +#endif + if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) { + offset_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrOffset); + } + CPUKernelUtils::ExpandDimsTo4(&input_shape_); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +bool EmbeddingLookUpCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast(gather_v2_out_) : output_addr; + size_t dim0 = input_shape_[0]; + size_t dim1 = input_shape_[1]; + size_t dim2 = input_shape_[2]; + if (axis_ == 3) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + for (size_t k = 0; k < dim2; ++k) { + LookUpTable(inputs, i, j, k, &gather_out_addr); + } + } + } + } else if (axis_ == 2) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + LookUpTable(inputs, i, j, 0, &gather_out_addr); + } + } + } else if (axis_ == 1) { + for (size_t i = 0; i < dim0; ++i) { + LookUpTable(inputs, i, 0, 0, &gather_out_addr); + } + } else if (axis_ == 0) { + LookUpTable(inputs, 0, 0, 0, &gather_out_addr); + } +#ifdef ENABLE_MPI + if (reduce_scatter_flag_) { + size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float); + size_t reduce_scatter_out_lens = one_split_lens / 8; + const std::vector &group = {0, 1, 2, 3, 4, 5, 6, 7}; + auto mpi_instance = device::cpu::MPIAdapter::Instance(); + MS_EXCEPTION_IF_NULL(mpi_instance); + for (int i = 0; i < split_num_; i++) { + mpi_instance->ReduceScatter(reinterpret_cast(gather_v2_out_) + i * one_split_lens, + output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum"); + } + } +#endif + return true; +} + +void LookUpTable_task(const float *input_addr, float *output_addr, const int *indices_addr, size_t indices_lens, + size_t num, size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, + std::vector input_shape, size_t input_lens) { + size_t lens = num * sizeof(float); + for (size_t i = 0; i < indices_lens; ++i) { + int indices = indices_addr[i] - offset; + if (indices >= 0) { + size_t index = IntToSize(indices); + if (index < input_shape[axis]) { + size_t pos = 0; + if (axis == 3) { + pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, dim2, index); + } else if (axis == 2) { + pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, index, 0); + } else if (axis == 1) { + pos = CPUKernelUtils::CalcOffset(input_shape, dim0, index, 0, 0); + } else if (axis == 0) { + pos = CPUKernelUtils::CalcOffset(input_shape, index, 0, 0, 0); + } + if (pos + num <= input_lens) { + auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; + } + } else { + auto ret = memset_s(output_addr, lens, 0, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; + } + } + } else { + auto ret = memset_s(output_addr, lens, 0, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; + } + } + } else { + auto ret = memset_s(output_addr, lens, 0, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; + } + } + output_addr += num; + } +} + +void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector &inputs, size_t dim0, size_t dim1, + size_t dim2, float **output_addr) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto indices_addr = reinterpret_cast(inputs[1]->addr); + size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); + float *task_out_addr = *output_addr; + const size_t thread_num = 8; + std::thread threads[8]; + size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; + size_t i; + size_t task_offset = 0; + MS_LOG(DEBUG) << "indices_lens_: " << indices_lens_ << " one task proc lens:" << task_proc_lens; + for (i = 0; i < thread_num; i++) { + if (task_offset >= indices_lens_) { + break; + } + MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; + threads[i] = + std::thread(LookUpTable_task, input_addr, task_out_addr + task_offset * num, indices_addr + task_offset, + task_proc_lens, num, dim0, dim1, dim2, offset_, axis_, input_shape_, input_lens_); + task_offset += task_proc_lens; + if (task_offset + task_proc_lens > indices_lens_) { + task_proc_lens = indices_lens_ - task_offset; + } + } + for (size_t j = 0; j < i; j++) { + threads[j].join(); + } + *output_addr += num * indices_lens_; +} + +void EmbeddingLookUpCPUKernel::CheckParam(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() + << ", but EmbeddingLookUpCPUKernel olny support 4d or lower."; + } + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but EmbeddingLookUpCPUKernel needs 2."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h new file mode 100644 index 0000000000..6c61ee346c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class EmbeddingLookUpCPUKernel : public CPUKernel { + public: + EmbeddingLookUpCPUKernel() { + axis_ = 0; + offset_ = 0; + split_num_ = 0; + input_lens_ = 0; + indices_lens_ = 0; + gatherv2_out_lens_ = 0; + reduce_scatter_flag_ = false; + gather_v2_out_ = nullptr; + } + ~EmbeddingLookUpCPUKernel() override { + if (gather_v2_out_ != nullptr) { + free(gather_v2_out_); + gather_v2_out_ = nullptr; + } + } + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void LookUpTable(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, + float **output_addr); + void CheckParam(const CNodePtr &kernel_node); + std::vector input_shape_; + std::vector indices_shape_; + std::vector output_shape_; + int axis_; + int offset_; + int split_num_; + size_t input_lens_; + size_t indices_lens_; + size_t gatherv2_out_lens_; + bool reduce_scatter_flag_; + + void *gather_v2_out_; +}; + +MS_REG_CPU_KERNEL( + EmbeddingLookup, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + EmbeddingLookUpCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.cc new file mode 100644 index 0000000000..a61cd185c6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/equal_count_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void EqualCountCPUKernel::InitKernel(const CNodePtr & /*kernel_node*/) {} + +bool EqualCountCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output empty!"; + } + if (inputs[0]->size != inputs[1]->size) { + MS_LOG(EXCEPTION) << "input or output size!"; + } + int count = 0; + auto left = reinterpret_cast(inputs[0]->addr); + auto right = reinterpret_cast(inputs[1]->addr); + size_t elem_num = inputs[0]->size / sizeof(int); + for (size_t i = 0; i < elem_num; i++) { + if (left[i] == right[i]) { + count++; + } + } + auto output = reinterpret_cast(outputs[0]->addr); + output[0] = count; + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.h new file mode 100644 index 0000000000..6e4ed6d5f1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class EqualCountCPUKernel : public CPUKernel { + public: + EqualCountCPUKernel() = default; + ~EqualCountCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + EqualCount, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + EqualCountCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc new file mode 100644 index 0000000000..73b11f1c01 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/gather_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_shape_.size()); + } + axis_ += 4 - input_shape_.size(); + CPUKernelUtils::ExpandDimsTo4(&input_shape_); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +bool GatherV2CPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto buff_size = outputs[0]->size; + size_t dim0 = input_shape_[0]; + size_t dim1 = input_shape_[1]; + size_t dim2 = input_shape_[2]; + if (axis_ == 3) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + for (size_t k = 0; k < dim2; ++k) { + CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size); + } + } + } + } else if (axis_ == 2) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size); + } + } + } else if (axis_ == 1) { + for (size_t i = 0; i < dim0; ++i) { + CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size); + } + } else if (axis_ == 0) { + CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); + } + return true; +} + +void GatherV2CPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, + size_t dim2, float **output_addr, size_t *buff_size) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto indices_addr = reinterpret_cast(inputs[1]->addr); + size_t elem_num = inputs[1]->size / 4; + size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); + for (size_t i = 0; i < elem_num; ++i) { + if (indices_addr[i] < 0) { + MS_LOG(EXCEPTION) << "The indices value is less than 0."; + } + size_t index = IntToSize(indices_addr[i]); + if (index >= input_shape_[IntToSize(axis_)]) { + auto ret = memset_s(*output_addr, *buff_size, 0., num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memset failed."; + } + } else { + size_t pos = 0; + if (axis_ == 3) { + pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index); + } else if (axis_ == 2) { + pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0); + } else if (axis_ == 1) { + pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0); + } else if (axis_ == 0) { + pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0); + } + auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed."; + } + } + *output_addr += num; + *buff_size -= num * sizeof(float); + } +} // namespace kernel + +void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but GatherV2CPUKernel olny support 4d or lower."; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h new file mode 100644 index 0000000000..8fdac0dfde --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class GatherV2CPUKernel : public CPUKernel { + public: + GatherV2CPUKernel() : axis_(0) {} + ~GatherV2CPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, + float **output_addr, size_t *buff_size); + void CheckParam(const CNodePtr &kernel_node); + std::vector input_shape_; + std::vector indices_shape_; + std::vector output_shape_; + int axis_; +}; + +MS_REG_CPU_KERNEL( + GatherV2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherV2CPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc new file mode 100644 index 0000000000..e58b1d319c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h" +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + if (src_shape.size() != 4 || weight_shape.size() != 4) { + MS_LOG(EXCEPTION) << "conv2d only support nchw input!"; + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); + dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); + + int kernel_size = SizeToInt(weight_shape[3]); + auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); + auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); + if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) { + MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!"; + } + if (stride_ori[0] != 1 || stride_ori[1] != 1) { + MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!"; + } + if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { + MS_LOG(EXCEPTION) << "conv2d dilation only support 1, and dilation must be 4d!"; + } + if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { + MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!"; + } + int stride = stride_ori[2]; + int dilation = dilation_ori[2]; + + dnnl::memory::dims strides{stride, stride}; + dnnl::memory::dims dilates{dilation - 1, dilation - 1}; + std::vector int_padding_l; + std::vector int_padding_r; + + const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); + GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); + if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { + MS_LOG(EXCEPTION) << "get padding failed"; + } + dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; + dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; + dnnl::convolution_forward::desc desc = + dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, + weights_desc, dst_desc, strides, dilates, padding_l, padding_r); + + auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_WEIGHTS, weights_desc); + AddArgument(DNNL_ARG_DST, dst_desc); +} + +bool Conv2dCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h new file mode 100644 index 0000000000..c0c64ba4da --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class Conv2dCPUKernel : public MKLCPUKernel { + public: + Conv2dCPUKernel() = default; + ~Conv2dCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + Conv2D, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Conv2dCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc new file mode 100644 index 0000000000..3fa6a91405 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h" +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector weight_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + std::vector dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (src_shape.size() != 4 || weight_shape.size() != 4) { + MS_LOG(EXCEPTION) << ("conv2d grad filter only support nchw input!"); + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); + dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); + + int kernel_size = SizeToInt(weight_shape[3]); + auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); + auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); + if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { + MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel only support equal stride, and stride must be 2d!"; + } + if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { + MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1, and dilation must be 4d!"; + } + if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { + MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1 in N axis and C axis!"; + } + int stride = stride_ori[0]; + int dilation = dilation_ori[2]; + + dnnl::memory::dims strides{stride, stride}; + dnnl::memory::dims dilates{dilation - 1, dilation - 1}; + const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); + std::vector int_padding_l; + std::vector int_padding_r; + GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); + if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { + MS_LOG(EXCEPTION) << "get padding failed"; + } + dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; + dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; + dnnl::convolution_forward::desc forward_desc = + dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, + weights_desc, dst_desc, strides, dilates, padding_l, padding_r); + + auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); + + dnnl::convolution_backward_weights::desc backward_desc = dnnl::convolution_backward_weights::desc( + dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); + + auto backward_prim_desc = dnnl::convolution_backward_weights::primitive_desc( + backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); + primitive_ = std::make_shared(backward_prim_desc); + + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DIFF_DST, dst_desc); + AddArgument(DNNL_ARG_DIFF_WEIGHTS, weights_desc); +} + +bool Conv2dGradFilterCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h new file mode 100644 index 0000000000..ae8269c142 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class Conv2dGradFilterCPUKernel : public MKLCPUKernel { + public: + Conv2dGradFilterCPUKernel() = default; + ~Conv2dGradFilterCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + Conv2DBackpropFilter, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Conv2dGradFilterCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc new file mode 100644 index 0000000000..1f02d70f86 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h" +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (src_shape.size() != 4 || weight_shape.size() != 4) { + MS_LOG(EXCEPTION) << "conv2d grad filter only support nchw input!"; + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); + dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); + + int kernel_size = SizeToInt(weight_shape[3]); + auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); + auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); + if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { + MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel only support equal stride, and stride must be 2d!"; + } + if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { + MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1, and dilation must be 4d!"; + } + if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { + MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1 in N axis and C axis!"; + } + int stride = stride_ori[0]; + int dilation = dilation_ori[2]; + dnnl::memory::dims strides{stride, stride}; + dnnl::memory::dims dilates{dilation - 1, dilation - 1}; + std::vector int_padding_l; + std::vector int_padding_r; + const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); + GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); + if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { + MS_LOG(EXCEPTION) << "conv2d grad get padding failed"; + } + dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; + dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; + dnnl::convolution_forward::desc forward_desc = + dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, + weights_desc, dst_desc, strides, dilates, padding_l, padding_r); + + auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); + + dnnl::convolution_backward_data::desc backward_desc = dnnl::convolution_backward_data::desc( + dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); + + auto backward_prim_desc = + dnnl::convolution_backward_data::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); + primitive_ = std::make_shared(backward_prim_desc); + + AddArgument(DNNL_ARG_DIFF_SRC, src_desc); + AddArgument(DNNL_ARG_DIFF_DST, dst_desc); + AddArgument(DNNL_ARG_WEIGHTS, weights_desc); +} + +bool Conv2dGradInputCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h new file mode 100644 index 0000000000..6f699130a8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class Conv2dGradInputCPUKernel : public MKLCPUKernel { + public: + Conv2dGradInputCPUKernel() = default; + ~Conv2dGradInputCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + Conv2DBackpropInput, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Conv2dGradInputCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc new file mode 100644 index 0000000000..626fd1934e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc @@ -0,0 +1,141 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h" +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { +#ifdef PLATFORM_86 + _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); +#endif + MS_EXCEPTION_IF_NULL(kernel_node); + using tag = dnnl::memory::format_tag; + using dim = dnnl::memory::dims; + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); + input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); + hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); + num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); + has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); + batch_size_ = SizeToInt(src_shape[1]); + seq_len_ = SizeToInt(src_shape[0]); + num_directions_ = 1; + if (bidirectional_) { + num_directions_ = 2; + } + if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { + MS_LOG(EXCEPTION) << "error iteration shape!"; + } + if (num_layers_ <= 0) { + MS_LOG(EXCEPTION) << "layers must be greater than zero!"; + } + if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { + MS_LOG(EXCEPTION) << "conv2d only support 3-D input!"; + } + const int gate_size = 4 * hidden_size_; + for (int i = 0; i < num_layers_; ++i) { + weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); + weight_h_size_ += gate_size * hidden_size_; + } + weight_size_ = weight_size_ * num_directions_; + weight_h_size_ = weight_h_size_ * num_directions_; + auto eng = MKLKernelEngine::Get().engine(); + dnnl::stream s(eng); + dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; + if (bidirectional_) { + direction = dnnl::rnn_direction::bidirectional_concat; + } + dim src_dims = {seq_len_, batch_size_, input_size_}; + dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; + weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; + bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; + dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; + dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); + dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); + dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); + dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); + dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); + dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); + dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); + auto desc = std::make_shared(dnnl::prop_kind::forward_training, direction, src_desc, + src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), + formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, + dst_h_desc, dst_c_desc); + prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng); + primitive_ = std::make_shared(prim_desc_); + AddArgument(DNNL_ARG_SRC_LAYER, src_desc); + AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); + AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); + AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc()); + AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc()); + AddArgument(DNNL_ARG_BIAS, bias_desc); + AddArgument(DNNL_ARG_DST_LAYER, dst_desc); + AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); + AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); + AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc()); +} + +bool LstmCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + auto eng = MKLKernelEngine::Get().engine(); + auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); + auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); + auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng); + auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng); + user_weights_memory.set_data_handle(inputs[3]->addr); + user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); + Reorder(&user_weights_memory, &weights_memory); + Reorder(&user_weights_h_memory, &weights_h_memory); + auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng); + if (has_bias_) { + bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); + } else { + auto ret = + memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, prim_desc_.bias_desc().get_size()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "bias memset error"; + } + } + // set handle + SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr); + SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h new file mode 100644 index 0000000000..761494a931 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ +#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64) +#define PLATFORM_86 +#endif +#ifdef PLATFORM_86 +#include +#endif +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" +namespace mindspore { +namespace kernel { +class LstmCPUKernel : public MKLCPUKernel { + public: + LstmCPUKernel() = default; + ~LstmCPUKernel() override = default; + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + int weight_size_ = 0; + int weight_h_size_ = 0; + int input_size_; + int hidden_size_; + int num_layers_; + int batch_size_; + int seq_len_; + int num_directions_; + bool bidirectional_; + bool has_bias_; + dnnl::memory::dims weights_dims_; + dnnl::memory::dims weights_h_dims_; + dnnl::memory::dims bias_dims_; + dnnl::lstm_forward::primitive_desc prim_desc_; +}; + +MS_REG_CPU_KERNEL(LSTM, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LstmCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc new file mode 100644 index 0000000000..56da8ec808 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc @@ -0,0 +1,196 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h" +#include +#include +#include +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + using tag = dnnl::memory::format_tag; + using dim = dnnl::memory::dims; + auto eng = MKLKernelEngine::Get().engine(); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); + input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); + hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); + num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); + has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); + batch_size_ = SizeToInt(src_shape[1]); + seq_len_ = SizeToInt(src_shape[0]); + num_directions_ = 1; + if (bidirectional_) { + num_directions_ = 2; + } + if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { + MS_LOG(EXCEPTION) << "error iteration shape!"; + } + if (num_layers_ <= 0) { + MS_LOG(EXCEPTION) << "layers must be greater than zero!"; + } + if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { + MS_LOG(EXCEPTION) << "conv2d only support 3-D input!"; + } + const int gate_size = 4 * hidden_size_; + for (int i = 0; i < num_layers_; ++i) { + weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); + weight_h_size_ += gate_size * hidden_size_; + } + weight_size_ = weight_size_ * num_directions_; + weight_h_size_ = weight_h_size_ * num_directions_; + dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; + if (bidirectional_) { + direction = dnnl::rnn_direction::bidirectional_concat; + } + dim src_dims = {seq_len_, batch_size_, input_size_}; + dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; + weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; + bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; + dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; + dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); + dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); + dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); + dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); + dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); + dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); + dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); + auto forward_desc = std::make_shared( + dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, + formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, + dst_c_desc); + auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng); + auto backward_desc = std::make_shared( + dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), + formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, + src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, + dst_h_desc, dst_c_desc); + prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc); + primitive_ = std::make_shared(prim_backward_desc_); + + AddArgument(DNNL_ARG_SRC_LAYER, src_desc); + AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); + AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); + AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc()); + AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc()); + AddArgument(DNNL_ARG_BIAS, bias_desc); + AddArgument(DNNL_ARG_DST_LAYER, dst_desc); + AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); + AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); + AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc()); + AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc); + AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc); + AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); + AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc()); + AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc()); + AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc); + AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc); + AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc); + AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc); +} + +bool LSTMGradCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace /*workspace*/, + const std::vector &outputs) { + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + auto eng = MKLKernelEngine::Get().engine(); + // construct fw memory + auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); + auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); + auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng); + auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng); + auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng); + user_weights_memory.set_data_handle(inputs[3]->addr); + user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); + Reorder(&user_weights_memory, &weights_memory); + Reorder(&user_weights_h_memory, &weights_h_memory); + if (has_bias_) { + bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); + } else { + if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0, + prim_backward_desc_.bias_desc().get_size())) { + MS_LOG(EXCEPTION) << "bias memset error"; + } + } + // construct bw memory + auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng); + auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng); + auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng); + auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); + auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); + user_diff_weights_memory.set_data_handle(outputs[3]->addr); + user_diff_weights_h_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_); + if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0, + user_diff_weights_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "user weights grad memset error"; + } + if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0, + user_diff_weights_h_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "user weights iter grad memset error"; + } + if (has_bias_) { + diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); + } + if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0, + prim_backward_desc_.diff_bias_desc().get_size())) { + MS_LOG(EXCEPTION) << "bias grad memset error"; + } + if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0, + diff_weights_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "weights grad memset error"; + } + if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0, + diff_weights_h_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "weights iter grad memset error"; + } + SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr); + SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr); + ExecutePrimitive(); + Reorder(&diff_weights_memory, &user_diff_weights_memory); + Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h new file mode 100644 index 0000000000..b95b5ba792 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class LSTMGradCPUKernel : public MKLCPUKernel { + public: + LSTMGradCPUKernel() = default; + ~LSTMGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + int weight_size_ = 0; + int weight_h_size_ = 0; + int input_size_; + int hidden_size_; + int num_layers_; + int batch_size_; + int seq_len_; + int num_directions_; + bool bidirectional_; + bool has_bias_; + dnnl::memory::dims weights_dims_; + dnnl::memory::dims weights_h_dims_; + dnnl::memory::dims bias_dims_; + dnnl::lstm_backward::primitive_desc prim_backward_desc_; +}; + +MS_REG_CPU_KERNEL(LSTMGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LSTMGradCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc new file mode 100644 index 0000000000..4bbaa6459f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h" +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "common/utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + + if (src_shape.size() != 2 || weight_shape.size() != 2 || dst_shape.size() != 2) { + MS_LOG(EXCEPTION) << "matmul invalid input size"; + } + bool trans_a = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_A); + bool trans_b = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_B); + if (trans_a) { + trans_a_ = TRANSPOSE_YES; + dim_m_ = static_cast(src_shape[1]); + dim_k_ = static_cast(src_shape[0]); + } else { + dim_m_ = static_cast(src_shape[0]); + dim_k_ = static_cast(src_shape[1]); + } + if (trans_b) { + trans_b_ = TRANSPOSE_YES; + } + dim_n_ = static_cast(dst_shape[1]); +} + +bool MatMulCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "matmul error input output size!"; + } + dnnl_dim_t lda = dim_m_; + if (trans_a_ == TRANSPOSE_NO) { + lda = dim_k_; + } + dnnl_dim_t ldb = dim_k_; + if (trans_b_ == TRANSPOSE_NO) { + ldb = dim_n_; + } + auto input_a = reinterpret_cast(inputs[0]->addr); + auto input_b = reinterpret_cast(inputs[1]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + (void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, 1.f, input_a, lda, input_b, ldb, 0.f, output, dim_n_); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h new file mode 100644 index 0000000000..ef52f652d0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class MatMulCPUKernel : public MKLCPUKernel { + public: + MatMulCPUKernel() = default; + ~MatMulCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + char trans_a_{TRANSPOSE_NO}; + char trans_b_{TRANSPOSE_NO}; + dnnl_dim_t dim_m_{0}; + dnnl_dim_t dim_n_{0}; + dnnl_dim_t dim_k_{0}; +}; + +MS_REG_CPU_KERNEL( + MatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MatMulCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc new file mode 100644 index 0000000000..c71abe809d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" +#include +#include +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" + +namespace mindspore { +namespace kernel { +void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, + const std::vector &src_shape, int kernel_size, int stride, + std::vector *padding_l, std::vector *padding_r) { + MS_EXCEPTION_IF_NULL(kernel_node); + if (src_shape.size() < 2) { + MS_LOG(EXCEPTION) << "set pad only support src dim >= 2!"; + } + std::vector weight_height; + weight_height.emplace_back(src_shape[src_shape.size() - 2]); + weight_height.emplace_back(src_shape[src_shape.size() - 1]); + int rad = kernel_size / 2; + int need_pad = kernel_size - 1; + MS_LOG(INFO) << "pad mode " << pad_mode; + if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) { + for (auto wh : weight_height) { + int re = (wh - 1) % stride; + int pad = std::max(rad - (re / 2), 0); + padding_r->emplace_back(pad); + pad = std::max(need_pad - pad - re, 0); + padding_l->emplace_back(pad); + } + } else if (pad_mode == PAD_MODE_LOWER_VALID || pad_mode == PAD_MODE_UPPER_VALID) { + MS_LOG(INFO) << "pad valid"; + padding_l->emplace_back(0); + padding_l->emplace_back(0); + padding_r->emplace_back(0); + padding_r->emplace_back(0); + } else { + std::vector pad = AnfAlgo::GetNodeAttr>(kernel_node, PAD); + if (pad.size() != 4) { + MS_LOG(EXCEPTION) << "wrong pad size in max pooling " << pad.size(); + } + padding_l->emplace_back(pad[0]); + padding_l->emplace_back(pad[1]); + padding_r->emplace_back(pad[2]); + padding_r->emplace_back(pad[3]); + } +} + +dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const { + dnnl::memory::format_tag mem_tag; + auto dim_size = dims.size(); + if (dim_size == 4) { + mem_tag = dnnl::memory::format_tag::abcd; + } else if (dim_size == 3) { + mem_tag = dnnl::memory::format_tag::abc; + } else if (dim_size == 2) { + mem_tag = dnnl::memory::format_tag::ab; + } else if (dim_size == 1) { + mem_tag = dnnl::memory::format_tag::a; + } else { + MS_LOG(EXCEPTION) << "kernel dims invalid " << dim_size; + } + return mem_tag; +} + +dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector &shape) { + dnnl::memory::dims dims; + dims.insert(dims.end(), shape.begin(), shape.end()); + dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims); + dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag); + return mem_desc; +} + +void MKLCPUKernel::AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc) { + arguments_[arg_key] = MKLKernelEngine::Get().CreateMemory(mem_desc, alloc); +} + +void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) { + auto arg_iter = arguments_.find(arg_key); + if (arg_iter != arguments_.end()) { + arg_iter->second.set_data_handle(ptr); + } +} + +void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); } + +void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { + MKLKernelEngine::Get().Reorder(src_mem, dst_mem); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h new file mode 100644 index 0000000000..fc7128b10e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include "dnnl.hpp" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class MKLCPUKernel : public CPUKernel { + public: + MKLCPUKernel() = default; + ~MKLCPUKernel() override = default; + + protected: + void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector &src_shape, + int kernel_size, int stride, std::vector *padding_l, std::vector *padding_r); + void AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc = false); + void SetArgumentHandle(int arg_key, void *ptr); + dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const; + dnnl::memory::desc GetDefaultMemDesc(const std::vector &shape); + void ExecutePrimitive(); + std::unordered_map arguments_; + std::shared_ptr primitive_{nullptr}; + inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) { + return dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout}; + } + void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.cc new file mode 100644 index 0000000000..777668f960 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "utils/log_adapter.h" +#include "dnnl.hpp" + +namespace mindspore { +namespace kernel { +void MKLKernelEngine::Execute(const std::shared_ptr &primitive, + const std::unordered_map &arguments) { + MS_EXCEPTION_IF_NULL(primitive); + primitive->execute(stream_, arguments); + (void)stream_.wait(); +} + +dnnl::memory MKLKernelEngine::CreateMemory(const dnnl::memory::desc &mem_desc, bool alloc) { + if (alloc) { + return dnnl::memory(mem_desc, engine_); + } else { + return dnnl::memory(mem_desc, engine_, nullptr); + } +} +void MKLKernelEngine::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { + dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h similarity index 100% rename from mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.h rename to mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc new file mode 100644 index 0000000000..fddd769047 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void MulCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) { + MS_LOG(EXCEPTION) << "mul only support same dim input or tensor * scalar " << src0_shape.size() << " vs " + << src1_shape.size(); + } + if (src1_shape.size() < src0_shape.size()) { + for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { + src1_shape.emplace_back(1); + } + } + dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape); + dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape); + dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape); + dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_mul, src0_mem_desc, src1_mem_desc, dst_mem_desc); + auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + AddArgument(DNNL_ARG_SRC_0, src0_mem_desc); + AddArgument(DNNL_ARG_SRC_1, src1_mem_desc); + AddArgument(DNNL_ARG_DST, dst_mem_desc); +} + +bool MulCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "mul error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h new file mode 100644 index 0000000000..182679f59d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class MulCPUKernel : public MKLCPUKernel { + public: + MulCPUKernel() = default; + ~MulCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MulCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc new file mode 100644 index 0000000000..e4bedf23b9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h" +#include +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); + std::vector kernel_sizes = AnfAlgo::GetNodeAttr>(kernel_node, KSIZE); + std::vector strides = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); + if (kernel_sizes.size() != 4 || strides.size() != 4) { + MS_LOG(EXCEPTION) << "invalid kernel size " << kernel_sizes.size() << " or stride size " << strides.size(); + } + dnnl::memory::dims strides_dims{strides[2], strides[3]}; + dnnl::memory::dims kernels_dims{kernel_sizes[2], kernel_sizes[3]}; + const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PADDING); + std::vector int_padding_l; + std::vector int_padding_r; + GetPadding(kernel_node, pad_mode, src_shape, kernel_sizes[3], strides[3], &int_padding_l, &int_padding_r); + if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { + MS_LOG(EXCEPTION) << "pooling get padding failed"; + } + dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; + dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; + dnnl::pooling_forward::desc desc = + dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_max, src_desc, dst_desc, + strides_dims, kernels_dims, padding_l, padding_r); + auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DST, dst_desc); + AddArgument(DNNL_ARG_WORKSPACE, prim_desc.workspace_desc()); +} + +bool PoolingCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h new file mode 100644 index 0000000000..8187eaffda --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class PoolingCPUKernel : public MKLCPUKernel { + public: + PoolingCPUKernel() = default; + ~PoolingCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + PoolingCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.cc new file mode 100644 index 0000000000..8189df07ff --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h" +#include +#include +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void PoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + src_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + dst_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector kernel_sizes = AnfAlgo::GetNodeAttr>(kernel_node, KSIZE); + std::vector strides = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); + if (kernel_sizes.size() != 4 || strides.size() != 4 || src_shape_.size() != 4 || dst_shape_.size() != 4) { + MS_LOG(EXCEPTION) << "pooling grad invalid input size"; + } + std::vector padding_r; + const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PADDING); + kernel_size_ = kernel_sizes[3]; + stride_ = strides[3]; + GetPadding(kernel_node, pad_mode, src_shape_, kernel_size_, stride_, &padding_l_, &padding_r); +} + +void PoolingGradCPUKernel::RowPoolingGrad(const float *input, float *output, float diff, + const std::vector> &box, + std::vector> *row_max_pair) { + float max_value = 0; + size_t max_index = box[1].second; + size_t src_width = src_shape_[3]; + size_t index_start; + size_t index; + for (size_t i = box[1].first; i < box[1].second; ++i) { + if ((*row_max_pair)[i].first == 0) { + index_start = box[0].first * src_width; + for (size_t j = box[0].first; j < box[0].second; ++j) { + index = index_start + i; + if (input[index] > (*row_max_pair)[i].second || j == box[0].first) { + (*row_max_pair)[i].second = input[index]; + (*row_max_pair)[i].first = index; + } + index_start += src_width; + } + } + if ((*row_max_pair)[i].second > max_value || max_index == box[1].second) { + max_value = (*row_max_pair)[i].second; + max_index = i; + } + } + + output[(*row_max_pair)[max_index].first] += diff; +} + +void PoolingGradCPUKernel::ChannelPoolingGrad(const float *input, const float *diff, float *output) { + int src_width = SizeToInt(src_shape_[3]); + int src_height = SizeToInt(src_shape_[2]); + std::vector> row_max_pair(src_shape_[3]); + std::vector> box(2); + int h_start = -padding_l_[0]; + size_t diff_index = 0; + for (size_t h = 0; h < dst_shape_[2]; ++h) { + box[0].first = IntToSize(std::max(h_start, 0)); + box[0].second = IntToSize(std::min(h_start + kernel_size_, src_height)); + for (size_t w = 0; w < src_shape_[3]; ++w) { + row_max_pair[w].first = 0; + row_max_pair[w].second = 0; + } + int w_start = -padding_l_[1]; + for (size_t w = 0; w < dst_shape_[3]; ++w) { + box[1].first = IntToSize(std::max(w_start, 0)); + box[1].second = IntToSize(std::min(w_start + kernel_size_, src_width)); + RowPoolingGrad(input, output, diff[diff_index], box, &row_max_pair); + diff_index += 1; + w_start += stride_; + } + h_start += stride_; + } +} + +bool PoolingGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 3 || outputs.empty()) { + MS_LOG(EXCEPTION) << "pooling grad error input output size!"; + } + + auto input = reinterpret_cast(inputs[0]->addr); + auto diff = reinterpret_cast(inputs[2]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + auto ret = memset_s(output, outputs[0]->size, 0, outputs[0]->size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "pooling grad memset error"; + } + size_t src_wh = src_shape_[2] * src_shape_[3]; + size_t dst_wh = dst_shape_[2] * dst_shape_[3]; + for (size_t n = 0; n < src_shape_[0]; ++n) { + for (size_t c = 0; c < src_shape_[1]; ++c) { + ChannelPoolingGrad(input, diff, output); + input = input + src_wh; + output = output + src_wh; + diff = diff + dst_wh; + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h new file mode 100644 index 0000000000..95a7bb3f66 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class PoolingGradCPUKernel : public MKLCPUKernel { + public: + PoolingGradCPUKernel() = default; + ~PoolingGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void RowPoolingGrad(const float *input, float *output, float diff, const std::vector> &box, + std::vector> *row_max_pair); + void ChannelPoolingGrad(const float *input, const float *diff, float *output); + int stride_{0}, kernel_size_{0}; + std::vector padding_l_; + std::vector src_shape_; + std::vector dst_shape_; +}; + +MS_REG_CPU_KERNEL(MaxPoolGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + PoolingGradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.cc new file mode 100644 index 0000000000..29ac9a1062 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void ReluCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (src_shape.size() != 4 && src_shape.size() != 2) { + MS_LOG(EXCEPTION) << "relu kernel dims invalid " << src_shape.size(); + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + + dnnl::eltwise_forward::desc desc = + dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::eltwise_relu, src_desc, 0.0); + auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DST, src_desc); +} + +bool ReluCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h new file mode 100644 index 0000000000..a2da2480e2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class ReluCPUKernel : public MKLCPUKernel { + public: + ReluCPUKernel() = default; + ~ReluCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReluCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.cc new file mode 100644 index 0000000000..9139aa7862 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void ReluGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (src_shape.size() != 4 && src_shape.size() != 2) { + MS_LOG(EXCEPTION) << "relu grad kernel dims invalid " << src_shape.size(); + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + + dnnl::eltwise_forward::desc forward_desc = + dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::eltwise_relu, src_desc, 0.0); + auto forward_prim_desc = dnnl::eltwise_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); + + dnnl::eltwise_backward::desc backward_desc = + dnnl::eltwise_backward::desc(dnnl::algorithm::eltwise_relu, src_desc, src_desc, 0.0, 0.0); + auto backward_prim_desc = + dnnl::eltwise_backward::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); + primitive_ = std::make_shared(backward_prim_desc); + + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DIFF_SRC, src_desc); + AddArgument(DNNL_ARG_DIFF_DST, src_desc); +} + +bool ReluGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "relu grad error input output size!"; + } + if (inputs[0]->size != outputs[0]->size) { + MS_LOG(EXCEPTION) << "relu grad error input output data size!"; + } + + SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); + ExecutePrimitive(); + size_t mem_bits = outputs[0]->size; + auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret; + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h new file mode 100644 index 0000000000..c895ab2756 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class ReluGradCPUKernel : public MKLCPUKernel { + public: + ReluGradCPUKernel() = default; + ~ReluGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + ReluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReluGradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc new file mode 100644 index 0000000000..94271b8a69 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void SoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector axis_list = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); + if (axis_list.size() != 1) { + MS_LOG(EXCEPTION) << "cpu softmax only support input axis size 1"; + } + int axis = axis_list[0]; + if (axis == -1 || axis >= SizeToInt(src_shape.size())) { + axis = SizeToInt(src_shape.size()) - 1; + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis); + auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DST, src_desc); +} + +bool SoftmaxCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "softmax error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h new file mode 100644 index 0000000000..2812dd31af --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class SoftmaxCPUKernel : public MKLCPUKernel { + public: + SoftmaxCPUKernel() = default; + ~SoftmaxCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SoftmaxCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc new file mode 100644 index 0000000000..889e2abdec --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h" +#include +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void SoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + size_t type_size = sizeof(float); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + workspace_size_list_.emplace_back(tensor_size); +} + +void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + dnnl::memory::dims mem_dims; + mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); + if (mem_dims.size() != 2) { + MS_LOG(EXCEPTION) << "SoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size(); + } + batch_size_ = shape[0]; + class_num_ = shape[1]; + if (batch_size_ == 0 || class_num_ == 0) { + MS_LOG(EXCEPTION) << "invalid batch size or class num input!"; + } + dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); + + dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); + auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + + AddArgument(DNNL_ARG_SRC, mem_desc); + AddArgument(DNNL_ARG_DST, mem_desc); +} + +void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *logits, const float *labels, + float *output1, float *output2) const { + float epsilon = 1e-6; + for (size_t i = 0; i < batch_size_; ++i) { + output1[i] = 0; + float loss = 0.0; + for (size_t j = 0; j < class_num_; ++j) { + float logit = logf(logits[i * class_num_ + j] <= 0.0 ? epsilon : logits[i * class_num_ + j]); + output2[i * class_num_ + j] = logits[i * class_num_ + j] - labels[i * class_num_ + j]; + loss += labels[i * class_num_ + j] * logit; + } + output1[i] = -loss; + } +} + +bool SoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + if (inputs.empty() || workspace.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + size_t batch_float_size = batch_size_ * sizeof(float); + size_t batch_class_float_size = class_num_ * batch_float_size; + if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size || + inputs[1]->size != batch_class_float_size) { + MS_LOG(EXCEPTION) << "error input data size!"; + } + if (outputs[1]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { + MS_LOG(EXCEPTION) << "error output data size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr); + ExecutePrimitive(); + auto labels = reinterpret_cast(inputs[1]->addr); + auto logits = reinterpret_cast(workspace[0]->addr); + auto output1 = reinterpret_cast(outputs[0]->addr); + auto output2 = reinterpret_cast(outputs[1]->addr); + ForwardPostExecute(logits, labels, output1, output2); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h new file mode 100644 index 0000000000..d05cb49b7b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class SoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { + public: + SoftmaxCrossEntropyWithLogitsCPUKernel() = default; + ~SoftmaxCrossEntropyWithLogitsCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + void InitInputOutputSize(const CNodePtr &kernel_node) override; + + private: + void ForwardPostExecute(const float *logits, const float *labels, float *output1, float *output2) const; + size_t class_num_{0}; + size_t batch_size_{0}; +}; +MS_REG_CPU_KERNEL(SoftmaxCrossEntropyWithLogits, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SoftmaxCrossEntropyWithLogitsCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc new file mode 100644 index 0000000000..b8bf7b318a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -0,0 +1,129 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h" +#include +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + size_t type_size = sizeof(float); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + workspace_size_list_.emplace_back(tensor_size); +} + +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + dnnl::memory::dims mem_dims; + mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); + if (mem_dims.size() != 2) { + MS_LOG(EXCEPTION) << "SparseSoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size(); + } + batch_size_ = shape[0]; + class_num_ = shape[1]; + if (batch_size_ == 0 || class_num_ == 0) { + MS_LOG(EXCEPTION) << "invalid batch size or class num input!"; + } + is_grad_ = AnfAlgo::GetNodeAttr(kernel_node, IS_GRAD); + dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); + + dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); + auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + + AddArgument(DNNL_ARG_SRC, mem_desc); + AddArgument(DNNL_ARG_DST, mem_desc); +} + +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses, + float *output) const { + float total_loss = 0; + for (size_t i = 0; i < batch_size_; ++i) { + if (labels[i] < 0) { + MS_LOG(EXCEPTION) << "label value must >= 0"; + } + size_t label = IntToSize(labels[i]); + if (label > class_num_) { + MS_LOG(EXCEPTION) << "error label input!"; + } + total_loss -= logf(losses[i * class_num_ + label]); + } + output[0] = total_loss / batch_size_; +} + +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, + float *output) const { + size_t row_start = 0; + for (size_t i = 0; i < batch_size_; ++i) { + if (labels[i] < 0) { + MS_LOG(EXCEPTION) << "label value must >= 0"; + } + size_t label = IntToSize(labels[i]); + if (label > class_num_) { + MS_LOG(EXCEPTION) << "error label input!"; + } + for (size_t j = 0; j < class_num_; ++j) { + size_t index = row_start + j; + if (j == label) { + output[index] = (losses[index] - 1) / batch_size_; + } else { + output[index] = losses[index] / batch_size_; + } + } + row_start += class_num_; + } +} + +bool SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + if (inputs.empty() || workspace.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + size_t batch_float_size = batch_size_ * sizeof(float); + size_t batch_class_float_size = class_num_ * batch_float_size; + if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size || + inputs[1]->size != batch_float_size) { + MS_LOG(EXCEPTION) << "error input data size!"; + } + if (is_grad_ && outputs[0]->size != batch_class_float_size) { + MS_LOG(EXCEPTION) << "error output data size!"; + } else if (!is_grad_ && outputs[0]->size != sizeof(float)) { + MS_LOG(EXCEPTION) << "error output data size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr); + ExecutePrimitive(); + auto labels = reinterpret_cast(inputs[1]->addr); + auto losses = reinterpret_cast(workspace[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + if (is_grad_) { + GradPostExecute(labels, losses, output); + } else { + ForwardPostExecute(labels, losses, output); + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h new file mode 100644 index 0000000000..0d79b0514b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { + public: + SparseSoftmaxCrossEntropyWithLogitsCPUKernel() = default; + ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + void InitInputOutputSize(const CNodePtr &kernel_node) override; + + private: + void ForwardPostExecute(const int *labels, const float *losses, float *output) const; + void GradPostExecute(const int *labels, const float *losses, float *output) const; + bool is_grad_{false}; + size_t class_num_{0}; + size_t batch_size_{0}; +}; + +MS_REG_CPU_KERNEL( + SparseSoftmaxCrossEntropyWithLogits, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + SparseSoftmaxCrossEntropyWithLogitsCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.cc new file mode 100644 index 0000000000..5bbc9f49a2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/one_hot_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void OneHotCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + if (output_shape.size() < 2) { + MS_LOG(EXCEPTION) << "invalid output shape size: " << output_shape.size(); + } + int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis != -1 && IntToSize(axis) >= output_shape.size()) { + MS_LOG(EXCEPTION) << "invalid axis: " << axis; + } + if (axis == -1) { + axis_ = output_shape.size() - 1; + } else { + axis_ = IntToSize(axis); + } + depth_ = output_shape[axis_]; + stride_ = 1; + for (size_t i = axis_ + 1; i < output_shape.size(); ++i) { + stride_ *= output_shape[i]; + } +} + +bool OneHotCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 3 || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output invalid!"; + } + auto indices = reinterpret_cast(inputs[0]->addr); + auto on_value = reinterpret_cast(inputs[1]->addr)[0]; + auto off_value = reinterpret_cast(inputs[2]->addr)[0]; + auto output = reinterpret_cast(outputs[0]->addr); + size_t elem_num = inputs[0]->size / sizeof(int); + + for (size_t i = 0; i < elem_num; i++) { + size_t stride_num = i / stride_; + size_t output_index = stride_num * depth_ * stride_ + i % stride_; + size_t index = IntToSize(indices[i]); + for (size_t j = 0; j < depth_; j++) { + if (index == j) { + output[output_index] = on_value; + } else { + output[output_index] = off_value; + } + output_index += stride_; + } + } + + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.h new file mode 100644 index 0000000000..393b0e8c41 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class OneHotCPUKernel : public CPUKernel { + public: + OneHotCPUKernel() = default; + ~OneHotCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t depth_; + size_t stride_; + size_t axis_; +}; + +MS_REG_CPU_KERNEL(OneHot, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + OneHotCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.cc new file mode 100644 index 0000000000..6537c88840 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +bool ApplyMomentumPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + return Launch(inputs, workspace, outputs); +} + +const std::vector &ApplyMomentumPSKernel::input_sizes() const { return GetInputSizeList(); } + +const std::vector &ApplyMomentumPSKernel::output_sizes() const { return GetOutputSizeList(); } + +const std::vector &ApplyMomentumPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h new file mode 100644 index 0000000000..a78f40d04b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +class ApplyMomentumPSKernel : public ApplyMomentumCPUKernel, public PServerKernel { + public: + ApplyMomentumPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ~ApplyMomentumPSKernel() override = default; + + bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + const std::vector &input_sizes() const override; + const std::vector &output_sizes() const override; + const std::vector &workspace_sizes() const override; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc new file mode 100644 index 0000000000..59ab65014b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h" +#include +#include "frontend/parallel/ps/worker.h" + +namespace mindspore { +namespace kernel { +namespace ps { +void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { + EmbeddingLookUpCPUKernel::InitKernel(kernel_node); + + for (auto dim : input_shape_) { + input_dims_ *= dim; + } + + if (mindspore::parallel::ps::Util::IsRoleOfWorker()) { + key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); + } + std::vector keys{key_, key_, key_}; + std::vector values; + values.insert(values.end(), input_shape_.begin(), input_shape_.end()); + values.insert(values.end(), indices_shape_.begin(), indices_shape_.end()); + values.insert(values.end(), output_shape_.begin(), output_shape_.end()); + std::vector lens{SizeToInt(input_shape_.size()), SizeToInt(indices_shape_.size()), + SizeToInt(output_shape_.size())}; + const char *env_role = getenv(mindspore::parallel::ps::kEnvRole); + if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) { + parallel::ps::Worker::GetInstance().AddEmbeddingTable(key_, input_shape_[axis_]); + parallel::ps::Worker::GetInstance().InitPSEmbeddingTable(keys, values, lens); + } +} + +bool EmbeddingLookUpProxyKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto indices_addr = reinterpret_cast(inputs[1]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + size_t input_size = inputs[1]->size; + size_t output_size = outputs[0]->size; + + size_t size = input_size / sizeof(float); + ::ps::SArray lookup_ids(size, 0); + ::ps::SArray lengths{size}; + ::ps::SArray lookup_result; + + auto ret = memcpy_s(lookup_ids.data(), input_size, indices_addr, input_size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + } + parallel::ps::Worker::GetInstance().DoPSEmbeddingLookup({key_}, lookup_ids, lengths, lookup_result, + parallel::ps::kEmbeddingLookupCmd); + + auto ret2 = memcpy_s(output_addr, output_size, lookup_result.data(), output_size); + if (ret2 != EOK) { + MS_LOG(EXCEPTION) << "Lookup result memcpy failed."; + } + return true; +} +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h new file mode 100644 index 0000000000..45e0a23fcb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ + +#include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h" +#include +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +namespace ps { +class EmbeddingLookUpProxyKernel : public EmbeddingLookUpCPUKernel { + public: + EmbeddingLookUpProxyKernel() = default; + ~EmbeddingLookUpProxyKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t key_{0}; + size_t input_dims_{1}; +}; + +MS_REG_CPU_KERNEL( + EmbeddingLookupProxy, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + EmbeddingLookUpProxyKernel); +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc new file mode 100644 index 0000000000..bcb3ca8ae8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" +#include +#include +#include +#include "backend/kernel_compiler/common_utils.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace kernel { +namespace ps { +using mindspore::parallel::ps::Util; +void EmbeddingLookUpPSKernel::InitKernel( + const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + input_shape_ = *(shape_vec[0]); + input_lens_ = 1; + for (auto shape : input_shape_) { + input_lens_ = input_lens_ * shape; + } + indices_shape_ = *(shape_vec[1]); + indices_lens_ = 1; + for (auto shape : indices_shape_) { + indices_lens_ = indices_lens_ * shape; + } + output_shape_ = *(shape_vec[2]); + axis_ = 2; + reduce_scatter_flag_ = false; + + size_t offset = 0; + for (size_t i = 0; i < rank_id_; i++) { + offset += Util::LocalShard(input_shape_[axis_], i, pserver_num_); + } + offset_ = offset; + split_num_ = pserver_num_; + + // input shape should be sharded after computing offset_; + Shard(input_shape_, axis_); + + size_t output_size = + std::accumulate(output_shape_.begin(), output_shape_.end(), sizeof(float), std::multiplies()); + output_size_list_.emplace_back(output_size); + CPUKernelUtils::ExpandDimsTo4(&input_shape_); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + const auto &indices_shape_ = *(shape_vec[0]); + indices_lens_ = indices_shape_[0]; + + size_t output_size = sizeof(float) * indices_lens_; + for (size_t i = axis_ + 1; i < input_shape_.size(); i++) { + output_size *= input_shape_[i]; + } + output_size_list_.clear(); + output_size_list_.emplace_back(output_size); +} + +bool EmbeddingLookUpPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + return Launch(inputs, workspace, outputs); +} + +const std::vector &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; } + +const std::vector &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); } + +const std::vector &EmbeddingLookUpPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h new file mode 100644 index 0000000000..e23a90a11c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerKernel { + public: + EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ~EmbeddingLookUpPSKernel() override = default; + + void InitKernel(const std::shared_ptr>>> &) override; + void ReInit(const std::shared_ptr>>> &) override; + + bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + const std::vector &input_sizes() const override; + const std::vector &output_sizes() const override; + const std::vector &workspace_sizes() const override; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc new file mode 100644 index 0000000000..3aa421881a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace kernel { +namespace ps {} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h new file mode 100644 index 0000000000..a2b6c4fa61 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace kernel { +namespace ps { +using mindspore::parallel::ps::Util; +class PServerKernel { + public: + PServerKernel(size_t rank_id, size_t pserver_num) : rank_id_(rank_id), pserver_num_(pserver_num) {} + ~PServerKernel() = default; + PServerKernel(const PServerKernel &) = delete; + PServerKernel &operator=(const PServerKernel &) = delete; + + virtual void InitKernel(const std::shared_ptr>>> &) {} + virtual void ReInit(const std::shared_ptr>>> &) {} + virtual bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) = 0; + + virtual const std::vector &input_sizes() const = 0; + virtual const std::vector &output_sizes() const = 0; + virtual const std::vector &workspace_sizes() const = 0; + + protected: + virtual void ReInit(const std::vector &) {} + void Shard(std::vector *shape, int axis) { + (*shape)[axis] = Util::LocalShard((*shape)[axis], rank_id_, pserver_num_); + } + + size_t rank_id_; + size_t pserver_num_; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.cc new file mode 100644 index 0000000000..92c901d4c8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/ps/pull_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_CPU_KERNEL_T( + Pull, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + PullKernel, float); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h new file mode 100644 index 0000000000..84dd9b819e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ + +#include +#include +#include "frontend/parallel/ps/worker.h" +#include "frontend/parallel/ps/util.h" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class PullKernel : public CPUKernel { + public: + PullKernel() : keys_size_(sizeof(size_t)), var_size_(sizeof(size_t)) {} + ~PullKernel() override = default; + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &) { + // If the paramter is embedding table, don't Pull from PServer. + if (param_name_.find("embedding") == std::string::npos && param_name_.find("wide_w") == std::string::npos) { + parallel::ps::Worker::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size); + } + return true; + } + void Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but pull needs 2 inputs."; + return; + } + + auto key_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < key_shape.size(); i++) { + keys_size_ *= key_shape[i]; + } + auto var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < var_shape.size(); i++) { + var_size_ *= var_shape[i]; + } + auto param_node = AnfAlgo::GetInputNode(kernel_node, 1); + MS_EXCEPTION_IF_NULL(param_node); + param_name_ = param_node->fullname_with_scope(); + + if (mindspore::parallel::ps::Util::IsRoleOfWorker()) { + key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); + } + InitSizeLists(); + return; + } + void InitKernel(const CNodePtr &kernel_node) { return; } + + protected: + void InitSizeLists() { + input_size_list_.push_back(keys_size_); + input_size_list_.push_back(var_size_); + output_size_list_.push_back(0); + } + + private: + size_t key_; + size_t keys_size_; + size_t var_size_; + std::string param_name_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc new file mode 100644 index 0000000000..96c1f15bda --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/ps/push_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_CPU_KERNEL_T(Push, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt64), + PushKernel, float); + +MS_REG_CPU_KERNEL_T( + Push, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), + PushKernel, float); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h new file mode 100644 index 0000000000..938792f3bf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ + +#include +#include +#include "frontend/parallel/ps/worker.h" +#include "frontend/parallel/ps/util.h" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class PushKernel : public CPUKernel { + public: + PushKernel() : key_(UINT64_MAX) {} + ~PushKernel() override = default; + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + std::vector keys; + std::vector addrs; + std::vector sizes; + for (auto input : inputs) { + keys.push_back(key_); + addrs.push_back(reinterpret_cast(input->addr)); + sizes.push_back(SizeToInt(input->size) / sizeof(T)); + } + parallel::ps::Worker::GetInstance().Push(keys, addrs, sizes); + memcpy(outputs[0]->addr, &key_, sizeof(size_t)); + return true; + } + + void Init(const CNodePtr &kernel_node) { + key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); + auto optim_input_shapes = AnfAlgo::GetNodeAttr>>(kernel_node, "optim_input_shapes"); + std::vector only_shape_indices = AnfAlgo::GetNodeAttr>(kernel_node, "only_shape_indices"); + MS_LOG(INFO) << "Key " << key_ << " optimizer input shapes are:" << optim_input_shapes; + MS_LOG(INFO) << "Only init shape indices are " << only_shape_indices; + for (size_t i = 0; i < optim_input_shapes.size(); i++) { + auto shape = optim_input_shapes[i]; + mindspore::parallel::ps::Worker::GetInstance().SetOptimInputShapes(key_, shape); + if (std::count(only_shape_indices.begin(), only_shape_indices.end(), i) == 0) { + size_t size = sizeof(T); + for (size_t j = 0; j < shape.size(); j++) { + size *= shape[j]; + } + input_size_list_.push_back(size); + } + } + + output_size_list_.push_back(sizeof(size_t)); + return; + } + + void InitKernel(const CNodePtr &kernel_node) { return; } + + private: + size_t key_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc new file mode 100644 index 0000000000..c7283954f8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h" +#include +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace kernel { +namespace ps { +void SparseApplyAdamPSKernel::InitKernel( + const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + std::vector &var_shape = *(shape_vec[0]); + std::vector &m_shape = *(shape_vec[1]); + std::vector &v_shape = *(shape_vec[2]); + const std::vector &grad_shape = *(shape_vec[9]); + const std::vector &indices_shape = *(shape_vec[10]); + + Shard(&var_shape, 0); + Shard(&m_shape, 0); + Shard(&v_shape, 0); + + if (!IsSameShape(var_shape, m_shape)) { + MS_LOG(EXCEPTION) << "var and m should have the same shape"; + } + if (!IsSameShape(var_shape, v_shape)) { + MS_LOG(EXCEPTION) << "var and v should have the same shape"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be 1D"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; + } + /* + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); + } + */ + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); +} + +void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + const std::vector &indices_shape = *(shape_vec[0]); + indices_size_ = indices_shape[0]; + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +void SparseApplyAdamPSKernel::ReInit(const std::vector &inputs) { + const auto &indices_addr = inputs[10]; + indices_size_ = indices_addr->size; + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +bool SparseApplyAdamPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + ReInit(inputs); + int *indices = reinterpret_cast(inputs[10]->addr); + for (size_t i = 0; i < inputs[10]->size / sizeof(int); i++) { + indices[i] -= rank_id_ * var_first_dim_size_; + } + return Launch(inputs, workspace, outputs); +} + +const std::vector &SparseApplyAdamPSKernel::input_sizes() const { return GetInputSizeList(); } + +const std::vector &SparseApplyAdamPSKernel::output_sizes() const { return GetOutputSizeList(); } + +const std::vector &SparseApplyAdamPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h new file mode 100644 index 0000000000..337fcb3bf0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +using mindspore::kernel::SparseApplyAdamCPUKernel; +class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerKernel { + public: + SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ~SparseApplyAdamPSKernel() override = default; + + void InitKernel(const std::shared_ptr>>> &) override; + void ReInit(const std::shared_ptr>>> &) override; + bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + const std::vector &input_sizes() const override; + const std::vector &output_sizes() const override; + const std::vector &workspace_sizes() const override; + + protected: + void ReInit(const std::vector &) override; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc new file mode 100644 index 0000000000..0392bd5a69 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace ps { +void SparseApplyFtrlPSKernel::InitKernel( + const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + std::vector var_shape = *(shape_vec[0]); + std::vector accum_shape = *(shape_vec[1]); + std::vector linear_shape = *(shape_vec[2]); + std::vector grad_shape = *(shape_vec[3]); + std::vector indices_shape = *(shape_vec[4]); + + Shard(&var_shape, 0); + Shard(&accum_shape, 0); + Shard(&linear_shape, 0); + + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be a 1D vector"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + lr_ = 0.01; + l1_ = 1e-8; + l2_ = 1e-8; + lr_power_ = -0.5; + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + +void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + std::vector indices_shape = *(shape_vec[0]); + indices_size_ = indices_shape[0]; + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +void SparseApplyFtrlPSKernel::ReInit(const std::vector &inputs) { + const auto &indices_addr = inputs[4]; + indices_size_ = indices_addr->size; + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +bool SparseApplyFtrlPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + ReInit(inputs); + int *indices = reinterpret_cast(inputs[4]->addr); + for (size_t i = 0; i < inputs[4]->size / sizeof(int); i++) { + indices[i] -= rank_id_ * var_first_dim_size_; + } + return Launch(inputs, workspace, outputs); +} + +const std::vector &SparseApplyFtrlPSKernel::input_sizes() const { return GetInputSizeList(); } + +const std::vector &SparseApplyFtrlPSKernel::output_sizes() const { return GetOutputSizeList(); } + +const std::vector &SparseApplyFtrlPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h new file mode 100644 index 0000000000..d97f19d349 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_PS_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +using mindspore::kernel::SparseApplyFtrlCPUKernel; +class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerKernel { + public: + SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ~SparseApplyFtrlPSKernel() override = default; + + void InitKernel(const std::shared_ptr>>> &) override; + void ReInit(const std::shared_ptr>>> &) override; + + bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + const std::vector &input_sizes() const override; + const std::vector &output_sizes() const override; + const std::vector &workspace_sizes() const override; + + protected: + void ReInit(const std::vector &) override; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc new file mode 100644 index 0000000000..0dddf1d3c4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "backend/kernel_compiler/cpu/reduce_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +const size_t kReduceTypeMax = 0; +const size_t kReduceTypeMean = 1; +const size_t kReduceTypeSum = 2; +const size_t kMaxDim = 100; +void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == "ReduceMax") { + reduce_type_ = kReduceTypeMax; + } else if (kernel_name == "ReduceMean") { + reduce_type_ = kReduceTypeMean; + } else if (kernel_name == "ReduceSum") { + reduce_type_ = kReduceTypeSum; + } else { + MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; + } + shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); + if (axis_addr->isa()) { + auto attr_axis = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); + if (attr_axis.size() > shape_.size()) { + MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size(); + } else if (attr_axis.empty()) { + axis_.push_back(shape_.size() - 1); + } else { + for (auto axis : attr_axis) { + if (IntToSize(axis) >= (shape_.size())) { + MS_LOG(EXCEPTION) << "axis value is oversize."; + } + axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); + } + } + } else if (axis_addr->isa()) { + int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis >= 0 && IntToSize(axis) >= shape_.size()) { + MS_LOG(EXCEPTION) << "axis value is oversize."; + } + axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; + } + for (size_t i = 0; i < shape_.size(); ++i) { + if (shape_[i] <= 0) { + MS_LOG(EXCEPTION) << "shape value is invalid."; + } + left_dims_ *= shape_[i]; + } + for (size_t i = 0; i < axis_.size(); ++i) { + stride_ *= shape_[axis_[i]]; + } + if (stride_ <= 0) { + MS_LOG(EXCEPTION) << "stride_ must greater than zero."; + } + left_dims_ = left_dims_ / stride_; +} +bool ReduceCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspaces*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output empty!"; + } + size_t out_float_size = left_dims_ * sizeof(float); + size_t in_float_size = stride_ * out_float_size; + if (inputs[0]->size != in_float_size || outputs[0]->size != out_float_size) { + MS_LOG(EXCEPTION) << "invalid input or output data size!"; + } + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + int size = inputs[0]->size / sizeof(float); + std::vector new_input(IntToSize(size), 0.0); + std::vector transpose_axis; + for (size_t i = 0; i < shape_.size(); ++i) { + bool insert = true; + for (size_t j = 0; j < axis_.size(); ++j) { + if (axis_[j] == i) { + insert = false; + break; + } + } + if (insert) { + transpose_axis.push_back(i); + } + } + (void)transpose_axis.insert(transpose_axis.end(), axis_.begin(), axis_.end()); + Transpose(size, input, shape_, transpose_axis, SizeToInt(shape_.size()), &new_input[0]); + if (reduce_type_ == kReduceTypeMax) { + for (size_t i = 0; i < left_dims_; ++i) { + float value = new_input[i * stride_]; + for (size_t k = 0; k < stride_; ++k) { + if (value < new_input[i * stride_ + k]) { + value = new_input[i * stride_ + k]; + } + } + output[i] = value; + } + } else { + for (size_t i = 0; i < left_dims_; ++i) { + float value = 0.0; + for (size_t k = 0; k < stride_; ++k) { + value += new_input[i * stride_ + k]; + } + if (reduce_type_ == kReduceTypeMean) { + output[i] = value / stride_; + } else { + output[i] = value; + } + } + } + return true; +} +void ReduceCPUKernel::Transpose(const int size, const float *input, const std::vector &input_shape, + const std::vector &input_axis, const int shape_size, float *output) { + int pos_array[kMaxDim]; + int size_offset[kMaxDim]; + size_offset[0] = size / SizeToInt(input_shape[0]); + for (int i = 1; i < shape_size; i++) { + size_offset[i] = size_offset[i - 1] / SizeToInt(input_shape[i]); + } + for (int position = 0; position < size; position += 1) { + int temp_position = position; + pos_array[0] = temp_position / size_offset[0]; + for (int i = 1; i < shape_size; i++) { + temp_position -= pos_array[i - 1] * size_offset[i - 1]; + pos_array[i] = temp_position / size_offset[i]; + } + int new_position = pos_array[SizeToInt(input_axis[shape_size - 1])]; + int new_position_size = 1; + for (int j = shape_size - 2; j >= 0; j--) { + new_position_size *= SizeToInt(input_shape[SizeToInt(input_axis[j + 1])]); + new_position += pos_array[SizeToInt(input_axis[j])] * new_position_size; + } + output[new_position] = input[position]; + } + return; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.h new file mode 100644 index 0000000000..a9696bad49 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ReduceCPUKernel : public CPUKernel { + public: + ReduceCPUKernel() = default; + ~ReduceCPUKernel() override = default; + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void Transpose(const int size, const float *input, const std::vector &input_shape, + const std::vector &input_axis, const int shape_size, float *output); + size_t reduce_type_; + std::vector axis_; + std::vector shape_; + size_t left_dims_ = 1; + size_t stride_ = 1; +}; +MS_REG_CPU_KERNEL(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceCPUKernel); +MS_REG_CPU_KERNEL(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceCPUKernel); +MS_REG_CPU_KERNEL(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.cc new file mode 100644 index 0000000000..f44c109ace --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "runtime/device/cpu/mpi/mpi_adapter.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr auto kRanksGroup = "group"; +} // namespace + +ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {} + +void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) { + auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); + if (op != nullptr) { + op_type_ = GetValue(op); + } + + auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); + if (ranks_group != nullptr) { + ranks_group_ = GetValue>(ranks_group); + } else { + MS_LOG(EXCEPTION) << "Miss attribute " << kRanksGroup; + } +} + +bool ReduceScatterCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto output_data_num = outputs[0]->size / sizeof(float); + auto mpi_instance = device::cpu::MPIAdapter::Instance(); + MS_EXCEPTION_IF_NULL(mpi_instance); + return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h new file mode 100644 index 0000000000..317d7df443 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ReduceScatterCPUKernel : public CPUKernel { + public: + ReduceScatterCPUKernel(); + ~ReduceScatterCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::string op_type_; + std::vector ranks_group_; +}; + +MS_REG_CPU_KERNEL(_HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceScatterCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc new file mode 100644 index 0000000000..6370fdc78a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/reshape_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } + +bool ReshapeCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output empty!"; + } + if (inputs[0]->size != outputs[0]->size) { + return false; + } + + if (inputs[0]->addr == outputs[0]->addr) { + return true; + } + + size_t mem_bits = outputs[0]->size; + auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h new file mode 100644 index 0000000000..04f1db3304 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ReshapeCPUKernel : public CPUKernel { + public: + ReshapeCPUKernel() = default; + ~ReshapeCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReshapeCPUKernel); + +MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReshapeCPUKernel); + +MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReshapeCPUKernel); +MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReshapeCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc new file mode 100644 index 0000000000..c6657a845a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc @@ -0,0 +1,179 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/slice_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + begin_[i] = begin_[i] + input_shape_[i]; + } + } + auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); + MS_EXCEPTION_IF_NULL(prim); + auto strides = prim->GetAttr(STRIDES); + if (strides != nullptr) { + strides_ = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); + end_ = AnfAlgo::GetNodeAttr>(kernel_node, END); + if (strides_.size() != end_.size() || strides_.size() != input_shape_.size()) { + MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; + } + for (size_t i = 0; i < strides_.size(); ++i) { + if (strides_[i] < 0) { + strides_[i] = (strides_[i] + input_shape_[i]) > 0 ? (strides_[i] + input_shape_[i]) : 0; + } + if (end_[i] < 0) { + end_[i] = (end_[i] + input_shape_[i]) > 0 ? (end_[i] + input_shape_[i]) : 0; + } + } + } else { + auto sizes = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); + if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) { + MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; + } + for (size_t i = 0; i < sizes.size(); ++i) { + if (sizes[i] < 0) { + sizes[i] = (sizes[i] + input_shape_[i]) > 0 ? (sizes[i] + input_shape_[i]) : 0; + } + strides_.emplace_back(1); + end_.emplace_back(begin_[i] + sizes[i]); + } + } + + ExpandAllMemberDims(); + CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); + CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); +} + +void SliceCPUKernel::ExpandAllMemberDims() { + CPUKernelUtils::ExpandDimsTo4(&output_shape_); + + auto input_len = input_shape_.size(); + if (input_len < 4) { + for (size_t i = 0; i < 4 - input_len; ++i) { + input_shape_.insert(input_shape_.begin(), 1); + begin_.insert(begin_.begin(), 0); + strides_.insert(strides_.begin(), 1); + end_.insert(end_.begin(), 1); + } + } +} + +bool SliceCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)}; + size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1], + begin_[2] * input_element_num_[2]}; + size_t in_step_size[3] = {strides_[0] * input_element_num_[0], strides_[1] * input_element_num_[1], + strides_[2] * input_element_num_[2]}; + + auto in_n_offset = in_start_offset[0]; + auto out_n_offset = 0; + for (int i = begin_[0]; i < end_[0]; + i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) { + if (can_copy_memory[0]) { + CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]); + continue; + } + auto in_c_offset = in_start_offset[1]; + auto out_c_offset = 0; + for (int j = begin_[1]; j < end_[1]; + j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) { + if (can_copy_memory[1]) { + CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, + input_element_num_[1]); + continue; + } + auto in_h_offset = in_start_offset[2]; + auto out_h_offset = 0; + for (int k = begin_[2]; k < end_[2]; + k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) { + if (can_copy_memory[2]) { + CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs, + out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]); + continue; + } + for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { + *output_addr++ = input_addr[in_n_offset + in_c_offset + in_h_offset + m]; + } + } + } + } + + return true; +} + +bool SliceCPUKernel::CanCopyMemoryOnAxis(size_t dim) const { + for (size_t i = dim + 1; i < 4; ++i) { + if (begin_[i] != 0 || end_[i] != SizeToInt(input_shape_[i]) || strides_[i] != 1) { + return false; + } + } + return true; +} + +void SliceCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t in_offset, + const std::vector &outputs, size_t out_offset, + size_t copy_num) const { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto in_buff_size = inputs[0]->size; + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto out_buff_size = outputs[0]->size; + + if ((in_offset + copy_num) * sizeof(float) > in_buff_size) { + MS_LOG(EXCEPTION) << "input memory out of bounds."; + } + if ((out_offset + copy_num) * sizeof(float) > out_buff_size) { + MS_LOG(EXCEPTION) << "output memory out of bounds."; + } + + auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset, + copy_num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret; + } +} + +void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SliceCPUKernel needs 1 inputs."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output."; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower."; + } + if (input_shape.size() == 0) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h new file mode 100644 index 0000000000..03b7ecdc17 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SliceCPUKernel : public CPUKernel { + public: + SliceCPUKernel() = default; + ~SliceCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void ExpandAllMemberDims(); + bool CanCopyMemoryOnAxis(size_t dim) const; + void CopyDataToOutput(const std::vector &inputs, size_t in_offset, + const std::vector &outputs, size_t out_offset, size_t copy_num) const; + void CheckParam(const CNodePtr &kernel_node) const; + std::vector begin_; + std::vector end_; + std::vector strides_; + std::vector input_shape_; + std::vector input_element_num_; + std::vector output_shape_; + std::vector output_element_num_; +}; + +MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceCPUKernel); +MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.cc new file mode 100644 index 0000000000..20904e0504 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.cc @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace kernel { +void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + + begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + begin_[i] = begin_[i] + output_shape_[i]; + } + } + + auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); + MS_EXCEPTION_IF_NULL(prim); + auto strides = prim->GetAttr(STRIDES); + if (strides != nullptr) { + strides_ = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); + end_ = AnfAlgo::GetNodeAttr>(kernel_node, END); + if (strides_.size() != end_.size() || strides_.size() != output_shape_.size()) { + MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; + } + for (size_t i = 0; i < strides_.size(); ++i) { + if (strides_[i] < 0) { + strides_[i] = (strides_[i] + output_shape_[i]) > 0 ? (strides_[i] + output_shape_[i]) : 0; + } + if (end_[i] < 0) { + end_[i] = (end_[i] + output_shape_[i]) > 0 ? (end_[i] + output_shape_[i]) : 0; + } + } + } else { + auto sizes = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); + if (sizes.size() != output_shape_.size() || begin_.size() != output_shape_.size()) { + MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; + } + for (size_t i = 0; i < sizes.size(); ++i) { + if (sizes[i] < 0) { + sizes[i] = (sizes[i] + output_shape_[i]) > 0 ? (sizes[i] + output_shape_[i]) : 0; + } + strides_.emplace_back(1); + end_.emplace_back(begin_[i] + sizes[i]); + } + } + + ExpandAllMemberDims(); + CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); + CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); +} + +void SliceGradCPUKernel::ExpandAllMemberDims() { + CPUKernelUtils::ExpandDimsTo4(&input_shape_); + + auto output_len = output_shape_.size(); + if (output_len < 4) { + for (size_t i = 0; i < 4 - output_len; ++i) { + output_shape_.insert(output_shape_.begin(), 1); + begin_.insert(begin_.begin(), 0); + strides_.insert(strides_.begin(), 1); + end_.insert(end_.begin(), 1); + } + } +} + +bool SliceGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + auto ret = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size); + if (ret != EOK) { + MS_LOG(ERROR) << "output buff memset fail. ret:" << ret; + return false; + } + + bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)}; + size_t out_start_offset[3] = {begin_[0] * output_element_num_[0], begin_[1] * output_element_num_[1], + begin_[2] * output_element_num_[2]}; + size_t out_step_size[3] = {strides_[0] * output_element_num_[0], strides_[1] * output_element_num_[1], + strides_[2] * output_element_num_[2]}; + + auto in_n_offset = 0; + auto out_n_offset = out_start_offset[0]; + for (int i = begin_[0]; i < end_[0]; + i += strides_[0], in_n_offset += input_element_num_[0], out_n_offset += out_step_size[0]) { + if (can_copy_memory[0]) { + CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]); + continue; + } + auto in_c_offset = 0; + auto out_c_offset = out_start_offset[1]; + for (int j = begin_[1]; j < end_[1]; + j += strides_[1], in_c_offset += input_element_num_[1], out_c_offset += out_step_size[1]) { + if (can_copy_memory[1]) { + CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, + input_element_num_[1]); + continue; + } + auto in_h_offset = 0; + auto out_h_offset = out_start_offset[2]; + for (int k = begin_[2]; k < end_[2]; + k += strides_[2], in_h_offset += input_element_num_[2], out_h_offset += out_step_size[2]) { + if (can_copy_memory[2]) { + CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs, + out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]); + continue; + } + for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { + output_addr[out_n_offset + out_c_offset + out_h_offset + m] = *input_addr++; + } + } + } + } + return true; +} + +bool SliceGradCPUKernel::CanCopyMemoryOnAxis(size_t dim) const { + for (size_t i = dim + 1; i < 4; ++i) { + if (begin_[i] != 0 || end_[i] != SizeToInt(output_shape_[i]) || strides_[i] != 1) { + return false; + } + } + return true; +} + +void SliceGradCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t in_offset, + const std::vector &outputs, size_t out_offset, + size_t copy_num) const { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto in_buff_size = inputs[0]->size; + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto out_buff_size = outputs[0]->size; + + if ((in_offset + copy_num) * sizeof(float) > in_buff_size) { + MS_LOG(EXCEPTION) << "input memory out of bounds."; + } + if ((out_offset + copy_num) * sizeof(float) > out_buff_size) { + MS_LOG(EXCEPTION) << "output memory out of bounds."; + } + + auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset, + copy_num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret; + } +} + +void SliceGradCPUKernel::CheckParam(const CNodePtr &kernel_node) const { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output."; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceGradGpuKernel only support 4d or lower."; + } + if (input_shape.size() == 0) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h new file mode 100644 index 0000000000..ec480d7e80 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SliceGradCPUKernel : public CPUKernel { + public: + SliceGradCPUKernel() = default; + ~SliceGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void ExpandAllMemberDims(); + bool CanCopyMemoryOnAxis(size_t dim) const; + void CopyDataToOutput(const std::vector &inputs, size_t in_offset, + const std::vector &outputs, size_t out_offset, size_t copy_num) const; + void CheckParam(const CNodePtr &kernel_node) const; + std::vector begin_; + std::vector end_; + std::vector strides_; + std::vector input_shape_; + std::vector input_element_num_; + std::vector output_shape_; + std::vector output_element_num_; +}; + +MS_REG_CPU_KERNEL( + SliceGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGradCPUKernel); +MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc new file mode 100644 index 0000000000..2ff8e77fcd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc @@ -0,0 +1,177 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyAdamInputSize = 11; + +void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto m = input_params->m_; + auto m_t = input_params->m_t_; + auto v = input_params->v_; + auto beta1 = input_params->beta1_; + auto beta2 = input_params->beta2_; + auto use_nesterov = input_params->use_nesterov_; + auto unique_sparse_grad = input_params->sparse_grad_; + auto var_first_dim_size = input_params->var_first_dim_size_; + auto var_outer_dim_size = input_params->var_outer_dim_size_; + for (size_t i = start; i < end; ++i) { + int index = unique_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= var_first_dim_size) { + MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; + } + size_t start_index = var_outer_dim_size * index; + size_t end_index = start_index + var_outer_dim_size; + for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { + auto summed_grad = unique_sparse_grad.value_[k]; + m[j] += (1 - beta1) * summed_grad; + v[j] += (1 - beta2) * summed_grad * summed_grad; + if (use_nesterov) { + m_t[j] = m[j] * beta1 + (1 - beta1) * summed_grad; + } + } + } +} + +void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto m = input_params->m_; + auto v = input_params->v_; + auto beta1 = input_params->beta1_; + auto beta2 = input_params->beta2_; + for (size_t i = start; i < end; ++i) { + m[i] *= beta1; + v[i] *= beta2; + } +} + +void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto var = input_params->var_; + auto m = input_params->m_; + auto v = input_params->v_; + auto lr = input_params->lr_; + auto epsilon = input_params->epsilon_; + for (size_t i = start; i < end; ++i) { + var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon); + } +} +} // namespace + +void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); +} + +void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + std::vector v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 10); + if (!IsSameShape(var_shape, m_shape)) { + MS_LOG(EXCEPTION) << "var and m should have the same shape"; + } + if (!IsSameShape(var_shape, v_shape)) { + MS_LOG(EXCEPTION) << "var and v should have the same shape"; + } + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "var must be at least 1D"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be 1D"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); + } +} + +bool SparseApplyAdamCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector & /*outputs*/) { + if (inputs.size() < kSparseApplyAdamInputSize) { + MS_LOG(EXCEPTION) << "Error input size!"; + } + + auto var = reinterpret_cast(inputs[0]->addr); + auto m = reinterpret_cast(inputs[1]->addr); + auto v = reinterpret_cast(inputs[2]->addr); + auto beta1_power = reinterpret_cast(inputs[3]->addr)[0]; + if (beta1_power == 1) { + MS_LOG(EXCEPTION) << "The beta1_power should not be 1"; + } + auto beta2_power = reinterpret_cast(inputs[4]->addr)[0]; + auto lr = reinterpret_cast(inputs[5]->addr)[0]; + auto beta1 = reinterpret_cast(inputs[6]->addr)[0]; + auto beta2 = reinterpret_cast(inputs[7]->addr)[0]; + auto epsilon = reinterpret_cast(inputs[8]->addr)[0]; + auto grad = reinterpret_cast(inputs[9]->addr); + auto indices = reinterpret_cast(inputs[10]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); + auto m_t = reinterpret_cast(workspace[2]->addr); + + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, + var_outer_dim_size_); + size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_; + lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); + + MultiThreadComputeParams input_params; + input_params.m_ = m; + input_params.v_ = v; + input_params.beta1_ = beta1; + input_params.beta2_ = beta2; + MultiThreadCompute(ComputeMomentum, &input_params, total_dim_size); + + input_params.m_t_ = m_t; + input_params.use_nesterov_ = use_nesterov_; + input_params.sparse_grad_ = unique_sparse_grad; + input_params.var_first_dim_size_ = var_first_dim_size_; + input_params.var_outer_dim_size_ = var_outer_dim_size_; + MultiThreadCompute(ComputeAdam, &input_params, unique_sparse_grad.indices_size_); + + if (use_nesterov_) { + input_params.m_ = input_params.m_t_; + } + input_params.var_ = var; + input_params.lr_ = lr; + input_params.epsilon_ = epsilon; + MultiThreadCompute(ComputeWeight, &input_params, total_dim_size); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h new file mode 100644 index 0000000000..5d3d4193f7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyAdamCPUKernel : public CPUKernel { + public: + SparseApplyAdamCPUKernel() = default; + ~SparseApplyAdamCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + void InitInputOutputSize(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + size_t indices_size_{0}; + size_t var_first_dim_size_{0}; + size_t var_outer_dim_size_{1}; + bool use_nesterov_{false}; +}; + +MS_REG_CPU_KERNEL(SparseApplyAdam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyAdamCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc new file mode 100644 index 0000000000..2662604e19 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyFtrlInputSize = 5; + +void ComputeFtrl(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto var = input_params->var_; + auto accum = input_params->accum_; + auto linear = input_params->linear_; + auto lr = input_params->lr_; + auto l1 = input_params->l1_; + auto l2_plus = 2 * input_params->l2_; + auto lr_power = input_params->lr_power_; + auto unique_sparse_grad = input_params->sparse_grad_; + auto var_first_dim_size = input_params->var_first_dim_size_; + auto var_outer_dim_size = input_params->var_outer_dim_size_; + for (size_t i = start; i < end; ++i) { + int index = unique_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= var_first_dim_size) { + MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; + } + size_t start_index = var_outer_dim_size * index; + size_t end_index = start_index + var_outer_dim_size; + for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { + auto summed_grad = unique_sparse_grad.value_[k]; + auto accum_new = accum[j] + summed_grad * summed_grad; + float y; + if (lr_power == -0.5) { + y = std::sqrt(accum_new); + linear[j] += summed_grad - (y - std::sqrt(accum[j])) / lr * var[j]; + } else { + y = std::pow(accum_new, -lr_power); + linear[j] += summed_grad - (y - std::pow(accum[j], -lr_power)) / lr * var[j]; + } + accum[j] = accum_new; + auto x = Sign(linear[j]) * l1 - linear[j]; + y = y / lr + l2_plus; + var[j] = std::fabs(linear[j]) > l1 ? x / y : 0; + } + } +} +} // namespace + +void SparseApplyFtrlCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + +void SparseApplyFtrlCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + std::vector linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); + if (!IsSameShape(var_shape, accum_shape)) { + MS_LOG(EXCEPTION) << "var and accum should have the same shape"; + } + if (!IsSameShape(var_shape, linear_shape)) { + MS_LOG(EXCEPTION) << "var and linear should have the same shape"; + } + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "var must be at least 1D"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be a 1D vector"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + lr_ = AnfAlgo::GetNodeAttr(kernel_node, "lr"); + if (lr_ <= 0) { + MS_LOG(EXCEPTION) << "lr should be a positive scalar"; + } + l1_ = AnfAlgo::GetNodeAttr(kernel_node, "l1"); + if (l1_ < 0) { + MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar"; + } + l2_ = AnfAlgo::GetNodeAttr(kernel_node, "l2"); + if (l2_ < 0) { + MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar"; + } + lr_power_ = AnfAlgo::GetNodeAttr(kernel_node, "lr_power"); + if (lr_power_ > 0) { + MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; + } +} + +bool SparseApplyFtrlCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector & /*outputs*/) { + if (inputs.size() < kSparseApplyFtrlInputSize) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + + auto var = reinterpret_cast(inputs[0]->addr); + auto accum = reinterpret_cast(inputs[1]->addr); + auto linear = reinterpret_cast(inputs[2]->addr); + auto grad = reinterpret_cast(inputs[3]->addr); + auto indices = reinterpret_cast(inputs[4]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); + auto tmp_grad = reinterpret_cast(workspace[2]->addr); + auto tmp_indices = reinterpret_cast(workspace[3]->addr); + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + SparseGradient tmp_sparse_grad({tmp_grad, tmp_indices, indices_size_}); + TwoLevelReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &tmp_sparse_grad, &unique_sparse_grad, + var_first_dim_size_, var_outer_dim_size_); + + MultiThreadComputeParams input_params; + input_params.var_ = var; + input_params.accum_ = accum; + input_params.linear_ = linear; + input_params.lr_ = lr_; + input_params.l1_ = l1_; + input_params.l2_ = l2_; + input_params.lr_power_ = lr_power_; + input_params.sparse_grad_ = unique_sparse_grad; + input_params.var_first_dim_size_ = var_first_dim_size_; + input_params.var_outer_dim_size_ = var_outer_dim_size_; + MultiThreadCompute(ComputeFtrl, &input_params, unique_sparse_grad.indices_size_); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h new file mode 100644 index 0000000000..af8796d8a5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyFtrlCPUKernel : public CPUKernel { + public: + SparseApplyFtrlCPUKernel() = default; + ~SparseApplyFtrlCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + void InitInputOutputSize(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + size_t indices_size_{0}; + size_t var_first_dim_size_{0}; + size_t var_outer_dim_size_{1}; + float lr_{0}; + float l1_{0}; + float l2_{0}; + float lr_power_{0}; +}; + +MS_REG_CPU_KERNEL(SparseApplyFtrl, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyFtrlCPUKernel); + +MS_REG_CPU_KERNEL(SparseApplyFtrlNoReturn, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyFtrlCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc new file mode 100644 index 0000000000..636d92dcbb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyLazyAdamInputSize = 11; + +void ComputeLazyAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto var = input_params->var_; + auto m = input_params->m_; + auto v = input_params->v_; + auto lr = input_params->lr_; + auto beta1 = input_params->beta1_; + auto beta2 = input_params->beta2_; + auto epsilon = input_params->epsilon_; + auto use_nesterov = input_params->use_nesterov_; + auto unique_sparse_grad = input_params->sparse_grad_; + auto var_first_dim_size = input_params->var_first_dim_size_; + auto var_outer_dim_size = input_params->var_outer_dim_size_; + for (size_t i = start; i < end; ++i) { + int index = unique_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= var_first_dim_size) { + MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range"; + } + size_t start_index = var_outer_dim_size * index; + size_t end_index = start_index + var_outer_dim_size; + for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { + auto summed_grad = unique_sparse_grad.value_[k]; + m[j] = beta1 * m[j] + (1 - beta1) * summed_grad; + v[j] = beta2 * v[j] + (1 - beta2) * summed_grad * summed_grad; + if (use_nesterov) { + var[j] -= lr * (m[j] * beta1 + (1 - beta1) * summed_grad) / (std::sqrt(v[j]) + epsilon); + } else { + var[j] -= lr * m[j] / (std::sqrt(v[j]) + epsilon); + } + } + } +} +} // namespace + +void SparseApplyLazyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + +void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + std::vector v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 10); + if (!IsSameShape(var_shape, m_shape)) { + MS_LOG(EXCEPTION) << "var and m should have the same shape"; + } + if (!IsSameShape(var_shape, v_shape)) { + MS_LOG(EXCEPTION) << "var and v should have the same shape"; + } + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "var must be at least 1D"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be 1D"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); + } +} + +bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector & /*outputs*/) { + if (inputs.size() < kSparseApplyLazyAdamInputSize) { + MS_LOG(EXCEPTION) << "Error input size!"; + } + + auto var = reinterpret_cast(inputs[0]->addr); + auto m = reinterpret_cast(inputs[1]->addr); + auto v = reinterpret_cast(inputs[2]->addr); + auto beta1_power = reinterpret_cast(inputs[3]->addr)[0]; + if (beta1_power == 1) { + MS_LOG(EXCEPTION) << "The beta1_power should not be 1"; + } + auto beta2_power = reinterpret_cast(inputs[4]->addr)[0]; + auto lr = reinterpret_cast(inputs[5]->addr)[0]; + auto beta1 = reinterpret_cast(inputs[6]->addr)[0]; + auto beta2 = reinterpret_cast(inputs[7]->addr)[0]; + auto epsilon = reinterpret_cast(inputs[8]->addr)[0]; + auto grad = reinterpret_cast(inputs[9]->addr); + auto indices = reinterpret_cast(inputs[10]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); + auto tmp_grad = reinterpret_cast(workspace[2]->addr); + auto tmp_indices = reinterpret_cast(workspace[3]->addr); + + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + SparseGradient tmp_sparse_grad({tmp_grad, tmp_indices, indices_size_}); + TwoLevelReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &tmp_sparse_grad, &unique_sparse_grad, + var_first_dim_size_, var_outer_dim_size_); + + lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); + MultiThreadComputeParams input_params; + input_params.var_ = var; + input_params.m_ = m; + input_params.v_ = v; + input_params.lr_ = lr; + input_params.beta1_ = beta1; + input_params.beta2_ = beta2; + input_params.epsilon_ = epsilon; + input_params.use_nesterov_ = use_nesterov_; + input_params.sparse_grad_ = unique_sparse_grad; + input_params.var_first_dim_size_ = var_first_dim_size_; + input_params.var_outer_dim_size_ = var_outer_dim_size_; + MultiThreadCompute(ComputeLazyAdam, &input_params, unique_sparse_grad.indices_size_); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h new file mode 100644 index 0000000000..ee95db8f33 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyLazyAdamCPUKernel : public CPUKernel { + public: + SparseApplyLazyAdamCPUKernel() = default; + ~SparseApplyLazyAdamCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + void InitInputOutputSize(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t indices_size_{0}; + size_t var_first_dim_size_{0}; + size_t var_outer_dim_size_{1}; + bool use_nesterov_{false}; +}; + +MS_REG_CPU_KERNEL(SparseApplyLazyAdam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyLazyAdamCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc new file mode 100644 index 0000000000..efba35ad8c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc @@ -0,0 +1,139 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyProximalAdagradInputSize = 7; + +void ComputeProximalAdagrad(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto var = input_params->var_; + auto accum = input_params->accum_; + auto lr = input_params->lr_; + auto l1 = input_params->l1_; + auto l2 = input_params->l2_; + auto unique_sparse_grad = input_params->sparse_grad_; + auto var_first_dim_size = input_params->var_first_dim_size_; + auto var_outer_dim_size = input_params->var_outer_dim_size_; + for (size_t i = start; i < end; ++i) { + int index = unique_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= var_first_dim_size) { + MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; + } + size_t start_index = var_outer_dim_size * index; + size_t end_index = start_index + var_outer_dim_size; + for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { + auto summed_grad = unique_sparse_grad.value_[k]; + accum[j] += summed_grad * summed_grad; + auto learning_rate = lr * (1 / std::sqrt(accum[j])); + auto prox_v = var[j]; + prox_v -= summed_grad * learning_rate; + if (l1 > 0) { + var[j] = Sign(prox_v) * std::fmax(std::fabs(prox_v) - learning_rate * l1, static_cast(0.0)) / + (1 + l2 * learning_rate); + } else { + var[j] = prox_v / (1 + l2 * learning_rate); + } + } + } +} +} // namespace + +void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + +void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + std::vector lr_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + std::vector l1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + std::vector l2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); + std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 5); + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 6); + if (!IsSameShape(var_shape, accum_shape)) { + MS_LOG(EXCEPTION) << "var and accum should have the same shape"; + } + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "var must be at least 1D"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be a 1D vector"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + if (!lr_shape.empty()) { + MS_LOG(EXCEPTION) << "lr is not a scalar"; + } + if (!l1_shape.empty()) { + MS_LOG(EXCEPTION) << "l1 is not a scalar"; + } + if (!l2_shape.empty()) { + MS_LOG(EXCEPTION) << "l2 is not a scalar"; + } +} + +bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector & /*outputs*/) { + if (inputs.size() < kSparseApplyProximalAdagradInputSize) { + MS_LOG(EXCEPTION) << "Wrong input size!"; + } + + auto var = reinterpret_cast(inputs[0]->addr); + auto accum = reinterpret_cast(inputs[1]->addr); + auto lr = reinterpret_cast(inputs[2]->addr)[0]; + auto l1 = reinterpret_cast(inputs[3]->addr)[0]; + auto l2 = reinterpret_cast(inputs[4]->addr)[0]; + auto grad = reinterpret_cast(inputs[5]->addr); + auto indices = reinterpret_cast(inputs[6]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, + var_outer_dim_size_); + + MultiThreadComputeParams input_params; + input_params.var_ = var; + input_params.accum_ = accum; + input_params.lr_ = lr; + input_params.l1_ = l1; + input_params.l2_ = l2; + input_params.sparse_grad_ = unique_sparse_grad; + input_params.var_first_dim_size_ = var_first_dim_size_; + input_params.var_outer_dim_size_ = var_outer_dim_size_; + MultiThreadCompute(ComputeProximalAdagrad, &input_params, unique_sparse_grad.indices_size_); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h new file mode 100644 index 0000000000..56b180ec0b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyProximalAdagradCPUKernel : public CPUKernel { + public: + SparseApplyProximalAdagradCPUKernel() = default; + ~SparseApplyProximalAdagradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + void InitInputOutputSize(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t indices_size_{0}; + size_t var_first_dim_size_{0}; + size_t var_outer_dim_size_{1}; +}; + +MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyProximalAdagradCPUKernel); + +MS_REG_CPU_KERNEL(SparseApplyProximalAdagradNoReturn, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyProximalAdagradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.cc new file mode 100644 index 0000000000..1e759390a2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "backend/kernel_compiler/cpu/sub_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void SubCPUKernel::InitKernel(const CNodePtr &kernel_node) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + if (shape.size() == 1) { + if (shape[0] != 1) { + MS_LOG(EXCEPTION) << "input 1 only support scalar"; + } + } else { + MS_LOG(EXCEPTION) << "input 1 only support scalar"; + } +} + +void sub_task(const int *in_addr, int *out_addr, size_t lens, int offset) { + for (size_t i = 0; i < lens; i++) { + out_addr[i] = in_addr[i] - offset; + } +} + +bool SubCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); +#endif + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + offset_ = *reinterpret_cast(inputs[1]->addr); + MS_LOG(INFO) << "offset: " << offset_; + auto lens = inputs[0]->size / sizeof(int); + if (lens < 10000) { + for (size_t i = 0; i < lens; i++) { + output_addr[i] = input_addr[i] - offset_; + } + } else { + const size_t thread_num = 4; + std::thread threads[4]; + size_t process_lens = (lens + thread_num - 1) / thread_num; + size_t process_offset = 0; + for (size_t i = 0; i < thread_num; i++) { + threads[i] = + std::thread(sub_task, input_addr + process_offset, output_addr + process_offset, process_lens, offset_); + if (process_offset + process_lens > lens) { + process_lens = lens - process_offset; + process_offset = lens; + } else { + process_offset += process_lens; + } + } + for (size_t i = 0; i < thread_num; i++) { + threads[i].join(); + } + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "SubscaleCPUKernel, used time: " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); + time += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "SubCPUKernel, used time: " << time << " us"; +#endif + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h new file mode 100644 index 0000000000..d1b55ded90 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SubCPUKernel : public CPUKernel { + public: + SubCPUKernel() : offset_(0) {} + ~SubCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + int offset_; +}; + +MS_REG_CPU_KERNEL( + Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SubCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.cc new file mode 100644 index 0000000000..8ec3698cf6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/transpose_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +namespace mindspore { +namespace kernel { +const size_t kMaxDim = 100; +void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + axis_ = AnfAlgo::GetNodeAttr>(kernel_node, "perm"); + if (shape_.size() != axis_.size()) { + MS_LOG(EXCEPTION) << "The size of input shape and transpose axis shape must be equal."; + } +} +bool TransposeCPUFwdKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + size_t size = IntToSize(inputs[0]->size / sizeof(float)); + size_t shape_size = IntToSize(shape_.size()); + if (shape_size > kMaxDim) { + MS_LOG(EXCEPTION) << "Input is " << shape_size << "-D, but transpose supports max " << kMaxDim << "-D inputs."; + } + size_t pos_array[kMaxDim]; + size_t size_offset[kMaxDim]; + size_offset[0] = size / shape_[0]; + for (size_t i = 1; i < shape_size; i++) { + size_offset[i] = size_offset[SizeToInt(i) - 1] / shape_[i]; + } + for (size_t position = 0; position < size; position += 1) { + size_t temp_position = position; + pos_array[0] = temp_position / size_offset[0]; + for (size_t i = 1; i < shape_size; i++) { + temp_position -= pos_array[SizeToInt(i) - 1] * size_offset[i - 1]; + pos_array[i] = temp_position / size_offset[i]; + } + size_t new_position = pos_array[axis_[SizeToInt(shape_size) - 1]]; + size_t new_position_size = 1; + for (int j = shape_size - 2; j >= 0; j--) { + new_position_size *= shape_[axis_[j + 1]]; + new_position += pos_array[axis_[j]] * new_position_size; + } + output[new_position] = input[position]; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.h new file mode 100644 index 0000000000..15796f9f3c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +namespace mindspore { +namespace kernel { +class TransposeCPUFwdKernel : public CPUKernel { + public: + TransposeCPUFwdKernel() = default; + ~TransposeCPUFwdKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::vector shape_; + std::vector axis_; +}; + +MS_REG_CPU_KERNEL(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + TransposeCPUFwdKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc new file mode 100644 index 0000000000..39f535a2af --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + ArgmaxGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + ArgmaxGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h new file mode 100644 index 0000000000..61a53c5b40 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h @@ -0,0 +1,106 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh" +namespace mindspore { +namespace kernel { +#define ARGMAX_MAX_DIMENSION 2 +template +class ArgmaxGpuKernel : public GpuKernel { + public: + ArgmaxGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), batch_size_(0), channel_size_(0), axis_(0) {} + ~ArgmaxGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + int *output = GetDeviceAddress(outputs, 0); + CalArgmax(input, SizeToInt(batch_size_), SizeToInt(channel_size_), axis_, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but argmax needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but argmax needs 1 output."; + return false; + } + auto output_type = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("output_type")); + if (output_type->type_id() != TypeId::kNumberTypeInt32) { + MS_LOG(EXCEPTION) << "Argmax only supports int32 output type."; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > ARGMAX_MAX_DIMENSION) { + MS_LOG(EXCEPTION) << "Input is " << input_shape.size() << "-D, but argmax supports max " << ARGMAX_MAX_DIMENSION + << "-D inputs."; + } + + axis_ = GetAttr(kernel_node, "axis"); + if (axis_ < 0) { + axis_ += SizeToInt(input_shape.size()); + } + if (input_shape.size() == 1) { + batch_size_ = 0; + channel_size_ = input_shape[0]; + input_size_ = sizeof(T) * channel_size_; + output_size_ = sizeof(int); + } else { + batch_size_ = input_shape[0]; + channel_size_ = input_shape[1]; + input_size_ = sizeof(T) * batch_size_ * channel_size_; + output_size_ = (axis_ == 1) ? sizeof(int) * batch_size_ : sizeof(int) * channel_size_; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t batch_size_; + size_t channel_size_; + int axis_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc new file mode 100644 index 0000000000..5ead387ccc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + ArgmaxWithValueGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + ArgmaxWithValueGpuKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h new file mode 100644 index 0000000000..d2369023fb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh" +namespace mindspore { +namespace kernel { +template +class ArgmaxWithValueGpuKernel : public GpuKernel { + public: + ArgmaxWithValueGpuKernel() : input_size_(0), output_size_(0), bound_(0), outerSize_(0), innerSize_(0) {} + ~ArgmaxWithValueGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 1); + S *index = GetDeviceAddress(outputs, 0); + CalArgmaxWithValue(input, bound_, outerSize_, innerSize_, index, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + std::vector shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); + int dims = shape.size(); + int axis = GetAttr(kernel_node, "axis"); + if (axis < 0) { + axis += dims; + } + input_size_ = sizeof(T); + for (auto x : shape) { + input_size_ *= x; + } + output_size_ = sizeof(S); + for (auto x : output_shape) { + output_size_ *= x; + } + bound_ = shape[axis]; + outerSize_ = 1; + for (int i = axis - 1; i >= 0; i--) { + outerSize_ *= shape[i]; + } + + innerSize_ = 1; + for (int i = axis + 1; i < dims; i++) { + innerSize_ *= shape[i]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_ / sizeof(S) * sizeof(T)); + } + + private: + size_t input_size_; + size_t output_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + int bound_; + int outerSize_; + int innerSize_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc new file mode 100644 index 0000000000..5d34a1c9c2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArrayReduceGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ArrayReduceGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArrayReduceGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ArrayReduceGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArrayReduceGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ArrayReduceGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h new file mode 100644 index 0000000000..b96f63670d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h @@ -0,0 +1,237 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +namespace mindspore { +namespace kernel { +const std::map kReduceTypeMap = { + {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, + {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, + {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, +}; +template +class ArrayReduceGpuKernel : public GpuKernel { + public: + ArrayReduceGpuKernel() + : cudnn_handle_(nullptr), + reduce_tensor_op_(CUDNN_REDUCE_TENSOR_ADD), + data_type_(CUDNN_DATA_FLOAT), + nan_prop_(CUDNN_NOT_PROPAGATE_NAN), + reduce_indices_(CUDNN_REDUCE_TENSOR_NO_INDICES), + reduce_tensor_descriptor_(nullptr), + inputA_descriptor_(nullptr), + outputC_descriptor_(nullptr), + keep_dims_(false), + all_match_(false), + is_null_input_(false), + input_size_(0), + output_size_(0), + workspace_size_(0) {} + ~ArrayReduceGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + T *workspace_addr = GetDeviceAddress(workspace, 0); + + const float alpha = 1; + const float beta = 0; + if (all_match_) { + MS_LOG(WARNING) + << "The corresponding dimensions of the input and output tensors all match. No need to call cuDNN kernel."; + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in ArrayReduceGpuKernel::Launch."); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha, + inputA_descriptor_, input_addr, &beta, outputC_descriptor_, output_addr), + "cudnnReduceTensor failed."); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but reduce op needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but reduce op needs 1 output."; + return false; + } + int input_dim_length = SizeToInt(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size()); + + if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa() || + AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa()) { + auto attr_axis = GetAttr>(kernel_node, "axis"); + if (attr_axis.empty()) { + axis_.push_back(-1); + } else { + for (auto axis : attr_axis) { + axis < 0 ? axis_.push_back(axis + input_dim_length) : axis_.push_back(axis); + } + } + } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa()) { + int axis = GetAttr(kernel_node, "axis"); + axis < 0 ? axis_.push_back(axis + input_dim_length) : axis_.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; + } + keep_dims_ = GetAttr(kernel_node, "keep_dims"); + + auto inputA_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto outputC_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(inputA_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ArrayReduceGpuKernel input is null"; + InitSizeLists(); + return true; + } + InferInAndOutDesc(inputA_shape, outputC_shape); + InferArrayReduceType(kernel_node); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_), + "cudnnCreateReduceTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&inputA_descriptor_), + "cudnnCreateTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&outputC_descriptor_), + "cudnnCreateTensorDescriptor failed."); + } + void InitSizeLists() override { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(inputA_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed."); + input_size_list_.push_back(input_size_); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(outputC_descriptor_, &output_size_), + "cudnnGetTensorSizeInBytes failed."); + output_size_list_.push_back(output_size_); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_, inputA_descriptor_, outputC_descriptor_, + &workspace_size_), + "cudnnGetReductionWorkspaceSize failed."); + workspace_size_list_.push_back(workspace_size_); + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_), + "cudnnDestroyReduceTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputA_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(outputC_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + } + void InferArrayReduceType(const CNodePtr &kernel_node) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kReduceTypeMap.find(kernel_name); + if (iter == kReduceTypeMap.end()) { + MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; + } else { + reduce_tensor_op_ = iter->second; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, reduce_tensor_op_, CUDNN_DATA_FLOAT, nan_prop_, + reduce_indices_, CUDNN_32BIT_INDICES), + "cudnnSetReduceTensorDescriptor failed"); + return; + } + void InferInAndOutDesc(const std::vector &input_shape, const std::vector &output_shape) { + std::vector inputA; + std::vector outputC_shape = output_shape; + ShapeNdTo4d(input_shape, &inputA); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0], + inputA[1], inputA[2], inputA[3]), + "cudnnSetTensor4dDescriptor failed"); + + if (axis_[0] == -1) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1), + "cudnnSetTensor4dDescriptor failed"); + if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) { + all_match_ = true; + } + return; + } + if (!keep_dims_) { + for (auto i : axis_) { + (void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); + } + } + std::vector outputC; + ShapeNdTo4d(outputC_shape, &outputC); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, + outputC[0], outputC[1], outputC[2], outputC[3]), + "cudnnSetTensor4dDescriptor failed"); + if (inputA == outputC) { + all_match_ = true; + } + return; + } + + cudnnHandle_t cudnn_handle_; + cudnnReduceTensorOp_t reduce_tensor_op_; + cudnnDataType_t data_type_; + cudnnNanPropagation_t nan_prop_; + cudnnReduceTensorIndices_t reduce_indices_; + cudnnReduceTensorDescriptor_t reduce_tensor_descriptor_; + cudnnTensorDescriptor_t inputA_descriptor_; + cudnnTensorDescriptor_t outputC_descriptor_; + + std::vector axis_; + bool keep_dims_; + bool all_match_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc new file mode 100644 index 0000000000..f5979dc62d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConcatV2GpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Concat, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ConcatV2GpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE( + Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ConcatV2GpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h new file mode 100644 index 0000000000..15ccedcaec --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h @@ -0,0 +1,128 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ConcatV2GpuFwdKernel : public GpuKernel { + public: + ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {} + ~ConcatV2GpuFwdKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (inputs.size() == 2) { + T *input_0 = GetDeviceAddress(inputs, 0); + T *input_1 = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, + reinterpret_cast(stream_ptr)); + } + + if (inputs.size() == 3) { + T *input_0 = GetDeviceAddress(inputs, 0); + T *input_1 = GetDeviceAddress(inputs, 1); + T *input_2 = GetDeviceAddress(inputs, 2); + T *output = GetDeviceAddress(outputs, 0); + ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output, + reinterpret_cast(stream_ptr)); + } + + if (inputs.size() == 4) { + T *input_0 = GetDeviceAddress(inputs, 0); + T *input_1 = GetDeviceAddress(inputs, 1); + T *input_2 = GetDeviceAddress(inputs, 2); + T *input_3 = GetDeviceAddress(inputs, 3); + T *output = GetDeviceAddress(outputs, 0); + ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output, + reinterpret_cast(stream_ptr)); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + + axis_ = GetAttr(kernel_node, "axis"); + if (axis_ < 0) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + axis_ += SizeToInt(input_shape.size()); + } + + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; i++) { + auto input_size = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + for (size_t j = 0; j < input_shape.size(); j++) { + input_size *= SizeToInt(input_shape[j]); + if (j >= IntToSize(axis_)) { + w_[i] *= SizeToInt(input_shape[j]); + } + input_size_list_.push_back(input_size); + } + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + output_size_ = sizeof(T); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + } + output_size_list_.push_back(output_size_); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override {} + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num < 2 || input_num > 4) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ConcatV2GpuFwdKernel needs 1 output."; + return false; + } + return true; + } + int w_[4] = {1, 1, 1, 1}; + int axis_; + size_t output_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.cc new file mode 100644 index 0000000000..8d3c06e805 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + GatherV2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + GatherV2, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + GatherGpuFwdKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h new file mode 100644 index 0000000000..2211361cee --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GATHER_GPU_KERNEL_H +#define MINDSPORE_GATHER_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh" + +namespace mindspore { +namespace kernel { +template +class GatherGpuFwdKernel : public GpuKernel { + public: + GatherGpuFwdKernel() : axis_(0), handle_(nullptr) {} + ~GatherGpuFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_addr = GetDeviceAddress(inputs, 0); + S *indices_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + auto input_dim1 = input_shapes_[IntToSize(axis_)]; + Gather(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuFwdKernel needs 2."; + } + input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + axis_ = GetAttr(kernel_node, "axis"); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_shapes_.size()); + } + + Reshape(); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + void InitSizeLists() override { + size_t size = GetSize(input_shapes_); + input_size_list_.push_back(size); + + size = GetSize(indices_shapes_); + input_size_list_.push_back(size); + + size = GetSize(output_shapes_); + output_size_list_.push_back(size); + } + + private: + void Reshape() { + size_t dim_before_axis = 1; + for (size_t i = 0; i < IntToSize(axis_); i++) { + dim_before_axis *= output_shapes_[i]; + } + + size_t dim_of_indices = 1; + for (size_t i = 0; i < indices_shapes_.size(); i++) { + dim_of_indices *= indices_shapes_[i]; + } + + size_t dim_after_indices = 1; + for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) { + dim_after_indices *= output_shapes_[i]; + } + + dims_[0] = dim_before_axis; + dims_[1] = dim_of_indices; + dims_[2] = dim_after_indices; + return; + } + size_t GetSize(const std::vector &shape) const { + if (shape.size() == 0) { + return 0; + } + size_t result = sizeof(T); + for (size_t i = 0; i < shape.size(); i++) { + result *= shape[i]; + } + return result; + } + + std::vector input_shapes_; + std::vector indices_shapes_; + std::vector output_shapes_; + + size_t dims_[3] = {}; + int axis_; + cudnnHandle_t handle_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_GATHER_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.cc new file mode 100644 index 0000000000..e764a08dc8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(OneHot, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + OneHotGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO(OneHot, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + OneHotGpuFwdKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h new file mode 100644 index 0000000000..6c46a63e69 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h @@ -0,0 +1,105 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class OneHotGpuFwdKernel : public GpuKernel { + public: + OneHotGpuFwdKernel() : input_size_(1), output_size_(1), depth_(0), left_dim_size_(1), right_dim_size_(1) {} + ~OneHotGpuFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + const S *indices = GetDeviceAddress(inputs, 0); + const T *on_value = GetDeviceAddress(inputs, 1); + const T *off_value = GetDeviceAddress(inputs, 2); + T *output = GetDeviceAddress(outputs, 0); + OneHot(indices, depth_, on_value, off_value, left_dim_size_, right_dim_size_, output, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + int axis = GetAttr(kernel_node, "axis"); + auto input = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output = AnfAlgo::GetOutputInferShape(kernel_node, 0); + int input_size = SizeToInt(input.size()); + const int default_axis = -1; + + // Compress arbitrary tensor dimensions into three dimensions (left_dims, depth, right_dims). + for (int i = 0; i < input_size; i++) { + auto dim_size = input[IntToSize(i)]; + if (axis == default_axis || i < axis) { + left_dim_size_ *= dim_size; + } + if (axis != default_axis && i >= axis) { + right_dim_size_ *= dim_size; + } + } + for (auto size : input) { + input_size_ *= size; + } + for (auto size : output) { + output_size_ *= size; + } + if (axis >= input_size) { + MS_LOG(ERROR) << "invalid one hot axis value: " << axis << " for input dims size: " << input.size(); + return false; + } + if (axis == default_axis) { + depth_ = output[output.size() - 1]; + } else { + depth_ = output[IntToSize(axis)]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + // inputs: indices, depth + input_size_list_.push_back((input_size_ + 1) * sizeof(S)); + output_size_list_.push_back(output_size_ * sizeof(T)); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t output_size_; + + size_t depth_; + size_t left_dim_size_; + size_t right_dim_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc new file mode 100644 index 0000000000..3c1323de07 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Select, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SelectGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Select, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SelectGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Select, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + SelectGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h new file mode 100644 index 0000000000..73e60c44bd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SelectGpuKernel : public GpuKernel { + public: + SelectGpuKernel() : input_size_(0), output_size_(0) {} + ~SelectGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + bool *input_cond = GetDeviceAddress(inputs, 0); + T *input_x = GetDeviceAddress(inputs, 1); + T *input_y = GetDeviceAddress(inputs, 2); + T *output = GetDeviceAddress(outputs, 0); + CalSelect(output_size_ / sizeof(T), input_cond, input_x, input_y, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(bool); + output_size_ = sizeof(T); + for (size_t x : shape) { + input_size_ = input_size_ * x; + output_size_ = output_size_ * x; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(output_size_); + input_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SelectGpuKernel needs 3 output."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but SelectGpuKernel needs 1 output."; + return false; + } + return true; + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc new file mode 100644 index 0000000000..4c9ff2b7f4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SliceGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SliceGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SliceGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SliceGpuFwdKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h new file mode 100644 index 0000000000..f8ecb9ccf0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h @@ -0,0 +1,162 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SliceGpuFwdKernel : public GpuKernel { + public: + SliceGpuFwdKernel() + : is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {} + ~SliceGpuFwdKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + if (is_strided_slice_) { + CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output, + reinterpret_cast(stream_ptr)); + } else { + Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0], + input_shape_[1], input_shape_[2], input_shape_[3], input, output, + reinterpret_cast(stream_ptr)); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + ShapeNdTo4d(input_shape, &input_shape_); + auto strides = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"); + if (strides) { + strides_ = GetAttr>(kernel_node, "strides"); + for (auto i = strides_.size(); i < 4; i++) { + (void)strides_.insert(strides_.begin(), 1); + } + size_ = GetAttr>(kernel_node, "end"); + is_strided_slice_ = true; + } else { + size_ = GetAttr>(kernel_node, "size"); + } + for (auto i = begin_.size(); i < 4; i++) { + (void)begin_.insert(begin_.begin(), 0); + } + for (size_t i = size_.size(); i < 4; i++) { + (void)size_.insert(size_.begin(), 1); + } + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + begin_[i] = begin_[i] + input_shape_[i]; + } + } + for (size_t i = 0; i < size_.size(); i++) { + if (size_[i] < 0) { + size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; + } + if (begin_[i] == size_[i] && is_strided_slice_) { + MS_LOG(WARNING) << "Output is null."; + is_null_input_ = true; + } + if (size_[i] == 0 && strides_[i] > 0) { + size_[i] = begin_[i] + 1; + } + } + + input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); + auto out_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + output_size_ = sizeof(T); + for (size_t x : out_shape) { + output_size_ = output_size_ * x; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SliceGpuFwdKernel needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but SliceGpuFwdKernel needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGpuFwdKernel olny support 4d or lower."; + return false; + } + if (input_shape.size() == 0) { + MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", scalar is not supported."; + return false; + } + begin_ = GetAttr>(kernel_node, "begin"); + for (size_t i = 0; i < input_shape.size(); i++) { + if ((begin_[i] > 0 && (begin_[i] > SizeToInt(input_shape[i]))) || + (begin_[i] < 0 && (std::abs(begin_[i]) > SizeToInt(input_shape[i])))) { + MS_LOG(INFO) << "Input out of bounds " << input_shape[i] << " in axis " << i << "."; + begin_[i] = 0; + } + } + return true; + } + std::vector begin_; + std::vector size_; + std::vector strides_; + std::vector input_shape_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + bool is_strided_slice_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc new file mode 100644 index 0000000000..2eeb3acf73 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + SliceGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SliceGradGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + SliceGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SliceGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SliceGradGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SliceGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h new file mode 100644 index 0000000000..006cbf0266 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h @@ -0,0 +1,147 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SliceGradGpuKernel : public GpuKernel { + public: + SliceGradGpuKernel() : is_strided_slice_(false), input_size_(0), output_size_(0), workspace_size_(0) {} + ~SliceGradGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *dy = GetDeviceAddress(inputs, 0); + T *dx = GetDeviceAddress(outputs, 0); + FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast(stream_ptr)); + if (is_strided_slice_) { + CalStridedSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, strides_, dx, + reinterpret_cast(stream_ptr)); + } else { + CalSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, dx, + reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == "StridedSliceGrad") { + is_strided_slice_ = true; + input_shape_ = GetAttr>(kernel_node, "shapex"); + for (auto i = input_shape_.size(); i < 4; i++) { + (void)input_shape_.insert(input_shape_.begin(), 1); + } + strides_ = GetAttr>(kernel_node, "strides"); + for (auto i = strides_.size(); i < 4; i++) { + (void)strides_.insert(strides_.begin(), 1); + } + size_ = GetAttr>(kernel_node, "end"); + } else { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + ShapeNdTo4d(input_shape, &input_shape_); + size_ = GetAttr>(kernel_node, "size"); + } + + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + ShapeNdTo4d(dy_shape, &dy_shape_); + begin_ = GetAttr>(kernel_node, "begin"); + DealParam(); + input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); + + output_size_ = sizeof(T); + for (auto x : dy_shape_) { + output_size_ = output_size_ * IntToSize(x); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(output_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGradGpuKernel only support 4d or lower."; + return false; + } + if (input_shape.size() == 0) { + MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", scalar is not supported."; + return false; + } + return true; + } + void DealParam() { + for (auto i = begin_.size(); i < 4; i++) { + (void)begin_.insert(begin_.begin(), 0); + } + for (auto i = size_.size(); i < 4; i++) { + (void)size_.insert(size_.begin(), 1); + } + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + begin_[i] = begin_[i] + input_shape_[i]; + } + } + for (size_t i = 0; i < size_.size(); i++) { + if (size_[i] < 0) { + size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; + } + } + } + std::vector begin_; + std::vector size_; + std::vector strides_; + std::vector input_shape_; + std::vector dy_shape_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + bool is_strided_slice_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc new file mode 100644 index 0000000000..77e7de6fef --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + TransposeGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + TransposeGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h new file mode 100644 index 0000000000..0f9c710e3e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h @@ -0,0 +1,111 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" +namespace mindspore { +namespace kernel { +template +class TransposeGpuFwdKernel : public GpuKernel { + public: + TransposeGpuFwdKernel() : shape_size_(0), input_size_(0), output_size_(0), workspace_size_(0) {} + ~TransposeGpuFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + int *input_shape = GetDeviceAddress(workspace, 0); + int *input_axis = GetDeviceAddress(workspace, 1); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_axis failed"); + int size = SizeToInt(input_size_ / sizeof(T)); + CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + shape_size_ = input_shape.size(); + if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION + << "-D inputs."; + } + + input_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + input_size_ *= input_shape[i]; + input_shape_.push_back(input_shape[i]); + } + input_size_ *= sizeof(T); + output_size_ = input_size_; + auto perm = GetAttr>(kernel_node, "perm"); + for (size_t j = 0; j < perm.size(); j++) { + input_axis_.push_back(perm[j]); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + workspace_size_ = shape_size_ * sizeof(int); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + return; + } + + private: + std::vector input_shape_; + std::vector input_axis_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t shape_size_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc new file mode 100644 index 0000000000..4be887ec79 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentSumGpuKernel, float, int) + +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentSumGpuKernel, float, int64_t) + +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentSumGpuKernel, int, int) + +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentSumGpuKernel, int, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h new file mode 100644 index 0000000000..1f7884c650 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh" + +namespace mindspore { +namespace kernel { +template +class UnsortedSegmentSumGpuKernel : public GpuKernel { + public: + UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {} + ~UnsortedSegmentSumGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + S *indices_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemsetAsync(output_addr, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), + "cudaMemSet Failed"); + UnsortedSegmentSum(input_dim0_, input_dim1_, output_dim0_, output_dim1_, input_addr, indices_addr, output_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + auto axis = ids_shapes.size(); + for (size_t i = 0; i < input_shapes.size(); i++) { + if (i < axis) { + input_dim0_ *= input_shapes[i]; + } else { + input_dim1_ *= input_shapes[i]; + } + } + + output_dim0_ = output_shapes[0]; + for (size_t j = 1; j < output_shapes.size(); j++) { + output_dim1_ *= output_shapes[j]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_dim0_ * input_dim1_ * sizeof(T)); + input_size_list_.push_back(input_dim0_ * sizeof(S)); + output_size_list_.push_back(output_dim0_ * output_dim1_ * sizeof(T)); + } + + private: + size_t input_dim0_; + size_t input_dim1_; + size_t output_dim0_; + size_t output_dim1_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.cc new file mode 100644 index 0000000000..a89d4e9baf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/control/recv_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_REGULAR(Recv, KernelAttr(), RecvGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.h new file mode 100644 index 0000000000..7de32ade4f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class RecvGpuKernel : public GpuKernel { + public: + RecvGpuKernel() {} + ~RecvGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &, const std::vector &, const std::vector &, + void *) override { + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamWaitEvent(wait_stream_, wait_event_, 0), "Waiting cuda event failed."); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + wait_stream_ = reinterpret_cast(GetAttr(kernel_node, "wait_event_stream")); + wait_event_ = reinterpret_cast(GetAttr(kernel_node, "wait_event")); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + return; + } + + private: + cudaStream_t wait_stream_{nullptr}; + cudaEvent_t wait_event_{nullptr}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.cc new file mode 100644 index 0000000000..946038bb18 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/control/send_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_REGULAR(Send, KernelAttr(), SendGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.h new file mode 100644 index 0000000000..beea19a435 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SendGpuKernel : public GpuKernel { + public: + SendGpuKernel() {} + ~SendGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &, const std::vector &, const std::vector &, + void *) override { + CHECK_CUDA_RET_WITH_EXCEPT(cudaEventRecord(record_event_, record_stream_), "Recording cuda event failed."); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + record_stream_ = reinterpret_cast(GetAttr(kernel_node, "record_event_stream")); + record_event_ = reinterpret_cast(GetAttr(kernel_node, "record_event")); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + return; + } + + private: + cudaStream_t record_stream_{nullptr}; + cudaEvent_t record_event_{nullptr}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cu new file mode 100644 index 0000000000..615b94723d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cu @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh" + +template +__device__ __forceinline__ T SqrtFunc(T input) { + return sqrt(input); +} + +template <> +__device__ __forceinline__ half SqrtFunc(half input) { + return hsqrt(input); +} + +template +__global__ void ApplyAdamKernel(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, + const T *learning_rate, const T *beta1, const T *beta2, const T *epsilon, T *variable, + T *m, T *v) { + const T one = static_cast(1.0); + const T new_learning_rate = learning_rate[0] * SqrtFunc(one - beta2_power[0]) / (one - beta1_power[0]); + + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + m[i] += (gradient[i] - m[i]) * (one - beta1[0]); + v[i] += (gradient[i] * gradient[i] - v[i]) * (one - beta2[0]); + variable[i] -= new_learning_rate * m[i] / (SqrtFunc(v[i]) + epsilon[0]); + } +} + +template +void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, + const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream) { + ApplyAdamKernel<<>>( + size, gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, variable, m, v); +} + +template void ApplyAdam(const size_t size, const float *gradient, const float *beta1_power, + const float *beta2_power, const float *learning_rate, const float *beta1, + const float *beta2, const float *epsilon, float *variable, float *m, float *v, + cudaStream_t cuda_stream); +template void ApplyAdam(const size_t size, const half *gradient, const half *beta1_power, const half *beta2_power, + const half *learning_rate, const half *beta1, const half *beta2, const half *epsilon, + half *variable, half *m, half *v, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh new file mode 100644 index 0000000000..7fc4a3e949 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, + const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cu new file mode 100644 index 0000000000..3bad9a61e1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cu @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "adam_weight_decay_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void AdamWeightDecayKernel(const int element_num_, const bool need_decay, const float *beta1, + const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, + const float *epsilon, const float *lr, const float *weight_decay, T *m, T *v, + T *param, T *gradient) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num_; i += blockDim.x * gridDim.x) { + float next_m = beta1[0] * m[i] + one_sub_beta1[0] * gradient[i]; + float next_v = beta2[0] * v[i] + one_sub_beta2[0] * gradient[i] * gradient[i]; + float update = next_m / (sqrt(next_v) + epsilon[0]); + if (need_decay && weight_decay != nullptr) { + update += weight_decay[0] * param[i]; + } + param[i] -= lr[0] * update; + m[i] = next_m; + v[i] = next_v; + } +} + +template +void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1, + const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr, + const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream) { + AdamWeightDecayKernel<<>>( + element_num_, need_decay, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, param, + gradient); +} + +template void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, + const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, + const float *epsilon, const float *lr, const float *weight_decay, float *m, float *v, + float *param, float *gradient, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu new file mode 100755 index 0000000000..a4f1f6680b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu @@ -0,0 +1,88 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "argmax_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +template +__global__ void Argmax1D(const T* input, const int channel_size, int* output) { + int max_index = 0; + T max = input[0]; + for (int pos = 1; pos < channel_size; pos++) { + if (max < input[pos]) { + max = input[pos]; + max_index = pos; + } + } + output[0] = max_index; + return; +} +template +__global__ void ArgmaxDefault2D(const T* input, const int batch_size, const int channel_size, int* output) { + int pos; + int max_index; + T max; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) { + max = input[i * channel_size]; + max_index = 0; + for (int j = 1; j < channel_size; j++) { + pos = i * channel_size + j; + if (max < input[pos]) { + max = input[pos]; + max_index = j; + } + } + + output[i] = max_index; + } + return; +} +template +__global__ void ArgmaxAxis2D(const T* input, const int batch_size, const int channel_size, int* output) { + int pos; + int max_index; + T max; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { + max = input[i]; + max_index = 0; + for (int j = 1; j < batch_size; j++) { + pos = j * channel_size + i; + if (max < input[pos]) { + max = input[pos]; + max_index = j; + } + } + output[i] = max_index; + } + return; +} +template +void CalArgmax(const T* input, const int batch_size, const int channel_size, const int axis, int* output, + cudaStream_t cuda_stream) { + if (batch_size == 0) { + Argmax1D<<<1, 1, 0, cuda_stream>>>(input, channel_size, output); + } else if (axis == 1) { + ArgmaxDefault2D<<>>(input, batch_size, channel_size, output); + } else { + ArgmaxAxis2D<<>>(input, batch_size, channel_size, output); + } + return; +} + +template void CalArgmax(const float* input, const int batch_size, const int channel_size, const int axis, + int* output, cudaStream_t cuda_stream); +template void CalArgmax(const half* input, const int batch_size, const int channel_size, const int axis, + int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmax_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/argmax_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu new file mode 100644 index 0000000000..46a8a75af9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "argmaxwithvalue_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +template +__global__ void ArgmaxWithValue(const T* input, const int bound, int outerSize, int innerSize, S* index, + T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outerSize); pos += blockDim.x * gridDim.x) { + int inputOutterOffset = pos * innerSize * bound; + int outputOutterOffset = pos * innerSize; + for (int j = 0; j < innerSize; j++) { + auto outputInnerOffset = outputOutterOffset + j; + S idx = 0; + T maxData = input[j + inputOutterOffset]; + for (S c = 0; c < bound; c++) { + int offset = j + c * innerSize; + auto inputData = input[inputOutterOffset + offset]; + idx = inputData > maxData ? c : idx; + maxData = inputData > maxData ? inputData : maxData; + } + output[outputInnerOffset] = maxData; + index[outputInnerOffset] = idx; + } + } + return; +} + +template +void CalArgmaxWithValue(const T* input, const int bound_, const int outerSize_, const int innerSize_, + S* index, T* output, cudaStream_t cuda_stream) { + ArgmaxWithValue<<>>(input, bound_, outerSize_, innerSize_, + index, output); + return; +} + +template void CalArgmaxWithValue(const float* input, const int bound_, const int outerSize_, + const int innerSize_, int* index, float* output, + cudaStream_t cuda_stream); +template void CalArgmaxWithValue(const half* input, const int bound_, const int outerSize_, + const int innerSize_, int* index, half* output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu new file mode 100644 index 0000000000..604391ccf3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "assign_add_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +template +__global__ void AssignAdd(const size_t size, T* ref, const T* value, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + output[pos] = ref[pos] + value[pos]; + ref[pos] = output[pos]; + } + return; +} + +template +void CalAssignAdd(const size_t size, T* ref, const T* value, T* output, cudaStream_t cuda_stream) { + AssignAdd<<>>(size, ref, value, output); + + return; +} + +template void CalAssignAdd(const size_t size, float* ref, const float* value, float* output, + cudaStream_t cuda_stream); +template void CalAssignAdd(const size_t size, half* ref, const half* value, half* output, + cudaStream_t cuda_stream); +template void CalAssignAdd(const size_t size, int* ref, const int* value, int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/assign_add_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/assign_add_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold2_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cuh new file mode 100644 index 0000000000..3a895405b1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cuh @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void BatchNormFold2Forward(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean, + const T *running_std, const T *running_mean, const int *global_step, T *y, int freeze_bn, + size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream); +template +void CalBatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, + const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, + T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream); +template +void CalBatchNormFold2GradFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, + const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, + T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream); +template +void BatchNormFold2GradReduce(const T *dout, const T *x, T *d_beta, T *tmp, T *reduce_x, T *tmp2, T *tmp_x, size_t N, + size_t C, size_t H, size_t W, cudaStream_t cuda_stream); + +template +void CalBatchNormFold2GradNotFreezeDxMul(const T *batch_std, const T *running_std, T *d_x, size_t N, size_t C, size_t H, + size_t W, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cu new file mode 100755 index 0000000000..dae9a7d629 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cu @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "batchnorm_fold_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void UpdateRunningStd(int channel_size, const double epsilon, T* running_std) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { + running_std[i] = sqrtf(running_std[i] + epsilon); + } + return; +} + +template +__global__ void UpdateBatchStd(int channel_size, T* batch_std) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { + batch_std[i] = 1 / batch_std[i]; + } + return; +} + +template +__global__ void CalDx(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, const T* batch_std, + int batch_size, int channel_size, int height, int width, T* dx) { + int n = batch_size * channel_size * height * width; + int normal_size = batch_size * height * width; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + int channel_index = i / (height * width) % channel_size; + dx[i] = d_batch_mean[channel_index] / normal_size + + d_batch_std[channel_index] * (x[i] - batch_mean[channel_index]) / batch_std[channel_index] / normal_size; + } + return; +} + +template +void CalUpdateRunningStd(int channel_size, double epsilon, T* running_std, cudaStream_t cuda_stream) { + UpdateRunningStd<<>>(channel_size, epsilon, running_std); + return; +} + +template void CalUpdateRunningStd(int channel_size, double epsilon, float* running_std, + cudaStream_t cuda_stream); + +template +void CalUpdateBatchStd(int channel_size, T* batch_std, cudaStream_t cuda_stream) { + UpdateBatchStd<<>>(channel_size, batch_std); + return; +} + +template void CalUpdateBatchStd(int channel_size, float* batch_std, cudaStream_t cuda_stream); + +template +void CalBatchNormFoldGrad(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, + const T* batch_std, int batch_size, int channel_size, int height, int width, T* dx, + cudaStream_t cuda_stream) { + CalDx<<>>( + d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_size, channel_size, height, width, dx); +} + +template void CalBatchNormFoldGrad(const float* d_batch_mean, const float* d_batch_std, const float* x, + const float* batch_mean, const float* batch_std, int batch_size, + int channel_size, int height, int width, float* dx, cudaStream_t cuda_stream); + +template +void ThrustFillWith(T* array, int size, T tofill, cudaStream_t cuda_stream) { + thrust::device_ptr dev_ptr(array); + thrust::fill(thrust::cuda::par.on(cuda_stream), dev_ptr, dev_ptr + size, tofill); +} + +template void ThrustFillWith(float* array, int size, float tofill, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu new file mode 100644 index 0000000000..262d4c438d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +struct MinimumGradFunc { + __device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) { + if (x1 < x2) { + atomicAdd(dx1, dy); + } else { + atomicAdd(dx2, dy); + } + } +}; + +template +struct MaximumGradFunc { + __device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) { + if (x1 > x2) { + atomicAdd(dx1, dy); + } else { + atomicAdd(dx2, dy); + } + } +}; + +__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } + +template +__device__ __forceinline__ void BroadcastGradOperator(const int &l0, const int &l1, const int &l2, const int &l3, + const int &r0, const int &r1, const int &r2, const int &r3, + const int &d0, const int &d1, const int &d2, const int &d3, + const T *x1, const T *x2, const T *dy, T *dx1, T *dx2) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { + int i = pos / (d1 * d2 * d3) % d0; + int j = pos / (d2 * d3) % d1; + int k = pos / d3 % d2; + int l = pos % d3; + + int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); + int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); + Func()(x1[l_index], x2[r_index], dy[pos], dx1 + l_index, dx2 + r_index); + } +} + +template +__global__ void BroadcastGradKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, + const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, + enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, + T *dx2) { + switch (op) { + case BROADCAST_GRAD_TYPE_MINIMUM: + return BroadcastGradOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy, + dx1, dx2); + case BROADCAST_GRAD_TYPE_MAXIMUM: + return BroadcastGradOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy, + dx1, dx2); + } +} + +template +void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, + cudaStream_t stream) { + int size = d0 * d1 * d2 * d3; + BroadcastGradKernel<<>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, + x1, x2, dy, dx1, dx2); +} + +template +__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *x1, const T *x2, const T *dy, T *dx1, + T *dx2) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { + Func()(x1[pos], x2[pos], dy[pos], dx1 + pos, dx2 + pos); + } +} + +template +__global__ void NoBroadcastGradKernel(const int nums, enum BroadcastGradOpType op, const T *x1, const T *x2, + const T *dy, T *dx1, T *dx2) { + switch (op) { + case BROADCAST_GRAD_TYPE_MINIMUM: + return NoBroadcastOperator>(nums, x1, x2, dy, dx1, dx2); + case BROADCAST_GRAD_TYPE_MAXIMUM: + return NoBroadcastOperator>(nums, x1, x2, dy, dx1, dx2); + } +} + +template +void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, + T *dx2, cudaStream_t stream) { + NoBroadcastGradKernel<<>>(nums, op, x1, x2, dy, dx1, dx2); +} + +template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const float *x1, const float *x2, + const float *dy, float *dx1, float *dx2, cudaStream_t stream); +template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const int *x1, const int *x2, + const int *dy, int *dx1, int *dx2, cudaStream_t stream); +template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastGradOpType op, const float *x1, const float *x2, const float *dy, float *dx1, + float *dx2, cudaStream_t stream); +template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastGradOpType op, const int *x1, const int *x2, const int *dy, int *dx1, + int *dx2, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh new file mode 100644 index 0000000000..7742043592 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ + +#include "runtime/device/gpu/cuda_common.h" + +enum BroadcastGradOpType { + BROADCAST_GRAD_TYPE_MAXIMUM = 0, + BROADCAST_GRAD_TYPE_MINIMUM = 1, + BROADCAST_GRAD_TYPE_INVALID = 0xffffffff, +}; + +template +void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, + cudaStream_t stream); + +template +void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, + T *dx2, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu new file mode 100644 index 0000000000..a72daa4234 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +struct GreaterFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; } +}; + +template +struct LessFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; } +}; + +template +struct MinimumFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } +}; + +template +struct MaximumFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; } +}; + +template +struct PowerFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); } +}; + +template <> +struct PowerFunc { + __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { + return __float2half(pow(__half2float(lhs), __half2float(rhs))); + } +}; + +template +struct RealDivFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); } +}; + +template +struct MulFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs * rhs); } +}; + +template +struct SubFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs - rhs); } +}; + +template +struct AddFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } +}; + +template <> +struct PowerFunc { + // invalid branch + __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } +}; + +__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } + +template +__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, + const int &r0, const int &r1, const int &r2, const int &r3, + const int &d0, const int &d1, const int &d2, const int &d3, + const T *input0, const T *input1, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { + int i = pos / (d1 * d2 * d3) % d0; + int j = pos / (d2 * d3) % d1; + int k = pos / d3 % d2; + int l = pos % d3; + + int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); + int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); + output[pos] = Func()(input0[l_index], input1[r_index]); + } +} + +template +__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, + const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, + enum BroadcastOpType op, const T *input0, const T *input1, S *output) { + switch (op) { + case BROADCAST_TYPE_GREATER: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_LESS: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_MINIMUM: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_MAXIMUM: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_POWER: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_REALDIV: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_MUL: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_SUB: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_ADD: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + } +} + +template +void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, + const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, + const T *input0, const T *input1, S *output, cudaStream_t stream) { + int size = d0 * d1 * d2 * d3; + BroadcastKernel<<>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, + input0, input1, output); +} + +template +__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *input0, const T *input1, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { + output[pos] = Func()(input0[pos], input1[pos]); + } +} + +template +__global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const T *input0, const T *input1, + S *output) { + switch (op) { + case BROADCAST_TYPE_GREATER: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_LESS: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_MINIMUM: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_MAXIMUM: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_POWER: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_REALDIV: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_MUL: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_SUB: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_ADD: + return NoBroadcastOperator>(nums, input0, input1, output); + } +} + +template +void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, const T *input1, S *output, + cudaStream_t stream) { + NoBroadcastKernel<<>>(nums, op, input0, input1, output); +} + +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const float *input0, const float *input1, bool *output, + cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const float *input0, const float *input1, float *output, + cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const half *input0, const half *input1, bool *output, + cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const half *input0, const half *input1, half *output, + cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const int *input0, const int *input1, int *output, + cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, + bool *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, + float *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, + bool *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, + half *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, + int *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh new file mode 100644 index 0000000000..dfc4c75c93 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ + +#include "runtime/device/gpu/cuda_common.h" + +enum BroadcastOpType { + BROADCAST_TYPE_GREATER = 0, + BROADCAST_TYPE_LESS = 1, + BROADCAST_TYPE_MAXIMUM = 2, + BROADCAST_TYPE_MINIMUM = 3, + BROADCAST_TYPE_POWER = 4, + BROADCAST_TYPE_REALDIV = 5, + BROADCAST_TYPE_MUL = 6, + BROADCAST_TYPE_SUB = 7, + BROADCAST_TYPE_ADD = 8, + BROADCAST_TYPE_INVALID = 0xffffffff, +}; + +template +void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, + const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, + const T *input0, const T *input1, S *output, cudaStream_t stream); + +template +void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, + cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu new file mode 100755 index 0000000000..147782591a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" +template +__global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int n = pos / (w1 + w2); + int m = pos % (w1 + w2); + output[pos] = m >= w1 ? input_2[n * w2 + m - w1] : input_1[n * w1 + m]; + } + return; +} + +template +__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, + const T* input_1, const T* input_2, const T* input_3, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int n = pos / (w1 + w2 + w3); + int m = pos % (w1 + w2 + w3); + output[pos] = m < w1 ? input_1[n * w1 + m] : + m < w1 + w2 ? input_2[n * w2 + m - w1] : + input_3[n * w3 + m - w1 - w2]; + } + return; +} + +template +__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4, + const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int n = pos / (w1 + w2 + w3 + w4); + int m = pos % (w1 + w2 + w3 + w4); + output[pos] = m < w1 ? input_1[n * w1 + m] : + m < w1 + w2 ? input_2[n * w2 + m - w1]: + m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]: + input_4[n * w4 + m - w1 - w2 - w3]; + } + return; +} + +template +void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, + cudaStream_t cuda_stream) { + Concat<<>>(size, w1, w2, input_1, input_2, output); + return; +} + +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const T* input_1, const T* input_2, const T* input_3, T* output, + cudaStream_t cuda_stream) { + Concat<<>>(size, w1, w2, w3, input_1, input_2, input_3, output); + return; +} + +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, + cudaStream_t cuda_stream) { + Concat<<>>(size, w1, w2, w3, w4, input_1, + input_2, input_3, input_4, output); + return; +} + +template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, + float* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, + int* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, + half* output, cudaStream_t cuda_stream); + +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const float* input_1, const float* input_2, const float* input_3, + float* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const int* input_1, const int* input_2, const int* input_3, + int* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const half* input_1, const half* input_2, const half* input_3, + half* output, cudaStream_t cuda_stream); + +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const float* input_1, const float* input_2, const float* input_3, const float* input_4, + float* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const int* input_1, const int* input_2, const int* input_3, const int* input_4, + int* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const half* input_1, const half* input_2, const half* input_3, const half* input_4, + half* output, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh new file mode 100755 index 0000000000..7bd32c140f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, + cudaStream_t cuda_stream); +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream); +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cu new file mode 100755 index 0000000000..87aaf1351c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cu @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "correction_mul_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void CorrectionMul(const T* weight, const T* gamma, const T* running_std, const int batchsize, const int chw, + T* output) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batchsize * chw; i += blockDim.x * gridDim.x) { + int n = i / chw; + output[i] = weight[i] * gamma[n] / running_std[n]; + } + return; +} + +template +__global__ void Mul(int N, const T* a, const T* b, T* c) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { + c[i] = a[i] * b[i]; + } + return; +} + +template +__global__ void Reduce(int N, int CHW, const T* tmp, const T* running_std, T* d_gamma) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { + d_gamma[i] = thrust::reduce(thrust::seq, tmp + i * CHW, tmp + (i + 1) * CHW, 0.f, thrust::plus()); + d_gamma[i] = d_gamma[i] / running_std[i]; + } + return; +} + +template +void CalCorrectionMul(const T* weight, const T* gamma, const T* running_std, int N, int C, int H, int W, T* output, + cudaStream_t cuda_stream) { + CorrectionMul<<>>(weight, gamma, running_std, N, C * H * W, + output); +} + +template void CalCorrectionMul(const float* weight, const float* gamma, const float* running_std, int N, int C, + int H, int W, float* output, cudaStream_t cuda_stream); + +template +void CalCorrectionMulGrad(const T* d_out, const T* weight, const T* running_std, int N, int C, int H, int W, T* d_gamma, + T* tmp, cudaStream_t cuda_stream) { + Mul<<>>(N * C * H * W, d_out, weight, tmp); + Reduce<<>>(N, C * H * W, tmp, running_std, d_gamma); +} + +template void CalCorrectionMulGrad(const float* d_out, const float* weight, const float* running_std, int N, + int C, int H, int W, float* d_gamma, float* tmp, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/correction_mul_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/correction_mul_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cuh new file mode 100644 index 0000000000..cb4ccc2c44 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cuh @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CrossEntropyWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *loss, + cudaStream_t cuda_stream); + +template +void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, + T *grad, cudaStream_t cuda_stream); + +template +void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, + T *dlogits, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cuh new file mode 100644 index 0000000000..3ba27eeeea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void DropoutForward(const T *input, T *mask, T *output, float *mask_f, size_t num_count, float keep_prob, + cudaStream_t cuda_stream); +template +void DropoutBackward(const T *dy, const T *mask, T *dx, size_t num_count, float keep_prob, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cu new file mode 100755 index 0000000000..e6f424c661 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cu @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "equalcount_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void EqualCount(const int size, const T* input1, const T* input2, T* output) { + T equal_count = 0; + + for (int i = 0; i < size; i++) { + if (input1[i] == input2[i]) { + equal_count++; + } + } + + output[0] = equal_count; + return; +} +template +void CalEqualCount(const int size, const T* input1, const T* input2, T* output, cudaStream_t cuda_stream) { + EqualCount<<<1, 1, 0, cuda_stream>>>(size, input1, input2, output); + return; +} + +template void CalEqualCount(const int size, const int* input1, const int* input2, int* output, + cudaStream_t cuda_stream); +template void CalEqualCount(const int size, const float* input1, const float* input2, float* output, + cudaStream_t cuda_stream); +template void CalEqualCount(const int size, const half* input1, const half* input2, half* output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/equalcount_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/equalcount_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cuh new file mode 100644 index 0000000000..e17615db67 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cuh @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric, + cudaStream_t cuda_stream); + +void CalFakeQuantPerChannel(const float *input, float *output, const int total_num, const int channel_num, + const float *nudge_min, const float *nudge_max, const float *scale, + cudaStream_t cuda_stream); + +void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, + const int channel_num, const float *nudge_min, const float *nudge_max, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cuh new file mode 100644 index 0000000000..5f6675b2d7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ + +#include "runtime/device/gpu/cuda_common.h" + +void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const bool symmetric, cudaStream_t cuda_stream); + +void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, + const float *nudge_max, const float *scale, cudaStream_t cuda_stream); + +void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, + const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu new file mode 100644 index 0000000000..bc400eb704 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/cuda_runtime.h" +#include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh" + +template +__global__ void IsNan(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsNan(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void IsInf(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) != 0) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsInf(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) != 0) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void IsFinite(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) == 0 && !isnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsFinite(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) == 0 && !__hisnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void FloatStatus(const size_t size, const T* input, T* out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) != 0 || isnan(input[pos])) { + out[0] = 1; + } + } + return; +} +template <> +__global__ void FloatStatus(const size_t size, const half* input, half* out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) { + out[0] = 1; + } + } + return; +} + +template +void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) { + FloatStatus<<>>(size, input, output); + return; +} +template +void CalIsNan(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsNan<<>>(size, input, output); + return; +} +template +void CalIsInf(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsInf<<>>(size, input, output); + return; +} +template +void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsFinite<<>>(size, input, output); + return; +} + +template void CalFloatStatus(const size_t size, const float* input, float* output, cudaStream_t cuda_stream); +template void CalFloatStatus(const size_t size, const half* input, half* output, cudaStream_t cuda_stream); +template void CalIsInf(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsInf(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); +template void CalIsNan(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsNan(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); +template void CalIsFinite(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsFinite(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh new file mode 100644 index 0000000000..fbe063e72a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ +#include "runtime/device/gpu/cuda_common.h" +template +void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream); +template +void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream); +template +void CalIsInf(const size_t size, const T *input, bool *output, cudaStream_t stream); +template +void CalIsFinite(const size_t size, const T *input, bool *output, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cu new file mode 100644 index 0000000000..be4415d509 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cu @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh" + +template +__device__ __forceinline__ T PowFunc(T x, T y) { + return pow(x, y); +} + +template <> +__device__ __forceinline__ half PowFunc(half x, half y) { + return __float2half(pow(__half2float(x), __half2float(y))); +} + +template +__device__ __forceinline__ bool CompareFunc(T x, T y) { + return abs(x) > y; +} + +template <> +__device__ __forceinline__ bool CompareFunc(half x, half y) { + return abs(__half2float(x)) > __half2float(y); +} + +template +__device__ __forceinline__ T Sgn(T x) { + return static_cast(x != 0 ? (x > 0 ? 1 : -1) : 0); +} + +template <> +__device__ __forceinline__ half Sgn(half x) { + return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0); +} + +template +__global__ void ApplyFtrlKernel(const size_t size, const T *gradient, const T *learning_rate, + const T *l1_regularization, const T *l2_regularization, const T *learning_rate_power, + T *variable, T *accumulation, T *linear) { + const T two = static_cast(2.0); + const T learning_rate_power_val = -learning_rate_power[0]; + + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const T cur_accumulation = accumulation[i] + gradient[i] * gradient[i]; + const T accumulation_power = PowFunc(accumulation[i], learning_rate_power_val); + const T cur_accumulation_power = PowFunc(cur_accumulation, learning_rate_power_val); + const T sigma = (cur_accumulation_power - accumulation_power) / learning_rate[0]; + + linear[i] += gradient[i] - sigma * variable[i]; + variable[i] = CompareFunc(linear[i], l1_regularization[0]) + ? ((l1_regularization[0] * Sgn(linear[i]) - linear[i]) / + (cur_accumulation_power / learning_rate[0] + two * l2_regularization[0])) + : static_cast(0); + accumulation[i] = cur_accumulation; + } +} + +template +void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, + const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, + cudaStream_t cuda_stream) { + ApplyFtrlKernel<<>>(size, gradient, learning_rate, l1_regularization, + l2_regularization, learning_rate_power, variable, + accumulation, linear); +} + +template void ApplyFtrl(const size_t size, const float *gradient, const float *learning_rate, + const float *l1_regularization, const float *l2_regularization, + const float *learning_rate_power, float *variable, float *accumulation, float *linear, + cudaStream_t cuda_stream); +template void ApplyFtrl(const size_t size, const half *gradient, const half *learning_rate, + const half *l1_regularization, const half *l2_regularization, + const half *learning_rate_power, half *variable, half *accumulation, half *linear, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh new file mode 100644 index 0000000000..b5f0f82afe --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, + const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu new file mode 100755 index 0000000000..03b58b81a0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void GatherKernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1) { + int num = output_dim0 * output_dim1 * output_dim2; + int i, j, k; + for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; + write_index += blockDim.x * gridDim.x) { + i = write_index / (output_dim1 * output_dim2) % output_dim0; + j = write_index / output_dim2 % output_dim1; + k = write_index % output_dim2; + + if ((indices[j] >= 0) && (indices[j] < input_dim1)) { + int read_index = i * input_dim1 * output_dim2 + indices[j] * output_dim2 + k; + output[write_index] = input[read_index]; + } else { + output[write_index] = 0; + } + } + + return; +} +template +void Gather(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream) { + int size = output_dim0 * output_dim1 * output_dim2; + GatherKernel<<>>(input, indices, output, output_dim0, output_dim1, + output_dim2, input_dim1); + return; +} + +template void Gather(float *input, int *indices, float *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); + +template void Gather(half *input, int *indices, half *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gather.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/gather.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cu new file mode 100644 index 0000000000..a4dc6648cc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cu @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void GeluKernel(size_t size, T *input_addr, T *output_addr) { + // formula: + // gelu(x) = 0.5 * x * (1.0 + tanh(y)) + // tanh(y) = 2 / (1 + exp(-2y)) - 1) + // y = sqrt(2/pi) * (x + 0.044715 * x^3) + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + float x = input_addr[pos]; + float tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); + output_addr[pos] = 0.5 * x * (1.0 + tanh_res); + } +} + +template <> +__global__ void GeluKernel(size_t size, half *input_addr, half *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + half x = input_addr[pos]; + float tanh_res = tanh(__half2float(half(0.7978845608) * (x + half(0.044715) * x * x * x))); + output_addr[pos] = half(0.5) * x * (half(1.0) + __float2half(tanh_res)); + } +} + +template <> +__global__ void GeluKernel(size_t size, half2 *input_addr, half2 *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + half2 x = input_addr[pos]; + float2 tanh_param = __half22float2(half2(0.7978845608, 0.7978845608) * (x + half2(0.044715, 0.044715) * x * x * x)); + float2 tanh_res; + tanh_res.x = tanh(tanh_param.x); + tanh_res.y = tanh(tanh_param.y); + output_addr[pos] = half2(0.5, 0.5) * x * (half2(1.0, 1.0) + __float22half2_rn(tanh_res)); + } +} + +template +void Gelu(size_t size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) { + GeluKernel<<>>(size, input_addr, output_addr); + return; +} + +template <> +void Gelu(size_t size, half *input_addr, half *output_addr, cudaStream_t cuda_stream) { + if (size % 2 == 0) { + GeluKernel<<>>( + size / 2, reinterpret_cast(input_addr), reinterpret_cast(output_addr)); + } else { + GeluKernel<<>>(size, input_addr, output_addr); + } + return; +} + +template +__global__ void GeluGradKernel(size_t size, T *dy_addr, T *x_addr, T *dx_addr) { + // formula: + // dx = dy * y' + // y' = 0.5 * (1 + tanh(tanh_para)) + + // 0.5 * x * (1 - tanh(tanh_para) * tanh(tanh_para)) * mul_right + // tanh_para = sqrt(2/pi) * (x + 0.044715 * x^3) + // mul_right = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)) + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + T x = x_addr[pos]; + T tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); + T mul_right = 0.7978845608 + 0.1070322244 * x * x; + T y_res = 0.5 * (1.0 + tanh_res) + 0.5 * x * (1.0 - tanh_res * tanh_res) * mul_right; + dx_addr[pos] = dy_addr[pos] * y_res; + } +} + +template +__global__ void GeluGradKernel(size_t size, half2 *dy_addr, half2 *x_addr, half2 *dx_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + half2 x = x_addr[pos]; + float2 tanh_param = __half22float2(half2(0.7978845608, 0.7978845608) * (x + half2(0.044715, 0.044715) * x * x * x)); + float2 tanh_res; + tanh_res.x = tanh(tanh_param.x); + tanh_res.y = tanh(tanh_param.y); + half2 tanh_res_half = __float22half2_rn(tanh_res); + half2 mul_right = half2(0.7978845608, 0.7978845608) + half2(0.1070322244, 0.1070322244) * x * x; + half2 y_res = half2(0.5, 0.5) * (half2(1.0, 1.0) + tanh_res_half) + + half2(0.5, 0.5) * x * (half2(1.0, 1.0) - tanh_res_half * tanh_res_half) * mul_right; + dx_addr[pos] = dy_addr[pos] * y_res; + } +} + +template +__global__ void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + half x = x_addr[pos]; + half tanh_param = half(0.7978845608) * (x + half(0.044715) * x * x * x); + half tanh_res = __float2half_rn(tanh(__half2float(tanh_param))); + half mul_right = half(0.7978845608) + half(0.1070322244) * x * x; + half y_res = half(0.5) * (half(1.0) + tanh_res) + half(0.5) * x * (half(1.0) - tanh_res * tanh_res) * mul_right; + dx_addr[pos] = dy_addr[pos] * y_res; + } +} + +template +void GeluGradKernel(size_t size, T *dy_addr, T *x_addr, T *dx_addr, cudaStream_t cuda_stream) { + GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); +} + +template <> +void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr, cudaStream_t cuda_stream) { + if (size % 2 == 0) { + GeluGradKernel<<>>( + size / 2, reinterpret_cast(dy_addr), reinterpret_cast(x_addr), + reinterpret_cast(dx_addr)); + } else { + GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); + } + return; +} + +template void Gelu(size_t size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); +template void Gelu(size_t size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); +template void GeluGradKernel(size_t size, float *dy_addr, float *x_addr, float *dx_addr, cudaStream_t cuda_stream); +template void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh new file mode 100644 index 0000000000..1e69f26d57 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void Gelu(size_t input_size, T* input_addr, T* output_addr, cudaStream_t cuda_stream); + +template +void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu new file mode 100644 index 0000000000..fcb7418952 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu @@ -0,0 +1,259 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh" + +constexpr int NUM_PER_THREAD_REDUCE = 4; +constexpr int WARP_SIZE = 32; + +template +inline __device__ T my_pow(T a, double b) { + return pow(a, static_cast(b)); +} + +template <> +inline __device__ half my_pow(half a, double b) { + return __float2half(pow(__half2float(a), static_cast(b))); +} + +template +inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, + const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, + T* dg, T* db) { + int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int row = NUM_PER_THREAD_REDUCE * i + j; + if (row >= row_dim) { + return; + } + + int pos = row * col_dim + col; + dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); + db[0] += dy[pos]; + } + } +} + +template +inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + dg[0] += __shfl_down_sync(0xffffffff, dg[0], delta); + db[0] += __shfl_down_sync(0xffffffff, db[0], delta); + } +} + +template +inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_dim, T* dg, T* db, T* dg_addr, + T* db_addr) { + if (threadIdx.x >= row_dim) { + return; + } + + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + DynamicSharedMem share_mem; + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 2; + share_mem.addr()[offset] = dg[0]; + share_mem.addr()[offset + 1] = db[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 2; + share_mem.addr()[threadIdx.x * 2] += share_mem.addr()[offset]; + share_mem.addr()[threadIdx.x * 2 + 1] += share_mem.addr()[offset + 1]; + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + dg_addr[col] = share_mem.addr()[0]; + db_addr[col] = share_mem.addr()[1]; + } +} + +template +__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T* dy, const T* x, + const T* mean_addr, const T* var_addr, T* dg_addr, T* db_addr) { + // row: [0:param_axis] + // col: [param_axis:] + // dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) + // dg[j] = \Sigma_{j}dg[i][j] + for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { + T dg = 0; + T db = 0; + GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); + GammaAndBetaWarpReduce(&dg, &db); + GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr); + } +} + +template +inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, + T* sum1, T* sum2, T* sum3, const T* dy, const T* x, const T* mean, + const T* var, const T* gamma) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int col = NUM_PER_THREAD_REDUCE * i + j; + if (col >= col_dim) { + return; + } + + int pos = row * col_dim + col; + int gamma_offset = pos % param_dim; + T v1 = dy[pos] * gamma[gamma_offset]; + T v2 = x[pos] - mean[row]; + + sum1[0] += -0.5 * v1 * v2 * my_pow(var[row] + epsilon, -1.5); + sum2[0] += v1; + sum3[0] += -2.0 * v2; + } + } +} + +template <> +inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, + half* sum1, half* sum2, half* sum3, const half* dy, const half* x, + const half* mean, const half* var, const half* gamma) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int col = NUM_PER_THREAD_REDUCE * i + j; + if (col >= col_dim) { + return; + } + + int pos = row * col_dim + col; + int gamma_offset = pos % param_dim; + half v1 = dy[pos] * gamma[gamma_offset]; + half v2 = x[pos] - mean[row]; + + sum1[0] += __float2half(-0.5) * v1 * v2 * my_pow(var[row] + epsilon, -1.5); + sum2[0] += v1; + sum3[0] += __float2half(-2.0) * v2; + } + } +} + +template +inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); + sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); + sum3[0] += __shfl_down_sync(0xffffffff, sum3[0], delta); + } +} + +template +inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* sum3, T* share_mem) { + if (threadIdx.x >= col_dim) { + return; + } + + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 3; + share_mem[offset] = sum1[0]; + share_mem[offset + 1] = sum2[0]; + share_mem[offset + 2] = sum3[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 3; + share_mem[threadIdx.x * 3] += share_mem[offset]; + share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1]; + share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2]; + } + } + __syncthreads(); +} + +template +inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, + const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx, + const T* share_mem) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = (row * col_dim + col); + int gamma_offset = pos % param_dim; + T v1 = dy[pos] * gamma[gamma_offset]; + T v2 = x[pos] - mean[row]; + T v3 = my_pow(var[row] + epsilon, -0.5); + dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 + + (-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim); + } +} + +template <> +inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, + const half* dy, const half* x, const half* mean, const half* var, const half* gamma, + half* dx, const half* share_mem) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = (row * col_dim + col); + int gamma_offset = pos % param_dim; + half v1 = dy[pos] * gamma[gamma_offset]; + half v2 = x[pos] - mean[row]; + half v3 = my_pow(var[row] + epsilon, -0.5); + dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 + + (__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\ + * __float2half(1.0 / col_dim); + } +} + +template +__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, + const T* x, const T* mean, const T* var, const T* gamma, T* dx) { + for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { + T sum1 = 0; + T sum2 = 0; + T sum3 = 0; + DynamicSharedMem share_mem; + InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma); + InputWarpReduce(&sum1, &sum2, &sum3); + InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem.addr()); + InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem.addr()); + } +} + +template +void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, + const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { + int share_mem_size = + ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); + InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, + gamma, dx); + + share_mem_size = + ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); + GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); +} + +template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, + const float* dy, const float* x, const float* mean, const float* var, const float* gamma, + float* dx, float* dg, float* db, cudaStream_t stream); +template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const half& epsilon, + const half* dy, const half* x, const half* mean, const half* var, const half* gamma, + half* dx, half* dg, half* db, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cuh new file mode 100644 index 0000000000..13d7a58614 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, + const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu new file mode 100644 index 0000000000..138300b303 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh" + +constexpr int NUM_PER_THREAD_REDUCE = 4; +constexpr int WARP_SIZE = 32; + +template +inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T &val) { + // Welford Algorithm: + // \mu_k = \mu_{k-1} + (x_k - \mu_{k-1})/k + // \sigma_k^2 = \sigma_{k-1}^2 + (x_k - \mu_{k-1}) * (x_k - \mu_k) + num[0]++; + T mean_new = mean[0] + (val - mean[0]) / num[0]; + var[0] = var[0] + (val - mean[0]) * (val - mean_new); + mean[0] = mean_new; +} + +template +inline __device__ void MeanAndVarMerge(T *m1, T *v1, T *n1, const T &m2, const T &v2, const T &n2) { + T zero = 0; + if (n2 == zero) { + return; + } + + T count = n1[0] + n2; + v1[0] = v1[0] + v2 + (m1[0] - m2) * (m1[0] - m2) * n1[0] * n2 / count; + m1[0] = (n1[0] * m1[0] + n2 * m2) / count; + n1[0] = count; +} + +template +inline __device__ void ThreadReduce(const int &col_dim, const T *block_addr, T *mean, T *var, T *num) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int pos = NUM_PER_THREAD_REDUCE * i + j; + if (pos >= col_dim) { + return; + } + MeanAndVarAccumulation(mean, var, num, block_addr[pos]); + } + } +} + +template +inline __device__ void WarpReduce(T *mean, T *var, T *num) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + T mean_other = __shfl_down_sync(0xffffffff, mean[0], delta); + T var_other = __shfl_down_sync(0xffffffff, var[0], delta); + T num_other = __shfl_down_sync(0xffffffff, num[0], delta); + MeanAndVarMerge(mean, var, num, mean_other, var_other, num_other); + } +} + +template +inline __device__ void BlockReduce(const int &col_dim, T *mean, T *var, T *num, T *mean_addr, T *var_addr, + T *share_mem) { + if (threadIdx.x >= col_dim) { + return; + } + + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 3; + share_mem[offset] = mean[0]; + share_mem[offset + 1] = var[0]; + share_mem[offset + 2] = num[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 3; + MeanAndVarMerge(&share_mem[threadIdx.x * 3], &share_mem[threadIdx.x * 3 + 1], &share_mem[threadIdx.x * 3 + 2], + share_mem[offset], share_mem[offset + 1], share_mem[offset + 2]); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + mean_addr[blockIdx.x] = share_mem[0]; + share_mem[1] /= col_dim; + var_addr[blockIdx.x] = share_mem[1]; + } +} + +template +inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const T *x, + const T *share_mem, const T *gamma, const T *beta, const T epsilon, T *y) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = row * col_dim + col; + int i = pos % param_dim; + y[pos] = (x[pos] - share_mem[0]) / sqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; + } +} + +template <> +inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const half *x, + const half *share_mem, const half *gamma, const half *beta, const half epsilon, + half *y) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = row * col_dim + col; + int i = pos % param_dim; + y[pos] = (x[pos] - share_mem[0]) / hsqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; + } +} + +template +__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x, + const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) { + for (auto row = blockIdx.x; row < row_dim; row += gridDim.x) { + T mean = 0; + T var = 0; + T num = 0; + const T *block_addr = x + row * col_dim; + DynamicSharedMem share_mem; + + ThreadReduce(col_dim, block_addr, &mean, &var, &num); + WarpReduce(&mean, &var, &num); + BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem.addr()); + + __syncthreads(); + LayerNorm(row, col_dim, param_dim, x, share_mem.addr(), gamma, beta, epsilon, y); + } +} + +template +void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *x, + const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) { + const dim3 block(row_dim); + const dim3 thread(256); + // keep the mean/var/num after warp reduce + int share_mem_size = + ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); + LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, + mean, var); +} + +template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, + const float *x, const float *gamma, const float *beta, float *y, float *mean, float *var, + cudaStream_t stream); +template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const half &epsilon, + const half *x, const half *gamma, const half *beta, half *y, half *mean, half *var, + cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh new file mode 100644 index 0000000000..9548b30d44 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +struct DynamicSharedMem; +template<> +struct DynamicSharedMem { + __device__ float *addr() { + extern __shared__ float addr_float[]; + return addr_float; + } +}; +template<> +struct DynamicSharedMem { + __device__ half *addr() { + extern __shared__ half addr_half[]; + return addr_half; + } +}; + +template +void LayerNorm(const int& outer, const int& inner, const int& param_dim, const T& epsilon, const T* x, const T* gamma, + const T* beta, T* y, T* mean, T* var, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cu new file mode 100644 index 0000000000..3915dba172 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cu @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "minmax_update_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min, + float *output_max, const float min, const float max, + const float decay) { + output_min[0] = decay * (min) + (1 - decay) * (input_min[0]); + output_min[0] = input_min[0] > 0 ? 0 : input_min[0]; + output_max[0] = decay * (max) + (1 - decay) * (input_max[0]); + output_max[0] = input_max[0] < 0 ? 0 : input_max[0]; + return; +} + +__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max) { + output_min[0] = min > 0 ? 0 : min; + output_max[0] = max < 0 ? 0 : max; + return; +} + +__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, + float *output_max, int channels, int per_channel_nums, bool ema, + float ema_decay) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { + thrust::pair sum = + thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); + if (ema) { + output_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; + output_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i]; + } else { + output_min[i] = sum.first[0]; + output_max[i] = sum.second[0]; + } + output_min[i] = input_min[i] > 0 ? 0 : input_min[i]; + output_max[i] = input_max[i] < 0 ? 0 : input_max[i]; + } + return; +} + +void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int total_num, const int channel_num, const float ema_decay, const bool ema, + cudaStream_t cuda_stream) { + int per_channel_num = total_num / channel_num; + UpdateInputMinMaxPerChannel<<>>( + input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay); + return; +} + +void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int total_num, const float ema_decay, const bool ema, cudaStream_t cuda_stream) { + float minel = 0.f; + float maxel = 0.f; + auto policy = thrust::cuda::par.on(cuda_stream); + thrust::pair, thrust::device_ptr> tuple; + tuple = + thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + total_num); + minel = tuple.first[0]; + maxel = tuple.second[0]; + + if (ema) { + UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel, + maxel, ema_decay); + } else { + UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel); + } + return; +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cuh new file mode 100644 index 0000000000..b4b4d582ee --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int total_num, const int channel_num, const float ema_decay, const bool ema, + cudaStream_t cuda_stream); + +void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh new file mode 100755 index 0000000000..62708663ad --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, + const S *momentum, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cu new file mode 100644 index 0000000000..6dc4d676f2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cu @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "one_hot_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void OneHotKernel(size_t size, const S *indices, size_t depth, const T *on_value, const T *off_value, + size_t left_dim_size, size_t right_dim_size, T *output) { + T on_v = *on_value; + T off_v = *off_value; + for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; + thread_idx += blockDim.x * gridDim.x) { + if (thread_idx < size) { + int left_idx = (thread_idx / (depth * right_dim_size)) % left_dim_size; + int d_idx = thread_idx / right_dim_size % depth; + int right_idx = thread_idx % right_dim_size; + int input_idx = left_idx * right_dim_size + right_idx; + int output_idx = left_idx * depth * right_dim_size + d_idx * right_dim_size + right_idx; + if (indices[input_idx] == d_idx) { + output[output_idx] = on_v; + } else { + output[output_idx] = off_v; + } + } + } +} +template +void OneHot(const S *indices, size_t depth, const T *on_value, const T *off_value, size_t left_dim_size, + size_t right_dim_size, T *output, cudaStream_t cuda_stream) { + size_t size = left_dim_size * depth * right_dim_size; + OneHotKernel<<>>(size, indices, depth, on_value, off_value, + left_dim_size, right_dim_size, output); + return; +} +template void OneHot(const int *indices, size_t depth, const float *on_value, const float *off_value, + size_t left_dim_size, size_t right_dim_size, float *output, cudaStream_t cuda_stream); +template void OneHot(const int *indices, size_t depth, const half *on_value, const half *off_value, + size_t left_dim_size, size_t right_dim_size, half *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/one_hot_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/one_hot_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu new file mode 100755 index 0000000000..3bb4d04a01 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu @@ -0,0 +1,87 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" + +template +__global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, float pad_value, T* output) { + T pad_value_ = static_cast(pad_value); + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int block_num = pos / padded_width / padded_height; + const int padded_w = pos % padded_width; + const int padded_h = pos / padded_width % padded_height; + if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height || + padded_w - pad_left >= old_width) { + output[pos] = pad_value_; + } else { + output[pos] = input[(block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left]; + } + } + return; +} + +template +__global__ void PadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, T* dx) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int block_num = pos / old_width / old_height; + const int padded_w = pos % old_width + pad_left; + const int padded_h = pos / old_width % old_height + pad_top; + dx[pos] = dy[(block_num * padded_height + padded_h) * padded_width + padded_w]; + } + return; +} + +template +void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, + const float pad_value, T* output, cudaStream_t cuda_stream) { + Pad<<>>(size, input, num, channels, old_height, old_width, + padded_height, padded_width, pad_top, pad_left, pad_value, + output); + return; +} + +template +void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, T* dx, cudaStream_t cuda_stream) { + PadGrad<<>>(size, dy, num, channels, old_height, old_width, + padded_height, padded_width, pad_top, pad_left, dx); + return; +} + +template void CalPad(const size_t size, const float* input, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, const int padded_width, + const int pad_top, const int pad_left, float pad_value, float* output, + cudaStream_t cuda_stream); +template void CalPadGrad(const size_t size, const float* dy, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, const int pad_top, const int pad_left, float* dx, + cudaStream_t cuda_stream); +template void CalPad(const size_t size, const half* input, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, const int padded_width, + const int pad_top, const int pad_left, float pad_value, half* output, + cudaStream_t cuda_stream); +template void CalPadGrad(const size_t size, const half* dy, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, const int pad_top, const int pad_left, half* dx, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh new file mode 100755 index 0000000000..b10804fdab --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, + float pad_value, T* output, cudaStream_t cuda_stream); +template +void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, T* dx, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu new file mode 100644 index 0000000000..6f99394562 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "random_op_impl.cuh" +template +__global__ void NormalKernel(int seed, curandState *globalState, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + curand_init(seed, i, 0, &globalState[i]); + output[i] = curand_normal(&globalState[i]); + } + return; +} + +template +void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream) { + int RNG_seed = 0; + if (seed2 != 0) { + RNG_seed = seed2; + } else if (seed != 0) { + RNG_seed = seed; + } else { + RNG_seed = time(NULL); + } + NormalKernel<<>>(RNG_seed, globalState, output, count); + return; +} + +template void StandardNormal(int seed, int seed2, curandState *globalState, + float *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh new file mode 100644 index 0000000000..b099ead9bf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void StandardNormal(int seed, int seed2, curandState *globalState, + T *output, size_t count, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cu new file mode 100644 index 0000000000..80806b552f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cu @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void RmsPropKernel(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, + T* mean_square, T*moment, T* gradients, const size_t size) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + mean_square[i] = decay * mean_square[i] + (1.0 - decay) * gradients[i] * gradients[i]; + moment[i] = momentum * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon) * gradients[i]; + variable[i] -= moment[i]; + } +} + +template +void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, + T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) { + RmsPropKernel<<>>(learning_rate, decay, momentum, epsilon, + variable, mean_square, moment, gradients, size); +} + +template +__global__ void RmsPropCenterKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, + T* variable, T* mean_gradients, T* mean_square, T*moment, T* gradients, + const size_t size) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + mean_gradients[i] = decay[0] * mean_gradients[i] + (1.0 - decay[0]) * gradients[i]; + mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; + moment[i] = momentum[0] * moment[i] + learning_rate[0] * + rsqrt(mean_square[i] - mean_gradients[i] * mean_gradients[i] + epsilon[0]) * gradients[i]; + variable[i] -= moment[i]; + } +} + +template +void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, + T* mean_gradients, T* mean_square, T*moment, T* gradients, const size_t size, + cudaStream_t cuda_stream) { + RmsPropCenterKernel<<>>(learning_rate, decay, momentum, epsilon, + variable, mean_gradients, mean_square, + moment, gradients, size); +} + +template +void RmsProp(const float* learning_rate, const float decay, const float momentum, const float epsilon, + float* variable, float* mean_square, float* moment, float* gradients, const size_t size, + cudaStream_t cuda_stream); + +template +void RmsPropCenter(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, + float* variable, float* mean_gradients, float* mean_square, float*moment, float* gradients, + const size_t size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cuh new file mode 100644 index 0000000000..16ad611381 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ +#include "runtime/device/gpu/cuda_common.h" + +template +void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, T* mean_square, + T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream); + +template +void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, + T* mean_gradients, T* mean_square, T* moment, T* gradients, const size_t size, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu new file mode 100644 index 0000000000..f7086f8093 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh" + +template +__global__ void Select(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + output[pos] = cond[pos] ? input_x[pos] : input_y[pos]; + } + return; +} + +template +void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, + cudaStream_t cuda_stream) { + Select<<>>(size, cond, input_x, input_y, output); + return; +} + +template void CalSelect(const size_t size, const bool* cond, const float* input_X, const float* input_y, + float* output, cudaStream_t cuda_stream); +template void CalSelect(const size_t size, const bool* cond, const int* input_X, const int* input_y, int* output, + cudaStream_t cuda_stream); +template void CalSelect(const size_t size, const bool* cond, const half* input_X, const half* input_y, + half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh new file mode 100644 index 0000000000..e201ab352c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu new file mode 100644 index 0000000000..f0c64bfb01 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh" + +template +__global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const T *logits, const S *labels, + T *outputs) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + if (logits[i] >= 0) { + outputs[i] = 1. / (1. + exp(-logits[i])) - labels[i]; + } else { + const T exp_val = exp(logits[i]); + outputs[i] = exp_val / (1. + exp_val) - labels[i]; + } + } +} + +template +void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, + cudaStream_t cuda_stream) { + SigmoidCrossEntropyWithLogitsGradKernel<<>>(size, logits, labels, + outputs); +} + +template void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const float *logits, + const float *labels, float *outputs, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh new file mode 100644 index 0000000000..6b444d6c02 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu new file mode 100644 index 0000000000..7425ac3809 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh" + +template +__global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const T reverse_factor = static_cast(logits[i] >= 0); + outputs[i] = log1p(exp(logits[i] - 2 * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); + } +} + +template +void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *labels, T *outputs, + cudaStream_t cuda_stream) { + SigmoidCrossEntropyWithLogitsKernel<<>>(size, logits, labels, outputs); +} + +template void SigmoidCrossEntropyWithLogits(const size_t size, const float *logits, const float *labels, + float *outputs, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh new file mode 100644 index 0000000000..7e9130857f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *labels, T *outputs, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu new file mode 100755 index 0000000000..dd4effc174 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu @@ -0,0 +1,191 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" + +template +__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const T *input, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) { + int i = pos / (l2 * l3 * l4) % l1; + int j = pos / (l3 * l4) % l2; + int k = pos / l4 % l3; + int o = pos % l4; + + int offset = (i + s1) * (d2 * d3 * d4) + + (j + s2) * (d3 * d4) + + (k + s3) * d4 + + (o + s4); + output[pos] = input[offset]; + } +} +template +__global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) { + output[start + pos] = dy[p + pos]; + } + return; +} +template +__global__ void StridedSlice(const T* input, int p, int start, int begin, int stride, int ended, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); + pos += blockDim.x * gridDim.x) { + output[p + pos] = input[start + pos * stride]; + } + return; +} +template +__global__ void StridedSliceGrad(const T* dy, int p, int start, int begin, int stride, int ended, T* dx) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); + pos += blockDim.x * gridDim.x) { + dx[start + pos * stride] = dy[p + pos]; + } + return; +} +template +__global__ void FillArray(T* addr, const size_t len, const float value) { + T value_ = static_cast(value); + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < len; pos += blockDim.x * gridDim.x) { + addr[pos] = value_; + } + return; +} +template +void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream) { + FillArray<<>>(addr, input_size, value); + return; +} +template +void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const T *input, T *output, cudaStream_t stream) { + Slice4D<<>>(s1, s2, s3, s4, l1, l2, l3, l4, + d1, d2, d3, d4, input, output); +} +template +void CalSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, const std::vector begin, + const std::vector size, T* output, cudaStream_t cuda_stream) { + int block = in_shape[1] * in_shape[2] * in_shape[3]; + int map = in_shape[2] * in_shape[3]; + int w = in_shape[3]; + int length = size[3]; + int p = 0; + for (int i = begin[0]; i < size[0] + begin[0]; i++) { + for (int j = begin[1]; j < size[1] + begin[1]; j++) { + for (int k = begin[2]; k < size[2] + begin[2]; k++) { + SliceGrad<<>>( + dy, p, i * block + j * map + k * w + begin[3], length, output); + p = p + size[3]; + } + } + } +} +template +void CalStridedSlice(const size_t input_size, const T* input, const std::vector in_shape, + const std::vector begin, const std::vector end, const std::vector strides, + T* output, cudaStream_t cuda_stream) { + int block = in_shape[1] * in_shape[2] * in_shape[3]; + int map = in_shape[2] * in_shape[3]; + int w = in_shape[3]; + int ended = end[3]; + int p = 0; + int start = 0; + for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0])); i += std::abs(strides[0])) { + for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1])); j += std::abs(strides[1])) { + for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2])); k += std::abs(strides[2])) { + start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + + (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; + StridedSlice<<>>(input, p, start, begin[3], strides[3], + ended, output); + p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); + } + } + } +} +template +void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, + const std::vector begin, const std::vector end, const std::vector strides, + T* dx, cudaStream_t cuda_stream) { + int block = in_shape[1] * in_shape[2] * in_shape[3]; + int map = in_shape[2] * in_shape[3]; + int w = in_shape[3]; + int ended = end[3]; + int p = 0; + int start = 0; + for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0] + 1)); i += std::abs(strides[0])) { + for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1] + 1)); + j += std::abs(strides[1])) { + for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2] + 1)); + k += std::abs(strides[2])) { + start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + + (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; + StridedSliceGrad<<>>(dy, p, start, begin[3], strides[3], + ended, dx); + p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); + } + } + } +} + +template void FillDeviceArray(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const float *input, float *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, + const std::vector begin, const std::vector size, float* output, + cudaStream_t cuda_stream); +template void CalStridedSlice(const size_t input_size, const float* input, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, float* output, cudaStream_t cuda_stream); +template void CalStridedSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, float* dx, cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const half *input, half *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, + const std::vector begin, const std::vector size, half* output, + cudaStream_t cuda_stream); +template void CalStridedSlice(const size_t input_size, const half* input, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, half* output, cudaStream_t cuda_stream); +template void CalStridedSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, half* dx, cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const int *input, int *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, + const std::vector begin, const std::vector size, int* output, + cudaStream_t cuda_stream); +template void CalStridedSlice(const size_t input_size, const int* input, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, int* output, cudaStream_t cuda_stream); +template void CalStridedSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, int* dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh new file mode 100755 index 0000000000..e04f277c3d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ + +#include +#include +#include "runtime/device/gpu/cuda_common.h" + + +template +void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const T *input, T *output, cudaStream_t stream); +template +void CalSliceGrad(const size_t input_size, const T* input, const std::vector in_shape, + const std::vector begin, const std::vector size, T* output, cudaStream_t cuda_stream); +template +void CalStridedSlice(const size_t input_size, const T* input, const std::vector in_shape, + const std::vector begin, const std::vector end, const std::vector strides, + T* output, cudaStream_t cuda_stream); +template +void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, + const std::vector begin, const std::vector end, const std::vector strides, + T* dx, cudaStream_t cuda_stream); +template +void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu new file mode 100644 index 0000000000..9050044b7f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "smooth_l1_loss_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void SmoothL1LossKernel(const int input_size, const float sigma, const T *prediction, const T *target, + T *loss) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T value = (prediction[i] - target[i]) > 0 ? (prediction[i] - target[i]) : (target[i] - prediction[i]); + if (value < sigma) { + loss[i] = static_cast(0.5) * value * value; + } else { + loss[i] = value - static_cast(0.5); + } + } +} + +template +void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss, + cudaStream_t stream) { + SmoothL1LossKernel<<>>(input_size, sigma, prediction, target, loss); +} + +template +__global__ void SmoothL1LossGradKernel(const int input_size, const float sigma, const T *prediction, const T *target, + const T *dloss, T *dx) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T value = prediction[i] - target[i]; + if (value > static_cast(sigma)) { + dx[i] = dloss[i]; + } else if (value < static_cast(-sigma)) { + dx[i] = -dloss[i]; + } else { + dx[i] = value * dloss[i]; + } + } +} + +template +void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss, + T *dx, cudaStream_t stream) { + SmoothL1LossGradKernel<<>>(input_size, sigma, prediction, target, + dloss, dx); +} + +template void SmoothL1Loss(const int &input_size, const float &sigma, const float *prediction, const float *target, + float *loss, cudaStream_t stream); +template void SmoothL1LossGrad(const int &input_size, const float &sigma, const float *prediction, const float *target, + const float *dloss, float *dx, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh new file mode 100644 index 0000000000..7938e18a3b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ +template +void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss, + cudaStream_t stream); +template +void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss, + T *dx, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh new file mode 100755 index 0000000000..fa32260381 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CalCrossEntropy(const float *logits, T *labels, const int batch_size, const int class_num, float *loss, + cudaStream_t cuda_stream); + +template +void CalCrossEntropyGrad(const float *logits, T *labels, const int batch_size, const int class_num, float *grad, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu new file mode 100755 index 0000000000..ffcb2c8052 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "transpose_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis, + const int shape_size, T* output) { + int pos_size; + int temp_pos; + int newpos; + int newpos_size; + int pos_array[TRANSPOSE_MAX_DIMENSION]; + + // for example 4-D: pos = posArray[0] * input_shape[1] * input_shape[2] * input_shape[3] + + // posArray[1] * input_shape[2] * input_shape[3] + + // posArray[2] * input_shape[3] + + // posArray[3] + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + temp_pos = pos; + pos_size = size / input_shape[0]; + pos_array[0] = temp_pos / pos_size; + for (int i = 1; i < shape_size; i++) { + temp_pos -= pos_array[i - 1] * pos_size; + pos_size = pos_size / input_shape[i]; + pos_array[i] = temp_pos / pos_size; + } + + newpos = pos_array[input_axis[shape_size - 1]]; + newpos_size = 1; + for (int j = shape_size - 2; j >= 0; j--) { + newpos_size *= input_shape[input_axis[j + 1]]; + newpos += pos_array[input_axis[j]] * newpos_size; + } + + output[newpos] = input[pos]; + } + return; +} +template +void CalTranspose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size, + T* output, cudaStream_t cuda_stream) { + Transpose<<>>(size, input, input_shape, input_axis, shape_size, + output); + return; +} + +template void CalTranspose(const int size, const float* input, const int* input_shape, const int* input_axis, + const int shape_size, float* output, cudaStream_t cuda_stream); +template void CalTranspose(const int size, const half* input, const int* input_shape, const int* input_axis, + const int shape_size, half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/transpose_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/transpose_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh new file mode 100755 index 0000000000..cf8b30866e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Logarithm(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Negative(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cu new file mode 100644 index 0000000000..3d299c2352 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cu @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh" + +template +__global__ void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + T* input_addr, S* ids_addr, T* output_addr) { + for (int input_index = blockIdx.x * blockDim.x + threadIdx.x; input_index < input_dim0 * input_dim1; + input_index += blockDim.x * gridDim.x) { + size_t j = input_index / input_dim1; + size_t k = input_index % input_dim1; + + S i = ids_addr[j]; + if (i < 0 || i >= output_dim0) { + continue; + } + size_t output_index = i * output_dim1 + k; + atomicAdd(output_addr + output_index, input_addr[input_index]); + } +} + +template +void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + T* input_addr, S* ids_addr, T* output_addr, cudaStream_t stream) { + int size = input_dim0 * input_dim1; + UnsortedSegmentSum<<>>(input_dim0, input_dim1, + output_dim0, output_dim1, input_addr, ids_addr, output_addr); + return; +} + +template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + float* input_addr, int* ids_addr, float* output_addr, cudaStream_t stream); +template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + float* input_addr, int64_t* ids_addr, float* output_addr, cudaStream_t stream); + +template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + int* input_addr, int* ids_addr, int* output_addr, cudaStream_t stream); +template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + int* input_addr, int64_t* ids_addr, int* output_addr, cudaStream_t stream); + + + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh new file mode 100644 index 0000000000..315677fde4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + T* input_addr, S* ids, T* output_addr, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.cc new file mode 100644 index 0000000000..3c88b88c74 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/data/dataset_init_kernel.h" +#include "backend/kernel_compiler/gpu/data/dataset_utils.h" +#include "runtime/device/gpu/gpu_buffer_mgr.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +using mindspore::device::GpuBufferMgr; + +DatasetInitKernel::DatasetInitKernel() : total_bytes_(0) {} + +const std::vector &DatasetInitKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &DatasetInitKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &DatasetInitKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool DatasetInitKernel::Init(const CNodePtr &kernel_node) { + queue_name_ = GetAttr(kernel_node, "queue_name"); + auto shapes = GetAttr>>(kernel_node, "shapes"); + auto types = GetAttr>(kernel_node, "types"); + if (shapes.size() != types.size()) { + MS_LOG(EXCEPTION) << "Invalid shapes: " << shapes << ", types: " << types; + } + + for (size_t i = 0; i < shapes.size(); i++) { + int unit = UnitSizeInBytes(types[i]->type_id()); + int nums = ElementNums(shapes[i]); + int bytes = unit * nums; + shapes_.push_back(bytes); + total_bytes_ += bytes; + } + return true; +} + +void DatasetInitKernel::InitSizeLists() { return; } + +bool DatasetInitKernel::Launch(const std::vector &, const std::vector &, + const std::vector &, void *) { + void *addr = nullptr; + size_t len = total_bytes_ * buffer_q_capacity_; + + if (!device::gpu::GPUMemoryAllocator::GetInstance().AllocBufferQueueMem(len, &addr)) { + MS_LOG(EXCEPTION) << "Memory not enough: failed to allocate GPU buffer queue memory[" << len << "]."; + } + + auto status = GpuBufferMgr::GetInstance().Create(0, queue_name_, addr, shapes_, buffer_q_capacity_); + if (status) { + MS_LOG(EXCEPTION) << "Init Dataset Failed. len: " << len << ", status:" << status; + } + + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.h new file mode 100644 index 0000000000..f8cc9b19ea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DATASET_INIT_KERNEL_H +#define MINDSPORE_DATASET_INIT_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class DatasetInitKernel : public GpuKernel { + public: + DatasetInitKernel(); + ~DatasetInitKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel_node) override; + + protected: + void InitSizeLists() override; + + private: + std::string queue_name_; + std::vector shapes_; + size_t total_bytes_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + // The capacity of buffer Q. + size_t buffer_q_capacity_{2}; +}; + +MS_REG_GPU_KERNEL(InitDataSetQueue, DatasetInitKernel) +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_QUEUE_CPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc new file mode 100644 index 0000000000..67a487ce28 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h" +#include +#include +#include +#include "runtime/device/gpu/gpu_buffer_mgr.h" +#include "runtime/device/gpu/gpu_common.h" +#include "backend/kernel_compiler/gpu/data/dataset_utils.h" + +namespace mindspore { +namespace kernel { +using mindspore::device::GpuBufferMgr; +using mindspore::device::HandleMgr; + +DatasetIteratorKernel::DatasetIteratorKernel() : handle_(HandleMgr::INVALID_HANDLE), total_bytes_(0) {} + +DatasetIteratorKernel::~DatasetIteratorKernel() { GpuBufferMgr::GetInstance().Close(handle_); } + +const std::vector &DatasetIteratorKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &DatasetIteratorKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &DatasetIteratorKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) { + queue_name_ = GetAttr(kernel_node, "shared_name"); + auto shapes = GetAttr>>(kernel_node, "shapes"); + auto types = GetAttr>(kernel_node, "types"); + if (shapes.size() != types.size()) { + MS_LOG(EXCEPTION) << "Invalid shapes: " << shapes << ", types: " << types; + } + + for (size_t i = 0; i < shapes.size(); i++) { + int unit = UnitSizeInBytes(types[i]->type_id()); + int nums = ElementNums(shapes[i]); + int bytes = unit * nums; + output_size_list_.push_back(bytes); + total_bytes_ += bytes; + } + + handle_ = GpuBufferMgr::GetInstance().Open(0, queue_name_, output_size_list_); + if (handle_ == HandleMgr::INVALID_HANDLE) { + MS_LOG(EXCEPTION) << "Gpu Queue(" << queue_name_ << ") Open Failed"; + } + + return true; +} + +void DatasetIteratorKernel::InitSizeLists() { return; } + +bool DatasetIteratorKernel::Launch(const std::vector &, const std::vector &, + const std::vector &outputs, void *stream) { + void *addr = nullptr; + size_t len = 0; + + int repeat = 0; + while (true) { + auto ret = GpuBufferMgr::GetInstance().Front(handle_, &addr, &len); + if (ret == device::SUCCESS) { + break; + } + + if (ret == device::TIMEOUT) { + repeat++; + if (repeat < 10) { + MS_LOG(INFO) << "Waiting for data...(" << repeat << " / 10)"; + continue; + } else { + MS_LOG(ERROR) << "Get data timeout"; + return false; + } + } + + MS_LOG(ERROR) << "Get data failed, errcode " << ret; + return false; + } + + if (total_bytes_ != len) { + MS_LOG(ERROR) << "Dataset front error. read: " << len << ", expect: " << total_bytes_ << ", "; + return false; + } + + for (size_t i = 0; i < output_size_list_.size(); i++) { + void *output_addr = GetDeviceAddress(outputs, i); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, addr, output_size_list_[i], cudaMemcpyDeviceToDevice, + reinterpret_cast(stream)), + "Cuda Memcpy Failed"); + addr = reinterpret_cast(addr) + output_size_list_[i]; + } + + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream)), + "cudaStreamSynchronize failed"); + (void)GpuBufferMgr::GetInstance().Pop(handle_); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h new file mode 100644 index 0000000000..746aed3294 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GET_NEXT_KERNEL_H +#define MINDSPORE_GET_NEXT_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class DatasetIteratorKernel : public GpuKernel { + public: + DatasetIteratorKernel(); + ~DatasetIteratorKernel(); + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel_node) override; + + protected: + void InitSizeLists() override; + + private: + std::string queue_name_; + unsigned int handle_; + size_t total_bytes_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +MS_REG_GPU_KERNEL(GetNext, DatasetIteratorKernel) +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_QUEUE_CPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.cc new file mode 100644 index 0000000000..cb014a3d2b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/data/dataset_utils.h" + +namespace mindspore { +namespace kernel { +size_t UnitSizeInBytes(const mindspore::TypeId &t) { + size_t bytes = 0; + switch (t) { + case kNumberTypeBool: + case kNumberTypeInt8: + case kNumberTypeUInt8: + bytes = 1; + break; + case kNumberTypeInt16: + case kNumberTypeUInt16: + case kNumberTypeFloat16: + bytes = 2; + break; + case kNumberTypeInt: + case kNumberTypeUInt: + case kNumberTypeInt32: + case kNumberTypeUInt32: + case kNumberTypeFloat: + case kNumberTypeFloat32: + bytes = 4; + break; + case kNumberTypeUInt64: + case kNumberTypeInt64: + case kNumberTypeFloat64: + bytes = 8; + break; + default: + MS_LOG(EXCEPTION) << "Invalid types " << t; + break; + } + + return bytes; +} + +int ElementNums(const std::vector &shape) { + if (shape.size() == 0) { + return 0; + } + + int nums = 1; + for (size_t i = 0; i < shape.size(); i++) { + nums *= shape[i]; + } + + return nums; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_utils.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.h similarity index 100% rename from mindspore/ccsrc/kernel/gpu/data/dataset_utils.h rename to mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h new file mode 100644 index 0000000000..4c179f2173 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h @@ -0,0 +1,106 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "runtime/device/gpu/gpu_device_manager.h" +#include "runtime/device/gpu/gpu_common.h" +#include "backend/session/anf_runtime_algorithm.h" +using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; + +namespace mindspore { +namespace kernel { +class GpuKernel : public KernelMod { + public: + virtual ~GpuKernel() = default; + virtual bool Init(const CNodePtr &kernel_node) = 0; + + protected: + virtual void InitResource() {} + virtual void InitSizeLists() = 0; + + template + inline T *GetDeviceAddress(const std::vector &addr_list, size_t index) { + if (index >= addr_list.size()) { + MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; + } + // Kernels may run normally without workspace, the addr_list[index] maybe nullptr. + if ((addr_list[index] == nullptr) || (addr_list[index]->size == 0)) { + return nullptr; + } + MS_EXCEPTION_IF_NULL(addr_list[index]->addr); + return reinterpret_cast(addr_list[index]->addr); + } + + template + inline T GetAttr(const CNodePtr &kernel_node, const std::string &key) const { + const PrimitivePtr &prim = AnfAlgo::GetCNodePrimitive(kernel_node); + const ValuePtr &attr = prim->GetAttr(key); + if (attr == nullptr) { + const std::string &prim_name = AnfAlgo::GetCNodeName(kernel_node); + MS_LOG(EXCEPTION) << "The attr(" << key << ") of kernel(" << prim_name << ") not exist"; + } + return GetValue(attr); + } + // expand Nd Shape to 4d (N in [0,4]) + void ShapeNdTo4d(const std::vector &src, std::vector *dst) { + if (src.size() > 4) { + MS_EXCEPTION(ValueError) << src.size() << "-D data is not supported!"; + } + dst->push_back(src.size() < 4 ? 1 : SizeToInt(src[src.size() - 4])); + dst->push_back(src.size() < 3 ? 1 : SizeToInt(src[src.size() - 3])); + dst->push_back(src.size() < 2 ? 1 : SizeToInt(src[src.size() - 2])); + dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1])); + } + + inline void CheckBroadcast4TensorOp(const std::vector &A, const std::vector &B, + const std::vector &Out) { + if (A != Out && B != Out) { + MS_EXCEPTION(ValueError) + << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n" + "InputA must match the corresponding dimension of the destination tensor outC, and each " + "dimension of the inputB " + "must match the corresponding dimension of outC or must be equal to 1."; + } + } + + // choose the suitable datatype for cudnn/cublas + inline cudnnDataType_t GetCudnnDataType(const std::string &Type) { + auto type = kCudnnDtypeMap.find(Type); + if (type == kCudnnDtypeMap.end()) { + MS_EXCEPTION(TypeError) << Type << " is not supported."; + } + return type->second; + } + inline cudaDataType_t GetCudaDataType(const std::string &Type) { + auto type = kCudaDtypeMap.find(Type); + if (type == kCudaDtypeMap.end()) { + MS_EXCEPTION(TypeError) << Type << " is not supported."; + } + return type->second; + } +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc new file mode 100644 index 0000000000..4a0191abd7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +#include +#include + +#include "common/utils.h" +#include "runtime/device/kernel_info.h" +#include "runtime/device/gpu/cuda_common.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +GpuKernelFactory &GpuKernelFactory::GetInstance() { + static GpuKernelFactory instance; + return instance; +} + +void GpuKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr, + GpuKernelCreater &&creater) { + map_kernel_name_to_creater_[kernel_name].emplace_back(kernel_attr, creater); +} + +void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, + std::vector> *iter_second, + size_t attr_index) { + if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { + if (iter_second->at(attr_index).first.GetAllSame()) { + auto dtype = iter_second->at(attr_index).first.GetInputAttr(0).first; + for (size_t attr = 1; attr < kernel_info->GetInputNum(); ++attr) { + (void)iter_second->at(attr_index).first.AddInputAttr(dtype); + } + } else { + MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; + } + } + if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { + if (iter_second->at(attr_index).first.GetAllSame()) { + auto dtype = iter_second->at(attr_index).first.GetOutputAttr(0).first; + for (size_t attr = 1; attr < kernel_info->GetOutputNum(); ++attr) { + (void)iter_second->at(attr_index).first.AddOutputAttr(dtype); + } + } else { + MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; + } + } +} + +std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) { + std::string type_lists = ""; + auto iter = map_kernel_name_to_creater_.find(kernel_name); + if (map_kernel_name_to_creater_.end() == iter) { + return type_lists; + } + for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { + std::string type_list = "in["; + auto attr = (iter->second)[attr_index].first; + for (size_t input_index = 0; input_index < attr.GetInputSize(); ++input_index) { + type_list = type_list + TypeId2String(attr.GetInputAttr(input_index).first) + + ((input_index == (attr.GetInputSize() - 1)) ? "" : " "); + } + type_list = type_list + "], out["; + for (size_t input_index = 0; input_index < attr.GetOutputSize(); ++input_index) { + type_list = type_list + TypeId2String(attr.GetOutputAttr(input_index).first) + + ((input_index == (attr.GetOutputSize() - 1)) ? "" : " "); + } + type_lists = type_lists + type_list + "]; "; + } + return type_lists; +} + +std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string &kernel_name, + const KernelBuildInfo *kernel_info) { + auto iter = map_kernel_name_to_creater_.find(kernel_name); + const int marjor_sm = GET_MAJOR_SM; + if (map_kernel_name_to_creater_.end() == iter) { + MS_LOG(INFO) << "Not registered GPU kernel: op[" << kernel_name << "]!"; + return std::make_pair(false, 0); + } + if ((iter->second).size() == 1 && (iter->second)[0].first.GetInputSize() == 0) { + return std::make_pair(true, 0); + } + + for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { + CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index); + bool flag = true; + // data type matching check of all input parameters of kernel + for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { + if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { + if (marjor_sm < MINIUM_SM) { + MS_LOG(EXCEPTION) << "Half precision ops can be used on Devices which computing capacity is >= " << MINIUM_SM + << ", but the current device's computing capacity is " << marjor_sm; + } + MS_LOG(WARNING) << "It is recommended to use devices with a computing capacity >= " << RECOMMEND_SM + << ", but the current device's computing capacity is " << marjor_sm; + } + if (kernel_info->GetInputDeviceType(input_index) != + (iter->second)[attr_index].first.GetInputAttr(input_index).first) { + flag = false; + break; + } + } + if (!flag) { + continue; + } + // data type matching check of all output parameters of kernel + for (size_t output_index = 0; output_index < kernel_info->GetOutputNum(); output_index++) { + if (kernel_info->GetOutputDeviceType(output_index) != + (iter->second)[attr_index].first.GetOutputAttr(output_index).first) { + flag = false; + break; + } + } + // finish data type matching check and return a pair maintain the whether matching is success, + // if first is true, second is index of matching KernelAttr and creater pair in vector; + if (flag) { + size_t match_index = attr_index; + return std::make_pair(true, match_index); + } + } + return std::make_pair(false, 0); +} + +GpuKernel *GpuKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { + auto kernel_info = dynamic_cast(apply_kernel->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(kernel_build_Info); + std::pair ret_pair = GpuKernelAttrCheck(kernel_name, kernel_build_Info); + if (ret_pair.first) { + return (map_kernel_name_to_creater_.find(kernel_name)->second)[ret_pair.second].second(); + } + return nullptr; +} + +bool GpuKernelFactory::SearchRegistered(const std::string &kernel_name, const KernelBuildInfoPtr &kernel_build_info) { + std::pair ret_pair = GpuKernelAttrCheck(kernel_name, kernel_build_info.get()); + return ret_pair.first; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h new file mode 100644 index 0000000000..8834fa0f1a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h @@ -0,0 +1,93 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "runtime/device/gpu/kernel_info_setter.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +using mindspore::device::gpu::KernelAttr; +using GpuKernelCreater = std::function; +class GpuKernelFactory { + public: + ~GpuKernelFactory() = default; + + static GpuKernelFactory &GetInstance(); + + void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, GpuKernelCreater &&creater); + + GpuKernel *Create(const std::string &kernel_name, const CNodePtr &apply_kernel); + + bool SearchRegistered(const std::string &kernel_name, const KernelBuildInfoPtr &kernel_info); + + std::string SupportedTypeList(const std::string &kernel_name); + + private: + GpuKernelFactory() = default; + + GpuKernelFactory(GpuKernelFactory const &); + + GpuKernelFactory &operator=(const GpuKernelFactory &); + + std::pair GpuKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info); + void CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, + std::vector> *iter_second, size_t attr_index); + // map to maintain kernel and creater, KernelAttr object and creater must be registered as a pair. + std::map>> map_kernel_name_to_creater_; +}; + +class GpuKernelRegister { + public: + GpuKernelRegister(const std::string &kernel_name, const KernelAttr &kernel_attr, GpuKernelCreater &&creater) { + GpuKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(creater)); + } +}; + +#define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \ + static_assert(std::is_base_of::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_gpu_kernel_reg(#OPNAME, KernelAttr(), []() { return new OPCLASS(); }); + +// regular register of fixed accuracy kernels +#define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \ + static_assert(std::is_base_of::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); + +// register of mixed accuracy kernels which use template and maintain one typename, ignore input num +#define MS_REG_GPU_KERNEL_SAME(OPNAME, ATTR, OPCLASS, T) \ + static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_##T##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); + +// register of mixed accuracy kernels which use template and maintain one typename +#define MS_REG_GPU_KERNEL_ONE(OPNAME, ATTR, OPCLASS, T) \ + static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_##T##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); + +// register of mixed accuracy kernels which use template and maintain two typename +#define MS_REG_GPU_KERNEL_TWO(OPNAME, ATTR, OPCLASS, T, S) \ + static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \ + []() { return new OPCLASS(); }); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/kernel_constants.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h similarity index 100% rename from mindspore/ccsrc/kernel/gpu/kernel_constants.h rename to mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.cc new file mode 100644 index 0000000000..86c7d8c108 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/addn_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AddNGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + AddNGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(AddN, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + AddNGpuFwdKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h new file mode 100644 index 0000000000..b69bd20216 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h @@ -0,0 +1,143 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class AddNGpuFwdKernel : public GpuKernel { + public: + AddNGpuFwdKernel() + : cudnn_handle_(nullptr), + input_descriptor_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + input_size_(0), + output_size_(0), + workspace_size_(0), + is_null_input_(false), + num_input_(0) {} + ~AddNGpuFwdKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *output_addr = GetDeviceAddress(outputs, 0); + if (cudnn_data_type_ == CUDNN_DATA_INT32) { + FillDeviceArray(outputs[0]->size / sizeof(T), output_addr, 0.0f, reinterpret_cast(stream_ptr)); + } + const float alpha = 1; + const float beta = 0; + for (size_t i = 0; i < IntToSize(num_input_); i++) { + T *input_addr = GetDeviceAddress(inputs, i); + if (cudnn_data_type_ == CUDNN_DATA_INT32) { + NoBroadcast(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr, + &(i > 0 ? alpha : beta), input_descriptor_, output_addr), + "cudnnAddTensor failed"); + } + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + num_input_ = GetAttr(kernel_node, "n"); + if (IntToSize(num_input_) != input_num) { + MS_LOG(ERROR) << "Input number is " << num_input_ << " in attr, but got " << input_num << "input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but cudnnAddTensor needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "AddNGpuFwdKernel input is null"; + InitSizeLists(); + return true; + } + for (size_t i = input_shape.size(); i < 4; i++) { + (void)input_shape.insert(input_shape.begin(), 1); + } + int dimA[4]; + for (size_t i = 0; i < input_shape.size(); i++) { + dimA[i] = SizeToInt(input_shape[i]); + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + SizeToInt(input_shape.size()), dimA), + "cudnnSetTensorNdDescriptor failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); + } + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + } + for (int i = 0; i < num_input_; i++) { + input_size_list_.push_back(input_size_); + } + output_size_list_.push_back(input_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); + } + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_descriptor_; + cudnnDataType_t cudnn_data_type_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + bool is_null_input_; + int num_input_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc new file mode 100644 index 0000000000..bffcca158b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + AssignAddGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE( + AssignAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AssignAddGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + AssignAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + AssignAddGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h new file mode 100644 index 0000000000..04a74b3412 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h @@ -0,0 +1,95 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cuh" +namespace mindspore { +namespace kernel { +template +class AssignAddGpuFwdKernel : public GpuKernel { + public: + AssignAddGpuFwdKernel() : is_null_input_(false), input_size_(0) {} + ~AssignAddGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *input_addr2 = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + CalAssignAdd(input_size_ / sizeof(T), input_addr, input_addr2, output_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but cudnnAddTensor needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "AssignAddGpuFwdKernel input is null"; + InitSizeLists(); + return true; + } + input_size_ = sizeof(T); + for (size_t i : input_shape) { + input_size_ = i * input_size_; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.cc new file mode 100644 index 0000000000..a07fb6ddf6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + BiasAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BiasAddGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + BiasAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BiasAddGpuKernel, float16) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h new file mode 100644 index 0000000000..fd344be28a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h @@ -0,0 +1,149 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_BIAS_ADD_GPU_KERNEL_H +#define MINDSPORE_BIAS_ADD_GPU_KERNEL_H +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class BiasAddGpuKernel : public GpuKernel { + public: + BiasAddGpuKernel() + : cudnn_handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + x_desc_(nullptr), + b_desc_(nullptr), + op_desc_(nullptr), + is_null_input_(false) {} + ~BiasAddGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + VARIABLE_NOT_USED(stream_ptr); + if (is_null_input_) { + return true; + } + + T *x_addr = GetDeviceAddress(inputs, 0); + T *b_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + try { + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnOpTensor(cudnn_handle_, op_desc_, &alpha, x_desc_, x_addr, &alpha, b_desc_, + b_addr, &beta, x_desc_, output_addr), + "cudnnOpTensor failed"); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cudnnOpTensor"; + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto num_dims = x_shape.size(); + is_null_input_ = CHECK_NULL_INPUT(x_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "input is null"; + InitSizeLists(); + return true; + } + + if (num_dims < 2) { + MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims; + } + + std::string format = GetAttr(kernel_node, "data_format"); + string::size_type pos = format.find("C"); + if (pos == std::string::npos || pos >= num_dims) { + MS_LOG(EXCEPTION) << "format '" << format << "' invalid"; + } + + // Expand to 4 dims for cudnnSetTensorNdDescriptorEx. + auto cudnn_dims = std::max(num_dims, 4UL); + std::unique_ptr x_dims = std::make_unique(cudnn_dims); + std::unique_ptr b_dims = std::make_unique(cudnn_dims); + for (size_t i = 0; i < cudnn_dims; i++) { + x_dims[i] = (i < num_dims) ? SizeToInt(x_shape[i]) : 1; + b_dims[i] = (i == pos) ? SizeToInt(x_shape[i]) : 1; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()), + "cudnnSetTensorNdDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(b_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()), + "cudnnSetTensorNdDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetOpTensorDescriptor(op_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN), + "cudnnSetOpTensorDescriptor failed"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&b_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateOpTensorDescriptor(&op_desc_), "cudnnCreateOpTensorDescriptor failed"); + } + void InitSizeLists() override { + size_t x_size, b_size; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &x_size), "cudnnGetTensorSizeInBytes failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(b_desc_, &b_size), "cudnnGetTensorSizeInBytes failed."); + input_size_list_.push_back(x_size); + input_size_list_.push_back(b_size); + output_size_list_.push_back(x_size); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyOpTensorDescriptor(op_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(b_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyOpTensorDescriptor failed"); + } + + cudnnHandle_t cudnn_handle_; + cudnnDataType_t cudnn_data_type_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t b_desc_; + cudnnOpTensorDescriptor_t op_desc_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_BIAS_ADD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc new file mode 100644 index 0000000000..41e7147328 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +// fp32 +MS_REG_GPU_KERNEL_TWO( + Greater, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, float, bool) +MS_REG_GPU_KERNEL_TWO( + Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, float, bool) +MS_REG_GPU_KERNEL_TWO( + Maximum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + RealDiv, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + TensorAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) + +// fp16 +MS_REG_GPU_KERNEL_TWO( + Greater, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, half, bool) +MS_REG_GPU_KERNEL_TWO( + Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, half, bool) +MS_REG_GPU_KERNEL_TWO( + Maximum, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + RealDiv, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + TensorAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) + +// int32 +MS_REG_GPU_KERNEL_TWO( + TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h new file mode 100644 index 0000000000..aaf827723a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +namespace mindspore { +namespace kernel { +template +class BroadcastOpGpuKernel : public GpuKernel { + public: + BroadcastOpGpuKernel() + : op_type_(BROADCAST_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} + ~BroadcastOpGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *lhs = GetDeviceAddress(inputs, 0); + T *rhs = GetDeviceAddress(inputs, 1); + S *output = GetDeviceAddress(outputs, 0); + + if (need_broadcast_) { + Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2], + rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs, + rhs, output, reinterpret_cast(stream_ptr)); + } else { + NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast(stream_ptr)); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + GetOpType(kernel_node); + auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); + need_broadcast_ = IsBroadcast(shape1, shape2); + if (need_broadcast_ && shape1.size() > 4) { + MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; + } + + for (size_t i = 0; i < shape3.size(); i++) { + output_shape_[i] = shape3[i]; + output_num_ *= shape3[i]; + } + int lhs_offset = shape3.size() - shape1.size(); + for (size_t j = 0; j < shape1.size(); j++) { + lhs_shape_[j + lhs_offset] = shape1[j]; + input1_num_ *= shape1[j]; + } + int rhs_offset = shape3.size() - shape2.size(); + for (size_t k = 0; k < shape2.size(); k++) { + rhs_shape_[k + rhs_offset] = shape2[k]; + input2_num_ *= shape2[k]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { return; } + void InitSizeLists() override { + input_size_list_.push_back(input1_num_ * sizeof(T)); + input_size_list_.push_back(input2_num_ * sizeof(T)); + output_size_list_.push_back(output_num_ * sizeof(S)); + } + + private: + void GetOpType(const CNodePtr &kernel_node) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + + static std::map kBroadcastTypeMap = { + {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, + {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, + {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, + }; + + auto iter = kBroadcastTypeMap.find(kernel_name); + if (iter == kBroadcastTypeMap.end()) { + MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; + } else { + op_type_ = iter->second; + } + } + + bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { + if (lhs.size() != rhs.size()) { + return true; + } + for (size_t i = 0; i < lhs.size(); i++) { + if (lhs[i] != rhs[i]) { + return true; + } + } + return false; + } + + BroadcastOpType op_type_; + bool need_broadcast_; + int input1_num_; + int input2_num_; + int output_num_; + int lhs_shape_[4] = {1, 1, 1, 1}; + int rhs_shape_[4] = {1, 1, 1, 1}; + int output_shape_[4] = {1, 1, 1, 1}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc new file mode 100644 index 0000000000..49be2fd9a6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(MaximumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + BroadcastOpGradGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(MaximumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + BroadcastOpGradGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h new file mode 100644 index 0000000000..6258c5c4e2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +namespace mindspore { +namespace kernel { +template +class BroadcastOpGradGpuKernel : public GpuKernel { + public: + BroadcastOpGradGpuKernel() + : op_type_(BROADCAST_GRAD_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} + ~BroadcastOpGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *x1 = GetDeviceAddress(inputs, 0); + T *x2 = GetDeviceAddress(inputs, 1); + T *dy = GetDeviceAddress(inputs, 2); + T *dx1 = GetDeviceAddress(outputs, 0); + T *dx2 = GetDeviceAddress(outputs, 1); + + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(dx1, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), + "cudaMemSet Failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(dx2, 0, outputs[1]->size, reinterpret_cast(stream_ptr)), + "cudaMemSet Failed"); + if (need_broadcast_) { + BroadcastGrad(x1_shape_[0], x1_shape_[1], x1_shape_[2], x1_shape_[3], x2_shape_[0], x2_shape_[1], x2_shape_[2], + x2_shape_[3], dy_shape_[0], dy_shape_[1], dy_shape_[2], dy_shape_[3], op_type_, x1, x2, dy, dx1, + dx2, reinterpret_cast(stream_ptr)); + } else { + NoBroadcastGrad(output_num_, op_type_, x1, x2, dy, dx1, dx2, reinterpret_cast(stream_ptr)); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + GetOpType(kernel_node); + auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto shape3 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + need_broadcast_ = IsBroadcast(shape1, shape2); + if (need_broadcast_ && shape1.size() > 4) { + MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; + } + + for (size_t i = 0; i < shape3.size(); i++) { + dy_shape_[i] = shape3[i]; + output_num_ *= shape3[i]; + } + int x1_offset = shape3.size() - shape1.size(); + for (size_t i = 0; i < shape1.size(); i++) { + x1_shape_[i + x1_offset] = shape1[i]; + input1_num_ *= shape1[i]; + } + int x2_offset = shape3.size() - shape2.size(); + for (size_t i = 0; i < shape2.size(); i++) { + x2_shape_[i + x2_offset] = shape2[i]; + input2_num_ *= shape2[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { return; } + void InitSizeLists() override { + input_size_list_.push_back(input1_num_ * sizeof(T)); + input_size_list_.push_back(input2_num_ * sizeof(T)); + input_size_list_.push_back(output_num_ * sizeof(T)); + output_size_list_.push_back(input1_num_ * sizeof(T)); + output_size_list_.push_back(input2_num_ * sizeof(T)); + } + + private: + void GetOpType(const CNodePtr &kernel_node) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + + static std::map kBroadcastTypeMap = { + {"MaximumGrad", BROADCAST_GRAD_TYPE_MAXIMUM}, + {"MinimumGrad", BROADCAST_GRAD_TYPE_MINIMUM}, + }; + + auto iter = kBroadcastTypeMap.find(kernel_name); + if (iter == kBroadcastTypeMap.end()) { + MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; + } else { + op_type_ = iter->second; + } + } + + bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { + if (lhs.size() != rhs.size()) { + return true; + } + for (size_t i = 0; i < lhs.size(); i++) { + if (lhs[i] != rhs[i]) { + return true; + } + } + return false; + } + + BroadcastGradOpType op_type_; + bool need_broadcast_; + int input1_num_; + int input2_num_; + int output_num_; + int x1_shape_[4] = {1, 1, 1, 1}; + int x2_shape_[4] = {1, 1, 1, 1}; + int dy_shape_[4] = {1, 1, 1, 1}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.cc new file mode 100644 index 0000000000..3103f30f52 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + EqualCount, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + EqualCountGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + EqualCount, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + EqualCountGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + EqualCount, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + EqualCountGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.h new file mode 100644 index 0000000000..eae7a893b7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.h @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_EQUALCOUNT_GPU_KERNEL_H +#define MINDSPORE_EQUALCOUNT_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class EqualCountGpuKernel : public GpuKernel { + public: + EqualCountGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} + ~EqualCountGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input1 = GetDeviceAddress(inputs, 0); + T *input2 = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + int size = SizeToInt(input_size_ / sizeof(T)); + CalEqualCount(size, input1, input2, output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but equalcount needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but equalcount needs 1 output."; + return false; + } + + output_size_ = sizeof(T); + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + return; + } + + private: + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.cc new file mode 100644 index 0000000000..313669a647 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h new file mode 100644 index 0000000000..be74f2e9dc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh" + +namespace mindspore { +namespace kernel { +enum Optype { OP_STATUS = 0, OP_INF, OP_NAN, OP_FINITE, OP_INVALID = 255 }; +static const std::map kOpTypeMap = { + {"FloatStatus", OP_STATUS}, {"IsInf", OP_INF}, {"IsNan", OP_NAN}, {"IsFinite", OP_FINITE}}; +template +class FloatStatusGpuKernel : public GpuKernel { + public: + FloatStatusGpuKernel() : kernel_name_(OP_INVALID), input_size_(0), output_size_(0) {} + ~FloatStatusGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + + switch (kernel_name_) { + case OP_STATUS: { + T *output = GetDeviceAddress(outputs, 0); + CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_INF: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsInf(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_NAN: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsNan(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_FINITE: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsFinite(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + default: { + MS_LOG(EXCEPTION) << "FloatStatus type " << kernel_name_ << " is not supported."; + } + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (size_t x : shape) { + input_size_ = input_size_ * x; + } + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kOpTypeMap.find(kernel_name); + if (iter == kOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "FloatStatus kernel " << kernel_name << " is not supported."; + } else { + kernel_name_ = iter->second; + } + if (kernel_name_ == OP_STATUS) { + output_size_ = sizeof(T); + } else { + output_size_ = input_size_ / sizeof(T) * sizeof(bool); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but FloatStatusGpuKernel needs 1 output."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but FloatStatusGpuKernel needs 1 output."; + return false; + } + return true; + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + Optype kernel_name_; + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.cc new file mode 100644 index 0000000000..471c394598 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + MatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MatMulGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + MatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + MatMulGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + BatchMatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MatMulGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + BatchMatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + MatMulGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h new file mode 100644 index 0000000000..7888d442c9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h @@ -0,0 +1,155 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MATMUL_GPU_KERNEL_H +#define MINDSPORE_MATMUL_GPU_KERNEL_H + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +template +class MatMulGpuKernel : public GpuKernel { + public: + MatMulGpuKernel() + : batch_(0), + m_(0), + n_(0), + k_(0), + is_null_input_(false), + transpose_x1_(CUBLAS_OP_N), + transpose_x2_(CUBLAS_OP_N), + handle_(nullptr), + dtype_a_(CUDA_R_32F), + dtype_b_(CUDA_R_32F), + dtype_c_(CUDA_R_32F), + algo_(CUBLAS_GEMM_DEFAULT_TENSOR_OP) {} + ~MatMulGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + VARIABLE_NOT_USED(stream_ptr); + if (is_null_input_) { + return true; + } + auto input1_addr = GetDeviceAddress(inputs, 0); + auto input2_addr = GetDeviceAddress(inputs, 1); + auto output_addr = GetDeviceAddress(outputs, 0); + + const float alpha = 1; + const float beta = 0; + const int lda = (transpose_x1_ == CUBLAS_OP_T) ? SizeToInt(m_) : SizeToInt(k_); + const int ldb = (transpose_x2_ == CUBLAS_OP_T) ? SizeToInt(k_) : SizeToInt(n_); + const int ldc = n_; + + auto stride_a = SizeToInt(m_ * k_); + auto stride_b = SizeToInt(k_ * n_); + auto stride_c = SizeToInt(m_ * n_); + + try { + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), + &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, + &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), + "cublasSgemm Call Fail"); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); + dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); + dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(output_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "input is null"; + InitSizeLists(); + return true; + } + auto dims = output_shape.size(); + if (dims < 2) { + MS_LOG(EXCEPTION) << "Output dims " << dims << " not support."; + } + + m_ = output_shape[dims - 2]; + n_ = output_shape[dims - 1]; + batch_ = 1; + for (size_t i = 0; i < dims - 2; i++) { + batch_ *= output_shape[i]; + } + + bool transpose = GetAttr(kernel_node, "transpose_x1"); + transpose_x1_ = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + auto input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + k_ = transpose ? input1_shape[dims - 2] : input1_shape[dims - 1]; + + transpose = GetAttr(kernel_node, "transpose_x2"); + transpose_x2_ = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t unit_size = sizeof(T); + + size_t input_size = batch_ * m_ * k_ * unit_size; + input_size_list_.push_back(input_size); + + input_size = batch_ * n_ * k_ * unit_size; + input_size_list_.push_back(input_size); + + size_t output_size = batch_ * m_ * n_ * unit_size; + output_size_list_.push_back(output_size); + } + + private: + size_t batch_; + size_t m_; + size_t n_; + size_t k_; + bool is_null_input_; + + cublasOperation_t transpose_x1_; + cublasOperation_t transpose_x2_; + cublasHandle_t handle_; + cudaDataType_t dtype_a_; + cudaDataType_t dtype_b_; + cudaDataType_t dtype_c_; + cublasGemmAlgo_t algo_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc new file mode 100644 index 0000000000..c72c271c52 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(StandardNormal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + RandomOpGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h new file mode 100644 index 0000000000..785ac02ee5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh" + +namespace mindspore { +namespace kernel { +enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_INVALID_TYPE = 255 }; + +const std::map kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}}; +template +class RandomOpGpuKernel : public GpuKernel { + public: + RandomOpGpuKernel() + : random_op_type_(RANDOM_OP_INVALID_TYPE), + input_size_0_(0), + output_size_(sizeof(T)), + workspace_size_(sizeof(curandState)) {} + ~RandomOpGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + void *workspace_addr = GetDeviceAddress(workspace, 0); + curandState *devStates = reinterpret_cast(workspace_addr); + T *output_addr = GetDeviceAddress(outputs, 0); + + switch (random_op_type_) { + case RANDOM_OP_NORMAL: { + StandardNormal(seed_, seed2_, devStates, output_addr, outputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } + default: { + MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; + } + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kRandomOpTypeMap.find(kernel_name); + if (iter == kRandomOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "Random operation " << kernel_name << " is not supported."; + } else { + random_op_type_ = iter->second; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but random op needs 1 output."; + return false; + } + auto input_shape_0 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape_0.size(); i++) { + input_size_0_ += input_shape_0[i]; + } + input_size_0_ *= sizeof(int); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + workspace_size_ *= output_shape[i]; + } + seed_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); + seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_0_); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(workspace_size_); + } + + private: + RandomOptype random_op_type_; + size_t input_size_0_; + size_t output_size_; + size_t workspace_size_; + int seed_; + int seed2_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc new file mode 100644 index 0000000000..ae8e7bbd0b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h new file mode 100644 index 0000000000..26993bc3bd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -0,0 +1,161 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh" + +namespace mindspore { +namespace kernel { +enum UnaryOptype { + UNARY_OP_EXP = 0, + UNARY_OP_LOG, + UNARY_OP_NEG, + UNARY_OP_RECIPROCAL, + UNARY_OP_ZEROSLIKE, + UNARY_OP_SQUARE, + UNARY_OP_SQRT, + UNARY_OP_RSQRT, + UNARY_OP_INVALID_TYPE = 255 +}; +static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, + {"Log", UNARY_OP_LOG}, + {"Neg", UNARY_OP_NEG}, + {"Reciprocal", UNARY_OP_RECIPROCAL}, + {"ZerosLike", UNARY_OP_ZEROSLIKE}, + {"Square", UNARY_OP_SQUARE}, + {"Sqrt", UNARY_OP_SQRT}, + {"Rsqrt", UNARY_OP_RSQRT}}; +template +class UnaryOpGpuKernel : public GpuKernel { + public: + UnaryOpGpuKernel() + : unary_op_type_(UNARY_OP_INVALID_TYPE), + input_size_(sizeof(T)), + output_size_(sizeof(T)), + workspace_size_(0), + is_null_input_(false) {} + ~UnaryOpGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + switch (unary_op_type_) { + case UNARY_OP_EXP: { + Exponential(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_LOG: { + Logarithm(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_NEG: { + Negative(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_RECIPROCAL: { + Reciprocal(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_SQUARE: { + Square(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_SQRT: { + Sqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_RSQRT: { + Rsqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_ZEROSLIKE: { + Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); + return true; + } + default: { + MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported."; + } + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kUnaryOpTypeMap.find(kernel_name); + if (iter == kUnaryOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "Unary operation " << kernel_name << " is not supported."; + } else { + unary_op_type_ = iter->second; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but unary op needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but unary op needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "UnaryOpGpuKernel input is null"; + InitSizeLists(); + return true; + } + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + output_size_ = input_size_; + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + UnaryOptype unary_op_type_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc new file mode 100644 index 0000000000..c6e3c4c043 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + NcclGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + NcclGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + NcclGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + NcclGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h new file mode 100644 index 0000000000..9701738bfc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h @@ -0,0 +1,188 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "runtime/device/gpu/distribution/collective_init.h" + +namespace mindspore { +namespace kernel { +enum NcclKernelType { NCCL_ALL_REDUCE = 0, NCCL_ALL_GATHER, NCCL_REDUCE_SCATTER, NCCL_INVALID_TYPE = 255 }; +const std::map kNcclTypeMap = { + {"AllReduce", NCCL_ALL_REDUCE}, + {"AllGather", NCCL_ALL_GATHER}, + {"ReduceScatter", NCCL_REDUCE_SCATTER}, +}; + +static std::map kNcclDtypeMap = { + {"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}}; + +typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t, + const std::string &); +typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t, const std::string &); +typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t, + const std::string &); + +template +class NcclGpuKernel : public GpuKernel { + public: + NcclGpuKernel() + : nccl_kernel_type_(NCCL_INVALID_TYPE), + nccl_reduce_type_(ncclSum), + group_name_(""), + input_size_(0), + output_size_(0), + collective_handle_(nullptr), + comm_stream_(nullptr) {} + ~NcclGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); + switch (nccl_kernel_type_) { + case NCCL_ALL_REDUCE: { + auto all_reduce_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); + MS_EXCEPTION_IF_NULL(all_reduce_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), + nccl_data_type_, nccl_reduce_type_, stream, group_name_), + "ncclAllReduce failed"); + break; + } + case NCCL_ALL_GATHER: { + auto all_gather_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); + MS_EXCEPTION_IF_NULL(all_gather_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT( + (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_), + "ncclAllGather failed"); + break; + } + case NCCL_REDUCE_SCATTER: { + auto reduce_scatter_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "ReduceScatter")); + MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), + nccl_data_type_, nccl_reduce_type_, stream, group_name_), + "ncclReduceScatter failed"); + break; + } + default: { + MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported."; + } + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + nccl_data_type_ = kNcclDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; ++i) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + size_t size = sizeof(T); + for (size_t j = 0; j < shape.size(); j++) { + size *= IntToSize(shape[j]); + } + input_size_list_.push_back(size); + input_size_ += size; + } + for (size_t i = 0; i < output_num; ++i) { + auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); + size_t size = sizeof(T); + for (size_t j = 0; j < shape.size(); j++) { + size *= IntToSize(shape[j]); + } + output_size_list_.push_back(size); + output_size_ += size; + } + + InferCommType(kernel_node); + group_name_ = GetAttr(kernel_node, kAttrGroup); + MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_; + auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id"); + if (comm_stream_attr) { + comm_stream_ = reinterpret_cast(GetValue(comm_stream_attr)); + MS_EXCEPTION_IF_NULL(comm_stream_); + } + + collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + return true; + } + + protected: + void InitSizeLists() override { return; } + + private: + void InferCommType(const CNodePtr &kernel_node) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kNcclTypeMap.find(kernel_name); + if (iter == kNcclTypeMap.end()) { + MS_LOG(EXCEPTION) << "Kernel " << kernel_name << " is not supported."; + } else { + nccl_kernel_type_ = iter->second; + } + + auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrOp); + if (reduce_op) { + std::string type = GetValue(reduce_op); + if (type == "sum") { + nccl_reduce_type_ = ncclSum; + } else if (type == "max") { + nccl_reduce_type_ = ncclMax; + } else if (type == "min") { + nccl_reduce_type_ = ncclMin; + } else if (type == "prod") { + nccl_reduce_type_ = ncclProd; + } else { + MS_LOG(EXCEPTION) << "Nccl reduce type " << type << " is not supported."; + } + } + return; + } + + NcclKernelType nccl_kernel_type_; + ncclRedOp_t nccl_reduce_type_; + ncclDataType_t nccl_data_type_; + std::string group_name_; + size_t input_size_; + size_t output_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + const void *collective_handle_; + cudaStream_t comm_stream_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc new file mode 100644 index 0000000000..334550b213 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGpuFwdKernel, half) + +MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGpuFwdKernel, half) + +MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h new file mode 100644 index 0000000000..d651da75e0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h @@ -0,0 +1,142 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class ActivationGpuFwdKernel : public GpuKernel { + public: + ActivationGpuFwdKernel() + : cudnn_handle_(nullptr), + activation_desc_(nullptr), + mode_(CUDNN_ACTIVATION_RELU), + data_descriptor_(nullptr), + is_null_input_(false), + cudnn_data_type_(CUDNN_DATA_FLOAT), + input_size_(0), + output_size_(0), + workspace_size_(0) {} + ~ActivationGpuFwdKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *) override { + if (is_null_input_) { + return true; + } + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, + &beta, data_descriptor_, output), + "cudnnActivationForward failed"); + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kernel_map.find(node_name); + if (iter == kernel_map.end()) { + MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; + } + mode_ = iter->second; + + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGpuFwdKernel needs 1."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ActivationGpuFwdKernel input is null."; + InitSizeLists(); + return true; + } + std::vector shape; + ShapeNdTo4d(input_shape, &shape); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), + "cudnnSetActivationDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + shape[0], shape[1], shape[2], shape[3]), + "cudnnSetTensor4dDescriptor failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), + "cudnnCreateActivationDescriptor failed"); + } + + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + output_size_ = input_size_; + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), + "cudnnDestroyActivationDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); + } + + std::map kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, + {"Tanh", CUDNN_ACTIVATION_TANH}, + {"ELU", CUDNN_ACTIVATION_ELU}, + {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; + + cudnnHandle_t cudnn_handle_; + cudnnActivationDescriptor_t activation_desc_; + cudnnActivationMode_t mode_; + cudnnTensorDescriptor_t data_descriptor_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + cudnnDataType_t cudnn_data_type_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc new file mode 100644 index 0000000000..8fd486c08c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/activation_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + ReluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + ReluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGradGpuKernel, half) + +MS_REG_GPU_KERNEL_ONE( + TanhGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + TanhGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGradGpuKernel, half) + +MS_REG_GPU_KERNEL_ONE( + SigmoidGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + SigmoidGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h new file mode 100644 index 0000000000..ffdb618098 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h @@ -0,0 +1,146 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class ActivationGradGpuKernel : public GpuKernel { + public: + ActivationGradGpuKernel() + : cudnn_handle_(nullptr), + activation_desc_(nullptr), + mode_(CUDNN_ACTIVATION_RELU), + data_descriptor_(nullptr), + is_null_input_(false), + cudnn_data_type_(CUDNN_DATA_FLOAT), + input_size_(0) {} + ~ActivationGradGpuKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *) override { + if (is_null_input_) { + return true; + } + T *dy = nullptr; + T *y = nullptr; + if (mode_ == CUDNN_ACTIVATION_RELU || mode_ == CUDNN_ACTIVATION_ELU) { + dy = GetDeviceAddress(inputs, 0); + y = GetDeviceAddress(inputs, 1); + } else { + y = GetDeviceAddress(inputs, 0); + dy = GetDeviceAddress(inputs, 1); + } + T *dx = GetDeviceAddress(outputs, 0); + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, + data_descriptor_, y, &beta, data_descriptor_, dx), + "cudnnActivationBackward failed"); + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kernel_map.find(node_name); + if (iter == kernel_map.end()) { + MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; + } + mode_ = iter->second; + + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGradGpuKernel needs 2."; + return false; + } + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ActivationGradGpuKernel input is null."; + InitSizeLists(); + return true; + } + std::vector shape; + ShapeNdTo4d(input_shape, &shape); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0), + "SetActivationDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + shape[0], shape[1], shape[2], shape[3]), + "SetTensor4dDescriptor failed"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), + "cudnnCreateActivationDescriptor failed"); + } + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), + "cudnnDestroyActivationDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); + } + + std::map kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU}, + {"TanhGrad", CUDNN_ACTIVATION_TANH}, + {"ELUGrad", CUDNN_ACTIVATION_ELU}, + {"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}}; + cudnnHandle_t cudnn_handle_; + cudnnActivationDescriptor_t activation_desc_; + cudnnActivationMode_t mode_; + cudnnTensorDescriptor_t data_descriptor_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + cudnnDataType_t cudnn_data_type_; + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.cc new file mode 100644 index 0000000000..0f89eb4419 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Adam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + AdamGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Adam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + AdamGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h new file mode 100644 index 0000000000..e2fc87ed51 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h @@ -0,0 +1,142 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh" +namespace mindspore { +namespace kernel { +template +class AdamGpuKernel : public GpuKernel { + public: + AdamGpuKernel() + : variable_size_(0), + m_size_(0), + v_size_(0), + beta1_power_size_(0), + beta2_power_size_(0), + learning_rate_size_(0), + beta1_size_(0), + beta2_size_(0), + epsilon_size_(0), + gradient_size_(0) {} + + ~AdamGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *m = GetDeviceAddress(inputs, 1); + T *v = GetDeviceAddress(inputs, 2); + T *beta1_power = GetDeviceAddress(inputs, 3); + T *beta2_power = GetDeviceAddress(inputs, 4); + T *learning_rate = GetDeviceAddress(inputs, 5); + T *beta1 = GetDeviceAddress(inputs, 6); + T *beta2 = GetDeviceAddress(inputs, 7); + T *epsilon = GetDeviceAddress(inputs, 8); + T *gradient = GetDeviceAddress(inputs, 9); + ApplyAdam(inputs[0]->size / sizeof(T), gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, + variable, m, v, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 10) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 10 inputs."; + return false; + } + + variable_size_ = sizeof(T); + m_size_ = sizeof(T); + v_size_ = sizeof(T); + beta1_power_size_ = sizeof(T); + beta2_power_size_ = sizeof(T); + learning_rate_size_ = sizeof(T); + beta1_size_ = sizeof(T); + beta2_size_ = sizeof(T); + epsilon_size_ = sizeof(T); + gradient_size_ = sizeof(T); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + } + + auto m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < m_shape.size(); i++) { + m_size_ *= m_shape[i]; + } + + auto v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + for (size_t i = 0; i < v_shape.size(); i++) { + v_size_ *= v_shape[i]; + } + + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(m_size_); + input_size_list_.push_back(v_size_); + input_size_list_.push_back(beta1_power_size_); + input_size_list_.push_back(beta2_power_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(beta1_size_); + input_size_list_.push_back(beta2_size_); + input_size_list_.push_back(epsilon_size_); + input_size_list_.push_back(gradient_size_); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t m_size_; + size_t v_size_; + size_t beta1_power_size_; + size_t beta2_power_size_; + size_t learning_rate_size_; + size_t beta1_size_; + size_t beta2_size_; + size_t epsilon_size_; + size_t gradient_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.cc new file mode 100644 index 0000000000..6131aa8568 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BiasAddGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BiasAddGradGpuKernel, float16) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h new file mode 100644 index 0000000000..3e15b818be --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h @@ -0,0 +1,158 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ + +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class BiasAddGradGpuKernel : public GpuKernel { + public: + BiasAddGradGpuKernel() + : same_dims_(true), + cudnn_handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + dy_desc_(nullptr), + db_desc_(nullptr), + op_desc_(nullptr) {} + ~BiasAddGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *dy_addr = GetDeviceAddress(inputs, 0); + T *db_addr = GetDeviceAddress(outputs, 0); + T *indices_addr = GetDeviceAddress(workspace, 0); + T *workspace_addr = GetDeviceAddress(workspace, 1); + + const float alpha = 1; + const float beta = 0; + if (same_dims_) { + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(db_addr, dy_addr, output_size_list_[0], cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed."); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnReduceTensor(cudnn_handle_, op_desc_, indices_addr, workspace_size_list_[0], workspace_addr, + workspace_size_list_[1], &alpha, dy_desc_, dy_addr, &beta, db_desc_, db_addr), + "cudnnReduceTensor failed"); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto num_dims = dy_shape.size(); + if (num_dims < 2) { + MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims; + } + + std::string format = GetAttr(kernel_node, "data_format"); + string::size_type pos = format.find("C"); + if (pos == std::string::npos || pos >= num_dims) { + MS_LOG(EXCEPTION) << "format '" << format << "' invalid"; + } + + // Expand to 4 dims for cudnnSetTensorNdDescriptorEx. + auto cudnn_dims = std::max(num_dims, 4UL); + std::unique_ptr dy_dims = std::make_unique(cudnn_dims); + std::unique_ptr db_dims = std::make_unique(cudnn_dims); + for (size_t i = 0; i < cudnn_dims; i++) { + dy_dims[i] = (i < num_dims) ? SizeToInt(dy_shape[i]) : 1; + db_dims[i] = (i == pos) ? SizeToInt(dy_shape[i]) : 1; + + if (dy_dims[i] != db_dims[i]) { + same_dims_ = false; + } + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()), + "cudnnSetTensorNdDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), + "cudnnSetTensorNdDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, + CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES), + "cudnnSetReduceTensorDescriptor failed"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&db_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateReduceTensorDescriptor(&op_desc_), "cudnnCreateOpTensorDescriptor failed"); + } + void InitSizeLists() override { + size_t dy_size, db_size; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, &dy_size), "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(db_desc_, &db_size), "cudnnGetTensorSizeInBytes failed"); + input_size_list_.push_back(dy_size); + output_size_list_.push_back(db_size); + + size_t indices_size, workspace_size; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetReductionIndicesSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &indices_size), + "cudnnGetReductionIndicesSize failed") + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetReductionWorkspaceSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &workspace_size), + "cudnnGetReductionWorkspaceSize failed") + workspace_size_list_.push_back(indices_size); + workspace_size_list_.push_back(workspace_size); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDestroyReduceTensorDescriptor(op_desc_), + "cudnnDestroyReduceTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(db_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyOpTensorDescriptor failed"); + } + + bool same_dims_; + cudnnHandle_t cudnn_handle_; + cudnnDataType_t cudnn_data_type_; + cudnnTensorDescriptor_t dy_desc_; + cudnnTensorDescriptor_t db_desc_; + cudnnReduceTensorDescriptor_t op_desc_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.cc new file mode 100644 index 0000000000..f9bb710b94 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Conv2D, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Conv2dGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + Conv2D, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + Conv2dGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h new file mode 100644 index 0000000000..6072614e22 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h @@ -0,0 +1,320 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class Conv2dGpuFwdKernel : public GpuKernel { + public: + Conv2dGpuFwdKernel() + : cudnn_handle_(nullptr), + input_desc_(nullptr), + output_desc_(nullptr), + filter_desc_(nullptr), + conv_desc_(nullptr), + padded_desc_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + group_(1), + is_null_input_(false), + input_size_(0), + filter_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~Conv2dGpuFwdKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *filter_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + T *workspace_addr = nullptr; + if (workspace_size_ != 0) { + workspace_addr = GetDeviceAddress(workspace, 0); + } + + const float alpha = 1; + const float beta = 0; + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded_addr = GetDeviceAddress(workspace, 1); + CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionForward(cudnn_handle_, &alpha, padded_desc_, padded_addr, filter_desc_, filter_addr, conv_desc_, + conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), + "cudnnConvolutionForward failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionForward(cudnn_handle_, &alpha, input_desc_, input_addr, filter_desc_, filter_addr, conv_desc_, + conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), + "cudnnConvolutionForward failed"); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(in_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "Conv2dGpuFwdKernel input is null."; + InitSizeLists(); + return true; + } + Set4DDesc(in_shape, filter_shape, output_shape); + group_ = GetAttr(kernel_node, "group"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); + pad_height_ = GetAttr(kernel_node, "pad"); + pad_width_ = pad_height_; + pad_mode_ = GetAttr(kernel_node, "pad_mode"); + SetStrideAndDilation(kernel_node); + cudnnTensorDescriptor_t input_descriptor_real = nullptr; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(in_shape, kernel_node); + input_descriptor_real = use_pad_ ? padded_desc_ : input_desc_; + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2], + dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + input_descriptor_real = input_desc_; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), + "cudnnSetConvolutionMathType failed.") + } + SelectAlgorithm(input_descriptor_real); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&filter_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), + "cudnnCreateConvolutionDescriptor failed"); + } + + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_desc_, reinterpret_cast(&input_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(filter_desc_, reinterpret_cast(&filter_size_)), + "cudnnGetFilterSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(output_desc_, reinterpret_cast(&output_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_desc_, reinterpret_cast(&padded_size_)), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + input_size_list_.push_back(filter_size_); + output_size_list_.push_back(output_size_); + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle_, padded_desc_, filter_desc_, conv_desc_, output_desc_, + conv_algorithm_, &workspace_size_), + "cudnnGetConvolutionForwardWorkspaceSize failed"); + workspace_size_list_.push_back(padded_size_); + } else { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle_, input_desc_, filter_desc_, conv_desc_, output_desc_, + conv_algorithm_, &workspace_size_), + "cudnnGetConvolutionForwardWorkspaceSize failed"); + } + } + (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); + + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), + "cudnnDestroyConvolutionDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(filter_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); + } + bool CheckParam(const CNodePtr &kernel_node) { + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but conv2d needs 2 inputs."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but conv2d needs 1 output."; + return false; + } + return true; + } + void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { + auto pad_list = GetAttr>(kernel_node, "pad_list"); + + n_ = SizeToInt(in_shape[0]); + c_ = SizeToInt(in_shape[1]); + old_height_ = SizeToInt(in_shape[2]); + old_width_ = SizeToInt(in_shape[3]); + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + + // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_, + old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( + conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3], + dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + } + + void Set4DDesc(const std::vector &in_shape, const std::vector &filter_shape, + const std::vector &output_shape) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), + SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(filter_shape[0]), + SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), + "cudnnSetFilter4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), + SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + } + void SelectAlgorithm(cudnnTensorDescriptor_t input_descriptor_real) { + if (group_ > 1 || CUDNN_MAJOR < 7) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetConvolutionForwardAlgorithm( + cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, output_desc_, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, 0, &conv_algorithm_), + "cudnnGetConvolutionForwardAlgorithm failed"); + } else { + constexpr int requested_algo_count = 1; + int returned_algo_count; + cudnnConvolutionFwdAlgoPerf_t perf_results; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionForwardAlgorithm_v7(cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, + output_desc_, requested_algo_count, &returned_algo_count, &perf_results), + "cudnnGetConvolutionForwardAlgorithm_v7 failed"); + conv_algorithm_ = perf_results.algo; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + } + } + void SetStrideAndDilation(const CNodePtr &kernel_node) { + stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); + dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); + if (stride_.size() != 4) { + MS_LOG(EXCEPTION) << "Conv2d's' stride must be 4d!"; + } + if (stride_[0] != 1 || stride_[1] != 1) { + MS_LOG(EXCEPTION) << "Conv2d stride only support 1 in N axis and C axis!"; + } + if (dilation_.size() != 4) { + MS_LOG(EXCEPTION) << "Conv2d's dilation must be 4d!"; + } + if (dilation_[0] != 1 || dilation_[1] != 1) { + MS_LOG(EXCEPTION) << "Conv2d dilation only support 1 in N axis and C axis!"; + } + } + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_desc_; + cudnnTensorDescriptor_t output_desc_; + cudnnFilterDescriptor_t filter_desc_; + cudnnConvolutionFwdAlgo_t conv_algorithm_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnTensorDescriptor_t padded_desc_; + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + const float pad_value_ = 0.0; + cudnnDataType_t cudnn_data_type_; + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + std::vector stride_; + std::vector dilation_; + int group_; + bool is_null_input_; + size_t input_size_; + size_t filter_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.cc new file mode 100644 index 0000000000..ca16e1a18c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Conv2DBackpropFilter, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConvGradFilterGpuBkwKernel, float) +MS_REG_GPU_KERNEL_ONE( + Conv2DBackpropFilter, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ConvGradFilterGpuBkwKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h new file mode 100644 index 0000000000..638da4a99f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h @@ -0,0 +1,320 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class ConvGradFilterGpuBkwKernel : public GpuKernel { + public: + ConvGradFilterGpuBkwKernel() + : cudnn_handle_(nullptr), + dw_desc_(nullptr), + conv_desc_(nullptr), + dy_desc_(nullptr), + x_desc_(nullptr), + padded_descriptor_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + group_(1), + is_null_input_(false), + input_size_(0), + dy_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~ConvGradFilterGpuBkwKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *dy = GetDeviceAddress(inputs, 0); + T *x = GetDeviceAddress(inputs, 1); + T *dw = GetDeviceAddress(outputs, 0); + T *work_space = nullptr; + if (workspace_size_ != 0) { + work_space = GetDeviceAddress(workspace, 0); + } + + const float alpha = 1; + const float beta = 0; + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded = GetDeviceAddress(workspace, 1); + CalPad(padded_size_ / sizeof(T), x, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, + reinterpret_cast(stream_ptr)); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, padded_descriptor_, padded, dy_desc_, dy, conv_desc_, + algo_, work_space, workspace_size_, &beta, dw_desc_, dw), + "ConvolutionBackwardFilter failed"); + return true; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, x_desc_, x, dy_desc_, dy, conv_desc_, algo_, work_space, + workspace_size_, &beta, dw_desc_, dw), + "ConvolutionBackwardFilter failed"); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ConvGradFilterGpuBkwKernel input is null."; + InitSizeLists(); + return true; + } + std::vector filter_shape; + GetFilterShape(kernel_node, &filter_shape); + Set4DDesc(dy_shape, filter_shape, in_shape); + group_ = GetAttr(kernel_node, "group"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); + + pad_height_ = GetAttr(kernel_node, "pad"); + pad_width_ = pad_height_; + pad_mode_ = GetAttr(kernel_node, "pad_mode"); + SetStrideAndDilation(kernel_node); + cudnnTensorDescriptor_t x_desc_real = nullptr; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(in_shape, kernel_node); + x_desc_real = use_pad_ ? padded_descriptor_ : x_desc_; + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], + dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "GetConvolution2dDescriptor failed"); + x_desc_real = x_desc_; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), + "cudnnSetConvolutionMathType failed.") + } + SelectAlgorithm(x_desc_real); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&dw_desc_), "cudnnCreateFilterDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), + "cudnnCreateConvolutionDescriptor failed"); + } + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, reinterpret_cast(&dy_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, reinterpret_cast(&input_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(dw_desc_, reinterpret_cast(&output_size_)), + "cudnnGetFilterSizeInBytes failed"); + } + input_size_list_.push_back(dy_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast(&padded_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, padded_descriptor_, dy_desc_, conv_desc_, + dw_desc_, algo_, reinterpret_cast(&workspace_size_)), + "cudnnGetConvolutionBackwardFilterWorkspaceSize failed"); + workspace_size_list_.push_back(padded_size_); + } else { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, x_desc_, dy_desc_, conv_desc_, dw_desc_, algo_, + reinterpret_cast(&workspace_size_)), + "cudnnGetConvolutionBackwardFilterWorkspaceSize failed"); + } + } + (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), + "cudnnDestroyConvolutionDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(dw_desc_), "cudnnDestroyFilterDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyTensorDescriptor failed"); + } + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ConvGradFilter needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ConvGradFilter needs 1 output."; + return false; + } + return true; + } + void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { + auto pad_list = GetAttr>(kernel_node, "pad_list"); + n_ = SizeToInt(in_shape[0]); + c_ = SizeToInt(in_shape[1]); + old_height_ = SizeToInt(in_shape[2]); + old_width_ = SizeToInt(in_shape[3]); + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, + c_, old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( + conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], + dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + } + void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { + if (group_ > 1 || CUDNN_MAJOR < 7) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, 0, &algo_), + "GetConvolutionBackwardFilterAlgorithm failed"); + } else { + constexpr int requested_algo_count = 1; + int returned_algo_count; + cudnnConvolutionBwdFilterAlgoPerf_t perf_results; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, + requested_algo_count, &returned_algo_count, &perf_results), + "GetConvolutionBackwardFilterAlgorithm failed"); + algo_ = perf_results.algo; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + } + } + void GetFilterShape(const CNodePtr &kernel_node, std::vector *filter_shape) { + auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast()->value(); + (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*filter_shape), + [](const ValuePtr &e) -> int { return e->cast()->value(); }); + } + void Set4DDesc(const std::vector &dy_shape, const std::vector &filter_shape, + const std::vector &in_shape) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), + SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), + "SetTensor4dDescriptor failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetFilter4dDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), filter_shape[1], + filter_shape[2], filter_shape[3]), + "SetFilter4dDescriptor failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), + SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), + "SetTensor4dDescriptor failed"); + } + void SetStrideAndDilation(const CNodePtr &kernel_node) { + stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); + dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); + if (stride_.size() != 2) { + MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel's stride must be 2d!"; + } + if (dilation_.size() != 4) { + MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel's dilation must be 4d!"; + } + if (dilation_[0] != 1 || dilation_[1] != 1) { + MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel dilation only support 1 in N axis and C axis!"; + } + } + cudnnHandle_t cudnn_handle_; + cudnnFilterDescriptor_t dw_desc_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnTensorDescriptor_t dy_desc_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t padded_descriptor_; + cudnnConvolutionBwdFilterAlgo_t algo_; + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + const float pad_value_ = 0.0; + cudnnDataType_t cudnn_data_type_; + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + std::vector stride_; + std::vector dilation_; + int group_; + bool is_null_input_; + size_t input_size_; + size_t dy_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.cc new file mode 100644 index 0000000000..d8441fb67c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Conv2DBackpropInput, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConvGradInputGpuBkwKernel, float) +MS_REG_GPU_KERNEL_ONE( + Conv2DBackpropInput, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ConvGradInputGpuBkwKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h new file mode 100644 index 0000000000..a9a1e5c0cc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h @@ -0,0 +1,315 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class ConvGradInputGpuBkwKernel : public GpuKernel { + public: + ConvGradInputGpuBkwKernel() + : cudnn_handle_(nullptr), + w_desc_(nullptr), + conv_desc_(nullptr), + dy_desc_(nullptr), + dx_desc_(nullptr), + padded_descriptor_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + group_(1), + is_null_input_(false), + dy_size_(0), + w_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~ConvGradInputGpuBkwKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *dy = GetDeviceAddress(inputs, 0); + T *w = GetDeviceAddress(inputs, 1); + T *dx = GetDeviceAddress(outputs, 0); + T *work_space = nullptr; + if (workspace_size_ != 0) { + work_space = GetDeviceAddress(workspace, 0); + } + + const float alpha = 1; + const float beta = 0; + + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded = GetDeviceAddress(workspace, 1); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, + workspace_size_, &beta, padded_descriptor_, padded), + "ConvolutionBackwardData failed"); + CalPadGrad(output_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, + workspace_size_, &beta, dx_desc_, dx), + "ConvolutionBackwardData failed"); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(dy_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ConvGradInputGpuBkwKernel input is null."; + InitSizeLists(); + return true; + } + std::vector input_shape; + GetInputShape(kernel_node, &input_shape); + Set4DDesc(dy_shape, input_shape, filter_shape); + + group_ = GetAttr(kernel_node, "group"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); + + pad_height_ = GetAttr(kernel_node, "pad"); + pad_width_ = pad_height_; + pad_mode_ = GetAttr(kernel_node, "pad_mode"); + SetStrideAndDilation(kernel_node); + cudnnTensorDescriptor_t dx_desc_real = nullptr; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(input_shape, kernel_node); + dx_desc_real = use_pad_ ? padded_descriptor_ : dx_desc_; + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], + dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + dx_desc_real = dx_desc_; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), + "cudnnSetConvolutionMathType failed.") + } + SelectAlgorithm(dx_desc_real); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "cudnnCreateFilterDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), + "cudnnCreateConvolutionDescriptor failed"); + } + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, &dy_size_), "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(w_desc_, &w_size_), "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dx_desc_, &output_size_), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(dy_size_); + input_size_list_.push_back(w_size_); + output_size_list_.push_back(output_size_); + + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), + "cudnnGetTensorSizeInBytes failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, padded_descriptor_, + algo_, &workspace_size_), + "cudnnGetConvolutionBackwardDataWorkspaceSize failed"); + workspace_size_list_.push_back(padded_size_); + } else { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_, algo_, &workspace_size_), + "cudnnGetConvolutionBackwardDataWorkspaceSize failed"); + } + } + (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), + "cudnnDestroyConvolutionDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "cudnnDestroyFilterDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "cudnnDestroyTensorDescriptor failed"); + } + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ConvGradInput needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ConvGradInput needs 1 output."; + return false; + } + return true; + } + void SetPad(const std::vector &input_shape, const CNodePtr &kernel_node) { + auto pad_list = GetAttr>(kernel_node, "pad_list"); + n_ = input_shape[0]; + c_ = input_shape[1]; + old_height_ = input_shape[2]; + old_width_ = input_shape[3]; + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, + c_, old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( + conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], + dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + } + void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { + if (group_ > 1 || CUDNN_MAJOR < 7) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, 0, &algo_), + "cudnnGetConvolutionBackwardDataAlgorithm failed"); + } else { + constexpr int requested_algo_count = 1; + int returned_algo_count; + cudnnConvolutionBwdDataAlgoPerf_t perf_results; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, + requested_algo_count, &returned_algo_count, &perf_results), + "cudnnGetConvolutionBackwardDataAlgorithm_v7 failed"); + algo_ = perf_results.algo; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } + } + void GetInputShape(const CNodePtr &kernel_node, std::vector *input_shape) { + auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast()->value(); + (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*input_shape), + [](const ValuePtr &e) -> int { return e->cast()->value(); }); + } + void Set4DDesc(const std::vector &dy_shape, const std::vector &input_shape, + const std::vector &filter_shape) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetFilter4dDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), + SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), + "SetFilter4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), + SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), + "SetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, input_shape[0], input_shape[1], + input_shape[2], input_shape[3]), + "SetTensor4dDescriptor failed"); + } + void SetStrideAndDilation(const CNodePtr &kernel_node) { + stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); + dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); + if (stride_.size() != 2) { + MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel's stride must be 2d!"; + } + if (dilation_.size() != 4) { + MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel's dilation must be 4d!"; + } + if (dilation_[0] != 1 || dilation_[1] != 1) { + MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel dilation only support 1 in N axis and C axis!"; + } + } + cudnnHandle_t cudnn_handle_; + cudnnFilterDescriptor_t w_desc_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnTensorDescriptor_t dy_desc_; + cudnnTensorDescriptor_t dx_desc_; + cudnnTensorDescriptor_t padded_descriptor_; + cudnnConvolutionBwdDataAlgo_t algo_; + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + cudnnDataType_t cudnn_data_type_; + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + std::vector stride_; + std::vector dilation_; + int group_; + bool is_null_input_; + size_t dy_size_; + size_t w_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.cc new file mode 100644 index 0000000000..155451875c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(CTCLossV2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CtcLossGpuKernel, float) + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h new file mode 100644 index 0000000000..8b02354516 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h @@ -0,0 +1,166 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" + +namespace mindspore { +namespace kernel { +template +class CtcLossGpuKernel : public GpuKernel { + public: + CtcLossGpuKernel() + : cudnn_handle_(nullptr), + probs_desc_(nullptr), + ctcloss_desc_(nullptr), + label_size_(0), + input_lengths_size_(0), + label_lengths_size_(0) {} + ~CtcLossGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + float *probs = GetDeviceAddress(inputs, 0); + int *labels = GetDeviceAddress(inputs, 1); + int *input_lengths = GetDeviceAddress(inputs, 2); + int *label_lengths = GetDeviceAddress(inputs, 3); + float *costs = GetDeviceAddress(outputs, 0); + float *grads = GetDeviceAddress(outputs, 1); + + // Copy labels/input_lengths/label_length to host as cudnn7.x.x requires + void *labels_host = nullptr; + void *input_lengths_host = nullptr; + void *label_lengths_host = nullptr; + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&labels_host, inputs[1]->size), "cudaMallocHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&input_lengths_host, inputs[2]->size), "cudaMallocHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&label_lengths_host, inputs[3]->size), "cudaMallocHost failed."); + cudaStream_t stream = reinterpret_cast(stream_ptr); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(labels_host, labels, inputs[1]->size, cudaMemcpyDeviceToHost, stream), + "cudaMemcpyAsync failed."); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(input_lengths_host, input_lengths, inputs[2]->size, cudaMemcpyDeviceToHost, stream), + "cudaMemcpyAsync failed."); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(label_lengths_host, label_lengths, inputs[3]->size, cudaMemcpyDeviceToHost, stream), + "cudaMemcpyAsync failed."); + + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); + size_t workspace_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetCTCLossWorkspaceSize(cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast(labels_host), + reinterpret_cast(label_lengths_host), + reinterpret_cast(input_lengths_host), CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, + ctcloss_desc_, &workspace_size), + "cudnnGetCTCLossWorkspaceSize failed."); + void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size); + if (workspace == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc workspace, size: " << workspace_size; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast(labels_host), + reinterpret_cast(label_lengths_host), reinterpret_cast(input_lengths_host), costs, + probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size), + "cudnnCtcLoss failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); + + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(workspace); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed."); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (probs_shape.size() != 3) { + MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support."; + } + probs_dims_[0] = probs_shape[0]; + probs_dims_[1] = probs_shape[1]; + probs_dims_[2] = probs_shape[2]; + + auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + if (labels_dims.size() != 1 && labels_dims.size() != 2) { + MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support."; + } + label_size_ = sizeof(int); + for (auto i : labels_dims) { + label_size_ *= i; + } + + auto input_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + input_lengths_size_ = input_length_dims[0] * sizeof(int); + auto label_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + label_lengths_size_ = label_length_dims[0] * sizeof(int); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(probs_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 3, probs_dims_), + "cudnnSetTensorNdDescriptorEx failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetCTCLossDescriptorEx(ctcloss_desc_, CUDNN_DATA_FLOAT, + CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN), + "cudnnSetCTCLossDescriptorEx failed."); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&probs_desc_), "cudnnCreateTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateCTCLossDescriptor(&ctcloss_desc_), "cudnnCreateCTCLossDescriptor failed."); + } + + void InitSizeLists() override { + input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float)); + input_size_list_.push_back(label_size_); + input_size_list_.push_back(input_lengths_size_); + input_size_list_.push_back(label_lengths_size_); + + output_size_list_.push_back(probs_dims_[1] * sizeof(float)); + output_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float)); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyCTCLossDescriptor(ctcloss_desc_), "cudnnDestroyCTCLossDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(probs_desc_), "cudnnDestroyTensorDescriptor failed."); + } + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t probs_desc_; + cudnnCTCLossDescriptor_t ctcloss_desc_; + int probs_dims_[3] = {0}; + int label_size_; + int input_lengths_size_; + int label_lengths_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.cc new file mode 100644 index 0000000000..423a230b6e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Dropout, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + DropoutGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + Dropout, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + DropoutGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h new file mode 100644 index 0000000000..2104d7af35 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cuh" +#include "include/curand.h" + +namespace mindspore { +namespace kernel { +template +class DropoutGpuFwdKernel : public GpuKernel { + public: + DropoutGpuFwdKernel() + : cudnn_handle_(nullptr), + is_null_input_(false), + num_count_(0), + keep_prob_(0.0), + states_init_(false), + mask_generator_(nullptr) {} + + ~DropoutGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + T *mask = GetDeviceAddress(outputs, 1); + float *mask_f = GetDeviceAddress(workspace, 0); + + if (!states_init_) { + curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(mask_generator_, time(NULL)); + states_init_ = true; + } + // curandGen only support float or double for mask. + curandGenerateUniform(mask_generator_, mask_f, num_count_); + DropoutForward(input, mask, output, mask_f, num_count_, keep_prob_, reinterpret_cast(stream_ptr)); + + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but DropoutGpuFwdKernel needs 1."; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + InitSizeLists(); + return true; + } + + num_count_ = 1; + for (size_t x : input_shape) { + num_count_ *= x; + } + keep_prob_ = GetAttr(kernel_node, "keep_prob"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + + void InitSizeLists() override { + size_t input_size = num_count_ * sizeof(T); + input_size_list_.push_back(input_size); + output_size_list_.push_back(input_size); // output size: the same with input size + output_size_list_.push_back(input_size); // mask size: the same with input size + workspace_size_list_.push_back(num_count_ * sizeof(float)); // temp mask_f for curandGen + } + + private: + cudnnHandle_t cudnn_handle_; + bool is_null_input_; + size_t num_count_; + float keep_prob_; + bool states_init_; + curandGenerator_t mask_generator_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.cc new file mode 100644 index 0000000000..faf884c2eb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + DropoutGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + DropoutGradGpuBwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + DropoutGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + DropoutGradGpuBwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h new file mode 100644 index 0000000000..a3a7250c9b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class DropoutGradGpuBwdKernel : public GpuKernel { + public: + DropoutGradGpuBwdKernel() : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), keep_prob_(0.0) {} + ~DropoutGradGpuBwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + + T *dy = GetDeviceAddress(inputs, 0); + T *mask = GetDeviceAddress(inputs, 1); + T *dx = GetDeviceAddress(outputs, 0); + + DropoutBackward(dy, mask, dx, num_count_, keep_prob_, reinterpret_cast(stream_ptr)); + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but DropoutGradGpuBwdKernel needs 2."; + return false; + } + + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + InitSizeLists(); + return true; + } + + num_count_ = 1; + for (size_t x : input_shape) { + num_count_ *= x; + } + keep_prob_ = GetAttr(kernel_node, "keep_prob"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + void InitSizeLists() override { + size_t dy_size = num_count_ * sizeof(T); + size_t mask_size = dy_size; + size_t dx_size = dy_size; + + input_size_list_.push_back(dy_size); + input_size_list_.push_back(mask_size); + output_size_list_.push_back(dx_size); + } + + private: + cudnnHandle_t cudnn_handle_; + bool is_null_input_; + size_t num_count_; + float keep_prob_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.cc new file mode 100644 index 0000000000..d8206aedcd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FlattenGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + FlattenGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FlattenGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FlattenGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + FlattenGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FlattenGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FlattenGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FlattenGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + FlattenGpuFwdKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h new file mode 100644 index 0000000000..a140579a3c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h @@ -0,0 +1,78 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class FlattenGpuFwdKernel : public GpuKernel { + public: + FlattenGpuFwdKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} + ~FlattenGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + cudaError_t ret = + cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)); + if (ret) { + MS_LOG(ERROR) << "cudaMemcpyAsync error in FlattenGpuFwdKernel::Launch, error code is " << ret; + return false; + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (size_t i = 0; i < shape.size(); ++i) { + input_size_ *= shape[i]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_ = input_size_; + output_size_list_.push_back(output_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.cc new file mode 100644 index 0000000000..c07126a2ed --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FlattenGardGpuBkwKernel, float) +MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FlattenGardGpuBkwKernel, half) +MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + FlattenGardGpuBkwKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h new file mode 100644 index 0000000000..b21327bc3b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class FlattenGardGpuBkwKernel : public GpuKernel { + public: + FlattenGardGpuBkwKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} + ~FlattenGardGpuBkwKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + cudaError_t ret = + cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)); + if (ret) { + MS_LOG(ERROR) << "cudaMemcpyAsync error in FlattenGardGpuFwdKernel::Launch, error code is " << ret; + return false; + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but FlattenGardGpuFwdKernel needs 1."; + return false; + } + + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < shape.size(); ++i) { + if (input_size_ == 0) { + input_size_ = 1; + } + input_size_ *= shape[i]; + } + input_size_ = input_size_ * sizeof(T); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_ = input_size_; + output_size_list_.push_back(output_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.cc new file mode 100644 index 0000000000..0186153745 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ApplyFtrl, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FtrlGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ApplyFtrl, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + FtrlGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h new file mode 100644 index 0000000000..ea08741dba --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh" +namespace mindspore { +namespace kernel { +template +class FtrlGpuKernel : public GpuKernel { + public: + FtrlGpuKernel() + : variable_size_(0), + accumulation_size_(0), + linear_size_(0), + gradient_size_(0), + learning_rate_size_(0), + l1_regularization_size_(0), + l2_regularization_size_(0), + learning_rate_power_size_(0) {} + + ~FtrlGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *accumulation = GetDeviceAddress(inputs, 1); + T *linear = GetDeviceAddress(inputs, 2); + T *gradient = GetDeviceAddress(inputs, 3); + T *learning_rate = GetDeviceAddress(inputs, 4); + T *l1_regularization = GetDeviceAddress(inputs, 5); + T *l2_regularization = GetDeviceAddress(inputs, 6); + T *learning_rate_power = GetDeviceAddress(inputs, 7); + ApplyFtrl(inputs[0]->size / sizeof(T), gradient, learning_rate, l1_regularization, l2_regularization, + learning_rate_power, variable, accumulation, linear, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 8) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 8 inputs."; + return false; + } + + variable_size_ = sizeof(T); + accumulation_size_ = sizeof(T); + linear_size_ = sizeof(T); + gradient_size_ = sizeof(T); + learning_rate_size_ = sizeof(T); + l1_regularization_size_ = sizeof(T); + l2_regularization_size_ = sizeof(T); + learning_rate_power_size_ = sizeof(T); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + } + + auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < accumulation_shape.size(); i++) { + accumulation_size_ *= accumulation_shape[i]; + } + + auto linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + for (size_t i = 0; i < linear_shape.size(); i++) { + linear_size_ *= linear_shape[i]; + } + + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(accumulation_size_); + input_size_list_.push_back(linear_size_); + input_size_list_.push_back(gradient_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(l1_regularization_size_); + input_size_list_.push_back(l2_regularization_size_); + input_size_list_.push_back(learning_rate_power_size_); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t accumulation_size_; + size_t linear_size_; + size_t gradient_size_; + size_t learning_rate_size_; + size_t l1_regularization_size_; + size_t l2_regularization_size_; + size_t learning_rate_power_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.cc new file mode 100644 index 0000000000..5ef2fd8786 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FusedAdamWeightDecay, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedAdamWeightDecayGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FusedAdam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedAdamWeightDecayGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h new file mode 100644 index 0000000000..c4fd31a737 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class FusedAdamWeightDecayGpuKernel : public GpuKernel { + public: + FusedAdamWeightDecayGpuKernel() : element_nums_(0), weight_decay_(false) {} + ~FusedAdamWeightDecayGpuKernel() override = default; + + bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (node_name == "AdamWeighDecay") { + weight_decay_ = true; + } + + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 7); + element_nums_ = 1; + for (auto i : shape) { + element_nums_ *= i; + } + + InitSizeLists(); + return true; + } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + float *beta1 = GetDeviceAddress(inputs, 0); + float *one_sub_beta1 = GetDeviceAddress(inputs, 1); + float *beta2 = GetDeviceAddress(inputs, 2); + float *one_sub_beta2 = GetDeviceAddress(inputs, 3); + float *epsilon = GetDeviceAddress(inputs, 4); + float *lr = GetDeviceAddress(inputs, 5); + T *param = GetDeviceAddress(inputs, 6); + T *m = GetDeviceAddress(inputs, 7); + T *v = GetDeviceAddress(inputs, 8); + T *gradient = GetDeviceAddress(inputs, 9); + float *weight_decay = nullptr; + if (weight_decay_) { + weight_decay = GetDeviceAddress(inputs, 10); + } + AdamWeightDecay(element_nums_, true, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, + param, gradient, reinterpret_cast(stream_ptr)); + return true; + } + + protected: + void InitResource() override{}; + void InitSizeLists() override { + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(element_nums_ * sizeof(T)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(element_nums_ * sizeof(T)); + if (weight_decay_) { + input_size_list_.push_back(sizeof(float)); + } + output_size_list_.push_back(element_nums_ * sizeof(T)); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int element_nums_; + bool weight_decay_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc new file mode 100644 index 0000000000..2ce39b63a0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedBatchNormGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + FusedBatchNormGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(BatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedBatchNormGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(BatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + FusedBatchNormGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h new file mode 100644 index 0000000000..774428dc40 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h @@ -0,0 +1,190 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class FusedBatchNormGpuKernel : public GpuKernel { + public: + FusedBatchNormGpuKernel() + : batch_(0), + channel_(0), + height_(0), + width_(0), + mode_(CUDNN_BATCHNORM_SPATIAL), + epsilon_(10e-5), + exp_avg_factor_(0.1), + is_train_(false), + is_null_input_(false), + x_desc_(nullptr), + y_desc_(nullptr), + scale_bias_mean_var_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~FusedBatchNormGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + VARIABLE_NOT_USED(stream_ptr); + if (is_null_input_) { + return true; + } + auto x = GetDeviceAddress(inputs, 0); + auto scale = GetDeviceAddress(inputs, 1); + auto bias = GetDeviceAddress(inputs, 2); + auto runing_mean = GetDeviceAddress(inputs, 3); + auto runnig_variance = GetDeviceAddress(inputs, 4); + auto y = GetDeviceAddress(outputs, 0); + + const float alpha = 1; + const float beta = 0; + if (is_train_) { + auto save_mean = GetDeviceAddress(outputs, 3); + auto save_variance = GetDeviceAddress(outputs, 4); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, + scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, + runnig_variance, epsilon_, save_mean, save_variance), + "Kernel launch failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationForwardInference(handle_, mode_, &alpha, &beta, x_desc_, x, + y_desc_, y, scale_bias_mean_var_desc_, scale, + bias, runing_mean, runnig_variance, epsilon_), + "Kernel launch failed"); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5"; + } + + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (shape.size() != 4) { + MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGpuKernel should be >= 4"; + } + is_null_input_ = CHECK_NULL_INPUT(shape); + if (is_null_input_) { + MS_LOG(WARNING) << "FusedBatchNormGpuKernel input is null"; + InitSizeLists(); + return true; + } + batch_ = SizeToInt(shape[0]); + channel_ = SizeToInt(shape[1]); + height_ = SizeToInt(shape[2]); + width_ = SizeToInt(shape[3]); + + mode_ = CUDNN_BATCHNORM_SPATIAL; + epsilon_ = GetAttr(kernel_node, "epsilon"); + // P.FusedBatchNorm is used for training; P.BatchNorm is used for inference + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (node_name == "FusedBatchNorm") { + is_train_ = true; + exp_avg_factor_ = GetAttr(kernel_node, "momentum"); + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + "Set x desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + "Set y desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), + "Set para desc failed"); + + InitSizeLists(); + + return true; + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); + } + void InitSizeLists() override { + size_t input_size = 0; + size_t para_size = 0; + size_t output_size = 0; + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &input_size), "Get input size failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, ¶_size), + "Get para size failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(y_desc_, &output_size), "Get para size failed"); + } + input_size_list_.push_back(input_size); + input_size_list_.push_back(para_size); // scale + input_size_list_.push_back(para_size); // bias + input_size_list_.push_back(para_size); // mean + input_size_list_.push_back(para_size); // variance + + output_size_list_.push_back(output_size); + output_size_list_.push_back(para_size); // running mean + output_size_list_.push_back(para_size); // running variance + output_size_list_.push_back(para_size); // save mean + output_size_list_.push_back(para_size); // save variance + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed"); + } + + int batch_; + int channel_; + int height_; + int width_; + cudnnBatchNormMode_t mode_; + double epsilon_; + double exp_avg_factor_; + bool is_train_; + bool is_null_input_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t y_desc_; + cudnnTensorDescriptor_t scale_bias_mean_var_desc_; + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc new file mode 100644 index 0000000000..546e034f6b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedBatchNormGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + FusedBatchNormGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h new file mode 100644 index 0000000000..a2d0d741b1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h @@ -0,0 +1,178 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class FusedBatchNormGradGpuKernel : public GpuKernel { + public: + FusedBatchNormGradGpuKernel() + : batch_(0), + channel_(0), + height_(0), + width_(0), + mode_(CUDNN_BATCHNORM_SPATIAL), + epsilon_(10e-5), + is_null_input_(false), + x_desc_(nullptr), + dy_desc_(nullptr), + dx_desc_(nullptr), + scale_bias_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~FusedBatchNormGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + VARIABLE_NOT_USED(stream_ptr); + if (is_null_input_) { + return true; + } + auto dy = GetDeviceAddress(inputs, 0); + auto x = GetDeviceAddress(inputs, 1); + auto scale = GetDeviceAddress(inputs, 2); + auto save_mean = GetDeviceAddress(inputs, 3); + auto save_variance = GetDeviceAddress(inputs, 4); + auto dx = GetDeviceAddress(outputs, 0); + auto bn_scale = GetDeviceAddress(outputs, 1); + auto bn_bias = GetDeviceAddress(outputs, 2); + + const float alpha_data_diff = 1; + const float beta_data_diff = 0; + const float alpha_param_diff = 1; + const float beta_param_diff = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, + &beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale, + bn_scale, bn_bias, epsilon_, save_mean, save_variance), + "Kernel Launch Failed."); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGradGpuKernel should be 5"; + } + + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (shape.size() != 4) { + MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradGpuKernel should be 4"; + return false; + } + is_null_input_ = CHECK_NULL_INPUT(shape); + if (is_null_input_) { + MS_LOG(WARNING) << "FusedBatchNormGradGpuKernel input is null"; + InitSizeLists(); + return true; + } + batch_ = SizeToInt(shape[0]); + channel_ = SizeToInt(shape[1]); + height_ = SizeToInt(shape[2]); + width_ = SizeToInt(shape[3]); + + mode_ = CUDNN_BATCHNORM_SPATIAL; + epsilon_ = GetAttr(kernel_node, "epsilon"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + "Set x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + "Set dy desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + "Set dx desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), + "Set para desc failed"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_desc_), "Create para desc failed"); + } + + void InitSizeLists() override { + size_t input_size = 0; + size_t para_size = 0; + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &input_size), "Get input size failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(scale_bias_desc_, ¶_size), "Get input size failed"); + } + + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(para_size); + input_size_list_.push_back(para_size); + input_size_list_.push_back(para_size); + + output_size_list_.push_back(input_size); + output_size_list_.push_back(para_size); + output_size_list_.push_back(para_size); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_desc_), "Destroy para desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); + } + + int batch_; + int channel_; + int height_; + int width_; + + cudnnBatchNormMode_t mode_; + double epsilon_; + bool is_null_input_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t dy_desc_; + cudnnTensorDescriptor_t dx_desc_; + cudnnTensorDescriptor_t scale_bias_desc_; + + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.cc new file mode 100644 index 0000000000..274e4896c9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(GeluGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + GeLUGpuGradKernel, float) +MS_REG_GPU_KERNEL_ONE(GeluGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + GeLUGpuGradKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h new file mode 100644 index 0000000000..823da1fe9f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class GeLUGpuGradKernel : public GpuKernel { + public: + GeLUGpuGradKernel() : input_size_(0) {} + ~GeLUGpuGradKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *dy_addr = GetDeviceAddress(inputs, 0); + T *x_addr = GetDeviceAddress(inputs, 1); + T *dx_addr = GetDeviceAddress(outputs, 0); + + GeluGradKernel(input_size_ / sizeof(T), dy_addr, x_addr, dx_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (auto dim : input_shape) { + input_size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.cc new file mode 100644 index 0000000000..03cd9a155b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/gelu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + GeluGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + GeluGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.h new file mode 100644 index 0000000000..76d3861d55 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class GeluGpuKernel : public GpuKernel { + public: + GeluGpuKernel() : input_size_(0) {} + ~GeluGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + Gelu(input_size_ / sizeof(T), input_addr, output_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (auto dim : input_shape) { + input_size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.cc new file mode 100644 index 0000000000..49f556ae64 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LayerNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LayerNormGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LayerNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LayerNormGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h new file mode 100644 index 0000000000..74669e03de --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class LayerNormGpuKernel : public GpuKernel { + public: + LayerNormGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} + ~LayerNormGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto x = GetDeviceAddress(inputs, 0); + auto gamma = GetDeviceAddress(inputs, 1); + auto beta = GetDeviceAddress(inputs, 2); + auto y = GetDeviceAddress(outputs, 0); + auto mean = GetDeviceAddress(outputs, 1); + auto variance = GetDeviceAddress(outputs, 2); + + const T epsilon = 10e-12; + LayerNorm(input_row_, input_col_, param_dim_, epsilon, x, gamma, beta, y, mean, variance, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + int begin_norm_axis = GetAttr(kernel_node, "begin_norm_axis"); + int begin_params_axis = GetAttr(kernel_node, "begin_params_axis"); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (begin_norm_axis < 0) { + begin_norm_axis += input_shape.size(); + } + + if (begin_params_axis < 0) { + begin_params_axis += input_shape.size(); + } + + for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { + input_row_ *= input_shape[i]; + } + + for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { + input_col_ *= input_shape[i]; + } + + for (size_t i = begin_params_axis; i < input_shape.size(); i++) { + param_dim_ *= input_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + + output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + output_size_list_.push_back(input_row_ * sizeof(T)); + output_size_list_.push_back(input_row_ * sizeof(T)); + return; + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int input_row_; + int input_col_; + int param_dim_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.cc new file mode 100644 index 0000000000..b59f95b8a2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LayerNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LayerNormGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LayerNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LayerNormGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h new file mode 100644 index 0000000000..93967adad3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class LayerNormGradGpuKernel : public GpuKernel { + public: + LayerNormGradGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} + ~LayerNormGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto x = GetDeviceAddress(inputs, 0); + auto dy = GetDeviceAddress(inputs, 1); + auto var = GetDeviceAddress(inputs, 2); + auto mean = GetDeviceAddress(inputs, 3); + auto gamma = GetDeviceAddress(inputs, 4); + auto dx = GetDeviceAddress(outputs, 0); + auto dg = GetDeviceAddress(outputs, 1); + auto db = GetDeviceAddress(outputs, 2); + + const T epsilon = 10e-12; + LayerNormGrad(input_row_, input_col_, param_dim_, epsilon, dy, x, mean, var, gamma, dx, dg, db, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + int begin_norm_axis = GetAttr(kernel_node, "begin_norm_axis"); + int begin_params_axis = GetAttr(kernel_node, "begin_params_axis"); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (begin_norm_axis < 0) { + begin_norm_axis += input_shape.size(); + } + + if (begin_params_axis < 0) { + begin_params_axis += input_shape.size(); + } + + for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { + input_row_ *= input_shape[i]; + } + + for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { + input_col_ *= input_shape[i]; + } + + for (size_t i = begin_params_axis; i < input_shape.size(); i++) { + param_dim_ *= input_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + input_size_list_.push_back(input_row_ * sizeof(T)); + input_size_list_.push_back(input_row_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + + output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + output_size_list_.push_back(param_dim_ * sizeof(T)); + output_size_list_.push_back(param_dim_ * sizeof(T)); + return; + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int input_row_; + int input_col_; + int param_dim_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.cc new file mode 100644 index 0000000000..a24aaeeb96 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LSTM, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LstmGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LSTM, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LstmGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h new file mode 100644 index 0000000000..ad3e588f00 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h @@ -0,0 +1,247 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class LstmGpuKernel : public GpuKernel { + public: + LstmGpuKernel() + : batch_size_(0), + seq_len_(0), + input_size_(0), + hidden_size_(0), + num_layers_(0), + has_bias_(false), + bidirectional_(false), + states_init_(false), + dropout_(0), + weight_size_(0), + reserved_size_(0), + x_desc_(nullptr), + hx_desc_(nullptr), + cx_desc_(nullptr), + w_desc_(nullptr), + dropout_desc_(nullptr), + y_desc_(nullptr), + hy_desc_(nullptr), + cy_desc_(nullptr), + rnn_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~LstmGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(stream_ptr); + auto x_addr = GetDeviceAddress(inputs, 0); + auto hx_addr = GetDeviceAddress(inputs, 1); + auto cx_addr = GetDeviceAddress(inputs, 2); + auto w_addr = GetDeviceAddress(inputs, 3); + auto y_addr = GetDeviceAddress(outputs, 0); + auto hy_addr = GetDeviceAddress(outputs, 1); + auto cy_addr = GetDeviceAddress(outputs, 2); + auto reserved_addr = GetDeviceAddress(outputs, 3); + auto states_addr = GetDeviceAddress(outputs, 4); + void *workspace_addr = GetDeviceAddress(workspace, 0); + + if (!states_init_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, output_size_list_[4], 0), + "set dropout_desc failed"); + states_init_ = true; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnRNNForwardTraining(handle_, rnn_desc_, seq_len_, x_desc_.get(), x_addr, hx_desc_, hx_addr, cx_desc_, cx_addr, + w_desc_, w_addr, y_desc_.get(), y_addr, hy_desc_, hy_addr, cy_desc_, cy_addr, + workspace_addr, workspace_size_list_[0], reserved_addr, reserved_size_), + "launch lstm kernel failed"); + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + seq_len_ = SizeToInt(input_shape[0]); + batch_size_ = SizeToInt(input_shape[1]); + input_size_ = SizeToInt(input_shape[2]); + + input_size_ = GetAttr(kernel_node, "input_size"); + hidden_size_ = GetAttr(kernel_node, "hidden_size"); + num_layers_ = GetAttr(kernel_node, "num_layers"); + has_bias_ = GetAttr(kernel_node, "has_bias"); + bidirectional_ = GetAttr(kernel_node, "bidirectional"); + dropout_ = GetAttr(kernel_node, "dropout"); + + cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; + cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; + CreateTensorDescGrp(); + int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set hx_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set cx_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set hy_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set cy_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), + "set dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + input_mode, direction, rnn_mode, algo, cudnn_data_type_), + "set rnn_desc failed"); + cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); + auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, x_desc_[0], &weight_size_, cudnn_data_type_), + "get weight_size_ failed"); + if (weight_size != weight_size_) { + MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; + } + int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), + "set w_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &reserved_size_), + "get reserve size failed"); + InitSizeLists(); + return true; + } + void CreateTensorDescGrp() { + int x_dims[3]{batch_size_, input_size_, 1}; + int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; + + x_desc_ = std::make_unique(seq_len_); + y_desc_ = std::make_unique(seq_len_); + + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_[i]), "create x_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(x_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), "set x_desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); + } + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cx_desc_), "create cx_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "create w_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hy_desc_), "create hy_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cy_desc_), "create cy_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); + } + void InitSizeLists() override { + size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); + + size_t h_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); + + input_size_list_.push_back(x_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(weight_size_); + + size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); + output_size_list_.push_back(y_size); + output_size_list_.push_back(h_size); + output_size_list_.push_back(h_size); + output_size_list_.push_back(reserved_size_); + size_t state_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); + output_size_list_.push_back(state_size); + + size_t workspace_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &workspace_size), + "get workspace size failed"); + workspace_size_list_.push_back(workspace_size); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cy_desc_), "destroy cy_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hy_desc_), "destroy hy_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "destroy w_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cx_desc_), "destroy cx_desc failed"); + + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_[i]), "destroy x_desc failed"); + } + } + + int batch_size_; + int seq_len_; + int input_size_; + int hidden_size_; + int num_layers_; + + bool has_bias_; + bool bidirectional_; + bool states_init_; + float dropout_; + + size_t weight_size_; + size_t reserved_size_; + + // input desc + std::unique_ptr x_desc_; + cudnnTensorDescriptor_t hx_desc_; + cudnnTensorDescriptor_t cx_desc_; + cudnnFilterDescriptor_t w_desc_; + cudnnDropoutDescriptor_t dropout_desc_; + std::unique_ptr y_desc_; + cudnnTensorDescriptor_t hy_desc_; + cudnnTensorDescriptor_t cy_desc_; + cudnnRNNDescriptor_t rnn_desc_; + + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.cc new file mode 100644 index 0000000000..1fa47690b3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LSTMGradData, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LstmGradDataGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LSTMGradData, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LstmGradDataGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h new file mode 100644 index 0000000000..6d6bed5555 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h @@ -0,0 +1,284 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class LstmGradDataGpuKernel : public GpuKernel { + public: + LstmGradDataGpuKernel() + : batch_size_(0), + seq_len_(0), + input_size_(0), + hidden_size_(0), + num_layers_(0), + has_bias_(false), + bidirectional_(false), + states_init_(false), + dropout_(0), + weight_size_(0), + reserved_size_(0), + rnn_desc_(nullptr), + y_desc_(nullptr), + dy_desc_(nullptr), + dhy_desc_(nullptr), + dcy_desc_(nullptr), + w_desc_(nullptr), + hx_desc_(nullptr), + cx_desc_(nullptr), + dropout_desc_(nullptr), + dx_desc_(nullptr), + dhx_desc_(nullptr), + dcx_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~LstmGradDataGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(stream_ptr); + auto y_addr = GetDeviceAddress(inputs, 0); + auto dy_addr = GetDeviceAddress(inputs, 1); + auto dhy_addr = GetDeviceAddress(inputs, 2); + auto dcy_addr = GetDeviceAddress(inputs, 3); + auto w_addr = GetDeviceAddress(inputs, 4); + auto hx_addr = GetDeviceAddress(inputs, 5); + auto cx_addr = GetDeviceAddress(inputs, 6); + auto reserved_addr = GetDeviceAddress(inputs, 7); + auto states_addr = GetDeviceAddress(inputs, 8); + auto dx_addr = GetDeviceAddress(outputs, 0); + auto dhx_addr = GetDeviceAddress(outputs, 1); + auto dcx_addr = GetDeviceAddress(outputs, 2); + void *workspace_addr = GetDeviceAddress(workspace, 0); + + if (!states_init_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnRestoreDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, input_size_list_[8], 0), + "restore dropout state failed"); + states_init_ = true; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnRNNBackwardData(handle_, rnn_desc_, seq_len_, y_desc_.get(), y_addr, dy_desc_.get(), dy_addr, dhy_desc_, + dhy_addr, dcy_desc_, dcy_addr, w_desc_, w_addr, hx_desc_, hx_addr, cx_desc_, cx_addr, + dx_desc_.get(), dx_addr, dhx_desc_, dhx_addr, dcx_desc_, dcx_addr, workspace_addr, + workspace_size_list_[0], reserved_addr, reserved_size_), + "launch lstm back data kernel failed"); + + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), + "stream synchronize failed."); + return true; + } + void GetAttrs(const CNodePtr &kernel_node) { + input_size_ = GetAttr(kernel_node, "input_size"); + hidden_size_ = GetAttr(kernel_node, "hidden_size"); + num_layers_ = GetAttr(kernel_node, "num_layers"); + has_bias_ = GetAttr(kernel_node, "has_bias"); + bidirectional_ = GetAttr(kernel_node, "bidirectional"); + dropout_ = GetAttr(kernel_node, "dropout"); + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + seq_len_ = SizeToInt(input_shape[0]); + batch_size_ = SizeToInt(input_shape[1]); + GetAttrs(kernel_node); + cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; + cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; + CreateTensorDescGrp(); + int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dhy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dhy_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dcy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dcy_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set hx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set cx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dhx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dhx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dcx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dcx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), + "set dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + input_mode, direction, rnn_mode, algo, cudnn_data_type_), + "set rnn_desc failed"); + cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); + auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); + size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, dx_desc_[0], &weight_size_, cudnn_data_type_), + "get weight_size_ failed"); + if (weight_size != weight_size_) { + MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; + } + int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), + "set w_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, dx_desc_.get(), &reserved_size_), "get size failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dhy_desc_), "create dhy_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dcy_desc_), "create dcy_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cx_desc_), "create cx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "create w_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dhx_desc_), "create dhx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dcx_desc_), "create dcx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); + } + + void InitSizeLists() override { + size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); + + size_t h_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); + + input_size_list_.push_back(y_size); + input_size_list_.push_back(y_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(weight_size_); + input_size_list_.push_back(h_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(reserved_size_); + size_t state_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); + input_size_list_.push_back(state_size); + + size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); + output_size_list_.push_back(x_size); + output_size_list_.push_back(h_size); + output_size_list_.push_back(h_size); + + size_t workspace_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, dx_desc_.get(), &workspace_size), + "get workspace size failed"); + workspace_size_list_.push_back(workspace_size); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dcx_desc_), "destroy dcx_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dhx_desc_), "destroy dhx_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "destroy w_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cx_desc_), "destroy cx_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dcy_desc_), "destroy dcy_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dhy_desc_), "destroy dhy_desc_ failed"); + DestroyTensorDescGrp(); + } + void CreateTensorDescGrp() { + int x_dims[3]{batch_size_, input_size_, 1}; + int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; + + dx_desc_ = std::make_unique(seq_len_); + y_desc_ = std::make_unique(seq_len_); + dy_desc_ = std::make_unique(seq_len_); + + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_[i]), "create x_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dx_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), + "set dx_desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_[i]), "create dy_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dy_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), + "set dy_desc_ failed"); + } + } + + void DestroyTensorDescGrp() { + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_[i]), "destroy dy_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_[i]), "destroy x_desc failed"); + } + } + + int batch_size_; + int seq_len_; + int input_size_; + int hidden_size_; + int num_layers_; + + bool has_bias_; + bool bidirectional_; + bool states_init_; + float dropout_; + + size_t weight_size_; + size_t reserved_size_; + + cudnnRNNDescriptor_t rnn_desc_; + + // input desc + std::unique_ptr y_desc_; + std::unique_ptr dy_desc_; + cudnnTensorDescriptor_t dhy_desc_; + cudnnTensorDescriptor_t dcy_desc_; + cudnnFilterDescriptor_t w_desc_; + cudnnTensorDescriptor_t hx_desc_; + cudnnTensorDescriptor_t cx_desc_; + + cudnnDropoutDescriptor_t dropout_desc_; + + // output desc + std::unique_ptr dx_desc_; + cudnnTensorDescriptor_t dhx_desc_; + cudnnTensorDescriptor_t dcx_desc_; + + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.cc new file mode 100644 index 0000000000..9ec239491f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LSTMGradWeight, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LstmGradWeightGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LSTMGradWeight, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LstmGradWeightGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h new file mode 100644 index 0000000000..445d2ce199 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h @@ -0,0 +1,231 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +namespace mindspore { +namespace kernel { +template +class LstmGradWeightGpuKernel : public GpuKernel { + public: + LstmGradWeightGpuKernel() + : batch_size_(0), + seq_len_(0), + input_size_(0), + hidden_size_(0), + num_layers_(0), + has_bias_(false), + bidirectional_(false), + states_init_(false), + dropout_(0), + weight_size_(0), + reserved_size_(0), + rnn_desc_(nullptr), + dropout_desc_(nullptr), + x_desc_(nullptr), + hx_desc_(nullptr), + y_desc_(nullptr), + dw_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~LstmGradWeightGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(stream_ptr); + auto x_addr = GetDeviceAddress(inputs, 0); + auto hx_addr = GetDeviceAddress(inputs, 1); + auto y_addr = GetDeviceAddress(inputs, 2); + auto reserved_addr = GetDeviceAddress(inputs, 3); + auto states_addr = GetDeviceAddress(inputs, 4); + auto dw_addr = GetDeviceAddress(outputs, 0); + void *workspace_addr = GetDeviceAddress(workspace, 0); + + if (!states_init_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnRestoreDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, input_size_list_[4], 0), + "restore dropout state failed"); + states_init_ = true; + } + + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemsetAsync(dw_addr, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), "cudaMemSet Failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnRNNBackwardWeights(handle_, rnn_desc_, seq_len_, x_desc_.get(), x_addr, hx_desc_, hx_addr, y_desc_.get(), + y_addr, workspace_addr, workspace_size_list_[0], dw_desc_, dw_addr, reserved_addr, + reserved_size_), + "launch lstm back weight kernel failed"); + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + seq_len_ = SizeToInt(input_shape[0]); + batch_size_ = SizeToInt(input_shape[1]); + + input_size_ = GetAttr(kernel_node, "input_size"); + hidden_size_ = GetAttr(kernel_node, "hidden_size"); + num_layers_ = GetAttr(kernel_node, "num_layers"); + has_bias_ = GetAttr(kernel_node, "has_bias"); + bidirectional_ = GetAttr(kernel_node, "bidirectional"); + dropout_ = GetAttr(kernel_node, "dropout"); + + cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; + cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; + + CreateTensorDescGrp(); + int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set hx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), + "set dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + input_mode, direction, rnn_mode, algo, cudnn_data_type_), + "set rnn_desc failed"); + cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); + + auto weight_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, x_desc_[0], &weight_size_, cudnn_data_type_), + "get weight_size_ failed"); + if (weight_size != weight_size_) { + MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; + } + int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), + "set dw_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &reserved_size_), + "get reserve size failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&dw_desc_), "create dw_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); + } + void InitSizeLists() override { + size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); + + size_t h_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); + + size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); + input_size_list_.push_back(x_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(y_size); + input_size_list_.push_back(reserved_size_); + size_t state_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); + input_size_list_.push_back(state_size); + + output_size_list_.push_back(weight_size_); + + size_t workspace_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &workspace_size), + "get workspace size failed"); + workspace_size_list_.push_back(workspace_size); + } + + private: + void CreateTensorDescGrp() { + int x_dims[3]{batch_size_, input_size_, 1}; + int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; + + x_desc_ = std::make_unique(seq_len_); + y_desc_ = std::make_unique(seq_len_); + + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_[i]), "create x_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(x_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), "set x_desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); + } + } + void DestroyTensorDescGrp() { + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_[i]), "destroy x_desc failed"); + } + } + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(dw_desc_), "destroy dw_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc_ failed"); + DestroyTensorDescGrp(); + } + + int batch_size_; + int seq_len_; + int input_size_; + int hidden_size_; + int num_layers_; + + bool has_bias_; + bool bidirectional_; + bool states_init_; + float dropout_; + + size_t weight_size_; + size_t reserved_size_; + + cudnnRNNDescriptor_t rnn_desc_; + cudnnDropoutDescriptor_t dropout_desc_; + + // input desc + std::unique_ptr x_desc_; + cudnnTensorDescriptor_t hx_desc_; + std::unique_ptr y_desc_; + + // output desc + cudnnFilterDescriptor_t dw_desc_; + + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc new file mode 100644 index 0000000000..99ae2affe8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + MomentumGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + MomentumGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16), + MomentumGpuKernel, half, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h new file mode 100644 index 0000000000..32d3fbb079 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h @@ -0,0 +1,100 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" +namespace mindspore { +namespace kernel { +template +class MomentumGpuKernel : public GpuKernel { + public: + MomentumGpuKernel() + : variable_size_(0), accumulation_size_(0), learning_rate_size_(0), gradient_size_(0), momentum_size_(0) {} + ~MomentumGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *accumulation = GetDeviceAddress(inputs, 1); + S *learning_rate = GetDeviceAddress(inputs, 2); + T *gradient = GetDeviceAddress(inputs, 3); + S *momentum = GetDeviceAddress(inputs, 4); + MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but momentum needs 5 inputs."; + return false; + } + + variable_size_ = sizeof(T); + accumulation_size_ = sizeof(T); + learning_rate_size_ = sizeof(S); + gradient_size_ = sizeof(T); + momentum_size_ = sizeof(S); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + } + auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < accumulation_shape.size(); i++) { + accumulation_size_ *= accumulation_shape[i]; + } + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(accumulation_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(gradient_size_); + input_size_list_.push_back(momentum_size_); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t accumulation_size_; + size_t learning_rate_size_; + size_t gradient_size_; + size_t momentum_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.cc new file mode 100644 index 0000000000..902b0d9faf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + PoolingGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + PoolingGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + PoolingGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + PoolingGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h new file mode 100644 index 0000000000..908a4e9b99 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h @@ -0,0 +1,252 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class PoolingGpuFwdKernel : public GpuKernel { + public: + PoolingGpuFwdKernel() + : cudnn_handle_(nullptr), + input_descriptor_(nullptr), + output_descriptor_(nullptr), + pooling_descriptor_(nullptr), + padded_descriptor_(nullptr), + pooling_mode_(CUDNN_POOLING_MAX), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + pad_value_(0), + is_null_input_(false), + input_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~PoolingGpuFwdKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } + T *input_addr = reinterpret_cast(inputs[0]->addr); + T *output_addr = reinterpret_cast(outputs[0]->addr); + const float alpha = 1; + const float beta = 0; + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded_addr = reinterpret_cast(workspace[0]->addr); + CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, + reinterpret_cast(stream_ptr)); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, padded_descriptor_, + padded_addr, &beta, output_descriptor_, output_addr), + "cudnnPoolingForward failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, + input_addr, &beta, output_descriptor_, output_addr), + "cudnnPoolingForward failed"); + } + return true; + } + bool Init(const CNodePtr &kernel_node) { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "PoolingGpuFwdKernel input is null."; + InitSizeLists(); + return true; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), + SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), + SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); + int window_height = window[2]; + int window_width = window[3]; + stride_ = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); + SetPoolingMode(kernel_node); + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(input_shape, window_height, window_width); + } else { + pad_height_ = 0; + pad_width_ = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, + window_width, pad_height_, pad_width_, stride_[2], stride_[3]), + "cudnnSetPooling2dDescriptor failed"); + } + + InitSizeLists(); + return true; + } + + protected: + void InitResource() { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), + "cudnnCreatePoolingDescriptor failed"); + } + void InitSizeLists() { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetTensorSizeInBytes(input_descriptor_, reinterpret_cast(&input_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetTensorSizeInBytes(output_descriptor_, reinterpret_cast(&output_size_)), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast(&padded_size_)), + "cudnnGetTensorSizeInBytes failed"); + workspace_size_list_.push_back(padded_size_); + if (padded_size_ == 0) { + MS_LOG(EXCEPTION) << "Padded size is 0."; + } + } + return; + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but pooling needs 1 inputs."; + return false; + } + return true; + } + void SetPad(const std::vector &input_shape, const int &window_height, const int &window_width) { + n_ = SizeToInt(input_shape[0]); + c_ = SizeToInt(input_shape[1]); + old_height_ = SizeToInt(input_shape[2]); + old_width_ = SizeToInt(input_shape[3]); + pad_height_ = + std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) + : (old_height_ / stride_[2]) + 1) - + 1) * + stride_[2] + + window_height - old_height_); + pad_width_ = + std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) + : (old_width_ / stride_[3]) + 1) - + 1) * + stride_[3] + + window_width - old_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, + c_, old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, + window_height, window_width, use_pad_ ? 0 : pad_top_, + use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), + "cudnnSetPooling2dDescriptor failed"); + } + void SetPoolingMode(const CNodePtr &kernel_node) { + pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); + mode_ = AnfAlgo::GetCNodeName(kernel_node); + if (mode_ == "AvgPool") { + pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + pad_value_ = 0.0; + } else { + pooling_mode_ = CUDNN_POOLING_MAX; + pad_value_ = kSignedMinFloat; + } + } + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), + "cudnnDestroyPoolingDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_descriptor_; + cudnnTensorDescriptor_t output_descriptor_; + cudnnPoolingDescriptor_t pooling_descriptor_; + cudnnTensorDescriptor_t padded_descriptor_; + cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; + std::vector stride_; + std::string mode_; + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + cudnnDataType_t cudnn_data_type_; + + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + float pad_value_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.cc new file mode 100644 index 0000000000..2948c900d2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + PoolingGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + PoolingGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + PoolingGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + PoolingGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h new file mode 100644 index 0000000000..a066eacfa0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h @@ -0,0 +1,296 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class PoolingGradGpuKernel : public GpuKernel { + public: + PoolingGradGpuKernel() + : cudnn_handle_(nullptr), + pooling_descriptor_(nullptr), + y_descriptor_(nullptr), + dy_descriptor_(nullptr), + x_descriptor_(nullptr), + dx_descriptor_(nullptr), + padded_descriptor_(nullptr), + pooling_mode_(CUDNN_POOLING_MAX), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + pad_value_(0), + is_null_input_(false), + input_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~PoolingGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *x_data = GetDeviceAddress(inputs, 0); + T *y = GetDeviceAddress(inputs, 1); + T *dy = GetDeviceAddress(inputs, 2); + T *dx = GetDeviceAddress(outputs, 0); + + const float alpha = 1; + const float beta = 0; + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded = GetDeviceAddress(workspace, 0); + T *padded_dx = GetDeviceAddress(workspace, 1); + + CalPad(padded_size_ / sizeof(T), x_data, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, + reinterpret_cast(stream_ptr)); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, + padded_descriptor_, padded, &beta, padded_descriptor_, padded_dx), + "cudnnPoolingBackward failed"); + + CalPadGrad(output_size_ / sizeof(T), padded_dx, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, + x_descriptor_, x_data, &beta, dx_descriptor_, dx), + "cudnnPoolingBackward failed"); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + auto window = GetAttr>(kernel_node, "ksize"); + int window_height = window[2]; + int window_width = window[3]; + SetPoolingMode(kernel_node); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); + if (is_null_input_) { + MS_LOG(WARNING) << "PoolingGradGpuKernel input is null."; + InitSizeLists(); + return true; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(y_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_mask[0]), + SizeToInt(input_mask[1]), SizeToInt(input_mask[2]), SizeToInt(input_mask[3])), + "cudnnSetTensor4dDescriptor"); + + auto dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dy_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dout_shape[0]), + SizeToInt(dout_shape[1]), SizeToInt(dout_shape[2]), SizeToInt(dout_shape[3])), + "cudnnSetTensor4dDescriptor"); + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dx_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), + SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { + SetPad(input_shape, window_height, window_width); + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, + window_width, pad_height_, pad_width_, stride_[2], stride_[3]), + "cudnnSetPooling2dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), + SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), + "cudnnSetTensor4dDescriptor"); + } + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), + "cudnnCreatePoolingDescriptor failed"); + } + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(y_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dx_descriptor_, &output_size_), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), + "cudnnGetTensorSizeInBytes failed"); + if (padded_size_ == 0) { + MS_LOG(EXCEPTION) << "Padded size is 0."; + } + workspace_size_list_.push_back(padded_size_); + workspace_size_list_.push_back(padded_size_); + } + return; + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but PoolingGradGpuKernel needs 3 inputs."; + return false; + } + return true; + } + void SetPad(const std::vector &input_shape, const int &window_height, const int &window_width) { + n_ = SizeToInt(input_shape[0]); + c_ = SizeToInt(input_shape[1]); + old_height_ = SizeToInt(input_shape[2]); + old_width_ = SizeToInt(input_shape[3]); + pad_height_ = + std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) + : (old_height_ / stride_[2]) + 1) - + 1) * + stride_[2] + + window_height - old_height_); + pad_width_ = + std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) + : (old_width_ / stride_[3]) + 1) - + 1) * + stride_[3] + + window_width - old_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, + c_, old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), + SizeToInt(input_shape[1]), SizeToInt(input_shape[2]) + (use_pad_ ? pad_height_ : 0), + SizeToInt(input_shape[3]) + (use_pad_ ? pad_width_ : 0)), + "cudnnSetTensor4dDescriptor"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, + window_height, window_width, use_pad_ ? 0 : pad_top_, + use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), + "cudnnSetPooling2dDescriptor failed"); + } + void SetPoolingMode(const CNodePtr &kernel_node) { + pad_mode_ = GetAttr(kernel_node, "padding"); + stride_ = GetAttr>(kernel_node, "strides"); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + mode_ = AnfAlgo::GetCNodeName(kernel_node); + if (mode_ == "AvgPoolGradGpu") { + pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + pad_value_ = 0.0; + } else { + pooling_mode_ = CUDNN_POOLING_MAX; + pad_value_ = kSignedMinFloat; + } + } + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), + "cudnnDestroyPoolingDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_descriptor_), "cudnnDestroyTensorDescriptor failed"); + } + + cudnnHandle_t cudnn_handle_; + cudnnPoolingDescriptor_t pooling_descriptor_; + cudnnTensorDescriptor_t y_descriptor_; + cudnnTensorDescriptor_t dy_descriptor_; + cudnnTensorDescriptor_t x_descriptor_; + cudnnTensorDescriptor_t dx_descriptor_; + cudnnTensorDescriptor_t padded_descriptor_; + cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; + std::vector stride_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + std::string mode_; + std::string pad_mode_; + cudnnDataType_t cudnn_data_type_; + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + float pad_value_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.cc new file mode 100644 index 0000000000..c33909a82b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ApplyRMSProp, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + RMSPropGpuKernel, float) + +MS_REG_GPU_KERNEL_ONE(ApplyCenteredRMSProp, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + RMSPropGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h new file mode 100644 index 0000000000..9811c71094 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class RMSPropGpuKernel : public GpuKernel { + public: + RMSPropGpuKernel() : size_(1), use_center_(false), decay_(0.0), momentum_(0.9), epsilon_(1e-12) {} + ~RMSPropGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream) override { + if (!use_center_) { + T *variable = GetDeviceAddress(inputs, 0); + T *mean_square = GetDeviceAddress(inputs, 1); + T *moment = GetDeviceAddress(inputs, 2); + T *learning_rate = GetDeviceAddress(inputs, 3); + T *gradients = GetDeviceAddress(inputs, 4); + + RmsProp(learning_rate, decay_, momentum_, epsilon_, variable, mean_square, moment, gradients, size_, + reinterpret_cast(stream)); + } else { + T *variable = GetDeviceAddress(inputs, 0); + T *mean_gradients = GetDeviceAddress(inputs, 1); + T *mean_square = GetDeviceAddress(inputs, 2); + T *moment = GetDeviceAddress(inputs, 3); + T *gradients = GetDeviceAddress(inputs, 4); + T *learning_rate = GetDeviceAddress(inputs, 5); + T *decay = GetDeviceAddress(inputs, 6); + T *momentum = GetDeviceAddress(inputs, 7); + T *epsilon = GetDeviceAddress(inputs, 8); + + RmsPropCenter(learning_rate, decay, momentum, epsilon, variable, mean_gradients, mean_square, moment, gradients, + size_, reinterpret_cast(stream)); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (node_name == "ApplyCenteredRMSProp") { + use_center_ = true; + } + + if (node_name == "ApplyRMSProp") { + decay_ = GetAttr(kernel_node, "rho"); + momentum_ = GetAttr(kernel_node, "momentum"); + epsilon_ = GetAttr(kernel_node, "epsilon"); + } + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (auto &dim : input_shape) { + size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t input_size = size_ * sizeof(T); + if (!use_center_) { + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(input_size); + output_size_list_.push_back(input_size); + } else { + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(input_size); + } + } + + private: + size_t size_; + bool use_center_; + float decay_; + float momentum_; + float epsilon_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc new file mode 100644 index 0000000000..96d2d29549 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + SigmoidCrossEntropyWithLogits, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SigmoidCrossEntropyWithLogitsGpuKernel, float, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h new file mode 100644 index 0000000000..a2d3aabb68 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SigmoidCrossEntropyWithLogitsGpuKernel : public GpuKernel { + public: + SigmoidCrossEntropyWithLogitsGpuKernel() : logits_size_(0), labels_size_(0), outputs_size_(0) {} + + ~SigmoidCrossEntropyWithLogitsGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *logits_addr = GetDeviceAddress(inputs, 0); + S *labels_addr = GetDeviceAddress(inputs, 1); + T *outputs_addr = GetDeviceAddress(outputs, 0); + + SigmoidCrossEntropyWithLogits(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogits needs 2 inputs."; + return false; + } + logits_size_ = sizeof(T); + labels_size_ = sizeof(S); + outputs_size_ = sizeof(T); + + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < logits_shape.size(); i++) { + logits_size_ *= logits_shape[i]; + } + + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < labels_shape.size(); i++) { + labels_size_ *= labels_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + outputs_size_ *= output_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(logits_size_); + input_size_list_.push_back(labels_size_); + output_size_list_.push_back(outputs_size_); + } + + private: + size_t logits_size_; + size_t labels_size_; + size_t outputs_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc new file mode 100644 index 0000000000..05c9a4234b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h new file mode 100644 index 0000000000..88ab46a6ba --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel { + public: + SigmoidCrossEntropyWithLogitsGradGpuKernel() : logits_size_(0), labels_size_(0), outputs_size_(0) {} + ~SigmoidCrossEntropyWithLogitsGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *logits_addr = GetDeviceAddress(inputs, 0); + S *labels_addr = GetDeviceAddress(inputs, 1); + T *outputs_addr = GetDeviceAddress(outputs, 0); + + SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogitsGrad needs 3 inputs."; + return false; + } + logits_size_ = sizeof(T); + labels_size_ = sizeof(S); + outputs_size_ = sizeof(T); + + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < logits_shape.size(); i++) { + logits_size_ *= logits_shape[i]; + } + + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < labels_shape.size(); i++) { + labels_size_ *= labels_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + outputs_size_ *= output_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(logits_size_); + input_size_list_.push_back(labels_size_); + output_size_list_.push_back(outputs_size_); + } + + private: + size_t logits_size_; + size_t labels_size_; + size_t outputs_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.cc new file mode 100644 index 0000000000..ea40bea6a4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + SmoothL1Loss, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SmoothL1LossGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h new file mode 100644 index 0000000000..dc20f75077 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh" +namespace mindspore { +namespace kernel { +template +class SmoothL1LossGpuKernel : public GpuKernel { + public: + SmoothL1LossGpuKernel() : input_size_(1), sigma_(1.0) {} + ~SmoothL1LossGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *prediction = GetDeviceAddress(inputs, 0); + T *target = GetDeviceAddress(inputs, 1); + T *loss = GetDeviceAddress(outputs, 0); + + SmoothL1Loss(input_size_, sigma_, prediction, target, loss, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + + sigma_ = GetAttr(kernel_node, "sigma"); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + } + + private: + size_t input_size_; + float sigma_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc new file mode 100644 index 0000000000..8a4fb38460 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(SmoothL1LossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SmoothL1LossGradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h new file mode 100644 index 0000000000..02be336932 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh" +namespace mindspore { +namespace kernel { +template +class SmoothL1LossGradGpuKernel : public GpuKernel { + public: + SmoothL1LossGradGpuKernel() : input_size_(1), sigma_(1.0) {} + ~SmoothL1LossGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *prediction = GetDeviceAddress(inputs, 0); + T *target = GetDeviceAddress(inputs, 1); + T *dloss = GetDeviceAddress(inputs, 2); + T *dx = GetDeviceAddress(outputs, 0); + + SmoothL1LossGrad(input_size_, sigma_, prediction, target, dloss, dx, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + + sigma_ = GetAttr(kernel_node, "sigma"); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + } + + private: + size_t input_size_; + float sigma_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc new file mode 100644 index 0000000000..8a64762c0a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(SoftmaxCrossEntropyWithLogits, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SoftmaxCrossEntropyWithLogitsGpuKernel, float, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h new file mode 100644 index 0000000000..e56cb96fd7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h @@ -0,0 +1,205 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { + public: + SoftmaxCrossEntropyWithLogitsGpuKernel() + : cudnn_handle_(nullptr), + logits_descriptor_(nullptr), + softmax_output_descriptor_(nullptr), + algo_(CUDNN_SOFTMAX_ACCURATE), + mode_(CUDNN_SOFTMAX_MODE_INSTANCE), + cudnn_data_type_(CUDNN_DATA_FLOAT), + is_null_input_(false), + logits_size_(0), + labels_size_(0), + output1_size_(0), + output2_size_(0), + softmax_output_logits_size_(0), + batch_size_(0), + channel_size_(0), + height_(0), + width_(0) {} + ~SoftmaxCrossEntropyWithLogitsGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *logits_addr = GetDeviceAddress(inputs, 0); + S *labels_addr = GetDeviceAddress(inputs, 1); + T *loss_addr = GetDeviceAddress(outputs, 0); + T *dlogits_addr = GetDeviceAddress(outputs, 1); + T *softmax_output_logits = GetDeviceAddress(workspace, 0); + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, logits_descriptor_, logits_addr, &beta, + softmax_output_descriptor_, softmax_output_logits), + "cudnnSoftmaxForward failed."); + + CrossEntropy(softmax_output_logits, labels_addr, batch_size_, channel_size_, loss_addr, dlogits_addr, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num + << ", but SoftmaxCrossEntropyWithLogitsGpuKernel needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(ERROR) << "Output number is " << output_num + << ", but SoftmaxCrossEntropyWithLogitsGpuKernel needs 2 output."; + return false; + } + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + + InferInputOutputSize(kernel_node); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + batch_size_, channel_size_, height_, width_), + "cudnnSetTensor4dDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(softmax_output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_size_, + channel_size_, height_, width_), + "cudnnSetTensor4dDescriptor failed."); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&logits_descriptor_), + "cudnnCreateTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&softmax_output_descriptor_), + "cudnnCreateTensorDescriptor failed."); + } + void InitSizeLists() override { + input_size_list_.push_back(logits_size_); + input_size_list_.push_back(labels_size_); + output_size_list_.push_back(output1_size_); + output_size_list_.push_back(output2_size_); + workspace_size_list_.push_back(softmax_output_logits_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(softmax_output_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(logits_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + } + void InferInputOutputSize(const CNodePtr &kernel_node) { + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(logits_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input1 is null"; + InitSizeLists(); + return; + } + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(logits_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input2 is null"; + InitSizeLists(); + return; + } + CheckShapeValidation(logits_shape, labels_shape); + + size_t logits_dims = logits_shape.size(); + batch_size_ = 1; + for (size_t i = 0; i < logits_dims - 1; i++) { + batch_size_ *= logits_shape[i]; + } + channel_size_ = logits_shape[logits_dims - 1]; + height_ = 1; + width_ = 1; + logits_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; + + labels_size_ = 1; + size_t labels_dims = labels_shape.size(); + for (size_t i = 0; i < labels_dims; i++) { + labels_size_ *= labels_shape[i]; + } + labels_size_ *= sizeof(S); + + output1_size_ = logits_size_ / logits_shape[logits_dims - 1]; + output2_size_ = logits_size_; + softmax_output_logits_size_ = logits_size_; + return; + } + void CheckShapeValidation(const std::vector &logits_shape, const std::vector &labels_shape) { + size_t logits_dim_length = logits_shape.size(); + size_t labels_dim_length = labels_shape.size(); + if (labels_dim_length != logits_dim_length) { + MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length for " + "SoftmaxCrossEntropyWithLogits, but got Labels " + "shape length:" + << labels_dim_length << ", Logits shape length:" << logits_dim_length; + } + if (!std::equal(labels_shape.begin(), labels_shape.end(), logits_shape.begin())) { + MS_LOG(EXCEPTION) << "The shape of labels should be the same as the shape of logits except its last demension."; + } + return; + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t logits_descriptor_; + cudnnTensorDescriptor_t softmax_output_descriptor_; + cudnnSoftmaxAlgorithm_t algo_; + cudnnSoftmaxMode_t mode_; + cudnnDataType_t cudnn_data_type_; + bool is_null_input_; + + size_t logits_size_; + size_t labels_size_; + size_t output1_size_; + size_t output2_size_; + size_t softmax_output_logits_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t batch_size_; + size_t channel_size_; + size_t height_; + size_t width_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.cc new file mode 100644 index 0000000000..24c2c12601 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SoftmaxGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SoftmaxGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SoftmaxGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SoftmaxGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h new file mode 100644 index 0000000000..279bac3aa9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h @@ -0,0 +1,252 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SoftmaxGpuKernel : public GpuKernel { + public: + SoftmaxGpuKernel() + : cudnn_handle_(nullptr), + input_descriptor_(nullptr), + output_descriptor_(nullptr), + algo_(CUDNN_SOFTMAX_ACCURATE), + mode_(CUDNN_SOFTMAX_MODE_INSTANCE), + cudnn_data_type_(CUDNN_DATA_FLOAT), + is_null_input_(false), + input_size_(0), + output_size_(0), + workspace_size_(0), + axis_(0), + shape_size_(0), + batch_size_(0), + channel_size_(0), + height_(0), + width_(0) {} + ~SoftmaxGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + const float alpha = 1; + const float beta = 0; + + if (axis_ == 1) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, input_descriptor_, + input_addr, &beta, output_descriptor_, output_addr), + "cudnnSoftmaxForward failed"); + } else { + T *transpose_input_addr = GetDeviceAddress(workspace, 0); + T *transpose_output_addr = GetDeviceAddress(workspace, 1); + int *input_shape = GetDeviceAddress(workspace, 2); + int *transpose_shape = GetDeviceAddress(workspace, 3); + int *transpose_axis = GetDeviceAddress(workspace, 4); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_axis failed"); + int size = SizeToInt(input_size_ / sizeof(T)); + CalTranspose(size, input_addr, input_shape, transpose_axis, shape_size_, transpose_input_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, input_descriptor_, transpose_input_addr, &beta, + output_descriptor_, transpose_output_addr), + "cudnnSoftmaxForward failed"); + CalTranspose(size, transpose_output_addr, transpose_shape, transpose_axis, shape_size_, output_addr, + reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but softmax needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxGpuKernel input is null"; + InitSizeLists(); + return true; + } + shape_size_ = SizeToInt(input_shape.size()); + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == "LogSoftmax") { + algo_ = CUDNN_SOFTMAX_LOG; + auto axis = GetAttr(kernel_node, "axis"); + InitSizeByAxis(input_shape, axis); + } else { + algo_ = CUDNN_SOFTMAX_ACCURATE; + auto axis = GetAttr>(kernel_node, "axis"); + InitSizeByAxis(input_shape, axis[0]); + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), + SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), + "set input_descriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), + SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), + "set output_descriptor failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "create input_descriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "create output_descriptor failed"); + } + + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(input_size_); + workspace_size_list_.push_back(output_size_); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "destroy output_descriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "destroy input_descriptor failed"); + } + + void InitSizeByAxis(const std::vector &input_shape, const int &axis) { + if (input_shape.size() == 2) { + InitSizeByAxis2D(input_shape, axis); + } else { + InitSizeByAxisLastDim(input_shape, axis); + } + } + + void InitSizeByAxis2D(const std::vector &input_shape, const int &axis) { + axis_ = axis; + if (axis_ < 0) { + axis_ += shape_size_; + } + if (axis_ == 1) { + batch_size_ = input_shape[0]; + channel_size_ = input_shape[1]; + } else if (axis_ == 0) { + batch_size_ = input_shape[1]; + channel_size_ = input_shape[0]; + input_shape_.push_back(input_shape[0]); + input_shape_.push_back(input_shape[1]); + transpose_shape_.push_back(input_shape[1]); + transpose_shape_.push_back(input_shape[0]); + transpose_axis_.push_back(1); + transpose_axis_.push_back(0); + } else { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; + } + + height_ = 1; + width_ = 1; + input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; + output_size_ = input_size_; + workspace_size_ = IntToSize(shape_size_) * sizeof(int); + } + + void InitSizeByAxisLastDim(const std::vector &input_shape, const int &axis) { + int axis_pos = axis; + if (axis_pos < 0) { + axis_pos += input_shape.size(); + } + // axis should be -1 with ND + if (axis_pos != SizeToInt(input_shape.size() - 1)) { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; + } + // squeeze to 2d, then invoke cudnn + size_t n = 1; + for (size_t i = 0; i < input_shape.size() - 1; i++) { + n *= input_shape[i]; + } + axis_ = 1; + batch_size_ = n; + channel_size_ = input_shape[axis_pos]; + height_ = 1; + width_ = 1; + input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; + output_size_ = input_size_; + input_shape_.push_back(batch_size_); + input_shape_.push_back(channel_size_); + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_descriptor_; + cudnnTensorDescriptor_t output_descriptor_; + cudnnSoftmaxAlgorithm_t algo_; + cudnnSoftmaxMode_t mode_; + cudnnDataType_t cudnn_data_type_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector input_shape_; + std::vector transpose_shape_; + std::vector transpose_axis_; + int axis_; + int shape_size_; + + size_t batch_size_; + size_t channel_size_; + size_t height_; + size_t width_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.cc new file mode 100644 index 0000000000..bd20413d08 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + LogSoftmaxGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SoftmaxGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + LogSoftmaxGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SoftmaxGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h new file mode 100644 index 0000000000..b814be9969 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h @@ -0,0 +1,219 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SoftmaxGradGpuKernel : public GpuKernel { + public: + SoftmaxGradGpuKernel() + : cudnn_handle_(nullptr), + y_desc_(nullptr), + algo_(CUDNN_SOFTMAX_ACCURATE), + mode_(CUDNN_SOFTMAX_MODE_INSTANCE), + cudnn_data_type_(CUDNN_DATA_FLOAT), + is_null_input_(false), + input_size_(0), + output_size_(0), + workspace_size_(0), + axis_(0), + shape_size_(0), + batch_size_(0), + channel_size_(0), + height_(0), + width_(0) {} + ~SoftmaxGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *y_addr = GetDeviceAddress(inputs, 0); + T *dy_addr = GetDeviceAddress(inputs, 1); + T *dx_addr = GetDeviceAddress(outputs, 0); + + T *transpose_y_addr = GetDeviceAddress(workspace, 0); + T *transpose_dy_addr = GetDeviceAddress(workspace, 1); + T *transpose_dx_addr = GetDeviceAddress(workspace, 2); + int *input_shape = GetDeviceAddress(workspace, 3); + int *transpose_shape = GetDeviceAddress(workspace, 4); + int *transpose_axis = GetDeviceAddress(workspace, 5); + const float alpha = 1; + const float beta = 0; + + if (axis_ == 1) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, y_addr, y_desc_, + dy_addr, &beta, y_desc_, dx_addr), + "cudnnSoftmaxBackward failed"); + } else { + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_axis failed"); + int size = SizeToInt(input_size_ / sizeof(T)); + CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr, + reinterpret_cast(stream_ptr)); + CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, transpose_y_addr, + y_desc_, transpose_dy_addr, &beta, y_desc_, transpose_dx_addr), + "cudnnSoftmaxBackward failed"); + CalTranspose(size, transpose_dx_addr, transpose_shape, transpose_axis, shape_size_, dx_addr, + reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax grad needs 2 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but softmax grad needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxGradGpuKernel input is null"; + InitSizeLists(); + return true; + } + shape_size_ = SizeToInt(input_shape.size()); + if (shape_size_ != 2) { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax grad only supports 2-D inputs."; + } + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == "LogSoftmaxGrad") { + algo_ = CUDNN_SOFTMAX_LOG; + auto axis = GetAttr(kernel_node, "axis"); + InitSizeByAxis(input_shape, axis); + } else { + algo_ = CUDNN_SOFTMAX_ACCURATE; + auto axis = GetAttr>(kernel_node, "axis"); + InitSizeByAxis(input_shape, axis[0]); + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), + SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), + "set input_descriptor failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "create input_descriptor failed"); + } + + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(input_size_); + workspace_size_list_.push_back(input_size_); + workspace_size_list_.push_back(output_size_); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "destroy output_descriptor failed"); + } + + void InitSizeByAxis(const std::vector input_shape, const int axis) { + axis_ = axis; + if (axis_ < 0) { + axis_ += shape_size_; + } + if (axis_ == 1) { + batch_size_ = input_shape[0]; + channel_size_ = input_shape[1]; + } else if (axis_ == 0) { + batch_size_ = input_shape[1]; + channel_size_ = input_shape[0]; + input_shape_.push_back(input_shape[0]); + input_shape_.push_back(input_shape[1]); + transpose_shape_.push_back(input_shape[1]); + transpose_shape_.push_back(input_shape[0]); + transpose_axis_.push_back(1); + transpose_axis_.push_back(0); + } else { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; + } + + height_ = 1; + width_ = 1; + input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; + output_size_ = input_size_; + workspace_size_ = IntToSize(shape_size_) * sizeof(int); + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t y_desc_; + cudnnSoftmaxAlgorithm_t algo_; + cudnnSoftmaxMode_t mode_; + cudnnDataType_t cudnn_data_type_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector input_shape_; + std::vector transpose_shape_; + std::vector transpose_axis_; + int axis_; + int shape_size_; + + size_t batch_size_; + size_t channel_size_; + size_t height_; + size_t width_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc new file mode 100644 index 0000000000..81b46f520c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + SparseSoftmaxCrossEntropyWithLogits, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + SparseSoftmaxCrossEntropyWithLogitsGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + SparseSoftmaxCrossEntropyWithLogits, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + SparseSoftmaxCrossEntropyWithLogitsGpuKernel, float, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h new file mode 100644 index 0000000000..bcb8a6b333 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h @@ -0,0 +1,206 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class SparseSoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { + public: + SparseSoftmaxCrossEntropyWithLogitsGpuKernel() + : cudnn_handle_(nullptr), + logits_descriptor_(nullptr), + softmax_output_descriptor_(nullptr), + algo_(CUDNN_SOFTMAX_ACCURATE), + mode_(CUDNN_SOFTMAX_MODE_INSTANCE), + cudnn_data_type_(CUDNN_DATA_FLOAT), + is_grad_(false), + is_null_input_(false), + logits_size_(0), + labels_size_(0), + output_size_(0), + softmax_output_logits_size_(0), + batch_size_(0), + channel_size_(0), + height_(0), + width_(0) {} + ~SparseSoftmaxCrossEntropyWithLogitsGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *logits_addr = GetDeviceAddress(inputs, 0); + S *labels_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + T *softmax_output_logits = GetDeviceAddress(workspace, 0); + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, logits_descriptor_, logits_addr, &beta, + softmax_output_descriptor_, softmax_output_logits), + "cudnnSoftmaxForward failed."); + + is_grad_ ? CrossEntropyGradWithSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output_addr, + reinterpret_cast(stream_ptr)) + : CrossEntropyWithSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output_addr, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num + << ", but SparseSoftmaxCrossEntropyWithLogitsGpuKernel needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num + << ", but SparseSoftmaxCrossEntropyWithLogitsGpuKernel needs 1 output."; + return false; + } + is_grad_ = GetAttr(kernel_node, "is_grad"); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + + InferInputOutputSize(kernel_node); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + batch_size_, channel_size_, height_, width_), + "cudnnSetTensor4dDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(softmax_output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_size_, + channel_size_, height_, width_), + "cudnnSetTensor4dDescriptor failed."); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&logits_descriptor_), + "cudnnCreateTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&softmax_output_descriptor_), + "cudnnCreateTensorDescriptor failed."); + } + void InitSizeLists() override { + input_size_list_.push_back(logits_size_); + input_size_list_.push_back(labels_size_); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(softmax_output_logits_size_); + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(softmax_output_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(logits_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + } + void InferInputOutputSize(const CNodePtr &kernel_node) { + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(logits_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input1 is null"; + InitSizeLists(); + return; + } + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(logits_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input2 is null"; + InitSizeLists(); + return; + } + CheckShapeValidation(logits_shape, labels_shape); + + size_t logits_dims = logits_shape.size(); + batch_size_ = 1; + for (size_t i = 0; i < logits_dims - 1; i++) { + batch_size_ *= logits_shape[i]; + } + channel_size_ = logits_shape[logits_dims - 1]; + height_ = 1; + width_ = 1; + logits_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; + + labels_size_ = 1; + size_t labels_dims = labels_shape.size(); + for (size_t i = 0; i < labels_dims; i++) { + labels_size_ *= labels_shape[i]; + } + labels_size_ *= sizeof(S); + + output_size_ = is_grad_ ? logits_size_ : sizeof(T); + softmax_output_logits_size_ = logits_size_; + return; + } + void CheckShapeValidation(const std::vector &logits_shape, const std::vector &labels_shape) { + size_t logits_dim_length = logits_shape.size(); + size_t labels_dim_length = labels_shape.size(); + if (labels_dim_length != logits_dim_length - 1) { + MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length minus 1 for " + "SparseSoftmaxCrossEntropyWithLogits, " + "but got Labels shape length:" + << labels_dim_length << ", Logits shape length:" << logits_dim_length; + } + if (!std::equal(labels_shape.begin(), labels_shape.end(), logits_shape.begin())) { + MS_LOG(EXCEPTION) << "The shape of labels should be the same as the shape of logits except its last demension."; + } + return; + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t logits_descriptor_; + cudnnTensorDescriptor_t softmax_output_descriptor_; + cudnnSoftmaxAlgorithm_t algo_; + cudnnSoftmaxMode_t mode_; + cudnnDataType_t cudnn_data_type_; + bool is_grad_; + bool is_null_input_; + + size_t logits_size_; + size_t labels_size_; + size_t output_size_; + size_t softmax_output_logits_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t batch_size_; + size_t channel_size_; + size_t height_; + size_t width_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.cc new file mode 100644 index 0000000000..4e07463a6c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/assign_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Assign, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AssignGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + Assign, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + AssignGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + AssignGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.h new file mode 100644 index 0000000000..76e863393c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class AssignGpuKernel : public GpuKernel { + public: + AssignGpuKernel() : input_size_(0) {} + ~AssignGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *var = GetDeviceAddress(inputs, 0); + T *value = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(var, value, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "cudaMemxcpyAsync failed."); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(output, value, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "cudaMemxcpyAsync failed."); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (size_t x : shape) { + input_size_ = input_size_ * x; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but AssignGpuKernel needs 2 output."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but AssignGpuKernel needs 1 output."; + return false; + } + return true; + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.cc new file mode 100644 index 0000000000..92652f67f9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BatchNormFold2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + BatchNormFold2GpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h new file mode 100644 index 0000000000..83600e20df --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BatchNormFold2GpuKernel : public GpuKernel { + public: + BatchNormFold2GpuKernel() + : cudnn_handle_(nullptr), + is_null_input_(false), + batch_size_(0), + channel_(0), + height_(0), + width_(0), + freeze_bn_(0) {} + + ~BatchNormFold2GpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + + auto *input = GetDeviceAddress(inputs, 0); + auto *beta = GetDeviceAddress(inputs, 1); + auto *gamma = GetDeviceAddress(inputs, 2); + auto *batch_std = GetDeviceAddress(inputs, 3); + auto *batch_mean = GetDeviceAddress(inputs, 4); + auto *running_std = GetDeviceAddress(inputs, 5); + auto *running_mean = GetDeviceAddress(inputs, 6); + auto *global_step = GetDeviceAddress(inputs, 7); + auto *output = GetDeviceAddress(outputs, 0); + + BatchNormFold2Forward(input, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, output, + freeze_bn_, batch_size_, channel_, height_, width_, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 8) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GpuKernel needs 8."; + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "BatchNormFold2GpuKernel input is null"; + InitSizeLists(); + return true; + } + + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "BatchNormFold2GpuKernel input shape needs (N,C,H,W)."; + return false; + } + batch_size_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + + void InitSizeLists() override { + size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); + size_t weight_size = channel_ * sizeof(T); + input_size_list_.push_back(input_size); + input_size_list_.push_back(weight_size); // beta + input_size_list_.push_back(weight_size); // gamma + input_size_list_.push_back(weight_size); // batch_std + input_size_list_.push_back(weight_size); // batch_mean + input_size_list_.push_back(weight_size); // running_std + input_size_list_.push_back(weight_size); // running_mean + input_size_list_.push_back(sizeof(int32_t)); // global_step + output_size_list_.push_back(input_size); + } + + private: + void DestroyResource() noexcept {} + + cudnnHandle_t cudnn_handle_; + bool is_null_input_; + size_t batch_size_; + size_t channel_; + size_t height_; + size_t width_; + size_t freeze_bn_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc new file mode 100644 index 0000000000..6fc080713a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BatchNormFold2Grad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BatchNormFold2GradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h new file mode 100644 index 0000000000..3335210925 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h @@ -0,0 +1,168 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BatchNormFold2GradGpuKernel : public GpuKernel { + public: + BatchNormFold2GradGpuKernel() + : cudnn_handle_(nullptr), + is_null_input_(false), + batch_size_(0), + channel_(0), + height_(0), + width_(0), + freeze_bn_(0) {} + + ~BatchNormFold2GradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + + auto *dout = GetDeviceAddress(inputs, 0); + auto *x = GetDeviceAddress(inputs, 1); + auto *gamma = GetDeviceAddress(inputs, 2); + auto *batch_std = GetDeviceAddress(inputs, 3); + auto *batch_mean = GetDeviceAddress(inputs, 4); + auto *running_std = GetDeviceAddress(inputs, 5); + auto *running_mean = GetDeviceAddress(inputs, 6); + auto *global_step = GetDeviceAddress(inputs, 7); + auto *d_batch_std = GetDeviceAddress(outputs, 0); + auto *d_batch_mean = GetDeviceAddress(outputs, 1); + auto *d_beta = GetDeviceAddress(outputs, 2); + auto *d_gamma = GetDeviceAddress(outputs, 3); + auto *d_x = GetDeviceAddress(outputs, 4); + auto *tmp = GetDeviceAddress(workspace, 0); + auto *tmp2 = GetDeviceAddress(workspace, 1); + auto *reduce_x = GetDeviceAddress(workspace, 2); + auto *tmp_x = GetDeviceAddress(workspace, 3); + + int32_t current_step_host[1]; + size_t x_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); + CHECK_CUDA_RET_WITH_ERROR( + cudaMemcpyAsync(d_x, dout, x_size, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); + + BatchNormFold2GradReduce(dout, x, d_beta, tmp, reduce_x, tmp2, tmp_x, batch_size_, channel_, height_, width_, + reinterpret_cast(stream_ptr)); + if (current_step_host[0] < freeze_bn_) { + CalBatchNormFold2GradNotFreezeDxMul(batch_std, running_std, d_x, batch_size_, channel_, height_, width_, + reinterpret_cast(stream_ptr)); + CalBatchNormFold2GradNotFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, + d_batch_mean, d_batch_std, channel_, reinterpret_cast(stream_ptr)); + } else { + CalBatchNormFold2GradFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, + d_batch_mean, d_batch_std, channel_, reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 8) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GradGpuKernel needs 8."; + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "BatchNormFold2GradGpuKernel input is null"; + InitSizeLists(); + return true; + } + + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "BatchNormFold2GradGpuKernel input shape needs (N,C,H,W)."; + return false; + } + batch_size_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + + void InitSizeLists() override { + size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); + size_t weight_size = channel_ * sizeof(T); + size_t workspace_size = batch_size_ * channel_ * sizeof(T); + input_size_list_.push_back(input_size); // dout + input_size_list_.push_back(input_size); // x + input_size_list_.push_back(weight_size); // gamma + input_size_list_.push_back(weight_size); // batch_std + input_size_list_.push_back(weight_size); // batch_mean + input_size_list_.push_back(weight_size); // running_std + input_size_list_.push_back(weight_size); // running_mean + input_size_list_.push_back(sizeof(int32_t)); // global_step + + output_size_list_.push_back(weight_size); // d_batch_std + output_size_list_.push_back(weight_size); // d_batch_mean + output_size_list_.push_back(weight_size); // d_beta + output_size_list_.push_back(weight_size); // d_gamma + output_size_list_.push_back(input_size); // d_x + + workspace_size_list_.push_back(workspace_size); // tmp + workspace_size_list_.push_back(workspace_size); // tmp2 + workspace_size_list_.push_back(weight_size); // reduce_x + workspace_size_list_.push_back(input_size); // tmp_x + } + + private: + void DestroyResource() noexcept {} + + cudnnHandle_t cudnn_handle_; + bool is_null_input_; + size_t batch_size_; + size_t channel_; + size_t height_; + size_t width_; + int32_t freeze_bn_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.cc new file mode 100644 index 0000000000..95349c84aa --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BatchNormFold, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BatchNormFoldGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h new file mode 100644 index 0000000000..11b150686c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h @@ -0,0 +1,209 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BatchNormFoldGpuKernel : public GpuKernel { + public: + BatchNormFoldGpuKernel() + : input_size_(0), + output_size_(0), + exp_avg_factor_(0.9), + epsilon_(1e-12), + is_training_(true), + freeze_bn_(0), + batch_(0), + channel_(0), + height_(0), + width_(0), + mode_(CUDNN_BATCHNORM_SPATIAL), + x_desc_(nullptr), + scale_bias_mean_var_desc_(nullptr), + handle_(nullptr) {} + + ~BatchNormFoldGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + (void)workspace; + auto x = GetDeviceAddress(inputs, 0); + auto mean = GetDeviceAddress(inputs, 1); + auto variance = GetDeviceAddress(inputs, 2); + int *current_step = GetDeviceAddress(inputs, 3); + int current_step_host[1]; + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), + "Copy gpu memoy failed."); + if (x == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null."; + return false; + } + if (mean == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGpuKernel mean is null."; + return false; + } + if (variance == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGpuKernel variance is null."; + return false; + } + if (current_step == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGpuKernel current_step is null."; + return false; + } + auto batch_mean = GetDeviceAddress(outputs, 0); + auto batch_std = GetDeviceAddress(outputs, 1); + auto running_mean = GetDeviceAddress(outputs, 2); + auto running_std = GetDeviceAddress(outputs, 3); + auto y = GetDeviceAddress(workspace, 0); + + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_std, variance, output_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); + CalUpdateRunningStd(channel_, epsilon_, running_std, reinterpret_cast(stream_ptr)); + if (!is_training_ || current_step_host[0] >= freeze_bn_) { + CHECK_CUDA_RET_WITH_ERROR(cudaMemset(batch_mean, 0, output_size_), "Failed to set gpu memory."); + ThrustFillWith(batch_std, channel_, 1.f, reinterpret_cast(stream_ptr)); + return true; + } + const T alpha = 1; + const T beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationForwardTraining( + handle_, mode_, &alpha, &beta, x_desc_, x, x_desc_, y, scale_bias_mean_var_desc_, + mean, mean, exp_avg_factor_, mean, variance, epsilon_, batch_mean, batch_std), + "Failed to launch kernel.") + CalUpdateBatchStd(channel_, batch_std, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 4) { + MS_LOG(ERROR) << "Input number is " << input_num << " but BatchNormFold GpuKernel OP needs 4 input."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 4) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFold GpuKernel OP needs 4 output."; + return false; + } + + T momentum = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("momentum")); + exp_avg_factor_ = 1.0 - momentum; + epsilon_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon")); + is_training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training")); + freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "Input shape is " << input_shape.size() + << ", but BatchNormFold GpuKernel OP needs 4DTensor input."; + return false; + } + batch_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + + input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; + output_size_ = sizeof(T) * channel_; + + cudnnDataType_t cudnnDataType = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, batch_, channel_, height_, width_), + "Set x desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, 1, channel_, 1, 1), + "Set para desc failed"); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + // x, mean, variance, current_step + input_size_list_.push_back(input_size_); + input_size_list_.push_back(output_size_); + input_size_list_.push_back(output_size_); + input_size_list_.push_back(sizeof(int)); + + // batch_mean, batch_std, running_mean, running_std + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_); + + // store y + workspace_size_list_.push_back(input_size_); + } + + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed"); + } + + size_t input_size_; + size_t output_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + double exp_avg_factor_; + double epsilon_; + bool is_training_; + int freeze_bn_; + int batch_; + int channel_; + int height_; + int width_; + + cudnnBatchNormMode_t mode_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t scale_bias_mean_var_desc_; + + cudnnHandle_t handle_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc new file mode 100644 index 0000000000..b727c6c7df --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BatchNormFoldGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + BatchNormFoldGradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h new file mode 100644 index 0000000000..93a3cbf46e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h @@ -0,0 +1,166 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BatchNormFoldGradGpuKernel : public GpuKernel { + public: + BatchNormFoldGradGpuKernel() + : input_size_(0), + channel_size_(0), + workspace_size_(0), + momentum_(0.1), + epsilon_(1e-12), + is_training_(true), + freeze_bn_(0), + current_step_(0), + batch_(0), + channel_(0), + height_(0), + width_(0) {} + ~BatchNormFoldGradGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' + T *d_batch_mean = GetDeviceAddress(inputs, 0); + T *d_batch_std = GetDeviceAddress(inputs, 1); + T *x = GetDeviceAddress(inputs, 2); + T *batch_mean = GetDeviceAddress(inputs, 3); + T *batch_std = GetDeviceAddress(inputs, 4); + int *current_step = GetDeviceAddress(inputs, 5); + int current_step_host[1]; + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), + "Copy gpu memoy failed."); + if (d_batch_mean == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null."; + return false; + } + if (d_batch_std == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_std is null."; + return false; + } + if (x == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel x is null."; + return false; + } + if (batch_mean == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_mean is null."; + return false; + } + if (batch_std == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_std is null."; + return false; + } + if (current_step == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel current_step is null."; + return false; + } + T *dx = GetDeviceAddress(outputs, 0); + + if (!is_training_ || current_step_host[0] >= freeze_bn_) { + ThrustFillWith(dx, batch_ * channel_ * height_ * width_, 0.f, reinterpret_cast(stream_ptr)); + return true; + } + CalBatchNormFoldGrad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_, channel_, height_, width_, dx, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 6) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs 6 input."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFoldGrad GpuKernel OP needs 4 output."; + return false; + } + + epsilon_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon")); + is_training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training")); + freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "Input shape is " << input_shape.size() + << ", but BatchNormFoldGrad GpuKernel OP needs 4DTensor input."; + return false; + } + batch_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + + input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; + channel_size_ = sizeof(T) * channel_; + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' + input_size_list_.push_back(channel_size_); + input_size_list_.push_back(channel_size_); + input_size_list_.push_back(input_size_); + input_size_list_.push_back(channel_size_); + input_size_list_.push_back(channel_size_); + input_size_list_.push_back(sizeof(int)); + // 'dx' + output_size_list_.push_back(input_size_); + } + + private: + size_t input_size_; + size_t channel_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + T momentum_; + T epsilon_; + bool is_training_; + int freeze_bn_; + int current_step_; + int batch_; + int channel_; + int height_; + int width_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.cc new file mode 100644 index 0000000000..9af5451c53 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020、 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(CorrectionMul, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CorrectionMulGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h new file mode 100644 index 0000000000..4ba6285e4b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class CorrectionMulGpuKernel : public GpuKernel { + public: + CorrectionMulGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} + ~CorrectionMulGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto *weight = GetDeviceAddress(inputs, 0); + auto *gamma = GetDeviceAddress(inputs, 1); + auto *running_std = GetDeviceAddress(inputs, 2); + auto *output = GetDeviceAddress(outputs, 0); + + CalCorrectionMul(weight, gamma, running_std, batch_size_, channel_, height_, width_, output, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGpuKernel needs 3."; + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "CorrectionMulGpuKernel input shape needs (N,C,H,W)."; + return false; + } + batch_size_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); + size_t weight_size = batch_size_ * sizeof(T); + input_size_list_.push_back(input_size); // weight + input_size_list_.push_back(weight_size); // gamma + input_size_list_.push_back(weight_size); // running_std + output_size_list_.push_back(input_size); + } + + void InitResource() override {} + + private: + void DestroyResource() noexcept {} + + size_t batch_size_; + size_t channel_; + size_t height_; + size_t width_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.cc new file mode 100644 index 0000000000..63a47bc452 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cuh" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(CorrectionMulGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CorrectionMulGradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h new file mode 100644 index 0000000000..b9fcbf0787 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class CorrectionMulGradGpuKernel : public GpuKernel { + public: + CorrectionMulGradGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} + ~CorrectionMulGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + auto *d_out = GetDeviceAddress(inputs, 0); + auto *weight = GetDeviceAddress(inputs, 1); + auto *gamma = GetDeviceAddress(inputs, 2); + auto *running_std = GetDeviceAddress(inputs, 3); + auto *d_weight = GetDeviceAddress(outputs, 0); + auto *d_gamma = GetDeviceAddress(outputs, 1); + auto *tmp = GetDeviceAddress(workspace, 0); + + CalCorrectionMul(d_out, gamma, running_std, batch_size_, channel_, height_, width_, d_weight, + reinterpret_cast(stream_ptr)); + CalCorrectionMulGrad(d_out, weight, running_std, batch_size_, channel_, height_, width_, d_gamma, tmp, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 4) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGradGpuKernel needs 4."; + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "CorrectionMulGradGpuKernel input shape needs (N,C,H,W)."; + return false; + } + batch_size_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); + size_t weight_size = batch_size_ * sizeof(T); + input_size_list_.push_back(input_size); // d_out + input_size_list_.push_back(input_size); // weight + input_size_list_.push_back(weight_size); // gamma + input_size_list_.push_back(weight_size); // running_std + output_size_list_.push_back(input_size); // d_weight + output_size_list_.push_back(weight_size); // d_gamma + workspace_size_list_.push_back(input_size); // tmp d_out * weight + } + void InitResource() override {} + + private: + void DestroyResource() noexcept {} + + size_t batch_size_; + size_t channel_; + size_t height_; + size_t width_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.cc new file mode 100644 index 0000000000..8a43ce0941 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.cc @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel() + : input_size_(0), + num_channels_(0), + num_bits_(0), + training_(false), + symmetric_(false), + narrow_range_(false), + quant_delay_(0), + quant_min_(0), + quant_max_(0), + global_step_(0) {} + +const std::vector &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &FakeQuantPerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &FakeQuantPerChannelGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 input."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << " but FakeQuant GpuKernel OP needs 1 output."; + return false; + } + + // get attribute + num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); + training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); + + if (num_bits_ <= 2 || num_bits_ >= 16) { + MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16."; + return false; + } + + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0."; + return false; + } + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + + // shape info for gpu + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + num_channels_ = SizeToInt(input_shape[0]); + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void FakeQuantPerChannelGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // input in tensor + input_size_list_.push_back(sizeof(float) * num_channels_); // min one scalar + input_size_list_.push_back(sizeof(float) * num_channels_); // max on scalar + output_size_list_.push_back(input_size_); // output in tensor + workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel +} + +void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, + float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) { + CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, + symmetric_, reinterpret_cast(stream_ptr)); + CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); +} + +bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + (void)workspace; + float *output = GetDeviceAddress(outputs, 0); + float *input = GetDeviceAddress(inputs, 0); + float *input_min = GetDeviceAddress(inputs, 1); + float *input_max = GetDeviceAddress(inputs, 2); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); + + if (input == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null."; + } + + if (training_) { + if (global_step_ >= quant_delay_) { + CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); + } else { + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed."); + } + global_step_++; + } else { + CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); + } + + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h new file mode 100755 index 0000000000..8e2c9524b2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeQuantPerChannelGpuKernel : public GpuKernel { + public: + FakeQuantPerChannelGpuKernel(); + ~FakeQuantPerChannelGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + void CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min, + float *nudge_max, float *scale, void *stream_ptr); + + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int num_channels_; + int num_bits_; + bool training_; + bool symmetric_; + bool narrow_range_; + int quant_delay_; + float quant_min_; + float quant_max_; + int global_step_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc new file mode 100644 index 0000000000..598a6a960d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" + +namespace mindspore { +namespace kernel { +FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel() + : input_size_(0), + num_bits_(0), + quant_min_(0), + quant_max_(0), + num_channels_(0), + quant_delay_(0), + global_step_(0), + narrow_range_(false), + symmetric_(false) {} + +const std::vector &FakeQuantPerChannelGradGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &FakeQuantPerChannelGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &FakeQuantPerChannelGradGpuKernel::GetWorkspaceSizeList() const { + return workspace_size_list_; +} + +bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 4) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output."; + } + + num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); + if (num_bits_ <= 2 || num_bits_ >= 16) { + MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; + } + + quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; + } + + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + num_channels_ = SizeToInt(input_shape[0]); + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void FakeQuantPerChannelGradGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // gradient + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float) * num_channels_); // min + input_size_list_.push_back(sizeof(float) * num_channels_); // max + output_size_list_.push_back(input_size_); // output + workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel +} + +bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + (void)workspace; + float *output = GetDeviceAddress(outputs, 0); + float *gradient = GetDeviceAddress(inputs, 0); + float *input = GetDeviceAddress(inputs, 1); + float *input_min = GetDeviceAddress(inputs, 2); + float *input_max = GetDeviceAddress(inputs, 3); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); + + if (gradient == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; + } + if (input == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input is null"; + } + if (input_min == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input min is null"; + } + if (input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input max is null"; + } + + int total_size = input_size_ / sizeof(float); + if (global_step_ >= quant_delay_) { + CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, + symmetric_, reinterpret_cast(stream_ptr)); + CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed."); + } + global_step_++; + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h new file mode 100644 index 0000000000..c2611ab8a2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeQuantPerChannelGradGpuKernel : public GpuKernel { + public: + FakeQuantPerChannelGradGpuKernel(); + ~FakeQuantPerChannelGradGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel_node) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int num_bits_; + float quant_min_; + float quant_max_; + int num_channels_; + int quant_delay_; + int global_step_; + bool narrow_range_; + bool symmetric_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.cc new file mode 100644 index 0000000000..24edec97a9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +FakeQuantPerLayerGpuKernel::FakeQuantPerLayerGpuKernel() + : input_size_(0), + quant_min_(0), + quant_max_(0), + quant_num_(1), + global_step_(0), + num_bits_(0), + quant_delay_(0), + training_(false), + narrow_range_(false), + symmetric_(false) {} + +const std::vector &FakeQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; + } + + num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); + quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); + training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + + if (num_bits_ <= 2 || num_bits_ >= 16) { + MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; + } + + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0."; + } + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void FakeQuantPerLayerGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // x + input_size_list_.push_back(sizeof(float)); // min + input_size_list_.push_back(sizeof(float)); // max + output_size_list_.push_back(input_size_); // y + workspace_size_list_.push_back(sizeof(float)); // scale + workspace_size_list_.push_back(sizeof(float)); // nudge_min + workspace_size_list_.push_back(sizeof(float)); // nudge_max +} + +bool FakeQuantPerLayerGpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + float *output = GetDeviceAddress(outputs, 0); + float *input = GetDeviceAddress(inputs, 0); + float *input_min = GetDeviceAddress(inputs, 1); + float *input_max = GetDeviceAddress(inputs, 2); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); + + if (input == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input x is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input min or input max is null."; + } + + if (training_) { + // control flow for quant_delay + if (global_step_ >= quant_delay_) { + // real launch + CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, + reinterpret_cast(stream_ptr)); + CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + } else { + // real launch + CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, + reinterpret_cast(stream_ptr)); + CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); + } + + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h new file mode 100755 index 0000000000..6df4da3104 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeQuantPerLayerGpuKernel : public GpuKernel { + public: + FakeQuantPerLayerGpuKernel(); + ~FakeQuantPerLayerGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + float quant_min_; + float quant_max_; + int quant_num_; + int global_step_; + int num_bits_; + int quant_delay_; + bool training_; + bool narrow_range_; + bool symmetric_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc new file mode 100644 index 0000000000..f96b6a48d2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" + +namespace mindspore { +namespace kernel { +FakeQuantPerLayerGradGpuKernel::FakeQuantPerLayerGradGpuKernel() + : input_size_(0), + workspace_size_(0), + num_bits_(0), + quant_min_(0), + quant_max_(0), + quant_num_(1), + quant_delay_(0), + global_step_(0), + narrow_range_(false), + symmetric_(false) {} + +const std::vector &FakeQuantPerLayerGradGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 4) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output."; + } + + num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); + if (num_bits_ <= 2 || num_bits_ >= 16) { + MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; + } + + quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; + } + + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void FakeQuantPerLayerGradGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // gradient + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float)); // min + input_size_list_.push_back(sizeof(float)); // max + output_size_list_.push_back(input_size_); // output + workspace_size_list_.push_back(sizeof(float)); // scale + workspace_size_list_.push_back(sizeof(float)); // nudge_min + workspace_size_list_.push_back(sizeof(float)); // nudge_max +} + +bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + float *output = GetDeviceAddress(outputs, 0); + float *gradient = GetDeviceAddress(inputs, 0); + float *input = GetDeviceAddress(inputs, 1); + float *input_min = GetDeviceAddress(inputs, 2); + float *input_max = GetDeviceAddress(inputs, 3); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); + + if (gradient == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel gradient is null"; + } + if (input == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input min or max is null."; + } + + if (global_step_ >= quant_delay_) { + CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, + reinterpret_cast(stream_ptr)); + CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h new file mode 100644 index 0000000000..475723f684 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeQuantPerLayerGradGpuKernel : public GpuKernel { + public: + FakeQuantPerLayerGradGpuKernel(); + ~FakeQuantPerLayerGradGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel_node) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int num_bits_; + float quant_min_; + float quant_max_; + int quant_num_; + int quant_delay_; + int global_step_; + bool narrow_range_; + bool symmetric_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.cc new file mode 100644 index 0000000000..742a9b8c55 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel() + : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0), num_channels_(0) {} + +const std::vector &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &MinMaxUpdatePerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList() const { + return workspace_size_list_; +} + +bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; + } + + ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); + ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + num_channels_ = SizeToInt(input_shape[0]); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void MinMaxUpdatePerChannelGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float) * num_channels_); // min + input_size_list_.push_back(sizeof(float) * num_channels_); // max + output_size_list_.push_back(sizeof(float) * num_channels_); // output min + output_size_list_.push_back(sizeof(float) * num_channels_); // output max +} + +bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + float *output_min = GetDeviceAddress(outputs, 0); + float *output_max = GetDeviceAddress(outputs, 1); + float *input = GetDeviceAddress(inputs, 0); + float *input_min = GetDeviceAddress(inputs, 1); + float *input_max = GetDeviceAddress(inputs, 2); + + if (input == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input x is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input min or input max is null."; + } + + // calculate the input min and max according by the parameter ema and ema_decay. + CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_, + ema_decay_, ema_, reinterpret_cast(stream_ptr)); + return true; +} + +MS_REG_GPU_KERNEL(MinMaxUpdatePerChannel, MinMaxUpdatePerChannelGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h new file mode 100644 index 0000000000..9a0fe23e6a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class MinMaxUpdatePerChannelGpuKernel : public GpuKernel { + public: + MinMaxUpdatePerChannelGpuKernel(); + ~MinMaxUpdatePerChannelGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int quant_num_; + bool ema_; + float ema_decay_; + int num_channels_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.cc new file mode 100644 index 0000000000..8f11e907e1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel() + : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0) {} + +const std::vector &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; + } + + ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); + ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void MinMaxUpdatePerLayerGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float)); // input min + input_size_list_.push_back(sizeof(float)); // input max + output_size_list_.push_back(sizeof(float)); // output min + output_size_list_.push_back(sizeof(float)); // output max +} + +bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + float *output_min = GetDeviceAddress(outputs, 0); + float *output_max = GetDeviceAddress(outputs, 1); + float *input = GetDeviceAddress(inputs, 0); + float *input_min = GetDeviceAddress(inputs, 1); + float *input_max = GetDeviceAddress(inputs, 2); + + if (input == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input x is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null."; + } + + CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, + reinterpret_cast(stream_ptr)); + + return true; +} + +MS_REG_GPU_KERNEL(MinMaxUpdatePerLayer, MinMaxUpdatePerLayerGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h new file mode 100644 index 0000000000..80ce6185c0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class MinMaxUpdatePerLayerGpuKernel : public GpuKernel { + public: + MinMaxUpdatePerLayerGpuKernel(); + ~MinMaxUpdatePerLayerGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int quant_num_; + bool ema_; + float ema_decay_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc new file mode 100644 index 0000000000..5ec4f52574 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hccl_kernel.h" +#include "runtime/device/ascend/tasksink/runtime_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" + +using HcclTaskInfoPtr = std::shared_ptr; +using ge::model_runner::HcclTaskInfo; +using mindspore::device::ascend::tasksink::RuntimeUtils; + +namespace mindspore { +namespace kernel { +void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) { + hcclKernelMap_.emplace(name, std::move(fun)); +} + +std::shared_ptr HcclKernelFactory::Get(const std::string &name) { + const auto &map = Get().hcclKernelMap_; + auto it = map.find(name); + if (it != map.end() && it->second) { + return (it->second)(); + } + return nullptr; +} + +HcclKernelFactory &HcclKernelFactory::Get() { + static HcclKernelFactory _this; + return _this; +} + +HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REP_OP_SUM), root_id_(0), anf_node_(nullptr) {} + +HcclKernel::~HcclKernel() { + hccl_kernel_input_shape_list_.clear(); + hccl_kernel_output_shape_list_.clear(); + hccl_data_type_list_.clear(); + hccl_count_ = 0; + op_type_ = HCCL_REP_OP_SUM; + root_id_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + anf_node_ = nullptr; +} + +bool HcclKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + op_name_ = AnfAlgo::GetCNodeName(anf_node); + + if (!HcomUtil::GetKernelInputShape(anf_node, &hccl_kernel_input_shape_list_)) { + MS_LOG(ERROR) << "GetKernelInputShape fail!"; + return false; + } + if (!HcomUtil::GetKernelOutputShape(anf_node, &hccl_kernel_output_shape_list_)) { + MS_LOG(ERROR) << "GetKernelOutputShape fail!"; + return false; + } + if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) { + MS_LOG(ERROR) << "GetHcomDataType fail!"; + return false; + } + if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) { + MS_LOG(ERROR) << "GetHcomCount fail!"; + return false; + } + if (op_name_ == kAllReduce || op_name_ == kReduceScatter) { + if (!HcomUtil::GetHcomOperationType(anf_node, &op_type_)) { + MS_LOG(ERROR) << "GetHcomOperationType fail!"; + return false; + } + } + if (op_name_ == kBroadcast) { + if (!HcomUtil::GetHcomRootId(anf_node, &root_id_)) { + MS_LOG(ERROR) << "GetHcomRootId fail!"; + return false; + } + } + HcomUtil::GetHcomGroup(NOT_NULL(anf_node), NOT_NULL(&group_)); + anf_node_ = anf_node; + return true; +} + +const std::vector &HcclKernel::GetInputSizeList() const { + size_t size = 0; + if (!input_size_list_.empty()) { + return input_size_list_; + } + for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { + if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_input_shape_list_[i], &size)) { + MS_LOG(ERROR) << "GetHcclOpInputSize failed"; + } + input_size_list_.push_back(size); + } + return input_size_list_; +} + +const std::vector &HcclKernel::GetOutputSizeList() const { + size_t size = 0; + if (!output_size_list_.empty()) { + return output_size_list_; + } + for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { + if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_output_shape_list_[i], &size)) { + MS_LOG(ERROR) << "GetHcclOpOutputSize failed"; + } + output_size_list_.push_back(size); + } + return output_size_list_; +} + +const std::vector &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +std::vector HcclKernel::GenTask(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "Inputs or outputs is empty"; + } + stream_id_ = stream_id; + std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); + MS_EXCEPTION_IF_NULL(inputs.at(0)); + auto input_data_addr = inputs.at(0)->addr; + MS_EXCEPTION_IF_NULL(outputs.at(0)); + auto output_data_addr = outputs.at(0)->addr; + void *workspace_address = nullptr; + const int64_t workspace_num = 0; + std::vector private_def; + hcclDataType_t data_type = hccl_data_type_list_[0]; + + MS_LOG(INFO) << "HCCL Task : stream_id=" << stream_id << ", ws_num=" << workspace_num << ", count=" << hccl_count_ + << ", root_id=" << root_id_ << ", op_type=" << static_cast(op_type_) + << ", data_type=" << static_cast(data_type); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + HcclTaskInfoPtr task_info_ptr = std::make_shared( + kernel_name_, stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, + private_def, nullptr, hccl_count_, root_id_, op_type_, data_type, group_, RuntimeUtils::HcomBindModel, + RuntimeUtils::HcomUnbindModel, RuntimeUtils::HcomDistribute, NeedDump()); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h new file mode 100644 index 0000000000..db7a0fbf7c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h @@ -0,0 +1,95 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "backend/kernel_compiler/hccl/hcom_util.h" +#include "hccl/hcom.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +class HcclKernel : public AscendKernelMod { + public: + HcclKernel(); + ~HcclKernel() override; + virtual bool Init(const AnfNodePtr &anf_node); + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + protected: + std::vector> hccl_kernel_input_shape_list_; + std::vector> hccl_kernel_output_shape_list_; + std::vector hccl_data_type_list_; + std::vector hccl_format_list_; + uint64_t hccl_count_; + hcclRedOp_t op_type_; + uint32_t root_id_; + mutable std::vector input_size_list_; + mutable std::vector output_size_list_; + mutable std::vector workspace_size_list_; + AnfNodePtr anf_node_; + std::string op_name_; + std::string group_; +}; + +using HcclKernelCreater = std::function()>; + +class HcclKernelFactory { + HcclKernelFactory() = default; + ~HcclKernelFactory() = default; + + public: + static HcclKernelFactory &Get(); + void Registe(const string &name, HcclKernelCreater &&fun); + static std::shared_ptr Get(const string &name); + + private: + std::map hcclKernelMap_; +}; + +class _HcclKernelRegister { + public: + _HcclKernelRegister(const string &name, HcclKernelCreater &&fun) { + HcclKernelFactory::Get().Registe(name, std::move(fun)); + } + ~_HcclKernelRegister() = default; +}; + +#define _MS_HCCL_REG_KERNEL_REG(KNAME, clazz) \ + static_assert(std::is_base_of::value, " must be base of HcclKernel"); \ + static const _HcclKernelRegister g_##KNAME##_##_kernel_reg(#KNAME, []() { \ + std::shared_ptr ptr = nullptr; \ + ptr = std::make_shared(); \ + MS_EXCEPTION_IF_NULL(ptr); \ + return ptr; \ + }); + +#define MS_HCCL_REG_KERNEL(KNAME, clazz) _MS_HCCL_REG_KERNEL_REG(KNAME, clazz) +} // namespace kernel +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.cc new file mode 100644 index 0000000000..8297be0b6d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hccl_kernel_build.h" + +#include +#include +#include + +#include "backend/kernel_compiler/hccl/hccl_kernel.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +KernelModPtr HcclOpBuild(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string opname = AnfAlgo::GetCNodeName(anf_node); + MS_LOG(INFO) << "Hccl op [" << opname << "]"; + auto kerPtr = HcclKernelFactory::Get(opname); + if (kerPtr == nullptr) { + MS_LOG(ERROR) << "Hccl can't find Kernel[" << opname << "]"; + return nullptr; + } + if (!kerPtr->Init(anf_node)) { + MS_LOG(ERROR) << "Kernel initialize failed!"; + return nullptr; + } + return kerPtr; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.h new file mode 100644 index 0000000000..21b34d6522 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.h @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_BUILD_H_ + +#include +#include +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +KernelModPtr HcclOpBuild(const AnfNodePtr &anf_node); +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc new file mode 100755 index 0000000000..55742d383c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h" +#include +#include +#include "utils/utils.h" +#include "backend/kernel_compiler/hccl/hcom_util.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +namespace { +std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { + const std::set kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; + auto op_name = AnfAlgo::GetCNodeName(kernel_node); + auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); + if (op_name != kReduceScatter && op_name != kAllGatherOpName) { + return format; + } + if (format == kOpFormat_FRAC_NZ && AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index).size() <= 2) { + return kOpFormat_DEFAULT; + } + if (kReduceNoSupportedSet.find(format) != kReduceNoSupportedSet.end()) { + return kOpFormat_DEFAULT; + } + return format; +} +} // namespace +void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + const std::vector kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, + kNumberTypeFloat32, kNumberTypeInt16}; + MS_EXCEPTION_IF_NULL(kernel_info_list); + MS_EXCEPTION_IF_NULL(kernel_node); + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + if (op_name != kAllGather && op_name != kAllReduce && op_name != kBroadcast && op_name != kReduceScatter) { + MS_LOG(DEBUG) << "Hccl does not have op [" << op_name << "]"; + return; + } + for (const auto &type : kHcclSupportTypes) { + std::vector inputs_format{}; + std::vector inputs_type{}; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index)); + inputs_type.push_back(type); + } + std::vector outputs_format; + std::vector outputs_type; + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index)); + outputs_type.push_back(type); + } + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetKernelType(HCCL_KERNEL); + kernel_info_list->push_back(builder.Build()); + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.h new file mode 100755 index 0000000000..25891fdaf6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ +#include +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc new file mode 100644 index 0000000000..e9fb4c9314 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_all_broadcast.h" + +#include +#include +#include + +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +bool HcomAllBroadCastKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void *stream_ptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_task_sink()) { + return true; + } + if (inputs.empty() || hccl_data_type_list_.empty()) { + MS_LOG(ERROR) << "BroadCast param is empty"; + return false; + } + const char *tag = "Hccl-BroadCast"; + MS_EXCEPTION_IF_NULL(inputs[0]); + hcclResult_t ret = + hcom_broadcast(tag, inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, nullptr, stream_ptr); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << static_cast(ret); + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.h new file mode 100644 index 0000000000..6434b5fb9c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_BROADCAST_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_BROADCAST_H_ + +#include +#include +#include "hccl/hcom.h" +#include "backend/kernel_compiler/hccl/hccl_kernel.h" + +namespace mindspore { +namespace kernel { +class HcomAllBroadCastKernel : public HcclKernel { + public: + HcomAllBroadCastKernel() = default; + ~HcomAllBroadCastKernel() override = default; + + /* Inherit from kernelmod */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: +}; +MS_HCCL_REG_KERNEL(Broadcast, HcomAllBroadCastKernel); +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc new file mode 100644 index 0000000000..201071dcb5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_all_gather.h" + +#include +#include +#include + +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +bool HcomAllGatherKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_task_sink()) { + return true; + } + if (inputs.empty() || hccl_data_type_list_.empty()) { + MS_LOG(ERROR) << "AllGather param is empty"; + return false; + } + const char *tag = "Hccl-AllGather"; + hcclResult_t ret = + hcom_all_gather(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], nullptr, stream_ptr); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcomAllGatherKernelOp : hcom_all_gather fail, return: " << static_cast(ret); + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h new file mode 100644 index 0000000000..21d8ffa484 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_GATHER_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_GATHER_H_ + +#include +#include +#include "hccl/hcom.h" +#include "backend/kernel_compiler/hccl/hccl_kernel.h" + +namespace mindspore { +namespace kernel { +class HcomAllGatherKernel : public HcclKernel { + public: + HcomAllGatherKernel() = default; + ~HcomAllGatherKernel() override = default; + + /* Inherit from kernelmod */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: +}; +MS_HCCL_REG_KERNEL(AllGather, HcomAllGatherKernel); +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc new file mode 100644 index 0000000000..533ce1b087 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_all_reduce.h" + +#include +#include +#include + +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +bool HcomAllReduceKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_task_sink()) { + return true; + } + if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { + MS_LOG(ERROR) << "AllReduce param is empty"; + return false; + } + const char *tag = "Hccl-AllReduce"; + hcclResult_t ret = hcom_all_reduce(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], + op_type_, nullptr, stream_ptr); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcomAllReduceKernelOp : hcom_all_reduce fail, return: " << static_cast(ret); + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.h new file mode 100644 index 0000000000..39641f7448 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_H_ + +#include +#include +#include "backend/kernel_compiler/hccl/hccl_kernel.h" + +namespace mindspore { +namespace kernel { +class HcomAllReduceKernel : public HcclKernel { + public: + HcomAllReduceKernel() = default; + ~HcomAllReduceKernel() override = default; + + /* Inherit from kernelmod */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: +}; + +MS_HCCL_REG_KERNEL(AllReduce, HcomAllReduceKernel); +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc new file mode 100644 index 0000000000..32c6dacb01 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h" + +#include +#include +#include + +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +bool HcomAllReduceScatterKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_task_sink()) { + return true; + } + if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { + MS_LOG(ERROR) << "ReduceScatter param is empty"; + return false; + } + const char *tag = "Hccl-ReduceScatter"; + hcclResult_t ret = hcom_reduce_scatter(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], + op_type_, nullptr, stream_ptr); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcomReduceScatterOp : hcom_reduce_scatter fail, return: " << static_cast(ret); + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h new file mode 100644 index 0000000000..2f4ace5aea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ + +#include +#include +#include "hccl/hcom.h" +#include "backend/kernel_compiler/hccl/hccl_kernel.h" + +namespace mindspore { +namespace kernel { +class HcomAllReduceScatterKernel : public HcclKernel { + public: + HcomAllReduceScatterKernel() = default; + ~HcomAllReduceScatterKernel() override = default; + + /* Inherit from kernelmod */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: +}; + +MS_HCCL_REG_KERNEL(ReduceScatter, HcomAllReduceScatterKernel); +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc new file mode 100644 index 0000000000..721c1b6ba0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_util.h" + +#include + +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" + +namespace mindspore { +bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_intput_shape_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { + std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); + hccl_kernel_intput_shape_list->emplace_back(shape_i); + } + + return true; +} + +bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_output_shape_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list); + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node); ++i) { + std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); + hccl_kernel_output_shape_list->emplace_back(shape_i); + } + + return true; +} + +bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector *data_type_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(data_type_list); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { + auto type_ptr = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, i); + auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type_ptr); + if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { + MS_LOG(EXCEPTION) << "HcomDataType cann't support Current Ascend Data Type : " << type_ptr; + } + data_type_list->emplace_back(iter->second); + } + auto type_base = *(std::begin(*data_type_list)); + if (std::any_of(data_type_list->begin(), data_type_list->end(), + [&type_base](hcclDataType_t type) { return type != type_base; })) { + MS_LOG(ERROR) << "hccl have different data type"; + return false; + } + return true; +} + +bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector &shape, size_t *size) { + MS_EXCEPTION_IF_NULL(size); + size_t tmp_size = 1; + uint32_t type_size = 4; + for (size_t i = 0; i < shape.size(); i++) { + tmp_size = SizetMulWithOverflowCheck(tmp_size, shape[i]); + } + + if (!GetHcomTypeSize(data_type, &type_size)) { + return false; + } + + *size = SizetMulWithOverflowCheck(tmp_size, type_size); + + MS_LOG(INFO) << "size[" << *size << "]"; + return true; +} + +bool HcomUtil::GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size) { + MS_EXCEPTION_IF_NULL(size); + auto iter = CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.find(data_type); + if (iter == CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.end()) { + MS_LOG(ERROR) << "HcomUtil::HcomDataTypeSize, No DataTypeSize!"; + return false; + } + *size = iter->second; + return true; +} + +bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector &data_type_list, + const vector> &shape_list, uint64_t *total_count) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(total_count); + const uint32_t align_size = 512; + const uint32_t filled_size = 32; + uint64_t total_size = 0; + uint64_t block_size; + size_t input_size; + uint32_t type_size = 4; + + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { + if (!GetHcomTypeSize(data_type_list[i], &type_size)) { + return false; + } + + if (!GetHcclOpSize(data_type_list[i], shape_list[i], &input_size)) { + MS_LOG(ERROR) << "Get GetHcclOpSize failed"; + return false; + } + + if (AnfAlgo::GetCNodeName(anf_node) == kReduceScatterOpName) { + int32_t rank_size; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (primitive->GetAttr("rank_size") != nullptr) { + rank_size = GetValue(primitive->GetAttr("rank_size")); + } else { + MS_LOG(ERROR) << "Get rank size failed"; + return false; + } + block_size = input_size / IntToSize(rank_size); + total_size = total_size + block_size; + } else { + if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) { + block_size = input_size; + } else { + block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; + } + total_size = total_size + block_size; + } + } + + if (type_size == 0 || total_size % type_size != 0) { + MS_LOG(ERROR) << "Total_size[" << total_size << "],Type_size[" << type_size << "] != 0, fail!"; + return false; + } + *total_count = total_size / type_size; + return true; +} + +bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(op_type); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (primitive->GetAttr("op") == nullptr) { + MS_LOG(ERROR) << "Get HCOM_ATTR_REDUCE_TYPE fail, not support!"; + return false; + } + auto hcom_op_type_get = GetValue(primitive->GetAttr("op")); + string hcom_op_type(hcom_op_type_get); + if (hcom_op_type == "min") { + *op_type = HCCL_REP_OP_MIN; + } else if (hcom_op_type == "max") { + *op_type = HCCL_REP_OP_MAX; + } else if (hcom_op_type == "prod") { + *op_type = HCCL_REP_OP_PROD; + } else if (hcom_op_type == "sum") { + *op_type = HCCL_REP_OP_SUM; + } else { + MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!"; + return false; + } + return true; +} + +bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(root_id); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (primitive->GetAttr("root_rank") != nullptr) { + *root_id = (uint32_t)GetValue(primitive->GetAttr("root_rank")); + } else { + MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"; + return false; + } + return true; +} + +void HcomUtil::GetHcomGroup(NotNull anf_node, NotNull group) { + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + auto attr = primitive->GetAttr("group"); + if (attr != nullptr) { + *group = GetValue(attr); + } else { + MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed"; + } +} +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_util.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h similarity index 100% rename from mindspore/ccsrc/kernel/hccl/hcom_util.h rename to mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/kash/kernel_pack.cc b/mindspore/ccsrc/backend/kernel_compiler/kash/kernel_pack.cc new file mode 100644 index 0000000000..9933826f2b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kash/kernel_pack.cc @@ -0,0 +1,248 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/akg/akg_kernel_build.h" +#include "nlohmann/json.hpp" +#include "securec/include/securec.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "utils/log_adapter.h" +#include "utils/convert_utils.h" +namespace mindspore { +namespace kernel { +constexpr auto kUtilsModule = "mindspore._extends.utils"; +constexpr auto kCalSha256Func = "cal_sha256"; + +namespace { +bool CheckHash(const std::string &json_file, const std::string &bin_file, const nlohmann::json &js) { + if (js.find("sha256") == js.end()) { + MS_LOG(ERROR) << "No sha256 found in " << json_file; + return false; + } + std::string sha256_str = js["sha256"]; + py::object ret = parse::python_adapter::CallPyFn(kUtilsModule, kCalSha256Func, bin_file); + std::string sha256_cal = py::cast(ret); + if (sha256_cal.empty()) { + MS_LOG(ERROR) << "Cal sha256 of " << bin_file << " failed."; + return false; + } + if (sha256_cal != sha256_str) { + MS_LOG(ERROR) << "Cal sha256 of " << bin_file << " failed."; + return false; + } + return true; +} +} // namespace + +const std::string KernelPack::Serialize() const { + MS_EXCEPTION_IF_NULL(json_); + MS_EXCEPTION_IF_NULL(kernel_); + std::string buffer; + (void)buffer.append((const char *)json_, json_->len + sizeof(json_->len)); + (void)buffer.append((const char *)kernel_, kernel_->len + sizeof(kernel_->len)); + return buffer; +} + +bool KernelPack::ReadFromJsonFileHelper(std::ifstream &kernelbin) { + size_t binsize = LongToSize(kernelbin.seekg(0, std::ios::end).tellg()); + // free old data + if (kernel_ != nullptr) { + delete[] kernel_; + kernel_ = nullptr; + } + + void *ptr = static_cast(new (std::nothrow) uint8_t[sizeof(KernelPack) + binsize]); + if (ptr != nullptr) { + kernel_ = static_cast(ptr); + } + if (kernel_ == nullptr) { + MS_LOG(ERROR) << "memory malloc failed."; + kernelbin.close(); + return false; + } + if (memset_s(kernel_, sizeof(KernelPack) + binsize, 0, sizeof(KernelPack) + binsize) != EOK) { + MS_LOG(ERROR) << "memset kernel_ failed."; + delete[] kernel_; + kernel_ = nullptr; + kernelbin.close(); + return false; + } + kernel_->len = binsize; + MS_LOG(INFO) << "kernel len:" << kernel_->len; + (void)kernelbin.seekg(0, std::ios::beg); + (void)kernelbin.read(kernel_->contents, SizeToLong(kernel_->len)); + return true; +} + +bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &processor) { + if (json_f.length() <= strlen(kJsonSuffix)) { + MS_LOG(ERROR) << "please check json path."; + return false; + } + + std::ifstream kerneljson(json_f); + if (!kerneljson.is_open()) { + MS_LOG(DEBUG) << "read json file error, please check kernelmeta."; + return false; + } + nlohmann::json js; + kerneljson >> js; + + size_t binsize = LongToSize(kerneljson.seekg(0, std::ios::end).tellg()); + void *ptr = static_cast(new (std::nothrow) uint8_t[sizeof(KernelPack) + binsize]); + if (ptr != nullptr) { + json_ = static_cast(ptr); + } + if (json_ == nullptr) { + MS_LOG(ERROR) << "memory malloc failed."; + kerneljson.close(); + return false; + } + json_->len = binsize; + (void)kerneljson.seekg(0, std::ios::beg); + (void)kerneljson.read(json_->contents, SizeToLong(json_->len)); + + if (processor == kProcessorCuda) { + std::string bin_f = json_f.substr(0, json_f.length() - 5) + ".ptx"; + std::ifstream kernelbin(bin_f); + if (!kernelbin.is_open()) { + MS_LOG(ERROR) << "read kernel ptx file error, please check kernelmeta."; + kerneljson.close(); + return false; + } + + if (ReadFromJsonFileHelper(kernelbin) == false) { + delete[] json_; + json_ = nullptr; + kerneljson.close(); + return false; + } + kerneljson.close(); + if (!CheckHash(json_f, bin_f, js)) { + return false; + } + return true; + } + + std::string binfilesuffix = js["binFileSuffix"]; + std::string bin_f = json_f.substr(0, json_f.length() - 5) + binfilesuffix; + if (binfilesuffix.compare(".so") == 0) { + // change "xx/xx.so" -> "xx/libxx.so" + auto sp = bin_f.rfind('/'); + if (sp == std::string::npos) { + MS_LOG(ERROR) << "illegal bin file path " << bin_f; + kerneljson.close(); + return false; + } + bin_f = bin_f.substr(0, sp + 1) + "lib" + bin_f.substr(sp + 1, bin_f.length() - sp - 1); + } + + std::ifstream kernelbin(bin_f, std::ios::binary); + if (!kernelbin.is_open()) { + MS_LOG(ERROR) << "read kernel binary file error, please check kernelmeta."; + kerneljson.close(); + delete[] json_; + json_ = nullptr; + return false; + } + + MS_LOG(INFO) << "kernelbin_name:" << bin_f; + if (ReadFromJsonFileHelper(kernelbin) == false) { + delete[] json_; + json_ = nullptr; + kerneljson.close(); + return false; + } + kerneljson.close(); + + if (!CheckHash(json_f, bin_f, js)) { + return false; + } + + return true; +} + +void KernelPack::ParseKernelJson(const nlohmann::json &js) { + kernel_json_info_.bin_file_name = js["binFileName"]; + kernel_json_info_.bin_file_suffix = js["binFileSuffix"]; + kernel_json_info_.block_dim = js["blockDim"]; + kernel_json_info_.kernel_name = js["kernelName"]; + kernel_json_info_.magic = js["magic"]; + if (js.find("parameters") != js.end()) { + if (!js.at("parameters").is_array()) { + MS_LOG(DEBUG) << "Format error!,parameters should be array."; + } + std::vector sizes = js.at("parameters"); + for (auto size : sizes) { + MS_LOG(INFO) << "parameter " << size; + kernel_json_info_.parameters.push_back(size); + } + } + if (js.find("workspace") != js.end()) { + auto workspace = js.at("workspace"); + std::vector sizes = workspace.at("size"); + for (auto size : sizes) { + MS_LOG(INFO) << "workspace_size_list " << size; + kernel_json_info_.workspaces.push_back(size); + } + } + kernel_json_info_.sha256 = js["sha256"]; +} + +bool KernelPack::LoadKernelMeta(const std::string &json_f, const std::string &processor) { + if (json_f.length() <= strlen(kJsonSuffix)) { + MS_LOG(ERROR) << "please check json path."; + return false; + } + std::ifstream kernel_json(json_f); + if (!kernel_json.is_open()) { + MS_LOG(DEBUG) << "read json file error, please check kernelmeta."; + return false; + } + nlohmann::json js; + kernel_json >> js; + ParseKernelJson(js); + kernel_json.close(); + + std::string bin_f = json_f.substr(0, json_f.length() - 5) + kernel_json_info_.bin_file_suffix; + if (kernel_json_info_.bin_file_suffix == ".so") { + // change "xx/xx.so" -> "xx/libxx.so" + auto sp = bin_f.rfind('/'); + if (sp == std::string::npos) { + MS_LOG(ERROR) << "illegal bin file path " << bin_f; + return false; + } + bin_f = bin_f.substr(0, sp + 1) + "lib" + bin_f.substr(sp + 1, bin_f.length() - sp - 1); + } + + std::ifstream kernelbin(bin_f, std::ios::binary); + if (!kernelbin.is_open()) { + MS_LOG(ERROR) << "read kernel binary file error, please check kernelmeta."; + return false; + } + + MS_LOG(INFO) << "kernelbin_name:" << bin_f; + if (!ReadFromJsonFileHelper(kernelbin)) { + return false; + } + + return CheckHash(json_f, bin_f, js); +} + +KernelJsonInfo KernelPack::kernel_json_info() const { return kernel_json_info_; } +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel.h b/mindspore/ccsrc/backend/kernel_compiler/kernel.h new file mode 100644 index 0000000000..2d240338f3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel.h @@ -0,0 +1,141 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_KERNEL_H_ +#include +#include +#include +#include "nlohmann/json.hpp" +#include "ir/anf.h" +#include "ir/dtype.h" +#include "utils/utils.h" +#include "ir/tensor.h" +#include "abstract/dshape.h" +#include "utils/log_adapter.h" + +namespace mindspore { +enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; + +namespace kernel { + +enum Axis : int { + N = 0, + C, + H, + W, +}; + +// Supported fusion type +enum FusionType { + CONVLUTION = 0, + ELEMWISE, + COMMREDUCE, + SEGMENT, + OPAQUE, + DYNAMIC, + UNKNOWN_FUSION_TYPE = -1, +}; +enum OpPattern { + kCommonPattern = 0, + kFormatAgnosticPattern = 1, + kBroadcastPattern = 2, + kReducePattern = 3, + kDynamicFormatPattern = 4, +}; + +// Backend processor +enum Processor { + AICORE = 0, + AICPU, + CUDA, +}; + +struct FlexArray { + size_t len; + char contents[]; +}; + +struct KernelJsonInfo { + std::string bin_file_name; + std::string bin_file_suffix; + uint32_t block_dim; + std::string kernel_name; + std::string magic; + std::vector parameters; + std::string sha256; + std::vector workspaces; + KernelJsonInfo() : block_dim(0) {} +}; + +class KernelPack { + public: + KernelPack() : json_(nullptr), kernel_(nullptr) {} + KernelPack(const KernelPack &) = default; + KernelJsonInfo kernel_json_info() const; + bool LoadKernelMeta(const std::string &json_f, const std::string &processor); + bool ReadFromJsonFile(const std::string &json_f, const std::string &processor); + const std::string Serialize() const; + const FlexArray *const GetJson() const { return json_; } + const FlexArray *const GetKernel() const { return kernel_; } + ~KernelPack() { + if (json_) { + delete[] json_; + json_ = nullptr; + } + if (kernel_) { + delete[] kernel_; + kernel_ = nullptr; + } + } + + private: + bool ReadFromJsonFileHelper(std::ifstream &kernelbin); + void ParseKernelJson(const nlohmann::json &js); + KernelJsonInfo kernel_json_info_; + FlexArray *json_; + FlexArray *kernel_; +}; +using KernelPackPtr = std::shared_ptr; + +/** + * @brief base class for autotensor kernel and cce kernel. + */ +struct Address { + void *addr; + size_t size; +}; +using AddressPtr = std::shared_ptr
; + +class KernelMod { + public: + virtual const std::vector &GetInputSizeList() const = 0; + virtual const std::vector &GetOutputSizeList() const = 0; + virtual const std::vector &GetWorkspaceSizeList() const = 0; + virtual bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) = 0; + virtual std::vector GenParameters() { return {}; } + + virtual ~KernelMod() = default; + void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } + + protected: + std::string kernel_name_; +}; +using KernelModPtr = std::shared_ptr; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc new file mode 100644 index 0000000000..68392d1871 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc @@ -0,0 +1,193 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/kernel_build_info.h" +#include +#include "utils/log_adapter.h" +#include "debug/anf_ir_dump.h" +namespace mindspore { +namespace kernel { +std::string KernelBuildInfo::GetInputFormat(size_t input_index) const { + if (input_index >= inputs_format_.size()) { + MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node"; + return kInvalidFormat; + } + return inputs_format_[input_index]; +} + +std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const { + if (output_index >= outputs_format_.size()) { + MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node"; + return kInvalidFormat; + } + return outputs_format_[output_index]; +} + +TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const { + if (input_index >= inputs_device_type_.size()) { + MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input"; + return TypeId::kNumberTypeEnd; + } + return inputs_device_type_[input_index]; +} + +TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const { + if (output_index >= outputs_device_type_.size()) { + MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output"; + return TypeId::kNumberTypeEnd; + } + return outputs_device_type_[output_index]; +} + +std::vector KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } + +std::vector KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } + +std::vector KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; } + +std::vector KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; } + +size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } + +size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } + +std::vector KernelBuildInfo::GetInputReshapeType(size_t input_index) const { + if (input_index >= input_reshape_type_.size()) { + MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size " + << input_reshape_type_.size(); + } + return input_reshape_type_[input_index]; +} + +std::vector KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { + if (output_index >= output_reshape_type_.size()) { + MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size " + << output_reshape_type_.size(); + } + return output_reshape_type_[output_index]; +} + +std::string KernelBuildInfo::ToString() const { + std::ostringstream output_buffer; + output_buffer << "("; + for (size_t index = 0; index < GetInputNum(); ++index) { + if (index != 0) { + output_buffer << ", "; + } + output_buffer << "<" << ToShortString(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">"; + } + output_buffer << ") -> ("; + for (size_t index = 0; index < GetOutputNum(); ++index) { + if (index != 0) { + output_buffer << ", "; + } + output_buffer << "<" << ToShortString(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">"; + } + output_buffer << ")"; + return output_buffer.str(); +} + +bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { + if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) { + return false; + } + if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) { + if (op_pattern_ != kFormatAgnosticPattern) { + return false; + } else { + MS_LOG(INFO) << "this kernel build info:" << this->ToString() + << ", other kernel build info: " << other.ToString(); + } + } + return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); +} + +bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); } + +bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); } + +bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); } + +void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->kernel_type_ = kernel_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector &inputs_format) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->inputs_format_ = inputs_format; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsFormat(const std::vector &outputs_format) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->outputs_format_ = outputs_format; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsDeviceType(const std::vector &inputs_device_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->inputs_device_type_ = inputs_device_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::vector &outputs_device_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->outputs_device_type_ = outputs_device_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->fusion_type_ = fusion_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->processor_ = processor; +} + +std::shared_ptr KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; } + +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType( + const std::vector> &input_reshape_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->input_reshape_type_ = input_reshape_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( + const std::vector> &output_reshape_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->output_reshape_type_ = output_reshape_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->op_pattern_ = pattern; +} +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + if (index >= kernel_build_info_->inputs_format_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + kernel_build_info_->inputs_format_[index] = format; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + if (index >= kernel_build_info_->outputs_format_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + kernel_build_info_->outputs_format_[index] = format; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h new file mode 100644 index 0000000000..be243c9ae0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h @@ -0,0 +1,147 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ +#define MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ +#include +#include +#include +#include +#include +#include "ir/dtype.h" +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +class KernelBuildInfo { + public: + class KernelBuildInfoBuilder; + + KernelBuildInfo() { + kernel_type_ = TBE_KERNEL; + fusion_type_ = OPAQUE; + processor_ = AICORE; + op_pattern_ = kCommonPattern; + input_reshape_type_ = {}; + output_reshape_type_ = {}; + inputs_format_ = {}; + outputs_format_ = {}; + inputs_device_type_ = {}; + outputs_device_type_ = {}; + } + + ~KernelBuildInfo() = default; + + KernelType kernel_type() const { return kernel_type_; } + + std::string GetInputFormat(size_t input_index) const; + + std::string GetOutputFormat(size_t output_index) const; + + TypeId GetInputDeviceType(size_t input_index) const; + + TypeId GetOutputDeviceType(size_t output_index) const; + + std::vector GetInputReshapeType(size_t input_index) const; + + bool IsInputDefaultPadding() const; + + bool IsOutputDefaultPadding() const; + + std::vector GetOutputReshapeType(size_t input_index) const; + + std::vector GetAllInputFormats() const; + + std::vector GetAllOutputFormats() const; + + std::vector GetAllInputDeviceTypes() const; + + std::vector GetAllOutputDeviceTypes() const; + + OpPattern op_pattern() const { return op_pattern_; } + + FusionType fusion_type() const { return fusion_type_; } + + Processor processor() const { return processor_; } + + size_t GetInputNum() const; + + size_t GetOutputNum() const; + + std::string ToString() const; + + bool operator==(const KernelBuildInfo &other) const; + + bool operator!=(const KernelBuildInfo &other) const; + + public: + static auto constexpr kInvalidFormat = "InvalidFormat"; + + private: + KernelType kernel_type_; + std::vector inputs_format_; + OpPattern op_pattern_; + std::vector outputs_format_; + std::vector> input_reshape_type_; + std::vector> output_reshape_type_; + std::vector inputs_device_type_; + std::vector outputs_device_type_; + FusionType fusion_type_; + Processor processor_; +}; +using KernelBuildInfoPtr = std::shared_ptr; + +class KernelBuildInfo::KernelBuildInfoBuilder { + public: + KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared(); } + + explicit KernelBuildInfoBuilder(std::shared_ptr kernel_build_info) + : kernel_build_info_(std::move(kernel_build_info)) {} + + ~KernelBuildInfoBuilder() = default; + + void SetKernelType(const KernelType &kernel_type); + + void SetInputsFormat(const std::vector &inputs_format); + + void SetOutputsFormat(const std::vector &outputs_format); + + void SetInputsDeviceType(const std::vector &inputs_device_type); + + void SetOutputsDeviceType(const std::vector &outputs_device_type); + + void SetInputReshapeType(const std::vector> &input_reshape_type); + + void SetOutputReshapeType(const std::vector> &output_reshape_type); + + void SetFusionType(FusionType fusion_type); + + void SetProcessor(Processor processor); + + void SetOpPattern(OpPattern pattern); + + void SetInputFormat(const std::string &format, size_t index); + + void SetOutputFormat(const std::string &format, size_t index); + + std::shared_ptr Build(); + + private: + std::shared_ptr kernel_build_info_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc new file mode 100644 index 0000000000..0045e49bef --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/kernel_fusion.h" + +#include +#include +#include +#include + +#include "common/utils.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" + +namespace mindspore { +namespace kernel { +using mindspore::kernel::tbe::TbeUtils; +static bool GenPreBuildKernelJson(const std::vector &compute_nodes, + std::vector *prebuild_op_list) { + MS_EXCEPTION_IF_NULL(prebuild_op_list); + TbeKernelJsonCreator creator(PREBUILD); + for (const auto &anf_node : compute_nodes) { + nlohmann::json prebuild; + if (!creator.GenTbeSingleKernelJson(anf_node, &prebuild)) { + MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; + return false; + } + (*prebuild_op_list).push_back(prebuild); + } + return true; +} + +std::map KernelFusion(const std::vector &fusion_scopes) { + MS_LOG(INFO) << "kernel fusion build start, scope size:" << fusion_scopes.size(); + std::map kernel_mod_ret; + auto build_manger = std::make_shared(); + MS_EXCEPTION_IF_NULL(build_manger); + for (const auto &fusion_scope_iter : fusion_scopes) { + auto scope_id = fusion_scope_iter.scope_id; + nlohmann::json fusion_op; + string fusion_kernel = "te_fusion"; + if (!TbeKernelBuild::GenFusionScopeJson(fusion_scope_iter.input_nodes, fusion_scope_iter.compute_nodes, &fusion_op, + &fusion_kernel)) { + continue; + } + // gen kernel_name & check cache + std::string json_str = fusion_op.dump(); + size_t hash_id = std::hash()(json_str); + auto json_name = fusion_kernel.append("_").append(std::to_string(hash_id)); + fusion_op["fusion_op_name"] = json_name; + // gen json for prebuild + std::vector prebuild_op_list; + if (!GenPreBuildKernelJson(fusion_scope_iter.compute_nodes, &prebuild_op_list)) { + continue; + } + // get io size + std::vector input_size_list; + std::vector output_size_list; + if (!TbeKernelBuild::GetIOSize(fusion_op["op_list"], fusion_scope_iter.output_nodes, &input_size_list, + &output_size_list)) { + continue; + } + // search cache + auto kernel_pack = TbeUtils::SearchCache(json_name, tbe::kProcessorAiCore); + if (kernel_pack != nullptr) { + MS_LOG(INFO) << "Use cached kernel, kernel json name: " << json_name; + auto kernel_mod = + build_manger->GenKernelMod(json_name, tbe::kProcessorAiCore, input_size_list, output_size_list, kernel_pack); + if (kernel_mod != nullptr) { + kernel_mod_ret[scope_id] = kernel_mod; + continue; + } + } + // fusion build + nlohmann::json fusion_json; + fusion_json["fusion_op"] = fusion_op; + fusion_json["prebuild_ops"] = prebuild_op_list; + auto task_id = build_manger->StartCompileOp(fusion_json); + TbeUtils::SaveJsonInfo(json_name, fusion_json.dump()); + if (task_id < 0) { + MS_EXCEPTION(ArgumentError) << "start compile failed."; + } + build_manger->SaveTaskInfo(task_id, nullptr, json_name, input_size_list, output_size_list, scope_id); + } + + int build_failed_num = 0; + while (!build_manger->IsAllTaskFinish()) { + int task_id = -1; + char *task_result = nullptr; + char *pre_build_result = nullptr; + auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); + if (!ret) { + MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; + } + + if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { + MS_LOG(INFO) << "Fusion warning: Fuison op build failed, err log: " << task_result + << " change to single op build."; + build_failed_num++; + } + auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, false); + if (kernel_mod_item.second != nullptr) { + (void)kernel_mod_ret.emplace(kernel_mod_item); + } + } + MS_LOG(INFO) << "Build Fusion Kernel Failed Num: " << build_failed_num; + return kernel_mod_ret; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.h new file mode 100644 index 0000000000..2fb3a05b4b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ +#define MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ +#include +#include +#include "backend/kernel_compiler/kernel.h" +namespace mindspore { +namespace kernel { +/* + * @brief fuse op and return a callable mod + */ +struct FusionScopeInfo { + int32_t scope_id; + std::vector input_nodes; + std::vector compute_nodes; + std::vector output_nodes; +}; + +std::map KernelFusion(const std::vector &fusion_scopes); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc new file mode 100755 index 0000000000..81b5d0f996 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/kernel_query.h" +#include +#include +#include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" +#include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" +#include "backend/kernel_compiler/akg/akg_kernel_metadata.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +namespace { +void FilterInvalidKernelInfo(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_info_list); + std::vector> filtered_list; + (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), + [&kernel_node](const std::shared_ptr &kernel_build_info) { + return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && + AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); + }); + if (!filtered_list.empty()) { + kernel_info_list->clear(); + (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); + } else { + MS_LOG(INFO) << "All kernel Info list does not match any kernel info "; + for (size_t index = 0; index < kernel_info_list->size(); ++index) { + std::ostringstream buffer; + auto kernel_info = kernel_info_list->at(index); + MS_EXCEPTION_IF_NULL(kernel_info); + if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info->GetOutputNum()) { + buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" + << " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]"; + } else { + buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]" + << " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]"; + } + MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str(); + } + kernel_info_list->clear(); + MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : [" + << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" + << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; + } +} +} // namespace + +void KernelQueryAll(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + + TbeMetadataInfo(kernel_node, kernel_info_list); + + if (kernel_info_list->empty()) { + AicpuMetadataInfo(kernel_node, kernel_info_list); + if (!kernel_info_list->empty()) { + MS_LOG(INFO) << "The node [" << kernel_node->DebugString() + << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; + AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); + } + } + + if (kernel_info_list->empty()) { + GetRtKelInfo(kernel_node, kernel_info_list); + } + + if (kernel_info_list->empty()) { + HcclMetadataInfo(kernel_node, kernel_info_list); + } + if (kernel_info_list->empty()) { + MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; + } +} + +void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list, + KernelType kernel_type) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { + kernel_type = KernelType::AKG_KERNEL; + } + + switch (kernel_type) { + case KernelType::AKG_KERNEL: + AkgMetadataInfo(kernel_node, kernel_info_list); + break; + default: + KernelQueryAll(kernel_node, kernel_info_list); + break; + } + + if (kernel_info_list->empty()) { + MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!"; + } + // check output + FilterInvalidKernelInfo(kernel_node, kernel_info_list); +} + +void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + kernel_info_list->clear(); + AicpuMetadataInfo(kernel_node, kernel_info_list); + FilterInvalidKernelInfo(kernel_node, kernel_info_list); +} +bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(select_kernel_build_info); + std::vector> kernel_info_list; + auto cnode = kernel_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + AICPUQuery(cnode, &kernel_info_list); + return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), + [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { + MS_EXCEPTION_IF_NULL(item); + return *item == *select_kernel_build_info; + }); +} + +bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(select_kernel_build_info); + std::vector> kernel_info_list; + auto cnode = kernel_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + TbeMetadataInfo(cnode, &kernel_info_list); + return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), + [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { + MS_EXCEPTION_IF_NULL(item); + return *item == *select_kernel_build_info; + }); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.h new file mode 100644 index 0000000000..20458f48d0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.h @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ +#define MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ + +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list, + KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); +void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); +bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h new file mode 100644 index 0000000000..64ae1009d1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h @@ -0,0 +1,175 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ +#define MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ +#include +#include +#include +#include +#include "ir/dtype.h" +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU }; +enum OpIOType { kInput = 0, kOutput }; + +class OpAttr { + public: + OpAttr() = default; + ~OpAttr() = default; + + std::string name() const { return name_; } + std::string param_type() const { return param_type_; } + std::string type() const { return type_; } + std::string value() const { return value_; } + std::string default_value() const { return default_value_; } + + void set_name(const std::string &name) { name_ = name; } + void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } + void set_type(const std::string &type) { type_ = type; } + void set_value(const std::string &value) { value_ = value; } + void set_default_value(const std::string &default_value) { default_value_ = default_value; } + + private: + std::string name_; + std::string param_type_; + std::string type_; + std::string value_; + std::string default_value_; +}; + +class OpIOInfo { + public: + OpIOInfo() = default; + ~OpIOInfo() = default; + + int index() const { return index_; } + std::string name() const { return name_; } + bool need_compile() const { return need_compile_; } + std::string param_type() const { return param_type_; } + std::string reshape_type() const { return reshape_type_; } + std::string shape() const { return shape_; } + std::vector dtypes() const { return dtypes_; } + std::vector formats() const { return formats_; } + + void set_index(const int index) { index_ = index; } + void set_name(const std::string &name) { name_ = name; } + void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } + void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } + void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } + void set_shape(const std::string &shape) { shape_ = shape; } + void set_dtypes(const std::vector &dtype) { dtypes_ = dtype; } + void set_formats(const std::vector &formats) { formats_ = formats; } + + private: + int index_ = 0; + std::string name_; + bool need_compile_ = false; + std::string param_type_; + std::string reshape_type_; + std::string shape_; + std::vector dtypes_; + std::vector formats_; +}; + +class OpInfo { + public: + OpInfo() = default; + OpInfo(const OpInfo &opinfo) { + op_name_ = opinfo.op_name(); + imply_type_ = opinfo.imply_type(); + + impl_path_ = opinfo.impl_path(); + fusion_type_ = opinfo.fusion_type(); + async_flag_ = opinfo.async_flag_; + binfile_name_ = opinfo.binfile_name_; + compute_cost_ = opinfo.compute_cost_; + kernel_name_ = opinfo.kernel_name(); + partial_flag_ = opinfo.partial_flag_; + dynamic_format_ = opinfo.dynamic_format_; + op_pattern_ = opinfo.op_pattern(); + processor_ = opinfo.processor_; + for (const auto &attr : opinfo.attrs_ptr()) { + attrs_ptr_.push_back(std::make_shared(*attr)); + } + for (const auto &input : opinfo.inputs_ptr()) { + inputs_ptr_.push_back(std::make_shared(*input)); + } + for (const auto &output : opinfo.outputs_ptr()) { + outputs_ptr_.push_back(std::make_shared(*output)); + } + ref_infos_ = opinfo.ref_infos(); + } + ~OpInfo() = default; + std::string op_name() const { return op_name_; } + OpImplyType imply_type() const { return imply_type_; } + std::string impl_path() const { return impl_path_; } + std::string fusion_type() const { return fusion_type_; } + std::string kernel_name() const { return kernel_name_; } + OpPattern op_pattern() const { return op_pattern_; } + std::string processor() const { return processor_; } + std::vector> attrs_ptr() const { return attrs_ptr_; } + std::vector> inputs_ptr() const { return inputs_ptr_; } + std::vector> outputs_ptr() const { return outputs_ptr_; } + const std::unordered_map &ref_infos() const { return ref_infos_; } + + void set_op_name(const std::string &op_name) { op_name_ = op_name; } + void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } + void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } + void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } + void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } + void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } + void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } + void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } + void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } + void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } + void set_processor(const std::string &processor) { processor_ = processor; } + void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } + void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } + void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } + bool is_ref() const { return !ref_infos_.empty(); } + bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } + void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } + void ClearInputs() { (void)inputs_ptr_.clear(); } + void ClearOutputs() { (void)outputs_ptr_.clear(); } + bool equals_to(const std::shared_ptr &other_info) const { + return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ && + this->processor_ == other_info->processor_; + } + + private: + std::string op_name_; + OpImplyType imply_type_ = kTBE; + std::string impl_path_; + std::string fusion_type_; + bool async_flag_ = false; + std::string binfile_name_; + int compute_cost_ = 0; + std::string kernel_name_; + bool partial_flag_ = false; + bool dynamic_format_ = false; + OpPattern op_pattern_ = kCommonPattern; + std::string processor_; + std::vector> attrs_ptr_; + std::vector> inputs_ptr_; + std::vector> outputs_ptr_; + std::unordered_map ref_infos_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc new file mode 100644 index 0000000000..69c4ca7db1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc @@ -0,0 +1,390 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/oplib/oplib.h" +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "utils/overload.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +constexpr auto kImplyType = "imply_type"; +constexpr auto kOpName = "op_name"; +constexpr auto kFusionType = "fusion_type"; +constexpr auto kAsyncFlag = "async_flag"; +constexpr auto kBinfileName = "binfile_name"; +constexpr auto kComputeCost = "compute_cost"; +constexpr auto kKernelName = "kernel_name"; +constexpr auto kPartialFlag = "partial_flag"; +constexpr auto kReshapeType = "reshape_type"; +constexpr auto kOpPattern = "op_pattern"; +constexpr auto kDynamicFormat = "dynamicFormat"; +constexpr auto kFormatAgnostic = "formatAgnostic"; +constexpr auto kBroadcast = "broadcast"; +constexpr auto kReduce = "reduce"; +constexpr auto kDtypeFormat = "dtype_format"; +constexpr auto kAttr = "attr"; +constexpr auto kIputs = "inputs"; +constexpr auto kOutputs = "outputs"; +constexpr auto kAiCPU = "AiCPU"; +constexpr auto kAiCore = "AiCore"; +constexpr auto kCUDA = "CUDA"; +constexpr auto kTbe = "TBE"; +constexpr auto kAkg = "AKG"; +constexpr auto kName = "name"; +constexpr auto kParamType = "param_type"; +constexpr auto kDtype = "dtype"; +constexpr auto kType = "type"; +constexpr auto kValue = "value"; +constexpr auto kDefaultValue = "default_value"; +constexpr auto kIndex = "index"; +constexpr auto kFormat = "format"; +constexpr auto kNeedCompile = "need_compile"; +constexpr auto kShape = "shape"; +constexpr auto kProcessor = "processor"; +std::vector> OpLib::op_info_; + +static std::string ImplTypeToStr(OpImplyType impl_type) { + switch (impl_type) { + case kTBE: + return kTbe; + case kAKG: + return kAkg; + case kAICPU: + return kAiCPU; + default: + return "unknow"; + } +} +bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) { + bool ret = false; + try { + auto op_json = nlohmann::json::parse(json_string); + std::string imply_type_string = op_json.at(kImplyType); + std::string op_name = op_json.at(kOpName); + if (imply_type_string == kTbe) { + OpImplyType imply_type = kTBE; + ret = DecodeOpInfo(op_json, imply_type, impl_path); + } else if (imply_type_string == kAkg) { + OpImplyType imply_type = kAKG; + ret = DecodeOpInfo(op_json, imply_type, impl_path); + } else if (imply_type_string == kAiCPU) { + OpImplyType imply_type = kAICPU; + ret = DecodeOpInfo(op_json, imply_type, impl_path); + } else { + MS_LOG(ERROR) << "Not support imply_type"; + } + if (!ret) { + MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string; + } + } catch (const std::exception &e) { + MS_LOG(ERROR) << "get op json elements failed: " << e.what(); + } + return ret; +} + +void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { + const std::map kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, + {kBroadcast, kBroadcastPattern}, + {kReduce, kReducePattern}, + {kDynamicFormat, kDynamicFormatPattern}}; + MS_EXCEPTION_IF_NULL(op_info); + op_info->set_async_flag(obj.at(kAsyncFlag)); + op_info->set_binfile_name(obj.at(kBinfileName)); + op_info->set_compute_cost(obj.at(kComputeCost)); + op_info->set_kernel_name(obj.at(kKernelName)); + op_info->set_partial_flag(obj.at(kPartialFlag)); + + if (obj.find(kOpPattern) != obj.end()) { + std::string op_pattern = obj.at(kOpPattern); + auto find_iter = kOpPatternMap.find(op_pattern); + if (find_iter == kOpPatternMap.end()) { + if (!op_pattern.empty()) { + MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern; + } + op_info->set_op_pattern(kCommonPattern); + } else { + op_info->set_op_pattern(find_iter->second); + } + } +} + +void OpLib::DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(op_info); + op_info->set_processor(obj.at(kProcessor)); +} + +bool OpLib::RegOpFromLocalInfo() { + MS_LOG(INFO) << "Start"; + static bool has_load = false; + if (has_load) { + return true; + } + has_load = true; + std::string dir = common::GetEnv("MINDSPORE_OP_INFO_PATH"); + if (dir.empty()) { + MS_LOG(INFO) << "MindSpore op info path does not been setted. use op info from python pass."; + return true; + } + char real_path[PATH_MAX] = {0}; + if (dir.size() >= PATH_MAX) { + MS_LOG(ERROR) << "Op info path is invalid: " << dir; + return false; + } +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) { + MS_LOG(ERROR) << "Op info path is invalid: " << dir; + return false; + } +#else + if (realpath(common::SafeCStr(dir), real_path) == nullptr) { + MS_LOG(ERROR) << "Op info path is invalid: " << dir; + return false; + } +#endif + MS_LOG(INFO) << "Start to read op info from local file."; + std::ifstream file(real_path); + if (!file.is_open()) { + MS_LOG(ERROR) << "Find op info file failed."; + return false; + } + std::string line; + while (getline(file, line)) { + if (!line.empty()) { + (void)OpLib::RegOp(line, ""); + } + } + MS_LOG(INFO) << "End"; + return true; +} + +bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, + const std::string &impl_path) { + std::shared_ptr op_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(op_info); + op_info->set_op_name(obj.at(kOpName)); + op_info->set_impl_path(impl_path); + op_info->set_imply_type(imply_type); + op_info->set_fusion_type(obj.at(kFusionType)); + if (imply_type == kTBE) { + DecodeTBESpecificInfo(obj, op_info); + } else if (imply_type == kAKG) { + DecodeAKGSpecificInfo(obj, op_info); + } + auto attrs = obj.at(kAttr); + for (const auto &attr : attrs) { + if (!DecodeAttr(attr, imply_type, op_info)) { + MS_LOG(ERROR) << "DecodeAttr Failed"; + return false; + } + } + nlohmann::json dtype_format; + if (obj.find(kDtypeFormat) != obj.end()) { + dtype_format = obj.at(kDtypeFormat); + } + auto inputs = obj.at(kIputs); + for (const auto &input : inputs) { + if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { + MS_LOG(ERROR) << "DecodeInputOutput Failed"; + return false; + } + } + auto outputs = obj.at(kOutputs); + for (const auto &output : outputs) { + if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { + MS_LOG(ERROR) << "DecodeInputOutput Failed"; + return false; + } + } + if (CheckRepetition(op_info)) { + MS_LOG(WARNING) << "This op info has been already registed. op name: " << op_info->op_name() + << ", impl type: " << ImplTypeToStr(op_info->imply_type()) + << ", impl path: " << op_info->impl_path(); + return true; + } + if (!GetRefInfo(op_info)) { + MS_LOG(ERROR) << "GetRefInfo Failed"; + return false; + } + op_info_.push_back(op_info); + return true; +} + +bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, + const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(op_info); + bool ret = true; + try { + std::shared_ptr op_attr = std::make_shared(); + MS_EXCEPTION_IF_NULL(op_attr); + op_attr->set_name(obj.at(kName)); + if (imply_type != kAICPU) { + op_attr->set_param_type(obj.at(kParamType)); + } + op_attr->set_type(obj.at(kType)); + if (imply_type == kTBE) { + op_attr->set_value(obj.at(kValue)); + } + if (obj.find(kDefaultValue) != obj.end()) { + op_attr->set_default_value(obj.at(kDefaultValue)); + } + op_info->add_attrs_ptr(op_attr); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "DecodeAttr failed:" << e.what(); + ret = false; + } + return ret; +} + +bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, + size_t index) { + MS_EXCEPTION_IF_NULL(op_io); + bool ret = true; + try { + std::vector dtype; + std::vector format; + for (const auto &it : dtype_format) { + dtype.emplace_back(it[index][0]); + format.emplace_back(it[index][1]); + } + op_io->set_dtypes(dtype); + op_io->set_formats(format); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); + ret = false; + } + return ret; +} + +bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, + const std::shared_ptr &op_info, const nlohmann::json &dtype_format) { + MS_EXCEPTION_IF_NULL(op_info); + bool ret = true; + try { + std::shared_ptr op_io = std::make_shared(); + MS_EXCEPTION_IF_NULL(op_io); + op_io->set_index(obj.at(kIndex)); + op_io->set_name(obj.at(kName)); + if (!dtype_format.empty()) { + if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) { + MS_LOG(ERROR) << "Decode dtype format failed"; + return false; + } + } else { + op_io->set_dtypes(obj.at(kDtype)); + op_io->set_formats(obj.at(kFormat)); + } + if (op_io->dtypes().size() != op_io->formats().size()) { + MS_LOG(ERROR) << "op " << op_io->name() << " dtype size: " << op_io->dtypes() + << " is not equal to format size: " << op_io->formats(); + return false; + } + if (obj.find(kParamType) != obj.end()) { + op_io->set_param_type(obj.at(kParamType)); + } + if (imply_type == kTBE) { + if (obj.find(kNeedCompile) != obj.end()) { + op_io->set_need_compile(obj.at(kNeedCompile)); + } + if (obj.find(kShape) != obj.end()) { + op_io->set_shape(obj.at(kShape)); + } + if (obj.find(kReshapeType) != obj.end()) { + op_io->set_reshape_type(obj.at(kReshapeType)); + } + } + + if (io_type == kInput) { + op_info->add_inputs_ptr(op_io); + } else if (io_type == kOutput) { + op_info->add_outputs_ptr(op_io); + } + } catch (const std::exception &e) { + MS_LOG(ERROR) << "DecodeInputOutput failed" << e.what(); + ret = false; + } + return ret; +} + +std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { + if (!OpLib::RegOpFromLocalInfo()) { + MS_LOG(INFO) << "Warning reg local op info failed."; + } + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool is_gpu = (context->device_target() == kGPUDevice); + if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) { + MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) + << ", current op num: " << op_info_.size(); + return nullptr; + } + for (const auto &op_info : op_info_) { + MS_EXCEPTION_IF_NULL(op_info); + if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { + auto akg_processor_match = [&]() { + return is_gpu ? op_info->processor() == kCUDA : op_info->processor() == kAiCore; + }; + if (imply_type != kAKG || akg_processor_match()) { + return op_info; + } + } + } + MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) + << ", current op num: " << op_info_.size(); + return nullptr; +} + +bool OpLib::GetRefInfo(const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(op_info); + const auto &output_infos = op_info->outputs_ptr(); + const auto &input_infos = op_info->inputs_ptr(); + for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { + MS_EXCEPTION_IF_NULL(output_infos[out_index]); + const auto &out_name = output_infos[out_index]->name(); + for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { + MS_EXCEPTION_IF_NULL(input_infos[in_index]); + const auto &in_name = input_infos[in_index]->name(); + if (out_name == in_name) { + if (op_info->has_ref_index(out_index)) { + MS_LOG(ERROR) << "The out_index " << out_index << " is already in ref_info"; + return false; + } + op_info->add_ref_pair(out_index, in_index); + MS_LOG(INFO) << "add ref info, op name is " << op_info->op_name() << ", outindex is " << out_index + << ", in_index is " << in_index; + } + } + } + return true; +} + +bool OpLib::CheckRepetition(const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(op_info); + for (const auto &exist_op_info : op_info_) { + MS_EXCEPTION_IF_NULL(exist_op_info); + if (exist_op_info->equals_to(op_info)) { + return true; + } + } + return false; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h new file mode 100644 index 0000000000..845edbfc2a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ +#define MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/oplib/opinfo.h" + +namespace mindspore { +namespace kernel { +class OpLib { + public: + OpLib() = default; + virtual ~OpLib() = default; + static bool RegOp(const std::string &json_string, const std::string &impl_path); + static void RegOpInfo(const std::shared_ptr &opinfo) { op_info_.emplace_back(opinfo); } + static std::shared_ptr FindOp(const std::string &op_name, OpImplyType imply_type); + static const std::vector> &GetAllOpsInfo() { return op_info_; } + + protected: + static std::vector> op_info_; + + private: + static bool RegOpFromLocalInfo(); + static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path); + static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, + const std::shared_ptr &op_info); + static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, + size_t index); + static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info); + static void DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info); + static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, + const std::shared_ptr &op_info, const nlohmann::json &dtype_format); + static bool GetRefInfo(const std::shared_ptr &op_info); + static bool CheckRepetition(const std::shared_ptr &op_info); +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oploader.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/oploader.h new file mode 100644 index 0000000000..6b2981e5b3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oploader.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_OPLOADER_H +#define MINDSPORE_OPLOADER_H + +#include +#include "backend/kernel_compiler/oplib/oplib.h" + +namespace mindspore { +namespace kernel { +class OpInfoLoaderPy { + public: + OpInfoLoaderPy() = default; + + ~OpInfoLoaderPy() = default; + + size_t GetAllOpsInfo() { + auto ops = OpLib::GetAllOpsInfo(); + auto op_infos = new std::vector(); + for (auto op_info : ops) { + auto new_op_info = new OpInfo(*op_info); + op_infos->emplace_back(new_op_info); + } + return (size_t)op_infos; + } +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_OPLOADER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc new file mode 100644 index 0000000000..552468bb71 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/assign.h" + +#include + +#include "runtime/mem.h" +#include "common/utils.h" + +using ge::model_runner::MemcpyAsyncTaskInfo; +using MemcpyAsyncTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +AssignKernel::AssignKernel() {} + +AssignKernel::~AssignKernel() {} + +bool AssignKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void *stream_ptr) { + if (inputs.size() != 2) { + MS_LOG(ERROR) << "inputs size is not two"; + return false; + } + + if (inputs[0]->addr == inputs[1]->addr) { + MS_LOG(INFO) << "first addr is same with second addr , no need assign"; + return true; + } + rtError_t status = rtMemcpyAsync(inputs[0]->addr, inputs[0]->size, inputs[1]->addr, inputs[1]->size, + RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Assign op rtMemcpyAsync failed!"; + return false; + } + return true; +} + +std::vector AssignKernel::GenTask(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) { + if (inputs.size() != 2) { + MS_LOG(EXCEPTION) << "inputs size is not two"; + } + stream_id_ = stream_id; + + std::shared_ptr task_info_ptr = + std::make_shared(kernel_name_, stream_id, inputs[0]->addr, inputs[0]->size, inputs[1]->addr, + inputs[1]->size, RT_MEMCPY_DEVICE_TO_DEVICE, false); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/assign.h b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.h new file mode 100644 index 0000000000..cff946cc36 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H +#define MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H + +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class AssignKernel : public RtKernel { + public: + AssignKernel(); + ~AssignKernel() override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; +}; + +MS_REG_RTKERNEL(assign, AssignKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc new file mode 100644 index 0000000000..8ec460fe0b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/label_goto.h" +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::LabelGotoTaskInfo; +using LabelGotoTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +LabelGotoKernel::LabelGotoKernel() { label_ = 0; } + +LabelGotoKernel::~LabelGotoKernel() {} + +bool LabelGotoKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "LabelGotoKernel init"; + auto cnode = anf_node->cast(); + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { + MS_LOG(EXCEPTION) << "LabelGotoKernel has no attr label_index"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + label_ = GetValue(primitive->GetAttr(kAttrLabelIndex)); + MS_LOG(INFO) << "LabelGotoKernel get attr label:" << label_; + return true; +} + +bool LabelGotoKernel::Launch(const std::vector & /*inputs*/, const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void * /*stream_ptr*/) { + MS_LOG(INFO) << "LabelGotoKernel launch"; + return true; +} + +std::vector LabelGotoKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "LabelGotoKernel GenTask label:" << label_ << ", stream id:" << stream_id; + std::vector task_info_list; + std::shared_ptr task_info_ptr = + std::make_shared(kernel_name_, stream_id, label_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + return task_info_list; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.h b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.h new file mode 100644 index 0000000000..2680d916a5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H +#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class LabelGotoKernel : public RtKernel { + public: + LabelGotoKernel(); + ~LabelGotoKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + uint32_t label_; +}; + +MS_REG_RTKERNEL(labelgoto, LabelGotoKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc new file mode 100644 index 0000000000..909885ff17 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/label_set.h" +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::LabelSetTaskInfo; +using LabelSetTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +LabelSetKernel::LabelSetKernel() { label_ = 0; } + +LabelSetKernel::~LabelSetKernel() {} + +bool LabelSetKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "LabelSetKernel init"; + auto cnode = anf_node->cast(); + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { + MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + label_ = GetValue(primitive->GetAttr(kAttrLabelIndex)); + MS_LOG(INFO) << "LabelSetKernel get attr label:" << label_; + return true; +} + +bool LabelSetKernel::Launch(const std::vector & /*inputs*/, const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void * /*stream_ptr*/) { + MS_LOG(INFO) << "LabelSetKernel launch"; + return true; +} + +std::vector LabelSetKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "LabelSetKernel GenTask label:" << label_ << ", stream id:" << stream_id; + std::vector task_info_list; + std::shared_ptr task_info_ptr = std::make_shared(kernel_name_, stream_id, label_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + return task_info_list; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.h b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.h new file mode 100644 index 0000000000..8d0cfdfb20 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H +#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class LabelSetKernel : public RtKernel { + public: + LabelSetKernel(); + ~LabelSetKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + uint32_t label_; +}; + +MS_REG_RTKERNEL(labelset, LabelSetKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc new file mode 100644 index 0000000000..ccb49d9497 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/label_switch.h" +#include +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::LabelSwitchTaskInfo; +using LabelSwitchTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +LabelSwitchKernel::LabelSwitchKernel() { + label_list_ = {}; + cond_ = nullptr; + label_size_ = 0; +} + +LabelSwitchKernel::~LabelSwitchKernel() {} + +bool LabelSwitchKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "LabelSwitchKernel init"; + auto cnode = anf_node->cast(); + if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) { + MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + label_list_ = GetValue>(primitive->GetAttr(kAttrLabelSwitchList)); + label_size_ = label_list_.size(); + MS_LOG(INFO) << "LabelSwitchKernel get attr label size:" << label_size_; + for (auto label : label_list_) { + MS_LOG(INFO) << "label: " << label; + } + return true; +} + +bool LabelSwitchKernel::Launch(const std::vector & /*inputs*/, + const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void * /*stream_ptr*/) { + MS_LOG(INFO) << "LabelSwitchKernel launch"; + return true; +} + +std::vector LabelSwitchKernel::GenTask(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) { + MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; + std::vector task_info_list; + cond_ = inputs[0]->addr; + auto task_info_ptr = std::make_shared(kernel_name_, stream_id, label_size_, label_list_, cond_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + return task_info_list; +} + +std::vector> LabelSwitchDesc::GetKernelInfo() { + std::vector> label_switch_build_info{}; + vector input_format{kOpFormat_DEFAULT}; + vector input_type{kNumberTypeInt32}; + if (input_format.size() != input_type.size()) { + MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " + << input_type.size(); + } + for (size_t i = 0; i < input_format.size(); ++i) { + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat({input_format[i]}); + builder.SetInputsDeviceType({input_type[i]}); + builder.SetProcessor(AICORE); + builder.SetKernelType(RT_KERNEL); + builder.SetFusionType(OPAQUE); + label_switch_build_info.emplace_back(builder.Build()); + } + return label_switch_build_info; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.h b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.h new file mode 100644 index 0000000000..1860d38d74 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H +#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class LabelSwitchKernel : public RtKernel { + public: + LabelSwitchKernel(); + ~LabelSwitchKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + std::vector label_list_; + uint32_t label_size_; + void *cond_; +}; + +class LabelSwitchDesc : public RtKerDesc { + public: + LabelSwitchDesc() = default; + ~LabelSwitchDesc() override = default; + std::vector> GetKernelInfo() override; +}; + +MS_REG_RTKERNEL_DESC(labelswitch, LabelSwitchDesc); +MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc new file mode 100644 index 0000000000..ca1114a83f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/memcpy_async.h" + +#include +#include + +#include "runtime/mem.h" +#include "common/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/trans.h" +#include "utils/context/ms_context.h" + +using ge::model_runner::MemcpyAsyncTaskInfo; +using MemcpyAsyncTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +MemCpyAsyncKernel::MemCpyAsyncKernel() {} + +MemCpyAsyncKernel::~MemCpyAsyncKernel() {} + +bool MemCpyAsyncKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) { + if (inputs.size() != 1) { + MS_LOG(ERROR) << "inputs size is not one"; + return false; + } + if (outputs.size() != 1) { + MS_LOG(ERROR) << "outputs size is not one"; + return false; + } + + if (inputs[0]->addr == outputs[0]->addr) { + MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async"; + return true; + } + if (outputs[0]->size < inputs[0]->size) { + MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; + } + // input x -> memcpy_async -> AllReduce + if (outputs[0]->size > inputs[0]->size) { + MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; + } + rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, + RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "MemCpyAsync op rtMemcpyAsync failed!"; + return false; + } + return true; +} + +bool MemCpyAsyncKernel::Init(const mindspore::AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + GetInputOutputDataType(anf_node); + GetInputOutputTotalCount(anf_node); + return true; +} + +void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + size_t input_size = AnfAlgo::GetInputTensorNum(anf_node); + if (input_size != 1) { + MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; + } + input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0); +} + +void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + size_t input_size = AnfAlgo::GetInputTensorNum(anf_node); + if (input_size != 1) { + MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; + } + size_t type_size = trans::TypeIdSize(input_type_id_); + std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0); + size_t total_size = 1; + for (size_t i = 0; i < shape_i.size(); i++) { + total_size = total_size * shape_i[i]; + } + total_size *= type_size; + MS_LOG(INFO) << "MemCpyAsync size[" << total_size << "]"; + input_size_list_.emplace_back(total_size); + output_size_list_.emplace_back(total_size); +} + +std::vector MemCpyAsyncKernel::GenTask(const std::vector &inputs, + const std::vector &, + const std::vector &outputs, uint32_t stream_id) { + if (inputs.size() != 1) { + MS_LOG(EXCEPTION) << "MemCpyAsync op inputs is not one"; + } + + if (outputs.size() != 1) { + MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one"; + } + + if (outputs[0]->size < inputs[0]->size) { + MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; + } + // input x -> memcpy_async -> AllReduce + if (outputs[0]->size > inputs[0]->size) { + MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; + } + + stream_id_ = stream_id; + std::shared_ptr task_info_ptr = + std::make_shared(kernel_name_, stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, + inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE, NeedDump()); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} + +const std::vector data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, + kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, + kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, kNumberTypeFloat16, + kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool}; +const std::vector format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, + kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, + kOpFormat_C1HWNCoC0}; + +MemCpyAsyncDesc::MemCpyAsyncDesc() {} + +MemCpyAsyncDesc::~MemCpyAsyncDesc() {} + +std::vector> MemCpyAsyncDesc::GetKernelInfo() { + std::vector> memcpy_build_info{}; + for (const auto &format : format_list) { + for (const auto &type : data_type_list) { + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + vector input_format{format}; + vector input_type{type}; + vector output_format{format}; + vector output_type{type}; + builder.SetInputsFormat(input_format); + builder.SetInputsDeviceType(input_type); + builder.SetOutputsFormat(output_format); + builder.SetOutputsDeviceType(output_type); + builder.SetProcessor(AICORE); + builder.SetKernelType(RT_KERNEL); + builder.SetFusionType(OPAQUE); + memcpy_build_info.emplace_back(builder.Build()); + } + } + return memcpy_build_info; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h new file mode 100644 index 0000000000..07a782be50 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H +#define MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class MemCpyAsyncKernel : public RtKernel { + public: + MemCpyAsyncKernel(); + ~MemCpyAsyncKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + void GetInputOutputDataType(const AnfNodePtr &anf_node); + void GetInputOutputTotalCount(const AnfNodePtr &anf_node); + TypeId input_type_id_{}; +}; + +class MemCpyAsyncDesc : public RtKerDesc { + public: + MemCpyAsyncDesc(); + ~MemCpyAsyncDesc() override; + std::vector> GetKernelInfo() override; +}; + +MS_REG_RTKERNEL_DESC(memcpy_async, MemCpyAsyncDesc); +MS_REG_RTKERNEL(memcpy_async, MemCpyAsyncKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc new file mode 100644 index 0000000000..8213468b48 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/profiling_kernel_mod.h" + +#include +#include +#include + +#include "framework/ge_runtime/task_info.h" +#include "runtime/device/ascend/profiling/profiling_utils.h" +#include "backend/session/anf_runtime_algorithm.h" + +using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo; +using mindspore::device::ascend::ProfilingUtils; + +namespace mindspore { +namespace kernel { +bool ProfilingKernelMod::Init(const AnfNodePtr &anf_node) { + MS_LOG(INFO) << "[profiling] init profiling kernel mod"; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + + ValuePtr notify_ptr = primitive->GetAttr(ProfilingUtils::kNotify); + MS_EXCEPTION_IF_NULL(notify_ptr); + + ValuePtr log_id_ptr = primitive->GetAttr(ProfilingUtils::kProfilerTraceId); + MS_EXCEPTION_IF_NULL(log_id_ptr); + + ValuePtr flags_ptr = primitive->GetAttr(ProfilingUtils::kFlags); + MS_EXCEPTION_IF_NULL(flags_ptr); + + notify_ = GetValue(notify_ptr); + log_id_ = GetValue(log_id_ptr); + flags_ = GetValue(flags_ptr); + MS_LOG(INFO) << "[profiling] profiling kernel notify_:" << notify_ << ", log_id_:" << log_id_ + << ", flags_:" << flags_; + return true; +} + +bool ProfilingKernelMod::Launch(const std::vector & /*inputs*/, + const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void * /*stream_ptr*/) { + return true; +} + +std::vector ProfilingKernelMod::GenTask(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) { + MS_LOG(INFO) << "gen task inputs size:" << inputs.size() << ", workspace size:" << workspace.size() + << ", outputs size:" << outputs.size(); + stream_id_ = stream_id; + std::shared_ptr task_info_ptr = + std::make_shared(kernel_name_, stream_id, log_id_, notify_, flags_); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h new file mode 100644 index 0000000000..cdb43afb3e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +namespace mindspore { +namespace kernel { +class ProfilingKernelMod : public RtKernel { + public: + ProfilingKernelMod() = default; + ~ProfilingKernelMod() override = default; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + bool Init(const AnfNodePtr &anf_node) override; + + private: + uint64_t log_id_{0}; + bool notify_{true}; + uint32_t flags_{0}; +}; +MS_REG_RTKERNEL(profiling, ProfilingKernelMod); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc new file mode 100644 index 0000000000..cee0ef2fdc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/recv.h" +#include +#include "runtime/stream.h" +#include "utils/context/ms_context.h" +#include "runtime/device/ascend/ascend_stream_assign.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +using ge::model_runner::EventWaitTaskInfo; +using mindspore::device::ascend::AscendStreamAssign; +using EventWaitTaskInfoPtr = std::shared_ptr; + +RecvKernel::RecvKernel() { event_id_ = 0; } + +RecvKernel::~RecvKernel() {} + +bool RecvKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (!AnfAlgo::HasNodeAttr(kAttrEventId, anf_node->cast())) { + MS_LOG(EXCEPTION) << "RecvKernel has no attr kAttrEventId"; + } + event_id_ = GetValue(primitive->GetAttr(kAttrEventId)); + MS_LOG(INFO) << "recv op event_id_:" << event_id_; + return true; +} + +bool RecvKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + rtEvent_t stream_event{}; + auto status = rtStreamWaitEvent(stream_ptr, stream_event); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Recv rtStreamWaitEvent failed!"; + return false; + } + return true; +} + +std::vector RecvKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "RecvKernel GenTask event_id_:" << event_id_ << ", stream_id_:" << stream_id; + stream_id_ = stream_id; + EventWaitTaskInfoPtr task_info_ptr = std::make_shared(kernel_name_, stream_id, event_id_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h new file mode 100644 index 0000000000..73e0214eae --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RECV_H +#define MINDSPORE_CCSRC_KERNEL_RTS_RECV_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class RecvKernel : public RtKernel { + public: + RecvKernel(); + ~RecvKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + uint32_t event_id_; +}; + +MS_REG_RTKERNEL(recv, RecvKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_RECV_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.cc new file mode 100644 index 0000000000..9279a84cf0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/rt_kernel.h" + +namespace mindspore { +namespace kernel { +void RtKernelFactory::Registe(const std::string &name, RtKernelCreater &&fun) { + (void)fmap_.emplace(name, std::move(fun)); +} + +std::shared_ptr RtKernelFactory::Create(const std::string &name) { + const auto &map = Get().fmap_; + auto it = map.find(name); + if (it != map.end() && it->second) { + return (it->second)(); + } + return nullptr; +} + +RtKernelFactory &RtKernelFactory::Get() { + static RtKernelFactory _this; + return _this; +} + +RtKernel::RtKernel() {} + +RtKernel::~RtKernel() {} + +bool RtKernel::Init(const mindspore::AnfNodePtr & /*anf_node*/) { return true; } + +const std::vector &RtKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &RtKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &RtKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.h new file mode 100644 index 0000000000..dc0aa3e283 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "backend/kernel_compiler/task_stream.h" + +namespace mindspore { +namespace kernel { +class RtKernel : public AscendKernelMod { + public: + RtKernel(); + ~RtKernel() override; + virtual bool Init(const AnfNodePtr &anf_node); + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + + protected: + mutable std::vector input_size_list_; + mutable std::vector output_size_list_; + mutable std::vector workspace_size_list_; +}; + +using RTKernelPtr = std::shared_ptr; + +using RtKernelCreater = std::function()>; +class RtKernelFactory { + RtKernelFactory() = default; + ~RtKernelFactory() = default; + + public: + static RtKernelFactory &Get(); + void Registe(const std::string &name, RtKernelCreater &&fun); + static std::shared_ptr Create(const std::string &name); + + private: + std::map fmap_; +}; + +class _RtKernelRegister { + public: + _RtKernelRegister(const std::string &name, RtKernelCreater &&fun) { + RtKernelFactory::Get().Registe(name, std::move(fun)); + } + ~_RtKernelRegister() = default; +}; + +#define _MS_REG_RTKERNEL_REG(KNAME, clazz) \ + static_assert(std::is_base_of::value, " must be base of RtKernel"); \ + static const _RtKernelRegister g_##KNAME##_##_RtKernel_reg(#KNAME, []() { return std::make_shared(); }); + +#define MS_REG_RTKERNEL(KNAME, clazz) _MS_REG_RTKERNEL_REG(KNAME, clazz) +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.cc new file mode 100644 index 0000000000..9704a9b97f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/rt_kernel_build.h" + +#include +#include +#include +#include + +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +KernelModPtr RtOpBuild(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + (void)std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower); + MS_LOG(INFO) << "Op Name(tolower)[" << op_name << "]"; + auto ker_ptr = RtKernelFactory::Create(op_name); + MS_EXCEPTION_IF_NULL(ker_ptr); + if (!ker_ptr->Init(anf_node)) { + MS_LOG(ERROR) << "Rt Op initialize failed!"; + return nullptr; + } + + return ker_ptr; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.h new file mode 100644 index 0000000000..ccfb8d923b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H +#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H + +#include +#include +#include "backend/kernel_compiler/kernel.h" +namespace mindspore { +namespace kernel { +KernelModPtr RtOpBuild(const AnfNodePtr &anf_node); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc new file mode 100755 index 0000000000..9501aed5f2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/rt_kernel_info.h" +#include +#include +#include "utils/convert_utils.h" +#include "utils/utils.h" +#include "common/utils.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +void RtKerDescFactory::Register(const std::string &name, RtKerDescCreater &&fun) { + if (fmap_.find(name) == fmap_.end()) { + (void)fmap_.emplace(name, std::move(fun)); + } +} + +std::shared_ptr RtKerDescFactory::Create(const std::string &name) { + const auto &map = Get().fmap_; + auto it = map.find(name); + if (it != map.end() && it->second) { + return (it->second)(); + } + return nullptr; +} + +RtKerDescFactory &RtKerDescFactory::Get() { + static RtKerDescFactory _this; + return _this; +} + +static bool IsDefaultKernelInfo(const std::string &name) { + static const std::set white_list = {kStreamSwitchOpName, kStreamActiveOpName, kLabelSetOpName, + kLabelGotoOpName}; + return white_list.find(name) != white_list.end(); +} + +void GetRtKelInfo(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_info_list); + MS_EXCEPTION_IF_NULL(kernel_node); + std::string opNameLower = AnfAlgo::GetCNodeName(kernel_node); + (void)std::transform(opNameLower.begin(), opNameLower.end(), opNameLower.begin(), ::tolower); + + auto ker_desc_ptr = RtKerDescFactory::Create(opNameLower); + if (ker_desc_ptr != nullptr && !ker_desc_ptr->GetKernelInfo().empty()) { + *kernel_info_list = ker_desc_ptr->GetKernelInfo(); + return; + } + // if can't find kernel info in kernel info database, use the default kernel info + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (IsDefaultKernelInfo(node_name)) { + auto kernel_build_info_builder = std::make_shared(); + // set input infos + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + kernel_build_info_builder->SetInputsFormat(std::vector(input_num, kOpFormat_DEFAULT)); + std::vector input_types = {}; + for (size_t i = 0; i < input_num; i++) { + input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i)); + } + kernel_build_info_builder->SetInputsDeviceType(input_types); + // set output info + auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + kernel_build_info_builder->SetOutputsFormat(std::vector(output_num, kOpFormat_DEFAULT)); + kernel_build_info_builder->SetOutputsDeviceType(std::vector(output_num, TypeId::kTypeUnknown)); + // set ohter info + kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); + kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); + kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); + kernel_info_list->push_back(kernel_build_info_builder->Build()); + return; + } + MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "]."; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.h b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.h new file mode 100644 index 0000000000..6048fb3779 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H +#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/dtype.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "utils/utils.h" + +namespace mindspore { +namespace kernel { +class RtKerDesc { + public: + virtual ~RtKerDesc() {} + virtual std::vector> GetKernelInfo() { + return std::vector>{}; + } +}; + +using RtKerDescCreater = std::function()>; +class RtKerDescFactory { + RtKerDescFactory() = default; + ~RtKerDescFactory() = default; + + public: + static RtKerDescFactory &Get(); + void Register(const std::string &name, RtKerDescCreater &&fun); + static std::shared_ptr Create(const std::string &name); + + private: + std::map fmap_; +}; + +class _RtKerDescRegister { + public: + _RtKerDescRegister(const std::string &name, RtKerDescCreater &&fun) { + RtKerDescFactory::Get().Register(name, std::move(fun)); + } + ~_RtKerDescRegister() = default; +}; + +#define _MS_REG_RTKERNEL_DESC_REG(KNAME, clazz) \ + static_assert(std::is_base_of::value, " must be base of RtKerDesc"); \ + static const _RtKerDescRegister g_##KNAME##_##_rtkernel_desc_reg(#KNAME, []() { return std::make_shared(); }); + +#define MS_REG_RTKERNEL_DESC(KNAME, clazz) _MS_REG_RTKERNEL_DESC_REG(KNAME, clazz) + +void GetRtKelInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/send.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/send.cc new file mode 100644 index 0000000000..11c0a7d668 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/send.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/send.h" +#include +#include "runtime/event.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::EventRecordTaskInfo; +using EventRecordTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +SendKernel::SendKernel() { event_id_ = 0; } + +SendKernel::~SendKernel() {} + +bool SendKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (!AnfAlgo::HasNodeAttr(kAttrEventId, anf_node->cast())) { + MS_LOG(EXCEPTION) << "SendKernel has no attr kAttrEventId"; + } + event_id_ = GetValue(primitive->GetAttr(kAttrEventId)); + MS_LOG(INFO) << "send op event id:" << event_id_; + return true; +} + +bool SendKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + rtEvent_t event{}; + rtError_t status = rtEventRecord(event, stream_ptr); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Send op rtEventRecord failed!"; + return false; + } + return true; +} + +std::vector SendKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "SendKernel GenTask event id:" << event_id_ << ", stream id:" << stream_id; + stream_id_ = stream_id; + EventRecordTaskInfoPtr task_info_ptr = std::make_shared(kernel_name_, stream_id, event_id_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/send.h b/mindspore/ccsrc/backend/kernel_compiler/rts/send.h new file mode 100644 index 0000000000..dbadb1ef44 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/send.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_SEND_H +#define MINDSPORE_CCSRC_KERNEL_RTS_SEND_H +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class SendKernel : public RtKernel { + public: + SendKernel(); + ~SendKernel() override; + bool Init(const AnfNodePtr &anf_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + uint32_t event_id_; +}; + +MS_REG_RTKERNEL(send, SendKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_SEND_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc new file mode 100644 index 0000000000..e33549973d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/stream_active.h" +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::StreamActiveTaskInfo; +using StreamActiveTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +StreamActiveKernel::StreamActiveKernel() { active_streams_index_ = {}; } + +StreamActiveKernel::~StreamActiveKernel() {} + +bool StreamActiveKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "stream active op init start"; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (!AnfAlgo::HasNodeAttr(kAttrActiveStreamList, anf_node->cast())) { + MS_LOG(EXCEPTION) << "StreamActiveKernel has no attr kAttrActiveStreamList"; + } + active_streams_index_ = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); + return true; +} + +bool StreamActiveKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + MS_LOG(INFO) << "Stream active op launch start"; + + if (active_streams_index_.empty()) { + MS_LOG(ERROR) << "activeStreamList_ is empty!"; + return false; + } + + rtStream_t act_stream; + rtError_t status; + for (auto index : active_streams_index_) { + act_stream = kernel::TaskStream::GetInstance()->gen_stream_list()[index]; + status = rtStreamActive(act_stream, stream_ptr); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Stream active failed!"; + return false; + } + } + return true; +} + +std::vector StreamActiveKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "StreamActiveKernel GenTask active stream size:" << active_streams_index_.size() + << ", stream id:" << stream_id; + stream_id_ = stream_id; + std::vector task_info_list; + for (auto &index : active_streams_index_) { + std::shared_ptr task_info_ptr = + std::make_shared(kernel_name_, stream_id, index); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + MS_LOG(INFO) << "StreamActiveKernel GenTask: streamId:" << stream_id << ", Active streamId:" << index; + } + return task_info_list; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.h b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.h new file mode 100644 index 0000000000..409c3437dc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H +#define MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class StreamActiveKernel : public RtKernel { + public: + StreamActiveKernel(); + ~StreamActiveKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + std::vector active_streams_index_; +}; + +MS_REG_RTKERNEL(streamactive, StreamActiveKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc new file mode 100644 index 0000000000..5fe03b1960 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/stream_switch.h" + +#include +#include + +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::StreamSwitchTaskInfo; +using StreamSwitchTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +StreamSwitchKernel::StreamSwitchKernel() { + cond_ = RT_EQUAL; + true_stream_index_ = 0; + data_type_ = RT_SWITCH_INT32; +} + +StreamSwitchKernel::~StreamSwitchKernel() {} + +bool StreamSwitchKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "stream switch op init start"; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (!AnfAlgo::HasNodeAttr(kAttrSwitchCondition, anf_node->cast())) { + MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrSwitchCondition"; + } + cond_ = tagRtCondition(GetValue(primitive->GetAttr(kAttrSwitchCondition))); + if (!AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, anf_node->cast())) { + MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrTrueBranchStream"; + } + true_stream_index_ = GetValue(primitive->GetAttr(kAttrTrueBranchStream)); + if (!AnfAlgo::HasNodeAttr(kAttrDataType, anf_node->cast())) { + MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrDataType"; + } + data_type_ = tagRtSwitchDataType(GetValue(primitive->GetAttr(kAttrDataType))); + MS_LOG(INFO) << "cond_:" << static_cast(cond_) << ", true_stream_index_:" << true_stream_index_ + << ", data_type_:" << static_cast(data_type_); + return true; +} + +bool StreamSwitchKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + MS_LOG(INFO) << "stream switch op launch start"; + if (inputs.size() != 2) { + MS_LOG(EXCEPTION) << "Stream switch inputs size is " << inputs.size() << ", only support 2"; + } + + void *loop_cnt = inputs[0]->addr; + void *ites_per_loop = inputs[1]->addr; + rtStream_t true_stream_ = kernel::TaskStream::GetInstance()->gen_stream_list()[true_stream_index_]; + rtError_t status = rtStreamSwitchEx(loop_cnt, cond_, ites_per_loop, true_stream_, stream_ptr, data_type_); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Stream switch failed!"; + return false; + } + return true; +} + +std::vector StreamSwitchKernel::GenTask(const std::vector &inputs, + const std::vector &, const std::vector &, + uint32_t stream_id) { + MS_LOG(INFO) << "StreamSwitchKernel GenTask start"; + if (inputs.size() != 2) { + MS_LOG(EXCEPTION) << "stream switch inputs size is " << inputs.size() << ", is not two"; + } + stream_id_ = stream_id; + MS_EXCEPTION_IF_NULL(inputs[0]); + MS_EXCEPTION_IF_NULL(inputs[1]); + auto loop_cnt = inputs[0]->addr; + auto ites_per_loop = inputs[1]->addr; + MS_LOG(INFO) << "cond_:" << static_cast(cond_) << ", true_stream_index_:" << true_stream_index_ + << ", stream_id:" << stream_id; + std::shared_ptr task_info_ptr = std::make_shared( + kernel_name_, stream_id, true_stream_index_, loop_cnt, ites_per_loop, cond_, data_type_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.h b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.h new file mode 100644 index 0000000000..64a51f68bf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H +#define MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class StreamSwitchKernel : public RtKernel { + public: + StreamSwitchKernel(); + ~StreamSwitchKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + rtCondition_t cond_; + uint32_t true_stream_index_; + rtSwitchDataType_t data_type_; +}; + +MS_REG_RTKERNEL(streamswitch, StreamSwitchKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H diff --git a/mindspore/ccsrc/kernel/task_stream.h b/mindspore/ccsrc/backend/kernel_compiler/task_stream.h similarity index 100% rename from mindspore/ccsrc/kernel/task_stream.h rename to mindspore/ccsrc/backend/kernel_compiler/task_stream.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc new file mode 100644 index 0000000000..449a9f4556 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc @@ -0,0 +1,424 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_adapter.h" + +#include +#include +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/opinfo.h" + +namespace mindspore { +namespace kernel { +namespace tbe { +static std::map tbe_func_adapter_map = { + {"softmax", "softmax_v2"}, + {"log_softmax", "log_softmax_v2"}, + {"apply_momentum", "apply_momentum_d"}, + {"apply_ftrl", "apply_ftrl_d"}, + {"re_lu6", "relu6"}, + {"re_lu6_grad", "relu6_grad"}, + {"re_lu", "relu"}, + {"re_luv2", "relu_v2"}, + {"p_re_lu", "prelu"}, + {"p_re_lu_grad", "prelu_grad"}, + {"tensor_add", "add"}, + {"reduce_mean", "reduce_mean_d"}, + {"reduce_max", "reduce_max_d"}, + {"reduce_min", "reduce_min_d"}, + {"avg_pool_grad", "avg_pool_grad_d"}, + {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, + {"conv2d_backprop_input", "conv2d_backprop_input_d"}, + {"depthwise_conv2d_native", "depthwise_conv2d"}, + {"depthwise_conv2d_native_backprop_filter", "depthwise_conv2d_backprop_filter_d"}, + {"depthwise_conv2d_native_backprop_input", "depthwise_conv2d_backprop_input_d"}, + {"scatter_nd", "scatter_nd_d"}, + {"tile", "tile_d"}, + {"gather_v2", "gather_v2_d"}, + {"sparse_gather_v2", "gather_v2_d"}, + {"batch_mat_mul", "batch_matmul"}, + {"b_n_training_reduce", "bn_training_reduce"}, + {"b_n_training_update", "bn_training_update"}, + {"b_n_training_update_v2", "bn_training_update_v2"}, + {"b_n_training_update_v3", "bn_training_update_v3"}, + {"b_n_training_reduce_grad", "bn_training_reduce_grad"}, + {"b_n_training_update_grad", "bn_training_update_grad"}, + {"b_n_infer", "bn_infer"}, + {"b_n_infer_grad", "bn_infer_grad"}, + {"n_pu_clear_float_status", "n_p_u_clear_float_status"}, + {"n_pu_get_float_status", "n_p_u_get_float_status"}, + {"n_pu_alloc_float_status", "n_p_u_alloc_float_status"}, + {"dropout_do_mask", "drop_out_do_mask"}, + {"strided_slice", "strided_slice_d"}, + {"strided_slice_grad", "strided_slice_grad_d"}, + {"sparse_apply_ftrl", "sparse_apply_ftrl_d"}, + {"sparse_apply_ftrl_v2", "sparse_apply_ftrl_v2_d"}, + {"apply_ada_max", "apply_ada_max_d"}, + {"apply_adadelta", "apply_adadelta_d"}, + {"apply_adagrad", "apply_adagrad_d"}, + {"apply_adagrad_v2", "apply_adagradv2_d"}, + {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, + {"sparse_apply_adagrad_v2", "sparse_apply_adagrad_v2_d"}, + {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, + {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, + {"apply_add_sign", "apply_add_sign_d"}, + {"apply_power_sign", "apply_power_sign_d"}, + {"transpose", "transpose_d"}, + {"fill", "fill_d"}, + {"unsorted_segment_sum", "unsorted_segment_sum_d"}, + {"unsorted_segment_prod", "unsorted_segment_prod_d"}, + {"concat", "concat_d"}, + {"slice", "slice_d"}, + {"reduce_sum", "reduce_sum_d"}, + {"inplace_add", "inplace_add_d"}, + {"inplace_sub", "inplace_sub_d"}, + {"one_hot", "one_hot_d"}, + {"sum", "reduce_sum_d"}, + {"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"}, + {"lamb_next_mv", "lamb_next_m_v"}, + {"split", "split_d"}, + {"split_v", "split_v_d"}, + {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, + {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, + {"pad", "pad_d"}, + {"argmax", "arg_max_d"}, + {"argmin", "arg_min_d"}, + {"space_to_batch", "space_to_batch_d"}, + {"batch_to_space", "batch_to_space_d"}, + {"space_to_batch_nd", "space_to_batch_nd_d"}, + {"batch_to_space_nd", "batch_to_space_nd_d"}, + {"resize_bilinear", "resize_bilinear_v2_d"}, + {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, + {"adam", "apply_adam_d"}, + {"r_oi_align", "roi_align"}, + {"r_oi_align_grad", "roi_align_grad"}, + {"i_ou", "iou"}, + {"s_gd", "sgd"}, + {"l_rn", "lrn"}, + {"l_rn_grad", "lrn_grad"}, + {"l_ars_update", "lars_v2_update"}, + {"n_ms_with_mask", "nms_with_mask"}, + {"square_sum_all", "square_sum_all"}, + {"cum_sum", "cumsum_d"}, + {"range", "range_d"}, + {"lin_space", "lin_space_d"}, + {"inv_grad", "inv_grad"}, + {"apply_rms_prop", "apply_rms_prop_d"}, + {"cum_prod", "cumprod_d"}, + {"reduce_all", "reduce_all_d"}, + {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, + {"unsorted_segment_min", "unsorted_segment_min_d"}, + {"reduce_prod", "reduce_prod_d"}, + {"a_cos", "acos"}, + {"a_cos_grad", "acos_grad"}, + {"histogram_fixed_width", "histogram_fixed_width_d"}, + {"broadcast_to", "broadcast_to_d"}, + {"inplace_update", "inplace_update_d"}, + {"matrix_diag", "matrix_diag_d"}, + {"matrix_diag_part", "matrix_diag_part_d"}, + {"matrix_set_diag", "matrix_set_diag_d"}}; + +void TbeAdapter::NormalizeFuncName(std::string *func_name) { + if (func_name == nullptr) { + MS_LOG(EXCEPTION) << "func_name is null"; + } + std::string name_tmp; + bool sub_head = false; + for (string::iterator iter = func_name->begin(); iter != func_name->end(); ++iter) { + if (islower(*iter)) { + sub_head = false; + } + if (isdigit(*iter)) { + sub_head = true; + } + if (isupper(*iter) && iter != func_name->begin()) { + if (!sub_head) { + (void)name_tmp.insert(name_tmp.end(), '_'); + sub_head = true; + } else { + string::iterator iter_next = iter + 1; + if (iter_next != func_name->end()) { + if (islower(*iter_next)) { + (void)name_tmp.insert(name_tmp.end(), '_'); + } + } + } + } + (void)name_tmp.insert(name_tmp.end(), *iter); + } + (void)transform(name_tmp.begin(), name_tmp.end(), name_tmp.begin(), ::tolower); + *func_name = name_tmp; + auto iter = tbe_func_adapter_map.find(*func_name); + if (iter != tbe_func_adapter_map.end()) { + MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second; + *func_name = iter->second; + } +} + +void TbeAdapter::SetTbeAttrsForTransDataOp(const mindspore::AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + if (AnfAlgo::GetCNodeName(anf_node) == kTransDataOpName) { + std::string input_format = AnfAlgo::GetInputFormat(anf_node, 0); + std::string output_format = AnfAlgo::GetOutputFormat(anf_node, 0); + if (input_format == kOpFormat_DEFAULT) { + input_format = kOpFormat_NCHW; + } + if (output_format == kOpFormat_DEFAULT) { + output_format = kOpFormat_NCHW; + } + AnfAlgo::SetNodeAttr("src_format", MakeValue(input_format), anf_node); + AnfAlgo::SetNodeAttr("dst_format", MakeValue(output_format), anf_node); + } +} + +std::unordered_set input_order_adjusted_ops = { + "Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop", + "LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"}; + +void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector> const &inputs_list, + nlohmann::json *inputs_json) { + MS_EXCEPTION_IF_NULL(inputs_json); + if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { + (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json))); + } else { + if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { + inputs_json->push_back(inputs_list[2]); + inputs_json->push_back(inputs_list[0]); + inputs_json->push_back(inputs_list[1]); + for (size_t i = 3; i < inputs_list.size(); ++i) { + inputs_json->push_back(inputs_list[i]); + } + } else if (op_name == "ApplyCenteredRMSProp") { + // Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map + // TBE parameter to correspond python API parameter by latter's index using hardcode + inputs_json->push_back(inputs_list[0]); + inputs_json->push_back(inputs_list[1]); + inputs_json->push_back(inputs_list[2]); + inputs_json->push_back(inputs_list[3]); + inputs_json->push_back(inputs_list[5]); + inputs_json->push_back(inputs_list[6]); + inputs_json->push_back(inputs_list[7]); + inputs_json->push_back(inputs_list[8]); + inputs_json->push_back(inputs_list[4]); + } else { + inputs_json->push_back(inputs_list[1]); + inputs_json->push_back(inputs_list[0]); + for (size_t i = 2; i < inputs_list.size(); ++i) { + inputs_json->push_back(inputs_list[i]); + } + } + } +} + +void TbeAdapter::FusionInputOrderPass(const std::string &op_name, const std::vector &inputs_list, + std::vector *inputs_json) { + MS_EXCEPTION_IF_NULL(inputs_json); + if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { + (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json))); + } else { + if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { + inputs_json->emplace_back(inputs_list[2]); + inputs_json->emplace_back(inputs_list[0]); + inputs_json->emplace_back(inputs_list[1]); + for (size_t i = 3; i < inputs_list.size(); ++i) { + inputs_json->emplace_back(inputs_list[i]); + } + } else { + inputs_json->emplace_back(inputs_list[1]); + inputs_json->emplace_back(inputs_list[0]); + for (size_t i = 2; i < inputs_list.size(); ++i) { + inputs_json->emplace_back(inputs_list[i]); + } + } + } +} + +void TbeAdapter::FusionDataOrderPass(const std::string &op_name, const std::vector &data_layer, + std::vector *reorder_data_layer) { + MS_EXCEPTION_IF_NULL(reorder_data_layer); + if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { + (void)std::copy(data_layer.begin(), data_layer.end(), std::back_inserter((*reorder_data_layer))); + } else { + if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { + reorder_data_layer->emplace_back(data_layer[2]); + reorder_data_layer->emplace_back(data_layer[0]); + reorder_data_layer->emplace_back(data_layer[1]); + for (size_t i = 3; i < data_layer.size(); ++i) { + reorder_data_layer->emplace_back(data_layer[i]); + } + } else { + reorder_data_layer->emplace_back(data_layer[1]); + reorder_data_layer->emplace_back(data_layer[0]); + for (size_t i = 2; i < data_layer.size(); ++i) { + reorder_data_layer->emplace_back(data_layer[i]); + } + } + } +} + +std::map TbeAdapter::build_json_attr_pass_map_ = { + {"MaximumGrad", TbeAdapter::MaximumGradAttrJsonPass}, + {"MinimumGrad", TbeAdapter::MinimumGradAttrJsonPass}, + {"Cast", TbeAdapter::CastAttrJsonPass}}; + +bool TbeAdapter::RunAttrPass(const mindspore::AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json) { + MS_EXCEPTION_IF_NULL(attrs_json); + auto cnode_name = AnfAlgo::GetCNodeName(anf_node); + auto FPass = build_json_attr_pass_map_.find(cnode_name); + if (FPass != build_json_attr_pass_map_.end()) { + FPass->second(anf_node, op_info_attrs, attrs_json); + return true; + } + return false; +} + +void TbeAdapter::MaximumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(attrs_json); + auto attr_num = op_info_attrs.size(); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + for (size_t i = 0; i < attr_num; i++) { + nlohmann::json attr_obj; + MS_EXCEPTION_IF_NULL(op_info_attrs[i]); + std::string attr_name = op_info_attrs[i]->name(); + auto value = primitive->GetAttr(attr_name); + if (value != nullptr) { + bool attr_value = GetValue(value); + attr_obj["value"] = attr_value; + attr_obj["valid"] = true; + } else { + attr_obj["valid"] = false; + } + attr_obj["name"] = attr_name; + attrs_json->push_back(attr_obj); + } + MS_LOG(INFO) << "MaximumGradAttrJsonPass done."; +} + +void TbeAdapter::MinimumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(attrs_json); + auto attr_num = op_info_attrs.size(); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + for (size_t i = 0; i < attr_num; i++) { + nlohmann::json attr_obj; + MS_EXCEPTION_IF_NULL(op_info_attrs[i]); + std::string attr_name = op_info_attrs[i]->name(); + auto value = primitive->GetAttr(attr_name); + if (value != nullptr) { + bool attr_value = GetValue(value); + attr_obj["value"] = attr_value; + attr_obj["valid"] = true; + } else { + attr_obj["valid"] = false; + } + attr_obj["name"] = attr_name; + attrs_json->push_back(attr_obj); + } + MS_LOG(INFO) << "MinimumGradAttrJsonPass done."; +} + +static int TypeStrToDstType(const std::string &type_str) { + int ret = -1; + if (type_str == "Float" || type_str == "Float32") { + ret = 0; + } else if (type_str == "Float16") { + ret = 1; + } else if (type_str == "Int8") { + ret = 2; + } else if (type_str == "Int32") { + ret = 3; + } else if (type_str == "UInt8") { + ret = 4; + } else if (type_str == "UInt64") { + ret = 10; + } else if (type_str == "Bool") { + ret = 12; + } else { + MS_LOG(INFO) << "Error type str is invailed: " << type_str; + } + return ret; +} + +void TbeAdapter::CastAttrJsonPass(const mindspore::AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(attrs_json); + if (op_info_attrs.size() != 1) { + MS_LOG(INFO) << "cast node should has dst_type attr"; + return; + } + auto attr_name = op_info_attrs[0]->name(); + auto type_ptr = std::make_shared(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, 0))); + MS_EXCEPTION_IF_NULL(type_ptr); + auto type_element = type_ptr->element(); + MS_EXCEPTION_IF_NULL(type_element); + auto dtype = type_element->ToString(); + auto dst_type_value = TypeStrToDstType(dtype); + nlohmann::json attr_obj; + attr_obj["value"] = dst_type_value; + attr_obj["valid"] = true; + attr_obj["name"] = attr_name; + attrs_json->push_back(attr_obj); + MS_LOG(INFO) << "CastAttrJsonPass done."; +} + +void TbeAdapter::GenTopKV2IndicesTensorInfo(const std::shared_ptr &anf_node, + size_t real_input_index, std::vector *input_list, + mindspore::kernel::kCreaterType creater_type) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(input_list); + auto input_x_shape = AnfAlgo::GetOutputInferShape(anf_node, 0); + size_t last_dim = input_x_shape[input_x_shape.size() - 1]; + std::vector tensor_shape = {last_dim}; + std::vector tensor_origin_shape = {last_dim}; + std::string tensor_format = AnfAlgo::GetInputFormat(anf_node, static_cast(real_input_index)); + if (tensor_format == kOpFormat_DEFAULT) { + tensor_format = kOpFormat_NCHW; + } + std::string tensor_origin_format = kOpFormat_NCHW; + std::string tensor_dtype = "float16"; + nlohmann::json input_desc_json; + input_desc_json["dtype"] = tensor_dtype; + input_desc_json["name"] = AnfAlgo::GetCNodeName(anf_node); + input_desc_json["ori_shape"] = tensor_origin_shape; + input_desc_json["ori_format"] = tensor_origin_format; + input_desc_json["shape"] = tensor_shape; + if (creater_type == OP_SELECT_FORMAT) { + input_desc_json["format"] = tensor_origin_format; + } else { + input_desc_json["format"] = tensor_format; + } + input_desc_json["valid"] = true; + input_list->emplace_back(input_desc_json); +} +} // namespace tbe +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h new file mode 100644 index 0000000000..aa09efc11f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H +#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H + +#include +#include +#include +#include +#include "nlohmann/json.hpp" +#include "base/base.h" +#include "backend/kernel_compiler/oplib/opinfo.h" +// Note: This file is mainly used to adapt the ME front-end operator description and +// the TBE back-end operator implementation difference +namespace mindspore { +namespace kernel { +enum kCreaterType : int { SINGLE_BUILD = 0, PREBUILD, OP_SELECT_FORMAT, CHECK_SUPPORTED, OP_PRE_COMPILE }; +namespace tbe { +using FAttrsPass = void (*)(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, + nlohmann::json *attrs_json); +class TbeAdapter { + public: + TbeAdapter() = default; + ~TbeAdapter() = default; + static void NormalizeFuncName(std::string *func_name); + static void SetTbeAttrsForTransDataOp(const AnfNodePtr &anf_node); + static void InputOrderPass(const std::string &op_name, std::vector> const &inputs_list, + nlohmann::json *inputs_json); + static bool RunAttrPass(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, + nlohmann::json *attrs_json); + static void GenTopKV2IndicesTensorInfo(const std::shared_ptr &anf_node, size_t real_input_index, + std::vector *input_list, kCreaterType creater_type); + + static void FusionInputOrderPass(const std::string &op_name, const std::vector &inputs_list, + std::vector *inputs_json); + static void FusionDataOrderPass(const std::string &op_name, const std::vector &data_layer, + std::vector *reorder_data_layer); + + private: + static void MaximumGradAttrJsonPass(const AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json); + static void MinimumGradAttrJsonPass(const AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json); + + static void CastAttrJsonPass(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, + nlohmann::json *attrs_json); + + static std::map build_json_attr_pass_map_; +}; +} // namespace tbe +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc new file mode 100644 index 0000000000..e7fd94ef84 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +namespace tbe { +const std::unordered_map type_str_id_maps = { + {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16}, + {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64}, + {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8}, + {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32}, + {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, + {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, + {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, + {"bool", TypeId::kNumberTypeBool}, +}; + +const std::map type_id_str_maps = { + {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"}, + {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"}, + {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"}, + {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"}, + {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, + {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, + {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, + {TypeId::kNumberTypeBool, "int8"}, +}; + +const std::map type_str_maps = { + {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, + {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, + {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool", "int8"}, {"Float64", "float64"}, +}; + +const std::unordered_map type_nbyte_maps = { + {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2}, + {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)}, + {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2}, + {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)}, +}; + +const std::unordered_map fusion_type_maps = { + {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, + {"SEGMENT", FusionType::SEGMENT}, {"DYNAMIC", FusionType::DYNAMIC}, {"OPAQUE", FusionType::OPAQUE}, +}; + +TypeId DtypeToTypeId(const std::string &dtypes) { + auto iter = type_str_id_maps.find(dtypes); + if (iter == type_str_id_maps.end()) { + MS_LOG(EXCEPTION) << "Illegal input device dtype: " << dtypes; + } + return iter->second; +} + +std::string TypeIdToString(TypeId type_id) { + auto iter = type_id_str_maps.find(type_id); + if (iter == type_id_str_maps.end()) { + MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id); + } + return iter->second; +} + +size_t GetDtypeNbyte(const std::string &dtypes) { + auto iter = type_nbyte_maps.find(dtypes); + if (iter == type_nbyte_maps.end()) { + MS_LOG(EXCEPTION) << "Illegal input dtype: " << dtypes; + } + return iter->second; +} + +FusionType GetFusionType(const std::string &pattern) { + auto iter = fusion_type_maps.find(pattern); + if (iter == fusion_type_maps.end()) { + MS_LOG(INFO) << "Illegal fusion pattern: " << pattern; + return UNKNOWN_FUSION_TYPE; + } + return iter->second; +} + +std::string GetProcessor(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string device; + switch (AnfAlgo::GetProcessor(anf_node)) { + case Processor::AICORE: + device = kProcessorAiCore; + break; + default: + MS_LOG(INFO) << "Unknown processor type." << anf_node->fullname_with_scope(); + break; + } + return device; +} +} // namespace tbe +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.h new file mode 100644 index 0000000000..dea058cd56 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ + +#include +#include "backend/kernel_compiler/kernel.h" +#include "base/base.h" +#include "ir/dtype/type.h" + +namespace mindspore { +namespace kernel { +namespace tbe { +constexpr auto kProcessorAiCore = "aicore"; +TypeId DtypeToTypeId(const std::string &dtypes); + +std::string TypeIdToString(TypeId type_id); + +size_t GetDtypeNbyte(const std::string &dtypes); + +FusionType GetFusionType(const std::string &pattern); + +std::string GetProcessor(const AnfNodePtr &anf_node); +} // namespace tbe +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc new file mode 100644 index 0000000000..73642b291a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -0,0 +1,1019 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include +#include +#include +#include "frontend/operator/ops.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/tbe/tbe_adapter.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" + +namespace mindspore { +namespace kernel { +using mindspore::kernel::tbe::TbeAdapter; +using mindspore::kernel::tbe::TbeUtils; +constexpr auto kFusionOpList = "op_list"; +constexpr auto kFusionKernelNamePrfix = "te_fusion"; +constexpr auto kOptional = "optional_"; +constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z"; +constexpr auto kPlatform = "platform"; +constexpr auto kPlatTBE = "TBE"; +constexpr auto kGenModel = "gen_model"; +constexpr auto kSingle = "single"; +constexpr auto kImplPath = "impl_path"; +constexpr auto kJInputs = "inputs"; +constexpr auto kJOutputs = "outputs"; +constexpr auto kJAttrs = "attrs"; +constexpr auto kJKernelName = "kernel_name"; +constexpr auto kJOpInfo = "op_info"; +constexpr auto kJDtype = "dtype"; +constexpr auto kJtype = "type"; +constexpr auto kJName = "name"; +constexpr auto kJOriShape = "ori_shape"; +constexpr auto kJOriFormat = "ori_format"; +constexpr auto kJShape = "shape"; +constexpr auto kJFormat = "format"; +constexpr auto kJValid = "valid"; +constexpr auto kJParamType = "param_type"; +constexpr auto kParamDynamic = "dynamic"; +constexpr auto kParamRequred = "required"; +constexpr auto kJDataType = "data_type"; +constexpr auto kJOutputIndex = "output_index"; +constexpr auto kJOutputDesc = "output_desc"; +constexpr auto kJInputDesc = "input_desc"; +constexpr auto kVTypeInt = "int"; +constexpr auto kVTypeStr = "str"; +constexpr auto kVTypeBool = "bool"; +constexpr auto kVTypeFloat = "float"; +constexpr auto kVTypeListInt = "listInt"; +constexpr auto kVTypeInt32 = "Int32"; +constexpr auto kVTypeListUInt64 = "listUInt64"; +constexpr auto kVTypeListFloat = "listFloat"; +constexpr auto kVTypeListListInt = "listListInt"; +constexpr auto kJValue = "value"; +constexpr auto kJDynIndex = "dyn_index"; +constexpr auto kJFuncName = "func_name"; + +std::string NormalizeFullScopeName(const string &full_scope_name) { + // exp:Default/ReLU-op0 -->Default_ReLU_op0 + string normal_ret = full_scope_name; + std::replace(normal_ret.begin(), normal_ret.end(), '/', '_'); + std::replace(normal_ret.begin(), normal_ret.end(), '-', '_'); + return normal_ret; +} + +bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr &anf_node, + nlohmann::json *kernel_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(kernel_json); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); + MS_EXCEPTION_IF_NULL(op_info_ptr); + (*kernel_json)[kPlatform] = kPlatTBE; + (*kernel_json)[kGenModel] = kSingle; + (*kernel_json)[kImplPath] = op_info_ptr->impl_path(); + nlohmann::json op_info_json; + if (op_info_ptr->impl_path().empty()) { + tbe::TbeAdapter::NormalizeFuncName(&op_name); + } else { + op_name = op_info_ptr->kernel_name(); + } + op_info_json[kJName] = op_name; + // generate inputs json + nlohmann::json inputs_json; + if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) { + MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate inputs json failed"; + return false; + } + op_info_json[kJInputs] = inputs_json; + // generate outputs json + nlohmann::json outputs_json; + if (!GenTbeOutputsJson(anf_node, op_info_ptr, &outputs_json)) { + MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate outputs json failed"; + return false; + } + op_info_json[kJOutputs] = outputs_json; + // generate attrs json + nlohmann::json attrs_json; + (void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json); + op_info_json[kJAttrs] = attrs_json; + std::string json_str = op_info_json.dump(); + size_t hash_id = std::hash()(json_str); + json_name_ = op_name + "_" + std::to_string(hash_id); + json_info_ = json_str; + if (creater_type_ == PREBUILD) { + op_info_json[kJKernelName] = NormalizeFullScopeName(anf_node->fullname_with_scope()); + } else { + op_info_json[kJKernelName] = json_name_; + } + (*kernel_json)[kJOpInfo] = op_info_json; + if (creater_type_ == SINGLE_BUILD) { + TbeUtils::SaveJsonInfo(json_name_, json_info_); + } + + MS_LOG(INFO) << "Operate type:" << creater_type_ << ", full scope name is :" << anf_node->fullname_with_scope() + << ", json info name is : " << json_name_ << ", kernel json:" << kernel_json->dump(); + + return true; +} + +bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr &anf_node, size_t real_input_index, + bool value, const std::shared_ptr &input_ptr, + const string &op_input_name, size_t input_i, + std::vector *input_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(input_ptr); + MS_EXCEPTION_IF_NULL(input_list); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { + TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_); + } else { + auto dtype = GetDeviceInputType(anf_node, real_input_index); + auto format = GetDeviceInputFormat(anf_node, real_input_index); + auto shape = GetDeviceInputShape(anf_node, real_input_index); + auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); + if (ori_shape.empty()) { + ori_shape.emplace_back(1); + } + nlohmann::json input_desc_json; + input_desc_json[kJDtype] = dtype; + input_desc_json[kJName] = op_input_name + std::to_string(input_i); + input_desc_json[kJOriShape] = ori_shape; + input_desc_json[kJOriFormat] = kOpFormat_NCHW; + input_desc_json[kJShape] = shape; + input_desc_json[kJFormat] = format; + input_desc_json[kJValid] = value; + input_desc_json[kJParamType] = input_ptr->param_type(); + input_list->emplace_back(input_desc_json); + } + return true; +} + +bool TbeKernelJsonCreator::GenInputList(const std::shared_ptr &anf_node, size_t input_tensor_num, + const std::shared_ptr &input_ptr, size_t *real_input_index, + string *op_input_name, std::vector *input_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(input_ptr); + MS_EXCEPTION_IF_NULL(real_input_index); + MS_EXCEPTION_IF_NULL(op_input_name); + MS_EXCEPTION_IF_NULL(input_list); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node); + bool value = true; + for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { + if (*real_input_index >= real_input_num) { + if (input_ptr->param_type() == "optional") { + *op_input_name = input_ptr->name() + "_optional_"; + nlohmann::json input_desc_json; + input_desc_json[kJValid] = false; + input_desc_json[kJName] = *op_input_name + std::to_string(*real_input_index); + input_list->emplace_back(input_desc_json); + continue; + } + MS_LOG(ERROR) << "Input num: " << *real_input_index << " is not match op inputs"; + return false; + } + if (op_name == "BatchNorm") { + if (input_ptr->name() == "mean" || input_ptr->name() == "variance") { + auto attr = primitive->GetAttr("is_training"); + MS_EXCEPTION_IF_NULL(attr); + bool is_training = GetValue(attr); + MS_LOG(INFO) << "Op_name" << op_name << ", tensor_name " << input_ptr->name() << ", is_training " + << is_training; + if (is_training) { + (*real_input_index)++; + break; + } + } + } + bool ret = GenInputDescJson(anf_node, *real_input_index, value, input_ptr, *op_input_name, input_i, input_list); + (*real_input_index)++; + if (!ret) { + return false; + } + } + return true; +} + +bool GetInputNameAndRealNum(const std::shared_ptr &anf_node, const std::shared_ptr &input_ptr, + size_t *dyn_input_index, size_t *input_num, std::string *op_input_name) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(input_ptr); + MS_EXCEPTION_IF_NULL(dyn_input_index); + MS_EXCEPTION_IF_NULL(input_num); + MS_EXCEPTION_IF_NULL(op_input_name); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. + std::vector dyn_input_sizes; + if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + + if (input_ptr->param_type() == kParamDynamic) { + if (*dyn_input_index >= dyn_input_sizes.size()) { + MS_LOG(ERROR) << "Dyn input index" << *dyn_input_index << "is over dyn input num" << dyn_input_sizes.size(); + return false; + } + *input_num = IntToSize(dyn_input_sizes[*dyn_input_index]); + *op_input_name = input_ptr->name() + "_dynamic_"; + (*dyn_input_index)++; + // if optional input is exist + } else { + *input_num = 1; + *op_input_name = input_ptr->name() + "_"; + } + return true; +} + +bool TbeKernelJsonCreator::GenTbeInputsJson(const std::shared_ptr &anf_node, + const std::shared_ptr &op_info, nlohmann::json *inputs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(op_info); + MS_EXCEPTION_IF_NULL(inputs_json); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (op_name == kAtomicAddrCleanOpName) { + return true; + } + std::vector> inputs_ptr = op_info->inputs_ptr(); + if (inputs_ptr.empty()) { + MS_LOG(INFO) << "Apply kernel " << op_name << "registration info has no input info"; + return true; + } + auto op_info_input_num = inputs_ptr.size(); + size_t dyn_input_index = 0; + size_t real_input_index = 0; + std::vector> inputs_list; + for (size_t i = 0; i < op_info_input_num; i++) { + size_t input_tensor_num; + std::shared_ptr input_ptr = inputs_ptr[i]; + std::string op_input_name; + MS_EXCEPTION_IF_NULL(input_ptr); + if (!GetInputNameAndRealNum(anf_node, input_ptr, &dyn_input_index, &input_tensor_num, &op_input_name)) { + return false; + } + std::vector input_list; + if (!GenInputList(anf_node, input_tensor_num, input_ptr, &real_input_index, &op_input_name, &input_list)) { + return false; + } + inputs_list.emplace_back(input_list); + } + + TbeAdapter::InputOrderPass(op_name, inputs_list, inputs_json); + return true; +} + +bool TbeKernelJsonCreator::GenTbeOutputsJson(const std::shared_ptr &anf_node, + const std::shared_ptr &op_info, nlohmann::json *outputs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(op_info); + MS_EXCEPTION_IF_NULL(outputs_json); + auto op_name = AnfAlgo::GetCNodeName(anf_node); + if (op_name == kAtomicAddrCleanOpName) { + return true; + } + auto outputs_ptr = op_info->outputs_ptr(); + return GenOutputDescJson(anf_node, outputs_ptr, outputs_json); +} + +bool TbeKernelJsonCreator::GenOutputDescJson( + const std::shared_ptr &anf_node, + const std::vector> &outputs_ptr, nlohmann::json *outputs_json) { + MS_EXCEPTION_IF_NULL(outputs_json); + size_t output_idx = 0; + auto op_name = AnfAlgo::GetCNodeName(anf_node); + size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node); + + for (const auto &output_ptr : outputs_ptr) { + size_t output_obj_num = 0; + if (output_ptr->param_type() == kParamRequred) { + output_obj_num = 1; + } else if (output_ptr->param_type() == kParamDynamic) { + if (outputs_ptr.size() > 1) { + MS_LOG(ERROR) << "Dynamic output is unsupported multi output!"; + return false; + } + output_obj_num = real_output_num; + } else { + if (output_idx >= real_output_num) { + MS_LOG(INFO) << "Op:" << op_name << ", output" << output_ptr->name() << " is optional, output is none."; + std::vector output_list; + nlohmann::json output_obj; + output_obj[kJName] = output_ptr->name(); + output_obj[kJValid] = false; + output_list.emplace_back(output_obj); + (*outputs_json).push_back(output_list); + continue; + } else { + output_obj_num = 1; + } + } + std::vector output_list; + GenOutputList(anf_node, output_obj_num, output_ptr, &output_idx, &output_list); + (*outputs_json).push_back(output_list); + } + return true; +} + +void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr &anf_node, const size_t &output_obj_num, + const std::shared_ptr &output_ptr, size_t *output_idx, + std::vector *output_list) { + MS_EXCEPTION_IF_NULL(output_idx); + MS_EXCEPTION_IF_NULL(output_list); + for (size_t i = 0; i < output_obj_num; i++) { + auto dtype = GetDeviceOutputType(anf_node, *output_idx); + auto format = GetDeviceOutputFormat(anf_node, *output_idx); + auto shape = GetDeviceOutputShape(anf_node, *output_idx); + std::vector ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); + if (ori_shape.empty()) { + ori_shape.emplace_back(1); + } + nlohmann::json output_obj; + output_obj[kJDtype] = dtype; + output_obj[kJShape] = shape; + output_obj[kJFormat] = format; + output_obj[kJOriShape] = ori_shape; + output_obj[kJOriFormat] = kOpFormat_NCHW; + output_obj[kJName] = output_ptr->name(); + output_obj[kJValid] = true; + output_obj[kJParamType] = output_ptr->param_type(); + output_list->emplace_back(output_obj); + (*output_idx)++; + } +} + +bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr &anf_node, + const std::shared_ptr &op_info, nlohmann::json *attrs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(op_info); + MS_EXCEPTION_IF_NULL(attrs_json); + auto attrs_ptr = op_info->attrs_ptr(); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (TbeAdapter::RunAttrPass(anf_node, attrs_ptr, attrs_json)) { + return true; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + for (const auto &attr_ptr : attrs_ptr) { + std::string attr_name = attr_ptr->name(); + nlohmann::json attr_obj; + attr_obj[kJName] = attr_name; + if (op_name == parallel::LAYER_NORM && attr_obj[kJName] == "epsilon" && creater_type_ == OP_SELECT_FORMAT) { + continue; + } + if (primitive->GetAttr(attr_name) != nullptr) { + auto value = primitive->GetAttr(attr_name); + std::string type = attr_ptr->type(); + ParseAttrValue(type, value, &attr_obj); + attr_obj[kJValid] = true; + } else { + if (op_info->impl_path().empty()) { + attr_obj[kJValid] = false; + } else { + if (attr_ptr->param_type() == kParamRequred && creater_type_ == SINGLE_BUILD) { + MS_LOG(EXCEPTION) << "Op name: " << op_info->op_name() << " attr: " << attr_name + << " is required, but not set."; + } else { + attr_obj[kJValid] = false; + } + } + } + (*attrs_json).push_back(attr_obj); + } + return true; +} + +void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, + nlohmann::json *attr_obj) { + MS_EXCEPTION_IF_NULL(value); + MS_EXCEPTION_IF_NULL(attr_obj); + if (type == kVTypeInt) { + auto attr_value = GetValue(value); + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeStr) { + auto attr_value = GetValue(value); + if (attr_value == kOpFormat_FRAC_Z) { + attr_value = kOpFormat_FRACTAL_Z; + } + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeBool) { + auto attr_value = GetValue(value); + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeFloat) { + auto attr_value = GetValue(value); + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeListInt) { + std::vector attr_value; + auto value_type = value->type(); + MS_EXCEPTION_IF_NULL(value_type); + auto value_type_str = value_type->ToString(); + if (value_type_str == kVTypeInt32) { + int data = GetValue(value); + attr_value.push_back(data); + } else { + attr_value = GetValue>(value); + } + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeListFloat) { + std::vector attr_value; + auto value_type = value->type(); + MS_EXCEPTION_IF_NULL(value_type); + auto value_type_str = value_type->ToString(); + if (value_type_str == kVTypeFloat) { + auto data = GetValue(value); + attr_value.push_back(data); + } else { + attr_value = GetValue>(value); + } + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeListUInt64) { + auto attr_value = GetValue>(value); + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeListListInt) { + auto attr_value = GetValue>>(value); + (*attr_obj)[kJValue] = attr_value; + } else { + MS_LOG(EXCEPTION) << "Type: " << type << "not support"; + } +} + +std::vector TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector shape; + if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { + shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index); + } else { + shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index); + } + if (shape.empty()) { + shape.emplace_back(1); + } + return shape; +} + +std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + TypeId type_id; + if (creater_type_ == OP_SELECT_FORMAT) { + type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index); + } else { + type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index); + } + return tbe::TypeIdToString(type_id); +} + +std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::string format = kOpFormat_NCHW; + if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { + format = AnfAlgo::GetInputFormat(anf_node, real_index); + if (format == kOpFormat_FRAC_Z) { + format = kOpFormat_FRACTAL_Z; + } else if (format == kOpFormat_DEFAULT) { + format = kOpFormat_NCHW; + } + } + return format; +} + +std::vector TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector shape; + if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { + shape = AnfAlgo::GetOutputInferShape(anf_node, real_index); + } else { + shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index); + } + if (shape.empty()) { + shape.emplace_back(1); + } + return shape; +} + +std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + TypeId type_id; + if (creater_type_ == OP_SELECT_FORMAT) { + type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index); + } else { + type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index); + } + return tbe::TypeIdToString(type_id); +} + +std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::string format = kOpFormat_NCHW; + if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { + format = AnfAlgo::GetOutputFormat(anf_node, real_index); + if (format == kOpFormat_FRAC_Z) { + format = kOpFormat_FRACTAL_Z; + } else if (format == kOpFormat_DEFAULT) { + format = kOpFormat_NCHW; + } + } + return format; +} + +bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, + std::vector *output_size_list) { + if (input_size_list == nullptr || output_size_list == nullptr) { + MS_LOG(ERROR) << "Input size or output size is nullptr"; + return false; + } + input_size_list->clear(); + output_size_list->clear(); + for (size_t i = 0; i < kernel_json[kJOpInfo][kJInputs].size(); i++) { + for (size_t m = 0; m < kernel_json[kJOpInfo][kJInputs][i].size(); m++) { + size_t size_i = 1; + if (kernel_json[kJOpInfo][kJInputs][i][m][kJValid] == false) { + std::string input_name = kernel_json[kJOpInfo][kJInputs][i][m][kJName]; + MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false."; + continue; + } + for (const auto &j : kernel_json[kJOpInfo][kJInputs][i][m][kJShape]) { + size_i *= static_cast(j); + } + std::string dtype = kernel_json[kJOpInfo][kJInputs][i][m][kJDtype]; + size_t nbyte = tbe::GetDtypeNbyte(dtype); + size_i *= nbyte; + input_size_list->push_back(size_i); + } + } + for (size_t i = 0; i < kernel_json[kJOpInfo][kJOutputs].size(); i++) { + for (size_t m = 0; m < kernel_json[kJOpInfo][kJOutputs][i].size(); m++) { + size_t size_i = 1; + if (kernel_json[kJOpInfo][kJOutputs][i][m][kJValid] == false) { + std::string output_name = kernel_json[kJOpInfo][kJOutputs][i][m][kJName]; + MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false."; + continue; + } + for (const auto &j : kernel_json[kJOpInfo][kJOutputs][i][m][kJShape]) { + size_i *= static_cast(j); + } + std::string dtype = kernel_json[kJOpInfo][kJOutputs][i][m][kJDtype]; + size_t nbyte = tbe::GetDtypeNbyte(dtype); + size_i *= nbyte; + output_size_list->push_back(size_i); + } + } + return true; +} + +bool TbeKernelBuild::GenFusionScopeJson(const std::vector &input_nodes, + const std::vector &compute_nodes, + nlohmann::json *fusion_str, std::string *fusion_kernel) { + MS_EXCEPTION_IF_NULL(fusion_str); + MS_EXCEPTION_IF_NULL(fusion_kernel); + // get input layer info + std::vector> input_layers; + std::map spec_data_input; + if (!GetInputLayers(input_nodes, compute_nodes, &input_layers, &spec_data_input)) { + return false; + } + // gen fusion scopre_op jsom + std::vector compute_list; + (*fusion_kernel) = kFusionKernelNamePrfix; + // index: fusion build option input record, next one from 0 + static size_t index = 0; + auto layer_iter = input_layers.begin(); + auto compute_op_iter = compute_nodes.begin(); + for (; compute_op_iter != compute_nodes.end(); ++compute_op_iter, ++layer_iter) { + nlohmann::json compute_op_str; + (void)GenFusionComputeJson(*compute_op_iter, &layer_iter, &compute_op_str, fusion_kernel, &index); + compute_list.push_back(compute_op_str); + } + index = 0; + // gen data input json + std::vector data_list; + for (const auto &layer : input_layers) { + for (const auto &data_input : layer) { + nlohmann::json data_str; + if (!GenFusionDataInputJson(data_input, spec_data_input, &data_str, &index)) { + MS_LOG(INFO) << "Fusion error: gen fusion datainput json faild."; + return false; + } + data_list.push_back(data_str); + } + } + index = 0; + data_list.insert(data_list.end(), compute_list.begin(), compute_list.end()); + (*fusion_str)[kFusionOpList] = data_list; + return true; +} + +void TbeKernelBuild::GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, + size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) { + std::string output_desc_name = anf_node->fullname_with_scope(); + if (node_out_idx > 0) { + output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx); + } + (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name); + auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); + (*output_desc)[kJDataType] = tbe::TypeIdToString(type_id); + auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); + if (ori_shape.empty()) { + ori_shape.emplace_back(1); + } + (*output_desc)[kJOriShape] = ori_shape; + auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, node_out_idx); + if (shape.empty()) { + shape.emplace_back(1); + } + (*output_desc)[kJShape] = shape; + auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx); + if (format == kOpFormat_DEFAULT) { + format = ori_shape.size() == 4 ? kOpFormat_NCHW : kOpFormat_ND; + } + (*output_desc)[kJFormat] = format; + (*output_desc)[kJOriFormat] = kOpFormat_NCHW; + (*output_desc)[kJOutputIndex] = desc_output_idx; + if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) { + std::vector spec_shape = {}; + spec_shape.emplace_back(shape[0]); + spec_shape.emplace_back(shape[1]); + spec_shape.emplace_back(shape[2] * shape[3]); + spec_shape.emplace_back(shape[4]); + (*output_desc)[kJShape] = spec_shape; + } else if (fusion_data_type == kFusionReLUGradV2) { + std::vector spec_shape = {}; + spec_shape.emplace_back(shape[0]); + spec_shape.emplace_back(shape[1]); + spec_shape.emplace_back(shape[2] * shape[3]); + spec_shape.emplace_back(16); + (*output_desc)[kJShape] = spec_shape; + (*output_desc)[kJDataType] = kVTypeBool; + } +} + +void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, + size_t output_index, nlohmann::json *output_desc) { + std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index); + (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name); + (*output_desc)[kJOutputIndex] = output_index; + std::vector shape; + (*output_desc)[kJShape] = shape; +} + +bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name, + const std::vector &reorder_layer, + std::map *spec_data_input) { + if ((op_name == kReluGradV2OpName || op_name == kAddNOpName) && reorder_layer.empty()) { + MS_LOG(INFO) << "Fusion error: node(" << op_name << " )'s input is null. "; + return false; + } + MS_LOG(INFO) << "Fusion info: op_name: " << op_name << "input layer size: " << reorder_layer.size(); + if (op_name == kReluGradV2OpName) { + (*spec_data_input)[reorder_layer[0]] = kFusionReLUGradV2; + } else if (op_name == kAddNOpName) { + for (const auto &it : reorder_layer) { + (*spec_data_input)[it] = kFusionAddN; + } + } + return true; +} + +bool TbeKernelBuild::GetInputLayers(const std::vector &input_nodes, + const std::vector &compute_nodes, + std::vector> *input_layers, + std::map *spec_data_input) { + MS_EXCEPTION_IF_NULL(input_layers); + MS_EXCEPTION_IF_NULL(spec_data_input); + auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) { + auto op_name = AnfAlgo::GetCNodeName(it); + return op_name == kConv2DBackpropInputOpName; + }); + bool need_spec = (result != compute_nodes.end()); + size_t input_size = 0; + for (const auto &compute_node : compute_nodes) { + std::vector layer = {}; + std::vector reorder_layer = {}; + MS_EXCEPTION_IF_NULL(compute_node); + auto op_name = AnfAlgo::GetCNodeName(compute_node); + auto ccompute_node = compute_node->cast(); + if (ccompute_node == nullptr) { + MS_LOG(INFO) << "Fusion error: fusion compute node must be cnode"; + return false; + } + MS_LOG(INFO) << "Fusion info: compute name: " << compute_node->fullname_with_scope(); + for (size_t i = 1; i < ccompute_node->inputs().size(); ++i) { + auto input = ccompute_node->input(i); + auto find_iter = std::find(input_nodes.begin(), input_nodes.end(), input); + if (find_iter != input_nodes.end()) { + MS_LOG(INFO) << "Fusion info: add compute node's [" << i << "] input: " << input->fullname_with_scope(); + layer.emplace_back((*find_iter)); + } else { + MS_LOG(INFO) << "Fusion warnig: this input [" << i << "] may be pre compute(" << input->fullname_with_scope() + << ") node's output."; + } + } + TbeAdapter::FusionDataOrderPass(op_name, layer, &reorder_layer); + if (need_spec) { + MS_LOG(INFO) << "Fusion info: match conv2d backprop input + ... patten."; + if (!GetSpecInputLayers(op_name, reorder_layer, spec_data_input)) { + return false; + } + } + input_size += reorder_layer.size(); + input_layers->emplace_back(reorder_layer); + } + if (input_nodes.size() != input_size) { + MS_LOG(INFO) << "Fusion error: fusion scope error, layer input:" << input_size + << ", input_node:" << input_nodes.size(); + return false; + } + return true; +} + +bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr &data_input, + const std::map &spec_data_input, + nlohmann::json *data_str, size_t *index) { + MS_EXCEPTION_IF_NULL(data_str); + MS_EXCEPTION_IF_NULL(index); + std::vector output_desc_list; + if (!data_input) { + MS_LOG(INFO) << "Data input is optional node"; + auto name = std::string(kOptional) + std::to_string(*index); + (*data_str)[kJName] = name; + nlohmann::json output_desc; + output_desc[kJName] = name; + output_desc[kJShape] = "NULL"; + output_desc_list.push_back(output_desc); + (*index)++; + } else { + FusionDataType fusion_data_type = kFusionNormal; + if (spec_data_input.find(data_input) != spec_data_input.end()) { + fusion_data_type = spec_data_input.at(data_input); + } + auto kernel_idx = AnfAlgo::VisitKernel(data_input, 0); + auto real_node = kernel_idx.first; + size_t real_idx = kernel_idx.second; + MS_LOG(INFO) << "Real name " << real_node->fullname_with_scope() << " index:" << real_idx; + // kJOutputDesc + nlohmann::json output_desc; + GenDescJson(real_node, real_idx, real_idx, &output_desc, fusion_data_type); + output_desc_list.push_back(output_desc); + (*data_str)[kJName] = NormalizeFullScopeName(real_node->fullname_with_scope()); + } + (*data_str)[kJOutputDesc] = output_desc_list; + (*data_str)[kJtype] = "Data"; + return true; +} + +bool TbeKernelBuild::IsDynamicInput(const mindspore::CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. + bool ret = false; + std::vector dyn_input_sizes; + auto dynamic_input_attr = primitive->GetAttr(kAttrDynInputSizes); + if (dynamic_input_attr != nullptr) { + dyn_input_sizes = GetValue>(dynamic_input_attr); + auto real_input_size = cnode->inputs().size() - 1; + auto dyn_input_size = dyn_input_sizes.size(); + if (dyn_input_size != 1) { + MS_LOG(INFO) << "Fusion error: fusion build not support dyn_input_sizes > 1"; + return ret; + } + if (IntToSize(dyn_input_sizes[0]) != real_input_size) { + MS_LOG(INFO) << "Fusion error: dyn_input_size" << dyn_input_sizes[0] << "not equal real_input_size" + << real_input_size; + return ret; + } + ret = true; + } + return ret; +} + +size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool is_dynamic_input) { + MS_EXCEPTION_IF_NULL(cnode); + if (is_dynamic_input) { + return 0; + } + MS_EXCEPTION_IF_NULL(cnode); + auto node_name = AnfAlgo::GetCNodeName(cnode); + auto op_info = OpLib::FindOp(node_name, kTBE); + MS_EXCEPTION_IF_NULL(cnode); + if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) { + MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope(); + } + return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size()); +} + +std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { + static std::map buffer_fussion_op_map = { + {parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}, {parallel::TENSOR_ADD, parallel::ADD}}; + string result = origin_type; + auto iter = buffer_fussion_op_map.find(origin_type); + if (iter != buffer_fussion_op_map.end()) { + result = iter->second; + } + return result; +} + +bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, + std::vector>::iterator *layer_iter, + std::vector *input_desc_list, size_t *index) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(input_desc_list); + std::vector input_desc_list_tmp = {}; + bool is_dynamic_input = IsDynamicInput(cnode); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto input = cnode->input(i); + auto kernel_idx = AnfAlgo::VisitKernel(input, 0); + auto real_node = kernel_idx.first; + size_t real_idx = kernel_idx.second; + MS_LOG(INFO) << "Real name" << real_node->fullname_with_scope() << "index:" << real_idx; + nlohmann::json input_desc; + GenDescJson(real_node, real_idx, real_idx, &input_desc); + if (is_dynamic_input) { + MS_LOG(INFO) << "Node has dynamic input."; + input_desc[kJDynIndex] = (i - 1); + } + input_desc_list_tmp.emplace_back(input_desc); + } + size_t optional_num = GetOptionalInput(cnode, is_dynamic_input); + if (optional_num > 0) { + MS_LOG(INFO) << "Node has optional input."; + for (size_t i = 0; i < optional_num; ++i) { + nlohmann::json optional_input_desc; + optional_input_desc[kJName] = std::string(kOptional) + std::to_string(*index); + (*index)++; + (*layer_iter)->emplace_back(nullptr); + input_desc_list_tmp.emplace_back(optional_input_desc); + } + } + auto op_name = AnfAlgo::GetCNodeName(cnode); + TbeAdapter::FusionInputOrderPass(op_name, input_desc_list_tmp, input_desc_list); + return true; +} + +std::vector TbeKernelBuild::GetDescOutputIndex(const std::vector &output_used_nums) { + std::vector desc_output_index = {}; + for (size_t idx = 0; idx < output_used_nums.size(); ++idx) { + auto output_use_num_item = output_used_nums[idx]; + MS_LOG(INFO) << "Output used num[" << idx << "] = " << output_use_num_item; + desc_output_index.emplace_back(idx); + if (output_use_num_item > 1) { + desc_output_index.emplace_back(idx); + } + } + return desc_output_index; +} + +bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, + std::vector *output_desc_list) { + MS_EXCEPTION_IF_NULL(output_desc_list); + auto output_size = AnfAlgo::GetOutputTensorNum(cnode); + if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { + auto output_used_nums = AnfAlgo::GetNodeAttr>(cnode, kAttrOutputUsedNum); + MS_LOG(INFO) << "This node's output has been reused, node name: " << cnode->fullname_with_scope(); + if (output_used_nums.size() != output_size) { + MS_LOG(INFO) << "Fusion error: output tenor num(" << output_size << ")" + << " is not match output used num(" << output_used_nums.size() << ")"; + return false; + } + auto desc_output_index = GetDescOutputIndex(output_used_nums); + for (size_t i = 0; i < output_size; ++i) { + MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i]; + nlohmann::json output_desc; + GenDescJson(cnode, i, desc_output_index[i], &output_desc); + output_desc_list->emplace_back(output_desc); + } + for (size_t j = output_size; j < desc_output_index.size(); ++j) { + MS_LOG(INFO) << "Fusion index: " << j << ", desc_output_index: " << desc_output_index[j]; + nlohmann::json output_desc; + GenReusedOutputDesc(cnode, j, desc_output_index[j], &output_desc); + output_desc_list->emplace_back(output_desc); + } + } else { + for (size_t i = 0; i < output_size; ++i) { + nlohmann::json output_desc; + GenDescJson(cnode, i, i, &output_desc); + output_desc_list->push_back(output_desc); + } + } + return true; +} + +bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, + std::vector>::iterator *layer_iter, + nlohmann::json *compute_op_str, std::string *fusion_kernel_name, + size_t *index) { + MS_EXCEPTION_IF_NULL(compute_node); + auto cnode = compute_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // gen input desc + std::vector input_desc_list; + (void)GenFusionComputeInputJson(cnode, layer_iter, &input_desc_list, index); + (*compute_op_str)[kJInputDesc] = input_desc_list; + // gen output desc + std::vector output_desc_list; + if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) { + MS_LOG(INFO) << "Fusion Error: gen fusion output desc faild, node full name: " << cnode->fullname_with_scope(); + return false; + } + (*compute_op_str)[kJOutputDesc] = output_desc_list; + // gen others + auto origin_type = AnfAlgo::GetCNodeName(cnode); + // replace special op type for buffer fusion op + auto type = GetRealOpType(origin_type); + (*compute_op_str)[kJtype] = type; + tbe::TbeAdapter::NormalizeFuncName(&type); + (*compute_op_str)[kJFuncName] = type; + (*compute_op_str)[kJName] = NormalizeFullScopeName(cnode->fullname_with_scope()); + (void)(*fusion_kernel_name).append("_"); + (void)(*fusion_kernel_name).append(type); + return true; +} + +size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) { + size_t ret = 1; + for (const auto &shape_item : desc[kJShape]) { + ret *= static_cast(shape_item); + } + std::string data_type = desc[kJDataType]; + size_t nbyte = tbe::GetDtypeNbyte(data_type); + ret *= nbyte; + return ret; +} + +bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, + const std::vector &output_nodes, + std::vector *input_size_list, std::vector *output_size_list) { + MS_EXCEPTION_IF_NULL(input_size_list); + MS_EXCEPTION_IF_NULL(output_size_list); + input_size_list->clear(); + output_size_list->clear(); + + for (const auto &op : fusion_op_list) { + if (op[kJtype] == "Data") { + const auto &data_output_desc = op[kJOutputDesc]; + for (const auto &data_output : data_output_desc) { + if (data_output[kJShape] == "NULL") { + break; + } + auto ret = GetIOSizeImpl(data_output); + input_size_list->push_back(ret); + MS_LOG(INFO) << "Fusion info: scope input name: " << op[kJName] << ", size: " << ret; + } + } + } + + for (const auto &output_node : output_nodes) { + auto kernel_idx = AnfAlgo::VisitKernel(output_node, 0); + auto real_node = kernel_idx.first; + size_t real_idx = kernel_idx.second; + auto normal_name = NormalizeFullScopeName(real_node->fullname_with_scope()); + MS_LOG(INFO) << "Fusion info: real node name: " << normal_name << ", real output index: " << real_idx; + for (const auto &op : fusion_op_list) { + if (op[kJName] == normal_name) { + auto op_output_desces = op[kJOutputDesc]; + if (output_node != real_node) { + // tuple_get item + MS_LOG(INFO) << "Output is a tuple getitem node"; + auto output_desc = op_output_desces[real_idx]; + if (output_desc[kJShape].empty()) { + MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx; + return false; + } + auto ret = GetIOSizeImpl(output_desc); + output_size_list->push_back(ret); + MS_LOG(INFO) << "Fusion info: scope output index: " << real_idx << ", size: " << ret; + } else { + for (const auto &output_desc : op_output_desces) { + if (output_desc[kJShape].empty()) { + MS_LOG(INFO) << "Fusion info: output_desc's shape is empty, may be this node output"; + continue; + } + auto ret = GetIOSizeImpl(output_desc); + output_size_list->push_back(ret); + MS_LOG(INFO) << "Fusion info: scope output size: " << ret; + } + } + } + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h new file mode 100644 index 0000000000..768f811055 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h @@ -0,0 +1,122 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "ir/dtype.h" +#include "backend/kernel_compiler/kernel.h" +#include "pybind11/stl.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/tbe/tbe_adapter.h" + +namespace mindspore { +namespace kernel { +// kernel operate type used for generate json + +class TbeKernelBuild { + enum FusionDataType { kFusionNormal = 0, kFusionAddN, kFusionReLUGradV2 }; + + public: + static bool GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, + std::vector *output_size_list); + // Ub Fuison + static bool GenFusionScopeJson(const std::vector &input_nodes, + const std::vector &compute_nodes, nlohmann::json *fusion_str, + std::string *fusion_kernel); + static bool GetIOSize(const nlohmann::json &fusion_op_list, const std::vector &output_nodes, + std::vector *input_size_list, std::vector *output_size_list); + + private: + TbeKernelBuild() = default; + ~TbeKernelBuild() = default; + static bool GenFusionDataInputJson(const std::shared_ptr &data_input, + const std::map &spec_data_input, + nlohmann::json *data_str, size_t *index); + static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, + std::vector>::iterator *layer_iter, + nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index); + static bool GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, + std::vector>::iterator *layer_iter, + std::vector *input_desc_list, size_t *index); + static std::vector GetDescOutputIndex(const std::vector &output_used_nums); + static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, + std::vector *output_desc_list); + static void GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, + size_t desc_output_idx, nlohmann::json *output_desc, + FusionDataType fusion_data_type = kFusionNormal); + static void GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, + size_t output_index, nlohmann::json *output_desc); + static size_t GetIOSizeImpl(const nlohmann::json &desc); + static bool GetSpecInputLayers(const std::string &op_name, const std::vector &reorder_layer, + std::map *spec_data_input); + static bool GetInputLayers(const std::vector &input_nodes, + const std::vector &compute_nodes, + std::vector> *input_layers, + std::map *spec_data_input); + static bool IsDynamicInput(const CNodePtr &cnode); + static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input); + static std::string GetRealOpType(const std::string &origin_type); +}; + +class TbeKernelJsonCreator { + public: + explicit TbeKernelJsonCreator(kCreaterType creater_type = SINGLE_BUILD) : creater_type_(creater_type) {} + ~TbeKernelJsonCreator() = default; + bool GenTbeSingleKernelJson(const std::shared_ptr &anf_node, nlohmann::json *kernel_json); + std::string json_name() { return json_name_; } + + private: + bool GenTbeInputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, + nlohmann::json *inputs_json); + bool GenTbeOutputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, + nlohmann::json *outputs_json); + bool GenTbeAttrJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, + nlohmann::json *attrs_json); + static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); + bool GenInputDescJson(const std::shared_ptr &anf_node, size_t real_input_index, bool value, + const std::shared_ptr &input_ptr, const string &op_input_name, size_t input_i, + std::vector *input_list); + bool GenOutputDescJson(const std::shared_ptr &anf_node, + const std::vector> &outputs_ptr, nlohmann::json *outputs_json); + bool GenInputList(const std::shared_ptr &anf_node, size_t input_tensor_num, + const std::shared_ptr &input_ptr, size_t *real_input_index, string *op_input_name, + std::vector *input_list); + void GenOutputList(const std::shared_ptr &anf_node, const size_t &output_obj_num, + const std::shared_ptr &output_ptr, size_t *output_idx, + std::vector *output_list); + std::vector GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const; + std::vector GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const; + + kCreaterType creater_type_; + std::string json_name_; + std::string json_info_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc new file mode 100644 index 0000000000..e6cb4cf30d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" +#include +#include "runtime/rt.h" +#include "utils/context/ms_context.h" +#include "graphengine/inc/framework/ge_runtime/task_info.h" + +namespace mindspore { +namespace kernel { +using TbeTaskInfoPtr = std::shared_ptr; +using tbe::KernelManager; +bool TbeKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (stream_ptr == nullptr) { + MS_LOG(ERROR) << "stream_ptr should not be nullptr."; + return false; + } + + if (kernel_pack_ == nullptr) { + MS_LOG(ERROR) << "kernel pack should not be nullptr."; + return false; + } + + uint32_t blockdim = 1; // default blockdim equal to 1. + auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &blockdim); + if (func_stub == 0) { + MS_LOG(ERROR) << "GenFuncStub failed."; + return false; + } + + // pack all addresses into a vector. + std::vector runtimeargs; + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs), + [](const AddressPtr &output) -> void * { return output->addr; }); + if (!workspace.empty()) { + (void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(runtimeargs), + [](const AddressPtr &addr) -> void * { return addr->addr; }); + } + rtL2Ctrl_t *l2ctrl = nullptr; + const void *stubFunc = reinterpret_cast(func_stub); + auto argsSize = static_cast(UlongToUint(sizeof(void *)) * runtimeargs.size()); + if (RT_ERROR_NONE != rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream_ptr)) { + MS_LOG(ERROR) << "Call runtime rtKernelLaunch error."; + return false; + } + + return true; +} + +std::vector TbeKernelMod::GenTask(const std::vector &inputs, + const std::vector &workspaces, + const std::vector &outputs, uint32_t stream_id) { + if (kernel_pack_ == nullptr) { + MS_EXCEPTION(ArgumentError) << "kernel pack should not be nullptr."; + } + + std::vector args; + std::vector sm_desc; + std::vector meta_data; + std::vector input_data_addrs; + std::vector output_data_addrs; + std::vector workspace_addrs; + + // pack all addresses into a vector. + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), + [](const AddressPtr &output) -> void * { return output->addr; }); + if (!workspaces.empty()) { + (void)std::transform(std::begin(workspaces), std::end(workspaces), std::back_inserter(workspace_addrs), + [](const AddressPtr &workspace) -> void * { return workspace->addr; }); + } + + stream_id_ = stream_id; + auto funcstub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim_); + if (funcstub == 0) { + MS_EXCEPTION(ArgumentError) << "GenFuncStub failed."; + } + + std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_); + + MS_LOG(INFO) << "block_dim is:" << block_dim_; + + TbeTaskInfoPtr task_info_ptr = make_shared( + kernel_name_, stream_id, stub_func, block_dim_, args, 0, sm_desc, nullptr, 0, meta_data, input_data_addrs, + output_data_addrs, workspace_addrs, NeedDump()); + return {task_info_ptr}; +} + +vector TbeKernelMod::GenParameters() { + auto kernel_json_info = kernel_pack_->kernel_json_info(); + return kernel_json_info.parameters; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h new file mode 100644 index 0000000000..de48c83d9b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" + +namespace mindspore { +namespace kernel { +class TbeKernelMod : public AscendKernelMod { + public: + explicit TbeKernelMod(KernelPackPtr kernel_pack) : kernel_pack_(std::move(kernel_pack)) {} + ~TbeKernelMod() override = default; + + void SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } + void SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } + void SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs, uint32_t stream_id) override; + std::vector GenParameters() override; + + private: + KernelPackPtr kernel_pack_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +using TbeKernelModPtr = std::shared_ptr; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc new file mode 100644 index 0000000000..48223f40c6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc @@ -0,0 +1,326 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" + +#include +#include +#include +#include +#include +#include + +#include "utils/context/ms_context.h" +#include "backend/kernel_compiler/tbe/tbe_adapter.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "./common.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" + +namespace mindspore { +namespace kernel { +using mindspore::kernel::tbe::TbeUtils; +constexpr auto kParallelCompileModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process"; +constexpr auto kCreateParallelCompiler = "create_tbe_parallel_compiler"; +constexpr auto kStartCompileOp = "start_compile_op"; +constexpr auto kWaitOne = "wait_one"; +constexpr auto kResetTaskInfo = "reset_task_info"; + +bool TbeOpParallelPreBuild(const std::vector &anf_nodes) { + auto build_manger = std::make_shared(); + MS_EXCEPTION_IF_NULL(build_manger); + for (const auto &anf_node : anf_nodes) { + // gen kernel json + MS_EXCEPTION_IF_NULL(anf_node); + nlohmann::json kernel_json; + TbeKernelJsonCreator creator(OP_PRE_COMPILE); + if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) { + MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; + return false; + } + kernel_json["compile_type"] = "pre_build"; + // op build + auto task_id = build_manger->StartCompileOp(kernel_json); + build_manger->SavePreTaskInfo(task_id, anf_node); + } + while (!build_manger->IsAllPreTaskFinish()) { + int task_id = -1; + char *task_result = nullptr; + char *pre_build_result = nullptr; + auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); + if (!ret) { + MS_EXCEPTION(ArgumentError) << "Pre Build Failed. wait one ret:" << ret << ", task id:" << task_id; + } + + if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { + MS_EXCEPTION(ArgumentError) << "task pre compile Failed, task id:" << task_id << ", cause:" << task_result; + } + + build_manger->PreTaskFinishProcess(task_id, pre_build_result); + } + return true; +} + +bool TbeOpParallelBuild(const std::vector &anf_nodes) { + auto build_manger = std::make_shared(); + MS_EXCEPTION_IF_NULL(build_manger); + set processed_kernel; + for (const auto &anf_node : anf_nodes) { + // gen kernel json + tbe::TbeAdapter::SetTbeAttrsForTransDataOp(anf_node); + if (AnfAlgo::GetKernelMod(anf_node) != nullptr) { + continue; + } + const std::string &processor = tbe::GetProcessor(anf_node); + nlohmann::json kernel_json; + TbeKernelJsonCreator creator(SINGLE_BUILD); + if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) { + MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; + return false; + } + // get size + std::vector input_size_list; + std::vector output_size_list; + (void)TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); + // search cache + const std::string &json_name = creator.json_name(); + if (build_manger->SearchInCache(json_name, processor, input_size_list, output_size_list, anf_node.get())) { + MS_LOG(INFO) << "Use cached kernel, kernel json name:." << json_name; + continue; + } + // same op not need build, but need wait build finish to set kernel mode + if (processed_kernel.find(json_name) != processed_kernel.end()) { + build_manger->SaveSameOpInfo(anf_node, json_name, input_size_list, output_size_list); + continue; + } + (void)processed_kernel.insert(json_name); + // op build + auto task_id = build_manger->StartCompileOp(kernel_json); + build_manger->SaveTaskInfo(task_id, anf_node, json_name, input_size_list, output_size_list); + } + while (!build_manger->IsAllTaskFinish()) { + int task_id = -1; + char *task_result = nullptr; + char *pre_build_result = nullptr; + auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); + if (!ret) { + MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; + } + + if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { + MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; + } + (void)build_manger->TaskFinishProcess(task_id); + } + return build_manger->GenSameOpKernelMod(); +} + +ParallelBuildManager::ParallelBuildManager() { tbe_parallel_compiler_ = TbePythonFuncs::TbeParallelCompiler(); } + +ParallelBuildManager::~ParallelBuildManager() { ResetTaskInfo(); } + +int32_t ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) const { + PyObject *pRes = nullptr; + PyObject *pArgs = PyTuple_New(1); + std::string json_str = kernel_json.dump(); + PyObject *arg1 = Py_BuildValue("s", json_str.c_str()); + (void)PyTuple_SetItem(pArgs, 0, arg1); + pRes = PyObject_CallMethod(tbe_parallel_compiler_, kStartCompileOp, "O", pArgs); + if (pRes == nullptr) { + PyErr_Print(); + MS_EXCEPTION(ArgumentError) << "Failed to call function start_compile_op"; + } + int task_id; + (void)PyArg_Parse(pRes, "i", &task_id); + MS_LOG(INFO) << "start compile , task id:" << task_id; + return task_id; +} + +bool ParallelBuildManager::WaitOne(int *task_id, char **task_result, char **pre_build_result) const { + MS_LOG(INFO) << "wait task start."; + MS_EXCEPTION_IF_NULL(task_id); + MS_EXCEPTION_IF_NULL(task_result); + PyObject *pRes = nullptr; + PyObject *pArg = Py_BuildValue("()"); + pRes = PyObject_CallMethod(tbe_parallel_compiler_, kWaitOne, "O", pArg); + if (pRes == nullptr) { + PyErr_Print(); + MS_EXCEPTION(ArgumentError) << "Failed to call function wait_one"; + return false; + } + (void)PyArg_ParseTuple(pRes, "iss", task_id, task_result, pre_build_result); + return true; +} + +void ParallelBuildManager::SavePreTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node) { + MS_LOG(INFO) << "SavePreTaskInfo, task id: " << task_id; + pre_task_map_[task_id] = anf_node; +} + +void ParallelBuildManager::SaveTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node, + const std::string &json_name, const std::vector &input_size_list, + const std::vector &output_size_list, int32_t scope_id) { + MS_LOG(INFO) << "SaveTaskInfo, task id: " << task_id; + struct KernelBuildTaskInfo task_info; + task_info.node = anf_node.get(); + task_info.json_name = json_name; + if (anf_node == nullptr) { + task_info.processor = tbe::kProcessorAiCore; + } else { + task_info.processor = tbe::GetProcessor(anf_node); + } + task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end()); + task_info.output_size_list.assign(output_size_list.begin(), output_size_list.end()); + task_info.scope_id = scope_id; + task_map_[task_id] = task_info; +} + +bool ParallelBuildManager::IsAllPreTaskFinish() const { + MS_LOG(INFO) << "wait pre build process task_num: " << pre_task_map_.size(); + return pre_task_map_.empty(); +} + +bool ParallelBuildManager::IsAllTaskFinish() const { + MS_LOG(INFO) << "wait process task_num: " << task_map_.size(); + return task_map_.empty(); +} + +void ParallelBuildManager::PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result) { + auto task_iter = pre_task_map_.find(task_id); + if (task_iter == pre_task_map_.end()) { + MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id; + } + auto node = task_iter->second; + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); + std::string start_flag = "fusion_pattern_start"; + std::string end_flag = "fusion_pattern_end"; + int start = pre_build_result.find(start_flag); + int end = pre_build_result.find(end_flag); + if (start != -1 && end != -1 && end >= start) { + std::string result = pre_build_result.substr(start + start_flag.size(), end - start - start_flag.size()); + if (result == "") { + (void)pre_task_map_.erase(task_iter); + return; + } + transform(result.begin(), result.end(), result.begin(), ::toupper); + FusionType fusion_type = tbe::GetFusionType(result); + builder->SetFusionType(fusion_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + } + (void)pre_task_map_.erase(task_iter); +} + +std::pair ParallelBuildManager::TaskFinishProcess(int32_t task_id, bool set_kernel_mod) { + auto task_iter = task_map_.find(task_id); + if (task_iter == task_map_.end()) { + MS_EXCEPTION(ArgumentError) << "can find task_id:" << task_id; + } + auto json_name = task_iter->second.json_name; + auto processor = task_iter->second.processor; + auto kernel_pack = TbeUtils::InsertCache(json_name, processor); + if (kernel_pack == nullptr) { + if (set_kernel_mod) { + MS_EXCEPTION(ArgumentError) << "build kernel name:" << task_iter->second.json_name << " failed."; + } else { + MS_LOG(INFO) << "fusion build kernel name:" << task_iter->second.json_name << "failed."; + auto ret = std::make_pair(task_iter->second.scope_id, nullptr); + (void)task_map_.erase(task_iter); + return ret; + } + } + auto kernel_mod = GenKernelMod(json_name, processor, task_iter->second.input_size_list, + task_iter->second.output_size_list, kernel_pack); + MS_EXCEPTION_IF_NULL(kernel_mod); + if (set_kernel_mod) { + AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node); + } + auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod); + (void)task_map_.erase(task_iter); + MS_LOG(INFO) << "wait process remain task_num:" << task_map_.size(); + return ret; +} + +void ParallelBuildManager::SaveSameOpInfo(const mindspore::AnfNodePtr &anf_node, const std::string &json_name, + const std::vector &input_size_list, + const std::vector &output_size_list) { + struct KernelBuildTaskInfo task_info; + task_info.node = anf_node.get(); + task_info.json_name = json_name; + task_info.processor = tbe::GetProcessor(anf_node); + task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end()); + task_info.output_size_list.assign(output_size_list.begin(), output_size_list.end()); + same_op_list_.push_back(task_info); +} + +bool ParallelBuildManager::GenSameOpKernelMod() const { + for (const auto &task_info : same_op_list_) { + bool ret = SearchInCache(task_info.json_name, task_info.processor, task_info.input_size_list, + task_info.output_size_list, task_info.node); + if (!ret) { + MS_LOG(INFO) << "can't find " << task_info.json_name << " in cache."; + return false; + } + } + return true; +} + +bool ParallelBuildManager::SearchInCache(const std::string &json_name, const std::string &processor, + const std::vector &input_size_list, + const std::vector &output_size_list, mindspore::AnfNode *node) const { + auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor); + if (cached_kernel_pack != nullptr) { + MS_LOG(INFO) << "Find cached kernel, kernel json name" << json_name; + auto kernel_mod_ptr = GenKernelMod(json_name, processor, input_size_list, output_size_list, cached_kernel_pack); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + AnfAlgo::SetKernelMod(kernel_mod_ptr, node); + return true; + } else { + return false; + } +} + +KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const string &processor, + const vector &input_size_list, + const vector &output_size_list, + const mindspore::kernel::KernelPackPtr &kernel_pack) const { + MS_EXCEPTION_IF_NULL(kernel_pack); + auto kernel_json_info = kernel_pack->kernel_json_info(); + auto kernel_mod_ptr = std::make_shared(kernel_pack); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + kernel_mod_ptr->SetInputSizeList(input_size_list); + kernel_mod_ptr->SetOutputSizeList(output_size_list); + kernel_mod_ptr->SetWorkspaceSizeList(kernel_json_info.workspaces); + return kernel_mod_ptr; +} + +void ParallelBuildManager::ResetTaskInfo() { + if (task_map_.empty()) { + MS_LOG(INFO) << "All tasks are compiled success."; + return; + } + task_map_.clear(); + same_op_list_.clear(); + if (tbe_parallel_compiler_ != nullptr) { + PyObject *pArg = Py_BuildValue("()"); + (void)PyObject_CallMethod(tbe_parallel_compiler_, kResetTaskInfo, "O", pArg); + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h new file mode 100644 index 0000000000..a29469b47c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "pybind11/stl.h" +#include +namespace mindspore { +namespace kernel { +bool TbeOpParallelPreBuild(const std::vector &anf_nodes); +bool TbeOpParallelBuild(const std::vector &anf_nodes); + +struct KernelBuildTaskInfo { + AnfNode *node; + std::string processor; + std::string json_name; + std::vector input_size_list; + std::vector output_size_list; + int32_t scope_id; +}; + +class ParallelBuildManager { + public: + ParallelBuildManager(); + ~ParallelBuildManager(); + int32_t StartCompileOp(const nlohmann::json &kernel_json) const; + void SavePreTaskInfo(int32_t task_id, const AnfNodePtr &anf_node); + void SaveTaskInfo(int32_t task_id, const AnfNodePtr &anf_node, const std::string &json_name, + const std::vector &input_size_list, const std::vector &output_size_list, + int32_t scope_id = 0); + void SaveSameOpInfo(const AnfNodePtr &anf_node, const std::string &json_name, + const std::vector &input_size_list, const std::vector &output_size_list); + bool GenSameOpKernelMod() const; + bool SearchInCache(const std::string &json_name, const std::string &processor, + const std::vector &input_size_list, const std::vector &output_size_list, + AnfNode *node) const; + + bool WaitOne(int *task_id, char **task_result, char **pre_build_result) const; + bool IsAllPreTaskFinish() const; + bool IsAllTaskFinish() const; + void PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result); + std::pair TaskFinishProcess(int32_t task_id, bool set_kernel_mod = true); + KernelModPtr GenKernelMod(const string &json_name, const string &processor, + const std::vector &input_size_list, const std::vector &output_size_list, + const KernelPackPtr &kernel_pack) const; + void ResetTaskInfo(); + + private: + PyObject *tbe_parallel_compiler_; + std::map pre_task_map_; + std::map task_map_; + std::vector same_op_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h similarity index 100% rename from mindspore/ccsrc/kernel/tbe/tbe_kernel_select/common_utils.h rename to mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc new file mode 100644 index 0000000000..c5e882949b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc @@ -0,0 +1,318 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +constexpr size_t kInputIndex_0 = 0; +constexpr size_t kChannelN = 0; +constexpr size_t kChannelC = 1; +constexpr size_t kAlignmented16 = 16; +// 1. all shape no scalar and same +// 2. part scalar : no_scalar (shape size > xxx && alig xxx) +// 3. all no_scalar and not same (broad cast xxx dim) +bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) { + MS_EXCEPTION_IF_NULL(support_format); + input_num_ = 0; + output_num_ = 0; + input_shapes_.clear(); + output_shapes_.clear(); + if (AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode_ptr_)) { + MS_LOG(INFO) << "This broadcast node has dynamic input."; + auto dynamic_size_vec = AnfAlgo::GetNodeAttr>(cnode_ptr_, kAttrDynInputSizes); + if (dynamic_size_vec.empty() || dynamic_size_vec[0] < 2) { + MS_LOG(EXCEPTION) << "dynamic attr set error, please check."; + } + auto dynamic_input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); + PadScalarShape(&dynamic_input_shape0_); + input_shapes_.emplace_back(dynamic_input_shape0_); + input_num_ = 1; + } else { + input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_); + for (size_t i = 0; i < input_num_; ++i) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); + PadScalarShape(&input_shape); + input_shapes_.emplace_back(input_shape); + } + } + + output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); + for (size_t i = 0; i < output_num_; ++i) { + auto output = AnfAlgo::GetOutputInferShape(cnode_ptr_, i); + PadScalarShape(&output); + output_shapes_.emplace_back(output); + } + AssignSupportFormat(kOpFormat_DEFAULT, support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_NC1HWC0, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (!Is4DShape(shape)) { + return false; + } + if (shape[kChannelC] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_NC1HWC0); + } + } + } else { + for (const auto &shape : input_shapes_) { + if (!Is4DShape(shape)) { + return false; + } + } + auto shape_tmp = input_shapes_[0]; + auto broadcast_c_axis = std::any_of( + input_shapes_.begin(), input_shapes_.end(), + [&shape_tmp](const std::vector &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); }); + if (broadcast_c_axis) { + MS_LOG(INFO) << "This node broadcast c channel."; + return false; + } + input_support_format.assign(input_num_, kOpFormat_NC1HWC0); + } + GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracZ(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_FRAC_Z, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (!Is4DShape(shape)) { + return false; + } + if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_FRAC_Z); + } + } + } else { + return false; + } + GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} +bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_C1HWNCoC0, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (!Is4DShape(shape)) { + return false; + } + if (shape[kChannelN] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_C1HWNCoC0); + } + } + } else { + for (const auto &shape : input_shapes_) { + if (!Is4DShape(shape)) { + return false; + } + } + auto shape_tmp = input_shapes_[0]; + auto broadcast_nc_axis = + std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector &elem) { + return (shape_tmp.at(kChannelC) != elem.at(kChannelC) || shape_tmp.at(kChannelN) != elem.at(kChannelN)); + }); + if (broadcast_nc_axis) { + MS_LOG(INFO) << "This node broadcast n || c channel."; + return false; + } + input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0); + } + GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (shape.size() < kShape2dDims) { + return false; + } + if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - 2] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_FRAC_NZ); + } + } + } else { + auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(), + [](const std::vector &elem) { return elem.size() < kShape2dDims; }); + if (less_2dims) { + MS_LOG(INFO) << "This node dim less 2."; + return false; + } + + auto shape_tmp = input_shapes_[0]; + auto broadcast_last_dim = + std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector &elem) { + return (shape_tmp.at(shape_tmp.size() - 1) != elem.at(elem.size() - 1)) || + (shape_tmp.at(shape_tmp.size() - 2) != elem.at(elem.size() - 2)); + }); + if (broadcast_last_dim) { + MS_LOG(INFO) << "This node broadcast last channel."; + return false; + } + + input_support_format.assign(input_num_, kOpFormat_FRAC_NZ); + } + GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + return false; +} + +bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector &shape) const { + return shape.size() == kShape4dDims; +} + +bool TbeKernelBroadCastSelecter::IsSameShape() const { + auto shape = input_shapes_.begin(); + for (const auto &item : input_shapes_) { + if (shape->size() != item.size()) { + return false; + } + for (size_t i = 0; i < shape->size(); ++i) { + if (shape->at(i) != item.at(i)) { + return false; + } + } + } + return true; +} + +void TbeKernelBroadCastSelecter::PadScalarShape(std::vector *shape) const { + MS_EXCEPTION_IF_NULL(shape); + if (shape->empty()) { + shape->emplace_back(1); + } +} + +bool TbeKernelBroadCastSelecter::IsScalarShape(const std::vector &shape) const { + return (shape.size() == 1 && shape[0] == 1); +} + +bool TbeKernelBroadCastSelecter::HasScalarInput() const { + bool ret = false; + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + ret = true; + break; + } + } + return ret; +} + +void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &support_format, + SupportFormatItem *output_support_item) const { + MS_EXCEPTION_IF_NULL(output_support_item); + for (const auto &shape : output_shapes_) { + if (IsScalarShape(shape)) { + output_support_item->emplace_back(kOpFormat_DEFAULT); + } else { + output_support_item->emplace_back(support_format); + } + } +} + +void TbeKernelBroadCastSelecter::AssignSupportFormat(const std::string &support_format_str, + mindspore::kernel::SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + input_support_format.assign(input_num_, support_format_str); + output_support_format.assign(output_num_, support_format_str); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h new file mode 100644 index 0000000000..4685df6724 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +class TbeKernelBroadCastSelecter { + public: + explicit TbeKernelBroadCastSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} + ~TbeKernelBroadCastSelecter() = default; + bool GetShapeInfo(SupportFormat *support_format); + bool IsBroadCastSupport5HD(SupportFormat *support_format) const; + bool IsBroadCastSupportFracZ(SupportFormat *support_format) const; + bool IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const; + bool IsBroadCastSupportFracNZ(SupportFormat *support_format) const; + bool IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const; + + private: + bool IsSameShape() const; + void PadScalarShape(std::vector *shape) const; + bool Is4DShape(const std::vector &shape) const; + bool IsScalarShape(const std::vector &shape) const; + bool HasScalarInput() const; + void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; + void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; + // broadcast + CNodePtr cnode_ptr_; + size_t input_num_{}; + size_t output_num_{}; + std::vector> input_shapes_; + std::vector> output_shapes_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc new file mode 100644 index 0000000000..61aa9dfb91 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc @@ -0,0 +1,152 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" +#include +#include +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +constexpr size_t kInputIndex_0 = 0; +constexpr size_t kOutputIndex_0 = 0; +constexpr size_t kChannelN = 0; +constexpr size_t kChannelC = 1; +constexpr size_t kReduceNZMinDim = 3; + +bool TbeKernelReduceSelecter::GetShapeInfo(SupportFormat *support_format) { + MS_EXCEPTION_IF_NULL(support_format); + input_shape_.clear(); + output_shape_.clear(); + axis_.clear(); + auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); + auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); + if (input_num != 1 || output_num != 1) { + MS_LOG(EXCEPTION) << "Reduce operator only support one input/output, input num: " << input_num + << ", output num: " << output_num; + } + // get input/output shape + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); + PadScalarShape(&input_shape_); + output_shape_ = AnfAlgo::GetOutputInferShape(cnode_ptr_, kOutputIndex_0); + PadScalarShape(&output_shape_); + // get keep dim attr + GetReduceAttrKeepDim(); + // get axis attr + axis_ = GetReduceAttrAxis(cnode_ptr_); + AssignSupportFormat(kOpFormat_DEFAULT, support_format); + return true; +} + +bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (!Is4DShape(input_shape_)) { + return false; + } + if (!keep_dims_ || axis_.empty()) { + return false; + } + auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); }); + if (reduce_c_axis) { + return false; + } + AssignSupportFormat(kOpFormat_NC1HWC0, support_format); + return true; +} + +bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + // like to 5HD + return false; +} + +bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { + return IsFracZAndC1HWNCoC0Common(kOpFormat_FRAC_Z, support_format); +} + +bool TbeKernelReduceSelecter::IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const { + return IsFracZAndC1HWNCoC0Common(kOpFormat_C1HWNCoC0, support_format); +} + +bool TbeKernelReduceSelecter::IsReduceSupportFracNZ(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (input_shape_.size() < kReduceNZMinDim) { + return false; + } + if (axis_.empty()) { + return false; + } + auto reduce_last_axis = std::any_of(axis_.begin(), axis_.end(), [this](const size_t &elem) { + return (elem == (this->input_shape_.size() - 1) || elem == (this->input_shape_.size() - 2)); + }); + if (reduce_last_axis) { + return false; + } + AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); + return true; +} + +bool TbeKernelReduceSelecter::IsFracZAndC1HWNCoC0Common(const std::string &format, + mindspore::kernel::SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (!Is4DShape(input_shape_)) { + return false; + } + if (!keep_dims_ || axis_.empty()) { + return false; + } + auto reduce_n_c_axis = std::any_of(axis_.begin(), axis_.end(), + [](const size_t &elem) { return (elem == kChannelC || elem == kChannelN); }); + if (reduce_n_c_axis) { + return false; + } + AssignSupportFormat(format, support_format); + return true; +} + +void TbeKernelReduceSelecter::GetReduceAttrKeepDim() { + if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode_ptr_)) { + MS_LOG(INFO) << "This node does't have keep_attr."; + keep_dims_ = false; + return; + } + keep_dims_ = AnfAlgo::GetNodeAttr(cnode_ptr_, kAttrKeepDims); +} + +void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_format_str, + mindspore::kernel::SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + input_support_format.emplace_back(support_format_str); + output_support_format.emplace_back(support_format_str); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); +} + +bool TbeKernelReduceSelecter::Is4DShape(const std::vector &shape) const { return shape.size() == kShape4dDims; } + +void TbeKernelReduceSelecter::PadScalarShape(std::vector *shape) const { + MS_EXCEPTION_IF_NULL(shape); + if (shape->empty()) { + shape->emplace_back(1); + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h new file mode 100644 index 0000000000..196bb7b06a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ +#include +#include +#include +#include "ir/anf.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" +namespace mindspore { +namespace kernel { +class TbeKernelReduceSelecter { + public: + explicit TbeKernelReduceSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} + ~TbeKernelReduceSelecter() = default; + bool GetShapeInfo(SupportFormat *support_format); + bool IsReduceSupport5HD(SupportFormat *support_format) const; + bool IsReduceSupportNDC1HWC0(SupportFormat *support_format) const; + bool IsReduceSupportFracZ(SupportFormat *support_format) const; + bool IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const; + bool IsReduceSupportFracNZ(SupportFormat *support_format) const; + + private: + bool IsFracZAndC1HWNCoC0Common(const std::string &format, SupportFormat *support_format) const; + void GetReduceAttrKeepDim(); + void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; + bool Is4DShape(const std::vector &shape) const; + void PadScalarShape(std::vector *shape) const; + CNodePtr cnode_ptr_; + std::vector input_shape_{}; + std::vector output_shape_{}; + std::vector axis_{}; + bool keep_dims_ = false; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc new file mode 100644 index 0000000000..21f2347629 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -0,0 +1,622 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "nlohmann/json.hpp" +#include "utils/context/ms_context.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +constexpr auto kName = "name"; +constexpr auto kDtype = "dtype"; +constexpr auto kFormat = "format"; +constexpr auto kPrefixInput = "input"; +constexpr auto kPrefixOutput = "output"; +constexpr char kParamTypeDynamic[] = "dynamic"; +constexpr char kParamTypeRequre[] = "required"; +constexpr char kParamTypeOptional[] = "optional"; +void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list); + tbe_selecter.TbeMetadataInfoEx(); +} + +TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector> *kernel_info_list) + : cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {} + +void TbeKernelSelect::TbeMetadataInfoEx() { + MS_EXCEPTION_IF_NULL(cnode_ptr_); + MS_EXCEPTION_IF_NULL(kernel_info_list_); + node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_); + auto op_info_ptr = OpLib::FindOp(node_name_, kTBE); + if (!op_info_ptr) { + MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; + return; + } + MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_ + << ", node name: " << cnode_ptr_->fullname_with_scope(); + OpPattern pattern = op_info_ptr->op_pattern(); + if (pattern == kCommonPattern) { + GetCommonPatternKernelInfo(*op_info_ptr); + } else if (pattern == kDynamicFormatPattern) { + GetDynamicFormatPatternKernelInfo(*op_info_ptr); + } else if (pattern == kFormatAgnosticPattern) { + GetAgnosticPatternKernelInfo(*op_info_ptr); + } else if (pattern == kBroadcastPattern) { + GetBroadcastPatternKernelInfo(*op_info_ptr); + } else if (pattern == kReducePattern) { + GetReducePatternKernelInfo(*op_info_ptr); + } else { + MS_LOG(INFO) << "Warning: op pattern is invailed."; + } + // check support + FilterInVaildKernelInfo(); + MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; +} + +void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + // get dynamic inputs + auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); + MS_EXCEPTION_IF_NULL(primitive); + std::vector dyn_input_sizes; + if (primitive->HasAttr(kAttrDynInputSizes)) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + // get real input/output num + size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); + const auto inputs_info = op_info.inputs_ptr(); + size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); + const auto outputs_info = op_info.outputs_ptr(); + if (inputs_info.empty() && outputs_info.empty()) { + MS_LOG(EXCEPTION) << "op info input & output is null, please check."; + } + // create kernel build info from opinfo + size_t kernel_build_info_num = + inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size(); + for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) { + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + SetTbeBuildCommonInfo(op_info, &builder); + std::vector inputs_format; + std::vector inputs_device_type; + std::vector> inputs_reshape_type; + // input + if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes, + &inputs_format, &inputs_device_type, &inputs_reshape_type)) { + break; + } + builder.SetInputsDeviceType(inputs_device_type); + builder.SetInputsFormat(inputs_format); + builder.SetInputReshapeType(inputs_reshape_type); + // output + std::vector outputs_format; + std::vector outputs_device_type; + std::vector> outputs_reshape_type; + if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes, + &outputs_format, &outputs_device_type, &outputs_reshape_type)) { + break; + } + builder.SetOutputsDeviceType(outputs_device_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputReshapeType(outputs_reshape_type); + kernel_info_list_->emplace_back(builder.Build()); + } + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + // + OpInfo op_info_new; + CreateNewOpInfo(op_info, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + if (op_info.inputs_ptr().size() != 1) { + MS_LOG(EXCEPTION) << "AgnosticPattern only support one input."; + } + auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0); + if (kOpFormatList.find(format) == kOpFormatList.end()) { + MS_LOG(INFO) << "Got the unknown format " << format; + format = kOpFormat_DEFAULT; + } + SupportFormat support_format; + SupportFormatItem input_item; + SupportFormatItem output_item; + input_item.assign(op_info.inputs_ptr().size(), format); + output_item.assign(op_info.outputs_ptr().size(), format); + support_format.input_format.emplace_back(input_item); + support_format.output_format.emplace_back(output_item); + PrintSupportedFormat(support_format); + OpInfo op_info_new; + CreateNewOpInfo(op_info, support_format, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_); + SupportFormat support_format; + broadcast_selecter.GetShapeInfo(&support_format); + if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD."; + } + if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ."; + } + if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0."; + } + if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; + } + PrintSupportedFormat(support_format); + OpInfo op_info_new; + CreateNewOpInfo(op_info, support_format, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_); + SupportFormat support_format; + reduce_selecter.GetShapeInfo(&support_format); + if (!reduce_selecter.IsReduceSupport5HD(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD."; + } + if (reduce_selecter.IsReduceSupportFracZ(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ."; + } + if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0."; + } + if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ."; + } + PrintSupportedFormat(support_format); + OpInfo op_info_new; + CreateNewOpInfo(op_info, support_format, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::FilterInVaildKernelInfo() { + if (kernel_info_list_->empty()) { + MS_LOG(INFO) << "Warning: get kernel build info failed."; + return; + } + auto kernel_build_info_iter = kernel_info_list_->begin(); + while (kernel_build_info_iter != kernel_info_list_->end()) { + if (!FilterInVaildShape(kernel_build_info_iter)) { + MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString(); + kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); + continue; + } + if (!TbeCheckSupported(kernel_build_info_iter)) { + MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString(); + kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); + continue; + } + kernel_build_info_iter++; + } +} + +bool TbeKernelSelect::FilterInVaildShape( + const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { + MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); + auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); + for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); + auto format = kernel_build_info_inputs_format.at(i); + if (!IsShapeMatchFormat(shape, format)) { + MS_LOG(INFO) << "The " << i << "th input check failed."; + return false; + } + } + auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); + for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { + auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); + auto format = kernel_build_info_outputs_format.at(j); + if (!IsShapeMatchFormat(shape, format)) { + MS_LOG(INFO) << "The " << j << "th input check failed."; + return false; + } + } + return true; +} + +bool TbeKernelSelect::IsShapeMatchFormat(const std::vector &shape, const std::string &format) { + if (format == kOpFormat_DEFAULT) { + return true; + } + static std::set kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; + // if format is default, it remarkes support all format + if (kOpFormatList.find(format) == kOpFormatList.end()) { + MS_LOG(EXCEPTION) << "Got the unknown format " << format; + } + // server not support format with C04 suffix + if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) != + kServerNotSupportFormat.end()) { + MS_LOG(INFO) << "Warning: Server not support format with C04 suffix."; + return false; + } + // not support format: + // 1 NDHWC with shape size != 5 + // 2 FRAC_NZ with shape size < 2 + // 3 !NDHWC with shape size > 4 + if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || + (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) || + (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { + MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); + return false; + } + return true; +} + +bool TbeKernelSelect::TbeCheckSupported( + const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { + MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); + static const std::set kCheckSupportedOpType = {parallel::MATMUL, + parallel::BATCHMATMUL, + parallel::TOPK, + parallel::IN_TOPK, + parallel::PACK, + parallel::UNSORTEF_SEGMENT_MIND, + parallel::UNSORTEF_SEGMENT_PRODD, + parallel::CAST}; + auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); + if (iter == kCheckSupportedOpType.end()) { + return true; + } + MS_LOG(INFO) << "Check support start."; + // replace kernel_info with current kernel info + auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); + AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); + nlohmann::json kernel_json; + TbeKernelJsonCreator creator(CHECK_SUPPORTED); + bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); + if (!ret) { + MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed."; + } + ret = TbePythonFuncs::CheckSupported(kernel_json); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); + return ret; +} + +void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info, + mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) { + MS_EXCEPTION_IF_NULL(builder); + builder->SetProcessor(AICORE); + std::string fusion_type = op_info.fusion_type(); + if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { + builder->SetFusionType(tbe::GetFusionType(fusion_type)); + } + builder->SetOpPattern(op_info.op_pattern()); + builder->SetKernelType(TBE_KERNEL); +} + +bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, + const std::vector> &ios_info, + const std::vector &dyn_input_sizes, std::vector *formats, + std::vector *device_types, std::vector> *reshape_types) { + MS_EXCEPTION_IF_NULL(formats); + MS_EXCEPTION_IF_NULL(device_types); + MS_EXCEPTION_IF_NULL(reshape_types); + size_t dynamic_input_index = 0; + size_t real_io_tensor_index = 0; + size_t io_info_index = 0; + size_t io_info_num = ios_info.size(); + for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { + std::shared_ptr io_info_item = ios_info[io_info_index]; + auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index); + std::string kernel_build_info_format; + if (!io_info_item->formats().empty()) { + kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index); + } + std::string io_param_type = io_info_item->param_type(); + std::vector reshape_type; + StringToAxisVector(io_info_item->reshape_type(), &reshape_type); + if (io_param_type == kParamTypeDynamic) { + // dynamic io + if (is_input) { + if (dynamic_input_index >= dyn_input_sizes.size()) { + MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index + << ", dyn_input_sizes size: " << dyn_input_sizes.size(); + } + int dynamic_input_size = dyn_input_sizes[dynamic_input_index]; + for (int i = 0; i < dynamic_input_size; ++i) { + device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); + formats->emplace_back(kernel_build_info_format); + reshape_types->emplace_back(reshape_type); + } + dynamic_input_index++; + real_io_tensor_index += dynamic_input_size; + } else { + if (ios_info.size() != 1) { + MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; + } + for (size_t i = 0; i < real_io_tensor_num; ++i) { + device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); + formats->emplace_back(kernel_build_info_format); + reshape_types->emplace_back(reshape_type); + } + real_io_tensor_index += real_io_tensor_num; + } + } else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) { + // requre or optional io + device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); + formats->emplace_back(kernel_build_info_format); + reshape_types->emplace_back(reshape_type); + real_io_tensor_index++; + } else { + MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; + } + } + + if (io_info_index != io_info_num) { + MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num + << "), this node may has optional input/output."; + } + if (real_io_tensor_index != real_io_tensor_num) { + std::string io_type = is_input ? "inputs " : "outputs"; + MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num + << ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index + << ") != real_io_tensor_num(" << real_io_tensor_num << ")"; + return false; + } + return true; +} + +void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec) { + MS_EXCEPTION_IF_NULL(reshape_type_vec); + for (const auto &c : reshape_type_str) { + switch (c) { + case 'N': + reshape_type_vec->push_back(kernel::N); + break; + case 'C': + reshape_type_vec->push_back(kernel::C); + break; + case 'H': + reshape_type_vec->push_back(kernel::H); + break; + case 'W': + reshape_type_vec->push_back(kernel::W); + break; + default: + MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; + } + } +} + +void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, + const std::vector> &support_format_item, size_t index, + mindspore::kernel::OpIOInfo *op_io_info_new) { + MS_EXCEPTION_IF_NULL(op_io_info_new); + op_io_info_new->set_index(op_io_info.index()); + op_io_info_new->set_name(op_io_info.name()); + op_io_info_new->set_param_type(op_io_info.param_type()); + op_io_info_new->set_need_compile(op_io_info.need_compile()); + op_io_info_new->set_reshape_type(op_io_info.reshape_type()); + op_io_info_new->set_shape(op_io_info.shape()); + // dtype + std::vector dtype_new; + auto dtype = op_io_info.dtypes(); + for (size_t i = 0; i < support_format_item.size(); ++i) { + dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end()); + } + op_io_info_new->set_dtypes(dtype_new); + // format + std::vector format_new; + for (const auto &formats : support_format_item) { + auto format = formats.at(index); + for (size_t j = 0; j < dtype.size(); ++j) { + format_new.emplace_back(format); + } + } + op_io_info_new->set_formats(format_new); +} + +std::vector TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) { + const std::map kDynamicFormatMap = { + {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}}; + if (op_select_json_item.empty()) { + MS_LOG(EXCEPTION) << "Op select ret item is null."; + } + const char space = ' '; + const char sep = ','; + std::string op_select_tmp = op_select_json_item + ","; + std::vector ret; + auto begin = op_select_tmp.find_first_not_of(space, 0); + auto sep_pos = op_select_tmp.find(sep); + if (begin >= sep_pos) { + MS_LOG(EXCEPTION) << "Select ret json is error."; + } + while (sep_pos != std::string::npos) { + auto obj = op_select_tmp.substr(begin, sep_pos - begin); + if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) { + obj = kDynamicFormatMap.at(obj); + } + ret.emplace_back(obj); + begin = op_select_tmp.find_first_not_of(space, sep_pos + 1); + sep_pos = op_select_tmp.find(sep, begin); + } + return ret; +} + +std::string TbeKernelSelect::OpSelectFormat() { + nlohmann::json kernel_json; + std::string res_json_str; + TbeKernelJsonCreator creator(OP_SELECT_FORMAT); + bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); + if (!ret) { + MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed."; + } + res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); + if (res_json_str.empty()) { + MS_LOG(EXCEPTION) << "op select format error."; + } + MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; + return res_json_str; +} + +void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format, + mindspore::kernel::OpInfo *op_info_new) { + MS_EXCEPTION_IF_NULL(op_info_new); + if (op_info.inputs_ptr().size() != support_format.input_format[0].size() || + op_info.outputs_ptr().size() != support_format.output_format[0].size()) { + MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size() + << ", input support size: " << support_format.input_format[0].size() + << ", op info output size: " << op_info.outputs_ptr().size() + << ", output support size: " << support_format.output_format[0].size(); + } + *op_info_new = op_info; + op_info_new->ClearInputs(); + op_info_new->ClearOutputs(); + for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { + auto input = op_info.inputs_ptr().at(i); + auto input_new = std::make_shared(); + CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get()); + op_info_new->add_inputs_ptr(input_new); + } + for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) { + auto output = op_info.outputs_ptr().at(j); + auto output_new = std::make_shared(); + CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get()); + op_info_new->add_outputs_ptr(output_new); + } +} + +struct SelectOpIOInfo { + std::string name; + std::vector dtypes; + std::vector formats; +}; + +void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, + mindspore::kernel::OpInfo *op_info_new) { + MS_EXCEPTION_IF_NULL(op_info_new); + auto op_seclect_json = OpSelectFormat(); + if (!op_seclect_json.empty()) { + nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json); + if (!json_obj.is_object()) { + MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json; + } + std::vector inputs; + std::vector outputs; + for (const auto &item : json_obj.items()) { + const std::string &item_name = item.key(); + bool is_input = (item_name.find(kPrefixInput) != std::string::npos); + bool is_output = (item_name.find(kPrefixOutput) != std::string::npos); + if (!is_input && !is_output) { + MS_LOG(EXCEPTION) << "op select ret json is error."; + } + if (is_input) { + SelectOpIOInfo select_input; + select_input.name = item.value().at(kName); + std::string input_dtype_item = item.value().at(kDtype); + select_input.dtypes = SplitStrToVec(input_dtype_item); + std::string input_format_item = item.value().at(kFormat); + select_input.formats = SplitStrToVec(input_format_item); + inputs.emplace_back(select_input); + } else if (is_output) { + SelectOpIOInfo select_output; + select_output.name = item.value().at(kName); + std::string input_dtype_item = item.value().at(kDtype); + select_output.dtypes = SplitStrToVec(input_dtype_item); + std::string input_format_item = item.value().at(kFormat); + select_output.formats = SplitStrToVec(input_format_item); + outputs.emplace_back(select_output); + } + } + + if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) { + MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register."; + } + + *op_info_new = op_info; + op_info_new->ClearInputs(); + op_info_new->ClearOutputs(); + for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { + auto input_new = std::make_shared(); + CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get()); + op_info_new->add_inputs_ptr(input_new); + } + for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) { + auto output_new = std::make_shared(); + CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get()); + op_info_new->add_outputs_ptr(output_new); + } + } +} + +void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, + const std::vector &support_dtype, + const std::vector &support_format, + mindspore::kernel::OpIOInfo *op_io_info_new) { + MS_EXCEPTION_IF_NULL(op_io_info_new); + op_io_info_new->set_index(op_io_info.index()); + op_io_info_new->set_name(op_io_info.name()); + op_io_info_new->set_param_type(op_io_info.param_type()); + op_io_info_new->set_need_compile(op_io_info.need_compile()); + op_io_info_new->set_reshape_type(op_io_info.reshape_type()); + op_io_info_new->set_shape(op_io_info.shape()); + // dtype && format + op_io_info_new->set_dtypes(support_dtype); + op_io_info_new->set_formats(support_format); +} + +void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) { + if (support_format.input_format.size() != support_format.output_format.size()) { + MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output(" + << support_format.output_format.size() << ") size not match."; + } + for (size_t i = 0; i < support_format.input_format.size(); ++i) { + auto input_items = support_format.input_format.at(i); + auto output_items = support_format.output_format.at(i); + std::string print_str = "["; + for (const auto &input : input_items) { + print_str.append(input); + print_str.append(", "); + } + print_str.append("] -->"); + for (const auto &output : output_items) { + print_str.append(output); + print_str.append(", "); + } + MS_LOG(INFO) << "Support format: " << print_str; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h new file mode 100644 index 0000000000..679c56379f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_TBE_KERNEL_SELECT_H +#define MINDSPORE_TBE_KERNEL_SELECT_H + +#include +#include +#include +#include "backend/kernel_compiler/oplib/opinfo.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); + +class TbeKernelSelect { + using OpInfoPtr = std::shared_ptr; + using KernelBuildInfoIter = std::vector>::iterator; + + public: + TbeKernelSelect(CNodePtr kernel_node, std::vector> *kernel_info_list); + ~TbeKernelSelect() = default; + void TbeMetadataInfoEx(); + + private: + void GetCommonPatternKernelInfo(const OpInfo &op_info); + void GetDynamicFormatPatternKernelInfo(const OpInfo &op_info); + void GetAgnosticPatternKernelInfo(const OpInfo &op_info); + void GetBroadcastPatternKernelInfo(const OpInfo &op_info); + void GetReducePatternKernelInfo(const OpInfo &op_info); + void FilterInVaildKernelInfo(); + bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); + static bool IsShapeMatchFormat(const std::vector &shape, const std::string &format); + bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); + static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); + bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, + const std::vector> &ios_info, const std::vector &dyn_input_sizes, + std::vector *formats, std::vector *device_types, + std::vector> *reshape_types); + static void StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec); + static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new); + static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, + const std::vector> &support_format_item, size_t index, + OpIOInfo *op_io_info_new); + // op select(dynamic) + void CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, mindspore::kernel::OpInfo *op_info_new); + static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector &support_dtype, + const std::vector &support_format, OpIOInfo *op_io_info_new); + static std::vector SplitStrToVec(const std::string &op_select_json_item); + std::string OpSelectFormat(); + + static void PrintSupportedFormat(const SupportFormat &support_format); + + private: + CNodePtr cnode_ptr_; + std::vector> *kernel_info_list_; + std::string node_name_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_TBE_KERNEL_SELECT_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.cc new file mode 100644 index 0000000000..facb07991a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "common/utils.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +using mindspore::kernel::tbe::TbeUtils; +constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process"; +constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler"; +constexpr auto kOpSelectFormatFunc = "op_select_format"; +constexpr auto kCheckSupportedFunc = "check_supported"; +constexpr auto kTBEException = "TBEException"; + +PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr; +PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr; +PyObject *TbePythonFuncs::pOpSelectFormatFunc_ = nullptr; +PyObject *TbePythonFuncs::pCheckSupportedFunc_ = nullptr; +bool TbePythonFuncs::Init() { + static bool initialized = false; + if (initialized) { + return true; + } + // Initialize cache + TbeUtils::LoadCache(); + + // tbe_process + PyObject *pTbeProcessModule = nullptr; + pTbeProcessModule = PyImport_ImportModule(kTbeProcessModule); + if (pTbeProcessModule == nullptr) { + MS_LOG(ERROR) << "Failed to import [" << kTbeProcessModule << "] module."; + return false; + } + + pCreateTbeParallelCompilerFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCreateTbeParallelCompilerFunc); + if (pCreateTbeParallelCompilerFunc_ == nullptr) { + MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule + << "], FuncName:[" << kCreateTbeParallelCompilerFunc << "]."; + return false; + } + + pTbeCompiler_ = PyEval_CallObject(pCreateTbeParallelCompilerFunc_, nullptr); + if (pTbeCompiler_ == nullptr) { + PyErr_Print(); + MS_EXCEPTION(ArgumentError) << "Failed to call function : create_parallel_compiler."; + return false; + } + + pOpSelectFormatFunc_ = PyObject_GetAttrString(pTbeProcessModule, kOpSelectFormatFunc); + if (pOpSelectFormatFunc_ == nullptr) { + MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule + << "], FuncName:[" << kOpSelectFormatFunc << "]."; + return false; + } + + pCheckSupportedFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCheckSupportedFunc); + if (pCheckSupportedFunc_ == nullptr) { + MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule + << "], FuncName:[" << kCheckSupportedFunc << "]."; + return false; + } + initialized = true; + MS_LOG(INFO) << "TbePythonFuncs initialized Success."; + return true; +} + +std::string TbePythonFuncs::PyObjectToStr(PyObject *PyObj) { + char *pChar = nullptr; + std::string str_res; + if (PyObj == nullptr) { + MS_LOG(ERROR) << "Input parameter is nullptr."; + return str_res; + } + PyObject *strArgs = PyObject_Str(PyObj); + if (strArgs != nullptr) { + (void)PyArg_Parse(strArgs, "s", &pChar); + } + if (pChar == nullptr) { + MS_LOG(ERROR) << "pChar is nullptr."; + return str_res; + } + str_res = pChar; + return str_res; +} + +std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) { + PyObject *pArg = nullptr; + PyObject *pRet = nullptr; + std::string res_json_str; + + if (!Init()) { + MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; + return res_json_str; + } + + // assembly Args + pArg = PyTuple_New(1); + std::string json_str = kernel_json.dump(); + (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", json_str.c_str())); + if (pArg == nullptr) { + MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject."; + return res_json_str; + } + + // call functions + if (pOpSelectFormatFunc_ == nullptr) { + MS_LOG(ERROR) << "function is nullptr."; + return res_json_str; + } + + pRet = PyEval_CallObject(pOpSelectFormatFunc_, pArg); + if (pRet == nullptr) { + PyErr_Print(); + MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc + << "], function args:" << PyObjectToStr(pArg); + } + + char *pstr = nullptr; + (void)PyArg_Parse(pRet, "s", &pstr); + res_json_str = pstr; + if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) { + MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str + << " ,function args:" << PyObjectToStr(pArg); + } + return res_json_str; +} + +bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) { + PyObject *pArg = nullptr; + PyObject *pRes = nullptr; + bool ret = false; + + if (!Init()) { + MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; + return ret; + } + // assembly Args + pArg = PyTuple_New(1); + std::string json_str = kernel_json.dump(); + PyObject *arg1 = Py_BuildValue("s", json_str.c_str()); + (void)PyTuple_SetItem(pArg, 0, arg1); + if (pArg == nullptr) { + MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject."; + return ret; + } + + // call functions + if (pCheckSupportedFunc_ == nullptr) { + MS_LOG(ERROR) << "function is nullptr."; + return ret; + } + + pRes = PyEval_CallObject(pCheckSupportedFunc_, pArg); + if (pRes == nullptr) { + PyErr_Print(); + MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc + << "], function args: " << PyObjectToStr(pArg); + } + if (PyBool_Check(pRes)) { + ret = PyObject_IsTrue(pRes) != 0; + } else { + char *pstr = nullptr; + (void)PyArg_Parse(pRes, "s", &pstr); + std::string res_str = pstr; + if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) { + MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str + << ", function args: " << PyObjectToStr(pArg); + } + } + + return ret; +} + +PyObject *TbePythonFuncs::TbeParallelCompiler() { + if (!Init()) { + MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; + return nullptr; + } + return pTbeCompiler_; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.h similarity index 100% rename from mindspore/ccsrc/kernel/tbe/tbe_python_funcs.h rename to mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.cc new file mode 100644 index 0000000000..76ef7b08d5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.cc @@ -0,0 +1,254 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "runtime/kernel.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "runtime/device/kernel_info.h" +#include "ir/dtype/type.h" +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" +#include "securec/include/securec.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace kernel { +namespace tbe { +constexpr auto kCceKernelMeta = "./kernel_meta/"; +constexpr auto kJsonSuffix = ".json"; +constexpr auto kInfoSuffix = ".info"; + +uintptr_t KernelManager::kernel_stub_gen_ = 0; +std::unordered_map KernelManager::info_table_ = {}; + +void TbeUtils::SaveJsonInfo(const std::string &json_name, const std::string &info) { + char real_path[PATH_MAX] = {0}; + std::string path = kCceKernelMeta + json_name + kInfoSuffix; + if (path.size() > PATH_MAX) { + MS_LOG(ERROR) << "file path: " << path << "is too long."; + return; + } + std::ifstream fin(path); + if (fin) { + MS_LOG(INFO) << "json file exist, no need to create."; + return; + } + std::ofstream file_write; + file_write.open(path); + if (!file_write.is_open()) { + return; + } + file_write << info << std::endl; + file_write.close(); + if (realpath(path.c_str(), real_path) == nullptr) { + MS_LOG(INFO) << "dir: " << path << "does not exit."; + return; + } + MS_LOG(INFO) << "real path is: " << real_path; + if (chmod(real_path, S_IRUSR) == -1) { + MS_LOG(INFO) << "modify file: " << real_path << "to read only fail."; + } +} + +void TbeUtils::LoadCache() { + static bool has_load = false; + if (!has_load) { + KernelMeta *bin_map = KernelMeta::GetInstance(); + if (bin_map != nullptr && !bin_map->ReadIndex(kCceKernelMeta)) { + MS_LOG(INFO) << "Cache initialize failed[" << kCceKernelMeta << "]"; + } else { + MS_LOG(INFO) << "Cache initialize to " << kCceKernelMeta; + } + has_load = true; + } +} + +KernelPackPtr TbeUtils::SearchCache(const std::string &kernel_name, const std::string &processor) { + // search cache. + KernelMeta *bin_map = KernelMeta::GetInstance(); + if (bin_map == nullptr) { + MS_LOG(INFO) << "kernel cache is invalid."; + return nullptr; + } + return bin_map->GetKernelPack(kernel_name, processor); +} + +KernelPackPtr TbeUtils::InsertCache(const std::string &kernel_name, const std::string &processor) { + MS_LOG(INFO) << "kernel name: " << kernel_name << ", processr:" << processor; + if (processor != kProcessorAiCore) { + MS_LOG(EXCEPTION) << "process type should be aicore, actually is: " << processor; + } + return SearchCache(kernel_name, processor); +} + +int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buffer, void **module, + const string &magic) { + static std::map magic_maps = {{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF}, + {"RT_DEV_BINARY_MAGIC_PLAIN", RT_DEV_BINARY_MAGIC_PLAIN}, + {"RT_DEV_BINARY_MAGIC_PLAIN_AICPU", RT_DEV_BINARY_MAGIC_PLAIN_AICPU}, + {"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU}}; + // object for device register. + rtDevBinary_t dev_bin; + dev_bin.data = kernel_buffer.contents; + auto iter = magic_maps.find(magic); + if (iter == magic_maps.end()) { + MS_LOG(INFO) << "Invalid magic number: " << magic; + return -1; + } + dev_bin.magic = iter->second; + dev_bin.length = kernel_buffer.len; + dev_bin.version = 2; + if (RT_ERROR_NONE != rtDevBinaryRegister(&dev_bin, module)) { + MS_LOG(INFO) << "Call runtime rtDevBinaryRegister error."; + return -1; + } + return 0; +} + +uintptr_t KernelManager::GenFuncStub(const mindspore::kernel::KernelPack &kernel_pack, bool force_reload, + uint32_t *block_dim) { + auto kernel = kernel_pack.GetKernel(); + if (kernel == nullptr) { + MS_LOG(EXCEPTION) << "Invalid kernel pack, json or kernel is nullptr."; + } + auto kernel_contents = kernel->contents; + if (kernel_contents == nullptr) { + MS_LOG(EXCEPTION) << "Invalid kernel context, json or kernel is nullptr."; + } + auto kernel_json_info = kernel_pack.kernel_json_info(); + + *block_dim = kernel_json_info.block_dim; + string func_name = kernel_json_info.kernel_name; + string magic = kernel_json_info.magic; + + if (!force_reload) { + // use the cached object. + auto iter = info_table_.find(func_name); + if (iter != info_table_.end()) { + auto kernelmeta = iter->second; + *block_dim = kernelmeta->block_dim_; + return kernelmeta->func_stub_; + } + } + void *module = nullptr; + if (BinaryRegister((*kernel_pack.GetKernel()), &module, magic) != 0) { + MS_LOG(INFO) << "Call runtime BinaryRegister error."; + return 0; + } + // to diff different funcs. + uintptr_t func_stub = ++kernel_stub_gen_; + if (RT_ERROR_NONE != + rtFunctionRegister(module, reinterpret_cast(func_stub), func_name.c_str(), func_name.c_str(), 0)) { + MS_LOG(INFO) << "Call runtime rtFunctionRegister error."; + return 0; + } + // cache the registered kernelmeta. + info_table_[func_name] = std::make_shared(KernelMetaInfo{func_stub, *block_dim}); + return func_stub; +} + +std::string KernelManager::GetStubFuncName(const KernelPackPtr &kernel_pack) { + MS_EXCEPTION_IF_NULL(kernel_pack); + auto kernel_json_info = kernel_pack->kernel_json_info(); + return kernel_json_info.kernel_name; +} + +KernelMeta *KernelMeta::GetInstance() { + static KernelMeta inst; + return &inst; +} + +bool KernelMeta::ReadIndex(const std::string &bin_dir) { + DIR *dir = opendir(bin_dir.c_str()); + if (dir == nullptr) { + auto ret = mkdir(bin_dir.c_str(), S_IRWXG | S_IRWXU); + if (ret != 0) { + MS_LOG(INFO) << "kernel dir: " << bin_dir << "not exist"; + return false; + } + dir = opendir(bin_dir.c_str()); + } + struct dirent *entry; + while ((entry = readdir(dir)) != nullptr) { + string bin_dir_tmp = bin_dir; + std::string cce_json = entry->d_name; + if (cce_json.length() <= 5) { + continue; + } + std::string suffix = cce_json.substr(cce_json.length() - 5); + if (suffix != kJsonSuffix) { + continue; + } + auto sp = cce_json.rfind('/'); + if (sp != std::string::npos) { + continue; + } + sp = cce_json.rfind('.'); + if (sp == std::string::npos) { + continue; + } + auto kernel_name = cce_json.substr(0, sp); + (void)bin_dir_tmp.append("/"); + (void)bin_dir_tmp.append(cce_json); + kernel_index_map_[kernel_name] = bin_dir_tmp; + } + (void)closedir(dir); + + MS_LOG(INFO) << "Cache kernel initialized, kernel size: " << kernel_index_map_.size(); + return true; +} + +KernelPackPtr KernelMeta::GetKernelPack(const std::string &kernel_name, const std::string &processor) { + KernelPackPtr ret = nullptr; + // 1. pack has been created + auto kernel_pack_iter = kernel_pack_map_.find(kernel_name); + if (kernel_pack_iter != kernel_pack_map_.end()) { + MS_LOG(INFO) << "kernel pack [" << kernel_name << "]has been created."; + ret = kernel_pack_iter->second; + } else { + // 2. kernel file has been create, but pack does not been created. + std::string cce_json = kCceKernelMeta; + (void)cce_json.append(kernel_name).append(kJsonSuffix); + ret = std::make_shared(); + if (!ret->LoadKernelMeta(cce_json, processor)) { + MS_LOG(INFO) << "Read cache json and bin file failed[" << cce_json << "]"; + return nullptr; + } + kernel_pack_map_[kernel_name] = ret; + auto iter = kernel_index_map_.find(kernel_name); + if (iter == kernel_index_map_.end()) { + MS_LOG(INFO) << "kernel name [" << kernel_name << "] has been ceated first."; + kernel_index_map_[kernel_name] = cce_json; + } + } + return ret; +} +} // namespace tbe +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.h new file mode 100644 index 0000000000..39ddaaa73d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.h @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ +#include +#include +#include +#include +#include +#include + +#include "backend/session/kernel_graph.h" +#include "ir/anf.h" +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +namespace tbe { +using std::string; +using std::vector; + +class TbeUtils { + public: + TbeUtils() = default; + + ~TbeUtils() = default; + + static void SaveJsonInfo(const std::string &json_name, const std::string &info); + + static void LoadCache(); + + static KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); + + static KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); +}; + +struct KernelMetaInfo { + uintptr_t func_stub_; + uint32_t block_dim_; +}; +using KernelMetaPtr = std::shared_ptr; + +class KernelManager { + public: + static uintptr_t GenFuncStub(const KernelPack &kernel_pack, bool force_reload, uint32_t *block_dim); + static std::string GetStubFuncName(const KernelPackPtr &kernel_pack); + + private: + KernelManager() = default; + ~KernelManager() = default; + static int BinaryRegister(const FlexArray &kernel_buffer, void **module, const string &magic); + static std::unordered_map info_table_; + static uintptr_t kernel_stub_gen_; +}; + +class KernelMeta { + public: + static KernelMeta *GetInstance(); + bool ReadIndex(const std::string &bin_dir); + KernelPackPtr GetKernelPack(const std::string &kernel_name, const std::string &processor); + + private: + KernelMeta() = default; + ~KernelMeta() = default; + std::unordered_map kernel_index_map_{}; + std::unordered_map kernel_pack_map_{}; +}; +} // namespace tbe +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/CMakeLists.txt b/mindspore/ccsrc/backend/optimizer/CMakeLists.txt new file mode 100644 index 0000000000..ee1532a416 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/CMakeLists.txt @@ -0,0 +1,14 @@ +file(GLOB_RECURSE _PREACTIVATE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "common/*.cc" + "mem_reuse/*.cc" + "pass/*.cc" + "gpu/*.cc" +) + +if (ENABLE_D) + file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc") + list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST}) +endif () + +set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT) +add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST}) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc new file mode 100644 index 0000000000..64d76ab358 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -0,0 +1,498 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ascend_backend_optimization.h" +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fission/bn_split.h" +#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" +#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h" +#include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h" +#include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h" +#include "backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h" +#include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" +#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" +#include "backend/optimizer/pass/communication_op_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h" +#include "backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" +#include "backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h" +#include "backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h" +#include "backend/optimizer/ascend/ir_fission/transdata_split.h" +#include "backend/optimizer/ascend/ir_fission/topk_split.h" +#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h" +#include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h" +#include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" +#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" +#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" +#include "backend/optimizer/pass/getitem_tuple.h" +#include "backend/optimizer/pass/optimize_dependence.h" +#include "backend/optimizer/pass/erase_visit_attr.h" +#include "backend/optimizer/ascend/format_type/insert_cast.h" +#include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" +#include "backend/optimizer/pass/eliminate_redundant_op.h" +#include "backend/optimizer/pass/common_subexpression_elimination.h" +#include "backend/optimizer/pass/fuse_graph_kernel.h" +#include "backend/optimizer/pass/fuse_basic.h" +#include "backend/optimizer/pass/add_atomic_clean.h" +#include "backend/optimizer/ascend/format_type/merge_cast_to_op.h" +#include "backend/optimizer/ascend/format_type/check_consistency.h" +#include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h" +#include "backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" +#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" +#include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h" +#include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" +#include "backend/optimizer/ascend/ir_fission/addn_fission.h" +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h" +#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h" +#include "backend/optimizer/ascend/ir_fission/split_fission.h" +#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h" +#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" +#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" +#include "utils/context/ms_context.h" +#include "utils/config_manager.h" +#include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" + +namespace mindspore { +namespace opt { +namespace { +void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { + MS_EXCEPTION_IF_NULL(ir_fusion_pm); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); +} +} // namespace + +void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + auto data_layout_pm = std::make_shared("pynative_transop_pm"); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(data_layout_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void AscendGraphKernelCommonProcess(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + MS_EXCEPTION_IF_NULL(optimizer); + auto common_process = std::make_shared("graph_kernel_common_process"); + MS_EXCEPTION_IF_NULL(common_process); + common_process->AddPass(std::make_shared()); + common_process->AddPass(std::make_shared()); + optimizer->AddPassManager(common_process); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void AscendDataLayout(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + auto data_layout_pm = std::make_shared("transop_pm"); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(data_layout_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void AscendMixPrecision(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + auto mixed_precision_pm = std::make_shared("cast_pm"); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(mixed_precision_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before" + "_graph_" + + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + DumpIRProto(kernel_graph, "before_hwopt_" + std::to_string(kernel_graph->graph_id())); + } + auto optimizer = std::make_shared(); + auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); + if (context_ptr->execution_mode() == kPynativeMode) { + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + } else { + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + } + ir_fusion_pm->AddPass(std::make_shared()); + if (context_ptr->ir_fusion_flag()) { + AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); + } + + if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + } + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(ir_fusion_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "hwopt_d_ir_fusion_after" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } +} + +void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->ir_fusion_flag()) { + MS_LOG(INFO) << "IRFusion is not enable, skip"; + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before.ir"; + DumpIR(file_path, kernel_graph); + } + auto optimizer = std::make_shared(); + auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + + optimizer->AddPassManager(ir_fusion_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after.ir"; + DumpIR(file_path, kernel_graph); + } +} + +void AscendBackendOptimization(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "hwopt_d_before" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } + // data layout optimization + AscendDataLayout(kernel_graph); + // mixed precision optimization + AscendMixPrecision(kernel_graph); + // other optimization + auto optimizer = std::make_shared(); + auto other_pm = std::make_shared("other_pm"); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(other_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + // buffer fusion + AscendBackendUBFusionOptimization(kernel_graph); + + // other2 optimization + auto optimizer2 = std::make_shared(); + auto other2_pm = std::make_shared("other2_pm"); + other2_pm->AddPass(std::make_shared()); + other2_pm->AddPass(std::make_shared()); + if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { + other2_pm->AddPass(std::make_shared()); + } + other2_pm->AddPass(std::make_shared()); + optimizer2->AddPassManager(other2_pm); + (void)optimizer2->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph, true); + DumpIRProto(kernel_graph, "after_hwopt_" + std::to_string(kernel_graph->graph_id())); + kernel_graph->DumpFuncGraph("hwopt_d_end"); + } +} + +void AscendBackendGraphKernelOpt(const std::shared_ptr &kernel_graph, + bool is_before_kernel_select) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!(context_ptr->enable_graph_kernel())) { + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_before_graph_" + + std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + + ".ir"; + DumpIR(file_path, kernel_graph); + } + + // Fuse graph kernels with basic ops + FuseGraphKernel(kernel_graph, is_before_kernel_select); + + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" + + std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + + ".ir"; + DumpIR(file_path, kernel_graph, true); + } +} + +void AscendBackendFuseBasicOpt(const std::shared_ptr &kernel_graph, + bool is_before_kernel_select) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!(context_ptr->enable_graph_kernel())) { + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" + + std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + + ".ir"; + DumpIR(file_path, kernel_graph, true); + } + + // Fuse basic ops with basic ops + FuseBasic(kernel_graph, is_before_kernel_select); + + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" + + std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + + ".ir"; + DumpIR(file_path, kernel_graph, true); + } +} + +void AscendBackendAddAtomicClean(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!(context_ptr->enable_graph_kernel())) { + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_add_atomic_clean_before" + "_graph_" + + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } + + AddAtomicClean(kernel_graph); + + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph, true); + } +} + +void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->ir_fusion_flag()) { + MS_LOG(INFO) << "UBFusion is not enable, skip"; + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = + save_graphs_path + "/hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } + auto fusion_id_allocator = std::make_shared(); + MS_EXCEPTION_IF_NULL(fusion_id_allocator); + fusion_id_allocator->Init(); + auto optimizer = std::make_shared(); + auto ub_fusion_pm = std::make_shared("ub_fusion_pm"); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(ub_fusion_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + if (save_graphs) { + std::string file_path = + save_graphs_path + "/hwopt_d_ub_fusion_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.h new file mode 100644 index 0000000000..8194ab467b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ +#include +#include "backend/session/kernel_graph.h" +namespace mindspore { +namespace opt { +void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph); +void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph); +void AscendDataLayout(const std::shared_ptr &kernel_graph); +void AscendMixPrecision(const std::shared_ptr &kernel_graph); +void AscendBackendOptimization(const std::shared_ptr &kernel_graph); +void AscendGraphKernelCommonProcess(const std::shared_ptr &kernel_graph); +void AscendBackendGraphKernelOpt(const std::shared_ptr &kernel_graph, + bool is_before_kernel_select = false); +void AscendBackendFuseBasicOpt(const std::shared_ptr &kernel_graph, + bool is_before_kernel_select = false); +void AscendBackendAddAtomicClean(const std::shared_ptr &kernel_graph); +void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph); +void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph); +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc new file mode 100644 index 0000000000..fd4c0e5952 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -0,0 +1,345 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ascend_helper.h" +#include +#include "common/trans.h" +#include "common/utils.h" +#include "backend/optimizer/common/helper.h" +#include "utils/utils.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/common_utils.h" +#include "frontend/operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; +namespace { +const std::set kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW}; +AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, + const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { + std::vector trans_inputs; + auto prim = std::make_shared(prim::kPrimReshape->name()); + trans_inputs.emplace_back(NewValueNode(prim)); + trans_inputs.emplace_back(input_node); + auto reshape = func_graph->NewCNode(trans_inputs); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); + AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape); + reshape->set_scope(input_node->scope()); + kernel_select->SelectKernel(reshape); + return reshape; +} + +AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { + AnfNodePtr trans_node = nullptr; + AnfNodePtr input_node = node; + CNodePtr trans_data = nullptr; + std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); + std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; + std::vector padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); + MS_EXCEPTION_IF_NULL(node); + // if insert transdata for input we need to change the input + if (is_insert_input) { + if (!node->isa()) { + MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; + } + auto cnode = node->cast(); + dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); + input_node = AnfAlgo::GetInputNode(cnode, insert_index); + padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); + } + bool need_padding = false; + if (is_insert_input) { + need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); + } else { + need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); + } + if (!need_padding) { + // don't need padding insert transdata only + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); + trans_node = trans_data; + } else if (is_insert_input) { + // if need padding & is input need insert a transdata + // reshape[padding shape] -> transdata[padding shape] -> node + auto padding_shape = + trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); + auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); + trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); + trans_node = trans_data; + } else { + // if need padding & is output need insert a transdata + // node -> transdata[padding shape] -> reshape[ori_shape] + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); + auto reshape_node = + CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); + trans_node = reshape_node; + } + // refresh the transdata's format to ori format & dst format + RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); + return trans_node; +} + +AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(node); + auto input_node = AnfAlgo::GetInputNode(node, index); + auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); + MS_EXCEPTION_IF_NULL(node_with_index.first); + auto real_input = node_with_index.first; + if (real_input->isa() || real_input->isa()) { + input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); + MS_EXCEPTION_IF_NULL(input_node); + AnfAlgo::SetNodeInput(node, input_node, index); + } + std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); + std::string dest_format = AnfAlgo::GetInputFormat(node, index); + if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { + MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) + << " To DefaultFormat , index: " << index; + return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); + } + return input_node; +} + +AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(node); + std::string output_format = AnfAlgo::GetOutputFormat(node, 0); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, 0); + if (output_format == kOpFormat_NC1KHKWHWC0) { + MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " + << node->DebugString(); + } + if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { + MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; + return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); + } + return node; +} + +AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + std::vector make_tuple_inputs; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { + std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); + if (output_format == kOpFormat_NC1KHKWHWC0) { + MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " + << node->DebugString(); + } + auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); + if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { + make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false)); + } else { + // No need insert trans op. + make_tuple_inputs.push_back(tuple_getitem); + } + } + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace +void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, + const AnfNodePtr &trans_data, const std::vector &reshape_type) { + MS_EXCEPTION_IF_NULL(trans_data); + auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); + MS_EXCEPTION_IF_NULL(ori_build_info); + auto builder = std::make_shared(ori_build_info); + builder->SetInputsFormat({input_format}); + builder->SetInputReshapeType({reshape_type}); + builder->SetOutputReshapeType({reshape_type}); + builder->SetOutputsFormat({output_format}); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); +} + +CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, + const bool need_padding, const std::string &op_name) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(input); + std::vector trans_inputs; + auto prim = std::make_shared(op_name); + trans_inputs.push_back(NewValueNode(prim)); + trans_inputs.push_back(input); + CNodePtr trans_node = func_graph->NewCNode(trans_inputs); + MS_EXCEPTION_IF_NULL(trans_node); + auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); + if (need_padding) { + // if need padding we should set the transdata node's shape to the padding shape + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, + {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, + trans_node.get()); + } else { + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, + {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); + } + // special handle for ut + if (trans_node->kernel_info() == nullptr) { + auto kernel_info = std::make_shared(); + trans_node->set_kernel_info(kernel_info); + } + MS_EXCEPTION_IF_NULL(kernel_select); + kernel_select->SelectKernel(trans_node); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); + MS_EXCEPTION_IF_NULL(trans_node); + trans_node->set_scope(input->scope()); + return trans_node; +} + +AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, + const TypeId &input_type, const TypeId &output_type, + const std::vector &origin_shape, const TypeId &origin_type) { + MS_EXCEPTION_IF_NULL(func_graph); + std::string input_format = format; + std::string output_format = format; + std::vector new_cast_inputs; + auto prim = std::make_shared(prim::kPrimCast->name()); + new_cast_inputs.push_back(NewValueNode(prim)); + new_cast_inputs.push_back(input); + CNodePtr cast = func_graph->NewCNode(new_cast_inputs); + MS_EXCEPTION_IF_NULL(cast); + // set kernel build info + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetInputsFormat({input_format}); + builder.SetOutputsFormat({output_format}); + builder.SetInputsDeviceType({input_type}); + builder.SetOutputsDeviceType({output_type}); + builder.SetFusionType(kernel::FusionType::OPAQUE); + builder.SetProcessor(kernel::Processor::AICORE); + if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) { + builder.SetKernelType(KernelType::TBE_KERNEL); + } else { + builder.SetKernelType(KernelType::AKG_KERNEL); + } + // if kernel info is null , it remarks this function is running ut + if (cast->kernel_info() == nullptr) { + auto kernel_info = std::make_shared(); + cast->set_kernel_info(kernel_info); + } + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); + AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); + return cast; +} + +AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select) { + size_t outputs_num = AnfAlgo::GetOutputTensorNum(node); + if (outputs_num == 0) { + return node; + } + // Single output + if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) { + return InsertTransOpForSingleOutput(func_graph, node, kernel_select); + } + // Multiple output + return InsertTransOpForMultipleOutput(func_graph, node, kernel_select); +} + +AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select); + MS_EXCEPTION_IF_NULL(input_node); + new_inputs.push_back(input_node); + } + CNodePtr new_cnode = nullptr; + // cnode changed so make a new cnode to differ from original one. + auto kernel_graph = func_graph->cast>(); + if (kernel_graph == nullptr) { + new_cnode = std::make_shared(*cnode); + } else { + new_cnode = kernel_graph->NewCNode(cnode); + } + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_inputs(new_inputs); + return new_cnode; +} + +CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); + TypeId origin_type(kTypeUnknown); + auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); + auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); + auto real_input_node = kernel_with_index.first; + if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + // weight + origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index); + if (origin_type == kTypeUnknown) { + origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); + } + } else { + // feature map + origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); + } + const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); + const std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index); + const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); + // In graph kernel, we check parameter, + // the eliminate pass will not eliminate this case, so we just do not insert the noused cast. + if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode(cur_input)) { + new_inputs.push_back(cur_input); + } else if (origin_type != device_type) { + auto cast = + AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type); + MS_EXCEPTION_IF_NULL(cast); + cast->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast); + new_inputs.push_back(cast); + } else { + new_inputs.push_back(cur_input); + } + } + auto kernel_graph = func_graph->cast>(); + CNodePtr new_node = nullptr; + if (kernel_graph == nullptr) { + new_node = std::make_shared(*cnode); + } else { + new_node = kernel_graph->NewCNode(cnode); + } + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_inputs(new_inputs); + return new_node; +} + +AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto prim = std::make_shared(kMemCpyAsyncOpName); + std::vector new_node_inputs = {NewValueNode(prim), node}; + auto new_node = graph->NewCNode(new_node_inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_abstract(node->abstract()); + new_node->set_scope(node->scope()); + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h new file mode 100644 index 0000000000..cb308a09a0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ + +#include +#include +#include +#include "runtime/device/ascend/kernel_select_ascend.h" +#include "backend/kernel_compiler/kernel_query.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +class KernelSelect { + public: + KernelSelect() = default; + virtual ~KernelSelect() = default; + virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); } +}; +using KernelSelectPtr = std::shared_ptr; + +class SupportedChecker { + public: + SupportedChecker() = default; + virtual ~SupportedChecker() = default; + virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node, + const kernel::KernelBuildInfoPtr &select_kernel_build_info) { + return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); + } + virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node, + const kernel::KernelBuildInfoPtr &select_kernel_build_info) { + return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); + } +}; +using SupportedCheckerPtr = std::shared_ptr; + +class KernelQuery { + public: + KernelQuery() = default; + virtual ~KernelQuery() = default; + virtual void Query(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { + kernel::KernelQuery(kernel_node, kernel_info_list); + } + virtual bool IsTbeRef(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(node), kernel::kTBE); + if (op_info != nullptr) { + return op_info->is_ref(); + } + return false; + } +}; +using KernelQueryPtr = std::shared_ptr; + +class OpFinder { + public: + OpFinder() = default; + virtual ~OpFinder() = default; + virtual int GetOpRegisteredOutputNum(const std::string &op_name) { + auto op_info = kernel::OpLib::FindOp(op_name, kernel::kTBE); + if (op_info == nullptr) { + return -1; + } + return op_info->outputs_ptr().size(); + } +}; +using OpFinderPtr = std::shared_ptr; + +void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, + const AnfNodePtr &trans_data, const std::vector &reshape_type = {}); + +CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, + const bool need_padding, const std::string &op_name); + +AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, + const TypeId &input_type, const TypeId &output_type, + const std::vector &origin_shape, const TypeId &origin_type); + +AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select); + +AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select); + +CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + +AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..22183c9050 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(relu_input); + auto add = relu_input->cast(); + MS_EXCEPTION_IF_NULL(add); + auto tuple_getitem = add->input(1); + MS_EXCEPTION_IF_NULL(tuple_getitem); + if (tuple_getitem->isa() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) { + auto getitem = tuple_getitem->cast(); + MS_EXCEPTION_IF_NULL(getitem); + auto bnupdate = getitem->input(1); + MS_EXCEPTION_IF_NULL(bnupdate); + if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { + std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); + for (auto out_getitem : manager->node_users()[bnupdate]) { + MS_EXCEPTION_IF_NULL(out_getitem.first); + auto out_getitem_ptr = out_getitem.first->cast(); + MS_EXCEPTION_IF_NULL(out_getitem_ptr); + auto input2 = out_getitem_ptr->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); + } + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); + std::unordered_set record{cnode, relu_input, bnupdate}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } +} + +void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { + MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); + } + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h new file mode 100644 index 0000000000..dfc45b4688 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass { + public: + explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {} + ~BnupdateEltwiseEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..59915d43d4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(relu_input); + auto getitem = relu_input->cast(); + MS_EXCEPTION_IF_NULL(getitem); + auto bnupdate = getitem->input(1); + MS_EXCEPTION_IF_NULL(bnupdate); + if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { + std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); + for (auto out_getitem : manager->node_users()[bnupdate]) { + MS_EXCEPTION_IF_NULL(out_getitem.first); + auto out_getitem_ptr = out_getitem.first->cast(); + MS_EXCEPTION_IF_NULL(out_getitem_ptr); + auto input2 = out_getitem_ptr->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); + } + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); + std::unordered_set record{cnode, bnupdate}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { + MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); + } + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h new file mode 100644 index 0000000000..abaf264d2e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class BnupdateEltwiseFusionPass : public FusionBasePass { + public: + explicit BnupdateEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {} + ~BnupdateEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..1bfff1b50e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltwise( + const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + } else { + return; + } + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + auto double_in_eltwise_input = input_cnode->input(1); + MS_EXCEPTION_IF_NULL(double_in_eltwise_input); + if (!double_in_eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { + return; + } + if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input, prim::kPrimConv2DBackpropInput)) { + (void)record.insert(double_in_eltwise_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && + (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { + MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h new file mode 100644 index 0000000000..6bf74d5268 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class Conv2DBackpropEltwiseEltwiseFusionPass : public FusionBasePass { + public: + explicit Conv2DBackpropEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {} + ~Conv2DBackpropEltwiseEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConv2DBackpropInputEltwiseEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..144ab4b53f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { + return; + } + if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) { + (void)record.insert(eltwise_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && + (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { + MatchConv2DBackpropInputEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h new file mode 100644 index 0000000000..93aa324566 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class Conv2DBackpropEltwiseFusionPass : public FusionBasePass { + public: + explicit Conv2DBackpropEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {} + ~Conv2DBackpropEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc new file mode 100644 index 0000000000..a2ebfbe79e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + auto conv = cnode->input(1); + MS_EXCEPTION_IF_NULL(conv); + if (conv->isa() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) { + std::vector output_used_num{SizeToInt(manager->node_users()[conv].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv); + std::unordered_set record{cnode, conv}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void ConvBnReduceFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) { + MatchConvBnreduce(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h new file mode 100644 index 0000000000..224422530b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class ConvBnReduceFusionPass : public FusionBasePass { + public: + explicit ConvBnReduceFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("ConvBnReduceFusionPass", idAllocator) {} + ~ConvBnReduceFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_CONV_BNREDUCE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc new file mode 100644 index 0000000000..1a67e3c39b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + } else { + return; + } + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + auto double_in_eltwise_input = input_cnode->input(1); + MS_EXCEPTION_IF_NULL(double_in_eltwise_input); + if (!double_in_eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { + return; + } + if (AnfAlgo::GetKernelType(double_in_eltwise_input) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::FusionType::CONVLUTION) { + (void)record.insert(double_in_eltwise_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void ConvDoubleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchConvDoubleInEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h new file mode 100644 index 0000000000..911cf744de --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class ConvDoubleInFusionPass : public FusionBasePass { + public: + explicit ConvDoubleInFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("ConvDoubleInFusionPass", idAllocator) {} + ~ConvDoubleInFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.cc new file mode 100644 index 0000000000..1eb26b12bc --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + while (CheckEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + if (record.size() == MAX_ELTWISE_NUM) { + break; + } + } + MS_EXCEPTION_IF_NULL(eltwise_input); + if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { + return; + } + if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::CONVLUTION) { + (void)record.insert(eltwise_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void ConvSingleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchConvSingleInEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h new file mode 100644 index 0000000000..6dddd600c2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class ConvSingleInFusionPass : public FusionBasePass { + public: + explicit ConvSingleInFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("ConvSingleInFusionPass", idAllocator) {} + ~ConvSingleInFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConvSingleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..285b8f6c07 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnode, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion, bool is_order) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + if (is_order) { + // DepthwiseConvolution--->Elemwise + auto depthwise_conv = cnode->input(1); + MS_EXCEPTION_IF_NULL(depthwise_conv); + if (cnode->isa() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) { + std::vector output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv); + std::unordered_set record{cnode, depthwise_conv}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } else { + // Elemwise-->DepthwiseConvolution + auto relu = cnode->input(1); + MS_EXCEPTION_IF_NULL(relu); + if (cnode->isa() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) { + std::vector output_used_num{SizeToInt(manager->node_users()[relu].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu); + std::unordered_set record{cnode, relu}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } +} + +void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { + MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); + } + } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { + MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h new file mode 100644 index 0000000000..6746dad984 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class DepthwiseConvEltwiseFusionPass : public FusionBasePass { + public: + explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("DepthwiseConvEltwiseFusionPass", idAllocator) {} + ~DepthwiseConvEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion, bool is_order); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc new file mode 100644 index 0000000000..1e24cce0e4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + while (CheckEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + if (record.size() == MAX_ELTWISE_SIZE) { + break; + } + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + } + if (record.size() < MIN_ELTWISE_SIZE) { + return; + } + candidate_fusion->push_back(record); + SetRecordFusionId(record); +} + +void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); + for (auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h new file mode 100644 index 0000000000..ae63687631 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class EltwiseFusionPass : public FusionBasePass { + public: + explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {} + ~EltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.cc new file mode 100644 index 0000000000..27a7a786d1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include +#include +#include "debug/anf_ir_dump.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto user_nodes = manager->node_users()[node]; + return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && + cnode->inputs().size() == ELTWISE_INPUT_SIZE; +} + +bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto user_nodes = manager->node_users()[node]; + return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && + cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE; +} + +bool FusionBasePass::CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto user_nodes = manager->node_users()[node]; + return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_MULTI_USE && + cnode->inputs().size() == ELTWISE_INPUT_SIZE; +} + +void FusionBasePass::SetRecordFusionId(const std::unordered_set &record) { + auto id = fusion_id_allocator->AllocateFusionId(); + for (auto node : record) { + fusion_id_allocator->SetFusionId(node, id); + } +} + +bool FusionBasePass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) { + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + auto return_node = kernel_graph.get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->inputs().size() <= 1) { + return false; + } + MS_LOG(DEBUG) << "MatchBufferFusionPattern start..."; + FusedNodeRecord candidate_fusion; + MatchSingleFusionPattern(kernel_graph, &candidate_fusion); + if (candidate_fusion.empty()) { + return false; + } + MS_LOG(DEBUG) << "MatchBufferFusionPattern Success..."; + return true; +} + +bool FusionBasePass::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto kernel_graph = graph->cast>(); + MS_EXCEPTION_IF_NULL(kernel_graph); + return MatchUBFusionPattern(*kernel_graph); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h new file mode 100644 index 0000000000..dced2c2fa2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ +#include +#include +#include +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +const int8_t MAX_ELTWISE_NUM = 3; +const int8_t MIN_ELTWISE_SIZE = 2; +const int8_t ELTWISE_INPUT_SIZE = 2; +const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3; +const int8_t CONV_DOUBLE_IN_INPUT_SIZE = 3; +const int8_t CONV_QUART_IN_INPUT_SIZE = 5; +const int8_t ELTWISE_USE = 1; +const int8_t ELTWISE_MULTI_USE = 2; +const int8_t MAX_ELTWISE_SIZE = 6; +const int8_t MULTI_ELTWISE_SIZE = 4; +using FusedNodeRecord = std::vector>; + +struct BufferFusionInfo_t { + std::vector anf_nodes; + std::vector inputs_list; + std::vector outputs_list; + kernel::KernelBuildInfoPtr kernel_build_info; +}; + +class FusionBasePass : public Pass { + public: + FusionBasePass(const std::string &name, FusionIdAllocatorPtr idAllocator) + : Pass(name), fusion_id_allocator(idAllocator) {} + ~FusionBasePass() override = default; + bool Run(const FuncGraphPtr &graph) override; + bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph); + + protected: + virtual void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) = 0; + void SetRecordFusionId(const std::unordered_set &record); + bool CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); + bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); + bool CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); + FusionIdAllocatorPtr fusion_id_allocator; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..7fcc6e45e0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector output_used_num{SizeToInt(manager->node_users()[relu_input].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input); + std::unordered_set record{cnode, relu_input}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); +} + +void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) { + MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); + } + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h new file mode 100644 index 0000000000..e0d08bb58d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class MatmulEltwiseFusionPass : public FusionBasePass { + public: + explicit MatmulEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {} + ~MatmulEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.cc new file mode 100644 index 0000000000..58a219aec7 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) { + std::vector output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input); + (void)record.insert(eltwise_input); + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + } else { + return; + } + while (CheckEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + if (record.size() == MULTI_ELTWISE_SIZE) { + break; + } + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + } + if (record.size() != MULTI_ELTWISE_SIZE) { + return; + } + candidate_fusion->push_back(record); + SetRecordFusionId(record); +} + +void MultiOutputFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchMultiOutputEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h new file mode 100644 index 0000000000..40a45360a1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class MultiOutputFusionPass : public FusionBasePass { + public: + explicit MultiOutputFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("MultiOutputFusionPass", idAllocator) {} + ~MultiOutputFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchMultiOutputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..95955818eb --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + while (CheckEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + if (record.size() == MAX_ELTWISE_NUM) { + break; + } + } + MS_EXCEPTION_IF_NULL(eltwise_input); + if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { + return; + } + if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::COMMREDUCE) { + (void)record.insert(eltwise_input); + auto previous_input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(previous_input_cnode); + auto previous_eltwise_input = previous_input_cnode->input(1); + auto previous_size = record.size(); + while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { + (void)record.insert(previous_eltwise_input); + auto previous_node = previous_eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(previous_node); + previous_eltwise_input = previous_node->input(1); + if (record.size() - previous_size == MAX_ELTWISE_NUM) { + break; + } + } + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void ReduceEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchReduceEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h new file mode 100644 index 0000000000..4d56eee7b3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class ReduceEltwiseFusionPass : public FusionBasePass { + public: + explicit ReduceEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("ReduceEltwiseFusionPass", idAllocator) {} + ~ReduceEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchReduceEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWSIE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..f2117f9374 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + while (CheckEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + if (record.size() == MAX_ELTWISE_NUM) { + break; + } + } + MS_EXCEPTION_IF_NULL(eltwise_input); + if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { + return; + } + if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::SEGMENT) { + (void)record.insert(eltwise_input); + auto previous_input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(previous_input_cnode); + auto previous_eltwise_input = previous_input_cnode->input(1); + auto previous_size = record.size(); + while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { + (void)record.insert(previous_eltwise_input); + auto previous_node = previous_eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(previous_node); + previous_eltwise_input = previous_node->input(1); + if (record.size() - previous_size == MAX_ELTWISE_NUM) { + break; + } + } + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchSegmentEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h new file mode 100644 index 0000000000..f3b97f8357 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class SegmentEltwiseFusionPass : public FusionBasePass { + public: + explicit SegmentEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("SegmentEltwiseFusionPass", idAllocator) {} + ~SegmentEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchSegmentEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWSIE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc new file mode 100644 index 0000000000..d93b47b66c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h" + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(const CNodePtr &cnode, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto write_input = cnode->input(1); + if (CheckEltWiseNode(manager.get(), write_input)) { + (void)record.insert(write_input); + auto input_cnode = write_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + write_input = input_cnode->input(1); + } + MS_EXCEPTION_IF_NULL(write_input); + if (!write_input->isa() || !AnfAlgo::IsRealCNodeKernel(write_input) || + fusion_id_allocator->HasFusionIdAttr(write_input)) { + return; + } + auto conv_cnode = write_input->cast(); + MS_EXCEPTION_IF_NULL(conv_cnode); + if (AnfAlgo::GetKernelType(conv_cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(conv_cnode) == kernel::FusionType::CONVLUTION && + conv_cnode->inputs().size() >= CONV_DOUBLE_IN_INPUT_SIZE && + conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) { + (void)record.insert(write_input); + auto conv_input = conv_cnode->input(1); + MS_EXCEPTION_IF_NULL(conv_input); + if (!conv_input->isa() || !AnfAlgo::IsRealCNodeKernel(conv_input) || + fusion_id_allocator->HasFusionIdAttr(conv_input)) { + return; + } + if (AnfAlgo::GetCNodeName(conv_input) == kStridedReadOpName) { + (void)record.insert(conv_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } +} + +void StridedReadConvStridedWriteFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) == kStridedWriteOpName) { + MatchStridedReadConvStridedWrite(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h new file mode 100644 index 0000000000..371c206399 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class StridedReadConvStridedWriteFusionPass : public FusionBasePass { + public: + explicit StridedReadConvStridedWriteFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("StridedReadConvStridedWriteFusionPass", idAllocator) {} + ~StridedReadConvStridedWriteFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchStridedReadConvStridedWrite(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc new file mode 100644 index 0000000000..9685530705 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc @@ -0,0 +1,448 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "runtime/device/kernel_info.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +namespace { +const int8_t MAX_PATTERN_SIZE = 7; +const int8_t MIN_PATTERN_SIZE = 2; +const int8_t ELTWISE_INPUT_SIZE = 2; +const int8_t ELTWISE_USE = 1; +const int8_t MULTI_ELTWISE_USE = 2; +const int8_t MAX_MULTI_ELTWISE_SIZE = 4; +const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3; +constexpr auto kOpAttrFusionId = "fusion_id"; + +#ifdef DEBUG +std::string GetFusionTypeName(const kernel::FusionType &type) { + switch (type) { + case kernel::FusionType::COMMREDUCE: + return "COMMREDUCE"; + case kernel::FusionType::SEGMENT: + return "SEGMENT"; + case kernel::FusionType::ELEMWISE: + return "ELEMWISE"; + case kernel::FusionType::CONVLUTION: + return "CONVLUTION"; + case kernel::FusionType::OPAQUE: + return "OPAQUE"; + default: + return "OPAQUE"; + } +} + +void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) { + MS_LOG(INFO) << "=== Dump FusionScopeInfo start id: " << info.scope_id; + for (auto &node : info.input_nodes) { + MS_LOG(INFO) << "=== Input: " << node->DebugString(); + } + for (auto &node : info.output_nodes) { + MS_LOG(INFO) << "=== Output: " << node->DebugString(); + } + for (auto &node : info.compute_nodes) { + MS_LOG(INFO) << "=== Compute: (" << node->DebugString() << ")-(" << GetFusionTypeName(AnfAlgo::GetFusionType(node)) + << ")"; + } + MS_LOG(INFO) << "=== Dump FusionScopeInfo end"; +} +#endif +CNodePtr CreateFusionOp(const std::vector &inputs_list, const std::vector &outputs_list, + const std::vector &anf_nodes, session::KernelGraph *kernel_graph) { + MS_LOG(DEBUG) << "Start Create FusionOp Kernel"; + MS_EXCEPTION_IF_NULL(kernel_graph); + std::string fusion_op_name = "FusionOp"; + for (auto node : anf_nodes) { + fusion_op_name += '_' + AnfAlgo::GetCNodeName(node); + } + auto fusion_op = std::make_shared(fusion_op_name); + MS_EXCEPTION_IF_NULL(fusion_op); + + std::vector input_names; + for (uint8_t i = 0; i < inputs_list.size(); i++) { + input_names.emplace_back("input" + std::to_string(i)); + } + std::vector output_names; + for (uint8_t i = 0; i < outputs_list.size(); i++) { + output_names.emplace_back("output" + std::to_string(i)); + } + + ValuePtr input_names_v = MakeValue(input_names); + ValuePtr output_names_v = MakeValue(output_names); + fusion_op->set_attr("input_names", input_names_v); + fusion_op->set_attr("output_names", output_names_v); + std::vector fusion_inputs_list = inputs_list; + auto value_node = std::make_shared(fusion_op); + (void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node); + auto buffer_fusion_kernel = kernel_graph->NewCNode(fusion_inputs_list); + if (buffer_fusion_kernel == nullptr) { + MS_LOG(EXCEPTION) << "New FusionOp kernel failed!"; + } + buffer_fusion_kernel->set_scope((anf_nodes.back())->scope()); + + return buffer_fusion_kernel; +} + +kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector &inputs_list, + const std::vector &outputs_list) { + MS_LOG(DEBUG) << "Start Create Kernel Info"; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + // inputs format and data type + std::vector inputs_format; + std::vector inputs_data_type; + for (const auto &input : inputs_list) { + auto real_input = AnfAlgo::VisitKernel(input, 0); + inputs_format.push_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); + inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); + } + // outputs format and data type + std::vector outputs_format; + std::vector outputs_data_type; + for (const auto &output : outputs_list) { + if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { + auto tuple_getitem = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + outputs_format.push_back(AnfAlgo::GetOutputFormat( + tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); + outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( + tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); + } else { + outputs_format.push_back(AnfAlgo::GetOutputFormat(output, 0)); + outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); + } + } + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_data_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_data_type); + builder.SetKernelType(KernelType::TBE_KERNEL); + return builder.Build(); +} + +AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph, + size_t output_index) { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector tuple_getitem_inputs_list; + auto value = std::make_shared(prim::kPrimTupleGetItem); + MS_EXCEPTION_IF_NULL(value); + auto idx = NewValueNode(SizeToInt(output_index)); + MS_EXCEPTION_IF_NULL(idx); + int temp = SizeToInt(output_index); + auto imm = std::make_shared(temp); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + tuple_getitem_inputs_list.push_back(value); + tuple_getitem_inputs_list.push_back(buffer_fusion_kernel); + tuple_getitem_inputs_list.push_back(idx); + auto tuple_item = kernel_graph->NewCNode(tuple_getitem_inputs_list); + MS_EXCEPTION_IF_NULL(tuple_item); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)}, + {AnfAlgo::GetOutputInferShape(buffer_fusion_kernel, output_index)}, + tuple_item.get()); + return tuple_item; +} + +void ReplaceInputNodeInOtherFusionScope(std::unordered_map *buffer_fusion_infos, + int32_t fusion_id, const AnfNodePtr &output_item, + const AnfNodePtr &replace_item) { + for (int32_t id = fusion_id + 1; id <= SizeToInt(buffer_fusion_infos->size()); ++id) { + auto itr = std::find((*buffer_fusion_infos)[id].inputs_list.begin(), (*buffer_fusion_infos)[id].inputs_list.end(), + output_item); + if (itr != (*buffer_fusion_infos)[id].inputs_list.end()) { + MS_LOG(DEBUG) << "replace input of other pattern, id = " << id; + *itr = replace_item; + } + } +} + +void ReplaceOldNode(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, + const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; + if (buffer_fusion_info.outputs_list.size() == 1) { // single output + (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); + ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0], + buffer_fusion_kernel); + } else { // multiple output + for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) { + auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index); + (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item); + ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index], + tuple_item); + } + } +} + +void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto nodes = TopoSort(kernel_graph->get_return()); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) { + auto fusion_id = AnfAlgo::GetNodeAttr(cnode, kOpAttrFusionId); + (*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode); + } + } +} + +void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, + std::unordered_map *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + + for (auto &buffer_fusion_info : *buffer_fusion_infos) { + auto fusion_id = buffer_fusion_info.first; + auto fusion_info = buffer_fusion_info.second; + for (const auto &node : fusion_info.anf_nodes) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { + auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); + if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == + fusion_info.anf_nodes.end()) { + if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), + (*buffer_fusion_infos)[fusion_id].inputs_list.end(), + cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { + (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx)); + } + } + } + } + } +} + +bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + auto getitem1 = node1->cast(); + auto getitem2 = node2->cast(); + MS_EXCEPTION_IF_NULL(getitem1); + MS_EXCEPTION_IF_NULL(getitem2); + if (getitem1->size() < kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1[" + << getitem1->DebugString() << "]"; + } + if (getitem2->size() < kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1[" + << getitem2->DebugString() << "]"; + } + auto output_idx1 = GetValue(GetValueNode(getitem1->input(2))); + auto output_idx2 = GetValue(GetValueNode(getitem2->input(2))); + return output_idx1 < output_idx2; +} + +void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + for (auto &buffer_fusion_info : *buffer_fusion_infos) { + auto fusion_id = buffer_fusion_info.first; + auto fusion_info = buffer_fusion_info.second; + for (const auto &node : fusion_info.anf_nodes) { + if (AnfAlgo::GetOutputTensorNum(node) == 1) { + for (auto use_node : manager->node_users()[node]) { + if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), use_node.first) == + fusion_info.anf_nodes.end()) { + (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(node); + break; + } + } + } else { + int prev_idx = 0; + std::vector tuple_getitem_nodes; + std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), + std::back_inserter(tuple_getitem_nodes), + [](const std::pair &use_node) { return use_node.first; }); + std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); + for (auto getitem : tuple_getitem_nodes) { + MS_EXCEPTION_IF_NULL(getitem); + auto getitem_ptr = getitem->cast(); + auto input2 = getitem_ptr->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + for (int stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) { + auto stub_node = CreateTupleGetItem(node, kernel_graph, IntToSize(stub_idx)); + (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node); + } + prev_idx = output_idx + 1; + for (auto item_use_node : manager->node_users()[getitem]) { + if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) == + fusion_info.anf_nodes.end()) { + (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem); + break; + } + } + } + } + } + } +} + +void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector &outputs_list, + const AnfNodePtr &fusion_kernel) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (size_t idx = 0; idx < outputs_list.size(); ++idx) { + auto output = outputs_list[idx]; + MS_EXCEPTION_IF_NULL(output); + if (output->isa() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { + auto real_output = AnfAlgo::VisitKernel(output, 0); + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + auto input2 = output_cnode->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + session::AnfWithOutIndex out_pair(real_output.first, output_idx); + if (kernel_graph->IsInRefOutputMap(out_pair)) { + auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); + session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); + kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); + } + } else { + session::AnfWithOutIndex out_pair(output, 0); + if (kernel_graph->IsInRefOutputMap(out_pair)) { + auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); + session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); + kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); + } + } + } +} +} // namespace + +void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) const { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); + GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos); + GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); + for (auto &buffer_fusion_info : *buffer_fusion_infos) { + buffer_fusion_info.second.kernel_build_info = + CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list); + } +} + +bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + bool change = false; + std::unordered_map buffer_fusion_infos; + buffer_fusion_infos.clear(); + GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos); + + std::vector fusion_scope_infos; + for (auto &buffer_fusion_info : buffer_fusion_infos) { + mindspore::kernel::FusionScopeInfo fusion_scope_info; + fusion_scope_info.scope_id = buffer_fusion_info.first; + fusion_scope_info.input_nodes = buffer_fusion_info.second.inputs_list; + fusion_scope_info.compute_nodes = buffer_fusion_info.second.anf_nodes; + fusion_scope_info.output_nodes = buffer_fusion_info.second.outputs_list; + fusion_scope_infos.push_back(fusion_scope_info); +#ifdef DEBUG + DumpFusionScopeInfo(fusion_scope_info); +#endif + } + auto kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos); + std::vector fusion_ids; + for (auto &buffer_fusion_info : buffer_fusion_infos) { + MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size() + << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size() + << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size(); + fusion_ids.push_back(buffer_fusion_info.first); + } + // Replace fusion op from return to head + std::sort(fusion_ids.begin(), fusion_ids.end()); + for (auto &fusion_id : fusion_ids) { + // Get kernel mod when supporting tbe + if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) { + MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed"; + continue; + } + change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); + } + MS_LOG(DEBUG) << "End Buffer Fusion"; + return change; +} + +bool UbPatternFusion::ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, + int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, + session::KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; + auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, + buffer_fusion_info.anf_nodes, kernel_graph); + AnfAlgo::SetSelectKernelBuildInfo(buffer_fusion_info.kernel_build_info, buffer_fusion.get()); + // Set abstract of fusion_op node + std::vector types; + std::vector> shapes; + for (const auto &out_node : buffer_fusion_info.outputs_list) { + for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(out_node); ++idx) { + types.push_back(AnfAlgo::GetOutputInferDataType(out_node, idx)); + shapes.push_back(AnfAlgo::GetOutputInferShape(out_node, idx)); + } + } + if (types.empty() || shapes.empty()) { + MS_LOG(WARNING) << "buffer_fusion_info.outputs_list is empty"; + return false; + } + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get()); + AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get()); + SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion); + ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph); + return true; +} + +bool UbPatternFusion::Run(const FuncGraphPtr &graph) { + bool changed = false; + MS_EXCEPTION_IF_NULL(graph); + auto kernel_graph = graph->cast>(); + MS_EXCEPTION_IF_NULL(kernel_graph); + changed = FuseBufferFusionPattern(kernel_graph.get()); + // clear fusion_id attr + for (auto &node : graph->nodes()) { + if (node != nullptr && node->isa()) { + AnfAlgo::EraseNodeAttr(kAttrFusionId, node); + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h new file mode 100644 index 0000000000..69eb0f43d4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ +#include +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class UbPatternFusion : public Pass { + public: + UbPatternFusion() : Pass("TbeBufferFusion") {} + ~UbPatternFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + void GetBufferFusionInfo(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) const; + bool ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, + const kernel::KernelModPtr &kernel_ptr, session::KernelGraph *kernel_graph) const; + bool FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc new file mode 100644 index 0000000000..a729cdd0f9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" + +namespace mindspore::opt { + +const BaseRef GetnextMemcpyElimination::DefinePattern() const { + auto prim_memcpy = std::make_shared(kMemCpyAsyncOpName); + VarPtr x = std::make_shared(); + VectorRef memcpy_async({prim_memcpy, x}); + return memcpy_async; +} + +const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + if (graph == nullptr || node == nullptr || equiv == nullptr) { + return nullptr; + } + auto memcpy_cnode = node->cast(); + if (memcpy_cnode == nullptr) { + return nullptr; + } + + // 1. memcpy has attr kAttrLabelForInsertStreamActive + if (!AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, memcpy_cnode)) { + MS_LOG(DEBUG) << "node has no label_for_insert_stream_active attr"; + return nullptr; + } + + // 2. memcpy's output has only one user next_node + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(memcpy_cnode) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "memcpy has no output in manager"; + } + auto next_nodes = manager->node_users()[memcpy_cnode]; + if (next_nodes.size() > 1) { + MS_LOG(DEBUG) << "node's output has more than one users"; + return nullptr; + } + + // 3. next_node is not nop node and it has only one input which is memcpy's output + for (auto &item : next_nodes) { + auto next_node = item.first->cast(); + if (opt::IsNopNode(next_node)) { + return nullptr; + } + if (next_node->inputs().size() != 2) { + MS_LOG(DEBUG) << "next node has more than one input"; + return nullptr; + } + // add attr label_for_insert_stream_active for next_node + AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), next_node); + } + + return memcpy_cnode->input(1); +} +} // namespace mindspore::opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h new file mode 100644 index 0000000000..365088b34a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class GetnextMemcpyElimination : public PatternProcessPass { + public: + explicit GetnextMemcpyElimination(bool multigraph = true) + : PatternProcessPass("getnext_memcpy_elimination", multigraph) {} + ~GetnextMemcpyElimination() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc new file mode 100644 index 0000000000..bac9f54ace --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h" +#include +#include +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +AnfNodePtr InsertMemcpyAsyncForGetNextOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + if (func_graph == nullptr || node == nullptr) { + return nullptr; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + if (output_num == 0) { + MS_LOG(DEBUG) << "Output number is zero, no need to insert memcpy_async!"; + return node; + } + + // getnext output is tuple and dynamic + std::vector make_tuple_inputs; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + + for (size_t output_index = 0; output_index < output_num; ++output_index) { + auto tuple_get_item = CreatTupleGetItemNode(func_graph, node, output_index); + auto new_node = CreateMemcpyAsyncOp(func_graph, tuple_get_item); + if (new_node == nullptr) { + MS_LOG(EXCEPTION) << "Create memcpy_async op failed!"; + } + AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node); + make_tuple_inputs.push_back(new_node); + } + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} + +const BaseRef InsertMemcpyAsyncForGetNext::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + auto prim = std::make_shared(kGetNextOpName); + + return VectorRef({prim, Xs}); +} + +const AnfNodePtr InsertMemcpyAsyncForGetNext::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (func_graph == nullptr || node == nullptr || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + + auto cnode = node->cast(); + if (AnfAlgo::HasNodeAttr(kAttrVisited, cnode)) { + MS_LOG(DEBUG) << "Node op_name[" << kGetNextOpName << "] has visited."; + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cnode); + + return InsertMemcpyAsyncForGetNextOutputs(func_graph, cnode); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h new file mode 100644 index 0000000000..6fefc32230 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class InsertMemcpyAsyncForGetNext : public PatternProcessPass { + public: + explicit InsertMemcpyAsyncForGetNext(bool multigraph = true) + : PatternProcessPass("insert_memcpy_async_for_getnext", multigraph) {} + ~InsertMemcpyAsyncForGetNext() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc new file mode 100644 index 0000000000..2585006be6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" +#include +#include +#include +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +namespace { +// insert memcpy for some cnode even if not a Ref cnode +const std::set kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName, + kLambUpdateWithLROpName}; + +bool IsParameterOrValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + return kernel_with_index.first->isa() || kernel_with_index.first->isa(); +} + +void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(hccl_node); + MS_EXCEPTION_IF_NULL(memcpy_async); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(hccl_node); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + // find hccl_node's output which is a control depend + for (const auto &node_index : iter->second) { + AnfNodePtr output = node_index.first; + int output_index = node_index.second; + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { + CNodePtr control_depend = output->cast(); + MS_EXCEPTION_IF_NULL(control_depend); + std::vector new_inputs; + for (size_t i = 0; i < control_depend->size(); ++i) { + if (i == IntToSize(output_index)) { + new_inputs.push_back(memcpy_async); + } else { + new_inputs.push_back(control_depend->input(i)); + } + } + control_depend->set_inputs(new_inputs); + } + } +} +} // namespace + +bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input); + // when input is a parameter or is a value node + if (IsParameterOrValueNode(input)) { + return true; + } + + // when input is a Ref or some special cnodes + if (kernel_query_->IsTbeRef(input) || + kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { + return true; + } + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(input); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + // when input is used by others + if (iter->second.size() > 1) { + return true; + } + return false; +} + +void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(hccl_node); + bool has_insert_memcpy = false; + AnfNodePtr memcpy_async = nullptr; + std::vector new_inputs = {hccl_node->input(0)}; + for (size_t i = 1; i < hccl_node->size(); ++i) { + auto input = hccl_node->input(i); + if (NeedInsertMemcpy(graph, input)) { + memcpy_async = CreateMemcpyAsyncOp(graph, input); + has_insert_memcpy = true; + new_inputs.push_back(memcpy_async); + } else { + new_inputs.push_back(input); + } + } + + if (has_insert_memcpy) { + CNodePtr new_hccl_node = std::make_shared(*hccl_node); + new_hccl_node->set_inputs(new_inputs); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; + (void)manager->Replace(hccl_node, new_hccl_node); + MS_LOG(DEBUG) << "end replace"; + + // transer hccl op's control to the memcpy_async + if (hccl_node->size() == 2) { + TransferControl(new_hccl_node, memcpy_async, graph); + } + } +} + +const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (func_graph == nullptr || node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + if (!AnfAlgo::IsCommunicationOp(node)) { + return nullptr; + } + InsertMemcpyAsync(func_graph, cnode); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h new file mode 100644 index 0000000000..7bd730a84d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { + public: + explicit InsertMemcpyAsyncForHcclOp(bool multigraph = true) + : PatternProcessPass("insert_memcpy_async_for_hccl_op", multigraph), + kernel_query_(std::make_shared()) {} + ~InsertMemcpyAsyncForHcclOp() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; + bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; + KernelQueryPtr kernel_query_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc new file mode 100644 index 0000000000..be61833fe4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" +#include +#include +#include +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler//oplib/oplib.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef InsertPadForNMSWithMask::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimNMSWithMask, Xs}); +} + +AnfNodePtr InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type, + const std::vector &origin_shape) { + MS_EXCEPTION_IF_NULL(func_graph); + std::vector new_pad_inputs; + auto prim = std::make_shared(prim::kPrimPad->name()); + new_pad_inputs.push_back(NewValueNode(prim)); + new_pad_inputs.push_back(input); + CNodePtr pad = func_graph->NewCNode(new_pad_inputs); + MS_EXCEPTION_IF_NULL(pad); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, pad.get()); + return pad; +} + +const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + + size_t input_num = AnfAlgo::GetInputTensorNum(node); + if (input_num == 0) { + return nullptr; + } + std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; + for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) { + auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx); + auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx); + auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx); + if (!(origin_shape.size() == 2 && origin_shape[1] == 5)) { + return nullptr; + } + origin_shape[1] = 8; + auto pad = InsertPadToGraph(func_graph, cur_input, origin_type, origin_shape); + MS_EXCEPTION_IF_NULL(pad); + pad->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr("paddings", MakeValue(std::vector>{{0, 0}, {0, 3}}), pad); + new_inputs.push_back(pad); + } + auto kernel_graph = func_graph->cast>(); + CNodePtr new_node = nullptr; + if (kernel_graph == nullptr) { + new_node = std::make_shared(*cnode); + } else { + new_node = kernel_graph->NewCNode(cnode); + } + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_inputs(new_inputs); + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h new file mode 100644 index 0000000000..6aed678ff2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass.h" + +namespace mindspore { +namespace opt { +class InsertPadForNMSWithMask : public PatternProcessPass { + public: + explicit InsertPadForNMSWithMask(bool multigraph = true) + : PatternProcessPass("insert_pad_for_nms_with_mask", multigraph) {} + ~InsertPadForNMSWithMask() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.cc new file mode 100644 index 0000000000..f508bb2868 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" + +#include +#include +#include +#include + +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +namespace { +using ConvertFunction = std::function; + +void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode); +const size_t kAxis_H = 2; +const size_t kAxis_W = 3; +const size_t kAxis_6HD_H = 1; +const size_t kAxis_6HD_W = 2; +const std::map kReduceConvertMap = {{kOpFormat_FRAC_Z, ConvertReduceAttrFraczAnd6HD}, + {kOpFormat_C1HWNCoC0, ConvertReduceAttrFraczAnd6HD}}; +void SafeCheckFunction(const CNodePtr &cnode, const std::vector &reduce_axis) { + if (reduce_axis.empty()) { + MS_LOG(EXCEPTION) << "The node " << cnode->DebugString() << "'s reduce axis got a empty vector"; + } + if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) && + AnfAlgo::GetInputTensorNum(cnode) != 1) { + MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString() + << "] is not single input or single output "; + } + for (auto elem : reduce_axis) { + if (elem > 4) { + MS_LOG(INFO) << "reduce axis is larger than 4 dims reduce axis : [" << elem << "]"; + } + } +} + +void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode) { + auto axis = kernel::GetReduceAttrAxis(cnode); + std::vector convert_axis; + SafeCheckFunction(cnode, axis); + auto format = AnfAlgo::GetInputFormat(cnode, 0); + if (format != kOpFormat_FRAC_Z || format != kOpFormat_C1HWNCoC0) { + MS_LOG(EXCEPTION) << "The node [" << cnode->DebugString() << "] format " << format << " is not 5hd"; + } + for (auto elem : axis) { + switch (elem) { + case kAxis_H: + convert_axis.emplace_back(kAxis_6HD_H); + break; + case kAxis_W: + convert_axis.emplace_back(kAxis_6HD_W); + break; + default: + MS_LOG(INFO) << "reduce axis is axis : [" << elem << "]" + << " but the format is not supported this reduce axis"; + } + } + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(convert_axis), cnode); +} +} // namespace + +const BaseRef ChangeAxisOfReduceKernel::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr ChangeAxisOfReduceKernel::Process(const FuncGraphPtr &, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + if (AnfAlgo::GetOpPattern(node) != kernel::kReducePattern) { + return nullptr; + } + auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0)); + if (convert_map == kReduceConvertMap.end()) { + return nullptr; + } + convert_map->second(node->cast()); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h new file mode 100644 index 0000000000..6bf1287ae7 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ChangeAxisOfReduceKernel : public PatternProcessPass { + public: + explicit ChangeAxisOfReduceKernel(bool multigraph = true) + : PatternProcessPass("change_axis_of_reduce_kernel", multigraph) {} + ~ChangeAxisOfReduceKernel() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.cc new file mode 100644 index 0000000000..7da0027310 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/format_type/check_consistency.h" + +#include +#include +#include + +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +namespace { +bool CheckFormatForConsistency(const CNodePtr &node, const size_t input_index) { + MS_EXCEPTION_IF_NULL(node); + // get prior node's device output format + string pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(node, input_index); + string selected_input_format = AnfAlgo::GetInputFormat(node, input_index); + if (pre_output_format == selected_input_format) { + return true; + } + auto input_origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, input_index); + if (pre_output_format == kOpFormat_DEFAULT || selected_input_format == kOpFormat_DEFAULT) { + string checking_format = (pre_output_format == kOpFormat_DEFAULT) ? selected_input_format : pre_output_format; + // when input shape size is 1D, default format and NC1HWC0 are compatible + if (input_origin_shape.size() == 1 && checking_format == kOpFormat_NC1HWC0) { + return true; + } + if (kDefaultCompatibleFormat.find(checking_format) != kDefaultCompatibleFormat.end()) { + return true; + } + } + if (input_origin_shape.size() == 0) { + return true; + } + MS_LOG(ERROR) << "Found inconsistent format! input format " << input_index << ": " << pre_output_format + << ", selected input format: " << selected_input_format; + return false; +} + +bool CheckDataTypeForConsistency(const CNodePtr &node, const size_t input_index) { + MS_EXCEPTION_IF_NULL(node); + TypeId input_data_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(node, input_index); + TypeId selected_data_type = AnfAlgo::GetInputDeviceDataType(node, input_index); + if (input_data_type == selected_data_type) { + return true; + } + MS_LOG(ERROR) << "Found inconsistent dtype! input dtype " << input_index << ": " << TypeIdLabel(input_data_type) + << ", selected dtype: " << TypeIdLabel(selected_data_type); + return false; +} +} // namespace + +const BaseRef CheckConsistency::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + + std::vector todos = {node}; + if (AnfAlgo::IsGraphKernel(node)) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + kernel::GetValidKernelNodes(sub_graph, &todos); + } + + for (auto &t : todos) { + CNodePtr cnode = t->cast(); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { + if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { + MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "[" + << cnode->DebugString() << "]"; + } + } + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.h new file mode 100644 index 0000000000..bf956895de --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class CheckConsistency : public PatternProcessPass { + public: + explicit CheckConsistency(bool multigraph = true) : PatternProcessPass("check_consistency", multigraph) {} + ~CheckConsistency() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc new file mode 100644 index 0000000000..48948dca06 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/kernel_compiler/kernel_query.h" +namespace mindspore { +namespace opt { +const BaseRef ConvertUnSupportNodeToAICPU::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraphPtr &, + const mindspore::AnfNodePtr &node, + const mindspore::EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto node_name = AnfAlgo::GetCNodeName(node); + if (node_name != prim::KPrimTransData->name() && node_name != prim::kPrimCast->name()) { + return nullptr; + } + auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); + if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) { + return nullptr; + } else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) { + auto builder = std::make_shared(kernel_builder_info); + builder->SetKernelType(AICPU_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), node); + } else { + MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" + << node->DebugString() << "]"; + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h new file mode 100644 index 0000000000..e534a851ad --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" +#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H +#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H +namespace mindspore { +namespace opt { +class ConvertUnSupportNodeToAICPU : public PatternProcessPass { + public: + explicit ConvertUnSupportNodeToAICPU(bool multigraph = true) + : PatternProcessPass("convert_unsupported_node_to_aicpu", multigraph), + supported_checker_(std::make_shared()) {} + ~ConvertUnSupportNodeToAICPU() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + SupportedCheckerPtr supported_checker_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc new file mode 100644 index 0000000000..4375a08031 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc @@ -0,0 +1,226 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { + session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); + AnfNodePtr cur_node = kernel_with_index.first; + size_t cur_out_index = kernel_with_index.second; + MS_EXCEPTION_IF_NULL(cur_node); + if (cur_node->isa()) { + auto cnode = cur_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string op_name = AnfAlgo::GetCNodeName(cnode); + auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); + // deal ref op + if (op_info != nullptr && op_info->is_ref()) { + auto ref_infos = op_info->ref_infos(); + if (ref_infos.count(cur_out_index) != 0) { + auto in_index = ref_infos.at(cur_out_index); + if (in_index > cnode->inputs().size()) { + MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() + << ", ref info is " << cur_out_index; + } + AnfNodePtr next_node = cnode->input(in_index + 1); + return FindRefOriginNode(next_node); + } + } + + // deal special (trans,cast,reshape) op + if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || + op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { + AnfNodePtr next_node = cnode->input(1); + return FindRefOriginNode(next_node); + } + } + + return kernel_with_index; +} + +void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, + const AnfNodePtr &final_node, size_t final_index, + const session::KernelWithIndex &origin_pair) { + // record the ref_pair + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + // if the final node is get item, means no trans or cast op is added, the final node is itself + // so add the pair for itself, because the get item will removed later + auto final_ref = (final_node == get_item ? cnode : final_node); + session::AnfWithOutIndex final_pair = std::make_pair(final_ref, final_index); + if (kernel_graph->IsInRefOutputMap(final_pair)) { + MS_LOG(EXCEPTION) << "ref_pair is already in ref map, node is " << final_ref->DebugString() << ", index is " + << final_index; + } + MS_LOG(DEBUG) << "Add Ref pair, final {node ptr " << final_pair.first.get() << " , info is " + << final_pair.first->DebugString() << " , index is " << final_pair.second << "}, origin {node ptr " + << origin_pair.first.get() << ", info is " << origin_pair.first->DebugString() << " : index " + << origin_pair.second << "}"; + kernel_graph->AddRefCorrespondPairs(final_pair, origin_pair); +} + +// if get_item is nullptr, the additional node will link to the cnode +// else the additional node will link to the get_item node (the get_item node link to cnode) +AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, + size_t input_index, const AnfNodePtr &get_item) { + AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); + size_t final_index = output_index; + AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); + session::KernelWithIndex origin_pair; + origin_pair = FindRefOriginNode(input_node); + MS_EXCEPTION_IF_NULL(origin_pair.first); + if (!origin_pair.first->isa()) { + MS_LOG(WARNING) << "ref op origin node is not parameter"; + } + MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is " + << origin_pair.first->DebugString() << ", index is " << origin_pair.second; + auto origin_format = AnfAlgo::GetOutputFormat(origin_pair.first, origin_pair.second); + auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second); + auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index); + auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index); + auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index); + // insert trans + if (origin_format != cur_format && cur_shape.size() > 1) { + auto kernel_select = std::make_shared(); + final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); + RefreshKernelBuildInfo(cur_format, origin_format, final_node); + final_index = 0; + MS_EXCEPTION_IF_NULL(final_node); + MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); + } + // insert cast + if (origin_type != cur_type) { + final_node = + AddCastOpNodeToGraph(func_graph, final_node, origin_format, cur_type, origin_type, cur_shape, cur_type); + MS_EXCEPTION_IF_NULL(final_node); + final_node->set_scope(cnode->scope()); + final_index = 0; + MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString(); + } + // add ref pair + AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair); + // insert depend + if (origin_format != cur_format || origin_type != cur_type) { + std::vector depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node}; + final_node = func_graph->NewCNode(depend_nodes); + MS_LOG(INFO) << "DealRefTransAndCast add denpend, op debug info is " << final_node->DebugString(); + } + + return final_node; +} +AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(op_info); + auto ref_infos = op_info->ref_infos(); + std::vector make_tuple_inputs; + AbstractBasePtrList abstract_list; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + AnfNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); + // deal with ref output + if (ref_infos.count(output_index) != 0) { + auto input_index = ref_infos.at(output_index); + final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node); + } + MS_EXCEPTION_IF_NULL(final_node); + abstract_list.push_back(final_node->abstract()); + make_tuple_inputs.push_back(final_node); + } + MS_EXCEPTION_IF_NULL(func_graph); + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + make_tuple->set_abstract(std::make_shared(abstract_list)); + return make_tuple; +} + +AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(op_info); + auto ref_infos = op_info->ref_infos(); + for (const auto &ref_info : ref_infos) { + if (ref_info.second > cnode->inputs().size()) { + MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() << ", ref info is " + << ref_info.second; + } + return AddAdditionalToRefOutput(func_graph, cnode, ref_info.first, ref_info.second, nullptr); + } + return nullptr; +} +} // namespace + +const BaseRef DealRefTransAndCast::DefinePattern() const { + VarPtr V = std::make_shared(UnVisited); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { + auto input_size = AnfAlgo::GetInputTensorNum(cnode); + for (size_t i = 0; i < input_size; ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); + auto input_node = input_node_with_index.first; + MS_EXCEPTION_IF_NULL(input_node); + MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope(); + AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index); + } + } +} + +const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::IsRealCNodeKernel(cnode)) { + return nullptr; + } + + DealBroadCastAsRef(graph, cnode); + + auto op_name = AnfAlgo::GetCNodeName(cnode); + auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); + if (op_info == nullptr || !op_info->is_ref()) { + return nullptr; + } + if (op_info->is_ref()) { + auto type = cnode->Type(); + MS_EXCEPTION_IF_NULL(type); + if (!type->isa()) { + return DealRefSigleOutput(graph, cnode, op_info); + } else { + return DealRefForMultipleOutput(graph, cnode, op_info); + } + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h new file mode 100644 index 0000000000..cb3b13dc49 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class DealRefTransAndCast : public PatternProcessPass { + public: + explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} + ~DealRefTransAndCast() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc new file mode 100644 index 0000000000..c3f7900645 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc @@ -0,0 +1,195 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/format_type/insert_cast.h" + +#include +#include +#include +#include + +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &need_insert_cast) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + std::vector make_tuple_inputs; + AbstractBasePtrList abstract_list; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { + AnfNodePtr replace_node = nullptr; + const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); + const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); + auto idx = NewValueNode(SizeToInt(output_idx)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(output_idx); + idx->set_abstract(std::make_shared(imm)); + auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); + AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get()); + if (need_insert_cast[output_idx]) { + const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); + TypeId origin_type(kTypeUnknown); + if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); + } + origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; + const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); + if (origin_type != device_type) { + replace_node = + AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, infer_type); + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); + } else { + replace_node = getitem; + } + } else { + replace_node = getitem; + } + abstract_list.push_back(replace_node->abstract()); + make_tuple_inputs.push_back(replace_node); + } + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + make_tuple->set_abstract(std::make_shared(abstract_list)); + return make_tuple; +} // namespace + +AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &need_insert_cast) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { + return cnode; + } + MS_EXCEPTION_IF_NULL(cnode->Type()); + // Single output + if (!cnode->Type()->isa()) { + if (!need_insert_cast[0]) { + return cnode; + } + + const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); + const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0); + TypeId origin_type(kTypeUnknown); + if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); + } + origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; + const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); + AnfNodePtr replace_node = cnode; + if (origin_type != device_type) { + replace_node = + AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, infer_type); + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); + } + return replace_node; + } + // Multiple output + return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast); +} + +AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + // insert cast for ops in graph kernel. + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + auto mng = sub_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + std::vector todo; + std::vector> graph_rets; + kernel::GetValidKernelNodes(sub_graph, &todo); + kernel::GetGraphRealOutput(sub_graph, &graph_rets); + for (auto &t : todo) { + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t); + // process input + CNodePtr t_cnode = t->cast(); + MS_EXCEPTION_IF_NULL(t_cnode); + auto t_new_node = InsertCastForInput(sub_graph, t_cnode); + AnfNodePtr t_new_node_1 = nullptr; + std::vector need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true); + // process output + auto iter = std::find_if(graph_rets.begin(), graph_rets.end(), + [&t](const std::pair &ret) { return ret.first == t; }); + if (iter != graph_rets.end()) { + auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t); + auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second); + auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin()); + if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) { + need_insert_cast[iter->second] = false; + } else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) { + need_insert_cast[iter->second] = false; + } + t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); + } else { + t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); + } + + if (t_new_node_1 != nullptr && t_new_node_1 != t) { + (void)mng->Replace(t, t_new_node_1); + } + } + + // insert cast for graph kernel. + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + // process input + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto new_node = InsertCastForInput(func_graph, cnode); + // process output + return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); +} +} // namespace + +const BaseRef InsertCast::DefinePattern() const { + VarPtr V = std::make_shared(UnVisited); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { + return nullptr; + } + + if (AnfAlgo::IsGraphKernel(node)) { + return ProcessGraphKernelOp(func_graph, node); + } + // insert cast for single op. + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + // process input + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto new_node = InsertCastForInput(func_graph, cnode); + // process output + return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.h new file mode 100644 index 0000000000..19c282aac9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ +#include + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "ir/anf.h" + +namespace mindspore { +namespace opt { +class InsertCast : public PatternProcessPass { + public: + explicit InsertCast(bool multigraph = true) : PatternProcessPass("insert_cast", multigraph) {} + ~InsertCast() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc new file mode 100644 index 0000000000..a22a1faa5f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#include +#include +#include "utils/utils.h" +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +const BaseRef InsertTransOp::DefinePattern() const { + std::shared_ptr V = std::make_shared(UnVisited); + std::shared_ptr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +bool IsGraphOutput(const AnfNodePtr &node, const std::vector &outputs) { + auto iter = std::find(outputs.begin(), outputs.end(), node); + if (iter != outputs.end()) { + return true; + } + + return false; +} + +const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + AnfNodePtr front_node; + auto kernel_graph = func_graph->cast>(); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { + front_node = kernel_graph->GetFrontNodeByInternalOutput(node); + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + MS_LOG(DEBUG) << "====process op: " << node->DebugString(); + AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) { + if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { + return new_node; + } + } + auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_); + if (kernel_graph != nullptr && front_node != nullptr) { + auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node); + kernel_graph->ReplaceInternalOutput(old_node, final_node); + } + return final_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.h new file mode 100644 index 0000000000..0b21375327 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class InsertTransOp : public PatternProcessPass { + public: + explicit InsertTransOp(bool multigraph = true) + : PatternProcessPass("insert_trans_op", multigraph), kernel_select_(std::make_shared()) {} + ~InsertTransOp() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc new file mode 100644 index 0000000000..d0b92b250d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h" +#include +#include "utils/utils.h" +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/oplib/oplib.h" + +namespace mindspore { +namespace opt { +const BaseRef RunOpInsertTransData::DefinePattern() const { + std::shared_ptr V = std::make_shared(UnVisited); + MS_EXCEPTION_IF_NULL(V); + std::shared_ptr Xs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + return VectorRef({V, Xs}); +} + +const AnfNodePtr RunOpInsertTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + MS_LOG(DEBUG) << "====process op: " << node->DebugString(); + return InsertTransOpForInput(func_graph, node, kernel_select_); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h new file mode 100644 index 0000000000..82ff5f2b9a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class RunOpInsertTransData : public PatternProcessPass { + public: + explicit RunOpInsertTransData(bool multigraph = true) + : PatternProcessPass("insert_transdata_for_runop", multigraph), + kernel_select_(std::make_shared()) {} + ~RunOpInsertTransData() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc new file mode 100644 index 0000000000..88e9fa77b8 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc @@ -0,0 +1,282 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/merge_cast_to_op.h" + +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t kCastInputNum = 2; +const size_t kTupleGetitemInputNum = 3; +bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, const size_t change_idx, + const std::shared_ptr &candidate_kernel_info) { + if (node == nullptr || node->kernel_info() == nullptr || candidate_kernel_info == nullptr) { + return false; + } + + // checkout inputs' fmt and dtype except index equal change_idx + for (size_t i = 0; i < candidate_kernel_info->GetInputNum(); i++) { + if (i == change_idx) { + if (candidate_kernel_info->GetInputDeviceType(i) != dst_type || + candidate_kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(node, i)) { + return false; + } + } else if (candidate_kernel_info->GetInputDeviceType(i) != AnfAlgo::GetInputDeviceDataType(node, i) || + candidate_kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(node, i)) { + return false; + } + } + + // check outputs's fmt and dtype + for (size_t i = 0; i < candidate_kernel_info->GetOutputNum(); i++) { + if (candidate_kernel_info->GetOutputDeviceType(i) != AnfAlgo::GetOutputDeviceDataType(node, i) || + candidate_kernel_info->GetOutputFormat(i) != AnfAlgo::GetOutputFormat(node, i)) { + return false; + } + } + return true; +} + +bool GetNextNodeAndCastIndex(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr *next_node, + size_t *cast_index) { + auto output_node_list = GetRealNodeUsedList(graph, node); + MS_EXCEPTION_IF_NULL(output_node_list); + if (output_node_list->size() != 1) { + return false; + } + auto node_pair = output_node_list->at(0); + *next_node = node_pair.first; + *cast_index = node_pair.second - 1; + return true; +} + +bool CheckInputs(const CNodePtr &node, const std::shared_ptr &kernel_info) { + MS_EXCEPTION_IF_NULL(kernel_info); + if (AnfAlgo::GetInputTensorNum(node) != kernel_info->GetInputNum()) { + return false; + } + + for (size_t index = 0; index < kernel_info->GetInputNum(); ++index) { + if (AnfAlgo::GetInputFormat(node, index) != kernel_info->GetInputFormat(index) || + AnfAlgo::GetInputDeviceDataType(node, index) != kernel_info->GetInputDeviceType(index)) { + return false; + } + } + return true; +} + +bool CheckOtherOutputs(const CNodePtr &node, const std::shared_ptr &kernel_info, + const size_t idx) { + MS_EXCEPTION_IF_NULL(kernel_info); + if (AnfAlgo::GetOutputTensorNum(node) != kernel_info->GetOutputNum()) { + return false; + } + for (size_t index = 0; index < kernel_info->GetOutputNum(); ++index) { + if (idx == index) { + continue; + } + if (AnfAlgo::GetOutputFormat(node, index) != kernel_info->GetOutputFormat(index) || + AnfAlgo::GetOutputDeviceDataType(node, index) != kernel_info->GetOutputDeviceType(index)) { + return false; + } + } + return true; +} + +bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr &kernel_info, size_t index) { + if (kernel_info == nullptr) { + return false; + } + + if (AnfAlgo::GetOutputDeviceDataType(node, 0) != kernel_info->GetOutputDeviceType(index)) { + return false; + } + if (AnfAlgo::GetOutputInferShape(node, 0).size() == 4 && AnfAlgo::GetOutputFormat(node, 0) == kOpFormat_NCHW && + kernel_info->GetOutputFormat(index) == kOpFormat_DEFAULT) { + return true; + } + return AnfAlgo::GetOutputFormat(node, 0) == kernel_info->GetOutputFormat(index); +} + +void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) { + using Shape = std::vector; + auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0); + auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0); + std::vector shapes; + std::vector types; + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { + if (cast_index == index) { + shapes.emplace_back(cast_shape); + types.emplace_back(cast_dtype); + continue; + } + shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index)); + types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index)); + } + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get()); +} + +AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, const KernelQueryPtr kernel_query) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(kernel_query); + AnfNodePtr next_node = nullptr; + size_t cast_index = 0; + if (!GetNextNodeAndCastIndex(graph, node, &next_node, &cast_index)) { + return nullptr; + } + MS_EXCEPTION_IF_NULL(next_node); + if (!next_node->isa() || !AnfAlgo::IsRealKernel(next_node)) { + return nullptr; + } + auto next_cnode = next_node->cast(); + if (AnfAlgo::IsGraphKernel(next_node)) { + return nullptr; + } + auto next_op_name = AnfAlgo::GetCNodeName(next_node); + std::vector> kernel_info_list; + kernel_query->Query(next_cnode, &kernel_info_list); + + auto dst_type_id = AnfAlgo::GetInputDeviceDataType(node, 0); + auto alternative_kernel_info = std::find_if( + kernel_info_list.begin(), kernel_info_list.end(), + [&next_cnode, &dst_type_id, &cast_index](const std::shared_ptr &candidate_kernel_info) { + return AlternativeKernelInfoForInput(next_cnode, dst_type_id, cast_index, candidate_kernel_info); + }); + if (alternative_kernel_info == kernel_info_list.end()) { + return nullptr; + } + auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(next_node); + MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_cnode->DebugString() + << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" + << (*alternative_kernel_info)->ToString(); + AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); + ChangeNodeInferInfo(next_cnode, node, cast_index); + if (node->inputs().size() < kCastInputNum) { + MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:"; + } + return node->input(1); +} + +bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_output, size_t *output_idx) { + MS_EXCEPTION_IF_NULL(x_node); + if (x_node->isa()) { + auto x_cnode = x_node->cast(); + *prior_op = x_cnode; + // when x_node is tuple_getitem + if (AnfAlgo::GetCNodeName(x_node) == prim::kPrimTupleGetItem->name()) { + if (x_cnode->inputs().size() < kTupleGetitemInputNum) { + MS_LOG(EXCEPTION) << "tuple getitem node has wrong input num" << x_cnode->inputs().size(); + } + MS_EXCEPTION_IF_NULL(output_idx); + AnfNodePtr input1 = x_cnode->input(1); + MS_EXCEPTION_IF_NULL(input1); + if (!input1->isa()) { + return false; + } + *prior_op = input1->cast(); + MS_EXCEPTION_IF_NULL(*prior_op); + AnfNodePtr input2 = x_cnode->input(2); + MS_EXCEPTION_IF_NULL(input2); + auto value_ptr = input2->cast(); + MS_EXCEPTION_IF_NULL(value_ptr); + *output_idx = IntToSize(GetValue(value_ptr->value())); + *single_output = false; + } + return AnfAlgo::IsRealKernel(*prior_op); + } + return false; +} + +AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_node, const KernelQueryPtr kernel_query) { + MS_EXCEPTION_IF_NULL(cur_node); + MS_EXCEPTION_IF_NULL(kernel_query); + if (cur_node->inputs().size() < kCastInputNum) { + MS_LOG(EXCEPTION) << "op[Cast] has wrong input num:"; + } + AnfNodePtr x_node = cur_node->input(1); + if (IsUsedByOthers(graph, x_node)) { + return nullptr; + } + + CNodePtr prior_op = nullptr; + bool single_output = true; + size_t output_idx = 0; + if (!GetPriorOp(x_node, &prior_op, &single_output, &output_idx)) { + return nullptr; + } + MS_EXCEPTION_IF_NULL(prior_op); + if (AnfAlgo::IsGraphKernel(prior_op)) { + return nullptr; + } + + std::vector> kernel_info_list; + kernel_query->Query(prior_op, &kernel_info_list); + auto kernel_info_it = std::find_if( + kernel_info_list.begin(), kernel_info_list.end(), + [&prior_op, &cur_node, &output_idx](const std::shared_ptr &item_kernel_info) { + return CheckInputs(prior_op, item_kernel_info) && CheckOtherOutputs(prior_op, item_kernel_info, output_idx) && + CheckIndexOutput(cur_node, item_kernel_info, output_idx); + }); + if (kernel_info_it == kernel_info_list.end()) { + return nullptr; + } + auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(prior_op); + MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << prior_op->DebugString() + << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" + << (*kernel_info_it)->ToString(); + AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get()); + ChangeNodeInferInfo(prior_op, cur_node, output_idx); + if (!single_output) { + MS_EXCEPTION_IF_NULL(x_node); + ChangeNodeInferInfo(x_node->cast(), cur_node, 0); + } + auto prior_name = AnfAlgo::GetCNodeName(prior_op); + if (prior_name == kFive2FourOpName) { + AnfAlgo::CopyNodeAttr("dst_type", "dstType", cur_node, prior_op); + } else if (prior_name == kFour2FiveOpName) { + AnfAlgo::CopyNodeAttr("dst_type", cur_node, prior_op); + } + return single_output ? prior_op : x_node; +} +} // namespace + +const BaseRef MergeCastToOp::DefinePattern() const { + VarPtr X = std::make_shared(); + return VectorRef({prim::kPrimCast, X}); +} + +const AnfNodePtr MergeCastToOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + auto new_node = MergeCastToNextOp(graph, cnode, kernel_query_); + if (new_node == nullptr) { + new_node = MergeCastToPriorOp(graph, cnode, kernel_query_); + } + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.h new file mode 100644 index 0000000000..d0e467b7a3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class MergeCastToOp : public PatternProcessPass { + public: + explicit MergeCastToOp(bool multigraph = true) + : PatternProcessPass("merge_cast_to_op", multigraph), kernel_query_(std::make_shared()) {} + ~MergeCastToOp() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + KernelQueryPtr kernel_query_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.cc new file mode 100644 index 0000000000..adca536f04 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h" +#include +#include +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); + auto input_format = AnfAlgo::GetInputFormat(cnode, 0); + if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) { + return nullptr; + } + if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { + return nullptr; + } + + AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode); + return cnode; +} + +AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) { + auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0); + if (input_shape.size() != 5) { + return nullptr; + } + if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) { + return nullptr; + } + + auto multiples = AnfAlgo::GetNodeAttr>(cnode, kAttrMultiples); + if (multiples.size() == 4 && multiples[1] == 1) { + multiples.push_back(1); + AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode); + } + + return cnode; +} + +AnfNodePtr ModifyAttrs(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto op_name = AnfAlgo::GetCNodeName(cnode); + if (op_name == prim::kPrimTile->name()) { + return ModifyTileOpAttrs(cnode); + } else if (op_name == prim::kPrimReduceSum->name()) { + // kPrimReduceMean + // kPrimReduceSum + // kPrimReduceAll + // kPrimReduceMax + // kPrimReduceMin + return ModifyReduceOpsAttrs(cnode); + } + return nullptr; +} +} // namespace + +const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsGraphKernel(node)) { + return nullptr; + } + MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node); + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + auto manager = fg->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector todos; + kernel::GetValidKernelNodes(fg, &todos); + for (auto &t : todos) { + auto new_node = ModifyAttrs(t->cast()); + if (new_node != nullptr && new_node != t) { + (void)manager->Replace(t, new_node); + } + } + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.h new file mode 100644 index 0000000000..f5608db05a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ModifyOpAttrs : public PatternProcessPass { + public: + explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {} + ~ModifyOpAttrs() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc new file mode 100644 index 0000000000..91b9326cc1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" + +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/common_utils.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef RectifyDoMaskKernelInfo::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kPynativeMode) { + return RectifyKernelInfoInPynativeProcess(node); + } + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { + return nullptr; + } + std::vector do_mask_node_list; + auto gen_mask_output_nodes = GetRealNodeUsedList(graph, cnode); + MS_EXCEPTION_IF_NULL(gen_mask_output_nodes); + for (const auto &output_node : *gen_mask_output_nodes) { + if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) { + MS_EXCEPTION_IF_NULL(output_node.first); + auto output_cnode = output_node.first->cast(); + do_mask_node_list.push_back(output_cnode); + } + } + std::vector input_shape; + for (const auto &output_node : do_mask_node_list) { + if (input_shape.empty()) { + input_shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); + continue; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); + if (!kernel::IsSameShape(shape, input_shape)) { + MS_LOG(EXCEPTION) << "The DropOutGenMask connected with same genmask's shape must be equal!" + << " GenMask " << node->DebugString(); + } + } + RectifyKernelInfo(do_mask_node_list, graph); + return nullptr; +} + +void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_mask_node_list, + const FuncGraphPtr &graph) const { + std::map format_counter; + std::string special_format; + std::string convert_format; + for (const auto &do_mask : do_mask_node_list) { + auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0); + if (special_format.empty() && kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end()) { + special_format = do_mask_data_format; + } + if (format_counter.find(do_mask_data_format) == format_counter.end()) { + format_counter[do_mask_data_format] = 1; + } else { + format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1; + } + } + if (format_counter.size() == 1) { + return; + } + if (convert_format.empty()) { + convert_format = GetConvertFormat(format_counter); + } + RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format, graph); +} + +std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map &format_counter) const { + std::string convert_format = kOpFormat_DEFAULT; + size_t counter = 0; + if (format_counter.size() > 2) { + return kOpFormat_DEFAULT; + } + if (format_counter.size() == 2 && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) { + return kOpFormat_DEFAULT; + } + for (const auto &iter : format_counter) { + if (counter < iter.second) { + convert_format = iter.first; + counter = iter.second; + } else if (counter == iter.second && kHWSpecialFormatSet.find(iter.first) != kHWSpecialFormatSet.end()) { + convert_format = iter.first; + } + } + return convert_format; +} + +void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, + const std::string &format, + const FuncGraphPtr &graph) const { + for (const auto &do_mask : do_mask_node_list) { + if (AnfAlgo::GetInputFormat(do_mask, 0) != format) { + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); + builder->SetInputFormat(format, 0); + builder->SetOutputFormat(format, 0); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); + ReSelecChildNodeKernelInfo(do_mask, graph); + } + } +} + +AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode == nullptr) { + return nullptr; + } + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { + return nullptr; + } + auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0); + if (do_mask_input_format != kOpFormat_DEFAULT) { + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); + builder->SetInputFormat(kOpFormat_DEFAULT, 0); + builder->SetOutputFormat(kOpFormat_DEFAULT, 0); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + } + return nullptr; +} + +void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const { + MS_EXCEPTION_IF_NULL(cnode); + auto output_node_list = GetRealNodeUsedList(graph, cnode); + MS_EXCEPTION_IF_NULL(output_node_list); + for (const auto &out_node_info : *output_node_list) { + MS_EXCEPTION_IF_NULL(out_node_info.first); + auto out_node = out_node_info.first->cast(); + if (AnfAlgo::IsRealKernel(out_node_info.first)) { + auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node); + kernel_selecter->SelectKernel(out_node); + auto new_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node); + MS_EXCEPTION_IF_NULL(new_build_info); + MS_EXCEPTION_IF_NULL(ori_build_info); + if ((*new_build_info) != (*ori_build_info)) { + ReSelecChildNodeKernelInfo(out_node, graph); + } + } else if (AnfAlgo::GetCNodeName(out_node) == prim::kPrimTupleGetItem->name() || + AnfAlgo::GetCNodeName(out_node) == prim::kPrimDepend->name()) { + ReSelecChildNodeKernelInfo(out_node, graph); + } else { + MS_LOG(INFO) << "Reselected the node " << cnode->DebugString() << " failed"; + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h new file mode 100644 index 0000000000..cc9333a013 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H +#include +#include +#include +#include + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" +namespace mindspore { +namespace opt { +class RectifyDoMaskKernelInfo : public PatternProcessPass { + public: + explicit RectifyDoMaskKernelInfo(bool multigraph = true) + : PatternProcessPass("batch_norm_bert_fission", multigraph), kernel_selecter(std::make_shared()) {} + ~RectifyDoMaskKernelInfo() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + void RectifyKernelInfo(const std::vector &do_mask_node_list, const FuncGraphPtr &graph) const; + AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const; + std::string GetConvertFormat(const std::map &format_counter) const; + void RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, const std::string &format, + const FuncGraphPtr &graph) const; + void ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const; + KernelSelectPtr kernel_selecter; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.cc new file mode 100644 index 0000000000..09992005a4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" +#include +#include +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto op_name = AnfAlgo::GetCNodeName(cnode); + if (op_name != prim::kPrimReshape->name()) { + return nullptr; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); + auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); + if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) { + return nullptr; + } + + return cnode->input(1); +} +} // namespace + +const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsGraphKernel(node)) { + return nullptr; + } + MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node); + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + auto manager = fg->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector todos; + kernel::GetValidKernelNodes(fg, &todos); + for (auto &t : todos) { + auto new_node = RemoveReshapeOp(t->cast()); + if (new_node != nullptr && new_node != t) { + (void)manager->Replace(t, new_node); + } + } + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h new file mode 100644 index 0000000000..135f11f52c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class RemoveNoUseReshapeOp : public PatternProcessPass { + public: + explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {} + ~RemoveNoUseReshapeOp() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.cc new file mode 100644 index 0000000000..a3fd704bc5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/addn_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_addn_cnode, size_t begin_index, + size_t offset) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(origin_addn_cnode); + std::vector new_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; + for (size_t i = begin_index; i < begin_index + offset; ++i) { + new_addn_inputs.push_back(origin_addn_cnode->input(i)); + } + CNodePtr new_addn = func_graph->NewCNode(new_addn_inputs); + MS_EXCEPTION_IF_NULL(new_addn); + new_addn->set_scope(origin_addn_cnode->scope()); + new_addn->set_abstract(origin_addn_cnode->abstract()); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); + std::vector dyn_input_sizes{SizeToInt(offset)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn); + return new_addn; +} +} // namespace + +const BaseRef AddnFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimAddN, Xs}); +} + +const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // The real input begins with index 1. + size_t origin_input_size = cnode->inputs().size() - 1; + if (origin_input_size <= inputs_divisor_) { + return nullptr; + } + CNodePtr new_cnode = cnode; + while (origin_input_size > inputs_divisor_) { + MS_EXCEPTION_IF_NULL(new_cnode); + std::vector base_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; + size_t cur_input_index = 1; + // Divide the inputs of addn by inputs_divisor_. + while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { + base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); + cur_input_index += inputs_divisor_; + } + for (size_t i = cur_input_index; i <= origin_input_size; i++) { + base_addn_inputs.push_back(new_cnode->input(i)); + } + CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); + MS_EXCEPTION_IF_NULL(base_addn); + base_addn->set_scope(new_cnode->scope()); + base_addn->set_abstract(new_cnode->abstract()); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); + std::vector dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn); + new_cnode = base_addn; + origin_input_size = base_addn->inputs().size() - 1; + } + + return new_cnode; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.h new file mode 100644 index 0000000000..e04cdfdf7b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr size_t kAddnInputsDivisor = 63; +class AddnFission : public PatternProcessPass { + public: + explicit AddnFission(bool multigraph = true) + : PatternProcessPass("addn_fission", multigraph), inputs_divisor_(kAddnInputsDivisor) {} + ~AddnFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + size_t inputs_divisor_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc new file mode 100644 index 0000000000..f0edefd5f5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc @@ -0,0 +1,172 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const std::vector kOutputIndex{0, 3, 4, 5}; +constexpr size_t kBatchNormRealOutputNum = 3; +constexpr size_t kBatchNormRealInputNum = 3; + +bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(bn) == manager->node_users().end()) { + return false; + } + size_t output_num = 0; + for (const auto &node_index : manager->node_users()[bn]) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + continue; + } + auto tuple_getiterm_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); + auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) { + return false; + } + bn_outputs->push_back(output); + output_num++; + } + return output_num == kBatchNormRealOutputNum; +} + +AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn); + auto bn_cnode = bn->cast(); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { + MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " + << kBatchNormRealInputNum + 1; + } + std::vector bn_training_reduce_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceOpName)), bn_cnode->input(1)}; + auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); + MS_EXCEPTION_IF_NULL(bn_training_reduce); + auto bn_input1 = bn_cnode->input(2); + MS_EXCEPTION_IF_NULL(bn_input1); + auto bn_input2 = bn_cnode->input(3); + MS_EXCEPTION_IF_NULL(bn_input2); + AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input2->abstract()}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_reduce->set_abstract(abstract_tuple); + bn_training_reduce->set_scope(bn->scope()); + AnfAlgo::CopyNodeAttrs(bn, bn_training_reduce); + return bn_training_reduce; +} + +AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, + const std::vector &bn_training_reduce_outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn); + auto bn_cnode = bn->cast(); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { + MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " + << kBatchNormRealInputNum + 1; + } + if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum + << ", but it is " << bn_training_reduce_outputs.size(); + } + std::vector bn_training_update_v2_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateV2OpName)), + bn_cnode->input(1), + bn_training_reduce_outputs[0], + bn_training_reduce_outputs[1], + bn_cnode->input(2), + bn_cnode->input(3)}; + auto bn_training_update_v2 = func_graph->NewCNode(bn_training_update_v2_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update_v2); + + auto bn_abstract_tuple = dyn_cast(bn->abstract()); + MS_EXCEPTION_IF_NULL(bn_abstract_tuple); + if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { + MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " + << bn_abstract_tuple->elements().size(); + } + std::vector abstract_list{bn_abstract_tuple->elements()[0], bn_abstract_tuple->elements()[3], + bn_abstract_tuple->elements()[4]}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_update_v2->set_abstract(abstract_tuple); + bn_training_update_v2->set_scope(bn->scope()); + AnfAlgo::CopyNodeAttrs(bn, bn_training_update_v2); + return bn_training_update_v2; +} +} // namespace + +const BaseRef BatchNormBertFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimBatchNorm, Xs}); +} + +const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + std::vector bn_outputs; + if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { + MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() != kBatchNormRealInputNum + 1) { + MS_LOG(INFO) << "The input size of BatchNorm should be " << kBatchNormRealInputNum + << ". The node should not be changed"; + return nullptr; + } + AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); + std::vector bn_training_reduce_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, + &bn_training_reduce_outputs); + + AnfNodePtr bn_training_update_v2 = CreateBNTrainingUpdateV2(func_graph, node, bn_training_reduce_outputs); + std::vector bn_training_update_v2_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_v2, kBNTrainingUpdateV2OutputNum, + &bn_training_update_v2_outputs); + if (bn_training_update_v2_outputs.size() != kBNTrainingUpdateV2OutputNum) { + MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingUpdateV2OutputNum + << ", but it is " << bn_training_update_v2_outputs.size(); + } + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem); + size_t output_index = 0; + for (const auto &output : bn_outputs) { + (void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); + output_index++; + } + // Return the new node for control depends. + return bn_training_update_v2; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h new file mode 100644 index 0000000000..23f0e56035 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormBertFission : public PatternProcessPass { + public: + explicit BatchNormBertFission(bool multigraph = true) : PatternProcessPass("batch_norm_bert_fission", multigraph) {} + ~BatchNormBertFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.cc new file mode 100644 index 0000000000..97c67e4441 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.cc @@ -0,0 +1,172 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h" +#include +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kBatchNormGradInferOutputNum = 3; +bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(node) == manager->node_users().end()) { + MS_LOG(DEBUG) << "The node " << node->DebugString() << " should have some outputs"; + return false; + } + for (const auto &node_index : manager->node_users()[node]) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + continue; + } + auto tuple_getiterm_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); + auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (index == kBatchNormGradInferOutputNum || index == kBatchNormGradInferOutputNum + 1) { + MS_LOG(DEBUG) << "The output " << index << " of node " << node->DebugString() << " is not null, no need change"; + return false; + } + } + return true; +} +} // namespace + +AnfNodePtr BatchNormGradInferFission::CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn_grad); + MS_EXCEPTION_IF_NULL(equiv); + // Set inputs + auto iter_input0 = (*equiv).find(input0_var_); + if (iter_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; + } + auto iter_input2 = (*equiv).find(input2_var_); + if (iter_input2 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input2 var after matched."; + } + auto iter_input4 = (*equiv).find(input4_var_); + if (iter_input4 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; + } + std::vector bn_infer_grad_inputs = { + NewValueNode(std::make_shared(kBNInferGradOpName)), utils::cast(iter_input0->second), + utils::cast(iter_input2->second), utils::cast(iter_input4->second)}; + auto bn_infer_grad = func_graph->NewCNode(bn_infer_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_infer_grad); + // Set abstract, the output of new node is taking the place of the 0th output of bn_grad. + auto bn_grad_abstract_tuple = dyn_cast(bn_grad->abstract()); + MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); + if (bn_grad_abstract_tuple->elements().empty()) { + MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be empty"; + } + bn_infer_grad->set_abstract(bn_grad_abstract_tuple->elements()[0]); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_infer_grad); + bn_infer_grad->set_scope(bn_grad->scope()); + return bn_infer_grad; +} + +AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, + const AnfNodePtr &bn_grad, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn_grad); + MS_EXCEPTION_IF_NULL(equiv); + // Set inputs + auto iter_input0 = (*equiv).find(input0_var_); + if (iter_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; + } + auto iter_input1 = (*equiv).find(input1_var_); + if (iter_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input1 var after matched."; + } + auto iter_input3 = (*equiv).find(input3_var_); + if (iter_input3 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input3 var after matched."; + } + auto iter_input4 = (*equiv).find(input4_var_); + if (iter_input4 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; + } + std::vector bn_training_update_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), + utils::cast(iter_input0->second), utils::cast(iter_input1->second), + utils::cast(iter_input3->second), utils::cast(iter_input4->second)}; + auto bn_training_update_grad = func_graph->NewCNode(bn_training_update_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update_grad); + // Set abstract, the outputs of new node are taking the place of the 1st and 2nd outputs of bn_grad. + auto bn_grad_abstract_tuple = dyn_cast(bn_grad->abstract()); + MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); + if (bn_grad_abstract_tuple->elements().size() < kBatchNormGradInferOutputNum) { + MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be less than 3"; + } + std::vector abstract_list{bn_grad_abstract_tuple->elements()[1], + bn_grad_abstract_tuple->elements()[2]}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_update_grad->set_abstract(abstract_tuple); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_training_update_grad); + bn_training_update_grad->set_scope(bn_grad->scope()); + return bn_training_update_grad; +} + +const BaseRef BatchNormGradInferFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimBatchNormGrad, input0_var_, input1_var_, input2_var_, input3_var_, input4_var_, Xs}); +} + +const AnfNodePtr BatchNormGradInferFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, node->cast())) { + MS_LOG(DEBUG) << "The BatchNormGrad " << node->DebugString() << " has no is_training attr, should not be changed"; + return nullptr; + } + if (AnfAlgo::GetNodeAttr(node, kAttrIsTraining)) { + MS_LOG(DEBUG) << "The is_training attr value of " << node->DebugString() << " is true, no need change"; + return nullptr; + } + if (!CheckOutputsIndex(func_graph, node)) { + MS_LOG(DEBUG) << "The output 3 or 4 of BatchNormGrad is not null, no need change"; + return nullptr; + } + AnfNodePtr bn_infer_grad = CreateBNInferGrad(func_graph, node, equiv); + AnfNodePtr bn_training_update_grad = CreateBNTrainingUpdateGrad(func_graph, node, equiv); + std::vector bn_training_update_grad_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_grad, kBNTrainingUpdateGradOutputNum, + &bn_training_update_grad_outputs); + if (bn_training_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "The output size of " << bn_training_update_grad << " should be " + << kBNTrainingUpdateGradOutputNum << ", but it is " << bn_training_update_grad_outputs.size(); + } + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_infer_grad, + bn_training_update_grad_outputs[0], bn_training_update_grad_outputs[1]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + return make_tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h new file mode 100644 index 0000000000..97100de284 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormGradInferFission : public PatternProcessPass { + public: + explicit BatchNormGradInferFission(bool multigraph = true) + : PatternProcessPass("batch_norm_grad_infer_fission", multigraph), + input0_var_(std::make_shared()), + input1_var_(std::make_shared()), + input2_var_(std::make_shared()), + input3_var_(std::make_shared()), + input4_var_(std::make_shared()) {} + ~BatchNormGradInferFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, const EquivPtr &equiv) const; + AnfNodePtr CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, + const EquivPtr &equiv) const; + + VarPtr input0_var_; + VarPtr input1_var_; + VarPtr input2_var_; + VarPtr input3_var_; + VarPtr input4_var_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc new file mode 100644 index 0000000000..97122386c6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h" + +#include +#include +#include + +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "common/utils.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, + std::vector *bn_update_grad_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_grad_node); + auto bn_grad_inputs = bn_grad_node->inputs(); + if (bn_grad_inputs.size() < kBNGradInputNum) { + MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; + } + std::vector bn_update_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], + bn_grad_inputs[4], bn_grad_inputs[5]}; + auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_update_grad); + bn_update_grad->set_kernel_info(std::make_shared()); + bn_update_grad->set_scope(bn_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 1), AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 1), AnfAlgo::GetOutputInferShape(bn_grad_node, 2)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_update_grad.get()); + + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad); + CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs); +} + +void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, + const std::vector &bn_update_grad_outputs, + std::vector *bn_reduce_grad_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_grad_node); + auto bn_grad_inputs = bn_grad_node->inputs(); + if (bn_grad_inputs.size() < kBNGradInputNum) { + MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; + } + if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "BNTrainingReduceGrad_outputs has wrong size"; + } + std::vector bn_reduce_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceGradOpName)), + bn_grad_inputs[1], + bn_grad_inputs[2], + bn_update_grad_outputs[0], + bn_update_grad_outputs[1], + bn_grad_inputs[3], + bn_grad_inputs[4], + bn_grad_inputs[5]}; + auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_reduce_grad); + bn_reduce_grad->set_kernel_info(std::make_shared()); + bn_reduce_grad->set_scope(bn_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_reduce_grad.get()); + + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); + (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); +} +} // namespace +const BaseRef BatchNormGradSplit::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto prim = std::make_shared(kBatchNormGradOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr BatchNormGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + if (!primitive->HasAttr(kAttrIsTraining)) { + MS_LOG(INFO) << "Op BatchNormGrad must have attrs of is_training"; + return nullptr; + } + if (!AnfAlgo::GetNodeAttr(cnode, kAttrIsTraining)) { + MS_LOG(INFO) << "is_training must be true"; + return nullptr; + } + + std::vector bn_update_grad_outputs; + CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); + if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; + } + + std::vector bn_reduce_grad_outputs; + CreateOutputsOfReduceGrad(func_graph, cnode, bn_update_grad_outputs, &bn_reduce_grad_outputs); + if (bn_reduce_grad_outputs.size() != kSingleOutputNum) { + MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size"; + } + + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0], + bn_update_grad_outputs[0], bn_update_grad_outputs[1]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h new file mode 100644 index 0000000000..e5378d8332 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class BatchNormGradSplit : public PatternProcessPass { + public: + explicit BatchNormGradSplit(bool multigraph = true) : PatternProcessPass("batch_norm_grad_split", multigraph) {} + ~BatchNormGradSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc new file mode 100644 index 0000000000..6c4e226120 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" + +#include +#include +#include + +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "common/utils.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, + std::vector *bn_update_grad_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_grad_node); + auto bn_grad_inputs = bn_grad_node->inputs(); + if (bn_grad_inputs.size() != kBNGradInputNum) { + MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; + } + std::vector bn_update_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], + bn_grad_inputs[4], bn_grad_inputs[5]}; + auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_update_grad); + bn_update_grad->set_kernel_info(std::make_shared()); + bn_update_grad->set_scope(bn_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 1), AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 1), AnfAlgo::GetOutputInferShape(bn_grad_node, 2)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_update_grad.get()); + + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad); + CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs); +} + +void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, + const std::vector &bn_update_grad_outputs, + std::vector *bn_reduce_grad_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_grad_node); + auto bn_grad_inputs = bn_grad_node->inputs(); + if (bn_grad_inputs.size() != kBNGradInputNum) { + MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; + } + if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; + } + std::vector bn_reduce_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceGradOpName)), + bn_grad_inputs[1], + bn_grad_inputs[2], + bn_update_grad_outputs[0], + bn_update_grad_outputs[1], + bn_grad_inputs[3], + bn_grad_inputs[4], + bn_grad_inputs[5]}; + auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_reduce_grad); + bn_reduce_grad->set_kernel_info(std::make_shared()); + bn_reduce_grad->set_scope(bn_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_reduce_grad.get()); + + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); + (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); +} + +CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + std::vector bn_update_grad_outputs; + CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); + if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; + } + + std::vector bn_reduce_grad_outputs; + CreateOutputsOfReduceGrad(func_graph, cnode, bn_update_grad_outputs, &bn_reduce_grad_outputs); + if (bn_reduce_grad_outputs.size() != 1) { + MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size"; + } + + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0], + bn_update_grad_outputs[0], bn_update_grad_outputs[1]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + return make_tuple; +} +} // namespace + +const BaseRef BnGradSplit::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimFusedBatchNormGrad, Xs}); +} + +const AnfNodePtr BnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + return BNGradSplitForTBE(func_graph, cnode); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h new file mode 100644 index 0000000000..6fe78d4724 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class BnGradSplit : public PatternProcessPass { + public: + explicit BnGradSplit(bool multigraph = true) : PatternProcessPass("bn_grad_split", multigraph) {} + ~BnGradSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc new file mode 100644 index 0000000000..33670e5703 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/bn_split.h" + +#include +#include +#include + +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, + std::vector *bn_training_reduce_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() != kBnInputNum) { + MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString(); + return false; + } + std::vector bn_training_reduce_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceOpName))}; + bn_training_reduce_inputs.push_back(bn_cnode->input(1)); + auto bn_training_reduce = graph->NewCNode(bn_training_reduce_inputs); + MS_EXCEPTION_IF_NULL(bn_training_reduce); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + bn_training_reduce->set_kernel_info(kernel_info); + std::vector bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0); + if (bn_shape_i0.size() < kShape2dDims) { + MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims; + return false; + } + std::vector bn_training_reduce_shape = {bn_shape_i0[1]}; + auto types = {kNumberTypeFloat32, kNumberTypeFloat32}; + auto shapes = {bn_training_reduce_shape, bn_training_reduce_shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_training_reduce.get()); + bn_training_reduce->set_scope(bn_cnode->scope()); + AnfAlgo::CopyNodeAttrs(bn_cnode, bn_training_reduce); + + CreateMultipleOutputsOfAnfNode(graph, bn_training_reduce, kBNTrainingReduceOutputNum, bn_training_reduce_outputs); + return true; +} + +AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, + const std::vector &bn_training_reduce_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() != kBnInputNum) { + MS_LOG(EXCEPTION) << "BN node has wrong input size"; + } + if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { + MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; + } + // the inputs of BNTrainingUpdate are from the outputs of BNTrainingReduce and the inputs of BN + std::vector bn_training_update_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateOpName))}; + bn_training_update_inputs.push_back(bn_cnode->input(1)); + bn_training_update_inputs.push_back(bn_training_reduce_outputs[0]); + bn_training_update_inputs.push_back(bn_training_reduce_outputs[1]); + bn_training_update_inputs.push_back(bn_cnode->input(2)); + bn_training_update_inputs.push_back(bn_cnode->input(3)); + bn_training_update_inputs.push_back(bn_cnode->input(4)); + bn_training_update_inputs.push_back(bn_cnode->input(5)); + auto bn_training_update = graph->NewCNode(bn_training_update_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + bn_training_update->set_kernel_info(kernel_info); + bn_training_update->set_abstract(bn_cnode->abstract()); + bn_training_update->set_scope(bn_cnode->scope()); + auto factor = AnfAlgo::GetNodeAttr(bn_cnode, kAttrMomentum); + AnfAlgo::SetNodeAttr(kAttrFactor, MakeValue(factor), bn_training_update); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update); + AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update); + return bn_training_update; +} + +AnfNodePtr SplitFusedBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() < kBnInputNum) { + MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs."; + return nullptr; + } + // Create BNTrainingReduce node and get outputs of BNTrainingReduce + std::vector bn_training_reduce_outputs; + if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) { + MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split"; + return nullptr; + } + if (bn_training_reduce_outputs.size() != kBN1OutputNum) { + MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"; + } + + // Create BNTrainingUpdate node + return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs); +} +} // namespace + +const BaseRef BnSplit::DefinePattern() const { + VarPtr Xs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + return VectorRef({prim::kPrimFusedBatchNorm, Xs}); +} + +const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + return SplitFusedBatchNormForTBE(func_graph, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h new file mode 100644 index 0000000000..4340ba0af6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class BnSplit : public PatternProcessPass { + public: + explicit BnSplit(bool multigraph = true) : PatternProcessPass("bn_split", multigraph) {} + ~BnSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc new file mode 100644 index 0000000000..e8a778b36f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/lars_v2_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +namespace { +void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2, + std::vector *square_sum_all_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(lars_v2); + if (lars_v2->size() != kLarsV2InputNum) { + MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; + } + + std::vector inputs = {NewValueNode(std::make_shared(kSquareSumAllOpName))}; + inputs.push_back(lars_v2->input(1)); + inputs.push_back(lars_v2->input(2)); + auto square_sum_all = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(square_sum_all); + square_sum_all->set_scope(lars_v2->scope()); + + auto types = {kNumberTypeFloat32, kNumberTypeFloat32}; + std::vector shape; + auto shapes = {shape, shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sum_all.get()); + + CreateMultipleOutputsOfAnfNode(graph, square_sum_all, 2, square_sum_all_outputs); +} + +CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2, + const std::vector &square_sum_all_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(lars_v2); + if (square_sum_all_outputs.size() != 2) { + MS_LOG(EXCEPTION) << "square_sum_all_outputs' size not equal 2"; + } + if (lars_v2->size() != kLarsV2InputNum) { + MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; + } + std::vector inputs = {NewValueNode(std::make_shared(kLarsV2UpdateOpName))}; + inputs.push_back(lars_v2->input(1)); + inputs.push_back(lars_v2->input(2)); + inputs.push_back(square_sum_all_outputs[0]); + inputs.push_back(square_sum_all_outputs[1]); + inputs.push_back(lars_v2->input(3)); + inputs.push_back(lars_v2->input(4)); + auto lars_v2_update = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(lars_v2_update); + lars_v2_update->set_scope(lars_v2->scope()); + lars_v2_update->set_abstract(lars_v2->abstract()); + return lars_v2_update; +} +} // namespace + +const BaseRef LarsV2Fission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto lars_v2_prim = std::make_shared(kLarsV2OpName); + return VectorRef({lars_v2_prim, Xs}); +} + +const AnfNodePtr LarsV2Fission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto lars_v2 = node->cast(); + MS_EXCEPTION_IF_NULL(lars_v2); + + std::vector square_sum_all_outputs; + CreateOutputsOfSquareSumAll(graph, lars_v2, &square_sum_all_outputs); + return CreateLarsV2Update(graph, lars_v2, square_sum_all_outputs); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.h new file mode 100644 index 0000000000..3a165f2b29 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class LarsV2Fission : public PatternProcessPass { + public: + explicit LarsV2Fission(bool multigraph = true) : PatternProcessPass("lars_v2_fission", multigraph) {} + ~LarsV2Fission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc new file mode 100644 index 0000000000..1d19def787 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "ir/primitive.h" +#include "common/utils.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop( + const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, + std::vector *layer_norm_x_backprop_outputs) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(layer_norm_grad); + auto prim = std::make_shared(kLayerNormXBackpropOpName); + std::vector layer_norm_x_backprop_inputs = {NewValueNode(prim)}; + for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) { + layer_norm_x_backprop_inputs.push_back(layer_norm_grad->input(i)); + } + auto layer_norm_x_backprop = graph->NewCNode(layer_norm_x_backprop_inputs); + MS_EXCEPTION_IF_NULL(layer_norm_x_backprop); + layer_norm_x_backprop->set_scope(layer_norm_grad->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_x_backprop.get()); + + (*layer_norm_x_backprop_outputs).push_back(layer_norm_x_backprop); +} + +void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop( + const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, + std::vector *layer_norm_beta_gamma_backprop_outputs) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(layer_norm_grad); + auto prim = std::make_shared(kLayerNormBetaGammaBackpropOpName); + std::vector layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim)}; + for (size_t i = 1; i < layer_norm_grad->inputs().size() - 1; ++i) { + layer_norm_beta_gamma_backprop_inputs.push_back(layer_norm_grad->input(i)); + } + auto layer_norm_beta_gamma_backprop = graph->NewCNode(layer_norm_beta_gamma_backprop_inputs); + MS_EXCEPTION_IF_NULL(layer_norm_beta_gamma_backprop); + auto kernel_info = std::make_shared(); + layer_norm_beta_gamma_backprop->set_kernel_info(kernel_info); + layer_norm_beta_gamma_backprop->set_scope(layer_norm_grad->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 1), + AnfAlgo::GetOutputInferDataType(layer_norm_grad, 2)}; + auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 1), AnfAlgo::GetOutputInferShape(layer_norm_grad, 2)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_beta_gamma_backprop.get()); + + // get device shape of LayerNormGrad's 5th Input, and convert it to attr + std::vector shape_gamma = AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, 4); + AnfAlgo::SetNodeAttr(kAttrShapeGamma, MakeValue(opt::Convert2Int(shape_gamma)), layer_norm_beta_gamma_backprop); + + CreateMultipleOutputsOfAnfNode(graph, layer_norm_beta_gamma_backprop, kLayerNormBetaGammaBackpropOutputNum, + layer_norm_beta_gamma_backprop_outputs); +} + +const BaseRef LayerNormGradSplit::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VectorRef pattern({prim::kPrimLayerNormGrad, Xs}); + return pattern; +} + +const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode->inputs().size() != kLayerNormGradInputNum) { + return nullptr; + } + + // create layer_norm_x_backprop + std::vector layer_norm_x_backprop_outputs; + CreateOutputsOfLayerNormXBackprop(graph, cnode, &layer_norm_x_backprop_outputs); + if (layer_norm_x_backprop_outputs.size() != kSingleOutputNum) { + MS_LOG(EXCEPTION) << "layer_norm_grad_outputs has wrong size"; + } + + // create layer_norm_beta_gamma_backprop + std::vector layer_norm_beta_gamma_backprop_outputs; + CreateOutputsOfLayerNormBetaGammaBackprop(graph, cnode, &layer_norm_beta_gamma_backprop_outputs); + if (layer_norm_beta_gamma_backprop_outputs.size() != kLayerNormBetaGammaBackpropOutputNum) { + MS_LOG(EXCEPTION) << "layer_norm_beta_gamma_outputs has wrong size"; + } + + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), layer_norm_x_backprop_outputs[0], + layer_norm_beta_gamma_backprop_outputs[0], + layer_norm_beta_gamma_backprop_outputs[1]}; + auto make_tuple = graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + return make_tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h new file mode 100644 index 0000000000..c1501b1593 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class LayerNormGradSplit : public PatternProcessPass { + public: + explicit LayerNormGradSplit(bool multigraph = true) : PatternProcessPass("layer_norm_grad_split", multigraph) {} + ~LayerNormGradSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + void CreateOutputsOfLayerNormXBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, + std::vector *layer_norm_grad_outputs) const; + void CreateOutputsOfLayerNormBetaGammaBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, + std::vector *layer_norm_beta_gamma_outputs) const; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.cc new file mode 100644 index 0000000000..133d51734f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kBatchNormRealInputNum = 3; + +AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn); + auto bn_cnode = bn->cast(); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { + MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " + << kBatchNormRealInputNum + 1; + } + std::vector bn_training_reduce_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceOpName)), bn_cnode->input(1)}; + auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); + MS_EXCEPTION_IF_NULL(bn_training_reduce); + + // set abstract + auto bn_input1 = bn_cnode->input(2); + MS_EXCEPTION_IF_NULL(bn_input1); + AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input1->abstract()}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_reduce->set_abstract(abstract_tuple); + bn_training_reduce->set_scope(bn->scope()); + return bn_training_reduce; +} + +AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, + const std::vector &bn_training_reduce_outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn); + auto bn_cnode = bn->cast(); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { + MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " + << kBatchNormRealInputNum + 1; + } + if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum + << ", but it is " << bn_training_reduce_outputs.size(); + } + std::vector bn_training_update_v3_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateV3OpName)), + bn_cnode->input(1), + bn_training_reduce_outputs[0], + bn_training_reduce_outputs[1], + bn_cnode->input(2), + bn_cnode->input(3)}; + auto bn_training_update_v3 = func_graph->NewCNode(bn_training_update_v3_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update_v3); + + auto bn_abstract_tuple = dyn_cast(bn->abstract()); + MS_EXCEPTION_IF_NULL(bn_abstract_tuple); + if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { + MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " + << bn_abstract_tuple->elements().size(); + } + bn_training_update_v3->set_abstract(bn->abstract()); + bn_training_update_v3->set_scope(bn->scope()); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update_v3); + return bn_training_update_v3; +} +} // namespace + +const BaseRef SingleBatchNormFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimBatchNorm, Xs}); +} + +const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() < kBatchNormRealInputNum + 1) { + MS_LOG(INFO) << "The input num of BatchNorm less than" << kBatchNormRealInputNum + << ". The node should not be changed"; + return nullptr; + } + if (!GetBoolAttr(cnode, kAttrIsTraining)) { + MS_LOG(INFO) << "is training should be true if do fusion"; + return nullptr; + } + AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); + std::vector bn_training_reduce_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, + &bn_training_reduce_outputs); + + return CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h new file mode 100644 index 0000000000..fb641c12d6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SingleBatchNormFission : public PatternProcessPass { + public: + explicit SingleBatchNormFission(bool multigraph = true) + : PatternProcessPass("single_batch_norm_fission", multigraph) {} + ~SingleBatchNormFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc new file mode 100644 index 0000000000..063f81a1ca --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc @@ -0,0 +1,197 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/split_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(input_node); + std::vector splitv_inputs{NewValueNode(std::make_shared(kSplitVOpName)), input_node}; + CNodePtr splitv = func_graph->NewCNode(splitv_inputs); + MS_EXCEPTION_IF_NULL(splitv); + splitv->set_scope(input_node->scope()); + return splitv; +} + +CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { + MS_EXCEPTION_IF_NULL(origin_cnode); + if (origin_cnode->inputs().size() < kSplitInputNum) { + MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " + << kSplitInputNum - 1; + } + return CreateSplitVNode(func_graph, origin_cnode->input(1)); +} + +void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector &size_splits, int split_dim, int num_split) { + AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv); + AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv); + AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv); +} + +size_t GetSmallSplitSize(const AnfNodePtr &split_node, int split_dim, int num_split) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0); + if (split_dim < 0) { + split_dim += input_shape.size(); + } + if (IntToSize(split_dim) >= input_shape.size()) { + MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0"; + } + return input_shape[split_dim] / num_split; +} + +void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int outputs_num, + std::vector *inputs) { + MS_EXCEPTION_IF_NULL(inputs); + std::vector new_splitv_output; + CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, outputs_num, &new_splitv_output); + inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end()); +} + +AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) { + MS_EXCEPTION_IF_NULL(func_graph); + auto idx = NewValueNode(SizeToInt(index)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToInt(index)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx}); + return tuple_getitem; +} + +void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int split_size, int num_split, + std::vector *new_type_ids, + std::vector> *new_output_shapes) { + MS_EXCEPTION_IF_NULL(new_type_ids); + MS_EXCEPTION_IF_NULL(new_output_shapes); + auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); + if (split_dim < 0) { + split_dim += output_shape.size(); + } + output_shape[split_dim] = split_size; + TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + for (int i = 0; i < num_split; ++i) { + new_type_ids->emplace_back(type_id); + new_output_shapes->emplace_back(output_shape); + } +} + +void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, + const std::vector &size_splits_base, int split_dim, int num_split) { + SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); + std::vector base_type_ids; + std::vector> base_output_shapes_base; + auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); + TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + if (split_dim < 0) { + split_dim += output_shape.size(); + } + for (int i = 0; i < num_split; ++i) { + output_shape[split_dim] = size_splits_base[i]; + base_output_shapes_base.emplace_back(output_shape); + base_type_ids.emplace_back(type_id); + } + AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); +} + +AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int num_split, int divisor) { + MS_EXCEPTION_IF_NULL(func_graph); + auto split_dim = AnfAlgo::GetNodeAttr(cnode, kAttrAxis); + CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode); + + // Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs. + auto small_split_size = SizeToInt(GetSmallSplitSize(cnode, split_dim, num_split)); + std::vector size_splits_new; + for (int i = 0; i < divisor; ++i) { + size_splits_new.emplace_back(small_split_size); + } + // Create new output shape and new output type id for each new Splitv node which has full inputs. + std::vector new_type_ids; + std::vector> new_output_shapes; + CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes); + + // Create make_tuple input to create a make_tuple for replacing the old Split node. + std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; + // Start to divide the outputs of Split. + std::vector size_splits_base; + const auto base_split_size = divisor * small_split_size; + int nodes_num = 0; + int cur_output_index = 0; + while (num_split - cur_output_index > divisor) { + CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); + SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor); + AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get()); + AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs); + cur_output_index += divisor; + size_splits_base.emplace_back(base_split_size); + nodes_num++; + } + if (cur_output_index < num_split) { + auto last_node_num_split = num_split - cur_output_index; + if (last_node_num_split > 1) { + CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); + std::vector size_splits_new_last; + for (int i = 0; i < last_node_num_split; ++i) { + size_splits_new_last.emplace_back(small_split_size); + } + SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); + // Create new output shape and new output type id for the last Splitv node + std::vector last_new_type_ids; + std::vector> last_new_output_shapes; + CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids, + &last_new_output_shapes); + AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get()); + AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs); + size_splits_base.emplace_back(last_node_num_split * small_split_size); + } else { + make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num)); + size_splits_base.emplace_back(small_split_size); + } + nodes_num++; + } + // Set Attr and abstract for the base splitv + SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num); + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace + +const BaseRef SplitFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto split_prim = std::make_shared(kSplitOpName); + return VectorRef({split_prim, Xs}); +} + +const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // Check output num + if (!AnfAlgo::HasNodeAttr(kAttrOutputNum, cnode)) { + return nullptr; + } + auto num_split = AnfAlgo::GetNodeAttr(cnode, kAttrOutputNum); + if (num_split <= outputs_divisor_) { + return nullptr; + } + return DoFission(func_graph, cnode, num_split, outputs_divisor_); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.h new file mode 100644 index 0000000000..6428a21e73 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr int kSplitOutputsDivisor = 63; +class SplitFission : public PatternProcessPass { + public: + explicit SplitFission(bool multigraph = true) + : PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {} + ~SplitFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + int outputs_divisor_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.cc new file mode 100644 index 0000000000..c9a879e921 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(tensor_scatter_update); + std::vector inputs = {NewValueNode(std::make_shared(kTensorMoveOpName)), + tensor_scatter_update->input(1)}; + auto tensor_move = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(tensor_move); + tensor_move->set_scope(tensor_scatter_update->scope()); + tensor_move->set_abstract(tensor_scatter_update->abstract()); + AnfAlgo::SetNodeAttr(kAttrUseLocking, MakeValue(false), tensor_move); + return tensor_move; +} + +CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update, + const CNodePtr &tensor_move) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(tensor_scatter_update); + MS_EXCEPTION_IF_NULL(tensor_move); + std::vector inputs = {NewValueNode(std::make_shared(kScatterNdUpdateOpName)), tensor_move, + tensor_scatter_update->input(2), tensor_scatter_update->input(3)}; + auto scatter_nd_update = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(scatter_nd_update); + scatter_nd_update->set_scope(tensor_scatter_update->scope()); + scatter_nd_update->set_abstract(tensor_scatter_update->abstract()); + return scatter_nd_update; +} +} // namespace + +const BaseRef TensorScatterUpdateFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto prim = std::make_shared(kTensorScatterUpdateOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr TensorScatterUpdateFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto tensor_scatter_update = node->cast(); + if (tensor_scatter_update == nullptr || tensor_scatter_update->size() != 4) { + return nullptr; + } + auto tensor_move = CreateTensorMove(func_graph, tensor_scatter_update); + return CreateScatterNdUpdate(func_graph, tensor_scatter_update, tensor_move); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h new file mode 100644 index 0000000000..0f7efb029c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class TensorScatterUpdateFission : public PatternProcessPass { + public: + explicit TensorScatterUpdateFission(bool multigraph = true) + : PatternProcessPass("tensor_scatter_update_fission", multigraph) {} + ~TensorScatterUpdateFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc new file mode 100644 index 0000000000..6eeb7a61f7 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/topk_split.h" +#include +#include +#include +#include +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "utils/utils.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +constexpr size_t kFloat16Len = 2; // size of float16; +constexpr size_t kTopkIndexK = 1; +namespace { +tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { + // 1 create tensor + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + auto last_dim = shape[shape.size() - 1]; + std::vector indices_shape = {SizeToInt(last_dim * 2)}; + TensorTypePtr tensor_type = std::make_shared(kFloat16); + MS_EXCEPTION_IF_NULL(tensor_type); + tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type}; + tensor::TensorPtr indices_tensor = std::make_shared(kFloat16->type_id(), indices_shape); + MS_EXCEPTION_IF_NULL(indices_tensor); + indices_tensor->set_device_info(device_info); + + // 2 set value of tensor + auto data_ptr = indices_tensor->data_c(); + MS_EXCEPTION_IF_NULL(data_ptr); + std::vector half_data; + for (size_t i = 0; i < last_dim; ++i) { + half_data.emplace_back(Eigen::half(static_cast(i))); + } + for (size_t i = 0; i < last_dim; ++i) { + auto gap = static_cast(i) - static_cast(Eigen::half(static_cast(i))); + half_data.emplace_back(Eigen::half(static_cast(gap))); + } + auto elem_num = last_dim * kFloat16Len * 2; + auto ret_code = memcpy_s(data_ptr, static_cast(indices_tensor->data().nbytes()), half_data.data(), elem_num); + if (ret_code != 0) { + MS_LOG(ERROR) << "Failed to copy data into Tensor."; + return nullptr; + } + return indices_tensor; +} + +ValueNodePtr CreateValueNode(const AnfNodePtr &node) { + tensor::TensorPtr indices_tensor = CreateTensor(node); + MS_EXCEPTION_IF_NULL(indices_tensor); + auto indices_const = std::make_shared(indices_tensor); + MS_EXCEPTION_IF_NULL(indices_const); + auto indices_abstract = indices_tensor->ToAbstract(); + indices_const->set_abstract(indices_abstract); + auto indices_kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(indices_kernel_info); + indices_const->set_kernel_info(indices_kernel_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1; + builder1.SetOutputsFormat({kOpFormat_DEFAULT}); + builder1.SetOutputsDeviceType({kNumberTypeFloat16}); + AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get()); + return indices_const; +} + +kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetKernelType(TBE_KERNEL); + builder.SetFusionType(kernel::OPAQUE); + builder.SetProcessor(kernel::AICORE); + builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); + builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); + builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); + builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); + return builder.Build(); +} + +bool CheckInputNamesSize(const CNodePtr &cnode) { + auto input_names_vec = AnfAlgo::GetNodeAttr>(cnode, kAttrInputNames); + if (input_names_vec.size() < kTopkIndexK + 1) { + MS_LOG(INFO) << "The input k of topk has been converted to attr"; + return false; + } + return true; +} + +bool CheckOutputShape(const AnfNodePtr &node) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + if (shape.empty()) { + MS_LOG(INFO) << "The output shape of topk to split must not be empty"; + return false; + } + auto last_dim = shape[shape.size() - 1]; + const size_t kMaxFloat16 = 65500; + if (last_dim > kMaxFloat16) { + MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops."; + return false; + } + return true; +} +} // namespace + +const BaseRef TopKSplit::DefinePattern() const { + VarPtr X1 = std::make_shared(); + VarPtr X2 = std::make_shared(); + auto prim = std::make_shared(kTopKOpName); + return VectorRef({prim, X1, X2}); +} + +const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto kernel_graph = func_graph->cast(); + // set value node as topk's input + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!CheckInputNamesSize(cnode)) { + return nullptr; + } + if (!CheckOutputShape(cnode)) { + return nullptr; + } + // Copy a new node to check supported. + std::vector new_inputs{NewValueNode(std::make_shared(kTopKOpName))}; + new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + CNodePtr new_cnode = func_graph->NewCNode(new_inputs); + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_scope(cnode->scope()); + AnfAlgo::CopyNodeAttrs(cnode, new_cnode); + CheckCNodeInputSize(new_cnode, kTopkInputNum); + // Convert the tensor input to scalar and convert it to attr + auto input_k = new_cnode->input(kTopkIndexK + 1); + MS_EXCEPTION_IF_NULL(input_k); + if (!IsValueNode(input_k)) { + return nullptr; + } + ValuePtr value = GetValueNode(input_k); + MS_EXCEPTION_IF_NULL(value); + auto tensor = value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + int32_t *data = reinterpret_cast(tensor->data_c()); + MS_EXCEPTION_IF_NULL(data); + auto new_value_node = std::make_shared(MakeValue(*data)); + new_cnode->set_input(kTopkIndexK + 1, new_value_node); + + std::unordered_set attr_index{kTopkIndexK}; + ConstInputToAttr(new_cnode, attr_index); + auto indices_const = CreateValueNode(new_cnode); + new_cnode->add_input(indices_const); + MS_EXCEPTION_IF_NULL(supported_checker_); + if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { + MS_LOG(INFO) << "split topk failed, check to aicpu."; + return nullptr; + } + + if (kernel_graph != nullptr) { + MS_LOG(INFO) << "split topk success. use tbe aicore."; + kernel_graph->AddValueNodeToGraph(indices_const); + } + + return new_cnode; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.h new file mode 100644 index 0000000000..e005a83a2f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class TopKSplit : public PatternProcessPass { + public: + explicit TopKSplit(bool multigraph = true) + : PatternProcessPass("topk_split", multigraph), supported_checker_(std::make_shared()) {} + ~TopKSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + SupportedCheckerPtr supported_checker_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc new file mode 100644 index 0000000000..057cf8deed --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/transdata_split.h" +#include +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +const std::set> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, + {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, + {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, + {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; + +bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + bool changed = false; + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { + CheckCNodeInputSize(node->cast(), kBackendTransDataInputNum); + if (IsFormatInvaild(node)) { + changed = DoSplit(func_graph, node); + } + } + } + return changed; +} +bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input_format = AnfAlgo::GetInputFormat(node, 0); + auto output_format = AnfAlgo::GetOutputFormat(node, 0); + auto format_pair = std::make_pair(input_format, output_format); + + return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); +} +// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) +bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input_node = node->cast()->input(1); + MS_EXCEPTION_IF_NULL(input_node); + + auto input_format = AnfAlgo::GetInputFormat(node, 0); + auto output_format = AnfAlgo::GetOutputFormat(node, 0); + AnfNodePtr new_transdata_node = nullptr; + AnfNodePtr new_transpose_node = nullptr; + AnfNodePtr new_replace_node = nullptr; + // if output_format=default transdata need split transdata->transpose else transpose->transdata + if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { + // trans input_format to hwcn + new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, + false, prim::KPrimTransData->name()); + RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node); + // trans hwcn to default_format + new_transpose_node = + NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); + RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node); + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{3, 2, 0, 1}), new_transpose_node); + new_replace_node = new_transpose_node; + } else { + // trans default to hwcn + new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, + false, prim::kPrimTranspose->name()); + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); + RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node); + + // trans hwcn to output_format + new_transdata_node = + NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); + RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); + new_replace_node = new_transdata_node; + } + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + + if (!manager->Replace(node, new_replace_node)) { + MS_LOG(EXCEPTION) << "Manager replace node failed"; + } + MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h new file mode 100644 index 0000000000..bc681944c3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ +#include +#include +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class TransDataSplit : public Pass { + public: + TransDataSplit() : Pass("trans_data_split"), kernel_select_(std::make_shared()) {} + ~TransDataSplit() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + bool DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node); + bool IsFormatInvaild(const AnfNodePtr &node); + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc new file mode 100644 index 0000000000..189ac94546 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h" +#include "backend/optimizer/common/helper.h" +namespace mindspore { +namespace opt { +AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto prim = std::make_shared(kAdamApplyOneOpName); + std::vector new_node_inputs = {NewValueNode(prim)}; + for (const auto &input_var : input_vars_) { + auto input_node = utils::cast((*equiv)[input_var]); + MS_EXCEPTION_IF_NULL(input_node); + new_node_inputs.push_back(input_node); + } + for (const auto &mul_x_input_var : mul_x_input_vars_) { + auto mul_x_input_node = utils::cast((*equiv)[mul_x_input_var]); + MS_EXCEPTION_IF_NULL(mul_x_input_node); + new_node_inputs.push_back(mul_x_input_node); + } + auto add2_y_node = utils::cast((*equiv)[add2_y_]); + MS_EXCEPTION_IF_NULL(add2_y_node); + new_node_inputs.push_back(add2_y_node); + auto new_node = func_graph->NewCNode(new_node_inputs); + return new_node; +} + +const BaseRef AdamApplyOneFusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); +} + +const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); +} + +const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); +} + +const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); +} + +const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); +} + +const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + auto new_node = CreateAdamApplyOneNode(func_graph, equiv); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(node->scope()); + // Set abstract of new node + AbstractBasePtrList new_node_abstract_list; + auto iter_add0 = (*equiv).find(add0_var_); + if (iter_add0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; + } + auto iter_add1 = (*equiv).find(add1_var_); + if (iter_add1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; + } + auto add0 = utils::cast(iter_add0->second); + MS_EXCEPTION_IF_NULL(add0); + auto add1 = utils::cast(iter_add1->second); + MS_EXCEPTION_IF_NULL(add1); + new_node_abstract_list.push_back(add1->abstract()); + new_node_abstract_list.push_back(add0->abstract()); + new_node_abstract_list.push_back(node->abstract()); + auto abstract_tuple = std::make_shared(new_node_abstract_list); + new_node->set_abstract(abstract_tuple); + // Create tuple_getitem node for outputs + std::vector new_node_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, new_node, kAdamApplyOneOutputNum, &new_node_outputs); + if (new_node_outputs.size() != kAdamApplyOneOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node " << new_node->DebugString() << " should be " + << kAdamApplyOneOutputNum; + } + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(add1, new_node_outputs[0]); + (void)manager->Replace(add0, new_node_outputs[1]); + return new_node_outputs[2]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h new file mode 100644 index 0000000000..683a345cdb --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +constexpr size_t kAdamApplyOneInputVarNum = 5; +constexpr size_t kAdamApplyOneMulInputVarNum = 4; + +class AdamApplyOneFusion : public PatternProcessPass { + public: + explicit AdamApplyOneFusion(const std::string &name = "adam_apply_one_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph) { + for (size_t i = 0; i < kAdamApplyOneInputVarNum; ++i) { + input_vars_.push_back(std::make_shared()); + } + for (size_t i = 0; i < kAdamApplyOneMulInputVarNum; ++i) { + mul_x_input_vars_.push_back(std::make_shared()); + } + add2_y_ = std::make_shared(); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + } + + ~AdamApplyOneFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + protected: + AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; + std::vector input_vars_; + std::vector mul_x_input_vars_; + VarPtr add2_y_; + VarPtr add0_var_; + VarPtr add1_var_; +}; + +class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond1Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond1_fusion", multigraph) {} + + ~AdamApplyOneCond1Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneCond2Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond2Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond2_fusion", multigraph) {} + + ~AdamApplyOneCond2Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneCond3Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond3Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond3_fusion", multigraph) {} + + ~AdamApplyOneCond3Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond4Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond4_fusion", multigraph) {} + + ~AdamApplyOneCond4Fusion() override = default; + const BaseRef DefinePattern() const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc new file mode 100644 index 0000000000..b1afa338d4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc @@ -0,0 +1,189 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" + +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +std::vector AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(equiv); + auto input0 = utils::cast((*equiv)[input0_]); + auto input1 = utils::cast((*equiv)[input1_]); + auto input2 = utils::cast((*equiv)[input2_]); + auto input3 = utils::cast((*equiv)[input3_]); + auto input4 = utils::cast((*equiv)[input4_]); + auto mul0_x = utils::cast((*equiv)[mul0_x_]); + auto mul1_x = utils::cast((*equiv)[mul1_x_]); + auto mul2_x = utils::cast((*equiv)[mul2_x_]); + auto mul3_x = utils::cast((*equiv)[mul3_x_]); + auto mul4_x = utils::cast((*equiv)[mul4_x_]); + auto add2_y = utils::cast((*equiv)[add2_y_]); + auto prim = std::make_shared(kAdamApplyOneWithDecayOpName); + return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y}; +} + +const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, input4_, add3}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, input2_, mul0_x_}); + VectorRef mul1({prim::kPrimMul, input0_, mul1_x_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, input1_, mul2_x_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + if (graph == nullptr || node == nullptr || equiv == nullptr) { + return nullptr; + } + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + std::vector inputs = GetFusionNodeInputs(equiv); + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(node->scope()); + + auto iter_add0 = (*equiv).find(add0_var_); + if (iter_add0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; + } + auto iter_add1 = (*equiv).find(add1_var_); + if (iter_add1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; + } + auto add0 = utils::cast(iter_add0->second); + MS_EXCEPTION_IF_NULL(add0); + auto add1 = utils::cast(iter_add1->second); + MS_EXCEPTION_IF_NULL(add1); + auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), + AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), + AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); + + std::vector fusion_node_outputs; + CreateMultipleOutputsOfAnfNode(graph, fusion_node, kAdamApplyOneWithDecayOutputNum, &fusion_node_outputs); + if (fusion_node_outputs.size() != kAdamApplyOneWithDecayOutputNum) { + MS_LOG(ERROR) << "create multiple outputs for fusion node fail!"; + return nullptr; + } + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(add1, fusion_node_outputs[0]); + (void)manager->Replace(add0, fusion_node_outputs[1]); + return fusion_node_outputs[2]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h new file mode 100644 index 0000000000..2d599a8cc9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" +namespace mindspore { +namespace opt { +class AdamApplyOneWithDecayRule : public PatternProcessPass { + public: + explicit AdamApplyOneWithDecayRule(const std::string &name = "adam_apply_one_with_decay_rule", bool multigraph = true) + : PatternProcessPass(name, multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + input3_ = std::make_shared(); + input4_ = std::make_shared(); + mul0_x_ = std::make_shared(); + mul1_x_ = std::make_shared(); + mul2_x_ = std::make_shared(); + mul3_x_ = std::make_shared(); + mul4_x_ = std::make_shared(); + add2_y_ = std::make_shared(); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + } + ~AdamApplyOneWithDecayRule() override = default; + const BaseRef DefinePattern() const override = 0; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + protected: + std::vector GetFusionNodeInputs(const EquivPtr &equiv) const; + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr input3_; + VarPtr input4_; + VarPtr mul0_x_; + VarPtr mul1_x_; + VarPtr mul2_x_; + VarPtr mul3_x_; + VarPtr mul4_x_; + VarPtr add2_y_; + VarPtr add0_var_; + VarPtr add1_var_; +}; + +class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond1(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond1", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond1() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond2 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond2(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond2", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond2() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond3 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond3(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond3", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond3() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond4 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond4(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond4", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond4() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond5(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond5", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond5() override = default; + const BaseRef DefinePattern() const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc new file mode 100644 index 0000000000..cc58d2b057 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" +#include +#include +#include "backend/optimizer/ascend/ir_fusion/input_to_output_registry.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" + +namespace mindspore { +namespace opt { +namespace { +void GetInputOrOutputNames(const CNodePtr &cnode, const std::string &attr_name, std::vector *names_vec) { + MS_EXCEPTION_IF_NULL(names_vec); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + ValuePtr names_value = primitive->GetAttr(attr_name); + if (names_value == nullptr) { + return; + } + *names_vec = GetValue>(names_value); +} + +void AddOutputs(const CNodePtr &cnode, const std::vector &input_indices) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector input_names_vec; + GetInputOrOutputNames(cnode, kAttrInputNames, &input_names_vec); + std::vector output_names_vec; + GetInputOrOutputNames(cnode, kAttrOutputNames, &output_names_vec); + AbstractBasePtrList abstract_list; + auto origin_abstract = cnode->abstract(); + MS_EXCEPTION_IF_NULL(origin_abstract); + if (origin_abstract->isa()) { + auto origin_abstract_tuple = dyn_cast(origin_abstract); + MS_EXCEPTION_IF_NULL(origin_abstract_tuple); + AbstractBasePtrList origin_abstract_list = origin_abstract_tuple->elements(); + (void)std::copy(origin_abstract_list.begin(), origin_abstract_list.end(), std::back_inserter(abstract_list)); + } else { + abstract_list.emplace_back(origin_abstract); + } + + for (size_t i = 0; i < input_indices.size(); ++i) { + size_t index = input_indices[i]; + if (index + 1 >= cnode->inputs().size()) { + MS_LOG(INFO) << "The input index " << index << " for converting to output is out of range, " + << "node: " << cnode->DebugString(); + continue; + } + auto node_to_output = cnode->input(index + 1); + MS_EXCEPTION_IF_NULL(node_to_output); + abstract_list.emplace_back(node_to_output->abstract()); + if (!input_names_vec.empty() && !output_names_vec.empty() && index < input_names_vec.size()) { + output_names_vec.emplace_back(input_names_vec[index]); + } + } + if (!output_names_vec.empty()) { + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names_vec), cnode); + } + auto abstract_tuple = std::make_shared(abstract_list); + cnode->set_abstract(abstract_tuple); +} +} // namespace + +const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string op_name = AnfAlgo::GetCNodeName(cnode); + InputToOutputRegister reg; + if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, ®)) { + return nullptr; + } + int output_num = op_finder_->GetOpRegisteredOutputNum(op_name); + // No need add output when it is not a tbe op. + if (output_num == -1) { + return nullptr; + } + // No need add output if the output num matches the registered output num for tbe. + if (AnfAlgo::GetOutputTensorNum(cnode) >= IntToSize(output_num)) { + return nullptr; + } + bool is_origin_tuple_output = AnfAlgo::IsTupleOutput(cnode); + AddOutputs(cnode, reg.input_indices()); + // No need to create tuple_getitem if the origin output is a tuple because there has already been some tuple_getitems + // pointed to the outputs. + if (is_origin_tuple_output) { + return nullptr; + } + std::vector new_outputs; + auto new_abstract_tuple = dyn_cast(cnode->abstract()); + MS_EXCEPTION_IF_NULL(new_abstract_tuple); + CreateMultipleOutputsOfAnfNode(func_graph, cnode, new_abstract_tuple->size(), &new_outputs); + if (new_outputs.size() != new_abstract_tuple->size()) { + MS_LOG(EXCEPTION) << "Failed to create outputs of " << cnode->DebugString(); + } + return new_outputs[0]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.h new file mode 100644 index 0000000000..6e5560bfb0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class AddInputToOutput : public PatternProcessPass { + public: + explicit AddInputToOutput(bool multigraph = true) + : PatternProcessPass("add_input_to_output", multigraph), op_finder_(std::make_shared()) {} + ~AddInputToOutput() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + OpFinderPtr op_finder_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc new file mode 100644 index 0000000000..51bcd880cd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateBNInfer(const FuncGraphPtr &graph, const CNodePtr &batchnorm, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(batchnorm); + MS_EXCEPTION_IF_NULL(node); + auto prim = std::make_shared(kBNInferOpName); + std::vector inputs = {NewValueNode(prim)}; + for (size_t i = 1; i < batchnorm->size(); ++i) { + inputs.push_back(batchnorm->input(i)); + } + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(batchnorm->scope()); + new_node->set_abstract(node->abstract()); + AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnorm, new_node); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnorm, new_node); + return new_node; +} + +bool CheckIndex(const AnfNodePtr &index_node) { + MS_EXCEPTION_IF_NULL(index_node); + if (!IsValueNode(index_node)) { + return false; + } + ValueNodePtr value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (index != 0) { + MS_LOG(DEBUG) << "tuple_getitem must be 0th output of BatchNorm"; + return false; + } + return true; +} + +bool CheckBatchNorm(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(batchnorm); + if (batchnorm->size() < kBatchNormInputNum + 1) { + MS_LOG(DEBUG) << "BatchNorm's input less than " << kBatchNormInputNum; + return false; + } + if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) { + return false; + } + auto is_training = AnfAlgo::GetNodeAttr(batchnorm, kAttrIsTraining); + if (is_training) { + MS_LOG(DEBUG) << "is_training is true, no need do fusion"; + return false; + } + + if (IsUsedByOthers(graph, batchnorm)) { + MS_LOG(DEBUG) << "Only the 0th output of BatchNorm is used, then do fusion"; + return false; + } + return true; +} + +bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto tuple_getitem = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); + AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + if (!CheckIndex(index_node)) { + return false; + } + + AnfNodePtr batchnorm_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(batchnorm_anf); + MS_EXCEPTION_IF_NULL(batchnorm); + *batchnorm = batchnorm_anf->cast(); + MS_EXCEPTION_IF_NULL(*batchnorm); + return CheckBatchNorm(graph, *batchnorm); +} +} // namespace + +const BaseRef BatchNorm2BNInfer::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VarPtr Y = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Y); + VectorRef batchnorm({prim::kPrimBatchNorm, Xs}); + VectorRef pattern({prim::kPrimTupleGetItem, batchnorm, Y}); + return pattern; +} + +const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + CNodePtr batchnorm = nullptr; + if (!NeedFusion(graph, node, &batchnorm)) { + return nullptr; + } + return CreateBNInfer(graph, batchnorm, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h new file mode 100644 index 0000000000..46872aa959 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNorm2BNInfer : public PatternProcessPass { + public: + explicit BatchNorm2BNInfer(bool multigraph = true) : PatternProcessPass("batchnorm_to_bninfer", multigraph) {} + ~BatchNorm2BNInfer() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc new file mode 100644 index 0000000000..defb011396 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateBNInferGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(batchnormgrad); + auto prim = std::make_shared(kBNInferGradOpName); + std::vector inputs = {NewValueNode(prim)}; + inputs.push_back(batchnormgrad->input(1)); + inputs.push_back(batchnormgrad->input(3)); + inputs.push_back(batchnormgrad->input(5)); + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(batchnormgrad->scope()); + new_node->set_abstract(node->abstract()); + AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnormgrad, new_node); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnormgrad, new_node); + return new_node; +} + +bool CheckIndex(const AnfNodePtr &index_node) { + MS_EXCEPTION_IF_NULL(index_node); + if (!IsValueNode(index_node)) { + return false; + } + ValueNodePtr value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (index != 0) { + MS_LOG(DEBUG) << "tuple_getitem must be 0th output of BatchNormGrad"; + return false; + } + return true; +} + +bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(batchnormgrad); + if (batchnormgrad->size() < kBatchNormInputNum + 1) { + MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBatchNormInputNum; + return false; + } + if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnormgrad)) { + return false; + } + auto is_training = AnfAlgo::GetNodeAttr(batchnormgrad, kAttrIsTraining); + if (is_training) { + MS_LOG(DEBUG) << "is_training is true, no need do fusion"; + return false; + } + + if (IsUsedByOthers(graph, batchnormgrad)) { + MS_LOG(DEBUG) << "Only the 0th output of BatchNormGrad is used, then do fusion"; + return false; + } + return true; +} + +bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto tuple_getitem = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); + AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + if (!CheckIndex(index_node)) { + return false; + } + + AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(batchnormgrad_anf); + MS_EXCEPTION_IF_NULL(batchnormgrad); + *batchnormgrad = batchnormgrad_anf->cast(); + MS_EXCEPTION_IF_NULL(*batchnormgrad); + return CheckBatchNormGrad(graph, *batchnormgrad); +} +} // namespace + +const BaseRef BatchNormGrad2BNInferGrad::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VarPtr Y = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Y); + VectorRef batchnormgrad({prim::kPrimBatchNormGrad, Xs}); + VectorRef pattern({prim::kPrimTupleGetItem, batchnormgrad, Y}); + return pattern; +} + +const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + CNodePtr batchnormgrad = nullptr; + if (!NeedFusion(graph, node, &batchnormgrad)) { + return nullptr; + } + return CreateBNInferGrad(graph, batchnormgrad, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h new file mode 100644 index 0000000000..0676f8a040 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormGrad2BNInferGrad : public PatternProcessPass { + public: + explicit BatchNormGrad2BNInferGrad(bool multigraph = true) + : PatternProcessPass("batchnormgrad_to_bninfergrad", multigraph) {} + ~BatchNormGrad2BNInferGrad() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc new file mode 100644 index 0000000000..1d89bfd388 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "common/utils.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +const BaseRef ClipByNormNoDivSquareSumFusion::DefinePattern() const { + auto greater = std::make_shared(kGreaterOpName); + MS_EXCEPTION_IF_NULL(greater); + auto sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(sqrt); + + VectorRef greater_pattern({greater, input_, constant_greater_}); + VectorRef pattern( + {prim::kPrimMaximum, + VectorRef({prim::kPrimSelect, greater_pattern, + VectorRef({sqrt, VectorRef({prim::kPrimSelect, greater_pattern, input_, constant_select_})}), input_}), + constant_maximum_}); + return pattern; +} + +const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + BaseRef &input_gnode = (*equiv)[input_]; + BaseRef &constant_select_gnode = (*equiv)[constant_select_]; + BaseRef &constant_greater_gnode = (*equiv)[constant_greater_]; + BaseRef &constant_maximum_gnode = (*equiv)[constant_maximum_]; + auto input = utils::cast(input_gnode); + auto constant_select = utils::cast(constant_select_gnode); + auto constant_greater = utils::cast(constant_greater_gnode); + auto constant_maximum = utils::cast(constant_maximum_gnode); + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(constant_select); + MS_EXCEPTION_IF_NULL(constant_greater); + MS_EXCEPTION_IF_NULL(constant_maximum); + + auto prim = std::make_shared(kClipByNormNoDivSumOpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), input, constant_select, constant_greater, constant_maximum}; + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); + fusion_node->set_scope(node->scope()); + return fusion_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h new file mode 100644 index 0000000000..9282b75527 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr auto kInputVarName = "input"; +constexpr auto kConstantSelectVarName = "constant_select"; +constexpr auto kConstantGreaterVarName = "constant_greater"; +constexpr auto kConstantMaximumVarName = "constant_maximum"; + +class ClipByNormNoDivSquareSumFusion : public PatternProcessPass { + public: + explicit ClipByNormNoDivSquareSumFusion(bool multigraph = true) + : PatternProcessPass("clip_by_norm_no_div_square_sum_fusion", multigraph) { + input_ = std::make_shared(kInputVarName); + constant_select_ = std::make_shared(kConstantSelectVarName); + constant_greater_ = std::make_shared(kConstantGreaterVarName); + constant_maximum_ = std::make_shared(kConstantMaximumVarName); + } + ~ClipByNormNoDivSquareSumFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_; + VarPtr constant_select_; + VarPtr constant_greater_; + VarPtr constant_maximum_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.cc new file mode 100644 index 0000000000..e1b0cb81e3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +bool GetMinimumOp(const AnfNodePtr &input0, const AnfNodePtr &input1, CNodePtr *minimum, bool *is_first_input) { + MS_EXCEPTION_IF_NULL(input0); + MS_EXCEPTION_IF_NULL(input1); + + CNodePtr cnode = nullptr; + if (input0->isa() && !input1->isa()) { + cnode = input0->cast(); + *is_first_input = true; + } else if (!input0->isa() && input1->isa()) { + cnode = input1->cast(); + *is_first_input = false; + } else if (input0->isa() && input1->isa()) { + if (AnfAlgo::GetCNodeName(input0) == prim::kPrimMinimum->name()) { + cnode = input0->cast(); + *is_first_input = true; + } else { + cnode = input1->cast(); + *is_first_input = false; + } + } else { + return false; + } + + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMinimum->name()) { + return false; + } + *minimum = cnode; + return true; +} +} // namespace + +const BaseRef ClipByValueFusion::DefinePattern() const { + VectorRef pattern({prim::kPrimMaximum, maximum_input0_, maximum_input1_}); + return pattern; +} + +const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto maximum_input0 = utils::cast((*equiv)[maximum_input0_]); + auto maximum_input1 = utils::cast((*equiv)[maximum_input1_]); + MS_EXCEPTION_IF_NULL(maximum_input0); + MS_EXCEPTION_IF_NULL(maximum_input1); + + CNodePtr minimum = nullptr; + bool is_first_input = true; + if (!GetMinimumOp(maximum_input0, maximum_input1, &minimum, &is_first_input)) { + return nullptr; + } + MS_EXCEPTION_IF_NULL(minimum); + if (minimum->inputs().size() != kMinimumInputNum) { + return nullptr; + } + + auto prim = std::make_shared(kClipByValueOpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), minimum->input(1), + is_first_input ? maximum_input1 : maximum_input0, minimum->input(2)}; + auto clip_by_value = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(clip_by_value); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, clip_by_value.get()); + clip_by_value->set_scope(node->scope()); + return clip_by_value; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h new file mode 100644 index 0000000000..05bf713bdd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ClipByValueFusion : public PatternProcessPass { + public: + explicit ClipByValueFusion(bool multigraph = true) : PatternProcessPass("clip_by_value_fusion", multigraph) { + maximum_input0_ = std::make_shared(); + maximum_input1_ = std::make_shared(); + } + ~ClipByValueFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr maximum_input0_; + VarPtr maximum_input1_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc new file mode 100644 index 0000000000..6ccf3e29bd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t kConfusionMulGradOutputNum = 2; + +CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum, const AnfNodePtr &mul0_anf, + const AnfNodePtr &input3) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(reduce_sum); + MS_EXCEPTION_IF_NULL(mul0_anf); + MS_EXCEPTION_IF_NULL(input3); + auto mul0 = mul0_anf->cast(); + MS_EXCEPTION_IF_NULL(mul0); + + auto prim = std::make_shared(kConfusionMulGradOpName); + std::vector inputs = {NewValueNode(prim), mul0->input(1), mul0->input(2), input3}; + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(reduce_sum->scope()); + AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node); + auto types = {AnfAlgo::GetOutputInferDataType(mul0, 0), AnfAlgo::GetOutputInferDataType(reduce_sum, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(mul0, 0), AnfAlgo::GetOutputInferShape(reduce_sum, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); + return fusion_node; +} + +AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const AnfNodePtr &mul1) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input2); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(input2) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + + AnfNodePtr mul0 = nullptr; + const AnfNodeIndexSet &outputs_set = manager->node_users()[input2]; + // input2 must be the 2rd input of mul0 + auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul1](const std::pair &node_index) { + return node_index.first != mul1 && node_index.second == 2; + }); + if (it != outputs_set.end() && AnfAlgo::GetCNodeName(it->first) == prim::kPrimMul->name()) { + mul0 = it->first; + } + return mul0; +} + +bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, + const AnfNodePtr &reduce_sum, const AnfNodePtr &input2) { + MS_EXCEPTION_IF_NULL(mul0_anf); + MS_EXCEPTION_IF_NULL(mul1_anf); + MS_EXCEPTION_IF_NULL(reduce_sum); + MS_EXCEPTION_IF_NULL(input2); + auto addn = input2->cast(); + if (addn == nullptr || AnfAlgo::GetCNodeName(addn) != prim::kPrimAddN->name()) { + MS_LOG(INFO) << "mul's second input is not addn"; + return true; + } + std::vector shape = AnfAlgo::GetOutputInferShape(addn, 0); + if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) { + MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]"; + return true; + } + if (!mul0_anf->isa() || !mul1_anf->isa()) { + return true; + } + auto mul1 = mul1_anf->cast(); + MS_EXCEPTION_IF_NULL(mul1); + auto mul0 = mul0_anf->cast(); + MS_EXCEPTION_IF_NULL(mul0); + + if (IsDepend(graph, mul0->input(1), reduce_sum)) { + MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; + return true; + } + if (IsDepend(graph, mul1->input(1), mul0)) { + MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion"; + return true; + } + return false; +} +} // namespace + +const BaseRef ConfusionMulGradFusion::DefinePattern() const { + VectorRef mul1({prim::kPrimMul, input3_, input2_}); + VectorRef reduce_sum({prim::kPrimReduceSum, mul1}); + return reduce_sum; +} + +const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto input2 = utils::cast((*equiv)[input2_]); + auto input3 = utils::cast((*equiv)[input3_]); + auto reduce_sum = node->cast(); + MS_EXCEPTION_IF_NULL(reduce_sum); + auto mul1 = reduce_sum->input(1); + if (IsUsedByOthers(graph, mul1)) { + MS_LOG(INFO) << "Mul1 is used by others, quit fusion!"; + return nullptr; + } + auto mul0 = GetMul0(graph, input2, mul1); + if (mul0 == nullptr) { + MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; + return nullptr; + } + if (QuitFusion(graph, mul0, mul1, node, input2)) { + return nullptr; + } + + auto fusion_node = CreateFusionNode(graph, reduce_sum, mul0, input3); + std::vector fusion_node_outputs; + CreateMultipleOutputsOfAnfNode(graph, fusion_node, kConfusionMulGradOutputNum, &fusion_node_outputs); + + auto manage = graph->manager(); + MS_EXCEPTION_IF_NULL(manage); + manage->Replace(mul0, fusion_node_outputs[0]); + return fusion_node_outputs[1]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h new file mode 100644 index 0000000000..932f0d2890 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConfusionMulGradFusion : public PatternProcessPass { + public: + explicit ConfusionMulGradFusion(bool multigraph = true) + : PatternProcessPass("confusion_mul_grad_fusion", multigraph) { + input2_ = std::make_shared(); + input3_ = std::make_shared(); + } + ~ConfusionMulGradFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input2_; + VarPtr input3_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.cc new file mode 100644 index 0000000000..a8cf0af465 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h" + +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { + return VectorRef({prim::kPrimSub, input0_, VectorRef({reduce_sum_, VectorRef({prim::kPrimMul, input1_, input0_})})}); +} + +const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + AnfNodePtr input0 = GetAnfNodeByVar(equiv, input0_); + AnfNodePtr input1 = GetAnfNodeByVar(equiv, input1_); + AnfNodePtr sum_anf = GetAnfNodeByVar(equiv, reduce_sum_); + if (sum_anf == nullptr || !sum_anf->isa()) { + MS_LOG(WARNING) << "Matched ReduceSum is not a CNode!"; + return nullptr; + } + if (!GetBoolAttr(sum_anf, kAttrKeepDims)) { + MS_LOG(INFO) << "ReduceSum's attr keep_dims should be true if do fusion. Otherwise the calculation will be wrong"; + return nullptr; + } + + auto prim = std::make_shared(kConfusionSoftmaxGradOpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), input0, input1}; + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_abstract(node->abstract()); + fusion_node->set_scope(node->scope()); + AnfAlgo::CopyNodeAttr(kAttrAxis, sum_anf, fusion_node); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum_anf, fusion_node); + return fusion_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h new file mode 100644 index 0000000000..e3a86e22c9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConfusionSoftmaxGradRule : public PatternProcessPass { + public: + explicit ConfusionSoftmaxGradRule(bool multigraph = true) + : PatternProcessPass("confusion_softmax_grad_rule", multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + reduce_sum_ = std::make_shared(std::make_shared(prim::kPrimReduceSum->name())); + } + ~ConfusionSoftmaxGradRule() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input0_; + VarPtr input1_; + VarPtr reduce_sum_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc new file mode 100644 index 0000000000..0fe042dc4e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t kReluV2OutputNum = 2; + +CNodePtr GetRelu(const CNodePtr &relu_grad) { + MS_EXCEPTION_IF_NULL(relu_grad); + if (relu_grad->size() != kReluGradInputNum) { + MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); + } + auto relu_anf = relu_grad->input(2); + MS_EXCEPTION_IF_NULL(relu_anf); + return relu_anf->cast(); +} + +CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(relu); + if (relu->size() != kReluInputNum) { + MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); + } + + auto prim = std::make_shared(kReluV2OpName); + std::vector inputs = {NewValueNode(prim), relu->input(1)}; + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(relu->scope()); + + // ReluV2's 2rd output is mask whose data type is uint8 + TypeId mask_dtype = kNumberTypeUInt8; + std::vector mask_shape = AnfAlgo::GetOutputInferShape(relu, 0); + if (mask_shape.size() != 4) { + MS_LOG(DEBUG) << "relu's infer shape size not equal 4"; + return nullptr; + } + auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0); + if (input_dtype == kNumberTypeUInt8 || input_dtype == kNumberTypeInt8) { + mask_shape[1] = (mask_shape[1] + 31) / 32; + mask_shape.push_back(4); + } else { + mask_shape[1] = (mask_shape[1] + 15) / 16; + mask_shape.push_back(2); + } + + auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype}; + auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); + return new_node; +} + +CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(relu_grad); + MS_EXCEPTION_IF_NULL(second_input); + + auto prim = std::make_shared(kReluGradV2OpName); + std::vector inputs = {NewValueNode(prim), relu_grad->input(1), second_input}; + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(relu_grad->scope()); + new_node->set_abstract(relu_grad->abstract()); + return new_node; +} +} // namespace + +const BaseRef DereluFusion::DefinePattern() const { + VarPtr i0 = std::make_shared(); + VarPtr i1 = std::make_shared(); + VectorRef relu({prim::kPrimRelu, i1}); + VectorRef relu_grad({prim::kPrimReluGrad, i0, relu}); + return relu_grad; +} + +const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto relu_grad = node->cast(); + MS_EXCEPTION_IF_NULL(relu_grad); + auto relu = GetRelu(relu_grad); + MS_EXCEPTION_IF_NULL(relu); + + auto relu_v2 = CreateReluV2(graph, relu); + if (relu_v2 == nullptr) { + return nullptr; + } + std::vector relu_v2_node_outputs; + CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); + + auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]); + + auto manage = graph->manager(); + MS_EXCEPTION_IF_NULL(manage); + manage->Replace(relu, relu_v2_node_outputs[0]); + return relu_grad_v2; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.h new file mode 100644 index 0000000000..7506960ecb --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class DereluFusion : public PatternProcessPass { + public: + explicit DereluFusion(bool multigraph = true) : PatternProcessPass("derelu_fusion", multigraph) {} + ~DereluFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc new file mode 100644 index 0000000000..dbff0374f3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc @@ -0,0 +1,340 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" +#include +#include +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kReplaceOutputIndex0 = 3; +constexpr size_t kReplaceOutputIndex1 = 4; +bool IsC(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + return in->isa(); + } + return false; +} + +void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn); + MS_EXCEPTION_IF_NULL(bn_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(bn) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs"; + } + for (const auto &node_index : manager->node_users()[bn]) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + bn_outputs->push_back(output); + } +} +} // namespace + +const BaseRef FusedBatchNormFusion::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + VarPtr index0 = std::make_shared(IsC); + VarPtr index1 = std::make_shared(IsC); + VarPtr index2 = std::make_shared(IsC); + VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); + VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); + VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); + VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); + VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1}); + VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2}); + VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); + VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); + VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); + VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); + return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); +} + +ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(equiv); + auto iter_constant_input0 = (*equiv).find(constant_input0_var_); + if (iter_constant_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the constant_input0 var after matched."; + } + auto constant_input = utils::cast(iter_constant_input0->second); + MS_EXCEPTION_IF_NULL(constant_input); + if (!constant_input->isa()) { + return nullptr; + } + auto value_node = constant_input->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + if (!value->isa()) { + return nullptr; + } + auto tensor_ptr = value->cast(); + MS_EXCEPTION_IF_NULL(tensor_ptr); + if (tensor_ptr->data_type() == kNumberTypeFloat16) { + auto *half_data = static_cast(tensor_ptr->data_c()); + MS_EXCEPTION_IF_NULL(half_data); + float float_data = Eigen::half_impl::half_to_float(half_data[0]); + return MakeValue(float_data); + } else if (tensor_ptr->data_type() == kNumberTypeFloat32) { + auto *tensor_data = static_cast(tensor_ptr->data_c()); + MS_EXCEPTION_IF_NULL(tensor_data); + return MakeValue(tensor_data[0]); + } else { + MS_LOG(WARNING) << "The factor data type of value node " << value_node->DebugString() << " is not fp16 or fp32"; + return nullptr; + } +} + +AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + // Set input to create node + auto iter_data_input0 = (*equiv).find(data_input0_var_); + if (iter_data_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; + } + std::vector bn_training_reduce_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceOpName)), + utils::cast(iter_data_input0->second)}; + auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); + MS_EXCEPTION_IF_NULL(bn_training_reduce); + bn_training_reduce->set_scope(node->scope()); + // Set abstract + auto iter_data_input1 = (*equiv).find(data_input1_var_); + if (iter_data_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; + } + auto data_input1 = utils::cast(iter_data_input1->second); + MS_EXCEPTION_IF_NULL(data_input1); + auto iter_data_input2 = (*equiv).find(data_input2_var_); + if (iter_data_input2 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; + } + auto data_input2 = utils::cast(iter_data_input2->second); + MS_EXCEPTION_IF_NULL(data_input2); + AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_reduce->set_abstract(abstract_tuple); + return bn_training_reduce; +} + +void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv, + const std::vector &bn_training_reduce_outputs, + std::vector *bn_training_update_inputs) const { + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(bn_training_update_inputs); + auto iter_data_input0 = (*equiv).find(data_input0_var_); + if (iter_data_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; + } + auto iter_data_input1 = (*equiv).find(data_input1_var_); + if (iter_data_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; + } + auto iter_data_input2 = (*equiv).find(data_input2_var_); + if (iter_data_input2 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; + } + auto iter_variable_input0 = (*equiv).find(variable_input0_var_); + if (iter_variable_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; + } + auto iter_variable_input1 = (*equiv).find(variable_input1_var_); + if (iter_variable_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; + } + if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum + << ", but it is " << bn_training_reduce_outputs.size(); + } + *bn_training_update_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateOpName)), + utils::cast(iter_data_input0->second), + bn_training_reduce_outputs[0], + bn_training_reduce_outputs[1], + utils::cast(iter_data_input1->second), + utils::cast(iter_data_input2->second), + utils::cast(iter_variable_input0->second), + utils::cast(iter_variable_input1->second), + }; +} + +void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn, + std::vector *abstract_list) const { + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(bn); + MS_EXCEPTION_IF_NULL(abstract_list); + auto bn_abstract_tuple = dyn_cast(bn->abstract()); + MS_EXCEPTION_IF_NULL(bn_abstract_tuple); + if (bn_abstract_tuple->elements().size() < kBnOutputNum) { + MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is " + << bn_abstract_tuple->elements().size(); + } + auto iter_variable_input0 = (*equiv).find(variable_input0_var_); + if (iter_variable_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; + } + auto variable_input0 = utils::cast(iter_variable_input0->second); + MS_EXCEPTION_IF_NULL(variable_input0); + auto iter_variable_input1 = (*equiv).find(variable_input1_var_); + if (iter_variable_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; + } + auto variable_input1 = utils::cast(iter_variable_input1->second); + MS_EXCEPTION_IF_NULL(variable_input1); + *abstract_list = {bn_abstract_tuple->elements()[0], variable_input0->abstract(), variable_input1->abstract(), + bn_abstract_tuple->elements()[1], bn_abstract_tuple->elements()[2]}; +} + +AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate( + const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, + const std::vector &bn_training_reduce_outputs) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + // Set input + std::vector bn_training_update_inputs; + GetBNTrainingUpdateInputs(equiv, bn_training_reduce_outputs, &bn_training_update_inputs); + auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update); + // Set abstract + auto iter_batch_norm = (*equiv).find(batch_norm_var_); + if (iter_batch_norm == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."; + } + AnfNodePtr bn = utils::cast(iter_batch_norm->second); + MS_EXCEPTION_IF_NULL(bn); + AbstractBasePtrList abstract_list; + GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list); + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_update->set_abstract(abstract_tuple); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn, bn_training_update); + ValuePtr factor = GetFactor(equiv); + if (factor == nullptr) { + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrFactor, factor, bn_training_update); + AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update); + bn_training_update->set_scope(node->scope()); + return bn_training_update; +} + +const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(node); + AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node, equiv); + std::vector bn_training_reduce_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, + &bn_training_reduce_outputs); + AnfNodePtr bn_training_update = CreateBNTrainingUpdate(func_graph, node, equiv, bn_training_reduce_outputs); + if (bn_training_update == nullptr) { + MS_LOG(DEBUG) << "Create BNTrainingUpdate failed for bn node " << node->DebugString(); + return nullptr; + } + std::vector bn_training_update_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update, kBNTrainingUpdateOutputNum, + &bn_training_update_outputs); + if (bn_training_update_outputs.size() < kBNTrainingUpdateOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node bn must be " << kBNTrainingUpdateOutputNum << ", but it is " + << bn_training_update_outputs.size(); + } + // Replace old bn outputs with new outputs + auto iter_batch_norm = (*equiv).find(batch_norm_var_); + if (iter_batch_norm == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."; + } + AnfNodePtr bn = utils::cast(iter_batch_norm->second); + std::vector bn_outputs; + GetBNOutput(func_graph, bn, &bn_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (const auto &output : bn_outputs) { + MS_EXCEPTION_IF_NULL(output); + if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + continue; + } + auto tuple_getitem_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); + AnfNodePtr index_node = tuple_getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (index == kReplaceOutputIndex0 || index == kReplaceOutputIndex1) { + (void)manager->Replace(output, bn_training_update_outputs[index]); + } + } + return bn_training_update_outputs[0]; +} + +const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + VarPtr index0 = std::make_shared(IsC); + VarPtr index1 = std::make_shared(IsC); + VarPtr index2 = std::make_shared(IsC); + VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); + VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); + VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); + VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); + VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); + VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); + VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); + VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); + VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); + VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); + VectorRef cast2 = VectorRef({prim::kPrimCast, mul0}); + VectorRef cast3 = VectorRef({prim::kPrimCast, mul1}); + VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2}); + VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); + return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); +} + +const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + VarPtr index0 = std::make_shared(IsC); + VarPtr index1 = std::make_shared(IsC); + VarPtr index2 = std::make_shared(IsC); + VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); + VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); + VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); + VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); + VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); + VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); + VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); + VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); + VectorRef cast0 = VectorRef({prim::kPrimCast, sub0}); + VectorRef cast1 = VectorRef({prim::kPrimCast, sub1}); + VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_}); + VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_}); + VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); + VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); + return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h new file mode 100644 index 0000000000..b3bbedc36e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +class FusedBatchNormFusion : public PatternProcessPass { + public: + explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph), + data_input0_var_(std::make_shared()), + data_input1_var_(std::make_shared()), + data_input2_var_(std::make_shared()), + variable_input0_var_(std::make_shared()), + variable_input1_var_(std::make_shared()), + constant_input0_var_(std::make_shared()), + constant_input1_var_(std::make_shared()), + batch_norm_var_(std::make_shared(std::make_shared(prim::kPrimBatchNorm->name()))) {} + ~FusedBatchNormFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + protected: + AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const; + void GetBNTrainingUpdateInputs(const EquivPtr &equiv, const std::vector &bn_training_reduce_outputs, + std::vector *bn_training_update_inputs) const; + void GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn, + std::vector *abstract_list) const; + AnfNodePtr CreateBNTrainingUpdate(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, + const std::vector &bn_training_reduce_outputs) const; + ValuePtr GetFactor(const EquivPtr &equiv) const; + + VarPtr data_input0_var_; + VarPtr data_input1_var_; + VarPtr data_input2_var_; + VarPtr variable_input0_var_; + VarPtr variable_input1_var_; + VarPtr constant_input0_var_; + VarPtr constant_input1_var_; + VarPtr batch_norm_var_; +}; + +class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion { + public: + explicit FusedBatchNormMixPrecisionFusion0(bool multigraph = true) + : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} + + ~FusedBatchNormMixPrecisionFusion0() override = default; + const BaseRef DefinePattern() const override; +}; + +class FusedBatchNormMixPrecisionFusion1 : public FusedBatchNormFusion { + public: + explicit FusedBatchNormMixPrecisionFusion1(bool multigraph = true) + : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} + + ~FusedBatchNormMixPrecisionFusion1() override = default; + const BaseRef DefinePattern() const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.cc new file mode 100644 index 0000000000..2fb42f9bd6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/input_to_output_registry.h" +#include +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +bool ApplyRMSPropPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool FusedMulApplyMomentumPreCheck(const CNodePtr &node) { + TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); +} + +bool SparseApplyRMSPropPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool ApplyAdagradV2PreCheck(const CNodePtr &node) { + TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); +} + +bool ApplyKerasMomentumPreCheck(const CNodePtr &node) { + TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); +} + +bool SparseApplyFtrlPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool SparseApplyFtrlV2PreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool SparseApplyAdagradV2PreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool SparseApplyAdadeltaPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} +} // namespace +InputToOutputRegistry::InputToOutputRegistry() { + Register(kApplyRMSPropOpName, {1, 2}, ApplyRMSPropPreCheck); + Register(kFusedMulApplyMomentumOpName, {1}, FusedMulApplyMomentumPreCheck); + Register(kApplyAdagradOpName, {1}); + Register(kApplyAdagradDAName, {1, 2}); + Register(kApplyAdadeltaOpName, {1, 2}); + Register(kApplyPowerSignOpName, {1}); + Register(kApplyProximalAdagradOpName, {1}); + Register(kApplyAdaMaxOpName, {1, 2}); + Register(kApplyAdagradV2OpName, {1}, ApplyAdagradV2PreCheck); + Register(kApplyKerasMomentumOpName, {1}, ApplyKerasMomentumPreCheck); + Register(kSparseApplyFtrlOpName, {1, 2}, SparseApplyFtrlPreCheck); + Register(kSparseApplyFtrlV2OpName, {1, 2}, SparseApplyFtrlV2PreCheck); + Register(kSparseApplyAdagradV2OpName, {1}, SparseApplyAdagradV2PreCheck); + Register(kSparseApplyProximalAdagradOpName, {1}); + Register(kSparseApplyAdagradOpName, {1}); + Register(kApplyFtrlV2OpName, {1, 2}); + Register(kApplyMomentumOpName, {1}); + Register(kApplyFtrlOpName, {1, 2}); + Register(kApplyAdamOpName, {1, 2}); + Register(kApplyCenteredRMSPropOpName, {1, 2, 3}); + Register(kApplyAddSignOpName, {1}); + Register(kSparseApplyRMSPropOpName, {1, 2}, SparseApplyRMSPropPreCheck); + Register(kSparseApplyAdadeltaOpName, {1, 2}, SparseApplyAdadeltaPreCheck); + Register(kApplyAdamWithAmsgradOpName, {1, 2}); +} + +InputToOutputRegistry &InputToOutputRegistry::Instance() { + static InputToOutputRegistry instance; + return instance; +} + +void InputToOutputRegistry::Register(const InputToOutputRegister ®) { + auto op_name = reg.op_name(); + if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) { + (void)op_input_to_output_map_.insert(make_pair(op_name, reg)); + MS_LOG(DEBUG) << op_name << " input2output register successfully!"; + } +} + +void InputToOutputRegistry::Register(const std::string &op_name, const std::vector &input_indices, + const PreCheckFunc &pre_check_func) { + if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) { + InputToOutputRegister reg(op_name, pre_check_func); + reg.set_input_indices(input_indices); + (void)op_input_to_output_map_.insert(make_pair(op_name, reg)); + MS_LOG(DEBUG) << op_name << " input2output register successfully!"; + } +} + +bool InputToOutputRegistry::GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const { + if (op_input_to_output_map_.find(op_name) != op_input_to_output_map_.end()) { + *reg = op_input_to_output_map_.at(op_name); + MS_LOG(DEBUG) << op_name << " input2output find in registry."; + return true; + } + return false; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.h new file mode 100644 index 0000000000..45738c289c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ +#include +#include +#include +#include +#include "ir/anf.h" +#include "common/utils.h" + +namespace mindspore { +namespace opt { +using PreCheckFunc = std::function; +class InputToOutputRegister { + public: + explicit InputToOutputRegister( + const std::string &op_name = "", const PreCheckFunc &pre_check_func = [](const CNodePtr &node) { return true; }) + : op_name_(op_name), pre_check_func_(pre_check_func) {} + virtual ~InputToOutputRegister() = default; + + void set_input_indices(const std::vector &input_indices) { input_indices_ = input_indices; } + + const std::vector &input_indices() const { return input_indices_; } + const std::string &op_name() const { return op_name_; } + + private: + std::string op_name_; + std::vector input_indices_; + PreCheckFunc pre_check_func_; +}; + +class InputToOutputRegistry { + public: + static InputToOutputRegistry &Instance(); + void Register(const InputToOutputRegister ®); + void Register( + const std::string &op_name, const std::vector &input_indices, + const PreCheckFunc &pre_check_func = [](const CNodePtr &node) { return true; }); + bool GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const; + + private: + InputToOutputRegistry(); + ~InputToOutputRegistry() = default; + DISABLE_COPY_AND_ASSIGN(InputToOutputRegistry) + std::unordered_map op_input_to_output_map_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.cc new file mode 100644 index 0000000000..fd9fd31f12 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.cc @@ -0,0 +1,266 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, + std::vector *old_pattern_outputs) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto real_div0 = GetAnfNodeByVar(equiv, real_div0_var_); + auto real_div2 = GetAnfNodeByVar(equiv, real_div2_var_); + + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &users = manager->node_users(); + if (users.find(real_div0) == users.end() || users[real_div0].size() < 2) { + return false; + } + AnfNodeIndexSet real_div0_outputs = users[real_div0]; + auto iter = std::find_if(real_div0_outputs.begin(), real_div0_outputs.end(), + [&real_div2, &equiv, this](const std::pair &node_index) { + return node_index.first != real_div2 && node_index.second == 1 && + MatchAnotherPattern(node_index.first, equiv); + }); + if (iter == real_div0_outputs.end()) { + return false; + } + + (*old_pattern_outputs).push_back(node); + (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add0_var_)); + (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add1_var_)); + (*old_pattern_outputs).push_back(iter->first); + + return true; +} + +AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph, + const std::vector &old_pattern_outputs, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + auto prim = std::make_shared(kLambNextMVOpName); + std::vector lamb_next_mv_rule_inputs = {NewValueNode(prim)}; + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input0_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input1_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input2_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input3_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input4_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input5_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input6_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul0_x_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul1_sub_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul2_x_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul3_sub1_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul4_x_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[add2_y_])); + auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs); + MS_EXCEPTION_IF_NULL(lamb_next_mv_rule); + + // Set abstract of new node + AbstractBasePtrList new_abstracts; + (void)std::transform(old_pattern_outputs.begin(), old_pattern_outputs.end(), std::back_inserter(new_abstracts), + [](const AnfNodePtr &out) { return out->abstract(); }); + auto abstract_tuple = std::make_shared(new_abstracts); + MS_EXCEPTION_IF_NULL(abstract_tuple); + lamb_next_mv_rule->set_abstract(abstract_tuple); + + // Create tuple_getitem node for outputs + std::vector lamb_next_mv_rule_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, lamb_next_mv_rule, kLambNextMVRuleOutputNum, &lamb_next_mv_rule_outputs); + + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(old_pattern_outputs[1], lamb_next_mv_rule_outputs[1]); + (void)manager->Replace(old_pattern_outputs[2], lamb_next_mv_rule_outputs[2]); + (void)manager->Replace(old_pattern_outputs[3], lamb_next_mv_rule_outputs[3]); + + return lamb_next_mv_rule_outputs[0]; +} + +bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { + return IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) && + IsSameNode(equiv1, equiv2, add2_y_); +} + +const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + std::vector old_pattern_outputs; + if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) { + return nullptr; + } + + return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv); +} + +const BaseRef LambNextMVRuleCond1::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); + auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); + auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); + auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); + auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); + + return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); +} + +BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} + +const BaseRef LambNextMVRuleCond2::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); + auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); + auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); + auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); + auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); + + return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); +} + +BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} + +const BaseRef LambNextMVRuleCond3::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); + auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); + auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); + auto mul3 = VectorRef({prim::kPrimMul, input0_, mul3_sub1_}); + auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); + + return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); +} + +BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} + +const BaseRef LambNextMVRuleCond4::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); + auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); + auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); + auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); + auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0}); + + return VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); +} + +BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h new file mode 100644 index 0000000000..d14ce6e3fe --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ + +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class LambNextMVRule : public MultipleOutputPatternProcessPass { + public: + explicit LambNextMVRule(const std::string &name = "", bool multigraph = true) + : MultipleOutputPatternProcessPass(name, multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + input3_ = std::make_shared(); + input4_ = std::make_shared(); + input5_ = std::make_shared(); + input6_ = std::make_shared(); + mul0_x_ = std::make_shared(); + mul1_sub_ = std::make_shared(); + mul2_x_ = std::make_shared(); + mul3_sub1_ = std::make_shared(); + mul4_x_ = std::make_shared(); + add2_y_ = std::make_shared(); + real_div0_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + real_div1_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + real_div2_var_ = std::make_shared(std::make_shared(prim::kPrimMul->name())); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + } + ~LambNextMVRule() override = default; + const BaseRef DefinePattern() const override = 0; + BaseRef DefineAnotherPattern() const override = 0; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; + + protected: + bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, + std::vector *old_pattern_outputs) const; + AnfNodePtr CreateLambNextMVNode(const FuncGraphPtr &func_graph, const std::vector &old_pattern_outputs, + const EquivPtr &equiv) const; + + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr input3_; + VarPtr input4_; + VarPtr input5_; + VarPtr input6_; + VarPtr mul0_x_; + VarPtr mul1_sub_; + VarPtr mul2_x_; + VarPtr mul3_sub1_; + VarPtr mul4_x_; + VarPtr add2_y_; + // nodes which two patterns share, and add2_y_ also. + VarPtr real_div0_var_; + VarPtr real_div1_var_; + // part of output nodes + VarPtr add0_var_; + VarPtr add1_var_; + // other node + VarPtr real_div2_var_; +}; + +class LambNextMVRuleCond1 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond1(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond1", multigraph) {} + + ~LambNextMVRuleCond1() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVRuleCond2 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond2(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond2", multigraph) {} + + ~LambNextMVRuleCond2() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVRuleCond3 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond3(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond3", multigraph) {} + + ~LambNextMVRuleCond3() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVRuleCond4 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {} + + ~LambNextMVRuleCond4() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc new file mode 100644 index 0000000000..4ef3fa269f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc @@ -0,0 +1,278 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" + +namespace mindspore { +namespace opt { +AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, + const AnfNodePtr &new_node, const AnfNodePtr &add3, + const AnfNodePtr &add5, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(new_node); + MS_EXCEPTION_IF_NULL(add3); + MS_EXCEPTION_IF_NULL(add5); + MS_EXCEPTION_IF_NULL(equiv); + auto add0 = GetAnfNodeByVar(equiv, add0_var_); + MS_EXCEPTION_IF_NULL(add0); + auto add1 = GetAnfNodeByVar(equiv, add1_var_); + MS_EXCEPTION_IF_NULL(add1); + + // Set abstract of new node + AbstractBasePtrList new_node_list; + new_node_list.push_back(add3->abstract()); + new_node_list.push_back(add0->abstract()); + new_node_list.push_back(add1->abstract()); + new_node_list.push_back(add5->abstract()); + auto abstract_tuple = std::make_shared(new_node_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + new_node->set_abstract(abstract_tuple); + // Create tuple_getitem node for outputs + std::vector new_node_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, new_node, kLambNextMVWithDecayOutputNum, &new_node_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(add3, new_node_outputs[0]); + (void)manager->Replace(add0, new_node_outputs[1]); + (void)manager->Replace(add1, new_node_outputs[2]); + return new_node_outputs[3]; +} + +AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, + const AnfNodePtr &add3, const AnfNodePtr &add5, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(add3); + MS_EXCEPTION_IF_NULL(equiv); + // Create new node with all the inputs + auto prim = std::make_shared(kLambNextMVWithDecayOpName); + std::vector new_node_inputs = {NewValueNode(prim)}; + for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { + auto input_node = utils::cast((*equiv)[input_vars_[i]]); + MS_EXCEPTION_IF_NULL(input_node); + new_node_inputs.push_back(input_node); + } + for (size_t i = 0; i < kLambNextMVWithDecayConstantMulInputNum; ++i) { + auto constant_mul_input_node = utils::cast((*equiv)[constant_mul_input_vars_[i]]); + MS_EXCEPTION_IF_NULL(constant_mul_input_node); + new_node_inputs.push_back(constant_mul_input_node); + } + auto constant_add2_y_node = utils::cast((*equiv)[constant_add2_y_]); + MS_EXCEPTION_IF_NULL(constant_add2_y_node); + new_node_inputs.push_back(constant_add2_y_node); + auto new_node = func_graph->NewCNode(new_node_inputs); + return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv); +} + +bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { + return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) && + IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_); +} + +const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_); + MS_EXCEPTION_IF_NULL(mul4); + // Get add3 and match the add3 pattern + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(mul4) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; + } + AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4]; + auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(), + [&node, &equiv, this](const std::pair &node_index) { + return node_index.first != node && MatchAnotherPattern(node_index.first, equiv); + }); + if (iter != mul4_outputs.end()) { + return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv); + } + return nullptr; +} + +BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[0], constant_mul_input_vars_[3]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); + VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); + return add5; +} + +BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1}); + VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); + return add5; +} + +BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); + VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); + return add5; +} + +BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); + VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); + return add5; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h new file mode 100644 index 0000000000..23114c37ee --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass { + public: + explicit LambNextMVWithDecayRule(const std::string &name = "", bool multigraph = true) + : MultipleOutputPatternProcessPass(name, multigraph) { + for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { + input_vars_.push_back(std::make_shared()); + } + for (size_t i = 0; i < kLambNextMVWithDecayConstantMulInputNum; ++i) { + constant_mul_input_vars_.push_back(std::make_shared()); + } + constant_add2_y_ = std::make_shared(); + mul4_var_ = std::make_shared(std::make_shared(prim::kPrimMul->name())); + real_div0_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + real_div1_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + } + + ~LambNextMVWithDecayRule() override = default; + const BaseRef DefinePattern() const override = 0; + BaseRef DefineAnotherPattern() const override = 0; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; + + protected: + AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, + const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const; + AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, + const AnfNodePtr &add5, const EquivPtr &equiv) const; + std::vector input_vars_; + std::vector constant_mul_input_vars_; + // nodes which two patterns share + VarPtr constant_add2_y_; + VarPtr mul4_var_; + VarPtr real_div0_var_; + VarPtr real_div1_var_; + // part of output nodes + VarPtr add0_var_; + VarPtr add1_var_; +}; + +class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond1(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond1", multigraph) {} + + ~LambNextMVWithDecayRuleCond1() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond2(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond2", multigraph) {} + + ~LambNextMVWithDecayRuleCond2() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond3(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond3", multigraph) {} + + ~LambNextMVWithDecayRuleCond3() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVWithDecayRuleCond4 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond4(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond4", multigraph) {} + + ~LambNextMVWithDecayRuleCond4() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc new file mode 100644 index 0000000000..f21433b3c6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h" + +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" + +namespace mindspore { +namespace opt { +namespace { +std::tuple GetSharedNodes(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto add3 = node->cast(); + MS_EXCEPTION_IF_NULL(add3); + if (add3->inputs().size() < kAddInputNum) { + MS_LOG(EXCEPTION) << "The input size of Add3 is less than " << kAddInputNum; + } + auto real_div2_anf = add3->input(1); + MS_EXCEPTION_IF_NULL(real_div2_anf); + auto real_div2 = real_div2_anf->cast(); + MS_EXCEPTION_IF_NULL(real_div2); + if (real_div2->inputs().size() < kRealDivInputNum) { + MS_LOG(EXCEPTION) << "The input size of RealDiv2 is less than " << kRealDivInputNum; + } + auto sqrt0_anf = real_div2->input(2); + MS_EXCEPTION_IF_NULL(sqrt0_anf); + auto sqrt0 = sqrt0_anf->cast(); + MS_EXCEPTION_IF_NULL(sqrt0); + if (sqrt0->inputs().size() < kRsqrtInputNum) { + MS_LOG(EXCEPTION) << "The input size of Sqrt0 is less than " << kSqrtInputNum; + } + auto add2_anf = sqrt0->input(1); + MS_EXCEPTION_IF_NULL(add2_anf); + auto add2 = add2_anf->cast(); + if (add2->inputs().size() < kAddInputNum) { + MS_LOG(EXCEPTION) << "The input size of Add2 is less than " << kAddInputNum; + } + return std::make_tuple(add3->input(2), real_div2->input(1), add2->input(1), add2->input(2)); +} + +bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfNodePtr &real_div0, + const AnfNodePtr &real_div1, const AnfNodePtr &add2_y) { + if (node == nullptr || !node->isa()) { + return false; + } + auto add5 = node->cast(); + if (AnfAlgo::GetCNodeName(add5) != prim::kPrimTensorAdd->name() || add5->inputs().size() != kAddInputNum) { + return false; + } + auto real_div4_anf = add5->input(1); + if (real_div4_anf == nullptr || !real_div4_anf->isa()) { + return false; + } + auto real_div4 = real_div4_anf->cast(); + if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName || real_div4->inputs().size() != kRealDivInputNum) { + return false; + } + auto add4_anf = real_div4->input(2); + if (add4_anf == nullptr || !add4_anf->isa()) { + return false; + } + auto add4 = add4_anf->cast(); + if (AnfAlgo::GetCNodeName(add4) != prim::kPrimTensorAdd->name() || add4->inputs().size() != kAddInputNum) { + return false; + } + auto sqrt1_anf = add4->input(1); + if (sqrt1_anf == nullptr || !sqrt1_anf->isa()) { + return false; + } + auto sqrt1 = sqrt1_anf->cast(); + if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || sqrt1->inputs().size() != kSqrtInputNum) { + return false; + } + return add5->input(2) == mul4 && real_div4->input(1) == real_div0 && sqrt1->input(1) == real_div1 && + *add4->input(2) == *add2_y; +} + +std::tuple GetAdd0Add1Nodes(const AnfNodePtr &real_div0_anf, const AnfNodePtr &real_div1_anf) { + MS_EXCEPTION_IF_NULL(real_div0_anf); + MS_EXCEPTION_IF_NULL(real_div1_anf); + auto real_div0 = real_div0_anf->cast(); + auto real_div1 = real_div1_anf->cast(); + MS_EXCEPTION_IF_NULL(real_div0); + MS_EXCEPTION_IF_NULL(real_div1); + if (real_div0->inputs().size() != kRealDivInputNum) { + MS_LOG(EXCEPTION) << "RealDiv0 has wrong input size"; + } + if (real_div1->inputs().size() != kRealDivInputNum) { + MS_LOG(EXCEPTION) << "RealDiv1 has wrong input size"; + } + return std::make_tuple(real_div0->input(1), real_div1->input(1)); +} +} // namespace + +std::vector LambNextMVWithDecayV1Rule::GetFusionNodeInputs(const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(equiv); + auto i0 = utils::cast((*equiv)[input0_]); + auto i1 = utils::cast((*equiv)[input1_]); + auto i2 = utils::cast((*equiv)[input2_]); + auto i3 = utils::cast((*equiv)[input3_]); + auto i4 = utils::cast((*equiv)[input4_]); + auto i5 = utils::cast((*equiv)[input5_]); + auto i6 = utils::cast((*equiv)[input6_]); + auto i7 = utils::cast((*equiv)[mul0_x_]); + auto i8 = utils::cast((*equiv)[mul1_sub_]); + auto i9 = utils::cast((*equiv)[mul2_x_]); + auto i10 = utils::cast((*equiv)[mul3_sub1_]); + auto i11 = utils::cast((*equiv)[mul4_x_]); + auto i12 = utils::cast((*equiv)[add2_y_]); + auto prim = std::make_shared(kLambNextMVWithDecayV1OpName); + return {NewValueNode(prim), i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12}; +} + +const BaseRef LambNextMVWithDecayV1Rule::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef add1({prim::kPrimTensorAdd, mul2, mul3}); + VectorRef real_div1({prim_real_div, add1, input2_}); + VectorRef add2({prim::kPrimTensorAdd, real_div1, add2_y_}); + VectorRef mul0({prim::kPrimMul, mul0_x_, input4_}); + VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_}); + VectorRef sqrt0({prim_rsqrt, add2}); + VectorRef add0({prim::kPrimTensorAdd, mul0, mul1}); + VectorRef real_div0({prim_real_div, add0, input5_}); + VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input6_}); + VectorRef add3({prim::kPrimTensorAdd, real_div2, mul4}); + return add3; +} + +const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + if (func_graph == nullptr || node == nullptr || equiv == nullptr) { + return nullptr; + } + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + AnfNodePtr mul4 = nullptr; + AnfNodePtr real_div0 = nullptr; + AnfNodePtr real_div1 = nullptr; + AnfNodePtr add2_y = nullptr; + std::tie(mul4, real_div0, real_div1, add2_y) = GetSharedNodes(node); + + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(mul4) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; + } + AnfNodeIndexSet mul4_output_node_index_set = manager->node_users()[mul4]; + auto iter = std::find_if( + mul4_output_node_index_set.begin(), mul4_output_node_index_set.end(), + [&node, &mul4, &real_div0, &real_div1, &add2_y](const std::pair &node_index) { + return node_index.first != node && MatchAdd5Pattern(node_index.first, mul4, real_div0, real_div1, add2_y); + }); + if (iter == mul4_output_node_index_set.end()) { + return nullptr; + } + + std::vector inputs = GetFusionNodeInputs(equiv); + auto fusion_node = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(node->scope()); + + AnfNodePtr add0 = nullptr; + AnfNodePtr add1 = nullptr; + AnfNodePtr add5 = iter->first; + std::tie(add0, add1) = GetAdd0Add1Nodes(real_div0, real_div1); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(add0, 0), + AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add5, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0), AnfAlgo::GetOutputInferShape(add0, 0), + AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add5, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); + + std::vector fusion_node_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, fusion_node, kLambNextMVWithDecayV1OutputNum, &fusion_node_outputs); + if (fusion_node_outputs.size() != kLambNextMVWithDecayV1OutputNum) { + MS_LOG(ERROR) << "create multiple outputs for fusion node fail!"; + return nullptr; + } + + (void)manager->Replace(add0, fusion_node_outputs[1]); + (void)manager->Replace(add1, fusion_node_outputs[2]); + (void)manager->Replace(add5, fusion_node_outputs[3]); + return fusion_node_outputs[0]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h new file mode 100644 index 0000000000..58f05c37ba --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class LambNextMVWithDecayV1Rule : public PatternProcessPass { + public: + explicit LambNextMVWithDecayV1Rule(bool multigraph = true) + : PatternProcessPass("lamb_next_mv_with_decay_v1_rule", multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + input3_ = std::make_shared(); + input4_ = std::make_shared(); + input5_ = std::make_shared(); + input6_ = std::make_shared(); + mul0_x_ = std::make_shared(); + mul1_sub_ = std::make_shared(); + mul2_x_ = std::make_shared(); + mul3_sub1_ = std::make_shared(); + mul4_x_ = std::make_shared(); + add2_y_ = std::make_shared(); + } + + ~LambNextMVWithDecayV1Rule() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + std::vector GetFusionNodeInputs(const EquivPtr &equiv) const; + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr input3_; + VarPtr input4_; + VarPtr input5_; + VarPtr input6_; + VarPtr mul0_x_; + VarPtr mul1_sub_; + VarPtr mul2_x_; + VarPtr mul3_sub1_; + VarPtr mul4_x_; + VarPtr add2_y_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.cc new file mode 100644 index 0000000000..03bc1e0484 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h" +#include +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + std::vector new_node_inputs; + auto prim = std::make_shared(kLambNextRightOpName); + MS_EXCEPTION_IF_NULL(prim); + new_node_inputs.push_back(NewValueNode(prim)); + auto input0 = utils::cast((*equiv)[input0_]); + MS_EXCEPTION_IF_NULL(input0); + new_node_inputs.push_back(input0); + auto input1 = utils::cast((*equiv)[input1_]); + MS_EXCEPTION_IF_NULL(input1); + new_node_inputs.push_back(input1); + auto mul2_x = utils::cast((*equiv)[mul2_x_]); + MS_EXCEPTION_IF_NULL(mul2_x); + new_node_inputs.push_back(mul2_x); + auto mul3_x = utils::cast((*equiv)[mul3_x_]); + MS_EXCEPTION_IF_NULL(mul3_x); + new_node_inputs.push_back(mul3_x); + auto true_div1_recip = utils::cast((*equiv)[true_div1_recip_]); + MS_EXCEPTION_IF_NULL(true_div1_recip); + new_node_inputs.push_back(true_div1_recip); + auto add2_y = utils::cast((*equiv)[add2_y_]); + MS_EXCEPTION_IF_NULL(add2_y); + new_node_inputs.push_back(add2_y); + auto new_node = func_graph->NewCNode(new_node_inputs); + return new_node; +} + +const BaseRef LambNextRightRule::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); + VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); + return VectorRef( + {prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); +} + +const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + auto new_node = CreateLambNextRightNode(func_graph, equiv); + MS_EXCEPTION_IF_NULL(new_node); + // Set abstract of new node + auto iter_add1 = (*equiv).find(add1_var_); + if (iter_add1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; + } + auto add1 = utils::cast(iter_add1->second); + MS_EXCEPTION_IF_NULL(add1); + AbstractBasePtrList new_node_abstract_list; + new_node_abstract_list.push_back(add1->abstract()); + new_node_abstract_list.push_back(node->abstract()); + auto abstract_tuple = std::make_shared(new_node_abstract_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + new_node->set_abstract(abstract_tuple); + // Create tuple_getitem node for outputs + std::vector new_node_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, new_node, kLambNextRightOutputNum, &new_node_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(add1, new_node_outputs[0]); + return new_node_outputs[1]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h new file mode 100644 index 0000000000..67687cc037 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +class LambNextRightRule : public PatternProcessPass { + public: + explicit LambNextRightRule(bool multigraph = true) + : PatternProcessPass("lamb_next_right_rule", multigraph), + input0_(std::make_shared()), + input1_(std::make_shared()), + mul2_x_(std::make_shared()), + mul3_x_(std::make_shared()), + true_div1_recip_(std::make_shared()), + add2_y_(std::make_shared()), + add1_var_(std::make_shared(std::make_shared(prim::kPrimTensorAdd->name()))) {} + + ~LambNextRightRule() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; + + VarPtr input0_; + VarPtr input1_; + VarPtr mul2_x_; + VarPtr mul3_x_; + VarPtr true_div1_recip_; + VarPtr add2_y_; + VarPtr add1_var_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc new file mode 100644 index 0000000000..8e38c3cc2e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "common/utils.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +const BaseRef LambUpdateWithLRRuleFusion::DefinePattern() const { + auto real_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(real_div); + auto greater = std::make_shared(kGreaterOpName); + MS_EXCEPTION_IF_NULL(greater); + + VectorRef pattern_real_div0({real_div, input1_, input2_}); + VectorRef pattern_greater0({greater, input0_, constant_greater_max_}); + VectorRef pattern_greater1({greater, input1_, constant_greater_max_}); + VectorRef pattern_select0({prim::kPrimSelect, pattern_greater0, pattern_real_div0, constant_select_}); + VectorRef pattern_select1({prim::kPrimSelect, pattern_greater1, pattern_select0, constant_select_}); + VectorRef pattern_minimum0({prim::kPrimMinimum, pattern_select1, constant_minimum_}); + VectorRef pattern_maximum0({prim::kPrimMaximum, pattern_minimum0, constant_greater_max_}); + VectorRef pattern_mul0({prim::kPrimMul, pattern_maximum0, input3_}); + VectorRef pattern_mul1({prim::kPrimMul, pattern_mul0, input4_}); + VectorRef pattern({prim::kPrimSub, input5_, pattern_mul1}); + return pattern; +} + +const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + auto input0 = utils::cast((*equiv)[input0_]); + auto input1 = utils::cast((*equiv)[input1_]); + auto input2 = utils::cast((*equiv)[input2_]); + auto input3 = utils::cast((*equiv)[input3_]); + auto input4 = utils::cast((*equiv)[input4_]); + auto input5 = utils::cast((*equiv)[input5_]); + auto input6 = utils::cast((*equiv)[constant_greater_max_]); + auto input7 = utils::cast((*equiv)[constant_select_]); + auto input8 = utils::cast((*equiv)[constant_minimum_]); + + auto prim = std::make_shared(kLambUpdateWithLROpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = { + NewValueNode(prim), input0, input1, input2, input3, input4, input5, input6, input7, input8}; + auto lamb_update_with_lr = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(lamb_update_with_lr); + + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lamb_update_with_lr.get()); + lamb_update_with_lr->set_scope(node->scope()); + return lamb_update_with_lr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h new file mode 100644 index 0000000000..5ea01ccf65 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class LambUpdateWithLRRuleFusion : public PatternProcessPass { + public: + explicit LambUpdateWithLRRuleFusion(bool multigraph = true) + : PatternProcessPass("lamb_update_with_lr_rule_fusion", multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + input3_ = std::make_shared(); + input4_ = std::make_shared(); + input5_ = std::make_shared(); + constant_greater_max_ = std::make_shared(); + constant_select_ = std::make_shared(); + constant_minimum_ = std::make_shared(); + } + ~LambUpdateWithLRRuleFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr input3_; + VarPtr input4_; + VarPtr input5_; + VarPtr constant_greater_max_; + VarPtr constant_select_; + VarPtr constant_minimum_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.cc new file mode 100644 index 0000000000..59511a611a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h" +#include +#include +#include +#include "utils/utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef LambUpdateWithLrV2::DefinePattern() const { + const auto prim_greater = std::make_shared(kGreaterOpName); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + + VectorRef greater0({prim_greater, input_varptr_[0], input_varptr_[5]}); + VectorRef greater1({prim_greater, input_varptr_[1], input_varptr_[5]}); + VectorRef real_div0({prim_deal_div, input_varptr_[0], input_varptr_[1]}); + VectorRef select0({prim::kPrimSelect, greater1, real_div0, input_varptr_[6]}); + VectorRef select1({prim::kPrimSelect, greater0, select0, input_varptr_[6]}); + VectorRef mul0({prim::kPrimMul, select1, input_varptr_[2]}); + VectorRef mul1({prim::kPrimMul, mul0, input_varptr_[3]}); + + return VectorRef({prim::kPrimSub, input_varptr_[4], mul1}); +} + +const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + auto prim = std::make_shared(kLambUpdateWithLrV2OpName); + std::vector inputs = {NewValueNode(prim)}; + (void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs), + [&equiv](const VarPtr &in) { return utils::cast((*equiv)[in]); }); + auto lamb_update_with_lr_v2 = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(lamb_update_with_lr_v2); + lamb_update_with_lr_v2->set_abstract(node->abstract()); + + return lamb_update_with_lr_v2; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h new file mode 100644 index 0000000000..c5396178a5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ + +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class LambUpdateWithLrV2 : public PatternProcessPass { + public: + explicit LambUpdateWithLrV2(bool multigraph = true) : PatternProcessPass("lamb_update_with_lr_v2", multigraph) { + for (size_t i = 0; i < kLambUpdateWithLrV2InputNum - 1; ++i) { + input_varptr_.push_back(std::make_shared()); + } + } + ~LambUpdateWithLrV2() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + std::vector input_varptr_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc new file mode 100644 index 0000000000..fa1e92120d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc @@ -0,0 +1,162 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +using common::SafeCStr; +namespace { +void GetOutputCastNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &node, std::vector *cast_nodes) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(node) == manager->node_users().end()) { + return; + } + for (const auto &node_index : manager->node_users()[node]) { + AnfNodePtr output = node_index.first; + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + if (AnfAlgo::GetCNodeName(output_cnode) != prim::kPrimTupleGetItem->name()) { + MS_LOG(EXCEPTION) << "The output of node " << node->DebugString() << " should be " + << prim::kPrimTupleGetItem->name(); + } + if (manager->node_users().find(output) == manager->node_users().end() || + manager->node_users()[output].size() != 1) { + continue; + } + AnfNodePtr transitive_output = manager->node_users()[output].begin()->first; + MS_EXCEPTION_IF_NULL(transitive_output); + auto transitive_output_cnode = transitive_output->cast(); + MS_EXCEPTION_IF_NULL(transitive_output_cnode); + if (AnfAlgo::GetCNodeName(transitive_output_cnode) == prim::kPrimCast->name()) { + cast_nodes->push_back(transitive_output_cnode); + } + } +} + +bool CheckKernelBuildInfo(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_info) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(kernel_info); + for (size_t i = 0; i < kernel_info->GetInputNum(); ++i) { + if (kernel_info->GetInputDeviceType(i) != kNumberTypeFloat16 || + kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(cnode, i)) { + return false; + } + } + for (size_t i = 0; i < kernel_info->GetOutputNum(); ++i) { + if (kernel_info->GetOutputDeviceType(i) != kNumberTypeFloat32 || + kernel_info->GetOutputFormat(i) != AnfAlgo::GetOutputFormat(cnode, i)) { + return false; + } + } + return true; +} + +bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + std::vector *cast_nodes) { + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::HasNodeAttr(kAttrShapeGamma, cnode)) { + MS_LOG(INFO) << "The node " << cnode->DebugString() << " has no " << kAttrShapeGamma << " attr"; + return false; + } + if (cnode->inputs().size() != kLayerNormBetaGammaBackpropInputNum) { + MS_LOG(INFO) << "The node " << cnode->DebugString() << " inputs num is not equal to " + << kLayerNormBetaGammaBackpropInputNum; + return false; + } + if (AnfAlgo::GetOutputTensorNum(cnode) != kLayerNormBetaGammaBackpropOutputNum) { + MS_LOG(INFO) << "The node " << cnode->DebugString() << " outputs num is not equal to " + << kLayerNormBetaGammaBackpropOutputNum; + return false; + } + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) { + if (AnfAlgo::GetInputDeviceDataType(cnode, i) != kNumberTypeFloat16) { + MS_LOG(INFO) << "The data type of node " << cnode->DebugString() << " input " << i << " is not float16"; + return false; + } + } + GetOutputCastNodes(func_graph, cnode, cast_nodes); + if (cast_nodes->size() != kLayerNormBetaGammaBackpropOutputNum) { + MS_LOG(INFO) << "The num of cast node in node " << cnode->DebugString() << " outputs is not equal to " + << kLayerNormBetaGammaBackpropOutputNum; + return false; + } + for (const auto &cast : *cast_nodes) { + if (AnfAlgo::GetInputDeviceDataType(cast, 0) != kNumberTypeFloat16 || + AnfAlgo::GetOutputDeviceDataType(cast, 0) != kNumberTypeFloat32) { + MS_LOG(INFO) << "The cast " << cast->DebugString() << " should be fp16->fp32"; + return false; + } + } + return true; +} +} // namespace + +const BaseRef LayerNormBetaGammaBackpropFusion::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + const auto prim = std::make_shared(kLayerNormBetaGammaBackpropOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + if (AnfAlgo::IsGraphKernel(node)) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::vector cast_nodes; + if (!CheckLayernormBetaGammaBackprop(func_graph, cnode, &cast_nodes)) { + return nullptr; + } + std::vector> kernel_info_list; + MS_EXCEPTION_IF_NULL(kernel_query_); + kernel_query_->Query(cnode, &kernel_info_list); + auto alternative_kernel_build_info = + std::find_if(kernel_info_list.begin(), kernel_info_list.end(), + [&cnode](const kernel::KernelBuildInfoPtr &candidate_kernel_build_info) { + return CheckKernelBuildInfo(cnode, candidate_kernel_build_info); + }); + if (alternative_kernel_build_info == kernel_info_list.end()) { + MS_LOG(INFO) << "Can not find alternative kernel build info for node " << node->DebugString(); + return nullptr; + } + AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_build_info, cnode.get()); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + // The cast_nodes size has been checked above. + MS_EXCEPTION_IF_NULL(cast_nodes[0]); + MS_EXCEPTION_IF_NULL(cast_nodes[1]); + if (cast_nodes[0]->inputs().size() != kCastInputNum) { + MS_LOG(EXCEPTION) << "The cast0 " << cast_nodes[0]->DebugString() << " input size should be " << kCastInputNum; + } + (void)manager->Replace(cast_nodes[0], cast_nodes[0]->input(1)); + if (cast_nodes[1]->inputs().size() != kCastInputNum) { + MS_LOG(EXCEPTION) << "The cast1 " << cast_nodes[1]->DebugString() << " input size should be " << kCastInputNum; + } + (void)manager->Replace(cast_nodes[1], cast_nodes[1]->input(1)); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h new file mode 100644 index 0000000000..5bf1608143 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class LayerNormBetaGammaBackpropFusion : public PatternProcessPass { + public: + explicit LayerNormBetaGammaBackpropFusion(bool multigraph = true) + : PatternProcessPass("layer_norm_beta_gamma_backprop_fusion", multigraph), + kernel_query_(std::make_shared()) {} + + ~LayerNormBetaGammaBackpropFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + KernelQueryPtr kernel_query_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc new file mode 100644 index 0000000000..fdd390677a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h" +#include +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kMatMulInputIndex = 1; +constexpr size_t kBiasInputIndex = 2; +} // namespace + +const BaseRef MatmulBiasaddFusion::DefinePattern() const { + VarPtr X0 = std::make_shared(); + VarPtr X1 = std::make_shared(); + VarPtr X2 = std::make_shared(); + const auto prim_bias_add = std::make_shared(kBiasAddOpName); + return VectorRef({prim_bias_add, VectorRef({prim::kPrimMatMul, X0, X1}), X2}); +} + +const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + CheckCNodeInputSize(cnode, kBiasAddInputNum); + AnfNodePtr matmul = cnode->input(kMatMulInputIndex); + MS_EXCEPTION_IF_NULL(matmul); + auto matmul_cnode = matmul->cast(); + MS_EXCEPTION_IF_NULL(matmul_cnode); + matmul_cnode->add_input(cnode->input(kBiasInputIndex)); + AnfAlgo::SetNodeAttr(kAttrHasBias, MakeValue(true), matmul); + return matmul; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h new file mode 100644 index 0000000000..8c762435a9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class MatmulBiasaddFusion : public PatternProcessPass { + public: + explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {} + + ~MatmulBiasaddFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc new file mode 100644 index 0000000000..90c5ac19a9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h" +#include +#include +#include +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kAccumIndex = 1; +bool CheckValueNodeInputOfMul(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + std::vector mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0); + return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1); +} +} // namespace + +const BaseRef MomentumLossscaleFusion::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VarPtr X0 = std::make_shared(); + VarPtr X1 = std::make_shared(); + VarPtr X2 = std::make_shared(); + VarPtr X4 = std::make_shared(); + return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4}); +} + +const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + CheckCNodeInputSize(cnode, kApplyMomentumInputNum); + AnfNodePtr mul = cnode->input(4); + MS_EXCEPTION_IF_NULL(mul); + auto mul_cnode = mul->cast(); + MS_EXCEPTION_IF_NULL(mul_cnode); + CheckCNodeInputSize(mul_cnode, kMulInputNum); + size_t value_node_index = 0; + for (size_t i = 1; i < kMulInputNum; ++i) { + if (CheckValueNodeInputOfMul(mul_cnode->input(i))) { + value_node_index = i; + break; + } + } + if (value_node_index == 0) { + MS_LOG(DEBUG) << "The Mul " << mul->DebugString() << " to be fused must has a scalar constant input"; + return nullptr; + } + auto new_prim = std::make_shared(kFusedMulApplyMomentumOpName); + std::vector new_node_inputs{NewValueNode(new_prim), + cnode->input(1), + cnode->input(2), + cnode->input(3), + mul_cnode->input(kMulInputNum - value_node_index), + cnode->input(5), + mul_cnode->input(value_node_index)}; + auto new_node = func_graph->NewCNode(new_node_inputs); + MS_EXCEPTION_IF_NULL(new_node); + AnfAlgo::CopyNodeAttrs(node, new_node); + auto input_names_value = AnfAlgo::GetNodeAttr>(new_node, kAttrInputNames); + input_names_value[3] = "x1"; + input_names_value.emplace_back("x2"); + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node); + new_node->set_abstract(node->abstract()); + new_node->set_scope(node->scope()); + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h new file mode 100644 index 0000000000..8d36684a11 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class MomentumLossscaleFusion : public PatternProcessPass { + public: + explicit MomentumLossscaleFusion(bool multigraph = true) + : PatternProcessPass("momentum_lossscale_fusion", multigraph) {} + + ~MomentumLossscaleFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.cc new file mode 100644 index 0000000000..2d766891a0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(add); + + for (size_t index = 1; index < add->size(); ++index) { + auto input = add->input(index); + MS_EXCEPTION_IF_NULL(input); + if (input->isa()) { + auto cnode = input->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) { + if (!opt::IsUsedByOthers(graph, cnode)) { + auto full_name = cnode->fullname_with_scope(); + // exclude lamb and adam, and only work in bert + if (std::string::npos != full_name.find("adam") || std::string::npos != full_name.find("lamb") || + std::string::npos == full_name.find("bert")) { + MS_LOG(INFO) << "Mul is in adam or lamb or not a bert network, quit fusion"; + return false; + } + + *mul = cnode; + *mul_index = index; + return true; + } + } + } + } + return false; +} +} // namespace +const BaseRef MulAddFusion::DefinePattern() const { + VarPtr x = std::make_shared(); + VarPtr y = std::make_shared(); + VectorRef pattern({prim::kPrimTensorAdd, x, y}); + return pattern; +} + +const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + if (graph == nullptr || node == nullptr) { + return nullptr; + } + auto add = node->cast(); + if (add == nullptr || add->inputs().size() != kAddInputNum) { + return nullptr; + } + CNodePtr mul = nullptr; + size_t mul_index = 0; + if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) { + MS_LOG(DEBUG) << "Cannot find used-by-only-one-op Mul in Add's inputs"; + return nullptr; + } + + auto prim = std::make_shared(kFusedMulAddOpName); + std::vector inputs = {NewValueNode(prim)}; + for (size_t index = 1; index < mul->size(); ++index) { + inputs.push_back(mul->input(index)); + } + auto another_input_node = add->input(add->size() - mul_index); + if (another_input_node->isa() && + AnfAlgo::GetCNodeName(another_input_node) == prim::kPrimTupleGetItem->name()) { + MS_LOG(INFO) << "Add's another input node has multiple outputs, do not fuse"; + return nullptr; + } + inputs.push_back(another_input_node); + auto fusion_node = graph->NewCNode(inputs); + fusion_node->set_scope(add->scope()); + fusion_node->set_abstract(add->abstract()); + return fusion_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.h new file mode 100644 index 0000000000..0ad13e10e6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class MulAddFusion : public PatternProcessPass { + public: + explicit MulAddFusion(bool multigraph = true) : PatternProcessPass("mul_add_fusion", multigraph) {} + ~MulAddFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc new file mode 100644 index 0000000000..3567864e2f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const CNodePtr &addn, + const size_t &lossscale_input_index) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mul); + MS_EXCEPTION_IF_NULL(addn); + auto prim = std::make_shared(kFusedMulAddNOpName); + std::vector inputs = {NewValueNode(prim)}; + inputs.push_back(mul->input(kMulInputNum - lossscale_input_index)); + inputs.push_back(addn->input(2)); + // scalar input should be 3rd input + inputs.push_back(mul->input(lossscale_input_index)); + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(addn->scope()); + fusion_node->set_abstract(addn->abstract()); + return fusion_node; +} +} // namespace + +const BaseRef MulAddNFusion::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Y = std::make_shared(); + VarPtr Z = std::make_shared(); + + VectorRef mul({prim::kPrimMul, X, Z}); + VectorRef addn({prim::kPrimAddN, mul, Y}); + return addn; +} + +const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + if (graph == nullptr || node == nullptr || equiv == nullptr) { + return nullptr; + } + + auto addn = node->cast(); + if (addn == nullptr || addn->inputs().size() != kAddNInputNum) { + return nullptr; + } + auto mul_anf = addn->input(1); + if (mul_anf == nullptr) { + return nullptr; + } + auto mul = mul_anf->cast(); + if (mul == nullptr || mul->inputs().size() != kMulInputNum) { + return nullptr; + } + if (IsUsedByOthers(graph, mul)) { + MS_LOG(DEBUG) << "Mul is used by more then two nodes, cannot fuse"; + return nullptr; + } + + size_t lossscale_input_index = 1; + for (size_t index = 1; index < mul->inputs().size(); ++index) { + auto input_node = mul->input(index); + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa()) { + lossscale_input_index = index; + break; + } + } + auto constant_shape = AnfAlgo::GetOutputInferShape(mul->input(lossscale_input_index), 0); + if (!(constant_shape.size() == 0 || (constant_shape.size() == 1 && constant_shape[0] == 1))) { + MS_LOG(DEBUG) << "The const input of Mul node must be scalar or shape=(1,), but shape size is " + << constant_shape.size() << " and shape[0] is " << constant_shape[0]; + return nullptr; + } + + return CreateFusionNode(graph, mul, addn, lossscale_input_index); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h new file mode 100644 index 0000000000..484cb75237 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class MulAddNFusion : public PatternProcessPass { + public: + explicit MulAddNFusion(bool multigraph = true) : PatternProcessPass("mul_addn_fusion", multigraph) {} + ~MulAddNFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc new file mode 100644 index 0000000000..0c2667e4d9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +namespace { +const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag, + std::vector *trans_road) { + if (node == nullptr) { + MS_LOG(ERROR) << "nullptr"; + return nullptr; + } + if (node->isa()) { + auto cnode = node->cast(); + auto op_name = AnfAlgo::GetCNodeName(cnode); + auto manager = func_graph->manager(); + if (manager == nullptr) { + return nullptr; + } + if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || + op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { + auto users = manager->node_users()[node]; + if (users.size() > 1 && !first_flag) { + return nullptr; + } + trans_road->push_back(cnode); + first_flag = false; + auto next_node = AnfAlgo::GetInputNode(cnode, 0); + if (next_node->isa() || next_node->isa()) { + return next_node; + } + return ParamTransRoad(func_graph, next_node, first_flag, trans_road); + } + } else if (node->isa() || node->isa()) { + return node; + } + return nullptr; +} + +kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type, + TypeId output_type) { + MS_EXCEPTION_IF_NULL(cast); + auto kernel_info = dynamic_cast(cast->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto cast_build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(cast_build_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsFormat({format}); + builder.SetInputsFormat({format}); + builder.SetInputsDeviceType({input_type}); + builder.SetOutputsDeviceType({output_type}); + builder.SetKernelType(cast_build_info->kernel_type()); + builder.SetFusionType(cast_build_info->fusion_type()); + builder.SetProcessor(cast_build_info->processor()); + return builder.Build(); +} +} // namespace +bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Func graph is nullptr"; + return false; + } + auto manager = func_graph->manager(); + if (manager == nullptr) { + return false; + } + std::vector node_list = TopoSort(func_graph->get_return()); + bool changed = false; + for (auto node : node_list) { + if (node == nullptr || !node->isa()) { + continue; + } + auto cnode = node->cast(); + auto node_name = AnfAlgo::GetCNodeName(cnode); + if (node_name == prim::kPrimCast->name() || node_name == prim::kPrimTranspose->name() || + node_name == prim::kPrimReshape->name() || node_name == kTransDataOpName) { + MS_LOG(DEBUG) << "Skip trans op"; + continue; + } + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { + std::vector trans_road; + bool first_flag = true; + auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road); + if (final_node != nullptr && trans_road.size() == 3 && AnfAlgo::GetCNodeName(trans_road[0]) == kTransDataOpName && + AnfAlgo::GetCNodeName(trans_road[1]) == prim::kPrimCast->name() && + AnfAlgo::GetCNodeName(trans_road[2]) == kTransDataOpName) { + auto cur_transop = trans_road[0]; + auto format = AnfAlgo::GetOutputFormat(cur_transop, 0); + auto dtype = AnfAlgo::GetOutputDeviceDataType(cur_transop, 0); + auto param_format = AnfAlgo::GetOutputFormat(final_node, 0); + auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); + + auto cast = trans_road[1]; + if (param_format == format && param_dtype != dtype) { + AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); + manager->Replace(trans_road[2], final_node); + manager->Replace(cur_transop, cast); + } + changed = true; + } + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h new file mode 100644 index 0000000000..0479fd3d63 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" + +namespace mindspore { +namespace opt { +class ParameterTransOpFusion : public Pass { + public: + explicit ParameterTransOpFusion(size_t groups = 1) : Pass("Parameter_and_transop_fusion"), groups_(groups) {} + ~ParameterTransOpFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + size_t groups_ = 1; +}; +} // namespace opt +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc new file mode 100644 index 0000000000..ebaa429ebf --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +void DoRefresh(const CNodePtr &cnode) { + if (cnode == nullptr) { + MS_LOG(EXCEPTION) << "node is nullptr"; + } + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { + auto input_kernel_node = AnfAlgo::GetInputNode(cnode, input_index); + if (input_kernel_node->isa()) { + std::shared_ptr builder = + std::make_shared(); + auto cnode_input_format = AnfAlgo::GetInputFormat(cnode, input_index); + auto kernel_node_format = AnfAlgo::GetOutputFormat(input_kernel_node, 0); + auto dtype = AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0); + if (kernel_node_format != cnode_input_format) { + builder->SetOutputsFormat({cnode_input_format}); + builder->SetOutputsDeviceType({dtype}); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + } + } + } +} + +bool RefreshParameterFormat::Run(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + MS_LOG(ERROR) << "func_graph is nullptr."; + return false; + } + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto node : node_list) { + if (node == nullptr || !node->isa()) { + continue; + } + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + auto node_name = AnfAlgo::GetCNodeName(cnode); + if (node_name == kBNTrainingUpdateOpName) { + DoRefresh(cnode); + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h new file mode 100644 index 0000000000..122bdf55ca --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" + +namespace mindspore { +namespace opt { +class RefreshParameterFormat : public Pass { + public: + explicit RefreshParameterFormat(size_t groups = 1) : Pass("refresh_parameter_format"), groups_(groups) {} + ~RefreshParameterFormat() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + size_t groups_ = 1; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc new file mode 100644 index 0000000000..6f48eabbc5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef RemoveReshapePair::DefinePattern() const { + VarPtr X = std::make_shared(); + MS_EXCEPTION_IF_NULL(X); + return VectorRef({prim::kPrimReshape, VectorRef({prim::kPrimReshape, X})}); +} + +const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_op_1); + // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly + if (IsUsedByOthers(func_graph, reshape_op_1)) { + return nullptr; + } + auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_op_2); + if (IsUsedByOthers(func_graph, reshape_op_2)) { + return nullptr; + } + auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0); + auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0); + if (input_shape == output_shape) { + auto input_node = reshape_op_2->input(1); + return input_node; + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h new file mode 100644 index 0000000000..848713201a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ + +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class RemoveReshapePair : public PatternProcessPass { + public: + explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) {} + ~RemoveReshapePair() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc new file mode 100644 index 0000000000..02a866930c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +bool CheckShapeDimInfo(const std::vector &shape) { + if (shape.empty()) { + return false; + } + if (shape.size() == 1 && shape[0] % kCubeSize != 0) { + return false; + } + return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); +} +} // namespace + +const BaseRef ReshapeTransposeFusion::DefinePattern() const { + const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); + VectorRef reshape({prim_reshape, input_varptr_}); + + return VectorRef({prim::kPrimTranspose, reshape}); +} + +const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(transpose_cnode); + auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_cnode); + std::vector reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); + std::vector transpose_output0_shape = AnfAlgo::GetOutputInferShape(transpose_cnode, 0); + if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_output0_shape)) { + return nullptr; + } + auto prim = std::make_shared(kConfusionTransposeDOpName); + std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; + auto new_node = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_abstract(node->abstract()); + + AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); + AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); + AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(false), new_node); + auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); + AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); + + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h new file mode 100644 index 0000000000..a76538019e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ReshapeTransposeFusion : public PatternProcessPass { + public: + explicit ReshapeTransposeFusion(bool multigraph = true) : PatternProcessPass("reshape_transpose_fusion", multigraph) { + input_varptr_ = std::make_shared(); + } + ~ReshapeTransposeFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_varptr_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.cc new file mode 100644 index 0000000000..a3706bfb68 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef SoftmaxGradExtFusion::DefinePattern() const { + VectorRef mul({prim::kPrimMul, input1_, input0_}); + VectorRef sum({sum_var_, mul}); + VectorRef sub({prim::kPrimSub, input0_, sum}); + VectorRef mul1({prim::kPrimMul, input2_, input1_}); + VectorRef mul_grad({prim::kPrimMul, mul1, sub}); + return mul_grad; +} + +const BaseRef SoftmaxGradExtFusionV2::DefinePattern() const { + VectorRef mul({prim::kPrimMul, input1_, input0_}); + VectorRef sum({sum_var_, mul}); + VectorRef sub({prim::kPrimSub, input0_, sum}); + VectorRef mul1({prim::kPrimMul, input1_, sub}); + VectorRef mul_grad({prim::kPrimMul, input2_, mul1}); + return mul_grad; +} + +const BaseRef SoftmaxGradExtFusionV3::DefinePattern() const { + VectorRef mul({prim::kPrimMul, input1_, input0_}); + VectorRef sum({sum_var_, mul}); + VectorRef sub({prim::kPrimSub, input0_, sum}); + VectorRef mul1({prim::kPrimMul, input1_, sub}); + VectorRef mul_grad({prim::kPrimMul, mul1, input2_}); + return mul_grad; +} + +const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(node); + auto input0 = GetAnfNodeByVar(equiv, input0_); + auto input1 = GetAnfNodeByVar(equiv, input1_); + auto input2 = GetAnfNodeByVar(equiv, input2_); + auto sum = GetAnfNodeByVar(equiv, sum_var_); + if (!GetBoolAttr(sum, kAttrKeepDims)) { + MS_LOG(INFO) << "sum's attr keep_dims should be true if do fusion"; + return nullptr; + } + + auto prim = std::make_shared(kSoftmaxGradExtOpName); + auto fusion_node = graph->NewCNode({NewValueNode(prim), input0, input1, input2}); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(node->scope()); + fusion_node->set_abstract(node->abstract()); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, "keepdims", sum, fusion_node); + AnfAlgo::CopyNodeAttr(kAttrAxis, sum, fusion_node); + return fusion_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h new file mode 100644 index 0000000000..1b884b2726 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SoftmaxGradExtFusion : public PatternProcessPass { + public: + explicit SoftmaxGradExtFusion(const std::string &name = "softmax_grad_ext_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + sum_var_ = std::make_shared(std::make_shared(prim::kPrimReduceSum->name())); + } + ~SoftmaxGradExtFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + protected: + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr sum_var_; +}; + +class SoftmaxGradExtFusionV2 : public SoftmaxGradExtFusion { + public: + explicit SoftmaxGradExtFusionV2(bool multigraph = true) + : SoftmaxGradExtFusion("softmax_grad_ext_fusion_v2", multigraph) {} + ~SoftmaxGradExtFusionV2() override = default; + const BaseRef DefinePattern() const override; +}; + +class SoftmaxGradExtFusionV3 : public SoftmaxGradExtFusion { + public: + explicit SoftmaxGradExtFusionV3(bool multigraph = true) + : SoftmaxGradExtFusion("softmax_grad_ext_fusion_v3", multigraph) {} + ~SoftmaxGradExtFusionV3() override = default; + const BaseRef DefinePattern() const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.cc new file mode 100644 index 0000000000..67c881759a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" + +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(square); + MS_EXCEPTION_IF_NULL(sum); + if (square->inputs().size() != kSquareNodeInputNum) { + MS_LOG(EXCEPTION) << "Square node has wrong input size"; + } + auto prim = std::make_shared(kSquareSumV1OpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector square_sumv1_inputs = {NewValueNode(prim), square->input(1)}; + auto square_sumv1 = graph->NewCNode(square_sumv1_inputs); + MS_EXCEPTION_IF_NULL(square_sumv1); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + square_sumv1->set_kernel_info(kernel_info); + auto types = {AnfAlgo::GetOutputInferDataType(sum, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(sum, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sumv1.get()); + square_sumv1->set_scope(sum->scope()); + AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv1); + auto names = MakeValue>({square->fullname_with_scope(), sum->fullname_with_scope()}); + AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv1); + return square_sumv1; +} + +CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(square); + MS_EXCEPTION_IF_NULL(sum); + if (square->inputs().size() != kSquareNodeInputNum) { + MS_LOG(EXCEPTION) << "Square node has wrong input size"; + } + auto prim = std::make_shared(kSquareSumV2OpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector square_sumv2_inputs = {NewValueNode(prim), square->input(1)}; + auto square_sumv2 = graph->NewCNode(square_sumv2_inputs); + MS_EXCEPTION_IF_NULL(square_sumv2); + auto types = {AnfAlgo::GetOutputInferDataType(sum, 0), AnfAlgo::GetOutputInferDataType(square, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(sum, 0), AnfAlgo::GetOutputInferShape(square, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sumv2.get()); + square_sumv2->set_scope(sum->scope()); + AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv2); + auto names = MakeValue>({square->fullname_with_scope(), sum->fullname_with_scope()}); + AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv2); + return square_sumv2; +} + +std::tuple GetPrevNodes(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto sum = node->cast(); + MS_EXCEPTION_IF_NULL(sum); + if (sum->inputs().size() != kSumNodeInputNum) { + MS_LOG(EXCEPTION) << "ReduceSumD node has wrong input size"; + } + auto square_anf = sum->input(1); + MS_EXCEPTION_IF_NULL(square_anf); + auto square = square_anf->cast(); + MS_EXCEPTION_IF_NULL(square); + + return std::make_tuple(sum, square_anf, square); +} +} // namespace + +const BaseRef SquareSumFusion::DefinePattern() const { + VarPtr X = std::make_shared(); + MS_EXCEPTION_IF_NULL(X); + return VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimSquare, X})}); +} + +const AnfNodePtr SquareSumFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + CNodePtr sum = nullptr; + AnfNodePtr square_anf = nullptr; + CNodePtr square = nullptr; + std::tie(sum, square_anf, square) = GetPrevNodes(node); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(square_anf) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "Square node has no output in NodeUsersMap"; + } + AnfNodePtr ret_node = nullptr; + if (manager->node_users()[square_anf].size() == 1) { + ret_node = GenerateSquareSumV1(graph, square, sum); + } else if (manager->node_users()[square_anf].size() == 2) { + auto square_sumv2 = GenerateSquareSumV2(graph, square, sum); + + std::vector square_sumv2_outputs; + CreateMultipleOutputsOfAnfNode(graph, square_sumv2, kSquareSumv2OutputNum, &square_sumv2_outputs); + if (square_sumv2_outputs.size() != kSquareSumv2OutputNum) { + MS_LOG(EXCEPTION) << "make SquareSumV2 outputs fail"; + } + (void)manager->Replace(square, square_sumv2_outputs[1]); + ret_node = square_sumv2_outputs[0]; + } + return ret_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.h new file mode 100644 index 0000000000..54189606ba --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SquareSumFusion : public PatternProcessPass { + public: + explicit SquareSumFusion(bool multigraph = true) : PatternProcessPass("square_sum_fusion", multigraph) {} + ~SquareSumFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc new file mode 100644 index 0000000000..46bf2a8604 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +bool CheckShapeDimInfo(const std::vector &shape) { + if (shape.empty()) { + return false; + } + if (shape.size() == 1 && shape[0] % kCubeSize != 0) { + return false; + } + return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); +} +} // namespace + +const BaseRef TransposeReshapeFusion::DefinePattern() const { + const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); + VectorRef transpose({prim::kPrimTranspose, input_varptr_}); + + return VectorRef({prim_reshape, transpose}); +} + +const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_cnode); + auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(transpose_cnode); + std::vector reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); + std::vector transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); + if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { + return nullptr; + } + auto prim = std::make_shared(kConfusionTransposeDOpName); + std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; + auto new_node = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + + new_node->set_abstract(node->abstract()); + AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); + AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); + AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(true), new_node); + auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); + AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); + + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h new file mode 100644 index 0000000000..39b8fe4687 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class TransposeReshapeFusion : public PatternProcessPass { + public: + explicit TransposeReshapeFusion(bool multigraph = true) : PatternProcessPass("transpose_reshape_fusion", multigraph) { + input_varptr_ = std::make_shared(); + } + ~TransposeReshapeFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_varptr_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.cc new file mode 100644 index 0000000000..b6da588e89 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef TransposeTransDataFusion::DefinePattern() const { + const auto prim_transdata = std::make_shared(prim::KPrimTransData->name()); + VectorRef transpose({prim::kPrimTranspose, input_varptr_}); + + return VectorRef({prim_transdata, transpose}); +} + +const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputNum); + MS_EXCEPTION_IF_NULL(transdata_cnode); + auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kBackendTransDataInputNum); + MS_EXCEPTION_IF_NULL(transpose_cnode); + auto transpose_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transpose_cnode); + auto transdata_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transdata_cnode); + MS_EXCEPTION_IF_NULL(transpose_kernel_build_info); + MS_EXCEPTION_IF_NULL(transdata_kernel_build_info); + + auto new_transdata_builder = std::make_shared(); + auto transpose_input_formats = transpose_kernel_build_info->GetAllInputFormats(); + new_transdata_builder->SetInputsFormat(transpose_input_formats); + new_transdata_builder->SetOutputsFormat(transdata_kernel_build_info->GetAllOutputFormats()); + new_transdata_builder->SetInputsDeviceType(transdata_kernel_build_info->GetAllInputDeviceTypes()); + new_transdata_builder->SetOutputsDeviceType(transdata_kernel_build_info->GetAllOutputDeviceTypes()); + new_transdata_builder->SetKernelType(transdata_kernel_build_info->kernel_type()); + new_transdata_builder->SetFusionType(transdata_kernel_build_info->fusion_type()); + new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); + + auto new_fusion_transdata = std::make_shared(kTransDataOpName); + if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) { + std::vector inputs = {NewValueNode(new_fusion_transdata), + utils::cast((*equiv)[input_varptr_])}; + auto new_node = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_abstract(node->abstract()); + AnfAlgo::CopyNodeAttrs(transdata_cnode, new_node); + AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(transpose_input_formats[0]), new_node); + AnfAlgo::SetSelectKernelBuildInfo(new_transdata_builder->Build(), new_node.get()); + MS_LOG(INFO) << "transpose transdata fusion node:" << node->fullname_with_scope() << " success"; + return new_node; + } else { + MS_LOG(INFO) << "transpose transdata fusion node:" << node->fullname_with_scope() << " failed"; + return node; + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h new file mode 100644 index 0000000000..852d5194ec --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class TransposeTransDataFusion : public PatternProcessPass { + public: + explicit TransposeTransDataFusion(bool multigraph = true) + : PatternProcessPass("transpose_transdata_fusion", multigraph) { + input_varptr_ = std::make_shared(); + supported_checker_ = std::make_shared(); + } + ~TransposeTransDataFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_varptr_; + + private: + SupportedCheckerPtr supported_checker_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc new file mode 100644 index 0000000000..887b9a76a1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/common_backend_optimization.h" +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/pass/convert_const_input_to_attr.h" +#include "backend/optimizer/pass/convert_tuple_output_to_maketuple.h" +#include "backend/optimizer/pass/convert_const_input_to_tensor_input.h" +#include "backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h" +#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h" +#include "utils/context/ms_context.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +void BackendCommonOptimization(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = + save_graphs_path + "/hwopt_common_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } + auto optimizer = std::make_shared(); + auto common_pm = std::make_shared("common_pm"); + common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(common_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + if (save_graphs) { + std::string file_path = + save_graphs_path + "/hwopt_common_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h new file mode 100644 index 0000000000..4127fc05de --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ +#include +#include "backend/session/kernel_graph.h" +namespace mindspore { +namespace opt { +void BackendCommonOptimization(const std::shared_ptr &kernel_graph); +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.cc b/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.cc new file mode 100644 index 0000000000..d21cabe54a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +FusionIdAllocator::FusionIdAllocator() { fusion_id = 0; } + +FusionIdAllocator::~FusionIdAllocator() {} + +void FusionIdAllocator::Init() { fusion_id = 0; } + +int32_t FusionIdAllocator::AllocateFusionId() { + fusion_id++; + return fusion_id; +} + +bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + return AnfAlgo::HasNodeAttr(kAttrFusionId, cnode); +} + +int32_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) { + if (HasFusionIdAttr(node)) { + return AnfAlgo::GetNodeAttr(node, kAttrFusionId); + } + return -1; +} + +void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int32_t id) { + ValuePtr fusion_id_v = MakeValue(id); + AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.h b/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.h new file mode 100644 index 0000000000..bdee5ee84a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ + +#include +#include "base/base.h" + +namespace mindspore { +namespace opt { +class FusionIdAllocator { + public: + FusionIdAllocator(); + virtual ~FusionIdAllocator(); + FusionIdAllocator(const FusionIdAllocator &in) = delete; + FusionIdAllocator &operator=(const FusionIdAllocator &in) = delete; + + void Init(); + int32_t AllocateFusionId(); + bool HasFusionIdAttr(const AnfNodePtr &node); + int32_t GetFusionId(const AnfNodePtr &node); + void SetFusionId(const AnfNodePtr &node, int32_t id); + + private: + int32_t fusion_id; +}; +using FusionIdAllocatorPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc new file mode 100644 index 0000000000..266130c6b1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -0,0 +1,785 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/common/helper.h" +#include +#include +#include +#include +#include +#include +#include +#include "utils/utils.h" +#include "utils/base_ref.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "common/utils.h" +#include "runtime/device/kernel_info.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +constexpr size_t kType32Len = 4; +std::vector Convert2Int(const std::vector &v) { + std::vector result; + (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt); + return result; +} + +bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + std::vector node_list = TopoSort(graph->get_return()); + std::map> control_depend_map; + for (auto &nd : node_list) { + MS_EXCEPTION_IF_NULL(nd); + if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { + auto control_depend = nd->cast(); + auto prior_node = control_depend->input(kControlDependPriorIndex); + auto behind_node = control_depend->input(kControlDependBehindIndex); + auto it = control_depend_map.find(behind_node); + if (it == control_depend_map.end()) { + control_depend_map[behind_node] = std::set{prior_node}; + } else { + it->second.insert(prior_node); + } + } + } + + FuncGraphManagerPtr manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + std::unordered_set seen_node; + std::deque todo{node1}; + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { + continue; + } + (void)seen_node.insert(node); + + if (node == node2) { + return true; + } + if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); + } + auto it = control_depend_map.find(node); + if (it != control_depend_map.end()) { + (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); + } + } + return false; +} + +bool UnVisited(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + if (IsValueNode(in)) { + auto value_node = in->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + auto prim_py = value->cast(); + MS_EXCEPTION_IF_NULL(prim_py); + return !prim_py->HasAttr(kAttrVisited); + } else if (IsValueNode(in)) { + auto func_graph = GetValueNode(in); + MS_EXCEPTION_IF_NULL(func_graph); + return !func_graph->has_flag(kAttrVisited); + } + return false; + } + return false; +} + +bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(ERROR) << "The node is expected to be a cnode"; + return false; + } + *cnode = node->cast(); + if (*cnode == nullptr) { + return false; + } + if ((*cnode)->inputs().size() < IntToSize(input_size)) { + auto op_name = AnfAlgo::GetCNodeName(*cnode); + MS_LOG(ERROR) << "op[" + op_name + "] has less than " << input_size << " inputs."; + return false; + } + return true; +} + +CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "The node is expected to be a cnode"; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() != IntToSize(input_size)) { + auto op_name = AnfAlgo::GetCNodeName(cnode); + MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs."; + } + return cnode; +} + +void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() != input_size) { + MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size; + } +} + +bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) { + MS_EXCEPTION_IF_NULL(node_x); + MS_EXCEPTION_IF_NULL(node_y); + return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) && + AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0)); +} + +const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + + auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum); + MS_EXCEPTION_IF_NULL(transop_cnode); + auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum); + auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum); + MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1)); + MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1)); + auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1); + MS_EXCEPTION_IF_NULL(transed_node); + + std::vector replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node, + depend_cnode->input(kDependInputNum - 1)}; + AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs); + MS_EXCEPTION_IF_NULL(replace_depend); + auto transed_abstract = transed_node->abstract(); + replace_depend->set_abstract(transed_abstract); + return replace_depend; +} + +bool Visited(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + if (IsValueNode(in)) { + auto value_node = in->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + auto prim_py = value->cast(); + MS_EXCEPTION_IF_NULL(prim_py); + return prim_py->HasAttr(kAttrVisited); + } else if (IsValueNode(in)) { + auto func_graph = GetValueNode(in); + MS_EXCEPTION_IF_NULL(func_graph); + return func_graph->has_flag(kAttrVisited); + } + return false; + } + return false; +} + +void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode, + std::vector *conv_bn1_outputs) { + auto prim = std::make_shared(kConvBN1OpName); + std::vector conv_bn1_inputs = {NewValueNode(prim)}; + MS_EXCEPTION_IF_NULL(conv_cnode); + // All the inputs of conv_bn1 are from the inputs of conv + for (size_t i = 1; i < conv_cnode->inputs().size(); i++) { + conv_bn1_inputs.push_back(conv_cnode->input(i)); + } + MS_EXCEPTION_IF_NULL(func_graph); + CNodePtr conv_bn1_cnode = func_graph->NewCNode(conv_bn1_inputs); + MS_EXCEPTION_IF_NULL(conv_bn1_cnode); + auto kernel_info = std::make_shared(); + conv_bn1_cnode->set_kernel_info(kernel_info); + // Set attr for conv_bn1 + AnfAlgo::CopyNodeAttrs(conv_cnode, conv_bn1_cnode); + // Set abstract of conv_bn1 + MS_EXCEPTION_IF_NULL(bn_cnode); + auto bn_abstract_tuple = dyn_cast(bn_cnode->abstract()); + MS_EXCEPTION_IF_NULL(bn_abstract_tuple); + AbstractBasePtrList conv_bn1_abstract_list; + conv_bn1_abstract_list.push_back(conv_cnode->abstract()); + auto abstract_tensor = std::make_shared( + kFloat32, Convert2Int(AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, kVariance - 1))); + conv_bn1_abstract_list.push_back(abstract_tensor); + conv_bn1_abstract_list.push_back(bn_abstract_tuple->elements()[kSaveMean]); + auto abstract_tuple = std::make_shared(conv_bn1_abstract_list); + conv_bn1_cnode->set_abstract(abstract_tuple); + + CreateMultipleOutputsOfAnfNode(func_graph, conv_bn1_cnode, kConvBn1OutputNum, conv_bn1_outputs); +} + +void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector &fused_bn1_outputs, + const CNodePtr &bn_node, std::vector *fused_bn2_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_node); + MS_EXCEPTION_IF_NULL(fused_bn2_outputs); + if (bn_node->inputs().size() != kBnInputNum) { + MS_LOG(EXCEPTION) << "BN node has wrong input size"; + } + if (fused_bn1_outputs.size() != kBN1OutputNum) { + MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; + } + + // the inputs of fused_bn2 are from the outputs of fused_bn1 and the inputs of bn + std::vector fused_bn2_inputs = {NewValueNode(std::make_shared(kFusedBN2OpName))}; + fused_bn2_inputs.push_back(fused_bn1_outputs[0]); + fused_bn2_inputs.push_back(fused_bn1_outputs[1]); + fused_bn2_inputs.push_back(bn_node->input(4)); + fused_bn2_inputs.push_back(bn_node->input(5)); + auto fused_bn2 = graph->NewCNode(fused_bn2_inputs); + MS_EXCEPTION_IF_NULL(fused_bn2); + auto kernel_info = std::make_shared(); + fused_bn2->set_kernel_info(kernel_info); + auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 4), AnfAlgo::GetOutputInferDataType(bn_node, 1), + AnfAlgo::GetOutputInferDataType(bn_node, 2)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 4), AnfAlgo::GetOutputInferShape(bn_node, 1), + AnfAlgo::GetOutputInferShape(bn_node, 2)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn2.get()); + fused_bn2->set_scope(bn_node->scope()); + AnfAlgo::CopyNodeAttr(kAttrMomentum, bn_node, fused_bn2); + + CreateMultipleOutputsOfAnfNode(graph, fused_bn2, kBN2OutputNum, fused_bn2_outputs); +} + +void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input, + const std::vector &fused_bn1_outputs, + const std::vector &fused_bn2_outputs, const CNodePtr &bn_node, + std::vector *fused_bn3_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(data_input); + MS_EXCEPTION_IF_NULL(bn_node); + MS_EXCEPTION_IF_NULL(fused_bn3_outputs); + if (bn_node->inputs().size() != kBnInputNum) { + MS_LOG(EXCEPTION) << "BN node has wrong input size"; + } + + if (fused_bn1_outputs.size() != kBN1OutputNum) { + MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; + } + + if (fused_bn2_outputs.size() != kBN2OutputNum) { + MS_LOG(EXCEPTION) << "BN2 outputs has wrong input size"; + } + + // the inputs of fused_bn3 are from the outputs of fused_bn1 and the inputs of bn + std::vector fused_bn3_inputs = {NewValueNode(std::make_shared(kFusedBN3OpName))}; + fused_bn3_inputs.push_back(data_input); + fused_bn3_inputs.push_back(fused_bn1_outputs[0]); + fused_bn3_inputs.push_back(fused_bn2_outputs[0]); + fused_bn3_inputs.push_back(bn_node->input(2)); + fused_bn3_inputs.push_back(bn_node->input(3)); + auto fused_bn3 = graph->NewCNode(fused_bn3_inputs); + MS_EXCEPTION_IF_NULL(fused_bn3); + auto kernel_info = std::make_shared(); + fused_bn3->set_kernel_info(kernel_info); + auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn3.get()); + + fused_bn3->set_scope(bn_node->scope()); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, kAttrEps, bn_node, fused_bn3); + + (*fused_bn3_outputs).push_back(fused_bn3); +} + +void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num, + std::vector *outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(outputs); + for (size_t i = 0; i < output_num; i++) { + auto idx = NewValueNode(SizeToInt(i)); + MS_EXCEPTION_IF_NULL(idx); + int temp = SizeToInt(i); + auto imm = std::make_shared(temp); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); + MS_EXCEPTION_IF_NULL(tuple_getitem); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(node, i)}, + {AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get()); + (*outputs).push_back(tuple_getitem); + } +} + +template +tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, + size_t data_length) { + MS_EXCEPTION_IF_NULL(value_tuple_ptr); + MS_EXCEPTION_IF_NULL(type_ptr); + std::vector values; + for (const auto &v : value_tuple_ptr->value()) { + MS_EXCEPTION_IF_NULL(v); + if (v->isa()) { + ScalarPtr scalar = v->cast(); + values.push_back(GetValue(scalar)); + } else { + MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; + return nullptr; + } + } + std::vector tensor_shape = {SizeToInt(values.size())}; + tensor::TensorPtr tensor = std::make_shared(type_ptr->type_id(), tensor_shape); + MS_EXCEPTION_IF_NULL(tensor); + tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr}; + tensor->set_device_info(device_info); + auto data_ptr = tensor->data_c(); + MS_EXCEPTION_IF_NULL(data_ptr); + auto elem_num = values.size() * data_length; + auto ret_code = memcpy_s(data_ptr, static_cast(tensor->data().nbytes()), values.data(), elem_num); + if (ret_code != 0) { + MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; + } + return tensor; +} + +tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { + MS_EXCEPTION_IF_NULL(value_tuple); + tensor::TensorPtr tensor = nullptr; + if (value_tuple->value().empty()) { + MS_LOG(WARNING) << "The value tuple is empty."; + return nullptr; + } + ValuePtr v = *(value_tuple->value().begin()); + MS_EXCEPTION_IF_NULL(v); + // Currently we only deal with the scalar tuple + if (!v->isa()) { + MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; + return nullptr; + } + ScalarPtr scalar = v->cast(); + MS_EXCEPTION_IF_NULL(scalar); + if (scalar->isa()) { + tensor = CreateTensorWithValueTuple(value_tuple, kInt32, kType32Len); + } else if (scalar->isa()) { + tensor = CreateTensorWithValueTuple(value_tuple, kFloat32, kType32Len); + } else { + auto type = scalar->type(); + auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); + MS_LOG(ERROR) << "Invalid scalar type: " << type_str; + return nullptr; + } + return tensor; +} + +bool IsNopNode(const AnfNodePtr &node) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) { + return false; + } + static std::unordered_set nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, + prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(), + kFlattenGradOpName}; + if (node == nullptr || !node->isa()) { + return false; + } + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end()) { + return false; + } + return true; +} + +bool IsAllNopNode(const session::KernelGraph *const graph) { + MS_EXCEPTION_IF_NULL(graph); + auto execution_order = graph->execution_order(); + for (auto &cnode : execution_order) { + MS_EXCEPTION_IF_NULL(cnode); + if (!IsNopNode(cnode)) { + return false; + } + } + return true; +} + +void HideNopNode(session::KernelGraph *const graph) { + MS_EXCEPTION_IF_NULL(graph); + if (IsAllNopNode(graph) == true) { + return; + } + auto execution_order = graph->execution_order(); + MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); + std::vector new_nodes; + for (auto &cnode : execution_order) { + MS_EXCEPTION_IF_NULL(cnode); + if (!IsNopNode(cnode)) { + new_nodes.push_back(cnode); + } + } + graph->set_execution_order(new_nodes); + MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size(); +} + +void RemoveNopNode(session::KernelGraph *const graph) { + MS_EXCEPTION_IF_NULL(graph); + if (IsAllNopNode(graph) == true) { + return; + } + bool changed = true; + while (changed) { + changed = false; + std::vector new_nodes; + for (auto &cnode : graph->execution_order()) { + MS_EXCEPTION_IF_NULL(cnode); + // ignore nop node itself + if (IsNopNode(cnode)) { + continue; + } + // Replace the input which is nop node + std::vector new_inputs; + new_inputs.push_back(cnode->input(0)); + bool need_update = false; + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto input = cnode->input(i); + MS_EXCEPTION_IF_NULL(input); + auto cinput = input->cast(); + if (cinput == nullptr || !IsNopNode(cinput)) { + new_inputs.push_back(input); + continue; + } + if (cinput->inputs().size() == 2) { + new_inputs.push_back(cinput->input(1)); + need_update = true; + changed = true; + } else { + new_inputs.push_back(input); + } + } + if (need_update) { + cnode->set_inputs(new_inputs); + } + // push into new execution list + new_nodes.push_back(cnode); + } + graph->set_execution_order(new_nodes); + } +} + +std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, + const AnfNodePtr &node) { + auto output_node_list = std::make_shared>>(); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto iter = manager->node_users().find(node); + if (iter == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + auto output_info_list = iter->second; + for (const auto &output_info : output_info_list) { + if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { + continue; + } + if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && + output_info.second == kDependAttachNodeIndex) { + continue; + } + output_node_list->push_back(output_info); + } + return output_node_list; +} + +bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto output_node_list = GetRealNodeUsedList(graph, node); + MS_EXCEPTION_IF_NULL(output_node_list); + return output_node_list->size() > 1; +} + +AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { + auto idx = NewValueNode(SizeToInt(output_idx)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToInt(output_idx)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); + MS_EXCEPTION_IF_NULL(tuple_getitem); + tuple_getitem->set_scope(node->scope()); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); + TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); + return tuple_getitem; +} + +void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector new_inputs; + std::vector new_input_names; + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + auto input_names = primitive->GetAttr(kAttrInputNames); + if (input_names == nullptr) { + MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; + return; + } + auto input_names_vec = GetValue>(input_names); + auto inputs = cnode->inputs(); + new_inputs.push_back(inputs[0]); + bool need_update = false; + for (size_t i = 0; i < inputs.size() - 1; ++i) { + auto input_node = inputs[i + 1]; + MS_EXCEPTION_IF_NULL(input_node); + if (input_attrs.find(i) != input_attrs.end() && input_node->isa()) { + auto value_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; + if (i >= input_names_vec.size()) { + MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; + } + primitive->set_attr(input_names_vec[i], value_node->value()); + need_update = true; + } else { + new_inputs.push_back(input_node); + if (i < input_names_vec.size()) { + new_input_names.push_back(input_names_vec[i]); + } + } + } + if (need_update) { + // Update cnode's inputs + cnode->set_inputs(new_inputs); + // Update cnode's input_names attr + primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); + } +} + +bool AnfEqual(const BaseRef &a, const BaseRef &b) { + if (utils::isa(a) && utils::isa(b)) { + auto a_node = utils::cast(a); + auto b_node = utils::cast(b); + MS_EXCEPTION_IF_NULL(a_node); + MS_EXCEPTION_IF_NULL(b_node); + if (IsValueNode(a_node) && IsValueNode(b_node)) { + auto a_value_node = a_node->cast(); + MS_EXCEPTION_IF_NULL(a_value_node); + auto a_value = a_value_node->value(); + MS_EXCEPTION_IF_NULL(a_value); + auto a_prim = a_value->cast(); + MS_EXCEPTION_IF_NULL(a_prim); + + auto b_value_node = b_node->cast(); + MS_EXCEPTION_IF_NULL(b_value_node); + auto b_value = b_value_node->value(); + MS_EXCEPTION_IF_NULL(b_value); + auto b_prim = b_value->cast(); + MS_EXCEPTION_IF_NULL(b_prim); + + return a_prim->name() == b_prim->name(); + } else if (a_node->isa() && b_node->isa()) { + auto a_value_node_ptr = a_node->cast(); + if (a_value_node_ptr == nullptr) { + MS_LOG(EXCEPTION) << "cast value node ptr fail"; + } + auto a_value_ptr = a_value_node_ptr->value(); + if (a_value_ptr == nullptr) { + MS_LOG(EXCEPTION) << "value ptr is nullptr"; + } + + auto b_value_node_ptr = b_node->cast(); + if (b_value_node_ptr == nullptr) { + MS_LOG(EXCEPTION) << "cast value node ptr fail"; + } + auto b_value_ptr = b_value_node_ptr->value(); + if (b_value_ptr == nullptr) { + MS_LOG(EXCEPTION) << "value ptr is nullptr"; + } + + return (*a_value_ptr) == (*b_value_ptr); + } + MS_LOG(DEBUG) << "check AnfNodePtr equal"; + } + if (utils::isa(a) && utils::isa(b)) { + MS_LOG(DEBUG) << "check GraphPtr equal"; + } + return a == b; +} + +bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { + // To matchCNode and Kernel's type + if (utils::isa(a) && utils::isa(b)) { + return true; + } + return a.type() == b.type(); +} + +namespace { +ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + return nullptr; +} + +CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + return nullptr; +} + +VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { + if (utils::isa(graph)) { + MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); + return std::make_shared(utils::cast(sexp), nullptr); + } + if (utils::isa(graph)) { + MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); + return std::make_shared(utils::cast(sexp), utils::cast(graph)); + } + MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); + return nullptr; +} + +AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph) { + MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + std::vector input_nodes; + const auto &tuple = utils::cast(sexp); + if (multigraph && utils::isa(graph)) { + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); + input_nodes.push_back(node); + } + VarPtr var_ptr = utils::cast(graph); + return std::make_shared(input_nodes, var_ptr); + } + + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); + input_nodes.push_back(node); + } + return CreateCNodeWithGraph(input_nodes, graph); +} +} // namespace + +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { + MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + MS_EXCEPTION_IF_NULL(primitive_vars); + if (utils::isa(sexp)) { + return HandleSexpVector(sexp, graph, primitive_vars, multigraph); + } + if (utils::isa(sexp)) { + auto var_ptr = utils::cast(sexp); + MS_EXCEPTION_IF_NULL(var_ptr); + if (var_ptr->primitive()) { + (*primitive_vars)[var_ptr->primitive()] = var_ptr; + return NewValueNode(var_ptr->primitive()); + } + return CreateVarNodeWithSexp(sexp, graph); + } + if (utils::isa(sexp)) { + return utils::cast(sexp); + } + auto value_node = CreateValueNodeWithSexp(sexp); + if (value_node == nullptr) { + MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); + } + return value_node; +} + +bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) { + MS_EXCEPTION_IF_NULL(equiv1); + MS_EXCEPTION_IF_NULL(equiv2); + MS_EXCEPTION_IF_NULL(var_node); + auto equiv1_node = GetAnfNodeByVar(equiv1, var_node); + MS_EXCEPTION_IF_NULL(equiv1_node); + auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); + MS_EXCEPTION_IF_NULL(equiv2_node); + return *equiv1_node == *equiv2_node; +} + +AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(var_node); + auto iter = (*equiv).find(var_node); + if (iter == (*equiv).end()) { + MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched."; + return nullptr; + } + auto res = utils::cast(iter->second); + if (res == nullptr) { + MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node"; + } + return res; +} + +bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { + MS_EXCEPTION_IF_NULL(n1); + MS_EXCEPTION_IF_NULL(n2); + auto n1_cnode = n1->cast(); + auto n2_cnode = n2->cast(); + MS_EXCEPTION_IF_NULL(n1_cnode); + MS_EXCEPTION_IF_NULL(n2_cnode); + auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_input1); + auto value_node1 = index_input1->cast(); + MS_EXCEPTION_IF_NULL(value_node1); + auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_input2); + auto value_node2 = index_input2->cast(); + MS_EXCEPTION_IF_NULL(value_node2); + return GetValue(value_node1->value()) < GetValue(value_node2->value()); +} + +bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(INFO) << "node is not a cnode"; + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr(node, attr_name); +} + +bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set) { + MS_EXCEPTION_IF_NULL(node); + TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0); + if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) { + return true; + } + MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString(); + return false; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h new file mode 100644 index 0000000000..a267e65b53 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -0,0 +1,199 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ + +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "backend/session/kernel_graph.h" +#include "common/utils.h" +#include "backend/optimizer/common/pattern_engine.h" + +namespace mindspore { +namespace opt { +constexpr size_t kTransOpInputNum = 2; +constexpr size_t kCastInputNum = 2; +constexpr size_t kDependInputNum = 3; +constexpr size_t kReluInputNum = 2; +constexpr size_t kReluGradInputNum = 3; +constexpr size_t kAddInputNum = 3; +constexpr size_t kAddNInputNum = 3; +constexpr size_t kTupleGetitemInputNum = 3; +constexpr size_t kConvInputNum = 3; +constexpr size_t kRealDivInputNum = 3; +constexpr size_t kSqrtInputNum = 2; +constexpr size_t kMulInputNum = 3; +constexpr size_t kRsqrtInputNum = 2; +constexpr size_t kSubInputNum = 3; +constexpr size_t kAssignSubInputNum = 3; + +constexpr size_t kConvBn1OutputNum = 3; +constexpr size_t kBn2ReluOutputNum = 4; + +constexpr size_t kBnInputNum = 6; +constexpr size_t kBnOutputNum = 5; +constexpr size_t kBatchNormInputNum = 5; +constexpr size_t kBatchNormOutputNum = 5; + +constexpr size_t kBN1OutputNum = 2; +constexpr size_t kBN2OutputNum = 3; +constexpr size_t kBN3OutputNum = 1; + +constexpr size_t kBNGradInputNum = 6; +constexpr size_t kBNGradOutputNum = 3; + +constexpr size_t kBNGrad1OutputNum = 3; +constexpr size_t kBNGrad2OutputNum = 5; +constexpr size_t kBNGrad3OutputNum = 1; + +constexpr size_t kBNTrainingReduceOutputNum = 2; +constexpr size_t kBNTrainingUpdateOutputNum = 5; +constexpr size_t kBNTrainingUpdateV2OutputNum = 3; +constexpr size_t kBNTrainingUpdateV3OutputNum = 5; +constexpr size_t kBNTrainingUpdateGradOutputNum = 2; + +constexpr size_t kSingleOutputNum = 1; +constexpr size_t kSumNodeInputNum = 2; +constexpr size_t kSquareNodeInputNum = 2; +constexpr size_t kSquareSumv2OutputNum = 2; +constexpr size_t kMinimumInputNum = 3; + +constexpr size_t kLambNextMVWithDecayInputNum = 7; +constexpr size_t kLambNextMVWithDecayConstantMulInputNum = 5; +constexpr size_t kLambNextMVWithDecayOutputNum = 4; +constexpr size_t kLambNextMVWithDecayV1OutputNum = 4; +constexpr size_t kLambNextRightOutputNum = 2; +constexpr size_t kLambUpdateWithLrV2InputNum = 8; +constexpr size_t kLambNextMVRuleInputNum = 14; +constexpr size_t kLambNextMVRuleOutputNum = 4; +constexpr size_t kBackendReshapeInputNum = 2; +constexpr size_t kBackendTransposeInputNum = 2; +constexpr size_t kAdamApplyOneWithDecayOutputNum = 3; +constexpr size_t kLayerNormBetaGammaBackpropInputNum = 5; +constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2; +constexpr size_t kLayerNormGradInputNum = 6; +constexpr size_t kAdamApplyOneOutputNum = 3; +constexpr size_t kBackendTransDataInputNum = 2; +constexpr size_t kApplyMomentumInputNum = 6; +constexpr size_t kBiasAddInputNum = 3; +constexpr size_t kTopkInputNum = 3; +constexpr size_t kLarsV2InputNum = 5; +constexpr size_t kFusedMulApplyMomentumOutputNum = 2; +constexpr size_t kSplitInputNum = 2; + +enum FusedBatchNormInput { + kX = 1, + kVariance = 5, +}; +enum FusedBatchNormOutput { + kY = 0, + kRunningMean, + kRunningVariance, + kSaveMean, + kSaveInvVariance, +}; +enum ConvBn1Output { + kData = 0, + kVarPart, + kMean, +}; + +std::vector Convert2Int(const std::vector &v); + +// check whether node1 depends on node2 or not +bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2); + +bool UnVisited(const BaseRef &n); + +bool Visited(const BaseRef &n); + +// check if the input node is CNode, then check it's input_size, if meet condition above, return true, otherwise return +// false. cnode can only be used when return true. +bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode); + +// check if the input node is CNode, then check it's input_size, return CNodePtr if check success. +CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size); + +void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size); + +bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y); + +const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node); + +void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode, + std::vector *conv_bn1_outputs); + +void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector &fused_bn1_outputs, + const CNodePtr &bn_node, std::vector *fused_bn2_outputs); +void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input, + const std::vector &fused_bn1_outputs, + const std::vector &fused_bn2_outputs, const CNodePtr &bn_node, + std::vector *fused_bn3_outputs); + +void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &anf_node_ptr, size_t output_num, + std::vector *outputs); + +tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, + size_t data_length); + +tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple); + +bool IsAllNopNode(const session::KernelGraph *const graph); + +bool IsNopNode(const AnfNodePtr &node); + +void HideNopNode(session::KernelGraph *const graph); + +void RemoveNopNode(session::KernelGraph *const graph); + +AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); + +bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); + +std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, + const AnfNodePtr &node); + +void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs); + +bool AnfEqual(const BaseRef &a, const BaseRef &b); + +bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); + +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph = false); + +// Check var_node in two equivs is the same node +bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node); + +// Get anf_node from equiv by var_node +AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node); + +// Compare tuple getitem's index, return bool[n1's index < n2's index] +bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2); + +// Get attr which is bool from cnode +bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name); + +// Check node's data type is in supported data type set +bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/node_pass.cc b/mindspore/ccsrc/backend/optimizer/common/node_pass.cc new file mode 100644 index 0000000000..16f5284a57 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/node_pass.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/node_pass.h" + +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +bool NodePass::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + + std::unordered_set seen_node; + std::deque todo{func_graph->output()}; + bool changes = false; + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { + continue; + } + (void)seen_node.insert(node); + AnfNodePtr new_node = Run(func_graph, node); + bool change = (new_node != nullptr); + if (new_node != nullptr && new_node != node) { + (void)manager->Replace(node, new_node); + (void)seen_node.erase(node); + } else if (new_node == nullptr) { + new_node = node; + } + if (new_node && IsValueNode(new_node)) { + auto const_func_graph = GetValueNode(new_node); + MS_EXCEPTION_IF_NULL(const_func_graph); + if (!const_func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + todo.push_back(const_func_graph->output()); + } + } else if (new_node && new_node->isa()) { + if (AnfAlgo::IsGraphKernel(new_node)) { + todo.push_back(new_node); + } + auto cnode = new_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); + } + changes = changes || change; + } + return changes; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/node_pass.h b/mindspore/ccsrc/backend/optimizer/common/node_pass.h new file mode 100644 index 0000000000..780ae1a056 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/node_pass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ +#include +#include + +#include "backend/optimizer/common/pass.h" + +namespace mindspore { +namespace opt { +// @brief ANF Node level optimization base pass +class NodePass : public Pass { + public: + explicit NodePass(const std::string &name) : Pass(name) {} + ~NodePass() override = default; + bool Run(const FuncGraphPtr &func_graph) final; + virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; +}; +using NodePassPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/optimizer.cc b/mindspore/ccsrc/backend/optimizer/common/optimizer.cc new file mode 100644 index 0000000000..01e9111e86 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/optimizer.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/optimizer.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "backend/optimizer/common/pass_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/manager.h" + +namespace mindspore { +namespace opt { +PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) + : NodePass(name), + multigraph_(multigraph), + pattern_engine_(PatternEngine(std::make_shared(), + std::function(AnfEqual), + std::function(CNodeTypeEqual))), + primitive_vars_(std::make_shared()) {} + +const BaseRef PatternProcessPass::DefinePattern() const { + VarPtr X = std::make_shared(); + return BaseRef({X}); +} + +void PatternProcessPass::Build() { + VarPtr fg = std::make_shared("RootG"); + BaseRef pattern = std::move(DefinePattern()); + pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_); +} + +AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + if (pattern_ == nullptr) { + Build(); + } + + auto empty_equiv = std::make_shared(); + MS_EXCEPTION_IF_NULL(primitive_vars_); + EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv); + if (equiv != nullptr && !equiv->empty()) { + return Process(func_graph, node, equiv); + } + return nullptr; +} + +bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + VarPtr fg = std::make_shared("RootG"); + auto empty_equiv = std::make_shared(); + MS_EXCEPTION_IF_NULL(child_primitive_vars_); + EquivPtr another_equiv = + child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node, + *child_primitive_vars_, empty_equiv); + if (another_equiv != nullptr && !another_equiv->empty()) { + return IsShareNodes(equiv, another_equiv); + } + return false; +} + +void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) { + if (pass_manager != nullptr) { + pass_managers_.push_back(pass_manager); + } +} + +FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) { + MS_EXCEPTION_IF_NULL(func_graph); + run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once; + // Performance risk by creating new manager each time + auto manager = Manage(func_graph, true); + + bool changed = true; + while (changed) { + changed = false; + for (size_t i = 0; i < pass_managers_.size(); ++i) { + const PassManagerPtr &pm = pass_managers_[i]; + if (pm != nullptr && pm->Run(func_graph)) { + changed = true; + } + } + if (run_only_once_) { + break; + } + } + + std::vector func_graphs; + func_graphs.push_back(func_graph); + manager->KeepRoots(func_graphs); + (void)TopoSort(func_graph->get_return()); + return func_graph; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/optimizer.h b/mindspore/ccsrc/backend/optimizer/common/optimizer.h new file mode 100644 index 0000000000..0b03c9c0ee --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/optimizer.h @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ + +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "utils/graph_utils.h" +#include "common/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +using PatternListType = std::initializer_list; + +class PatternProcessPass : public NodePass { + public: + explicit PatternProcessPass(const std::string &name = "", bool multigraph = true); + ~PatternProcessPass() override = default; + virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0; + virtual const BaseRef DefinePattern() const; + AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; + + private: + void Build(); + + AnfNodePtr pattern_ = nullptr; + bool multigraph_ = true; + PatternEngine pattern_engine_; + PrimitiveVarMapPtr primitive_vars_; +}; + +class MultipleOutputPatternProcessPass : public PatternProcessPass { + public: + explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true) + : PatternProcessPass(name, multigraph), + child_pattern_engine_(PatternEngine(std::make_shared(), + std::function(AnfEqual), + std::function(CNodeTypeEqual))), + child_primitive_vars_(std::make_shared()) {} + ~MultipleOutputPatternProcessPass() override = default; + virtual BaseRef DefineAnotherPattern() const = 0; + // check two patterns whether share the same nodes or not + virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0; + + protected: + bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; + PatternEngine child_pattern_engine_; + PrimitiveVarMapPtr child_primitive_vars_; +}; + +class GraphOptimizer { + public: + explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {} + virtual ~GraphOptimizer() = default; + + void AddPassManager(const PassManagerPtr &pass_manager); + FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true); + + private: + const std::string name_ = "graph_optimizer"; + std::vector pass_managers_{}; + bool run_only_once_ = true; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/pass.h b/mindspore/ccsrc/backend/optimizer/common/pass.h new file mode 100644 index 0000000000..6e35fb1dc4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ +#include +#include + +#include "ir/anf.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +// @brief ANF Graph level optimization base pass +class Pass { + public: + explicit Pass(const std::string &name = "pass") : name_(name) {} + virtual ~Pass() = default; + virtual bool Run(const FuncGraphPtr &func_graph) = 0; + virtual std::string name() const { return name_; } + + private: + const std::string name_; +}; +using PassPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc b/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc new file mode 100644 index 0000000000..f9f41237e0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/pass_manager.h" + +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +const std::vector &PassManager::Passes() const { return passes_; } + +void PassManager::AddPass(const PassPtr &pass) { + if (pass != nullptr) { + passes_.push_back(pass); + } +} + +bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { + if (func_graph == nullptr) { + return false; + } + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + bool changed = false; + size_t num = 0; + for (const auto &pass : passes) { + if (pass != nullptr) { +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time {}; + struct timeval end_time {}; + (void)gettimeofday(&start_time, nullptr); +#endif + if (pass->Run(func_graph)) { + changed = true; + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost << " us"; +#endif + if (save_graphs) { + auto dump_file_path = + save_graphs_path + "/" + "hwopt_" + name() + "_" + std::to_string(num) + "_" + pass->name() + ".ir"; + DumpIR(dump_file_path, func_graph); + } + num++; + } + } + return changed; +} + +bool PassManager::Run(const FuncGraphPtr &func_graph) const { + bool changed = false; + // run all passes + bool change = true; + while (change) { + change = Run(func_graph, passes_); + changed = change || changed; + if (run_only_once_) { + break; + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/pass_manager.h b/mindspore/ccsrc/backend/optimizer/common/pass_manager.h new file mode 100644 index 0000000000..51db27d250 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/pass_manager.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ + +#include +#include +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/node_pass.h" + +namespace mindspore { +namespace opt { +// @brief For optimization passes management +class PassManager { + public: + explicit PassManager(const std::string &name = "pm", bool run_only_once = true) + : name_(name), passes_{}, run_only_once_(run_only_once) {} + virtual ~PassManager() = default; + // Get all the passes added by AddPass + const std::vector &Passes() const; + // Add graph pass, the pass object will be freed when pass manager freed. + void AddPass(const PassPtr &pass); + // Run passes added in pass manager on the input graph + // @param [inout] graph The graph to be optimized + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph) const; + // Run the given graph passes on the input graph + // @param [inout] graph The graph to be optimized + // @param [in] passes The given graph passes + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; + std::string name() const { return name_; } + + private: + const std::string name_; + std::vector passes_; + bool run_only_once_; +}; +using PassManagerPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/pattern_engine.cc b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.cc new file mode 100644 index 0000000000..bd4efd82ef --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.cc @@ -0,0 +1,360 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/common/pattern_engine.h" + +#include +#include +#include +#include + +#include "frontend/optimizer/opt.h" + +#include "ir/anf.h" +#include "utils/convert_utils_base.h" +#include "utils/overload.h" + +namespace mindspore { +static int GetNextTag() { + static int kID = 0; + return kID++; +} + +void Var::EnsureTag() { + if (tag_.length() == 0) { + std::ostringstream buffer; + buffer << "_" << GetNextTag(); + tag_ = buffer.str(); + } +} + +bool operator==(const VarPtr &lhs, const VarPtr &rhs) { + if (lhs->isa() && rhs->isa()) { + CondVarPtr v1 = dyn_cast(lhs); + CondVarPtr v2 = dyn_cast(rhs); + return *v1 == *v2; + } + + if (lhs->isa() && rhs->isa()) { + SVarPtr v1 = dyn_cast(lhs); + SVarPtr v2 = dyn_cast(rhs); + return *v1 == *v2; + } + return (*lhs == *rhs); +} + +std::string SeqVar::ToString() const { + std::ostringstream buffer; + buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")"; + return buffer.str(); +} + +std::ostream &operator<<(std::ostream &os, const VarPtr &var) { + if (var == nullptr) { + os << ""; + } else { + os << var->ToString(); + } + return os; +} + +template <> +std::ostream &operator<<(std::ostream &os, const Equiv &equiv) { + os << "[Equiv]" + << "\n"; + for (auto &equiv_item : equiv) { + auto k = equiv_item.first; + os << k << ":"; + BaseRef x = equiv_item.second; + if (utils::isa(x)) { + auto node = utils::cast(x); + os << "TypeString[" << node->type_name() << "]"; + if (IsValueNode(node)) { + os << "IsValueNodeGraph "; + } + os << "type " << node->type_name(); + if (node->isa()) { + os << " value " << GetValueNode(node); + } + os << " addr: " << node; + } else if (utils::isa(x)) { + os << "Named " << x.ToString().c_str(); + } else if (utils::isa(x)) { + os << "TypeString[Var]"; + os << utils::cast(x); + } else if (utils::isa(x)) { + os << "TypeString[Graph]"; + } + os << "\n"; + } + return os; +} + +static BaseRef GetVar(const BaseRef &x) { + MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); + if (utils::isa(x)) { + auto node = utils::cast(x); + MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]"; + if (node->isa()) { + MS_LOG(DEBUG) << "IsVarNode " + node->cast()->var_->ToString(); + return node->cast()->var_; + } + if (node->isa()) { + MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString(); + } else { + MS_LOG(DEBUG) << "type " + node->type_name(); + } + } else if (utils::isa(x)) { + MS_LOG(DEBUG) << "Named " + x.ToString(); + } else if (utils::isa(x)) { + MS_LOG(DEBUG) << "VectorRef"; + } else if (utils::isa(x)) { + MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString(); + } + MS_LOG(DEBUG) << "GetVar end: " + x.ToString(); + return x; +} + +EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { + MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); + MS_EXCEPTION_IF_NULL(equiv); + if (utils::isa(pattern)) { + VarPtr var = utils::cast(pattern); + if (var->matches(expr)) { + (*equiv)[var] = expr; + MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString(); + return equiv; + } + } + + return nullptr; +} + +bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const { + MS_EXCEPTION_IF_NULL(values_expr); + if (utils::isa(pattern_ref)) { + *values_pattern = pattern_ref; + *values_expr = expr_ref; + return true; + } + return false; +} + +bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const { + MS_EXCEPTION_IF_NULL(values_expr); + // visitor to visite the list + auto appender_pattern = [](VectorRef &values) { + std::function fn = [&](const BaseRef &u) { + values.push_back(GetVar(u)); + return u; + }; + return fn; + }; + + visitor_->SetFn(appender_pattern(*values_pattern)); + MS_LOG(DEBUG) << "visit pattern_ref"; + bool success = visitor_->Visit(pattern_ref, nullptr); + if (!success) { + return false; + } + + auto appender_expr = [](VectorRef &values) { + std::function fn = [&](const BaseRef &u) { + values.push_back(u); + return u; + }; + return fn; + }; + + visitor_->SetFn(appender_expr(*values_expr)); + MS_LOG(DEBUG) << "visit expr_ref"; + return visitor_->Visit(expr_ref, nullptr); +} + +static int GetSVarStartIndex(const VectorRef &values) { + int index = -1; + int count = 0; + for (auto &value : values) { + if (utils::isa(value) && utils::cast(value)->isa()) { + if (index != -1) { + MS_LOG(DEBUG) << "Multiple SVars in sequence"; + return kInvalidVarIndex; + } + index = count; + } + count++; + } + return index; +} + +void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) { + if (equiv == nullptr || values_pattern.empty() || !utils::isa(values_pattern[0]) || + !utils::isa(expr_ref)) { + return; + } + auto real_node = utils::cast(expr_ref); + MS_EXCEPTION_IF_NULL(real_node); + if (!real_node->isa()) { + return; + } + auto prim_node = utils::cast(values_pattern[0]); + MS_EXCEPTION_IF_NULL(prim_node); + if (!IsValueNode(prim_node)) { + return; + } + ValuePtr value = GetValueNode(prim_node); + MS_EXCEPTION_IF_NULL(value); + auto prim = value->cast(); + MS_EXCEPTION_IF_NULL(prim); + auto iter = primitive_vars.find(prim); + if (iter == primitive_vars.end()) { + return; + } + (*equiv)[iter->second] = real_node; +} + +EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, + const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { + int svar_index = GetSVarStartIndex(values_pattern); + if (svar_index == kInvalidVarIndex) { + return nullptr; + } + + size_t values_pattern_len = values_pattern.size(); + size_t values_expr_len = values_expr.size(); + + if (svar_index == -1) { + if (values_pattern_len != values_expr_len) { + MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", expr len " + << values_expr_len; + return nullptr; + } + } + if (values_expr_len < values_pattern_len - 1) { + MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len; + return nullptr; + } + size_t diff = values_expr_len - values_pattern_len + 1; + for (size_t i = 0; i < values_pattern_len; i++) { + size_t expr_i = i; + if (svar_index != -1 && i == IntToSize(svar_index)) { + auto seq = + std::vector(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); + equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv); + } else { + if (svar_index != -1 && i > IntToSize(svar_index)) { + expr_i = i + diff - 1; + } + equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv); + } + if (equiv == nullptr) { + return nullptr; + } + } + return equiv; +} + +EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) const { + MS_LOG(DEBUG) << "-----[in Match]"; + MS_LOG(DEBUG) << "GetVar w"; + BaseRef pattern_ref = GetVar(pattern); + MS_LOG(DEBUG) << "GetVar v"; + BaseRef expr_ref = expr; + + if (equiv == nullptr) { + MS_LOG(EXCEPTION) << "Equiv pointer is null"; + } + + MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString(); + // 1. if pattern_ref is var and already in equiv, replace it. + if (utils::isa(pattern_ref)) { + VarPtr var = utils::cast(pattern_ref); + auto iter = equiv->find(var); + if (iter != equiv->end()) { + pattern_ref = iter->second; + } + } + + // 2. check equal + if (eq_(pattern_ref, expr_ref)) { + return equiv; + } + + // 3. match var + EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv); + if (ret_equiv) { + return ret_equiv; + } + + // 4. here the type can be std:vector, std:list, + // or cnode. + if (!type_eq_(pattern_ref, expr_ref)) { + MS_LOG(DEBUG) << "Type mismatch"; + return nullptr; + } + + // 5. transfer the Containers by visitor to std::vector + VectorRef values_pattern; + VectorRef values_expr; + if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) { + return nullptr; + } + + // 6. if any svar in both side, find the SeqVar index, + // try to pack the Var s in std::vector to a Seq and match elements one by one. + // check svar + equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv); + UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv); + return equiv; +} + +BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(equiv); + MS_LOG(DEBUG) << "-----[in Replace]"; + BaseRef ref = GetVar(pattern); + BaseRef out; + bool is_match = false; + + // w is var + if (utils::isa(ref)) { + const VarPtr &var = utils::cast(ref); + auto iter = equiv->find(var); + if (iter != equiv->end()) { + out = iter->second; + is_match = true; + } + } + if (is_match) { + return out; + } + + // visitor to visit the list + std::function fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); }; + + visitor_->SetFn(fn); + BaseRef visit_out; + if (!visitor_->Visit(pattern, &visit_out)) { + return pattern; + } + return visit_out; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h new file mode 100644 index 0000000000..51fa8801b2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h @@ -0,0 +1,204 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "backend/optimizer/common/visit.h" +#include "base/base.h" +#include "utils/log_adapter.h" +#include "utils/base_ref.h" + +namespace mindspore { +class CondVar; +class SeqVar; +using CondVarPtr = std::shared_ptr; +using SVarPtr = std::shared_ptr; +const int kInvalidVarIndex = -2; + +using ConditionFunc = std::function; + +// Base wildcard variable which could match any anf node. +class Var : public Base { + friend class VarHasher; + + public: + explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); } + explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { + EnsureTag(); + } + Var(const Var &other) : Base(other), tag_(other.tag_) {} + virtual Var &operator=(const Var &other) { + if (&other == this) { + return *this; + } + this->tag_ = other.tag_; + return *this; + } + ~Var() override = default; + MS_DECLARE_PARENT(Var, Base); + + virtual bool matches(const BaseRef &) { return true; } + + virtual bool operator==(const Var &other) const { return tag_ == other.tag_; } + bool operator!=(const Var &other) const { return !(&other == this); } + + std::string tag() const { return tag_; } + PrimitivePtr primitive() const { return primitive_; } + std::string ToString() const override { + std::ostringstream buffer; + buffer << "Var(" << tag_ << ")"; + return buffer.str(); + } + std::size_t hash() const override { return std::hash()(tag_); } + + protected: + void EnsureTag(); + + std::string tag_; + PrimitivePtr primitive_; +}; + +// VarNode means variable node, a subclass of AnfNode +class VarNode : public AnfNode { + public: + VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {} + ~VarNode() override = default; + MS_DECLARE_PARENT(VarNode, AnfNode); + + const VarPtr var_; +}; +using VarNodePtr = std::shared_ptr; + +class VarHasher { + public: + std::size_t operator()(const Var &var) const { return var.hash(); } +}; + +// Condition Var, match an anf node when condition function return true. +class CondVar : public Var { + public: + explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {} + ~CondVar() override = default; + MS_DECLARE_PARENT(CondVar, Var); + bool matches(const BaseRef &value) override { + MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); + if (utils::isa(value)) { + return false; + } + return cond_fn_(value); + } + ConditionFunc cond_fn_; +}; + +using Seq = VectorRef; +using SeqPtr = std::shared_ptr; + +// Sequence Var which could match multiple consecutive input nodes of a CNode. +class SeqVar : public Var { + public: + SeqVar() { subvar_ = std::make_shared(); } + ~SeqVar() override = default; + MS_DECLARE_PARENT(SeqVar, Var); + explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; } + bool matches(const BaseRef &value) override { + // match Seq. + if (utils::isa(value)) { + const Seq &seq = utils::cast(value); + return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) { + auto eq = subvar_->matches(v); + return eq; + }); + } + return false; + } + bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; } + std::string ToString() const override; + + private: + VarPtr subvar_; +}; + +bool operator==(const VarPtr &lhs, const VarPtr &rhs); + +inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); } + +std::ostream &operator<<(std::ostream &os, const VarPtr &var); + +using Equiv = std::map; +using EquivPtr = std::shared_ptr; +using PrimitiveVarMap = std::unordered_map; +using PrimitiveVarMapPtr = std::shared_ptr; + +inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); } + +class PatternEngine { + public: + PatternEngine(const std::shared_ptr &visitor, + const std::function &eq, + const std::function &type_eq = DefaultTypeEq) + : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} + ~PatternEngine() = default; + + EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) const; + // Replace pattern with equivalent + BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const; + + private: + EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, + const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; + bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern, + VectorRef *const values_expr) const; + bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const; + std::shared_ptr visitor_; + std::function eq_; + std::function type_eq_; +}; +} // namespace mindspore +namespace std { +using mindspore::ERROR; +using mindspore::LogStream; +using mindspore::NoExceptionType; +template <> +struct hash { + std::size_t operator()(const mindspore::VarPtr var) const { + if (var == nullptr) { + MS_LOG(ERROR) << "Invalid var ptr"; + return 0; + } + return std::hash{}(var->tag()); + } +}; +} // namespace std +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/visit.cc b/mindspore/ccsrc/backend/optimizer/common/visit.cc new file mode 100644 index 0000000000..d0b52609f8 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/visit.cc @@ -0,0 +1,166 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/common/visit.h" + +#include +#include +#include +#include + +#include "backend/optimizer/common/pattern_engine.h" +#include "utils/any.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "utils/log_adapter.h" + +/* namespace to support utils definition */ +namespace mindspore { +bool CheckIfNeedExpand(const std::vector &list) { + return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa(any); }); +} + +std::shared_ptr ExpandList(const std::vector &list) { + std::shared_ptr new_list = std::make_shared(); + for (auto &item : list) { + if (utils::isa(item)) { + const Seq &seq = utils::cast(item); + new_list->insert(new_list->end(), seq.begin(), seq.end()); + } else { + new_list->push_back(item); + } + } + return new_list; +} + +bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const { + std::vector out; + (void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out), + [this](const BaseRef &item) { return fn_(item); }); + if (visit_out != nullptr) { + *visit_out = ExpandList(out); + } + return true; +} + +bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const { + if (utils::isa(any)) { + return Visit(utils::cast(any), visit_out); + } else if (utils::isa(any)) { + auto nodeptr = utils::cast(any); + AnfNodePtr output; + AnfNodePtr *p_output = &output; + if (visit_out == nullptr) { + p_output = nullptr; + } + Visit(nodeptr, fn_, p_output); + if (visit_out != nullptr) { + *visit_out = output; + } + return true; + } + MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString(); + return false; +} + +void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const { + if (node->isa()) { + Visit(node->cast(), fn, output); + return; + } + + if (node->isa()) { + Visit(node->cast(), fn, output); + return; + } + + if (output != nullptr) { + *output = node; + } +} + +void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const { + // if output is nullptr, it's not required to make the new CNode node. + if (output == nullptr) { + for (auto &inp : cnode->inputs()) { + (void)fn(inp); + } + + if (cnode->func_graph() != nullptr) { + (void)fn(cnode->func_graph()); + } else { + (void)fn(cnode->func_graph_as_var()); + } + return; + } + + std::vector new_inputs; + std::vector after_cnode_fn; + std::shared_ptr out; + (void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn); + if (CheckIfNeedExpand(after_cnode_fn)) { + out = ExpandList(after_cnode_fn); + } + + std::vector &outs = after_cnode_fn; + if (out != nullptr) { + outs = out->elements(); + } + + for (auto &any_item : outs) { + if (!utils::isa(any_item)) { + MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr"; + } + new_inputs.push_back(utils::cast(any_item)); + } + + BaseRef any_fg; + AnfNodePtr new_cnode = nullptr; + if (cnode->func_graph() != nullptr) { + any_fg = fn(cnode->func_graph()); + if (!utils::isa(any_fg)) { + MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr"; + } + new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); + } else { + any_fg = fn(cnode->func_graph_as_var()); + if (utils::isa(any_fg)) { + new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); + } else if (utils::isa(any_fg)) { + new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); + } else { + MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr"; + } + } + new_cnode->set_abstract(cnode->abstract()); + *output = new_cnode; +} + +void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const { + const BaseRef &value = utils::cast(fn(vnode->value())); + if (utils::isa(value)) { + if (output != nullptr) { + auto ct = NewValueNode(utils::cast(value)); + ct->set_abstract(vnode->abstract()); + *output = ct; + } + return; + } + MS_LOG(EXCEPTION) << "Visit result is not ValuePtr."; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/visit.h b/mindspore/ccsrc/backend/optimizer/common/visit.h new file mode 100644 index 0000000000..9799d3f9c1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/visit.h @@ -0,0 +1,61 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_VISIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_VISIT_H_ + +#include +#include +#include +#include +#include +#include + +#include "base/base.h" +#include "utils/base_ref.h" + +// namespace to support utils definition +namespace mindspore { +using VisitFn = std::function; + +class Visitor { + public: + virtual void SetFn(VisitFn fn) = 0; + virtual bool Visit(const BaseRef &e, BaseRef *out) const = 0; + virtual bool Visit(const VectorRef &e, BaseRef *out) const = 0; + virtual ~Visitor() = default; +}; + +class DefaultVisitor : public Visitor { + public: + DefaultVisitor() : fn_(nullptr) {} + ~DefaultVisitor() override = default; + void SetFn(VisitFn fn) override { fn_ = fn; }; + bool Visit(const VectorRef &e, BaseRef *out) const override; + bool Visit(const BaseRef &e, BaseRef *out) const override; + void Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const; + void Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const; + void Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const; + + VisitFn fn_; +}; + +std::shared_ptr ExpandList(const std::vector &list); +bool CheckIfNeedExpand(const std::vector &list); +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_VISIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc new file mode 100644 index 0000000000..41e4abee27 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/adam_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { + std::vector inputs_format; + std::vector outputs_format; + std::vector inputs_type; + std::vector outputs_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); + inputs_format.push_back(kOpFormat_DEFAULT); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); + outputs_format.push_back(kOpFormat_DEFAULT); + } + builder.SetInputsDeviceType(inputs_type); + builder.SetInputsFormat(inputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetOutputsFormat(outputs_format); + return builder.Build(); +} +} // namespace + +const BaseRef AdamFusion::DefinePattern() const { + VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), + VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); + VectorRef next_v = + VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), + VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); + VectorRef update = VectorRef( + {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); + VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); + VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); + VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); + VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); + return depend3; +} + +const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto beta1_input = utils::cast((*equiv)[beta1_]); + auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); + auto beta2_input = utils::cast((*equiv)[beta2_]); + auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); + auto eps_input = utils::cast((*equiv)[eps_]); + auto lr_input = utils::cast((*equiv)[lr_]); + auto param_input = utils::cast((*equiv)[param_]); + auto m_input = utils::cast((*equiv)[m_]); + auto v_input = utils::cast((*equiv)[v_]); + auto gradient_input = utils::cast((*equiv)[gradient_]); + MS_EXCEPTION_IF_NULL(beta1_input); + MS_EXCEPTION_IF_NULL(one_sub_beta1_input); + MS_EXCEPTION_IF_NULL(beta2_input); + MS_EXCEPTION_IF_NULL(one_sub_beta2_input); + MS_EXCEPTION_IF_NULL(eps_input); + MS_EXCEPTION_IF_NULL(lr_input); + MS_EXCEPTION_IF_NULL(param_input); + MS_EXCEPTION_IF_NULL(m_input); + MS_EXCEPTION_IF_NULL(v_input); + MS_EXCEPTION_IF_NULL(gradient_input); + + auto prim = std::make_shared(kFusedAdamName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = { + NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, + eps_input, lr_input, param_input, m_input, v_input, + gradient_input}; + auto adam = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(adam); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get()); + adam->set_scope(node->scope()); + + auto build_info = GenerateKernelBuildInfo(adam); + AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); + return adam; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h new file mode 100644 index 0000000000..f87defc04c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class AdamFusion : public PatternProcessPass { + public: + explicit AdamFusion(bool multigraph = true) : PatternProcessPass("adam_fusion", multigraph) { + beta1_ = std::make_shared(); + one_sub_beta1_ = std::make_shared(); + beta2_ = std::make_shared(); + one_sub_beta2_ = std::make_shared(); + eps_ = std::make_shared(); + lr_ = std::make_shared(); + param_ = std::make_shared(); + m_ = std::make_shared(); + v_ = std::make_shared(); + gradient_ = std::make_shared(); + } + ~AdamFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr beta1_; + VarPtr one_sub_beta1_; + VarPtr beta2_; + VarPtr one_sub_beta2_; + VarPtr eps_; + VarPtr lr_; + VarPtr param_; + VarPtr m_; + VarPtr v_; + VarPtr gradient_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc new file mode 100644 index 0000000000..c95945c980 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/adam_weight_decay_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { + std::vector inputs_format; + std::vector outputs_format; + std::vector inputs_type; + std::vector outputs_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); + inputs_format.push_back(kOpFormat_DEFAULT); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); + outputs_format.push_back(kOpFormat_DEFAULT); + } + builder.SetInputsDeviceType(inputs_type); + builder.SetInputsFormat(inputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetOutputsFormat(outputs_format); + return builder.Build(); +} +} // namespace + +const BaseRef AdamWeightDecayFusion::DefinePattern() const { + VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), + VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); + VectorRef next_v = + VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), + VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); + VectorRef update = VectorRef( + {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); + VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); + + VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); + VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); + VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); + VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); + return depend3; +} + +const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto beta1_input = utils::cast((*equiv)[beta1_]); + auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); + auto beta2_input = utils::cast((*equiv)[beta2_]); + auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); + auto eps_input = utils::cast((*equiv)[eps_]); + auto lr_input = utils::cast((*equiv)[lr_]); + auto weight_decay_input = utils::cast((*equiv)[weight_decay_]); + auto param_input = utils::cast((*equiv)[param_]); + auto m_input = utils::cast((*equiv)[m_]); + auto v_input = utils::cast((*equiv)[v_]); + auto gradient_input = utils::cast((*equiv)[gradient_]); + MS_EXCEPTION_IF_NULL(beta1_input); + MS_EXCEPTION_IF_NULL(one_sub_beta1_input); + MS_EXCEPTION_IF_NULL(beta2_input); + MS_EXCEPTION_IF_NULL(one_sub_beta2_input); + MS_EXCEPTION_IF_NULL(eps_input); + MS_EXCEPTION_IF_NULL(lr_input); + MS_EXCEPTION_IF_NULL(weight_decay_input); + MS_EXCEPTION_IF_NULL(param_input); + MS_EXCEPTION_IF_NULL(m_input); + MS_EXCEPTION_IF_NULL(v_input); + MS_EXCEPTION_IF_NULL(gradient_input); + + auto prim = std::make_shared(kFusedAdamWeightDecayName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = { + NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, + eps_input, lr_input, param_input, m_input, v_input, + gradient_input, weight_decay_input}; + auto adam_weight_decay = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(adam_weight_decay); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam_weight_decay.get()); + adam_weight_decay->set_scope(node->scope()); + + auto build_info = GenerateKernelBuildInfo(adam_weight_decay); + AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get()); + return adam_weight_decay; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h new file mode 100644 index 0000000000..53477ec898 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class AdamWeightDecayFusion : public PatternProcessPass { + public: + explicit AdamWeightDecayFusion(bool multigraph = true) : PatternProcessPass("adam_weight_decay_fusion", multigraph) { + beta1_ = std::make_shared(); + one_sub_beta1_ = std::make_shared(); + beta2_ = std::make_shared(); + one_sub_beta2_ = std::make_shared(); + eps_ = std::make_shared(); + lr_ = std::make_shared(); + weight_decay_ = std::make_shared(); + param_ = std::make_shared(); + m_ = std::make_shared(); + v_ = std::make_shared(); + gradient_ = std::make_shared(); + } + ~AdamWeightDecayFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr beta1_; + VarPtr one_sub_beta1_; + VarPtr beta2_; + VarPtr one_sub_beta2_; + VarPtr eps_; + VarPtr lr_; + VarPtr weight_decay_; + VarPtr param_; + VarPtr m_; + VarPtr v_; + VarPtr gradient_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.cc new file mode 100644 index 0000000000..b531b0caa5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/mem_reuse/kernel_refcount.h" +#include +#include "utils/log_adapter.h" +namespace mindspore { +namespace memreuse { +/** + * Add some set && get function + */ +void KernelRefCount::SetKernelRefCountInfo(int index, size_t size, RefCountType reftype) { + index_ = index; + size_ = size; + reftype_ = reftype; +} + +std::vector KernelDef::GetInputRefIndexs() const { + std::vector input_ref_indexs; + if (input_refs_.empty()) { + return input_ref_indexs; + } + (void)std::transform(input_refs_.begin(), input_refs_.end(), std::back_inserter(input_ref_indexs), + [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); + return input_ref_indexs; +} + +std::vector KernelDef::GetOutputRefIndexs() const { + std::vector output_ref_indexs; + if (output_refs_.empty()) { + return output_ref_indexs; + } + (void)std::transform(output_refs_.begin(), output_refs_.end(), std::back_inserter(output_ref_indexs), + [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); + return output_ref_indexs; +} + +std::vector KernelDef::GetWorkspaceRefIndexs() const { + std::vector wk_ref_indexs; + if (wk_space_.empty()) { + return wk_ref_indexs; + } + // only one key + auto wk_refs_iter = wk_space_.begin(); + auto wk_refs = wk_refs_iter->second; + (void)std::transform(wk_refs.begin(), wk_refs.end(), std::back_inserter(wk_ref_indexs), + [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); + return wk_ref_indexs; +} +} // namespace memreuse +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/kernel_refcount.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h similarity index 100% rename from mindspore/ccsrc/pre_activate/mem_reuse/kernel_refcount.h rename to mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h new file mode 100644 index 0000000000..1952415515 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ + +#include +#include +#include +#include +#include +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/kernel.h" + +using HostAddress = mindspore::kernel::Address; +namespace mindspore { +namespace device { +namespace memswap { +enum class SwapKind { kDeviceToHost = 0, kHostToDevice = 1 }; + +struct TensorInfo { + size_t tensor_size_{0}; + AnfNodePtr kernel_{nullptr}; + size_t output_idx_{0}; +}; + +struct KernelExecutionInfo { + size_t topo_order_{0}; + float execution_perform_{0.0}; + bool trigger_swap_{false}; + bool need_swap_{false}; + // output index to topo orders of node users + std::map> node_users_map_; + // kernel output idx to host addr + std::map host_addrs_; + + KernelExecutionInfo() : KernelExecutionInfo(0, 0.0, false, false) {} + explicit KernelExecutionInfo(size_t topo_order) + : topo_order_(topo_order), execution_perform_(0.0), trigger_swap_(false), need_swap_(false) {} + KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap, bool need_swap) + : topo_order_(topo_order), + execution_perform_(execution_perform), + trigger_swap_(trigger_swap), + need_swap_(need_swap) {} +}; + +// trigger swap +struct MemSwapInfo { + SwapKind swap_kind_; + // kernel need to be swapped + AnfNodePtr kernel_{nullptr}; + size_t output_idx_{0}; +}; + +class MemCopyManager { + public: + MemCopyManager() = default; + + virtual ~MemCopyManager() = default; + + virtual void Init() {} + + virtual void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {} + + virtual void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {} + + virtual bool SyncMemCopyStream(SwapKind swap_kind) { return true; } + + virtual DeviceAddressPtr UpdateSwapOutQueue() { return nullptr; } + + virtual DeviceAddressPtr UpdateSwapInQueue() { return nullptr; } + + virtual bool AllocHostPinnedMem(size_t size, void **addr) const { return true; } + + virtual void FreeHostPinnedMem(void *addr) const {} + + virtual void ClearSwapQueue() {} +}; +using MemCopyManagerPtr = std::shared_ptr; +} // namespace memswap +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.cc new file mode 100644 index 0000000000..8f705be556 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.cc @@ -0,0 +1,326 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h" +#include "common/utils.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +DynamicMemPoolBestFit::~DynamicMemPoolBestFit() { + global_mem_block_list_.clear(); + global_idle_mem_buf_map_.clear(); +} + +DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) { + size_t align_size = AlignMemorySize(size); + // Find the idle memory buf by tensor size, if not find, then add new memory block and memory buf. + DeviceMemPtr device_addr = FindIdleMemBuf(align_size); + if (!device_addr) { + device_addr = AddMemBlockAndMemBuf(align_size); + } + return device_addr; +} + +std::vector DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size, + std::vector size_list) { + std::vector device_addr_list; + // Pre-alloc the one whole piece memory. + auto device_addr = AllocTensorMem(total_size); + if (!device_addr) { + return device_addr_list; + } + // Remove the pre-alloc memory. + auto mem_block = FindMemBlock(device_addr); + MS_EXCEPTION_IF_NULL(mem_block); + auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); + if (iter == mem_block->block_all_mem_buf_map_.end()) { + MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; + } + auto mem_buf = iter->second; + MS_EXCEPTION_IF_NULL(mem_buf); + auto rest_size = mem_buf->size_ - total_size; + (void)mem_block->block_all_mem_buf_map_.erase(iter); + // Split the pre-alloc memory into continuous memory by the size list. + DynamicMemBufPtr continuous_mem_buf; + auto buf_addr = device_addr; + for (size_t i = 0; i < size_list.size(); i++) { + continuous_mem_buf = std::make_shared(buf_addr, kMemBufUsed, size_list[i]); + (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf); + device_addr_list.emplace_back(buf_addr); + buf_addr = AddressOffset(buf_addr, size_list[i]); + } + // Update the size of the last memory buf. + continuous_mem_buf->size_ += rest_size; + return device_addr_list; +} + +size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const { + if (size == 0) { + return DYNAMIC_MEM_ALIGN_SIZE; + } + return ((size + DYNAMIC_MEM_ALIGN_SIZE - 1) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE; +} + +DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size) { + auto iter = global_idle_mem_buf_map_.lower_bound(size); + if (iter != global_idle_mem_buf_map_.end()) { + auto mem_buf = iter->second; + MS_EXCEPTION_IF_NULL(mem_buf); + if (mem_buf->status_ != kMemBufIdle) { + MS_LOG(EXCEPTION) << "Find the mem_buf is not idle, alloc_size[" << size << "] mem_buf_size[" << mem_buf->size_ + << "] mem_buf_address[" << mem_buf->device_addr_ << "]."; + } + mem_buf->status_ = kMemBufUsed; + // Remove map of old idle memory buf + (void)global_idle_mem_buf_map_.erase(iter); + // Divide memory buf + if (IsDivide(size, mem_buf->size_)) { + DivideMemBuf(size, mem_buf); + } + // Memory statistics + total_used_mem_statistics_ += mem_buf->size_; + if (total_used_mem_statistics_ > used_mem_peak_statistics_) { + used_mem_peak_statistics_ = total_used_mem_statistics_; + } + return mem_buf->device_addr_; + } + return nullptr; +} + +DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size) { + size_t alloc_mem_size = CalMemBlockAllocSize(size); + if (alloc_mem_size == 0) { + return nullptr; + } + // Add new memory block + DeviceMemPtr device_addr = nullptr; + auto real_alloc_size = AllocDeviceMem(alloc_mem_size, &device_addr); + if (real_alloc_size < size) { + MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size + << "]."; + return nullptr; + } + auto mem_block = std::make_shared(device_addr, real_alloc_size); + MS_EXCEPTION_IF_NULL(mem_block); + auto iter = std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock); + (void)global_mem_block_list_.insert(iter, mem_block); + // Add new memory buf + auto mem_buf = std::make_shared(device_addr, kMemBufUsed, real_alloc_size); + MS_EXCEPTION_IF_NULL(mem_buf); + // Add map of new memory buf in the block + (void)mem_block->block_all_mem_buf_map_.emplace(device_addr, mem_buf); + // Divide memory buf + if (IsDivide(size, mem_buf->size_)) { + DivideMemBuf(size, mem_buf); + } + // Memory statistics + total_mem_statistics_ += real_alloc_size; + total_used_mem_statistics_ += mem_buf->size_; + if (total_used_mem_statistics_ > used_mem_peak_statistics_) { + used_mem_peak_statistics_ = total_used_mem_statistics_; + } + return mem_buf->device_addr_; +} + +size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size) { + auto device_free_mem_size = free_mem_size(); + if (device_free_mem_size < size) { + MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size + << "] is smaller than required size[" << size << "]."; + return 0; + } + auto alloc_mem_size = mem_alloc_unit_size(); + // Growing at twice of alloc size + while (alloc_mem_size < size) { + alloc_mem_size = alloc_mem_size * 2; + } + alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size); + return alloc_mem_size; +} + +bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) const { + return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE; +} + +void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) { + MS_EXCEPTION_IF_NULL(mem_buf); + auto mem_block = FindMemBlock(mem_buf->device_addr_); + MS_EXCEPTION_IF_NULL(mem_block); + // Divide new memory buf + size_t newbuf_size = mem_buf->size_ - size; + mem_buf->size_ = size; + DeviceMemPtr newbuf_addr = AddressOffset(mem_buf->device_addr_, size); + auto new_mem_buf = std::make_shared(newbuf_addr, kMemBufIdle, newbuf_size); + // Add map of new memory buf in the block + (void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf); + // Add map of new idle memory buf + (void)global_idle_mem_buf_map_.emplace(newbuf_size, new_mem_buf); +} + +bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr device_addr, const DynamicMemBlockPtr mem_block) { + MS_EXCEPTION_IF_NULL(device_addr); + MS_EXCEPTION_IF_NULL(mem_block); + return device_addr < mem_block->device_addr(); +} + +DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr device_addr) { + MS_EXCEPTION_IF_NULL(device_addr); + auto iter = std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock); + if (iter != global_mem_block_list_.begin()) { + return *(--iter); + } + return nullptr; +} + +void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr device_addr) { + MS_EXCEPTION_IF_NULL(device_addr); + auto mem_block = FindMemBlock(device_addr); + if (mem_block == nullptr) { + MS_LOG(WARNING) << "Can't find the mem_block of the device address[" << device_addr << "]."; + return; + } + CombineMemBuf(mem_block, device_addr); +} + +void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr device_addr) { + MS_EXCEPTION_IF_NULL(mem_block); + MS_EXCEPTION_IF_NULL(device_addr); + auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); + if (iter == mem_block->block_all_mem_buf_map_.end()) { + MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; + } + auto mem_buf = iter->second; + MS_EXCEPTION_IF_NULL(mem_buf); + if (mem_buf->status_ != kMemBufUsed) { + MS_LOG(EXCEPTION) << "Find the mem_buf is not used, mem_buf_address[" << mem_buf->device_addr_ << "]."; + } + mem_buf->status_ = kMemBufIdle; + total_used_mem_statistics_ -= mem_buf->size_; + // Combine backward(combine the next_mem_buf to mem_buf) + auto next_iter = iter; + (void)next_iter++; + if (next_iter != mem_block->block_all_mem_buf_map_.end()) { + auto next_mem_buf = next_iter->second; + MS_EXCEPTION_IF_NULL(next_mem_buf); + if (next_mem_buf->status_ == kMemBufIdle) { + mem_buf->size_ += next_mem_buf->size_; + EraseIdleMemBuf(next_mem_buf->size_, next_mem_buf->device_addr_); + (void)mem_block->block_all_mem_buf_map_.erase(next_iter); + } + } + // Combine forward(combine the mem_buf to prev_mem_buf) + bool forward_combine = false; + DynamicMemBufPtr prev_mem_buf; + if (iter != mem_block->block_all_mem_buf_map_.begin()) { + auto prev_iter = iter; + (void)prev_iter--; + prev_mem_buf = prev_iter->second; + MS_EXCEPTION_IF_NULL(prev_mem_buf); + if (prev_mem_buf->status_ == kMemBufIdle) { + EraseIdleMemBuf(prev_mem_buf->size_, prev_mem_buf->device_addr_); + prev_mem_buf->size_ += mem_buf->size_; + (void)mem_block->block_all_mem_buf_map_.erase(iter); + forward_combine = true; + } + } + // Add map of new idle memory + if (forward_combine) { + (void)global_idle_mem_buf_map_.emplace(prev_mem_buf->size_, prev_mem_buf); + } else { + (void)global_idle_mem_buf_map_.emplace(mem_buf->size_, mem_buf); + } +} + +void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr device_addr) { + MS_EXCEPTION_IF_NULL(device_addr); + auto iter = global_idle_mem_buf_map_.equal_range(size); + while (iter.first != iter.second) { + MS_EXCEPTION_IF_NULL(iter.first->second); + // Remove map of the idle memory buf by size and device address + if (iter.first->second->device_addr_ == device_addr) { + (void)global_idle_mem_buf_map_.erase(iter.first); + return; + } + (void)iter.first++; + } + MS_LOG(ERROR) << "Can't find the size[" << size << "] and device address[" << device_addr << "] in the idle mem_buf."; +} + +void DynamicMemPoolBestFit::ReleaseDeviceRes() { + MS_LOG(INFO) << "The dynamic memmory pool total size is " << total_mem_statistics_ << ", total used size is " + << total_used_mem_statistics_ << ", used peak size is " << used_mem_peak_statistics_ << "."; + for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) { + auto device_addr = (*iter)->device_addr(); + if (device_addr != nullptr) { + if (!FreeDeviceMem(device_addr)) { + MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error."; + } + } + } +} + +void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() { + MS_LOG(INFO) << "Start dump dynamic memory pool info."; + DeviceAddrMapMemBuf mem_block_map; + DynamicMemBufPtr mem_buf; + size_t total_mem = 0; + size_t total_used_mem = 0; + size_t total_idle_mem1 = 0; + size_t total_idle_mem2 = 0; + // Dump the memory block info and memory buf info + MS_LOG(INFO) << "Dump all mem_block info: counts[" << global_mem_block_list_.size() << "]."; + for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) { + total_mem += (*iter)->size(); + mem_block_map = (*iter)->block_all_mem_buf_map_; + MS_LOG(INFO) << "MemBlock info: number[" << iter - global_mem_block_list_.begin() << "] mem_buf_counts[" + << mem_block_map.size() << "] base_address[" << (*iter)->device_addr() << "] block_size[" + << (*iter)->size() << "]."; + for (auto iter_mem_buf = mem_block_map.begin(); iter_mem_buf != mem_block_map.end(); ++iter_mem_buf) { + mem_buf = iter_mem_buf->second; + MS_EXCEPTION_IF_NULL(mem_buf); + if (mem_buf->status_ == kMemBufIdle) { + total_idle_mem1 += mem_buf->size_; + } else { + total_used_mem += mem_buf->size_; + } + MS_LOG(INFO) << "MemBuf info: address[" << mem_buf->device_addr_ << "] size[" << mem_buf->size_ << "] status[" + << mem_buf->status_ << "]."; + } + } + // Dump all the idle memory buf info + MS_LOG(INFO) << "Dump all idle mem_buf info: counts[" << global_idle_mem_buf_map_.size() << "]."; + for (auto iter_idle = global_idle_mem_buf_map_.begin(); iter_idle != global_idle_mem_buf_map_.end(); ++iter_idle) { + mem_buf = iter_idle->second; + MS_EXCEPTION_IF_NULL(mem_buf); + total_idle_mem2 += mem_buf->size_; + MS_LOG(INFO) << "Idle mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_ << "] status[" + << mem_buf->status_ << "]."; + } + // Dump the memory statistical info + MS_LOG(INFO) << "Total allocated memory[" << total_mem << "], used memory[" << total_used_mem << "], idle memory[" + << total_idle_mem1 << "]."; + if (total_idle_mem1 != total_idle_mem2) { + MS_LOG(ERROR) << "Check error: the idle memory in the mem_block is not equal the global idle memory."; + } + if (total_mem != total_used_mem + total_idle_mem1) { + MS_LOG(ERROR) << "Check error: the the total memory is not equal the sum of used memory and idle memory."; + } + MS_LOG(INFO) << "Finish dump dynamic memory pool info."; +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h similarity index 100% rename from mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h rename to mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc new file mode 100644 index 0000000000..263ceaec63 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc @@ -0,0 +1,436 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/mem_reuse/mem_reuse.h" +#include +#include +#include "backend/optimizer/mem_reuse/mem_reuse_checker.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace memreuse { +bool MemReuseUtil::InitDynamicOutputKernelRef() { + int index = util_index_; + auto kernel_cnodes = graph_->execution_order(); + if (kernel_cnodes.empty()) { + return true; + } + int kernel_out_ref_num = 0; + for (auto &kernel_cnode : kernel_cnodes) { +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().CheckSignalOps(kernel_cnode); +#endif + if (kernel_cnode == nullptr) { + return false; + } + auto kernel_mod = AnfAlgo::GetKernelMod(kernel_cnode); + if (kernel_mod == nullptr) { + return false; + } + auto key = kernel_cnode.get(); + // for every apply_kernel to set new output + auto iter = kernel_output_refs_.find(key); + if (iter == kernel_output_refs_.end()) { + auto output_sizes = kernel_mod->GetOutputSizeList(); + KernelRefCountPtrList kernel_refs; + for (auto size : output_sizes) { + total_dy_size_ += size; + // do not MallocDynamicMem just record this + KernelRefCountPtr kernel_ref = std::make_shared(); + index++; + auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode); + kernel_ref->stream_id_ = curr_stream_id; + kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); + kernel_refs.push_back(kernel_ref); + kernel_out_ref_num++; + total_refs_list_.push_back(kernel_ref); + } + if (!kernel_refs.empty()) { + kernel_output_refs_[key] = kernel_refs; + } + } + } + return true; +} + +bool MemReuseUtil::InitDynamicWorkspaceKernelRef() { + int WkIndex = util_index_; + auto kernel_cnodes = graph_->execution_order(); + if (kernel_cnodes.empty()) { + return true; + } + for (auto &kernel_cnode : kernel_cnodes) { + if (kernel_cnode == nullptr) { + return false; + } + auto kernel_mod = AnfAlgo::GetKernelMod(kernel_cnode); + if (kernel_mod == nullptr) { + return false; + } + auto key = kernel_cnode.get(); + auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); + KernelRefCountPtrList workspace_kernel_refs; + for (auto size : workspace_sizes) { + total_workspace_size_ += size; + ++WkIndex; + KernelRefCountPtr workspace_ref = std::make_shared(); + workspace_ref->SetKernelRefCountInfo(WkIndex, size, kDynamicRefCount); + workspace_kernel_refs.push_back(workspace_ref); + // total wk ref + total_wk_ref_list_.push_back(workspace_ref); + } + if (!workspace_kernel_refs.empty()) { + // every key index wk_refs + kernel_workspace_refs_[key] = workspace_kernel_refs; + } + } + return true; +} + +bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + graph_ = graph; + is_all_nop_node_ = opt::IsAllNopNode(graph); + if (!InitDynamicOutputKernelRef()) { + MS_LOG(INFO) << "InitDynamicOutputKernelRef fail"; + return false; + } + if (!InitDynamicWorkspaceKernelRef()) { + MS_LOG(INFO) << "InitDynamicWorkspaceKernelRef fail"; + return false; + } + return true; +} + +// set longest worspace list && largest workspace sizes +void MemReuseUtil::SetWorkSpaceList() { + int max_list_size = 0; + std::vector total_sizes; + std::vector max_list; + auto kernel_cnodes = graph_->execution_order(); + for (auto &kernel_cnode : kernel_cnodes) { + MS_EXCEPTION_IF_NULL(kernel_cnode); + auto cnode_key = kernel_cnode.get(); + auto cnode_iter = kernel_workspace_refs_.find(cnode_key); + if (cnode_iter != kernel_workspace_refs_.end()) { + auto kernel_refs = cnode_iter->second; + std::vector current_list; + for (size_t i = 0; i < kernel_refs.size(); ++i) { + auto size = kernel_refs[i]->size_; + current_list.push_back(size); + } + if (max_list_size < SizeToInt(current_list.size())) { + max_list_size = SizeToInt(current_list.size()); + } + (void)std::copy(current_list.begin(), current_list.end(), std::back_inserter(total_sizes)); + } + } + sort(total_sizes.rbegin(), total_sizes.rend()); + max_list.resize(IntToSize(max_list_size)); + if (SizeToInt(total_sizes.size()) < max_list_size) { + MS_LOG(EXCEPTION) << "total workspace size is less than required max list size"; + } + max_list.assign(total_sizes.begin(), total_sizes.begin() + max_list_size); + for (auto &ma : max_list) { + total_reuseworkspace_size_ += ma; + } + max_workspace_size_ = max_list_size; + max_workspace_list_ = max_list; +} + +void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_def_ptr); + auto key = kernel.get(); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto ref_ptr = GetKernelInputRef(kernel, i); + if (ref_ptr != nullptr) { + if (ref_ptr->reftype() == kStaticRefCount) { + continue; + } else if (ref_ptr->reftype() == kDynamicRefCount) { + auto iter = kernel_def_ptr->inputs_.find(key); + if (iter == kernel_def_ptr->inputs_.end()) { + kernel_def_ptr->inputs_[key].push_back(ref_ptr); + } else { + iter->second.push_back(ref_ptr); + } + } + } + } +} + +void MemReuseUtil::SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_def_ptr); + auto key = kernel.get(); + auto iter = kernel_def_ptr->outputs_.find(key); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t k = 0; k < kernel_mod->GetOutputSizeList().size(); ++k) { + KernelRefCountPtr kernel_ref = kernel_output_refs_[key][k]; + if (iter == kernel_def_ptr->outputs_.end()) { + kernel_def_ptr->outputs_[key].push_back(kernel_ref); + } else { + iter->second.push_back(kernel_ref); + } + } +} + +void MemReuseUtil::SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_def_ptr); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto key = kernel.get(); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + if (kernel_workspace_refs_.find(key) != kernel_workspace_refs_.end()) { + auto wk_refs = kernel_workspace_refs_[key]; + if (i < wk_refs.size()) { + auto wk_ref = wk_refs[i]; + kernel_def_ptr->wk_space_[key].push_back(wk_ref); + } else { + MS_LOG(EXCEPTION) << "current index: " << i << " larger than wk_refs size " << wk_refs.size(); + } + } else { + MS_LOG(EXCEPTION) << "kernel_workspace_refs_ init error"; + } + } +} + +KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { + if (node == nullptr) { + MS_LOG(EXCEPTION) << "The node pointer is a nullptr."; + } + if (node->isa()) { + auto ak_node = node->cast(); + auto key = ak_node.get(); + MemReuseChecker::GetInstance().CheckOutRef(kernel_output_refs_, ak_node, IntToSize(output_idx)); + return kernel_output_refs_[key][IntToSize(output_idx)]; + } + return nullptr; +} + +KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { + if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { + MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " + << AnfAlgo::GetInputTensorNum(kernel); + } + auto input_node = kernel->input(input_idx + 1); + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + session::KernelWithIndex kernel_input; + if (is_all_nop_node_) { + // The graph does not remove the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); + } else { + // The graph removes the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + } + if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { + MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; + } + auto result = GetRef(kernel_input.first, SizeToInt(kernel_input.second)); + return result; +} + +void MemReuseUtil::SetKernelDefMap() { + auto kernel_cnodes = graph_->execution_order(); + for (auto &kernel : kernel_cnodes) { + KernelDefPtr kernel_def_ptr = std::make_shared(); + kernel_def_ptr->set_kernel_name(AnfAlgo::GetCNodeName(kernel)); + kernel_def_ptr->set_scope_full_name(kernel->fullname_with_scope()); + kernel_def_ptr->set_stream_id(AnfAlgo::GetStreamId(kernel)); + SetInputMap(kernel, kernel_def_ptr.get()); + SetOutputMap(kernel, kernel_def_ptr.get()); + SetWkMap(kernel, kernel_def_ptr.get()); + auto key = kernel.get(); + kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); + kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); + kernel_def_ptr_list_.push_back(kernel_def_ptr); + kernel_map_[key] = kernel_def_ptr; + } + SetKernelDefInputs(); +} + +void MemReuseUtil::SetKernelDefInputs() { + for (const auto &kernel : graph_->execution_order()) { + MS_EXCEPTION_IF_NULL(kernel); + auto key = kernel.get(); + // find kernel_def according to cnode addr + auto iter = kernel_map_.find(key); + if (iter == kernel_map_.end()) { + MS_LOG(EXCEPTION) << "kernel [" << kernel->fullname_with_scope() << "] is not init."; + } + auto kernel_def = iter->second; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto ref_ptr = GetKernelInputRef(kernel, i); + if (ref_ptr != nullptr) { + // set the inputs of this kernel_def + auto input_node = AnfAlgo::GetInputNode(kernel, i); + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + session::KernelWithIndex input; + if (is_all_nop_node_) { + // The graph does not remove the nop node. + input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); + } else { + // The graph removes the nop node. + input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + } + if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { + MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; + } + auto input_key = (input.first).get(); + auto input_iter = kernel_map_.find(input_key); + if (input_iter == kernel_map_.end()) { + MS_LOG(EXCEPTION) << "kernel [" << (input.first)->fullname_with_scope() << "] is not init."; + } + kernel_def->InsertInputKernel(input_iter->second); + } + } + } +} + +void MemReuseUtil::SetReuseRefCount() { + auto kernels = graph_->execution_order(); + for (auto &kernel : kernels) { + auto key = kernel.get(); + for (auto &def : kernel_def_ptr_list_) { + auto iter = def->inputs_.find(key); + if (iter != def->inputs_.end()) { + for (auto &input : iter->second) { + input->ref_count_++; + input->ref_count_dynamic_use_++; + } + } + } + } +} + +void MemReuseUtil::SetSummaryNodesRefCount() { + bool summary_exist = graph_->summary_node_exist(); + if (!summary_exist) { + return; + } + + auto summary_nodes = graph_->summary_nodes(); + if (summary_nodes.empty()) { + return; + } + + size_t total_summary_size = 0; + for (auto &node_item : summary_nodes) { + auto node = node_item.second.first; + size_t index = IntToSize(node_item.second.second); + if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) { + KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; + kernel_ref->ref_count_ = kMaxRefCount; + kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; + total_summary_size += kernel_ref->size_; + MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; + } else { + MS_LOG(WARNING) << "Can't find summary node's kernel_def " << node->fullname_with_scope() << " index: " << index; + } + } +#ifdef MEM_REUSE_DEBUG + auto graph = *graph_; + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); +#endif + MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size; +} + +void MemReuseUtil::SetGraphOutputRefCount() { + auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); + for (const auto &node : nodes) { + session::KernelWithIndex kernel_input; + if (is_all_nop_node_) { + // The graph does not remove the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false); + } else { + // The graph removes the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + } + MS_EXCEPTION_IF_NULL(kernel_input.first); + if (!kernel_input.first->isa() || !AnfAlgo::IsRealKernel(kernel_input.first)) { + continue; + } + auto ak_node = kernel_input.first->cast(); + auto key = ak_node.get(); + auto iter = kernel_output_refs_.find(key); + if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) { + auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second]; + MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr); + kernel_ref_count_ptr->ref_count_ = kMaxRefCount; + kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount; + } + } +#ifdef MEM_REUSE_DEBUG + auto graph = *graph_; + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); +#endif +} + +void MemReuseUtil::ResetDynamicUsedRefCount() { + for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) { + for (auto &ref_count : iter->second) { + MS_EXCEPTION_IF_NULL(ref_count); + ref_count->ref_count_dynamic_use_ = ref_count->ref_count_; + } + } +} + +void MemReuseUtil::SetAllInfo(KernelGraph *graph) { + if (!InitDynamicKernelRef(graph)) { + MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; + } + SetKernelDefMap(); + SetReuseRefCount(); + SetSummaryNodesRefCount(); + SetWorkSpaceList(); +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); +#endif +} + +uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { + auto key = node.get(); + auto iter = kernel_output_refs_.find(key); + uint8_t *ptr = nullptr; + if (iter != kernel_output_refs_.end()) { + if (index >= iter->second.size()) { + MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]"; + } + auto output_ref = iter->second[index]; + ptr = mem_base_ + output_ref->offset_; + } else { + MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs"; + } + return ptr; +} + +uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const { + auto key = node.get(); + auto iter = kernel_workspace_refs_.find(key); + uint8_t *ptr = nullptr; + if (iter != kernel_workspace_refs_.end()) { + if (index >= iter->second.size()) { + MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]"; + } + auto wk_ref = iter->second[index]; + ptr = mem_base_ + wk_ref->offset_; + } + return ptr; +} +} // namespace memreuse +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h new file mode 100644 index 0000000000..b286bcbc2c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h @@ -0,0 +1,107 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ +#include +#include +#include +#include "backend/optimizer/mem_reuse/kernel_refcount.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +using mindspore::kernel::tbe::TbeUtils; +namespace mindspore { +namespace memreuse { +static constexpr int kMaxRefCount = 9999; +static constexpr size_t kDefaultMemAlignSize = 512; +static constexpr size_t kAttAlignSize = 31; +static constexpr int kInvalidIndex = -2; + +using KernelDefPtrMaps = std::vector; +using KernelRefs = std::map; + +using KernelGraph = mindspore::session::KernelGraph; + +class MemReuseUtil { + public: + KernelRefs kernel_output_refs_; + KernelRefCountPtrList total_refs_list_; + KernelRefCountPtrList total_wk_ref_list_; + KernelRefs kernel_workspace_refs_; + MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr), is_all_nop_node_(false) {} + ~MemReuseUtil() { + if (graph_ != nullptr) { + graph_ = nullptr; + } + MS_LOG(INFO) << "Total Dynamic Memory Size: " << total_dy_size_; + MS_LOG(INFO) << "Total WorkSpace Memory Size: " << total_workspace_size_; + MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_; + } + + void SetAllInfo(KernelGraph *graph); + bool InitDynamicOutputKernelRef(); + bool InitDynamicWorkspaceKernelRef(); + bool InitDynamicKernelRef(const KernelGraph *graph); + void SetWorkSpaceList(); + void SetKernelDefMap(); + void SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); + void SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); + void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); + void SetKernelDefInputs(); + void SetReuseRefCount(); + void SetSummaryNodesRefCount(); + // Set the reference count of graph output specially. + void SetGraphOutputRefCount(); + // Reset the dynamic used reference count by ref_count_. + void ResetDynamicUsedRefCount(); + + KernelRefCountPtr GetRef(const AnfNodePtr &node, int output_idx); + KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx); + KernelRefCountPtrList total_refs_list() const { return total_refs_list_; } + KernelRefCountPtrList total_wk_ref_list() const { return total_wk_ref_list_; } + KernelDefPtrMaps kernel_def_ptr_list() const { return kernel_def_ptr_list_; } + int max_workspace_size() const { return max_workspace_size_; } + std::vector max_workspace_list() const { return max_workspace_list_; } + void set_total_refs_list(const KernelRefCountPtrList &total_refs_list) { total_refs_list_ = total_refs_list; } + void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) { + kernel_def_ptr_list_ = kernel_def_ptr_list; + } + void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; } + uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const; + uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const; + + private: + int util_index_; + const KernelGraph *graph_; + bool is_all_nop_node_; + KernelRefCountPtrList ref_list_; + KernelDefPtrMaps kernel_def_ptr_list_; + KernelRefCountPtrList last_ref_list_; + int max_workspace_size_ = 0; + std::vector max_workspace_list_; + size_t total_dy_size_ = 0; + size_t total_workspace_size_ = 0; + size_t total_reuseworkspace_size_ = 0; + uint8_t *mem_base_{nullptr}; + // kernel_map_: key is the AnfNodePtr addr, value is the KernelDef + std::map kernel_map_; +}; +using MemReuseUtilPtr = std::shared_ptr; +} // namespace memreuse +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc new file mode 100644 index 0000000000..d1a50a0dfe --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc @@ -0,0 +1,423 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" +#include "backend/optimizer/mem_reuse/mem_reuse.h" +#include "backend/optimizer/mem_reuse/mem_reuse_checker.h" +#ifdef ENABLE_D +#include "runtime/device/ascend/ascend_stream_assign.h" +#endif +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#include "debug/debug_services.h" +#endif + +namespace mindspore { +namespace memreuse { +void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) { + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + set_tensor_ptr_list(mem_reuse_util_ptr->total_refs_list()); + set_workspace_ptr_list(mem_reuse_util_ptr->total_wk_ref_list()); + set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list()); + // check info Correctness + for (auto &tensor : tensor_ptr_list_) { + tensor->size_ = AlignMemorySize(tensor->size_); + } + // align wk size to 512 && refcount == 1 + for (auto &wk : wk_tensor_list_) { + wk->size_ = AlignMemorySize(wk->size_); + wk->ref_count_ = 1; + } +#ifdef ENABLE_D + stream_groups_ = device::ascend::AscendStreamAssign::GetInstance().get_stream_group(); +#endif +} + +void BestFitMemReuse::InitKernelDependence() { + for (const auto &kernel : op_ptr_list_) { + std::set front; + std::queue to_visit; + to_visit.push(kernel); + // find all kernels before current kernel + while (!to_visit.empty()) { + auto curr = to_visit.front(); + to_visit.pop(); + if (front.count(curr)) { + continue; + } + front.insert(curr); + auto iter = kernel_front_map_.find(curr); + if (iter != kernel_front_map_.end()) { + auto visited_front = iter->second; + front.insert(visited_front.begin(), visited_front.end()); + continue; + } + for (const auto &input : curr->input_kernels()) { + to_visit.push(input); + } + } + kernel_front_map_[kernel] = front; + } +} + +bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr &mem_buf) { + // determine whether the kernel_curr can reuse kernel_prev's output tensor membuf + MS_EXCEPTION_IF_NULL(kernel_curr); + MS_EXCEPTION_IF_NULL(mem_buf); + auto kernel_prev = mem_buf->used_kernel_; + MS_EXCEPTION_IF_NULL(kernel_prev); +#ifdef ENABLE_DEBUGGER + auto debugger_ = mindspore::Debugger::GetInstance(); + DebugServices *debug_services = debugger_->debug_services(); + auto watchpoint_table = debug_services->GetWatchpointTable(); + std::string current_kernel_name = kernel_curr->scope_full_name(); + if (debug_services->IsWatchPoint(current_kernel_name, watchpoint_table)) { + return false; + } +#endif + auto curr_stream_id = kernel_curr->stream_id(); + auto prev_stream_id = kernel_prev->stream_id(); + if (curr_stream_id == prev_stream_id) { + mem_buf->type_ = IN_STREAM_REUSE; + return true; + } + + bool reuse_between_streams = true; + for (auto &stream_group : stream_groups_) { + size_t cur_index = UINT32_MAX; + size_t prev_index = UINT32_MAX; + for (size_t index = 0; index < stream_group.size(); index++) { + if (curr_stream_id == stream_group[index]) { + cur_index = index; + continue; + } + if (prev_stream_id == stream_group[index]) { + prev_index = index; + continue; + } + } + if ((prev_index != UINT32_MAX) && (cur_index == UINT32_MAX || (prev_index > cur_index))) { + // previous stream and current stream are not in the same group can't be reused + // previous stream is behind current stream can't be reused + reuse_between_streams = false; + break; + } + } + + if (reuse_between_streams) { + mem_buf->type_ = BETWEEN_STREAMS_REUSE; + return true; + } + + auto iter = kernel_front_map_.find(kernel_curr); + if (iter == kernel_front_map_.end()) { + MS_LOG(EXCEPTION) << kernel_curr->scope_full_name() << " is not init."; + } + auto kernel_curr_front = iter->second; + auto depend_count = kernel_curr_front.count(kernel_prev); + if (depend_count) { + mem_buf->type_ = KERNEL_DEPENDENCE_REUSE; + return true; + } + + return false; +} + +void BestFitMemReuse::AssignNodeOutputOffset() { + for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + size_t index = GetTensorIndex(tensor_idx); + auto tensor_desc = tensor_ptr_list_[index]; + MS_EXCEPTION_IF_NULL(tensor_desc); + auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_); + if (!reusable_membuf_map.empty()) { + auto membuf_index = reusable_membuf_map.begin()->second; + // find the best suitable membuf in membuf list, and reuse it + ReuseExistMembuf(tensor_desc.get(), membuf_index, kDynamicMem); + } else { + // no membuf can reuse, add new membuf after the membuf_ptr_list + AddNewMembufPtr(tensor_desc.get(), kDynamicMem); +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; +#endif + } + } +} + +void BestFitMemReuse::AssignNodeWorkspaceOffset() { + for (auto &wk_idx : current_kernel_->GetWorkspaceRefIndexs()) { + size_t index = GetWorkspaceIndex(wk_idx); + auto wk_ref = wk_tensor_list_[index]; + MS_EXCEPTION_IF_NULL(wk_ref); + auto re_wk_membuf_map = GetReusableMembufMap(wk_ref->size_); + if (!re_wk_membuf_map.empty()) { + auto membuf_index = re_wk_membuf_map.begin()->second; + ReuseExistMembuf(wk_ref.get(), membuf_index, kWorkspaceMem); + } else { + AddNewMembufPtr(wk_ref.get(), kWorkspaceMem); + } + } +} + +void BestFitMemReuse::ReuseExistMembuf(KernelRefCount *tensor_desc, size_t membuf_index, int flag) { + MS_EXCEPTION_IF_NULL(tensor_desc); + CheckMembufIndx(membuf_index); + auto membuf = membuf_ptr_list_[membuf_index]; + MS_EXCEPTION_IF_NULL(membuf); + // first to split && then update membuf_info + if (IsSplit(tensor_desc->size_, membuf->size_)) { + // split the membuf, and insert a new membuf after this membuf + SplitMembuf(tensor_desc, membuf_index); + } + // update membuf status, and set tensor offset + UpdateMembufInfo(tensor_desc, membuf.get(), flag); +} + +std::map BestFitMemReuse::GetReusableMembufMap(size_t tensor_size) { + std::map size_map; + for (size_t i = 0; i < membuf_ptr_list_.size(); ++i) { + auto membuf = membuf_ptr_list_[i]; + auto index = i; + bool is_membuf_ok = membuf->status_ == kUnused && membuf->size_ >= tensor_size; + if (is_membuf_ok && IsUsable(current_kernel_, membuf)) { + (void)size_map.insert(std::make_pair(membuf->size_, index)); + break; + } + } + return size_map; +} + +void BestFitMemReuse::UpdateMembufInfo(KernelRefCount *tensor_desc, Membuf *membuf, int flag) { + MS_EXCEPTION_IF_NULL(tensor_desc); + MS_EXCEPTION_IF_NULL(membuf); + auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); + membuf->status_ = kReused; + membuf->index_ = real_index; + membuf->used_kernel_ = current_kernel_; + tensor_desc->offset_ = membuf->offset_; +} + +bool BestFitMemReuse::IsSplit(size_t tensor_size, size_t membuf_size) const { return tensor_size < membuf_size; } + +void BestFitMemReuse::SplitMembuf(const KernelRefCount *tensor_desc, size_t membuf_index) { + MS_EXCEPTION_IF_NULL(tensor_desc); + CheckMembufIndx(membuf_index); + auto membuf = membuf_ptr_list_[membuf_index]; + MS_EXCEPTION_IF_NULL(membuf); + auto bias = membuf->size_ - tensor_desc->size_; + membuf->size_ = tensor_desc->size_; + // to check if spilt membuf can be merge + auto new_membuf = std::make_shared(kUnused, bias, membuf->offset_ + membuf->size_, kInvalidIndex, + membuf->type_, current_kernel_); + (void)membuf_ptr_list_.insert(membuf_ptr_list_.begin() + SizeToInt(membuf_index + 1), new_membuf); +} + +void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) { + MS_EXCEPTION_IF_NULL(tensor_desc); + size_t membuf_offset = 0; + if (!membuf_ptr_list_.empty()) { + membuf_offset = membuf_ptr_list_.back()->offset_ + membuf_ptr_list_.back()->size_; + } + auto membuf_size = tensor_desc->size_; + auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); + auto membuf = std::make_shared(kReused, membuf_size, membuf_offset, real_index, NEW, current_kernel_); + membuf_ptr_list_.push_back(membuf); + tensor_desc->offset_ = membuf_offset; +} + +void BestFitMemReuse::UpdateNodeInputAndMembuf() { + // process node input tensor + for (const auto &tensor_idx : current_kernel_->GetInputRefIndexs()) { + size_t tensor_index = GetTensorIndex(tensor_idx); + auto tensor_desc = tensor_ptr_list_[tensor_index]; + MS_EXCEPTION_IF_NULL(tensor_desc); + tensor_desc->ref_count_--; + if (tensor_desc->ref_count_ == 0) { + ReleaseMembuf(tensor_index, kDynamicMem); + } else if (tensor_desc->ref_count_ < 0) { + MS_LOG(EXCEPTION) << "tensor: " << tensor_desc->index_ << " refcount: " << tensor_desc->ref_count_ + << " check error"; + } + } +} + +void BestFitMemReuse::ReleaseNodeUnusedOutput() { + for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + size_t tensor_index = GetTensorIndex(tensor_idx); + auto tensor_desc = tensor_ptr_list_[tensor_index]; + MS_EXCEPTION_IF_NULL(tensor_desc); + if (tensor_desc->ref_count_ == 0) { + ReleaseMembuf(tensor_index, kDynamicMem); + } else if (tensor_desc->ref_count_ < 0) { + MS_LOG(EXCEPTION) << "tensor: " << tensor_desc->index_ << " refcount: " << tensor_desc->ref_count_ + << " check error"; + } + } +} + +void BestFitMemReuse::ReleasePreNodeWorkspace(const KernelDef *kernel_def_ptr) { + for (auto &workspace_index : kernel_def_ptr->GetWorkspaceRefIndexs()) { + size_t index = GetWorkspaceIndex(workspace_index); + auto wk_tensor = wk_tensor_list_[index]; + wk_tensor->ref_count_--; + if (wk_tensor->ref_count_ == 0) { + ReleaseMembuf(index, kWorkspaceMem); + } else if (wk_tensor->ref_count_ < 0) { + MS_LOG(EXCEPTION) << "tensor: " << wk_tensor->index_ << " refcount: " << wk_tensor->ref_count_ << " check error"; + } + } +} + +void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) { + if (membuf_ptr_list_.empty()) { + return; + } + auto real_index = GetRealIndex(tensor_index, flag); + auto membuf_iter = std::find_if(membuf_ptr_list_.begin(), membuf_ptr_list_.end(), + [real_index](const MembufPtr &membuf) { return membuf->index_ == real_index; }); + if (membuf_iter == membuf_ptr_list_.end()) { + return; + } + auto membuf = (*membuf_iter); + MS_EXCEPTION_IF_NULL(membuf); + membuf->status_ = kUnused; + if (membuf_iter != membuf_ptr_list_.end() - 1) { + auto next_iter = membuf_iter + 1; + auto membuf_next = (*next_iter); + MS_EXCEPTION_IF_NULL(membuf_next); + if (membuf_next->status_ == kUnused) { + bool is_merge = IsUsable(current_kernel_, membuf_next); + if (is_merge) { + membuf->size_ += membuf_next->size_; + (void)membuf_ptr_list_.erase(next_iter); + } + } + } + if (membuf_iter != membuf_ptr_list_.begin()) { + auto prev_iter = membuf_iter - 1; + auto membuf_prev = (*prev_iter); + MS_EXCEPTION_IF_NULL(membuf_prev); + if (membuf_prev->status_ == kUnused) { + bool is_merge = IsUsable(current_kernel_, membuf_prev); + if (is_merge) { + membuf->size_ += membuf_prev->size_; + membuf->offset_ = membuf_prev->offset_; + (void)membuf_ptr_list_.erase(prev_iter); + } + } + } +} + +size_t BestFitMemReuse::AlignMemorySize(size_t size) const { + // memory size 512 align + return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize; +} + +size_t BestFitMemReuse::GetAllocatedSize() { + size_t AllocatedSize = kTotalSize; + if (membuf_ptr_list_.empty()) { + return AllocatedSize; + } + AllocatedSize = membuf_ptr_list_.back()->offset_ + membuf_ptr_list_.back()->size_; + MS_LOG(INFO) << "MemReuse Allocated Dynamic Size: " << AllocatedSize; + return AllocatedSize; +} + +bool BestFitMemReuse::IsRelease() { + // unable_used_node include the node type that output tensor cannot be released, + // even if its refcount is equal to zero. + std::unordered_set unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(), + prim::kPrimFusedBatchNorm->name(), + prim::kPrimFusedBatchNormGrad->name()}; + return unable_used_node.find(current_kernel_->kernel_name()) == unable_used_node.end(); +} + +size_t BestFitMemReuse::GetTensorIndex(int index) const { + if (index < 0 || IntToSize(index) >= tensor_ptr_list_.size()) { + MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); + MS_LOG(EXCEPTION) << "invalid tensor index"; + } + return IntToSize(index); +} + +size_t BestFitMemReuse::GetWorkspaceIndex(int index) const { + if (index < 0 || IntToSize(index) >= wk_tensor_list_.size()) { + MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); + MS_LOG(EXCEPTION) << "invalid tensor index"; + } + return IntToSize(index); +} + +int BestFitMemReuse::GetRealIndex(size_t index, int flag) const { + if (flag == kDynamicMem) { + return SizeToInt(index); + } else if (flag == kWorkspaceMem) { + return kWorkspaceIndexFactor * SizeToInt(index + 1); + } else { + MS_LOG(EXCEPTION) << "flag " << flag << " is invalid"; + } +} + +void BestFitMemReuse::CheckMembufIndx(size_t membuf_index) const { + if (membuf_index >= membuf_ptr_list_.size()) { + MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); + MS_LOG(EXCEPTION) << "invalid membuf index: " << membuf_index << ", real size: " << membuf_ptr_list_.size(); + } +} + +void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + InitMemReuseInfo(mem_reuse_util_ptr); + InitKernelDependence(); + KernelDefPtr pre_op = nullptr; +#ifdef MEM_REUSE_DEBUG + size_t op_num = 0; +#endif + for (const auto &op_def_ptr : op_ptr_list_) { + current_kernel_ = op_def_ptr; + // releas pre_op_def + if (pre_op != nullptr) { + ReleasePreNodeWorkspace(pre_op.get()); + } + MemReuseChecker::GetInstance().IsAddNewMembuf_ = false; + // process node output tensor + AssignNodeOutputOffset(); +#ifdef MEM_REUSE_DEBUG + if (MemReuseChecker::GetInstance().IsAddNewMembuf_) { + MemReuseChecker::GetInstance().SetAddNewMembuInfos(op_def_ptr.get(), membuf_ptr_list_, op_num); + } +#endif + // deal with current op'workspace + AssignNodeWorkspaceOffset(); + pre_op = op_def_ptr; + // update node input tensor refcount, and membuf list status + UpdateNodeInputAndMembuf(); + // check node output tensor which refcount is equal to zero + if (IsRelease()) { + ReleaseNodeUnusedOutput(); + } +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().SetMembuInfos(op_def_ptr.get(), membuf_ptr_list_); + ++op_num; +#endif + } +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().ExportMembufInfoIR(); + MemReuseChecker::GetInstance().ExportAddNewMmebufIR(); + MemReuseChecker::GetInstance().set_kernel_front_map(kernel_front_map_); + MemReuseChecker::GetInstance().ExportKernelDependence(); +#endif +} +} // namespace memreuse +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h new file mode 100644 index 0000000000..ef1cfd3e11 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h @@ -0,0 +1,159 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/optimizer/mem_reuse/kernel_refcount.h" +#include "backend/optimizer/mem_reuse/mem_reuse.h" + +namespace mindspore { +namespace memreuse { +static constexpr int kWorkspaceIndexFactor = -1000; +static constexpr int kDynamicMem = -1; +static constexpr int kWorkspaceMem = 1; +static constexpr size_t kTotalSize = 0; +enum Status { kUnused, kReused }; +enum MEMTYPE { NEW, IN_STREAM_REUSE, BETWEEN_STREAMS_REUSE, KERNEL_DEPENDENCE_REUSE }; +class Membuf { + public: + Membuf() = default; + Membuf(Status status, size_t size, size_t offset, int index, MEMTYPE type, const KernelDefPtr &used_kernel) + : status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {} + ~Membuf() = default; + // Memory block status flags + Status status_ = kUnused; + size_t size_{0}; + size_t offset_{0}; + // Store the tensor index stored in this memory block at a certain moment + int index_{0}; + MEMTYPE type_{NEW}; + KernelDefPtr used_kernel_; +}; +using MembufPtr = std::shared_ptr; + +class BestFitMemReuse { + public: + BestFitMemReuse() = default; + ~BestFitMemReuse() { membuf_ptr_list_.clear(); } + /** + * Init all information need by memory reuse + * @param mem_reuse_util_ptr, initialize in the memreuse.cc + */ + void InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr); + void CheckMembufIndx(size_t check_idx) const; + void AssignNodeWorkspaceOffset(); + void ReleasePreNodeWorkspace(const KernelDef *kernel_def_ptr); + /** + * Assign output tensor memory offset of current kernel + */ + void AssignNodeOutputOffset(); + /** + * Update input tensor's status of current kernel, and the status of membuf used by current kernel + */ + void UpdateNodeInputAndMembuf(); + /** + * Check whether to release the kernel output tensor which refcount is equal to zero + */ + void ReleaseNodeUnusedOutput(); + /** + * Reuse the exist membuf if possible + * @param tensor_desc, the output tensor of current kernel + * @param membuf_index, the index of membuf to be reused + * @param flag + */ + void ReuseExistMembuf(KernelRefCount *tensor_desc, size_t membuf_index, int flag); + /** + * Get the membuf that can be reused + * @param tensor_size, the size of the tensor ready to assign memory offset + * @return membuf map, key: the membuf size, value: the membuf index + */ + std::map GetReusableMembufMap(size_t tensor_size); + /** + * Update the status of the reused memory block + * @param tensor_desc, the tensor ready to assign memory + * @param membuf, the membuf to be reused + * @param flag, distinguish dynamic memory and workspace + */ + void UpdateMembufInfo(KernelRefCount *tensor_desc, Membuf *membuf, int flag); + // If the size of the memory block is greater than the size of the tensor, split the extra memory + void SplitMembuf(const KernelRefCount *tensor_desc, size_t membuf_index); + // Determine if the memory block needs to be split + bool IsSplit(size_t tensor_size, size_t membuf_size) const; + // If there is no memory block that can be reused, add a new memory block at the end + void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag); + // Merge unused membuf + void ReleaseMembuf(size_t tensor_index, int flag); + // Memory address alignment 512 + size_t AlignMemorySize(size_t size) const; + int GetRealIndex(size_t index, int flag = kDynamicMem) const; + size_t GetTensorIndex(int index) const; + size_t GetWorkspaceIndex(int index) const; + // Memory reuse main program entry + void Reuse(const MemReuseUtil *mem_reuse_util_ptr); + // Get the total memory that needs to be applied eventually + size_t GetAllocatedSize(); + // return false, when the node output cannot be released + bool IsRelease(); + /** + * determine if the kernel_curr can reuse the output tensor add of kernel_prev + * @param kernel_curr, current kernel + * @param mem_buf, the membuf + * @return bool + */ + bool IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr &mem_buf); + /** + * init the dependence of all kernels in the graph + */ + void InitKernelDependence(); + // set tensor_def and op_def + void set_tensor_ptr_list(const std::vector &tensor_ptr_list) { + tensor_ptr_list_ = tensor_ptr_list; + } + void set_workspace_ptr_list(const std::vector &workspace_ptr_list) { + wk_tensor_list_ = workspace_ptr_list; + } + void set_op_ptr_list(const std::vector &op_ptr_list) { op_ptr_list_ = op_ptr_list; } + + private: + KernelDefPtr current_kernel_; + // Save all tensor information + std::vector tensor_ptr_list_; + std::vector wk_tensor_list_; + // Save all op information, including input and output tensor index + std::vector op_ptr_list_; + // Memory block information sequence, temporary variables + std::vector membuf_ptr_list_; + // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def + std::map> kernel_front_map_; + std::vector> stream_groups_; +}; +} // namespace memreuse +} // namespace mindspore +#endif // #define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc new file mode 100644 index 0000000000..b93bf42f9f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc @@ -0,0 +1,572 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/mem_reuse/mem_reuse_checker.h" +#include +#include +#include +#include + +namespace mindspore { +namespace memreuse { +MemReuseChecker &MemReuseChecker::GetInstance() { + static MemReuseChecker instance; + return instance; +} + +void MemReuseChecker::CheckSignalOps(const CNodePtr &c_node) { + std::string node_name = AnfAlgo::GetCNodeName(c_node); + if (node_name == kSend || node_name == kRecv) { + MS_LOG(INFO) << "MemReuseChecker check op_name of Send or Send"; + // get op's info && check + MS_LOG(INFO) << "op: " << node_name << " in_num: " << AnfAlgo::GetInputTensorNum(c_node) + << " out_num: " << AnfAlgo::GetOutputTensorNum(c_node); + } +} + +void MemReuseChecker::CheckWorkSpace(const std::vector &max_list) { + for (auto &ma : max_list) { + total_re_wkspe_size_checker_ += ma; + } +} + +void MemReuseChecker::CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx) { + auto key = c_node.get(); + auto iter = kernel_refs.find(key); + auto node_name = AnfAlgo::GetCNodeName(c_node); + if (iter == kernel_refs.end()) { + MS_LOG(EXCEPTION) << "kernel [" << node_name << "] has no output tensor, node: " << c_node->DebugString() + << " output index: " << output_idx; + } + if (output_idx >= iter->second.size()) { + MS_LOG(INFO) << "invalid cnode: " << c_node->fullname_with_scope().c_str(); + MS_LOG(EXCEPTION) << "The index: " << output_idx + << " is out of the size of kernel_output_refs_:" << iter->second.size(); + } +} + +int64_t MemReuseChecker::CalculOriInput(const KernelGraph *graph) const { + MS_EXCEPTION_IF_NULL(graph); + int64_t static_input_size = 0; + for (auto &item : graph->inputs()) { + if (!item->isa()) { + continue; + } + auto output_size = AnfAlgo::GetOutputTensorNum(item); + for (size_t index = 0; index < output_size; index++) { + TypeId ou_type = AnfAlgo::GetOutputDeviceDataType(item, index); + // parameter has not init by a cnode + if (ou_type == kTypeUnknown) { + ou_type = AnfAlgo::GetOutputInferDataType(item, index); + } + size_t type_size = GetTypeByte(TypeIdToType(ou_type)); + std::vector shape = AnfAlgo::GetOutputDeviceShape(item, index); + size_t tensor_size = + shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + auto checker_size = SizeToLong(tensor_size); + static_input_size += checker_size; + } + } + return static_input_size; +} + +int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const { + MS_EXCEPTION_IF_NULL(graph); + int64_t static_value_size = 0; + for (auto &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + auto &node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + auto tensor = node_value->cast(); + if (tensor == nullptr) { + continue; + } + size_t tensor_size = tensor->data().nbytes(); + auto checker_size = SizeToLong(tensor_size); + static_value_size += checker_size; + } + return static_value_size; +} + +int64_t MemReuseChecker::CalculOriStatic(KernelGraph *graph) const { + // cal static inputs + auto static_input_size = CalculOriInput(graph); + // do not calcul outpput size + auto statica_value_size = CalculOriValue(graph); + auto total_ori_static_size = static_input_size + statica_value_size; + return total_ori_static_size; +} + +int64_t MemReuseChecker::CalculOriDy(const KernelGraph *graph) const { + MS_EXCEPTION_IF_NULL(graph); + int64_t ori_dy_size = 0; + auto kerenls = graph->execution_order(); + for (auto &kernel : kerenls) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (auto &dy_size : kernel_mod->GetOutputSizeList()) { + auto checker_size = SizeToLong(dy_size); + ori_dy_size += checker_size; + } + } + return ori_dy_size; +} + +int64_t MemReuseChecker::CalculOriWk(const KernelGraph *graph) const { + MS_EXCEPTION_IF_NULL(graph); + int64_t ori_wk_size = 0; + auto kerenls = graph->execution_order(); + for (auto &kernel : kerenls) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (auto &wk_size : kernel_mod->GetWorkspaceSizeList()) { + auto checker_size = SizeToLong(wk_size); + ori_wk_size += checker_size; + } + } + return ori_wk_size; +} + +std::string MemReuseChecker::GetSplitName(const std::string &scope_name) const { + auto indx = scope_name.rfind(kSplitC); + if (indx == std::string::npos) { + return scope_name; + } else { + if (indx < scope_name.size() - 1) { + auto split_name = scope_name.substr(indx + 1); + return split_name; + } + return scope_name; + } +} + +void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, + const KernelDefPtrMaps &kernel_def_ptr_list, KernelGraph *graph) { + total_ori_static_size_ = CalculOriStatic(graph); + total_ori_input_size_ = CalculOriInput(graph); + total_ori_value_size_ = CalculOriValue(graph); + total_ori_dy_size_ = CalculOriDy(graph); + total_ori_wkspace_size_ = CalculOriWk(graph); + std::string graph_id = std::to_string(graph->graph_id()); + std::string filename = "./memreuse_" + graph_id + ".ir"; + std::ofstream ofs(filename); + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; + return; + } + ofs << "all_tensor_refs:\n"; + ofs << "index:" + << "\tsize:" + << "\trefcount:\n"; + for (auto &ref : total_refs_list) { + ofs << "%" << ref->index_ << "T" + << "\t" + << "#" << ref->size_ << "S" + << "\t" << ref->ref_count_ << "C" + << "\n"; + } + ofs << "kernel_def exc_order:\n"; + int def_idx = 0; + for (auto &def : kernel_def_ptr_list) { + ExportMemOpIr(def.get(), ofs, def_idx); + def_idx++; + } + ofs.close(); +} + +void MemReuseChecker::ExportKernelDependence() { + std::string filename = "./memreuse_dependence.ir"; + std::ofstream ofs(filename); + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; + return; + } + size_t i = 0; + for (const auto &kernel_front : kernel_front_map_) { + auto kernel = kernel_front.first; + auto front = kernel_front.second; + ofs << "[" << i++ << "] " << kernel->scope_full_name() << "\n"; + for (const auto &node : front) { + ofs << node->scope_full_name() << "\n"; + } + ofs << "\n\n"; + } + + ofs.close(); +} + +bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph) { + // set real graph output node to be special who's refcount equal kMaxRefCount + for (const auto &output : graph->outputs()) { + MS_EXCEPTION_IF_NULL(output); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) { + if (output->isa()) { + auto cnode = output->cast(); + auto input_node = cnode->input(i + 1); + auto kernel_input_with_idx = AnfAlgo::VisitKernel(input_node, 0); + auto kernel_input = kernel_input_with_idx.first; + MS_EXCEPTION_IF_NULL(kernel_input); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel_input); + if (kernel_mod == nullptr) { + continue; + } + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.empty()) { + continue; + } + for (size_t j = 0; j < output_sizes.size(); ++j) { + if (!AnfAlgo::OutputAddrExist(kernel_input, j)) { + return false; + } + } + } + } + } + return true; +} + +void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) { + auto scope_name = def->scope_full_name(); + std::string split_name = GetSplitName(scope_name); + ofs << "$" << def_idx << "\t" << split_name << "\t"; + ofs << "inputs["; + for (auto &in : def->inputs_) { + for (auto &in_ref : in.second) { + ofs << "%" << in_ref->index_ << "T" + << ","; + } + } + ofs << "]"; + ofs << "\toutpus["; + for (auto &ou : def->outputs_) { + for (auto &ou_ref : ou.second) { + ofs << "%" << ou_ref->index_ << "T" + << ","; + } + } + ofs << "]"; + ofs << "\tstreamID[" + << "@" << def->stream_id() << "]\n"; +} + +void MemReuseChecker::ExportNormalTensorIR(std::ofstream &ofs) { + ofs << "all_tensor_refs:\n"; + ofs << "index:" + << "\tsize:" + << "\trefcount:\n"; + size_t ou_idx = 0; + for (auto &ou : nor_output_tensors_) { + ofs << "%" << ou_idx << "T" + << "\t" + << "#" << nor_tensor_sizes_[ou_idx] << "S" + << "\t"; + auto iter_ref = ptr_refs_.find(ou); + if (iter_ref != ptr_refs_.end()) { + ofs << iter_ref->second << "C" + << "\n"; + } else { + MS_LOG(EXCEPTION) << "can not find refs for output"; + } + ou_idx++; + } + ofs << "kernel_def exc_order:\n"; +} + +int MemReuseChecker::GetTensorIdx(const void *in) const { + auto iter = ptr_idx_.find(in); + if (iter == ptr_idx_.end()) { + return kInvalidIndex; + } else { + return SizeToInt(iter->second); + } +} + +void MemReuseChecker::ExportNormalOpIr(const std::vector &cnodes) { + std::ofstream ofs("./normal_mem.ir"); + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file failed!"; + return; + } + ExportNormalTensorIR(ofs); + size_t node_idx = 0; + for (const auto &node : cnodes) { + MS_EXCEPTION_IF_NULL(node); + ofs << "$" << node_idx << "\t" << GetSplitName(node->fullname_with_scope()) << "\t"; + std::vector in_idx; + auto iter = node_ins_.find(node.get()); + if (iter != node_ins_.end()) { + for (auto &in : iter->second) { + if (GetTensorIdx(in) != kInvalidIndex) { + in_idx.push_back(GetTensorIdx(in)); + } + } + } + std::vector ou_idx; + iter = node_ous_.find(node.get()); + if (iter != node_ous_.end()) { + for (auto &ou : iter->second) { + if (GetTensorIdx(ou) != kInvalidIndex) { + ou_idx.push_back(GetTensorIdx(ou)); + } + } + } + ofs << "inputs["; + for (auto idx : in_idx) { + bool has_in_ou = std::any_of(ou_idx.begin(), ou_idx.end(), [idx](int odx) { return idx == odx; }); + if (!has_in_ou) { + ofs << "%" << idx << "T,"; + } + } + ofs << "]\toutpus["; + for (auto odx : ou_idx) { + ofs << "%" << odx << "T,"; + } + ofs << "]\tstreamID[@" << AnfAlgo::GetStreamId(node) << "]\n"; + node_idx++; + } + ofs.close(); +} + +void MemReuseChecker::SetTesnorFromAndToInfo(const KernelDef *op_def) { + auto split_name = GetSplitName(op_def->scope_full_name()); + for (auto &in : op_def->inputs_) { + auto in_tensors = in.second; + for (auto &tensor : in_tensors) { + auto indx = tensor->index_; + tensor_to_[indx].push_back(split_name); + } + } + for (auto &ou : op_def->outputs_) { + auto ou_tensors = ou.second; + for (auto &tensor : ou_tensors) { + auto indx = tensor->index_; + tensor_from_[indx].push_back(split_name); + } + } +} + +void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) { + const auto &cnodes = graph->execution_order(); + for (const auto &node : cnodes) { + std::vector curr_ous; + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); ++i) { + auto it = AnfAlgo::GetOutputAddr(node, i); + MS_EXCEPTION_IF_NULL(it); + auto ptr = it->GetPtr(); + nor_output_tensors_.push_back(ptr); + nor_tensor_sizes_.push_back(it->GetSize()); + curr_ous.push_back(it->GetPtr()); + } + (void)node_ous_.insert(std::make_pair(node.get(), curr_ous)); + std::vector curr_ins; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { + if (i + 1 >= node->inputs().size()) { + MS_LOG(EXCEPTION) << "Input index: " << i + << " is larger than input number: " << AnfAlgo::GetInputTensorNum(node); + } + auto real_input_index = AnfAlgo::GetRealInputIndex(node, i); + auto input = node->input(real_input_index + 1); + MS_EXCEPTION_IF_NULL(input); + auto kernel_with_index = AnfAlgo::VisitKernel(input, 0); + if (kernel_with_index.first->isa()) { + continue; + } + auto device_address = AnfAlgo::GetPrevNodeOutputAddr(node, real_input_index); + MS_EXCEPTION_IF_NULL(device_address); + nor_input_tensors_.push_back(device_address->GetPtr()); + curr_ins.push_back(device_address->GetPtr()); + } + (void)node_ins_.insert(std::make_pair(node.get(), curr_ins)); + } + size_t ou_idx = 0; + for (const auto &ou : nor_output_tensors_) { + (void)ptr_idx_.insert(std::make_pair(ou, ou_idx)); + (void)ptr_refs_.insert(std::make_pair(ou, 0)); + ou_idx++; + } + for (const auto &in : nor_input_tensors_) { + if (ptr_idx_.find(in) != ptr_idx_.end()) { + if (ptr_refs_.find(in) != ptr_refs_.end()) { + auto iter = ptr_refs_.find(in); + (iter->second)++; + } else { + MS_LOG(EXCEPTION) << "ptr_refs is not equal to ptr_idx"; + } + } + } + ExportNormalOpIr(cnodes); +} + +void MemReuseChecker::SetMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list) { + std::vector curr_mem_infos; + for (const auto &mem : membuf_ptr_list) { + auto mem_checker = + std::make_shared(mem->status_, mem->size_, mem->offset_, mem->index_, mem->type_, mem->used_kernel_); + curr_mem_infos.push_back(mem_checker); + } + membuf_all_infos_.push_back(curr_mem_infos); + auto split_name = GetSplitName(op_def->scope_full_name()); + all_split_names_.push_back(split_name); + SetTesnorFromAndToInfo(op_def); +} + +void MemReuseChecker::SetAddNewMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list, + size_t op_idx) { + std::vector add_new_curr_mem; + + for (const auto &mem : membuf_ptr_list) { + auto mem_checker = + std::make_shared(mem->status_, mem->size_, mem->offset_, mem->index_, mem->type_, mem->used_kernel_); + add_new_curr_mem.push_back(mem_checker); + } + add_new_mem_infos_.push_back(add_new_curr_mem); + auto split_name = GetSplitName(op_def->scope_full_name()); + add_new_names_.push_back(split_name); + add_new_op_indxs_.push_back(op_idx); + add_new_stream_ids_.push_back(op_def->stream_id()); +} + +void MemReuseChecker::ExportEachMembufInfo(std::ofstream &ofs) { + size_t i = 0; + std::vector each_node_used_size; + std::vector each_node_allocated_size; + for (const auto &curr_membuf_list : membuf_all_infos_) { + ofs << all_split_names_.at(i) << "\n"; + ++i; + ofs << "mem_num\t" + << "stream_id\t" + << "status\t" + << "tensor_idex\t" + << "mem_size\t" + << "mem_head\t" + << "mem_tail\t" + << "mem_type\t" + << "used_kernel\n"; + size_t curr_used = 0; + size_t curr_allocated = 0; + for (size_t j = 0; j < curr_membuf_list.size(); ++j) { + auto membuf = curr_membuf_list.at(j); + auto used_kernel = membuf->used_kernel_->scope_full_name(); + ofs << "&" << j << "\t" + << "streamID[@" << membuf->used_kernel_->stream_id() << "]" + << "\t" + << "#" << static_cast(membuf->status_) << "\t%" << membuf->index_ << "T" + << "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t\t" << membuf->offset_ + membuf->size_ << "\t" + << "\t" << static_cast(membuf->type_) << "\t" << GetSplitName(used_kernel) << "\n"; + if (membuf->status_ == kReused) { + curr_used += membuf->size_; + } + } + if (!curr_membuf_list.empty()) { + curr_allocated = curr_membuf_list.back()->offset_ + curr_membuf_list.back()->size_; + } + each_node_used_size.push_back(curr_used); + each_node_allocated_size.push_back(curr_allocated); + ofs << "curr real used size: \t" << curr_used << "\n"; + ofs << "curr allocated size: \t" << curr_allocated << "\n"; + ofs << "\n\n"; + } + auto optimal_iter = std::max_element(each_node_used_size.begin(), each_node_used_size.end()); + ofs << "theoretical optimal size: " << *optimal_iter << "\n"; + ofs << "each node used size: \n"; + for (auto size : each_node_used_size) { + ofs << size << "\t"; + } + ofs << "\n\n"; + ofs << "each node allocated size: \n"; + for (auto size : each_node_allocated_size) { + ofs << size << "\t"; + } + ofs << "\n\n"; +} + +void MemReuseChecker::ExportMembufInfoIR() { + std::string ir_file_name = "./mem_buf_info.ir"; + std::ofstream ofs(ir_file_name); + int64_t total_reuse_size = 0; + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file [" << ir_file_name << "] failed!"; + } + ofs << "Total static size:\t" << total_ori_static_size_ << "\n"; + ofs << "Graph inputs size:\t" << total_ori_input_size_ << "\n"; + ofs << "Value nodes size:\t" << total_ori_value_size_ << "\n"; + ofs << "Total dynamic size:\t" << total_ori_dy_size_ << "\n"; + ofs << "Total workspace size:\t" << total_ori_wkspace_size_ << "\n"; + // get last membuf_list + if (membuf_all_infos_.empty()) { + return; + } + auto last_membuf_list = membuf_all_infos_.back(); + for (const auto &membuf : last_membuf_list) { + auto checker_size = SizeToLong(membuf->size_); + total_reuse_size += checker_size; + } + ofs << "After reuse size:\t" << total_reuse_size << "\n\n"; + ExportEachMembufInfo(ofs); + ofs.close(); +} + +void MemReuseChecker::ExportAddNewMmebufIR() { + std::string ir_file_name = "./AddNewMembuf.ir"; + std::ofstream ofs(ir_file_name); + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file [" << ir_file_name << "] failed!"; + } + auto check_idx = add_new_mem_infos_.size(); + if (check_idx == add_new_op_indxs_.size() && check_idx == add_new_names_.size() && + check_idx == add_new_stream_ids_.size()) { + size_t i = 0; + for (const auto &curr_membuf_list : add_new_mem_infos_) { + ofs << "op_idx:$" << add_new_op_indxs_.at(i) << "\t" << add_new_names_.at(i) << "\t"; + ofs << "streamID[@" << add_new_stream_ids_.at(i) << "]" + << "\n"; + i++; + ofs << "mem_num\t" + << "status\t" + << "tensor_idex\t" + << "mem_size\t" + << "mem_head\t" + << "mem_tail\t" + << "FromOp\t" + << "ToOp\n"; + for (size_t j = 0; j < curr_membuf_list.size(); ++j) { + auto membuf = curr_membuf_list.at(j); + ofs << "&" << j << "\t" + << "\t" + << "#" << static_cast(membuf->status_) << "\t%" << membuf->index_ << "T" + << "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t" << membuf->offset_ + membuf->size_ << "\t"; + auto in_idx_iter = tensor_from_.find(membuf->index_); + if (in_idx_iter != tensor_from_.end()) { + for (auto &in_name : in_idx_iter->second) { + ofs << in_name << ","; + } + ofs << "\t"; + } + auto ou_idx_iter = tensor_to_.find(membuf->index_); + if (ou_idx_iter != tensor_to_.end()) { + for (auto &ou_name : ou_idx_iter->second) { + ofs << ou_name << ","; + } + ofs << "\n"; + } + } + ofs << "\n"; + } + } + ofs.close(); +} +} // namespace memreuse +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h new file mode 100644 index 0000000000..3c4a00a3ca --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ir/anf.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/mem_reuse/mem_reuse.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" +namespace mindspore { +namespace memreuse { +constexpr auto kSend = "Send"; +constexpr auto kRecv = "Recv"; +constexpr auto kSplitC = '/'; +class MemReuseChecker { + public: + bool IsAddNewMembuf_ = false; + static MemReuseChecker &GetInstance(); + MemReuseChecker(const MemReuseChecker &) = delete; + MemReuseChecker &operator=(const MemReuseChecker &) = delete; + void CheckSignalOps(const CNodePtr &c_node); + void CheckWorkSpace(const std::vector &max_list); + void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx); + bool CheckGraphOutputAssigned(const session::KernelGraph *graph); + void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list, + KernelGraph *graph); + int64_t CalculOriStatic(KernelGraph *graph) const; + int64_t CalculOriInput(const KernelGraph *graph) const; + int64_t CalculOriValue(KernelGraph *graph) const; + int64_t CalculOriDy(const KernelGraph *graph) const; + int64_t CalculOriWk(const KernelGraph *graph) const; + std::string GetSplitName(const std::string &scope_name) const; + int GetTensorIdx(const void *in) const; + void SetMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list); + void SetTesnorFromAndToInfo(const KernelDef *op_def); + void ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx); + void ExportNormalOpIr(const std::vector &cnodes); + void ExportNormalTensorIR(std::ofstream &ofs); + void CheckNormalIR(const session::KernelGraph *graph); + void ExportMembufInfoIR(); + void ExportEachMembufInfo(std::ofstream &ofs); + void SetAddNewMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list, size_t op_idx); + void ExportAddNewMmebufIR(); + void set_kernel_front_map(const std::map> &kernel_front_map) { + kernel_front_map_ = kernel_front_map; + } + void ExportKernelDependence(); + + private: + MemReuseChecker() = default; + ~MemReuseChecker() {} + size_t total_re_wkspe_size_checker_{0}; + std::vector> membuf_all_infos_; + std::vector nor_output_tensors_; + std::vector nor_tensor_sizes_; + std::vector nor_input_tensors_; + std::map ptr_idx_; + std::map ptr_refs_; + std::map> node_ins_; + std::map> node_ous_; + std::vector> add_new_mem_infos_; + std::vector add_new_names_; + std::vector add_new_op_indxs_; + std::vector add_new_stream_ids_; + std::vector all_split_names_; + std::map> tensor_from_; + std::map> tensor_to_; + std::map> kernel_front_map_; + int64_t total_ori_static_size_ = 0; + int64_t total_ori_input_size_ = 0; + int64_t total_ori_value_size_ = 0; + int64_t total_ori_dy_size_ = 0; + int64_t total_ori_wkspace_size_ = 0; +}; +} // namespace memreuse +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc new file mode 100644 index 0000000000..41bf5460c3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc @@ -0,0 +1,344 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/mem_reuse/mem_swap_manager.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace device { +namespace memswap { +void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + graph_manager_ = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(graph_manager_); + auto &kernels = kernel_graph->execution_order(); + for (const auto &kernel : kernels) { + if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) { + execution_order_.push_back(kernel); + } + } + + size_t kernel_index = 0; + for (const auto &kernel : execution_order_) { + // parse topo order of kernel + (void)kernel_execution_info_.emplace(kernel.get(), kernel_index++); + // parse tensor info + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + + for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(kernel); ++output_idx) { + TensorInfo tensor_info = {output_sizes[output_idx], kernel, output_idx}; + ordered_tensors_.push_back(tensor_info); + } + } + + // parse topo order of user kernel + SaveUserKernelTopoOrder(); + + sort(ordered_tensors_.begin(), ordered_tensors_.end(), + [](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; }); + + auto cur_tensor_size = ordered_tensors_.front().tensor_size_; + for (auto &tensor_info : ordered_tensors_) { + if (cur_tensor_size != tensor_info.tensor_size_) { + cur_tensor_size = tensor_info.tensor_size_; + tensor_size_num_++; + } + } + tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; + tensor_size_threshold_idx_ = 0; + + distance_threshold_ = kernel_index / kDistanceInitFactor; + mem_swap_initialized_ = true; + MS_EXCEPTION_IF_NULL(mem_copy_manager_); + mem_copy_manager_->Init(); +} + +bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { + MS_EXCEPTION_IF_NULL(kernel); + NodeUsersMap &user_map = graph_manager_->node_users(); + auto iter = user_map.find(kernel); + bool adjacent_with_communication_op = false; + if (iter != user_map.end()) { + AnfNodeIndexSet node_set = iter->second; + adjacent_with_communication_op = std::any_of( + node_set.begin(), node_set.end(), + [](const std::pair &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); }); + } + return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op; +} + +void MemSwapManager::SaveUserKernelTopoOrder() { + NodeUsersMap &user_map = graph_manager_->node_users(); + for (const auto &kernel : execution_order_) { + auto iter = user_map.find(kernel); + if (iter == user_map.end()) { + continue; + } + AnfNodeIndexSet node_set = iter->second; + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + for (auto &node_pair : node_set) { + auto user_kernel = node_pair.first; + if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) { + continue; + } + + size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_; + auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1); + auto &output_idx = kernel_with_index.second; + if (kernel_with_index.first.get() != kernel.get()) { + MS_LOG(EXCEPTION) << "Save user kernel topo order failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort); + } + for (auto &node_user_pair : kernel_exec_info.node_users_map_) { + sort(node_user_pair.second.begin(), node_user_pair.second.end()); + } + } +} + +void MemSwapManager::AddSwapInfo() { + for (const auto &tensor : ordered_tensors_) { + size_t tensor_size = tensor.tensor_size_; + if (tensor_size < tensor_size_threshold_) { + break; + } + + size_t output_idx = tensor.output_idx_; + const AnfNodePtr &kernel = tensor.kernel_; + if (IsCommunicationRelevantOp(kernel)) { + continue; + } + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &node_users_map = kernel_exec_info.node_users_map_; + + auto iter = node_users_map.find(output_idx); + if (iter == node_users_map.end()) { + continue; + } + auto &node_users = iter->second; + bool need_swap = (node_users.size() == 1 && node_users[0] - kernel_exec_info.topo_order_ >= distance_threshold_) || + (node_users.size() > 1 && node_users[1] - node_users[0] >= distance_threshold_); + if (!need_swap) { + continue; + } + AddKernelNeedSwap(kernel, true); + HostAddress host_addr; + host_addr.size = tensor_size; + auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast(&host_addr.addr)); + if (!ret) { + MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed."; + } + kernel_exec_info.host_addrs_[output_idx] = host_addr; + MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel, output_idx}; + if (node_users.size() > 1) { + AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info); + AddKernelTriggerSwap(execution_order_[node_users[0]], true); + } else { + AddKernelMemSwapInfo(kernel, mem_swap_out_info); + AddKernelTriggerSwap(kernel, true); + } + + size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1; + if (swap_in_order <= kernel_exec_info.topo_order_) { + MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + auto swap_in_kernel = execution_order_[swap_in_order]; + MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel, output_idx}; + AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info); + AddKernelTriggerSwap(swap_in_kernel, true); + + host_addrs_list_.push_back(host_addr); + } +} + +void MemSwapManager::AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, + const HostAddress &host_address) const { + if (swap_kind == SwapKind::kDeviceToHost) { + mem_copy_manager_->AddMemSwapOutTask(device_address, host_address); + } else if (swap_kind == SwapKind::kHostToDevice) { + mem_copy_manager_->AddMemSwapInTask(device_address, host_address); + } +} + +bool MemSwapManager::SyncMemCopyStream(SwapKind swap_kind) const { + return mem_copy_manager_->SyncMemCopyStream(swap_kind); +} + +DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind) const { + if (swap_kind == SwapKind::kDeviceToHost) { + return mem_copy_manager_->UpdateSwapOutQueue(); + } else { + return mem_copy_manager_->UpdateSwapInQueue(); + } +} + +// retreat to find a workable swap scheme +bool MemSwapManager::RetreatSwapInfo() { + if (!trigger_swap_) { + trigger_swap_ = true; + } + if (swap_info_already_set_) { + ResetSwapInfo(); + if (distance_threshold_ >= kDistanceLowerBound) { + auto distance_decay_step = execution_order_.size() / kDistanceInitFactor / tensor_size_num_; + distance_threshold_ -= (distance_decay_step > 1 ? distance_decay_step : 1); + } + + while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { + ++tensor_size_threshold_idx_; + if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { + tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; + break; + } + } + + if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) { + MS_LOG(ERROR) << "Retreat swap info failed"; + return false; + } + } else { + swap_info_already_set_ = true; + } + AddSwapInfo(); + return true; +} + +KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const { + MS_EXCEPTION_IF_NULL(kernel); + auto iter = kernel_execution_info_.find(kernel.get()); + if (iter == kernel_execution_info_.end()) { + MS_LOG(EXCEPTION) << "Can not find execution info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return const_cast(iter->second); +} + +void MemSwapManager::AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform) { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + kernel_exec_info.execution_perform_ = perform; +} + +void MemSwapManager::AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap) { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + kernel_exec_info.trigger_swap_ = trigger_swap; +} + +void MemSwapManager::AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap) { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + kernel_exec_info.need_swap_ = need_swap; +} + +void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, + const std::pair &perform) { + MS_EXCEPTION_IF_NULL(kernel); + kernel_swap_perform_[kernel.get()][output_idx] = perform; +} + +void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { + MS_EXCEPTION_IF_NULL(kernel); + mem_swap_info_[kernel.get()].push_back(mem_swap_info); +} + +float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const { + const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + return kernel_exec_info.execution_perform_; +} + +bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const { + const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + return kernel_exec_info.trigger_swap_; +} + +bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const { + const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + return kernel_exec_info.need_swap_; +} + +const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const { + MS_EXCEPTION_IF_NULL(kernel); + auto iter_kernel = kernel_swap_perform_.find(kernel.get()); + if (iter_kernel == kernel_swap_perform_.end()) { + MS_LOG(EXCEPTION) << "Can not find swap performance data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + + auto &perform_map = iter_kernel->second; + auto iter_output = perform_map.find(output_idx); + if (iter_output == perform_map.end()) { + MS_LOG(EXCEPTION) << "Can not find swap performance data of output[" << output_idx << "] of op[" + << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return iter_output->second; +} + +const std::vector &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { + MS_EXCEPTION_IF_NULL(kernel); + auto iter = mem_swap_info_.find(kernel.get()); + if (iter == mem_swap_info_.end()) { + MS_LOG(EXCEPTION) << "Can not find memory swap information data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return iter->second; +} + +void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); } + +bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const { + auto iter = swap_in_blacklist_.find(device_ptr); + return iter != swap_in_blacklist_.end(); +} + +const HostAddress &MemSwapManager::kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &host_addrs = kernel_exec_info.host_addrs_; + auto iter = host_addrs.find(output_idx); + if (iter == host_addrs.end()) { + MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return iter->second; +} + +bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const { + return mem_copy_manager_->AllocHostPinnedMem(size, addr); +} + +void MemSwapManager::ReleaseHostPinnedMem() { + for (const auto &host_addr : host_addrs_list_) { + if (host_addr.addr) { + mem_copy_manager_->FreeHostPinnedMem(host_addr.addr); + } + } + host_addrs_list_.clear(); +} + +void MemSwapManager::ClearSwapQueue() const { mem_copy_manager_->ClearSwapQueue(); } + +void MemSwapManager::ResetSwapInfo() { + ClearSwapQueue(); + for (auto &kernel_exec_info_pair : kernel_execution_info_) { + auto &kernel_exec_info = kernel_exec_info_pair.second; + kernel_exec_info.trigger_swap_ = false; + kernel_exec_info.need_swap_ = false; + kernel_exec_info.host_addrs_.clear(); + } + ReleaseHostPinnedMem(); + swap_in_blacklist_.clear(); + mem_swap_info_.clear(); +} +} // namespace memswap +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h new file mode 100644 index 0000000000..d8620c8516 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include "backend/optimizer/mem_reuse/mem_copy_manager.h" + +using PerformPair = std::pair; +namespace mindspore { +namespace device { +namespace memswap { +class MemSwapManager { + public: + explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager) + : tensor_size_threshold_(0), tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1) { + mem_copy_manager_ = mem_copy_manager; + } + + MemSwapManager(const MemSwapManager &) = delete; + + MemSwapManager &operator=(const MemSwapManager &) = delete; + + ~MemSwapManager() = default; + + void Init(const mindspore::session::KernelGraph *kernel_graph); + + void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, + const HostAddress &host_address) const; + + bool SyncMemCopyStream(SwapKind swap_kind) const; + + DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const; + + // retreat to find a workable swap scheme + bool RetreatSwapInfo(); + + bool trigger_swap() const { return trigger_swap_; } + + bool mem_swap_init() const { return mem_swap_initialized_; } + + KernelExecutionInfo &SearchKernelExecutionInfo(const AnfNodePtr &kernel) const; + + void AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform); + + float QueryKernelExecutionPerform(const AnfNodePtr &kernel) const; + + void AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, const PerformPair &perform); + + const PerformPair &QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const; + + bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const; + + bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const; + + const std::vector &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; + + void InsertSwapInBlackList(const void *device_ptr); + + bool FindInSwapInBlackList(const void *device_ptr) const; + + const HostAddress &kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const; + + bool AllocHostPinnedMem(size_t size, void **addr) const; + + void ReleaseHostPinnedMem(); + + void ClearSwapQueue() const; + + private: + void AddSwapInfo(); + + void ResetSwapInfo(); + + void SaveUserKernelTopoOrder(); + + void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap); + + void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap); + + void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); + + bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; + + std::vector execution_order_; + std::vector ordered_tensors_; + std::unordered_map kernel_execution_info_; + std::unordered_map> kernel_swap_perform_; + // trigger swap kernel key : MemSwapInfo of kernel need to be swapped + std::unordered_map> mem_swap_info_; + std::vector host_addrs_list_; + std::unordered_set swap_in_blacklist_; + + size_t tensor_size_threshold_; + size_t tensor_size_threshold_idx_; + size_t tensor_size_num_; + size_t distance_threshold_; + + MemCopyManagerPtr mem_copy_manager_{nullptr}; + FuncGraphManagerPtr graph_manager_{nullptr}; + bool mem_swap_initialized_{false}; + bool swap_info_already_set_{false}; + bool trigger_swap_{false}; + + static constexpr size_t kDistanceInitFactor = 3; + static constexpr size_t kDistanceLowerBound = 3; +}; +using MemSwapManagerPtr = std::shared_ptr; +} // namespace memswap +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc new file mode 100644 index 0000000000..900dd0d563 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/pass/add_atomic_clean.h" +#include +#include +#include +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "utils/graph_utils.h" +#include "utils/log_adapter.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +namespace { + +static std::vector g_output_idx; + +bool HasAtomic(const AnfNodePtr &input) { + if (IsPrimitiveCNode(input)) { + const auto &cnode = input->cast(); + const auto &prim = GetValueNode(cnode->input(0)); + return prim->HasAttr("atomic_add"); + } + return false; +} + +std::vector CalCleanSize(const CNodePtr &pre_node) { + MS_EXCEPTION_IF_NULL(pre_node); + std::vector clean_size_list; + // clean output + for (auto &index : g_output_idx) { + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(pre_node, index); + size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); + std::vector shape = AnfAlgo::GetOutputDeviceShape(pre_node, index); + auto size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + clean_size_list.push_back((size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize); + } + MS_LOG(DEBUG) << "Clear output size: " << clean_size_list.size() << ", pre_node: " << pre_node->fullname_with_scope(); + return clean_size_list; +} + +CNodePtr CreateTbeAtomicCleanNode(const std::shared_ptr &kernel_graph, + const mindspore::CNodePtr &pre_node) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(pre_node); + auto clean_zero_prim = std::make_shared(kAtomicAddrCleanOpName); + auto new_value_node = NewValueNode(clean_zero_prim); + std::vector inputs = {new_value_node}; + CNodePtr clean_zero = kernel_graph->NewCNode(inputs); + AbstractBasePtr abstract = std::make_shared(); + clean_zero->set_abstract(abstract); + auto builder = std::make_shared(); + builder->SetKernelType(KernelType::TBE_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clean_zero.get()); + auto clean_size = CalCleanSize(pre_node); + AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clean_zero); + AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(g_output_idx), clean_zero); + AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clean_zero.get()); + return clean_zero; +} +} // namespace + +void AddAtomicClean(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto mng = kernel_graph->manager(); + if (mng == nullptr) { + mng = Manage(kernel_graph, true); + kernel_graph->set_manager(mng); + } + auto &todos = kernel_graph->execution_order(); + for (auto iter = todos.cbegin(); iter != todos.end(); ++iter) { + auto node = *iter; + if (AnfAlgo::IsGraphKernel(node) && kernel_graph->nodes().contains(node)) { + auto fg = GetValueNode(node->input(kAnfPrimitiveIndex)); + MS_EXCEPTION_IF_NULL(fg); + auto input = fg->get_return()->input(1); + if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { + const auto &cnode = input->cast(); + for (size_t i = 0; i < cnode->inputs().size(); ++i) { + if (HasAtomic(cnode->input(i))) { + g_output_idx.push_back(i - 1); + } + } + } else if (HasAtomic(input)) { + g_output_idx.push_back(0); + } + + if (!g_output_idx.empty()) { + auto zero_node = CreateTbeAtomicCleanNode(kernel_graph, node); + auto depend = kernel_graph->NewCNode({NewValueNode(prim::kPrimDepend), node->input(1), zero_node}); + std::vector new_input = node->inputs(); + new_input[1] = depend; + auto new_cnode = std::make_shared(new_input, kernel_graph); + // Set abstract + new_cnode->set_abstract(node->abstract()); + // Set kernel info + new_cnode->set_kernel_info(node->kernel_info_ptr()); + mng->Replace(node, new_cnode); + g_output_idx.clear(); + } + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.h b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.h new file mode 100644 index 0000000000..7e3fbdb472 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H_ + +#include +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +void AddAtomicClean(const std::shared_ptr &kernel_graph); +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H diff --git a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc new file mode 100644 index 0000000000..133a7e764a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/common_subexpression_elimination.h" +#include +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +namespace { +bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(main); + MS_EXCEPTION_IF_NULL(node); + auto main_kernel_info = dynamic_cast(main->kernel_info()); + auto node_kernel_info = dynamic_cast(node->kernel_info()); + if (main_kernel_info == nullptr && node_kernel_info == nullptr) { + return true; + } + if (main_kernel_info != nullptr && node_kernel_info != nullptr) { + return *main_kernel_info == *node_kernel_info; + } + return false; +} +} // namespace + +bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { + MS_EXCEPTION_IF_NULL(main); + MS_EXCEPTION_IF_NULL(node); + + bool replace = false; + if (main->isa() && node->isa()) { + auto main_value = GetValueNode(main); + auto node_value = GetValueNode(node); + if (main_value->isa() && node_value->isa()) { + replace = false; + } else if (main_value->isa() && node_value->isa()) { + replace = (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node); + } else { + replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); + } + } else if (main->isa() && node->isa()) { + if (!CheckEqualKernelBuildInfo(main, node)) { + replace = false; + } else { + auto c_main = main->cast(); + MS_EXCEPTION_IF_NULL(c_main); + auto c_node = node->cast(); + MS_EXCEPTION_IF_NULL(c_node); + const auto &inp1 = c_main->inputs(); + const auto &inp2 = c_node->inputs(); + if (inp1.size() == inp2.size()) { + bool appsame = true; + for (size_t j = 0; j < inp1.size(); j++) { + MS_EXCEPTION_IF_NULL(inp1[j]); + MS_EXCEPTION_IF_NULL(inp2[j]); + if (!(*inp1[j] == *inp2[j])) { + appsame = false; + break; + } + } + replace = appsame; + } + } + } + return replace; +} + +bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto backend_cse = std::make_shared(); + return backend_cse->Cse(func_graph, func_graph->manager()); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h new file mode 100644 index 0000000000..bac870e59f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ +#include "backend/optimizer/common/pass.h" +#include "frontend/optimizer/cse.h" + +namespace mindspore { +namespace opt { +class CommonSubexpressionElimination : public Pass { + public: + CommonSubexpressionElimination() : Pass("cse") {} + ~CommonSubexpressionElimination() override = default; + bool Run(const FuncGraphPtr &func_graph) override; +}; + +class BackendCSE : public CSE { + public: + BackendCSE() = default; + ~BackendCSE() override = default; + bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc new file mode 100644 index 0000000000..3ba055880c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -0,0 +1,274 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/communication_op_fusion.h" + +#include +#include +#include + +#include "utils/graph_utils.h" +#include "frontend/operator/ops.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "frontend/parallel/context.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr auto kAttrDefaultGroup = "default_group"; +constexpr auto kAttrDefaultOp = "default_op"; + +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index, + size_t end_index) { + if (end_index >= communication_op_info.communication_op_nodes.size()) { + MS_LOG(EXCEPTION) << "end index out of vector size"; + } + std::vector inputs_device_format; + std::vector outputs_device_format; + std::vector inputs_device_type; + std::vector outputs_device_type; + std::vector> outputs_shape; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + for (size_t idx = start_index; idx <= end_index; ++idx) { + auto cnode = communication_op_info.communication_op_nodes[idx]; + MS_EXCEPTION_IF_NULL(cnode); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); + inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); + outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); + } + builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); + builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); + builder.SetKernelType(AnfAlgo::GetKernelType(cnode)); + } + builder.SetInputsFormat(inputs_device_format); + builder.SetOutputsFormat(outputs_device_format); + builder.SetInputsDeviceType(inputs_device_type); + builder.SetOutputsDeviceType(outputs_device_type); + return builder.Build(); +} + +std::string GetFusionGroupKey(const AnfNodePtr &node) { + auto primitive = AnfAlgo::GetCNodePrimitive(node); + MS_EXCEPTION_IF_NULL(primitive); + ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion); + if (attr_fusion == nullptr) { + return ""; + } + int fusion = GetValue(attr_fusion); + if (fusion == 0) { + return ""; + } + std::string group = kAttrDefaultGroup; + ValuePtr attr_group = primitive->GetAttr(kAttrGroup); + if (attr_group != nullptr) { + group = GetValue(attr_group); + } + std::string op = kAttrDefaultOp; + ValuePtr attr_op = primitive->GetAttr(kAttrOp); + if (attr_op != nullptr) { + op = GetValue(attr_op); + } + return group + op + std::to_string(fusion); +} +} // namespace + +bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, + std::vector *segment_index, const std::string &group) const { + MS_EXCEPTION_IF_NULL(segment_num); + MS_EXCEPTION_IF_NULL(segment_index); + size_t communication_op_node_size = communication_op_info.communication_op_nodes.size(); + MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size; + + auto parallel_context = parallel::ParallelContext::GetInstance(); + MS_EXCEPTION_IF_NULL(parallel_context); + const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); + + size_t segments = 0; + if (split_indices.size() != 0) { + uint32_t last_index = 0; + for (size_t i = 0; i < split_indices.size(); ++i) { + uint32_t index = split_indices[i]; + if (index <= last_index || index >= communication_op_node_size) { + MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index; + } + segment_index->push_back(index); + last_index = index; + segments++; + } + if (last_index != communication_op_node_size - 1) { + segment_index->push_back(communication_op_node_size - 1); + segments++; + } + } else { + segments = groups_; + for (size_t i = 0; i < segments - 1; ++i) { + segment_index->push_back((i + 1) * (communication_op_node_size / segments) - 1); + } + segment_index->push_back(communication_op_node_size - 1); + } + + if (segments >= communication_op_node_size) { + MS_LOG(INFO) << "fusion not changed: segment_num=" << segments + << ", communication_op_node_size=" << communication_op_node_size; + return false; + } + if (segment_index->at(segments - 1) != communication_op_node_size - 1) { + MS_LOG(EXCEPTION) << "the last segment index is invalid."; + } + for (size_t i = 0; i < segments - 1; ++i) { + if (segment_index->at(i) > segment_index->at(i + 1)) { + MS_LOG(EXCEPTION) << "illegal split: segment_index[" << i << "]=" << segment_index->at(i) << ", segment_index[ " + << i + 1 << "]=" << segment_index->at(i + 1); + } + } + *segment_num = segments; + return true; +} + +AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, + const CommunicationOpInfo &communication_op_info, + size_t start_index, size_t end_index) const { + MS_EXCEPTION_IF_NULL(func_graph); + auto prim = std::make_shared(op_name_); + MS_EXCEPTION_IF_NULL(prim); + std::vector fusion_inputs = {NewValueNode(prim)}; + // get all inputs of current segment + if (end_index >= communication_op_info.communication_op_nodes.size()) { + MS_LOG(EXCEPTION) << "end index out of vector size"; + } + for (size_t idx = start_index; idx <= end_index; ++idx) { + auto cnode = communication_op_info.communication_op_nodes[idx]; + MS_EXCEPTION_IF_NULL(cnode); + fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + } + AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs); + MS_EXCEPTION_IF_NULL(fused_node); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + fused_node->set_kernel_info(kernel_info); + AbstractBasePtrList abstract_list; + for (size_t idx = start_index; idx <= end_index; ++idx) { + auto cnode = communication_op_info.communication_op_nodes[idx]; + MS_EXCEPTION_IF_NULL(cnode); + AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node); + AnfAlgo::CopyNodeAttr("op", cnode, fused_node); + AnfAlgo::CopyNodeAttr("group", cnode, fused_node); + abstract_list.push_back(cnode->abstract()); + } + auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get()); + auto abstract_tuple = std::make_shared(abstract_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + fused_node->set_abstract(abstract_tuple); + return fused_node; +} + +bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, + size_t segment_num, const std::vector &segment_index) const { + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + bool changed = false; + size_t start_index = 0; + for (size_t segment_idx = 0; segment_idx < segment_num; ++segment_idx) { + size_t end_index = segment_index.at(segment_idx); + if (end_index - start_index < 1) { + start_index = end_index + 1; + continue; + } + AnfNodePtr new_communication_op = + CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index); + // replace old communication op with new communication op + for (auto idx = start_index; idx <= end_index; ++idx) { + std::vector tuple_getitem_input; + tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem)); + tuple_getitem_input.push_back(new_communication_op); + auto index = NewValueNode(SizeToInt(idx - start_index)); + MS_EXCEPTION_IF_NULL(index); + auto imm = std::make_shared(idx - start_index); + MS_EXCEPTION_IF_NULL(imm); + auto abstract_scalar = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_scalar); + index->set_abstract(abstract_scalar); + tuple_getitem_input.push_back(index); + AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input); + MS_EXCEPTION_IF_NULL(tuple_getitem); + auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx); + MS_EXCEPTION_IF_NULL(communication_op_node_item); + tuple_getitem->set_abstract(communication_op_node_item->abstract()); + if (!manager->Replace(communication_op_node_item, tuple_getitem)) { + MS_LOG(EXCEPTION) << "manager replace node failed"; + } + } + start_index = end_index + 1; + changed = true; + } + return changed; +} + +bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + const float input_grad_size_num = 0.0; + const float input_grad_time_num = 0.0; + // divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion + std::unordered_map candidate_groups; + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == op_name_) { + std::string key = GetFusionGroupKey(node); + if (key.empty()) { + continue; + } + if (candidate_groups.find(key) == candidate_groups.end()) { + CommunicationOpInfo communication_op_info; + candidate_groups[key] = communication_op_info; + } + candidate_groups[key].communication_op_nodes.push_back(node->cast()); + candidate_groups[key].input_grad_size.push_back(input_grad_size_num); + candidate_groups[key].input_grad_time.push_back(input_grad_time_num); + } + } + // split candidate group to segments according to _group class member + bool changed = false; + for (auto &it : candidate_groups) { + if (it.second.communication_op_nodes.size() <= 1) { + continue; + } + auto first_node = it.second.communication_op_nodes[0]; + if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr(first_node, kAttrIndex) > 0) { + std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(), + [](const CNodePtr &a, const CNodePtr &b) { + return AnfAlgo::GetNodeAttr(a, kAttrIndex) < AnfAlgo::GetNodeAttr(b, kAttrIndex); + }); + } + size_t segment_num = 0; + std::vector segment_index; + if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) { + if (DoFusion(func_graph, it.second, segment_num, segment_index)) { + changed = true; + } + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.h b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.h new file mode 100644 index 0000000000..0e7cf9762d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.h @@ -0,0 +1,80 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ +#include +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +struct CommunicationOpInfo { + std::vector communication_op_nodes; + std::vector input_grad_size; + std::vector input_grad_time; +}; + +class CommunicationOpFusion : public Pass { + public: + explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1) + : Pass(name), op_name_(std::move(op_name)), groups_(groups) {} + ~CommunicationOpFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, size_t segment_num, + const std::vector &segment_index) const; + AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, + const CommunicationOpInfo &communication_op_info, size_t start_index, + size_t end_index) const; + bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, + std::vector *segment_index, const std::string &group) const; + std::string op_name_; + size_t groups_ = 1; +}; + +class AllReduceFusion : public CommunicationOpFusion { + public: + explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {} + ~AllReduceFusion() override = default; +}; + +class AllGatherFusion : public CommunicationOpFusion { + public: + explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {} + ~AllGatherFusion() override = default; +}; + +class BroadcastFusion : public CommunicationOpFusion { + public: + explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {} + ~BroadcastFusion() override = default; +}; + +class ReduceScatterFusion : public CommunicationOpFusion { + public: + explicit ReduceScatterFusion(size_t groups = 1) + : CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {} + ~ReduceScatterFusion() override = default; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc new file mode 100644 index 0000000000..814ad9567c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/const_input_to_attr_registry.h" + +#include + +#include "utils/utils.h" +#include "utils/log_adapter.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { + Register(prim::kPrimCast->name(), {1}); + Register(prim::kPrimAvgPoolGrad->name(), {0}); + Register(prim::kPrimConv2DBackpropInput->name(), {2}); + Register(prim::kPrimConv2DBackpropFilter->name(), {2}); + Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); + Register(prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), {0}); + Register(prim::kPrimReshape->name(), {1}); + Register(prim::kPrimReduceMax->name(), {1}); + Register(prim::kPrimReduceMin->name(), {1}); + Register(prim::kPrimReduceSum->name(), {1}); + Register(prim::kPrimReduceMean->name(), {1}); + Register(prim::kPrimGatherV2->name(), {2}); + Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5}); + Register(prim::kPrimEmbeddingLookupCommGrad->name(), {1}); + Register(prim::kPrimSubscalar->name(), {1}); + Register(prim::kPrimTranspose->name(), {1}); + Register(prim::kPrimUnsortedSegmentSum->name(), {2}); + Register(prim::kPrimOneHot->name(), {1}); + Register(prim::kPrimConcat->name(), {0}); + Register(prim::kPrimCumSum->name(), {1}); + Register(prim::kPrimCumProd->name(), {1}); + Register(prim::kPrimReduceAll->name(), {1}); + Register(prim::kPrimUnsortedSegmentMin->name(), {2}); + Register(kSparseGatherV2, {2}); + Register(kUnsortedSegmentProdOpName, {2}); + Register(kSimpleMeanGradOpName, {1}); + Register(kMeanGradOpName, {1}); + Register(kSliceOpName, {1, 2}); + Register(kSliceGradOpName, {2, 3}); + Register(kTileOpName, {1}); + Register(kScatterNdOpName, {2}); + Register(kStridedSliceAssignOpName, {1, 2, 3}); + Register(kStridedSliceOpName, {1, 2, 3}); + Register(kFlattenGradOpName, {1}); + Register(kExpandDimsOpName, {1}); + Register(kSplitOpName, {0}); + Register(kErfOpName, {1}); + Register(kSparseApplyAdagradOpName, {2}); + Register(kResizeNearestNeighborGradOpName, {1}); + Register(kResizeNearestNeighborV2OpName, {1}); + Register(kResizeNearestNeighborV2GradOpName, {1}); + Register(kApplyRMSPropOpname, {5, 6, 7}); + Register(kResizeBilinearV2OpName, {1}); + Register(kReduceProdOpName, {1}); + Register(kCumprodOpName, {1}); + Register(kSpaceToBatchOpName, {1}); + Register(kBatchToSpaceOpName, {1}); + Register(kPadOpName, {1}); + Register(kPushOpName, {1}); +} + +ConstInputToAttrInfoRegistry &ConstInputToAttrInfoRegistry::Instance() { + static ConstInputToAttrInfoRegistry instance; + return instance; +} + +void ConstInputToAttrInfoRegistry::Register(const ConstInputToAttrInfoRegister ®) { + auto op_name = reg.GetOpName(); + if (op_input_to_attr_map_.find(op_name) == op_input_to_attr_map_.end()) { + (void)op_input_to_attr_map_.insert(make_pair(op_name, reg)); + MS_LOG(DEBUG) << op_name << " const2attr register successfully!"; + } +} + +void ConstInputToAttrInfoRegistry::Register(const std::string &op_name, + const std::unordered_set &input_attr_set) { + if (op_input_to_attr_map_.find(op_name) == op_input_to_attr_map_.end()) { + ConstInputToAttrInfoRegister reg(op_name); + (void)reg.SetConstInputToAttr(input_attr_set); + (void)op_input_to_attr_map_.insert(make_pair(op_name, reg)); + MS_LOG(DEBUG) << op_name << " const2attr register successfully!"; + } +} + +bool ConstInputToAttrInfoRegistry::GetRegisterByOpName(const std::string &op_name, + ConstInputToAttrInfoRegister *reg) const { + if (op_input_to_attr_map_.find(op_name) != op_input_to_attr_map_.end()) { + *reg = op_input_to_attr_map_.at(op_name); + MS_LOG(DEBUG) << op_name << " const2attr find in registery."; + return true; + } + return false; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.h b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h similarity index 100% rename from mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.h rename to mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc new file mode 100644 index 0000000000..51d399bbcd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/context/ms_context.h" +#include "utils/utils.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t strides_index = 5; + +bool GetStridesValues(const CNodePtr &strided_slice_grad, ValuePtrList *strides_values) { + MS_EXCEPTION_IF_NULL(strided_slice_grad); + if (strided_slice_grad->size() < 6) { + MS_LOG(DEBUG) << "Op strided_slice_grad's inputs size less than 6, graph not changed"; + return false; + } + auto strides_input = strided_slice_grad->input(strides_index); + MS_EXCEPTION_IF_NULL(strides_input); + auto strides_value_node = strides_input->cast(); + if (strides_value_node == nullptr) { + MS_LOG(DEBUG) << "strides is not a value node."; + return false; + } + auto value = strides_value_node->value(); + if (value == nullptr) { + MS_LOG(DEBUG) << "strides has no value."; + return false; + } + auto value_tuple = value->cast(); + if (value_tuple == nullptr) { + MS_LOG(DEBUG) << "strides is not a value tuple."; + return false; + } + *strides_values = value_tuple->value(); + return true; +} + +bool CheckValues(const ValuePtrList &strides_values) { + if (strides_values.empty()) { + MS_LOG(DEBUG) << "strides_values is empty"; + return false; + } + for (auto &value : strides_values) { + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + auto scalar = value->cast(); + MS_EXCEPTION_IF_NULL(scalar); + if (!scalar->isa()) { + MS_LOG(DEBUG) << "strides value is not a Integer"; + return false; + } + if (GetValue(scalar) != 1) { + MS_LOG(DEBUG) << "StridedSliceGrad has no 1 value"; + return false; + } + } else { + MS_LOG(DEBUG) << "The value " << value << "of tuple is not a scalar"; + return false; + } + } + return true; +} + +bool CheckAttrs(const CNodePtr &strided_slice_grad) { + MS_EXCEPTION_IF_NULL(strided_slice_grad); + if (!AnfAlgo::HasNodeAttr(kAttrNewAxisMask, strided_slice_grad) || + !AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, strided_slice_grad)) { + MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not exist in cnode[" + strided_slice_grad->DebugString() + "]"; + return false; + } + auto new_axis_mask = AnfAlgo::GetNodeAttr(strided_slice_grad, kAttrNewAxisMask); + auto shrink_axis_mask = AnfAlgo::GetNodeAttr(strided_slice_grad, kAttrShrinkAxisMask); + if (new_axis_mask != 0 || shrink_axis_mask != 0) { + MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not equal 0"; + return false; + } + return true; +} +} // namespace + +const BaseRef ConstToAttrStridedSliceGradPass::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto strided_slice_grad_prim = std::make_shared(kStridedSliceGradOpName); + return VectorRef({strided_slice_grad_prim, Xs}); +} + +const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto strided_slice_grad = node->cast(); + MS_EXCEPTION_IF_NULL(strided_slice_grad); + + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + + if (ms_context->device_target() == kAscendDevice) { + if (!CheckAttrs(strided_slice_grad)) { + MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed"; + return nullptr; + } + + ValuePtrList strides_values; + if (!GetStridesValues(strided_slice_grad, &strides_values)) { + return nullptr; + } + + if (!CheckValues(strides_values)) { + MS_LOG(INFO) << "Check strides' values failed, graph not changed"; + return nullptr; + } + } + + ConstInputToAttr(strided_slice_grad, {1, 2, 3, 4}); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h new file mode 100644 index 0000000000..83b44d5f51 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConstToAttrStridedSliceGradPass : public PatternProcessPass { + public: + explicit ConstToAttrStridedSliceGradPass(bool multigraph = true) + : PatternProcessPass("const_to_attr_strided_slice_grad_", multigraph) {} + ~ConstToAttrStridedSliceGradPass() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc new file mode 100644 index 0000000000..f2e35351b4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/convert_const_input_to_attr.h" + +#include +#include +#include +#include + +#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "backend/optimizer/common/helper.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { + return nullptr; + } + std::vector todos; + if (AnfAlgo::IsGraphKernel(node)) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + kernel::GetValidKernelNodes(sub_graph, &todos); + } else { + todos.push_back(node); + } + + for (auto &t : todos) { + CNodePtr cnode = t->cast(); + ConstInputToAttrInfoRegister reg; + if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) { + continue; + } + ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); + } + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.h b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.h new file mode 100644 index 0000000000..e6def42fa1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ +#include +#include +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvertConstInputToAttr : public PatternProcessPass { + public: + explicit ConvertConstInputToAttr(bool multigraph = true) + : PatternProcessPass("convert_const_input_to_attr", multigraph) {} + ~ConvertConstInputToAttr() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + std::unordered_map> op_input_attr_map_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc new file mode 100644 index 0000000000..f204841f3c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc @@ -0,0 +1,152 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/convert_const_input_to_tensor_input.h" + +#include +#include +#include + +#include "utils/graph_utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +namespace { +ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + ValueNodePtr new_value_node = std::make_shared(value_node->value()); + new_value_node->set_abstract(value_node->abstract()); + // create kernel_info fo new value node + auto kernel_info = std::make_shared(); + new_value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + // set value node initial device data type = infer data type + std::vector types; + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { + types.push_back(kTypeUnknown); + } + kernel_build_info_builder->SetOutputsDeviceType(types); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + return new_value_node; +} + +AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) { + MS_EXCEPTION_IF_NULL(input_node); + auto value_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + tensor::TensorPtr tensor_ptr = nullptr; + if (value->isa()) { + tensor_ptr = ScalarToTensor(value->cast()); + } else if (value->isa()) { + tensor_ptr = CreateTupleTensor(value->cast()); + } else { + MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple"; + } + if (tensor_ptr == nullptr) { + MS_LOG(WARNING) << "Create tensor failed"; + return nullptr; + } + auto tensor_input = std::make_shared(tensor_ptr); + MS_EXCEPTION_IF_NULL(tensor_input); + tensor_input->set_abstract(tensor_ptr->ToAbstract()); + if (kernel_graph != nullptr) { + tensor_input = kernel_graph->NewValueNode(tensor_input); + kernel_graph->AddValueNodeToGraph(tensor_input); + } else { + tensor_input = MakeValueNode(tensor_input); + } + tensor_input->set_scope(input_node->scope()); + return tensor_input; +} + +AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + std::vector new_inputs; + auto kernel_graph = func_graph->cast>(); + auto inputs = cnode->inputs(); + new_inputs.push_back(inputs[0]); + bool need_update = false; + // the first input is primitive node which is not the real input + for (size_t i = 0; i < inputs.size() - 1; ++i) { + auto input_node = inputs[i + 1]; + if (IsValueNode(input_node) || IsValueNode(input_node)) { + auto tensor_input = CreateTensorInput(kernel_graph, input_node); + if (tensor_input == nullptr) { + new_inputs.push_back(input_node); + continue; + } + new_inputs.push_back(tensor_input); + need_update = true; + } else { + new_inputs.push_back(input_node); + } + } + if (need_update) { + MS_EXCEPTION_IF_NULL(func_graph); + auto new_cnode = func_graph->NewCNode(new_inputs); + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_scope(cnode->scope()); + AnfAlgo::CopyNodeAttrs(cnode, new_cnode); + return new_cnode; + } + return nullptr; +} + +AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + auto mng = sub_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + std::vector todo; + std::vector> graph_rets; + kernel::GetValidKernelNodes(sub_graph, &todo); + kernel::GetGraphRealOutput(sub_graph, &graph_rets); + + for (auto &t : todo) { + auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast()); + if (t_new_node != nullptr && t_new_node != t) { + (void)mng->Replace(t, t_new_node); + } + } + + return node; +} +} // namespace + +const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { + return nullptr; + } + if (AnfAlgo::IsGraphKernel(node)) { + return ProcessGraphKernelOp(node); + } else { + return ConstInputToTensorInput(func_graph, node->cast()); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.h b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.h new file mode 100644 index 0000000000..072652497a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvertConstInputToTensorInput : public PatternProcessPass { + public: + explicit ConvertConstInputToTensorInput(bool multigraph = true) + : PatternProcessPass("convert_const_input_to_tensor_input", multigraph) {} + ~ConvertConstInputToTensorInput() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc new file mode 100644 index 0000000000..b96a7af8f3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h" + +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +namespace { +bool MakeValueNode(const AnfNodePtr &node) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return false; + } + + // create kernel_info fo new value node + auto kernel_info = std::make_shared(); + value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + // set value node initial device data type = infer data type + TypeId infer_data_type; + if (AnfAlgo::GetOutputTensorNum(value_node) == 0) { + infer_data_type = kTypeUnknown; + } else { + infer_data_type = AnfAlgo::GetOutputInferDataType(value_node, 0); + } + kernel_build_info_builder->SetOutputsDeviceType(std::vector{infer_data_type}); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get()); + return true; +} + +void ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node, + std::vector *plant_inputs, std::vector *dyn_input_sizes) { + MS_EXCEPTION_IF_NULL(plant_inputs); + MS_EXCEPTION_IF_NULL(dyn_input_sizes); + MS_EXCEPTION_IF_NULL(graph); + auto output_size = AnfAlgo::GetOutputTensorNum(input_node); + dyn_input_sizes->push_back(output_size); + std::vector convert_inputs; + auto kernel_graph = graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + if (input_node->isa()) { + auto value_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + convert_inputs = kernel_graph->SplitTupleValueNodeToNodeList(value_node); + } else { + for (size_t index = 0; index < output_size; ++index) { + auto tuple_get_item = CreatTupleGetItemNode(graph, input_node, index); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, index)}, + {AnfAlgo::GetOutputInferShape(input_node, index)}, tuple_get_item.get()); + convert_inputs.emplace_back(tuple_get_item); + } + } + (void)std::copy(convert_inputs.begin(), convert_inputs.end(), std::back_inserter(*plant_inputs)); +} + +void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { + MS_EXCEPTION_IF_NULL(cnode_ptr); + MS_EXCEPTION_IF_NULL(graph); + auto &ori_args = cnode_ptr->inputs(); + if (ori_args.size() < 1) { + return; + } + std::vector plant_inputs; + std::vector dyn_input_sizes; + plant_inputs.push_back(ori_args[kAnfPrimitiveIndex]); + for (size_t i = 1; i < ori_args.size(); ++i) { + auto input_node = ori_args[i]; + if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) { + auto input_size = AnfAlgo::GetOutputTensorNum(input_node); + dyn_input_sizes.push_back(input_size); + auto cnode = input_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + for (size_t j = 1; j < inputs.size(); ++j) { + MS_EXCEPTION_IF_NULL(inputs[j]); + if (IsValueNode(inputs[j])) { + auto success = MakeValueNode(inputs[j]); + if (!success) { + MS_LOG(WARNING) << "Make value node failed, " << inputs[j]->DebugString(); + } + } + plant_inputs.push_back(inputs[j]); + } + } else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) { + ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes); + } else { + dyn_input_sizes.push_back(-1); + plant_inputs.push_back(input_node); + } + } + // If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs. + if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int s) { return s >= 0; })) { + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode_ptr); + cnode_ptr->set_inputs(plant_inputs); + } +} +} // namespace + +const BaseRef ConvertTupleInputToDynamicInput::DefinePattern() const { + VarPtr V = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + if (AnfAlgo::IsGraphKernel(node)) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + std::vector todos; + kernel::GetValidKernelNodes(sub_graph, &todos); + for (auto &t : todos) { + ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast()); + } + } else { + ConvertMakeTupleInputToPlantInputs(func_graph, node->cast()); + } + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h new file mode 100644 index 0000000000..63d2415dc5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ + +#include +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvertTupleInputToDynamicInput : public PatternProcessPass { + public: + explicit ConvertTupleInputToDynamicInput(bool multigraph = true) + : PatternProcessPass("convert_tuple_input_to_dynamic_input", multigraph) {} + + ~ConvertTupleInputToDynamicInput() override = default; + + const BaseRef DefinePattern() const override; + + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc new file mode 100644 index 0000000000..34ba83ef17 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/convert_tuple_output_to_maketuple.h" + +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { + MS_EXCEPTION_IF_NULL(cnode_ptr); + MS_EXCEPTION_IF_NULL(graph); + std::vector convert_inputs = {cnode_ptr->input(0)}; + for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode_ptr); ++index) { + auto input_node = AnfAlgo::GetInputNode(cnode_ptr, index); + if (AnfAlgo::IsTupleOutput(input_node)) { + std::vector types; + std::vector> shapes; + std::vector make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)}; + for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) { + make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index)); + types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index)); + shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index)); + } + auto make_tuple = graph->NewCNode(make_tuple_inputs_list); + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); + convert_inputs.emplace_back(make_tuple); + } else { + convert_inputs.push_back(input_node); + } + } + return graph->NewCNode(convert_inputs); +} +} // namespace + +const BaseRef ConvertTupleOutputToMaketuple::DefinePattern() const { + VarPtr V = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { + return nullptr; + } + if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { + return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); + })) { + return ConvertTupleInputToMakeTuple(func_graph, cnode); + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.h b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.h new file mode 100644 index 0000000000..9ff5ca91ed --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H +#define MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H +#include +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvertTupleOutputToMaketuple : public PatternProcessPass { + public: + explicit ConvertTupleOutputToMaketuple(bool multigraph = true) + : PatternProcessPass("convert_tuple_output_to_maketuple", multigraph) {} + + ~ConvertTupleOutputToMaketuple() override = default; + + const BaseRef DefinePattern() const override; + + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H diff --git a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc new file mode 100644 index 0000000000..3ef912bcec --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc @@ -0,0 +1,190 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/pass/eliminate_redundant_op.h" +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "frontend/operator/ops.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +using KernelWithIndex = std::pair; +namespace { +CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector *pass_vector) { + MS_EXCEPTION_IF_NULL(pass_vector); + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::IsRealCNodeKernel(cnode)) { + pass_vector->push_back(make_pair(cnode, IntToSize(1))); + return cnode; + } + + auto input0 = cnode->input(0); + MS_EXCEPTION_IF_NULL(input0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + auto temp_node = cnode->input(index + IntToSize(1)); + MS_EXCEPTION_IF_NULL(temp_node); + pass_vector->push_back(make_pair(cnode, index + IntToSize(1))); + return GetRealPrevCNode(temp_node, 0, pass_vector); + } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + auto input2 = cnode->input(2); + MS_EXCEPTION_IF_NULL(input2); + auto value_node = input2->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + pass_vector->push_back(make_pair(cnode, IntToSize(1))); + return GetRealPrevCNode(cnode->input(1), IntToSize(item_idx), pass_vector); + } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { + pass_vector->push_back(make_pair(cnode, IntToSize(1))); + return GetRealPrevCNode(cnode->input(1), 0, pass_vector); + } else { + return nullptr; + } +} + +bool TransOpEliminateCondition(const CNodePtr &, const CNodePtr &) { return true; } + +bool CastEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { + return HasSymmetricalKernelInfo(node1, node2); +} + +bool TransDataOpEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { + return AnfAlgo::GetInputFormat(node1, 0) == AnfAlgo::GetOutputFormat(node2, 0) && + AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0); +} + +const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &prev_cnode, + std::vector *pass_vector) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(pass_vector); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + bool has_depend_node = false; + bool has_node_used_more_than_once = false; + auto &users = manager->node_users(); + + auto pass_size = pass_vector->size(); + for (size_t idx = 1; idx <= pass_size - 1; ++idx) { + auto nd = (*pass_vector)[idx].first; + if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend) || + AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { + has_depend_node = true; + } + if (users[nd].size() >= 2) { + has_node_used_more_than_once = true; + } + } + + // when no depend node and no node used more than once, no need to rebuild the pass nodes + if (!has_depend_node) { + return prev_cnode->input(1); + } else if (!has_node_used_more_than_once) { + (void)manager->Replace(prev_cnode, prev_cnode->input(1)); + return cnode->input(1); + } else { // rebuild the pass nodes + for (size_t idx = pass_size - 2; idx > 0; --idx) { + auto new_node = func_graph->NewCNode((*pass_vector)[idx].first->inputs()); + new_node->set_input((*pass_vector)[idx].second, + (*pass_vector)[idx + 1].first->input((*pass_vector)[idx + 1].second)); + (*pass_vector)[idx].first = new_node; + } + return (*pass_vector)[1].first; + } +} +} // namespace + +void EliminateRedundantOp::Init() { + (void)redundant_process_map_.emplace(std::pair( + kFour2FiveOpName, std::pair(kFive2FourOpName, TransOpEliminateCondition))); + (void)redundant_process_map_.emplace(std::pair( + kFive2FourOpName, std::pair(kFour2FiveOpName, TransOpEliminateCondition))); + (void)redundant_process_map_.emplace(std::pair( + prim::kPrimCast->name(), std::pair(prim::kPrimCast->name(), CastEliminateCondition))); + (void)redundant_process_map_.emplace(std::pair( + kTransDataOpName, std::pair(kTransDataOpName, TransDataOpEliminateCondition))); +} + +const AnfNodePtr EliminateRedundantOp::DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { + // match the first name + auto name1 = AnfAlgo::GetCNodeName(cnode); + auto it = redundant_process_map_.find(name1); + if (it == redundant_process_map_.end()) { + return nullptr; + } + std::vector pass_vector; + pass_vector.push_back(make_pair(cnode, 1)); + auto prev_cnode = GetRealPrevCNode(cnode->input(1), 0, &pass_vector); + if (prev_cnode == nullptr) { + return nullptr; + } + // match the second name + auto name2 = AnfAlgo::GetCNodeName(prev_cnode); + if (name2 != it->second.first) { + return nullptr; + } + // match condition + auto condition_func = it->second.second; + if (condition_func == nullptr) { + return nullptr; + } + if (!condition_func(cnode, prev_cnode)) { + return nullptr; + } + + return ProcessMatchedNodes(func_graph, cnode, prev_cnode, &pass_vector); +} + +const AnfNodePtr EliminateRedundantOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode == nullptr || func_graph == nullptr) { + return nullptr; + } + + if (AnfAlgo::IsGraphKernel(node)) { + // do eliminate for ops in graph kernel. + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + auto mng = sub_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + std::vector todo; + kernel::GetValidKernelNodes(sub_graph, &todo); + for (auto &t : todo) { + CNodePtr t_cnode = t->cast(); + MS_EXCEPTION_IF_NULL(t_cnode); + auto t_new_node = DoEliminate(sub_graph, t_cnode); + if (t_new_node != nullptr && t_new_node != t) { + (void)mng->Replace(t, t_new_node); + } + } + return node; + } + // do eliminate for single op. + return DoEliminate(func_graph, cnode); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.h b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.h new file mode 100644 index 0000000000..2fb4715cff --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +using ConditionFunc = std::function; +using RedundantOpPair = std::pair; + +class EliminateRedundantOp : public PatternProcessPass { + public: + explicit EliminateRedundantOp(bool multigraph = true) : PatternProcessPass("eliminate_redundant_op", multigraph) { + Init(); + } + ~EliminateRedundantOp() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + void Init(); + const AnfNodePtr DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; + std::unordered_map redundant_process_map_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.cc new file mode 100644 index 0000000000..8c6cb4beb5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/pass/erase_visit_attr.h" +#include +#include +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef EraseVisitAttr::DefinePattern() const { + std::shared_ptr V = std::make_shared(Visited); + std::shared_ptr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { + if (node != nullptr && AnfAlgo::IsRealCNodeKernel(node)) { + if (AnfAlgo::IsGraphKernel(node)) { + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + std::vector todos; + kernel::GetValidKernelNodes(fg, &todos); + for (auto &t : todos) { + AnfAlgo::EraseNodeAttr(kAttrVisited, t); + } + } + AnfAlgo::EraseNodeAttr(kAttrVisited, node); + } else { + AnfAlgo::EraseNodeAttr(kAttrVisited, node); + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.h b/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.h new file mode 100644 index 0000000000..37b88a4e39 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.h @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class EraseVisitAttr : public PatternProcessPass { + public: + explicit EraseVisitAttr(bool multigraph = true) : PatternProcessPass("erase_visit_attr", multigraph) {} + ~EraseVisitAttr() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc new file mode 100644 index 0000000000..32655f1ec2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc @@ -0,0 +1,222 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/fuse_basic.h" +#include "backend/optimizer/pass/fuse_graph_kernel.h" + +#include +#include +#include +#include +#include +#include + +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "utils/graph_utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "vm/segment_runner.h" +#include "debug/draw.h" +#include "debug/anf_ir_dump.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace opt { +namespace { +std::vector get_fusable_basic_ops(bool is_before_kernel_select) { + std::vector fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, + prim::kPrimExpandDims}; + if (!is_before_kernel_select) { + fusable_basic_ops.push_back(prim::kPrimCast); + } + return fusable_basic_ops; +} + +IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, + const AnfNodePtr &node) { + if (cur_node == node) { + return FOLLOW; + } + if (!IsPrimitiveCNode(node)) { + return EXCLUDE; + } + + auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select); + bool is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); + + return is_fusable ? FOLLOW : EXCLUDE; +} + +std::vector FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { + GraphKernelInfo info; + info.is_before_kernel_select = is_before_kernel_select; + // Search fusable nodes according input direction. + auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1); + auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); + if (used_nodes.size() > 1) { + used_nodes = RemoveCircle(used_nodes, false); + } + TopoSortForNodeList(&used_nodes); + return used_nodes; +} + +void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { + AnfNodeSet outputs_set; + for (auto out : *outputs) { + outputs_set.insert(out); + } + + AnfNodePtrList vir_outputs; + std::unordered_map eqv; + auto fg_outputs = fg->output(); + if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { + auto cnode = fg_outputs->cast(); + for (size_t i = 1; i < cnode->size(); ++i) { + vir_outputs.push_back(cnode->input(i)); + } + } else { + vir_outputs.push_back(fg_outputs); + } + + if (vir_outputs.size() != outputs->size()) { + MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; + } + bool has_erase_outs = false; + size_t index = -1; + for (auto it = outputs->begin(); it != outputs->end();) { + index++; + auto out = *it; + eqv[out] = vir_outputs[index]; + auto users = mng->node_users()[out]; + bool is_only_control_depend_use = true; + std::vector control_depend_use_index; + std::vector control_depend_nodes; + AnfNodePtr use_out = nullptr; + for (auto &user : users) { + auto use_node = user.first; + if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { + is_only_control_depend_use = false; + continue; + } + if (outputs_set.count(use_node) != 0) { + use_out = use_node; + } + + if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { + control_depend_nodes.push_back(use_node->cast()); + control_depend_use_index.push_back(user.second); + } + } + + if (is_only_control_depend_use && !control_depend_nodes.empty()) { + MS_EXCEPTION_IF_NULL(use_out); + it = outputs->erase(it); + for (size_t i = 0; i < control_depend_nodes.size(); ++i) { + auto control_depend_node = control_depend_nodes[i]; + std::vector new_control_depend_inputs; + for (size_t j = 0; j < control_depend_node->size(); ++j) { + if (j == control_depend_use_index[i]) { + new_control_depend_inputs.push_back(use_out); + } else { + new_control_depend_inputs.push_back(control_depend_node->input(j)); + } + } + auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs); + mng->Replace(control_depend_node, new_control_depend); + has_erase_outs = true; + } + } else { + it++; + } + } + + if (!has_erase_outs) { + return; + } + + AnfNodePtr fg_new_output; + if (outputs->size() > 1) { + std::vector output_args; + output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); + (void)std::transform(std::begin(*outputs), std::end(*outputs), std::back_inserter(output_args), + [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); + // Set output for AnfGraph + fg_new_output = fg->NewCNode(output_args); + } else { + fg_new_output = eqv[(*outputs)[0]]; + } + fg->set_output(fg_new_output, true); +} + +void FuseBasic(const std::shared_ptr &kernel_graph, const std::vector &todos, + std::unordered_set *fused_ops, bool is_before_kernel_select) { + auto mng = kernel_graph->manager(); + for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { + auto node = (*iter)->cast(); + if (node == nullptr) { + continue; + } + if (fused_ops->count(node)) { + continue; + } + auto fusable_basic_ops = get_fusable_basic_ops(is_before_kernel_select); + bool is_basic_op = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); + if (!is_basic_op || !kernel_graph->nodes().contains(node)) { + continue; + } + + auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select); + if (fuse_nodes.size() <= 1) { + continue; + } + + FuncGraphPtr fg; + AnfNodePtrList inputs; + AnfNodePtrList outputs; + std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); + RemoveControlDependOut(fg, &outputs, mng); + auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select); + + ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); + + // Set graph kernel attr + std::string fuse_op_name = ""; + for (auto &fuse_node : fuse_nodes) { + fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_"; + } + fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end()); + fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); + } +} +} // namespace + +void FuseBasic(const std::shared_ptr &kernel_graph, bool is_before_kernel_select) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto mng = kernel_graph->manager(); + if (mng == nullptr) { + mng = Manage(kernel_graph, true); + kernel_graph->set_manager(mng); + } + std::unordered_set fused_ops; + auto todos = TopoSort(kernel_graph->get_return()); + std::reverse(todos.begin(), todos.end()); + FuseBasic(kernel_graph, todos, &fused_ops, is_before_kernel_select); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.h b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.h new file mode 100644 index 0000000000..9b3916fe28 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.h @@ -0,0 +1,29 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +void FuseBasic(const std::shared_ptr &kernel_graph, bool is_before_kernel_select); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.cc b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.cc new file mode 100644 index 0000000000..e04110d8a0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.cc @@ -0,0 +1,562 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/fuse_graph_kernel.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "utils/graph_utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "vm/segment_runner.h" +#include "debug/draw.h" +#include "debug/anf_ir_dump.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace opt { +std::vector get_fusable_basic_ops(bool is_before_kernel_select) { + std::vector fusable_basic_ops = { + prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum, + prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt, + prim::kPrimReciprocal, prim::kPrimExpandDims, prim::kPrimLessEqual}; + if (!is_before_kernel_select) { + fusable_basic_ops.push_back(prim::kPrimCast); + } + return fusable_basic_ops; +} + +std::vector get_fusable_basic_ops_with_reduce(bool is_before_kernel_select) { + std::vector fusable_basic_ops_with_reduce; + if (!is_before_kernel_select) { + fusable_basic_ops_with_reduce.push_back(prim::kPrimCast); + } + return fusable_basic_ops_with_reduce; +} + +std::vector get_reduce_ops() { + std::vector reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin, + prim::kPrimReduceMax, prim::kPrimReduceAll}; + return reduce_ops; +} + +void GetGraphKernelInfo(const FuncGraphPtr fg, GraphKernelInfo *info) { + MS_EXCEPTION_IF_NULL(fg); + auto reduce_ops = get_reduce_ops(); + const auto &nodes = fg->nodes(); + info->op_type = ELEWISE; + info->cal_step = -1; + info->reduce_op_num = 0; + for (auto node : nodes) { + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + info->cal_step++; + auto prim = GetValueNode(cnode->input(0)); + if (prim != nullptr) { + bool is_reudce = std::any_of(reduce_ops.begin(), reduce_ops.end(), [&prim](const PrimitivePtr &op) { + return op->hash() == prim->hash() && op->name() == prim->name(); + }); + if (is_reudce) { + info->op_type = REDUCE; + info->reduce_op_num++; + } + } + } +} + +bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) { + auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select); + auto fusable_basic_ops_with_reduce = get_fusable_basic_ops_with_reduce(info.is_before_kernel_select); + bool is_fusable = false; + if (info.op_type == REDUCE && + (info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) { + is_fusable = std::any_of(fusable_basic_ops_with_reduce.begin(), fusable_basic_ops_with_reduce.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); + } else { + is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); + } + + return is_fusable; +} + +IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, + const AnfNodePtr &node) { + if (cur_node == node) { + return FOLLOW; + } + if (!IsPrimitiveCNode(node)) { + return EXCLUDE; + } + + bool is_fusable = IsFuse(info, node); + return is_fusable ? FOLLOW : EXCLUDE; +} + +IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, + const AnfNodePtr &node) { + if (cur_node == node) { + return FOLLOW; + } + if (AnfAlgo::IsGraphKernel(node)) { + auto cnode = node->cast(); + auto fg = GetValueNode(cnode->input(kAnfPrimitiveIndex)); + auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + MS_EXCEPTION_IF_NULL(fg_attr_val); + auto fg_attr = GetValue(fg_attr_val); + if (fg_attr == kApplyMomentumOpName) { + return FOLLOW; + } + return EXCLUDE; + } + if (!IsPrimitiveCNode(node)) { + return EXCLUDE; + } + + bool is_fusable = IsFuse(info, node); + return is_fusable ? FOLLOW : EXCLUDE; +} + +bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &check_node, + std::set *cached_unconnected_set) { + if (!check_node->isa() || AnfAlgo::IsGraphKernel(check_node)) { + return false; + } + + auto cnode = check_node->cast(); + const auto &inputs = cnode->inputs(); + // there is a input not in fused_op_set, but the input depends on the fused_op_set + bool has_circle = false; + for (auto input : inputs) { + if (input->isa() && !fused_op_set.count(input)) { + std::set done; + std::vector todos = {input}; + while (!todos.empty()) { + auto node = todos.back(); + todos.pop_back(); + if (done.count(node) || cached_unconnected_set->count(node)) { + continue; + } + + done.insert(node); + if (fused_op_set.count(node)) { + has_circle = true; + break; + } + + if (node->isa()) { + auto cnode_ptr = node->cast(); + for (auto it : cnode_ptr->inputs()) { + if (it->isa()) { + todos.push_back(it); + } + } + } + } + + if (has_circle) { + return true; + } + cached_unconnected_set->insert(done.begin(), done.end()); + } + } + + return false; +} + +bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) { + if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { + auto &inputs = out->cast()->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + real_outs->push_back(inputs[i]); + } + return true; + } + + if (AnfAlgo::GetCNodeFuncGraphPtr(out) != nullptr) { + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out); + auto fg_out = fg->output(); + if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) { + auto inputs = fg_out->cast()->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + real_outs->push_back(inputs[i]); + } + return true; + } + } + return false; +} + +std::vector RemoveCircle(const std::vector &fused_op, bool is_backward) { + std::set cached_unconnected_set; + std::set fused_op_set(fused_op.begin(), fused_op.end()); + auto include = [&fused_op_set](const AnfNodePtr &node) { + if (fused_op_set.count(node)) { + return FOLLOW; + } + return EXCLUDE; + }; + for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { + bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set); + // delete the circle node and the node which depend on the circle node in fused op + if (has_circle) { + auto mng = (*iter)->func_graph()->manager(); + std::vector erase_nodes; + if (is_backward) { + erase_nodes = DeepUsersSearch(*iter, include, mng); + } else { + erase_nodes = DeepLinkedGraphSearch(*iter, include); + } + for (auto erase_node : erase_nodes) { + fused_op_set.erase(erase_node); + } + } + } + + std::vector res; + for (auto node : fused_op) { + if (fused_op_set.count(node)) { + res.push_back(node); + } + } + return res; +} + +void TopoSortForNodeList(std::vector *lst) { + if (lst->size() < 2) { + return; + } + + std::vector res; + std::set node_sets(lst->begin(), lst->end()); + std::map> ins; + std::map> outs; + std::queue q; + for (auto node : *lst) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (auto input : cnode->inputs()) { + if (!node_sets.count(input)) { + continue; + } + // out_degree + outs[input].insert(node); + // in_degree + ins[node].insert(input); + } + if (!ins.count(node)) { + ins[node] = {}; + } + } + + for (auto p : ins) { + if (p.second.size() == 0) { + q.push(p.first); + } + } + + while (!q.empty()) { + auto node = q.front(); + q.pop(); + res.push_back(node); + if (!outs.count(node)) { + continue; + } + for (auto out : outs[node]) { + if (!ins.count(out)) { + continue; + } + ins[out].erase(node); + if (ins[out].size() == 0) { + q.push(out); + } + } + } + + lst->assign(res.begin(), res.end()); +} + +std::vector FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { + auto func_graph = cnode->func_graph(); + auto graph_kernel_g = GetValueNode(cnode->input(0)); + GraphKernelInfo info; + info.is_before_kernel_select = is_before_kernel_select; + GetGraphKernelInfo(graph_kernel_g, &info); + auto mng = func_graph->manager(); + // Search fusable nodes according input direction. + auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1); + auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); + std::reverse(used_nodes.begin(), used_nodes.end()); + // Search fusable nodes according output direction. + auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, info, std::placeholders::_1); + auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng); + + used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end()); + if (used_nodes.size() > 1) { + used_nodes = RemoveCircle(used_nodes); + } + TopoSortForNodeList(&used_nodes); + return used_nodes; +} + +AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) { + auto out_spec = node->abstract(); + if (out_spec->isa()) { + return out_spec->cast()->elements()[output_idx]; + } + return out_spec; +} + +AnfNodePtr CreateNewFuseCNode(const std::shared_ptr &kernel_graph, const FuncGraphPtr &fg, + const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, + bool is_before_kernel_select) { + auto func_node = NewValueNode(fg); + std::vector fn_inputs; + fn_inputs.push_back(func_node); + fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end()); + auto fuse_cnode = kernel_graph->NewCNode(fn_inputs); + // Set output abstract + if (outputs.size() > 1) { + std::vector out_specs; + for (size_t i = 0; i < outputs.size(); ++i) { + out_specs.push_back(outputs[i]->abstract()); + } + auto out_spec = std::make_shared(out_specs); + fuse_cnode->set_abstract(out_spec); + } else { + fuse_cnode->set_abstract(outputs[0]->abstract()); + } + // Set parameter abstract. + for (size_t i = 0; i < inputs.size(); ++i) { + auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); + auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); + fg->parameters()[i]->set_abstract(input_abs); + if (is_before_kernel_select) { + fg->parameters()[i]->set_kernel_info(std::make_shared()); + } + } + // Set kernel info. + if (!is_before_kernel_select) { + std::vector graph_input_format; + std::vector graph_input_type; + std::vector graph_output_format; + std::vector graph_output_type; + for (size_t i = 0; i < inputs.size(); ++i) { + auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); + auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); + graph_input_format.push_back(input_format); + auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); + graph_input_type.push_back(input_type); + auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); + fg->parameters()[i]->set_abstract(input_abs); + } + auto new_outputs = outputs; + if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) { + std::vector real_outs; + if (IsMakeTupleOut(outputs[0], &real_outs)) { + new_outputs = real_outs; + } + } + for (size_t i = 0; i < new_outputs.size(); ++i) { + auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0); + auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); + auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); + graph_output_format.push_back(output_format); + graph_output_type.push_back(output_type); + } + kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; + graph_info_builder.SetInputsFormat(graph_input_format); + graph_info_builder.SetInputsDeviceType(graph_input_type); + graph_info_builder.SetOutputsFormat(graph_output_format); + graph_info_builder.SetOutputsDeviceType(graph_output_type); + graph_info_builder.SetProcessor(kernel::Processor::AICORE); + graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); + graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); + auto graph_selected_info = graph_info_builder.Build(); + AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, fuse_cnode.get()); + } + return fuse_cnode; +} + +void ReplaceNewFuseCNode(const std::shared_ptr &kernel_graph, const AnfNodePtr &new_fuse_cnode, + const AnfNodePtrList &outputs) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto mng = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + // single out + if (outputs.size() == 1) { + mng->Replace(outputs[0], new_fuse_cnode); + return; + } + + std::vector fn_inputs; + for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) { + AnfNodePtrList real_outs; + // not make tuple out, replace + if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) { + fn_inputs.clear(); + fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); + fn_inputs.push_back(new_fuse_cnode); + fn_inputs.push_back(NewValueNode(MakeValue(SizeToInt(out_idx)))); + auto new_out = kernel_graph->NewCNode(fn_inputs); + new_out->set_abstract(outputs[out_idx]->abstract()); + mng->Replace(outputs[out_idx], new_out); + continue; + } + + // the out is make tuple , modify the get_item node's value + auto users = mng->node_users()[outputs[out_idx]]; + for (auto &user : users) { + auto use_node = user.first; + if (use_node->isa() && (IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem))) { + auto get_item_cnode = use_node->cast(); + auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(value_input); + auto value_node = value_input->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + int new_item_idx = SizeToInt(out_idx) + item_idx; + fn_inputs.clear(); + fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); + fn_inputs.push_back(new_fuse_cnode); + fn_inputs.push_back(NewValueNode(new_item_idx)); + auto new_out = kernel_graph->NewCNode(fn_inputs); + new_out->set_abstract(get_item_cnode->abstract()); + mng->Replace(get_item_cnode, new_out); + } + } + } +} + +AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) { + AnfNodePtrList outs; + auto out_node = (*fg)->output(); + if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { + std::vector output_args; + auto out_cnode = out_node->cast(); + for (auto out : out_cnode->inputs()) { + if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { + auto inputs = out->cast()->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + output_args.push_back(inputs[i]); + } + } else { + output_args.push_back(out); + } + } + if (output_args.size() != out_cnode->inputs().size()) { + auto new_out = (*fg)->NewCNode(output_args); + (*mng)->Replace(out_node, new_out); + } + + for (size_t i = 1; i < output_args.size(); ++i) { + outs.push_back(output_args[i]); + } + return outs; + } + + outs.push_back(out_node); + return outs; +} + +AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) { + AnfNodePtrList res; + if (outs.size() <= 1) { + return outs; + } + + for (auto out : outs) { + AnfNodePtrList real_outs; + if (IsMakeTupleOut(out, &real_outs)) { + res.insert(res.end(), real_outs.begin(), real_outs.end()); + continue; + } + res.push_back(out); + } + return res; +} + +void FuseGraphKernel(const std::shared_ptr &kernel_graph, bool is_before_kernel_select) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto mng = kernel_graph->manager(); + if (mng == nullptr) { + mng = Manage(kernel_graph, true); + kernel_graph->set_manager(mng); + } + auto &todos = kernel_graph->execution_order(); + for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { + auto node = *iter; + if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) { + continue; + } + + auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + if (fg_attr != nullptr) { + auto fg_name = GetValue(fg_attr); + if (graph_kernel_black_list.count(fg_name) != 0) { + continue; + } + } + + auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select); + if (fuse_nodes.size() <= 1) { + continue; + } + + FuncGraphPtr fg; + AnfNodePtrList inputs; + AnfNodePtrList outputs; + std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); + + // Remove nest make tuple in outs + auto expand_out = GetExpandOuts(outputs); + auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, expand_out, is_before_kernel_select); + + ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); + + // Inline origin graphkernel + auto cnodes = fg->GetOrderedCnodes(); + for (const auto &n : cnodes) { + if (!AnfAlgo::IsGraphKernel(n)) { + continue; + } + auto graph_kernel_g = GetValueNode(n->input(0)); + AnfNodePtrList ins; + ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end()); + auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope()); + mng->Replace(n, out); + } + + EliminateMakeTuple(&fg, &mng); + // Set graphkernel flag + auto ori_fg = GetValueNode(node->input(kAnfPrimitiveIndex)); + fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, ori_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.h b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.h new file mode 100644 index 0000000000..e14661dfdf --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.h @@ -0,0 +1,63 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +enum GraphKernelType { + ELEWISE = 0, // only contain elewise basic ops + REDUCE, // contain reduce ops + CUBE, // contain cube ops +}; +struct GraphKernelInfo { + GraphKernelType op_type = ELEWISE; + bool is_before_kernel_select = false; + int reduce_op_num = 0; + int cal_step = 0; +}; + +// when reduce graph kernel's cal step is greater than this number, not fuse +const int MAX_REDUCE_OP_FUSION_CAL_STEP = 5; +// when reduce graph kernel contain reduce op num is greater than this number, not fuse +const int MAX_REDUCE_OP_FUSION_REDUCE_NUM = 2; + +const std::set graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", + "LambNextMV", "LambUpdateWithLR"}; + +std::vector RemoveCircle(const std::vector &fused_op, bool is_backward = true); + +void TopoSortForNodeList(std::vector *lst); + +AnfNodePtr CreateNewFuseCNode(const std::shared_ptr &kernel_graph, const FuncGraphPtr &fg, + const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, + bool is_before_kernel_select); + +void ReplaceNewFuseCNode(const std::shared_ptr &kernel_graph, const AnfNodePtr &new_fuse_cnode, + const AnfNodePtrList &outputs); + +void FuseGraphKernel(const std::shared_ptr &kernel_graph, bool is_before_kernel_select = false); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.cc b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.cc new file mode 100644 index 0000000000..a51a6bab42 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/getitem_tuple.h" + +#include +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +bool IsC(const BaseRef &n) { + MS_EXCEPTION_IF_NULL(n); + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + return in->isa(); + } else { + return false; + } +} +} // namespace + +const BaseRef GetitemTuple::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VarPtr C = std::make_shared(IsC); + return VectorRef({prim::kPrimTupleGetItem, VectorRef({prim::kPrimMakeTuple, Xs}), C}); +} + +const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + CNodePtr tuple_getitem = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + if (tuple_getitem->inputs().size() < kTupleGetitemInputNum) { + MS_LOG(EXCEPTION) << "tuple getitem's input num is wrong"; + } + AnfNodePtr make_tuple_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(make_tuple_anf); + AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + if (IsValueNode(index_node)) { + ValueNodePtr value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + CNodePtr make_tuple = make_tuple_anf->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + if (make_tuple->inputs().size() > IntToSize(index + 1)) { + auto ret = make_tuple->input(IntToSize(index + 1)); + MS_EXCEPTION_IF_NULL(ret); + return ret; + } + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.h b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.h new file mode 100644 index 0000000000..9a25b924bd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class GetitemTuple : public PatternProcessPass { + public: + explicit GetitemTuple(bool multigraph = true) : PatternProcessPass("getitem_tuple", multigraph) {} + ~GetitemTuple() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc new file mode 100644 index 0000000000..710e130a85 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/pass/optimize_dependence.h" +#include +#include +#include +#include "backend/optimizer/common/helper.h" +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +constexpr auto kSingleInputIndex = 1; +namespace { +AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + string op_name = AnfAlgo::GetCNodeName(cnode); + // Currently we only eliminate transdata or cast nodes. + if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { + return nullptr; + } + CheckCNodeInputSize(cnode, kSingleInputIndex + 1); + return cnode->input(kSingleInputIndex); +} + +AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { + return nullptr; + } + std::vector new_make_tuple_inputs; + bool need_update = false; + for (const auto &input : cnode->inputs()) { + AnfNodePtr replace_input = GetReplaceNode(input); + // If replace input is not null, it will be the input of the TransData or Cast. + if (replace_input == nullptr) { + new_make_tuple_inputs.push_back(input); + continue; + } + new_make_tuple_inputs.push_back(replace_input); + need_update = true; + } + if (need_update) { + auto kernel_graph = func_graph->cast>(); + CNodePtr new_make_tuple = nullptr; + if (kernel_graph == nullptr) { + new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs); + } else { + new_make_tuple = kernel_graph->NewCNode(cnode); + } + MS_EXCEPTION_IF_NULL(new_make_tuple); + new_make_tuple->set_inputs(new_make_tuple_inputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(cnode, new_make_tuple); + return new_make_tuple; + } + return nullptr; +} +} // namespace + +const BaseRef OptimizeDependence::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return nullptr; + } + auto node_name = AnfAlgo::GetCNodeName(node); + if (node_name != prim::kPrimControlDepend->name() && node_name != prim::kPrimDepend->name()) { + return nullptr; + } + size_t index = 0; + auto depend_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(depend_cnode); + std::vector new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)}; + if (node_name == prim::kPrimDepend->name()) { + index = 1; + new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend)); + } + if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) { + MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got " + << AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString(); + } + auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode); + while (index < input_num) { + auto replace_node = GetConvertNode(func_graph, node, index); + MS_EXCEPTION_IF_NULL(replace_node); + new_depend_inputs.push_back(replace_node); + ++index; + } + auto kernel_graph = func_graph->cast>(); + CNodePtr new_depend = nullptr; + if (kernel_graph == nullptr) { + new_depend = func_graph->NewCNode(new_depend_inputs); + MS_EXCEPTION_IF_NULL(new_depend); + new_depend->set_abstract(node->abstract()); + new_depend->set_scope(node->scope()); + } else { + new_depend = kernel_graph->NewCNode(depend_cnode); + MS_EXCEPTION_IF_NULL(new_depend); + new_depend->set_inputs(new_depend_inputs); + } + return new_depend; +} + +const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, + const size_t index) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto depend_cnode = node->cast(); + auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); + MS_EXCEPTION_IF_NULL(replacing_node); + if (!replacing_node->isa()) { + return replacing_node; + } + auto replacing_cnode = replacing_node->cast(); + MS_EXCEPTION_IF_NULL(replacing_cnode); + // Deal with the make_tuple with TransData or Cast inputs. + auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode); + if (make_tuple_replace_node != nullptr) { + return make_tuple_replace_node; + } + AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); + if (replace_node == nullptr) { + MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); + return replacing_node; + } + return replace_node; +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.h b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.h new file mode 100644 index 0000000000..8ddd4d662e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class OptimizeDependence : public PatternProcessPass { + public: + explicit OptimizeDependence(bool multigraph = true) : PatternProcessPass("optimize_dependence", multigraph) {} + ~OptimizeDependence() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + const AnfNodePtr GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t index) const; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc new file mode 100644 index 0000000000..cd34464cda --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/replace_node_by_proxy.h" +#include +#include +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace opt { +kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector inputs_device_format; + std::vector outputs_device_format; + std::vector inputs_device_type; + std::vector outputs_device_type; + std::vector> outputs_shape; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); + inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); + outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); + } + builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); + builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); + builder.SetKernelType(AnfAlgo::GetKernelType(cnode)); + + builder.SetInputsFormat(inputs_device_format); + builder.SetOutputsFormat(outputs_device_format); + builder.SetInputsDeviceType(inputs_device_type); + builder.SetOutputsDeviceType(outputs_device_type); + return builder.Build(); +} + +bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto node : node_list) { + if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) { + CNodePtr cnode = node->cast(); + auto prim = std::make_shared(kEmbeddingLookupProxyOpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector proxy_inputs = {NewValueNode(prim)}; + proxy_inputs.insert(proxy_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + AnfNodePtr proxy_node = func_graph->NewCNode(proxy_inputs); + MS_EXCEPTION_IF_NULL(proxy_node); + + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + proxy_node->set_kernel_info(kernel_info); + + AbstractBasePtrList abstract_list; + AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node); + AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node); + AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node); + abstract_list.push_back(cnode->abstract()); + auto abstract_tuple = std::make_shared(abstract_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + proxy_node->set_abstract(abstract_tuple); + + auto kernel_build_info = GenerateKernelBuildInfo(cnode); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, proxy_node.get()); + + if (!manager->Replace(cnode, proxy_node)) { + MS_LOG(EXCEPTION) << "Replace node by proxy node failed."; + } + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.h b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.h new file mode 100644 index 0000000000..382b08304f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ +#include +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace opt { +class ReplaceNodeByProxy : public Pass { + public: + explicit ReplaceNodeByProxy(const std::string &name) : Pass(name) {} + ~ReplaceNodeByProxy() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CNodePtr &cnode); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ diff --git a/mindspore/ccsrc/backend/session/CMakeLists.txt b/mindspore/ccsrc/backend/session/CMakeLists.txt new file mode 100644 index 0000000000..b7b791ada9 --- /dev/null +++ b/mindspore/ccsrc/backend/session/CMakeLists.txt @@ -0,0 +1,32 @@ +file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "kernel_graph.cc" + "session_basic.cc" + "session_factory.cc" + "anf_runtime_algorithm.cc" +) + +if (ENABLE_GPU) + file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "gpu_session.cc" + ) + list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) +endif () + +if (ENABLE_CPU) + file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "cpu_session.cc" + ) + list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST}) +endif () + +if (ENABLE_D) + file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "ascend_session.cc" + "ascend_control_parser.cc" + "ascend_inference_session.cc" + ) + list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) +endif () + +set_property(SOURCE ${_SESSION_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_SESSION) +add_library(_mindspore_backend_session_obj OBJECT ${_SESSION_SRC_LIST}) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc new file mode 100644 index 0000000000..38c040e6b1 --- /dev/null +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -0,0 +1,1151 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/anf_runtime_algorithm.h" +#include +#include +#include +#include +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "runtime/device/kernel_info.h" +#include "runtime/device/device_address.h" +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "common/utils.h" +#include "common/trans.h" + +namespace mindspore { +namespace session { +using abstract::AbstractTensor; +using abstract::AbstractTuple; +using device::KernelInfo; +using device::ascend::AscendDeviceAddress; +using kernel::KernelBuildInfoPtr; +using kernel::KernelMod; +using kernel::KernelModPtr; +namespace { +constexpr size_t kNopNodeInputSize = 2; +constexpr size_t kNopNodeRealInputIndex = 1; + +std::vector TransShapeToSizet(const abstract::ShapePtr &shape) { + MS_EXCEPTION_IF_NULL(shape); + std::vector shape_size_t; + std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize); + return shape_size_t; +} +} // namespace + +AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) { + MS_EXCEPTION_IF_NULL(tuple_get_item); + if (tuple_get_item->size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem); +} + +size_t AnfRuntimeAlgorithm::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) { + MS_EXCEPTION_IF_NULL(tuple_get_item); + if (tuple_get_item->size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(output_index_value_node); + auto value_node = output_index_value_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + return IntToSize(GetValue(value_node->value())); +} + +KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + if (anf_node->isa()) { + return std::make_pair(anf_node, 0); + } else if (anf_node->isa()) { + return std::make_pair(anf_node, 0); + } else if (anf_node->isa()) { + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input0 = cnode->input(0); + MS_EXCEPTION_IF_NULL(input0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + auto node = cnode->input(index + IntToSize(1)); + MS_EXCEPTION_IF_NULL(node); + return VisitKernel(node, 0); + } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + if (cnode->inputs().size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(input2); + auto value_node = input2->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx)); + } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { + return VisitKernel(cnode->input(kRealInputIndexInDepend), 0); + } else { + return std::make_pair(anf_node, index); + } + } else { + MS_LOG(EXCEPTION) << "The input is invalid"; + } +} + +KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, int index, + bool visit_nop_node, + const std::vector &return_types) { + MS_EXCEPTION_IF_NULL(anf_node); + if (std::any_of(return_types.begin(), return_types.end(), [&anf_node](const PrimitivePtr &prim_type) -> bool { + return CheckPrimitiveType(anf_node, prim_type); + })) { + return KernelWithIndex(anf_node, index); + } + if (!anf_node->isa()) { + return KernelWithIndex(anf_node, 0); + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) { + auto item_with_index_tmp = VisitKernelWithReturnType(GetTupleGetItemRealInput(cnode), + GetTupleGetItemOutIndex(cnode), visit_nop_node, return_types); + if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) { + MS_EXCEPTION_IF_NULL(item_with_index_tmp.first); + auto make_tuple = item_with_index_tmp.first->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + const std::vector &make_tuple_inputs = make_tuple->inputs(); + size_t make_tuple_input_index = item_with_index_tmp.second + 1; + if (make_tuple_input_index >= make_tuple_inputs.size()) { + MS_LOG(EXCEPTION) << "Index[" << make_tuple_input_index << "] out of range[" << make_tuple_inputs.size() + << "]."; + } + return VisitKernelWithReturnType(make_tuple_inputs[make_tuple_input_index], 0, visit_nop_node, return_types); + } + return item_with_index_tmp; + } + if (CheckPrimitiveType(cnode, prim::kPrimDepend) || CheckPrimitiveType(cnode, prim::kPrimControlDepend)) { + return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types); + } + if (opt::IsNopNode(cnode) && visit_nop_node) { + if (cnode->size() != kNopNodeInputSize) { + MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString(); + } + return VisitKernelWithReturnType(cnode->input(kNopNodeRealInputIndex), 0, visit_nop_node, return_types); + } + return KernelWithIndex(anf_node, index); +} + +std::vector AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node, + const std::vector &return_types) { + std::vector ret; + auto return_prim_type = return_types; + // if visited make_tuple should return back + return_prim_type.push_back(prim::kPrimMakeTuple); + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type); + if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { + MS_EXCEPTION_IF_NULL(item_with_index.first); + auto make_tuple = item_with_index.first->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + for (size_t i = 1; i < make_tuple->inputs().size(); i++) { + auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types); + (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret)); + } + return ret; + } + ret.push_back(item_with_index.first); + return ret; +} + +AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + return node->input(kAnfPrimitiveIndex); +} + +PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto attr_input = GetCNodePrimitiveNode(cnode); + MS_EXCEPTION_IF_NULL(attr_input); + auto value_node = attr_input->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + auto primitive = value->cast(); + return primitive; +} + +bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); +} + +FuncGraphPtr AnfRuntimeAlgorithm::GetCNodeFuncGraphPtr(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + auto value_node = attr_input->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + return value->cast(); +} + +std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto primitive = AnfAlgo::GetCNodePrimitive(node); + if (primitive != nullptr) { + return primitive->name(); + } + auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(func_graph); + return func_graph->ToString(); + } + MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString(); +} + +std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + return node->DebugString(); +} + +void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString(); + } + // single op cnode. + auto primitive = AnfAlgo::GetCNodePrimitive(node); + if (primitive != nullptr) { + primitive->set_attr(key, value); + return; + } + // graph kernel cnode. + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + fg->set_attr(key, value); +} + +void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) { + CopyNodeAttr(key, key, from, to); +} + +void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, + const AnfNodePtr &to) { + MS_EXCEPTION_IF_NULL(from); + MS_EXCEPTION_IF_NULL(to); + if (!from->isa() || !to->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is " + << to->DebugString(); + } + auto from_primitive = AnfAlgo::GetCNodePrimitive(from); + MS_EXCEPTION_IF_NULL(from_primitive); + auto to_primitive = AnfAlgo::GetCNodePrimitive(to); + MS_EXCEPTION_IF_NULL(to_primitive); + to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key)); +} + +void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) { + MS_EXCEPTION_IF_NULL(from); + MS_EXCEPTION_IF_NULL(to); + if (!from->isa() || !to->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is " + << from->DebugString(); + } + auto from_primitive = AnfAlgo::GetCNodePrimitive(from); + MS_EXCEPTION_IF_NULL(from_primitive); + auto to_primitive = AnfAlgo::GetCNodePrimitive(to); + MS_EXCEPTION_IF_NULL(to_primitive); + (void)to_primitive->SetAttrs(from_primitive->attrs()); +} + +void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString(); + } + // single op cnode. + auto primitive = AnfAlgo::GetCNodePrimitive(node); + if (primitive != nullptr) { + primitive->EraseAttr(key); + return; + } + // graph kernel cnode. + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + fg->erase_flag(key); +} + +bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString(); + return false; + } + // single op cnode. + auto primitive = AnfAlgo::GetCNodePrimitive(node); + if (primitive != nullptr) { + return primitive->HasAttr(key); + } + // graph kernel cnode. + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + return fg->has_attr(key); +} + +size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString(); + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + size_t input_num = cnode->inputs().size(); + if (input_num == 0) { + MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero"; + } + // exclude intputs[0],which is value_node storing attr,inputs left are real input + return input_num - 1; +} + +size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + TypePtr type = node->Type(); + if (type == nullptr) { + return 0; + } + if (type->isa()) { + auto tuple_type = type->cast(); + MS_EXCEPTION_IF_NULL(tuple_type); + return tuple_type->size(); + } else if (type->isa() || type->isa()) { + return 1; + } else if (type->isa()) { + return 0; + } else { + return 1; + } +} + +std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "Output index:" << output_idx + << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" + << node->DebugString() << "]"; + } + if (!AnfAlgo::IsRealKernel(node)) { + return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); + } + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + auto format = build_info->GetOutputFormat(output_idx); + if (format == kernel::KernelBuildInfo::kInvalidFormat) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid output format"; + } + return format; +} + +std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) { + MS_EXCEPTION_IF_NULL(node); + if (input_idx > GetInputTensorNum(node)) { + MS_LOG(EXCEPTION) << "Input index :" << input_idx + << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" + << node->DebugString() << "]"; + } + if (!IsRealKernel(node)) { + GetPrevNodeOutputFormat(node, input_idx); + } + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + auto format = build_info->GetInputFormat(input_idx); + if (format == kernel::KernelBuildInfo::kInvalidFormat) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid input format"; + } + return format; +} + +KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) { + MS_EXCEPTION_IF_NULL(anf_node); + if (!anf_node->isa()) { + MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (input_idx + 1 >= cnode->inputs().size()) { + MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); + } + auto node = cnode->input(input_idx + 1); + MS_EXCEPTION_IF_NULL(node); + return VisitKernel(node, 0); +} + +std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); + return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); +} + +std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); + return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second); +} + +std::vector AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + abstract::BaseShapePtr base_shape = node->Shape(); + MS_EXCEPTION_IF_NULL(base_shape); + if (base_shape->isa() && output_idx == 0) { + return TransShapeToSizet(base_shape->cast()); + } else if (base_shape->isa()) { + auto tuple_shape = base_shape->cast(); + MS_EXCEPTION_IF_NULL(tuple_shape); + if (output_idx >= tuple_shape->size()) { + MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size() + << "."; + } + auto b_shp = (*tuple_shape)[output_idx]; + if (b_shp->isa()) { + return TransShapeToSizet(b_shp->cast()); + } else if (b_shp->isa()) { + return std::vector(); + } else { + MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx + << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString(); + } + } else if (base_shape->isa()) { + return std::vector(); + } + MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is " + << base_shape->ToString(); +} + +std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); + return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second); +} + +std::vector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) { + auto format = GetOutputFormat(node, output_idx); + auto infer_shape = GetOutputInferShape(node, output_idx); + if (infer_shape.empty()) { + return infer_shape; + } + // if format is default_format or NC1KHKWHWC0,device shape = original shape + if (trans::IsNeedPadding(format, infer_shape.size())) { + infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); + } + return trans::TransShapeToDevice(infer_shape, format); +} + +std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { + auto format = GetInputFormat(node, input_idx); + auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx); + if (infer_shape.empty()) { + return infer_shape; + } + // if format is default_format or NC1KHKWHWC0,device shape = original shape + if (trans::IsNeedPadding(format, infer_shape.size())) { + infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); + } + return trans::TransShapeToDevice(infer_shape, format); +} + +std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { + MS_EXCEPTION_IF_NULL(node); + if (input_idx > GetInputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index:" << input_idx + << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" + << node->DebugString() << "]"; + } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputReshapeType(node, input_idx); + } + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + if (build_info->IsInputDefaultPadding()) { + return {}; + } + return build_info->GetInputReshapeType(input_idx); +} + +std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; + } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputReshapeType(node, output_idx); + } + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + if (build_info->IsOutputDefaultPadding()) { + return {}; + } + return build_info->GetOutputReshapeType(output_idx); +} + +TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + TypePtr type_ptr = node->Type(); + MS_EXCEPTION_IF_NULL(type_ptr); + if (type_ptr->isa() && output_idx == 0) { + auto tensor_ptr = type_ptr->cast(); + MS_EXCEPTION_IF_NULL(tensor_ptr); + TypePtr elem = tensor_ptr->element(); + MS_EXCEPTION_IF_NULL(elem); + return elem->type_id(); + } else if (type_ptr->isa()) { + auto tuple_ptr = type_ptr->cast(); + MS_EXCEPTION_IF_NULL(tuple_ptr); + if (output_idx >= tuple_ptr->size()) { + MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size(); + } + auto tuple_i = (*tuple_ptr)[output_idx]; + MS_EXCEPTION_IF_NULL(tuple_i); + if (tuple_i->isa()) { + auto tensor_ptr = tuple_i->cast(); + MS_EXCEPTION_IF_NULL(tensor_ptr); + TypePtr elem = tensor_ptr->element(); + MS_EXCEPTION_IF_NULL(elem); + return elem->type_id(); + } else if (tuple_i->isa()) { + return tuple_i->type_id(); + } else { + MS_LOG(WARNING) << "Not support type " << tuple_i->ToString(); + return tuple_i->type_id(); + } + } else if (type_ptr->isa()) { + return type_ptr->type_id(); + } + return type_ptr->type_id(); +} + +TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); + return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second); +} + +TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; + } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputDeviceDataType(node, output_idx); + } + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + auto dtype = build_info->GetOutputDeviceType(output_idx); + if (dtype == TypeId::kNumberTypeEnd) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid dtype"; + } + return dtype; +} + +TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) { + MS_EXCEPTION_IF_NULL(node); + if (input_idx > GetInputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " + << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; + } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputDeviceDataType(node, 0); + } + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + auto dtype = build_info->GetInputDeviceType(input_idx); + if (dtype == TypeId::kNumberTypeEnd) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid dtype"; + } + return dtype; +} + +TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); + return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); +} + +// get output device addr of anf_node +const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx, + bool visit_nop_node) { + MS_EXCEPTION_IF_NULL(node); + if (opt::IsNopNode(node) && visit_nop_node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() == kNopNodeInputSize) { + return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0); + } else { + MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; + } + } + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto addr = kernel_info->GetOutputAddr(output_idx); + if (addr == nullptr) { + MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() + << " output addr is not exist"; + } + return addr; +} + +DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, + bool visit_nop_node) { + MS_EXCEPTION_IF_NULL(node); + if (opt::IsNopNode(node) && visit_nop_node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() == kNopNodeInputSize) { + return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0); + } else { + MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; + } + } + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto addr = kernel_info->GetMutableOutputAddr(output_idx); + if (addr == nullptr) { + MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString() + << " output addr is not exist"; + } + return addr; +} + +// get output device addr of anf_node +bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; + } + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->OutputAddrExist(output_idx); +} + +const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, + bool visit_nop_node) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); + return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); +} + +DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, + bool visit_nop_node) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); + return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); +} + +// set output device addr of anf_node +void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + if (!kernel_info->SetOutputAddr(addr, output_idx)) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; + } +} + +// set workspace device addr of anf_node +void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; + } +} + +// get workspace device addr of anf_node +DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto addr = kernel_info->GetWorkspaceAddr(output_idx); + if (addr == nullptr) { + MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() + << "] workspace addr is not exist"; + } + return addr; +} + +// set infer shapes and types of anf node +void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector &types, + const std::vector> &shapes, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + if (types.size() != shapes.size()) { + MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size(); + } + if (shapes.empty()) { + node->set_abstract(std::make_shared()); + } else if (shapes.size() == 1) { + // single output handle + std::vector shape_int; + std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt); + auto abstract = std::make_shared(TypeIdToType(types[0]), shape_int); + node->set_abstract(abstract); + } else { + // multiple output handle + std::vector abstract_list; + for (size_t i = 0; i < types.size(); ++i) { + std::vector shape_int; + std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt); + abstract_list.push_back(std::make_shared(TypeIdToType(types[i]), shape_int)); + } + auto abstract_tuple = std::make_shared(abstract_list); + node->set_abstract(abstract_tuple); + } +} +// copy an abstract of a node to another node +void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) { + to_node->set_abstract(from_node->abstract()); +} + +kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + // select_kernel_build_info() has checked whether return pointer is null + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + return build_info->op_pattern(); +} + +// get KernelBuildType of node, such as ATT,RT,FWK and so on +KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + // select_kernel_build_info() has checked whether return pointer is null + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + return build_info->kernel_type(); +} + +kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + return build_info->processor(); +} + +kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + return build_info->fusion_type(); +} + +// set select kernel_build_info +void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->set_select_kernel_build_info(select_kernel_build_info); +} + +// get select kernel_build_info +KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->GetMutableSelectKernelBuildInfo(); +} + +// get kernelMode +KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->MutableKernelMod(); +} + +// set kernel mod +void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + kernel_info->set_kernel_mod(kernel_mod); +} + +bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // parameter and value node is not a real kernel too + if (!node->isa()) { + return true; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString(); + } + auto input = cnode->inputs()[0]; + bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || + IsPrimitive(input, prim::kPrimTensorSummary) || + IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || + IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || + IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || + IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); + return !is_virtual_node; +} + +bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // parameter and value node is not a real cnode kernel + if (!node->isa()) { + return false; + } + // return considered as a real node + if (CheckPrimitiveType(node, prim::kPrimReturn)) { + return true; + } + return IsRealKernel(node); +} + +bool AnfRuntimeAlgorithm::IsGraphKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // graph kernel should be a real cnode kernel. + if (!IsRealCNodeKernel(node)) { + return false; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input = cnode->input(kAnfPrimitiveIndex); + // graph kernel should has func_graph as first input. + if (!IsValueNode(input)) { + return false; + } + + auto func_graph = GetValueNode(input); + MS_EXCEPTION_IF_NULL(func_graph); + return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); +} + +bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { + MS_EXCEPTION_IF_NULL(node); + return node->has_default(); +} + +void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + kernel_info->set_stream_id(stream_id); +} + +uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->stream_id(); +} + +void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + kernel_info->set_stream_distinction_label(stream_label); +} + +uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->stream_distinction_label(); +} + +void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + kernel_info->set_graph_id(graph_id); +} + +uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->graph_id(); +} + +bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) { + MS_EXCEPTION_IF_NULL(anf); + TypePtr type = anf->Type(); + MS_EXCEPTION_IF_NULL(type); + return type->isa(); +} + +AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) { + MS_EXCEPTION_IF_NULL(node); + auto get_input_index = index + 1; + if (index + 1 > node->inputs().size()) { + MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just" + << node->inputs().size(); + } + // input 0 is primitive node + return node->input(get_input_index); +} + +bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + return false; + } + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->is_feature_map(); +} + +bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) { + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map"; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input_node = cnode->input(input_index + 1); + return IsFeatureMapOutput(input_node); +} + +size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) { + MS_EXCEPTION_IF_NULL(anf_node); + static std::map> spec_node_list = { + {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}}, + {kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}}, + {kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, + {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}}, + {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}}, + {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, + {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, + {prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, + {prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}, + {prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}, + {prim::kPrimApplyCenteredRMSProp->name(), + {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 4}}}}; + size_t ret = cur_index; + auto node_name = AnfAlgo::GetCNodeName(anf_node); + if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) { + auto find = spec_node_list.find(node_name); + if (find != spec_node_list.end()) { + ret = find->second[cur_index]; + MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name; + } + } + return ret; +} + +void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(input_node); + node->set_input(index + 1, input_node); +} + +bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto kernel_name = AnfAlgo::GetCNodeName(node); + if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName || + kernel_name == kReduceScatterOpName) { + return true; + } + return false; +} + +bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { + auto kernel_name = AnfAlgo::GetCNodeName(node); + return kernel_name == kGetNextOpName; +} + +FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto value_node = node->cast(); + if (value_node == nullptr) { + return nullptr; + } + auto value = value_node->value(); + if (value == nullptr) { + return nullptr; + } + auto func_graph = value->cast(); + return func_graph; +} + +std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { + MS_EXCEPTION_IF_NULL(call_node); + if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared("call"))) { + MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node."; + } + auto input1 = call_node->input(1); + MS_EXCEPTION_IF_NULL(input1); + if (input1->isa()) { + auto value_node = input1->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto kernel_graph = value_node->value(); + MS_EXCEPTION_IF_NULL(kernel_graph); + return {kernel_graph->cast()}; + } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { + auto switch_node = input1->cast(); + MS_EXCEPTION_IF_NULL(switch_node); + auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { + auto partial = switch_node->input(input_index); + MS_EXCEPTION_IF_NULL(partial); + if (IsValueNode(partial)) { + return GetValueNode(partial); + } + auto partial_cnode = partial->cast(); + MS_EXCEPTION_IF_NULL(partial_cnode); + auto graph_node = partial_cnode->input(1); + MS_EXCEPTION_IF_NULL(graph_node); + auto graph_value_node = graph_node->cast(); + MS_EXCEPTION_IF_NULL(graph_value_node); + auto graph_value = graph_value_node->value(); + MS_EXCEPTION_IF_NULL(graph_value); + auto child_graph = graph_value->cast(); + return child_graph; + }; + return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; + } + return {}; +} + +bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { + MS_EXCEPTION_IF_NULL(call_node); + if (!CheckPrimitiveType(call_node, prim::kPrimCall)) { + MS_LOG(EXCEPTION) << "Call node should be a 'call', but is a " << call_node->DebugString(); + } + auto input1 = call_node->input(1); + if (input1->isa()) { + return false; + } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { + return true; + } + MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); +} + +bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (shape.empty()) { + return true; + } + return shape.size() == kShape1dDims && shape[0] == 1; +} + +bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (shape.empty()) { + return true; + } + return shape.size() == kShape1dDims && shape[0] == 1; +} + +void AnfRuntimeAlgorithm::ReorderExecList(NotNull *> node_list) { + std::vector all_opt_list; + std::vector non_opt_list; + + for (const auto &node : *node_list) { + MS_EXCEPTION_IF_NULL(node); + if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) { + all_opt_list.emplace_back(node); + } else { + non_opt_list.emplace_back(node); + } + } + node_list->clear(); + std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list)); + std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); +} + +TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto prim = AnfAlgo::GetCNodePrimitive(node); + if (prim == nullptr) { + return kTypeUnknown; + } + + TypeId except_type = kTypeUnknown; + if (prim->GetAttr(kAttrOutputPrecision) != nullptr) { + auto output_type_str = GetValue(prim->GetAttr(kAttrOutputPrecision)); + if (output_type_str == "float16") { + except_type = kNumberTypeFloat16; + } else if (output_type_str == "float32") { + except_type = kNumberTypeFloat32; + } else { + MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got " << output_type_str; + } + } + + return except_type; +} + +TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx) { + if (!node->isa()) { + MS_LOG(EXCEPTION) << node->DebugString() << ", input node is not CNode."; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (input_idx + 1 >= cnode->inputs().size()) { + MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); + } + auto input_node = cnode->input(input_idx + 1); + MS_EXCEPTION_IF_NULL(input_node); + auto kernel_with_index = VisitKernel(input_node, 0); + if (!kernel_with_index.first->isa()) { + return kTypeUnknown; + } + return GetCNodeOutputPrecision(kernel_with_index.first); +} + +bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode."; + } + auto input = node->input(kAnfPrimitiveIndex); + return IsPrimitive(input, prim::kPrimLabelGoto) || IsPrimitive(input, prim::kPrimLabelSwitch); +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h new file mode 100644 index 0000000000..4fa3150e36 --- /dev/null +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -0,0 +1,216 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H +#define MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H +#include +#include +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "ir/dtype.h" +#include "base/base.h" +#include "ir/primitive.h" +#include "runtime/device/device_address.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "frontend/operator/ops.h" +#include "utils/contract.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace session { +using AnfVisitFuncion = std::function; +using KernelWithIndex = std::pair; +using DeviceAddress = device::DeviceAddress; +using DeviceAddressPtr = device::DeviceAddressPtr; +class AnfRuntimeAlgorithm { + public: + // get real input node of tuple_get_item + static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item); + static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); + // get input_anf_node's real kernel by recurse + static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index); + static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, int output_index, + bool visit_nop_node = false, + const std::vector &return_types = { + prim::kPrimMakeTuple}); + static std::vector GetAllOutput(const AnfNodePtr &node, + const std::vector &return_types = {}); + // get cnode primitive + static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node); + static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index); + static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); + // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple + static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type); + // get cnode primitive + static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node); + // get kernel_name of anf node + static std::string GetCNodeName(const AnfNodePtr &node); + // get detail info of anf node + static std::string GetNodeDebugString(const AnfNodePtr &node); + // get attr of anf node + template + static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + std::string node_debug_log = node->DebugString(); + MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str(); + } + // single op cnode. + if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) { + return GetValue(primitive->GetAttr(key)); + } + // graph kernel cnode. + auto fg = GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + return GetValue(fg->get_attr(key)); + } + static bool IsTupleOutput(const AnfNodePtr &anf); + // set attr of anf node + static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); + // set attr of key from 'from' node to 'to' node + static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to); + // set a new key for attr from 'from' node to 'to' node + static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, + const AnfNodePtr &to); + // set all attrs from 'from' node to 'to' node + static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to); + // check whether a cnode has the specified attr. + static bool HasNodeAttr(const std::string &key, const CNodePtr &node); + // delete attr of anf node + static void EraseNodeAttr(const std::string &key, AnfNodePtr node); + // get the num of input real_kernel(which can be build and run in device) + static size_t GetInputTensorNum(const AnfNodePtr &node); + // get the num of output real_kernel(which can be build and run in device) + static size_t GetOutputTensorNum(const AnfNodePtr &node); + // get output format select of anf node + static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx); + // get input format select of anf node + static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx); + // get prev node output width output index + static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); + // get output format from prev node,input_index is the input index of current node related to prev node + static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); + // get reshape_type of from the output of input node. + static std::vector GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); + // get output shapes inferred by ME from input nodes. + static std::vector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); + // get input shapes inferred by ME from input nodes. + static std::vector GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx); + // get output shapes which will built and run in device + static std::vector GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); + // get input shapes which will built and run in device + static std::vector GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); + // Get Input Padding Axis + static std::vector GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); + // Get Output Padding Axis + static std::vector GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); + // get output data type inferred by ME of anf node + static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); + // get output original data type from prev node,input_index is the input index of current node related to prev node + static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx); + // get output select data type of anf node + static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx); + // get input select data type of anf node + static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx); + // get output select data type from prev node,input_index is the input index of current node related to prev node + static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx); + // get output device addr of anf_node + static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); + // get mutable output device addr of anf_node + static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); + // check whether output addr is exist or not + static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); + // get address from prev node,input_index is the input index of current node related to prev node + static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, + bool visit_nop_node = true); + static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, + bool visit_nop_node = true); + // set output device addr of anf_node + static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); + // set workspace device addr of anf_node + static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); + // get workspace device addr of anf_node + static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx); + // set infer shapes and types of anf node + static void SetOutputInferTypeAndShape(const std::vector &types, + const std::vector> &shapes, AnfNode *node); + static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); + // get op pattern of the node + static kernel::OpPattern GetOpPattern(const AnfNodePtr &node); + // get KernelBuildType of node ,such as ATT,RT,FWK and so on + static KernelType GetKernelType(const AnfNodePtr &node); + // get processor type:AICORE,AICPU... + static kernel::Processor GetProcessor(const AnfNodePtr &node); + // get fusion type:AICORE,AICPU... + static kernel::FusionType GetFusionType(const AnfNodePtr &node); + // set select kernel_build_info + static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node); + // get select kernel_build_info + static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node); + // get kernelMode + static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node); + // set kernel mod + static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node); + // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too + static bool IsRealKernel(const AnfNodePtr &node); + // checkout whether the anf node is a real kernel that is a cnode and can run on device + static bool IsRealCNodeKernel(const AnfNodePtr &node); + // checkout whether the anf node is a graph kernel. + static bool IsGraphKernel(const AnfNodePtr &node); + // check parameter is weight or data + static bool IsParameterWeight(const ParameterPtr &node); + // set stream id of kernel,which will be set in stream assign and be used in stream generate + static void SetStreamId(uint32_t stream_id, AnfNode *node); + // get stream id + static uint32_t GetStreamId(const AnfNodePtr &node); + // set stream distinction label to distinguish different ops in different streams + static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node); + // get stream distinction label + static uint32_t GetStreamDistinctionLabel(const AnfNode *node); + // set graph id + static void SetGraphId(uint32_t graph_id, AnfNode *node); + // get graph id + static uint32_t GetGraphId(const AnfNode *node); + static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index); + // charge if the node's output is a feature map output + static bool IsFeatureMapOutput(const AnfNodePtr &node); + // charge if the node's input is from a feature map output + static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index); + // get real input index for some tbe ops which input order is different between me and tbe impl + static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); + static bool IsCommunicationOp(const AnfNodePtr &node); + static bool IsGetNext(const NotNull &node); + static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); + static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); + static bool IsSwitchCall(const CNodePtr &call_node); + static bool IsScalarInput(const CNodePtr &cnode, size_t index); + static bool IsScalarOutput(const CNodePtr &cnode, size_t index); + static void ReorderExecList(NotNull *> node_list); + // get fix output precision of cnode. + static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); + // get fix output precision from prev node, input_idx is the input index of current node related to prev node. + static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); + static bool IsCondControlKernel(const CNodePtr &node); +}; +} // namespace session +using AnfAlgo = session::AnfRuntimeAlgorithm; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc new file mode 100644 index 0000000000..274b355679 --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -0,0 +1,829 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/session/ascend_control_parser.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/union_find_set.h" +#include "runtime/device/ascend/ascend_label_assign.h" + +static constexpr size_t kCNodePrim = 0; +static constexpr size_t kCNodeCallArg = 1; +static constexpr size_t kCNodeSwitchCond = 1; +static constexpr size_t kCNodeSwitchTrue = 2; +static constexpr size_t kCNodeSwitchFalse = 3; +static constexpr size_t kCNodeSwitchLength = 4; +static constexpr size_t kCNodePartialLength = 2; +static constexpr size_t kCNodePartialFunc = 1; +static constexpr size_t kCNodeSwitchLayerBranch = 2; +static constexpr size_t kCNodeSwitchLayerLength = 3; +static constexpr size_t kCNodeAssignTarget = 1; +static constexpr size_t kCNodeAssignSource = 2; + +namespace mindspore { +namespace session { +static void RecursiveReplaceNode(NotNull kg, NotNull main_parameter, + const std::set ¶meter_reuse_set, + const NotNull *> memo) { + if (parameter_reuse_set.empty()) { + MS_LOG(EXCEPTION) << "Parameter_reuse_set is empty."; + } + if (memo->find(kg.get()) != memo->end()) { + return; + } + memo->insert(kg.get()); + + for (auto ¶ : parameter_reuse_set) { + if (para == main_parameter.get()) { + continue; + } + MS_EXCEPTION_IF_NULL(para); + MS_LOG(INFO) << "In " << kg->ToString() << " replace " << para->DebugString() << " of graph " + << AnfAlgo::GetGraphId(para.get()) << " to " << main_parameter->DebugString() << " of graph " + << AnfAlgo::GetGraphId(main_parameter.get().get()); + kg->ReplaceNode(NOT_NULL(para), main_parameter); + } + + for (auto &child : kg->child_graph_order()) { + RecursiveReplaceNode(NOT_NULL(child), main_parameter, parameter_reuse_set, memo); + } +} + +static AnfNodePtr GetMainParameter(NotNull root_kg, const AnfNodePtr &key, + const std::set ¶meter_reuse_set) { + AnfNodePtr main_parameter = key; + std::set root_inputs_set; + const auto &root_inputs_vector = root_kg->inputs(); + root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); + for (auto &node : parameter_reuse_set) { + if (root_inputs_set.find(node) != root_inputs_set.end()) { + main_parameter = node; + break; + } + } + return main_parameter; +} + +static void ReuseParameter(NotNull root_kg, + const std::vector> &link_list) { + // make union find set + UnionFindSet union_find_set; + for (auto &[param, arg] : link_list) { + union_find_set.Add(param); + union_find_set.Add(arg); + } + for (auto &[param, arg] : link_list) { + union_find_set.Union(param, arg); + } + auto parameter_reuse_sets = union_find_set.GetSets(); + + for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { + if (parameter_reuse_set.size() <= 1) { + continue; + } + auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set); + std::set memo; + RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo)); + } +} + +static CNodePtr GetNextRealKernel(const std::vector &list, size_t start) { + for (size_t i = start; i < list.size() - 1; ++i) { + if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { + return list[i]; + } + } + return nullptr; +} + +static void UpdateLabelIdToLabelSetMap(const std::vector &exec_order, + const NotNull *> label_id_to_label_set) { + for (auto &node : exec_order) { + MS_EXCEPTION_IF_NULL(node); + if (!IsPrimitiveCNode(node, prim::kPrimLabelSet)) { + continue; + } + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { + MS_LOG(EXCEPTION) << node->DebugString() << " has no attr kAttrLabelIndex"; + } + uint32_t label_id = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); + if (auto iter = label_id_to_label_set->find(label_id); iter != label_id_to_label_set->end()) { + MS_LOG(EXCEPTION) << "There are more than one node has same label id " << label_id + << ", node: " << iter->second->DebugString() << " and " << node->DebugString(); + } + (*label_id_to_label_set)[label_id] = node; + } +} + +static std::vector GetTargetLabelSetNodes(NotNull jump_node, + const std::map &label_id_to_label_set) { + std::vector target_label_list; + std::vector target_labelset_nodes; + if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelGoto)) { + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, jump_node)) { + MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kAttrLabelIndex"; + } + uint32_t label_id = AnfAlgo::GetNodeAttr(jump_node.get(), kAttrLabelIndex); + target_label_list.push_back(label_id); + } else if (IsPrimitiveCNode(jump_node.get(), prim::kPrimLabelSwitch)) { + if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, jump_node)) { + MS_LOG(EXCEPTION) << jump_node->DebugString() << " has no attr kPrimLabelSwitch"; + } + target_label_list = AnfAlgo::GetNodeAttr>(jump_node.get(), kAttrLabelSwitchList); + } else { + MS_LOG(EXCEPTION) << "Unknown type jump node " << jump_node->DebugString(); + } + + for (auto label_id : target_label_list) { + auto iter = label_id_to_label_set.find(label_id); + if (iter == label_id_to_label_set.end()) { + MS_LOG(EXCEPTION) << "Connot find LabelSet node has label id " << label_id; + } + target_labelset_nodes.push_back(iter->second); + } + return target_labelset_nodes; +} + +static void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNull *> exec_order) { + MS_EXCEPTION_IF_NULL(node); + auto exec_iter = std::find(exec_order->begin(), exec_order->end(), node); + if (exec_iter == exec_order->end()) { + MS_LOG(EXCEPTION) << "Cannot find " << node->DebugString() << " in exec order."; + } + exec_order->erase(exec_iter); +} + +void AscendControlParser::LinkGraph(NotNull kg) { + std::set memo; + std::vector> link_list; + // Insert Assign + ChildGraphDataAssign(kg, NOT_NULL(&link_list), NOT_NULL(&memo)); + // Reuse Parameter + ReuseParameter(kg, link_list); + // replace call by label goto / label switch + memo.clear(); + (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); + // assign label resource + device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); +} + +void AscendControlParser::EraseParameter(NotNull root_graph, + const std::set &graph_list) { + std::vector exec_order = root_graph->execution_order(); + std::set search_list(exec_order.begin(), exec_order.end()); + std::set root_inputs(root_graph->inputs().begin(), root_graph->inputs().end()); + auto ref_map = root_graph->GetRefMap(); + ReferenceCounter parameter_count([](int32_t read, int32_t write) -> bool { return write == 1; }); + std::multimap> ref_multimap; + std::transform(ref_map.begin(), ref_map.end(), std::inserter(ref_multimap, ref_multimap.end()), + [](const std::pair, std::pair> &p) + -> std::pair> { + return {p.first.first, {p.first.second, p.second.first, p.second.second}}; + }); + std::set all_nodes; + std::map para_to_written_node; + for (auto &graph : graph_list) { + auto out = graph->get_return(); + MS_EXCEPTION_IF_NULL(out); + search_list.insert(out->cast()); + auto nodes = TopoSort(out); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode != nullptr) { + all_nodes.insert(cnode); + } + } + } + // prepare referance count + for (auto &node : search_list) { + MS_EXCEPTION_IF_NULL(node); + // if assign node + std::set refed_parameters; + for (auto [iter, end] = ref_multimap.equal_range(node); iter != end; ++iter) { + refed_parameters.insert(std::get<1>(iter->second)); + } + + for (auto &in : node->inputs()) { + auto visit_node = AnfAlgo::VisitKernelWithReturnType(in, 0).first; + if (!visit_node->isa() || root_inputs.find(visit_node) != root_inputs.end()) { + continue; + } + if (refed_parameters.find(visit_node) != refed_parameters.end()) { + parameter_count.AddWriteCount(visit_node, 1); + para_to_written_node[visit_node] = node; + } else { + parameter_count.AddReadCount(visit_node, 1); + } + } + } + + while (parameter_count.HasValidElem()) { + auto [para, read, written] = parameter_count.GetOneValidElem(); + MS_LOG(INFO) << para->DebugString() << " was read " << read << " times, written " << written << " times."; + auto assign_iter = para_to_written_node.find(para); + if (assign_iter == para_to_written_node.end()) { + MS_LOG(EXCEPTION) << "Cannot find assign node that write " << para->DebugString(); + } + auto &assign_node = assign_iter->second; + MS_EXCEPTION_IF_NULL(assign_node); + if (!IsPrimitiveCNode(assign_node, prim::kPrimAssign)) { + parameter_count.EraseElem(para); + continue; + } + MS_LOG(INFO) << "Erase " << assign_node->DebugString(5); + EraseNodeFromExecOrder(assign_node, NOT_NULL(&exec_order)); + + auto source = AnfAlgo::VisitKernelWithReturnType(assign_node->input(kCNodeAssignSource), 0).first; + parameter_count.AddReadCount(source, -1); + parameter_count.AddWriteCount(para, -1); + for (auto &node : all_nodes) { + for (size_t i = 0; i < node->size(); ++i) { + if (node->input(i) == para) { + MS_LOG_INFO << "Replace " << node->DebugString() << " input " << i << " by " << source->DebugString(); + node->set_input(i, source); + } + } + } + parameter_count.AddReadCount(source, 1); + parameter_count.AddReadCount(para, -1); + } + root_graph->set_execution_order(exec_order); +} + +void AscendControlParser::EraseLabel(NotNull root_graph) { + std::vector exec_order = root_graph->execution_order(); + ReferenceCounter label_count([](int32_t read, int32_t write) -> bool { return read <= 1; }); + std::map label_to_written_node; + std::map label_id_to_label_set; + UpdateLabelIdToLabelSetMap(exec_order, NOT_NULL(&label_id_to_label_set)); + CNodePtr last_node = nullptr; + for (auto &cur_node : exec_order) { + MS_EXCEPTION_IF_NULL(cur_node); + if (AnfAlgo::IsCondControlKernel(cur_node)) { + std::vector target_labelset_nodes = GetTargetLabelSetNodes(NOT_NULL(cur_node), label_id_to_label_set); + for (auto &label_set : target_labelset_nodes) { + label_count.AddReadCount(label_set, 1); + label_to_written_node[label_set] = cur_node; + } + } else if (IsPrimitiveCNode(cur_node, prim::kPrimLabelSet)) { + label_count.AddWriteCount(cur_node, 1); + if (last_node != nullptr && !AnfAlgo::IsCondControlKernel(last_node)) { + label_count.AddReadCount(cur_node, 1); + label_to_written_node[cur_node] = last_node; + } + } + last_node = cur_node; + } + + while (label_count.HasValidElem()) { + auto [label_set, read, written] = label_count.GetOneValidElem(); + MS_LOG(INFO) << label_set->DebugString() << " was read " << read << " times, written " << written << " times."; + auto iter = label_to_written_node.find(label_set); + if (read > 0 && iter == label_to_written_node.end()) { + MS_LOG(EXCEPTION) << "Cannot find node jump to " << label_set->DebugString(); + } + CNodePtr jump_node = read > 0 ? iter->second : nullptr; + if (jump_node == nullptr || IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) { + MS_LOG(INFO) << "Erase node " << label_set->DebugString(); + EraseNodeFromExecOrder(label_set, NOT_NULL(&exec_order)); + } + if (jump_node != nullptr && IsPrimitiveCNode(jump_node, prim::kPrimLabelGoto)) { + MS_LOG(INFO) << "Erase node " << jump_node->DebugString(); + EraseNodeFromExecOrder(jump_node, NOT_NULL(&exec_order)); + } + label_count.EraseElem(label_set); + } + + root_graph->set_execution_order(exec_order); +} + +void AscendControlParser::ExecutorValidate(NotNull root_graph) { + std::set memo; + (void)RecurseGraph(root_graph, NOT_NULL(&memo)); + EraseParameter(root_graph, memo); + EraseLabel(root_graph); +} + +std::vector>> AscendControlParser::ParseCallNode( + NotNull call_node) { + std::vector>> ret; + if (!IsPrimitiveCNode(call_node.get(), prim::kPrimCall)) { + MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " is not a call node."; + } + if (call_node->size() <= kCNodeCallArg) { + MS_LOG(EXCEPTION) << "Node " << call_node->DebugString() << " has invalid inputs size " << call_node->size(); + } + const std::vector &call_node_inputs = call_node->inputs(); + auto call_arg = call_node_inputs[kCNodeCallArg]; + MS_EXCEPTION_IF_NULL(call_arg); + if (IsValueNode(call_arg)) { + ret.emplace_back(GetValueNode(call_arg), + std::vector(call_node_inputs.begin() + kCNodeCallArg + 1, call_node_inputs.end())); + } else if (IsPrimitiveCNode(call_arg, prim::kPrimSwitch)) { + auto switch_cnode = call_arg->cast(); + MS_EXCEPTION_IF_NULL(switch_cnode); + const std::vector &switch_inputs = switch_cnode->inputs(); + if (switch_inputs.size() <= kCNodeSwitchCond) { + MS_LOG(EXCEPTION) << "Node " << switch_cnode->DebugString() << " has invalid inputs size " + << switch_inputs.size(); + } + for (auto iter = switch_inputs.begin() + kCNodeSwitchCond + 1; iter != switch_inputs.end(); ++iter) { + const auto &[target_graph, args] = ParsePartial(NOT_NULL(*iter)); + ret.emplace_back(target_graph, args); + } + } else { + MS_LOG(EXCEPTION) << "Unsupport call node: " << call_node->DebugString(5); + } + return ret; +} + +void AscendControlParser::ChildGraphDataAssign( + NotNull kg, const NotNull> *> link_list, + const NotNull *> memo) { + if (memo->find(kg) != memo->end()) { + return; + } + memo->insert(kg.get()); + + MS_LOG(INFO) << "Start link data for " << kg->ToString(); + const std::vector &nodes = kg->execution_order(); + + for (auto &node : nodes) { + if (!IsPrimitiveCNode(node, prim::kPrimCall)) { + continue; + } + + auto child_graph_list = ParseCallNode(NOT_NULL(node)); + for (auto &[child_graph, args] : child_graph_list) { + MS_EXCEPTION_IF_NULL(child_graph); + const std::vector ¶ms = child_graph->inputs(); + if (args.size() != params.size()) { + MS_LOG(EXCEPTION) << child_graph->ToString() << " needs " << params.size() << " inputs but call node " + << node->DebugString(5) << " gives " << args.size(); + } + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->isa() && memo->find(child_graph) == memo->end()) { + MS_LOG(INFO) << args[i]->DebugString() << " to " << params[i]->DebugString() + << " should be reused, continue."; + link_list->emplace_back(args[i], params[i]); + continue; + } + + InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); + } + } + } + kg->SetExecOrderByDefault(); + for (auto &child_graph : kg->child_graph_order()) { + ChildGraphDataAssign(NOT_NULL(child_graph), link_list, memo); + } +} + +NotNull AscendControlParser::GetStartLabel(NotNull kg, const CNodePtr &last_node, + const CNodePtr &last_label) { + CNodePtr start_label; + if (last_node != nullptr && last_label != nullptr) { + start_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); + kg->set_start_label(start_label); + } else { + // no goto node will jump to start label of root graph, so return a fake label + start_label = std::make_shared(std::vector(), FuncGraphPtr(nullptr)); + } + return NOT_NULL(start_label); +} + +NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, + const CNodePtr &last_label, + const NotNull *> memo) { + MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); + + // 1. recursive condition + if (memo->find(kg) != memo->end()) { + MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); + return NOT_NULL(kg->get_start_label()); + } + memo->insert(kg.get()); + + // 2. args replace placeholder + LinkParentGraph(kg, last_node, last_label); + + // 3. topological sort + kg->SetExecOrderByDefault(); + const std::vector &nodes = kg->execution_order(); + // 4. insert first_label + CNodePtr start_label = GetStartLabel(kg, last_node, last_label); + + // 5. traverse + for (size_t i = 0; i < nodes.size(); ++i) { + auto &cnode = nodes[i]; + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() < kCNodePrim + 1) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex); + if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { + MS_LOG(DEBUG) << "Continue node " << cnode->DebugString(); + continue; + } + AnfNodePtr arg = cnode->input(kFirstDataInputIndex); + MS_EXCEPTION_IF_NULL(arg); + if (IsValueNode(arg)) { + RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); + } else if (!arg->isa()) { + MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); + } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitch)) { + auto arg_cnode = arg->cast(); + MS_EXCEPTION_IF_NULL(arg_cnode); + cnode->set_inputs(arg_cnode->inputs()); + RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); + } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitchLayer)) { + auto arg_cnode = arg->cast(); + MS_EXCEPTION_IF_NULL(arg_cnode); + cnode->set_inputs(arg_cnode->inputs()); + RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); + } + } + kg->SetExecOrderByDefault(); + MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); + return NOT_NULL(start_label); +} + +void AscendControlParser::InsertDependToGraph(NotNull kg, NotNull attch_node) { + auto return_node = kg->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), + return_node->input(kFirstDataInputIndex), attch_node.get()}; + auto depend_node = kg->NewCNode(inputs); + return_node->set_input(kFirstDataInputIndex, depend_node); +} + +void AscendControlParser::InsertControlDependToGraph(NotNull kg, NotNull first_node, + NotNull second_node) { + MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() + << ", the second node is " << second_node->DebugString(); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimControlDepend->name())), + first_node, second_node}; + auto control_depend = kg->NewCNode(inputs); + InsertDependToGraph(kg, NOT_NULL(control_depend)); +} + +void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, + const CNodePtr &last_label) { + // if not entry graph, replace return with label_goto + if (from_graph_call_node != nullptr && last_label != nullptr) { + auto label_goto = + kg->NewCNode({std::make_shared(std::make_shared(kLabelGotoOpName)), last_label}); + MS_EXCEPTION_IF_NULL(label_goto); + MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString(); + kg->set_end_goto(label_goto); + } +} + +void AscendControlParser::RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, + const NotNull *> memo) { + MS_LOG(INFO) << "Process call func " << cur_node->DebugString(); + + // 1 get kernel graph + const std::vector &origin_inputs = cur_node->inputs(); + if (kCNodeCallArg >= origin_inputs.size()) { + MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size(); + } + std::vector new_inputs = {std::make_shared(std::make_shared(kLabelGotoOpName))}; + if (!IsValueNode(origin_inputs[kCNodeCallArg])) { + MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; + return; + } + // 2 return label + auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " call node " + << cur_node->DebugString(); + // 3 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + if (next_node != nullptr && next_node != kg->get_return()) { + InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); + } + auto call_kg = GetValueNode(origin_inputs[kCNodeCallArg]); + // 4 modify call op to goto op + cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]); + // 5 recurse sub graph + CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo); + new_inputs.push_back(sub_label); + cur_node->set_inputs(new_inputs); + cur_node->set_abstract(nullptr); + AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>({call_kg}), cur_node.get()); + MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); +} + +void AscendControlParser::RecurseSwitch(NotNull kg, NotNull cur_node, + const CNodePtr &next_node, const NotNull *> memo) { + MS_LOG(INFO) << "Process switch node " << cur_node->DebugString(); + + if (cur_node->size() < kCNodeSwitchLength) { + MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; + } + // 1 return label + auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + MS_EXCEPTION_IF_NULL(back_label); + MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node " + << cur_node->DebugString(); + // 2 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + if (next_node != nullptr && next_node != kg->get_return()) { + InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); + } + // 3 recurse sub graph + const std::vector &origin_switch_inputs = cur_node->inputs(); + if (kCNodeSwitchCond >= origin_switch_inputs.size()) { + MS_LOG(EXCEPTION) << "The size of origin_switch_inputs is not more than " << kCNodeSwitchCond; + } + std::vector new_switch_inputs = { + std::make_shared(std::make_shared(kLabelSwitchOpName)), + origin_switch_inputs[kCNodeSwitchCond]}; + std::vector child_graphs; + for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { + // 3.1 branch kernel graph and args + KernelGraphPtr branch_fg; + std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + child_graphs.push_back(branch_fg); + // 3.2 recurse sub graph + CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); + new_switch_inputs.push_back(branch_label); + } + std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); + + cur_node->set_inputs(new_switch_inputs); + cur_node->set_abstract(nullptr); + AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>(child_graphs), cur_node.get()); + MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString(); +} + +void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull cur_node, + const CNodePtr &next_node, + const NotNull *> memo) { + MS_LOG(INFO) << "Process switch node " << cur_node->DebugString(); + + if (cur_node->size() < kCNodeSwitchLayerLength) { + MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; + } + + auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); + MS_EXCEPTION_IF_NULL(branch_tuple); + if (!branch_tuple->isa()) { + MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode"; + } + const std::vector &branch_partial = utils::cast(branch_tuple)->inputs(); + // 1 return label + auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + // 2 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + if (next_node != nullptr && next_node != kg->get_return()) { + InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); + } + // 3 recurse sub graph + const std::vector &origin_switch_inputs = cur_node->inputs(); + if (kCNodeSwitchCond >= origin_switch_inputs.size()) { + MS_LOG(EXCEPTION) << "Index out of range:" << origin_switch_inputs.size() << "."; + } + std::vector new_switch_inputs = { + std::make_shared(std::make_shared(kLabelSwitchOpName)), + origin_switch_inputs[kCNodeSwitchCond]}; + std::vector child_graphs; + for (size_t i = 0; i < branch_partial.size(); ++i) { + // 3.1 branch kernel graph and args + KernelGraphPtr branch_fg; + std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + child_graphs.push_back(branch_fg); + // 3.2 recurse sub graph + CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); + new_switch_inputs.push_back(branch_label); + } + new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); + cur_node->set_inputs(new_switch_inputs); + cur_node->set_abstract(nullptr); + AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>(child_graphs), cur_node.get()); + MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); +} + +std::tuple> AscendControlParser::ParsePartial(NotNull node) { + if (!node.get()->isa()) { + if (IsValueNode(node)) { + return {GetValueNode(node), {}}; + } + MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); + } + // 2.1 branch kernel graph and args + auto partial_cnode = utils::cast(node.get()); + MS_EXCEPTION_IF_NULL(partial_cnode); + if (partial_cnode->size() < kCNodePartialLength) { + MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; + } + + const auto &partial_inputs = partial_cnode->inputs(); + if (kCNodePartialFunc >= partial_inputs.size()) { + MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << "."; + } + auto branch_kg = GetValueNode(partial_inputs[kCNodePartialFunc]); + return {branch_kg, std::vector(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end())}; +} + +void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, const AnfNodePtr &jump_node, + NotNull from, NotNull to) { + std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); + std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); + MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]"; + if (from_outputs.size() != to_outputs.size()) { + MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" + << to_outputs.size() << "]"; + } + for (size_t i = 0; i < from_outputs.size(); i++) { + auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); + const auto &from_graph_exe_order = from_graph->execution_order(); + std::vector real_exe_order(from_graph_exe_order.size()); + size_t real_exe_order_size = 0; + std::copy_if(from_graph_exe_order.begin(), from_graph_exe_order.end(), real_exe_order.begin(), + [&real_exe_order_size](const CNodePtr &node) -> bool { + return (IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimPartial)) + ? false + : (++real_exe_order_size, true); + }); + real_exe_order.resize(real_exe_order_size); + if (jump_node == nullptr) { + if (!real_exe_order.empty()) { + InsertControlDependToGraph(from_graph, NOT_NULL(*(real_exe_order.rbegin())), NOT_NULL(assign_node)); + } else { + InsertDependToGraph(from_graph, NOT_NULL(assign_node)); + } + continue; + } + + auto jump_node_iter = std::find(real_exe_order.begin(), real_exe_order.end(), jump_node); + if (jump_node_iter == real_exe_order.end()) { + MS_LOG(EXCEPTION) << "Cannot find jump node " << jump_node->DebugString() << " in graph " + << from_graph->ToString(); + } + // insert assign between jump_node -1 and jump_node + if (jump_node_iter != real_exe_order.begin()) { + InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); + } + InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); + } +} + +AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, + NotNull to) { + if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && + AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { + return nullptr; + } + if (from.get() == to.get()) { + return nullptr; + } + MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " + << to->DebugString(); + // config inputs of assign node + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimAssign->name())), to, from}; + // generate a new cnode + auto assign_node = kg->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(assign_node); + assign_node->set_abstract(to->abstract()); + return assign_node; +} + +std::vector AscendControlParser::RecurseGraph(NotNull graph, + const NotNull *> memo) { + MS_LOG(INFO) << "Graph:" << graph->graph_id() << " start"; + if (memo->find(graph) != memo->end()) { + return {}; + } + memo->insert(graph.get()); + graph->SetExecOrderByDefault(); + std::vector cnodes = graph->execution_order(); + + auto end_label_goto = graph->get_end_goto(); + if (cnodes.rbegin() != cnodes.rend() && *cnodes.rbegin() == end_label_goto) { + cnodes.pop_back(); + } + AnfAlgo::ReorderExecList(NOT_NULL(&cnodes)); + if (end_label_goto != nullptr) { + cnodes.push_back(end_label_goto); + } + + std::vector execution_order; + uint32_t child_order_index = 0; + for (auto &node : cnodes) { + execution_order.push_back(node); + if (node == graph->get_end_goto()) { + continue; + } + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { + std::vector label_switch_list = AnfAlgo::GetNodeAttr>(node, kAttrLabelSwitchList); + for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { + if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { + MS_LOG(EXCEPTION) << "Check label index fail"; + } + if (child_order_index >= graph->child_graph_order().size()) { + MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); + } + auto child_graph = graph->child_graph_order()[child_order_index++]; + auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); + execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); + } + } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { + uint32_t label_index = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); + if (!CheckLabelIndex(child_order_index, label_index, node, graph)) { + MS_LOG(EXCEPTION) << "Check label index fail"; + } + auto child_graph = graph->child_graph_order()[child_order_index++]; + auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); + execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); + } + } + graph->set_execution_order(execution_order); + graph->PrintGraphExecuteOrder(); + return execution_order; +} + +bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, + NotNull graph) { + const std::vector> &child_graph_order = graph->child_graph_order(); + // check index and child order size + if (child_graph_order.size() <= IntToSize(order_index)) { + MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " + << child_graph_order.size() << " goto index " << order_index; + } + auto child_graph = child_graph_order[order_index]; + MS_EXCEPTION_IF_NULL(child_graph); + + // get start_label_set_index of child graph + auto start_label_set = child_graph->get_start_label(); + uint32_t start_label_set_index = AnfAlgo::GetNodeAttr(start_label_set, kAttrLabelIndex); + if (label_index != start_label_set_index) { + MS_EXCEPTION_IF_NULL(cur_label); + MS_EXCEPTION_IF_NULL(start_label_set); + MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() + << " index " << start_label_set_index << " current child graph order : " << order_index; + return false; + } else { + return true; + } +} + +void AscendControlParser::ReferenceCounter::AddReadCount(const AnfNodePtr &key, int32_t num) { + auto iter = count_.find(key); + if (iter != count_.end()) { + iter->second.first += num; + } else { + count_[key] = {num, 0}; + } +} + +void AscendControlParser::ReferenceCounter::AddWriteCount(const AnfNodePtr &key, int32_t num) { + auto iter = count_.find(key); + if (iter != count_.end()) { + iter->second.second += num; + } else { + count_[key] = {0, num}; + } +} + +void AscendControlParser::ReferenceCounter::EraseElem(const AnfNodePtr &key) { count_.erase(key); } + +bool AscendControlParser::ReferenceCounter::HasValidElem() const { + auto it = std::find_if(count_.begin(), count_.end(), + [this](const std::pair> &p) -> bool { + auto &[read, written] = p.second; + return predicate_(read, written); + }); + return it != count_.end(); +} + +std::tuple AscendControlParser::ReferenceCounter::GetOneValidElem() const { + auto it = std::find_if(count_.begin(), count_.end(), + [this](const std::pair> &p) -> bool { + auto &[read, written] = p.second; + return predicate_(read, written); + }); + if (it == count_.end()) { + MS_LOG(EXCEPTION) << "No valid parameter."; + } + return {it->first, it->second.first, it->second.second}; +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h new file mode 100644 index 0000000000..ac24735139 --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.h @@ -0,0 +1,92 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H +#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H + +#include +#include +#include +#include +#include +#include +#include "backend/session/kernel_graph.h" +#include "utils/base_ref.h" +#include "utils/contract.h" +#include "utils/union_find_set.h" + +namespace mindspore { +namespace session { +class AscendControlParser { + public: + static void LinkGraph(NotNull kg); + + static void InsertDependToGraph(NotNull kg, NotNull attch_node); + static void InsertControlDependToGraph(NotNull kg, NotNull first_node, + NotNull second_node); + static void ExecutorValidate(NotNull root_graph); + static void InsertMultipleAssignToGraph(NotNull from_graph, const AnfNodePtr &jump_node, + NotNull from, NotNull to); + + private: + class ReferenceCounter; + + static void EraseParameter(NotNull root_graph, const std::set &graph_list); + static void EraseLabel(NotNull root_graph); + static void ChildGraphDataAssign(NotNull kg, + const NotNull> *> link_list, + const NotNull *> memo); + static NotNull GetStartLabel(NotNull kg, const CNodePtr &last_node, + const CNodePtr &last_label); + static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, + const CNodePtr &last_label, + const NotNull *> memo); + static void RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, + const NotNull *> memo); + static void RecurseSwitch(NotNull kg, NotNull cur_node, const CNodePtr &next_node, + const NotNull *> memo); + static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, const CNodePtr &next_node, + const NotNull *> memo); + + static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, + const CNodePtr &last_label); + + static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); + static std::vector>> ParseCallNode(NotNull call_node); + static std::tuple> ParsePartial(NotNull node); + + // root graph order + static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, + NotNull graph); + static std::vector RecurseGraph(NotNull graph, + const NotNull *> memo); +}; +class AscendControlParser::ReferenceCounter { + public: + explicit ReferenceCounter(std::function func) : predicate_(func), count_() {} + void AddReadCount(const AnfNodePtr &key, int32_t num); + void AddWriteCount(const AnfNodePtr &key, int32_t num); + void EraseElem(const AnfNodePtr &key); + bool HasValidElem() const; + std::tuple GetOneValidElem() const; + + private: + std::function predicate_; + std::map> count_; +}; +} // namespace session +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.cc b/mindspore/ccsrc/backend/session/ascend_inference_session.cc new file mode 100644 index 0000000000..d251eb2039 --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/ascend_inference_session.h" +#include "frontend/operator/ops.h" +#include "ir/tensor.h" +#include "ir/anf.h" +#include "ir/param_value.h" +#include "runtime/device/kernel_runtime.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "common/trans.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "utils/config_manager.h" +#include "utils/base_ref_extends.h" + +namespace mindspore { +namespace session { +void AscendInferenceSession::LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector inputs(inputs_const); + auto input_nodes = kernel_graph->inputs(); + + size_t no_weight_input = 0; + for (size_t i = 0; i < input_nodes.size(); ++i) { + tensor::TensorPtr tensor = nullptr; + if (!input_nodes[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; + continue; + } + auto pk_node = input_nodes[i]->cast(); + MS_EXCEPTION_IF_NULL(pk_node); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + MS_EXCEPTION_IF_NULL(device_address); + if (!AnfAlgo::IsParameterWeight(pk_node)) { + tensor = inputs[no_weight_input++]; + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } +} + +GraphId AscendInferenceSession::CompileGraph(NotNull func_graph) { + auto graph_id = AscendSession::CompileGraph(func_graph); + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + // load weight data to device + auto input_nodes = kernel_graph->inputs(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + if (!input_nodes[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; + continue; + } + auto pk_node = input_nodes[i]->cast(); + MS_EXCEPTION_IF_NULL(pk_node); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + MS_EXCEPTION_IF_NULL(device_address); + if (AnfAlgo::IsParameterWeight(pk_node)) { + const auto ¶m_value = pk_node->default_param(); + MS_EXCEPTION_IF_NULL(param_value); + auto tensor = std::dynamic_pointer_cast(param_value->value()); + MS_EXCEPTION_IF_NULL(tensor); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } + return graph_id; +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.h b/mindspore/ccsrc/backend/session/ascend_inference_session.h new file mode 100644 index 0000000000..5364ae8d4e --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H +#define MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/session/ascend_session.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/session_factory.h" +#include "backend/session/ascend_control_parser.h" + +namespace mindspore { +namespace session { +class AscendInferenceSession : public AscendSession { + public: + AscendInferenceSession() = default; + ~AscendInferenceSession() = default; + void LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const; + GraphId CompileGraph(NotNull func_graph) override; +}; +MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc new file mode 100644 index 0000000000..75bc4e2d05 --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -0,0 +1,1996 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/ascend_session.h" +#include +#include +#include +#include +#include +#include +#include "frontend/operator/ops.h" +#include "ir/tensor.h" +#include "ir/anf.h" +#include "common/trans.h" +#include "runtime/device/kernel_runtime.h" +#include "runtime/device/ascend/kernel_select_ascend.h" +#include "runtime/device/ascend/kernel_build_ascend.h" +#include "runtime/device/ascend/ascend_kernel_runtime.h" +#include "runtime/device/ascend/ascend_device_address.h" +#include "backend/optimizer/ascend/ascend_backend_optimization.h" +#include "backend/optimizer/common/common_backend_optimization.h" +#include "runtime/device/kernel_adjust.h" +#include "runtime/device/ascend/ascend_stream_assign.h" +#include "runtime/device/ascend/ascend_label_assign.h" +#include "predict/predict.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/scalar.h" +#include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" +#include "debug/draw.h" +#include "common/utils.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "utils/config_manager.h" +#include "utils/base_ref_extends.h" +#include "debug/tensor_load.h" + +namespace mindspore { +namespace session { +const size_t kInvalidIndex = SIZE_MAX; +constexpr size_t kReturnDataIndex = 1; +namespace { +void DumpGraphExeOrder(const std::vector &execution_order, const std::string &tag = "") { + MS_LOG(INFO) << "Dump execution_order size " << execution_order.size(); + MS_LOG(INFO) << "[index][stream_label][graph_id][node string]"; + int i = 0; + for (auto &cnode : execution_order) { + MS_EXCEPTION_IF_NULL(cnode); + MS_LOG(INFO) << "[ " << i << "]" + << "[" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "]" + << "[" << AnfAlgo::GetGraphId(cnode.get()) << "]" + << "[" << cnode->DebugString() << "]"; + i++; + } + + std::stringstream buf; + buf << "================== execution order ==================\n"; + if (!tag.empty()) { + buf << tag << "\n"; + } + buf << "execution_order size: " << execution_order.size() << "\n"; + i = 0; + for (auto &cnode : execution_order) { + MS_EXCEPTION_IF_NULL(cnode); + buf << i << ":\n"; + buf << "\t" << cnode->DebugString() << "\n"; + buf << "\t" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "\n"; + buf << "\t" << AnfAlgo::GetGraphId(cnode.get()) << "\n"; + i++; + } + buf << "================== execution order ==================\n"; + // std::cout << buf.str() << std::endl; +} + +void DumpGraphInputArgs(const VectorRef &args) { + MS_LOG(INFO) << "Args size[%lu]" << args.size(); + for (size_t i = 0; i < args.size(); i++) { + if (utils::isa(args[i])) { + auto anf = utils::cast(args[i]); + MS_EXCEPTION_IF_NULL(anf); + MS_LOG(INFO) << "Parameter arg" << i << " = [%s]" << anf->DebugString(); + } else if (utils::isa(args[i])) { + auto value = utils::cast(args[i]); + MS_EXCEPTION_IF_NULL(value); + MS_LOG(INFO) << "Tensor arg" << i << " = " << value->ToString(); + } else { + MS_LOG(INFO) << "Unknown arg" << i << " = " << args[i].ToString(); + } + } +} + +void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) { + MS_EXCEPTION_IF_NULL(graph); + if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) { + graph->set_stream_distinction_label(label); + } +} + +std::vector GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) { + MS_EXCEPTION_IF_NULL(graph); + std::vector graph_inputs = graph->inputs(); + auto valid_inputs = graph->valid_inputs(); + size_t real_args_size = 0; + std::vector real_args = {}; + for (size_t i = 0; i < args.size(); i++) { + if (utils::isa(args[i])) { + auto tmp_args = AnfAlgo::GetAllOutput(utils::cast(args[i]), {prim::kPrimTupleGetItem}); + for (auto &real_arg : tmp_args) { + auto anf_node = utils::cast(real_arg); + MS_EXCEPTION_IF_NULL(anf_node); + auto abstract = anf_node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + // create multiple parameters if is a tuple output real kernel + if (abstract->isa() && + !AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + real_args_size += tuple_abstract->size(); + continue; + } + real_args_size += 1; + real_args.push_back(real_arg); + } + } else { + real_args_size += 1; + real_args.push_back(args[i]); + } + } + if (graph_inputs.size() != valid_inputs.size()) { + MS_LOG(EXCEPTION) << "Graph_inputs.size(): " << graph_inputs.size() + << ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; + } + if (real_args_size != graph_inputs.size()) { + for (size_t j = 0; j < valid_inputs.size(); j++) { + if (valid_inputs[j]) { + MS_LOG(INFO) << "Index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); + } + } + MS_LOG(WARNING) << "Real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() + << " not equal"; + } + return real_args; +} + +std::vector GetCNodes(const std::vector &anf_nodes) { + std::vector cnodes = {}; + size_t i = 0; + for (const auto &anf : anf_nodes) { + MS_LOG(INFO) << "Apply_list[" << i++ << "] = " << anf->DebugString(); + MS_EXCEPTION_IF_NULL(anf); + if (anf->isa()) { + cnodes.push_back(anf->cast()); + } + } + return cnodes; +} + +static std::vector> GetChildList(const std::vector &cnodes, + const std::set &cut_prims) { + size_t after_cut_index = 0; + std::vector> ret; + for (size_t i = 0; i < cnodes.size(); ++i) { + bool is_cut_node = false; + for (auto &prim : cut_prims) { + if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim)) { + is_cut_node = true; + break; + } + } + if (is_cut_node) { + // is call and not switch call,cut to 3 lists + if (!AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall)) { + // if is not a call,cut to 2 lists + ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); + after_cut_index = i; + } else if (!AnfAlgo::IsSwitchCall(cnodes[i])) { + ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); + ret.emplace_back(1, cnodes[i]); + after_cut_index = i + 1; + continue; + } + } + // get last child graph list + if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { + ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.end()); + continue; + } + } + return ret; +} + +static void BindCallArgsWithParameter(const std::vector ¶meters, const std::vector &args, + const KernelGraphPtr &graph, KernelGraphPtr child_graph, + const NotNull *> memo) { + MS_EXCEPTION_IF_NULL(child_graph); + MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id(); + if (args.empty()) { + return; + } + if (parameters.size() != args.size()) { + MS_LOG(EXCEPTION) << "Graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() + << " and args size:" << args.size() << " not equal!"; + } + child_graph->SetExecOrderByDefault(); + for (size_t i = 0; i < parameters.size(); i++) { + MS_LOG(INFO) << "parameters[" << i << "]" << parameters[i]->DebugString() << ",args[" << i << "]" + << args[i]->DebugString(); + if (args[i] == parameters[i]) { + MS_LOG(INFO) << "Parameter and arg are same."; + continue; + } + child_graph->SetRealInput(parameters[i], args[i]); + if (memo->find(child_graph) != memo->end() || !args[i]->isa()) { + MS_LOG(INFO) << "Add unreused arg,graph:" << graph->graph_id(); + child_graph->AddUnreuseArgs(args[i], graph); + } + } +} + +// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of +// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] +static void UpdateRealInput(NotNull graph, bool split_flag, + const NotNull *> memo) { + MS_EXCEPTION_IF_NULL(memo.get()); + auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); + for (auto &call_node : call_nodes) { + MS_EXCEPTION_IF_NULL(call_node); + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); + if (child_graphs.size() == 1) { + MS_EXCEPTION_IF_NULL(child_graphs[0]); + std::vector real_args = + std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()); + std::vector child_inputs = child_graphs[0]->inputs(); + BindCallArgsWithParameter(child_inputs, real_args, graph, child_graphs[0], memo); + if (split_flag) { + call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); + } + } else if (child_graphs.size() == 2) { + auto get_partial_args = [&](size_t input_index) -> std::vector { + auto switch_node = call_node->input(1); + MS_EXCEPTION_IF_NULL(switch_node); + auto switch_cnode = switch_node->cast(); + MS_EXCEPTION_IF_NULL(switch_cnode); + auto partial = switch_cnode->input(input_index); + MS_EXCEPTION_IF_NULL(partial); + if (IsValueNode(partial)) { + return {}; + } + auto partial_cnode = partial->cast(); + MS_EXCEPTION_IF_NULL(partial_cnode); + auto ret = std::vector(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); + if (split_flag) { + partial_cnode->set_inputs( + std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); + } + return ret; + }; + BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), graph, child_graphs[0], memo); + BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), graph, child_graphs[1], memo); + } + } +} + +static void RecurseToUpdateCallRealInput(NotNull graph, + const NotNull *> memo) { + memo->insert(graph.get()); + MS_LOG(INFO) << "Start graph id:" << graph->graph_id(); + for (auto &child_graph : graph->child_graph_order()) { + if (memo->find(child_graph) != memo->end()) { + MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() + << ",parent graph:" << graph->parent_graph()->graph_id(); + continue; + } + RecurseToUpdateCallRealInput(NOT_NULL(child_graph), memo); + } + // this action should from bottom to top + graph->UpdateCallRealInput(); +} + +void InsertMakeTupleForOutput(NotNull root_graph) { + auto return_node = root_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->size() <= kReturnDataIndex) { + return; + } + auto make_tuple = root_graph->NewCNode( + {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), root_graph->output()}); + root_graph->set_output(make_tuple); +} +} // namespace + +GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + MS_LOG(INFO) << "Start"; + // construct graph, if successfully, graph_sum_ + 1 + auto graph = ConstructKernelGraph(lst, outputs); + auto graph_id = graph->graph_id(); + MS_LOG(INFO) << "Compile graph " << graph_id << " success"; + return graph_id; +} + +GraphId AscendSession::CompileGraph(NotNull func_graph) { + MS_LOG(INFO) << "Start"; + std::vector all_graphs; + auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); + BackendOptimization(all_graphs); + // empty graph dont entry to backend + if (root_graph->execution_order().empty()) { + MS_LOG(INFO) << root_graph->ToString() << " is empty graph."; + InsertMakeTupleForOutput(NOT_NULL(root_graph)); + root_graph->set_executable(false); + InitRuntimeResource(); + return root_graph->graph_id(); + } + // create parameter for multiple branch + std::set memo; + CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); + // insert goto labels and label_sets + LinkChildGraphs(NOT_NULL(root_graph)); + // resource initialize + InitRuntimeResource(); + + IrFusionPass(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); + + SelectKernel(NOT_NULL(root_graph)); + memo.clear(); + + HardwareOptimize(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); + + AssignStaticMemory(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); + + UpdateRefOutputMap(NOT_NULL(root_graph), NOT_NULL(&memo)); + memo.clear(); + // add make_tuple to the output graph + InsertMakeTupleForOutput(NOT_NULL(root_graph)); + // root root_graph valiate,include genearte execute order and so on + RootGraphExecutorValidate(NOT_NULL(root_graph)); + // adjust kernel + AdjustKernel(root_graph); + // assign stream + AssignStream(NOT_NULL(root_graph)); + // insert profiling point + device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); + // build kernel + BuildKernel(root_graph); +#ifdef ENABLE_DEBUGGER + if (debugger_) { + debugger_->PreExecute(root_graph); + } +#endif + // alloc mem + MemoryAlloc(root_graph.get()); + // task generate + GenerateTaskInfo(root_graph); + // load task into device + LoadTask(root_graph); + DumpAllGraphs(all_graphs); + // return the root_graph id to backend + auto graph_id = root_graph->graph_id(); + return graph_id; +} + +void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto graph_order = GetGraphOrder(kernel_graph->graph_id()); + for (auto graph_id : graph_order) { + auto child_graph = GetGraph(graph_id); + if (child_graph == nullptr) { + continue; + } + if (child_graph->summary_node_exist()) { + kernel_graph->set_summary_node_exist(true); + return; + } + } + kernel_graph->set_summary_node_exist(false); +} + +void AscendSession::BuildGraph(GraphId graph_id) { + MS_LOG(INFO) << "Start"; + auto graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(graph); + // resource initialize + InitRuntimeResource(); + // multiple graph handle + if (graph_id == final_graph_id_) { + if (!graph->executable()) { + return; + } + // insert assigns to child graph + InsertAllAssigns(); + // insert switch and active to child graph + MergeSwitchCompile(); + SetFinalGraphSummaryFlag(graph); + // OptChildGraphs + auto graph_order = GetGraphOrder(final_graph_id_); + auto &graph_type = GetGraphOrderType(final_graph_id_); + for (size_t i = 0; i < graph_order.size(); i++) { + if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { + continue; + } + MS_LOG(INFO) << "Start build child graph " << graph_order[i]; + auto child_graph = GetGraph(graph_order[i]); + CompileChildGraph(child_graph); + } + GetSummaryNodes(graph.get()); + // merge child graph + MergeGraphExecOrder(); + } else { + auto single_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(single_graph); + CompileChildGraph(single_graph); + // set the distinction label of single graph + single_graph->set_stream_distinction_label(graph_id); + single_graph->UpdateExecuteKernelStreamLabel(); + } + // adjust execution order because merge child graph and other special operations + AdjustKernel(graph); + // Assign streams for control sink and hccl and so on + AssignStream(NOT_NULL(graph)); + + device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get())); + // build kernel if node is cnode + BuildKernel(graph); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); +#ifdef ENABLE_DEBUGGER + if (debugger_) { + debugger_->PreExecute(graph); + } +#endif + if (ms_context->precompile_only()) { + MS_LOG(INFO) << "Precompile only, stop in build kernel step"; + } else { + // alloc memory, including static memory and dynamic memory + MemoryAlloc(graph.get()); + // generate task info for task sink mode + GenerateTaskInfo(graph); + // load task info to device if it is sink mode + LoadTask(graph); + } + // sync the inital const tensor to device + SyncInitialTenosrToDevice(); + DumpAllGraphs({graph}); + MS_LOG(INFO) << "End"; +} + +void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { + MS_EXCEPTION_IF_NULL(child_graph); + MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString(); + opt::AscendBackendIRFusionOptimization(child_graph); + opt::AscendBackendFuseBasicOpt(child_graph, true); + opt::AscendBackendGraphKernelOpt(child_graph, true); + child_graph->SetExecOrderByDefault(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir"; + DumpIR(file_path, child_graph); + } + // select kernel build info + SelectKernel(*child_graph); + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir"; + DumpIR(file_path, child_graph); + } + // convert kernel Graph to model + predictmodel::StepConvertGraph(child_graph); + // optimize graph + HardwareOptimize(child_graph); + // assign static memory of parameters + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->AssignStaticMemoryInput(child_graph.get()); + runtime_instance->AssignStaticMemoryValueNode(child_graph.get()); +} + +void AscendSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, + VectorRef *const outputs) { + MS_LOG(INFO) << "Start"; + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + // if none of child graph and no anf output exists + if (!kernel_graph->executable()) { + MS_LOG(INFO) << "No child graph has anf output"; + UpdateOutputs(kernel_graph, outputs, inputs); + return; + } + // load input data from user input + LoadInputData(kernel_graph, inputs); + // convert inputs to model + predictmodel::StepConvertWeight(inputs); + { + py::gil_scoped_release release; + // run task on device + ExecTask(kernel_graph); + } + // get result from device + UpdateOutputs(kernel_graph, outputs, inputs); + // summary + Summary(kernel_graph.get()); +#ifdef ENABLE_DEBUGGER + // load tensor from device for debugger + if (debugger_ && debugger_->debugger_enabled()) { + LoadTensor(kernel_graph); + } +#endif + // dump used for debug + Dump(kernel_graph); +#ifdef ENABLE_DEBUGGER + // debugger post-execution processing + if (debugger_) { + debugger_->PostExecute(); + } +#endif + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start"; + // data layout optimization + opt::RunOpAscendDataLayout(kernel_graph); + // mixed precision optimization + opt::AscendMixPrecision(kernel_graph); + MS_LOG(INFO) << "Finish"; +} + +void AscendSession::RunOpExecTask(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get()); + if (!ret_ok) { + MS_LOG(EXCEPTION) << "Run task error!"; + } + MS_LOG(INFO) << "Finish!"; +} + +bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { + if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { + return true; + } + + return false; +} + +void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask) { + MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; + if (GraphCacheExist(graph_info)) { + MS_LOG(INFO) << "Build op " << op_run_info.op_name << " graph cache has existed !"; + return; + } + + // construct graph include one op + auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); + MS_EXCEPTION_IF_NULL(graph); + opt::RunOpAscendBackendIRFusionOptimization(graph); + // kernel select + SelectKernel(*graph); + // optimize + RunOpHardwareOptimize(graph); + // init runtime resource + InitRuntimeResource(); + // build kernel + RunOpAdjustKernel(graph); + BuildKernel(graph); + run_op_graphs_[graph_info] = graph; + MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; +} + +py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors) { + auto graph = run_op_graphs_[graph_info]; + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; + // malloc mem + RunOpMemoryAlloc(input_tensors, graph.get()); + // load input data to device + LoadInputData(graph, input_tensors); + // run op + RunOpExecTask(graph); + // get output + VectorRef outputs; + UpdateOutputs(graph, &outputs, input_tensors); + // trans output to tuple + auto output_tensors = TransformBaseRefListToTuple(outputs); + if (!utils::isa(output_tensors) || + !py::isinstance(utils::cast(output_tensors).object_)) { + MS_LOG(EXCEPTION) << "The output tensors should be a tuple !"; + } + py::object tuple_obj = utils::cast(output_tensors).object_; + py::tuple tuple_tensors = py::cast(tuple_obj); + RunOpMemoryClear(graph.get()); + MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; + return tuple_tensors; +} + +// compile graph steps +void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + size_t raise_precision_count = 0; + size_t reduce_precision_count = 0; + for (const auto &cnode : kernel_graph.execution_order()) { + auto status = device::ascend::SelectKernelInfo(cnode); + if (status == device::ascend::kStatusRaisePrecision) { + raise_precision_count++; + } else if (status == device::ascend::kStatusReducePrecision) { + reduce_precision_count++; + } + MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kGraphMode) { + if (raise_precision_count > 0) { + MS_LOG(WARNING) << "There has " << raise_precision_count + << " node/nodes used raise precision to selected the kernel!"; + } + if (reduce_precision_count > 0) { + MS_LOG(WARNING) << "There has " << reduce_precision_count + << " node/nodes used reduce precision to selected the kernel!"; + } + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::InitRuntimeResource() { + MS_LOG(INFO) << "Start!"; + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + if (!runtime_instance->Init()) { + MS_LOG(EXCEPTION) << "Kernel runtime init error."; + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::HardwareOptimize(const std::shared_ptr &kernel_graph) const { + device::ascend::KernelPreBuild(kernel_graph.get()); + MS_LOG(INFO) << "HardwareOptimize start!"; + opt::AscendBackendOptimization(kernel_graph); + opt::AscendGraphKernelCommonProcess(kernel_graph); + opt::AscendBackendFuseBasicOpt(kernel_graph, false); + opt::AscendBackendAddAtomicClean(kernel_graph); + MS_EXCEPTION_IF_NULL(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + MS_LOG(INFO) << "HardwareOptimize Finish!"; +} + +void AscendSession::AdjustKernel(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + opt::HideNopNode(kernel_graph.get()); + // Insert CLearZero op + // prepare for next step from json get atomic info + BuildKernel(kernel_graph); + device::ascend::KernelBuildPreprocess(kernel_graph.get()); + device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "after_adjust_kernel.ir"; + DumpIR(file_path, kernel_graph); + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + opt::HideNopNode(kernel_graph.get()); + // Insert CLearZero op + // prepare for next step from json get atomic info + BuildKernel(kernel_graph); + device::ascend::KernelBuildPreprocess(kernel_graph.get()); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::AssignStream(NotNull kernel_graph) const { + MS_LOG(INFO) << "Start!"; + device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::BuildKernel(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); + auto ret = device::ascend::KernelBuild(kernel_graph.get()); + if (!ret) { + MS_LOG(EXCEPTION) << "Kernel build error."; + } + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "KernelBuild run in " << PRIu64 << " us " << cost; + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { + MS_LOG(INFO) << "Start!"; + MS_EXCEPTION_IF_NULL(kernel_graph); + opt::RemoveNopNode(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->AssignMemory(kernel_graph); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::RunOpMemoryAlloc(const std::vector &input_tensors, + KernelGraph *kernel_graph) const { + MS_LOG(INFO) << "Start memory alloc!"; + MS_EXCEPTION_IF_NULL(kernel_graph); + opt::RemoveNopNode(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->RunOpClearMemory(kernel_graph); +} + +void AscendSession::GenerateTaskInfo(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + bool ret_ok = runtime_instance->GenTask(kernel_graph.get()); + if (!ret_ok) { + MS_LOG(EXCEPTION) << "Generate task error!"; + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::LoadTask(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + bool ret_ok = runtime_instance->LoadTask(kernel_graph.get()); + if (!ret_ok) { + MS_LOG(EXCEPTION) << "Load task error!"; + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::ExecTask(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + bool ret_ok = runtime_instance->Run(kernel_graph.get()); + if (!ret_ok) { + MS_LOG(EXCEPTION) << "run task error!"; + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::Dump(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + (void)runtime_instance->DumpData(kernel_graph.get()); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::DumpAllGraphs(const std::vector &all_graphs) { +#ifdef ENABLE_DUMP_IR + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + if (!save_graphs) { + return; + } + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + for (auto &graph : all_graphs) { + MS_EXCEPTION_IF_NULL(graph); + std::string file_path = save_graphs_path + "/graph_build_" + std::to_string(graph->graph_id()) + ".ir"; + DumpIR(file_path, graph, true); + DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id())); + } +#endif +} + +void AscendSession::LoadTensor(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + MS_EXCEPTION_IF_NULL(kernel_graph); +#ifdef ENABLE_DEBUGGER + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + DebugServices *debug_services = debugger_->debug_services(); + TensorLoader *tensor_loader = debug_services->tensor_loader(); + // TensorData will be freed up here + tensor_loader->EmptyTensor(); + uint32_t iter_num = tensor_loader->GetIterNum(); + tensor_loader->set_iter_num(++iter_num); + (void)runtime_instance->LoadData(kernel_graph.get(), debugger_.get()); + tensor_loader->EmptyPrevTensor(); +#endif + MS_LOG(INFO) << "Finish!"; +} + +GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { + MS_LOG(INFO) << "Start! Args size " << args.size(); + auto final_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(final_graph); + final_graph_id_ = final_graph->graph_id(); + MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success"; + // init private variables and bind them with final_graph_id + graph_execute_orders_[final_graph_id_] = std::vector(); + graph_order_types_[final_graph_id_] = std::vector(); + for (const auto ¶meter : args) { + MS_EXCEPTION_IF_NULL(parameter); + if (!parameter->isa()) { + MS_LOG(EXCEPTION) << parameter->DebugString() << " is not a parameter type!"; + } + AnfNodePtr parameter_backend = nullptr; + // if function return UINT_MAX,the parameter is not exist in child graph + auto parameter_belong_graph_id = GetGraphIdByNode(parameter); + if (parameter_belong_graph_id == kInvalidGraphId) { + parameter_backend = CreateNewParameterFromParameter(parameter, true, final_graph.get()); + final_graph->FrontBackendlMapAdd(parameter, parameter_backend); + MS_LOG(INFO) << "New parameter" << parameter->DebugString() << "in final_graph"; + } else { + // parametr is a parameter of child graph + auto graph = GetGraph(parameter_belong_graph_id); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Reuse parameter [" << parameter->DebugString() << "] of child graph [" + << parameter_belong_graph_id << "]"; + parameter_backend = graph->GetBackendAnfByFrontAnf(parameter); + // add parameter in backend to final graph inputs + auto final_graph_inputs = final_graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(final_graph_inputs); + final_graph_inputs->push_back(parameter_backend); + } + MS_EXCEPTION_IF_NULL(parameter_backend); + MS_LOG(INFO) << "Parameter backend " << parameter_backend->DebugString() << " belong_graph_id " + << AnfAlgo::GetGraphId(parameter_backend.get()); + } + MS_LOG(INFO) << "End final_graph_id " << final_graph_id_; + return final_graph_id_; +} + +void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph, + std::map> *summary) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(summary); + // if final graph have no child graph + auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); + if (graph_order_iter == graph_execute_orders_.end()) { + SessionBasic::GetSummaryNodes(graph); + auto summary_nodes = graph->summary_nodes(); + summary->insert(summary_nodes.begin(), summary_nodes.end()); + return; + } + // for every child graph, find summary nodes + auto graph_order = GetGraphOrder(graph->graph_id()); + for (size_t i = 0; i < graph_order.size(); i++) { + auto child_graph = GetGraph(graph_order[i]); + if (child_graph == nullptr) { + continue; + } + SessionBasic::GetSummaryNodes(child_graph.get()); + auto child_graph_summary = child_graph->summary_nodes(); + summary->insert(child_graph_summary.begin(), child_graph_summary.end()); + RecurseGetSummaryNodes(child_graph.get(), summary); + } + graph->set_summary_nodes(*summary); +} + +void AscendSession::GetSummaryNodes(KernelGraph *graph) { + MS_LOG(DEBUG) << "Update summary Start"; + MS_EXCEPTION_IF_NULL(graph); + auto summary_nodes = graph->summary_nodes(); + std::map> summary; + summary.insert(summary_nodes.begin(), summary_nodes.end()); + RecurseGetSummaryNodes(graph, &summary); + graph->set_summary_nodes(summary); + MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); +} + +AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { + auto fake_graph = GetGraph(fake_graph_id); + MS_EXCEPTION_IF_NULL(fake_graph); + auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0); + auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr { + auto parameter = fake_graph->NewParameter(); + MS_EXCEPTION_IF_NULL(parameter); + parameter->set_abstract(abstract); + auto new_parameter = fake_graph->NewParameter(parameter); + // Add new parameter to the graph input of fake_graph to sure that all parameters will be allocated memory. + auto graph_inputs = fake_graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + graph_inputs->push_back(new_parameter); + return new_parameter; + }; + auto create_parameter_from_cnode = [&](const AnfNodePtr &cnode, size_t output_idx) -> AnfNodePtr { + MS_EXCEPTION_IF_NULL(cnode); + auto abstract = cnode->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + // create multiple parameters if is a tuple output real kernel + if (abstract->isa()) { + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + MS_LOG(INFO) << "Tuple size [" << tuple_abstract->size() << "]"; + return create_parameter((*tuple_abstract)[output_idx]); + } + return create_parameter(cnode->abstract()); + }; + if (AnfAlgo::CheckPrimitiveType(output_item_with_index.first, prim::kPrimMakeTuple)) { + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + auto make_tuple = output_item_with_index.first->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + for (size_t i = 1; i < make_tuple->inputs().size(); i++) { + auto input = make_tuple->inputs()[i]; + make_tuple_inputs.push_back(CreateFakeOutput(fake_graph_id, input)); + } + return fake_graph->NewCNode(make_tuple_inputs); + } + return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second); +} + +void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) { + // get the backend anf node related to the output node of front + auto output_from_graph_id = GetGraphIdByNode(node); + auto output_from_graph = GetGraph(output_from_graph_id); + MS_EXCEPTION_IF_NULL(node); + MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id + << "] to final graph"; + MS_EXCEPTION_IF_NULL(output_from_graph); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(final_graph); + // if output is from final graph,it remarks no child graph exist + if (final_graph_id_ == output_from_graph_id) { + MS_LOG(INFO) << "No child graph,output is " << node->DebugString(); + final_graph->set_output(ConstructOutput({node}, final_graph)); + final_graph->set_executable(false); + return; + } + final_graph->set_output(output_from_graph->output()); +} + +void AscendSession::SetFinalGraphOutput(const ValuePtr &value) { + auto value_node = NewValueNode(value); + auto kernel_info = std::make_shared(); + value_node->set_kernel_info(kernel_info); + value_node->set_abstract(abstract::FromValue(value)); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(final_graph); + final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node})); + final_graph->set_executable(false); + MS_EXCEPTION_IF_NULL(value); + MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]"; +} + +void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) { + for (auto &output : vec_output) { + if (utils::isa(output)) { + auto output_anf_node = utils::cast(output); + SetFinalGraphOutput(output_anf_node); + } else if (utils::isa(output)) { + auto value = utils::cast(output); + SetFinalGraphOutput(value); + } else { + MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); + } + } +} + +void AscendSession::SetFinalGraphOutput(const BaseRef &output) { + if (utils::isa(output)) { + auto output_anf_node = utils::cast(output); + SetFinalGraphOutput(output_anf_node); + } else if (utils::isa(output)) { + auto value = utils::cast(output); + SetFinalGraphOutput(value); + } else if (utils::isa(output)) { + auto vec_output = utils::cast(output); + SetFinalGraphOutput(vec_output); + } else { + MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); + } +} + +void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) { + MS_LOG(INFO) << "Start!"; + MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]"; + auto condition_graph = GetGraph(condition_graph_id); + MS_EXCEPTION_IF_NULL(condition_graph); + tensor::TensorPtr tensor = std::make_shared(kNumberTypeInt32, std::vector{1}); + int32_t *val = nullptr; + val = static_cast(tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 0; + auto value_node = std::make_shared(tensor); + value_node->set_abstract(abstract::FromValue(tensor, false)); + auto counter_const = condition_graph->NewValueNode(value_node); + condition_graph->AddValueNodeToGraph(counter_const); + // create a new switch op + auto switch_primitive = std::make_shared("StreamSwitch"); + auto cond_output_it = condition_output_.find(condition_graph_id); + if (cond_output_it == condition_output_.end()) { + MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; + } + auto cond_output_kernel = + AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first; + MS_EXCEPTION_IF_NULL(cond_output_kernel); + std::vector inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; + CNodePtr switch_node = condition_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(switch_node); + switch_node->set_abstract(std::make_shared()); + AnfAlgo::SetGraphId(condition_graph_id, switch_node.get()); + // set attr: cond_ RT_GREATER + AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue(static_cast(RT_GREATER)), switch_node); + // set attr:data_type + AnfAlgo::SetNodeAttr(kAttrDataType, MakeValue(static_cast(RT_SWITCH_INT64)), switch_node); + // set attr:true branch graph id ,which is same to stream distinction label + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(true_graph_id), switch_node); + // append switch at the end of condition graph + auto return_node = condition_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + InsertControlDependToGraph(condition_graph_id, return_node->input(kReturnDataIndex), switch_node); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { + auto &graph_execute_order = GetGraphOrder(final_graph_id_); + auto &graph_order_type = GetGraphOrderType(final_graph_id_); + auto false_index = ExecOrderOfChildGraph(final_graph_id_, false_graph_id); + if (false_index == kInvalidIndex || false_index == 0) { + return; + } + for (int i = SizeToInt(false_index) - 1; i >= 0; i--) { + size_t graph_index = IntToSize(i); + if (graph_index >= graph_execute_order.size()) { + MS_LOG(EXCEPTION) << "Graph index[" << graph_index << "] out of range[" << graph_execute_order.size() << "]"; + } + if (graph_order_type[graph_index] == COMMON_GRAPH) { + auto true_last_id = graph_execute_order[graph_index]; + MS_LOG(INFO) << "The last graph of if true branch is " << true_last_id; + auto true_last = GetGraph(true_last_id); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(final_graph); + auto false_last = GetGraph(false_graph_id); + MS_EXCEPTION_IF_NULL(true_last); + MS_EXCEPTION_IF_NULL(false_last); + MS_LOG(INFO) << "The last graph of false branch is " << false_graph_id; + // create fake output + auto fake_output_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(fake_output_graph); + graph_execute_order.push_back(fake_output_graph->graph_id()); + graph_order_type.push_back(COMMON_GRAPH); + fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output())); + final_graph->set_output(fake_output_graph->output()); + InsertMultipleAssignToGraph(true_last_id, true_last->output(), final_graph->output()); + InsertMultipleAssignToGraph(false_graph_id, false_last->output(), final_graph->output()); + // insert stream active for loop sink + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && + ConfigManager::GetInstance().iter_num() > 1) { + // insert active in true graph, another active will be inserted in kernel adjust + InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel); + } + break; + } + } +} + +void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id, + const AnfNodePtr &output) { + if (switches_.find(cond_graph_id) != switches_.end()) { + MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before "; + return; + } + switches_[cond_graph_id] = std::pair(true_graph_id, false_graph_id); + condition_output_[cond_graph_id] = output; + MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; + // set the type of condition graph + auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); + auto &graph_order_type = GetGraphOrderType(final_graph_id_); + if (cond_graph_index >= graph_order_type.size()) { + MS_LOG(EXCEPTION) << "Cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size(); + } + graph_order_type[cond_graph_index] = CONDITION_GRAPH; + // update distinction label of false graph,update before merge to sure the distinction + if (false_graph_id != kInvalidGraphId) { + // false graph and condition in graph same stream + auto condition_graph = GetGraph(cond_graph_id); + MS_EXCEPTION_IF_NULL(condition_graph); + SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); + // if false graph is a condition graph and has been switch compiled before,it's false should be updated again + auto cond_it = switches_.find(false_graph_id); + while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) { + cond_graph_id = cond_it->first; + false_graph_id = cond_it->second.second; + condition_graph = GetGraph(cond_graph_id); + if (condition_graph == nullptr) { + continue; + } + SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); + cond_it = switches_.find(false_graph_id); + } + } +} // namespace session + +void AscendSession::MergeSwitchCompile() { + auto graph_execute_order = GetGraphOrder(final_graph_id_); + auto &graph_order_type = GetGraphOrderType(final_graph_id_); + for (auto switch_compile : switches_) { + auto cond_graph_id = switch_compile.first; + auto true_graph_id = switch_compile.second.first; + auto false_graph_id = switch_compile.second.second; + MS_LOG(INFO) << "Switch compile: " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; + auto condition_graph = GetGraph(cond_graph_id); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(condition_graph); + MS_EXCEPTION_IF_NULL(final_graph); + // insert switch to condition graph + InsertSwitchToGraph(cond_graph_id, true_graph_id); + auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); + auto prev_graph_id = kInvalidGraphId; + // if condition graph is the first graph and final graph has assign op,then the final graph is the common graph + if (cond_graph_index == 0 && !final_graph->execution_order().empty()) { + prev_graph_id = final_graph_id_; + // set the distinction label of final graph + SetStreamDistinctionLabel(final_graph, final_graph_id_, true); + // if condition graph is not the first graph + } else if ((cond_graph_index - 1 < graph_execute_order.size()) && + (graph_order_type[cond_graph_index - 1] == COMMON_GRAPH)) { + prev_graph_id = graph_execute_order[cond_graph_index - 1]; + } + // insert stream active to common graph + if (prev_graph_id != kInvalidGraphId) { + InsertStreamActiveToGraph(prev_graph_id, condition_graph->stream_distinction_label()); + } + // if this is a 'if' condition + auto it = while_condition_graphs_.find(cond_graph_id); + if (it == while_condition_graphs_.end()) { + CopyOutputOfIf(false_graph_id); + } else { + // if it is a while,insert a stream active to true graph + GraphId from_graph = it->second; + InsertStreamActiveToGraph(from_graph, condition_graph->stream_distinction_label()); + } + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::InsertAllAssigns() { + std::vector> assigns; + for (auto assign : assigns_) { + auto front_anf = std::get<0>(assign); + auto to_graph_id = std::get<1>(assign); + auto input_idx = std::get<2>(assign); + auto to_graph = GetGraph(to_graph_id); + MS_EXCEPTION_IF_NULL(to_graph); + std::vector graph_inputs = to_graph->inputs(); + if (input_idx >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); + } + auto backend_parameter = graph_inputs[input_idx]; + assigns.emplace_back(std::pair(front_anf, backend_parameter)); + } + // erase the repeat assign + std::set> inserted_nodes; + for (auto &assign : assigns) { + auto front_anf = assign.first; + auto backend_parameter = assign.second; + auto from_graph_id = GetGraphIdByNode(front_anf); + auto from_graph = GetGraph(from_graph_id); + MS_EXCEPTION_IF_NULL(from_graph); + auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); + if (inserted_nodes.find(assign) == inserted_nodes.end()) { + InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); + (void)inserted_nodes.insert(assign); + } + } +} + +// insert active to graph +void AscendSession::SetActive(GraphId from, GraphId to) { + if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) { + MS_LOG(WARNING) << "To " << to << " has been exits in map,from " << from << ",exist from " + << while_condition_graphs_[to]; + return; + } + MS_LOG(INFO) << "From " << from << " to " << to; + auto &graph_order = GetGraphOrder(final_graph_id_); + auto &graph_type = GetGraphOrderType(final_graph_id_); + std::vector graph_order_new; + std::vector graph_type_new; + for (size_t i = 0; i < graph_order.size(); i++) { + auto graph_id = graph_order[i]; + graph_order_new.push_back(graph_id); + graph_type_new.push_back(graph_type[i]); + if (from == graph_id) { + graph_order_new.push_back(kInvalidGraphId); + graph_type_new.push_back(BRANCH_END); + } + } + graph_order = graph_order_new; + graph_type = graph_type_new; + // set the graph type of condition graph + graph_type[ExecOrderOfChildGraph(final_graph_id_, to)] = CONDITION_GRAPH; + // record the condition graph into while condition set + while_condition_graphs_[to] = from; +} + +void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx) { + MS_LOG(INFO) << "Start!"; + MS_EXCEPTION_IF_NULL(front_anf); + auto from_graph_id = GetGraphIdByNode(front_anf); + auto from_graph = GetGraph(from_graph_id); + MS_EXCEPTION_IF_NULL(from_graph); + auto to_graph = GetGraph(to_graph_id); + MS_EXCEPTION_IF_NULL(to_graph); + std::vector graph_inputs = to_graph->inputs(); + if (input_idx >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); + } + auto backend_parameter = graph_inputs[input_idx]; + MS_EXCEPTION_IF_NULL(backend_parameter); + auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); + MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" + << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) + << "]"; + // a node should not assign to itself + if (backend_arg.get() == backend_parameter.get()) { + return; + } + // if arg is the the parameter of child graph,it is parameter of final graph too + if (front_anf->isa()) { + MS_EXCEPTION_IF_NULL(backend_arg); + MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString() + << "] will be replaced."; + to_graph->ReplaceNode(NOT_NULL(backend_parameter), NOT_NULL(backend_arg)); + return; + } + MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node" + << backend_parameter->DebugString() << "of graph " << to_graph_id; + assigns_.emplace_back(std::tuple(front_anf, to_graph_id, input_idx)); +} + +void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, + size_t input_idx) { + MS_LOG(INFO) << "Start!"; + std::pair graph_input_pair(to_graph_id, input_idx); + initial_tenosrs_[graph_input_pair] = front_tensor; + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::UpdateGraphOrder(GraphId to_graph_id) { + MS_LOG(INFO) << "To_graph_id " << to_graph_id; + auto &graph_order = GetGraphOrder(final_graph_id_); + auto &graph_type = GetGraphOrderType(final_graph_id_); + for (size_t i = 0; i < graph_order.size(); i++) { + if (graph_order[i] == to_graph_id) { + return; + } + } + // if graph is not in graph order,add it to graph order + SetStreamDistinctionLabel(GetGraph(to_graph_id), to_graph_id, false); + graph_order.push_back(to_graph_id); + graph_type.push_back(COMMON_GRAPH); + for (size_t i = 0; i < graph_order.size(); i++) { + MS_LOG(INFO) << "Index " << i << ",graph_id " << graph_order[i] << ",graph_type" << graph_type[i]; + } +} + +size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto output_num = AnfAlgo::GetOutputTensorNum(node); + if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + return input_index + output_num; + } + auto valid_inputs = graph->valid_inputs(); + if (valid_inputs[input_index]) { + SetChildGraphParameter(node, graph->graph_id(), input_index); + } else { + MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString(); + } + return ++input_index; +} + +size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(value); + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); + } + SetChildGraphParameter(value->cast(), graph->graph_id(), input_index); + return ++input_index; +} + +size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index) { + auto index = input_index; + for (auto &arg : vec_args) { + if (utils::isa(arg)) { + // arg is a anf node + auto node = utils::cast(arg); + index = SetChildGraphInput(graph, node, input_index); + } else if (utils::isa(arg)) { + // arg is a tensor + auto value = utils::cast(arg); + index = SetChildGraphInput(graph, value, input_index); + } else { + MS_LOG(EXCEPTION) << "Unexpected arg type " << arg.ToString(); + } + } + return index; +} + +void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { + MS_LOG(INFO) << "Set input of graph " << g; + auto to_graph = GetGraph(g); + MS_EXCEPTION_IF_NULL(to_graph); + DumpGraphInputArgs(args); + UpdateGraphOrder(g); + auto &graph_inputs = to_graph->inputs(); + auto real_args = GetRealArgs(to_graph, args); + size_t input_index = 0; + for (size_t i = 0; i < real_args.size(); i++) { + if (input_index >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "Input_index " << input_index << " out of range size " << graph_inputs.size(); + } + auto &real_arg = real_args[i]; + if (utils::isa(real_arg)) { + // arg is a anf node + auto node = utils::cast(real_arg); + input_index = SetChildGraphInput(to_graph, node, input_index); + } else if (utils::isa(real_arg)) { + // arg is a tensor + auto value = utils::cast(real_arg); + input_index = SetChildGraphInput(to_graph, value, input_index); + } else if (utils::isa(real_arg)) { + // arg is a VectorRef + auto vec_args = utils::cast(real_arg); + input_index = SetChildGraphInput(to_graph, vec_args, input_index); + } else { + MS_LOG(EXCEPTION) << "Unexpected arg type " << real_arg.ToString(); + } + } + MS_LOG(INFO) << "Finish!"; +} + +GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const { + for (const auto &graph_item : graphs_) { + auto graph = graph_item.second; + MS_EXCEPTION_IF_NULL(graph); + // if front_anf is a parameter,the backend parameter may have two + if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) { + return graph_item.first; + } + } + MS_EXCEPTION_IF_NULL(front_anf); + MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph"; + return kInvalidGraphId; +} + +void AscendSession::MergeGraphExecOrder() { + MS_LOG(INFO) << "Start!"; + // merge graph order + auto &graph_order = GetGraphOrder(final_graph_id_); + auto &graph_type = GetGraphOrderType(final_graph_id_); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(final_graph); + if (graph_order.empty()) { + MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!"; + return; + } + if (graph_order.size() > 1) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->enable_task_sink()) { + MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!"; + } + } + // if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph + SetStreamDistinctionLabel(final_graph, graph_order[0], false); + std::vector final_exec_order = final_graph->execution_order(); + KernelGraphPtr last_graph = nullptr; + for (size_t i = 0; i < graph_order.size(); i++) { + auto graph_id = graph_order[i]; + if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { + continue; + } + auto child_graph = GetGraph(graph_id); + last_graph = child_graph; + MS_EXCEPTION_IF_NULL(child_graph); + auto exec_order = child_graph->execution_order(); + MS_LOG(INFO) << "Merge graph,graph_id " << graph_id; + (void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order), + [&](CNodePtr node) -> CNodePtr { + AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get()); + return node; + }); + // add all value nodes of child graphs to final graph + for (auto &value_node : child_graph->graph_value_nodes()) { + final_graph->AddValueNodeToGraph(value_node); + } + // copy ref map to final graph + auto child_ref_map = child_graph->GetRefMap(); + for (auto &item : child_ref_map) { + if (final_graph->IsInRefOutputMap(item.first)) { + MS_LOG(EXCEPTION) << "The ref pair is already in final graph!"; + } + final_graph->AddRefCorrespondPairs(item.first, item.second); + } + } + // set final_exec_order into final graph + MS_EXCEPTION_IF_NULL(final_graph); + DumpGraphExeOrder(final_exec_order); + final_graph->set_execution_order(final_exec_order); +} + +void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { + MS_EXCEPTION_IF_NULL(from); + MS_EXCEPTION_IF_NULL(to); + if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && + AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { + return; + } + if (from.get() == to.get()) { + return; + } + MS_LOG(INFO) << "Insert assign to graph " << graph_id << " from " << from->DebugString() << " to " + << to->DebugString(); + auto graph = graphs_[graph_id]; + MS_EXCEPTION_IF_NULL(graph); + // config inputs of assign node + std::vector inputs = {NewValueNode(std::make_shared("Assign")), to, from}; + // generate a new cnode + auto assign_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(assign_node); + assign_node->set_abstract(to->abstract()); + // append the assign at the end of from graph + InsertDependToGraph(graph_id, assign_node); +} + +void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { + std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); + std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); + MS_LOG(INFO) << "Insert assigns from [" << AnfAlgo::GetGraphId(from.get()) << "] to [" + << AnfAlgo::GetGraphId(to.get()) << "]"; + if (from_outputs.size() != to_outputs.size()) { + MS_LOG(INFO) << "From[" << from->DebugString(5) << "] to[" << to->DebugString(5) << "]"; + MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" + << to_outputs.size() << "]"; + } + for (size_t i = 0; i < from_outputs.size(); i++) { + InsertAssignToGraph(graph_id, from_outputs[i], to_outputs[i]); + } +} + +void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) { + MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream; + auto from_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(from_graph); + std::vector inputs = {NewValueNode(std::make_shared("StreamActive"))}; + auto active_node = from_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(active_node); + active_node->set_abstract(std::make_shared()); + // set the active stream id into the attr of active node + std::vector active_index_value = {}; + active_index_value.push_back(actived_stream); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_value), active_node); + // append the active node at the end of from graph + auto return_node = from_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + InsertControlDependToGraph(graph_id, return_node->input(kReturnDataIndex), active_node); +} + +void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { + AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node)); +} + +void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, + const AnfNodePtr &second_node) { + AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node), + NOT_NULL(second_node)); +} + +size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { + auto &graph_order = GetGraphOrder(final_graph); + for (size_t i = 0; i < graph_order.size(); i++) { + if (child_graph == graph_order[i]) { + return i; + } + } + return kInvalidIndex; +} + +std::vector &AscendSession::GetGraphOrder(GraphId final_graph_id) { + auto graph_order_iter = graph_execute_orders_.find(final_graph_id); + if (graph_order_iter == graph_execute_orders_.end()) { + MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no child graph"; + } + return graph_order_iter->second; +} + +// get graph order type vector by graph id +std::vector &AscendSession::GetGraphOrderType(GraphId final_graph_id) { + auto graph_type_iter = graph_order_types_.find(final_graph_id); + if (graph_type_iter == graph_order_types_.end()) { + MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no graph_order_types_"; + } + return graph_type_iter->second; +} + +void AscendSession::SyncInitialTenosrToDevice() { + for (auto &item : initial_tenosrs_) { + auto to_graph_id = item.first.first; + auto input_idx = item.first.second; + auto front_tensor = item.second; + auto to_graph = GetGraph(to_graph_id); + MS_EXCEPTION_IF_NULL(to_graph); + std::vector graph_inputs = to_graph->inputs(); + if (input_idx >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); + } + auto backend_parameter = graph_inputs[input_idx]; + // sync data from host to device + MS_EXCEPTION_IF_NULL(front_tensor); + size_t tensor_size = front_tensor->data().nbytes(); + auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); + MS_EXCEPTION_IF_NULL(addr); + if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, + front_tensor->data_type(), front_tensor->data_c())) { + MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; + } + } +} + +static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph, const std::vector &list) { + // count the output of every anf node + std::set has_output_nodes; + for (auto &anf_node : list) { + MS_EXCEPTION_IF_NULL(anf_node); + for (auto &input : anf_node->inputs()) { + (void)has_output_nodes.insert(input); + } + } + + auto make_tuple_primitve = NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())); + std::vector make_tuple_inputs = {make_tuple_primitve}; + int output_idx = 0; + MS_EXCEPTION_IF_NULL(new_kernel_graph); + for (auto &anf_node : list) { + if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { + new_kernel_graph->set_return(anf_node); + } + if (has_output_nodes.find(anf_node) == has_output_nodes.end()) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "Output[" << output_idx++ << "]:" << anf_node->DebugString(); + make_tuple_inputs.push_back(anf_node); + } + } + if (new_kernel_graph->get_return() == nullptr) { + new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); + } +} + +std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, + const std::vector &list) { + MS_EXCEPTION_IF_NULL(new_kernel_graph); + MS_LOG(INFO) << "Start contruct splited kernel graph:" << new_kernel_graph->graph_id(); + MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); + std::vector call_node_inputs; + std::vector new_graph_inputs; + // create new parameter from cnode + for (auto &anf_node : list) { + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { + auto input = cnode->inputs()[input_idx]; + MS_EXCEPTION_IF_NULL(input); + AnfNodePtr new_parameter = nullptr; + // check whether input has been put into args of call, if mulptiple use of one parameter or cnode, only set one + // parameter in graph inputs and one arg in call node + auto call_input_it = std::find(call_node_inputs.begin(), call_node_inputs.end(), input); + if (call_input_it != call_node_inputs.end()) { + cnode->set_input(input_idx, new_graph_inputs[std::distance(call_node_inputs.begin(), call_input_it)]); + continue; + } + // value node consider move to new graph + if (input->isa()) { + cnode->set_input(input_idx, input); + continue; + } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { + // if is cnode and not in current child graph + new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); + cnode->set_input(input_idx, new_parameter); + } else { + // if is a cnode and in current graph + continue; + } + new_graph_inputs.push_back(new_parameter); + call_node_inputs.push_back(input); + } + } + // set graph inputs of new graph + auto graph_inputs = new_kernel_graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + graph_inputs->clear(); + std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs)); + + MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); + ConstructSplitedGraphOutput(new_kernel_graph, list); + MS_LOG(INFO) << "End"; + return call_node_inputs; +} + +void AscendSession::BackendOptimization(const std::vector &all_graphs) { + MS_LOG(INFO) << "Start BackendCommonOptimization"; + for (auto &graph : all_graphs) { + opt::BackendCommonOptimization(graph); + } + MS_LOG(INFO) << "End."; +} + +void AscendSession::SplitGraphs(NotNull root_graph) { + std::set memo; + // if output of graph is nullptr,no need insert maketuple at the end of graph + if (root_graph->output() == nullptr) { + return; + } + // if root graph output is a call node ,the root graph is condition graph of 'if' sentence + auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; + if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { + SplitGraph(root_graph, {prim::kPrimReturn}, NOT_NULL(&memo)); + for (auto &child_graph : root_graph->child_graph_order()) { + RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); + } + } else { + RecurseSplitGraph(root_graph, NOT_NULL(&memo)); + } + memo.clear(); + // add maketuple to the end of the last child graph to suit old process + auto output_graph = root_graph->child_graph_order().empty() ? root_graph : root_graph->child_graph_order().back(); + auto make_tuple = output_graph->NewCNode( + {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), output_graph->output()}); + output_graph->set_output(make_tuple); + // replace the real input if the real input is a call + RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo)); +} + +AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull graph, + const std::vector &child_graph_list) { + // if child graph list only has a call ,then return the exist call + if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { + return child_graph_list[0]; + } + // create new child graph + auto child_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(child_graph); + // create new value node to bind child graph + auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph)); + std::vector new_call_input = {NewValueNode(std::make_shared(prim::kPrimCall->name())), + graph_value_node}; + // set the graph id of all node of child graph + for (auto &child_graph_node : child_graph_list) { + AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); + } + auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list); + std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input)); + auto new_call = graph->NewCNode(new_call_input); + AnfAlgo::SetNodeAttr("graph_id", MakeValue(graph->graph_id()), new_call); + return new_call; +} + +void AscendSession::SplitGraph(NotNull graph, const std::set &cut_prims, + const NotNull *> memo) { + MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); + bool split_flag = false; + auto apply_list = GetCNodes(TopoSort(graph->get_return())); + // update the root graph child graph order + graph->UpdateChildGraphOrder(); + // get child list from current graph + std::vector> child_graph_lists = GetChildList(apply_list, cut_prims); + if (child_graph_lists.size() > 1) { + std::list depend_input = {}; + for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { + auto call_node = BindNewCallToNewGraph(graph, child_graph_lists[call_index]); + MS_EXCEPTION_IF_NULL(call_node); + // if call node is the last call of true graph,no need create child graph after that + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); + depend_input.push_front(call_node); + if (child_graphs.size() == 1 && child_graphs[0] == graph->parent_graph()) { + break; + } + } + depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimDepend->name())))); + auto depend = graph->NewCNode(std::vector(depend_input.begin(), depend_input.end())); + auto new_return_primitive = + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimReturn->name()))); + graph->set_return(graph->NewCNode({new_return_primitive, depend})); + AnfNodePtr pre_call_node = nullptr; + AnfNodePtr cur_call_node = nullptr; + auto iter = depend_input.begin(); + for (++iter; iter != depend_input.end(); ++iter) { + pre_call_node = cur_call_node; + cur_call_node = *iter; + if (pre_call_node != nullptr && cur_call_node != nullptr) { + AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); + } + } + split_flag = true; + } + graph->UpdateChildGraphOrder(); + UpdateRealInput(graph, split_flag, memo); + MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; +} + +void AscendSession::RecurseSplitGraph(NotNull graph, const NotNull *> memo) { + memo->insert(graph.get()); + SplitGraph(graph, {prim::kPrimCall}, memo); + for (auto &child_graph : graph->child_graph_order()) { + if (memo->find(child_graph) == memo->end()) { + RecurseSplitGraph(NOT_NULL(child_graph), memo); + } + } +} + +void AscendSession::LinkChildGraphs(NotNull graph) { AscendControlParser::LinkGraph(graph); } + +void AscendSession::RootGraphExecutorValidate(NotNull graph) { + AscendControlParser::ExecutorValidate(graph); +} + +void AscendSession::RecurseCompileGraph(NotNull graph, const NotNull *> memo) { + memo->insert(graph.get()); + CompileChildGraph(graph); + for (auto child_graph : graph->child_graph_order()) { + if (memo->find(child_graph) != memo->end()) { + continue; + } + RecurseCompileGraph(NOT_NULL(child_graph), memo); + // copy ref map to final graph + auto child_ref_map = child_graph->GetRefMap(); + for (auto &item : child_ref_map) { + if (graph->IsInRefOutputMap(item.first)) { + MS_LOG(EXCEPTION) << "The ref pair is already in final graph!"; + } + graph->AddRefCorrespondPairs(item.first, item.second); + } + } +} + +void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNull *> memo) { + if (memo->find(graph.get()) != memo->end()) { + return; + } + memo->insert(graph.get()); + + graph->UpdateChildGraphOrder(); + for (auto &child_graph : graph->child_graph_order()) { + CreateMultiBranchOutput(NOT_NULL(child_graph), memo); + } + + std::map need_replace_list; + auto node_list = GetCNodes(TopoSort(graph->get_return())); + for (auto &node : node_list) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { + // create a parameter to store the output of multiple branch and set the parameter as the condition graph's output + // auto multi_output_param = graph->NewParameter(); + auto origin_inputs = graph->inputs(); + auto output_param = CreateNewParameterFromCNode(node, true, graph.get().get()); + MS_EXCEPTION_IF_NULL(graph->MutableInputs()); + graph->MutableInputs()->operator=(origin_inputs); + graph->AddChildGraphResult(output_param); + + std::vector depend_inputs = { + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimDepend->name()))), output_param, node}; + auto depend = graph->NewCNode(depend_inputs); + need_replace_list.emplace(node, depend); + MS_LOG(INFO) << "Create parameter " << output_param->DebugString() << " for call node " << node->DebugString() + << ", depend node is " << depend->DebugString(); + // insert assign in order to transfer child graph output to parameter + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node); + for (auto &child_graph : child_graphs) { + MS_EXCEPTION_IF_NULL(child_graph); + if (child_graph->get_output_null()) { + continue; + } + auto graph_output = child_graph->output(); + AscendControlParser::InsertMultipleAssignToGraph(NOT_NULL(child_graph), nullptr, NOT_NULL(graph_output), + NOT_NULL(output_param)); + } + } + } + // searching for nodes' input to replace call by depend(parameter, call) + for (auto &node : node_list) { + for (size_t i = 0; i < node->size(); ++i) { + auto input = node->input(i); + auto iter = need_replace_list.find(input); + if (iter != need_replace_list.end()) { + node->set_input(i, iter->second); + } + } + } +} + +void AscendSession::IrFusionPass(const NotNull graph, NotNull *> memo) { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + + opt::AscendBackendIRFusionOptimization(graph); + opt::AscendBackendFuseBasicOpt(graph, true); + opt::AscendBackendGraphKernelOpt(graph, true); + graph->SetExecOrderByDefault(); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs) { + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + std::string file_path = + save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(graph->graph_id()) + ".ir"; + DumpIR(file_path, graph.get()); + } + + for (auto &child_graph : graph->child_graph_order()) { + IrFusionPass(NOT_NULL(child_graph), memo); + } +} + +void AscendSession::SelectKernel(NotNull root_graph) { + MS_LOG(INFO) << "Start select kernel."; + size_t raise_precision_count = 0; + size_t reduce_precision_count = 0; + + std::set memo; + (void)RecurseSelectKernelInfo(root_graph, NOT_NULL(&memo), &raise_precision_count, &reduce_precision_count); + memo.clear(); + + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kGraphMode) { + if (raise_precision_count > 0) { + MS_LOG(WARNING) << "There has " << raise_precision_count + << " node/nodes used raise precision to selected the kernel!"; + } + if (reduce_precision_count > 0) { + MS_LOG(WARNING) << "There has " << raise_precision_count + << " node/nodes used reduce precision to selected the kernel!"; + } + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::RecurseSelectKernelInfo(NotNull graph, + NotNull *> const memo, + size_t *const raise_precision_count, + size_t *const reduce_precision_count) const { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + MS_LOG(INFO) << "Start to select kernel info in graph: " << graph->graph_id(); + + for (const auto &cnode : graph->execution_order()) { + if (AnfAlgo::IsCondControlKernel(cnode)) { + std::vector child_graphs; + if (AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)) { + child_graphs = AnfAlgo::GetNodeAttr>(cnode, kAttrChildGraph); + } + for (auto &child_graph : child_graphs) { + RecurseSelectKernelInfo(NOT_NULL(child_graph), memo, raise_precision_count, reduce_precision_count); + } + } + + auto status = device::ascend::SelectKernelInfo(cnode); + if (status == device::ascend::kStatusRaisePrecision) { + (*raise_precision_count)++; + } else if (status == device::ascend::kStatusReducePrecision) { + (*reduce_precision_count)++; + } + MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); + } + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs) { + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + std::string file_path = + save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(graph->graph_id()) + ".ir"; + DumpIR(file_path, graph.get()); + } + MS_LOG(INFO) << "Finish selecting kernel info in graph: " << graph->graph_id(); +} + +void AscendSession::HardwareOptimize(NotNull graph, + NotNull *> const memo) const { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + + MS_LOG(INFO) << "Start to do HardwareOptimize in graph: " << graph->graph_id(); + // convert kernel Graph to model + predictmodel::StepConvertGraph(graph.get()); + + HardwareOptimize(graph.get()); + for (auto &child_graph : graph->child_graph_order()) { + HardwareOptimize(NOT_NULL(child_graph), memo); + } + MS_LOG(INFO) << "Finish doing HardwareOptimize in graph: " << graph->graph_id(); +} + +void AscendSession::AssignStaticMemory(NotNull graph, + NotNull *> const memo) const { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + + MS_LOG(INFO) << "Start to assign static memory for parameter in graph: " << graph->graph_id(); + // assign static memory for parameters + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->AssignStaticMemoryInput(graph.get().get()); + runtime_instance->AssignStaticMemoryValueNode(graph.get().get()); + for (auto &child_graph : graph->child_graph_order()) { + AssignStaticMemory(NOT_NULL(child_graph), memo); + } + MS_LOG(INFO) << "Finish assigning static memory for parameter in graph: " << graph->graph_id(); +} + +void AscendSession::UpdateRefOutputMap(NotNull graph, + NotNull *> const memo) const { + if (memo->find(graph) != memo->end()) { + return; + } + memo->insert(graph.get()); + + for (auto &child_graph : graph->child_graph_order()) { + UpdateRefOutputMap(NOT_NULL(child_graph), memo); + // copy ref map to final graph + auto child_ref_map = child_graph->GetRefMap(); + for (auto &item : child_ref_map) { + if (graph->IsInRefOutputMap(item.first)) { + MS_LOG(WARNING) << "The ref pair <" << item.first.first->DebugString() << ", " << item.first.second + << "> is already in " << graph->ToString(); + continue; + } + graph->AddRefCorrespondPairs(item.first, item.second); + } + } +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h new file mode 100755 index 0000000000..11cb1c92d2 --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -0,0 +1,184 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H +#define MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/session_factory.h" +#include "backend/session/ascend_control_parser.h" + +namespace mindspore { +namespace session { +enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 }; + +class AscendSession : public SessionBasic { + public: + AscendSession() { final_graph_id_ = kInvalidGraphId; } + ~AscendSession() override = default; + void Init(uint32_t device_id) override { + SessionBasic::Init(device_id); + context_ = std::make_shared(kAscendDevice, device_id); + } + GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + GraphId CompileGraph(NotNull func_graph) override; + void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + void BuildGraph(GraphId) override; + void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask) override; + py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors) override; + + // set parameters of final graph + GraphId SetFinalGraphInput(const std::vector &args) override; + // set output of final graph + void SetFinalGraphOutput(const BaseRef &output) override; + // insert switch and set the relative active ops + void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override; + // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter + void SetChildGraphInput(GraphId g, const VectorRef &args) override; + // get graph id in child graphs by ME front anf node pointer + GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; + // get graph id of final graph + GraphId GetFinalRunGraph() const override { return final_graph_id_; } + // insert active to graph + void SetActive(GraphId, GraphId) override; + // compile child graph when session have multiple child graphs + void CompileChildGraph(const KernelGraphPtr &child_graph); + void RecurseGetSummaryNodes(KernelGraph *graph, std::map> *summary); + void GetSummaryNodes(KernelGraph *graph); + + private: + void InitRuntimeResource(); + void SelectKernel(const KernelGraph &kernel_graph) const; + void HardwareOptimize(const std::shared_ptr &kernel_graph) const; + void AdjustKernel(const std::shared_ptr &kernel_graph) const; + void RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const; + void AssignStream(NotNull kernel_graph) const; + void BuildKernel(const std::shared_ptr &kernel_graph) const; + void MemoryAlloc(KernelGraph *kernel_graph) const; + void RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const; + void RunOpMemoryClear(const KernelGraph *kernel_graph) const; + void GenerateTaskInfo(const std::shared_ptr &kernel_graph) const; + void LoadTask(const std::shared_ptr &kernel_graph) const; + void ExecTask(const std::shared_ptr &kernel_graph) const; + void Dump(const std::shared_ptr &kernel_graph) const; + void DumpAllGraphs(const std::vector &all_graphs); + void LoadTensor(const std::shared_ptr &kernel_graph) const; + // below functions are used for run op + void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const; + void RunOpExecTask(const std::shared_ptr &kernel_graph) const; + + size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index); + size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); + size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); + + void SetFinalGraphOutput(const AnfNodePtr &node); + void SetFinalGraphOutput(const ValuePtr &value); + void SetFinalGraphOutput(const VectorRef &vec_output); + + void SplitGraph(NotNull graph, const std::set &cut_prims, + const NotNull *> memo); + // split graphs with recurse from root graph + void SplitGraphs(NotNull root_graph); + void BackendOptimization(const std::vector &all_graphs); + void LinkChildGraphs(NotNull graph); + void RootGraphExecutorValidate(NotNull graph); + std::vector ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, + const std::vector &list); + void RecurseCompileGraph(NotNull graph, const NotNull *> memo); + void RecurseSplitGraph(NotNull graph, const NotNull *> memo); + AnfNodePtr BindNewCallToNewGraph(NotNull graph, const std::vector &child_graph_list); + + // merge execution order list of child graphs + void MergeGraphExecOrder(); + // insert assion op to sync data bettween different graphs + void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); + // insert mutiple assigns to graph + void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); + // insert active op to graph + void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream); + // get execute index of graph + size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph); + // handle condition graph from vm + void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id); + // insert depend to graph, used to attch control nodes to graph + void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); + // insert depend to graph, used to attch control nodes to graph + void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); + // set child graph parameter if front arg is a anf + void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); + // set child graph parameter if front arg is a tensor + void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx); + // update the execution order of all child graphs + void UpdateGraphOrder(GraphId to_graph); + // handle switch when merge + void MergeSwitchCompile(); + // get graph order vector by graph id + std::vector &GetGraphOrder(GraphId final_graph_id); + // get graph order type vector by graph id + std::vector &GetGraphOrderType(GraphId final_graph_id); + // copy output of if and else + void CopyOutputOfIf(GraphId false_graph_id); + // check if graph cache exist + bool GraphCacheExist(const GraphInfo &graph_info) const; + // insert all assign to child graph + void InsertAllAssigns(); + // create fake output of final graph + AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); + // sync intial tensors' data to device + void SyncInitialTenosrToDevice(); + void SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph); + // create parameter to receive data from multiple branch output + void CreateMultiBranchOutput(NotNull graph, NotNull *> memo); + void SelectKernel(NotNull root_graph); + void RecurseSelectKernelInfo(NotNull graph, NotNull *> const memo, + size_t *const raise_precision_count, size_t *const reduce_precision_count) const; + void IrFusionPass(const NotNull graph, NotNull *> memo); + void HardwareOptimize(const NotNull graph, NotNull *> memo) const; + void AssignStaticMemory(const NotNull graph, NotNull *> memo) const; + void UpdateRefOutputMap(const NotNull graph, NotNull *> memo) const; + + // member variables + // key is final_graph_id,value is child graph execute order of final graph + std::unordered_map> graph_execute_orders_; + // key is final_graph_id,value is the graph types of child graphs + std::unordered_map> graph_order_types_; + // record condition graph of while + std::unordered_map while_condition_graphs_; + // record all conditions + std::unordered_map> switches_; + std::unordered_map condition_output_; + // share parameters + std::vector> assigns_; + // initial tensors, these tensor will sync data to device before run graph + std::map, tensor::TensorPtr> initial_tenosrs_; + // final_graph_id is used in every root graph has it's own session situation + GraphId final_graph_id_; +}; +MS_REG_SESSION(kAscendDevice, AscendSession); +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc new file mode 100644 index 0000000000..ca1c78d206 --- /dev/null +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/session/cpu_session.h" +#include +#include "ir/tensor.h" +#include "ir/anf.h" +#include "backend/kernel_compiler/kernel.h" +#include "common/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_runtime.h" +#include "predict/predict.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "runtime/device/cpu/kernel_select_cpu.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif + +namespace mindspore { +namespace session { +ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(graph); + if (!anf->isa()) { + MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; + } + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + TraceManager::DebugTrace(std::make_shared(anf->debug_info())); + ParameterPtr new_parameter = graph->NewParameter(anf->cast()); + TraceManager::EndTrace(); + graph_inputs->push_back(new_parameter); + valid_inputs->push_back(valid_input); + return new_parameter; +} + +GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + auto graph_id = graph_sum_; + auto graph = ConstructKernelGraph(lst, outputs); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Set kernel info"; + SetKernelInfo(graph.get()); + predictmodel::StepConvertGraph(graph); + MS_LOG(INFO) << "Build kernel"; + BuildKernel(graph.get()); + MS_LOG(INFO) << "Assign kernel address"; + runtime_.AssignKernelAddress(graph.get()); + return graph_id; +} + +void CPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { + auto &kernel_graph = graphs_[graph_id]; + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_LOG(INFO) << "Bind input output address"; + std::vector need_sync_outputs; + runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs, &need_sync_outputs); + MS_LOG(INFO) << "Run graph start"; + predictmodel::StepConvertWeight(inputs); + auto execution_order = kernel_graph->execution_order(); + Reorder(&execution_order); + + bool enable_summary = summary_callback_ != nullptr; + kernel_graph->set_execution_order(execution_order); + NamedSummaryOutputs summary_outputs; + if (enable_summary) { + GetSummaryNodes(kernel_graph.get()); + summary_outputs = kernel_graph->summary_nodes(); + runtime_.IncreaseSummaryRefCount(summary_outputs); + } +#ifdef ENABLE_DEBUGGER + // debugger pre-execution processing + if (debugger_) { + debugger_->PreExecute(kernel_graph); + } +#endif + bool ret = runtime_.Run(kernel_graph.get()); + if (!ret) { + MS_LOG(EXCEPTION) << "Run graph failed"; + } + for (auto output : need_sync_outputs) { + (void)output->data_sync(); + } + + if (enable_summary) { + Summary(kernel_graph.get()); + runtime_.DecreaseSummaryRefCount(summary_outputs); + } + +#ifdef ENABLE_DEBUGGER + // debugger post-execution processing + if (debugger_) { + debugger_->PostExecute(); + } +#endif + MS_LOG(INFO) << "Run graph end"; +} + +void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto &kernel_nodes = kernel_graph->execution_order(); + for (const auto &kernel_node : kernel_nodes) { + MS_EXCEPTION_IF_NULL(kernel_node); + device::cpu::SetKernelInfo(kernel_node); + } +} + +void CPUSession::BuildKernel(const KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto &kernel_nodes = kernel_graph->execution_order(); + for (const auto &kernel_node : kernel_nodes) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + MS_LOG(INFO) << "Cpu building operator[" << kernel_name << "]."; + std::shared_ptr cpu_kernel = + kernel::CPUKernelFactory::GetInstance().Create(kernel_name, kernel_node); + if (cpu_kernel == nullptr) { + MS_LOG(EXCEPTION) << "Operator[" << kernel_name << "] is not support."; + } + cpu_kernel->Init(kernel_node); + AnfAlgo::SetKernelMod(cpu_kernel, kernel_node.get()); + MS_LOG(INFO) << "Cpu build success operator[" << kernel_name << "]."; + } +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h new file mode 100644 index 0000000000..b0dbd1cc2b --- /dev/null +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_CPU_SESSION_H +#define MINDSPORE_CCSRC_SESSION_CPU_SESSION_H +#include +#include +#include +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "runtime/device/cpu/cpu_kernel_runtime.h" +#include "backend/session/session_factory.h" +namespace mindspore { +namespace session { +class CPUSession : public SessionBasic { + public: + CPUSession() = default; + ~CPUSession() override = default; + void Init(uint32_t device_id) override { + SessionBasic::Init(device_id); + context_ = std::make_shared(kCPUDevice, device_id); + } + GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + + protected: + ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; + + private: + void SetKernelInfo(const KernelGraph *kernel_graph); + void BuildKernel(const KernelGraph *kernel_graph); + device::cpu::CPUKernelRuntime runtime_; +}; +MS_REG_SESSION(kCPUDevice, CPUSession); +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_CPU_SESSION_H diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc new file mode 100644 index 0000000000..14e30c1a44 --- /dev/null +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -0,0 +1,268 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/gpu_session.h" +#include "runtime/device/gpu/kernel_info_setter.h" +#include "runtime/device/gpu/gpu_kernel_build.h" +#include "runtime/device/gpu/gpu_kernel_runtime.h" +#include "runtime/device/gpu/gpu_stream_assign.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/pass/communication_op_fusion.h" +#include "backend/optimizer/pass/getitem_tuple.h" +#include "backend/optimizer/gpu/adam_weight_decay_fusion.h" +#include "backend/optimizer/gpu/adam_fusion.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "predict/predict.h" +#include "common/utils.h" +#include "common/trans.h" +#include "utils/context/ms_context.h" +#include "utils/base_ref_extends.h" + +namespace mindspore { +namespace session { +namespace gpu { +using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; + +void GPUSession::SelectKernel(const std::shared_ptr &kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + for (const auto &kernel_node : kernel_graph->execution_order()) { + MS_EXCEPTION_IF_NULL(kernel_node); + device::gpu::SetKernelInfo(kernel_node); + } +} + +void GPUSession::StartKernelRT() const { + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + if (!runtime_instance->Init()) { + MS_LOG(EXCEPTION) << "GPU start kernel runtime failed"; + } +} + +void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void GPUSession::HardwareOptimize(const std::shared_ptr &kernel_graph) { + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void GPUSession::AssignStream(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + device::gpu::AssignGpuStream(kernel_graph); +} + +void GPUSession::BuildKernel(const std::shared_ptr &kernel_graph) const { + device::gpu::GpuBuild(kernel_graph); +} + +void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->AssignMemory(kernel_graph); +} + +void GPUSession::RunOpAllocateMemory(const std::vector &input_tensors, + KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); +} + +void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->RunOpClearMemory(kernel_graph); +} + +void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const { + std::vector inputs(inputs_const); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto input_nodes = kernel_graph->inputs(); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + + for (size_t i = 0; i < inputs.size(); ++i) { + auto tensor = inputs[i]; + MS_EXCEPTION_IF_NULL(tensor); + auto input_node = input_nodes[i]; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { + auto pk_node = input_node->cast(); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + auto tensor_address = std::dynamic_pointer_cast(tensor->device_address()); + bool need_sync = false; + if (ms_context->enable_pynative_infer()) { + if (tensor_address == nullptr || tensor_address != device_address) { + need_sync = true; + } + } else if (tensor->is_dirty() || tensor_address == nullptr) { + need_sync = true; + } else if (tensor_address != device_address) { + if (tensor_address->DeviceType() == device_address->DeviceType()) { + AnfAlgo::SetOutputAddr(tensor_address, 0, pk_node.get()); + } else { + need_sync = true; + } + } + if (need_sync) { + tensor->set_device_address(device_address); + MS_EXCEPTION_IF_NULL(device_address); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } + tensor->set_dirty(false); + } +} + +void GPUSession::Execute(const std::shared_ptr &kernel_graph) const { + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + if (!runtime_instance->Run(kernel_graph.get())) { + MS_LOG(EXCEPTION) << "GPU execute graph failed!"; + } +} + +GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + // Construct graph, if successfully, graph_sum_ + 1 + auto graph_id = graph_sum_; + auto graph = ConstructKernelGraph(lst, outputs); + MS_EXCEPTION_IF_NULL(graph); + // Optimize + Optimize(graph); + // Select kernel build info + SelectKernel(graph); + // Convert kernel Graph to model + predictmodel::StepConvertGraph(graph); + // Start gpu kernel runtime + StartKernelRT(); + // HardwareOptimize + HardwareOptimize(graph); + // Assign CUDA streams + AssignStream(graph); + // Hide NoOp from execution graph + opt::HideNopNode(graph.get()); + // Build kernel if node is cnode + BuildKernel(graph); + // Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph + auto execution_order = graph->execution_order(); + Reorder(&execution_order); + graph->set_execution_order(execution_order); + // Get summary nodes. + GetSummaryNodes(graph.get()); + // Remove NoOp from execution graph + opt::RemoveNopNode(graph.get()); + // Set graph manager. + MS_EXCEPTION_IF_NULL(context_); + FuncGraphManagerPtr manager = MakeManager({graph}); + context_->AddManager(manager); + if (manager) { + manager->AddFuncGraph(graph); + graph->set_manager(manager); + } + // Alloc memory, including static memory and dynamic memory + AllocateMemory(graph.get()); + return graph_id; +} + +void GPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { + auto &kernel_graph = graphs_[graph_id]; + // Load input data from user input + LoadInputData(kernel_graph, inputs); + MS_EXCEPTION_IF_NULL(kernel_graph); + // Convert inputs to model + predictmodel::StepConvertWeight(inputs); + { + py::gil_scoped_release gil_release; + // Run graph on GPU + Execute(kernel_graph); + } + // Get result from GPU + UpdateOutputs(kernel_graph, outputs, inputs); + // Summary + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_gpu_summary()) { + Summary(kernel_graph.get()); + } +} + +void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask) { + // Check if the graph cache exists. + if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { + return; + } + // Prepare the graph + auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); + MS_EXCEPTION_IF_NULL(kernel_graph); + SelectKernel(kernel_graph); + StartKernelRT(); + // Hide NoOp from execution graph + opt::HideNopNode(kernel_graph.get()); + BuildKernel(kernel_graph); + run_op_graphs_[graph_info] = kernel_graph; +} + +py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors) { + auto kernel_graph = run_op_graphs_[graph_info]; + MS_EXCEPTION_IF_NULL(kernel_graph); + // Remove NoOp from execution graph + opt::RemoveNopNode(kernel_graph.get()); + RunOpAllocateMemory(input_tensors, kernel_graph.get()); + // Execute the computation + LoadInputData(kernel_graph, input_tensors); + Execute(kernel_graph); + // Fetch outputs + VectorRef outputs; + UpdateOutputs(kernel_graph, &outputs, input_tensors); + // Trans output to tuple + auto output_tensors = TransformBaseRefListToTuple(outputs); + if (!utils::isa(output_tensors) || + !py::isinstance(utils::cast(output_tensors).object_)) { + MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; + } + py::object tuple_obj = utils::cast(output_tensors).object_; + py::tuple tuple_tensors = py::cast(tuple_obj); + RunOpClearMemory(kernel_graph.get()); + return tuple_tensors; +} +} // namespace gpu +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h new file mode 100644 index 0000000000..7e07dfbcbd --- /dev/null +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_GPU_SESSION_H +#define MINDSPORE_CCSRC_SESSION_GPU_SESSION_H + +#include +#include +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/session_factory.h" +using KernelGraph = mindspore::session::KernelGraph; + +namespace mindspore { +namespace session { +namespace gpu { +class GPUSession : public SessionBasic { + public: + GPUSession() = default; + ~GPUSession() override = default; + + void Init(uint32_t device_id) override { + SessionBasic::Init(device_id); + context_ = std::make_shared(kGPUDevice, device_id); + } + + GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + + void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask) override; + py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors) override; + + private: + void SelectKernel(const std::shared_ptr &kernel_graph) const; + + void StartKernelRT() const; + + void Optimize(const std::shared_ptr &kernel_graph); + + void HardwareOptimize(const std::shared_ptr &kernel_graph); + + void AssignStream(const std::shared_ptr &kernel_graph); + + void BuildKernel(const std::shared_ptr &kernel_graph) const; + + void AllocateMemory(KernelGraph *kernel_graph) const; + + void RunOpAllocateMemory(const std::vector &input_tensors, KernelGraph *kernel_graph) const; + + void RunOpClearMemory(KernelGraph *kernel_graph) const; + + void LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const override; + + void Execute(const std::shared_ptr &kernel_graph) const; +}; +using GPUSessionPtr = std::shared_ptr; +MS_REG_SESSION(kGPUDevice, GPUSession); +} // namespace gpu +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_GPU_SESSION_H diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc new file mode 100644 index 0000000000..df810fe6ef --- /dev/null +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -0,0 +1,1023 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/kernel_graph.h" +#include +#include +#include +#include +#include "frontend/operator/ops.h" +#include "ir/param_value.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace session { +namespace { +constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; +constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; +void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, + std::unordered_set *visited_nodes) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(que); + MS_EXCEPTION_IF_NULL(visited_nodes); + if (visited_nodes->find(node) == visited_nodes->end()) { + que->push(node); + (void)visited_nodes->insert(node); + MS_LOG(DEBUG) << "Push que:" << node->DebugString(); + } +} + +std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { + auto item_with_index = + AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple}); + AnfNodePtr node = item_with_index.first; + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { + auto outputs = AnfAlgo::GetAllOutput(node); + std::set memo; + std::vector new_output; + for (auto &output : outputs) { + if (memo.find(output) != memo.end()) { + continue; + } + memo.insert(output); + new_output.push_back(output); + } + if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) { + node = new_output[0]; + } + } + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { + return {node}; + } + std::vector real_inputs; + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast()); + for (const auto &child_graph : child_graphs) { + if (child_graph->get_output_null()) { + continue; + } + auto real_input = child_graph->output(); + auto child_real_inputs = GetCallRealOutputs(real_input); + std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs)); + } + return real_inputs; +} + +AnfNodePtr MakeValueNode(const AnfNodePtr &node) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return nullptr; + } + + ValueNodePtr new_value_node = std::make_shared(value_node->value()); + new_value_node->set_abstract(value_node->abstract()); + // create kernel_info fo new value node + auto kernel_info = std::make_shared(); + new_value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + // set value node initial device data type = infer data type + std::vector types; + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { + types.push_back(kTypeUnknown); + } + kernel_build_info_builder->SetOutputsDeviceType(types); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + return new_value_node; +} + +bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { + if (left == right) { + return true; + } + if (left == nullptr || right == nullptr) { + return false; + } + if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) { + return false; + } + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) { + return AnfAlgo::GetNodeAttr(left, kAttrLabelIndex) == + AnfAlgo::GetNodeAttr(right, kAttrLabelIndex); + } + return false; +} +} // namespace +std::vector KernelGraph::outputs() const { + auto graph_output = output(); + if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { + auto make_tuple = output()->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + auto &inputs = make_tuple->inputs(); + return std::vector(inputs.begin() + 1, inputs.end()); + } + return std::vector(1, graph_output); +} + +void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, + std::unordered_set *visited_nodes) { + MS_EXCEPTION_IF_NULL(visit_queue); + MS_EXCEPTION_IF_NULL(visited_nodes); + auto it = node_output_edges_.find(node); + if (it == node_output_edges_.end()) { + // value node and parameter has no input,no need to print log + if (node->isa()) { + MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; + } + return; + } + + // visit all reduce node first, then other nodes + std::vector active_nodes; + for (const auto &output_edge : it->second) { + auto next_node = output_edge.first; + MS_EXCEPTION_IF_NULL(next_node); + if (node_input_num_.find(next_node) == node_input_num_.end()) { + MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; + } + MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() + << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; + if (node_input_num_[next_node] < output_edge.second) { + MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node] + << ",depend edge:" << output_edge.second; + } + node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; + // allreduce first + if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) { + (void)visited_nodes->insert(next_node); + if (AnfAlgo::IsCommunicationOp(next_node)) { + MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString(); + visit_queue->push(next_node); + } else { + active_nodes.emplace_back(next_node); + } + } + } + + for (auto &node : active_nodes) { + MS_EXCEPTION_IF_NULL(node); + MS_LOG(DEBUG) << "Visit node:" << node->DebugString(); + visit_queue->push(node); + } +} + +void KernelGraph::SetExecOrderByDefault() { + std::queue seed_nodes; + UpdateNodeEdgeList(&seed_nodes); + execution_order_.clear(); + std::unordered_set visited_nodes; + std::queue zero_input_nodes; + AnfNodePtr last_communication_node = nullptr; + std::queue communication_descendants; + while (!seed_nodes.empty() || last_communication_node != nullptr) { + // seed nodes first, then visit last all reduce node descendant + if (seed_nodes.empty()) { + VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); + last_communication_node = nullptr; + } else { + zero_input_nodes.push(seed_nodes.front()); + seed_nodes.pop(); + } + // all reduce node descendant first, then common queue + while (!zero_input_nodes.empty() || !communication_descendants.empty()) { + AnfNodePtr node = nullptr; + bool is_communication_descendant = false; + if (communication_descendants.empty()) { + node = zero_input_nodes.front(); + zero_input_nodes.pop(); + } else { + node = communication_descendants.front(); + communication_descendants.pop(); + is_communication_descendant = true; + } + // add execute node + MS_EXCEPTION_IF_NULL(node); + if (node->isa() && AnfAlgo::IsRealKernel(node)) { + execution_order_.push_back(node->cast()); + } + // for all reduce node, visit last all reduce node descendant + if (AnfAlgo::IsCommunicationOp(node)) { + if (last_communication_node != nullptr) { + VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); + } + last_communication_node = node; + } else if (is_communication_descendant) { + VisitNodeDescendants(node, &communication_descendants, &visited_nodes); + } else { + VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes); + } + } + } + CheckLoop(); + // resort start label / end goto + std::vector re_order; + if (start_label_ != nullptr) { + re_order.push_back(start_label_); + } + for (auto &node : execution_order_) { + if (node == start_label_ || node == end_goto_) { + continue; + } + + if (IsSameLabel(node, end_goto_)) { + end_goto_ = node; + MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id(); + continue; + } + + if (IsSameLabel(node, start_label_)) { + start_label_ = node; + MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id(); + continue; + } + + re_order.push_back(node); + } + if (end_goto_ != nullptr) { + re_order.push_back(end_goto_); + } + execution_order_ = re_order; +} + +void KernelGraph::CheckLoop() { + std::map none_zero_nodes; + if (node_input_edges_.size() != node_input_num_.size()) { + MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size() + << "not equal to node_input_num_ size:" << node_input_num_.size(); + } + for (auto &it : node_input_num_) { + MS_EXCEPTION_IF_NULL(it.first); + string str; + auto node_input_it = node_input_edges_.find(it.first); + if (node_input_it == node_input_edges_.end()) { + MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]"; + } + for (const auto &input_edge : node_input_edges_[it.first]) { + MS_EXCEPTION_IF_NULL(input_edge.first); + str = str.append(input_edge.first->DebugString()).append("|"); + } + if (it.second != 0) { + MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second; + none_zero_nodes[it.first] = it.second; + } + } + // if don't consider control depend and loop exit,a exception will be throw + if (!none_zero_nodes.empty()) { + MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); + } +} + +CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { + auto cnode = FuncGraph::NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode); + cnode->set_abstract(std::make_shared()); + CreateKernelInfoFromNewParameter(cnode); + + auto kernel_info = std::make_shared(); + std::vector feature_map_input_indexs; + // if the node only has the primitive(such as getNext) or the node's input has a feature map input + // then the node's output is a feature map output + for (size_t index = 1; index < inputs.size(); ++index) { + auto node = inputs[index]; + if (AnfAlgo::IsFeatureMapOutput(node)) { + feature_map_input_indexs.push_back(index); + } + } + if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { + AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); + } + if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { + kernel_info->SetFeatureMapFlag(true); + } + if (AnfAlgo::IsRealCNodeKernel(cnode)) { + AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); + AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); + } + cnode->set_kernel_info(kernel_info); + AnfAlgo::SetGraphId(graph_id_, cnode.get()); + return cnode; +} + +void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { + if (!AnfAlgo::IsGraphKernel(cnode)) { + return; + } + auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector node_list; + std::vector input_list; + std::vector output_list; + kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); + for (auto &anf_node : node_list) { + MS_EXCEPTION_IF_NULL(anf_node); + auto kernel_info = std::make_shared(); + anf_node->set_kernel_info(kernel_info); + auto anf_cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(anf_cnode); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_cnode); ++i) { + auto input_node = anf_cnode->input(i + 1); + MS_EXCEPTION_IF_NULL(input_node); + if (IsValueNode(input_node)) { + auto new_input_node = MakeValueNode(input_node); + if (new_input_node != nullptr) { + anf_cnode->set_input(i + 1, new_input_node); + } + } + } + } + for (auto &anf_node : input_list) { + MS_EXCEPTION_IF_NULL(anf_node); + auto kernel_info = std::make_shared(); + anf_node->set_kernel_info(kernel_info); + } +} + +CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto new_cnode = std::make_shared(*cnode); + // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map + if (BackendNodeExistInFrontBackendMap(cnode)) { + FrontBackendlMapUpdate(cnode, new_cnode); + } + AnfAlgo::SetGraphId(graph_id_, cnode.get()); + if (IsInternalOutput(cnode)) { + ReplaceInternalOutput(cnode, new_cnode); + } + return new_cnode; +} + +ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { + ParameterPtr new_parameter = add_parameter(); + MS_EXCEPTION_IF_NULL(new_parameter); + // create kernel_info form new parameter + auto kernel_info = std::make_shared(); + size_t output_tensor_num = 1; + // if use default parameter = nullptr,it remarks create a new parameter from no parameter + if (parameter == nullptr) { + new_parameter->set_abstract(std::make_shared()); + kernel_info->SetFeatureMapFlag(true); + } else { + // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter + new_parameter->set_abstract(parameter->abstract()); + new_parameter->set_name(parameter->name()); + if (AnfAlgo::IsParameterWeight(parameter)) { + new_parameter->set_default_param(parameter->default_param()); + kernel_info->SetFeatureMapFlag(false); + } else { + kernel_info->SetFeatureMapFlag(true); + } + } + new_parameter->set_kernel_info(kernel_info); + // create kernel_build_info for new parameter + auto kernel_build_info_builder = std::make_shared(); + // create init data type, + std::vector init_data_type = {}; + + TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0); + init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); + + // set the format of parameter to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat(std::vector(output_tensor_num, kOpFormat_DEFAULT)); + // set parameter initaial device data type + kernel_build_info_builder->SetOutputsDeviceType(init_data_type); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get()); + AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); + return new_parameter; +} + +std::vector KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + auto node_value = value_node->value(); + auto output_size = AnfAlgo::GetOutputTensorNum(value_node); + std::vector convert_inputs; + if (!node_value->isa()) { + MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString(); + } + auto value_tuple = node_value->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + if (value_tuple->size() != output_size) { + MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size() + << " is not mathced with the value node's output size" << output_size; + } + for (size_t index = 0; index < value_tuple->value().size(); ++index) { + auto new_value_node = std::make_shared(value_tuple->value()[index]); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)}, + {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get()); + AddValueNodeToGraph(new_value_node); + auto kernel_info = std::make_shared(); + new_value_node->set_kernel_info(kernel_info); + kernel_info->SetFeatureMapFlag(false); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT}); + // set value node initial device data type = infer data type + kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown}); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); + AddValueNodeToGraph(new_value_node); + convert_inputs.emplace_back(new_value_node); + } + if (!RemoveValueNodeFromGraph(value_node)) { + MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); + } + return convert_inputs; +} + +ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + auto new_value_node = MakeValueNode(value_node)->cast(); + AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); + return new_value_node; +} + +const std::vector &KernelGraph::inputs() const { + MS_EXCEPTION_IF_NULL(inputs_); + return *inputs_; +} + +void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) { + MS_EXCEPTION_IF_NULL(front_anf); + MS_EXCEPTION_IF_NULL(backend_anf); + if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) { + MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_"; + } + if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) { + MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_"; + } + front_backend_anf_map_[front_anf] = backend_anf; + backend_front_anf_map_[backend_anf] = front_anf; +} + +void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) { + MS_EXCEPTION_IF_NULL(old_backend_anf); + MS_EXCEPTION_IF_NULL(new_backend_anf); + if (old_backend_anf == new_backend_anf) { + MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString(); + return; + } + if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) { + MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map"; + return; + } + if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) { + MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString(); + } + front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf; + backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf]; + // delete old kernel + (void)backend_front_anf_map_.erase(old_backend_anf); +} +// get kernel by anf +AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) { + if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) { + return nullptr; + } + return front_backend_anf_map_[front_anf]; +} + +bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) { + return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end(); +} + +ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) { + if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) { + return nullptr; + } + return tensor_to_value_node_map_[tensor]; +} + +void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(tensor); + MS_EXCEPTION_IF_NULL(value_node); + tensor_to_value_node_map_[tensor] = value_node; +} + +void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(input); + MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num; + auto output_depend_edge = std::pair(node, depend_edge_num); + // add output depend edge of input + auto output_it = node_output_edges_.find(input); + if (output_it == node_output_edges_.end()) { + node_output_edges_[input] = std::vector>{output_depend_edge}; + } else { + output_it->second.push_back(output_depend_edge); + } + // add input depend edge of output + auto input_depend_edge = std::pair(input, depend_edge_num); + auto input_it = node_input_edges_.find(node); + if (input_it == node_input_edges_.end()) { + node_input_edges_[node] = std::vector>{input_depend_edge}; + } else { + input_it->second.push_back(input_depend_edge); + } + // add node input depend num + auto depend_it = node_input_num_.find(node); + if (depend_it == node_input_num_.end()) { + node_input_num_[node] = depend_edge_num; + } else { + depend_it->second += depend_edge_num; + } +} + +std::vector KernelGraph::GetOutputNodes(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto it = node_output_edges_.find(node); + if (it == node_output_edges_.end()) { + MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]"; + } + std::vector output_nodes; + auto trans = [](const std::pair &pair) -> AnfNodePtr { return pair.first; }; + (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans); + return output_nodes; +} + +// Find control_depend real input nodes. +void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(result); + MS_EXCEPTION_IF_NULL(visited); + if (visited->find(anf_node) != visited->end()) { + MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; + return; + } + visited->insert(anf_node); + if (AnfAlgo::IsRealKernel(anf_node)) { + result->emplace_back(anf_node); + return; + } + if (!anf_node->isa()) { + return; + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString(); + } + auto input0 = cnode->input(0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + GetAllFatherRealNode(cnode->input(i), result, visited); + } + } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + if (cnode->inputs().size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited); + } else if (IsPrimitive(input0, prim::kPrimDepend)) { + if (cnode->inputs().size() != kDependInputSize) { + MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!"; + } + GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited); + GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited); + } +} + +// update the depend relations of control depend +void KernelGraph::UpdateControlDependRelations(const std::vector &depends) { + for (const auto &node : depends) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { + MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend"; + } + auto prior_node = cnode->input(kControlDependPriorIndex); + auto depend_node = cnode->input(kControlDependBehindIndex); + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(depend_node); + std::vector prior_nodes = {prior_node}; + std::vector depend_nodes = {depend_node}; + int depend_mode = 0; + if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { + depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); + } + MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() + << "], depend_mode :" << depend_mode << "."; + if (prior_node->isa() && depend_mode == 1) { + prior_nodes = GetOutputNodes(prior_node); + } + if (depend_node->isa()) { + depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector{}; + } + + std::vector real_prior_nodes; + std::set prior_visited; + for (const auto &tmp : prior_nodes) { + GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); + } + + std::vector real_depend_nodes; + std::set depend_visited; + for (const auto &tmp : depend_nodes) { + GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); + } + + for (auto &first_node : real_prior_nodes) { + if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { + continue; + } + for (auto &second_node : real_depend_nodes) { + if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { + continue; + } + MS_EXCEPTION_IF_NULL(first_node); + MS_EXCEPTION_IF_NULL(second_node); + MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() + << ",second node:" << second_node->DebugString(); + AddDependEdge(second_node, first_node, 1); + } + } + } +} + +bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue *que, + std::unordered_set *visited_nodes) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(que); + MS_EXCEPTION_IF_NULL(visited_nodes); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { + return false; + } + // set the control depend visited but don't push it into the que + if (visited_nodes->find(node) != visited_nodes->end()) { + return true; + } + (void)visited_nodes->insert(cnode); + // add a 0 depend num to keep the link relations to prepare for finding zero output nodes + auto prior_node = cnode->input(kControlDependPriorIndex); + auto depend_node = cnode->input(kControlDependBehindIndex); + for (const auto &input : cnode->inputs()) { + AddDependEdge(node, input, 0); + } + PushNoVisitedNode(depend_node, que, visited_nodes); + PushNoVisitedNode(prior_node, que, visited_nodes); + return true; +} + +void KernelGraph::UpdateNodeEdgeList(std::queue *seed_nodes) { + MS_EXCEPTION_IF_NULL(seed_nodes); + node_output_edges_.clear(); + node_input_num_.clear(); + node_input_edges_.clear(); + std::vector control_depends; + std::unordered_set visited_nodes; + std::queue que; + que.push(get_return()); + while (!que.empty()) { + auto node = que.front(); + que.pop(); + MS_EXCEPTION_IF_NULL(node); + if (node->isa() || node->isa()) { + seed_nodes->push(node); + continue; + } + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // handle data links + for (const auto &input : cnode->inputs()) { + size_t depend_edge_num = 1; + // handle control depend,all inputs of control depend has no depend edge + if (HandleControlDependNode(input, &que, &visited_nodes)) { + control_depends.push_back(input); + depend_edge_num = 0; + } + PushNoVisitedNode(input, &que, &visited_nodes); + AddDependEdge(node, input, depend_edge_num); + } + } + UpdateControlDependRelations(control_depends); +} + +void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); } + +bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; } + +AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const { + if (!IsInRefOutputMap(out_pair)) { + MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap"; + } + return ref_out_in_map_.at(out_pair); +} + +void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) { + if (IsInRefOutputMap(final_pair)) { + MS_LOG(EXCEPTION) << "Out_pair is already in RefOutputMap"; + } + (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair)); +} + +bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { + if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) { + (void)graph_value_nodes_.erase(value_node); + return true; + } + return false; +} + +void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull new_anf_node) { + MS_EXCEPTION_IF_NULL(inputs_); + { + std::queue seed_nodes; + UpdateNodeEdgeList(&seed_nodes); + } + auto it = node_output_edges_.find(old_anf_node); + if (it != node_output_edges_.end()) { + const auto &outputs = it->second; + for (auto &output_node : outputs) { + MS_EXCEPTION_IF_NULL(output_node.first); + auto output_cnode = output_node.first->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + auto &output_node_inputs = output_cnode->inputs(); + // don't replace node if it is a control edge => output_node.second == 0 + if (output_node.second == 0) { + continue; + } + for (size_t i = 1; i < output_node_inputs.size(); i++) { + if (output_node_inputs[i] == old_anf_node.get()) { + output_cnode->set_input(i, new_anf_node); + } + } + // update graph inputs + for (size_t i = 0; i < inputs_->size(); i++) { + if ((*inputs_)[i] == old_anf_node.get()) { + MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() + << ",new graph input:" << new_anf_node->DebugString(); + (*inputs_)[i] = new_anf_node.get(); + break; + } + } + } + // update front to backend map + FrontBackendlMapUpdate(old_anf_node, new_anf_node); + } + { + std::queue seed_nodes; + UpdateNodeEdgeList(&seed_nodes); + } + // update graph inputs in child graph + auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(), + [&old_anf_node](const std::pair> &n) -> bool { + return n.first == old_anf_node.get(); + }); + if (it_real_inputs != real_inputs_.end()) { + // erase old parameter in map + auto old_args = it_real_inputs->second; + real_inputs_.erase(it_real_inputs); + // insert new parameter to map + auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(), + [&new_anf_node](const std::pair> &n) -> bool { + return n.first == new_anf_node.get(); + }); + if (iter != real_inputs_.end()) { + MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited."; + iter->second = old_args; + } else { + real_inputs_.emplace_back(new_anf_node, old_args); + } + } +} + +void KernelGraph::UpdateExecuteKernelStreamLabel() { + for (auto &kernel : execution_order_) { + AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get()); + } +} + +std::vector> KernelGraph::GetLeafGraphOrder() { + std::vector> leaf_graph_order; + if (IsLeafGraph()) { + leaf_graph_order.push_back(shared_from_this()->cast()); + } else { + for (const auto &child_graph : child_graph_order_) { + MS_EXCEPTION_IF_NULL(child_graph); + auto child_leaf_graph_order = child_graph->GetLeafGraphOrder(); + std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order)); + } + } + return leaf_graph_order; +} + +bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } + +std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const { + std::vector result; + for (const auto &anf : execution_order_) { + if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { + result.push_back(anf->cast()); + } + } + return result; +} + +void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { + MS_EXCEPTION_IF_NULL(parameter); + MS_EXCEPTION_IF_NULL(arg); + MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString(); + MS_EXCEPTION_IF_NULL(parameter); + MS_EXCEPTION_IF_NULL(arg); + auto iter = std::find_if( + real_inputs_.begin(), real_inputs_.end(), + [¶meter](const std::pair> &n) -> bool { return n.first == parameter; }); + if (iter != real_inputs_.end()) { + auto &args = iter->second; + args.push_back(arg); + } else { + real_inputs_.emplace_back(parameter, std::vector(1, arg)); + } +} + +void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph) { + unreuse_args_[arg] = from_graph; +} + +void KernelGraph::UpdateCallRealInput() { + MS_LOG(INFO) << "Update graph id: " << graph_id_; + std::vector>> real_inputs_map; + for (auto &it : real_inputs_) { + auto parameter = it.first; + MS_EXCEPTION_IF_NULL(parameter); + auto real_inputs = it.second; + std::vector new_real_inputs; + for (auto &real_input : real_inputs) { + // if real input is a call node ,find the child graph output act as the new real input + auto tmp_real_input = GetCallRealOutputs(real_input); + std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); + // replace the call in unreuse_args_ + auto unreuse_arg_it = unreuse_args_.find(real_input); + if (unreuse_arg_it != unreuse_args_.end()) { + auto old_graph = unreuse_arg_it->second; + for (auto new_real_input : new_real_inputs) { + // if call reference graph output is parameter, it will be allowed to reuse + if (!new_real_input->isa()) { + unreuse_args_[new_real_input] = old_graph; + } + } + } + } + real_inputs_map.emplace_back(parameter, new_real_inputs); + } + real_inputs_ = real_inputs_map; +} + +void KernelGraph::PrintGraphExecuteOrder() const { + MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; + for (size_t i = 0; i < execution_order_.size(); i++) { + CNodePtr cur_cnode_ptr = execution_order_[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + std::string event_str; + std::string label_str; + if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) { + event_str = ", event_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId)) + "]"; + } + + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) { + label_str = ", label_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrLabelIndex)) + "]"; + } + + if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) { + auto label_list = AnfAlgo::GetNodeAttr>(cur_cnode_ptr, kAttrLabelSwitchList); + label_str = ", label_id["; + for (size_t j = 0; j < label_list.size(); ++j) { + label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]"); + } + } + + MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id[" + << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" + << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]" + << event_str << label_str; + } +} + +void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { + if (front_node == nullptr || node == nullptr) { + MS_LOG(INFO) << "Front node or node is nullptr"; + return; + } + MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); + front_to_internal_outputs_map_[front_node] = node; + internal_outputs_to_front_map_[node] = front_node; +} + +void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) { + if (new_node == nullptr || node == nullptr) { + MS_LOG(INFO) << "New node or node is nullptr"; + return; + } + if (node == new_node) { + MS_LOG(INFO) << "New node and node is the same"; + return; + } + auto iter = internal_outputs_to_front_map_.find(node); + if (iter == internal_outputs_to_front_map_.end()) { + MS_LOG(INFO) << "Node is not internal output"; + return; + } + MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString(); + internal_outputs_to_front_map_[new_node] = iter->second; + front_to_internal_outputs_map_[iter->second] = new_node; + internal_outputs_to_front_map_.erase(iter); +} + +AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { + auto iter = front_to_internal_outputs_map_.find(front_node); + if (iter != front_to_internal_outputs_map_.end()) { + return iter->second; + } + return nullptr; +} + +bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const { + if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) { + return true; + } + return false; +} + +AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const { + auto iter = internal_outputs_to_front_map_.find(node); + if (iter != internal_outputs_to_front_map_.end()) { + return iter->second; + } + return nullptr; +} + +void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) { + if (node == nullptr) { + return; + } + (void)final_output_kernels_.insert(node); +} + +bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { + if (node == nullptr) { + return false; + } + if (final_output_kernels_.find(node) != final_output_kernels_.end()) { + return true; + } + return false; +} + +void KernelGraph::UpdateChildGraphOrder() { + MS_LOG(INFO) << "Update " << ToString() << " child graph order."; + SetExecOrderByDefault(); + auto call_nodes = FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); + std::vector child_graph_order; + for (auto &call_node : call_nodes) { + MS_EXCEPTION_IF_NULL(call_node); + auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); + for (const auto &child_graph : call_child_graphs) { + MS_EXCEPTION_IF_NULL(child_graph); + if (child_graph != parent_graph_) { + auto shared_this = std::dynamic_pointer_cast(shared_from_this()); + MS_EXCEPTION_IF_NULL(shared_this); + child_graph->set_parent_graph(shared_this); + } + child_graph_order.push_back(child_graph); + } + } + for (size_t i = 0; i < child_graph_order.size(); ++i) { + MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; + } + child_graph_order_ = child_graph_order; +} + +std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } + +KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h new file mode 100644 index 0000000000..48df351120 --- /dev/null +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -0,0 +1,233 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H +#define MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "utils/graph_utils.h" +#include "utils/contract.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace session { +using AnfWithOutIndex = std::pair; +class KernelGraph : public FuncGraph { + public: + KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false), current_epoch_(0) { + inputs_ = std::make_shared>(); + execution_order_ = {}; + executable_ = true; + summary_node_exist_ = false; + stream_distinction_label_ = kInvalidDistincLabel; + } + ~KernelGraph() override; + + MS_DECLARE_PARENT(KernelGraph, FuncGraph); + + const std::vector &inputs() const; + std::vector *MutableInputs() const { return inputs_.get(); } + std::vector outputs() const; + CNodePtr NewCNode(const std::vector &inputs) override; + void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); + CNodePtr NewCNode(const CNodePtr &cnode); + ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); + ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr); + std::vector SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node); + void set_execution_order(const std::vector &order) { execution_order_ = order; } + const std::vector &execution_order() const { return execution_order_; } + void SetExecOrderByDefault(); + uint32_t graph_id() const { return graph_id_; } + void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } + + // and a new front to backend anf relation to maop + void FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf); + // replace old backend anf with new backend anf + void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf); + // get backend anf by front anf + AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf); + // check backend node whether exist in map + bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf); + // get value node by tensor + ValueNodePtr GetValueNodeByTensor(const tensor::TensorPtr &tensor); + // add value node tensor relation map + void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node); + // get all value nodes of graph + const std::unordered_set graph_value_nodes() const { return graph_value_nodes_; } + // add value node to graph + void AddValueNodeToGraph(const ValueNodePtr &value_node); + // ref output is in map + bool IsInRefOutputMap(const AnfWithOutIndex &pair) const; + // get ref correspond pairs + AnfWithOutIndex GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const; + // add ref correspond pairs + void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair); + // get map + std::map GetRefMap() const { return ref_out_in_map_; } + // checkout whether loop exist in graph + void CheckLoop(); + // check whether graph is executable + bool executable() const { return executable_; } + // set executable of graph + void set_executable(bool executable) { executable_ = executable; } + // set summary_node of graph + void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; } + // check whether exist summary node in graph + bool summary_node_exist() const { return summary_node_exist_; } + // set invalid inputs for control sink + std::vector *MutableValidInputs() { return &valid_inputs_; } + std::vector valid_inputs() const { return valid_inputs_; } + // replace node in graph + void ReplaceNode(NotNull old_anf_node, NotNull new_anf_node); + // set stream label of graph + void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; } + // get stream label of graph + uint32_t stream_distinction_label() { return stream_distinction_label_; } + // refresh execute kernel stream label + void UpdateExecuteKernelStreamLabel(); + // calculate the leaf graph order of root graph + std::vector> GetLeafGraphOrder(); + // the child graph of current graph + const std::vector> &child_graph_order() const { return child_graph_order_; } + void set_child_graph_order(const std::vector> &order) { child_graph_order_ = order; } + // checkout whether current graph is leaf graph + bool IsLeafGraph() const; + + // set input_tensors pointer of control parameter + void set_input_ctrl_tensors(const std::shared_ptr> &input_tensors_ptr) { + input_ctrl_tensors_ = input_tensors_ptr; + } + // get input_tensors pointer of control parameter + std::shared_ptr> input_ctrl_tensors() const { return input_ctrl_tensors_; } + // get parent kernel graph + std::shared_ptr parent_graph() const { return parent_graph_; } + // set parent kernel graph + void set_parent_graph(const std::shared_ptr &parent_graph) { parent_graph_ = parent_graph; } + // find anf node in graph + std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; + // get real inputs + const std::vector>> &real_inputs() const { return real_inputs_; } + void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); + // mark unreused args + void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph); + const std::map> &unreuse_args() const { return unreuse_args_; } + // used to dump ir + std::string ToString() const override; + // update the real input if the node is a call + void UpdateCallRealInput(); + + void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } + CNodePtr get_start_label() { return start_label_; } + void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } + CNodePtr get_end_goto() { return end_goto_; } + bool get_output_null() { return null_output_; } + void set_output_null(bool is_output_null) { null_output_ = is_output_null; } + void PrintGraphExecuteOrder() const; + const std::map> &summary_nodes() const { return summary_nodes_; } + void set_summary_nodes(const std::map> &nodes) { summary_nodes_ = nodes; } + void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); + void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node); + AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; + bool IsInternalOutput(const AnfNodePtr &node) const; + AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const; + void AddFinalOutputKernel(const AnfNodePtr &node); + bool IsFinalOutputKernel(const AnfNodePtr &node) const; + uint32_t current_epoch() const { return current_epoch_; } + void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } + void UpdateChildGraphOrder(); + const std::vector &child_graph_result() const { return child_graph_result_; } + void AddChildGraphResult(const AnfNodePtr ¶meter) { child_graph_result_.push_back(parameter); } + void set_child_graph_result(const std::vector &child_graph_result) { + child_graph_result_ = child_graph_result; + } + + private: + // remove value node form graph + bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); + void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, + std::unordered_set *visited_nodes); + // update node edge list + void UpdateNodeEdgeList(std::queue *seed_nodes); + // add node depend edge by data edge or control depend + void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); + // handle control depend + std::vector GetOutputNodes(const AnfNodePtr &node); + bool HandleControlDependNode(const AnfNodePtr &node, std::queue *que, + std::unordered_set *visited_nodes); + void UpdateControlDependRelations(const std::vector &depends); + + std::shared_ptr> inputs_; + std::vector child_graph_result_; + std::vector execution_order_; + uint32_t graph_id_; + uint32_t stream_distinction_label_; + + // record map bettween front anf and backend anf,use two map implement bidirectional map + std::unordered_map front_backend_anf_map_; + std::unordered_map backend_front_anf_map_; + // there may be a tensor from ME backend ,a value ndoe will be create according the tensor,map record + std::unordered_map tensor_to_value_node_map_; + // include all value nodes + std::unordered_set graph_value_nodes_; + std::unordered_map node_input_num_; + std::unordered_map>> node_input_edges_; + // record map between ref final output anf with index and ref origin input with index + std::map ref_out_in_map_; + std::unordered_map>> node_output_edges_; + std::map> summary_nodes_; + // graph needn't execute + bool executable_; + // exist summary node in graph + bool summary_node_exist_; + // valid inputs + std::vector valid_inputs_; + + // new members for control sink process + // all child grahs refers to partial node + std::map> node_to_child_graphs_; + // child graph execute order in root graph + std::vector> child_graph_order_; + + // input_tensors of control parameter + std::shared_ptr> input_ctrl_tensors_; + + // parameter graph + std::shared_ptr parent_graph_; + // record real parameters,inputs_ is the formal parameters + std::vector>> real_inputs_; + std::map> unreuse_args_; + + CNodePtr start_label_; + CNodePtr end_goto_; + bool null_output_; + std::unordered_map front_to_internal_outputs_map_; + std::unordered_map internal_outputs_to_front_map_; + std::set final_output_kernels_; + uint32_t current_epoch_; +}; +} // namespace session +using KernelGraphPtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H diff --git a/mindspore/ccsrc/backend/session/session.cc b/mindspore/ccsrc/backend/session/session.cc new file mode 100644 index 0000000000..95484a1113 --- /dev/null +++ b/mindspore/ccsrc/backend/session/session.cc @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "include/inference.h" +#include "backend/session/session.h" +#include "utils/load_onnx/anf_converter.h" +#include "backend/session/session_basic.h" +#include "backend/session/session_factory.h" +#include "utils/base_ref_utils.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#ifdef ENABLE_D +#include "utils/context/ms_context.h" +#include "backend/session/ascend_session.h" +#else +#include "backend/session/cpu_session.h" +#endif + +namespace py = pybind11; +namespace mindspore::inference { +std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device) { + try { + inference::Session::RegAllOp(); + auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); + return anf_graph; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference LoadModel failed"; + return nullptr; + } +} + +void ExitInference() { + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return; + } + if (!ms_context->CloseTsd()) { + MS_LOG(ERROR) << "Inference CloseTsd failed!"; + return; + } +} + +std::shared_ptr MSSession::CreateSession(const std::string &device, uint32_t device_id) { + try { + auto session = std::make_shared(); + auto ret = session->Init(device, device_id); + if (ret != 0) { + return nullptr; + } + return session; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference CreatSession failed"; + return nullptr; + } +} + +void Session::RegAllOp() { + static std::mutex init_mutex; + static bool Initialized = false; + + std::lock_guard lock(init_mutex); + if (Initialized) { + return; + } + Initialized = true; + MsContext::GetInstance()->set_execution_mode(kGraphMode); + Py_Initialize(); + auto c_expression = PyImport_ImportModule("mindspore._c_expression"); + if (c_expression == nullptr) { + MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module."; + return; + } + PyObject *c_expression_dict = PyModule_GetDict(c_expression); + + PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); + if (op_info_loader_class == nullptr) { + MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression."; + return; + } + PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); + if (op_info_loader == nullptr) { + MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance."; + return; + } + PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); + if (op_info_loader_ins == nullptr) { + MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance."; + return; + } + auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); + if (all_ops_info_vector_addr_ul == nullptr) { + MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr."; + return; + } + auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); + auto all_ops_info = static_cast *>(all_ops_info_vector_addr); + for (auto op_info : *all_ops_info) { + kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); + } + all_ops_info->clear(); + delete all_ops_info; + Py_DECREF(op_info_loader); + Py_DECREF(op_info_loader_class); + Py_DECREF(c_expression_dict); + Py_DECREF(c_expression); + return; +} + +uint32_t Session::CompileGraph(std::shared_ptr funcGraphPtr) { + MS_ASSERT(session_impl_ != nullptr); + try { + auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); + py::gil_scoped_release gil_release; + return graph_id; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference CompileGraph failed"; + return static_cast(-1); + } +} + +MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector> &inputs) { + try { + std::vector inTensors; + inTensors.resize(inputs.size()); + bool has_error = false; + std::transform(inputs.begin(), inputs.end(), inTensors.begin(), + [&has_error](const std::shared_ptr &tensor_ptr) -> tensor::TensorPtr { + if (tensor_ptr == nullptr) { + MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; + has_error = true; + return nullptr; + } + auto tensor = static_cast(tensor_ptr.get()); + if (tensor == nullptr) { + MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; + has_error = true; + return nullptr; + } + return tensor->tensor(); + }); + if (has_error) { + MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; + std::vector> multiTensor; + return multiTensor; + } + VectorRef outputs; + session_impl_->RunGraph(graph_id, inTensors, &outputs); + + return TransformVectorRefToMultiTensor(outputs); + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference Rungraph failed"; + return MultiTensor(); + } +} +namespace { +string AjustTargetName(const std::string &device) { + if (device == kAscendDevice) { + return std::string(kAscendDevice) + "Inference"; + } else { + MS_LOG(ERROR) << "Only support device Ascend right now"; + return ""; + } +} +} // namespace +int Session::Init(const std::string &device, uint32_t device_id) { + RegAllOp(); + auto ms_context = MsContext::GetInstance(); + ms_context->set_execution_mode(kGraphMode); + ms_context->set_device_id(device_id); + auto ajust_device = AjustTargetName(device); + if (ajust_device == "") { + return -1; + } + ms_context->set_device_target(device); + session_impl_ = session::SessionFactory::Get().Create(ajust_device); + if (session_impl_ == nullptr) { + MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; + return -1; + } + session_impl_->Init(device_id); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return -1; + } + if (!ms_context->OpenTsd()) { + MS_LOG(ERROR) << "Session init OpenTsd failed!"; + return -1; + } + return 0; +} + +Session::Session() = default; +} // namespace mindspore::inference diff --git a/mindspore/ccsrc/backend/session/session.h b/mindspore/ccsrc/backend/session/session.h new file mode 100644 index 0000000000..6ea9cfaa47 --- /dev/null +++ b/mindspore/ccsrc/backend/session/session.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H +#define MINDSPORE_CCSRC_SESSION_SESSION_H + +#include +#include +#include +#include +#include +#include + +#include "backend/session/session_basic.h" +#include "ir/anf.h" +#include "include/inference.h" + +namespace mindspore { +namespace inference { +class Session : public MSSession { + public: + Session(); + + uint32_t CompileGraph(std::shared_ptr funcGraphPtr) override; + + MultiTensor RunGraph(uint32_t graph_id, const std::vector> &inputs) override; + + int Init(const std::string &device, uint32_t device_id); + + static void RegAllOp(); + + private: + std::shared_ptr session_impl_ = nullptr; + std::vector graph_id_; +}; +} // namespace inference +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc new file mode 100644 index 0000000000..9755dfc7d0 --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -0,0 +1,1101 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/session_basic.h" +#include +#include +#include +#include +#include "pipeline/jit/parse/data_converter.h" +#include "ir/manager.h" +#include "ir/param_value.h" +#include "backend/kernel_compiler/common_utils.h" +#include "frontend/operator/ops.h" +#include "common/trans.h" +#include "utils/context/ms_context.h" +#include "utils/config_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/optimizer/common/common_backend_optimization.h" +#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "backend/optimizer/common/helper.h" +#include "common/utils.h" +#include "ir/dtype.h" +#include "ir/anf.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace session { +static std::shared_ptr> python_paras; +void ClearPythonParasMap() { python_paras = nullptr; } +namespace { +const int kSummaryGetItem = 2; + +ParamValuePtr GetParamDefaultValue(const AnfNodePtr &node) { + if (node == nullptr) { + return nullptr; + } + auto parameter = node->cast(); + if (parameter == nullptr || !parameter->has_default()) { + return nullptr; + } + return parameter->default_param(); +} + +BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph, + const std::vector &input_tensors) { + MS_EXCEPTION_IF_NULL(node); + MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; + // if node is a value node, no need sync addr from device to host + if (!AnfAlgo::OutputAddrExist(node, output_index)) { + if (node->isa()) { + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + return value_node->value(); + } + if (node->isa()) { + for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) { + if (input_idx >= input_tensors.size()) { + MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); + } + if (graph.inputs()[input_idx] == node) { + return input_tensors[input_idx]; + } + } + MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; + } + } + // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) + auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); + MS_EXCEPTION_IF_NULL(address); + auto shape = AnfAlgo::GetOutputInferShape(node, output_index); + TypeId type_id = kNumberTypeFloat32; + type_id = AnfAlgo::GetOutputInferDataType(node, output_index); + std::vector temp_shape; + if (graph.IsInternalOutput(node)) { + temp_shape.emplace_back(1); + tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + tensor->set_device_address(address); + tensor->set_dirty(false); + return tensor; + } + (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); + tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + // if in paynative mode,data only copyed to host when user want to print data + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { + tensor->set_device_address(address); + tensor->set_dirty(false); + } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), + LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { + MS_LOG(INFO) << "Output sync device to host error!!!"; + tensor->set_dirty(false); + } + return tensor; +} + +BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, + const std::vector &input_tensors) { + MS_EXCEPTION_IF_NULL(anf); + MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]"; + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); + MS_EXCEPTION_IF_NULL(item_with_index.first); + MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString(); + // special handle for maketuple + if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { + auto cnode = item_with_index.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + VectorRef ret; + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto out = CreateTensorForOutput(cnode->input(i), graph, input_tensors); + ret.push_back(out); + } + return ret; + } + // if is graph return nothing ,the function should return a null anylist + size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first); + if (size == 0) { + return VectorRef(); + } + return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors); +} + +ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(graph); + auto value_node = anf->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + return nullptr; + } + auto new_value_node = graph->NewValueNode(value_node); + graph->FrontBackendlMapAdd(anf, new_value_node); + graph->AddValueNodeToGraph(new_value_node); + return new_value_node; +} + +size_t LoadCtrlInputTensor(const std::shared_ptr &graph, std::vector *inputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Load kInputCtrlTensors"; + auto inputs_params = graph->input_ctrl_tensors(); + if (inputs_params == nullptr) { + return 0; + } + if (inputs_params->size() < 2) { + MS_LOG(EXCEPTION) << "Illegal inputs_params size"; + } + auto tensor = (*inputs_params)[0]; + MS_EXCEPTION_IF_NULL(tensor); + auto *val = static_cast(tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 0; + tensor->set_dirty(true); + // set loop_count to zero + MS_EXCEPTION_IF_NULL(inputs); + inputs->push_back(tensor); + + auto epoch_tensor = (*inputs_params)[1]; + MS_EXCEPTION_IF_NULL(epoch_tensor); + auto *epoch_val = static_cast(epoch_tensor->data_c()); + MS_EXCEPTION_IF_NULL(epoch_val); + *epoch_val = graph->current_epoch(); + epoch_tensor->set_dirty(true); + inputs->push_back(epoch_tensor); + MS_LOG(INFO) << "Load epoch_val:" << *epoch_val; + + graph->set_current_epoch(graph->current_epoch() + 1); + + return inputs_params->size(); +} + +ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr &graph, const tensor::TensorPtr &input_tensor) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input_tensor); + auto value_node = std::make_shared(input_tensor); + MS_EXCEPTION_IF_NULL(value_node); + // construct abstract of value node + auto type_of_tensor = input_tensor->Dtype(); + auto shape_of_tensor = input_tensor->shape(); + auto abstract = std::make_shared(type_of_tensor, shape_of_tensor); + value_node->set_abstract(abstract); + // add value node to graph + auto input_value_node = graph->NewValueNode(value_node); + graph->AddValueNodeToGraph(input_value_node); + return input_value_node; +} + +ParameterPtr ConstructRunOpParameter(const std::shared_ptr &graph, const tensor::TensorPtr &input_tensor, + int tensor_mask) { + MS_EXCEPTION_IF_NULL(graph); + auto param = graph->NewParameter(); + MS_EXCEPTION_IF_NULL(param); + if (tensor_mask == kParameterWeightTensorMask) { + auto param_value_new = std::make_shared(); + param->set_default_param(param_value_new); + } + // set the kernel info of parameter + auto kernel_build_info_builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(input_tensor); + auto device_address = std::dynamic_pointer_cast(input_tensor->device_address()); + if (device_address == nullptr) { + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type(); + kernel_build_info_builder->SetOutputsDeviceType(std::vector{param_init_data_type}); + } else { + kernel_build_info_builder->SetOutputsFormat(std::vector{device_address->format()}); + kernel_build_info_builder->SetOutputsDeviceType(std::vector{device_address->type_id()}); + } + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); + // construct abstract of parameter + auto type_of_tensor = input_tensor->Dtype(); + auto shape_of_tensor = input_tensor->shape(); + auto abstract = std::make_shared(type_of_tensor, shape_of_tensor); + param->set_abstract(abstract); + return param; +} + +void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { + MS_LOG(INFO) << "Graph outputs:"; + const size_t max_deep = 10; + if (recurse_level > max_deep) { + MS_LOG(INFO) << "Recurse too deep"; + return; + } + std::string tab_str; + for (size_t i = 0; i < recurse_level; i++) { + tab_str = tab_str.append(" "); + } + if (any.is()) { + (void)tab_str.append("{"); + MS_LOG(INFO) << tab_str; + auto any_list = any.cast(); + for (auto &it : any_list) { + DumpGraphOutput(it, recurse_level + 1); + } + (void)tab_str.append("}"); + MS_LOG(INFO) << tab_str; + } + (void)tab_str.append(any.ToString()); + MS_LOG(INFO) << tab_str; +} + +bool ExistSummaryNode(const KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto ret = graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto all_nodes = DeepLinkedGraphSearch(ret); + for (auto &n : all_nodes) { + if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || + IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { + return true; + } + } + return false; +} +} // namespace + +GraphId SessionBasic::graph_sum_ = 0; + +KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) { + auto it = graphs_.find(graph_id); + if (it == graphs_.end()) { + MS_LOG(WARNING) << "Can't find graph " << graph_id; + return nullptr; + } + return it->second; +} + +void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) { + auto graph_id = GetGraphIdByNode(out_node); + if (graph_id == kInvalidGraphId) { + return; + } + auto node_graph = GetGraph(graph_id); + if (node_graph == nullptr) { + return; + } + MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString(); + auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node); + if (ref_node == nullptr) { + MS_LOG(INFO) << "No corresponding internal output for output node"; + return; + } + auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0); + auto ref_real_node = real_kernel.first; + auto ref_real_node_index = real_kernel.second; + if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node) && + node_graph->IsFinalOutputKernel(ref_real_node)) { + auto kernel_info = ref_real_node->kernel_info(); + if (kernel_info == nullptr || !kernel_info->has_build_info()) { + MS_LOG(INFO) << "No kernel info"; + return; + } + auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index); + if (address == nullptr) { + MS_LOG(INFO) << "No kernel address"; + return; + } + auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); + auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); + auto d_kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(d_kernel_info); + parameter->set_kernel_info(d_kernel_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsDeviceType({type}); + builder.SetOutputsFormat({format}); + d_kernel_info->set_select_kernel_build_info(builder.Build()); + AnfAlgo::SetOutputAddr(address, 0, parameter.get()); + } +} + +std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, + KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(graph); + std::vector parameters; + std::vector pre_graph_out = {node}; + // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive + if (!AnfAlgo::IsRealKernel(node)) { + pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); + } + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { + auto parameter = graph->NewParameter(); + MS_EXCEPTION_IF_NULL(parameter); + parameter->set_abstract(abstract); + auto new_parameter = graph->NewParameter(parameter); + parameters.push_back(new_parameter); + valid_inputs->push_back(valid_input); + graph_inputs->push_back(new_parameter); + }; + for (const auto &out_node : pre_graph_out) { + MS_EXCEPTION_IF_NULL(out_node); + auto abstract = out_node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + // create multiple parameters if is a tuple output real kernel + if (abstract->isa() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; + for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { + create_parameter((*tuple_abstract)[output_idx]); + } + continue; + } + // create single parameter if is a abstract real kernel + create_parameter(out_node->abstract()); + InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]); + } + return parameters; +} + +ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, + KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + if (!anf->isa()) { + MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; + } + MS_EXCEPTION_IF_NULL(graph); + auto param_value = GetParamDefaultValue(anf); + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + ParameterPtr new_parameter = nullptr; + // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter + if (python_paras == nullptr) { + python_paras = std::make_shared>(); + } + auto iter = python_paras->find(param_value); + if (iter != python_paras->end()) { + new_parameter = iter->second; + } else { + TraceManager::DebugTrace(std::make_shared(anf->debug_info())); + new_parameter = graph->NewParameter(anf->cast()); + if (param_value != nullptr) { + (*python_paras)[param_value] = new_parameter; + } + TraceManager::EndTrace(); + } + graph_inputs->push_back(new_parameter); + valid_inputs->push_back(valid_input); + return new_parameter; +} + +AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; + auto parameters = CreateParameterFromTuple(anf, valid_input, graph); + if (parameters.empty()) { + MS_LOG(EXCEPTION) << "No parameter exist!!"; + } + if (parameters.size() == 1) { + return parameters[0]; + } + std::vector make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; + (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); + auto make_tuple = graph->NewCNode(make_tuple_input); + MS_EXCEPTION_IF_NULL(make_tuple); + MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; + return make_tuple; +} + +CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, + bool *from_other_graph, + std::unordered_map *other_graph_cnode) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(from_other_graph); + MS_EXCEPTION_IF_NULL(other_graph_cnode); + *from_other_graph = false; + // get primitive of old node + std::vector cnode_inputs; + auto prim = AnfAlgo::GetCNodePrimitive(cnode); + if (prim != nullptr) { + // push attr to inputs[0] of new cnode + cnode_inputs.push_back(std::make_shared(std::make_shared(*prim))); + } else { + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); + MS_EXCEPTION_IF_NULL(fg); + auto new_fg = BasicClone(fg); + cnode_inputs.push_back(std::make_shared(new_fg)); + } + auto origin_inputs = cnode->inputs(); + bool optimize_depend = false; + if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && + origin_inputs[kRealInputIndexInDepend]->isa()) { + optimize_depend = true; + } + // if has multiple depends,only select first depend as parameter + for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { + auto anf = origin_inputs[input_idx]; + MS_EXCEPTION_IF_NULL(anf); + // anf has been created before + if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { + cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); + continue; + } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { + cnode_inputs.push_back((*other_graph_cnode)[anf]); + continue; + } else if (anf->isa() && !IsValueNode(anf)) { + // if input is a value node, + auto new_value_node = CreateNewValueNode(anf, graph); + if (new_value_node != nullptr) { + cnode_inputs.emplace_back(new_value_node); + } + continue; + } else if (anf->isa() && AnfAlgo::GetOutputTensorNum(anf) == 1) { + auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); + cnode_inputs.push_back(new_parameter); + if (GetGraphIdByNode(anf) == kInvalidGraphId) { + graph->FrontBackendlMapAdd(anf, new_parameter); + } else { + (*other_graph_cnode)[anf] = new_parameter; + } + continue; + } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { + cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); + continue; + } else { + *from_other_graph = true; + // the input node is a cnode from other graph + auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); + cnode_inputs.push_back(parameter_from_cnode); + (*other_graph_cnode)[anf] = parameter_from_cnode; + } + } + TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); + auto new_cnode = graph->NewCNode(cnode_inputs); + TraceManager::EndTrace(); + return new_cnode; +} + +CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(node_input); + MS_EXCEPTION_IF_NULL(graph); + // switch input generalizes partial + if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) || + AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) { + return node_input->cast(); + } + if (node_input->isa()) { + MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call."; + } + std::vector partial_inputs = {NewValueNode(std::make_shared(prim::kPrimPartial->name()))}; + if (node_input->isa() && IsValueNode(node_input)) { + partial_inputs.emplace_back(node_input); + auto partial_node = graph->NewCNode(partial_inputs); + return partial_node; + } + KernelGraphPtr kernel_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(kernel_graph); + kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input)); + partial_inputs.emplace_back(std::make_shared(kernel_graph)); + auto partial_node = graph->NewCNode(partial_inputs); + return partial_node; +} + +CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(graph); + auto node = anf_node->cast(); + MS_EXCEPTION_IF_NULL(node); + if (node->inputs().size() < kSwitchInputSize) { + MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize; + } + auto primitive = NewValueNode(std::make_shared(prim::kPrimSwitch->name())); + std::vector switch_inputs = {primitive, node->input(1)}; + for (size_t index = 2; index < node->inputs().size(); index++) { + auto input = CreateSwitchInput(node->input(index), graph); + switch_inputs.emplace_back(input); + } + auto switch_node = graph->NewCNode(switch_inputs); + return switch_node; +} + +std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + // create primitive of cnode:call(partial or switch) + std::vector cnode_inputs = { + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); + if (cnode_input == nullptr) { + MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString() + << ", but input[0] has not been created."; + } + // if the node is partial, insert the inputs of partial to the call + if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) { + auto partial_node = attr_input->cast(); + MS_EXCEPTION_IF_NULL(partial_node); + auto partial_inputs = partial_node->inputs(); + std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(), + std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node)); + return graph->GetBackendAnfByFrontAnf(node); + }); + return cnode_inputs; + } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { + auto switch_node = HandleSwitchInputs(cnode_input, graph); + cnode_inputs.emplace_back(switch_node); + return cnode_inputs; + } + MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; +} + +CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + std::vector cnode_inputs; + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + if (AnfAlgo::IsGraphKernel(cnode)) { + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); + MS_EXCEPTION_IF_NULL(fg); + auto new_fg = BasicClone(fg); + cnode_inputs.push_back(std::make_shared(new_fg)); + } else if (IsValueNode(attr_input)) { + // create primitive of cnode:call + cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; + // create a ValueNode as input of cnode:call + if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { + cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input)); + } else { + auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph); + if (new_value_node != nullptr) { + cnode_inputs.emplace_back(new_value_node); + } + } + } else if (attr_input->isa()) { + cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); + } else { + // get primitive of old node + auto prim = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(prim); + // push attr to inputs[0] of new cnode + cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(*prim)))}; + } + + for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { + auto anf = cnode->input(input_idx); + MS_EXCEPTION_IF_NULL(anf); + // anf has been created before + if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { + cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); + continue; + } else if (IsValueNode(anf)) { + continue; + } + MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; + } + TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); + auto new_cnode = graph->NewCNode(cnode_inputs); + TraceManager::EndTrace(); + return new_cnode; +} + +ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(graph); + auto value_node = anf->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf); + MS_EXCEPTION_IF_NULL(sub_func_graph); + if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) { + MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph."; + } + auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph]; + + ValueNodePtr new_value_node = std::make_shared(sub_kernel_graph); + new_value_node->set_abstract(value_node->abstract()); + // create new kernel_info of new value_node + auto kernel_info = std::make_shared(); + kernel_info->SetFeatureMapFlag(false); + new_value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get()); + + graph->FrontBackendlMapAdd(anf, new_value_node); + + return new_value_node; +} + +ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(graph); + if (!anf->isa()) { + MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; + } + + auto param_value = GetParamDefaultValue(anf); + ParameterPtr new_parameter = nullptr; + if (python_paras == nullptr) { + python_paras = std::make_shared>(); + } + auto iter = python_paras->find(param_value); + if (iter != python_paras->end()) { + new_parameter = iter->second; + } else { + TraceManager::DebugTrace(std::make_shared(anf->debug_info())); + new_parameter = graph->NewParameter(anf->cast()); + if (param_value != nullptr) { + (*python_paras)[param_value] = new_parameter; + } + TraceManager::EndTrace(); + } + + return new_parameter; +} + +KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + std::unordered_map other_graph_cnode; + auto graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Create graph: " << graph->graph_id(); + size_t from_other_graph_depend_num = 0; + for (const auto &node : lst) { + MS_EXCEPTION_IF_NULL(node); + MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode"; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // create a new cnode object + bool from_other_graph = false; + // only first depend from other graph can create + bool valid_input = true; + if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { + valid_input = false; + } + auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode); + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) { + from_other_graph_depend_num++; + } + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_scope(cnode->scope()); + // record map relations between anf from ME and new anf node used in backend + graph->FrontBackendlMapAdd(node, new_cnode); + } + // add a make_tuple at the end of graph as output + graph->set_output(ConstructOutput(outputs, graph)); + MS_EXCEPTION_IF_NULL(context_); + FuncGraphManagerPtr manager = MakeManager({graph}); + if (manager) { + manager->AddFuncGraph(graph); + graph->set_manager(manager); + } + graph->SetExecOrderByDefault(); + if (ExistSummaryNode(graph.get())) { + graph->set_summary_node_exist(true); + } + opt::BackendCommonOptimization(graph); + return graph; +} + +void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(graph); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // create a new cnode object + auto new_cnode = CreateNewCNode(cnode, graph.get()); + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_fullname_with_scope(cnode->fullname_with_scope()); + new_cnode->set_scope(cnode->scope()); + graph->FrontBackendlMapAdd(node, new_cnode); + if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) { + graph->set_return(new_cnode); + } +} +std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph, + std::vector *all_out_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(all_out_graph); + auto node_list = TopoSort(func_graph->get_return()); + auto graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(graph); + front_backend_graph_map_[func_graph] = graph; + MS_LOG(INFO) << "Create graph: " << graph->graph_id(); + + bool is_trace_back = false; + for (const auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); + if (node->isa()) { + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + auto new_parameter = CreateNewParameter(node, graph.get()); + graph_inputs->push_back(new_parameter); + graph->FrontBackendlMapAdd(node, new_parameter); + continue; + } else if (node->isa()) { + if (!IsValueNode(node)) { + // if input is a common value node, + (void)CreateNewValueNode(node, graph.get()); + } else { + // if input is a ValueNode + FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); + if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) { + is_trace_back = true; + } else { + (void)ConstructKernelGraph(child_graph, all_out_graph); + } + (void)CreateValueNodeKernelGraph(node, graph.get()); + } + continue; + } else { + CreateCNodeKernelGraph(node, graph); + } + } + // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null. + graph->set_output_null(is_trace_back); + AddParameterToGraphInputs(func_graph->parameters(), graph.get()); + graph->SetExecOrderByDefault(); + if (ExistSummaryNode(graph.get())) { + graph->set_summary_node_exist(true); + } + all_out_graph->push_back(graph); + return graph; +} + +void SessionBasic::AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + graph_inputs->clear(); + for (auto ¶meter : parameters) { + MS_EXCEPTION_IF_NULL(parameter); + auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); + if (backend_parameter == nullptr) { + // for example "def f(x,y,z) {return x + y}", parameter z in unused + auto new_parameter = CreateNewParameter(parameter, graph); + graph_inputs->push_back(new_parameter); + MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); + continue; + } + MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString(); + graph_inputs->push_back(backend_parameter); + } +} + +// run graph steps +void SessionBasic::LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const { + std::vector inputs(inputs_const); + size_t input_ctrl_size = 2; + MS_EXCEPTION_IF_NULL(kernel_graph); + if (kernel_graph->input_ctrl_tensors()) { + input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); + } + auto input_nodes = kernel_graph->inputs(); + if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) { + MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() + << ", input_ctrl_size:" << input_ctrl_size; + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + for (size_t i = 0; i < inputs.size(); ++i) { + auto tensor = inputs[i]; + MS_EXCEPTION_IF_NULL(tensor); + auto input_node = input_nodes[i]; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { + auto pk_node = input_node->cast(); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + bool need_sync = false; + if (ms_context->enable_pynative_infer()) { + if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { + need_sync = true; + } + } else { + if (tensor->is_dirty()) { + need_sync = true; + } else if (tensor->device_address() != device_address) { + (void)tensor->data_sync(); + need_sync = true; + } + } + if (need_sync) { + if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) { + tensor->set_device_address(device_address); + } + MS_EXCEPTION_IF_NULL(device_address); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } + tensor->set_dirty(false); + } +} + +void SessionBasic::UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, + const std::vector &input_tensors) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(outputs); + auto anf_outputs = kernel_graph->outputs(); + for (auto &item : anf_outputs) { + MS_EXCEPTION_IF_NULL(item); + MS_LOG(INFO) << "Update output[" << item->DebugString() << "]"; + outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors)); + } +} + +void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { + MS_EXCEPTION_IF_NULL(callback); + summary_callback_ = callback; +} + +void SessionBasic::Reorder(std::vector *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } + +void SessionBasic::GetSummaryNodes(KernelGraph *graph) { + MS_LOG(DEBUG) << "Update summary Start"; + MS_EXCEPTION_IF_NULL(graph); + if (!graph->summary_node_exist()) { + return; + } + auto summary = graph->summary_nodes(); + auto apply_list = TopoSort(graph->get_return()); + for (auto &n : apply_list) { + MS_EXCEPTION_IF_NULL(n); + if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || + IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { + auto cnode = n->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() <= kSummaryGetItem) { + MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!"; + } + auto node = cnode->input(kSummaryGetItem); + MS_EXCEPTION_IF_NULL(node); + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + MS_EXCEPTION_IF_NULL(item_with_index.first); + if (!AnfAlgo::IsRealKernel(item_with_index.first)) { + MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); + } + summary[n->fullname_with_scope()] = item_with_index; + } + } + graph->set_summary_nodes(summary); + MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); +} + +void SessionBasic::Summary(KernelGraph *graph) { + if (summary_callback_ == nullptr) { + return; + } + MS_EXCEPTION_IF_NULL(graph); + bool exist_summary = graph->summary_node_exist(); + if (!exist_summary) { + return; + } + GetSummaryNodes(graph); + auto summary_outputs = graph->summary_nodes(); + std::map params_list; + // fetch outputs apply kernel in session & run callback functions + for (auto &output_item : summary_outputs) { + auto node = output_item.second.first; + size_t index = IntToSize(output_item.second.second); + auto address = AnfAlgo::GetOutputAddr(node, index); + auto shape = AnfAlgo::GetOutputInferShape(node, index); + TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index); + std::vector temp_shape; + (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); + tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + MS_EXCEPTION_IF_NULL(address); + if (!address->GetPtr()) { + continue; + } + if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()), + tensor->data_type(), tensor->data_c())) { + MS_LOG(ERROR) << "Failed to sync output from device to host."; + } + tensor->set_dirty(false); + params_list[output_item.first] = tensor; + } + // call callback function here + summary_callback_(0, params_list); +} + +CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph) { + MS_EXCEPTION_IF_NULL(graph); + std::vector output_args; + for (const auto &output : outputs) { + MS_EXCEPTION_IF_NULL(output); + MS_LOG(INFO) << "Output:" << output->DebugString(); + } + auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { + auto backend_anf = graph->GetBackendAnfByFrontAnf(out); + if (backend_anf != nullptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->execution_mode() == kPynativeMode) { + return backend_anf; + } + auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); + auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); + MS_EXCEPTION_IF_NULL(out); + auto out_func_graph = out->func_graph(); + MS_EXCEPTION_IF_NULL(out_func_graph); + auto out_func_graph_manager = out_func_graph->manager(); + if (out_func_graph_manager == nullptr) { + return backend_anf; + } + auto node_users = out_func_graph_manager->node_users(); + auto users = node_users[out]; + bool internal_output = true; + std::string kernel_target = GetCNodeTarget(front_real_kernel.first); + for (auto user : users) { + if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { + internal_output = false; + break; + } + } + if (internal_output) { + MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); + graph->AddInternalOutput(out, backend_real_kernel.first); + } + return backend_anf; + } + MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; + }; + output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); + (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args), + [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); }); + return graph->NewCNode(output_args); +} + +void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr &graph) { + MS_LOG(INFO) << "Start!"; + std::vector make_tuple_inputs; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + MS_EXCEPTION_IF_NULL(graph); + if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) { + for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) { + auto idx = NewValueNode(SizeToInt(output_index)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(output_index); + idx->set_abstract(std::make_shared(imm)); + auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); + std::vector types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)}; + std::vector> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get()); + make_tuple_inputs.push_back(getitem); + } + } else { + make_tuple_inputs.push_back(cnode); + } + // create output + auto g_output = graph->NewCNode(make_tuple_inputs); + graph->set_output(g_output); + // set graph manager,which now is only used to get valuenodes and hardware optimizing + MS_EXCEPTION_IF_NULL(context_); + FuncGraphManagerPtr manager = context_->manager(); + if (manager != nullptr) { + manager->AddFuncGraph(graph); + graph->set_manager(manager); + } + MS_LOG(INFO) << "Finish!"; +} + +std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, + const std::vector &input_tensors, + const std::vector &tensors_mask) { + auto graph = std::make_shared(); + std::vector inputs; + // set input[0] + PrimitivePtr op_prim = op_run_info.py_primitive; + MS_EXCEPTION_IF_NULL(op_prim); + inputs.push_back(std::make_shared(op_prim)); + // set input parameter + MS_LOG(INFO) << "Input tensor size: " << input_tensors.size(); + if (input_tensors.size() != tensors_mask.size()) { + MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size " + << tensors_mask.size(); + } + for (size_t i = 0; i < input_tensors.size(); ++i) { + if (tensors_mask[i] == kValueNodeTensorMask) { + auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]); + inputs.push_back(value_node); + continue; + } + auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]); + inputs.push_back(parameter); + auto mutable_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(mutable_inputs); + mutable_inputs->push_back(parameter); + } + // set execution order + auto cnode = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode); + // set abstract,which include inferred shapes and types + cnode->set_abstract(op_run_info.abstract); + // set execution order + std::vector exe_order = {cnode}; + graph->set_execution_order(exe_order); + // set output + CreateOutputNode(cnode, graph); + return graph; +} + +BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) { + if (utils::isa(base_ref)) { + auto ref_list = utils::cast(base_ref); + py::tuple output_tensors(ref_list.size()); + for (size_t i = 0; i < ref_list.size(); ++i) { + auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef + if (utils::isa(output)) { + auto tensor_ptr = utils::cast(output); + MS_EXCEPTION_IF_NULL(tensor_ptr); + output_tensors[i] = tensor_ptr; + } else if (utils::isa(output)) { + py::object obj = utils::cast(output).object_; + py::tuple tensor_tuple = py::cast(obj); + output_tensors[i] = tensor_tuple; + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } + } + return output_tensors; // turn tuple to py::object and store in PyObjectRef + } else if (utils::isa(base_ref)) { + return base_ref; + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } +} + +KernelGraphPtr SessionBasic::NewKernelGraph() { + auto graph = std::make_shared(); + graph->set_graph_id(graph_sum_); + graphs_[graph_sum_++] = graph; + return graph; +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h new file mode 100755 index 0000000000..c662e3978b --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -0,0 +1,160 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H +#define MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H + +#include +#include +#include +#include +#include +#include + +#include "utils/base_ref_extends.h" +#include "backend/session/session_context.h" +#include "backend/session/kernel_graph.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "utils/any.h" +#include "utils/contract.h" +#include "pipeline/pynative/pynative_execute.h" +#include "runtime/device/kernel_info.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif + +namespace mindspore { +using GraphId = uint32_t; +using GraphInfo = std::string; +namespace session { +void ClearPythonParasMap(); +using CallBackFunc = uint32_t (*)(uint32_t graph_id, + const std::map ¶ms_list); +using AnyList = std::vector; +using AnyListPtr = std::shared_ptr; + +using OpRunInfo = pynative::OpExecInfo; +using OpRunInfoPtr = std::shared_ptr; + +class SessionBasic { + public: + SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) { +#ifdef ENABLE_DEBUGGER + debugger_ = nullptr; +#endif + } + + virtual void Init(uint32_t device_id) { device_id_ = device_id; } + + virtual ~SessionBasic() { summary_callback_ = nullptr; } + + virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; + virtual GraphId CompileGraph(NotNull func_graph) { return kInvalidGraphId; } + // build graph, used to handle multiple child graphs + virtual void BuildGraph(GraphId) {} + + virtual void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) = 0; + + virtual void BuildOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors, + const std::vector &tensors_mask) {} + + virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors) { + return py::tuple(); + } + + virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); + + void CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph); + + std::shared_ptr ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); + std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph, + std::vector *all_out_graph); + + CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, + std::unordered_map *other_graph_cnode); + CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); + + CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); + CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph); + std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); + + // set parameters of final graph + virtual GraphId SetFinalGraphInput(const std::vector &) { return kInvalidGraphId; } + // set output of final graph + virtual void SetFinalGraphOutput(const BaseRef &) {} + // insert switch and set the relative active ops + virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {} + // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter + virtual void SetChildGraphInput(GraphId, const VectorRef &) {} + // get graph id in child graphs by ME front anf node pointer + virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } + virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } + virtual void SetActive(GraphId, GraphId) {} + virtual void GetSummaryNodes(KernelGraph *graph); + +#ifdef ENABLE_DEBUGGER + // set debugger + void SetDebugger() { + debugger_ = Debugger::GetInstance(); + debugger_->Init(device_id_); + } +#endif + + protected: + // Get graph by graph id ,if not exist return null ptr + KernelGraphPtr GetGraph(GraphId graph_id); + virtual void LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const; + void UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, + const std::vector &input_tensors) const; + void Reorder(std::vector *node_list); + void Summary(KernelGraph *graph); + // create graph output for RunOp + void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr &graph); + CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph); + // create a single run op graph + std::shared_ptr ConstructSingleOpGraph(const OpRunInfo &op_run_info, + const std::vector &input_tensors, + const std::vector &tensors_mask); + // trans BaseRef list to py::tuple + BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); + // create a new kernel graph and update the graph sum + KernelGraphPtr NewKernelGraph(); + std::vector CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph); + virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); + ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); + ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); + AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); + void AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph); + void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); + + std::unordered_map> graphs_; + std::unordered_map> run_op_graphs_; + std::unordered_map front_backend_graph_map_; + std::shared_ptr context_; + CallBackFunc summary_callback_; + static GraphId graph_sum_; + uint32_t device_id_; +#ifdef ENABLE_DEBUGGER + std::shared_ptr debugger_; +#endif +}; + +using SessionPtr = std::shared_ptr; +using NamedSummaryOutputs = std::map>; +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/backend/session/session_context.cc b/mindspore/ccsrc/backend/session/session_context.cc new file mode 100644 index 0000000000..f5ec49c090 --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_context.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/session_context.h" +namespace mindspore { +namespace session { +std::shared_ptr Context::GetInstance() { + static std::shared_ptr context_singleton = std::make_shared(); + return context_singleton; +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/session_context.h b/mindspore/ccsrc/backend/session/session_context.h new file mode 100644 index 0000000000..22cc0c813a --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_context.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H +#define MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H +#include +#include +#include +#include +#include +#include + +#include "ir/tensor.h" +#include "pipeline/jit/resource.h" +#include "utils/context/ms_context.h" +namespace mindspore { +namespace session { +const char kInputCtrlTensors[] = "input_ctrl_tensors"; + +class Context : public pipeline::ResourceBase { + public: + explicit Context(std::string target = kAscendDevice, uint32_t device_id = 0) + : target_(std::move(target)), device_id_(device_id) {} + ~Context() override = default; + + uint32_t device_id() const { return device_id_; } + static std::shared_ptr GetInstance(); + void AddManager(const FuncGraphManagerPtr &m) { manager_list_.push_back(m); } + + private: + std::vector manager_list_; + std::string target_; + uint32_t device_id_; +}; +} // namespace session +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H diff --git a/mindspore/ccsrc/backend/session/session_factory.cc b/mindspore/ccsrc/backend/session/session_factory.cc new file mode 100644 index 0000000000..8a8f9a9cea --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_factory.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/session_factory.h" +#include +#include +#include +namespace mindspore { +namespace session { +SessionFactory &SessionFactory::Get() { + static SessionFactory instance; + return instance; +} + +void SessionFactory::Register(const std::string &device_name, SessionCreator &&session_creator) { + if (session_creators_.end() == session_creators_.find(device_name)) { + (void)session_creators_.emplace(device_name, session_creator); + } +} + +std::shared_ptr SessionFactory::Create(const std::string &device_name) { + auto iter = session_creators_.find(device_name); + if (session_creators_.end() != iter) { + MS_EXCEPTION_IF_NULL(iter->second); + return (iter->second)(); + } + return nullptr; +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/session_factory.h b/mindspore/ccsrc/backend/session/session_factory.h new file mode 100644 index 0000000000..054f03cf4b --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_factory.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ +#define MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ + +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "backend/session/session_basic.h" +namespace mindspore { +namespace session { +using SessionCreator = std::function()>; +class SessionFactory { + public: + static SessionFactory &Get(); + void Register(const std::string &device_name, SessionCreator &&session_creator); + std::shared_ptr Create(const std::string &device_name); + + private: + SessionFactory() = default; + ~SessionFactory() = default; + DISABLE_COPY_AND_ASSIGN(SessionFactory) + std::map session_creators_; +}; + +class SessionRegistrar { + public: + SessionRegistrar(const std::string &device_name, SessionCreator &&session_creator) { + SessionFactory::Get().Register(device_name, std::move(session_creator)); + } + ~SessionRegistrar() = default; +}; + +#define MS_REG_SESSION(DEVICE_NAME, SESSION_CLASS) \ + static const SessionRegistrar g_session_registrar__##DEVICE_NAME##_##_reg( \ + DEVICE_NAME, []() { return std::make_shared(); }); +} // namespace session +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ diff --git a/mindspore/ccsrc/common.h b/mindspore/ccsrc/common.h index 0928dcfcf6..6b882a15d4 100644 --- a/mindspore/ccsrc/common.h +++ b/mindspore/ccsrc/common.h @@ -23,13 +23,13 @@ #include "pybind11/pybind11.h" #include "pybind11/stl.h" -#include "pipeline/static_analysis/dshape.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/resolve.h" +#include "abstract/dshape.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/resolve.h" namespace py = pybind11; #endif // MINDSPORE_CCSRC_COMMON_H_ diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 9cf6eb3a5a..1841826ca9 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -18,9 +18,9 @@ #include #include #include "common/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel.h" -#include "device/convert_tensor_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel.h" +#include "runtime/device/convert_tensor_utils.h" #include "utils/convert_utils.h" #include "utils/log_adapter.h" #include "utils/utils.h" diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index a8fc7c8a00..286c76afd0 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -24,7 +24,7 @@ #include #include #include "ir/dtype.h" -#include "kernel/kernel.h" +#include "backend/kernel_compiler/kernel.h" #include "ir/dtype/type.h" namespace mindspore { diff --git a/mindspore/ccsrc/common/utils.h b/mindspore/ccsrc/common/utils.h index 8f6e8f7c0c..23d08f8f28 100644 --- a/mindspore/ccsrc/common/utils.h +++ b/mindspore/ccsrc/common/utils.h @@ -38,6 +38,14 @@ static inline std::string GetEnv(const std::string &envvar) { return std::string(value); } + +static inline int SetEnv(const char *envname, const char *envvar, int overwrite = 1) { +#if defined(_WIN32) + return 0; +#else + return ::setenv(envname, envvar, overwrite); +#endif +} } // namespace common } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/CMakeLists.txt b/mindspore/ccsrc/dataset/CMakeLists.txt deleted file mode 100644 index 9238be93f2..0000000000 --- a/mindspore/ccsrc/dataset/CMakeLists.txt +++ /dev/null @@ -1,143 +0,0 @@ -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-reorder") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-switch") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sequence-point") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-variable") - -if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-uninitialized") -else() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-maybe-uninitialized") -endif() -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") - -############################# Options ################################ -if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") - add_definitions(-D _CRT_RAND_S) -endif () -if (ENABLE_GPUQUE) - add_definitions(-D ENABLE_GPUQUE) - message(STATUS "GPU queue is enabled") -endif () -if (ENABLE_TDTQUE) - add_definitions(-D ENABLE_TDTQUE) - message(STATUS "TDT queue is enabled") -endif () - -# conde coverage -# option(ENABLE_COVERAGE "Enable code coverage report" OFF) -# if (ENABLE_COVERAGE) -# include(${CMAKE_SOURCE_DIR}/cmake/CodeCoverage.cmake) -# append_coverage_compiler_flags() -# endif () - -########### Set up the include directories ########################### -include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc) -include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/device/ascend/platform) - -include_directories(${CMAKE_BINARY_DIR}) # for protobuf generated .h - -include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/mindrecord/include) -###################################################################### - -####################### Flags ######################################## -# compile flags -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") - -################## Include sub-modules ############################### -add_subdirectory(util) -add_subdirectory(core) -add_subdirectory(kernels) -add_subdirectory(engine) -add_subdirectory(api) -add_subdirectory(text) -###################################################################### -add_dependencies(core utils) -add_dependencies(kernels-image core) -add_dependencies(kernels-data core) -add_dependencies(kernels core) -add_dependencies(engine-datasetops-source core) -add_dependencies(engine-datasetops-source-sampler core) -add_dependencies(engine-datasetops core) -add_dependencies(engine-opt core) -add_dependencies(engine-perf core) -add_dependencies(engine-gnn core) -add_dependencies(engine core) -add_dependencies(text core) -add_dependencies(text-kernels core) -add_dependencies(APItoPython core) -if (ENABLE_TDTQUE) - add_dependencies(engine-tdt core) -endif () -################### Create _c_dataengine Library ###################### -set(submodules - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - ) - -if (ENABLE_TDTQUE) - add_library(_c_dataengine SHARED ${submodules} $) -else () - add_library(_c_dataengine SHARED ${submodules}) -endif () - -set_target_properties(_c_dataengine PROPERTIES - PREFIX "${PYTHON_MODULE_PREFIX}" - SUFFIX "${PYTHON_MODULE_EXTENSION}" - ) - -###################################################################### - -################# Link with external libraries ######################## -target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar) -if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") - target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY}) -else() - set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n) - target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY}) -endif() -target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::opencv_core mindspore::opencv_imgcodecs - mindspore::opencv_imgproc mindspore::tinyxml2 ${ICU_LIB}) -if (ENABLE_GPUQUE) - target_link_libraries(_c_dataengine PRIVATE gpu_queue - ${CUDNN_PATH}/lib64/libcudnn.so - ${CUDA_PATH}/lib64/libcudart.so - ${CUDA_PATH}/lib64/stubs/libcuda.so) -endif () - -if (ENABLE_TDTQUE) - target_link_libraries(_c_dataengine PRIVATE ${TSDCLIENT}) -endif () - -add_dependencies(_c_dataengine _c_mindrecord) -if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") - set(MINDRECORD_LINK_OBJECT ${CMAKE_BINARY_DIR}/mindspore/ccsrc/mindrecord/CMakeFiles/_c_mindrecord.dir/objects.a) - target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite) -else() - target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) -endif() - -if (USE_GLOG) - target_link_libraries(_c_dataengine PRIVATE mindspore::glog) -else() - if (CMAKE_SYSTEM_NAME MATCHES "Linux") - target_link_options(_c_dataengine PRIVATE -Wl,-init,mindspore_log_init) - elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") - set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON) - endif () -endif() diff --git a/mindspore/ccsrc/dataset/api/CMakeLists.txt b/mindspore/ccsrc/dataset/api/CMakeLists.txt deleted file mode 100644 index 194aeed457..0000000000 --- a/mindspore/ccsrc/dataset/api/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(APItoPython OBJECT - de_pipeline.cc - python_bindings.cc - ) -target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS}) diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc deleted file mode 100644 index 78fcdb7dd4..0000000000 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ /dev/null @@ -1,1477 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/api/de_pipeline.h" - -#include -#include -#include - -#include "common/utils.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/bucket_batch_by_length_op.h" -#include "dataset/engine/datasetops/filter_op.h" -#include "dataset/engine/datasetops/source/celeba_op.h" -#include "dataset/engine/datasetops/source/cifar_op.h" -#include "dataset/engine/datasetops/source/clue_op.h" -#include "dataset/engine/datasetops/source/coco_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/random_data_op.h" -#include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#include "dataset/kernels/py_func_op.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" -#include "mindrecord/include/shard_category.h" -#include "mindrecord/include/shard_distributed_sample.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_shuffle.h" -#include "pybind11/stl.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *, std::shared_ptr *); - -static std::unordered_map g_parse_op_func_ = { - {kShuffle, &DEPipeline::ParseShuffleOp}, - {kMindrecord, &DEPipeline::ParseMindRecordOp}, - {kMap, &DEPipeline::ParseMapOp}, - {kFilter, &DEPipeline::ParseFilterOp}, - {kBatch, &DEPipeline::ParseBatchOp}, - {kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp}, - {kBarrier, &DEPipeline::ParseBarrierOp}, - {kRepeat, &DEPipeline::ParseRepeatOp}, - {kSkip, &DEPipeline::ParseSkipOp}, - {kZip, &DEPipeline::ParseZipOp}, - {kConcat, &DEPipeline::ParseConcatOp}, - {kRename, &DEPipeline::ParseRenameOp}, - {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, - {kGenerator, &DEPipeline::ParseGeneratorOp}, - {kTfReader, &DEPipeline::ParseTFReaderOp}, - {kProject, &DEPipeline::ParseProjectOp}, - {kTake, &DEPipeline::ParseTakeOp}, - {kImageFolder, &DEPipeline::ParseImageFolderOp}, - {kMnist, &DEPipeline::ParseMnistOp}, - {kManifest, &DEPipeline::ParseManifestOp}, - {kVoc, &DEPipeline::ParseVOCOp}, - {kCoco, &DEPipeline::ParseCocoOp}, - {kCifar10, &DEPipeline::ParseCifar10Op}, - {kCifar100, &DEPipeline::ParseCifar100Op}, - {kCelebA, &DEPipeline::ParseCelebAOp}, - {kRandomData, &DEPipeline::ParseRandomDataOp}, - {kTextFile, &DEPipeline::ParseTextFileOp}, - {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, - {kClue, &DEPipeline::ParseClueOp}}; - -DEPipeline::DEPipeline() : iterator_(nullptr) { - try { - // One time init - (void)GlobalInit(); - - // Instantiate the execution tree - tree_ = std::make_shared(); - repeat_num_ = 1; - batch_size_ = 1; - num_rows_ = 0; - num_classes_ = 0; - temp_batch_size_ = 1; - temp_drop_remainder_ = false; - } catch (const std::exception &err) { - MS_LOG(ERROR) << "Dataset pipeline exception caught on init: " << err.what() << "."; - return; - } -} - -DEPipeline::~DEPipeline() { - { - // Release GIL before joining all threads - py::gil_scoped_release gil_release; - // Release tree - tree_.reset(); - } -} - -// Function to add a Node to the Execution Tree. -Status DEPipeline::AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output) { - // For each operator, Parse through the list of arguments, then call the respective builder/constructor. - // Note that each call to the parse function may result in building more than one dataset operator. - // For example, one call to ParseNNNOp may result in multiple internal C nodes: - // nodeA - // | - // nodeB - // | - // nodeC - // However, the python side dataset is more abstract, and it does not know about the potential subtree that - // is being built here. Since the python api is hooking tree nodes together (parent/child hookups), the - // python side needs to know about nodeA and NodeC to be able to appropriately hook up parents and child - // to this subtee. - // Thus, it is required that both the top-most parent and bottom-most child are returned from the parse - // function. - DsOpPtr top = nullptr; - DsOpPtr bottom = nullptr; - auto iter = g_parse_op_func_.find(op_name); - if (iter != g_parse_op_func_.end()) { - pFunction func = iter->second; - RETURN_IF_NOT_OK((this->*func)(args, &top, &bottom)); - - if (top == nullptr) { - RETURN_STATUS_UNEXPECTED("An operator was parsed but it did not produce a C node."); - } - - // It is not required that the parse function always produces the bottom pointer. If it's still null, - // then set top and bottom to be the same operator - if (bottom == nullptr) bottom = top; - - // Pack these pointers into a py dict so that we can return both back to python. - (*output)["top"] = top; - (*output)["bottom"] = bottom; - } else { - RETURN_STATUS_UNEXPECTED("No such Op"); - } - // Associate current dataset op node with the tree. - RETURN_IF_NOT_OK(tree_->AssociateNode(top)); - return Status::OK(); -} -// Function to add a child and parent relationship. -Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op) { - // Link this relationship. - // Note parent node takes ownership of the child - return (parent_op->AddChild(child_op)); -} - -// Function to assign the node as root. -Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } - -// Function to launch the tree execution. -Status DEPipeline::LaunchTreeExec() { - RETURN_IF_NOT_OK(tree_->Prepare()); - RETURN_IF_NOT_OK(tree_->Launch()); - iterator_ = std::make_unique(tree_); - if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); - return Status::OK(); -} - -void DEPipeline::PrintTree() { - for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) { - std::stringstream ss; - ss << *itr; - MS_LOG(DEBUG) << "Operator ID is " << itr->id() << ". Details: " << ss.str().c_str() << "."; - } -} - -Status DEPipeline::GetNextAsMap(py::dict *output) { - TensorMap row; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->GetNextAsMap(&row); - } - RETURN_IF_NOT_OK(s); - // Generate Python dict as return - for (auto el : row) { - (*output)[common::SafeCStr(el.first)] = el.second; - } - return Status::OK(); -} - -Status DEPipeline::GetNextAsList(py::list *output) { - TensorRow row; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->FetchNextTensorRow(&row); - } - RETURN_IF_NOT_OK(s); - // Generate Python list as return - for (auto el : row) { - output->append(el); - } - return Status::OK(); -} - -Status DEPipeline::GetOutputShapes(py::list *output) { - std::vector shapes; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->GetOutputShapes(&shapes); - } - RETURN_IF_NOT_OK(s); - for (auto el : shapes) { - py::list shape; - for (auto dim : el.AsVector()) { - shape.append(dim); - } - output->append(shape); - } - return Status::OK(); -} - -Status DEPipeline::GetOutputTypes(py::list *output) { - std::vector types; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->GetOutputTypes(&types); - } - RETURN_IF_NOT_OK(s); - for (auto el : types) { - output->append(el.AsNumpyType()); - } - return Status::OK(); -} - -int DEPipeline::GetDatasetSize() const { return num_rows_ / batch_size_; } - -int DEPipeline::GetBatchSize() const { return batch_size_; } - -int DEPipeline::GetRepeatCount() const { return repeat_num_; } - -float ToFloat(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -int ToInt(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -bool ToBool(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -std::string ToString(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -std::vector ToStringVector(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::vector vector; - for (auto l : list) { - if (!l.is_none()) - vector.push_back(py::str(l)); - else - vector.emplace_back(""); - } - return vector; -} - -std::set ToStringSet(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::set set; - for (auto l : list) { - if (!l.is_none()) { - (void)set.insert(py::str(l)); - } - } - return set; -} - -std::map ToStringMap(const py::handle handle) { - py::dict dict = py::reinterpret_borrow(handle); - std::map map; - for (auto p : dict) { - (void)map.insert(std::make_pair(ToString(p.first), ToInt(p.second))); - } - return map; -} - -std::vector ToIntVector(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::vector vector; - for (auto l : list) { - if (!l.is_none()) { - vector.push_back(ToInt(l)); - } - } - return vector; -} - -std::vector ToTypeVector(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::vector vector; - for (auto l : list) { - if (l.is_none()) { - vector.emplace_back(DataType()); - } else { - vector.push_back(l.cast()); - } - } - return vector; -} - -Status DEPipeline::SetBatchParameters(const py::dict &args) { - if (args["batch_size"].is_none()) { - std::string err_msg = "Error: batchSize is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - temp_batch_size_ = ToInt(args["batch_size"]); - CHECK_FAIL_RETURN_UNEXPECTED(temp_batch_size_ > 0, "Error: batchSize is invalid."); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "drop_remainder") { - temp_drop_remainder_ = ToBool(value); - } - } - } - - return Status::OK(); -} - -Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - if (!args["buffer_size"].is_none()) { - (void)builder->SetShuffleSize(ToInt(args["buffer_size"])); - } else { - std::string err_msg = "Error: Shuffle buffer size is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "reshuffle_each_epoch") { - (void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"])); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, - std::vector> *operators, - int num_padded) { - auto sampler = py::reinterpret_borrow(handle); - auto create = sampler.attr("create_for_minddataset"); - auto op = create().cast>(); - std::stack> stack_ops; - while (op != nullptr) { - auto sampler_op = std::dynamic_pointer_cast(op); - if (sampler_op && num_padded > 0) { - sampler_op->SetNumPaddedSamples(num_padded); - stack_ops.push(sampler_op); - } else { - stack_ops.push(op); - } - op = op->GetChildOp(); - } - while (!stack_ops.empty()) { - operators->push_back(stack_ops.top()); - stack_ops.pop(); - } - return Status::OK(); -} - -Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["dataset_file"].is_none()) { - std::string err_msg = "Error: at least one of dataset_files is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - bool load_dataset = ToBool(args["load_dataset"]); - if (load_dataset == true) { - (void)builder->SetDatasetFile({ToString(args["dataset_file"])}); - } else { - (void)builder->SetDatasetFile(ToStringVector(args["dataset_file"])); - } - (void)builder->SetLoadDataset(load_dataset); - std::vector in_col_names; - if (!args["columns_list"].is_none()) { - in_col_names = ToStringVector(args["columns_list"]); - if (in_col_names.empty() || in_col_names[0].empty()) { - std::string err_msg = "Error: columns_list is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)builder->SetColumnsToLoad(in_col_names); - } - - if (!args["padded_sample"].is_none()) { - (void)builder->SetPaddedSample(args["padded_sample"]); - (void)builder->SetNumToPadSamples(ToInt(args["num_padded"])); - } - std::vector> operators; - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumMindRecordWorkers(ToInt(value)); - } else if (key == "block_reader" && ToBool(value) == true) { - (void)builder->SetBlockReader(); - } else if (key == "sampler") { - int num_padded = 0; - if (!args["num_padded"].is_none()) { - num_padded = ToInt(args["num_padded"]); - } - RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded)); - } - } - } - - if (!operators.empty()) { - (void)builder->SetOperators(operators); - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - num_rows_ = op->num_rows(); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - MapOp::Builder map_builder; - std::vector> tensor_op_list; - std::vector project_columns; - - if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n"); - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "input_columns") { - std::vector in_col_names = ToStringVector(args["input_columns"]); - (void)map_builder.SetInColNames(in_col_names); - } else if (key == "output_columns") { - (void)map_builder.SetOutColNames(ToStringVector(value)); - } else if (key == "columns_order") { - project_columns = ToStringVector(value); - } else if (key == "num_parallel_workers") { - (void)map_builder.SetNumWorkers(ToInt(value)); - } else if (key == "prefetch_size") { - (void)map_builder.SetOpConnectorSize(ToInt(value)); - } else if (key == "operations") { - py::handle tensor_ops = args["operations"]; - // operation can be a list of TensorOps or a single TensorOp. - if (py::isinstance(tensor_ops)) { - for (auto op : tensor_ops) { - std::shared_ptr tensor_op; - if (py::isinstance(op)) { - tensor_op = op.cast>(); - } else if (py::isinstance(op)) { - tensor_op = std::make_shared(op.cast()); - } else { - RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); - } - tensor_op_list.push_back(tensor_op); - } - } - if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set."); - (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); - } else { - RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); - } - } - } - - std::shared_ptr map_op; - RETURN_IF_NOT_OK(map_builder.Build(&map_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(map_op)); - *top = map_op; - - // Add a project op over top of the map if the user wanted to reposition the columns - if (!project_columns.empty()) { - ProjectOp::Builder proj_builder(project_columns); - std::shared_ptr proj_op; - RETURN_IF_NOT_OK(proj_builder.Build(&proj_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(proj_op)); - RETURN_IF_NOT_OK(proj_op->AddChild(map_op)); - *top = proj_op; - *bottom = map_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - - if (args["predicate"].is_none()) { - RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); - } - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "predicate") { - py::handle op = args["predicate"]; - if (!py::isinstance(op)) { - RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); - } - py::function predicate_func = op.cast(); - (void)builder->SetPredicateFunc(std::move(predicate_func)); - } else if (key == "input_columns") { - std::vector in_col_names = ToStringVector(args["input_columns"]); - (void)builder->SetInColNames(in_col_names); - } else { - RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["count"].is_none()) { - std::string err_msg = "Error: count is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - repeat_num_ = ToInt(args["count"]); - std::shared_ptr op; - RETURN_IF_NOT_OK(RepeatOp::Builder(ToInt(args["count"])).Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["count"].is_none()) { - std::string err_msg = "Error: count is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::shared_ptr op; - RETURN_IF_NOT_OK(SkipOp::Builder(ToInt(args["count"])).Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "source") { - py::object obj = py::cast(&value); - if (!py::isinstance(obj)) { - std::string err_msg = "Error: generator is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)builder->SetGeneratorFunction(obj.cast()); - } else if (key == "column_names") { - (void)builder->SetColumnNames(ToStringVector(value)); - } else if (key == "column_types") { - (void)builder->SetColumnTypes(ToTypeVector(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder; - if (py::isinstance(args["batch_size"])) { - batch_size_ = ToInt(args["batch_size"]); - CHECK_FAIL_RETURN_UNEXPECTED(batch_size_ > 0, "Error: batch_size is invalid."); - builder = std::make_shared(ToInt(args["batch_size"])); - } else if (py::isinstance(args["batch_size"])) { - builder = std::make_shared(1); - (void)builder->SetBatchSizeFunc(args["batch_size"].cast()); - } else { - std::string err_msg = "Error: batch_size is neither an Integer nor a python function"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "drop_remainder") { - (void)builder->SetDrop(ToBool(value)); - } - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } - if (key == "per_batch_map") { - (void)builder->SetBatchMapFunc(value.cast()); - } - if (key == "input_columns") { - (void)builder->SetColumnsToMap(ToStringVector(value)); - } - if (key == "pad_info") { - PadInfo pad_info; - RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); - (void)builder->SetPaddingMap(pad_info, true); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::vector mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", - "bucket_batch_sizes"}; - for (auto name : mandatory_arguments) { - if (args[name.c_str()].is_none()) { - std::string err_msg = "Error: " + name + " is not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - - std::shared_ptr builder = std::make_shared( - ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]), - ToIntVector(args[mandatory_arguments[2].c_str()])); - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "length_dependent_columns") { - (void)builder->SetLengthDependentColumns(ToStringVector(value)); - } - if (key == "bucket_boundaries") { - (void)builder->SetBucketBoundaries(ToIntVector(value)); - } - if (key == "bucket_batch_sizes") { - (void)builder->SetBucketBatchSizes(ToIntVector(value)); - } - if (key == "element_length_function") { - (void)builder->SetElementLengthFunction(value.cast()); - } - if (key == "pad_info") { - PadInfo pad_info; - RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); - (void)builder->SetPadInfo(pad_info); - } - if (key == "pad_to_bucket_boundary") { - (void)builder->SetPadToBucketBoundary(ToBool(value)); - } - if (key == "drop_remainder") { - (void)builder->SetDropRemainder(ToBool(value)); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - // Right now barrier should only take num_rows_per_buffer = 1 - // The reason for this is because having it otherwise can lead to blocking issues - // See barrier_op.h for more details - (void)builder->SetRowsPerBuffer(1); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "condition_name") { - (void)builder->SetConditionName(ToString(value)); - } else if (key == "condition_func") { - (void)builder->SetConditionFunc(value.cast()); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - int32_t prefetch_size = 0; - if (args.contains("prefetch_size")) { - if (args["prefetch_size"].is_none()) { - prefetch_size = 16; - } else { - prefetch_size = ToInt(args["prefetch_size"]); - } - } - std::shared_ptr builder = std::make_shared(prefetch_size); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "queue_name") { - (void)builder->SetChannelName(ToString(value)); - } else if (key == "device_type") { - (void)builder->SetDeviceType(ToString(value)); - } else if (key == "device_id") { - (void)builder->SetDeviceId(ToInt(value)); - } else if (key == "num_batch") { - (void)builder->SetNumBatch(ToInt(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::vector in_col_names; - std::vector out_col_names; - std::shared_ptr builder = std::make_shared(); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "input_columns") { - in_col_names = ToStringVector(value); - } else if (key == "output_columns") { - out_col_names = ToStringVector(value); - } - } - } - if (in_col_names.empty() || in_col_names[0].empty()) { - std::string err_msg = "Error: input_column_names is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - if (out_col_names.empty() || out_col_names[0].empty()) { - std::string err_msg = "Error: output_column_names is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)builder->SetInColNames(in_col_names); - (void)builder->SetOutColNames(out_col_names); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["count"].is_none()) { - std::string err_msg = "Error: count is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::shared_ptr op; - RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - std::vector files_list; - std::shared_ptr builder = std::make_shared(); - if (!args["dataset_files"].is_none()) { - files_list = ToStringVector(args["dataset_files"]); - (void)builder->SetDatasetFilesList(files_list); - } else { - std::string err_msg = "Error: at least one of dataset_files or schema_file is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::vector columns_to_load; - bool schema_exists = false; - bool shuffle_required = false; - int64_t num_devices = 0; - int64_t total_rows = 0; - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "columns_list") { - columns_to_load = ToStringVector(value); - (void)builder->SetColumnsToLoad(columns_to_load); - } else if (key == "shuffle_files") { - (void)builder->SetShuffleFiles(ToBool(value)); - } else if (key == "shuffle_global") { - shuffle_required = ToBool(value); - } else if (key == "schema_file_path" || key == "schema_json_string") { - schema_exists = true; - } else if (key == "num_samples") { - total_rows = ToInt(value); - (void)builder->setTotalRows(total_rows); - } else if (key == "num_shards") { - num_devices = ToInt(value); - (void)builder->SetNumDevices(num_devices); - } else if (key == "shard_id") { - (void)builder->SetDeviceId(ToInt(value)); - } else if (key == "shard_equal_rows") { - (void)builder->SetShardEqualRows(ToBool(value)); - } - } - } - if (schema_exists) { - std::unique_ptr schema = std::make_unique(); - if (args.contains("schema_file_path")) { - RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); - } else { - RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); - } - (void)builder->SetDataSchema(std::move(schema)); - } - std::shared_ptr tf_op; - RETURN_IF_NOT_OK(builder->Build(&tf_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op)); - *top = tf_op; - - if (shuffle_required) { - const boolean estimate = true; - const int64_t workers = 8; - std::shared_ptr shuffle_op = nullptr; - int64_t shuffle_size = 0; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset via estimate and then compute the shuffle size - RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, files_list, workers, estimate)); - RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, total_rows, &shuffle_size)); - - // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller - RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, tf_op, &shuffle_op)); - *top = shuffle_op; - *bottom = tf_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParseProjectOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["columns"].is_none()) { - std::string err_msg = "Error: columns is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::vector columns_to_project = ToStringVector(args["columns"]); - std::shared_ptr builder = std::make_shared(columns_to_project); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::shared_ptr builder = std::make_shared(); - (void)builder->SetImageFolderDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "extensions") { - (void)builder->SetExtensions(ToStringSet(value)); - } else if (key == "class_indexing") { - (void)builder->SetClassIndex(ToStringMap(value)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_file"].is_none()) { - std::string err_msg = "Error: No dataset files specified for manifest"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::shared_ptr builder = std::make_shared(); - (void)builder->SetManifestFile(ToString(args["dataset_file"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "class_indexing") { - (void)builder->SetClassIndex(ToStringMap(value)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "usage") { - (void)builder->SetUsage(ToString(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["task"].is_none()) { - std::string err_msg = "Error: No task specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["mode"].is_none()) { - std::string err_msg = "Error: No mode specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetDir(ToString(args["dataset_dir"])); - (void)builder->SetTask(ToString(args["task"])); - (void)builder->SetMode(ToString(args["mode"])); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "class_indexing") { - (void)builder->SetClassIndex(ToStringMap(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - - return Status::OK(); -} - -Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["annotation_file"].is_none()) { - std::string err_msg = "Error: No annotation_file specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["task"].is_none()) { - std::string err_msg = "Error: No task specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetDir(ToString(args["dataset_dir"])); - (void)builder->SetFile(ToString(args["annotation_file"])); - (void)builder->SetTask(ToString(args["task"])); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetCifarDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } - } - } - - (void)builder->SetCifarType(true); - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetCifarDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } - } - } - - (void)builder->SetCifarType(false); - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - RandomDataOp::Builder builder; - - if (args["num_samples"].is_none()) { - std::string err_msg = "Error: num_samples is a required argument"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::vector columns_to_load; - bool schema_exists = false; - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (key == "num_parallel_workers") { - (void)builder.SetNumWorkers(ToInt(value)); - } else if (key == "schema_file_path" || key == "schema_json_string") { - schema_exists = true; - } else if (key == "columns_list") { - columns_to_load = ToStringVector(value); - } else if (key == "num_samples") { - // This is not sampling here. The random data op needs to know how much data to - // generate. It does not currently support sampling. - (void)builder.SetTotalRows(ToInt(value)); - } - } - if (schema_exists) { - std::unique_ptr schema = std::make_unique(); - if (args.contains("schema_file_path")) { - RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); - } else { - RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); - } - (void)builder.SetDataSchema(std::move(schema)); - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder.Build(&op)); - *top = op; - return Status::OK(); -} - -int32_t DEPipeline::GetNumClasses() const { return num_classes_; } - -Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); - } - - std::shared_ptr builder = std::make_shared(); - if (builder == nullptr) { - std::string err_msg = "Create celebaop builder failed"; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); - } - (void)builder->SetCelebADir(ToString(args["dataset_dir"])); - for (const auto &arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "extensions") { - (void)builder->SetExtensions(ToStringSet(value)); - } else if (key == "dataset_type") { - (void)builder->SetDatasetType(ToString(value)); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - std::vector files_list; - std::shared_ptr builder = std::make_shared(); - if (!args["dataset_files"].is_none()) { - files_list = ToStringVector(args["dataset_files"]); - (void)builder->SetTextFilesList(files_list); - } else { - RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); - } - // Optional arguments - bool shuffle_required = false; - int64_t num_devices = 0; - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "shuffle_files") { - (void)builder->SetShuffleFiles(ToBool(value)); - } else if (key == "shuffle_global") { - shuffle_required = ToBool(value); - } else if (key == "num_samples") { - (void)builder->SetTotalRows(ToInt(value)); - } else if (key == "num_shards") { - num_devices = ToInt(value); - (void)builder->SetNumDevices(num_devices); - } else if (key == "shard_id") { - (void)builder->SetDeviceId(ToInt(value)); - } - } - } - - std::shared_ptr txt_op; - RETURN_IF_NOT_OK(builder->Build(&txt_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); - *top = txt_op; - - if (shuffle_required) { - std::shared_ptr shuffle_op = nullptr; - int64_t shuffle_size = 0; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset and then compute the shuffle size - RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(files_list, &num_rows)); - RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); - - // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller - RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, txt_op, &shuffle_op)); - *top = shuffle_op; - *bottom = txt_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { - for (auto p : py::reinterpret_borrow(value)) { - if (!p.second.is_none()) { - auto tp = py::reinterpret_borrow(p.second); - CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)"); - TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]); - std::shared_ptr pad_val = nullptr; - if (py::isinstance(tp[1])) { - std::string pad_val_string = tp[1].is_none() ? "" : ToString(tp[1]); - CHECK_FAIL_RETURN_UNEXPECTED( - Tensor::CreateTensor(&pad_val, std::vector{pad_val_string}, TensorShape::CreateScalar()), - "Cannot create pad_value Tensor"); - } else { - float pad_val_float = tp[1].is_none() ? 0 : ToFloat(tp[1]); - CHECK_FAIL_RETURN_UNEXPECTED(Tensor::CreateTensor(&pad_val, TensorImpl::kFlexible, TensorShape::CreateScalar(), - DataType(DataType::DE_FLOAT32)), - "Cannot create pad_value Tensor"); - pad_val->SetItemAt({}, pad_val_float); - } - (void)pad_info->insert({ToString(p.first), {shape, pad_val}}); - } else { // tuple is None - (void)pad_info->insert({ToString(p.first), {TensorShape({}), nullptr}}); - } - } - return Status::OK(); -} - -Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "freq_range") { - py::tuple tp = py::reinterpret_borrow(value); - if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow(tp[0])); - if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow(tp[1])); - } else if (key == "top_k") { - builder->SetTopK(py::reinterpret_borrow(value)); - } else if (key == "columns") { - (void)builder->SetColumnNames(ToStringVector(value)); - } else if (key == "vocab") { - (void)builder->SetVocab(value.cast>()); - } else if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "special_first") { - (void)builder->SetSpecialFirst(ToBool(value)); - } else if (key == "special_tokens") { - (void)builder->SetSpecialTokens(ToStringVector(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::vector files_list; - std::shared_ptr builder = std::make_shared(); - if (!args["dataset_files"].is_none()) { - files_list = ToStringVector(args["dataset_files"]); - (void)builder->SetClueFilesList(files_list); - } else { - RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); - } - // Optional arguments - bool shuffle_required = false; - int64_t num_devices = 0; - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "shuffle_files") { - (void)builder->SetShuffleFiles(ToBool(value)); - } else if (key == "shuffle_global") { - shuffle_required = ToBool(value); - } else if (key == "num_samples") { - (void)builder->SetNumSamples(ToInt(value)); - } else if (key == "num_shards") { - num_devices = ToInt(value); - (void)builder->SetNumDevices(num_devices); - } else if (key == "shard_id") { - (void)builder->SetDeviceId(ToInt(value)); - } else if (key == "cols_to_keyword") { - std::map map_dict; - for (auto p : py::reinterpret_borrow(value)) { - if (!p.second.is_none()) { - map_dict.insert({ToString(p.first), ToString(p.second)}); - } else { - map_dict.insert({ToString(p.first), ToString(p.first)}); - } - } - (void)builder->SetColsKeyMap(map_dict); - } - } - } - - std::shared_ptr clue_op; - RETURN_IF_NOT_OK(builder->Build(&clue_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); - *top = clue_op; - - if (shuffle_required) { - std::shared_ptr shuffle_op = nullptr; - int64_t shuffle_size = 0; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset and then compute the shuffle size - RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(files_list, &num_rows)); - RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); - - // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller - RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, clue_op, &shuffle_op)); - *top = shuffle_op; - *bottom = clue_op; - } - - return Status::OK(); -} - -// Helper function to inject a shuffle operator over top of the current operation being built. -Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, - std::shared_ptr *shuffle_op) { - std::shared_ptr new_shuffle_op = nullptr; - ShuffleOp::Builder shuffle_builder; - - (void)shuffle_builder.SetShuffleSize(shuffle_size); - RETURN_IF_NOT_OK(shuffle_builder.Build(&new_shuffle_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(new_shuffle_op)); - RETURN_IF_NOT_OK(new_shuffle_op->AddChild(input_op)); - // We have now created: - // - // ShuffleOp - // | - // input_op - // - *shuffle_op = new_shuffle_op; - - return Status::OK(); -} - -// Common code for computing a default shuffle size -Status DEPipeline::ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, - int64_t *shuffle_size) { - const int64_t average_files_multiplier = 4; - const int64_t shuffle_max = 10000; - int64_t avg_rows_per_file = 0; - - // Adjust the num rows per shard if sharding was given - if (num_devices > 0) { - if (num_rows % num_devices == 0) { - num_rows = num_rows / num_devices; - } else { - num_rows = (num_rows / num_devices) + 1; - } - } - - // Cap based on total rows directive. Some ops do not have this and give value of 0. - if (total_rows > 0) { - num_rows = std::min(num_rows, total_rows); - } - - // get the average per file - avg_rows_per_file = num_rows / num_files; - - *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h deleted file mode 100644 index 7cfc73307c..0000000000 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_API_DE_PIPELINE_H_ -#define DATASET_API_DE_PIPELINE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "dataset/core/client.h" // DE client -#include "dataset/engine/dataset_iterator.h" -#include "dataset/util/status.h" -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -namespace py = pybind11; -namespace mindspore { -namespace dataset { -using DsOpPtr = std::shared_ptr; - -// enum for the dataset operator names -enum OpName { - kShuffle, - kMindrecord, - kBatch, - kBucketBatch, - kBarrier, - kCache, - kRepeat, - kSkip, - kTake, - kZip, - kConcat, - kMap, - kFilter, - kDeviceQueue, - kGenerator, - kRename, - kTfReader, - kProject, - kImageFolder, - kMnist, - kManifest, - kVoc, - kCoco, - kCifar10, - kCifar100, - kCelebA, - kRandomData, - kTextFile, - kBuildVocab, - kClue -}; - -// The C++ binder class that we expose to the python script. -class DEPipeline { - public: - DEPipeline(); - - ~DEPipeline(); - - // Function to add a Node to the Execution Tree. - Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output); - - // Function to add a child and parent relationship. - static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op); - - // Function to assign the node as root. - Status AssignRootNode(const DsOpPtr &dataset_op); - - // Function to launch the tree execution. - Status LaunchTreeExec(); - - // Get a row of data as dictionary of column name to the value. - Status GetNextAsMap(py::dict *output); - - // Get a row of data as list. - Status GetNextAsList(py::list *output); - - Status GetOutputShapes(py::list *output); - - Status GetOutputTypes(py::list *output); - - int GetDatasetSize() const; - - int GetBatchSize() const; - - int GetRepeatCount() const; - - Status ParseShuffleOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status BuildMindrecordSamplerChain(const py::handle &handle, - std::vector> *operators, - int num_padded); - - Status ParseMapOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseFilterOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseRepeatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseSkipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseBatchOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom); - - Status ParseBarrierOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseRenameOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseTakeOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseZipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseConcatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseProjectOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseManifestOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseVOCOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseCocoOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseCifar10Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseCifar100Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - void PrintTree(); - - int32_t GetNumClasses() const; - - Status ParseMnistOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status SetBatchParameters(const py::dict &args); - - Status ParseCelebAOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseTextFileOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseClueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - private: - // Execution tree that links the dataset operators. - std::shared_ptr tree_; - - std::unique_ptr iterator_; - - static Status ParsePadInfo(py::handle value, PadInfo *pad_info); - - /// \brief Helper function to inject a shuffle operator over top of the current operation being built. - /// \param[in] shuffle_size The size to use in the shuffle buffer - /// \param[in] input_op The operator to build shuffle on top of - /// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be - /// the shuffle operator - /// \return Status return code - Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, - std::shared_ptr *shuffle_op); - - /// \brief Helper function to compute the shuffle size - /// \param[in] num_files The number of files in the dataset - /// \param[in] num_devices The number of devices in the dataset - /// \param[in] num_rows The number of rows in the dataset - /// \param[in] total_rows An upper bound on the total rows in the dataset - /// \param[out] shuffle_size The resultant computed shuffle size - /// \return Status return code - Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, - int64_t *shuffle_size); - - int batch_size_; - int repeat_num_; - int num_rows_; - int num_classes_; - - int temp_batch_size_; - bool temp_drop_remainder_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_API_DE_PIPELINE_H_ diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc deleted file mode 100644 index ed3f993fb8..0000000000 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ /dev/null @@ -1,919 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "dataset/api/de_pipeline.h" -#include "dataset/engine/datasetops/source/cifar_op.h" -#include "dataset/engine/datasetops/source/clue_op.h" -#include "dataset/engine/datasetops/source/coco_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/mindrecord_op.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/random_data_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/python_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#include "dataset/engine/gnn/graph.h" -#include "dataset/engine/jagged_connector.h" -#include "dataset/kernels/data/concatenate_op.h" -#include "dataset/kernels/data/duplicate_op.h" -#include "dataset/kernels/data/fill_op.h" -#include "dataset/kernels/data/mask_op.h" -#include "dataset/kernels/data/one_hot_op.h" -#include "dataset/kernels/data/pad_end_op.h" -#include "dataset/kernels/data/slice_op.h" -#include "dataset/kernels/data/to_float16_op.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/kernels/image/bounding_box_augment_op.h" -#include "dataset/kernels/image/center_crop_op.h" -#include "dataset/kernels/image/cut_out_op.h" -#include "dataset/kernels/image/decode_op.h" -#include "dataset/kernels/image/hwc_to_chw_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/image/normalize_op.h" -#include "dataset/kernels/image/pad_op.h" -#include "dataset/kernels/image/random_color_adjust_op.h" -#include "dataset/kernels/image/random_crop_and_resize_op.h" -#include "dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" -#include "dataset/kernels/image/random_crop_decode_resize_op.h" -#include "dataset/kernels/image/random_crop_op.h" -#include "dataset/kernels/image/random_crop_with_bbox_op.h" -#include "dataset/kernels/image/random_horizontal_flip_bbox_op.h" -#include "dataset/kernels/image/random_horizontal_flip_op.h" -#include "dataset/kernels/image/random_resize_op.h" -#include "dataset/kernels/image/random_resize_with_bbox_op.h" -#include "dataset/kernels/image/random_rotation_op.h" -#include "dataset/kernels/image/random_vertical_flip_op.h" -#include "dataset/kernels/image/random_vertical_flip_with_bbox_op.h" -#include "dataset/kernels/image/rescale_op.h" -#include "dataset/kernels/image/resize_bilinear_op.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/image/resize_with_bbox_op.h" -#include "dataset/kernels/image/uniform_aug_op.h" -#include "dataset/kernels/no_op.h" -#include "dataset/text/kernels/jieba_tokenizer_op.h" -#include "dataset/text/kernels/lookup_op.h" -#include "dataset/text/kernels/ngram_op.h" -#include "dataset/text/kernels/to_number_op.h" -#include "dataset/text/kernels/unicode_char_tokenizer_op.h" -#include "dataset/text/kernels/wordpiece_tokenizer_op.h" -#include "dataset/text/vocab.h" -#include "dataset/util/random.h" -#include "mindrecord/include/shard_distributed_sample.h" -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_pk_sample.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_sequential_sample.h" -#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "pybind11/stl_bind.h" - -#ifdef ENABLE_ICU4C -#include "dataset/text/kernels/basic_tokenizer_op.h" -#include "dataset/text/kernels/bert_tokenizer_op.h" -#include "dataset/text/kernels/case_fold_op.h" -#include "dataset/text/kernels/normalize_utf8_op.h" -#include "dataset/text/kernels/regex_replace_op.h" -#include "dataset/text/kernels/regex_tokenizer_op.h" -#include "dataset/text/kernels/unicode_script_tokenizer_op.h" -#include "dataset/text/kernels/whitespace_tokenizer_op.h" -#endif - -namespace py = pybind11; - -namespace mindspore { -namespace dataset { -#define THROW_IF_ERROR(s) \ - do { \ - Status rc = std::move(s); \ - if (rc.IsError()) throw std::runtime_error(rc.ToString()); \ - } while (false) - -void bindDEPipeline(py::module *m) { - (void)py::class_(*m, "DEPipeline") - .def(py::init<>()) - .def( - "AddNodeToTree", - [](DEPipeline &de, const OpName &op_name, const py::dict &args) { - py::dict out; - THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out)); - return out; - }, - py::return_value_policy::reference) - .def_static("AddChildToParentNode", - [](const DsOpPtr &child_op, const DsOpPtr &parent_op) { - THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op)); - }) - .def("AssignRootNode", - [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) - .def("SetBatchParameters", - [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) - .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) - .def("GetNextAsMap", - [](DEPipeline &de) { - py::dict out; - THROW_IF_ERROR(de.GetNextAsMap(&out)); - return out; - }) - .def("GetNextAsList", - [](DEPipeline &de) { - py::list out; - THROW_IF_ERROR(de.GetNextAsList(&out)); - return out; - }) - .def("GetOutputShapes", - [](DEPipeline &de) { - py::list out; - THROW_IF_ERROR(de.GetOutputShapes(&out)); - return out; - }) - .def("GetOutputTypes", - [](DEPipeline &de) { - py::list out; - THROW_IF_ERROR(de.GetOutputTypes(&out)); - return out; - }) - .def("GetDatasetSize", &DEPipeline::GetDatasetSize) - .def("GetBatchSize", &DEPipeline::GetBatchSize) - .def("GetNumClasses", &DEPipeline::GetNumClasses) - .def("GetRepeatCount", &DEPipeline::GetRepeatCount); -} -void bindDatasetOps(py::module *m) { - (void)py::class_>(*m, "TFReaderOp") - .def_static("get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) { - int64_t count = 0; - std::vector filenames; - for (auto l : files) { - !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back(""); - } - THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate)); - return count; - }); - - (void)py::class_>(*m, "CifarOp") - .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) { - int64_t count = 0; - THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count)); - return count; - }); - - (void)py::class_>(*m, "ImageFolderOp") - .def_static("get_num_rows_and_classes", [](const std::string &path) { - int64_t count = 0, num_classes = 0; - THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set{}, &count, &num_classes)); - return py::make_tuple(count, num_classes); - }); - - (void)py::class_>(*m, "MindRecordOp") - .def_static("get_num_rows", [](const std::vector &paths, bool load_dataset, const py::object &sampler, - const int64_t num_padded) { - int64_t count = 0; - std::shared_ptr op; - if (py::hasattr(sampler, "create_for_minddataset")) { - auto create = sampler.attr("create_for_minddataset"); - op = create().cast>(); - } - THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); - return count; - }); - - (void)py::class_>(*m, "ManifestOp") - .def_static("get_num_rows_and_classes", - [](const std::string &file, const py::dict &dict, const std::string &usage) { - int64_t count = 0, num_classes = 0; - THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); - return py::make_tuple(count, num_classes); - }) - .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) { - std::map output_class_indexing; - THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); - return output_class_indexing; - }); - - (void)py::class_>(*m, "MnistOp") - .def_static("get_num_rows", [](const std::string &dir) { - int64_t count = 0; - THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count)); - return count; - }); - - (void)py::class_>(*m, "TextFileOp") - .def_static("get_num_rows", [](const py::list &files) { - int64_t count = 0; - std::vector filenames; - for (auto file : files) { - !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); - } - THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); - return count; - }); - - (void)py::class_>(*m, "ClueOp") - .def_static("get_num_rows", [](const py::list &files) { - int64_t count = 0; - std::vector filenames; - for (auto file : files) { - file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); - } - THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count)); - return count; - }); - - (void)py::class_>(*m, "VOCOp") - .def_static("get_num_rows", - [](const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t numSamples) { - int64_t count = 0; - THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); - return count; - }) - .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, - const std::string &task_mode, const py::dict &dict) { - std::map output_class_indexing; - THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing)); - return output_class_indexing; - }); - (void)py::class_>(*m, "CocoOp") - .def_static("get_class_indexing", - [](const std::string &dir, const std::string &file, const std::string &task) { - std::vector>> output_class_indexing; - THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing)); - return output_class_indexing; - }) - .def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) { - int64_t count = 0; - THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count)); - return count; - }); -} -void bindTensor(py::module *m) { - (void)py::class_(*m, "GlobalContext") - .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference); - - (void)py::class_>(*m, "ConfigManager") - .def("__str__", &ConfigManager::ToString) - .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer) - .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers) - .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) - .def("set_op_connector_size", &ConfigManager::set_op_connector_size) - .def("set_seed", &ConfigManager::set_seed) - .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval) - .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer) - .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers) - .def("get_worker_connector_size", &ConfigManager::worker_connector_size) - .def("get_op_connector_size", &ConfigManager::op_connector_size) - .def("get_seed", &ConfigManager::seed) - .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) - .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); - - (void)py::class_>(*m, "Tensor", py::buffer_protocol()) - .def(py::init([](py::array arr) { - std::shared_ptr out; - THROW_IF_ERROR(Tensor::CreateTensor(&out, arr)); - return out; - })) - .def_buffer([](Tensor &tensor) { - py::buffer_info info; - THROW_IF_ERROR(Tensor::GetBufferInfo(tensor, &info)); - return info; - }) - .def("__str__", &Tensor::ToString) - .def("shape", &Tensor::shape) - .def("type", &Tensor::type) - .def("as_array", [](py::object &t) { - auto &tensor = py::cast(t); - if (tensor.type() == DataType::DE_STRING) { - py::array res; - tensor.GetDataAsNumpyStrings(&res); - return res; - } - py::buffer_info info; - THROW_IF_ERROR(Tensor::GetBufferInfo(tensor, &info)); - return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t); - }); - - (void)py::class_(*m, "TensorShape") - .def(py::init()) - .def("__str__", &TensorShape::ToString) - .def("as_list", &TensorShape::AsPyList) - .def("is_known", &TensorShape::known); - - (void)py::class_(*m, "DataType") - .def(py::init()) - .def(py::self == py::self) - .def("__str__", &DataType::ToString) - .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); -} - -void bindTensorOps1(py::module *m) { - (void)py::class_>(*m, "TensorOp") - .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); - - (void)py::class_>( - *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") - .def(py::init(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), - py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); - - (void)py::class_>( - *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.") - .def(py::init(), py::arg("rescale"), py::arg("shift")); - - (void)py::class_>( - *m, "CenterCropOp", "Tensor operation to crop and image in the middle. Takes height and width (optional)") - .def(py::init(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth); - - (void)py::class_>( - *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation); - - (void)py::class_>( - *m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth, - py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation); - - (void)py::class_>( - *m, "RandomResizeWithBBoxOp", - "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth); - - (void)py::class_>( - *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") - .def(py::init>, int32_t>(), py::arg("operations"), - py::arg("NumOps") = UniformAugOp::kDefNumOps); - - (void)py::class_>( - *m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.") - .def(py::init, float>(), py::arg("transform"), - py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio); - - (void)py::class_>( - *m, "ResizeBilinearOp", - "Tensor operation to resize an image using " - "Bilinear mode. Takes height and width.") - .def(py::init(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeBilinearOp::kDefWidth); - - (void)py::class_>(*m, "DecodeOp", - "Tensor operation to decode a jpg image") - .def(py::init<>()) - .def(py::init(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat); - - (void)py::class_>( - *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.") - .def(py::init(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability); - - (void)py::class_>( - *m, "RandomHorizontalFlipWithBBoxOp", - "Tensor operation to randomly flip an image horizontally, while flipping bounding boxes.") - .def(py::init(), py::arg("probability") = RandomHorizontalFlipWithBBoxOp::kDefProbability); -} - -void bindTensorOps2(py::module *m) { - (void)py::class_>( - *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.") - .def(py::init(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability); - - (void)py::class_>( - *m, "RandomVerticalFlipWithBBoxOp", - "Tensor operation to randomly flip an image vertically" - " and adjust bounding boxes.") - .def(py::init(), py::arg("probability") = RandomVerticalFlipWithBBoxOp::kDefProbability); - - (void)py::class_>(*m, "RandomCropOp", - "Gives random crop of specified size " - "Takes crop size") - .def(py::init(), - py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop, - py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft, - py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType, - py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR, - py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB); - (void)py::class_>(*m, "ChannelSwapOp").def(py::init<>()); - - (void)py::class_>(*m, "RandomCropWithBBoxOp", - "Gives random crop of given " - "size + adjusts bboxes " - "Takes crop size") - .def(py::init(), - py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropWithBBoxOp::kDefPadTop, - py::arg("padBottom") = RandomCropWithBBoxOp::kDefPadBottom, - py::arg("padLeft") = RandomCropWithBBoxOp::kDefPadLeft, - py::arg("padRight") = RandomCropWithBBoxOp::kDefPadRight, - py::arg("borderType") = RandomCropWithBBoxOp::kDefBorderType, - py::arg("padIfNeeded") = RandomCropWithBBoxOp::kDefPadIfNeeded, - py::arg("fillR") = RandomCropWithBBoxOp::kDefFillR, py::arg("fillG") = RandomCropWithBBoxOp::kDefFillG, - py::arg("fillB") = RandomCropWithBBoxOp::kDefFillB); - - (void)py::class_>( - *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.") - .def(py::init()); - - (void)py::class_>( - *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") - .def(py::init>()); - - (void)py::class_>(*m, "SliceOp", "Tensor slice operation.") - .def(py::init()) - .def(py::init([](const py::list &py_list) { - std::vector c_list; - for (auto l : py_list) { - if (!l.is_none()) { - c_list.push_back(py::reinterpret_borrow(l)); - } - } - return std::make_shared(c_list); - })) - .def(py::init([](const py::tuple &py_slice) { - if (py_slice.size() != 3) { - THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); - } - Slice c_slice; - if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) { - c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1]), - py::reinterpret_borrow(py_slice[2])); - } else if (py_slice[0].is_none() && py_slice[2].is_none()) { - c_slice = Slice(py::reinterpret_borrow(py_slice[1])); - } else if (!py_slice[0].is_none() && !py_slice[1].is_none()) { - c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1])); - } - - if (!c_slice.valid()) { - THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); - } - return std::make_shared(c_slice); - })); - - (void)py::enum_(*m, "RelationalOp", py::arithmetic()) - .value("EQ", RelationalOp::kEqual) - .value("NE", RelationalOp::kNotEqual) - .value("LT", RelationalOp::kLess) - .value("LE", RelationalOp::kLessEqual) - .value("GT", RelationalOp::kGreater) - .value("GE", RelationalOp::kGreaterEqual) - .export_values(); - - (void)py::class_>(*m, "MaskOp", - "Tensor mask operation using relational comparator") - .def(py::init, DataType>()); - - (void)py::class_>(*m, "DuplicateOp", "Duplicate tensor.") - .def(py::init<>()); - - (void)py::class_>( - *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") - .def(py::init()); - - (void)py::class_>(*m, "ConcatenateOp", - "Tensor operation concatenate tensors.") - .def(py::init, std::shared_ptr>(), py::arg("axis"), - py::arg("prepend").none(true), py::arg("append").none(true)); - - (void)py::class_>( - *m, "RandomRotationOp", - "Tensor operation to apply RandomRotation." - "Takes a range for degrees and " - "optional parameters for rotation center and image expand") - .def(py::init(), - py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX, - py::arg("centerY") = RandomRotationOp::kDefCenterY, - py::arg("interpolation") = RandomRotationOp::kDefInterpolation, - py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR, - py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB); - - (void)py::class_>( - *m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.") - .def(py::init>()); -} - -void bindTensorOps3(py::module *m) { - (void)py::class_>( - *m, "RandomCropAndResizeOp", - "Tensor operation to randomly crop an image and resize to a given size." - "Takes output height and width and" - "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," - "interpolation mode, and max attempts to crop") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb, - py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb, - py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb, - py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb, - py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation, - py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter); - - (void)py::class_>( - *m, "RandomCropAndResizeWithBBoxOp", - "Tensor operation to randomly crop an image (with BBoxes) and resize to a given size." - "Takes output height and width and" - "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," - "interpolation mode, and max attempts to crop") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeWithBBoxOp::kDefScaleLb, - py::arg("scaleUb") = RandomCropAndResizeWithBBoxOp::kDefScaleUb, - py::arg("aspectLb") = RandomCropAndResizeWithBBoxOp::kDefAspectLb, - py::arg("aspectUb") = RandomCropAndResizeWithBBoxOp::kDefAspectUb, - py::arg("interpolation") = RandomCropAndResizeWithBBoxOp::kDefInterpolation, - py::arg("maxIter") = RandomCropAndResizeWithBBoxOp::kDefMaxIter); - - (void)py::class_>( - *m, "RandomColorAdjustOp", - "Tensor operation to adjust an image's color randomly." - "Takes range for brightness, contrast, saturation, hue and") - .def(py::init(), py::arg("bright_factor_start"), - py::arg("bright_factor_end"), py::arg("contrast_factor_start"), py::arg("contrast_factor_end"), - py::arg("saturation_factor_start"), py::arg("saturation_factor_end"), py::arg("hue_factor_start"), - py::arg("hue_factor_end")); - - (void)py::class_>( - *m, "RandomResizeOp", - "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth); - - (void)py::class_>( - *m, "CutOutOp", "Tensor operation to randomly erase a portion of the image. Takes height and width.") - .def(py::init(), py::arg("boxHeight"), - py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor, - py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG, - py::arg("fillB") = CutOutOp::kDefFillB); -} - -void bindTensorOps4(py::module *m) { - (void)py::class_>( - *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.") - .def(py::init(), py::arg("data_type")) - .def(py::init(), py::arg("data_type")); - - (void)py::class_>(*m, "NoOp", - "TensorOp that does nothing, for testing purposes only.") - .def(py::init<>()); - - (void)py::class_>( - *m, "ToFloat16Op", py::dynamic_attr(), "Tensor operator to type cast float32 data to a float16 type.") - .def(py::init<>()); - - (void)py::class_>( - *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb, - py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb, - py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb, - py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb, - py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation, - py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter); - - (void)py::class_>( - *m, "PadOp", - "Pads image with specified color, default black, " - "Takes amount to pad for top, bottom, left, right of image, boarder type and color") - .def(py::init(), py::arg("padTop"), - py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType, - py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); - (void)py::class_>(*m, "ToNumberOp", - "TensorOp to convert strings to numbers.") - .def(py::init(), py::arg("data_type")) - .def(py::init(), py::arg("data_type")); -} - -void bindTokenizerOps(py::module *m) { - (void)py::class_>(*m, "JiebaTokenizerOp", "") - .def(py::init(), py::arg("hmm_path"), py::arg("mp_path"), - py::arg("mode") = JiebaMode::kMix) - .def("add_word", - [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); - (void)py::class_>( - *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") - .def(py::init<>()); - (void)py::class_>(*m, "LookupOp", - "Tensor operation to LookUp each word") - .def(py::init, WordIdType>(), py::arg("vocab"), py::arg("unknown")) - .def(py::init>(), py::arg("vocab")); - (void)py::class_>(*m, "NgramOp", "TensorOp performs ngram mapping") - .def(py::init &, int32_t, int32_t, const std::string &, const std::string &, - const std::string &>(), - py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"), - py::arg("separator")); - (void)py::class_>( - *m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.") - .def(py::init &, const std::string &, const int &, const std::string &>(), - py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), - py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, - py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken)); -} - -void bindDependIcuTokenizerOps(py::module *m) { -#ifdef ENABLE_ICU4C - (void)py::class_>( - *m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.") - .def(py::init<>()); - (void)py::class_>( - *m, "UnicodeScriptTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.") - .def(py::init<>()) - .def(py::init(), py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace); - (void)py::class_>( - *m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor") - .def(py::init<>()); - (void)py::class_>( - *m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.") - .def(py::init<>()) - .def(py::init(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm); - (void)py::class_>( - *m, "RegexReplaceOp", "Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.") - .def(py::init(), py::arg("pattern"), py::arg("replace"), - py::arg("replace_all")); - (void)py::class_>( - *m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.") - .def(py::init(), py::arg("delim_pattern"), py::arg("keep_delim_pattern")); - (void)py::class_>( - *m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.") - .def(py::init(), py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, - py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, - py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, - py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken); - (void)py::class_>(*m, "BertTokenizerOp", - "Tokenizer used for Bert text process.") - .def(py::init &, const std::string &, const int &, const std::string &, bool, bool, - NormalizeForm, bool>(), - py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), - py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, - py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), - py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, - py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, - py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, - py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken); -#endif -} - -void bindSamplerOps(py::module *m) { - (void)py::class_>(*m, "Sampler") - .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) - .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) - .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) - .def("get_indices", - [](Sampler &self) { - py::array ret; - THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); - return ret; - }) - .def("add_child", - [](std::shared_ptr self, std::shared_ptr child) { THROW_IF_ERROR(self->AddChild(child)); }); - - (void)py::class_>(*m, "ShardOperator") - .def("add_child", [](std::shared_ptr self, - std::shared_ptr child) { self->SetChildOp(child); }); - - (void)py::class_>(*m, "DistributedSampler") - .def(py::init()); - - (void)py::class_>(*m, "PKSampler") - .def(py::init()); - - (void)py::class_>(*m, "RandomSampler") - .def(py::init()); - - (void)py::class_>(*m, "SequentialSampler") - .def(py::init()); - - (void)py::class_>(*m, "SubsetRandomSampler") - .def(py::init>()); - - (void)py::class_>( - *m, "MindrecordSubsetRandomSampler") - .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); - - (void)py::class_>( - *m, "MindrecordPkSampler") - .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { - if (shuffle == true) { - return std::make_shared(kColumn, kVal, std::numeric_limits::max(), - GetSeed()); - } else { - return std::make_shared(kColumn, kVal); - } - })); - - (void)py::class_>(*m, "MindrecordDistributedSampler") - .def(py::init()); - - (void)py::class_>( - *m, "MindrecordRandomSampler") - .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) { - return std::make_shared(GetSeed(), num_samples, replacement, reshuffle_each_epoch); - })); - - (void)py::class_>(*m, "MindrecordSequentialSampler") - .def(py::init([](int num_samples, int start_index) { - return std::make_shared(num_samples, start_index); - })); - - (void)py::class_>(*m, "WeightedRandomSampler") - .def(py::init, bool>()); - - (void)py::class_>(*m, "PythonSampler") - .def(py::init()); -} - -void bindInfoObjects(py::module *m) { - (void)py::class_(*m, "CBatchInfo") - .def(py::init()) - .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num) - .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); -} - -void bindVocabObjects(py::module *m) { - (void)py::class_>(*m, "Vocab") - .def(py::init<>()) - .def_static("from_list", - [](const py::list &words, const py::list &special_tokens, bool special_first) { - std::shared_ptr v; - THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v)); - return v; - }) - .def_static("from_file", - [](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens, - bool special_first) { - std::shared_ptr v; - THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v)); - return v; - }) - .def_static("from_dict", [](const py::dict &words) { - std::shared_ptr v; - THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v)); - return v; - }); -} - -void bindGraphData(py::module *m) { - (void)py::class_>(*m, "Graph") - .def(py::init([](std::string dataset_file, int32_t num_workers) { - std::shared_ptr g_out = std::make_shared(dataset_file, num_workers); - THROW_IF_ERROR(g_out->Init()); - return g_out; - })) - .def("get_all_nodes", - [](gnn::Graph &g, gnn::NodeType node_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); - return out; - }) - .def("get_all_edges", - [](gnn::Graph &g, gnn::EdgeType edge_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); - return out; - }) - .def("get_nodes_from_edges", - [](gnn::Graph &g, std::vector edge_list) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); - return out; - }) - .def("get_all_neighbors", - [](gnn::Graph &g, std::vector node_list, gnn::NodeType neighbor_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); - return out; - }) - .def("get_sampled_neighbors", - [](gnn::Graph &g, std::vector node_list, std::vector neighbor_nums, - std::vector neighbor_types) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); - return out; - }) - .def("get_neg_sampled_neighbors", - [](gnn::Graph &g, std::vector node_list, gnn::NodeIdType neighbor_num, - gnn::NodeType neg_neighbor_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); - return out; - }) - .def("get_node_feature", - [](gnn::Graph &g, std::shared_ptr node_list, std::vector feature_types) { - TensorRow out; - THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); - return out.getRow(); - }) - .def("graph_info", - [](gnn::Graph &g) { - py::dict out; - THROW_IF_ERROR(g.GraphInfo(&out)); - return out; - }) - .def("random_walk", [](gnn::Graph &g, std::vector node_list, std::vector meta_path, - float step_home_param, float step_away_param, gnn::NodeIdType default_node) { - std::shared_ptr out; - THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); - return out; - }); -} - -// This is where we externalize the C logic as python modules -PYBIND11_MODULE(_c_dataengine, m) { - m.doc() = "pybind11 for _c_dataengine"; - (void)py::class_>(m, "DatasetOp"); - - (void)py::enum_(m, "OpName", py::arithmetic()) - .value("SHUFFLE", OpName::kShuffle) - .value("BATCH", OpName::kBatch) - .value("BUCKETBATCH", OpName::kBucketBatch) - .value("BARRIER", OpName::kBarrier) - .value("MINDRECORD", OpName::kMindrecord) - .value("CACHE", OpName::kCache) - .value("REPEAT", OpName::kRepeat) - .value("SKIP", OpName::kSkip) - .value("TAKE", OpName::kTake) - .value("ZIP", OpName::kZip) - .value("CONCAT", OpName::kConcat) - .value("MAP", OpName::kMap) - .value("FILTER", OpName::kFilter) - .value("DEVICEQUEUE", OpName::kDeviceQueue) - .value("GENERATOR", OpName::kGenerator) - .export_values() - .value("RENAME", OpName::kRename) - .value("TFREADER", OpName::kTfReader) - .value("PROJECT", OpName::kProject) - .value("IMAGEFOLDER", OpName::kImageFolder) - .value("MNIST", OpName::kMnist) - .value("MANIFEST", OpName::kManifest) - .value("VOC", OpName::kVoc) - .value("COCO", OpName::kCoco) - .value("CIFAR10", OpName::kCifar10) - .value("CIFAR100", OpName::kCifar100) - .value("RANDOMDATA", OpName::kRandomData) - .value("BUILDVOCAB", OpName::kBuildVocab) - .value("CELEBA", OpName::kCelebA) - .value("TEXTFILE", OpName::kTextFile) - .value("CLUE", OpName::kClue); - - (void)py::enum_(m, "JiebaMode", py::arithmetic()) - .value("DE_JIEBA_MIX", JiebaMode::kMix) - .value("DE_JIEBA_MP", JiebaMode::kMp) - .value("DE_JIEBA_HMM", JiebaMode::kHmm) - .export_values(); - -#ifdef ENABLE_ICU4C - (void)py::enum_(m, "NormalizeForm", py::arithmetic()) - .value("DE_NORMALIZE_NONE", NormalizeForm::kNone) - .value("DE_NORMALIZE_NFC", NormalizeForm::kNfc) - .value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc) - .value("DE_NORMALIZE_NFD", NormalizeForm::kNfd) - .value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd) - .export_values(); -#endif - - (void)py::enum_(m, "InterpolationMode", py::arithmetic()) - .value("DE_INTER_LINEAR", InterpolationMode::kLinear) - .value("DE_INTER_CUBIC", InterpolationMode::kCubic) - .value("DE_INTER_AREA", InterpolationMode::kArea) - .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour) - .export_values(); - - (void)py::enum_(m, "BorderType", py::arithmetic()) - .value("DE_BORDER_CONSTANT", BorderType::kConstant) - .value("DE_BORDER_EDGE", BorderType::kEdge) - .value("DE_BORDER_REFLECT", BorderType::kReflect) - .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric) - .export_values(); - bindDEPipeline(&m); - bindTensor(&m); - bindTensorOps1(&m); - bindTensorOps2(&m); - bindTensorOps3(&m); - bindTensorOps4(&m); - bindTokenizerOps(&m); - bindSamplerOps(&m); - bindDatasetOps(&m); - bindInfoObjects(&m); - bindVocabObjects(&m); - bindGraphData(&m); - bindDependIcuTokenizerOps(&m); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/CMakeLists.txt b/mindspore/ccsrc/dataset/core/CMakeLists.txt deleted file mode 100644 index 27b9f0e13b..0000000000 --- a/mindspore/ccsrc/dataset/core/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -ms_protobuf_generate(EXAMPLE_SRCS EXAMPLE_HDRS example.proto) -ms_protobuf_generate(FEATURE_SRCS FEATURE_HDRS feature.proto) -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(core OBJECT - ${EXAMPLE_SRCS} - ${FEATURE_SRCS} - client.cc - config_manager.cc - cv_tensor.cc - data_type.cc - global_context.cc - tensor.cc - tensor_row.cc - tensor_shape.cc - ) -add_dependencies(core mindspore::protobuf) -target_include_directories(core PRIVATE ${pybind11_INCLUDE_DIRS}) diff --git a/mindspore/ccsrc/dataset/core/client.cc b/mindspore/ccsrc/dataset/core/client.cc deleted file mode 100644 index 6247ddae7d..0000000000 --- a/mindspore/ccsrc/dataset/core/client.cc +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/client.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "dataset/util/services.h" -#include "dataset/util/sig_handler.h" - -namespace mindspore { -namespace dataset { -// This is a one-time global initializer which includes the call to instantiate singletons. -// It is external api call and not a member of the GlobalContext directly. -Status GlobalInit() { - // Bring up all the services (logger, task, bufferpool) - return (Services::CreateInstance()); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h deleted file mode 100644 index a10cb4596e..0000000000 --- a/mindspore/ccsrc/dataset/core/client.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_CLIENT_H_ -#define DATASET_CORE_CLIENT_H_ - -// client.h -// Include file for DE client functions - -#include "dataset/core/constants.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/barrier_op.h" -#include "dataset/engine/datasetops/batch_op.h" -#include "dataset/engine/datasetops/build_vocab_op.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/datasetops/device_queue_op.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/engine/datasetops/project_op.h" -#include "dataset/engine/datasetops/rename_op.h" -#include "dataset/engine/datasetops/filter_op.h" -#include "dataset/engine/datasetops/repeat_op.h" -#include "dataset/engine/datasetops/skip_op.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/engine/datasetops/source/generator_op.h" -#include "dataset/engine/datasetops/source/mindrecord_op.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" -#include "dataset/engine/datasetops/take_op.h" -#include "dataset/engine/datasetops/zip_op.h" -#include "dataset/engine/datasetops/concat_op.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// This is a one-time global initializer that needs to be called at the -// start of any minddata applications. -extern Status GlobalInit(); -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CORE_CLIENT_H_ diff --git a/mindspore/ccsrc/dataset/core/config_manager.cc b/mindspore/ccsrc/dataset/core/config_manager.cc deleted file mode 100644 index 9291a8f832..0000000000 --- a/mindspore/ccsrc/dataset/core/config_manager.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/config_manager.h" - -#include -#include -#include - -#include "dataset/util/system_pool.h" - -namespace mindspore { -namespace dataset { -// A print method typically used for debugging -void ConfigManager::Print(std::ostream &out) const { - // Don't show the test/internal ones. Only display the main ones here. - // fyi, boolalpha tells the output stream to write "true" and "false" for bools - out << "\nClient config settings :" - << "\nDataCache Rows per buffer : " << rows_per_buffer_ - << "\nParallelOp workers : " << num_parallel_workers_ - << "\nParallelOp worker connector size : " << worker_connector_size_ - << "\nSize of each Connector : " << op_connector_size_ << std::endl; -} - -// Private helper function that taks a nlohmann json format and populates the settings -Status ConfigManager::FromJson(const nlohmann::json &j) { - set_rows_per_buffer(j.value("rowsPerBuffer", rows_per_buffer_)); - set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_)); - set_worker_connector_size(j.value("workerConnectorSize", worker_connector_size_)); - set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); - set_seed(j.value("seed", seed_)); - set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); - return Status::OK(); -} - -// Loads a json file with the default settings and populates all the settings -Status ConfigManager::LoadFile(const std::string &settingsFile) { - Status rc; - if (!Path(settingsFile).Exists()) { - RETURN_STATUS_UNEXPECTED("File is not found."); - } - // Some settings are mandatory, others are not (with default). If a setting - // is optional it will set a default value if the config is missing from the file. - try { - std::ifstream in(settingsFile); - nlohmann::json js; - in >> js; - rc = FromJson(js); - } catch (const nlohmann::json::type_error &e) { - std::ostringstream ss; - ss << "Client file failed to load:\n" << e.what(); - std::string err_msg = ss.str(); - RETURN_STATUS_UNEXPECTED(err_msg); - } catch (const std::exception &err) { - RETURN_STATUS_UNEXPECTED("Client file failed to load."); - } - return rc; -} - -// Setter function -void ConfigManager::set_rows_per_buffer(int32_t rows_per_buffer) { rows_per_buffer_ = rows_per_buffer; } - -// Setter function -void ConfigManager::set_num_parallel_workers(int32_t num_parallel_workers) { - num_parallel_workers_ = num_parallel_workers; -} - -// Setter function -void ConfigManager::set_worker_connector_size(int32_t connector_size) { worker_connector_size_ = connector_size; } - -// Setter function -void ConfigManager::set_op_connector_size(int32_t connector_size) { op_connector_size_ = connector_size; } - -uint32_t ConfigManager::seed() const { return seed_; } - -void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; } - -void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/config_manager.h b/mindspore/ccsrc/dataset/core/config_manager.h deleted file mode 100644 index 807591daa1..0000000000 --- a/mindspore/ccsrc/dataset/core/config_manager.h +++ /dev/null @@ -1,137 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_CONFIG_MANAGER_H_ -#define DATASET_CORE_CONFIG_MANAGER_H_ - -#include -#include -#include - -#include - -#include "dataset/core/constants.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" - -// Config settings for the client-side -// example config file: -// { -// "rowsPerBuffer": 3 -// } -// - -namespace mindspore { -namespace dataset { -// The ConfigManager is a class for managing default values. When a user is constructing any objects -// in the framework, often they may choose to omit some settings instead of overriding them. -// This class manages some of the default values, for cases when the user does not manually specify -// those values. -class ConfigManager { - public: - ConfigManager() = default; - - // destructor - ~ConfigManager() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - void Print(std::ostream &out) const; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param cS - reference to the ConfigManager to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const ConfigManager &cS) { - cS.Print(out); - return out; - } - - // Another debug print helper. Converts the print info to a string for you. - // @return The string version of the debug print - std::string ToString() { - std::stringstream ss; - ss << *this; - return ss.str(); - } - - // Loads a json file with the default settings and populates all the settings - // @param settingsFile - A json file with a set of default settings - // @return Status error code - Status LoadFile(const std::string &settingsFile); - - // getter function - // @return The rows per buffer setting - int32_t rows_per_buffer() const { return rows_per_buffer_; } - - // getter function - // @return The number of workers setting - int32_t num_parallel_workers() const { return num_parallel_workers_; } - - // getter function - // @return The queue size of the operator's output connector - int32_t op_connector_size() const { return op_connector_size_; } - - // getter function - // @return The internal worker-to-master connector queue size - int32_t worker_connector_size() const { return worker_connector_size_; } - - // setter function - // @param rows_per_buffer - The setting to apply to the config - void set_rows_per_buffer(int32_t rows_per_buffer); - - // setter function - // @param num_parallel_workers - The setting to apply to the config - void set_num_parallel_workers(int32_t num_parallel_workers); - - // setter function - // @param connector_size - The setting to apply to the config - void set_worker_connector_size(int32_t connector_size); - - // setter function - // @param connector_size - The setting to apply to the config - void set_op_connector_size(int32_t connector_size); - - uint32_t seed() const; - - // setter function - // @param seed - The default seed to use - void set_seed(uint32_t seed); - - // setter function - // @param interval - The setting to apply to the config - void set_monitor_sampling_interval(uint32_t interval); - - // getter function - // @return The iterval of monitor sampling - int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; } - - private: - int32_t rows_per_buffer_{kCfgRowsPerBuffer}; - int32_t num_parallel_workers_{kCfgParallelWorkers}; - int32_t worker_connector_size_{kCfgWorkerConnectorSize}; - int32_t op_connector_size_{kCfgOpConnectorSize}; - uint32_t seed_{kCfgDefaultSeed}; - uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval}; - - // Private helper function that taks a nlohmann json format and populates the settings - // @param j - The json nlohmann json info - Status FromJson(const nlohmann::json &j); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CORE_CONFIG_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/core/constants.h b/mindspore/ccsrc/dataset/core/constants.h deleted file mode 100644 index 34d2f2583c..0000000000 --- a/mindspore/ccsrc/dataset/core/constants.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_CONSTANTS_H_ -#define DATASET_CORE_CONSTANTS_H_ - -#include -#include -#include - -namespace mindspore { -namespace dataset { -// Various type defines for convenience -using uchar = unsigned char; -using dsize_t = int64_t; - -// Possible dataset types for holding the data and client type -enum class DatasetType { kUnknown, kArrow, kTf }; - -// Possible flavours of Tensor implementations -enum class TensorImpl { kNone, kFlexible, kCv, kNP }; - -// convenience functions for 32bit int bitmask -inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } - -inline void BitSet(uint32_t *bits, uint32_t bitMask) { *bits |= bitMask; } - -inline void BitClear(uint32_t *bits, uint32_t bitMask) { *bits &= (~bitMask); } - -constexpr int32_t kDeMaxDim = std::numeric_limits::max(); // 2147483647 or 2^32 -1 -constexpr int32_t kDeMaxRank = std::numeric_limits::max(); - -constexpr uint32_t kCfgRowsPerBuffer = 1; -constexpr uint32_t kCfgParallelWorkers = 4; -constexpr uint32_t kCfgWorkerConnectorSize = 16; -constexpr uint32_t kCfgOpConnectorSize = 16; -constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed; -constexpr uint32_t kCfgMonitorSamplingInterval = 10; - -// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) -constexpr uint8_t kCVInvalidType = 255; - -using connection_id_type = int64_t; -using row_id_type = int64_t; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CORE_CONSTANTS_H_ diff --git a/mindspore/ccsrc/dataset/core/cv_tensor.cc b/mindspore/ccsrc/dataset/core/cv_tensor.cc deleted file mode 100644 index 16921e8b2d..0000000000 --- a/mindspore/ccsrc/dataset/core/cv_tensor.cc +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/cv_tensor.h" - -#include -#include - -#include "dataset/core/constants.h" -#include "dataset/core/tensor.h" - -namespace mindspore { -namespace dataset { -CVTensor::CVTensor(const TensorShape &shape, const DataType &type) : Tensor(shape, type) { - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); -} - -CVTensor::CVTensor(const TensorShape &shape, const DataType &type, const uchar *data) : Tensor(shape, type, data) { - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); -} - -CVTensor::CVTensor(std::shared_ptr tensor) : Tensor(std::move(*tensor)) { - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); -} - -std::pair, int> CVTensor::IsValidImage(const TensorShape &shape, const DataType &type) { - std::array size = {1, 1}; - if (shape.Rank() <= 2 || (shape.Rank() == 3 && shape[2] <= CV_CN_MAX)) { - uint8_t ch = 1; - if (shape.Rank() == 3) { - ch = static_cast(shape[2]); - } - if (shape.Rank() > 0) size[0] = static_cast(shape[0]); - if (shape.Rank() > 1) size[1] = static_cast(shape[1]); - if (type.AsCVType() == kCVInvalidType) return std::make_pair(size, -1); - - int cv_type = CV_MAKETYPE(type.AsCVType(), ch); - return std::make_pair(size, cv_type); - } - return std::make_pair(size, -1); -} - -std::shared_ptr CVTensor::AsCVTensor(std::shared_ptr t) { - std::shared_ptr cv_t = std::dynamic_pointer_cast(t); - if (cv_t != nullptr) { - return cv_t; - } else { - return std::make_shared(t); - } -} - -Status CVTensor::MatInit(uchar *data, const TensorShape &shape, const DataType &type, cv::Mat *mat) { - std::pair, int> cv_shape_type = IsValidImage(shape, type); - if (cv_shape_type.second == -1) { - std::vector sizes = shape.AsVector(); - std::vector sizes32(sizes.begin(), sizes.end()); // convert long to int for usage with OpenCV - if (static_cast(shape.Rank()) != shape.Rank()) { - RETURN_STATUS_UNEXPECTED("Error in creating CV mat. Wrong shape."); - } - - uint8_t cv_type = type.AsCVType(); - if (cv_type == kCVInvalidType) { - RETURN_STATUS_UNEXPECTED("Error in creating CV mat. Invalid type."); - } - *mat = cv::Mat(static_cast(shape.Rank()), &sizes32[0], cv_type, data); - } else { - *mat = cv::Mat(2, &(cv_shape_type.first[0]), cv_shape_type.second, data); - } - return Status::OK(); -} - -Status CVTensor::Reshape(const TensorShape &shape) { - RETURN_IF_NOT_OK(Tensor::Reshape(shape)); - RETURN_IF_NOT_OK(this->MatInit(GetMutableBuffer(), shape_, type_, &mat_)); - return Status::OK(); -} - -Status CVTensor::ExpandDim(const dsize_t &axis) { - RETURN_IF_NOT_OK(Tensor::ExpandDim(axis)); - RETURN_IF_NOT_OK(this->MatInit(GetMutableBuffer(), shape_, type_, &mat_)); - return Status::OK(); -} - -void CVTensor::Squeeze() { - Tensor::Squeeze(); - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/cv_tensor.h b/mindspore/ccsrc/dataset/core/cv_tensor.h deleted file mode 100644 index 8c136f5f3c..0000000000 --- a/mindspore/ccsrc/dataset/core/cv_tensor.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_CV_TENSOR_H_ -#define DATASET_CORE_CV_TENSOR_H_ - -#include -#include -#include - -#include - -#include "./securec.h" - -#include "dataset/core/constants.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" - -namespace mindspore { -namespace dataset { -class CVTensor : public Tensor { - public: - // Create an empty CVTensor of shape `shape` and type `type`. - // @note The shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - CVTensor(const TensorShape &shape, const DataType &type); - - // Create a CVTensor from a given buffer, shape and type. - // @note This constructor allocates a new space in the memory and copies the buffer into it. - // @note The buffer should be valid and the shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - // @param data unsigned char*, pointer to the data. - CVTensor(const TensorShape &shape, const DataType &type, const uchar *data); - - // Create a CVTensor from a given CV::Mat. - // @note This constructor allocates a new space in the memory and copies the CV::Mat buffer into it. - // @param mat CV::Mat - explicit CVTensor(const cv::Mat &mat) - : CVTensor(TensorShape(mat.size, mat.type()), DataType::FromCVType(mat.type()), mat.data) {} - - ~CVTensor() = default; - - // Static function to cast a given Tensor as CVTensor. If the input tensor is already of type CVTensor, - // this function would be treated as a no-op. Fot other tensor types, a new CVTensor is created based on the data - // provided. The Passed Tensor will be invalidated. - // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. - // @param tensor - // @return CVTensor - static std::shared_ptr AsCVTensor(std::shared_ptr tensor); - - // Create a CVTensor from a given tensor. The input tensor will be invalidated (i.e., the shape and type will be - // set to unknown and the data buffer will point to null. - // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. - // @param tensor - explicit CVTensor(std::shared_ptr tensor); - - // Getter function for the CV::Mat - // @return - cv::Mat mat() const { return mat_; } - - // Static function to check if the passed information (shape and type) can be treated as a valid description - // of an image in OpenCV. Moreover, it returns OpenCV shape and type - // For example, if the shape is <512,512,3> and type is DE_UINT8, the output would be [512,512] and CV_8UC3. - // In case of invalid shape or type, the function will return pair - // @param shape TensorShape - // @param type DataType - // @return std::pair of OpenCV shape and type - std::pair, int> IsValidImage(const TensorShape &shape, const DataType &type); - - Status Reshape(const TensorShape &shape) override; - - Status ExpandDim(const dsize_t &axis) override; - - void Squeeze() override; - - Status Mat(const std::vector &index, cv::Mat *mat) { - uchar *start = nullptr; - TensorShape remaining({-1}); - RETURN_IF_NOT_OK(this->StartAddrOfIndex(index, &start, &remaining)); - RETURN_IF_NOT_OK(this->MatInit(start, remaining, type_, mat)); - return Status::OK(); - } - - private: - cv::Mat mat_; - - // Initialize CV::Mat with the data_, shape_ and type_ - Status MatInit(uchar *data, const TensorShape &shape, const DataType &type, cv::Mat *mat); -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_CV_TENSOR_H_ diff --git a/mindspore/ccsrc/dataset/core/data_type.cc b/mindspore/ccsrc/dataset/core/data_type.cc deleted file mode 100644 index bb10fae52f..0000000000 --- a/mindspore/ccsrc/dataset/core/data_type.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/data_type.h" - -#include "utils/log_adapter.h" - -#include "dataset/core/pybind_support.h" - -namespace mindspore { -namespace dataset { - -uint8_t DataType::SizeInBytes() const { - if (type_ < DataType::NUM_OF_TYPES) - return kTypeInfo[type_].sizeInBytes_; - else - return 0; -} - -py::dtype DataType::AsNumpyType() const { - if (type_ < DataType::NUM_OF_TYPES) - return py::dtype(kTypeInfo[type_].pybindType_); - else - return py::dtype("unknown"); -} - -uint8_t DataType::AsCVType() const { - uint8_t res = kCVInvalidType; - if (type_ < DataType::NUM_OF_TYPES) { - res = kTypeInfo[type_].cvType_; - } - - if (res == kCVInvalidType) { - MS_LOG(ERROR) << "Cannot convert to OpenCV type. Return invalid type!"; - } - - return res; -} // namespace dataset - -DataType DataType::FromCVType(int cv_type) { - auto depth = static_cast(cv_type) & static_cast(CV_MAT_DEPTH_MASK); - switch (depth) { - case CV_8S: - return DataType(DataType::DE_INT8); - case CV_8U: - return DataType(DataType::DE_UINT8); - case CV_16S: - return DataType(DataType::DE_INT16); - case CV_16U: - return DataType(DataType::DE_UINT16); - case CV_32S: - return DataType(DataType::DE_INT32); - case CV_16F: - return DataType(DataType::DE_FLOAT16); - case CV_32F: - return DataType(DataType::DE_FLOAT32); - case CV_64F: - return DataType(DataType::DE_FLOAT64); - default: - MS_LOG(ERROR) << "Cannot convert from OpenCV type, unknown CV type. Unknown data type is returned!"; - return DataType(DataType::DE_UNKNOWN); - } -} - -DataType::DataType(const std::string &type_str) { - if (type_str == "bool") - type_ = DE_BOOL; - else if (type_str == "int8") - type_ = DE_INT8; - else if (type_str == "uint8") - type_ = DE_UINT8; - else if (type_str == "int16") - type_ = DE_INT16; - else if (type_str == "uint16") - type_ = DE_UINT16; - else if (type_str == "int32") - type_ = DE_INT32; - else if (type_str == "uint32") - type_ = DE_UINT32; - else if (type_str == "int64") - type_ = DE_INT64; - else if (type_str == "uint64") - type_ = DE_UINT64; - else if (type_str == "float16") - type_ = DE_FLOAT16; - else if (type_str == "float32") - type_ = DE_FLOAT32; - else if (type_str == "float64") - type_ = DE_FLOAT64; - else if (type_str == "string") - type_ = DE_STRING; - else - type_ = DE_UNKNOWN; -} - -std::string DataType::ToString() const { - if (type_ < DataType::NUM_OF_TYPES) - return kTypeInfo[type_].name_; - else - return "unknown"; -} - -DataType DataType::FromNpArray(const py::array &arr) { - if (py::isinstance>(arr)) { - return DataType(DataType::DE_BOOL); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_INT8); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_UINT8); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_INT16); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_UINT16); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_INT32); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_UINT32); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_INT64); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_UINT64); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_FLOAT16); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_FLOAT32); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_FLOAT64); - } else if (arr.dtype().kind() == 'S' || arr.dtype().kind() == 'U') { - return DataType(DataType::DE_STRING); - } else { - MS_LOG(ERROR) << "Cannot convert from numpy type. Unknown data type is returned!"; - return DataType(DataType::DE_UNKNOWN); - } -} - -std::string DataType::GetPybindFormat() const { - std::string res; - if (type_ < DataType::NUM_OF_TYPES) { - res = kTypeInfo[type_].pybindFormatDescriptor_; - } - - if (res.empty()) { - MS_LOG(ERROR) << "Cannot convert from data type to pybind format descriptor!"; - } - return res; -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/data_type.h b/mindspore/ccsrc/dataset/core/data_type.h deleted file mode 100644 index a487f3300e..0000000000 --- a/mindspore/ccsrc/dataset/core/data_type.h +++ /dev/null @@ -1,326 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_DATA_TYPE_H_ -#define DATASET_CORE_DATA_TYPE_H_ - -#include - -#include - -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" - -#include "dataset/core/constants.h" -#include "dataset/core/pybind_support.h" - -namespace py = pybind11; -namespace mindspore { -namespace dataset { - -// Class that represents basic data types in DataEngine. -class DataType { - public: - enum Type : uint8_t { - DE_UNKNOWN = 0, - DE_BOOL, - DE_INT8, - DE_UINT8, - DE_INT16, - DE_UINT16, - DE_INT32, - DE_UINT32, - DE_INT64, - DE_UINT64, - DE_FLOAT16, - DE_FLOAT32, - DE_FLOAT64, - DE_STRING, - NUM_OF_TYPES - }; - - struct TypeInfo { - const char *name_; // name to be represent the type while printing - const uint8_t sizeInBytes_; // number of bytes needed for this type - const char *pybindType_; // Python matching type, used in get_output_types - const std::string pybindFormatDescriptor_; // pybind format used for numpy types - const uint8_t cvType_; // OpenCv matching type - }; - - static inline const TypeInfo kTypeInfo[] = { - // name, sizeInBytes, pybindTypem formatDescriptor, openCV - {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN - {"bool", 1, "bool", py::format_descriptor::format(), CV_8U}, // DE_BOOL - {"int8", 1, "int8", py::format_descriptor::format(), CV_8S}, // DE_INT8 - {"uint8", 1, "uint8", py::format_descriptor::format(), CV_8U}, // DE_UINT8 - {"int16", 2, "int16", py::format_descriptor::format(), CV_16S}, // DE_INT16 - {"uint16", 2, "uint16", py::format_descriptor::format(), CV_16U}, // DE_UINT16 - {"int32", 4, "int32", py::format_descriptor::format(), CV_32S}, // DE_INT32 - {"uint32", 4, "uint32", py::format_descriptor::format(), kCVInvalidType}, // DE_UINT32 - {"int64", 8, "int64", py::format_descriptor::format(), kCVInvalidType}, // DE_INT64 - {"uint64", 8, "uint64", py::format_descriptor::format(), kCVInvalidType}, // DE_UINT64 - {"float16", 2, "float16", "e", CV_16F}, // DE_FLOAT16 - {"float32", 4, "float32", py::format_descriptor::format(), CV_32F}, // DE_FLOAT32 - {"float64", 8, "double", py::format_descriptor::format(), CV_64F}, // DE_FLOAT64 - {"string", 0, "bytes", "S", kCVInvalidType} // DE_STRING - }; - - // No arg constructor to create an unknown shape - DataType() : type_(DE_UNKNOWN) {} - - // Create a type from a given string - // @param type_str - explicit DataType(const std::string &type_str); - - // Default destructor - ~DataType() = default; - - // Create a type from a given enum - // @param d - constexpr explicit DataType(Type d) : type_(d) {} - - constexpr bool operator==(const DataType a) const { return type_ == a.type_; } - - constexpr bool operator==(const Type a) const { return type_ == a; } - - constexpr bool operator!=(const DataType a) const { return type_ != a.type_; } - - constexpr bool operator!=(const Type a) const { return type_ != a; } - - // Disable this usage `if(d)` where d is of type DataType - // @return - operator bool() = delete; - - // To be used in Switch/case - // @return - operator Type() const { return type_; } - - // The number of bytes needed to store one value of this type - // @return - uint8_t SizeInBytes() const; - - // Convert from DataType to OpenCV type - // @return - uint8_t AsCVType() const; - - // Convert from OpenCV type to DataType - // @param cv_type - // @return - static DataType FromCVType(int cv_type); - - // Returns a string representation of the type - // @return - std::string ToString() const; - - // returns true if the template type is the same as the Tensor type_ - // @tparam T - // @return true or false - template - bool IsCompatible() const { - return type_ == FromCType(); - } - - // returns true if the template type is the same as the Tensor type_ - // @tparam T - // @return true or false - template - bool IsLooselyCompatible() const; - - // << Stream output operator overload - // @notes This allows you to print the info using stream operators - // @param out - reference to the output stream being overloaded - // @param rO - reference to the DataType to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const DataType &so) { - out << so.ToString(); - return out; - } - - template - static DataType FromCType(); - - // Convert from DataType to Pybind type - // @return - py::dtype AsNumpyType() const; - - // Convert from NP type to DataType - // @param type - // @return - static DataType FromNpType(const py::dtype &type); - - // Convert from NP array to DataType - // @param py array - // @return - static DataType FromNpArray(const py::array &arr); - - // Get the buffer string format of the current type. Used in pybind buffer protocol. - // @return - std::string GetPybindFormat() const; - - bool IsSignedInt() const { - return type_ == DataType::DE_INT8 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT32 || - type_ == DataType::DE_INT64; - } - - bool IsUnsignedInt() const { - return type_ == DataType::DE_UINT8 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT32 || - type_ == DataType::DE_UINT64; - } - - bool IsInt() const { return IsSignedInt() || IsUnsignedInt(); } - - bool IsFloat() const { - return type_ == DataType::DE_FLOAT16 || type_ == DataType::DE_FLOAT32 || type_ == DataType::DE_FLOAT64; - } - - bool IsBool() const { return type_ == DataType::DE_BOOL; } - - bool IsNumeric() const { return type_ != DataType::DE_STRING; } - - Type value() const { return type_; } - - private: - Type type_; -}; - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_BOOL); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_FLOAT64); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_FLOAT32); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_FLOAT16); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_INT64); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_UINT64); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_INT32); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_UINT32); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_INT16); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_UINT16); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_INT8); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_UINT8); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_STRING); -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_BOOL; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_FLOAT64 || type_ == DataType::DE_FLOAT32; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_FLOAT32; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_FLOAT16; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_INT64 || type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 || - type_ == DataType::DE_INT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_UINT64 || type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 || - type_ == DataType::DE_UINT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_INT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_UINT8; -} -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_DATA_TYPE_H_ diff --git a/mindspore/ccsrc/dataset/core/global_context.cc b/mindspore/ccsrc/dataset/core/global_context.cc deleted file mode 100644 index 3de8e0fcd8..0000000000 --- a/mindspore/ccsrc/dataset/core/global_context.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/global_context.h" - -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/tensor.h" -#include "dataset/util/allocator.h" -#include "dataset/util/circular_pool.h" -#include "dataset/util/system_pool.h" - -namespace mindspore { -namespace dataset { -// Global static pointer for the singleton GlobalContext -std::unique_ptr GlobalContext::global_context_ = nullptr; -std::once_flag GlobalContext::init_instance_flag_; - -constexpr int GlobalContext::kArenaSize; -constexpr int GlobalContext::kMaxSize; -constexpr bool GlobalContext::kInitArena; - -// Singleton initializer -GlobalContext *GlobalContext::Instance() { - // If the single global context is not created yet, then create it. Otherwise the - // existing one is returned. - std::call_once(init_instance_flag_, []() { - global_context_.reset(new GlobalContext()); - Status rc = global_context_->Init(); - if (rc.IsError()) { - std::terminate(); - } - }); - return global_context_.get(); -} - -Status GlobalContext::Init() { - config_manager_ = std::make_shared(); - mem_pool_ = std::make_shared(); - // For testing we can use Dummy pool instead - - // Create some tensor allocators for the different types and hook them into the pool. - tensor_allocator_ = std::make_unique>(mem_pool_); - cv_tensor_allocator_ = std::make_unique>(mem_pool_); - int_allocator_ = std::make_unique(mem_pool_); - return Status::OK(); -} - -// A print method typically used for debugging -void GlobalContext::Print(std::ostream &out) const { - out << "GlobalContext contains the following default config: " << *config_manager_ << "\n"; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/global_context.h b/mindspore/ccsrc/dataset/core/global_context.h deleted file mode 100644 index ee0cbfbbe0..0000000000 --- a/mindspore/ccsrc/dataset/core/global_context.h +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_GLOBAL_CONTEXT_H_ -#define DATASET_CORE_GLOBAL_CONTEXT_H_ - -#include -#include - -#include "dataset/core/constants.h" -#include "dataset/util/allocator.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// forward declare -class MemoryPool; -class ConfigManager; -class Tensor; -class CVTensor; - -using TensorAlloc = Allocator; // An allocator for Tensors -using CVTensorAlloc = Allocator; // An allocator CVTensors -using IntAlloc = Allocator; - -class GlobalContext { - // some consts for pool config - static constexpr int kArenaSize = 128; - static constexpr int kMaxSize = -1; - static constexpr bool kInitArena = true; - - public: - // Singleton pattern. This method either: - // - creates the single version of the GlobalContext for the first time and returns it - // OR - // - returns the already existing single instance of the GlobalContext - // @return the single global context - static GlobalContext *Instance(); - - // Destructor - ~GlobalContext() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - void Print(std::ostream &out) const; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param g_c - reference to the GlobalContext to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const GlobalContext &g_c) { - g_c.Print(out); - return out; - } - - // Getter method - // @return the client config as raw const pointer - static std::shared_ptr config_manager() { return Instance()->config_manager_; } - - // Getter method - // @return the mem pool - std::shared_ptr mem_pool() const { return mem_pool_; } - - // Getter method - // @return the tensor allocator as raw pointer - const TensorAlloc *tensor_allocator() const { return tensor_allocator_.get(); } - - // Getter method - // @return the CVTensor allocator as raw pointer - const CVTensorAlloc *cv_tensor_allocator() const { return cv_tensor_allocator_.get(); } - - // Getter method - // @return the integer allocator as raw pointer - const IntAlloc *int_allocator() const { return int_allocator_.get(); } - - private: - // Constructor. - // @note Singleton. Instantiation flows through instance() - // @return This is a constructor. - GlobalContext() = default; - - Status Init(); - - static std::once_flag init_instance_flag_; - static std::unique_ptr global_context_; // The instance of the singleton (global) - std::shared_ptr mem_pool_; // A global memory pool - std::shared_ptr config_manager_; // The configs - std::unique_ptr tensor_allocator_; // An allocator for Tensors - std::unique_ptr cv_tensor_allocator_; // An allocator for CV Tensors - std::unique_ptr int_allocator_; // An allocator for ints -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CORE_GLOBAL_CONTEXT_H_ diff --git a/mindspore/ccsrc/dataset/core/tensor.cc b/mindspore/ccsrc/dataset/core/tensor.cc deleted file mode 100644 index 8de3425c5b..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor.cc +++ /dev/null @@ -1,1013 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/tensor.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/core/constants.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/global_context.h" -#include "dataset/core/pybind_support.h" -#include "dataset/core/tensor_shape.h" - -namespace py = pybind11; -namespace mindspore { -namespace dataset { -// Helper macros for printing tensor elements -#define CASE_PRINT(de_type, native_type) \ - case de_type: { \ - native_type o; \ - rc = GetItemAt(&o, index); \ - out << o; \ - break; \ - } - -#define CASE_PRINT_HEX(de_type, native_type) \ - case de_type: { \ - native_type o; \ - rc = GetItemAt(&o, index); \ - out << std::hex << std::setw(2) << std::setfill('0') << o << std::dec << std::setfill(' '); \ - break; \ - } - -Tensor::Tensor(const TensorShape &shape, const DataType &type) : shape_(shape), type_(type), data_(nullptr) { - // grab the mem pool from global context and create the allocator for char data area - std::shared_ptr global_pool = GlobalContext::Instance()->mem_pool(); - data_allocator_ = std::make_unique>(global_pool); -} - -Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data) : Tensor(shape, type) { - if (type.IsNumeric()) { - // If the data pointer was given, then we can also populate the tensor with data - if (data != nullptr) { - // Given the shape/type of this tensor, compute the data size and copy in the input bytes. - int64_t byte_size = this->SizeInBytes(); - Status s = this->AllocateBuffer(byte_size); // Allocates data_ inside itself - if (s.IsOk() && data_ != nullptr) { - int ret_code = memcpy_s(data_, byte_size, data, byte_size); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data into Tensor!"; - } - } else { - MS_LOG(ERROR) << "Failed to create memory for Tensor!"; - } - } - } else { - MS_LOG(ERROR) << "Type should be numeric to use this constructor."; - } -} - -Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length) - : Tensor(shape, type) { - // If the data pointer was given, then we can also populate the tensor with data - if (data != nullptr) { - // Allocates data_ inside itself - Status s = AllocateBuffer(length); - if (s.IsError()) { - MS_LOG(ERROR) << "Failed to create memory for Tensor!"; - } - if (data_ != nullptr) { - int ret_code = memcpy_s(data_, length, data, length); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data into Tensor!"; - } - } - } -} - -Tensor::Tensor(Tensor &&other) noexcept - : shape_(other.shape()), - type_(other.type()), - data_(other.GetMutableBuffer()), - data_allocator_(std::move(other.data_allocator_)) { - other.Invalidate(); -} - -Tensor &Tensor::operator=(Tensor &&other) noexcept { - if (&other != this) { - shape_ = other.shape(); - type_ = other.type(); - data_ = other.GetMutableBuffer(); - data_end_ = other.data_end_; - data_allocator_ = std::move(other.data_allocator_); - other.Invalidate(); - } - return *this; -} - -Tensor::Tensor(const std::vector &strings, const TensorShape &shape) - : Tensor(TensorShape({static_cast(strings.size())}), DataType(DataType::DE_STRING)) { - auto length_sum = [](dsize_t sum, const std::string &s) { return s.length() + sum; }; - dsize_t total_length = std::accumulate(strings.begin(), strings.end(), 0, length_sum); - - // total bytes needed = offset array + strings - // offset array needs to store one offset var per element + 1 extra to get the length of the last string. - // strings will be null-terminated --> need 1 extra byte per element - dsize_t num_bytes = (kOffsetSize + 1) * shape_.NumOfElements() + kOffsetSize + total_length; - - data_ = data_allocator_->allocate(num_bytes); - - auto offset_arr = reinterpret_cast(data_); - uchar *buf = GetStringsBuffer(); - - offset_t offset = buf - data_; // the first string will start here - uint32_t i = 0; - for (const auto &str : strings) { - // insert the start index of the string. - offset_arr[i++] = offset; - // total bytes are reduced by kOffsetSize - num_bytes -= kOffsetSize; - // insert actual string - int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); - if (ret_code != 0) MS_LOG(ERROR) << "Cannot copy string into Tensor"; - // next string will be stored right after the current one. - offset = offset + str.length() + 1; - // total bytes are reduced by the length of the string - num_bytes -= str.length() + 1; - } - // store one more offset value so we can get the length of the last string - // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] - offset_arr[i] = offset; - - this->data_end_ = data_ + offset_arr[i]; - - MS_ASSERT(num_bytes == 0); - if (shape.known()) Tensor::Reshape(shape); -} -Tensor::Tensor(const dataengine::BytesList &bytes_list, const TensorShape &shape) - : Tensor(TensorShape({static_cast(bytes_list.value_size())}), DataType(DataType::DE_STRING)) { - // total bytes needed = offset array + strings - // offset array needs to store one offset var per element + 1 extra to get the length of the last string. - // strings will be null-terminated --> need 1 extra byte per element - dsize_t num_bytes = (kOffsetSize)*shape_.NumOfElements() + kOffsetSize + bytes_list.ByteSizeLong(); - - data_ = data_allocator_->allocate(num_bytes); - - auto offset_arr = reinterpret_cast(data_); - uchar *buf = GetStringsBuffer(); - - offset_t offset = buf - data_; // the first string will start here - uint32_t i = 0; - for (; i < bytes_list.value_size(); i++) { - const std::string &str = bytes_list.value(i); - // insert the start index of the string. - offset_arr[i] = offset; - // total bytes are reduced by kOffsetSize - num_bytes -= kOffsetSize; - // insert actual string - int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); - if (ret_code != 0) { - MS_LOG(ERROR) << "Cannot copy string into Tensor"; - } - // next string will be stored right after the current one. - offset = offset + str.length() + 1; - // total bytes are reduced by the length of the string - num_bytes -= str.length() + 1; - } - // store one more offset value so we can get the length of the last string - // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] - offset_arr[i] = offset; - - data_end_ = data_ + offset_arr[i]; - - MS_ASSERT(num_bytes == 0); - if (shape.known()) Tensor::Reshape(shape); -} -Status Tensor::CreateTensor(std::shared_ptr *ptr, TensorImpl tensor_impl, const TensorShape &shape, - DataType type, const unsigned char *data) { - if (!shape.known()) { - RETURN_STATUS_UNEXPECTED("Invalid shape."); - } - if (type == DataType::DE_UNKNOWN) { - RETURN_STATUS_UNEXPECTED("Invalid data type."); - } - - switch (tensor_impl) { - case TensorImpl::kFlexible: { - // The flex tensor is really just the base class tensor implementation - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, shape, type, data); - break; - } - case TensorImpl::kCv: { - const CVTensorAlloc *alloc = GlobalContext::Instance()->cv_tensor_allocator(); - *ptr = std::allocate_shared(*alloc, shape, type, data); - break; - } - default: { - std::string err_msg("Invalid tensor implementation type."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - return Status::OK(); // returns base-class shared_ptr -} - -Status Tensor::CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr) { - std::vector shape; - for (dsize_t i = 0; i < arr.ndim(); i++) { - shape.push_back(static_cast(arr.shape()[i])); - } - arr.resize({arr.size()}); // flatten the py::array so we can iterate once - std::vector strings; - - if (arr.dtype().kind() == 'U') { - std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast(s)); }); - } else { - std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast(s)); }); - } - - arr.resize(shape); // resize arr back to the original shape - - return CreateTensor(ptr, strings, TensorShape{shape}); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, py::array arr) { - if (DataType::FromNpArray(arr) == DataType::DE_STRING) { - return CreateTensorFromNumpyString(ptr, arr); - } - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, TensorShape({}), DataType(DataType::DE_UNKNOWN)); - - std::vector shape; - for (dsize_t i = 0; i < arr.ndim(); i++) { - shape.push_back(static_cast(arr.shape()[i])); - } - - (*ptr)->shape_ = TensorShape(shape); - (*ptr)->type_ = DataType::FromNpArray(arr); - if (!(*ptr)->shape_.known()) RETURN_STATUS_UNEXPECTED("Invalid shape."); - - if ((*ptr)->type_ == DataType::DE_UNKNOWN) RETURN_STATUS_UNEXPECTED("Invalid data type."); - - std::shared_ptr global_pool = GlobalContext::Instance()->mem_pool(); - (*ptr)->data_allocator_ = std::make_unique>(global_pool); - int64_t byte_size = (*ptr)->SizeInBytes(); - RETURN_IF_NOT_OK((*ptr)->AllocateBuffer(byte_size)); - - unsigned char *data = static_cast(arr.request().ptr); - if ((*ptr)->data_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Failed to create memory for Tensor."); - } - - std::vector strides; - for (dsize_t i = 0; i < arr.ndim(); i++) { - strides.push_back(static_cast(arr.strides()[i])); - } - - // check if strides are contiguous - bool is_strided = false; - dsize_t count = (*ptr)->shape_.NumOfElements(); - for (size_t i = 0; i < shape.size(); i++) { - count /= shape[i]; - if (strides[i] != (*ptr)->type_.SizeInBytes() * count) { - is_strided = true; - break; - } - } - - if (is_strided) { - RETURN_IF_NOT_OK(CopyStridedArray((*ptr)->data_, data, shape, strides, (*ptr)->type_.SizeInBytes())); - } else { - int ret_code = memcpy_s((*ptr)->data_, byte_size, data, byte_size); - if (ret_code != 0) { - RETURN_STATUS_UNEXPECTED("Failed to copy data into Tensor."); - } - } - - return Status::OK(); // returns base-class shared_ptr -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::vector &strings, - const TensorShape &shape) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, strings, shape); - return Status::OK(); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, bytes_list, shape); - return Status::OK(); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::string &file_path) { - std::ifstream fs; - fs.open(file_path, std::ios::binary | std::ios::in); - CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + file_path); - int64_t num_bytes = fs.seekg(0, std::ios::end).tellg(); - CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Fail to find size of file"); - RETURN_IF_NOT_OK( - Tensor::CreateTensor(ptr, TensorImpl::kFlexible, TensorShape{num_bytes}, DataType(DataType::DE_UINT8))); - int64_t written_bytes = fs.read(reinterpret_cast((*ptr)->GetMutableBuffer()), num_bytes).gcount(); - CHECK_FAIL_RETURN_UNEXPECTED(written_bytes == num_bytes && fs.good(), "Error in writing to tensor"); - fs.close(); - return Status::OK(); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape, const DataType &type, dsize_t pad_size) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(ptr, TensorImpl::kFlexible, shape, type)); - - unsigned char *current_tensor_addr = (*ptr)->GetMutableBuffer(); - int64_t tensor_bytes_remaining = bytes_list.value_size() * pad_size; - - for (int i = 0; i < bytes_list.value_size(); i++) { - // read string data into tensor - const std::string ¤t_element = bytes_list.value(i); - int return_code = - memcpy_s(current_tensor_addr, tensor_bytes_remaining, common::SafeCStr(current_element), current_element.size()); - - CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when reading bytesList element into Tensor"); - - current_tensor_addr += current_element.size(); - tensor_bytes_remaining -= current_element.size(); - - // pad - int64_t chars_to_pad = pad_size - current_element.size(); - return_code = memset_s(current_tensor_addr, tensor_bytes_remaining, static_cast(' '), chars_to_pad); - CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when padding Tensor"); - - current_tensor_addr += chars_to_pad; - tensor_bytes_remaining -= chars_to_pad; - } - - return Status::OK(); -} - -// Memcpy the given strided array's used part to consecutive memory -// Consider a 3-d array -// A[(i * shape[1] + j) * shape[2] + k] = B[i][j][k] = C[i * strides[0] + j * strides[1] + k * strides[2]] -// Here we convert array C to array A, by memcpy index by index (Note that not all elements in C is copied) -Status Tensor::CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, - std::vector strides, uint8_t type_size) { - dsize_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - for (dsize_t i = 0; i < size; ++i) { - dsize_t offset = 0; - dsize_t count = i; - for (size_t j = 0; j < shape.size(); ++j) { - // convert 1d array's index to 3d array's index (A -> B) - dsize_t idx = count % shape[shape.size() - 1 - j]; - count /= shape[shape.size() - 1 - j]; - // calculate the raw data offset based on strides (B -> C) - offset += idx * strides[shape.size() - 1 - j]; - // once count = 0, the following idxes are all zero, skip them - if (count == 0) break; - } - // strides already consider byte size of the data type, but dst doesn't. - // dst[i] = dst + i * type_size = src + offset - int ret_code = memcpy_s(dst + i * type_size, type_size, src + offset, type_size); - if (ret_code != 0) { - RETURN_STATUS_UNEXPECTED("Failed to copy data into Tensor."); - } - } - return Status::OK(); -} - -// Name: Destructor -// Description: Destructor -Tensor::~Tensor() { - if (data_ != nullptr) { - if (data_allocator_ != nullptr) { - data_allocator_->deallocate(data_); - data_ = nullptr; - data_end_ = nullptr; - } else { - // If we didn't have an allocator, but data_ is not null then it must - // be a stand-alone tensor that used malloc directly. - free(data_); - data_ = nullptr; - data_end_ = nullptr; - } - } -} - -bool Tensor::operator==(const Tensor &rhs) const { - // 1. different shape 2. different type 3. one data_ is nullptr and the other is not - if (shape_ != rhs.shape() || type_ != rhs.type_ || (data_ == nullptr && rhs.data_ != nullptr) || - (data_ != nullptr && rhs.data_ == nullptr)) { - return false; - } - if (data_ == nullptr && rhs.data_ == nullptr) { - return true; - } - // use mem compare to compare the two data, size are already verified - return memcmp(data_, rhs.data_, SizeInBytes()) == 0; -} - -// Name: PrintItemAt() -// Description: A function that print the value as specified by its index -void Tensor::PrintItemAt(const std::vector &index, std::ostream &out) const { - Status rc; - MS_ASSERT(data_); - - switch (type_.value()) { - CASE_PRINT_HEX(DataType::DE_BOOL, bool); - - CASE_PRINT_HEX(DataType::DE_INT8, int8_t); - - CASE_PRINT_HEX(DataType::DE_UINT8, uint8_t); - - CASE_PRINT(DataType::DE_INT16, int16_t); - - CASE_PRINT(DataType::DE_UINT16, uint16_t); - - CASE_PRINT(DataType::DE_INT32, int32_t); - - CASE_PRINT(DataType::DE_UINT32, uint32_t); - - CASE_PRINT(DataType::DE_INT64, int64_t); - - CASE_PRINT(DataType::DE_UINT64, uint64_t); - - CASE_PRINT(DataType::DE_FLOAT16, float16); - - CASE_PRINT(DataType::DE_FLOAT32, float); - - CASE_PRINT(DataType::DE_FLOAT64, double); - - case DataType::DE_STRING: { - std::string_view o{""}; - GetItemAt(&o, index); - out << "\"" << o << "\""; - break; - } - default: { - out << "?"; - break; - } - } - if (rc.IsError()) { - out << rc.ToString(); - } -} - -// Name: PrintRecursive() -// Description: A function that prints Tensor recursively, first called by print -void Tensor::PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const { - if (cur_index.size() == shape_.Rank()) { - PrintItemAt(cur_index, out); - } else { - out << "["; - for (dsize_t i = 0; i < shape_[cur_dim]; i++) { - std::vector new_index = cur_index; - new_index.push_back(i); - PrintRecursive(out, cur_dim + 1, new_index); - if (i < shape_[cur_dim] - 1) { - out << ","; - } - } - out << "]"; - } -} - -// Name: Print() -// Description: A function that prints info about the tensor -void Tensor::Print(std::ostream &out) const { - out << "Tensor (shape: "; - out << shape_; - out << ", Type: " << type_ << ")\n"; - if (data_) { - PrintRecursive(out, 0, std::vector{}); - } else { - out << "[Data area is null]"; - } -} -Status Tensor::AllocateBuffer(const dsize_t &length) { - if (data_ == nullptr) { - if (data_allocator_ != nullptr) { - data_ = data_allocator_->allocate(length); - RETURN_UNEXPECTED_IF_NULL(data_); - data_end_ = data_ + length; - } else { - data_ = static_cast(malloc(length)); - data_end_ = data_ + length; - RETURN_UNEXPECTED_IF_NULL(data_); - } - } - return Status::OK(); -} -const unsigned char *Tensor::GetBuffer() const { - // This version cannot modify anything. data_ could possibly be null. - return data_; -} - -unsigned char *Tensor::GetMutableBuffer() { - if (!shape_.known() || type_ == DataType::DE_UNKNOWN) { - return nullptr; - } - // If the data area is already created, return the pointer to it - if (data_ != nullptr) { - return data_; - } else { - // If the data area is not created, then identify the memory size based - // on the shape and type and allocate it. - if (this->AllocateBuffer(this->SizeInBytes()).IsOk()) { - return data_; - } else { - return nullptr; - } - } -} - -Status Tensor::Reshape(const TensorShape &shape) { - if (shape.NumOfElements() == shape_.NumOfElements()) { - shape_ = shape; - return Status::OK(); - } else { - std::string err = "Cannot reshape, Number of elements do not match"; - RETURN_STATUS_UNEXPECTED(err); - } -} - -void Tensor::Invalidate() { - shape_ = TensorShape::CreateUnknownRankShape(); - type_ = DataType(DataType::DE_UNKNOWN); - data_ = nullptr; - data_end_ = nullptr; - data_allocator_ = nullptr; -} - -template -Status Tensor::GetItemPtr(T **ptr, const std::vector &index) const { - if (type_.IsCompatible()) { - if (data_ == nullptr) { - std::string err = "Data is not allocated yet"; - RETURN_STATUS_UNEXPECTED(err); - } - dsize_t flat_idx; - RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &flat_idx)); - *ptr = reinterpret_cast(data_ + flat_idx * type_.SizeInBytes()); - - return Status::OK(); - } else { - std::string err = "data type not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } -} - -Status Tensor::GetItemPtr(uchar **ptr, const std::vector &index, offset_t *length) const { - if (type_ == DataType::DE_STRING) { - if (data_ == nullptr) { - std::string err = "Data is not allocated yet"; - RETURN_STATUS_UNEXPECTED(err); - } - dsize_t flat_idx; - RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &flat_idx)); - offset_t length_temp = 0; - RETURN_IF_NOT_OK(GetStringAt(flat_idx, ptr, &length_temp)); - if (length != nullptr) *length = length_temp; - return Status::OK(); - } else { - std::string err = "data type not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } -} - -Status Tensor::StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining) { - if (type() == DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("StartAddrOfIndex does not support string tensors yet."); - } - - dsize_t flat_ind; - std::vector t_shape = shape().AsVector(); - std::vector r(t_shape.begin() + ind.size(), t_shape.end()); - *remaining = TensorShape(r); - ind.resize(this->Rank(), 0); // same as -> while (ind.size() < this->Rank()) ind.push_back(0); - - RETURN_IF_NOT_OK(shape_.ToFlatIndex(ind, &flat_ind)); - // check if GetBuffer() returns null, we should flag this as an error, this sanity check will only - // be true is the tensor failed to allocate memory. - if (GetMutableBuffer() == nullptr) { - RETURN_STATUS_UNEXPECTED("Invalid GetBuffer in Tensor, got nullptr"); - } - *start_addr_of_index = GetMutableBuffer() + flat_ind * this->type().SizeInBytes(); - return Status::OK(); -} - -Status Tensor::InsertTensor(const std::vector &ind, const std::shared_ptr &tensor) { - std::string err_msg; - err_msg += (this->type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; - err_msg += (!this->shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; - err_msg += (ind.size() + tensor->Rank() != this->Rank()) ? "[Tensor] incorrect index\n" : ""; - err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; - uchar *start_addr_of_ind = nullptr; - TensorShape remaining_shape({-1}); - err_msg += (!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : ""; - err_msg += !(remaining_shape == tensor->shape()) ? "[Tensor] memory error\n" : ""; - if (!err_msg.empty()) { - MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - if (start_addr_of_ind != nullptr) { - int ret_code = - memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); - if (ret_code == 0) { - return Status::OK(); - } else { - err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; - MS_LOG(DEBUG) << "Tensor message: " << err_msg; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } else { - RETURN_STATUS_UNEXPECTED("Failed to create memory for Tensor."); - } - } -} - -Status Tensor::Concatenate(const std::vector &index, const std::shared_ptr &tensor) { - std::string err_msg; - err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : ""; - err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; - err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; - - err_msg += - (index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : ""; - err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; - uchar *start_addr_of_ind = nullptr; - - TensorShape remaining_shape = tensor->shape(); - StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape); - err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : ""; - - if (!err_msg.empty()) { - MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; - - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - int ret_code = - memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); - - if (ret_code == 0) { - return Status::OK(); - } else { - err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; - MS_LOG(DEBUG) << "Tensor message: " << err_msg; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } -} - -Status Tensor::ExpandDim(const dsize_t &axis) { - if (axis > Rank()) { - std::string err = "Axis is out of bound"; - RETURN_STATUS_UNEXPECTED(err); - } - if (axis == Rank()) { - shape_ = shape_.AppendDim(1); - } else { - shape_ = shape_.InsertDim(axis, 1); - } - return Status::OK(); -} - -std::vector Tensor::Strides() { - std::vector strides = shape_.Strides(); - uint8_t size = type_.SizeInBytes(); - std::transform(strides.begin(), strides.end(), strides.begin(), [&size](const auto &c) { return c * size; }); - return strides; -} - -Status Tensor::GetBufferInfo(Tensor &t, py::buffer_info *out) { - CHECK_FAIL_RETURN_UNEXPECTED(t.type().IsNumeric(), "Cannot use GetBufferInfo on tensor of strings."); - - std::string format_desc = t.type().GetPybindFormat(); - if (format_desc.empty()) { - RETURN_STATUS_UNEXPECTED("Cannot convert DE type tp pybind format"); - } - *out = py::buffer_info(t.GetMutableBuffer(), /* Pointer to buffer */ - t.type().SizeInBytes(), /* Size of one scalar */ - format_desc, /* Python struct-style format descriptor */ - t.Rank(), /* Number of dimensions */ - t.shape().AsVector(), /* Buffer dimensions */ - t.Strides()); - return Status::OK(); -} - -template -Status Tensor::GetItemAt(T *o, const std::vector &index) const { - if (data_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); - } - if (!type_.IsLooselyCompatible()) { - std::string err = "Template type and Tensor type are not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } - if (type_.IsUnsignedInt()) { - RETURN_IF_NOT_OK(GetUnsignedIntAt(o, index)); - } else if (type_.IsSignedInt()) { - RETURN_IF_NOT_OK(GetSignedIntAt(o, index)); - } else if (type_.IsFloat()) { - RETURN_IF_NOT_OK(GetFloatAt(o, index)); - } else if (type_.IsBool()) { - bool *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - } else { - std::string err = "Tensor Type is unknown"; - RETURN_STATUS_UNEXPECTED(err); - } - return Status::OK(); -} - -Status Tensor::GetItemAt(std::string_view *o, const std::vector &index) const { - RETURN_UNEXPECTED_IF_NULL(data_); - RETURN_UNEXPECTED_IF_NULL(o); - CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Tensor type is not a string"); - - uchar *start = nullptr; - offset_t length = 0; - RETURN_IF_NOT_OK(GetItemPtr(&start, index, &length)); - std::string_view sv{reinterpret_cast(start)}; - o->swap(sv); - return Status::OK(); -} -// return data as numpy, should return status -Status Tensor::GetDataAsNumpy(py::array *data) { - RETURN_UNEXPECTED_IF_NULL(data_); - RETURN_UNEXPECTED_IF_NULL(data); - if (type_ == DataType::DE_BOOL) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_INT8) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_INT16) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_INT32) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_INT64) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_UINT8) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_UINT16) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_UINT32) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_UINT64) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_FLOAT16) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_FLOAT32) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_FLOAT64) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_STRING) { - GetDataAsNumpyStrings(data); - } else { - RETURN_STATUS_UNEXPECTED("Got unexpected type when returning numpy"); - } - return Status::OK(); -} -Status Tensor::GetDataAsNumpyStrings(py::array *data) { - auto itr = begin(); - uint64_t max = 0; - for (; itr != end(); itr++) { - max = std::max((*itr).length(), max); - } - // if all strings are empty, numpy stores a byte for each string |S1 - max = (max == 0 ? 1 : max); - uint64_t total_size = shape_.NumOfElements() * max; - char *tmp_data = reinterpret_cast(data_allocator_->allocate(total_size)); - if (tmp_data == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create temp array."); - int ret_code = memset_s(tmp_data, total_size, 0, total_size); - CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to initialize temp memory"); - - itr = begin(); - uint64_t i = 0; - for (; itr != end(); itr++, i++) { - if (!(*itr).empty()) { - ret_code = memcpy_s(tmp_data + i * max, total_size, (*itr).data(), (*itr).length()); - CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy string data."); - } - } - auto strides = shape_.Strides(); - std::transform(strides.begin(), strides.end(), strides.begin(), [&max](const auto &s) { return s * max; }); - *data = py::array(py::dtype("S" + std::to_string(max)), shape_.AsVector(), strides, tmp_data); - data_allocator_->deallocate(reinterpret_cast(tmp_data)); - return Status::OK(); -} - -void Tensor::Squeeze() { shape_ = shape_.Squeeze(); } - -template -Status Tensor::GetUnsignedIntAt(T *o, const std::vector &index) const { - if (data_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); - } - if (!type_.IsLooselyCompatible()) { - std::string err = "Template type and Tensor type are not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } - switch (type_.value()) { - case DataType::DE_UINT8: { - uint8_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_UINT16: { - uint16_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_UINT32: { - uint32_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_UINT64: { - uint64_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - default: - std::string err = "Tensor Type is not an unsigned Integer"; - RETURN_STATUS_UNEXPECTED(err); - } - return Status::OK(); -} - -template -Status Tensor::GetSignedIntAt(T *o, const std::vector &index) const { - if (data_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); - } - if (!type_.IsLooselyCompatible()) { - std::string err = "Template type and Tensor type are not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } - switch (type_.value()) { - case DataType::DE_INT8: { - int8_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_INT16: { - int16_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_INT32: { - int32_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_INT64: { - int64_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - default: - std::string err = "Tensor Type is not a signed Integer"; - RETURN_STATUS_UNEXPECTED(err); - } - return Status::OK(); -} - -template -Status Tensor::GetFloatAt(T *o, const std::vector &index) const { - if (data_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); - } - if (!type_.IsLooselyCompatible()) { - std::string err = "Template type and Tensor type are not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } - switch (type_.value()) { - case DataType::DE_FLOAT16: { - float16 *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_FLOAT32: { - float *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_FLOAT64: { - double *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - default: - std::string err = "Tensor Type is not a float/double"; - RETURN_STATUS_UNEXPECTED(err); - } - return Status::OK(); -} -Status Tensor::GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const { - CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Type is not string"); - RETURN_UNEXPECTED_IF_NULL(data_); - RETURN_UNEXPECTED_IF_NULL(string_start); - RETURN_UNEXPECTED_IF_NULL(length); - auto *offset_ptr = reinterpret_cast(data_); // offsets starts here - offset_t start = offset_ptr[index]; - *string_start = data_ + start; - *length = offset_ptr[index + 1] - start - 1; // -1 to skip the \0 from the string length - return Status::OK(); -} -Status Tensor::CopyLastDimAt(const std::shared_ptr &src, const std::vector &index) { - CHECK_FAIL_RETURN_UNEXPECTED(src->type() == type_, "Source Tensor has a different type"); - CHECK_FAIL_RETURN_UNEXPECTED(index.back() == 0, "Last dim in index should be 0"); - - uint8_t type_size = type_.SizeInBytes(); - size_t len = std::min(src->shape()[-1], shape_[-1]) * type_size; - dsize_t src_flat_ind = 0, dst_flat_ind = 0; - RETURN_IF_NOT_OK(src->shape().ToFlatIndex(index, &src_flat_ind)); - RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &dst_flat_ind)); - - const unsigned char *src_addr = src->GetBuffer() + src_flat_ind * type_size; - unsigned char *dst_addr = GetMutableBuffer() + dst_flat_ind * type_size; - CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error"); - return Status::OK(); -} -Status Tensor::Slice(std::shared_ptr *out, const std::vector &indices) { - CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only."); - CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty."); - if (type_.IsNumeric()) { - return SliceNumeric(out, indices); - } else { - return SliceString(out, indices); - } -} -Status Tensor::SliceNumeric(std::shared_ptr *out, const std::vector &indices) { - RETURN_IF_NOT_OK( - CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast(indices.size())}), type_)); - (*out)->GetMutableBuffer(); - dsize_t out_index = 0; - dsize_t dim_length = shape_[0]; - dsize_t type_size = type_.SizeInBytes(); - dsize_t src_start = HandleNeg(indices[0], dim_length); - uchar *dst_addr = (*out)->data_; - dsize_t count = 1; - - for (dsize_t i = 0; i < indices.size(); i++) { - dsize_t cur_index = HandleNeg(indices[i], dim_length); - CHECK_FAIL_RETURN_UNEXPECTED( - cur_index >= 0 && cur_index < dim_length, - "Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")"); - if (i < indices.size() - 1) { - dsize_t next_index = HandleNeg(indices[i + 1], dim_length); - if (next_index == cur_index + 1) { - count++; - continue; - } - } - int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, - count * type_size); - CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric"); - out_index += count; - if (i < indices.size() - 1) { - src_start = HandleNeg(indices[i + 1], dim_length); // next index - } - count = 1; - } - return Status::OK(); -} -Status Tensor::SliceString(std::shared_ptr *out, const std::vector &indices) { - dsize_t dim_length = shape_[0]; - std::vector strings; - for (dsize_t index : indices) { - dsize_t cur_index = HandleNeg(index, dim_length); - CHECK_FAIL_RETURN_UNEXPECTED( - cur_index >= 0 && cur_index < dim_length, - "Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")"); - std::string_view sv; - GetItemAt(&sv, {cur_index}); - strings.emplace_back(sv); - } - return CreateTensor(out, strings); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/tensor.h b/mindspore/ccsrc/dataset/core/tensor.h deleted file mode 100644 index 9fed0bbc97..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor.h +++ /dev/null @@ -1,652 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_TENSOR_H_ -#define DATASET_CORE_TENSOR_H_ - -#include -#include -#include -#include -#include "./securec.h" -#include "utils/log_adapter.h" -#if defined(_WIN32) || defined(_WIN64) -#undef HAVE_STDDEF_H -#undef HAVE_STDLIB_H -#endif -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "dataset/core/constants.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/util/allocator.h" -#include "dataset/util/status.h" -#include "proto/example.pb.h" - -namespace py = pybind11; -namespace mindspore { -namespace dataset { -class Tensor; - -using CharAllocPtr = std::unique_ptr>; -using TensorAllocPtr = std::shared_ptr>; // An allocator shared_ptr for Tensors - -class Tensor { - public: - Tensor() = delete; - - // Create a new tensor, does not internally allocate storage. This constructor is protected, use CreateTensor. - // @note The shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - Tensor(const TensorShape &shape, const DataType &type); - - // Create a new tensor, allocates storage and copies in data. This constructor is protected, use CreateTensor. - // @note The buffer should be valid and the shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - // @param data unsigned char*, pointer to the data. - Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data); - - Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length); - - Tensor(const Tensor &other) = delete; - - Tensor &operator=(const Tensor &other) = delete; - - Tensor(Tensor &&other) noexcept; - - Tensor &operator=(Tensor &&other) noexcept; - - Status AllocateBuffer(const dsize_t &length); - - // type of offest values to store strings information - using offset_t = uint32_t; - // const of the size of the offset variable - static constexpr uint8_t kOffsetSize = sizeof(offset_t); - // Tensor base class which holds the data in an unsigned char* buffer. - - // Construct a scalar string Tensor - explicit Tensor(const std::string &str) : Tensor(std::vector{str}, TensorShape::CreateScalar()) {} - - // Construct a tensor from a list of strings. Reshape the tensor with `shape` if given, otherwise assume the shape is - // the size of the vector `strings`. - // The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. - // Thr offset array will store one extra value to find the length of the last string. - // OFFSET1, OFFSET2, ..., OFFSETn+1, STRING1, STRING2, ..., STRINGn - // The value of each offset is the start index of the corresponding string - // Offsets is of type offest_t - // strings will ne null-terminated - // example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) - // |----------------------------------------------------------------| - // | OFFSET ARRAY | STRINGS | - // | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | - // | 11 | 15 | 18 | abc\0 | de\0 | - // |----------------------------------------------------------------| - explicit Tensor(const std::vector &strings, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // Same as Tensor(vector) but the input is protobuf bytelist - explicit Tensor(const dataengine::BytesList &bytes_list, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // A static factory method to create the given flavour of derived Tensor - // Returns the base class reference for the Tensor. - // @param ptr output argument to hold the created Tensor of given tensor_impl - // @param tensor_impl - which implementation of Tensor - // @param shape - shape of the tensor - // @param type - datatype of the tensor - // @param data - data to be copied to Tensor new allocation - // @return Status Code - static Status CreateTensor(std::shared_ptr *, TensorImpl tensor_impl, const TensorShape &shape, DataType type, - const unsigned char *data = nullptr); - - /// Create a copy of the input tensor - /// \param out [out] output tensor to be generated - /// \param in [in] orginal tensor to be copied - /// \return Status - static Status CreateTensor(std::shared_ptr *out, const std::shared_ptr &in) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *out = std::allocate_shared(*alloc, in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes()); - return Status::OK(); - } - - // A static factory method to create a Tensor from a given py::array. - // @param ptr output argument to hold the created Tensor - // @param arr py::array - // @return Status Code - static Status CreateTensor(std::shared_ptr *ptr, py::array arr); - - // Helper function to create a tensor from Numpy of strings - static Status CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr); - - // A static factory method to create a Tensor from a given list of strings. - // @param ptr output argument to hold the created Tensor - // @param strings elements of the tensor - // @param shape shape of the tensor - // @return Status Code - static Status CreateTensor(std::shared_ptr *ptr, const std::vector &strings, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // create tensor from protobuf bytelist with strings - static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape); - - // A static factory method to create a Tensor from a given list of numbers. - // @param ptr output argument to hold the created Tensor - // @param items elements of the tensor - // @param shape shape of the tensor - // @return Status Code - template - static Status CreateTensor(std::shared_ptr *ptr, const std::vector &items, - const TensorShape &shape_req = TensorShape::CreateUnknownRankShape()) { - DataType type = DataType::FromCType(); - auto items_ptr = reinterpret_cast(&items[0]); - TensorShape shape = shape_req; - if (!shape.known()) { - shape = TensorShape({static_cast(items.size())}); - } - return CreateTensor(ptr, TensorImpl::kFlexible, shape, type, items_ptr); - } - - // A static factory method to create a Tensor from a given number. - // @param ptr output argument to hold the created Tensor - // @param item value - // @return Status Code - template - static Status CreateTensor(std::shared_ptr *ptr, const T &item) { - return CreateTensor(ptr, {item}, TensorShape::CreateScalar()); - } - // Create tensor from protobuf bytelist with uint8 or int8 types - static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape, const DataType &type, dsize_t pad_size); - - static Status CreateTensor(std::shared_ptr *ptr, const std::string &path); - - // Copy raw data of a array based on shape and strides to the destination pointer - // @param dst Pointer to the destination array where the content is to be copied - // @param src Pointer to the source of strided array to be copied - // @param shape - shape of the source array - // @param strides - strides of the source array - // @param type_size - number of bytes needed to store one array element's type - // @return Status Code - static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, - std::vector strides, uint8_t type_size); - - // Release the memory using the allocator - virtual ~Tensor(); - - // compare the tensor shape and data - bool operator==(const Tensor &rhs) const; - - bool operator!=(const Tensor &rhs) const { return !((*this) == rhs); } - - // Get item located at `index`, caller needs to provide the type. - // @tparam T - // @param index vector - // @return return the item specified at index - template - Status GetItemAt(T *o, const std::vector &index) const; - - // Get string located at `index`. - // @param index vector - // @return return std::string_view specified at index - Status GetItemAt(std::string_view *o, const std::vector &index) const; - - template - Status GetUnsignedIntAt(T *o, const std::vector &index) const; - - template - Status GetSignedIntAt(T *o, const std::vector &index) const; - - template - Status GetFloatAt(T *o, const std::vector &index) const; - - // set item at location specified by index - // @tparam `T` - // @param index - // @param value of type `T` - template - Status SetItemAt(const std::vector &index, const T &value) { - RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); - T *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *ptr = value; - return Status::OK(); - } - - // set string item at location specified by index - // @param index - // @param value of type std::string - Status SetItemAt(const std::vector &index, const std::string &value) { - RETURN_UNEXPECTED_IF_NULL(data_); - uchar *ptr = nullptr; - offset_t length = 0; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index, &length)); - if (value.length() != length) { - RETURN_STATUS_UNEXPECTED("Length of the new string does not match the item."); - } - memcpy_s(reinterpret_cast(ptr), length, value.c_str(), length); - - return Status::OK(); - } - // fill tensor with Zeros. Does not support strings. - Status Zero() { - CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use Zero on tensor of strings.."); - dsize_t size = SizeInBytes(); - CHECK_FAIL_RETURN_UNEXPECTED(memset_sp(GetMutableBuffer(), size, 0, size) == 0, - "Failed to fill tensor with zeroes."); - return Status::OK(); - } - - // Fill all elements in the Tensor with the given value of type `T`. Does not support strings. - // @tparam T - // @param value - template - Status Fill(const T &value) { - CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use fill on tensor of strings."); - RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); - int64_t cellSize = type_.SizeInBytes(); - if ((data_ != nullptr) && type_.IsCompatible()) { - for (dsize_t i = 0; i < Size(); i++) { - CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s((data_ + i * cellSize), cellSize, &value, cellSize) == 0, "memcpy err"); - } - return Status::OK(); - } else { - std::string err; - err += (data_ == nullptr) ? "data_ is nullptr \t" : ""; - err += type_.IsCompatible() ? "data type not compatible\t" : ""; - return Status(StatusCode::kUnexpectedError, err); - } - } - - // Getter function for shape - // @return - const TensorShape &shape() const { return shape_; } - - // Reshape the tensor. The given shape should have the same number of elements in the Tensor - // @param shape - virtual Status Reshape(const TensorShape &shape); - - // @return number of elements in this tensor - dsize_t Size() const { return shape().NumOfElements(); } - - // @return the number of bytes this tensor is needs - dsize_t SizeInBytes() const { - if (data_end_ == nullptr) return type_.SizeInBytes() * shape_.NumOfElements(); - return data_end_ - data_; - } - - // @return the rank of the tensor - dsize_t Rank() const { return shape().Rank(); } - - // Get the starting memory address as a constant for the data of the tensor. This potentially - // drives an allocation if the data area. - // @return const unsigned char* - const unsigned char *GetBuffer() const; - - // Getter of the type - // @return - DataType type() const { return type_; } - - // Provide stream operator for displaying it - // @param output stream - // @param so the Tensor object to be printed - // @return output stream - friend std::ostream &operator<<(std::ostream &out, const Tensor &so) { - so.Print(out); - return out; - } - - // Invalidate this Tensor by setting the type and shape to unknown and MData to null. - // Calling this method will make the Tensor and its data inaccessible, use it with caution. - void Invalidate(); - - // Copy input tensor into self at the location index. - // Index is a vector of axises which can be incomplete: - // Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. - // @param index - // @param input - // @return Status code - Status InsertTensor(const std::vector &index, const std::shared_ptr &input); - - // Find the address of the given index. Used in InsertTensor. - // Example: - // Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 - // @param index incomplete index - // @param output: startAddrofIndex - // @param output: remaining - // @return Status code - Status StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining); - - // Expand the shape of the Tensor with one extra dimension. - // For example, if the shape is <512,512,3>: - // *- ExpandDim(0) gives: <1,512,512,3> - // *- ExpandDim(1) gives: <512,1,512,3> - // *- ExpandDim(3) gives: <512,512,3,1> - // @param axis location of the dim - virtual Status ExpandDim(const dsize_t &axis); - - virtual void Squeeze(); - - /// Calculates the strides of the Tensor - /// Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) - /// The strides will be {6,2,1}. - /// Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) - /// The strides will be {24,8,4}. - /// @return vector of integers - std::vector Strides(); - - std::string ToString() { - std::stringstream ss; - this->Print(ss); - return ss.str(); - } - - // Handle negative indices. - static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } - - // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. - // Based on the type of tensor, SliceNumeric or SliceString will be called - // @param out Tensor - // @param indices vector of indices - // @return Status error code - Status Slice(std::shared_ptr *out, const std::vector &indices); - - // Slice numeric tensors. - Status SliceNumeric(std::shared_ptr *out, const std::vector &indices); - - // Slice string tensors - Status SliceString(std::shared_ptr *out, const std::vector &indices); - - // Constructs numpy array from input tensor - // @param data this data is the location of python data - // @return Status code - Status GetDataAsNumpy(py::array *data); - - Status GetDataAsNumpyStrings(py::array *data); - - static Status GetBufferInfo(Tensor &t, py::buffer_info *out); - - // Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor - Status Concatenate(const std::vector &index, const std::shared_ptr &input); - - // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor - // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 - // @tparam T type of values in the Tensor Iterator - template - class TensorIterator { - public: - using iterator_category = std::random_access_iterator_tag; - using value_type = T; - using difference_type = ptrdiff_t; - using pointer = T *; - using reference = T &; - - explicit TensorIterator(uchar *ptr = nullptr) { ptr_ = reinterpret_cast(ptr); } - - TensorIterator(const TensorIterator &raw_iterator) { ptr_ = raw_iterator.ptr_; } - - ~TensorIterator() = default; - - TensorIterator &operator=(const TensorIterator &rhs) { - ptr_ = rhs.ptr_; - return *this; - } - - TensorIterator &operator=(T *rhs) { - ptr_ = rhs; - return *this; - } - - bool operator==(const TensorIterator &rhs) { return ptr_ == rhs.ptr_; } - - bool operator!=(const TensorIterator &rhs) { return !(*this == rhs); } - - operator bool() const { return ptr_ != nullptr; } - - T &operator*() { return *ptr_; } - - const T &operator*() const { return *ptr_; } - - T *operator->() { return ptr_; } - - TensorIterator &operator+=(const ptrdiff_t &inc) { - ptr_ += inc; - return *this; - } - - TensorIterator &operator-=(const ptrdiff_t &inc) { - ptr_ -= inc; - return *this; - } - - TensorIterator &operator++() { - ++ptr_; - return *this; - } - - TensorIterator &operator--() { - --ptr_; - return *this; - } - - TensorIterator operator++(int) { - auto temp(*this); - ++ptr_; - return temp; - } - - TensorIterator operator--(int) { - auto temp(*this); - --ptr_; - return temp; - } - - TensorIterator operator+(const ptrdiff_t &inc) { - auto oldPtr = ptr_; - ptr_ += inc; - auto temp(*this); - ptr_ = oldPtr; - return temp; - } - - TensorIterator operator-(const ptrdiff_t &inc) { - auto oldPtr = ptr_; - ptr_ -= inc; - auto temp(*this); - ptr_ = oldPtr; - return temp; - } - - protected: - T *ptr_; - }; - - // Specialization of TensorIterator for strings. It returns std::string_view for every item. - // @tparam DUMMY, used to mbe able to specialize the inner class - template - class TensorIterator { - public: - using iterator_category = std::random_access_iterator_tag; - using value_type = std::string_view; - using difference_type = ptrdiff_t; - using pointer = std::string_view *; - using reference = std::string_view &; - - explicit TensorIterator(uchar *data = nullptr, dsize_t index = 0) { - data_ = reinterpret_cast(data); - index_ = index; - } - - TensorIterator(const TensorIterator &raw_iterator) { - data_ = raw_iterator.data_; - index_ = raw_iterator.index_; - } - - ~TensorIterator() = default; - - bool operator==(const TensorIterator &rhs) { return data_ == rhs.data_ && index_ == rhs.index_; } - - bool operator!=(const TensorIterator &rhs) { return !(*this == rhs); } - - operator bool() const { return data_ != nullptr; } - - std::string_view operator*() const { - auto offset_ = reinterpret_cast(data_); - offset_t start = offset_[index_]; - return std::string_view{data_ + start}; - } - - TensorIterator &operator+=(const dsize_t &inc) { - index_ += inc; - return *this; - } - - TensorIterator &operator-=(const dsize_t &inc) { - index_ -= inc; - return *this; - } - - TensorIterator &operator++() { - ++index_; - return *this; - } - - TensorIterator &operator--() { - --index_; - return *this; - } - - TensorIterator operator++(int) { - auto temp(*this); - ++index_; - return temp; - } - - TensorIterator operator--(int) { - auto temp(*this); - --index_; - return temp; - } - - TensorIterator operator+(const dsize_t &inc) { - auto oldPtr = index_; - index_ += inc; - auto temp(*this); - index_ = oldPtr; - return temp; - } - - TensorIterator operator-(const dsize_t &inc) { - auto oldPtr = index_; - index_ -= inc; - auto temp(*this); - index_ = oldPtr; - return temp; - } - - protected: - dsize_t index_; - const char *data_; - }; - - // Return a TensorIterator that points to the start of the Tensor. - // It's the user responsibility to use the correct type that matches the Tensor type - // @tparam T The type of values in the Tensor - // @return TensorIterator - template - TensorIterator begin() { - AllocateBuffer(SizeInBytes()); - return TensorIterator(data_); - } - - // Return a linear iterator that points to the place after the last element of the Tensor. - // @tparam T The type of values in the Tensor - // @return TensorIterator - template - TensorIterator end() { - return TensorIterator(data_end_); - } - - // Copies the last dimension at `index` from Tensor `src` to this Tensor. - // @param src Tensor - // @param index vector to the start of the dimension. The last dim should be 0 - // @return Status - Status CopyLastDimAt(const std::shared_ptr &src, const std::vector &index); - - protected: - // Get the starting memory address for the data of the tensor. This potentially - // drives an allocation if the data is null. - // @return unsigned char* - unsigned char *GetMutableBuffer(); - - // A function that prints Tensor recursively, first called by print - // @param out - // @param cur_dim - // @param cur_index - void PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const; - - // A function that prints info about the tensor - // @param out output stream - void Print(std::ostream &out) const; - - // A function that print the value as specified by its index - // @param index vector representing the index - // @param out - void PrintItemAt(const std::vector &index, std::ostream &out) const; - - // Get pointer to item located at `index`, caller needs to provide the type. - // @tparam T - // @param index vector - // @return return a pointer to the item specified at index of type `T` - template - Status GetItemPtr(T **, const std::vector &index) const; - - // Get pointer to string located at `index` and the length of string - // @param index vector - // @return return a pointer to the string specified at index and the length of the string - Status GetItemPtr(uchar **, const std::vector &index, offset_t *length = nullptr) const; - - // Given a flat index of an item string, return the start and length of the item - // @param index flat index of the item - // @return start address of the ths string - // @return length of the string - Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const; - - // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the - // tensor's type is a string, otherwise undefined address would be returned. - // @return address of the first string of the tensor. - uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } - - // all access to shape_ should be via shape - TensorShape shape_; - // data type of tensor - DataType type_; - // pointer to the start of the physical data - unsigned char *data_; - // An allocator for data_ - CharAllocPtr data_allocator_; - // pointer to the end of the physical data - unsigned char *data_end_ = nullptr; -}; -template <> -inline Tensor::TensorIterator Tensor::end() { - return TensorIterator(data_, shape_.NumOfElements()); -} -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_TENSOR_H_ diff --git a/mindspore/ccsrc/dataset/core/tensor_row.cc b/mindspore/ccsrc/dataset/core/tensor_row.cc deleted file mode 100644 index 882f6728bf..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor_row.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "dataset/core/tensor_row.h" - -namespace py = pybind11; -namespace mindspore { -namespace dataset { - -TensorRow::TensorRow() noexcept : id_(kDefaultRowId) {} - -TensorRow::TensorRow(size_type n, TensorRow::value_type t) noexcept : id_(kDefaultRowId), row_(n, t) {} - -TensorRow::TensorRow(const TensorRow::vector_type &v) : id_(kDefaultRowId), row_(v) {} - -TensorRow::TensorRow(row_id_type id, const std::initializer_list &lst) : id_(id), row_(lst) {} - -TensorRow::TensorRow(const TensorRow &tr) : id_(tr.id_), row_(tr.row_) {} - -TensorRow &TensorRow::operator=(const TensorRow &tr) { - if (this == &tr) { - return *this; - } - row_ = tr.row_; - id_ = tr.id_; - return *this; -} - -TensorRow &TensorRow::operator=(const std::initializer_list &lst) { - row_ = lst; - return *this; -} - -TensorRow::TensorRow(TensorRow::vector_type &&v) noexcept : id_(kDefaultRowId), row_(std::move(v)) {} - -TensorRow::TensorRow(row_id_type id, std::initializer_list &&lst) noexcept - : id_(id), row_(std::move(lst)) {} - -TensorRow::TensorRow(TensorRow &&tr) noexcept { - id_ = tr.id_; - row_ = std::move(tr.row_); -} - -TensorRow &TensorRow::operator=(TensorRow &&tr) noexcept { - if (this == &tr) { - return *this; - } - row_ = std::move(tr.row_); - id_ = tr.id_; - tr.id_ = kDefaultRowId; - return *this; -} - -TensorRow &TensorRow::operator=(std::initializer_list &&lst) noexcept { - row_ = std::move(lst); - return *this; -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/tensor_row.h b/mindspore/ccsrc/dataset/core/tensor_row.h deleted file mode 100644 index 49bc61657c..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor_row.h +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_CORE_TENSOR_ROW_H_ -#define DATASET_CORE_TENSOR_ROW_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" - -namespace mindspore { -namespace dataset { - -class TensorRow; // A set of Tensor pointers with an id -using TensorTable = std::vector; // The table of tensors is a vector of rows -using TensorQTable = std::deque; // A different flavour of tensor table, this one has queue functionality - -class TensorRow { - public: - static constexpr row_id_type kDefaultRowId = -1; // Default row id - - // Type definitions - using size_type = dsize_t; - using value_type = std::shared_ptr; - using reference = std::shared_ptr &; - using const_reference = const std::shared_ptr &; - using vector_type = std::vector>; - using iterator = std::vector>::iterator; - using const_iterator = std::vector>::const_iterator; - - TensorRow() noexcept; - - TensorRow(size_type n, value_type t) noexcept; - - // Copy Constructors - explicit TensorRow(const vector_type &v); - - TensorRow(row_id_type id, const std::initializer_list &lst); - - TensorRow(const TensorRow &tr); - - TensorRow &operator=(const TensorRow &tr); - - TensorRow &operator=(const std::initializer_list &lst); - - // Move Constructors - explicit TensorRow(vector_type &&v) noexcept; - - TensorRow(row_id_type id, std::initializer_list &&lst) noexcept; - - TensorRow(TensorRow &&tr) noexcept; - - TensorRow &operator=(TensorRow &&tr) noexcept; - - TensorRow &operator=(std::initializer_list &&lst) noexcept; - - // Destructor - ~TensorRow() = default; - - // Functions to fetch/set id/vector - row_id_type getId() const { return id_; } - - void setId(row_id_type id) { id_ = id; } - - const vector_type &getRow() const { return row_; } - - // Wrapper functions to support vector operations - void emplace_back(value_type t) { row_.emplace_back(t); } - - void push_back(value_type t) { row_.push_back(t); } - - void clear() noexcept { row_.clear(); } - - size_type size() const noexcept { return row_.size(); } - - void reserve(size_type size) { row_.reserve(size); } - - void resize(size_type size) { row_.resize(size); } - - bool empty() { return row_.empty(); } - - void insert(iterator position, iterator first, iterator last) { row_.insert(position, first, last); } - - // Wrapper functions to support vector element access - reference at(size_type index) { return row_.at(index); } - - const_reference at(size_type index) const { return row_.at(index); } - - reference front() { return row_.front(); } - - const_reference front() const { return row_.front(); } - - reference back() { return row_.back(); } - - const_reference back() const { return row_.back(); } - - reference operator[](size_type index) { return row_[index]; } - - const_reference operator[](size_type index) const { return row_[index]; } - - // Wrapper functions to support vector iteration - iterator begin() { return row_.begin(); } - - const_iterator begin() const { return row_.begin(); } - - iterator end() { return row_.end(); } - - const_iterator end() const { return row_.end(); } - - protected: - row_id_type id_; - std::vector> row_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_TENSOR_ROW_H_ diff --git a/mindspore/ccsrc/dataset/core/tensor_shape.cc b/mindspore/ccsrc/dataset/core/tensor_shape.cc deleted file mode 100644 index a0d6b9cd8d..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor_shape.cc +++ /dev/null @@ -1,231 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#define MAX_INTEGER_DTYPE 9223372036854775807 - -#include "dataset/core/tensor_shape.h" - -#include - -#include "common/utils.h" -#include "utils/log_adapter.h" -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { -constexpr dsize_t TensorShape::kDimUnknown; - -bool multi_ok(dsize_t x, dsize_t y) { - dsize_t p = x * y; - if (x == 0) { - return true; - } - return p / x == y; -} - -dsize_t TensorShape::NumOfElements() const { - if (!known()) { - return 0; - } - return strides_[0]; -} - -void TensorShape::Print(std::ostream &out) const { - if (!known() && raw_shape_.empty()) { - out << ""; - } else { - out << "<"; - for (auto i = 0; i < this->Rank(); i++) { - if (raw_shape_[i] == kDimUnknown) { - out << "*"; - } else { - out << raw_shape_[i]; - } - if (i != this->Rank() - 1) { - out << ","; - } - } - out << ">"; - } -} - -TensorShape::TensorShape(const std::initializer_list &list) - : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { - AddListToShape(list); -} - -TensorShape::TensorShape(const std::vector &list) - : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { - AddListToShape(list); -} - -TensorShape::TensorShape(const TensorShape &shape) - : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { - AddListToShape(shape.AsVector()); - known_ = shape.known_; // override with the input shape in case of unknown-rank tensor shape. -} - -TensorShape::TensorShape(py::list l) - : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { - std::vector list_c; - for (auto &i : l) { - if (!i.is_none()) { - list_c.push_back(i.cast()); - } else { - list_c.push_back(TensorShape::kDimUnknown); - } - } - AddListToShape(list_c); -} - -TensorShape::TensorShape(cv::MatSize cv_size, uint32_t type) - : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { - for (int i = 0; i < cv_size.dims(); i++) { - raw_shape_.push_back(cv_size[i]); - } - auto channels = static_cast(1 + (type >> static_cast(CV_CN_SHIFT))); - if (channels != 1) { - raw_shape_.push_back(channels); - } - known_ = true; -} - -TensorShape TensorShape::CreateUnknownRankShape() { - TensorShape s({}); - s.known_ = false; - return s; -} - -TensorShape TensorShape::InsertDim(dsize_t axis, dsize_t dim) const { - std::vector tmp = AsVector(); - (void)tmp.insert(tmp.begin() + axis, dim); - return TensorShape(tmp); -} - -std::vector TensorShape::AsVector() const { - return std::vector(raw_shape_.begin(), raw_shape_.end()); -} - -bool TensorShape::IsValidIndex(const std::vector &index) const { - dsize_t s_rank = Rank(); - if (index.size() != s_rank) { - return false; - } - for (dsize_t i = 0; i < s_rank; i++) { - if (index[i] < 0 || raw_shape_[i] <= index[i]) { - return false; - } - } - return true; -} - -template -void TensorShape::AddListToShape(const T &list) { - raw_shape_.resize(list.size()); - strides_.resize(list.size() + 1); - strides_[list.size()] = 1; - known_ = true; - dsize_t size = 0; - auto itr = std::rbegin(list); // iterate over the list in reverse order - auto s = list.size() - 1; // to compute strides while adding dims - for (; itr != std::rend(list); itr++, s--) { - dsize_t dim = *itr; - if (dim > 0) { - if (strides_[s + 1] > std::numeric_limits::max() / dim) { - MS_LOG(ERROR) << "Invalid shape data, overflow occurred!"; - known_ = false; - raw_shape_.clear(); - return; - } - strides_[s] = dim * strides_[s + 1]; - } - if (dim < 0) { - known_ = false; - } - if (dim > kDeMaxDim) { - std::stringstream ss; - ss << "Invalid shape data, dim (" << size << ") is larger than the maximum dim size(" << kDeMaxDim << ")!"; - MS_LOG(ERROR) << ss.str().c_str(); - known_ = false; - raw_shape_.clear(); - return; - } - raw_shape_[s] = dim; - size++; - } - if (size > kDeMaxRank) { - std::stringstream ss; - ss << "Invalid shape data, rank (" << size << ") is larger than the maximum rank size(" << kDeMaxRank << ")."; - MS_LOG(ERROR) << ss.str().c_str(); - known_ = false; - raw_shape_.clear(); - return; - } -} - -TensorShape TensorShape::CreateUnknownShapeWithRank(dsize_t rank) { - TensorShape s({}); - for (dsize_t i = 0; i < rank; i++) { - s.raw_shape_.push_back(kDimUnknown); - } - s.known_ = false; - return s; -} - -TensorShape TensorShape::PrependDim(dsize_t dim) const { - if (Size() == 0) { - return TensorShape({dim}); - } - return InsertDim(0, dim); -} - -TensorShape TensorShape::AppendDim(dsize_t dim) const { - auto vec = AsVector(); - vec.push_back(dim); - return TensorShape(vec); -} - -py::list TensorShape::AsPyList() { - py::list list; - for (auto i : raw_shape_) { - list.append(i); - } - return list; -} - -TensorShape TensorShape::Squeeze() const { - std::vector new_shape; - for (auto s : AsVector()) { - if (s != 1) { - new_shape.push_back(s); - } - } - return TensorShape(new_shape); -} - -std::vector TensorShape::Strides() const { return std::vector{strides_.begin() + 1, strides_.end()}; } - -// Name: ToFlatIndex() -// Description: convert a vector style index to number, used to access memory internal use only -Status TensorShape::ToFlatIndex(const std::vector &index, dsize_t *flat_index) const { - *flat_index = 0; - for (size_t k = 0; k < index.size(); k++) { - *flat_index += index[k] * strides_[k + 1]; // skip the first element of strides_ which is numOfElements - } - CHECK_FAIL_RETURN_UNEXPECTED(*flat_index < NumOfElements(), "Not a valid index"); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/tensor_shape.h b/mindspore/ccsrc/dataset/core/tensor_shape.h deleted file mode 100644 index c83e43cd7d..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor_shape.h +++ /dev/null @@ -1,187 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_TENSOR_SHAPE_H_ -#define DATASET_CORE_TENSOR_SHAPE_H_ - -#include -#include -#include -#include -#include - -#include - -#include "pybind11/pybind11.h" - -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/util/allocator.h" - -namespace py = pybind11; -namespace mindspore { -namespace dataset { -// Class that represents a shape of a Tensor. A shape can be: -// -# Known shape (mKnown = true) -// -# Scalar --> empty vector --> <> -// -# n-Dim --> not empty vector --> where di is >= 0\n -// Example: <1,2>, <1>, <1,13,10,11,1> -// -# Unknown shape (mKnown = false) -// -# Rank is unknown --> empty vector --> <> -// -# one or more dim is unknown --> not empty vector --> where di is unknown\n -// Example: <3,?> (the 1st dim is unknown)\n -// <2,?,?,?> (all dims but the 0th dim are unknown) -// TensorShape supports any dim > 0 and < 2^31-1 -class TensorShape { - public: - static constexpr dsize_t kDimUnknown = -1; // constant for an unknown dimension - - // Force the compiler to not create a no-arg constructor - TensorShape() = delete; - - // Create a Shape from an initialization list (e.g., TensorShape s = {2,2}). - // If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown - // @param list - explicit TensorShape(const std::initializer_list &list); - - // Create a Shape from a vector (e.g., TensorShape s = std::vector({2,2}) ). - // If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown - // @param list - explicit TensorShape(const std::vector &list); - - // Copy constructor - // @param shape - TensorShape(const TensorShape &shape); - - // construct a TensorShape via a python list - // @param py::list l - a list object from python - explicit TensorShape(py::list l); - - ~TensorShape() = default; - - // Create a scalar Shape (i.e., empty shape with mKnown = true) - // @return TensorShape - static TensorShape CreateScalar() { return TensorShape({}); } - - // Create a shape with an unknown rank. - // @return TensorShape - static TensorShape CreateUnknownRankShape(); - - // Create a shape with a known rank . - // @return TensorShape - static TensorShape CreateUnknownShapeWithRank(dsize_t rank); - - // Insert a new dim into a copy of the current shape. - // @param dim to be added - // @param axis the index where dim should be added - // @return New modified shape - TensorShape InsertDim(dsize_t axis, dsize_t dim) const; - - // Insert new dim at index 0. For example, <2,4> --> PrependDim(4) --> <4,2,4> - // @param dim - // @return - TensorShape PrependDim(dsize_t dim) const; - - // Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4> - // @param dim - // @return - TensorShape AppendDim(dsize_t dim) const; - - // Create a shape based on OpenCV shape and type - // @param cv_size - // @param type int that represent the type in OpenCV, example CV_8U, CV_64S - TensorShape(cv::MatSize cv_size, uint32_t type); - - dsize_t Size() const { return raw_shape_.size(); } - - dsize_t Rank() const { return raw_shape_.size(); } - - bool known() const { return known_; } - - bool empty() const { return raw_shape_.empty(); } - - dsize_t NumOfElements() const; - - bool operator==(const TensorShape &rhs) const { return known_ == rhs.known_ && raw_shape_ == rhs.raw_shape_; } - - bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); } - - dsize_t operator[](const dsize_t index) const { - if (index < 0) return raw_shape_[raw_shape_.size() + index]; - return raw_shape_[index]; - } - - // Return the Shape as a vector - // @return - std::vector AsVector() const; - - // Returns the class info as a string - // @return - std::string ToString() const { - std::stringstream ss; - ss << *this; - return ss.str(); - } - - // Actual print function used by operator<< - // @param out output string stream - void Print(std::ostream &out) const; - - // << Stream output operator overload - // @notes This allows you to print the info using stream operators - // @param out - reference to the output stream being overloaded - // @param rO - reference to the TensorShape to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const TensorShape &so) { - so.Print(out); - return out; - } - - py::list AsPyList(); - - // Checks if the given index is a valid index for this tensor. - // For example: Tensor<3,4> Index<1,1> is valid. But Index<4,1> or <1> are not. - // @param index - // @return bool - bool IsValidIndex(const std::vector &index) const; - - TensorShape Squeeze() const; - - std::vector Strides() const; - - // Returns the location of the item assuming row major memory layout. - // @param index - // @return - Status ToFlatIndex(const std::vector &index, dsize_t *flat_index) const; - - private: - // True if known and valid shape, false otherwise - bool known_; - // Vector to keep the dims of the shape. - std::vector raw_shape_; - // Vector to keep the strides of the shape. The size is rank+1 - std::vector strides_; - - // Internal utility function to iterate over a list, check if the dim is valid and then insert it into the shape. - // @tparam T list - // @param list Iterable list - // @return true if the shape is valid and no overflow would be generated when counting the number of elements. - // False otherwise. - template - void AddListToShape(const T &list); -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_TENSOR_SHAPE_H_ diff --git a/mindspore/ccsrc/dataset/engine/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/CMakeLists.txt deleted file mode 100644 index 66f95d0926..0000000000 --- a/mindspore/ccsrc/dataset/engine/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -add_subdirectory(datasetops) -add_subdirectory(opt) -add_subdirectory(gnn) -add_subdirectory(perf) -if (ENABLE_TDTQUE) - add_subdirectory(tdt) -endif () - -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(engine OBJECT - execution_tree.cc - data_buffer.cc - data_schema.cc - dataset_iterator.cc - ) -target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) - -if (ENABLE_TDTQUE) - add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf) -else() - add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf) -endif () diff --git a/mindspore/ccsrc/dataset/engine/connector.h b/mindspore/ccsrc/dataset/engine/connector.h deleted file mode 100644 index bd66172be5..0000000000 --- a/mindspore/ccsrc/dataset/engine/connector.h +++ /dev/null @@ -1,211 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_CONNECTOR_H_ -#define DATASET_ENGINE_CONNECTOR_H_ - -#include -#include -#include -#include -#include "dataset/util/task_manager.h" -#include "dataset/util/queue.h" -#include "dataset/util/services.h" -#include "dataset/util/cond_var.h" - -namespace mindspore { -namespace dataset { -// Connector is a communication data structure between two group of threads that -// preserve the order. -// -// Example use case: -// An initial tasks-list of [1,2,3,4,5,6,7,8,9] with 5 threads getting/processing elements from that list, -// and pushing the processed elements to a Connector in any order whoever finishes processing first. -// If the consumer of the Connector is single threaded, when the consumer pop() the -// element from the Connector one by one, it will get [1,2,3,4,5,6,7,8,9]. -// -// Requirements: -// 1. Each thread in the group of consumer or producer threads must be assigned ids starting from 0. -// 2. If your multi-threads program is not reading from a Connector class but -// want to push to a Connector class, you must follow roundrobin element distribution, -// i.e., the thread-id0 must have the first element, thread-id1 has the second element, -// and so on; then each of this worker can push to the Connector class async in parallel. -// -// Blocking conditions: -// 1. Connector.push(int, T) can block when the internal queue it's trying to push is full. -// 2. Connector.pop(int) can block when -// - The internal queue it's trying to pop is empty. -// - The caller thread of pop() is not equal to the _expectConsumer. This is to enforce -// the ordering. -// -// Future improvement: -// 1. Fault tolerant: Right now, if one of the worker dies, the Connector will not work -// properly. -template -class Connector { - public: - // Name: Constructor - // Description: Initializing private members with the given input arguments. - // expect_consumer_ and pop_from_ is initialized to 0 as part of - // our requirements. We instantiate nProducers number of internal - // queues so that each producer thread can push to its queue without - // any sync overhead. - // Constructor of Connector - // Initializing private members with the given input arguments. - // _expectConsumer and _popFrom is initialized to 0 as part of - // our requirements. We instantiate nProducers number of internal - // queues so that each producer thread can push to its queue without - // any sync overhead. - // @param n_producers The number of threads producing data into this DbConnector. - // @param n_consumers The number of thread consuming data from this DbConnector. - // @param queue_capacity The number of element (DataBuffer) for each queue. - Connector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity) - : num_producers_(n_producers), num_consumers_(n_consumers) { - MS_LOG(DEBUG) << "A connector is created with " << n_producers << " producers and " << n_consumers << " consumers."; - my_name_ = Services::GetUniqueID(); - // We require the consumers to have ids sequentially from 0 to the num_consumers_-1, - // Otherwise a ordered list of consumer ids have to be passed here. (not implemented yet) - expect_consumer_ = 0; - - // Roundrobin pop starts from index 0 of the queues_. - pop_from_ = 0; - - // Initialize the queues_ to have num_producers_ number of queues. - // Each queue is a blocking queue and has the same queue_capacity. - queues_.Init(num_producers_, queue_capacity); - } - - // Destructor of Connector - virtual ~Connector() = default; - - // Get an element from the Connector. - // @not Call to pop() can block the caller thread, see the blocking condition at the top of this file. - // @param worker_id The id of a worker thread calling this method. - // @param result The address of an object where the popped element will be placed. - virtual Status Pop(int32_t worker_id, // The worker-id of the caller. See the requirement at the top of this file. - T *result) noexcept { - { - MS_ASSERT(worker_id < num_consumers_); - std::unique_lock lk(m_); - RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return expect_consumer_ == worker_id; })); - RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); - pop_from_ = (pop_from_ + 1) % num_producers_; - out_buffers_count_++; - expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; - } - - cv_.NotifyAll(); - return Status::OK(); - } - - // Add an element into the DbConnector without the overhead of synchronization. - // It may block when the internal queue is full. - // The element passed to this function will be copied into the internal queue. - // @param worker_id The id of a worker thread calling this method. - // @param el A const lvalue element to be passed/added/pushed. - Status Push(int32_t worker_id, const T &el) noexcept { - MS_ASSERT(worker_id < static_cast(queues_.size())); - MS_ASSERT(queues_[worker_id] != nullptr); - return (queues_[worker_id]->Add(el)); - } - - auto out_buffers_count() const { return out_buffers_count_.load(); } - - // Add an element into the DbConnector without the overhead of synchronization. - // It may block when the internal queue is full. - // The element passed to this function will be forwarded into the internal queue. - // @param worker_id The id of a worker thread calling this method. - // @param el An element to be passed/added/pushed. - virtual Status Push(int32_t worker_id, T &&el) noexcept { - MS_ASSERT(worker_id < static_cast(queues_.size())); - MS_ASSERT(queues_[worker_id] != nullptr); - return (queues_[worker_id]->Add(std::forward(el))); - } - - // Resets the internal index tracking of the queue so that it can be used again with new inputs, - // starting from the beginning. - void Reset() { - for (int i = 0; i < queues_.size(); ++i) { - queues_[i]->ResetQue(); - } - expect_consumer_ = 0; - pop_from_ = 0; - out_buffers_count_ = 0; - MS_LOG(DEBUG) << "Connector counters reset."; - } - - void Print(std::ostream &out, bool showAll) const { - out << "\n--------- Connector ------------" - << "\nConnector Name : " << my_name_ << "\nNumber of consumers : " << num_consumers_ - << "\nNumber of producers : " << num_producers_ << "\n"; - } - - friend std::ostream &operator<<(std::ostream &out, const Connector &con) { - con.print(out, false); - return out; - } - - // Get current size of connector. - int32_t size() const { - int32_t size = 0; - for (int32_t i = 0; i < queues_.size(); ++i) { - size += queues_[i]->size(); - } - return size; - } - - int32_t capacity() const { - int32_t capacity = 0; - for (int32_t i = 0; i < queues_.size(); ++i) { - capacity += queues_[i]->capacity(); - } - return capacity; - } - - // Register the internal resources with Task group for interruption service. - // @param vg - // @return - Status Register(TaskGroup *vg) { - Status rc = queues_.Register(vg); - if (rc.IsOk()) { - rc = cv_.Register(vg->GetIntrpService()); - } - return rc; - } - - protected: - std::string my_name_; - - // A list of Queues that are thread safe. - QueueList queues_; - - // The consumer that we allow to get the next data from pop() - int32_t expect_consumer_; - - // The index to the queues_ where the next data should be popped. - int32_t pop_from_; - - int32_t num_producers_; - int32_t num_consumers_; - - // Used in the Pop(), when a thread call pop() but it is not the expect_consumer_. - std::mutex m_; - CondVar cv_; - std::atomic out_buffers_count_ = 0; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_CONNECTOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/data_buffer.cc b/mindspore/ccsrc/dataset/engine/data_buffer.cc deleted file mode 100644 index 32a70c259f..0000000000 --- a/mindspore/ccsrc/dataset/engine/data_buffer.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/data_buffer.h" -#include "dataset/util/allocator.h" -#include "dataset/core/global_context.h" -#include "dataset/core/tensor.h" - -namespace mindspore { -namespace dataset { -// Name: Constructor #1 -// Description: This is the main constructor that is used for making a buffer -DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {} - -// Name: print() -// Description: A function that prints info about the DataBuffer (base class version) -void DataBuffer::Print(std::ostream &out, // In: The output stream to print to - bool show_all) const { // In: T/F if it should show everything - out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n"; - - // If the column counts are set then it means that data has been set into - // the tensor table. Display the tensor table here. - if (this->NumCols() > 0) { - out << "Tensor table:\n"; - for (int32_t row = 0; row < DataBuffer::NumRows(); ++row) { - out << "Row # : " << row << "\n"; - TensorRow currRow = (*tensor_table_)[row]; - for (int32_t col = 0; col < this->NumCols(); ++col) { - out << "Column #: " << col << "\n"; // Should add the column name here as well? - // Call the tensor display - out << *(currRow[col]) << "\n"; - } - } - } -} - -Status DataBuffer::Load() { - std::string err_msg = "Base class load called, but it does not have an implementation!"; - RETURN_STATUS_UNEXPECTED(err_msg); -} - -// Remove me!! Callers should fetch rows via pop -Status DataBuffer::GetTensor(std::shared_ptr *ptr, int32_t row_id, int32_t col_id) const { - if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) { - *ptr = (tensor_table_->at(row_id)).at(col_id); - } else { - std::string err_msg = - "indices for mTensorTable out of range: (" + std::to_string(row_id) + "," + std::to_string(col_id) + ")."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -// Remove me!! Callers should fetch rows via pop -Status DataBuffer::GetRow(int32_t row_id, TensorRow *ptr) const { - if (tensor_table_ && !tensor_table_->empty() && row_id < tensor_table_->size()) { - *ptr = tensor_table_->at(row_id); - } else { - std::string err_msg = "rowId for mTensorTable out of range: " + std::to_string(row_id); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - return Status::OK(); -} - -Status DataBuffer::PopRow(TensorRow *ptr) { - if (tensor_table_ && !tensor_table_->empty()) { - *ptr = std::move(tensor_table_->front()); - tensor_table_->pop_front(); - } - - return Status::OK(); -} - -Status DataBuffer::SliceOff(int64_t number_of_rows) { - while (number_of_rows > 0) { - tensor_table_->pop_back(); - number_of_rows--; - } - - return Status::OK(); -} - -// Destructor -DataBuffer::~DataBuffer() {} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/data_buffer.h b/mindspore/ccsrc/dataset/engine/data_buffer.h deleted file mode 100644 index 2ab0783519..0000000000 --- a/mindspore/ccsrc/dataset/engine/data_buffer.h +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATA_BUFFER_H_ -#define DATASET_ENGINE_DATA_BUFFER_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/allocator.h" -#include "dataset/util/status.h" -#include "dataset/core/constants.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_row.h" - -namespace mindspore { -namespace dataset { -// The DataBuffer class is a base class that will represent the data for n values based -// on a unique row id for each row of data. -// There can be different types of DataBuffers to abstract over how the data is stored -// in memory and acquired from storage. -// Each buffer holds a range of consecutive row id's. -class DataBuffer { - public: - // Buffer flags - enum BufferFlags : uint32_t { - kDeBFlagNone = 0, - kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg - kDeBFlagEOE = 1u << 1 // The buffer is an eoe end-of-epoch msg - }; - - // Name: Constructor #1 - // Description: This is the main constructor that is used for making a buffer - DataBuffer(int32_t id, BufferFlags flags); - - // Destructor - virtual ~DataBuffer(); - - // Name: print() - // Description: A function that prints info about the DataBuffer (base class version) - virtual void Print(std::ostream &out, // In: The output stream to print to - bool show_all) const; // In: T/F if it should show everything - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) { - cb.Print(out, false); - return out; - } - - // Name: load() - // Description: populates the DataBuffer with data based on it's id - virtual Status Load(); - - // Convenience getter functions for flag checking - bool eof() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagEOF)); } - - bool eoe() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagEOE)); } - - // Simple getter funcs - int32_t id() const { return buffer_id_; } - - void set_id(int32_t id) { buffer_id_ = id; } - - int32_t NumRows() const { return ((tensor_table_) ? tensor_table_->size() : 0); } - - int32_t NumCols() const { - return (tensor_table_ == nullptr || tensor_table_->empty()) ? 0 : tensor_table_->at(0).size(); - } - - BufferFlags buffer_flags() const { return buffer_flags_; } - - // Remove me!! Callers should fetch rows via pop - Status GetTensor(std::shared_ptr *, int32_t row_id, int32_t col_id) const; - - // Remove me!! Callers should drain rows via pop. - Status GetRow(int32_t row_id, TensorRow *) const; - - // Get a row from the TensorTable - Status PopRow(TensorRow *); - - Status SliceOff(int64_t number_of_rows); - - // Replacing mTensorTable, the unique_ptr assignment will release the old TensorTable. - void set_tensor_table(std::unique_ptr new_table) { tensor_table_ = std::move(new_table); } - - void set_flag(BufferFlags in_flag) { - buffer_flags_ = static_cast(static_cast(buffer_flags_) | static_cast(in_flag)); - } - - void Shuffle() {} // does nothing right now. possibly remove later - - protected: - int32_t buffer_id_; // An id for the buffer. - std::unique_ptr tensor_table_; // A table (row major) of Tensors - BufferFlags buffer_flags_; // bit mask for various buffer properties -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATA_BUFFER_H_ diff --git a/mindspore/ccsrc/dataset/engine/data_schema.cc b/mindspore/ccsrc/dataset/engine/data_schema.cc deleted file mode 100644 index 6c5f882bed..0000000000 --- a/mindspore/ccsrc/dataset/engine/data_schema.cc +++ /dev/null @@ -1,451 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/data_schema.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/util/status.h" -#include "dataset/core/tensor_shape.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// A macro for converting an input string representing the column type to it's actual -// numeric column type. -#define STR_TO_TENSORIMPL(in_col_str, out_type) \ - do { \ - if (in_col_str == "cvmat") { \ - out_type = TensorImpl::kCv; \ - } else if (in_col_str == "flex") { \ - out_type = TensorImpl::kFlexible; \ - } else if (in_col_str == "np") { \ - out_type = TensorImpl::kNP; \ - } else { \ - out_type = TensorImpl::kNone; \ - } \ - } while (false) - -// Constructor 1: Simple constructor that leaves things uninitialized. -ColDescriptor::ColDescriptor() - : type_(DataType::DE_UNKNOWN), rank_(0), tensor_impl_(TensorImpl::kNone), tensor_shape_(nullptr) {} - -// Constructor 2: Main constructor -ColDescriptor::ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank, - const TensorShape *in_shape) - : type_(col_type), rank_(rank), tensor_impl_(tensor_impl), col_name_(col_name) { - // If a shape was provided, create unique pointer for it and copy construct it into - // our shape. Otherwise, set our shape to be empty. - if (in_shape != nullptr) { - // Create a shape and copy construct it into our column's shape. - tensor_shape_ = std::make_unique(*in_shape); - } else { - tensor_shape_ = nullptr; - } - // If the user input a shape, then the rank of the input shape needs to match - // the input rank - if (in_shape != nullptr && in_shape->known() && in_shape->Size() != rank_) { - rank_ = in_shape->Size(); - MS_LOG(WARNING) << "Rank does not match the number of dimensions in the provided shape." - << " Overriding rank with the number of dimensions in the provided shape."; - } -} - -// Explicit copy constructor is required -ColDescriptor::ColDescriptor(const ColDescriptor &in_cd) - : type_(in_cd.type_), rank_(in_cd.rank_), tensor_impl_(in_cd.tensor_impl_), col_name_(in_cd.col_name_) { - // If it has a tensor shape, make a copy of it with our own unique_ptr. - tensor_shape_ = in_cd.hasShape() ? std::make_unique(in_cd.shape()) : nullptr; -} - -// Assignment overload -ColDescriptor &ColDescriptor::operator=(const ColDescriptor &in_cd) { - if (&in_cd != this) { - type_ = in_cd.type_; - rank_ = in_cd.rank_; - tensor_impl_ = in_cd.tensor_impl_; - col_name_ = in_cd.col_name_; - // If it has a tensor shape, make a copy of it with our own unique_ptr. - tensor_shape_ = in_cd.hasShape() ? std::make_unique(in_cd.shape()) : nullptr; - } - return *this; -} - -// Destructor -ColDescriptor::~ColDescriptor() = default; - -// A print method typically used for debugging -void ColDescriptor::Print(std::ostream &out) const { - out << " Name : " << col_name_ << "\n Type : " << type_ << "\n Rank : " << rank_ - << "\n Shape : ("; - if (tensor_shape_) { - out << *tensor_shape_ << ")\n"; - } else { - out << "no shape provided)\n"; - } -} - -// Given a number of elements, this function will compute what the actual Tensor shape would be. -// If there is no starting TensorShape in this column, or if there is a shape but it contains -// an unknown dimension, then the output shape returned shall resolve dimensions as needed. -Status ColDescriptor::MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const { - if (out_shape == nullptr) { - RETURN_STATUS_UNEXPECTED("Unexpected null output shape argument."); - } - - // If the shape is not given in this column, then we assume the shape will be: {numElements} - if (tensor_shape_ == nullptr) { - if (this->rank() == 0 && num_elements == 1) { - *out_shape = TensorShape::CreateScalar(); - return Status::OK(); - } - *out_shape = TensorShape({num_elements}); - return Status::OK(); - } - - // Build the real TensorShape based on the requested shape and the number of elements in the data. - // If there are unknown dimensions, then the unknown dimension needs to be filled in. - // Example: requestedShape: {?,4,3}. - // If numElements is 24, then the output shape can be computed to: {2,4,3} - std::vector requested_shape = tensor_shape_->AsVector(); - int64_t num_elements_of_shape = 1; // init to 1 as a starting multiplier. - - // unknownDimPosition variable is overloaded to provide 2 meanings: - // 1) If it's set to DIM_UNKNOWN, then it provides a boolean knowledge to tell us if there are - // any unknown dimensions. i.e. if it's set to unknown, then there are no unknown dimensions. - // 2) If it's set to a numeric value, then this is the vector index position within the shape - // where the single unknown dimension can be found. - int64_t unknown_dim_position = TensorShape::kDimUnknown; // Assume there are no unknown dims to start - - for (int i = 0; i < requested_shape.size(); ++i) { - // If we already had an unknown dimension, then we cannot have a second unknown dimension. - // We only support the compute of a single unknown dim. - if (requested_shape[i] == TensorShape::kDimUnknown && unknown_dim_position != TensorShape::kDimUnknown) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Requested shape has more than one unknown dimension!"); - } - - // If the current dimension in the requested shape is a known value, then compute the number of - // elements so far. - if (requested_shape[i] != TensorShape::kDimUnknown) { - num_elements_of_shape *= requested_shape[i]; - } else { - // This dimension is unknown so track which dimension position has it. - unknown_dim_position = i; - } - } - - // Sanity check the the computed element counts divide evenly into the input element count - if (num_elements < num_elements_of_shape || num_elements_of_shape == 0 || num_elements % num_elements_of_shape != 0) { - RETURN_STATUS_UNEXPECTED("Requested shape has an invalid element count!"); - } - - // If there was any unknown dimensions, then update the requested shape to fill in the unknown - // dimension with the correct value. If there were no unknown dim's then the output shape will - // remain to be the same as the requested shape. - if (unknown_dim_position != TensorShape::kDimUnknown) { - requested_shape[unknown_dim_position] = (num_elements / num_elements_of_shape); - } - - // Any unknown dimension is filled in now. Set the output shape - *out_shape = TensorShape(requested_shape); - return Status::OK(); -} - -// getter function for the shape -TensorShape ColDescriptor::shape() const { - if (tensor_shape_ != nullptr) { - return *tensor_shape_; // copy construct a shape to return - } else { - return TensorShape::CreateUnknownRankShape(); // empty shape to return - } -} - -const char DataSchema::DEFAULT_DATA_SCHEMA_FILENAME[] = "datasetSchema.json"; - -// Constructor 1: Simple constructor that leaves things uninitialized. -DataSchema::DataSchema() : num_rows_(0) {} - -// Internal helper function. Parses the json schema file in any order and produces a schema that -// does not follow any particular order (json standard does not enforce any ordering protocol). -// This one produces a schema that contains all of the columns from the schema file. -Status DataSchema::AnyOrderLoad(nlohmann::json column_tree) { - // Iterate over the json file. Each parent json node is the column name, - // followed by the column properties in the child tree under the column. - // Outer loop here iterates over the parents (i.e. the column name) - if (!column_tree.is_array()) { - for (nlohmann::json::iterator it = column_tree.begin(); it != column_tree.end(); ++it) { - std::string col_name = it.key(); - nlohmann::json column_child_tree = it.value(); - RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, col_name)); - } - } else { - // Case where the schema is a list of columns not a dict - for (nlohmann::json::iterator it = column_tree.begin(); it != column_tree.end(); ++it) { - nlohmann::json column_child_tree = it.value(); - RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, "")); - } - } - return Status::OK(); -} - -// Internal helper function. For each input column name, perform a lookup to the json document to -// find the matching column. When the match is found, process that column to build the column -// descriptor and add to the schema in the order in which the input column names are given.id -Status DataSchema::ColumnOrderLoad(nlohmann::json column_tree, const std::vector &columns_to_load) { - if (!column_tree.is_array()) { - // the json file is dict (e.g., {image: ...}) - // Loop over the column name list - for (const auto &curr_col_name : columns_to_load) { - // Find the column in the json document - auto column_info = column_tree.find(common::SafeCStr(curr_col_name)); - if (column_info == column_tree.end()) { - RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); - } - // At this point, columnInfo.value() is the subtree in the json document that contains - // all of the data for a given column. This data will formulate our schema column. - const std::string &col_name = column_info.key(); - nlohmann::json column_child_tree = column_info.value(); - RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, col_name)); - } - } else { - // the json file is array (e.g., [name: image...]) - // Loop over the column name list - for (const auto &curr_col_name : columns_to_load) { - // Find the column in the json document - int32_t index = -1; - int32_t i = 0; - for (const auto &it_child : column_tree.items()) { - auto name = it_child.value().find("name"); - if (name == it_child.value().end()) { - RETURN_STATUS_UNEXPECTED("Name field is missing for this column."); - } - if (name.value() == curr_col_name) { - index = i; - break; - } - i++; - } - if (index == -1) { - RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); - } - nlohmann::json column_child_tree = column_tree[index]; - RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, curr_col_name)); - } - } - return Status::OK(); -} - -// Internal helper function for parsing shape info and building a vector for the shape construction. -static Status buildShape(const nlohmann::json &shapeVal, std::vector *outShape) { - if (outShape == nullptr) { - RETURN_STATUS_UNEXPECTED("null output shape"); - } - if (shapeVal.empty()) return Status::OK(); - - // Iterate over the integer list and add those values to the output shape tensor - auto items = shapeVal.items(); - using it_type = decltype(items.begin()); - (void)std::transform(items.begin(), items.end(), std::back_inserter(*outShape), [](it_type j) { return j.value(); }); - return Status::OK(); -} - -// Internal helper function. Given the json tree for a given column, load it into our schema. -Status DataSchema::ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name) { - int32_t rank_value = -1; - TensorImpl t_impl_value = TensorImpl::kFlexible; - std::string name, type_str; - std::vector tmp_shape = {}; - bool shape_field_exists = false; - // Iterate over this column's attributes. - // Manually iterating each of the child nodes/trees here so that we can provide our own error handling. - for (const auto &it_child : column_child_tree.items()) { - // Save the data for each of the attributes into variables. We'll use these to construct later. - if (it_child.key() == "name") { - name = it_child.value(); - } else if (it_child.key() == "type") { - type_str = it_child.value(); - } else if (it_child.key() == "rank") { - rank_value = it_child.value(); - } else if (it_child.key() == "t_impl") { - STR_TO_TENSORIMPL(it_child.value(), t_impl_value); - } else if (it_child.key() == "shape") { - shape_field_exists = true; - RETURN_IF_NOT_OK(buildShape(it_child.value(), &tmp_shape)); - } else { - std::string err_msg = "Unexpected column attribute " + it_child.key() + " for column " + col_name; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - if (!name.empty()) { - if (!col_name.empty() && col_name != name) { - std::string err_msg = - "json schema file for column " + col_name + " has column name that does not match columnsToLoad"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } else { - if (col_name.empty()) { - std::string err_msg = "json schema file for column " + col_name + " has invalid or missing column name."; - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - name = col_name; - } - } - // data type is mandatory field - if (type_str.empty()) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "json schema file for column " + col_name + " has invalid or missing column type."); - - // rank number is mandatory field - if (rank_value <= -1) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "json schema file for column " + col_name + " must define a positive rank value."); - - // Create the column descriptor for this column from the data we pulled from the json file - TensorShape col_shape = TensorShape(tmp_shape); - if (shape_field_exists) - (void)this->AddColumn(ColDescriptor(name, DataType(type_str), t_impl_value, rank_value, &col_shape)); - else - // Create a column descriptor that doesn't have a shape - (void)this->AddColumn(ColDescriptor(name, DataType(type_str), t_impl_value, rank_value)); - return Status::OK(); -} - -// Parses a schema json file and populates the columns and meta info. -Status DataSchema::LoadSchemaFile(const std::string &schema_file_path, - const std::vector &columns_to_load) { - try { - std::ifstream in(schema_file_path); - - nlohmann::json js; - in >> js; - RETURN_IF_NOT_OK(PreLoadExceptionCheck(js)); - try { - num_rows_ = js.at("numRows").get(); - } catch (nlohmann::json::out_of_range &e) { - num_rows_ = 0; - } catch (nlohmann::json::exception &e) { - RETURN_STATUS_UNEXPECTED("Unable to parse \"numRows\" from schema"); - } - nlohmann::json column_tree = js.at("columns"); - if (column_tree.empty()) { - RETURN_STATUS_UNEXPECTED("columns is null"); - } - if (columns_to_load.empty()) { - // Parse the json tree and load the schema's columns in whatever order that the json - // layout decides - RETURN_IF_NOT_OK(this->AnyOrderLoad(column_tree)); - } else { - RETURN_IF_NOT_OK(this->ColumnOrderLoad(column_tree, columns_to_load)); - } - } catch (const std::exception &err) { - // Catch any exception and convert to Status return code - RETURN_STATUS_UNEXPECTED("Schema file failed to load"); - } - return Status::OK(); -} - -// Parses a schema json string and populates the columns and meta info. -Status DataSchema::LoadSchemaString(const std::string &schema_json_string, - const std::vector &columns_to_load) { - try { - nlohmann::json js = nlohmann::json::parse(schema_json_string); - RETURN_IF_NOT_OK(PreLoadExceptionCheck(js)); - num_rows_ = js.value("numRows", 0); - nlohmann::json column_tree = js.at("columns"); - if (column_tree.empty()) { - RETURN_STATUS_UNEXPECTED("columns is null"); - } - if (columns_to_load.empty()) { - // Parse the json tree and load the schema's columns in whatever order that the json - // layout decides - RETURN_IF_NOT_OK(this->AnyOrderLoad(column_tree)); - } else { - RETURN_IF_NOT_OK(this->ColumnOrderLoad(column_tree, columns_to_load)); - } - } catch (const std::exception &err) { - // Catch any exception and convert to Status return code - RETURN_STATUS_UNEXPECTED("Schema file failed to load"); - } - return Status::OK(); -} - -// Destructor -DataSchema::~DataSchema() = default; - -// Getter for the ColDescriptor by index -const ColDescriptor &DataSchema::column(int32_t idx) const { - MS_ASSERT(idx < static_cast(col_descs_.size())); - return col_descs_[idx]; -} - -// A print method typically used for debugging -void DataSchema::Print(std::ostream &out) const { - out << "Dataset schema: ("; - for (const auto &col_desc : col_descs_) { - out << col_desc << "\n"; - } -} - -// Adds a column descriptor to the schema -Status DataSchema::AddColumn(const ColDescriptor &cd) { - // Sanity check there's not a duplicate name before adding the column - for (int32_t i = 0; i < col_descs_.size(); ++i) { - if (col_descs_[i].name() == cd.name()) { - std::ostringstream ss; - ss << "column name '" << cd.name() << "' already exists in schema."; - std::string err_msg = ss.str(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - col_descs_.push_back(cd); - return Status::OK(); -} - -// Internal helper function. Performs sanity checks on the json file setup. -Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) { - // Check if columns node exists. It is required for building schema from file. - if (js.find("columns") == js.end()) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "\"columns\" node is required in the schema json file."); - return Status::OK(); -} - -// Loops through all columns in the schema and returns a map with the column -// name to column index number. -Status DataSchema::GetColumnNameMap(std::unordered_map *out_column_name_map) { - if (out_column_name_map == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "unexpected null output column name map."); - } - - for (int32_t i = 0; i < col_descs_.size(); ++i) { - if (col_descs_[i].name().empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Constructing column name map from schema, but found empty column name."); - } - (*out_column_name_map)[col_descs_[i].name()] = i; - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/data_schema.h b/mindspore/ccsrc/dataset/engine/data_schema.h deleted file mode 100644 index ce61b8952d..0000000000 --- a/mindspore/ccsrc/dataset/engine/data_schema.h +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATA_SCHEMA_H_ -#define DATASET_ENGINE_DATA_SCHEMA_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "dataset/core/constants.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -/// \class ColDescriptor data_schema.h -/// \brief A simple class to provide meta info about a column. -class ColDescriptor { - public: - /// \brief Constructor 1: Simple constructor that leaves things uninitialized. - ColDescriptor(); - - /// \brief Constructor 2: Main constructor - /// \param[in] col_name - The name of the column - /// \param[in] col_type - The DE Datatype of the column - /// \param[in] tensor_impl - The (initial) type of tensor implementation for the column - /// \param[in] rank - The number of dimension of the data - /// \param[in] in_shape - option argument for input shape - ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank, - const TensorShape *in_shape = nullptr); - - /// \brief Explicit copy constructor is required - /// \param[in] in_cd - the source ColDescriptor - ColDescriptor(const ColDescriptor &in_cd); - - /// \brief Assignment overload - /// \param in_cd - the source ColDescriptor - ColDescriptor &operator=(const ColDescriptor &in_cd); - - /// \brief Destructor - ~ColDescriptor(); - - /// \brief A print method typically used for debugging - /// \param out - The output stream to write output to - void Print(std::ostream &out) const; - - /// \brief Given a number of elements, this function will compute what the actual Tensor shape would be. - /// If there is no starting TensorShape in this column, or if there is a shape but it contains - /// an unknown dimension, then the output shape returned shall resolve dimensions as needed. - /// \param[in] num_elements - The number of elements in the data for a Tensor - /// \param[inout] out_shape - The materialized output Tensor shape - /// \return Status - The error code return - Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const; - - /// \brief << Stream output operator overload - /// This allows you to write the debug print info using stream operators - /// \param[in] out - reference to the output stream being overloaded - /// \param[in] cd - reference to the ColDescriptor to display - /// \return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const ColDescriptor &cd) { - cd.Print(out); - return out; - } - - /// \brief getter function - /// \return The column's DataType - DataType type() const { return type_; } - - /// \brief getter function - /// \return The column's rank - int32_t rank() const { return rank_; } - - /// \brief getter function - /// \return The column's name - std::string name() const { return col_name_; } - - /// \brief getter function - /// \return The column's shape - TensorShape shape() const; - - /// \brief getter function - /// \return TF if the column has an assigned fixed shape. - bool hasShape() const { return tensor_shape_ != nullptr; } - - /// \brief getter function - /// \return The column's tensor implementation type - TensorImpl tensorImpl() const { return tensor_impl_; } - - private: - DataType type_; // The columns type - int32_t rank_; // The rank for this column (number of dimensions) - TensorImpl tensor_impl_; // The initial flavour of the tensor for this column - std::unique_ptr tensor_shape_; // The fixed shape (if given by user) - std::string col_name_; // The name of the column -}; - -/// \class DataSchema data_schema.h -/// \brief A list of the columns. -class DataSchema { - public: - /// \brief Constructor - DataSchema(); - - /// \brief Destructor - ~DataSchema(); - - /// \brief Parses a schema json file and populates the columns and meta info. - /// \param[in] schema_file_path - the schema file that has the column's info to load - /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. - /// \return Status - The error code return - Status LoadSchemaFile(const std::string &schema_file_path, const std::vector &columns_to_load); - - /// \brief Parses a schema JSON string and populates the columns and meta info. - /// \param[in] schema_json_string - the schema file that has the column's info to load - /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. - /// \return Status - The error code return - Status LoadSchemaString(const std::string &schema_json_string, const std::vector &columns_to_load); - - /// \brief A print method typically used for debugging - /// \param[in] out - The output stream to write output to - void Print(std::ostream &out) const; - - /// \brief << Stream output operator overload. This allows you to write the debug print info using stream operators - /// \param[in] out - reference to the output stream being overloaded - /// \param[in] ds - reference to the DataSchema to display - /// \return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const DataSchema &ds) { - ds.Print(out); - return out; - } - - /// \brief Adds a column descriptor to the schema - /// \param[in] cd - The ColDescriptor to add - /// \return Status - The error code return - Status AddColumn(const ColDescriptor &cd); - - /// \brief getter - /// \return The reference to a ColDescriptor to get (const version) - const ColDescriptor &column(int32_t idx) const; - - /// \brief getter - /// \return The number of columns in the schema - int32_t NumColumns() const { return col_descs_.size(); } - - bool Empty() const { return NumColumns() == 0; } - - /// \brief getter - /// \return The number of rows read from schema - int64_t num_rows() const { return num_rows_; } - - static const char DEFAULT_DATA_SCHEMA_FILENAME[]; - - /// \brief Loops through all columns in the schema and returns a map with the column name to column index number. - /// \param[inout] out_column_name_map - The output map of columns names to column index - /// \return Status - The error code return - Status GetColumnNameMap(std::unordered_map *out_column_name_map); - - private: - /// \brief Internal helper function. Parses the json schema file in any order and produces a schema that - /// does not follow any particular order (json standard does not enforce any ordering protocol). - /// This one produces a schema that contains all of the columns from the schema file. - /// \param[in] column_tree - The nlohmann tree from the json file to parse - /// \return Status - The error code return - Status AnyOrderLoad(nlohmann::json column_tree); - - /// \brief Internal helper function. For each input column name, perform a lookup to the json document to - /// find the matching column. When the match is found, process that column to build the column - /// descriptor and add to the schema in the order in which the input column names are given. - /// \param[in] column_tree - The nlohmann tree from the json file to parse - /// \param[in] columns_to_load - list of strings for the columns to add to the schema - /// \return Status - The error code return - Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector &columns_to_load); - - /// \brief Internal helper function. Given the json tree for a given column, load it into our schema. - /// \param[in] columnTree - The nlohmann child tree for a given column to load. - /// \param[in] col_name - The string name of the column for that subtree. - /// \return Status - The error code return - Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name); - - /// \brief Internal helper function. Performs sanity checks on the json file setup. - /// \param[in] js - The nlohmann tree for the schema file - /// \return Status - The error code return - Status PreLoadExceptionCheck(const nlohmann::json &js); - - std::vector col_descs_; // Vector of column descriptors - int64_t num_rows_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATA_SCHEMA_H_ diff --git a/mindspore/ccsrc/dataset/engine/dataset_iterator.cc b/mindspore/ccsrc/dataset/engine/dataset_iterator.cc deleted file mode 100644 index be333741b1..0000000000 --- a/mindspore/ccsrc/dataset/engine/dataset_iterator.cc +++ /dev/null @@ -1,268 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/dataset_iterator.h" -#include -#include -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/status.h" -#include "dataset/engine/datasetops/dataset_op.h" - -namespace mindspore { -namespace dataset { -// Constructor of the IteratorBase -IteratorBase::IteratorBase() : curr_buffer_(nullptr), eof_handled_(false) {} - -IteratorBase::~IteratorBase() = default; - -// Fetches one row of data from the iterator as a column map. -Status IteratorBase::GetNextAsMap(TensorMap *out_map) { - if (out_map == nullptr) { - RETURN_STATUS_UNEXPECTED("Null output map in iterator!"); - } - - out_map->clear(); - - TensorRow curr_row; - RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row)); - - // Return empty map if there's no data - if (curr_row.empty()) { - return Status::OK(); - } - - // The column name mapping is needed to be able to produce the tensor map output. - // The column name mapping comes from the source operator that is producing the data into the iterator. - // To avoid having to fetch this for every time, we'll take a local copy of the column name id mapping - // and save in the iterator. We only have to do this once. All subsequent iterations use the same mapping. - if (col_name_id_map_.empty()) { - // Determine the column name map by calling the derived class method to retrieve the column - // name map - col_name_id_map_ = this->GetColumnNameMap(); - } - - // Populate the out map from the row and return it - for (auto colMap : col_name_id_map_) { - (*out_map)[colMap.first] = std::move(curr_row[colMap.second]); - } - - return Status::OK(); -} - -// Fetches one row of data from the iterator. -// The base class version simply performs error handling and returns empty row. Actual -// functionality exists in the derived versions of this function. -Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) { - if (out_row == nullptr) { - RETURN_STATUS_UNEXPECTED("Null output row in iterator!"); - } - - // clear the old tensor row - out_row->clear(); - - return Status::OK(); -} - -// Constructor of the DatasetIterator -DatasetIterator::DatasetIterator(std::shared_ptr exe_tree) - : IteratorBase(), - root_(exe_tree->root()), - tracing_(nullptr), - cur_batch_num_(0), - cur_connector_size_(0), - cur_connector_capacity_(0) { - std::shared_ptr node; - Status s = exe_tree->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node); - if (s.IsOk()) { - tracing_ = std::dynamic_pointer_cast(node); - } -} - -DatasetIterator::~DatasetIterator() = default; - -// Fetches one row of data from the iterator. Overrides the base class. This one fetches -// from the tree root node directly. -Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) { - // Common code init and error checking in the base class. - RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row)); - - // Once eof is handled, always return empty row. Class must be destroyed and recreated if you - // want to iterate again. - if (eof_handled_) { - return Status::OK(); - } - - // Check if we need to get a new DataBuffer to iterate. - if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { - if (tracing_ != nullptr) { - cur_connector_size_ = root_->ConnectorSize(); - cur_connector_capacity_ = root_->ConnectorCapacity(); - } - RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); - - // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually - // handle eoe and eof messages here. - // - // An eoe buffer means we have iterated fully to the end of the tree. - // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of - // all operators. - if (curr_buffer_->eoe()) { - MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row."; - - // Before returning the last empty vector, fetch the eof buffer which should be the last - // buffer, and then free it. - RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); - - if (!curr_buffer_->eof()) { - RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!"); - } - eof_handled_ = true; - curr_buffer_.reset(); // explicitly free the eof buffer - // Set tree to Finished state - root_->Tree()->SetFinished(); - - return Status::OK(); - } - - if (curr_buffer_->eof()) { - // An eof by itself, without being preceded by an eoe, is possible if a repeat operator - // exists below us in the stack. Repeat operator eats eoe's but eventually allows the - // flow of an eof up the pipeline by itself. - eof_handled_ = true; - curr_buffer_.reset(); // explicitly free the eof buffer - // Set tree to Finished state - root_->Tree()->SetFinished(); - return Status::OK(); - } - } - - // If we got this far, now it's time to pop that next row for return to caller - RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row)); - if (tracing_ != nullptr) { - cur_batch_num_++; - tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_); - } - return Status::OK(); -} - -Status DatasetIterator::GetOutputShapes(std::vector *out_shapes) { - if (out_shapes == nullptr) { - RETURN_STATUS_UNEXPECTED("Null output shape argument"); - } - if (device_queue_row_.empty()) { - RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); - } - for (auto ts : device_queue_row_) { - out_shapes->push_back(ts->shape()); - } - - return Status::OK(); -} - -Status DatasetIterator::GetOutputTypes(std::vector *out_types) { - if (out_types == nullptr) { - RETURN_STATUS_UNEXPECTED("Null output type argument"); - } - if (device_queue_row_.empty()) { - RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); - } - for (auto ts : device_queue_row_) { - out_types->push_back(ts->type()); - } - return Status::OK(); -} - -// Getter -std::unordered_map DatasetIterator::GetColumnNameMap() const { - return root_->column_name_id_map(); -} - -// Constructor of the ChildIterator -ChildIterator::ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx) - : IteratorBase(), current_op_(current_op), child_idx_(child_idx), worker_id_(worker_id), end_epoch_(false) {} - -ChildIterator::~ChildIterator() { current_op_ = nullptr; } - -// Fetches one row of data from the iterator. Overrides the base class. This one fetches -// only from the child/worker id as given from the constructor. -Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) { - // Common code init and error checking in the base class. - RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row)); - - // Once eof is handled, always return empty row. Class must be destroyed and recreated if you - // want to iterate again. - if (eof_handled_) { - return Status::OK(); - } - - // Check if we need to get a new DataBuffer to iterate. - if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { - RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); - - // Unlike the DatasetIterator, this child iterator does not quit after eoe. - // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the - // caller to decide what it wants to do next. - if (curr_buffer_->eoe()) { - MS_LOG(DEBUG) << "Child iterator picked up EOE."; - end_epoch_ = true; - return Status::OK(); - } - - if (curr_buffer_->eof()) { - MS_LOG(DEBUG) << "Child iterator picked up EOF."; - eof_handled_ = true; - return Status::OK(); - } - } - - // If we got this far, now it's time to pop that next row for return to caller - RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row)); - - return Status::OK(); -} - -// drain till the next eoe -Status ChildIterator::Drain() { - if (end_epoch_ == true) { - // Calling drain against a child that is already at it's eoe state will not result in any action. - // This allows you to do: - // - fetch until empty row - // - drain (will not actually drain because you are already at the end of the iteration) - // However, the next time after that, it will perform it's normal draining activities. - end_epoch_ = false; - MS_LOG(DEBUG) << "No operation drain, already at end of epoch."; - return Status::OK(); - } - MS_LOG(DEBUG) << "Child draining buffers until eoe."; - // else we drain until eoe or eof, eof here is for sanity check - while (!curr_buffer_->eoe() && !curr_buffer_->eof()) { - RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); - } - if (curr_buffer_->eof()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain."); - } - return Status::OK(); -} - -// Getter -std::unordered_map ChildIterator::GetColumnNameMap() const { - return current_op_->child(child_idx_)->column_name_id_map(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/dataset_iterator.h b/mindspore/ccsrc/dataset/engine/dataset_iterator.h deleted file mode 100644 index 4e40e77c74..0000000000 --- a/mindspore/ccsrc/dataset/engine/dataset_iterator.h +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASET_ITERATOR_H_ -#define DATASET_ENGINE_DATASET_ITERATOR_H_ - -#include -#include -#include -#include -#include "dataset/util/status.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/perf/dataset_iterator_tracing.h" - -namespace mindspore { -namespace dataset { -using TensorMap = std::unordered_map>; - -// forward declare -class ExecutionTree; - -class DataBuffer; - -// IteratorBase class is used to iterate data from an executionTree one row at a time. -// The base class provides the general interface, whereas derived classes provide slightly -// different implementations. -class IteratorBase { - public: - // Constructor of IteratorBase - IteratorBase(); - - // Destructor - virtual ~IteratorBase(); - - // Fetches one row of data from the iterator. - // the base class version simply performs error handling and returns empty row. Actual - // functionality exists in the derived versions of this function. - // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data - // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. - // @return Status - The error code return - // @note The position of a Tensor/column might be different from the initial column order - // in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change - // the column ordering. - virtual Status FetchNextTensorRow(TensorRow *out_row); - - // Fetches one row of data from the iterator as a column map. - // @return A unordered map from column name to shared pointer to Tensor. - Status GetNextAsMap(TensorMap *out_map); - - // Getter - // @return T/F if this iterator is completely done after getting an eof - bool eof_handled() const { return eof_handled_; } - - // Getter - // @return The string to column id mapping. - virtual std::unordered_map GetColumnNameMap() const = 0; - - protected: - std::unique_ptr curr_buffer_; // holds the current buffer - bool eof_handled_; // T/F if this op got an eof - std::unordered_map col_name_id_map_; -}; - -// The DatasetIterator derived class is for fetching rows off the end/root of the execution tree. -class DatasetIterator : public IteratorBase { - public: - // Constructor of the DatasetIterator - // @param exe_tree The execution tree we want to pull/iterate the data from using it's root node. - explicit DatasetIterator(std::shared_ptr exe_tree); - - // Destructor - ~DatasetIterator(); - - // Fetches one row of data from the iterator. Overrides the base class. This one fetches - // from the tree root node directly. - // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data - // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. - // @return Status - The error code return - Status FetchNextTensorRow(TensorRow *out_row) override; - - // Fetches the next tensor row into device row, and returns it's shape. - // @param out_shapes - A vector of tensor shapes (one shape per column) - // @return Status - The error code return - Status GetOutputShapes(std::vector *out_shapes); - - // Fetches the next tensor row into device row, and returns it's shape. - // @param outShapes - A vector of tensor shapes (one shape per column) - // @return Status - The error code return - Status GetOutputTypes(std::vector *out_types); - - // Getter - // @return The string to column id mapping. - std::unordered_map GetColumnNameMap() const override; - - private: - std::shared_ptr root_; // saves the root of the executionTree - TensorRow device_queue_row_; - std::shared_ptr tracing_; // trace profiling data - int32_t cur_batch_num_; // current batch number,used for profiling - int32_t cur_connector_size_; // current connector size of root op,used for profiling - int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling -}; - -// The ChildIterator derived class is for fetching rows from intermediate nodes of execution tree. -// This one should only be used by internal Dataset operators, rather than an end-user. -class ChildIterator : public IteratorBase { - public: - // Constructor of the DatasetIterator - // @param current_op - The parent op from which we'll fetch from it's children. - // @param worker_id - The worker id to use when fetching from the children. - // @param child_idx - The index to the child to fetch from. - ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx); - - // Destructor - ~ChildIterator(); - - // Fetches one row of data from the iterator. Overrides the base class. This one fetches - // only from the child/worker id as given from the constructor. - // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data - // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. - // @return Status - The error code return - Status FetchNextTensorRow(TensorRow *out_row) override; - - // This function drains buffer until next eoe has been received. - // It will be a no-op if the previous row returned is empty. - // @return Status - The error code return - Status Drain(); - - // Getter - // @return The string to column id mapping. - std::unordered_map GetColumnNameMap() const override; - - private: - DatasetOp *current_op_; // The parent operator. We consume from it's children. - int32_t child_idx_; // The specific child this iterator will fetch from. - int32_t worker_id_; // The worker id uses for fetching the child data. - bool end_epoch_; // the flag used when an empty row has been returned. -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASET_ITERATOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt deleted file mode 100644 index ed57421030..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -add_subdirectory(source) - -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(engine-datasetops OBJECT - dataset_op.cc - parallel_op.cc - pipeline_op.cc - barrier_op.cc - batch_op.cc - bucket_batch_by_length_op.cc - device_queue_op.cc - map_op.cc - project_op.cc - rename_op.cc - repeat_op.cc - skip_op.cc - take_op.cc - shuffle_op.cc - zip_op.cc - concat_op.cc - filter_op.cc - build_vocab_op.cc - ) - diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc deleted file mode 100644 index 6fc276a75e..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc +++ /dev/null @@ -1,242 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/barrier_op.h" -#include -#include -#include "dataset/core/constants.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -BarrierOp::Builder::Builder() { - // Some arguments to the BarrierOp constructor have a default argument that is taken - // from the client config. - // The user may choose to change these values for the construction of the BarrierOp by - // using the various builder set methods. - - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } - -Status BarrierOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, - builder_condition_func_); - return Status::OK(); -} - -// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions -BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, - py::function condition_func) - : PipelineOp(op_connector_size), - rows_per_buffer_(rows_per_buffer), - buffer_id_(0), - clean_up_(false), - eof_(false), - condition_name_(condition_name), - condition_function_(condition_func) {} - -// destructor -BarrierOp::~BarrierOp() {} - -// Entry point for Barrier, called by launch() -Status BarrierOp::operator()() { - // The children_num_ parameter needs to be put here - // Synchronize with TaskManager once the thread is created. - TaskManager::FindMe()->Post(); - - // create child iterator, right now this barrier is a pipeline operator - const int32_t worker_id = 0; - const int32_t child_idx = 0; - child_iterator_ = std::make_unique(this, worker_id, child_idx); - - // Loop until eof is true - while (!eof_) { - // Create new table to put the new tensor rows - std::unique_ptr curr_table = std::make_unique(); - RETURN_IF_NOT_OK(prepare(curr_table.get())); - - // If an eof got picked up during the above prepare, then we're done - if (eof_) { - break; - } - - // we have to output new buffer with possibly different buffer size, possibly one row - while (!clean_up_) { - // 1. If a previous loop iteration sent the current table out, then create a new one. - - if (curr_table == nullptr) { - curr_table = std::make_unique(); - } - - // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished - RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); - - // 3 create and update buffer and send it to the out connector - if (!curr_table->empty()) { - std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); - curr_buffer->set_tensor_table(std::move(curr_table)); - MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " - << curr_buffer->NumCols() << ", map " << column_name_id_map_.size() << "."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); - buffer_id_++; - } - } - - // 4 handle drain state. - if (clean_up_) { - MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; - // Send the eoe up. - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - } - } - // 5 handle eof - // propagate eof here. - MS_LOG(INFO) << "Barrier operator got EOF, propagating."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - return Status::OK(); -} - -// Handles preprocessing of the main loop, used when starting new epoch -Status BarrierOp::prepare(TensorQTable *const table) { - MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; - clean_up_ = false; - buffer_id_ = 0; - if (table == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); - } - // fill initial row - TensorRow new_row = {}; - // use iterator to get next row and invoke pyfunc wait - RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); - - // If the first row fetching resulted in eof, then we are done. - if (eof_) { - return Status::OK(); - } - if (new_row.empty()) { - // This epoch is empty - return Status::OK(); - } - // Pack this first row into our tensor table - // first row we also have to check if we should block - RETURN_IF_NOT_OK(blockCond()); - - table->push_back(std::move(new_row)); - - // the update code below shouldn't do anything bad if the column name already exists. - return Status::OK(); -} - -// fillBuffer always expects a new table to fill -Status BarrierOp::fillBuffer(TensorQTable *const table) { - if (table == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); - } - TensorRow new_row = {}; - while (table->size() < static_cast(rows_per_buffer_)) { - RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); - // Early exit the loop if we got empty row from any of our child iterations - if (new_row.empty()) { - return Status::OK(); - } - // else we got a row so pack it into the tensor table. - RETURN_IF_NOT_OK(blockCond()); - - table->push_back(std::move(new_row)); - } - return Status::OK(); -} - -// function executes a py_func and blocks until condition becomes true. -Status BarrierOp::blockCond() { - { - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - // we have condition name, however the flexibility is in python today - try { - // Invoke python function - py::object ret_py_obj = condition_function_(); - // Process the return value - if (!py::isinstance(ret_py_obj)) { - return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); - } - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } - } - return Status::OK(); -} - -// fetches next Barrier buffer row -Status BarrierOp::getNextTensorRow(TensorRow *new_row) { - // iterate over all iterators and generate a row - RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); - // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row - if (new_row->empty()) { - // If we did not get a row from any of the children, then it's the end of an epoch and we can move - // to drain state. - MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; - clean_up_ = true; - // If we picked up an eof here, then we are completely done. - if ((child_iterator_)->eof_handled()) { - MS_LOG(INFO) << "Barrier operator iterator got EOF."; - eof_ = true; - } - return Status::OK(); - } - return Status::OK(); -} - -// A function that prints info about the Operator -void BarrierOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nCondition: " << condition_name_ << "\n\n"; - } -} - -// overwrite function and handle eof -Status BarrierOp::EofReceived(int32_t) { - MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; - return Status::OK(); -} - -// overwrite function and handle eoe -Status BarrierOp::EoeReceived(int32_t) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h deleted file mode 100644 index 379b8f146b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h +++ /dev/null @@ -1,169 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ - -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -// Forward declare -class DataBuffer; -class ExecutionTree; - -// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has -// been received. This signal is given from python layer. The current barrier design respects the -// rows per buffer design and will only output a buffer with rows once it has received rows per buffer -// signals from python. - -class BarrierOp : public PipelineOp { - public: - // The nested builder class inside of the BarrierOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @param int32_t op_connector_size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @param const std::string & condition_name - // @return Builder setter method returns reference to the builder. - Builder &SetConditionName(const std::string &condition_name) { - builder_condition_name_ = condition_name; - return *this; - } - - // Setter method. - // @param py::function condition_func - blocking condition function - // @return Builder setter method returns reference to the builder. - Builder &SetConditionFunc(py::function condition_func) { - builder_condition_func_ = condition_func; - return *this; - } - - // The builder "build" method creates the BarrierOp dataset Operator. - // @return shared_ptr to the new BarrierOp object - Status Build(std::shared_ptr *); - - private: - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - std::string builder_condition_name_; - py::function builder_condition_func_; - - Status SanityCheck() const; - }; - - // Constructor for BarrierOp - // @param rows_per_buffer - number of rows in output buffer - // @param op_connector_size - connector size - // @param condition_name - the condition name associated with this operator - // @param condition_func - the blocking condition check per row - // @note - currently rows_per_buffer should = 1 for barrier. - // The reason for this is having other values would complicate how the pipeline behaves with other operators - // One example of such case is having batch after barrier. Batch would be waiting for data and having - // rows per buffer in this case can result in hanging - BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, - py::function condition_func); - - // Destructor - ~BarrierOp(); - - Status EofReceived(int32_t) override; - - Status EoeReceived(int32_t) override; - - // Print function for Barrier - // @param out - output stream to print to - // @param show_all - if it should print everything - void Print(std::ostream &out, bool show_all) const override; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { - bo.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Handles preprocessing of the main loop, used when starting new epoch - // @param table - a table of tensors to be moved into a buffer - Status prepare(TensorQTable *const table); - - // This function calls takes a table repeatedly adds rows to it. - // @param table - a table of tensors to be moved into a buffer - Status fillBuffer(TensorQTable *const table); - - // Gets next tensor row and sets control signals - Status getNextTensorRow(TensorRow *new_row); - - // This function runs the wait function on condition - Status blockCond(); - - private: - // clean up variable to return imcomplete buffer - bool clean_up_; - // end of file state, we stop reading data and shut down - bool eof_; - // rows per buffer - int32_t rows_per_buffer_; - // buffer_id - int32_t buffer_id_; - // iterator to pull new rows, we only have one child - std::unique_ptr child_iterator_; - // condition name, to support multiple barriers - std::string condition_name_; - // Function pointer of blocking function - py::function condition_function_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc deleted file mode 100644 index 8bfa8c287c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc +++ /dev/null @@ -1,416 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/batch_op.h" - -#include -#include - -#include "common/utils.h" -#include "dataset/core/pybind_support.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/kernels/data/data_utils.h" - -using float16 = Eigen::half; - -namespace mindspore { -namespace dataset { -BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false), builder_pad_(false), builder_pad_map_({}) { - builder_batch_size_ = batch_size; - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status BatchOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(builder_batch_size_, builder_drop_, builder_pad_, builder_op_connector_size_, - builder_num_workers_, builder_cols_to_map_, builder_batch_size_func_, - builder_batch_map_func_, builder_pad_map_); - return Status::OK(); -} - -Status BatchOp::Builder::SanityCheck() { - std::string err; - err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; - err += builder_batch_size_ <= 0 ? "batch size <= 0\n" : ""; - err += builder_num_workers_ <= 0 ? "batch num_parallel_workers <= 0\n" : ""; - return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); -} - -BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, - const std::vector &cols_to_map, py::function batch_size_func, py::function batch_map_func, - PadInfo pad_map) - : ParallelOp(num_workers, op_queue_size), - start_batch_size_(batch_size), - drop_(drop), - pad_(pad), - pyfunc_column_names_(cols_to_map), - batch_size_func_(batch_size_func), - batch_map_func_(batch_map_func), - pad_info_(pad_map) { - worker_queues_.Init(num_workers, op_queue_size); -} - -Status BatchOp::operator()() { - Status rc = LaunchThreadsAndInitOp(); - // Synchronize with TaskManager - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(rc); - int64_t epoch_num = 0, batch_num = 0, cnt = 0; - TensorRow new_row; - std::unique_ptr table = std::make_unique(); - child_iterator_ = std::make_unique(this, 0, 0); - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - int32_t cur_batch_size = 0; - RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(0, 0, 0))); - while (child_iterator_->eof_handled() == false) { - while (new_row.empty() == false) { - table->emplace_back(new_row); - // if # of rows is enough to make 1 batch (1 batch is buffer), send it to worker_queue - if (table->size() == static_cast(cur_batch_size)) { - RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack( - std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt - epoch_num)))); - table = std::make_unique(); - RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num))); - } - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - } - // Reminder logic, execute only when there is a remainder (table is non empty) and don't drop - if (drop_ == false && table->empty() == false) { - RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack( - std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt - epoch_num)))); - } - table = std::make_unique(); // this drops when drop == true - // end of the current epoch, batch_num should start from 0 again - batch_num = 0; - epoch_num++; - RETURN_IF_NOT_OK( - worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOE)))); - RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num))); - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - } // end of eof_handled() == false - RETURN_IF_NOT_OK( - worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOF)))); - // EOF received, send quit signal (an empty buffer) to all workers - for (int32_t ind = 0; ind < num_workers_; ind++) { - RETURN_IF_NOT_OK( - worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kQuit)))); - } - return Status::OK(); -} - -void BatchOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [batch size: " << start_batch_size_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nStart batch size: " << start_batch_size_ << "\nDrop remainder: " << (drop_ ? "yes" : "no") << "\n\n"; - } -} - -Status BatchOp::BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, - dsize_t batch_size) { - if ((*src)->size() != batch_size) { - RETURN_STATUS_UNEXPECTED("[Internal Batch ERROR] Source table size does not match the batch_size"); - } - - if (batch_size == 1) { - TensorRow row = std::move((*src)->front()); - (*src)->pop_front(); - (*dest)->push_back(row); - for (const auto &tensor : (*dest)->front()) { - RETURN_IF_NOT_OK(tensor->ExpandDim(0)); - } - return Status::OK(); - } - - TensorRow batched_row; - auto num_columns = (*src)->front().size(); - for (size_t i = 0; i < num_columns; i++) { - std::shared_ptr first_tensor = (*src)->at(0).at(i); // first row, column i - TensorShape first_shape = first_tensor->shape(); - DataType first_type = first_tensor->type(); - TensorShape new_shape = first_shape.PrependDim(static_cast(batch_size)); - - std::shared_ptr new_tensor; - if (first_type.IsNumeric()) { // numeric tensor - RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, TensorImpl::kFlexible, new_shape, first_type)); - dsize_t j = 0; - for (auto row : **src) { - std::shared_ptr old_tensor = row.at(i); // row j, column i - if (old_tensor->shape() == first_shape) { // check the newly popped rows have the same dim as the first - RETURN_IF_NOT_OK(new_tensor->InsertTensor({j++}, old_tensor)); - } else { - RETURN_STATUS_UNEXPECTED("[Batch ERROR] Inconsistent TensorShapes of Column " + std::to_string(i)); - } - } - } else { // handle string column differently - std::vector strings; - for (dsize_t j = 0; j < batch_size; j++) { - std::shared_ptr old_tensor = (*src)->at(j).at(i); - for (auto itr = old_tensor->begin(); itr != old_tensor->end(); itr++) { - strings.emplace_back(*itr); - } - } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, strings, new_shape)); - } - batched_row.emplace_back(new_tensor); - } - - (*dest)->emplace_back(batched_row); - - return Status::OK(); -} - -Status BatchOp::WorkerEntry(int32_t workerId) { - TaskManager::FindMe()->Post(); - std::pair, CBatchInfo> table_pair; - RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair)); - while (table_pair.second.ctrl_ != batchCtrl::kQuit) { - if (table_pair.second.ctrl_ == batchCtrl::kEOE) { - RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - } else if (table_pair.second.ctrl_ == batchCtrl::kEOF) { - RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else if (table_pair.second.ctrl_ == batchCtrl::kNoCtrl) { - std::unique_ptr db = nullptr; - RETURN_IF_NOT_OK(MakeBatchedBuffer(std::move(table_pair), &db)); - RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::move(db))); - } - RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair)); - } - return Status::OK(); -} - -Status BatchOp::MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, - std::unique_ptr *db) { - RETURN_UNEXPECTED_IF_NULL(table_pair.first); - if (!pyfunc_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc - if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_)); // do padding if needed - (*db) = std::make_unique(table_pair.second.batch_num_, DataBuffer::kDeBFlagNone); - std::unique_ptr dest_table = std::make_unique(); - RETURN_IF_NOT_OK(BatchRows(&table_pair.first, &dest_table, table_pair.first->size())); - (*db)->set_tensor_table(std::move(dest_table)); - return Status::OK(); -} - -Status BatchOp::LaunchThreadsAndInitOp() { - RETURN_UNEXPECTED_IF_NULL(tree_); - RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&BatchOp::WorkerEntry, this, std::placeholders::_1))); - return Status::OK(); -} - -Status BatchOp::EofReceived(int32_t) { return Status::OK(); } - -Status BatchOp::EoeReceived(int32_t) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -Status BatchOp::MapColumns(std::pair, CBatchInfo> *table_pair) { - TensorBatchTable input_table; - input_table.reserve(pyfunc_column_names_.size()); - for (std::string col_name : pyfunc_column_names_) { - if (column_name_id_map_.find(col_name) == column_name_id_map_.end()) { - RETURN_STATUS_UNEXPECTED("column : '" + col_name + "' does not exist\n"); - } - TensorBatch tensor_batch; - tensor_batch.reserve(table_pair->first->size()); - size_t col_idx = static_cast(column_name_id_map_[col_name]); - for (size_t row_idx = 0; row_idx < table_pair->first->size(); row_idx++) { - tensor_batch.push_back(std::move(table_pair->first->at(row_idx)[col_idx])); - } - input_table.push_back(std::move(tensor_batch)); - } - - // Perform batch map - TensorBatchTable output_table; - RETURN_IF_NOT_OK(InvokeBatchMapFunc(&input_table, &output_table, table_pair->second)); - - // Write back to TensorQTable - for (size_t input_idx = 0; input_idx < pyfunc_column_names_.size(); input_idx++) { - size_t col_idx = static_cast(column_name_id_map_[pyfunc_column_names_[input_idx]]); - size_t row_id = 0; - for (TensorRow &row : *(table_pair->first)) { - row[col_idx] = std::move(output_table[input_idx][row_id++]); - } - } - return Status::OK(); -} - -Status BatchOp::GetBatchSize(int32_t *batch_size, CBatchInfo info) { - if (batch_size_func_ != nullptr) { - RETURN_IF_NOT_OK(InvokeBatchSizeFunc(batch_size, info)); - } else { - (*batch_size) = start_batch_size_; - } - return Status::OK(); -} - -Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) { - { - // Acquire Python GIL - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - py::object size = batch_size_func_(info); - *batch_size = size.cast(); - if (*batch_size <= 0) { - return Status(StatusCode::kPyFuncException, "Batch size function should return an integer > 0"); - } - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } catch (const py::cast_error &e) { - return Status(StatusCode::kPyFuncException, "Batch size function should return an integer > 0"); - } - } - return Status(StatusCode::kOK, "Batch size func call succeed"); -} - -Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *output, CBatchInfo info) { - { - // Acquire Python GIL - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - // Prepare batch map call back parameters - py::tuple input_args(input->size() + 1); - for (size_t i = 0; i < input->size(); i++) { - std::vector np_batch; - for (std::shared_ptr t : input->at(i)) { - py::array np_array; - RETURN_IF_NOT_OK(t->GetDataAsNumpy(&np_array)); - np_batch.push_back(std::move(np_array)); - } - input_args[i] = np_batch; - } - input_args[input->size()] = info; - // Invoke batch map func - py::object ret_py_obj = batch_map_func_(*input_args); - // Parse batch map return value - py::tuple ret_tuple = py::cast(ret_py_obj); - if (ret_tuple.size() != pyfunc_column_names_.size() || !py::isinstance(ret_tuple)) { - return Status(StatusCode::kPyFuncException, "Batch map function should return a tuple"); - } - for (size_t i = 0; i < ret_tuple.size(); i++) { - TensorBatch output_batch; - py::list output_list = py::cast(ret_tuple[i]); - for (size_t j = 0; j < output_list.size(); j++) { - std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, py::cast(output_list[j]))); - output_batch.push_back(std::move(out)); - } - output->push_back(std::move(output_batch)); - } - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } catch (const py::cast_error &e) { - return Status(StatusCode::kPyFuncException, "Batch map function should return an tuple of list of numpy array"); - } - } - return Status(StatusCode::kOK); -} - -Status BatchOp::PadColumns(std::unique_ptr *table, const PadInfo &pad_info, - const std::unordered_map &column_name_id_map) { - RETURN_UNEXPECTED_IF_NULL(table); // placeholder for now, might need this in the future - CHECK_FAIL_RETURN_UNEXPECTED((*table)->front().size() == column_name_id_map.size(), "col_name_map mismatch"); - std::vector> pad_vals(column_name_id_map.size(), - 0); // value to pad each column's tensor with, default 0 - std::set pad_cols; - // padded_shape provided by user, maximum shapes of current batch of tensors - std::vector> pad_shapes(column_name_id_map.size()), max_shapes(column_name_id_map.size()); - RETURN_IF_NOT_OK(UnpackPadInfo(pad_info, column_name_id_map, &pad_cols, &pad_vals, &pad_shapes)); - - // init each shape in max_shape to {-1,-1...} init each unspecified shape in pad_shape to -1 as well - for (size_t col_id : pad_cols) { - max_shapes[col_id] = std::vector((*table)->front()[col_id]->Rank(), -1); - if (pad_shapes[col_id].empty()) pad_shapes[col_id] = max_shapes[col_id]; // fill pad shape with -1 - CHECK_FAIL_RETURN_UNEXPECTED(pad_shapes[col_id].size() == max_shapes[col_id].size(), "wrong rank in pad_shape"); - } - - // calculate maximum shape for each column that needs to be padded - for (const TensorRow &row : **table) { // iterator each row in a batch - for (size_t col_id : pad_cols) { // iterator each tensor in a row - CHECK_FAIL_RETURN_UNEXPECTED(row[col_id]->Rank() == max_shapes[col_id].size(), - "Tensor to be padded together need to have the same rank"); - for (size_t dim = 0; dim < row[col_id]->Rank(); dim++) { // pick the largest number in each dimension - max_shapes[col_id][dim] = std::max(max_shapes[col_id][dim], row[col_id]->shape()[dim]); - } - } - } - - // if user sets a dimension to -1 (None in python), use the max value for current dimension - for (size_t col_id : pad_cols) { - for (size_t dim = 0; dim < pad_shapes[col_id].size(); dim++) { - if (pad_shapes[col_id][dim] < 0) pad_shapes[col_id][dim] = max_shapes[col_id][dim]; - } - } - - // call pad on each tensor that needs to be padded - for (TensorRow &row : **table) { - for (size_t col_id : pad_cols) { - std::shared_ptr pad_tensor; - RETURN_IF_NOT_OK(PadEnd(row[col_id], &pad_tensor, pad_shapes[col_id], pad_vals[col_id])); - row[col_id] = pad_tensor; - } - } - return Status::OK(); -} - -Status BatchOp::UnpackPadInfo(const PadInfo &pad_info, - const std::unordered_map &column_name_id_map, - std::set *pad_cols, std::vector> *pad_vals, - std::vector> *pad_shapes) { - if (pad_info.empty()) { // if pad_info empty, pad every columns automatically - for (dsize_t col_id = 0; col_id < column_name_id_map.size(); col_id++) { - pad_cols->insert(col_id); - } - } else { - for (const auto &p : pad_info) { - auto location = column_name_id_map.find(p.first); - CHECK_FAIL_RETURN_UNEXPECTED(location != column_name_id_map.end(), "no column exists with name:" + p.first); - auto col_id = static_cast(location->second); - CHECK_FAIL_RETURN_UNEXPECTED(col_id < pad_vals->size() && col_id < pad_shapes->size(), "col_id out of bound"); - pad_cols->insert(col_id); - (*pad_vals)[col_id] = p.second.second; // set pad values - (*pad_shapes)[col_id] = p.second.first.AsVector(); // empty vector if shape is unknown - } - } - return Status::OK(); -} - -// Visitor accept method for NodePass -Status BatchOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h deleted file mode 100644 index 28df5e7e81..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h +++ /dev/null @@ -1,271 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DataBuffer; - -using TensorBatch = TensorRow; -using TensorBatchTable = std::vector; -using PadInfo = std::map>>; - -class BatchOp : public ParallelOp { - public: - class Builder { - public: - // Builder constructor for Batch, batch size needs to be specified - // @param int32_t batch_size - explicit Builder(int32_t batch_size); - - // Default destructor - ~Builder() = default; - - // set number of parallel Workers on batch - // @param int32_t num_workers - // @return Builder & reference to builder class object - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // set drop for batch op,default false - // @param bool drop - // @return Builder & reference to builder class object - Builder &SetDrop(bool drop) { - builder_drop_ = drop; - return *this; - } - - Builder &SetPaddingMap(const PadInfo &pad_map, bool pad = true) { - builder_pad_ = pad; - builder_pad_map_ = pad_map; - return *this; - } - - // set connector size for batch - // @param int32_t op_conn_size - // @return Builder & reference to builder class object - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = (op_connector_size == 0 ? builder_op_connector_size_ : op_connector_size); - return *this; - } - - // set columns to perform map on - // @param const std::vector & cols_to_map - name of columns to perform map on - // @return Builder & reference to builder class object - Builder &SetColumnsToMap(const std::vector &cols_to_map) { - builder_cols_to_map_ = cols_to_map; - return *this; - } - - // set columns to perform map on - // @param const std::vector & cols_to_map - name of columns to perform map on - // @return Builder & reference to builder class object - Builder &SetBatchMapFunc(py::function batch_map_func) { - builder_batch_map_func_ = batch_map_func; - return *this; - } - - // SetBatchSizeFunc, a function that calls to python after every batch is made - // @param py::function batch_size_func - python function to call, GIL required before calling - // @return Builder & reference to builder class object - Builder &SetBatchSizeFunc(py::function batch_size_func) { - builder_batch_size_func_ = batch_size_func; - return *this; - } - - // @param std::shared_ptr *ptr pointer to shared_ptr, actual return arg - // @return Status - The error code return - Status Build(std::shared_ptr *); - - private: - // Sanity check for builder class args - // @return Status - The error code return - Status SanityCheck(); - - bool builder_drop_; - bool builder_pad_; - int32_t builder_batch_size_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - std::vector builder_cols_to_map_; - PadInfo builder_pad_map_; - py::function builder_batch_size_func_; - py::function builder_batch_map_func_; - }; - - enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 }; - - // Parameters associate with one batch. - // This struct is used for both internal control and python callback. - // This struct is bound to python with read-only access. - struct CBatchInfo { - CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl) - : epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {} - CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {} - CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {} - explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {} - int64_t epoch_num_; // i-th epoch. i starts from 0 - int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0 - int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0 - batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3 - const int64_t get_batch_num() const { return batch_num_; } - const int64_t get_epoch_num() const { return epoch_num_; } - }; - - // BatchOp constructor - // @param int32_t batch_size - // @param bool drop - // @param int32_t op_queue_size - // @param int32_t rows_per_buf - // @param int32_t num_workers - BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, - const std::vector &, py::function batch_size_func, py::function batch_map_func, PadInfo pad_map); - - // BatchOp destructor - ~BatchOp() {} - - // @param int32_t workerId - // @return Status - The error code return - Status EofReceived(int32_t) override; - - // @param int32_t workerId - // @return Status - The error code return - Status EoeReceived(int32_t) override; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param sO - reference to the BatchOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const BatchOp &bo) { - bo.Print(out, false); - return out; - } - - // Main loop of batch - // @return Status - The error code return - Status operator()() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "BatchOp"; } - - // batch the rows in src table then put it to dest table - // @param const std::unique_ptr *src - table that has the rows for batching - // @param const std::unique_ptr *dest - dest_table to hold batched rows - // @param int32_t size - batch_size - // @param const std::unordered_map& column_name_id_map - column names to index mapping - // @return Status - The error code return - static Status BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, - dsize_t batch_size); - - // @param table - // @param const PadInfo &pad_info pad info - // @param const std::unordered_map& column_name_id_map - column names to index mapping - // @return Status - The error code return - static Status PadColumns(std::unique_ptr *table, const PadInfo &pad_info, - const std::unordered_map &column_name_id_map); - - private: - // Worker thread for doing the memcpy of batch - // @param int32_t param workerId - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Generate buffer with batched tensors - // @return Status - The error code return - Status MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, - std::unique_ptr *db); - // Function that calls pyfunc to perform map on batch - // @param (std::pair, batch_stats> *table_pair - contains un-batched tensor - // @return Status - The error code return - Status MapColumns(std::pair, CBatchInfo> *table_pair); - - // @param const PadInfo &pad_info pad info to unpack - // @param const std::unordered_map& column_name_id_map - column names to index mapping - // @param std::set *cols, col ids to perform pad on - // @param std::vector *vals, default padding value for each column - // @param std::vector> *shapes, padding shape specified by user - // @return Status - The error code return - static Status UnpackPadInfo(const PadInfo &pad_info, - const std::unordered_map &column_name_id_map, - std::set *pad_cols, std::vector> *pad_vals, - std::vector> *pad_shapes); - - // the number of thread pulling from the mOutConnector of the Op below - // @return int32_t, 1 - int32_t num_consumers() const override { return 1; } - - // get the batch size for next batch - // @return Status - The error code return - Status GetBatchSize(int32_t *batch_size, CBatchInfo info); - - // Do the initialization of all queues then start all worker threads - // @return Status - The error code return - Status LaunchThreadsAndInitOp(); - - // Invoke batch size function with current BatchInfo to generate batch size. - // @return Status - The error code return - Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info); - - // Invoke batch map function with current BatchInfo to generate tensors to batch. - // @return Status - The error code return - Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info); - - int32_t start_batch_size_; - bool drop_; // bool for whether to drop remainder or not - bool pad_; // bool for whether to perform padding on tensor - std::vector pyfunc_column_names_; // Name of the columns to perform map op on - PadInfo pad_info_; // column names to perform padding on - std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 - QueueList, CBatchInfo>> worker_queues_; // internal queue for syncing worker - py::function batch_size_func_; // Function pointer of batch size function - py::function batch_map_func_; // Function pointer of per batch map function -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc deleted file mode 100644 index 5e143b700f..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc +++ /dev/null @@ -1,240 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/bucket_batch_by_length_op.h" - -#include -#include -#include -#include -#include - -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "dataset/core/pybind_support.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/util/status.h" - -namespace py = pybind11; -namespace mindspore { -namespace dataset { -BucketBatchByLengthOp::Builder::Builder(std::vector length_dependent_columns, - std::vector bucket_boundaries, std::vector bucket_batch_sizes) - : builder_length_dependent_columns_(length_dependent_columns), - builder_bucket_boundaries_(bucket_boundaries), - builder_bucket_batch_sizes_(bucket_batch_sizes), - builder_pad_info_({}), - builder_pad_to_bucket_boundary_(false), - builder_drop_remainder_(false) { - std::shared_ptr config_manager = GlobalContext::config_manager(); - builder_op_connector_size_ = config_manager->op_connector_size(); -} - -Status BucketBatchByLengthOp::Builder::SanityCheck() { - std::string error_message; - - if (builder_length_dependent_columns_.empty()) { - error_message += "At least 1 column must be specified for element length calculation.\n"; - } - - if (builder_bucket_boundaries_.empty()) { - error_message += "At least 1 bucket boundary must be specified.\n"; - } - - if (builder_bucket_batch_sizes_.size() != builder_bucket_boundaries_.size() + 1) { - error_message += "There must be exactly one bucket batch size specified for each bucket boundary.\n"; - } - - CHECK_FAIL_RETURN_UNEXPECTED(error_message.empty(), error_message); - - return Status::OK(); -} - -Status BucketBatchByLengthOp::Builder::Build(std::shared_ptr *new_bucket_batch_by_length_op) { - RETURN_IF_NOT_OK(SanityCheck()); - - // insert 0 for the first bucket - builder_bucket_boundaries_.insert(builder_bucket_boundaries_.begin(), 0); - - *new_bucket_batch_by_length_op = std::make_shared( - builder_length_dependent_columns_, builder_bucket_boundaries_, builder_bucket_batch_sizes_, - builder_element_length_function_, builder_pad_info_, builder_pad_to_bucket_boundary_, builder_drop_remainder_, - builder_op_connector_size_); - - return Status::OK(); -} - -BucketBatchByLengthOp::BucketBatchByLengthOp(std::vector length_dependent_columns, - std::vector bucket_boundaries, - std::vector bucket_batch_sizes, - py::function element_length_function, PadInfo pad_info, - bool pad_to_bucket_boundary, bool drop_remainder, - int32_t op_connector_size) - : PipelineOp(op_connector_size), - length_dependent_columns_(length_dependent_columns), - bucket_boundaries_(bucket_boundaries), - bucket_batch_sizes_(bucket_batch_sizes), - element_length_function_(element_length_function), - pad_info_(pad_info), - pad_to_bucket_boundary_(pad_to_bucket_boundary), - drop_remainder_(drop_remainder), - batch_count_(0) { - for (int i = 0; i < bucket_batch_sizes_.size(); i++) { - buckets_.push_back(std::make_unique()); - } -} - -Status BucketBatchByLengthOp::EoeReceived(int32_t) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; } - -Status BucketBatchByLengthOp::operator()() { - TaskManager::FindMe()->Post(); - - TensorRow current_row; - child_iterator_ = std::make_unique(this, 0, 0); - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); - while (!child_iterator_->eof_handled()) { - while (!current_row.empty()) { - int32_t element_length; - RETURN_IF_NOT_OK(ObtainElementLength(&element_length, current_row)); - - int bucket_index = bucket_boundaries_.size() - 1; - while (element_length < bucket_boundaries_[bucket_index]) { - bucket_index--; - } - - buckets_[bucket_index]->push_back(current_row); - - if (buckets_[bucket_index]->size() == bucket_batch_sizes_[bucket_index]) { - RETURN_IF_NOT_OK(PadAndBatchBucket(bucket_index, bucket_batch_sizes_[bucket_index])); - } - - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); - } - - // got EOE, do what we need to do with remainders in each bucket - if (!drop_remainder_) { - for (int i = 0; i < bucket_boundaries_.size(); i++) { - if (!buckets_[i]->empty()) { - RETURN_IF_NOT_OK(PadAndBatchBucket(i, buckets_[i]->size())); - } - } - } - - // need to send EOE manually since we set state to idle in EoeRecieved() - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); - } - - return Status::OK(); -} - -Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, TensorRow element) { - // call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of - // the single column specified in length_dependent_columns_ - if (element_length_function_) { - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - size_t number_of_arguments = length_dependent_columns_.size(); - py::tuple input_arguments(number_of_arguments); - for (size_t i = 0; i < number_of_arguments; i++) { - py::array argument_value; - int32_t column_index = column_name_id_map_[length_dependent_columns_[i]]; - RETURN_IF_NOT_OK(element[column_index]->GetDataAsNumpy(&argument_value)); - input_arguments[i] = argument_value; - } - - py::object length = element_length_function_(*input_arguments); - *out_element_length = length.cast(); - if (*out_element_length < 0) { - return Status(StatusCode::kPyFuncException, "Element length function should return a non negative integer."); - } - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } catch (const py::cast_error &e) { - return Status(StatusCode::kPyFuncException, "Count not cast output of element length function to int32_t."); - } - } else { - *out_element_length = element[0]->shape()[0]; - } - - return Status::OK(); -} - -Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t batch_size) { - std::unique_ptr *bucket = &buckets_[bucket_index]; - - PadInfo pad_info_copy = pad_info_; - if (pad_to_bucket_boundary_) { - for (auto &pair : pad_info_copy) { - std::vector pad_shape = pair.second.first.AsVector(); - - for (size_t i = 0; i < pad_shape.size(); i++) { - if (pad_shape[i] == TensorShape::kDimUnknown) { - if (bucket_index + 1 >= bucket_boundaries_.size()) { - std::string error_message = "Requested to pad to bucket boundary, element falls in last bucket"; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message); - } - - pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1; - } - } - - pair.second.first = TensorShape(pad_shape); - } - } - - // PadColumns will change the data in bucket - RETURN_IF_NOT_OK(BatchOp::PadColumns(bucket, pad_info_copy, column_name_id_map_)); - - std::unique_ptr batched_bucket = std::make_unique(); - RETURN_IF_NOT_OK(BatchOp::BatchRows(bucket, &batched_bucket, batch_size)); - (*bucket)->clear(); - - std::unique_ptr batched_buffer = std::make_unique(batch_count_, DataBuffer::kDeBFlagNone); - batched_buffer->set_tensor_table(std::move(batched_bucket)); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(batched_buffer))); - - batch_count_++; - - return Status::OK(); -} - -Status BucketBatchByLengthOp::Reset() { - batch_count_ = 0; - - for (int i = 0; i < buckets_.size(); i++) { - buckets_[i] = std::make_unique(); - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h deleted file mode 100644 index bf0bcb0e78..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/batch_op.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DataBuffer; - -class BucketBatchByLengthOp : public PipelineOp { - public: - class Builder { - public: - Builder(std::vector length_dependent_columns, std::vector bucket_boundaries, - std::vector bucket_batch_sizes); - - ~Builder() = default; - - Builder &SetLengthDependentColumns(std::vector length_dependent_columns) { - builder_length_dependent_columns_ = length_dependent_columns; - return *this; - } - - Builder &SetBucketBoundaries(std::vector bucket_boundaries) { - builder_bucket_boundaries_ = bucket_boundaries; - return *this; - } - - Builder &SetBucketBatchSizes(std::vector bucket_batch_sizes) { - builder_bucket_batch_sizes_ = bucket_batch_sizes; - return *this; - } - - Builder &SetElementLengthFunction(py::function element_length_function) { - builder_element_length_function_ = element_length_function; - return *this; - } - - Builder &SetPadInfo(PadInfo pad_info) { - builder_pad_info_ = pad_info; - return *this; - } - - Builder &SetPadToBucketBoundary(bool pad_to_bucket_boundary) { - builder_pad_to_bucket_boundary_ = pad_to_bucket_boundary; - return *this; - } - - Builder &SetDropRemainder(bool drop_remainder) { - builder_drop_remainder_ = drop_remainder; - return *this; - } - - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - Status Build(std::shared_ptr *new_bucket_batch_by_length_op); - - private: - Status SanityCheck(); - - std::vector builder_length_dependent_columns_; - std::vector builder_bucket_boundaries_; - std::vector builder_bucket_batch_sizes_; - py::function builder_element_length_function_; - PadInfo builder_pad_info_; - bool builder_pad_to_bucket_boundary_; - bool builder_drop_remainder_; - int32_t builder_op_connector_size_; - }; - - BucketBatchByLengthOp(std::vector length_dependent_columns, std::vector bucket_boundaries, - std::vector bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, - bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); - - // Destructor - ~BucketBatchByLengthOp() = default; - - // Might need to batch remaining buckets after receiving eoe, so override this method. - // @param int32_t workerId - // @return Status - The error code returned - Status EoeReceived(int32_t) override; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param sO - reference to the BucketBatchByLengthOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const BucketBatchByLengthOp &bo) { - bo.Print(out, false); - return out; - } - - // Main loop of batch - // @return Status - The error code returned - Status operator()() override; - - // Function that is called by ResetOp at the end of every epoch - // @return Status - The error code returned - Status Reset() override; - - private: - Status ObtainElementLength(int32_t *out_element_length, TensorRow element); - - Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size); - - std::vector length_dependent_columns_; - std::vector bucket_boundaries_; - std::vector bucket_batch_sizes_; - py::function element_length_function_; - PadInfo pad_info_; - bool pad_to_bucket_boundary_; - bool drop_remainder_; - - int32_t batch_count_; - std::unique_ptr child_iterator_; - std::vector> buckets_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc deleted file mode 100644 index ceb5058593..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/engine/datasetops/build_vocab_op.h" - -#include -#include -#include -#include -#include -#include "dataset/core/config_manager.h" - -namespace mindspore { -namespace dataset { - -BuildVocabOp::BuildVocabOp(std::shared_ptr vocab, std::vector col_names, - std::pair freq_r, int64_t top_k, const std::vector &tokens, - bool prepend, int32_t num_workers, int32_t op_conn_size) - : ParallelOp(num_workers, op_conn_size), - interval_(op_conn_size * num_workers), - vocab_(vocab), - col_names_(col_names), - freq_range_(freq_r), - top_k_(top_k), - special_tokens_(tokens), - special_first_(prepend) { - // init two queues for thread sync - distributor_queue_ = std::make_unique>(num_workers * op_conn_size); - collector_queue_ = - std::make_unique>>>(num_workers * op_conn_size); -} - -Status BuildVocabOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - TensorRow new_row; - RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); - std::unique_ptr> wrkr_map = - std::make_unique>(); - int32_t row_cnt = 0; - while (!new_row.empty()) { - for (int32_t col : col_ids_) { - CHECK_FAIL_RETURN_UNEXPECTED(!new_row[col]->type().IsNumeric(), "from_dataset only works on string columns"); - for (auto itr = new_row[col]->begin(); itr != new_row[col]->end(); itr++) { - (*wrkr_map)[std::string(*itr)] += 1; - } - } - row_cnt++; // row is processed by this point - if ((row_cnt % interval_ == 0) && ((row_cnt / interval_) % num_workers_ == worker_id) && (!wrkr_map->empty())) { - RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); - wrkr_map = std::make_unique>(); - } - RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); - } - // clean up - if (!wrkr_map->empty()) { - RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); - } - // empty map as quit signal - RETURN_IF_NOT_OK(collector_queue_->Add(std::make_unique>())); - return Status::OK(); -} - -Status BuildVocabOp::operator()() { - // launch the collector thread - RETURN_UNEXPECTED_IF_NULL(tree_); - RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks())); - // launch worker threads and collector thread - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&BuildVocabOp::WorkerEntry, this, std::placeholders::_1))); - RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("collector", std::bind(&BuildVocabOp::CollectorThread, this))); - TaskManager::FindMe()->Post(); - child_iterator_ = std::make_unique(this, 0, 0); - TensorRow new_row; - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - if (!col_names_.empty()) { - col_ids_.reserve(col_names_.size()); - for (std::string col : col_names_) { - auto itr = column_name_id_map_.find(col); - CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(), col + " column doesn't exist"); - col_ids_.push_back(itr->second); - } - } else { - col_ids_.reserve(column_name_id_map_.size()); - for (const auto &p : column_name_id_map_) { - col_ids_.push_back(p.second); - } - } - bool eoe_warning = false; // give out warning if receive more than 1 eoe - while (child_iterator_->eof_handled() == false) { - while (new_row.empty() == false) { - RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(new_row)); - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - } - CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no op should be after from_dataset (repeat detected)"); - eoe_warning = true; - } - - // tell all workers to quit - for (int32_t wrkr_id = 0; wrkr_id < num_workers_; wrkr_id++) { - RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(TensorRow())); - } - return Status::OK(); -} - -Status BuildVocabOp::CollectorThread() { - TaskManager::FindMe()->Post(); - int32_t num_quited_worker = 0; - std::unique_ptr> wrkr_map; - while (num_quited_worker != num_workers_) { - RETURN_IF_NOT_OK(collector_queue_->PopFront(&wrkr_map)); - RETURN_UNEXPECTED_IF_NULL(wrkr_map); - if (!wrkr_map->empty()) { - for (const auto &wd : *wrkr_map) word_cnt_[wd.first] += wd.second; - } else { - ++num_quited_worker; - } - } // all frequencies are obtained - CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty"); - std::vector words; - // make sure enough is reserved, this will become a partially sorted list eventually - words.reserve(wrkr_map->size()); - - for (auto it = word_cnt_.begin(); it != word_cnt_.end();) { - if (it->second >= freq_range_.first && it->second <= freq_range_.second) { - words.push_back(it->first); - it++; - } else { - it = word_cnt_.erase(it); - } - } - std::string err_msg; - - for (const std::string &sp_tk : special_tokens_) { - // if a special word exists in dataset, warn user about this - err_msg += (word_cnt_.find(sp_tk) != word_cnt_.end() ? sp_tk + "\t" : ""); - } - - CHECK_FAIL_RETURN_UNEXPECTED(err_msg.empty(), "These specials words are already in the dataset: " + err_msg + "."); - - int64_t num_words = std::min(static_cast(words.size()), top_k_); - if (num_words == 0) { - MS_LOG(WARNING) << "No word falls in the frequency range: (" << freq_range_.first << "," << freq_range_.second - << ") vocab would be empty (except for special tokens)."; - } - - // this would take the top-k most frequent words - std::partial_sort(words.begin(), words.begin() + num_words, words.end(), - [this](const std::string &w1, const std::string &w2) { - int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2]; - return f1 == f2 ? w1 < w2 : f1 > f2; - }); - - if (special_first_) { - for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk); - } - - for (int64_t i = 0; i < num_words; i++) { - vocab_->append_word(words[i]); - } - - if (!special_first_) { - for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk); - } - - RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - // then use std::nth_element to partial sort - return Status::OK(); -} - -Status BuildVocabOp::Builder::Build(std::shared_ptr *op) { - CHECK_FAIL_RETURN_UNEXPECTED(builder_num_workers_ > 0, "builder num_workers need to be greater than 0"); - CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number"); - CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0, - "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)"); - (*op) = std::make_shared( - builder_vocab_, builder_col_names_, std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_, - builder_speical_tokens_, builder_special_first_, builder_num_workers_, builder_connector_size_); - return Status::OK(); -} - -BuildVocabOp::Builder::Builder() - : builder_top_k_(std::numeric_limits::max()), - builder_min_freq_(0), - builder_max_freq_(std::numeric_limits::max()), - builder_special_first_(true) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_connector_size_ = cfg->op_connector_size(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h b/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h deleted file mode 100644 index bf358c48c6..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h +++ /dev/null @@ -1,174 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/text/vocab.h" -#include "dataset/util/queue.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class BuildVocabOp : public ParallelOp { - public: - class Builder { - public: - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method - // @param int32_t size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t size) { - builder_connector_size_ = size; - return *this; - } - - // Setter method - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param int64_t top_k - // @return Builder setter method returns reference to the builder. - Builder &SetTopK(int64_t top_k) { - builder_top_k_ = top_k; - return *this; - } - - // Setter method - // @param int64_t min_freq - // @return Builder setter method returns reference to the builder. - Builder &SetMinFreq(int64_t min_freq) { - builder_min_freq_ = min_freq; - return *this; - } - - // Setter method - // @param int64_t max_freq - // @return Builder setter method returns reference to the builder. - Builder &SetMaxFreq(int64_t max_freq) { - builder_max_freq_ = max_freq; - return *this; - } - - // set columns names - // @param const std::vector & col_names - name of columns to get words - // @return Builder & reference to builder class object - Builder &SetColumnNames(const std::vector &col_names) { - builder_col_names_ = col_names; - return *this; - } - - // set special tokens - // @param const std::vector & col_names - name of columns to get words - // @return Builder & reference to builder class object - Builder &SetSpecialTokens(const std::vector &tokens) { - builder_speical_tokens_ = tokens; - return *this; - } - - // set vocab object - Builder &SetVocab(std::shared_ptr vocab) { - builder_vocab_ = vocab; - return *this; - } - - // set special tokens first (or last) - Builder &SetSpecialFirst(bool prepend) { - builder_special_first_ = prepend; - return *this; - } - - // The builder "build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - int32_t builder_num_workers_; - int32_t builder_connector_size_; - int64_t builder_min_freq_; - int64_t builder_max_freq_; - bool builder_special_first_; - std::vector builder_col_names_; - std::vector builder_speical_tokens_; - std::shared_ptr builder_vocab_; - int64_t builder_top_k_; - }; - - BuildVocabOp(std::shared_ptr vocab, std::vector col_names, std::pair freq_range, - int64_t top_k, const std::vector &tokens, bool prepend, int32_t num_workers, - int32_t op_connector_size); - - ~BuildVocabOp() = default; - - Status WorkerEntry(int32_t worker_id) override; - - // collect the work product from each worker - Status CollectorThread(); - - Status EofReceived(int32_t) override { return Status::OK(); } - - Status EoeReceived(int32_t) override { return Status::OK(); } - - Status operator()() override; - - // Getter - // @return the number of workers - int32_t num_producers() const override { return 1; } - - // Getter - // @return the number of threads consuming from the previous Connector - int32_t num_consumers() const override { return 1; } - - Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); } - - private: - const int32_t interval_; - bool special_first_; - std::shared_ptr vocab_; - std::vector col_names_; - std::vector col_ids_; - std::vector special_tokens_; - // pair = {min_f, max_f} - // make sure that 0<= min_f < max_f <= int32_max in the builder - std::pair freq_range_; - - int64_t top_k_; // every thing means top_k_ == int32_max - std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 - std::unique_ptr> distributor_queue_; // master thread assigns each worker TensorRow via this - std::unique_ptr>>> collector_queue_; - std::unordered_map word_cnt_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc deleted file mode 100644 index 4bada31e7e..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include - -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/concat_op.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -ConcatOp::Builder::Builder() { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -// The builder "build" method creates the final object. -Status ConcatOp::Builder::Build(std::shared_ptr *ptr) { - *ptr = std::make_shared(builder_op_connector_size_); - return Status::OK(); -} - -// Constructor of the ConcatOp. -ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {} - -// A function that prints info about the Operator -void ConcatOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this is summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nDatasets: " << children_num_ << "\n\n"; - } -} - -// Main entry point for Concat -Status ConcatOp::operator()() { - // The children_num_ parameter needs to be put here - children_num_ = static_cast(child_.size()); - - TaskManager::FindMe()->Post(); - std::unique_ptr buf; - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - - int eof_count = 0; - while (eof_count != children_num_) { - for (int i = 0; i < children_num_; i++) { - // 1. Throw the eof buffer when meet it - if (buf->eof() || buf->eoe()) { - RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); - } - // 2. Do verification as for column name, column data type and rank of column data - RETURN_IF_NOT_OK(Verify(i, buf)); - - // 3. Put the data into output_connector - while (!buf->eoe() && !buf->eof()) { - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); - RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); - } - - // 4. Throw the eoe buffer when meet it - if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) { - RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); - } - // 5. Add eoe buffer after get buffer from all child - if (i == (children_num_ - 1)) { - auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - } - if (buf->eof()) { - eof_count++; - } - } - } - // 6. Add eof buffer in the end manually - MS_LOG(DEBUG) << "Add the eof buffer manualy in the end."; - auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - - return Status::OK(); -} - -Status ConcatOp::Verify(int32_t id, const std::unique_ptr &buf) { - TensorRow new_row; - buf->GetRow(0, &new_row); - - if (id == 0) { - // Obtain the data type and data rank in child[0] - for (auto item : new_row) { - data_type_.push_back(item->type()); - data_rank_.push_back(item->Rank()); - } - } else { - // Compare the data type and data rank with these in child[0] - int32_t index = 0; - for (auto item : new_row) { - if ((item->type() != data_type_[index]) || item->Rank() != data_rank_[index++]) { - RETURN_STATUS_UNEXPECTED("The data type or data rank is not the same with previous dataset."); - } - } - } - return Status::OK(); -} - -Status ConcatOp::PrepareNodePostAction() { - RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); - tree_->AddToEOEOpStack(shared_from_this()); - return Status::OK(); -} - -// We need to overwrite the super class ComputeColMap here because the number of children is more than 1. -Status ConcatOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - // Obtain columns_name_id_map from child_[0] - column_name_id_map_ = child_[0]->column_name_id_map(); - if (column_name_id_map_.empty()) { - RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); - } - // Verify all children have the same column name map - for (int32_t i = 0; i < child_.size(); ++i) { - if (child_[i]->column_name_id_map() != column_name_id_map_) { - RETURN_STATUS_UNEXPECTED("The column name or column order is not the same with previous dataset."); - } - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h deleted file mode 100644 index 4bcfdbf6c6..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ - -#include -#include -#include -#include -#include "dataset/engine/datasetops/pipeline_op.h" - -namespace mindspore { -namespace dataset { -class ConcatOp : public PipelineOp { - public: - // The nested builder class inside of the ConcatOp is used to help manage all of the arguments - // for constructing it. This Concat op is very simple though, so this builder is really just - // provided for a consistent look and feel for creators of Dataset operators overall. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // The builder "build" method creates the final object. - // @return shared_ptr to the new ConcatOp object - Status Build(std::shared_ptr *); - - private: - int32_t builder_op_connector_size_; - }; - - // Constructor of the ConcatOp. - // @note The builder class should be used to call it - // @param op_connector_size - connector size - explicit ConcatOp(int32_t op_connector_size); - - // Destructor - ~ConcatOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param ro - reference to the ConcatOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const ConcatOp &ro) { - ro.Print(out, false); - return out; - } - - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePostAction() override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ConcatOp"; } - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - private: - Status Verify(int32_t id, const std::unique_ptr &buf); - - int32_t children_num_; // The num of child of parent node. - std::unordered_map column_name_id_; // Mapping between col index and col name - std::vector data_type_; - std::vector data_rank_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc deleted file mode 100644 index 3e31f6c017..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc +++ /dev/null @@ -1,398 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/dataset_op.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/device_queue_op.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" -#include "utils/system/crc32c.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// Constructor -DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr sampler) - : oc_queue_size_(op_connector_size), - sampler_(sampler), - operator_id_(kInvalidOperatorId), - tree_(nullptr), - state_(OpState::kDeOpIdle), - op_ctrl_flags_(kDeOpNone), - out_connector_(nullptr) { - // The operator starts out with an invalid operator id. The only way to - // get it out of invalid state is to assign the operator to an execution tree. -} - -// Adds a operator to become our child. -Status DatasetOp::AddChild(std::shared_ptr child) { - if (std::dynamic_pointer_cast(child) != nullptr) { - std::string err_msg("DeviceQueueOp cannot be added as a child, DeviceQueueOp must be a root node"); - RETURN_STATUS_UNEXPECTED(err_msg); - } - if (operator_id_ == kInvalidOperatorId) { - std::string err_msg( - "Cannot add child node. Tree node connections can only" - "be made if the node belongs to a tree."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // disallow relationships with other trees - if (tree_ != child->tree_) { - std::string err_msg( - "Cannot add child node. Tree node connections can only be made if both nodes belong to the same tree."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - child_.push_back(child); - child->AddParent(this); - return Status::OK(); -} - -Status DatasetOp::RemoveChild(std::shared_ptr child) { - if (operator_id_ == kInvalidOperatorId) { - std::string err_msg( - "Cannot remove child node. Tree node connections can only" - "be made if the node belongs to a tree."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // disallow relationships with other trees - if (tree_ != child->tree_) { - std::string err_msg( - "Cannot remove child node. Tree node connections can only be made if both nodes belong to the same tree."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - child_.erase(std::remove(child_.begin(), child_.end(), child), child_.end()); - child->RemoveParent(this); - return Status::OK(); -} - -Status DatasetOp::InsertAsParent(std::shared_ptr to_add) { - for (auto &prev_parent : this->parent_) { - RETURN_IF_NOT_OK(prev_parent->RemoveChild(shared_from_this())); - RETURN_IF_NOT_OK(prev_parent->AddChild(to_add)); - } - RETURN_IF_NOT_OK(to_add->AddChild(shared_from_this())); - if (tree_->root()->id() == this->id()) { - tree_->AssignRoot(to_add); - } - return Status::OK(); -} - -// Adds a parent operator to this operator -void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); } - -// Removes a parent operator from this operator -void DatasetOp::RemoveParent(const DatasetOp *parent) { - parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end()); -} - -// Removes this node from the tree and connects it's parent/child together -Status DatasetOp::Remove() { - if (parent_.size() > 1) { - std::string err_msg("No support for op removal if the operator has more than one parent"); - RETURN_STATUS_UNEXPECTED(err_msg); - } - if (child_.size() > 1) { - std::string err_msg("No support for op removal if the operator has more than one child"); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Scenario's when removing node B: - // A -> B -> C - // A -> B - // B -> C - // - // If we remove B, then first take our child A and update it's parent to be C - // It's possible the parent is null if we are the root node being removed. - if (!child_.empty()) { - // If we have a parent, then assign chlid's parent to point to our parent. - if (!parent_.empty()) { - child_[0]->parent_[0] = parent_[0]; - } else { - // We don't have a parent, so we are the root node being removed. - // clear the parent list of our child so that it becomes the new root. - child_[0]->parent_.clear(); - tree_->AssignRoot(child_[0]); - } - } - - // Next, if we had a parent, then set it's child to be our child. - if (!parent_.empty()) { - // if we have a child, then set our parent to point to it - if (!child_.empty()) { - parent_[0]->child_[0] = child_[0]; - } else { - // We don't have a child, so clear the child list of the current - // parent because it will be empty once we are removed. - parent_[0]->child_.clear(); - } - } - - return Status::OK(); -} - -// Getter function to get a shared pointer to our childAdds a operator to become our child. -std::shared_ptr DatasetOp::child(int32_t child_index) const { - MS_ASSERT(child_index < static_cast(child_.size())); - // Return a shared pointer - return child_[child_index]; -} - -// Creates the connector within this operator -void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) { - MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers - << ". Consumer: " << num_consumers << "."; - if (oc_queue_size_ > 0) { - out_connector_ = std::make_unique(num_producers, // The number of producers - num_consumers, // Only one consumer (the training App) - oc_queue_size_); - } else { - // Some op's may choose not to have an output connector - MS_LOG(DEBUG) << "Bypassed connector creation for tree operator: " << operator_id_ << "."; - out_connector_ = nullptr; - } -} - -// A print method typically used for debugging. showAll of true will recursively descend to child prints -void DatasetOp::Print(std::ostream &out, bool show_all) const { - // When show_all is false, we display a 1 liner piece of text for the op. - // When show_all is true, we display more detailed output for the op. - // Derived printers should show their own header info, then call base class printer, followed by - // derived-specific items. - // For now, the base class doesn't have any summary info to show so it's a no-op in that case. - if (show_all) { - // The detailed display will show common base class info of the op. Allow the derived class to print - // it's own id and name though as the first line. - out << "\nNumber of children : " << child_.size(); - for (size_t i = 0; i < child_.size(); i++) { - out << "\n Child[" << i << "] id: " << child_[i]->id(); - } - out << "\nNumber of parents : " << parent_.size(); - for (size_t i = 0; i < parent_.size(); i++) { - out << "\n Parent[" << i << "] id: " << parent_[i]->id(); - } - out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex - << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' '); - if (sampler_) { - sampler_->Print(out, show_all); - } - } -} - -// Gets the next buffer from the given child -Status DatasetOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { -#if defined(_WIN32) || defined(_WIN64) - RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), p_buffer, retry_if_eoe)); -#else - std::unique_ptr next_buff; - // pop is a blocked call and will throw an interruption if the whole group shuts down. - RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), &next_buff, retry_if_eoe)); - - *p_buffer = std::move(next_buff); -#endif - return Status::OK(); -} - -// Gets the next buffer from the given child . This function also has built-in eoe and eof -// message handling so that child classes don't have to manually code pass-through logic when -// those messages are received. -Status DatasetOp::GetNextInput(std::unique_ptr *p_buffer, int32_t worker_id, int32_t child_index) { - if (child_.size() == 0) { - return this->GetNextBuffer(p_buffer, worker_id); - } - CHECK_FAIL_RETURN_UNEXPECTED(child_index < child_.size(), "Child index too big : " + std::to_string(child_index)); - std::shared_ptr child = child_[child_index]; - std::unique_ptr buf; - RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); - // Loop until non EOE is received - while (buf->eoe()) { - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - if (state_ == OpState::kDeOpIdle) { - *p_buffer = std::move(buf); - return Status::OK(); - } - RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); - } - // Check if the last buf is next eof - if (buf->eof()) { - RETURN_IF_NOT_OK(EofReceived(worker_id)); - } - *p_buffer = std::move(buf); - return Status::OK(); -} - -// Performs handling for when an eoe message is received. -// The base class implementation simply flows the eoe message to output. Derived classes -// may override if they need to perform special eoe handling. -Status DatasetOp::EoeReceived(int32_t worker_id) { - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - return (out_connector_->Add(static_cast(worker_id), std::move(eoe_buffer))); -} - -// Performs handling for when an eof message is received. -// The base class implementation simply flows the eof message to output. Derived classes -// may override if they need to perform special eof handling. -Status DatasetOp::EofReceived(int32_t worker_id) { - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - return (out_connector_->Add(static_cast(worker_id), std::move(eof_buffer))); -} - -// During tree prepare phase, operators may have specific pre-operations to perform depending on -// their role. -Status DatasetOp::PrepareNodePreAction() { - if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated); - return Status::OK(); -} -// During tree prepare phase, operators may have specific post-operations to perform depending on -// their role. -Status DatasetOp::PrepareNodePostAction() { - // If this op does not have any children and it is in a repeat path of the tree... - if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) { - // push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator - // above us will consume them. - tree_->AddToEOEOpStack(shared_from_this()); - } - // Creating Connector object for each op. - // The consumer of the root node is assumed to be one thread. - // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion. - if (parent_.empty()) { - this->CreateConnector(num_producers(), 1); - } else { - this->CreateConnector(num_producers(), parent_[0]->num_consumers()); - } - if (out_connector_) { - RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks())); - } - RETURN_IF_NOT_OK(this->RegisterWorkerConnectors()); - - // Generate the column name map for the current op. - RETURN_IF_NOT_OK(this->ComputeColMap()); - - return Status::OK(); -} - -// Getter function. Base class does not have any special flags setting. -uint32_t DatasetOp::PrepareFlags() const { return ExecutionTree::kDePrepNone; } - -// Derived classes may implement the reset function if the operator is stateful and needs -// specific reset handling that is not contained in this common code version of the reset. -Status DatasetOp::Reset() { - state_ = OpState::kDeOpRunning; - return Status::OK(); -} - -// gives a string output for the column map for handy debug printing -std::string DatasetOp::ColumnNameMapAsString() const { - std::string outStr = "Column name id map: "; - for (auto &it : column_name_id_map_) { - outStr += (" " + it.first + ":" + std::to_string(it.second)); - } - return outStr; -} - -// Computing the assignment of the column name map. -// This just inherits the column map from its first child, can only be used if the number of children is 1. -// Operations changing the column map must overwrite this function. -Status DatasetOp::ComputeColMap() { - if (child_.size() > 1) { - RETURN_STATUS_UNEXPECTED("Assigning column name map from child only works for single-child operators."); - } - if (column_name_id_map_.empty()) { - column_name_id_map_ = child_[0]->column_name_id_map(); - if (column_name_id_map_.empty()) { - RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); - } - MS_LOG(DEBUG) << "Setting column map:\n" << DatasetOp::ColumnNameMapAsString(); - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -Status DatasetOp::PreAccept(NodePass *p, bool *modified) { - // DatasetOp is the base class of visitor target pre-visit. - // This method will only be called if its derived class does not implement one. - return p->PreRunOnNode(shared_from_this(), modified); -} - -Status DatasetOp::Accept(NodePass *p, bool *modified) { - // DatasetOp is the base class of visitor target. - // This method will only be called if its derived class does not implement one. - return p->RunOnNode(shared_from_this(), modified); -} - -// A helper function with some common code that leaf nodes can use during -// prepare phase for checking if they need to assign a sampler to the cache. -Status DatasetOp::SaveSamplerForCache(bool random_access_op) { - // If we are a descendant under a cache op and we have a sampler, then save this sampler - // to a stack so that the cache can pick it up during it's processing above us. - if (sampler_) { - if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { - // use move semantic to set our sampler_ to null after the move. This is okay because a sampler is - // useless to a random data op. It was only being used as a temporary holding until the cache can - // be created - tree_->AddToSamplerStack(sampler_); - MS_LOG(INFO) << "Preparing a leaf op: passing sampler up the tree for Cache handling."; - } else if (!random_access_op) { - // A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf. - // This is an error because that type of leaf does not use sampling unless there's a cache to hook it into. - RETURN_STATUS_UNEXPECTED( - "Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree"); - } - } - - if (!random_access_op) { - // Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache - // we can remove it now from the base. - sampler_.reset(); - } - - return Status::OK(); -} -uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { - std::stringstream ss; - op->tree_->Print(ss, op); - std::string ss_str = ss.str(); - - // Filter out the Operator control flags field when generating the check sum - ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), ""); - - // Filter out the Device id field to allow cache sharing for a distributed run of the same pipeline - ss_str = std::regex_replace(ss_str, std::regex("Device id.*\n"), ""); - ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), ""); - - // The Cache crc and Server cache id field is different when creating new cache_client and re-using the same - // cache_client later. So we filter out these two fields to allow cache sharing. - ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), ""); - ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), ""); - - uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); - return cache_crc; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h deleted file mode 100644 index ab5cb90357..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h +++ /dev/null @@ -1,360 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ -#define DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ - -#include -#include -#include -#include -#include -#include "dataset/core/constants.h" -#include "dataset/engine/db_connector.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// Forward declare -class ExecutionTree; - -class DataBuffer; - -class NodePass; - -class Sampler; - -/// \brief The base class DatasetOp is the main tree node. It is an abstract class, so -/// the actual implementation of the operators will be derived from here. -class DatasetOp : public std::enable_shared_from_this { - // Allow execution tree to access internal members - friend class ExecutionTree; - - public: - static constexpr int32_t kInvalidOperatorId = -1; - - // Flags that control operator runtime behaviours - enum OpControlFlags { - kDeOpNone = 0, - kDeOpRepeated = 1, // Operator is a leaf node in a repeat path - kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop - }; - - // Flags that control operator runtime behaviours - enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated }; - - /// Constructor - /// \param op_connector_size - The size for the output connector of this operator. - /// \param sampler - The sampler for the op - explicit DatasetOp(int32_t op_connector_size, std::shared_ptr sampler); - - /// Destructor - virtual ~DatasetOp() { tree_ = nullptr; } - - /// Adds a operator to become our child. - /// \param child - shared pointer to the child to add. - Status AddChild(std::shared_ptr child); - - /// Remove a operator from our children. - /// \param child - shared pointer to the child to remove. - Status RemoveChild(std::shared_ptr child); - - /// \brief Removes this node from the tree and connects it's parent/child together. - /// \return Status eerror code returned - Status Remove(); - - /// \brief Getter function to get a shared pointer to our child - /// \param child_index - An operator can have n children. Indicates choose which child to return. - std::shared_ptr child(int32_t child_index) const; - - /// \brief Inserts a operator as the parent current op. - /// Inserted op will become the sole parent of the current op. - /// The existing parent of the current op will be transferred to the inserted op. - Status InsertAsParent(std::shared_ptr to_add); - - /// \brief Creates the connector within this operator - /// \param num_producers - number of threads that write into this connector - /// \param num_consumers - number of threads that read from this connector - void CreateConnector(int32_t num_producers, int32_t num_consumers); - - /// \brief A print method typically used for debugging - /// \param out - The output stream to write output to - /// \param show_all - A bool to control if you want to show all info or just a summary - virtual void Print(std::ostream &out, bool show_all) const; - - /// \brief << Stream output operator overload - /// \notes This allows you to write the debug print info using stream operators - /// \param out - reference to the output stream being overloaded - /// \param dO - reference to the DatasetOp to display - /// \return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const DatasetOp &dO) { - dO.Print(out, false); - return out; - } - - /// \brief Class functor operator (). - /// DatasetOps operate by launching a thread (see ExecutionTree). - /// This pure virtual version makes the requirement that derived classes must provide a functor - /// that will execute their main runtime loop code. - /// \return Status - The error code return - virtual Status operator()() = 0; - - /// \brief Gets the next buffer from the given child - /// \notes See GetNextInput for similar function that has built-in message handling - /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) - /// \param worker_id - The worker id - /// \return Status - The error code return - virtual Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id) { - return GetNextBuffer(p_buffer, worker_id, false); - } - - /// \brief Gets the next buffer from the given child - /// \notes See GetNextInput for similar function that has built-in message handling - /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) - /// \return Status - The error code return - virtual Status GetNextBuffer(std::unique_ptr *p_buffer) { return GetNextBuffer(p_buffer, 0, false); } - - /// \brief Gets the next buffer from the given child - /// \notes See GetNextInput for similar function that has built-in message handling - /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) - /// \param worker_id - The worker id - /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. - /// \return Status - The error code return - virtual Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe); - - /// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof - /// message handling so that child classes don't have to manually code pass-through logic when - /// those messages are received. - /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) - /// \param worker_id - The worker id - /// \return Status - The error code return - Status GetNextInput(std::unique_ptr *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); - - /// \brief Performs handling for when an eoe message is received. - /// The base class implementation simply flows the eoe message to output. Derived classes - /// may override if they need to perform special eoe handling. - /// \param worker_id - The worker id - /// \return Status - The error code return - virtual Status EoeReceived(int32_t worker_id); - - /// \brief Performs handling for when an eof message is received. - /// The base class implementation simply flows the eof message to output. Derived classes - /// may override if they need to perform special eof handling. - /// \param worker_id - The worker id - /// \return Status - The error code return - virtual Status EofReceived(int32_t worker_id); - - /// \brief Derived classes may implement the reset function if the operator is stateful and needs - /// specific reset handling that is not contained in this common code version of the reset - /// \return Status - The error code return - virtual Status Reset(); - - /// \brief This calls the reset function on this subtree in pre-order - /// \return Status - The error code return - virtual Status ResetSubtree() { - RETURN_IF_NOT_OK(Reset()); - for (const auto &c : child_) { - RETURN_IF_NOT_OK(c->ResetSubtree()); - } - return Status::OK(); - } - - /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on - /// their role. - /// \notes Derived versions of this function should always call it's superclass version first - /// before providing their own implementations. - virtual Status PrepareNodePreAction(); - - /// \brief During tree prepare phase, operators may have specific post-operations to perform depending on - /// their role. - /// \notes Derived versions of this function should always call it's superclass version first - /// before providing their own implementations. - virtual Status PrepareNodePostAction(); - - /// \brief Getter function - /// \return The operator id - int32_t id() const { return operator_id_; } - - /// \brief Getter function - /// \return The prepare flags - virtual uint32_t PrepareFlags() const; - - /// \brief Getter function - /// \return The number of workers in this op - virtual int32_t num_workers() const = 0; - - /// \brief Getter function - /// \return The number of threads consuming from previous op. - virtual int32_t num_consumers() const = 0; - - /// \brief Getter function - /// \return The number of threads producing to the output connector. - virtual int32_t num_producers() const = 0; - - /// \brief Getter function - /// \return T/F if this is an inlined operator - bool inlined() const { return (oc_queue_size_ == 0); } - - /// \brief Setter function - /// \return Sets the control flags - void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); } - - /// \brief Setter function - /// \return Sets the control flags - void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); } - - /// \brief Register the internal worker connectors. No op unless it is a parallel op - /// \return Status - virtual Status RegisterWorkerConnectors() { return Status::OK(); } - - /// \brief Getter for the column name mapping - /// \return The returned map - std::unordered_map column_name_id_map() const { return column_name_id_map_; } - - /// \brief Checks if the column name map has been set up yet for this op - /// \return - T/F if the operator has the map set up - bool HasColumnNameMap() const { return (column_name_id_map_.empty()); } - - /// \brief gives a string output for the column map for handy debug printing - /// \return - the column name map as a string - std::string ColumnNameMapAsString() const; - - /// \brief Getter function - /// \return connector size of current op - int32_t ConnectorSize() const { - if (!inlined()) { - return out_connector_->size(); - } - // Return child connector size for inlined op - return ChildOpConnectorSize(); - } - - /// \brief Counting number of buffer sent out by a connector - int64_t ConnectorOutBufferCount() const { - return out_connector_ == nullptr ? int64_t(-1) : static_cast(out_connector_->out_buffers_count()); - } - - /// \brief Getter function - /// \return connector size of current op - int32_t ConnectorCapacity() const { - if (!inlined()) { - return out_connector_->capacity(); - } - // Return child connector capacity for inlined op - return ChildOpConnectorCapacity(); - } - - /// \brief Getter function - /// \return connector size of child op - int32_t ChildOpConnectorSize(int32_t child_index = 0) const { return child_[child_index]->ConnectorSize(); } - - /// \brief Getter function - /// \return connector capacity of child op - int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); } - - /// \brief Children Getter - /// \return Vector of Children - std::vector> Children() const { return child_; } - - /// \brief Base method for NodePass pre-visit. A tree walk consists of walking down the tree and also walking back up - /// in a depth-first order. PreAccept is the node visit on the way down, whereas the regular Accept is the main - /// visit on the way back up the tree during a post-order traversal. Subclass needs to override this if it - /// requires special node visit access. Check "dataset/engine/opt/pass.h" for more details. - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - virtual Status PreAccept(NodePass *p, bool *modified); - - /// \brief Base method for NodePass visit. Subclass needs to override this if it requires special node visit access. - /// Check "dataset/engine/opt/pass.h" for more details. - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - virtual Status Accept(NodePass *p, bool *modified); - - /// Op name getter - /// \return Name of the current Op - virtual std::string Name() const { return "DatasetOp"; } - - /// Execution Tree getter - /// \return Pointer to the ExecutionTree the current op belongs to, no ownership - ExecutionTree *Tree() { return tree_; } - - /// Getter for the sampler - /// \return Shared pointer to the sampler (may return nullptr) - std::shared_ptr sampler() { return sampler_; } - - /// Computes a CRC value for the operator - static uint32_t GenerateCRC(const std::shared_ptr &op); - - /// \brief A helper templated function for casting "this" pointer to shared_ptr - /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr - /// \return A shared_ptr casted to the derived class - template - std::shared_ptr shared_from_base() { - return std::static_pointer_cast(shared_from_this()); - } - - protected: - /// Adds a parent operator to this operator - /// \notes External callers do not have access to this function. - /// \param parent - The parent node to add - void AddParent(DatasetOp *parent); - - /// Removes a parent operator from this operator - /// \notes External callers do not have access to this function. - /// \param parent - The parent node to remove - void RemoveParent(const DatasetOp *parent); - - /// Compute the current op's column map using its child's column map. - /// Get called during the tree post-prepare phase in PrepareNodePostAction. - /// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1. - /// Operations changing the column map it inherits from the child must overwrite this function. - /// \return - Status - virtual Status ComputeColMap(); - - /// A helper function with some common code that leaf nodes can use during - /// pre/pare phase for checking if they need to assign a sampler to the cache. - /// \param random_access_op - indicate if this is a mappable random access leaf or not - /// \return - Status - Status SaveSamplerForCache(bool random_access_op); - - std::vector> child_; // Child nodes - std::vector parent_; // Parent nodes. No ownership - std::shared_ptr sampler_; // Some leaf ops might have a sampler - int32_t oc_queue_size_; // Capacity for each out_connector_ - int32_t operator_id_; // Generated id for the node - ExecutionTree *tree_; // Back pointer to our tree. - OpState state_; // The state of the operator, Running, Idle, Terminated - uint32_t op_ctrl_flags_; // Flags for the operator - std::unique_ptr out_connector_; // Output Connector - std::unordered_map column_name_id_map_; // Mapping between col index and col name - std::mutex column_name_map_mutex_; // For protecting shared access to the column map - - private: - /// Sets the operator id. - /// \notes No public interface. Only the class itself, or it's friend the execution tree can set - /// this - /// \param op_id - the Id value to set into the operator - void set_id(int32_t op_id) { operator_id_ = op_id; } - - /// Sets the tree into the op so that the operator has a back pointer to the tree. - /// \param tree - the tree to assign to the op. - void set_tree(ExecutionTree *tree) { tree_ = tree; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc deleted file mode 100644 index 0f1fefc0f0..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc +++ /dev/null @@ -1,320 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/device_queue_op.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/engine/perf/profiling.h" -#include "dataset/engine/perf/device_queue_tracing.h" -#include "dataset/util/status.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, - int32_t op_connector_size, int64_t num_batch) - : PipelineOp(op_connector_size), - channel_name_(channel_name), - device_type_(device_type), - device_id_(device_id), - prefetch_size_(prefetch_size), - num_batch_(num_batch) {} - -DeviceQueueOp::~DeviceQueueOp() {} - -#ifdef ENABLE_GPUQUE -void ReleaseData(void *addr) { - if (addr != nullptr) { - free(addr); - } -} -#endif - -DeviceQueueOp::Builder::Builder(int32_t prefetch_size) - : builder_prefetch_size_(prefetch_size), - builder_device_id_(0), - builder_device_type_(DeviceType::CPU), - builder_channel_name_(""), - builder_num_batch_(0) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status DeviceQueueOp::EoeReceived(int32_t worker_id) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -Status DeviceQueueOp::operator()() { - TaskManager::FindMe()->Post(); - - if (device_type_ == DeviceType::Ascend) { -#ifdef ENABLE_TDTQUE - RETURN_IF_NOT_OK(SendDataToAscend()); -#endif - } else if (device_type_ == DeviceType::GPU) { -#ifdef ENABLE_GPUQUE - RETURN_IF_NOT_OK(SendDataToGPU()); -#endif - } else if (device_type_ == DeviceType::CPU) { - RETURN_IF_NOT_OK(SendDataToCPU()); - } - - return Status::OK(); -} - -Status DeviceQueueOp::CheckExceptions(const std::unique_ptr &buffer) const { - // this method checks if the buffer meets the conditions to be sent to TDT - if (buffer->NumRows() != 0) { - TensorRow row; - buffer->GetRow(0, &row); - for (const auto &item : row) { - CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device."); - } - } - return Status::OK(); -} - -#ifdef ENABLE_TDTQUE -Status DeviceQueueOp::SendDataToAscend() { - MS_LOG(INFO) << "Device queue, sending data to Ascend."; - int64_t total_batch = 0; - bool is_break_loop = false; - double batch_start_time, end_time; - int32_t batch_cost, tdt_cost; - int32_t connector_size = 0; - int32_t connector_capacity; - std::shared_ptr profiling_node; - bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable(); - if (isProfilingEnable) { - std::shared_ptr node; - RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node)); - profiling_node = std::dynamic_pointer_cast(node); - batch_start_time = ProfilingTime::GetCurMilliSecond(); - connector_capacity = ChildOpConnectorCapacity(); - } - std::unique_ptr current_buffer; - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - - while (!current_buffer->eof() && !is_break_loop) { - while (!current_buffer->eoe() && !is_break_loop) { - RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); - TensorRow currRow; - for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) { - RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); - auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); - if (status == TdtStatus::FAILED) { - return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); - } - - if (isProfilingEnable) { - end_time = ProfilingTime::GetCurMilliSecond(); - // record push tdt time - profiling_node->Record(TIME, TDT_PUSH_TIME, total_batch + 1, tdt_cost); - batch_cost = (int32_t)(end_time - batch_start_time); - // record batch time - profiling_node->Record(TIME, BATCH_TIME, total_batch + 1, batch_cost); - // record pipeline time - profiling_node->Record(TIME, PIPELINE_TIME, total_batch + 1, batch_cost - tdt_cost); - batch_start_time = end_time; - // record connector depth - profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size); - } - total_batch++; - if (num_batch_ > 0 && total_batch == num_batch_) { - is_break_loop = true; - } - } - if (isProfilingEnable) { - connector_size = ChildOpConnectorSize(); - connector_capacity = ChildOpConnectorCapacity(); - } - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - } - if (isProfilingEnable) { - connector_size = ChildOpConnectorSize(); - connector_capacity = ChildOpConnectorCapacity(); - } - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - } - - tree_->SetFinished(); - MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; - - return Status::OK(); -} -#endif - -#ifdef ENABLE_GPUQUE -Status DeviceQueueOp::SendDataToGPU() { - MS_LOG(INFO) << "Device queue, sending data to GPU."; - int64_t total_batch = 0; - bool is_break_loop = false; - bool is_open = false; - uint32_t handle = INVALID_HANDLE; - - std::unique_ptr current_buffer; - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - - while (!current_buffer->eof() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed()) { - while (!current_buffer->eoe() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed()) { - RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); - TensorRow curr_row; // batch data - for (int row_id = 0; - row_id < current_buffer->NumRows() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed(); row_id++) { - RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &curr_row)); - - std::vector data_size; - for (int i = 0; i < curr_row.size(); i++) { - data_size.push_back(static_cast(curr_row[i]->SizeInBytes())); - } - if (!is_open) { - handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, data_size, ReleaseData); - if (handle == INVALID_HANDLE) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "open failed"); - } - is_open = true; - } - RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle)); - total_batch++; - if (num_batch_ > 0 && total_batch == num_batch_) { - is_break_loop = true; - } - } - if (!TaskManager::FindMe()->Interrupted()) - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - else - is_break_loop = true; - } - if (!TaskManager::FindMe()->Interrupted()) - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - else - is_break_loop = true; - } - - MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; - - GpuBufferMgr::GetInstance().Close(handle); - - GpuBufferMgr::GetInstance().CloseConfirm(); - - return Status::OK(); -} - -Status DeviceQueueOp::RetryPushGPUData(const std::vector &data_size, const TensorRow &curr_row, - uint32_t handle) { - std::vector items; - for (int i = 0; i < data_size.size(); i++) { - device::DataItemGpu data_item; - data_item.data_len_ = data_size[i]; - data_item.data_ptr_ = nullptr; - items.push_back(data_item); - } - - while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { - RETURN_IF_NOT_OK(MallocForGPUData(&items, curr_row)); - BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); - if (ret) { - for (int i = 0; i < items.size(); i++) { - free(items[i].data_ptr_); - } - if (ret == BlockQueueStatus_T::ERROR_INPUT) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it."); - } else { - MS_LOG(WARNING) << "Retry pushing data..."; - continue; - } - } else { - break; - } - } - return Status::OK(); -} - -Status DeviceQueueOp::MallocForGPUData(std::vector *items, const TensorRow &curr_row) { - int i = 0; - for (auto &sub_item : *items) { - sub_item.data_ptr_ = (unsigned char *)malloc(sub_item.data_len_); - if (sub_item.data_ptr_ == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memory malloc failed."); - } - (void)memset_s(sub_item.data_ptr_, sub_item.data_len_, 0, sub_item.data_len_); - const unsigned char *column_data = curr_row[i]->GetBuffer(); - if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data, - static_cast(curr_row[i++]->SizeInBytes())) != 0) { - MS_LOG(ERROR) << "memcpy_s failed!"; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memcpy_s failed."); - } - } - - return Status::OK(); -} -#endif - -Status DeviceQueueOp::SendDataToCPU() { - MS_LOG(INFO) << "Device queue, sending data to CPU."; - int64_t total_batch = 0; - - std::unique_ptr child_iterator = std::make_unique(this, 0, 0); - while (!(child_iterator->eof_handled())) { - TensorRow curr_row; - RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&curr_row)); - - if (!curr_row.empty()) { - MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << "."; - MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << "."; - total_batch++; - if (num_batch_ > 0 && total_batch == num_batch_) { - break; - } - } - } - - MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; - - return Status::OK(); -} - -void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n"; - } -} - -// Visitor accept method for NodePass -Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h deleted file mode 100644 index a854004593..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ - -#include -#include -#include - -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/status.h" - -#ifdef ENABLE_TDTQUE -#include "dataset/engine/tdt/tdt_plugin.h" -#endif - -#ifdef ENABLE_GPUQUE -#include "device/gpu/gpu_buffer_mgr.h" -using mindspore::device::BlockQueueStatus_T; -using mindspore::device::GpuBufferMgr; -#endif - -namespace mindspore { -namespace dataset { -class DeviceQueueOp : public PipelineOp { - public: - static const uint32_t INVALID_HANDLE = 0xffffffffUL; - static const uint32_t WAIT_TIME = 5; - - enum class DeviceType { Ascend = 0, GPU = 1, CPU = 2 }; - - // The nested builder class inside of the DeviceQueueOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - class Builder { - public: - explicit Builder(int32_t prefetch_size); - - // Default destructor - ~Builder() = default; - - Builder &SetPrefetchSize(int32_t prefetch_size) { - builder_prefetch_size_ = prefetch_size; - return *this; - } - - Builder &SetChannelName(const std::string &channel_name) { - builder_channel_name_ = channel_name; - return *this; - } - - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - Builder &SetDeviceType(const std::string &device_type) { - if (device_type == "Ascend") { - builder_device_type_ = DeviceType::Ascend; - } else if (device_type == "GPU") { - builder_device_type_ = DeviceType::GPU; - } else if (device_type == "CPU") { - builder_device_type_ = DeviceType::CPU; - } - return *this; - } - - Builder &SetDeviceId(int32_t device_id) { - builder_device_id_ = device_id; - return *this; - } - - Builder &SetNumBatch(int64_t num_batch) { - builder_num_batch_ = num_batch; - return *this; - } - - // Name: Build() - // Description: The final step for building a DeviceQueueOp via the Builder is - // to call this Build() method. It will instantiate the DeviceQueueOp - // and return it to caller as a shared pointer. - Status Build(std::shared_ptr *ptr) { - *ptr = std::make_shared(builder_channel_name_, builder_device_type_, builder_device_id_, - builder_prefetch_size_, builder_op_connector_size_, builder_num_batch_); - return Status::OK(); - } - - private: - int32_t builder_prefetch_size_; - int32_t builder_device_id_; - DeviceType builder_device_type_; - std::string builder_channel_name_; - int64_t builder_num_batch_; - int32_t builder_op_connector_size_; - }; - - // Name: constructor - // Description - DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, - int32_t op_connector_size, int64_t num_batch); - - // Name: destructor - // Description - ~DeviceQueueOp(); - - Status EoeReceived(int32_t worker_id) override; - - const int32_t get_prefetch_size() { return prefetch_size_; } - - // Name: Print() - // Description: A function that prints info about the node - void Print(std::ostream &out, // In: The output stream to print to - bool show_all) const override; // In: T/F if it should print everything - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const DeviceQueueOp &to) { - to.Print(out, false); - return out; - } - - Status operator()() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "DeviceQueueOp"; } - - private: - // Name: checkExceptions(DataBuffer); - // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp - Status CheckExceptions(const std::unique_ptr &buffer) const; - -#ifdef ENABLE_TDTQUE - Status SendDataToAscend(); -#endif - -#ifdef ENABLE_GPUQUE - Status SendDataToGPU(); - Status RetryPushGPUData(const std::vector &data_size, const TensorRow &curr_row, uint32_t handle); - Status MallocForGPUData(std::vector *items, const TensorRow &curr_row); -#endif - - Status SendDataToCPU(); - std::string channel_name_; - DeviceType device_type_; - const int32_t device_id_; - const int32_t prefetch_size_; - const int64_t num_batch_; - -#ifdef ENABLE_TDTQUE - std::shared_ptr tdtInstancePtr; -#endif -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc deleted file mode 100644 index 81c93c6e1c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc +++ /dev/null @@ -1,267 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/filter_op.h" -#include -#include -#include -#include -#include -#include -#include "dataset/core/config_manager.h" -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/kernels/tensor_op.h" -#include "utils/log_adapter.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { - -Status FilterOp::Builder::SanityCheck() { - std::string err; - err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; - err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : ""; - return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); -} - -FilterOp::Builder::Builder() { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status FilterOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_, - builder_predicate_func_); - return Status::OK(); -} - -FilterOp::FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, - py::function predicate_func) - : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {} - -Status FilterOp::operator()() { - // The operator class just starts off threads by calling the tree_ function. - RETURN_UNEXPECTED_IF_NULL(tree_); - filter_queues_.Init(num_workers_, oc_queue_size_); - RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); - Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1)); - // Synchronize with TaskManager. - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(rc); - RETURN_IF_NOT_OK(Collector()); - return Status::OK(); -} - -Status FilterOp::EofReceived(int32_t) { return Status::OK(); } - -Status FilterOp::EoeReceived(int32_t) { return Status::OK(); } - -// Validating if each of the input_columns exists in the DataBuffer. -Status FilterOp::ValidateInColumns(const std::vector *input_columns) { - for (const auto &inCol : *input_columns) { - bool found = column_name_id_map_.find(inCol) != column_name_id_map_.end() ? true : false; - if (!found) { - std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - return Status::OK(); -} - -// A print method typically used for debugging. -void FilterOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nInput column names:"; - for (size_t i = 0; i < in_columns_.size(); i++) { - out << " " << in_columns_[i]; - } - out << "\n\n"; - } -} - -Status FilterOp::WorkerEntry(int32_t worker_id) { - // Handshake with TaskManager that thread creation is successful. - TaskManager::FindMe()->Post(); - std::unique_ptr in_buffer; - bool worker_stop = false; - while (worker_stop == false) { - // Getting a databuffer to work on. - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); - if (in_buffer->eoe()) { - filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); - continue; - } else if (in_buffer->eof()) { - filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); - worker_stop = true; - continue; - } - - RETURN_IF_NOT_OK(CheckColumns(in_buffer.get(), &in_columns_)); - - // if the databuffer was all filtered, it is marked as kFilterEmpty. - // if the databuffer was partially filtered, it is marked as kFilterPartial. - // if the databuffer was not filtered, it is marked as kFilterFull. - int32_t num_rows = in_buffer->NumRows(); - std::unique_ptr new_tensor_table; - RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), &new_tensor_table)); - - if (new_tensor_table->empty()) { - RETURN_IF_NOT_OK( - filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty))); - } else if (new_tensor_table->size() == num_rows) { - in_buffer->set_tensor_table(std::move(new_tensor_table)); - RETURN_IF_NOT_OK( - filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull))); - } else { // kFilterPartial - in_buffer->set_tensor_table(std::move(new_tensor_table)); - RETURN_IF_NOT_OK( - filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial))); - } - } - return Status::OK(); -} - -Status FilterOp::WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out) { - *out = std::make_unique(); - int32_t num_rows = in_buffer->NumRows(); - for (int32_t i = 0; i < num_rows; i++) { - TensorRow to_process; - TensorRow cur_row; - RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); - if (in_columns_.empty() == true) { - MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; - to_process = cur_row; - } else { - (void)std::transform( - in_columns_.begin(), in_columns_.end(), std::back_inserter(to_process), - [&cur_row, this](const auto &it) -> std::shared_ptr { return cur_row[column_name_id_map_[it]]; }); - } - bool predicate = true; - RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate)); - if (predicate) { - (*out)->push_back(std::move(cur_row)); - } - } - return Status::OK(); -} - -// if the filtered DataBuffer is written directly to out_connector_, -// the thread fetching data will block in a queue. -// Collector function will reorder the DataBuffer in order. -// for example in two work queues: -// int filter_queues_: -// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof) -// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe) -// after reorder in out_connector_: -// queue1: DB(data2) DB(data4) DB(eof) -// queue2: DB(eoe) DB(eoe) -Status FilterOp::Collector() { - bool collector_stop = false; - uint64_t task_id_cnt = 0; - uint64_t out_id_cnt = 0; - std::pair, filterCtrl> in_pair; - while (collector_stop == false) { - uint32_t w_id = task_id_cnt % num_workers_; - RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); - if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || - in_pair.second == filterCtrl::kFilterEoe) { - uint32_t out_task_id = out_id_cnt % num_workers_; - RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); - out_id_cnt++; - task_id_cnt++; - } else if (in_pair.second == filterCtrl::kFilterEof) { - uint32_t out_task_id = out_id_cnt % num_workers_; - RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); - collector_stop = true; - } else { // kFilterEmpty - task_id_cnt++; - } - } - return Status::OK(); -} - -// Private function for checking the column legality. -Status FilterOp::CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns) { - int32_t num_rows = in_buf->NumRows(); - int32_t num_cols = in_buf->NumCols(); - if (num_rows == 0 || num_cols == 0) { - RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer."); - } - // Check if there is invalid column name in the inColumns. - RETURN_IF_NOT_OK(ValidateInColumns(input_columns)); - return Status::OK(); -} - -Status FilterOp::CheckInput(const TensorRow &input) const { - for (auto &item : input) { - if (item == nullptr) { - RETURN_STATUS_UNEXPECTED("input is null."); - } - } - return Status::OK(); -} - -Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) { - RETURN_IF_NOT_OK(CheckInput(input)); - // Acquire Python GIL. - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - // Transform input tensor vector into numpy array vector. - py::tuple input_args(input.size()); - for (size_t i = 0; i < input.size(); i++) { - py::array new_data; - RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); - input_args[i] = new_data; - } - // Invoke python function. - py::object ret_py_obj = predicate_func_(*input_args); - *out_predicate = ret_py_obj.cast(); - } catch (const py::error_already_set &e) { - std::stringstream ss; - ss << e.what() << std::endl; - ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool."; - return Status(StatusCode::kPyFuncException, ss.str()); - } - return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); -} - -// Visitor accept method for NodePass -Status FilterOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h deleted file mode 100644 index 36f70cb82f..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h +++ /dev/null @@ -1,188 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ - -#include -#include -#include -#include -#include -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/queue.h" - -namespace mindspore { -namespace dataset { - -class FilterOp : public ParallelOp { - public: - // The nested builder class inside of the FilterOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args. - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetPredicateFunc(py::function func) { - builder_predicate_func_ = std::move(func); - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetInColNames(const std::vector &in_col_names) { - build_in_col_names_ = in_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - builder_op_connector_size_ = connector_size; - return *this; - } - - // The builder "build" method creates the final object. - // @param ptr The shared_ptr to the new FilterOp object. - // @return Status. - Status Build(std::shared_ptr *ptr); - - private: - // Sanity check for builder class args. - // @return Status - The error code return. - Status SanityCheck(); - std::vector build_in_col_names_; - py::function builder_predicate_func_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - }; - - enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 }; - - // Constructor of FilterOp - // @note The builder class should be used to call it. - // @param in_col_names A list of input column names,when it is empty the predicate will be - // applied all columns in the dataset. - // @param num_workers The number of worker threads. - // @param op_connector_size The size of each queue in the connector. - // @param predicate_func python callable which returns a boolean value. - FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, - py::function predicate_func); - - // Destructor - ~FilterOp() = default; - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will - // provide the master loop that drives the logic for performing the work. - // @return Status The error code return - Status operator()() override; - - // @param int32_t workerId. - // @return Status - The error code return. - Status EofReceived(int32_t) override; - - // @param int32_t workerId. - // @return Status - The error code return. - Status EoeReceived(int32_t) override; - - // A print method typically used for debugging. - // @param out The output stream to write output to. - // @param show_all A bool to control if you want to show all info or just a summary. - void Print(std::ostream &out, bool show_all) const override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "FilterOp"; } - - private: - // predicate_func python callable which returns a boolean value. - py::function predicate_func_; - - // Variable to store the column name that will feed to predicate function. - std::vector in_columns_; - - // Internal queue for filter. - QueueList, filterCtrl>> filter_queues_; - - // Private function for worker/thread to loop continuously. It comprises the main - // logic of FilterOp, getting the data from previous Op, validating user specified column names, - // applying predicate to each of the data, filter the data when predicate result is false. - // @param worker_id The id assigned to this thread/worker upon creation. - // @return Status The error code return. - Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ - - // Filter the data by predicate function . - // @param in_buffer input data buffer. - // @param to_proess_indices Indices of columns to be processed. - // @param out data buffer that are filtered by predicate. - // @return Status The error code return. - Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out); - - // Collector databuffer. - // @return Status The error code return. - Status Collector(); - - // @param input tensor vector. - // @return Status - The error code return. - Status CheckInput(const TensorRow &input) const; - - // Invoke python func. - // @param input tensor vector. - // @param the result of predicate. - // @return Status - The error code return. - Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); - - // Private function for validating if each of the user specified input column names - // exist in the DataBuffer. - // @param input_columns The vector of input column names used in the current thread. - // @return Status The error code return. - Status ValidateInColumns(const std::vector *input_columns); - - // Private function for checking the column legality - // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory - // and is not shared with other threads. - // @param[out] to_process_indices Indices of columns that will feed to predicate. - // @param input_columns The vector of input column names used in the current thread. - Status CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns); -}; - -} // namespace dataset -} // namespace mindspore -#endif diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc deleted file mode 100644 index 05a1ac7925..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc +++ /dev/null @@ -1,373 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/map_op.h" -#include -#include -#include -#include -#include -#include "dataset/core/config_manager.h" - -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/kernels/tensor_op.h" -#include "utils/log_adapter.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -MapOp::Builder::Builder() : build_perf_mode_(true) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_workers_ = cfg->num_parallel_workers(); - build_op_connector_size_ = cfg->op_connector_size(); -} - -// Check if the required parameters are set by the builder. -Status MapOp::Builder::sanityCheck() const { - if (build_tensor_funcs_.empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Building a MapOp that has not provided any function/operation to apply"); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status MapOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(sanityCheck()); - *ptr = std::make_shared(std::move(build_in_col_names_), std::move(build_out_col_names_), - std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_, - build_perf_mode_); - return Status::OK(); -} - -// Constructor of MapOp -MapOp::MapOp(const std::vector &in_col_names, const std::vector &out_col_names, - std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, - bool perf_mode) - : ParallelOp(num_workers, op_connector_size), - tfuncs_(std::move(tensor_funcs)), - in_columns_(in_col_names), - out_columns_(out_col_names), - perf_mode_(perf_mode) { - // If caller didn't specify the out_col_names, assume they are same as the in_columns. - if (out_columns_.empty() || out_columns_[0].empty()) { - out_columns_ = in_columns_; - } - MS_LOG(DEBUG) << "Performance Mode in map operator is " << perf_mode_ << "."; -} - -// The number of threads consuming data from previous op's output Connector. -int32_t MapOp::num_consumers() const { - // When Performance Mode is on, there is only one thread consuming from the previous Connector. - return perf_mode_ == true ? 1 : num_workers_; -} - -// A print method typically used for debugging -void MapOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nInput column names:"; - for (size_t i = 0; i < in_columns_.size(); i++) { - out << " " << in_columns_[i]; - } - out << "\n TensorOps:"; - for (size_t i = 0; i < tfuncs_.size(); i++) { - out << " " << *(tfuncs_[i].get()); - } - out << "\n\n"; - } -} - -// This class functor will provide the master loop that drives the logic for performing the work -Status MapOp::operator()() { - if (perf_mode_) { - // Create and register the local queues. - local_queues_.Init(num_workers_, oc_queue_size_); - Status rc = local_queues_.Register(tree_->AllTasks()); - if (rc.IsError()) { - TaskManager::FindMe()->Post(); - return rc; - } - } - - // The operator class just starts off threads by calling the tree_ function - Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1)); - // Synchronize with TaskManager - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(rc); - - if (perf_mode_) { - int64_t que_id = 0; - std::unique_ptr buff; - bool is_eof = false; - // Draining output connector of the previous op and distribute it to local queues. - // Stop when all worker threads are finished (received EOF). - while (!is_eof) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); - is_eof = buff->eof(); - RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(buff))); - que_id = (que_id + 1) % num_workers_; - } - } - - return Status::OK(); -} - -// Private function for worker/thread to loop continuously. It comprises the main -// logic of MapOp: getting the data from previous Op, validating user specified column names, -// applying a list of TensorOps to each of the data, process the results and then -// pushing them back to MapOp's output Connector to be fetched by the next Op. -Status MapOp::WorkerEntry(int32_t worker_id) { - // Handshake with TaskManager that thread creation is successful. - TaskManager::FindMe()->Post(); - std::unique_ptr in_buffer; - - // Getting a databuffer to work on. - // Perform the first fetch here outside of the loop. This allows us to execute one-time only - // initializations that happen after the first fetch. - RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); - - // Sanity check the databuffer. - // Special case: if there's more threads than buffers, some threads simply get the final control - // messages (eoe/eof), and so they will not perform the check. - if (!in_buffer->eoe() && !in_buffer->eof()) { - int32_t num_rows = in_buffer->NumRows(); - int32_t num_cols = in_buffer->NumCols(); - if (num_rows == 0 || num_cols == 0) { - RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer."); - } - } - - // Now that init work is done, drop into the main fetching loop. - // Map op does not use child iterator, and it needs to manually handle eoe and eof's itself - // rather than use the base-class defaults. - while (true) { - // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work - // with Performance Mode design. - if (in_buffer->eoe()) { - // Calling base class EoeReceived to forward eoe buffer. - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); - continue; - } else if (in_buffer->eof()) { - // Calling base class EofReceived to forward eof buffer. - RETURN_IF_NOT_OK(EofReceived(worker_id)); - break; - } - - std::unique_ptr new_tensor_table(std::make_unique()); - // Perform the compute function of TensorOp(s) and store the result in new_tensor_table. - RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get())); - - // Replace the TensorTable in DataBuffer with the new one. - in_buffer->set_tensor_table(std::move(new_tensor_table)); - - // Push the buffer onto the connector for next operator to consume. - RETURN_IF_NOT_OK(out_connector_->Add(static_cast(worker_id), std::move(in_buffer))); - - // Fetch the next buffer and loop back to the top. - RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); - } - - return Status::OK(); -} - -Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table) { - // Getting number of rows and cols in this buffer. - int32_t num_rows = in_buffer->NumRows(); - int32_t num_cols = in_buffer->NumCols(); - - for (int32_t r = 0; r < num_rows; r++) { - // to_process : A vector of Tensors only holding cols in input_columns. - // result_row; : A vector of Tensors to hold the result after Compute(). - // cur_row : A vector of Tensors holding all the columns from DataBuffer. - TensorRow to_process, result_row, cur_row; - RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); - - // Populate the Tensor from the current row to be processed by TensorOp - for (const auto &idx : to_process_indices_) { - to_process.push_back(std::move(cur_row[idx])); - } - - // Looping over multiple TensorOps supplied in to MapOp. - // The assumption is that the result of one TensorOp matches the required input to the next TensorOp. - for (size_t i = 0; i < tfuncs_.size(); i++) { - // TensorOp can operate on single col or multiple cols. MapOp always call compute for multiple cols. - // TensorOp base class will call the single column Compute() depending on the ops. - // Note: The columns of the result_row is not preallocated, the compute function of each tensor op are - // required to resize/push back the result_row - RETURN_IF_NOT_OK(tfuncs_[i]->Compute(to_process, &result_row)); - - // Assign result_row to to_process for the next TensorOp processing, except for the last TensorOp in the list. - if (i + 1 < tfuncs_.size()) { - to_process = std::move(result_row); - } - } - - if (out_columns_.size() != result_row.size()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Result of a tensorOp doesn't match output column names"); - } - - if (in_columns_.size() == out_columns_.size()) { - for (size_t i = 0; i < result_row.size(); i++) { - cur_row[to_process_indices_[i]] = std::move(result_row[i]); - } - new_tensor_table->push_back(std::move(cur_row)); - } else { - // Add the columns we did not touch to the result_row. - for (int32_t i = 0; i < num_cols; i++) { - if (keep_input_columns_[i]) { - result_row.push_back(std::move(cur_row[i])); - } - } - - // Add this final result_row to our new TensorTable. - new_tensor_table->push_back(std::move(result_row)); - } - } - - return Status::OK(); -} - -Status MapOp::ComputeColMap() { - // If the map has not been set up yet in the base class, then set it up - if (column_name_id_map_.empty()) { - std::unordered_map current_name_id_map = child_[0]->column_name_id_map(); - // Initialize private variables - RETURN_IF_NOT_OK(InitPrivateVariable(¤t_name_id_map)); - // Create the final column name to index mapping in the base class field - CreateFinalColMap(¤t_name_id_map); - MS_LOG(DEBUG) << "Column name map for map op set: " << this->ColumnNameMapAsString(); - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -// Validating if each of the input_columns exists in the DataBuffer. -Status MapOp::ValidateInColumns(const std::unordered_map &col_name_id_map) { - for (const auto &inCol : in_columns_) { - bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; - if (!found) { - std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - return Status::OK(); -} - -Status MapOp::InitPrivateVariable(std::unordered_map *col_name_id_map) { - // If input_columns is empty(), The col at index-0 will be picked. - if (in_columns_.empty()) { - for (const auto &pair : *col_name_id_map) { - if (pair.second == 0) { - MS_LOG(INFO) << "Input columns empty for map op, will apply to the first column in the current table."; - in_columns_.push_back(pair.first); - break; - } - } - - // If caller didn't specify the out_col_names, assume they are same as the input_columns. - // This was done in the constructor, but if input columns was empty to start we have to redo it here. - if (out_columns_.empty() || out_columns_[0].empty()) { - out_columns_ = in_columns_; - } - } - - // Before we continue, issue a sanity check to make sure the input columns from user and the incoming - // columns from child are correct - RETURN_IF_NOT_OK(this->ValidateInColumns(*col_name_id_map)); - - // initialize keep_input_columns, true means to keep the column. - keep_input_columns_.resize(col_name_id_map->size(), true); - for (const auto &col_name : in_columns_) { - int32_t missed = (*col_name_id_map)[col_name]; - keep_input_columns_[missed] = false; - } - - // initialize to_process_indices. - for (const auto &col_name : in_columns_) { - to_process_indices_.push_back((*col_name_id_map)[col_name]); - } - return Status::OK(); -} - -// Create the final column name to index mapping and get indices of the columns this mapop does not use. -void MapOp::CreateFinalColMap(std::unordered_map *col_name_id_map) { - std::unordered_map final_col_name_id_map; - size_t num_cols = col_name_id_map->size(); - std::vector new_ids(num_cols); - if (in_columns_.size() == out_columns_.size()) { - for (size_t i = 0; i < in_columns_.size(); i++) { - int32_t loc = (*col_name_id_map)[in_columns_[i]]; - (void)col_name_id_map->erase(in_columns_[i]); - (*col_name_id_map)[out_columns_[i]] = loc; - } - - // Set the base class final column id map result - column_name_id_map_ = *col_name_id_map; - } else { - int32_t fill_idx = 0; - // First columns of the tables are occupied by the output columns from tensorOp. - for (const auto &col_name : out_columns_) { - final_col_name_id_map[col_name] = fill_idx++; - } - - // Creating new_ids mapping for the columns we keep. - for (size_t i = 0; i < num_cols; i++) { - if (keep_input_columns_[i]) { - new_ids[i] = fill_idx++; - } - } - - // Iterating through the old mapping to update the final mapping for the columns we kept. - std::string name; - for (const auto &pair : *col_name_id_map) { - name = pair.first; - int32_t old_id = pair.second; - if (keep_input_columns_[old_id]) { - final_col_name_id_map[name] = new_ids[old_id]; - } - } - - // Set the base class final column id map result - column_name_id_map_ = final_col_name_id_map; - } -} - -// Visitor accept method for NodePass -Status MapOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h deleted file mode 100644 index 371d865196..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h +++ /dev/null @@ -1,261 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_MAP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_MAP_OP_H_ - -#include -#include -#include -#include -#include -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/queue.h" - -namespace mindspore { -namespace dataset { -// Forward declare -class DataBuffer; -class ExecutionTree; - -// MapOp class implements the Map operator. It will apply a list of operations to each record specified by column names. -// The column order behavior after MapOp is as follows. -// [Case 1] If the number of Input Columns == the number of Output Column, column ordering after MapOp -// is the same as the original column order where the Remainder Columns stay in the same position, -// and the Output Columns are placed the same position of the Input Columns. -// For example, initially if the dataset has column order |A, B, C, D, E|, -// and we apply MapOp() with Input Columns {B, C} and Output Columns {X, Y}. -// The column order after applying MapOp will be |A, X, Y, D, E|. -// Note that in this case, |X, Y| is the Output Columns and |A, D, E| which is the Remainder Columns stay in -// their original position, and column B is replaced by column X and column C is replace by column Y. -// [Case 2] If the number of Input Columns != the number of Output Column, column ordering after MapOp -// is Output Columns followed by Remainder Columns. -// For example, initially if the dataset has column order |A, B, C, D, E|, -// and we apply MapOp() with Input Columns {B, C, A} and Output Columns {X, Y}. -// The column order after applying MapOp will be |X, Y, D, E|. -// Note that in this case, |X, Y| is the Output Columns and |D, E| is the Remainder Columns, -// and the Input Columns are gone and replaced by the Output Columns. - -// Keywords: -// Input Columns : a vector of column names (string) passed to MapOp specifying the column names from which -// Tensors are taken and passed to the TensorOp Compute(). -// Output Columns : a vector of column names (string) passed to MapOp specifying what are the column names -// for the Tensors produced by TensorOp Compute(). -// Remainder Columns : columns that exist in the dataset but are not mentioned in Input Columns. -// These columns will not be passed to TensorOp Compute(), but will be appended to the end of the Output Columns. -class MapOp : public ParallelOp { - public: - // The nested builder class inside of the MapOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetInColNames(const std::vector &in_col_names) { - build_in_col_names_ = in_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOutColNames(const std::vector &out_col_names) { - build_out_col_names_ = out_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetTensorFuncs(std::vector> funcs) { - build_tensor_funcs_ = std::move(funcs); - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - build_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - build_op_connector_size_ = connector_size; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetPerformanceMode(bool perf_mode) { - build_perf_mode_ = perf_mode; - return *this; - } - - // The builder "build" method creates the final object. - // @param ptr The shared_ptr to the new MapOp object - // @return Status - Status Build(std::shared_ptr *ptr); - - private: - std::vector build_in_col_names_; - std::vector build_out_col_names_; - std::vector> build_tensor_funcs_; - int32_t build_num_workers_; - int32_t build_op_connector_size_; - bool build_perf_mode_; // Default true. - - // Check if the required parameters are set by the builder. - // @return Status The error code return - Status sanityCheck() const; - }; - - // Constructor of MapOp - // @note The builder class should be used to call it. - // @param in_col_names A list of input column names (should match the input/output \p tensorFuncs). - // @param out_col_names A list of output column names (should match the input/output \p tensorFuncs). - // @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data. - // @param num_workers The number of worker threads. - // @param op_connector_size The size of each queue in the connector. - MapOp(const std::vector &in_col_names, const std::vector &out_col_names, - std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, - bool perf_mode); - - // Destructor - ~MapOp() = default; - - // A print method typically used for debugging - // @param out The output stream to write output to - // @param show_all A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out reference to the output stream being overloaded - // @param mo reference to the MapOp to display - // @return the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const MapOp &mo) { - mo.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status The error code return - Status operator()() override; - - // Getter - // @return the number of threads consuming data from previous op's output Connector. - int32_t num_consumers() const override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "MapOp"; } - - private: - // Local queues where worker threads can pop from. - // Popping directly from the Connector can block if the previous designated threads haven't pop. - // Setting the size of these queues to 0 is essentially the same as pulling directly from Connector. - QueueList> local_queues_; - - // Static variables to be ready by worker threads, no modification and readonly - const std::vector> tfuncs_; - - // Variable to store the column name that the tensorOps are consuming - std::vector in_columns_; - - // Variable to store the column name that the tensorOps are producing - std::vector out_columns_; - - // Boolean mapping, true means to keep the column. - std::vector keep_input_columns_; - - // Indices of the columns to process. - std::vector to_process_indices_; - - // Performance mode is when the main thread creates local queues, pulls databuffers from the previous - // op's Connector and distributes them to the local queues. Workers pull from the local queues. - // If this flag is false, each worker pulls directly from the Connector. This use less resources - // (thread and memory), but when the computation cost is heavy (e.g. DecodeOp) and fluctuating, it can - // cause additional blocking because pop calls to Connector from the threads are synchronized to enforce the order. - bool perf_mode_; - - // Private function for worker/thread to loop continuously. It comprises the main - // logic of MapOp: getting the data from previous Op, validating user specified column names, - // applying a list of TensorOps to each of the data, process the results and then - // pushing them back to MapOp's output Connector to be fetched by the next Op. - // @param worker_id The id assigned to this thread/worker upon creation. - // @return Status The error code return - Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ - - // Private helper function for getting the next buffer - // When PerformanceMode is enabled, workers pop from the local queue. - // Otherwise, workers pop from the first child output Connector. - // @param p_buffer - the buffer to return - // @return Status return code - Status FetchNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id) { - if (perf_mode_) { - RETURN_IF_NOT_OK(local_queues_[worker_id]->PopFront(p_buffer)); - } else { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id)); - } - return Status::OK(); - } - - // Private function for worker thread to perform TensorOp's compute function and get the result. - // @param in_buffer A raw pointer to the DataBuffer. A raw pointer is fine because this function doesn't manage memory - // and is not shared with other threads. - // @param[out] new_tensor_table A new Tensor Table to be populated in this function. - Status WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table); - - // Private function that create the final column name to index mapping and - // get indices of the columns this mapop does not use. - // @param col_name_id_map The column name to index mapping obtained from child operator - void CreateFinalColMap(std::unordered_map *col_name_id_map); - - // Validating if each of the input_columns exists in the DataBuffer. - // @param - the column map to check - // @return - status return code - Status ValidateInColumns(const std::unordered_map &col_name_id_map); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - // Private function for initializing private variables such as in_columns_, out_columns_. - // @return - Status - Status InitPrivateVariable(std::unordered_map *col_name_id_map); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_MAP_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc deleted file mode 100644 index 244861a6c8..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/parallel_op.h" - -#include -#include -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/core/config_manager.h" -#include "dataset/engine/db_connector.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -// Constructor -ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler) - : DatasetOp(op_connector_size, sampler), - num_workers_(num_workers), - num_producers_(num_workers), - worker_connector_size_(1), - worker_connector_(nullptr) {} - -// Creates the internal worker connector for the parallel op if the derived class wants to use it -Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { - if (worker_connector_size == 0) { - RETURN_STATUS_UNEXPECTED("Worker connector size 0 is invalid."); - } - num_producers_ = 1; - worker_connector_size_ = worker_connector_size; - // Instantiate the worker connector. This is the internal connector, not the operators - // output connector. It has single master consuming from it (num producers is 1), and the number - // of workers is the defined count from the op. - worker_connector_ = std::make_unique(num_workers_, num_producers_, worker_connector_size); - - return Status::OK(); -} - -// A print method typically used for debugging -void ParallelOp::Print(std::ostream &out, bool show_all) const { - // Summary 1-liner print - if (!show_all) { - out << " [workers: " << num_workers_ << "]"; - // Call super class printer - DatasetOp::Print(out, show_all); - } else { - // Detailed print - DatasetOp::Print(out, show_all); - out << "\nNum workers: " << num_workers_; - } -} - -// Override base class reset to provide reset actions specific to the ParallelOp class. -Status ParallelOp::Reset() { - RETURN_IF_NOT_OK(DatasetOp::Reset()); // Perform any super class reset work - - // ParallelOp is abstract, but we do own the connector between workers and master - // (if the parallel op is configured for this). Reset that connector here. - if (worker_connector_) { - worker_connector_->Reset(); - } - - return Status::OK(); -} - -// Register the internal worker connectors -Status ParallelOp::RegisterWorkerConnectors() { - if (worker_connector_) { - return (worker_connector_->Register(tree_->AllTasks())); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h b/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h deleted file mode 100644 index f59d4bfc53..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ -#define DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ - -#include -#include -#include "dataset/core/constants.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// global const in our namespace -constexpr int32_t kEndOfActions = -1; - -// Forward declares -class DataBuffer; - -class DbConnector; - -// A ParallelOp provides a multi-threaded DatasetOp -class ParallelOp : public DatasetOp { - public: - // Constructor - // @param num_workers - // @param op_connector_size - size of the output connector for this operator - // @param sampler - The sampler for the op - ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler = nullptr); - - // Destructor - ~ParallelOp() = default; - - // Creates the internal worker connector for the parallel op if the derived class wants to use it. - // @notes This changes the number of producers of this op to 1, since it establishes a master/worker - // relationship within the op, making all production flow through a single master. - // @return Status - The error return code - Status CreateWorkerConnector(int32_t worker_connector_size); - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param pO - reference to the ParallelOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const ParallelOp &po) { - po.Print(out, false); - return out; - } - - // During tree prepare phase, operators may have specific pre-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - // @return Status - The error return code - Status PrepareNodePreAction() override { - // Run common code from super class before adding ParallelOp specific logic - return (DatasetOp::PrepareNodePreAction()); - } - - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - // @return Status - The error return code - Status PrepareNodePostAction() override { - // Run common code from super class before adding ParallelOp specific logic - return (DatasetOp::PrepareNodePostAction()); - } - - // Override base class reset to provide reset actions specific to the ParallelOp class. - // @return Status - The error code return - Status Reset() override; - - // Getter - // @return the number of workers - int32_t num_workers() const override { return num_workers_; } - - // Getter - // @return the number of threads consuming from the previous Connector - int32_t num_consumers() const override { return num_workers_; } - - // Getter - // @return the number of producers pushing to the output Connector - // @notes The number of producers is commonly the same as number of workers, except in the case - // when a worker connector is set up. In that case, there are n workers, and a single master - // such that only 1 thread is a producer rather than the n workers. - // @return the number of producers - int32_t num_producers() const override { return num_producers_; } - - // Register the internal worker connectors. - // @return Status - Status RegisterWorkerConnectors() override; - - protected: - // Interface for derived classes to implement. All derived classes must provide the entry - // function with the main execution loop for worker threads. - // @return Status - The error code return - virtual Status WorkerEntry(int32_t workerId) = 0; - - int32_t num_workers_; // The number of worker threads - int32_t num_producers_; // The number of threads pushing to the out_connector_ - int32_t worker_connector_size_; - std::unique_ptr worker_connector_; // The internal connector for worker threads -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc deleted file mode 100644 index 1d017a4d3e..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/pipeline_op.h" -#include -#include - -namespace mindspore { -namespace dataset { -// Constructor -PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr sampler) - : DatasetOp(op_connector_size, sampler) {} - -// A print method typically used for debugging -void PipelineOp::Print(std::ostream &out, bool show_all) const { - // Summary 1-liner print - if (!show_all) { - out << " [workers: "; - if (this->inlined()) { - out << "0 (inlined)]"; - } else { - out << "1]"; // Pipeline ops only have 1 worker - } - // Call super class printer - DatasetOp::Print(out, show_all); - } else { - // Detailed print - DatasetOp::Print(out, show_all); - out << "\nNum workers: "; - if (this->inlined()) { - out << "0 (inlined)"; - } else { - out << "1"; // Pipeline ops only have 1 worker - } - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h b/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h deleted file mode 100644 index cb3c76813b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ - -#include -#include -#include "dataset/engine/datasetops/dataset_op.h" - -namespace mindspore { -namespace dataset { -// forward declare -class ExecutionTree; - -class DataBuffer; - -class PipelineOp : public DatasetOp { - public: - // Constructor - // @param op_connector_size - size of the output connector - // @return Builder setter method returns reference to the builder. - // @param sampler - The sampler for the op - explicit PipelineOp(int32_t op_connector_size, std::shared_ptr sampler = nullptr); - - // Destructor - ~PipelineOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param po - reference to the PipelineOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const PipelineOp &po) { - po.Print(out, false); - return out; - } - - // Getter - // @return The number of workers inside this op. Pipeline ops only have a single worker. - int32_t num_workers() const override { return 1; } - - // Getter - // @return the number of threads consuming from the previous Connector - int32_t num_consumers() const override { return 1; } - - // Getter - // @return The number of threads that push data to the output connector - int32_t num_producers() const override { return 1; } - - // During tree prepare phase, operators may have specific pre-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePreAction() override { - // Run common code from super class before adding PipelineOp specific logic - return (DatasetOp::PrepareNodePreAction()); - } - - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePostAction() override { - // Run common code from super class before adding PipelineOp specific logic - return (DatasetOp::PrepareNodePostAction()); - } - - protected: - // ******************************************************************************* - // I'm predicting there will be common arguments or functionality for pipeline ops, - // just not sure yet what those are. perhaps this intermediate class between - // DatasetOp and the actual ops is not needed at all? - // For example, if there's no common code for all of the non-parallel ops, then - // they can just inherit from DatasetOp directly and we can put this class into the - // trash. -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc deleted file mode 100644 index 5ce4056024..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc +++ /dev/null @@ -1,159 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/engine/datasetops/project_op.h" -#include -#include -#include -#include -#include -#include -#include -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -ProjectOp::Builder::Builder(const std::vector &columns_to_project) - : builder_columns_to_project_(columns_to_project) {} - -Status ProjectOp::Builder::SanityCheck() const { - if (builder_columns_to_project_.empty()) { - std::string err_msg("Columns to project is empty."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -Status ProjectOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(builder_columns_to_project_); - return Status::OK(); -} - -ProjectOp::ProjectOp(const std::vector &columns_to_project) - : PipelineOp(0), columns_to_project_(columns_to_project) {} - -void ProjectOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nColumns that are projected:"; - for (size_t i = 0; i < columns_to_project_.size(); i++) { - out << "\n" << columns_to_project_[i]; - } - out << "\n\n"; - } -} - -// Gets a buffer from the child operator and projects the buffer. -Status ProjectOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id, retry_if_eoe)); - if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) { - RETURN_IF_NOT_OK(Project(p_buffer)); - } - return Status::OK(); -} - -Status ProjectOp::Project(std::unique_ptr *data_buffer) { - std::unique_ptr new_tensor_table = std::make_unique(); - while ((*data_buffer)->NumRows() > 0) { - TensorRow current_row; - RETURN_IF_NOT_OK((*data_buffer)->PopRow(¤t_row)); - TensorRow new_row; - (void)std::transform(projected_column_indices_.begin(), projected_column_indices_.end(), - std::back_inserter(new_row), [¤t_row](uint32_t x) { return current_row[x]; }); - new_tensor_table->push_back(new_row); - } - (*data_buffer)->set_tensor_table(std::move(new_tensor_table)); - return Status::OK(); -} - -// Class functor operator () override. -// Most dataset ops operate by launching a thread (see ExecutionTree). -// However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the -// functor since this op runs inlined inside another operator. The function is overloaded to -// ensure that it is not called by mistake (it will generate an error). -Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); } - -int32_t ProjectOp::num_consumers() const { - if (parent_.empty()) { - MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1."; - return 1; - } else if (parent_[0] == nullptr) { - MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0."; - return 0; - } else { - return parent_[0]->num_consumers(); - } -} - -int32_t ProjectOp::num_producers() const { - if (child_.empty() || child_[0] == nullptr) { - MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0."; - return 0; - } else { - return child_[0]->num_producers(); - } -} - -Status ProjectOp::EoeReceived(int32_t worker_id) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } - -// Visitor accept method for NodePass -Status ProjectOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -// Compute the column map and save it into our own column name map -// We cannot use the super class ComputeColMap here because we're making a modification of the -// map from the child map. -Status ProjectOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - std::unordered_map child_column_name_mapping = child_[0]->column_name_id_map(); - for (size_t i = 0; i < columns_to_project_.size(); i++) { - std::string ¤t_column = columns_to_project_[i]; - if (child_column_name_mapping.find(current_column) == child_column_name_mapping.end()) { - std::string err_msg = "ProjectOp: column " + current_column + " does not exist in child operator."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - // Setup the new column name mapping for ourself (base class field) - column_name_id_map_[current_column] = i; - projected_column_indices_.push_back(child_column_name_mapping[current_column]); - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/project_op.h b/mindspore/ccsrc/dataset/engine/datasetops/project_op.h deleted file mode 100644 index 628c1342ba..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/project_op.h +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ -#define DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ - -#include -#include -#include - -#include "dataset/engine/datasetops/pipeline_op.h" - -namespace mindspore { -namespace dataset { -class ProjectOp : public PipelineOp { - public: - // The nested builder class inside of the ProjectOp is used to help manage all of the arguments - // for constructing it. This repeat op is very simple though, so this builder is really just - // provided for a consistent look and feel for creators of Dataset operators overall. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @param columns_to_project - - // @return This is a constructor. - explicit Builder(const std::vector &columns_to_project); - - // Builder destructor. - ~Builder() = default; - - // The builder "build" method creates the final object. - // @return shared_ptr to the new ProjectOp object. - Status Build(std::shared_ptr *); - - private: - std::vector builder_columns_to_project_; - Status SanityCheck() const; - }; - - // Constructor of the ProjectOp. - // @param columnsToProject - - explicit ProjectOp(const std::vector &columns_to_project); - - // Destructor. - ~ProjectOp() = default; - - // A print method typically used for debugging. - // @param out - The output stream to write output to. - // @param show_all - A bool to control if you want to show all info or just a summary. - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload. - // @notes This allows you to write the debug print info using stream operators. - // @param out - reference to the output stream being overloaded. - // @param project_op - reference to the ProjectOp to display. - // @return - the output stream must be returned. - friend std::ostream &operator<<(std::ostream &out, const ProjectOp &project_op) { - project_op.Print(out, false); - return out; - } - - // Class functor operator () override. - // Most dataset ops operate by launching a thread (see ExecutionTree). - // However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the - // functor since this op runs inlined inside another operator. The function is overloaded to - // ensure that it is not called by mistake (it will generate an error). - // @return Status - The error code returned. - Status operator()() override; - - // Gets a buffer from the child node and projects that buffer. The caller is typically our parent node. - // @param p_buffer - output pointer to the projected buffer. - // @param worker_id - The worker id - Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; - - // Base-class override. Return the number of workers in the first parent. - // @param workerId - The worker id - int32_t num_consumers() const override; - - // Base-class override. Return the number of producers in the first child. - // @param workerId - The worker id - int32_t num_producers() const override; - - // Base-class override for special eoe handler. - // Inline operators must override this because there is no connector to push eoe onto. - // @return Status - The error code returned. - Status EoeReceived(int32_t worker_id) override; - - // Base-class override for special eof handler. - // Inline operators must override this because there is no connector to push eof onto. - // @return Status - The error code returned. - Status EofReceived(int32_t worker_id) override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ProjectOp"; } - - private: - std::vector columns_to_project_; - std::vector projected_column_indices_; - - Status Project(std::unique_ptr *data_buffer); - - // Computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc deleted file mode 100644 index 23cd29d295..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc +++ /dev/null @@ -1,182 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/rename_op.h" -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// builds -RenameOp::Builder::Builder() { - // Some arguments to the RenameOp constructor have a default argument that is taken - // from the client config. - // The user may choose to change these values for the construction of the RenameOp by - // using the various builder set methods. - - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status RenameOp::Builder::SanityCheck() const { return Status::OK(); } - -// build method for RenameOp -Status RenameOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(builder_in_columns_, builder_out_columns_, builder_op_connector_size_); - return Status::OK(); -} - -// constructor -RenameOp::RenameOp(const std::vector &in_col_names, const std::vector &out_col_names, - int32_t op_connector_size) - : PipelineOp(op_connector_size), in_columns_(in_col_names), out_columns_(out_col_names) {} - -// destructor -RenameOp::~RenameOp() {} - -// main entry point for rename -Status RenameOp::operator()() { - TaskManager::FindMe()->Post(); - std::unique_ptr curr_buffer; - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - if (curr_buffer->buffer_flags() != DataBuffer::kDeBFlagNone) { - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); - std::string err_msg = "Rename first buffer got was control signal"; - // if 1st eoe or eof, pass it on then return - RETURN_STATUS_UNEXPECTED(err_msg); - } - - while (curr_buffer->eof() == false) { - while (curr_buffer->eoe() == false) { - // push the renamed input buffer - MS_LOG(DEBUG) << "Rename operator pushing next buffer."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - } // end of while eoe loop - - // we got eoe, now try again until we get eof - MS_LOG(DEBUG) << "Rename operator EOE Received."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - MS_LOG(DEBUG) << "Rename operator fetching buffer after EOE."; - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - } // end of while eof loop - - MS_LOG(DEBUG) << "Rename opeerator EOF Received."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - return Status::OK(); -} - -// Rename core functionality to compute the new column name id map. -// We need to overwrite the super class ComputeColMap here because we're making a modification of the -// map from the child map. -Status RenameOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - column_name_id_map_ = child_[0]->column_name_id_map(); - // iterate over my index in input vector, find the corresponding position - std::unordered_map new_col_name_id_map = {}; - // parameter for input check - size_t found = 0; - - // iterate over all the pairs and if there is a name match with rename, rename the column and add it to new map - // by doing it this way we recreate a new ColNameIdMap and allow for switching - for (const auto &pair : column_name_id_map_) { - std::string name = pair.first; - int32_t id = pair.second; - // find name - std::vector::iterator it; - it = std::find(in_columns_.begin(), in_columns_.end(), name); - // for c input checks here we have to count the number of times we find the stuff in in_columns_ - // because we iterate over the mInputList n times - if (it != in_columns_.end()) { - // found - found += 1; - int index = std::distance(in_columns_.begin(), it); - MS_LOG(DEBUG) << "Rename operator index found " << index << " value " << id << "."; - - new_col_name_id_map[out_columns_[index]] = id; - } else { - // not found - MS_LOG(DEBUG) << "Rename operator index not found: " << id << " is the column id."; - new_col_name_id_map[name] = id; - } - } - // only checks number of renamed columns have been found, this input check doesn't check everything - if (found != in_columns_.size()) { - MS_LOG(DEBUG) << "Rename operator column names found: " << found << " out of " << in_columns_.size() << "."; - std::string err_msg = "Renamed column doesn't exist in dataset"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Now, overwrite our column map with the new renamed columns/id's - column_name_id_map_ = new_col_name_id_map; - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -// prints rename -void RenameOp::Print(std::ostream &out, // In: The output stream to print to - bool show_all) const { // In: T/F if it should print everything - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nIn columns:"; - for (size_t i = 0; i < in_columns_.size(); ++i) { - out << "\n " << in_columns_[i]; - } - for (size_t i = 0; i < out_columns_.size(); ++i) { - out << "\n " << out_columns_[i]; - } - out << "\n\n"; - } -} - -Status RenameOp::EofReceived(int32_t) { - MS_LOG(DEBUG) << "Rename operator EOF received, do nothing now."; - return Status::OK(); -} - -Status RenameOp::EoeReceived(int32_t) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -// Visitor accept method for NodePass -Status RenameOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.h b/mindspore/ccsrc/dataset/engine/datasetops/rename_op.h deleted file mode 100644 index e209c075d6..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.h +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ -#define DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ - -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// forward declare -class DataBuffer; - -class RenameOp : public PipelineOp { - public: - // The nested builder class inside of the RenameOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetInColNames(const std::vector &in_col_names) { - builder_in_columns_ = in_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOutColNames(const std::vector &out_col_names) { - builder_out_columns_ = out_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // The builder "build" method creates the ZipOp dataset Operator. - // @return shared_ptr to the new RenameOp object - Status Build(std::shared_ptr *); - - private: - std::vector builder_in_columns_; - std::vector builder_out_columns_; - int32_t builder_op_connector_size_; - - Status SanityCheck() const; - }; - - // Constructor for RenameOp - // @param in_col_names names of columns to rename - // @param out_col_names names of columns after rename - // @param op_connector_size connector size - RenameOp(const std::vector &in_col_names, // In: Col names to consume - const std::vector &out_col_names, // In: Col names to produce - int32_t op_connector_size); - - // Destructor - ~RenameOp(); - - Status EofReceived(int32_t) override; - - Status EoeReceived(int32_t) override; - - // Print function for Rename - // @param out output stream to print to - // @param show_all if it should print everything - void Print(std::ostream &out, bool show_all) const override; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const RenameOp &ro) { - ro.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "RenameOp"; } - - protected: - // Rename core functionality - // Computing the assignment of the new column name map. - // @return - Status - Status ComputeColMap() override; - - // Variable to store the input column names - std::vector in_columns_; - - // Variable to store the output column names - std::vector out_columns_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc deleted file mode 100644 index 4999dddd02..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc +++ /dev/null @@ -1,196 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include - -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/repeat_op.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {} - -Status RepeatOp::Builder::SanityCheck() const { - if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) { - std::string err_msg("Repeat count must be > 0 or -1."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status RepeatOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_repeats_); - return Status::OK(); -} - -// Constructor of the RepeatOp. -RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {} - -// Destructor -RepeatOp::~RepeatOp() {} - -// A print method typically used for debugging -void RepeatOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [repeats: " << max_repeats_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ - << "\nLeaf Nodes in execution path:"; - if (!eoe_ops_.empty()) { - for (size_t i = 0; i < eoe_ops_.size(); i++) { - out << "\n Operator: " << eoe_ops_[i]->id(); - } - } else { - out << " None."; - } - out << "\n\n"; - } -} - -// Base-class override for executing specific RepeatOp configurations. This code will be called -// during the execution tree prepare phase when it is visiting this operator. -Status RepeatOp::PrepareNodePostAction() { - // Run any common code from super class first before adding our own specific logic - RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); - std::shared_ptr leaf_op = tree_->PopFromEOEOpStack(); - while (leaf_op != nullptr) { - // Track the leaf operators that are under this repeat op. - eoe_ops_.push_back(leaf_op); - leaf_op = tree_->PopFromEOEOpStack(); - } - // Push ourselves to the stack in case one of our ascendants is repeat too. - tree_->AddToEOEOpStack(shared_from_this()); - return Status::OK(); -} - -// Base-class override for setting specific RepeatOp configurations. This code will be called -// during the execution tree prepare phase BEFORE traversing down to child operators. -uint32_t RepeatOp::PrepareFlags() const { return ExecutionTree::kDePrepRepeat; } - -// This function returns the buffer that is at the top of our output connector. The caller is -// typically our parent node, when the parent is asking us to provide the next buffer of data. -// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get -// a buffer from our child. -// This function sets the `retryIfEoe` flag when popping from the child connector. This way, -// this function will retry to pop the connector again and will get the non-EOE buffer if any. -Status RepeatOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { - if (child_.empty()) { - RETURN_STATUS_UNEXPECTED("RepeatOp can't be the leaf node."); - } - - std::unique_ptr buf; - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - // Loop until non EOE is received - while (buf->eoe()) { - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - if (state_ == OpState::kDeOpIdle) { - *p_buffer = std::move(buf); - return Status::OK(); - } - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - } - // Check if the last buf is next eof - if (buf->eof()) { - RETURN_IF_NOT_OK(EofReceived(worker_id)); - } - *p_buffer = std::move(buf); - return Status::OK(); -} - -// Base-class override for handling cases when an eoe is received. -Status RepeatOp::EoeReceived(int32_t worker_id) { - repeat_count_++; - MS_LOG(DEBUG) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << "."; - bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); - bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); - // If we've reached the requested repeat count, then flag the eoe nodes - // to tell them they've got one more epoch to perform. When they reach the end - // of the last epoch, they quit rather than loop again. This happens in two cases: - // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or, - // 2- We are not repeated - if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) { - for (auto &eoe_op : eoe_ops_) { - eoe_op->set_control_flag(kDeOpLastRepeat); - } - } - if (repeat_count_ == max_repeats_) { - repeat_count_ = 0; - state_ = OpState::kDeOpIdle; - return Status::OK(); - } - - // base-class ResetSubtree - return (DatasetOp::ResetSubtree()); -} - -// Class functor operator () override. -// Most dataset ops operate by launching a thread (see ExecutionTree). -// However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the -// functor since this op runs inlined inside another operator. The function is overloaded to -// ensure that it is not called by mistake (it will generate an error). -Status RepeatOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. RepeatOp is an inlined operator."); } - -// Base-class override for handling cases when an eof is received. -Status RepeatOp::EofReceived(int32_t worker_id) { - MS_LOG(DEBUG) << "Repeat operator EOF received, do nothing now."; - return Status::OK(); -} - -int32_t RepeatOp::num_consumers() const { - if (parent_.empty()) { - MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1."; - return 1; - } else if (parent_[0] == nullptr) { - MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0."; - return 0; - } else { - return parent_[0]->num_consumers(); - } -} - -int32_t RepeatOp::num_producers() const { - if (child_.empty() || child_[0] == nullptr) { - MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; - return 0; - } else { - return child_[0]->num_producers(); - } -} - -// Visitor accept method for NodePass -Status RepeatOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h deleted file mode 100644 index bba85c3bb5..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ -#define DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ - -#include -#include -#include -#include "dataset/engine/datasetops/pipeline_op.h" - -namespace mindspore { -namespace dataset { -class RepeatOp : public PipelineOp { - public: - static constexpr int32_t kInfiniteRepeat = -1; - - // The nested builder class inside of the RepeatOp is used to help manage all of the arguments - // for constructing it. This repeat op is very simple though, so this builder is really just - // provided for a consistent look and feel for creators of Dataset operators overall. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @param count - The number of repeats to do - // @return This is a constructor. - explicit Builder(int32_t count); - - // Default destructor - ~Builder() = default; - - // The builder "build" method creates the final object. - // @return shared_ptr to the new RepeatOp object - Status Build(std::shared_ptr *); - - private: - int32_t build_max_repeats_; - - Status SanityCheck() const; - }; - - // Constructor of the RepeatOp. - // @note The builder class should be used to call it - // @param count - The number of repeats to do - explicit RepeatOp(int32_t count); - - // Destructor - ~RepeatOp(); - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param ro - reference to the RepeatOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const RepeatOp &ro) { - ro.Print(out, false); - return out; - } - - // Class functor operator () override. - // Most dataset ops operate by launching a thread (see ExecutionTree). - // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the - // functor since this op runs inlined inside another operator. The function is overloaded to - // ensure that it is not called by mistake (it will generate an error). - // @return Status - The error code return - Status operator()() override; - - // Base-class override for setting specific RepeatOp configurations. This code will be called - // during the execution tree prepare phase BEFORE traversing down to child operators. - uint32_t PrepareFlags() const override; - - // Base-class override for executing specific RepeatOp configurations. This code will be called - // during the execution tree post-prepare phase when it is visiting this operator. - Status PrepareNodePostAction() override; - - // This function returns the buffer that is at the top of our output connector. The caller is - // typically our parent node, when the parent is asking us to provide the next buffer of data. - // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get - // a buffer from our child. - // @note This function sets the `retryIfEoe` flag when popping from the child connector. This way, - // this function will retry to pop the connector again and will get the non-EOE buffer if any. - // @param p_buffer - output pointer to the buffer that it will fetch. - // @param worker_id - The worker id - // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. - // @return Status - The error code return - Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; - - // Base-class override for handling cases when an eoe is received. - // @param worker_id - The worker id - Status EoeReceived(int32_t worker_id) override; - - // Base-class override for handling cases when an eof is received. - // @param worker_id - The worker id - Status EofReceived(int32_t worker_id) override; - - // Base-class override. Return the number of workers in the first parent. - // @param workerId - The worker id - int32_t num_consumers() const override; - - // Base-class override. Return the number of producers in the first child. - // @param workerId - The worker id - int32_t num_producers() const override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "RepeatOp"; } - - private: - int32_t max_repeats_; // The number of repeats that the user requested - int32_t repeat_count_; // A counter for the current number of executed repeats - std::vector> eoe_ops_; // List of operators that can generate EOE underneath this repeat. -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc deleted file mode 100644 index f86fcc602b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc +++ /dev/null @@ -1,304 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#if defined(_WIN32) || defined(_WIN64) -#include -#endif -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -constexpr int32_t ShuffleOp::kShuffleStateInit; -constexpr int32_t ShuffleOp::kShuffleStateActive; -constexpr int32_t ShuffleOp::kShuffleStateDrain; - -// Builder constructor. Creates the builder object. -ShuffleOp::Builder::Builder() : build_shuffle_size_(0), build_reshuffle_each_epoch_(true) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_op_connector_size_ = cfg->op_connector_size(); - build_rows_per_buffer_ = cfg->rows_per_buffer(); - build_shuffle_seed_ = GetSeed(); -} - -Status ShuffleOp::Builder::SanityCheck() const { - if (build_shuffle_size_ < 2) { - RETURN_STATUS_UNEXPECTED("Shuffle buffer size must be greater than 1."); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status ShuffleOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_shuffle_size_, build_shuffle_seed_, build_op_connector_size_, - build_reshuffle_each_epoch_, build_rows_per_buffer_); - return Status::OK(); -} - -// Constructor of the ShuffleOp -ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch, - int32_t rows_per_buffer) - : PipelineOp(op_connector_size), - shuffle_size_(shuffle_size), - shuffle_seed_(shuffle_seed), - reshuffle_each_epoch_(reset_every_epoch), - rng_(shuffle_seed), - buffer_counter_(0), - rows_per_buffer_(rows_per_buffer), - shuffle_buffer_(std::make_unique()), - shuffle_last_row_idx_(0), - shuffle_buffer_state_(kShuffleStateInit) {} - -// Private function to re-init the shuffle op for another epoch. Shuffle op calls this by -// itself rather than waiting for the reset driven from operators above it in the pipeline. -Status ShuffleOp::SelfReset() { - MS_LOG(DEBUG) << "Shuffle operator performing a self-reset."; - // If reshuffle_each_epoch is false, then we always use the same seed for every - // epoch. - // If reshuffle_each_epoch is true, then the first epoch uses the given seed, - // and all subsequent epochs will then keep on using the rng_ without resetting it - if (!reshuffle_each_epoch_) { - rng_ = std::mt19937_64(shuffle_seed_); - } - - shuffle_buffer_ = std::make_unique(); - buffer_counter_ = 0; - shuffle_last_row_idx_ = 0; - shuffle_buffer_state_ = kShuffleStateInit; - return Status::OK(); -} - -// A print method typically used for debugging -void ShuffleOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [shuffle size: " << shuffle_size_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nShuffle size: " << shuffle_size_ << "\nRows per buffer: " << rows_per_buffer_ - << "\nShuffle buffer state: " << shuffle_buffer_state_ << "\nShuffle seed: " << shuffle_seed_ << "\n\n"; - } -} - -// Private function to add a new row to the shuffle buffer. -Status ShuffleOp::AddRowToShuffleBuffer(TensorRow new_shuffle_row) { - // If the last slot of our shuffle buffer was not the full size of the shuffle buffer then we are - // filling it during the initial fill codepath and thus growing it's size. In that case, we push - // back the new row to grow our shuffle buffer size by 1. - // If we are already at the full size, then we overwrite the last slot with our row (and the last - // slot better be empty because it should already have been swapped out during the random row - // selection that was done previously!) - if (shuffle_last_row_idx_ < (shuffle_size_ - 1)) { - shuffle_buffer_->push_back(std::move(new_shuffle_row)); - shuffle_last_row_idx_ = (shuffle_buffer_->size()) - 1; - } else { - if (!(*shuffle_buffer_)[shuffle_last_row_idx_].empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Last row of shuffle buffer should not be occupied!"); - } - (*shuffle_buffer_)[shuffle_last_row_idx_] = std::move(new_shuffle_row); - } - return Status::OK(); -} - -// Class functor operator () override. -// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will -// provide the master loop that drives the logic for performing the work -Status ShuffleOp::operator()() { - std::unique_ptr new_buffer_table; // A tensor table to be used for output. - - // Synchronize with TaskManager once the thread is launched. - TaskManager::FindMe()->Post(); - - // Shuffle op does not have workers, and only consumes from child 0. - // Create the child iterator to fetch our data from. - int32_t worker_id = 0; - int32_t child_idx = 0; - child_iterator_ = std::make_unique(this, worker_id, child_idx); - - // Main operator loop - while (true) { - // Do an initial populate of the shuffle buffer - RETURN_IF_NOT_OK(InitShuffleBuffer()); - - // This is our main loop exit condition, when the iterator has no more data completely. - if (child_iterator_->eof_handled()) { - break; - } - - // Next, enter into the main execution loop of the shuffle op. - // When the tail index position of our shuffle buffer goes negative it means that we've - // fully drained the data from the shuffle buffer and we're done. - while (shuffle_last_row_idx_ >= 0) { - // Step 1) - // Create an output tensor table if one is not created yet. - if (!new_buffer_table) { - new_buffer_table = std::make_unique(); - } - - // Step 2) - // Randomly select a slot from our shuffle buffer and copy that row into the output - // tensor table. We remove the data from the shuffle buffer, leaving that slot - // in the table as an empty vector - int64_t random_slot = rng_() % (shuffle_last_row_idx_ + 1); - new_buffer_table->push_back(std::move((*shuffle_buffer_)[random_slot])); - - // Step 3) - // If the output tensor table is at the requested size, then create a buffer for it - // and send this buffer on it's way up the pipeline. Special case is if this is the - // last row then we also send it. - if (new_buffer_table->size() == rows_per_buffer_ || shuffle_last_row_idx_ == 0) { - auto new_buffer = std::make_unique(buffer_counter_, DataBuffer::kDeBFlagNone); - new_buffer->set_tensor_table(std::move(new_buffer_table)); - buffer_counter_++; - MS_LOG(DEBUG) << "Shuffle operator sending a buffer to output."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(new_buffer))); - } - - // Step 4) - // Take the last row from shuffle buffer, and swap it into the row position that was - // just vacated. This makes the shuffle buffer contiguous, with an empty slot at the - // tail of the shuffle buffer. - if (random_slot != shuffle_last_row_idx_) { - (*shuffle_buffer_)[random_slot] = std::move((*shuffle_buffer_)[shuffle_last_row_idx_]); - } - - // Step 5) - // Refill the last slot of the shuffle buffer with the next row from input if we are in the - // active state. - // If we are in the draining state, we do not need to fetch another row to replace the one we - // just drained. - if (shuffle_buffer_state_ == kShuffleStateActive) { - TensorRow new_row; - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - - if (!new_row.empty()) { - RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); - } else { - shuffle_buffer_state_ = kShuffleStateDrain; - } - } - - // If we are draining, reposition (decrement) our tail index in the shuffle buffer since we - // just drained a row from it. - if (shuffle_buffer_state_ == kShuffleStateDrain) { - shuffle_last_row_idx_--; - } - } - - // Since we overloaded eoeReceived function, we are responsible to flow the EOE up the - // pipepline manually now that we are done draining the shuffle buffer - MS_LOG(DEBUG) << "Shuffle operator sending EOE."; - auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - - // Do not wait for any reset to be flown down from operators above us. - // Instead, manually update ourselves and then go reloop to start fetching from child operator - // right away. Any Reset() from the parent will still perform common reset actions. - RETURN_IF_NOT_OK(this->SelfReset()); - } - return Status::OK(); -} - -// Private function populate the shuffle buffer initially by fetching from the child output -// connector until the shuffle buffer is full (or there is no more data coming). -Status ShuffleOp::InitShuffleBuffer() { - MS_LOG(DEBUG) << "Shuffle operator initializing the shuffle buffer."; - - // The first phase of this operator is to read incoming buffers and then drain those - // rows from the buffers, putting them into our own local table of tensors (the shuffle - // buffer). - // This shuffle buffer initialization phase stops when we've either filled up the - // shuffle buffer to it's max size, or the dataset below us is not providing any more - // rows. - if (shuffle_buffer_state_ != kShuffleStateInit) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Invalid shuffle buffer state (SHUFFLE_STATE_INIT expected)"); - } - - // Before we drop into the fetching loop, call the fetch once for the first time - // to fill the first row and grab the first buffer. - TensorRow new_row; - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - - if (child_iterator_->eof_handled()) { - MS_LOG(DEBUG) << "Shuffle operator init picked up EOF. No more epochs."; - return Status::OK(); - } - - if (new_row.empty()) { - RETURN_STATUS_UNEXPECTED("Unable to fetch a single row for shuffle buffer."); - } - - // Now fill the rest of the shuffle buffer until we are unable to get the next row or we reached - // the desired shuffle buffer size. - while (!new_row.empty() && shuffle_buffer_->size() < static_cast(shuffle_size_ - 1)) { - // Add the previously fetched row - RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); - - // Fetch the next row - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - } - - // If we quit the loop due to being at the shuffle size, still need to add the last row here. - if (!new_row.empty()) { - RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); - shuffle_buffer_state_ = kShuffleStateActive; // Transition to the active state - } else { - // If init phase doesn't have more rows, then skip the active state and jump straight to the - // shuffle buffer draining state - shuffle_buffer_state_ = kShuffleStateDrain; - } - - MS_LOG(DEBUG) << "Shuffle operator finished intializing the shuffle buffer."; - return Status::OK(); -} - -Status ShuffleOp::EoeReceived(int32_t worker_id) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -// Visitor accept method for NodePass -Status ShuffleOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h deleted file mode 100644 index 14b1e4511e..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h +++ /dev/null @@ -1,204 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// Forward declare -class ExecutionTree; - -class DbConnector; - -class DataBuffer; - -class ShuffleOp : public PipelineOp { - // Shuffle buffer state flags - // - // Shuffle buffer is in a state of being initialized - static constexpr int32_t kShuffleStateInit = 0; - - // Shuffle buffer is in a state of being actively drained from, but refilling as well - static constexpr int32_t kShuffleStateActive = 1; - - // Shuffle buffer is in a state of being drained - static constexpr int32_t kShuffleStateDrain = 2; - - public: - // The nested builder class inside of the ShuffleOp is used to help manage all of the arguments - // for constructing it. The shuffle op is fairly simple though, but the builder provides a - // consistent look and feel for creators of Dataset operators overall. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetShuffleSize(int32_t shuffle_size) { - build_shuffle_size_ = shuffle_size; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetShuffleSeed(uint32_t shuffle_seed) { - build_shuffle_seed_ = shuffle_seed; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - build_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetReshuffleEachEpoch(bool reshuffle_each_epoch) { - build_reshuffle_each_epoch_ = reshuffle_each_epoch; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - build_op_connector_size_ = op_connector_size; - return *this; - } - - // The builder "build" method creates the final object. - // @return shared_ptr to the new ShuffleOp object - Status Build(std::shared_ptr *); - - private: - // The builder saves all ShuffleOp construction arguments internally. - // The following are the arguments. - int32_t build_shuffle_size_; - uint32_t build_shuffle_seed_; - int32_t build_rows_per_buffer_; - bool build_reshuffle_each_epoch_; - int32_t build_op_connector_size_; - - Status SanityCheck() const; - }; - - // Constructor of the ShuffleOp - // @note The builder class should be used to call it - // @param shuffle_size - The size for the shuffle buffer - // @param shuffle_seed - The seed to use for random number generation - // @param op_connector_size - The output connector queue size - // @param rows_per_buffer - The requested number of rows per buffer - ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch, - int32_t rows_per_buffer); - - // Destructor - ~ShuffleOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param so - reference to the ShuffleOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const ShuffleOp &so) { - so.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Base-class override for special eoe handler. - // ShuffleOp must override this because it shall not perform default handling of eoe. Instead - // the ShuffleOp needs to manage actions related to the end of the epoch itself. - // @return Status - The error code return - Status EoeReceived(int32_t worker_id) override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ShuffleOp"; } - - private: - // Private function to add a new row to the shuffle buffer. - // @return Status - The error code return - Status AddRowToShuffleBuffer(TensorRow new_shuffle_row); - - // Private function to populate the shuffle buffer initially by fetching from the child output - // connector until the shuffle buffer is full (or there is no more data coming). - // @return Status - The error code return - Status InitShuffleBuffer(); - - // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by - // itself rather than waiting for the reset driven from operators above it in the pipeline. - // @return Status - The error code return - Status SelfReset(); - - int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows) - uint32_t shuffle_seed_; - bool reshuffle_each_epoch_; - // rng_ is seeded initially with shuffle_seed_. mt19937 is used for its large period. - // specifically mt19937_64 is used to generate larger random numbers to reduce bias when - // modding to fit within our desired range. we dont use a distribution - // (ie uniform_int_distribution) because we will need to create up to |dataset| instances - // of the distribution object in the common case of a perfect shuffle - std::mt19937_64 rng_; - int32_t buffer_counter_; // For creating new buffer id's - int32_t rows_per_buffer_; // Number of rows to pack into output buffer - // A single (potentially large) buffer of tensor rows for performing shuffling. - std::unique_ptr shuffle_buffer_; - int32_t shuffle_last_row_idx_; // Internal tracking of the last slot of our shuffle buffer - int32_t shuffle_buffer_state_; // State tracking for the shuffle buffer phases of work - - std::unique_ptr child_iterator_; // An iterator for fetching. -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc deleted file mode 100644 index f6b0fe689c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/skip_op.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status SkipOp::Builder::SanityCheck() const { - if (build_max_skips_ < 0) { - std::string err_msg("Skip count must be positive integer or 0."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status SkipOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_skips_, builder_op_connector_size_); - return Status::OK(); -} - -// Constructor of the SkipOp. -SkipOp::SkipOp(int32_t count, int32_t op_connector_size) - : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} - -// Destructor -SkipOp::~SkipOp() {} - -// A print method typically used for debugging -void SkipOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [skips: " << max_skips_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nSkip count: " << skip_count_ << "\nMax skips: " << max_skips_ << "\n\n"; - } -} - -// Base-class override for handling cases when an eoe is received. -Status SkipOp::EoeReceived(int32_t worker_id) { - skip_count_ = 0; - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -// main entry point for skip -Status SkipOp::operator()() { - TaskManager::FindMe()->Post(); - std::unique_ptr curr_buffer; - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - - while (curr_buffer->eof() == false) { - // Reset count - skip_count_ = 0; - while (curr_buffer->eoe() == false) { - // Drop first count rows - while (skip_count_ < max_skips_) { - if (curr_buffer->eoe() || curr_buffer->eof()) { - break; - } - // Consider the rows of buffer more than one - TensorRow drop_row; - int row_num = curr_buffer->NumRows(); - int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; - skip_count_ += drop_num; - for (int i = 0; i < drop_num; i++) { - RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row)); - } - if (curr_buffer->NumRows() == 0) { - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - } - } - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - } - // we got eoe, now try again until we got eof - MS_LOG(DEBUG) << "Skip operator EOE Received."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - } - - MS_LOG(DEBUG) << "Skip operator EOF Received."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - return Status::OK(); -} - -// Base-class override for handling cases when an eof is received. -Status SkipOp::EofReceived(int32_t worker_id) { - MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; - return Status::OK(); -} - -// Visitor accept method for NodePass -Status SkipOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h deleted file mode 100644 index 4cb658b2a7..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ - -#include -#include -#include -#include "dataset/engine/datasetops/pipeline_op.h" - -namespace mindspore { -namespace dataset { -class SkipOp : public PipelineOp { - public: - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @param count - The number of skip to do - // @return This is a constructor. - explicit Builder(int32_t count); - - // Default destructor - ~Builder() = default; - - // The builder "build" method creates the final object. - // @return shared_ptr to the new SkipOp object - Status Build(std::shared_ptr *); - - private: - int32_t build_max_skips_; - int32_t builder_op_connector_size_; - - Status SanityCheck() const; - }; - - // Constructor of the SkipOp. - // @note The builder class should be used to call it - // @param count - The number of skips to do - explicit SkipOp(int32_t count, int32_t op_connector_size); - - // Destructor - ~SkipOp(); - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Base-class override for handling cases when an eoe is received. - // @param worker_id - The worker id - Status EoeReceived(int32_t worker_id) override; - - // Base-class override for handling cases when an eof is received. - // @param worker_id - The worker id - Status EofReceived(int32_t worker_id) override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "SkipOp"; } - - private: - int32_t max_skips_; // The number of skips that the user requested - int32_t skip_count_; // A counter for the current number of executed skips -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt deleted file mode 100644 index b78ddcd87b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_subdirectory(sampler) -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(engine-datasetops-source OBJECT - generator_op.cc - io_block.cc - mindrecord_op.cc - tf_reader_op.cc - image_folder_op.cc - mnist_op.cc - voc_op.cc - coco_op.cc - manifest_op.cc - cifar_op.cc - random_data_op.cc - celeba_op.cc - text_file_op.cc - clue_op.cc - ) \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc deleted file mode 100644 index c7a4269a39..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc +++ /dev/null @@ -1,423 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include "dataset/engine/datasetops/source/celeba_op.h" - -#include -#include -#include "dataset/core/config_manager.h" -#include "dataset/util/path.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/kernels/image/image_utils.h" - -namespace mindspore { -namespace dataset { -CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status CelebAOp::Builder::Build(std::shared_ptr *op) { - MS_LOG(DEBUG) << "Celeba dataset directory is " << builder_dir_.c_str() << "."; - MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << "."; - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - - builder_schema_ = std::make_unique(); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - // label is like this:0 1 0 0 1...... - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - *op = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, - builder_op_connector_size_, builder_decode_, builder_dataset_type_, - builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_)); - if (*op == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null"); - } - - return Status::OK(); -} - -Status CelebAOp::Builder::SanityCheck() { - Path dir(builder_dir_); - std::string err_msg; - err_msg += dir.IsDirectory() ? "" : "CelebA path is invalid or not set\n"; - err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is smaller than 1\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, - bool decode, const std::string &dataset_type, const std::set &exts, - std::unique_ptr schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size, std::move(sampler)), - rows_per_buffer_(rows_per_buffer), - folder_path_(dir), - decode_(decode), - extensions_(exts), - data_schema_(std::move(schema)), - num_rows_in_attr_file_(0), - dataset_type_(dataset_type) { - attr_info_queue_ = std::make_unique>>(queue_size); - io_block_queues_.Init(num_workers_, queue_size); -} - -Status CelebAOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "tree_ not set"); - } - - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(attr_info_queue_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - - RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Walking attr file", std::bind(&CelebAOp::ParseAttrFile, this))); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(ParseImageAttrInfo()); - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - - return Status::OK(); -} - -Status CelebAOp::ParseAttrFile() { - TaskManager::FindMe()->Post(); - Path folder_path(folder_path_); - std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString()); - if (!attr_file.is_open()) { - return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "Celeba attr file does not exist"); - } - - const auto PushBackToQueue = [this](std::vector &vec, std::ifstream &attr_file, - std::ifstream &partition_file) { - Status s = attr_info_queue_->EmplaceBack(vec); - if (s.IsError()) { - CLOSE_FILE(attr_file, partition_file); - return s; - } - return Status::OK(); - }; - - std::string rows_num; - std::string attr_name; - (void)getline(attr_file, rows_num); - try { - num_rows_in_attr_file_ = static_cast(std::stoul(rows_num)); // First line is rows number in attr file - } catch (std::invalid_argument &e) { - RETURN_STATUS_UNEXPECTED("Conversion to ulong failed, invalid argument."); - } catch (std::out_of_range &e) { - RETURN_STATUS_UNEXPECTED("Conversion to ulong failed, out of range."); - } - - (void)getline(attr_file, attr_name); // Second line is attribute name,ignore it - std::string image_info; - std::vector image_infos; - image_infos.reserve(oc_queue_size_); - while (getline(attr_file, image_info)) { - if ((image_info.empty()) || (dataset_type_ != "all" && !CheckDatasetTypeValid())) { - continue; - } - image_infos.push_back(image_info); - if (image_info.size() % oc_queue_size_ == 0) { - RETURN_IF_NOT_OK(PushBackToQueue(image_infos, attr_file, partition_file_)); - image_infos.clear(); - } - } - if (!image_infos.empty()) { - RETURN_IF_NOT_OK(PushBackToQueue(image_infos, attr_file, partition_file_)); - } - std::vector end_indicator = std::vector(0); - RETURN_IF_NOT_OK(PushBackToQueue(end_indicator, attr_file, partition_file_)); // end indicator - CLOSE_FILE(attr_file, partition_file_); - return Status::OK(); -} - -bool CelebAOp::CheckDatasetTypeValid() { - if (!partition_file_.is_open()) { - Path folder_path(folder_path_); - partition_file_.open((folder_path / "list_eval_partition.txt").toString()); - if (!partition_file_.is_open()) { - MS_LOG(ERROR) << "Celeba partition file does not exist!"; - return false; - } - } - std::string line; - (void)getline(partition_file_, line); - std::vector vec = Split(line); - if (vec.size() != 2) { - return false; - } - int32_t type; - try { - type = std::stoi(vec[1]); - } catch (std::invalid_argument &e) { - MS_LOG(WARNING) << "Conversion to unsigned long failed, invalid argument, " << vec[0] << "."; - return false; - } catch (std::out_of_range &e) { - MS_LOG(WARNING) << "Conversion to unsigned long failed, out of range, " << vec[0] << "."; - return false; - } - // train:0, valid=1, test=2 - if (dataset_type_ == "train" && (type == 0)) { - return true; - } else if (dataset_type_ == "valid" && (type == 1)) { - return true; - } else if (dataset_type_ == "test" && (type == 2)) { - return true; - } - - return false; -} - -Status CelebAOp::ParseImageAttrInfo() { - std::vector image_infos; - bool needMoreData = true; - RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); - while (!image_infos.empty() && needMoreData) { - for (uint32_t index = 0; index < image_infos.size(); index++) { - std::string image_info = image_infos[index]; - std::vector split = Split(image_info); - std::pair> image_labels; - - Path path(folder_path_); - Path file_path = path / split[0]; - if (!extensions_.empty() && extensions_.find(file_path.Extension()) == extensions_.end()) { - MS_LOG(WARNING) << "Unsupported file found at " << file_path.toString().c_str() << ", its extension is " - << file_path.Extension().c_str() << "."; - continue; - } - image_labels.first = split[0]; - for (uint32_t label_index = 1; label_index < split.size(); label_index++) { - int32_t value; - try { - value = std::stoi(split[label_index]); - } catch (std::invalid_argument &e) { - RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument."); - } catch (std::out_of_range &e) { - RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); - } - image_labels.second.push_back(value); - } - - image_labels_vec_.push_back(image_labels); - } - - RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); - } - - num_rows_ = image_labels_vec_.size(); - if (num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API " - "validation first."); - } - MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_ << "."; - return Status::OK(); -} - -std::vector CelebAOp::Split(const std::string &line) { - std::string str = line; - std::string::size_type pos; - std::vector split; - str += " "; - int size = str.size(); - for (uint32_t index = 0; index < size;) { - pos = str.find(" ", index); - if (pos != index) { // skip space - std::string s = str.substr(index, pos - index); - split.push_back(s); - } - index = pos + 1; - } - - return split; -} - -// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work -Status CelebAOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr data_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&data_buffer)); - RETURN_IF_NOT_OK(AddIOBlock(&data_buffer)); - return Status::OK(); -} - -Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { - int64_t buff_count = 0; - while (true) { - std::vector keys; - keys.reserve(rows_per_buffer_); - int64_t row_count = 0; - while (!(*data_buffer)->eoe()) { - TensorRow sample_row; - RETURN_IF_NOT_OK((*data_buffer)->PopRow(&sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) { - MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_ << "."; - continue; - } - keys.push_back(*itr); - row_count++; - if (row_count % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buff_count++ % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - keys.clear(); - } - } - RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); - } - - if (!keys.empty()) { - RETURN_IF_NOT_OK(io_block_queues_[(buff_count++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - RETURN_IF_NOT_OK( - io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK( - io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset - RETURN_IF_NOT_OK( - io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); - } - } -} - -Status CelebAOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty()) { - return Status::OK(); // empty key is a quit signal for workers - } - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unexpected nullptr received in worker"); -} - -Status CelebAOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - for (const auto &key : keys) { - TensorRow row; - RETURN_IF_NOT_OK(LoadTensorRow(key, image_labels_vec_[key], &row)); - deq->push_back(std::move(row)); - } - - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -Status CelebAOp::LoadTensorRow(row_id_type row_id, const std::pair> &image_label, - TensorRow *row) { - std::shared_ptr image; - std::shared_ptr label; - - Path path(folder_path_); - Path image_path = path / image_label.first; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, image_path.toString())); - if (decode_ == true) { - Status rc = Decode(image, &image); - if (rc.IsError()) { - image = nullptr; - std::string err_msg = "Fail to decode image: " + image_path.toString(); - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); - } - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), - TensorShape({1, (uint32_t)image_label.second.size()}), - data_schema_->column(1).type())); - RETURN_IF_NOT_OK(label->Zero()); - for (uint32_t index = 0; index < image_label.second.size(); index++) { - if (image_label.second[index] == 1) { - label->SetItemAt({0, static_cast(index)}, 1); - } else { - label->SetItemAt({0, static_cast(index)}, 0); - } - } - label->Squeeze(); - - (*row) = TensorRow(row_id, {std::move(image), std::move(label)}); - return Status::OK(); -} - -void CelebAOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_ << "\nceleba dir: " << folder_path_ << "\n\n"; - } -} - -// Reset Sampler and wakeup Master thread (functor) -Status CelebAOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - wp_.Set(); // wake up master thread after reset is done - return Status::OK(); -} - -Status CelebAOp::ComputeColMap() { - // Set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t index = 0; index < data_schema_->NumColumns(); index++) { - column_name_id_map_[data_schema_->column(index).name()] = index; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h deleted file mode 100644 index a6fa495a14..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h +++ /dev/null @@ -1,234 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ - -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H -#define DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H - -#include -#include -#include -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/util/queue.h" -#include "dataset/engine/datasetops/source/io_block.h" - -#define CLOSE_FILE(attr_file, pairition_file) \ - do { \ - attr_file.close(); \ - if (pairition_file.is_open()) { \ - pairition_file.close(); \ - } \ - } while (false) - -namespace mindspore { -namespace dataset { -class CelebAOp : public ParallelOp, RandomAccessOp { - public: - class Builder { - public: - // Constructor for Builder class of CelebAOp - // @return Builder setter method returns reference to the builder. - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method - // @param int32_t size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t size) { - builder_op_connector_size_ = size; - return *this; - } - - // Setter method - // @param std::set & exts, file extensions to be read - // @return Builder setter method returns reference to the builder. - Builder &SetExtensions(const std::set &exts) { - builder_extensions_ = exts; - return *this; - } - - // Setter method - // @param bool decode - // @return Builder setter method returns reference to the builder. - Builder &SetDecode(bool decode) { - builder_decode_ = decode; - return *this; - } - - // Setter method - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method - // @param const std::string &dir - // @return Builder setter method returns reference to the builder. - Builder &SetCelebADir(const std::string &dir) { - builder_dir_ = dir; - return *this; - } - - // Setter method - // @param const std::string dataset_type: type to be read - // @return Builder setter method returns reference to the builder. - Builder &SetDatasetType(const std::string &dataset_type) { - builder_dataset_type_ = dataset_type; - return *this; - } - // Check validity of input args - // @return - The error code return - Status SanityCheck(); - - // The builder "build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - bool builder_decode_; - std::string builder_dir_; - int32_t builder_num_workers_; - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - std::set builder_extensions_; - std::shared_ptr builder_sampler_; - std::unique_ptr builder_schema_; - std::string builder_dataset_type_; - }; - - // Constructor - // @param int32_t - num_workers - Num of workers reading images in parallel - // @param int32_t - rows_per_buffer Number of images (rows) in each buffer - // @param std::string - dir directory of celeba dataset - // @param int32_t queueSize - connector queue size - // @param std::unique_ptr sampler - sampler tells CelebAOp what to read - CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, - const std::string &dataset_type, const std::set &exts, std::unique_ptr schema, - std::shared_ptr sampler); - - ~CelebAOp() override = default; - - // Main Loop of CelebaOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t worker_id - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // Method in operator(), to fill IOBlockQueue - // @param std::unique_ptr sampler_buffer - to fill IOBlockQueue - // @return Status - The error code return - Status AddIOBlock(std::unique_ptr *data_buffer); - - // Op name getter - // @return Name of the current Op - std::string Name() const { return "CelebAOp"; } - - private: - // Called first when function is called - // @return - Status LaunchThreadsAndInitOp(); - - // Parse attribute file - // @return - Status ParseAttrFile(); - - // Parse each image line in attribute file - // @return - Status ParseImageAttrInfo(); - - // Split attribute info with space - // @param std::string - line - Line from att or partition file - // @return std::vector - string after split - std::vector Split(const std::string &line); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Load a tensor row according to a pair - // @param row_id_type row_id - id for this tensor row - // @param std::pair - > - // @param TensorRow row - image & label read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, const std::pair> &image_label, - TensorRow *row); - - // Check if need read according to dataset type - // @return bool - if need read - bool CheckDatasetTypeValid(); - - // reset Op - // @return Status - The error code return - Status Reset() override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t rows_per_buffer_; - std::string folder_path_; // directory of celeba folder - bool decode_; - std::set extensions_; // extensions allowed - std::unique_ptr data_schema_; - std::unique_ptr>> attr_info_queue_; - int64_t num_rows_in_attr_file_; // rows number specified in attr file - QueueList> io_block_queues_; - WaitPost wp_; - std::vector>> image_labels_vec_; - std::string dataset_type_; - std::ifstream partition_file_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc deleted file mode 100644 index 8dd615a8c1..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc +++ /dev/null @@ -1,465 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/cifar_op.h" - -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { -constexpr uint32_t kCifarImageHeight = 32; -constexpr uint32_t kCifarImageWidth = 32; -constexpr uint32_t kCifarImageChannel = 3; -constexpr uint32_t kCifarBlockImageNum = 5; -constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel; - -CifarOp::Builder::Builder() : sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - num_workers_ = cfg->num_parallel_workers(); - rows_per_buffer_ = cfg->rows_per_buffer(); - op_connect_size_ = cfg->op_connector_size(); - cifar_type_ = kCifar10; -} - -Status CifarOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - sampler_ = std::make_shared(start_index, num_samples); - } - schema_ = std::make_unique(); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - if (cifar_type_ == kCifar10) { - RETURN_IF_NOT_OK( - schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - } else { - RETURN_IF_NOT_OK(schema_->AddColumn( - ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - TensorShape another_scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema_->AddColumn( - ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar))); - } - - *ptr = std::make_shared(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, - std::move(schema_), std::move(sampler_)); - return Status::OK(); -} - -Status CifarOp::Builder::SanityCheck() { - Path dir(dir_); - std::string err_msg; - err_msg += dir.IsDirectory() == false ? "Cifar path is invalid or not set\n" : ""; - err_msg += num_workers_ <= 0 ? "Num of parallel workers is negative or 0\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, - int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_works, queue_size, std::move(sampler)), - cifar_type_(type), - rows_per_buffer_(rows_per_buf), - folder_path_(file_dir), - data_schema_(std::move(data_schema)), - row_cnt_(0), - buf_cnt_(0) { - constexpr uint64_t kUtilQueueSize = 512; - cifar_raw_data_block_ = std::make_unique>>(kUtilQueueSize); - io_block_queues_.Init(num_workers_, queue_size); -} - -// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work -Status CifarOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (true) { // each iterator is 1 epoch - std::vector keys; - keys.reserve(rows_per_buffer_); - while (sampler_buffer->eoe() == false) { - TensorRow sample_row; - RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { - keys.push_back(*itr); - row_cnt_++; - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - keys.clear(); - } - } - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - } -} - -Status CifarOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK( - tree_->AllTasks()->CreateAsyncTask("Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this))); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CifarOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - // The order of the following 2 functions must not be changed! - RETURN_IF_NOT_OK(ParseCifarData()); // Parse cifar data and get num rows, blocking - RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler - return Status::OK(); -} - -// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ -// IMPORTANT: 1 IOBlock produces 1 DataBuffer -Status CifarOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty() == true) { - return Status::OK(); // empty key is a quit signal for workers - } - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -// Load 1 TensorRow (image,label). 1 function call produces 1 TensorTow in a DataBuffer -Status CifarOp::LoadTensorRow(uint64_t index, TensorRow *trow) { - std::shared_ptr label; - std::shared_ptr fine_label; - std::shared_ptr ori_image = cifar_image_label_pairs_[index].first; - std::shared_ptr copy_image = - std::make_shared(ori_image->shape(), ori_image->type(), ori_image->GetBuffer()); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), - reinterpret_cast(&cifar_image_label_pairs_[index].second[0]))); - if (cifar_image_label_pairs_[index].second.size() > 1) { - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &fine_label, data_schema_->column(2).tensorImpl(), data_schema_->column(2).shape(), - data_schema_->column(2).type(), reinterpret_cast(&cifar_image_label_pairs_[index].second[1]))); - (*trow) = TensorRow(index, {copy_image, std::move(label), std::move(fine_label)}); - } else { - (*trow) = TensorRow(index, {copy_image, std::move(label)}); - } - - return Status::OK(); -} - -// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer -Status CifarOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - for (const int64_t &key : keys) { - TensorRow trow; - RETURN_IF_NOT_OK(LoadTensorRow(key, &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -void CifarOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_ << "\nCifar directory: " << folder_path_ << "\n\n"; - } -} - -// Reset Sampler and wakeup Master thread (functor) -Status CifarOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done - return Status::OK(); -} - -// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows -Status CifarOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -Status CifarOp::ReadCifarBlockDataAsync() { - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(GetCifarFiles()); - if (cifar_type_ == kCifar10) { - RETURN_IF_NOT_OK(ReadCifar10BlockData()); - } else { - RETURN_IF_NOT_OK(ReadCifar100BlockData()); - } - - return Status::OK(); -} - -Status CifarOp::ReadCifar10BlockData() { - constexpr uint32_t num_cifar10_records = 10000; - uint32_t block_size = (kCifarImageSize + 1) * kCifarBlockImageNum; // about 2M - std::vector image_data(block_size * sizeof(unsigned char), 0); - for (auto &file : cifar_files_) { - std::ifstream in(file, std::ios::binary); - if (!in.is_open()) { - std::string err_msg = file + " can not be opened."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - for (uint32_t index = 0; index < num_cifar10_records / kCifarBlockImageNum; ++index) { - (void)in.read(reinterpret_cast(&(image_data[0])), block_size * sizeof(unsigned char)); - if (in.fail()) { - RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file); - } - (void)cifar_raw_data_block_->EmplaceBack(image_data); - } - in.close(); - } - (void)cifar_raw_data_block_->EmplaceBack(std::vector()); // end block - - return Status::OK(); -} - -Status CifarOp::ReadCifar100BlockData() { - uint32_t num_cifar100_records = 0; // test:10000, train:50000 - uint32_t block_size = (kCifarImageSize + 2) * kCifarBlockImageNum; // about 2M - std::vector image_data(block_size * sizeof(unsigned char), 0); - for (auto &file : cifar_files_) { - int pos = file.find_last_of('/'); - if (pos == std::string::npos) { - RETURN_STATUS_UNEXPECTED("Invalid cifar100 file path"); - } - std::string file_name(file.substr(pos + 1)); - if (file_name.find("test") != std::string::npos) { - num_cifar100_records = 10000; - } else if (file_name.find("train") != std::string::npos) { - num_cifar100_records = 50000; - } else { - RETURN_STATUS_UNEXPECTED("Cifar 100 file not found!"); - } - - std::ifstream in(file, std::ios::binary); - if (!in.is_open()) { - RETURN_STATUS_UNEXPECTED(file + " can not be opened."); - } - - for (uint32_t index = 0; index < num_cifar100_records / kCifarBlockImageNum; index++) { - (void)in.read(reinterpret_cast(&(image_data[0])), block_size * sizeof(unsigned char)); - if (in.fail()) { - RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file); - } - (void)cifar_raw_data_block_->EmplaceBack(image_data); - } - in.close(); - } - (void)cifar_raw_data_block_->EmplaceBack(std::vector()); // block end - return Status::OK(); -} - -Status CifarOp::GetCifarFiles() { - // Initialize queue to hold the file names - const std::string kExtension = ".bin"; - Path dataset_directory(folder_path_); - auto dirIt = Path::DirIterator::OpenDirectory(&dataset_directory); - if (dirIt) { - while (dirIt->hasNext()) { - Path file = dirIt->next(); - std::string filename = file.toString(); - if (filename.find(kExtension) != std::string::npos) { - cifar_files_.push_back(filename); - MS_LOG(INFO) << "Cifar operator found file at " << filename << "."; - } - } - } else { - std::string err_msg = "Unable to open directory " + dataset_directory.toString(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::sort(cifar_files_.begin(), cifar_files_.end()); - return Status::OK(); -} - -Status CifarOp::ParseCifarData() { - std::vector block; - RETURN_IF_NOT_OK(cifar_raw_data_block_->PopFront(&block)); - uint32_t cur_block_index = 0; - while (!block.empty()) { - for (uint32_t index = 0; index < kCifarBlockImageNum; ++index) { - std::vector labels; - uint32_t label = block[cur_block_index++]; - labels.push_back(label); - if (cifar_type_ == kCifar100) { - uint32_t fine_label = block[cur_block_index++]; - labels.push_back(fine_label); - } - - std::shared_ptr image_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image_tensor, data_schema_->column(0).tensorImpl(), - TensorShape({kCifarImageHeight, kCifarImageWidth, kCifarImageChannel}), - data_schema_->column(0).type())); - auto itr = image_tensor->begin(); - uint32_t total_pix = kCifarImageHeight * kCifarImageWidth; - for (int pix = 0; pix < total_pix; ++pix) { - for (int ch = 0; ch < kCifarImageChannel; ++ch) { - *itr = block[cur_block_index + ch * total_pix + pix]; - itr++; - } - } - cur_block_index += total_pix * kCifarImageChannel; - cifar_image_label_pairs_.emplace_back(std::make_pair(image_tensor, labels)); - } - RETURN_IF_NOT_OK(cifar_raw_data_block_->PopFront(&block)); - cur_block_index = 0; - } - cifar_image_label_pairs_.shrink_to_fit(); - num_rows_ = cifar_image_label_pairs_.size(); - if (num_rows_ == 0) { - std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; - std::string err_msg = "There is no valid data matching the dataset API " + api + - ".Please check file path or dataset API validation first."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - cifar_raw_data_block_->Reset(); - return Status::OK(); -} - -// Derived from RandomAccessOp -Status CifarOp::GetClassIds(std::map> *cls_ids) const { - if (cls_ids == nullptr || !cls_ids->empty()) { - RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); - } - - for (uint64_t index = 0; index < cifar_image_label_pairs_.size(); ++index) { - uint32_t label = (cifar_image_label_pairs_[index].second)[0]; - (*cls_ids)[label].push_back(index); - } - - for (auto &pair : (*cls_ids)) { - pair.second.shrink_to_fit(); - } - return Status::OK(); -} - -Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) { - // the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block() - std::shared_ptr op; - *count = 0; - RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op)); - RETURN_IF_NOT_OK(op->GetCifarFiles()); - if (op->cifar_type_ == kCifar10) { - constexpr int64_t num_cifar10_records = 10000; - for (auto &file : op->cifar_files_) { - std::ifstream in(file, std::ios::binary); - if (!in.is_open()) { - std::string err_msg = file + " can not be opened."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - *count = *count + num_cifar10_records; - } - return Status::OK(); - } else { - int64_t num_cifar100_records = 0; - for (auto &file : op->cifar_files_) { - size_t pos = file.find_last_of('/'); - if (pos == std::string::npos) { - std::string err_msg = "Invalid cifar100 file path"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::string file_name; - if (file.size() > 0) - file_name = file.substr(pos + 1); - else - RETURN_STATUS_UNEXPECTED("Invalid string length!"); - if (file_name.find("test") != std::string::npos) { - num_cifar100_records = 10000; - } else if (file_name.find("train") != std::string::npos) { - num_cifar100_records = 50000; - } - std::ifstream in(file, std::ios::binary); - if (!in.is_open()) { - std::string err_msg = file + " can not be opened."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - *count = num_cifar100_records; - return Status::OK(); - } -} - -Status CifarOp::ComputeColMap() { - // set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (uint32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h deleted file mode 100644 index 917b23db94..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h +++ /dev/null @@ -1,230 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -class CifarOp : public ParallelOp, public RandomAccessOp { - public: - enum CifarType { kCifar10, kCifar100 }; - - class Builder { - public: - // Constructor for Builder class of CifarOp - // @return Builder setter method returns reference to the builder. - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method - // @param uint32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method - // @param uint32_t size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t size) { - op_connect_size_ = size; - return *this; - } - - // Setter method - // @param uint32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - sampler_ = std::move(sampler); - return *this; - } - - // Setter method - // @param const std::string & dir - // @return - Builder &SetCifarDir(const std::string &dir) { - dir_ = dir; - return *this; - } - - // Setter method - // @param const std::string & dir - // @return - Builder &SetCifarType(const bool cifar10) { - if (cifar10) { - cifar_type_ = kCifar10; - } else { - cifar_type_ = kCifar100; - } - return *this; - } - - // Check validity of input args - // @return - The error code return - Status SanityCheck(); - - // The builder "build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - std::string dir_; - int32_t num_workers_; - int32_t rows_per_buffer_; - int32_t op_connect_size_; - std::shared_ptr sampler_; - std::unique_ptr schema_; - CifarType cifar_type_; - }; - - // Constructor - // @param CifarType type - Cifar10 or Cifar100 - // @param uint32_t numWorks - Num of workers reading images in parallel - // @param uint32_t - rowsPerBuffer Number of images (rows) in each buffer - // @param std::string - dir directory of cifar dataset - // @param uint32_t - queueSize - connector queue size - // @param std::unique_ptr sampler - sampler tells ImageFolderOp what to read - CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size, - std::unique_ptr data_schema, std::shared_ptr sampler); - // Destructor. - ~CifarOp() = default; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param uint32_t workerId - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Main Loop of CifarOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // Function to count the number of samples in the CIFAR dataset - // @param dir path to the CIFAR directory - // @param isCIFAR10 true if CIFAR10 and false if CIFAR100 - // @param count output arg that will hold the actual dataset size - // @return - static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count); - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "CifarOp"; } - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Load a tensor row according to a pair - // @param uint64_t index - index need to load - // @param TensorRow row - image & label read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(uint64_t index, TensorRow *row); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Read block data from cifar file - // @return - Status ReadCifarBlockDataAsync(); - - // Called first when function is called - // @return - Status LaunchThreadsAndInitOp(); - - // reset Op - // @return Status - The error code return - Status Reset() override; - - // Get cifar files in dir - // @return - Status GetCifarFiles(); - - // Read cifar10 data as block - // @return - Status ReadCifar10BlockData(); - - // Read cifar100 data as block - // @return - Status ReadCifar100BlockData(); - - // Parse cifar data - // @return - Status ParseCifarData(); - - // Method derived from RandomAccess Op, enable Sampler to get all ids for each calss - // @param (std::map> * map - key label, val all ids for this class - // @return Status - The error code return - Status GetClassIds(std::map> *cls_ids) const override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - CifarType cifar_type_; - int32_t rows_per_buffer_; - std::string folder_path_; - std::unique_ptr data_schema_; - int64_t row_cnt_; - int64_t buf_cnt_; - - WaitPost wp_; - QueueList> io_block_queues_; - std::unique_ptr>> cifar_raw_data_block_; - std::vector cifar_files_; - std::vector, std::vector>> cifar_image_label_pairs_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc deleted file mode 100644 index 9fceb6f333..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc +++ /dev/null @@ -1,555 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/clue_op.h" - -#include -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/util/task_manager.h" -#include "dataset/engine/jagged_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -ClueOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { - std::shared_ptr config_manager = GlobalContext::config_manager(); - builder_num_workers_ = config_manager->num_parallel_workers(); - builder_op_connector_size_ = config_manager->op_connector_size(); - builder_rows_per_buffer_ = config_manager->rows_per_buffer(); - builder_worker_connector_size_ = config_manager->worker_connector_size(); -} - -Status ClueOp::Builder::ValidateInputs() const { - std::string err; - err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; - err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : ""; - return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err); -} - -Status ClueOp::Builder::Build(std::shared_ptr *op) { - RETURN_IF_NOT_OK(ValidateInputs()); - - // Throttle the number of workers if we have more workers than files! - if (static_cast(builder_num_workers_) > builder_clue_files_list_.size()) { - builder_num_workers_ = builder_clue_files_list_.size(); - MS_LOG(WARNING) << "ClueOp operator parallelism reduced to " << builder_num_workers_ << " workers."; - } - - ColKeyMap ck_map; - for (auto &p : builder_cols_to_keyword_) { - ck_map.insert({p.first, split(p.second, '/')}); - } - - std::shared_ptr clue_op = std::make_shared( - builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, - builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, - builder_device_id_); - RETURN_IF_NOT_OK(clue_op->Init()); - *op = std::move(clue_op); - - return Status::OK(); -} - -std::vector ClueOp::Builder::split(const std::string &s, char delim) { - std::vector res; - std::stringstream ss(s); - std::string item; - - while (getline(ss, item, delim)) { - res.push_back(item); - } - return res; -} - -ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, - ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_device, int32_t device_id) - : ParallelOp(num_workers, op_connector_size), - rows_per_buffer_(rows_per_buffer), - num_rows_per_shard_(0), - all_num_rows_(0), - num_samples_(num_samples), - filename_index_(std::make_unique()), - clue_files_list_(std::move(clue_files_list)), - load_jagged_connector_(true), - cols_to_keyword_(cols_to_keyword), - shuffle_files_(shuffle_files), - finished_reading_dataset_(false), - num_devices_(num_device), - device_id_(device_id), - load_io_block_queue_(true) { - worker_connector_size_ = worker_connector_size; -} - -Status ClueOp::Init() { - RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_)); - - int32_t safe_queue_size = static_cast(std::ceil(clue_files_list_.size() / num_workers_) + 1); - io_block_queues_.Init(num_workers_, safe_queue_size); - - RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); - jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); - - return Status::OK(); -} - -Status ClueOp::Reset() { - load_jagged_connector_ = true; - load_io_block_queue_ = true; - - RETURN_IF_NOT_OK(ParallelOp::Reset()); - NotifyToFillIOBlockQueue(); - return Status::OK(); -} - -Status ClueOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { - TensorRow tRow(1, nullptr); - (*tensor_table)->push_back(std::move(tRow)); - - std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); - (**tensor_table)[row][0] = std::move(tensor); - return Status::OK(); -} - -Status ClueOp::GetValue(const nlohmann::json &js, std::vector key_chain, std::shared_ptr *t) { - nlohmann::json cursor = js; - for (int i = 0; i < key_chain.size(); i++) { - if (cursor.find(key_chain[i]) != cursor.end()) { - cursor = cursor[key_chain[i]]; - } else { - RETURN_STATUS_UNEXPECTED("Failed to find key: " + key_chain[i]); - } - } - std::string final_str = key_chain.back(); - switch (cursor.type()) { - case nlohmann::detail::value_t::string: - RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get()}, TensorShape::CreateScalar())); - break; - - case nlohmann::detail::value_t::number_integer: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); - (*t)->SetItemAt({0}, cursor.get()); - break; - case nlohmann::detail::value_t::number_unsigned: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); - (*t)->SetItemAt({0}, cursor.get()); - break; - case nlohmann::detail::value_t::number_float: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32))); - (*t)->SetItemAt({0}, cursor.get()); - break; - case nlohmann::detail::value_t::array: - RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get>()}, TensorShape::CreateScalar())); - break; - default: - break; - } - return Status::OK(); -} - -Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, - const int32_t worker_id) { - std::ifstream handle(file); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Failed to open file " + file); - } - - int64_t rows_each_buffer = 0; - int64_t rows_total = 0; - std::string line; - std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - std::unique_ptr tensor_table = std::make_unique(); - - while (getline(handle, line)) { - if (line.empty()) { - continue; - } - // If read to the end offset of this file, break. - if (rows_total >= end_offset) { - break; - } - // Skip line before start offset. - if (rows_total < start_offset) { - rows_total++; - continue; - } - - try { - nlohmann::json js = nlohmann::json::parse(line); - int cols_count = cols_to_keyword_.size(); - TensorRow tRow(cols_count, nullptr); - tensor_table->push_back(std::move(tRow)); - - int cout = 0; - for (auto &p : cols_to_keyword_) { - std::shared_ptr tensor; - RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor)); - (*tensor_table)[rows_each_buffer][cout] = std::move(tensor); - cout++; - } - } catch (const std::exception &err) { - // Catch any exception and convert to Status return code - RETURN_STATUS_UNEXPECTED("Failed to load json file"); - } - - // RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); - rows_each_buffer++; - rows_total++; - if (rows_each_buffer == rows_per_buffer_) { - cur_buffer->set_tensor_table(std::move(tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); - - cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - tensor_table = std::make_unique(); - rows_each_buffer = 0; - } - } - - if (rows_each_buffer > 0) { - cur_buffer->set_tensor_table(std::move(tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); - } - return Status::OK(); -} - -Status ClueOp::operator()() { - RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - - // launch one thread, responsible for filling IoBlockQueue - RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this))); - - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1))); - - // must be called after launching workers. - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); - NotifyToFillIOBlockQueue(); - - while (!finished_reading_dataset_) { - int64_t buffer_id = 0; - int32_t workers_done = 0; - int64_t rows_read = 0; - load_io_block_queue_ = true; - - while (workers_done < num_workers_) { - std::unique_ptr buffer; - RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); - if (buffer->eoe()) { - workers_done++; - } else if (num_samples_ == 0 || rows_read < num_samples_) { - if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { - int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); - RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); - } - rows_read += buffer->NumRows(); - buffer->set_id(buffer_id++); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); - } else { - // end of epoch - load_jagged_connector_ = false; - load_io_block_queue_ = false; - } - } - - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - finished_reading_dataset_ = true; - NotifyToFillIOBlockQueue(); - } else { - jagged_buffer_connector_->DoReset(); - buffer_id = 0; - } - } - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - - RETURN_IF_NOT_OK(PostEndOfData()); - return Status::OK(); -} - -Status ClueOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - std::unique_ptr io_block; - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - while (!io_block->eof()) { - if (!io_block->eoe()) { - if (load_jagged_connector_) { - std::string filename; - RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); - int64_t start_offset = io_block->GetStartOffset(); - int64_t end_offset = io_block->GetEndOffset(); - RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); - } - } else { - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); - } - - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - } - return Status::OK(); -} - -// A print method typically used for debugging -void ClueOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ - << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ - << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n"; - for (int i = 0; i < clue_files_list_.size(); ++i) { - out << " " << clue_files_list_[i]; - } - out << "\n\n"; - } -} - -// Pops an element from a queue in io_block_queues -Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); - - return Status::OK(); -} - -// Pushes an element to a queue in io_block_queues -Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); - - return Status::OK(); -} - -static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { - std::mt19937 rng(seed); - std::shuffle(i_keys->begin(), i_keys->end(), rng); -} - -Status ClueOp::WaitToFillIOBlockQueue() { - // must be called first if called by worker spanwed by taskgroup - TaskManager::FindMe()->Post(); - - std::vector i_keys; - if (shuffle_files_) { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - i_keys.push_back(it.key()); - } - } - uint32_t seed = 0; - while (true) { - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); - io_block_queue_wait_post_.Clear(); - - if (finished_reading_dataset_) { - break; - } - - if (shuffle_files_) { - ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); - } - RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); - } - return Status::OK(); -} - -Status ClueOp::FillIOBlockQueue(const std::vector &i_keys) { - int32_t queue_index = 0; - int64_t pre_count = 0; - int64_t start_offset = 0; - int64_t end_offset = 0; - bool finish = false; - while (!finish) { - std::vector> file_index; - if (!i_keys.empty()) { - for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { - { - if (!load_io_block_queue_) { - break; - } - } - file_index.emplace_back(std::pair((*filename_index_)[*it], *it)); - } - } else { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - { - if (!load_io_block_queue_) { - break; - } - } - file_index.emplace_back(std::pair(it.value(), it.key())); - } - } - for (auto file_info : file_index) { - if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { - auto ioBlock = - std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - queue_index = (queue_index + 1) % num_workers_; - } - - pre_count += filename_numrows_[file_info.first]; - } - - if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { - finish = false; - } else { - finish = true; - } - } - - RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); - return Status::OK(); -} - -void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } - -bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count) { - *start_offset = 0; - *end_offset = 0; - bool push = false; - int64_t start_index = device_id_ * num_rows_per_shard_; - if (device_id_ + 1 < 0) { - MS_LOG(ERROR) << "Device id is invalid"; - return false; - } - - int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; - if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { - *start_offset = start_index - pre_count; - push = true; - if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - if (pre_count >= start_index && pre_count < end_index) { - *start_offset = 0; - push = true; - if (pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - return push; -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker -// pops this control indicator, it will wait until the next epoch starts and then resume execution. -Status ClueOp::PostEndOfEpoch(int32_t queue_index) { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); - } - - return Status::OK(); -} - -Status ClueOp::CalculateNumRowsPerShard() { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - int64_t count = CountTotalRows(it.value()); - filename_numrows_[it.value()] = count; - all_num_rows_ += count; - } - if (all_num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API CLUEDataset. Please check file path or dataset API " - "validation first."); - } - - num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); - MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; - return Status::OK(); -} - -int64_t ClueOp::CountTotalRows(const std::string &file) { - std::ifstream handle(file); - if (!handle.is_open()) { - MS_LOG(ERROR) << "Failed to open file: " << file; - return 0; - } - - std::string line; - int64_t count = 0; - while (getline(handle, line)) { - if (!line.empty()) { - count++; - } - } - - return count; -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. -// When the worker pops this control indicator, it will shut itself down gracefully. -Status ClueOp::PostEndOfData() { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); - } - - return Status::OK(); -} - -Status ClueOp::CountAllFileRows(const std::vector &files, int64_t *count) { - std::shared_ptr op; - *count = 0; - RETURN_IF_NOT_OK(Builder().SetClueFilesList(files).Build(&op)); - for (auto file : files) { - *count += op->CountTotalRows(file); - } - return Status::OK(); -} - -Status ClueOp::ComputeColMap() { - // Set the column name mapping (base class field) - if (column_name_id_map_.empty()) { - int count = 0; - for (auto &p : cols_to_keyword_) { - column_name_id_map_[p.first] = count; - count++; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h deleted file mode 100644 index 487ed0d47f..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h +++ /dev/null @@ -1,277 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "dataset/util/auto_index.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" - -namespace mindspore { -namespace dataset { -using StringIndex = AutoIndexObj; -using ColKeyMap = std::map>; - -class JaggedConnector; - -class ClueOp : public ParallelOp { - public: - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Checks if the inputs of the builder is valid. - // @return Status - the error code returned. - Status ValidateInputs() const; - - // Create the final object. - // @param op - dataset op. - // @return - the error code return. - Status Build(std::shared_ptr *op); - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumDevices(int64_t num_dev) { - builder_num_devices_ = num_dev; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetDeviceId(int64_t dev_id) { - builder_device_id_ = dev_id; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetClueFilesList(const std::vector &files_list) { - builder_clue_files_list_ = files_list; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShuffleFiles(bool shuffle_files) { - builder_shuffle_files_ = shuffle_files; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumSamples(int64_t num_samples) { - builder_num_samples_ = num_samples; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetColsKeyMap(const std::map &cols_to_key) { - builder_cols_to_keyword_ = cols_to_key; - return *this; - } - - // Split string based on a character delimiter - // @return - the a string vector - std::vector split(const std::string &s, char delim); - - private: - int32_t builder_device_id_; - int32_t builder_num_devices_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - int64_t builder_rows_per_buffer_; - int64_t builder_num_samples_; - int32_t builder_worker_connector_size_; - std::vector builder_clue_files_list_; - bool builder_shuffle_files_; - std::map builder_cols_to_keyword_; - }; - - // Constructor of ClueOp - ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, - ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id); - - // Default destructor - ~ClueOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // Instantiates the internal queues and connectors - // @return Status - the error code returned - Status Init(); - - // Class functor operator () override. - // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - the error code returned. - Status operator()() override; - - // Overrides base class reset method. Cleans up any state info from it's previous execution - // reinitializes itself so that it can be executed again, as if it was just created. - // @return Status - the error code returned. - Status Reset() override; - - // Get total rows in files. - // @param files - all clue files. - // @param count - number of rows. - // @return Status - the error coed returned. - static Status CountAllFileRows(const std::vector &files, int64_t *count); - - // File names getter - // @return Vector of the input file names - std::vector FileNames() { return clue_files_list_; } - - private: - // The entry point for when workers are launched. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status WorkerEntry(int32_t worker_id) override; - - // Parses a single row and puts the data into a tensor table. - // @param line - the content of the row. - // @param tensor_table - the tensor table to put the parsed data in. - // @param row - the id of the row filled in the tensor table. - // @return Status - the error code returned. - Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); - - // Reads a clue file and loads the data into multiple buffers. - // @param file - the file to read. - // @param start_offset - the start offset of file. - // @param end_offset - the end offset of file. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, - const int32_t worker_id); - - // Pops an element from a queue in IOBlockQueue. - // @param index - the index of the queue to pop from. - // @param out_block - the popped element. - // @return Status - the error code returned. - Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); - - // Pushes an element to a queue in IOBlockQueue. - // @param index - the index of the queue to push to. - // @param io_block - the element to push onto the queue. - // @return Status - the error code returned. - Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); - - // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. - // @return Status - the error code returned. - Status WaitToFillIOBlockQueue(); - - // Fill the IOBlockQueue. - // @para i_keys - keys of file to fill to the IOBlockQueue - // @return Status - the error code returned. - Status FillIOBlockQueue(const std::vector &i_keys); - - // Notifies the thread which called FillIoBlockQueue to resume execution - void NotifyToFillIOBlockQueue(); - - // Select file and push it to the block queue. - // @param file_name - File name. - // @param start_file - If file contains the first sample of data. - // @param end_file - If file contains the end sample of data. - // @param pre_count - Total rows of previous files. - // @return Status - the error code returned. - bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count); - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker - // pops this control indicator, it will wait until the next epoch starts and then resume execution. - // @return Status - the error code returned. - Status PostEndOfEpoch(int32_t queue_index); - - // Calculate number of rows in each shard. - // @return Status - the error code returned. - Status CalculateNumRowsPerShard(); - - // Count number of rows in each file. - // @param filename - clue file name. - // @return int64_t - the total number of rows in file. - int64_t CountTotalRows(const std::string &file); - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. - // When the worker pops this control indicator, it will shut itself down gracefully. - // @return Status - the error code returned. - Status PostEndOfData(); - - // @return Status - the error code returned. - Status GetValue(const nlohmann::json &js, std::vector key_chain, std::shared_ptr *t); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t device_id_; - bool shuffle_files_; - bool finished_reading_dataset_; - int32_t num_devices_; - int64_t rows_per_buffer_; - bool load_io_block_queue_; - int64_t num_rows_per_shard_; - int64_t all_num_rows_; - int64_t num_samples_; - std::map filename_numrows_; - std::unique_ptr filename_index_; - std::vector clue_files_list_; - WaitPost io_block_queue_wait_post_; - std::unique_ptr jagged_buffer_connector_; - QueueList> io_block_queues_; - bool load_jagged_connector_; - ColKeyMap cols_to_keyword_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc deleted file mode 100644 index 92f6794769..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc +++ /dev/null @@ -1,639 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/coco_op.h" - -#include -#include -#include -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { -const char kColumnImage[] = "image"; -const char kJsonImages[] = "images"; -const char kJsonImagesFileName[] = "file_name"; -const char kJsonId[] = "id"; -const char kJsonAnnotations[] = "annotations"; -const char kJsonAnnoSegmentation[] = "segmentation"; -const char kJsonAnnoCounts[] = "counts"; -const char kJsonAnnoSegmentsInfo[] = "segments_info"; -const char kJsonAnnoIscrowd[] = "iscrowd"; -const char kJsonAnnoBbox[] = "bbox"; -const char kJsonAnnoArea[] = "area"; -const char kJsonAnnoImageId[] = "image_id"; -const char kJsonAnnoNumKeypoints[] = "num_keypoints"; -const char kJsonAnnoKeypoints[] = "keypoints"; -const char kJsonAnnoCategoryId[] = "category_id"; -const char kJsonCategories[] = "categories"; -const char kJsonCategoriesIsthing[] = "isthing"; -const char kJsonCategoriesName[] = "name"; -const float kDefaultPadValue = -1.0; -const unsigned int kPadValueZero = 0; - -CocoOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); - builder_task_type_ = TaskType::Detection; -} - -Status CocoOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - builder_schema_ = std::make_unique(); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - switch (builder_task_type_) { - case TaskType::Detection: - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoCategoryId), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - break; - case TaskType::Stuff: - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoSegmentation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - break; - case TaskType::Keypoint: - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoKeypoints), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoNumKeypoints), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - break; - case TaskType::Panoptic: - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoCategoryId), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoArea), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - break; - default: - RETURN_STATUS_UNEXPECTED("Invalid task type"); - } - *ptr = std::make_shared(builder_task_type_, builder_dir_, builder_file_, builder_num_workers_, - builder_rows_per_buffer_, builder_op_connector_size_, builder_decode_, - std::move(builder_schema_), std::move(builder_sampler_)); - return Status::OK(); -} - -Status CocoOp::Builder::SanityCheck() { - Path dir(builder_dir_); - Path file(builder_file_); - std::string err_msg; - err_msg += dir.IsDirectory() == false ? "Coco image folder path is invalid or not set\n" : ""; - err_msg += file.Exists() == false ? "Coco annotation json path is invalid or not set\n" : ""; - err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, - int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, - std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size), - decode_(decode), - row_cnt_(0), - buf_cnt_(0), - task_type_(task_type), - image_folder_path_(image_folder_path), - annotation_path_(annotation_path), - rows_per_buffer_(rows_per_buffer), - sampler_(std::move(sampler)), - data_schema_(std::move(data_schema)) { - io_block_queues_.Init(num_workers_, queue_size); -} - -Status CocoOp::TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) > num_rows_) continue; - keys->push_back(*itr); - row_cnt_++; - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); - keys->clear(); - } - } - return Status::OK(); -} - -Status CocoOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (true) { - std::vector keys; - keys.reserve(rows_per_buffer_); - while (sampler_buffer->eoe() == false) { - std::shared_ptr sample_ids; - RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); - if (sample_ids->type() != DataType(DataType::DE_INT64)) { - RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); - } - RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - } -} - -void CocoOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows: " << num_rows_ << "\nCOCO Directory: " << image_folder_path_ << "\n\n"; - } -} - -Status CocoOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); - return Status::OK(); -} - -Status CocoOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *trow) { - std::shared_ptr image, coordinate; - auto itr = coordinate_map_.find(image_id); - if (itr == coordinate_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); - - std::string kImageFile = image_folder_path_ + image_id; - RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); - - auto bboxRow = itr->second; - std::vector bbox_row; - dsize_t bbox_row_num = static_cast(bboxRow.size()); - dsize_t bbox_column_num = 0; - for (auto bbox : bboxRow) { - if (static_cast(bbox.size()) > bbox_column_num) { - bbox_column_num = static_cast(bbox.size()); - } - } - - for (auto bbox : bboxRow) { - bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); - dsize_t pad_len = bbox_column_num - static_cast(bbox.size()); - if (pad_len > 0) { - for (dsize_t i = 0; i < pad_len; i++) { - bbox_row.push_back(kDefaultPadValue); - } - } - } - - std::vector bbox_dim = {bbox_row_num, bbox_column_num}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&coordinate, data_schema_->column(1).tensorImpl(), TensorShape(bbox_dim), - data_schema_->column(1).type(), - reinterpret_cast(&bbox_row[0]))); - if (task_type_ == TaskType::Detection) { - RETURN_IF_NOT_OK(LoadDetectionTensorRow(row_id, image_id, image, coordinate, trow)); - } else if (task_type_ == TaskType::Stuff || task_type_ == TaskType::Keypoint) { - RETURN_IF_NOT_OK(LoadSimpleTensorRow(row_id, image_id, image, coordinate, trow)); - } else if (task_type_ == TaskType::Panoptic) { - RETURN_IF_NOT_OK(LoadMixTensorRow(row_id, image_id, image, coordinate, trow)); - } else { - RETURN_STATUS_UNEXPECTED("Invalid task type."); - } - - return Status::OK(); -} - -// When task is Detection, user can get data with four columns: -// column ["image"] with datatype=uint8 -// column ["bbox"] with datatype=float32 -// column ["category_id"] with datatype=uint32 -// column ["iscrowd"] with datatype=uint32 -// By the way, column ["iscrowd"] is used for some testcases, like fasterRcnn. -// If "iscrowd" is not existed, user will get default value 0. -Status CocoOp::LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow) { - std::shared_ptr category_id, iscrowd; - std::vector category_id_row; - std::vector iscrowd_row; - auto itr_item = simple_item_map_.find(image_id); - if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); - - std::vector annotation = itr_item->second; - for (int64_t i = 0; i < annotation.size(); i++) { - if (i % 2 == 0) { - category_id_row.push_back(annotation[i]); - } else if (i % 2 == 1) { - iscrowd_row.push_back(annotation[i]); - } - } - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), - data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); - - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), - data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); - (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd)}); - return Status::OK(); -} - -// When task is "Stuff"/"Keypoint", user can get data with three columns: -// column ["image"] with datatype=uint8 -// column ["segmentation"]/["keypoints"] with datatype=float32 -// column ["iscrowd"]/["num_keypoints"] with datatype=uint32 -Status CocoOp::LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow) { - std::shared_ptr item; - std::vector item_queue; - auto itr_item = simple_item_map_.find(image_id); - if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); - - item_queue = itr_item->second; - std::vector bbox_dim = {static_cast(item_queue.size()), 1}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&item, data_schema_->column(2).tensorImpl(), TensorShape(bbox_dim), - data_schema_->column(2).type(), - reinterpret_cast(&item_queue[0]))); - (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(item)}); - return Status::OK(); -} - -// When task is "Panoptic", user can get data with five columns: -// column ["image"] with datatype=uint8 -// column ["bbox"] with datatype=float32 -// column ["category_id"] with datatype=uint32 -// column ["iscrowd"] with datatype=uint32 -// column ["area"] with datattype=uint32 -Status CocoOp::LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow) { - std::shared_ptr category_id, iscrowd, area; - std::vector category_id_row; - std::vector iscrowd_row; - std::vector area_row; - auto itr_item = simple_item_map_.find(image_id); - if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); - - std::vector annotation = itr_item->second; - for (int64_t i = 0; i < annotation.size(); i++) { - if (i % 3 == 0) { - category_id_row.push_back(annotation[i]); - } else if (i % 3 == 1) { - iscrowd_row.push_back(annotation[i]); - } else if (i % 3 == 2) { - area_row.push_back(annotation[i]); - } - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), - data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); - - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), - data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); - - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &area, data_schema_->column(4).tensorImpl(), TensorShape({static_cast(area_row.size()), 1}), - data_schema_->column(4).type(), reinterpret_cast(&area_row[0]))); - (*trow) = TensorRow( - row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd), std::move(area)}); - return Status::OK(); -} - -Status CocoOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - TensorRow trow; - for (const int64_t &key : keys) { - RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_ids_[key], &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -Status CocoOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, (std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty() == true) return Status::OK(); - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -template -Status CocoOp::SearchNodeInJson(nlohmann::json input_tree, std::string node_name, T *output_node) { - auto node = input_tree.find(node_name); - if (node == input_tree.end()) RETURN_STATUS_UNEXPECTED("Invalid node found in json : " + node_name); - (*output_node) = *node; - return Status::OK(); -} - -Status CocoOp::ParseAnnotationIds() { - std::ifstream in(annotation_path_); - nlohmann::json js; - in >> js; - - std::vector image_que; - nlohmann::json image_list; - RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonImages), &image_list)); - RETURN_IF_NOT_OK(ImageColumnLoad(image_list, &image_que)); - if (task_type_ == TaskType::Detection || task_type_ == TaskType::Panoptic) { - nlohmann::json node_categories; - RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonCategories), &node_categories)); - RETURN_IF_NOT_OK(CategoriesColumnLoad(node_categories)); - } - nlohmann::json annotations_list; - RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonAnnotations), &annotations_list)); - for (auto annotation : annotations_list) { - int32_t image_id = 0, id = 0; - std::string file_name; - RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonAnnoImageId), &image_id)); - auto itr_file = image_index_.find(image_id); - if (itr_file == image_index_.end()) - RETURN_STATUS_UNEXPECTED("Invalid image id of annotations : " + std::to_string(image_id)); - file_name = itr_file->second; - switch (task_type_) { - case TaskType::Detection: - RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); - RETURN_IF_NOT_OK(DetectionColumnLoad(annotation, file_name, id)); - break; - case TaskType::Stuff: - RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); - RETURN_IF_NOT_OK(StuffColumnLoad(annotation, file_name, id)); - break; - case TaskType::Keypoint: - RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); - RETURN_IF_NOT_OK(KeypointColumnLoad(annotation, file_name, id)); - break; - case TaskType::Panoptic: - RETURN_IF_NOT_OK(PanopticColumnLoad(annotation, file_name, image_id)); - break; - default: - RETURN_STATUS_UNEXPECTED("Invalid task type"); - } - } - for (auto img : image_que) { - if (coordinate_map_.find(img) != coordinate_map_.end()) image_ids_.push_back(img); - } - num_rows_ = image_ids_.size(); - return Status::OK(); -} - -Status CocoOp::ImageColumnLoad(nlohmann::json image_tree, std::vector *image_vec) { - if (image_tree.size() == 0) { - RETURN_STATUS_UNEXPECTED("No images found in " + annotation_path_); - } - for (auto img : image_tree) { - std::string file_name; - int32_t id = 0; - RETURN_IF_NOT_OK(SearchNodeInJson(img, std::string(kJsonImagesFileName), &file_name)); - RETURN_IF_NOT_OK(SearchNodeInJson(img, std::string(kJsonId), &id)); - - image_index_[id] = file_name; - image_vec->push_back(file_name); - } - return Status::OK(); -} - -Status CocoOp::DetectionColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, - const int32_t &unique_id) { - std::vector bbox; - nlohmann::json node_bbox; - uint32_t category_id = 0, iscrowd = 0; - RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoBbox), &node_bbox)); - RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoCategoryId), &category_id)); - auto search_category = category_set_.find(category_id); - if (search_category == category_set_.end()) - RETURN_STATUS_UNEXPECTED("category_id can't find in categories where category_id: " + std::to_string(category_id)); - auto node_iscrowd = annotation_tree.find(kJsonAnnoIscrowd); - if (node_iscrowd != annotation_tree.end()) iscrowd = *node_iscrowd; - bbox.insert(bbox.end(), node_bbox.begin(), node_bbox.end()); - coordinate_map_[image_file].push_back(bbox); - simple_item_map_[image_file].push_back(category_id); - simple_item_map_[image_file].push_back(iscrowd); - return Status::OK(); -} - -Status CocoOp::StuffColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, - const int32_t &unique_id) { - uint32_t iscrowd = 0; - std::vector bbox; - RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoIscrowd), &iscrowd)); - simple_item_map_[image_file].push_back(iscrowd); - nlohmann::json segmentation; - RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoSegmentation), &segmentation)); - if (iscrowd == 0) { - for (auto item : segmentation) { - if (bbox.size() > 0) bbox.clear(); - bbox.insert(bbox.end(), item.begin(), item.end()); - coordinate_map_[image_file].push_back(bbox); - } - } else if (iscrowd == 1) { - nlohmann::json segmentation_count; - RETURN_IF_NOT_OK(SearchNodeInJson(segmentation, std::string(kJsonAnnoCounts), &segmentation_count)); - bbox.insert(bbox.end(), segmentation_count.begin(), segmentation_count.end()); - coordinate_map_[image_file].push_back(bbox); - } - return Status::OK(); -} - -Status CocoOp::KeypointColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, - const int32_t &unique_id) { - auto itr_num_keypoint = annotation_tree.find(kJsonAnnoNumKeypoints); - if (itr_num_keypoint == annotation_tree.end()) - RETURN_STATUS_UNEXPECTED("No num_keypoint found in annotations where id: " + std::to_string(unique_id)); - simple_item_map_[image_file].push_back(*itr_num_keypoint); - auto itr_keypoint = annotation_tree.find(kJsonAnnoKeypoints); - if (itr_keypoint == annotation_tree.end()) - RETURN_STATUS_UNEXPECTED("No keypoint found in annotations where id: " + std::to_string(unique_id)); - coordinate_map_[image_file].push_back(*itr_keypoint); - return Status::OK(); -} - -Status CocoOp::PanopticColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, - const int32_t &image_id) { - auto itr_segments = annotation_tree.find(kJsonAnnoSegmentsInfo); - if (itr_segments == annotation_tree.end()) - RETURN_STATUS_UNEXPECTED("No segments_info found in annotations where image_id: " + std::to_string(image_id)); - for (auto info : *itr_segments) { - std::vector bbox; - uint32_t category_id = 0; - auto itr_bbox = info.find(kJsonAnnoBbox); - if (itr_bbox == info.end()) - RETURN_STATUS_UNEXPECTED("No bbox found in segments_info where image_id: " + std::to_string(image_id)); - bbox.insert(bbox.end(), itr_bbox->begin(), itr_bbox->end()); - coordinate_map_[image_file].push_back(bbox); - - RETURN_IF_NOT_OK(SearchNodeInJson(info, std::string(kJsonAnnoCategoryId), &category_id)); - auto search_category = category_set_.find(category_id); - if (search_category == category_set_.end()) - RETURN_STATUS_UNEXPECTED("category_id can't find in categories where category_id: " + - std::to_string(category_id)); - auto itr_iscrowd = info.find(kJsonAnnoIscrowd); - if (itr_iscrowd == info.end()) - RETURN_STATUS_UNEXPECTED("No iscrowd found in segments_info where image_id: " + std::to_string(image_id)); - auto itr_area = info.find(kJsonAnnoArea); - if (itr_area == info.end()) - RETURN_STATUS_UNEXPECTED("No area found in segments_info where image_id: " + std::to_string(image_id)); - simple_item_map_[image_file].push_back(category_id); - simple_item_map_[image_file].push_back(*itr_iscrowd); - simple_item_map_[image_file].push_back(*itr_area); - } - return Status::OK(); -} - -Status CocoOp::CategoriesColumnLoad(nlohmann::json categories_tree) { - if (categories_tree.size() == 0) RETURN_STATUS_UNEXPECTED("No categories found in " + annotation_path_); - for (auto category : categories_tree) { - int32_t id = 0; - std::string name; - std::vector label_info; - auto itr_id = category.find(kJsonId); - if (itr_id == category.end()) RETURN_STATUS_UNEXPECTED("No id found in categories of " + annotation_path_); - id = *itr_id; - label_info.push_back(id); - category_set_.insert(id); - - auto itr_name = category.find(kJsonCategoriesName); - if (itr_name == category.end()) - RETURN_STATUS_UNEXPECTED("No name found in categories where id: " + std::to_string(id)); - name = *itr_name; - - if (task_type_ == TaskType::Panoptic) { - auto itr_isthing = category.find(kJsonCategoriesIsthing); - if (itr_isthing == category.end()) - RETURN_STATUS_UNEXPECTED("No isthing found in categories of " + annotation_path_); - label_info.push_back(*itr_isthing); - } - label_index_.emplace_back(std::make_pair(name, label_info)); - } - return Status::OK(); -} - -Status CocoOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -Status CocoOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CocoOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(this->ParseAnnotationIds()); - RETURN_IF_NOT_OK(this->InitSampler()); - return Status::OK(); -} - -Status CocoOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); - - if (decode_ == true) { - Status rc = Decode(*tensor, tensor); - if (rc.IsError()) { - RETURN_STATUS_UNEXPECTED("fail to decode file: " + path); - } - } - return Status::OK(); -} - -Status CocoOp::CountTotalRows(const std::string &dir, const std::string &file, const std::string &task, - int64_t *count) { - std::shared_ptr op; - RETURN_IF_NOT_OK(Builder().SetDir(dir).SetFile(file).SetTask(task).Build(&op)); - RETURN_IF_NOT_OK(op->ParseAnnotationIds()); - *count = static_cast(op->image_ids_.size()); - return Status::OK(); -} - -Status CocoOp::GetClassIndexing(const std::string &dir, const std::string &file, const std::string &task, - std::vector>> *output_class_indexing) { - std::shared_ptr op; - RETURN_IF_NOT_OK(Builder().SetDir(dir).SetFile(file).SetTask(task).Build(&op)); - RETURN_IF_NOT_OK(op->ParseAnnotationIds()); - *output_class_indexing = op->label_index_; - return Status::OK(); -} - -Status CocoOp::ComputeColMap() { - // Set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h deleted file mode 100644 index 3791853798..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h +++ /dev/null @@ -1,334 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_COC0_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// Forward declares -template -class Queue; - -using CoordinateRow = std::vector>; - -class CocoOp : public ParallelOp, public RandomAccessOp { - public: - enum class TaskType { Detection = 0, Stuff = 1, Panoptic = 2, Keypoint = 3 }; - - class Builder { - public: - // Constructor for Builder class of ImageFolderOp - // @param uint32_t numWrks - number of parallel workers - // @param dir - directory folder got ImageNetFolder - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method. - // @param const std::string & build_dir - // @return Builder setter method returns reference to the builder. - Builder &SetDir(const std::string &build_dir) { - builder_dir_ = build_dir; - return *this; - } - - // Setter method. - // @param const std::string & build_file - // @return Builder setter method returns reference to the builder. - Builder &SetFile(const std::string &build_file) { - builder_file_ = build_file; - return *this; - } - - // Setter method. - // @param const std::string & task_type - // @return Builder setter method returns reference to the builder. - Builder &SetTask(const std::string &task_type) { - if (task_type == "Detection") { - builder_task_type_ = TaskType::Detection; - } else if (task_type == "Stuff") { - builder_task_type_ = TaskType::Stuff; - } else if (task_type == "Panoptic") { - builder_task_type_ = TaskType::Panoptic; - } else if (task_type == "Keypoint") { - builder_task_type_ = TaskType::Keypoint; - } - return *this; - } - - // Setter method. - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @param int32_t op_connector_size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method. - // @param bool do_decode - // @return Builder setter method returns reference to the builder. - Builder &SetDecode(bool do_decode) { - builder_decode_ = do_decode; - return *this; - } - - // Check validity of input args - // @return = The error code return - Status SanityCheck(); - - // The builder "Build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - bool builder_decode_; - std::string builder_dir_; - std::string builder_file_; - TaskType builder_task_type_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - int32_t builder_rows_per_buffer_; - std::shared_ptr builder_sampler_; - std::unique_ptr builder_schema_; - }; - - // Constructor - // @param TaskType task_type - task type of Coco - // @param std::string image_folder_path - image folder path of Coco - // @param std::string annotation_path - annotation json path of Coco - // @param int32_t num_workers - number of workers reading images in parallel - // @param int32_t rows_per_buffer - number of images (rows) in each buffer - // @param int32_t queue_size - connector queue size - // @param int64_t num_samples - number of samples to read - // @param bool decode - whether to decode images - // @param std::unique_ptr data_schema - the schema of the Coco dataset - // @param std::shared_ptr sampler - sampler tells CocoOp what to read - CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, - int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, - std::unique_ptr data_schema, std::shared_ptr sampler); - - // Destructor - ~CocoOp() = default; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t workerId - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Main Loop of CocoOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // @param const std::string &dir - Coco image dir path - // @param const std::string &file - Coco json file path - // @param const std::string &task - task mode of Coco task - // @param int64_t numSamples - samples number of CocoDataset - // @param int64_t *count - output rows number of CocoDataset - static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, - int64_t *count); - - // @param const std::string &dir - Coco image dir path - // @param const std::string &file - Coco json file path - // @param const std::string &task - task mode of Coco task - // @param int64_t numSamples - samples number of CocoDataset - // @param std::map *output_class_indexing - output class index of CocoDataset - static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, - std::vector>> *output_class_indexing); - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Load a tensor row according to image id - // @param row_id_type row_id - id for this tensor row - // @param std::string image_id - image id - // @param TensorRow row - image & target read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); - - // Load a tensor row with vector which a vector to a tensor - // @param row_id_type row_id - id for this tensor row - // @param const std::string &image_id - image is - // @param std::shared_ptr image - image tensor - // @param std::shared_ptr coordinate - coordinate tensor - // @param TensorRow row - image & target read into this tensor row - // @return Status - The error code return - Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow); - - // Load a tensor row with vector which a vector to a tensor - // @param row_id_type row_id - id for this tensor row - // @param const std::string &image_id - image is - // @param std::shared_ptr image - image tensor - // @param std::shared_ptr coordinate - coordinate tensor - // @param TensorRow row - image & target read into this tensor row - // @return Status - The error code return - Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow); - - // Load a tensor row with vector which a vector to multi-tensor - // @param row_id_type row_id - id for this tensor row - // @param const std::string &image_id - image is - // @param std::shared_ptr image - image tensor - // @param std::shared_ptr coordinate - coordinate tensor - // @param TensorRow row - image & target read into this tensor row - // @return Status - The error code return - Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow); - - // @param const std::string &path - path to the image file - // @param const ColDescriptor &col - contains tensor implementation and datatype - // @param std::shared_ptr tensor - return - // @return Status - The error code return - Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Read annotation from Annotation folder - // @return Status - The error code return - Status ParseAnnotationIds(); - - // @param const std::shared_ptr &sample_ids - sample ids of tensor - // @param std::vector *keys - image id - // @return Status - The error code return - Status TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); - - // Called first when function is called - // @return Status - The error code return - Status LaunchThreadsAndInitOp(); - - // Reset dataset state - // @return Status - The error code return - Status Reset() override; - - // @param nlohmann::json image_tree - image tree of json - // @param std::vector *image_vec - image id list of json - // @return Status - The error code return - Status ImageColumnLoad(nlohmann::json image_tree, std::vector *image_vec); - - // @param nlohmann::json categories_tree - categories tree of json - // return Status - The error code return - Status CategoriesColumnLoad(nlohmann::json categories_tree); - - // @param nlohmann::json categories_tree - categories tree of json - // @param const std::string &image_file - current image name in annotation - // @param const int32_t &id - current unique id of annotation - // @return Status - The error code return - Status DetectionColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); - - // @param nlohmann::json categories_tree - categories tree of json - // @param const std::string &image_file - current image name in annotation - // @param const int32_t &id - current unique id of annotation - // @return Status - The error code return - Status StuffColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); - - // @param nlohmann::json categories_tree - categories tree of json - // @param const std::string &image_file - current image name in annotation - // @param const int32_t &id - current unique id of annotation - // @return Status - The error code return - Status KeypointColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); - - // @param nlohmann::json categories_tree - categories tree of json - // @param const std::string &image_file - current image name in annotation - // @param const int32_t &image_id - current unique id of annotation - // @return Status - The error code return - Status PanopticColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &image_id); - - template - Status SearchNodeInJson(nlohmann::json input_tree, std::string node_name, T *output_node); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - bool decode_; - int64_t row_cnt_; - int64_t buf_cnt_; - std::string image_folder_path_; - std::string annotation_path_; - TaskType task_type_; - int32_t rows_per_buffer_; - std::shared_ptr sampler_; - std::unique_ptr data_schema_; - - WaitPost wp_; - std::vector image_ids_; - std::map image_index_; - QueueList> io_block_queues_; - std::vector>> label_index_; - std::map coordinate_map_; - std::map> simple_item_map_; - std::set category_set_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_Coco_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc deleted file mode 100644 index 36c221fc16..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc +++ /dev/null @@ -1,267 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/generator_op.h" -#include -#include "dataset/core/global_context.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/task_manager.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -GeneratorOp::Builder::Builder() { - // Some arguments to the GeneratorOp constructor have a default argument that is taken - // from the client config. - build_buffer_size_ = kCfgRowsPerBuffer; - build_op_connector_size_ = kCfgOpConnectorSize; -} - -Status GeneratorOp::Builder::SanityCheck() { - // Update queue size to fit the prefetch requirement - MS_LOG(DEBUG) << "Generator operator sanity check, prefetch size is " << build_prefetch_size_ << "."; - if (build_prefetch_size_ > 0) { - build_op_connector_size_ = (build_prefetch_size_ + build_buffer_size_ - 1) / build_buffer_size_; - } - return Status::OK(); -} - -Status GeneratorOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_generator_function_, build_column_names_, build_column_types_, - build_prefetch_size_, build_buffer_size_, build_op_connector_size_); - return (*ptr)->Init(); -} - -GeneratorOp::GeneratorOp(py::function generator_function, std::vector column_names, - std::vector column_types, int32_t prefetch_size, int32_t buffer_size, - int32_t connector_size) - : PipelineOp(connector_size), - generator_function_(generator_function), - column_names_(column_names), - column_types_(column_types), - prefetch_size_(prefetch_size), - buffer_size_(buffer_size), - buffer_id_(0) {} - -GeneratorOp::~GeneratorOp() { this->Dealloc(); } - -void GeneratorOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nColumn names:\n"; - for (int i = 0; i < column_names_.size(); ++i) { - out << "\n " << column_names_[i]; - } - out << "\n\n"; - } -} - -void GeneratorOp::Dealloc() noexcept { - // Setup GIL state - PyGILState_STATE gstate; - gstate = PyGILState_Ensure(); - // GC the generator object within GIL - (void)generator_.dec_ref(); - // Release GIL - PyGILState_Release(gstate); -} - -// Reentrant init method. -Status GeneratorOp::Init() { - // Reset BufferID - buffer_id_ = 0; - Status ret; - { - // Acquire Python GIL - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - // Invoke the generatorFunction to get generator object - try { - generator_ = generator_function_(); - } catch (const py::error_already_set &e) { - ret = Status(StatusCode::kPyFuncException, e.what()); - } - } - return ret; -} - -Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) { - if (!py::isinstance(py_data)) { - return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, "Generator should return a tuple of numpy arrays."); - } - py::tuple py_row = py_data.cast(); - // Check if returned number of columns matches with column names - if (py_row.size() != column_names_.size()) { - return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, - "Generator should return same number of numpy arrays as specified in column names."); - } - // Iterate over two containers simultaneously for memory copy - for (int i = 0; i < py_row.size(); ++i) { - py::object ret_py_ele = py_row[i]; - if (!py::isinstance(ret_py_ele)) { - return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, - "Generator should return a tuple of numpy arrays."); - } - std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, ret_py_ele.cast())); - if ((!column_types_.empty()) && (column_types_[i] != DataType::DE_UNKNOWN) && - (column_types_[i] != tensor->type())) { - return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, "Generator type check failed."); - } - tensor_row->push_back(tensor); - } - return Status(StatusCode::kOK, ""); -} - -Status GeneratorOp::FillBuffer(TensorQTable *tt) { - for (int i = 0; i < buffer_size_; i++) { - TensorRow row; - RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &row)); - tt->push_back(std::move(row)); - } - return Status::OK(); -} - -// Entry point for Generator, called by launch() -// Note that this function is very easy to break because of the Python GIL mechanism -// The master thread has the following workflow -// -// while !eof: -// Try: -// Prepare one data buffer GIL, Can throw -// Catch: -// Fetch Python Exception GIL -// Check if Exception is StopIteration (EOE) GIL -// Restore Python Exception GIL -// If not StopIteration: -// Return Status PyFuncException -// -// Push data buffer to connector Block -// -// if EOE -// Push EOE Block -// if more epoch: -// Block until next epoch Block -// else: -// Push EOF Block -// eof = true -// Return Status OK -// -// Note that any modification of this function need to guarantee: -// 1. All "Require GIL" operations are protected by GIL -// SegFault / Deadlock will occur if this condition is not fulfilled. -// 2. All "Block" operations are free from GIL, all block target are registered with tree. -// Deadlock will occur if this condition is not fulfilled -// 3. No Python GC should be triggered outside of GIL. -// SegFault will occur is this condition is not fulfilled -// -Status GeneratorOp::operator()() { - // Handshake with TaskManager to synchronize thread creation - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - std::unique_ptr fetched_buffer; - bool eof = false; - while (!eof) { - // Create new buffer each iteration - fetched_buffer = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); - std::unique_ptr fetched_table = std::make_unique(); - bool eoe = false; - { - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - RETURN_IF_NOT_OK(FillBuffer(fetched_table.get())); - } catch (py::error_already_set &e) { - eoe = e.matches(PyExc_StopIteration); - // Restore exception to python - e.restore(); - // Pop up non StopIteration Python Exception - if (!eoe) { - return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what()); - } - } - } - if (fetched_table->size() > 0) { - fetched_buffer->set_tensor_table(std::move(fetched_table)); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); - } - if (eoe) { - // Push out EOE upon StopIteration exception from generator - MS_LOG(DEBUG) << "Generator operator sends out EOE."; - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - // If last repeat or not repeated, push out EOF and exit master loop - MS_LOG(DEBUG) << "Generator operator sends out EOF."; - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - MS_LOG(DEBUG) << "Generator operator main execution loop complete."; - eof = true; - } else { - // Waiting for repeatOp to start new epoch - // If Reset() is called first by repeat op, this wait() will return right away. - // If Reset() is not called yet, this wait() will block until reset. - RETURN_IF_NOT_OK(wp_.Wait()); - // Clear the status of the wait post - wp_.Clear(); - } - } - } - return Status::OK(); -} - -Status GeneratorOp::Reset() { - // Reset Op state - RETURN_IF_NOT_OK(this->Init()); - // Wake up master thread - wp_.Set(); - return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); -} - -// Visitor accept method for NodePass -Status GeneratorOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status GeneratorOp::ComputeColMap() { - // Setup column names map (base class field) - if (column_name_id_map_.empty()) { - for (int i = 0; i < column_names_.size(); ++i) { - column_name_id_map_[column_names_[i]] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h deleted file mode 100644 index 98dd2d70a1..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ - -#include -#include -#include -#include -#include -#include -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -#pragma GCC visibility push(hidden) - -class GeneratorOp : public PipelineOp { - public: - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetGeneratorFunction(py::function generator_function) { - build_generator_function_ = generator_function; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetColumnNames(const std::vector &column_names) { - build_column_names_ = column_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetColumnTypes(const std::vector &column_types) { - build_column_types_ = column_types; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetPrefetchSize(int32_t prefetch_size) { - build_prefetch_size_ = prefetch_size; - return *this; - } - - // The builder "build" method creates the final object. - // @return shared_ptr to the new GeneratorOp object - Status Build(std::shared_ptr *); - - private: - // The builder saves all GeneratorOp construction arguments internally. - // The following are the arguments. - py::function build_generator_function_; - std::vector build_column_names_; - std::vector build_column_types_; - - int32_t build_prefetch_size_ = 0; - int32_t build_buffer_size_; - int32_t build_op_connector_size_; - - Status SanityCheck(); - }; - - GeneratorOp(py::function generator_function, std::vector column_names, - std::vector column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size); - - ~GeneratorOp(); - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param generator_op - reference to the GeneratorOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const GeneratorOp &generator_op) { - generator_op.Print(out, false); - return out; - } - - // Class functor operator () override. - // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work. - // @return Status - The error code return - Status operator()() override; - - // Overrides base class reset method. When an operator does a reset, it cleans up any state - // info from it's previous execution and then initializes itself so that it can be executed - // again. - // @return Status - The error code return - Status Reset() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "GeneratorOp"; } - - private: - py::function generator_function_; - std::vector column_names_; - std::vector column_types_; - int32_t prefetch_size_; - int32_t buffer_size_; - - py::object generator_; - int32_t buffer_id_; - - WaitPost wp_; - - Status Init(); - - void Dealloc() noexcept; - - Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row); - - Status FillBuffer(TensorQTable *tt); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; -}; - -#pragma GCC visibility pop -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc deleted file mode 100644 index 837eae1e3c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc +++ /dev/null @@ -1,429 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include -#include -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -ImageFolderOp::Builder::Builder() : builder_decode_(false), builder_recursive_(false), builder_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status ImageFolderOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - builder_schema_ = std::make_unique(); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); - *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, - builder_op_connector_size_, builder_recursive_, builder_decode_, - builder_extensions_, builder_labels_to_read_, std::move(builder_schema_), - std::move(builder_sampler_)); - return Status::OK(); -} - -Status ImageFolderOp::Builder::SanityCheck() { - Path dir(builder_dir_); - std::string err_msg; - err_msg += dir.IsDirectory() == false ? "ImageFolder path is invalid or not set\n" : ""; - err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, - bool recursive, bool do_decode, const std::set &exts, - const std::map &map, std::unique_ptr data_schema, - std::shared_ptr sampler) - : ParallelOp(num_wkrs, queue_size, std::move(sampler)), - rows_per_buffer_(rows_per_buffer), - folder_path_(file_dir), - recursive_(recursive), - decode_(do_decode), - extensions_(exts), - class_index_(map), - data_schema_(std::move(data_schema)), - row_cnt_(0), - buf_cnt_(0), - sampler_ind_(0), - dirname_offset_(0) { - folder_name_queue_ = std::make_unique>(num_wkrs * queue_size); - image_name_queue_ = std::make_unique>(num_wkrs * queue_size); - io_block_queues_.Init(num_workers_, queue_size); -} - -// Master thread that pulls the prescan worker's results. -// Keep collecting results until all prescan workers quit -// Then consolidate 2 level shuffles together into 1 giant vector -// calculate numRows then return -Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) { - std::vector v; - int64_t cnt = 0; - while (cnt != num_workers_) { // count number of end signals - FolderImagesPair p; - RETURN_IF_NOT_OK(image_name_queue_->PopFront(&p)); - if (p == nullptr) { - cnt++; - } else { - v.push_back(p); - } - } - std::sort(v.begin(), v.end(), - [](const FolderImagesPair &lhs, const FolderImagesPair &rhs) { return lhs->first < rhs->first; }); - // following loop puts the 2 level of shuffles together into 1 vector - for (size_t ind = 0; ind < v.size(); ++ind) { - while (v[ind]->second.empty() == false) { - MS_ASSERT(!(v[ind]->first.empty())); // make sure that v[ind]->first.substr(1) is not out of bound - v[ind]->second.front()->second = class_index_.empty() ? ind : class_index_[v[ind]->first.substr(1)]; - image_label_pairs_.push_back(v[ind]->second.front()); - v[ind]->second.pop(); - } - } - image_label_pairs_.shrink_to_fit(); - num_rows_ = image_label_pairs_.size(); - if (num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset " - "API validation first."); - } - // free memory of two queues used for pre-scan - folder_name_queue_->Reset(); - image_name_queue_->Reset(); - return Status::OK(); -} - -// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work -Status ImageFolderOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (true) { // each iterator is 1 epoch - std::vector keys; - keys.reserve(rows_per_buffer_); - while (sampler_buffer->eoe() == false) { - TensorRow sample_row; - RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - if (sample_ids->type() != DataType(DataType::DE_INT64)) RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - keys.push_back(*itr); - row_cnt_++; - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK( - io_block_queues_[buf_cnt_++ % num_workers_]->Add(std::make_unique(keys, IOBlock::kDeIoBlockNone))); - keys.clear(); - } - } - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(keys, IOBlock::kDeIoBlockNone))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); - for (int32_t i = 0; i < num_workers_; ++i) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { // not the last repeat. Sleep master thread, wait for the wake-up from reset - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - } -} - -// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ -// IMPORTANT: 1 IOBlock produces 1 DataBuffer -Status ImageFolderOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty() == true) return Status::OK(); // empty key is a quit signal for workers - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer -Status ImageFolderOp::LoadTensorRow(row_id_type row_id, ImageLabelPair pairPtr, TensorRow *trow) { - std::shared_ptr image, label; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), - reinterpret_cast(&pairPtr->second))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, folder_path_ + (pairPtr->first))); - - if (decode_ == true) { - Status rc = Decode(image, &image); - if (rc.IsError()) { - std::string err = "Fail to decode image:" + folder_path_ + (pairPtr->first); - RETURN_STATUS_UNEXPECTED(err); - } - } - (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); - return Status::OK(); -} - -// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer -Status ImageFolderOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - TensorRow trow; - for (const int64_t &key : keys) { - RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_label_pairs_[key], &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -void ImageFolderOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_ << "\nImageFolder directory: " << folder_path_ << "\n\n"; - } -} - -// Reset Sampler and wakeup Master thread (functor) -Status ImageFolderOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done - return Status::OK(); -} - -// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows -Status ImageFolderOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -// Derived from RandomAccessOp -Status ImageFolderOp::GetClassIds(std::map> *cls_ids) const { - if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { - RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); - } - for (size_t i = 0; i < image_label_pairs_.size(); ++i) { - (*cls_ids)[image_label_pairs_[i]->second].push_back(i); - } - for (auto &pair : (*cls_ids)) { - pair.second.shrink_to_fit(); - } - return Status::OK(); -} - -// Worker Entry for pre-scanning all the folders and do the 1st level shuffle -// Worker pull a file name from mFoldernameQueue (which is a Queue), walks all the images under that foldername -// After walking is complete, sort all the file names (relative path to all jpeg files under the same directory ) -// (Sort is automatically conducted using a set which is implemented using a Red-Black Tree) -// Add the sorted filenames in to a queue. The make a pair (foldername, queue*), -// foldername is used for 2nd level sorting. -// FYI: 1st level sorting: sort all images under the same directory. -// FYI: 2nd level sorting: sort all folder names -// push this pair to mImagenameQueue (which is again a Queue) -Status ImageFolderOp::PrescanWorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - std::string folder_name; - RETURN_IF_NOT_OK(folder_name_queue_->PopFront(&folder_name)); - while (folder_name.empty() == false) { - Path folder(folder_path_ + folder_name); - std::shared_ptr dirItr = Path::DirIterator::OpenDirectory(&folder); - if (folder.Exists() == false || dirItr == nullptr) { - RETURN_STATUS_UNEXPECTED("Error unable to open: " + folder_name); - } - std::set imgs; // use this for ordering - while (dirItr->hasNext()) { - Path file = dirItr->next(); - if (extensions_.empty() || extensions_.find(file.Extension()) != extensions_.end()) { - (void)imgs.insert(file.toString().substr(dirname_offset_)); - } else { - MS_LOG(WARNING) << "Image folder operator unsupported file found: " << file.toString() - << ", extension: " << file.Extension() << "."; - } - } - FolderImagesPair p = std::make_shared>>(); - p->first = folder_name; - for (const std::string &img : imgs) { - p->second.push(std::make_shared>(img, 0)); - } - RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(p)); - RETURN_IF_NOT_OK(folder_name_queue_->PopFront(&folder_name)); - } - RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(nullptr)); // end signal - return Status::OK(); -} - -// This helper function recursively walks all foldernames, and send each foldername to mFoldernameQueue -// if mRecursive == false, don't go into folder of folders -Status ImageFolderOp::RecursiveWalkFolder(Path *dir) { - std::shared_ptr dir_itr = Path::DirIterator::OpenDirectory(dir); - RETURN_UNEXPECTED_IF_NULL(dir_itr); - while (dir_itr->hasNext()) { - Path subdir = dir_itr->next(); - if (subdir.IsDirectory()) { - if (class_index_.empty() || - class_index_.find(subdir.toString().substr(dirname_offset_ + 1)) != class_index_.end()) { - RETURN_IF_NOT_OK(folder_name_queue_->EmplaceBack(subdir.toString().substr(dirname_offset_))); - } - if (recursive_ == true) { - RETURN_IF_NOT_OK(RecursiveWalkFolder(&subdir)); - } - } - } - return Status::OK(); -} - -// A thread that calls RecursiveWalkFolder -Status ImageFolderOp::startAsyncWalk() { - TaskManager::FindMe()->Post(); - Path dir(folder_path_); - if (dir.Exists() == false || dir.IsDirectory() == false) { - RETURN_STATUS_UNEXPECTED("Error unable to open: " + folder_path_); - } - dirname_offset_ = folder_path_.length(); - RETURN_IF_NOT_OK(RecursiveWalkFolder(&dir)); - // send out num_workers_ end signal to mFoldernameQueue, 1 for each worker. - // Upon receiving end Signal, worker quits and set another end Signal to mImagenameQueue. - for (int32_t ind = 0; ind < num_workers_; ++ind) { - RETURN_IF_NOT_OK(folder_name_queue_->EmplaceBack("")); // end signal - } - return Status::OK(); -} - -Status ImageFolderOp::LaunchThreadsAndInitOp() { - RETURN_UNEXPECTED_IF_NULL(tree_); - // Registers QueueList and individual Queues for interrupt services - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(image_name_queue_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - // The following code launch 3 threads group - // 1) A thread that walks all folders and push the folder names to a util:Queue mFoldernameQueue. - // 2) Workers that pull foldername from mFoldernameQueue, walk it and return the sorted images to mImagenameQueue - // 3) Launch main workers that load DataBuffers by reading all images - RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("walk dir", std::bind(&ImageFolderOp::startAsyncWalk, this))); - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::PrescanWorkerEntry, this, std::placeholders::_1))); - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - // The order of the following 2 functions must not be changed! - RETURN_IF_NOT_OK(this->PrescanMasterEntry(folder_path_)); // Master thread of pre-scan workers, blocking - RETURN_IF_NOT_OK(this->InitSampler()); // pass numRows to Sampler - return Status::OK(); -} - -Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::set &exts, int64_t *num_rows, - int64_t *num_classes, int64_t dev_id, int64_t num_dev) { - Path dir(path); - std::string err_msg = ""; - int64_t row_cnt = 0; - err_msg += (dir.Exists() == false || dir.IsDirectory() == false) ? "unable to open dir " + path : ""; - err_msg += (num_classes == nullptr || num_rows == nullptr) ? "num_class/num_rows is null\n" : ""; - err_msg += (dev_id >= num_dev || num_dev <= 0) ? "invalid sharding config\n" : ""; - if (err_msg.empty() == false) { - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::queue foldernames; - std::shared_ptr dir_itr = Path::DirIterator::OpenDirectory(&dir); - while (dir_itr->hasNext()) { - Path subdir = dir_itr->next(); - if (subdir.IsDirectory()) { - foldernames.push(subdir.toString()); - } - } - (*num_classes) = foldernames.size(); - while (foldernames.empty() == false) { - Path subdir(foldernames.front()); - dir_itr = Path::DirIterator::OpenDirectory(&subdir); - while (dir_itr->hasNext()) { - if (exts.empty() || exts.find(subdir.Extension()) != exts.end()) { - ++row_cnt; - } - } - foldernames.pop(); - } - (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); - return Status::OK(); -} - -// Visitor accept method for NodePass -Status ImageFolderOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status ImageFolderOp::ComputeColMap() { - // Set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h deleted file mode 100644 index 6629fd6092..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h +++ /dev/null @@ -1,274 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// Forward declares -template -class Queue; - -using ImageLabelPair = std::shared_ptr>; -using FolderImagesPair = std::shared_ptr>>; - -class ImageFolderOp : public ParallelOp, public RandomAccessOp { - public: - class Builder { - public: - // Constructor for Builder class of ImageFolderOp - // @param int32_t numWrks - number of parallel workers - // @param dir - directory folder got ImageNetFolder - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method - // @param int32_t size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t size) { - builder_op_connector_size_ = size; - return *this; - } - - // Setter method - // @param std::set & exts, file extensions to be read - // @return Builder setter method returns reference to the builder. - Builder &SetExtensions(const std::set &exts) { - builder_extensions_ = exts; - return *this; - } - - // Setter method - // @paramconst std::map& map - a class name to label map - // @return - Builder &SetClassIndex(const std::map &map) { - builder_labels_to_read_ = map; - return *this; - } - - // Setter method - // @param bool do_decode - // @return Builder setter method returns reference to the builder. - Builder &SetDecode(bool do_decode) { - builder_decode_ = do_decode; - return *this; - } - - // Setter method - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method - // @param const std::string & dir - // @return - Builder &SetImageFolderDir(const std::string &dir) { - builder_dir_ = dir; - return *this; - } - - // Whether dir are walked recursively - // @param bool recursive - if set to false, only get dirs in top level dir - // @return - Builder &SetRecursive(bool recursive) { - builder_recursive_ = recursive; - return *this; - } - - // Check validity of input args - // @return - The error code return - Status SanityCheck(); - - // The builder "build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - bool builder_decode_; - bool builder_recursive_; - std::string builder_dir_; - int32_t builder_num_workers_; - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - std::set builder_extensions_; - std::shared_ptr builder_sampler_; - std::unique_ptr builder_schema_; - std::map builder_labels_to_read_; - }; - - // Constructor - // @param int32_t num_wkrs - Num of workers reading images in parallel - // @param int32_t - rows_per_buffer Number of images (rows) in each buffer - // @param std::string - dir directory of ImageNetFolder - // @param int32_t queue_size - connector queue size - // @param std::set exts - set of file extensions to read, if empty, read everything under the dir - // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read - ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, - bool do_decode, const std::set &exts, const std::map &map, - std::unique_ptr, std::shared_ptr sampler); - - // Destructor. - ~ImageFolderOp() = default; - - // Initialize ImageFOlderOp related var, calls the function to walk all files - // @param - std::string dir file directory to ImageNetFolder - // @return - The error code return - Status PrescanMasterEntry(const std::string &dir); - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t workerId - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t workerId - id of each worker - // @return Status - The error code return - Status PrescanWorkerEntry(int32_t worker_id); - - // Main Loop of ImageFolderOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // Method derived from RandomAccess Op, enable Sampler to get all ids for each class - // @param (std::map> * map - key label, val all ids for this class - // @return Status - The error code return - Status GetClassIds(std::map> *cls_ids) const override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // This function is a hack! It is to return the num_class and num_rows. The result - // returned by this function may not be consistent with what image_folder_op is going to return - // user this at your own risk! - static Status CountRowsAndClasses(const std::string &path, const std::set &exts, int64_t *num_rows, - int64_t *num_classes, int64_t dev_id = 0, int64_t num_dev = 1); - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ImageFolderOp"; } - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Load a tensor row according to a pair - // @param row_id_type row_id - id for this tensor row - // @param ImageLabelPair pair - - // @param TensorRow row - image & label read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // @param std::string & dir - dir to walk all images - // @param int64_t * cnt - number of non folder files under the current dir - // @return - Status RecursiveWalkFolder(Path *dir); - - // start walking of all dirs - // @return - Status startAsyncWalk(); - - // Called first when function is called - // @return - Status LaunchThreadsAndInitOp(); - - // reset Op - // @return Status - The error code return - Status Reset() override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t rows_per_buffer_; - std::string folder_path_; // directory of image folder - bool recursive_; - bool decode_; - std::set extensions_; // extensions allowed - std::map class_index_; - std::unique_ptr data_schema_; - int64_t row_cnt_; - int64_t buf_cnt_; - int64_t sampler_ind_; - int64_t dirname_offset_; - WaitPost wp_; - std::vector image_label_pairs_; - QueueList> io_block_queues_; // queues of IOBlocks - std::unique_ptr> folder_name_queue_; - std::unique_ptr> image_name_queue_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.cc deleted file mode 100644 index 0963f1a67a..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/io_block.h" - -#include -#include - -namespace mindspore { -namespace dataset { -// IOBlock Class // - -// Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. -IOBlock::IOBlock(int64_t inKey, IOBlockFlags io_block_flags) : index_keys_(1, inKey), io_block_flags_(io_block_flags) {} - -// Constructor of the IOBlock (2) -IOBlock::IOBlock(const std::vector &in_keys, IOBlockFlags io_block_flags) : io_block_flags_(io_block_flags) { - index_keys_.insert(index_keys_.end(), in_keys.begin(), in_keys.end()); -} - -// Constructor of the IOBlock (3). A special IOBlock that is used for control messaging. -IOBlock::IOBlock(IOBlockFlags io_block_flags) : io_block_flags_(io_block_flags) {} - -// Fetches the first key from this block -Status IOBlock::GetKey(int64_t *out_key) const { - if (out_key == nullptr || index_keys_.empty()) { - RETURN_STATUS_UNEXPECTED("Failed to get the key from IOBlock"); - } - *out_key = index_keys_[0]; - return Status::OK(); -} - -// Fetches the list of keys from this block. -Status IOBlock::GetKeys(std::vector *out_keys) const { - if (out_keys == nullptr) { - RETURN_STATUS_UNEXPECTED("Output arg for GetKeys is null"); - } - *out_keys = index_keys_; // vector copy assign - return Status::OK(); -} - -// FilenameBlock derived class // - -// Constructor of the FilenameBlock (1) -FilenameBlock::FilenameBlock(int64_t key, int64_t start_offset, int64_t end_offset, IOBlockFlags io_block_flags) - : IOBlock(key, io_block_flags), start_offset_(start_offset), end_offset_(end_offset) {} - -// Constructor of the FilenameBlock (2). A special IOBlock that is used for control messaging. -FilenameBlock::FilenameBlock(IOBlockFlags io_block_flags) - : IOBlock(io_block_flags), start_offset_(kInvalidOffset), end_offset_(kInvalidOffset) {} - -// Gets the filename from the block using the provided index container -Status FilenameBlock::GetFilename(std::string *out_filename, const AutoIndexObj &index) const { - if (out_filename == nullptr) { - RETURN_STATUS_UNEXPECTED("Failed to get filename from FilenameBlock"); - } - - // a FilenameBlock only has one key. Call base class method to fetch that key - int64_t fetched_key; - RETURN_IF_NOT_OK(IOBlock::GetKey(&fetched_key)); - - // Do an index lookup using that key to get the filename. - auto r = index.Search(fetched_key); - if (r.second) { - auto &it = r.first; - *out_filename = it.value(); - } else { - RETURN_STATUS_UNEXPECTED("Could not find filename from index"); - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.h b/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.h deleted file mode 100644 index 87b417f027..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.h +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ - -#include -#include - -#include "dataset/util/auto_index.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// The IOBlock class is used to describe a "unit of work" that a storage leaf operator worker thread -// is responsible for acting on. -// The IOBlocks and it's derived classes abstracts a key-store and key-lookup interface where each -// block contains 1 to n keys, and the keys are used in conjunction with an index to provide the meta -// information for satisfying an IO request. -class IOBlock { - public: - enum IOBlockFlags : uint32_t { - kDeIoBlockNone = 0, - kDeIoBlockFlagEoe = 1u, // end of IOBlocks for one epoch - kDeIoBlockFlagEof = 1u << 1 // end of IOBlocks for entire program - }; - - // Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. - // @param inKey - A single key to add into the block - // @param io_block_flags - The flag setting for the block - IOBlock(int64_t inKey, IOBlockFlags io_block_flags); - - // Constructor of the IOBlock (2). - // @param in_keys - A vector of keys to add into the block - // @param io_block_flags - The flag setting for the block - IOBlock(const std::vector &in_keys, IOBlockFlags io_block_flags); - - // Constructor of the IOBlock (3). A special IOBlock that is used for control messaging. - // @param io_block_flags - The flag setting for the block - explicit IOBlock(IOBlockFlags io_block_flags); - - // Destructor - virtual ~IOBlock() = default; - - // Fetches the first key from the block. - // @note Only useful if you know the block only has 1 key. - // @return A copy of the first key from the block - // @return Status - The error code return - Status GetKey(int64_t *out_key) const; - - // Fetches the list of keys from this block. - // @param out_keys - A copy of the vector of keys from the block. - // @return Status - The error code return - Status GetKeys(std::vector *out_keys) const; - - // Does this block have the eoe flag turned on? - // @return T/F if the IOBlock is eoe - bool eoe() const { return static_cast(io_block_flags_) & static_cast(kDeIoBlockFlagEoe); } - - // Does this block have the eof flag turned on? - // @return T/F if the IOBlock is eof - bool eof() const { return static_cast(io_block_flags_) & static_cast(kDeIoBlockFlagEof); } - - // Adds a key to this block - // @param key - The key to add to this block - void AddKey(int64_t key) { index_keys_.push_back(key); } - - protected: - std::vector index_keys_; // keys used for lookups to the meta info for the data - IOBlockFlags io_block_flags_; -}; // class IOBlock - -const int64_t kInvalidOffset = -1; - -// The Filename block derived class implements a style of IO block where each block contains only a -// single key that maps to a filename. -class FilenameBlock : public IOBlock { - public: - // Constructor of the FilenameBlock (1) - // @param key - The key identifier that can be used to find the data for this block - // @param start_offset - Start offset - // @param end_offset - End offset - // @param io_block_flags - The flag setting for the block - FilenameBlock(int64_t key, int64_t start_offset, int64_t end_offset, IOBlockFlags io_block_flags); - - // Constructor of the FilenameBlock (2). A special IOBlock that is used for control messaging. - // @param io_block_flags - The flag setting for the block - explicit FilenameBlock(IOBlockFlags io_block_flags); - - // Destructor - ~FilenameBlock() = default; - - // Gets the filename from the block using the provided index container - // @param out_filename - The filename to add to the block - // @param index - The index to perform lookup against - // @return Status - The error code return - Status GetFilename(std::string *out_filename, const AutoIndexObj &index) const; - - // Get the start offset of file - // @return int64_t - Start offset - int64_t GetStartOffset() const { return start_offset_; } - - // Get the end offset of the file - // @return int64_t - Start offset - int64_t GetEndOffset() const { return end_offset_; } - - private: - int64_t start_offset_; - int64_t end_offset_; -}; // class TFBlock -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc deleted file mode 100644 index e65da8707b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc +++ /dev/null @@ -1,431 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/manifest_op.h" - -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { -ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_decode_(false) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status ManifestOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - builder_schema_ = std::make_unique(); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_file_, - builder_op_connector_size_, builder_decode_, builder_labels_to_read_, - std::move(builder_schema_), std::move(builder_sampler_), builder_usage_); - return Status::OK(); -} - -Status ManifestOp::Builder::SanityCheck() { - std::string err_msg; - err_msg += builder_file_.empty() ? "Manifest file is not set\n" : ""; - err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers smaller than 1\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, - const std::map &class_index, std::unique_ptr data_schema, - std::shared_ptr sampler, std::string usage) - : ParallelOp(num_works, queue_size, std::move(sampler)), - rows_per_buffer_(rows_per_buffer), - io_block_pushed_(0), - row_cnt_(0), - sampler_ind_(0), - data_schema_(std::move(data_schema)), - file_(file), - class_index_(class_index), - decode_(decode), - usage_(usage), - buf_cnt_(0) { - io_block_queues_.Init(num_workers_, queue_size); - (void)std::transform(usage_.begin(), usage_.end(), usage_.begin(), ::tolower); -} - -// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work -Status ManifestOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - return AddIoBlock(&sampler_buffer); -} - -Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { - while (true) { // each iterator is 1 epoch - std::vector keys; - keys.reserve(rows_per_buffer_); - while (!(*sampler_buffer)->eoe()) { - TensorRow sample_row; - RETURN_IF_NOT_OK((*sampler_buffer)->PopRow(&sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - keys.push_back(*itr); - row_cnt_++; - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - keys.clear(); - } - } - RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); - } - } -} - -Status ManifestOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&ManifestOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(ParseManifestFile()); - RETURN_IF_NOT_OK(CountDatasetInfo()); - RETURN_IF_NOT_OK(InitSampler()); - return Status::OK(); -} - -// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ -// IMPORTANT: 1 IOBlock produces 1 DataBuffer -Status ManifestOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty()) { - return Status::OK(); // empty key is a quit signal for workers - } - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer -Status ManifestOp::LoadTensorRow(row_id_type row_id, const std::pair> &data, - TensorRow *trow) { - std::shared_ptr image; - std::shared_ptr label; - std::vector label_index(data.second.size()); - (void)std::transform(data.second.begin(), data.second.end(), label_index.begin(), - [this](const std::string &label_name) { return label_index_[label_name]; }); - if (label_index.size() == 1) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), TensorShape({}), - data_schema_->column(1).type(), - reinterpret_cast(&label_index[0]))); - } else { - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &label, data_schema_->column(1).tensorImpl(), TensorShape(std::vector(1, label_index.size())), - data_schema_->column(1).type(), reinterpret_cast(&label_index[0]))); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data.first)); - if (decode_ == true) { - Status rc = Decode(image, &image); - if (rc.IsError()) { - std::string err = "Fail to decode image:" + data.first; - RETURN_STATUS_UNEXPECTED(err); - } - } - (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); - return Status::OK(); -} - -// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer -Status ManifestOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - for (const auto &key : keys) { - TensorRow trow; - RETURN_IF_NOT_OK(LoadTensorRow(key, image_labelname_[static_cast(key)], &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -void ManifestOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_ << "\nManifest file: " << file_ << "\n\n"; - } -} - -// Reset Sampler and wakeup Master thread (functor) -Status ManifestOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done - return Status::OK(); -} - -// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows -Status ManifestOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -// Derived from RandomAccessOp -Status ManifestOp::GetClassIds(std::map> *cls_ids) const { - if (cls_ids == nullptr || !cls_ids->empty() || image_labelname_.empty()) { - RETURN_STATUS_UNEXPECTED("Class indexing is invalid."); - } - - for (size_t i = 0; i < image_labelname_.size(); i++) { - size_t image_index = i; - for (size_t j = 0; j < image_labelname_[image_index].second.size(); j++) { - std::string label_name = (image_labelname_[image_index].second)[j]; - int32_t label_index = label_index_.at(label_name); - (*cls_ids)[label_index].emplace_back(image_index); - } - } - - for (auto &pair : (*cls_ids)) { - pair.second.shrink_to_fit(); - } - return Status::OK(); -} - -// Manifest file content -// {"source": "/path/to/image1.jpg", "usage":"train", annotation": ...} -// {"source": "/path/to/image2.jpg", "usage":"eval", "annotation": ...} -Status ManifestOp::ParseManifestFile() { - std::ifstream file_handle(file_); - if (!file_handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Manifest file " + file_ + " can not open."); - } - std::string line; - while (getline(file_handle, line)) { - try { - nlohmann::json js = nlohmann::json::parse(line); - std::string image_file_path = js.value("source", ""); - // If image is not JPEG/PNG/GIF/BMP, drop it - bool valid = false; - RETURN_IF_NOT_OK(CheckImageType(image_file_path, &valid)); - if (!valid) { - continue; - } - std::string usage = js.value("usage", ""); - (void)std::transform(usage.begin(), usage.end(), usage.begin(), ::tolower); - if (usage != usage_) { - continue; - } - std::vector labels; - nlohmann::json annotations = js.at("annotation"); - for (nlohmann::json::iterator it = annotations.begin(); it != annotations.end(); ++it) { - nlohmann::json annotation = it.value(); - std::string label_name = annotation.value("name", ""); - if (label_name == "") { - file_handle.close(); - RETURN_STATUS_UNEXPECTED("Label name is not found in manifest file for " + image_file_path); - } - if (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) { - if (label_index_.find(label_name) == label_index_.end()) { - label_index_[label_name] = 0; - } - labels.emplace_back(label_name); - } - } - if (!labels.empty()) { - image_labelname_.emplace_back(std::make_pair(image_file_path, labels)); - } - } catch (const std::exception &err) { - file_handle.close(); - RETURN_STATUS_UNEXPECTED("Parse manifest file failed"); - } - } - file_handle.close(); - - return Status::OK(); -} - -// Only support JPEG/PNG/GIF/BMP -Status ManifestOp::CheckImageType(const std::string &file_name, bool *valid) { - std::ifstream file_handle; - constexpr int read_num = 3; - *valid = false; - file_handle.open(file_name, std::ios::binary | std::ios::in); - if (!file_handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Can not open image file " + file_name); - } - unsigned char file_type[read_num]; - (void)file_handle.read(reinterpret_cast(file_type), read_num); - - if (file_handle.fail()) { - file_handle.close(); - RETURN_STATUS_UNEXPECTED("Read image file failed " + file_name); - } - file_handle.close(); - if (file_type[0] == 0xff && file_type[1] == 0xd8 && file_type[2] == 0xff) { - // Normal JPEGs start with \xff\xd8\xff\xe0 - // JPEG with EXIF stats with \xff\xd8\xff\xe1 - // Use \xff\xd8\xff to cover both. - *valid = true; - } else if (file_type[0] == 0x89 && file_type[1] == 0x50 && file_type[2] == 0x4e) { - // It's a PNG - *valid = true; - } else if (file_type[0] == 0x47 && file_type[1] == 0x49 && file_type[2] == 0x46) { - // It's a GIF - *valid = true; - } else if (file_type[0] == 0x42 && file_type[1] == 0x4d) { - // It's a BMP - *valid = true; - } - return Status::OK(); -} - -Status ManifestOp::CountDatasetInfo() { - int32_t index = 0; - for (auto &label : label_index_) { - label.second = class_index_.empty() ? index : class_index_[label.first]; - index++; - } - - num_rows_ = static_cast(image_labelname_.size()); - if (num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " - "validation first."); - } - return Status::OK(); -} - -Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, - int64_t *count, int64_t *numClasses) { - // the logic of counting the number of samples is copied from ParseManifestFile() - std::map map; - for (auto p : dict) { - (void)map.insert(std::pair(py::reinterpret_borrow(p.first), - py::reinterpret_borrow(p.second))); - } - - std::shared_ptr op; - *count = 0; - RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op)); - RETURN_IF_NOT_OK(op->ParseManifestFile()); - *numClasses = static_cast(op->label_index_.size()); - *count = static_cast(op->image_labelname_.size()); - return Status::OK(); -} - -Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, - std::map *output_class_indexing) { - std::map input_class_indexing; - for (auto p : dict) { - (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), - py::reinterpret_borrow(p.second))); - } - - if (!input_class_indexing.empty()) { - *output_class_indexing = input_class_indexing; - } else { - std::shared_ptr op; - RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(input_class_indexing).SetUsage(usage).Build(&op)); - RETURN_IF_NOT_OK(op->ParseManifestFile()); - RETURN_IF_NOT_OK(op->CountDatasetInfo()); - uint32_t count = 0; - for (const auto label : op->label_index_) { - (*output_class_indexing).insert(std::make_pair(label.first, count)); - count++; - } - } - - return Status::OK(); -} - -Status ManifestOp::ComputeColMap() { - // Set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h deleted file mode 100644 index c180ea581d..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h +++ /dev/null @@ -1,244 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/queue.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -class ManifestOp : public ParallelOp, public RandomAccessOp { - public: - class Builder { - public: - // Constructor for Builder class of ManifestOp - Builder(); - - // Destructor - ~Builder() = default; - - // Setter method - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method - // @param int32_t size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t size) { - builder_op_connector_size_ = size; - return *this; - } - - // Setter method - // @param const std::map& map - a class name to label map - // @return - Builder &SetClassIndex(const std::map &map) { - builder_labels_to_read_ = map; - return *this; - } - - // Setter method - // @param bool do_decode - // @return Builder setter method returns reference to the builder. - Builder &SetDecode(bool do_decode) { - builder_decode_ = do_decode; - return *this; - } - - // Setter method - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method - // @param const std::string & dir - // @return Builder setter method returns reference to the builder. - Builder &SetManifestFile(const std::string &file) { - builder_file_ = file; - return *this; - } - - // Setter method - // @param const std::string & dir - // @return Builder setter method returns reference to the builder. - Builder &SetUsage(const std::string &usage) { - builder_usage_ = usage; - return *this; - } - - // Check validity of input args - // @return Status - The error code return - Status SanityCheck(); - - // The builder "build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - std::shared_ptr builder_sampler_; - bool builder_decode_; - - std::string builder_file_; - int32_t builder_num_workers_; - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - std::unique_ptr builder_schema_; - std::string builder_usage_; - std::map builder_labels_to_read_; - }; - - // Constructor - // @param int32_t num_works - Num of workers reading images in parallel - // @param int32_t - rows_per_buffer Number of images (rows) in each buffer - // @param std::string - file list of Manifest - // @param int32_t queue_size - connector queue size - // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read - ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, - const std::map &class_index, std::unique_ptr data_schema, - std::shared_ptr sampler, std::string usage); - // Destructor. - ~ManifestOp() = default; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t worker_id - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Main Loop of ManifestOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // Method derived from RandomAccess Op, enable Sampler to get all ids for each class - // @param (std::map> * map - key label, val all ids for this class - // @return Status - The error code return - Status GetClassIds(std::map> *cls_ids) const override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count, - int64_t *numClasses); - - // Get str-to-int mapping from label name to index - static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, - std::map *output_class_indexing); - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ManifestOp"; } - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Method in operator(), to fill IOBlockQueue - // @param std::unique_ptr sampler_buffer - to fill IOBlockQueue - // @return Status - The error code return - Status AddIoBlock(std::unique_ptr *sampler_buffer); - - // Load a tensor row according to a pair - // @param row_id_type row_id - id for this tensor row - // @param std::pair> - > - // @param TensorRow row - image & label read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, const std::pair> &data, - TensorRow *row); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Parse manifest file to get image path and label and so on. - // @return Status - The error code return - Status ParseManifestFile(); - - // Called first when function is called - // @return Status - The error code return - Status LaunchThreadsAndInitOp(); - - // reset Op - // @return Status - The error code return - Status Reset() override; - - // Check if image ia valid.Only support JPEG/PNG/GIF/BMP - // @return - Status CheckImageType(const std::string &file_name, bool *valid); - - // Count label index,num rows and num samples - // @return Status - The error code return - Status CountDatasetInfo(); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t rows_per_buffer_; - int64_t io_block_pushed_; - int64_t row_cnt_; - int64_t sampler_ind_; - std::unique_ptr data_schema_; - std::string file_; // file that store the information of images - std::map class_index_; - bool decode_; - std::string usage_; - int64_t buf_cnt_; - - WaitPost wp_; - QueueList> io_block_queues_; - std::map label_index_; - std::vector>> image_labelname_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc deleted file mode 100644 index 2b9d010ebb..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ /dev/null @@ -1,513 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/mindrecord_op.h" - -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -using mindrecord::kInt64Len; -using mindrecord::MSRStatus; -using mindrecord::Schema; -using mindrecord::ShardOperator; -using mindrecord::ShardReader; - -// Builder constructor. Creates the builder object. -MindRecordOp::Builder::Builder() : build_dataset_file_({}) { - // Some arguments to the MindRecordOp constructor have a default argument that is taken - // from the client config. - // The user may choose to change these values for the construction of the MindRecordOp by - // using the various builder set methods. - - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_mind_record_workers_ = kDefaultMindRecordWorkers; - build_rows_per_buffer_ = cfg->rows_per_buffer(); - build_op_connector_queue_size_ = cfg->op_connector_size(); - build_block_reader_ = false; - builder_num_workers_ = 0; - build_num_padded_ = 0; - build_sample_ = nullptr; -} - -// The builder "build" method creates the final object. -Status MindRecordOp::Builder::Build(std::shared_ptr *ptr) { - std::shared_ptr new_mind_record_op; - - if (build_dataset_file_.empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Building a MindRecordOp that has not provided a file."); - } - mindrecord::json sample_json; - if (build_num_padded_ > 0) { - sample_json = ToJson(build_sample_); - } - new_mind_record_op = std::make_shared( - build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_, - build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_, build_num_padded_, - sample_json, build_sample_bytes_); - - RETURN_IF_NOT_OK(new_mind_record_op->Init()); - *ptr = std::move(new_mind_record_op); - return Status::OK(); -} - -Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); } - -mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) { - if (obj.is_none()) { - return nullptr; - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj)) { // also catch py::bytes - return obj.cast(); - } - if (py::isinstance(obj)) { - auto out = mindrecord::json::object(); - for (const py::handle &key : obj) { - if (py::isinstance(obj[key])) { - build_sample_bytes_[py::str(key).cast()] = obj[key].cast(); - } else { - out[py::str(key).cast()] = ToJson(obj[key]); - } - } - return out; - } - MS_LOG(ERROR) << "Python object convert to json failed, object is: " << py::cast(obj); - return mindrecord::json(); -} - -// Constructor of the MindRecordOp. -MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, - std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, - const std::vector &columns_to_load, - const std::vector> &operators, const bool &block_reader, - int64_t num_padded, const mindrecord::json &sample_json, - const std::map &sample_bytes) - : ParallelOp(num_mind_record_workers, op_connector_queue_size), - rows_per_buffer_(rows_per_buffer), - dataset_file_(dataset_file), - load_dataset_(load_dataset), - columns_to_load_(columns_to_load), - operators_(operators), - num_mind_record_workers_(num_mind_record_workers), - block_reader_(block_reader), - num_rows_(0), - buffers_needed_(0), - buf_cnt_(0), - ended_worker_(0), - buffer_water_mark_(0), - num_padded_(num_padded), - sample_json_(sample_json), - sample_bytes_(sample_bytes) { - io_blk_queues_.Init(num_workers_, op_connector_queue_size); - if (!block_reader_) return; - for (int32_t i = 0; i < num_workers_; ++i) { - block_buffer_.emplace_back(std::make_unique>(std::vector{})); - } -} - -// Private helper method to encapsulate some common construction/reset tasks -Status MindRecordOp::Init() { - shard_reader_ = std::make_unique(); - auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, - block_reader_, num_padded_); - - CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, - "MindRecordOp init failed. Error message: " + ErrnoToMessage(rc)); - - data_schema_ = std::make_unique(); - - std::vector col_names = shard_reader_->GetShardColumn()->GetColumnName(); - CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found"); - std::vector col_data_types = shard_reader_->GetShardColumn()->GeColumnDataType(); - std::vector> col_shapes = shard_reader_->GetShardColumn()->GetColumnShape(); - - bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything - std::map colname_to_ind; - for (uint32_t i = 0; i < col_names.size(); i++) { - std::string colname = col_names[i]; - ColDescriptor col_desc; - - TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown - std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]]; - DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"} - - if (col_data_types[i] == mindrecord::ColumnBytes) { // rank = 1 - col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1); - } else if (col_data_types[i] == mindrecord::ColumnString) { // rank = 0 - col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 0); - } else if (col_shapes[i].size() > 0) { - std::vector vec(col_shapes[i].size()); // temporary vector to hold shape - (void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin()); - t_shape = TensorShape(vec); - col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); - } else { // unknown shape - // create colDesc and add it to schema - col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); - } - - colname_to_ind[colname] = data_schema_->NumColumns(); - RETURN_IF_NOT_OK(data_schema_->AddColumn(col_desc)); - - if (load_all_cols) { - columns_to_load_.emplace_back(colname); - } - } - - if (!load_all_cols) { - std::unique_ptr tmp_schema = std::make_unique(); - for (std::string colname : columns_to_load_) { - CHECK_FAIL_RETURN_UNEXPECTED(colname_to_ind.find(colname) != colname_to_ind.end(), colname + ": doesn't exist"); - RETURN_IF_NOT_OK(tmp_schema->AddColumn(data_schema_->column(colname_to_ind[colname]))); - } - data_schema_ = std::move(tmp_schema); - } - - return Status::OK(); -} - -// Destructor -MindRecordOp::~MindRecordOp() {} - -// A print method typically used for debugging -void MindRecordOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\n Dataset file : "; - for (auto &file : dataset_file_) { - out << file << " "; - } - out << "\nNumber of rows : " << num_rows_ << "\nRows per buffer : " << rows_per_buffer_ - << "\nNumber of buffers : " << buffers_needed_ - << "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n"; - } -} - -Status MindRecordOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe()) { - RETURN_IF_NOT_OK( - out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - continue; - } - if (io_block->eof()) { - RETURN_IF_NOT_OK( - out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - continue; - } - - // load data buffer - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty() == true) { - { - std::unique_lock lock(ended_worker_mutex_); - ended_worker_++; - if (ended_worker_ == num_workers_) shard_reader_->Close(); - } - return Status::OK(); // empty key is a quit signal for workers - } - - const uint64_t buffer_id = keys[0]; - std::unique_ptr fetched_buffer; - - // Get the next buffer. Push it up to the output connector. - if (buffer_id % LOG_INTERVAL == 0) { - MS_LOG(DEBUG) << "MindRecord operator consumed buffer " << buffer_id << " by worker " << worker_id << "."; - } - RETURN_IF_NOT_OK(GetBufferFromReader(&fetched_buffer, buffer_id, worker_id)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(fetched_buffer))); - if (!block_reader_) { - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - continue; - } - - // update block-reader buffer - block_buffer_[buffer_id % num_workers_]->clear(); - { - std::unique_lock lck(mtx_block_reader_); - if (buffer_id == buffer_water_mark_) { - buffer_water_mark_++; - while (block_set_.count(buffer_water_mark_) > 0) (void)block_set_.erase(buffer_water_mark_++); - } else { - (void)block_set_.insert(buffer_id); - } - } - cv_reader_.notify_one(); - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -Status MindRecordOp::GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, - int32_t worker_id) { - *fetched_buffer = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - std::unique_ptr tensor_table = std::make_unique(); - for (int32_t i = 0; i < rows_per_buffer_; ++i) { - ShardTuple tupled_buffer; - mindrecord::TaskType task_type = mindrecord::TaskType::kCommonTask; - if (block_reader_) { - if (i >= block_buffer_[buffer_id % num_workers_]->size()) break; - tupled_buffer = block_buffer_[buffer_id % num_workers_]->at(i); - } else { - int32_t row_id = buffer_id * rows_per_buffer_ + i; - auto rc = shard_reader_->GetNextById(row_id, worker_id); - task_type = rc.first; - tupled_buffer = rc.second; - if (task_type == mindrecord::TaskType::kPaddedTask) { - TensorRow tensor_row; - RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, {}, mindrecord::json(), task_type)); - tensor_table->push_back(std::move(tensor_row)); - } - if (tupled_buffer.empty()) break; - } - if (task_type == mindrecord::TaskType::kCommonTask) { - for (const auto &tupled_row : tupled_buffer) { - std::vector columns_blob = std::get<0>(tupled_row); - mindrecord::json columns_json = std::get<1>(tupled_row); - TensorRow tensor_row; - RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json, task_type)); - tensor_table->push_back(std::move(tensor_row)); - } - } - } - - // Replace the TensorTable in DataBuffer with the new one. - (*fetched_buffer)->set_tensor_table(std::move(tensor_table)); - return Status::OK(); -} - -Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, - const mindrecord::json &columns_json, const mindrecord::TaskType task_type) { - for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) { - auto column_name = columns_to_load_[i_col]; - - // Initialize column parameters - const unsigned char *data = nullptr; - std::unique_ptr data_ptr; - uint64_t n_bytes = 0; - mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType; - uint64_t column_data_type_size = 1; - std::vector column_shape; - - // Get column data - auto shard_column = shard_reader_->GetShardColumn(); - if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) { - auto rc = - shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, &column_shape); - if (rc.first != MSRStatus::SUCCESS) { - RETURN_STATUS_UNEXPECTED("Failed to retrieve data type."); - } - if (rc.second == mindrecord::ColumnInRaw) { - auto has_column = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes); - if (has_column == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("Failed to retrieve raw data from padding sample."); - } - } else if (rc.second == mindrecord::ColumnInBlob) { - if (sample_bytes_.find(column_name) == sample_bytes_.end()) { - RETURN_STATUS_UNEXPECTED("Failed to retrieve blob data from padding sample."); - } - std::string ss(sample_bytes_[column_name]); - n_bytes = ss.size(); - data_ptr = std::make_unique(n_bytes); - std::copy(ss.begin(), ss.end(), data_ptr.get()); - } else { - RETURN_STATUS_UNEXPECTED("Retrieved data type is unknown."); - } - if (data == nullptr) { - data = reinterpret_cast(data_ptr.get()); - } - } else { - auto has_column = - shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, - &column_data_type, &column_data_type_size, &column_shape); - if (has_column == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader."); - } - } - - std::shared_ptr tensor; - const ColDescriptor &column = data_schema_->column(i_col); - DataType type = column.type(); - - // Set shape - auto num_elements = n_bytes / column_data_type_size; - if (type == DataType::DE_STRING) { - std::string s{data, data + n_bytes}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {s}, TensorShape::CreateScalar())); - } else if (column.hasShape()) { - auto new_shape = TensorShape(column.shape()); - RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast(num_elements), &new_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); - } else { - std::vector shapeDetails = {static_cast(num_elements)}; - auto new_shape = TensorShape(shapeDetails); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); - } - tensor_row->push_back(std::move(tensor)); - } - return Status::OK(); -} - -Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { - { - std::unique_lock lck(mtx_block_reader_); - cv_reader_.wait(lck, [buffer_id, this] { return buffer_id < buffer_water_mark_ + num_workers_; }); - } - for (int32_t i = 0; i < rows_per_buffer_; i++) { - // Block reader does NOT care about argument - auto rc = shard_reader_->GetNextById(i, i); - ShardTuple tuple_buffer = rc.second; - if (tuple_buffer.empty()) break; - block_buffer_[buffer_id % num_workers_]->push_back(std::move(tuple_buffer)); - } - return Status::OK(); -} - -// Class functor operator () override. -// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will -// provide the master loop that drives the logic for performing the work -// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work -Status MindRecordOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadAndInitOp()); - num_rows_ = shard_reader_->GetNumRows(); - // Compute how many buffers we would need to accomplish rowsPerBuffer - buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; - - while (true) { // each iterator is 1 epoch - for (int32_t i = 0; i < buffers_needed_; ++i) { - if (block_reader_) RETURN_IF_NOT_OK(FetchBlockBuffer(i)); - std::vector keys(1, i); - RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - RETURN_IF_NOT_OK( - io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK( - io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK(io_blk_queues_[i]->Add( - std::move(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone)))); - } - return Status::OK(); - } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset - RETURN_IF_NOT_OK( - io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - - // reset our buffer count and go to loop again. - RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); - shard_reader_wait_post_.Clear(); - } - } -} - -// Overrides base class reset method. When an operator does a reset, it cleans up any state -// info from it's previous execution and then initializes itself so that it can be executed -// again. -Status MindRecordOp::Reset() { - RETURN_IF_NOT_OK(ParallelOp::Reset()); // Call our super class reset first. - - if (block_reader_) { - shard_reader_->Reset(); - buffer_water_mark_ = 0; - } else { - shard_reader_->ShuffleTask(); - } - shard_reader_wait_post_.Set(); - - return Status::OK(); -} - -Status MindRecordOp::LaunchThreadAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - - RETURN_IF_NOT_OK(io_blk_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(shard_reader_wait_post_.Register(tree_->AllTasks())); - if (shard_reader_->Launch(!block_reader_) == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); - } - // Launch main workers that load DataBuffers by reading all images - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - return Status::OK(); -} - -Status MindRecordOp::CountTotalRows(const std::vector dataset_path, bool load_dataset, - const std::shared_ptr &op, int64_t *count, int64_t num_padded) { - std::unique_ptr shard_reader = std::make_unique(); - MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded); - if (rc == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); - } - return Status::OK(); -} - -// Visitor accept method for NodePass -Status MindRecordOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status MindRecordOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - for (int i = 0; i < static_cast(columns_to_load_.size()); i++) { - column_name_id_map_[columns_to_load_[i]] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h deleted file mode 100644 index af405a8f5b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h +++ /dev/null @@ -1,276 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/util/queue.h" -#include "dataset/util/status.h" -#include "mindrecord/include/shard_column.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/common/shard_utils.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// Forward declares -template -class Queue; -class DataBuffer; - -using mindrecord::ShardOperator; -using mindrecord::ShardReader; -using ShardTuple = std::vector, mindrecord::json>>; // Row of data from ShardReader - -const int32_t LOG_INTERVAL = 19; - -class MindRecordOp : public ParallelOp { - public: - // The nested builder class inside of the MindRecordOp is used to help manage all of the arguments - // for constructing it. Use the builder by setting each argument with the provided set methods, - // and then finally call the build method to execute the actual construction. - class Builder { - public: - Builder(); - - ~Builder() = default; - - Status Build(std::shared_ptr *); - - Builder &SetRowsPerBuffer(int rows_per_buffer) { - build_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - Builder &SetNumMindRecordWorkers(int32_t num_mind_record_workers) { - build_num_mind_record_workers_ = num_mind_record_workers; - return *this; - } - - Builder &SetOpConnectorQueueSize(int32_t queue_size) { - build_op_connector_queue_size_ = queue_size; - return *this; - } - - Builder &SetDatasetFile(const std::vector &files) { - build_dataset_file_ = files; - return *this; - } - - Builder &SetColumnsToLoad(const std::vector &columns) { - build_columns_to_load_ = columns; - return *this; - } - - Builder &SetOperators(const std::vector> &operators) { - build_operators_ = operators; - return *this; - } - - Builder &SetBlockReader() { - build_block_reader_ = true; - return *this; - } - - Builder &SetLoadDataset(bool load_dataset) { - build_load_dataset_ = load_dataset; - return *this; - } - - Builder &SetNumToPadSamples(int64_t num_padded) { - build_num_padded_ = num_padded; - return *this; - } - - Builder &SetPaddedSample(const py::handle &sample) { - build_sample_ = sample; - return *this; - } - - Status SanityCheck() const; - - static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; } - - mindrecord::json ToJson(const py::handle &obj); - - private: - static constexpr int32_t kDefaultMindRecordWorkers = 4; - // The builder saves all MindRecordOp construction arguments internally. - // The following are the arguments. - int32_t build_num_mind_record_workers_; - int32_t builder_num_workers_; - int32_t build_rows_per_buffer_; - int32_t build_op_connector_queue_size_; - std::vector build_dataset_file_; - bool build_load_dataset_; - std::vector build_columns_to_load_; - std::vector> build_operators_; - bool build_block_reader_; - int64_t build_num_padded_; - py::handle build_sample_; - std::map build_sample_bytes_; - }; - - // Constructor of the MindRecordOp. - // @note The builder class should be used to call it - // @param num_mind_record_workers - The number of workers for the op (run by ShardReader) - // @param rows_per_buffer - The requested number of rows per buffer - // @param dataset_file - dataset files - // @param op_connector_queue_size - The output connector queue size - // @param columns_to_load - The list of columns to use (column name) - // @param operators - ShardOperators for Shuffle, Category, Sample - MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector dataset_file, - bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, - const std::vector> &operators, const bool &block_reader, - int64_t num_padded_, const mindrecord::json &sample_json, - const std::map &sample_bytes_); - - // Destructor - ~MindRecordOp() override; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param op - reference to the MindRecordOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const MindRecordOp &op) { - op.Print(out, false); - return out; - } - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t workerId - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Class functor operator () override. - // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work. - // @return Status - The error code return - Status operator()() override; - - // Called first when function is called - // @return - Status LaunchThreadAndInitOp(); - - // Overrides base class reset method. When an operator does a reset, it cleans up any state - // info from it's previous execution and then initializes itself so that it can be executed - // again. - // @return Status - The error code return - Status Reset() override; - - // Getter method - int32_t num_rows() const { return num_rows_; } - - static Status CountTotalRows(const std::vector dataset_path, bool load_dataset, - const std::shared_ptr &op, int64_t *count, int64_t num_padded); - - // Getter method - int32_t rows_per_buffer() const { return rows_per_buffer_; } - - // Getter method - std::vector dataset_file() const { return dataset_file_; } - - // Getter method - std::vector columns_to_load() const { return columns_to_load_; } - - bool block_reader() const { return block_reader_; } - - bool load_dataset() const { return load_dataset_; } - - Status Init(); - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "MindRecordOp"; } - - private: - Status GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, int32_t worker_id); - - // Parses a single cell and puts the data into a tensor - // @param tensor_row - the tensor row to put the parsed data in - // @param columns_blob - the blob data received from the reader - // @param columns_json - the data for fields received from the reader - Status LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, - const mindrecord::json &columns_json, const mindrecord::TaskType task_type); - - Status FetchBlockBuffer(const int32_t &buffer_id); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t rows_per_buffer_; // The number of requested rows per buffer. - std::vector dataset_file_; // dataset files - bool load_dataset_; // load dataset from single file or not - std::vector columns_to_load_; // Columns to load from dataset - std::vector> operators_; // ShardOperators to use - int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader - bool block_reader_; // block reader switch - int32_t buffers_needed_; // Counter for the buffers that were fetched - int64_t buf_cnt_; // Buffer counter - int32_t num_rows_; // One more than the last row id in the range for this cache - std::atomic ended_worker_; - std::atomic buffer_water_mark_; - - int64_t num_padded_; - mindrecord::json sample_json_; - std::map sample_bytes_; - - std::unique_ptr data_schema_; // Data schema for column typing - std::vector columns_blob_; // Blob Columns to load from dataset - std::vector columns_blob_index_; // Blob Columns to load from dataset - - std::unique_ptr shard_reader_; - WaitPost shard_reader_wait_post_; - QueueList> io_blk_queues_; - - // For block reader - std::mutex mtx_block_reader_; - std::condition_variable cv_reader_; - std::vector>> block_buffer_; - std::unordered_set block_set_; - - std::mutex ended_worker_mutex_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc deleted file mode 100644 index e98f8ae8c1..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc +++ /dev/null @@ -1,443 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/mnist_op.h" - -#include -#include -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { -const int32_t kMnistImageFileMagicNumber = 2051; -const int32_t kMnistLabelFileMagicNumber = 2049; -const int32_t kMnistImageRows = 28; -const int32_t kMnistImageCols = 28; - -MnistOp::Builder::Builder() : builder_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status MnistOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - builder_schema_ = std::make_unique(); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, - builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_)); - return Status::OK(); -} - -Status MnistOp::Builder::SanityCheck() { - Path dir(builder_dir_); - std::string err_msg; - err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : ""; - err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, - std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size, std::move(sampler)), - buf_cnt_(0), - row_cnt_(0), - folder_path_(folder_path), - rows_per_buffer_(rows_per_buffer), - data_schema_(std::move(data_schema)) { - io_block_queues_.Init(num_workers, queue_size); -} - -Status MnistOp::TraversalSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - keys->push_back(*itr); - row_cnt_++; - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); - keys->clear(); - } - } - return Status::OK(); -} - -// functor that contains the main logic of MNIST op -Status MnistOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (true) { // each iterator is 1 epoch - std::vector keys; - keys.reserve(rows_per_buffer_); - while (sampler_buffer->eoe() == false) { - std::shared_ptr sample_ids; - RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); - if (sample_ids->type() != DataType(DataType::DE_INT64)) { - RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't UINT64"); - } - RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys)); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - for (int32_t i = 0; i < num_workers_; ++i) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - } -} - -// contains the logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ -Status MnistOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr iOBlock; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); - while (iOBlock != nullptr) { - if (iOBlock->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (iOBlock->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(iOBlock->GetKeys(&keys)); - if (keys.empty() == true) return Status::OK(); // empty key is a quit signal for workers - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -// Load 1 TensorRow (image,label) using 1 MnistLabelPair. -Status MnistOp::LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *trow) { - std::shared_ptr image, label; - int32_t l = mnist_pair.second; - // make a copy of cached tensor - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), mnist_pair.first->shape(), - mnist_pair.first->type(), mnist_pair.first->GetBuffer())); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), reinterpret_cast(&l))); - (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); - return Status::OK(); -} - -// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer -Status MnistOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - TensorRow trow; - for (const int64_t &key : keys) { - RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_label_pairs_[key], &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -void MnistOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_ << "\nMNIST Directory: " << folder_path_ << "\n\n"; - } -} - -// Reset Sampler and wakeup Master thread (functor) -Status MnistOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done - return Status::OK(); -} - -// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows -Status MnistOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -// Derived from RandomAccessOp -Status MnistOp::GetClassIds(std::map> *cls_ids) const { - if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { - RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); - } - for (size_t i = 0; i < image_label_pairs_.size(); ++i) { - (*cls_ids)[image_label_pairs_[i].second].push_back(i); - } - for (auto &pair : (*cls_ids)) { - pair.second.shrink_to_fit(); - } - return Status::OK(); -} - -Status MnistOp::ReadFromReader(std::ifstream *reader, uint32_t *result) { - uint32_t res = 0; - reader->read(reinterpret_cast(&res), 4); - if (reader->fail()) { - RETURN_STATUS_UNEXPECTED("Failed to read 4 bytes from file"); - } - *result = SwapEndian(res); - return Status::OK(); -} - -uint32_t MnistOp::SwapEndian(uint32_t val) const { - val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF); - return (val << 16) | (val >> 16); -} - -Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images) { - if (image_reader->is_open() == false) { - RETURN_STATUS_UNEXPECTED("Cannot open mnist image file: " + file_name); - } - int64_t image_len = image_reader->seekg(0, std::ios::end).tellg(); - (void)image_reader->seekg(0, std::ios::beg); - // The first 16 bytes of the image file are type, number, row and column - if (image_len < 16) { - RETURN_STATUS_UNEXPECTED("Mnist file is corrupted."); - } - uint32_t magic_number; - RETURN_IF_NOT_OK(ReadFromReader(image_reader, &magic_number)); - CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistImageFileMagicNumber, - "This is not the mnist image file: " + file_name); - - uint32_t num_items; - RETURN_IF_NOT_OK(ReadFromReader(image_reader, &num_items)); - uint32_t rows; - RETURN_IF_NOT_OK(ReadFromReader(image_reader, &rows)); - uint32_t cols; - RETURN_IF_NOT_OK(ReadFromReader(image_reader, &cols)); - // The image size of the Mnist dataset is fixed at [28,28] - if ((rows != kMnistImageRows) || (cols != kMnistImageCols)) { - RETURN_STATUS_UNEXPECTED("Wrong shape of image."); - } - if ((image_len - 16) != num_items * rows * cols) { - RETURN_STATUS_UNEXPECTED("Wrong number of image."); - } - *num_images = num_items; - return Status::OK(); -} - -Status MnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) { - if (label_reader->is_open() == false) { - RETURN_STATUS_UNEXPECTED("Cannot open mnist label file: " + file_name); - } - int64_t label_len = label_reader->seekg(0, std::ios::end).tellg(); - (void)label_reader->seekg(0, std::ios::beg); - // The first 8 bytes of the image file are type and number - if (label_len < 8) { - RETURN_STATUS_UNEXPECTED("Mnist file is corrupted."); - } - uint32_t magic_number; - RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number)); - CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistLabelFileMagicNumber, - "This is not the mnist label file: " + file_name); - uint32_t num_items; - RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items)); - if ((label_len - 8) != num_items) { - RETURN_STATUS_UNEXPECTED("Wrong number of labels!"); - } - *num_labels = num_items; - return Status::OK(); -} - -Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) { - uint32_t num_images, num_labels; - RETURN_IF_NOT_OK(CheckImage(image_names_[index], image_reader, &num_images)); - RETURN_IF_NOT_OK(CheckLabel(label_names_[index], label_reader, &num_labels)); - CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels), "num_images != num_labels"); - // The image size of the Mnist dataset is fixed at [28,28] - int64_t size = kMnistImageRows * kMnistImageCols; - auto images_buf = std::make_unique(size * num_images); - auto labels_buf = std::make_unique(num_images); - if (images_buf == nullptr || labels_buf == nullptr) { - std::string err_msg = "Fail to allocate memory for MNIST Buffer."; - MS_LOG(ERROR) << err_msg.c_str(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)image_reader->read(images_buf.get(), size * num_images); - if (image_reader->fail()) { - RETURN_STATUS_UNEXPECTED("Fail to read:" + image_names_[index] + " size:" + std::to_string(size * num_images)); - } - (void)label_reader->read(labels_buf.get(), num_images); - if (label_reader->fail()) { - RETURN_STATUS_UNEXPECTED("Fail to read:" + label_names_[index] + " size: " + std::to_string(num_images)); - } - TensorShape img_tensor_shape = TensorShape({kMnistImageRows, kMnistImageCols, 1}); - for (int64_t j = 0; j != num_images; ++j) { - auto pixels = &images_buf[j * size]; - for (int64_t m = 0; m < size; ++m) { - pixels[m] = (pixels[m] == 0) ? 0 : 255; - } - std::shared_ptr image; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), img_tensor_shape, - data_schema_->column(0).type(), reinterpret_cast(pixels))); - image_label_pairs_.emplace_back(std::make_pair(image, labels_buf[j])); - } - return Status::OK(); -} - -Status MnistOp::ParseMnistData() { - for (size_t i = 0; i < image_names_.size(); ++i) { - std::ifstream image_reader, label_reader; - image_reader.open(image_names_[i], std::ios::binary); - label_reader.open(label_names_[i], std::ios::binary); - - Status s = ReadImageAndLabel(&image_reader, &label_reader, i); - // Close the readers - image_reader.close(); - label_reader.close(); - RETURN_IF_NOT_OK(s); - } - image_label_pairs_.shrink_to_fit(); - num_rows_ = image_label_pairs_.size(); - if (num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API " - "validation first."); - } - return Status::OK(); -} - -Status MnistOp::WalkAllFiles() { - const std::string kImageExtension = "idx3-ubyte"; - const std::string kLabelExtension = "idx1-ubyte"; - - Path dir(folder_path_); - auto dir_it = Path::DirIterator::OpenDirectory(&dir); - if (dir_it != nullptr) { - while (dir_it->hasNext()) { - Path file = dir_it->next(); - std::string filename = file.toString(); - if (filename.find(kImageExtension) != std::string::npos) { - image_names_.push_back(filename); - MS_LOG(INFO) << "Mnist operator found image file at " << filename << "."; - } else if (filename.find(kLabelExtension) != std::string::npos) { - label_names_.push_back(filename); - MS_LOG(INFO) << "Mnist Operator found label file at " << filename << "."; - } - } - } else { - MS_LOG(WARNING) << "Mnist operator unable to open directory " << dir.toString() << "."; - } - - std::sort(image_names_.begin(), image_names_.end()); - std::sort(label_names_.begin(), label_names_.end()); - - if (image_names_.size() != label_names_.size()) { - RETURN_STATUS_UNEXPECTED("num of images does not equal to num of labels"); - } - - return Status::OK(); -} - -Status MnistOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&MnistOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(this->WalkAllFiles()); - RETURN_IF_NOT_OK(this->ParseMnistData()); - RETURN_IF_NOT_OK(this->InitSampler()); // handle shake with sampler - return Status::OK(); -} - -Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) { - // the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader() - std::shared_ptr op; - *count = 0; - RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op)); - - RETURN_IF_NOT_OK(op->WalkAllFiles()); - - for (size_t i = 0; i < op->image_names_.size(); ++i) { - std::ifstream image_reader; - image_reader.open(op->image_names_[i], std::ios::binary); - std::ifstream label_reader; - label_reader.open(op->label_names_[i], std::ios::binary); - - uint32_t num_images; - RETURN_IF_NOT_OK(op->CheckImage(op->image_names_[i], &image_reader, &num_images)); - uint32_t num_labels; - RETURN_IF_NOT_OK(op->CheckLabel(op->label_names_[i], &label_reader, &num_labels)); - CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels), "num of images does not equal to num of labels"); - *count = *count + num_images; - - // Close the readers - image_reader.close(); - label_reader.close(); - } - - return Status::OK(); -} - -Status MnistOp::ComputeColMap() { - // set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h deleted file mode 100644 index 9bd6276a11..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h +++ /dev/null @@ -1,246 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// Forward declares -template -class Queue; - -using MnistLabelPair = std::pair, int32_t>; - -class MnistOp : public ParallelOp, public RandomAccessOp { - public: - class Builder { - public: - // Constructor for Builder class of MnistOp - // @param uint32_t numWrks - number of parallel workers - // @param dir - directory folder got ImageNetFolder - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method - // @param int32_t op_connector_size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method - // @param const std::string & dir - // @return - Builder &SetDir(const std::string &dir) { - builder_dir_ = dir; - return *this; - } - - // Check validity of input args - // @return - The error code return - Status SanityCheck(); - - // The builder "Build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - std::string builder_dir_; - int32_t builder_num_workers_; - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - std::shared_ptr builder_sampler_; - std::unique_ptr builder_schema_; - }; - - // Constructor - // @param int32_t num_workers - number of workers reading images in parallel - // @param int32_t rows_per_buffer - number of images (rows) in each buffer - // @param std::string folder_path - dir directory of mnist - // @param int32_t queue_size - connector queue size - // @param std::unique_ptr data_schema - the schema of the mnist dataset - // @param td::unique_ptr sampler - sampler tells MnistOp what to read - MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, - std::unique_ptr data_schema, std::shared_ptr sampler); - - // Destructor. - ~MnistOp() = default; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t worker_id - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Main Loop of MnistOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // Method derived from RandomAccess Op, enable Sampler to get all ids for each class - // @param (std::map> * map - key label, val all ids for this class - // @return Status - The error code return - Status GetClassIds(std::map> *cls_ids) const override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // Function to count the number of samples in the MNIST dataset - // @param dir path to the MNIST directory - // @param count output arg that will hold the minimum of the actual dataset size and numSamples - // @return - static Status CountTotalRows(const std::string &dir, int64_t *count); - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "MnistOp"; } - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Load a tensor row according to a pair - // @param row_id_type row_id - id for this tensor row - // @param ImageLabelPair pair - - // @param TensorRow row - image & label read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Iterate through all members in sampleIds and fill them into IOBlock. - // @param std::shared_ptr sample_ids - - // @param std::vector *keys - keys in ioblock - // @return Status - The error code return - Status TraversalSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); - - // Check image file stream. - // @param const std::string *file_name - image file name - // @param std::ifstream *image_reader - image file stream - // @param uint32_t num_images - returns the number of images - // @return Status - The error code return - Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images); - - // Check label stream. - // @param const std::string &file_name - label file name - // @param std::ifstream *label_reader - label file stream - // @param uint32_t num_labels - returns the number of labels - // @return Status - The error code return - Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels); - - // Read 4 bytes of data from a file stream. - // @param std::ifstream *reader - file stream to read - // @return uint32_t - read out data - Status ReadFromReader(std::ifstream *reader, uint32_t *result); - - // Swap endian - // @param uint32_t val - - // @return uint32_t - swap endian data - uint32_t SwapEndian(uint32_t val) const; - - // Read the specified number of images and labels from the file stream - // @param std::ifstream *image_reader - image file stream - // @param std::ifstream *label_reader - label file stream - // @param int64_t read_num - number of image to read - // @return Status - The error code return - Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index); - - // Parse all mnist dataset files - // @return Status - The error code return - Status ParseMnistData(); - - // Read all files in the directory - // @return Status - The error code return - Status WalkAllFiles(); - - // Called first when function is called - // @return Status - The error code return - Status LaunchThreadsAndInitOp(); - - // reset Op - // @return Status - The error code return - Status Reset() override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int64_t buf_cnt_; - int64_t row_cnt_; - WaitPost wp_; - std::string folder_path_; // directory of image folder - int32_t rows_per_buffer_; - std::unique_ptr data_schema_; - std::vector image_label_pairs_; - std::vector image_names_; - std::vector label_names_; - QueueList> io_block_queues_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc deleted file mode 100644 index 3a865d8d69..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc +++ /dev/null @@ -1,429 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/engine/datasetops/source/random_data_op.h" -#include -#include -#include "dataset/engine/execution_tree.h" -#include "dataset/core/config_manager.h" -#include "dataset/util/random.h" -#include "dataset/util/wait_post.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -RandomDataOp::Builder::Builder() - : builder_data_schema_(nullptr), - builder_num_workers_(0), - builder_op_connector_size_(0), - builder_rows_per_buffer_(0), - builder_total_rows_(0), - builder_sampler_(nullptr) { - // Some arguments to the RandomDataOp have a default argument that is taken from the config. - // The user may override these defaults by using the builder set methods. - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -// The build method that produces the instantiated RandomDataOp as a shared pointer -Status RandomDataOp::Builder::Build(std::shared_ptr *out_op) { - RETURN_IF_NOT_OK(SanityCheck()); - - *out_op = - std::make_shared(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, - builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_)); - - // If the user did not provide a schema, then we will ask the op to generate a pseudo-random - // schema. - // See details of generateSchema function to learn what type of schema it will create. - if ((*out_op)->data_schema_ == nullptr) { - RETURN_IF_NOT_OK((*out_op)->GenerateSchema()); - } - - return Status::OK(); -} - -// Check if the required parameters are set by the builder. -Status RandomDataOp::Builder::SanityCheck() const { - // There actually is no required arguments for the random data op at all. - // Some arguments are preset with global values from config, and if they are not given by the user - // then we create them randomly. Leaving this function here for consistency with other operators. - return Status::OK(); -} - -// Constructor for RandomDataOp -RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, - std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), - buffer_id_(0), - rows_per_buffer_(rows_per_buffer), - total_rows_(total_rows), - epoch_buffers_sent_(0), - guys_in_(0), - guys_out_(num_workers_), - eoe_worker_id_(0), - data_schema_(std::move(data_schema)) { - rand_gen_.seed(GetSeed()); // seed the random generator - // If total rows was not given, then randomly pick a number - if (total_rows_ == 0) { - total_rows_ = GenRandomInt(1, kMaxTotalRows); - } - // Everyone is already out from the sync area. - all_out_.Set(); -} - -// A print method typically used for debugging -void RandomDataOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [total rows: " << total_rows_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nTotal_rows: " << total_rows_ << "\nRows per buffer: " << rows_per_buffer_ << "\nSchema:\n" - << *data_schema_ << "\n\n"; - } -} - -// Helper function to produce a default/random schema if one didn't exist -Status RandomDataOp::GenerateSchema() { - if (data_schema_ != nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Generating a schema but one already exists!"); - } - - // To randomly create a schema, we need to choose: - // a) how many columns - // b) the type of each column - // c) the shape of each column (number of dimensions i.e. rank) - // d) the shape of each column (dimension values) - data_schema_ = std::make_unique(); - std::unique_ptr newShape; - std::unique_ptr newCol; - - // Loop over the number of chosen columns - int32_t numColumns = GenRandomInt(1, kMaxNumColumns); - for (int32_t i = 0; i < numColumns; i++) { - // For each column: - // - choose a datatype - // - generate a shape that randomly chooses the number of dimensions and the dimension values. - DataType::Type newType = static_cast(GenRandomInt(1, DataType::NUM_OF_TYPES - 2)); - int32_t rank = GenRandomInt(1, kMaxRank); - std::vector dims; - for (int32_t d = 0; d < rank; d++) { - // 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random - // 0 value to the unknown attribute if 0 is chosen - dsize_t dim_value = static_cast(GenRandomInt(0, kMaxDimValue)); - if (dim_value == 0) dim_value = TensorShape::kDimUnknown; - dims.push_back(dim_value); - } - newShape = std::make_unique(dims); - - // Create the column descriptor - std::string colName = "c" + std::to_string(i); - newCol = std::make_unique(colName, DataType(newType), TensorImpl::kFlexible, rank, newShape.get()); - - data_schema_->AddColumn(*newCol); - } - - return Status::OK(); -} - -// Class functor operator () override. -// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will -// provide the master loop that drives the logic for performing the work. -Status RandomDataOp::operator()() { - // First, compute how many buffers we'll need to satisfy the total row count. - // The only reason we do this is for the purpose of throttling worker count if needed. - int64_t buffers_needed = total_rows_ / rows_per_buffer_; - if (total_rows_ % rows_per_buffer_ != 0) { - buffers_needed++; - } - - // If the amount of workers we have exceeds the number of buffers to produce, then we'll have - // idle workers doing nothing. In that case, let's throttle the worker count. - if (num_workers_ > buffers_needed) { - MS_LOG(INFO) << "RandomDataOp throttling worker count from " << num_workers_ << "to " << buffers_needed; - num_workers_ = buffers_needed; - num_producers_ = num_workers_; - guys_out_ = num_workers_; - // The output connector was already created with a different worker count. We have to drop and recreate - // that connector. - DatasetOp::CreateConnector(num_producers_, num_workers_); - } - - // Assign the number of rows to each worker in a round robin fashion. - worker_max_rows_.reserve(num_workers_); - worker_rows_packed_.reserve(num_workers_); - // init the counts to zero to start. - for (int32_t w = 0; w < num_workers_; w++) { - worker_max_rows_.push_back(0); - worker_rows_packed_.push_back(0); - } - // then assign round robin row counts - int32_t currentWorker = 0; - for (int64_t r = 0; r < total_rows_; r++) { - worker_max_rows_[currentWorker]++; - currentWorker = (currentWorker + 1) % num_workers_; - } - - // Next, compute the total buffer count. This stat is needed during reset logic - for (int32_t w = 0; w < num_workers_; w++) { - int64_t worker_buffers = 0; - worker_buffers = worker_max_rows_[w] / rows_per_buffer_; - if (worker_max_rows_[w] % rows_per_buffer_ != 0) worker_buffers++; - epoch_buffers_sent_ += worker_buffers; - } - - // For the connector to work, we need to target the correct worker channel for the eoe. - // This will initialize it for the first one. reset() handles for the rest of the epochs. - eoe_worker_id_ = epoch_buffers_sent_ % num_workers_; - epoch_buffers_sent_++; // Add the eoe buffer to the count for subsequent epochs - - // RandomDataOp doesn't need the master thread to stay around. Kick off the workers and then master exits. - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&RandomDataOp::WorkerEntry, this, std::placeholders::_1))); - - // required task group setup after launching workers - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(epoch_sync_wait_post_.Register(tree_->AllTasks())); - - return Status::OK(); -} - -// Performs a synchronization between workers at the end of an epoch -Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " syncing at end of epoch"; - - // Sync on the guys_in counter - // We have to wait the last guy is out. - all_out_.Wait(); - // If we are not in a repeat loop, or that was the last repeat already, then setup our exit - // condition from the master loop. - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - *quitting = true; - } - - auto prev = guys_in_.fetch_add(1); - bool last_guy_in = (prev + 1) == num_workers_; - // If we are the last worker to hit this sync point, we have some extra tasks - if (last_guy_in) { - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker " - << eoe_worker_id_; - // Prepare for sync - all_out_.Clear(); - // Always flow eoe at the end - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(eoe_worker_id_, std::move(eoe_buffer))); - // If we're done then also flow the eof - if (*quitting) { - // The eof needs to be sent from the next sender in the round robin, so +1 - int32_t eof_worker_id = (eoe_worker_id_ + 1) % num_workers_; - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " has no more epochs. sending eof as worker " - << eof_worker_id; - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(eof_worker_id, std::move(eof_buffer))); - } - } - - // Wait for the reset to wake us up if we're not quitting - if (!(*quitting)) { - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entering sync wait."; - RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait()); - prev = guys_out_.fetch_add(1); - bool last_guy_out = (prev + 1) == num_workers_; - // Last guy out will clear the wait post and set the row counts - if (last_guy_out) { - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " last guy out clearing wait post."; - epoch_sync_wait_post_.Clear(); - guys_in_ = 0; - all_out_.Set(); - } - } - - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " epoch sync complete."; - return Status::OK(); -} - -// The entry point code for when workers are launched -Status RandomDataOp::WorkerEntry(int32_t worker_id) { - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entry"; - - // handshake with the master first to tell it we're alive - TaskManager::FindMe()->Post(); - - bool quitting = false; - std::unique_ptr new_tensor_table = nullptr; - - // Loop until the quitting variable gets set to true - do { - // If we have not yet reached the row count for this worker then produce another record - if (worker_rows_packed_[worker_id] < worker_max_rows_[worker_id]) { - TensorRow new_row; - - // Start a new tensor table if needed - if (new_tensor_table == nullptr) { - new_tensor_table = std::make_unique(); - } - - // Create the data for the row - RETURN_IF_NOT_OK(CreateRandomRow(worker_id, &new_row)); - - // Add the row to our table - new_tensor_table->push_back(std::move(new_row)); - worker_rows_packed_[worker_id]++; - - // If the tensor table is at capacity then it's time to send it to output - if (new_tensor_table->size() == rows_per_buffer_) { - RETURN_IF_NOT_OK(PackAndSend(worker_id, std::move(new_tensor_table))); - } - } else { - // We've reached the total row count for this worker, so it's time for epoch sync. - // There is likely some records built but not sent yet, so take care of those first - // (this buffer will be smaller than rows_per_buffer) - if (new_tensor_table != nullptr && new_tensor_table->size() > 0) { - RETURN_IF_NOT_OK(PackAndSend(worker_id, std::move(new_tensor_table))); - } - - // Now, let's enter the epoch sync - RETURN_IF_NOT_OK(EpochSync(worker_id, &quitting)); - } - } while (!quitting); - - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is now quitting."; - - return Status::OK(); -} - -// A helper function to stuff the tensor table into a buffer and send it to output connector -Status RandomDataOp::PackAndSend(int32_t worker_id, std::unique_ptr in_table) { - auto new_buffer = std::make_unique(GetNextBufferId(), DataBuffer::kDeBFlagNone); - new_buffer->set_tensor_table(std::move(in_table)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(new_buffer))); - return Status::OK(); -} - -// A helper function to create random data for the row -Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) { - if (new_row == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Missing tensor row output"); - } - - // Create a tensor for each column, then add the tensor to the row - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - const ColDescriptor current_col = data_schema_->column(i); - std::vector current_shape = current_col.shape().AsVector(); - std::unique_ptr new_shape = nullptr; - std::unique_ptr buf = nullptr; - std::shared_ptr new_tensor = nullptr; - - // We need to resolve the shape to fill in any unknown dimensions with random - // values, then use that as our shape for this tensor. - for (int j = 0; j < current_shape.size(); ++j) { - if (current_shape[j] == TensorShape::kDimUnknown) { - current_shape[j] = static_cast(GenRandomInt(1, kMaxDimValue)); - } - } - - new_shape = std::make_unique(current_shape); - int64_t size_in_bytes = new_shape->NumOfElements() * current_col.type().SizeInBytes(); - - // Generate a random byte of data. This may cause some funny data for things like doubles,floats, bools - // however the random data op is not too concerned about the physical data itself. - std::uniform_int_distribution uniDist(0, 255); - uint8_t random_byte = uniDist(rand_gen_); - - // Now, create a chunk of memory for the entire tensor and copy this byte in repeatedly. - buf = std::make_unique(size_in_bytes); - int ret_code = memset_s(buf.get(), size_in_bytes, random_byte, size_in_bytes); - if (ret_code != 0) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Failed to set random bytes for a tensor."); - } - - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&new_tensor, current_col.tensorImpl(), *new_shape, current_col.type(), buf.get())); - - // Add this tensor to the tensor row for output - (*new_row).push_back(std::move(new_tensor)); - } - return Status::OK(); -} - -// Overrides base class reset method. When an operator does a reset, it cleans up any state -// info from it's previous execution and then initializes itself so that it can be executed -// again. -Status RandomDataOp::Reset() { - MS_LOG(INFO) << "RandomDataOp resetting."; - - // Ensure all guys are in the waitpost - if (guys_in_ != num_workers_) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Issuing a reset, but some workers are missing from epochSync!"); - } - - // reset the row counters for all workers - for (int32_t w = 0; w < num_workers_; w++) { - worker_rows_packed_[w] = 0; - worker_max_rows_[w] = 0; - } - buffer_id_ = 0; - - // Re-assign round robin row counts, starting from the worker after the one that gave - // the eoe last time - int32_t currentWorker = (eoe_worker_id_ + 1) % num_workers_; - for (int64_t r = 0; r < total_rows_; r++) { - worker_max_rows_[currentWorker]++; - currentWorker = (currentWorker + 1) % num_workers_; - } - - // Compute which worker should get the eoe for the next epoch - eoe_worker_id_ = ((epoch_buffers_sent_ % num_workers_) + eoe_worker_id_) % num_workers_; - - // Wake up the workers to get them going again in a new epoch - guys_out_ = 0; - epoch_sync_wait_post_.Set(); - - return Status::OK(); -} - -Status RandomDataOp::ComputeColMap() { - // Extract the column name mapping from the schema and save it in the class. - if (column_name_id_map_.empty()) { - RETURN_IF_NOT_OK(data_schema_->GetColumnNameMap(&(column_name_id_map_))); - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -// During tree prepare phase, operators may have specific post-operations to perform depending on -// their role. -Status RandomDataOp::PrepareNodePostAction() { - // Run common code from super class before adding RandomDataOp specific handling - RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); - // Specific handling for this op, we need to do cache op work to assign the sampler to the cache. - RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(false)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h deleted file mode 100644 index b2af27dda3..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h +++ /dev/null @@ -1,291 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ - -#include -#include -#include -#include -#include -#include -#include -#include "dataset/util/status.h" -#include "dataset/core/tensor.h" -#include "dataset/core/data_type.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// The RandomDataOp is a leaf node storage operator that generates random data based -// on the schema specifications. Typically, it's used for testing and demonstrating -// various dataset operator pipelines. It is not "real" data to train with. -// The data that is random created is just random and repeated bytes, there is no -// "meaning" behind what these bytes are. -class RandomDataOp : public ParallelOp { - public: - // Some constants to provide limits to random generation. - static constexpr int32_t kMaxNumColumns = 4; - static constexpr int32_t kMaxRank = 4; - static constexpr int32_t kMaxDimValue = 32; - static constexpr int32_t kMaxTotalRows = 1024; - - // A nested builder class to aid in the construction of a RandomDataOp - class Builder { - public: - /** - * Builder constructor. Creates the builder object. - * @note No default args. - * @return This is a constructor. - */ - Builder(); - - /** - * Default destructor - */ - ~Builder() = default; - - /** - * The build method that produces the instantiated RandomDataOp as a shared pointer - * @param out_op - The output RandomDataOperator that was constructed - * @return Status - The error code return - */ - Status Build(std::shared_ptr *out_op); - - /** - * Builder set method - * @param data_schema - A user-provided schema - * @return Builder - The modified builder by reference - */ - Builder &SetDataSchema(std::unique_ptr data_schema) { - builder_data_schema_ = std::move(data_schema); - return *this; - } - - /** - * Builder set method - * @param num_workers - The number of workers - * @return Builder - The modified builder by reference - */ - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - /** - * Builder set method - * @param op_connector_size - The size of the output connector - * @return Builder - The modified builder by reference - */ - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - /** - * Builder set method - * @param rows_per_buffer - The number of rows in each DataBuffer - * @return Builder - The modified builder by reference - */ - Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - /** - * Builder set method - * @param total_rows - The total number of rows in the dataset - * @return Builder - The modified builder by reference - */ - Builder &SetTotalRows(int64_t total_rows) { - builder_total_rows_ = total_rows; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - private: - /** - * Check if the required parameters are set by the builder. - * @return Status - The error code return - */ - Status SanityCheck() const; - - std::unique_ptr builder_data_schema_; - std::shared_ptr builder_sampler_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - int64_t builder_rows_per_buffer_; - int64_t builder_total_rows_; - }; // class Builder - - /** - * Constructor for RandomDataOp - * @note Private constructor. Must use builder to construct. - * @param num_workers - The number of workers - * @param op_connector_size - The size of the output connector - * @param rows_per_buffer - The number of rows in each DataBuffer - * @param data_schema - A user-provided schema - * @param total_rows - The total number of rows in the dataset - * @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes - * @return Builder - The modified builder by reference - */ - RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, - std::unique_ptr data_schema, std::shared_ptr sampler); - - /** - * Destructor - */ - ~RandomDataOp() = default; - - /** - * A print method typically used for debugging - * @param out - The output stream to write output to - * @param show_all - A bool to control if you want to show all info or just a summary - */ - void Print(std::ostream &out, bool show_all) const override; - - /** - * << Stream output operator overload - * @notes This allows you to write the debug print info using stream operators - * @param out - reference to the output stream being overloaded - * @param so - reference to the ShuffleOp to display - * @return - the output stream must be returned - */ - friend std::ostream &operator<<(std::ostream &out, const RandomDataOp &op) { - op.Print(out, false); - return out; - } - - /** - * Class functor operator () override. - * All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will - * provide the master loop that drives the logic for performing the work. - * @return Status - The error code return - */ - Status operator()() override; - - /** - * Overrides base class reset method. When an operator does a reset, it cleans up any state - * info from it's previous execution and then initializes itself so that it can be executed - * again. - * @return Status - The error code return - */ - Status Reset() override; - - /** - * Quick getter for total rows. - */ - int64_t GetTotalRows() const { return total_rows_; } - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "RandomDataOp"; } - - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePostAction() override; - - private: - /** - * The entry point code for when workers are launched - * @param worker_id - The worker id - * @return Status - The error code return - */ - Status WorkerEntry(int32_t worker_id) override; - - /** - * Helper function to produce a default/random schema if one didn't exist - @return Status - The error code return - */ - Status GenerateSchema(); - - /** - * Performs a synchronization between workers at the end of an epoch - * @param worker_id - The worker id - * @return Status - The error code return - */ - Status EpochSync(int32_t worker_id, bool *quitting); - - /** - * A helper function to stuff the tensor table into a buffer and send it to output connector - * @param worker_id - The worker id - * @param in_table - The tensor table to pack and send - * @return Status - The error code return - */ - Status PackAndSend(int32_t worker_id, std::unique_ptr in_table); - - /** - * A helper function to create random data for the row - * @param worker_id - The worker id - * @param new_row - The output row to produce - * @return Status - The error code return - */ - Status CreateRandomRow(int32_t worker_id, TensorRow *new_row); - - /** - * A quick inline for producing a random number between (and including) min/max - * @param min - minimum number that can be generated - * @param max - maximum number that can be generated - * @return - The generated random number - */ - inline int32_t GenRandomInt(int32_t min, int32_t max) { - std::uniform_int_distribution uniDist(min, max); - return uniDist(rand_gen_); - } - - /** - * A quick inline for producing the next buffer id in sequence, threadsafe - * @return - The next buffer id. - */ - inline int32_t GetNextBufferId() { - std::unique_lock lock(buffer_id_mutex_); - return ++buffer_id_; - } - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t buffer_id_; - int64_t rows_per_buffer_; - int64_t total_rows_; - int64_t epoch_buffers_sent_; - std::atomic guys_in_; - std::atomic guys_out_; - int32_t eoe_worker_id_; - std::unique_ptr data_schema_; - std::vector worker_max_rows_; - std::vector worker_rows_packed_; - std::mt19937 rand_gen_; - WaitPost epoch_sync_wait_post_; - WaitPost all_out_; - std::mutex buffer_id_mutex_; -}; // class RandomDataOp -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt deleted file mode 100644 index 5209d9ba4a..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(engine-datasetops-source-sampler OBJECT - distributed_sampler.cc - pk_sampler.cc - python_sampler.cc - random_sampler.cc - sampler.cc - sequential_sampler.cc - subset_random_sampler.cc - weighted_random_sampler.cc - ) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc deleted file mode 100644 index 9f4a9cf55c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ /dev/null @@ -1,119 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" - -#include -#include - -#include "dataset/engine/data_buffer.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, - uint32_t seed) - : Sampler(num_samples, std::numeric_limits::max()), - cnt_(0), - seed_(seed == std::numeric_limits::max() ? GetSeed() : seed), - device_id_(dev_id), - num_devices_(num_dev), - shuffle_(shuffle) {} - -Status DistributedSampler::InitSampler() { - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. - if (num_samples_ == 0 || num_samples_ > num_rows_) { - num_samples_ = num_rows_; - } - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n"); - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); - CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, - "fail to init DistributedSampler"); - rnd_.seed(seed_++); - samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) - samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_; - if (shuffle_ == true) { - shuffle_vec_.reserve(num_rows_); - for (int64_t i = 0; i < num_rows_; i++) { - shuffle_vec_.push_back(i); - } - std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); - } - return Status::OK(); -} - -Status DistributedSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (cnt_ > samples_per_buffer_) { - RETURN_STATUS_UNEXPECTED("Distributed Sampler Error"); - } else if (cnt_ == samples_per_buffer_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - (*out_buffer) = std::make_unique(cnt_, DataBuffer::kDeBFlagNone); - std::shared_ptr sample_ids; - RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_)); - auto id_ptr = sample_ids->begin(); - while (cnt_ < samples_per_buffer_ && id_ptr != sample_ids->end()) { - int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_; - if (shuffle_) { - sampled_id = shuffle_vec_[static_cast(sampled_id)]; - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *id_ptr = sampled_id; - id_ptr++; - cnt_++; - } - TensorRow row(1, sample_ids); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - } - return Status::OK(); -} - -Status DistributedSampler::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); - cnt_ = 0; - - if (shuffle_ == true) { - rnd_.seed(seed_); - seed_++; - std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -void DistributedSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: DistributedSampler"; - if (show_all) { - Sampler::Print(out, show_all); - out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ - << "\nshuffle: " << shuffle_; - } -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h deleted file mode 100644 index 7083580c6c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ - -#include -#include -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -class DistributedSampler : public Sampler { - public: - // @param num_samples - // @param int64_t num_dev - // @param int64_t dev_id - // @param bool shuffle - DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, - uint32_t seed = std::numeric_limits::max()); - - // default destructor - ~DistributedSampler() = default; - - // @param std::unique_ptr * pBuffer - // @param int32_t workerId - // @return - The error code return - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // Init sampler, called by base class or python - Status InitSampler() override; - - // for next epoch of sampleIds - // @return - The error code return - Status ResetSampler() override; - - void Print(std::ostream &out, bool show_all) const override; - - private: - int64_t cnt_; // number of samples that have already been filled in to buffer - uint32_t seed_; - int64_t device_id_; - int64_t num_devices_; - bool shuffle_; - std::mt19937 rnd_; - std::vector shuffle_vec_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc deleted file mode 100644 index cd2cadb9ff..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include -#include -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), - shuffle_(shuffle), - seed_(GetSeed()), - next_id_(0), - samples_per_class_(val) {} - -Status PKSampler::InitSampler() { - labels_.reserve(label_to_ids_.size()); - for (const auto &pair : label_to_ids_) { - if (pair.second.empty() == false) { - labels_.push_back(pair.first); - } - } - rnd_.seed(seed_++); - - // The special handshake gives the list of classes and id's, but it did not set the num_rows_ to - // capture the total number of possible sample ids. - // Compute that here for this case to find the total number of samples that are available to return. - // (in this case, samples per class * total classes). - num_rows_ = samples_per_class_ * static_cast(labels_.size()); - - // The user may have chosen to sample less than the total amount. - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. - if (num_samples_ == 0 || num_samples_ > num_rows_) { - num_samples_ = num_rows_; - } - - samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; - if (shuffle_ == true) { - std::shuffle(labels_.begin(), labels_.end(), rnd_); - } else { - std::sort(labels_.begin(), labels_.end()); - } - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_class or K (num samples per class) is not positive"); - return Status::OK(); -} - -Status PKSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (next_id_ > num_samples_ || num_samples_ == 0) { - RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler"); - } else if (next_id_ == num_samples_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); - std::shared_ptr sample_ids; - int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; - RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_)); - auto id_ptr = sample_ids->begin(); - while (next_id_ < last_id && id_ptr != sample_ids->end()) { - int64_t cls_id = next_id_++ / samples_per_class_; - const std::vector &samples = label_to_ids_[labels_[cls_id]]; - int64_t rnd_ind = std::uniform_int_distribution(0, samples.size() - 1)(rnd_); - int64_t sampled_id = samples[rnd_ind]; - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *id_ptr = sampled_id; - id_ptr++; - } - - TensorRow row(1, sample_ids); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - } - return Status::OK(); -} - -Status PKSampler::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); - next_id_ = 0; - rnd_.seed(seed_++); - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { - RETURN_UNEXPECTED_IF_NULL(op); - RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); - RETURN_IF_NOT_OK(InitSampler()); - return Status::OK(); -} - -void PKSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: PKSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info if any - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h deleted file mode 100644 index cde8a75b5b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ - -#include -#include -#include -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -class PKSampler : public Sampler { // NOT YET FINISHED - public: - // @param num_samples - the number of samples to draw. value of 0 means to take the full amount - // @param int64_t val - // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 - // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle, - int64_t samples_per_buffer = std::numeric_limits::max()); - - // default destructor - ~PKSampler() = default; - - // @param std::unique_ptr *out_buffer) override; - - // first handshake between leaf source op and Sampler. This func will determine the amount of data - // in the dataset that we can sample from. - // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is - // @return - Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; - - // init sampler, to be called by python or Handshake - Status InitSampler() override; - - // for next epoch of sampleIds - // @return - The error code return - Status ResetSampler() override; - - // Printer for debugging purposes. - // @param out - output stream to write to - // @param show_all - bool to show detailed vs summary - void Print(std::ostream &out, bool show_all) const override; - - private: - bool shuffle_; - uint32_t seed_; - int64_t next_id_; - int64_t samples_per_class_; - std::mt19937 rnd_; - std::vector labels_; - std::map> label_to_ids_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc deleted file mode 100644 index d204c55ce9..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ /dev/null @@ -1,116 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/python_sampler.h" - -#include - -namespace mindspore { -namespace dataset { - -PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} - -Status PythonSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (need_to_reset_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - std::shared_ptr sample_ids; - { - py::gil_scoped_acquire gil_acquire; - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - py::object py_ret = py_sampler_instance.attr("_get_indices")(); - py::array np_sample_ids = py_ret.cast(); - Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor - - if (HasChildSampler()) { - for (auto it = sample_ids->begin(); it != sample_ids->end(); ++it) { - int64_t associated_child_id = 0; - RETURN_IF_NOT_OK(GetAssociatedChildId(&associated_child_id, associated_child_id)); - *it = associated_child_id; - } - } - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } catch (const py::cast_error &e) { - return Status(StatusCode::kPyFuncException, "Python Sampler iterator should return integer index"); - } - } - TensorRow row(1, sample_ids); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - need_to_reset_ = true; - } - return Status::OK(); -} - -Status PythonSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. - if (num_samples_ == 0 || num_samples_ > num_rows_) { - num_samples_ = num_rows_; - } - { - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - py_sampler_instance.attr("_handshake")(num_rows_, num_samples_); - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } - } - return Status::OK(); -} - -Status PythonSampler::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); - need_to_reset_ = false; - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - py_sampler_instance.attr("reset")(); - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -void PythonSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: PythonSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info if any - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h deleted file mode 100644 index 7d653b2087..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ - -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -class PythonSampler : public Sampler { - public: - // Constructor - // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the - // data from the dataset. - // @param py_sampler_instance - the python instance of the sampler - // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance, - int64_t samples_per_buffer = std::numeric_limits::max()); - - // Destructor. - ~PythonSampler() = default; - - // Initialize the sampler. - // @return Status - Status InitSampler() override; - - // for next epoch of sampleIds - // @return - The error code return - Status ResetSampler() override; - - // Op calls this to get next Buffer that contains all the sampleIds - // @param std::unique_ptr pBuffer - Buffer to be returned to corresponding Dataset Op - // @param int32_t workerId - not meant to be used - // @return - The error code return - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // Printer for debugging purposes. - // @param out - output stream to write to - // @param show_all - bool to show detailed vs summary - void Print(std::ostream &out, bool show_all) const override; - - private: - bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() - - py::object py_sampler_instance; // The handle to the py_sampler python object -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc deleted file mode 100644 index db0a96ea3a..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" - -#include -#include -#include -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, - int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), - seed_(GetSeed()), - replacement_(replacement), - next_id_(0), - reshuffle_each_epoch_(reshuffle_each_epoch), - dist(nullptr) {} - -Status RandomSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (next_id_ > num_samples_) { - RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); - } else if (next_id_ == num_samples_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); - - std::shared_ptr sampleIds; - int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_); - RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_)); - auto id_ptr = sampleIds->begin(); - - for (int64_t i = 0; i < (last_id - next_id_); i++) { - int64_t sampled_id = 0; - if (replacement_) { - sampled_id = (*dist)(rnd_); - } else { - sampled_id = shuffled_ids_[static_cast(i + next_id_)]; - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *(id_ptr + i) = sampled_id; - } - next_id_ = last_id; - TensorRow row(1, sampleIds); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - } - return Status::OK(); -} - -Status RandomSampler::InitSampler() { - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. - if (num_samples_ == 0 || num_samples_ > num_rows_) { - num_samples_ = num_rows_; - } - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); - samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; - rnd_.seed(seed_); - - if (replacement_ == false) { - shuffled_ids_.reserve(num_rows_); - for (int64_t i = 0; i < num_rows_; i++) { - shuffled_ids_.push_back(i); - } - std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); - } else { - dist = std::make_unique>(0, num_rows_ - 1); - } - - return Status::OK(); -} - -Status RandomSampler::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); - next_id_ = 0; - - if (reshuffle_each_epoch_) { - seed_++; - } - - rnd_.seed(seed_); - - if (replacement_ == false && reshuffle_each_epoch_) { - std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -void RandomSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: RandomSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info if any - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h deleted file mode 100644 index b1c54eb98c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ - -#include -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -class RandomSampler : public Sampler { - public: - // Constructor - // @param int64_t num_samples - number samples to draw - // @param bool replacement - put he id back / or not after a sample - // @param reshuffle_each_epoch - T/F to reshuffle after epoch - // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, - int64_t samples_per_buffer = std::numeric_limits::max()); - - // Destructor. - ~RandomSampler() = default; - - // Op calls this to get next Buffer that contains all the sampleIds - // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp - // @param int32_t workerId - not meant to be used - // @return - The error code return - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // meant to be called by base class or python - Status InitSampler() override; - - // for next epoch of sampleIds - // @return - The error code return - Status ResetSampler() override; - - virtual void Print(std::ostream &out, bool show_all) const; - - private: - uint32_t seed_; - bool replacement_; - std::vector shuffled_ids_; // only used for NO REPLACEMENT - int64_t next_id_; - std::mt19937 rnd_; - std::unique_ptr> dist; - bool reshuffle_each_epoch_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc deleted file mode 100644 index 1584166dc3..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ /dev/null @@ -1,176 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -#include - -namespace mindspore { -namespace dataset { -Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { - // The sampler base class itself does not compute it's own num_rows_ value. - // Instead, this value is computed by the derived leaf op during it's own initialization - // after it has interacted with it's storage layers. - // Here, it is just a getter method to return the value. However, it is invalid if there is - // not a value set for this count, so generate a failure if that is the case. - if (num == nullptr || num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet."); - } - (*num) = num_rows_; - return Status::OK(); -} - -Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) - : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} - -Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { - std::shared_ptr child_sampler; - if (HasChildSampler()) { - child_sampler = std::dynamic_pointer_cast(child_[0]); - if (!child_sampler) { - std::string err_msg("Cannot handshake, child is not a sampler object."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Handshake and init child first. - RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); - } - - CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); - - // If there's a child sampler, set the row count to be it's sample count - if (HasChildSampler()) { - num_rows_ = child_sampler->num_samples_; - } else { - RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); - } - - // It's up to the derived class to check the validity of the two args - // Because some sampler only needs one of the arg (weighted_random_sampler) - RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback - - return Status::OK(); -} - -Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements) { - if (num_elements == 0) { - RETURN_STATUS_UNEXPECTED("num of Elements is 0"); - } - if (col_desc_ == nullptr) { - // a ColDescriptor for Tensor that holds SampleIds - col_desc_ = std::make_unique("sampleIds", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1); - } - TensorShape shape(std::vector(1, num_elements)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, col_desc_->tensorImpl(), shape, col_desc_->type())); - RETURN_IF_NOT_OK( - (*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes())); // allocate memory in case user forgets! - return Status::OK(); -} - -void Sampler::Print(std::ostream &out, bool show_all) const { - // Sampler printing is usually only called in the show_all mode. - // Derived classes will display the name, then call back to this base - // for common info. - // No-op in the summary mode. - if (show_all) { - out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_; - } -} - -Status Sampler::GetAllIdsThenReset(py::array *data) { - std::unique_ptr db; - std::shared_ptr sample_ids; - TensorRow sample_row; - - // A call to derived class to get sample ids wrapped inside a buffer - RETURN_IF_NOT_OK(GetNextSample(&db)); - // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch - RETURN_IF_NOT_OK(db->GetRow(0, &sample_row)); - sample_ids = sample_row[0]; - - // check this buffer is not a ctrl buffer - CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); - { - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data)); - } catch (const std::runtime_error &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } - } - // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch - RETURN_IF_NOT_OK(GetNextSample(&db)); - CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); - // Reset Sampler since this is the end of the epoch - RETURN_IF_NOT_OK(ResetSampler()); - return Status::OK(); -} - -Status Sampler::SetNumSamples(int64_t num_samples) { - CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "num_samples is negative"); - num_samples_ = num_samples; - return Status::OK(); -} - -Status Sampler::SetNumRowsInDataset(int64_t num_rows) { - CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "num_rows is negative or 0"); - num_rows_ = num_rows; - return Status::OK(); -} - -Status Sampler::AddChild(std::shared_ptr child) { - if (child == nullptr) { - return Status::OK(); - } - - // Only samplers can be added, not any other DatasetOp. - std::shared_ptr sampler = std::dynamic_pointer_cast(child); - if (!sampler) { - std::string err_msg("Cannot add child, child is not a sampler object."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Samplers can have at most 1 child. - if (!child_.empty()) { - std::string err_msg("Cannot add child sampler, this sampler already has a child."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - child_.push_back(child); - - // doesn't work, protected? - // child->AddParent(this); - return Status::OK(); -} - -bool Sampler::HasChildSampler() { return !child_.empty(); } - -Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { - if (child_ids_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); - } - - TensorRow sample_row; - RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - RETURN_IF_NOT_OK(sample_ids->GetItemAt(out_associated_id, {id})); - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h deleted file mode 100644 index 34c3cb7935..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h +++ /dev/null @@ -1,159 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/dataset_op.h" - -namespace mindspore { -namespace dataset { -// RandomAccessOp is a base class that all data-producing leaf operators -// must inherit from if those leaf operator wish to support sampling. -class RandomAccessOp { - public: - // Sampler get number of rows in the dataset - // @param int64_t num - return number of rows for this dataset - // @return - The error code return - Status GetNumRowsInDataset(int64_t *num_rows) const; - - // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK - // @param std::map> * map - // @return - The error code return - virtual Status GetClassIds(std::map> *map) const { - RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK"); - } - - // default destructor - virtual ~RandomAccessOp() = default; - - protected: - // The amount of rows in the dataset itself. This is the before-sampling value, the - // total count of rows. A sampler may choose to sample less than this amount. - int64_t num_rows_; -}; - -class Sampler { - public: - // Constructor - // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 - // indicates that the sampler should produce the complete set of ids. - // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); - - Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {} - - // default destructor - ~Sampler() = default; - - // Get a list of sample ids. - // @note It is Sampler responsibility to make sure that the id is not out of bound. - // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp - // @param int32_t workerId - not meant to be used - // @return - The error code return - virtual Status GetNextSample(std::unique_ptr *out_buffer) = 0; - - // return all ids in one epoch as a numpy array, then call reset - Status GetAllIdsThenReset(py::array *data); - - // for next epoch of sampleIds - // @return - The error code return - virtual Status ResetSampler() = 0; - - // first handshake between leaf source op and Sampler. This func will determine the amount of data - // in the dataset that we can sample from. - // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is - // @return - virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); - - // initialize sampler and perform checks on certain vars - virtual Status InitSampler() { return Status::OK(); } - - // setter for num samples - // @param num_samples - the number of samples to assign. - // @return status error code - Status SetNumSamples(int64_t num_samples); - - // setter for num or records in the dataset - // @param num_rows - the number of records - // @return status error code - Status SetNumRowsInDataset(int64_t num_rows); - - // Adds a sampler to become our child. - // @param std::shared_ptr - The sampler to add as a child. - // @return - The error code returned. - Status AddChild(std::shared_ptr child); - - // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler - // @param std::shared_ptr* sampleIds - // @param int64_t numElements - must be a non 0 number - // @return - The error code returned. - Status CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements); - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - virtual void Print(std::ostream &out, bool show_all) const; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param sampler - reference to teh sampler to print - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { - sampler.Print(out, false); - return out; - } - - // Checks if this sampler has a child sampler. - // @return - tre if there is a child sampler, false otherwise. - bool HasChildSampler(); - - // Uses id as an index for the list of ids generated by the child sampler, and gets the - // associated id. - // @param int64_t* out_associated_id - Out parameter, contains the associated id. - // @param int64_t id - The id used as an index to get the associated child id. - // @return - The error code returned. - Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); - - protected: - // Number of rows of data from the place this sampler is sampling from. If this sampler - // has a child sampler, num_rows_ is the number of ids the child sampler will - // output. Otherwise, num_rows_ is the number of rows in the dataset. - int64_t num_rows_; - - // The user may want to sample less than the full amount of data. num_samples_ reduces the number - // of id's returned as request by the user. Derived classes will choose how to sample the smaller - // amount. - int64_t num_samples_; - - int64_t samples_per_buffer_; - std::unique_ptr col_desc_; - std::vector> child_; // Child nodes - std::unique_ptr child_ids_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc deleted file mode 100644 index 28598da55f..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" - -#include -#include - -namespace mindspore { -namespace dataset { -SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} - -Status SequentialSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (id_count_ > num_samples_) { - RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); - } else if (id_count_ == num_samples_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - (*out_buffer) = std::make_unique(current_id_, DataBuffer::kDeBFlagNone); - std::shared_ptr sampleIds; - - // Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for - // samples per buffer though. - int64_t remaining_ids = num_samples_ - id_count_; - int64_t num_elements = std::min(remaining_ids, samples_per_buffer_); - - RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements)); - auto idPtr = sampleIds->begin(); - for (int64_t i = 0; i < num_elements; i++) { - int64_t sampled_id = current_id_; - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *idPtr = sampled_id; - current_id_++; // Move the current id to the next one in the sequence - idPtr++; - } - - id_count_ += num_elements; // Count the packed ids towards our overall sample count - - TensorRow row(1, sampleIds); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - } - return Status::OK(); -} - -Status SequentialSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); - CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0, "num_samples < 0\n"); - // Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample - // the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data. - int64_t available_row_count = num_rows_ - start_index_; - if (num_samples_ == 0 || num_samples_ > available_row_count) { - num_samples_ = available_row_count; - } - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); - samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; - return Status::OK(); -} - -Status SequentialSampler::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); - current_id_ = start_index_; - id_count_ = 0; - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -void SequentialSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: SequentialSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info - out << "\nStart index: " << start_index_; - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h deleted file mode 100644 index 06f084fb7a..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ - -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -class SequentialSampler : public Sampler { - public: - // Constructor - // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the - // full amount of ids from the dataset - // @param start_index - The starting index value - // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit SequentialSampler(int64_t num_samples, int64_t start_index, - int64_t samples_per_buffer = std::numeric_limits::max()); - - // Destructor. - ~SequentialSampler() = default; - - // init sampler, called by python - Status InitSampler() override; - - // for next epoch of sampleIds - // @return - The error code return - Status ResetSampler() override; - - // Op calls this to get next Buffer that contains all the sampleIds - // @param std::unique_ptr pBuffer - Buffer to be returned to corresponding Dataset Op - // @param int32_t workerId - not meant to be used - // @return - The error code return - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // Printer for debugging purposes. - // @param out - output stream to write to - // @param show_all - bool to show detailed vs summary - void Print(std::ostream &out, bool show_all) const override; - - private: - int64_t current_id_; // The id sequencer. Each new id increments from this - int64_t start_index_; // The starting id. current_id_ begins from here. - int64_t id_count_; // An internal counter that tracks how many ids have been produced -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc deleted file mode 100644 index 08a623ed1b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" - -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -// Constructor. -SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector &indices, - int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} - -// Initialized this Sampler. -Status SubsetRandomSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); - - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // In this case, the id's are provided by the user. Cap the num_samples on the number of id's given. - if (num_samples_ == 0 || num_samples_ > static_cast(indices_.size())) { - num_samples_ = static_cast(indices_.size()); - } - // Initialize random generator with seed from config manager - rand_gen_.seed(GetSeed()); - - if (samples_per_buffer_ > num_samples_) { - samples_per_buffer_ = num_samples_; - } - - // num_samples_ could be smaller than the total number of input id's. - // We will shuffle the full set of id's, but only select the first num_samples_ of them later. - std::shuffle(indices_.begin(), indices_.end(), rand_gen_); - - return Status::OK(); -} - -// Reset the internal variable to the initial state. -Status SubsetRandomSampler::ResetSampler() { - // Reset the internal counters. - sample_id_ = 0; - buffer_id_ = 0; - - // Randomized the indices again. - rand_gen_.seed(GetSeed()); - std::shuffle(indices_.begin(), indices_.end(), rand_gen_); - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -// Get the sample ids. -Status SubsetRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { - // All samples have been drawn - if (sample_id_ == num_samples_) { - (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); - std::shared_ptr outputIds; - - int64_t last_id = sample_id_ + samples_per_buffer_; - // Handling the return all samples at once, and when last draw is not a full batch. - if (last_id > num_samples_) { - last_id = num_samples_; - } - - // Allocate tensor - RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_)); - - // Initialize tensor - auto id_ptr = outputIds->begin(); - while (sample_id_ < last_id) { - if (indices_[sample_id_] >= num_rows_) { - std::string err_msg = - "Generated id is bigger than numRows (out of bound). indices_: " + std::to_string(indices_[sample_id_]) + - " num_rows_: " + std::to_string(num_rows_); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - int64_t sampled_id = indices_[sample_id_]; - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *id_ptr = sampled_id; - id_ptr++; - sample_id_++; - } - - // Create a TensorTable from that single tensor and push into DataBuffer - (*out_buffer)->set_tensor_table(std::make_unique(1, TensorRow(1, outputIds))); - } - - return Status::OK(); -} - -void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: SubsetRandomSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info if any - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h deleted file mode 100644 index ffc7cb17bc..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ - -#include -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -// Randomly samples elements from a given list of indices, without replacement. -class SubsetRandomSampler : public Sampler { - public: - // Constructor. - // @param num_samples The number of samples to draw. 0 for the full amount. - // @param indices List of indices from where we will randomly draw samples. - // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). - // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. - explicit SubsetRandomSampler(int64_t num_samples, const std::vector &indices, - std::int64_t samples_per_buffer = std::numeric_limits::max()); - - // Destructor. - ~SubsetRandomSampler() = default; - - // Initialize the sampler. - // @return Status - Status InitSampler() override; - - // Reset the internal variable to the initial state and reshuffle the indices. - // @return Status - Status ResetSampler() override; - - // Get the sample ids. - // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. - // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // Printer for debugging purposes. - // @param out - output stream to write to - // @param show_all - bool to show detailed vs summary - void Print(std::ostream &out, bool show_all) const override; - - private: - // A list of indices (already randomized in constructor). - std::vector indices_; - - // Current sample id. - int64_t sample_id_; - - // Current buffer id. - int64_t buffer_id_; - - // A random number generator. - std::mt19937 rand_gen_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc deleted file mode 100644 index 6bf3d2d85e..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ /dev/null @@ -1,169 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" - -#include -#include -#include -#include -#include - -#include "dataset/core/global_context.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -// Constructor. -WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, - int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), - weights_(weights), - replacement_(replacement), - sample_id_(0), - buffer_id_(0) {} - -// Initialized this Sampler. -Status WeightedRandomSampler::InitSampler() { - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. - if (num_samples_ == 0 || num_samples_ > num_rows_) { - num_samples_ = num_rows_; - } - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, "num_samples & num_rows need to be positive"); - CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); - - // Initialize random generator with seed from config manager - rand_gen_.seed(GetSeed()); - - samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; - - if (!replacement_) { - exp_dist_ = std::make_unique>(1); - InitOnePassSampling(); - } else { - discrete_dist_ = std::make_unique>(weights_.begin(), weights_.end()); - } - - return Status::OK(); -} - -// Initialized the computation for generating weighted random numbers without replacement using onepass method. -void WeightedRandomSampler::InitOnePassSampling() { - exp_dist_->reset(); - onepass_ids_.clear(); - std::vector> val_idx; - for (size_t i = 0; i < weights_.size(); i++) { - val_idx.emplace_back(std::make_pair((*exp_dist_)(rand_gen_) / weights_[i], i)); - } - - // Partial sort the first `numSamples` elements. - std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end()); - for (int64_t i = 0; i < num_samples_; i++) { - onepass_ids_.push_back(val_idx[i].second); - } -} - -// Reset the internal variable to the initial state and reshuffle the indices. -Status WeightedRandomSampler::ResetSampler() { - sample_id_ = 0; - buffer_id_ = 0; - rand_gen_.seed(GetSeed()); - if (!replacement_) { - InitOnePassSampling(); - } else { - discrete_dist_->reset(); - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -// Get the sample ids. -Status WeightedRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (weights_.size() > static_cast(num_rows_)) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); - } - - if (!replacement_ && (weights_.size() < static_cast(num_samples_))) { - RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); - } - - if (sample_id_ == num_samples_) { - (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); - std::shared_ptr outputIds; - - int64_t last_id = sample_id_ + samples_per_buffer_; - // Handling the return all samples at once, and when last draw is not a full batch. - if (last_id > num_samples_) { - last_id = num_samples_; - } - - // Allocate tensor. - RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_)); - - // Initialize tensor. - auto id_ptr = outputIds->begin(); - // Assign the data to tensor element. - while (sample_id_ < last_id) { - int64_t genId; - if (replacement_) { - genId = (*discrete_dist_)(rand_gen_); - } else { - // Draw sample without replacement. - genId = onepass_ids_.front(); - onepass_ids_.pop_front(); - } - - if (genId >= num_rows_) { - RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound)."); - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId)); - } - - *id_ptr = genId; - id_ptr++; - sample_id_++; - } - - // Create a TensorTable from that single tensor and push into DataBuffer - (*out_buffer)->set_tensor_table(std::make_unique(1, TensorRow(1, outputIds))); - } - - return Status::OK(); -} - -void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: WeightedRandomSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info if any - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h deleted file mode 100644 index 1fbe29ed80..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ - -#include -#include -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -// Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights). -class WeightedRandomSampler : public Sampler { - public: - // Constructor. - // @param num_samples Number of samples to be drawn. - // @param weights A lift of sample weights. - // @param replacement Determine if samples are drawn with/without replacement. - // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). - // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. - WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, - int64_t samples_per_buffer = std::numeric_limits::max()); - - // Destructor. - ~WeightedRandomSampler() = default; - - // Initialize the sampler. - // @param op (Not used in this sampler) - // @return Status - Status InitSampler() override; - - // Reset the internal variable to the initial state and reshuffle the indices. - Status ResetSampler() override; - - // Get the sample ids. - // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. - // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // Printer for debugging purposes. - // @param out - output stream to write to - // @param show_all - bool to show detailed vs summary - void Print(std::ostream &out, bool show_all) const override; - - private: - // A list of weights for each sample. - std::vector weights_; - - // A flag indicating if samples are drawn with/without replacement. - bool replacement_; - - // Current sample id. - int64_t sample_id_; - - // Current buffer id. - int64_t buffer_id_; - - // Random engine and device - std::mt19937 rand_gen_; - - // Discrete distribution for generating weighted random numbers with replacement. - std::unique_ptr> discrete_dist_; - - // Exponential distribution for generating weighted random numbers without replacement. - // based on "Accelerating weighted random sampling without replacement" by Kirill Muller. - std::unique_ptr> exp_dist_; - - // Initialized the computation for generating weighted random numbers without replacement - // using onepass method. - void InitOnePassSampling(); - - // Store the random weighted ids generated by onepass method in `InitOnePassSampling` - std::deque onepass_ids_; -}; -} // namespace dataset -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc deleted file mode 100644 index 818b5ab3f4..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc +++ /dev/null @@ -1,498 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/core/config_manager.h" -#include "dataset/util/task_manager.h" -#include "dataset/util/wait_post.h" -#include "dataset/util/random.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { -TextFileOp::Builder::Builder() - : builder_device_id_(0), - builder_num_devices_(1), - builder_total_rows_(0), - builder_shuffle_files_(false), - builder_sampler_(nullptr) { - std::shared_ptr config_manager = GlobalContext::config_manager(); - builder_num_workers_ = config_manager->num_parallel_workers(); - builder_op_connector_size_ = config_manager->op_connector_size(); - builder_rows_per_buffer_ = config_manager->rows_per_buffer(); - builder_worker_connector_size_ = config_manager->worker_connector_size(); -} - -Status TextFileOp::Builder::ValidateInputs() const { - std::string err_msg; - err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; - err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -Status TextFileOp::Builder::Build(std::shared_ptr *op) { - RETURN_IF_NOT_OK(ValidateInputs()); - - // Throttle the number of workers if we have more workers than files! - if (static_cast(builder_num_workers_) > builder_text_files_list_.size()) { - builder_num_workers_ = builder_text_files_list_.size(); - MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; - } - - builder_schema_ = std::make_unique(); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - - std::shared_ptr text_file_op = std::make_shared( - builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, - std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, - builder_num_devices_, builder_device_id_, std::move(builder_sampler_)); - RETURN_IF_NOT_OK(text_file_op->Init()); - *op = std::move(text_file_op); - - return Status::OK(); -} - -TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, - std::unique_ptr schema, std::vector text_files_list, - int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, - std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), - device_id_(device_id), - num_devices_(num_device), - rows_per_buffer_(rows_per_buffer), - total_rows_(total_rows), - text_files_list_(std::move(text_files_list)), - shuffle_files_(shuffle_files), - data_schema_(std::move(schema)), - all_num_rows_(0), - num_rows_per_shard_(0), - filename_index_(std::make_unique()), - finished_reading_dataset_(false), - load_io_block_queue_(true), - load_jagged_connector_(true) { - worker_connector_size_ = worker_connector_size; -} - -// A print method typically used for debugging -void TextFileOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nRows per buffer: " << rows_per_buffer_ << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ - << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") - << "\nText files list:\n"; - for (int i = 0; i < text_files_list_.size(); ++i) { - out << " " << text_files_list_[i]; - } - out << "\nData Schema:\n"; - out << *data_schema_ << "\n\n"; - } -} - -Status TextFileOp::Init() { - RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_)); - - int32_t safe_queue_size = static_cast(std::ceil(text_files_list_.size() / num_workers_) + 1); - io_block_queues_.Init(num_workers_, safe_queue_size); - - RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); - - jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); - return Status::OK(); -} - -Status TextFileOp::Reset() { - load_jagged_connector_ = true; - load_io_block_queue_ = true; - - RETURN_IF_NOT_OK(ParallelOp::Reset()); - NotifyToFillIOBlockQueue(); - return Status::OK(); -} - -Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { - TensorRow tRow(1, nullptr); - (*tensor_table)->push_back(std::move(tRow)); - - std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); - (**tensor_table)[row][0] = std::move(tensor); - return Status::OK(); -} - -Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, - const int32_t worker_id) { - std::ifstream handle(file); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Failed to open file " + file); - } - - int64_t rows_each_buffer = 0; - int64_t rows_total = 0; - std::string line; - std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - std::unique_ptr tensor_table = std::make_unique(); - - while (getline(handle, line)) { - if (line.empty()) { - continue; - } - // If read to the end offset of this file, break. - if (rows_total >= end_offset) { - break; - } - // Skip line before start offset. - if (rows_total < start_offset) { - rows_total++; - continue; - } - - RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); - rows_each_buffer++; - rows_total++; - if (rows_each_buffer == rows_per_buffer_) { - cur_buffer->set_tensor_table(std::move(tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); - - cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - tensor_table = std::make_unique(); - rows_each_buffer = 0; - } - } - - if (rows_each_buffer > 0) { - cur_buffer->set_tensor_table(std::move(tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); - } - - return Status::OK(); -} - -Status TextFileOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - - std::unique_ptr io_block; - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - while (!io_block->eof()) { - if (!io_block->eoe()) { - if (load_jagged_connector_) { - std::string filename; - RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); - int64_t start_offset = io_block->GetStartOffset(); - int64_t end_offset = io_block->GetEndOffset(); - RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); - } - } else { - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); - } - - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - } - return Status::OK(); -} - -// Pops an element from a queue in io_block_queues -Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); - - return Status::OK(); -} - -// Pushes an element to a queue in io_block_queues -Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); - - return Status::OK(); -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. -// When the worker pops this control indicator, it will shut itself down gracefully. -Status TextFileOp::PostEndOfData() { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); - } - - return Status::OK(); -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker -// pops this control indicator, it will wait until the next epoch starts and then resume execution. -Status TextFileOp::PostEndOfEpoch(int32_t queue_index) { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); - } - - return Status::OK(); -} - -static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { - std::mt19937 rng(seed); - std::shuffle(i_keys->begin(), i_keys->end(), rng); -} - -bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count) { - *start_offset = 0; - *end_offset = 0; - bool push = false; - int64_t start_index = device_id_ * num_rows_per_shard_; - if (device_id_ + 1 < 0) { - MS_LOG(ERROR) << "Device id is invalid"; - return false; - } - - int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; - if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { - *start_offset = start_index - pre_count; - push = true; - if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - if (pre_count >= start_index && pre_count < end_index) { - *start_offset = 0; - push = true; - if (pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - return push; -} - -Status TextFileOp::FillIOBlockQueue(const std::vector &i_keys) { - int32_t queue_index = 0; - int64_t pre_count = 0; - int64_t start_offset = 0; - int64_t end_offset = 0; - bool finish = false; - while (!finish) { - std::vector> file_index; - if (!i_keys.empty()) { - for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { - { - if (!load_io_block_queue_) { - break; - } - } - file_index.emplace_back(std::pair((*filename_index_)[*it], *it)); - } - } else { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - { - if (!load_io_block_queue_) { - break; - } - } - file_index.emplace_back(std::pair(it.value(), it.key())); - } - } - for (auto file_info : file_index) { - if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { - auto ioBlock = - std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - queue_index = (queue_index + 1) % num_workers_; - } - - pre_count += filename_numrows_[file_info.first]; - } - - if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { - finish = false; - } else { - finish = true; - } - } - - RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); - return Status::OK(); -} - -Status TextFileOp::WaitToFillIOBlockQueue() { - // must be called first if called by worker spanwed by taskgroup - TaskManager::FindMe()->Post(); - - std::vector i_keys; - if (shuffle_files_) { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - i_keys.push_back(it.key()); - } - } - uint32_t seed = 0; - while (true) { - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); - io_block_queue_wait_post_.Clear(); - - if (finished_reading_dataset_) { - break; - } - - if (shuffle_files_) { - ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); - } - RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); - } - return Status::OK(); -} - -void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } - -Status TextFileOp::operator()() { - RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - - // launch one thread, responsible for filling IoBlockQueue - RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this))); - - // Read data from disk into buffers - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1))); - - // must be called after launching workers. - TaskManager::FindMe()->Post(); - - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); - NotifyToFillIOBlockQueue(); - while (!finished_reading_dataset_) { - int64_t buffer_id = 0; - int32_t workers_done = 0; - int64_t rows_read = 0; - load_io_block_queue_ = true; - - while (workers_done < num_workers_) { - std::unique_ptr buffer; - RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); - if (buffer->eoe()) { - workers_done++; - } else if (total_rows_ == 0 || rows_read < total_rows_) { - if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) { - int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read); - RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); - } - rows_read += buffer->NumRows(); - buffer->set_id(buffer_id++); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); - } else { - // end of epoch - load_jagged_connector_ = false; - load_io_block_queue_ = false; - } - } - - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - finished_reading_dataset_ = true; - NotifyToFillIOBlockQueue(); - } else { - jagged_buffer_connector_->DoReset(); - buffer_id = 0; - } - } - - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - - RETURN_IF_NOT_OK(PostEndOfData()); - - return Status::OK(); -} - -int64_t TextFileOp::CountTotalRows(const std::string &file) { - std::ifstream handle(file); - if (!handle.is_open()) { - MS_LOG(ERROR) << "Failed to open file: " << file; - return 0; - } - - std::string line; - int64_t count = 0; - while (getline(handle, line)) { - if (!line.empty()) { - count++; - } - } - - return count; -} - -Status TextFileOp::CalculateNumRowsPerShard() { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - int64_t count = CountTotalRows(it.value()); - filename_numrows_[it.value()] = count; - all_num_rows_ += count; - } - if (all_num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API TextFileDataset.Please check file path or dataset API " - "validation first."); - } - - num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); - MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; - return Status::OK(); -} - -Status TextFileOp::CountAllFileRows(const std::vector &files, int64_t *count) { - std::shared_ptr op; - *count = 0; - RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op)); - for (auto file : files) { - *count += op->CountTotalRows(file); - } - return Status::OK(); -} - -Status TextFileOp::ComputeColMap() { - // Set the column name mapping (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h deleted file mode 100644 index 5b787d4dad..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h +++ /dev/null @@ -1,289 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/util/auto_index.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/util/queue.h" -#include "dataset/util/wait_post.h" -#include "dataset/engine/jagged_connector.h" - -namespace mindspore { -namespace dataset { -using StringIndex = AutoIndexObj; - -class TextFileOp : public ParallelOp { - public: - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Checks if the inputs of the builder is valid. - // @return Status - the error code returned. - Status ValidateInputs() const; - - // Create the final object. - // @param op - dataset op. - // @return - the error code return. - Status Build(std::shared_ptr *op); - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumDevices(int64_t num_dev) { - builder_num_devices_ = num_dev; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetDeviceId(int64_t dev_id) { - builder_device_id_ = dev_id; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetTextFilesList(const std::vector &files_list) { - builder_text_files_list_ = files_list; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShuffleFiles(bool shuffle_files) { - builder_shuffle_files_ = shuffle_files; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetTotalRows(int64_t total_rows) { - builder_total_rows_ = total_rows; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - private: - int32_t builder_device_id_; - int32_t builder_num_devices_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - int64_t builder_rows_per_buffer_; - int64_t builder_total_rows_; - int32_t builder_worker_connector_size_; - std::vector builder_text_files_list_; - bool builder_shuffle_files_; - std::unique_ptr builder_schema_; - std::shared_ptr builder_sampler_; - }; - - // Constructor of TextFileOp - // @note The builder class should be used to call this constructor. - // @param num_workers - number of worker threads reading data from tf_file files. - // @param rows_per_buffer - number of rows that a full buffer will contain. - // @param total_num_rows - number of rows to read - // @param dataset_files_list - list of filepaths for the dataset files. - // @param data_schema - the data schema object. - // @param op_connector_size - size of each queue in the connector that the child operator pulls from. - // @param columns_to_load - the names of the columns to load data from. - // @param shuffle_files - whether or not to shuffle the files before reading data. - // @param equal_rows_per_shard - whether or not to get equal rows for each process. - // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes - TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, - std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); - - // Default destructor - ~TextFileOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // Instantiates the internal queues and connectors - // @return Status - the error code returned - Status Init(); - - // Class functor operator () override. - // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - the error code returned. - Status operator()() override; - - // Overrides base class reset method. Cleans up any state info from it's previous execution - // reinitializes itself so that it can be executed again, as if it was just created. - // @return Status - the error code returned. - Status Reset() override; - - // Get total rows in files. - // @param files - all text files. - // @param count - number of rows. - // @return Status - the error coed returned. - static Status CountAllFileRows(const std::vector &files, int64_t *count); - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "TextFileOp"; } - - // File names getter - // @return Vector of the input file names - std::vector FileNames() { return text_files_list_; } - - private: - // The entry point for when workers are launched. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status WorkerEntry(int32_t worker_id) override; - - // Parses a single row and puts the data into a tensor table. - // @param line - the content of the row. - // @param tensor_table - the tensor table to put the parsed data in. - // @param row - the id of the row filled in the tensor table. - // @return Status - the error code returned. - Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); - - // Reads a text file and loads the data into multiple buffers. - // @param file - the file to read. - // @param start_offset - the start offset of file. - // @param end_offset - the end offset of file. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, - const int32_t worker_id); - - // Calculate number of rows in each shard. - // @return Status - the error code returned. - Status CalculateNumRowsPerShard(); - - // Count number of rows in each file. - // @param filename - text file name. - // @return int64_t - the total number of rows in file. - int64_t CountTotalRows(const std::string &file); - - // Notifies the thread which called FillIoBlockQueue to resume execution - void NotifyToFillIOBlockQueue(); - - // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. - // @return Status - the error code returned. - Status WaitToFillIOBlockQueue(); - - // Fill the IOBlockQueue. - // @para i_keys - keys of file to fill to the IOBlockQueue - // @return Status - the error code returned. - Status FillIOBlockQueue(const std::vector &i_keys); - - // Select file and push it to the block queue. - // @param file_name - File name. - // @param start_file - If file contains the first sample of data. - // @param end_file - If file contains the end sample of data. - // @param pre_count - Total rows of previous files. - // @return Status - the error code returned. - bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count); - - // Pops an element from a queue in IOBlockQueue. - // @param index - the index of the queue to pop from. - // @param out_block - the popped element. - // @return Status - the error code returned. - Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); - - // Pushes an element to a queue in IOBlockQueue. - // @param index - the index of the queue to push to. - // @param io_block - the element to push onto the queue. - // @return Status - the error code returned. - Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. - // When the worker pops this control indicator, it will shut itself down gracefully. - // @return Status - the error code returned. - Status PostEndOfData(); - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker - // pops this control indicator, it will wait until the next epoch starts and then resume execution. - // @return Status - the error code returned. - Status PostEndOfEpoch(int32_t queue_index); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t device_id_; - int32_t num_devices_; - int64_t rows_per_buffer_; - int64_t total_rows_; - std::vector text_files_list_; - bool shuffle_files_; - std::unique_ptr data_schema_; - int64_t all_num_rows_; - int64_t num_rows_per_shard_; - std::map filename_numrows_; - std::unique_ptr filename_index_; - QueueList> io_block_queues_; - WaitPost io_block_queue_wait_post_; - bool finished_reading_dataset_; - bool load_io_block_queue_; - bool load_jagged_connector_; - std::unique_ptr jagged_buffer_connector_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc deleted file mode 100644 index 48f13ff766..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ /dev/null @@ -1,1057 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/tf_reader_op.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "proto/example.pb.h" -#include "./securec.h" -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/connector.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/jagged_connector.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" -#include "dataset/util/task_manager.h" -#include "dataset/util/wait_post.h" -#include "utils/system/crc32c.h" - -namespace mindspore { -namespace dataset { -TFReaderOp::Builder::Builder() - : builder_device_id_(0), - builder_num_devices_(1), - builder_total_rows_(0), - builder_equal_rows_per_shard_(false), - builder_sampler_(nullptr) { - std::shared_ptr config_manager = GlobalContext::config_manager(); - builder_num_workers_ = config_manager->num_parallel_workers(); - builder_worker_connector_size_ = config_manager->worker_connector_size(); - builder_op_connector_size_ = config_manager->op_connector_size(); - builder_rows_per_buffer_ = config_manager->rows_per_buffer(); - builder_shuffle_files_ = false; - builder_data_schema_ = std::make_unique(); -} - -bool ValidateFirstRowCrc(const std::string &filename) { - std::ifstream reader; - reader.open(filename); - if (!reader) { - return false; - } - - // read data - int64_t record_length = 0; - (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); - - // read crc from file - uint32_t masked_crc = 0; - (void)reader.read(reinterpret_cast(&masked_crc), static_cast(sizeof(uint32_t))); - - // generate crc from data - uint32_t generated_crc = - system::Crc32c::GetMaskCrc32cValue(reinterpret_cast(&record_length), sizeof(int64_t)); - - return masked_crc == generated_crc; -} - -Status TFReaderOp::Builder::ValidateInputs() const { - std::string err_msg; - - if (builder_num_workers_ <= 0) { - err_msg += "Number of parallel workers is smaller or equal to 0\n"; - } - - if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { - err_msg += "Wrong sharding configs\n"; - } - - std::vector invalid_files(builder_dataset_files_list_.size()); - auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(), - [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); - invalid_files.resize(std::distance(invalid_files.begin(), it)); - - if (!invalid_files.empty()) { - err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n"; - - std::string accumulated_filenames = std::accumulate( - invalid_files.begin(), invalid_files.end(), std::string(""), - [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); - err_msg += accumulated_filenames; - } - - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -Status TFReaderOp::Builder::Build(std::shared_ptr *out_tf_reader_op) { - RETURN_IF_NOT_OK(ValidateInputs()); - - // Throttle the number of workers if we have more workers than files! - if (static_cast(builder_num_workers_) > builder_dataset_files_list_.size()) { - builder_num_workers_ = builder_dataset_files_list_.size(); - MS_LOG(WARNING) << "TFReader operator parallelism reduced to " << builder_num_workers_ << " workers."; - } - - std::shared_ptr new_tf_reader_op = std::make_shared( - builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, - builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, - builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_, - std::move(builder_sampler_)); - - RETURN_IF_NOT_OK(new_tf_reader_op->Init()); - *out_tf_reader_op = std::move(new_tf_reader_op); - return Status::OK(); -} - -TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, - int64_t total_num_rows, std::vector dataset_files_list, - std::unique_ptr data_schema, int32_t op_connector_size, - std::vector columns_to_load, bool shuffle_files, int32_t num_device, - int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), - device_id_(device_id), - num_devices_(num_device), - rows_per_buffer_(rows_per_buffer), - total_rows_(total_num_rows), - dataset_files_list_(std::move(dataset_files_list)), - columns_to_load_(std::move(columns_to_load)), - finished_reading_dataset_(false), - shuffle_files_(shuffle_files), - data_schema_(std::move(data_schema)), - filename_index_(std::make_unique()), - load_io_block_queue_(true), - load_jagged_connector_(true), - num_rows_(0), - num_rows_per_shard_(0), - equal_rows_per_shard_(equal_rows_per_shard) { - worker_connector_size_ = worker_connector_size; -} - -// A print method typically used for debugging -void TFReaderOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nRows per buffer: " << rows_per_buffer_ << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ - << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") - << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n"; - for (int i = 0; i < dataset_files_list_.size(); ++i) { - out << " " << dataset_files_list_[i]; - } - if (!columns_to_load_.empty()) { - out << "\nColumns to load:\n"; - for (int i = 0; i < columns_to_load_.size(); ++i) { - out << " " << columns_to_load_[i]; - } - } - out << "\nData Schema:\n"; - out << *data_schema_ << "\n\n"; - } -} - -Status TFReaderOp::Init() { - if (data_schema_->Empty()) { - RETURN_IF_NOT_OK(CreateSchema(dataset_files_list_[0], columns_to_load_)); - } - - if (total_rows_ == 0) { - total_rows_ = data_schema_->num_rows(); - } - if (total_rows_ < 0) { - RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0"); - } - - // Build the index with our files such that each file corresponds to a key id. - RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); - - // The creation of the internal connector has been delayed until now, since we may have adjusted the - // number of workers. Now that the worker count is established, create the connector now in the - // parallel op base. - RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); - - jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); - - // temporary: make size large enough to hold all files + EOE to avoid hangs - int32_t safe_queue_size = static_cast(std::ceil(dataset_files_list_.size() / num_workers_)) + 1; - io_block_queues_.Init(num_workers_, safe_queue_size); - - return Status::OK(); -} - -Status TFReaderOp::CalculateNumRowsPerShard() { - if (!equal_rows_per_shard_) { - return Status::OK(); - } - - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - std::vector file(1, it.value()); - int64_t num = CountTotalRowsSectioned(file, 0, 1); - filename_numrows_[it.value()] = num; - num_rows_ += num; - } - num_rows_per_shard_ = static_cast(std::ceil(num_rows_ * 1.0 / num_devices_)); - if (num_rows_per_shard_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API TFRecordDataset.Please check file path or dataset API " - "validation first."); - } - return Status::OK(); -} -// Class functor operator () override. -// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will -// provide the master loop that drives the logic for performing the work -Status TFReaderOp::operator()() { - RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - - // launch one thread, responsible for filling mIOBlockQueue - RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TFReaderOp::WaitToFillIOBlockQueue, this))); - - // launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading - // data from disk into buffers - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&TFReaderOp::WorkerEntry, this, std::placeholders::_1))); - - // must be called after launching workers. workers can't be spawned after this post, - // so workers have to be kept alive until the end of the program - TaskManager::FindMe()->Post(); - - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); - - NotifyToFillIOBlockQueue(); - while (!finished_reading_dataset_) { - int64_t buffer_id = 0; - int32_t workers_done = 0; - int64_t rows_read = 0; - { - std::unique_lock lock(load_io_block_queue_mutex_); - load_io_block_queue_ = true; - } - - while (workers_done < num_workers_) { - std::unique_ptr fetched_buffer; - RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer)); - if (fetched_buffer->eoe()) { - workers_done++; - } else if (total_rows_ == 0 || rows_read < total_rows_) { - // we need to push a buffer - if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) { - // this is last buffer we need, and we only need a part of it - int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read); - RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove)); - } - - rows_read += fetched_buffer->NumRows(); - fetched_buffer->set_id(buffer_id); - buffer_id++; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); - } else { - // user specified number of rows they want, and we read enough rows - // - // IOBlockQueue thread needs to: - // -stop pushing stuff to IOBlockQueue - // -call PostEndOfEpoch (will send EOE) - // -wait for reset - // - // Worker threads need to: - // -stop reading the file they are currently reading and throw it away - // -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE) - // - // Master thread needs to: - // -tell IOBlockQueue thread to stop pushing - // -tell worker threads to stop reading the file tey are currently reading - // -keep pulling until EOE - - // don't think we need a lock for now - load_jagged_connector_ = false; - - std::unique_lock lock(load_io_block_queue_mutex_); - load_io_block_queue_ = false; - } - } - - // all workers finished reading for this epoch, and we have read all the data from all workers - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - finished_reading_dataset_ = true; - NotifyToFillIOBlockQueue(); - } else { - jagged_buffer_connector_->DoReset(); - buffer_id = 0; - } - } - - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - - RETURN_IF_NOT_OK(PostEndOfData()); - - return Status::OK(); -} - -// static local-only helper function -static void shuffleKeys(std::vector *i_keys, uint32_t seed) { - std::mt19937 rng(seed); - std::shuffle(i_keys->begin(), i_keys->end(), rng); -} - -// The entry point for when workers are launched. -Status TFReaderOp::WorkerEntry(int32_t worker_id) { - // must be called first if called by worker spawned by taskgroup - TaskManager::FindMe()->Post(); - - std::unique_ptr io_block; - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - - while (!io_block->eof()) { - if (!io_block->eoe()) { - if (load_jagged_connector_) { - std::string filename; - RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); - int64_t start_offset = io_block->GetStartOffset(); - int64_t end_offset = io_block->GetEndOffset(); - RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); - MS_LOG(DEBUG) << "TFReader operator worker " << worker_id << " loaded file " << filename << "."; - } - } else { - std::unique_ptr eoe_buffer = std::make_unique(1, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); - } - - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - } - - return Status::OK(); -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. -// When the worker pops this control indicator, it will shut itself down gracefully. -Status TFReaderOp::PostEndOfData() { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); - } - - return Status::OK(); -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker -// pops this control indicator, it will wait until the next epoch starts and then resume execution. -Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); - } - - return Status::OK(); -} - -bool TFReaderOp::NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count) { - *start_offset = 0; - *end_offset = 0; - bool push = false; - int64_t start_index = device_id_ * num_rows_per_shard_; - if (device_id_ + 1 < 0) { - MS_LOG(ERROR) << "Device id is invalid"; - return false; - } - int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; - - if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { - *start_offset = start_index - pre_count; - push = true; - if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - if (pre_count >= start_index && pre_count < end_index) { - *start_offset = 0; - push = true; - if (pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - return push; -} - -Status TFReaderOp::FillIOBlockShuffle(const std::vector &i_keys) { - int32_t queue_index = 0; - int32_t key_index = 0; - int64_t pre_count = 0; - int64_t start_offset = 0; - int64_t end_offset = 0; - bool finish = false; - bool end_of_epoch = false; - while (!finish) { - for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { - { - std::unique_lock lock(load_io_block_queue_mutex_); - if (load_io_block_queue_ == false) { - end_of_epoch = true; - break; - } - } - if (!equal_rows_per_shard_) { - if (key_index++ % num_devices_ == device_id_) { - auto ioBlock = std::make_unique(*it, kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - queue_index = (queue_index + 1) % num_workers_; - } - } else { - // Do an index lookup using that key to get the filename. - std::string file_name = (*filename_index_)[*it]; - if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { - auto ioBlock = std::make_unique(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset; - queue_index = (queue_index + 1) % num_workers_; - } - - pre_count += filename_numrows_[file_name]; - } - } - if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_ && - !end_of_epoch) { - finish = false; - } else { - finish = true; - } - } - RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); - return Status::OK(); -} - -Status TFReaderOp::FillIOBlockNoShuffle() { - int32_t queue_index = 0; - int32_t key_index = 0; - int64_t pre_count = 0; - int64_t start_offset = 0; - int64_t end_offset = 0; - bool finish = false; - bool end_of_epoch = false; - while (!finish) { - // Iterate over all the keys and add one key to each block. - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - { - std::unique_lock lock(load_io_block_queue_mutex_); - if (load_io_block_queue_ == false) { - end_of_epoch = true; - break; - } - } - if (!equal_rows_per_shard_) { - if (key_index++ % num_devices_ == device_id_) { - auto ioBlock = - std::make_unique(it.key(), kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - queue_index = (queue_index + 1) % num_workers_; - } - } else { - std::string file_name = it.value(); - if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { - auto ioBlock = std::make_unique(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - queue_index = (queue_index + 1) % num_workers_; - } - - pre_count += filename_numrows_[file_name]; - } - } - if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_ && - !end_of_epoch) { - finish = false; - } else { - finish = true; - } - } - - RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); - return Status::OK(); -} - -// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. -Status TFReaderOp::WaitToFillIOBlockQueue() { - // must be called first if called by worker spawned by taskgroup - TaskManager::FindMe()->Post(); - - std::vector i_keys; - // Generate a vector of keys that we can shuffle - if (shuffle_files_) { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - i_keys.push_back(it.key()); - } - } - uint32_t seed = 0; - while (true) { - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); - io_block_queue_wait_post_.Clear(); - - if (finished_reading_dataset_) { - break; - } - - if (shuffle_files_) { - shuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); - RETURN_IF_NOT_OK(FillIOBlockShuffle(i_keys)); - } else { // shuffle_files_ == false - RETURN_IF_NOT_OK(FillIOBlockNoShuffle()); - } - } - - return Status::OK(); -} - -// Notifies the thread which called WaitToFillIOBlockQueue to resume execution. -void TFReaderOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } - -// Pops an element from a queue in io_block_queues -Status TFReaderOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); - - return Status::OK(); -} - -// Pushes an element to a queue in io_block_queues -Status TFReaderOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); - - return Status::OK(); -} - -// Reads a tf_file file and loads the data into multiple buffers. -Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, - const int32_t &worker_id) { - std::ifstream reader; - reader.open(filename); - if (!reader) { - RETURN_STATUS_UNEXPECTED("failed to open file: " + filename); - } - - int64_t rows_read = 0; - int64_t rows_total = 0; - std::unique_ptr current_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - std::unique_ptr new_tensor_table = std::make_unique(); - - while (reader.peek() != EOF) { - if (!load_jagged_connector_) { - break; - } - - // read length - int64_t record_length = 0; - (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); - - // ignore crc header - (void)reader.ignore(static_cast(sizeof(int32_t))); - - // read serialized Example - std::string serialized_example; - serialized_example.resize(record_length); - (void)reader.read(&serialized_example[0], static_cast(record_length)); - if (start_offset == kInvalidOffset || (rows_total >= start_offset && rows_total < end_offset)) { - dataengine::Example tf_file; - if (!tf_file.ParseFromString(serialized_example)) { - std::string errMsg = "parse tfrecord failed"; - RETURN_STATUS_UNEXPECTED(errMsg); - } - RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); - rows_read++; - } - - // ignore crc footer - (void)reader.ignore(static_cast(sizeof(int32_t))); - rows_total++; - - if (rows_read == rows_per_buffer_) { - current_buffer->set_tensor_table(std::move(new_tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(current_buffer))); - - current_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - new_tensor_table = std::make_unique(); - rows_read = 0; - } - } - - if (rows_read > 0) { - current_buffer->set_tensor_table(std::move(new_tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(current_buffer))); - } - - return Status::OK(); -} - -// Parses a single row and puts the data into a tensor table. -Status TFReaderOp::LoadExample(const dataengine::Example *tf_file, std::unique_ptr *tensor_table, - int64_t row) { - int32_t num_columns = data_schema_->NumColumns(); - TensorRow newRow(num_columns, nullptr); - (*tensor_table)->push_back(std::move(newRow)); - - for (int32_t col = 0; col < num_columns; ++col) { - const ColDescriptor current_col = data_schema_->column(col); - const dataengine::Features &example_features = tf_file->features(); - const google::protobuf::Map &feature_map = example_features.feature(); - const dataengine::Feature &column_values_list = feature_map.at(current_col.name()); - RETURN_IF_NOT_OK(LoadFeature(tensor_table, column_values_list, current_col, row, col)); - } - - return Status::OK(); -} - -// Parses a single cell and puts the data into a tensor table. -Status TFReaderOp::LoadFeature(const std::unique_ptr *tensor_table, - const dataengine::Feature &column_values_list, const ColDescriptor ¤t_col, - int64_t row, int32_t col) { - const dataengine::Feature::KindCase column_list_type = column_values_list.kind_case(); - std::unique_ptr float_array; // For staging data from protobuf deserialization - const unsigned char *data_ptr = nullptr; // Generic pointer used for populating the Tensor - - // This variable will point into the above staging variables. - // Also used for creating shape attributes. - int32_t num_elements = 0; - - // we build a tensor first a read directly into it if we need to cast - std::shared_ptr ts; - - // Depending on the type of data from the tf_file, we want to extract 2 things: - // 1) A pointer to the data as a const unsigned char * - // 2) The number of elements of the data - // After those are determined, we can then build the tensor to represent this data. - switch (column_list_type) { - case dataengine::Feature::KindCase::kBytesList: { - RETURN_IF_NOT_OK(LoadBytesList(current_col, column_values_list, &num_elements, &ts)); - - break; - } - case dataengine::Feature::KindCase::kFloatList: { - RETURN_IF_NOT_OK(LoadFloatList(current_col, column_values_list, &num_elements, &float_array)); - - data_ptr = reinterpret_cast(float_array.get()); - - // only floatList needs to create the tensor here, other two lists read directly - // into the tensor - TensorShape current_shape = TensorShape::CreateUnknownRankShape(); - RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(num_elements, ¤t_shape)); - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&ts, current_col.tensorImpl(), current_shape, current_col.type(), data_ptr)); - break; - } - case dataengine::Feature::KindCase::kInt64List: { - RETURN_IF_NOT_OK(LoadIntListSwitch(current_col, column_values_list, &num_elements, &ts)); - break; - } - case dataengine::Feature::KindCase::KIND_NOT_SET: { - std::string err_msg = "tf_file column list type enum is KIND_NOT_SET"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - default: { - std::string err_msg = "tf_file column list type enum does not match any known DE type"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - - (**tensor_table)[row][col] = std::move(ts); - - return Status::OK(); -} - -// Overrides base class reset method. Cleans up any state info from it's previous execution and -// reinitializes itself so that it can be executed again, as if it was just created. -Status TFReaderOp::Reset() { - // start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true - load_jagged_connector_ = true; - - { - std::unique_lock lock(load_io_block_queue_mutex_); - load_io_block_queue_ = true; - } - - RETURN_IF_NOT_OK(ParallelOp::Reset()); - NotifyToFillIOBlockQueue(); - - return Status::OK(); -} - -Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor) { - // kBytesList can map to the following DE types ONLY! - // DE_UINT8, DE_INT8 - // Must be single byte type for each element! - if (current_col.type() != DataType::DE_UINT8 && current_col.type() != DataType::DE_INT8 && - current_col.type() != DataType::DE_STRING) { - std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - const dataengine::BytesList &bytes_list = column_values_list.bytes_list(); - - *num_elements = bytes_list.value_size(); - - if (current_col.type() == DataType::DE_STRING) { - TensorShape shape = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, shape)); - return Status::OK(); - } - - uint64_t max_size = 0; - for (uint32_t i = 0; i < bytes_list.value_size(); ++i) max_size = std::max(max_size, bytes_list.value(i).size()); - - int64_t pad_size = max_size; - - // if user provides a shape in the form of [-1, d1, 2d, ... , dn], we need to pad to d1 * d2 * ... * dn - if (current_col.hasShape()) { - TensorShape cur_shape = current_col.shape(); - if (cur_shape.Size() >= 2 && cur_shape[0] == TensorShape::kDimUnknown) { - int64_t new_pad_size = 1; - for (int i = 1; i < cur_shape.Size(); ++i) { - if (cur_shape[i] == TensorShape::kDimUnknown) { - std::string err_msg = "More than one unknown dimension in the shape of column: " + current_col.name(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - new_pad_size *= cur_shape[i]; - } - pad_size = new_pad_size; - } - } - - // know how many elements there are and the total bytes, create tensor here: - TensorShape current_shape = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(current_col.MaterializeTensorShape((*num_elements) * pad_size, ¤t_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, current_shape, current_col.type(), pad_size)); - - return Status::OK(); -} - -Status TFReaderOp::LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::unique_ptr *float_array) { - // KFloatList can only map to DE types: - // DE_FLOAT32 - if (current_col.type() != DataType::DE_FLOAT32) { - std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - const dataengine::FloatList &float_list = column_values_list.float_list(); - - // Identify how many values we have and then create a local array of these - // to deserialize into - *num_elements = float_list.value_size(); - *float_array = std::make_unique(*num_elements); - for (int i = 0; i < float_list.value_size(); ++i) { - (*float_array)[i] = float_list.value(i); - } - - return Status::OK(); -} - -// Determines which template type to use and calls LoadIntList -Status TFReaderOp::LoadIntListSwitch(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor) { - if (current_col.type() == DataType::DE_UINT64) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_INT64) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_UINT32) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_INT32) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_UINT16) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_INT16) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_UINT8) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_INT8) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else { - std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - return Status::OK(); -} - -// Reads values from a bytes list and casts the value to type T, must be an integral type -// compatible with int64_t -template -Status TFReaderOp::LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor) { - if (!(current_col.type().IsInt())) { - std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - const dataengine::Int64List &int64_list = column_values_list.int64_list(); - - // Identify how many values we have and then create a local array of these - // to deserialize into - *num_elements = int64_list.value_size(); - - // know how many elements there are, create tensor here: - TensorShape current_shape = TensorShape::CreateUnknownRankShape(); - RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, ¤t_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, current_col.tensorImpl(), current_shape, current_col.type())); - - // Tensors are lazily allocated, this eagerly allocates memory for the tensor. - RETURN_IF_NOT_OK((*tensor)->AllocateBuffer((*tensor)->SizeInBytes())); - - int64_t i = 0; - auto it = (*tensor)->begin(); - for (; it != (*tensor)->end(); i++, ++it) { - T element = static_cast(int64_list.value(i)); - *it = element; - } - - return Status::OK(); -} - -Status TFReaderOp::CreateSchema(const std::string tf_file, std::vector columns_to_load) { - std::ifstream reader; - reader.open(tf_file); - - // read length - int64_t record_length = 0; - (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); - - // ignore crc header - (void)reader.ignore(static_cast(sizeof(int32_t))); - - // read serialized Example - std::string serialized_example; - serialized_example.resize(record_length); - (void)reader.read(&serialized_example[0], static_cast(record_length)); - - dataengine::Example example; - if (!example.ParseFromString(serialized_example)) RETURN_STATUS_UNEXPECTED("parse tf_file failed"); - - const dataengine::Features &example_features = example.features(); - const google::protobuf::Map &feature_map = example_features.feature(); - - if (columns_to_load.empty()) { - (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns_to_load), - [](const auto &it) -> std::string { return it.first; }); - std::sort(columns_to_load.begin(), columns_to_load.end()); - } - - for (const auto &curr_col_name : columns_to_load) { - auto it = feature_map.find(curr_col_name); - if (it == feature_map.end()) { - RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); - } - std::string column_name = it->first; - - std::string column_type; - - const dataengine::Feature &feature = it->second; - const dataengine::Feature::KindCase kind_case = feature.kind_case(); - switch (kind_case) { - case dataengine::Feature::KindCase::kBytesList: - column_type = "uint8"; - break; - - case dataengine::Feature::KindCase::kFloatList: - column_type = "float32"; - break; - - case dataengine::Feature::KindCase::kInt64List: - column_type = "int64"; - break; - - case dataengine::Feature::KindCase::KIND_NOT_SET: - RETURN_STATUS_UNEXPECTED("trying to make schema, tf_file column list type enum is KIND_NOT_SET"); - - default: - RETURN_STATUS_UNEXPECTED( - "trying to make schema, tf_file column list type enum does not match any known DE type"); - } - - RETURN_IF_NOT_OK( - data_schema_->AddColumn(ColDescriptor(column_name, DataType(column_type), TensorImpl::kFlexible, 1))); - } - - return Status::OK(); -} - -Status TFReaderOp::CountTotalRows(int64_t *out_total_rows, const std::vector &filenames, int64_t threads, - bool estimate) { - try { - if (threads > filenames.size()) { - threads = filenames.size(); - } - - std::vector> async_results; - - int64_t chunk_size = filenames.size() / threads; - int64_t remainder = filenames.size() % threads; - - int64_t begin = 0; - int64_t end = begin; - for (int i = 0; i < threads; i++) { - end += chunk_size; - if (remainder > 0) { - end++; - remainder--; - } - - if (estimate) { - // Parse a single file for each chunk with estimate mode on - async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, begin + 1)); - } else { - // Parse the whole chunk with estimate mode off - async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, end)); - } - - begin = end; - } - - int64_t total_rows = 0; - for (int i = 0; i < async_results.size(); i++) { - total_rows += async_results[i].get(); - } - - if (estimate) { - // Each thread only scans 1 file - // Estimated total rows = Average rows * total number of files - total_rows = total_rows / threads * filenames.size(); - } - - *out_total_rows = total_rows; - } catch (const std::exception &e) { - std::string err_msg = "Unexpected error occurred: "; - err_msg += e.what(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - return Status::OK(); -} - -int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector &filenames, int64_t begin, int64_t end) { - int64_t rows_read = 0; - for (int i = begin; i < end; i++) { - std::ifstream reader; - reader.open(filenames[i]); - if (!reader) { - MS_LOG(DEBUG) << "TFReader operator failed to open file " << filenames[i] << "."; - } - - while (reader.peek() != EOF) { - // read length - int64_t record_length = 0; - (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); - - // ignore crc header - (void)reader.ignore(static_cast(sizeof(int32_t))); - - // ignore tf_file contents - (void)reader.ignore(static_cast(record_length)); - - // ignore crc footer - (void)reader.ignore(static_cast(sizeof(int32_t))); - - rows_read++; - } - } - - return rows_read; -} - -// Visitor accept method for NodePass -Status TFReaderOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status TFReaderOp::ComputeColMap() { - // Construct the column name map for this operator (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -// During tree prepare phase, operators may have specific post-operations to perform depending on -// their role. -Status TFReaderOp::PrepareNodePostAction() { - // Run common code from super class before adding TFReaderOp specific handling - RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); - - // Specific handling for this op, we need to do cache op work so assign the sampler to the cache - // TF is a special case because it can support file-based sharding/shuffling, or, if there - // is a cache, then it can also do row-based sampler using the sampler on the cache. - // Thus, pass true for random access op flag when saving the sampler. This is a special case, - // since usually a non-mappable dataset would pass false here. - RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(true)); - - // Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into - // a simpler producer of all data (no shuffling or sharding or anything) - if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { - device_id_ = 0; - num_devices_ = 1; - total_rows_ = 0; - shuffle_files_ = false; - equal_rows_per_shard_ = false; - sampler_.reset(); // Normally SaveSampler code did this for us, but we passed in true above (See comment) - } else { - // This sanity check had been delayed until now in the prepare loop. - // If we are not in a cache path, then we can validate the the file-based sharding config. - // If we are in a cache path, there is no file-based sharding so the check is not correct in that - // situation. - if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast(num_devices_)) { - RETURN_STATUS_UNEXPECTED("Not enough tfrecord files provided\n"); - } - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h deleted file mode 100644 index 9226c4c6c5..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h +++ /dev/null @@ -1,415 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/util/wait_post.h" -#include "dataset/util/auto_index.h" -#include "dataset/util/status.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" - -namespace dataengine { -class Example; -class Feature; -class BytesList; -} // namespace dataengine - -namespace mindspore { -namespace dataset { -template -class Queue; - -template -class Connector; - -class JaggedConnector; -class FilenameBlock; - -using StringIndex = AutoIndexObj; - -class TFReaderOp : public ParallelOp { - public: - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Checks if the inputs of the builder is valid. - // @return Status - the error code returned. - Status ValidateInputs() const; - - Status Build(std::shared_ptr *out_tf_reader_op); - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetDataSchema(std::unique_ptr data_schema) { - builder_data_schema_ = std::move(data_schema); - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetWorkerConnectorSize(int32_t size) { - builder_worker_connector_size_ = size; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumDevices(int64_t num_dev) { - builder_num_devices_ = num_dev; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetDeviceId(int64_t dev_id) { - builder_device_id_ = dev_id; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &setTotalRows(int64_t total_rows) { - builder_total_rows_ = total_rows; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetDatasetFilesList(const std::vector &dataset_files_list) { - builder_dataset_files_list_ = dataset_files_list; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetColumnsToLoad(const std::vector &columns_to_load) { - builder_columns_to_load_ = columns_to_load; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShuffleFiles(bool shuffle_files) { - builder_shuffle_files_ = shuffle_files; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShardEqualRows(bool shard_equal_rows) { - builder_equal_rows_per_shard_ = shard_equal_rows; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - private: - std::unique_ptr builder_data_schema_; - std::shared_ptr builder_sampler_; - int32_t builder_device_id_; - int32_t builder_num_devices_; - int32_t builder_num_workers_; - int32_t builder_worker_connector_size_; - int32_t builder_op_connector_size_; - int64_t builder_rows_per_buffer_; - int64_t builder_total_rows_; - std::vector builder_dataset_files_list_; - std::vector builder_columns_to_load_; - bool builder_shuffle_files_; - bool builder_equal_rows_per_shard_; - }; - - // Constructor of TFReaderOp (2) - // @note The builder class should be used to call this constructor. - // @param num_workers - number of worker threads reading data from tf_file files. - // @param worker_connector_size - size of each internal queue. - // @param rows_per_buffer - number of rows that a full buffer will contain. - // @param total_num_rows - Number of rows to read - // @param dataset_files_list - list of filepaths for the dataset files. - // @param data_schema - the data schema object. - // @param op_connector_size - size of each queue in the connector that the child operator pulls from. - // @param columns_to_load - the names of the columns to load data from. - // @param shuffle_files - whether or not to shuffle the files before reading data. - // @param equal_rows_per_shard - whether or not to get equal rows for each process. - // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes - TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, - std::vector dataset_files_list, std::unique_ptr data_schema, - int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, - int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler); - - // Default destructor - ~TFReaderOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // Instantiates the internal queues and connectors. - // @return Status - the error code returned. - Status Init(); - - // Class functor operator () override. - // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - the error code returned. - Status operator()() override; - - // Overrides base class reset method. Cleans up any state info from it's previous execution and - // reinitializes itself so that it can be executed again, as if it was just created. - // @return Status - the error code returned. - Status Reset() override; - - // Getter method - int64_t rows_per_buffer() const { return rows_per_buffer_; } - - // Reads all the provided tf_file files and counts the total number of rows. filenames will - // first be sectioned into equal parts, then sections are read in parallel. If threads is - // greater than the number of files, threads will be clamped to the number of files. - // @param out_total_tows - output parameter which contains the total number of rows - // @param filenames - a list of tf_file filenames. - // @param threads - number of threads to use to read the tf_file files. - // @param estimate - estimate mode, under this mode each threads will sample a single file from each chunk - // @return Status - the error code returned. - static Status CountTotalRows(int64_t *out_total_rows, const std::vector &filenames, int64_t threads = 1, - bool estimate = false); - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "TFReaderOp"; } - - // File names getter - // @return Vector of the input file names - std::vector FileNames() { return dataset_files_list_; } - - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePostAction() override; - - private: - // The entry point for when workers are launched. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status WorkerEntry(int32_t worker_id) override; - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. - // When the worker pops this control indicator, it will shut itself down gracefully. - // @return Status - the error code returned. - Status PostEndOfData(); - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker - // pops this control indicator, it will wait until the next epoch starts and then resume execution. - // @return Status - the error code returned. - Status PostEndOfEpoch(int32_t queue_index); - - // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. - // @return Status - the error code returned. - Status WaitToFillIOBlockQueue(); - - // Notifies the thread which called WaitToFillIOBlockQueue to resume execution. - void NotifyToFillIOBlockQueue(); - - // Pops an element from a queue in IOBlockQueue. - // @param index - the index of the queue to pop from. - // @param out_block - the popped element. - // @return Status - the error code returned. - Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); - - // Pushes an element to a queue in IOBlockQueue. - // @param index - the index of the queue to push to. - // @param io_block - the element to push onto the queue. - // @return Status - the error code returned. - Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); - - // Reads a tf_file file and loads the data into multiple buffers. - // @param filename - the tf_file file to read. - // @param start_offset - the start offset of file. - // @param end_offset - the end offset of file. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, - const int32_t &worker_id); - - // Parses a single row and puts the data into a tensor table. - // @param tf_file - the row to be parsed. - // @param tensor_table - the tensor table to put the parsed data in. - // @param row - the id of the row filled in the tensor table. - // @return Status - the error code returned. - Status LoadExample(const dataengine::Example *tf_file, std::unique_ptr *tensor_table, int64_t row); - - // Parses a single cell and puts the data into a tensor table. - // @param tensor_table - the tensor table to put the parsed data in. - // @param column_values_list - the cell to parse. - // @param current_col - the column descriptor containing the expected shape and type of the data. - // @return Status - the error code returned. - Status LoadFeature(const std::unique_ptr *tensor_table, const dataengine::Feature &column_values_list, - const ColDescriptor ¤t_col, int64_t row, int32_t col); - - // Reads values from a bytes list - // @param current_col - the column descriptor containing the expected shape and type of the data. - // @param column_values_list - the cell that contains the bytes list to read from. - // @param elementStr - the string we read the value into. - // @return Status - the error code returned. - static Status LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor); - - // Reads values from a float list - // @param current_col - the column descriptor containing the expected shape and type of the data. - // @param column_values_list - the cell that contains the float list to read from. - // @Param numElements - number of values in the float list. - // @param float_array - the array we read the values into. - // @return Status - the error code returned. - Status LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::unique_ptr *float_array); - - // Reads values from a bytes list and casts the value to type T, must be an integral - // type compatible with int64_t - // @param current_col - the column descriptor containing the expected shape and type of the data. - // @param column_values_list - the cell that contains the int list to read from. - // @Param num_elements - number of values in the int list. - // @param tensor - the tensor we read the values into. - // @return Status - the error code returned. - template - Status LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor); - - // Determines which template type to use and calls LoadIntList - // @param current_col - the column descriptor containing the expected shape and type of the data. - // @param column_values_list - the cell that contains the int list to read from. - // @Param numElements - number of values in the int list. - // @param tensor - the tensor we read the values into. - // @return Status - the error code returned. - Status LoadIntListSwitch(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor); - - // Reads one row of data from a tf file and creates a schema based on that row - // @return Status - the error code returned. - Status CreateSchema(const std::string tf_file, std::vector columns_to_load); - - // Meant to be called async. Will read files in the range [begin, end) and return the total rows - // @param filenames - a list of tf data filenames. - // @param begin - index of first file to read. - // @param end - one greater than the index of the last file to read. - // @return int63_t - the total number of rows of files read. - static int64_t CountTotalRowsSectioned(const std::vector &filenames, const int64_t begin, - const int64_t end); - // Fill IO block queue if shuffle is true - // @param i_keys - shuffle keys. - // @return Status - the error code returned. - Status FillIOBlockShuffle(const std::vector &i_keys); - - /** - * Fill IO block queue if shuffle is false - * @param i_keys - shuffle keys. - * @return Status - the error code returned. - */ - Status FillIOBlockNoShuffle(); - - // Select file and push it to the block queue. - // @param file_name - File name. - // @param start_file - If file contains the first sample of data. - // @param end_file - If file contains the end sample of data. - // @param pre_count - Total rows of previous files. - // @return Status - the error code returned. - bool NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count); - - // Caculate number of rows in each shard. - // @return Status - the error code returned. - Status CalculateNumRowsPerShard(); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t device_id_; - int32_t num_devices_; - int64_t rows_per_buffer_; - int64_t total_rows_; - std::vector dataset_files_list_; - std::vector columns_to_load_; - bool finished_reading_dataset_; - bool shuffle_files_; - std::unique_ptr data_schema_; - std::unique_ptr filename_index_; - bool load_io_block_queue_; - bool load_jagged_connector_; - - std::unique_ptr jagged_buffer_connector_; - QueueList> io_block_queues_; - WaitPost io_block_queue_wait_post_; - std::mutex load_io_block_queue_mutex_; - std::map filename_numrows_; - int64_t num_rows_; - int64_t num_rows_per_shard_; - bool equal_rows_per_shard_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc deleted file mode 100644 index 958aa65b06..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc +++ /dev/null @@ -1,465 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/voc_op.h" - -#include -#include -#include -#include "./tinyxml2.h" -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" - -using tinyxml2::XMLDocument; -using tinyxml2::XMLElement; -using tinyxml2::XMLError; -namespace mindspore { -namespace dataset { -const char kColumnImage[] = "image"; -const char kColumnTarget[] = "target"; -const char kColumnAnnotation[] = "annotation"; -const char kJPEGImagesFolder[] = "/JPEGImages/"; -const char kSegmentationClassFolder[] = "/SegmentationClass/"; -const char kAnnotationsFolder[] = "/Annotations/"; -const char kImageSetsSegmentation[] = "/ImageSets/Segmentation/"; -const char kImageSetsMain[] = "/ImageSets/Main/"; -const char kImageExtension[] = ".jpg"; -const char kSegmentationExtension[] = ".png"; -const char kAnnotationExtension[] = ".xml"; -const char kImageSetsExtension[] = ".txt"; - -VOCOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); - builder_task_type_ = TaskType::Segmentation; -} - -Status VOCOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - builder_schema_ = std::make_unique(); - if (builder_task_type_ == TaskType::Segmentation) { - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnTarget), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - } else if (builder_task_type_ == TaskType::Detection) { - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnAnnotation), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - } - *ptr = std::make_shared(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_, - builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, - builder_decode_, std::move(builder_schema_), std::move(builder_sampler_)); - return Status::OK(); -} - -Status VOCOp::Builder::SanityCheck() { - Path dir(builder_dir_); - std::string err_msg; - err_msg += dir.IsDirectory() == false ? "VOC path is invalid or not set\n" : ""; - err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, - const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, - int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size, std::move(sampler)), - decode_(decode), - row_cnt_(0), - buf_cnt_(0), - task_type_(task_type), - task_mode_(task_mode), - folder_path_(folder_path), - class_index_(class_index), - rows_per_buffer_(rows_per_buffer), - data_schema_(std::move(data_schema)) { - io_block_queues_.Init(num_workers_, queue_size); -} - -Status VOCOp::TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) > num_rows_) continue; - keys->push_back(*itr); - row_cnt_++; - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); - keys->clear(); - } - } - return Status::OK(); -} - -Status VOCOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (true) { - std::vector keys; - keys.reserve(rows_per_buffer_); - while (sampler_buffer->eoe() == false) { - std::shared_ptr sample_ids; - RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); - if (sample_ids->type() != DataType(DataType::DE_INT64)) { - RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); - } - RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - } -} - -void VOCOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows: " << num_rows_ << "\nVOC Directory: " << folder_path_ << "\n\n"; - } -} - -Status VOCOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); - return Status::OK(); -} - -Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *trow) { - if (task_type_ == TaskType::Segmentation) { - std::shared_ptr image, target; - const std::string kImageFile = - folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension); - const std::string kTargetFile = - folder_path_ + std::string(kSegmentationClassFolder) + image_id + std::string(kSegmentationExtension); - RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); - RETURN_IF_NOT_OK(ReadImageToTensor(kTargetFile, data_schema_->column(1), &target)); - (*trow) = TensorRow(row_id, {std::move(image), std::move(target)}); - } else if (task_type_ == TaskType::Detection) { - std::shared_ptr image, annotation; - const std::string kImageFile = - folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension); - const std::string kAnnotationFile = - folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); - RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); - RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, data_schema_->column(1), &annotation)); - (*trow) = TensorRow(row_id, {std::move(image), std::move(annotation)}); - } - return Status::OK(); -} - -Status VOCOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - TensorRow trow; - for (const uint64_t &key : keys) { - RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_ids_[key], &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -Status VOCOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, (std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty() == true) return Status::OK(); - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -Status VOCOp::ParseImageIds() { - std::string image_sets_file; - if (task_type_ == TaskType::Segmentation) { - image_sets_file = - folder_path_ + std::string(kImageSetsSegmentation) + task_mode_ + std::string(kImageSetsExtension); - } else if (task_type_ == TaskType::Detection) { - image_sets_file = folder_path_ + std::string(kImageSetsMain) + task_mode_ + std::string(kImageSetsExtension); - } - std::ifstream in_file; - in_file.open(image_sets_file); - if (in_file.fail()) { - RETURN_STATUS_UNEXPECTED("Fail to open file: " + image_sets_file); - } - std::string id; - while (getline(in_file, id)) { - if (id.size() > 0 && id[id.size() - 1] == '\r') { - image_ids_.push_back(id.substr(0, id.size() - 1)); - } else { - image_ids_.push_back(id); - } - } - in_file.close(); - image_ids_.shrink_to_fit(); - num_rows_ = image_ids_.size(); - return Status::OK(); -} - -Status VOCOp::ParseAnnotationIds() { - std::vector new_image_ids; - for (auto id : image_ids_) { - const std::string kAnnotationName = - folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension); - RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName)); - if (label_map_.find(kAnnotationName) != label_map_.end()) { - new_image_ids.push_back(id); - } - } - - if (image_ids_.size() != new_image_ids.size()) { - image_ids_.clear(); - image_ids_.insert(image_ids_.end(), new_image_ids.begin(), new_image_ids.end()); - } - uint32_t count = 0; - for (auto &label : label_index_) { - label.second = count++; - } - - num_rows_ = image_ids_.size(); - return Status::OK(); -} - -Status VOCOp::ParseAnnotationBbox(const std::string &path) { - if (!Path(path).Exists()) { - RETURN_STATUS_UNEXPECTED("File is not found : " + path); - } - Bbox bbox; - XMLDocument doc; - XMLError e = doc.LoadFile(common::SafeCStr(path)); - if (e != XMLError::XML_SUCCESS) { - RETURN_STATUS_UNEXPECTED("Xml load failed"); - } - XMLElement *root = doc.RootElement(); - if (root == nullptr) { - RETURN_STATUS_UNEXPECTED("Xml load root element error"); - } - XMLElement *object = root->FirstChildElement("object"); - if (object == nullptr) { - RETURN_STATUS_UNEXPECTED("No object find in " + path); - } - while (object != nullptr) { - std::string label_name; - uint32_t xmin = 0, ymin = 0, xmax = 0, ymax = 0, truncated = 0, difficult = 0; - XMLElement *name_node = object->FirstChildElement("name"); - if (name_node != nullptr && name_node->GetText() != 0) label_name = name_node->GetText(); - XMLElement *truncated_node = object->FirstChildElement("truncated"); - if (truncated_node != nullptr) truncated = truncated_node->UnsignedText(); - XMLElement *difficult_node = object->FirstChildElement("difficult"); - if (difficult_node != nullptr) difficult = difficult_node->UnsignedText(); - - XMLElement *bbox_node = object->FirstChildElement("bndbox"); - if (bbox_node != nullptr) { - XMLElement *xmin_node = bbox_node->FirstChildElement("xmin"); - if (xmin_node != nullptr) xmin = xmin_node->UnsignedText(); - XMLElement *ymin_node = bbox_node->FirstChildElement("ymin"); - if (ymin_node != nullptr) ymin = ymin_node->UnsignedText(); - XMLElement *xmax_node = bbox_node->FirstChildElement("xmax"); - if (xmax_node != nullptr) xmax = xmax_node->UnsignedText(); - XMLElement *ymax_node = bbox_node->FirstChildElement("ymax"); - if (ymax_node != nullptr) ymax = ymax_node->UnsignedText(); - } else { - RETURN_STATUS_UNEXPECTED("bndbox dismatch in " + path); - } - if (label_name != "" && (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) && xmin > 0 && - ymin > 0 && xmax > xmin && ymax > ymin) { - std::vector bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, truncated, difficult}; - bbox.emplace_back(std::make_pair(label_name, bbox_list)); - label_index_[label_name] = 0; - } - object = object->NextSiblingElement("object"); - } - if (bbox.size() > 0) label_map_[path] = bbox; - return Status::OK(); -} - -Status VOCOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -Status VOCOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&VOCOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(this->ParseImageIds()); - if (task_type_ == TaskType::Detection) { - RETURN_IF_NOT_OK(this->ParseAnnotationIds()); - } - RETURN_IF_NOT_OK(this->InitSampler()); - return Status::OK(); -} - -Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); - if (decode_ == true) { - Status rc = Decode(*tensor, tensor); - if (rc.IsError()) { - RETURN_STATUS_UNEXPECTED("fail to decode file: " + path); - } - } - return Status::OK(); -} - -Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, - std::shared_ptr *tensor) { - Bbox bbox_info = label_map_[path]; - std::vector bbox_row; - dsize_t bbox_column_num = 0, bbox_num = 0; - for (auto box : bbox_info) { - if (label_index_.find(box.first) != label_index_.end()) { - std::vector bbox; - if (class_index_.find(box.first) != class_index_.end()) { - bbox.emplace_back(class_index_[box.first]); - } else { - bbox.emplace_back(label_index_[box.first]); - } - bbox.insert(bbox.end(), box.second.begin(), box.second.end()); - bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); - if (bbox_column_num == 0) { - bbox_column_num = static_cast(bbox.size()); - } - bbox_num++; - } - } - - std::vector bbox_dim = {bbox_num, bbox_column_num}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, col.tensorImpl(), TensorShape(bbox_dim), col.type(), - reinterpret_cast(&bbox_row[0]))); - return Status::OK(); -} - -Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t *count) { - if (task_type == "Detection") { - std::map input_class_indexing; - for (auto p : dict) { - (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), - py::reinterpret_borrow(p.second))); - } - - std::shared_ptr op; - RETURN_IF_NOT_OK( - Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); - RETURN_IF_NOT_OK(op->ParseImageIds()); - RETURN_IF_NOT_OK(op->ParseAnnotationIds()); - *count = static_cast(op->image_ids_.size()); - } else if (task_type == "Segmentation") { - std::shared_ptr op; - RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op)); - RETURN_IF_NOT_OK(op->ParseImageIds()); - *count = static_cast(op->image_ids_.size()); - } - - return Status::OK(); -} - -Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, std::map *output_class_indexing) { - std::map input_class_indexing; - for (auto p : dict) { - (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), - py::reinterpret_borrow(p.second))); - } - - if (!input_class_indexing.empty()) { - *output_class_indexing = input_class_indexing; - } else { - std::shared_ptr op; - RETURN_IF_NOT_OK( - Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); - RETURN_IF_NOT_OK(op->ParseImageIds()); - RETURN_IF_NOT_OK(op->ParseAnnotationIds()); - for (const auto label : op->label_index_) { - (*output_class_indexing).insert(std::make_pair(label.first, label.second)); - } - } - - return Status::OK(); -} - -Status VOCOp::ComputeColMap() { - // Set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h deleted file mode 100644 index 89875341ca..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h +++ /dev/null @@ -1,288 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// Forward declares -template -class Queue; - -using Bbox = std::vector>>; - -class VOCOp : public ParallelOp, public RandomAccessOp { - public: - enum class TaskType { Segmentation = 0, Detection = 1 }; - - class Builder { - public: - // Constructor for Builder class of ImageFolderOp - // @param uint32_t numWrks - number of parallel workers - // @param dir - directory folder got ImageNetFolder - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method. - // @param const std::string & build_dir - // @return Builder setter method returns reference to the builder. - Builder &SetDir(const std::string &build_dir) { - builder_dir_ = build_dir; - return *this; - } - - // Setter method. - // @param const std::map &map - a class name to label map - // @return Builder setter method returns reference to the builder. - Builder &SetClassIndex(const std::map &map) { - builder_labels_to_read_ = map; - return *this; - } - - // Setter method. - // @param const std::string & task_type - // @return Builder setter method returns reference to the builder. - Builder &SetTask(const std::string &task_type) { - if (task_type == "Segmentation") { - builder_task_type_ = TaskType::Segmentation; - } else if (task_type == "Detection") { - builder_task_type_ = TaskType::Detection; - } - return *this; - } - - // Setter method. - // @param const std::string & task_mode - // @return Builder setter method returns reference to the builder. - Builder &SetMode(const std::string &task_mode) { - builder_task_mode_ = task_mode; - return *this; - } - - // Setter method. - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @param int32_t op_connector_size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method. - // @param bool do_decode - // @return Builder setter method returns reference to the builder. - Builder &SetDecode(bool do_decode) { - builder_decode_ = do_decode; - return *this; - } - - // Check validity of input args - // @return = The error code return - Status SanityCheck(); - - // The builder "Build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - bool builder_decode_; - std::string builder_dir_; - TaskType builder_task_type_; - std::string builder_task_mode_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - int32_t builder_rows_per_buffer_; - std::shared_ptr builder_sampler_; - std::unique_ptr builder_schema_; - std::map builder_labels_to_read_; - }; - - // Constructor - // @param TaskType task_type - task type of VOC - // @param std::string task_mode - task mode of VOC - // @param std::string folder_path - dir directory of VOC - // @param std::map class_index - input class-to-index of annotation - // @param int32_t num_workers - number of workers reading images in parallel - // @param int32_t rows_per_buffer - number of images (rows) in each buffer - // @param int32_t queue_size - connector queue size - // @param bool decode - whether to decode images - // @param std::unique_ptr data_schema - the schema of the VOC dataset - // @param std::shared_ptr sampler - sampler tells VOCOp what to read - VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, - const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, - int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler); - - // Destructor - ~VOCOp() = default; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t workerId - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Main Loop of VOCOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // @param const std::string &dir - VOC dir path - // @param const std::string &task_type - task type of reading voc job - // @param const std::string &task_mode - task mode of reading voc job - // @param const py::dict &dict - input dict of class index - // @param int64_t *count - output rows number of VOCDataset - static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t *count); - - // @param const std::string &dir - VOC dir path - // @param const std::string &task_type - task type of reading voc job - // @param const std::string &task_mode - task mode of reading voc job - // @param const py::dict &dict - input dict of class index - // @param int64_t numSamples - samples number of VOCDataset - // @param std::map *output_class_indexing - output class index of VOCDataset - static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, std::map *output_class_indexing); - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "VOCOp"; } - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Load a tensor row according to image id - // @param row_id_type row_id - id for this tensor row - // @param std::string image_id - image id - // @param TensorRow row - image & target read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); - - // @param const std::string &path - path to the image file - // @param const ColDescriptor &col - contains tensor implementation and datatype - // @param std::shared_ptr tensor - return - // @return Status - The error code return - Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); - - // @param const std::string &path - path to the image file - // @param const ColDescriptor &col - contains tensor implementation and datatype - // @param std::shared_ptr tensor - return - // @return Status - The error code return - Status ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Read image list from ImageSets - // @return Status - The error code return - Status ParseImageIds(); - - // Read annotation from Annotation folder - // @return Status - The error code return - Status ParseAnnotationIds(); - - // @param const std::string &path - path to annotation xml - // @return Status - The error code return - Status ParseAnnotationBbox(const std::string &path); - - // @param const std::shared_ptr &sample_ids - sample ids of tensor - // @param std::vector *keys - image id - // @return Status - The error code return - Status TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); - - // Called first when function is called - // @return Status - The error code return - Status LaunchThreadsAndInitOp(); - - // Reset dataset state - // @return Status - The error code return - Status Reset() override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - bool decode_; - int64_t row_cnt_; - int64_t buf_cnt_; - std::string folder_path_; - TaskType task_type_; - std::string task_mode_; - int32_t rows_per_buffer_; - std::unique_ptr data_schema_; - - WaitPost wp_; - std::vector image_ids_; - QueueList> io_block_queues_; - std::map class_index_; - std::map label_index_; - std::map label_map_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc deleted file mode 100644 index 8bc449cdc9..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include - -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/take_op.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status TakeOp::Builder::SanityCheck() const { - if (build_max_takes_ <= 0) { - std::string err_msg("Take count must be greater than 0."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status TakeOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_takes_, builder_op_connector_size_); - return Status::OK(); -} - -// Constructor of the TakeOp. -TakeOp::TakeOp(int32_t count, int32_t op_connector_size) - : PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {} - -// A print method typically used for debugging -void TakeOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [takes: " << max_takes_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nTake count: " << take_count_ << "\nMax takes: " << max_takes_ << "\n\n"; - } -} - -// Main entry point for Take -Status TakeOp::operator()() { - TaskManager::FindMe()->Post(); - std::unique_ptr buf; - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - - while (buf->eof() == false) { - if (take_count_ == max_takes_) { - // Do drain Operation - while (!buf->eoe() && !buf->eof()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - } - } - - // Loop until non EOE is received - if (buf->eoe()) { - take_count_ = 0; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - continue; - } - - // Get buffer and push back when take_count is still small - if (take_count_ < max_takes_) { - std::unique_ptr p_buffer; - RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer)); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer))); - } - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - } - - take_count_ = 0; - MS_LOG(DEBUG) << "Meet the end and push-back eof buffer."; - auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - return Status::OK(); -} - -// Function FillBuffer mainly prepare the buffer for returning -Status TakeOp::FillBuffer(std::unique_ptr *buffer, std::unique_ptr *data_buffer) { - int32_t buffer_size = (*buffer)->NumRows(); - if (take_count_ + buffer_size < max_takes_) { - *data_buffer = std::move(*buffer); - take_count_ = take_count_ + buffer_size; - } else { - MS_LOG(DEBUG) << "In last buffer: Push one buffer."; - std::unique_ptr new_tensor_table = std::make_unique(); - while (take_count_ < max_takes_) { - TensorRow new_row; - RETURN_IF_NOT_OK((*buffer)->PopRow(&new_row)); - take_count_++; - new_tensor_table->push_back(new_row); - } - (*buffer)->set_tensor_table(std::move(new_tensor_table)); - *data_buffer = std::move(*buffer); - } - return Status::OK(); -} - -Status TakeOp::PrepareNodePostAction() { - RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction()); - tree_->AddToEOEOpStack(shared_from_this()); - return Status::OK(); -} - -// Visitor accept method for NodePass -Status TakeOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/dataset/engine/datasetops/take_op.h deleted file mode 100644 index 9619a4409d..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ - -#include -#include -#include -#include "dataset/engine/datasetops/pipeline_op.h" - -namespace mindspore { -namespace dataset { -class TakeOp : public PipelineOp { - public: - // The nested builder class inside of the TakeOp is used to help manage all of the arguments - // for constructing it. This take op is very simple though, so this builder is really just - // provided for a consistent look and feel for creators of Dataset operators overall. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @param count - The number of takes to do - // @return This is a constructor. - explicit Builder(int32_t count); - - // Default destructor - ~Builder() = default; - - // The builder "build" method creates the final object. - // @return shared_ptr to the new TakeOp object - Status Build(std::shared_ptr *); - - private: - int32_t build_max_takes_; - int32_t builder_op_connector_size_; - - Status SanityCheck() const; - }; - - // Constructor of the TakeOp. - // @note The builder class should be used to call it - // @param count - The number of takes to do - explicit TakeOp(int32_t count, int32_t op_connector_size); - - // Destructor - ~TakeOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param ro - reference to the TakeOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const TakeOp &ro) { - ro.Print(out, false); - return out; - } - - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePostAction() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "TakeOp"; } - - private: - int32_t max_takes_; // The number of takes that the user requested - int32_t take_count_; // A counter for the current number of executed takes - - Status FillBuffer(std::unique_ptr *buffer, std::unique_ptr *data_buffer); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc deleted file mode 100644 index 70bce16a89..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc +++ /dev/null @@ -1,268 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/zip_op.h" -#include -#include -#include "dataset/core/constants.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -ZipOp::Builder::Builder() { - // Some arguments to the ZipOp constructor have a default argument that is taken - // from the client config. - // The user may choose to change these values for the construction of the ZipOp by - // using the various builder set methods. - - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status ZipOp::Builder::SanityCheck() const { return Status::OK(); } - -Status ZipOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_); - return Status::OK(); -} - -// Construct ZipOp here, local variables initialized in operator due to tree construction restrictions -ZipOp::ZipOp(int32_t rows_per_buffer, int32_t op_connector_size) - : PipelineOp(op_connector_size), - children_num_(0), - rows_per_buffer_(rows_per_buffer), - buffer_id_(0), - draining_(false), - eof_(false) {} - -// destructor -ZipOp::~ZipOp() {} - -// Entry point for Zip, called by launch() -Status ZipOp::operator()() { - // The children_num_ parameter needs to be put here - children_num_ = child_.size(); - // Synchronize with TaskManager once the thread is created. - TaskManager::FindMe()->Post(); - - // initialize the iterators - for (int32_t i = 0; i < children_num_; ++i) { - // magic number 0 since Zip is not a parallel Op - child_iterators_.push_back(std::make_unique(this, 0, i)); - } - - // Loop until eof is true - while (!eof_) { - // Create tensor table and prepare it by fetching and packing the first zipped row into it. - std::unique_ptr curr_table = std::make_unique(); - RETURN_IF_NOT_OK(prepare(curr_table.get())); - - // If an eof got picked up during the above prepare, then we're done - if (eof_) { - break; - } - while (!draining_) { - // 1. If a previous loop iteration sent the current table out, then create a new one. - if (curr_table == nullptr) { - curr_table = std::make_unique(); - } - - // 2 fill the table. Note: draining mode might get turned on if any of the child inputs were done - RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); - - // 3 create and update buffer and send it to the out connector - if (!curr_table->empty()) { - std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); - curr_buffer->set_tensor_table(std::move(curr_table)); - MS_LOG(DEBUG) << "Zip operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " - << curr_buffer->NumCols() << ", map " << column_name_id_map_.size() << "."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); - buffer_id_++; - } - } - - // 4 handle drain state. - if (draining_) { - MS_LOG(DEBUG) << "Zip operator is now draining child inputs."; - RETURN_IF_NOT_OK(drainPipeline()); - // Now that we have drained child inputs, send the eoe up. - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - } - } - - // 5 handle eof - // propagate eof here. - MS_LOG(DEBUG) << "Zip operator got EOF, propagating."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - return Status::OK(); -} - -// Handles preprocessing of the main loop, used when starting new epoch -Status ZipOp::prepare(TensorQTable *const table) { - MS_LOG(DEBUG) << "Zip operator prepares for new epoch."; - draining_ = false; - buffer_id_ = 0; - if (table == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase requires a tensor table."); - } - // fill initial row - TensorRow new_row; - RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); - - // If the first row fetching resulted in eof, then we are done. - if (eof_) { - return Status::OK(); - } - if (new_row.empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!"); - } - - // Pack this first row into our tensor table - table->push_back(std::move(new_row)); - - return Status::OK(); -} - -// fillBuffer always expects a new table to fill -Status ZipOp::fillBuffer(TensorQTable *const table) { - if (table == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp fillBuffer null table pointer."); - } - TensorRow new_row; - while (table->size() < static_cast(rows_per_buffer_)) { - RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); - // Early exit the loop if we got empty row from any of our child iterations - if (new_row.empty()) { - return Status::OK(); - } - // else we got a row so pack it into the tensor table. - table->push_back(std::move(new_row)); - } - return Status::OK(); -} - -// fetches next zip buffer row (merged row) -Status ZipOp::getNextTensorRow(TensorRow *const new_zip_row) { - // iterate over all iterators and generate a row - for (int32_t i = 0; i < children_num_; ++i) { - TensorRow new_row = {}; - RETURN_IF_NOT_OK((child_iterators_[i])->FetchNextTensorRow(&new_row)); - // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row - if (new_row.empty()) { - // If we did not get a row from any of the children, then it's the end of an epoch and we can move - // to drain state. - MS_LOG(DEBUG) << "Zip operator child iterator produced empty row."; - draining_ = true; - new_zip_row->clear(); - // If we picked up an eof here, then we are completely done. - if ((child_iterators_[i])->eof_handled()) { - MS_LOG(DEBUG) << "Zip operator iterator got EOF."; - eof_ = true; - } - return Status::OK(); - } else { - MS_LOG(DEBUG) << "Zip operator got row from child " << i << ". Num cols: " << new_row.size() << "."; - // if row isn't empty then we can append the fetched row with new_zip_row - new_zip_row->insert(new_zip_row->end(), new_row.begin(), new_row.end()); - } - } - MS_LOG(DEBUG) << "Zip operator builds a zipped row. Number of columns in row: " << new_zip_row->size() << "."; - return Status::OK(); -} - -// drain end of epoch messages from iterator for this epoch -Status ZipOp::drainPipeline() { - // we don't need to drain if we reached eof - if (eof_) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "ZipOp draining should not be done if already at eof!"); - } - for (int32_t con = 0; con < children_num_; ++con) { - MS_LOG(DEBUG) << "Zip operator draining child at " << con << "."; - RETURN_IF_NOT_OK(child_iterators_[con]->Drain()); - } - // at this point all connectors don't contain end of epoch messages. next iteration should be clean - return Status::OK(); -} - -// A function that prints info about the Operator -void ZipOp::Print(std::ostream &out, // In: The output stream to print to - bool show_all) const { // In: T/F if it should print everything - // Always show the id and name as first line regardless if this is summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nDatasets: " << children_num_ << "\n\n"; - } -} - -// overwrite function and handle eof -Status ZipOp::EofReceived(int32_t) { - MS_LOG(DEBUG) << "Zip operator EOF received, do nothing now."; - return Status::OK(); -} - -// overwrite function and handle eoe -Status ZipOp::EoeReceived(int32_t) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -// Visitor accept method for NodePass -Status ZipOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status ZipOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - column_name_id_map_ = {}; - for (int32_t i = 0; i < child_.size(); ++i) { - // Initializing col_name_id_map from the child. - const std::unordered_map col_name_id_map = child_[i]->column_name_id_map(); - int32_t colsCurrent = column_name_id_map_.size(); - // the update code below shouldn't do anything bad if the column name already exists. - for (const auto &pair : col_name_id_map) { - std::string name = pair.first; - int32_t old_id = pair.second; - // check if name already exists in column name descriptor - if (column_name_id_map_.count(name) == 1) { - RETURN_STATUS_UNEXPECTED("key already exists when zipping datasets"); - } - column_name_id_map_[name] = old_id + colsCurrent; - } - } - MS_LOG(DEBUG) << "Setting column map:\n" << this->ColumnNameMapAsString(); - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h deleted file mode 100644 index fad3c22eaa..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// forward declare -class DataBuffer; - -class ZipOp : public PipelineOp { - public: - // The nested builder class inside of the ZipOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - // NOTE: the rows per buffer with initial value 0 means to default to the number of rows from the first child - - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // The builder "build" method creates the ZipOp dataset Operator. - // @return shared_ptr to the new ZipOp object - Status Build(std::shared_ptr *); - - private: - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - - Status SanityCheck() const; - }; - - // Constructor for ZipOp - // @param rows_per_buffer - number of rows in output buffer - // @param op_connector_size - connector size - ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); - - // Destructor - ~ZipOp(); - - Status EofReceived(int32_t) override; - - Status EoeReceived(int32_t) override; - - // Print function for Zip - // @param out - output stream to print to - // @param show_all - if it should print everything - void Print(std::ostream &out, bool show_all) const override; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const ZipOp &zo) { - zo.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ZipOp"; } - - private: - // Handles preprocessing of the main loop, used when starting new epoch - Status prepare(TensorQTable *const table); - - // This function calls takes a table repeatedly adds rows to it. - // @param table a table of tensors to be moved into a buffer - Status fillBuffer(TensorQTable *const table); - - // Special handle case where an empty row has been received from child iterator - // @note - we need to drain eoe signals from all children connectors. - // @details - when this function is called, then we encountered eoe at child iterator - // we have to drain rows from other child iterators until we hit eoe from all other child iterators - Status drainPipeline(); - - // Merges 1 row from each childIterator together - // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty - // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true - // @details merge rows from iterator together. This is the main functionality for ZipOp - // this function takes one row and fills it with tensors from rows fetched - // from childIterators. - // @example: - // Zips multiple rows at a time, the output is store in newZipRow - // 1 a T - // \ | / - // 1, a, T - Status getNextTensorRow(TensorRow *const new_zip_row); - - // Computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t children_num_; - int32_t rows_per_buffer_; - int32_t buffer_id_; - bool draining_; - bool eof_; - std::vector> child_iterators_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/db_connector.h b/mindspore/ccsrc/dataset/engine/db_connector.h deleted file mode 100644 index 54909f51ba..0000000000 --- a/mindspore/ccsrc/dataset/engine/db_connector.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DB_CONNECTOR_H_ -#define DATASET_ENGINE_DB_CONNECTOR_H_ - -#include -#include -#include "dataset/engine/connector.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { -// DbConnector is a derived class from Connector with added logic to handle EOE and EOF. -// The Connector class itself is responsible to ensure deterministic order on every run. -class DbConnector : public Connector> { - public: - // Constructor of DbConnector - // @note DbConnector will create internal N number of blocking queues, where N = nProducers. - // See Connector.h for more details. - // @param n_producers The number of threads producing data into this DbConnector. - // @param n_consumers The number of thread consuming data from this DbConnector. - // @param queue_capacity The number of element (DataBuffer) for each internal queue. - DbConnector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity) - : Connector>(n_producers, n_consumers, queue_capacity), end_of_file_(false) {} - - // Destructor of DbConnector - ~DbConnector() = default; - - // Add a unique_ptr into the DbConnector. - // @note The caller of this add method should use std::move to pass the ownership to DbConnector. - // @param worker_id The id of a worker thread calling this method. - // @param el A rvalue reference to an element to be passed/added/pushed. - Status Add(int32_t worker_id, std::unique_ptr &&el) noexcept { - return (Connector>::Push(worker_id, std::move(el))); - } - - // Get a unique_ptr from the DbConnector. - // @note After the first EOF Buffer is encountered, subsequent pop()s will return EOF Buffer. - // This will provide/propagate the EOF to all consumer threads of this Connector. - // Thus, When the num_consumers < num_producers, there will be extra EOF messages in some of the internal queues - // and reset() must be called before reusing DbConnector. - // @param worker_id The id of a worker thread calling this method. - // @param result The address of a unique_ptr where the popped element will be placed. - // @param retry_if_eoe A flag to allow the same thread invoke pop() again if the current pop returns eoe buffer. - Status PopWithRetry(int32_t worker_id, std::unique_ptr *result, bool retry_if_eoe = false) noexcept { - if (result == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "[ERROR] nullptr detected when getting data from db connector"); - } else { - std::unique_lock lk(m_); - RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return (expect_consumer_ == worker_id) || end_of_file_; })); - // Once an EOF message is encountered this flag will be set and we can return early. - if (end_of_file_) { - *result = std::make_unique(0, DataBuffer::kDeBFlagEOF); - } else { - RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); - if (*result == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "[ERROR] nullptr detected when getting data from db connector"); - } - // Setting the internal flag once the first EOF is encountered. - if ((*result)->eof()) { - end_of_file_ = true; - } - pop_from_ = (pop_from_ + 1) % num_producers_; - } - // Do not increment expect_consumer_ when result is eoe and retry_if_eoe is set. - if (!((*result)->eoe() && retry_if_eoe)) { - expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; - } - } - out_buffers_count_++; - cv_.NotifyAll(); - return Status::OK(); - } - - private: - // A flag to indicate the end of stream has been encountered. - bool end_of_file_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DB_CONNECTOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.cc b/mindspore/ccsrc/dataset/engine/execution_tree.cc deleted file mode 100644 index 385722e257..0000000000 --- a/mindspore/ccsrc/dataset/engine/execution_tree.cc +++ /dev/null @@ -1,310 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/execution_tree.h" -#include -#include -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/util/task_manager.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/engine/opt/pre/removal_pass.h" -#include "dataset/engine/perf/profiling.h" -#include "dataset/engine/perf/monitor.h" - -namespace mindspore { -namespace dataset { -// Constructor -ExecutionTree::ExecutionTree() : id_count_(0) { - tg_ = std::make_unique(); - tree_state_ = kDeTStateInit; - prepare_flags_ = kDePrepNone; - perf_monitor_ = std::make_unique(this); - profiling_manager_ = std::make_unique(this); -} - -// Destructor -ExecutionTree::~ExecutionTree() { (void)tg_->ServiceStop(); } - -// Associates a DatasetOp with this tree. This assigns a valid node id to the operator and -// provides it with a link to the tree. A node cannot form any relationships (parent/child) with -// other nodes unless they are associated with the same tree. -Status ExecutionTree::AssociateNode(const std::shared_ptr &op) { - // If we are already a part of the tree, no-op - if (op->tree_ == this) { - return Status::OK(); - } - if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { - std::string err_msg = - "Invalid tree state for adding a node. Current state: " + std::to_string(static_cast(tree_state_)) + - " Expected states: " + std::to_string(static_cast(kDeTStateInit)) + " or " + - std::to_string(static_cast(kDeTStateBuilding)); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Enter the building state if we were not already there - tree_state_ = kDeTStateBuilding; - - // Assign an id to the operator - op->set_id(id_count_); - id_count_++; - - // Assign our tree into the op so that each op has a link back to the tree - op->set_tree(this); - return Status::OK(); -} - -// Sets the root node of the tree -Status ExecutionTree::AssignRoot(const std::shared_ptr &op) { - // Tree must be in building state before we can assign root to it - if (tree_state_ != kDeTStateBuilding) { - std::string err_msg = - "Invalid tree state for assigning a root node. Current state: " + std::to_string(static_cast(tree_state_)) + - " Expected state: " + std::to_string(static_cast(kDeTStateBuilding)); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // If they didn't already call AssociateNode for this node before calling AssignRoot, - // then do so now. - if (op->operator_id_ == DatasetOp::kInvalidOperatorId) { - RETURN_IF_NOT_OK(this->AssociateNode(op)); - } - - // Then add it as the root. - root_ = op; - - return Status::OK(); -} - -// A print method typically used for debugging -void ExecutionTree::Print(std::ostream &out, const std::shared_ptr &op) const { - out << "Execution tree summary:\n" - << "-----------------------\n"; - this->PrintNode(out, op == nullptr ? root_ : op, "", true, false); - out << "\nExecution tree operator details:\n" - << "--------------------------------\n"; - this->PrintNode(out, op == nullptr ? root_ : op, "", true, true); -} - -// A helper functions for doing the recursive printing -void ExecutionTree::PrintNode(std::ostream &out, const std::shared_ptr &dataset_op, std::string indent, - bool last, bool detailed) const { - // Decide which printer to use based on detailed arg. - if (!detailed) { - out << indent << "+- " << *dataset_op; - indent += (last ? " " : "| "); - } else { - dataset_op->Print(out, detailed); - } - - // Descend to children - for (int32_t i = 0; i < dataset_op->child_.size(); ++i) { - this->PrintNode(out, dataset_op->child_[i], indent, (i == (dataset_op->child_.size() - 1)), detailed); - } -} - -// Start the execution of the tree -Status ExecutionTree::Launch() { - // Tree must be built and prepared before it can be launched! - if (tree_state_ != kDeTStateReady) { - std::string err_msg = - "Invalid tree state for launching tree. Current state: " + std::to_string(static_cast(tree_state_)) + - " Expected state: " + std::to_string(static_cast(kDeTStateReady)); - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::ostringstream ss; - ss << *this; - - // Profiling infrastructures need to be initialized before Op launching - if (profiling_manager_->IsProfilingEnable()) { - // Setup profiling manager - RETURN_IF_NOT_OK(profiling_manager_->Initialize()); - // Launch Monitor Thread - RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Monitor Thread launched", std::ref(*perf_monitor_))); - } - - MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str(); - for (auto itr = this->begin(); itr != this->end(); ++itr) { - // An inlined operator is one that has an output connector size of 0, and it does not - // require a thread to execute. Instead, the work of this operator is executed inlined - // from the tree node directly above it (or in the case of a root node, it runs from within - // the launching tree/user thread. Do not exec any thread for an inlined op. - itr->state_ = DatasetOp::OpState::kDeOpRunning; - if (!itr->inlined()) { - RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Op launched, OperatorId:" + std::to_string(itr->id()), std::ref(*itr))); - // Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp - } - } - - tree_state_ = kDeTStateExecuting; - - return Status::OK(); -} - -// A function that traverse the tree in postorder then save the results in nodes -void ExecutionTree::Iterator::PostOrderTraverse(const std::shared_ptr &node) { - if (node == nullptr) { - return; - } - for (int32_t i = 0; i < node->child_.size(); ++i) { - PostOrderTraverse(node->child_[i]); - } - nodes_.push_back(node); -} - -ExecutionTree::Iterator::Iterator(const std::shared_ptr &root) : ind_(0) { - // post-order traverse the tree, if root is null, it return - PostOrderTraverse(root); - nodes_.emplace_back(nullptr); -} - -// Given the number of workers, launches the worker entry function for each. Essentially a -// wrapper for the TaskGroup handling that is stored inside the execution tree. -Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function func) { - // Launch the workers - for (int32_t i = 0; i < num_workers; ++i) { - RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Parallel Op Worker", std::bind(func, i))); - } - return Status::OK(); -} - -// The driver of the prepare phase of the execution tree. -// Prepare phase consists of three sub phases -// -// 1. PrepareTreePreAction() -// Compulsory transformation/action pre optimization. -// For example, CacheOp Insertion -// -// 2. Optimize() -// Optimization transformation/action, optional -// For example, MapOp Fusion -// -// 3. PrepareTreePostAction() -// Compulsory transformation/action post optimization. -// For example, repeatOp inlining -// -// @return Status - The error code return -Status ExecutionTree::Prepare() { - // Pre optimization compulsory transformation - RETURN_IF_NOT_OK(this->PrepareTreePreAction()); - - // Optimization transformation - RETURN_IF_NOT_OK(this->Optimize()); - - // Post optimization compulsory transformation - RETURN_IF_NOT_OK(this->PrepareTreePostAction()); - - // Existing transformation implementation, will be removed later - RETURN_IF_NOT_OK(this->PrepareDeprecated()); - return Status::OK(); -} - -Status ExecutionTree::PrepareTreePreAction() { - bool modified = false; - std::vector> pre_actions; - // Construct pre actions - MS_LOG(INFO) << "Running pre pass"; - pre_actions.push_back(std::make_unique(RemovalPass())); - // Apply pre action passes - for (auto &pass : pre_actions) { - RETURN_IF_NOT_OK(pass->Run(this, &modified)); - } - return Status::OK(); -} - -Status ExecutionTree::PrepareTreePostAction() { - // The tree is ready to be prepared. - tree_state_ = kDeTStatePrepare; - return Status::OK(); -} - -Status ExecutionTree::Optimize() { - // auto pp = new PrinterPass(); - // bool modified = false; - // pp->Run(this, &modified); - return Status::OK(); -} - -// The driver of the prepare phase of the execution tree. The prepare phase will recursively -// walk the tree to perform modifications to the tree or specific nodes within the tree to get -// it ready for execution. -// -// This driver is deprecated. -Status ExecutionTree::PrepareDeprecated() { - // Tree must be in pending prepare state before we can assign root to it - if (tree_state_ != kDeTStatePrepare) { - std::string err_msg = - "Invalid tree state for preparing the tree. Current state: " + std::to_string(static_cast(tree_state_)) + - " Expected state: " + std::to_string(static_cast(kDeTStatePrepare)); - RETURN_STATUS_UNEXPECTED(err_msg); - } - // Start the recursive prepare - RETURN_IF_NOT_OK(this->PrepareNode(root_)); - tree_state_ = kDeTStateReady; - return Status::OK(); -} - -// Recursive function used during prepare phase to visit a node and drive any pre- and post- -// node actions during a tree walk. -Status ExecutionTree::PrepareNode(const std::shared_ptr &dataset_op) { - // execute PreAction - RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction()); - - // Before going down into children, make any prepare flags updates based on this operator. - uint32_t op_prep_flags = dataset_op->PrepareFlags(); - BitSet(&prepare_flags_, op_prep_flags); - - // Now, descend to children - for (const auto &i : dataset_op->child_) { - RETURN_IF_NOT_OK(this->PrepareNode(i)); - } - - // No more children, now we execute any prepare actions before going back up the - // the tree on recursive function - RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction()); - - // Then clear the flags from this op now that we have prepared it. - BitClear(&prepare_flags_, op_prep_flags); - - return Status::OK(); -} - -// Adds an operator to the eoe operator stack during prepare phase. -void ExecutionTree::AddToEOEOpStack(std::shared_ptr dataset_op) { eoe_stack_.push(dataset_op); } - -// Pops an operator from the eoe operator stack during prepare phase. -std::shared_ptr ExecutionTree::PopFromEOEOpStack() { - std::shared_ptr top_op = nullptr; - if (!eoe_stack_.empty()) { - top_op = eoe_stack_.top(); - eoe_stack_.pop(); - } - return top_op; -} - -// Adds a sampler to the sampler stack during prepare phase. -void ExecutionTree::AddToSamplerStack(std::shared_ptr sampler) { sampler_stack_.push(sampler); } - -// Pops an operator from the sampler stack during prepare phase. -std::shared_ptr ExecutionTree::PopFromSamplerStack() { - std::shared_ptr top_sampler = nullptr; - if (!sampler_stack_.empty()) { - top_sampler = sampler_stack_.top(); - sampler_stack_.pop(); - } - return top_sampler; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.h b/mindspore/ccsrc/dataset/engine/execution_tree.h deleted file mode 100644 index 5ebfa539ad..0000000000 --- a/mindspore/ccsrc/dataset/engine/execution_tree.h +++ /dev/null @@ -1,257 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_EXECUTION_TREE_H_ -#define DATASET_ENGINE_EXECUTION_TREE_H_ - -#include -#include -#include -#include -#include -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/util/status.h" -#include "mindspore/ccsrc/dataset/engine/perf/profiling.h" - -namespace mindspore { -namespace dataset { -// Forward declares -class TaskGroup; -class DatasetOp; -class Monitor; - -class ExecutionTree { - public: - // Prepare flags used during tree prepare phase - enum PrepareFlags { - kDePrepNone = 0, - kDePrepRepeat = 1, // Processing a repeat operation - kDePrepCache = 2 // Processing a cache operation - }; - - // State flags for the lifecycle of the tree - enum TreeState { - kDeTStateInit = 0, // The freshly initialized state after construction - kDeTStateBuilding, // The tree is being built, nodes are being added - kDeTStatePrepare, // The tree has been assigned a root node and is pending prepare - kDeTStateReady, // The tree has been prepared and is ready to be launched - kDeTStateExecuting, // The tree has been launched and is executing - kDeTStateFinished // The tree has been drained, dataset iterator received EOF - }; - - class Iterator { - public: - // Constructor - // @param root The root node to start iterating from - explicit Iterator(const std::shared_ptr &root = nullptr); - - // Destructor - ~Iterator() {} - - Iterator &operator++() { - ++ind_; - return *this; - } // prefix ++ overload - Iterator operator++(int) { - Iterator it = *this; - it.ind_ = ind_; - ind_++; - return it; - } // post-fix ++ overload - Iterator &operator--() { - --ind_; - return *this; - } // prefix -- overload - Iterator operator--(int) { - Iterator it = *this; - it.ind_ = ind_; - ind_--; - return it; - } // post-fix -- overload - DatasetOp &operator*() { return *nodes_[ind_]; } // dereference operator - std::shared_ptr operator->() { return nodes_[ind_]; } - - // getter function - // @return Shared pointer to the current operator - std::shared_ptr get() { return nodes_[ind_]; } - - bool operator!=(const Iterator &rhs) { return nodes_[ind_] != rhs.nodes_[rhs.ind_]; } - - int32_t NumNodes() { return nodes_.size(); } - - private: - int32_t ind_; // the cur node our Iterator points to - std::vector> nodes_; // store the nodes in post order - void PostOrderTraverse(const std::shared_ptr &); - }; - - // Constructor - ExecutionTree(); - - // Destructor - ~ExecutionTree(); - - // Associates a DatasetOp with this tree. This assigns a valid node id to the operator and - // provides it with a link to the tree. A node cannot form any relationships (parent/child) with - // other nodes unless they are associated with the same tree. - // @param op - The operator to associate - // @return Status - The error code return - Status AssociateNode(const std::shared_ptr &op); - - // Sets the root node of the tree - // @param op - The operator to assign as root - // @return Status - The error code return - Status AssignRoot(const std::shared_ptr &op); - - // Start the execution of the tree - // @return Status - The error code return - Status Launch(); - - /// A print method typically used for debugging - /// \param out - The output stream to write output to - void Print(std::ostream &out, const std::shared_ptr &op = nullptr) const; - - // Returns an iterator positioned at the start - // @return Iterator - The iterator - ExecutionTree::Iterator begin(const std::shared_ptr &root = nullptr) const { - return Iterator(root == nullptr ? root_ : root); - } - - // Returns an iterator positioned at the end - // @return Iterator - The iterator - ExecutionTree::Iterator end() const { return Iterator(nullptr); } - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param exe_tree - reference to the execution tree to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, ExecutionTree &exe_tree) { - exe_tree.Print(out); - return out; - } - - // Given the number of workers, launches the worker entry function for each. Essentially a - // wrapper for the TaskGroup handling that is stored inside the execution tree. - // @param num_workers - The number of workers to launch - // @param func - The function entry point that workers will execute - // @return Status - The error code return - Status LaunchWorkers(int32_t num_workers, std::function func); - - // Getter method - // @return shared_ptr to the root operator - std::shared_ptr root() const { return root_; } - - // Getter method - // @return the prepare flags - uint32_t PrepareFlags() const { return prepare_flags_; } - - // The driver of the prepare phase of the execution tree. - // Prepare phase consists of three sub phases - // - // 1. PrepareTreePreAction() - // Compulsory transformation/action pre optimization. - // For example, CacheOp Insertion - // - // 2. Optimize() - // Optimization transformation/action, optional - // For example, MapOp Fusion - // - // 3. PrepareTreePostAction() - // Compulsory transformation/action post optimization. - // For example, repeatOp inlining - // - // @return Status - The error code return - Status Prepare(); - - // Compulsory transformation/action pre optimization. - // @return Status - The error code return - Status PrepareTreePreAction(); - - // Compulsory transformation/action post optimization. - // @return Status - The error code return - Status PrepareTreePostAction(); - - // Optimization transformation/action, optional. - // @return Status - The error code return - Status Optimize(); - - // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively - // walk the tree to perform modifications to the tree or specific nodes within the tree to get - // it ready for execution. - // @return Status - The error code return - Status PrepareDeprecated(); - - // Recursive function used during prepare phase to visit a node and drive any pre- and post- - // node actions during a tree walk. - // @param op - The dataset op to work on - // @return Status - The error code return - Status PrepareNode(const std::shared_ptr &dataset_op); - - /// Adds an operator to the eoe operator stack during prepare phase. - /// \param op - The dataset op to work add to eoe stack - /// \return Status - The error code return - void AddToEOEOpStack(std::shared_ptr dataset_op); - - /// Pops an operator from the eoe operator stack during prepare phase. - /// \return shared_ptr to the popped operator - std::shared_ptr PopFromEOEOpStack(); - - /// Adds a sampler to the sampler stack during prepare phase. - /// \param samplerop - The dataset op to work add to eoe stack - /// \return Status - The error code return - void AddToSamplerStack(std::shared_ptr sampler); - - /// Pops an operator from the sampler stack during prepare phase. - /// \return shared_ptr to the popped operator - std::shared_ptr PopFromSamplerStack(); - - // Return the pointer to the TaskGroup - // @return raw pointer to the TaskGroup - TaskGroup *AllTasks() const { return tg_.get(); } - - // Return if the ExecutionTree is finished (iterator receives EOF). - // @return Bool - true is ExecutionTree is finished - bool isFinished() const { return tree_state_ == TreeState::kDeTStateFinished; } - - // Set the ExecutionTree to Finished state. - void SetFinished() { tree_state_ = TreeState::kDeTStateFinished; } - - // Getter for profiling manager, no ownership - ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); } - - private: - // A helper functions for doing the recursive printing - // @param dataset_op - The dataset op to print - // @param indent - an indent string for aligning child levels in output - // @param last - an indicator if it's the last child or not - // @param detailed - should it display the detailed node output or the summary line - void PrintNode(std::ostream &out, const std::shared_ptr &dataset_op, std::string indent, bool last, - bool detailed) const; - - std::unique_ptr tg_; // Class for worker management - std::shared_ptr root_; // The root node of the tree - int32_t id_count_; // Counter for generating operator id's - uint32_t prepare_flags_; // Flags used during tree prepare - TreeState tree_state_; // Tracking the current tree state - std::unique_ptr perf_monitor_; // Performance Monitor - std::unique_ptr profiling_manager_; // Profiling manager - std::stack> eoe_stack_; // A stack used during prepare phase - std::stack> sampler_stack_; // A stack used during prepare phase -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_EXECUTION_TREE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/edge.h b/mindspore/ccsrc/dataset/engine/gnn/edge.h deleted file mode 100644 index 47314d97c2..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/edge.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_EDGE_H_ -#define DATASET_ENGINE_GNN_EDGE_H_ - -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/engine/gnn/feature.h" -#include "dataset/engine/gnn/node.h" - -namespace mindspore { -namespace dataset { -namespace gnn { -using EdgeType = int8_t; -using EdgeIdType = int32_t; - -class Edge { - public: - // Constructor - // @param EdgeIdType id - edge id - // @param EdgeType type - edge type - // @param std::shared_ptr src_node - source node - // @param std::shared_ptr dst_node - destination node - Edge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) - : id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {} - - virtual ~Edge() = default; - - // @return NodeIdType - Returned edge id - EdgeIdType id() const { return id_; } - - // @return NodeIdType - Returned edge type - EdgeType type() const { return type_; } - - // Get the feature of a edge - // @param FeatureType feature_type - type of feature - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; - - // Get nodes on the edge - // @param std::pair, std::shared_ptr> *out_node - Source and destination nodes returned - Status GetNode(std::pair, std::shared_ptr> *out_node) { - *out_node = std::make_pair(src_node_, dst_node_); - return Status::OK(); - } - - // Set node to edge - // @param const std::pair, std::shared_ptr> &in_node - - Status SetNode(const std::pair, std::shared_ptr> &in_node) { - src_node_ = in_node.first; - dst_node_ = in_node.second; - return Status::OK(); - } - - // Update feature of edge - // @param std::shared_ptr feature - - // @return Status - The error code return - virtual Status UpdateFeature(const std::shared_ptr &feature) = 0; - - protected: - EdgeIdType id_; - EdgeType type_; - std::shared_ptr src_node_; - std::shared_ptr dst_node_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_EDGE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/feature.cc b/mindspore/ccsrc/dataset/engine/gnn/feature.cc deleted file mode 100644 index e457947821..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/feature.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/gnn/feature.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -Feature::Feature(FeatureType type_name, std::shared_ptr value) : type_name_(type_name), value_(value) {} - -} // namespace gnn -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/feature.h b/mindspore/ccsrc/dataset/engine/gnn/feature.h deleted file mode 100644 index 7ce5967fbd..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/feature.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_FEATURE_H_ -#define DATASET_ENGINE_GNN_FEATURE_H_ - -#include - -#include "dataset/core/tensor.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -namespace gnn { -using FeatureType = int16_t; - -class Feature { - public: - // Constructor - // @param FeatureType type_name - feature type - // @param std::shared_ptr value - feature value - Feature(FeatureType type_name, std::shared_ptr value); - - ~Feature() = default; - - // Get feature value - // @return std::shared_ptr *out_value - feature value - const std::shared_ptr Value() const { return value_; } - - // @return NodeIdType - Returned feature type - FeatureType type() const { return type_name_; } - - private: - FeatureType type_name_; - std::shared_ptr value_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_FEATURE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/dataset/engine/gnn/graph.cc deleted file mode 100644 index a143bd4e38..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.cc +++ /dev/null @@ -1,614 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/gnn/graph.h" - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor_shape.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -Graph::Graph(std::string dataset_file, int32_t num_workers) - : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { - rnd_.seed(GetSeed()); - MS_LOG(INFO) << "num_workers:" << num_workers; -} - -Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr *out) { - auto itr = node_type_map_.find(node_type); - if (itr == node_type_map_.end()) { - std::string err_msg = "Invalid node type:" + std::to_string(node_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); - } - return Status::OK(); -} - -template -Status Graph::CreateTensorByVector(const std::vector> &data, DataType type, - std::shared_ptr *out) { - if (!type.IsCompatible()) { - RETURN_STATUS_UNEXPECTED("Data type not compatible"); - } - if (data.empty()) { - RETURN_STATUS_UNEXPECTED("Input data is empty"); - } - std::shared_ptr tensor; - size_t m = data.size(); - size_t n = data[0].size(); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &tensor, TensorImpl::kFlexible, TensorShape({static_cast(m), static_cast(n)}), type, nullptr)); - auto ptr = tensor->begin(); - for (const auto &id_m : data) { - CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); - for (const auto &id_n : id_m) { - *ptr = id_n; - ptr++; - } - } - tensor->Squeeze(); - *out = std::move(tensor); - return Status::OK(); -} - -template -Status Graph::ComplementVector(std::vector> *data, size_t max_size, T default_value) { - if (!data || data->empty()) { - RETURN_STATUS_UNEXPECTED("Input data is empty"); - } - for (std::vector &vec : *data) { - size_t size = vec.size(); - if (size > max_size) { - RETURN_STATUS_UNEXPECTED("The max_size parameter is abnormal"); - } else { - for (size_t i = 0; i < (max_size - size); ++i) { - vec.push_back(default_value); - } - } - } - return Status::OK(); -} - -Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr *out) { - auto itr = edge_type_map_.find(edge_type); - if (itr == edge_type_map_.end()) { - std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); - } - return Status::OK(); -} - -Status Graph::GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) { - if (edge_list.empty()) { - RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); - } - - std::vector> node_list; - node_list.reserve(edge_list.size()); - for (const auto &edge_id : edge_list) { - auto itr = edge_id_map_.find(edge_id); - if (itr == edge_id_map_.end()) { - std::string err_msg = "Invalid edge id:" + std::to_string(edge_id); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - std::pair, std::shared_ptr> nodes; - RETURN_IF_NOT_OK(itr->second->GetNode(&nodes)); - node_list.push_back({nodes.first->id(), nodes.second->id()}); - } - } - RETURN_IF_NOT_OK(CreateTensorByVector(node_list, DataType(DataType::DE_INT32), out)); - return Status::OK(); -} - -Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, - std::shared_ptr *out) { - if (node_list.empty()) { - RETURN_STATUS_UNEXPECTED("Input node_list is empty."); - } - if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { - std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::vector> neighbors; - size_t max_neighbor_num = 0; - neighbors.resize(node_list.size()); - for (size_t i = 0; i < node_list.size(); ++i) { - std::shared_ptr node; - RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node)); - RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i])); - max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); - } - - RETURN_IF_NOT_OK(ComplementVector(&neighbors, max_neighbor_num, kDefaultNodeId)); - RETURN_IF_NOT_OK(CreateTensorByVector(neighbors, DataType(DataType::DE_INT32), out)); - - return Status::OK(); -} - -Status Graph::CheckSamplesNum(NodeIdType samples_num) { - NodeIdType all_nodes_number = - std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, - [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); - if ((samples_num < 1) || (samples_num > all_nodes_number)) { - std::string err_msg = "Wrong samples number, should be between 1 and " + std::to_string(all_nodes_number) + - ", got " + std::to_string(samples_num); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -Status Graph::GetSampledNeighbors(const std::vector &node_list, - const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) { - CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); - CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), - "The sizes of neighbor_nums and neighbor_types are inconsistent."); - for (const auto &num : neighbor_nums) { - RETURN_IF_NOT_OK(CheckSamplesNum(num)); - } - for (const auto &type : neighbor_types) { - if (node_type_map_.find(type) == node_type_map_.end()) { - std::string err_msg = "Invalid neighbor type:" + std::to_string(type); - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - std::vector> neighbors_vec(node_list.size()); - for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { - std::shared_ptr input_node; - RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &input_node)); - neighbors_vec[node_idx].emplace_back(node_list[node_idx]); - std::vector input_list = {node_list[node_idx]}; - for (size_t i = 0; i < neighbor_nums.size(); ++i) { - std::vector neighbors; - neighbors.reserve(input_list.size() * neighbor_nums[i]); - for (const auto &node_id : input_list) { - if (node_id == kDefaultNodeId) { - for (int32_t j = 0; j < neighbor_nums[i]; ++j) { - neighbors.emplace_back(kDefaultNodeId); - } - } else { - std::shared_ptr node; - RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); - std::vector out; - RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); - neighbors.insert(neighbors.end(), out.begin(), out.end()); - } - } - neighbors_vec[node_idx].insert(neighbors_vec[node_idx].end(), neighbors.begin(), neighbors.end()); - input_list = std::move(neighbors); - } - } - RETURN_IF_NOT_OK(CreateTensorByVector(neighbors_vec, DataType(DataType::DE_INT32), out)); - return Status::OK(); -} - -Status Graph::NegativeSample(const std::vector &data, const std::unordered_set &exclude_data, - int32_t samples_num, std::vector *out_samples) { - CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); - std::vector shuffled_id(data.size()); - std::iota(shuffled_id.begin(), shuffled_id.end(), 0); - std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); - for (const auto &index : shuffled_id) { - if (exclude_data.find(data[index]) != exclude_data.end()) { - continue; - } - out_samples->emplace_back(data[index]); - if (out_samples->size() >= samples_num) { - break; - } - } - return Status::OK(); -} - -Status Graph::GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, - NodeType neg_neighbor_type, std::shared_ptr *out) { - CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); - RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); - if (node_type_map_.find(neg_neighbor_type) == node_type_map_.end()) { - std::string err_msg = "Invalid neighbor type:" + std::to_string(neg_neighbor_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::vector> neighbors_vec; - neighbors_vec.resize(node_list.size()); - for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { - std::shared_ptr node; - RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node)); - std::vector neighbors; - RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors)); - std::unordered_set exclude_node; - std::transform(neighbors.begin(), neighbors.end(), - std::insert_iterator>(exclude_node, exclude_node.begin()), - [](const NodeIdType node) { return node; }); - auto itr = node_type_map_.find(neg_neighbor_type); - if (itr == node_type_map_.end()) { - std::string err_msg = "Invalid node type:" + std::to_string(neg_neighbor_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - neighbors_vec[node_idx].emplace_back(node->id()); - if (itr->second.size() > exclude_node.size()) { - while (neighbors_vec[node_idx].size() < samples_num + 1) { - RETURN_IF_NOT_OK(NegativeSample(itr->second, exclude_node, samples_num - neighbors_vec[node_idx].size(), - &neighbors_vec[node_idx])); - } - } else { - MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() - << " neg_neighbor_type:" << neg_neighbor_type; - // If there are no negative neighbors, they are filled with kDefaultNodeId - for (int32_t i = 0; i < samples_num; ++i) { - neighbors_vec[node_idx].emplace_back(kDefaultNodeId); - } - } - } - } - RETURN_IF_NOT_OK(CreateTensorByVector(neighbors_vec, DataType(DataType::DE_INT32), out)); - return Status::OK(); -} - -Status Graph::RandomWalk(const std::vector &node_list, const std::vector &meta_path, - float step_home_param, float step_away_param, NodeIdType default_node, - std::shared_ptr *out) { - RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); - std::vector> walks; - RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); - RETURN_IF_NOT_OK(CreateTensorByVector({walks}, DataType(DataType::DE_INT32), out)); - return Status::OK(); -} - -Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { - auto itr = default_feature_map_.find(feature_type); - if (itr == default_feature_map_.end()) { - std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - *out_feature = itr->second; - } - return Status::OK(); -} - -Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, - TensorRow *out) { - if (!nodes || nodes->Size() == 0) { - RETURN_STATUS_UNEXPECTED("Input nodes is empty"); - } - CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty"); - TensorRow tensors; - for (const auto &f_type : feature_types) { - std::shared_ptr default_feature; - // If no feature can be obtained, fill in the default value - RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature)); - - TensorShape shape(default_feature->Value()->shape()); - auto shape_vec = nodes->shape().AsVector(); - dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); - shape = shape.PrependDim(size); - std::shared_ptr fea_tensor; - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); - - dsize_t index = 0; - for (auto node_itr = nodes->begin(); node_itr != nodes->end(); ++node_itr) { - std::shared_ptr feature; - if (*node_itr == kDefaultNodeId) { - feature = default_feature; - } else { - std::shared_ptr node; - RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node)); - if (!node->GetFeatures(f_type, &feature).IsOk()) { - feature = default_feature; - } - } - RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); - index++; - } - - TensorShape reshape(nodes->shape()); - for (auto s : default_feature->Value()->shape().AsVector()) { - reshape = reshape.AppendDim(s); - } - RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); - fea_tensor->Squeeze(); - tensors.push_back(fea_tensor); - } - *out = std::move(tensors); - return Status::OK(); -} - -Status Graph::GetEdgeFeature(const std::shared_ptr &edges, const std::vector &feature_types, - TensorRow *out) { - return Status::OK(); -} - -Status Graph::Init() { - RETURN_IF_NOT_OK(LoadNodeAndEdge()); - return Status::OK(); -} - -Status Graph::GetMetaInfo(MetaInfo *meta_info) { - meta_info->node_type.resize(node_type_map_.size()); - std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), - [](auto itr) { return itr.first; }); - std::sort(meta_info->node_type.begin(), meta_info->node_type.end()); - - meta_info->edge_type.resize(edge_type_map_.size()); - std::transform(edge_type_map_.begin(), edge_type_map_.end(), meta_info->edge_type.begin(), - [](auto itr) { return itr.first; }); - std::sort(meta_info->edge_type.begin(), meta_info->edge_type.end()); - - for (const auto &node : node_type_map_) { - meta_info->node_num[node.first] = node.second.size(); - } - - for (const auto &edge : edge_type_map_) { - meta_info->edge_num[edge.first] = edge.second.size(); - } - - for (const auto &node_feature : node_feature_map_) { - for (auto type : node_feature.second) { - meta_info->node_feature_type.emplace_back(type); - } - } - std::sort(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); - auto unique_node = std::unique(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); - meta_info->node_feature_type.erase(unique_node, meta_info->node_feature_type.end()); - - for (const auto &edge_feature : edge_feature_map_) { - for (const auto &type : edge_feature.second) { - meta_info->edge_feature_type.emplace_back(type); - } - } - std::sort(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); - auto unique_edge = std::unique(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); - meta_info->edge_feature_type.erase(unique_edge, meta_info->edge_feature_type.end()); - return Status::OK(); -} - -Status Graph::GraphInfo(py::dict *out) { - MetaInfo meta_info; - RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); - (*out)["node_type"] = py::cast(meta_info.node_type); - (*out)["edge_type"] = py::cast(meta_info.edge_type); - (*out)["node_num"] = py::cast(meta_info.node_num); - (*out)["edge_num"] = py::cast(meta_info.edge_num); - (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type); - (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type); - return Status::OK(); -} - -Status Graph::LoadNodeAndEdge() { - GraphLoader gl(dataset_file_, num_workers_); - // ask graph_loader to load everything into memory - RETURN_IF_NOT_OK(gl.InitAndLoad()); - // get all maps - RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, - &node_feature_map_, &edge_feature_map_, &default_feature_map_)); - return Status::OK(); -} - -Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { - auto itr = node_id_map_.find(id); - if (itr == node_id_map_.end()) { - std::string err_msg = "Invalid node id:" + std::to_string(id); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - *node = itr->second; - } - return Status::OK(); -} - -Graph::RandomWalkBase::RandomWalkBase(Graph *graph) - : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} - -Status Graph::RandomWalkBase::Build(const std::vector &node_list, const std::vector &meta_path, - float step_home_param, float step_away_param, const NodeIdType default_node, - int32_t num_walks, int32_t num_workers) { - node_list_ = node_list; - if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { - std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) + - ". The size of input path is " + std::to_string(meta_path.size()); - RETURN_STATUS_UNEXPECTED(err_msg); - } - meta_path_ = meta_path; - if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) { - std::string err_msg = "Failed, step_home_param and step_away_param required greater than " + - std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) + - ", step_away_param: " + std::to_string(step_away_param); - RETURN_STATUS_UNEXPECTED(err_msg); - } - step_home_param_ = step_home_param; - step_away_param_ = step_away_param; - default_node_ = default_node; - num_walks_ = num_walks; - num_workers_ = num_workers; - return Status::OK(); -} - -Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path) { - // Simulate a random walk starting from start node. - auto walk = std::vector(1, start_node); // walk is an vector - // walk simulate - while (walk.size() - 1 < meta_path_.size()) { - // current nodE - auto cur_node_id = walk.back(); - std::shared_ptr cur_node; - RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node)); - - // current neighbors - std::vector cur_neighbors; - RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true)); - std::sort(cur_neighbors.begin(), cur_neighbors.end()); - - // break if no neighbors - if (cur_neighbors.empty()) { - break; - } - - // walk by the fist node, then by the previous 2 nodes - std::shared_ptr stochastic_index; - if (walk.size() == 1) { - RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index)); - } else { - NodeIdType prev_node_id = walk[walk.size() - 2]; - RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index)); - } - NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)]; - walk.push_back(next_node_id); - } - - while (walk.size() - 1 < meta_path_.size()) { - walk.push_back(default_node_); - } - - *walk_path = std::move(walk); - return Status::OK(); -} - -Status Graph::RandomWalkBase::SimulateWalk(std::vector> *walks) { - // Repeatedly simulate random walks from each node - std::vector permutation(node_list_.size()); - std::iota(permutation.begin(), permutation.end(), 0); - for (int32_t i = 0; i < num_walks_; i++) { - unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); - std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed)); - for (const auto &i_perm : permutation) { - std::vector walk; - RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk)); - walks->push_back(walk); - } - } - return Status::OK(); -} - -Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, - std::shared_ptr *node_probability) { - // Generate alias nodes - std::shared_ptr node; - graph_->GetNodeByNodeId(node_id, &node); - std::vector neighbors; - RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true)); - std::sort(neighbors.begin(), neighbors.end()); - auto non_normalized_probability = std::vector(neighbors.size(), 1.0); - *node_probability = - std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); - return Status::OK(); -} - -Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, - std::shared_ptr *edge_probability) { - // Get the alias edge setup lists for a given edge. - std::shared_ptr src_node; - graph_->GetNodeByNodeId(src, &src_node); - std::vector src_neighbors; - RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true)); - - std::shared_ptr dst_node; - graph_->GetNodeByNodeId(dst, &dst_node); - std::vector dst_neighbors; - RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true)); - - std::sort(dst_neighbors.begin(), dst_neighbors.end()); - std::vector non_normalized_probability; - for (const auto &dst_nbr : dst_neighbors) { - if (dst_nbr == src) { - non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] - continue; - } - auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr); - if (it != src_neighbors.end()) { - // stay close, this node connect both src and dst - non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight'] - } else { - // step far away - non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] - } - } - - *edge_probability = - std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); - return Status::OK(); -} - -StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector &probability) { - uint32_t K = probability.size(); - std::vector switch_to_large_index(K, 0); - std::vector weight(K, .0); - std::vector smaller; - std::vector larger; - auto random_device = GetRandomDevice(); - std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon); - float accumulate_threshold = 0.0; - for (uint32_t i = 0; i < K; i++) { - float threshold_one = distribution(random_device); - accumulate_threshold += threshold_one; - weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold; - weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i); - } - - while ((!smaller.empty()) && (!larger.empty())) { - uint32_t small = smaller.back(); - smaller.pop_back(); - uint32_t large = larger.back(); - larger.pop_back(); - switch_to_large_index[small] = large; - weight[large] = weight[large] + weight[small] - 1.0; - weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large); - } - return StochasticIndex(switch_to_large_index, weight); -} - -uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { - auto switch_to_large_index = stochastic_index.first; - auto weight = stochastic_index.second; - const uint32_t size_of_index = switch_to_large_index.size(); - - auto random_device = GetRandomDevice(); - std::uniform_real_distribution<> distribution(0.0, 1.0); - - // Generate random integer between [0, K) - uint32_t random_idx = std::floor(distribution(random_device) * size_of_index); - - if (distribution(random_device) < weight[random_idx]) { - return random_idx; - } - return switch_to_large_index[random_idx]; -} - -template -std::vector Graph::RandomWalkBase::Normalize(const std::vector &non_normalized_probability) { - float sum_probability = - 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); - if (sum_probability < kGnnEpsilon) { - sum_probability = 1.0; - } - std::vector normalized_probability; - std::transform(non_normalized_probability.begin(), non_normalized_probability.end(), - std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; }); - return normalized_probability; -} -} // namespace gnn -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.h b/mindspore/ccsrc/dataset/engine/gnn/graph.h deleted file mode 100644 index 344a6c6bf2..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.h +++ /dev/null @@ -1,250 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_GRAPH_H_ -#define DATASET_ENGINE_GNN_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_row.h" -#include "dataset/engine/gnn/graph_loader.h" -#include "dataset/engine/gnn/feature.h" -#include "dataset/engine/gnn/node.h" -#include "dataset/engine/gnn/edge.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -const float kGnnEpsilon = 0.0001; -const uint32_t kMaxNumWalks = 80; -using StochasticIndex = std::pair, std::vector>; - -struct MetaInfo { - std::vector node_type; - std::vector edge_type; - std::map node_num; - std::map edge_num; - std::vector node_feature_type; - std::vector edge_feature_type; -}; - -class Graph { - public: - // Constructor - // @param std::string dataset_file - - // @param int32_t num_workers - number of parallel threads - Graph(std::string dataset_file, int32_t num_workers); - - ~Graph() = default; - - // Get all nodes from the graph. - // @param NodeType node_type - type of node - // @param std::shared_ptr *out - Returned nodes id - // @return Status - The error code return - Status GetAllNodes(NodeType node_type, std::shared_ptr *out); - - // Get all edges from the graph. - // @param NodeType edge_type - type of edge - // @param std::shared_ptr *out - Returned edge ids - // @return Status - The error code return - Status GetAllEdges(EdgeType edge_type, std::shared_ptr *out); - - // Get the node id from the edge. - // @param std::vector edge_list - List of edges - // @param std::shared_ptr *out - Returned node ids - // @return Status - The error code return - Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out); - - // All neighbors of the acquisition node. - // @param std::vector node_list - List of nodes - // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported - // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is - // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors - // is not enough, fill in tensor as -1. - // @return Status - The error code return - Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, - std::shared_ptr *out); - - // Get sampled neighbors. - // @param std::vector node_list - List of nodes - // @param std::vector neighbor_nums - Number of neighbors sampled per hop - // @param std::vector neighbor_types - Neighbor type sampled per hop - // @param std::shared_ptr *out - Returned neighbor's id. - // @return Status - The error code return - Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out); - - // Get negative sampled neighbors. - // @param std::vector node_list - List of nodes - // @param NodeIdType samples_num - Number of neighbors sampled - // @param NodeType neg_neighbor_type - The type of negative neighbor. - // @param std::shared_ptr *out - Returned negative neighbor's id. - // @return Status - The error code return - Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, - NodeType neg_neighbor_type, std::shared_ptr *out); - - // Node2vec random walk. - // @param std::vector node_list - List of nodes - // @param std::vector meta_path - node type of each step - // @param float step_home_param - return hyper parameter in node2vec algorithm - // @param float step_away_param - inout hyper parameter in node2vec algorithm - // @param NodeIdType default_node - default node id - // @param std::shared_ptr *out - Returned nodes id in walk path - // @return Status - The error code return - Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, - float step_home_param, float step_away_param, NodeIdType default_node, - std::shared_ptr *out); - - // Get the feature of a node - // @param std::shared_ptr nodes - List of nodes - // @param std::vector feature_types - Types of features, An error will be reported if the feature type - // does not exist. - // @param TensorRow *out - Returned features - // @return Status - The error code return - Status GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, - TensorRow *out); - - // Get the feature of a edge - // @param std::shared_ptr edget - List of edges - // @param std::vector feature_types - Types of features, An error will be reported if the feature type - // does not exist. - // @param Tensor *out - Returned features - // @return Status - The error code return - Status GetEdgeFeature(const std::shared_ptr &edget, const std::vector &feature_types, - TensorRow *out); - - // Get meta information of graph - // @param MetaInfo *meta_info - Returned meta information - // @return Status - The error code return - Status GetMetaInfo(MetaInfo *meta_info); - - // Return meta information to python layer - Status GraphInfo(py::dict *out); - - Status Init(); - - private: - class RandomWalkBase { - public: - explicit RandomWalkBase(Graph *graph); - - Status Build(const std::vector &node_list, const std::vector &meta_path, - float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, - int32_t num_walks = 1, int32_t num_workers = 1); - - ~RandomWalkBase() = default; - - Status SimulateWalk(std::vector> *walks); - - private: - Status Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path); - - Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, - std::shared_ptr *node_probability); - - Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, - std::shared_ptr *edge_probability); - - static StochasticIndex GenerateProbability(const std::vector &probability); - - static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index); - - template - std::vector Normalize(const std::vector &non_normalized_probability); - - Graph *graph_; - std::vector node_list_; - std::vector meta_path_; - float step_home_param_; // Return hyper parameter. Default is 1.0 - float step_away_param_; // Inout hyper parameter. Default is 1.0 - NodeIdType default_node_; - - int32_t num_walks_; // Number of walks per source. Default is 10 - int32_t num_workers_; // The number of worker threads. Default is 1 - }; - - // Load graph data from mindrecord file - // @return Status - The error code return - Status LoadNodeAndEdge(); - - // Create Tensor By Vector - // @param std::vector> &data - - // @param DataType type - - // @param std::shared_ptr *out - - // @return Status - The error code return - template - Status CreateTensorByVector(const std::vector> &data, DataType type, std::shared_ptr *out); - - // Complete vector - // @param std::vector> *data - To be completed vector - // @param size_t max_size - The size of the completed vector - // @param T default_value - Filled default - // @return Status - The error code return - template - Status ComplementVector(std::vector> *data, size_t max_size, T default_value); - - // Get the default feature of a node - // @param FeatureType feature_type - - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); - - // Find node object using node id - // @param NodeIdType id - - // @param std::shared_ptr *node - Returned node object - // @return Status - The error code return - Status GetNodeByNodeId(NodeIdType id, std::shared_ptr *node); - - // Negative sampling - // @param std::vector &input_data - The data set to be sampled - // @param std::unordered_set &exclude_data - Data to be excluded - // @param int32_t samples_num - - // @param std::vector *out_samples - Sampling results returned - // @return Status - The error code return - Status NegativeSample(const std::vector &input_data, const std::unordered_set &exclude_data, - int32_t samples_num, std::vector *out_samples); - - Status CheckSamplesNum(NodeIdType samples_num); - - std::string dataset_file_; - int32_t num_workers_; // The number of worker threads - std::mt19937 rnd_; - RandomWalkBase random_walk_; - - std::unordered_map> node_type_map_; - std::unordered_map> node_id_map_; - - std::unordered_map> edge_type_map_; - std::unordered_map> edge_id_map_; - - std::unordered_map> node_feature_map_; - std::unordered_map> edge_feature_map_; - - std::unordered_map> default_feature_map_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_GRAPH_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc deleted file mode 100644 index 6504d088bf..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc +++ /dev/null @@ -1,254 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "dataset/engine/gnn/graph_loader.h" -#include "mindspore/ccsrc/mindrecord/include/shard_error.h" -#include "dataset/engine/gnn/local_edge.h" -#include "dataset/engine/gnn/local_node.h" -#include "dataset/util/task_manager.h" - -using ShardTuple = std::vector, mindspore::mindrecord::json>>; - -namespace mindspore { -namespace dataset { -namespace gnn { - -using mindrecord::MSRStatus; - -GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) - : mr_path_(mr_filepath), - num_workers_(num_workers), - row_id_(0), - shard_reader_(nullptr), - keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} - -Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, - EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, - EdgeFeatureMap *e_feature_map, DefaultFeatureMap *default_feature_map) { - for (std::deque> &dq : n_deques_) { - while (dq.empty() == false) { - std::shared_ptr node_ptr = dq.front(); - n_id_map->insert({node_ptr->id(), node_ptr}); - (*n_type_map)[node_ptr->type()].push_back(node_ptr->id()); - dq.pop_front(); - } - } - - for (std::deque> &dq : e_deques_) { - while (dq.empty() == false) { - std::shared_ptr edge_ptr = dq.front(); - std::pair, std::shared_ptr> p; - RETURN_IF_NOT_OK(edge_ptr->GetNode(&p)); - auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id()); - CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); - CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); - RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); - RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); - e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ - (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id()); - dq.pop_front(); - } - } - - for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); - for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); - - MergeFeatureMaps(n_feature_map, e_feature_map, default_feature_map); - return Status::OK(); -} - -Status GraphLoader::InitAndLoad() { - CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n"); - CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n"); - n_deques_.resize(num_workers_); - e_deques_.resize(num_workers_); - n_feature_maps_.resize(num_workers_); - e_feature_maps_.resize(num_workers_); - default_feature_maps_.resize(num_workers_); - TaskGroup vg; - - shard_reader_ = std::make_unique(); - CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, - "Fail to open" + mr_path_); - CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); - CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); - - mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"]; - for (const std::string &key : keys_) { - if (schema.find(key) == schema.end()) { - RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); - } - } - - // launching worker threads - for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { - RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); - } - // wait for threads to finish and check its return code - vg.join_all(Task::WaitFlag::kBlocking); - RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny()); - return Status::OK(); -} - -Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrecord::json &col_jsn, - std::shared_ptr *node, NodeFeatureMap *feature_map, - DefaultFeatureMap *default_feature) { - NodeIdType node_id = col_jsn["first_id"]; - NodeType node_type = static_cast(col_jsn["type"]); - (*node) = std::make_shared(node_id, node_type); - std::vector indices; - RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices)); - - for (int32_t ind : indices) { - std::shared_ptr tensor; - RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); - RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared(ind, tensor))); - (*feature_map)[node_type].insert(ind); - if ((*default_feature)[ind] == nullptr) { - std::shared_ptr zero_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); - RETURN_IF_NOT_OK(zero_tensor->Zero()); - (*default_feature)[ind] = std::make_shared(ind, zero_tensor); - } - } - return Status::OK(); -} - -Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrecord::json &col_jsn, - std::shared_ptr *edge, EdgeFeatureMap *feature_map, - DefaultFeatureMap *default_feature) { - EdgeIdType edge_id = col_jsn["first_id"]; - EdgeType edge_type = static_cast(col_jsn["type"]); - NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; - std::shared_ptr src = std::make_shared(src_id, -1); - std::shared_ptr dst = std::make_shared(dst_id, -1); - (*edge) = std::make_shared(edge_id, edge_type, src, dst); - std::vector indices; - RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices)); - for (int32_t ind : indices) { - std::shared_ptr tensor; - RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); - RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared(ind, tensor))); - (*feature_map)[edge_type].insert(ind); - if ((*default_feature)[ind] == nullptr) { - std::shared_ptr zero_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); - RETURN_IF_NOT_OK(zero_tensor->Zero()); - (*default_feature)[ind] = std::make_shared(ind, zero_tensor); - } - } - return Status::OK(); -} - -Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector &col_blob, - const mindrecord::json &col_jsn, std::shared_ptr *tensor) { - const unsigned char *data = nullptr; - std::unique_ptr data_ptr; - uint64_t n_bytes = 0, col_type_size = 1; - mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; - std::vector column_shape; - MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( - key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); - CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); - if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible, - std::move(TensorShape({static_cast(n_bytes / col_type_size)})), - std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data)); - return Status::OK(); -} - -Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector &col_blob, - const mindrecord::json &col_jsn, std::vector *indices) { - const unsigned char *data = nullptr; - std::unique_ptr data_ptr; - uint64_t n_bytes = 0, col_type_size = 1; - mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; - std::vector column_shape; - MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( - key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); - CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); - - if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); - - for (int i = 0; i < n_bytes; i += col_type_size) { - int32_t feature_ind = -1; - if (col_type == mindrecord::ColumnInt32) { - feature_ind = *(reinterpret_cast(data + i)); - } else if (col_type == mindrecord::ColumnInt64) { - feature_ind = *(reinterpret_cast(data + i)); - } else { - RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); - } - if (feature_ind >= 0) indices->push_back(feature_ind); - } - return Status::OK(); -} - -Status GraphLoader::WorkerEntry(int32_t worker_id) { - // Handshake - TaskManager::FindMe()->Post(); - auto ret = shard_reader_->GetNextById(row_id_++, worker_id); - ShardTuple rows = ret.second; - while (rows.empty() == false) { - RETURN_IF_INTERRUPTED(); - for (const auto &tupled_row : rows) { - std::vector col_blob = std::get<0>(tupled_row); - mindrecord::json col_jsn = std::get<1>(tupled_row); - std::string attr = col_jsn["attribute"]; - if (attr == "n") { - std::shared_ptr node_ptr; - RETURN_IF_NOT_OK( - LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), &default_feature_maps_[worker_id])); - n_deques_[worker_id].emplace_back(node_ptr); - } else if (attr == "e") { - std::shared_ptr edge_ptr; - RETURN_IF_NOT_OK( - LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), &default_feature_maps_[worker_id])); - e_deques_[worker_id].emplace_back(edge_ptr); - } else { - MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node."; - } - } - auto rc = shard_reader_->GetNextById(row_id_++, worker_id); - rows = rc.second; - } - return Status::OK(); -} - -void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, - DefaultFeatureMap *default_feature_map) { - for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { - for (auto &m : n_feature_maps_[wkr_id]) { - for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); - } - for (auto &m : e_feature_maps_[wkr_id]) { - for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); - } - for (auto &m : default_feature_maps_[wkr_id]) { - (*default_feature_map)[m.first] = m.second; - } - } - n_feature_maps_.clear(); - e_feature_maps_.clear(); -} - -} // namespace gnn -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h deleted file mode 100644 index 0ad54bae6d..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_GRAPH_LOADER_H_ -#define DATASET_ENGINE_GNN_GRAPH_LOADER_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/gnn/feature.h" -#include "dataset/engine/gnn/graph.h" -#include "dataset/engine/gnn/node.h" -#include "dataset/engine/gnn/edge.h" -#include "dataset/util/status.h" -#include "mindrecord/include/shard_reader.h" -namespace mindspore { -namespace dataset { -namespace gnn { - -using mindrecord::ShardReader; -using NodeIdMap = std::unordered_map>; -using EdgeIdMap = std::unordered_map>; -using NodeTypeMap = std::unordered_map>; -using EdgeTypeMap = std::unordered_map>; -using NodeFeatureMap = std::unordered_map>; -using EdgeFeatureMap = std::unordered_map>; -using DefaultFeatureMap = std::unordered_map>; - -// this class interfaces with the underlying storage format (mindrecord) -// it returns raw nodes and edges via GetNodesAndEdges -// it is then the responsibility of graph to construct itself based on the nodes and edges -// if needed, this class could become a base where each derived class handles a specific storage format -class GraphLoader { - public: - explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4); - - ~GraphLoader() = default; - // Init mindrecord and load everything into memory multi-threaded - // @return Status - the status code - Status InitAndLoad(); - - // this function will query mindrecord and construct all nodes and edges - // nodes and edges are added to map without any connection. That's because there nodes and edges are read in - // random order. src_node and dst_node in Edge are node_id only with -1 as type. - // features attached to each node and edge are expected to be filled correctly - Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, - DefaultFeatureMap *); - - private: - // - // worker thread that reads mindrecord file - // @param int32_t worker_id - id of each worker - // @return Status - the status code - Status WorkerEntry(int32_t worker_id); - - // Load a node based on 1 row of mindrecord, returns a shared_ptr - // @param std::vector &blob - contains data in blob field in mindrecord - // @param mindrecord::json &jsn - contains raw data - // @param std::shared_ptr *node - return value - // @param NodeFeatureMap *feature_map - - // @param DefaultFeatureMap *default_feature - - // @return Status - the status code - Status LoadNode(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *node, - NodeFeatureMap *feature_map, DefaultFeatureMap *default_feature); - - // @param std::vector &blob - contains data in blob field in mindrecord - // @param mindrecord::json &jsn - contains raw data - // @param std::shared_ptr *edge - return value, the edge ptr, edge is not yet connected - // @param FeatureMap *feature_map - // @param DefaultFeatureMap *default_feature - - // @return Status - the status code - Status LoadEdge(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *edge, - EdgeFeatureMap *feature_map, DefaultFeatureMap *default_feature); - - // @param std::string key - column name - // @param std::vector &blob - contains data in blob field in mindrecord - // @param mindrecord::json &jsn - contains raw data - // @param std::vector *ind - return value, list of feature index in int32_t - // @return Status - the status code - Status LoadFeatureIndex(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, - std::vector *ind); - - // @param std::string &key - column name - // @param std::vector &blob - contains data in blob field in mindrecord - // @param mindrecord::json &jsn - contains raw data - // @param std::shared_ptr *tensor - return value feature tensor - // @return Status - the status code - Status LoadFeatureTensor(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, - std::shared_ptr *tensor); - - // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 - void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultFeatureMap *); - - const int32_t num_workers_; - std::atomic_int row_id_; - std::string mr_path_; - std::unique_ptr shard_reader_; - std::vector>> n_deques_; - std::vector>> e_deques_; - std::vector n_feature_maps_; - std::vector e_feature_maps_; - std::vector default_feature_maps_; - const std::vector keys_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_GRAPH_LOADER_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_edge.cc b/mindspore/ccsrc/dataset/engine/gnn/local_edge.cc deleted file mode 100644 index 7465b689d5..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/local_edge.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/gnn/local_edge.h" - -#include - -namespace mindspore { -namespace dataset { -namespace gnn { - -LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) - : Edge(id, type, src_node, dst_node) {} - -Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { - auto itr = features_.find(feature_type); - if (itr != features_.end()) { - *out_feature = itr->second; - return Status::OK(); - } else { - std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } -} - -Status LocalEdge::UpdateFeature(const std::shared_ptr &feature) { - auto itr = features_.find(feature->type()); - if (itr != features_.end()) { - RETURN_STATUS_UNEXPECTED("Feature already exists"); - } else { - features_[feature->type()] = feature; - return Status::OK(); - } -} -} // namespace gnn -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_edge.h b/mindspore/ccsrc/dataset/engine/gnn/local_edge.h deleted file mode 100644 index a34fc00373..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/local_edge.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_LOCAL_EDGE_H_ -#define DATASET_ENGINE_GNN_LOCAL_EDGE_H_ - -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/engine/gnn/edge.h" -#include "dataset/engine/gnn/feature.h" -#include "dataset/engine/gnn/node.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -class LocalEdge : public Edge { - public: - // Constructor - // @param EdgeIdType id - edge id - // @param EdgeType type - edge type - // @param std::shared_ptr src_node - source node - // @param std::shared_ptr dst_node - destination node - LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node); - - ~LocalEdge() = default; - - // Get the feature of a edge - // @param FeatureType feature_type - type of feature - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; - - // Update feature of edge - // @param std::shared_ptr feature - - // @return Status - The error code return - Status UpdateFeature(const std::shared_ptr &feature) override; - - private: - std::unordered_map> features_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_LOCAL_EDGE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/dataset/engine/gnn/local_node.cc deleted file mode 100644 index c829f8e8ca..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/gnn/local_node.h" - -#include -#include -#include - -#include "dataset/engine/gnn/edge.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } - -Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { - auto itr = features_.find(feature_type); - if (itr != features_.end()) { - *out_feature = itr->second; - return Status::OK(); - } else { - std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } -} - -Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, bool exclude_itself) { - std::vector neighbors; - auto itr = neighbor_nodes_.find(neighbor_type); - if (itr != neighbor_nodes_.end()) { - if (exclude_itself) { - neighbors.resize(itr->second.size()); - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), - [](const std::shared_ptr node) { return node->id(); }); - } else { - neighbors.resize(itr->second.size() + 1); - neighbors[0] = id_; - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, - [](const std::shared_ptr node) { return node->id(); }); - } - } else { - MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; - if (!exclude_itself) { - neighbors.emplace_back(id_); - } - } - *out_neighbors = std::move(neighbors); - return Status::OK(); -} - -Status LocalNode::GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, - std::vector *out) { - std::vector shuffled_id(neighbors.size()); - std::iota(shuffled_id.begin(), shuffled_id.end(), 0); - std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); - int32_t num = std::min(samples_num, static_cast(neighbors.size())); - for (int32_t i = 0; i < num; ++i) { - out->emplace_back(neighbors[shuffled_id[i]]->id()); - } - return Status::OK(); -} - -Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, - std::vector *out_neighbors) { - std::vector neighbors; - neighbors.reserve(samples_num); - auto itr = neighbor_nodes_.find(neighbor_type); - if (itr != neighbor_nodes_.end()) { - while (neighbors.size() < samples_num) { - RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); - } - } else { - MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; - // If there are no neighbors, they are filled with kDefaultNodeId - for (int32_t i = 0; i < samples_num; ++i) { - neighbors.emplace_back(kDefaultNodeId); - } - } - *out_neighbors = std::move(neighbors); - return Status::OK(); -} - -Status LocalNode::AddNeighbor(const std::shared_ptr &node) { - auto itr = neighbor_nodes_.find(node->type()); - if (itr != neighbor_nodes_.end()) { - itr->second.push_back(node); - } else { - neighbor_nodes_[node->type()] = {node}; - } - return Status::OK(); -} - -Status LocalNode::UpdateFeature(const std::shared_ptr &feature) { - auto itr = features_.find(feature->type()); - if (itr != features_.end()) { - RETURN_STATUS_UNEXPECTED("Feature already exists"); - } else { - features_[feature->type()] = feature; - return Status::OK(); - } -} - -} // namespace gnn -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/dataset/engine/gnn/local_node.h deleted file mode 100644 index bc069d073f..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/local_node.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_LOCAL_NODE_H_ -#define DATASET_ENGINE_GNN_LOCAL_NODE_H_ - -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/engine/gnn/node.h" -#include "dataset/engine/gnn/feature.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -class LocalNode : public Node { - public: - // Constructor - // @param NodeIdType id - node id - // @param NodeType type - node type - LocalNode(NodeIdType id, NodeType type); - - ~LocalNode() = default; - - // Get the feature of a node - // @param FeatureType feature_type - type of feature - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; - - // Get the all neighbors of a node - // @param NodeType neighbor_type - type of neighbor - // @param std::vector *out_neighbors - Returned neighbors id - // @return Status - The error code return - Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, - bool exclude_itself = false) override; - - // Get the sampled neighbors of a node - // @param NodeType neighbor_type - type of neighbor - // @param int32_t samples_num - Number of neighbors to be acquired - // @param std::vector *out_neighbors - Returned neighbors id - // @return Status - The error code return - Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, - std::vector *out_neighbors) override; - - // Add neighbor of node - // @param std::shared_ptr node - - // @return Status - The error code return - Status AddNeighbor(const std::shared_ptr &node) override; - - // Update feature of node - // @param std::shared_ptr feature - - // @return Status - The error code return - Status UpdateFeature(const std::shared_ptr &feature) override; - - private: - Status GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, - std::vector *out); - - std::mt19937 rnd_; - std::unordered_map> features_; - std::unordered_map>> neighbor_nodes_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_LOCAL_NODE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/node.h b/mindspore/ccsrc/dataset/engine/gnn/node.h deleted file mode 100644 index 282f856797..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/node.h +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_NODE_H_ -#define DATASET_ENGINE_GNN_NODE_H_ - -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/engine/gnn/feature.h" - -namespace mindspore { -namespace dataset { -namespace gnn { -using NodeType = int8_t; -using NodeIdType = int32_t; - -constexpr NodeIdType kDefaultNodeId = -1; - -class Node { - public: - // Constructor - // @param NodeIdType id - node id - // @param NodeType type - node type - Node(NodeIdType id, NodeType type) : id_(id), type_(type) {} - - virtual ~Node() = default; - - // @return NodeIdType - Returned node id - NodeIdType id() const { return id_; } - - // @return NodeIdType - Returned node type - NodeType type() const { return type_; } - - // Get the feature of a node - // @param FeatureType feature_type - type of feature - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; - - // Get the all neighbors of a node - // @param NodeType neighbor_type - type of neighbor - // @param std::vector *out_neighbors - Returned neighbors id - // @return Status - The error code return - virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, - bool exclude_itself = false) = 0; - - // Get the sampled neighbors of a node - // @param NodeType neighbor_type - type of neighbor - // @param int32_t samples_num - Number of neighbors to be acquired - // @param std::vector *out_neighbors - Returned neighbors id - // @return Status - The error code return - virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, - std::vector *out_neighbors) = 0; - - // Add neighbor of node - // @param std::shared_ptr node - - // @return Status - The error code return - virtual Status AddNeighbor(const std::shared_ptr &node) = 0; - - // Update feature of node - // @param std::shared_ptr feature - - // @return Status - The error code return - virtual Status UpdateFeature(const std::shared_ptr &feature) = 0; - - protected: - NodeIdType id_; - NodeType type_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_NODE_H_ diff --git a/mindspore/ccsrc/dataset/engine/jagged_connector.h b/mindspore/ccsrc/dataset/engine/jagged_connector.h deleted file mode 100644 index 2058c542a8..0000000000 --- a/mindspore/ccsrc/dataset/engine/jagged_connector.h +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_JAGGED_CONNECTOR_H_ -#define DATASET_ENGINE_JAGGED_CONNECTOR_H_ - -#include -#include -#include -#include -#include "dataset/engine/connector.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/util/status.h" -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { -class JaggedConnector : public Connector> { - public: - JaggedConnector(int32_t num_producers, int32_t num_consumers, int32_t queue_capacity) - : Connector>(num_producers, num_consumers, queue_capacity) { - for (int i = 0; i < num_producers; i++) { - is_queue_finished_.push_back(false); - } - } - - ~JaggedConnector() = default; - - Status Add(int32_t worker_d, std::unique_ptr &&element) noexcept { - return Connector>::Push(worker_d, std::move(element)); - } - - Status Pop(int32_t worker_id, std::unique_ptr *result) noexcept override { - { - MS_ASSERT(worker_id < num_consumers_); - std::unique_lock lock(m_); - RETURN_IF_NOT_OK(cv_.Wait(&lock, [this, worker_id]() { return expect_consumer_ == worker_id; })); - if (is_queue_finished_[pop_from_]) { - std::string errMsg = "ERROR: popping from a finished queue in JaggedConnector"; - RETURN_STATUS_UNEXPECTED(errMsg); - } - - RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); - if ((*result)->eoe()) { - is_queue_finished_[pop_from_] = true; - } - - for (int offset = 1; offset <= num_producers_; offset++) { - int32_t nextQueueIndex = (pop_from_ + offset) % num_producers_; - if (is_queue_finished_[nextQueueIndex] == false) { - pop_from_ = nextQueueIndex; - break; - } - } - - expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; - } - - cv_.NotifyAll(); - return Status::OK(); - } - - void DoReset() { - for (int i = 0; i < is_queue_finished_.size(); i++) { - is_queue_finished_[i] = false; - } - - Connector>::Reset(); - } - - private: - std::vector is_queue_finished_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_JAGGED_CONNECTOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt deleted file mode 100644 index 080d968cfc..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(engine-opt OBJECT - pass.cc - pre/removal_nodes.cc - pre/removal_pass.cc - util/printer_pass.cc - ) diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.cc b/mindspore/ccsrc/dataset/engine/opt/pass.cc deleted file mode 100644 index 27769f056b..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pass.cc +++ /dev/null @@ -1,164 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/engine/opt/pass.h" -#include "dataset/engine/datasetops/batch_op.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/datasetops/device_queue_op.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/engine/datasetops/project_op.h" -#include "dataset/engine/datasetops/rename_op.h" -#include "dataset/engine/datasetops/filter_op.h" -#include "dataset/engine/datasetops/repeat_op.h" -#include "dataset/engine/datasetops/skip_op.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/engine/datasetops/source/generator_op.h" -#include "dataset/engine/datasetops/source/mindrecord_op.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/take_op.h" -#include "dataset/engine/datasetops/zip_op.h" - -namespace mindspore { -namespace dataset { - -// Driver method for TreePass -Status TreePass::Run(ExecutionTree *tree, bool *modified) { - if (tree == nullptr || modified == nullptr) { - return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); - } - return this->RunOnTree(tree, modified); -} - -// Driver method for NodePass -Status NodePass::Run(ExecutionTree *tree, bool *modified) { - if (tree == nullptr || modified == nullptr) { - return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); - } - std::shared_ptr root = tree->root(); - if (traversalOrder_ == Order::DFS) { - // DFS - return DFSNodeVisit(root, modified); - } else if (traversalOrder_ == Order::BFS) { - // BFS - return BFSNodeVisit(root, modified); - } - return Status::OK(); -} - -// Helper function to perform DFS visit -Status NodePass::DFSNodeVisit(std::shared_ptr node, bool *modified) { - RETURN_IF_NOT_OK(node->PreAccept(this, modified)); - for (const auto &c : node->Children()) { - RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); - } - return node->Accept(this, modified); -} - -// Helper function to perform BFS visit -Status NodePass::BFSNodeVisit(std::shared_ptr root, bool *modified) { - // Initialize bfs queue with root - std::queue> bfsQueue; - bfsQueue.push(root); - - // BFS loop - while (!bfsQueue.empty()) { - // Pop the front of the bfs queue - auto curNode = bfsQueue.front(); - bfsQueue.pop(); - - // Run node pass - RETURN_IF_NOT_OK(curNode->Accept(this, modified)); - - // Push children into bfs queue - for (const auto &c : curNode->Children()) { - bfsQueue.push(c); - } - } - return Status::OK(); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.h b/mindspore/ccsrc/dataset/engine/opt/pass.h deleted file mode 100644 index 129c2fab37..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pass.h +++ /dev/null @@ -1,159 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_H_ - -#include -#include - -#include "dataset/engine/execution_tree.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class BatchOp; - -class MapOp; - -class ProjectOp; - -class RenameOp; - -class FilterOp; - -class SkipOp; - -class ShuffleOp; - -class GeneratorOp; - -class MindRecordOp; - -class TFReaderOp; - -class TakeOp; - -class ZipOp; - -class DeviceQueueOp; - -class ImageFolderOp; - -// The base class Pass is the basic unit of tree transformation. -// The actual implementation of the passes will be derived from here. -class Pass : public std::enable_shared_from_this { - public: - // Run the transformation pass against the execution tree. - // @param tree - Pointer to the execution tree to be transformed. - // @param modified - Pointer to the modified flag, - virtual Status Run(ExecutionTree *tree, bool *modified) = 0; -}; - -// TreePass is a basic Pass class which performs transformation on ExecutionTree directly. -class TreePass : public Pass { - public: - /// \brief Run the transformation pass against the execution tree. - /// \param[inout] tree Pointer to the execution tree to be transformed. - /// \param[inout] modified Indicate if the tree was modified - Status Run(ExecutionTree *tree, bool *modified) final; - - /// \brief Derived classes may implement the runOnTree function to implement tree transformation. - /// "modified" flag needs to be set to true if tree is modified during the pass execution. - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate of the tree was modified. - /// \return Status The error code return - virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } -}; - -// NodePass is a basic Pass class which performs transformation on Node visiting. -// NodePass implements Visitor design pattern. -class NodePass : public Pass { - public: - // Tree traversal order - enum Order { DFS, BFS }; - - // Constructor - // Default DFS traversal - explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } - - ~NodePass() = default; - - /// \brief Run the transformation pass against the execution tree - /// \param[inout] tree Pointer to the execution tree to be transformed - /// \param[inout] modified Indicator if the tree was changed - Status Run(ExecutionTree *tree, bool *modified) final; - - /// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down - /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution - /// \param[in] node The node being visited - /// \param[out] modified Indicator if the node was changed at all - /// \return Status The error code return - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } - - /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation - /// "modified" flag needs to be set to true if tree is modified during the pass execution - /// \param[in] node The node being visited - /// \param[out] modified Indicator if the node was changed at all. - /// \return Status The error code return - virtual Status RunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } - - // Visit methods to be overridden. - // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode - // of its own type and override "Accept" from DatasetOp. - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - private: - // Helper function to perform DFS visit - Status DFSNodeVisit(std::shared_ptr node, bool *modified); - - // Helper function to perform BFS visit - Status BFSNodeVisit(std::shared_ptr root, bool *modified); - - // Tree traversal order of the NodePass - Order traversalOrder_; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc deleted file mode 100644 index 831a2a76ba..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/engine/opt/pre/removal_nodes.h" -#include "dataset/engine/opt/pre/removal_pass.h" -#include "dataset/engine/datasetops/shuffle_op.h" - -namespace mindspore { -namespace dataset { - -RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} - -// Perform ShuffleOp removal check. -Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - // If we are in a cache descendant tree, then this shuffle op needs to be removed - if (is_caching_) { - MS_LOG(DEBUG) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; - if (removal_pass_) { - removal_pass_->AddToRemovalList(std::static_pointer_cast(node)); - } else { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); - } - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h deleted file mode 100644 index 11ef37d80c..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ - -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class RemovalPass; - -/// \class RemovalNodes removal_nodes.h -/// \brief This is a NodePass who's job is to identify which nodes should be removed. -/// It works in conjunction with the removal_pass. -class RemovalNodes : public NodePass { - public: - /// \brief Constructor - /// \param[in] removal_pass Raw pointer back to controlling tree pass - explicit RemovalNodes(RemovalPass *removal_pass); - - /// \brief Perform ShuffleOp removal check - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - private: - bool is_caching_; - RemovalPass *removal_pass_; // Back pointer to the owning removal pass -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc b/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc deleted file mode 100644 index 31ec31234f..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "dataset/engine/opt/pre/removal_nodes.h" -#include "dataset/engine/opt/pre/removal_pass.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { - -// constructor -RemovalPass::RemovalPass() {} - -// Runs a removal_nodes pass first to find out which nodes to remove, then removes them. -Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { - // Create the removal node pass which can identify which nodes need to be removed. - std::unique_ptr removal_nodes = std::make_unique(this); - RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); - - // Then, execute the removal of any nodes that were set up for removal - for (auto node : removal_nodes_) { - node->Remove(); - } - return Status::OK(); -} - -// Adds an operator to the list of operators to be removed -void RemovalPass::AddToRemovalList(std::shared_ptr dataset_op) { removal_nodes_.push_back(dataset_op); } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.h b/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.h deleted file mode 100644 index 6523ca69b2..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ - -#include -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class DatasetOp; - -/// \class RemovalPass removal_pass.h -/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which -/// nodes should be removed, and then removes them. -class RemovalPass : public TreePass { - public: - /// \brief Constructor - RemovalPass(); - - /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate of the tree was modified. - /// \return Status The error code return - Status RunOnTree(ExecutionTree *tree, bool *modified) override; - - /// \brief Adds an operator to the list of operators to be removed - /// \param[in] dataset_op The operator to add to the removal list - void AddToRemovalList(std::shared_ptr dataset_op); - - private: - std::vector> removal_nodes_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc b/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc deleted file mode 100644 index 852bc018b2..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc +++ /dev/null @@ -1,111 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/engine/opt/util/printer_pass.h" - -namespace mindspore { -namespace dataset { - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting DatasetOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting BatchOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting MapOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting ProjectOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting RenameOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting FilterOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting SkipOp" << '\n'; - return Status::OK(); -} -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting ShuffleOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting GeneratorOp" << '\n'; - return Status::OK(); -} -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting MindRecordOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting TFReaderOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting TakeOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting ZipOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting DeviceQueueOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting ImageFolderOp" << '\n'; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h b/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h deleted file mode 100644 index fa04a88277..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H -#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H - -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class PrinterPass : public NodePass { - public: - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H diff --git a/mindspore/ccsrc/dataset/engine/perf/connector_size.cc b/mindspore/ccsrc/dataset/engine/perf/connector_size.cc deleted file mode 100644 index 0bd2754075..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/connector_size.cc +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/perf/connector_size.h" -#include -#include -#include -#include -#include "dataset/core/config_manager.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/path.h" - -using json = nlohmann::json; -namespace mindspore { -namespace dataset { -using Qrow = std::vector; - -// Sample action -Status ConnectorSize::Sample() { - Qrow cur_row; - std::transform(tree_->begin(), tree_->end(), std::back_inserter(cur_row), - [](DatasetOp &op) { return op.ConnectorSize(); }); - // Push new row of sample - sample_table_.push_back(cur_row); - return Status::OK(); -} - -// JSON serializer helper function -json ConnectorSize::ParseOpInfo(const DatasetOp &node, const std::vector &size) { - auto children = node.Children(); - std::vector children_id; - std::transform(children.begin(), children.end(), std::back_inserter(children_id), - [](std::shared_ptr op) -> int32_t { return op->id(); }); - json json_node; - json_node["op_id"] = node.id(); - json_node["op_type"] = node.Name(); - json_node["num_workers"] = node.num_workers(); - json metrics; - // DeviceQueueOp is a special op,it is not inlined but its output queue is invalid. - // So we should not output its queue size. - if (!node.inlined() && node.Name() != "DeviceQueueOp") { - metrics["output_queue"] = {{"size", size}, {"length", node.ConnectorCapacity()}}; - } - json_node["metrics"] = metrics; - if (!children_id.empty()) { - json_node["children"] = children_id; - } - - return json_node; -} - -// Save profiling data to file -Status ConnectorSize::SaveToFile() { - std::ofstream os(file_path_, std::ios::trunc); - uint32_t idx = 0; - json output; - std::shared_ptr cfg = GlobalContext::config_manager(); - output["sampling_interval"] = cfg->monitor_sampling_interval(); - // Traverse the ExecutionTree for JSON node generation - for (auto &node : *tree_) { - std::vector cur_queue_size; - std::transform(sample_table_.begin(), sample_table_.end(), std::back_inserter(cur_queue_size), - [&](const ConnectorSizeSample &sample) { return sample[idx]; }); - json json_node = ParseOpInfo(node, cur_queue_size); - output["op_info"].push_back(json_node); - idx++; - } - os << output; - return Status::OK(); -} -Status ConnectorSize::Init(const std::string &dir_path, const std::string &device_id) { - file_path_ = (Path(dir_path) / Path("pipeline_profiling_" + device_id + ".json")).toString(); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/connector_size.h b/mindspore/ccsrc/dataset/engine/perf/connector_size.h deleted file mode 100644 index 2584289fb4..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/connector_size.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CONNECTOR_SIZE_H -#define DATASET_CONNECTOR_SIZE_H - -#include -#include -#include -#include "dataset/engine/perf/profiling.h" -#include "dataset/engine/datasetops/dataset_op.h" - -using json = nlohmann::json; - -namespace mindspore { -namespace dataset { -class ExecutionTree; - -// Connector size sampling samples the output connector size of each op in the pipeline. -// It support JSON serialization for external usage. -class ConnectorSize : public Sampling { - // Connecto size sampling data is stored as a 2D vector - // op_0 ... op_m - // sample_0 size_0_0 ... size_m_0 - // ... ... ... ... - // sample_n size_0_m ... size_m_n - // - // A circular buffer will be implemented in the future to make this table more flexible. - using ConnectorSizeSample = std::vector; - using ConnectorSizeSampleTable = std::vector; - - public: - explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {} - - ~ConnectorSize() override = default; - - // Driver function for connector size sampling. - // This function samples the connector size of every nodes within the ExecutionTree - Status Sample() override; - - std::string Name() const override { return kConnectorSizeSamplingName; } - - // Save sampling data to file - // @return Status - The error code return - Status SaveToFile() override; - - Status Init(const std::string &dir_path, const std::string &device_id) override; - - // Parse op infomation and transform to json format - json ParseOpInfo(const DatasetOp &node, const std::vector &size); - - private: - ExecutionTree *tree_ = nullptr; // ExecutionTree pointer - ConnectorSizeSampleTable sample_table_; // Dataset structure to store all samples of connector size sampling -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CONNECTOR_SIZE_H diff --git a/mindspore/ccsrc/dataset/engine/perf/connector_throughput.cc b/mindspore/ccsrc/dataset/engine/perf/connector_throughput.cc deleted file mode 100644 index 4fd59de390..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/connector_throughput.cc +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include "dataset/engine/perf/connector_throughput.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/path.h" - -namespace mindspore { -namespace dataset { - -// temporary helper -int ConnectorThroughput::InitNodes() { - auto it = (*tree_).begin(); - return it.NumNodes(); -} -// Sample action -Status ConnectorThroughput::Sample() { - std::vector out_buffer_count_row(n_nodes_); - std::vector throughput_row(n_nodes_); - TimePoint cur_time; // initialised inside the loop, used outside the loop to update prev sample time. - auto col = 0; - for (const auto &node : *tree_) { - auto cur_out_buffer_count = node.ConnectorOutBufferCount(); - out_buffer_count_row[col] = cur_out_buffer_count; - auto sz = timestamps_.size(); - cur_time = std::chrono::steady_clock::now(); - auto _dt = std::chrono::duration_cast(timestamps_[0][sz - 1] - timestamps_[0][sz - 2]); - auto dt = std::chrono::duration(_dt).count(); - auto prev_out_buffer_count = out_buffer_count_table_[col][out_buffer_count_table_.size() - 1]; - if (dt != 0) { - auto thr = (cur_out_buffer_count - prev_out_buffer_count) / (1000 * dt); - throughput_row[col] = thr; - } else { - throughput_row[col] = -1; - } - col++; - } - std::vector v = {cur_time}; // temporary fix - timestamps_.AddSample(v); - // Push new row of sample - out_buffer_count_table_.AddSample(out_buffer_count_row); - throughput_.AddSample(throughput_row); - return Status::OK(); -} - -json ConnectorThroughput::ParseOpInfo(const DatasetOp &node, const std::vector &thr) { - auto children = node.Children(); - std::vector children_id; - std::transform(children.begin(), children.end(), std::back_inserter(children_id), - [](std::shared_ptr op) -> int32_t { return op->id(); }); - json json_node; - json_node["op_id"] = node.id(); - json_node["op_type"] = node.Name(); - json_node["num_workers"] = node.num_workers(); - json metrics; - metrics["output_queue"] = {{"throughput", thr}}; - - json_node["metrics"] = metrics; - if (!children_id.empty()) { - json_node["children"] = children_id; - } - - return json_node; -} - -// Save profiling data to file -Status ConnectorThroughput::SaveToFile() { - std::ofstream os(file_path_); - json output; - output["sampling_interval"] = 10; - // Traverse the ExecutionTree for JSON node generation - int col = 0; - for (auto &node : *tree_) { - std::vector throughput; - for (auto i = 0; i < throughput_.size(); i++) { - throughput.push_back(throughput_[col][i]); - } - json json_node = ParseOpInfo(node, throughput); - output["op_info"].push_back(json_node); - col++; - } - os << output; - return Status::OK(); -} -Status ConnectorThroughput::Init(const std::string &dir_path, const std::string &device_id) { - file_path_ = (Path(dir_path) / Path("pipeline_profiling_" + Name() + "_" + device_id + ".json")).toString(); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/connector_throughput.h b/mindspore/ccsrc/dataset/engine/perf/connector_throughput.h deleted file mode 100644 index e873eb8315..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/connector_throughput.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_CONNECTOR_THROUGHPUT_H -#define DATASET_CONNECTOR_THROUGHPUT_H - -#include -#include -#include -#include -#include -#include "dataset/engine/perf/profiling.h" -#include "dataset/engine/perf/perf_data.h" -#include "dataset/engine/perf/cyclic_array.h" -#include "dataset/engine/datasetops/dataset_op.h" - -using json = nlohmann::json; -namespace mindspore { -namespace dataset { -class ExecutionTree; - -// Connector throughput samples the output connector size of each op in the pipeline. -// For the description of the data structure see perf_buffer.h -// It support JSON serialization for external usage. -class ConnectorThroughput : public Sampling { - using OutBufferCount = PerfData>; - using Throughput = PerfData>; - using TimePoint = std::chrono::time_point; - using TimeStamps = PerfData>; - - public: - explicit ConnectorThroughput(ExecutionTree *tree, int64_t max_rows = 1000000) - : tree_(tree), - max_rows_(max_rows), - n_nodes_(InitNodes()), - out_buffer_count_table_(OutBufferCount(max_rows_, n_nodes_)), - throughput_(Throughput(max_rows_, n_nodes_)), - timestamps_(TimeStamps(max_rows_, 1)) { - timestamps_.AddSample(std::vector(1)); - out_buffer_count_table_.AddSample(std::vector(n_nodes_)); - } - // Driver function for connector size sampling. - // This function samples the connector size of every nodes within the ExecutionTree - Status Sample() override; - - /* Status TestPrint() override { - std::ofstream os("performance_monitor.txt"); - if (throughput_.size() == 0) { - os << "data is empty" << std::endl; - return Status::OK(); - } - for (int i = 0; i < throughput_.size(); i++) { - for (int j = 0; j < n_nodes_; j++) { - os << throughput_[j][i] << " "; - } - os << std::endl; - } - return Status::OK(); - };*/ - - // Traverse the tree nodes and count them - int InitNodes(); - - std::string Name() const override { return name_; }; - - // Save sampling data to file - // @return Status - The error code return - Status SaveToFile() override; - - Status Init(const std::string &dir_path, const std::string &device_id); - - json ParseOpInfo(const DatasetOp &node, const std::vector &thr); - - private: - ExecutionTree *tree_ = nullptr; // ExecutionTree pointer - int64_t max_rows_; - int32_t n_nodes_; - OutBufferCount out_buffer_count_table_; - Throughput throughput_; - TimeStamps timestamps_; - std::string name_ = kConnectorThroughputSamplingName; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CONNECTOR_THROUGHPUT_H diff --git a/mindspore/ccsrc/dataset/engine/perf/cyclic_array.h b/mindspore/ccsrc/dataset/engine/perf/cyclic_array.h deleted file mode 100644 index fa60b401c5..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/cyclic_array.h +++ /dev/null @@ -1,197 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_CYCLIC_ARRAY_H -#define DATASET_CYCLIC_ARRAY_H - -#include -#include -#include -#include -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { - -/// \class CyclicArray "include/cyclic_array.h -/// \brief This is a container with a contiguous memory layout that pnly keeps N last entries, -/// when the number of entries exceeds the capacity -/// Must be preallocated -template -class CyclicArray { - public: - using value_type = T; - class Iterator { - // Add operator[] and make fully compliant with random access iterator - // and add a const iterator - // add resize(), empty() - public: - using iterator_category = std::random_access_iterator_tag; - using value_type = CyclicArray::value_type; - using difference_type = std::ptrdiff_t; - using pointer = CyclicArray::value_type *; - using reference = CyclicArray::value_type &; - - Iterator() = default; - - Iterator(dsize_t idx, pointer ptr, dsize_t capacity, dsize_t head) - : cur_idx_(idx), ptr_(ptr), capacity_(capacity), head_(head) {} - - Iterator(const Iterator &rhs) = default; - - ~Iterator() = default; - - Iterator &operator++() { - cur_idx_ = (cur_idx_ + 1) % (capacity_ + 1); - return *this; - } - - Iterator operator++(int) { - Iterator tmp(*this); - cur_idx_ = (cur_idx_ + 1) % (capacity_ + 1); - return tmp; - } - - Iterator &operator--() { - cur_idx_ = (cur_idx_ + capacity_) % (capacity_ + 1); - return *this; - } - - Iterator operator--(int) { - Iterator tmp(*this); - cur_idx_ = (cur_idx_ + capacity_) % (capacity_ + 1); - return tmp; - } - - Iterator operator+(dsize_t x) { return Iterator((cur_idx_ + x) % (capacity_ + 1), ptr_, capacity_, head_); } - - Iterator operator-(dsize_t x) { - return Iterator((cur_idx_ + (capacity_ + 1 - x)) % (capacity_ + 1), ptr_, capacity_, head_); - } - - bool operator<(const Iterator &rhs) { - return (head_ + cur_idx_) % (capacity_ + 1) < (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); - } - - bool operator>(const Iterator &rhs) { - return (head_ + cur_idx_) % (capacity_ + 1) > (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); - } - - bool operator>=(const Iterator &rhs) { - return (head_ + cur_idx_) % (capacity_ + 1) >= (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); - } - - bool operator<=(const Iterator &rhs) { - return (head_ + cur_idx_) % (capacity_ + 1) <= (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); - } - - difference_type operator-(const Iterator &rhs) { - return (cur_idx_ - rhs.cur_idx_ + capacity_ + 1) % (capacity_ + 1); - } - - reference operator*() { return ptr_[cur_idx_]; } - - pointer operator->() { return &(ptr_[cur_idx_]); } - - bool operator==(const Iterator &rhs) { return cur_idx_ == rhs.cur_idx_; } - - bool operator!=(const Iterator &rhs) { return cur_idx_ != rhs.cur_idx_; } - - private: - dsize_t cur_idx_; - pointer ptr_; - dsize_t capacity_; - dsize_t head_; - }; - - /// \brief Default constructor - CyclicArray() : buf_(nullptr), head_(0), tail_(0), size_(0), capacity_(0) {} - - /// \brief Constructor - /// \param[in] capacity - explicit CyclicArray(dsize_t capacity) - : buf_(std::make_unique(capacity + 1)), head_(0), tail_(0), size_(0), capacity_(capacity) {} - - CyclicArray(const CyclicArray &rhs) - : buf_(std::make_unique(rhs.capacity_ + 1)), - head_(rhs.head_), - tail_(rhs.tail_), - size_(rhs.size_), - capacity_(rhs.capacity_) { - std::copy(rhs.begin(), rhs.end(), begin()); - } - - CyclicArray(CyclicArray &&rhs) = default; - - ~CyclicArray() = default; - - /// \brief Iterator begin() - Iterator begin() { return Iterator(head_, buf_.get(), capacity_, head_); } - - /// \brief Iterator end() - Iterator end() { return Iterator(tail_, buf_.get(), capacity_, head_); } - - // not really const. - Iterator begin() const { return Iterator(head_, buf_.get(), capacity_, head_); } - - Iterator end() const { return Iterator(tail_, buf_.get(), capacity_, head_); } - - /// \brief clear the array. Does not deallocate memory, capacity remains the same - void clear() { - head_ = 0; - tail_ = 0; - size_ = 0; - } - - /// \brief returns current size - dsize_t size() { return size_; } - - /// \brief returns capacity - dsize_t capacity() { return capacity_; } - - /// \brief pushes a value - /// \param[in] val value - void push_back(T val) { - buf_[tail_] = val; - if (size_ >= capacity_) { - (tail_ != capacity_) ? tail_++ : tail_ = 0; - (head_ != capacity_) ? head_++ : head_ = 0; - } else { - tail_++; - size_++; - } - } - - /// \brief returns const reference to an element of the array - /// \param[in] idx index of the element - /// \param[out] const T& reference to an element of the array - const T &operator[](dsize_t idx) const { return buf_[(head_ + idx) % (capacity_ + 1)]; } - - /// \brief returns non-const reference to an element of the array - /// \param[in] idx index of the element - /// \param[out] T& reference to an element of the array - T &operator[](dsize_t idx) { return buf_[(head_ + idx) % (capacity_ + 1)]; } - - private: - std::unique_ptr buf_; - dsize_t head_; - dsize_t tail_; - dsize_t size_; - dsize_t capacity_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CYCLIC_ARRAY_H diff --git a/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.cc b/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.cc deleted file mode 100644 index 99b0c2d7e0..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include "dataset/engine/perf/dataset_iterator_tracing.h" -#include "dataset/util/path.h" - -namespace mindspore { -namespace dataset { - -Status DatasetIteratorTracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, - const int32_t value) { - // Format: "type extra-info batch-num value" - // type: 0: time, 1: connector size - // extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time - // if type is 1 - connector capacity - // batch-num: batch number - // value: if type is 0 - value is time(ms) - // if type is 1 - value is connector size - // Examples: - // 0 0 20 10 - The 20th batch took 10ms to get data from pipeline. - // 1 64 20 5 - Connector size is 5 when get the 20th batch.Connector capacity is 64. - std::string data = std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " + - std::to_string(value); - value_.emplace_back(data); - return Status::OK(); -} - -Status DatasetIteratorTracing::SaveToFile() { - if (value_.empty()) { - return Status::OK(); - } - - std::ofstream handle(file_path_, std::ios::trunc); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Profiling file can not be opened."); - } - for (auto value : value_) { - handle << value << "\n"; - } - handle.close(); - - return Status::OK(); -} - -Status DatasetIteratorTracing::Init(const std::string &dir_path, const std::string &device_id) { - file_path_ = (Path(dir_path) / Path("dataset_iterator_profiling_" + device_id + ".txt")).toString(); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.h b/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.h deleted file mode 100644 index 129863c6d1..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_DATASET_ITERATOR_TRACING_H -#define MINDSPORE_DATASET_ITERATOR_TRACING_H - -#include -#include -#include "dataset/engine/perf/profiling.h" - -namespace mindspore { -namespace dataset { -class DatasetIteratorTracing : public Tracing { - public: - // Constructor - DatasetIteratorTracing() = default; - - // Destructor - ~DatasetIteratorTracing() override = default; - - // Record tracing data - // @return Status - The error code return - Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); - - std::string Name() const override { return kDatasetIteratorTracingName; }; - - // Save tracing data to file - // @return Status - The error code return - Status SaveToFile() override; - - Status Init(const std::string &dir_path, const std::string &device_id) override; - - private: - std::vector value_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_DATASET_ITERATOR_TRACING_H diff --git a/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.cc b/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.cc deleted file mode 100644 index 204a83e3fb..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "dataset/engine/perf/device_queue_tracing.h" -#include "dataset/util/path.h" -namespace mindspore { -namespace dataset { - -Status DeviceQueueTracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, - const int32_t value) { - // Format: "type extra-info batch-num value" - // type: 0: time, 1: connector size - // extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time - // if type is 1 - connector capacity - // batch-num: batch number - // value: if type is 0 - value is time(ms) - // if type is 1 - value is connector size - // Examples: - // 0 0 20 10 - The 20th batch took 10ms to get data from pipeline. - // 1 64 20 5 - Connector size is 5 when get the 20th batch.Connector capacity is 64. - std::string data = std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " + - std::to_string(value); - value_.emplace_back(data); - return Status::OK(); -} - -Status DeviceQueueTracing::SaveToFile() { - if (value_.empty()) { - return Status::OK(); - } - - std::ofstream handle(file_path_, std::ios::trunc); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Profiling file can not be opened."); - } - for (auto value : value_) { - handle << value << "\n"; - } - handle.close(); - - return Status::OK(); -} - -Status DeviceQueueTracing::Init(const std::string &dir_path, const std::string &device_id) { - file_path_ = (Path(dir_path) / Path("device_queue_profiling_" + device_id + ".txt")).toString(); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.h b/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.h deleted file mode 100644 index 13ef7121c1..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_DEVICE_QUEUE_TRACING_H -#define MINDSPORE_DEVICE_QUEUE_TRACING_H - -#include -#include -#include "dataset/engine/perf/profiling.h" - -namespace mindspore { -namespace dataset { -class DeviceQueueTracing : public Tracing { - public: - // Constructor - DeviceQueueTracing() = default; - - // Destructor - ~DeviceQueueTracing() override = default; - - // Record tracing data - // @return Status - The error code return - Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); - - std::string Name() const override { return kDeviceQueueTracingName; }; - - // Save tracing data to file - // @return Status - The error code return - Status SaveToFile() override; - - Status Init(const std::string &dir_path, const std::string &device_id) override; - - private: - std::vector value_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_DEVICE_QUEUE_TRACING_H diff --git a/mindspore/ccsrc/dataset/engine/perf/monitor.cc b/mindspore/ccsrc/dataset/engine/perf/monitor.cc deleted file mode 100644 index 8a0d682b81..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/monitor.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/core/config_manager.h" -#include "dataset/engine/perf/monitor.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { - -Monitor::Monitor(ExecutionTree *tree) : tree_(tree) { - std::shared_ptr cfg = GlobalContext::config_manager(); - sampling_interval_ = cfg->monitor_sampling_interval(); - max_samples_ = 0; - cur_row_ = 0; -} -Status Monitor::operator()() { - // Register this thread with TaskManager to receive proper interrupt signal. - TaskManager::FindMe()->Post(); - - // Keep sampling if - // 1) Monitor Task is not interrupted by TaskManager AND - // 2) Iterator has not received EOF - while (!this_thread::is_interrupted() && !(tree_->isFinished())) { - for (auto &node : tree_->GetProfilingManager()->GetSamplingNodes()) { - RETURN_IF_NOT_OK(node.second->Sample()); - std::this_thread::sleep_for(std::chrono::milliseconds(sampling_interval_)); - } - } - - // Output all profiling data upon request. - tree_->GetProfilingManager()->SaveProfilingData(); - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/monitor.h b/mindspore/ccsrc/dataset/engine/perf/monitor.h deleted file mode 100644 index 8b4245db8e..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/monitor.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MONITOR_H -#define MINDSPORE_MONITOR_H - -#include -#include -#include -#include "dataset/util/status.h" -#include "dataset/engine/perf/profiling.h" - -namespace mindspore { -namespace dataset { -class ExecutionTree; -class Monitor { - public: - // Monitor object constructor - - explicit Monitor(ExecutionTree *tree); - - Monitor() = default; - - ~Monitor() = default; - - // Functor for Perf Monitor main loop. - // This function will be the entry point of mindspore::Dataset::Task - Status operator()(); - - int64_t GetSamplingInterval() { return sampling_interval_; } - - private: - int64_t cur_row_; - int64_t max_samples_; - int64_t sampling_interval_; - ExecutionTree *tree_; - std::vector> sampling_list_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_MONITOR_H diff --git a/mindspore/ccsrc/dataset/engine/perf/perf_data.h b/mindspore/ccsrc/dataset/engine/perf/perf_data.h deleted file mode 100644 index a201d705ea..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/perf_data.h +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_PERF_DATA_H -#define DATASET_PERF_DATA_H - -#include -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { - -// PerfData is a convenience class to record and store the data produced by Monitor -// and represents a 2D column major table with every column storing samples -// for an operator. The number of rows equals to the number of samples, -// the number of columns equals to the number of operators. -// The capacity is determined on construction and cannot be changed. -// ColumnType can be std::vector or CyclicArray. In case of the latter data can be added -// indefinitely without the risk of overflowing otherwise the capacity must not be exceeded. -// Given PerfData pd(n_rows, n_cols) an element in the column i and row j can be accessed as -// pd[i][j] - -template -class PerfData { - public: - PerfData() = default; - ~PerfData() = default; - PerfData(dsize_t max_rows, dsize_t n_cols) : counter_(0), max_rows_(max_rows), n_cols_(n_cols) { - for (auto i = 0; i < n_cols_; i++) { - data_.push_back(ColumnType(max_rows_)); - } - } - PerfData(const PerfData &rhs) = default; - PerfData(PerfData &&rhs) = default; - - // Adds a row of data - // T must be any container working with range based loops - template - void AddSample(const T &row) { - auto i = 0; - for (const auto &e : row) { - data_[i++].push_back(e); - } - counter_++; - } - - // Fetches a row of data by copy - template - auto Row(dsize_t idx) { - std::vector row(n_cols_); - for (auto i = 0; i < n_cols_; i++) { - row[i] = data_[i][idx]; - } - return row; - } - - // returns a column of data - ColumnType &operator[](size_t idx) { return data_[idx]; } - - const ColumnType &operator[](size_t idx) const { return data_[idx]; } - - dsize_t size() { return counter_ < max_rows_ ? counter_ : max_rows_; } - - dsize_t capacity() { return max_rows_; } - - private: - std::vector data_; - dsize_t counter_; - dsize_t max_rows_; - int n_cols_; -}; - -} // namespace dataset -} // namespace mindspore -#endif // DATASET_PERF_DATA_H diff --git a/mindspore/ccsrc/dataset/engine/perf/profiling.cc b/mindspore/ccsrc/dataset/engine/perf/profiling.cc deleted file mode 100644 index 66f27c46ba..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/profiling.cc +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/perf/profiling.h" -#include -#include -#include -#include "common/utils.h" -#include "dataset/util/path.h" -#include "dataset/engine/perf/monitor.h" -#include "dataset/engine/perf/device_queue_tracing.h" -#include "dataset/engine/perf/connector_size.h" -#include "dataset/engine/perf/connector_throughput.h" -#include "dataset/engine/perf/dataset_iterator_tracing.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { - -bool ProfilingManager::IsProfilingEnable() const { - auto profiling = common::GetEnv("PROFILING_MODE"); - if (profiling.empty() || profiling != "true") { - return false; - } - return true; -} - -Status ProfilingManager::Initialize() { - // Register nodes based on config - std::string dir = common::GetEnv("MINDDATA_PROFILING_DIR"); - if (dir.empty()) { - RETURN_STATUS_UNEXPECTED("Profiling dir is not set."); - } - char real_path[PATH_MAX] = {0}; - if (dir.size() >= PATH_MAX) { - RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); - } -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) { - RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); - } -#else - if (realpath(common::SafeCStr(dir), real_path) == nullptr) { - RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); - } -#endif - dir_path_ = real_path; - - // If DEVICE_ID is not set,defult value is 0 - device_id_ = common::GetEnv("DEVICE_ID"); - if (device_id_.empty()) { - device_id_ = "0"; - } - - // Register all profiling node. - // device_queue node is used for graph mode - std::shared_ptr device_queue_tracing = std::make_shared(); - RETURN_IF_NOT_OK(RegisterTracingNode(device_queue_tracing)); - // dataset_iterator node is used for graph mode - std::shared_ptr dataset_iterator_tracing = std::make_shared(); - RETURN_IF_NOT_OK(RegisterTracingNode(dataset_iterator_tracing)); - - std::shared_ptr connector_size_sampling = std::make_shared(tree_); - RETURN_IF_NOT_OK(RegisterSamplingNode(connector_size_sampling)); - - std::shared_ptr connector_thr_sampling = std::make_shared(tree_); - RETURN_IF_NOT_OK(RegisterSamplingNode(connector_thr_sampling)); - return Status::OK(); -} - -// Profiling node registration -Status ProfilingManager::RegisterTracingNode(std::shared_ptr node) { - // Check if node with the same name has already been registered. - auto exist = tracing_nodes_.find(node->Name()); - if (exist != tracing_nodes_.end()) { - return Status(StatusCode::kProfilingError, "Profiling node already exist: " + node->Name()); - } - // Register the node with its name as key. - RETURN_IF_NOT_OK(node->Init(dir_path_, device_id_)); - tracing_nodes_[node->Name()] = node; - return Status::OK(); -} - -// Profiling node getter -Status ProfilingManager::GetTracingNode(const std::string &name, std::shared_ptr *node) { - // Check if node with the same name has already been registered. - auto exist = tracing_nodes_.find(name); - if (exist == tracing_nodes_.end()) { - return Status(StatusCode::kProfilingError, "Profiling node does not exist: " + name); - } - // Fetch node. - *node = tracing_nodes_[name]; - return Status::OK(); -} - -// Profiling node registration -Status ProfilingManager::RegisterSamplingNode(std::shared_ptr node) { - // Check if node with the same name has already been registered. - auto exist = sampling_nodes_.find(node->Name()); - if (exist != sampling_nodes_.end()) { - return Status(StatusCode::kProfilingError, "Profiling node already exist: " + node->Name()); - } - // Register the node with its name as key. - RETURN_IF_NOT_OK(node->Init(dir_path_, device_id_)); - sampling_nodes_[node->Name()] = node; - return Status::OK(); -} - -// Profiling node getter -Status ProfilingManager::GetSamplingNode(const std::string &name, std::shared_ptr *node) { - // Check if node with the same name has already been registered. - auto exist = sampling_nodes_.find(name); - if (exist == sampling_nodes_.end()) { - return Status(StatusCode::kProfilingError, "Profiling node does not exist: " + name); - } - // Fetch node. - *node = sampling_nodes_[name]; - return Status::OK(); -} - -Status ProfilingManager::SaveProfilingData() { - if (!IsProfilingEnable()) { - return Status::OK(); - } - MS_LOG(INFO) << "Start to save profiling data."; - for (auto node : tracing_nodes_) { - RETURN_IF_NOT_OK(node.second->SaveToFile()); - } - for (auto node : sampling_nodes_) { - RETURN_IF_NOT_OK(node.second->SaveToFile()); - } - MS_LOG(INFO) << "Save profiling data end."; - return Status::OK(); -} - -int64_t ProfilingTime::GetCurMilliSecond() { - // because cpplint does not allow using namespace - using std::chrono::duration_cast; - using std::chrono::milliseconds; - using std::chrono::steady_clock; - return duration_cast(steady_clock::now().time_since_epoch()).count(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/profiling.h b/mindspore/ccsrc/dataset/engine/perf/profiling.h deleted file mode 100644 index e38c2d5e54..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/profiling.h +++ /dev/null @@ -1,144 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_PROFILE_H_ -#define DATASET_UTIL_PROFILE_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class Monitor; -class ExecutionTree; - -const char kDeviceQueueTracingName[] = "Device_Queue_Tracing"; -const char kDatasetIteratorTracingName[] = "Dataset_Iterator_Tracing"; -const char kConnectorSizeSamplingName[] = "Connector_Size_Sampling"; -const char kConnectorThroughputSamplingName[] = "Connector_Throughput_Sampling"; - -// Profiling is a class of basic unit of profiling action -// This base class encapsulate the serialization output logic -class Profiling : std::enable_shared_from_this { - public: - // Constructor - Profiling() = default; - - // Destructor - virtual ~Profiling() = default; - - virtual Status Init(const std::string &dir_path, const std::string &device_id) = 0; - - // Default serialization file generator - virtual Status SaveToFile() = 0; - - // Profiling name - virtual std::string Name() const = 0; - - protected: - std::string file_path_; -}; - -// Sampling is a class of profiling which generate samples periodically. -class Sampling : public Profiling { - public: - // Sampling action function. This function will be invoked by performance monitor thread. - virtual Status Sample() = 0; - // virtual Status TestPrint() = 0; - virtual ~Sampling() = default; -}; - -// Tracing is class of profiling which record samples upon request. -class Tracing : public Profiling { - // Tracing does not define a fixed interface to provide flexible on data recording. -}; - -// ProfilingManager is a class manages all profiling infrastructure -// It serves the following purposes: -// 1) Fetch profiling configs from global contexts -// 2) Setup all profiling node based on config -// 3) Provide access of profiling nodes for profiling actions -// 4) Manage profiling data serialization process -class ProfilingManager { - public: - explicit ProfilingManager(ExecutionTree *tree) : tree_(tree) {} - - ~ProfilingManager() = default; - - Status Initialize(); - - // Save profile data to file - // @return Status - The error code return - Status SaveProfilingData(); - - // Sampling node getter - // @param name - The name of the requested node - // @param node - Pointer to the shared pointer for the Sampling node - // @return Status - The error code return - Status GetSamplingNode(const std::string &name, std::shared_ptr *node); - - // Tracing node getter - // @param name - The name of the requested node - // @param node - Pointer to the shared pointer for the Tracing node - // @return Status - The error code return - Status GetTracingNode(const std::string &name, std::shared_ptr *node); - - // If profiling is enabled. - bool IsProfilingEnable() const; - - const std::unordered_map> &GetSamplingNodes() { return sampling_nodes_; } - - private: - std::unordered_map> tracing_nodes_; - - std::unordered_map> sampling_nodes_; - - // Register profile node to tree - // @param node - Profiling node - // @return Status - The error code return - Status RegisterTracingNode(std::shared_ptr node); - - // Register profile node to tree - // @param node - Profiling node - // @return Status - The error code return - Status RegisterSamplingNode(std::shared_ptr node); - - ExecutionTree *tree_ = nullptr; // ExecutionTree pointer - std::string dir_path_; // where to create profiling file - std::string device_id_; // used when create profiling file,filename_deviceid.suffix -}; - -enum ProfilingType { TIME, CONNECTOR_DEPTH }; - -enum ProfilingTimeSubType { - PIPELINE_TIME, - TDT_PUSH_TIME, - BATCH_TIME, - INVALID_TIME, -}; - -class ProfilingTime { - public: - static int64_t GetCurMilliSecond(); -}; - -} // namespace dataset -} // namespace mindspore -#endif diff --git a/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.cc deleted file mode 100644 index ca9f2176f5..0000000000 --- a/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/tdt/tdt_plugin.h" -#include "common/utils.h" -#include "utils/log_adapter.h" -#include "dataset/engine/perf/profiling.h" - -namespace mindspore { -namespace dataset { -static std::shared_ptr instance_ptr_ = nullptr; - -std::shared_ptr TdtPlugin::GetInstance() { - if (instance_ptr_ == nullptr) { - instance_ptr_ = std::shared_ptr(new TdtPlugin); - } - return instance_ptr_; -} - -TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time) { - MS_LOG(DEBUG) << "TDT channel name is " << channel_name << "."; - std::vector items; - double start_time; - auto ret = translate(ts_row, items); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "TDT converting tensor failed!"; - return FAILED; - } - if (profiling) { - start_time = ProfilingTime::GetCurMilliSecond(); - } - if (tdt::TdtHostPushData(channel_name, items) != 0) { - MS_LOG(ERROR) << "TDT pushing data failed!"; - return FAILED; - } - if (profiling) { - double end_time = ProfilingTime::GetCurMilliSecond(); - time = (int32_t)(end_time - start_time); - } - return SUCCESS; -} - -TdtStatus TdtPlugin::getTdtType(DataType d_type, std::string &datatype) { - switch (d_type.value()) { - case DataType::DE_BOOL: - datatype = "bool"; - break; - case DataType::DE_INT8: - datatype = "int8"; - break; - case DataType::DE_UINT8: - datatype = "uint8"; - break; - case DataType::DE_INT16: - datatype = "int16"; - break; - case DataType::DE_UINT16: - datatype = "uint16"; - break; - case DataType::DE_INT32: - datatype = "int32"; - break; - case DataType::DE_UINT32: - datatype = "uint32"; - break; - case DataType::DE_FLOAT16: - datatype = "float16"; - break; - case DataType::DE_FLOAT32: - datatype = "float32"; - break; - case DataType::DE_FLOAT64: - datatype = "float64"; - break; - case DataType::DE_INT64: - datatype = "int64"; - break; - case DataType::DE_UINT64: - datatype = "uint64"; - break; - default: - return FAILED; - } - return SUCCESS; -} - -TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector &items) { - if (ts_row.size() == 0) { - MS_LOG(ERROR) << "TDT the size of row is zero."; - return SUCCESS; - } - for (auto ts : ts_row) { - std::string datatype; - TdtStatus status = getTdtType(ts->type(), datatype); - if (status != SUCCESS) { - return status; - } - TensorShape tsShape = ts->shape(); - std::string dataShapes = "["; - for (auto dim : tsShape.AsVector()) { - (void)dataShapes.append(std::to_string(dim)).append(","); - } - dataShapes.pop_back(); - (void)dataShapes.append("]"); - DataItem data_item; - data_item.dataType_ = tdt::TDT_TENSOR; - data_item.tensorShape_ = dataShapes; - data_item.tensorType_ = datatype; - data_item.dataLen_ = ts->SizeInBytes(); - data_item.dataPtr_ = - std::shared_ptr(reinterpret_cast(&(*ts->begin())), [](const void *elem) {}); - items.emplace_back(data_item); - MS_LOG(DEBUG) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is " - << ts->Size() << "."; - } - return SUCCESS; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.h b/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.h deleted file mode 100644 index 304b205b81..0000000000 --- a/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_TDT_TDT_PLUGIN_H_ -#define DATASET_ENGINE_TDT_TDT_PLUGIN_H_ - -#include -#include -#include -#include -#include -#include -#include "tdt/tdt_host_interface.h" - -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_row.h" - -namespace mindspore { -namespace dataset { -enum TdtStatus { SUCCESS, FAILED }; - -using tdt::DataItem; - -class TdtPlugin { - public: - static std::shared_ptr GetInstance(); - - TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time); - - private: - TdtPlugin() {} - - TdtStatus getTdtType(DataType d_type, std::string &datatype); - - TdtStatus translate(const TensorRow &ts_row, std::vector &items); - - void *tdt_handle_ = nullptr; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_TDT_TDT_PLUGIN_H_ diff --git a/mindspore/ccsrc/dataset/kernels/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/CMakeLists.txt deleted file mode 100644 index 2ebdd15e3c..0000000000 --- a/mindspore/ccsrc/dataset/kernels/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -add_subdirectory(image) -add_subdirectory(data) -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(kernels OBJECT - py_func_op.cc - tensor_op.cc) -target_include_directories(kernels PRIVATE ${pybind11_INCLUDE_DIRS}) diff --git a/mindspore/ccsrc/dataset/kernels/data/concatenate_op.cc b/mindspore/ccsrc/dataset/kernels/data/concatenate_op.cc deleted file mode 100644 index 87115fd3ce..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/concatenate_op.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/concatenate_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -Status ConcatenateOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - RETURN_IF_NOT_OK(Concatenate(input, output, axis_, prepend_, append_)); - return Status::OK(); -} - -Status ConcatenateOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - - std::vector inputs_copy; - inputs_copy.push_back(inputs[0].Squeeze()); - - CHECK_FAIL_RETURN_UNEXPECTED(inputs.at(0).Rank() == 1, "Only 1D input tensors supported"); - - outputs.clear(); - dsize_t output_shape = 0; - output_shape = output_shape + inputs.at(0).NumOfElements(); - if (prepend_ != nullptr) { - CHECK_FAIL_RETURN_UNEXPECTED(prepend_->shape().Rank() == 1, "Only 1D prepend tensors supported"); - output_shape = output_shape + prepend_->shape().NumOfElements(); - } - if (append_ != nullptr) { - CHECK_FAIL_RETURN_UNEXPECTED(append_->shape().Rank() == 1, "Only 1D append tensors supported"); - output_shape = output_shape + append_->shape().NumOfElements(); - } - - outputs.emplace_back(std::vector{output_shape}); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/concatenate_op.h b/mindspore/ccsrc/dataset/kernels/data/concatenate_op.h deleted file mode 100644 index 4e4c7ad4e0..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/concatenate_op.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_KERNELS_DATA_CONCATENATE_OP_H_ -#define DATASET_KERNELS_DATA_CONCATENATE_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -class ConcatenateOp : public TensorOp { - public: - /// Constructor to ConcatenateOp. - /// @param int8_t axis - axis to concatenate tensors along. - /// @param std::shared_ptr prepend - prepend tensor. - /// @param std::shared_ptr append -append tensor. - explicit ConcatenateOp(int8_t axis, std::shared_ptr prepend, std::shared_ptr append) - : axis_(axis), prepend_(prepend), append_(append) {} - - ~ConcatenateOp() override = default; - - /// Print method to see which tensor Op this is. - /// @param std::ostream &out - output stream object. - void Print(std::ostream &out) const override { out << "ConcatenateOp"; } - - /// Compute method allowing multiple tensors as inputs - /// @param TensorRow &input - input tensor rows - /// @param TensorRow *output - output tensor rows - Status Compute(const TensorRow &input, TensorRow *output) override; - - /// Compute tensor output shape - /// @param std::vector &inputs - vector of input tensor shapes - /// @param std::vector &inputs, std::vector &outputs) override; - - /// Number of inputs the tensor operation accepts - uint32_t NumInput() override { return 0; } - - private: - int8_t axis_; - std::shared_ptr prepend_; - std::shared_ptr append_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_CONCATENATE_OP_H diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc deleted file mode 100644 index 40eba1edf6..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc +++ /dev/null @@ -1,649 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/data/data_utils.h" - -#include -#include -#include -#include - -#include "dataset/core/constants.h" -#include "dataset/core/data_type.h" -#include "dataset/core/pybind_support.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status OneHotEncodingUnsigned(const std::shared_ptr &input, std::shared_ptr *output, - dsize_t num_classes, int64_t index) { - uint64_t class_idx; - if (input->Rank() == 0) { - RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {})); - } else { - RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {index})); - } - if (class_idx >= static_cast(num_classes)) { - RETURN_STATUS_UNEXPECTED("One_hot index values are not in range"); - } - if (input->type() == DataType::DE_UINT64) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_UINT32) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_UINT16) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_UINT8) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else { - RETURN_STATUS_UNEXPECTED("One hot unsigned only supports unsigned int as input."); - } - return Status::OK(); -} - -Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, - int64_t index) { - int64_t class_idx; - if (input->Rank() == 0) { - RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {})); - } else { - RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {index})); - } - if (class_idx >= static_cast(num_classes)) { - RETURN_STATUS_UNEXPECTED("One_hot index values are not in range"); - } - if (input->type() == DataType::DE_INT64) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_INT32) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_INT16) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_INT8) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else { - RETURN_STATUS_UNEXPECTED("One hot signed only supports signed int as input."); - } - return Status::OK(); -} - -Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *output, dsize_t num_classes) { - input->Squeeze(); - - if (input->Rank() > 1) { // We expect the input to be int he first dimension - RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors."); - } - if (!input->type().IsInt()) { - RETURN_STATUS_UNEXPECTED("One hot does not support input of this type."); - } - try { - dsize_t num_elements = 1; - if (input->Rank() == 1) num_elements = input->shape()[0]; - TensorShape out_shape({num_elements, num_classes}); - std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type())); - RETURN_IF_NOT_OK(out->Zero()); - for (dsize_t i = 0; i < num_elements; ++i) { - if (input->type().IsUnsignedInt()) { - RETURN_IF_NOT_OK(OneHotEncodingUnsigned(input, &out, num_classes, i)); - } else { - RETURN_IF_NOT_OK(OneHotEncodingSigned(input, &out, num_classes, i)); - } - } - out->Squeeze(); - *output = out; - return Status::OK(); - } catch (const std::exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp"); - } -} - -Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value) { - CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)), - "Types do not match"); - - CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar"); - - std::shared_ptr out; - - const DataType &to = input->type(); - std::unique_ptr op(new TypeCastOp(to)); - - std::shared_ptr fill_output; - RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output)); - - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type())); - - switch (input->type().value()) { - case DataType::DE_BOOL: { - bool value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_INT8: { - int8_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_UINT8: { - uint8_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_UINT16: { - uint16_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_INT16: { - int16_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_UINT32: { - uint32_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_INT32: { - int32_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_UINT64: { - uint64_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_INT64: { - int64_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_FLOAT16: { - int64_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_FLOAT32: { - float value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_FLOAT64: { - double value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_STRING: { - std::vector strings; - std::string_view fill_string_view; - RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {})); - std::string fill_string = std::string(fill_string_view); - for (int i = 0; i < input->shape().NumOfElements(); i++) { - strings.emplace_back(fill_string); - } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape())); - break; - } - case DataType::DE_UNKNOWN: { - RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type."); - break; - } - } - - *output = out; - return Status::OK(); -} -template -void Cast(const std::shared_ptr &input, std::shared_ptr *output) { - auto in_itr = input->begin(); - auto out_itr = (*output)->begin(); - auto out_end = (*output)->end(); - - for (; out_itr != out_end; static_cast(in_itr++), static_cast(out_itr++)) - *out_itr = static_cast(*in_itr); -} - -template -void CastFrom(const std::shared_ptr &input, std::shared_ptr *output) { - switch ((*output)->type().value()) { - case DataType::DE_BOOL: - Cast(input, output); - break; - case DataType::DE_INT8: - Cast(input, output); - break; - case DataType::DE_UINT8: - Cast(input, output); - break; - case DataType::DE_INT16: - Cast(input, output); - break; - case DataType::DE_UINT16: - Cast(input, output); - break; - case DataType::DE_INT32: - Cast(input, output); - break; - case DataType::DE_UINT32: - Cast(input, output); - break; - case DataType::DE_INT64: - Cast(input, output); - break; - case DataType::DE_UINT64: - Cast(input, output); - break; - case DataType::DE_FLOAT16: - Cast(input, output); - break; - case DataType::DE_FLOAT32: - Cast(input, output); - break; - case DataType::DE_FLOAT64: - Cast(input, output); - break; - case DataType::DE_UNKNOWN: - MS_LOG(ERROR) << "Unknown data type."; - break; - } -} - -// Type cast operator -Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type)); - - RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); - switch (input->type().value()) { - case DataType::DE_BOOL: - CastFrom(input, output); - break; - case DataType::DE_INT8: - CastFrom(input, output); - break; - case DataType::DE_UINT8: - CastFrom(input, output); - break; - case DataType::DE_INT16: - CastFrom(input, output); - break; - case DataType::DE_UINT16: - CastFrom(input, output); - break; - case DataType::DE_INT32: - CastFrom(input, output); - break; - case DataType::DE_UINT32: - CastFrom(input, output); - break; - case DataType::DE_INT64: - CastFrom(input, output); - break; - case DataType::DE_UINT64: - CastFrom(input, output); - break; - case DataType::DE_FLOAT16: - CastFrom(input, output); - break; - case DataType::DE_FLOAT32: - CastFrom(input, output); - break; - case DataType::DE_FLOAT64: - CastFrom(input, output); - break; - case DataType::DE_UNKNOWN: - // sanity check, unreachable code. - RETURN_STATUS_UNEXPECTED("TypeCast does not support input of this type."); - } - return Status::OK(); -} - -Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { - // initiate new tensor for type cast - DataType new_type = DataType("float16"); - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type)); - RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); - - auto in_itr = input->begin(); - auto out_itr = (*output)->begin(); - auto out_end = (*output)->end(); - - for (; out_itr != out_end; in_itr++, out_itr++) { - float element = *in_itr; - float float16_max = static_cast(std::numeric_limits::max()); - float float16_min = static_cast(std::numeric_limits::lowest()); - if (element > float16_max || element < float16_min) { - RETURN_STATUS_UNEXPECTED("Value " + std::to_string(element) + " is outside of valid float16 range [" + - std::to_string(float16_max) + ", " + std::to_string(float16_min) + "]."); - } - - *out_itr = Eigen::half(*in_itr); - } - - return Status::OK(); -} - -Status PadEnd(const std::shared_ptr &src, std::shared_ptr *dst, const std::vector &pad_shape, - const std::shared_ptr &pad_val) { - if (pad_val == nullptr) { - if (src->type().IsNumeric()) { - return PadEndNumeric(src, dst, pad_shape, 0); - } else { - return PadEndString(src, dst, pad_shape, ""); - } - } - CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(), - "Source and pad_value tensors are not of the same type."); - if (pad_val->type().IsNumeric()) { - std::shared_ptr float_pad_value; - RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32))); - float val = 0; - RETURN_IF_NOT_OK(float_pad_value->GetItemAt(&val, {})); - return PadEndNumeric(src, dst, pad_shape, val); - } - std::string_view val; - RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {})); - return PadEndString(src, dst, pad_shape, std::string(val)); -} - -Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr *dst, - const std::vector &pad_shape, float pad_val) { - CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr"); - if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) { - (*dst) = src; // if no padding, copy the pointer - } else { - CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); - RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type())); - auto tensor_type = src->type().value(); - if (pad_val == 0) { // if pad with zero, don't care what type it is - RETURN_IF_NOT_OK((*dst)->Zero()); - } else if (tensor_type == DataType::DE_INT8) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_BOOL) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_UINT8) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_INT16) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_FLOAT16) { - RETURN_IF_NOT_OK((*dst)->Fill(static_cast(pad_val))); - } else if (tensor_type == DataType::DE_UINT16) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_INT32) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_UINT32) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_INT64) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_UINT64) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_FLOAT32) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_FLOAT64) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else { - RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type"); - } - std::vector cur_ind(src->Rank(), 0); - RETURN_IF_NOT_OK(PadEndNumericHelper(src, *dst, cur_ind, 0)); - } - return Status::OK(); -} -Status PadEndNumericHelper(const std::shared_ptr &src, std::shared_ptr dst, - std::vector cur_ind, size_t cur_dim) { - if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data - dst->CopyLastDimAt(src, cur_ind); - } else { // not the last dimension, keep doing recursion - dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]); - for (dsize_t i = 0; i < min_ind; i++) { - cur_ind[cur_dim] = i; - RETURN_IF_NOT_OK(PadEndNumericHelper(src, dst, cur_ind, cur_dim + 1)); - } - } - return Status::OK(); -} - -Status PadEndString(const std::shared_ptr &src, std::shared_ptr *dst, - const std::vector &pad_shape, const std::string &pad_val) { - CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr"); - if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) { - (*dst) = src; // if no padding, copy the pointer - } else { - CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); - std::vector cur_ind(src->Rank(), 0); - std::vector strings; - RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, strings, TensorShape(pad_shape))); - } - return Status::OK(); -} - -Status PadEndStringHelper(const std::shared_ptr &src, std::vector *dst, - const TensorShape &dst_shape, std::vector cur_ind, size_t cur_dim, - const std::string &pad_value) { - if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data - dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]); - for (dsize_t i = 0; i < min_ind; i++) { - cur_ind[cur_dim] = i; - std::string_view item; - RETURN_IF_NOT_OK(src->GetItemAt(&item, cur_ind)); - dst->emplace_back(item); - } - for (dsize_t i = min_ind; i < dst_shape[cur_dim]; i++) { - dst->emplace_back(pad_value); - } - - } else { // not the last dimension, keep doing recursion - dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]); - for (dsize_t i = 0; i < min_ind; i++) { - cur_ind[cur_dim] = i; - RETURN_IF_NOT_OK(PadEndStringHelper(src, dst, dst_shape, cur_ind, cur_dim + 1, pad_value)); - } - dsize_t count = (dst_shape[cur_dim] - min_ind) * dst_shape.Strides()[cur_dim]; - for (dsize_t i = 0; i < count; i++) { - dst->emplace_back(pad_value); - } - } - return Status::OK(); -} - -template -Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &output, - const std::shared_ptr &value_tensor, RelationalOp op) { - T value; - RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {})); - auto in_itr = input->begin(); - auto out_itr = output->begin(); - for (; in_itr != input->end(); in_itr++, out_itr++) { - switch (op) { - case RelationalOp::kEqual: - *out_itr = (*in_itr == value); - break; - case RelationalOp::kNotEqual: - *out_itr = (*in_itr != value); - break; - case RelationalOp::kGreater: - *out_itr = (*in_itr > value); - break; - case RelationalOp::kGreaterEqual: - *out_itr = (*in_itr >= value); - break; - case RelationalOp::kLess: - *out_itr = (*in_itr < value); - break; - case RelationalOp::kLessEqual: - *out_itr = (*in_itr <= value); - break; - default: - RETURN_STATUS_UNEXPECTED("Unknown relational operator."); - } - } - return Status::OK(); -} - -Status Mask(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, - RelationalOp op) { - CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(), - "Cannot convert constant value to the type of the input tensor."); - CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar"); - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL))); - - std::unique_ptr value_cast_op(new TypeCastOp(input->type())); - std::shared_ptr casted_value; - if (input->type().IsNumeric()) { - RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value)); - } else { - casted_value = value; - } - - switch (input->type().value()) { - case DataType::DE_BOOL: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_INT8: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_UINT8: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_UINT16: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_INT16: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_UINT32: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_INT32: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_UINT64: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_INT64: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_FLOAT16: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_FLOAT32: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_FLOAT64: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_STRING: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_UNKNOWN: - RETURN_STATUS_UNEXPECTED("Unsupported input type."); - break; - } - return Status::OK(); -} - -Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, - std::shared_ptr append) { - CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported"); - CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported"); - - axis = Tensor::HandleNeg(axis, input[0]->shape().Rank()); - CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported"); - - std::shared_ptr out; - if (prepend != nullptr) { - CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported"); - RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0])); - } else { - out = input[0]; - } - for (dsize_t i = 1; i < input.size(); i++) { - std::shared_ptr out_t; - CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported"); - RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i])); - out = out_t; - } - std::shared_ptr out_t; - if (append != nullptr) { - CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported"); - RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append)); - } else { - out_t = out; - } - output->push_back(out_t); - - return Status::OK(); -} - -Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, - std::shared_ptr append) { - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match"); - - TensorShape t({}); - - for (dsize_t i = 0; i < input->shape().Rank(); i++) { - if (i != axis) { - t = t.AppendDim(input->shape()[i]); - } else { - dsize_t new_shape = input->shape()[i] + append->shape()[i]; - - t = t.AppendDim(new_shape); - } - } - std::shared_ptr out; - - if (input->type().IsNumeric()) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type())); - - RETURN_IF_NOT_OK(out->Concatenate({0}, input)); - RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append)); - *output = out; - } else { - std::vector strings; - - auto itr = input->begin(); - for (; itr != input->end(); itr++) { - strings.emplace_back(*itr); - } - itr = append->begin(); - for (; itr != append->end(); itr++) { - strings.emplace_back(*itr); - } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t)); - - *output = out; - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/dataset/kernels/data/data_utils.h deleted file mode 100644 index 6034e2a0eb..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.h +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_DATA_UTILS_H_ -#define DATASET_KERNELS_DATA_DATA_UTILS_H_ - -#include -#include -#include -#include "dataset/core/constants.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_row.h" - -namespace mindspore { -namespace dataset { -// Returns Onehot encoding of the input tensor. -// Example: if input=2 and numClasses=3, the output is [0 0 1]. -// @param input: Tensor has type DE_UINT64, the non-one hot values are stored -// along the first dimensions or rows.. -// If the rank of input is not 1 or the type is not DE_UINT64, -// then it will fail. -// @param output: Tensor. The shape of the output tensor is -// and the type is same as input. -// @param num_classes: Number of classes to. -Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *output, dsize_t num_classes); - -Status OneHotEncodingUnsigned(const std::shared_ptr &input, std::shared_ptr *output, - dsize_t num_classes, int64_t index); - -Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, - int64_t index); - -// Returns a tensor of shape input filled with the passed fill_value -// @param input Tensor -// @param output Tensor. The shape and type of the output tensor is same as input -// @param fill_value Tensor. A scalar tensor used to fill the output tensor - -Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value); - -// Returns a type changed input tensor. -// Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp -// @param input Tensor -// @param output Tensor. The shape of the output tensor is same as input with the type changed. -// @param data_type: type of data to cast data to -// @note: this operation will do a memcpy and if the value is truncated then precision will be lost - -template -void CastFrom(const std::shared_ptr &input, std::shared_ptr *output); - -template -void Cast(const std::shared_ptr &input, std::shared_ptr *output); - -Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output); - -Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type); - -// Pad input tensor according pad_shape, need to have same rank. -// Based on the type of the input tensor, PadEndNumeric/String will be called. -// @param std::shared_ptr src - tensor to pad from -// @param std::shared_ptr *dst - return tensor padded -// @param std::vector pad_shape - shape to pad to -// @param std::shared_ptr pad_val - value to pad with in Tensor format, -// @return - The error code return -Status PadEnd(const std::shared_ptr &src, std::shared_ptr *dst, const std::vector &pad_shape, - const std::shared_ptr &pad_val); - -// Pad input numeric tensor according pad_shape, need to have same rank. -// @param std::shared_ptr src - tensor to pad from -// @param std::shared_ptr *dst - return tensor padded -// @param std::vector pad_shape - shape to pad to -// @param float pad_val - value to pad with -// @return - The error code return -Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr *dst, - const std::vector &pad_shape, float pad_val); - -// recursive helper function for padding numric tensors. This function could be very expensive if called on a -// multi-dimensional tensor it is only meant to be called by PadEndNumeric. -// @tparam T - type of tensor and fill value -// @param std::shared_ptr src - Tensor to pad from -// @param std::shared_ptr* dst - Tensor to pad to, return value -// @param std::vector cur_ind - recursion helper -// @param T pad_val - value to pad tensor with -// @param size_t cur_dim - recursion helper -// @return Status - The error code return -Status PadEndNumericHelper(const std::shared_ptr &src, std::shared_ptr dst, - std::vector cur_ind, size_t cur_dim = 0); - -// Pad input string tensor according pad_shape, need to have same rank. -// @param std::shared_ptr src - tensor to pad from -// @param std::shared_ptr *dst - return tensor padded -// @param std::vector pad_shape - shape to pad to -// @param std::string pad_val - value to pad with -// @return - The error code return -Status PadEndString(const std::shared_ptr &src, std::shared_ptr *dst, - const std::vector &pad_shape, const std::string &pad_val); - -// recursive helper function for padding string tensors. This function could be very expensive if called on a -// multi-dimensional tensor it is only meant to be called by PadEndString. -// @tparam T - type of tensor and fill value -// @param std::shared_ptr src - Tensor to pad from -// @param std::shared_ptr* dst - Tensor to pad to, return value -// @param std::vector cur_ind - recursion helperas text -// @param std::string pad_val - value to pad tensor with -// @param size_t cur_dim - recursion helper -// @return Status - The error code return -Status PadEndStringHelper(const std::shared_ptr &src, std::vector *dst, - const TensorShape &dst_shape, std::vector cur_ind, size_t cur_dim, - const std::string &pad_value); - -enum class RelationalOp { - kEqual = 0, // == - kNotEqual, // != - kLess, // < - kLessEqual, // <= - kGreater, // > - kGreaterEqual, // >= -}; - -/// Helper method that masks the input tensor -/// @tparam T type of the tensor -/// @param input[in] input tensor -/// @param output[out] output tensor -/// @param value_tensor[in] scalar tensor value to compared with -/// @param op[in] RelationalOp enum -/// @return Status ok/error -template -Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &output, - const std::shared_ptr &value_tensor, RelationalOp op); - -/// Mask the input tensor -/// @param input[in] input tensor -/// @param output[out] output tensor -/// @param value[in] scalar tensor value to compared with -/// @param op[in] RelationalOp enum -/// @return Status ok/error -Status Mask(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, - RelationalOp op); - -Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, - std::shared_ptr append); - -// helper for concat, always append to the input, and pass that to the output -Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, - std::shared_ptr append); - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_DATA_DATA_UTILS_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/duplicate_op.cc b/mindspore/ccsrc/dataset/kernels/data/duplicate_op.cc deleted file mode 100644 index 959516a4aa..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/duplicate_op.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/data/duplicate_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -Status DuplicateOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); - std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, input[0])); - output->push_back(input[0]); - output->push_back(out); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/duplicate_op.h b/mindspore/ccsrc/dataset/kernels/data/duplicate_op.h deleted file mode 100644 index 4c9d6d36c9..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/duplicate_op.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_DUPLICATE_OP_H_ -#define DATASET_KERNELS_DATA_DUPLICATE_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -class DuplicateOp : public TensorOp { - public: - DuplicateOp() = default; - - ~DuplicateOp() override = default; - - void Print(std::ostream &out) const override { out << "DuplicateOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - uint32_t NumOutput() override { return 2; } -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DUPLICATE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/fill_op.cc b/mindspore/ccsrc/dataset/kernels/data/fill_op.cc deleted file mode 100644 index 63895d3a95..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/fill_op.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/fill_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -Status FillOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - Status s = Fill(input, output, fill_value_); - return s; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/fill_op.h b/mindspore/ccsrc/dataset/kernels/data/fill_op.h deleted file mode 100644 index 03f59f3e67..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/fill_op.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_KERNELS_DATA_FILL_OP_H_ -#define DATASET_KERNELS_DATA_FILL_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class FillOp : public TensorOp { - public: - explicit FillOp(std::shared_ptr value) : fill_value_(value) {} - - ~FillOp() override = default; - void Print(std::ostream &out) const override { out << "FillOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - std::shared_ptr fill_value_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_FILL_OP_H diff --git a/mindspore/ccsrc/dataset/kernels/data/mask_op.cc b/mindspore/ccsrc/dataset/kernels/data/mask_op.cc deleted file mode 100644 index 2cfeb7e36f..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/mask_op.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/data/mask_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -Status MaskOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - std::shared_ptr temp_output; - CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(), "Cannot generate a string mask. Type should be numeric."); - - RETURN_IF_NOT_OK(Mask(input, &temp_output, value_, op_)); - - // cast the output to the the required type. Skip casting if type_ is bool. - if (type_ != DataType::DE_BOOL) { - RETURN_IF_NOT_OK(cast_->Compute(temp_output, output)); - } else { - *output = std::move(temp_output); - } - - return Status::OK(); -} - -Status MaskOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = type_; - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/mask_op.h b/mindspore/ccsrc/dataset/kernels/data/mask_op.h deleted file mode 100644 index 0affe543bb..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/mask_op.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_MASK_OP_H_ -#define DATASET_KERNELS_DATA_MASK_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/kernels/data/data_utils.h" - -namespace mindspore { -namespace dataset { - -class MaskOp : public TensorOp { - public: - MaskOp(RelationalOp op, std::shared_ptr value, DataType type = DataType(DataType::DE_BOOL)) - : op_(op), value_(std::move(value)), type_(type), cast_(new TypeCastOp(type)) {} - - ~MaskOp() override = default; - - void Print(std::ostream &out) const override { out << "MaskOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - private: - RelationalOp op_; - std::shared_ptr value_; - DataType type_; - std::unique_ptr cast_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DATA_MASK_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/one_hot_op.cc b/mindspore/ccsrc/dataset/kernels/data/one_hot_op.cc deleted file mode 100644 index 65d1a183b3..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/one_hot_op.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/one_hot_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -Status OneHotOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - Status s = OneHotEncoding(input, output, num_classes_); - return s; -} - -Status OneHotOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - std::vector inputs_copy; - inputs_copy.push_back(inputs[0].Squeeze()); - if (inputs_copy[0].Rank() == 0) outputs.emplace_back(std::vector{num_classes_}); - if (inputs_copy[0].Rank() == 1) outputs.emplace_back(std::vector{inputs_copy[0][0], num_classes_}); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/one_hot_op.h b/mindspore/ccsrc/dataset/kernels/data/one_hot_op.h deleted file mode 100644 index 80494dc5c0..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/one_hot_op.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_ONE_HOT_OP_H_ -#define DATASET_KERNELS_DATA_ONE_HOT_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class OneHotOp : public TensorOp { - public: - explicit OneHotOp(int num_classes) : num_classes_(num_classes) {} - - ~OneHotOp() override = default; - - void Print(std::ostream &out) const override { out << "OneHotOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - private: - int num_classes_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/pad_end_op.cc b/mindspore/ccsrc/dataset/kernels/data/pad_end_op.cc deleted file mode 100644 index 5b3b4cbe16..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/pad_end_op.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/pad_end_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -Status PadEndOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - Status s = PadEnd(input, output, output_shape_.AsVector(), pad_val_); - return s; -} - -Status PadEndOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - for (auto s : inputs) { - outputs.emplace_back(TensorShape(output_shape_.AsVector())); - } - CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "Input has a wrong shape"); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/pad_end_op.h b/mindspore/ccsrc/dataset/kernels/data/pad_end_op.h deleted file mode 100644 index c6bc0c430e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/pad_end_op.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_PAD_END_OP_H_ -#define DATASET_KERNELS_DATA_PAD_END_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class PadEndOp : public TensorOp { - public: - explicit PadEndOp(const TensorShape &pad_shape, const std::shared_ptr &pad_value) - : output_shape_(pad_shape), pad_val_(pad_value) {} - - ~PadEndOp() override = default; - - void Print(std::ostream &out) const override { out << "PadEndOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - private: - TensorShape output_shape_; - std::shared_ptr pad_val_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DATA_PAD_END_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/slice_op.cc b/mindspore/ccsrc/dataset/kernels/data/slice_op.cc deleted file mode 100644 index 2eebf26e84..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/slice_op.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/slice_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -Status SliceOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SliceOp supports 1D Tensors only for now."); - - // if `all` flag is true, output is just the input. - if (all_) { - *output = input; - return Status::OK(); - } - - // if slice object was provided, indices should be empty. Generate indices from the slice object. - if (slice_.valid() && indices_.empty()) { - dsize_t len = input->shape()[0]; - std::vector indices = slice_.Indices(len); - return input->Slice(output, indices); - } - - // if indices are not empty, slices should be invalid, use indices_ to slice - if (!indices_.empty() && !slice_.valid()) { - return input->Slice(output, indices_); - } - RETURN_STATUS_UNEXPECTED("The indexing parameters are invalid"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/slice_op.h b/mindspore/ccsrc/dataset/kernels/data/slice_op.h deleted file mode 100644 index 0a24ae171e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/slice_op.h +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_SLICE_OP_H_ -#define DATASET_KERNELS_DATA_SLICE_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class Slice { - public: - Slice() : start_(0), stop_(0), step_(0) {} - Slice(dsize_t start, dsize_t stop, dsize_t step) : start_(start), stop_(stop), step_(step) {} - Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} - explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} - - ~Slice() = default; - - std::vector Indices(dsize_t length) { - std::vector indices; - dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); - dsize_t end_index = std::min(Tensor::HandleNeg(stop_, length), length); - if (step_ > 0) { - for (; index < end_index; index += step_) { - indices.push_back(index); - } - } else { - for (; index > end_index; index += step_) { - indices.push_back(index); - } - } - return indices; - } - - bool valid() { return !(start_ == 0 && stop_ == 0 && step_ == 0); } - - dsize_t start_; - dsize_t stop_; - dsize_t step_; -}; - -class SliceOp : public TensorOp { - public: - explicit SliceOp(std::vector indices) : indices_(std::move(indices)) {} - explicit SliceOp(Slice slice) : slice_(slice) {} - explicit SliceOp(bool all) : all_(all) {} - - ~SliceOp() override = default; - - void Print(std::ostream &out) const override { out << "SliceOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - // only on of the following will be valid - // given indices to slice the Tensor. Empty vector if invalid. - std::vector indices_; - // Slice object. All start, stop and step are 0 if invalid. - Slice slice_; - // Flag to read all indcies in the dim. - bool all_ = false; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DATA_SLICE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/to_float16_op.cc b/mindspore/ccsrc/dataset/kernels/data/to_float16_op.cc deleted file mode 100644 index 1cd79456e0..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/to_float16_op.cc +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/to_float16_op.h" -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -Status ToFloat16Op::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - return ToFloat16(input, output); -} -Status ToFloat16Op::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = DataType(DataType::DE_FLOAT16); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/to_float16_op.h b/mindspore/ccsrc/dataset/kernels/data/to_float16_op.h deleted file mode 100644 index 3fca50bf07..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/to_float16_op.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDDATA_TOFLOAT16OP_H -#define MINDDATA_TOFLOAT16OP_H - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class ToFloat16Op : public TensorOp { - public: - ToFloat16Op() = default; - - ~ToFloat16Op() override = default; - - // Overrides the base class compute function - // Calls the ToFloat16 function in ImageUtils, this function takes an input tensor - // and transforms its data to float16, the output memory is manipulated to contain the result - // @return Status - The error code return - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - void Print(std::ostream &out) const override { out << "ToFloat16Op"; } - - Status OutputType(const std::vector &inputs, std::vector &outputs) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDDATA_TOFLOAT16OP_H diff --git a/mindspore/ccsrc/dataset/kernels/data/type_cast_op.cc b/mindspore/ccsrc/dataset/kernels/data/type_cast_op.cc deleted file mode 100644 index 74c84a668a..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/type_cast_op.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/type_cast_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -TypeCastOp::TypeCastOp(const DataType &new_type) : type_(new_type) {} - -TypeCastOp::TypeCastOp(const std::string &data_type) { type_ = DataType(data_type); } - -Status TypeCastOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - return TypeCast(input, output, type_); -} -Status TypeCastOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = type_; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/type_cast_op.h b/mindspore/ccsrc/dataset/kernels/data/type_cast_op.h deleted file mode 100644 index 1b3f2c3290..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/type_cast_op.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ -#define DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class TypeCastOp : public TensorOp { - public: - // Constructor for TypecastOp - // @param data_type datatype to cast to - explicit TypeCastOp(const DataType &data_type); - - // Constructor for TypecastOp - // @param data_type datatype to cast to - explicit TypeCastOp(const std::string &data_type); - - ~TypeCastOp() override = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - void Print(std::ostream &out) const override { out << "TypeCastOp"; } - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - private: - DataType type_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt deleted file mode 100644 index fef698912c..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(kernels-image OBJECT - center_crop_op.cc - cut_out_op.cc - decode_op.cc - hwc_to_chw_op.cc - image_utils.cc - normalize_op.cc - pad_op.cc - random_color_adjust_op.cc - random_crop_decode_resize_op.cc - random_crop_and_resize_with_bbox_op.cc - random_crop_and_resize_op.cc - random_crop_op.cc - random_crop_with_bbox_op.cc - random_horizontal_flip_op.cc - random_horizontal_flip_bbox_op.cc - bounding_box_augment_op.cc - random_resize_op.cc - random_rotation_op.cc - random_vertical_flip_op.cc - random_vertical_flip_with_bbox_op.cc - rescale_op.cc - resize_bilinear_op.cc - resize_op.cc - uniform_aug_op.cc - resize_with_bbox_op.cc - random_resize_with_bbox_op.cc - ) diff --git a/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.cc b/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.cc deleted file mode 100644 index 04e00d878d..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "dataset/kernels/image/bounding_box_augment_op.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/core/cv_tensor.h" - -namespace mindspore { -namespace dataset { -const float BoundingBoxAugmentOp::kDefRatio = 0.3; - -BoundingBoxAugmentOp::BoundingBoxAugmentOp(std::shared_ptr transform, float ratio) - : ratio_(ratio), transform_(std::move(transform)) { - rnd_.seed(GetSeed()); -} - -Status BoundingBoxAugmentOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); // check if bounding boxes are valid - uint32_t num_of_boxes = input[1]->shape()[0]; - uint32_t num_to_aug = num_of_boxes * ratio_; // cast to int - std::vector boxes(num_of_boxes); - std::vector selected_boxes; - for (uint32_t i = 0; i < num_of_boxes; i++) boxes[i] = i; - // sample bboxes according to ratio picked by user - std::sample(boxes.begin(), boxes.end(), std::back_inserter(selected_boxes), num_to_aug, rnd_); - std::shared_ptr crop_out; - std::shared_ptr res_out; - std::shared_ptr input_restore = CVTensor::AsCVTensor(input[0]); - - for (uint32_t i = 0; i < num_to_aug; i++) { - uint32_t min_x = 0; - uint32_t min_y = 0; - uint32_t b_w = 0; - uint32_t b_h = 0; - // get the required items - input[1]->GetItemAt(&min_x, {selected_boxes[i], 0}); - input[1]->GetItemAt(&min_y, {selected_boxes[i], 1}); - input[1]->GetItemAt(&b_w, {selected_boxes[i], 2}); - input[1]->GetItemAt(&b_h, {selected_boxes[i], 3}); - Crop(input_restore, &crop_out, min_x, min_y, b_w, b_h); - // transform the cropped bbox region - transform_->Compute(crop_out, &res_out); - // place the transformed region back in the restored input - std::shared_ptr res_img = CVTensor::AsCVTensor(res_out); - // check if transformed crop is out of bounds of the box - if (res_img->mat().cols > b_w || res_img->mat().rows > b_h || res_img->mat().cols < b_w || - res_img->mat().rows < b_h) { - // if so, resize to fit in the box - std::shared_ptr resize_op = std::make_shared(b_h, b_w); - resize_op->Compute(std::static_pointer_cast(res_img), &res_out); - res_img = CVTensor::AsCVTensor(res_out); - } - res_img->mat().copyTo(input_restore->mat()(cv::Rect(min_x, min_y, res_img->mat().cols, res_img->mat().rows))); - } - (*output).push_back(std::move(std::static_pointer_cast(input_restore))); - (*output).push_back(input[1]); - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.h b/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.h deleted file mode 100644 index 6c106f75dc..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ -#define DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ - -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -class BoundingBoxAugmentOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefRatio; - - // Constructor for BoundingBoxAugmentOp - // @param std::shared_ptr transform transform: C++ opration to apply on select bounding boxes - // @param float ratio: ratio of bounding boxes to have the transform applied on - BoundingBoxAugmentOp(std::shared_ptr transform, float ratio); - - ~BoundingBoxAugmentOp() override = default; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const BoundingBoxAugmentOp &so) { - so.Print(out); - return out; - } - - void Print(std::ostream &out) const override { out << "BoundingBoxAugmentOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - private: - float ratio_; - std::mt19937 rnd_; - std::shared_ptr transform_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/center_crop_op.cc b/mindspore/ccsrc/dataset/kernels/image/center_crop_op.cc deleted file mode 100644 index a5129e9c71..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/center_crop_op.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/center_crop_op.h" -#include -#include "common/utils.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t CenterCropOp::kDefWidth = 0; - -Status CenterCropOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - std::string err_msg; - dsize_t rank = input->shape().Rank(); - err_msg += (rank < 2 || rank > 3) ? "Rank received::" + std::to_string(rank) + " Expected: 2 or 3 \t" : ""; - err_msg += (crop_het_ <= 0 || crop_wid_ <= 0) ? "crop size needs to be positive integers\t" : ""; - - if (err_msg.length() != 0) RETURN_STATUS_UNEXPECTED(common::SafeCStr(err_msg)); - - int32_t top = crop_het_ - input->shape()[0]; // number of pixels to pad (top and bottom) - int32_t left = crop_wid_ - input->shape()[1]; - std::shared_ptr pad_image; - if (top > 0 && left > 0) { // padding only - return Pad(input, output, top / 2 + top % 2, top / 2, left / 2 + left % 2, left / 2, BorderType::kConstant); - } else if (top > 0) { - RETURN_IF_NOT_OK(Pad(input, &pad_image, top / 2 + top % 2, top / 2, 0, 0, BorderType::kConstant)); - return Crop(pad_image, output, (static_cast(pad_image->shape()[1]) - crop_wid_) / 2, - (static_cast(pad_image->shape()[0]) - crop_het_) / 2, crop_wid_, crop_het_); - } else if (left > 0) { - RETURN_IF_NOT_OK(Pad(input, &pad_image, 0, 0, left / 2 + left % 2, left / 2, BorderType::kConstant)); - return Crop(pad_image, output, (static_cast(pad_image->shape()[1]) - crop_wid_) / 2, - (static_cast(pad_image->shape()[0]) - crop_het_) / 2, crop_wid_, crop_het_); - } - return Crop(input, output, (input->shape()[1] - crop_wid_) / 2, (input->shape()[0] - crop_het_) / 2, crop_wid_, - crop_het_); -} - -void CenterCropOp::Print(std::ostream &out) const { - out << "CenterCropOp: " - << "cropWidth: " << crop_wid_ << "cropHeight: " << crop_het_ << "\n"; -} -Status CenterCropOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out = TensorShape{crop_het_, crop_wid_}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/center_crop_op.h b/mindspore/ccsrc/dataset/kernels/image/center_crop_op.h deleted file mode 100644 index eb8e71ba7c..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/center_crop_op.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ -#define DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class CenterCropOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefWidth; - - explicit CenterCropOp(int32_t het, int32_t wid = kDefWidth) : crop_het_(het), crop_wid_(wid == 0 ? het : wid) {} - - ~CenterCropOp() override = default; - - void Print(std::ostream &out) const override; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - private: - int32_t crop_het_; - int32_t crop_wid_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc deleted file mode 100644 index 74d9df5d6b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include "dataset/kernels/image/cut_out_op.h" - -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const bool CutOutOp::kDefRandomColor = false; -const uint8_t CutOutOp::kDefFillR = 0; -const uint8_t CutOutOp::kDefFillG = 0; -const uint8_t CutOutOp::kDefFillB = 0; - -// constructor -CutOutOp::CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color, uint8_t fill_r, - uint8_t fill_g, uint8_t fill_b) - : rnd_(GetSeed()), - box_height_(box_height), - box_width_(box_width), - num_patches_(num_patches), - random_color_(random_color), - fill_r_(fill_r), - fill_g_(fill_g), - fill_b_(fill_b) {} - -// main function call for cut out -Status CutOutOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - std::shared_ptr inputCV = CVTensor::AsCVTensor(input); - // cut out will clip the erasing area if the box is near the edge of the image and the boxes are black - RETURN_IF_NOT_OK(Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, &rnd_, fill_r_, - fill_g_, fill_b_)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h deleted file mode 100644 index 2198f23e44..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#ifndef DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ -#define DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class CutOutOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const bool kDefRandomColor; - static const uint8_t kDefFillR; - static const uint8_t kDefFillG; - static const uint8_t kDefFillB; - - // Constructor for CutOutOp - // @param box_height box height - // @param box_width box_width - // @param num_patches how many patches to erase from image - // @param random_color boolean value to indicate fill patch with random color - // @param fill_r R value for the color to fill patch with - // @param fill_g G value for the color to fill patch with - // @param fill_b B value for the color to fill patch with - // @note maybe using unsigned long int isn't the best here according to our coding rules - CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color = kDefRandomColor, - uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); - - ~CutOutOp() override = default; - - void Print(std::ostream &out) const override { - out << "CutOut:: box_height: " << box_height_ << " box_width: " << box_width_ << " num_patches: " << num_patches_; - } - - // Overrides the base class compute function - // Calls the erase function in ImageUtils, this function takes an input tensor - // and overwrites some of its data using openCV, the output memory is manipulated to contain the result - // @return Status - The error code return - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - std::mt19937 rnd_; - int32_t box_height_; - int32_t box_width_; - int32_t num_patches_; - bool random_color_; - uint8_t fill_r_; - uint8_t fill_g_; - uint8_t fill_b_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/decode_op.cc b/mindspore/ccsrc/dataset/kernels/image/decode_op.cc deleted file mode 100644 index ef6cf88b3b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/decode_op.cc +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/decode_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const bool DecodeOp::kDefRgbFormat = true; - -DecodeOp::DecodeOp(bool is_rgb_format) : is_rgb_format_(is_rgb_format) { - if (is_rgb_format_) { // RGB colour mode - MS_LOG(DEBUG) << "Decode colour mode is RGB."; - } else { - MS_LOG(DEBUG) << "Decode colour mode is BGR."; - } -} - -Status DecodeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (is_rgb_format_) { // RGB colour mode - return Decode(input, output); - } else { // BGR colour mode - RETURN_STATUS_UNEXPECTED("Decode BGR is deprecated"); - } -} -Status DecodeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out({-1, -1, 3}); // we don't know what is output image size, but we know it should be 3 channels - if (inputs[0].Rank() == 1) outputs.emplace_back(out); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} - -Status DecodeOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = DataType(DataType::DE_UINT8); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/decode_op.h b/mindspore/ccsrc/dataset/kernels/image/decode_op.h deleted file mode 100644 index 6e7180958a..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/decode_op.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_DECODE_OP_H_ -#define DATASET_KERNELS_IMAGE_DECODE_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DecodeOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const bool kDefRgbFormat; - - explicit DecodeOp(bool is_rgb_format = true); - - ~DecodeOp() = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - void Print(std::ostream &out) const override { out << "DecodeOp"; } - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - private: - bool is_rgb_format_ = true; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_DECODE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.cc b/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.cc deleted file mode 100644 index 8ed2229cd1..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/hwc_to_chw_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status HwcToChwOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // input.shape == HWC - // output.shape == CHW - return HwcToChw(input, output); -} -Status HwcToChwOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape in = inputs[0]; - TensorShape out = TensorShape{in[2], in[0], in[1]}; - if (inputs[0].Rank() == 3) outputs.emplace_back(out); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.h b/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.h deleted file mode 100644 index 825ffa4443..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ -#define DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class HwcToChwOp : public TensorOp { - public: - void Print(std::ostream &out) const override { out << "HwcToChw"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/dataset/kernels/image/image_utils.cc deleted file mode 100644 index ded9a8db11..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/image_utils.cc +++ /dev/null @@ -1,835 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/image_utils.h" -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "dataset/core/constants.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/util/random.h" - -#define MAX_INT_PRECISION 16777216 // float int precision is 16777216 -namespace mindspore { -namespace dataset { -int GetCVInterpolationMode(InterpolationMode mode) { - switch (mode) { - case InterpolationMode::kLinear: - return static_cast(cv::InterpolationFlags::INTER_LINEAR); - case InterpolationMode::kCubic: - return static_cast(cv::InterpolationFlags::INTER_CUBIC); - case InterpolationMode::kArea: - return static_cast(cv::InterpolationFlags::INTER_AREA); - case InterpolationMode::kNearestNeighbour: - return static_cast(cv::InterpolationFlags::INTER_NEAREST); - default: - return static_cast(cv::InterpolationFlags::INTER_LINEAR); - } -} - -int GetCVBorderType(BorderType type) { - switch (type) { - case BorderType::kConstant: - return static_cast(cv::BorderTypes::BORDER_CONSTANT); - case BorderType::kEdge: - return static_cast(cv::BorderTypes::BORDER_REPLICATE); - case BorderType::kReflect: - return static_cast(cv::BorderTypes::BORDER_REFLECT101); - case BorderType::kSymmetric: - return static_cast(cv::BorderTypes::BORDER_REFLECT); - default: - return static_cast(cv::BorderTypes::BORDER_CONSTANT); - } -} - -Status Flip(std::shared_ptr input, std::shared_ptr *output, int flip_code) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); - - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - RETURN_IF_NOT_OK(output_cv->AllocateBuffer(output_cv->SizeInBytes())); - - if (input_cv->mat().data) { - try { - cv::flip(input_cv->mat(), output_cv->mat(), flip_code); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in flip op."); - } - } else { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor, the input data is null"); - } -} - -Status HorizontalFlip(std::shared_ptr input, std::shared_ptr *output) { - return Flip(std::move(input), output, 1); -} - -Status VerticalFlip(std::shared_ptr input, std::shared_ptr *output) { - return Flip(std::move(input), output, 0); -} - -Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, - int32_t output_width, double fx, double fy, InterpolationMode mode) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { - RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of or "); - } - cv::Mat in_image = input_cv->mat(); - // resize image too large or too small - if (output_height == 0 || output_height > in_image.rows * 1000 || output_width == 0 || - output_width > in_image.cols * 1000) { - std::string err_msg = - "The resizing width or height 1) is too big, it's up to " - "1000 times the original image; 2) can not be 0."; - return Status(StatusCode::kShapeMisMatch, err_msg); - } - try { - TensorShape shape{output_height, output_width}; - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - auto cv_mode = GetCVInterpolationMode(mode); - cv::resize(in_image, output_cv->mat(), cv::Size(output_width, output_height), fx, fy, cv_mode); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in image resize."); - } -} - -bool HasJpegMagic(const std::shared_ptr &input) { - const unsigned char *kJpegMagic = (unsigned char *)"\xFF\xD8\xFF"; - constexpr size_t kJpegMagicLen = 3; - return input->SizeInBytes() >= kJpegMagicLen && memcmp(input->GetBuffer(), kJpegMagic, kJpegMagicLen) == 0; -} - -Status Decode(const std::shared_ptr &input, std::shared_ptr *output) { - if (HasJpegMagic(input)) { - return JpegCropAndDecode(input, output); - } else { - return DecodeCv(input, output); - } -} - -Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *output) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - try { - cv::Mat img_mat = cv::imdecode(input_cv->mat(), cv::IMREAD_COLOR | cv::IMREAD_IGNORE_ORIENTATION); - if (img_mat.data == nullptr) { - std::string err = "Error in decoding\t"; - RETURN_STATUS_UNEXPECTED(err); - } - cv::cvtColor(img_mat, img_mat, static_cast(cv::COLOR_BGR2RGB)); - std::shared_ptr output_cv = std::make_shared(img_mat); - RETURN_UNEXPECTED_IF_NULL(output_cv); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in image Decode"); - } -} - -static void JpegInitSource(j_decompress_ptr cinfo) {} - -static boolean JpegFillInputBuffer(j_decompress_ptr cinfo) { - if (cinfo->src->bytes_in_buffer == 0) { - ERREXIT(cinfo, JERR_INPUT_EMPTY); - return FALSE; - } - return TRUE; -} - -static void JpegTermSource(j_decompress_ptr cinfo) {} - -static void JpegSkipInputData(j_decompress_ptr cinfo, int64_t jump) { - if (jump < 0) { - return; - } - if (static_cast(jump) > cinfo->src->bytes_in_buffer) { - cinfo->src->bytes_in_buffer = 0; - return; - } else { - cinfo->src->bytes_in_buffer -= jump; - cinfo->src->next_input_byte += jump; - } -} - -void JpegSetSource(j_decompress_ptr cinfo, const void *data, int64_t datasize) { - cinfo->src = static_cast( - (*cinfo->mem->alloc_small)(reinterpret_cast(cinfo), JPOOL_PERMANENT, sizeof(struct jpeg_source_mgr))); - cinfo->src->init_source = JpegInitSource; - cinfo->src->fill_input_buffer = JpegFillInputBuffer; -#if defined(_WIN32) || defined(_WIN64) - cinfo->src->skip_input_data = reinterpret_cast(JpegSkipInputData); -#else - cinfo->src->skip_input_data = JpegSkipInputData; -#endif - cinfo->src->resync_to_restart = jpeg_resync_to_restart; - cinfo->src->term_source = JpegTermSource; - cinfo->src->bytes_in_buffer = datasize; - cinfo->src->next_input_byte = static_cast(data); -} - -static Status JpegReadScanlines(jpeg_decompress_struct *const cinfo, int max_scanlines_to_read, JSAMPLE *buffer, - int buffer_size, int crop_w, int crop_w_aligned, int offset, int stride) { - // scanlines will be read to this buffer first, must have the number - // of components equal to the number of components in the image - int64_t scanline_size = crop_w_aligned * cinfo->output_components; - std::vector scanline(scanline_size); - JSAMPLE *scanline_ptr = &scanline[0]; - while (cinfo->output_scanline < static_cast(max_scanlines_to_read)) { - int num_lines_read = jpeg_read_scanlines(cinfo, &scanline_ptr, 1); - if (cinfo->out_color_space == JCS_CMYK && num_lines_read > 0) { - for (int i = 0; i < crop_w; ++i) { - int cmyk_pixel = 4 * i + offset; - const int c = scanline_ptr[cmyk_pixel]; - const int m = scanline_ptr[cmyk_pixel + 1]; - const int y = scanline_ptr[cmyk_pixel + 2]; - const int k = scanline_ptr[cmyk_pixel + 3]; - int r, g, b; - if (cinfo->saw_Adobe_marker) { - r = (k * c) / 255; - g = (k * m) / 255; - b = (k * y) / 255; - } else { - r = (255 - c) * (255 - k) / 255; - g = (255 - m) * (255 - k) / 255; - b = (255 - y) * (255 - k) / 255; - } - buffer[3 * i + 0] = r; - buffer[3 * i + 1] = g; - buffer[3 * i + 2] = b; - } - } else if (num_lines_read > 0) { - int copy_status = memcpy_s(buffer, buffer_size, scanline_ptr + offset, stride); - if (copy_status != 0) { - jpeg_destroy_decompress(cinfo); - RETURN_STATUS_UNEXPECTED("memcpy failed"); - } - } else { - jpeg_destroy_decompress(cinfo); - std::string err_msg = "failed to read scanline"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - buffer += stride; - buffer_size = buffer_size - stride; - } - return Status::OK(); -} - -static Status JpegSetColorSpace(jpeg_decompress_struct *cinfo) { - switch (cinfo->num_components) { - case 1: - // we want to output 3 components if it's grayscale - cinfo->out_color_space = JCS_RGB; - return Status::OK(); - case 3: - cinfo->out_color_space = JCS_RGB; - return Status::OK(); - case 4: - // Need to manually convert to RGB - cinfo->out_color_space = JCS_CMYK; - return Status::OK(); - default: - jpeg_destroy_decompress(cinfo); - std::string err_msg = "wrong number of components"; - RETURN_STATUS_UNEXPECTED(err_msg); - } -} - -void JpegErrorExitCustom(j_common_ptr cinfo) { - char jpeg_last_error_msg[JMSG_LENGTH_MAX]; - (*(cinfo->err->format_message))(cinfo, jpeg_last_error_msg); - throw std::runtime_error(jpeg_last_error_msg); -} - -Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr *output, int crop_x, int crop_y, - int crop_w, int crop_h) { - struct jpeg_decompress_struct cinfo; - auto DestroyDecompressAndReturnError = [&cinfo](const std::string &err) { - jpeg_destroy_decompress(&cinfo); - RETURN_STATUS_UNEXPECTED(err); - }; - struct JpegErrorManagerCustom jerr; - cinfo.err = jpeg_std_error(&jerr.pub); - jerr.pub.error_exit = JpegErrorExitCustom; - try { - jpeg_create_decompress(&cinfo); - JpegSetSource(&cinfo, input->GetBuffer(), input->SizeInBytes()); - (void)jpeg_read_header(&cinfo, TRUE); - RETURN_IF_NOT_OK(JpegSetColorSpace(&cinfo)); - jpeg_calc_output_dimensions(&cinfo); - } catch (std::runtime_error &e) { - return DestroyDecompressAndReturnError(e.what()); - } - if (crop_x == 0 && crop_y == 0 && crop_w == 0 && crop_h == 0) { - crop_w = cinfo.output_width; - crop_h = cinfo.output_height; - } else if (crop_w == 0 || static_cast(crop_w + crop_x) > cinfo.output_width || crop_h == 0 || - static_cast(crop_h + crop_y) > cinfo.output_height) { - return DestroyDecompressAndReturnError("Crop window is not valid"); - } - const int mcu_size = cinfo.min_DCT_scaled_size; - unsigned int crop_x_aligned = (crop_x / mcu_size) * mcu_size; - unsigned int crop_w_aligned = crop_w + crop_x - crop_x_aligned; - try { - (void)jpeg_start_decompress(&cinfo); - jpeg_crop_scanline(&cinfo, &crop_x_aligned, &crop_w_aligned); - } catch (std::runtime_error &e) { - return DestroyDecompressAndReturnError(e.what()); - } - JDIMENSION skipped_scanlines = jpeg_skip_scanlines(&cinfo, crop_y); - // three number of output components, always convert to RGB and output - constexpr int kOutNumComponents = 3; - TensorShape ts = TensorShape({crop_h, crop_w, kOutNumComponents}); - auto output_tensor = std::make_shared(ts, DataType(DataType::DE_UINT8)); - const int buffer_size = output_tensor->SizeInBytes(); - JSAMPLE *buffer = static_cast(reinterpret_cast(&(*output_tensor->begin()))); - const int max_scanlines_to_read = skipped_scanlines + crop_h; - // stride refers to output tensor, which has 3 components at most - const int stride = crop_w * kOutNumComponents; - // offset is calculated for scanlines read from the image, therefore - // has the same number of components as the image - const int offset = (crop_x - crop_x_aligned) * cinfo.output_components; - RETURN_IF_NOT_OK( - JpegReadScanlines(&cinfo, max_scanlines_to_read, buffer, buffer_size, crop_w, crop_w_aligned, offset, stride)); - *output = output_tensor; - jpeg_destroy_decompress(&cinfo); - return Status::OK(); -} - -Status Rescale(const std::shared_ptr &input, std::shared_ptr *output, float rescale, float shift) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - cv::Mat input_image = input_cv->mat(); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); - RETURN_UNEXPECTED_IF_NULL(output_cv); - try { - input_image.convertTo(output_cv->mat(), CV_32F, rescale, shift); - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in image rescale"); - } - return Status::OK(); -} - -Status Crop(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, int w, int h) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { - RETURN_STATUS_UNEXPECTED("Shape not or "); - } - try { - TensorShape shape{h, w}; - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - cv::Rect roi(x, y, w, h); - (input_cv->mat())(roi).copyTo(output_cv->mat()); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in crop."); - } -} - -Status HwcToChw(std::shared_ptr input, std::shared_ptr *output) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - if (input_cv->Rank() == 2) { - // If input tensor is 2D, we assume we have hw dimensions - *output = input; - return Status::OK(); - } - int num_channels = input_cv->shape()[2]; - if (input_cv->shape().Size() < 2 || input_cv->shape().Size() > 3 || - (input_cv->shape().Size() == 3 && num_channels != 3 && num_channels != 1)) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3 nor 1"); - } - cv::Mat output_img; - - int height = input_cv->shape()[0]; - int width = input_cv->shape()[1]; - - auto output_cv = std::make_unique(TensorShape{num_channels, height, width}, input_cv->type()); - for (int i = 0; i < num_channels; ++i) { - cv::Mat mat; - RETURN_IF_NOT_OK(output_cv->Mat({i}, &mat)); - cv::extractChannel(input_cv->mat(), mat, i); - } - *output = std::move(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in ChannelSwap."); - } -} - -Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); - int num_channels = input_cv->shape()[2]; - if (input_cv->shape().Size() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); - } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast(cv::COLOR_BGR2RGB)); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in ChangeMode."); - } -} - -Status CropAndResize(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, - int crop_height, int crop_width, int target_height, int target_width, InterpolationMode mode) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { - RETURN_STATUS_UNEXPECTED("Shape not or "); - } - // image too large or too small - if (crop_height == 0 || crop_width == 0 || target_height == 0 || target_height > crop_height * 1000 || - target_width == 0 || target_height > crop_width * 1000) { - std::string err_msg = - "The resizing width or height 1) is too big, it's up to " - "1000 times the original image; 2) can not be 0."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - cv::Rect roi(x, y, crop_width, crop_height); - auto cv_mode = GetCVInterpolationMode(mode); - cv::Mat cv_in = input_cv->mat(); - TensorShape shape{target_height, target_width}; - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr cvt_out = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(cvt_out); - cv::resize(cv_in(roi), cvt_out->mat(), cv::Size(target_width, target_height), 0, 0, cv_mode); - *output = std::static_pointer_cast(cvt_out); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in CropAndResize."); - } -} - -Status Rotate(const std::shared_ptr &input, std::shared_ptr *output, float fx, float fy, float degree, - InterpolationMode interpolation, bool expand, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - cv::Mat input_img = input_cv->mat(); - if (input_img.cols > (MAX_INT_PRECISION * 2) || input_img.rows > (MAX_INT_PRECISION * 2)) { - RETURN_STATUS_UNEXPECTED("Image too large center not precise"); - } - // default to center of image - if (fx == -1 && fy == -1) { - fx = (input_img.cols - 1) / 2.0; - fy = (input_img.rows - 1) / 2.0; - } - cv::Mat output_img; - cv::Scalar fill_color = cv::Scalar(fill_b, fill_g, fill_r); - // maybe don't use uint32 for image dimension here - cv::Point2f pc(fx, fy); - cv::Mat rot = cv::getRotationMatrix2D(pc, degree, 1.0); - std::shared_ptr output_cv; - if (!expand) { - // this case means that the shape doesn't change, size stays the same - // We may not need this memcpy if it is in place. - output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - // using inter_nearest to comply with python default - cv::warpAffine(input_img, output_cv->mat(), rot, input_img.size(), GetCVInterpolationMode(interpolation), - cv::BORDER_CONSTANT, fill_color); - } else { - // we resize here since the shape changes - // create a new bounding box with the rotate - cv::Rect2f bbox = cv::RotatedRect(cv::Point2f(), input_img.size(), degree).boundingRect2f(); - rot.at(0, 2) += bbox.width / 2.0 - input_img.cols / 2.0; - rot.at(1, 2) += bbox.height / 2.0 - input_img.rows / 2.0; - // use memcpy and don't compute the new shape since openCV has a rounding problem - cv::warpAffine(input_img, output_img, rot, bbox.size(), GetCVInterpolationMode(interpolation), - cv::BORDER_CONSTANT, fill_color); - output_cv = std::make_shared(output_img); - RETURN_UNEXPECTED_IF_NULL(output_cv); - } - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in image rotation"); - } - return Status::OK(); -} - -Status Normalize(const std::shared_ptr &input, std::shared_ptr *output, - const std::shared_ptr &mean, const std::shared_ptr &std) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!(input_cv->mat().data && input_cv->Rank() == 3)) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - cv::Mat in_image = input_cv->mat(); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); - RETURN_UNEXPECTED_IF_NULL(output_cv); - mean->Squeeze(); - if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) { - std::string err_msg = "Mean tensor should be of size 3 and type float."; - return Status(StatusCode::kShapeMisMatch, err_msg); - } - std->Squeeze(); - if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) { - std::string err_msg = "Std tensor should be of size 3 and type float."; - return Status(StatusCode::kShapeMisMatch, err_msg); - } - try { - // NOTE: We are assuming the input image is in RGB and the mean - // and std are in RGB - cv::Mat rgb[3]; - cv::split(in_image, rgb); - for (uint8_t i = 0; i < 3; i++) { - float mean_c, std_c; - RETURN_IF_NOT_OK(mean->GetItemAt(&mean_c, {i})); - RETURN_IF_NOT_OK(std->GetItemAt(&std_c, {i})); - rgb[i].convertTo(rgb[i], CV_32F, 1.0 / std_c, (-mean_c / std_c)); - } - cv::merge(rgb, 3, output_cv->mat()); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in Normalize"); - } -} - -Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - cv::Mat input_img = input_cv->mat(); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); - } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - output_cv->mat() = input_img * alpha; - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in adjust brightness"); - } - return Status::OK(); -} - -Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - cv::Mat input_img = input_cv->mat(); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); - } - cv::Mat gray, output_img; - cv::cvtColor(input_img, gray, CV_RGB2GRAY); - int mean_img = static_cast(cv::mean(gray).val[0] + 0.5); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - output_img = cv::Mat::zeros(input_img.rows, input_img.cols, CV_8UC1); - output_img = output_img + mean_img; - cv::cvtColor(output_img, output_img, CV_GRAY2RGB); - output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha; - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in adjust contrast"); - } - return Status::OK(); -} - -Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - cv::Mat input_img = input_cv->mat(); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); - } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - cv::Mat output_img = output_cv->mat(); - cv::Mat gray; - cv::cvtColor(input_img, gray, CV_RGB2GRAY); - cv::cvtColor(gray, output_img, CV_GRAY2RGB); - output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha; - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in adjust saturation"); - } - return Status::OK(); -} - -Status AdjustHue(const std::shared_ptr &input, std::shared_ptr *output, const float &hue) { - if (hue > 0.5 || hue < -0.5) { - MS_LOG(ERROR) << "Hue factor is not in [-0.5, 0.5]."; - RETURN_STATUS_UNEXPECTED("hue_factor is not in [-0.5, 0.5]."); - } - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - cv::Mat input_img = input_cv->mat(); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); - } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - cv::Mat output_img; - cv::cvtColor(input_img, output_img, CV_RGB2HSV_FULL); - for (int y = 0; y < output_img.cols; y++) { - for (int x = 0; x < output_img.rows; x++) { - uint8_t cur1 = output_img.at(cv::Point(y, x))[0]; - uint8_t h_hue = 0; - h_hue = static_cast(hue * 255); - cur1 += h_hue; - output_img.at(cv::Point(y, x))[0] = cur1; - } - } - cv::cvtColor(output_img, output_cv->mat(), CV_HSV2RGB_FULL); - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in adjust hue"); - } - return Status::OK(); -} - -Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, - int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, - uint8_t fill_g, uint8_t fill_b) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - int num_channels = input_cv->shape()[2]; - if (input_cv->mat().data == nullptr || input_cv->Rank() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase"); - } - cv::Mat input_img = input_cv->mat(); - int32_t image_h = input_cv->shape()[0]; - int32_t image_w = input_cv->shape()[1]; - // check if erase size is bigger than image itself - if (box_height > image_h || box_width > image_w) { - RETURN_STATUS_UNEXPECTED("input box size too large for image erase"); - } - - // for random color - std::normal_distribution normal_distribution(0, 1); - std::uniform_int_distribution height_distribution_bound(0, image_h - box_height); - std::uniform_int_distribution width_distribution_bound(0, image_w - box_width); - std::uniform_int_distribution height_distribution_unbound(0, image_h + box_height); - std::uniform_int_distribution width_distribution_unbound(0, image_w + box_width); - // core logic - // update values based on random erasing or cutout - - for (int32_t i = 0; i < num_patches; i++) { - // rows in cv mat refers to the height of the cropped box - // we determine h_start and w_start using two different distributions as erasing is used by two different - // image augmentations. The bounds are also different in each case. - int32_t h_start = (bounded) ? height_distribution_bound(*rnd) : (height_distribution_unbound(*rnd) - box_height); - int32_t w_start = (bounded) ? width_distribution_bound(*rnd) : (width_distribution_unbound(*rnd) - box_width); - - int32_t max_width = (w_start + box_width > image_w) ? image_w : w_start + box_width; - int32_t max_height = (h_start + box_height > image_h) ? image_h : h_start + box_height; - // check for starting range >= 0, here the start range is checked after for cut out, for random erasing - // w_start and h_start will never be less than 0. - h_start = (h_start < 0) ? 0 : h_start; - w_start = (w_start < 0) ? 0 : w_start; - for (int y = w_start; y < max_width; y++) { - for (int x = h_start; x < max_height; x++) { - if (random_color) { - // fill each box with a random value - input_img.at(cv::Point(y, x))[0] = static_cast(normal_distribution(*rnd)); - input_img.at(cv::Point(y, x))[1] = static_cast(normal_distribution(*rnd)); - input_img.at(cv::Point(y, x))[2] = static_cast(normal_distribution(*rnd)); - } else { - input_img.at(cv::Point(y, x))[0] = fill_r; - input_img.at(cv::Point(y, x))[1] = fill_g; - input_img.at(cv::Point(y, x))[2] = fill_b; - } - } - } - } - *output = std::static_pointer_cast(input); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in erasing"); - } -} - -Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, - const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, - uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { - try { - // input image - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - // get the border type in openCV - auto b_type = GetCVBorderType(border_types); - // output image - cv::Mat out_image; - if (b_type == cv::BORDER_CONSTANT) { - cv::Scalar fill_color = cv::Scalar(fill_b, fill_g, fill_r); - cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type, fill_color); - } else { - cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type); - } - std::shared_ptr output_cv = std::make_shared(out_image); - RETURN_UNEXPECTED_IF_NULL(output_cv); - // pad the dimension if shape information is only 2 dimensional, this is grayscale - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() == 3 && num_channels == 1 && output_cv->Rank() == 2) output_cv->ExpandDim(2); - *output = std::static_pointer_cast(output_cv); - - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in pad"); - } -} -// -------- BBOX OPERATIONS -------- // -Status UpdateBBoxesForCrop(std::shared_ptr *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax, - int CB_Ymax) { - // PASS LIST, COUNT OF BOUNDING BOXES - // Also PAss X/Y Min/Max of image cropped region - normally obtained from 'GetCropBox' functions - uint32_t bb_Xmin_t, bb_Ymin_t, bb_Xmax_t, bb_Ymax_t; - - std::vector correct_ind; - std::vector copyVals; - dsize_t bboxDim = (*bboxList)->shape()[1]; - bool retFlag = false; // true unless overlap found - for (int i = 0; i < *bboxCount; i++) { - int bb_Xmin, bb_Xmax, bb_Ymin, bb_Ymax; - RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&bb_Xmin_t, {i, 0})); - RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&bb_Ymin_t, {i, 1})); - RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&bb_Xmax_t, {i, 2})); - RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&bb_Ymax_t, {i, 3})); - bb_Xmin = bb_Xmin_t; - bb_Ymin = bb_Ymin_t; - bb_Xmax = bb_Xmax_t; - bb_Ymax = bb_Ymax_t; - bb_Xmax = bb_Xmin + bb_Xmax; - bb_Ymax = bb_Ymin + bb_Ymax; - // check for image / BB overlap - if (((bb_Xmin > CB_Xmax) || (bb_Ymin > CB_Ymax)) || ((bb_Xmax < CB_Xmin) || (bb_Ymax < CB_Ymin))) { - continue; // no overlap found - } - // Update this bbox and select it to move to the final output tensor - correct_ind.push_back(i); - // adjust BBox corners by bringing into new CropBox if beyond - // Also reseting/adjusting for boxes to lie within CropBox instead of Image - subtract CropBox Xmin/YMin - bb_Xmin = bb_Xmin - (std::min(0, (bb_Xmin - CB_Xmin)) + CB_Xmin); - bb_Xmax = bb_Xmax - (std::max(0, (bb_Xmax - CB_Xmax)) + CB_Xmin); - bb_Ymin = bb_Ymin - (std::min(0, (bb_Ymin - CB_Ymin)) + CB_Ymin); - bb_Ymax = bb_Ymax - (std::max(0, (bb_Ymax - CB_Ymax)) + CB_Ymin); - // reset min values and calculate width/height from Box corners - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 0}, static_cast(bb_Xmin))); - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 1}, static_cast(bb_Ymin))); - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 2}, static_cast(bb_Xmax - bb_Xmin))); - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 3}, static_cast(bb_Ymax - bb_Ymin))); - } - // create new tensor and copy over bboxes still valid to the image - // bboxes outside of new cropped region are ignored - empty tensor returned in case of none - *bboxCount = correct_ind.size(); - uint32_t temp; - for (auto slice : correct_ind) { // for every index in the loop - for (int ix = 0; ix < bboxDim; ix++) { - RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&temp, {slice, ix})); - copyVals.push_back(temp); - } - } - std::shared_ptr retV; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&retV, copyVals, TensorShape({static_cast(*bboxCount), bboxDim}))); - (*bboxList) = retV; // reset pointer - return Status::OK(); -} - -Status PadBBoxes(const std::shared_ptr *bboxList, const size_t &bboxCount, int32_t pad_top, int32_t pad_left) { - for (int i = 0; i < bboxCount; i++) { - uint32_t xMin, yMin; - RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&xMin, {i, 0})); - RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&yMin, {i, 1})); - xMin += static_cast(pad_left); // should not be negative - yMin += static_cast(pad_top); - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 0}, xMin)); - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 1}, yMin)); - } - return Status::OK(); -} - -Status UpdateBBoxesForResize(const std::shared_ptr &bboxList, const size_t &bboxCount, int32_t target_width_, - int32_t target_height_, int orig_width, int orig_height) { - uint32_t bb_Xmin, bb_Ymin, bb_Xwidth, bb_Ywidth; - // cast to float to preseve fractional - double W_aspRatio = (target_width_ * 1.0) / (orig_width * 1.0); - double H_aspRatio = (target_height_ * 1.0) / (orig_height * 1.0); - for (int i = 0; i < bboxCount; i++) { - // for each bounding box - RETURN_IF_NOT_OK(bboxList->GetUnsignedIntAt(&bb_Xmin, {i, 0})); - RETURN_IF_NOT_OK(bboxList->GetUnsignedIntAt(&bb_Ymin, {i, 1})); - RETURN_IF_NOT_OK(bboxList->GetUnsignedIntAt(&bb_Xwidth, {i, 2})); - RETURN_IF_NOT_OK(bboxList->GetUnsignedIntAt(&bb_Ywidth, {i, 3})); - // update positions and widths - bb_Xmin = bb_Xmin * W_aspRatio; - bb_Ymin = bb_Ymin * H_aspRatio; - bb_Xwidth = bb_Xwidth * W_aspRatio; - bb_Ywidth = bb_Ywidth * H_aspRatio; - // reset bounding box values - RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 0}, bb_Xmin)); - RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 1}, bb_Ymin)); - RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 2}, bb_Xwidth)); - RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 3}, bb_Ywidth)); - } - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/dataset/kernels/image/image_utils.h deleted file mode 100644 index 57ffce6a12..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/image_utils.h +++ /dev/null @@ -1,263 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ -#define DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ - -#include - -#include -#include -#include -#include -#if defined(_WIN32) || defined(_WIN64) -#undef HAVE_STDDEF_H -#undef HAVE_STDLIB_H -#endif -#include "./jpeglib.h" -#include "./jerror.h" -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3 }; - -enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 }; - -void JpegErrorExitCustom(j_common_ptr cinfo); - -struct JpegErrorManagerCustom { - // "public" fields - struct jpeg_error_mgr pub; - // for return to caller - jmp_buf setjmp_buffer; -}; - -// Returns the interpolation mode in openCV format -// @param mode: interpolation mode in DE format -int GetCVInterpolationMode(InterpolationMode mode); - -// Returns the openCV equivalent of the border type used for padding. -// @param type -// @return -int GetCVBorderType(BorderType type); - -// Returns flipped image -// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param flip_code: 1 for Horizontal (around y-axis), 0 for Vertical (around x-axis), -1 for both -// The flipping happens in place. -Status Flip(std::shared_ptr input, std::shared_ptr *output, int flip_code); - -// Returns Horizontally flipped image -// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// The flipping happens in place. -Status HorizontalFlip(std::shared_ptr input, std::shared_ptr *output); - -// Returns Vertically flipped image -// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// The flipping happens in place. -Status VerticalFlip(std::shared_ptr input, std::shared_ptr *output); - -// Returns Resized image. -// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param output_height: height of output -// @param output_width: width of output -// @param fx: horizontal scale -// @param fy: vertical scale -// @param InterpolationMode: the interpolation mode -// @param output: Resized image of shape or -// and same type as input -Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, - int32_t output_width, double fx = 0.0, double fy = 0.0, - InterpolationMode mode = InterpolationMode::kLinear); - -// Returns Decoded image -// Supported images: -// BMP JPEG JPG PNG TIFF -// supported by opencv, if user need more image analysis capabilities, please compile opencv particularlly. -// @param input: CVTensor containing the not decoded image 1D bytes -// @param output: Decoded image Tensor of shape and type DE_UINT8. Pixel order is RGB -Status Decode(const std::shared_ptr &input, std::shared_ptr *output); - -Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *output); - -bool HasJpegMagic(const std::shared_ptr &input); - -void JpegSetSource(j_decompress_ptr c_info, const void *data, int64_t data_size); - -Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr *output, int x = 0, int y = 0, - int w = 0, int h = 0); -// Returns Rescaled image -// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param rescale: rescale parameter -// @param shift: shift parameter -// @param output: Rescaled image Tensor of same input shape and type DE_FLOAT32 -Status Rescale(const std::shared_ptr &input, std::shared_ptr *output, float rescale, float shift); - -// Returns cropped ROI of an image -// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param x: starting horizontal position of ROI -// @param y: starting vertical position of ROI -// @param w: width of the ROI -// @param h: height of the ROI -// @param output: Cropped image Tensor of shape or and same input type. -Status Crop(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, int w, int h); - -// Swaps the channels in the image, i.e. converts HWC to CHW -// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param output: Tensor of shape or and same input type. -Status HwcToChw(std::shared_ptr input, std::shared_ptr *output); - -// Swap the red and blue pixels (RGB <-> BGR) -// @param input: Tensor of shape and any OpenCv compatible type, see CVTensor. -// @param output: Swapped image of same shape and type -Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output); - -// Crops and resizes the image -// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param x: horizontal start point -// @param y: vertical start point -// @param crop_height: height of the cropped ROI -// @param crop_width: width of the cropped ROI -// @param target_width: width of the final resized image -// @param target_height: height of the final resized image -// @param InterpolationMode: the interpolation used in resize operation -// @param output: Tensor of shape or -// and same type as input -Status CropAndResize(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, - int crop_height, int crop_width, int target_height, int target_width, InterpolationMode mode); - -// Returns rotated image -// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param fx: rotation center x coordinate -// @param fy: rotation center y coordinate -// @param degree: degree to rotate -// @param expand: if reshape is necessary -// @param output: rotated image of same input type. -Status Rotate(const std::shared_ptr &input, std::shared_ptr *output, float fx, float fy, float degree, - InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, bool expand = false, - uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); - -// Returns Normalized image -// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -// @param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order -// @param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order -// @param output: Normalized image Tensor of same input shape and type DE_FLOAT32 -Status Normalize(const std::shared_ptr &input, std::shared_ptr *output, - const std::shared_ptr &mean, const std::shared_ptr &std); - -// Returns image with adjusted brightness. -// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -// @param alpha: Alpha value to adjust brightness by. Should be a positive number. -// If user input one value in python, the range is [1 - value, 1 + value]. -// This will output original image multiplied by alpha. 0 gives a black image, 1 gives the -// original image while 2 increases the brightness by a factor of 2. -// @param output: Adjusted image of same shape and type. -Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); - -// Returns image with adjusted contrast. -// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -// @param alpha: Alpha value to adjust contrast by. Should be a positive number. -// If user input one value in python, the range is [1 - value, 1 + value]. -// 0 gives a solid gray image, 1 gives the original image while 2 increases -// the contrast by a factor of 2. -// @param output: Adjusted image of same shape and type. -Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); - -// Returns image with adjusted saturation. -// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -// @param alpha: Alpha value to adjust saturation by. Should be a positive number. -// If user input one value in python, the range is [1 - value, 1 + value]. -// 0 will give a black and white image, 1 will give the original image while -// 2 will enhance the saturation by a factor of 2. -// @param output: Adjusted image of same shape and type. -Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); - -// Returns image with adjusted hue. -// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -// @param hue: Hue value to adjust by, should be within range [-0.5, 0.5]. 0.5 and - 0.5 will reverse the hue channel -// completely. -// If user input one value in python, the range is [-value, value]. -// @param output: Adjusted image of same shape and type. -Status AdjustHue(const std::shared_ptr &input, std::shared_ptr *output, const float &hue); - -// Masks out a random section from the image with set dimension -// @param input: input Tensor -// @param output: cutOut Tensor -// @param box_height: height of the cropped box -// @param box_width: width of the cropped box -// @param num_patches: number of boxes to cut out from the image -// @param bounded: boolean flag to toggle between random erasing and cutout -// @param random_color: whether or not random fill value should be used -// @param fill_r: red fill value for erase -// @param fill_g: green fill value for erase -// @param fill_b: blue fill value for erase. -Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, - int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, - uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); - -// Pads the input image and puts the padded image in the output -// @param input: input Tensor -// @param output: padded Tensor -// @param pad_top: amount of padding done in top -// @param pad_bottom: amount of padding done in bottom -// @param pad_left: amount of padding done in left -// @param pad_right: amount of padding done in right -// @param border_types: the interpolation to be done in the border -// @param fill_r: red fill value for pad -// @param fill_g: green fill value for pad -// @param fill_b: blue fill value for pad. -Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, - const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, - uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); - -// -------- BBOX OPERATIONS -------- // -// Updates and checks bounding boxes for new cropped region of image -// @param bboxList: A tensor contaning bounding box tensors -// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop -// @param CB_Xmin: Image's CropBox Xmin coordinate -// @param CB_Xmin: Image's CropBox Ymin coordinate -// @param CB_Xmax: Image's CropBox Xmax coordinate - (Xmin + width) -// @param CB_Xmax: Image's CropBox Ymax coordinate - (Ymin + height) -Status UpdateBBoxesForCrop(std::shared_ptr *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax, - int CB_Ymax); - -// Updates bounding boxes with required Top and Left padding -// Top and Left padding amounts required to adjust bboxs min X,Y values according to padding 'push' -// Top/Left since images 0,0 coordinate is taken from top left -// @param bboxList: A tensor contaning bounding box tensors -// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop -// @param pad_top: Total amount of padding applied to image top -// @param pad_left: Total amount of padding applied to image left side -Status PadBBoxes(const std::shared_ptr *bboxList, const size_t &bboxCount, int32_t pad_top, int32_t pad_left); - -// Updates bounding boxes for an Image Resize Operation - Takes in set of valid BBoxes -// For e.g those that remain after a crop -// @param bboxList: A tensor contaning bounding box tensors -// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop -// @param bboxList: A tensor contaning bounding box tensors -// @param target_width_: required width of image post resize -// @param target_width_: required height of image post resize -// @param orig_width: current width of image pre resize -// @param orig_height: current height of image pre resize -Status UpdateBBoxesForResize(const std::shared_ptr &bboxList, const size_t &bboxCount, int32_t target_width_, - int32_t target_height_, int orig_width, int orig_height); - -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/normalize_op.cc b/mindspore/ccsrc/dataset/kernels/image/normalize_op.cc deleted file mode 100644 index 638eaad264..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/normalize_op.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/normalize_op.h" - -#include - -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -NormalizeOp::NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b) { - int size[] = {3}; - cv::Mat mean_cv(1, size, CV_32F); - mean_cv.at(0) = mean_r; - mean_cv.at(1) = mean_g; - mean_cv.at(2) = mean_b; - mean_ = std::make_shared(mean_cv); - mean_->Squeeze(); - - cv::Mat std_cv(1, size, CV_32F); - std_cv.at(0) = std_r; - std_cv.at(1) = std_g; - std_cv.at(2) = std_b; - std_ = std::make_shared(std_cv); - std_->Squeeze(); -} - -Status NormalizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // Doing the normalization - return Normalize(input, output, mean_, std_); -} - -void NormalizeOp::Print(std::ostream &out) const { - out << "NormalizeOp, mean: " << mean_->mat().at(0) << ", " << mean_->mat().at(1) << ", " - << mean_->mat().at(2) << "std: " << std_->mat().at(0) << ", " << std_->mat().at(1) << ", " - << std_->mat().at(2) << std::endl; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/normalize_op.h b/mindspore/ccsrc/dataset/kernels/image/normalize_op.h deleted file mode 100644 index 7aa6fa69bd..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/normalize_op.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ - -#include - -#include "dataset/core/cv_tensor.h" -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class NormalizeOp : public TensorOp { - public: - NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b); - - ~NormalizeOp() override = default; - - void Print(std::ostream &out) const override; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - std::shared_ptr mean_; - std::shared_ptr std_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/pad_op.cc b/mindspore/ccsrc/dataset/kernels/image/pad_op.cc deleted file mode 100644 index b4d9c2bbf0..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/pad_op.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/pad_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const BorderType PadOp::kDefBorderType = BorderType::kConstant; -const uint8_t PadOp::kDefFillR = 0; -const uint8_t PadOp::kDefFillG = 0; -const uint8_t PadOp::kDefFillB = 0; - -PadOp::PadOp(int32_t pad_top, int32_t pad_bottom, int32_t pad_left, int32_t pad_right, BorderType border_types, - uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) - : pad_top_(pad_top), - pad_bottom_(pad_bottom), - pad_left_(pad_left), - pad_right_(pad_right), - boarder_type_(border_types), - fill_r_(fill_r), - fill_g_(fill_g), - fill_b_(fill_b) {} - -Status PadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - return Pad(input, output, pad_top_, pad_bottom_, pad_left_, pad_right_, boarder_type_, fill_r_, fill_g_, fill_b_); -} - -Status PadOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out({-1, -1, 3}); // we don't know what is output image size, but we know it should be 3 channels - if (inputs[0].Rank() == 1) outputs.emplace_back(out); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/pad_op.h b/mindspore/ccsrc/dataset/kernels/image/pad_op.h deleted file mode 100644 index 76d99d0162..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/pad_op.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_PAD_OP_H_ -#define DATASET_KERNELS_IMAGE_PAD_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class PadOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const BorderType kDefBorderType; - static const uint8_t kDefFillR; - static const uint8_t kDefFillG; - static const uint8_t kDefFillB; - - // Constructor for PadOp. - // @param pad_top number of pixels to pad the top of image with. - // @param pad_bottom number of pixels to pad the bottom of the image with. - // @param pad_left number of pixels to pad the left of the image with. - // @param pad_right number of pixels to pad the right of the image with. - // @param border_types BorderType enum, the type of boarders that we are using. - // @param fill_r R value for the color to pad with. - // @param fill_g G value for the color to pad with. - // @param fill_b B value for the color to pad with. - PadOp(int32_t pad_top, int32_t pad_bottom, int32_t pad_left, int32_t pad_right, BorderType border_types, - uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); - - ~PadOp() override = default; - - void Print(std::ostream &out) const override { out << "PadOp: "; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - private: - int32_t pad_top_; - int32_t pad_bottom_; - int32_t pad_left_; - int32_t pad_right_; - BorderType boarder_type_; - uint8_t fill_r_; - uint8_t fill_g_; - uint8_t fill_b_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_PAD_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.cc deleted file mode 100644 index e420f86e9a..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_color_adjust_op.h" - -#include - -#include "dataset/core/config_manager.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -RandomColorAdjustOp::RandomColorAdjustOp(float s_bright_factor, float e_bright_factor, float s_contrast_factor, - float e_contrast_factor, float s_saturation_factor, float e_saturation_factor, - float s_hue_factor, float e_hue_factor) - : bright_factor_start_(s_bright_factor), - bright_factor_end_(e_bright_factor), - contrast_factor_start_(s_contrast_factor), - contrast_factor_end_(e_contrast_factor), - saturation_factor_start_(s_saturation_factor), - saturation_factor_end_(e_saturation_factor), - hue_factor_start_(s_hue_factor), - hue_factor_end_(e_hue_factor) { - rnd_.seed(GetSeed()); -} - -Status RandomColorAdjustOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - - // randomly select an augmentation to apply to the input image until all the transformations run - std::vector params_vector = {"brightness", "contrast", "saturation", "hue"}; - - std::shuffle(params_vector.begin(), params_vector.end(), rnd_); - - *output = std::static_pointer_cast(input); - // determine if certain augmentation needs to be executed: - for (const auto ¶m : params_vector) { - // case switch - if (param == "brightness") { - if (CmpFloat(bright_factor_start_, bright_factor_end_) && CmpFloat(bright_factor_start_, 1.0f)) { - MS_LOG(DEBUG) << "Not running brightness."; - } else { - // adjust the brightness of an image - float random_factor = std::uniform_real_distribution(bright_factor_start_, bright_factor_end_)(rnd_); - RETURN_IF_NOT_OK(AdjustBrightness(*output, output, random_factor)); - } - } else if (param == "contrast") { - if (CmpFloat(contrast_factor_start_, contrast_factor_end_) && CmpFloat(contrast_factor_start_, 1.0f)) { - MS_LOG(DEBUG) << "Not running contrast."; - } else { - float random_factor = std::uniform_real_distribution(contrast_factor_start_, contrast_factor_end_)(rnd_); - RETURN_IF_NOT_OK(AdjustContrast(*output, output, random_factor)); - } - } else if (param == "saturation") { - // adjust the Saturation of an image - if (CmpFloat(saturation_factor_start_, saturation_factor_end_) && CmpFloat(saturation_factor_start_, 1.0f)) { - MS_LOG(DEBUG) << "Not running saturation."; - } else { - float random_factor = - std::uniform_real_distribution(saturation_factor_start_, saturation_factor_end_)(rnd_); - RETURN_IF_NOT_OK(AdjustSaturation(*output, output, random_factor)); - } - } else if (param == "hue") { - if (CmpFloat(hue_factor_start_, hue_factor_end_) && CmpFloat(hue_factor_start_, 0.0f)) { - MS_LOG(DEBUG) << "Not running hue."; - } else { - // adjust the Hue of an image - float random_factor = std::uniform_real_distribution(hue_factor_start_, hue_factor_end_)(rnd_); - RETURN_IF_NOT_OK(AdjustHue(*output, output, random_factor)); - } - } - } - // now after we do all the transformations, the last one is fine - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.h b/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.h deleted file mode 100644 index 74d1ec450b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomColorAdjustOp : public TensorOp { - public: - static const uint32_t kDefSeed; - - // Constructor for RandomColorAdjustOp. - // @param s_bright_factor brightness change range start value. - // @param e_bright_factor brightness change range end value. - // @param s_contrast_factor contrast change range start value. - // @param e_contrast_factor contrast change range start value. - // @param s_saturation_factor saturation change range end value. - // @param e_saturation_factor saturation change range end value. - // @param s_hue_factor hue change factor start value, this should be greater than -0.5. - // @param e_hue_factor hue change factor start value, this should be less than 0.5. - // @param seed optional seed to pass in to the constructor. - // @details the randomly chosen degree is uniformly distributed. - RandomColorAdjustOp(float s_bright_factor, float e_bright_factor, float s_contrast_factor, float e_contrast_factor, - float s_saturation_factor, float e_saturation_factor, float s_hue_factor, float e_hue_factor); - - ~RandomColorAdjustOp() override = default; - - // Print function for RandomJitter. - // @param out output stream to print to. - void Print(std::ostream &out) const override { out << "RandomColorAdjustOp: "; } - - // Overrides the base class compute function. - // Calls multiple transform functions in ImageUtils, this function takes an input tensor. - // and transforms its data using openCV, the output memory is manipulated to contain the result. - // @return Status - The error code return. - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - std::mt19937 rnd_; - float bright_factor_start_; - float bright_factor_end_; - float contrast_factor_start_; - float contrast_factor_end_; - float saturation_factor_start_; - float saturation_factor_end_; - float hue_factor_start_; - float hue_factor_end_; - // Compare two floating point variables. Return true if they are same / very close. - inline bool CmpFloat(const float &a, const float &b, float epsilon = 0.0000000001f) const { - return (std::fabs(a - b) < epsilon); - } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc deleted file mode 100644 index c5b5f20c63..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_crop_and_resize_op.h" -#include - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const float RandomCropAndResizeOp::kDefScaleLb = 0.08; -const float RandomCropAndResizeOp::kDefScaleUb = 1.0; -const float RandomCropAndResizeOp::kDefAspectLb = 0.75; -const float RandomCropAndResizeOp::kDefAspectUb = 1.333333; -const InterpolationMode RandomCropAndResizeOp::kDefInterpolation = InterpolationMode::kLinear; -const int32_t RandomCropAndResizeOp::kDefMaxIter = 10; - -RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t target_width, float scale_lb, - float scale_ub, float aspect_lb, float aspect_ub, - InterpolationMode interpolation, int32_t max_iter) - : target_height_(target_height), - target_width_(target_width), - rnd_scale_(scale_lb, scale_ub), - rnd_aspect_(log(aspect_lb), log(aspect_ub)), - interpolation_(interpolation), - aspect_lb_(aspect_lb), - aspect_ub_(aspect_ub), - max_iter_(max_iter) { - rnd_.seed(GetSeed()); -} - -Status RandomCropAndResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); - - int h_in = input->shape()[0]; - int w_in = input->shape()[1]; - int x = 0; - int y = 0; - int crop_height = 0; - int crop_width = 0; - (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); - return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); -} -Status RandomCropAndResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out = TensorShape{target_height_, target_width_}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) { - *crop_width = w_in; - *crop_height = h_in; - CHECK_FAIL_RETURN_UNEXPECTED(w_in != 0, "Width is 0"); - CHECK_FAIL_RETURN_UNEXPECTED(h_in != 0, "Height is 0"); - CHECK_FAIL_RETURN_UNEXPECTED(aspect_lb_ > 0, "Aspect lower bound must be greater than zero"); - for (int32_t i = 0; i < max_iter_; i++) { - double const sample_scale = rnd_scale_(rnd_); - // In case of non-symmetrical aspect ratios, use uniform distribution on a logarithmic sample_scale. - // Note rnd_aspect_ is already a random distribution of the input aspect ratio in logarithmic sample_scale. - double const sample_aspect = exp(rnd_aspect_(rnd_)); - - *crop_width = static_cast(std::round(std::sqrt(h_in * w_in * sample_scale * sample_aspect))); - *crop_height = static_cast(std::round(*crop_width / sample_aspect)); - if (*crop_width <= w_in && *crop_height <= h_in) { - std::uniform_int_distribution<> rd_x(0, w_in - *crop_width); - std::uniform_int_distribution<> rd_y(0, h_in - *crop_height); - *x = rd_x(rnd_); - *y = rd_y(rnd_); - return Status::OK(); - } - } - double const img_aspect = static_cast(w_in) / h_in; - if (img_aspect < aspect_lb_) { - *crop_width = w_in; - *crop_height = static_cast(std::round(*crop_width / static_cast(aspect_lb_))); - } else { - if (img_aspect > aspect_ub_) { - *crop_height = h_in; - *crop_width = static_cast(std::round(*crop_height * static_cast(aspect_ub_))); - } else { - *crop_width = w_in; - *crop_height = h_in; - } - } - *x = static_cast(std::round((w_in - *crop_width) / 2.0)); - *y = static_cast(std::round((h_in - *crop_height) / 2.0)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.h b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.h deleted file mode 100644 index db805a9374..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomCropAndResizeOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefScaleLb; - static const float kDefScaleUb; - static const float kDefAspectLb; - static const float kDefAspectUb; - static const InterpolationMode kDefInterpolation; - static const int32_t kDefMaxIter; - - RandomCropAndResizeOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, - float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, float aspect_ub = kDefAspectUb, - InterpolationMode interpolation = kDefInterpolation, int32_t max_iter = kDefMaxIter); - - ~RandomCropAndResizeOp() override = default; - - void Print(std::ostream &out) const override { - out << "RandomCropAndResize: " << target_height_ << " " << target_width_; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - Status GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width); - - protected: - int32_t target_height_; - int32_t target_width_; - std::uniform_real_distribution rnd_scale_; - std::uniform_real_distribution rnd_aspect_; - std::mt19937 rnd_; - InterpolationMode interpolation_; - int32_t max_iter_; - double aspect_lb_; - double aspect_ub_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc deleted file mode 100644 index fbaf2c9326..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "dataset/util/random.h" -#include "dataset/util/status.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" - -namespace mindspore { -namespace dataset { - -Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); - CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); - - output->resize(2); - (*output)[1] = std::move(input[1]); // move boxes over to output - - size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor - int h_in = input[0]->shape()[0]; - int w_in = input[0]->shape()[1]; - int x = 0; - int y = 0; - int crop_height = 0; - int crop_width = 0; - - RETURN_IF_NOT_OK(RandomCropAndResizeOp::GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width)); - - int maxX = x + crop_width; // max dims of selected CropBox on image - int maxY = y + crop_height; - - RETURN_IF_NOT_OK(UpdateBBoxesForCrop(&(*output)[1], &bboxCount, x, y, maxX, maxY)); // IMAGE_UTIL - RETURN_IF_NOT_OK(CropAndResize(input[0], &(*output)[0], x, y, crop_height, crop_width, target_height_, target_width_, - interpolation_)); - - RETURN_IF_NOT_OK( - UpdateBBoxesForResize((*output)[1], bboxCount, target_width_, target_height_, crop_width, crop_height)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h deleted file mode 100644 index 9675d43933..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ - -#include "dataset/kernels/image/random_crop_and_resize_op.h" - -namespace mindspore { -namespace dataset { - -class RandomCropAndResizeWithBBoxOp : public RandomCropAndResizeOp { - public: - // Constructor for RandomCropAndResizeWithBBoxOp, with default value and passing to base class constructor - RandomCropAndResizeWithBBoxOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, - float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, - float aspect_ub = kDefAspectUb, InterpolationMode interpolation = kDefInterpolation, - int32_t max_iter = kDefMaxIter) - : RandomCropAndResizeOp(target_height, target_width, scale_lb, scale_ub, aspect_lb, aspect_ub, interpolation, - max_iter) {} - - ~RandomCropAndResizeWithBBoxOp() override = default; - - void Print(std::ostream &out) const override { - out << "RandomCropAndResizeWithBBox: " << RandomCropAndResizeOp::target_height_ << " " - << RandomCropAndResizeOp::target_width_; - } - - Status Compute(const TensorRow &input, TensorRow *output) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.cc deleted file mode 100644 index 74aa91ea7e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_crop_decode_resize_op.h" -#include -#include "dataset/kernels/image/image_utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/kernels/image/decode_op.h" - -namespace mindspore { -namespace dataset { -RandomCropDecodeResizeOp::RandomCropDecodeResizeOp(int32_t target_height, int32_t target_width, float scale_lb, - float scale_ub, float aspect_lb, float aspect_ub, - InterpolationMode interpolation, int32_t max_iter) - : RandomCropAndResizeOp(target_height, target_width, scale_lb, scale_ub, aspect_lb, aspect_ub, interpolation, - max_iter) {} - -Status RandomCropDecodeResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - if (input == nullptr) { - RETURN_STATUS_UNEXPECTED("input tensor is null"); - } - if (!HasJpegMagic(input)) { - DecodeOp op(true); - std::shared_ptr decoded; - RETURN_IF_NOT_OK(op.Compute(input, &decoded)); - return RandomCropAndResizeOp::Compute(decoded, output); - } else { - struct jpeg_decompress_struct cinfo {}; - struct JpegErrorManagerCustom jerr {}; - cinfo.err = jpeg_std_error(&jerr.pub); - jerr.pub.error_exit = JpegErrorExitCustom; - try { - jpeg_create_decompress(&cinfo); - JpegSetSource(&cinfo, input->GetBuffer(), input->SizeInBytes()); - (void)jpeg_read_header(&cinfo, TRUE); - jpeg_calc_output_dimensions(&cinfo); - } catch (std::runtime_error &e) { - jpeg_destroy_decompress(&cinfo); - RETURN_STATUS_UNEXPECTED(e.what()); - } - int h_in = cinfo.output_height; - int w_in = cinfo.output_width; - jpeg_destroy_decompress(&cinfo); - - int x = 0; - int y = 0; - int crop_height = 0; - int crop_width = 0; - (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); - - std::shared_ptr decoded; - RETURN_IF_NOT_OK(JpegCropAndDecode(input, &decoded, x, y, crop_width, crop_height)); - return Resize(decoded, output, target_height_, target_width_, 0.0, 0.0, interpolation_); - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.h b/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.h deleted file mode 100644 index 9566169946..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ - -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/image/random_crop_and_resize_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomCropDecodeResizeOp : public RandomCropAndResizeOp { - public: - RandomCropDecodeResizeOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, - float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, float aspect_ub = kDefAspectUb, - InterpolationMode interpolation = kDefInterpolation, int32_t max_iter = kDefMaxIter); - - ~RandomCropDecodeResizeOp() override = default; - - void Print(std::ostream &out) const override { - out << "RandomCropDecodeResize: " << RandomCropAndResizeOp::target_height_ << " " - << RandomCropAndResizeOp::target_width_; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_op.cc deleted file mode 100644 index 110d769f26..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_op.cc +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_crop_op.h" -#include -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t RandomCropOp::kDefPadTop = 0; -const int32_t RandomCropOp::kDefPadBottom = 0; -const int32_t RandomCropOp::kDefPadLeft = 0; -const int32_t RandomCropOp::kDefPadRight = 0; -const BorderType RandomCropOp::kDefBorderType = BorderType::kConstant; -const bool RandomCropOp::kDefPadIfNeeded = false; -const uint8_t RandomCropOp::kDefFillR = 0; -const uint8_t RandomCropOp::kDefFillG = 0; -const uint8_t RandomCropOp::kDefFillB = 0; - -RandomCropOp::RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_top, int32_t pad_bottom, - int32_t pad_left, int32_t pad_right, BorderType border_types, bool pad_if_needed, - uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) - : crop_height_(crop_height), - crop_width_(crop_width), - pad_top_(pad_top), - pad_bottom_(pad_bottom), - pad_left_(pad_left), - pad_right_(pad_right), - pad_if_needed_(pad_if_needed), - border_type_(border_types), - fill_r_(fill_r), - fill_g_(fill_g), - fill_b_(fill_b) { - rnd_.seed(GetSeed()); -} - -Status RandomCropOp::ImagePadding(const std::shared_ptr &input, std::shared_ptr *pad_image, - int32_t *t_pad_top, int32_t *t_pad_bottom, int32_t *t_pad_left, int32_t *t_pad_right, - int32_t *padded_image_w, int32_t *padded_image_h, bool *crop_further) { - *t_pad_top = pad_top_; - *t_pad_bottom = pad_bottom_; - *t_pad_left = pad_left_; - *t_pad_right = pad_right_; - - RETURN_IF_NOT_OK( - Pad(input, pad_image, pad_top_, pad_bottom_, pad_left_, pad_right_, border_type_, fill_r_, fill_g_, fill_b_)); - CHECK_FAIL_RETURN_UNEXPECTED((*pad_image)->shape().Size() >= 2, "Abnormal shape"); - - *padded_image_h = (*pad_image)->shape()[0]; - *padded_image_w = (*pad_image)->shape()[1]; - - if (*padded_image_h == crop_height_ && *padded_image_w == crop_width_) { - *crop_further = false; // no need for further crop - return Status::OK(); - } else if (pad_if_needed_) { - // check the dimensions of the image for padding, if we do need padding, then we change the pad values - if (*padded_image_h < crop_height_) { - RETURN_IF_NOT_OK(Pad(*pad_image, pad_image, crop_height_ - *padded_image_h, crop_height_ - *padded_image_h, 0, 0, - border_type_, fill_r_, fill_g_, fill_b_)); - - // update pad total above/below - t_pad_top += (crop_height_ - *padded_image_h); - t_pad_bottom += (crop_height_ - *padded_image_h); - } - if (*padded_image_w < crop_width_) { - RETURN_IF_NOT_OK(Pad(*pad_image, pad_image, 0, 0, crop_width_ - *padded_image_w, crop_width_ - *padded_image_w, - border_type_, fill_r_, fill_g_, fill_b_)); - // update pad total left/right - t_pad_left += (crop_width_ - *padded_image_w); - t_pad_right += (crop_width_ - *padded_image_w); - } - *padded_image_h = (*pad_image)->shape()[0]; - *padded_image_w = (*pad_image)->shape()[1]; - } - - if (*padded_image_h < crop_height_ || *padded_image_w < crop_width_ || crop_height_ == 0 || crop_width_ == 0) { - return Status(StatusCode::kShapeMisMatch, __LINE__, __FILE__, - "Crop size is greater than the image dimensions or is zero."); - } - return Status::OK(); -} - -void RandomCropOp::GenRandomXY(int *x, int *y, const int32_t &padded_image_w, const int32_t &padded_image_h) { - // GenCropPoints for cropping - *x = std::uniform_int_distribution(0, padded_image_w - crop_width_)(rnd_); - *y = std::uniform_int_distribution(0, padded_image_h - crop_height_)(rnd_); -} - -Status RandomCropOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - - // Apply padding first then crop - std::shared_ptr pad_image; - int32_t t_pad_top, t_pad_bottom, t_pad_left, t_pad_right; - int32_t padded_image_w; - int32_t padded_image_h; - bool crop_further = true; // whether image needs further cropping based on new size & requirements - - RETURN_IF_NOT_OK( // error code sent back directly - ImagePadding(input, &pad_image, &t_pad_top, &t_pad_bottom, &t_pad_left, &t_pad_right, &padded_image_w, - &padded_image_h, &crop_further)); - if (!crop_further) { - *output = pad_image; - return Status::OK(); - } - - int x, y; - GenRandomXY(&x, &y, padded_image_w, padded_image_h); - return Crop(pad_image, output, x, y, crop_width_, crop_height_); -} - -Status RandomCropOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out = TensorShape{crop_height_, crop_width_}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_op.h b/mindspore/ccsrc/dataset/kernels/image/random_crop_op.h deleted file mode 100644 index cd43ec1efb..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_op.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomCropOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefPadTop; - static const int32_t kDefPadBottom; - static const int32_t kDefPadLeft; - static const int32_t kDefPadRight; - static const BorderType kDefBorderType; - static const bool kDefPadIfNeeded; - static const uint8_t kDefFillR; - static const uint8_t kDefFillG; - static const uint8_t kDefFillB; - - RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_top = kDefPadTop, - int32_t pad_bottom = kDefPadBottom, int32_t pad_left = kDefPadLeft, int32_t pad_right = kDefPadRight, - BorderType border_types = kDefBorderType, bool pad_if_needed = kDefPadIfNeeded, - uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); - - ~RandomCropOp() override = default; - - void Print(std::ostream &out) const override { out << "RandomCropOp: " << crop_height_ << " " << crop_width_; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - // Function breaks out the compute function's image padding functionality and makes available to other Ops - // Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op - // @param input: Input is the original Image - // @param pad_image: Pointer to new Padded image - // @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required - // @param t_pad_bottom: Total bottom Padding - Based on input and value calculated in function if required - // @param t_pad_left: Total left Padding - Based on input and value calculated in function if required - // @param t_pad_right: Total right Padding - Based on input and value calculated in function if required - // @param padded_image_w: Final Width of the 'pad_image' - // @param padded_image_h: Final Height of the 'pad_image' - // @param crop_further: Whether image required cropping after padding - False if new padded image matches required - // dimensions - Status ImagePadding(const std::shared_ptr &input, std::shared_ptr *pad_image, int32_t *t_pad_top, - int32_t *t_pad_bottom, int32_t *t_pad_left, int32_t *t_pad_right, int32_t *padded_image_w, - int32_t *padded_image_h, bool *crop_further); - - // Function breaks X,Y generation functionality out of original compute function and makes available to other Ops - void GenRandomXY(int *x, int *y, const int32_t &padded_image_w, const int32_t &padded_image_h); - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - protected: - int32_t crop_height_ = 0; - int32_t crop_width_ = 0; - - private: - int32_t pad_top_ = 0; - int32_t pad_bottom_ = 0; - int32_t pad_left_ = 0; - int32_t pad_right_ = 0; - bool pad_if_needed_ = false; - BorderType border_type_; - uint8_t fill_r_ = 0; - uint8_t fill_g_ = 0; - uint8_t fill_b_ = 0; - std::mt19937 rnd_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc deleted file mode 100644 index c873307afd..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "dataset/kernels/image/random_crop_with_bbox_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); - - std::shared_ptr pad_image; - int32_t t_pad_top, t_pad_bottom, t_pad_left, t_pad_right; - size_t boxCount = input[1]->shape()[0]; // number of rows - - int32_t padded_image_h; - int32_t padded_image_w; - - output->resize(2); - (*output)[1] = std::move(input[1]); // since some boxes may be removed - - bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches - RETURN_IF_NOT_OK( // Error passed back to caller - RandomCropOp::ImagePadding(input[0], &pad_image, &t_pad_top, &t_pad_bottom, &t_pad_left, &t_pad_right, - &padded_image_w, &padded_image_h, &crop_further)); - - // update bounding boxes with new values based on relevant image padding - if (t_pad_left || t_pad_bottom) { - RETURN_IF_NOT_OK(PadBBoxes(&(*output)[1], boxCount, t_pad_left, t_pad_top)); - } - if (!crop_further) { - // no further cropping required - (*output)[0] = pad_image; - (*output)[1] = std::move(input[1]); - return Status::OK(); - } - - int x, y; - RandomCropOp::GenRandomXY(&x, &y, padded_image_w, padded_image_h); - int maxX = x + RandomCropOp::crop_width_; // max dims of selected CropBox on image - int maxY = y + RandomCropOp::crop_height_; - RETURN_IF_NOT_OK(UpdateBBoxesForCrop(&(*output)[1], &boxCount, x, y, maxX, maxY)); - return Crop(pad_image, &(*output)[0], x, y, RandomCropOp::crop_width_, RandomCropOp::crop_height_); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.h deleted file mode 100644 index 88a58d3557..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ - -#include -#include - -#include "dataset/kernels/image/random_crop_op.h" - -namespace mindspore { -namespace dataset { -class RandomCropWithBBoxOp : public RandomCropOp { - public: - // Constructor for RandomCropWithBBoxOp, with default value and passing to base class constructor - RandomCropWithBBoxOp(int32_t crop_height, int32_t crop_width, int32_t pad_top = kDefPadTop, - int32_t pad_bottom = kDefPadBottom, int32_t pad_left = kDefPadLeft, - int32_t pad_right = kDefPadRight, BorderType border_types = kDefBorderType, - bool pad_if_needed = kDefPadIfNeeded, uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, - uint8_t fill_b = kDefFillB) - : RandomCropOp(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right, border_types, pad_if_needed, - fill_r, fill_g, fill_b) {} - - ~RandomCropWithBBoxOp() override = default; - - void Print(std::ostream &out) const override { - out << "RandomCropWithBBoxOp: " << RandomCropOp::crop_height_ << " " << RandomCropOp::crop_width_; - } - - Status Compute(const TensorRow &input, TensorRow *output) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.cc deleted file mode 100644 index 5a5c632e81..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "dataset/kernels/image/random_horizontal_flip_bbox_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/pybind_support.h" - -namespace mindspore { -namespace dataset { -const float RandomHorizontalFlipWithBBoxOp::kDefProbability = 0.5; - -Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); - if (distribution_(rnd_)) { - // To test bounding boxes algorithm, create random bboxes from image dims - size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes - float img_center = (input[0]->shape()[1] / 2.); // get the center of the image - - for (int i = 0; i < num_of_boxes; i++) { - uint32_t b_w = 0; // bounding box width - uint32_t min_x = 0; - // get the required items - input[1]->GetItemAt(&min_x, {i, 0}); - input[1]->GetItemAt(&b_w, {i, 2}); - // do the flip - float diff = img_center - min_x; // get distance from min_x to center - uint32_t refl_min_x = diff + img_center; // get reflection of min_x - uint32_t new_min_x = refl_min_x - b_w; // subtract from the reflected min_x to get the new one - input[1]->SetItemAt({i, 0}, new_min_x); - } - (*output).push_back(nullptr); - (*output).push_back(nullptr); - // move input to output pointer of bounding boxes - (*output)[1] = std::move(input[1]); - // perform HorizontalFlip on the image - std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input[0])); - return HorizontalFlip(std::static_pointer_cast(input_cv), &(*output)[0]); - } - *output = input; - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.h deleted file mode 100644 index 06c96e11ae..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_bbox_op.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ - -#include -#include -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl_bind.h" - -namespace mindspore { -namespace dataset { -class RandomHorizontalFlipWithBBoxOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefProbability; - - explicit RandomHorizontalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { - rnd_.seed(GetSeed()); - } - - ~RandomHorizontalFlipWithBBoxOp() override = default; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const RandomHorizontalFlipWithBBoxOp &so) { - so.Print(out); - return out; - } - - void Print(std::ostream &out) const override { out << "RandomHorizontalFlipWithBBoxOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - private: - std::mt19937 rnd_; - std::bernoulli_distribution distribution_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.cc deleted file mode 100644 index ae76e1bf59..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_horizontal_flip_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const float RandomHorizontalFlipOp::kDefProbability = 0.5; - -Status RandomHorizontalFlipOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (distribution_(rnd_)) { - return HorizontalFlip(input, output); - } - *output = input; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.h b/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.h deleted file mode 100644 index efea124533..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomHorizontalFlipOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefProbability; - - explicit RandomHorizontalFlipOp(float probability = kDefProbability) : distribution_(probability) { - rnd_.seed(GetSeed()); - } - - ~RandomHorizontalFlipOp() override = default; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const RandomHorizontalFlipOp &so) { - so.Print(out); - return out; - } - - void Print(std::ostream &out) const override { out << "RandomHorizontalFlipOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - std::mt19937 rnd_; - std::bernoulli_distribution distribution_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_resize_op.cc deleted file mode 100644 index c14224a930..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_resize_op.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_resize_op.h" - -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t RandomResizeOp::kDefTargetWidth = 0; - -Status RandomResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - // Randomly selects from the following four interpolation methods - // 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area - interpolation_ = static_cast(distribution_(random_generator_)); - return ResizeOp::Compute(input, output); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_resize_op.h b/mindspore/ccsrc/dataset/kernels/image/random_resize_op.h deleted file mode 100644 index af23803d4c..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_resize_op.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomResizeOp : public ResizeOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefTargetWidth; - - explicit RandomResizeOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeOp(size_1, size_2) { - random_generator_.seed(GetSeed()); - } - - ~RandomResizeOp() = default; - - // Description: A function that prints info about the node - void Print(std::ostream &out) const override { - out << "RandomResizeOp: " << ResizeOp::size1_ << " " << ResizeOp::size2_; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - std::mt19937 random_generator_; - std::uniform_int_distribution distribution_{0, 3}; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.cc deleted file mode 100644 index de69c02e39..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/image/random_resize_with_bbox_op.h" -#include "dataset/kernels/image/resize_with_bbox_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t RandomResizeWithBBoxOp::kDefTargetWidth = 0; - -Status RandomResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - // Randomly selects from the following four interpolation methods - // 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area - interpolation_ = static_cast(distribution_(random_generator_)); - RETURN_IF_NOT_OK(ResizeWithBBoxOp::Compute(input, output)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.h deleted file mode 100644 index 4a7614525f..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H -#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/image/resize_with_bbox_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefTargetWidth; - explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) { - random_generator_.seed(GetSeed()); - } - - ~RandomResizeWithBBoxOp() = default; - - // Description: A function that prints info about the node - void Print(std::ostream &out) const override { - out << "RandomResizeWithBBoxOp: " << ResizeWithBBoxOp::size1_ << " " << ResizeWithBBoxOp::size2_; - } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - private: - std::mt19937 random_generator_; - std::uniform_int_distribution distribution_{0, 3}; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.cc deleted file mode 100644 index 65e024865b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.cc +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_rotation_op.h" - -#include - -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const float RandomRotationOp::kDefCenterX = -1; -const float RandomRotationOp::kDefCenterY = -1; -const InterpolationMode RandomRotationOp::kDefInterpolation = InterpolationMode::kNearestNeighbour; -const bool RandomRotationOp::kDefExpand = false; -const uint8_t RandomRotationOp::kDefFillR = 0; -const uint8_t RandomRotationOp::kDefFillG = 0; -const uint8_t RandomRotationOp::kDefFillB = 0; - -// constructor -RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, float center_x, float center_y, - InterpolationMode interpolation, bool expand, uint8_t fill_r, uint8_t fill_g, - uint8_t fill_b) - : degree_start_(start_degree), - degree_end_(end_degree), - center_x_(center_x), - center_y_(center_y), - interpolation_(interpolation), - expand_(expand), - fill_r_(fill_r), - fill_g_(fill_g), - fill_b_(fill_b) { - rnd_.seed(GetSeed()); -} - -// main function call for random rotation : Generate the random degrees -Status RandomRotationOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - float random_double = distribution_(rnd_); - // get the degree rotation range, mod by 360 because full rotation doesn't affect - // the way this op works (uniform distribution) - // assumption here is that mDegreesEnd > mDegreeStart so we always get positive number - // Note: the range technically is greater than 360 degrees, but will be halved - float degree_range = (degree_end_ - degree_start_) / 2; - float mid = (degree_end_ + degree_start_) / 2; - float degree = mid + random_double * degree_range; - - return Rotate(input, output, center_x_, center_y_, degree, interpolation_, expand_, fill_r_, fill_g_, fill_b_); -} -Status RandomRotationOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - int32_t outputH = -1, outputW = -1; - // if expand_, then we cannot know the shape. We need the input image to find the output shape --> set it to - // <-1,-1[,3]> - if (!expand_) { - outputH = inputs[0][0]; - outputW = inputs[0][1]; - } - TensorShape out = TensorShape{outputH, outputW}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.h b/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.h deleted file mode 100644 index d30cd24288..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.h +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/kernels/image/image_utils.h" - -namespace mindspore { -namespace dataset { -class RandomRotationOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefCenterX; - static const float kDefCenterY; - static const InterpolationMode kDefInterpolation; - static const bool kDefExpand; - static const uint8_t kDefFillR; - static const uint8_t kDefFillG; - static const uint8_t kDefFillB; - - // Constructor for RandomRotationOp - // @param startDegree starting range for random degree - // @param endDegree ending range for random degree - // @param centerX x coordinate for center of image rotation - // @param centerY y coordinate for center of image rotation - // @param interpolation DE interpolation mode for rotation - // @param expand option for the output image shape to change - // @param fill_r R value for the color to pad with - // @param fill_g G value for the color to pad with - // @param fill_b B value for the color to pad with - // @details the randomly chosen degree is uniformly distributed - // @details the output shape, if changed, will contain the entire rotated image - // @note maybe using unsigned long int isn't the best here according to our coding rules - RandomRotationOp(float start_degree, float end_degree, float center_x = kDefCenterX, float center_y = kDefCenterY, - InterpolationMode interpolation = kDefInterpolation, bool expand = kDefExpand, - uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); - - ~RandomRotationOp() override = default; - - // Print function for RandomRotation - // @param out output stream to print to - void Print(std::ostream &out) const override { out << "RandomRotationOp: "; } - - // Overrides the base class compute function - // Calls the rotate function in ImageUtils, this function takes an input tensor - // and transforms its data using openCV, the output memory is manipulated to contain the result - // @return Status - The error code return - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - private: - float degree_start_; - float degree_end_; - float center_x_; - float center_y_; - InterpolationMode interpolation_; - bool expand_; - uint8_t fill_r_; - uint8_t fill_g_; - uint8_t fill_b_; - std::uniform_real_distribution distribution_{-1.0, 1.0}; - std::mt19937 rnd_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.cc deleted file mode 100644 index 096923a9ec..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/image/random_vertical_flip_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const float RandomVerticalFlipOp::kDefProbability = 0.5; - -Status RandomVerticalFlipOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (distribution_(rnd_)) { - return VerticalFlip(input, output); - } - *output = input; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.h b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.h deleted file mode 100644 index 18693bc0eb..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -class RandomVerticalFlipOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefProbability; - - explicit RandomVerticalFlipOp(float probability = kDefProbability) : distribution_(probability) { - rnd_.seed(GetSeed()); - } - - ~RandomVerticalFlipOp() override = default; - - void Print(std::ostream &out) const override { out << "RandomVerticalFlipOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - std::mt19937 rnd_; - std::bernoulli_distribution distribution_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc deleted file mode 100644 index ffea851eac..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "dataset/util/status.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/image/random_vertical_flip_with_bbox_op.h" - -namespace mindspore { -namespace dataset { -const float RandomVerticalFlipWithBBoxOp::kDefProbability = 0.5; -Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); - - if (distribution_(rnd_)) { - dsize_t imHeight = input[0]->shape()[0]; - size_t boxCount = input[1]->shape()[0]; // number of rows in tensor - - // one time allocation -> updated in the loop - // type defined based on VOC test dataset - for (int i = 0; i < boxCount; i++) { - uint32_t boxCorner_y = 0; - uint32_t boxHeight = 0; - uint32_t newBoxCorner_y = 0; - RETURN_IF_NOT_OK(input[1]->GetUnsignedIntAt(&boxCorner_y, {i, 1})); // get min y of bbox - RETURN_IF_NOT_OK(input[1]->GetUnsignedIntAt(&boxHeight, {i, 3})); // get height of bbox - - // subtract (curCorner + height) from (max) for new Corner position - newBoxCorner_y = (imHeight - 1) - ((boxCorner_y + boxHeight) - 1); - RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); - } - - output->resize(2); - (*output)[1] = std::move(input[1]); - - return VerticalFlip(input[0], &(*output)[0]); - } - *output = input; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.h deleted file mode 100644 index 4764cc2b75..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -class RandomVerticalFlipWithBBoxOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefProbability; - // Constructor for RandomVerticalFlipWithBBoxOp - // @param probability: Probablity of Image flipping, 0.5 by default - explicit RandomVerticalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { - rnd_.seed(GetSeed()); - } - - ~RandomVerticalFlipWithBBoxOp() override = default; - - void Print(std::ostream &out) const override { out << "RandomVerticalFlipWithBBoxOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - private: - std::mt19937 rnd_; - std::bernoulli_distribution distribution_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/rescale_op.cc b/mindspore/ccsrc/dataset/kernels/image/rescale_op.cc deleted file mode 100644 index fd1807991c..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/rescale_op.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/rescale_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status RescaleOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - return Rescale(input, output, rescale_, shift_); -} -Status RescaleOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = DataType(DataType::DE_FLOAT32); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/rescale_op.h b/mindspore/ccsrc/dataset/kernels/image/rescale_op.h deleted file mode 100644 index 8aee75b0c1..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/rescale_op.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RESCALE_OP_H_ -#define DATASET_KERNELS_IMAGE_RESCALE_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RescaleOp : public TensorOp { - public: - RescaleOp(float rescale_ratio, float shift_ratio) : rescale_(rescale_ratio), shift_(shift_ratio) {} - - ~RescaleOp() override = default; - - void Print(std::ostream &out) const override { - out << "RescaleOp: shift: " << shift_ << ", Rescale: " << rescale_ << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - private: - float rescale_; - float shift_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RESCALE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.cc b/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.cc deleted file mode 100644 index 658caac6a5..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.cc +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/resize_bilinear_op.h" -#include - -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t ResizeBilinearOp::kDefWidth = 0; - -void ResizeBilinearOp::Print(std::ostream &out) const { out << "ResizeBilinearOp: "; } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.h b/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.h deleted file mode 100644 index c8c2a5185b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ -#define DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ - -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class ResizeBilinearOp : public ResizeOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefWidth; - - // Name: constructor - // Resizes the image to the output specified size using Bilinear interpolation. - // If only one value is provided, the it will resize the smaller size and maintains - // the aspect ratio. - // @param size1: the first size of output. If only this parameter is provided - // the smaller dimension will be resized to this and then the other dimension changes - // such that the aspect ratio is maintained. - // @param size2: the second size of output. If this is also provided, the output size - // will be (size1, size2) - explicit ResizeBilinearOp(int32_t size1, int32_t size2 = kDefWidth) - : ResizeOp(size1, size2, ResizeOp::kDefInterpolation) {} - - // Name: Destructor - // Description: Destructor - ~ResizeBilinearOp() = default; - - // Name: Print() - // Description: A function that prints info about the node - void Print(std::ostream &out) const override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/resize_op.cc deleted file mode 100644 index 7c0252188e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_op.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/resize_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t ResizeOp::kDefWidth = 0; -const InterpolationMode ResizeOp::kDefInterpolation = InterpolationMode::kLinear; - -Status ResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape size " + std::to_string(input->shape().Size()) + - " of input tensor is invalid"); - int32_t output_h, output_w = 0; - int32_t input_h = static_cast(input->shape()[0]); - int32_t input_w = static_cast(input->shape()[1]); - if (size2_ == 0) { - if (input_h < input_w) { - CHECK_FAIL_RETURN_UNEXPECTED(input_h != 0, "The input height is 0"); - output_h = size1_; - output_w = static_cast(std::lround(static_cast(input_w) / input_h * output_h)); - } else { - CHECK_FAIL_RETURN_UNEXPECTED(input_w != 0, "The input width is 0"); - output_w = size1_; - output_h = static_cast(std::lround(static_cast(input_h) / input_w * output_w)); - } - } else { - output_h = size1_; - output_w = size2_; - } - return Resize(input, output, output_h, output_w, 0, 0, interpolation_); -} - -Status ResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - int32_t outputH = -1, outputW = -1; - // if size2_ == 0, then we cannot know the shape. We need the input image to find the output shape --> set it to - // <-1,-1[,3]> - if (size2_ != 0) { - outputH = size1_; - outputW = size2_; - } - TensorShape out = TensorShape{outputH, outputW}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_op.h b/mindspore/ccsrc/dataset/kernels/image/resize_op.h deleted file mode 100644 index 5a35a6076c..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_op.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RESIZE_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class ResizeOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefWidth; - static const InterpolationMode kDefInterpolation; - - // Resizes the image to the output specified size. If only one value is provided, - // the it will resize the smaller size and maintains the aspect ratio. - // @param size1: the first size of output. If only this parameter is provided - // the smaller dimension will be resized to this and then the other dimension changes - // such that the aspect ratio is maintained. - // @param size2: the second size of output. If this is also provided, the output size - // will be (size1, size2) - // @param InterpolationMode: the interpolation mode being used. - explicit ResizeOp(int32_t size1, int32_t size2 = kDefWidth, InterpolationMode mInterpolation = kDefInterpolation) - : size1_(size1), size2_(size2), interpolation_(mInterpolation) {} - - ~ResizeOp() override = default; - - void Print(std::ostream &out) const override { out << "ResizeOp: " << size1_ << " " << size2_; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - protected: - int32_t size1_; - int32_t size2_; - InterpolationMode interpolation_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.cc deleted file mode 100644 index 8a633d5678..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/image/resize_with_bbox_op.h" -#include -#include -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/pybind_support.h" -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -Status ResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); - - int32_t input_h = input[0]->shape()[0]; - int32_t input_w = input[0]->shape()[1]; - - output->resize(2); - (*output)[1] = std::move(input[1]); // move boxes over to output - - std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input[0])); - - RETURN_IF_NOT_OK(ResizeOp::Compute(std::static_pointer_cast(input_cv), &(*output)[0])); - - int32_t output_h = (*output)[0]->shape()[0]; // output height if ResizeWithBBox - int32_t output_w = (*output)[0]->shape()[1]; // output width if ResizeWithBBox - - size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor - RETURN_IF_NOT_OK(UpdateBBoxesForResize((*output)[1], bboxCount, output_w, output_h, input_w, input_h)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.h deleted file mode 100644 index 17bdd01ef1..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H -#define DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/kernels/image/resize_op.h" - -namespace mindspore { -namespace dataset { -class ResizeWithBBoxOp : public ResizeOp { - public: - // Constructor for ResizeWithBBoxOp, with default value and passing to base class constructor - explicit ResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefWidth, - InterpolationMode mInterpolation = kDefInterpolation) - : ResizeOp(size_1, size_2, mInterpolation) {} - - ~ResizeWithBBoxOp() override = default; - - void Print(std::ostream &out) const override { out << "ResizeWithBBoxOp: " << size1_ << " " << size2_; } - - Status Compute(const TensorRow &input, TensorRow *output) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc deleted file mode 100644 index 7889b3b157..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include -#include "dataset/kernels/image/uniform_aug_op.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -const int UniformAugOp::kDefNumOps = 2; - -UniformAugOp::UniformAugOp(std::vector> op_list, int32_t num_ops) - : tensor_op_list_(op_list), num_ops_(num_ops) { - rnd_.seed(GetSeed()); -} - -// compute method to apply uniformly random selected augmentations from a list -Status UniformAugOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - - // randomly select ops to be applied - std::vector> selected_tensor_ops; - std::sample(tensor_op_list_.begin(), tensor_op_list_.end(), std::back_inserter(selected_tensor_ops), num_ops_, rnd_); - - bool first = true; - for (const auto &tensor_op : selected_tensor_ops) { - // Do NOT apply the op, if second random generator returned zero - if (std::uniform_int_distribution(0, 1)(rnd_)) { - continue; - } - // apply C++ ops (note: python OPs are not accepted) - if (first) { - RETURN_IF_NOT_OK(tensor_op->Compute(input, output)); - first = false; - } else { - RETURN_IF_NOT_OK(tensor_op->Compute(std::move(*output), output)); - } - } - - // The case where no tensor op is applied. - if (output->empty()) { - *output = input; - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h deleted file mode 100644 index 824898ba2d..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#ifndef DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ -#define DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class UniformAugOp : public TensorOp { - public: - // Default number of Operations to be applied - static const int kDefNumOps; - - // Constructor for UniformAugOp - // @param std::vector> op_list: list of candidate C++ operations - // @param int32_t num_ops: number of augemtation operations to applied - UniformAugOp(std::vector> op_list, int32_t num_ops); - - // Destructor - ~UniformAugOp() override = default; - - void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; } - - // Overrides the base class compute function - // @return Status - The error code return - Status Compute(const TensorRow &input, TensorRow *output) override; - - private: - int32_t num_ops_; - std::vector> tensor_op_list_; - std::mt19937 rnd_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/no_op.h b/mindspore/ccsrc/dataset/kernels/no_op.h deleted file mode 100644 index bfbdf43b36..0000000000 --- a/mindspore/ccsrc/dataset/kernels/no_op.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_NO_OP_H_ -#define DATASET_KERNELS_NO_OP_H_ - -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class NoOp : public TensorOp { - public: - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override { - *output = input; - return Status::OK(); - } - - void Print(std::ostream &out) const override { out << "NoOp"; }; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_NO_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/py_func_op.cc b/mindspore/ccsrc/dataset/kernels/py_func_op.cc deleted file mode 100644 index 0a6a1452b5..0000000000 --- a/mindspore/ccsrc/dataset/kernels/py_func_op.cc +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/py_func_op.h" - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - Status ret = Status(StatusCode::kOK, "PyFunc Call Succeed"); - { - // Acquire Python GIL - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - ret = Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - goto ComputeReturn; - } - try { - // Transform input tensor vector into numpy array vector - py::tuple input_args(input.size()); - for (size_t i = 0; i < input.size(); i++) { - py::array new_data; - RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); - // possible memcpy here - input_args[i] = new_data; - } - // Invoke python function - py::object ret_py_obj = this->py_func_ptr_(*input_args); - // Process the return value - if (py::isinstance(ret_py_obj)) { - // In case of a n-1 mapping, the return value will be a numpy array - std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_obj.cast())); - output->push_back(out); - } else if (py::isinstance(ret_py_obj)) { - // In case of a n-m mapping, the return value will be a tuple of numpy arrays - py::tuple ret_py_tuple = ret_py_obj.cast(); - // Iterate over two containers simultaneously for memory copy - for (size_t i = 0; i < ret_py_tuple.size(); i++) { - py::object ret_py_ele = ret_py_tuple[i]; - if (!py::isinstance(ret_py_ele)) { - goto ShapeMisMatch; - } - std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_ele.cast())); - output->push_back(out); - } - } else { - goto ShapeMisMatch; - } - } catch (const py::error_already_set &e) { - ret = Status(StatusCode::kPyFuncException, e.what()); - } - } - -ComputeReturn: - return ret; - -ShapeMisMatch: - ret = Status(StatusCode::kShapeMisMatch, "PyFunc should return a numpy array or a numpy array tuple"); - goto ComputeReturn; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/py_func_op.h b/mindspore/ccsrc/dataset/kernels/py_func_op.h deleted file mode 100644 index a50aceafbb..0000000000 --- a/mindspore/ccsrc/dataset/kernels/py_func_op.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_KERNELS_PY_FUNC_OP_H_ -#define DATASET_KERNELS_PY_FUNC_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class __attribute__((visibility("hidden"))) PyFuncOp : public TensorOp { - public: - explicit PyFuncOp(py::function func) : py_func_ptr_(std::move(func)) {} - - ~PyFuncOp() override = default; - - uint32_t NumInput() override { return 0; } - uint32_t NumOutput() override { return 0; } - - // Compute function for n-n mapping. - Status Compute(const TensorRow &input, TensorRow *output) override; - - private: - py::function py_func_ptr_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_PY_FUNC_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/tensor_op.cc b/mindspore/ccsrc/dataset/kernels/tensor_op.cc deleted file mode 100644 index 92aef8dc9e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/tensor_op.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/tensor_op.h" -#include -#include -#include -#include - -namespace mindspore { -namespace dataset { -// Name: Compute() -// Description: This Compute() take 1 Tensor and produce 1 Tensor. -// The derived class should override this function otherwise error. -Status TensorOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (!OneToOne()) { - return Status(StatusCode::kUnexpectedError, "Wrong Compute() function is called. This is not 1-1 TensorOp."); - } else { - return Status(StatusCode::kUnexpectedError, - "Is this TensorOp 1-1? If yes, please implement this Compute() in the derived class."); - } -} - -// Name: Compute() -// Description: This Compute() take multiple Tensors from different columns and produce multiple Tensors too. -// The derived class should override this function otherwise error. -Status TensorOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - if (OneToOne()) { - output->resize(1); - return Compute(input[0], &(*output)[0]); - } - - return Status(StatusCode::kUnexpectedError, - "Is this TensorOp oneToOne? If no, please implement this Compute() in the derived class."); -} - -void TensorOp::Print(std::ostream &out) const { out << "TensorOp" << std::endl; } - -Status TensorOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - if (inputs.size() != NumInput()) - return Status(StatusCode::kUnexpectedError, - "The size of the input argument vector does not match the number of inputs"); - outputs = inputs; - return Status::OK(); -} - -Status TensorOp::OutputType(const std::vector &inputs, std::vector &outputs) { - if (inputs.size() != NumInput()) - return Status(StatusCode::kUnexpectedError, - "The size of the input argument vector does not match the number of inputs"); - outputs = inputs; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/tensor_op.h b/mindspore/ccsrc/dataset/kernels/tensor_op.h deleted file mode 100644 index 9aae50d6b0..0000000000 --- a/mindspore/ccsrc/dataset/kernels/tensor_op.h +++ /dev/null @@ -1,148 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_TENSOR_OP_H_ -#define DATASET_KERNELS_TENSOR_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_row.h" -#include "dataset/util/status.h" - -#define IO_CHECK(input, output) \ - do { \ - if (input == nullptr || output == nullptr) { \ - RETURN_STATUS_UNEXPECTED("input or output is null."); \ - } \ - } while (false) - -#define IO_CHECK_VECTOR(input, output) \ - do { \ - if (output == nullptr) { \ - RETURN_STATUS_UNEXPECTED("output is null."); \ - } \ - for (auto &_i : input) { \ - if (_i == nullptr) { \ - RETURN_STATUS_UNEXPECTED("input is null."); \ - } \ - } \ - } while (false) - -#define BOUNDING_BOX_CHECK(input) \ - do { \ - if (input.size() != 2) { \ - return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ - "Requires Image and Bounding Boxes, likely missed bounding boxes."); \ - } \ - if (input[1]->shape().Size() < 2) { \ - return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ - "Bounding boxes shape should have at least two dimensions."); \ - } \ - uint32_t num_of_features = input[1]->shape()[1]; \ - if (num_of_features < 4) { \ - return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ - "Bounding boxes should be have at least 4 features."); \ - } \ - uint32_t num_of_boxes = input[1]->shape()[0]; \ - uint32_t img_h = input[0]->shape()[0]; \ - uint32_t img_w = input[0]->shape()[1]; \ - for (uint32_t i = 0; i < num_of_boxes; i++) { \ - uint32_t min_x = 0; \ - uint32_t min_y = 0; \ - uint32_t b_w = 0; \ - uint32_t b_h = 0; \ - input[1]->GetItemAt(&min_x, {i, 0}); \ - input[1]->GetItemAt(&min_y, {i, 1}); \ - input[1]->GetItemAt(&b_w, {i, 2}); \ - input[1]->GetItemAt(&b_h, {i, 3}); \ - if ((min_x + b_w > img_w) || (min_y + b_h > img_h)) { \ - return Status(StatusCode::kBoundingBoxOutOfBounds, __LINE__, __FILE__, \ - "At least one of the bounding boxes is out of bounds of the image."); \ - } \ - if (static_cast(min_x) < 0 || static_cast(min_y) < 0) { \ - return Status(StatusCode::kBoundingBoxOutOfBounds, __LINE__, __FILE__, \ - "At least one of the bounding boxes has negative min_x or min_y."); \ - } \ - } \ - } while (false) - -namespace mindspore { -namespace dataset { -// A class that does a computation on a Tensor -class TensorOp { - public: - TensorOp() = default; - - virtual ~TensorOp() = default; - - // A function that prints info about the tensor operation - // @param out - virtual void Print(std::ostream &out) const; - - // Provide stream operator for displaying it - // @param output stream - // @param so the TensorOp object to be printed - // @return output stream - friend std::ostream &operator<<(std::ostream &out, const TensorOp &so) { - so.Print(out); - return out; - } - - // Perform an operation on one Tensor and produce one Tensor. This is for 1-to-1 column MapOp - // @param input shares the ownership of the Tensor (increase the ref count). - // @param output the address to a shared_ptr where the result will be placed. - // @return Status - virtual Status Compute(const std::shared_ptr &input, std::shared_ptr *output); - - // Perform an operation on Tensors from multiple columns, and produce multiple Tensors. - // This is for m-to-n column MapOp. - // @param input is a vector of shared_ptr to Tensor (pass by const reference). - // @param output is the address to an empty vector of shared_ptr to Tensor. - // @return Status - virtual Status Compute(const TensorRow &input, TensorRow *output); - - // Returns true oif the TensorOp takes one input and returns one output. - // @return true/false - bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; } - - // Function to determine the number of inputs the TensorOp can take. 0: means undefined. - // @return uint32_t - virtual uint32_t NumInput() { return 1; } - - // Function to determine the number of output the TensorOp generates. 0: means undefined. - // @return uint32_t - virtual uint32_t NumOutput() { return 1; } - - // Function to determine the shapes of the output tensor given the input tensors' shapes. - // If a subclass did not override this function, it means that the shape does not change. - // @param inputs in: vector of the shapes of the input tensors. - // @param outputs out: vector of the shapes of the output tensors to be filled. - // @return Status - virtual Status OutputShape(const std::vector &inputs, std::vector &outputs); - - // Function to determine the types of the output tensor given the input tensor's types. - // If a subclass did not override this function, it means that the type does not change. - // @param inputs in: vector of the types of the input tensors. - // @param outputs out: vector of the types of the output tensors to be filled. - // @return Status - virtual Status OutputType(const std::vector &inputs, std::vector &outputs); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_TENSOR_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc deleted file mode 100644 index 3512a4b2d7..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc +++ /dev/null @@ -1,167 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/basic_tokenizer_op.h" -#include -#include -#include -#include -#include -#include - -#include "unicode/errorcode.h" -#include "unicode/normalizer2.h" -#include "unicode/utypes.h" - -namespace mindspore { -namespace dataset { -const bool BasicTokenizerOp::kDefLowerCase = false; -const bool BasicTokenizerOp::kDefKeepWhitespace = false; -const NormalizeForm BasicTokenizerOp::kDefNormalizationForm = NormalizeForm::kNone; -const bool BasicTokenizerOp::kDefPreserveUnusedToken = true; -const char BasicTokenizerOp::kCommonPattern[] = - "[!-/]" - "|[:-@]" - "|[\\[-`]" - "|[{-~]" - "|[\\p{P}]" - "|[\\x{4E00}-\\x{9FFF}]" - "|[\\x{3400}-\\x{4DBF}]" - "|[\\x{20000}-\\x{2A6DF}]" - "|[\\x{2A700}-\\x{2B73F}]" - "|[\\x{2B740}-\\x{2B81F}]" - "|[\\x{2B820}-\\x{2CEAF}]" - "|[\\x{F900}-\\x{FAFF}]" - "|[\\x{2F800}-\\x{2FA1F}]"; -const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|\\[unused\\d+\\]|"; -const std::unordered_set BasicTokenizerOp::kUnusedWords{"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"}; -BasicTokenizerOp::BasicTokenizerOp(bool lower_case, bool keep_whitespace, NormalizeForm normalization_form, - bool preserve_unused_token) - : lower_case_(lower_case), - keep_whitespace_(keep_whitespace), - preserve_unused_token_(preserve_unused_token), - case_fold_(std::make_unique()), - nfd_normalize_(std::make_unique(NormalizeForm::kNfd)), - normalization_form_(normalization_form), - common_normalize_(std::make_unique(normalization_form)), - replace_accent_chars_(std::make_unique("\\p{Mn}", "")), - replace_control_chars_(std::make_unique("\\p{Cc}|\\p{Cf}", " ")) { - std::string delim_pattern = std::string("\\s+|") + kCommonPattern; - std::string keep_delim_pattern; - if (keep_whitespace_) { - keep_delim_pattern = delim_pattern; - } else { - keep_delim_pattern = kCommonPattern; - } - if (preserve_unused_token_) { - keep_delim_pattern = kUnusedPattern + keep_delim_pattern; - delim_pattern = kUnusedPattern + delim_pattern; - } - regex_tokenizer_ = std::make_unique(delim_pattern, keep_delim_pattern); -} - -Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text, - const std::unordered_set &unused_words, - std::string *outupt) { - icu::ErrorCode error; - const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); - outupt->clear(); - - // 1. get start and end offsets of not case fold strs - std::queue> offsets; // offsets of not used words - int start = -1; - int len = 0; - for (int i = 0; i < text.length(); i++) { - if (text[i] == '[') { - start = i; - ++len; - } else if (text[i] == ']' && start >= 0) { - ++len; - std::string word(text.substr(start, len)); - if (unused_words.find(word) != unused_words.end()) { - offsets.push(std::make_pair(start, start + len - 1)); - } - start = -1; - len = 0; - } else if (start >= 0) { - ++len; - } - } - - // 2. Do not apply case fold on `unused_words` - start = 0; - for (int i = 0; i < text.length();) { - std::string_view process_text; - std::string preserve_token; - if (offsets.empty()) { - i = text.length(); - process_text = text.substr(start, i - start); - } else { - preserve_token = text.substr(offsets.front().first, offsets.front().second - offsets.front().first + 1); - process_text = text.substr(start, offsets.front().first - start); - i = offsets.front().second + 1; - offsets.pop(); - } - std::string temp; - icu::StringByteSink sink(&temp); - nfkc_case_fold->normalizeUTF8(0, icu::StringPiece(process_text.data(), process_text.size()), sink, nullptr, error); - *outupt += temp + preserve_token; - } - return Status::OK(); -} - -Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr &input, - std::shared_ptr *output) { - IO_CHECK(input, output); - std::vector strs(input->Size()); - int i = 0; - for (auto iter = input->begin(); iter != input->end(); iter++) { - RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(*iter, kUnusedWords, &strs[i++])); - } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); -} - -Status BasicTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); - } - std::shared_ptr cur_input; - std::shared_ptr processed_tensor; - if (lower_case_) { - if (!preserve_unused_token_) { - // to lower case - RETURN_IF_NOT_OK(case_fold_->Compute(input, &processed_tensor)); - } else { - // to lower case except words in kUnusedWords - RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(input, &processed_tensor)); - } - cur_input = processed_tensor; - // strip accent characters - RETURN_IF_NOT_OK(nfd_normalize_->Compute(cur_input, &processed_tensor)); - cur_input = processed_tensor; - RETURN_IF_NOT_OK(replace_accent_chars_->Compute(cur_input, &processed_tensor)); - } else { - RETURN_IF_NOT_OK(common_normalize_->Compute(input, &processed_tensor)); - } - // strip control characters - cur_input = processed_tensor; - RETURN_IF_NOT_OK(replace_control_chars_->Compute(cur_input, &processed_tensor)); - return regex_tokenizer_->Compute(processed_tensor, output); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h deleted file mode 100644 index 01827a0ba4..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/text/kernels/case_fold_op.h" -#include "dataset/text/kernels/normalize_utf8_op.h" -#include "dataset/text/kernels/regex_replace_op.h" -#include "dataset/text/kernels/regex_tokenizer_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class BasicTokenizerOp : public TensorOp { - public: - static const bool kDefLowerCase; - static const bool kDefKeepWhitespace; - static const NormalizeForm kDefNormalizationForm; - static const bool kDefPreserveUnusedToken; - explicit BasicTokenizerOp(bool lower_case = kDefLowerCase, bool keep_whitespace = kDefKeepWhitespace, - NormalizeForm normalization_form = kDefNormalizationForm, - bool preserve_unused_token = kDefPreserveUnusedToken); - - ~BasicTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "BasicTokenizerOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - protected: - Status CaseFoldWithoutUnusedWords(const std::string_view &text, const std::unordered_set &unused_words, - std::string *outupt); - Status CaseFoldWithoutUnusedWords(const std::shared_ptr &input, std::shared_ptr *output); - - private: - static const char kCommonPattern[]; - static const char kUnusedPattern[]; - static const std::unordered_set kUnusedWords; - bool lower_case_; - bool keep_whitespace_; - NormalizeForm normalization_form_; - bool preserve_unused_token_; - std::unique_ptr case_fold_; - std::unique_ptr nfd_normalize_; - std::unique_ptr common_normalize_; - std::unique_ptr replace_accent_chars_; - std::unique_ptr replace_control_chars_; - std::unique_ptr regex_tokenizer_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.cc deleted file mode 100644 index 2b68a5accb..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.cc +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/bert_tokenizer_op.h" -namespace mindspore { -namespace dataset { -Status BertTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - std::shared_ptr basic_tensor; - RETURN_IF_NOT_OK(basic_tokenizer_.Compute(input, &basic_tensor)); - RETURN_IF_NOT_OK(wordpiece_tokenizer_.Compute(basic_tensor, output)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.h deleted file mode 100644 index 660fdc7ba5..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/text/kernels/basic_tokenizer_op.h" -#include "dataset/text/kernels/wordpiece_tokenizer_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class BertTokenizerOp : public TensorOp { - public: - explicit BertTokenizerOp(const std::shared_ptr &vocab, - const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, - const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, - const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, - bool lower_case = BasicTokenizerOp::kDefLowerCase, - bool keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, - NormalizeForm normalization_form = BasicTokenizerOp::kDefNormalizationForm, - bool preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken) - : wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token), - basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token) {} - - ~BertTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "BertTokenizerOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - WordpieceTokenizerOp wordpiece_tokenizer_; - BasicTokenizerOp basic_tokenizer_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/case_fold_op.cc b/mindspore/ccsrc/dataset/text/kernels/case_fold_op.cc deleted file mode 100644 index d935608efd..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/case_fold_op.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/case_fold_op.h" -#include -#include -#include -#include -#include - -#include "unicode/errorcode.h" -#include "unicode/normalizer2.h" -#include "unicode/utypes.h" - -namespace mindspore { -namespace dataset { - -Status CaseFoldOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - icu::ErrorCode error; - const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); - std::vector strs(input->Size()); - int i = 0; - for (auto iter = input->begin(); iter != input->end(); iter++) { - icu::StringByteSink sink(&strs[i++]); - nfkc_case_fold->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); - } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/case_fold_op.h b/mindspore/ccsrc/dataset/text/kernels/case_fold_op.h deleted file mode 100644 index d1b5ba53f1..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/case_fold_op.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ -#define DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class CaseFoldOp : public TensorOp { - public: - CaseFoldOp() {} - - ~CaseFoldOp() override = default; - - void Print(std::ostream &out) const override { out << "CaseFoldOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.cc deleted file mode 100644 index de1d915fbb..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/jieba_tokenizer_op.h" - -#include -#include -#include -#include "dataset/util/path.h" - -namespace mindspore { -namespace dataset { - -JiebaTokenizerOp::JiebaTokenizerOp(const std::string &hmm_path, const std::string &dict_path, JiebaMode mode) - : jieba_mode_(mode), hmm_model_path_(hmm_path), mp_dict_path_(dict_path) { - jieba_parser_ = std::make_unique(mp_dict_path_, hmm_model_path_, ""); -} - -Status JiebaTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - RETURN_UNEXPECTED_IF_NULL(jieba_parser_); - - if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); - } - - std::string_view sentence_v; - RETURN_IF_NOT_OK(input->GetItemAt(&sentence_v, {})); - std::string sentence{sentence_v}; - std::vector words; - if (sentence == "") { - words.push_back(""); - } else { - if (jieba_mode_ == JiebaMode::kMp) { - jieba_parser_->CutSmall(sentence, words, MAX_WORD_LENGTH); - } else if (jieba_mode_ == JiebaMode::kHmm) { - jieba_parser_->CutHMM(sentence, words); - } else { // Mix - jieba_parser_->Cut(sentence, words, true); - } - } - *output = std::make_shared(words, TensorShape({(dsize_t)words.size()})); - return Status::OK(); -} - -Status JiebaTokenizerOp::AddWord(const std::string &word, int freq) { - RETURN_UNEXPECTED_IF_NULL(jieba_parser_); - if (jieba_parser_->InsertUserWord(word, freq, "") == false) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "add word error"); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.h deleted file mode 100644 index 41736e4fdb..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_TEXT_JIEBA_OP_H_ -#define DATASET_ENGINE_TEXT_JIEBA_OP_H_ - -#include -#include - -#include "cppjieba/Jieba.hpp" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -enum class JiebaMode { kMix = 0, kMp = 1, kHmm = 2 }; - -class JiebaTokenizerOp : public TensorOp { - public: - // deffault constant for Jieba MPSegment algorithm. - static constexpr size_t MAX_WORD_LENGTH = 512; - // Constructor for JiebaTokenizerOp. - // @param hmm_path HMM model file. - // @param mp_path MP model file. - // @mode tokenization mode [Default "MIX"], "MP" model will tokenize with MPSegment algorithm, "HMM" mode will - // tokenize with Hiddel Markov Model Segment algorithm, "MIx" model will tokenize with a mix of MPSegment and - // HMMSegment algorithm. - JiebaTokenizerOp(const std::string &hmm_path, const std::string &mp_path, JiebaMode mode = JiebaMode::kMix); - ~JiebaTokenizerOp() override = default; - - void Print(std::ostream &out) const override { - out << "JiebaTokenizerOp: " << jieba_mode_ << "hmm_model_path_ " << hmm_model_path_ << "mp_dict_path_" - << mp_dict_path_; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - // @word the word to be added to the JiebaTokenizer. - // @freq [Default 0] the frequency fo the word to be added. - // @tag [Default ""] the tag of the word to be added. - Status AddWord(const std::string &word, int freq = 0); - - protected: - std::string hmm_model_path_; - std::string mp_dict_path_; - std::unique_ptr jieba_parser_; - JiebaMode jieba_mode_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_TEXT_JIEBA_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/lookup_op.cc b/mindspore/ccsrc/dataset/text/kernels/lookup_op.cc deleted file mode 100644 index 07cf7aef5c..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/lookup_op.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/lookup_op.h" - -#include - -namespace mindspore { -namespace dataset { - -LookupOp::LookupOp(std::shared_ptr vocab, WordIdType default_id) - : vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} - -Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - RETURN_UNEXPECTED_IF_NULL(vocab_); - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor"); - std::vector word_ids; - word_ids.reserve(input->Size()); - for (auto itr = input->begin(); itr != input->end(); itr++) { - word_ids.push_back(vocab_->Lookup(std::string(*itr), default_id_)); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_, - reinterpret_cast(word_ids.data()))); - return Status::OK(); -} -Status LookupOp::OutputType(const std::vector &inputs, std::vector &outputs) { - CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match"); - CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type"); - outputs[0] = type_; - return Status::OK(); -} - -void LookupOp::Print(std::ostream &out) const { - out << "LookupOp: " - << "type: " << type_ << "\n default lookup id: " << default_id_ << "\n"; -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/lookup_op.h b/mindspore/ccsrc/dataset/text/kernels/lookup_op.h deleted file mode 100644 index dad99c3241..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/lookup_op.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_TEXT_KERNELS_LOOKUP_OP_H_ -#define DATASET_TEXT_KERNELS_LOOKUP_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/text/vocab.h" - -namespace mindspore { -namespace dataset { -class LookupOp : public TensorOp { - public: - // constructor for lookup, takes in a vocab object - // @param std::shared_ptr vocab - - // @param WordIdType default_id, id to lookup if a word is not in vocab - explicit LookupOp(std::shared_ptr vocab, WordIdType default_id = 1); - - ~LookupOp() = default; - - // perform actual lookup on each tensor - // @param const std::shared_ptr &input - // @param std::shared_ptr *output - // @return error code - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - // print method - // @param std::ostream out - void Print(std::ostream &out) const override; - - // @param std::vector &inputs - - // @param std::vector &outputs - - // @return error code - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - private: - std::shared_ptr vocab_; - WordIdType default_id_; - DataType type_; // type of tensor after lookup -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_TEXT_KERNELS_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/ngram_op.cc b/mindspore/ccsrc/dataset/text/kernels/ngram_op.cc deleted file mode 100644 index bbe449a89a..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/ngram_op.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/text/kernels/ngram_op.h" - -#include -#include -#include -#include - -namespace mindspore { -namespace dataset { - -NgramOp::NgramOp(const std::vector &ngrams, int32_t l_len, int32_t r_len, const std::string &l_pad, - const std::string &r_pad, const std::string &separator) - : ngrams_(ngrams), - l_len_(l_len), - r_len_(r_len), - l_pad_with_sp_(l_pad + separator), - r_pad_with_sp_(r_pad + separator), - separator_(separator) {} - -Status NgramOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING && input->Rank() == 1, "Not a 1-D str Tensor"); - std::vector offsets; // offsets for each str - std::vector res; // holds the result of ngrams - std::string str_buffer; // concat all pad tokens with string interleaved with separators - res.reserve(input->shape().NumOfElements()); // this should be more than enough - offsets.reserve(1 + l_len_ + r_len_ + input->shape().NumOfElements()); - str_buffer.reserve(l_pad_with_sp_.size() * l_len_ + r_pad_with_sp_.size() * r_len_ + input->SizeInBytes()); - offsets.push_back(str_buffer.size()); // insert 0 as the starting pos - for (int i = 0; i < l_len_; i++) offsets.push_back((str_buffer += l_pad_with_sp_).size()); - - for (auto itr = input->begin(); itr != input->end(); itr++) { - str_buffer += (*itr); - str_buffer += separator_; - offsets.push_back(str_buffer.size()); - } - - for (int i = 0; i < r_len_; i++) offsets.push_back((str_buffer += r_pad_with_sp_).size()); - - for (auto n : ngrams_) { - CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "n gram needs to be a positive number.\n"); - int32_t start_ind = l_len_ - std::min(l_len_, n - 1); - int32_t end_ind = offsets.size() - r_len_ + std::min(r_len_, n - 1); - if (end_ind - start_ind <= n) { - res.emplace_back(std::string()); // push back empty string - } else { - CHECK_FAIL_RETURN_UNEXPECTED(end_ind - n >= 0, "Incorrect loop condition"); - - for (int i = start_ind; i < end_ind - n; i++) { - res.emplace_back(str_buffer.substr(offsets[i], offsets[i + n] - offsets[i] - separator_.size())); - } - } - } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, res, TensorShape({static_cast(res.size())}))); - return Status::OK(); -} - -void NgramOp::Print(std::ostream &out) const { - out << "NgramOp: " - << "left pad width: " << l_len_ << " left pad token with separator: " << l_pad_with_sp_ << "\n" - << "right pad width: " << r_len_ << " right pad token with separator: " << r_pad_with_sp_ << "\n" - << "separator: " << separator_ << "\n"; -} - -Status NgramOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n"); - CHECK_FAIL_RETURN_UNEXPECTED(inputs[0].Rank() == 1, "ngram only works with 1-dim data\n"); - dsize_t num_elements = ngrams_.size(); - for (int32_t n : ngrams_) { - // here since rank == 1, NumOfElements == shape[0]. add padding length to string - int32_t len_with_padding = inputs[0].NumOfElements() + std::min(n - 1, l_len_) + std::min(n - 1, r_len_); - // if len_with_padding - n < 0, this would return an empty string - num_elements += std::max(len_with_padding - n, 0); - } - outputs.emplace_back(TensorShape({num_elements})); - CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n"); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/ngram_op.h b/mindspore/ccsrc/dataset/text/kernels/ngram_op.h deleted file mode 100644 index 3d2c547f79..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/ngram_op.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_TEXT_KERNELS_NGRAM_OP_H_ -#define DATASET_TEXT_KERNELS_NGRAM_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -namespace py = pybind11; - -class NgramOp : public TensorOp { - public: - // Constructor of Ngram model - // @param const std::vector &ngrams - // @param int32_tl_len - padding length on the left - // @param int32_t r_len - padding length on the right - // @param const std::string &l_pad - padding token on the left - // @param const std::string &r_pad - padding token on the right - // @param const std::string &separator - use to join strings - NgramOp(const std::vector &ngrams, int32_t l_len, int32_t r_len, const std::string &l_pad, - const std::string &r_pad, const std::string &separator); - - // perform ngram model on each tensor - // @param const std::shared_ptr &input - // @param std::shared_ptr *output - // @return error code - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - // destructor - ~NgramOp() override = default; - - // @param std::vector &inputs - shape of input tensors - // @param std::vector &outputs - shape of output tensors - // @return error code - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - // print arg for debugging - // @param std::ostream &out - void Print(std::ostream &out) const override; - - private: - std::vector ngrams_; // list of n grams - int32_t l_len_; // left padding length - int32_t r_len_; // right padding length - std::string l_pad_with_sp_; // left padding appended with separator - std::string r_pad_with_sp_; // right padding appended with separator - std::string separator_; // separator -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_TEXT_KERNELS_NGRAM_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.cc b/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.cc deleted file mode 100644 index b902286576..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/normalize_utf8_op.h" -#include -#include -#include -#include -#include - -#include "unicode/errorcode.h" -#include "unicode/normalizer2.h" -#include "unicode/utypes.h" - -namespace mindspore { -namespace dataset { -const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc; -Status NormalizeUTF8Op::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - icu::ErrorCode error; - const icu::Normalizer2 *normalize = nullptr; - switch (normalize_form_) { - case NormalizeForm::kNone: { - *output = input; - return Status::OK(); - } - case NormalizeForm::kNfc: { - normalize = icu::Normalizer2::getNFCInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFCInstance failed"); - break; - } - case NormalizeForm::kNfkc: { - normalize = icu::Normalizer2::getNFKCInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCInstance failed"); - break; - } - case NormalizeForm::kNfd: { - normalize = icu::Normalizer2::getNFDInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFDInstance failed"); - break; - } - case NormalizeForm::kNfkd: { - normalize = icu::Normalizer2::getNFKDInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKDInstance failed"); - break; - } - default: { - RETURN_STATUS_UNEXPECTED("unexpected normalize form"); - break; - } - } - std::vector strs(input->Size()); - int i = 0; - for (auto iter = input->begin(); iter != input->end(); iter++) { - icu::StringByteSink sink(&strs[i++]); - normalize->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); - } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.h b/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.h deleted file mode 100644 index 5033f2355f..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ -#define DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -enum class NormalizeForm { - kNone = 0, - kNfc, - kNfkc, - kNfd, - kNfkd, -}; - -class NormalizeUTF8Op : public TensorOp { - public: - static const NormalizeForm kDefNormalizeForm; - explicit NormalizeUTF8Op(NormalizeForm normalize_form = kDefNormalizeForm) : normalize_form_(normalize_form) {} - - ~NormalizeUTF8Op() override = default; - - void Print(std::ostream &out) const override { out << "NormalizeUTF8Op"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - NormalizeForm normalize_form_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.cc b/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.cc deleted file mode 100644 index 1ce2c5ea61..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/regex_replace_op.h" -#include -#include -#include -#include -#include - -namespace mindspore { -namespace dataset { - -Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, - std::string *out) const { - CHECK_FAIL_RETURN_UNEXPECTED((matcher != nullptr && out != nullptr), "Input is null"); - UErrorCode icu_error = U_ZERO_ERROR; - icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(text); - matcher->reset(unicode_text); - icu::UnicodeString unicode_out; - if (replace_all_) { - unicode_out = matcher->replaceAll(replace_, icu_error); - } else { - unicode_out = matcher->replaceFirst(replace_, icu_error); - } - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "RegexReplace failed"); - unicode_out.toUTF8String(*out); - return Status::OK(); -} - -Status RegexReplaceOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - UErrorCode icu_error = U_ZERO_ERROR; - icu::RegexMatcher matcher(pattern_, 0, icu_error); - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern"); - std::vector strs(input->Size()); - int i = 0; - for (auto iter = input->begin(); iter != input->end(); iter++) { - RETURN_IF_NOT_OK(RegexReplace(&matcher, *iter, &strs[i])); - } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.h b/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.h deleted file mode 100644 index 30fae13241..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ -#define DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ -#include -#include - -#include "unicode/regex.h" -#include "unicode/errorcode.h" -#include "unicode/utypes.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class RegexReplaceOp : public TensorOp { - public: - RegexReplaceOp(const std::string &pattern, const std::string &replace, bool replace_all = true) - : pattern_(icu::UnicodeString::fromUTF8(pattern)), - replace_(icu::UnicodeString::fromUTF8(replace)), - replace_all_(replace_all) {} - - ~RegexReplaceOp() override = default; - - void Print(std::ostream &out) const override { out << "RegexReplaceOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - protected: - Status RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, std::string *out) const; - - private: - const icu::UnicodeString pattern_; - const icu::UnicodeString replace_; - const bool replace_all_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.cc deleted file mode 100644 index 34c06f28ea..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.cc +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/regex_tokenizer_op.h" -#include -#include -#include -#include -#include - -namespace mindspore { -namespace dataset { -Status RegexTokenizerOp::GetUnicodeSubstr(const icu::UnicodeString &input, int start, int len, std::string *out_utf8, - icu::UnicodeString *out_unicode) const { - CHECK_FAIL_RETURN_UNEXPECTED((out_utf8 != nullptr || out_unicode != nullptr), "Wrong input"); - int total_len = input.length(); - int end = start + len; - CHECK_FAIL_RETURN_UNEXPECTED((start >= 0 && len > 0 && end <= total_len), "Out of range"); - icu::UnicodeString temp; - input.extract(start, len, temp); - if (out_utf8 != nullptr) { - temp.toUTF8String(*out_utf8); - } - if (out_unicode != nullptr) { - *out_unicode = temp; - } - return Status::OK(); -} - -Status RegexTokenizerOp::GetRegexTokens(const std::string &text, std::vector *out_tokens) const { - UErrorCode status = U_ZERO_ERROR; - out_tokens->clear(); - icu::RegexMatcher token_matcher(delim_pattern_, 0, status); - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Create icu RegexMatcher failed, you may input one error pattern"); - icu::RegexMatcher delim_matcher(keep_delim_pattern_, 0, status); - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Create icu RegexMatcher failed, you may input one error pattern"); - - icu::UnicodeString utext(icu::UnicodeString::fromUTF8(text)); - token_matcher.reset(utext); - - int token_start_index = 0; - status = U_ZERO_ERROR; - while (token_matcher.find(status) && U_SUCCESS(status)) { - int deli_start_index = token_matcher.start(status); - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Get RegexMatcher matched start index failed"); - int deli_end_index = token_matcher.end(status); - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Get RegexMatcher matched start index failed"); - - // Add non-empty token - int token_len = deli_start_index - token_start_index; - if (token_len > 0) { - std::string token; - RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, token_start_index, token_len, &token)); - out_tokens->emplace_back(std::move(token)); - } - - int delim_len = deli_end_index - deli_start_index; - if (keep_delim_ && delim_len > 0) { - icu::UnicodeString delim_str; - std::string delim_utf8_str; - RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, deli_start_index, delim_len, &delim_utf8_str, &delim_str)); - delim_matcher.reset(delim_str); - if (delim_matcher.matches(status) && U_SUCCESS(status)) { - out_tokens->emplace_back(std::move(delim_utf8_str)); - } - } - token_start_index = deli_end_index; - } - - if (token_start_index < utext.length()) { - std::string temp; - RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, token_start_index, utext.length() - token_start_index, &temp)); - out_tokens->emplace_back(std::move(temp)); - } - return Status::OK(); -} - -Status RegexTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); - } - std::string_view text; - RETURN_IF_NOT_OK(input->GetItemAt(&text, {})); - std::vector tokens; - RETURN_IF_NOT_OK(GetRegexTokens(std::string(text.data(), text.size()), &tokens)); - *output = std::make_shared(std::move(tokens), TensorShape({(dsize_t)tokens.size()})); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.h deleted file mode 100644 index bcf02a4a11..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_REGEX_TOKENIZER_OP_H_ -#define DATASET_TEXT_REGEX_TOKENIZER_OP_H_ -#include -#include -#include - -#include "unicode/regex.h" -#include "unicode/errorcode.h" -#include "unicode/utypes.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class RegexTokenizerOp : public TensorOp { - public: - RegexTokenizerOp(const std::string &delim_pattern, const std::string &keep_delim_pattern) - : delim_pattern_(icu::UnicodeString::fromUTF8(delim_pattern)), - keep_delim_pattern_(icu::UnicodeString::fromUTF8(keep_delim_pattern)), - keep_delim_(!keep_delim_pattern.empty()) {} - - ~RegexTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "RegexTokenizerOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - protected: - Status GetUnicodeSubstr(const icu::UnicodeString &input, int start, int len, std::string *out_utf8, - icu::UnicodeString *out_unicode = nullptr) const; - Status GetRegexTokens(const std::string &text, std::vector *out_tokens) const; - - private: - const icu::UnicodeString delim_pattern_; - const icu::UnicodeString keep_delim_pattern_; - const bool keep_delim_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_REGEX_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/to_number_op.cc b/mindspore/ccsrc/dataset/text/kernels/to_number_op.cc deleted file mode 100644 index 1368684daf..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/to_number_op.cc +++ /dev/null @@ -1,241 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/text/kernels/to_number_op.h" - -#include -#include -#include -#include -#include -#include - -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -ToNumberOp::ToNumberOp(const DataType &cast_to_type) : cast_to_type_(cast_to_type) {} - -ToNumberOp::ToNumberOp(const std::string &cast_to_type) : cast_to_type_(DataType(cast_to_type)) {} - -Status ToNumberOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tenosrs should have type string."); - - switch (cast_to_type_.value()) { - case DataType::DE_INT8: - RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); - break; - case DataType::DE_INT16: - RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); - break; - case DataType::DE_INT32: - RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); - break; - case DataType::DE_INT64: - RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); - break; - case DataType::DE_UINT8: - RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); - break; - case DataType::DE_UINT16: - RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); - break; - case DataType::DE_UINT32: - RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); - break; - case DataType::DE_UINT64: - RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); - break; - case DataType::DE_FLOAT16: - RETURN_IF_NOT_OK(this->ToFloat16(input, output)); - break; - case DataType::DE_FLOAT32: - RETURN_IF_NOT_OK(ToFloat(input, output)); - break; - case DataType::DE_FLOAT64: - RETURN_IF_NOT_OK(ToDouble(input, output)); - break; - } - - return Status::OK(); -} - -void ToNumberOp::Print(std::ostream &out) const { out << "ToNumberOp: casting to " << '\n'; } - -Status ToNumberOp::OutputShape(const std::vector &input_shapes, std::vector &output_shapes) { - (void)std::copy(input_shapes.begin(), input_shapes.end(), std::back_inserter(output_shapes)); - return Status::OK(); -} - -template -Status ToNumberOp::ToSignedIntegral(const std::shared_ptr &input, std::shared_ptr *output) { - std::vector casted; - - for (auto it = input->begin(); it != input->end(); ++it) { - bool is_cast_out_of_range = false; - int64_t result = 0; - - try { - result = std::stoll(std::string(*it)); - } catch (const std::out_of_range &) { - is_cast_out_of_range = true; - } catch (const std::invalid_argument &) { - RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to a number."); - } - - if (result > std::numeric_limits::max() || result < std::numeric_limits::min() || is_cast_out_of_range) { - std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + - cast_to_type_.ToString() + ". The valid range is: [" + - std::to_string(std::numeric_limits::min()) + ", " + - std::to_string(std::numeric_limits::max()) + "]."; - - RETURN_STATUS_UNEXPECTED(error_message); - } - - T casted_result = static_cast(result); - casted.push_back(casted_result); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); - return Status::OK(); -} - -template -Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr &input, std::shared_ptr *output) { - std::vector casted; - - for (auto it = input->begin(); it != input->end(); ++it) { - bool is_cast_out_of_range = false; - uint64_t result = 0; - - // If there is a - at the start of the string, it is considered by us to - // be out of bounds. If the - is somewhere else in the string, it is - // deemed invalid by std::stoull and will throw std::invalid_argument - for (int i = 0; i < (*it).size(); i++) { - if ((*it)[i] == '-') { - is_cast_out_of_range = true; - break; - } - } - - try { - result = std::stoull(std::string(*it)); - } catch (const std::out_of_range &) { - is_cast_out_of_range = true; - } catch (const std::invalid_argument &) { - RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); - } - - if (result > std::numeric_limits::max() || result < std::numeric_limits::min() || is_cast_out_of_range) { - std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + - cast_to_type_.ToString() + ". The valid range is: [" + - std::to_string(std::numeric_limits::min()) + ", " + - std::to_string(std::numeric_limits::max()) + "]."; - - RETURN_STATUS_UNEXPECTED(error_message); - } - - T casted_result = static_cast(result); - casted.push_back(casted_result); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); - return Status::OK(); -} - -Status ToNumberOp::ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { - // special case, float16 does not exist in c++, no native support for - // casting, so cast to float first then use this method, which use Eigen. - std::shared_ptr temp; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32"))); - RETURN_IF_NOT_OK(ToFloat(input, &temp)); - RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output)); - return Status::OK(); -} - -Status ToNumberOp::ToFloat(const std::shared_ptr &input, std::shared_ptr *output) { - std::vector casted; - - for (auto it = input->begin(); it != input->end(); ++it) { - bool is_cast_out_of_range = false; - float result = 0; - - try { - result = std::stof(std::string(*it)); - } catch (const std::out_of_range &) { - is_cast_out_of_range = true; - } catch (const std::invalid_argument &) { - RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); - } - - if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest() || - is_cast_out_of_range) { - std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + - cast_to_type_.ToString() + ". The valid range is: [" + - std::to_string(std::numeric_limits::lowest()) + ", " + - std::to_string(std::numeric_limits::max()) + "]."; - - RETURN_STATUS_UNEXPECTED(error_message); - } - - float casted_result = static_cast(result); - casted.push_back(casted_result); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); - return Status::OK(); -} - -Status ToNumberOp::ToDouble(const std::shared_ptr &input, std::shared_ptr *output) { - std::vector casted; - - for (auto it = input->begin(); it != input->end(); ++it) { - bool is_cast_out_of_range = false; - double result = 0; - - try { - result = std::stod(std::string(*it)); - } catch (const std::out_of_range &) { - is_cast_out_of_range = true; - } catch (const std::invalid_argument &) { - RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); - } - - if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest() || - is_cast_out_of_range) { - std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + - cast_to_type_.ToString() + ". The valid range is: [" + - std::to_string(std::numeric_limits::lowest()) + ", " + - std::to_string(std::numeric_limits::max()) + "]."; - - RETURN_STATUS_UNEXPECTED(error_message); - } - - double casted_result = static_cast(result); - casted.push_back(casted_result); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/to_number_op.h b/mindspore/ccsrc/dataset/text/kernels/to_number_op.h deleted file mode 100644 index 1346ce2f47..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/to_number_op.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ -#define DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ - -#include -#include -#include - -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class ToNumberOp : public TensorOp { - public: - // Constructor of ToNumberOp - // @param const DataType &cast_to_type - the type to convert string inputs to. - explicit ToNumberOp(const DataType &cast_to_type); - - // Constructor of ToNumberOp - // @param const std::string &cast_to_type - the type in string form to convert string inputs to. - explicit ToNumberOp(const std::string &cast_to_type); - - ~ToNumberOp() override = default; - - // Perform numeric conversion on each string in each tensor. - // @param const std::shared_ptr &input - // @param std::shared_ptr *output - // @return error code - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - // For each input shape, find the output shape - // @param std::vector &inputs - shape of input tensors - // @param std::vector &outputs - shape of output tensors - // @return error code - Status OutputShape(const std::vector &input_shapes, std::vector &output_shapes) override; - - // print arg for debugging - // @param std::ostream &out - void Print(std::ostream &out) const override; - - private: - template - Status ToSignedIntegral(const std::shared_ptr &input, std::shared_ptr *output); - - template - Status ToUnsignedIntegral(const std::shared_ptr &input, std::shared_ptr *output); - - Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output); - - Status ToFloat(const std::shared_ptr &input, std::shared_ptr *output); - - Status ToDouble(const std::shared_ptr &input, std::shared_ptr *output); - - DataType cast_to_type_; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc b/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc deleted file mode 100644 index 136d5006df..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/text/kernels/truncate_sequence_pair_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/kernels/data/slice_op.h" - -namespace mindspore { -namespace dataset { - -Status TruncateSequencePairOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 2, "Number of inputs should be two."); - std::shared_ptr seq1 = input[0]; - std::shared_ptr seq2 = input[1]; - CHECK_FAIL_RETURN_UNEXPECTED(seq1->shape().Rank() == 1 && seq2->shape().Rank() == 1, - "Both sequences should be of rank 1"); - dsize_t length1 = seq1->shape()[0]; - dsize_t length2 = seq2->shape()[0]; - dsize_t outLength1 = length1; - dsize_t outLength2 = length2; - - dsize_t total = length1 + length2; - while (total > max_length_) { - if (outLength1 > outLength2) - outLength1--; - else - outLength2--; - total--; - } - std::shared_ptr outSeq1; - if (length1 != outLength1) { - std::unique_ptr slice1(new SliceOp(Slice(outLength1 - length1))); - RETURN_IF_NOT_OK(slice1->Compute(seq1, &outSeq1)); - } else { - outSeq1 = std::move(seq1); - } - - std::shared_ptr outSeq2; - if (length2 != outLength2) { - std::unique_ptr slice2(new SliceOp(Slice(outLength2 - length2))); - RETURN_IF_NOT_OK(slice2->Compute(seq2, &outSeq2)); - } else { - outSeq2 = std::move(seq2); - } - output->push_back(outSeq1); - output->push_back(outSeq2); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h b/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h deleted file mode 100644 index e8be6802a8..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ -#define DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/kernels/data/data_utils.h" - -namespace mindspore { -namespace dataset { - -class TruncateSequencePairOp : public TensorOp { - public: - explicit TruncateSequencePairOp(dsize_t length) : max_length_(length) {} - - ~TruncateSequencePairOp() override = default; - - void Print(std::ostream &out) const override { out << "TruncateSequencePairOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - private: - dsize_t max_length_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.cc deleted file mode 100644 index 063bf21630..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/unicode_char_tokenizer_op.h" -#include -#include -#include -#include - -#include "cppjieba/Unicode.hpp" - -using cppjieba::DecodeRunesInString; -using cppjieba::RuneStrArray; - -namespace mindspore { -namespace dataset { - -Status UnicodeCharTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); - } - std::string_view str; - RETURN_IF_NOT_OK(input->GetItemAt(&str, {})); - - RuneStrArray runes; - if (!DecodeRunesInString(str.data(), str.size(), runes)) { - RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); - } - std::vector splits(runes.size()); - for (size_t i = 0; i < runes.size(); i++) { - splits[i] = str.substr(runes[i].offset, runes[i].len); - } - if (splits.empty()) { - splits.emplace_back(""); - } - *output = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.h deleted file mode 100644 index 01a84eca8b..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class UnicodeCharTokenizerOp : public TensorOp { - public: - UnicodeCharTokenizerOp() {} - - ~UnicodeCharTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "UnicodeCharTokenizerOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; -}; - -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.cc deleted file mode 100644 index 97a4f1333d..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/unicode_script_tokenizer_op.h" -#include -#include -#include -#include -#include - -#include "cppjieba/Unicode.hpp" -#include "unicode/errorcode.h" -#include "unicode/uchar.h" -#include "unicode/uscript.h" - -using cppjieba::DecodeRunesInString; -using cppjieba::RuneStrArray; - -namespace mindspore { -namespace dataset { - -const bool UnicodeScriptTokenizerOp::kDefKeepWhitespace = false; - -Status UnicodeScriptTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); - } - std::string_view str; - RETURN_IF_NOT_OK(input->GetItemAt(&str, {})); - RuneStrArray runes; - if (!DecodeRunesInString(str.data(), str.size(), runes)) { - RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); - } - - UScriptCode last_script = USCRIPT_INVALID_CODE; - icu::ErrorCode status; - int start = 0; - int len = 0; - std::vector splits; - - bool was_space = false; - for (size_t i = 0; i < runes.size(); i++) { - bool is_space = u_isUWhiteSpace(runes[i].rune); - UScriptCode script = uscript_getScript(runes[i].rune, status); - if (status.isFailure()) { - status.reset(); - script = USCRIPT_INVALID_CODE; - } - // 1) Seperate UTF-8 strings of different UScriptCode values - // (such as: "Chinese中国" should be splited to ["Chinese", "中国"]) - // 2) Seperate whitespace and non-whitespace UTF-8 strings - // (such as: " ." should be split to [" ", "."]) - if (len > 0 && (script != last_script || is_space != was_space)) { - // 3) If keep_whitespace_ is false, all the whitespace characters will be discard - if (keep_whitespace_ || !was_space) { - std::string temp(str.substr(start, len)); - splits.emplace_back(std::move(temp)); - } - start = runes[i].offset; - len = runes[i].len; - } else { - len += runes[i].len; - } - last_script = script; - was_space = is_space; - } - - if (len > 0 && (keep_whitespace_ || !was_space)) { - std::string temp(str.substr(start, len)); - splits.emplace_back(std::move(temp)); - } - // 4) If the input is empty scalar string, the output will be 1-D empty string. - if (splits.empty()) { - splits.emplace_back(""); - } - *output = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.h deleted file mode 100644 index a77b0b3fa3..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class UnicodeScriptTokenizerOp : public TensorOp { - public: - static const bool kDefKeepWhitespace; - - explicit UnicodeScriptTokenizerOp(bool keep_whitespace = kDefKeepWhitespace) : keep_whitespace_(keep_whitespace) {} - - ~UnicodeScriptTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "UnicodeScriptTokenizerOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - private: - bool keep_whitespace_; // If or not keep whitespace tokens -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.cc deleted file mode 100644 index 35f3f8d0e2..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/whitespace_tokenizer_op.h" -#include -#include -#include -#include -#include - -#include "cppjieba/Unicode.hpp" -#include "unicode/errorcode.h" -#include "unicode/uchar.h" -#include "unicode/uscript.h" - -using cppjieba::DecodeRunesInString; -using cppjieba::RuneStrArray; - -namespace mindspore { -namespace dataset { -Status WhitespaceTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (input->Rank() != 0 || input->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); - } - std::string_view str; - RETURN_IF_NOT_OK(input->GetItemAt(&str, {})); - - RuneStrArray runes; - if (!DecodeRunesInString(str.data(), str.size(), runes)) { - RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); - } - std::vector splits; - int start = 0; - int len = 0; - for (size_t i = 0; i < runes.size(); i++) { - if (u_isUWhiteSpace(runes[i].rune)) { - if (len > 0) { - std::string temp(str.substr(start, len)); - splits.emplace_back(std::move(temp)); - len = 0; - } - } else { - if (len == 0) { - start = runes[i].offset; - } - len += runes[i].len; - } - } - if (len > 0) { - std::string temp(str.substr(start, len)); - splits.emplace_back(std::move(temp)); - } - if (splits.empty()) { - splits.emplace_back(""); - } - *output = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.h deleted file mode 100644 index 6d0bab0bea..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class WhitespaceTokenizerOp : public TensorOp { - public: - WhitespaceTokenizerOp() {} - - ~WhitespaceTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "WhitespaceTokenizerOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.cc deleted file mode 100644 index e488c527cd..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/text/kernels/wordpiece_tokenizer_op.h" -#include -#include - -namespace mindspore { -namespace dataset { - -const char WordpieceTokenizerOp::kDefSuffixIndicator[] = "##"; -const int WordpieceTokenizerOp::kDefMaxBytesPerToken = 100; -const char WordpieceTokenizerOp::kDefUnknownToken[] = "[UNK]"; - -WordpieceTokenizerOp::WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator, - const int &max_bytes_per_token, const std::string &unknown_token) - : vocab_(vocab), - suffix_indicator_(suffix_indicator), - max_bytes_per_token_(max_bytes_per_token), - unknown_token_(unknown_token) {} - -Status WordpieceTokenizerOp::LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, - bool *out_found, int *out_end) const { - CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && start < input_token.size(), "Out of range"); - *out_found = false; - for (int i = runes.size() - 1; i >= 0; i--) { - *out_end = runes[i].offset + runes[i].len; - int len = *out_end - start; - std::string word = input_token.substr(start, len); - if (start > 0) { - word = suffix_indicator_ + word; - } - WordIdType default_id = -1; - if (vocab_->Lookup(word, default_id) != default_id) { - *out_found = true; - break; - } - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::FoundNoToken(const std::string &input_token, std::vector *out_tokens) const { - out_tokens->clear(); - if (unknown_token_.empty()) { - out_tokens->emplace_back(input_token); - } else { - out_tokens->emplace_back(unknown_token_); - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::AddSubword(const std::string &input_token, const int start, const int end, - std::vector *out_tokens) const { - CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && end > start && end <= input_token.size(), "Out of range"); - std::string subword = input_token.substr(start, end - start); - if (start > 0) { - subword = suffix_indicator_ + subword; - } - out_tokens->emplace_back(subword); - return Status::OK(); -} - -Status WordpieceTokenizerOp::GetTokens(const std::string &input_token, std::vector *out_tokens) const { - if (input_token.size() > max_bytes_per_token_) { - return FoundNoToken(input_token, out_tokens); - } - RuneStrArray runes; - if (!DecodeRunesInString(input_token.data(), input_token.size(), runes)) { - RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); - } - int end = 0; - for (int start = 0; start < input_token.size();) { - bool found = false; - RETURN_IF_NOT_OK(LookupWord(input_token, runes, start, &found, &end)); - if (found) { - RETURN_IF_NOT_OK(AddSubword(input_token, start, end, out_tokens)); - start = end; - } else { - return FoundNoToken(input_token, out_tokens); - } - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (input->Rank() > 1 || input->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar or 1-D string tensor"); - } - std::vector out_tokens; - for (auto iter = input->begin(); iter != input->end(); iter++) { - std::vector temp_tokens; - RETURN_IF_NOT_OK(GetTokens(std::string(*iter), &temp_tokens)); - out_tokens.insert(out_tokens.end(), temp_tokens.begin(), temp_tokens.end()); - } - if (out_tokens.empty()) { - out_tokens.emplace_back(""); - } - *output = std::make_shared(out_tokens, TensorShape({(dsize_t)out_tokens.size()})); - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.h deleted file mode 100644 index c9a75025c6..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ -#include -#include -#include -#include - -#include "cppjieba/Unicode.hpp" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/text/vocab.h" -#include "dataset/util/status.h" - -using cppjieba::DecodeRunesInString; -using cppjieba::RuneStrArray; -namespace mindspore { -namespace dataset { - -class WordpieceTokenizerOp : public TensorOp { - public: - static const char kDefSuffixIndicator[]; - static const int kDefMaxBytesPerToken; - static const char kDefUnknownToken[]; - WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator = kDefSuffixIndicator, - const int &max_bytes_per_token = kDefMaxBytesPerToken, - const std::string &unknown_token = kDefUnknownToken); - - ~WordpieceTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "WordpieceTokenizerOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - protected: - Status AddSubword(const std::string &input_token, const int start, const int end, - std::vector *out_token) const; - Status FoundNoToken(const std::string &input_token, std::vector *out_tokens) const; - Status LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, bool *out_found, - int *out_end) const; - Status GetTokens(const std::string &input_token, std::vector *out_tokens) const; - - private: - const std::shared_ptr vocab_; - const std::string suffix_indicator_; - const int max_bytes_per_token_; - const std::string unknown_token_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/vocab.cc b/mindspore/ccsrc/dataset/text/vocab.cc deleted file mode 100644 index 100dc9d655..0000000000 --- a/mindspore/ccsrc/dataset/text/vocab.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include - -#include "dataset/text/vocab.h" - -namespace mindspore { -namespace dataset { -Vocab::Vocab(std::unordered_map word2id) { word2id_ = std::move(word2id); } - -WordIdType Vocab::Lookup(const WordType &word, WordIdType default_id) const { - auto itr = word2id_.find(word); - return itr == word2id_.end() ? default_id : itr->second; -} - -Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special, - std::shared_ptr *vocab) { - // check of duplication on both words and special_tokens will be performed in python - // special_tokens and words both need to be unique, and shouldn't overlap - std::unordered_map word2id; - // if special is added in front, normal words id will start from number of special tokens - WordIdType word_id = prepend_special ? static_cast(special_tokens.size()) : 0; - - for (auto word : words) { - word2id[py::str(word)] = word_id++; - } - - word_id = prepend_special ? 0 : word2id.size(); - - for (auto special_token : special_tokens) { - word2id[py::str(special_token)] = word_id++; - } - - *vocab = std::make_shared(std::move(word2id)); - return Status::OK(); -} - -Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, - const py::list &special_tokens, bool prepend_special, std::shared_ptr *vocab) { - // python validator checks special_tokens doesn't contain any duplicate words - std::unordered_set specials; - // used to check that words in file don't contain any special token that already exists - for (auto word : special_tokens) { - specials.insert(py::str(word)); - } - WordIdType word_id = prepend_special ? static_cast(special_tokens.size()) : 0; - std::unordered_map word2id; - std::fstream handle(path, std::ios::in); - CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path); - std::string word; - while (std::getline(handle, word)) { - if (!delimiter.empty()) { - // if delimiter is not found, find_first_of would return std::string::npos which is -1 - word = word.substr(0, word.find_first_of(delimiter)); - } - CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word + "."); - CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(), word + " is already in special_tokens."); - word2id[word] = word_id++; - // break if enough row is read, if vocab_size is smaller than 0 - if (word2id.size() == vocab_size) break; - } - - word_id = prepend_special ? 0 : word2id.size(); - - for (auto special_token : special_tokens) { - word2id[py::str(special_token)] = word_id++; - } - - *vocab = std::make_shared(std::move(word2id)); - return Status::OK(); -} - -Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr *vocab) { - std::unordered_map word2id; - for (auto p : words) { - word2id[py::str(p.first)] = py::reinterpret_borrow(p.second); - } - *vocab = std::make_shared(std::move(word2id)); - return Status::OK(); -} - -void Vocab::append_word(const std::string &word) { - if (word2id_.find(word) == word2id_.end()) { - word2id_[word] = word2id_.size(); - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/vocab.h b/mindspore/ccsrc/dataset/text/vocab.h deleted file mode 100644 index fc21c380a2..0000000000 --- a/mindspore/ccsrc/dataset/text/vocab.h +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_TEXT_VOCAB_H_ -#define DATASET_TEXT_VOCAB_H_ - -#include -#include -#include -#include - -#include "dataset/util/status.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -namespace mindspore { -namespace dataset { -namespace py = pybind11; - -using WordIdType = int32_t; -using WordType = std::string; - -class Vocab { - public: - // Build a vocab from a python dictionary key is each word ,id needs to start from 2, no duplicate and continuous - // @param const py::dict &words - a dictionary containing word, word id pair. - // @param std::shared_ptr *vocab - return value, vocab object - // @return error code - static Status BuildFromPyDict(const py::dict &words, std::shared_ptr *vocab); - - // Build a vocab from a python list, id will be assigned automatically, start from 2 - // @param const py::list &words - a list of string, used to build vocab, id starts from 2 - // @param std::shared_ptr *vocab - return value, vocab object - // @return error code - static Status BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special, - std::shared_ptr *vocab); - - // Build a vocab from reading a vocab file, id are automatically assigned, start from 2 - // @param std::string &path - path to vocab file , each line is assumed to contain 1 word - // @param std::string &delimiter - delimiter to break each line with - // @param int32_t vocab_size - number of words to read from file - // @param std::shared_ptr *vocab - return value, vocab object - // @return error code - static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, - const py::list &special_tokens, bool prepend_special, std::shared_ptr *vocab); - - // Lookup the id of a word, if word doesn't exist in vocab, return default_id - // @param const WordType word - word to look up - // @param WordIdType default_id - word id to return to user when its not in the vocab - // @return WordIdType, word_id - WordIdType Lookup(const WordType &word, WordIdType default_id) const; - - // reverse lookup, lookup the word based on its id - // @param WordIdType id - word id to lookup to - // @return WordType the word - WordType Lookup(WordIdType id); - - // constructor, shouldn't be called directly, can't be private due to std::make_unique() - // @param std::unordered_map map - sanitized word2id map - explicit Vocab(std::unordered_map map); - - Vocab() = default; - - // add one word to vocab, increment it's index automatically - // @param std::string & word - word to be added will skip if word already exists - void append_word(const std::string &word); - - // destructor - ~Vocab() = default; - - private: - std::unordered_map word2id_; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_TEXT_VOCAB_H_ diff --git a/mindspore/ccsrc/dataset/util/allocator.h b/mindspore/ccsrc/dataset/util/allocator.h deleted file mode 100644 index 50a9cadbe3..0000000000 --- a/mindspore/ccsrc/dataset/util/allocator.h +++ /dev/null @@ -1,177 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_ALLOCATOR_H_ -#define DATASET_UTIL_ALLOCATOR_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/memory_pool.h" - -namespace mindspore { -namespace dataset { -// The following conforms to the requirements of -// std::allocator. Do not rename/change any needed -// requirements, e.g. function names, typedef etc. -template -class Allocator { - public: - template - friend class Allocator; - - using value_type = T; - using pointer = T *; - using const_pointer = const T *; - using reference = T &; - using const_reference = const T &; - using size_type = uint64_t; - - template - struct rebind { - using other = Allocator; - }; - - using propagate_on_container_copy_assignment = std::true_type; - using propagate_on_container_move_assignment = std::true_type; - using propagate_on_container_swap = std::true_type; - - explicit Allocator(const std::shared_ptr &b) : pool_(b) {} - - ~Allocator() = default; - - template - explicit Allocator(Allocator const &rhs) : pool_(rhs.pool_) {} - - template - bool operator==(Allocator const &rhs) const { - return pool_ == rhs.pool_; - } - - template - bool operator!=(Allocator const &rhs) const { - return pool_ != rhs.pool_; - } - - pointer allocate(std::size_t n) { - void *p; - Status rc = pool_->Allocate(n * sizeof(T), &p); - if (rc.IsOk()) { - return reinterpret_cast(p); - } else if (rc.IsOutofMemory()) { - throw std::bad_alloc(); - } else { - throw std::exception(); - } - } - - void deallocate(pointer p, std::size_t n = 0) noexcept { pool_->Deallocate(p); } - - size_type max_size() { return pool_->get_max_size(); } - - private: - std::shared_ptr pool_; -}; -/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will -/// be released when the object goes out of scope \tparam T The type of object to be allocated \tparam C Allocator. -/// Default to std::allocator -template > -class MemGuard { - public: - using allocator = C; - MemGuard() : n_(0) {} - explicit MemGuard(allocator a) : n_(0), alloc_(a) {} - // There is no copy constructor nor assignment operator because the memory is solely owned by this object. - MemGuard(const MemGuard &) = delete; - MemGuard &operator=(const MemGuard &) = delete; - // On the other hand, We can support move constructor - MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} - MemGuard &operator=(MemGuard &&lhs) noexcept { - if (this != &lhs) { - this->deallocate(); - n_ = lhs.n_; - alloc_ = std::move(lhs.alloc_); - ptr_ = std::move(lhs.ptr_); - } - return *this; - } - /// \brief Explicitly deallocate the memory if allocated - void deallocate() { - if (ptr_) { - auto *p = ptr_.release(); - if (!std::is_arithmetic::value && std::is_destructible::value) { - for (auto i = 0; i < n_; ++i) { - p[i].~T(); - } - } - alloc_.deallocate(p, n_); - n_ = 0; - } - } - /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is - /// allocated. - /// \param n Number of objects of type T to be allocated - /// \tparam Args Extra arguments pass to the constructor of T - template - Status allocate(size_t n, Args &&... args) noexcept { - try { - deallocate(); - if (n > 0) { - T *data = alloc_.allocate(n); - if (!std::is_arithmetic::value) { - for (auto i = 0; i < n; i++) { - std::allocator_traits::construct(alloc_, &(data[i]), std::forward(args)...); - } - } - ptr_ = std::unique_ptr(data); - n_ = n; - } - } catch (const std::bad_alloc &e) { - return Status(StatusCode::kOutOfMemory); - } catch (std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - return Status::OK(); - } - ~MemGuard() noexcept { deallocate(); } - /// \brief Getter function - /// \return The pointer to the memory allocated - T *GetPointer() const { return ptr_.get(); } - /// \brief Getter function - /// \return The pointer to the memory allocated - T *GetMutablePointer() { return ptr_.get(); } - /// \brief Overload [] operator to access a particular element - /// \param x index to the element. Must be less than number of element allocated. - /// \return pointer to the x-th element - T *operator[](size_t x) { return GetMutablePointer() + x; } - /// \brief Overload [] operator to access a particular element - /// \param x index to the element. Must be less than number of element allocated. - /// \return pointer to the x-th element - T *operator[](size_t x) const { return GetPointer() + x; } - /// \brief Return how many bytes are allocated in total - /// \return Number of bytes allocated in total - size_t GetSizeInBytes() const { return n_ * sizeof(T); } - - private: - allocator alloc_; - std::unique_ptr> ptr_; - size_t n_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/dataset/util/arena.cc b/mindspore/ccsrc/dataset/util/arena.cc deleted file mode 100644 index af4f522678..0000000000 --- a/mindspore/ccsrc/dataset/util/arena.cc +++ /dev/null @@ -1,256 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/arena.h" -#include -#include -#include "dataset/util/system_pool.h" -#include "./securec.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -struct MemHdr { - uint32_t sig; - uint64_t addr; - uint64_t blk_size; - MemHdr(uint64_t a, uint64_t sz) : sig(0xDEADBEEF), addr(a), blk_size(sz) {} - static void setHdr(void *p, uint64_t addr, uint64_t sz) { new (p) MemHdr(addr, sz); } - static void getHdr(void *p, MemHdr *hdr) { - auto *tmp = reinterpret_cast(p); - *hdr = *tmp; - } -}; -Status Arena::Init() { - RETURN_IF_NOT_OK(DeMalloc(size_in_MB_ * 1048576L, &ptr_, false)); - // Divide the memory into blocks. Ignore the last partial block. - uint64_t num_blks = size_in_bytes_ / ARENA_BLK_SZ; - MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << "."; - tr_.Insert(0, num_blks); - return Status::OK(); -} - -Status Arena::Allocate(size_t n, void **p) { - if (n == 0) { - *p = nullptr; - return Status::OK(); - } - std::unique_lock lck(mux_); - // Round up n to 1K block - uint64_t req_size = static_cast(n) + ARENA_WALL_OVERHEAD_SZ; - if (req_size > this->get_max_size()) { - return Status(StatusCode::kOutOfMemory); - } - uint64_t reqBlk = SizeToBlk(req_size); - // Do a first fit search - auto blk = tr_.Top(); - if (blk.second && reqBlk <= blk.first.priority) { - uint64_t addr = blk.first.key; - uint64_t size = blk.first.priority; - // Trim to the required size and return the rest to the tree. - tr_.Pop(); - if (size > reqBlk) { - tr_.Insert(addr + reqBlk, size - reqBlk); - } - lck.unlock(); - char *q = static_cast(ptr_) + addr * ARENA_BLK_SZ; - MemHdr::setHdr(q, addr, reqBlk); - *p = get_user_addr(q); - } else { - return Status(StatusCode::kOutOfMemory); - } - return Status::OK(); -} - -void Arena::Deallocate(void *p) { - auto *q = get_base_addr(p); - MemHdr hdr(0, 0); - MemHdr::getHdr(q, &hdr); - MS_ASSERT(hdr.sig == 0xDEADBEEF); - // We are going to insert a free block back to the treap. But first, check if we can combine - // with the free blocks before and after to form a bigger block. - std::unique_lock lck(mux_); - // Query if we have a free block after us. - auto nextBlk = tr_.Search(hdr.addr + hdr.blk_size); - if (nextBlk.second) { - // Form a bigger block - hdr.blk_size += nextBlk.first.priority; - tr_.DeleteKey(nextBlk.first.key); - } - // Next find a block in front of us. - auto result = FindPrevBlk(hdr.addr); - if (result.second) { - // We can combine with this block - hdr.addr = result.first.first; - hdr.blk_size += result.first.second; - tr_.DeleteKey(result.first.first); - } - // Now we can insert the free node - tr_.Insert(hdr.addr, hdr.blk_size); -} - -Status Arena::Reallocate(void **pp, size_t old_sz, size_t new_sz) { - MS_ASSERT(pp); - MS_ASSERT(*pp); - uint64_t actual_size = static_cast(new_sz) + ARENA_WALL_OVERHEAD_SZ; - if (actual_size > this->get_max_size()) { - RETURN_STATUS_UNEXPECTED("Request size too big : " + std::to_string(new_sz)); - } - uint64_t req_blk = SizeToBlk(actual_size); - char *oldAddr = reinterpret_cast(*pp); - auto *oldHdr = get_base_addr(oldAddr); - MemHdr hdr(0, 0); - MemHdr::getHdr(oldHdr, &hdr); - MS_ASSERT(hdr.sig == 0xDEADBEEF); - std::unique_lock lck(mux_); - if (hdr.blk_size > req_blk) { - // Refresh the header with the new smaller size. - MemHdr::setHdr(oldHdr, hdr.addr, req_blk); - // Return the unused memory back to the tree. Unlike allocate, we we need to merge with the block after us. - auto next_blk = tr_.Search(hdr.addr + hdr.blk_size); - if (next_blk.second) { - hdr.blk_size += next_blk.first.priority; - tr_.DeleteKey(next_blk.first.key); - } - tr_.Insert(hdr.addr + req_blk, hdr.blk_size - req_blk); - } else if (hdr.blk_size < req_blk) { - uint64_t addr = hdr.addr; - // Attempt a block enlarge. No guarantee it is always successful. - bool success = BlockEnlarge(&addr, hdr.blk_size, req_blk); - if (success) { - auto *newHdr = static_cast(ptr_) + addr * ARENA_BLK_SZ; - MemHdr::setHdr(newHdr, addr, req_blk); - if (addr != hdr.addr) { - errno_t err = - memmove_s(get_user_addr(newHdr), (req_blk * ARENA_BLK_SZ) - ARENA_WALL_OVERHEAD_SZ, oldAddr, old_sz); - if (err) { - RETURN_STATUS_UNEXPECTED("Error from memmove: " + std::to_string(err)); - } - } - *pp = get_user_addr(newHdr); - return Status::OK(); - } - // If we reach here, allocate a new block and simply move the content from the old to the new place. - // Unlock since allocate will grab the lock again. - lck.unlock(); - return FreeAndAlloc(pp, old_sz, new_sz); - } - return Status::OK(); -} - -std::ostream &operator<<(std::ostream &os, const Arena &s) { - for (auto &it : s.tr_) { - os << "Address : " << it.key << ". Size : " << it.priority << "\n"; - } - return os; -} - -Arena::Arena(size_t val_in_MB) : ptr_(nullptr), size_in_MB_(val_in_MB), size_in_bytes_(val_in_MB * 1048576L) {} - -Status Arena::CreateArena(std::shared_ptr *p_ba, size_t val_in_MB) { - if (p_ba == nullptr) { - RETURN_STATUS_UNEXPECTED("p_ba is null"); - } - Status rc; - auto ba = new (std::nothrow) Arena(val_in_MB); - if (ba == nullptr) { - return Status(StatusCode::kOutOfMemory); - } - rc = ba->Init(); - if (rc.IsOk()) { - (*p_ba).reset(ba); - } else { - delete ba; - } - return rc; -} - -int Arena::PercentFree() const { - uint64_t sz = 0; - for (auto &it : tr_) { - sz += it.priority; - } - double ratio = static_cast(sz * ARENA_BLK_SZ) / static_cast(size_in_bytes_); - return static_cast(ratio * 100.0); -} - -uint64_t Arena::get_max_size() const { return (size_in_bytes_ - ARENA_WALL_OVERHEAD_SZ); } - -std::pair, bool> Arena::FindPrevBlk(uint64_t addr) { - for (auto &it : tr_) { - if (it.key + it.priority == addr) { - return std::make_pair(std::make_pair(it.key, it.priority), true); - } else if (it.key > addr) { - break; - } - } - return std::make_pair(std::make_pair(0, 0), false); -} - -bool Arena::BlockEnlarge(uint64_t *addr, uint64_t old_sz, uint64_t new_sz) { - uint64_t size = old_sz; - // The logic is very much identical to Deallocate. We will see if we can combine with the blocks before and after. - auto next_blk = tr_.Search(*addr + old_sz); - if (next_blk.second) { - size += next_blk.first.priority; - if (size >= new_sz) { - // In this case, we can just enlarge the block without doing any moving. - tr_.DeleteKey(next_blk.first.key); - // Return unused back to the tree. - if (size > new_sz) { - tr_.Insert(*addr + new_sz, size - new_sz); - } - } - return true; - } - // If we still get here, we have to look at the block before us. - auto result = FindPrevBlk(*addr); - if (result.second) { - // We can combine with this block together with the next block (if any) - size += result.first.second; - *addr = result.first.first; - if (size >= new_sz) { - // We can combine with this block together with the next block (if any) - tr_.DeleteKey(*addr); - if (next_blk.second) { - tr_.DeleteKey(next_blk.first.key); - } - // Return unused back to the tree. - if (size > new_sz) { - tr_.Insert(*addr + new_sz, size - new_sz); - } - return true; - } - } - return false; -} - -Status Arena::FreeAndAlloc(void **pp, size_t old_sz, size_t new_sz) { - MS_ASSERT(pp); - MS_ASSERT(*pp); - void *p = nullptr; - void *q = *pp; - RETURN_IF_NOT_OK(Allocate(new_sz, &p)); - errno_t err = memmove_s(p, new_sz, q, old_sz); - if (err) { - RETURN_STATUS_UNEXPECTED("Error from memmove: " + std::to_string(err)); - } - *pp = p; - // Free the old one. - Deallocate(q); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/arena.h b/mindspore/ccsrc/dataset/util/arena.h deleted file mode 100644 index 8c5d1e1093..0000000000 --- a/mindspore/ccsrc/dataset/util/arena.h +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_ARENA_H_ -#define DATASET_UTIL_ARENA_H_ - -#include -#include -#include -#include "dataset/util/memory_pool.h" -#include "dataset/util/treap.h" - -#define ARENA_LOG_BLK_SZ (6u) -#define ARENA_BLK_SZ (static_cast(1u << ARENA_LOG_BLK_SZ)) -#define ARENA_WALL_OVERHEAD_SZ 32 -namespace mindspore { -namespace dataset { -// This is a memory arena based on a treap data structure. -// The constructor of the Arena takes the size of the initial memory size (in MB). -// Internally we divide the memory into multiple blocks. Each block is 64 bytes. -// The treap contains all the free blocks with the relative memory address as key -// and the size of the block as priority. -// -// Initially the treap has only one root which is the whole memory piece. -// -// For memory suballocation, we pop the root node of the treap which contains the largest free block. -// We allocate what we need and return the rest back to the treap. We search for the first fit instead -// of the best fit so to give us a constant time in memory allocation. -// -// When a block of memory is freed. It is joined with the blocks before and after (if they are available) to -// form a bigger block. -class Arena : public MemoryPool { - public: - Arena(const Arena &) = delete; - - Arena &operator=(const Arena &) = delete; - - ~Arena() override { - if (ptr_ != nullptr) { - free(ptr_); - ptr_ = nullptr; - } - } - - Status Allocate(size_t n, void **p) override; - - Status Reallocate(void **, size_t old_sz, size_t new_sz) override; - - void Deallocate(void *) override; - - uint64_t get_max_size() const override; - - static uint64_t SizeToBlk(uint64_t sz) { - uint64_t req_blk = sz / ARENA_BLK_SZ; - if (sz % ARENA_BLK_SZ) { - ++req_blk; - } - return req_blk; - } - - int PercentFree() const override; - - const void *get_base_addr() const { return ptr_; } - - friend std::ostream &operator<<(std::ostream &os, const Arena &s); - - static Status CreateArena(std::shared_ptr *p_ba, size_t val_in_MB = 4096); - - private: - std::mutex mux_; - Treap tr_; - void *ptr_; - size_t size_in_MB_; - size_t size_in_bytes_; - - explicit Arena(size_t val_in_MB = 4096); - - std::pair, bool> FindPrevBlk(uint64_t addr); - - Status Init(); - - bool BlockEnlarge(uint64_t *addr, uint64_t old_sz, uint64_t new_sz); - - Status FreeAndAlloc(void **pp, size_t old_sz, size_t new_sz); - - void *get_user_addr(void *base_addr) const { return reinterpret_cast(base_addr) + ARENA_WALL_OVERHEAD_SZ; } - - void *get_base_addr(void *user_addr) const { return reinterpret_cast(user_addr) - ARENA_WALL_OVERHEAD_SZ; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_ARENA_H_ diff --git a/mindspore/ccsrc/dataset/util/auto_index.h b/mindspore/ccsrc/dataset/util/auto_index.h deleted file mode 100644 index 5c43ecfd80..0000000000 --- a/mindspore/ccsrc/dataset/util/auto_index.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_AUTO_INDEX_H_ -#define DATASET_UTIL_AUTO_INDEX_H_ - -#include -#include -#include -#include - -#include "dataset/util/btree.h" -#include "dataset/util/system_pool.h" - -namespace mindspore { -namespace dataset { -/// This is a B+ tree with generated int64_t value as key. -/// Use minKey() function to query the min key. -/// Use maxKey() function to query the max key. -/// @tparam T -template > -class AutoIndexObj : public BPlusTree { - public: - using my_tree = BPlusTree; - using key_type = typename my_tree::key_type; - using value_type = typename my_tree::value_type; - - AutoIndexObj() : my_tree::BPlusTree(), inx_(kMinKey) {} - - explicit AutoIndexObj(const Allocator &alloc) : my_tree::BPlusTree(alloc), inx_(kMinKey) {} - - ~AutoIndexObj() = default; - - // Insert an object into the tree. - // @param val - // @return - Status insert(const value_type &val, key_type *key = nullptr) { - key_type my_inx = inx_.fetch_add(1); - if (key != nullptr) { - *key = my_inx; - } - return my_tree::DoInsert(my_inx, val); - } - - Status insert(std::unique_ptr &&val, key_type *key = nullptr) { - key_type my_inx = inx_.fetch_add(1); - if (key) { - *key = my_inx; - } - return my_tree::DoInsert(my_inx, std::move(val)); - } - - // Insert a vector of objects into the tree. - // @param v - // @return - Status insert(std::vector v) { - uint64_t num_ele = v.size(); - if (num_ele > 0) { - // reserve a range of keys rather than getting it one by one. - key_type my_inx = inx_.fetch_add(num_ele); - for (uint64_t i = 0; i < num_ele; i++) { - RETURN_IF_NOT_OK(my_tree::DoInsert(my_inx + i, v.at(i))); - } - } - return Status::OK(); - } - - // @return the minimum key - key_type min_key() const { - auto it = this->cbegin(); - return it.key(); - } - - // @return the maximum key - key_type max_key() const { - auto it = this->cend(); - --it; - return it.key(); - } - - private: - static constexpr key_type kMinKey = 0; - std::atomic inx_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_AUTO_INDEX_H_ diff --git a/mindspore/ccsrc/dataset/util/btree.h b/mindspore/ccsrc/dataset/util/btree.h deleted file mode 100644 index ccf642e366..0000000000 --- a/mindspore/ccsrc/dataset/util/btree.h +++ /dev/null @@ -1,459 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_INDEX_H_ -#define DATASET_UTIL_INDEX_H_ - -#include -#include -#include -#include -#include -#include -#include "./securec.h" -#include "dataset/util/allocator.h" -#include "dataset/util/list.h" -#include "dataset/util/lock.h" -#include "dataset/util/memory_pool.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// Default traits for a B+ tree -struct BPlusTreeTraits { - // This determines the limit of number of keys in a node. - using slot_type = uint16_t; - // Number of slots in each leaf of the tree. - static constexpr slot_type kLeafSlots = 256; - // Number of slots in each inner node of the tree - static constexpr slot_type kInnerSlots = 128; -}; - -/// Implementation of B+ tree -/// @tparam K -- the type of key -/// @tparam V -- the type of value -/// @tparam A -- allocator -/// @tparam C -- comparison class -/// @tparam T -- trait -template , typename C = std::less, - typename T = BPlusTreeTraits> -class BPlusTree { - public: - enum class IndexRc : char { - kOk = 0, - kDuplicateKey = 1, - kSlotFull = 2, - kKeyNotFound = 3, - kNullPointer = 4, - kOutOfMemory = 5, - kRetry = 6, - kUnexpectedError = 127 - }; -#define RETURN_IF_BAD_RC(_s) \ - do { \ - IndexRc __rc = (_s); \ - if (__rc != IndexRc::kOk) { \ - return __rc; \ - } \ - } while (false) - - Status IndexRc2Status(IndexRc rc) { - if (rc == IndexRc::kOk) { - return Status(StatusCode::kOK); - } else if (rc == IndexRc::kOutOfMemory) { - return Status(StatusCode::kOutOfMemory); - } else if (rc == IndexRc::kDuplicateKey) { - return Status(StatusCode::kDuplicateKey); - } else { - RETURN_STATUS_UNEXPECTED(std::to_string(static_cast(rc))); - } - } - - using key_type = K; - using value_type = V; - using key_compare = C; - using slot_type = typename T::slot_type; - using traits = T; - using value_allocator = A; - using key_allocator = typename value_allocator::template rebind::other; - using slot_allocator = typename value_allocator::template rebind::other; - - BPlusTree(); - - explicit BPlusTree(const Allocator &alloc); - - ~BPlusTree() noexcept; - - BPlusTree(const BPlusTree &) = delete; - - BPlusTree(BPlusTree &&) = delete; - - BPlusTree &operator=(const BPlusTree &) = delete; - - BPlusTree &operator=(BPlusTree &&) = delete; - - key_compare key_comp() const { return key_less_; } - - size_t size() const { return stats_.size_; } - - bool empty() const { return (size() == 0); } - - /// @param key - /// @param value - /// @return - Status DoInsert(const key_type &key, const value_type &value); - Status DoInsert(const key_type &key, std::unique_ptr &&value); - - // Update a new value for a given key. - std::unique_ptr DoUpdate(const key_type &key, const value_type &new_value); - std::unique_ptr DoUpdate(const key_type &key, std::unique_ptr &&new_value); - - // Statistics - struct tree_stats { - std::atomic size_; - uint32_t leaves_; - uint32_t inner_nodes_; - uint32_t level_; - - tree_stats() : size_(0), leaves_(0), inner_nodes_(0), level_(0) {} - }; - - private: - // Abstract class of a node (leaf or inner) - class BaseNode { - public: - friend class BPlusTree; - - virtual bool is_leafnode() const = 0; - - virtual bool is_full() const = 0; - - explicit BaseNode(const value_allocator &alloc) : alloc_(alloc) {} - - virtual ~BaseNode() = default; - - protected: - mutable RWLock rw_lock_; - value_allocator alloc_; - - private: - Node lru_; - }; - - // This control block keeps track of all the nodes we traverse on insert. - // To maximize concurrency, internal nodes are latched S. If a node split - // is required, we must releases all the latches and redo it again and change - // the latch mode from S to X. - struct LockPathCB { - enum class LockMode : char { kShared = 0, kExclusive = 1, kNone = 2 }; - - struct path { - BaseNode *node_; - bool locked_; - - path() : node_(nullptr), locked_(false) {} - - path(BaseNode *p, LockMode lockmode) : node_(p), locked_(false) { - if (lockmode == LockMode::kExclusive) { - p->rw_lock_.LockExclusive(); - locked_ = true; - } else if (lockmode == LockMode::kShared) { - p->rw_lock_.LockShared(); - locked_ = true; - } - } - }; - - LockPathCB(BPlusTree *tree, bool retryWithXlock) : self_(tree), latch_shared_(true) { - if (retryWithXlock) { - latch_shared_ = false; - } - if (latch_shared_) { - tree->rw_lock_.LockShared(); - } else { - tree->rw_lock_.LockExclusive(); - } - } - - ~LockPathCB() noexcept { - // Make sure all locks are released. - while (!paths_.empty()) { - path p = paths_.back(); - paths_.pop_back(); - if (p.locked_) { - p.node_->rw_lock_.Unlock(); - } - } - self_->rw_lock_.Unlock(); - self_ = nullptr; - } - - void LockNode(BaseNode *p, LockMode locktype) { paths_.emplace_back(p, locktype); } - - void UnlockMyParents(BaseNode *me) { - path p = paths_.front(); - while (p.node_ != me) { - if (p.locked_) { - p.node_->rw_lock_.Unlock(); - } - paths_.pop_front(); - p = paths_.front(); - } - } - - BPlusTree *self_; - std::deque paths_; - bool latch_shared_; - }; - - // Definition of inner node which fans to either inner node or leaf node. - class InnerNode : public BaseNode { - public: - friend class BPlusTree; - - using alloc_type = typename value_allocator::template rebind::other; - - bool is_leafnode() const override { return false; } - - bool is_full() const override { return (slotuse_ == traits::kInnerSlots); } - - IndexRc Sort(); - - // 50/50 split - IndexRc Split(InnerNode *to, key_type *split_key); - - IndexRc InsertIntoSlot(slot_type slot, const key_type &key, BaseNode *ptr); - - explicit InnerNode(const value_allocator &alloc) : BaseNode::BaseNode(alloc), slotuse_(0) {} - - ~InnerNode() = default; - - slot_type slot_dir_[traits::kInnerSlots] = {0}; - key_type keys_[traits::kInnerSlots] = {0}; - BaseNode *data_[traits::kInnerSlots + 1] = {nullptr}; - slot_type slotuse_; - }; - - // Definition of a leaf node which contains the key/value pair - class LeafNode : public BaseNode { - public: - friend class BPlusTree; - - using alloc_type = typename value_allocator::template rebind::other; - Node link_; - - bool is_leafnode() const override { return true; } - - bool is_full() const override { return (slotuse_ == traits::kLeafSlots); } - - IndexRc Sort(); - - // 50/50 split - IndexRc Split(LeafNode *to); - - IndexRc InsertIntoSlot(LockPathCB *insCB, slot_type slot, const key_type &key, std::unique_ptr &&value); - - explicit LeafNode(const value_allocator &alloc) : BaseNode::BaseNode(alloc), slotuse_(0) {} - - ~LeafNode() = default; - - slot_type slot_dir_[traits::kLeafSlots] = {0}; - key_type keys_[traits::kLeafSlots] = {0}; - std::unique_ptr data_[traits::kLeafSlots]; - slot_type slotuse_; - }; - - mutable RWLock rw_lock_; - value_allocator alloc_; - // All the leaf nodes. Used by the iterator to traverse all the key/values. - List leaf_nodes_; - // All the nodes (inner + leaf). Used by the destructor to free the memory of all the nodes. - List all_; - // Pointer to the root of the tree. - BaseNode *root_; - // Key comparison object - key_compare key_less_; - // Stat - tree_stats stats_; - - bool LessThan(const key_type &a, const key_type &b) const { return key_less_(a, b); } - - bool EqualOrLessThan(const key_type &a, const key_type &b) const { return !key_less_(b, a); } - - bool Equal(const key_type &a, const key_type &b) const { return !key_less_(a, b) && !key_less_(b, a); } - - IndexRc AllocateInner(InnerNode **p); - - IndexRc AllocateLeaf(LeafNode **p); - - template - slot_type FindSlot(const node_type *node, const key_type &key, bool *duplicate = nullptr) const { - slot_type lo = 0; - while (lo < node->slotuse_ && key_comp()(node->keys_[node->slot_dir_[lo]], key)) { - ++lo; - } - bool keymatch = (lo < node->slotuse_ && Equal(key, node->keys_[node->slot_dir_[lo]])); - if (keymatch && !node->is_leafnode()) { - // For an inner node and we match a key during search, we should look into the next slot. - ++lo; - } - if (duplicate != nullptr) { - *duplicate = keymatch; - } - return lo; - } - - IndexRc LeafInsertKeyValue(LockPathCB *ins_cb, LeafNode *node, const key_type &key, - std::unique_ptr &&value, key_type *split_key, LeafNode **split_node); - - IndexRc InnerInsertKeyChild(InnerNode *node, const key_type &key, BaseNode *ptr, key_type *split_key, - InnerNode **split_node); - - inline BaseNode *FindBranch(InnerNode *inner, slot_type slot) const { - BaseNode *child = nullptr; - if (slot == 0) { - child = inner->data_[0]; - } else { - child = inner->data_[inner->slot_dir_[slot - 1] + 1]; - } - return child; - } - - IndexRc InsertKeyValue(LockPathCB *ins_cb, BaseNode *n, const key_type &key, std::unique_ptr &&value, - key_type *split_key, BaseNode **split_node); - - IndexRc Locate(RWLock *parent_lock, bool forUpdate, BaseNode *top, const key_type &key, LeafNode **ln, - slot_type *s) const; - - public: - class Iterator : public std::iterator { - public: - using reference = BPlusTree::value_type &; - using pointer = BPlusTree::value_type *; - - explicit Iterator(BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0), locked_(false) {} - - Iterator(LeafNode *leaf, slot_type slot, bool locked = false) : cur_(leaf), slot_(slot), locked_(locked) {} - - ~Iterator(); - - explicit Iterator(const Iterator &); - - Iterator &operator=(const Iterator &lhs); - - Iterator(Iterator &&); - - Iterator &operator=(Iterator &&lhs); - - pointer operator->() const { return cur_->data_[cur_->slot_dir_[slot_]].get(); } - - reference operator*() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } - - const key_type &key() const { return cur_->keys_[cur_->slot_dir_[slot_]]; } - - value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } - - // Prefix++ - Iterator &operator++(); - - // Postfix++ - Iterator operator++(int); - - // Prefix-- - Iterator &operator--(); - - // Postfix-- - Iterator operator--(int); - - bool operator==(const Iterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } - bool operator!=(const Iterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } - - private: - typename BPlusTree::LeafNode *cur_; - slot_type slot_; - bool locked_; - }; - - class ConstIterator : public std::iterator { - public: - using reference = BPlusTree::value_type &; - using pointer = BPlusTree::value_type *; - - explicit ConstIterator(const BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0), locked_(false) {} - - ~ConstIterator(); - - ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false) - : cur_(leaf), slot_(slot), locked_(locked) {} - - explicit ConstIterator(const ConstIterator &); - - ConstIterator &operator=(const ConstIterator &lhs); - - ConstIterator(ConstIterator &&); - - ConstIterator &operator=(ConstIterator &&lhs); - - pointer operator->() const { return cur_->data_[cur_->slot_dir_[slot_]].get(); } - - reference operator*() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } - - const key_type &key() const { return cur_->keys_[cur_->slot_dir_[slot_]]; } - - value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } - - // Prefix++ - ConstIterator &operator++(); - - // Postfix++ - ConstIterator operator++(int); - - // Prefix-- - ConstIterator &operator--(); - - // Postfix-- - ConstIterator operator--(int); - - bool operator==(const ConstIterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } - bool operator!=(const ConstIterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } - - private: - const typename BPlusTree::LeafNode *cur_; - slot_type slot_; - bool locked_; - }; - - Iterator begin(); - Iterator end(); - - ConstIterator begin() const; - ConstIterator end() const; - - ConstIterator cbegin() const; - ConstIterator cend() const; - - // Locate the entry with key - std::pair Search(const key_type &key) const; - std::pair Search(const key_type &key); - - value_type operator[](key_type key); -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_INDEX_H_ - -#include "btree_impl.tpp" -#include "btree_iterator.tpp" diff --git a/mindspore/ccsrc/dataset/util/buddy.cc b/mindspore/ccsrc/dataset/util/buddy.cc deleted file mode 100644 index 540fa993d6..0000000000 --- a/mindspore/ccsrc/dataset/util/buddy.cc +++ /dev/null @@ -1,388 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/buddy.h" -#include -#include -#include "dataset/util/memory_pool.h" -#include "dataset/util/system_pool.h" -#include "utils/log_adapter.h" -#include "./securec.h" - -inline uint64_t BitLeftShift(uint64_t v, uint64_t n) { return (v << n); } - -inline uint64_t BitRightShift(uint64_t v, uint64_t n) { return (v >> n); } - -inline uint64_t BitOr(uint64_t rhs, uint64_t lhs) { return rhs | lhs; } - -inline uint64_t BitEx(uint64_t rhs, uint64_t lhs) { return rhs ^ lhs; } - -inline uint64_t BitAnd(uint64_t rhs, uint64_t lhs) { return rhs & lhs; } - -namespace mindspore { -namespace dataset { -Status BuddySpace::Init() { - if (log_min_ < 0) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "log_min must be positive : " + std::to_string(log_min_)); - } - if (num_lvl_ < 3 || num_lvl_ > 18) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "num_lvl must be between 3 and 18 : " + std::to_string(num_lvl_)); - } - min_ = BitLeftShift(1, log_min_); - max_ = BitLeftShift(1, log_min_ + num_lvl_ - 1); - size_t offset_1 = sizeof(rel_addr_t) * num_lvl_; - size_t offset_2 = sizeof(int) * num_lvl_ + offset_1; - size_t offset_3 = sizeof(char) * BitLeftShift(1, num_lvl_ - 3) + offset_2; - RETURN_IF_NOT_OK(DeMalloc(offset_3, &ptr_, true)); - hint_ = reinterpret_cast(ptr_); - count_ = reinterpret_cast((reinterpret_cast(ptr_) + offset_1)); - map_ = reinterpret_cast(ptr_) + offset_2; - count_[num_lvl_ - 1] = 1; - map_[0] = BitOr(MORE_BIT, num_lvl_ - 3); - return Status::OK(); -} - -Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) noexcept { - std::lock_guard lock(mutex_); - addr_t addr = AllocNoLock(sz, desc); - if (addr != NOSPACE) { - *p = addr; - return Status::OK(); - } else { - return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore."); - } -} - -addr_t BuddySpace::AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept { - MS_ASSERT(sz <= max_); - uint32_t reqSize = SizeToBlock(sz); - rel_addr_t rel_addr = AllocBuddySeg(reqSize); - if (rel_addr != static_cast(NOSPACE)) { - (void)memset_s(desc, sizeof(BSpaceDescriptor), 0, sizeof(BSpaceDescriptor)); - desc->sig = static_cast(0xDEADBEEF); - desc->addr = rel_addr; - desc->req_size = reqSize; - desc->blk_size = NextPowerOf2(reqSize); - return static_cast(rel_addr * min_); - } else { - return NOSPACE; - } -} - -void BuddySpace::FreeNoLock(const BSpaceDescriptor *desc) { - MS_ASSERT(desc->sig == 0XDEADBEEF); - rel_addr_t rel_addr = desc->addr; - size_t blk_size = desc->blk_size; - size_t req_size = desc->req_size; - FreeBuddySeg(rel_addr, blk_size, req_size); -} - -void BuddySpace::Free(const BSpaceDescriptor *desc) { - std::lock_guard lock(mutex_); - return FreeNoLock(desc); -} - -std::ostream &operator<<(std::ostream &os, const BuddySpace &s) { - os << "1 unit = " << s.GetMinSize() << "\n" - << "Size of buddy space = " << s.GetMaxSize() << "\n" - << "Number of levels = " << s.num_lvl_ << "\n\n" - << "Percent free = " << s.PercentFree() << "\n" - << "Dumping count array : " - << "\n"; - for (int i = 0; i < s.num_lvl_; i++) { - os << "[" << i << "] = " << s.count_[i] << " "; - if (((i + 1) % 4) == 0) { - os << "\n"; - } - } - os << "\n"; - os << "Dumping allocation info:" - << "\n"; - auto max_addr = static_cast(BitLeftShift(1, s.num_lvl_ - 1)); - rel_addr_t addr = 0; - while (addr < max_addr) { - size_t sz = 0; - BuddySpace::STATE st; - s.GetBuddySegState(addr, &sz, &st); - os << "Address : " << std::left << std::setw(8) << addr << " Size : " << std::setw(8) << sz << " State : " - << ((st == BuddySpace::STATE::kAlloc) ? "ALLOC" : ((st == BuddySpace::STATE::kFree) ? "FREE" : "Unkonwn")) - << "\n"; - addr += sz; - } - return os; -} - -void BuddySpace::GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const { - char byte; - int pos; - int offset; - uint64_t val = 0; - int shift; - pos = BitRightShift(rel_addr, 2); - offset = rel_addr % 4; - shift = offset * 2; - byte = map_[pos]; - switch (offset) { - case 0: - val = byte; - break; - case 1: - case 3: - if (offset == 1) { - val = BitLeftShift(BitAnd(byte, 0x30), shift); - } else { - val = BitLeftShift(BitAnd(byte, 0x03), shift); - } - break; - case 2: - val = BitLeftShift(BitAnd(byte, 0x0F), shift); - break; - } - if (BitAnd(val, ONE_BIT)) { - *rel_sz = 1; - } else if (BitAnd(val, TWO_BIT)) { - *rel_sz = 2; - } else if (BitAnd(val, MORE_BIT)) { - log_t lg = BitAnd(val, 0x0F); - *rel_sz = BitLeftShift(1, lg + 2); - } else { - *st = STATE::kEmpty; - return; - } - *st = BitAnd(val, ALLOC_BIT) ? STATE::kAlloc : STATE::kFree; -} - -void BuddySpace::SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st) { - int clr; - int mask; - int pos; - int offset; - int val = 0; - int shift; - auto log_sz = static_cast(Log2(rel_sz)); - pos = BitRightShift(rel_addr, 2); - offset = rel_addr % 4; - shift = offset * 2; - if (rel_sz == 1) { - val = ONE_BIT; - mask = 0xC0; - } else if (rel_sz == 2) { - val = TWO_BIT; - mask = 0xF0; - } else { - val = BitOr(log_sz - 2, MORE_BIT); - mask = 0xFF; - } - if (st == STATE::kAlloc) { - val = BitOr(val, ALLOC_BIT); - } else if (st == STATE::kFree) { - val = BitAnd(val, ~(static_cast(ALLOC_BIT))); - } else if (st == STATE::kEmpty) { - val = 0; - } - clr = static_cast(~(BitRightShift(mask, shift))); - map_[pos] = static_cast(BitAnd(map_[pos], clr)); - map_[pos] = static_cast(BitOr(map_[pos], BitRightShift(val, shift))); - if (st == STATE::kAlloc) { - count_[log_sz]--; - } else if (st == STATE::kFree) { - count_[log_sz]++; - if (rel_addr < hint_[log_sz]) { - hint_[log_sz] = rel_addr; - } - } -} - -void BuddySpace::JoinBuddySeg(rel_addr_t addr, size_t blk_sz) { - while (blk_sz < BitLeftShift(1, num_lvl_)) { - rel_addr_t buddy = BitEx(addr, blk_sz); - size_t sz = 0; - STATE st; - GetBuddySegState(buddy, &sz, &st); - if (st == STATE::kFree && sz == blk_sz) { - auto log_sz = static_cast(Log2(blk_sz)); - rel_addr_t left = (buddy < addr) ? buddy : addr; - rel_addr_t right = left + blk_sz; - MS_ASSERT(count_[log_sz] >= 2); - count_[log_sz] -= 2; - SetBuddySegState(right, blk_sz, STATE::kEmpty); - SetBuddySegState(left, BitLeftShift(blk_sz, 1), STATE::kFree); - for (int i = 0; i < log_sz; i++) { - if (hint_[i] == right) { - hint_[i] = left; - } - } - addr = left; - blk_sz <<= 1u; - } else { - break; - } - } -} - -void BuddySpace::TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { - MS_ASSERT(ask_sz < blk_sz); - uint32_t inx = Log2(blk_sz); - size_t remaining_sz = ask_sz; - for (int i = inx; i > 0; i--) { - size_t b_size = BitLeftShift(1, i); - size_t half_sz = BitRightShift(b_size, 1); - count_[i]--; - SetBuddySegState(addr, half_sz, STATE::kFree); - SetBuddySegState(addr + half_sz, half_sz, STATE::kFree); - if (remaining_sz >= half_sz) { - SetBuddySegState(addr, half_sz, STATE::kAlloc); - remaining_sz -= half_sz; - if (remaining_sz == 0) { - break; - } - addr += half_sz; - } - } -} - -void BuddySpace::UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { - MS_ASSERT(ask_sz < blk_sz); - uint32_t inx = Log2(blk_sz); - size_t remaining_sz = ask_sz; - for (int i = inx; i > 0; i--) { - size_t b_size = BitLeftShift(1, i); - size_t half_sz = BitRightShift(b_size, 1); - if (remaining_sz >= half_sz) { -#ifdef DEBUG - { - size_t sz = 0; - STATE st; - GetBuddySegState(addr, &sz, &st); - MS_ASSERT(sz == half_sz && st == STATE::kAlloc); - } -#endif - SetBuddySegState(addr, half_sz, STATE::kFree); - remaining_sz -= half_sz; - if (remaining_sz == 0) { - JoinBuddySeg(addr, half_sz); - break; - } - addr += half_sz; - } - } -} - -rel_addr_t BuddySpace::AllocBuddySeg(uint32_t req_size) noexcept { - uint32_t blk_size = NextPowerOf2(req_size); - int start_inx = static_cast(Log2(blk_size)); - bool found = false; - rel_addr_t ask_addr = 0; - auto max_addr = static_cast(BitLeftShift(1, num_lvl_ - 1)); - STATE st; - size_t sz = 0; - for (int i = start_inx; !found && i < num_lvl_; i++) { - MS_ASSERT(count_[i] >= 0); - if (count_[i] == 0) { - continue; - } - auto blk_sz = static_cast(BitLeftShift(1, i)); - ask_addr = hint_[i]; - while (ask_addr < max_addr && !found) { - GetBuddySegState(ask_addr, &sz, &st); - if (st == STATE::kFree && sz == blk_sz) { - found = true; - } else { - MS_ASSERT(st != STATE::kEmpty); - ask_addr += ((sz > blk_sz) ? sz : blk_sz); - } - } - } - if (found) { - if (sz > req_size) { - TrimBuddySeg(ask_addr, sz, req_size); - } else { - SetBuddySegState(ask_addr, sz, STATE::kAlloc); - hint_[start_inx] = ask_addr; - } - return ask_addr; - } else { - return static_cast(NOSPACE); - } -} - -void BuddySpace::FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size) { - if (req_size == blk_size) { -#ifdef DEBUG - { - size_t sz = 0; - STATE st; - GetBuddySegState(addr, &sz, &st); - } -#endif - SetBuddySegState(addr, blk_size, STATE::kFree); - JoinBuddySeg(addr, blk_size); - } else { - UnTrimBuddySeg(addr, blk_size, req_size); - } -} - -int BuddySpace::PercentFree() const { - uint64_t total_free_sz = 0; - uint64_t max_sz_in_unit = BitLeftShift(1, num_lvl_ - 1); - // Go through the count array without lock - for (int i = 0; i < num_lvl_; i++) { - int cnt = count_[i]; - if (cnt == 0) { - continue; - } - uint64_t blk_sz = BitLeftShift(1, i); - total_free_sz += (blk_sz * cnt); - } - return static_cast(static_cast(total_free_sz) / static_cast(max_sz_in_unit) * 100); -} - -BuddySpace::BuddySpace(int log_min, int num_lvl) - : hint_(nullptr), - count_(nullptr), - map_(nullptr), - log_min_(log_min), - num_lvl_(num_lvl), - min_(0), - max_(0), - ptr_(nullptr) {} - -BuddySpace::~BuddySpace() { - if (ptr_ != nullptr) { - free(ptr_); - } - hint_ = nullptr; - count_ = nullptr; - map_ = nullptr; -} - -Status BuddySpace::CreateBuddySpace(std::unique_ptr *out_bs, int log_min, int num_lvl) { - Status rc; - auto bs = new (std::nothrow) BuddySpace(log_min, num_lvl); - if (bs == nullptr) { - return Status(StatusCode::kOutOfMemory); - } - rc = bs->Init(); - if (rc.IsOk()) { - (*out_bs).reset(bs); - } else { - delete bs; - } - return rc; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/buddy.h b/mindspore/ccsrc/dataset/util/buddy.h deleted file mode 100644 index 08c05cbbdb..0000000000 --- a/mindspore/ccsrc/dataset/util/buddy.h +++ /dev/null @@ -1,133 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_BUDDY_H_ -#define DATASET_UTIL_BUDDY_H_ - -#include -#include -#include -#include -#include -#include -#include "dataset/util/status.h" - -using addr_t = int64_t; -using rel_addr_t = int32_t; -using log_t = int; -#define ALLOC_BIT 0x80 -#define ONE_BIT 0x40 -#define TWO_BIT 0x20 -#define MORE_BIT 0x10 -#define NOSPACE ((addr_t)(-1)) -namespace mindspore { -namespace dataset { -struct BSpaceDescriptor { - int32_t sig; - rel_addr_t addr; - size_t req_size; - size_t blk_size; -}; - -class BuddySpace { - public: - // C++11 feature. Change STATE into a type safe class with - // the keyword. Don't take out the keyword 'class' - enum class STATE { kFree, kAlloc, kEmpty }; - - BuddySpace(const BuddySpace &) = delete; - - BuddySpace &operator=(const BuddySpace &) = delete; - - virtual ~BuddySpace(); - - Status Alloc(uint64_t sz, BSpaceDescriptor *desc, addr_t *) noexcept; - - void Free(const BSpaceDescriptor *desc); - - uint64_t GetMinSize() const { return min_; } - - uint64_t GetMaxSize() const { return max_; } - - int PercentFree() const; - - friend std::ostream &operator<<(std::ostream &os, const BuddySpace &s); - - static uint64_t NextPowerOf2(uint64_t n) { - if (n <= 1) { - return 1; - } - n = n - 1; - while (n & (n - 1)) { - n = n & (n - 1); - } - return n << 1; - } - - static uint32_t Log2(uint64_t n) { - uint32_t cnt = 0; - while (n >>= 1) { - cnt++; - } - return cnt; - } - - static Status CreateBuddySpace(std::unique_ptr *out_bs, int log_min = 15, int num_lvl = 18); - - private: - rel_addr_t *hint_; - int *count_; - char *map_; - int log_min_; - int num_lvl_; - uint64_t min_; - uint64_t max_; - void *ptr_; - std::mutex mutex_; - - explicit BuddySpace(int log_min = 15, int num_lvl = 18); - - Status Init(); - - addr_t AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept; - - void FreeNoLock(const BSpaceDescriptor *desc); - - uint32_t SizeToBlock(const uint64_t sz) const { - uint32_t reqSize = (sz / min_); - if (sz % min_) { - reqSize++; - } - return reqSize; - } - - void GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const; - - void SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st); - - void JoinBuddySeg(rel_addr_t addr, size_t blk_sz); - - void TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); - - void UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); - - rel_addr_t AllocBuddySeg(uint32_t req_size) noexcept; - - void FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_BUDDY_H_ diff --git a/mindspore/ccsrc/dataset/util/cache_pool.cc b/mindspore/ccsrc/dataset/util/cache_pool.cc deleted file mode 100644 index 92504cd063..0000000000 --- a/mindspore/ccsrc/dataset/util/cache_pool.cc +++ /dev/null @@ -1,202 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "common/utils.h" -#include "dataset/util/cache_pool.h" -#include "dataset/util/services.h" - -namespace mindspore { -namespace dataset { -CachePool::CachePool(const value_allocator &alloc, const std::string &root) - : alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} - -Status CachePool::DoServiceStart() { - tree_ = std::make_shared(); - // If we are given a disk path, set up the StorageManager - if (!root_.toString().empty()) { - Path spill = GetSpillPath(); - RETURN_IF_NOT_OK(spill.CreateDirectories()); - sm_ = std::make_shared(spill); - RETURN_IF_NOT_OK(sm_->ServiceStart()); - MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString()); - } - return Status::OK(); -} -Status CachePool::DoServiceStop() { - Status rc; - Status rc2; - if (sm_ != nullptr) { - rc = sm_->ServiceStop(); - if (rc.IsError()) { - rc2 = rc; - } - } - sm_.reset(); - for (auto &bl : *tree_) { - if (bl.ptr != nullptr) { - alloc_.deallocate(bl.ptr, bl.sz); - } - } - tree_.reset(); - if (!root_.toString().empty()) { - Path spill = GetSpillPath(); - auto it = Path::DirIterator::OpenDirectory(&spill); - while (it->hasNext()) { - rc = it->next().Remove(); - if (rc.IsError() && rc2.IsOk()) { - rc2 = rc; - } - } - rc = spill.Remove(); - if (rc.IsError() && rc2.IsOk()) { - rc2 = rc; - } - } - return rc2; -} -CachePool::~CachePool() noexcept { (void)ServiceStop(); } -Status CachePool::Insert(const std::vector &buf, CachePool::key_type *key) { - DataLocator bl; - Status rc; - size_t sz = 0; - // We will consolidate all the slices into one piece. - for (auto &v : buf) { - sz += v.GetSize(); - } - bl.sz = sz; - try { - bl.ptr = alloc_.allocate(sz); - // We will do a piecewise copy. - WritableSlice dest(bl.ptr, bl.sz); - size_t pos = 0; - for (auto &v : buf) { - WritableSlice out(dest, pos); - rc = WritableSlice::Copy(&out, v); - if (rc.IsError()) { - break; - } - pos += v.GetSize(); - } - if (rc.IsError()) { - alloc_.deallocate(bl.ptr, sz); - bl.ptr = nullptr; - return rc; - } - } catch (std::bad_alloc &e) { - if (sm_ != nullptr) { - RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); - // We have an assumption 0 is not a valid key from the design of AutoIndexObj. - // Make sure it is not 0. - if (bl.storage_key == 0) { - RETURN_STATUS_UNEXPECTED("Key 0 is returned which is unexpected"); - } - } else { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } - } - rc = tree_->insert(bl, key); - if (rc.IsError() && bl.ptr != nullptr) { - alloc_.deallocate(bl.ptr, sz); - } - return rc; -} -Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { - RETURN_UNEXPECTED_IF_NULL(dest); - auto r = tree_->Search(key); - if (r.second) { - auto &it = r.first; - if (it->ptr != nullptr) { - ReadableSlice src(it->ptr, it->sz); - RETURN_IF_NOT_OK(WritableSlice::Copy(dest, src)); - } else if (sm_ != nullptr) { - size_t expectedLength = 0; - RETURN_IF_NOT_OK(sm_->Read(it->storage_key, dest, &expectedLength)); - if (expectedLength != it->sz) { - MS_LOG(ERROR) << "Unexpected length. Read " << expectedLength << ". Expected " << it->sz << "." - << " Internal key: " << key << "\n"; - RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); - } - } - if (bytesRead != nullptr) { - *bytesRead = it->sz; - } - } else { - RETURN_STATUS_UNEXPECTED("Key not found"); - } - return Status::OK(); -} -const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; } -Path CachePool::GetSpillPath() const { - auto spill = Path(root_) / subfolder_; - return spill; -} -CachePool::CacheStat CachePool::GetStat() const { - CacheStat cs{0}; - for (auto &it : *tree_) { - if (it.ptr != nullptr) { - ++cs.num_mem_cached; - } else { - ++cs.num_disk_cached; - } - } - return cs; -} -Status CachePool::Spill(CachePool::DataLocator *dl) { - if (sm_ == nullptr) { - RETURN_STATUS_UNEXPECTED("No disk storage to spill"); - } - RETURN_UNEXPECTED_IF_NULL(dl); - RETURN_UNEXPECTED_IF_NULL(dl->ptr); - if (dl->storage_key == 0) { - ReadableSlice data(dl->ptr, dl->sz); - RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data})); - } - alloc_.deallocate(dl->ptr, dl->sz); - dl->ptr = nullptr; - return Status::OK(); -} -Status CachePool::Locate(CachePool::DataLocator *dl) { - RETURN_UNEXPECTED_IF_NULL(dl); - if (dl->ptr == nullptr) { - if (sm_ == nullptr) { - RETURN_STATUS_UNEXPECTED("No disk storage to locate the data"); - } - try { - dl->ptr = alloc_.allocate(dl->sz); - WritableSlice dest(dl->ptr, dl->sz); - Status rc = Read(dl->storage_key, &dest); - if (rc.IsError()) { - alloc_.deallocate(dl->ptr, dl->sz); - dl->ptr = nullptr; - return rc; - } - } catch (const std::bad_alloc &e) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } - } - return Status::OK(); -} -size_t CachePool::GetSize(CachePool::key_type key) const { - auto r = tree_->Search(key); - if (r.second) { - auto &it = r.first; - return it->sz; - } else { - return 0; - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/cache_pool.h b/mindspore/ccsrc/dataset/util/cache_pool.h deleted file mode 100644 index d35617d0e4..0000000000 --- a/mindspore/ccsrc/dataset/util/cache_pool.h +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_CACHE_POOL_H_ -#define DATASET_UTIL_CACHE_POOL_H_ - -#include -#include -#include -#include -#include "dataset/util/allocator.h" -#include "dataset/util/service.h" -#include "dataset/util/slice.h" -#include "dataset/util/storage_manager.h" -#include "dataset/util/auto_index.h" - -namespace mindspore { -namespace dataset { -/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of -/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to -/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to -/// restore the buffer. -/// \see ReadableSlice -class CachePool : public Service { - public: - using base_type = uint8_t; - using pointer = base_type *; - using const_pointer = const base_type *; - using reference = base_type &; - using const_reference = const base_type &; - using value_allocator = Allocator; - - // An internal class to locate the whereabouts of a backed up buffer which can be either in - class DataLocator { - public: - DataLocator() : ptr(nullptr), sz(0), storage_key(0) {} - ~DataLocator() = default; - DataLocator(const DataLocator &other) = default; - DataLocator &operator=(const DataLocator &other) = default; - DataLocator(DataLocator &&other) noexcept { - ptr = other.ptr; - sz = other.sz; - storage_key = other.storage_key; - other.ptr = nullptr; - other.sz = 0; - other.storage_key = 0; - } - DataLocator &operator=(DataLocator &&other) noexcept { - if (&other != this) { - ptr = other.ptr; - sz = other.sz; - storage_key = other.storage_key; - other.ptr = nullptr; - other.sz = 0; - other.storage_key = 0; - } - return *this; - } - pointer ptr; - size_t sz; - StorageManager::key_type storage_key; - }; - - using data_index = AutoIndexObj; - using key_type = data_index::key_type; - using bl_alloc_type = typename value_allocator::template rebind::other; - - /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and - /// how many elements are spilled to disk. - struct CacheStat { - int64_t num_mem_cached; - int64_t num_disk_cached; - }; - - /// \brief Constructor - /// \param alloc Allocator to allocate memory from - /// \param root Optional disk folder to spill - explicit CachePool(const value_allocator &alloc, const std::string &root = ""); - - CachePool(const CachePool &) = delete; - CachePool(CachePool &&) = delete; - CachePool &operator=(const CachePool &) = delete; - CachePool &operator=(CachePool &&) = delete; - ~CachePool() noexcept; - - Status DoServiceStart() override; - Status DoServiceStop() override; - - Path GetSpillPath() const; - - /// \brief Insert a sequence of ReadableSlice objects into the pool. - /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. - /// \param[in] buf A sequence of ReadableSlice objects. - /// \param[out] key Generated key - /// \return Error code - Status Insert(const std::vector &buf, key_type *key); - /// \brief Restore a cached buffer (from memory or disk) - /// \param[in] key A previous key returned from Insert - /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice - /// \param[out] bytesRead Optional. Number of bytes read. - /// \return Error code - Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; - - Status Spill(DataLocator *dl); - - Status Locate(DataLocator *dl); - - size_t GetSize(key_type key) const; - - /// \brief Get statistics. - /// \return CacheStat object - CacheStat GetStat() const; - - const value_allocator &get_allocator() const; - - std::string MyName() const { return subfolder_; } - - private: - value_allocator alloc_; - Path root_; - const std::string subfolder_; - std::shared_ptr sm_; - std::shared_ptr tree_; -}; -} // namespace dataset -} // namespace mindspore -#endif diff --git a/mindspore/ccsrc/dataset/util/circular_pool.cc b/mindspore/ccsrc/dataset/util/circular_pool.cc deleted file mode 100644 index 0c68dab81b..0000000000 --- a/mindspore/ccsrc/dataset/util/circular_pool.cc +++ /dev/null @@ -1,222 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/circular_pool.h" - -#include -#include -#include -#include "./securec.h" -#include "dataset/util/system_pool.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -Status CircularPool::AddOneArena() { - Status rc; - std::shared_ptr b; - RETURN_IF_NOT_OK(Arena::CreateArena(&b, arena_size_)); - tail_ = b.get(); - cur_size_in_mb_ += arena_size_; - mem_segments_.push_back(std::move(b)); - return Status::OK(); -} - -ListOfArenas::iterator CircularPool::CircularIterator::Next() { - ListOfArenas::iterator it = dp_->mem_segments_.begin(); - uint32_t size = dp_->mem_segments_.size(); - // This is what we return - it += cur_; - // Prepare for the next round - cur_++; - if (cur_ == size) { - if (start_ == 0) { - has_next_ = false; - } else { - wrap_ = true; - cur_ = 0; - } - } else if (cur_ == start_) { - has_next_ = false; - } - return it; -} - -bool CircularPool::CircularIterator::has_next() const { return has_next_; } - -void CircularPool::CircularIterator::Reset() { - wrap_ = false; - has_next_ = false; - if (!dp_->mem_segments_.empty()) { - // Find the buddy arena that corresponds to the tail. - cur_tail_ = dp_->tail_; - auto list_end = dp_->mem_segments_.end(); - auto it = std::find_if(dp_->mem_segments_.begin(), list_end, - [this](const std::shared_ptr &b) { return b.get() == cur_tail_; }); - MS_ASSERT(it != list_end); - start_ = std::distance(dp_->mem_segments_.begin(), it); - cur_ = start_; - has_next_ = true; - } -} - -CircularPool::CircularIterator::CircularIterator(CircularPool *dp) : dp_(dp) { Reset(); } - -Status CircularPool::Allocate(size_t n, void **p) { - if (p == nullptr) { - RETURN_STATUS_UNEXPECTED("p is null"); - } - Status rc; - void *ptr = nullptr; - do { - SharedLock lock_s(&rw_lock_); - int prevSzInMB = cur_size_in_mb_; - bool move_tail = false; - CircularIterator cirIt(this); - while (cirIt.has_next()) { - auto it = cirIt.Next(); - Arena *ba = it->get(); - // If we are asked to move forward the tail - if (move_tail) { - Arena *expected = cirIt.cur_tail_; - (void)atomic_compare_exchange_weak(&tail_, &expected, ba); - move_tail = false; - } - rc = ba->Allocate(n, &ptr); - if (rc.IsOk()) { - *p = ptr; - break; - } else if (rc.IsOutofMemory()) { - // Make the next arena a new tail and continue. - move_tail = true; - } else { - return rc; - } - } - - // Handle the case we have done one round robin search. - if (ptr == nullptr) { - // If we have room to expand. - if (unlimited_ || cur_size_in_mb_ < max_size_in_mb_) { - // lock in exclusively mode. - lock_s.Upgrade(); - // Check again if someone has already expanded. - if (cur_size_in_mb_ == prevSzInMB) { - RETURN_IF_NOT_OK(AddOneArena()); - } - // Re-acquire the shared lock and try again - lock_s.Downgrade(); - } else { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } - } - } while (ptr == nullptr); - return rc; -} - -void CircularPool::Deallocate(void *p) { - // Lock in the chain in shared mode and find out which - // segment it comes from - SharedLock lock(&rw_lock_); - auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr &b) -> bool { - char *q = reinterpret_cast(p); - char *base = const_cast(reinterpret_cast(b->get_base_addr())); - return (q > base && q < base + b->get_max_size()); - }); - lock.Unlock(); - it->get()->Deallocate(p); -} - -Status CircularPool::Reallocate(void **pp, size_t old_sz, size_t new_sz) { - // Lock in the chain in shared mode and find out which - // segment it comes from - if (pp == nullptr) { - RETURN_STATUS_UNEXPECTED("pp is null"); - } - void *p = *pp; - SharedLock lock(&rw_lock_); - auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr &b) -> bool { - char *q = reinterpret_cast(p); - char *base = const_cast(reinterpret_cast(b->get_base_addr())); - return (q > base && q < base + b->get_max_size()); - }); - lock.Unlock(); - MS_ASSERT(it != mem_segments_.end()); - Arena *ba = it->get(); - Status rc = ba->Reallocate(pp, old_sz, new_sz); - if (rc.IsOutofMemory()) { - // The current arena has no room for the bigger size. - // Allocate free space from another arena and copy - // the content over. - void *q = nullptr; - rc = this->Allocate(new_sz, &q); - RETURN_IF_NOT_OK(rc); - errno_t err = memcpy_s(q, new_sz, p, old_sz); - if (err) { - this->Deallocate(q); - RETURN_STATUS_UNEXPECTED(std::to_string(err)); - } - *pp = q; - ba->Deallocate(p); - } - return Status::OK(); -} - -uint64_t CircularPool::get_max_size() const { return mem_segments_.front()->get_max_size(); } - -int CircularPool::PercentFree() const { - int percent_free = 0; - int num_arena = 0; - for (auto const &p : mem_segments_) { - percent_free += p->PercentFree(); - num_arena++; - } - if (num_arena) { - return percent_free / num_arena; - } else { - return 100; - } -} - -CircularPool::CircularPool(int max_size_in_gb, int arena_size) - : unlimited_(max_size_in_gb <= 0), - max_size_in_mb_(unlimited_ ? std::numeric_limits::max() : max_size_in_gb * 1024), - arena_size_(arena_size), - cur_size_in_mb_(0) {} - -Status CircularPool::CreateCircularPool(std::shared_ptr *out_pool, int max_size_in_gb, int arena_size, - bool createOneArena) { - Status rc; - if (out_pool == nullptr) { - RETURN_STATUS_UNEXPECTED("pPool is null"); - } - auto pool = new (std::nothrow) CircularPool(max_size_in_gb, arena_size); - if (pool == nullptr) { - return Status(StatusCode::kOutOfMemory); - } - if (createOneArena) { - rc = pool->AddOneArena(); - } - if (rc.IsOk()) { - (*out_pool).reset(pool); - } else { - delete pool; - } - return rc; -} - -CircularPool::~CircularPool() = default; -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/circular_pool.h b/mindspore/ccsrc/dataset/util/circular_pool.h deleted file mode 100644 index 3c52659799..0000000000 --- a/mindspore/ccsrc/dataset/util/circular_pool.h +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_CIRCULAR_POOL_H_ -#define DATASET_UTIL_CIRCULAR_POOL_H_ - -#include -#include -#include -#include "dataset/util/memory_pool.h" -#include "dataset/util/arena.h" -#include "dataset/util/lock.h" - -namespace mindspore { -namespace dataset { -using ListOfArenas = std::vector>; - -// This is a dynamic memory pool built on top of memory -// segment each of which is 4G in size. Initially we start -// with one segment, and gradually add segments (not -// guaranteed contiguous) until we reach 32G in size. There -// is an assumption about this kind of memory pool. Allocated -// memory is not held for the whole duration of the pool and -// will be released soon. Based on this assumption, memory is -// obtained from the tail while allocated memory is returned -// to the head of the pool. -class CircularPool : public MemoryPool { - public: - class CircularIterator { - friend class CircularPool; - - public: - explicit CircularIterator(CircularPool *dp); - - ~CircularIterator() = default; - - bool has_next() const; - - ListOfArenas::iterator Next(); - - void Reset(); - - private: - CircularPool *dp_; - Arena *cur_tail_{}; - uint32_t start_{}; - uint32_t cur_{}; - bool wrap_{}; - bool has_next_{}; - }; - - CircularPool(const CircularPool &) = delete; - - CircularPool &operator=(const CircularPool &) = delete; - - ~CircularPool() override; - - Status Allocate(size_t n, void **) override; - - Status Reallocate(void **, size_t old_size, size_t new_size) override; - - void Deallocate(void *) override; - - uint64_t get_max_size() const override; - - int PercentFree() const override; - - friend std::ostream &operator<<(std::ostream &os, const CircularPool &s) { - int i = 0; - for (auto it = s.mem_segments_.begin(); it != s.mem_segments_.end(); ++it, ++i) { - os << "Dumping segment " << i << "\n" << *(it->get()); - } - return os; - } - - static Status CreateCircularPool(std::shared_ptr *out_pool, int max_size_in_gb = -1, - int arena_size = 4096, bool create_one_arena = false); - - private: - ListOfArenas mem_segments_; - std::atomic tail_{}; - bool unlimited_; - int max_size_in_mb_; - int arena_size_; - int cur_size_in_mb_; - RWLock rw_lock_; - - // We can take negative or 0 as input which means unlimited. - CircularPool(int max_size_in_gb, int arena_size); - - Status AddOneArena(); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_CIRCULAR_POOL_H_ diff --git a/mindspore/ccsrc/dataset/util/cond_var.cc b/mindspore/ccsrc/dataset/util/cond_var.cc deleted file mode 100644 index 8b1099fb71..0000000000 --- a/mindspore/ccsrc/dataset/util/cond_var.cc +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/cond_var.h" -#include -#include -#include "dataset/util/services.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {} - -Status CondVar::Wait(std::unique_lock *lck, const std::function &pred) { - try { - if (svc_ != nullptr) { - // If this cv registers with a global resource tracking, then wait unconditionally. - auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); }; - cv_.wait(*lck, f); - // If we are interrupted, override the return value if this is the master thread. - // Master thread is being interrupted mostly because of some thread is reporting error. - RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus())); - } else { - // Otherwise we wake up once a while to check for interrupt (for this thread). - auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); }; - while (!f()) { - (void)cv_.wait_for(*lck, std::chrono::milliseconds(1)); - } - RETURN_IF_INTERRUPTED(); - } - } catch (const std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - return Status::OK(); -} - -CondVar::~CondVar() noexcept { - if (svc_ != nullptr) { - (void)svc_->Deregister(my_name_); - svc_ = nullptr; - } -} - -void CondVar::NotifyOne() noexcept { cv_.notify_one(); } - -void CondVar::NotifyAll() noexcept { cv_.notify_all(); } - -Status CondVar::Register(std::shared_ptr svc) { - Status rc = svc->Register(my_name_, this); - if (rc.IsOk()) { - svc_ = svc; - } - return rc; -} - -void CondVar::Interrupt() { - IntrpResource::Interrupt(); - cv_.notify_all(); -} - -std::string CondVar::my_name() const { return my_name_; } - -Status CondVar::Deregister() { - if (svc_) { - Status rc = svc_->Deregister(my_name_); - svc_ = nullptr; - return rc; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/cond_var.h b/mindspore/ccsrc/dataset/util/cond_var.h deleted file mode 100644 index b23dcd566e..0000000000 --- a/mindspore/ccsrc/dataset/util/cond_var.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_COND_VAR_H_ -#define DATASET_UTIL_COND_VAR_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/intrp_resource.h" -#include "dataset/util/intrp_service.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class CondVar : public IntrpResource { - public: - CondVar(); - - ~CondVar() noexcept; - - Status Wait(std::unique_lock *lck, const std::function &pred); - - void Interrupt() override; - - void NotifyOne() noexcept; - - void NotifyAll() noexcept; - - Status Register(std::shared_ptr svc); - - std::string my_name() const; - - Status Deregister(); - - protected: - std::condition_variable cv_; - std::shared_ptr svc_; - - private: - std::string my_name_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_COND_VAR_H_ diff --git a/mindspore/ccsrc/dataset/util/intrp_resource.h b/mindspore/ccsrc/dataset/util/intrp_resource.h deleted file mode 100644 index 52024cb90a..0000000000 --- a/mindspore/ccsrc/dataset/util/intrp_resource.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_INTRP_RESOURCE_H_ -#define DATASET_UTIL_INTRP_RESOURCE_H_ - -#include -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class IntrpResource { - public: - enum class State : int { kRunning, kInterrupted }; - - IntrpResource() : st_(State::kRunning) {} - - virtual ~IntrpResource() = default; - - virtual void Interrupt() { st_ = State::kInterrupted; } - - virtual void ResetIntrpState() { st_ = State::kRunning; } - - State CurState() const { return st_; } - - bool Interrupted() const { return CurState() == State::kInterrupted; } - - virtual Status GetInterruptStatus() const { - if (Interrupted()) { - return Status(StatusCode::kInterrupted); - } - return Status::OK(); - } - - protected: - std::atomic st_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_INTRP_RESOURCE_H_ diff --git a/mindspore/ccsrc/dataset/util/intrp_service.cc b/mindspore/ccsrc/dataset/util/intrp_service.cc deleted file mode 100644 index da8dde992c..0000000000 --- a/mindspore/ccsrc/dataset/util/intrp_service.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/intrp_service.h" -#include -#include "common/utils.h" -#include "dataset/util/services.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -IntrpService::IntrpService() : high_water_mark_(0) { (void)ServiceStart(); } - -IntrpService::~IntrpService() noexcept { - MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << "."; - if (!all_intrp_resources_.empty()) { - try { - InterruptAll(); - } catch (const std::exception &e) { - // Ignore all error as we can't throw in the destructor. - } - } - (void)ServiceStop(); -} - -Status IntrpService::Register(const std::string &name, IntrpResource *res) { - SharedLock stateLck(&state_lock_); - // Now double check the state - if (ServiceState() != STATE::kRunning) { - return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "Interrupt service is shutting down"); - } else { - std::lock_guard lck(mutex_); - try { - std::ostringstream ss; - ss << this_thread::get_id(); - MS_LOG(DEBUG) << "Register resource with name " << name << ". Thread ID " << ss.str() << "."; - auto it = all_intrp_resources_.emplace(name, res); - if (it.second == false) { - return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, name); - } - high_water_mark_++; - } catch (std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - } - return Status::OK(); -} - -Status IntrpService::Deregister(const std::string &name) noexcept { - std::lock_guard lck(mutex_); - try { - std::ostringstream ss; - ss << this_thread::get_id(); - MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << "."; - auto n = all_intrp_resources_.erase(name); - if (n == 0) { - MS_LOG(INFO) << "Key " << name << " not found."; - } - } catch (std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - return Status::OK(); -} - -void IntrpService::InterruptAll() noexcept { - std::lock_guard lck(mutex_); - for (auto const &it : all_intrp_resources_) { - std::string kName = it.first; - try { - it.second->Interrupt(); - } catch (const std::exception &e) { - // continue the clean up. - } - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/intrp_service.h b/mindspore/ccsrc/dataset/util/intrp_service.h deleted file mode 100644 index de1d5eb753..0000000000 --- a/mindspore/ccsrc/dataset/util/intrp_service.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_INTRP_SERVICE_H_ -#define DATASET_UTIL_INTRP_SERVICE_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/allocator.h" -#include "dataset/util/intrp_resource.h" -#include "dataset/util/service.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -using SvcAllocator = Allocator>; - -class IntrpService : public Service { - public: - IntrpService(); - - ~IntrpService() noexcept override; - - IntrpService(const IntrpService &) = delete; - - IntrpService &operator=(const IntrpService &) = delete; - - Status Register(const std::string &name, IntrpResource *res); - - Status Deregister(const std::string &name) noexcept; - - void InterruptAll() noexcept; - - Status DoServiceStart() override { return Status::OK(); } - - Status DoServiceStop() override { return Status::OK(); } - - private: - int high_water_mark_; - std::mutex mutex_; - std::map all_intrp_resources_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_INTRP_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/util/lock.cc b/mindspore/ccsrc/dataset/util/lock.cc deleted file mode 100644 index bde9d84005..0000000000 --- a/mindspore/ccsrc/dataset/util/lock.cc +++ /dev/null @@ -1,185 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/lock.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -void SpinLock::Lock() { - while (true) { - int expected = kUnlocked; - if (val_.compare_exchange_weak(expected, kLocked)) { - break; - } - } -} - -bool SpinLock::TryLock() { - int expected = kUnlocked; - return val_.compare_exchange_strong(expected, kLocked); -} - -void SpinLock::Unlock() noexcept { val_.store(kUnlocked); } - -void RWLock::LockShared() { - std::unique_lock lck(mtx_); - waiting_readers_ += 1; - read_cv_.wait(lck, [this]() { return (waiting_writers_ == 0 && status_ >= 0); }); - waiting_readers_ -= 1; - status_ += 1; -} - -void RWLock::Unlock() noexcept { - std::unique_lock lck(mtx_); - if (status_ == -1) { - // I am the writer. By definition, no other writer nor reader. - status_ = 0; - } else if (status_ > 0) { - // One less reader - status_ -= 1; - } - // Wake up writer only if there is no reader. - if (waiting_writers_ > 0) { - if (status_ == 0) { - write_cv_.notify_one(); - } - } else { - read_cv_.notify_all(); - } -} - -void RWLock::Upgrade() { - std::unique_lock lck(mtx_); - MS_ASSERT(status_); - if (status_ == -1) { - // I am a writer already. - return; - } else if (status_ == 1) { - // If I am the only reader. Just change the status. - status_ = -1; - return; - } else { - // In all other cases, let of the shared lock and relock in exclusive. - lck.unlock(); - this->Unlock(); - this->LockExclusive(); - } -} - -void RWLock::Downgrade() { - std::unique_lock lck(mtx_); - MS_ASSERT(status_); - if (status_ == -1) { - // If there are no other writers waiting, just change the status - if (waiting_writers_ == 0) { - status_ = 1; - } else { - // Otherwise just unlock and relock in shared - lck.unlock(); - this->Unlock(); - this->LockShared(); - } - } else if (status_ > 0) { - return; - } -} - -SharedLock::SharedLock(RWLock *rw) : rw_(rw), ownlock_(false) { - rw_->LockShared(); - ownlock_ = true; -} - -SharedLock::~SharedLock() { - if (ownlock_) { - rw_->Unlock(); - ownlock_ = false; - } - rw_ = nullptr; -} - -void SharedLock::Unlock() { - MS_ASSERT(ownlock_ == true); - rw_->Unlock(); - ownlock_ = false; -} - -void SharedLock::Lock() { - MS_ASSERT(ownlock_ == false); - rw_->LockShared(); - ownlock_ = true; -} - -void SharedLock::Upgrade() { - MS_ASSERT(ownlock_ == true); - rw_->Upgrade(); -} - -void SharedLock::Downgrade() { - MS_ASSERT(ownlock_ == true); - rw_->Downgrade(); -} - -UniqueLock::UniqueLock(RWLock *rw) : rw_(rw), ownlock_(false) { - rw_->LockExclusive(); - ownlock_ = true; -} - -UniqueLock::~UniqueLock() { - if (ownlock_) { - rw_->Unlock(); - ownlock_ = false; - } - rw_ = nullptr; -} - -void UniqueLock::Unlock() { - MS_ASSERT(ownlock_ == true); - rw_->Unlock(); - ownlock_ = false; -} - -void UniqueLock::Lock() { - MS_ASSERT(ownlock_ == false); - rw_->LockExclusive(); - ownlock_ = true; -} - -LockGuard::LockGuard(SpinLock *lock) : lck_(lock), own_lock_(false) { - lck_->Lock(); - own_lock_ = true; -} - -LockGuard::~LockGuard() { - if (own_lock_) { - lck_->Unlock(); - own_lock_ = false; - } - lck_ = nullptr; -} - -void LockGuard::Unlock() { - MS_ASSERT(own_lock_); - lck_->Unlock(); - own_lock_ = false; -} - -void LockGuard::Lock() { - MS_ASSERT(own_lock_ == false); - lck_->Lock(); - own_lock_ = true; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/memory_pool.cc b/mindspore/ccsrc/dataset/util/memory_pool.cc deleted file mode 100644 index 5d66b4bd6d..0000000000 --- a/mindspore/ccsrc/dataset/util/memory_pool.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/memory_pool.h" -#include "./securec.h" - -namespace mindspore { -namespace dataset { -Status DeMalloc(std::size_t s, void **p, bool init_to_zero = false) { - if (p == nullptr) { - RETURN_STATUS_UNEXPECTED("p is null"); - } - void *q = ::malloc(s); - if (q == nullptr) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } else { - *p = q; - if (init_to_zero) { - (void)memset_s(q, s, 0, s); - } - return Status::OK(); - } -} -} // namespace dataset -} // namespace mindspore - -void *operator new(std::size_t s, mindspore::dataset::Status *rc, std::shared_ptr b) { - void *ptr = nullptr; - *rc = b->Allocate(s, &ptr); - return ptr; -} - -void *operator new[](std::size_t s, mindspore::dataset::Status *rc, std::shared_ptr b) { - void *ptr = nullptr; - *rc = b->Allocate(s, &ptr); - return ptr; -} - -void operator delete(void *p, std::shared_ptr b) { - if (p != nullptr) b->Deallocate(p); -} - -void operator delete[](void *p, std::shared_ptr b) { - if (p != nullptr) b->Deallocate(p); -} diff --git a/mindspore/ccsrc/dataset/util/memory_pool.h b/mindspore/ccsrc/dataset/util/memory_pool.h deleted file mode 100644 index ee1da3bda1..0000000000 --- a/mindspore/ccsrc/dataset/util/memory_pool.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_MEMORY_POOL_H_ -#define DATASET_UTIL_MEMORY_POOL_H_ - -#include -#include -#include -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// Abstract class of a memory pool -class MemoryPool { - public: - // Allocate a block of size n - virtual Status Allocate(size_t, void **) = 0; - - // Enlarge or shrink a block from oldSz to newSz - virtual Status Reallocate(void **, size_t old_sz, size_t new_sz) = 0; - - // Free a pointer - virtual void Deallocate(void *) = 0; - - // What is the maximum size I can allocate ? - virtual uint64_t get_max_size() const = 0; - - virtual int PercentFree() const = 0; - - // Destructor - virtual ~MemoryPool() {} -}; - -Status DeMalloc(std::size_t s, void **p, bool); -} // namespace dataset -} // namespace mindspore - -void *operator new(std::size_t, mindspore::dataset::Status *, std::shared_ptr); - -void *operator new[](std::size_t, mindspore::dataset::Status *, std::shared_ptr); - -void operator delete(void *, std::shared_ptr); - -void operator delete[](void *, std::shared_ptr); - -#endif // DATASET_UTIL_MEMORY_POOL_H_ diff --git a/mindspore/ccsrc/dataset/util/path.cc b/mindspore/ccsrc/dataset/util/path.cc deleted file mode 100644 index cdd2343799..0000000000 --- a/mindspore/ccsrc/dataset/util/path.cc +++ /dev/null @@ -1,340 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/path.h" - -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -#if defined(_WIN32) || defined(_WIN64) -char Path::separator_ = '\\'; -#else -char Path::separator_ = '/'; -#endif - -Path::Path(const std::string &s) : path_(s) {} - -Path::Path(const char *p) : path_(p) {} - -Path::Path(const Path &p) : path_(p.path_) {} - -Path &Path::operator=(const Path &p) { - if (&p != this) { - this->path_ = p.path_; - } - return *this; -} - -Path &Path::operator=(Path &&p) noexcept { - if (&p != this) { - this->path_ = std::move(p.path_); - } - return *this; -} - -Path::Path(Path &&p) noexcept { this->path_ = std::move(p.path_); } - -Path Path::operator+(const Path &p) { - std::string q = path_ + p.toString(); - return Path(q); -} - -Path Path::operator+(const std::string &p) { - std::string q = path_ + p; - return Path(q); -} - -Path Path::operator+(const char *p) { - std::string q = path_ + p; - return Path(q); -} - -Path &Path::operator+=(const Path &rhs) { - path_ += rhs.toString(); - return *this; -} - -Path &Path::operator+=(const std::string &p) { - path_ += p; - return *this; -} - -Path &Path::operator+=(const char *p) { - path_ += p; - return *this; -} - -Path Path::operator/(const Path &p) { - std::string q = path_ + separator_ + p.toString(); - return Path(q); -} - -Path Path::operator/(const std::string &p) { - std::string q = path_ + separator_ + p; - return Path(q); -} - -Path Path::operator/(const char *p) { - std::string q = path_ + separator_ + p; - return Path(q); -} - -std::string Path::Extension() const { - std::size_t found = path_.find_last_of('.'); - if (found != std::string::npos) { - return path_.substr(found); - } else { - return std::string(""); - } -} - -bool Path::Exists() { - struct stat sb; - int rc = stat(common::SafeCStr(path_), &sb); - if (rc == -1 && errno != ENOENT) { - MS_LOG(WARNING) << "Unable to query the status of " << path_ << ". Errno = " << errno << "."; - } - return (rc == 0); -} - -bool Path::IsDirectory() { - struct stat sb; - int rc = stat(common::SafeCStr(path_), &sb); - if (rc == 0) { - return S_ISDIR(sb.st_mode); - } else { - return false; - } -} - -Status Path::CreateDirectory() { - if (!Exists()) { -#if defined(_WIN32) || defined(_WIN64) - int rc = mkdir(common::SafeCStr(path_)); -#else - int rc = mkdir(common::SafeCStr(path_), S_IRUSR | S_IWUSR | S_IXUSR); -#endif - if (rc) { - std::ostringstream oss; - oss << "Unable to create directory " << path_ << ". Errno = " << errno; - RETURN_STATUS_UNEXPECTED(oss.str()); - } - return Status::OK(); - } else { - if (IsDirectory()) { - return Status::OK(); - } else { - std::ostringstream oss; - oss << "Unable to create directory " << path_ << ". It exists but is not a directory"; - RETURN_STATUS_UNEXPECTED(oss.str()); - } - } -} - -std::string Path::ParentPath() { - std::string r(""); - std::size_t found = path_.find_last_of(separator_); - if (found != std::string::npos) { - if (found == 0) { - r += separator_; - } else { - r = std::string(path_.substr(0, found)); - } - } - return r; -} - -Status Path::CreateDirectories() { - if (IsDirectory()) { - MS_LOG(DEBUG) << "Directory " << toString() << " already exists."; - return Status::OK(); - } else { - MS_LOG(DEBUG) << "Creating directory " << toString() << "."; - std::string parent = ParentPath(); - if (!parent.empty()) { - if (Path(parent).CreateDirectories()) { - return CreateDirectory(); - } - } else { - return CreateDirectory(); - } - } - return Status::OK(); -} - -Status Path::Remove() { - if (Exists()) { - if (IsDirectory()) { - errno_t err = rmdir(common::SafeCStr(path_)); - if (err == -1) { - std::ostringstream oss; - oss << "Unable to delete directory " << path_ << ". Errno = " << errno; - RETURN_STATUS_UNEXPECTED(oss.str()); - } - } else { - errno_t err = unlink(common::SafeCStr(path_)); - if (err == -1) { - std::ostringstream oss; - oss << "Unable to delete file " << path_ << ". Errno = " << errno; - RETURN_STATUS_UNEXPECTED(oss.str()); - } - } - } - return Status::OK(); -} - -Status Path::CreateFile(int *file_descriptor) { return OpenFile(file_descriptor, true); } - -Status Path::OpenFile(int *file_descriptor, bool create) { - int fd; - if (file_descriptor == nullptr) { - RETURN_STATUS_UNEXPECTED("null pointer"); - } - if (IsDirectory()) { - std::ostringstream oss; - oss << "Unable to create file " << path_ << " which is a directory."; - RETURN_STATUS_UNEXPECTED(oss.str()); - } - // Convert to canonical form. - if (strlen(common::SafeCStr(path_)) > PATH_MAX) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - char canonical_path[PATH_MAX + 1] = {0x00}; -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(canonical_path, common::SafeCStr(path_), PATH_MAX) == nullptr) { -#else - if (realpath(common::SafeCStr(path_), canonical_path) == nullptr) { -#endif - if (errno == ENOENT && create) { - // File doesn't exist and we are to create it. Let's break it down. - auto file_part = Basename(); - auto parent_part = ParentPath(); -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(canonical_path, common::SafeCStr(parent_part), PATH_MAX) == nullptr) { -#else - if (realpath(common::SafeCStr(parent_part), canonical_path) == nullptr) { -#endif - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - auto cur_inx = strlen(canonical_path); - if ((cur_inx + file_part.length() + 1) > PATH_MAX) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - canonical_path[cur_inx++] = separator_; - if (strncpy_s(canonical_path + cur_inx, PATH_MAX - cur_inx, common::SafeCStr(file_part), file_part.length()) != - EOK) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - } else { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - } - if (create) { - fd = open(canonical_path, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP); - } else { - fd = open(canonical_path, O_RDWR); - } - if (fd == -1) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - *file_descriptor = fd; - return Status::OK(); -} - -Status Path::CloseFile(int fd) const { - if (close(fd) < 0) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - return Status::OK(); -} - -Status Path::TruncateFile(int fd) const { - int rc; - rc = ftruncate(fd, 0); - if (rc == 0) { - return Status::OK(); - } else { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } -} - -std::string Path::Basename() { - std::size_t found = path_.find_last_of(separator_); - if (found != std::string::npos) { - return path_.substr(found + 1); - } else { - return path_; - } -} - -std::shared_ptr Path::DirIterator::OpenDirectory(Path *f) { - auto it = new (std::nothrow) DirIterator(f); - - if (it == nullptr) { - return nullptr; - } - - if (it->dp_) { - return std::shared_ptr(it); - } else { - delete it; - return nullptr; - } -} - -Path::DirIterator::~DirIterator() { - if (dp_) { - (void)closedir(dp_); - } - dp_ = nullptr; - dir_ = nullptr; - entry_ = nullptr; -} - -Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) { - MS_LOG(DEBUG) << "Open directory " << f->toString() << "."; - dp_ = opendir(f->toString().c_str()); -} - -bool Path::DirIterator::hasNext() { - do { - entry_ = readdir(dp_); - if (entry_) { - if (strcmp(entry_->d_name, ".") == 0 || strcmp(entry_->d_name, "..") == 0) { - continue; - } - } - break; - } while (true); - return (entry_ != nullptr); -} - -Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); } - -std::ostream &operator<<(std::ostream &os, const Path &s) { - os << s.path_; - return os; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/path.h b/mindspore/ccsrc/dataset/util/path.h deleted file mode 100644 index fbf65b8c23..0000000000 --- a/mindspore/ccsrc/dataset/util/path.h +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_PATH_H_ -#define DATASET_UTIL_PATH_H_ - -#include -#include -#include - -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class Path { - public: - class DirIterator { - public: - static std::shared_ptr OpenDirectory(Path *f); - - ~DirIterator(); - - bool hasNext(); - - Path next(); - - private: - explicit DirIterator(Path *f); - - Path *dir_; - DIR *dp_; - struct dirent *entry_; - }; - - explicit Path(const std::string &); - - explicit Path(const char *); - - ~Path() = default; - - Path(const Path &); - - Path &operator=(const Path &); - - Path(Path &&) noexcept; - - Path &operator=(Path &&) noexcept; - - std::string toString() const { return path_; } - - Path operator+(const Path &); - - Path operator+(const std::string &); - - Path operator+(const char *); - - Path &operator+=(const Path &rhs); - - Path &operator+=(const std::string &); - - Path &operator+=(const char *); - - Path operator/(const Path &); - - Path operator/(const std::string &); - - Path operator/(const char *); - - bool Exists(); - - bool IsDirectory(); - - Status CreateDirectory(); - - Status CreateDirectories(); - - std::string Extension() const; - - std::string ParentPath(); - - Status Remove(); - - Status CreateFile(int *fd); - - Status OpenFile(int *fd, bool create = false); - - Status CloseFile(int fd) const; - - Status TruncateFile(int fd) const; - - std::string Basename(); - - friend std::ostream &operator<<(std::ostream &os, const Path &s); - - private: - static char separator_; - std::string path_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_PATH_H_ diff --git a/mindspore/ccsrc/dataset/util/queue.h b/mindspore/ccsrc/dataset/util/queue.h deleted file mode 100644 index 7fca93d944..0000000000 --- a/mindspore/ccsrc/dataset/util/queue.h +++ /dev/null @@ -1,253 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_QUEUE_H_ -#define DATASET_UTIL_QUEUE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "utils/log_adapter.h" -#include "dataset/util/allocator.h" -#include "dataset/util/services.h" -#include "dataset/util/cond_var.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -template -struct is_shared_ptr : public std::false_type {}; - -template -struct is_shared_ptr> : public std::true_type {}; - -template -struct is_unique_ptr : public std::false_type {}; - -template -struct is_unique_ptr> : public std::true_type {}; - -// A simple thread safe queue using a fixed size array -template -class Queue { - public: - using value_type = T; - using pointer = T *; - using const_pointer = const T *; - using reference = T &; - using const_reference = const T &; - - void Init() { - if (sz_ > 0) { - // We allocate a block of memory and then call the default constructor for each slot. Maybe simpler to call - // new[] but we want to control where the memory is allocated from. - arr_ = alloc_.allocate(sz_); - for (uint64_t i = 0; i < sz_; i++) { - std::allocator_traits>::construct(alloc_, &(arr_[i])); - } - } - } - - explicit Queue(int sz) - : sz_(sz), - arr_(nullptr), - head_(0), - tail_(0), - my_name_(Services::GetUniqueID()), - alloc_(Services::GetInstance().GetServiceMemPool()) { - Init(); - MS_LOG(DEBUG) << "Create Q with uuid " << my_name_ << " of size " << sz_ << "."; - } - - virtual ~Queue() { - ResetQue(); - if (arr_) { - // Simply free the pointer. Since there is nothing in the queue. We don't want to invoke the destructor - // of T in each slot. - alloc_.deallocate(arr_); - arr_ = nullptr; - } - } - - int size() const { - int v = tail_ - head_; - return (v >= 0) ? v : 0; - } - - int capacity() const { return sz_; } - - bool empty() const { return head_ == tail_; } - - void Reset() { ResetQue(); } - - // Producer - Status Add(const_reference ele) noexcept { - std::unique_lock _lock(mux_); - // Block when full - Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); - if (rc.IsOk()) { - uint32_t k = tail_++ % sz_; - arr_[k] = ele; - empty_cv_.NotifyAll(); - _lock.unlock(); - } else { - empty_cv_.Interrupt(); - } - return rc; - } - - Status Add(T &&ele) noexcept { - std::unique_lock _lock(mux_); - // Block when full - Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); - if (rc.IsOk()) { - uint32_t k = tail_++ % sz_; - arr_[k] = std::forward(ele); - empty_cv_.NotifyAll(); - _lock.unlock(); - } else { - empty_cv_.Interrupt(); - } - return rc; - } - - template - Status EmplaceBack(Ts &&... args) noexcept { - std::unique_lock _lock(mux_); - // Block when full - Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); - if (rc.IsOk()) { - uint32_t k = tail_++ % sz_; - new (&(arr_[k])) T(std::forward(args)...); - empty_cv_.NotifyAll(); - _lock.unlock(); - } else { - empty_cv_.Interrupt(); - } - return rc; - } - - // Consumer - Status PopFront(pointer p) { - std::unique_lock _lock(mux_); - // Block when empty - Status rc = empty_cv_.Wait(&_lock, [this]() -> bool { return !empty(); }); - if (rc.IsOk()) { - uint32_t k = head_++ % sz_; - *p = std::move(arr_[k]); - if (std::is_destructible::value) { - // std::move above only changes arr_[k] from rvalue to lvalue. - // The real implementation of move constructor depends on T. - // It may be compiler generated or user defined. But either case - // the result of arr_[k] is still a valid object of type T, and - // we will not keep any extra copy in the queue. - arr_[k].~T(); - // For gcc 9, an extra fix is needed here to clear the memory content - // of arr_[k] because this slot can be reused by another Add which can - // do another std::move. We have seen SEGV here in this case. - std::allocator_traits>::construct(alloc_, &(arr_[k])); - } - full_cv_.NotifyAll(); - _lock.unlock(); - } else { - full_cv_.Interrupt(); - } - return rc; - } - - void ResetQue() noexcept { - std::unique_lock _lock(mux_); - // If there are elements in the queue, invoke its destructor one by one. - if (!empty() && std::is_destructible::value) { - for (uint64_t i = head_; i < tail_; i++) { - uint32_t k = i % sz_; - arr_[k].~T(); - } - } - empty_cv_.ResetIntrpState(); - full_cv_.ResetIntrpState(); - head_ = 0; - tail_ = 0; - } - - Status Register(TaskGroup *vg) { - Status rc1 = empty_cv_.Register(vg->GetIntrpService()); - Status rc2 = full_cv_.Register(vg->GetIntrpService()); - if (rc1.IsOk()) { - return rc2; - } else { - return rc1; - } - } - - private: - uint64_t sz_; - pointer arr_; - uint64_t head_; - uint64_t tail_; - std::string my_name_; - std::mutex mux_; - CondVar empty_cv_; - CondVar full_cv_; - Allocator alloc_; -}; - -// A container of queues with [] operator accessors. Basically this is a wrapper over of a vector of queues -// to help abstract/simplify code that is maintaining multiple queues. -template -class QueueList { - public: - QueueList() {} - - void Init(int num_queues, int capacity) { - queue_list_.reserve(num_queues); - for (int i = 0; i < num_queues; i++) { - queue_list_.emplace_back(std::make_unique>(capacity)); - } - } - - Status Register(TaskGroup *vg) { - if (vg == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Null task group during QueueList registration."); - } - for (int i = 0; i < queue_list_.size(); ++i) { - RETURN_IF_NOT_OK(queue_list_[i]->Register(vg)); - } - return Status::OK(); - } - - int size() const { return queue_list_.size(); } - - std::unique_ptr> &operator[](const int index) { return queue_list_[index]; } - - const std::unique_ptr> &operator[](const int index) const { return queue_list_[index]; } - - ~QueueList() = default; - - private: - // Queue contains non-copyable objects, so it cannot be added to a vector due to the vector - // requirement that objects must have copy semantics. To resolve this, we use a vector of unique - // pointers. This allows us to provide dynamic creation of queues in a container. - std::vector>> queue_list_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_QUEUE_H_ diff --git a/mindspore/ccsrc/dataset/util/random.h b/mindspore/ccsrc/dataset/util/random.h deleted file mode 100644 index 957a4214a8..0000000000 --- a/mindspore/ccsrc/dataset/util/random.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_RANDOM_H_ -#define DATASET_UTIL_RANDOM_H_ - -#if defined(_WIN32) || defined(_WIN64) -#include -#endif -#include -#include -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -inline std::mt19937 GetRandomDevice() { -#if defined(_WIN32) || defined(_WIN64) - unsigned int number; - rand_s(&number); - std::mt19937 random_device{static_cast(number)}; -#else - int i = 0; - while (i < 5) { - try { - std::mt19937 random_device{std::random_device("/dev/urandom")()}; - return random_device; - } catch (const std::exception &e) { - MS_LOG(WARNING) << "Get std::random_device failed, retry: " << i << ", error: " << e.what(); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - i++; - } - } - std::mt19937 random_device{std::random_device("/dev/urandom")()}; -#endif - return random_device; -} - -inline uint32_t GetNewSeed() { - std::mt19937 random_device = GetRandomDevice(); - std::uniform_int_distribution distribution(0, std::numeric_limits::max()); - return distribution(random_device); -} - -inline uint32_t GetSeed() { - uint32_t seed = GlobalContext::config_manager()->seed(); - if (seed == std::mt19937::default_seed) { - seed = GetNewSeed(); - } - return seed; -} - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_RANDOM_H_ diff --git a/mindspore/ccsrc/dataset/util/semaphore.cc b/mindspore/ccsrc/dataset/util/semaphore.cc deleted file mode 100644 index 36ddf5511d..0000000000 --- a/mindspore/ccsrc/dataset/util/semaphore.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/semaphore.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -Status Semaphore::P() { - std::unique_lock lck(mutex_); - RETURN_IF_NOT_OK(wait_cond_.Wait(&lck, [this]() { return value_ > 0; })); - --value_; - return Status::OK(); -} -void Semaphore::V() { - std::unique_lock lck(mutex_); - ++value_; - wait_cond_.NotifyOne(); -} -int Semaphore::Peek() { - std::unique_lock lck(mutex_); - return value_; -} -Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } -Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } -void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/semaphore.h b/mindspore/ccsrc/dataset/util/semaphore.h deleted file mode 100644 index 07b9e83e7f..0000000000 --- a/mindspore/ccsrc/dataset/util/semaphore.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_SEMAPHORE_H_ -#define DATASET_UTIL_SEMAPHORE_H_ - -#include "dataset/util/cond_var.h" - -namespace mindspore { -namespace dataset { -class TaskGroup; - -/// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be -/// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters. -class Semaphore { - public: - /// \brief Constructor - /// \param init Initial value of the internal counter. - explicit Semaphore(int init) : value_(init) {} - - virtual ~Semaphore() {} - /// \brief Decrement the internal counter. Will be blocked if the value is 0. - /// \return Error code. Can get interrupt. - Status P(); - /// \brief Increment the internal counter. Wakeup on of the watiers if any. - void V(); - /// \brief Peek the internal value - /// \return The internal value - int Peek(); - Status Register(TaskGroup *vg); - Status Deregister(); - void ResetIntrpState(); - - private: - int value_; - - std::mutex mutex_; - CondVar wait_cond_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_SEMAPHORE_H_ diff --git a/mindspore/ccsrc/dataset/util/service.cc b/mindspore/ccsrc/dataset/util/service.cc deleted file mode 100644 index c89f7287f6..0000000000 --- a/mindspore/ccsrc/dataset/util/service.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/service.h" -#include - -namespace mindspore { -namespace dataset { -Status Service::ServiceStart() { - do { - UniqueLock lck(&state_lock_); - // No-op if it is already up or some other thread is - // in the process of bring it up. - if (state_ == STATE::kRunning || state_ == STATE::kStartInProg) { - return Status::OK(); - } - // If a stop is in progress, we line up after it - // is done. - if (state_ == STATE::kStopInProg) { - std::this_thread::yield(); - } else { - state_ = STATE::kStartInProg; - // At this point, we will let go of the lock. This allow others to proceed. - lck.Unlock(); - RETURN_IF_NOT_OK(DoServiceStart()); - // Lock again to change state. - lck.Lock(); - state_ = STATE::kRunning; - return Status::OK(); - } - } while (true); -} - -Status Service::ServiceStop() noexcept { - do { - UniqueLock lck(&state_lock_); - // No-op if it is already stopped or some other thread is - // in the process of shutting it down - if (state_ == STATE::kStopped || state_ == STATE::kStopInProg) { - return Status::OK(); - } - // If a start is in progress, we line up after it - // is done. - if (state_ == STATE::kStartInProg) { - std::this_thread::yield(); - } else { - state_ = STATE::kStopInProg; - // At this point, we will let go of the lock. This allows others to proceed. - lck.Unlock(); - RETURN_IF_NOT_OK(DoServiceStop()); - // Lock again to change state. - lck.Lock(); - state_ = STATE::kStopped; - return Status::OK(); - } - } while (true); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/service.h b/mindspore/ccsrc/dataset/util/service.h deleted file mode 100644 index 1113fc1d14..0000000000 --- a/mindspore/ccsrc/dataset/util/service.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_SERVICE_H_ -#define DATASET_UTIL_SERVICE_H_ - -#include -#include "dataset/util/lock.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class Service { - public: - enum class STATE : int { kStartInProg = 1, kRunning, kStopInProg, kStopped }; - - Service() : state_(STATE::kStopped) {} - - Service(const Service &) = delete; - - Service &operator=(const Service &) = delete; - - virtual ~Service() {} - - STATE ServiceState() const { return state_; } - - virtual Status DoServiceStart() = 0; - - virtual Status DoServiceStop() = 0; - - Status ServiceStart(); - - Status ServiceStop() noexcept; - - protected: - STATE state_; - RWLock state_lock_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/util/services.cc b/mindspore/ccsrc/dataset/util/services.cc deleted file mode 100644 index 6516deea41..0000000000 --- a/mindspore/ccsrc/dataset/util/services.cc +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/services.h" - -#include -#if !defined(_WIN32) && !defined(_WIN64) -#include -#else -#include -#endif -#include -#include "dataset/util/circular_pool.h" -#include "dataset/util/random.h" -#include "dataset/util/task_manager.h" - -#define SLOT_TASK_MGR 0 -namespace mindspore { -namespace dataset { -std::unique_ptr Services::instance_ = nullptr; -std::once_flag Services::init_instance_flag_; - -#if !defined(_WIN32) && !defined(_WIN64) -std::string Services::GetUserName() { - char user[LOGIN_NAME_MAX]; - (void)getlogin_r(user, sizeof(user)); - return std::string(user); -} - -std::string Services::GetHostName() { - char host[LOGIN_NAME_MAX]; - (void)gethostname(host, sizeof(host)); - return std::string(host); -} - -int Services::GetLWP() { return syscall(SYS_gettid); } -#endif - -std::string Services::GetUniqueID() { - const std::string kStr = "abcdefghijklmnopqrstuvwxyz0123456789"; - std::mt19937 gen = GetRandomDevice(); - std::uniform_int_distribution dist(0, kStr.size() - 1); - char buffer[UNIQUEID_LEN]; - for (int i = 0; i < UNIQUEID_LEN; i++) { - buffer[i] = kStr[dist(gen)]; - } - return std::string(buffer, UNIQUEID_LEN); -} - -TaskManager &Services::getTaskMgrInstance() { - Services &sm = GetInstance(); - return *(static_cast(sm.sa_[SLOT_TASK_MGR])); -} - -Status Services::CreateAllInstances() { - // In order, TaskMgr, BufferMgr - Status rc; - sa_[SLOT_TASK_MGR] = new (&rc, pool_) TaskManager(); - RETURN_IF_NOT_OK(rc); - rc = sa_[SLOT_TASK_MGR]->ServiceStart(); - return rc; -} - -Services::Services() : pool_(nullptr), sa_{nullptr} { - Status rc = CircularPool::CreateCircularPool(&pool_, -1, 16, true); // each arena 16M - if (rc.IsError()) { - std::terminate(); - } -} - -Services::~Services() noexcept { - try { - // In reverse order - TaskManager *tm = static_cast(sa_[SLOT_TASK_MGR]); - if (tm) { - (void)tm->ServiceStop(); - tm->~TaskManager(); - pool_->Deallocate(tm); - } - } catch (const std::exception &e) { - // Do nothing. - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/services.h b/mindspore/ccsrc/dataset/util/services.h deleted file mode 100644 index e19f44dccc..0000000000 --- a/mindspore/ccsrc/dataset/util/services.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_SERVICES_H_ -#define DATASET_UTIL_SERVICES_H_ - -#include -#include -#include -#include "dataset/util/memory_pool.h" -#include "dataset/util/allocator.h" -#include "dataset/util/service.h" - -#define UNIQUEID_LEN 36 -namespace mindspore { -namespace dataset { -class TaskManager; - -class Services { - public: - static Status CreateInstance() { - std::call_once(init_instance_flag_, [&]() -> Status { - instance_.reset(new Services()); - return (instance_->CreateAllInstances()); - }); - - if (instance_ == nullptr) { - instance_.reset(new Services()); - return (instance_->CreateAllInstances()); - } - - return Status::OK(); - } - - static Services &GetInstance() { - if (instance_ == nullptr) { - if (!CreateInstance()) { - std::terminate(); - } - } - return *instance_; - } - - Services(const Services &) = delete; - - Services &operator=(const Services &) = delete; - - ~Services() noexcept; - - static TaskManager &getTaskMgrInstance(); - - std::shared_ptr GetServiceMemPool() { return pool_; } - -#if !defined(_WIN32) && !defined(_WIN64) - static std::string GetUserName(); - - static std::string GetHostName(); - - static int GetLWP(); -#endif - - static std::string GetUniqueID(); - - template - static Allocator GetAllocator() { - return Allocator(Services::GetInstance().GetServiceMemPool()); - } - - private: - static std::once_flag init_instance_flag_; - static std::unique_ptr instance_; - // A small pool used for small objects that last until the - // Services Manager shuts down. Used by all sub-services. - std::shared_ptr pool_; - // We use pointers here instead of unique_ptr because we - // want to have ultimate control on the order of - // construction and destruction. - static constexpr int kNumServices_ = 1; - Service *sa_[kNumServices_]; - - Services(); - - Status CreateAllInstances(); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_SERVICES_H_ diff --git a/mindspore/ccsrc/dataset/util/sig_handler.cc b/mindspore/ccsrc/dataset/util/sig_handler.cc deleted file mode 100644 index 644a633066..0000000000 --- a/mindspore/ccsrc/dataset/util/sig_handler.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/sig_handler.h" -#include -#include -#if !defined(_WIN32) && !defined(_WIN64) -#include -#endif -#include -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -// Register the custom signal handlers -#if !defined(_WIN32) && !defined(_WIN64) -void RegisterHandlers() { - struct sigaction new_int_action; - - // For the interrupt handler, we do not use SA_RESETHAND so this handler remains in play - // permanently, do not use the OS default handler for it. - new_int_action.sa_sigaction = &IntHandler; - (void)sigemptyset(&new_int_action.sa_mask); - new_int_action.sa_flags = SA_RESTART | SA_SIGINFO; - (void)sigaction(SIGINT, &new_int_action, nullptr); -} - -extern void IntHandler(int sig_num, // The signal that was raised - siginfo_t *sig_info, // The siginfo structure. - void *context) { // context info - // Wake up the watchdog which is designed as async-signal-safe. - TaskManager::WakeUpWatchDog(); -} -#endif -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/slice.cc b/mindspore/ccsrc/dataset/util/slice.cc deleted file mode 100644 index f1798b4f44..0000000000 --- a/mindspore/ccsrc/dataset/util/slice.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include "dataset/util/slice.h" - -namespace mindspore { -namespace dataset { -WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset, size_t len) : ReadableSlice(src, offset, len) { - mutable_data_ = static_cast(src.mutable_data_) + offset; -} -WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset) - : WritableSlice(src, offset, src.GetSize() - offset) {} -Status WritableSlice::Copy(WritableSlice *dest, const ReadableSlice &src) { - RETURN_UNEXPECTED_IF_NULL(dest); - RETURN_UNEXPECTED_IF_NULL(dest->GetMutablePointer()); - if (dest->GetSize() <= 0) { - RETURN_STATUS_UNEXPECTED("Destination length is non-positive"); - } - auto err = memcpy_s(dest->GetMutablePointer(), dest->GetSize(), src.GetPointer(), src.GetSize()); - if (err) { - RETURN_STATUS_UNEXPECTED(std::to_string(err)); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/slice.h b/mindspore/ccsrc/dataset/util/slice.h deleted file mode 100644 index 127df23cfa..0000000000 --- a/mindspore/ccsrc/dataset/util/slice.h +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_SLICE_H_ -#define DATASET_UTIL_SLICE_H_ - -#include -#include -#include -#include "./securec.h" -#include "dataset/util/allocator.h" -#include "dataset/util/status.h" -namespace mindspore { -namespace dataset { -/// \brief A ReadableSlice wraps a const pointer in memory and its size. -/// \see WritableSlice for a non-const version -/// -class ReadableSlice { - public: - ReadableSlice() : ptr_(nullptr), sz_(0) {} - ReadableSlice(const void *ptr, size_t sz) : ptr_(ptr), sz_(sz) {} - ReadableSlice(const ReadableSlice &src, off64_t offset, size_t len) { - ptr_ = static_cast(src.GetPointer()) + offset; - sz_ = len; - } - ReadableSlice(const ReadableSlice &src, off64_t offset) : ReadableSlice(src, offset, src.sz_ - offset) {} - ReadableSlice(const ReadableSlice &lhs) { - ptr_ = lhs.ptr_; - sz_ = lhs.sz_; - } - ReadableSlice &operator=(const ReadableSlice &lhs) { - if (this != &lhs) { - ptr_ = lhs.ptr_; - sz_ = lhs.sz_; - } - return *this; - } - ReadableSlice(ReadableSlice &&lhs) noexcept { - if (this != &lhs) { - ptr_ = lhs.ptr_; - sz_ = lhs.sz_; - lhs.ptr_ = nullptr; - lhs.sz_ = 0; - } - } - ReadableSlice &operator=(ReadableSlice &&lhs) noexcept { - if (this != &lhs) { - ptr_ = lhs.ptr_; - sz_ = lhs.sz_; - lhs.ptr_ = nullptr; - lhs.sz_ = 0; - } - return *this; - } - /// \brief Getter function - /// \return Const version of the pointer - const void *GetPointer() const { return ptr_; } - /// \brief Getter function - /// \return Size of the slice - size_t GetSize() const { return sz_; } - bool empty() const { return ptr_ == nullptr; } - - private: - const void *ptr_; - size_t sz_; -}; -/// \brief A WritableSlice inherits from ReadableSlice to allow -/// one to write to the address pointed to by the pointer. -/// -class WritableSlice : public ReadableSlice { - public: - friend class StorageContainer; - /// \brief Default constructor - WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} - /// \brief This form of a constructor takes a pointer and its size. - WritableSlice(void *ptr, size_t sz) : ReadableSlice(ptr, sz), mutable_data_(ptr) {} - WritableSlice(const WritableSlice &src, off64_t offset, size_t len); - WritableSlice(const WritableSlice &src, off64_t offset); - WritableSlice(const WritableSlice &lhs) : ReadableSlice(lhs) { mutable_data_ = lhs.mutable_data_; } - WritableSlice &operator=(const WritableSlice &lhs) { - if (this != &lhs) { - mutable_data_ = lhs.mutable_data_; - ReadableSlice::operator=(lhs); - } - return *this; - } - WritableSlice(WritableSlice &&lhs) noexcept : ReadableSlice(std::move(lhs)) { - if (this != &lhs) { - mutable_data_ = lhs.mutable_data_; - lhs.mutable_data_ = nullptr; - } - } - WritableSlice &operator=(WritableSlice &&lhs) noexcept { - if (this != &lhs) { - mutable_data_ = lhs.mutable_data_; - lhs.mutable_data_ = nullptr; - ReadableSlice::operator=(std::move(lhs)); - } - return *this; - } - /// \brief Copy the content from one slice onto another. - static Status Copy(WritableSlice *dest, const ReadableSlice &src); - - private: - void *mutable_data_; - void *GetMutablePointer() { return mutable_data_; } -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_SLICE_H_ diff --git a/mindspore/ccsrc/dataset/util/status.cc b/mindspore/ccsrc/dataset/util/status.cc deleted file mode 100644 index 27e9dfbc83..0000000000 --- a/mindspore/ccsrc/dataset/util/status.cc +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/status.h" -#include -#include "common/utils.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -std::string CodeAsString(const StatusCode c) { - const char *s = nullptr; - if (c == StatusCode::kOK) { - // Optimize the most frequent case - return std::string("OK"); - } else { - switch (c) { - case StatusCode::kOutOfMemory: - s = "Out of memory"; - break; - case StatusCode::kInterrupted: - s = "Interrupted system call"; - break; - case StatusCode::kShapeMisMatch: - s = "Shape is incorrect."; - break; - case StatusCode::kNoSpace: - s = "No space left on device"; - break; - case StatusCode::kPyFuncException: - s = "Exception thrown from PyFunc"; - break; - case StatusCode::kDuplicateKey: - s = "Duplicate key"; - break; - case StatusCode::kProfilingError: - s = "Error encountered while profiling"; - break; - case StatusCode::kUnexpectedError: - default: - s = "Unexpected error"; - break; - } - } - return std::string(s); -} - -Status::Status(StatusCode c) noexcept : code_(c), err_msg_(std::move(CodeAsString(c))) {} - -Status::Status() noexcept : code_(StatusCode::kOK), err_msg_("") {} - -Status::~Status() noexcept {} - -Status::Status(const Status &s) : code_(s.code_), err_msg_(s.err_msg_) {} - -Status &Status::operator=(const Status &s) { - if (this == &s) { - return *this; - } - code_ = s.code_; - err_msg_ = s.err_msg_; - return *this; -} - -Status::Status(Status &&s) noexcept { - code_ = s.code_; - s.code_ = StatusCode::kOK; - err_msg_ = std::move(s.err_msg_); -} - -Status &Status::operator=(Status &&s) noexcept { - if (this == &s) { - return *this; - } - code_ = s.code_; - s.code_ = StatusCode::kOK; - err_msg_ = std::move(s.err_msg_); - return *this; -} - -Status::Status(const StatusCode code, const std::string &msg) : code_(code), err_msg_(msg) {} - -Status::Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra) { - code_ = code; - std::ostringstream ss; - ss << "Thread ID " << this_thread::get_id() << " " << CodeAsString(code) << ". "; - if (!extra.empty()) { - ss << extra; - } - ss << "\n"; - ss << "Line of code : " << line_of_code << "\n"; - if (file_name != nullptr) { - ss << "File : " << file_name << "\n"; - } - err_msg_ = ss.str(); - MS_LOG(INFO) << err_msg_; -} - -std::ostream &operator<<(std::ostream &os, const Status &s) { - os << s.ToString(); - return os; -} - -std::string Status::ToString() const { return err_msg_; } - -StatusCode Status::get_code() const { return code_; } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/storage_container.cc b/mindspore/ccsrc/dataset/util/storage_container.cc deleted file mode 100644 index 3a4c13e2d9..0000000000 --- a/mindspore/ccsrc/dataset/util/storage_container.cc +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/storage_container.h" - -#include -#include -#include -#include -#include "common/utils.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -Status StorageContainer::Create() { - RETURN_IF_NOT_OK(BuddySpace::CreateBuddySpace(&bs_)); - RETURN_IF_NOT_OK(cont_.CreateFile(&fd_)); - is_open_ = true; - MS_LOG(INFO) << "Container " << cont_ << " created"; - return Status::OK(); -} - -Status StorageContainer::Open() noexcept { - std::lock_guard lck(mutex_); - // Check again - if (!is_open_) { - RETURN_IF_NOT_OK(cont_.OpenFile(&fd_)); - is_open_ = true; - } - return Status::OK(); -} - -Status StorageContainer::Close() noexcept { - if (is_open_) { - std::lock_guard lck(mutex_); - // Check again - if (is_open_) { - RETURN_IF_NOT_OK(cont_.CloseFile(fd_)); - is_open_ = false; - fd_ = -1; - } - } - return Status::OK(); -} - -Status StorageContainer::Read(WritableSlice *dest, off64_t offset) const noexcept { - MS_ASSERT(is_open_); - RETURN_UNEXPECTED_IF_NULL(dest); - auto sz = dest->GetSize(); -#if defined(_WIN32) || defined(_WIN64) - // Doesn't seem there is any pread64 on mingw. - // So we will do a seek and then a read under - // a protection of mutex. - std::lock_guard lck(mutex_); - auto seek_err = lseek(fd_, offset, SEEK_SET); - if (seek_err < 0) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - auto r_sz = read(fd_, dest->GetMutablePointer(), sz); -#else - auto r_sz = pread64(fd_, dest->GetMutablePointer(), sz, offset); -#endif - if (r_sz != sz) { - errno_t err = (r_sz == 0) ? EOF : errno; - RETURN_STATUS_UNEXPECTED(strerror(err)); - } - return Status::OK(); -} - -Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const noexcept { - MS_ASSERT(is_open_); - auto sz = dest.GetSize(); -#if defined(_WIN32) || defined(_WIN64) - // Doesn't seem there is any pwrite64 on mingw. - // So we will do a seek and then a read under - // a protection of mutex. - std::lock_guard lck(mutex_); - auto seek_err = lseek(fd_, offset, SEEK_SET); - if (seek_err < 0) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - auto r_sz = write(fd_, dest.GetPointer(), sz); -#else - auto r_sz = pwrite64(fd_, dest.GetPointer(), sz, offset); -#endif - if (r_sz != sz) { - errno_t err = (r_sz == 0) ? EOF : errno; - RETURN_STATUS_UNEXPECTED(strerror(err)); - } - return Status::OK(); -} - -Status StorageContainer::Insert(const std::vector &buf, off64_t *offset) noexcept { - size_t sz = 0; - for (auto &v : buf) { - sz += v.GetSize(); - } - if (sz == 0) { - RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); - } - if (sz > bs_->GetMaxSize()) { - RETURN_STATUS_UNEXPECTED("Request size too big"); - } - BSpaceDescriptor bspd{0}; - addr_t addr = 0; - RETURN_IF_NOT_OK(bs_->Alloc(sz, &bspd, &addr)); - *offset = static_cast(addr); - // We will do piecewise copy of the data to disk. - for (auto &v : buf) { - RETURN_IF_NOT_OK(Write(v, addr)); - addr += v.GetSize(); - } - return Status::OK(); -} - -Status StorageContainer::Truncate() const noexcept { - if (is_open_) { - RETURN_IF_NOT_OK(cont_.TruncateFile(fd_)); - MS_LOG(INFO) << "Container " << cont_ << " truncated"; - } - return Status::OK(); -} - -StorageContainer::~StorageContainer() noexcept { - (void)Truncate(); - (void)Close(); -} - -std::ostream &operator<<(std::ostream &os, const StorageContainer &s) { - os << "File path : " << s.cont_ << "\n" << *(s.bs_.get()); - return os; -} - -Status StorageContainer::CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path) { - Status rc; - auto sc = new (std::nothrow) StorageContainer(path); - if (sc == nullptr) { - return Status(StatusCode::kOutOfMemory); - } - rc = sc->Create(); - if (rc.IsOk()) { - (*out_sc).reset(sc); - } else { - delete sc; - } - return rc; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/storage_container.h b/mindspore/ccsrc/dataset/util/storage_container.h deleted file mode 100644 index 07e41bd66a..0000000000 --- a/mindspore/ccsrc/dataset/util/storage_container.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_ -#define DATASET_UTIL_STORAGE_CONTAINER_H_ - -#include -#include -#include -#include -#include -#include -#include "dataset/util/system_pool.h" -#include "dataset/util/buddy.h" -#include "dataset/util/path.h" -#include "dataset/util/slice.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class StorageManager; - -class StorageContainer { - public: - friend class StorageManager; - - ~StorageContainer() noexcept; - - StorageContainer(const StorageContainer &) = delete; - - StorageContainer &operator=(const StorageContainer &) = delete; - - friend std::ostream &operator<<(std::ostream &os, const StorageContainer &s); - - Status Open() noexcept; - - Status Close() noexcept; - - Status Insert(const std::vector &buf, off64_t *offset) noexcept; - - Status Write(const ReadableSlice &dest, off64_t offset) const noexcept; - - Status Read(WritableSlice *dest, off64_t offset) const noexcept; - - Status Truncate() const noexcept; - - bool IsOpen() const { return is_open_; } - - static Status CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path); - - private: - mutable std::mutex mutex_; - Path cont_; - int fd_; - bool is_open_; - std::unique_ptr bs_; - - // Use the default value of BuddySpace - // which can map upto 4G of space. - explicit StorageContainer(const std::string &path) : cont_(path), fd_(-1), is_open_(false), bs_(nullptr) {} - - Status Create(); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_STORAGE_CONTAINER_H_ diff --git a/mindspore/ccsrc/dataset/util/storage_manager.cc b/mindspore/ccsrc/dataset/util/storage_manager.cc deleted file mode 100644 index 1d958576ba..0000000000 --- a/mindspore/ccsrc/dataset/util/storage_manager.cc +++ /dev/null @@ -1,166 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/storage_manager.h" - -#include -#include -#include -#include -#include "common/utils.h" -#include "dataset/util/path.h" -#include "dataset/util/services.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -std::string StorageManager::GetBaseName(const std::string &prefix, int32_t file_id) { - std::ostringstream oss; - oss << prefix << std::setfill('0') << std::setw(5) << file_id; - return oss.str(); -} - -std::string StorageManager::ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix) { - std::string base_name = GetBaseName(prefix, file_id); - return (base_name + "." + suffix); -} - -Status StorageManager::AddOneContainer() { - const std::string kPrefix = "IMG"; - const std::string kSuffix = "LB"; - Path container_name = root_ / ConstructFileName(kPrefix, file_id_, kSuffix); - std::shared_ptr sc; - RETURN_IF_NOT_OK(StorageContainer::CreateStorageContainer(&sc, container_name.toString())); - containers_.push_back(sc); - file_id_++; - return Status::OK(); -} - -Status StorageManager::DoServiceStart() { - containers_.reserve(1000); - if (root_.IsDirectory()) { - RETURN_IF_NOT_OK(AddOneContainer()); - } else { - RETURN_STATUS_UNEXPECTED("Not a directory"); - } - return Status::OK(); -} - -Status StorageManager::Write(key_type *key, const std::vector &buf) { - RETURN_UNEXPECTED_IF_NULL(key); - size_t sz = 0; - for (auto &v : buf) { - sz += v.GetSize(); - } - if (sz == 0) { - RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); - } - std::shared_ptr cont; - key_type out_key; - value_type out_value; - bool create_new_container = false; - do { - SharedLock lock_s(&rw_lock_); - size_t num_containers = containers_.size(); - if (create_new_container) { - // Upgrade to exclusvie lock. - lock_s.Upgrade(); - create_new_container = false; - // Check again if someone has already added a - // new container after we got the x lock - if (containers_.size() == num_containers) { - RETURN_IF_NOT_OK(AddOneContainer()); - } - // Refresh how many containers there are. - num_containers = containers_.size(); - // Downgrade back to shared lock - lock_s.Downgrade(); - } - if (num_containers == 0) { - RETURN_STATUS_UNEXPECTED("num_containers is zero"); - } - // Go to the last container to insert. - cont = containers_.at(num_containers - 1); - off64_t offset; - Status rc = cont->Insert(buf, &offset); - if (rc.IsNoSpace()) { - create_new_container = true; - } else if (rc.IsOk()) { - out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); - RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); - *key = out_key; - break; - } else { - return rc; - } - } while (true); - return Status::OK(); -} - -Status StorageManager::Read(StorageManager::key_type key, WritableSlice *dest, size_t *bytesRead) const { - RETURN_UNEXPECTED_IF_NULL(dest); - auto r = index_.Search(key); - if (r.second) { - auto &it = r.first; - value_type v = *it; - int container_inx = v.first; - off_t offset = v.second.first; - size_t sz = v.second.second; - if (dest->GetSize() < sz) { - std::string errMsg = "Destination buffer too small. Expect at least " + std::to_string(sz) + - " but length = " + std::to_string(dest->GetSize()); - RETURN_STATUS_UNEXPECTED(errMsg); - } - if (bytesRead != nullptr) { - *bytesRead = sz; - } - auto cont = containers_.at(container_inx); - RETURN_IF_NOT_OK(cont->Read(dest, offset)); - } else { - RETURN_STATUS_UNEXPECTED("Key not found"); - } - return Status::OK(); -} - -Status StorageManager::DoServiceStop() noexcept { - Status rc; - Status rc1; - for (auto const &p : containers_) { - // The destructor of StorageContainer is not called automatically until the use - // count drops to 0. But it is not always the case. We will do it ourselves. - rc = p.get()->Truncate(); - if (rc.IsError()) { - rc1 = rc; - } - } - containers_.clear(); - file_id_ = 0; - return rc1; -} - -StorageManager::StorageManager(const Path &root) : root_(root), file_id_(0), index_() {} - -StorageManager::~StorageManager() { (void)StorageManager::DoServiceStop(); } - -std::ostream &operator<<(std::ostream &os, const StorageManager &s) { - os << "Dumping all containers ..." - << "\n"; - for (auto const &p : s.containers_) { - os << *(p.get()); - } - return os; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/storage_manager.h b/mindspore/ccsrc/dataset/util/storage_manager.h deleted file mode 100644 index 075ac713d2..0000000000 --- a/mindspore/ccsrc/dataset/util/storage_manager.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_STORAGE_MANAGER_H_ -#define DATASET_UTIL_STORAGE_MANAGER_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/allocator.h" -#include "dataset/util/auto_index.h" -#include "dataset/util/lock.h" -#include "dataset/util/memory_pool.h" -#include "dataset/util/path.h" -#include "dataset/util/service.h" -#include "dataset/util/slice.h" -#include "dataset/util/storage_container.h" - -using ListOfContainers = std::vector>; -namespace mindspore { -namespace dataset { -class StorageManager : public Service { - public: - using storage_index = AutoIndexObj>>; - using key_type = storage_index::key_type; - using value_type = storage_index::value_type; - - explicit StorageManager(const Path &); - - ~StorageManager() override; - - StorageManager(const StorageManager &) = delete; - - StorageManager &operator=(const StorageManager &) = delete; - - Status Write(key_type *out_key, const std::vector &buf); - - Status Read(key_type key, WritableSlice *dest, size_t *bytesRead) const; - - Status DoServiceStart() override; - - Status DoServiceStop() noexcept override; - - friend std::ostream &operator<<(std::ostream &os, const StorageManager &s); - - private: - Path root_; - ListOfContainers containers_; - int file_id_; - RWLock rw_lock_; - storage_index index_; - - std::string GetBaseName(const std::string &prefix, int32_t file_id); - - std::string ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix); - - Status AddOneContainer(); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_STORAGE_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/util/system_pool.h b/mindspore/ccsrc/dataset/util/system_pool.h deleted file mode 100644 index 286e30a615..0000000000 --- a/mindspore/ccsrc/dataset/util/system_pool.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_SYSTEM_POOL_H_ -#define DATASET_UTIL_SYSTEM_POOL_H_ - -#include -#include -#include -#include -#include -#include "./securec.h" -#include "dataset/util/allocator.h" -#include "dataset/util/memory_pool.h" - -namespace mindspore { -namespace dataset { -// This class demonstrate how to implement a simple MemoryPool -// for minddata/dataset using malloc/free/realloc. We need to -// implement 4 virtual functions. Other MemoryPool -// implementation, e.g., are BuddyArena and CircularPool. All -// these MemoryPool can be used together with Allocator.h for -// C++ STL containers. -class SystemPool : public MemoryPool { - public: - ~SystemPool() override {} - - Status Allocate(size_t n, void **pp) override { return DeMalloc(n, pp, false); } - - void Deallocate(void *p) override { free(p); } - - Status Reallocate(void **p, size_t old_sz, size_t new_sz) override { - if (old_sz >= new_sz) { - // Do nothing if we shrink. - return Status::OK(); - } else { - void *ptr = *p; - void *q = nullptr; - RETURN_IF_NOT_OK(DeMalloc(new_sz, &q, false)); - errno_t err = memcpy_s(q, new_sz, ptr, old_sz); - if (err) { - free(q); - RETURN_STATUS_UNEXPECTED(std::to_string(err)); - } - free(ptr); - *p = q; - return Status::OK(); - } - } - - uint64_t get_max_size() const override { return std::numeric_limits::max(); } - - int PercentFree() const override { return 100; } - - template - static Allocator GetAllocator() { - return Allocator(std::make_shared()); - } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_SYSTEM_POOL_H_ diff --git a/mindspore/ccsrc/dataset/util/task.cc b/mindspore/ccsrc/dataset/util/task.cc deleted file mode 100644 index 93db55d5f9..0000000000 --- a/mindspore/ccsrc/dataset/util/task.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/task.h" -#include "common/utils.h" -#include "dataset/util/task_manager.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -thread_local Task *gMyTask = nullptr; - -void Task::operator()() { -#if !defined(_WIN32) && !defined(_WIN64) - gMyTask = this; -#endif - id_ = this_thread::get_id(); - std::stringstream ss; - ss << id_; - MS_LOG(DEBUG) << my_name_ << " Thread ID " << ss.str() << " Started."; - try { - // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set - // the TaskGroup pointer and register. We move the registration logic to here (after we spawn) so we can - // get the thread id. - TaskGroup *vg = MyTaskGroup(); - rc_ = vg->GetIntrpService()->Register(ss.str(), this); - if (rc_.IsOk()) { - // Now we can run the given task. - rc_ = fnc_obj_(); - } - // Some error codes are ignored, e.g. interrupt. Others we just shutdown the group. - if (rc_.IsError() && !rc_.IsInterrupted()) { - ShutdownGroup(); - } - } catch (const std::bad_alloc &e) { - rc_ = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, e.what()); - ShutdownGroup(); - } catch (const std::exception &e) { - rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); - ShutdownGroup(); - } -} - -void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine. - { - std::lock_guard lk(mux_); - caught_severe_exception_ = true; - } - TaskGroup *vg = MyTaskGroup(); - // If multiple threads hit severe errors in the same group. Keep the first one and - // discard the rest. - if (vg->rc_.IsOk()) { - std::unique_lock rcLock(vg->rc_mux_); - // Check again after we get the lock - if (vg->rc_.IsOk()) { - vg->rc_ = rc_; - rcLock.unlock(); - TaskManager::InterruptMaster(rc_); - TaskManager::InterruptGroup(*this); - } - } -} - -Status Task::GetTaskErrorIfAny() const { - std::lock_guard lk(mux_); - if (caught_severe_exception_) { - return rc_; - } else { - return Status::OK(); - } -} - -Task::Task(const std::string &myName, const std::function &f) - : my_name_(myName), - rc_(), - fnc_obj_(f), - task_group_(nullptr), - is_master_(false), - running_(false), - caught_severe_exception_(false) { - IntrpResource::ResetIntrpState(); - wp_.ResetIntrpState(); - wp_.Clear(); -} - -Status Task::Run() { - Status rc; - if (running_ == false) { - try { - thrd_ = std::async(std::launch::async, std::ref(*this)); - running_ = true; - caught_severe_exception_ = false; - } catch (const std::exception &e) { - rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); - } - } - return rc; -} - -Status Task::Join(WaitFlag blocking) { - if (running_) { - RETURN_UNEXPECTED_IF_NULL(MyTaskGroup()); - auto interrupt_svc = MyTaskGroup()->GetIntrpService(); - try { - if (blocking == WaitFlag::kBlocking) { - // If we are asked to wait, then wait - thrd_.get(); - } else if (blocking == WaitFlag::kNonBlocking) { - // There is a race condition in the global resource tracking such that a thread can miss the - // interrupt and becomes blocked on a conditional variable forever. As a result, calling - // join() will not come back. We need some timeout version of join such that if the thread - // doesn't come back in a reasonable of time, we will send the interrupt again. - while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { - // We can't tell which conditional_variable this thread is waiting on. So we may need - // to interrupt everything one more time. - MS_LOG(INFO) << "Some threads not responding. Interrupt again"; - interrupt_svc->InterruptAll(); - } - } else { - RETURN_STATUS_UNEXPECTED("Unknown WaitFlag"); - } - std::stringstream ss; - ss << get_id(); - MS_LOG(DEBUG) << MyName() << " Thread ID " << ss.str() << " Stopped."; - running_ = false; - RETURN_IF_NOT_OK(wp_.Deregister()); - RETURN_IF_NOT_OK(interrupt_svc->Deregister(ss.str())); - } catch (const std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - } - return Status::OK(); -} - -TaskGroup *Task::MyTaskGroup() { return task_group_; } - -void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; } - -Task::~Task() { task_group_ = nullptr; } -Status Task::OverrideInterruptRc(const Status &rc) { - if (rc.IsInterrupted() && this_thread::is_master_thread()) { - // If we are interrupted, override the return value if this is the master thread. - // Master thread is being interrupted mostly because of some thread is reporting error. - return TaskManager::GetMasterThreadRc(); - } - return rc; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/task.h b/mindspore/ccsrc/dataset/util/task.h deleted file mode 100644 index 49eb16b182..0000000000 --- a/mindspore/ccsrc/dataset/util/task.h +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_TASK_H_ -#define DATASET_UTIL_TASK_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "dataset/util/intrp_resource.h" -#include "dataset/util/list.h" -#include "dataset/util/memory_pool.h" -#include "dataset/util/services.h" -#include "dataset/util/wait_post.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -class TaskManager; - -class Task : public IntrpResource { - public: - friend class TaskManager; - friend class TaskGroup; - - enum class WaitFlag : int { kBlocking, kNonBlocking }; - - Task(const std::string &myName, const std::function &f); - - // Future objects are not copyable. - Task(const Task &) = delete; - - ~Task() override; - - Task &operator=(const Task &) = delete; - - // Move constructor and Assignment are not supported. - // Too many things in this class. - Task(Task &&) = delete; - - Task &operator=(Task &&) = delete; - - Status GetTaskErrorIfAny() const; - - void ChangeName(const std::string &newName) { my_name_ = newName; } - - // To execute the _fncObj - void operator()(); - - Node node; - Node group; - Node free; - - // Run the task - Status Run(); - - Status Join(WaitFlag wf = WaitFlag::kBlocking); - - bool Running() const { return running_; } - - bool CaughtSevereException() const { return caught_severe_exception_; } - - bool IsMasterThread() const { return is_master_; } - - std::thread::id get_id() { return id_; } - - std::string MyName() { return my_name_; } - - // An operator used by std::find - bool operator==(const Task &other) const { return (this == &other); } - - bool operator!=(const Task &other) const { return !(*this == other); } - - void Post() { wp_.Set(); } - - Status Wait() { return (wp_.Wait()); } - - static Status OverrideInterruptRc(const Status &rc); - - private: - mutable std::mutex mux_; - std::string my_name_; - Status rc_; - WaitPost wp_; - // Task need to provide definition for this function. It - // will be called by thread function. - std::function fnc_obj_; - // Misc fields used by TaskManager. - TaskGroup *task_group_; - std::future thrd_; - std::thread::id id_; - bool is_master_; - volatile bool running_; - volatile bool caught_severe_exception_; - - void ShutdownGroup(); - TaskGroup *MyTaskGroup(); - void set_task_group(TaskGroup *vg); -}; - -extern thread_local Task *gMyTask; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_TASK_H_ diff --git a/mindspore/ccsrc/dataset/util/task_manager.cc b/mindspore/ccsrc/dataset/util/task_manager.cc deleted file mode 100644 index 3965e35564..0000000000 --- a/mindspore/ccsrc/dataset/util/task_manager.cc +++ /dev/null @@ -1,353 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include "./securec.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -// This takes the same parameter as Task constructor. -Status TaskManager::CreateAsyncTask(const std::string &my_name, const std::function &f, TaskGroup *vg, - Task **task) { - // We need to block destructor coming otherwise we will deadlock. We will grab the - // stateLock in shared allowing CreateAsyncTask to run concurrently. - SharedLock stateLck(&state_lock_); - // Now double check the state - if (ServiceState() == STATE::kStopInProg || ServiceState() == STATE::kStopped) { - return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "TaskManager is shutting down"); - } - RETURN_IF_NOT_OK(GetFreeTask(my_name, f, task)); - if (vg == nullptr) { - RETURN_STATUS_UNEXPECTED("TaskGroup is null"); - } - // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set - // the TaskGroup pointer. We will do the set here before we call run(). The run() will do the registration. - (*task)->set_task_group(vg); - // Link to the master lru list. - { - UniqueLock lck(&lru_lock_); - lru_.Append(*task); - } - // Link to the group list as well before we spawn. - { - UniqueLock lck(&vg->rw_lock_); - vg->grp_list_.Append(*task); - } - // Track all the TaskGroup. Used for control-c - { - LockGuard lck(&tg_lock_); - this->grp_list_.insert(vg); - } - RETURN_IF_NOT_OK((*task)->wp_.Register(vg)); - RETURN_IF_NOT_OK((*task)->Run()); - // Wait for the thread to initialize successfully. - RETURN_IF_NOT_OK((*task)->Wait()); - return Status::OK(); -} - -Status TaskManager::join_all() { - Status rc; - Status rc2; - SharedLock lck(&lru_lock_); - for (Task &tk : lru_) { - rc = tk.Join(); - if (rc.IsError()) { - rc2 = rc; - } - } - return rc2; -} - -void TaskManager::interrupt_all() noexcept { - global_interrupt_ = 1; - LockGuard lck(&tg_lock_); - for (TaskGroup *vg : grp_list_) { - auto svc = vg->GetIntrpService(); - if (svc) { - // Stop the interrupt service. No new request is accepted. - svc->ServiceStop(); - svc->InterruptAll(); - } - } - master_->Interrupt(); -} - -Task *TaskManager::FindMe() { -#if !defined(_WIN32) && !defined(_WIN64) - return gMyTask; -#else - TaskManager &tm = TaskManager::GetInstance(); - SharedLock lock(&tm.lru_lock_); - auto id = this_thread::get_id(); - auto tk = std::find_if(tm.lru_.begin(), tm.lru_.end(), [id](const Task &tk) { return tk.id_ == id; }); - if (tk != tm.lru_.end()) { - return &(*tk); - } - // If we get here, either I am the watchdog or the master thread. - if (tm.master_->id_ == id) { - return tm.master_.get(); - } else if (tm.watchdog_ != nullptr && tm.watchdog_->id_ == id) { - return tm.watchdog_; - } - MS_LOG(ERROR) << "Task not found."; - return nullptr; -#endif -} - -TaskManager::TaskManager() try : global_interrupt_(0), - lru_(&Task::node), - free_lst_(&Task::free), - watchdog_grp_(nullptr), - watchdog_(nullptr) { - auto alloc = Services::GetAllocator(); - // Create a dummy Task for the master thread (this thread) - master_ = std::allocate_shared(alloc, "master", []() -> Status { return Status::OK(); }); - master_->id_ = this_thread::get_id(); - master_->running_ = true; - master_->is_master_ = true; -#if !defined(_WIN32) && !defined(_WIN64) - gMyTask = master_.get(); - // Initialize the semaphore for the watchdog - errno_t rc = sem_init(&sem_, 0, 0); - if (rc == -1) { - MS_LOG(ERROR) << "Unable to initialize a semaphore. Errno = " << rc << "."; - std::terminate(); - } -#endif -} catch (const std::exception &e) { - MS_LOG(ERROR) << "MindData initialization failed: " << e.what() << "."; - std::terminate(); -} - -TaskManager::~TaskManager() { - if (watchdog_) { - WakeUpWatchDog(); - watchdog_->Join(); - // watchdog_grp_ and watchdog_ pointers come from Services::GetInstance().GetServiceMemPool() which we will free it - // on shutdown. So no need to free these pointers one by one. - watchdog_grp_ = nullptr; - watchdog_ = nullptr; - } -#if !defined(_WIN32) && !defined(_WIN64) - (void)sem_destroy(&sem_); -#endif -} - -Status TaskManager::DoServiceStart() { - MS_LOG(INFO) << "Starting Task Manager."; -#if !defined(_WIN32) && !defined(_WIN64) - // Create a watchdog for control-c - std::shared_ptr mp = Services::GetInstance().GetServiceMemPool(); - // A dummy group just for the watchdog. We aren't really using it. But most code assumes a thread must - // belong to a group. - auto f = std::bind(&TaskManager::WatchDog, this); - Status rc; - watchdog_grp_ = new (&rc, mp) TaskGroup(); - RETURN_IF_NOT_OK(rc); - rc = watchdog_grp_->CreateAsyncTask("Watchdog", f, &watchdog_); - if (rc.IsError()) { - ::operator delete(watchdog_grp_, mp); - watchdog_grp_ = nullptr; - return rc; - } - grp_list_.erase(watchdog_grp_); - lru_.Remove(watchdog_); -#endif - return Status::OK(); -} - -Status TaskManager::DoServiceStop() { - WakeUpWatchDog(); - interrupt_all(); - return Status::OK(); -} - -Status TaskManager::WatchDog() { - TaskManager::FindMe()->Post(); -#if !defined(_WIN32) && !defined(_WIN64) - errno_t err = sem_wait(&sem_); - if (err == -1) { - RETURN_STATUS_UNEXPECTED("Errno = " + std::to_string(errno)); - } - // We are woken up by control-c and we are going to stop all threads that are running. - // In addition, we also want to prevent new thread from creating. This can be done - // easily by calling the parent function. - RETURN_IF_NOT_OK(ServiceStop()); -#endif - return Status::OK(); -} - -// Follow the group link and interrupt other -// Task in the same group. It is used by -// Watchdog only. -void TaskManager::InterruptGroup(Task &curTk) { - TaskGroup *vg = curTk.MyTaskGroup(); - vg->interrupt_all(); -} - -void TaskManager::InterruptMaster(const Status &rc) { - TaskManager &tm = TaskManager::GetInstance(); - std::shared_ptr master = tm.master_; - std::lock_guard lck(master->mux_); - master->Interrupt(); - if (rc.IsError() && master->rc_.IsOk()) { - master->rc_ = rc; - master->caught_severe_exception_ = true; - } -} - -Status TaskManager::GetMasterThreadRc() { - TaskManager &tm = TaskManager::GetInstance(); - std::shared_ptr master = tm.master_; - Status rc = tm.master_->GetTaskErrorIfAny(); - if (rc.IsError()) { - // Reset the state once we retrieve the value. - std::lock_guard lck(master->mux_); - master->rc_ = Status::OK(); - master->caught_severe_exception_ = false; - master->ResetIntrpState(); - } - return rc; -} - -void TaskManager::ReturnFreeTask(Task *p) noexcept { - // Take it out from lru_ if any - { - UniqueLock lck(&lru_lock_); - auto it = std::find(lru_.begin(), lru_.end(), *p); - if (it != lru_.end()) { - lru_.Remove(p); - } - } - // We need to deallocate the string resources associated with the Task class - // before we cache its memory for future use. - p->~Task(); - // Put it back into free list - { - LockGuard lck(&free_lock_); - free_lst_.Append(p); - } -} - -Status TaskManager::GetFreeTask(const std::string &my_name, const std::function &f, Task **p) { - if (p == nullptr) { - RETURN_STATUS_UNEXPECTED("p is null"); - } - Task *q = nullptr; - // First try the free list - { - LockGuard lck(&free_lock_); - if (free_lst_.count > 0) { - q = free_lst_.head; - free_lst_.Remove(q); - } - } - if (q) { - new (q) Task(my_name, f); - } else { - std::shared_ptr mp = Services::GetInstance().GetServiceMemPool(); - Status rc; - q = new (&rc, mp) Task(my_name, f); - RETURN_IF_NOT_OK(rc); - } - *p = q; - return Status::OK(); -} - -Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::function &f, Task **ppTask) { - auto pMytask = TaskManager::FindMe(); - // We need to block ~TaskGroup coming otherwise we will deadlock. We will grab the - // stateLock in shared allowing CreateAsyncTask to run concurrently. - SharedLock state_lck(&state_lock_); - // Now double check the state - if (ServiceState() != STATE::kRunning) { - return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "Taskgroup is shutting down"); - } - TaskManager &dm = TaskManager::GetInstance(); - Task *pTask = nullptr; - // If the group is already in error, early exit too. - // We can't hold the rc_mux_ throughout because the thread spawned by CreateAsyncTask may hit error which - // will try to shutdown the group and grab the rc_mux_ and we will deadlock. - { - std::unique_lock rcLock(rc_mux_); - if (rc_.IsError()) { - return pMytask->IsMasterThread() ? rc_ : Status(StatusCode::kInterrupted); - } - } - RETURN_IF_NOT_OK(dm.CreateAsyncTask(my_name, f, this, &pTask)); - if (ppTask) { - *ppTask = pTask; - } - return Status::OK(); -} - -void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); } - -Status TaskGroup::join_all(Task::WaitFlag wf) { - Status rc; - Status rc2; - SharedLock lck(&rw_lock_); - for (Task &tk : grp_list_) { - rc = tk.Join(wf); - if (rc.IsError()) { - rc2 = rc; - } - } - return rc2; -} - -Status TaskGroup::DoServiceStop() { - intrp_svc_->ServiceStop(); - interrupt_all(); - return (join_all(Task::WaitFlag::kNonBlocking)); -} - -TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) { - auto alloc = Services::GetAllocator(); - intrp_svc_ = std::allocate_shared(alloc); - (void)Service::ServiceStart(); -} - -TaskGroup::~TaskGroup() { - (void)Service::ServiceStop(); - // The TaskGroup is going out of scope, and we can return the Task list to the free list. - Task *cur = grp_list_.head; - TaskManager &tm = TaskManager::GetInstance(); - while (cur) { - Task *next = cur->group.next; - grp_list_.Remove(cur); - tm.ReturnFreeTask(cur); - cur = next; - } - { - LockGuard lck(&tm.tg_lock_); - (void)tm.grp_list_.erase(this); - } -} - -Status TaskGroup::GetTaskErrorIfAny() { - SharedLock lck(&rw_lock_); - for (Task &tk : grp_list_) { - RETURN_IF_NOT_OK(tk.GetTaskErrorIfAny()); - } - return Status::OK(); -} - -std::shared_ptr TaskGroup::GetIntrpService() { return intrp_svc_; } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/task_manager.h b/mindspore/ccsrc/dataset/util/task_manager.h deleted file mode 100644 index 5961c9000e..0000000000 --- a/mindspore/ccsrc/dataset/util/task_manager.h +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_TASK_MANAGER_H_ -#define DATASET_UTIL_TASK_MANAGER_H_ - -#if !defined(_WIN32) && !defined(_WIN64) -#include -#include // for sig_atomic_t -#endif -#include -#include -#include -#include -#include -#include "dataset/util/allocator.h" -#include "dataset/util/intrp_service.h" -#include "dataset/util/lock.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" -#include "dataset/util/task.h" - -namespace mindspore { -namespace dataset { -namespace thread { -using id = std::thread::id; -} // namespace thread - -namespace this_thread { -inline thread::id get_id() { return std::this_thread::get_id(); } -} // namespace this_thread - -class TaskManager : public Service { - public: - friend class Services; - - friend class TaskGroup; - - ~TaskManager() override; - - TaskManager(const TaskManager &) = delete; - - TaskManager &operator=(const TaskManager &) = delete; - - static TaskManager &GetInstance() noexcept { return Services::getTaskMgrInstance(); } - - Status DoServiceStart() override; - - Status DoServiceStop() override; - - // A public global interrupt flag for signal handlers - volatile sig_atomic_t global_interrupt_; - - // API - // This takes the same parameter as Task constructor. Take a look - // of the test-thread.cc for usage. - Status CreateAsyncTask(const std::string &my_name, const std::function &f, TaskGroup *vg, Task **); - - // Same usage as boot thread group - Status join_all(); - - void interrupt_all() noexcept; - - // Locate a particular Task. - static Task *FindMe(); - - static void InterruptGroup(Task &); - - static Status GetMasterThreadRc(); - - static void InterruptMaster(const Status &rc = Status::OK()); - - static void WakeUpWatchDog() { -#if !defined(_WIN32) && !defined(_WIN64) - TaskManager &tm = TaskManager::GetInstance(); - (void)sem_post(&tm.sem_); -#endif - } - - void ReturnFreeTask(Task *p) noexcept; - - Status GetFreeTask(const std::string &my_name, const std::function &f, Task **p); - - Status WatchDog(); - - private: - RWLock lru_lock_; - SpinLock free_lock_; - SpinLock tg_lock_; - std::shared_ptr master_; - List lru_; - List free_lst_; -#if !defined(_WIN32) && !defined(_WIN64) - sem_t sem_; -#endif - TaskGroup *watchdog_grp_; - std::set grp_list_; - Task *watchdog_; - - TaskManager(); -}; - -// A group of related tasks. -class TaskGroup : public Service { - public: - friend class Task; - friend class TaskManager; - - Status CreateAsyncTask(const std::string &my_name, const std::function &f, Task **pTask = nullptr); - - void interrupt_all() noexcept; - - Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking); - - int size() const noexcept { return grp_list_.count; } - - Status DoServiceStart() override { return Status::OK(); } - - Status DoServiceStop() override; - - TaskGroup(); - - ~TaskGroup() override; - - Status GetTaskErrorIfAny(); - - std::shared_ptr GetIntrpService(); - - private: - Status rc_; - // Can't use rw_lock_ as we will lead to deadlatch. Create another mutex to serialize access to rc_. - std::mutex rc_mux_; - RWLock rw_lock_; - List grp_list_; - std::shared_ptr intrp_svc_; -}; - -namespace this_thread { -inline bool is_interrupted() { - TaskManager &tm = TaskManager::GetInstance(); - if (tm.global_interrupt_ == 1) { - return true; - } - Task *my_task = TaskManager::FindMe(); - return my_task->Interrupted(); -} - -inline bool is_master_thread() { - Task *my_task = TaskManager::FindMe(); - return my_task->IsMasterThread(); -} - -inline Status GetInterruptStatus() { - Task *my_task = TaskManager::FindMe(); - return my_task->GetInterruptStatus(); -} -} // namespace this_thread - -#define RETURN_IF_INTERRUPTED() \ - do { \ - if (mindspore::dataset::this_thread::is_interrupted()) { \ - return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \ - } \ - } while (false) - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_TASK_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/util/wait_post.cc b/mindspore/ccsrc/dataset/util/wait_post.cc deleted file mode 100644 index 204f203d9a..0000000000 --- a/mindspore/ccsrc/dataset/util/wait_post.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/wait_post.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -WaitPost::WaitPost() : value_(0) {} - -Status WaitPost::Wait() { - std::unique_lock lck(mutex_); - return (wait_cond_.Wait(&lck, [this]() { return value_ != 0; })); -} - -void WaitPost::Set() { - std::unique_lock lck(mutex_); - value_ = 1; - wait_cond_.NotifyAll(); -} - -void WaitPost::Clear() { - std::unique_lock lck(mutex_); - value_ = 0; -} - -Status WaitPost::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } - -void WaitPost::ResetIntrpState() { wait_cond_.ResetIntrpState(); } - -Status WaitPost::Deregister() { return wait_cond_.Deregister(); } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/wait_post.h b/mindspore/ccsrc/dataset/util/wait_post.h deleted file mode 100644 index 4e60995bd9..0000000000 --- a/mindspore/ccsrc/dataset/util/wait_post.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_WAIT_POST_H_ -#define DATASET_UTIL_WAIT_POST_H_ - -#include -#include "dataset/util/cond_var.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class TaskGroup; - -class WaitPost { - public: - WaitPost(); - - ~WaitPost() = default; - - Status Wait(); - - void Set(); - - void Clear(); - - Status Register(TaskGroup *vg); - - Status Deregister(); - - void ResetIntrpState(); - - private: - std::mutex mutex_; - CondVar wait_cond_; - int value_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_WAIT_POST_H_ diff --git a/mindspore/ccsrc/debug/CMakeLists.txt b/mindspore/ccsrc/debug/CMakeLists.txt index ba0c5e07ac..37ffcceeaf 100644 --- a/mindspore/ccsrc/debug/CMakeLists.txt +++ b/mindspore/ccsrc/debug/CMakeLists.txt @@ -19,6 +19,15 @@ if (ENABLE_DEBUGGER) ) endif (ENABLE_DEBUGGER) +if (ENABLE_D) + list(APPEND _DEBUG_SRC_LIST + "${CMAKE_CURRENT_SOURCE_DIR}/common.cc" + ) + if (ENABLE_DATA_DUMP) + list(APPEND _DEBUG_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/data_dump_parser.cc") + endif(ENABLE_DATA_DUMP) +endif() + if (ENABLE_DUMP_E2E) list(APPEND _DEBUG_SRC_LIST "${CMAKE_CURRENT_SOURCE_DIR}/e2e_dump.cc") endif (ENABLE_DUMP_E2E) diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index fc32e0fb5f..42d372cefb 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -24,9 +24,9 @@ #include "ir/primitive.h" #include "ir/func_graph.h" -#include "device/kernel_info.h" +#include "runtime/device/kernel_info.h" #include "utils/graph_utils.h" -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" namespace mindspore { const std::string ToShortString(const TypeId &typeId) { @@ -128,7 +128,7 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr return; } auto kernel_info = node->kernel_info(); - if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { + if (kernel_info == nullptr || !kernel_info->has_build_info()) { return; } @@ -179,7 +179,7 @@ void DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMa // print parameters' type and shape PrintNodeOutputType(buffer, p); auto kernel_info = p->kernel_info(); - if (kernel_info != nullptr && kernel_info->select_kernel_build_info() != nullptr) { + if (kernel_info != nullptr && kernel_info->has_build_info()) { buffer << " : "; auto type = AnfAlgo::GetOutputDeviceDataType(p, 0); auto format = AnfAlgo::GetOutputFormat(p, 0); diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index c797b8efea..273a6f6458 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -26,19 +26,19 @@ #include "utils/graph_utils.h" #include "utils/symbolic.h" #include "ir/meta_func_graph.h" -#include "ir/param_value_py.h" +#include "ir/param_value.h" #include "ir/tensor_py.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/resolve.h" -#include "operator/composite/composite.h" -#include "operator/composite/map.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/resolve.h" +#include "frontend/operator/composite/composite.h" +#include "frontend/operator/composite/map.h" #include "utils/ordered_map.h" #include "utils/ordered_set.h" #include "utils/utils.h" #include "debug/trace.h" #include "debug/label.h" #include "utils/context/ms_context.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" using mindspore::tensor::TensorPy; @@ -485,8 +485,8 @@ void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vectorhas_default()) { - auto param_value = std::dynamic_pointer_cast(param_ptr->default_param()); - ofs << " = @" << DumpObject(param_value->value(), "D"); + auto param_value = param_ptr->default_param(); + ofs << " = @" << DumpObject(py::cast(param_value), "D"); } // output comment @@ -1667,7 +1667,7 @@ class IrParser { // load parameter default value from serialized file py::object default_obj = LoadObject(lexer_.GetTokenText()); - auto param_value_new = std::make_shared(default_obj); + auto param_value_new = py::cast(default_obj); param->set_default_param(param_value_new); tok = lexer_.GetNextToken(); diff --git a/mindspore/ccsrc/debug/anf_ir_utils.h b/mindspore/ccsrc/debug/anf_ir_utils.h index 4503692eb9..ed5e3b8a5d 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.h +++ b/mindspore/ccsrc/debug/anf_ir_utils.h @@ -28,9 +28,9 @@ #include "ir/anf.h" #include "ir/func_graph.h" #include "ir/meta_func_graph.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/resolve.h" -#include "operator/composite/composite.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/resolve.h" +#include "frontend/operator/composite/composite.h" #include "utils/symbolic.h" #include "utils/ordered_map.h" #include "utils/ordered_set.h" diff --git a/mindspore/ccsrc/debug/common.cc b/mindspore/ccsrc/debug/common.cc new file mode 100644 index 0000000000..6caf7e2c39 --- /dev/null +++ b/mindspore/ccsrc/debug/common.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "debug/common.h" + +#include +#include +#include "utils/system/env.h" +#include "utils/system/file_system.h" +#include "utils/log_adapter.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +std::optional Common::GetRealPath(const std::string &input_path) { + std::string out_path; + auto path_split_pos = input_path.find_last_of('/'); + if (path_split_pos == std::string::npos) { + path_split_pos = input_path.find_last_of('\\'); + } + // get real path + char real_path[PATH_MAX] = {0}; + if (path_split_pos != std::string::npos) { + std::string prefix_path = input_path.substr(0, path_split_pos); + if (prefix_path.length() >= PATH_MAX) { + MS_LOG(ERROR) << "Prefix path is too longer!"; + return std::nullopt; + } + std::string last_path = input_path.substr(path_split_pos, input_path.length() - path_split_pos); + auto ret = CreateNotExistDirs(prefix_path); + if (!ret) { + MS_LOG(ERROR) << "CreateNotExistDirs Failed!"; + return std::nullopt; + } + + if (nullptr == realpath(prefix_path.c_str(), real_path)) { + MS_LOG(ERROR) << "dir " << prefix_path << " does not exit."; + return std::nullopt; + } + out_path = std::string(real_path) + last_path; + } + + if (path_split_pos == std::string::npos) { + if (input_path.length() >= PATH_MAX) { + MS_LOG(ERROR) << "Prefix path is too longer!"; + return std::nullopt; + } + if (nullptr == realpath(input_path.c_str(), real_path)) { + MS_LOG(ERROR) << "File " << input_path << " does not exit, it will be created."; + } + out_path = std::string(real_path); + } + return out_path; +} + +bool Common::CreateNotExistDirs(const std::string &path) { + std::shared_ptr fs = system::Env::GetFileSystem(); + MS_EXCEPTION_IF_NULL(fs); + char temp_path[PATH_MAX] = {0}; + if (path.length() > PATH_MAX) { + MS_LOG(ERROR) << "Path lens is max than " << PATH_MAX; + return false; + } + for (uint32_t i = 0; i < path.length(); i++) { + temp_path[i] = path[i]; + if (temp_path[i] == '\\' || temp_path[i] == '/') { + if (i != 0) { + char tmp_char = temp_path[i]; + temp_path[i] = '\0'; + std::string path_handle(temp_path); + if (!fs->FileExist(temp_path)) { + MS_LOG(INFO) << "Dir " << path_handle << " does not exit, creating..."; + if (!fs->CreateDir(temp_path)) { + MS_LOG(ERROR) << "Create " << path_handle << " dir error"; + return false; + } + } + temp_path[i] = tmp_char; + } + } + } + + if (!fs->FileExist(path)) { + MS_LOG(INFO) << "Dir " << path << " does not exit, creating..."; + if (!fs->CreateDir(path)) { + MS_LOG(ERROR) << "Create " << path << " dir error"; + return false; + } + } + return true; +} + +std::optional Common::GetConfigFile(const std::string &env) { + if (env.empty()) { + MS_LOG(EXCEPTION) << "Invalid env"; + } + auto config_path_str = std::getenv(env.c_str()); + if (config_path_str == nullptr) { + MS_LOG(ERROR) << "Please export env:" << env; + return {}; + } + MS_LOG(INFO) << "Async Dump Getenv env:" << env << "=" << config_path_str; + + std::string dump_config_file(config_path_str); + std::shared_ptr fs = system::Env::GetFileSystem(); + MS_EXCEPTION_IF_NULL(fs); + if (!fs->FileExist(dump_config_file)) { + MS_LOG(ERROR) << dump_config_file << " not exist."; + return {}; + } + return dump_config_file; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/debug/common.h b/mindspore/ccsrc/debug/common.h new file mode 100644 index 0000000000..8d4a6cb467 --- /dev/null +++ b/mindspore/ccsrc/debug/common.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEBUG_COMMON_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEBUG_COMMON_H_ + +#include +#include +#include "utils/contract.h" + +namespace mindspore { +class Common { + public: + Common() = default; + ~Common() = default; + static std::optional GetRealPath(const std::string &input_path); + static std::optional GetConfigFile(const std::string &env); + + private: + static bool CreateNotExistDirs(const std::string &path); +}; +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEBUG_COMMON_H_ diff --git a/mindspore/ccsrc/debug/data_dump_parser.cc b/mindspore/ccsrc/debug/data_dump_parser.cc new file mode 100644 index 0000000000..259ec388d3 --- /dev/null +++ b/mindspore/ccsrc/debug/data_dump_parser.cc @@ -0,0 +1,152 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "debug/data_dump_parser.h" + +#include +#include "utils/context/ms_context.h" +#include "debug/common.h" + +constexpr auto kDataDumpConfigPtah = "DATA_DUMP_CONFIG_PATH"; +constexpr auto kEnableDataDump = "ENABLE_DATA_DUMP"; +constexpr auto kDataDumpPath = "DATA_DUMP_PATH"; +namespace mindspore { +void DataDumpParser::ResetParam() { + enable_ = false; + net_name_.clear(); + dump_mode_ = 0; + dump_step_ = 0; + kernel_set_.clear(); +} + +bool DataDumpParser::DumpEnabled() const { + auto enable_dump = std::getenv(kEnableDataDump); + if (!enable_dump) { + MS_LOG(WARNING) << "[DataDump] enable dump is null. Please export ENABLE_DATA_DUMP"; + return false; + } + + auto enabled = std::atoi(enable_dump); + if (enabled != 1) { + MS_LOG(WARNING) << "[DataDump] Please export ENABLE_DATA_DUMP=1"; + return false; + } + + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + if (context->execution_mode() == kPynativeMode) { + MS_LOG(EXCEPTION) << "[DataDump] PyNative mode not support data dump"; + } + return true; +} + +std::optional DataDumpParser::GetDumpPath() const { + auto dump_path = std::getenv(kDataDumpPath); + if (!dump_path) { + MS_LOG(ERROR) << "[DataDump] dump path is null. Please export DATA_DUMP_PATH"; + return {}; + } + std::string dump_path_str(dump_path); + return dump_path_str; +} + +void DataDumpParser::ParseDumpConfig() { + std::lock_guard guard(lock_); + MS_LOG(INFO) << "[DataDump] parse start"; + if (!DumpEnabled()) { + MS_LOG(INFO) << "[DataDump] dump not enable"; + return; + } + + ResetParam(); + + auto dump_config_file = Common::GetConfigFile(kDataDumpConfigPtah); + if (!dump_config_file.has_value()) { + MS_LOG(EXCEPTION) << "[DataDump] Get config file failed"; + } + + std::ifstream json_file(dump_config_file.value()); + if (!json_file.is_open()) { + MS_LOG(EXCEPTION) << "[DataDump] " << dump_config_file.value() << " open failed."; + } + + nlohmann::json j; + json_file >> j; + if (j.find("DumpSettings") == j.end()) { + MS_LOG(EXCEPTION) << "[DataDump] DumpSettings is not exist."; + } + + nlohmann::json dump_settings = j.at("DumpSettings"); + // convert json to string + std::stringstream ss; + ss << dump_settings; + std::string cfg = ss.str(); + MS_LOG(INFO) << "[DataDump] Async dump settings Json: " << cfg; + if (!IsConfigExist(dump_settings)) { + MS_LOG(EXCEPTION) << "[DataDump] Async dump json invalid"; + } + + if (!ParseDumpSetting(dump_settings)) { + MS_LOG(EXCEPTION) << "[DataDump] Parse dump json failed"; + } +} + +bool DataDumpParser::NeedDump(const std::string &op_full_name) const { + if (!DumpEnabled()) { + return false; + } + if (dump_mode_ == 0) { + return true; + } + auto iter = kernel_set_.find(op_full_name); + return iter != kernel_set_.end(); +} + +bool DataDumpParser::IsConfigExist(const nlohmann::json &dump_settings) const { + if (dump_settings.find("mode") == dump_settings.end() || dump_settings.find("net_name") == dump_settings.end() || + dump_settings.find("iteration") == dump_settings.end() || dump_settings.find("kernels") == dump_settings.end()) { + MS_LOG(ERROR) << "[DataDump] DumpSettings keys are not exist."; + return false; + } + return true; +} + +bool DataDumpParser::ParseDumpSetting(const nlohmann::json &dump_settings) { + auto mode = dump_settings.at("mode"); + auto net_name = dump_settings.at("net_name"); + auto iteration = dump_settings.at("iteration"); + auto kernels = dump_settings.at("kernels"); + if (!(mode.is_number() && net_name.is_string() && iteration.is_number() && kernels.is_array())) { + MS_LOG(ERROR) << "[DataDump] Element's type in Dump config json is invalid."; + enable_ = false; + return false; + } + + enable_ = true; + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + dump_mode_ = mode; + net_name_ = net_name; + dump_step_ = iteration; + for (const auto &kernel : kernels) { + auto kernel_str = kernel.dump(); + kernel_str.erase(std::remove(kernel_str.begin(), kernel_str.end(), '\"'), kernel_str.end()); + MS_LOG(INFO) << "[DataDump] Need dump kernel:" << kernel_str; + kernel_set_.insert(kernel_str); + } + return true; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/debug/data_dump_parser.h b/mindspore/ccsrc/debug/data_dump_parser.h new file mode 100644 index 0000000000..751c61dd1a --- /dev/null +++ b/mindspore/ccsrc/debug/data_dump_parser.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEBUG_ASYNC_DUMP_JSON_PARE_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEBUG_ASYNC_DUMP_JSON_PARE_H_ + +#include +#include +#include +#include +#include "nlohmann/json.hpp" +#include "common/utils.h" + +namespace mindspore { +class DataDumpParser { + public: + static DataDumpParser &GetInstance() { + static DataDumpParser instance; + return instance; + } + void ParseDumpConfig(); + bool NeedDump(const std::string &op_full_name) const; + bool DumpEnabled() const; + std::optional GetDumpPath() const; + bool enable() const { return enable_; } + const std::string &net_name() const { return net_name_; } + uint32_t dump_mode() const { return dump_mode_; } + uint32_t dump_step() const { return dump_step_; } + const std::set &kernel_set() const { return kernel_set_; } + + private: + DataDumpParser() = default; + virtual ~DataDumpParser() = default; + DISABLE_COPY_AND_ASSIGN(DataDumpParser); + + void ResetParam(); + bool IsConfigExist(const nlohmann::json &dump_settings) const; + bool ParseDumpSetting(const nlohmann::json &dump_settings); + + std::mutex lock_; + bool enable_{false}; + std::string net_name_; + uint32_t dump_mode_{0}; + uint32_t dump_step_{0}; + std::set kernel_set_; +}; +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEBUG_ASYNC_DUMP_JSON_PARE_H_ diff --git a/mindspore/ccsrc/debug/debug_services.cc b/mindspore/ccsrc/debug/debug_services.cc index cb883eef51..cc6c5c53ad 100644 --- a/mindspore/ccsrc/debug/debug_services.cc +++ b/mindspore/ccsrc/debug/debug_services.cc @@ -37,8 +37,8 @@ DebugServices &DebugServices::operator=(const DebugServices &other) { DebugServices::~DebugServices() { delete tensor_loader_; } -void DebugServices::add_watchpoint(unsigned int id, unsigned int watch_condition, - const std::vector> &check_node_list) { +void DebugServices::AddWatchpoint(unsigned int id, unsigned int watch_condition, + const std::vector> &check_node_list) { std::lock_guard lg(lock_); watchpoint_t watchpoint_item; @@ -57,14 +57,14 @@ void DebugServices::add_watchpoint(unsigned int id, unsigned int watch_condition watchpoint_table[id] = watchpoint_item; } -void DebugServices::remove_watchpoint(unsigned int id) { +void DebugServices::RemoveWatchpoint(unsigned int id) { std::lock_guard lg(lock_); watchpoint_table.erase(id); } -void DebugServices::check_watchpoints(std::vector *name, std::vector *slot, - std::vector *data_ptr, std::vector *data_size, - std::vector *condition, std::vector *wacthpoint_id) { +void DebugServices::CheckWatchpoints(std::vector *name, std::vector *slot, + std::vector *data_ptr, std::vector *data_size, + std::vector *condition, std::vector *wacthpoint_id) { std::lock_guard lg(lock_); std::vector> tensor_list = tensor_loader_->GetTensor(); @@ -171,9 +171,9 @@ void DebugServices::check_watchpoints(std::vector *name, std::vecto } } -void DebugServices::read_nodes_tensors(std::vector name, std::vector *ret_name, - std::vector *data_ptr, std::vector *data_size, - std::vector *dtype, std::vector> *shape) { +void DebugServices::ReadNodesTensors(std::vector name, std::vector *ret_name, + std::vector *data_ptr, std::vector *data_size, + std::vector *dtype, std::vector> *shape) { std::vector>> result_list; tensor_loader_->SearchTensors(name, &result_list); @@ -189,6 +189,28 @@ void DebugServices::read_nodes_tensors(std::vector name, std::vecto } } -TensorLoader *DebugServices::get_tensor_loader() const { return tensor_loader_; } +bool DebugServices::IsWatchPoint(std::string kernel_name, + std::unordered_map watchpoint_table) { + bool ret = false; + for (auto w_table_item : watchpoint_table) { + auto check_node_list = std::get<1>(w_table_item).check_node_list; + for (auto check_node : check_node_list) { + std::string w_name = std::get<0>(check_node); + bool w_type = std::get<1>(check_node); + if ((w_type == true && + ((kernel_name.find(w_name) != string::npos && kernel_name.rfind(w_name, 0) == 0) || w_name == "*")) || + (w_type == false && kernel_name == w_name)) { + ret = true; + return ret; + } + } + } + return ret; +} + +TensorLoader *DebugServices::tensor_loader() const { return tensor_loader_; } +std::unordered_map DebugServices::GetWatchpointTable() { + return watchpoint_table; +} } // namespace mindspore diff --git a/mindspore/ccsrc/debug/debug_services.h b/mindspore/ccsrc/debug/debug_services.h index b2fd41cd68..41400af1d5 100644 --- a/mindspore/ccsrc/debug/debug_services.h +++ b/mindspore/ccsrc/debug/debug_services.h @@ -37,22 +37,6 @@ class DebugServices { ~DebugServices(); - void add_watchpoint(unsigned int id, unsigned int watch_condition, - const std::vector> &check_node_list); - - void remove_watchpoint(unsigned int id); - - void check_watchpoints(std::vector *name, std::vector *slot, std::vector *data_ptr, - std::vector *data_size, std::vector *condition, - std::vector *wacthpoint_id); - - void read_nodes_tensors(std::vector name, std::vector *ret_name, - std::vector *data_ptr, std::vector *data_size, - std::vector *dtype, std::vector> *shape); - - TensorLoader *get_tensor_loader() const; - - private: typedef struct condition_no_param { bool enabled = false; } condition_no_param_t; @@ -84,6 +68,26 @@ class DebugServices { std::vector> check_node_list; } watchpoint_t; + void AddWatchpoint(unsigned int id, unsigned int watch_condition, + const std::vector> &check_node_list); + + void RemoveWatchpoint(unsigned int id); + + void CheckWatchpoints(std::vector *name, std::vector *slot, std::vector *data_ptr, + std::vector *data_size, std::vector *condition, + std::vector *wacthpoint_id); + + void ReadNodesTensors(std::vector name, std::vector *ret_name, + std::vector *data_ptr, std::vector *data_size, + std::vector *dtype, std::vector> *shape); + + bool IsWatchPoint(std::string kernel_name, std::unordered_map watchpoint_table); + + TensorLoader *tensor_loader() const; + + std::unordered_map GetWatchpointTable(); + + private: std::mutex lock_; std::unordered_map watchpoint_table; diff --git a/mindspore/ccsrc/debug/debugger/debug_graph.proto b/mindspore/ccsrc/debug/debugger/debug_graph.proto index 042360fac3..0930791ac0 100644 --- a/mindspore/ccsrc/debug/debugger/debug_graph.proto +++ b/mindspore/ccsrc/debug/debugger/debug_graph.proto @@ -313,4 +313,10 @@ message TensorProto { // If the tensor content transferring is finished. optional bool finished = 6; + + // The iteration of the tensor. Supported: "prev" or leave empty. + optional string iter = 7; + + // If the tensor name should be truncated. + optional bool truncate = 8; } \ No newline at end of file diff --git a/mindspore/ccsrc/debug/debugger/debugger.cc b/mindspore/ccsrc/debug/debugger/debugger.cc index ea147a929f..dd89e17e2d 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.cc +++ b/mindspore/ccsrc/debug/debugger/debugger.cc @@ -19,8 +19,8 @@ #include #include #include "debug/debugger/debugger.h" -#include "pipeline/pipeline.h" -#include "session/anf_runtime_algorithm.h" +#include "pipeline/jit/pipeline.h" +#include "backend/session/anf_runtime_algorithm.h" using debugger::EventReply; using debugger::GraphProto; @@ -43,7 +43,8 @@ Debugger::Debugger() device_id_(0), num_step_(0), debugger_enabled_(false), - is_dataset_graph_(false) {} + is_dataset_graph_(false), + partial_memory_(false) {} void Debugger::Init(const uint32_t device_id) { // access lock for public method @@ -57,6 +58,7 @@ void Debugger::EnableDebugger() { // reset some of the class members num_step_ = 0; debugger_enabled_ = false; + partial_memory_ = false; grpc_client_ = nullptr; debug_services_ = nullptr; @@ -72,7 +74,8 @@ void Debugger::EnableDebugger() { MS_LOG(WARNING) << "Not enabling debugger. Set environment variable ENABLE_MS_DEBUGGER=1 to enable debugger."; return; } - // configure host + + // configure grpc host const char *env_host_str = std::getenv("MS_DEBUGGER_HOST"); std::string host; if (env_host_str != nullptr) { @@ -82,7 +85,7 @@ void Debugger::EnableDebugger() { MS_LOG(WARNING) << "Environment variable MS_DEBUGGER_HOST doesn't exist. Using default debugger host: localhost"; host = "localhost"; } - // configure port + // configure grpc port const char *env_port_str = std::getenv("MS_DEBUGGER_PORT"); std::string port; if (env_port_str != nullptr) { @@ -93,6 +96,27 @@ void Debugger::EnableDebugger() { port = "50051"; } + // configure partial memory reuse + const char *env_partial_mem_str = std::getenv("MS_DEBUGGER_PARTIAL_MEM"); + if (env_partial_mem_str != nullptr) { + MS_LOG(INFO) << "Getenv MS_DEBUGGER_PARTIAL_MEM: " << env_partial_mem_str; + if (std::strcmp(env_partial_mem_str, "1") == 0) { + partial_memory_ = true; + } + } + // switch memory reuse on or off + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + context_ptr->set_enable_mem_reuse(partial_memory_); + // print some message about memory reuse to user + if (partial_memory_) { + MS_LOG(WARNING) << "Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first " + "step. 2. Tensor values are only available for nodes that are watched by any watchpoint."; + } else { + MS_LOG(WARNING) << "Memory Reuse is disabled. Set environment variable MS_DEBUGGER_PARTIAL_MEM=1 to reduce memory " + "usage for large models."; + } + // initialize grpc client grpc_client_ = std::make_unique(host, port); debug_services_ = std::make_unique(); @@ -106,6 +130,7 @@ void Debugger::Reset() { num_step_ = 0; debugger_enabled_ = false; is_dataset_graph_ = false; + partial_memory_ = false; graph_ptr_ = nullptr; grpc_client_ = nullptr; debug_services_ = nullptr; @@ -178,7 +203,7 @@ void Debugger::CheckDatasetGraph() { is_dataset_graph_ = false; } -GraphProto Debugger::GetGraphProto() { +GraphProto Debugger::GetGraphProto() const { // convert kernel graph to debugger modelproto ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_); return model.graph(); @@ -261,12 +286,9 @@ void Debugger::CommandLoop() { MS_LOG(INFO) << "node name: " << node.node_name(); MS_LOG(INFO) << "node type: " << node.node_type(); } - WatchCondition recieved_condition = GetWatchcondition(reply); - MS_LOG(INFO) << "condition: " << recieved_condition.condition(); - int32_t id = GetWatchpointID(reply); - MS_LOG(INFO) << "id: " << id; - bool delete_ = GetWatchpointDelete(reply); - MS_LOG(INFO) << "delete: " << delete_; + MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition(); + MS_LOG(INFO) << "id: " << GetWatchpointID(reply); + MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply); } MS_LOG(INFO) << "Setting watchpoint"; if (GetWatchpointDelete(reply)) { @@ -284,15 +306,20 @@ void Debugger::CommandLoop() { MS_LOG(INFO) << "tensor node name: " << tensor.node_name(); MS_LOG(INFO) << "tensor slot: " << tensor.slot(); MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha; + MS_LOG(INFO) << "tensor iter: " << tensor.iter(); + MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha; } } MS_LOG(INFO) << "Sending tensors"; std::list tensors = LoadTensors(GetTensors(reply)); { + // print view cmd reply for (auto tensor : tensors) { MS_LOG(INFO) << "tensor node name: " << tensor.node_name(); MS_LOG(INFO) << "tensor slot: " << tensor.slot(); MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha; + MS_LOG(INFO) << "tensor iter: " << tensor.iter(); + MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha; MS_LOG(INFO) << "tensor dims: "; for (auto dim : tensor.dims()) { MS_LOG(INFO) << dim << ","; @@ -309,81 +336,18 @@ void Debugger::CommandLoop() { } } -DebuggerCommand Debugger::GetCommand(const EventReply &reply) { - DebuggerCommand cmd = DebuggerCommand::kUnknownCMD; - switch (reply.cmd_case()) { - case debugger::EventReply::CmdCase::kExit: - cmd = DebuggerCommand::kExitCMD; - break; - case debugger::EventReply::CmdCase::kRunCmd: - cmd = DebuggerCommand::kRunCMD; - break; - case debugger::EventReply::CmdCase::kSetCmd: - cmd = DebuggerCommand::kSetCMD; - break; - case debugger::EventReply::CmdCase::kViewCmd: - cmd = DebuggerCommand::kViewCMD; - break; - default: - MS_LOG(ERROR) << "Error: UnknownCMD"; - break; - } - return cmd; -} - -ProtoVector Debugger::GetWatchnodes(const EventReply &reply) { - if (!reply.has_set_cmd()) { - MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector()."; - return ProtoVector(); - } - return reply.set_cmd().watch_nodes(); -} - -WatchCondition Debugger::GetWatchcondition(const EventReply &reply) { - if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) { - MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition()."; - return WatchCondition(); - } - return reply.set_cmd().watch_condition(); -} - -int32_t Debugger::GetWatchpointID(const EventReply &reply) { - if (!reply.has_set_cmd()) { - MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0."; - return 0; - } - return reply.set_cmd().id(); -} - -bool Debugger::GetWatchpointDelete(const EventReply &reply) { - if (!reply.has_set_cmd()) { - MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false."; - return false; - } - return reply.set_cmd().delete_(); -} - -ProtoVector Debugger::GetTensors(const EventReply &reply) { - if (!reply.has_view_cmd()) { - MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector()."; - return ProtoVector(); - } - return reply.view_cmd().tensors(); -} - void Debugger::SetWatchpoint(const ProtoVector &nodes, const WatchCondition &condition, const int32_t id) { std::vector> check_node_list; std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list), [](WatchNode node) -> std::tuple { return make_tuple(node.node_name(), node.node_type() == "scope"); }); - - debug_services_->add_watchpoint(id, condition.condition(), check_node_list); + debug_services_->AddWatchpoint(id, condition.condition(), check_node_list); } -void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->remove_watchpoint(id); } +void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->RemoveWatchpoint(id); } -std::list Debugger::LoadTensors(const ProtoVector &tensors) { +std::list Debugger::LoadTensors(const ProtoVector &tensors) const { std::vector name; std::vector ret_name; std::vector data_ptr; @@ -391,38 +355,42 @@ std::list Debugger::LoadTensors(const ProtoVector &ten std::vector dtype; std::vector> shape; - std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), - [](TensorProto tensor) -> std::string { return tensor.node_name() + ":" + tensor.slot(); }); + std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName); - debug_services_->read_nodes_tensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape); + // ret_name will contain tensor names that are found in TensorLoader + // items in ret_name will be in the same order with tensors if found + debug_services_->ReadNodesTensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape); std::list tensor_list; unsigned int result_index = 0; - TensorProto tensor_item; - for (auto tensor : tensors) { + TensorProto tensor_item; tensor_item.set_node_name(tensor.node_name()); tensor_item.set_slot(tensor.slot()); + tensor_item.set_iter(tensor.iter()); + tensor_item.set_truncate(tensor.truncate()); + tensor_item.clear_tensor_content(); + tensor_item.clear_data_type(); + tensor_item.clear_dims(); + // always set finished to true before big tensor splitting is supported tensor_item.set_finished(true); // return empty tensor if didn't find the requested tensor - if (result_index >= ret_name.size() || ret_name[result_index] != tensor.node_name() + ":" + tensor.slot()) { + if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) { tensor_list.push_back(tensor_item); continue; } tensor_item.set_tensor_content(data_ptr[result_index], data_size[result_index]); tensor_item.set_data_type(GetDebuggerNumberDataType(dtype[result_index])); - tensor_item.clear_dims(); for (auto &elem : shape[result_index]) { tensor_item.add_dims(elem); } + // add tensor to result list and increment result_index to check next item in ret_name tensor_list.push_back(tensor_item); - result_index++; } - return tensor_list; } @@ -432,7 +400,7 @@ void Debugger::Exit() { std::exit(EXIT_FAILURE); } -std::list Debugger::CheckWatchpoints() { +std::list Debugger::CheckWatchpoints() const { std::vector name; std::vector slot; std::vector data_ptr; @@ -440,33 +408,24 @@ std::list Debugger::CheckWatchpoints() { std::vector condition; std::vector watchpoint_id; - debug_services_->check_watchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id); - - std::list points; - + debug_services_->CheckWatchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id); + std::list hits; for (unsigned int i = 0; i < name.size(); i++) { - TensorProto *tensor_item; - tensor_item = new TensorProto(); + WatchpointHit hit; + hit.set_id(watchpoint_id[i]); + + // here TensorProto act as a tensor indicator, not sending tensor content + TensorProto *tensor_item = hit.mutable_tensor(); tensor_item->set_node_name(name[i]); tensor_item->set_slot(slot[i]); - tensor_item->set_tensor_content(data_ptr[i], data_size[i]); - - // finished in TensorProto will always be true before we implement big tensor splitting tensor_item->set_finished(true); - WatchCondition *condition_item; - condition_item = new WatchCondition(); + WatchCondition *condition_item = hit.mutable_watch_condition(); condition_item->set_condition(debugger::WatchCondition_Condition(condition[i])); - WatchpointHit point; - point.set_allocated_tensor(tensor_item); - point.set_allocated_watch_condition(condition_item); - point.set_id(watchpoint_id[i]); - - points.push_back(point); + hits.push_back(hit); } - - return points; + return hits; } void Debugger::SendWatchpointsAndSuspend(const std::list &points) { @@ -481,8 +440,83 @@ void Debugger::SendWatchpointsAndSuspend(const std::list &points) CommandLoop(); } -DebugServices *Debugger::get_debug_services() { return debug_services_.get(); } +DebugServices *Debugger::debug_services() const { return debug_services_.get(); } + +bool Debugger::debugger_enabled() const { return debugger_enabled_; } + +DebuggerCommand GetCommand(const EventReply &reply) { + DebuggerCommand cmd = DebuggerCommand::kUnknownCMD; + switch (reply.cmd_case()) { + case debugger::EventReply::CmdCase::kExit: + cmd = DebuggerCommand::kExitCMD; + break; + case debugger::EventReply::CmdCase::kRunCmd: + cmd = DebuggerCommand::kRunCMD; + break; + case debugger::EventReply::CmdCase::kSetCmd: + cmd = DebuggerCommand::kSetCMD; + break; + case debugger::EventReply::CmdCase::kViewCmd: + cmd = DebuggerCommand::kViewCMD; + break; + default: + MS_LOG(ERROR) << "Error: UnknownCMD"; + break; + } + return cmd; +} + +ProtoVector GetWatchnodes(const EventReply &reply) { + if (!reply.has_set_cmd()) { + MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector()."; + return ProtoVector(); + } + return reply.set_cmd().watch_nodes(); +} + +WatchCondition GetWatchcondition(const EventReply &reply) { + if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) { + MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition()."; + return WatchCondition(); + } + return reply.set_cmd().watch_condition(); +} + +int32_t GetWatchpointID(const EventReply &reply) { + if (!reply.has_set_cmd()) { + MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0."; + return 0; + } + return reply.set_cmd().id(); +} + +bool GetWatchpointDelete(const EventReply &reply) { + if (!reply.has_set_cmd()) { + MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false."; + return false; + } + return reply.set_cmd().delete_(); +} + +ProtoVector GetTensors(const EventReply &reply) { + if (!reply.has_view_cmd()) { + MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector()."; + return ProtoVector(); + } + return reply.view_cmd().tensors(); +} + +std::string GetTensorFullName(const TensorProto &tensor) { + string node_name = tensor.node_name(); + if (tensor.truncate()) { + // scopes in node name are seperated by '/' + // use the name without scope if truncate is true + std::size_t found = node_name.find_last_of("/"); + node_name = node_name.substr(found + 1); + } + return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter()); +} -bool Debugger::debugger_enabled() { return debugger_enabled_; } +bool Debugger::partial_memory() { return partial_memory_; } } // namespace mindspore diff --git a/mindspore/ccsrc/debug/debugger/debugger.h b/mindspore/ccsrc/debug/debugger/debugger.h index 6ce7d03625..5a3965d7cc 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.h +++ b/mindspore/ccsrc/debug/debugger/debugger.h @@ -19,7 +19,7 @@ #include #include #include -#include "session/kernel_graph.h" +#include "backend/session/kernel_graph.h" #include "debug/debugger/grpc_client.h" #include "debug/debug_services.h" @@ -72,9 +72,11 @@ class Debugger : public std::enable_shared_from_this { // suspend the execution after a debug_op void PostDebugOp(); - DebugServices *get_debug_services(); + DebugServices *debug_services() const; - bool debugger_enabled(); + bool debugger_enabled() const; + + bool partial_memory(); private: // private constructor for singleton @@ -92,7 +94,7 @@ class Debugger : public std::enable_shared_from_this { void CheckDatasetGraph(); // serialize graph and get proto - GraphProto GetGraphProto(); + GraphProto GetGraphProto() const; // send graph and enter command wait loop void SendGraphAndSuspend(const GraphProto &graph_proto); @@ -102,16 +104,6 @@ class Debugger : public std::enable_shared_from_this { // break if RunCMD void CommandLoop(); - // process reply and command type - DebuggerCommand GetCommand(const EventReply &reply); - - // parse other data out of EventReply - ProtoVector GetWatchnodes(const EventReply &reply); - WatchCondition GetWatchcondition(const EventReply &reply); - int32_t GetWatchpointID(const EventReply &reply); - bool GetWatchpointDelete(const EventReply &reply); - ProtoVector GetTensors(const EventReply &reply); - // set what nodes and conditions to watch void SetWatchpoint(const ProtoVector &nodes, const WatchCondition &condition, const int32_t id); @@ -119,14 +111,14 @@ class Debugger : public std::enable_shared_from_this { void RemoveWatchpoint(const int32_t id); // load tensor for view command - std::list LoadTensors(const ProtoVector &tensors); + std::list LoadTensors(const ProtoVector &tensors) const; // terminate training process void Exit(); // analyze tensors and check watchpoint conditions // return names of tensors and what condition they hit - std::list CheckWatchpoints(); + std::list CheckWatchpoints() const; // send watchpoints that hit and enter command wait loop void SendWatchpointsAndSuspend(const std::list &points); @@ -139,6 +131,7 @@ class Debugger : public std::enable_shared_from_this { int32_t num_step_; bool debugger_enabled_; bool is_dataset_graph_; + bool partial_memory_; std::mutex access_lock_; // singleton @@ -155,5 +148,18 @@ ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph); // for getting proto DataType from Type of Tensor DataType GetDebuggerNumberDataType(const TypePtr &type); +// process reply and command type +DebuggerCommand GetCommand(const EventReply &reply); + +// parse other data out of EventReply +ProtoVector GetWatchnodes(const EventReply &reply); +WatchCondition GetWatchcondition(const EventReply &reply); +int32_t GetWatchpointID(const EventReply &reply); +bool GetWatchpointDelete(const EventReply &reply); +ProtoVector GetTensors(const EventReply &reply); + +// get the full name of a tensor, which is the name used in TensorLoader +std::string GetTensorFullName(const TensorProto &tensor); + } // namespace mindspore #endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_ diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index 573452eac0..ff8132fb28 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -25,11 +25,11 @@ #include "pybind11/pybind11.h" #include "ir/meta_func_graph.h" -#include "ir/param_value_py.h" +#include "ir/param_value.h" #include "ir/primitive.h" #include "utils/graph_utils.h" #include "utils/utils.h" -#include "operator/composite/composite.h" +#include "frontend/operator/composite/composite.h" #include "ir/tensor.h" namespace py = pybind11; @@ -321,18 +321,9 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { buffer_ << parameter->ToString(); auto param = parameter->cast(); if (param->has_default()) { - auto param_value = std::dynamic_pointer_cast(param->default_param()); - auto py_p = param_value->value(); - if (py::hasattr(py_p, "default_input")) { - py_p = py_p.attr("default_input"); - std::vector shape; - if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) { - auto m_tensor = py_p.cast>(); - shape = m_tensor->shape(); - } else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) { - auto m_tensor = py_p.cast>(); - shape = m_tensor->shape(); - } + auto tensor = param->default_param()->value(); + if (tensor) { + auto &shape = tensor->shape(); std::ostringstream shape_str; std::copy(shape.begin(), shape.end(), std::ostream_iterator(shape_str, ",")); buffer_ << "[" << shape_str.str() << "]"; diff --git a/mindspore/ccsrc/debug/draw.h b/mindspore/ccsrc/debug/draw.h index 7804c6e94a..cb670fe0f6 100644 --- a/mindspore/ccsrc/debug/draw.h +++ b/mindspore/ccsrc/debug/draw.h @@ -22,7 +22,7 @@ #include #include "ir/anf.h" #include "utils/any.h" -#include "pipeline/parse/resolve.h" +#include "pipeline/jit/parse/resolve.h" namespace mindspore { namespace draw { diff --git a/mindspore/ccsrc/debug/dump_proto.cc b/mindspore/ccsrc/debug/dump_proto.cc index 99440537c7..35cdfafe26 100644 --- a/mindspore/ccsrc/debug/dump_proto.cc +++ b/mindspore/ccsrc/debug/dump_proto.cc @@ -453,6 +453,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr & GetOpNodeTypeAndAttrs(func_graph, op, node_proto); node_proto->set_name(std::to_string(apply_idx)); node_proto->set_scope(node->scope()->name()); + node_proto->set_full_name(node->fullname_with_scope()); // process OP inputs for (size_t i = 1; i < inputs.size(); ++i) { diff --git a/mindspore/ccsrc/debug/e2e_dump.cc b/mindspore/ccsrc/debug/e2e_dump.cc index 78a331fc27..9037a6d00b 100644 --- a/mindspore/ccsrc/debug/e2e_dump.cc +++ b/mindspore/ccsrc/debug/e2e_dump.cc @@ -17,12 +17,14 @@ #include #include #include +#include #include #include "utils/log_adapter.h" #include "utils/system/file_system.h" #include "utils/system/env.h" #include "utils/convert_utils.h" #include "utils/context/ms_context.h" +#include "debug/common.h" using json = nlohmann::json; @@ -158,100 +160,19 @@ bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) return false; } - std::string realpath; - bool ret = GetRealPath(filename, &realpath); - if (!ret) { + auto realpath = Common::GetRealPath(filename); + if (!realpath.has_value()) { MS_LOG(ERROR) << "Get real path failed."; return false; } std::ofstream fd; - fd.open(realpath, std::ios::binary | std::ios::out); + fd.open(realpath.value(), std::ios::binary | std::ios::out); if (!fd.is_open()) { - MS_LOG(ERROR) << "Open file " << realpath << " fail."; + MS_LOG(ERROR) << "Open file " << realpath.value() << " fail."; return false; } (void)fd.write(reinterpret_cast(data), SizeToLong(len)); fd.close(); return true; } - -bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) { - MS_EXCEPTION_IF_NULL(outpath); - auto path_split_pos = inpath.find_last_of('/'); - if (path_split_pos == std::string::npos) { - path_split_pos = inpath.find_last_of('\\'); - } - // get real path - char real_path[PATH_MAX] = {0}; - if (path_split_pos != std::string::npos) { - std::string prefix_path = inpath.substr(0, path_split_pos); - if (prefix_path.length() >= PATH_MAX) { - MS_LOG(ERROR) << "Prefix path is too longer!"; - return false; - } - std::string last_path = inpath.substr(path_split_pos, inpath.length() - path_split_pos); - auto ret = CreateNotExistDirs(prefix_path); - if (ret == false) { - MS_LOG(ERROR) << "CreateNotExistDirs Failed!"; - return false; - } - - if (nullptr == realpath(prefix_path.c_str(), real_path)) { - MS_LOG(ERROR) << "dir " << prefix_path << " does not exit."; - return false; - } - *outpath = std::string(real_path) + last_path; - } - - if (path_split_pos == std::string::npos) { - if (inpath.length() >= PATH_MAX) { - MS_LOG(ERROR) << "Prefix path is too longer!"; - return false; - } - if (nullptr == realpath(inpath.c_str(), real_path)) { - MS_LOG(ERROR) << "File " << inpath << " does not exit, it will be created."; - } - *outpath = std::string(real_path); - } - - return true; -} - -bool Dump::CreateNotExistDirs(const std::string &path) { - std::shared_ptr fs = system::Env::GetFileSystem(); - MS_EXCEPTION_IF_NULL(fs); - char temp_path[PATH_MAX] = {0}; - if (path.length() > PATH_MAX) { - MS_LOG(ERROR) << "Path lens is max than " << PATH_MAX; - return false; - } - for (uint32_t i = 0; i < path.length(); i++) { - temp_path[i] = path[i]; - if (temp_path[i] == '\\' || temp_path[i] == '/') { - if (i != 0) { - char tmp_char = temp_path[i]; - temp_path[i] = '\0'; - std::string path_handle(temp_path); - if (!fs->FileExist(temp_path)) { - MS_LOG(INFO) << "Dir " << path_handle << " does not exit, creating..."; - if (!fs->CreateDir(temp_path)) { - MS_LOG(ERROR) << "Create " << path_handle << " dir error"; - return false; - } - } - temp_path[i] = tmp_char; - } - } - } - - if (!fs->FileExist(path)) { - MS_LOG(INFO) << "Dir " << path << " does not exit, creating..."; - if (!fs->CreateDir(path)) { - MS_LOG(ERROR) << "Create " << path << " dir error"; - return false; - } - } - - return true; -} } // namespace mindspore diff --git a/mindspore/ccsrc/debug/e2e_dump.h b/mindspore/ccsrc/debug/e2e_dump.h index 4c3e8308da..acde1626cb 100644 --- a/mindspore/ccsrc/debug/e2e_dump.h +++ b/mindspore/ccsrc/debug/e2e_dump.h @@ -59,10 +59,6 @@ class Dump { uint32_t cur_iter_; std::vector dump_kernels_; - static bool GetRealPath(const std::string &inpath, std::string *outpath); - - static bool CreateNotExistDirs(const std::string &path); - private: bool ParseDumpConfig(const std::string &dump_config_file); bool IsConfigExist(const nlohmann::json &dumpSettings); diff --git a/mindspore/ccsrc/debug/info.h b/mindspore/ccsrc/debug/info.h index c09c6031b3..39475a4606 100644 --- a/mindspore/ccsrc/debug/info.h +++ b/mindspore/ccsrc/debug/info.h @@ -24,7 +24,7 @@ #include #include -#include "ir/base.h" +#include "base/base.h" #include "debug/trace_info.h" namespace mindspore { diff --git a/mindspore/ccsrc/debug/tensor_data.h b/mindspore/ccsrc/debug/tensor_data.h index 9704d69089..00af203208 100644 --- a/mindspore/ccsrc/debug/tensor_data.h +++ b/mindspore/ccsrc/debug/tensor_data.h @@ -51,25 +51,13 @@ class TensorData { int GetExecutionOrder() { return this->execution_order; } - int SetExecutionOrder(int execution_order) { - this->execution_order = execution_order; - return true; - } + void SetExecutionOrder(int execution_order) { this->execution_order = execution_order; } - int SetName(const std::string &name) { - this->name = name; - return true; - } + void SetName(const std::string &name) { this->name = name; } - bool SetTensor(mindspore::tensor::TensorPtr out_tensor) { - this->tensor_ptr = out_tensor; - return true; - } + void SetTensor(mindspore::tensor::TensorPtr out_tensor) { this->tensor_ptr = out_tensor; } - bool SetSlot(size_t slot) { - this->slot = slot; - return true; - } + void SetSlot(size_t slot) { this->slot = slot; } }; } // namespace mindspore #endif // MINDSPORE_CCSRC_DEBUG_TENSOR_DATA_H_ diff --git a/mindspore/ccsrc/debug/tensor_load.h b/mindspore/ccsrc/debug/tensor_load.h index 6c3ea67a78..ae0e89aae2 100644 --- a/mindspore/ccsrc/debug/tensor_load.h +++ b/mindspore/ccsrc/debug/tensor_load.h @@ -19,17 +19,28 @@ #include #include #include +#include #include #include +#include #include "debug/tensor_data.h" namespace mindspore { class TensorLoader { public: TensorLoader() : iter_num(-1) {} - ~TensorLoader() {} + ~TensorLoader() { EmptyTensor(); } - bool LoadNewTensor(std::shared_ptr tensor) { + bool LoadNewTensor(std::shared_ptr tensor, bool keep_prev) { + std::lock_guard lg(lock_); + if (keep_prev) { + // add prev step tensor into current step map with ":prev" suffix + auto handle = prev_tensor_list_map.extract(tensor->GetName()); + if (!handle.empty()) { + handle.key() = tensor->GetName() + ":prev"; + tensor_list_map.insert(std::move(handle)); + } + } tensor_list.push_back(tensor); tensor_list_map.insert({tensor->GetName(), tensor}); return true; @@ -52,18 +63,23 @@ class TensorLoader { } } - bool EmptyTensor() { - tensor_list_map.clear(); + void EmptyTensor() { + std::lock_guard lg(lock_); + prev_tensor_list_map.clear(); + tensor_list_map.swap(prev_tensor_list_map); tensor_list.clear(); - return true; } + void EmptyPrevTensor() { prev_tensor_list_map.clear(); } + void set_iter_num(uint32_t iter_num) { this->iter_num = iter_num; } private: std::vector> tensor_list; std::map> tensor_list_map; + std::map> prev_tensor_list_map; uint32_t iter_num; + std::mutex lock_; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_DEBUG_TENSOR_LOAD_H_ diff --git a/mindspore/ccsrc/debug/trace.cc b/mindspore/ccsrc/debug/trace.cc index e12a7b1209..b8d3f0a7c7 100644 --- a/mindspore/ccsrc/debug/trace.cc +++ b/mindspore/ccsrc/debug/trace.cc @@ -29,10 +29,10 @@ #include "ir/meta_func_graph.h" #include "utils/graph_utils.h" -#include "operator/composite/composite.h" +#include "frontend/operator/composite/composite.h" #include "ir/tensor.h" #include "debug/anf_ir_utils.h" -#include "pipeline/static_analysis/evaluator.h" +#include "pipeline/jit/static_analysis/evaluator.h" namespace mindspore { // namespace to support debug trace infomation diff --git a/mindspore/ccsrc/debug/trace.h b/mindspore/ccsrc/debug/trace.h index 9583997e93..7cf45abe30 100644 --- a/mindspore/ccsrc/debug/trace.h +++ b/mindspore/ccsrc/debug/trace.h @@ -27,7 +27,7 @@ #include "debug/info.h" #include "ir/anf.h" #include "ir/func_graph.h" -#include "pipeline/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/static_analysis.h" #include "utils/any.h" namespace mindspore { diff --git a/mindspore/ccsrc/debug/trace_info.h b/mindspore/ccsrc/debug/trace_info.h index cf4f0c080a..62908cb449 100644 --- a/mindspore/ccsrc/debug/trace_info.h +++ b/mindspore/ccsrc/debug/trace_info.h @@ -24,7 +24,7 @@ #include #include -#include "ir/base.h" +#include "base/base.h" namespace mindspore { class TraceInfo; diff --git a/mindspore/ccsrc/device/CMakeLists.txt b/mindspore/ccsrc/device/CMakeLists.txt deleted file mode 100644 index 652c04d4cd..0000000000 --- a/mindspore/ccsrc/device/CMakeLists.txt +++ /dev/null @@ -1,65 +0,0 @@ -file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc" - "kernel_info.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" -) - -if (ENABLE_GPU) - list(APPEND DEVICE_SRC_LIST "gpu/distribution/collective_init.cc") -else () - list(APPEND DEVICE_SRC_LIST "gpu/distribution/collective_fake_init.cc") -endif () - -if (ENABLE_D) - file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc" "kernel_adjust.cc") -endif () - -if (ENABLE_CPU) - file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") - list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc") -endif () - -if (ENABLE_MPI) - # _ms_mpi - file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc") - set_property(SOURCE ${MPI_SRC_LIST} - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - add_library(mpi_adapter SHARED ${MPI_SRC_LIST}) - target_link_libraries(mpi_adapter PRIVATE mindspore::ompi) - - set_property(SOURCE "gpu/mpi/mpi_initializer.cc" - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") - target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) -endif () - -# gpu -if (ENABLE_GPU) - file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc" "gpu/*.cu") - - set(GPU_QUEUE_SRCS "gpu/blocking_queue.cc" "gpu/gpu_buffer_mgr.cc") - set(GPU_COLLECTIVE_SRCS "gpu/distribution/collective_wrapper.cc" - "gpu/distribution/mpi_wrapper.cc" - "gpu/distribution/nccl_wrapper.cc") - - # gpu_queue - list(REMOVE_ITEM CUDA_SRC_LIST ${GPU_QUEUE_SRCS}) - set_property(SOURCE ${GPU_QUEUE_SRCS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - add_library(gpu_queue SHARED ${GPU_QUEUE_SRCS}) - target_link_libraries(gpu_queue ${CMAKE_THREAD_LIBS_INIT} ${CUDA_PATH}/lib64/libcudart.so) - - list(REMOVE_ITEM CUDA_SRC_LIST "gpu/mpi/mpi_initializer.cc" ${GPU_COLLECTIVE_SRCS}) - - if (ENABLE_MPI) - include(ExternalProject) - # gpu_collective - set_property(SOURCE ${GPU_COLLECTIVE_SRCS} - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS}) - target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl) - endif () - - # add_library(_mindspore_device_cuda_obj OBJECT ${CUDA_SRC_LIST}) -endif () - -set_property(SOURCE ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST} - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) -add_library(_mindspore_device_obj OBJECT ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST}) diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/device/ascend/ascend_device_address.cc deleted file mode 100644 index c4b8717fa5..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.cc +++ /dev/null @@ -1,405 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/ascend/ascend_device_address.h" -#include -#include -#include -#include -#include "runtime/mem.h" -#include "device/kernel_runtime_manager.h" -#include "device/convert_tensor_utils.h" -#include "ir/dtype/type.h" -#include "ir/tensor.h" -#include "kernel/common_utils.h" -#include "utils/utils.h" -#include "common/utils.h" -#include "common/trans.h" -#ifdef ENABLE_DUMP_E2E -#include "debug/e2e_dump.h" -#endif -#ifdef ENABLE_DEBUGGER -#include "debug/tensor_load.h" -#endif - -namespace mindspore { -namespace device { -namespace ascend { -const int FLOAT_LEN = sizeof(float); -const int FLOAT16_LEN = 2; // sizeof(float16); -const std::set kOpNeedTransFormat = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, - kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, - kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; - -void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) { - auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind); - if (ret_rt_memcpy != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtMemcpy failed"; - } -} - -bool FloatToHalfAndSyncHostToDevice(void *dst, size_t dst_size, const void *src, size_t src_size) { - auto elem_num = src_size / FLOAT_LEN; - if (elem_num != (dst_size / FLOAT16_LEN)) { - MS_EXCEPTION(ArgumentError) << "FloatToHalf failed. size not match src_size[" << src_size << "], dst_size[" - << dst_size << "]"; - } - std::vector half_data(elem_num); - FloatToHalf(half_data.data(), src, elem_num); - SyncMemory(dst, half_data.data(), dst_size, RT_MEMCPY_HOST_TO_DEVICE); - return true; -} - -bool Float64ToFloatAndSyncHostToDevice(void *dst, size_t dst_size, const void *src, size_t src_size) { - if (src_size / 2 != dst_size) { - MS_EXCEPTION(ArgumentError) << "src_size[" << src_size << "], dst_size[" << dst_size << "]"; - } - size_t elem_num = dst_size / sizeof(float); - auto host_tmp = std::vector(elem_num); - DoubleToFloat(host_tmp.data(), src, elem_num); - SyncMemory(dst, host_tmp.data(), dst_size, RT_MEMCPY_HOST_TO_DEVICE); - return true; -} - -bool SyncDeviceToHostAndHalfToFloat(void *dst, size_t dst_size, const void *src, size_t src_size) { - auto elem_num = src_size / FLOAT16_LEN; - if (elem_num != (dst_size / FLOAT_LEN)) { - MS_EXCEPTION(ArgumentError) << "HalfToFloat failed. size not match src_size[" << src_size << "], dst_size[" - << dst_size << "]"; - } - std::vector half_data(elem_num); - SyncMemory(half_data.data(), src, src_size, RT_MEMCPY_DEVICE_TO_HOST); - HalfToFloat(dst, half_data.data(), elem_num); - return true; -} - -bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *src, size_t src_size) { - if (src_size != dst_size / 2) { - MS_EXCEPTION(ArgumentError) << "src_size[" << src_size << "], dst_size[" << dst_size << "]"; - } - size_t elem_num = src_size / sizeof(float); - auto host_tmp = std::vector(elem_num); - SyncMemory(host_tmp.data(), src, src_size, RT_MEMCPY_DEVICE_TO_HOST); - FloatToDouble(dst, host_tmp.data(), elem_num); - return true; -} - -void AscendDeviceAddress::SyncStream() const { - MS_LOG(INFO) << "Start!"; - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() != kPynativeMode) { - MS_LOG(INFO) << "Finish!"; - return; - } - auto device_id = ms_context->device_id(); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); - MS_EXCEPTION_IF_NULL(runtime_instance); - auto ret = runtime_instance->SyncStream(); - if (!ret) { - MS_LOG(EXCEPTION) << "Sync stream error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t size, mindspore::TypeId type, - void *host_ptr) const { - MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) - << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; - SyncStream(); - bool sync_ok = false; - std::vector host_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); - if (host_shape.empty()) { - host_shape.emplace_back(1); - } - if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { - if (type_id_ == type) { - SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST); - sync_ok = true; - } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { - sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); - } else { - auto shape_size = trans::ShapeSize(host_shape); - auto host = std::vector(size_); - SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); - const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; - sync_ok = trans::TransDataType(type_args, host_ptr); - if (!sync_ok) { - MS_LOG(ERROR) << "trans data type failed."; - return false; - } - } - } else { - auto iter = kOpNeedTransFormat.find(format_); - if (iter != kOpNeedTransFormat.end()) { - sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr); - } else { - MS_LOG(INFO) << "Can not find format transfer for :" << format_; - } - } - if (!sync_ok) { - MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_) - << ", host_type:" << TypeIdLabel(type); - return false; - } - return sync_ok; -} - -bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, - mindspore::TypeId type, void *host_ptr) const { - MS_LOG(INFO) << "SyncDeviceToHostAndConvertFormat, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) - << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; - bool sync_ok = false; - auto host_tmp = std::vector(size_); - SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); - std::vector host_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); - std::vector device_shape; - if (host_shape.empty()) { - host_shape.emplace_back(1); - } - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { - device_shape = trans::TransShapeToDevice(host_shape, format_); - } else { - if (host_shape_.empty()) { - host_shape = trans::PaddingShapeTo4d(host_shape); - } else { - host_shape.clear(); - (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize); - } - - device_shape = trans::TransShapeToDevice(host_shape, format_); - } - if (type_id_ != type) { - const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, - host_shape, device_shape, type_id_}; - auto host = std::vector(size_); - sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data()); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans format failed."; - return false; - } - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; - sync_ok = trans::TransDataType(type_args, host_ptr); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans format failed."; - return false; - } - } else { - const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, - host_shape, device_shape, type_id_}; - sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans format failed."; - return false; - } - } - return sync_ok; -} - -bool AscendDeviceAddress::SyncHostToDevice(const std::vector &shape, size_t size, mindspore::TypeId type, - const void *host_ptr) const { - MS_LOG(INFO) << "SyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) - << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; - SyncStream(); - bool sync_ok = false; - std::vector host_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); - if (host_shape.empty()) { - host_shape.emplace_back(1); - } - if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { - if (type_id_ == type) { - SyncMemory(ptr_, host_ptr, size_, RT_MEMCPY_HOST_TO_DEVICE); - sync_ok = true; - } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { - sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); - } else { - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; - auto host_tmp = std::vector(size_); - sync_ok = trans::TransDataType(type_args, host_tmp.data()); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans data type failed."; - return false; - } - SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); - } - } else { - auto iter = kOpNeedTransFormat.find(format_); - if (iter != kOpNeedTransFormat.end()) { - sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr); - } else { - MS_LOG(INFO) << "Can not find format transfer for :" << format_; - } - } - if (!sync_ok) { - MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_) - << ", host_type:" << TypeIdLabel(type); - return false; - } - return sync_ok; -} - -bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, - mindspore::TypeId type, const void *host_ptr) const { - bool sync_ok = false; - MS_LOG(INFO) << "ConvertFormatAndSyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) - << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; - std::vector host_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); - if (host_shape.empty()) { - host_shape.emplace_back(1); - } - std::vector device_shape; - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { - device_shape = trans::TransShapeToDevice(host_shape, format_); - } else { - host_shape = trans::PaddingShapeTo4d(host_shape); - device_shape = trans::TransShapeToDevice(host_shape, format_); - } - if (type_id_ != type) { - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; - auto host_tmp = std::vector(size_); - sync_ok = trans::TransDataType(type_args, host_tmp.data()); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans datatype failed."; - return false; - } - const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, - host_shape, device_shape, type_id_}; - auto dst_tmp = std::vector(size_); - sync_ok = trans::TransFormat(format_args, dst_tmp.data()); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans format failed."; - return false; - } - SyncMemory(ptr_, dst_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); - } else { - const trans::FormatArgs format_args{host_ptr, size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_}; - auto host_tmp = std::vector(size_); - sync_ok = trans::TransFormat(format_args, host_tmp.data()); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans format failed."; - return false; - } - SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); - } - return sync_ok; -} - -AscendDeviceAddress::~AscendDeviceAddress() { - if (ptr_ == nullptr) { - return; - } - if (from_mem_pool_) { - AscendMemoryPool::GetInstance().FreeTensorMem(ptr_); - ptr_ = nullptr; - } -} - -#ifdef ENABLE_DUMP_E2E -bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &filepath, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type) const { - bool ret = false; - if (filepath.empty()) { - MS_LOG(ERROR) << "Dump file path is null!"; - return ret; - } - std::string shape = "shape"; - if (host_shape.size()) { - for (auto &value : host_shape) { - shape = shape + '_' + std::to_string(value); - } - } else { - shape = shape + "_0"; - } - std::string file_extension = ".bin"; - if (trans_flag) { - std::string path = filepath + '_' + shape + '_' + TypeIdLabel(host_type) + '_' + host_fmt + file_extension; - MS_LOG(INFO) << "E2E Dump path is " << path; - mindspore::tensor::TensorPtr out_tensor = std::make_shared(host_type, host_shape); - size_t host_size = out_tensor->data().nbytes(); - ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); - if (!ret) { - MS_LOG(ERROR) << "Copy device mem to host failed"; - return ret; - } - ret = mindspore::Dump::DumpToFile(path, out_tensor->data_c(), host_size); - } else { - auto host_tmp = std::vector(size_); - auto ret_rt_memcpy = rtMemcpy(host_tmp.data(), size_, ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); - if (ret_rt_memcpy != RT_ERROR_NONE) { - MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; - } - std::string path = - filepath + '_' + shape + '_' + TypeIdToType(type_id_)->ToString() + '_' + format_ + file_extension; - MS_LOG(INFO) << "E2E Dump path is " << path; - ret = mindspore::Dump::DumpToFile(path, host_tmp.data(), size_); - } - - return ret; -} -#endif - -#ifdef ENABLE_DEBUGGER -bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tensor_name, int execution_order, - const std::string &host_fmt, const std::vector &host_shape, - TypeId host_type, size_t slot, Debugger *debugger) const { - bool ret = false; - - DebugServices *debug_services = debugger->get_debug_services(); - TensorLoader *tensor_loader = debug_services->get_tensor_loader(); - - if (trans_flag) { - MS_LOG(INFO) << "E2E tensor name is " << tensor_name; - mindspore::tensor::TensorPtr out_tensor = std::make_shared(host_type, host_shape); - size_t host_size = out_tensor->data().nbytes(); - ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); - if (!ret) { - MS_LOG(ERROR) << "Copy device mem to host failed"; - return ret; - } - auto tensor_data = std::make_shared(); - tensor_data->SetName(tensor_name); - tensor_data->SetExecutionOrder(execution_order); - tensor_data->SetTensor(out_tensor); - tensor_data->SetSlot(slot); - ret = tensor_loader->LoadNewTensor(tensor_data); - } else { - mindspore::tensor::TensorPtr out_tensor = std::make_shared(type_id_, host_shape); - size_t host_size = out_tensor->data().nbytes(); - auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST); - - auto tensor_data = std::make_shared(); - tensor_data->SetName(tensor_name); - tensor_data->SetExecutionOrder(execution_order); - tensor_data->SetTensor(out_tensor); - tensor_data->SetSlot(slot); - ret = tensor_loader->LoadNewTensor(tensor_data); - if (ret_rt_memcpy != RT_ERROR_NONE) { - MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; - } - MS_LOG(INFO) << "E2E tensor name is " << tensor_name; - } - return ret; -} -#endif -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.h b/mindspore/ccsrc/device/ascend/ascend_device_address.h deleted file mode 100644 index 16b9f7817a..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ - -#include -#include -#include -#include "device/device_address.h" -#include "device/ascend/ascend_memory_pool.h" -#include "ir/dtype.h" - -namespace mindspore { -#ifdef ENABLE_DEBUGGER -class Debugger; -#endif -namespace device { -namespace ascend { -class AscendDeviceAddress : public DeviceAddress { - public: - explicit AscendDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} - explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id) - : DeviceAddress(ptr, size, format, type_id) {} - ~AscendDeviceAddress() override; - bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; - bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; - DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; } -#ifdef ENABLE_DUMP_E2E - bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type) const; -#endif -#ifdef ENABLE_DEBUGGER - bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type, size_t slot, Debugger *debugger) const; -#endif - - private: - bool SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const; - bool ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, TypeId type, - const void *host_ptr) const; - void SyncStream() const; -}; -using AscendDeviceAddressPtr = std::shared_ptr; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc deleted file mode 100644 index efdcb98755..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ /dev/null @@ -1,675 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#define PATH_MAX 0x3ffff -#include "device/ascend/ascend_kernel_runtime.h" -#include -#include -#include -#include -#include -#include -#include "device/ascend/ascend_device_address.h" -#include "device/cpu/mpi/mpi_adapter.h" -#include "utils/context/ms_context.h" -#include "utils/mpi/mpi_config.h" -#include "device/ascend/profiling/profiling_manager.h" -#include "hccl/hcom.h" -#include "common/trans.h" -#include "runtime/context.h" -#include "device/ascend/ascend_label_assign.h" -#include "device/ascend/ascend_stream_assign.h" -#include "device/ascend/ascend_memory_pool.h" -#include "framework/ge_runtime/model_runner.h" -#include "device/ascend/tasksink/task_generator.h" -#include "session/anf_runtime_algorithm.h" -#include "device/ascend/profiling/profiling_utils.h" -#include "kernel/tbe/tbe_utils.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "pre_activate/mem_reuse/mem_reuse_checker.h" -#include "device/ascend/ascend_memory_manager.h" -#include "debug/tensor_load.h" - -using mindspore::device::ascend::ProfilingManager; -using mindspore::device::ascend::ProfilingUtils; -using mindspore::device::ascend::tasksink::TaskGenerator; -using mindspore::kernel::tbe::TbeUtils; -using std::vector; - -namespace mindspore { -namespace device { -namespace ascend { -static const size_t PRAMATER_OUTPUT_INDEX = 0; -namespace { -std::string GetRankId() { - std::string rank_id_str; -#ifdef ENABLE_MPI - auto mpi_config_ptr = MpiConfig::GetInstance(); - MS_EXCEPTION_IF_NULL(mpi_config_ptr); - if (mpi_config_ptr->enable_mpi()) { - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - int rank_id = mpi_instance->GetRankId(); - const char *offset = std::getenv("RANK_OFFSET"); - if (offset != nullptr) { - try { - int rank_offset = std::stoi(offset); - rank_id += rank_offset; - } catch (std::invalid_argument) { - MS_LOG(EXCEPTION) << "Call stoi invalid argument:" << offset; - } catch (std::out_of_range) { - MS_LOG(EXCEPTION) << "Call stoi out_of_range:" << offset; - } - } - rank_id_str = std::to_string(rank_id); - } else { - rank_id_str = std::getenv("RANK_ID"); - } -#else - rank_id_str = std::getenv("RANK_ID"); -#endif - if (rank_id_str.empty()) { - MS_LOG(ERROR) << "Get hccl rankid failed, please set env RANK_ID"; - } - return rank_id_str; -} -} // namespace - -AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); } - -void AscendKernelRuntime::ClearGraphModelMap() { - for (auto &iter : graph_model_map_) { - MS_LOG(INFO) << "Ge UnloadModel " << iter.first; - auto ret = ge::model_runner::ModelRunner::Instance().UnloadModel(iter.first); - if (!ret) { - MS_LOG(ERROR) << "UnloadModel failed"; - } - } -} - -void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { - MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; - auto iter = graph_model_map_.find(graph_id); - if (iter == graph_model_map_.end()) { - MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; - return; - } - MS_LOG(DEBUG) << "Ge UnloadModel " << iter->first; - auto ret = ge::model_runner::ModelRunner::Instance().UnloadModel(iter->first); - if (!ret) { - MS_LOG(ERROR) << "UnloadModel failed"; - } - graph_model_map_.erase(iter); -} - -bool AscendKernelRuntime::NeedDestroyHccl() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->enable_hccl()) { - MS_LOG(INFO) << "Hccl is not enabled"; - return false; - } - // Note: make sure hcom_connectivity_detection api never be used. - return true; -} - -void AscendKernelRuntime::ReleaseDeviceRes() { - MS_LOG(INFO) << "Ascend finalize start"; - // release ge runtime - ClearGraphModelMap(); - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - auto ret = rtSetDevice(context_ptr->device_id()); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast(ret) << "]"; - } - - if (mem_manager_ != nullptr) { - mem_manager_->FreeDeviceMemory(); - } - - (void)DestroyHccl(); - (void)ResetDevice(); - (void)ProfilingManager::GetInstance().StopProfiling(); - MS_LOG(INFO) << "Ascend finalize end"; -} - -bool AscendKernelRuntime::Init() { - if (initialized_) { - return true; - } - bool ret = false; -#ifdef ENABLE_DUMP_E2E - ret = SetDumpConf(); - if (!ret) { - MS_LOG(INFO) << "No dump conf to set!"; - } -#endif - - // Start up profiling before rtSetDevice - ret = ProfilingManager::GetInstance().StartupProfiling(device_id_); - if (!ret) { - MS_EXCEPTION(DeviceProcessError) << "StartupProfiling failed."; - } - - ret = InitDevice(); - if (!ret) { - return ret; - } - mem_manager_ = std::make_shared(); - MS_EXCEPTION_IF_NULL(mem_manager_); - mem_manager_->MallocDeviceMemory(); - - initialized_ = true; - return ret; -} - -#ifdef ENABLE_DUMP_E2E -namespace { -void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(dump_conf); - bool trans_flag = dump_conf->trans_flag(); - const auto &apply_kernels = graph->execution_order(); - for (const auto &node : apply_kernels) { - MS_EXCEPTION_IF_NULL(node); - auto node_name = AnfAlgo::GetCNodeName(node); - std::string kernel_name = node->fullname_with_scope(); - if (!dump_conf->IsKernelNeedDump(kernel_name)) { - continue; - } - const std::string strsrc = "/"; - const std::string strdst = "--"; - std::string::size_type pos = 0; - std::string::size_type srclen = strsrc.size(); - std::string::size_type dstlen = strdst.size(); - while ((pos = kernel_name.find(strsrc, pos)) != std::string::npos) { - kernel_name.replace(pos, srclen, strdst); - pos += dstlen; - } - auto output_size = AnfAlgo::GetOutputTensorNum(node); - for (size_t j = 0; j < output_size; ++j) { - auto addr = AnfAlgo::GetOutputAddr(node, j); - std::vector int_shapes; - if (trans_flag) { - int_shapes = trans::GetRuntimePaddingShape(node, j); - } else { - auto shape = AnfAlgo::GetOutputDeviceShape(node, j); - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), - [](size_t inner_item) { return SizeToInt(inner_item); }); - } - auto type = AnfAlgo::GetOutputInferDataType(node, j); - auto format = kOpFormat_DEFAULT; - string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j); - auto ascend_addr = dynamic_cast(addr); - auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type); - if (!ret) { - MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath - << ", host_format:" << format << ".!"; - } - } - } -} - -void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(dump_conf); - bool trans_flag = dump_conf->trans_flag(); - const auto ¶meters = graph->inputs(); - for (auto &item : parameters) { - if (!item->isa()) { - continue; - } - std::string parameter_name = item->fullname_with_scope(); - if (!dump_conf->IsKernelNeedDump(parameter_name)) { - continue; - } - auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); - std::vector int_shapes; - if (trans_flag) { - int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); - } else { - auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX); - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), - [](size_t inner_item) { return SizeToInt(inner_item); }); - } - auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); - auto format = kOpFormat_DEFAULT; - string filepath = dump_path + '/' + parameter_name + '_' + "output_0"; - auto ascend_addr = dynamic_cast(addr); - auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type); - if (!ret) { - MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath - << ", host_format:" << format << ".!"; - } - } -} -} // namespace -#endif - -bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); -#ifdef ENABLE_DUMP_E2E - MS_LOG(INFO) << "Start dump step"; - DumpConfPtr dump_conf = GetDumpConf(); - MS_EXCEPTION_IF_NULL(dump_conf); - dump_conf->UpdataCurIter(); - bool dump_flag = dump_conf->dump_enable(); - if (!dump_flag) { - MS_LOG(INFO) << "Dump flag is disable, pass dump step"; - return true; - } - uint32_t cur_iter = dump_conf->cur_iter(); - if (dump_conf->dump_iter() != 0) { - if (cur_iter != dump_conf->dump_iter()) { - return true; - } - } - MS_LOG(INFO) << "Cur iter is " << cur_iter; - std::string net_name = dump_conf->dump_net_name(); - std::string iterator = to_string(cur_iter); - std::string dump_path = dump_conf->dump_path(); - if (dump_path.back() == '/') { - dump_path = dump_path + net_name + '/' + iterator; - } else { - dump_path = dump_path + '/' + net_name + '/' + iterator; - } - // dump output - DumpOutput(graph, dump_path, dump_conf); - // dump parameters - DumpParameters(graph, dump_path, dump_conf); -#endif - return true; -} - -#ifdef ENABLE_DEBUGGER -namespace { -void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) { - MS_EXCEPTION_IF_NULL(graph); - bool trans_flag = false; - const auto &apply_kernels = graph->execution_order(); - // for kernels, execution order starts from 1 - int exec_order = 1; - for (const auto &node : apply_kernels) { - MS_EXCEPTION_IF_NULL(node); - auto node_name = AnfAlgo::GetCNodeName(node); - std::string kernel_name = node->fullname_with_scope(); - auto output_size = AnfAlgo::GetOutputTensorNum(node); - for (size_t j = 0; j < output_size; ++j) { - auto addr = AnfAlgo::GetOutputAddr(node, j); - auto type = AnfAlgo::GetOutputInferDataType(node, j); - auto format = kOpFormat_DEFAULT; - string tensor_name = kernel_name + ':' + std::to_string(j); - auto ascend_addr = dynamic_cast(addr); - std::vector int_shapes; - if (trans_flag) { - int_shapes = trans::GetRuntimePaddingShape(node, j); - } else { - auto shape = AnfAlgo::GetOutputDeviceShape(node, j); - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), - [](size_t inner_item) { return SizeToInt(inner_item); }); - } - auto ret = ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, j, debugger); - if (!ret) { - MS_LOG(ERROR) << "LoadMemToHost: flag:" << trans_flag << ", tensor_name:" << tensor_name - << ", host_format:" << format << ".!"; - } - } - exec_order = exec_order + 1; - } -} - -void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) { - MS_EXCEPTION_IF_NULL(graph); - bool trans_flag = false; - const auto ¶meters = graph->inputs(); - // for parameters, set its execution order to be 0; - int exec_order = 0; - for (auto &item : parameters) { - if (!item->isa()) { - continue; - } - std::string parameter_name = item->fullname_with_scope(); - auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); - auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); - auto format = kOpFormat_DEFAULT; - string tensor_name = parameter_name + ':' + "0"; - auto ascend_addr = dynamic_cast(addr); - std::vector int_shapes; - if (trans_flag) { - int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); - } else { - auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX); - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), - [](size_t inner_item) { return SizeToInt(inner_item); }); - } - auto ret = ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, 0, debugger); - if (!ret) { - MS_LOG(ERROR) << "LoadMemToHost Failed: flag:" << trans_flag << ", path:" << tensor_name - << ", host_format:" << format << ".!"; - } - } -} -} // namespace -#endif - -bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { - MS_EXCEPTION_IF_NULL(graph); -#ifdef ENABLE_DEBUGGER - MS_LOG(INFO) << "Start load step"; - uint32_t cur_iter = 0; - MS_LOG(INFO) << "Cur iter is " << cur_iter; - // load output - LoadOutput(graph, debugger); - // load parameters - LoadParameters(graph, debugger); -#endif - return true; -} - -bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { - if (AnfAlgo::OutputAddrExist(kernel, index)) { - auto address = AnfAlgo::GetOutputAddr(kernel, index); - MS_EXCEPTION_IF_NULL(address); - return address->DeviceType() == DeviceAddressType::kAscend; - } - return false; -} - -DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) { - return std::make_shared(device_ptr, device_size, format, type_id); -} - -bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { - if (graph == nullptr) { - MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; - } - MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool is_task_sink = context_ptr->enable_task_sink(); - if (!is_task_sink) { - return true; - } -#ifdef MEM_REUSE_DEBUG - if (!context_ptr->enable_mem_reuse()) { - // Get normal graph ir for memreuse - mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph); - } -#endif - vector> task_info_list; - auto anf_node_list = graph->execution_order(); - TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id()); - // Store the task_info_list - auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list)); - if (!insert_ret.second) { - MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; - } - // Graph may have no compute node, such TensorAddGrad. - if (task_info_list.empty()) { - MS_LOG(WARNING) << "Graph " << graph->graph_id() << " have no compute node"; - return true; - } - AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); - // the streams' flag not HEAD_STREAM - std::vector wait_active_stream_list; - assign_instance.GetWaitStreams(&wait_active_stream_list); - std::vector force_copy_stream_list; - assign_instance.GetHcomStreams(&force_copy_stream_list); - MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.get_cur_stream_num() - << ", total event num:" << resource_manager.get_cur_event_num() - << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) - << ", wait_active_stream_list size:" << wait_active_stream_list.size() - << ", force_copy_stream_list size:" << force_copy_stream_list.size(); - std::vector> empty_list; - std::shared_ptr model = std::make_shared( - task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, - 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), - resource_manager.get_cur_event_num(), 0); - auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); - if (!ret.second) { - MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; - } - MS_LOG(INFO) << "TaskGenerator GetTaskInfo end..."; - return true; -} - -bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { - if (graph == nullptr) { - MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; - } - MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool is_task_sink = context_ptr->enable_task_sink(); - if (!is_task_sink) { - return true; - } - - if (GraphWithEmptyTaskList(graph)) { - MS_LOG(WARNING) << "LoadTask end, task list is empty"; - return true; - } - - auto model_iter = graph_model_map_.find(graph->graph_id()); - if (model_iter == graph_model_map_.end()) { - MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph LoadTask without GenTask."; - return false; - } - - std::shared_ptr listener; - MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first; - bool status = ge::model_runner::ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, - model_iter->second, listener); - if (!status) { - MS_LOG(EXCEPTION) << "Load Task Failed"; - } - if (ProfilingManager::GetInstance().IsProfiling()) { - auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(model_iter->first); - auto stream_ids = ge::model_runner::ModelRunner::Instance().GetStreamIdList(model_iter->first); - ProfilingUtils::ReportProfilingData(task_ids, stream_ids, NOT_NULL(graph)); - } - return true; -} - -void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) { - auto task_ids = ge::model_runner::ModelRunner::Instance().GetTaskIdList(graph_id); - auto graph_task_names = ProfilingUtils::graph_kernel_name(); - auto iter = graph_task_names.find(graph_id); - if (iter != graph_task_names.end()) { - const auto &task_names = iter->second; - if (task_ids.size() != task_names.size()) { - MS_LOG(WARNING) << "Task_ids and task_names size not match"; - return; - } - for (size_t i = 0; i < task_ids.size(); ++i) { - MS_LOG(INFO) << "Task_id:" << task_ids[i] << " task_name:" << task_names[i]; - } - } -} - -bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id(); - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - ge::InputData input_tensors = ge::InputData(); - ge::OutputData *output_tensors = nullptr; - if (GraphWithEmptyTaskList(graph)) { - MS_LOG(WARNING) << "RunTask end, no task info found"; - return true; - } - - if (!CheckGraphIdValid(graph->graph_id())) { - MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph RunTask without GenTask."; - return false; - } - - bool status = ge::model_runner::ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors); - if (!status) { - MS_LOG(ERROR) << "Run task failed"; - DebugTaskIdName(graph->graph_id()); - return false; - } - return true; -} - -bool AscendKernelRuntime::SyncStream() { - if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { // o for switch stream - MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; - return false; - } - return true; -} - -bool AscendKernelRuntime::InitDevice() { - int device_count = 0; - auto ret = rtGetDeviceCount(&device_count); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast(ret) << "]"; - } - - ret = rtSetDevice(device_id_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast(ret) << "]"; - } - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr == nullptr) { - MS_LOG(ERROR) << "Get MsContext instance failed"; - return false; - } - if (context_ptr->enable_hccl()) { - if (!HcclInit()) { - MS_LOG(ERROR) << "HcclInit init failed"; - return false; - } - } - - ret = rtCtxCreate(&rt_context_, 0, device_id_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast(ret) << "]"; - } - - ret = rtCtxSetCurrent(rt_context_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; - } - - ret = rtStreamCreate(&stream_, 0); - if (ret != RT_ERROR_NONE) { - MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]"; - } - - return true; -} - -bool AscendKernelRuntime::ResetDevice() { - auto ret = rtCtxSetCurrent(rt_context_); - if (ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Call rtCtxSetCurrent failed"; - return false; - } - - if (stream_ != nullptr) { - ret = rtStreamDestroy(stream_); - if (ret != RT_ERROR_NONE) { - MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]"; - } - stream_ = nullptr; - } - - if (rt_context_ != nullptr) { - ret = rtCtxDestroy(rt_context_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]"; - } - rt_context_ = nullptr; - } - return true; -} - -bool AscendKernelRuntime::HcclInit() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->IsTsdOpened()) { - MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open"; - } - MS_LOG(INFO) << "Do hcom init"; - auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH"); - if (config_path_str == nullptr) { - config_path_str = std::getenv("RANK_TABLE_FILE"); - if (config_path_str == nullptr) { - MS_LOG(ERROR) << "Get hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE"; - return false; - } - } - if (strlen(config_path_str) > PATH_MAX) { - MS_LOG(ERROR) << "File path oversize"; - return false; - } - std::string rank_id_str = GetRankId(); - auto full_path = realpath(config_path_str, nullptr); - if (full_path == nullptr) { - MS_LOG(ERROR) << "File path " << config_path_str << " does not exist"; - return false; - } - MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; - hcclResult_t res = hcom_init(full_path, rank_id_str.c_str()); - free(full_path); - if (res != HCCL_SUCCESS) { - MS_LOG(ERROR) << "Hcom init failed, res is " << static_cast(res); - return false; - } - return true; -} - -bool AscendKernelRuntime::DestroyHccl() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!NeedDestroyHccl()) { - MS_LOG(INFO) << "Hccl is not enable, no need to close."; - return true; - } - hcclResult_t res = hcom_destroy(); - if (res != HCCL_SUCCESS) { - MS_LOG(ERROR) << "Hccl destroy failed"; - return false; - } - MS_LOG(INFO) << "Hccl destroy successful, status = " << res << "."; - context_ptr->set_enable_hccl(false); - return true; -} - -bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const { - auto iter = task_map_.find(graph->graph_id()); - if (iter == task_map_.end()) { - MS_LOG(EXCEPTION) << "Unknown graph ptr"; - } - return iter->second.empty(); -} - -bool AscendKernelRuntime::CheckGraphIdValid(GraphId graph_id) const { - return task_map_.find(graph_id) != task_map_.end() && graph_model_map_.find(graph_id) != graph_model_map_.end(); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h deleted file mode 100644 index 69ba8b295a..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ -#include -#include -#include -#include -#include "device/kernel_runtime.h" -#include "runtime/context.h" -#include "framework/ge_runtime/davinci_model.h" -#include "device/kernel_runtime_manager.h" -#include "session/session_basic.h" - -using ge::model_runner::TaskInfo; -using std::unordered_map; -using std::vector; -namespace mindspore { -namespace device { -namespace ascend { -class AscendKernelRuntime : public KernelRuntime { - public: - AscendKernelRuntime() = default; - ~AscendKernelRuntime() override; - bool Init() override; - bool DumpData(session::KernelGraph *graph) override; - bool LoadData(session::KernelGraph *graph, Debugger *debugger) override; - bool GenTask(const session::KernelGraph *graph) override; - bool RunTask(const session::KernelGraph *graph) override; - bool LoadTask(const session::KernelGraph *graph) override; - void ClearGraphRuntimeResource(uint32_t graph_id) override; - bool SyncStream() override; - - protected: - DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) override; - bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override; - - private: - bool InitDevice(); - bool ResetDevice(); - bool HcclInit(); - bool NeedDestroyHccl(); - bool DestroyHccl(); - - void ClearGraphModelMap(); - void ReleaseDeviceRes() override; - bool GraphWithEmptyTaskList(const session::KernelGraph *graph) const; - bool CheckGraphIdValid(GraphId graph_id) const; - static void DebugTaskIdName(GraphId graph_id); - - rtContext_t rt_context_{nullptr}; - bool initialized_{false}; - unordered_map>> task_map_; - unordered_map> graph_model_map_; -}; - -MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime); -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc deleted file mode 100644 index 2db81a1725..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "device/ascend/ascend_label_assign.h" -#include "session/anf_runtime_algorithm.h" - -static constexpr uint32_t kLabelGotoLabelId = 1; -static constexpr uint32_t kLabelSwitchLabelId = 2; - -namespace mindspore { -namespace device { -namespace ascend { -static void UpdateLabelGoto(NotNull node) { - if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { - return; - } - if (node->size() <= kLabelGotoLabelId) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); - } - - auto input = node->input(kLabelGotoLabelId); - uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); - AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(goto_label_id), node.get()); - MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; - node->set_inputs({node->input(0)}); -} - -static void UpdateLabelSwitch(NotNull node) { - if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { - return; - } - if (node->size() <= kLabelGotoLabelId) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); - } - std::vector label_list; - for (size_t i = kLabelSwitchLabelId; i < node->size(); ++i) { - auto input = node->input(i); - if (!input->isa() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) { - break; - } - - uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); - label_list.push_back(goto_label_id); - MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; - } - AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue>(label_list), node.get()); - node->set_inputs({node->input(kAnfPrimitiveIndex), node->input(kFirstDataInputIndex)}); -} - -static void AssignLabelForLabelSet(NotNull> graph, NotNull label_id, - NotNull> *> memo) { - if (memo->find(graph.get()) != memo->end()) { - return; - } - memo->insert(graph.get()); - - MS_LOG(INFO) << "Assign label for " << graph->ToString(); - graph->SetExecOrderByDefault(); - auto nodes = graph->execution_order(); - - for (auto &node : nodes) { - if (!node->isa()) { - continue; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::string node_name = AnfAlgo::GetCNodeName(node); - if (node_name == kLabelSetOpName && !AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { - AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(*label_id), node); - MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << *label_id; - ++(*label_id); - } - } - - for (auto &cg : graph->child_graph_order()) { - AssignLabelForLabelSet(NOT_NULL(cg), label_id, memo); - } -} - -static void AssignLabelForGotoSwitch(NotNull> graph, - NotNull> *> memo) { - if (memo->find(graph.get()) != memo->end()) { - return; - } - memo->insert(graph.get()); - - MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); - - auto nodes = graph->execution_order(); - auto end_goto = graph->get_end_goto(); - if (end_goto != nullptr) { - nodes.push_back(end_goto); - } - for (auto &node : nodes) { - if (!node->isa()) { - continue; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::string node_name = AnfAlgo::GetCNodeName(node); - if (node_name == kLabelGotoOpName) { - UpdateLabelGoto(NOT_NULL(cnode)); - cnode->set_abstract(nullptr); - } - - if (node_name == kLabelSwitchOpName) { - UpdateLabelSwitch(NOT_NULL(cnode)); - } - } - for (auto &cg : graph->child_graph_order()) { - AssignLabelForGotoSwitch(NOT_NULL(cg), memo); - } - graph->SetExecOrderByDefault(); -} - -void AscendLabelAssign::AssignLabel(NotNull> graph) { - MS_LOG(INFO) << "Assign label start."; - std::set> memo; - uint32_t label_id = 0; - AssignLabelForLabelSet(graph, NOT_NULL(&label_id), NOT_NULL(&memo)); - memo.clear(); - { - std::lock_guard lock(label_num_mutex_); - label_num_[graph.get().get()] = label_id; - } - AssignLabelForGotoSwitch(graph, NOT_NULL(&memo)); - MS_LOG(INFO) << "Assign label end."; -} - -uint32_t AscendLabelAssign::GetLabelNum(NotNull graph) { - std::lock_guard lock(label_num_mutex_); - auto iter = label_num_.find(graph.get()); - if (iter == label_num_.end()) { - MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 0."; - return 0; - } - return iter->second; -} - -uint32_t AscendLabelAssign::GetLabelNum(NotNull> graph) { - return GetLabelNum(NOT_NULL(graph.get().get())); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.h b/mindspore/ccsrc/device/ascend/ascend_label_assign.h deleted file mode 100644 index 98055576eb..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ - -#include -#include -#include "session/kernel_graph.h" -#include "utils/contract.h" - -namespace mindspore { -namespace device { -namespace ascend { -class AscendLabelAssign { - public: - static AscendLabelAssign &GetInstance() { - static AscendLabelAssign instance; // Guaranteed to be destroyed. - return instance; - } - - AscendLabelAssign(const AscendLabelAssign &) = delete; - AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; - - void AssignLabel(NotNull> graph); - uint32_t GetLabelNum(NotNull graph); - uint32_t GetLabelNum(NotNull> graph); - - private: - AscendLabelAssign() = default; - ~AscendLabelAssign() = default; - - std::map label_num_; - std::mutex label_num_mutex_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc deleted file mode 100644 index 42c611c3af..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "device/ascend/ascend_memory_manager.h" -#include "device/ascend/ascend_memory_pool.h" -#include "utils/context/ms_context.h" -#include "runtime/mem.h" -namespace mindspore { -namespace device { -namespace ascend { -constexpr uint64_t kAscendDeviceMemGB = 26; -constexpr uint64_t kAscendMemPoolGB = 4; -constexpr uint64_t kMemSizeGB = 30; -constexpr uint64_t kMaxMemSizeGB = 30; -constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB); -constexpr uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << kMemSizeGB); - -void AscendMemoryManager::MallocDeviceMemory() { - auto context_mem = GetDeviceMemSizeFromContext(); - device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem; - static_mem_offset_ = device_mem_size_; - auto ret = rtMalloc(reinterpret_cast(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]"; - } - - if (context_mem == 0) { - device_mem_pool_size_ = kAscendMemPoolSize; - ret = rtMalloc(reinterpret_cast(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]"; - } - AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_pool_base_); - AscendMemoryPool::GetInstance().set_device_mem_pool_size(device_mem_pool_size_); - } -} - -uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - auto variable_memory_max_size = context->variable_memory_max_size(); - if (variable_memory_max_size == "0") { - return 0; - } - MS_LOG(INFO) << "context variable_memory_max_size:" << variable_memory_max_size; - auto pos = variable_memory_max_size.find('*'); - if (pos == std::string::npos) { - MS_LOG(EXCEPTION) << "Invalid variable_memory_max_size"; - } - auto gb_str = variable_memory_max_size.substr(0, pos); - auto gb_var = std::stoull(gb_str); - MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var; - if (gb_var > kMaxMemSizeGB || gb_var == 0) { - MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB"; - } - return gb_var << kMemSizeGB; -} - -void AscendMemoryManager::FreeDeviceMemory() { - if (device_mem_base_ != nullptr) { - auto ret = rtFree(device_mem_base_); - if (ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "rtFree mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]"; - } - device_mem_base_ = nullptr; - } - if (device_mem_pool_base_ != nullptr) { - auto ret = rtFree(device_mem_pool_base_); - if (ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "rtFree mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]"; - } - device_mem_pool_base_ = nullptr; - } -} - -void *AscendMemoryManager::MallocMemFromMemPool(size_t size) { - return AscendMemoryPool::GetInstance().AllocTensorMem(size); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_manager.h b/mindspore/ccsrc/device/ascend/ascend_memory_manager.h deleted file mode 100644 index 7fdd8f553e..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_memory_manager.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ -#include "device/memory_manager.h" -namespace mindspore { -namespace device { -namespace ascend { -class AscendMemoryManager : public MemoryManager { - public: - AscendMemoryManager() = default; - ~AscendMemoryManager() override = default; - - void MallocDeviceMemory() override; - void FreeDeviceMemory() override; - void *MallocMemFromMemPool(size_t size) override; - - private: - uint8_t *device_mem_pool_base_{nullptr}; - uint64_t device_mem_pool_size_{0}; - - uint64_t GetDeviceMemSizeFromContext(); -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc b/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc deleted file mode 100644 index 69c6dca576..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/ascend_memory_pool.h" -#include "device/ascend/ascend_kernel_runtime.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -namespace ascend { -size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { - if (has_malloc_) { - MS_LOG(EXCEPTION) << "Has alloc memory pool memory !"; - } - if (size == 0 || size > free_mem_size_) { - MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory !"; - } - *addr = device_mem_pool_base_; - if (*addr == nullptr) { - MS_LOG(EXCEPTION) << "Device memory pool base is nullptr, failed to alloc memory pool memory!"; - } - has_malloc_ = true; - free_mem_size_ -= size; - return size; -} - -bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { - MS_EXCEPTION_IF_NULL(addr); - has_malloc_ = false; - free_mem_size_ = total_mem_size_; - return true; -} - -size_t AscendMemoryPool::AlignMemorySize(size_t size) const { - if (size == 0) { - return DYNAMIC_MEM_ALIGN_SIZE; - } - return ((size + DYNAMIC_MEM_ALIGN_SIZE + 31) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE; -} - -size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; } - -void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { - MS_EXCEPTION_IF_NULL(device_mem_pool_base); - device_mem_pool_base_ = device_mem_pool_base; -} - -size_t AscendMemoryPool::free_mem_size() { return free_mem_size_; } - -size_t AscendMemoryPool::total_mem_size() { return total_mem_size_; } -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h b/mindspore/ccsrc/device/ascend/ascend_memory_pool.h deleted file mode 100644 index 7fa3ebc23e..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ - -#include -#include "pre_activate/mem_reuse/mem_dynamic_allocator.h" - -namespace mindspore { -namespace device { -namespace ascend { -class AscendMemoryPool : public DynamicMemPoolBestFit { - public: - ~AscendMemoryPool() override = default; - AscendMemoryPool(const AscendMemoryPool &) = delete; - AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; - - size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; - bool FreeDeviceMem(const DeviceMemPtr &addr) override; - void set_device_mem_pool_base(uint8_t *device_mem_pool_base); - void set_device_mem_pool_size(uint64_t device_mem_pool_size) { - device_mem_pool_size_ = device_mem_pool_size; - free_mem_size_ = device_mem_pool_size_; - total_mem_size_ = free_mem_size_; - } - size_t free_mem_size() override; - size_t total_mem_size() override; - - static AscendMemoryPool &GetInstance() { - static AscendMemoryPool instance; - return instance; - } - - protected: - // The real size by memory alloc aligned. - size_t AlignMemorySize(size_t size) const override; - // Get the minimum memory unit size using for dynamic extend. - size_t mem_alloc_unit_size() const override; - - private: - AscendMemoryPool() = default; - bool has_malloc_{false}; - uint8_t *device_mem_pool_base_{nullptr}; - uint64_t device_mem_pool_size_{0}; - size_t free_mem_size_{0}; - size_t total_mem_size_{0}; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc deleted file mode 100644 index 736d6203e9..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ /dev/null @@ -1,916 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/ascend_stream_assign.h" - -#include -#include - -#include "ir/manager.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_adjust.h" -#include "predict/generator/utils/ir_model_util.h" -#include "pre_activate/common/helper.h" -#include "utils/utils.h" - -namespace mindspore { -namespace device { -namespace ascend { -const uint32_t kHcomMaxTask = 5; -const uint32_t kCommonMaxTask = 350; - -void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) { - if (IsTaskSink()) { - Reset(); - ReorderIndependentOrders(graph_ptr); - AssignAllNodesStream(graph_ptr); - UpdateAtomicAddrCleanStreamId(graph_ptr); - InsertStreamActive(graph_ptr); - InsertEventForHcomParallel(graph_ptr); - InsertEventForIndependentParallel(graph_ptr); - GetNeedActiveStreams(graph_ptr); - graph_ptr->PrintGraphExecuteOrder(); - CheckResourceAssign(graph_ptr); - MS_LOG(INFO) << "After finish stream assign"; - - // Get info for D Model - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - generator::IRModelUtil::GetInstance().set_event_num(resource_manager.get_cur_event_num()); - generator::IRModelUtil::GetInstance().set_stream_num(resource_manager.get_cur_stream_num()); - // Init to 1,temporarily - generator::IRModelUtil::GetInstance().set_batch_num(1); - } -} - -// section 1 -void AscendStreamAssign::ReorderIndependentOrders(const NotNull &graph_ptr) { - std::vector exe_orders; - std::vector independents; - std::vector others; - - auto cnode_ptr_list = graph_ptr->execution_order(); - MS_LOG(INFO) << "Before reorder, graph orders size:" << cnode_ptr_list.size(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - auto cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (IsIndependentNode(cur_cnode_ptr)) { - independents.emplace_back(cur_cnode_ptr); - } else { - others.emplace_back(cur_cnode_ptr); - } - } - - if (others.empty() || independents.empty()) { - MS_LOG(INFO) << "Independent or others is empty, no need reorder"; - return; - } - - std::set processed; - for (size_t i = 0; i < others.size(); i++) { - auto begin = others.begin() + i; - auto end = begin + 1; - bool flag = false; - for (size_t j = 0; j < independents.size(); j++) { - auto cur_independent = independents[j]; - auto it = std::find(processed.begin(), processed.end(), cur_independent.get()); - if (it != processed.end()) { - continue; - } - - auto res = FindTargetOp(begin, end, cur_independent); - if (res != end) { - flag = true; - exe_orders.emplace_back(cur_independent); - exe_orders.emplace_back(*begin); - processed.emplace(cur_independent.get()); - break; - } - } - - if (!flag) { - exe_orders.emplace_back(*begin); - } - } - - MS_LOG(INFO) << "After reorder, graph orders size:" << exe_orders.size(); - if (processed.size() != independents.size()) { - MS_LOG(WARNING) << "Processed independent nodes size is not equal to exiting independent nodes size"; - return; - } - - graph_ptr->set_execution_order(exe_orders); -} - -// section 2 -void AscendStreamAssign::AssignAllNodesStream(const NotNull &graph_ptr) { - auto cnode_ptr_list = graph_ptr->execution_order(); - bool exit_independent = false; - bool exit_hcom = false; - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - // node has been assigned stream before - if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { - continue; - } - - if (IsHcom(cur_cnode_ptr)) { - exit_hcom = true; - continue; - } - - if (IsIndependentNode(cur_cnode_ptr)) { - exit_independent = true; - continue; - } - - AssignCommonStreamId(cur_cnode_ptr); - } - MS_LOG(INFO) << "Common start from 0, common stream nums:" << resource_manager.get_cur_stream_num(); - - if (exit_hcom) { - uint32_t first_hcom_stream_id = resource_manager.ApplyNewStream(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - // node has been assigned stream before - if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { - continue; - } - - if (IsHcom(cur_cnode_ptr)) { - AssignHcomStreamId(cur_cnode_ptr); - } - } - MS_LOG(INFO) << "Hcom start from :" << first_hcom_stream_id << ", hcom stream nums:" << hcom_stream_map_.size(); - } - - if (exit_independent) { - uint32_t first_independ = resource_manager.ApplyNewStream(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { - continue; - } - if (IsIndependentNode(cur_cnode_ptr)) { - AssignIndependentStreamId(cur_cnode_ptr); - } - } - MS_LOG(INFO) << "Independ start from:" << first_independ << ", stream nums:" << independent_stream_map_.size(); - } - - MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num(); -} - -void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t cur_common_stream_id = 0; - uint32_t cur_stream_num = resource_manager.get_cur_stream_num(); - if (cur_stream_num == 0) { - cur_common_stream_id = resource_manager.ApplyNewStream(); - } else { - cur_common_stream_id = resource_manager.GetCurAllocStreamId(); - } - - auto it = common_stream_map_.find(cur_common_stream_id); - if (it == common_stream_map_.end()) { - AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get()); - common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1)); - } else { - if (it->second < kCommonMaxTask) { - AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); - it->second++; - } else { - cur_common_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get()); - common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1)); - } - } -} - -void AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr) { - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t cur_hcom_stream_id = resource_manager.GetCurAllocStreamId(); - auto it = hcom_stream_map_.find(cur_hcom_stream_id); - if (it == hcom_stream_map_.end()) { - AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); - hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); - } else { - if (it->second < kHcomMaxTask) { - AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); - it->second++; - } else { - cur_hcom_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); - hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); - } - } -} - -void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr) { - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t cur_independent_id = resource_manager.GetCurAllocStreamId(); - auto it = independent_stream_map_.find(cur_independent_id); - if (it == independent_stream_map_.end()) { - AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); - independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); - } else { - if (it->second < kCommonMaxTask) { - AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); - it->second++; - } else { - cur_independent_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); - independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); - } - } -} - -bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { - MS_EXCEPTION_IF_NULL(node_ptr); - if (AnfAlgo::GetKernelType(node_ptr) != AICPU_KERNEL) { - return false; - } - - if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) { - MS_LOG(INFO) << "GetNext should not be independent node"; - return false; - } - - uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr); - if (input_nums == 0) { - MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero"; - return true; - } - - auto inputs = node_ptr->inputs(); - for (size_t i = 1; i < inputs.size(); i++) { - if (!inputs[i]->isa()) { - return false; - } - } - MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node"; - return true; -} - -// section 3: -void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr) { - MS_LOG(INFO) << "Start"; - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - // update AtomicAddrClean stream same witch the next node - if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) { - AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get()); - } - } - MS_LOG(INFO) << "End"; -} - -// section 4 -void AscendStreamAssign::InsertStreamActive(const NotNull &graph_ptr) { - MS_LOG(INFO) << "Start"; - GetProcessedStream(graph_ptr); - std::vector update_cnode_list; - CNodePtr cur_cnode_ptr = nullptr; - CNodePtr pre_cnode_ptr = nullptr; - uint32_t pre_stream_id = UINT32_MAX; - - bool independent_flag = !(independent_stream_map_.empty()); - bool hcom_flag = !(hcom_stream_map_.empty()); - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (IsIndependentNode(cur_cnode_ptr)) { - update_cnode_list.emplace_back(cur_cnode_ptr); - continue; - } - - if (IsHcom(cur_cnode_ptr)) { - update_cnode_list.emplace_back(cur_cnode_ptr); - continue; - } - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - bool processed = IsProcessedStream(cur_stream_id); - // 1)inner stream assign, need insert active op - if (!processed) { - MS_LOG(INFO) << "Common stream active info:" << pre_stream_id << "->active" << cur_stream_id; - CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); - // 1.set stream id - AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); - // 2.set active stream ids - std::vector active_index_list{cur_stream_id}; - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); - update_cnode_list.emplace_back(active_ptr); - } - - if ((independent_flag || hcom_flag) && (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName)) { - MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel"; - UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list); - } else { - update_cnode_list.emplace_back(cur_cnode_ptr); - } - - processed_streams_.emplace(cur_stream_id); - pre_stream_id = cur_stream_id; - pre_cnode_ptr = cur_cnode_ptr; - } - graph_ptr->set_execution_order(update_cnode_list); - MS_LOG(INFO) << "End"; -} - -void AscendStreamAssign::GetProcessedStream(const NotNull &graph_ptr) { - // 0 stream is activated at first - processed_streams_.emplace(0); - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - auto cur_cnode_ptr = cnode_ptr_list[i]; - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { - auto true_stream_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrTrueBranchStream); - processed_streams_.emplace(true_stream_id); - - if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { - continue; - } - auto need_active = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kStreamNeedActivedFirst); - if (need_active) { - processed_streams_.emplace(cur_stream_id); - } - } - } - for (const auto &item : processed_streams_) { - MS_LOG(INFO) << "Before active:" << item << " is been processed"; - } -} - -void AscendStreamAssign::UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, - vector *orders) { - orders->emplace_back(switch_ptr); - if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) { - return; - } - - auto need_active = AnfAlgo::GetNodeAttr(switch_ptr, kStreamNeedActivedFirst); - if (!need_active) { - return; - } - - MS_EXCEPTION_IF_NULL(switch_ptr); - auto true_stream_id = AnfAlgo::GetNodeAttr(switch_ptr, kAttrTrueBranchStream); - MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) - << "; active stream id:" << true_stream_id; - - CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); - AnfAlgo::SetStreamId(true_stream_id, active_ptr.get()); - vector active_ids; - // active indepdent stream - for (const auto &item : independent_stream_map_) { - active_ids.emplace_back(item.first); - } - // active hcom stream - for (const auto &item : hcom_stream_map_) { - active_ids.emplace_back(item.first); - } - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_ids), active_ptr); - - // update processed stream - independent_stream_activated_ = true; - for (const auto &item : independent_stream_map_) { - processed_streams_.emplace(item.first); - } - - hcom_stream_activated_ = true; - for (const auto &item : hcom_stream_map_) { - processed_streams_.emplace(item.first); - } - - orders->emplace_back(active_ptr); -} - -bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { - auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id); - if (it != processed_streams_.end()) { - return true; - } - return false; -} - -// section5 -void AscendStreamAssign::InsertEventForHcomParallel(const NotNull &graph_ptr) { - MS_LOG(INFO) << "Start"; - InsertEventCommonDependHcom(graph_ptr); - InsertEventHcomDependCommon(graph_ptr); - InsertEventHcomDependHcom(graph_ptr); - MS_LOG(INFO) << "End"; -} - -void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - vector cnodes = cnode_ptr_list; - uint32_t cur_event_id = resource_manager.ApplyNewEvent(); - auto it = cnodes.begin(); - while (it != cnodes.end() && (it + 1) != cnodes.end()) { - MS_EXCEPTION_IF_NULL(*it); - MS_EXCEPTION_IF_NULL(*(it + 1)); - if (IsHcom(*it) && !IsHcom(*(it + 1))) { - CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); - it = cnodes.insert(it + 1, send_cnode_ptr); - - auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); - if (target == cnodes.end()) { - MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope() - << ", can't find target for insert recv op, no insert send/recv"; - it = cnodes.erase(it); - continue; - } - - if (IsHcom(*target)) { - it = cnodes.erase(it); - continue; - } - - // deal recv op - uint32_t stream_id = AnfAlgo::GetStreamId(*target); - CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); - (void)cnodes.insert(target, recv_cnode_ptr); - cur_event_id = resource_manager.ApplyNewEvent(); - } - ++it; - } - // one event allocated additional, should delete - resource_manager.DeleteEvent(); - graph_ptr->set_execution_order(cnodes); - MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num(); -} - -void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - vector cnodes; - CNodePtr cur_cnode_ptr = nullptr; - uint32_t pre_stream_id = UINT32_MAX; - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (i == 0) { - cnodes.emplace_back(cur_cnode_ptr); - pre_stream_id = cur_stream_id; - continue; - } - - if (!IsHcom(cur_cnode_ptr)) { - cnodes.emplace_back(cur_cnode_ptr); - pre_stream_id = cur_stream_id; - continue; - } - - if (cur_stream_id == pre_stream_id) { - cnodes.emplace_back(cur_cnode_ptr); - pre_stream_id = cur_stream_id; - continue; - } - - if (!IsHcom(cnode_ptr_list[i - 1])) { - uint32_t cur_event_id = resource_manager.ApplyNewEvent(); - auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id); - cnodes.emplace_back(send); - auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id); - cnodes.emplace_back(recv); - cnodes.emplace_back(cur_cnode_ptr); - } - pre_stream_id = cur_stream_id; - } - - graph_ptr->set_execution_order(cnodes); - MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); -} - -void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - uint32_t first_hcom_stream = kInvalidStreamId; - uint32_t last_hcom_stream = kInvalidStreamId; - // key: stream id, value:hcom index - std::map> hcom_index; - for (size_t i = 0; i < cnode_ptr_list.size(); i++) { - auto cur_cnode = cnode_ptr_list[i]; - if (!IsHcom(cur_cnode)) { - continue; - } - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); - auto it = hcom_index.find(cur_stream_id); - if (it != hcom_index.end()) { - hcom_index[cur_stream_id].emplace_back(i); - } else { - hcom_index[cur_stream_id] = {i}; - } - - // record first hcom stream id - if (first_hcom_stream == kInvalidStreamId) { - first_hcom_stream = cur_stream_id; - } - - // record last hcom stream id - if (cur_stream_id != last_hcom_stream) { - last_hcom_stream = cur_stream_id; - } - } - - if (hcom_index.size() < 2) { - MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them"; - return; - } - InsertEventBetweenHcom(graph_ptr, hcom_index, first_hcom_stream, last_hcom_stream); - MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); -} - -void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &graph_ptr, - const map> &hcom_index, - uint32_t first_hcom_stream, uint32_t last_hcom_stream) { - vector orders; - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - uint32_t cur_event_id = resource_manager.ApplyNewEvent(); - size_t first_stream_last_index = hcom_index.at(first_hcom_stream).back(); - size_t last_stream_first_index = hcom_index.at(last_hcom_stream).front(); - std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders)); - for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) { - auto cur_cnode = cnode_ptr_list[i]; - if (!IsSatisfiedHcom(hcom_index, cur_cnode, i)) { - orders.emplace_back(cur_cnode); - continue; - } - auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode); - if (i == first_stream_last_index) { - // first fusion hcom - orders.emplace_back(cur_cnode); - auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(send); - } else if (i == last_stream_first_index) { - // last fusion hcom - auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(recv); - orders.emplace_back(cur_cnode); - } else { - auto cur_stream_hcom_size = hcom_index.at(cur_hcom_stream_id).size(); - if (cur_stream_hcom_size == 1) { - auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(recv); - cur_event_id = resource_manager.ApplyNewEvent(); - orders.emplace_back(cur_cnode); - auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(send); - } else { - // current stream, first hcom:add recv op - if (i == hcom_index.at(cur_hcom_stream_id).front()) { - auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(recv); - cur_event_id = resource_manager.ApplyNewEvent(); - orders.emplace_back(cur_cnode); - } else if (i == hcom_index.at(cur_hcom_stream_id).back()) { - // current stream, last hcom:add send op - orders.emplace_back(cur_cnode); - auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(send); - } else { - // current stream, not first and last op - orders.emplace_back(cur_cnode); - } - } - } - } - std::copy(cnode_ptr_list.begin() + last_stream_first_index + 1, cnode_ptr_list.end(), std::back_inserter(orders)); - graph_ptr->set_execution_order(orders); -} - -bool AscendStreamAssign::IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, - size_t index) { - MS_EXCEPTION_IF_NULL(node_ptr); - auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr); - auto it = hcom_index.find(cur_hcom_stream_id); - if (it == hcom_index.end()) { - return false; - } - auto iter = std::find(hcom_index.at(cur_hcom_stream_id).begin(), hcom_index.at(cur_hcom_stream_id).end(), index); - if (iter == hcom_index.at(cur_hcom_stream_id).end()) { - return false; - } - return true; -} - -// section6 -void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull &graph_ptr) { - MS_LOG(INFO) << "Start"; - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - vector cnodes = cnode_ptr_list; - uint32_t cur_event_id = resource_manager.ApplyNewEvent(); - auto it = cnodes.begin(); - while (it != cnodes.end()) { - MS_EXCEPTION_IF_NULL(*it); - if (IsIndependentNode(*it)) { - MS_LOG(INFO) << "Deal independent op[" << (*it)->DebugString() << "]"; - CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); - it = cnodes.insert(it + 1, send_cnode_ptr); - - auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); - if (target == cnodes.end()) { - MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope() - << "] can't find target for insert recv op, no insert send/recv"; - it = cnodes.erase(it); - continue; - } - - // deal recv op - uint32_t stream_id = AnfAlgo::GetStreamId(*target); - CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); - (void)cnodes.insert(target, recv_cnode_ptr); - cur_event_id = resource_manager.ApplyNewEvent(); - } - ++it; - } - // one event allocated additional, should delete - resource_manager.DeleteEvent(); - graph_ptr->set_execution_order(cnodes); - MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.get_cur_event_num(); - MS_LOG(INFO) << "End"; -} - -// section7 -void AscendStreamAssign::GetNeedActiveStreams(const NotNull &graph_ptr) { - CNodePtr cur_cnode_ptr = nullptr; - auto cnode_ptr_list = graph_ptr->execution_order(); - // 1)first stream 0 should be actived first; - need_first_active_streams_.emplace_back(0); - - // 2)stream witch kStreamNeedActivedFirst attr should be actived; - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { - continue; - } - - auto need_active = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kStreamNeedActivedFirst); - if (need_active) { - auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first"; - need_first_active_streams_.push_back(stream_id); - } - } - - // 3)independent stream:if has not been activate, push to need active vector - if (!independent_stream_activated_) { - for (auto &item : independent_stream_map_) { - need_first_active_streams_.emplace_back(item.first); - } - } - - // 4)hcom stream:if has not been activate, push to need active vector - if (!hcom_stream_activated_) { - for (auto &item : hcom_stream_map_) { - need_first_active_streams_.emplace_back(item.first); - } - } -} - -// section8 -void AscendStreamAssign::CheckResourceAssign(const NotNull &graph_ptr) { - CheckStreamAssign(graph_ptr); - CheckEventAssign(graph_ptr); -} - -void AscendStreamAssign::CheckStreamAssign(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - std::set streams; - uint32_t max_stream = 0; - uint32_t min_stream = kInvalidStreamId; - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - if (stream_id == kInvalidStreamId) { - MS_LOG(EXCEPTION) << "Node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "had not been assigned stream"; - } - - (void)streams.emplace(stream_id); - if (stream_id > max_stream) { - max_stream = stream_id; - } - if (stream_id < min_stream) { - min_stream = stream_id; - } - } - - // check stream assign - if (!streams.empty()) { - if (min_stream != 0) { - MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream; - } - uint32_t assigned_stream_num = resource_manager.get_cur_stream_num(); - if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) { - MS_LOG(EXCEPTION) << "Stream should be consecutive, max stream id:" << max_stream - << "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size(); - } - } -} - -void AscendStreamAssign::CheckEventAssign(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - std::map> event_map; - uint32_t max_event_id = 0; - uint32_t min_event_id = kInvalidEventId; - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - auto name = AnfAlgo::GetCNodeName(cur_cnode_ptr); - if (name == kSendOpName || name == kRecvOpName) { - uint32_t event_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId); - if (event_id > max_event_id) { - max_event_id = event_id; - } - - if (event_id < min_event_id) { - min_event_id = event_id; - } - auto it = event_map.find(event_id); - if (it == event_map.end()) { - event_map[event_id] = {cur_cnode_ptr}; - } else { - event_map[event_id].emplace_back(cur_cnode_ptr); - } - } - } - // check event assign - if (!event_map.empty()) { - if (min_event_id != 0) { - MS_LOG(EXCEPTION) << "Event should start from 0, now is from " << min_event_id; - } - uint32_t assigned_event_num = resource_manager.get_cur_event_num(); - if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) { - MS_LOG(EXCEPTION) << "Event should be consecutive"; - } - for (const auto &item : event_map) { - if (item.second.size() != 2) { - MS_LOG(EXCEPTION) << "Send/recv should be in pair and share one event id"; - } - auto first_name = AnfAlgo::GetCNodeName(item.second[0]); - auto second_name = AnfAlgo::GetCNodeName(item.second[1]); - if (!(first_name == kSendOpName && second_name == kRecvOpName)) { - MS_LOG(EXCEPTION) << "Send should be before recv"; - } - } - } -} - -// section9 -CNodePtr AscendStreamAssign::CreateSendApplyKernel(const NotNull &graph_ptr, uint32_t event_id, - uint32_t stream_id) { - auto send_op = std::make_shared(kSendOpName); - MS_EXCEPTION_IF_NULL(send_op); - auto send_apply = std::make_shared(send_op); - MS_EXCEPTION_IF_NULL(send_apply); - std::vector send_input_list = {send_apply}; - CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); - MS_EXCEPTION_IF_NULL(send_node_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); - AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); - auto abstract_none = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_none); - send_node_ptr->set_abstract(abstract_none); - AnfAlgo::SetStreamId(stream_id, send_node_ptr.get()); - return send_node_ptr; -} - -CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull &graph_ptr, uint32_t event_id, - uint32_t stream_id) { - auto recv_op = std::make_shared(kRecvOpName); - MS_EXCEPTION_IF_NULL(recv_op); - auto recv_apply = std::make_shared(recv_op); - MS_EXCEPTION_IF_NULL(recv_apply); - std::vector recv_input_list = {recv_apply}; - CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); - MS_EXCEPTION_IF_NULL(recv_node_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); - AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); - AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get()); - auto abstract_none = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_none); - recv_node_ptr->set_abstract(abstract_none); - return recv_node_ptr; -} - -vector::iterator AscendStreamAssign::FindTargetOp(vector::iterator begin, - vector::iterator end, const CNodePtr &node) { - while (begin != end) { - auto inputs = (*begin)->inputs(); - for (size_t i = 1; i < inputs.size(); i++) { - auto input = inputs[i]; - if (opt::IsNopNode(input)) { - CNodePtr cnode = input->cast(); - auto new_inputs = cnode->inputs(); - for (size_t j = 1; j < new_inputs.size(); j++) { - auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0); - if (node == new_real_input.first) { - MS_LOG(INFO) << "Nop node find target op[" << (*begin)->DebugString() << "]"; - return begin; - } - } - } else { - auto real_input = AnfAlgo::VisitKernel(input, 0); - if (node == real_input.first) { - MS_LOG(INFO) << "Find target op[" << (*begin)->DebugString() << "]"; - return begin; - } - } - } - ++begin; - } - return end; -} - -bool AscendStreamAssign::IsTaskSink() { - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (!ms_context->enable_task_sink()) { - MS_LOG(INFO) << "Task sink mode is not enable"; - return false; - } else { - MS_LOG(INFO) << "Task sink mode is enable"; - return true; - } -} - -void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { - MS_EXCEPTION_IF_NULL(wait_active_stream_list); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t total_stream_num = resource_manager.get_cur_stream_num(); - if (total_stream_num == 0) { - MS_LOG(INFO) << "The total_common_stream_num is zero"; - return; - } - - // common stream:active first common stream - for (uint32_t i = 0; i < total_stream_num; i++) { - auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); - if (it == need_first_active_streams_.end()) { - MS_LOG(INFO) << "Wait common stream id = " << i; - wait_active_stream_list->push_back(i); - } - } -} - -bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) { - MS_EXCEPTION_IF_NULL(apply_kernel); - return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL; -} - -void AscendStreamAssign::GetHcomStreams(std::vector *streams) { - MS_EXCEPTION_IF_NULL(streams); - for (const auto &item : hcom_stream_map_) { - streams->emplace_back(item.first); - } -} - -void AscendStreamAssign::Reset() { - independent_stream_activated_ = false; - hcom_stream_activated_ = false; - independent_stream_map_.clear(); - hcom_stream_map_.clear(); - common_stream_map_.clear(); - processed_streams_.clear(); - need_first_active_streams_.clear(); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h deleted file mode 100644 index 625ab6ad6e..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "runtime/base.h" -#include "runtime/rt_model.h" -#include "runtime/stream.h" -#include "session/kernel_graph.h" -#include "utils/contract.h" - -namespace mindspore { -namespace device { -namespace ascend { -using std::map; -using std::shared_ptr; -using std::unordered_map; -using std::unordered_set; -using std::vector; -const uint32_t kInvalidStreamId = UINT32_MAX; -const uint32_t kInvalidEventId = UINT32_MAX; -class AscendResourceMng { - public: - static AscendResourceMng &GetInstance() { - static AscendResourceMng instance; - return instance; - } - - void ResetResource() { - cur_stream_num_ = 0; - cur_event_num_ = 0; - } - uint32_t ApplyNewStream() { - if (!cur_stream_num_) { - uint32_t cur_stream_id = cur_stream_num_; - cur_stream_num_++; - return cur_stream_id; - } - uint32_t cur_stream_id = cur_stream_num_; - cur_stream_num_++; - return cur_stream_id; - } - uint32_t ApplyNewEvent() { - if (!cur_event_num_) { - uint32_t cur_event_id = cur_event_num_; - cur_event_num_++; - return cur_event_id; - } - uint32_t cur_event_id = cur_event_num_; - cur_event_num_++; - return cur_event_id; - } - - void DeleteEvent() { - if (!cur_event_num_) { - MS_LOG(WARNING) << "total event num is 0, no event to delete"; - } else { - --cur_event_num_; - } - } - uint32_t get_cur_stream_num() { return cur_stream_num_; } - uint32_t GetCurAllocStreamId() { - if (!cur_stream_num_) { - MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get"; - } - return cur_stream_num_ - 1; - } - uint32_t get_cur_event_num() { return cur_event_num_; } - - private: - uint32_t cur_stream_num_{0}; - uint32_t cur_event_num_{0}; -}; - -class AscendStreamAssign { - public: - static AscendStreamAssign &GetInstance() { - static AscendStreamAssign instance; // Guaranteed to be destroyed. - return instance; - } - - AscendStreamAssign(const AscendStreamAssign &) = delete; - AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; - - void AssignStream(const NotNull &graph_ptr); - void GetHcomStreams(std::vector *streams); - void GetWaitStreams(vector *wait_active_stream_list); - CNodePtr CreateSendApplyKernel(const NotNull &graph_ptr, uint32_t event_id, uint32_t stream_id); - CNodePtr CreateRecvApplyKernel(const NotNull &graph_ptr, uint32_t event_id, uint32_t stream_id); - - private: - AscendStreamAssign() = default; - ~AscendStreamAssign() = default; - void Reset(); - void CheckResourceAssign(const NotNull &graph_ptr); - void CheckStreamAssign(const NotNull &graph_ptr); - void CheckEventAssign(const NotNull &graph_ptr); - void AssignAllNodesStream(const NotNull &graph_ptr); - void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr); - void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr); - void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr); - void UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr); - void FindHcomParallelStreams(const NotNull &graph_ptr); - void InsertStreamActive(const NotNull &graph_ptr); - void UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, - vector *orders); - void InsertEventForIndependentParallel(const NotNull &graph_ptr); - void InsertEventForHcomParallel(const NotNull &graph_ptr); - void InsertEventCommonDependHcom(const NotNull &graph_ptr); - void InsertEventHcomDependCommon(const NotNull &graph_ptr); - void InsertEventHcomDependHcom(const NotNull &graph_ptr); - void InsertEventBetweenHcom(const NotNull &graph_ptr, const map> &hcom_index, - uint32_t first_hcom_stream, uint32_t last_hcom_stream); - bool IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, size_t index); - - void GetProcessedStream(const NotNull &graph_ptr); - void GetNeedActiveStreams(const NotNull &graph_ptr); - void ReorderIndependentOrders(const NotNull &graph_ptr); - - bool IsTaskSink(); - bool IsHcom(const CNodePtr &cur_cnode_ptr); - bool IsIndependentNode(const CNodePtr &node_ptr); - bool IsProcessedStream(uint32_t stream_id); - vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, - const CNodePtr &node); - void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); - - bool independent_stream_activated_{false}; - bool hcom_stream_activated_{false}; - std::map independent_stream_map_{}; - std::map hcom_stream_map_{}; - std::map common_stream_map_{}; - std::set processed_streams_{}; - std::vector need_first_active_streams_{}; - // new policy end -}; -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ diff --git a/mindspore/ccsrc/device/ascend/kernel_build_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_build_ascend.cc deleted file mode 100644 index bd0b436344..0000000000 --- a/mindspore/ccsrc/device/ascend/kernel_build_ascend.cc +++ /dev/null @@ -1,286 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/kernel_build_ascend.h" - -#include -#include -#include -#include - -#include "device/ascend/kernel_select_ascend.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "kernel/tbe/tbe_kernel_build.h" -#include "kernel/tbe/tbe_kernel_parallel_build.h" -#include "kernel/akg/ascend/akg_ascend_kernel_build.h" -#include "kernel/aicpu/aicpu_kernel_build.h" -#include "kernel/hccl/hccl_kernel_build.h" -#include "kernel/rts/rt_kernel_build.h" -#include "kernel/tbe/tbe_utils.h" -#include "kernel/common_utils.h" -#include "operator/ops.h" -#include "session/anf_runtime_algorithm.h" -#include "./common.h" - -namespace mindspore { -namespace device { -namespace ascend { -using mindspore::kernel::tbe::TbeUtils; -using std::make_shared; -static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) { - kernel::KernelModPtr kernel_mod_ptr = nullptr; - KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); - switch (kernel_type) { - case KernelType::AICPU_KERNEL: { - kernel_mod_ptr = kernel::AicpuOpBuild(anf_node); - break; - } - case KernelType::RT_KERNEL: { - kernel_mod_ptr = kernel::RtOpBuild(anf_node); - break; - } - case KernelType::HCCL_KERNEL: { - kernel_mod_ptr = kernel::HcclOpBuild(anf_node); - break; - } - default: { - MS_LOG(EXCEPTION) << "node [" << anf_node->DebugString() << "] Unsupported kernel_type:" << kernel_type; - } - } - return kernel_mod_ptr; -} - -static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - std::vector tbe_nodes; - for (const auto &anf_node : kernel_graph_ptr->execution_order()) { - MS_EXCEPTION_IF_NULL(anf_node); - if (!AnfAlgo::IsRealKernel(anf_node)) { - continue; - } - KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); - switch (kernel_type) { - case KernelType::TBE_KERNEL: { - if (AnfAlgo::GetKernelMod(anf_node) == nullptr && - AnfAlgo::GetFusionType(anf_node) == kernel::FusionType::DYNAMIC) { - tbe_nodes.push_back(anf_node); - } - break; - } - default: { - break; - } - } - } - bool ret = kernel::TbeOpParallelPreBuild(tbe_nodes); - return ret; -} - -static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - std::vector tbe_nodes; - std::vector akg_nodes; - std::vector other_nodes; - for (const auto &anf_node : kernel_graph_ptr->execution_order()) { - MS_EXCEPTION_IF_NULL(anf_node); - if (!AnfAlgo::IsRealKernel(anf_node)) { - continue; - } - KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); - switch (kernel_type) { - case KernelType::TBE_KERNEL: { - if (AnfAlgo::GetKernelMod(anf_node) == nullptr) { - tbe_nodes.push_back(anf_node); - } - break; - } - case KernelType::AKG_KERNEL: { - akg_nodes.push_back(anf_node); - break; - } - default: { - other_nodes.push_back(anf_node); - break; - } - } - } - bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes); - bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes); - auto bin_map = kernel::tbe::KernelMeta::GetInstance(); - (void)bin_map->ReadIndex(kernel::kCceKernelMeta); - for (const auto &anf_node : other_nodes) { - kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); - } - return tbe_ret && akg_ret; -} - -static std::vector CalCleanZerosSize(const CNodePtr &pre_node) { - MS_EXCEPTION_IF_NULL(pre_node); - auto kernel_mod = AnfAlgo::GetKernelMod(pre_node); - MS_EXCEPTION_IF_NULL(kernel_mod); - std::vector clean_size_list; - // clean output - if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { - auto output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); - auto output_men_size = kernel_mod->GetOutputSizeList(); - for (auto index : output_indexs) { - auto clean_item = (output_men_size.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; - clean_size_list.emplace_back(clean_item); - } - } - // clean workspace - if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { - auto workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); - auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList(); - for (const auto &index : workspace_indexs) { - auto clean_item = (workspace_men_sizes.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; - clean_size_list.emplace_back(clean_item); - } - } - MS_LOG(INFO) << "clear output size:" << clean_size_list.size() << ",pre_node:" << pre_node->fullname_with_scope(); - return clean_size_list; -} - -static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph, - const mindspore::CNodePtr &pre_node, std::vector *new_nodes) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(pre_node); - MS_EXCEPTION_IF_NULL(new_nodes); - auto clear_zero_prim = std::make_shared(kAtomicAddrCleanOpName); - MS_EXCEPTION_IF_NULL(clear_zero_prim); - auto new_value_node = NewValueNode(clear_zero_prim); - MS_EXCEPTION_IF_NULL(new_value_node); - std::vector inputs = {new_value_node}; - inputs.push_back(pre_node); - CNodePtr clear_zero = kernel_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(clear_zero); - AbstractBasePtr abstract = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract); - clear_zero->set_abstract(abstract); - auto builder = std::make_shared(); - builder->SetKernelType(KernelType::TBE_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get()); - auto clean_size = CalCleanZerosSize(pre_node); - AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clear_zero); - AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get()); - new_nodes->push_back(clear_zero); -} - -static bool IsAtomicNode(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto parameters_indexs = kernel_mod->GenParameters(); - if (parameters_indexs.empty()) { - return false; - } - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size(); - size_t param_num = parameters_indexs.size(); - size_t total_num = input_num + workspace_num + output_num; - MS_LOG(INFO) << "parameters size: " << param_num << ", input & workspace & output num: " << total_num; - size_t pad_index = param_num; - for (; pad_index < total_num; ++pad_index) { - parameters_indexs.emplace_back(0); - } - // process input - for (size_t j = 0; j < input_num; ++j) { - if (parameters_indexs.at(j) == 1) { - MS_LOG(EXCEPTION) << "Atomic addr clean does't support clean input address, input index: " << j; - } - } - // process output - std::vector output_indexs = {}; - for (size_t i = 0; i < output_num; ++i) { - auto param_output = parameters_indexs.at(input_num + workspace_num + i); - if (param_output == 1) { - output_indexs.emplace_back(i); - MS_LOG(INFO) << "Atomic clear output index: " << i; - } - } - if (!output_indexs.empty()) { - AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexs), kernel_node); - } - // process workspace - std::vector workspace_indexs = {}; - for (size_t k = 0; k < workspace_num; ++k) { - auto param_workspace = parameters_indexs.at(input_num + k); - if (param_workspace == 1) { - workspace_indexs.emplace_back(k); - MS_LOG(INFO) << "Atomic clear workspace index: " << k; - } - } - if (!workspace_indexs.empty()) { - AnfAlgo::SetNodeAttr(kAttrAtomicWorkspaceIndexs, MakeValue(workspace_indexs), kernel_node); - } - return !(workspace_indexs.empty() && output_indexs.empty()); -} - -bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - bool ret = device::ascend::KernelPreBuildParallelCompile(kernel_graph_ptr); - return ret; -} - -bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - TbeUtils::LoadCache(); - bool ret; - ret = device::ascend::KernelBuildParallelCompile(kernel_graph_ptr); - return ret; -} - -void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - std::vector new_nodes; - for (const auto &anf_node : kernel_graph->execution_order()) { - std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); - if (apply_function_name == prim::kPrimMaxPoolGrad->name() && - AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { - auto clear_zero_prim = std::make_shared(kClearZeroOpName); - MS_EXCEPTION_IF_NULL(clear_zero_prim); - auto new_value_node = NewValueNode(clear_zero_prim); - MS_EXCEPTION_IF_NULL(new_value_node); - std::vector inputs = {new_value_node}; - inputs.push_back(anf_node); - CNodePtr clear_zero = kernel_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(clear_zero); - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - clear_zero->set_kernel_info(kernel_info); - AbstractBasePtr abstract = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract); - AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector({"x"})), clear_zero); - SelectKernelInfo(clear_zero); - // set the distinction label of clear same with anf - AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get()); - new_nodes.push_back(clear_zero); - } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) { - if (IsAtomicNode(anf_node)) { - AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); - } - } - new_nodes.push_back(anf_node); - } - kernel_graph->set_execution_order(new_nodes); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/kernel_build_ascend.h b/mindspore/ccsrc/device/ascend/kernel_build_ascend.h deleted file mode 100644 index d987b6ce7a..0000000000 --- a/mindspore/ccsrc/device/ascend/kernel_build_ascend.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ - -#include "session/kernel_graph.h" - -namespace mindspore { -namespace device { -namespace ascend { -/** - * @brief kernel pre build for ascend. - */ -bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr); -/** - * @brief kernel build for ascend. - */ -bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr); -/** - * @brief preporcess of kernel build for ascend, e.g. inserting clear_zero node for maxpool, bn. - * Must DO these changes just before kernel build, and after all of other optimizations on AnfGraph - */ -void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph); -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc deleted file mode 100644 index cde79a18f7..0000000000 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ /dev/null @@ -1,584 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/kernel_select_ascend.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "debug/anf_ir_dump.h" -#include "operator/ops.h" -#include "ir/func_graph.h" -#include "utils/context/ms_context.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "kernel/common_utils.h" -#include "kernel/kernel_query.h" -#include "kernel/oplib/oplib.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace { -const float kWegihtBaseScore = 1; -const float kFeatureMapBaseScore = 10; -constexpr auto kPriChoosenFormat = "pri_format"; -enum MatchCountPriority : int { - MATCH_COUNT_PRIORITY_BEGIN = 0, - MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, - MATCH_FORMAT_COUNT, - MATCH_SPECIAL_FORMAT_COUNT, - MATCH_DEFAULT_FORMAT_COUNT, - MATCH_OUTPUT_DTYPE_COUNT, - MATCH_COUNT_PRIORITY_END -}; - -const int kUnSupportMixedDataTypeIndex = -1; - -bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { - MS_EXCEPTION_IF_NULL(cnode); - // Check input data type - for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { - TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); - if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { - return false; - } - } - // Check output data type - for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { - if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) { - return false; - } - } - return true; -} - -string GetPriorityMatchFormat(const CNodePtr &cnode) { - string priority_matched_format = kOpFormat_NC1HWC0; - bool is_init = false; - bool need_change_nd = false; - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { - auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); - if (AnfAlgo::IsFeatureMapInput(cnode, index) && - kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) { - priority_matched_format = !is_init ? pre_output_format : priority_matched_format; - is_init = true; - } - // feature map has two or more special format; - if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) { - priority_matched_format = kOpFormat_DEFAULT; - } - auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); - need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); - } - if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) { - priority_matched_format = kOpFormat_DEFAULT; - } - AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); - return priority_matched_format; -} -/** - * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location, - * if equal then next num location - * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3] - */ -bool PriorityChooseItem(const std::vector &cur_item, std::vector *best_item) { - MS_EXCEPTION_IF_NULL(best_item); - if (cur_item.size() != best_item->size()) { - MS_LOG(ERROR) << "Item size should be same!"; - return false; - } - // Update the best_item by comparing the cur_item and best_item - for (size_t i = 0; i < cur_item.size(); i++) { - if (cur_item[i] > best_item->at(i)) { - *best_item = cur_item; - return true; - } else if (cur_item[i] == best_item->at(i)) { - continue; - } else { - return false; - } - } - return false; -} - -void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr &kernel_node, - std::vector *const cur_kernelinfo_match_counts) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts); - if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) { - MS_LOG(EXCEPTION) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END; - } - auto pri_match_format = GetPriorityMatchFormat(kernel_node); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - auto input_anf_node = kernel_node->input(input_index + 1); - // we do not take ValueNode into consideration in graph kernel. - if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) { - if (input_anf_node->isa() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) { - continue; - } - } - auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore; - if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { - (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score; - } - // we match output fix precision first. - auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index); - if (prev_device_type == kTypeUnknown) { - prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); - } - if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) { - (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score; - } - if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { - (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; - } - if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) { - (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score; - } - } - - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - // cal count of same output dtype between abstract and kernel info - if (kernel_build_info.GetOutputDeviceType(output_index) == - AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) { - (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT] += 1; - } - } -} - -void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector *support_index) { - MS_EXCEPTION_IF_NULL(support_index); - int index = kUnSupportMixedDataTypeIndex; - switch (data_type) { - case kNumberTypeFloat16: - index = 0; - break; - case kNumberTypeFloat32: - case kNumberTypeFloat: - index = 1; - break; - default: - break; - } - support_index->push_back(index); -} - -void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index, - std::vector *support_datatype_index, std::vector *support_datatype) { - MS_EXCEPTION_IF_NULL(support_datatype); - auto data_type = kernel_build_info.GetInputDeviceType(input_index); - support_datatype->push_back(data_type); - AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); -} - -void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index, - std::vector *support_datatype_index, std::vector *support_datatype) { - MS_EXCEPTION_IF_NULL(support_datatype); - auto data_type = kernel_build_info.GetOutputDeviceType(output_index); - support_datatype->push_back(data_type); - AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); -} - -void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index, - std::vector *node_mix_precision_datatype_index, - std::vector *node_mix_precision_datatype) { - AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); - MS_EXCEPTION_IF_NULL(cur_input); - MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); - TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); - AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); - node_mix_precision_datatype->push_back(input_origin_type); -} - -void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index, - std::vector *node_mix_precision_datatype_index, - std::vector *node_mix_precision_datatype) { - MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); - auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index); - AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index); - node_mix_precision_datatype->push_back(output_origin_type); -} - -void CheckDataTypeInputs(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatypes, - std::map> *kernel_match_datatype_idx) { - if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) { - MS_LOG(EXCEPTION) << "Node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size " - << node_mix_precision_datatype.size(); - } - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) { - MS_LOG(EXCEPTION) << "Kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size " - << kernel_support_datatypes.size(); - } -} - -bool RaiseDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatypes, - std::map> *kernel_match_datatype_idx) { - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, - kernel_match_datatype_idx); - for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { - if (node_mix_precision_datatype[i] == kTypeUnknown) { - continue; - } - auto iter = kernel_match_datatype_idx->begin(); - while (iter != kernel_match_datatype_idx->end()) { - if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { - auto find_iter = kernel_support_datatypes.find(iter->first); - if (find_iter == kernel_support_datatypes.end()) { - MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; - } - if (i >= find_iter->second.size()) { - MS_LOG(EXCEPTION) << "Node index " << i << "kernel datatype size " << find_iter->second.size(); - } - if (node_mix_precision_datatype[i] != find_iter->second[i]) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - continue; - } - auto datatype_indexes = iter->second; - if (i >= datatype_indexes.size()) { - MS_LOG(EXCEPTION) << "Node datatype index: " << i << " kernel support size " << datatype_indexes.size(); - } - if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - } - } - return !kernel_match_datatype_idx->empty(); -} - -bool CanDataTypeReduce(const std::vector &datatype_indexes, int check_index, - const std::vector &node_mix_precision_datatype_index) { - auto check_index_tmp = IntToSize(check_index); - if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) { - return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex && - datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index]; - } - MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range"; -} - -bool RaiseOrReduceDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatypes, - std::map> *kernel_match_datatype_idx) { - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, - kernel_match_datatype_idx); - for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { - if (node_mix_precision_datatype[i] == kTypeUnknown) { - continue; - } - auto iter = kernel_match_datatype_idx->begin(); - while (iter != kernel_match_datatype_idx->end()) { - if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { - auto find_iter = kernel_support_datatypes.find(iter->first); - if (find_iter == kernel_support_datatypes.end()) { - MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; - } - if (i >= find_iter->second.size()) { - MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size(); - } - if (node_mix_precision_datatype[i] != find_iter->second[i]) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - continue; - } - auto datatype_indexes = iter->second; - if (i >= datatype_indexes.size()) { - MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); - } - if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - } - } - return !kernel_match_datatype_idx->empty(); -} - -void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info, - std::vector *support_indexes, std::vector *node_mix_precision_datatype, - std::vector *support_datatypes, - std::vector *node_mix_precision_datatype_index) { - MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); - bool add_node_datatype_flag = false; - if (node_mix_precision_datatype->empty()) { - add_node_datatype_flag = true; - } - for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { - AddKernelInputSupportDataType(kernel_build_info, input_index, support_indexes, support_datatypes); - if (add_node_datatype_flag) { - AddNodeInputDataType(kernel_node, input_index, node_mix_precision_datatype_index, node_mix_precision_datatype); - } - } - // Check output data type - for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { - AddKernelOutputSupportDataType(kernel_build_info, output_index, support_indexes, support_datatypes); - if (add_node_datatype_flag) { - AddNodeOutputDataType(kernel_node, output_index, node_mix_precision_datatype_index, node_mix_precision_datatype); - } - } -} - -void PrecisionReduce(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatype, - std::map> *kernel_match_datatype_idx, bool *precision_reduce) { - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(precision_reduce); - std::map> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx; - // raise precision - bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, - kernel_support_datatype, kernel_match_datatype_idx); - if (selected_ret) { - *precision_reduce = false; - return; - } - if (context_ptr->enable_reduce_precision()) { - selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, - kernel_support_datatype, &kernel_match_datatype_idx_copy); - } - if (selected_ret) { - *precision_reduce = true; - *kernel_match_datatype_idx = kernel_match_datatype_idx_copy; - } -} - -void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, - const std::shared_ptr &selected_kernel_build_info, - bool precision_reduce) { - MS_EXCEPTION_IF_NULL(selected_kernel_build_info); - MS_EXCEPTION_IF_NULL(cnode); - std::ostringstream buffer; - buffer << cnode->DebugString(); - if (precision_reduce) { - buffer << " Reduce precision, node datatype: \n"; - } else { - buffer << " Raise precision, node datatype: \n"; - } - PrintInputAndOutputInferType(buffer, cnode); - buffer << ", select kernel:" << selected_kernel_build_info->ToString(); - MS_LOG(INFO) << buffer.str(); -} - -std::shared_ptr ChooseMatchedKernelInfo( - const CNodePtr &kernel_node, const std::vector> &kernel_info_list) { - if (kernel_info_list.empty()) { - return nullptr; - } - std::vector most_match_counts = {-1, -1, -1, -1, -1}; - size_t selected_index = 0; - for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { - std::vector cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; - auto kernel_info_ptr = kernel_info_list[info_index]; - MS_EXCEPTION_IF_NULL(kernel_info_ptr); - UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); - // Currently the selection policy is the match format count first, and then is datatype counts. - if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) { - selected_index = SizeToInt(info_index); - } - } - return kernel_info_list[selected_index]; -} - -std::vector> FilteredKernelInfoByDtype( - const CNodePtr &cnode, const std::vector> &kernel_info_list) { - std::vector> result; - for (const auto &kernel_build_info : kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_build_info); - if (!MatchInferOutputDataType(cnode, *kernel_build_info)) { - continue; - } - result.push_back(kernel_build_info); - } - return result; -} - -std::vector> FilterRaisedOrReducePrecisionMatchedKernelInfo( - const CNodePtr &cnode, const std::vector> &kernel_info_list, - bool *precision_reduce) { - std::vector> filtered_kernel_info_list; - std::map> kernel_match_datatype_idx; - std::map> kernel_support_datatype; - std::vector node_mix_precision_datatype_index; - std::vector node_mix_precision_datatype; - for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { - std::vector support_indexes; - std::vector support_datatypes; - MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]); - AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype, - &support_datatypes, &node_mix_precision_datatype_index); - kernel_match_datatype_idx[info_index] = support_indexes; - kernel_support_datatype[info_index] = support_datatypes; - } - PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, - &kernel_match_datatype_idx, precision_reduce); - std::transform( - kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list), - [&](const std::pair> &matched_idx) -> std::shared_ptr { - return kernel_info_list[matched_idx.first]; - }); - return filtered_kernel_info_list; -} -} // namespace - -void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); - MS_EXCEPTION_IF_NULL(input_kernel_node); - auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); - MS_EXCEPTION_IF_NULL(input_with_index.first); - auto real_input_node = input_with_index.first; - if (real_input_node->isa()) { - continue; - } - if (real_input_node->isa() && !AnfAlgo::IsParameterWeight(real_input_node->cast())) { - continue; - } - auto builder = std::make_shared(); - if (IsValueNode(input_kernel_node) && - AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { - std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; - builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; - builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); - continue; - } - // we set special device info of a input tensor. - bool is_ref = false; - auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); - if (op_info != nullptr) { - is_ref = op_info->is_ref(); - } - MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); - if (MsContext::GetInstance()->execution_mode() == kPynativeMode && - AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { - continue; - } - if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { - std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; - builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; - builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); - } - } -} - -KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, - const std::vector> &kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - KernelSelectStatus select_status = kNoMatched; - bool precision_reduce = false; - std::shared_ptr selected_kernel_info = nullptr; - // Matched kernel info - // Filter kernel info matched with me infered type - auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list); - if (!filtered_kernel_info_list.empty()) { - selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); - select_status = kStatusAllMatched; - } else { - // selected kernel info using raised precision or reduce precision - filtered_kernel_info_list = - FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce); - selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); - if (selected_kernel_info == nullptr) { - return select_status; - } else { - PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); - select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; - } - } - // Set kernel info to the anfnode - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); - // Set format and data type for input tensor. - SetTensorDeviceInfo(*selected_kernel_info, kernel_node); - return select_status; -} - -KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { - std::vector> kernel_info_list; - std::vector> aicpu_kernel_info_list; - MS_EXCEPTION_IF_NULL(kernel_node); - if (AnfAlgo::IsGraphKernel(kernel_node)) { - auto func_graph = GetValueNode(kernel_node->input(kAnfPrimitiveIndex)); - MS_EXCEPTION_IF_NULL(func_graph); - SelectGraphKernelInfo(kernel_node, func_graph); - return kStatusAllMatched; - } - kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type); - auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); - // If aicore not find valid kernel info reloading aicpu kernel info list to find it - if (select_status == kNoMatched) { - MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() - << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; - kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list); - select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list); - AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); - } - // The kernel info not finded both in the aicpu kernel list & aicore kernel list - if (select_status == kNoMatched) { - std::ostringstream buffer; - PrintInputAndOutputInferType(buffer, kernel_node); - MS_LOG(WARNING) << ">>> Candidates kernel info list:"; - for (size_t index = 0; index < kernel_info_list.size(); ++index) { - MS_LOG(WARNING) << "Kernel [" << index << "] :" << kernel_info_list[index]->ToString(); - } - for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) { - MS_LOG(WARNING) << "Kernel [" << (kernel_info_list.size() + index) - << "] :" << aicpu_kernel_info_list[index]->ToString(); - } - if (IsPrimitiveCNode(kernel_node, prim::kPrimLabelSwitch)) { - auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); - // Set format and data type for input tensor. - SetTensorDeviceInfo(*selected_kernel_info, kernel_node); - } else { - MS_LOG(WARNING) << " <<<"; - MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() - << "] cannot find valid kernel info, not supported the type:" << buffer.str() - << ", please refer to the supported dtypes in candidates kernel info list"; - } - } - return select_status; -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.h b/mindspore/ccsrc/device/ascend/kernel_select_ascend.h deleted file mode 100644 index 7b7a7b9fb9..0000000000 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ -#include "ir/anf.h" -#include "kernel/kernel_build_info.h" -namespace mindspore { -namespace device { -namespace ascend { -enum KernelSelectStatus { - kNoMatched = -1, - kStatusAllMatched = 0, - kStatusReducePrecision = 1, - kStatusRaisePrecision = 2, -}; -KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, - KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); -void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node); -void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph); -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ diff --git a/mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc b/mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc deleted file mode 100644 index db31460d31..0000000000 --- a/mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc +++ /dev/null @@ -1,531 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/kernel_select_ascend.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "ir/func_graph.h" -#include "kernel/common_utils.h" -#include "kernel/kernel_query.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace { -// sort format according the number of occurrences. -bool cmp_format_num(const std::pair &a, const std::pair &b) { - if (a.second != b.second) { - return a.second > b.second; - } else if (a.first == kOpFormat_DEFAULT) { - return a.second + 1 > b.second; - } else if (b.first == kOpFormat_DEFAULT) { - return a.second > b.second + 1; - } - return a.second > b.second; -} - -TypeId GetPrimitivePrecision(const CNodePtr &cnode) { - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - - TypeId except_type = kTypeUnknown; - if (primitive->GetAttr(kAttrFixPrecision) != nullptr) { - auto strExceptDtype = GetValue(primitive->GetAttr(kAttrFixPrecision)); - if (strExceptDtype == "float16") { - except_type = kNumberTypeFloat16; - } else if (strExceptDtype == "float32") { - except_type = kNumberTypeFloat32; - } else { - MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got" << strExceptDtype; - } - } - - return except_type; -} -} // namespace - -void ResetKernelBuildInfo(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); - MS_EXCEPTION_IF_NULL(input_kernel_node); - auto kernel_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); - if (!kernel::IsWeightBoundary(kernel_with_index.first)) { - continue; - } - // reset format and dtype. - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - builder.SetOutputsDeviceType(std::vector{kTypeUnknown}); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get()); - } -} - -void UpdateKernelInfo(const std::vector &node_list) { - for (size_t i = 0; i < node_list.size(); ++i) { - // select nodes in subgraph. - auto anf_node = node_list[i]; - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto fix_precision_type = GetPrimitivePrecision(cnode); - if (fix_precision_type != kTypeUnknown) { - std::vector> kernel_info_list; - kernel::KernelQuery(cnode, &kernel_info_list, KernelType::AKG_KERNEL); - - for (size_t index = 0; index < kernel_info_list.size(); ++index) - // only math the first input - if (kernel_info_list[index]->GetInputDeviceType(0) == fix_precision_type && - kernel_info_list[index]->GetInputFormat(0) == AnfAlgo::GetPrevNodeOutputFormat(cnode, 0) && - AnfAlgo::GetInputDeviceDataType(cnode, 0) != fix_precision_type) { - auto selected_kernel_info_ptr = kernel_info_list[index]; - ResetKernelBuildInfo(cnode); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get()); - SetTensorDeviceInfo(*selected_kernel_info_ptr, cnode); - break; - } - } - } -} - -bool CanConvertDefaultShapeToNZ(const std::vector &shape) { - for (size_t i = 1; i <= shape.size(); ++i) { - if (i > 2) { - break; - } - if (shape[shape.size() - i] != 1 && shape[shape.size() - i] % kCubeSize != 0) { - return false; - } - } - return true; -} - -std::vector DefaultToFracNZAxis(const std::vector &ori_shape, const std::vector &axis) { - std::vector frac_nz_axis = axis; - auto shape_len = ori_shape.size(); - for (size_t i = 0; i < axis.size(); ++i) { - auto axis_idx = (frac_nz_axis[i] + shape_len) % shape_len; - if (axis_idx == shape_len - 1) { - frac_nz_axis[i] = axis_idx - 1; - frac_nz_axis.push_back(axis_idx + 2); - } else if (axis_idx == shape_len - 2) { - frac_nz_axis[i] = axis_idx + 1; - frac_nz_axis.push_back(axis_idx + 2); - } else { - frac_nz_axis[i] = axis_idx; - } - } - return frac_nz_axis; -} - -std::vector GetReducedFracNZShape(const std::vector &ori_shape, const std::vector &axis, - bool keep_dims) { - std::vector result; - std::set positive_idx; - for (const auto &a : axis) { - positive_idx.insert(a >= 0 ? a : ori_shape.size() + a); - } - for (size_t i = 0; i < ori_shape.size(); ++i) { - if (positive_idx.count(i) == 0) { - result.push_back(ori_shape[i]); - } else if (keep_dims) { - result.push_back(1); - } - } - return result; -} - -void UpdateFracNZReduceOp(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); - if (input_format == kOpFormat_FRAC_NZ) { - // Clone primitive to modify it - auto prim = GetCNodePrimitive(cnode); - auto new_prim = std::make_shared(*prim); - auto new_prim_node = NewValueNode(new_prim); - cnode->set_input(0, new_prim_node); - - auto axis_value = new_prim->GetAttr(kAttrAxis); - std::vector default_axis; - if (axis_value->isa()) { - auto value_list = dyn_cast(axis_value); - for (const auto &item : value_list->value()) { - if (item->isa()) { - default_axis.push_back(GetValue(item)); - } - } - } else if (axis_value->isa()) { - auto value_tuple = dyn_cast(axis_value); - for (const auto &item : value_tuple->value()) { - if (item->isa()) { - default_axis.push_back(GetValue(item)); - } - } - } else { - MS_LOG(ERROR) << "Axis attr type is not correct!"; - } - auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); - std::vector frac_nz_axis = DefaultToFracNZAxis(infer_shape, default_axis); - AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue>(frac_nz_axis), cnode); - auto output_shape = AnfAlgo::GetOutputInferShape(cnode, 0); - if (output_shape.size() == 1) { - AnfAlgo::SetNodeAttr(kAttrOutputDefault, MakeValue(true), cnode); - } - } -} - -void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, bool *use_same_format) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(default_format); - MS_EXCEPTION_IF_NULL(use_same_format); - std::unordered_map all_input_formats; - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; ++i) { - auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_kernel_node); - if (!input_kernel_node->isa()) { - ++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)]; - continue; - } - auto para = input_kernel_node->cast(); - if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { - ++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)]; - continue; - } - *use_same_format = false; - } - - if (all_input_formats.empty()) { - // all inputs are parameter. - *default_format = kOpFormat_NC1HWC0; - } else { - std::vector> pairs; - for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) { - pairs.push_back(std::make_pair(iter->first, iter->second)); - } - - std::sort(pairs.begin(), pairs.end(), cmp_format_num); - *default_format = pairs.begin()->first; - } - - for (size_t i = 0; i < input_num; ++i) { - auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_kernel_node); - if (!input_kernel_node->isa() || - AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) != kTypeUnknown) { - continue; - } - auto weight_infer_shape = AnfAlgo::GetOutputInferShape(input_kernel_node, 0); - if (weight_infer_shape.size() < 2 && *default_format == kOpFormat_FRAC_NZ) { - *default_format = kOpFormat_DEFAULT; - *use_same_format = true; - break; - } - } -} - -void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector &input_list, - const std::string &default_format, bool use_same_format, - std::vector *graph_input_format, std::vector *graph_input_type) { - MS_EXCEPTION_IF_NULL(graph_input_format); - MS_EXCEPTION_IF_NULL(graph_input_type); - // We set same format to all inputs of graph kernel subgraph, and process this latter. - // We set dtype to inputs of graph kernel subgraph same as infer dtypes. - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; ++i) { - auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_kernel_node); - if (use_same_format) { - bool can_convert = true; - if (default_format == kOpFormat_FRAC_NZ) { - auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - if (!CanConvertDefaultShapeToNZ(infer_shape)) { - MS_LOG(WARNING) << "Shape can't be converted to frac nz shape, so use default format instead"; - can_convert = false; - } - } - if (can_convert) { - graph_input_format->push_back(default_format); - } else { - graph_input_format->push_back(kOpFormat_DEFAULT); - } - graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); - continue; - } - - if (!input_kernel_node->isa()) { - // subgraph parameter from output of other nodes. - graph_input_format->push_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)); - graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); - continue; - } - - auto para = input_kernel_node->cast(); - MS_EXCEPTION_IF_NULL(para); - if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { - // parameter already selected. - graph_input_format->push_back(AnfAlgo::GetOutputFormat(para, 0)); - graph_input_type->push_back(AnfAlgo::GetOutputDeviceDataType(para, 0)); - continue; - } - - // weight parameter. - graph_input_format->push_back(default_format); - graph_input_type->push_back(AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)); - } - - for (size_t i = 0; i < input_num; ++i) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - std::vector outputs_format = {(*graph_input_format)[i]}; - std::vector outputs_device_type = {(*graph_input_type)[i]}; - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_device_type); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); - } -} - -void UpdateEquivFormat(const std::vector> &output_index, - const std::vector &node_list, const FuncGraphPtr &func_graph, - const FuncGraphManagerPtr &mng) { - MS_EXCEPTION_IF_NULL(mng); - for (size_t i = 0; i < node_list.size(); ++i) { - // select nodes in subgraph. - auto anf_node = node_list[i]; - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - cnode->set_kernel_info(std::make_shared()); - SelectKernelInfo(cnode, KernelType::AKG_KERNEL); - // Update ReduceSum - if (!IsPrimitiveCNode(cnode, prim::kPrimReduceSum)) { - continue; - } - UpdateFracNZReduceOp(cnode); - // If ReduceSum's output is 1d and not Default format, convert it to Default format - auto out_format = AnfAlgo::GetOutputFormat(cnode, 0); - if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) { - continue; - } - auto infer_shape = AnfAlgo::GetOutputInferShape(cnode, 0); - // Insert EquivFormat node, then select kernel info again - std::vector trans_inputs; - trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat)); - trans_inputs.push_back(cnode); - CNodePtr trans_node = func_graph->NewCNode(trans_inputs); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0)}, - {AnfAlgo::GetOutputInferShape(cnode, 0)}, trans_node.get()); - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue>({"x"}), trans_node); - - if (trans_node->kernel_info() == nullptr) { - trans_node->set_kernel_info(std::make_shared()); - } - SelectKernelInfo(trans_node, KernelType::AKG_KERNEL); - mng->Replace(cnode, trans_node); - } -} - -void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &input_list, - const FuncGraphManagerPtr &mng, const std::string &default_format, - std::vector *graph_input_format, std::vector *graph_input_type, - std::vector *need_update) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(mng); - MS_EXCEPTION_IF_NULL(graph_input_format); - MS_EXCEPTION_IF_NULL(graph_input_type); - MS_EXCEPTION_IF_NULL(need_update); - // check graph input format and dtype use inner ops. - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (graph_input_format->size() != input_num || graph_input_type->size() != input_num || - need_update->size() != input_num) { - MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() - << "], [" << graph_input_format->size() << "] != [" << input_num << "]"; - } - auto &node_users = mng->node_users(); - for (size_t i = 0; i < input_num; ++i) { - auto &input = input_list[i]; - auto iter = node_users.find(input); - if (iter == node_users.end() || iter->second.empty()) { - continue; - } - for (auto &node_user : iter->second) { - if (node_user.first->kernel_info() == nullptr || - node_user.first->kernel_info()->select_kernel_build_info() == nullptr) { - // maybe not a real kernel. - continue; - } - auto user_format = AnfAlgo::GetInputFormat(node_user.first, IntToSize(node_user.second - 1)); - if (user_format != (*graph_input_format)[i]) { - MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" - << kernel_node->DebugString() - << "] selected different format. we use defult: " << default_format; - (*graph_input_format)[i] = default_format; - (*need_update)[i] = true; - } - - if (kernel_node->input(i + 1)->isa() || - AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) { - continue; - } - - TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); - MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" - << kernel_node->DebugString() - << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); - (*graph_input_type)[i] = default_dtype; - (*need_update)[i] = true; - } - } -} - -void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &node_list, - const std::vector &input_list, const std::vector &need_update, - const std::vector &graph_input_format, - const std::vector &graph_input_type) { - MS_EXCEPTION_IF_NULL(kernel_node); - // update graph input format and dtype use inner ops. - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (graph_input_format.size() != input_num || graph_input_type.size() != input_num || - need_update.size() != input_num) { - MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() - << "], [" << graph_input_format.size() << "] != [" << input_num << "]"; - } - for (size_t i = 0; i < input_num; ++i) { - if (!need_update[i]) { - continue; - } - - MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString() - << "] to: " << graph_input_format[i]; - MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString() - << "] to: " << TypeIdLabel(graph_input_type[i]); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - std::vector outputs_format = {graph_input_format[i]}; - std::vector outputs_device_type = {graph_input_type[i]}; - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_device_type); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); - } - - ResetKernelBuildInfo(kernel_node); - // select nodes in subgraph again. - for (size_t i = 0; i < node_list.size(); ++i) { - auto anf_node = node_list[i]; - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - size_t cnode_input_num = AnfAlgo::GetInputTensorNum(cnode); - for (size_t j = 0; j < cnode_input_num; ++j) { - auto input_node = cnode->input(j + 1); - MS_EXCEPTION_IF_NULL(input_node); - if (!IsValueNode(input_node)) { - continue; - } - // reset format and dtype of const tensor. - builder.SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - builder.SetOutputsDeviceType(std::vector{kTypeUnknown}); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_node.get()); - } - SelectKernelInfo(node_list[i]->cast(), KernelType::AKG_KERNEL); - } -} - -void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector> &output_index, - const std::vector &graph_input_format, - const std::vector &graph_input_type) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector graph_output_format; - std::vector graph_output_type; - for (size_t i = 0; i < output_index.size(); ++i) { - auto const &output = output_index[i]; - graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second)); - TypeId output_type(kTypeUnknown); - if (output.first->isa()) { - output_type = AnfAlgo::GetCNodeOutputPrecision(output.first); - } - if (output_type == kTypeUnknown) { - output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second); - } - graph_output_type.push_back(output_type); - } - - kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; - graph_info_builder.SetInputsFormat(graph_input_format); - graph_info_builder.SetInputsDeviceType(graph_input_type); - graph_info_builder.SetOutputsFormat(graph_output_format); - graph_info_builder.SetOutputsDeviceType(graph_output_type); - graph_info_builder.SetProcessor(kernel::Processor::AICORE); - graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); - graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); - auto graph_selected_info = graph_info_builder.Build(); - MS_EXCEPTION_IF_NULL(graph_selected_info); - AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get()); - SetTensorDeviceInfo(*graph_selected_info, kernel_node); -} - -void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(func_graph); - - // collect input info of funcgraph - std::vector node_list; - std::vector input_list; - std::vector output_list; - kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); - if (input_list.size() != kernel_node->inputs().size() - 1) { - MS_EXCEPTION(ArgumentError) << "Input num of funcgraph[" << func_graph->ToString() << "] not equal input of cnode[" - << kernel_node->DebugString() << "], [%" << input_list.size() << "] != [" - << kernel_node->inputs().size() << "]"; - } - - std::string default_format; - bool use_same_format = true; - GetDefaultFormat(kernel_node, &default_format, &use_same_format); - MS_LOG(DEBUG) << "GraphKernel[" << func_graph->ToString() << "] use same input format[" << default_format - << "] for ParameterWeight."; - - std::vector graph_input_format; - std::vector graph_input_type; - UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, - &graph_input_type); - - auto mng = func_graph->manager(); - if (mng == nullptr) { - mng = Manage(func_graph, true); - } - auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list); - UpdateEquivFormat(output_index, node_list, func_graph, mng); - node_list.clear(); - input_list.clear(); - output_list.clear(); - kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); - - // update graph input format and dtype use inner ops. - std::vector need_update(AnfAlgo::GetInputTensorNum(kernel_node), false); - CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type, - &need_update); - UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type); - - // set fix_precision for kernel when the me prim has fix_precision attr - UpdateKernelInfo(node_list); - - output_index = kernel::GetOutputIndex(node_list, input_list, output_list); - SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.cc b/mindspore/ccsrc/device/ascend/profiling/plugin_impl.cc deleted file mode 100644 index 7790107aa1..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.cc +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/ascend/profiling/plugin_impl.h" -#include -#include "utils/log_adapter.h" -using std::string; - -namespace mindspore { -namespace device { -namespace ascend { -Reporter *PluginImpl::reporter_ = nullptr; - -PluginImpl::PluginImpl(const std::string &module) : module_(module) { MS_LOG(INFO) << "Create PluginImpl."; } - -int PluginImpl::Init(const Reporter *reporter) { - MS_LOG(INFO) << "PluginImpl init"; - MS_EXCEPTION_IF_NULL(reporter); - reporter_ = const_cast(reporter); - return 0; -} - -int PluginImpl::UnInit() { - MS_LOG(INFO) << " PluginImpl Uninit "; - reporter_ = nullptr; - return 0; -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc deleted file mode 100644 index a393409334..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/ascend/profiling/profiling_engine_impl.h" -#include "utils/log_adapter.h" -#include "device/ascend/profiling/plugin_impl.h" - -namespace mindspore { -namespace device { -namespace ascend { -PluginIntf *ProfilingEngineImpl::CreatePlugin() { - MS_LOG(INFO) << "Create Plugin."; - return new (std::nothrow) PluginImpl("Framework"); -} - -int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) { - if (plugin != nullptr) { - delete plugin; - plugin = nullptr; - } - return 0; -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc deleted file mode 100644 index a2fe5b852d..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc +++ /dev/null @@ -1,207 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/profiling/profiling_manager.h" -#include -#include -#include "securec/include/securec.h" -#include "./prof_mgr_core.h" -#include "device/ascend/profiling/plugin_impl.h" -#include "device/ascend/profiling/profiling_engine_impl.h" -#include "utils/log_adapter.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" -#include "utils/convert_utils.h" -#include "runtime/base.h" - -namespace mindspore { -namespace device { -namespace ascend { -ProfilingManager &ProfilingManager::GetInstance() { - static ProfilingManager inst; - return inst; -} - -ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) { - engine_0_ = std::make_shared(); -} - -uint64_t ProfilingManager::GetJobId() const { - const char *job_id = std::getenv("JOB_ID"); - return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); -} - -bool ProfilingManager::ReportProfilingData(const map &op_taskId_map) const { - if (!IsProfiling()) { - MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; - return false; - } - if (op_taskId_map.empty()) { - MS_LOG(WARNING) << "op_taskId_map is empty."; - return false; - } - auto reporter = PluginImpl::GetPluginReporter(); - if (reporter == nullptr) { - MS_LOG(ERROR) << "No profiling data report!"; - return false; - } - MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); - - Msprof::Engine::ReporterData reporter_data = {}; - for (const auto &iter : op_taskId_map) { - auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; - reporter_data.deviceId = UintToInt(device_id_); - reporter_data.data = (unsigned char *)(const_cast(data.c_str())); - reporter_data.dataLen = data.size(); - auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); - if (ret != 0) { - MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; - return false; - } - ret = reporter->Report(&reporter_data); - if (ret != 0) { - MS_LOG(ERROR) << "reporter data fail, errorno(" << ret << ")"; - return false; - } - } - return true; -} - -static std::vector Split(const std::string &str, const char delim) { - std::vector elems; - - if (str.empty()) { - elems.emplace_back(""); - return elems; - } - - std::stringstream ss(str); - std::string item; - - while (getline(ss, item, delim)) { - elems.push_back(item); - } - auto str_size = str.size(); - if (str_size > 0 && str[str_size - 1] == delim) { - elems.emplace_back(""); - } - - return elems; -} - -bool ProfilingManager::StartupProfiling(uint32_t device_id) { - auto is_profiling = IsProfiling(); - if (!is_profiling) { - MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; - return true; - } - device_id_ = device_id; - // register Framework to profiling - int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); - if (result != 0) { - MS_LOG(ERROR) << "Register profiling Engine failed."; - return false; - } - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - const string prof_options_str = context->profiling_options(); - std::vector opts = Split(prof_options_str, ':'); - if (opts.empty()) { - MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!"; - return true; - } - // current one docker only use one device` - nlohmann::json p_device; - // JOBID - auto job_id = GetJobId(); - p_device["jobID"] = std::to_string(job_id); - // device_id - p_device["deviceID"] = std::to_string(device_id); - // features:'training_trace', 'task_trace' etc - nlohmann::json features; - for (std::vector::size_type i = 0; i < opts.size(); i++) { - nlohmann::json f; - f["name"] = opts[i]; - features[i] = f; - } - p_device["features"] = features; - // only one device, but sProfMgrStartUp API require for device list - nlohmann::json devices; - devices[0] = p_device; - nlohmann::json startCfg; - startCfg["startCfg"] = devices; - - if (!ProfStartUp(NOT_NULL(&startCfg))) { - MS_LOG(ERROR) << "ProfMgrStartUp failed."; - return false; - } - return true; -} - -bool ProfilingManager::ProfStartUp(NotNull startCfg) { - // convert json to string - std::stringstream ss; - ss << *startCfg; - std::string cfg = ss.str(); - MS_LOG(INFO) << "profiling config " << cfg; - auto ret = rtProfilerStart(); - if (ret != RT_ERROR_NONE) { - MS_LOG(INFO) << "Call rtProfilerStart failed, ret:" << ret; - return false; - } - - // call profiling startup API - ProfMgrCfg prof_cfg = {cfg}; - prof_handle_ = ProfMgrStartUp(&prof_cfg); - if (prof_handle_ == nullptr) { - MS_LOG(ERROR) << "Startup profiling failed."; - return false; - } - return true; -} - -bool ProfilingManager::StopProfiling() { - MS_LOG(INFO) << "StopProfiling"; - if (!IsProfiling()) { - MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; - return true; - } - Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - if (reporter != nullptr) { - MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); - } - - auto rt_ret = rtProfilerStop(); - if (rt_ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Call rtProfilerStop failed"; - return false; - } - - if (prof_handle_ != nullptr) { - int result = ProfMgrStop(prof_handle_); - if (result != 0) { - MS_LOG(ERROR) << "ProfMgr stop return fail:" << result << "."; - prof_handle_ = nullptr; - return false; - } - prof_handle_ = nullptr; - } - - return true; -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc deleted file mode 100644 index 17ac4c4530..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc +++ /dev/null @@ -1,367 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/profiling/reporter/graph_desc_reporter.h" -#include "device/ascend/profiling/profiling_utils.h" -#include "kernel/kernel.h" -#include "device/ascend/profiling/profiling_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "utils/utils.h" -#include "device/ascend/profiling/reporter/task_desc_reporter.h" -#include "utils/context/ms_context.h" -#include "device/ascend/profiling/reporter/point_reporter.h" - -namespace mindspore { -namespace device { -namespace ascend { -constexpr uint32_t kMaxProfilingNodeNum = 100; -constexpr char kCustomNode[] = "PROFILING_CUSTOM_"; -constexpr char kFpStartNode[] = "PROFILING_FP_START"; -constexpr char kBpEndNode[] = "PROFILING_BP_END"; -constexpr char kIterEndNode[] = "PROFILING_ITER_END"; -// PROFILING_CUSTOM_LOGID_START 3 -constexpr uint64_t kProfilingFpStartLogId = 1; -constexpr uint64_t kProfilingBpEndLogId = 2; -constexpr uint64_t kProfilingIterEndLogId = 255; -std::map> ProfilingUtils::graph_profiling_cnode_; -std::map> ProfilingUtils::graph_kernel_name_; -std::map>> ProfilingUtils::graph_point_; -uint32_t ProfilingUtils::custom_node_index_ = 1; - -ProfilingTraceInfo ProfilingUtils::GetProfilingTraceFromEnv(NotNull graph_ptr) { - MS_LOG(INFO) << "get env start"; - custom_node_index_ = 1; - auto &cnode_exec_order = graph_ptr->execution_order(); - ProfilingTraceInfo profiling_trace; - profiling_trace.trace_begin = GetTraceBegin(cnode_exec_order); - profiling_trace.trace_bp_end = GetTraceBpEnd(cnode_exec_order); - profiling_trace.trace_netoutput = GetTraceNetoutput(cnode_exec_order); - - for (uint32_t i = 1; i <= kMaxProfilingNodeNum; ++i) { - std::string env_str = std::string(kCustomNode) + std::to_string(i); - const char *node_full_name = std::getenv(env_str.c_str()); - if (node_full_name == nullptr) { - break; - } - MS_LOG(INFO) << "Get profiling node:" << node_full_name; - profiling_trace.trace_custom_node.insert(node_full_name); - } - MS_LOG(INFO) << "get env end"; - GetTraceHccl(cnode_exec_order, NOT_NULL(&profiling_trace)); - - MS_LOG(INFO) << "[profiling]trace_begin:" << profiling_trace.trace_begin - << " trace_bp_end:" << profiling_trace.trace_bp_end - << " trace_netoutput:" << profiling_trace.trace_netoutput; - return profiling_trace; -} - -void ProfilingUtils::GetTraceHccl(const std::vector &cnode_exec_order, - NotNull profiling_trace) { - for (const auto &node : cnode_exec_order) { - if (AnfAlgo::IsCommunicationOp(node)) { - MS_EXCEPTION_IF_NULL(node); - profiling_trace->trace_custom_node.insert(node->fullname_with_scope()); - MS_LOG(INFO) << "[profiling]Get hccl node:" << node->fullname_with_scope(); - } - } -} - -std::string ProfilingUtils::GetTraceBegin(const std::vector &cnode_exec_order) { - const char *trace_begin = std::getenv(kFpStartNode); - if (trace_begin != nullptr) { - return std::string(trace_begin); - } - - std::string fp_start_str; - std::set getnext_outputs; - GetCNodeOutputRealNode(kGetNextOpName, cnode_exec_order, NOT_NULL(&getnext_outputs)); - if (getnext_outputs.empty()) { - auto first_node = cnode_exec_order.front(); - MS_EXCEPTION_IF_NULL(first_node); - fp_start_str = first_node->fullname_with_scope(); - } else { - for (auto &cnode : cnode_exec_order) { - if (getnext_outputs.count(cnode->fullname_with_scope()) != 0) { - fp_start_str = cnode->fullname_with_scope(); - break; - } - } - } - return fp_start_str; -} - -void ProfilingUtils::GetCNodeOutputRealNode(const std::string &node_name, const std::vector &cnode_exec_order, - NotNull *> getnext_outputs) { - for (const auto &cnode : cnode_exec_order) { - MS_EXCEPTION_IF_NULL(cnode); - for (const auto &input : cnode->inputs()) { - auto prev_cnode = AnfAlgo::VisitKernel(input, 0); - if (!prev_cnode.first->isa()) { - continue; - } - if (AnfAlgo::GetCNodeName(prev_cnode.first) == node_name) { - getnext_outputs->insert(cnode->fullname_with_scope()); - MS_LOG(INFO) << "Find GetNext Output CNode:" << cnode->fullname_with_scope(); - } - } - } - if (getnext_outputs->empty()) { - MS_LOG(WARNING) << "GetNext not found"; - } -} - -std::string ProfilingUtils::GetTraceBpEnd(const std::vector &cnode_exec_order) { - const char *trace_bp_end = std::getenv(kBpEndNode); - - if (trace_bp_end != nullptr) { - return std::string(trace_bp_end); - } - std::string bp_end_str; - // Contain hccl kernel - auto iter = cnode_exec_order.rbegin(); - while (iter != cnode_exec_order.rend()) { - if (AnfAlgo::IsCommunicationOp(*iter)) { - // store communication op input nodes' name - std::set ar_input_node_names; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(*iter); ++i) { - auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(*iter, i); - auto input_node = input_node_with_index.first; - ar_input_node_names.insert(input_node->fullname_with_scope()); - } - // start from previous node - ++iter; - // find input names in previous node - while (iter != cnode_exec_order.rend()) { - if (ar_input_node_names.find((*iter)->fullname_with_scope()) != ar_input_node_names.end()) { - bp_end_str = (*iter)->fullname_with_scope(); - break; - } - ++iter; - } - break; - } - ++iter; - } - - if (bp_end_str.empty()) { - bp_end_str = GetGraphLastTbeKernelName(cnode_exec_order); - } - return bp_end_str; -} - -std::string ProfilingUtils::GetGraphLastTbeKernelName(const std::vector &cnode_exec_order) { - std::string last_tbe_kernel_name; - // find last tbe_kernel - for (auto iter = cnode_exec_order.rbegin(); iter != cnode_exec_order.rend(); ++iter) { - if (AnfAlgo::GetKernelType(*iter) == TBE_KERNEL) { - last_tbe_kernel_name = (*iter)->fullname_with_scope(); - break; - } - } - if (last_tbe_kernel_name.empty()) { - MS_LOG(WARNING) << "tbe kernel not found in graph"; - } - return last_tbe_kernel_name; -} - -std::string ProfilingUtils::GetTraceNetoutput(const std::vector &cnode_exec_order) { - const char *trace_netoutput = std::getenv(kIterEndNode); - return trace_netoutput == nullptr ? GetGraphLastTbeKernelName(cnode_exec_order) : std::string(trace_netoutput); -} - -NotNull ProfilingUtils::CreateProfilingCNode(const ProfilingContent &profiling_content, - NotNull graph_ptr) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); - selected_kernel_builder.SetInputsDeviceType({TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); - selected_kernel_builder.SetProcessor(kernel::Processor::AICORE); - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - abstract::AbstractBasePtr type_none_abstract = std::make_shared(); - auto primitive = std::make_shared(ProfilingUtils::kProfiling); - std::vector inputs; - inputs.emplace_back(NewValueNode(primitive)); - CNodePtr cnode_ptr = graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(cnode_ptr); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), cnode_ptr.get()); - cnode_ptr->set_abstract(type_none_abstract); - // set attr - ValuePtr notify_value = MakeValue(profiling_content.notify); - ValuePtr trace_id_value = MakeValue(profiling_content.profiler_trace_id); - ValuePtr flags_value = MakeValue(profiling_content.flags); - AnfAlgo::SetNodeAttr(ProfilingUtils::kNotify, notify_value, cnode_ptr); - AnfAlgo::SetNodeAttr(ProfilingUtils::kProfilerTraceId, trace_id_value, cnode_ptr); - AnfAlgo::SetNodeAttr(ProfilingUtils::kFlags, flags_value, cnode_ptr); - return NOT_NULL(cnode_ptr); -} - -void ProfilingUtils::SaveProfilingPoint(uint32_t graph_id, const std::string &node_name, uint32_t point_id) { - std::shared_ptr prof_desc_ptr = std::make_shared(node_name, point_id); - auto iter = graph_point_.find(graph_id); - if (iter == graph_point_.end()) { - std::vector> tmp_vect = {prof_desc_ptr}; - graph_point_.insert({graph_id, tmp_vect}); - } else { - iter->second.emplace_back(prof_desc_ptr); - } -} - -void ProfilingUtils::ProfilingTraceFpStart(const mindspore::AnfNodePtr &anf_node, - const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list) { - if (profiling_trace_info.trace_begin == anf_node->fullname_with_scope()) { - MS_LOG(INFO) << "Profiling Match FpStart:" << profiling_trace_info.trace_begin; - ProfilingTraceJobId(anf_node, graph_ptr, kernel_list); - ProfilingContent fp_profiling_content = {false, kProfilingFpStartLogId, 0}; - auto fp_profiling_node = CreateProfilingCNodeWithStream(anf_node, fp_profiling_content, graph_ptr); - kernel_list->emplace_back(fp_profiling_node); - // insert ProfDesc - SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingFpStartLogId); - } -} - -void ProfilingUtils::ProfilingTraceJobId(const AnfNodePtr &anf_node, NotNull graph_ptr, - NotNull *> kernel_list) { - MS_LOG(INFO) << "Profiling Match start"; - auto job_id = ProfilingManager::GetInstance().GetJobId(); - ProfilingContent job_profiling_context = {false, job_id, 0}; - auto job_profiling_node = CreateProfilingCNodeWithStream(anf_node, job_profiling_context, graph_ptr); - kernel_list->emplace_back(job_profiling_node); -} - -CNodePtr ProfilingUtils::CreateProfilingCNodeWithStream(const mindspore::AnfNodePtr &anf_node, - const ProfilingContent &profiling_content, - NotNull graph_ptr) { - CNodePtr profiling_node = CreateProfilingCNode(profiling_content, graph_ptr); - AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), profiling_node.get()); - AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(anf_node), profiling_node.get()); - return profiling_node; -} - -void ProfilingUtils::ProfilingCustomOp(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto iter = profiling_trace_info.trace_custom_node.find(anf_node->fullname_with_scope()); - if (iter == profiling_trace_info.trace_custom_node.end()) { - return; - } - MS_LOG(INFO) << "Profiling Match CustomOp:" << anf_node->fullname_with_scope(); - // custom op profiling job start from 3. - auto custom_point_id = 2 * custom_node_index_ + 1; - ProfilingContent front_profiling_content = {false, custom_point_id, 0}; - CNodePtr front_node = CreateProfilingCNodeWithStream(anf_node, front_profiling_content, graph_ptr); - kernel_list->insert(kernel_list->end() - 1, front_node); - SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), custom_point_id); - - ProfilingContent back_profiling_content = {false, custom_point_id + 1, 0}; - CNodePtr back_node = CreateProfilingCNodeWithStream(anf_node, back_profiling_content, graph_ptr); - kernel_list->insert(kernel_list->end(), back_node); - SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), custom_point_id + 1); - ++custom_node_index_; -} - -void ProfilingUtils::ProfilingTraceBpEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list) { - MS_EXCEPTION_IF_NULL(anf_node); - if (profiling_trace_info.trace_bp_end == anf_node->fullname_with_scope()) { - MS_LOG(INFO) << "Profiling Match BpEnd:" << profiling_trace_info.trace_bp_end; - ProfilingContent bp_end_profiling_content = {false, kProfilingBpEndLogId, 0}; - CNodePtr bp_end_node = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); - kernel_list->emplace_back(bp_end_node); - SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingBpEndLogId); - } -} - -void ProfilingUtils::ProfilingTraceEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto full_scope_name = anf_node->fullname_with_scope(); - if (profiling_trace_info.trace_netoutput == full_scope_name) { - MS_LOG(INFO) << "Profiling Match IterEnd:" << profiling_trace_info.trace_netoutput; - ProfilingContent bp_end_profiling_content = {true, kProfilingIterEndLogId, 0}; - CNodePtr bp_kernel_ptr = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); - kernel_list->emplace_back(bp_kernel_ptr); - SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingIterEndLogId); - } -} - -void ProfilingUtils::SetGraphKernelName(uint32_t graph_id, const std::vector &kernel_names) { - auto ret = graph_kernel_name_.try_emplace(graph_id, kernel_names); - if (!ret.second) { - MS_LOG(ERROR) << "[profiling]graph " << graph_id << " kernel names already exist"; - } -} - -void ProfilingUtils::SetGraphProfilingCNode(uint32_t graph_id, const std::vector &profiling_cnode_list) { - auto ret = graph_profiling_cnode_.try_emplace(graph_id, profiling_cnode_list); - if (!ret.second) { - MS_LOG(ERROR) << "[profiling]graph " << graph_id << " profiling cnode list already exist"; - } -} - -bool ProfilingUtils::ValidComputeGraph(NotNull graph_ptr) { - for (const auto &node : graph_ptr->execution_order()) { - if (AnfAlgo::GetKernelType(node) == TBE_KERNEL) { - return true; - } - } - return false; -} - -void ProfilingUtils::ReportProfilingData(const std::vector &task_ids, const std::vector &stream_ids, - NotNull graph) { - if (!ValidComputeGraph(graph)) { - MS_LOG(WARNING) << "Not a valid compute graph:" << graph->graph_id(); - return; - } - - auto ret = graph_profiling_cnode_.find(graph->graph_id()); - if (ret == graph_profiling_cnode_.end()) { - MS_LOG(ERROR) << "Graph id not found"; - return; - } - - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second); - task_reporter.set_task_ids(task_ids); - task_reporter.set_stream_ids(stream_ids); - task_reporter.ReportData(); - - GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second); - graph_profiling_cnode_.erase(ret); - graph_reporter.ReportData(); - - // Report profiling point - auto point_iter = graph_point_.find(graph->graph_id()); - if (point_iter == graph_point_.end()) { - MS_LOG(ERROR) << "Graph id not found in graph_point"; - return; - } - PointReporter point_reporter(context->device_id(), "vm.point"); - for (const auto &point : point_iter->second) { - point_reporter.AddReportData(point); - } - point_reporter.ReportData(); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h deleted file mode 100644 index a3c7739447..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ - -#include -#include -#include -#include -#include -#include -#include "session/kernel_graph.h" -#include "utils/contract.h" -#include "device/ascend/profiling/reporter/profiling_desc.h" - -namespace mindspore { -namespace device { -namespace ascend { -struct ProfilingTraceInfo { - // execute order's first execute op(like: Cast or Four2Five ...), except tdt op(GetNext ...) - std::string trace_begin; - // get first net_output(apply kernel) from graph outputs: fp ->net_output<- bp - std::string trace_bp_end; - // execute order's end execute (like: Conv2DBackpropFilter) - std::string trace_netoutput; - - // profiling specific op, such as AllReduce; - std::set trace_custom_node; - - // 1. insert profiling_trace_begin if profiling_trace_bp_end is not empty. - // 2. op lanuch get task info with callback func. - // 3. insert profiling_trace_bp_end. - // 4. insert profiling_trace_net_output if profiling_trace_bp_end is not empty. - - bool IsValid() const { return !(trace_begin.empty() || trace_netoutput.empty()); } -}; - -struct ProfilingContent { - // true -send data from device to host and finish profiling - bool notify; - uint64_t profiler_trace_id; - uint32_t flags; -}; - -class ProfilingUtils { - public: - ProfilingUtils() = default; - ~ProfilingUtils() = default; - - // Insert job_id profiling node and fp_start profiling node. - // Job_id is got from envs, which shound be a number greater than 255 - // Fp_start node should been inserted in the start of a network, and the log_id is hard code to 1. - static void ProfilingTraceFpStart(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list); - - static void ProfilingTraceJobId(const AnfNodePtr &anf_node, NotNull graph_ptr, - NotNull *> kernel_list); - - // Insert net output profiling node, which tells the device to stop profiling. - // The notify in struct ProfilingContent should be 'true', which tells the device to send data to host. - static void ProfilingTraceEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list); - - // Insert bp_end profiling node, which should been inserted after the last backpropagation CNode in the network. - static void ProfilingTraceBpEnd(const mindspore::AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list); - - // Mapping graph id and the kernels' name in the graph - static void SetGraphProfilingCNode(uint32_t graph_id, const std::vector &profiling_cnode_list); - - static void SetGraphKernelName(uint32_t graph_id, const std::vector &kernel_names); - - // Mapping task_id and kernel name for device to generate the time cost of specific kernel. - // Device calculate the time cost of the task which is marked by task id. - // But we need data of (kernel name , time cost) - static void ReportProfilingData(const std::vector &task_ids, const std::vector &stream_ids, - NotNull graph); - - // Get profiling trace point from envs. - // export PROFILING_FP_START='full name of the first cnode to execute' - // export PROFILING_BP_END='full name of the last backpropagation cnode to execute' - // export PROFILING_ITER_END='full name of last cnode in graph to execute' - // And other cnode, like AllReduce, export PROFILING_CUSTOM_1='full name of AllReduce cnode' - // GetNext, export PROFIFLING_CUSTOM_2='full name fo GetNext cnode' - // The variable i in PROFILING_CUSTOM_i should start from 1 without interruption. - static ProfilingTraceInfo GetProfilingTraceFromEnv(NotNull graph_ptr); - - // Insert two profiling trace points, one in front and one behind - static void ProfilingCustomOp(const mindspore::AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list); - - static std::map> graph_kernel_name() { return graph_kernel_name_; } - - inline static constexpr char kProfiling[] = "Profiling"; - inline static constexpr char kNotify[] = "notify"; - inline static constexpr char kProfilerTraceId[] = "profiler_trace_id"; - inline static constexpr char kFlags[] = "flags"; - - private: - static NotNull CreateProfilingCNode(const ProfilingContent &profiling_content, - NotNull graph_ptr); - static CNodePtr CreateProfilingCNodeWithStream(const AnfNodePtr &anf_node, const ProfilingContent &profiling_content, - NotNull graph_ptr); - static std::string GetTraceBegin(const std::vector &cnode_exec_order); - static std::string GetTraceBpEnd(const std::vector &cnode_exec_order); - static std::string GetTraceNetoutput(const std::vector &cnode_exec_order); - static std::string GetGraphLastTbeKernelName(const std::vector &cnode_exec_order); - static void GetTraceHccl(const std::vector &cnode_exec_order, - NotNull profiling_trace); - static void GetCNodeOutputRealNode(const std::string &node_name, const std::vector &cnode_exec_order, - NotNull *> getnext_outputs); - - static bool ValidComputeGraph(NotNull graph_ptr); - static void SaveProfilingPoint(uint32_t graph_id, const std::string &node_name, uint32_t point_id); - - // graph id --> (kernel name list) - static std::map> graph_profiling_cnode_; - static std::map> graph_kernel_name_; - static std::map>> graph_point_; - static uint32_t custom_node_index_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.cc b/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.cc deleted file mode 100644 index cf80c07ca9..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "device/ascend/profiling/reporter/desc_reporter.h" -#include "device/ascend/profiling/plugin_impl.h" -#include "utils/log_adapter.h" - -constexpr size_t kReportMaxLen = 2048; - -namespace mindspore { -namespace device { -namespace ascend { -DescReporter::~DescReporter() = default; - -void DescReporter::ReportByLine(const std::string &data, const std::string &file_name) const { - auto reporter = PluginImpl::GetPluginReporter(); - MS_EXCEPTION_IF_NULL(reporter); - - auto tot_size = data.size(); - size_t cur_size = 0; - while (cur_size < tot_size) { - size_t remain_size = tot_size - cur_size; - size_t report_size = std::min(remain_size, kReportMaxLen); - - Msprof::Engine::ReporterData report_data{}; - report_data.deviceId = device_id_; - report_data.dataLen = report_size; - report_data.data = (unsigned char *)data.c_str() + cur_size; - auto ret = memcpy_s(report_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, file_name.c_str(), file_name.length()); - if (ret != 0) { - MS_LOG(EXCEPTION) << "Memcpy_s report data tag failed"; - } - auto report_ret = reporter->Report(&report_data); - if (report_ret != 0) { - MS_LOG(EXCEPTION) << "Report data failed"; - } - if (report_size == 0) { - MS_LOG(WARNING) << "Report_size is 0"; - break; - } - cur_size += report_size; - } -} - -void DescReporter::ReportAllLine() { - for (const auto &desc : prof_desc_list_) { - auto data = desc->ToString(); - ReportByLine(data, file_name_); - } -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.h b/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.h deleted file mode 100644 index c8e1b3ed62..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ - -#include -#include -#include -#include -#include "toolchain/prof_reporter.h" -#include "device/ascend/profiling/reporter/profiling_desc.h" -#include "utils/contract.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace device { -namespace ascend { -class DescReporter { - public: - virtual ~DescReporter() = 0; - DescReporter(int device_id, std::string file_name) : device_id_(device_id), file_name_(std::move(file_name)) {} - - virtual void ReportData() = 0; - - protected: - void ReportByLine(const std::string &data, const std::string &file_name) const; - void ReportAllLine(); - - int device_id_; - std::string file_name_; - std::vector> prof_desc_list_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.cc b/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.cc deleted file mode 100644 index 1f2d1570bb..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "device/ascend/profiling/reporter/graph_desc_reporter.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace device { -namespace ascend { -void GraphDescReporter::ReportData() { - for (const auto &node : cnode_list_) { - if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { - MS_LOG(WARNING) << "Skip non tbe kernel"; - continue; - } - std::vector input_data_list; - std::vector output_data_list; - MS_EXCEPTION_IF_NULL(node); - auto op_name = node->fullname_with_scope(); - auto op_type = AnfAlgo::GetCNodeName(node); - auto input_size = AnfAlgo::GetInputTensorNum(node); - for (size_t i = 0; i < input_size; ++i) { - auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); - auto input_node = input_node_with_index.first; - auto input_index = input_node_with_index.second; - DataElement element{}; - element.index_ = i; - element.data_type_ = AnfAlgo::GetOutputDeviceDataType(input_node, input_index); - element.data_format_ = AnfAlgo::GetOutputFormat(input_node, input_index); - element.data_shape_ = AnfAlgo::GetOutputDeviceShape(input_node, input_index); - input_data_list.emplace_back(element); - } - - auto output_size = AnfAlgo::GetOutputTensorNum(node); - for (size_t i = 0; i < output_size; ++i) { - DataElement element{}; - element.index_ = i; - element.data_type_ = AnfAlgo::GetOutputDeviceDataType(node, i); - element.data_format_ = AnfAlgo::GetOutputFormat(node, i); - element.data_shape_ = AnfAlgo::GetOutputDeviceShape(node, i); - output_data_list.emplace_back(element); - } - - auto graph_desc = std::make_shared(op_name, op_type, input_data_list, output_data_list); - prof_desc_list_.emplace_back(graph_desc); - } - ReportAllLine(); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h b/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h deleted file mode 100644 index 10f78092f2..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ - -#include -#include -#include -#include "device/ascend/profiling/reporter/desc_reporter.h" - -namespace mindspore { -namespace device { -namespace ascend { -class GraphDescReporter : public DescReporter { - public: - GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector cnode_list) - : DescReporter(device_id, file_name), cnode_list_(std::move(cnode_list)) {} - ~GraphDescReporter() override = default; - void ReportData() override; - - private: - std::vector cnode_list_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.cc b/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.cc deleted file mode 100644 index 0024ab9c22..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/profiling/reporter/point_reporter.h" - -namespace mindspore { -namespace device { -namespace ascend { -void PointReporter::ReportData() { ReportAllLine(); } - -void PointReporter::AddReportData(const std::shared_ptr &prof_desc) { - prof_desc_list_.emplace_back(prof_desc); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.h b/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.h deleted file mode 100644 index ae12672df6..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ - -#include -#include -#include "device/ascend/profiling/reporter/desc_reporter.h" - -namespace mindspore { -namespace device { -namespace ascend { -class PointReporter : public DescReporter { - public: - PointReporter(uint32_t device_id, const std::string &file_name) : DescReporter(device_id, file_name) {} - ~PointReporter() override = default; - void ReportData() override; - void AddReportData(const std::shared_ptr &prof_desc); -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/profiling_desc.cc b/mindspore/ccsrc/device/ascend/profiling/reporter/profiling_desc.cc deleted file mode 100644 index 082cb81e42..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/profiling_desc.cc +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include "device/ascend/profiling/reporter/profiling_desc.h" - -namespace mindspore { -namespace device { -namespace ascend { -std::string TaskDesc::ToString() { - std::string out = op_name_; - out.append(" ") - .append(std::to_string(block_dim_)) - .append(" ") - .append(std::to_string(task_id_)) - .append(" ") - .append(std::to_string(stream_id_)) - .append("\n"); - return out; -} - -std::string GraphDesc::ToString() { - std::string desc; - desc.append("op_name:").append(op_name_).append(" op_type:").append(op_type_); - int input_id = 0; - for (const auto &element : input_data_list_) { - desc.append(" input_id:") - .append(std::to_string(input_id++)) - .append(" input_format:") - .append(element.data_format_) - .append(" input_data_type:") - .append(std::to_string(element.data_type_)) - .append(" input_shape:") - .append(DataShapeToString(element.data_shape_)); - } - - input_id = 0; - for (const auto &element : output_data_list_) { - desc.append(" output_id:") - .append(std::to_string(input_id++)) - .append(" output_format:") - .append(element.data_format_) - .append(" output_data_type:") - .append(std::to_string(element.data_type_)) - .append(" output_shape:") - .append((DataShapeToString(element.data_shape_))); - } - - desc.append("\n"); - - return desc; -} - -std::string PointDesc::ToString() { - std::string desc; - desc.append(std::to_string(point_id_)).append(" ").append(op_name_).append("\n"); - return desc; -} - -std::string GraphDesc::DataShapeToString(const std::vector &shape) { - std::ostringstream oss; - oss << "\""; - if (!shape.empty()) { - std::copy(shape.begin(), shape.end() - 1, std::ostream_iterator(oss, ",")); - oss << shape.back(); - } - oss << "\""; - return oss.str(); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.cc b/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.cc deleted file mode 100644 index 0bd66e31ef..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.cc +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "device/ascend/profiling/reporter/task_desc_reporter.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/ascend_kernel_mod.h" - -namespace mindspore { -namespace device { -namespace ascend { -void TaskDescReporter::ReportData() { - MS_LOG(INFO) << "cnode_list.size()=" << cnode_list_.size() << " task_ids_.size()=" << task_ids_.size(); - if (cnode_list_.size() != task_ids_.size()) { - MS_LOG(ERROR) << "cnode list size not equal task ids size"; - return; - } - - size_t task_index = 0; - for (const auto &node : cnode_list_) { - if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { - MS_LOG(WARNING) << "Skip non tbe kernel"; - ++task_index; - continue; - } - auto kernel_mod = AnfAlgo::GetKernelMod(node); - auto ascend_kernel_mod = dynamic_cast(kernel_mod); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(ascend_kernel_mod); - // Check task_id and stream_id valid - CheckStreamTaskValid(task_index, task_index); - auto desc_ptr = std::make_shared(node->fullname_with_scope(), task_ids_[task_index], - ascend_kernel_mod->block_dim(), stream_ids_[task_index]); - prof_desc_list_.emplace_back(desc_ptr); - ++task_index; - } - ReportAllLine(); -} - -void TaskDescReporter::CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id) { - if (task_id >= task_ids_.size() || stream_id >= stream_ids_.size()) { - MS_LOG(EXCEPTION) << "Index invalid. task_id:" << task_id << ", task_ids.size:" << task_ids_.size() - << ", stream_id:" << stream_id << ", stream_ids.size:" << stream_ids_.size(); - } -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.h b/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.h deleted file mode 100644 index 087c691a5f..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ - -#include -#include -#include -#include "device/ascend/profiling/reporter/desc_reporter.h" - -namespace mindspore { -namespace device { -namespace ascend { -class TaskDescReporter : public DescReporter { - public: - TaskDescReporter(int device_id, const std::string &file_name, std::vector cnode_list) - : DescReporter(device_id, file_name), cnode_list_(std::move(cnode_list)) {} - ~TaskDescReporter() override = default; - void ReportData() override; - void set_task_ids(const std::vector &task_ids) { task_ids_ = task_ids; } - void set_stream_ids(const std::vector &stream_ids) { stream_ids_ = stream_ids; } - - private: - std::vector task_ids_; - std::vector stream_ids_; - void CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id); - std::vector cnode_list_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc b/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc deleted file mode 100644 index 3faeefb820..0000000000 --- a/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/tasksink/runtime_utils.h" - -#include - -#include "hccl/hcom.h" -#include "utils/log_adapter.h" -#include "utils/utils.h" - -constexpr auto kHcomBroadcast = "hcom_broadcast_"; -constexpr auto kHcomAllGather = "hcom_all_gather_"; -constexpr auto kHcomAllReduce = "hcom_all_reduce_"; -constexpr auto kHcomReduceScatter = "hcom_reduce_scatter_"; -constexpr auto kUnderline = "_"; -namespace mindspore { -namespace device { -namespace ascend { -namespace tasksink { -bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) { - hcclResult_t ret = hcom_bind_model(model, stream); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast(ret); - return false; - } - return true; -} - -bool RuntimeUtils::HcomUnbindModel(rtModel_t model) { - hcclResult_t ret = hcom_unbind_model(model); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast(ret); - return false; - } - return true; -} - -bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info, rtStream_t stream) { - MS_LOG(INFO) << "hccl distribute start"; - MS_EXCEPTION_IF_NULL(task_info); - hcclResult_t ret; - static uint32_t task_counter = 0; - auto hccl_group = task_info->group(); - if (task_info->hccl_type() == kBroadcastOpName) { - // call hcom broadcast interface to run op - const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); - ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast(task_info->count()), - static_cast(task_info->data_type()), static_cast(task_info->root_id()), - hccl_group.c_str(), stream); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast(ret); - return false; - } - } else if (task_info->hccl_type() == kAllGatherOpName) { - // call hcom allgather interface to run op - const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); - ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), - static_cast(task_info->count()), static_cast(task_info->data_type()), - hccl_group.c_str(), stream); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; - return false; - } - } else if (task_info->hccl_type() == kAllReduceOpName) { - // call hcom allreduce interface to run op - const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0); - ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), - static_cast(task_info->count()), static_cast(task_info->data_type()), - static_cast(task_info->op_type()), hccl_group.c_str(), stream); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; - return false; - } - } else if (task_info->hccl_type() == kReduceScatterOpName) { - // call hcom reducescatter interface to run op - const string tag_reduce_scatter = - kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0); - ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), - static_cast(task_info->count()), static_cast(task_info->data_type()), - static_cast(task_info->op_type()), hccl_group.c_str(), stream); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; - return false; - } - } - return true; -} -} // namespace tasksink -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc deleted file mode 100644 index e026459ae9..0000000000 --- a/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc +++ /dev/null @@ -1,199 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/tasksink/task_generator.h" - -#include -#include "kernel/task_stream.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" -#include "device/ascend/profiling/profiling_utils.h" -#include "device/ascend/profiling/profiling_manager.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace tasksink { -bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::vector *task_info_list, - uint32_t graph_id) { - MS_LOG(INFO) << "GenTasks start..."; - MS_EXCEPTION_IF_NULL(task_info_list); - // Traverse graph applykernel list and run - if (!LaunchAllKernel(anf_node_list, task_info_list, graph_id)) { - MS_LOG(ERROR) << "LaunchAllKernel failed"; - return false; - } - MS_LOG(INFO) << "GenTasks end..."; - return true; -} - -void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { - MS_EXCEPTION_IF_NULL(anf_node_ptr); - MS_EXCEPTION_IF_NULL(kernel_inputs); - // akg process - // set atomic clean addr - if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, anf_node_ptr)) { - auto clean_output_indexs = AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicOutputIndexs); - auto graph = anf_node_ptr->func_graph(); - MS_EXCEPTION_IF_NULL(graph); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto node_users = manager->node_users(); - if (node_users[anf_node_ptr].empty()) { - MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty."; - } - auto depend_node = node_users[anf_node_ptr].pop().first; - if (!IsPrimitiveCNode(depend_node, prim::kPrimDepend)) { - MS_LOG(EXCEPTION) << "Checking Depend node failed"; - } - if (node_users[depend_node].empty()) { - MS_LOG(EXCEPTION) << "Node users of " << depend_node->ToString() << " is empty."; - } - auto post_node = node_users[depend_node].pop().first; - for (auto index : clean_output_indexs) { - auto device_address = AnfAlgo::GetOutputAddr(post_node, index); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - input->size = device_address->size_; - kernel_inputs->push_back(input); - } - MS_LOG(DEBUG) << "AtomicAddClean clean output size: " << clean_output_indexs.size(); - } -} - -void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { - MS_EXCEPTION_IF_NULL(anf_node_ptr); - MS_EXCEPTION_IF_NULL(kernel_inputs); - if (anf_node_ptr->inputs().size() != 2) { - LaunchAddrCleanAkgKernel(anf_node_ptr, kernel_inputs); - return; - } - MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); - auto pre_node = (anf_node_ptr->inputs()[1])->cast(); - // set clean output addr - if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { - auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); - for (auto index : clean_output_indexs) { - auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(input->addr); - input->size = device_address->size_; - kernel_inputs->push_back(input); - } - MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); - } - // set clean workspace address - if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { - auto clean_workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); - for (const auto &index : clean_workspace_indexs) { - auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(workspace->addr); - workspace->size = device_address->size_; - kernel_inputs->push_back(workspace); - } - } - auto clear_mems = AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicAddMemSize); - if (kernel_inputs->size() != clear_mems.size()) { - MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size,kerenl_inputs size:" - << kernel_inputs->size() << ",clean mem size" << clear_mems.size(); - } -} - -bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id, - std::vector *task_info_list) { - MS_EXCEPTION_IF_NULL(task_info_list); - MS_EXCEPTION_IF_NULL(anf_node_ptr); - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr); - MS_EXCEPTION_IF_NULL(kernel_mod); - if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) { - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) { - auto real_input_index = AnfAlgo::GetRealInputIndex(anf_node_ptr, i); - auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, real_input_index); - AddressPtr input = std::make_shared
(); - input->addr = device_address->ptr_; - input->size = device_address->size_; - kernel_inputs.push_back(input); - } - - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node_ptr); ++i) { - auto it = AnfAlgo::GetOutputAddr(anf_node_ptr, i); - AddressPtr output = std::make_shared
(); - output->addr = it->ptr_; - output->size = it->size_; - kernel_outputs.push_back(output); - } - - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto device_address = AnfAlgo::GetWorkspaceAddr(anf_node_ptr, i); - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_address->ptr_; - workspace->size = device_address->size_; - kernel_workspaces.push_back(workspace); - } - } else { - LaunchAddrCleanKernel(anf_node_ptr, &kernel_inputs); - } - - auto ascend_kernel_mod = dynamic_cast(kernel_mod); - MS_EXCEPTION_IF_NULL(ascend_kernel_mod); - std::vector task_info_ptrs = - ascend_kernel_mod->GenTask(kernel_inputs, kernel_workspaces, kernel_outputs, stream_id); - task_info_list->insert(task_info_list->end(), task_info_ptrs.begin(), task_info_ptrs.end()); - return true; -} - -bool TaskGenerator::LaunchAllKernel(const std::vector &anf_node_list, - std::vector *task_info_list, uint32_t graph_id) { - uint32_t current_op_index = 0; - std::vector profiling_cnode_list; - std::vector kernel_name_list; - for (const auto &anf_node_ptr : anf_node_list) { - size_t old_size = task_info_list->size(); - uint32_t stream_id = AnfAlgo::GetStreamId(anf_node_ptr); - MS_EXCEPTION_IF_NULL(anf_node_ptr); - MS_LOG(INFO) << "Task gen launch begin, current_op_idx:" << current_op_index - << " name:" << anf_node_ptr->fullname_with_scope() << ", stream id:" << stream_id; - if (!LaunchKernel(anf_node_ptr, stream_id, task_info_list)) { - MS_LOG(ERROR) << "LaunchKernel failed."; - return false; - } - for (size_t i = old_size; i < task_info_list->size(); ++i) { - profiling_cnode_list.emplace_back(anf_node_ptr); - kernel_name_list.emplace_back(anf_node_ptr->fullname_with_scope()); - } - current_op_index++; - } - - ProfilingUtils::SetGraphKernelName(graph_id, kernel_name_list); - if (ProfilingManager::GetInstance().IsProfiling()) { - ProfilingUtils::SetGraphProfilingCNode(graph_id, profiling_cnode_list); - } - return true; -} -} // namespace tasksink -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/tasksink/task_generator.h b/mindspore/ccsrc/device/ascend/tasksink/task_generator.h deleted file mode 100644 index ecd5889b04..0000000000 --- a/mindspore/ccsrc/device/ascend/tasksink/task_generator.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ - -#include -#include -#include -#include -#include -#include -#include "device/kernel_runtime.h" -#include "ir/anf.h" -#include "kernel/ascend_kernel_mod.h" -#include "framework/ge_runtime/task_info.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace tasksink { -using mindspore::kernel::Address; -using mindspore::kernel::AddressPtr; -using AddressPtrList = std::vector; -using ge::model_runner::TaskInfo; -using TaskInfoPtr = std::shared_ptr; -class TaskGenerator { - public: - TaskGenerator() = default; - ~TaskGenerator() = default; - TaskGenerator(const TaskGenerator &in) = delete; - TaskGenerator &operator=(const TaskGenerator &in) = delete; - - static bool GenTasks(const std::vector &anf_node_list, std::vector *task_info_list, - uint32_t graph_id); - - private: - static void LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs); - static void LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs); - static bool LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id, std::vector *task_info_list); - static bool LaunchAllKernel(const std::vector &anf_node_list, std::vector *task_info_list, - uint32_t graph_id); -}; -} // namespace tasksink -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ diff --git a/mindspore/ccsrc/device/convert_tensor_utils.cc b/mindspore/ccsrc/device/convert_tensor_utils.cc deleted file mode 100644 index bac72727c2..0000000000 --- a/mindspore/ccsrc/device/convert_tensor_utils.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/convert_tensor_utils.h" -#include -namespace mindspore { -namespace device { -void HalfToFloat(void *dst, const void *src, size_t elem_num) { - auto half_data = static_cast(src); - auto float_data = static_cast(dst); - for (size_t i = 0; i < elem_num; ++i) { - float tmp = Eigen::half_impl::half_to_float(half_data[i]); - float_data[i] = tmp; - } -} - -void FloatToHalf(void *dst, const void *src, size_t elem_num) { - auto float_data = static_cast(src); - auto half_data = static_cast(dst); - for (size_t i = 0; i < elem_num; ++i) { - half_data[i] = Eigen::half(float_data[i]); - } -} - -void DoubleToFloat(void *dst, const void *src, size_t elem_num) { - auto double_data = static_cast(src); - auto float_data = static_cast(dst); - for (size_t i = 0; i < elem_num; ++i) { - float_data[i] = static_cast(double_data[i]); - } -} - -void FloatToDouble(void *dst, const void *src, size_t elem_num) { - auto float_data = static_cast(src); - auto double_data = static_cast(dst); - for (size_t i = 0; i < elem_num; ++i) { - double_data[i] = static_cast(float_data[i]); - } -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/cpu_device_address.cc b/mindspore/ccsrc/device/cpu/cpu_device_address.cc deleted file mode 100644 index 09ab0da12b..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_device_address.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/cpu/cpu_device_address.h" -#include -#include "device/convert_tensor_utils.h" - -namespace mindspore { -namespace device { -namespace cpu { -bool CPUDeviceAddress::SyncDeviceToHost(const std::vector & /*shape*/, size_t size, TypeId type, - void *host_ptr) const { - if (ptr_ == nullptr) { - MS_LOG(ERROR) << "The pointer ptr_ is null!"; - return false; - } - - if (host_ptr == ptr_) { - MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored."; - return true; - } - - if (type == type_id_) { - auto ret_code = memcpy_s(host_ptr, size, ptr_, size_); - if (ret_code != EOK) { - MS_LOG(ERROR) << "Failed to copy tensor!"; - return false; - } - } else if (type == kNumberTypeFloat16) { - FloatToHalf(host_ptr, ptr_, size / 2); - } else if (type == kNumberTypeFloat64) { - FloatToDouble(host_ptr, ptr_, size / sizeof(double)); - } else { - MS_LOG(ERROR) << "Types not match. Device type: " << TypeIdLabel(type_id_) << ", host type: " << TypeIdLabel(type) - << "!"; - return false; - } - return true; -} - -bool CPUDeviceAddress::SyncHostToDevice(const std::vector & /*shape*/, size_t size, TypeId type, - const void *host_ptr) const { - if (type == kNumberTypeFloat16) { - HalfToFloat(ptr_, host_ptr, size / 2); - } else if (type == kNumberTypeFloat64) { - DoubleToFloat(ptr_, host_ptr, size / sizeof(double)); - } - return true; -} -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/cpu_device_address.h b/mindspore/ccsrc/device/cpu/cpu_device_address.h deleted file mode 100644 index a041567f47..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_device_address.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ - -#include -#include -#include "device/device_address.h" - -namespace mindspore { -namespace device { -namespace cpu { -class CPUDeviceAddress : public DeviceAddress { - public: - CPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} - - CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) - : DeviceAddress(ptr, size, format, type_id) {} - - ~CPUDeviceAddress() override = default; - - bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; - bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; - DeviceAddressType DeviceType() const override { return DeviceAddressType::kCPU; } -}; -} // namespace cpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc deleted file mode 100644 index f46d10ed82..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc +++ /dev/null @@ -1,324 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/cpu/cpu_kernel_runtime.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "utils/context/ms_context.h" -#include "utils/config_manager.h" -#include "utils/profile.h" -#include "common/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "session/session_basic.h" -#include "operator/ops.h" - -namespace mindspore { -namespace device { -namespace cpu { -const size_t INIT_NODE_REF = 1; -namespace { -TypeId GetCPUSupportOutputTypeId(const TypeId type_id) { - TypeId support_type_id = type_id; - if (type_id == kNumberTypeUInt32) { - support_type_id = kNumberTypeInt32; - } - if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 || - type_id == kNumberTypeFloat64) { - support_type_id = kNumberTypeFloat32; - } - if (support_type_id != kNumberTypeInt32 && support_type_id != kNumberTypeFloat32) { - MS_LOG(EXCEPTION) << "Check output type failed."; - } - return support_type_id; -} -} // namespace - -void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) { - AssignValueNodeAddress(kernel_graph); - AssignInputNodeAddress(kernel_graph); - AssignKernelOutputAddress(kernel_graph); - resource_manager_.MemPlan(kernel_graph); - resource_manager_.MemMalloc(kernel_graph); -} - -void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - size_t type_size = sizeof(float); - for (auto &item_node : kernel_graph->graph_value_nodes()) { - MS_EXCEPTION_IF_NULL(item_node); - if (item_node->isa()) { - auto value_node = item_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto node_value = value_node->value(); - MS_EXCEPTION_IF_NULL(node_value); - if (!node_value->isa()) { - continue; - } - auto tensor = node_value->cast(); - MS_EXCEPTION_IF_NULL(tensor); - std::vector data_shape = tensor->shape(); - size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); - DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32); - MS_EXCEPTION_IF_NULL(address); - if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { - address->ptr_ = tensor->data_c(); - } else { - address->ptr_ = resource_manager_.MemMalloc(tensor_size); - if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "Value node sync host to device failed!"; - } - } - address->ref_count_ = INIT_NODE_REF; - AnfAlgo::SetOutputAddr(address, 0, item_node.get()); - } - } -} - -void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - size_t type_size = sizeof(float); - for (auto &item : kernel_graph->inputs()) { - MS_EXCEPTION_IF_NULL(item); - if (item->isa()) { - auto output_num = AnfAlgo::GetOutputTensorNum(item); - for (size_t index = 0; index < output_num; index++) { - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); - std::vector fmt_shape = AnfAlgo::GetOutputDeviceShape(item, index); - size_t tensor_size = - fmt_shape.empty() ? type_size - : std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies()); - auto format = AnfAlgo::GetOutputFormat(item, index); - auto address = CreateDeviceAddress(nullptr, tensor_size, format, output_type_id); - AnfAlgo::SetOutputAddr(address, index, item.get()); - } - } - } -} - -void CPUKernelRuntime::AssignKernelOutputAddress(const session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto kernels = kernel_graph->execution_order(); - for (auto &kernel : kernels) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - for (size_t i = 0; i < output_sizes.size(); ++i) { - auto output_format = AnfAlgo::GetOutputFormat(kernel, i); - auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); - AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i, - kernel.get()); - } - auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); - for (size_t i = 0; i < workspace_sizes.size(); ++i) { - AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32), - i, kernel.get()); - } - } -} - -DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) { - return std::make_shared(device_ptr, device_size, format, type_id); -} - -tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, size_t index, - std::set *bound_addresses, - std::vector *need_sync_outputs) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(bound_addresses); - MS_EXCEPTION_IF_NULL(need_sync_outputs); - size_t output_size = AnfAlgo::GetOutputTensorNum(node); - if (index >= output_size) { - MS_LOG(EXCEPTION) << "Invalid input index " << index; - } - auto address = AnfAlgo::GetMutableOutputAddr(node, index); - MS_EXCEPTION_IF_NULL(address); - auto shape = AnfAlgo::GetOutputInferShape(node, index); - std::vector temp_shape; - (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); - TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index); - type_id = GetCPUSupportOutputTypeId(type_id); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - MS_EXCEPTION_IF_NULL(tensor); - if (bound_addresses->find(address) != bound_addresses->end()) { - tensor->set_device_address(address); - need_sync_outputs->emplace_back(tensor); - } else { - address->ptr_ = tensor->data_c(); - address->ref_count_ = INIT_NODE_REF; - (void)bound_addresses->insert(address); - } - tensor->set_dirty(false); - return tensor; -} - -BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, - const std::unordered_map &input_map, - std::set *bound_addresses, - std::vector *need_sync_outputs) { - auto &input_node = kernel_with_index.first; - auto index = kernel_with_index.second; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa()) { - auto node = input_node->cast(); - MS_EXCEPTION_IF_NULL(node); - if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimMakeTuple->name()) { - VectorRef ret; - for (size_t i = 1; i < node->inputs().size(); i++) { - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0); - auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs); - ret.push_back(out); - } - return ret; - } - return CreatTensorForOutput(node, index, bound_addresses, need_sync_outputs); - } else if (input_node->isa() || input_node->isa()) { - auto iter = input_map.find(input_node.get()); - if (iter != input_map.end()) { - return iter->second; - } - } - return BaseRef(); -} - -void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, - const std::vector &inputs, VectorRef *outputs, - std::vector *need_sync_outputs) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(outputs); - // bind input ptr - auto &input_nodes = kernel_graph->inputs(); - if (input_nodes.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; - } - std::unordered_map input_map; - size_t input_idx = 0; - for (auto &item : input_nodes) { - MS_EXCEPTION_IF_NULL(item); - input_map[item.get()] = inputs[input_idx]; - if (item->isa()) { - auto address = AnfAlgo::GetMutableOutputAddr(item, 0); - auto tensor = inputs[input_idx]; - auto tensor_address = tensor->device_address(); - MS_EXCEPTION_IF_NULL(address); - MS_EXCEPTION_IF_NULL(tensor); - if (tensor_address != nullptr && tensor_address != address) { - (void)tensor->data_sync(); - } - std::vector data_shape = tensor->shape(); - size_t tensor_size = - std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies()); - if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { - address->ptr_ = tensor->data_c(); - } else { - address->ptr_ = resource_manager_.MemMalloc(tensor_size); - if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!"; - } - tensor->set_dirty(true); - } - address->ref_count_ = INIT_NODE_REF; - tensor->set_device_address(address); - } - input_idx++; - } - // new output and bind ptr - std::set bound_addresses; - auto output_nodes = kernel_graph->outputs(); - for (const auto &item : output_nodes) { - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); - auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs); - outputs->push_back(std::move(out)); - } -} - -void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector *input_list) { - MS_EXCEPTION_IF_NULL(address); - MS_EXCEPTION_IF_NULL(input_list); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - if (address->ptr_ == nullptr) { - address->ptr_ = resource_manager_.MemMalloc(address->size_); - } - MS_EXCEPTION_IF_NULL(address->ptr_); - input->addr = address->ptr_; - input->size = address->size_; - input_list->push_back(input); -} - -void CPUKernelRuntime::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { - resource_manager_.IncreaseSummaryRefCount(summary_outputs); -} - -void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { - resource_manager_.DecreaseSummaryRefCount(summary_outputs); -} - -bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - resource_manager_.IncreaseAddressRefCount(kernel_graph); - - auto kernels = kernel_graph->execution_order(); - for (const auto &kernel : kernels) { -#ifdef ENABLE_PROFILE - double start_time = GetTime(); -#endif - std::vector kernel_inputs; - std::vector kernel_workspaces; - std::vector kernel_outputs; - size_t input_num = AnfAlgo::GetInputTensorNum(kernel); - for (size_t i = 0; i < input_num; ++i) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i).get(); - MS_EXCEPTION_IF_NULL(device_address); - AddRuntimeAddress(device_address, &kernel_inputs); - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); - for (size_t i = 0; i < output_num; ++i) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i).get(); - MS_EXCEPTION_IF_NULL(device_address); - AddRuntimeAddress(device_address, &kernel_outputs); - } - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); - MS_EXCEPTION_IF_NULL(device_address); - AddRuntimeAddress(device_address, &kernel_workspaces); - } - auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, 0); - resource_manager_.DecreaseAddressRefCount(kernel); - if (!ret) { - MS_LOG(EXCEPTION) << "Launch kernel failed."; - } -#ifdef ENABLE_PROFILE - double cost_time = GetTime() - start_time; - MS_LOG(INFO) << "cpu kernel: " << kernel->fullname_with_scope() << " costs " << cost_time * 1e6 << " us"; -#endif - } - return true; -} -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h deleted file mode 100644 index 354d2922c2..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ - -#include -#include -#include -#include -#include -#include "device/kernel_runtime.h" -#include "session/kernel_graph.h" -#include "session/session_basic.h" -#include "device/cpu/cpu_resource_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/any.h" -namespace mindspore { -namespace device { -namespace cpu { -class CPUKernelRuntime : public KernelRuntime { - public: - CPUKernelRuntime() = default; - ~CPUKernelRuntime() override = default; - - bool Init() override { return true; } - bool Run(session::KernelGraph *graph) override; - void AssignKernelAddress(session::KernelGraph *kernel_graph); - void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector &inputs, - VectorRef *outputs, std::vector *need_sync_outputs); - void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); - void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); - - protected: - bool SyncStream() override { return true; }; - DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) override; - - private: - tensor::TensorPtr CreatTensorForOutput(const CNodePtr &node, size_t index, - std::set *bound_addresses, - std::vector *need_sync_outputs); - - BaseRef CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, - const std::unordered_map &input_map, - std::set *bound_addresses, - std::vector *need_sync_outputs); - void AssignValueNodeAddress(session::KernelGraph *kernel_graph); - void AssignInputNodeAddress(const session::KernelGraph *kernel_graph); - void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph); - void AddRuntimeAddress(DeviceAddress *address, std::vector *input_list); - CPUResourceManager resource_manager_; -}; -} // namespace cpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/device/cpu/cpu_resource_manager.cc b/mindspore/ccsrc/device/cpu/cpu_resource_manager.cc deleted file mode 100644 index c69ef35305..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_resource_manager.cc +++ /dev/null @@ -1,174 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/cpu/cpu_resource_manager.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace device { -namespace cpu { -CPUResourceManager::~CPUResourceManager() { MemFree(); } - -void CPUResourceManager::MemFree() { - if (mem_ptr_ != nullptr) { - free(mem_ptr_); - mem_ptr_ = nullptr; - mem_size_ = 0; - } - - for (auto &&iter : dynamic_mem_) { - free(iter.first); - } - dynamic_mem_.clear(); -} - -void CPUResourceManager::MemPlan(const session::KernelGraph *graph) { - mem_plan_.MemPlan(graph); - size_t graph_mem_size = mem_plan_.GetGraphMemSize(graph); - if (graph_mem_size > mem_size_) { - MemFree(); - mem_ptr_ = reinterpret_cast(malloc(graph_mem_size)); - if (mem_ptr_ != nullptr) { - mem_size_ = graph_mem_size; - dynamic_malloc_ = false; - } else { - MS_LOG(INFO) << "Switch to dynamic malloc"; - dynamic_malloc_ = true; - } - } -} - -void CPUResourceManager::MemMalloc(const session::KernelGraph *graph) { - if (dynamic_malloc_) { - return; - } - mem_plan_.MemAssign(graph, mem_ptr_); -} - -void *CPUResourceManager::MemMalloc(size_t mem_size) { - void *ptr = malloc(mem_size); - if (ptr != nullptr) { - memset_s(ptr, mem_size, 0, mem_size); - dynamic_mem_[ptr] = mem_size; - return ptr; - } else { - MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size; - } -} - -void CPUResourceManager::MemFree(void *ptr) { - auto iter = dynamic_mem_.find(ptr); - if (iter != dynamic_mem_.end()) { - (void)dynamic_mem_.erase(iter); - free(ptr); - } -} - -void CPUResourceManager::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { - if (!dynamic_malloc_) { - return; - } - - if (summary_outputs.empty()) { - return; - } - - for (auto &output_item : summary_outputs) { - auto node = output_item.second.first; - size_t index = IntToSize(output_item.second.second); - auto address = AnfAlgo::GetMutableOutputAddr(node, index); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_++; - } -} - -void CPUResourceManager::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { - if (!dynamic_malloc_) { - return; - } - - if (summary_outputs.empty()) { - return; - } - - for (auto &output_item : summary_outputs) { - auto node = output_item.second.first; - size_t index = IntToSize(output_item.second.second); - auto address = AnfAlgo::GetMutableOutputAddr(node, index); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_--; - if (address->ref_count_ == 0 && address->ptr_ != nullptr) { - MemFree(address->ptr_); - address->ptr_ = nullptr; - } - } -} - -void CPUResourceManager::IncreaseAddressRefCount(const session::KernelGraph *graph) { - if (!dynamic_malloc_) { - return; - } - MS_EXCEPTION_IF_NULL(graph); - auto kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel); - for (size_t i = 0; i < input_num; ++i) { - auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_++; - } - - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_++; - } - } -} - -void CPUResourceManager::DecreaseAddressRefCount(const AnfNodePtr &kernel) { - if (!dynamic_malloc_) { - return; - } - MS_EXCEPTION_IF_NULL(kernel); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel); - for (size_t i = 0; i < input_num; ++i) { - auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_--; - if (address->ref_count_ == 0 && address->ptr_ != nullptr) { - MemFree(address->ptr_); - address->ptr_ = nullptr; - } - } - - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_--; - if (address->ref_count_ == 0 && address->ptr_ != nullptr) { - MemFree(address->ptr_); - address->ptr_ = nullptr; - } - } -} -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/cpu_resource_manager.h b/mindspore/ccsrc/device/cpu/cpu_resource_manager.h deleted file mode 100644 index d130241464..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_resource_manager.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ - -#include -#include -#include "session/kernel_graph.h" -#include "session/session_basic.h" -#include "device/device_address.h" -#include "device/cpu/cpu_simple_mem_plan.h" -namespace mindspore { -namespace device { -namespace cpu { -class CPUResourceManager { - public: - CPUResourceManager() = default; - ~CPUResourceManager(); - - void MemPlan(const session::KernelGraph *graph); - void MemMalloc(const session::KernelGraph *graph); - void IncreaseAddressRefCount(const session::KernelGraph *graph); - void DecreaseAddressRefCount(const AnfNodePtr &kernel); - void *MemMalloc(size_t mem_size); - void MemFree(void *ptr); - void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); - void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); - - private: - void MemFree(); - CPUSimpleMemPlan mem_plan_; - - size_t mem_size_{0}; - uint8_t *mem_ptr_{nullptr}; - bool dynamic_malloc_{false}; - std::unordered_map dynamic_mem_; -}; -} // namespace cpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ diff --git a/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.cc b/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.cc deleted file mode 100644 index e6cb6ee53a..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.cc +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/cpu/cpu_simple_mem_plan.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace device { -namespace cpu { -void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - size_t total_mem_size = 0; - auto kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel); - for (size_t i = 0; i < input_num; ++i) { - auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); - MS_EXCEPTION_IF_NULL(kernel_with_index.first); - if (kernel_with_index.first->isa()) { - continue; - } - auto address = AnfAlgo::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, true); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - total_mem_size += address->size_; - } - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); - for (size_t i = 0; i < output_num; ++i) { - auto address = AnfAlgo::GetOutputAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - total_mem_size += address->size_; - } - } - - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - total_mem_size += address->size_; - } - } - } - graph_mem_size_[graph] = total_mem_size; -} - -size_t CPUSimpleMemPlan::GetGraphMemSize(const session::KernelGraph *graph) const { - auto iter = graph_mem_size_.find(graph); - if (iter != graph_mem_size_.end()) { - return iter->second; - } - return 0; -} - -void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(base_ptr); - uint8_t *mem_ptr = base_ptr; - auto kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel); - for (size_t i = 0; i < input_num; ++i) { - auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); - MS_EXCEPTION_IF_NULL(kernel_with_index.first); - if (kernel_with_index.first->isa()) { - continue; - } - auto address = AnfAlgo::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, true); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - address->ptr_ = mem_ptr; - mem_ptr = mem_ptr + address->size_; - } - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); - for (size_t i = 0; i < output_num; ++i) { - auto address = AnfAlgo::GetMutableOutputAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - address->ptr_ = mem_ptr; - mem_ptr = mem_ptr + address->size_; - } - } - - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - address->ptr_ = mem_ptr; - mem_ptr = mem_ptr + address->size_; - } - } - } -} -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.h b/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.h deleted file mode 100644 index 7633ef3f45..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ - -#include -#include -#include "session/kernel_graph.h" -#include "device/device_address.h" - -namespace mindspore { -namespace device { -namespace cpu { -class CPUSimpleMemPlan { - public: - CPUSimpleMemPlan() = default; - ~CPUSimpleMemPlan() = default; - - void MemPlan(const session::KernelGraph *graph); - void MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr); - size_t GetGraphMemSize(const session::KernelGraph *graph) const; - - private: - std::unordered_map graph_mem_size_; -}; -} // namespace cpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ diff --git a/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc deleted file mode 100644 index 9d72bcab89..0000000000 --- a/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/cpu/kernel_select_cpu.h" - -#include -#include -#include - -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace device { -namespace cpu { -using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; -using mindspore::kernel::KernelBuildInfo; -namespace { -bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) { - auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa() || input_node->isa()) { - return true; - } - return false; -} - -void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector &input_not_cnode_indexes, - const CNodePtr kernel_node) { - for (auto &input_index : input_not_cnode_indexes) { - auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_node); - std::vector output_types; - output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first); - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - builder->SetOutputsFormat({kOpFormat_DEFAULT}); - builder->SetOutputsDeviceType(output_types); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get()); - } -} - -void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector *input_formats, - std::vector *input_types, std::vector *input_no_cnode_indexes) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - TypeId dtype = kTypeUnknown; - if (IsInputNotCNode(kernel_node, input_index)) { - input_no_cnode_indexes->emplace_back(input_index); - dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); - } else { - dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); - } - input_formats->emplace_back(kOpFormat_DEFAULT); - input_types->emplace_back(dtype); - } -} - -void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr, - std::vector *output_formats, std::vector *output_types) { - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - for (size_t output_index = 0; output_index < output_num; ++output_index) { - output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second); - auto dtype = kernel_attr.GetOutputAttr(output_index).first; - output_types->emplace_back(dtype); - } -} - -bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector &input_formats, - const std::vector &input_types, - const std::vector &input_not_cnode_indexes) { - if (kernel_attr.GetInputSize() != input_types.size()) { - MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size(); - return false; - } - auto input_num = input_types.size(); - for (size_t i = 0; i < input_num; ++i) { - bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), - [i](size_t index) { return index == i; }); - bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); - if (have_cnode_input && is_not_cnode_idx) { - continue; - } - if (kernel_attr.GetInputAttr(i).first != input_types[i]) { - MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first - << ", actual input dtype:" << input_types[i]; - return false; - } - if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { - MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second - << ", actual input format:" << input_formats[i]; - return false; - } - } - return true; -} - -void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { - MS_EXCEPTION_IF_NULL(kernel_attr); - TypeId input_dtype = kernel_attr->GetInputAttr(0).first; - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 1; i < input_num; ++i) { - kernel_attr->AddInputAttr(input_dtype); - } - - TypeId output_dtype = kernel_attr->GetOutputAttr(0).first; - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - for (size_t i = 1; i < output_num; ++i) { - kernel_attr->AddOutputAttr(output_dtype); - } -} -} // namespace - -void SetKernelInfo(const CNodePtr &kernel_node) { - std::vector input_formats; - std::vector input_types; - std::vector input_not_cnode_indexes; - std::vector output_formats; - std::vector output_types; - - MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node); - GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes); - - auto kernel_attrs = - kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); - - for (size_t index = 0; index < kernel_attrs.size(); ++index) { - auto kernel_attr = kernel_attrs[index]; - if (kernel_attr.GetAllSame()) { - ExpandKernelAttr(kernel_node, &kernel_attr); - } - if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (kernel_attr.GetOutputSize() != output_num) { - MS_LOG(DEBUG) << "Output num is not equal!"; - continue; - } - MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; - GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); - UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); - for (auto &input_index : input_not_cnode_indexes) { - input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; - } - break; - } - } - - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - builder->SetInputsFormat(input_formats); - builder->SetInputsDeviceType(input_types); - builder->SetOutputsFormat(output_formats); - builder->SetOutputsDeviceType(output_types); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); -} -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc b/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc deleted file mode 100644 index 9b06c0a40a..0000000000 --- a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc +++ /dev/null @@ -1,277 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/cpu/mpi/mpi_adapter.h" -#ifdef ENABLE_MPI -#include -#include -#include "pybind11/pybind11.h" -#endif // ENABLE_MPI -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -namespace cpu { -std::shared_ptr MPIAdapter::instance_ = nullptr; -std::shared_ptr MPIAdapter::Instance() { - if (instance_ == nullptr) { - MS_LOG(DEBUG) << "Create new mpi adapter instance."; - instance_.reset(new (std::nothrow) MPIAdapter()); - } - return instance_; -} - -#ifdef ENABLE_MPI - -#define RAISE_EXCEPTION(message) \ - { \ - std::ostringstream oss; \ - oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \ - pybind11::pybind11_fail(oss.str()); \ - } - -#define RAISE_EXCEPTION_WITH_PARAM(message, param) \ - { \ - std::ostringstream oss; \ - oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \ - pybind11::pybind11_fail(oss.str()); \ - } - -namespace { -MPI_Op GetMpiOp(const std::string &op_type) { - if (op_type == "sum") { - return MPI_SUM; - } else if (op_type == "max") { - return MPI_MAX; - } else if (op_type == "min") { - return MPI_MIN; - } else if (op_type == "prod") { - return MPI_PROD; - } - - RAISE_EXCEPTION_WITH_PARAM("unsupport op_type: ", op_type); - return MPI_SUM; -} - -int GetScatterIndex(int rankid, const std::vector &ranks_group) { - int scatter_index = -1; - for (size_t i = 0; i < ranks_group.size(); ++i) { - if (ranks_group[i] == rankid) { - scatter_index = static_cast(i); - break; - } - } - if (scatter_index == -1) { - RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rankid); - } - return scatter_index; -} -} // namespace - -MPIAdapter::MPIAdapter() : comm_group_world_(MPI_GROUP_NULL) { Init(); } - -MPIAdapter::~MPIAdapter() { - int finalized; - MPI_Finalized(&finalized); - if (finalized != 0) { - return; - } - - for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) { - MPI_Group_free(&iter->second); - } - ranks_group_.clear(); - if (comm_group_world_ != MPI_GROUP_NULL) { - MPI_Group_free(&comm_group_world_); - comm_group_world_ = MPI_GROUP_NULL; - } - MPI_Finalize(); -} - -void MPIAdapter::Init() { - static bool init = false; - if (init) { - return; - } - - int init_flag = 0; - if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { - RAISE_EXCEPTION("Check mpi initialized fail!"); - } - if (init_flag == 0) { - auto ret = MPI_Init(nullptr, nullptr); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION("Failed to init mpi!"); - } - } - - MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_); - if (comm_group_world_ == MPI_GROUP_NULL) { - RAISE_EXCEPTION("comm_group_world_ init fail!"); - } - auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION("Failed to init mpi rank id!"); - } - - ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_) - } - init = true; -} - -MPI_Group MPIAdapter::AddGroup(const std::vector &ranks) { - if (ranks.size() > static_cast(rank_size_) || ranks.empty()) { - RAISE_EXCEPTION_WITH_PARAM("input rank size:", ranks.size()); - } - - if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) { - RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rank_id_); - } - std::lock_guard lock(group_mutex_); - auto iter = ranks_group_.find(ranks); - if (iter != ranks_group_.end()) { - return iter->second; - } - const auto ranks_size = ranks.size(); - std::vector ranks_input(ranks_size, 0); - for (size_t i = 0; i < ranks_size; ++i) { - ranks_input[i] = ranks[i]; - } - - MPI_Group group = MPI_GROUP_NULL; - MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group); - if (group == MPI_GROUP_NULL) { - RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_) - } - - ranks_group_[ranks] = group; - return group; -} - -bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector &ranks_group, size_t data_num, - const std::string &op_type) { - if (ranks_group.empty()) { - RAISE_EXCEPTION("input rank group is empty!"); - return false; - } - - auto group = AddGroup(ranks_group); - if (group == MPI_GROUP_NULL) { - RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_) - } - MPI_Comm comm; - MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); - if (comm == MPI_COMM_NULL) { - RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); - } - std::vector receive_count(ranks_group.size(), 0); - for (size_t i = 0; i < ranks_group.size(); ++i) { - receive_count[i] = data_num; - } - - auto op = GetMpiOp(op_type); - auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm); - bool result = true; - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi reduce_scatter fail!ret = ", ret); - result = false; - } - - ret = MPI_Comm_free(&comm); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail! ret = ", ret); - } - return result; -} - -bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t input_data_num, - size_t output_size, const std::string &op_type, float *output) { - int scatter_index = GetScatterIndex(rank_id_, ranks_group); - auto group = AddGroup(ranks_group); - if (group == MPI_GROUP_NULL) { - RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_); - } - MPI_Comm comm; - MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); - if (comm == MPI_COMM_NULL) { - RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); - } - - MPI_Win window; - auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi window create fail! ret = ", ret); - } - MPI_Win_fence(0, window); - for (size_t i = 0; i < ranks_group.size(); ++i) { - int remote_rank = ranks_group[i]; - if (rank_id_ == remote_rank) { - continue; - } - auto op = GetMpiOp(op_type); - ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num, - input_data_num, MPI_FLOAT, op, window); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi accumulate fail!ret = ", ret); - } - } - MPI_Win_fence(0, window); - if (output != nullptr) { - auto data_size = input_data_num * sizeof(float); - if (output_size < data_size) { - std::ostringstream exception_msg; - exception_msg << "output buffer size " << output_size << " < input size " << data_size; - RAISE_EXCEPTION(exception_msg.str()) - } - auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size); - if (copy_ret != 0) { - RAISE_EXCEPTION_WITH_PARAM("copy output memory fail!ret = ", copy_ret); - } - } - MPI_Win_free(&window); - MPI_Comm_free(&comm); - return true; -} - -bool MPIAdapter::AllGather(const float *input, float *output, const std::vector &ranks_group, size_t data_num) { - if (ranks_group.empty()) { - RAISE_EXCEPTION("input rank group is empty!"); - return false; - } - auto group = AddGroup(ranks_group); - if (group == MPI_GROUP_NULL) { - RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail! rankid:", rank_id_); - } - MPI_Comm comm; - MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); - if (comm == MPI_COMM_NULL) { - RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail! rankid:", rank_id_); - } - auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi allgater fail!ret = ", ret); - } - ret = MPI_Comm_free(&comm); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail!ret = ", ret); - } - return true; -} -#endif // ENABLE_MPI -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/device_address.h b/mindspore/ccsrc/device/device_address.h deleted file mode 100644 index 0447cc2539..0000000000 --- a/mindspore/ccsrc/device/device_address.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_DEVICE_TENSOR_H -#define MINDSPORE_DEVICE_TENSOR_H - -#include -#include -#include -#include "ir/dtype.h" - -using std::string; - -namespace mindspore { -namespace device { -namespace cpu { -class CPUSimpleMemPlan; -class CPUResourceManager; -class CPUKernelRuntime; -} // namespace cpu -namespace ascend { -class AscendKernelRuntime; -class AscendMemoryManager; -namespace tasksink { -class TaskGenerator; -} // namespace tasksink -} // namespace ascend -namespace gpu { -class GPUKernelRuntime; -class GPUMemoryManager; -} // namespace gpu -} // namespace device -} // namespace mindspore - -namespace mindspore { -namespace device { -enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice }; -enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU }; - -class DeviceAddress { - public: - explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {} - explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) - : ptr_(ptr), size_(size), format_(format), type_id_(type_id) {} - virtual ~DeviceAddress() { ptr_ = nullptr; } - virtual bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const = 0; - virtual bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, - const void *host_ptr) const = 0; - const void *GetPtr() const { return ptr_; } - size_t GetSize() const { return size_; } - std::string format() const { return format_; } - TypeId type_id() const { return type_id_; } - void set_host_shape(const std::vector &shape) { host_shape_ = shape; } - virtual void set_status(DeviceAddressStatus status) {} - virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } - virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } - - protected: - const void *ptr() const { return ptr_; } - size_t size() const { return size_; } - void set_ptr(void *ptr) { ptr_ = ptr; } - void *ptr_{nullptr}; - size_t size_{0}; - size_t ref_count_{0}; - string format_{"DefaultFormat"}; - TypeId type_id_{kNumberTypeFloat16}; - bool from_mem_pool_{false}; - std::vector host_shape_{}; - friend class KernelRuntime; - friend class MemoryManager; - friend class mindspore::device::ascend::tasksink::TaskGenerator; - friend class mindspore::device::cpu::CPUSimpleMemPlan; - friend class mindspore::device::cpu::CPUResourceManager; - friend class mindspore::device::cpu::CPUKernelRuntime; - friend class mindspore::device::gpu::GPUKernelRuntime; - friend class mindspore::device::gpu::GPUMemoryManager; - friend class mindspore::device::ascend::AscendKernelRuntime; - friend class mindspore::device::ascend::AscendMemoryManager; -}; - -using DeviceAddressPtr = std::shared_ptr; -using DeviceAddressPtrList = std::vector; -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_DEVICE_TENSOR_H diff --git a/mindspore/ccsrc/device/gpu/blocking_queue.cc b/mindspore/ccsrc/device/gpu/blocking_queue.cc deleted file mode 100644 index 3b5e75f551..0000000000 --- a/mindspore/ccsrc/device/gpu/blocking_queue.cc +++ /dev/null @@ -1,143 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/blocking_queue.h" -#include -#include "device/gpu/gpu_common.h" -#include "common/utils.h" - -namespace mindspore { -namespace device { -GpuQueue::GpuQueue(void *addr, const std::vector &shape, const size_t &capacity) - : buffer_(addr), head_(0), tail_(0), shape_(shape), len_(0), capacity_(capacity), stream_(0), node_info_(nullptr) { - CHECK_CUDA_RET_WITH_ERROR(cudaStreamCreate(&stream_), "Cuda Create Stream Failed"); - node_info_ = std::make_unique(capacity); - for (auto item : shape) { - len_ += item; - } -} - -GpuQueue::~GpuQueue() { buffer_ = nullptr; } - -BlockQueueStatus_T GpuQueue::Push(const std::vector &data) { - int offset = 0; - for (size_t i = 0; i < data.size(); i++) { - auto item = data[i]; - if (item.data_ptr_ == nullptr || item.data_len_ != shape_[i]) { - MS_LOG(ERROR) << "Invalid Input: ptr: " << item.data_ptr_ << ", len: " << item.data_len_; - return ERROR_INPUT; - } - - void *addr = reinterpret_cast(buffer_) + tail_ * len_ + offset; - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(addr, item.data_ptr_, item.data_len_, cudaMemcpyHostToDevice, stream_), - "Cuda Memcpy Error"); - - offset += item.data_len_; - } - - node_info_[tail_].event_.reset(new cudaEvent_t()); - CHECK_CUDA_RET_WITH_ERROR(cudaEventCreate(&(*(node_info_[tail_].event_))), "Cuda Create Event Failed"); - node_info_[tail_].data_ = data; - tail_ = (tail_ + 1) % (capacity_); - return SUCCESS; -} - -BlockQueueStatus_T GpuQueue::Front(void **addr, size_t *len) const { - CHECK_CUDA_RET_WITH_ERROR(cudaEventSynchronize(*(node_info_[head_].event_)), "Cuda Event Syn Failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaEventDestroy(*(node_info_[head_].event_)), "Cuda Destroy Event Failed"); - *addr = (unsigned char *)buffer_ + head_ * len_; - *len = len_; - - for (auto item : node_info_[head_].data_) { - host_release_(item.data_ptr_); - } - return SUCCESS; -} - -BlockQueueStatus_T GpuQueue::Pop() { - head_ = (head_ + 1) % (capacity_); - return SUCCESS; -} - -bool GpuQueue::Destroy() { - if (stream_ != nullptr) { - auto ret = cudaStreamDestroy(stream_); - if (ret == cudaSuccess) { - return true; - } else { - return false; - } - } else { - return true; - } -} - -BlockQueueStatus_T BlockingQueue::Create(void *addr, const std::vector &shape, const size_t &capacity) { - if (addr == nullptr) { - MS_LOG(ERROR) << "addr is nullptr"; - return INTERNAL_ERROR; - } - queue_ = std::make_shared(addr, shape, capacity); - return SUCCESS; -} - -void BlockingQueue::RegisterRelease(const std::function &func) { queue_->RegisterRelease(func); } - -BlockQueueStatus_T BlockingQueue::Push(const std::vector &data, unsigned int timeout_in_sec) { - std::unique_lock locker(mutex_); - if (queue_->IsFull()) { - if (not_full_cond_.wait_for(locker, std::chrono::seconds(timeout_in_sec)) == std::cv_status::timeout) { - return TIMEOUT; - } - } - auto ret = queue_->Push(data); - if (ret) { - return ret; - } - not_empty_cond_.notify_one(); - return SUCCESS; -} - -BlockQueueStatus_T BlockingQueue::Front(void **addr, size_t *len) { - std::unique_lock locker(mutex_); - bool timeout = not_empty_cond_.wait_for(locker, std::chrono::seconds(30), [this] { return !queue_->IsEmpty(); }); - if (!timeout) { - return TIMEOUT; - } - - return queue_->Front(addr, len); -} - -BlockQueueStatus_T BlockingQueue::Pop() { - std::unique_lock locker(mutex_); - not_empty_cond_.wait(locker, [this] { return !queue_->IsEmpty(); }); - auto ret = queue_->Pop(); - if (ret) { - return ret; - } - not_full_cond_.notify_one(); - return SUCCESS; -} - -bool BlockingQueue::Destroy() { - if (queue_ != nullptr) { - return queue_->Destroy(); - } else { - return true; - } -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/cuda_common.h b/mindspore/ccsrc/device/gpu/cuda_common.h deleted file mode 100644 index b79ba8bc28..0000000000 --- a/mindspore/ccsrc/device/gpu/cuda_common.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ - -#include -#include "device/gpu/gpu_device_manager.h" - -namespace mindspore { -namespace device { -namespace gpu { -class CudaCommon { - public: - inline int threads_num() const { return threads_per_block_; } - inline int major_sm() const { return major_sm_; } - inline int blocks_num(const int total_threads) const { - return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_); - } - - static CudaCommon &GetInstance() { - static CudaCommon instance; - return instance; - } - - private: - CudaCommon() { - uint32_t device_id = GPUDeviceManager::GetInstance().cur_device_id(); - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, device_id); - threads_per_block_ = prop.maxThreadsPerBlock; - max_blocks_ = prop.multiProcessorCount; - major_sm_ = prop.major; - } - ~CudaCommon() = default; - CudaCommon(const CudaCommon &) = delete; - CudaCommon &operator=(const CudaCommon &) = delete; - - int max_blocks_; - int threads_per_block_; - int major_sm_; -}; -#define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) -#define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() -#define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() -#define MINIUM_SM 6 -#define RECOMMEND_SM 7 -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ diff --git a/mindspore/ccsrc/device/gpu/cuda_driver.cc b/mindspore/ccsrc/device/gpu/cuda_driver.cc deleted file mode 100644 index 0dee53df64..0000000000 --- a/mindspore/ccsrc/device/gpu/cuda_driver.cc +++ /dev/null @@ -1,231 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/cuda_driver.h" -#include -#include "utils/log_adapter.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace device { -namespace gpu { -size_t CudaDriver::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { - size_t retreat_count = 0; - auto ret = cudaMalloc(reinterpret_cast(addr), size); - // If free memory is not enough, then retry with mem_malloc_retry_rate_. - while (ret == cudaErrorMemoryAllocation) { - size = FloatToSize(size * mem_malloc_retry_rate_); - size = (size / mem_malloc_align_size_) * mem_malloc_align_size_; - ret = cudaMalloc(reinterpret_cast(addr), size); - retreat_count++; - if (retreat_count > mem_malloc_retry_conut_max_) { - break; - } - } - - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMalloc failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return 0; - } - return size; -} - -bool CudaDriver::FreeDeviceMem(const DeviceMemPtr &addr) { - auto ret = cudaFree(addr); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaFree failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -size_t CudaDriver::AllocHostPinnedMem(size_t size, void **addr) { - if (size == 0) { - MS_LOG(EXCEPTION) << "The memory allocate size is 0"; - } - auto ret = cudaHostAlloc(addr, size, cudaHostAllocDefault); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaHostAlloc failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return 0; - } - return size; -} - -void CudaDriver::FreeHostPinnedMem(void *addr) { - if (addr) { - auto ret = cudaFreeHost(addr); - if (ret != cudaSuccess) { - MS_LOG(EXCEPTION) << "cudaFreeHost failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - } - } -} - -bool CudaDriver::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) { - auto ret = cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemcpy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) { - auto ret = cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemcpy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, DeviceStream stream) { - auto ret = cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, (cudaStream_t)stream); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemcpyAsync failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, - DeviceStream stream) { - auto ret = cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, (cudaStream_t)stream); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemcpyAsync failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -size_t CudaDriver::total_mem_size() { - size_t free; - size_t total; - auto ret = cudaMemGetInfo(&free, &total); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemGetInfo failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return 0; - } - return total; -} - -size_t CudaDriver::free_mem_size() { - size_t free; - size_t total; - auto ret = cudaMemGetInfo(&free, &total); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemGetInfo failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return 0; - } - - return free; -} - -bool CudaDriver::CreateStream(DeviceStream *stream) { - auto ret = cudaStreamCreateWithFlags(reinterpret_cast(stream), cudaStreamNonBlocking); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamCreate failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::DestroyStream(const DeviceStream &stream) { - auto ret = cudaStreamDestroy((cudaStream_t)stream); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamDestroy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::SyncStream(const DeviceStream &stream) { - auto ret = cudaStreamSynchronize((cudaStream_t)stream); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamSynchronize failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::CreateEvent(DeviceEvent *event, unsigned int flag) { - auto ret = cudaEventCreateWithFlags(reinterpret_cast(event), flag); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaEventCreateWithFlags failed, ret[" << static_cast(ret) << "], " - << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::DestroyEvent(const DeviceEvent &event) { - auto ret = cudaEventDestroy((cudaEvent_t)event); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaEventDestroy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::RecordEvent(DeviceEvent event, DeviceStream stream) { - auto ret = cudaEventRecord((cudaEvent_t)event, (cudaStream_t)stream); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaEventRecord failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::SyncEvent(const DeviceEvent &event) { - auto ret = cudaEventSynchronize((cudaEvent_t)event); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaEventSynchronize failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::QueryEvent(const DeviceEvent &event) { - auto ret = cudaEventQuery((cudaEvent_t)event); - if (ret == cudaSuccess) { - return true; - } else if (ret == cudaErrorNotReady) { - return false; - } else { - MS_LOG(ERROR) << "cudaEventQuery failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } -} - -int CudaDriver::device_count() { - int dev_count; - auto ret = cudaGetDeviceCount(&dev_count); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaGetDeviceCount failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - } - return dev_count; -} - -bool CudaDriver::set_current_device(int index) { - auto ret = cudaSetDevice(index); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaSetDevice failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_common.h b/mindspore/ccsrc/device/gpu/distribution/collective_common.h deleted file mode 100644 index f9564a0c74..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/collective_common.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_COLLECTIVE_COMMON_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_COLLECTIVE_COMMON_H_ - -#include -#include "pybind11/pybind11.h" - -namespace mindspore { -namespace device { -namespace gpu { -#define MAX_HOSTNAME_LEN 1024 -#define CHECK_RET(expression, result, message) \ - { \ - auto ret = (expression); \ - if (ret != result) { \ - std::ostringstream oss; \ - oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ << " | GPU collective Error " << message \ - << " | Error Number " << ret; \ - pybind11::pybind11_fail(oss.str()); \ - } \ - } -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_COLLECTIVE_COMMON_H_ diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_fake_init.cc b/mindspore/ccsrc/device/gpu/distribution/collective_fake_init.cc deleted file mode 100644 index 06497a2e82..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/collective_fake_init.cc +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/distribution/collective_fake_init.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -namespace gpu { -void CollectiveFakeInitializer::InitCollective() { MS_LOG(EXCEPTION) << "build without enable gpu!"; } - -void CollectiveFakeInitializer::FinalizeCollective() { MS_LOG(EXCEPTION) << "build without enable gpu!"; } -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_init.cc b/mindspore/ccsrc/device/gpu/distribution/collective_init.cc deleted file mode 100644 index d7ab95bbe8..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/collective_init.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/distribution/collective_init.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -namespace gpu { -CollectiveInitializer &CollectiveInitializer::instance() { - static CollectiveInitializer instance = {}; - return instance; -} - -bool CollectiveInitializer::collective_inited() const { return collective_inited_; } - -const void *CollectiveInitializer::collective_handle() const { return collective_handle_; } - -void CollectiveInitializer::InitCollective() { - void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); - if (handle == nullptr) { - MS_LOG(EXCEPTION) - << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " - "installed.\n2.nccl is not " - "installed or found.\n3.mpi is not installed or found"; - } - auto mpi_init_funcptr = reinterpret_cast(dlsym(handle, "InitMPI")); - MS_EXCEPTION_IF_NULL(mpi_init_funcptr); - (*mpi_init_funcptr)(); - - CollectiveInitializer::instance().collective_inited_ = true; - CollectiveInitializer::instance().collective_handle_ = handle; -} - -void CollectiveInitializer::FinalizeCollective() { - if (CollectiveInitializer::instance().collective_handle_ != nullptr) { - if (dlclose(CollectiveInitializer::instance().collective_handle_) != 0) { - MS_LOG(EXCEPTION) << "Closing libgpu_collective.so handle failed."; - } - } -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_init.h b/mindspore/ccsrc/device/gpu/distribution/collective_init.h deleted file mode 100644 index 424abcf470..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/collective_init.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ - -#include - -namespace mindspore { -namespace device { -namespace gpu { -using InitMPI = void (*)(); -using InitNCCLComm = void (*)(); -using GetLocalRankId = int (*)(); - -class CollectiveInitializer { - public: - CollectiveInitializer(CollectiveInitializer const &) = delete; - CollectiveInitializer &operator=(const CollectiveInitializer &) = delete; - static CollectiveInitializer &instance(); - bool collective_inited() const; - const void *collective_handle() const; - static void InitCollective(); - static void FinalizeCollective(); - - private: - CollectiveInitializer() : collective_inited_(false) {} - ~CollectiveInitializer() = default; - - bool collective_inited_; - void *collective_handle_{nullptr}; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_wrapper.cc b/mindspore/ccsrc/device/gpu/distribution/collective_wrapper.cc deleted file mode 100644 index 5fb0f74849..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/collective_wrapper.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include "device/gpu/distribution/mpi_wrapper.h" -#include "device/gpu/distribution/nccl_wrapper.h" - -#ifndef EXPORT_WRAPPER -#define EXPORT_WRAPPER __attribute__((visibility("default"))) -#endif - -using MPIWrapper = mindspore::device::gpu::MPIWrapper; -using NCCLWrapper = mindspore::device::gpu::NCCLWrapper; - -extern "C" EXPORT_WRAPPER void InitMPI() { MPIWrapper::instance(); } - -extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } - -extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } - -extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, ncclRedOp_t reduce_type, - cudaStream_t stream) { - return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream); -} - -extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, cudaStream_t stream) { - return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream); -} - -extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, ncclRedOp_t reduce_type, - cudaStream_t stream) { - return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream); -} diff --git a/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.cc b/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.cc deleted file mode 100644 index 46b574c575..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.cc +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/distribution/mpi_wrapper.h" - -#include -#include -#include "device/gpu/distribution/nccl_wrapper.h" - -namespace mindspore { -namespace device { -namespace gpu { -MPIWrapper::MPIWrapper() : rank_id_(0), rank_size_(0), local_rank_id_(0) { Init(); } - -MPIWrapper::~MPIWrapper() { - int finalized; - MPI_Finalized(&finalized); - if (finalized == 0) { - MPI_Finalize(); - } -} - -MPIWrapper &MPIWrapper::instance() { - static MPIWrapper instance; - return instance; -} - -int MPIWrapper::local_rank_id() const { return local_rank_id_; } - -void MPIWrapper::Init() { - int initialized; - CHECK_RET(MPI_Initialized(&initialized), MPI_SUCCESS, "Failed to check mpi initialization status."); - - if (initialized == 0) { - MPI_Init(nullptr, nullptr); - } - CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id."); - CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size."); - NCCLWrapper::instance().set_rank(rank_id_, rank_size_); - AssignLocalRankId(); - - ncclUniqueId unique_id; - if (rank_id_ == 0) { - unique_id = NCCLWrapper::instance().nccl_unique_id(); - } - CHECK_RET(MPI_Bcast(reinterpret_cast(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD), - MPI_SUCCESS, "Failed to broadcast nccl unique id."); - NCCLWrapper::instance().set_nccl_unique_id(unique_id); - return; -} - -void MPIWrapper::AssignLocalRankId() { - char host_name[MAX_HOSTNAME_LEN] = {0}; - CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), 0, "Getting host name failed."); - size_t host_hash = std::hash()(host_name); - - const int kRankSize = rank_size_; - size_t all_host_hashs[kRankSize]; - all_host_hashs[rank_id_] = host_hash; - CHECK_RET(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs, sizeof(size_t), MPI_BYTE, MPI_COMM_WORLD), - MPI_SUCCESS, "MPI_Allgather host hashs failed."); - for (int global_rank = 0; global_rank < kRankSize; global_rank++) { - if (global_rank == rank_id_) { - break; - } - if (all_host_hashs[global_rank] == all_host_hashs[rank_id_]) { - local_rank_id_++; - } - } - return; -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.h b/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.h deleted file mode 100644 index 6dfedea922..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ - -#include -#include -#include -#include -#include -#include "device/gpu/distribution/collective_common.h" - -namespace mindspore { -namespace device { -namespace gpu { -class MPIWrapper { - public: - MPIWrapper(MPIWrapper const &) = delete; - MPIWrapper &operator=(const MPIWrapper &) = delete; - static MPIWrapper &instance(); - int local_rank_id() const; - - private: - MPIWrapper(); - ~MPIWrapper(); - void Init(); - void AssignLocalRankId(); - - int rank_id_; - int rank_size_; - int local_rank_id_; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ diff --git a/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.cc b/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.cc deleted file mode 100644 index aa4756a69f..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.cc +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/distribution/nccl_wrapper.h" - -namespace mindspore { -namespace device { -namespace gpu { -NCCLWrapper &NCCLWrapper::instance() { - static NCCLWrapper instance; - return instance; -} - -ncclUniqueId NCCLWrapper::nccl_unique_id() const { - ncclUniqueId unique_id; - CHECK_RET(ncclGetUniqueId(&unique_id), ncclSuccess, "Failed to create nccl unique id."); - return unique_id; -} - -void NCCLWrapper::set_nccl_unique_id(ncclUniqueId unique_id) { unique_id_ = unique_id; } - -void NCCLWrapper::set_rank(int rank_id, int rank_size) { - rank_id_ = rank_id; - rank_size_ = rank_size; -} - -void NCCLWrapper::InitNCCLComm() { - CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess, - "Failed to init nccl communicator."); -} - -ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, - ncclRedOp_t reduce_type, cudaStream_t stream) { - return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, comm_, stream); -} - -ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, - cudaStream_t stream) { - return ncclAllGather(input_addr, output_addr, count, data_type, comm_, stream); -} - -ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream) { - return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, comm_, stream); -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.h b/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.h deleted file mode 100644 index 5df1e63bb8..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ - -#include -#include -#include -#include "device/gpu/distribution/collective_common.h" - -namespace mindspore { -namespace device { -namespace gpu { -class NCCLWrapper { - public: - NCCLWrapper(NCCLWrapper const &) = delete; - NCCLWrapper &operator=(const NCCLWrapper &) = delete; - static NCCLWrapper &instance(); - ncclUniqueId nccl_unique_id() const; - void set_nccl_unique_id(ncclUniqueId unique_id); - void set_rank(int rank_id, int rank_size); - void InitNCCLComm(); - ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, - ncclRedOp_t op, cudaStream_t stream); - ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, - cudaStream_t stream); - ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, - ncclRedOp_t op, cudaStream_t stream); - - private: - NCCLWrapper() : rank_id_(-1), rank_size_(0) {} - ~NCCLWrapper() = default; - - private: - int rank_id_; - int rank_size_; - ncclUniqueId unique_id_; - ncclComm_t comm_; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.cc b/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.cc deleted file mode 100644 index 621ba557e5..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.cc +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_buffer_mgr.h" -#include -#include -#include "utils/log_adapter.h" -#include "common/utils.h" - -namespace mindspore { -namespace device { -unsigned int HandleMgr::AllocHandle() { - for (size_t i = 0; i < MAX_HANDLE_NUM; ++i) { - if (!handle_list_[i]) { - handle_list_[i] = true; - return (unsigned int)i; - } - } - return INVALID_HANDLE; -} - -void HandleMgr::FreeHandle(unsigned int handle_id) { - if (handle_id >= MAX_HANDLE_NUM) { - return; - } - handle_list_[handle_id] = false; -} - -GpuBufferMgr &GpuBufferMgr::GetInstance() noexcept { - static GpuBufferMgr instance; - return instance; -} - -BlockQueueStatus_T GpuBufferMgr::Create(unsigned int device_id, const std::string &channel_name, void *addr, - const std::vector &shape, const size_t &capacity) { - std::string name = std::to_string(device_id) + std::string("_") + channel_name; - if (name_queue_map_.count(name)) { - MS_LOG(ERROR) << "Queue not exist " << name; - return QUEUE_NOT_EXIST; - } - std::shared_ptr queue = std::make_shared(); - BlockQueueStatus_T rt = queue->Create(addr, shape, capacity); - if (rt != SUCCESS) { - return rt; - } - (void)name_queue_map_.insert(std::make_pair(name, queue)); - init_ = true; - return SUCCESS; -} - -unsigned int GpuBufferMgr::Open(unsigned int device_id, const std::string &channel_name, - const std::vector &shape, const std::function func) { - set_device(); - std::string name = std::to_string(device_id) + std::string("_") + channel_name; - if (!name_queue_map_.count(name)) { - MS_LOG(ERROR) << "Queue not exist " << name; - return HandleMgr::INVALID_HANDLE; - } - unsigned int handle = handle_mgr_.AllocHandle(); - if (handle == HandleMgr::INVALID_HANDLE) { - MS_LOG(ERROR) << "handle is invalid"; - return HandleMgr::INVALID_HANDLE; - } - (void)handle_queue_map_.insert(std::make_pair(handle, name_queue_map_[name])); - name_queue_map_[name]->RegisterRelease(func); - open_by_dataset_++; - return handle; -} - -unsigned int GpuBufferMgr::Open(unsigned int device_id, const std::string &channel_name, - const std::vector &shape) { - set_device(); - std::string name = std::to_string(device_id) + std::string("_") + channel_name; - if (!name_queue_map_.count(name)) { - MS_LOG(ERROR) << "Queue not exist " << name; - return HandleMgr::INVALID_HANDLE; - } - unsigned int handle = handle_mgr_.AllocHandle(); - if (handle == HandleMgr::INVALID_HANDLE) { - MS_LOG(ERROR) << "handle is invalid"; - return HandleMgr::INVALID_HANDLE; - } - (void)handle_queue_map_.insert(std::make_pair(handle, name_queue_map_[name])); - return handle; -} - -void GpuBufferMgr::set_device_id(int device_id) { cur_dev_id_ = device_id; } - -void GpuBufferMgr::set_device() const { - auto ret = cudaSetDevice(cur_dev_id_); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaSetDevice, ret[" << static_cast(ret) << "]"; - } -} - -BlockQueueStatus_T GpuBufferMgr::Push(unsigned int handle, const std::vector &data, - unsigned int timeout_in_sec) { - auto iter = handle_queue_map_.find(handle); - if (iter == handle_queue_map_.end()) { - return HANDLE_NOT_EXIST; - } - return iter->second->Push(data, timeout_in_sec); -} - -BlockQueueStatus_T GpuBufferMgr::Front(unsigned int handle, void **addr, size_t *len) { - auto iter = handle_queue_map_.find(handle); - if (iter == handle_queue_map_.end()) { - return HANDLE_NOT_EXIST; - } - return iter->second->Front(addr, len); -} - -BlockQueueStatus_T GpuBufferMgr::Pop(unsigned int handle) { - auto iter = handle_queue_map_.find(handle); - if (iter == handle_queue_map_.end()) { - return HANDLE_NOT_EXIST; - } - return iter->second->Pop(); -} - -void GpuBufferMgr::Close(unsigned int handle) noexcept { - if (!handle_queue_map_.count(handle)) { - return; - } - (void)handle_queue_map_.erase(handle); - handle_mgr_.FreeHandle(handle); - return; -} - -bool GpuBufferMgr::IsInit() const { return init_; } - -bool GpuBufferMgr::IsClosed() const { return closed_; } - -bool GpuBufferMgr::Destroy() { - for (auto iter = name_queue_map_.begin(); iter != name_queue_map_.end(); ++iter) { - std::shared_ptr queue = iter->second; - if (queue != nullptr) { - if (!queue->Destroy()) { - return false; - } - queue.reset(); - } - } - name_queue_map_.clear(); - return true; -} - -inline bool GpuBufferMgr::isCreated(unsigned int device_id, const std::string &channel_name) { - std::string name = std::to_string(device_id) + std::string("_") + channel_name; - if (name_queue_map_.count(name) != 0) { - return true; - } - return false; -} - -bool GpuBufferMgr::CloseNotify() { - bool result = true; - // lock scope - { - std::lock_guard lk(close_mutex_); - // set closed_ to be true, all the dataset retry can be jumped out of the while - closed_ = true; - } - - // wati for the dataset threads' ack - for (int i = 0; i < open_by_dataset_; i++) { - if (sema.Wait() == false) { - MS_LOG(ERROR) << "time out of receiving signals"; - result = false; - } - MS_LOG(DEBUG) << "receive one signal (" << i + 1 << "/" << open_by_dataset_ << ")"; - } - return result; -} - -void GpuBufferMgr::CloseConfirm() { sema.Signal(); } -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.h b/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.h deleted file mode 100644 index 5ce4a2cbdc..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.h +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "device/gpu/blocking_queue.h" - -#define EXPORT __attribute__((visibility("default"))) - -namespace mindspore { -namespace device { -static const unsigned int MAX_WAIT_TIME_IN_SEC = 60; - -class Semaphore { - public: - explicit Semaphore(int count = 0) : count_(count) {} - - inline void Signal() { - std::unique_lock lock(mutex_); - ++count_; - cv_.notify_one(); - } - - inline bool Wait() { - std::unique_lock lock(mutex_); - while (count_ == 0) { - if (cv_.wait_for(lock, std::chrono::seconds(MAX_WAIT_TIME_IN_SEC)) == std::cv_status::timeout) { - return false; - } - } - --count_; - return true; - } - - private: - std::mutex mutex_; - std::condition_variable cv_; - int count_; -}; - -class HandleMgr { - public: - static const unsigned int MAX_HANDLE_NUM = 32; - static const unsigned int INVALID_HANDLE = 0xffffffffUL; - - unsigned int AllocHandle(); - void FreeHandle(unsigned int); - - private: - bool handle_list_[MAX_HANDLE_NUM]; -}; - -class GpuBufferMgr { - public: - EXPORT GpuBufferMgr() : cur_dev_id_(0), init_(false), closed_(false), open_by_dataset_(0) {} - - EXPORT virtual ~GpuBufferMgr() = default; - - EXPORT static GpuBufferMgr &GetInstance() noexcept; - - EXPORT BlockQueueStatus_T Create(unsigned int device_id, const std::string &channel_name, void *addr, - const std::vector &shape, const size_t &capacity); - - // call for Push thread - EXPORT unsigned int Open(unsigned int device_id, const std::string &channel_name, const std::vector &shape, - std::function func); - - // call for Front/Pop thread - EXPORT unsigned int Open(unsigned int device_id, const std::string &channel_name, const std::vector &shape); - - EXPORT BlockQueueStatus_T Push(unsigned int handle, const std::vector &data, - unsigned int timeout_in_sec); - EXPORT BlockQueueStatus_T Front(unsigned int handle, void **addr, size_t *len); - EXPORT BlockQueueStatus_T Pop(unsigned int handle); - - EXPORT void set_device_id(int device_id); - - EXPORT void Close(unsigned int handle) noexcept; - - EXPORT bool IsInit() const; - - EXPORT bool IsClosed() const; - - EXPORT bool Destroy(); - - // call for Release GPU Resources - EXPORT bool CloseNotify(); - - // call for dataset send thread - EXPORT void CloseConfirm(); - - private: - void set_device() const; - - int cur_dev_id_; - bool init_; - bool closed_; - std::mutex mutex_; - std::mutex close_mutex_; - // how many queues opened by dataset - int open_by_dataset_; - Semaphore sema; - - HandleMgr handle_mgr_; - - std::map> handle_queue_map_; - std::map> name_queue_map_; - - inline bool isCreated(unsigned int device_id, const std::string &channel_name); - - GpuBufferMgr(const GpuBufferMgr &) = delete; - GpuBufferMgr &operator=(const GpuBufferMgr &) = delete; -}; -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_device_address.cc b/mindspore/ccsrc/device/gpu/gpu_device_address.cc deleted file mode 100644 index 401eb9f34e..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_device_address.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_device_address.h" -#include -#include "device/gpu/gpu_device_manager.h" -#include "utils/log_adapter.h" -#include "device/gpu/gpu_memory_allocator.h" - -namespace mindspore { -namespace device { -namespace gpu { -bool GPUDeviceAddress::SyncDeviceToHost(const std::vector &, size_t size, TypeId, void *host_ptr) const { - MS_EXCEPTION_IF_NULL(host_ptr); - auto &stream = GPUDeviceManager::GetInstance().default_stream(); - MS_EXCEPTION_IF_NULL(stream); - auto ret = GPUDeviceManager::GetInstance().SyncStream(stream); - if (!ret) { - MS_LOG(ERROR) << "SyncStream failed"; - return ret; - } - if (size != size_) { - MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_; - return true; - } - return GPUDeviceManager::GetInstance().CopyDeviceMemToHost(host_ptr, ptr_, size_); -} - -bool GPUDeviceAddress::SyncHostToDevice(const std::vector &, size_t, TypeId, const void *host_ptr) const { - MS_EXCEPTION_IF_NULL(host_ptr); - auto &stream = GPUDeviceManager::GetInstance().default_stream(); - MS_EXCEPTION_IF_NULL(stream); - if (!GPUDeviceManager::GetInstance().CopyHostMemToDeviceAsync(ptr_, host_ptr, size_, stream)) { - MS_LOG(ERROR) << "CopyHostMemToDeviceAsync failed"; - return false; - } - return GPUDeviceManager::GetInstance().SyncStream(stream); -} - -GPUDeviceAddress::~GPUDeviceAddress() { - if (ptr_ == nullptr) { - return; - } - if (from_mem_pool_) { - GPUMemoryAllocator::GetInstance().FreeTensorMem(ptr_); - ptr_ = nullptr; - } -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_device_address.h b/mindspore/ccsrc/device/gpu/gpu_device_address.h deleted file mode 100644 index 4074cb6ce9..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_device_address.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ - -#include -#include -#include "device/device_address.h" - -namespace mindspore { -namespace device { -namespace gpu { -class GPUDeviceAddress : public DeviceAddress { - public: - GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} - GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) - : DeviceAddress(ptr, size, format, type_id) {} - ~GPUDeviceAddress() override; - - bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; - bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; - void set_status(DeviceAddressStatus status) { status_ = status; } - DeviceAddressStatus status() const { return status_; } - DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; } - - private: - DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc b/mindspore/ccsrc/device/gpu/gpu_device_manager.cc deleted file mode 100644 index 9f5f37c606..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_device_manager.h" -#include "device/gpu/gpu_common.h" -#include "utils/log_adapter.h" -#include "utils/convert_utils.h" -#include "device/gpu/gpu_buffer_mgr.h" - -namespace mindspore { -namespace device { -namespace gpu { -void GPUDeviceManager::InitDevice() { - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::set_current_device(SizeToInt(cur_dev_id_)), "Failed to set current device id"); - CHECK_OP_RET_WITH_EXCEPT(CreateStream(&default_stream_), "Failed to create CUDA stream."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreate(&cudnn_handle_), "Failed to create cuDNN handle"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetStream(cudnn_handle_, reinterpret_cast(default_stream())), - "Failed to set stream for cuDNN handle."); - CHECK_CUBLAS_RET_WITH_EXCEPT(cublasCreate(&cublas_handle_), "Failed to create cuBLAS handle."); - CHECK_CUBLAS_RET_WITH_EXCEPT(cublasSetStream(cublas_handle_, reinterpret_cast(default_stream())), - "Failed to set stream for cuBLAS handle."); - CHECK_OP_RET_WITH_EXCEPT(GPUMemoryAllocator::GetInstance().Init(), "Failed to Init gpu memory allocator") -} - -void GPUDeviceManager::ReleaseDevice() { - for (DeviceStream stream : gpu_streams_) { - if (stream != nullptr) { - CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream), "Failed to destroy CUDA stream."); - } - } - if (cudnn_handle_ != nullptr) { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroy(cudnn_handle_), "Failed to destroy cuDNN handle"); - } - if (cublas_handle_ != nullptr) { - CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle."); - } - CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); -} - -bool GPUDeviceManager::CreateStream(DeviceStream *stream) { - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); - gpu_streams_.emplace_back(*stream); - return true; -} - -const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } - -int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } - -bool GPUDeviceManager::set_cur_device_id(uint32_t device_id) { - if (!dev_id_init_) { - dev_id_init_ = true; - cur_dev_id_ = device_id; - mindspore::device::GpuBufferMgr::GetInstance().set_device_id(UintToInt(device_id)); - return true; - } else { - MS_LOG(ERROR) << "Device already been set."; - return false; - } -} - -uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } - -bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } - -const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } - -const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } - -bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } - -bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { - return CudaDriver::CopyDeviceMemToHost(dst, src, size); -} - -bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { - return CudaDriver::CopyHostMemToDevice(dst, src, size); -} - -bool GPUDeviceManager::CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, - DeviceStream stream) const { - return CudaDriver::CopyDeviceMemToHostAsync(dst, src, size, stream); -} - -bool GPUDeviceManager::CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, - DeviceStream stream) const { - return CudaDriver::CopyHostMemToDeviceAsync(dst, src, size, stream); -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.h b/mindspore/ccsrc/device/gpu/gpu_device_manager.h deleted file mode 100644 index b6b630181e..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ - -#include -#include -#include -#include -#include "device/gpu/cuda_driver.h" -#include "device/gpu/gpu_memory_allocator.h" - -namespace mindspore { -namespace device { -namespace gpu { -class GPUDeviceManager { - public: - void InitDevice(); - void ReleaseDevice(); - - int device_count() const; - bool set_cur_device_id(uint32_t device_id); - uint32_t cur_device_id() const; - bool is_device_id_init() const; - - bool CreateStream(DeviceStream *stream); - bool SyncStream(const DeviceStream &stream) const; - const DeviceStream &default_stream() const; - - const cudnnHandle_t &GetCudnnHandle() const; - const cublasHandle_t &GetCublasHandle() const; - - bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; - bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; - - bool CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, DeviceStream stream) const; - bool CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, DeviceStream stream) const; - - static GPUDeviceManager &GetInstance() { - static GPUDeviceManager instance; - return instance; - } - - private: - GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} - ~GPUDeviceManager() = default; - GPUDeviceManager(const GPUDeviceManager &) = delete; - GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; - - // default CUDA stream used for all the kernels. - DeviceStream default_stream_{nullptr}; - - // all gpu CUDA streams including default_stream_. - std::vector gpu_streams_; - - // handle used for cuDNN kernels. - cudnnHandle_t cudnn_handle_{nullptr}; - - // handle used for cuBLAS kernels. - cublasHandle_t cublas_handle_{nullptr}; - - bool dev_id_init_; - uint32_t cur_dev_id_; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_build.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_build.cc deleted file mode 100644 index 19d2284510..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_build.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/gpu/gpu_kernel_build.h" -#include -#include "kernel/kernel.h" -#include "kernel/akg/akg_kernel_build.h" -#include "kernel/akg/gpu/akg_gpu_kernel_build.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "operator/ops.h" -#include "session/anf_runtime_algorithm.h" -namespace mindspore { -namespace device { -namespace gpu { -void GpuBuild(const KernelGraphPtr &kernel_graph) { - kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); - MS_EXCEPTION_IF_NULL(bin_map); - bin_map->Initialize(); - MS_EXCEPTION_IF_NULL(kernel_graph); - auto kernels = kernel_graph->execution_order(); - for (const auto &kernel : kernels) { - std::string kernel_name = session::AnfRuntimeAlgorithm::GetCNodeName(kernel); - if (kernel_name == prim::kPrimTupleGetItem->name() || kernel_name == prim::kPrimMakeTuple->name() || - kernel_name == prim::kPrimDepend->name() || kernel_name == prim::kPrimStateSetItem->name()) { - continue; - } - - if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) { - auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel); - if (!gpu_kernel_ptr) { - MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed"; - } - session::AnfRuntimeAlgorithm::SetKernelMod(gpu_kernel_ptr, kernel.get()); - } else { - auto gpu_kernel_ptr = kernel::GpuKernelFactory::GetInstance().Create(kernel_name, kernel); - if (!gpu_kernel_ptr) { - MS_LOG(EXCEPTION) << "Build gpu kernel op[" << kernel_name << "] failed"; - } - if (!gpu_kernel_ptr->Init(kernel)) { - MS_LOG(EXCEPTION) << "Initialize gpu kernel op[" << kernel_name << "] failed."; - } - session::AnfRuntimeAlgorithm::SetKernelMod((kernel::KernelModPtr)gpu_kernel_ptr, kernel.get()); - } - } -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_build.h b/mindspore/ccsrc/device/gpu/gpu_kernel_build.h deleted file mode 100644 index 5770e4d3b1..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_build.h +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ - -#include -#include "session/kernel_graph.h" -namespace mindspore { -namespace device { -namespace gpu { -void GpuBuild(const std::shared_ptr &kernel_graph); -} // namespace gpu -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc deleted file mode 100644 index ad0e093d7f..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ /dev/null @@ -1,611 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_kernel_runtime.h" -#include "device/gpu/gpu_device_address.h" -#include "device/gpu/cuda_driver.h" -#include "device/gpu/gpu_buffer_mgr.h" -#include "device/gpu/gpu_device_manager.h" -#include "device/gpu/gpu_memory_allocator.h" -#include "device/gpu/distribution/collective_init.h" -#include "utils/convert_utils.h" -#include "utils/context/ms_context.h" -#include "device/kernel_runtime_manager.h" -#include "device/gpu/gpu_common.h" -#include "common/utils.h" -#include "device/gpu/gpu_memory_manager.h" -#include "kernel/common_utils.h" -#include "device/gpu/gpu_memory_copy_manager.h" - -namespace mindspore { -namespace device { -namespace gpu { -using mindspore::device::memswap::MemSwapManager; -using mindspore::device::memswap::SwapKind; -bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } - -bool GPUKernelRuntime::Init() { - if (device_init_ == true) { - GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); - return true; - } - auto ret = InitDevice(); - if (!ret) { - MS_LOG(ERROR) << "InitDevice error."; - return ret; - } - mem_manager_ = std::make_shared(); - MS_EXCEPTION_IF_NULL(mem_manager_); - mem_manager_->MallocDeviceMemory(); - const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); - bool collective_inited = CollectiveInitializer::instance().collective_inited(); - if (collective_inited && collective_handle_ != nullptr) { - auto init_nccl_comm_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "InitNCCLComm")); - MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr); - (*init_nccl_comm_funcptr)(); - } - device_init_ = true; - return ret; -} - -DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) { - return std::make_shared(device_ptr, device_size, format, type_id); -} - -bool GPUKernelRuntime::InitDevice() { - if (GPUDeviceManager::GetInstance().device_count() <= 0) { - MS_LOG(ERROR) << "No GPU device found."; - return false; - } - const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); - bool collective_inited = CollectiveInitializer::instance().collective_inited(); - if (collective_inited && collective_handle_ != nullptr) { - auto get_local_rank_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "local_rank_id")); - MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); - device_id_ = IntToUint((*get_local_rank_funcptr)()); - } - if (!GPUDeviceManager::GetInstance().is_device_id_init()) { - if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_id_)) { - MS_LOG(ERROR) << "Failed to set current device to " << SizeToInt(device_id_); - return false; - } - } - GPUDeviceManager::GetInstance().InitDevice(); - stream_ = GPUDeviceManager::GetInstance().default_stream(); - if (stream_ == nullptr) { - MS_LOG(ERROR) << "No default CUDA stream found."; - return false; - } - return true; -} - -void GPUKernelRuntime::ReleaseDeviceRes() { - // For dataset mode. - if (GpuBufferMgr::GetInstance().IsInit()) { - if (!GpuBufferMgr::GetInstance().IsClosed()) { - if (!GpuBufferMgr::GetInstance().CloseNotify()) { - MS_LOG(EXCEPTION) << "Could not close gpu data queue."; - } - } - CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue."); - } - - // Destroy remaining memory swap events and free host memory. - for (auto &item : mem_swap_map_) { - auto &mem_swap_manager = item.second; - MS_EXCEPTION_IF_NULL(mem_swap_manager); - if (mem_swap_manager->trigger_swap()) { - mem_swap_manager->ClearSwapQueue(); - mem_swap_manager->ReleaseHostPinnedMem(); - } - } - - GPUDeviceManager::GetInstance().ReleaseDevice(); - if (mem_manager_ != nullptr) { - mem_manager_->FreeDeviceMemory(); - } - - kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); - MS_EXCEPTION_IF_NULL(bin_map); - bin_map->RemoveKernelCache(); -} - -void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(mem_manager_); - mem_manager_->ResetDynamicMemory(); - AssignStaticMemoryInput(graph); - AssignStaticMemoryValueNode(graph); - bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); - if (is_enable_dynamic_mem) { - // Use the dynamic memory pool. - InitKernelRefCount(graph); - InitKernelOutputAddress(graph); - } else { - AssignDynamicMemory(graph); - } -} - -bool GPUKernelRuntime::Run(session::KernelGraph *graph) { - bool ret = true; - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); - bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); - auto iter = mem_swap_map_.find(graph); - if (iter == mem_swap_map_.end()) { - GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared(); - iter = mem_swap_map_.emplace(graph, std::make_shared(gpu_mem_copy_manager)).first; - } - mem_swap_manager_ = iter->second; - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); - if (is_enable_dynamic_mem && !is_enable_pynative_infer) { - while (!LaunchKernelDynamic(graph)) { - ClearKernelOutputAddress(graph); - if (!mem_swap_manager_->mem_swap_init()) { - mem_swap_manager_->Init(graph); - } - if (!mem_swap_manager_->RetreatSwapInfo()) { - return false; - } - } - } else { - ret = LaunchKernel(graph); - } - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(DEBUG) << "GPU kernel runtime run graph in " << cost << " us"; - return ret; -} - -void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - // Init the kernel reference count. - if (!mem_reuse_util_ptr->InitDynamicKernelRef(graph)) { - MS_LOG(EXCEPTION) << "Init kernel reference count failed"; - } - mem_reuse_util_ptr->SetKernelDefMap(); - mem_reuse_util_ptr->SetReuseRefCount(); - // Can't free the device address of graph output, so set the reference count of graph output specially. - mem_reuse_util_ptr->SetGraphOutputRefCount(); - // Can't free the device address of summary nodes, so set the reference count of summary nodes specially. - mem_reuse_util_ptr->SetSummaryNodesRefCount(); - auto graph_id = graph->graph_id(); - mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr; -} - -void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto &kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - for (size_t i = 0; i < output_sizes.size(); ++i) { - if (AnfAlgo::OutputAddrExist(kernel, i)) { - continue; - } - std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); - auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); - auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); - AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); - } - } -} - -void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto &kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - for (size_t i = 0; i < output_sizes.size(); ++i) { - if (!AnfAlgo::OutputAddrExist(kernel, i)) { - continue; - } - - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); - if (device_address->ptr_) { - mem_manager_->FreeMemFromMemPool(device_address); - } - device_address->set_status(DeviceAddressStatus::kInDevice); - } - } -} - -bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - auto graph_id = graph->graph_id(); - auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - // Reset the reference count. - mem_reuse_util_ptr->ResetDynamicUsedRefCount(); - // The inputs and outputs memory of communication kernel need be continuous, so separate processing. - AllocCommunicationOpDynamicRes(graph); - - auto &kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); - if (!ret) { - return false; - } - if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { - MS_LOG(EXCEPTION) << "Launch kernel failed."; - } - FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id); - - if (mem_swap_manager_->trigger_swap() && mem_swap_manager_->QueryKernelTriggerSwap(kernel)) { - CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); - if (!AddMemSwapTask(kernel)) { - return false; - } - } - - if (mem_swap_manager_->trigger_swap()) { - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); - } - } - - CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); - if (mem_swap_manager_->trigger_swap()) { - mem_swap_manager_->ClearSwapQueue(); - } - return true; -} - -bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); - for (auto &mem_swap_info : mem_swap_info_list) { - auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); - const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; - auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false); - - if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { - mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); - } else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { - auto status = device_address->status(); - if (status == DeviceAddressStatus::kInDeviceToHost) { - mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); - device_address->set_status(DeviceAddressStatus::kInDevice); - } else if (status == DeviceAddressStatus::kInHost) { - if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_)) { - return false; - } - if (!mem_swap_manager_->FindInSwapInBlackList(device_address->ptr_)) { - mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address); - } - } - } - } - return true; -} - -bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) { - MS_EXCEPTION_IF_NULL(mem_manager_); - auto ret = mem_manager_->MallocMemFromMemPool(device_address, size); - if (!ret) { - if (!mem_swap_manager_->trigger_swap()) { - return false; - } - - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { - device_address_swap_out->set_status(DeviceAddressStatus::kInHost); - mem_manager_->FreeMemFromMemPool(device_address_swap_out); - } - } - - ret = mem_manager_->MallocMemFromMemPool(device_address, size); - if (!ret) { - return false; - } - } - return true; -} - -void *GPUKernelRuntime::AttemptMallocMem(size_t size) { - MS_EXCEPTION_IF_NULL(mem_manager_); - auto device_ptr = mem_manager_->MallocMemFromMemPool(size); - if (!device_ptr) { - if (!mem_swap_manager_->trigger_swap()) { - return nullptr; - } - - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { - device_address_swap_out->set_status(DeviceAddressStatus::kInHost); - mem_manager_->FreeMemFromMemPool(device_address_swap_out); - } - } - - device_ptr = mem_manager_->MallocMemFromMemPool(size); - if (!device_ptr) { - return nullptr; - } - } - return device_ptr; -} - -bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, - const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, - AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) { - if (!AllocKernelInputDynamicRes(kernel, kernel_inputs)) { - return false; - } - if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs)) { - return false; - } - if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces)) { - return false; - } - return true; -} - -bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_inputs); - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); - MS_EXCEPTION_IF_NULL(device_address); - if (mem_swap_manager_->trigger_swap()) { - while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { - device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); - } - - auto status = device_address->status(); - switch (status) { - case DeviceAddressStatus::kInDevice: - break; - case DeviceAddressStatus::kInHost: - break; - case DeviceAddressStatus::kInDeviceToHost: { - mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); - device_address->set_status(DeviceAddressStatus::kInDevice); - break; - } - case DeviceAddressStatus::kInHostToDevice: { - while (device_address->status() != DeviceAddressStatus::kInDevice) { - while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { - device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); - } - } - break; - } - default: - MS_LOG(ERROR) << "Invaild device address status"; - return false; - } - } - MS_EXCEPTION_IF_NULL(device_address->ptr_); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - input->size = device_address->size_; - kernel_inputs->emplace_back(input); - } - return true; -} - -bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, - const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_outputs) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_outputs); - MS_EXCEPTION_IF_NULL(mem_manager_); - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - if (mem_swap_manager_->trigger_swap()) { - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { - device_address_swap_out->set_status(DeviceAddressStatus::kInHost); - mem_manager_->FreeMemFromMemPool(device_address_swap_out); - } - } - } - auto output_sizes = kernel_mod.GetOutputSizeList(); - for (size_t i = 0; i < output_sizes.size(); ++i) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); - MS_EXCEPTION_IF_NULL(device_address); - if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) { - return false; - } - kernel::AddressPtr output = std::make_shared(); - MS_EXCEPTION_IF_NULL(output); - output->addr = device_address->ptr_; - output->size = output_sizes[i]; - kernel_outputs->emplace_back(output); - } - return true; -} - -bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, - const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_workspaces) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_workspaces); - auto workspace_sizes = kernel_mod.GetWorkspaceSizeList(); - for (size_t i = 0; i < workspace_sizes.size(); ++i) { - if (workspace_sizes[i] == 0) { - kernel_workspaces->emplace_back(nullptr); - continue; - } - auto device_ptr = AttemptMallocMem(workspace_sizes[i]); - if (!device_ptr) { - return false; - } - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_ptr; - workspace->size = workspace_sizes[i]; - kernel_workspaces->emplace_back(workspace); - } - return true; -} - -void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto &kernels = graph->execution_order(); - for (auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - if (AnfAlgo::IsCommunicationOp(kernel)) { - AllocCommunicationOpInputDynamicRes(kernel); - AllocCommunicationOpOutputDynamicRes(kernel); - } - } -} - -void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(kernel); - bool is_need_alloc_memory = false; - bool is_need_free_memory = false; - size_t total_size = 0; - std::vector size_list; - DeviceAddressPtrList addr_list; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); - MS_EXCEPTION_IF_NULL(device_address); - if (device_address->ptr_ == nullptr) { - is_need_alloc_memory = true; - } else { - is_need_free_memory = true; - } - total_size += device_address->size_; - size_list.emplace_back(device_address->size_); - addr_list.emplace_back(device_address); - } - AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); -} - -void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(kernel); - bool is_need_alloc_memory = false; - bool is_need_free_memory = false; - size_t total_size = 0; - std::vector size_list; - DeviceAddressPtrList addr_list; - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - for (size_t i = 0; i < output_sizes.size(); ++i) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); - MS_EXCEPTION_IF_NULL(device_address); - if (device_address->ptr_ == nullptr) { - is_need_alloc_memory = true; - } else { - is_need_free_memory = true; - } - total_size += output_sizes[i]; - size_list.emplace_back(output_sizes[i]); - addr_list.emplace_back(device_address); - } - AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); -} - -void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, - const DeviceAddressPtrList addr_list, size_t total_size, - std::vector size_list) { - MS_EXCEPTION_IF_NULL(mem_manager_); - if (!is_need_alloc_memory) { - return; - } - if (is_need_free_memory) { - for (const auto &iter : addr_list) { - MS_EXCEPTION_IF_NULL(iter); - // Free the inputs/outputs of communication kernel which are not released. - if (iter->ptr_ != nullptr) { - mem_manager_->FreeMemFromMemPool(iter); - } - } - } - auto ret = mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); - if (!ret) { - MS_LOG(EXCEPTION) << "Malloc device memory failed."; - } -} - -void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, - const AddressPtrList &kernel_workspaces, uint32_t graph_id) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - auto cnode = kernel->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::IsCommunicationOp(kernel)) { - return; - } - // Free the input of kernel by reference count. - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetKernelInputRef(cnode, i); - if (kernel_ref_count_ptr == nullptr) { - continue; - } - kernel_ref_count_ptr->ref_count_dynamic_use_--; - if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) { - MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; - } - if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); - mem_manager_->FreeMemFromMemPool(device_address); - device_address->set_status(DeviceAddressStatus::kInDevice); - } - } - // Free the output of kernel, if output has no reference. - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) { - auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetRef(cnode, i); - if (kernel_ref_count_ptr == nullptr) { - continue; - } - if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); - mem_manager_->FreeMemFromMemPool(device_address); - device_address->set_status(DeviceAddressStatus::kInDevice); - } - } - // Free the workspace of kernel. - for (size_t i = 0; i < kernel_workspaces.size(); ++i) { - auto workspace = kernel_workspaces[i]; - if (workspace != nullptr) { - MS_EXCEPTION_IF_NULL(workspace->addr); - mem_manager_->FreeMemFromMemPool(workspace->addr); - workspace->addr = nullptr; - } - } -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h deleted file mode 100644 index ea3ab17160..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ - -#include -#include -#include -#include -#include -#include "device/kernel_runtime.h" -#include "device/kernel_runtime_manager.h" -#include "pre_activate/mem_reuse/mem_swap_manager.h" - -namespace mindspore { -namespace device { -namespace gpu { -using mindspore::device::memswap::MemSwapManagerPtr; -class GPUKernelRuntime : public KernelRuntime { - public: - GPUKernelRuntime() = default; - ~GPUKernelRuntime() override = default; - bool Init() override; - void ReleaseDeviceRes() override; - void AssignMemory(session::KernelGraph *graph) override; - bool Run(session::KernelGraph *graph) override; - - protected: - DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) override; - bool SyncStream() override; - - private: - GPUKernelRuntime(const GPUKernelRuntime &); - GPUKernelRuntime &operator=(const GPUKernelRuntime &); - bool InitDevice(); - bool device_init_{false}; - - // The related functions and members for using dynamic memory pool. - void InitKernelRefCount(const session::KernelGraph *graph); - void InitKernelOutputAddress(const session::KernelGraph *graph); - void ClearKernelOutputAddress(const session::KernelGraph *graph); - bool LaunchKernelDynamic(const session::KernelGraph *graph); - bool AddMemSwapTask(const AnfNodePtr &kernel); - bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); - void *AttemptMallocMem(size_t size); - bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, - AddressPtrList *kernel_outputs); - bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs); - bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_outputs); - bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, - const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces); - void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); - void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); - void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); - void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, - const DeviceAddressPtrList addr_list, size_t total_size, - std::vector size_list); - void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, - uint32_t graph_id); - std::unordered_map mem_reuse_util_map_; - std::unordered_map mem_swap_map_; - MemSwapManagerPtr mem_swap_manager_{nullptr}; -}; -MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); -} // namespace gpu -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc deleted file mode 100644 index 9137945661..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "device/gpu/gpu_memory_allocator.h" -#include "device/gpu/cuda_driver.h" -#include "utils/log_adapter.h" -#include "utils/context/ms_context.h" -#include "utils/convert_utils_base.h" - -namespace mindspore { -namespace device { -namespace gpu { -bool GPUMemoryAllocator::Init() { - size_t total_size = total_mem_size(); - size_t free_size = CudaDriver::free_mem_size(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - limited_device_memory_ = context_ptr->max_device_memory(); - available_device_memory_ = FloatToSize(limited_device_memory_ * 1024 * 1024 * 1024); - if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) { - MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size - << ", set max available memory size " << available_device_memory_ << "."; - } else { - MS_LOG(EXCEPTION) << "GPU device memory error, total memory size " << total_size << ", current free memory size " - << free_size << ", set max available memory size " << available_device_memory_ << "."; - } - return true; -} - -void GPUMemoryAllocator::CheckMaxDeviceMemory() const { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - auto max_device_memory = context_ptr->max_device_memory(); - // Currently not support modifying the max device memory. - if (limited_device_memory_ != max_device_memory) { - MS_LOG(EXCEPTION) - << "Can't change context param max_device_memory in runtime, currently effective max_device_memory(" - << limited_device_memory_ << "GB), set new max_device_memory(" << max_device_memory << "GB) failed."; - } -} - -bool GPUMemoryAllocator::Finalize() { - if (buffer_q_addr_ != nullptr) { - if (!CudaDriver::FreeDeviceMem(buffer_q_addr_)) { - MS_LOG(ERROR) << "Could not free buffer queue memory."; - return false; - } - } - return true; -} - -bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { - auto alloc_size = AllocDeviceMem(size, addr); - buffer_q_addr_ = *addr; - // Buffer queue needs to ensure that the alloc_size and size is equal. - return (alloc_size == size) ? true : false; -} - -size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { - if (size == 0) { - MS_LOG(EXCEPTION) << "The memory alloc size is 0."; - } - auto free_size = free_mem_size(); - if (size > free_size) { - MS_LOG(EXCEPTION) << "Memory not enough: current free memory size[" << free_size - << "] is smaller than required size[" << size << "]."; - } - - auto alloc_size = CudaDriver::AllocDeviceMem(size, addr); - if (alloc_size == 0) { - MS_LOG(EXCEPTION) << "Alloc device memory[" << size << "] failed."; - } - total_used_device_memory_ += alloc_size; - available_device_memory_ -= alloc_size; - MS_LOG(INFO) << "Current free memory size[" << free_size - alloc_size << "], current alloc size[" << alloc_size - << "], total used size[" << total_used_device_memory_ << "]."; - return alloc_size; -} - -bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } - -size_t GPUMemoryAllocator::free_mem_size() { return std::min(CudaDriver::free_mem_size(), available_device_memory_); } - -size_t GPUMemoryAllocator::total_mem_size() { return CudaDriver::total_mem_size(); } -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h deleted file mode 100644 index 90d7791057..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ - -#include -#include "device/gpu/cuda_driver.h" -#include "pre_activate/mem_reuse/mem_dynamic_allocator.h" - -namespace mindspore { -namespace device { -namespace gpu { -class GPUMemoryAllocator : public DynamicMemPoolBestFit { - public: - ~GPUMemoryAllocator() override = default; - bool Init(); - void CheckMaxDeviceMemory() const; - bool Finalize(); - bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); - - size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; - bool FreeDeviceMem(const DeviceMemPtr &addr) override; - size_t free_mem_size() override; - size_t total_mem_size() override; - - static GPUMemoryAllocator &GetInstance() { - static GPUMemoryAllocator instance; - return instance; - } - - private: - GPUMemoryAllocator() = default; - GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; - GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; - - // Used to track address of data buffer queue. - DeviceMemPtr buffer_q_addr_{nullptr}; - - float limited_device_memory_{0.0}; - size_t total_used_device_memory_{0}; - size_t available_device_memory_{0}; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.cc b/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.cc deleted file mode 100644 index 80206f309d..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_memory_copy_manager.h" -#include "device/gpu/gpu_common.h" -#include "device/gpu/gpu_device_manager.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace device { -namespace gpu { -void GPUMemCopyManager::Init() { - CHECK_OP_RET_WITH_EXCEPT(GPUDeviceManager::GetInstance().CreateStream(&swap_out_stream_), - "Failed to create CUDA stream of memory swap out."); - CHECK_OP_RET_WITH_EXCEPT(GPUDeviceManager::GetInstance().CreateStream(&swap_in_stream_), - "Failed to create CUDA stream of memory swap in."); -} - -void GPUMemCopyManager::AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) { - MS_EXCEPTION_IF_NULL(device_address); - MS_EXCEPTION_IF_NULL(host_addr.addr); - DeviceEvent event = nullptr; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event."); - DeviceMemPtr device_ptr = const_cast(device_address->GetPtr()); - MS_EXCEPTION_IF_NULL(device_ptr); - device_address->set_status(DeviceAddressStatus::kInDeviceToHost); - - CHECK_OP_RET_WITH_EXCEPT( - CudaDriver::CopyDeviceMemToHostAsync(host_addr.addr, device_ptr, host_addr.size, swap_out_stream_), - "Failed to copy device memory to host."); - - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(event, swap_out_stream_), - "Failed to record CUDA event to swap out stream."); - swap_out_queue_.emplace(device_address, event); -} - -void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) { - MS_EXCEPTION_IF_NULL(device_address); - MS_EXCEPTION_IF_NULL(host_addr.addr); - DeviceEvent event = nullptr; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event."); - DeviceMemPtr device_ptr = const_cast(device_address->GetPtr()); - MS_EXCEPTION_IF_NULL(device_ptr); - device_address->set_status(DeviceAddressStatus::kInHostToDevice); - - CHECK_OP_RET_WITH_EXCEPT( - CudaDriver::CopyHostMemToDeviceAsync(device_ptr, host_addr.addr, host_addr.size, swap_in_stream_), - "Failed to copy host memory to device."); - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(event, swap_in_stream_), - "Failed to record CUDA event to swap in stream."); - swap_in_queue_.emplace(device_address, event); -} - -bool GPUMemCopyManager::SyncMemCopyStream(SwapKind swap_kind) { - if (swap_kind == SwapKind::kDeviceToHost) { - return GPUDeviceManager::GetInstance().SyncStream(swap_out_stream_); - } else { - return GPUDeviceManager::GetInstance().SyncStream(swap_in_stream_); - } -} - -DeviceAddressPtr GPUMemCopyManager::UpdateSwapOutQueue() { - if (swap_out_queue_.empty()) { - return nullptr; - } - auto &task = swap_out_queue_.front(); - auto device_address = task.first; - auto &event = task.second; - bool finish_swap = CudaDriver::QueryEvent(event); - if (!finish_swap) { - return nullptr; - } - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap out."); - swap_out_queue_.pop(); - return device_address; -} - -DeviceAddressPtr GPUMemCopyManager::UpdateSwapInQueue() { - if (swap_in_queue_.empty()) { - return nullptr; - } - auto &task = swap_in_queue_.front(); - auto device_address = task.first; - auto &event = task.second; - bool finish_swap = CudaDriver::QueryEvent(event); - if (!finish_swap) { - return nullptr; - } - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap in."); - swap_in_queue_.pop(); - return device_address; -} - -bool GPUMemCopyManager::AllocHostPinnedMem(size_t size, void **addr) const { - auto alloc_size = CudaDriver::AllocHostPinnedMem(size, addr); - return alloc_size == size; -} - -void GPUMemCopyManager::FreeHostPinnedMem(void *addr) const { CudaDriver::FreeHostPinnedMem(addr); } - -void GPUMemCopyManager::ClearSwapQueue() { - CHECK_OP_RET_WITH_EXCEPT(SyncMemCopyStream(SwapKind::kDeviceToHost), "Failed to sync swap out stream"); - CHECK_OP_RET_WITH_EXCEPT(SyncMemCopyStream(SwapKind::kHostToDevice), "Failed to sync swap in stream"); - - while (!swap_out_queue_.empty()) { - auto &event = swap_out_queue_.front().second; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap out."); - swap_out_queue_.pop(); - } - while (!swap_in_queue_.empty()) { - auto &event = swap_in_queue_.front().second; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap in."); - swap_in_queue_.pop(); - } -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.h b/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.h deleted file mode 100644 index 36ff273015..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ - -#include -#include -#include -#include "pre_activate/mem_reuse/mem_copy_manager.h" -#include "device/device_address.h" -#include "device/gpu/cuda_driver.h" -#include "kernel/kernel.h" - -namespace mindspore { -namespace device { -namespace gpu { -using mindspore::device::memswap::MemCopyManager; -using mindspore::device::memswap::SwapKind; -class GPUMemCopyManager : public MemCopyManager { - public: - GPUMemCopyManager() = default; - - ~GPUMemCopyManager() override = default; - - void Init() override; - - void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override; - - void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override; - - bool SyncMemCopyStream(SwapKind swap_kind) override; - - DeviceAddressPtr UpdateSwapOutQueue() override; - - DeviceAddressPtr UpdateSwapInQueue() override; - - bool AllocHostPinnedMem(size_t size, void **addr) const override; - - void FreeHostPinnedMem(void *addr) const override; - - void ClearSwapQueue() override; - - private: - DeviceStream swap_out_stream_{nullptr}; - DeviceStream swap_in_stream_{nullptr}; - std::queue> swap_out_queue_; - std::queue> swap_in_queue_; -}; -using GPUMemCopyManagerPtr = std::shared_ptr; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc deleted file mode 100644 index 9a63921add..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_memory_manager.h" -#include "device/gpu/gpu_memory_allocator.h" -#include "utils/context/ms_context.h" -#include "utils/convert_utils.h" -namespace mindspore { -namespace device { -namespace gpu { -void *GPUMemoryManager::MallocMemFromMemPool(size_t size) { - return GPUMemoryAllocator::GetInstance().AllocTensorMem(size); -} - -void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) { - GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr); -} - -std::vector GPUMemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { - return GPUMemoryAllocator::GetInstance().AllocContinuousTensorMem(total_size, size_list); -} - -void GPUMemoryManager::MallocDeviceMemory() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - // If use the dynamic memory pool, then alloc the first memory block to init. - if (context_ptr->enable_dynamic_mem_pool()) { - auto device_addr = MallocMemFromMemPool(1); - if (!device_addr) { - MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; - } - } else { - // Need to reserve 20% space for dynamic memory - const float init_gpu_mem_ratio = 0.8; - size_t mem_size = FloatToSize(GPUMemoryAllocator::GetInstance().free_mem_size() * init_gpu_mem_ratio); - auto alloc_size = - GPUMemoryAllocator::GetInstance().AllocDeviceMem(mem_size, reinterpret_cast(&device_mem_base_)); - device_mem_size_ = alloc_size; - static_mem_offset_ = device_mem_size_; - } -} - -void GPUMemoryManager::FreeDeviceMemory() { - if (device_mem_base_ != nullptr) { - if (!GPUMemoryAllocator::GetInstance().FreeDeviceMem(device_mem_base_)) { - MS_LOG(EXCEPTION) << "Could not free gpu device memory."; - } - } - GPUMemoryAllocator::GetInstance().ReleaseDeviceRes(); -} - -uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_dynamic_mem_pool()) { - auto device_ptr = MallocMemFromMemPool(size); - MS_EXCEPTION_IF_NULL(device_ptr); - return AddressOffset(device_ptr, 0); - } - - auto align_size = GetCommonAlignSize(size); - if (static_mem_offset_ < align_size) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] failed!"; - } - auto offset = static_mem_offset_ - align_size; - if (dynamic_mem_offset_ > offset) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] failed!"; - } - total_static_size_ += align_size; - static_mem_offset_ = offset; - return device_mem_base_ + offset; -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.h b/mindspore/ccsrc/device/gpu/gpu_memory_manager.h deleted file mode 100644 index c79fb9cc22..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ -#include -#include "device/memory_manager.h" -namespace mindspore { -namespace device { -namespace gpu { -class GPUMemoryManager : public MemoryManager { - public: - GPUMemoryManager() = default; - virtual ~GPUMemoryManager() = default; - - void MallocDeviceMemory() override; - void FreeDeviceMemory() override; - - void *MallocMemFromMemPool(size_t size) override; - void FreeMemFromMemPool(void *device_ptr) override; - std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); - - protected: - uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc b/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc deleted file mode 100644 index 42cdcf29ec..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc +++ /dev/null @@ -1,193 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_stream_assign.h" -#include -#include -#include -#include -#include "device/gpu/gpu_common.h" -#include "device/gpu/kernel_info_setter.h" -#include "device/gpu/gpu_device_manager.h" - -namespace mindspore { -namespace device { -namespace gpu { -void AssignGpuStream(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - std::vector allreduce_kernels; - auto execution_kernels = kernel_graph->execution_order(); - for (auto kernel_node : execution_kernels) { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == kAllReduceOpName) { - allreduce_kernels.emplace_back(kernel_node); - } else { - DeviceStream compute_stream = GPUDeviceManager::GetInstance().default_stream(); - MS_EXCEPTION_IF_NULL(compute_stream); - AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(reinterpret_cast(compute_stream)), kernel_node); - } - } - if (allreduce_kernels.size() > 1) { - // Assign multiple streams only when there're multiple AllReduce nodes. - std::vector send_recv_pairs; - if (FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs)) { - DeviceStream comm_stream = nullptr; - GPUDeviceManager::GetInstance().CreateStream(&comm_stream); - std::transform( - allreduce_kernels.begin(), allreduce_kernels.end(), allreduce_kernels.begin(), [&](CNodePtr allreduce_kernel) { - AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(reinterpret_cast(comm_stream)), allreduce_kernel); - return allreduce_kernel; - }); - InsertStreamSwitchNode(kernel_graph, send_recv_pairs); - } else { - return; - } - } -} - -bool FindAllReduceStreamSwitchPos(const std::shared_ptr &kernel_graph, - std::vector *send_recv_pairs) { - auto execution_kernels = kernel_graph->execution_order(); - std::vector::iterator iter, iter_begin; - iter = iter_begin = execution_kernels.begin(); - std::vector::iterator iter_end = execution_kernels.end(); - for (; iter != execution_kernels.end(); ++iter) { - std::string kernel_name = AnfAlgo::GetCNodeName(*iter); - if (kernel_name == kAllReduceOpName) { - // Find AllReduce node's last input node. - std::vector::iterator mock_send_node_iter = - FindSendNodePos(iter_begin, iter + 1, *iter, kAllReduceStreamSwitch); - if (mock_send_node_iter == iter + 1) { - MS_LOG(WARNING) << "Can't find send node place before AllReduce node."; - continue; - } - SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter, - IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)}; - send_recv_pairs->push_back(pair1); - // Find node which uses AllReduce as input[0]. - std::vector::iterator mock_recv_node_iter = - FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch); - if (mock_recv_node_iter == iter_end) { - MS_LOG(WARNING) << "Can't find recv node place after AllReduce node."; - return false; - } - SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1), - IntToSize(mock_recv_node_iter - iter_begin)}; - send_recv_pairs->push_back(pair2); - } - } - return true; -} - -std::vector::iterator FindSendNodePos(std::vector::iterator begin, - std::vector::iterator end, const CNodePtr mock_recv_node, - StreamSwitchType stream_switch_type) { - MS_EXCEPTION_IF_NULL(mock_recv_node); - if (stream_switch_type == kAllReduceStreamSwitch) { - for (auto iter = begin; iter != end; iter++) { - if (*(iter + 1) == mock_recv_node) { - return iter; - } - } - } - return end; -} - -std::vector::iterator FindRecvNodePos(std::vector::iterator begin, - std::vector::iterator end, const CNodePtr mock_send_node, - StreamSwitchType stream_switch_type) { - MS_EXCEPTION_IF_NULL(mock_send_node); - for (auto iter = begin; iter != end; iter++) { - auto node = *iter; - if (stream_switch_type == kAllReduceStreamSwitch) { - for (auto input : node->inputs()) { - if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) { - return iter; - } - } - } - } - return end; -} - -void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, - const std::vector &send_recv_pairs) { - std::set ordered_stream_switch_nodes; - for (SendRecvPair pair : send_recv_pairs) { - StreamSwitchType stream_switch_type = pair.stream_switch_type; - CNodePtr mock_send_node = pair.mock_send_node; - CNodePtr mock_recv_node = pair.mock_recv_node; - size_t send_node_offset = pair.send_node_offset; - size_t recv_node_offset = pair.recv_node_offset; - CNodePtr send_node = nullptr; - CNodePtr recv_node = nullptr; - // Step 1: generate Send and Recv CNodes. - if (stream_switch_type == kAllReduceStreamSwitch) { - if (!GenSendRecvCNodesForAllReduce(kernel_graph, mock_send_node, mock_recv_node, &send_node, &recv_node)) { - MS_LOG(EXCEPTION) << "Generating CNodes for send and recv failed. Stream switch type: kAllReduceStreamSwitch"; - } - } - // Step 2: sort send and recv CNodes by offset. - ordered_stream_switch_nodes.insert({send_node_offset, send_node}); - ordered_stream_switch_nodes.insert({recv_node_offset, recv_node}); - } - // Step 3: insert stream switch CNodes into execution kernel list. - auto execution_kernels = kernel_graph->execution_order(); - for (auto node = ordered_stream_switch_nodes.rbegin(); node != ordered_stream_switch_nodes.rend(); node++) { - execution_kernels.insert(execution_kernels.begin() + node->offset, node->cnode); - } - kernel_graph->set_execution_order(execution_kernels); -} - -bool GenSendRecvCNodesForAllReduce(const std::shared_ptr &kernel_graph, - const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node, - CNodePtr *recv_node) { - *send_node = CreateStreamSwitchNode(kernel_graph, kSendOpName); - MS_EXCEPTION_IF_NULL(*send_node); - *recv_node = CreateStreamSwitchNode(kernel_graph, kRecvOpName); - MS_EXCEPTION_IF_NULL(*recv_node); - - cudaEvent_t event = nullptr; - CHECK_CUDA_RET_WITH_EXCEPT(cudaEventCreate(&event, cudaEventDisableTiming), "Creating cuda event failed."); - AnfAlgo::SetNodeAttr(kAttrRecordEvent, MakeValue(reinterpret_cast(event)), *send_node); - AnfAlgo::SetNodeAttr(kAttrWaitEvent, MakeValue(reinterpret_cast(event)), *recv_node); - - uintptr_t send_stream = AnfAlgo::GetNodeAttr(mock_send_node, kAttrStreamId); - AnfAlgo::SetNodeAttr(kAttrRecordEventStream, MakeValue(send_stream), *send_node); - uintptr_t recv_stream = AnfAlgo::GetNodeAttr(mock_recv_node, kAttrStreamId); - AnfAlgo::SetNodeAttr(kAttrWaitEventStream, MakeValue(recv_stream), *recv_node); - return true; -} - -CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name) { - auto op = std::make_shared(name); - MS_EXCEPTION_IF_NULL(op); - auto apply = std::make_shared(op); - MS_EXCEPTION_IF_NULL(apply); - std::vector input_list = {apply}; - CNodePtr node = kernel_graph->NewCNode(input_list); - MS_EXCEPTION_IF_NULL(node); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), node.get()); - auto abstract_none = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_none); - node->set_abstract(abstract_none); - SetKernelInfo(node); - return node; -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_stream_assign.h b/mindspore/ccsrc/device/gpu/gpu_stream_assign.h deleted file mode 100644 index f8041878b2..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_stream_assign.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ - -#include -#include -#include -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace device { -namespace gpu { -enum StreamSwitchType { kAllReduceStreamSwitch, kStreamSwitchInvalidType = 255 }; -struct SendRecvPair { - StreamSwitchType stream_switch_type; - CNodePtr mock_send_node; - CNodePtr mock_recv_node; - size_t send_node_offset; - size_t recv_node_offset; -}; -struct StreamSwitchNode { - size_t offset; - CNodePtr cnode; - bool operator<(const StreamSwitchNode &n) const { - if (offset < n.offset) { - return true; - } else if (offset == n.offset) { - return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false; - } else { - return false; - } - } -}; -void AssignGpuStream(const std::shared_ptr &kernel_graph); -bool FindAllReduceStreamSwitchPos(const std::shared_ptr &kernel_graph, - std::vector *send_recv_pairs); -// Find Send node position according to "mock" recv node. -// "mock" recv node is a gpu kernel node after a real Recv node, e.g. AllReduce node. -std::vector::iterator FindSendNodePos(std::vector::iterator begin, - std::vector::iterator end, const CNodePtr mock_recv_node, - StreamSwitchType stream_switch_type); -// Find Recv node position according to "mock" send node. -// "mock" send node is a gpu kernel node before a real send node, e.g. AllReduce node. -std::vector::iterator FindRecvNodePos(std::vector::iterator begin, - std::vector::iterator end, const CNodePtr mock_send_node, - StreamSwitchType stream_switch_type); -void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, - const std::vector &send_recv_pairs); -bool GenSendRecvCNodesForAllReduce(const std::shared_ptr &kernel_graph, - const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node, - CNodePtr *recv_node); -CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name); -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc deleted file mode 100644 index 42e76e2483..0000000000 --- a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc +++ /dev/null @@ -1,211 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/kernel_info_setter.h" -#include -#include -#include "kernel/kernel.h" -#include "utils/utils.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/kernel_build_info.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/common_utils.h" -#include "common/utils.h" -#include "kernel/oplib/oplib.h" -#include "kernel/oplib/opinfo.h" - -namespace mindspore { -namespace device { -namespace gpu { -using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; -using mindspore::kernel::KernelBuildInfo; -namespace { -bool CheckKernelInfo(const std::shared_ptr &alternative_kernel_info, - const std::shared_ptr &selected_kernel_info) { - MS_EXCEPTION_IF_NULL(selected_kernel_info); - MS_EXCEPTION_IF_NULL(alternative_kernel_info); - size_t selected_input_num = selected_kernel_info->GetInputNum(); - size_t alternative_input_num = alternative_kernel_info->GetInputNum(); - if (selected_input_num != alternative_input_num) { - return false; - } - for (size_t i = 0; i < selected_input_num; i++) { - if (selected_kernel_info->GetInputFormat(i) != alternative_kernel_info->GetInputFormat(i)) { - return false; - } - if (selected_kernel_info->GetInputDeviceType(i) != alternative_kernel_info->GetInputDeviceType(i)) { - return false; - } - } - - size_t selected_output_num = selected_kernel_info->GetOutputNum(); - size_t alternative_output_num = alternative_kernel_info->GetOutputNum(); - if (selected_output_num != alternative_output_num) { - return false; - } - for (size_t i = 0; i < selected_output_num; i++) { - if (selected_kernel_info->GetOutputFormat(i) != alternative_kernel_info->GetOutputFormat(i)) { - return false; - } - if (selected_kernel_info->GetOutputDeviceType(i) != alternative_kernel_info->GetOutputDeviceType(i)) { - return false; - } - } - return true; -} - -std::string SupportedTypeList(const CNodePtr &kernel_node) { - std::string supported_type_lists = - kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); - if (!supported_type_lists.empty()) { - return supported_type_lists; - } - std::vector> kernel_info_list; - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kAKG); - if (op_info_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Unsupported op [" << op_name << "]"; - } - (void)ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list); - for (size_t i = 0; i < kernel_info_list.size(); i++) { - auto supported_akg_type = kernel_info_list[i]->GetAllInputDeviceTypes(); - auto supported_akg_type_out = kernel_info_list[i]->GetAllOutputDeviceTypes(); - std::string supported_akg_type_list = "in["; - for (auto type : supported_akg_type) { - supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); - } - supported_type_lists = supported_type_lists + supported_akg_type_list + "], out["; - for (auto type : supported_akg_type_out) { - supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); - } - supported_type_lists += "]; "; - } - return supported_type_lists; -} - -bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr &selected_kernel_info) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(selected_kernel_info); - std::vector> kernel_info_list; - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kAKG); - if (op_info_ptr == nullptr) { - MS_LOG(ERROR) << "Not find op[" << op_name << "] in akg"; - return false; - } - if (!ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list)) { - MS_LOG(EXCEPTION) << "Parsed metadata of op[" << op_name << "] failed."; - } - if (kernel_info_list.empty()) { - MS_LOG(EXCEPTION) << "Akg dose not has metadata of op[" << op_name << "]."; - } - - bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(), - [&](const std::shared_ptr &alternative_kernel_info) { - return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); - }); - if (!match) { - MS_LOG(ERROR) << "Not find op[" << op_name << "] in akg"; - return false; - } - return true; -} - -void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - auto input_kernel_node = kernel_node->input(input_index + 1); - MS_EXCEPTION_IF_NULL(input_kernel_node); - if (!input_kernel_node->isa()) { - continue; - } - std::shared_ptr builder = - std::make_shared(); - - auto param = input_kernel_node->cast(); - MS_EXCEPTION_IF_NULL(param); - if (!AnfAlgo::IsParameterWeight(param)) { - std::vector output_format = {kOpFormat_DEFAULT}; - builder->SetOutputsFormat(output_format); - std::vector output_type = {AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)}; - builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); - continue; - } - if ((AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) || - (AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) { - std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; - builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; - builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); - } - } -} -} // namespace - -void SetKernelInfo(const CNodePtr &kernel_node) { - std::vector inputs_format; - std::vector inputs_type; - std::shared_ptr builder = - std::make_shared(); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - inputs_format.emplace_back(kOpFormat_DEFAULT); - inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); - } - builder->SetInputsFormat(inputs_format); - builder->SetInputsDeviceType(inputs_type); - std::vector outputs_format; - std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - outputs_format.emplace_back(kOpFormat_DEFAULT); - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); - } - builder->SetOutputsFormat(outputs_format); - builder->SetOutputsDeviceType(outputs_type); - - bool result = - kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); - KernelType kernel_type = UNKNOWN_KERNEL_TYPE; - - if (!result) { - result = SelectAkgKernel(kernel_node, builder->Build()); - kernel_type = AKG_KERNEL; - } - - if (!result) { - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - std::string build_type = "in ["; - std::for_each(std::begin(inputs_type), std::end(inputs_type), - [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); - build_type += "] out ["; - std::for_each(std::begin(outputs_type), std::end(outputs_type), - [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); - build_type += "]"; - auto supported_type_lists = SupportedTypeList(kernel_node); - MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name - << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists - << ", but get " << build_type; - } - builder->SetKernelType(kernel_type); - builder->SetProcessor(kernel::Processor::CUDA); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); - SetTensorDeviceInfo(*(builder->Build()), kernel_node); -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/mpi/mpi_initializer.cc b/mindspore/ccsrc/device/gpu/mpi/mpi_initializer.cc deleted file mode 100644 index bcad74e5b5..0000000000 --- a/mindspore/ccsrc/device/gpu/mpi/mpi_initializer.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/mpi/mpi_initializer.h" - -#include -#include -#include - -namespace mindspore { -namespace device { -namespace gpu { -MPIInitializer::MPIInitializer() { - int init_flag = 0; - if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { - return; - } - if (init_flag == 0) { - auto ret = MPI_Init(nullptr, nullptr); - if (ret != MPI_SUCCESS) { - return; - } - } - MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); - MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); -} - -MPIInitializer::~MPIInitializer() { - int finalized_flag = 0; - (void)MPI_Finalized(&finalized_flag); - if (finalized_flag == 0) { - (void)MPI_Finalize(); - } -} - -MPIInitializer &MPIInitializer::GetInstance() { - static MPIInitializer instance; - return instance; -} - -int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; } - -int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; } - -PYBIND11_MODULE(_ms_mpi, mpi_initializer) { - mpi_initializer.doc() = "mindspore mpi python wrapper"; - mpi_initializer.def("get_rank_id", &MPIInitializer::get_rank_id, "get rank id"); - mpi_initializer.def("get_rank_size", &MPIInitializer::get_rank_size, "get rank size"); -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc deleted file mode 100644 index fd0a8eb967..0000000000 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ /dev/null @@ -1,573 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/kernel_adjust.h" - -#include -#include -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "utils/context/ms_context.h" -#include "common/trans.h" -#include "utils/config_manager.h" -#include "common/utils.h" -#include "kernel/kernel_build_info.h" -#include "utils/utils.h" -#include "device/ascend/profiling/profiling_manager.h" -#include "device/ascend/kernel_select_ascend.h" -#include "runtime/base.h" -#include "device/ascend/ascend_stream_assign.h" -namespace mindspore { -namespace device { -using device::ascend::ProfilingUtils; -void KernelAdjust::ReorderGetNext(const std::shared_ptr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - const std::vector &origin_cnode_list = kernel_graph_ptr->execution_order(); - std::vector getnext_list; - std::vector other_list; - for (const auto &cnode : origin_cnode_list) { - if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) { - getnext_list.emplace_back(cnode); - } else { - other_list.emplace_back(cnode); - } - } - std::vector new_order_list; - new_order_list.insert(new_order_list.end(), getnext_list.begin(), getnext_list.end()); - new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end()); - kernel_graph_ptr->set_execution_order(new_order_list); -} - -bool KernelAdjust::NeedInsertSwitch() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && - ConfigManager::GetInstance().iter_num() > 1); -} - -CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, - uint32_t event_id) { - MS_EXCEPTION_IF_NULL(graph_ptr); - auto send_op = std::make_shared(kSendOpName); - MS_EXCEPTION_IF_NULL(send_op); - auto send_apply = std::make_shared(send_op); - MS_EXCEPTION_IF_NULL(send_apply); - std::vector send_input_list = {send_apply}; - CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); - MS_EXCEPTION_IF_NULL(send_node_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); - AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); - auto abstract_none = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_none); - send_node_ptr->set_abstract(abstract_none); - return send_node_ptr; -} - -CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, - uint32_t event_id) { - MS_EXCEPTION_IF_NULL(graph_ptr); - auto recv_op = std::make_shared(kRecvOpName); - MS_EXCEPTION_IF_NULL(recv_op); - auto recv_apply = std::make_shared(recv_op); - MS_EXCEPTION_IF_NULL(recv_apply); - std::vector recv_input_list = {recv_apply}; - CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); - MS_EXCEPTION_IF_NULL(recv_node_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); - AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); - auto abstract_none = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_none); - recv_node_ptr->set_abstract(abstract_none); - return recv_node_ptr; -} - -void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { - device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); - resource_manager.ResetResource(); - if (!NeedInsertSwitch()) { - return; - } - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX; - ReorderGetNext(kernel_graph_ptr); - std::map switch_loop_input; - CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input); - - std::vector *mute_inputs = kernel_graph_ptr->MutableInputs(); - MS_EXCEPTION_IF_NULL(mute_inputs); - mute_inputs->push_back(switch_loop_input[kLoopCountParamName]); - mute_inputs->push_back(switch_loop_input[kIterLoopParamName]); - mute_inputs->push_back(switch_loop_input[kZeroParamName]); - mute_inputs->push_back(switch_loop_input[kOneParamName]); - for (const auto &input : kernel_graph_ptr->inputs()) { - MS_EXCEPTION_IF_NULL(input); - if (input->isa()) { - ParameterPtr param_ptr = input->cast(); - if (param_ptr == nullptr) { - MS_EXCEPTION(NotSupportError) << "Cast to parameter point failed !"; - } - } - } - - const std::vector &orders = kernel_graph_ptr->execution_order(); - if (orders.empty()) { - MS_LOG(EXCEPTION) << "graph execution order is empty"; - } - - std::vector exec_order; - std::vector getnext_active_streams; - std::vector fpbp_active_streams; - CNodePtr getnext_cnode; - uint32_t eos_done_event_id = UINT32_MAX; - - // getnext loop process - // getnext loop stream switch op - CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(getnext_switch_app); - uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); - exec_order.push_back(getnext_switch_app); - - // getnext op - uint32_t getnext_stream_id = resource_manager.ApplyNewStream(); - size_t i = 0; - for (; i < orders.size(); i++) { - auto node = orders[i]; - exec_order.push_back(node); - AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get()); - if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { - getnext_cnode = node; - break; - } - } - - // update getnext loop stream switch true_branch_stream attr - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(getnext_stream_id), getnext_switch_app); - - // getnext loop fpbp start send - uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent(); - CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id); - AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get()); - exec_order.push_back(fpbp_start_send); - - if (eos_mode) { - // getnext loop eos start send - uint32_t eos_start_event_id = resource_manager.ApplyNewEvent(); - CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id); - AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get()); - exec_order.push_back(eos_start_send); - - // End Of Sequence loop process - // eos loop stream switch - CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(eos_switch_app); - uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get()); - AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), eos_switch_app); - exec_order.push_back(eos_switch_app); - - // eos loop eos start recv - CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id); - uint32_t eos_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get()); - exec_order.push_back(eos_start_recv); - - // update eos loop stream switch true_branch_stream attr - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(eos_stream_id), eos_switch_app); - - // EndOfSequence op - CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode); - MS_EXCEPTION_IF_NULL(end_of_sequence_op); - AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get()); - exec_order.push_back(end_of_sequence_op); - - // eos loop eos done send - eos_done_event_id = resource_manager.ApplyNewEvent(); - CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, eos_done_event_id); - AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get()); - exec_order.push_back(eos_done_send); - - // eos loop stream active - fpbp_active_streams.push_back(eos_switch_stream_id); - } - - // fpbp loop process - // fpbp loop stream switch - CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(fpbp_switch_app); - uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); - AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), fpbp_switch_app); - exec_order.push_back(fpbp_switch_app); - - // fpbp loop fpbp start recv - CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id); - uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get()); - exec_order.push_back(fpbp_start_recv); - - // update fpbp loop stream switch true_branch_stream attr - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(fpbp_stream_id), fpbp_switch_app); - - // fpbp loop AssignAdd - CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(assign_add_one); - AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get()); - exec_order.push_back(assign_add_one); - - // fpbp memcpy - std::vector memcpy_list; - std::vector other_list; - CNodePtr cur_cnode = nullptr; - for (size_t idx = i + 1; idx < orders.size(); idx++) { - cur_cnode = orders[idx]; - if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { - memcpy_list.emplace_back(cur_cnode); - } else { - other_list.emplace_back(cur_cnode); - } - } - - (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); - - // fpbp loop eos done recv - if (eos_mode) { - CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id); - AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get()); - exec_order.push_back(eos_done_recv); - } - - // stream active to activate getnext loop - CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(getnext_active_app); - getnext_active_streams.push_back(getnext_switch_stream_id); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(getnext_active_streams), - getnext_active_app); - exec_order.push_back(getnext_active_app); - - // fpbp loop other ops - (void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order)); - - // stream active to activate fpbp loop and eos loop - CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(fpbp_active_app); - fpbp_active_streams.push_back(fpbp_switch_stream_id); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(fpbp_active_streams), fpbp_active_app); - exec_order.push_back(fpbp_active_app); - - kernel_graph_ptr->set_execution_order(exec_order); -} - -void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, - std::map *switch_loop_input) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(switch_loop_input); - std::vector shp = {1}; - tensor::TensorPtr tensor_ptr = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(tensor_ptr); - mindspore::abstract::AbstractBasePtr paremeter_abstract_ptr = tensor_ptr->ToAbstract(); - if (paremeter_abstract_ptr == nullptr) { - MS_LOG(EXCEPTION) << "create abstract before insert switch op failed!"; - } - - ParameterPtr loop_count = std::make_shared(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(loop_count); - loop_count->set_name(kLoopCountParamName); - loop_count->set_abstract(paremeter_abstract_ptr); - ParameterPtr loop_count_new = kernel_graph_ptr->NewParameter(loop_count); - - (*switch_loop_input)[kLoopCountParamName] = loop_count_new; - - ParameterPtr iter_loop = std::make_shared(kernel_graph_ptr); - iter_loop->set_name(kIterLoopParamName); - iter_loop->set_abstract(paremeter_abstract_ptr); - ParameterPtr iter_loop_new = kernel_graph_ptr->NewParameter(iter_loop); - (*switch_loop_input)[kIterLoopParamName] = iter_loop_new; - - ParameterPtr zero = std::make_shared(kernel_graph_ptr); - zero->set_name(kZeroParamName); - zero->set_abstract(paremeter_abstract_ptr); - ParameterPtr zero_new = kernel_graph_ptr->NewParameter(zero); - (*switch_loop_input)[kZeroParamName] = zero_new; - - ParameterPtr one = std::make_shared(kernel_graph_ptr); - one->set_name(kOneParamName); - one->set_abstract(paremeter_abstract_ptr); - ParameterPtr one_new = kernel_graph_ptr->NewParameter(one); - (*switch_loop_input)[kOneParamName] = one_new; -} - -kernel::KernelBuildInfo::KernelBuildInfoBuilder KernelAdjust::CreateMngKernelBuilder( - const std::vector &formats, const std::vector &type_ids) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetInputsFormat(formats); - selected_kernel_builder.SetInputsDeviceType(type_ids); - - selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); - selected_kernel_builder.SetProcessor(kernel::Processor::AICORE); - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - return selected_kernel_builder; -} - -CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, - const std::map &switch_loop_input) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - auto typeNone_abstract = std::make_shared(); - auto stream_switch = std::make_shared(kStreamSwitchOpName); - std::vector inputs; - inputs.push_back(NewValueNode(stream_switch)); - inputs.push_back(switch_loop_input.at(kLoopCountParamName)); - inputs.push_back(switch_loop_input.at(kIterLoopParamName)); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr stream_switch_app = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(stream_switch_app); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_switch_app.get()); - stream_switch_app->set_abstract(typeNone_abstract); - // set attr: cond_ RT_LESS - int condition = static_cast(RT_LESS); - ValuePtr cond = MakeValue(condition); - AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); - // set attr:data_type - int data_type = static_cast(RT_SWITCH_INT64); - ValuePtr dt = MakeValue(data_type); - AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); - // set distinction label and graph id - return stream_switch_app; -} - -CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_others = std::make_shared(kStreamActiveOpName); - std::vector inputs; - inputs.push_back(NewValueNode(stream_active_others)); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(stream_active_others_app); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get()); - stream_active_others_app->set_abstract(typeNone_abstract); - return stream_active_others_app; -} - -CNodePtr KernelAdjust::CreatTupleGetItemNode(const std::shared_ptr &kernel_graph_ptr, - const CNodePtr &node, size_t output_idx) { - auto idx = NewValueNode(SizeToInt(output_idx)); - MS_EXCEPTION_IF_NULL(idx); - auto imm = std::make_shared(SizeToInt(output_idx)); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - CNodePtr tuple_getitem = kernel_graph_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); - MS_EXCEPTION_IF_NULL(tuple_getitem); - tuple_getitem->set_scope(node->scope()); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); - AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); - return tuple_getitem; -} - -CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr &kernel_graph_ptr, - const CNodePtr &getnext_cnode) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT}); - selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8}); - - selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); - selected_kernel_builder.SetProcessor(kernel::Processor::AICPU); - selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL); - - selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); - selected_kernel_builder.SetOutputsDeviceType({kNumberTypeUInt8}); - // EndOfSequence - auto end_of_sequence = std::make_shared(kEndOfSequence); - std::vector inputs; - inputs.push_back(NewValueNode(end_of_sequence)); - // GetNext output 0 is EndOfSequence's input - auto tuple_get_item = CreatTupleGetItemNode(kernel_graph_ptr, getnext_cnode, 0); - inputs.push_back(tuple_get_item); - CNodePtr end_of_sequence_node = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(end_of_sequence_node); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), end_of_sequence_node.get()); - std::vector input_names = {"x"}; - ValuePtr input_names_v = MakeValue(input_names); - AnfAlgo::SetNodeAttr("input_names", input_names_v, end_of_sequence_node); - std::vector output_names = {"y"}; - ValuePtr output_names_v = MakeValue(output_names); - AnfAlgo::SetNodeAttr("output_names", output_names_v, end_of_sequence_node); - end_of_sequence_node->set_abstract(tuple_get_item->abstract()); - return end_of_sequence_node; -} - -CNodePtr KernelAdjust::CreateStreamAssignAddnOP( - const std::shared_ptr &kernel_graph_ptr, - const std::map &switch_loop_input) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); - selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32}); - // AssignAdd - auto assign_add = std::make_shared(kAssignAddOpName); - std::vector inputs; - inputs.push_back(NewValueNode(assign_add)); - inputs.push_back(switch_loop_input.at(kLoopCountParamName)); - inputs.push_back(switch_loop_input.at(kOneParamName)); - CNodePtr assign_add_one = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(assign_add_one); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), assign_add_one.get()); - std::vector input_names = {"ref", "value"}; - std::vector output_names = {"output"}; - ValuePtr input_names_v = MakeValue(input_names); - ValuePtr output_names_v = MakeValue(output_names); - AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_one); - AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_one); - selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); - MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); - assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); - return assign_add_one; -} - -bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr) { - if (!NeedInsertSwitch()) { - return true; - } - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - auto input_nodes = kernel_graph_ptr->inputs(); - std::vector inputs; - LoadSwitchInputs(&inputs); - std::shared_ptr> inputsPtr = std::make_shared>(inputs); - kernel_graph_ptr->set_input_ctrl_tensors(inputsPtr); - size_t input_ctrl_size = inputs.size(); - // inputs_node:include four ctrl nodes in the back. such as:conv,loop_cnt, ites_loop, zero, one. - // deal four ctrl nodes. - for (size_t i = 0; i < inputs.size(); ++i) { - auto tensor = inputs[i]; - size_t deal_index = input_nodes.size() - input_ctrl_size + i; - if (deal_index >= input_nodes.size()) { - MS_LOG(EXCEPTION) << "deal_index[" << deal_index << "] out of range"; - } - auto input_node = input_nodes[deal_index]; - bool need_sync = false; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa()) { - auto pk_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(tensor); - MS_EXCEPTION_IF_NULL(pk_node); - if (tensor->is_dirty() || !pk_node->has_default()) { - need_sync = true; - } - } - if (need_sync) { - auto pk_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(pk_node); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - MS_EXCEPTION_IF_NULL(device_address); - tensor->set_device_address(device_address); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(INFO) << "SyncHostToDevice failed."; - return false; - } - } - tensor->set_dirty(false); - } - return true; -} - -void KernelAdjust::LoadSwitchInputs(std::vector *inputs) { - MS_LOG(INFO) << "---------------- LoadSwitchInputs---"; - MS_EXCEPTION_IF_NULL(inputs); - std::vector shp = {1}; - tensor::TensorPtr loop_count_tensor = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(loop_count_tensor); - int32_t *val = nullptr; - val = static_cast(loop_count_tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 0; - inputs->push_back(loop_count_tensor); - - tensor::TensorPtr iter_loop_tensor = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(iter_loop_tensor); - val = static_cast(iter_loop_tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = SizeToInt(LongToSize(ConfigManager::GetInstance().iter_num())); - MS_LOG(INFO) << "iter_loop_tensor = " << *val; - inputs->push_back(iter_loop_tensor); - - tensor::TensorPtr zero_tensor = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(zero_tensor); - val = static_cast(zero_tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 0; - inputs->push_back(zero_tensor); - - tensor::TensorPtr one_tensor = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(one_tensor); - val = static_cast(one_tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 1; - inputs->push_back(one_tensor); - MS_LOG(INFO) << "---------------- LoadSwitchInputs End--"; -} - -void KernelAdjust::Profiling(NotNull kernel_graph_ptr) { - if (!ascend::ProfilingManager::GetInstance().IsProfiling()) { - MS_LOG(INFO) << "No need to profiling"; - return; - } - ProfilingTraceInfo profiling_trace_info = ProfilingUtils::GetProfilingTraceFromEnv(kernel_graph_ptr); - if (!profiling_trace_info.IsValid()) { - MS_LOG(WARNING) << "[profiling] no profiling node found!"; - return; - } - InsertProfilingKernel(profiling_trace_info, kernel_graph_ptr); -} - -void KernelAdjust::InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, - NotNull kernel_graph_ptr) { - MS_LOG(INFO) << "[profiling] Insert profiling kernel start"; - if (!profiling_trace_info.IsValid()) { - MS_LOG(WARNING) << "Profiling trace point not found"; - return; - } - std::vector new_cnode_list; - std::vector cnode_ptr_list = kernel_graph_ptr->execution_order(); - if (cnode_ptr_list.empty()) { - MS_LOG(ERROR) << "No CNode in graph"; - return; - } - for (const auto &cnode_ptr : cnode_ptr_list) { - ProfilingUtils::ProfilingTraceFpStart(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); - new_cnode_list.emplace_back(cnode_ptr); - ProfilingUtils::ProfilingCustomOp(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); - ProfilingUtils::ProfilingTraceBpEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); - ProfilingUtils::ProfilingTraceEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); - } - kernel_graph_ptr->set_execution_order(new_cnode_list); -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_adjust.h b/mindspore/ccsrc/device/kernel_adjust.h deleted file mode 100644 index bf3ba2acb2..0000000000 --- a/mindspore/ccsrc/device/kernel_adjust.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ - -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "session/kernel_graph.h" -#include "kernel/kernel_build_info.h" -#include "session/session_context.h" -#include "ir/tensor.h" -#include "device/ascend/profiling/profiling_utils.h" -#include "device/kernel_info.h" - -using mindspore::device::ascend::ProfilingTraceInfo; -using mindspore::device::ascend::ProfilingUtils; -namespace mindspore { -constexpr auto kLoopCountParamName = "loop_count"; -constexpr auto kIterLoopParamName = "iter_loop"; -constexpr auto kZeroParamName = "zero"; -constexpr auto kOneParamName = "one"; -constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; -constexpr uint32_t kSecondStreamSwitchLabel = 2; - -namespace device { -class KernelAdjust { - public: - static KernelAdjust &GetInstance() { - static KernelAdjust instance; - return instance; - } - - void InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr); - bool StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr); - void Profiling(NotNull kernel_graph_ptr); - static bool NeedInsertSwitch(); - CNodePtr CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr); - - private: - KernelAdjust() = default; - ~KernelAdjust() = default; - - void ReorderGetNext(const std::shared_ptr &kernel_graph_ptr); - CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); - CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); - void CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, - std::map *switch_loop_input); - CNodePtr CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, - const std::map &switch_loop_input); - CNodePtr CreatTupleGetItemNode(const std::shared_ptr &kernel_graph_ptr, const CNodePtr &node, - size_t output_idx); - CNodePtr CreateEndOfSequenceOP(const std::shared_ptr &kernel_graph_ptr, - const CNodePtr &getnext_cnode); - CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr &kernel_graph_ptr, - const std::map &switch_loop_input); - kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector &formats, - const std::vector &type_ids); - void LoadSwitchInputs(std::vector *inputs); - void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, - NotNull kernel_graph_ptr); -}; -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ diff --git a/mindspore/ccsrc/device/kernel_info.cc b/mindspore/ccsrc/device/kernel_info.cc deleted file mode 100644 index 59c9b0f411..0000000000 --- a/mindspore/ccsrc/device/kernel_info.cc +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/kernel_info.h" - -namespace mindspore { -namespace device { -const kernel::KernelBuildInfo *KernelInfo::select_kernel_build_info() const { return select_kernel_build_info_.get(); } - -kernel::KernelBuildInfoPtr KernelInfo::GetMutableSelectKernelBuildInfo() const { return select_kernel_build_info_; } - -const DeviceAddress *KernelInfo::GetOutputAddr(size_t index) const { - if (index >= output_address_list_.size()) { - MS_LOG(ERROR) << "Index [" << index << "] out of range"; - return nullptr; - } - return output_address_list_[index].get(); -} - -DeviceAddressPtr KernelInfo::GetMutableOutputAddr(size_t index) const { - if (index >= output_address_list_.size()) { - MS_LOG(ERROR) << "Index [" << index << "] out of range"; - return nullptr; - } - return output_address_list_[index]; -} - -bool KernelInfo::OutputAddrExist(size_t index) const { - if (index >= output_address_list_.size()) { - return false; - } - return output_address_list_[index] != nullptr; -} - -bool KernelInfo::SetOutputAddr(const DeviceAddressPtr &output_address, size_t index) { - // parameter and valuenode - if (kernel_mod_ == nullptr && index >= output_address_list_.size()) { - for (size_t i = output_address_list_.size(); i <= index; i++) { - output_address_list_.emplace_back(nullptr); - } - } else if (output_address_list_.empty()) { - // set cnode - for (size_t i = 0; i < kernel_mod_->GetOutputSizeList().size(); i++) { - output_address_list_.emplace_back(nullptr); - } - } - if (index >= output_address_list_.size()) { - MS_LOG(ERROR) << "Index [" << index << "] out of range"; - return false; - } - output_address_list_[index] = output_address; - return true; -} - -DeviceAddress *KernelInfo::GetWorkspaceAddr(size_t index) const { - if (index >= workspace_address_list_.size()) { - MS_LOG(ERROR) << "Index [" << index << "] out of range"; - return nullptr; - } - return workspace_address_list_[index].get(); -} - -bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { - if (workspace_address_list_.empty()) { - // parameter and valuenode - if (kernel_mod_ == nullptr) { - workspace_address_list_.emplace_back(nullptr); - } else { - // set cnode - for (size_t i = 0; i < kernel_mod_->GetWorkspaceSizeList().size(); i++) { - workspace_address_list_.emplace_back(nullptr); - } - } - } - if (index >= workspace_address_list_.size()) { - MS_LOG(ERROR) << "Index" << index << " out of range"; - return false; - } - workspace_address_list_[index] = output_address; - return true; -} - -void KernelInfo::set_kernel_mod(const kernel::KernelModPtr &kernel_mod) { kernel_mod_ = kernel_mod; } - -kernel::KernelMod *KernelInfo::MutableKernelMod() const { return kernel_mod_.get(); } - -const kernel::KernelMod *KernelInfo::kernel_mod() const { return kernel_mod_.get(); } - -bool KernelInfo::operator==(const KernelInfo &other) const { - if (stream_id_ != other.stream_id_ || stream_distinction_label_ != other.stream_distinction_label_ || - graph_id_ != other.graph_id_) { - return false; - } - if ((select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ == nullptr) || - (select_kernel_build_info_ == nullptr && other.select_kernel_build_info_ != nullptr)) { - return false; - } - if (select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ != nullptr) { - if (!(*select_kernel_build_info_ == *(other.select_kernel_build_info_))) { - return false; - } - } - // Currently we only check whether both the kernel_mod_ are initialized or uninitialized. - if ((kernel_mod_ == nullptr && other.kernel_mod_ != nullptr) || - (kernel_mod_ != nullptr && other.kernel_mod_ == nullptr)) { - return false; - } - // Currently we only check whether both the sizes are equal of output_address_list_ and workspace_address_list_ or - // not. We can complete this check in the future. - if (output_address_list_.size() != other.output_address_list_.size() || - workspace_address_list_.size() != other.workspace_address_list_.size()) { - return false; - } - return true; -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_info.h b/mindspore/ccsrc/device/kernel_info.h deleted file mode 100644 index 84cfaa0fa3..0000000000 --- a/mindspore/ccsrc/device/kernel_info.h +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_DEVICE_KERNEL_INFO_H_ -#define MINDSPORE_DEVICE_KERNEL_INFO_H_ - -#include -#include -#include "kernel/kernel_build_info.h" -#include "device/ascend/ascend_device_address.h" -#include "kernel/kernel.h" - -namespace mindspore { -const uint32_t kInvalidGraphId = UINT32_MAX; -const uint32_t kInvalidDistincLabel = UINT32_MAX; -namespace device { -class KernelInfo { - public: - KernelInfo() { - kernel_mod_ = nullptr; - is_feature_map_ = false; - select_kernel_build_info_ = nullptr; - output_address_list_ = {}; - workspace_address_list_ = {}; - stream_id_ = UINT32_MAX; - stream_distinction_label_ = kInvalidDistincLabel; - graph_id_ = kInvalidGraphId; - } - virtual ~KernelInfo() = default; - - const kernel::KernelBuildInfo *select_kernel_build_info() const; - kernel::KernelBuildInfoPtr GetMutableSelectKernelBuildInfo() const; - void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { - select_kernel_build_info_ = select_kernel_build_info; - } - void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; } - const DeviceAddress *GetOutputAddr(size_t index) const; - DeviceAddressPtr GetMutableOutputAddr(size_t index) const; - bool OutputAddrExist(size_t index) const; - bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); - DeviceAddress *GetWorkspaceAddr(size_t index) const; - bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); - void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); - kernel::KernelMod *MutableKernelMod() const; - const kernel::KernelMod *kernel_mod() const; - uint32_t stream_id() const { return stream_id_; } - void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; } - uint32_t stream_distinction_label() const { return stream_distinction_label_; } - void set_stream_distinction_label(uint32_t stream_distinction_label) { - stream_distinction_label_ = stream_distinction_label; - } - void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } - uint32_t graph_id() const { return graph_id_; } - bool operator==(const KernelInfo &other) const; - bool is_feature_map() const { return is_feature_map_; } - - private: - bool is_feature_map_; - kernel::KernelBuildInfoPtr select_kernel_build_info_; - std::vector> output_address_list_; - std::vector> workspace_address_list_; - kernel::KernelModPtr kernel_mod_; - // stream_id_ is the index of stream object vector - uint32_t stream_id_; - // stream_distinction_label_ is used mark different op in different stream - uint32_t stream_distinction_label_; - // record which graph the node belong to - uint32_t graph_id_; -}; -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_DEVICE_KERNEL_INFO_H_ diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc deleted file mode 100644 index 27cf1dfc92..0000000000 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ /dev/null @@ -1,762 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/kernel_runtime.h" -#include -#include -#include -#include -#include "common/utils.h" -#include "common/trans.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "operator/ops.h" -#include "pipeline/parse/python_adapter.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/common_utils.h" -#include "kernel/oplib/oplib.h" -#include "ir/value.h" -using mindspore::kernel::Address; -using mindspore::kernel::AddressPtr; - -namespace mindspore { -namespace device { -KernelRuntime::~KernelRuntime() { -#ifdef ENABLE_DUMP_E2E - dump_conf_ptr_ = nullptr; -#endif -} - -bool KernelRuntime::Run(session::KernelGraph *graph) { - bool ret = false; - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif - bool is_task_sink = context_ptr->enable_task_sink(); - if (is_task_sink) { - ret = RunTask(graph); - } else { - ret = LaunchKernel(graph); - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Call MS Run Success in " << cost << " us"; -#endif - return ret; -} - -// for D to impl -bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - -// for D to impl -bool KernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { - if (graph != nullptr) { - return true; - } - return false; -} - -// for D to impl -bool KernelRuntime::GenTask(const session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - -bool KernelRuntime::LoadTask(const session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - -// for D to impl -bool KernelRuntime::RunTask(const session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - -bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { - MS_EXCEPTION_IF_NULL(kernel); - if (AnfAlgo::OutputAddrExist(kernel, index)) { - return true; - } - return false; -} - -size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { - MS_EXCEPTION_IF_NULL(node); - if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { - MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size [" - << AnfAlgo::GetOutputTensorNum(node) << "] of node!"; - } - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); - if (output_type_id == kTypeUnknown) { - output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index); - } - size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); - std::vector shape = AnfAlgo::GetOutputDeviceShape(node, output_index); - auto format = AnfAlgo::GetOutputFormat(node, output_index); - if (shape.empty() && format != kOpFormat_DEFAULT) { - shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index)); - shape = trans::TransShapeToDevice(shape, format); - } - // scalar's output shape is a empty vector - size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - return tensor_size; -} - -void KernelRuntime::AssignMemory(session::KernelGraph *graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(mem_manager_); - mem_manager_->ResetDynamicMemory(); - AssignStaticMemory(graph); - AssignDynamicMemory(graph); - UpdateRefNodeOutputMem(graph); -} - -void KernelRuntime::RunOpAssignMemory(const std::vector &input_tensors, - session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - RunOpAssignInputMemory(input_tensors, graph); - AssignStaticMemoryValueNode(graph); - for (const auto &cnode : graph->execution_order()) { - RunOpAssignOutputMemory(cnode); - RunOpAssignWorkSpaceMemory(cnode); - } - UpdateRefNodeOutputMem(graph); -} - -void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - // clear input parameter memory resource - for (const auto &input_node : graph->inputs()) { - MS_EXCEPTION_IF_NULL(input_node); - AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get()); - } - // clear input value node memory resource - for (const auto &value_node : graph->graph_value_nodes()) { - MS_EXCEPTION_IF_NULL(value_node); - AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get()); - } - for (const auto &cnode : graph->execution_order()) { - MS_EXCEPTION_IF_NULL(cnode); - // clear output memory resource - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { - AnfAlgo::SetOutputAddr(nullptr, index, cnode.get()); - } - // clear workspace memory resource - auto kernel_mod = AnfAlgo::GetKernelMod(cnode); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto workspace_lists = kernel_mod->GetWorkspaceSizeList(); - for (size_t index = 0; index < workspace_lists.size(); ++index) { - AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get()); - } - } -} - -void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) { - AssignStaticMemoryInput(graph); - AssignStaticMemoryValueNode(graph); - AssignStaticMemoryOutput(graph); -} - -void KernelRuntime::RunOpAssignInputMemory(const std::vector &input_tensors, - const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mem_manager_); - if (input_tensors.size() != graph->inputs().size()) { - MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() - << " should be equal to graph input parameter size " << graph->inputs().size(); - } - - for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) { - auto item = graph->inputs()[input_index]; - MS_EXCEPTION_IF_NULL(item); - if (!item->isa()) { - continue; - } - auto output_size = AnfAlgo::GetOutputTensorNum(item); - for (size_t index = 0; index < output_size; index++) { - MS_EXCEPTION_IF_NULL(input_tensors[input_index]); - if (input_tensors[input_index]->device_address().get() != nullptr) { - AnfAlgo::SetOutputAddr(input_tensors[input_index]->device_address(), index, item.get()); - continue; - } - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); - if (output_type_id == kTypeUnknown) { - output_type_id = AnfAlgo::GetOutputInferDataType(item, index); - } - auto tensor_size = CountNodeDeviceMemorySize(item, index); - auto device_address = - CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); - MS_EXCEPTION_IF_NULL(device_address); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size); - if (!ret) { - MS_LOG(EXCEPTION) << "Malloc device memory failed."; - } - AnfAlgo::SetOutputAddr(device_address, index, item.get()); - } - } -} - -void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.empty()) { - return; - } - - for (size_t i = 0; i < output_sizes.size(); ++i) { - if (AnfAlgo::OutputAddrExist(kernel, i)) { - continue; - } - if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); - AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); - continue; - } - std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); - auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); - auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); - device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i)); - MS_EXCEPTION_IF_NULL(device_address); - auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); - if (!ret) { - MS_LOG(EXCEPTION) << "Malloc device memory failed."; - } - AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); - } -} - -void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(mem_manager_); - if (kernel->isa()) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto workspace_lists = kernel_mod->GetWorkspaceSizeList(); - for (size_t i = 0; i < workspace_lists.size(); ++i) { - auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown); - MS_EXCEPTION_IF_NULL(device_address); - auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]); - if (!ret) { - MS_LOG(EXCEPTION) << "Malloc device memory failed."; - } - AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); - } - } -} - -void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto graph_inputs = graph->inputs(); - auto graph_valid_input = graph->valid_inputs(); - std::vector need_alloc_nodes; - for (size_t i = 0; i < graph_inputs.size(); ++i) { - auto item = graph_inputs[i]; - MS_EXCEPTION_IF_NULL(item); - if (i < graph_valid_input.size() && !graph_valid_input[i]) { - continue; - } - - if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) { - auto outs = AnfAlgo::GetAllOutput(item); - for (auto &out : outs) { - MS_EXCEPTION_IF_NULL(out); - if (!out->isa()) { - continue; - } - if (NodeOutputDeviceAddressExist(out, 0)) { - continue; - } - need_alloc_nodes.push_back(out); - } - } - if (!item->isa()) { - continue; - } - if (NodeOutputDeviceAddressExist(item, 0)) { - continue; - } - need_alloc_nodes.push_back(item); - } - - for (auto &item : need_alloc_nodes) { - auto output_size = AnfAlgo::GetOutputTensorNum(item); - for (size_t index = 0; index < output_size; index++) { - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); - // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown - if (output_type_id == kTypeUnknown) { - MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; - output_type_id = AnfAlgo::GetOutputInferDataType(item, index); - } - auto tensor_size = CountNodeDeviceMemorySize(item, index); - auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); - auto address = CreateDeviceAddress(ptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); - AnfAlgo::SetOutputAddr(address, index, item.get()); - } - } -} - -void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); - std::vector non_communication_op; - // Assign Communicate Op Memory firstly. - for (const auto &node : nodes) { - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - MS_EXCEPTION_IF_NULL(item_with_index.first); - if (!item_with_index.first->isa() || !AnfAlgo::IsRealKernel(item_with_index.first)) { - continue; - } - graph->AddFinalOutputKernel(item_with_index.first); - if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { - AssignCommunicationNodeMem(kStaticMem, item_with_index.first); - } else { - non_communication_op.emplace_back(item_with_index); - } - } - - for (const auto &item_with_index : non_communication_op) { - AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second)); - } -} - -void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto &kernels = graph->execution_order(); - for (auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.empty()) { - MS_LOG(INFO) << "This kernel has no output size."; - continue; - } - for (size_t i = 0; i < output_sizes.size(); ++i) { - session::AnfWithOutIndex out_pair(kernel, i); - if (graph->IsInRefOutputMap(out_pair)) { - auto origin_pair = graph->GetRefCorrespondOutput(out_pair); - MS_EXCEPTION_IF_NULL(origin_pair.first); - auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second); - MS_EXCEPTION_IF_NULL(origin_node_output_addr); - auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i); - if (origin_node_output_addr.get() != cur_node_output_addr.get()) { - MS_LOG(INFO) << "REF address is not same, ref node output need address update"; - MS_LOG(INFO) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is " - << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i; - AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get()); - } - } - } - } -} - -void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { - AssignCommunicationNodeInputMem(node); - AssignCommunicationNodeOutputMem(flag, node); -} - -void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto kernel_mod = AnfAlgo::GetKernelMod(node); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.empty()) { - MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size."; - return; - } - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - size_t total_size = 0; - size_t output_index = 0; - std::vector align_size_list; - for (uint64_t mem_size : output_sizes) { - if (AnfAlgo::OutputAddrExist(node, output_index++)) { - MS_LOG(INFO) << "communication op addr exist"; - continue; - } - if (context_ptr->enable_hccl()) { - mem_size = mem_manager_->GetCommonAlignSize(mem_size); - } - total_size += mem_size; - align_size_list.emplace_back(mem_size); - } - uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); - for (size_t j = 0; j < align_size_list.size(); ++j) { - std::string output_format = AnfAlgo::GetOutputFormat(node, j); - auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); - auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type); - AnfAlgo::SetOutputAddr(address, j, node.get()); - output_ptr += align_size_list[j]; - } -} - -DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) { - MS_EXCEPTION_IF_NULL(anf_node); - auto kernel_mod = AnfAlgo::GetKernelMod(anf_node); - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.size() <= index) { - MS_LOG(EXCEPTION) << "Previous node output size < node index"; - } - std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index); - auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index); - auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type); - AnfAlgo::SetOutputAddr(address, index, anf_node.get()); - return address; -} - -void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(mem_manager_); - size_t total_size = 0; - std::vector> addr_size; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { - auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); - auto input_node = input_node_with_index.first; - DeviceAddressPtr address = nullptr; - if (input_node->isa()) { - address = PreAssignCNodeMemory(input_node, input_node_with_index.second); - } else { - MS_LOG(EXCEPTION) << "Communication node inputs only support CNode"; - } - MS_EXCEPTION_IF_NULL(address); - auto mem_size = mem_manager_->GetCommonAlignSize(address->size()); - total_size += mem_size; - addr_size.emplace_back(address.get(), mem_size); - } - uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size); - for (const auto &iter : addr_size) { - MS_EXCEPTION_IF_NULL(iter.first); - iter.first->set_ptr(input_ptr); - input_ptr += iter.second; - } -} - -void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(mem_manager_); - if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { - MS_LOG(INFO) << "GetNext disable mem_reuse"; - flag = kDynamicMem; - } - auto kernel_mod = AnfAlgo::GetKernelMod(node); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.empty()) { - MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size."; - return; - } - for (size_t i = 0; i < output_sizes.size(); ++i) { - if ((kGetAllOuts != index) && (SizeToInt(i) != index)) { - continue; - } - if (NodeOutputDeviceAddressExist(node, i)) { - MS_LOG(INFO) << "Already malloc index:" << i; - continue; - } - auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]); - if (ptr == nullptr) { - // reused ptr, no need alloc, continue; - continue; - } - std::string output_format = AnfAlgo::GetOutputFormat(node, i); - auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); - auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type); - device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); - AnfAlgo::SetOutputAddr(device_address, i, node.get()); - } -} - -void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, - size_t output_idx) { - MS_EXCEPTION_IF_NULL(value_node); - MS_EXCEPTION_IF_NULL(node_value); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - auto tensor = node_value->cast(); - if (tensor == nullptr) { - MS_LOG(WARNING) << "Tensor is null"; - return; - } - size_t tensor_size = tensor->data().nbytes(); - auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); - if (output_type_id == kTypeUnknown) { - output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); - } - auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); - DeviceAddressPtr address = nullptr; - if (ms_context->enable_pynative_infer()) { - address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); - MS_EXCEPTION_IF_NULL(address); - if (!mem_manager_->MallocMemFromMemPool(address, node_size)) { - MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; - } - } else { - auto ptr = mem_manager_->MallocMem(kStaticMem, node_size); - address = CreateDeviceAddress(ptr, node_size, output_format, output_type_id); - MS_EXCEPTION_IF_NULL(address); - } - AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); - if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), - tensor->data_c())) { - MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" - << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " - << AnfAlgo::GetOutputInferDataType(value_node, output_idx); - } -} - -void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - for (auto &value_node : graph->graph_value_nodes()) { - MS_EXCEPTION_IF_NULL(value_node); - if (NodeOutputDeviceAddressExist(value_node, 0)) { - MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist"; - continue; - } - auto &node_value = value_node->value(); - MS_EXCEPTION_IF_NULL(node_value); - if (node_value->isa()) { - AssignValueNodeTensor(value_node, node_value, 0); - } else if (node_value->isa()) { - auto value = GetValue(node_value); - size_t tensor_size = value.size(); - DeviceAddressPtr address = nullptr; - if (ms_context->enable_pynative_infer()) { - address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); - MS_EXCEPTION_IF_NULL(address); - if (!mem_manager_->MallocMemFromMemPool(address, tensor_size)) { - MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; - } - } else { - auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); - address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); - MS_EXCEPTION_IF_NULL(address); - } - AnfAlgo::SetOutputAddr(address, 0, value_node.get()); - std::vector shape = {1, SizeToInt(tensor_size)}; - if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) { - MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!"; - } - } - } -} - -void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); - auto mem_flag = kDynamicMem; - if (is_enable_mem_reuse) { - mem_manager_->MallocReusedDynamicMem(graph); - mem_flag = kReuseDynamicMem; - } - auto &execution_nodes = graph->execution_order(); - std::vector compute_nodes; - // communication nodes first - for (auto &node : execution_nodes) { - if (AnfAlgo::IsCommunicationOp(node)) { - // skip if the memory is already alocated - AssignCommunicationNodeMem(mem_flag, node); - } else { - compute_nodes.emplace_back(node); - } - } - - // then compute nodes - for (auto &node : compute_nodes) { - AssignNodeOutputMem(mem_flag, node, kGetAllOuts); - AssignWorkSpaceMem(mem_flag, node); - } -} - -void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto kernel_mod = AnfAlgo::GetKernelMod(node); - MS_EXCEPTION_IF_NULL(kernel_mod); - size_t index = 0; - for (auto &size : kernel_mod->GetWorkspaceSizeList()) { - auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size); - AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); - index++; - } -} - -void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, - AddressPtrList *kernel_outputs) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_inputs); - MS_EXCEPTION_IF_NULL(kernel_workspaces); - MS_EXCEPTION_IF_NULL(kernel_outputs); - auto cnode = kernel->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { - return GenAddrCleanLaunchArgs(cnode, kernel_inputs); - } - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); - auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); - MS_EXCEPTION_IF_NULL(device_address); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(input->addr); - input->size = device_address->size_; - kernel_inputs->emplace_back(input); - } - - for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) { - auto device_address = AnfAlgo::GetOutputAddr(kernel, i); - kernel::AddressPtr output = std::make_shared(); - MS_EXCEPTION_IF_NULL(output); - output->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(output->addr); - output->size = device_address->size_; - kernel_outputs->emplace_back(output); - } - - for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) { - auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(workspace->addr); - workspace->size = device_address->size_; - kernel_workspaces->emplace_back(workspace); - } -} - -void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) { - if (cnode->inputs().size() != 2) { - MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2."; - } - MS_EXCEPTION_IF_NULL(cnode->inputs()[1]); - auto pre_node = (cnode->inputs()[1])->cast(); - // set clean output address - if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { - auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); - for (auto index : clean_output_indexs) { - auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(input->addr); - input->size = device_address->size_; - kernel_inputs->emplace_back(input); - } - MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); - } - // set clean workspace address - if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { - auto clean_workspaces_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); - for (const auto &index : clean_workspaces_indexs) { - auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(workspace->addr); - workspace->size = device_address->size_; - kernel_inputs->emplace_back(workspace); - } - } -} - -bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { - auto &kernels = graph.execution_order(); - for (const auto &kernel : kernels) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); - auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); - if (!ret) { - MS_LOG(ERROR) << "Launch kernel failed."; - return false; - } - } - return true; -} - -bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - if (!LaunchKernelMod(*graph)) { - MS_LOG(ERROR) << "LaunchKernelMod failed!"; - return false; - } - return true; -} - -void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { - MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; -} - -#ifdef ENABLE_DUMP_E2E -bool KernelRuntime::SetDumpConf() { - dump_conf_ptr_ = std::make_shared(); - MS_EXCEPTION_IF_NULL(dump_conf_ptr_); - bool ret = dump_conf_ptr_->SetDumpConfFromJsonFile(); - return ret; -} - -DumpConfPtr KernelRuntime::GetDumpConf() { return dump_conf_ptr_; } -#endif -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h deleted file mode 100644 index 8c6a5eb19b..0000000000 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ -#include -#include -#include -#include - -#include "device/device_address.h" -#include "ir/tensor.h" -#include "predict/generator/utils/ir_model_util.h" -#ifdef ENABLE_DUMP_E2E -#include "debug/e2e_dump.h" -#endif -#ifdef ENABLE_DEBUGGER -#include "debug/debugger/debugger.h" -#endif -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel.h" -#include "utils/context/ms_context.h" -#include "device/memory_manager.h" - -using mindspore::tensor::Tensor; -using std::vector; -using TensorPtr = std::shared_ptr; -using mindspore::kernel::AddressPtr; -using AddressPtrList = std::vector; - -namespace mindspore { -#ifndef ENABLE_DEBUGGER -class Debugger; -#endif -namespace device { -class KernelRuntime { - public: - KernelRuntime() = default; - virtual ~KernelRuntime(); - virtual bool Init() = 0; - virtual void AssignMemory(session::KernelGraph *graph); - void RunOpAssignMemory(const std::vector &input_tensors, session::KernelGraph *graph); - void RunOpClearMemory(const session::KernelGraph *graph); - virtual bool Run(session::KernelGraph *graph); - virtual bool DumpData(session::KernelGraph *graph); - virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger); - virtual bool RunTask(const session::KernelGraph *graph); - virtual bool GenTask(const session::KernelGraph *graph); - bool LaunchKernel(const session::KernelGraph *graph); - virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); - virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); - virtual void ClearGraphRuntimeResource(uint32_t graph_id); - virtual bool SyncStream() = 0; - -#ifdef ENABLE_DUMP_E2E - DumpConfPtr GetDumpConf(); -#endif - virtual bool LoadTask(const session::KernelGraph *graph); - // for GPU and D to impl - virtual void ReleaseDeviceRes() {} - void set_device_id(uint32_t device_id) { device_id_ = device_id; } - - protected: - virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) = 0; - virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); - void AssignStaticMemory(session::KernelGraph *graph); - void AssignDynamicMemory(session::KernelGraph *graph); - void ReuseAssignDynamicMemory(session::KernelGraph *graph); - void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); - void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); - void AssignReuseWorkSpaceMem(const AnfNodePtr &node); - - void UpdateRefNodeOutputMem(const session::KernelGraph *graph); - - void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); - void AssignCommunicationNodeInputMem(const AnfNodePtr &node); - void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); -#ifdef ENABLE_DUMP_E2E - bool SetDumpConf(); -#endif - - private: - void AssignStaticMemoryOutput(session::KernelGraph *graph); - void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); - bool LaunchKernelMod(const session::KernelGraph &graph); - void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); - size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); - void RunOpAssignInputMemory(const std::vector &input_tensors, const session::KernelGraph *graph); - void RunOpAssignOutputMemory(const AnfNodePtr &kernel); - void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); - void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); - DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); - - protected: - uint32_t device_id_{0}; -#ifdef ENABLE_DUMP_E2E - DumpConfPtr dump_conf_ptr_; -#endif - void *stream_ = nullptr; - std::shared_ptr mem_manager_{nullptr}; -}; -using KernelRuntimePtr = std::shared_ptr; -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/device/kernel_runtime_manager.cc b/mindspore/ccsrc/device/kernel_runtime_manager.cc deleted file mode 100644 index 29d74762b4..0000000000 --- a/mindspore/ccsrc/device/kernel_runtime_manager.cc +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/kernel_runtime_manager.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -void KernelRuntimeManager::ClearRuntimeResource() { - std::lock_guard guard(lock_); - for (auto &iter : runtime_map_) { - MS_LOG(INFO) << "Release device " << iter.first; - MS_EXCEPTION_IF_NULL(iter.second); - iter.second->ReleaseDeviceRes(); - } - runtime_map_.clear(); -} - -void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) { - std::lock_guard guard(lock_); - for (auto &iter : runtime_map_) { - MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource"; - if (!iter.second) { - MS_LOG(ERROR) << "Kernel runtime is nullptr"; - continue; - } - iter.second->ClearGraphRuntimeResource(graph_id); - } -} - -void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) { - if (runtime_creators_.find(device_name) == runtime_creators_.end()) { - (void)runtime_creators_.emplace(device_name, runtime_creator); - } -} - -KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) { - std::string runtime_key = device_name + "_" + std::to_string(device_id); - auto runtime_iter = runtime_map_.find(runtime_key); - if (runtime_iter != runtime_map_.end()) { - return runtime_iter->second.get(); - } else if (runtime_map_.size() > 0) { - auto cur_runtime_key = runtime_map_.begin()->first; - auto find_pos = cur_runtime_key.rfind('_'); - if (find_pos != std::string::npos) { - if (cur_runtime_key.size() > find_pos + 1) { - auto cur_device_id = cur_runtime_key.substr(find_pos + 1); - MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id - << ", set device id: " << device_id << " failed"; - } else { - MS_LOG(EXCEPTION) << "Can't change device id in runtime, current runtime_key size error, set device id: " - << device_id << " failed"; - } - } - } - return GetKernelRuntime(device_name, device_id); -} - -KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) { - std::lock_guard guard(lock_); - std::string runtime_key = device_name + "_" + std::to_string(device_id); - auto runtime_iter = runtime_map_.find(runtime_key); - if (runtime_iter != runtime_map_.end()) { - return runtime_iter->second.get(); - } - std::shared_ptr kernel_runtime; - auto creator_iter = runtime_creators_.find(device_name); - if (creator_iter != runtime_creators_.end()) { - MS_EXCEPTION_IF_NULL(creator_iter->second); - kernel_runtime = (creator_iter->second)(); - kernel_runtime->set_device_id(device_id); - MS_EXCEPTION_IF_NULL(kernel_runtime); - runtime_map_[runtime_key] = kernel_runtime; - } else { - MS_LOG(EXCEPTION) << "No kernel runtime creator for " << device_name << " with device id " << device_id; - } - - return kernel_runtime.get(); -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_runtime_manager.h b/mindspore/ccsrc/device/kernel_runtime_manager.h deleted file mode 100644 index 89b45ff5f8..0000000000 --- a/mindspore/ccsrc/device/kernel_runtime_manager.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "device/kernel_runtime.h" -namespace mindspore { -namespace device { -using KernelRuntimeCreator = std::function()>; - -class KernelRuntimeManager { - public: - static KernelRuntimeManager &Instance() { - static KernelRuntimeManager instance; - return instance; - } - void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator); - KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id); - KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id); - void ClearRuntimeResource(); - void ClearGraphResource(uint32_t graph_id); - - private: - KernelRuntimeManager() = default; - ~KernelRuntimeManager() = default; - DISABLE_COPY_AND_ASSIGN(KernelRuntimeManager); - std::map > runtime_map_; - std::map runtime_creators_; - std::mutex lock_; -}; - -class KernelRuntimeRegistrar { - public: - KernelRuntimeRegistrar(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) { - KernelRuntimeManager::Instance().Register(device_name, std::move(runtime_creator)); - } - ~KernelRuntimeRegistrar() = default; -}; - -#define MS_REG_KERNEL_RUNTIME(DEVICE_NAME, RUNTIME_CLASS) \ - static const KernelRuntimeRegistrar g_kernel_runtime_##DEVICE_NAME##_reg( \ - DEVICE_NAME, []() { return std::make_shared(); }); -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ diff --git a/mindspore/ccsrc/device/memory_manager.cc b/mindspore/ccsrc/device/memory_manager.cc deleted file mode 100644 index 5efbcd8a36..0000000000 --- a/mindspore/ccsrc/device/memory_manager.cc +++ /dev/null @@ -1,203 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/memory_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/context/ms_context.h" -using mindspore::memreuse::BestFitMemReuse; -using mindspore::memreuse::MemReuseUtilPtr; -namespace mindspore { -namespace device { -size_t MemoryManager::GetCommonAlignSize(size_t input_size) const { - return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; -} - -size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const { - return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize; -} - -void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - // set all infos - mem_reuse_util_ptr->SetAllInfo(graph); - auto bestfit_mem_reuse = std::make_shared(); - MS_EXCEPTION_IF_NULL(bestfit_mem_reuse); - bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get()); - size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize(); - MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]"; - mem_reuse_util_ptr_ = mem_reuse_util_ptr; - auto base_ptr = MallocDynamicMem(total_allocated_size, false); - mem_reuse_util_ptr_->set_mem_base(base_ptr); -} - -uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { - MS_EXCEPTION_IF_NULL(node); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - uint8_t *ptr = nullptr; - if (AnfAlgo::IsCommunicationOp(node)) { - bool communication_mem = false; - if (context_ptr->enable_hccl()) { - communication_mem = true; - } - if (flag == kStaticMem) { - ptr = MallocStaticMem(size, communication_mem); - } else { - ptr = MallocDynamicMem(size, communication_mem); - } - return ptr; - } - - if (flag == kStaticMem) { - ptr = MallocStaticMem(size, false); - } else if (flag == kDynamicMem) { - ptr = MallocDynamicMem(size, false); - } else if (flag == kReuseDynamicMem) { - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); - ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); - } - return ptr; -} - -uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { - if (flag == kReuseDynamicMem) { - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); - return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index); - } - return MallocDynamicMem(size, false); -} - -uint8_t *MemoryManager::MallocMem(int flag, size_t size) { - uint8_t *ptr = nullptr; - if (flag == kStaticMem) { - ptr = MallocStaticMem(size, false); - } else if (flag == kDynamicMem) { - ptr = MallocDynamicMem(size, false); - } - return ptr; -} - -uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem) { - size_t align_size = 0; - if (communication_mem) { - align_size = GetCommunicationAlignSize(size); - } else { - align_size = GetCommonAlignSize(size); - } - if (static_mem_offset_ < align_size) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] failed!"; - } - total_static_size_ += align_size; - auto offset = static_mem_offset_ - align_size; - if (dynamic_mem_offset_ > offset) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] failed!"; - } - static_mem_offset_ = offset; - if (communication_mem) { - return device_mem_base_ + offset + kMemAlignSize; - } else { - return device_mem_base_ + offset; - } -} - -uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { - size_t align_size = 0; - if (communication_mem) { - align_size = GetCommunicationAlignSize(size); - } else { - align_size = GetCommonAlignSize(size); - } - uint64_t offset = dynamic_mem_offset_; - auto new_offset = dynamic_mem_offset_ + align_size; - if (new_offset > static_mem_offset_) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] failed!"; - } - total_dynamic_size_ += align_size; - dynamic_mem_offset_ = new_offset; - - if (communication_mem) { - return device_mem_base_ + offset + kMemAlignSize; - } else { - return device_mem_base_ + offset; - } -} - -bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) { - auto device_ptr = MallocMemFromMemPool(size); - if (!device_ptr) { - return false; - } - address->ptr_ = device_ptr; - address->from_mem_pool_ = true; - return true; -} - -void *MemoryManager::MallocMemFromMemPool(size_t size) { - if (size == 0) { - MS_LOG(ERROR) << "MallocMemFromMemPool size is 0."; - } - return nullptr; -} - -void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) { - MS_EXCEPTION_IF_NULL(address); - MS_EXCEPTION_IF_NULL(address->ptr_); - FreeMemFromMemPool(address->ptr_); - address->ptr_ = nullptr; -} - -void MemoryManager::FreeMemFromMemPool(void *device_ptr) { - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; - } -} - -bool MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, - std::vector size_list) { - auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); - if (device_ptr_list.size() == 0) { - return false; - } - if (addr_list.size() != device_ptr_list.size()) { - MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; - } - for (size_t i = 0; i < addr_list.size(); i++) { - MS_EXCEPTION_IF_NULL(device_ptr_list[i]); - MS_EXCEPTION_IF_NULL(addr_list[i]); - addr_list[i]->ptr_ = device_ptr_list[i]; - addr_list[i]->from_mem_pool_ = true; - } - return true; -} - -std::vector MemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { - if (total_size == 0) { - MS_LOG(ERROR) << "MallocContinuousMemFromMemPool total_size is 0."; - } - std::vector device_ptr_list; - device_ptr_list.emplace_back(nullptr); - return device_ptr_list; -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/memory_manager.h b/mindspore/ccsrc/device/memory_manager.h deleted file mode 100644 index be250e0f3f..0000000000 --- a/mindspore/ccsrc/device/memory_manager.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ -#include -#include -#include "pre_activate/mem_reuse/mem_reuse.h" -#include "pre_activate/mem_reuse/mem_reuse_allocator.h" -namespace mindspore { -namespace device { -const int kStaticMem = 0; -const int kDynamicMem = 1; -const int kReuseDynamicMem = 2; -const int kGetAllOuts = -1; -const uint64_t kMemAlignSize = 512; -using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; - -class MemoryManager { - public: - MemoryManager() = default; - virtual ~MemoryManager() = default; - - virtual void MallocDeviceMemory() = 0; - virtual void FreeDeviceMemory() = 0; - void ResetDynamicMemory() { - total_dynamic_size_ = 0; - dynamic_mem_offset_ = 0; - } - - void MallocReusedDynamicMem(session::KernelGraph *graph); - uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size); - uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size); - virtual uint8_t *MallocMem(int flag, size_t size); - - virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); - virtual void *MallocMemFromMemPool(size_t size); - virtual void FreeMemFromMemPool(const DeviceAddressPtr address); - virtual void FreeMemFromMemPool(void *device_ptr); - virtual bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, - std::vector size_list); - virtual std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); - - size_t GetCommonAlignSize(size_t input_size) const; - size_t GetCommunicationAlignSize(size_t input_size) const; - - protected: - virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem); - virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem); - uint8_t *device_mem_base_{nullptr}; - uint64_t device_mem_size_{0}; - uint64_t dynamic_mem_offset_{0}; - uint64_t static_mem_offset_{0}; - size_t total_static_size_ = 0; - size_t total_dynamic_size_ = 0; - MemReuseUtilPtr mem_reuse_util_ptr_{nullptr}; -}; -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/operator/CMakeLists.txt b/mindspore/ccsrc/frontend/operator/CMakeLists.txt new file mode 100644 index 0000000000..0b6dd77c69 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE _OPERATOR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_OPERATOR_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) +add_library(_mindspore_frontend_operator_obj OBJECT ${_OPERATOR_SRC_FILES}) diff --git a/mindspore/ccsrc/frontend/operator/cc_implementations.cc b/mindspore/ccsrc/frontend/operator/cc_implementations.cc new file mode 100644 index 0000000000..3ec3455be7 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/cc_implementations.cc @@ -0,0 +1,432 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/cc_implementations.h" +#include +#include +#include +#include +#include +#include "utils/misc.h" +#include "utils/log_adapter.h" +#include "utils/convert_utils.h" +#include "common/utils.h" + +namespace mindspore { +// namespace to support primitive operators definition +namespace prim { +enum class DataType { kInt, kFloat, kDouble, kUnknown }; + +// Whether has a T type data in AnyPtrList. +template +bool HasType(const AnyPtrList &list) { + bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is(); }); + return ret; +} + +DataType InferType(const AnyPtrList &list) { + if (HasType(list)) { + return DataType::kDouble; + } else if (HasType(list)) { + return DataType::kFloat; + } else if (HasType(list)) { + return DataType::kInt; + } + return DataType::kUnknown; +} + +enum OpType { ADD, SUB, MUL, DIV, MOD }; + +template +bool IsSignedIntOverflow(T x, T y, OpType opType) { + auto max = std::numeric_limits::max(); + auto min = std::numeric_limits::min(); + + if (opType == OpType::ADD) { + return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x); + } + + if (opType == OpType::SUB) { + return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x); + } + + if (opType == OpType::MUL) { + return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || + (x > 0 && y < 0 && (min / y) < x) || (x < 0 && y > 0 && (min / y) > x); + } + + if (opType == OpType::DIV || opType == OpType::MOD) { + return x == min && static_cast(y) == -1; + } + + MS_LOG(EXCEPTION) << "Unsupported operation type."; +} + +template +T InnerScalarAdd(T x, T y) { + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::ADD)) { + MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x) + << ", y: " << std::to_string(y) << "."; + } + return x + y; +} + +template +T InnerScalarSub(T x, T y) { + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::SUB)) { + MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x) + << ", y: " << std::to_string(y) << "."; + } + return x - y; +} + +template +T InnerScalarMul(T x, T y) { + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::MUL)) { + MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x) + << ", y: " << std::to_string(y) << "."; + } + return x * y; +} + +template +float InnerScalarDiv(T x, T y) { + if (y == 0) { + MS_LOG(EXCEPTION) << "Divisor could not be zero"; + } + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::DIV)) { + MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x) + << ", y: " << std::to_string(y) << "."; + } + return static_cast(x) / static_cast(y); +} + +template +T InnerScalarFloordiv(T x, T y) { + auto ret = std::floor(InnerScalarDiv(x, y)); + if (std::is_integral::value) { + return static_cast(ret); + } + return ret; +} + +template +T InnerScalarMod(T x, T y) { + if (y == 0) { + MS_LOG(EXCEPTION) << "Could not mod to zero."; + } + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::MOD)) { + MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x) + << ", y: " << std::to_string(y) << "."; + } + if (std::is_integral::value) { + return static_cast(x) % static_cast(y); + } + int x_int = std::floor(x); + int y_int = std::ceil(y); + int max = x_int / y_int; + float ret = x - y * max; + return ret; +} + +template +T InnerScalarPow(T x, U y) { + return std::pow(x, y); +} + +template +bool InnerScalarEq(T x, U y) { + double error = static_cast(x) - static_cast(y); + error = fabs(error); + return error < DBL_EPSILON; +} + +template +bool InnerScalarLt(T x, U y) { + return x < y; +} + +template +bool InnerScalarGt(T x, U y) { + return x > y; +} + +template +bool InnerScalarNe(T x, U y) { + return !InnerScalarEq(x, y); +} + +template +bool InnerScalarLe(T x, U y) { + return x <= y; +} + +template +bool InnerScalarGe(T x, U y) { + return x >= y; +} + +#define SCALAR_OP(op_t) \ + ValuePtr Scalar##op_t(const ValuePtrList &list) { \ + do { \ + if (list.size() < 2) { \ + MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ + } \ + ValuePtr x = list[0]; \ + ValuePtr y = list[1]; \ + MS_EXCEPTION_IF_NULL(x); \ + MS_EXCEPTION_IF_NULL(y); \ + if (x->isa() && y->isa()) { \ + double sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + float sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + int sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + float sum = InnerScalar##op_t(IntToFloat(GetValue(x)), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + float sum = InnerScalar##op_t(GetValue(x), IntToFloat(GetValue(y))); \ + return MakeValue(sum); \ + } \ + MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \ + << ", y: " << y->ToString(); \ + } while (0); \ + } + +SCALAR_OP(Add) +SCALAR_OP(Sub) +SCALAR_OP(Mul) +SCALAR_OP(Div) +SCALAR_OP(Mod) +SCALAR_OP(Pow) +SCALAR_OP(Floordiv) + +#define LOGIC_OP(op_t) \ + ValuePtr Scalar##op_t(const ValuePtrList &list) { \ + if (list.size() < 2) { \ + MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ + } \ + ValuePtr x = list[0]; \ + ValuePtr y = list[1]; \ + MS_EXCEPTION_IF_NULL(x); \ + MS_EXCEPTION_IF_NULL(y); \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \ + << ", y: " << y->ToString() << "."; \ + } + +LOGIC_OP(Eq) +LOGIC_OP(Lt) +LOGIC_OP(Gt) +LOGIC_OP(Ne) +LOGIC_OP(Le) +LOGIC_OP(Ge) + +ValuePtr ScalarUAdd(const ValuePtrList &list) { + if (list.size() != 1) { + MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); + } + ValuePtr x = list[0]; + MS_EXCEPTION_IF_NULL(x); + return x; +} + +ValuePtr ScalarUSub(const ValuePtrList &list) { + if (list.size() != 1) { + MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); + } + ValuePtr x = list[0]; + MS_EXCEPTION_IF_NULL(x); + + if (x->isa()) { + int32_t sum = -1 * GetValue(x); + return MakeValue(sum); + } + if (x->isa()) { + float sum = -1.0f * GetValue(x); + return MakeValue(sum); + } + + MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; +} + +ValuePtr ScalarLog(const ValuePtrList &list) { + if (list.empty()) { + MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; + } + ValuePtr x = list[0]; + MS_EXCEPTION_IF_NULL(x); + + if (x->isa()) { + double v = log(GetValue(x)); + return MakeValue(v); + } + if (x->isa()) { + auto v = static_cast(log(GetValue(x))); + return MakeValue(v); + } + + MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); +} + +ValuePtr BoolNot(const ValuePtrList &list) { + if (list.empty()) { + MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; + } + ValuePtr x = list[0]; + MS_EXCEPTION_IF_NULL(x); + bool convert = false; + + if (ValueToBool(x, &convert)) { + auto res = !convert; + return MakeValue(res); + } + + MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); +} + +ValuePtr BoolAnd(const ValuePtrList &list) { + if (list.size() < 2) { + MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; + } + ValuePtr x = list[0]; + ValuePtr y = list[1]; + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(y); + bool x_b = false; + bool y_b = false; + + if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { + auto res = x_b && y_b; + return MakeValue(res); + } + + MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; +} + +ValuePtr BoolOr(const ValuePtrList &list) { + if (list.size() < 2) { + MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; + } + ValuePtr x = list[0]; + ValuePtr y = list[1]; + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(y); + bool x_b = false; + bool y_b = false; + + if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { + auto res = x_b || y_b; + return MakeValue(res); + } + + MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; +} + +ValuePtr BoolEq(const ValuePtrList &list) { + if (list.size() < 2) { + MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; + } + ValuePtr x = list[0]; + ValuePtr y = list[1]; + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(y); + bool x_b = false; + bool y_b = false; + + if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { + auto res = x_b == y_b; + return MakeValue(res); + } + + MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << "."; +} + +std::vector BroadcastShape_(std::vector shpx, std::vector shpy) { + int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); + if (dlen < 0) { + for (int i = 0; i < -dlen; ++i) { + (void)shpx.insert(shpx.begin(), 1); + } + } else if (dlen > 0) { + for (int i = 0; i < dlen; i++) { + (void)shpy.insert(shpy.begin(), 1); + } + } + if (shpx.size() != shpy.size()) { + MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; + } + std::vector shp; + for (size_t i = 0; i < shpx.size(); i++) { + auto a = shpx[i]; + auto b = shpy[i]; + if (a == 1) { + shp.push_back(b); + } else if (b == 1) { + shp.push_back(a); + } else if (a == -1) { + shp.push_back(b); + } else if (b == -1) { + shp.push_back(a); + } else if (a == b) { + shp.push_back(a); + } else { + return std::vector(); + } + } + return shp; +} +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/operator/cc_implementations.h b/mindspore/ccsrc/frontend/operator/cc_implementations.h similarity index 100% rename from mindspore/ccsrc/operator/cc_implementations.h rename to mindspore/ccsrc/frontend/operator/cc_implementations.h diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc new file mode 100644 index 0000000000..7d2573e50a --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -0,0 +1,971 @@ + +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/composite.h" +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "frontend/operator/cc_implementations.h" +#include "frontend/optimizer/opt.h" +#include "utils/symbolic.h" +#include "pybind_api/api_register.h" +#include "./common.h" +#include "ir/signature.h" +#include "debug/trace.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using AbstractTensor = mindspore::abstract::AbstractTensor; +using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; + +using mindspore::abstract::AbstractAttribute; +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractClass; +using mindspore::abstract::AbstractDictionary; +using mindspore::abstract::AbstractDictionaryPtr; +using mindspore::abstract::AbstractEllipsis; +using mindspore::abstract::AbstractEllipsisPtr; +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractFunctionPtr; +using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractNone; +using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractSlice; +using mindspore::abstract::AbstractTuple; + +ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul}, + {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod}, + {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt}, + {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe}, + {"__ge__", kPrimScalarGe}}; + +const MetaFuncGraphPtr kTail = std::make_shared("tail"); + +// copy from python API: reduce. +// Apply a function of two arguments cumulatively to the items of a sequence, +// from left to right, so as to reduce the sequence to a single value.For example, +// reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). +AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) { + std::shared_ptr ret; + size_t size = list.size(); + if (size < 2) { + MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; + } + + AnyPtrList input; + input.push_back(list[0]); + input.push_back(list[1]); + ret = std::make_shared(func(input)); + + for (size_t i = 2; i < size; ++i) { + input.clear(); + input.push_back(ret); + input.push_back(list[i]); + ret = std::make_shared(func(input)); + } + + return ret; +} + +AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector &list) { + size_t size = list.size(); + if (size < 2) { + MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; + } + + std::vector input; + input.push_back(list[0]); + input.push_back(list[1]); + AnfNodePtr ret = func(input); + + for (size_t i = 2; i < size; ++i) { + input.clear(); + input.push_back(ret); + input.push_back(list[i]); + ret = func(input); + } + + return ret; +} + +ValuePtr kCompositeHyperMap = std::make_shared(); + +void HyperMap::Init() { + if (fn_leaf_) { + name_ = "hyper_map[" + fn_leaf_->name() + "]"; + } + signatures_ = + // def hypermap(func:read, *args:ref): + std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, + {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); +} + +HyperMap::HyperMap(const std::shared_ptr &fn_leaf) + : MetaFuncGraph("hyper_map"), + fn_leaf_(fn_leaf), + broadcast_(false), + nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { + Init(); +} + +HyperMap::HyperMap(const HyperMap &h) + : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { + Init(); +} + +AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map) { + MS_EXCEPTION_IF_NULL(func_graph); + std::vector inputs; + if (fn_arg != nullptr) { + inputs.push_back(fn_arg); + } else { + inputs.push_back(NewValueNode(fn_leaf_)); + } + + (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), + [](const std::pair &item) { return item.first; }); + return func_graph->NewCNode(inputs); +} + +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(type); + + std::size_t size = type->elements().size(); + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { + auto lhs = std::static_pointer_cast(item.second); + MS_EXCEPTION_IF_NULL(lhs); + return lhs->elements().size() != size; + }); + if (is_not_same) { + MS_LOG(EXCEPTION) << "List in HyperMap should have same length"; + } + + // cannot use shared_from_base() also known as this, as it will make a reference cycle on + // hypermap and graph generated, it will cause memory leak. + auto fn_rec = NewValueNode(std::make_shared(*this)); + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeList)); + + for (int i = 0; i < SizeToInt(size); ++i) { + std::vector inputs2; + inputs2.push_back(fn_rec); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + (void)std::transform( + arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), + [&func_graph, i](const std::pair &item) { + return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); + }); + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(type); + + std::size_t size = type->elements().size(); + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { + auto lhs = std::static_pointer_cast(item.second); + MS_EXCEPTION_IF_NULL(lhs); + return lhs->elements().size() != size; + }); + if (is_not_same) { + MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length"; + } + + // cannot use shared_from_base() also known as this, as it will make a reference cycle on + // hypermap and graph generated, it will cause memory leak. + auto fn_rec = NewValueNode(std::make_shared(*this)); + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + + for (int i = 0; i < SizeToInt(size); ++i) { + std::vector inputs2; + inputs2.push_back(fn_rec); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + (void)std::transform( + arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair item) { + return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); + }); + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { + MS_EXCEPTION_IF_NULL(type); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); + inputs.push_back(NewValueNode(type)); + + // cannot use shared_from_base() also known as this, as it will make a reference cycle on + // hypermap and graph generated, it will cause memory leak. + auto fn_rec = NewValueNode(std::make_shared(*this)); + std::size_t attrSize = type->GetAttributes().size(); + for (std::size_t i = 0; i < attrSize; ++i) { + std::vector inputs2; + inputs2.push_back(fn_rec); + if (fn_arg) { + inputs2.push_back(fn_arg); + } + + int j = 0; + for (auto item : arg_map) { + inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); + j++; + } + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { + bool found = false; + TypeId id = kObjectTypeEnd; + std::pair pair; + for (auto &item : arg_map) { + pair = item; + id = item.second->type_id(); + if (nonleaf_.count(id)) { + found = true; + break; + } + } + + if (found) { + // In a nonleaf situation, all arguments must have the same generic. + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair &item) { + if (item.first != pair.first) { + return item.second->type_id() != pair.second->type_id(); + } + return false; + }); + if (is_not_same) { + std::ostringstream oss; + oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" + << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; + int idx = 0; + for (auto &item : arg_map) { + oss << ++idx << ": " << item.second->ToString() << "\n"; + } + MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); + } + } + + switch (id) { + case kObjectTypeList: { + auto type = std::static_pointer_cast(pair.second); + return FullMake(type, func_graph, fn_arg, arg_map); + } + case kObjectTypeTuple: { + auto type = std::static_pointer_cast(pair.second); + return FullMake(type, func_graph, fn_arg, arg_map); + } + case kObjectTypeClass: { + auto type = std::static_pointer_cast(pair.second); + return FullMake(type, func_graph, fn_arg, arg_map); + } + default: + return FullMake(pair.second, func_graph, fn_arg, arg_map); + } +} + +ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) { + TypePtr type_tensor = std::make_shared(); + bool flag = std::any_of( + args_spec_list.begin(), args_spec_list.end(), + [type_tensor](const std::pair &item) { return IsSubType(item.second, type_tensor); }); + if (flag && broadcast_) { + ArgsPairList ret; + for (auto &item : args_spec_list) { + if (!IsSubType(item.second, type_tensor)) { + TypePtr type_tensor_ele = std::make_shared(item.second); + ret.push_back( + std::make_pair(func_graph->NewCNode({NewValueNode(prim::kPrimScalarToArray), item.first}), type_tensor_ele)); + } else { + ret.push_back(std::make_pair(item.first, item.second)); + } + } + return ret; + } + return args_spec_list; +} + +FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { + FuncGraphPtr ptrGraph = std::make_shared(); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); + ptrGraph->debug_info()->set_name("hyper_map"); + + AnfNodePtr ptrFnArg = nullptr; + std::size_t i = 0; + ArgsPairList argmap; + ArgsPairList argmap2; + if (fn_leaf_ == nullptr) { + ptrFnArg = ptrGraph->add_parameter(); + i = 1; + } + + std::size_t size = args_spec_list.size(); + for (; i < size; ++i) { + argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); + } + + argmap2 = Harmonize(ptrGraph, argmap); + ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2)); + return ptrGraph; +} + +abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { + if (fn_leaf_ == nullptr) { + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + // Assert that hypermap's function param does not contain free variables + if (args_spec_list[0]->isa()) { + auto graph_func = dyn_cast(args_spec_list[0]); + auto func_graph = graph_func->func_graph(); + if (func_graph->parent() != nullptr) { + MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet."; + } + } + } + + AbstractBasePtrList broadened; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->Broaden(); + }); + return broadened; +} + +REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { + (void)py::class_>(*m, "HyperMap_") + .def(py::init>(), py::arg("leaf")) + .def(py::init<>()); + })); + +FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { + MS_EXCEPTION_IF_NULL(a_tuple); + + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ret->debug_info()->set_name("tail"); + AnfNodePtr ptrTup = ret->add_parameter(); + + std::vector elems; + elems.push_back(NewValueNode(prim::kPrimMakeTuple)); + + int tuple_size = SizeToInt(a_tuple->size()); + for (int i = 1; i < tuple_size; ++i) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)})); + } + + ret->set_output(ret->NewCNode(elems)); + return ret; +} + +FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { + MS_EXCEPTION_IF_NULL(a_list); + + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ret->debug_info()->set_name("tail"); + AnfNodePtr ptrList = ret->add_parameter(); + + std::vector elems; + elems.push_back(NewValueNode(prim::kPrimMakeList)); + + int list_size = SizeToInt(a_list->size()); + for (int i = 1; i < list_size; ++i) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)})); + } + + ret->set_output(ret->NewCNode(elems)); + return ret; +} + +FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; + } + + AbstractBasePtr a = args_spec_list[0]; + abstract::AbstractTuplePtr a_tuple = dyn_cast(a); + if (a_tuple != nullptr) { + return GenerateTupleFuncGraph(a_tuple); + } + + abstract::AbstractListPtr a_list = dyn_cast(a); + if (a_list != nullptr) { + return GenerateListFuncGraph(a_list); + } + + MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString(); +} + +REGISTER_PYBIND_DEFINE( + Tail_, ([](const py::module *m) { + (void)py::class_>(*m, "Tail_").def(py::init()); + })); + +FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + int tuple_size = SizeToInt(args_spec_list.size()); + + std::ostringstream ss; + ss << "▶make_tuple_" << tuple_size; + FuncGraphPtr fg = std::make_shared(); + fg->debug_info()->set_name(ss.str()); + + std::vector params; + params.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (int i = 0; i < tuple_size; ++i) { + params.push_back(fg->add_parameter()); + } + + // make fprob first result, maketuple's forward result. + AnfNodePtr out = fg->NewCNode(params); + + // make fprob second result, maketuple's backward function. + FuncGraphPtr b = std::make_shared(); + + ss.clear(); + ss << "◀make_tuple_" << tuple_size; + b->debug_info()->set_name(ss.str()); + AnfNodePtr dout = b->add_parameter(); + + std::vector grads; + grads.push_back(NewValueNode(prim::kPrimMakeTuple)); + grads.push_back(NewValueNode(newenv)); + for (int i = 0; i < tuple_size; ++i) { + grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); + } + + b->set_flag(FUNC_GRAPH_FLAG_CORE, true); + b->set_output(b->NewCNode(grads)); + + fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); + fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); + (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple)); + return fg; +} + +GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) + : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { + if (get_by_list) { + signatures_ = + // def grad(func:read, weight_list:ref): + std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, + {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}}); + } +} + +FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, + const std::vector ¶ms_list, const std::vector &args, + bool applyJ) { + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + + auto weights_node = weights; + if (weights == nullptr && !args.empty()) { + weights_node = ret->NewCNode(args); + } + + ValueNodePtr opsJ = NewValueNode(prim::kPrimJ); + ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem); + + std::vector inputs; + if (applyJ) { + inputs.push_back(opsJ); + inputs.push_back(node); + node = ret->NewCNode(inputs); + } + + std::vector params; + for (size_t i = 0; i < params_list.size(); ++i) { + params.push_back(ret->add_parameter()); + } + + inputs.clear(); + inputs.push_back(node); + (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs)); + AnfNodePtr cnode = ret->NewCNode(inputs); + + inputs.clear(); + inputs.push_back(opsTupleItem); + inputs.push_back(cnode); + inputs.push_back(NewValueNode(0)); + auto out = ret->NewCNode(inputs); + + inputs.clear(); + inputs.push_back(opsTupleItem); + inputs.push_back(cnode); + inputs.push_back(NewValueNode(1)); + AnfNodePtr ptrBprop = ret->NewCNode(inputs); + + doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem); + return ret; +} + +void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, + ValueNodePtr opsTupleItem) { + MS_EXCEPTION_IF_NULL(func_graph); + + AnfNodePtr ptrBPropArg = nullptr; + if (sens_param_) { + ptrBPropArg = func_graph->add_parameter(); + } else { + auto ones_like = prim::GetPythonOps("ones_like"); + ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out}); + } + + AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg}); + + CNodePtr fv_bprop = nullptr; + if (get_by_list_) { + // python code: grads = hyper_map(F.partial(env_get, env), weights) + AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)}); + AnfNodePtr partial_env_get = + func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); + MetaFuncGraphPtr hyper_map = std::make_shared(); + fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights}); + } + + CNodePtr inputs_bprop = nullptr; + if (get_all_) { + inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp}); + } + + // Gradients wrt inputs and parameters + if (fv_bprop != nullptr && inputs_bprop != nullptr) { + func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop})); + return; + } + + // Gradients wrt parameters + if (fv_bprop != nullptr) { + func_graph->set_output(fv_bprop); + return; + } + + // Gradients wrt inputs + if (inputs_bprop != nullptr) { + func_graph->set_output(inputs_bprop); + return; + } + + // Gradients wrt first input. + // ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input + func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)})); +} + +// Generate the graph. +FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.size() < 1) { + MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " + << args_spec_list.size() << "."; + } + + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + AbstractFunctionPtr fn = dyn_cast(args_spec_list[0]); + if (fn == nullptr) { + MS_LOG(EXCEPTION) << "GradOperation arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); + } + + // Waiting for implementation. + auto real_fn = dyn_cast(fn); + MS_EXCEPTION_IF_NULL(real_fn); + + FuncGraphPtr ptrGraph = real_fn->func_graph(); + MS_EXCEPTION_IF_NULL(ptrGraph); + TraceManager::DebugTrace(std::make_shared(ptrGraph->debug_info())); + FuncGraphPtr dfBuilder = std::make_shared(); + TraceManager::EndTrace(); + auto nparam = ptrGraph->parameters().size(); + + std::ostringstream ss; + ss << "grad{" << nparam << "}"; + dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true); + dfBuilder->debug_info()->set_name(ss.str()); + ParameterPtr param_graph = dfBuilder->add_parameter(); + + AnfNodePtr weights = nullptr; + if (get_by_list_) { + weights = dfBuilder->add_parameter(); + } + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimJ)); + inputs.push_back(param_graph); + auto jf = dfBuilder->NewCNode(inputs); + // df is checked in GetGrad + TraceManager::DebugTrace(std::make_shared(ptrGraph->debug_info())); + auto df = GetGrad(jf, weights, ptrGraph->parameters()); + TraceManager::EndTrace(); + dfBuilder->set_output(NewValueNode(df)); + + return dfBuilder; +} + +REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { + (void)py::class_>( + *m, "GradOperation_") + .def(py::init(), py::arg("fn")) + .def(py::init(), py::arg("fn"), py::arg("get_all"), + py::arg("get_by_list"), py::arg("sens_param")); + })); + +// Generate the ListMap func graph. +FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + size_t args_num = args_spec_list.size(); + // args: fn, list1, list2, ... + if (args_num < 2) { + MS_LOG(EXCEPTION) << "list_map takes at least two arguments"; + } + + for (size_t i = 1; i < args_num; ++i) { + if (typeid(args_spec_list[i]) != typeid(AbstractBase)) { + // The function currently not be use + MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'"; + } + } + + FuncGraphPtr fg_ptr = std::make_shared(); + fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); + fg_ptr->debug_info()->set_name("list_map"); + AnfNodePtr fn = fg_ptr->add_parameter(); + + std::vector lists; + for (size_t i = 1; i < args_num; ++i) { + lists.push_back(fg_ptr->add_parameter()); + } + + std::vector iters; + (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(std::string("list_iter")), item}); + }); + + std::vector nexts; + (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(std::string("next")), item}); + }); + + std::vector values; + (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item}); + }); + + (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)}); + }); + + (void)values.insert(values.begin(), fn); + AnfNodePtr cnode_graph = fg_ptr->NewCNode(values); + AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimMakeList), cnode_graph}); + + FuncGraphPtr fgnext_ptr = std::make_shared(); + fgnext_ptr->debug_info()->set_name("body"); + + FuncGraphPtr fgcond_ptr = std::make_shared(); + fgcond_ptr->debug_info()->set_name("cond"); + + MakeCond(lists, fgnext_ptr, fgcond_ptr); + MakeNext(lists, fgcond_ptr, fgnext_ptr); + + CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl}); + + auto inputs = output_cnode->inputs(); + (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); + output_cnode->set_inputs(inputs); + + fg_ptr->set_output(output_cnode); + return fg_ptr; +} + +void ListMap::MakeCond(const std::vector &lists, const FuncGraphPtr &fgnext_ptr, + const FuncGraphPtr &fg_ptr) { + MS_EXCEPTION_IF_NULL(fg_ptr); + + AnfNodePtr fn = fg_ptr->add_parameter(); + AnfNodePtr resl = fg_ptr->add_parameter(); + + std::vector iters; + (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), + [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); }); + + std::vector hasnexts; + (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(std::string("hasnext")), item}); + }); + + // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts) + FuncGraphPtr fgtrue_ptr = std::make_shared(); + fgtrue_ptr->debug_info()->set_name("ftrue"); + fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); + + CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl}); + auto inputs = fgtrue_output_cnode->inputs(); + (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); + fgtrue_output_cnode->set_inputs(inputs); + fgtrue_ptr->set_output(fgtrue_output_cnode); + + FuncGraphPtr fgfalse_ptr = std::make_shared(); + fgfalse_ptr->debug_info()->set_name("ffalse"); + fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); + fgfalse_ptr->set_output(resl); + + AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), + NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)}); + fgtrue_ptr->set_output(output_cnode); +} + +void ListMap::MakeNext(const std::vector &lists, const FuncGraphPtr &fgcond_ptr, + const FuncGraphPtr &fg_ptr) { + MS_EXCEPTION_IF_NULL(fg_ptr); + AnfNodePtr fn = fg_ptr->add_parameter(); + + std::vector iters; + (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), + [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); }); + + std::vector nexts; + (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(std::string("next")), item}); + }); + + std::vector values; + (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, nullptr}); + }); + + iters.clear(); + (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)}); + }); + + (void)values.insert(values.begin(), fn); + AnfNodePtr cnode_graph = fg_ptr->NewCNode(values); + AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimListAppend), cnode_graph}); + CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl}); + + auto inputs = output_cnode->inputs(); + (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); + output_cnode->set_inputs(inputs); + fg_ptr->set_output(output_cnode); +} + +FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // args: tuple1, tuple2 + abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); + AbstractBasePtr abs_a = args_spec_list[0]; + AbstractBasePtr abs_b = args_spec_list[1]; + + abstract::AbstractTuplePtr a_tuple = dyn_cast(abs_a); + abstract::AbstractTuplePtr b_tuple = dyn_cast(abs_b); + if (a_tuple == nullptr || b_tuple == nullptr) { + MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", " + << args_spec_list[1]->ToString(); + } + + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + AnfNodePtr p_tup_a = ret->add_parameter(); + AnfNodePtr p_tup_b = ret->add_parameter(); + + std::vector elems; + elems.push_back(NewValueNode(prim::kPrimMakeTuple)); + + int tuple_size = SizeToInt(a_tuple->size()); + for (int i = 0; i < tuple_size; ++i) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)})); + } + + tuple_size = SizeToInt(b_tuple->size()); + for (int i = 0; i < tuple_size; ++i) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)})); + } + + ret->set_output(ret->NewCNode(elems)); + return ret; +} + +int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) { + MS_EXCEPTION_IF_NULL(scalar); + return GetValue(scalar->BuildValue()); +} + +bool CheckIndexInRange(int index, int min, int max) { return (index >= min && index <= max); } + +int GetPositiveIndex(int index, int length) { + if (index < 0) { + index += length; + } + return index; +} + +int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) { + MS_EXCEPTION_IF_NULL(member); + + if (member->isa()) { + return GetArgScalarValue(dyn_cast(member), member_name); + } + + if (member->isa()) { + return default_value; + } + + MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString(); +} + +void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index, + int *stop_index, int *step_value) { + MS_EXCEPTION_IF_NULL(tuple); + MS_EXCEPTION_IF_NULL(slice); + MS_EXCEPTION_IF_NULL(start_index); + MS_EXCEPTION_IF_NULL(stop_index); + MS_EXCEPTION_IF_NULL(step_value); + + const std::string start_name("Slice start index"); + const std::string stop_name("Slice stop index"); + const std::string step_name("Slice step value"); + + int tuple_size = SizeToInt(tuple->size()); + int start_default = 0; + int stop_default = tuple_size; + int step_default = 1; + + *step_value = CheckSliceMember(slice->step(), step_default, step_name); + if (*step_value == 0) { + MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0."; + } + + if (*step_value < 0) { + start_default = tuple_size - 1; + stop_default = -1; + } + + *start_index = CheckSliceMember(slice->start(), start_default, start_name); + *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name); + if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) || + !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) { + MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index + << " out of range, tuple size " << tuple_size << "."; + } + + *start_index = GetPositiveIndex(*start_index, tuple_size); + if (!slice->stop()->isa()) { + *stop_index = GetPositiveIndex(*stop_index, tuple_size); + } +} + +FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // slice a tuple + // args: tuple, start index, end index, step + const std::string op_name("TupleSlice"); + abstract::CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr tuple = abstract::CheckArg(op_name, args_spec_list, 0); + AbstractSlicePtr slice = abstract::CheckArg(op_name, args_spec_list, 1); + + int start_index; + int stop_index; + int step_value; + GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value); + + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + AnfNodePtr p_tuple = ret->add_parameter(); + (void)ret->add_parameter(); + + std::vector elems; + elems.push_back(NewValueNode(prim::kPrimMakeTuple)); + if (step_value > 0) { + for (int index = start_index; index < stop_index; index = index + step_value) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); + } + } else { + for (int index = start_index; index > stop_index; index = index + step_value) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); + } + } + + ret->set_output(ret->NewCNode(elems)); + return ret; +} + +FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // select indexed item + // args: tuple of items, index + const std::string op_name = std::string("TupleGetItemTensor"); + abstract::CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr branches_abs = abstract::CheckArg(op_name, args_spec_list, 0); + AbstractBasePtrList branches = branches_abs->elements(); + if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa()) { + FuncGraphPtr ret_graph = std::make_shared(); + ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + AnfNodePtr functions = ret_graph->add_parameter(); + auto index = ret_graph->add_parameter(); + + ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); + return ret_graph; + } + + MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << "."; +} + +REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { + (void)py::class_>(*m, "TupleAdd_") + .def(py::init()); + })); + +REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { + (void)py::class_>(*m, "TupleSlice_") + .def(py::init()); + })); + +REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) { + (void)py::class_>( + *m, "TupleGetItemTensor_") + .def(py::init()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.h b/mindspore/ccsrc/frontend/operator/composite/composite.h new file mode 100644 index 0000000000..3821192dba --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/composite.h @@ -0,0 +1,192 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "frontend/operator/composite/zip_operation.h" +#include "frontend/operator/composite/list_append_operation.h" +#include "frontend/operator/composite/do_signature.h" +#include "frontend/operator/composite/unpack_call.h" +#include "frontend/operator/composite/multitype_funcgraph.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "utils/any.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using AbstractSlicePtr = abstract::AbstractSlicePtr; +using AbstractScalarPtr = abstract::AbstractScalarPtr; +using AbstractTensorPtr = abstract::AbstractTensorPtr; +using ElemwiseMap = std::unordered_map; +using ArgsPairList = std::vector>; + +class HyperMap : public MetaFuncGraph { + public: + explicit HyperMap(const std::shared_ptr &fn_leaf = nullptr); + HyperMap(const HyperMap &h); + void Init(); + HyperMap &operator=(const HyperMap &h) { + if (this != &h) { + fn_leaf_ = h.fn_leaf_; + broadcast_ = h.broadcast_; + nonleaf_ = h.nonleaf_; + if (fn_leaf_) { + name_ = "hyper_map[" + fn_leaf_->name() + "]"; + } + } + return *this; + } + ~HyperMap() override = default; + MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) + + abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; + MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } + + private: + AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); + ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); + + MultitypeFuncGraphPtr fn_leaf_; + bool broadcast_; + std::set nonleaf_; +}; +using HyperMapPtr = std::shared_ptr; + +class HyperMapPy : public HyperMap { + public: + explicit HyperMapPy(const std::shared_ptr &fn_leaf = nullptr) : HyperMap(fn_leaf) {} + ~HyperMapPy() override = default; + MS_DECLARE_PARENT(HyperMapPy, HyperMap) +}; +using HyperMapPyPtr = std::shared_ptr; + +extern ValuePtr kCompositeHyperMap; + +class Tail : public MetaFuncGraph { + public: + explicit Tail(const std::string &name) : MetaFuncGraph(name) {} + ~Tail() override = default; + MS_DECLARE_PARENT(Tail, MetaFuncGraph) + + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); + FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); + + friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } +}; +using TailPtr = std::shared_ptr; + +class MakeTupleGradient : public MetaFuncGraph { + public: + explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} + ~MakeTupleGradient() override = default; + MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } +}; +using MakeTupleGradientPtr = std::shared_ptr; + +class GradOperation : public MetaFuncGraph { + public: + explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, + bool sens_param = false); + ~GradOperation() override = default; + MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) + + FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector &ptrParams, + const std::vector &args = {}, bool applyJ = false); + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + bool sens_param() const { return sens_param_; } + bool get_all_; + bool get_by_list_; + bool sens_param_; + + private: + void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, + ValueNodePtr opsTupleItem); +}; +using GradOperationPtr = std::shared_ptr; + +class ListMap { + public: + explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } + ~ListMap() = default; + void MakeCond(const std::vector &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); + void MakeNext(const std::vector &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); + + private: + std::string name_; + std::map, FuncGraphPtr> cache_; +}; + +class TupleAdd : public MetaFuncGraph { + public: + explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} + ~TupleAdd() override = default; + MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } +}; +using TupleAddPtr = std::shared_ptr; + +class TupleSlice : public MetaFuncGraph { + public: + explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} + ~TupleSlice() override = default; + MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } +}; +using TupleSlicePtr = std::shared_ptr; + +class TupleGetItemTensor : public MetaFuncGraph { + public: + explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} + ~TupleGetItemTensor() override = default; + MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) { + return lhs.name_ == rhs.name_; + } +}; +using TupleGetItemTensorPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc new file mode 100644 index 0000000000..50be3c5b29 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc @@ -0,0 +1,338 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/do_signature.h" +#include +#include + +#include "abstract/abstract_value.h" +#include "ir/anf.h" +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "frontend/operator/cc_implementations.h" +#include "frontend/optimizer/opt.h" +#include "utils/symbolic.h" +#include "./common.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +const std::map type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt8, 2}, {kNumberTypeUInt8, 3}, + {kNumberTypeInt16, 4}, {kNumberTypeInt32, 5}, {kNumberTypeInt64, 6}, + {kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}}; +namespace { +const std::vector &GetSignature(const ValuePtr &function) { + static const auto empty = std::vector(); + if (function->isa() && function->cast()->has_signature()) { + return function->cast()->signatures(); + } else if (function->isa()) { + return function->cast()->signatures(); + } + return empty; +} + +void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, + const std::vector &signature, bool has_var, std::vector *const op_inputs) { + std::size_t sig_size = signature.size(); + auto positional_size = sig_size; + if (has_var) { + positional_size = sig_size - 1; + } + if (args_spec_list.size() < positional_size) { + for (size_t i = args_spec_list.size(); i < sig_size; ++i) { + auto default_value = signature[i].default_value; + if (default_value == nullptr) { + MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; + } else { + (*op_inputs).push_back(NewValueNode(default_value)); + } + } + } +} + +void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) { + *max_type_id = type_id; + *max_type_number = type_number; +} + +bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, + TypeId *arg_type = nullptr) { + if (arg_value->isa()) { + if (is_write) { + arg_value = arg_value->cast()->ref_origin(); + } else { + arg_value = arg_value->cast()->ref(); + } + } + if (arg_value->isa()) { + auto tensor = arg_value->cast(); + auto tensor_type = tensor->element()->BuildType(); + MS_EXCEPTION_IF_NULL(tensor_type); + *arg_type_id = tensor_type->type_id(); + if (arg_type != nullptr) { + *arg_type = kObjectTypeTensorType; + } + return true; + } + if (arg_value->isa()) { + auto scalar = arg_value->cast(); + auto scalar_type = scalar->BuildType(); + MS_EXCEPTION_IF_NULL(scalar_type); + *arg_type_id = scalar_type->type_id(); + if (arg_type != nullptr) { + *arg_type = kObjectTypeNumber; + } + return true; + } + return false; +} + +TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector indices, + const std::set &write_indices) { + TypeId max_type_id = kTypeUnknown; + size_t max_type_number = 0; + bool has_int8 = false; + bool has_scalar_int32 = false; + bool has_scalar_float32 = false; + for (const auto &index : indices) { + TypeId arg_type_id = kTypeUnknown; + TypeId arg_type = kTypeUnknown; + auto is_write = (write_indices.find(index) != write_indices.end()); + if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { + continue; + } + if (arg_type != kObjectTypeTensorType) { + if (arg_type_id == kNumberTypeInt32) { + has_scalar_int32 = true; + } else if (arg_type_id == kNumberTypeFloat32) { + has_scalar_float32 = true; + } + continue; + } + auto it = type_map.find(arg_type_id); + if (it == type_map.end()) { + continue; + } + if (arg_type_id == kNumberTypeInt8) { + has_int8 = true; + } + if (max_type_id == kTypeUnknown) { + SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); + continue; + } + if (it->second > max_type_number) { + SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); + } + } + + if (max_type_id == kNumberTypeUInt8 && has_int8 == true) { + max_type_id = kNumberTypeInt16; + } + // if bool is the max type, see if there is scalar input + // if so, it means that max is bool tensor, use scalar type instead. + // for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2]) + if (max_type_id == kNumberTypeBool) { + if (has_scalar_int32) { + max_type_id = kNumberTypeInt32; + } + if (has_scalar_float32) { + max_type_id = kNumberTypeFloat32; + } + } + return max_type_id; +} + +// Get the largest type of index in the same SignatureEnumDType of arguments. +std::map GetMaxDtype(const std::vector &dtypes, + const abstract::AbstractBasePtrList &args_spec_list, + const std::set &write_indices) { + // record index for signature.dtypes of the same type + // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} + std::map> type_indices; + for (size_t i = 0; i < dtypes.size(); ++i) { + auto it = type_indices.find(dtypes[i]); + if (it == type_indices.end()) { + (void)type_indices.insert(std::make_pair(dtypes[i], std::vector{i})); + } else { + it->second.push_back(i); + } + } + std::map dst_type; + for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) { + auto type = it->first; + auto indices = it->second; + // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it. + if (indices.size() < 2) { + continue; + } + bool has_tensor = false; + for (const auto &index : indices) { + AbstractBasePtr arg_value = args_spec_list[index]; + if (arg_value->isa()) { + arg_value = arg_value->cast()->ref(); + } + if (arg_value->isa()) { + has_tensor = true; + break; + } + } + if (!has_tensor) { + (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); + continue; + } + (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices))); + } + return dst_type; +} + +AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGraphPtr &graph) { + auto prim_cast_class = prim::GetPythonOps("Cast", "mindspore.ops.operations"); + MS_EXCEPTION_IF_NULL(prim_cast_class); + auto dtype_node = NewValueNode(TypeIdToType(type_id)); + auto cast_node = NewCNode({NewValueNode(prim_cast_class)}, graph); + return NewCNode({cast_node, param, dtype_node}, graph); +} + +void DoAutoCast(const std::string &func_name, const std::vector &signature, + const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, + std::vector *const op_inputs, const std::set &write_indices) { + std::vector dtypes; + (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), + [](const Signature &sig) { return sig.dtype; }); + int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); + if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { + return; + } + // Stat the index of the arguments with the largest type in the same SignatureEnumDType. + std::map dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); + // Identify which arg requires auto cast + for (size_t i = 0; i < args_spec_list.size(); ++i) { + auto it = dst_type.find(dtypes[i]); + if (it == dst_type.end() || it->second == kTypeUnknown) { + continue; + } + auto rw_it = write_indices.find(i); + auto is_write = (rw_it != write_indices.end()); + + TypeId arg_type_id = kTypeUnknown; + AbstractBasePtr arg_value = args_spec_list[i]; + (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); + auto it_map = type_name_map.find(arg_type_id); + if (it_map == type_name_map.end()) { + continue; + } + if (is_write) { + if (arg_type_id != it->second) { + auto it_name_map = type_name_map.find(it->second); + if (it_name_map == type_name_map.end()) { + continue; + } + RaiseExceptionForConvertRefDtype(func_name, it_map->second, it_name_map->second); + } + continue; + } + if (arg_value->isa() && arg_type_id == it->second) { + continue; + } + (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); + } +} + +AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const std::vector ¶ms_list) { + // args: original inputs + auto &signature = GetSignature(function); + std::size_t sig_size = signature.size(); + auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); + if (sig_size > 0) { + if (has_var) { + if (sig_size - 1 > args_spec_list.size()) { + MS_LOG(EXCEPTION) << "Function " << func_name + << "'s input length less than PositionalKeyword Signature length."; + } + } else if (args_spec_list.size() > sig_size) { + MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; + } + } + std::vector op_inputs; + std::set write_indices; + op_inputs.push_back(NewValueNode(function)); + // Assume, the write input of op is always the first input. We check if any write op, + // and add cast op on other inputs to keep the same type with assigned parameter. + for (size_t i = 0; i < args_spec_list.size(); ++i) { + AnfNodePtr param = params_list[i]; + if (args_spec_list[i] == nullptr) { + op_inputs.push_back(param); + continue; + } + SignatureEnumRW sig = SignatureEnumRW::kRWDefault; + // If sig_size is 0 use defalut. + if (sig_size > 0 && i < sig_size) { + sig = signature[i].rw; + } else if (has_var && i >= sig_size) { + sig = signature[sig_size - 1].rw; + } + + TypePtr type = args_spec_list[i]->GetTypeTrack(); + if (type && type->type_id() == kObjectTypeRef) { + if (sig == SignatureEnumRW::kRWRead) { + param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); + } else if (sig == SignatureEnumRW::kRWWrite) { + param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); + write_indices.insert(i); + } + // If sig is SignatureEnumRW::kRWRef, not do anything. + } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { + MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; + } + op_inputs.push_back(param); + } + // process default + ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); + DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); + return func_graph->NewCNode(op_inputs); +} +} // namespace + +AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) { + auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); + return new_cnode; +} + +FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + FuncGraphPtr func_graph = std::make_shared(); + + for (size_t i = 0; i < args_spec_list.size(); ++i) { + (void)func_graph->add_parameter(); + } + auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters()); + func_graph->set_output(new_cnode); + func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + return func_graph; +} + +void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::string &ref_type, + const std::string &target_type) { + MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n" + << "the type of writable argument is '" << ref_type << "', " + << "but the largest type in the same SignatureEumDtype is '" << target_type + << "'. The writable arg type is not equal to the largest type, " + << "so can not cast automatically."; +} +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.h b/mindspore/ccsrc/frontend/operator/composite/do_signature.h new file mode 100644 index 0000000000..9139be806a --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "utils/any.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" +#include "common/utils.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +class DoSignatureMetaFuncGraph : public MetaFuncGraph { + public: + explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function) + : MetaFuncGraph("S-" + name), function_(function) {} + + ~DoSignatureMetaFuncGraph() override = default; + + MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) + + FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override; + const ValuePtr function() const { return function_; } + + friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) { + return &lhs == &rhs; + } + + private: + ValuePtr function_; +}; +using RWSignaturePtr = std::shared_ptr; + +extern const std::map type_map; + +void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::string &ref_type, + const std::string &target_type); + +AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/list_append_operation.cc b/mindspore/ccsrc/frontend/operator/composite/list_append_operation.cc new file mode 100644 index 0000000000..3dfe2e23d0 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/list_append_operation.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/list_append_operation.h" + +#include +#include +#include + +#include "abstract/param_validator.h" +#include "frontend/optimizer/opt.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) { + abstract::CheckArgsSize("ListAppend", args_list, 2); + + AbstractBasePtr arg0 = args_list[0]; + abstract::AbstractListPtr arg0_list = dyn_cast(arg0); + MS_EXCEPTION_IF_NULL(arg0_list); + + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ret->debug_info()->set_name("append"); + AnfNodePtr arg0_node = ret->add_parameter(); + + std::vector elems; + elems.push_back(NewValueNode(prim::kPrimMakeList)); + size_t arg0_length = arg0_list->size(); + for (size_t i = 0; i < arg0_length; ++i) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(SizeToInt(i))})); + } + AnfNodePtr arg1_node = ret->add_parameter(); + elems.push_back(arg1_node); + + ret->set_output(ret->NewCNode(elems)); + return ret; +} + +REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) { + (void)py::class_>(*m, "ListAppend_") + .def(py::init()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/list_append_operation.h b/mindspore/ccsrc/frontend/operator/composite/list_append_operation.h similarity index 100% rename from mindspore/ccsrc/operator/composite/list_append_operation.h rename to mindspore/ccsrc/frontend/operator/composite/list_append_operation.h diff --git a/mindspore/ccsrc/frontend/operator/composite/map.cc b/mindspore/ccsrc/frontend/operator/composite/map.cc new file mode 100644 index 0000000000..a5f674187b --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/map.cc @@ -0,0 +1,292 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/map.h" +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/dshape.h" +#include "pybind_api/api_register.h" +#include "debug/trace.h" +#include "frontend/operator/ops.h" +#include "./common.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; + +AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) { + MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n"; + MS_EXCEPTION_IF_NULL(func_graph); + std::vector inputs; + if (fn_arg != nullptr) { + inputs.emplace_back(fn_arg); + } else { + inputs.emplace_back(NewValueNode(fn_leaf_)); + } + inputs.insert(inputs.end(), args.begin(), args.end()); + return func_graph->NewCNode(inputs); +} + +FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { + // Generate func for leaf nodes + FuncGraphPtr ptrGraph = std::make_shared(); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); + ptrGraph->debug_info()->set_name("map"); + AnfNodePtr ptrFnArg = nullptr; + if (fn_leaf_ == nullptr) { + ptrFnArg = ptrGraph->add_parameter(); + } + AnfNodePtrList args; + for (size_t i = 0; i < args_size; ++i) { + args.emplace_back(ptrGraph->add_parameter()); + } + ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args)); + return ptrGraph; +} + +AnfNodePtr Map::FullMakeList(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(type); + + std::size_t size = type->elements().size(); + bool is_not_same = + std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { + auto lhs = std::dynamic_pointer_cast(item.second); + MS_EXCEPTION_IF_NULL(lhs); + return lhs->elements().size() != size; + }); + if (is_not_same) { + MS_LOG(EXCEPTION) << "List in Map should have same length"; + } + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeList)); + + for (int i = 0; i < SizeToInt(size); ++i) { + MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target"; + auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); + auto fn = NewValueNode(ptrGraph); + + std::vector inputs2; + inputs2.push_back(fn); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + (void)std::transform( + arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), + [&func_graph, i](const std::pair &item) { + return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); + }); + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr Map::FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(type); + + std::size_t size = type->elements().size(); + bool is_not_same = + std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { + auto lhs = std::dynamic_pointer_cast(item.second); + MS_EXCEPTION_IF_NULL(lhs); + return lhs->elements().size() != size; + }); + if (is_not_same) { + MS_LOG(EXCEPTION) << "tuple in Map should have same length"; + } + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + + for (int i = 0; i < SizeToInt(size); ++i) { + MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs"; + auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); + auto fn = NewValueNode(ptrGraph); + + std::vector inputs2; + inputs2.push_back(fn); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + (void)std::transform( + arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), + [&func_graph, &i](std::pair item) { + return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); + }); + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr Map::FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + MS_EXCEPTION_IF_NULL(type); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); + inputs.push_back(NewValueNode(type)); + + std::size_t attrSize = type->GetAttributes().size(); + for (std::size_t i = 0; i < attrSize; ++i) { + MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs"; + auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); + auto fn = NewValueNode(ptrGraph); + + std::vector inputs2; + inputs2.push_back(fn); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + int j = 0; + for (auto item : arg_pairs) { + inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); + j++; + } + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + if (arg_pairs.empty()) { + MS_EXCEPTION(TypeError) << "map() must have at least two arguments"; + } + bool found = false; + TypeId id = kObjectTypeEnd; + std::pair pair; + for (auto &item : arg_pairs) { + pair = item; + MS_LOG(DEBUG) << "Map " << pair.second->ToString(); + id = item.second->type_id(); + if (nonleaf_.count(id)) { + found = true; + break; + } + } + + if (found) { + // In a nonleaf situation, all arguments must have the same generic. + bool is_not_same = + std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair &item) { + if (item.first != pair.first) { + return item.second->type_id() != pair.second->type_id(); + } + return false; + }); + if (is_not_same) { + std::ostringstream oss; + oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n" + << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; + int idx = 0; + for (auto &item : arg_pairs) { + oss << ++idx << ": " << item.second->ToString() << "\n"; + } + MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n" + << oss.str() << pair.second->ToString() << "\n"; + } + } + + switch (id) { + case kObjectTypeList: { + auto type = std::static_pointer_cast(pair.second); + return FullMakeList(type, func_graph, fn_arg, arg_pairs); + } + case kObjectTypeTuple: { + auto type = std::static_pointer_cast(pair.second); + return FullMakeTuple(type, func_graph, fn_arg, arg_pairs); + } + case kObjectTypeClass: { + auto type = std::static_pointer_cast(pair.second); + return FullMakeClass(type, func_graph, fn_arg, arg_pairs); + } + default: + MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class " + << ", but got " << pair.second->ToString(); + } +} + +FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { + FuncGraphPtr ptrGraph = std::make_shared(); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); + ptrGraph->debug_info()->set_name("map"); + + AnfNodePtr ptrFnArg = nullptr; + std::size_t i = 0; + if (fn_leaf_ == nullptr) { + ptrFnArg = ptrGraph->add_parameter(); + i = 1; + } + ArgsPairList arg_pairs; + std::size_t size = args_spec_list.size(); + for (; i < size; ++i) { + MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString(); + arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); + } + + ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs)); + return ptrGraph; +} + +abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { + if (fn_leaf_ == nullptr) { + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + // Assert that map's function param does not contain free variables + if (args_spec_list[0]->isa()) { + auto graph_func = dyn_cast(args_spec_list[0]); + auto func_graph = graph_func->func_graph(); + if (func_graph->parent() != nullptr) { + MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet."; + } + } + } + + AbstractBasePtrList broadened; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->Broaden(); + }); + return broadened; +} + +REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) { + (void)py::class_>(*m, "Map_") + .def(py::init>(), py::arg("leaf")) + .def(py::init<>()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/map.h b/mindspore/ccsrc/frontend/operator/composite/map.h new file mode 100644 index 0000000000..428014f9c4 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/map.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ + +#include +#include +#include +#include + +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" +#include "frontend/operator/composite/multitype_funcgraph.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using ArgsPairList = std::vector>; + +class Map : public MetaFuncGraph { + public: + explicit Map(const std::shared_ptr &fn_leaf = nullptr) + : MetaFuncGraph("map"), + fn_leaf_(fn_leaf), + broadcast_(false), + nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { + Init(); + } + Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { + Init(); + } + Map &operator=(const Map &h) { + if (this != &h) { + fn_leaf_ = h.fn_leaf_; + broadcast_ = h.broadcast_; + nonleaf_ = h.nonleaf_; + if (fn_leaf_) { + name_ = "map[" + fn_leaf_->name() + "]"; + } + } + return *this; + } + ~Map() override = default; + MS_DECLARE_PARENT(Map, MetaFuncGraph) + abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; + MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } + + private: + FuncGraphPtr GenerateLeafFunc(const size_t &args_size); + AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args); + AnfNodePtr FullMakeList(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_pairs); + AnfNodePtr FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_pairs); + AnfNodePtr FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_pairs); + AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs); + void Init() { + if (fn_leaf_ != nullptr) { + name_ = "map[" + fn_leaf_->name() + "]"; + } + signatures_ = + // def map(func:read, *args:ref): + std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, + {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); + } + + MultitypeFuncGraphPtr fn_leaf_; + bool broadcast_; + std::set nonleaf_; +}; +using MapPtr = std::shared_ptr; +class MapPy : public Map { + public: + explicit MapPy(const std::shared_ptr &fn_leaf = nullptr) : Map(fn_leaf) {} + ~MapPy() override = default; + MS_DECLARE_PARENT(MapPy, Map) +}; +using MapPyPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc new file mode 100644 index 0000000000..ba0d3d9ebb --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc @@ -0,0 +1,198 @@ + +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/multitype_funcgraph.h" +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "frontend/operator/cc_implementations.h" +#include "frontend/optimizer/opt.h" +#include "utils/context/ms_context.h" +#include "utils/symbolic.h" +#include "pybind_api/api_register.h" +#include "./common.h" +#include "ir/signature.h" +#include "debug/trace.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { + fn_cache_.clear(); + signatures_ = std::vector({// def multitype(*args:ref): + {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); +} + +void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { + MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; + auto fn = fn_cache_.find(types); + if (fn != fn_cache_.end()) { + MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; + } + fn_cache_[types] = s_fn; +} + +void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { + MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; + auto fn = fn_cache_.find(types); + if (fn != fn_cache_.end()) { + MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; + } + fn_cache_py_[types] = py_fn; +} + +void MultitypeFuncGraph::Register(const std::vector &types_name, const py::function &py_fn) { + TypePtrList types; + for (auto &type_name : types_name) { + auto type_ptr = StringToType(type_name); + if (type_ptr == nullptr) { + MS_LOG(EXCEPTION) << type_name << " convert from string error "; + } + types.push_back(type_ptr); + } + Register(types, py_fn); +} + +void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { + std::vector types_name; + for (size_t it = 0; it < tuple.size(); ++it) { + py::object name_py = tuple[it]; + if (py::isinstance(name_py)) { + types_name.push_back(name_py.cast()); + continue; + } + MS_LOG(EXCEPTION) << "Register must be string"; + } + Register(types_name, py_fn); +} +static TypePtr UnwrapRef(const TypePtr &type) { + if (type->isa()) { + return type->cast()->subtype(); + } + return type; +} + +// Return Exact match if exists, else return non ambiguous sub class match +// Return py::none() if matching is ambiguous +const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { + // Exact match + for (auto &item : fn_cache_py_) { + TypePtrList sign = item.first; + if (sign.size() != types.size()) { + continue; + } + auto match = true; + for (size_t i = 0; i < sign.size(); ++i) { + if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { + match = false; + break; + } + } + if (!match) { + continue; + } + return item.second; + } + return py::none(); +} + +FuncGraphPtr GenerateStubFunc(const TypePtrList &types) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (!enable_sparse) { + return nullptr; + } + + std::vector parameters; + ParameterPtr undetermined_param = nullptr; + auto stub = std::make_shared(); + for (size_t i = 0; i < types.size(); ++i) { + auto param = stub->add_parameter(); + parameters.push_back(param); + if (types[i]->type_id() == kObjectTypeUndeterminedType) { + undetermined_param = param; + } + } + if (undetermined_param != nullptr) { + std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; + for (size_t i = 0; i < types.size(); ++i) { + if (types[i]->type_id() == kObjectTypeFunction) { + std::vector call_prim{parameters[i], undetermined_param}; + inputs.push_back(stub->NewCNode(call_prim)); + } else { + inputs.push_back(parameters[i]); + } + } + auto stub_output = stub->NewCNode(inputs); + stub->set_output(stub_output); + stub->set_stub(true); + return stub; + } + return nullptr; +} + +FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { + auto py_fn = SignMatch(types); + std::ostringstream buffer; + buffer << types; + if (py_fn != py::none()) { + FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn); + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); + } + MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString(); + return func_graph; + } + auto stub = GenerateStubFunc(types); + if (stub != nullptr) { + MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString(); + return stub; + } + std::ostringstream oss; + oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ + << "`, corresponding location info:\n"; + int idx = 0; + for (auto &item : fn_cache_py_) { + FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); + if (func_graph == nullptr) { + MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; + continue; + } + oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; + } + MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n" + << oss.str(); +} + +REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { + (void)py::class_>( + *m, "MultitypeFuncGraph_") + .def(py::init()) + .def("register_fn", &MultitypeFuncGraph::PyRegister); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h new file mode 100644 index 0000000000..2139a0e9d1 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h @@ -0,0 +1,65 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +class MultitypeFuncGraph : public MetaFuncGraph { + public: + explicit MultitypeFuncGraph(const std::string &name); + ~MultitypeFuncGraph() override = default; + MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) + + using specialize_fn = FuncGraph *(*)(TypePtrList); + // Register a method which specialize based on types vectors; + virtual void Register(const TypePtrList &types, specialize_fn s_fn); + virtual void Register(const TypePtrList &types, const py::function &py_fn); + virtual void Register(const std::vector &types_name, const py::function &py_fn); + virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); + + FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; + size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } + const std::unordered_map &GetPyFunctions() const { + return fn_cache_py_; + } + + private: + const py::function SignMatch(const TypePtrList &types); + std::unordered_map fn_cache_; + std::unordered_map fn_cache_py_; +}; +using MultitypeFuncGraphPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc b/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc new file mode 100644 index 0000000000..2c9e0b538f --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/unpack_call.h" +#include +#include + +#include "./common.h" +#include "abstract/abstract_value.h" +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "frontend/operator/cc_implementations.h" +#include "ir/anf.h" +#include "frontend/optimizer/opt.h" +#include "utils/symbolic.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using mindspore::abstract::AbstractAttribute; +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractDictionary; +using mindspore::abstract::AbstractDictionaryPtr; +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractKeywordArg; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractTuplePtr; + +FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // slice a tensor + // args: tensor, slice or slice tuple + const std::string op_name = std::string("UnpackCall"); + size_t arg_length = args_spec_list.size(); + if (arg_length < 2) { + MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; + } + + (void)abstract::CheckArg(op_name, args_spec_list, 0); + auto ret_graph = std::make_shared(); + ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + + AnfNodePtr fnNode = ret_graph->add_parameter(); + std::vector elems; + elems.push_back(fnNode); + for (size_t index = 1; index < arg_length; index++) { + MS_EXCEPTION_IF_NULL(args_spec_list[index]); + if (args_spec_list[index]->isa()) { + auto arg_tuple = args_spec_list[index]->cast(); + AnfNodePtr para_tuple = ret_graph->add_parameter(); + for (size_t i = 0; i < arg_tuple->size(); ++i) { + elems.push_back( + ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))})); + } + } else if (args_spec_list[index]->isa()) { + AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast(); + AnfNodePtr para_dict = ret_graph->add_parameter(); + auto dict_elems = arg_dict->elements(); + (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), + [ret_graph, para_dict](const AbstractAttribute &item) { + auto dict_get_item = ret_graph->NewCNode( + {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); + return ret_graph->NewCNode( + {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item}); + }); + } else { + MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got " + << args_spec_list[index]->ToString(); + } + } + ret_graph->set_output(ret_graph->NewCNode(elems)); + return ret_graph; +} + +REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) { + (void)py::class_>(*m, "UnpackCall_") + .def(py::init()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/unpack_call.h b/mindspore/ccsrc/frontend/operator/composite/unpack_call.h new file mode 100644 index 0000000000..79c2600f36 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/unpack_call.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "utils/any.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" +#include "common/utils.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +// Expand the tuple and dict parameters generated when parsing the function call, +// and generate positional parameters and key-value pairs for function. +class UnpackCall : public MetaFuncGraph { + public: + explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {} + ~UnpackCall() override = default; + MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } +}; +using UnpackCallPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/zip_operation.cc b/mindspore/ccsrc/frontend/operator/composite/zip_operation.cc new file mode 100644 index 0000000000..9e2b6d28b2 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/zip_operation.cc @@ -0,0 +1,92 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/zip_operation.h" +#include + +#include "abstract/abstract_value.h" +#include "ir/anf.h" +#include "abstract/dshape.h" +#include "frontend/operator/cc_implementations.h" +#include "frontend/optimizer/opt.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractSequeue; +using mindspore::abstract::AbstractSequeuePtr; +using mindspore::abstract::AbstractTuple; + +FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // zip operation: + // input: tuple arguments + // output: tuple of items of input iterated on every input + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << "For 'zip', there is at least one input."; + } + + auto is_all_sequeue = + std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool { + MS_EXCEPTION_IF_NULL(abs); + return abs->isa(); + }); + if (!is_all_sequeue) { + MS_LOG(EXCEPTION) << "For 'zip', all inputs must be sequence."; + } + + auto min_abs = std::min_element( + args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &x, const AbstractBasePtr &y) { + return (x->cast()->size() < y->cast()->size()); + }); + FuncGraphPtr ret_graph = std::make_shared(); + ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + for (size_t idx = 0; idx < args_spec_list.size(); idx++) { + (void)ret_graph->add_parameter(); + } + + // generate tuple output of ziped arguments input + std::vector make_tuple_nodes; + make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t idx = 0; idx < (*min_abs)->cast()->size(); idx++) { + std::vector make_tuple_zip_nodes; + make_tuple_zip_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); + std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl"; + ValuePtr op = prim::GetPythonOps("getitem", module_name); + for (size_t arg_idx = 0; arg_idx < args_spec_list.size(); arg_idx++) { + std::vector tuple_get_item_nodes{NewValueNode(op), ret_graph->parameters()[arg_idx], + NewValueNode(SizeToInt(idx))}; + auto tuple_get_item_op = ret_graph->NewCNode(tuple_get_item_nodes); + make_tuple_zip_nodes.push_back(tuple_get_item_op); + } + auto make_tuple_zip_op = ret_graph->NewCNode(make_tuple_zip_nodes); + make_tuple_nodes.push_back(make_tuple_zip_op); + } + ret_graph->set_output(ret_graph->NewCNode(make_tuple_nodes)); + return ret_graph; +} + +REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) { + (void)py::class_>(*m, + "ZipOperation_") + .def(py::init()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/zip_operation.h b/mindspore/ccsrc/frontend/operator/composite/zip_operation.h new file mode 100644 index 0000000000..96697cb472 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/zip_operation.h @@ -0,0 +1,59 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "utils/any.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using AbstractBasePtr = abstract::AbstractBasePtr; +using AbstractBasePtrList = abstract::AbstractBasePtrList; +using AbstractTuplePtr = abstract::AbstractTuplePtr; + +class ZipOperation : public MetaFuncGraph { + public: + explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {} + ~ZipOperation() override = default; + MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) { + os << op.name_; + return os; + } + friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; } +}; +using ZipOperationPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ diff --git a/mindspore/ccsrc/frontend/operator/ops.cc b/mindspore/ccsrc/frontend/operator/ops.cc new file mode 100755 index 0000000000..5c7672ee3c --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/ops.cc @@ -0,0 +1,288 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/ops.h" +#include +#include + +namespace mindspore { +// namespace to support primitive operators +namespace prim { +// Arithmetic +const PrimitivePtr kPrimScalarAdd = std::make_shared("scalar_add"); +const PrimitivePtr kPrimScalarSub = std::make_shared("scalar_sub"); +const PrimitivePtr kPrimScalarMul = std::make_shared("scalar_mul"); +const PrimitivePtr kPrimScalarDiv = std::make_shared("scalar_div"); +const PrimitivePtr kPrimScalarFloordiv = std::make_shared("scalar_floordiv"); +const PrimitivePtr kPrimScalarMod = std::make_shared("scalar_mod"); +const PrimitivePtr kPrimScalarPow = std::make_shared("scalar_pow"); +const PrimitivePtr kPrimScalarTrunc = std::make_shared("scalar_trunc"); +const PrimitivePtr kPrimScalarFloor = std::make_shared("scalar_floor"); +const PrimitivePtr kPrimScalarUadd = std::make_shared("scalar_uadd"); +const PrimitivePtr kPrimScalarUsub = std::make_shared("scalar_usub"); +const PrimitivePtr kPrimScalarExp = std::make_shared("scalar_exp"); +const PrimitivePtr kPrimScalarLog = std::make_shared("scalar_log"); +const PrimitivePtr kPrimScalarSin = std::make_shared("scalar_sin"); +const PrimitivePtr kPrimScalarCos = std::make_shared("scalar_cos"); +const PrimitivePtr kPrimScalarTan = std::make_shared("scalar_tan"); + +// Comparisons +const PrimitivePtr kPrimScalarEq = std::make_shared("scalar_eq"); +const PrimitivePtr kPrimScalarLt = std::make_shared("scalar_lt"); +const PrimitivePtr kPrimScalarGt = std::make_shared("scalar_gt"); +const PrimitivePtr kPrimScalarNe = std::make_shared("scalar_ne"); +const PrimitivePtr kPrimScalarLe = std::make_shared("scalar_le"); +const PrimitivePtr kPrimScalarGe = std::make_shared("scalar_ge"); +const PrimitivePtr kPrimBoolNot = std::make_shared("bool_not"); +const PrimitivePtr kPrimBoolAnd = std::make_shared("bool_and"); +const PrimitivePtr kPrimBoolOr = std::make_shared("bool_or"); +const PrimitivePtr kPrimBoolEq = std::make_shared("bool_eq"); +const PrimitivePtr kPrimGreater = std::make_shared("Greater"); +const PrimitivePtr kPrimGreaterEqual = std::make_shared("GreaterEqual"); +const PrimitivePtr kPrimLess = std::make_shared("Less"); +const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); +const PrimitivePtr kPrimEqual = std::make_shared("Equal"); +const PrimitivePtr kPrimNotEqual = std::make_shared("NotEqual"); + +// Type introspection +const PrimitivePtr kPrimTypeOf = std::make_shared("typeof"); +const PrimitivePtr kPrimHasType = std::make_shared("hastype"); + +// Statements +const PrimitivePtr kPrimSwitch = std::make_shared("switch"); +const PrimitivePtr kPrimSwitchLayer = std::make_shared("switch_layer"); +const PrimitivePtr kPrimReturn = std::make_shared("return"); +const PrimitivePtr kPrimAssign = std::make_shared("Assign"); +const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); +const PrimitivePtr kPrimAssignSub = std::make_shared("AssignSub"); +const PrimitivePtr kPrimSelect = std::make_shared("Select"); +const PrimitivePtr kPrimCall = std::make_shared("call"); + +const PrimitivePtr kPrimDistribute = std::make_shared("distribute"); +const PrimitivePtr kPrimDot = std::make_shared("dot"); +const PrimitivePtr kPrimIm2Col = std::make_shared("im2col"); +const PrimitivePtr kPrimCol2Im = std::make_shared("col2im"); +const PrimitivePtr kPrimIm2ColV1 = std::make_shared("im2col_v1"); +const PrimitivePtr kPrimCol2ImV1 = std::make_shared("col2im_v1"); + +const PrimitivePtr kPrimResolve = std::make_shared("resolve"); +const PrimitivePtr kPrimEmbed = std::make_shared("embed"); +const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); +const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); + +const PrimitivePtr kPrimLabelGoto = std::make_shared("LabelGoto"); +const PrimitivePtr kPrimLabelSwitch = std::make_shared("LabelSwitch"); +const PrimitivePtr kPrimLabelSet = std::make_shared("LabelSet"); + +// Structure +const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); +const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); +const PrimitivePtr kPrimMakeTuple = std::make_shared("make_tuple"); +const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); +const PrimitivePtr kPrimMakeDict = std::make_shared("make_dict"); +const PrimitivePtr kPrimMakeKeywordArg = std::make_shared("make_keyword_arg"); +const PrimitivePtr kPrimExtractKeywordArg = std::make_shared("extract_keyword_arg"); +const PrimitivePtr kPrimMakeSlice = std::make_shared("make_slice"); +const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); +const PrimitivePtr kPrimTupleGetItem = std::make_shared("tuple_getitem"); +const PrimitivePtr kPrimListGetItem = std::make_shared("list_getitem"); +const PrimitivePtr kPrimArrayGetItem = std::make_shared("array_getitem"); +const PrimitivePtr kPrimTupleSetItem = std::make_shared("tuple_setitem"); +const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); +const PrimitivePtr kPrimArraySetItem = std::make_shared("array_setitem"); +const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); +const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); +const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); +const PrimitivePtr kPrimGetAttr = std::make_shared("getattr"); +const PrimitivePtr kPrimTupleLen = std::make_shared("tuple_len"); +const PrimitivePtr kPrimDictLen = std::make_shared("dict_len"); +const PrimitivePtr kPrimListLen = std::make_shared("list_len"); +const PrimitivePtr kPrimArrayLen = std::make_shared("array_len"); +const PrimitivePtr kPrimListMap = std::make_shared("list_map"); +const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); +const PrimitivePtr kPrimTupleReversed = std::make_shared("tuple_reversed"); + +const PrimitivePtr kPrimTileShape = std::make_shared("tile_shape"); +const PrimitivePtr kPrimReducedShape = std::make_shared("reduced_shape"); +const PrimitivePtr kPrimTupleDiv = std::make_shared("tuple_div"); +const PrimitivePtr kPrimTupleToArray = std::make_shared("tuple_to_array"); +const PrimitivePtr kPrimShapeMul = std::make_shared("shape_mul"); +const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared("generate_shape_index"); +const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared("generate_inverse_index"); +const PrimitivePtr kPrimTupleEqual = std::make_shared("tuple_equal"); +const PrimitivePtr kPrimListEqual = std::make_shared("list_equal"); +const PrimitivePtr kPrimMakeRange = std::make_shared("make_range"); +const PrimitivePtr kPrimStopGradient = std::make_shared("stop_gradient"); + +// Arrays +const PrimitivePtr kPrimScalarToArray = std::make_shared("scalar_to_array"); +const PrimitivePtr kPrimArrayToScalar = std::make_shared("array_to_scalar"); +const PrimitivePtr kPrimBroadcastShape = std::make_shared("broadcast_shape"); +const PrimitivePtr kPrimArrayMap = std::make_shared("array_map"); +const PrimitivePtr kPrimArrayReduce = std::make_shared("array_reduce"); +const PrimitivePtr kPrimShape = std::make_shared("Shape"); +const PrimitivePtr kPrimCast = std::make_shared("Cast"); +const PrimitivePtr kPrimConcat = std::make_shared("Concat"); +const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); +const PrimitivePtr kPrimTranspose = std::make_shared("Transpose"); +const PrimitivePtr kPrimGatherV2 = std::make_shared("GatherV2"); +const PrimitivePtr kPrimEmbeddingLookup = std::make_shared("EmbeddingLookup"); +const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("EmbeddingLookupCommGrad"); +const PrimitivePtr kPrimSize = std::make_shared("Size"); +const PrimitivePtr kPrimArgMax = std::make_shared("Argmax"); +const PrimitivePtr kPrimPack = std::make_shared("Pack"); +const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared("UnsortedSegmentSum"); +const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared("UnsortedSegmentMin"); +const PrimitivePtr kPrimConcatOffset = std::make_shared("ConcatOffset"); +const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); +const PrimitivePtr kPrimTile = std::make_shared("Tile"); +const PrimitivePtr kPrimAddN = std::make_shared("AddN"); +const PrimitivePtr KPrimTransData = std::make_shared("TransData"); +const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); +const PrimitivePtr kPrimPad = std::make_shared("Pad"); +const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); + +// Maths +const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); +const PrimitivePtr kPrimMatMul = std::make_shared("MatMul"); +const PrimitivePtr kPrimBatchMatMul = std::make_shared("BatchMatMul"); +const PrimitivePtr kPrimMaximumGrad = std::make_shared("MaximumGrad"); +const PrimitivePtr kPrimMinimumGrad = std::make_shared("MinimumGrad"); +const PrimitivePtr kPrimReduceMean = std::make_shared("ReduceMean"); +const PrimitivePtr kPrimReduceSum = std::make_shared("ReduceSum"); +const PrimitivePtr kPrimReduceAll = std::make_shared("ReduceAll"); +const PrimitivePtr kPrimReduceMax = std::make_shared("ReduceMax"); +const PrimitivePtr kPrimReduceMin = std::make_shared("ReduceMin"); +const PrimitivePtr kPrimNeg = std::make_shared("Neg"); +const PrimitivePtr kPrimSub = std::make_shared("Sub"); +const PrimitivePtr kPrimMul = std::make_shared("Mul"); +const PrimitivePtr kPrimMinimum = std::make_shared("Minimum"); +const PrimitivePtr kPrimMaximum = std::make_shared("Maximum"); +const PrimitivePtr kPrimSquare = std::make_shared("Square"); +const PrimitivePtr kPrimCumSum = std::make_shared("CumSum"); +const PrimitivePtr kPrimCumProd = std::make_shared("CumProd"); +const PrimitivePtr kPrimSubscalar = std::make_shared("Subscalar"); +const PrimitivePtr kPrimInplaceAdd = std::make_shared("InplaceAdd"); +const PrimitivePtr kPrimInplaceSub = std::make_shared("InplaceSub"); +const PrimitivePtr kPrimPow = std::make_shared("Pow"); +const PrimitivePtr kPrimRealDiv = std::make_shared("RealDiv"); +const PrimitivePtr kPrimSqrt = std::make_shared("Sqrt"); +const PrimitivePtr kPrimReciprocal = std::make_shared("Reciprocal"); +const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); + +// NN +const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); +const PrimitivePtr kPrimSoftmax = std::make_shared("Softmax"); +const PrimitivePtr kPrimLogSoftmax = std::make_shared("LogSoftmax"); +const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared("LogSoftmaxGrad"); +const PrimitivePtr kPrimTanh = std::make_shared("Tanh"); +const PrimitivePtr kPrimTanhGrad = std::make_shared("TanhGrad"); +const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); +const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); +const PrimitivePtr kPrimMaxPool = std::make_shared("MaxPool"); +const PrimitivePtr kPrimMaxPoolGrad = std::make_shared("MaxPoolGrad"); +const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("ApplyCenteredRMSProp"); +const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); +const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); +const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); +const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared("FusedBatchNormGrad"); +const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); +const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); +const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); +const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); +const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); +const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); +const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = + std::make_shared("DepthwiseConv2dNativeBackpropFilter"); +const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = + std::make_shared("DepthwiseConv2dNativeBackpropInput"); +const PrimitivePtr kPrimBiasAddGrad = std::make_shared("BiasAddGrad"); +const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared("SoftmaxCrossEntropyWithLogits"); +const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = + std::make_shared("SparseSoftmaxCrossEntropyWithLogits"); +const PrimitivePtr kPrimMomentum = std::make_shared("Momentum"); +const PrimitivePtr kPrimApplyMomentum = std::make_shared("ApplyMomentum"); +const PrimitivePtr kPrimLayerNorm = std::make_shared("LayerNorm"); +const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGrad"); +const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared("LayerNormXBackprop"); +const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); +const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); +const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); +const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); +const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); +const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); +const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); +const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); +const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); +const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); +const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); +const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); +const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); +const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); + +// Other miscellaneous +const PrimitivePtr kPrimIdentity = std::make_shared("identity"); +const PrimitivePtr kPrimPartial = std::make_shared("Partial"); +const PrimitivePtr kPrimJ = std::make_shared("J"); +const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); +const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); +const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); +const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); +const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); +const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); +const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); +const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); +const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); +const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); +const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); +const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); +const PrimitivePtr kPrimPrint = std::make_shared("Print"); + +const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); +const PrimitivePtr kPrimDepend = std::make_shared("Depend"); +const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); + +const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("BroadcastGradientArgs"); +const PrimitivePtr kPrimControlDepend = std::make_shared("ControlDepend"); +const PrimitivePtr kPrimIs_ = std::make_shared("is_"); +const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); +const PrimitivePtr kPrimInDict = std::make_shared("in_dict"); +const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); +const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); +const PrimitivePtr kPrimIsConsant = std::make_shared("is_constant"); +const PrimitivePtr kPrimEquivFormat = std::make_shared("EquivFormat"); + +// Comm ops +const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); +const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); +const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); +const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); + +// Debug ops +const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); +const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary"); +const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); +const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); +const PrimitivePtr kPrimDebug = std::make_shared("Debug"); + +// IndexedSlices +const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared("MakeIndexedSlices"); +const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared("IndexedSlicesGetValues"); +const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared("IndexedSlicesGetIndices"); +const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared("IndexedSlicesGetDenseShape"); +const PrimitivePtr kPrimIsIndexedSlices = std::make_shared("IsIndexedSlices"); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h new file mode 100755 index 0000000000..0dea045a6e --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/ops.h @@ -0,0 +1,336 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_OPS_H_ +#define MINDSPORE_CCSRC_OPERATOR_OPS_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "ir/primitive.h" + +namespace mindspore { +// namespace to support primitive operators +namespace prim { +ValuePtr GetPythonOps(const std::string &op_name, + const std::string &module_name = "mindspore._extends.parse.standard_method", + bool use_signature = false); + +// Arithmetic +extern const PrimitivePtr kPrimScalarAdd; +extern const PrimitivePtr kPrimScalarSub; +extern const PrimitivePtr kPrimScalarMul; +extern const PrimitivePtr kPrimScalarDiv; +extern const PrimitivePtr kPrimScalarFloordiv; +extern const PrimitivePtr kPrimScalarMod; +extern const PrimitivePtr kPrimScalarPow; +extern const PrimitivePtr kPrimScalarTrunc; +extern const PrimitivePtr kPrimScalarFloor; +extern const PrimitivePtr kPrimScalarUadd; +extern const PrimitivePtr kPrimScalarUsub; +extern const PrimitivePtr kPrimScalarExp; +extern const PrimitivePtr kPrimScalarLog; +extern const PrimitivePtr kPrimScalarSin; +extern const PrimitivePtr kPrimScalarCos; +extern const PrimitivePtr kPrimScalarTan; + +// Comparisons +extern const PrimitivePtr kPrimScalarEq; +extern const PrimitivePtr kPrimScalarLt; +extern const PrimitivePtr kPrimScalarGt; +extern const PrimitivePtr kPrimScalarNe; +extern const PrimitivePtr kPrimScalarLe; +extern const PrimitivePtr kPrimScalarGe; +extern const PrimitivePtr kPrimBoolNot; +extern const PrimitivePtr kPrimBoolAnd; +extern const PrimitivePtr kPrimBoolOr; +extern const PrimitivePtr kPrimBoolEq; +extern const PrimitivePtr kPrimGreater; +extern const PrimitivePtr kPrimGreaterEqual; +extern const PrimitivePtr kPrimLess; +extern const PrimitivePtr kPrimLessEqual; +extern const PrimitivePtr kPrimEqual; +extern const PrimitivePtr kPrimNotEqual; + +// Type introspection +extern const PrimitivePtr kPrimTypeOf; +extern const PrimitivePtr kPrimHasType; + +// Statements +extern const PrimitivePtr kPrimSwitch; +extern const PrimitivePtr kPrimSwitchLayer; +extern const PrimitivePtr kPrimReturn; +extern const PrimitivePtr kPrimAssign; +extern const PrimitivePtr kPrimAssignAdd; +extern const PrimitivePtr kPrimAssignSub; +extern const PrimitivePtr kPrimSelect; +extern const PrimitivePtr kPrimCall; + +extern const PrimitivePtr kPrimDistribute; +extern const PrimitivePtr kPrimDot; +extern const PrimitivePtr kPrimIm2Col; +extern const PrimitivePtr kPrimCol2Im; +extern const PrimitivePtr kPrimIm2ColV1; +extern const PrimitivePtr kPrimCol2ImV1; + +extern const PrimitivePtr kPrimResolve; +extern const PrimitivePtr kPrimEmbed; +extern const PrimitivePtr kPrimRefToEmbed; +extern const PrimitivePtr kPrimCreateInstance; + +extern const PrimitivePtr kPrimLabelGoto; +extern const PrimitivePtr kPrimLabelSwitch; +extern const PrimitivePtr kPrimLabelSet; + +// Structure +extern const PrimitivePtr kPrimStringEqual; +extern const PrimitivePtr kPrimStringConcat; +extern const PrimitivePtr kPrimMakeTuple; +extern const PrimitivePtr kPrimMakeList; +extern const PrimitivePtr kPrimMakeDict; +extern const PrimitivePtr kPrimMakeKeywordArg; +extern const PrimitivePtr kPrimExtractKeywordArg; +extern const PrimitivePtr kPrimMakeSlice; +extern const PrimitivePtr kPrimMakeRecord; +extern const PrimitivePtr kPrimTupleGetItem; +extern const PrimitivePtr kPrimListGetItem; +extern const PrimitivePtr kPrimArrayGetItem; +extern const PrimitivePtr kPrimTupleSetItem; +extern const PrimitivePtr kPrimListSetItem; +extern const PrimitivePtr kPrimArraySetItem; +extern const PrimitivePtr kPrimDictGetItem; +extern const PrimitivePtr kPrimDictSetItem; +extern const PrimitivePtr kPrimListAppend; +extern const PrimitivePtr kPrimGetAttr; +extern const PrimitivePtr kPrimTupleLen; +extern const PrimitivePtr kPrimDictLen; +extern const PrimitivePtr kPrimListLen; +extern const PrimitivePtr kPrimArrayLen; +extern const PrimitivePtr kPrimListMap; +extern const PrimitivePtr kPrimListReduce; +extern const PrimitivePtr kPrimTupleReversed; +extern const PrimitivePtr kPrimTileShape; +extern const PrimitivePtr kPrimReducedShape; +extern const PrimitivePtr kPrimTupleDiv; +extern const PrimitivePtr kPrimTupleToArray; +extern const PrimitivePtr kPrimShapeMul; +extern const PrimitivePtr kPrimGenerateShapeIndex; +extern const PrimitivePtr kPrimGenerateInverseIndex; +extern const PrimitivePtr kPrimTupleEqual; +extern const PrimitivePtr kPrimListEqual; +extern const PrimitivePtr kPrimMakeRange; +extern const PrimitivePtr kPrimStopGradient; + +// Arrays +extern const PrimitivePtr kPrimScalarToArray; +extern const PrimitivePtr kPrimArrayToScalar; +extern const PrimitivePtr kPrimBroadcastShape; +extern const PrimitivePtr kPrimArrayMap; +extern const PrimitivePtr kPrimArrayReduce; +extern const PrimitivePtr kPrimShape; +extern const PrimitivePtr kPrimCast; +extern const PrimitivePtr kPrimConcat; +extern const PrimitivePtr kPrimSqueeze; +extern const PrimitivePtr kPrimTranspose; +extern const PrimitivePtr kPrimGatherV2; +extern const PrimitivePtr kPrimEmbeddingLookup; +extern const PrimitivePtr kPrimEmbeddingLookupCommGrad; +extern const PrimitivePtr kPrimSize; +extern const PrimitivePtr kPrimArgMax; +extern const PrimitivePtr kPrimPack; +extern const PrimitivePtr kPrimUnpack; +extern const PrimitivePtr kPrimUnsortedSegmentMin; +extern const PrimitivePtr kPrimUnsortedSegmentSum; +extern const PrimitivePtr kPrimConcatOffset; +extern const PrimitivePtr kPrimReshape; +extern const PrimitivePtr kPrimTile; +extern const PrimitivePtr kPrimAddN; +extern const PrimitivePtr KPrimTransData; +extern const PrimitivePtr kPrimNMSWithMask; +extern const PrimitivePtr kPrimPad; +extern const PrimitivePtr kPrimArgMaxWithValue; +extern const PrimitivePtr kPrimRealDiv; +extern const PrimitivePtr kPrimSqrt; +extern const PrimitivePtr kPrimReciprocal; +extern const PrimitivePtr kPrimExpandDims; + +// Maths +extern const PrimitivePtr kPrimTensorAdd; +extern const PrimitivePtr kPrimMatMul; +extern const PrimitivePtr kPrimBatchMatMul; +extern const PrimitivePtr kPrimMaximumGrad; +extern const PrimitivePtr kPrimMinimumGrad; +extern const PrimitivePtr kPrimReduceMean; +extern const PrimitivePtr kPrimReduceSum; +extern const PrimitivePtr kPrimReduceAll; +extern const PrimitivePtr kPrimReduceMax; +extern const PrimitivePtr kPrimReduceMin; +extern const PrimitivePtr kPrimNeg; +extern const PrimitivePtr kPrimSub; +extern const PrimitivePtr kPrimMul; +extern const PrimitivePtr kPrimRealDiv; +extern const PrimitivePtr kPrimMinimum; +extern const PrimitivePtr kPrimMaximum; +extern const PrimitivePtr kPrimSquare; +extern const PrimitivePtr kPrimSqrt; +extern const PrimitivePtr kPrimEqual; +extern const PrimitivePtr kPrimLess; +extern const PrimitivePtr kPrimLessEqual; +extern const PrimitivePtr kPrimCumSum; +extern const PrimitivePtr kPrimCumProd; +extern const PrimitivePtr kPrimSubscalar; +extern const PrimitivePtr kPrimInplaceAdd; +extern const PrimitivePtr kPrimInplaceSub; +extern const PrimitivePtr kPrimPow; + +// NN +extern const PrimitivePtr kPrimFlatten; +extern const PrimitivePtr kPrimSoftmax; +extern const PrimitivePtr kPrimLogSoftmax; +extern const PrimitivePtr kPrimLogSoftmaxGrad; +extern const PrimitivePtr kPrimApplyCenteredRMSProp; +extern const PrimitivePtr kPrimTanh; +extern const PrimitivePtr kPrimTanhGrad; +extern const PrimitivePtr kPrimPooling; +extern const PrimitivePtr kPrimPoolingGrad; +extern const PrimitivePtr kPrimFusedBatchNorm; +extern const PrimitivePtr kPrimBatchNorm; +extern const PrimitivePtr kPrimBatchNormGrad; +extern const PrimitivePtr kPrimConv2D; +extern const PrimitivePtr kPrimMaxPool; +extern const PrimitivePtr kPrimMaxPoolGrad; +extern const PrimitivePtr kPrimAvgPoolGrad; +extern const PrimitivePtr kPrimFusedBatchNormGrad; +extern const PrimitivePtr kPrimReluGrad; +extern const PrimitivePtr kPrimConv2DBackpropInput; +extern const PrimitivePtr kPrimConv2DBackpropFilter; +extern const PrimitivePtr kPrimDepthwiseConv2dNative; +extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter; +extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput; + +extern const PrimitivePtr kPrimBiasAddGrad; +extern const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits; +extern const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits; +extern const PrimitivePtr kPrimMomentum; +extern const PrimitivePtr kPrimApplyMomentum; +extern const PrimitivePtr kPrimLayerNorm; +extern const PrimitivePtr kPrimLayerNormGrad; +extern const PrimitivePtr kPrimLayerNormXBackprop; +extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop; +extern const PrimitivePtr kPrimDropoutGenMask; +extern const PrimitivePtr kPrimDropoutDoMask; +extern const PrimitivePtr kPrimOneHot; +extern const PrimitivePtr kPrimGelu; +extern const PrimitivePtr kPrimGeluGrad; +extern const PrimitivePtr kPrimRelu; +extern const PrimitivePtr kPrimReluV2; +extern const PrimitivePtr kPrimActivation; +extern const PrimitivePtr kPrimZerosLike; +extern const PrimitivePtr kPrimFakeBprop; +extern const PrimitivePtr kPrimBpropCut; +extern const PrimitivePtr kPrimFakeQuantPerLayer; +extern const PrimitivePtr kPrimFakeQuantPerChannel; +extern const PrimitivePtr kPrimApplyRMSProp; + +// Other Miscellaneous +extern const PrimitivePtr kPrimIdentity; +extern const PrimitivePtr kPrimPartial; +extern const PrimitivePtr kPrimJ; +extern const PrimitivePtr kPrimEnvSetItem; +extern const PrimitivePtr kPrimEnvGetItem; +extern const PrimitivePtr kPrimEnvAdd; +extern const PrimitivePtr kPrimMakeRefKey; +extern const PrimitivePtr kPrimMakeRef; +extern const PrimitivePtr kPrimGetRefKey; +extern const PrimitivePtr kPrimGetRefValue; +extern const PrimitivePtr kPrimGetRefOrigin; +extern const PrimitivePtr kPrimInsertGradientOf; +extern const PrimitivePtr kPrimHookBackward; +extern const PrimitivePtr kPrimPrintShapeType; +extern const PrimitivePtr kPrimPrint; +extern const PrimitivePtr kPrimSameTypeShape; +extern const PrimitivePtr kPrimCheckBprop; +extern const PrimitivePtr kPrimDepend; +extern const PrimitivePtr kPrimStateSetItem; +extern const PrimitivePtr kPrimScalarSummary; +extern const PrimitivePtr kPrimImageSummary; +extern const PrimitivePtr kPrimTensorSummary; +extern const PrimitivePtr kPrimHistogramSummary; +extern const PrimitivePtr kPrimBroadcastGradientArgs; +extern const PrimitivePtr kPrimControlDepend; +extern const PrimitivePtr kPrimIs_; +extern const PrimitivePtr kPrimIsNot; +extern const PrimitivePtr kPrimInDict; +extern const PrimitivePtr kPrimNotInDict; +extern const PrimitivePtr kPrimMixedPrecisionCast; +extern const PrimitivePtr kPrimIsConsant; +extern const PrimitivePtr kPrimEquivFormat; +extern const PrimitivePtr kPrimDebug; + +// Comm ops +extern const PrimitivePtr kPrimAllReduce; +extern const PrimitivePtr kPrimMirror; +extern const PrimitivePtr kPrimVirtualDiv; +extern const PrimitivePtr kPrimVirtualDataset; + +// IndexedSlices +extern const PrimitivePtr kPrimMakeIndexedSlices; +extern const PrimitivePtr kPrimIndexedSlicesGetValues; +extern const PrimitivePtr kPrimIndexedSlicesGetIndices; +extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; +extern const PrimitivePtr kPrimIsIndexedSlices; + +// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll +const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; +// max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it +// will be sunk(i.e. not unrolled) +const int MAX_FOR_LOOP_COUNT = 600; + +class DoSignaturePrimitive : public Primitive { + public: + explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) + : Primitive("S-Prim-" + name), function_(function) {} + + ~DoSignaturePrimitive() override = default; + + MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive) + + const ValuePtr function() const { return function_; } + + private: + ValuePtr function_; +}; +using DoSignaturePrimitivePtr = std::shared_ptr; + +class UnpackGraphPrimitive : public Primitive { + public: + explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) + : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} + ~UnpackGraphPrimitive() override = default; + MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) + bool with_sens_in_args() const { return with_sens_in_args_; } + bool need_unpack_args() const { return need_unpack_args_; } + + private: + bool with_sens_in_args_; + bool need_unpack_args_; +}; +using UnpackGraphPrimitivePtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_OPS_H_ diff --git a/mindspore/ccsrc/frontend/operator/ops_extends.cc b/mindspore/ccsrc/frontend/operator/ops_extends.cc new file mode 100755 index 0000000000..c406682c3e --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/ops_extends.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/ops.h" +#include +#include +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/data_converter.h" + +namespace mindspore { +// namespace to support primitive operators +namespace prim { +ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name, bool use_signature) { + py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); + ValuePtr node = nullptr; + bool succ = parse::ConvertData(obj, &node, use_signature); + if (!succ) { + MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail"; + } + return node; +} +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_arrays.cc b/mindspore/ccsrc/frontend/operator/prim_arrays.cc new file mode 100644 index 0000000000..caaf1d1b2a --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_arrays.cc @@ -0,0 +1,170 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" +#include "frontend/operator/cc_implementations.h" +#include "abstract/param_validator.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a scalar. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractScalarPtr arg = CheckArg(op_name, args_spec_list, 0); + return std::make_shared(arg, std::make_shared()); +} + +AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor with 0 shape. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto arg = CheckArg(op_name, args_spec_list, 0); + auto a_shp = arg->shape(); + if (!a_shp->shape().empty()) { + MS_LOG(EXCEPTION) << "array_to_scalar requires zero size shape."; + } + return arg->element(); +} + +AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tuples. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto xs = CheckArg(op_name, args_spec_list, 0); + auto ys = CheckArg(op_name, args_spec_list, 1); + + auto value_tuple_x = xs->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(value_tuple_x); + auto shp_tuple_x = value_tuple_x->value(); + std::vector shp_x; + (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(shp_x), + [](const ValuePtr &e) -> int { return GetValue(e); }); + + auto value_tuple_y = ys->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(value_tuple_y); + auto shp_tuple_y = value_tuple_y->value(); + std::vector shp_y; + (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y), + [](const ValuePtr &e) -> int { return GetValue(e); }); + + std::vector res = prim::BroadcastShape_(shp_x, shp_y); + if (res.empty()) { + MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," + << args_spec_list[1]->ToString(); + } + + AbstractBasePtrList elems; + (void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int n) -> AbstractBasePtr { + return std::make_shared(std::make_shared(n), kInt32); + }); + + return std::make_shared(elems); +} + +AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, 0); + MS_LOG(DEBUG) << "InferImplShape:" << arg->ToString(); + + AbstractBasePtrList values; + auto shp = arg->shape(); + for (int entry : shp->shape()) { + auto entry_v = MakeValue(entry); + values.push_back(std::make_shared(entry_v, entry_v->type())); + } + return std::make_shared(values); +} + +AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto arg = CheckArg(op_name, args_spec_list, 0); + auto multiples = CheckArg(op_name, args_spec_list, 1); + + ShapePtr input_shape = arg->shape(); + (void)CheckTensorDType(arg, {kInt16, kFloat16, kInt32, kFloat32}, "Input 0 of Tile should be %s"); + + auto mul_shp_value = multiples->BuildValue(); + if (mul_shp_value->isa()) { + MS_LOG(EXCEPTION) << "shape's data field can't be anything: " << args_spec_list[1]->ToString(); + } + + std::vector mul_shp; + auto value_tuple_mul = mul_shp_value->cast(); + auto mul_shp_data = value_tuple_mul->value(); + (void)std::transform(std::begin(mul_shp_data), std::end(mul_shp_data), std::back_inserter(mul_shp), + [](const ValuePtr &e) -> int { return GetValue(e); }); + if (input_shape->shape().size() != mul_shp_data.size()) { + MS_LOG(EXCEPTION) << "Tile requires input and multiples size equal, while the input size is " + << input_shape->shape().size() << ", value size is: " << mul_shp_data.size() << "."; + } + + std::vector result_shp; + for (size_t i = 0; i < mul_shp_data.size(); ++i) { + result_shp.push_back(input_shape->shape()[i] * mul_shp[i]); + } + return std::make_shared(arg->element(), std::make_shared(result_shp)); +} + +AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple of tensor. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto arg = CheckArg(op_name, args_spec_list, 0); + if (arg->elements().empty()) { + MS_LOG(EXCEPTION) << "Arg elements is empty."; + } + + size_t tuple_len = arg->elements().size(); + AbstractTensorPtr tensor_base = CheckArg(op_name, arg->elements(), 0); + int rank_base = SizeToInt(tensor_base->shape()->shape().size()); + + ValuePtr axis = primitive->GetAttr("axis"); + // Axis value should be in [-(rank_base + 1), rank_base). + int axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base); + // If axis is negative, add offset(rank_base + 1) to turn it to positive. + axis_value = GetPositiveAxis(axis_value, IntToSize(rank_base + 1)); + + for (size_t i = 1; i < tuple_len; ++i) { + AbstractTensorPtr tensor = CheckArg(op_name, arg->elements(), i); + (void)CheckDtypeSame(op_name, tensor_base, tensor); + (void)CheckShapeSame(op_name, tensor_base, tensor); + } + + primitive->set_attr("N", MakeValue(SizeToInt(tuple_len))); + primitive->set_attr("T", tensor_base->element()->BuildType()); + + AbstractTensorPtr ret = dyn_cast(tensor_base->Broaden()); + MS_EXCEPTION_IF_NULL(ret); + auto shape = ret->shape()->shape(); + (void)shape.insert(shape.begin() + axis_value, tuple_len); + ret->set_shape(std::make_shared(shape)); + return ret; +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_debug.cc b/mindspore/ccsrc/frontend/operator/prim_debug.cc new file mode 100644 index 0000000000..718dadf5c1 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_debug.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/param_validator.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" +#include "utils/symbolic.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor(value) + const std::string op_name = primitive->name(); + + CheckArgsSize(op_name, args_spec_list, 1); + auto tensor_value = CheckArg(op_name, args_spec_list, 0); + + int tensor_rank = SizeToInt(tensor_value->shape()->shape().size()); + if (tensor_rank == 0) { + MS_LOG(EXCEPTION) << op_name << " summary evaluator second arg should be an tensor, but got a scalar, rank is 0"; + } + + return std::make_shared(AbstractBasePtrList({tensor_value->Broaden()})); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_maths.cc b/mindspore/ccsrc/frontend/operator/prim_maths.cc new file mode 100644 index 0000000000..e4543a3821 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_maths.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" +#include "abstract/param_validator.h" +#include "common/utils.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto input_x = CheckArg(op_name, args_spec_list, 0); + auto input_y = CheckArg(op_name, args_spec_list, 1); + auto dout = CheckArg(op_name, args_spec_list, 2); + (void)CheckTensorsDTypeSame({input_x, input_y, dout}, {kInt, kUInt, kFloat}, + op_name + "evaluator three inputs should be %s"); + + AbstractBasePtr dx = input_x->Broaden(); + AbstractBasePtr dy = input_y->Broaden(); + + return std::make_shared(AbstractBasePtrList({dx, dy})); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_nn.cc b/mindspore/ccsrc/frontend/operator/prim_nn.cc new file mode 100644 index 0000000000..96c86d815d --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_nn.cc @@ -0,0 +1,432 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" +#include "abstract/param_validator.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTensorPtr input_tensor = CheckArg(op_name, args_spec_list, 0); + (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input 0 of Pooling should be %s"); + + ShapePtr input_shape = dyn_cast(input_tensor->GetShapeTrack()); // NCHW + MS_EXCEPTION_IF_NULL(input_shape); + if (input_shape->shape().size() != 4) { + MS_LOG(EXCEPTION) << "Pooling input should be a 4-D tensor."; + } + int h_input = input_shape->shape()[2]; + int w_input = input_shape->shape()[3]; + + int window = primitive->GetAttr("window")->cast()->value(); + int stride = primitive->GetAttr("stride")->cast()->value(); + int padding = primitive->GetAttr("pad")->cast()->value(); + int nan_opt = primitive->GetAttr("nan_opt")->cast()->value(); + int data_mode = primitive->GetAttr("data_mode")->cast()->value(); + int ceil_mode = primitive->GetAttr("ceil_mode")->cast()->value(); + + if (stride <= 0) { + MS_LOG(EXCEPTION) << "Invalid stride value: " << stride << ", should greater then 0"; + } + if (nan_opt != 0) { + MS_LOG(EXCEPTION) << "Invalid nan_opt value: " << nan_opt << ", should be 0"; + } + if (data_mode != 1) { + MS_LOG(EXCEPTION) << "Invalid data_mode value: " << data_mode << ", should be 1"; + } + if (ceil_mode != 0) { + MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0"; + } + + std::set available_pad_mode{"pad", "same", "valid"}; + auto pad_mode_ptr = primitive->GetAttr("pad_mode"); + if ((pad_mode_ptr != nullptr) && pad_mode_ptr->isa()) { + auto pad_mode = pad_mode_ptr->cast()->value(); + if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { + MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; + } + if (pad_mode == "valid") { + padding = 0; + } else if (pad_mode == "same") { + padding = (window - 1) / 2; + } + } + + std::set available_mode{"max", "avg"}; + auto mode_ptr = primitive->GetAttr("mode"); + if ((mode_ptr != nullptr) && mode_ptr->isa()) { + auto mode = mode_ptr->cast()->value(); + if (available_mode.find(mode) == available_mode.end()) { + MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << "."; + } + } + + int h_out = ((h_input + 2 * padding - (window - 1) - 1) / stride) + 1; + int w_out = ((w_input + 2 * padding - (window - 1) - 1) / stride) + 1; + std::vector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out}; + AbstractBasePtr ret = input_tensor->Broaden(); + ret->set_shape(std::make_shared(shape_out)); + return ret; +} + +AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors(y, dy, x). + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto out_y = CheckArg(op_name, args_spec_list, 0); + auto d_out = CheckArg(op_name, args_spec_list, 1); + auto input_x = CheckArg(op_name, args_spec_list, 2); + (void)CheckTensorsDTypeSame({out_y, d_out, input_x}, {kInt, kUInt, kFloat}, + op_name + "evaluator three inputs should be %s"); + + AbstractBasePtr ret = d_out->Broaden(); + auto x_shape = dyn_cast(args_spec_list[2]->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(x_shape); + + ret->set_shape(x_shape); + return ret; +} + +void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { + // check dimension, x > 1, others equal 1 + const std::string op_name = primitive->name(); + for (std::size_t i = 0; i < args_spec_list.size(); ++i) { + AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, i); + ShapePtr arg_shape = dyn_cast(arg->GetShapeTrack()); + if (arg_shape == nullptr) { + MS_LOG(EXCEPTION) << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString(); + } + + if (i == 0) { + if (arg_shape->shape().size() < 2) { + MS_LOG(EXCEPTION) << op_name << " shape of args[" << i + << "] should be TensorShape with dimension greater than 1, but shape: " + << arg_shape->ToString(); + } + continue; + } + + if (arg_shape->shape().size() != 1) { + MS_LOG(EXCEPTION) << op_name << " shape of args[" << i + << "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString(); + } + } +} + +AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: five tensors(x, gamma, beta, mean, variance). + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 5); + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + MS_LOG(DEBUG) << "InferImplFusedBatchNorm args0:" << args_spec_list[0]->ToString() + << ", arg1:" << args_spec_list[1]->ToString(); + FusedBatchNormCheckDim(primitive, args_spec_list); + + auto input = args_spec_list[0]; + auto input_shape = dyn_cast(input->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(input_shape); + const auto &input_shape_list = input_shape->shape(); + if (input_shape_list.size() < 2) { + MS_LOG(EXCEPTION) << "Input shape size should >= 2."; + } + + for (size_t i = 1; i < args_spec_list.size(); ++i) { + auto arg_shape = dyn_cast(args_spec_list[i]->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(arg_shape); + const auto &arg_shape_list = arg_shape->shape(); + if (arg_shape_list.size() < 1) { + MS_LOG(EXCEPTION) << "Arg shape size should >= 1."; + } + if (arg_shape_list[0] != input_shape_list[1]) { + MS_LOG(EXCEPTION) << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0] + << ") should match the second dimension of tensor" + " param[0](which is " + << input_shape_list[1] << ")."; + } + } + auto input_tensor = CheckArg(op_name, args_spec_list, 0); + (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param 0 of FusedBatchNorm should be %s"); + + AbstractTensorPtrList tensorPtrList = std::vector(); + for (size_t i = 1; i < args_spec_list.size(); ++i) { + auto param = CheckArg(op_name, args_spec_list, i); + tensorPtrList.push_back(param); + } + (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32}, "param 1 to 4 of FusedBatchNorm should be %s"); + + // check validity; + auto epsilon_value = primitive->GetAttr("epsilon"); + auto momentum_value = primitive->GetAttr("momentum"); + MS_EXCEPTION_IF_NULL(epsilon_value); + MS_EXCEPTION_IF_NULL(momentum_value); + if (!epsilon_value->isa() || !momentum_value->isa()) { + MS_LOG(EXCEPTION) << "expect epsilon and momentum be float, but: epsilon: " << epsilon_value->ToString() + << ", momentum: " << momentum_value->ToString(); + } + + auto epsilon = epsilon_value->cast()->value(); + auto momentum = momentum_value->cast()->value(); + + if (epsilon > 1.0f || epsilon <= 0.0f) { + MS_LOG(EXCEPTION) << "expect epsilon is greater than 0 and less or equal than 1, but epsilon: " << epsilon; + } + if (momentum > 1.0f || momentum < 0.0f) { + MS_LOG(EXCEPTION) << "expect momentum is great or equal than 0 and less or equal than 1, but epsilon: " << momentum; + } + + // Outputs: y, running_mean, running_variance, save_mean, save_inv_variance. + AbstractBasePtr y = input->Broaden(); + AbstractBasePtr other = args_spec_list[1]->Broaden(); + MS_LOG(DEBUG) << "output y: " << y->ToString() << ", other: " << other->ToString(); + + AbstractBasePtrList elements = {y, other, other, other, other}; + return std::make_shared(elements); +} + +AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + MS_EXCEPTION_IF_NULL(args_spec_list[2]); + MS_EXCEPTION_IF_NULL(args_spec_list[3]); + + CheckArgsSize(primitive->name(), args_spec_list, 5); + auto dx = args_spec_list[1]->Broaden(); + auto dscale = args_spec_list[2]->Broaden(); + auto dbias = args_spec_list[3]->Broaden(); + + AbstractBasePtrList rets = {dx, dscale, dbias}; + return std::make_shared(rets); +} + +AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors(y_backprop, x). + CheckArgsSize(primitive->name(), args_spec_list, 2); + return args_spec_list[1]->Broaden(); +} + +AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors(doutput, input, filters). + CheckArgsSize(primitive->name(), args_spec_list, 3); + return args_spec_list[1]->Broaden(); +} + +AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors(inputs, filter, doutput). + CheckArgsSize(primitive->name(), args_spec_list, 3); + return args_spec_list[2]->Broaden(); +} + +AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: at least one tensor(y_backprop) + // Outputs: dbias + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << primitive->name() << " evaluator at least has 1 parameters, while the input size is " + << args_spec_list.size() << "."; + } + + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + ShapePtr shape_y = dyn_cast(args_spec_list[0]->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(shape_y); + std::vector y_dims = shape_y->shape(); + if (y_dims.size() < 2) { + MS_LOG(EXCEPTION) << primitive->name() << " input y backprop, dim should >= 2, while " << y_dims.size() << "."; + } + std::vector bias_dims = {y_dims[1]}; + ShapePtr ret_shape = std::make_shared(bias_dims); + AbstractBasePtr ret = args_spec_list[0]->Broaden(); + ret->set_shape(ret_shape); + return ret; +} + +AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Broaden(); +} + +AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Broaden(); +} + +AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Broaden(); +} + +AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + AbstractBasePtrList args_list; + for (size_t i = 0; i < args_spec_list.size() - 2; i++) { + args_list.push_back(args_spec_list[i]->Broaden()); + } + return std::make_shared(args_list); +} + +AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors(x, gamma, beta). + // outputs: y, mean, variance + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto input_x = CheckArg(op_name, args_spec_list, 0); + auto input_shape = input_x->shape(); + auto const &input_shape_list = input_shape->shape(); + const size_t input_rank = input_shape_list.size(); + if (input_rank == 0) { + MS_LOG(EXCEPTION) << "input_rank should not be zero"; + } + + // begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1 + ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis"); + int begin_norm_axis = CheckAxis(op_name, bna_ptr, -1, SizeToInt(input_rank) - 1); + + ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis"); + int begin_params_axis = CheckAxis(op_name, bpa_ptr, -1, SizeToInt(input_rank) - 1); + begin_params_axis = GetPositiveAxis(begin_params_axis, input_rank); + + // the beta and gama shape should be x_shape[begin_params_axis:] + auto tensor = CheckArg(op_name, args_spec_list, 0); + auto gamma = CheckArg(op_name, args_spec_list, 1); + auto beta = CheckArg(op_name, args_spec_list, 2); + (void)CheckTensorDType(tensor, {kFloat16, kFloat32}, "input 0 of LayerNorm should be %s"); + (void)CheckTensorDType(gamma, {kFloat16, kFloat32}, "input 1 of LayerNorm should be %s"); + (void)CheckTensorDType(beta, {kFloat16, kFloat32}, "input 2 of LayerNorm should be %s"); + auto gamma_shape = dyn_cast(gamma->BuildShape()); + auto beta_shape = dyn_cast(beta->BuildShape()); + MS_EXCEPTION_IF_NULL(gamma_shape); + MS_EXCEPTION_IF_NULL(beta_shape); + + auto const &gamma_shape_list = gamma_shape->shape(); + auto const &beta_shape_list = beta_shape->shape(); + if (gamma_shape_list.empty() || beta_shape_list.empty()) { + MS_LOG(EXCEPTION) << "LayerNorm evaluator gamma or beta is a AbstractScalar that is not support."; + } + + size_t begin_params_axis_u = IntToSize(begin_params_axis); + if ((begin_params_axis_u > input_shape_list.size()) || + (gamma_shape_list.size() + begin_params_axis_u < input_shape_list.size()) || + (beta_shape_list.size() + begin_params_axis_u < input_shape_list.size())) { + MS_LOG(EXCEPTION) << "Gamma and beta shape get wrong size."; + } + for (size_t i = begin_params_axis_u; i < input_shape_list.size(); ++i) { + size_t gamma_beta_shape_dim = i - begin_params_axis_u; + if ((gamma_shape_list[gamma_beta_shape_dim] != input_shape_list[i]) || + (beta_shape_list[gamma_beta_shape_dim] != input_shape_list[i])) { + MS_LOG(EXCEPTION) << "Gamma or beta shape not match input shape, input_shape=" << input_shape->ToString() + << ", gamma_shape=" << gamma_shape->ToString() << ", beta_shape=" << beta_shape->ToString(); + } + } + + auto mean_var_shape_value = input_shape->shape(); + if (begin_norm_axis == -1) { + mean_var_shape_value[input_rank - 1] = 1; + } else { + for (size_t i = begin_norm_axis; i < input_rank; ++i) { + mean_var_shape_value[i] = 1; + } + } + + auto mean = input_x->Broaden(); + mean->set_shape(std::make_shared(mean_var_shape_value)); + auto var = input_x->Broaden(); + var->set_shape(std::make_shared(mean_var_shape_value)); + + AbstractBasePtrList args_list({input_x->Broaden(), mean, var}); + return std::make_shared(args_list); +} + +AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: five tensors(y_backprob, x, variance, mean, gamma). + // Outputs: x_backprob, gamma_backprob, beta_backprob + CheckArgsSize(primitive->name(), args_spec_list, 5); + + auto x_backprob = args_spec_list[0]->Broaden(); + auto gamma_backprob = args_spec_list[4]->Broaden(); + auto beta_backprob = args_spec_list[4]->Broaden(); + + AbstractBasePtrList args_list({x_backprob, gamma_backprob, beta_backprob}); + return std::make_shared(args_list); +} + +AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple and a tensor. + // Outputs: mask. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr x_shape = CheckArg(op_name, args_spec_list, 0); + AbstractTensorPtr keep_prob = CheckArg(op_name, args_spec_list, 1); + + TypePtr prob_type = keep_prob->element()->BuildType(); + if ((prob_type->type_id() != kNumberTypeFloat16) && (prob_type->type_id() != kNumberTypeFloat32)) { + MS_LOG(EXCEPTION) << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString() + << "."; + } + + auto x_shape_data = x_shape->elements(); + int count = 1; + for (std::size_t i = 0; i < x_shape->size(); ++i) { + auto value_track = x_shape_data[i]->GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + if (!value_track->isa()) { + MS_LOG(EXCEPTION) << "DropOutGenMask input x_shape elements is not int32, but " << value_track->ToString() << "."; + } + + int e_value = GetValue(value_track); + if (e_value <= 0) { + MS_LOG(EXCEPTION) << "DropOutGenMask product of x_shape should be > 0"; + } + if (std::numeric_limits::max() / count / e_value < 1) { + MS_LOG(EXCEPTION) << "integer multiply integer overflow"; + } + count = count * e_value; + } + + // convert to bytes(8 bits) mask, using round up + int n128s = count / 128; + if ((count % 128) != 0) { + n128s++; + } + int bytes_count = n128s * 16; + std::vector shape_y{bytes_count}; + + primitive->set_attr("T", kInt32); + return std::make_shared(std::make_shared(kAnyValue, kUInt8), + std::make_shared(std::vector{shape_y})); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_others.cc b/mindspore/ccsrc/frontend/operator/prim_others.cc new file mode 100644 index 0000000000..530ad6a10c --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_others.cc @@ -0,0 +1,410 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "ir/dtype.h" +#include "common/utils.h" +#include "frontend/operator/ops.h" +#include "abstract/param_validator.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "abstract/utils.h" +#include "utils/context/ms_context.h" +#include "utils/symbolic.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // An object of a subclass of AbstractBase + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]; +} + +AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: An object of AbstractFunction. + CheckArgsSize(primitive->name(), args_spec_list, 1); + MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString(); + + AbstractFunctionPtr x = dyn_cast(args_spec_list[0]); + if (x == nullptr) { + return std::make_shared(args_spec_list[0]); + } + + AbstractFuncAtomPtrList jv; + auto build_jv = [&jv](const AbstractFuncAtomPtr &func) { + auto j_closure = std::make_shared(func); + jv.push_back(j_closure); + }; + x->Visit(build_jv); + + return AbstractFunction::MakeAbstractFunction(jv); +} + +AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(primitive); + // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). + CheckArgsSize(primitive->name(), args_spec_list, 3); + auto key = args_spec_list[1]; + auto dflt = args_spec_list[2]; + TypePtr type = key->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(type); + if (type->type_id() != kObjectTypeSymbolicKeyType) { + MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); + } + + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (enable_sparse && dflt->isa()) { + auto dflt_tensor = dflt->cast(); + return std::make_shared(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); + } + + if (!key->GetValueTrack()->isa()) { + return dflt; + } + ValuePtr key_value_ptr = key->GetValueTrack(); + MS_EXCEPTION_IF_NULL(key_value_ptr); + auto key_value_track = key_value_ptr->cast(); + auto expected = key_value_track->abstract(); + MS_EXCEPTION_IF_NULL(expected); + (void)expected->Join(dflt); + return expected; +} + +AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). + CheckArgsSize(primitive->name(), args_spec_list, 3); + + auto key = args_spec_list[1]; + ValuePtr key_value_ptr = key->GetValueTrack(); + MS_EXCEPTION_IF_NULL(key_value_ptr); + auto key_value_track = key_value_ptr->cast(); + if (key_value_track == nullptr) { + MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: " + << key_value_ptr->ToString(); + } + auto expected = key_value_track->abstract(); + MS_EXCEPTION_IF_NULL(expected); + return std::make_shared(kAnyValue, std::make_shared()); +} + +AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). + CheckArgsSize(primitive->name(), args_spec_list, 2); + return std::make_shared(kAnyValue, std::make_shared()); +} + +AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &prim, const AbstractBasePtrList &) { + ValuePtr name_value = prim->GetAttr("tag"); + auto name = name_value->cast(); + if (name == nullptr) { + MS_LOG(EXCEPTION) << "MakeRefKey attr tag sould be a String " << name_value->ToString() << "."; + } + auto refkey = std::make_shared(name->value()); + if (refkey == nullptr) { + MS_LOG(EXCEPTION) << "MakeRefKey std::make_shared failed"; + } + return refkey->ToAbstract(); +} + +AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // arguments: key, value, original value + if (args_spec_list.size() != 3) { + MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() + << "."; + } + TypePtr type = args_spec_list[0]->GetTypeTrack(); + if (type->type_id() != kObjectTypeRefKey) { + MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); + } + return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); +} + +AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // arguments: value + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "get_ref_key requires 1 parameters, while the input size is " << args_spec_list.size() << "."; + } + TypePtr type = args_spec_list[0]->GetTypeTrack(); + if (type->type_id() != kObjectTypeRef) { + MS_LOG(EXCEPTION) << "First input of get_ref_key should be a Ref but a " << type->ToString(); + } + return args_spec_list[0]->cast()->ref(); +} + +AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // arguments: value + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size() + << "."; + } + TypePtr type = args_spec_list[0]->GetTypeTrack(); + if (type->type_id() != kObjectTypeRef) { + MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); + } + return args_spec_list[0]->cast()->ref(); +} + +AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // arguments: value + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size() + << "."; + } + TypePtr type = args_spec_list[0]->GetTypeTrack(); + if (type->type_id() != kObjectTypeRef) { + MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); + } + return args_spec_list[0]->cast()->ref_origin(); +} + +AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: Two objects of a subclass of AbstractBase, key and value. + CheckArgsSize(primitive->name(), args_spec_list, 2); + + TypePtr type = args_spec_list[0]->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(type); + if (type->type_id() != kObjectTypeRefKey && type->type_id() != kObjectTypeSymbolicKeyType) { + MS_LOG(EXCEPTION) << "First input of StateSetItem should be a RefKey or SymbolicKeyType but a " << type->ToString(); + } + return std::make_shared(kAnyValue, kBool); +} + +AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0"; + } + auto depends = args_spec_list[0]->Broaden(); + return depends; +} + +bool CompareShape(const std::vector &x_shape, const std::vector &y_shape) { + if (x_shape.size() != y_shape.size()) { + return false; + } + + for (size_t i = 0; i < x_shape.size(); ++i) { + if (GetValue(x_shape[i]) != GetValue(y_shape[i])) { + return false; + } + } + + return true; +} + +enum State { + SAME, + X_ONE, + Y_ONE, +}; + +void ComputeReduceIndex(const std::vector &reverse_x, const std::vector &reverse_y, + std::vector *grad_x_reduce_idx, std::vector *grad_y_reduce_idy) { + const size_t n = reverse_x.size(); + for (size_t i = 0; i < n; ++i) { + State curr; + const int32_t x_i = reverse_x[i]; + const int32_t y_i = reverse_y[i]; + const int reduce_idx = SizeToInt(n - 1 - i); + if (x_i == y_i) { + curr = SAME; + } else if (x_i == 1) { + grad_x_reduce_idx->push_back(reduce_idx); + curr = X_ONE; + } else if (y_i == 1) { + grad_y_reduce_idy->push_back(reduce_idx); + curr = Y_ONE; + } else { + MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs"; + } + if (curr == SAME && x_i == 1) { + grad_x_reduce_idx->push_back(reduce_idx); + grad_y_reduce_idy->push_back(reduce_idx); + continue; + } + } + + std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end()); + std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end()); +} + +AbstractBasePtr BroadcastGradientArgsDiff(const std::vector &x_shape, const std::vector &y_shape) { + std::vector reverse_x; + std::vector reverse_y; + + (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x), + [](const ValuePtr &v) { return v->cast()->value(); }); + (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y), + [](const ValuePtr &v) { return v->cast()->value(); }); + + if (reverse_x.size() > reverse_y.size()) { + reverse_y.resize(reverse_x.size(), 1); + } else { + reverse_x.resize(reverse_y.size(), 1); + } + + std::vector grad_x_reduce_idx; + std::vector grad_y_reduce_idy; + ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy); + + AbstractBasePtrList abs_list_x; + AbstractBasePtrList abs_list_y; + (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x), + [](int v) { return abstract::FromValue(v); }); + (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y), + [](int v) { return abstract::FromValue(v); }); + auto x_reduce_idx = std::make_shared(abs_list_x); + auto y_reduce_idx = std::make_shared(abs_list_y); + AbstractBasePtrList elem_list; + elem_list.push_back(x_reduce_idx); + elem_list.push_back(y_reduce_idx); + + return std::make_shared(elem_list); +} + +AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // this primitive get the index that need to reduce + // input: x's shape and y's shape, inputs should be tuple + // output: tuple of x and y 's reduce index, reduce index should be a tuple + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto arg_x = CheckArg(op_name, args_spec_list, 0); + auto arg_y = CheckArg(op_name, args_spec_list, 1); + + ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(arg_x_value); + + ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(arg_y_value); + + const std::vector x_shape = arg_x_value->value(); + const std::vector y_shape = arg_y_value->value(); + bool is_same_shape = CompareShape(x_shape, y_shape); + // if it is the same shape , do not need reduce , return empty tuple + if (is_same_shape) { + AbstractBasePtrList empty_list; + auto x_reduce_idx = std::make_shared(empty_list); + auto y_reduce_idx = std::make_shared(empty_list); + + AbstractBasePtrList elem_list; + elem_list.push_back(x_reduce_idx); + elem_list.push_back(y_reduce_idx); + + return std::make_shared(elem_list); + } + + return BroadcastGradientArgsDiff(x_shape, y_shape); +} + +AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: Two objects of a subclass of AbstractBase + CheckArgsSize(primitive->name(), args_spec_list, 2); + auto arg_src = args_spec_list[0]; + auto arg_dst = args_spec_list[1]; + // control depend can not setup tuple of ops to tuple of ops dependency relation + if (arg_src->isa() && arg_dst->isa()) { + auto src_size = arg_src->cast()->size(); + auto dst_size = arg_src->cast()->size(); + if (src_size > 1 && dst_size > 1) { + MS_LOG(EXCEPTION) << "Control depend can not setup operator dependcy relationship from tuple from tuple"; + } + } + return std::make_shared(kAnyValue, kBool); +} + +AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto indices = CheckArg(op_name, args_spec_list, 0); + auto values = CheckArg(op_name, args_spec_list, 1); + auto dense_shape = CheckArg(op_name, args_spec_list, 2); + + auto dense_shape_value = dense_shape->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(dense_shape_value); + auto shp = dense_shape_value->value(); + std::vector dense_shape_vec; + (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), + [](const ValuePtr &e) -> int { + auto elem = GetValue(e); + return elem; + }); + auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); + ret->set_indices(indices); + ret->set_values(values); + ret->set_dense_shape(dense_shape); + return ret; +} + +AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto indexed_slices = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(indexed_slices->values()); + return indexed_slices->values(); +} + +AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto indexed_slices = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(indexed_slices->indices()); + return indexed_slices->indices(); +} + +AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto indexed_slices = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape()); + return indexed_slices->dense_shape(); +} + +AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + bool ret = false; + if (args_spec_list[0]->isa()) { + ret = true; + } + MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString(); + return std::make_shared(ret); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_statement.cc b/mindspore/ccsrc/frontend/operator/prim_statement.cc new file mode 100644 index 0000000000..bb421bdf8a --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_statement.cc @@ -0,0 +1,249 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/param_validator.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" +#include "utils/symbolic.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a pointer to an AbstractBase object + if (args_spec_list.size() != 1) { + MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? " + "while the input size is " + << args_spec_list.size() << "."; + } + AbstractBasePtr abs_base = args_spec_list[0]; + return abs_base; +} + +AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a pointer to an AbstractBase object + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size() + << "."; + } + AbstractBasePtr abs_base = args_spec_list[0]; + MS_EXCEPTION_IF_NULL(abs_base); + TypePtr type = abs_base->BuildType(); + return std::make_shared(type); +} + +AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a pointer to an AbstractBase object and a pointer to a Type + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTypePtr abs_type = CheckArg(op_name, args_spec_list, 1); + + auto mode_v = abs_type->GetValueTrack(); + MS_EXCEPTION_IF_NULL(mode_v); + if (!mode_v->isa()) { + MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed."; + } + + TypePtr mode_t = mode_v->cast(); + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + bool v = IsSubtype(args_spec_list[0], mode_t); + return std::make_shared(std::make_shared(v), kBool); +} + +AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTensorPtr input_x = CheckArg(op_name, args_spec_list, 0); + AbstractTensorPtr input_y = CheckArg(op_name, args_spec_list, 1); + + ShapePtr x_shp = input_x->shape(); + auto x_shp_value = x_shp->shape(); + ShapePtr y_shp = input_y->shape(); + auto y_shp_value = y_shp->shape(); + // Should be matrix which shape size is 2. + if (x_shp_value.size() != 2 || y_shp_value.size() != 2) { + MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are " + << x_shp_value.size() << ", " << y_shp_value.size() << " "; + } + if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) { + MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}"; + } + + auto x_element = input_x->element(); + MS_EXCEPTION_IF_NULL(x_element); + (void)x_element->Join(input_y->element()); + auto param = {x_shp_value[0], y_shp_value[1]}; + + return std::make_shared(input_x->element(), std::make_shared(param)); +} + +AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim, + const AbstractBasePtrList &args_spec_list) { + // Inputs: condition, true branch, false branch + if (args_spec_list.size() != 3) { + MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_spec_list.size() + << "."; + } + + auto cond = args_spec_list[0]; + auto tb = args_spec_list[1]; + auto fb = args_spec_list[2]; + MS_EXCEPTION_IF_NULL(cond); + + auto unroll_flag = prim->GetAttr(prim::SWITCH_UNROLL_FLAG); + if (unroll_flag != nullptr && GetValue(unroll_flag) == 0) { + return tb->Join(fb); + } + + ValuePtr v = cond->GetValueTrack(); + MS_EXCEPTION_IF_NULL(v); + // for tensor as condition, keeps both true and false branch. + if (v->isa() || cond->isa()) { + MS_EXCEPTION_IF_NULL(tb); + return tb->Join(fb); + } + + if (v->isa()) { + if (v->cast()->IsOne()) { + return tb; + } else { + return fb; + } + } + + MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString(); +} + +AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: index, branch + const std::string op_name = primitive->name(); + abstract::CheckArgsSize(op_name, args_spec_list, 2); + (void)CheckArg(op_name, args_spec_list, 0); + AbstractTuplePtr branches_abs = CheckArg(op_name, args_spec_list, 1); + AbstractBasePtrList branches = branches_abs->elements(); + const size_t maximum_layer_num = 1000; + if (branches.size() < 0 || branches.size() > maximum_layer_num) { + MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " + << branches.size() << " branches."; + } + + for (size_t i = 0; i < branches.size(); i++) { + MS_EXCEPTION_IF_NULL(branches[i]); + if (!branches[i]->isa()) { + MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got " + << branches[i]->ToString() << " as the " << i << "th element."; + } + } + + auto b = branches[0]; + for (size_t i = 1; i < branches.size(); i++) { + b = b->Join(branches[i]); + } + return b; +} + +std::vector GetSupportedTargetValue() { + std::vector list = {kNone, MakeValue(false), MakeValue(true)}; + return list; +} + +bool SupportedIsTargetValue(const ValuePtr t) { + auto list = GetSupportedTargetValue(); + auto match = std::any_of(list.begin(), list.end(), [&t](const ValuePtr &v) { return *v == *t; }); + return match; +} + +AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: x is t + // Inputs: x, t + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + ValuePtr t = args_spec_list[1]->BuildValue(); + if (!SupportedIsTargetValue(t)) { + MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString() + << " for statement is, supported list is:None, False, True "; + } + ValuePtr x = args_spec_list[0]->BuildValue(); + + return std::make_shared(*t == *x); +} + +AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: x is not t + // Inputs: x, t + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + ValuePtr t = args_spec_list[1]->BuildValue(); + if (!SupportedIsTargetValue(t)) { + MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString() + << " for statement is not, supported list is:None, False, True "; + } + ValuePtr x = args_spec_list[0]->BuildValue(); + + return std::make_shared(!(*t == *x)); +} + +bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto key = CheckArg(op_name, args_spec_list, 0); + auto dict = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + auto key_str = GetValue(key_value); + std::vector dict_elems = dict->elements(); + auto it = std::find_if(dict_elems.begin(), dict_elems.end(), + [key_str](const AbstractAttribute &item) { return item.first == key_str; }); + return it != dict_elems.end(); +} + +AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: x in t + // Inputs: x, t + return std::make_shared(IsInDict(primitive, args_spec_list)); +} + +AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: x not in t + // Inputs: x, t + return std::make_shared(!IsInDict(primitive, args_spec_list)); +} + +AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: isconstant(x) + // Inputs: x + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1"; + } + ValuePtr v = args_spec_list[0]->BuildValue(); + return std::make_shared(!v->isa()); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_structures.cc b/mindspore/ccsrc/frontend/operator/prim_structures.cc new file mode 100644 index 0000000000..b602b07a0c --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_structures.cc @@ -0,0 +1,712 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/prim.h" +#include "abstract/utils.h" +#include "abstract/param_validator.h" +#include "frontend/operator/ops.h" +#include "utils/convert_utils.h" +#include "ir/tensor_py.h" + +using mindspore::tensor::TensorPy; + +namespace mindspore { +namespace abstract { + +AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two scalars whose value is a string. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); + + ValuePtr value_x = scalar_x->BuildValue(); + ValuePtr value_y = scalar_y->BuildValue(); + if (!value_x->isa() || !value_y->isa()) { + MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() + << ", param1: " << value_y->ToString(); + } + + bool ret = (value_x->cast()->value() == value_y->cast()->value()); + return std::make_shared(ret); +} + +AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two scalars whose value is a string. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); + + ValuePtr value_x = scalar_x->BuildValue(); + ValuePtr value_y = scalar_y->BuildValue(); + if (!value_x->isa() || !value_y->isa()) { + MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() + << ", param1: " << value_y->ToString(); + } + + std::string ret = (value_x->cast()->value() + value_y->cast()->value()); + return std::make_shared(ret); +} + +AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + return std::make_shared(args_spec_list); +} + +AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + return std::make_shared(args_spec_list); +} + +AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tuples. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr keys = CheckArg(op_name, args_spec_list, 0); + AbstractTuplePtr values = CheckArg(op_name, args_spec_list, 1); + + size_t keys_size = keys->size(); + if (values->size() != keys_size) { + MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size"; + } + + std::vector key_value; + AbstractScalarPtr key; + AbstractBasePtrList key_list = keys->elements(); + AbstractBasePtrList value_list = values->elements(); + for (size_t index = 0; index < keys_size; index++) { + key = CheckArg(op_name + "key", key_list, index); + ValuePtr keyPtr = key->BuildValue(); + MS_EXCEPTION_IF_NULL(keyPtr); + if (!keyPtr->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); + } + std::string key_string = GetValue(keyPtr); + key_value.emplace_back(key_string, value_list[index]); + } + return std::make_shared(key_value); +} + +AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a string and an object of a subclass of AbstractBase. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 0); + + ValuePtr keyPtr = key->BuildValue(); + if (!keyPtr->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); + } + std::string key_string = GetValue(keyPtr); + return std::make_shared(key_string, args_spec_list[1]); +} + +AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a string and a keyword. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 0); + AbstractKeywordArgPtr kwarg = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + std::string key_input = GetValue(key_value); + std::string key_actual = kwarg->get_key(); + if (key_actual != key_input) { + MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is " + << key_input << ", AbstractKeywordArg' key is " << key_actual; + } + return kwarg->get_arg(); +} + +AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three scalars whose value is an int32 number. + CheckArgsSize(primitive->name(), args_spec_list, 3); + size_t args_size = args_spec_list.size(); + for (size_t index = 0; index < args_size; index++) { + MS_EXCEPTION_IF_NULL(args_spec_list[index]); + if (!args_spec_list[index]->isa() && !args_spec_list[index]->isa()) { + MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; + } + if (args_spec_list[index]->isa() && + !dyn_cast(args_spec_list[index])->BuildValue()->isa()) { + MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number."; + } + } + // Slice: start, end, step + return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); +} + +// Eval the return type of make_record +AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: at lease two objects of a subclass of AbstractBase. + if (args_spec_list.size() < 2) { + MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is " + << args_spec_list.size() << "."; + } + + // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + TypePtr type = args_spec_list[0]->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(type); + if (type->type_id() != kMetaTypeTypeType) { + MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType"; + } + + ValuePtr value_track = args_spec_list[0]->GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + TypePtr type_ptr = value_track->cast(); + if (type_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString(); + } + + auto cls = dyn_cast(type_ptr); + MS_EXCEPTION_IF_NULL(cls); + ClassAttrVector attributes = cls->GetAttributes(); + CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1); + + std::vector abs_attributes; + for (size_t i = 0; i < attributes.size(); i++) { + AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]); + abs_attributes.push_back(elem); + } + + return std::make_shared(cls->tag(), abs_attributes, cls->methods()); +} + +template +AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple or list and a scalar whose value is an int32 number. + CheckArgsSize(op_name, args_spec_list, 2); + auto queue = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr index = CheckArg(op_name, args_spec_list, 1); + + ValuePtr index_value = index->BuildValue(); + if (!index_value->isa()) { + // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element + // and continue + if (dyn_cast(queue->elements()[0]) != nullptr) { + return std::make_shared(queue->elements()[0]->BuildType()); + } + MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " + << index_value->ToString(); + } + int idx_v = GetValue(index_value); + std::size_t nelems = queue->elements().size(); + if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { + MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " + << SizeToInt(nelems) << "), but got " << idx_v << "."; + } + + std::size_t uidx_v = 0; + if (idx_v >= 0) { + uidx_v = IntToSize(idx_v); + } else { + uidx_v = IntToSize(idx_v + SizeToInt(nelems)); + } + return queue->elements()[uidx_v]; +} + +template +AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase. + CheckArgsSize(op_name, args_spec_list, 3); + auto queue = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr index = CheckArg(op_name, args_spec_list, 1); + + ValuePtr index_value = index->BuildValue(); + if (!index_value->isa()) { + MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " + << index_value->ToString(); + } + int idx_v = GetValue(index_value); + if (idx_v < 0) { + MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v + << "."; + } + + size_t uidx_v = IntToSize(idx_v); + AbstractBasePtrList elements = queue->elements(); + std::size_t nelems = elements.size(); + if (uidx_v >= nelems) { + MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 + << "."; + } + elements[uidx_v] = args_spec_list[2]; + return std::make_shared(elements); +} + +AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListGetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListGetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListSetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListSetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a dict and a scalar whose value is a string. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + auto key_str = GetValue(key_value); + std::vector dict_elems = dict->elements(); + auto it = std::find_if(dict_elems.begin(), dict_elems.end(), + [key_str](const AbstractAttribute &item) { return item.first == key_str; }); + + if (it == dict_elems.end()) { + MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); + } + return it->second; +} + +AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + std::string key_str = GetValue(key_value); + std::vector dict_elems = dict->elements(); + auto it = std::find_if(dict_elems.begin(), dict_elems.end(), + [key_str](AbstractAttribute &item) { return item.first == key_str; }); + + MS_EXCEPTION_IF_NULL(args_spec_list[2]); + auto new_ele = std::make_pair(key_str, args_spec_list[2]); + if (it != dict_elems.end()) { + int index = it - dict_elems.begin(); + dict_elems[IntToSize(index)] = new_ele; + } else { + dict_elems.push_back(new_ele); + } + return std::make_shared(dict_elems); +} + +AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a list and an object of a subclass of AbstractBase. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractListPtr list = CheckArg(op_name, args_spec_list, 0); + (void)AbstractJoin(list->elements()); + return list; +} + +template +AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple or list or dict. + CheckArgsSize(op_name, args_spec_list, 1); + auto arg = CheckArg(op_name, args_spec_list, 0); + return std::make_shared(SizeToInt(arg->size())); +} + +AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + return std::make_shared(kAnyValue, kInt32); +} + +AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: fn, list1, list2, ... + MS_EXCEPTION_IF_NULL(engine); + if (args_spec_list.size() <= 1) { + MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << "."; + } + AbstractFunctionPtr fn = CheckArg(primitive->name(), args_spec_list, 0); + // check args from 1. + CheckArgsSpec(AbstractBasePtrList(args_spec_list.begin() + 1, args_spec_list.end())); + + AbstractBasePtrList subargs; + for (std::size_t i = 1; i < args_spec_list.size(); i++) { + AbstractListPtr l_ptr = dyn_cast(args_spec_list[i]); + if (l_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Argument[" << i << "] of list_map should be a list."; + } + subargs.push_back(AbstractJoin(l_ptr->elements())); + } + EvalResultPtr engin_exc = engine->Execute(fn, subargs); + AbstractBasePtrList result; + for (std::size_t i = 1; i < args_spec_list.size(); i++) { + result.push_back(engin_exc->abstract()); + } + return std::make_shared(result); +} + +AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a fn, a list and an object of a subclass of a AbstractBase. + MS_EXCEPTION_IF_NULL(engine); + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + AbstractFunctionPtr fn = CheckArg(op_name, args_spec_list, 0); + AbstractListPtr lst = CheckArg(op_name, args_spec_list, 1); + AbstractBasePtr dflt = args_spec_list[2]; + + AbstractBasePtr list_type = AbstractJoin(lst->elements()); + auto result1 = engine->Execute(fn, lst->elements()); + auto result2 = engine->Execute(fn, {dflt, list_type}); + MS_EXCEPTION_IF_NULL(result1->abstract()); + MS_EXCEPTION_IF_NULL(result2->abstract()); + return result1->abstract()->Join(result2->abstract()); +} + +AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTuplePtr input = CheckArg(op_name, args_spec_list, 0); + + auto tuple_elements = input->elements(); + AbstractBasePtrList elem_list; + (void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list), + [](const AbstractBasePtr &elem) { return elem->Clone(); }); + return std::make_shared(elem_list); +} + +AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value, + const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) { + size_t x_rank = x_shape->size(); + std::set axis_set; + auto axis_data = axis_value_ptr->value(); + if (axis_data.empty()) { + int size = 1; + AbstractBasePtrList values(x_rank, std::make_shared(size)); + return std::make_shared(values); + } + + for (auto &elem : axis_data) { + int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1); + (void)axis_set.insert(e_value); + } + + auto x_shp_data = x_shp_value->cast()->value(); + if (x_shp_data.size() < x_rank) { + MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank; + } + AbstractBasePtrList values; + for (size_t i = 0; i < x_rank; i++) { + if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) { + auto axis_v = MakeValue(1); + values.push_back(std::make_shared(axis_v, axis_v->type())); + } else { + int dim_value = x_shp_data[i]->cast()->value(); + auto dim = MakeValue(dim_value); + values.push_back(std::make_shared(dim, dim->type())); + } + } + + return std::make_shared(values); +} + +AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: x_shape, axis + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr shape_x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + + auto x_shp_value = shape_x->BuildValue(); + if (x_shp_value->isa()) { + MS_LOG(EXCEPTION) << op_name + << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); + } + + // Axis can be scalar, tuple or None + AbstractTuplePtr axis = nullptr; + if (args_spec_list[1]->isa()) { + MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar"; + AbstractBasePtrList axis_list = {dyn_cast(args_spec_list[1])}; + axis = std::make_shared(axis_list); + } else if (args_spec_list[1]->isa()) { + MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple"; + axis = args_spec_list[1]->cast(); + } else { + MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got " + << args_spec_list[1]->ToString(); + } + + auto axis_value = axis->BuildValue(); + if (axis_value->isa()) { + MS_LOG(EXCEPTION) << op_name + << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); + } + auto axis_value_ptr = axis_value->cast(); + MS_EXCEPTION_IF_NULL(axis_value_ptr); + + return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive); +} + +AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tuples. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr shape_x = CheckArg(op_name, args_spec_list, 0); + AbstractTuplePtr div_shp = CheckArg(op_name, args_spec_list, 1); + MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString(); + + auto div_shp_value = div_shp->BuildValue(); + if (div_shp_value->isa()) { + MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString(); + } + + auto shpx_value = shape_x->BuildValue(); + if (shpx_value->isa()) { + MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString(); + } + + if (div_shp->size() != shape_x->size()) { + MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size() + << ", shapex: " << shape_x->size() << "."; + } + + auto shpx_data = shpx_value->cast()->value(); + auto div_shp_data = div_shp_value->cast()->value(); + AbstractBasePtrList values; + + for (size_t i = 0; i < div_shp_data.size(); i++) { + if (div_shp_data[i]->cast() == nullptr) { + MS_LOG(EXCEPTION) << "div_shp_shape data should be an int32 number, but it's " << args_spec_list[1]->ToString(); + } + int shapex_value = GetValue(shpx_data[i]); + int div_value = GetValue(div_shp_data[i]); + MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value; + if (div_value == 0) { + MS_LOG(EXCEPTION) << "error: division value should not be 0!"; + } + if ((shapex_value % div_value) != 0) { + MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int:" << shapex_value << " div_value: " << div_value; + } + + int result = shapex_value / div_value; + auto result_v = MakeValue(result); + values.push_back(std::make_shared(result_v, result_v->type())); + } + + return std::make_shared(values); +} + +AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTuplePtr input = CheckArg(op_name, args_spec_list, 0); + + py::tuple data_tuple = ValuePtrToPyData(input->BuildValue()); + py::array data = py::array(data_tuple); + auto tensor = TensorPy::MakeTensor(data); + auto ret = tensor->ToAbstract(); + ret->set_value(tensor); + MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString(); + return ret; +} + +AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple + // example: tuple = (1, 2, 3), shape_mul(tuple) = 1*2*3 = 6 + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTuplePtr shape_x = CheckArg(op_name, args_spec_list, 0); + + auto shpx_value = shape_x->BuildValue(); + if (shpx_value->isa()) { + MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString(); + } + + auto shpx_data = shpx_value->cast()->value(); + + int result = 1; + for (size_t i = 0; i < shpx_data.size(); i++) { + int value = GetValue(shpx_data[i]); + result = IntMulWithOverflowCheck(result, value); + } + + auto result_v = MakeValue(result); + MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString(); + return std::make_shared(result_v, result_v->type()); +} + +template +AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: two tuples or two lists. + CheckArgsSize(op_name, args_spec_list, 2); + auto input_x = CheckArg(op_name, args_spec_list, 0); + auto input_y = CheckArg(op_name, args_spec_list, 1); + + ValuePtr x_value = input_x->BuildValue(); + ValuePtr y_value = input_y->BuildValue(); + return std::make_shared(*x_value == *y_value); +} + +AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferImplTupleOrListEqual(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferImplTupleOrListEqual(primitive->name(), args_spec_list); +} + +struct SlideInfo { + int start; + int step; + int stop; +}; + +void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) { + int arg1 = 0; + int arg2 = 0; + if (!args_spec_list.empty()) { + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + auto arg_value = args_spec_list[0]->BuildValue(); + if (!arg_value->isa()) { + MS_LOG(EXCEPTION) << "Only supported input an int32 number."; + } + arg1 = GetValue(arg_value); + } + + if (args_spec_list.size() >= 2) { + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + auto arg_value = args_spec_list[1]->BuildValue(); + if (!arg_value->isa()) { + MS_LOG(EXCEPTION) << "Only supported input an int32 number."; + } + arg2 = GetValue(arg_value); + } + + if (args_spec_list.size() == 3) { + MS_EXCEPTION_IF_NULL(args_spec_list[2]); + auto arg_value = args_spec_list[2]->BuildValue(); + if (!arg_value->isa()) { + MS_LOG(EXCEPTION) << "Only supported input an int32 number."; + } + slide->step = GetValue(arg_value); + slide->start = arg1; + slide->stop = arg2; + } + + if (args_spec_list.size() == 2) { + slide->start = arg1; + slide->stop = arg2; + } + + if (args_spec_list.size() == 1) { + slide->stop = arg1; + } +} + +AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << "Cannot make range from empty input."; + } + + if (args_spec_list.size() > 3) { + MS_LOG(EXCEPTION) << "Error args size of make range operational."; + } + + SlideInfo slide = {0, 1, 0}; + CalcSlidePara(args_spec_list, &slide); + + if (slide.step == 0) { + MS_LOG(EXCEPTION) << "Error, step value is 0."; + } + + AbstractBasePtrList args; + if (slide.start <= slide.stop) { + if (slide.step <= 0) { + MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; + } + for (int i = slide.start; i < slide.stop; i += slide.step) { + args.push_back(abstract::FromValue(i)); + } + } else { + if (slide.step >= 0) { + MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; + } + for (int i = slide.start; i > slide.stop; i += slide.step) { + args.push_back(abstract::FromValue(i)); + } + } + + return std::make_shared(args); +} + +AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Clone(); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_to_function.cc b/mindspore/ccsrc/frontend/operator/prim_to_function.cc new file mode 100644 index 0000000000..7b9592e80e --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_to_function.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/prim_to_function.h" +#include +#include +#include + +namespace mindspore { +// namespace to support prim related definition +namespace prim { + +PrimToFunction::PrimToFunction() + : prim_func_type_map_({// ONE_ARG prim + {"bool_not", kPrimTypeOneArg}, + {"scalar_cos", kPrimTypeOneArg}, + {"scalar_exp", kPrimTypeOneArg}, + {"scalar_floor", kPrimTypeOneArg}, + {"scalar_log", kPrimTypeOneArg}, + {"scalar_sin", kPrimTypeOneArg}, + {"scalar_tan", kPrimTypeOneArg}, + {"scalar_trunc", kPrimTypeOneArg}, + {"typeof", kPrimTypeOneArg}, + {"scalar_uadd", kPrimTypeOneArg}, + {"scalar_usub", kPrimTypeOneArg}, + // TWO_ARGS prim + {"scalar_add", kPrimTypeTwoArgs}, + {"bool_and", kPrimTypeTwoArgs}, + {"bool_eq", kPrimTypeTwoArgs}, + {"bool_or", kPrimTypeTwoArgs}, + {"scalar_div", kPrimTypeTwoArgs}, + {"scalar_eq", kPrimTypeTwoArgs}, + {"scalar_ge", kPrimTypeTwoArgs}, + {"scalar_gt", kPrimTypeTwoArgs}, + {"scalar_le", kPrimTypeTwoArgs}, + {"scalar_lt", kPrimTypeTwoArgs}, + {"scalar_ne", kPrimTypeTwoArgs}, + {"scalar_mod", kPrimTypeTwoArgs}, + {"scalar_mul", kPrimTypeTwoArgs}, + {"scalar_pow", kPrimTypeTwoArgs}, + {"scalar_sub", kPrimTypeTwoArgs}, + {"scalar_floordiv", kPrimTypeTwoArgs}}) {} + +bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { + bool result = false; + + if (func != nullptr) { + int args_num = GetPrimType(prim); + std::vector one_arg{std::make_shared()}; + std::vector two_args{std::make_shared(), std::make_shared()}; + TypePtr retval = std::make_shared(); + result = true; + switch (args_num) { + case kPrimTypeOneArg: + *func = Function(one_arg, retval).DeepCopy()->cast(); + break; + case kPrimTypeTwoArgs: + *func = Function(two_args, retval).DeepCopy()->cast(); + break; + default: + result = false; + break; + } + } + + return result; +} + +int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const { + MS_EXCEPTION_IF_NULL(prim); + int prim_type = static_cast(kPrimTypeUnknown); + + auto value = prim_func_type_map_.find(prim->name()); + if (value != prim_func_type_map_.end()) { + prim_type = value->second; + } + return prim_type; +} +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_to_function.h b/mindspore/ccsrc/frontend/operator/prim_to_function.h similarity index 100% rename from mindspore/ccsrc/operator/prim_to_function.h rename to mindspore/ccsrc/frontend/operator/prim_to_function.h diff --git a/mindspore/ccsrc/frontend/optimizer/CMakeLists.txt b/mindspore/ccsrc/frontend/optimizer/CMakeLists.txt new file mode 100644 index 0000000000..14fda83052 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE _OPTIMIZER_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_OPTIMIZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_OPTIMIZER) +add_library(_mindspore_frontend_optimizer_obj OBJECT ${_OPTIMIZER_SRC_FILES}) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/adjoint.cc b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.cc new file mode 100644 index 0000000000..60ccf28df4 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/ad/adjoint.h" + +#include +#include + +#include "ir/anf.h" +#include "frontend/optimizer/ad/dfunctor.h" + +namespace mindspore { +namespace ad { +Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller) + : primal_(primal), caller_(caller), dout_(nullptr) { + if (k != nullptr) { + k_ = k; + MS_LOG(DEBUG) << "Add adjoint for " << primal->ToString() << " " << k_->ToString(); + } else { + // Init k hole in a recursive case. + auto k_hole = std::make_shared("k_hole"); + (void)k_hole->AddAttr("info", MakeValue(primal->ToString())); + k_ = NewValueNode(k_hole); + MS_LOG(DEBUG) << "Add hole for " << primal->ToString() << " " << k_->ToString(); + } + + dout_hole_ = caller_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), k_}); + RegisterKUser(dout_hole_->cast(), 1); +} + +AnfNodePtr Adjoint::k() { return k_; } + +void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } + +void Adjoint::UpdateK(const AnfNodePtr &new_k) { + MS_EXCEPTION_IF_NULL(new_k); + MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); + // In recursive case, it needs update. + for (auto &user : k_user_) { + MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" + << new_k->ToString(); + if (user.first->input(user.second) != k_) { + MS_LOG(EXCEPTION) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k " + << new_k->ToString() << ", user relation is set wrongly"; + } + user.first->set_input(user.second, new_k); + } + k_ = new_k; +} + +AnfNodePtr Adjoint::primal() { return primal_; } + +AnfNodePtr Adjoint::dout() { return dout_hole_; } + +void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) { + dout_user_.emplace_back(std::make_pair(user, index)); +} + +void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) { + if (dout_ != nullptr) { + MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); + auto add = prim::GetPythonOps("hyper_add"); + dout_ = caller_->NewCNode({NewValueNode(add), dout_, dout_factor}); + return; + } + dout_ = dout_factor; +} + +void Adjoint::CallDoutHole() { + if (dout_ != nullptr) { + for (auto &user : dout_user_) { + MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " + << dout_->ToString(); + if (user.first->input(user.second) != dout_hole_) { + MS_LOG(EXCEPTION) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " + << dout_->ToString() << ", user relation is set wrongly"; + } + user.first->set_input(user.second, dout_); + } + } +} +} // namespace ad +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/adjoint.h b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.h new file mode 100644 index 0000000000..37986e6810 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ + +#include +#include +#include + +#include "ir/anf.h" +#include "frontend/optimizer/opt.h" + +namespace mindspore { +namespace ad { +class Adjoint { + public: + Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller); + ~Adjoint() = default; + AnfNodePtr primal(); + AnfNodePtr k(); + void UpdateK(const AnfNodePtr &k); + void RegisterKUser(const CNodePtr &user, size_t index); + AnfNodePtr dout(); + void AccumulateDout(const AnfNodePtr &dout_factor); + void RegisterDoutUser(const CNodePtr &user, size_t index); + void CallDoutHole(); + + private: + AnfNodePtr primal_; + FuncGraphPtr caller_; + // For ```def f(x): return expr```, The representation graph k is ```def kf(kx): return expr, bprop{expr}```. + AnfNodePtr k_; + std::vector> k_user_; + AnfNodePtr dout_; + AnfNodePtr dout_hole_; + std::vector> dout_user_; +}; + +using AdjointPtr = std::shared_ptr; +} // namespace ad +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc new file mode 100644 index 0000000000..b314b22f81 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -0,0 +1,617 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/ad/dfunctor.h" + +#include +#include +#include + +#include "ir/anf.h" +#include "ir/meta_func_graph.h" +#include "debug/info.h" +#include "ir/func_graph_cloner.h" +#include "ir/manager.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/parse/parse.h" +#include "frontend/optimizer/ad/adjoint.h" +#include "frontend/optimizer/opt.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/composite.h" +#include "utils/symbolic.h" +#include "utils/context/ms_context.h" +#include "./common.h" + +namespace mindspore { +namespace ad { +std::unordered_map DFunctor::func_graph_to_functor_; +std::unordered_map DFunctor::anfnode_to_adjoin_definition_; +FuncGraphSet DFunctor::scope_; + +DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources) + : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { + TraceManager::DebugTrace(std::make_shared(primal_graph->debug_info())); + k_graph_ = std::make_shared(); + if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + std::string grad_op_name = GetValue(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); + } + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(primal_graph->debug_info())); + tape_ = std::make_shared(); + // Add "_Grad" postfix + if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + std::string grad_op_name = GetValue(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad"; + tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); + } + TraceManager::EndTrace(); + + dout_ = tape_->add_parameter(); +} + +void DFunctor::Init(bool is_top) { + func_graph_to_functor_[primal_graph_] = shared_from_this(); + is_top_ = is_top; + if (is_top) { + scope_ = primal_graph_->scope(); + } +} + +void DFunctor::Clear() { + func_graph_to_functor_.clear(); + anfnode_to_adjoin_definition_.clear(); + scope_.clear(); +} + +void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { + auto fv_adjoint = anfnode_to_adjoin_.find(fv); + if (fv_adjoint == anfnode_to_adjoin_.end()) { + MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() + << " " << fv->ToString() << "."; + fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); + if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) { + MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv " + << fv->func_graph()->ToString() << " " << fv->ToString() << "."; + auto parent_adjoint = FindAdjoint(fv); + AdjointPtr adjoint = nullptr; + if (parent_adjoint != nullptr) { + adjoint = std::make_shared(fv, parent_adjoint->k(), tape_); + } else { + MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole " + << fv->func_graph()->ToString() << " " << fv->ToString() << "."; + adjoint = std::make_shared(fv, nullptr, tape_); + } + anfnode_to_adjoin_indirect_fv_[fv] = adjoint; + fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); + } + } + auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()}); + fv_adjoint->second->RegisterKUser(node, 1); + auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()}); + fv_adjoint->second->RegisterKUser(default_val, 1); + auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, node, default_val}); + MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv " + << fv->func_graph()->ToString() << " " << fv->ToString() << "."; + MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << node->ToString() << "."; + fv_adjoint->second->AccumulateDout(dfv); +} + +void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) { + // Take switch_layer as a set of candidate functions. + auto input = cnode_morph->input(2); + if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { + MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << "."; + } + auto tuple_graphs = input->cast(); + for (size_t i = 1; i < tuple_graphs->size(); ++i) { + auto graph = tuple_graphs->input(i); + if (!IsValueNode(graph)) { + MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString() + << " as the " << i << "th element."; + } + auto func_graph = GetValueNode(graph); + auto functor = func_graph_to_functor_.find(func_graph); + if (functor == func_graph_to_functor_.end()) { + MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] " + << func_graph->ToString() << "."; + } + // Consider direct and indirect fvs. + for (auto fv : func_graph->free_variables_nodes()) { + BackPropagateFv(fv, env); + } + for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { + MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " " + << indirect_fv.first->ToString() << "."; + BackPropagateFv(indirect_fv.first, env); + } + } +} + +void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) { + auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)}); + // Call with delimited continuation dout. + auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()}); + node_adjoint->RegisterDoutUser(bprop_app, 1); + // Special case for switch_layer + if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) { + auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)}); + BackPropagateSwitchLayer(cnode_morph, din); + return; + } + for (size_t i = 0; i < cnode_morph->size(); i++) { + auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))}); + auto input = cnode_morph->input(i); + // Backprop sens wrt fvs. + if (IsValueNode(input)) { + auto func_graph = GetValueNode(input); + auto functor = func_graph_to_functor_.find(func_graph); + if (functor == func_graph_to_functor_.end()) { + MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] " + << func_graph->ToString() << "."; + } + // Consider direct and indirect fvs. + for (auto fv : func_graph->free_variables_nodes()) { + BackPropagateFv(fv, din); + } + for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { + MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " " + << indirect_fv.first->ToString() << "."; + BackPropagateFv(indirect_fv.first, din); + } + continue; + } + // Backprop sens wrt inputs. + auto input_adjoint = anfnode_to_adjoin_.find(input); + if (input_adjoint == anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << "."; + } + input_adjoint->second->AccumulateDout(din); + } +} + +// Map a morphism. +AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { + // MapMorphism All type except CNode should already be mapped by MapObject. + if (!morph->isa()) { + return nullptr; + } + ScopeGuard scope_guard(morph->scope()); + auto cnode_morph = morph->cast(); + + std::vector inputs; + std::vector param_adjoints; + for (size_t i = 0; i < cnode_morph->size(); i++) { + auto node = cnode_morph->input(i); + auto node_adjoint_iter = anfnode_to_adjoin_.find(node); + AdjointPtr node_adjoint = nullptr; + AnfNodePtr k = nullptr; + if (node_adjoint_iter != anfnode_to_adjoin_.end()) { + node_adjoint = node_adjoint_iter->second; + } else { + // Input might be a CNode that needs to be handled before hand. + node_adjoint = MapMorphism(node); + } + MS_EXCEPTION_IF_NULL(node_adjoint); + k = node_adjoint->k(); + if (k == nullptr) { + MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << "."; + } + inputs.push_back(k); + param_adjoints.push_back(node_adjoint); + } + TraceManager::DebugTrace(std::make_shared(cnode_morph->debug_info())); + auto k_app = k_graph_->NewCNode(inputs); + TraceManager::EndTrace(); + for (size_t i = 0; i < param_adjoints.size(); ++i) { + param_adjoints[i]->RegisterKUser(k_app, i); + } + + // Do forward computation + auto foward_app = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(0)}); + // K:: cnode -> forward_app + auto node_adjoint = std::make_shared(morph, foward_app, tape_); + UpdateAdjoint(node_adjoint); + anfnode_to_adjoin_[morph] = node_adjoint; + if (cnode_morph->stop_gradient()) { + MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped."; + return node_adjoint; + } + + // Do sens backpropagation + BackPropagate(cnode_morph, k_app, node_adjoint); + MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << "."; + return node_adjoint; +} + +bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { + // Do not care about non-CNode + if (!node->isa()) { + return false; + } + // Do not care about kPrimReturn + if (IsPrimitiveCNode(node, prim::kPrimReturn)) { + return false; + } + auto &users = primal_graph_->manager()->node_users()[node]; + // Do not care about isolated morphisms + if (users.empty()) { + return false; + } + // Not free if it's used by some node in primal_graph + bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) { + auto &user = kv.first; + return user->func_graph() == primal_graph_; + }); + return !nonfree; +} + +void DFunctor::MapFreeMorphism() { + // Handle cnode not attached to output, that might be refered in other functions. + for (auto &node : primal_graph_->nodes()) { + if (!IsFreeMorphism(node)) { + continue; + } + MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << "."; + (void)MapMorphism(node); + } +} + +AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { + AnfNodePtr new_grad_fv = grad_fv; + // Add grads wrt fv. + const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); + for (auto &fv : free_variables_nodes) { + auto fv_adjoint = anfnode_to_adjoin_.find(fv); + if (fv_adjoint == anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << "."; + } + auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()}); + fv_adjoint->second->RegisterKUser(node, 1); + auto sens = fv_adjoint->second->dout(); + new_grad_fv = tape_->NewCNode({ + NewValueNode(prim::kPrimEnvSetItem), + new_grad_fv, + node, + sens, + }); + fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast(), 3); + MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " " + << fv->ToString() << " " << primal_graph_->ToString() << "."; + } + return new_grad_fv; +} + +AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) { + AnfNodePtr new_grad_fv = grad_fv; + // Add indirect fv bprop. + for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) { + MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " " + << primal_graph_->ToString() << "."; + auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()}); + fv_adjoint.second->RegisterKUser(node, 1); + auto sens = fv_adjoint.second->dout(); + new_grad_fv = tape_->NewCNode({ + NewValueNode(prim::kPrimEnvSetItem), + new_grad_fv, + node, + sens, + }); + fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast(), 3); + MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to " + << new_grad_fv->ToString() << "."; + } + return new_grad_fv; +} + +void DFunctor::MapMorphism() { + // Set stop_gradient before MapMorphism. + BroadCastStopFlag(); + + // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent + MapFreeMorphism(); + // Handle morphism from output. + (void)MapMorphism(primal_graph_->output()); + + // Construct K for primal_graph_ + auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output()); + // Attach dout_ parameter to output_adjoint. + output_adjoint->second->AccumulateDout(dout_); + + // Set output for tape closure. + auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv))); + + std::vector inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv}; + // Add grads wrt inputs. + std::vector param_adjoints; + for (auto ¶m : primal_graph_->parameters()) { + auto param_adjoint = anfnode_to_adjoin_.find(param); + inputs.push_back(param_adjoint->second->dout()); + param_adjoints.push_back(param_adjoint->second); + } + auto tape_output = tape_->NewCNode(inputs); + for (size_t i = 0; i < param_adjoints.size(); ++i) { + param_adjoints[i]->RegisterDoutUser(tape_output, i + 2); + } + tape_->set_output(tape_output); + // Set output for k_graph_, K:: cnode->forward_app. + auto forward_app = output_adjoint->second->k(); + auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)}); + output_adjoint->second->RegisterKUser(output, 1); + k_graph_->set_output(output); + (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_))); + (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_))); +} + +FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { + // K user defined cell bprop. + auto bprop = primal->transforms().find("bprop"); + if (bprop != primal->transforms().end()) { + FuncGraphPtr bprop_graph = bprop->second.func_graph(); + resources_->manager()->AddFuncGraph(bprop_graph); + + if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) { + MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope " + << primal->output()->scope()->name() << " does not support Parameter data type."; + } + auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph); + if (fg == nullptr) { + MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope " + << primal->output()->scope()->name() << "."; + } + + // Cache the grad func + (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg))); + (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal))); + // Reset defer_inline to enable successive inlining + primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); + + auto functor = std::make_shared(primal, resources_); + functor->Init(); + functor->k_graph_ = fg; + + return fg; + } + return nullptr; +} + +// MapToK(func) +AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { + auto f = func_graph_to_functor_.find(primal); + if (f != func_graph_to_functor_.end()) { + MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << "."; + return NewValueNode(f->second->k_graph_); + } + + auto k_user_defined = KUserDefined(primal); + if (k_user_defined != nullptr) { + MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << "."; + return NewValueNode(k_user_defined); + } + + auto functor = std::make_shared(primal, resources_); + functor->Init(); + functor->MapObject(); + functor->MapMorphism(); + + MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << "."; + return NewValueNode(functor->k_graph_); +} + +// Construct representation graph for given node. +AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { + ScopeGuard scope_guard(primal->scope()); + // MapToK(prim) + if (IsValueNode(primal)) { + auto value_node = primal->cast(); + auto prim = GetValueNode(value_node); + if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { + MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; + need_cut_ = true; + } + auto k_prim = g_k_prims.KPrimitive(value_node, resources_); + if (k_prim != nullptr) { + return NewValueNode(k_prim); + } + // When failed to find k_prim, try k_meta. + auto k_meta = g_k_prims.KMetaFuncGraph(prim); + if (k_meta != nullptr) { + return NewValueNode(k_meta); + } + } + + // MapToK(func) + if (IsValueNode(primal)) { + auto func_graph = GetValueNode(primal); + auto k_func = MapToK(func_graph); + return k_func; + } + + if (primal->isa()) { + TraceManager::DebugTrace(std::make_shared(primal->debug_info())); + auto ret = k_graph_->add_parameter(); + TraceManager::EndTrace(); + return ret; + } + + if (!primal->isa()) { + MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode."; + } + return primal; +} + +bool DFunctor::IsInScope(const AnfNodePtr &node) { + return std::any_of(scope_.begin(), scope_.end(), + [&](const FuncGraphPtr &graph) { return node->func_graph() == graph; }); +} + +void DFunctor::MapFvObject() { + // Map free variable. + const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); + for (auto &node : free_variables_nodes) { + ScopeGuard scope_guard(node->scope()); + MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << "."; + // Find fv's K from parent. + AdjointPtr adjoint = nullptr; + auto parent_adjoint = FindAdjoint(node); + if (parent_adjoint != nullptr) { + adjoint = std::make_shared(node, parent_adjoint->k(), tape_); + } else { + if (is_top_ || node->isa() || !IsInScope(node)) { + // Out of ad scope, add adjoint for free variables. + adjoint = std::make_shared(node, node, tape_); + UpdateAdjoint(adjoint); + } else { + MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << "."; + adjoint = std::make_shared(node, nullptr, tape_); + } + } + if (adjoint == nullptr) { + MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << "."; + } + anfnode_to_adjoin_[node] = adjoint; + } +} + +void DFunctor::MapParamObject() { + // Map parameter. + for (auto &p : primal_graph_->parameters()) { + ScopeGuard scope_guard(p->scope()); + MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << "."; + auto adjoint = std::make_shared(p, MapToK(p), tape_); + UpdateAdjoint(adjoint); + anfnode_to_adjoin_[p] = adjoint; + } +} + +void DFunctor::MapValueObject() { + // Map ValueNode. + auto manager = resources_->manager(); + auto &value_nodes = primal_graph_->value_nodes(); + for (const auto &value_pair : value_nodes) { + auto node = value_pair.first; + auto parent_adjoint = FindAdjoint(node); + if (parent_adjoint != nullptr) { + auto adjoint = std::make_shared(node, parent_adjoint->k(), tape_); + anfnode_to_adjoin_[node] = adjoint; + continue; + } + // Skip Return. + if (IsValueNode(node) && GetValueNode(node) == prim::kPrimReturn) { + continue; + } + MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << "."; + auto adjoint = std::make_shared(node, MapToK(node), tape_); + UpdateAdjoint(adjoint); + anfnode_to_adjoin_[node] = adjoint; + } +} + +// Skip morphism. +void DFunctor::MapObject() { + // The order does not matter + MapFvObject(); + MapParamObject(); + MapValueObject(); +} + +void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) { + auto primal = adjoint_definition->primal(); + if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) { + MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " " + << primal->ToString() << "."; + } + anfnode_to_adjoin_definition_[primal] = adjoint_definition; + // Update k hole for primal. + for (auto &f : func_graph_to_functor_) { + auto adjoint = f.second->anfnode_to_adjoin_.find(primal); + if (adjoint != f.second->anfnode_to_adjoin_.end()) { + adjoint->second->UpdateK(adjoint_definition->k()); + } + adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal); + if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) { + adjoint->second->UpdateK(adjoint_definition->k()); + } + } +} + +AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) { + auto adjoint = anfnode_to_adjoin_definition_.find(primal); + if (adjoint != anfnode_to_adjoin_definition_.end()) { + MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << "."; + return adjoint->second; + } + MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << "."; + return nullptr; +} + +void DFunctor::CallDoutHoleOnTape() { + if (!is_top_) { + return; + } + + // Call dout hole of all adjoint. + for (auto &f : func_graph_to_functor_) { + for (auto &adjoint : f.second->anfnode_to_adjoin_) { + adjoint.second->CallDoutHole(); + } + for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) { + adjoint.second->CallDoutHole(); + } + } +} +FuncGraphPtr DFunctor::k_graph() { + CallDoutHoleOnTape(); + return k_graph_; +} + +void DFunctor::BroadCastStopFlag() { + // As stop set expanding, all directly or indirectly stopped CNode will be cut off + while (need_cut_) { + need_cut_ = false; + for (auto &node : primal_graph_->nodes()) { + if (node->isa()) { + auto cnode = node->cast(); + if (!cnode->stop_gradient()) { + // Cut off the cnode only when it's not referred any more + if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || AllReferencesStopped(cnode)) { + MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << "."; + cnode->set_stop_gradient(true); + // The stop set changed, more cut required + need_cut_ = true; + } + } + } + } + } +} + +bool DFunctor::AllReferencesStopped(const CNodePtr &node) { + auto &users = primal_graph_->manager()->node_users()[node]; + // Only care about stop_gradient caused cutting + if (users.empty()) { + return false; + } + for (auto &kv : users) { + auto &user = kv.first; + if (!user->isa() || !user->cast()->stop_gradient()) { + return false; + } + } + return true; +} +} // namespace ad +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h new file mode 100644 index 0000000000..9ee93334e8 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -0,0 +1,210 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ + +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/meta_func_graph.h" +#include "ir/func_graph_cloner.h" +#include "pipeline/jit/resource.h" +#include "frontend/optimizer/ad/adjoint.h" +#include "frontend/operator/ops.h" +#include "debug/trace.h" + +namespace mindspore { +namespace ad { +struct PrimitiveTotalEqual { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { + MS_EXCEPTION_IF_NULL(t1); + MS_EXCEPTION_IF_NULL(t2); + return *t1 == *t2; + } +}; + +using Registry = std::unordered_map; +class KPrim; +extern KPrim g_k_prims; +class DFunctor; +using DFunctorPtr = std::shared_ptr; + +// D Functor's rules to map closure object and morphisms. +class DFunctor : public std::enable_shared_from_this { + public: + DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources); + ~DFunctor() = default; + // Map object in D category to K category. + void MapObject(); + // Map morphism in D category to K category. + void MapMorphism(); + FuncGraphPtr k_graph(); + // Construct user defined k object. + FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); + // Register functor objects to form a global view. + void Init(bool is_top = false); + bool IsInScope(const AnfNodePtr &node); + + // Clear resources. + static void Clear(); + + private: + // Map one morphism. + AdjointPtr MapMorphism(const AnfNodePtr &morph); + bool IsFreeMorphism(const AnfNodePtr &node); + // Map morphism that's not attached to output. + void MapFreeMorphism(); + void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); + void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env); + void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); + AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); + AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); + // Map Anfnode object from D category to K category. + AnfNodePtr MapToK(const AnfNodePtr &primal); + // Map FuncGraph object from D category to K category. + AnfNodePtr MapToK(const FuncGraphPtr &primal); + // MapObject impls. + void MapFvObject(); + void MapValueObject(); + void MapParamObject(); + // Find adjoint with its primary k. + AdjointPtr FindAdjoint(const AnfNodePtr &primal); + // Broadcast stop flags. + void BroadCastStopFlag(); + bool AllReferencesStopped(const CNodePtr &node); + // Update k hole with adjoint_definition, only applied in recursive case. + void UpdateAdjoint(const AdjointPtr &adjoint_definition); + void CallDoutHoleOnTape(); + + std::unordered_map anfnode_to_adjoin_; + // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. + std::unordered_map anfnode_to_adjoin_indirect_fv_; + FuncGraphPtr primal_graph_; + // K object for primal_graph_; + FuncGraphPtr k_graph_; + // The Backprop part of k_graph_. + FuncGraphPtr tape_; + // Dout parameter for primal_graph_. + AnfNodePtr dout_; + pipeline::ResourceBasePtr resources_; + // Cut off stopped objects in category D. + bool need_cut_; + bool is_top_; + static std::unordered_map> func_graph_to_functor_; + static std::unordered_map anfnode_to_adjoin_definition_; + static FuncGraphSet scope_; +}; + +// D Functor's rules to map primitive object. +class KPrim { + public: + KPrim() = default; + ~KPrim() = default; + + FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); + MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); + FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); + + void clear() { + bprop_registry_meta_.clear(); + bprop_registry_.clear(); + } + + private: + FuncGraphPtr GetBprop(const PrimitivePtr &prim); + FuncGraphPtr GetFprop(const PrimitivePtr &prim); + FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); + FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); + // Given a bprop rule, do the K mapping. + template + FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); + AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); + void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, + std::vector *const transf_args); + void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check); + + Registry bprop_registry_; + std::unordered_map bprop_registry_meta_; +}; + +template +FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { + MS_EXCEPTION_IF_NULL(primal); + MS_EXCEPTION_IF_NULL(bprop_fg); + CheckBprop(bprop_fg, primal->ToString()); + + auto debug_info = std::make_shared(); + debug_info->set_name(primal->ToString()); + + auto cloned_bprop_fg = BasicClone(bprop_fg); + MS_EXCEPTION_IF_NULL(cloned_bprop_fg); + + cloned_bprop_fg->debug_info()->set_name(""); + cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared(debug_info)); + + AnfNodePtr bout = BuildOutput(cloned_bprop_fg); + cloned_bprop_fg->set_output(bout); + + TraceManager::DebugTrace(std::make_shared(debug_info)); + auto outer = std::make_shared(); + (void)outer->transforms().emplace("primal", FuncGraphTransform(primal)); + outer->set_output(NewValueNode(kNone)); + TraceManager::EndTrace(); + + auto mng = Manage({cloned_bprop_fg, outer}, false); + + // Make sure (out, dout) provided. + if (cloned_bprop_fg->parameters().size() < 2) { + MS_LOG(EXCEPTION) << "Primitive or Cell " << primal->ToString() + << " bprop requires out and dout at least, but only got " << cloned_bprop_fg->parameters().size() + << " params. NodeInfo: " << trace::GetDebugInfo(cloned_bprop_fg->debug_info()); + } + + // In a bprop definition, the last two param should be out and dout. + auto dout = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 1]; + auto out_param = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 2]; + std::vector transf_args; + TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); + + TraceManager::DebugTrace(std::make_shared(dout->debug_info())); + (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); + auto out_value = outer->NewCNode(transf_args); + TraceManager::EndTrace(); + + (void)mng->Replace(out_param, out_value); + + TraceManager::DebugTrace(std::make_shared(out_param->debug_info())); + auto new_dout = cloned_bprop_fg->add_parameter(); + (void)mng->Replace(dout, new_dout); + // We remove all parameters except new_dout. + std::vector newBpropParams = {new_dout}; + cloned_bprop_fg->set_parameters(newBpropParams); + TraceManager::EndTrace(); + + outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)})); + return BasicClone(outer); +} +} // namespace ad +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc new file mode 100644 index 0000000000..ef2d7d400a --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/ad/grad.h" +#include "frontend/optimizer/ad/dfunctor.h" +#include "ir/func_graph_cloner.h" +#include "utils/context/ms_context.h" +#include "utils/symbolic.h" +#include "utils/graph_utils.h" + +namespace mindspore { +namespace ad { +FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) { + MS_EXCEPTION_IF_NULL(func_graph); + auto gradkv = func_graph->transforms().find("grad"); + if (gradkv != func_graph->transforms().end()) { + return gradkv->second.func_graph(); + } + + auto manager_ptr = resources->manager(); + MS_EXCEPTION_IF_NULL(manager_ptr); + manager_ptr->AddFuncGraph(func_graph); + + auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { + if (MsContext::GetInstance()->is_multi_graph_sink()) { + if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { + f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + } + } + }; + + auto f = std::make_shared(func_graph, resources); + auto user_defined = f->KUserDefined(func_graph); + if (user_defined != nullptr) { + multi_graph_sink(user_defined); + if (is_top) { + DFunctor::Clear(); + } + return user_defined; + } + f->Init(is_top); + f->MapObject(); + f->MapMorphism(); + auto ret = f->k_graph(); + if (is_top) { + DFunctor::Clear(); + } + + multi_graph_sink(ret); + return ret; +} + +FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { + auto fg = g_k_prims.KPrimitive(value_node, resources); + if (fg == nullptr) { + return nullptr; + } + return BasicClone(fg); +} + +MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &) { + MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim); + return fg; +} + +void CleanRes() { DFunctor::Clear(); } +} // namespace ad +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.h b/mindspore/ccsrc/frontend/optimizer/ad/grad.h new file mode 100644 index 0000000000..ee9ab79ffb --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ + +#include +#include + +#include "ir/anf.h" +#include "ir/meta_func_graph.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace ad { +using ResourcePtr = std::shared_ptr; + +FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top = true); +FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); +MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &); +void CleanRes(); +} // namespace ad +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc new file mode 100644 index 0000000000..5ca2ca6c43 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -0,0 +1,291 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "ir/anf.h" +#include "ir/primitive_py.h" +#include "ir/meta_func_graph.h" +#include "ir/func_graph_cloner.h" +#include "ir/manager.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/parse/parse.h" +#include "frontend/optimizer/ad/dfunctor.h" +#include "frontend/optimizer/opt.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/composite.h" +#include "utils/symbolic.h" +#include "utils/primitive_utils.h" +#include "utils/context/ms_context.h" +#include "debug/info.h" +#include "debug/trace.h" + +#include "./common.h" + +namespace mindspore { +namespace ad { +using PatternListType = std::initializer_list; +KPrim g_k_prims; + +FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { + // Set a child scope named "grad'PrimitiveName'" for the bprop function, + // and add "Gradients" to the front. + static const std::string gradients_scope = "Gradients/"; + static const std::string grad_op_child_scope_prefix = "/grad"; + MS_EXCEPTION_IF_NULL(prim); + auto scope = std::make_shared(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + + grad_op_child_scope_prefix + prim->name()); + ScopeGuard scope_guard(scope); + py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast()->GetBpropFunction(); + if (fn == nullptr || py::isinstance(fn)) { + MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; + return nullptr; + } + FuncGraphPtr func_graph = parse::ParsePythonCode(fn); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; + return nullptr; + } + return func_graph; +} + +FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) { + static const std::string ad_module = "mindspore.ops._grad.grad_implementations"; + std::string func_name = "_fprop_" + prim->name(); + py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name); + auto func_graph = parse::ParsePythonCode(fn); + MS_EXCEPTION_IF_NULL(func_graph); + return BasicClone(func_graph); +} + +MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(prim); + + auto iter = bprop_registry_meta_.find(prim); + if (iter != bprop_registry_meta_.end()) { + return iter->second; + } + + if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { + MetaFuncGraphPtr meta = std::make_shared("make_tuple_gradient"); + bprop_registry_meta_[prim::kPrimMakeTuple] = meta; + return meta; + } + + MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; +} + +FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { + if (!IsValueNode(value_node)) { + MS_LOG(EXCEPTION) << "Primitive node is not valid."; + } + + auto prim = GetValueNode(value_node); + if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) { + auto fprop = GetFprop(prim); + fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer)); + return fprop; + } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { + return nullptr; + } + + FuncGraphPtr bprop_fg = nullptr; + if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { + bprop_fg = BpropCut(value_node, resources); + } else { + auto iter = bprop_registry_.find(prim); + if (iter != bprop_registry_.end()) { + bprop_fg = iter->second; + } + + if (bprop_fg == nullptr) { + bprop_fg = GetBprop(prim); + if (bprop_fg != nullptr) { + // Set bprop_g graph cache + bprop_registry_[prim] = bprop_fg; + } else { + bprop_fg = FakeBprop(value_node, resources); + } + } + } + + auto expanded_fg = BpropToK(prim, bprop_fg); + if (expanded_fg == nullptr) { + MS_LOG(EXCEPTION) << "Failed convert " << prim->name() + << " prim bprop function to J expanded func graph. NodeInfo: " + << trace::GetDebugInfo(bprop_fg->debug_info()); + } + + return expanded_fg; +} + +AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg) { + // bprop_fg has been checked in caller + if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { + // Set bprop output as (env, dx, dy, dz, ...) + auto cbprop = bprop_fg->output()->cast(); + auto &inputs = cbprop->inputs(); + + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + args.push_back(NewValueNode(newenv)); + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + return NewCNode(args, bprop_fg); + } + + // Set bprop output as (env, dx) + std::string model_name("mindspore.ops.composite.multitype_ops.add_impl"); + std::string python_ops("_tuple_add"); + auto tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg); + return NewCNode({NewValueNode(prim::GetPythonOps(python_ops, model_name)), tuple, bprop_fg->output()}, bprop_fg); +} + +void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, + std::vector *const transf_args) { + MS_EXCEPTION_IF_NULL(mng); + // bprop_fg has been checked in caller + // transform except the last 2 parameters: out, dout. + for (size_t i = 0; i < bprop_fg->parameters().size() - 2; ++i) { + auto p = bprop_fg->parameters()[i]; + MS_EXCEPTION_IF_NULL(p); + + TraceManager::DebugTrace(std::make_shared(p->debug_info())); + auto transf_p = outer->add_parameter(); + TraceManager::EndTrace(); + + (void)mng->Replace(p, transf_p); + transf_args->push_back(transf_p); + } +} + +void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool check_bprop_flag = context->check_bprop_flag(); + // Skip checking if check_bprop not set + if (!check_bprop_flag) { + return; + } + + // bprop_fg has been checked in caller + auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops"); + MS_EXCEPTION_IF_NULL(check_bprop_class); + auto check_bprop = + bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared(prim_to_check))}); + + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2); + AnfNodePtr params = bprop_fg->NewCNode(inputs); + + inputs.clear(); + inputs.push_back(check_bprop); + inputs.push_back(bprop_fg->output()); + inputs.push_back(params); + AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs); + bprop_fg->set_output(bprop_out); +} + +FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { + MS_EXCEPTION_IF_NULL(bprop_fg); + auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph(); + auto expanded_fg = BpropToK(fprop_fg, bprop_fg); + if (expanded_fg == nullptr) { + MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString() + << " Cell bprop function to K expanded func graph. NodeInfo: " + << trace::GetDebugInfo(fprop_fg->debug_info()); + } + return expanded_fg; +} + +FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { + auto prim = GetValueNode(value_node); + MS_EXCEPTION_IF_NULL(prim); + auto &node_users = resources->manager()->node_users(); + + auto &users = node_users[value_node]; + auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { + return IsPrimitiveCNode(user.first, prim); + }); + if (cnode == users.end()) { + MS_LOG(EXCEPTION) << "Fail to find cnode."; + } + auto inputs_num = cnode->first->cast()->size() - 1; + + auto func_graph = std::make_shared(); + std::vector outputs; + + auto bprop_cut = std::make_shared("bprop_cut", py::object()); + bprop_cut->CopyHookFunction(prim); + + auto cell_id = GetValue(prim->GetAttr("cell_id")); + if (cell_id != "") { + (void)bprop_cut->AddAttr("cell_hook", MakeValue(true)); + (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id)); + } + + outputs.push_back(NewValueNode(bprop_cut)); + for (size_t i = 0; i < inputs_num; ++i) { + auto param = func_graph->add_parameter(); + outputs.push_back(param); + } + auto p1 = func_graph->add_parameter(); + auto p2 = func_graph->add_parameter(); + outputs.push_back(p1); + outputs.push_back(p2); + + func_graph->set_output(func_graph->NewCNode(outputs)); + return func_graph; +} + +FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { + auto prim = value_node->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + auto &node_users = resources->manager()->node_users(); + + auto &users = node_users[value_node]; + auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { + return IsPrimitiveCNode(user.first, prim); + }); + if (cnode == users.end()) { + MS_LOG(EXCEPTION) << "Fail to find cnode."; + } + auto inputs_num = cnode->first->cast()->inputs().size() - 1; + + auto func_graph = std::make_shared(); + std::vector outputs; + outputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + + auto fake_bprop = std::make_shared("fake_bprop"); + (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined.")); + + for (size_t i = 0; i < inputs_num; ++i) { + // Mock params for inputs + auto param = func_graph->add_parameter(); + // Mock derivatives for each inputs + outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param})); + } + // mock params for out and dout + (void)func_graph->add_parameter(); + (void)func_graph->add_parameter(); + func_graph->set_output(func_graph->NewCNode(outputs)); + return func_graph; +} +} // namespace ad +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc new file mode 100644 index 0000000000..e35760ceaf --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -0,0 +1,538 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/clean.h" +#include +#include +#include +#include +#include +#include "./common.h" +#include "debug/trace.h" +#include "frontend/operator/composite/composite.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { +using mindspore::abstract::AbstractAttribute; +using mindspore::abstract::AbstractClass; +using mindspore::abstract::AbstractDictionary; +using mindspore::abstract::AbstractJTagged; +using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractUndetermined; + +static AbstractBasePtr Reabs(const AbstractBasePtr &t) { + if (t == nullptr) { + return nullptr; + } + + if (t->isa()) { + auto abs_class = dyn_cast(t); + AbstractBasePtrList baselist; + auto attributes = abs_class->attributes(); + (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), + [](const AbstractAttribute &item) { return item.second; }); + return std::make_shared(baselist); + } + if (t->isa()) { + auto abs_dict = dyn_cast(t); + AbstractBasePtrList baselist; + auto elements = abs_dict->elements(); + (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), + [](const AbstractAttribute &item) { return item.second; }); + return std::make_shared(baselist); + } + if (t->isa()) { + auto abs_list = dyn_cast(t); + return std::make_shared(abs_list->elements()); + } + + return nullptr; +} + +AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + const auto &inputs = node->inputs(); + // Inputs should be [getattr, data, attribute] + MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + MS_EXCEPTION_IF_NULL(data); + MS_EXCEPTION_IF_NULL(cons); + + auto dt = data->abstract(); + if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) { + return nullptr; + } + + if (!dt->isa()) { + MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << "."; + } + + auto cons_is_str = IsValueNode(cons); + auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; + + auto ct = dyn_cast(dt); + const auto &cmap = ct->attributes(); + int count = 0; + for (auto &item : cmap) { + if (cons_is_str && item.first == cons_str) { + break; + } + count++; + } + + auto idx_c = NewValueNode(count); + AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); + idx_c->set_abstract(aptr); + + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); +} + +AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + // Inputs should be [dict_getitem, dict, item] + const auto &inputs = node->inputs(); + MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + MS_EXCEPTION_IF_NULL(data); + MS_EXCEPTION_IF_NULL(cons); + + auto dt = data->abstract(); + MS_EXCEPTION_IF_NULL(dt); + if (!dt->isa()) { + MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name(); + } + auto cons_is_str = IsValueNode(cons); + auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; + + auto ct = dyn_cast(dt); + const auto &cmap = ct->elements(); + int count = 0; + for (auto &item : cmap) { + if (cons_is_str && item.first == cons_str) { + break; + } + count++; + } + + auto idx_c = NewValueNode(count); + AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); + idx_c->set_abstract(aptr); + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); +} + +AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + // Inputs should be [dict_setitem, dict, item, value] + const auto &inputs = node->inputs(); + MS_ASSERT(inputs.size() == 4 && "DictSetItem should have three inputs."); + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + AnfNodePtr item_value = inputs[3]; + MS_EXCEPTION_IF_NULL(data); + MS_EXCEPTION_IF_NULL(cons); + + auto dt = data->abstract(); + MS_EXCEPTION_IF_NULL(dt); + if (!dt->isa()) { + MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name(); + } + auto cons_is_str = IsValueNode(cons); + auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; + + auto ct = dyn_cast(dt); + const auto &cmap = ct->elements(); + int count = 0; + for (auto &item : cmap) { + if (cons_is_str && item.first == cons_str) { + break; + } + count++; + } + if (IntToSize(count) >= cmap.size()) { + // for dictionary set, if the key does not exist, we should create a new item + auto tuple_add_op = std::make_shared("tuple_add"); + auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value}); + return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item}); + } + auto idx_c = NewValueNode(count); + AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); + idx_c->set_abstract(aptr); + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value}); +} + +AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr; + (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end()); + return node->func_graph()->NewCNode(inputs); +} + +AnfNodePtr ErasePartialNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + const auto &inputs = node->inputs(); + // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; + MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); + + std::vector args(inputs.begin() + 2, inputs.end()); + auto oper = inputs[1]; + if (IsPrimitive(oper, prim::kPrimMakeRecord)) { + if (args.size() == 1) { + return NewValueNode(prim::kPrimMakeTuple); + } + + if (args.size() > 1) { + std::vector new_inputs; + new_inputs.emplace_back(NewValueNode(prim::kPrimPartial)); + new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end()); + + MS_EXCEPTION_IF_NULL(node->func_graph()); + return node->func_graph()->NewCNode(new_inputs); + } + } + return nullptr; +} + +AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items; + (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end()); + return node->func_graph()->NewCNode(inputs); +} + +AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + const auto &inputs = node->inputs(); + // Inputs should be [list_getitem, list, item] + if (inputs.size() < 3) { + MS_LOG(EXCEPTION) << "Node's input number < 3."; + } + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + MS_EXCEPTION_IF_NULL(data); + MS_EXCEPTION_IF_NULL(cons); + + auto cons_node = cons->cast(); + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); +} + +AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + const auto &inputs = node->inputs(); + // Inputs should be [list_setitem, list, index, item] + if (inputs.size() < 4) { + MS_LOG(EXCEPTION) << "Node's input number < 4."; + } + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + AnfNodePtr value = inputs[3]; + + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); +} + +AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + const auto &inputs = node->inputs(); + MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); + return inputs[2]; +} + +AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + const auto &inputs = node->inputs(); + // Inputs should be [make_keyword_arg, key, value] + MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); + return inputs[2]; +} + +AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + const auto &inputs = node->inputs(); + // Inputs should be [extract_keyword_arg, arg, key] + MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); + return inputs[2]; +} + +ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) { + const int DEPTH_MAX = 5; + if (depth > DEPTH_MAX) { + MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; + } + std::vector elements; + for (const auto &it : value_list->value()) { + ValuePtr value = nullptr; + if (it->isa()) { + value = ConvertValueListToValueTuple(it->cast(), depth + 1); + } else { + value = it; + } + elements.push_back(value); + } + return std::make_shared(elements); +} + +AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + ValuePtr value = node->value(); + auto value_list = value->cast(); + MS_EXCEPTION_IF_NULL(value_list); + int depth = 0; + return std::make_shared(ConvertValueListToValueTuple(value_list, depth)); +} + +// Convert class to Tuple +// Convert getattr to getitem +// Convert make_record to make_tuple +bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + bool changed = false; + + // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var + AnfNodeSet all_node = manager->all_nodes(); + for (auto &node : all_node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + AnfNodePtr new_node = nullptr; + if (IsValueNode(node)) { + new_node = NewValueNode(prim::kPrimMakeTuple); + } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) { + new_node = ConvertGetAttrToTupleGetItem(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) { + new_node = ConvertMakeRecordToMakeTuple(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) { + new_node = ErasePartialNode(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) { + new_node = ConvertDictGetItemToTupleGetItem(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) { + new_node = ConvertDictSetItemToTupleSetItem(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) { + new_node = EraseMakeDictNode(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) { + new_node = EraseMakeKeywordArgNode(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) { + new_node = EraseExtractKeywordArg(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { + new_node = ConvertMakeListToMakeTuple(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) { + new_node = ConvertListGetItemToTupleGetItem(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) { + new_node = ConvertListSetItemToTupleSetItem(cnode); + } else if (IsValueNode(node)) { + new_node = ConvertValueListNodeToValueTupleNode(node->cast()); + } + + if (new_node != nullptr) { + new_node->set_abstract(node->abstract()); + MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); + (void)manager->Replace(node, new_node); + changed = true; + } + } + + for (auto &node : manager->all_nodes()) { + auto ret = Reabs(node->abstract()); + if (ret) { + MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " + << ret->ToString(); + node->set_abstract(ret); + changed = true; + } + } + return changed; +} + +// expand tuples in graph parameters +static std::vector ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph, + const std::vector ¶ms) { + MS_EXCEPTION_IF_NULL(mng); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector new_params; + for (const auto ¶m : params) { + MS_EXCEPTION_IF_NULL(param); + auto param_abs = param->abstract(); + MS_EXCEPTION_IF_NULL(param_abs); + + if (param_abs->isa()) { + MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info()); + } + + if (!param_abs->isa()) { + new_params.emplace_back(param); + continue; + } + + std::vector new_param; + std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; + auto abs_tuple = dyn_cast(param_abs); + for (auto &elem : abs_tuple->elements()) { + auto np = std::make_shared(func_graph); + np->set_abstract(elem); + new_param.emplace_back(np); + } + (void)inputs.insert(inputs.end(), new_param.begin(), new_param.end()); + auto new_tuple = func_graph->NewCNode(inputs); + (void)mng->Replace(param, new_tuple); + + auto expand_param = ExpandTuplesP(mng, func_graph, new_param); + (void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end()); + } + return new_params; +} + +// expand tuples in graph applies +static std::vector ExpandTuplesC(const FuncGraphPtr &graph, const std::vector &inputs) { + MS_EXCEPTION_IF_NULL(graph); + + std::vector new_inputs; + for (const auto &input : inputs) { + MS_EXCEPTION_IF_NULL(input); + + auto input_abs = input->abstract(); + MS_EXCEPTION_IF_NULL(input_abs); + + if (input_abs->isa()) { + auto abstract_tag = dyn_cast(input_abs); + if (abstract_tag->element()->isa()) { + MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info()); + } + } + + if (!input_abs->isa()) { + new_inputs.emplace_back(input); + continue; + } + + int idx = 0; + std::vector new_input; + auto abs_tuple = dyn_cast(input_abs); + for (auto &elem : abs_tuple->elements()) { + auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); + AbstractBasePtr aptr = std::make_shared(std::make_shared(idx)); + c_node->input(2)->set_abstract(aptr); + c_node->set_abstract(elem); + new_input.emplace_back(c_node); + idx++; + } + + auto expand_tuple = ExpandTuplesC(graph, new_input); + (void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end()); + } + + return new_inputs; +} + +// remove most uses of tuples from the graph parameters & apply inputs +// tuples that are returned will be kept +// tuples in CNode's inputs: AbstractTuple (a, b ,c) --> +// CNode("tuple_getitem", (a,b,c), 0) +// CNode("tuple_getitem", (a,b,c), 1) +// CNode("tuple_getitem", (a,b,c), 2) +// tuples in Graph's parameters: AbstractTuple (a, b, c) --> +// CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) +// cppcheck-suppress unusedFunction +void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var + AnfNodeSet all_node = manager->all_nodes(); + for (auto &node : all_node) { + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + + const auto &inputs = cnode->inputs(); + + // Bypass the first input in inputs as it's fn. + if (!IsValueNode(inputs[0])) { + std::vector expand_inputs; + (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end()); + + auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs); + if (new_inputs != expand_inputs) { + std::vector cnode_inputs{inputs[0]}; + (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end()); + + MS_EXCEPTION_IF_NULL(node->func_graph()); + auto new_node = node->func_graph()->NewCNode(cnode_inputs); + new_node->set_abstract(node->abstract()); + + (void)manager->Replace(node, new_node); + } + // Bypass the first 2 inputs in inputs as it's [partial, fn]. + } else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode(inputs[1])) { + std::vector expand_inputs; + (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end()); + + auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs); + if (new_inputs != expand_inputs) { + std::vector cnode_inputs{inputs[0], inputs[1]}; + (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end()); + + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + auto new_node = cnode->func_graph()->NewCNode(cnode_inputs); + new_node->set_abstract(cnode->abstract()); + + (void)manager->Replace(node, new_node); + } + } + } + + FuncGraphSet all_graph = manager->func_graphs(); + for (auto &func_graph : all_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); + manager->SetParameters(func_graph, expand_p); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/clean.h b/mindspore/ccsrc/frontend/optimizer/clean.h new file mode 100644 index 0000000000..54faabaa63 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/clean.h @@ -0,0 +1,43 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ + +#include +#include "ir/anf.h" +#include "frontend/operator/ops.h" +#include "utils/any.h" +#include "ir/manager.h" +#include "abstract/dshape.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { + +// Remove the class type from graphs +bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); + +// Remove most uses of tuples from the graph +// tuples that are returned will be kept +void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); + +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/control_depend.cc b/mindspore/ccsrc/frontend/optimizer/control_depend.cc new file mode 100644 index 0000000000..8cc9bdb7f4 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/control_depend.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/control_depend.h" + +#include +#include +#include +#include +#include + +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +std::vector DoControlDepend(const FuncGraphPtr &graph, const CNodePtr &return_node, + const std::vector &effect_index, const std::vector &cnodes) { + std::vector depend_nodes{NewValueNode(prim::kPrimDepend), return_node->input(1)}; + std::vector make_tuple{NewValueNode(prim::kPrimMakeTuple)}; + size_t effect_size = effect_index.size(); + for (size_t i = 0; i < effect_size; i++) { + size_t pre_index = 0; + if (i > 0) { + pre_index = effect_index[i - 1] + 1; + } + size_t this_index = effect_index[i]; + size_t last_index = cnodes.size() - 2; + if (i < effect_size - 1) { + last_index = effect_index[i + 1]; + } + + if (this_index > pre_index) { + std::vector pre_segment; + for (size_t k = pre_index; k < this_index; k++) { + // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. + if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || + IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { + continue; + } + pre_segment.push_back(cnodes[k]); + } + auto roots = FindRoots(pre_segment); + for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { + AnfNodePtr control_depend = + graph->NewCNode({NewValueNode(prim::kPrimControlDepend), *iter, cnodes[this_index]}); + make_tuple.push_back(control_depend); + } + } + if (last_index > this_index) { + std::vector last_segment; + for (size_t k = this_index + 1; k <= last_index; k++) { + // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. + if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || + IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { + continue; + } + last_segment.push_back(cnodes[k]); + } + auto leaves = FindLeaves(last_segment); + for (auto iter = leaves->begin(); iter != leaves->end(); (void)iter++) { + AnfNodePtr control_depend = + graph->NewCNode({NewValueNode(prim::kPrimControlDepend), cnodes[this_index], *iter}); + make_tuple.push_back(control_depend); + } + } + } + depend_nodes.push_back(graph->NewCNode(make_tuple)); + return depend_nodes; +} + +void AddControlDepend(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + std::list orders = graph->GetOrderedCnodes(); + std::vector cnodes(orders.begin(), orders.end()); + size_t cnodes_size = cnodes.size(); + // get effect index of cnodes + std::vector effect_index{}; + for (size_t i = 0; i < cnodes_size; i++) { + if (graph->HasEffect(cnodes[i])) { + effect_index.push_back(i); + } + } + if (effect_index.empty()) { + return; + } + AnfNodePtr last_node = cnodes[cnodes_size - 1]; + CNodePtr return_node; + if (last_node->isa()) { + return_node = last_node->cast(); + } + MS_EXCEPTION_IF_NULL(return_node); + if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { + MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; + } + if (return_node->inputs().size() < 2) { + MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; + } + + auto depend_node_inputs = DoControlDepend(graph, return_node, effect_index, cnodes); + auto depend_cnode = graph->NewCNode(depend_node_inputs); + depend_cnode->set_abstract(depend_cnode->input(1)->abstract()); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (!manager->Replace(return_node->input(1), depend_cnode)) { + MS_LOG(EXCEPTION) << "Depend replace node failed"; + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/control_depend.h b/mindspore/ccsrc/frontend/optimizer/control_depend.h similarity index 100% rename from mindspore/ccsrc/optimizer/control_depend.h rename to mindspore/ccsrc/frontend/optimizer/control_depend.h diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc new file mode 100644 index 0000000000..4d968d6d74 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -0,0 +1,231 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/cse.h" +#include +#include +#include +#include "./common.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractFunctionPtr; + +BasePtr AbsOf(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto node_abs = node->abstract(); + // in testcase: TestOptOpt.CSE, node->abstract() is null; + if (node_abs == nullptr) { + return kAnyValue; + } + + return node_abs; +} + +bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { + bool changed = false; + for (FuncGraphPtr fg : manager->func_graphs()) { + MS_EXCEPTION_IF_NULL(fg); + std::vector order_group; + std::unordered_map> groups; + std::unordered_map hashes; + + std::vector toposet = TopoSort(fg->get_return()); + for (auto node : toposet) { + MS_EXCEPTION_IF_NULL(node); + if (hashes.find(node) != hashes.end()) { + continue; + } + + std::size_t h = 0; + if (node->isa()) { + ValueNodePtr value_node = node->cast(); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + h = hash_combine(value->hash(), (AbsOf(value_node)->hash())); + } else if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + size_t init = 0; + h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) { + return hash_combine(hash, hashes[node_in]); + }); + } else if (node->isa()) { + h = node->hash(); + } else { + MS_LOG(ERROR) << "Unknow node type"; + } + + hashes[node] = h; + if (groups.find(h) == groups.end()) { + std::vector innervec({node}); + groups[h] = innervec; + order_group.emplace_back(h); + } else { + groups[h].push_back(node); + } + } + + changed = DoReplace(manager, order_group, &groups) || changed; + } + + return changed; +} +// The op like print, summary, or the op do not has true output, and always as a depend node input. +static bool HasSideEffect(const AnfNodePtr &node) { + auto prim = GetCNodePrimitive(node); + if (prim == nullptr) { + return false; + } + auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT); + if (side_effect_v != nullptr && side_effect_v->isa()) { + return GetValue(side_effect_v); + } + return false; +} +// If true do not merge the node. +bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { + bool has_random_effect = false; + auto prim_main = GetCNodePrimitive(main); + auto prim_node = GetCNodePrimitive(node); + // if has random effect, when generate by different op (not same object), do not merge. + if (prim_main != nullptr) { + if (prim_main == prim_node) { + return false; + } + auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); + if (effect_val != nullptr && effect_val->isa()) { + has_random_effect = GetValue(effect_val); + } + } + return has_random_effect; +} + +bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { + MS_EXCEPTION_IF_NULL(main); + MS_EXCEPTION_IF_NULL(node); + + if (main->isa() && node->isa()) { + auto main_value = GetValueNode(main); + auto node_value = GetValueNode(node); + return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); + } else if (main->isa() && node->isa()) { + auto c_main = main->cast(); + auto c_node = node->cast(); + // When appsame is true, check if has side effect, do not merge. + if (check_side_effect && HasSideEffect(main)) { + return false; + } + const auto &inp1 = c_main->inputs(); + const auto &inp2 = c_node->inputs(); + if (inp1.size() != inp2.size()) { + return false; + } + for (size_t j = 0; j < inp1.size(); j++) { + auto inp1_j = inp1[j]; + auto inp2_j = inp2[j]; + MS_EXCEPTION_IF_NULL(inp1_j); + MS_EXCEPTION_IF_NULL(inp2_j); + if (!(*inp1_j == *inp2_j)) { + // Handle the case of two different Tensor, but with the same value + if (IsValueNode(inp1_j) && IsValueNode(inp2_j)) { + auto tensor1 = GetValueNode(inp1_j); + auto tensor2 = GetValueNode(inp2_j); + if (tensor1->ValueEqual(*tensor2)) { + continue; + } + } else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) { + // When the same side effect node as another two nodes' inputs, we still merge the node. + // Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the + // node. + if (CheckReplace(inp1_j, inp2_j, false)) { + continue; + } + } + return false; + } + } + // When appsame is true, check if has random effect do not merge + if (CheckRandomEffect(c_main, c_node)) { + return false; + } + return true; + } + // a parameter node. + return false; +} + +bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, + std::unordered_map> *groups) const { + bool changes = false; + std::set clear_set; + for (auto &h : order_group) { + std::vector &group = (*groups)[h]; + // If there are more than 2 node in that group, they may be same common expression can be eliminated. + if (group.size() > 1) { + for (size_t k = 0; k < group.size() - 1; k++) { + AnfNodePtr main = group[k]; + MS_EXCEPTION_IF_NULL(main); + + // When all node in group has been replaced + // or a valuenode node, skip compare in group + if ((k + 1 + clear_set.size() == group.size()) || (k > 0 && main->isa())) { + break; + } + + // skip node has been replaced + if (clear_set.find(k) != clear_set.end()) { + continue; + } + + // Compare with rest elements in this group. + for (size_t i = k + 1; i < group.size(); i++) { + auto node = group[i]; + MS_EXCEPTION_IF_NULL(node); + + if (clear_set.find(i) != clear_set.end()) { + continue; + } + if (main->func_graph() != node->func_graph()) { + continue; + } + if (CheckReplace(node, main)) { + changes = true; + (void)manager->Replace(node, main); + (void)clear_set.insert(i); + } + } + } + clear_set.clear(); + } + } + + return changes; +} + +bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const { + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + return BuildOrderGroupAndDoReplace(manager); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/cse.h b/mindspore/ccsrc/frontend/optimizer/cse.h new file mode 100644 index 0000000000..140f592715 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/cse.h @@ -0,0 +1,61 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "ir/manager.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { + +// Common subexpression elimination. +class CSE { + public: + explicit CSE(bool report_changes = true) : report_changes_(report_changes) {} + virtual ~CSE() = default; + + bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { + bool chg = Cse(root, optimizer->resource()->manager()); + return chg && report_changes_; + } + + virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const; + + virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; + + bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const; + + private: + bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const; + bool DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, + std::unordered_map> *groups) const; + bool report_changes_; +}; + +BasePtr AbsOf(const AnfNodePtr &node); +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc new file mode 100644 index 0000000000..c157777040 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/graph_kernel_reuse.h" +#include +#include +#include +#include "./common.h" +#include "utils/graph_utils.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { + +bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) { + if (a->abstract() && b->abstract()) { + auto a_type = a->abstract()->GetTypeTrack(); + auto b_type = b->abstract()->GetTypeTrack(); + + if (a_type != b_type) { + return false; + } + + auto a_shape = a->abstract()->GetShapeTrack(); + auto b_shape = b->abstract()->GetShapeTrack(); + if (a_shape != nullptr && a_shape == b_shape) { + return true; + } + + if (a_shape != nullptr && b_shape != nullptr && a_shape->isa() && + b_shape->isa()) { + return a_shape->cast()->shape() == b_shape->cast()->shape(); + } + } + return false; +} + +bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { + bool changed = false; + auto fgs = manager->func_graphs(); + for (FuncGraphPtr &fg : fgs) { + if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + continue; + } + std::string key = GetValue(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) { + if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) { + FuncGraphPtr new_fg = nullptr; + for (auto &cfg : graph_kernel_ops[key]) { + // If two graphs have different size then continue + auto fg_topos = TopoSort(fg->get_return()); + auto cfg_topos = TopoSort(cfg->get_return()); + if (fg_topos.size() != cfg_topos.size()) { + continue; + } + + // Compare const tensor + bool has_same = true; + for (size_t i = 0; i < fg_topos.size(); ++i) { + if (IsValueNode(fg_topos[i])) { + if (!IsValueNode(cfg_topos[i])) { + has_same = false; + break; + } + + auto tensor1 = GetValueNode(fg_topos[i]); + auto tensor2 = GetValueNode(cfg_topos[i]); + if (!tensor1->ValueEqual(*tensor2)) { + has_same = false; + break; + } + } + } + + if (!has_same) { + continue; + } + + auto fg_input = fg->parameters(); + auto cfg_input = cfg->parameters(); + if (fg_input.size() != cfg_input.size()) { + continue; + } + // Compare input + for (size_t i = 0; i < fg_input.size(); ++i) { + if (!CompareNode(fg_input[i], cfg_input[i])) { + has_same = false; + break; + } + } + if (!has_same) { + continue; + } + + // Compare output + if (!CompareNode(fg->output(), cfg->output())) { + continue; + } + + // Find reusable fg + new_fg = cfg; + break; + } + + if (new_fg != nullptr) { + // Replace current fg with existing fg + auto users = fg->func_graph_cnodes_index(); + for (auto &iter : users) { + auto cnode = iter.first->first->cast(); + auto new_input = cnode->inputs(); + auto main_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(main_graph); + if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) { + new_input[1] = NewValueNode(new_fg); + } else { + new_input[0] = NewValueNode(new_fg); + } + auto new_cnode = main_graph->NewCNode(new_input); + manager->Replace(iter.first->first, new_cnode); + changed = true; + } + + } else { + // Add current fg to map + graph_kernel_ops[key].push_back(fg); + } + } + } else { + graph_kernel_ops[key] = {fg}; + } + } + + return changed; +} + +bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) { + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + return DoReplace(manager); +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h new file mode 100644 index 0000000000..a79ef3ce6d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H +#define MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H + +#include +#include +#include +#include "mindspore/ccsrc/backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { + +// Common subexpression elimination. +class GraphKernelReuse { + public: + GraphKernelReuse() : count(0) {} + virtual ~GraphKernelReuse() = default; + + bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { + bool chg = ReuseGraphKernel(root, optimizer->resource()->manager()); + return chg; + } + + bool CompareNode(const AnfNodePtr a, const AnfNodePtr other); + bool DoReplace(const FuncGraphManagerPtr manager); + + bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager); + + private: + std::unordered_map> graph_kernel_ops; + int count; +}; + +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc new file mode 100644 index 0000000000..efc3795a4c --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -0,0 +1,174 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/arithmetic_simplify.h" +#include "frontend/optimizer/irpass/branch_culling.h" +#include "frontend/optimizer/irpass/cast_eliminate.h" +#include "frontend/optimizer/irpass/convert.h" +#include "frontend/optimizer/irpass/env_item_eliminate.h" +#include "frontend/optimizer/irpass/grad_var_prepare.h" +#include "frontend/optimizer/irpass/gradient_eliminate.h" +#include "frontend/optimizer/irpass/inline.h" +#include "frontend/optimizer/irpass/incorporate_call.h" +#include "frontend/optimizer/irpass/incorporate_getitem.h" +#include "frontend/optimizer/irpass/item_tuple_eliminate.h" +#include "frontend/optimizer/irpass/mark_interface_fusion.h" +#include "frontend/optimizer/irpass/merge_addn.h" +#include "frontend/optimizer/irpass/minmax_grad.h" +#include "frontend/optimizer/irpass/param_replace.h" +#include "frontend/optimizer/irpass/partial_eliminate.h" +#include "frontend/optimizer/irpass/reduce_eliminate.h" +#include "frontend/optimizer/irpass/ref_eliminate.h" +#include "frontend/optimizer/irpass/reshape_eliminate.h" +#include "frontend/optimizer/irpass/special_op_eliminate.h" +#include "frontend/optimizer/irpass/specialize_transform.h" +#include "frontend/optimizer/irpass/symbol_resolver.h" +#include "frontend/optimizer/irpass/tile_eliminate.h" +#include "frontend/optimizer/irpass/transpose_eliminate.h" +#include "frontend/optimizer/opt.h" +#include "frontend/optimizer/irpass/indexed_slices_eliminate.h" + +namespace mindspore { +namespace opt { +namespace irpass { +OptimizeIRPassLib::OptimizeIRPassLib() { + arithmetic_simplify_ = MakeSubstitution(std::make_shared(), "arithmetic_simplify", + {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, + prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); + arithmetic_simplify2_ = + MakeSubstitution(std::make_shared(), "arithmetic_simplify2", {prim::kPrimMul}); + special_op_eliminate_ = + MakeSubstitution(std::make_shared(), "special_op_eliminate", + {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, + prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv}); + zero_like_fill_zero_ = + MakeSubstitution(std::make_shared(), "zero_like_fill_zero", prim::kPrimZerosLike); + adjust_all_reduce_mul_add_ = + MakeSubstitution(std::make_shared(), "adjust_all_reduce_mul_add", prim::kPrimAddN); + + // ops eliminate + item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", + {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); + tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); + cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); + reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); + transpose_eliminate_ = + MakeSubstitution(std::make_shared(), "transpose_eliminate", prim::kPrimTranspose); + reduce_eliminate_ = MakeSubstitution( + std::make_shared(), "reduce_eliminate", + {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); + partial_eliminate_ = MakeSubstitution(std::make_shared(), "partial_eliminate", IsCNodeDup); + same_eliminate_ = MakeSubstitution(std::make_shared(), "same_eliminate", prim::kPrimSameTypeShape); + check_bprop_eliminate_ = + MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); + reset_defer_inline_ = + MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); + depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); + + // Env Item Eliminate + env_get_item_eliminate_ = + MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); + new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); + incorporate_env_getitem_ = + MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); + incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), + "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); + + // Ref eliminate + make_ref_eliminate_ = + MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); + get_ref_param_eliminate_ = MakeSubstitution(std::make_shared(), "get_ref_param_eliminate", + {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); + get_make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "get_make_ref_eliminate", + {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); + + replace_refkey_by_param_ = MakeSubstitution(std::make_shared(), "replace_refkey_by_param", + IsValueNode, opt::FORCE_RENORM); + replace_old_param_ = MakeSubstitution(std::make_shared(), "replace_old_param", IsParam); + // Gradient transforms + expand_jprim_ = MakeSubstitution(std::make_shared(), "expand_jprim", prim::kPrimJ); + minmaximum_grad_ = MakeSubstitution(std::make_shared(), "minmaximum_grad", prim::kPrimTupleGetItem); + + // branch culling + switch_simplify_ = MakeSubstitution(std::make_shared(), "switch_simplify", prim::kPrimSwitch); + float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared(), + "float_tuple_getitem_switch", prim::kPrimTupleGetItem); + float_env_getitem_switch_ = + MakeSubstitution(std::make_shared(), "float_env_getitem_switch", prim::kPrimEnvGetItem); + convert_switch_replacement_ = + MakeSubstitution(std::make_shared(), "convert_switch_replacement", IsCNodeDup); + + // Addn + merge_addn_ = MakeSubstitution(std::make_shared(), "merge_addn", prim::kPrimAddN); + addn_zero_filter_ = MakeSubstitution(std::make_shared(), "addn_zero_filter", prim::kPrimAddN); + + // inline + inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); + replace_applicator_ = + MakeSubstitution(std::make_shared(), "replace_applicator", IsValueNode); + specialize_transform_ = + MakeSubstitution(std::make_shared(), "specialize_transform", IsCNodeGraph); + + // Incorporation + incorporate_getitem_set_ = + MakeSubstitution(std::make_shared(), "incorporate_getitem_set", prim::kPrimTupleGetItem); + incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared(), + "incorporate_getitem_from_param", IsCNodeGraphKernel); + incorporate_call_ = MakeSubstitution(std::make_shared(), "incorporate_call", IsCNodeDup); + incorporate_call_switch_ = + MakeSubstitution(std::make_shared(), "incorporate_call_switch", IsCNodeDup); + + // Virtual Dataset + virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared(), + "virtual_dataset_eliminate", prim::kPrimVirtualDataset); + + // Convert + print_tuple_wrapper_ = + MakeSubstitution(std::make_shared(), "print_tuple_wrapper", prim::kPrimPrint); + + // Unused parameter eliminate + unused_parameter_eliminate_ = + MakeSubstitution(std::make_shared(), "unused_parameter_eliminate", IsCNodeGraphKernel); + unused_output_eliminate_ = + MakeSubstitution(std::make_shared(), "unused_output_eliminate", IsCNodeGraphKernel); + + // AddN eliminate + addn_eliminate_ = MakeSubstitution(std::make_shared(), "addn_eliminate", IsCNodeGraphKernel); + + // Mark interface fusion + mark_interface_fusion_ = + MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); + + // IndexedSlices Eliminate + indexed_slices_eliminate_ = MakeSubstitution( + std::make_shared(), "indexed_slices_eliminate", + {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); +} + +ResolveIRPassLib::ResolveIRPassLib() { + resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); + resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); +} + +InferenceOptPrepareLib::InferenceOptPrepareLib() { + grad_var_prepare_ = MakeSubstitution(std::make_shared(), "grad_var_prepare", IsCNode); +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h new file mode 100644 index 0000000000..4af8c0789d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -0,0 +1,192 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ + +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/opt.h" +#include "ir/visitor.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// the collection of irpass for optimie action +class OptimizeIRPassLib { + public: + OptimizeIRPassLib(); + ~OptimizeIRPassLib() = default; + + SubstitutionPtr arithmetic_simplify_; + SubstitutionPtr arithmetic_simplify2_; + SubstitutionPtr special_op_eliminate_; + SubstitutionPtr zero_like_fill_zero_; + SubstitutionPtr adjust_all_reduce_mul_add_; + + // ops eliminate + SubstitutionPtr item_tuple_eliminate_; + SubstitutionPtr tile_eliminate_; + SubstitutionPtr cast_eliminate_; + SubstitutionPtr reshape_eliminate_; + SubstitutionPtr transpose_eliminate_; + SubstitutionPtr reduce_eliminate_; + SubstitutionPtr partial_eliminate_; + SubstitutionPtr same_eliminate_; + SubstitutionPtr check_bprop_eliminate_; + SubstitutionPtr reset_defer_inline_; + SubstitutionPtr depend_value_elim_; + + // Env Item Eliminate + SubstitutionPtr env_get_item_eliminate_; + SubstitutionPtr new_env_get_item_; + SubstitutionPtr incorporate_env_getitem_; + SubstitutionPtr incorporate_env_getitem_switch_; + + // Ref eliminate + SubstitutionPtr make_ref_eliminate_; + SubstitutionPtr get_ref_param_eliminate_; + SubstitutionPtr get_make_ref_eliminate_; + SubstitutionPtr replace_refkey_by_param_; + SubstitutionPtr replace_old_param_; + + // Branch culling + SubstitutionPtr switch_simplify_; + SubstitutionPtr float_tuple_getitem_switch_; + SubstitutionPtr float_env_getitem_switch_; + SubstitutionPtr convert_switch_replacement_; + + // AddN + SubstitutionPtr merge_addn_; + SubstitutionPtr addn_zero_filter_; + + // Gradient irpasses + SubstitutionPtr expand_jprim_; + SubstitutionPtr minmaximum_grad_; + + // inline + SubstitutionPtr inline_; + SubstitutionPtr replace_applicator_; + SubstitutionPtr specialize_transform_; + + // Incorporation + SubstitutionPtr incorporate_getitem_set_; + SubstitutionPtr incorporate_getitem_from_param_; + SubstitutionPtr incorporate_call_; + SubstitutionPtr incorporate_call_switch_; + + // virtual dataset + SubstitutionPtr virtual_dataset_eliminate_; + + // Convert + SubstitutionPtr print_tuple_wrapper_; + + // Unused parameter eliminate + SubstitutionPtr unused_parameter_eliminate_; + SubstitutionPtr unused_output_eliminate_; + + // AddN eliminate + SubstitutionPtr addn_eliminate_; + + // Fusion + SubstitutionPtr mark_interface_fusion_; + + // IndexedSlices Eliminate + SubstitutionPtr indexed_slices_eliminate_; +}; + +// the collection of irpass for resolve action +class ResolveIRPassLib { + public: + ResolveIRPassLib(); + ~ResolveIRPassLib() = default; + + SubstitutionPtr resolver_resolve_; + SubstitutionPtr resolver_getattr_; +}; + +class InferenceOptPrepareLib { + public: + InferenceOptPrepareLib(); + ~InferenceOptPrepareLib() = default; + SubstitutionPtr grad_var_prepare_; +}; + +// predicate functions +inline bool IsNode(const AnfNodePtr &) { return true; } + +inline bool IsCNode(const AnfNodePtr &node) { + if (node != nullptr) { + return node->isa(); + } + return false; +} + +inline bool IsVNode(const AnfNodePtr &node) { + if (node != nullptr) { + return node->isa(); + } + return false; +} + +inline bool IsParam(const AnfNodePtr &node) { + if (node != nullptr) { + return node->isa(); + } + return false; +} + +// Check if CNode Input 0 is Func Graph +inline bool IsCNodeGraph(const AnfNodePtr &node) { + if (node == nullptr || !node->isa()) { + return false; + } + + auto inp0 = node->cast()->input(0); + return IsValueNode(inp0); +} + +// Check if CNode Input 0 is Func Graph of graph kernel. +inline bool IsCNodeGraphKernel(const AnfNodePtr &node) { + if (node == nullptr || !node->isa()) { + return false; + } + + auto inp0 = node->cast()->input(0); + if (IsValueNode(inp0)) { + auto fg = GetValueNode(inp0); + if (fg == nullptr) { + return false; + } + return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + } + return false; +} + +// Check if CNode Input 0 is CNode +inline bool IsCNodeDup(const AnfNodePtr &node) { + if (node == nullptr || !node->isa()) { + return false; + } + + auto inp0 = node->cast()->input(0); + return (inp0 != nullptr) && inp0->isa(); +} +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc new file mode 100644 index 0000000000..83f7fae582 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -0,0 +1,680 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "frontend/optimizer/irpass/arithmetic_simplify.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/prim_eliminate.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} +// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} +AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimScalarMul)(node); + + if (is_zero_) { + return NewValueNode(zero_); + } + if (is_one_) { + return x_; + } + return nullptr; +} + +void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) { + if (is_one_ || node->isa()) { + x_ = node; + return; + } + + AnfVisitor::Visit(node); + if (!is_one_) { + x_ = node; + } +} + +void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (*value == *zero_) { + is_zero_ = true; + } else if (*value == *one_) { + is_one_ = true; + } +} + +void MultiplyByZeroOrOne::Reset() { + x_ = nullptr; + is_one_ = false; + is_zero_ = false; +} + +// Support class used for checking if all values of a Tensor are equal `check_value_` +// Supported data types: double, float/float32, int/int32 +bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > FLT_EPSILON) { + return false; + } + } + return true; + } else if (tensor_type == TypeId::kNumberTypeFloat64) { + double *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > DBL_EPSILON) { + return false; + } + } + return true; + } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { + int *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (data2[i] != check_value_) { + return false; + } + } + return true; + } + // input Data Types is not supported + return false; +} + +bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { + return false; + } + return IsTensorConstant(value); +} + +void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) { + if (!node->isa()) { + return nullptr; + } + + auto value = node->cast()->value(); + + if (!value->isa()) { + return nullptr; + } + + tensor::TensorPtr tensor_ptr = dyn_cast(value); + return tensor_ptr->data_c(); +} + +// Make a new tensor (when possible) with the same shape as of `node` +// If x is nullptr then fill new tensor will "0" +// If x is a tensor with empty shape then fill new tensor with the single value of x +// If x is a tensor with same shape as `node` then return x as result +AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) { + if ((node->abstract() == nullptr) || !node->abstract()->isa()) { + return nullptr; + } + + auto tensor_abstract = node->abstract()->cast(); + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + + if (x == nullptr) { + std::memset(data, 0, mem_size); + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + // x is not nullptr + if (x->isa()) { + if ((x->abstract() == nullptr) || !x->abstract()->isa()) { + return nullptr; + } + auto x_abstract = x->abstract()->cast(); + std::vector x_shape = x_abstract->shape()->shape(); + + if (x_shape != tensor_shape) { + return nullptr; + } + return x; + } + + if (!x->isa()) { + return nullptr; + } + auto x_value = x->cast()->value(); + if (!x_value->isa()) { + return nullptr; + } + + auto x_tensor_ptr = dyn_cast(x_value); + + if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { + return nullptr; + } + char *source_data = reinterpret_cast(GetPointerToTensorData(x)); + if (x_tensor_ptr->DataSize() == 1) { + for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { + memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr)); + } + } else { + memcpy(data, source_data, mem_size); + } + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; +} + +// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} +AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimMul)(node); + + if (is_zero_) { + if (x_->func_graph() != node->func_graph()) { + return nullptr; + } + return NewTensorFilledWithData(node); + } + return nullptr; +} + +void TensorMultiplyByZero::Visit(const AnfNodePtr &node) { + if (is_zero_) { + x_ = node; + return; + } + + if (IsParam(node)) { + x_ = node; + return; + } + + if (IsCNode(node)) { + CNodePtr cnode = node->cast(); + if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { + is_zero_ = true; + return; + } + x_ = node; + return; + } + auto value = node->cast()->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } + x_ = node; +} + +void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } + x_ = vnode; +} +void TensorMultiplyByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} +AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimMul)(node); + + if (is_one_) { + return NewTensorFilledWithData(node, x_); + } + return nullptr; +} + +void TensorMultiplyByOne::Visit(const AnfNodePtr &node) { + if (is_one_) { + x_ = node; + return; + } + + if (IsParam(node) || IsCNode(node)) { + x_ = node; + return; + } + + auto value = node->cast()->value(); + if (CheckTensorConstant(1).IsTensorConstant(value)) { + is_one_ = true; + return; + } + x_ = node; +} + +void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(1).IsTensorConstant(value)) { + is_one_ = true; + return; + } + x_ = vnode; +} +void TensorMultiplyByOne::Reset() { + x_ = nullptr; + is_one_ = false; +} + +// {prim::kPrimScalarAdd, X, 0} +// {prim::kPrimScalarAdd, 0, X} +AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimScalarAdd)(node); + + if (is_zero_) { + return x_; + } + return nullptr; +} + +void AddByZero::Visit(const AnfNodePtr &node) { + if (node->isa() && + ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { + is_zero_ = true; + return; + } + + x_ = node; +} + +void AddByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, +// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} +AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimTensorAdd)(node); + + if (is_zero_) { + return x_; + } + return nullptr; +} + +void TensorAddByZero::Visit(const AnfNodePtr &node) { + if (node->isa() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { + is_zero_ = true; + return; + } + + x_ = node; +} + +void TensorAddByZero::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } +} + +void TensorAddByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} +AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { + return nullptr; + } + + // {PrimMomentum, {...}, Y, Z, Xs} + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { + return nullptr; + } + auto y = inputs[2]; + auto z = inputs[3]; + + // {kPrimZerosLike, X} + if (inputs[1]->cast()->size() != 2) { + return nullptr; + } + + // {prim::kPrimMakeTuple, Z, Y} + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); +} + +// {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} -> +// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} +// Support function to multiply two constant tensors: partially support broadcasting shapes +template +void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, + void **out_data, int out_data_size) { + T *data_1 = reinterpret_cast(in_data_1); + T *data_2 = reinterpret_cast(in_data_2); + T *data_out = new T[out_data_size]; + + if (in_data_1_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[i]; + } + } + if (in_data_2_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[i]; + } + } + *out_data = reinterpret_cast(data_out); + return; +} + +AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, + const AnfNodePtr &node_3) { + if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || + (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { + return nullptr; + } + + auto value_1 = GetValueNode(vnode_1); + auto value_2 = GetValueNode(vnode_2); + + if (!value_1->isa() || !value_2->isa()) { + return nullptr; + } + + auto tensor_ptr_1 = dyn_cast(value_1); + auto tensor_ptr_2 = dyn_cast(value_2); + + auto tensor_1_abstract = vnode_1->abstract()->cast(); + auto tensor_2_abstract = vnode_1->abstract()->cast(); + auto tensor_3_abstract = node_3->abstract()->cast(); + + TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); + TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); + TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); + + if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || + (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { + return nullptr; + } + + std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); + + int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); + + if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { + return nullptr; + } + if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { + return nullptr; + } + + void *data_out; + + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), + &data_out, data_out_size); + } else { + if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + // Un-support data types + return nullptr; + } + } + } + + auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); + size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + memcpy(data, data_out, mem_size); + + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; +} + +AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + // {prim::kPrimMul, Tensor1, {...}} + AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); + if (vnode_ == nullptr || c_p_node_ == nullptr) { + return nullptr; + } + + if (!IsCNode(c_p_node_)) { + return nullptr; + } + + auto tensor1 = vnode_; + auto mul = c_p_node_->cast(); + + Reset(); + // {prim::kPrimMul, Tensor2, {...}} + AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); + if (vnode_ == nullptr || c_p_node_ == nullptr) { + return nullptr; + } + auto tensor2 = vnode_; + auto c_p_node = c_p_node_; + + auto PrimMul = GetValueNode(mul->input(0)); + auto fg = node->func_graph(); + + auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); + if (new_mul_tensor == nullptr) { + auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); + return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); + } + return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); +} + +void ConstantDuplicateMul::Visit(const AnfNodePtr &node) { + if (IsValueNode(node)) { + vnode_ = node; + } + + if (IsCNode(node) || IsParam(node)) { + c_p_node_ = node; + } +} + +void ConstantDuplicateMul::Reset() { + vnode_ = nullptr; + c_p_node_ = nullptr; +} + +AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (!IsValueNode(inputs[2])) { + return nullptr; + } + auto scalar = GetValueNode(inputs[2]); + if (scalar->isa() && GetValue(scalar) == 1.0) { + return inputs[1]; + } else if (scalar->isa() && GetValue(scalar) == 1) { + return inputs[1]; + } + return nullptr; +} + +// grad = AllReduce(grad) / worker_number +// grad = grad + weight * decy +// -> +// grad = grad + weight * decy +// grad = AllReduce(grad) / worker_number +// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> +// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} +AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + // {prim::kPrimAddN, Zs} + if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { + return nullptr; + } + auto addn = node->cast(); + if (addn->size() != 2) { + return nullptr; + } + AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); + if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { + return nullptr; + } + auto addn_maketuple = addn->input(1); + + auto fg = all_reduce_fg_; + // addn inputs cross the graph, make the inputs same as allreduce node. + if (z_->isa() && fg != z_->func_graph()) { + auto cnode_z = z_->cast(); + z_ = NewCNode(cnode_z->inputs(), fg); + } + + auto addn_op_node = addn->input(0); + auto make_tuple_op_node = addn->input(1)->cast()->input(0); + + AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); + AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); + AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); + AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); + ProcessDependEdge(fg, addn_maketuple, all_reduce); + return mul; +} + +void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, + const AnfNodePtr &new_node) { + // If has dynamic loss scale. + auto &users_map = fg->manager()->node_users(); + auto it = users_map.find(mul_cnode_); + if (it != users_map.end()) { + auto users = it->second; + for (auto &user_pair : users) { + auto node = user_pair.first; + if (node != addn_maketuple) { + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + fg->manager()->SetEdge(node, user_pair.second, new_node); + } + } + } + } +} + +void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) { + if (level_ == 0) { + level_ = 1; + is_reduce_match_ = false; + // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} + AnfVisitor::Match(prim::kPrimMul)(node); + level_ = 0; + if (is_reduce_match_) { + mul_ = node->cast()->input(0); + mul_cnode_ = node->cast(); + y_ = tmp_; + } else { + z_ = node; + } + } + + if (level_ == 1) { + // {prim::kPrimAllReduce, X} + if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { + auto cnode = node->cast(); + if (cnode->size() > 1) { + all_reduce_ = cnode->input(0); + x_ = cnode->input(1); + is_reduce_match_ = true; + all_reduce_fg_ = cnode->func_graph(); + } + } else { + tmp_ = node; + } + } +} + +void AdjustAllReduceMulAdd::Reset() { + level_ = 0; + is_reduce_match_ = false; + x_ = nullptr; + y_ = nullptr; + z_ = nullptr; + tmp_ = nullptr; + all_reduce_fg_ = nullptr; +} + +AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; +} + +AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h new file mode 100644 index 0000000000..3088231396 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h @@ -0,0 +1,259 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ + +#include +#include +#include + +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/prim_eliminate.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} +// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} +class MultiplyByZeroOrOne : public AnfVisitor { + public: + MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} + ~MultiplyByZeroOrOne() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_zero_{false}, is_one_{false}; + ValuePtr zero_, one_; + AnfNodePtr x_{nullptr}; +}; + +// Support class used for checking if all values of a Tensor are equal `check_value_` +// Supported data types: double, float/float32, int/int32 +class CheckTensorConstant { + public: + explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} + ~CheckTensorConstant() = default; + + bool IsTensorConstant(const ValuePtr &value); + bool IsTensorScalarConstant(const ValuePtr &value); + + private: + int check_value_; +}; + +class TensorMultiplyBase : public AnfVisitor { + protected: + void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false); + + // Make a new tensor (when possible) with the same shape as of `node` + // If x is nullptr then fill new tensor will "0" + // If x is a tensor with empty shape then fill new tensor with the single value of x + // If x is a tensor with same shape as `node` then return x as result + AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr); + + AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} +class TensorMultiplyByZero : public TensorMultiplyBase { + public: + TensorMultiplyByZero() : zero_(MakeValue(0)) {} + ~TensorMultiplyByZero() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_zero_{false}; + ValuePtr zero_; +}; + +// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} +class TensorMultiplyByOne : public TensorMultiplyBase { + public: + TensorMultiplyByOne() {} + ~TensorMultiplyByOne() override = default; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_one_{false}; +}; + +// {prim::kPrimScalarAdd, X, 0} +// {prim::kPrimScalarAdd, 0, X} +class AddByZero : public AnfVisitor { + public: + AddByZero() : zero_(MakeValue(0)) {} + ~AddByZero() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Reset(); + + private: + bool is_zero_{false}; + ValuePtr zero_; + AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, +// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} +class TensorAddByZero : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_zero_{false}; + AnfNodePtr x_{nullptr}; +}; + +// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} +class OptUpdateZeroTensor : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; + +// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> +// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} +class ConstantDuplicateMul : public AnfVisitor { + public: + // Support function to multiply two constant tensors: partially support broadcasting shapes + template + void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, + int out_data_size); + + AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3); + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Reset(); + + private: + AnfNodePtr vnode_; + AnfNodePtr c_p_node_; +}; + +class PowerOneEliminate : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; + +// grad = AllReduce(grad) / worker_number +// grad = grad + weight * decy +// -> +// grad = grad + weight * decy +// grad = AllReduce(grad) / worker_number + +// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> +// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} +class AdjustAllReduceMulAdd : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node); + void Visit(const AnfNodePtr &node) override; + void Reset(); + + private: + int level_{0}; + bool is_reduce_match_{false}; + AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; + AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; + FuncGraphPtr all_reduce_fg_{nullptr}; +}; + +class ArithmeticSimplify : public OptimizerCaller { + public: + ArithmeticSimplify() + : multiply_by_zero_or_one_(std::make_shared()), + tensor_multiply_by_one_(std::make_shared()), + add_by_zero_(std::make_shared()), + tensor_add_by_zero_(std::make_shared()), + identity_(std::make_shared(prim::kPrimIdentity)), + opt_update_zero_tensor_(std::make_shared()), + constant_duplicate_mul_(std::make_shared()), + power_one_(std::make_shared()) { + eliminaters_.emplace_back(multiply_by_zero_or_one_); + eliminaters_.emplace_back(tensor_multiply_by_one_); + eliminaters_.emplace_back(add_by_zero_); + eliminaters_.emplace_back(tensor_add_by_zero_); + eliminaters_.emplace_back(identity_); + eliminaters_.emplace_back(opt_update_zero_tensor_); + eliminaters_.emplace_back(constant_duplicate_mul_); + eliminaters_.emplace_back(power_one_); + } + ~ArithmeticSimplify() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; + + private: + OptimizerCallerPtr multiply_by_zero_or_one_; + OptimizerCallerPtr tensor_multiply_by_one_; + OptimizerCallerPtr add_by_zero_; + OptimizerCallerPtr tensor_add_by_zero_; + OptimizerCallerPtr identity_; + OptimizerCallerPtr opt_update_zero_tensor_; + OptimizerCallerPtr constant_duplicate_mul_; + OptimizerCallerPtr power_one_; + + std::vector eliminaters_{}; +}; + +// Arithmetic Simplifications should be done after step_parallel. +// eg: Mul(0, weight) where weight is a parameter will be simplified to a constant tensor +// with shape(weight), but after step_parallel, shape of weight may be changed, so the +// shape of the constant tensor should also be changed. So this pass is seperated from +// ArithmeticSimplify and deferred until step_parallel. +class ArithmeticSimplify2 : public OptimizerCaller { + public: + ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared()) { + eliminaters_.emplace_back(tensor_multiply_by_zero_); + } + ~ArithmeticSimplify2() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; + + private: + OptimizerCallerPtr tensor_multiply_by_zero_; + std::vector eliminaters_{}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc new file mode 100644 index 0000000000..dc580f6b63 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc @@ -0,0 +1,584 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/branch_culling.h" + +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +AnfNodePtr GenerateSwitchNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data, + int switch_idx) { + auto switch_node = prim::GetPythonOps("geswitch", "mindspore.ops.functional")->cast(); + std::vector switch_nodes{NewValueNode(switch_node), data, cond}; + auto switch_apply = graph->NewCNode(switch_nodes); + std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), switch_apply, + NewValueNode(MakeValue(switch_idx))}; + return graph->NewCNode(tuple_getitem_nodes); +} + +AnfNodePtr GenerateSwitchTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { + return GenerateSwitchNode(graph, cond, data, 1); +} + +AnfNodePtr GenerateSwitchFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { + return GenerateSwitchNode(graph, cond, data, 0); +} + +bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { + // The CNode inputs of the following Primitive with index in std::vector should not be guarded by geswitch + // node because it is attribute or ge specific reason. + // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be + // converted to switch guarded. + std::vector>> white_list({{prim::kPrimApplyMomentum, {1, 2}}, + {prim::kPrimMomentum, {2, 3}}, + {prim::kPrimStateSetItem, {1}}, + {prim::kPrimTupleGetItem, {2}}, + {prim::kPrimEnvGetItem, {1}}, + {prim::kPrimEnvSetItem, {1}}, + {prim::kPrimReduceSum, {2}}, + {prim::kPrimReduceMean, {2}}, + {prim::kPrimReduceAll, {2}}, + {prim::kPrimCast, {2}}, + {prim::kPrimTranspose, {2}}, + {prim::kPrimOneHot, {2}}, + {prim::kPrimGatherV2, {3}}, + {prim::kPrimReshape, {2}}, + {prim::kPrimAssign, {1}}, + {prim::kPrimAssignAdd, {1}}, + {prim::kPrimAssignSub, {1}}, + {prim::kPrimTensorSummary, {1}}, + {prim::kPrimImageSummary, {1}}, + {prim::kPrimScalarSummary, {1}}, + {prim::kPrimApplyRMSProp, {6, 7, 8}}, + {prim::kPrimCumSum, {2}}, + {prim::kPrimTile, {2}}, + {prim::kPrimExpandDims, {2}}, + {prim::kPrimHistogramSummary, {1}}}); + for (auto &item : white_list) { + auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { + return IsPrimitiveCNode(node, item.first) && idx == index; + }); + if (matched) { + return true; + } + } + + std::vector adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend}; + for (auto &item : adapter_convert_ops) { + if (IsPrimitiveCNode(node, item)) { + return true; + } + } + return false; +} + +using NodeInputReplMap = std::unordered_map, AnfNodePtr, PairHasher>; +// replace the nodes which should be changed +void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector> nodes_changed, + std::unordered_map repl_node, NodeInputReplMap repl_node_inputs, + const FuncGraphPtr &func_graph) { + for (auto &node_pair : nodes_changed) { + CNodePtr old_node = node_pair.first; + CNodePtr new_node = node_pair.second; + MS_EXCEPTION_IF_NULL(old_node); + MS_EXCEPTION_IF_NULL(new_node); + for (size_t i = 0; i < old_node->size(); i++) { + auto input = old_node->input(i); + if (repl_node.count(input) != 0) { + new_node->add_input(repl_node[input]); + } else if (repl_node_inputs.count(std::pair(old_node, i)) != 0) { + new_node->add_input(repl_node_inputs[std::pair(old_node, i)]); + } else { + new_node->add_input(input); + } + } + } + + for (auto &item : repl_node) { + if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) { + func_graph->set_output(item.second->cast()->input(1)); + } else if (!manager->Replace(item.first, item.second)) { + MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2) + << " to new: " << item.second->DebugString(2); + } + } +} + +// trace the node that should add switch and replace them with new nodes in the graph +FuncGraphPtr TransformGraphCondBranchNodes( + const FuncGraphPtr &graph, const AnfNodePtr &cond, + const std::function &generate_func) { + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + // record the node that has been changed + std::vector> nodes_changed; + // record the node to be replaced + std::unordered_map repl_node; + // record the node input to be replaced + NodeInputReplMap repl_node_inputs; + const AnfNodeSet &nodes = graph->nodes(); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto inputs = node->cast()->inputs(); + bool should_replace = false; + // if the apply input does not belong to graph, insert a switch node + for (size_t index = 0; index < inputs.size(); index++) { + auto input_node = inputs[index]; + MS_EXCEPTION_IF_NULL(input_node); + // for some ops input should not guard it with switch + if (InConvertWhiteList(node, index)) { + continue; + } + + // If the input for node is not the graph belonged, or it is an ValueNode. + // Bypass the Primitive node which is inputs[0]. + if ((index >= 1 && inputs[index]->func_graph() != nullptr && inputs[index]->func_graph() != graph) || + ((index >= 1 && inputs[index]->isa()))) { + input_node = generate_func(graph, cond, inputs[index]); + repl_node_inputs[std::pair(node, index)] = input_node; + should_replace = true; + } + if (input_node == nullptr) { + MS_LOG(EXCEPTION) << "generate switch node failed"; + } + } + if (should_replace) { + auto new_node = graph->NewCNode(); + repl_node[node] = new_node; + nodes_changed.emplace_back(node->cast(), new_node); + } + } + RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph); + return graph; +} + +struct SharedOp { + tensor::TensorPtr const_data; + CNodePtr square_ops[2]; + CNodePtr merge_ops[2]; +} MergeNetOutput; + +inline tensor::TensorPtr GetConstData() { return MergeNetOutput.const_data; } +inline void SetConstData(const tensor::TensorPtr &const_value) { MergeNetOutput.const_data = const_value; } + +inline CNodePtr GetSquareOp(int switch_idx) { return MergeNetOutput.square_ops[switch_idx]; } +inline void SetSquareOp(int switch_idx, const CNodePtr &op) { MergeNetOutput.square_ops[switch_idx] = op; } + +inline CNodePtr GetMergeOp(int switch_idx) { return MergeNetOutput.merge_ops[switch_idx]; } +inline void SetMergeOp(int switch_idx, const CNodePtr &op) { MergeNetOutput.merge_ops[switch_idx] = op; } + +inline void ResetSharedOp() { + SetConstData(nullptr); + SetSquareOp(0, nullptr); + SetSquareOp(1, nullptr); + SetMergeOp(0, nullptr); + SetMergeOp(1, nullptr); +} + +tensor::TensorPtr ConstData() { + std::vector shp = {1}; + tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); + auto *val = static_cast(const_data->data_c()); + *val = 0; + return const_data; +} + +CNodePtr SquareOp(const FuncGraphPtr &graph, const AnfNodePtr &cond, int switch_idx, + const tensor::TensorPtr &const_data) { + auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); + // for the depended node , add two const data to merge the flow ,one for depended node with same switch, + // the other use the opposite + auto ctrl_data = NewValueNode(const_data); + auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); + + std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; + auto square_op = graph->NewCNode(square_nodes); + + return square_op; +} + +CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int switch_idx, + const tensor::TensorPtr &const_data, const CNodePtr &square_op) { + // for the depended node , add two const data to merge the flow ,one for depended node with same switch, + // the other use the opposite + auto oppsite_ctrl_data = NewValueNode(const_data); + auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); + + std::vector merge_nodes; + auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); + merge_nodes.push_back(NewValueNode(PrimMerge)); + std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; + merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); + auto merge_op = graph->NewCNode(merge_nodes); + + return merge_op; +} + +// construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) +// control_depend(output_node, square_op) +AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node, + int switch_idx) { + tensor::TensorPtr const_data = GetConstData(); + if (const_data == nullptr) { + const_data = ConstData(); + SetConstData(const_data); + } + + CNodePtr square_op = GetSquareOp(switch_idx); + if (square_op == nullptr) { + square_op = SquareOp(graph, cond, switch_idx, const_data); + SetSquareOp(switch_idx, square_op); + } + + CNodePtr merge_op = GetMergeOp(switch_idx); + if (merge_op == nullptr) { + merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op); + SetMergeOp(switch_idx, merge_op); + } + + std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op}; + auto control_depend_op = graph->NewCNode(control_depend_nodes); + + std::vector depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op}; + auto depend_op = graph->NewCNode(depend_nodes); + + return depend_op; +} + +// construct a merge output and add dependency with the netoutput node from control_depend +// we need to reserve the control_depend node, besides the generated merge node and control_depend node +CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, + const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst, + int switch_idx) { + auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); + auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); + std::vector shp = {1}; + tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); + auto *val = static_cast(const_data->data_c()); + *val = 0; + // for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same + // switch the other use the opposite + auto ctrl_data = NewValueNode(const_data); + auto oppsite_ctrl_data = NewValueNode(const_data); + auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); + auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); + + std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; + auto square_op = graph->NewCNode(square_nodes); + + std::vector merge_nodes; + merge_nodes.push_back(NewValueNode(PrimMerge)); + std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; + merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); + auto merge_output = graph->NewCNode(merge_nodes); + + std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op}; + auto cond_dep_output = graph->NewCNode(control_depend_nodes); + + std::vector depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output, + cond_dep_output}; + return graph->NewCNode(depended_make_tuple_nodes); +} + +// generate switch nodes for true graph node inputs +AnfNodePtr GenerateSwitchDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { + // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch + return GenerateSwitchDependNode(graph, cond, data, 1); +} + +// generate switch nodes for false graph node inputs +AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { + // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch + return GenerateSwitchDependNode(graph, cond, data, 0); +} + +// generate switch nodes for true graph node inputs +CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, + const AnfNodePtr &con_input, const AnfNodePtr &output) { + // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch + return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1); +} + +// generate switch nodes for false graph node inputs +CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, + const AnfNodePtr &con_input, const AnfNodePtr &output) { + // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch + return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0); +} + +// to judge if the node used in ControlDepend is a net output node +bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { + auto uses = manager->node_users()[node]; + bool is_output_node = true; + for (auto &item : uses) { + if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) { + continue; + } + is_output_node = false; + break; + } + return is_output_node; +} + +// generate node for Depended MakeTuple +void GenerateReplNodeForDependMakeTuple( + const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond, + const std::shared_ptr> &repl_node, + const std::function &generate_func, + const std::function &gen_ctl_depd_func) { + MS_EXCEPTION_IF_NULL(graph->manager()); + + auto make_tuple_inputs = depended_node->cast()->inputs(); + const size_t make_tuple_begin_idx = 1; + std::vector new_make_tuple_nodes; + bool replace_make_tuple = false; + new_make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t idx = make_tuple_begin_idx; idx < make_tuple_inputs.size(); idx++) { + auto depended_tuple_input_node = make_tuple_inputs[idx]; + if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimDepend)) { + new_make_tuple_nodes.push_back(depended_tuple_input_node); + continue; + } + if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimControlDepend)) { + // only when the control depend input is not square op (the op to use as merge output) + auto control_inputs = depended_tuple_input_node->cast()->inputs(); + if (control_inputs.size() != 3) { + MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); + } + // control inputs: primitive, src, dst + auto dst_node = control_inputs[2]; + if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { + auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node); + MS_EXCEPTION_IF_NULL(gen_node); + auto tuple_inputs = gen_node->inputs(); + // add depended tuple inputs to new_make_tuple directly + for (size_t i = 1; i < tuple_inputs.size(); i++) { + new_make_tuple_nodes.push_back(tuple_inputs[i]); + } + } + replace_make_tuple = true; + continue; + } + + if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) { + auto gen_node = generate_func(graph, cond, depended_tuple_input_node); + new_make_tuple_nodes.push_back(gen_node); + replace_make_tuple = true; + continue; + } + + MS_LOG(WARNING) << "depended node being used by others, "; + } + if (replace_make_tuple) { + auto make_tuple_op = graph->NewCNode(new_make_tuple_nodes); + (*repl_node)[depended_node] = make_tuple_op; + } +} + +// generate a replace depend node for a single network output node +void GenerateRepDepend( + const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond, + const std::shared_ptr> &repl_node, + const std::function &generate_func, + const std::function &gen_ctl_depd_func) { + auto inputs = node->inputs(); + if (inputs.size() != 3) { + MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node]."; + } + + std::vector new_depened_inputs; + // Inputs should be [depend, actual_value, depended_node] + auto depended_node = inputs[2]; + new_depened_inputs.push_back(inputs[0]); + new_depened_inputs.push_back(inputs[1]); + // depended node should be make_tuple or a single depended node + if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) { + GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func); + } else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) { + // only when the control depend input is not square op (the op to use as merge output) + auto control_inputs = depended_node->cast()->inputs(); + // control inputs: primitive, src, dst + if (control_inputs.size() != 3) { + MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); + } + auto dst_node = control_inputs[2]; + if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { + auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node); + (*repl_node)[depended_node] = gen_node; + } + } else { + // Check if there is only single user for depend_node. + if (graph->manager()->node_users()[depended_node].size() == 1) { + auto gen_node = generate_func(graph, cond, depended_node); + (*repl_node)[depended_node] = gen_node; + } else { + MS_LOG(WARNING) << "depended node being used by others"; + } + } +} + +// generate depend node for netoutput node, to resolve the stream synchronize problem of ge +// traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const) +// and add control_depend of graph output node and square node. +FuncGraphPtr TransformGraphDependNode( + const FuncGraphPtr &graph, const AnfNodePtr &cond, + const std::function &gen_depend_func, + const std::function &gen_ctl_depd_func) { + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + ResetSharedOp(); + std::shared_ptr> repl_node = + std::make_shared>(); // record the node to be replaced + const AnfNodeSet &nodes = graph->nodes(); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + auto cnode = node->cast(); + if (cnode->size() != 3) { + MS_LOG(EXCEPTION) << "Dependnode input size != 3"; + } + auto depended_node = cnode->input(2); + MS_EXCEPTION_IF_NULL(depended_node); + if (!depended_node->isa()) { + continue; + } + if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) { + continue; + } + GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func); + } + } + ResetSharedOp(); + + for (auto &item : *repl_node) { + if (!manager->Replace(item.first, item.second)) { + MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed"; + } + } + + return graph; +} + +FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { + (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode); + return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode); +} + +FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { + (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode); + return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode); +} + +// judge if the true and false graph output is compatible(they shall have same tuple size) +bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const AbstractBasePtr &false_branch_abs) { + MS_EXCEPTION_IF_NULL(true_branch_abs); + MS_EXCEPTION_IF_NULL(false_branch_abs); + + if (true_branch_abs->isa() && false_branch_abs->isa()) { + abstract::AbstractTuplePtr true_branch_tuple = true_branch_abs->cast(); + abstract::AbstractTuplePtr false_branch_tuple = false_branch_abs->cast(); + if (true_branch_tuple->elements().size() != false_branch_tuple->elements().size()) { + MS_LOG(ERROR) << "true branch size:" << true_branch_tuple->elements().size() + << ", not equal to false banch size:" << false_branch_tuple->elements().size() << " "; + return false; + } + bool all_compatible = true; + for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { + all_compatible = + all_compatible && GraphOutputCompatible(true_branch_tuple->elements()[i], false_branch_tuple->elements()[i]); + } + return all_compatible; + } + TypePtr true_branch_type = true_branch_abs->BuildType(); + TypePtr false_branch_type = false_branch_abs->BuildType(); + MS_LOG(DEBUG) << "branch output Type equal?" << (*true_branch_type == *false_branch_type) + << " true:" << true_branch_type->ToString() << " false:" << false_branch_type->ToString(); + return (*true_branch_type == *false_branch_type); +} + +AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, + const AbstractBasePtr &true_graph_output_abs, + const AbstractBasePtr &false_graph_output_abs, const FuncGraphPtr &switch_graph, + const AnfNodePtr &cond) { + MS_EXCEPTION_IF_NULL(true_graph_output_abs); + MS_EXCEPTION_IF_NULL(false_graph_output_abs); + MS_EXCEPTION_IF_NULL(cond); + MS_EXCEPTION_IF_NULL(switch_graph); + auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); + MS_EXCEPTION_IF_NULL(PrimMerge); + + if (!true_graph_output_abs->isa()) { + std::vector merge_nodes; + merge_nodes.push_back(NewValueNode(PrimMerge)); + std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), true_output_node, false_output_node}; + merge_nodes.push_back(switch_graph->NewCNode(make_tuple_nodes)); + std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), + switch_graph->NewCNode(merge_nodes), NewValueNode(MakeValue(0))}; + return switch_graph->NewCNode(tuple_getitem_nodes); + } else { + abstract::AbstractTuplePtr true_branch_tuple = true_graph_output_abs->cast(); + abstract::AbstractTuplePtr false_branch_tuple = false_graph_output_abs->cast(); + + std::vector make_tuple_nodes; + make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { + std::vector true_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), true_output_node, + NewValueNode(MakeValue(SizeToInt(i)))}; + auto true_node = switch_graph->NewCNode(true_getitem_nodes); + std::vector false_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), false_output_node, + NewValueNode(MakeValue(SizeToInt(i)))}; + auto false_node = switch_graph->NewCNode(false_getitem_nodes); + + auto merge_node = GenerateMergeNodes(true_node, false_node, true_branch_tuple->elements()[i], + false_branch_tuple->elements()[i], switch_graph, cond); + make_tuple_nodes.push_back(merge_node); + } + return switch_graph->NewCNode(make_tuple_nodes); + } +} + +AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, + const AbstractBasePtr &true_graph_output_abs, + const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, + const FuncGraphPtr &switch_graph) { + if (!GraphOutputCompatible(true_graph_output_abs, false_graph_output_abs)) { + MS_LOG(EXCEPTION) << "Switch output branch not compatible, true:" << true_graph_output_abs->ToString() + << ", false:" << false_graph_output_abs->ToString(); + } + return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, + switch_graph, cond); +} +} // namespace internal +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h new file mode 100644 index 0000000000..b3f3fe4733 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ + +#include +#include + +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/pattern_matcher.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimSwitch, true, X, Y} +// {prim::kPrimSwitch, false, X, Y} +class SwitchSimplify : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode cond, true_br, false_br; + auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { + auto cond_value_ = GetValue(GetValueNode(cond.GetNode(node))); + if (cond_value_) { + return true_br.GetNode(node); + } + return false_br.GetNode(node); + }; + + MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda, + cond.CheckFunc(IsValueNode, node)); + + return nullptr; + } +}; + +// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => +// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} +class FloatTupleGetItemSwitch : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode cond, true_br, false_br, x; + MATCH_REPLACE_IF(node, + PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), + PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x), + PPrimitive(prim::kPrimTupleGetItem, false_br, x)), + x.CheckFunc(IsVNode, node)); + return nullptr; + } +}; + +// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => +// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} +class FloatEnvGetItemSwitch : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode cond, true_br, false_br, x, x2; + MATCH_REPLACE(node, + PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), + PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2), + PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2))); + + return nullptr; + } +}; + +namespace internal { +FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); +FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); +AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, + const AbstractBasePtr &true_graph_output_abs, + const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, + const FuncGraphPtr &func_graph); +} // namespace internal + +// {{prim::kPrimSwitch, X, G1, G2}, Xs} +class ConvertSwitchReplacement : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto cnode_ = node->cast(); + if (cnode_->size() < 1) { + return nullptr; + } + + auto node_ = cnode_->input(0); + + PatternNode cond, true_br, false_br; + + auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { + auto g1_ = GetValueNode(true_br.GetNode(node_)); + auto g2_ = GetValueNode(false_br.GetNode(node_)); + auto x_ = cond.GetNode(node_); + + // for switch replace method, only graphs without graph inside can be replaced + for (auto &item : g1_->value_nodes()) { + auto value_node = item.first; + if (IsValueNode(value_node)) { + return nullptr; + } + } + + for (auto &item : g2_->value_nodes()) { + auto value_node = item.first; + if (IsValueNode(value_node)) { + return nullptr; + } + } + + auto true_output = g1_->output()->abstract(); + auto false_output = g2_->output()->abstract(); + auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_); + auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); + + std::vector params; + auto fg = node_->func_graph(); + auto cloned_g1 = InlineClone(trans_g1, fg, params); + auto cloned_g2 = InlineClone(trans_g2, fg, params); + auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); + + return nnode; + }; + + MATCH_REPLACE_LAMBDA_IF( + node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, + true_br.CheckFunc(IsValueNode, node_) && false_br.CheckFunc(IsValueNode, node_)); + + return nullptr; + } +}; + +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc new file mode 100644 index 0000000000..ddb84806e1 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/cast_eliminate.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "ir/func_graph.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/python_adapter.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimCast, X, T} +AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimCast, {IsNode, IsVNode})(node); + + // check pattern match + if (tgt_ == nullptr) { + return nullptr; + } + + // src type check + auto src_type = src_->Type(); + if (src_type == nullptr || !src_type->isa()) { + return nullptr; + } + + src_type = src_type->cast()->element(); + + // tgt type check + auto tgt_type = GetValueNode(tgt_); + if (tgt_type->isa()) { + tgt_type = tgt_type->cast()->element(); + } + + if (src_type->type_id() == tgt_type->type_id()) { + return src_; + } + + return nullptr; +} + +void CastSameTypeEliminater::Visit(const AnfNodePtr &node) { + if (src_ == nullptr) { + src_ = node; + } else { + tgt_ = node; + } +} + +// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} +AnfNodePtr TwoCastEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimCast, {IsCNode, IsNode})(node); + + if (x_ != nullptr && t_ != nullptr) { + auto cast_op = parse::python_adapter::GetPyFn("mindspore.ops.operations", "Cast")(); + ValuePtr cast = parse::data_converter::PyDataToValue(cast_op); + auto cnode = NewCNode({NewValueNode(cast), x_, t_}, node->func_graph()); + cnode->set_abstract(node->abstract()); + return cnode; + } + return nullptr; +} + +void TwoCastEliminater::Visit(const AnfNodePtr &node) { + if (IsPrimitiveCNode(node, prim::kPrimCast)) { + auto cnode = node->cast(); + // {prim::kPrimCast, X, Y} + if (cnode->size() != 3) { + return; + } + x_ = cnode->input(1); + } else { + t_ = node; + } +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h new file mode 100644 index 0000000000..d5222d4310 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ + +#include "ir/visitor.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimCast, X, T} +class CastSameTypeEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + void Visit(const AnfNodePtr &node) override; + void Reset() { + src_ = nullptr; + tgt_ = nullptr; + } + + private: + AnfNodePtr src_{nullptr}, tgt_{nullptr}; +}; + +// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} +class TwoCastEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + void Visit(const AnfNodePtr &node) override; + void Reset() { + x_ = nullptr; + t_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}, t_{nullptr}; +}; + +class CastEliminater : public OptimizerCaller { + public: + CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} + ~CastEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + auto new_node = cast_same_type_eliminater_(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + + new_node = two_cast_eliminater_(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + + return nullptr; + } + + private: + CastSameTypeEliminater cast_same_type_eliminater_; + TwoCastEliminater two_cast_eliminater_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/convert.h b/mindspore/ccsrc/frontend/optimizer/irpass/convert.h new file mode 100644 index 0000000000..d887874203 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/convert.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ + +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimPrint, Xs} -> {prim::kPrimPrint, {prim::kPrinMakeTuple, Xs}} +class PrintTupleWrapper : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimPrint)) { + return nullptr; + } + + // already be {prim::kPrimPrint, {prim::kPrinMakeTuple, Xs}} + auto cnode = node->cast(); + if (cnode->size() == 2 && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeTuple)) { + return nullptr; + } + + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + + // {prim::kPrimPrint, Xs} + auto &inputs = cnode->inputs(); + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + + // {prim::kPrinMakeTuple, Xs} + auto fg = node->func_graph(); + auto tuple = NewCNode(args, fg); + auto print = GetValueNode(cnode->input(0)); + return NewCNode({NewValueNode(print), tuple}, fg); + } +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h new file mode 100644 index 0000000000..14fd8743ff --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h @@ -0,0 +1,364 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ + +#include +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "utils/symbolic.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +class EnvGetitemTransform { + public: + EnvGetitemTransform() : cache_() {} + ~EnvGetitemTransform() = default; + + FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) { + if (cache_.find(fg) == cache_.end()) { + cache_[fg] = {}; + } + + auto &cache = cache_[fg]; + auto hash_key = std::make_pair(key, default_node); + if (cache.find(hash_key) == cache.end()) { + std::ostringstream ss("env", std::ostringstream::app); + if (key->node() != nullptr) { + ss << key->node()->ToString(); + } + + auto new_fg = TransformableClone(fg, std::make_shared(ss.str())); + auto env = new_fg->output(); + while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { + // {prim::kPrimEnvSetItem, env, symbolickey, value} + auto &inputs = env->cast()->inputs(); + if (inputs.size() != 4 || !IsValueNode(inputs[2])) { + MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance."; + } + + env = inputs[1]; + auto value = inputs[3]; + auto key2 = GetValueNode(inputs[2]); + if (*key2 == *key) { + new_fg->set_output(value); + cache[hash_key] = new_fg; + cache_[fg] = cache; + return new_fg; + } + } + new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node})); + cache[hash_key] = new_fg; + } + + return cache[hash_key]; + } + + private: + std::unordered_map, FuncGraphPtr, PairHasher>> + cache_; +}; +} // namespace internal + +// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y +class NewEnvGetItem : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + auto gety = [this](const AnfNodePtr &node) -> bool { + this->y_ = node; + return true; + }; + + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode, IsVNode, gety})(node); + if (env_ != nullptr && env_->Len() == 0) { + return y_; + } + return nullptr; + } + + void Visit(const ValueNodePtr &vnode) override { + if (env_ == nullptr) { + env_ = GetValueNode(vnode); + } + } + + void Reset() { + y_ = nullptr; + env_ = nullptr; + } + + private: + AnfNodePtr y_{nullptr}; + EnvInstancePtr env_{nullptr}; +}; + +// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> +// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}} +class AddEnvGetItem : public AnfVisitor { + public: + AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {} + ~AddEnvGetItem() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + is_match_ = false; + auto IsAddCNode = [](const AnfNodePtr &node) -> bool { + return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast()->size() == 3; + }; + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node); + + if (!is_match_ || node->func_graph() == nullptr) { + return nullptr; + } + + // {prim::kPrimEnvGetItem, {...}, C, Z} + auto cnode = node->cast(); + auto inp1 = cnode->input(1)->cast(); + auto c = cnode->input(2); + auto z = cnode->input(3); + + // {prim::kPrimEnvAdd, X, Y} + auto x = inp1->input(1); + auto y = inp1->input(2); + + auto fg = node->func_graph(); + auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z}); + auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z}); + + return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz}); + } + + void Visit(const AnfNodePtr &) override { is_match_ = true; } + + private: + bool is_match_{false}; + ValuePtr PrimHyperAdd_; +}; + +// {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z} +class EnvGetSetItem : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + is_match_ = false; + auto IsSetCNode = [](const AnfNodePtr &node) -> bool { + if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) { + return false; + } + + // {prim::kPrimEnvSetItem, X, C1, Y} + auto &inputs = node->cast()->inputs(); + if (inputs.size() != 4) { + return false; + } + + return IsValueNode(inputs[2]); + }; + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode, IsNode})(node); + + if (!is_match_ || node->func_graph() == nullptr) { + return nullptr; + } + + // {prim::kPrimEnvGetItem, {...}, C2, Z} + auto cnode = node->cast(); + auto inp1 = cnode->input(1)->cast(); + auto key2 = cnode->input(2); + auto c2 = GetValueNode(key2); + auto default_v = cnode->input(3); + + // {prim::kPrimEnvSetItem, X, C1, Y} + auto env = inp1->input(1); + auto c1 = GetValueNode(inp1->input(2)); + auto last_set = inp1->input(3); + + if (*c1 == *c2) { + return last_set; + } + + while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { + // {prim::kPrimEnvSetItem, env, symbolickey, value} + auto &inputs = env->cast()->inputs(); + if (inputs.size() != 4 || !IsValueNode(inputs[2])) { + MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance."; + } + + env = inputs[1]; + last_set = inputs[3]; + auto symbolic_c1 = GetValueNode(inputs[2]); + if (*symbolic_c1 == *c2) { + return last_set; + } + } + + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v}); + } + + void Visit(const AnfNodePtr &) override { is_match_ = true; } + + private: + bool is_match_{false}; +}; + +class EnvGetItemEliminater : public OptimizerCaller { + public: + EnvGetItemEliminater() + : new_env_get_item_(std::make_shared()), + add_env_get_item_(std::make_shared()), + env_get_set_item_(std::make_shared()) { + eliminaters_.emplace_back(new_env_get_item_); + eliminaters_.emplace_back(add_env_get_item_); + eliminaters_.emplace_back(env_get_set_item_); + } + ~EnvGetItemEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; + std::vector eliminaters_{}; +}; + +// {prim::kPrimEnvGetItem, {G, Xs}, C, Y} +class IncorporateEnvGetitem : public AnfVisitor { + public: + IncorporateEnvGetitem() : env_get_item_transform_() {} + ~IncorporateEnvGetitem() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + is_match_ = false; + auto IsGCNode = [](const AnfNodePtr &node) -> bool { + auto cnode = node->cast(); + if (cnode == nullptr || cnode->size() < 1) { + return false; + } + return IsValueNode(cnode->input(0)); + }; + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode, IsNode})(node); + + if (!is_match_) { + return nullptr; + } + + // {prim::kPrimEnvGetItem, {...}, C, Y} + auto cnode = node->cast(); + auto inp1 = cnode->input(1)->cast(); + auto key = GetValueNode(cnode->input(2)); + auto default_v = cnode->input(3); + + // {G, Xs} + auto inputs = inp1->inputs(); + auto fg = GetValueNode(inputs[0]); + auto new_fg = env_get_item_transform_(fg, key, default_v); + + std::vector args; + args.push_back(NewValueNode(new_fg)); + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + + return node->func_graph()->NewCNode(args); + } + + void Visit(const AnfNodePtr &) override { is_match_ = true; } + + private: + bool is_match_{false}; + internal::EnvGetitemTransform env_get_item_transform_; +}; + +// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} +class IncorporateEnvGetitemSwitch : public AnfVisitor { + public: + IncorporateEnvGetitemSwitch() : env_get_item_transform_() {} + ~IncorporateEnvGetitemSwitch() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + is_match_ = false; + auto IsSwNode = [](const AnfNodePtr &node) -> bool { + auto cnode = node->cast(); + if (cnode == nullptr || cnode->size() < 1) { + return false; + } + + return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch); + }; + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode, IsNode})(node); + if (!is_match_ || node->func_graph() == nullptr) { + return nullptr; + } + + // {prim::kPrimEnvGetItem, {...}, C, Y} + auto cnode = node->cast(); + auto inp1 = cnode->input(1)->cast(); + auto key = GetValueNode(cnode->input(2)); + auto default_v = cnode->input(3); + + // {{prim::kPrimSwitch, X, G1, G2}, Xs} + auto inputs = inp1->inputs(); + is_match_ = false; + AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(inputs[0]); + if (!is_match_) { + return nullptr; + } + + // {prim::kPrimSwitch, X, G1, G2} + auto sw = inputs[0]->cast(); + auto x = sw->input(1); + auto g1 = GetValueNode(sw->input(2)); + auto g2 = GetValueNode(sw->input(3)); + auto new_g1 = env_get_item_transform_(g1, key, default_v); + auto new_g2 = env_get_item_transform_(g2, key, default_v); + + auto fg = node->func_graph(); + auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)}); + + std::vector args{new_sw}; + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + + return fg->NewCNode(args); + } + + void Visit(const AnfNodePtr &) override { is_match_ = true; } + + private: + bool is_match_{false}; + internal::EnvGetitemTransform env_get_item_transform_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc new file mode 100644 index 0000000000..44c1b62fa5 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/grad_var_prepare.h" +#include +#include +#include +#include + +#include "frontend/operator/composite/composite.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace opt { +namespace irpass { +static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, FuncGraphPtr func_graph, + AnfNodePtr func_node, bool is_unpack, bool sens_param) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(func_node); + std::vector nodes; + AnfNodePtr unpack_graph_node = nullptr; + if (is_unpack) { + auto unpack_graph = std::make_shared("unpack_graph", sens_param, true); + nodes.push_back(NewValueNode(unpack_graph)); + nodes.push_back(func_node); + // {unpackcall, {GradOperation, ...}, args...} + std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), + [](const AnfNodePtr &node) { return node; }); + unpack_graph_node = func_graph->NewCNode(nodes); + } else { + auto unpack_graph = std::make_shared("unpack_graph", sens_param, false); + nodes.push_back(NewValueNode(unpack_graph)); + nodes.push_back(func_node); + // {{GradOperation, ...}, args...} + std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), + [](const AnfNodePtr &node) { return node; }); + unpack_graph_node = func_graph->NewCNode(nodes); + } + return unpack_graph_node; +} + +// get metagraph of value node +MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) { + ValuePtr value; + if (IsValueNode(node)) { + value = GetValueNode(node)->cast()->function(); + } else { + value = GetValueNode(node); + } + if (value == nullptr) { + return nullptr; + } + return value->cast(); +} + +// check if node is a specific metafuncgraph op +bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) { + if (node != nullptr) { + auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); + if (meta_func_graph_ptr == nullptr) { + return false; + } + + if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) { + return true; + } + } + return false; +} + +// {{GradOperation, g, w}, Ys} +// {UnPackCall, {GradOperation, g, w}, Ys} +AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + // {{...}, Ys} + auto inputs_y = node->cast()->inputs(); + std::vector inputs_x; + if (IsCNode(inputs_y[0])) { + inputs_x = inputs_y[0]->cast()->inputs(); + } else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) { + inputs_x = inputs_y[1]->cast()->inputs(); + } else { + return nullptr; + } + + // {{...}, Xs} + if (inputs_x.size() < 2) { + return nullptr; + } + + // {GradOperation, g, w} or {GradOperation, g} + if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) { + return nullptr; + } + + auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]); + if (meta_func == nullptr) { + return nullptr; + } + auto grad_op_ptr = meta_func->cast(); + auto func_node = inputs_x[1]; + if (!IsValueNode(func_node)) { + return nullptr; + } + + AnfNodePtr unpack_graph_node = + GenerateUnpackGraphNode(inputs_y, node->cast()->func_graph(), func_node, + IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param()); + // constuct new grad_opration + inputs_x[1] = unpack_graph_node; + auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x); + if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) { + inputs_y[1] = grad_op_cnode; + } else { + inputs_y[0] = grad_op_cnode; + } + auto cnode = node->func_graph()->NewCNode(inputs_y); + return cnode; +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.h b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.h new file mode 100644 index 0000000000..f6992a87c6 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ + +#include +#include +#include +#include + +#include "frontend/operator/composite/composite.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {{GradOperation, g, w}, Ys} +// {UnPackCall, {GradOperation, g, w}, Ys} +class GradVarPrepare : public AnfVisitor { + public: + GradVarPrepare() + : grad_op_(std::make_shared("grad")), + unpack_op_(std::make_shared("unpack_call")) {} + ~GradVarPrepare() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + private: + MetaFuncGraphPtr grad_op_; + MetaFuncGraphPtr unpack_op_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc new file mode 100644 index 0000000000..0d98cffa37 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/gradient_eliminate.h" + +#include + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { + ScopeGuard scope_guard(vnode->scope()); + + auto newg = ad::Kprim(vnode, resource); + if (newg != nullptr) { + return NewValueNode(newg); + } + + // when find in J failed, try in Jmeta + auto prim = GetValueNode(vnode); + MetaFuncGraphPtr meta = ad::Kmeta(prim, resource); + if (meta != nullptr) { + return NewValueNode(meta); + } + + return nullptr; +} + +bool CheckIfEmbedJFuncGraph(const FuncGraphPtr func_graph) { + // if func graph also contain J FuncGraph, then ignore this funcgraph. ExpandJ innermost graph first; + auto func_graph_manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(func_graph_manager); + return func_graph_manager->func_graph_j_total(func_graph); +} + +AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { + if (IsValueNode(vnode)) { + ScopeGuard scope_guard(vnode->scope()); + + auto func_graph = GetValueNode(vnode); + MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString(); + + // high_order_grad begin; + // if graph also contain J Graph, then ignore this graph. ExpandJ innermost graph first; + if (CheckIfEmbedJFuncGraph(func_graph)) { + MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J(funcgraph), will expandJ later"; + return nullptr; + } + // high_order_grad end; + + MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now"; + auto newfg = ad::Grad(func_graph, resource); + return NewValueNode(newfg); + } + + if (IsValueNode(vnode)) { + return ExpandJPrimitive(vnode, resource); + } + + return nullptr; +} +} // namespace internal +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h new file mode 100644 index 0000000000..82312d9e37 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ + +#include +#include +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "common/utils.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/ad/grad.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource); +} // namespace internal + +// {prim::kPrimJ, C} +class ExpandJPrim : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + x_ = nullptr; + AnfVisitor::Match(prim::kPrimJ, {IsVNode})(node); + if (x_ != nullptr) { + TraceManager::DebugTrace(std::make_shared(node->debug_info())); + auto j_node = internal::ExpandJ(x_, optimizer->resource()); + TraceManager::EndTrace(); + return j_node; + } + return nullptr; + } + + void Visit(const ValueNodePtr &node) override { x_ = node; } + + private: + ValueNodePtr x_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h new file mode 100644 index 0000000000..2f6404458f --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ + +#include +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +class CallOutputTransform { + public: + CallOutputTransform() : cache_() {} + ~CallOutputTransform() = default; + + FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs) { + if (cache_.find(fg) == cache_.end()) { + cache_[fg] = {}; + } + + auto &cache = cache_[fg]; + if (cache.find(nargs) == cache.end()) { + FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("call")); + + std::vector new_items; + new_items.push_back(new_fg->output()); + for (size_t i = 0; i < nargs; i++) { + new_items.push_back(new_fg->add_parameter()); + } + new_fg->set_output(new_fg->NewCNode(new_items)); + + cache[nargs] = new_fg; + } + return cache[nargs]; + } + + private: + std::unordered_map> cache_; +}; +} // namespace internal + +// {{G, Xs}, Ys} +class IncorporateCall : public AnfVisitor { + public: + IncorporateCall() : call_output_transform_() {} + ~IncorporateCall() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs[0] == nullptr || !inputs[0]->isa()) { + return nullptr; + } + + AnfVisitor::Visit(inputs[0]); + if (fg_ == nullptr) { + return nullptr; + } + + auto xs_size = Xs_.size(); + auto ys_size = inputs.size() - 1; + auto new_fg = call_output_transform_(fg_, ys_size); + + std::vector args; + args.push_back(NewValueNode(new_fg)); + + if (xs_size > 0) { + (void)args.insert(args.end(), Xs_.begin(), Xs_.end()); + } + + if (ys_size > 0) { + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + } + + return node->func_graph()->NewCNode(args); + } + + void Visit(const CNodePtr &cnode) override { + // {G, Xs} + if (cnode->size() < 1 || !IsValueNode(cnode->input(0))) { + return; + } + + auto &inputs = cnode->inputs(); + fg_ = GetValueNode(inputs[0]); + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); + } + + void Reset() { + Xs_.clear(); + fg_ = nullptr; + } + + private: + FuncGraphPtr fg_; + std::vector Xs_{}; + internal::CallOutputTransform call_output_transform_; +}; + +// {{{prim::kPrimSwitch, X, G1, G2}, Xs}, Ys} +class IncorporateCallSwitch : public AnfVisitor { + public: + IncorporateCallSwitch() : call_output_transform_() {} + ~IncorporateCallSwitch() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + // {{...}, Ys} + auto &inputs = node->cast()->inputs(); + if (inputs[0] == nullptr || !inputs[0]->isa()) { + return nullptr; + } + + // {{...}, Xs} + auto &inputs_x = inputs[0]->cast()->inputs(); + if (inputs_x[0] == nullptr || !inputs_x[0]->isa()) { + return nullptr; + } + + // {prim::kPrimSwitch, X, G1, G2} + AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(inputs_x[0]); + if (g2_ == nullptr) { + return nullptr; + } + + auto fg = node->func_graph(); + auto xs_size = inputs_x.size() - 1; + auto ys_size = inputs.size() - 1; + auto new_g1 = call_output_transform_(g1_, ys_size); + auto new_g2 = call_output_transform_(g2_, ys_size); + auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); + + std::vector args{sw_node}; + if (xs_size > 0) { + (void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end()); + } + if (ys_size > 0) { + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + } + + return fg->NewCNode(args); + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + return; + } + AnfVisitor::Visit(node); + } + + void Visit(const ValueNodePtr &vnode) override { + auto g = GetValueNode(vnode); + if (g1_ == nullptr) { + g1_ = g; + } else { + g2_ = g; + } + } + + void Reset() { + x_ = nullptr; + g1_ = nullptr; + g2_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}; + FuncGraphPtr g1_{nullptr}, g2_{nullptr}; + internal::CallOutputTransform call_output_transform_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h new file mode 100644 index 0000000000..828e205e4f --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -0,0 +1,416 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ + +#include +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +class GetitemTransform { + public: + GetitemTransform() : cache_() {} + ~GetitemTransform() = default; + + FuncGraphPtr operator()(const FuncGraphPtr &fg, int idx) { + if (cache_.find(fg) == cache_.end()) { + cache_[fg] = {}; + } + + auto &cache = cache_[fg]; + if (cache.find(idx) == cache.end()) { + std::ostringstream ss("tp", std::ostringstream::app); + ss << idx; + + auto new_fg = TransformableClone(fg, std::make_shared(ss.str())); + auto output = new_fg->output(); + if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { + auto cnode = output->cast(); + auto ids = IntToSize(idx + 1); + // Inputs should be [make_tuple, item1, item2, ...], so have to offset idx in tuple_getitem by 1. + if (ids >= cnode->size()) { + MS_LOG(EXCEPTION) << "index " << ids << " is out of inputs length " << cnode->size(); + } + new_fg->set_output(cnode->input(ids)); + } else { + new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(idx)})); + } + + cache[idx] = new_fg; + } + return cache[idx]; + } + + private: + std::unordered_map> cache_; +}; +} // namespace internal + +// {prim::kPrimTupleGetItem, {G, Xs}, C} +class IncorporateGetitem : public AnfVisitor { + public: + IncorporateGetitem() : getitem_transform_() {} + ~IncorporateGetitem() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); + if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr) { + return nullptr; + } + + if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + // If graph kernel has muti output, do not split. + // some graph kernel output has EnvInstance node or DeadCode node should split. + auto output = fg_->output(); + if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { + auto output_cnode = output->cast(); + auto outputs = output_cnode->inputs(); + int real_output_cnt = 0; + for (size_t i = 1; i < outputs.size(); ++i) { + if (IsCNode(outputs[i]) || IsValueNode(outputs[i]) || IsParam(outputs[i])) { + real_output_cnt++; + if (real_output_cnt > 1) { + return nullptr; + } + } + } + } + } + + auto new_fg = getitem_transform_(fg_, idx_); + (void)args_.insert(args_.begin(), NewValueNode(new_fg)); + return node->func_graph()->NewCNode(args_); + } + + void Visit(const CNodePtr &cnode) override { + if (cnode->size() == 0 || !IsValueNode(cnode->input(0))) { + return; + } + + auto &inputs = cnode->inputs(); + fg_ = GetValueNode(inputs[0]); + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); + } + + void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } + + void Reset() { + idx_ = -1; + fg_ = nullptr; + args_.clear(); + } + + private: + int idx_{-1}; + FuncGraphPtr fg_{nullptr}; + std::vector args_{}; + internal::GetitemTransform getitem_transform_; +}; + +class IncorporateGetitemFromParam : public AnfVisitor { + public: + void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr ¶m, size_t input_idx) { + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto &node_users = mng->node_users(); + if (node_users.find(param) == node_users.end() || node_users[param].empty()) { + args_.push_back(cnode->input(input_idx + 1)); + return; + } + + for (auto &user : node_users[param]) { + if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { + // we do not process this case. + args_.push_back(cnode->input(input_idx + 1)); + return; + } + } + + // update new args. + if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) { + // case 1 + replace_parameters_[input_idx] = true; + need_update_ = true; + auto make_tuple_cnode = cnode->input(input_idx + 1)->cast(); + auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs(); + inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1; + args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end()); + } else { + // case 2 + auto prev_cnode = cnode->input(input_idx + 1)->cast(); + auto prev_fg = GetValueNode(prev_cnode->input(0)); + auto fg_output = prev_fg->output(); + if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) { + MS_LOG(ERROR) << "The return of: " << prev_fg->ToString() + << " should be a make tuple, but got: " << fg_output->DebugString(); + return; + } + replace_parameters_[input_idx] = true; + need_update_ = true; + auto make_tuple_cnode = fg_output->cast(); + inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1; + for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) { + auto new_getitem = + func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToInt(output_i))}); + auto aptr = std::make_shared(std::make_shared(SizeToInt(output_i))); + new_getitem->input(2)->set_abstract(aptr); + new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract()); + args_.push_back(new_getitem); + } + } + } + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (node->func_graph() == nullptr) { + return nullptr; + } + + Reset(); + + auto cnode = node->cast(); + if (cnode == nullptr) { + return nullptr; + } + auto &inputs = cnode->inputs(); + auto fg = GetValueNode(inputs[0]); + if (fg == nullptr) { + return nullptr; + } + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto parameters = fg->parameters(); + if (parameters.size() != inputs.size() - 1) { + return nullptr; + } + replace_parameters_ = std::vector(parameters.size(), false); + inputs_num_ = std::vector(parameters.size(), 1); + auto node_fg = node->func_graph(); + + for (size_t i = 1; i < inputs.size(); ++i) { + if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) { + Process(node_fg, cnode, parameters[i - 1], i - 1); + } else { + args_.push_back(inputs[i]); + } + } + + if (!need_update_) { + return nullptr; + } + + FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("sp")); + mng->AddFuncGraph(new_fg); + + auto node_users = mng->node_users(); + std::vector new_fg_parameters = new_fg->parameters(); + std::vector new_parameters; + size_t curr_input_idx{0}; + for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) { + if (!replace_parameters_[param_i]) { + if (parameters[param_i]->abstract() != nullptr) { + new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract()); + } + new_parameters.push_back(new_fg_parameters[param_i]); + curr_input_idx++; + continue; + } + + // make a new parameter. + for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) { + auto new_param = std::make_shared(new_fg); + new_param->set_abstract(args_.at(curr_input_idx)->abstract()); + + // update users of new parameter. + for (auto &user : node_users[new_fg_parameters[param_i]]) { + idx_ = -1; + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode})(user.first); + if (idx_ == -1) { + MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString() + << " must be tuple getitem here, but got: " << user.first->DebugString(); + return nullptr; + } + + if (input_i == IntToSize(idx_)) { + for (auto &sub_user : node_users[user.first]) { + auto sub_user_cnode = sub_user.first->cast(); + MS_EXCEPTION_IF_NULL(sub_user_cnode); + sub_user_cnode->set_input(sub_user.second, new_param); + (void)mng->Replace(sub_user.first, sub_user_cnode); + } + } + } + + // (void)mng->Replace(new_fg_parameters[param_i], new_param); + new_parameters.push_back(new_param); + curr_input_idx++; + } + } + + mng->SetParameters(new_fg, new_parameters); + (void)args_.insert(args_.begin(), NewValueNode(new_fg)); + auto new_call = node_fg->NewCNode(args_); + new_call->set_abstract(node->abstract()); + return new_call; + } + + void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } + + void Visit(const CNodePtr &cnode) override {} + + void Reset() { + replace_parameters_.clear(); + args_.clear(); + inputs_num_.clear(); + need_update_ = false; + idx_ = -1; + } + + private: + std::vector replace_parameters_{}; + std::vector args_{}; + std::vector inputs_num_{}; + bool need_update_{false}; + int idx_{-1}; +}; + +// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C} +class IncorporateGetitemSwitch : public AnfVisitor { + public: + IncorporateGetitemSwitch() : getitem_transform_() {} + ~IncorporateGetitemSwitch() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + is_in_get_ = true; + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); + is_in_get_ = false; + + auto fg = node->func_graph(); + if (idx_ == -1 || switch_ == nullptr || fg == nullptr) { + return nullptr; + } + + is_in_switch_ = true; + AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(switch_); + is_in_switch_ = false; + + if (g2_ == nullptr) { + return nullptr; + } + + auto new_g1 = getitem_transform_(g1_, idx_); + auto new_g2 = getitem_transform_(g2_, idx_); + auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); + (void)args_.insert(args_.begin(), sw_node); + + return fg->NewCNode(args_); + } + + void Visit(const AnfNodePtr &node) override { + if (is_in_switch_ && x_ == nullptr) { + x_ = node; + return; + } + AnfVisitor::Visit(node); + } + + void Visit(const CNodePtr &cnode) override { + if (is_in_get_ && cnode->size() != 0) { + auto &inputs = cnode->inputs(); + switch_ = inputs[0]; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (is_in_get_) { + idx_ = GetValue(vnode->value()); + } + + if (is_in_switch_) { + auto g = GetValueNode(vnode); + if (g1_ == nullptr) { + g1_ = g; + } else { + g2_ = g; + } + } + } + + void Reset() { + x_ = nullptr; + g1_ = nullptr; + g2_ = nullptr; + switch_ = nullptr; + args_.clear(); + is_in_get_ = false; + is_in_switch_ = false; + } + + private: + int idx_{-1}; + AnfNodePtr switch_{nullptr}, x_{nullptr}; + FuncGraphPtr g1_{nullptr}, g2_{nullptr}; + bool is_in_get_{false}, is_in_switch_{false}; + std::vector args_{}; + internal::GetitemTransform getitem_transform_; +}; + +class IncorporateGetitemSet : public OptimizerCaller { + public: + IncorporateGetitemSet() + : incorporate_getitem_(std::make_shared()), + incorporate_getitem_switch_(std::make_shared()) { + eliminaters_.emplace_back(incorporate_getitem_); + eliminaters_.emplace_back(incorporate_getitem_switch_); + } + ~IncorporateGetitemSet() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; + std::vector eliminaters_{}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/indexed_slices_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/indexed_slices_eliminate.h new file mode 100644 index 0000000000..dfe345fe01 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/indexed_slices_eliminate.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ + +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}} +// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}} +// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}} +class IndexedSlicesEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(1); + } + AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(2); + } + AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(3); + } + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) { + tuple_ = cnode; + is_match_ = true; + } + } + + void Reset() { + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + CNodePtr tuple_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h new file mode 100644 index 0000000000..8cafb268b4 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -0,0 +1,204 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +class ReplaceApplicator : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsValueNode(node)) { + return nullptr; + } + + auto fg = GetValueNode(node); + if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { + return nullptr; + } + + auto out = fg->output(); + MS_EXCEPTION_IF_NULL(out); + if (!out->isa()) { + return nullptr; + } + + auto &inputs = out->cast()->inputs(); + auto params = fg->parameters(); + + // Exclude first elements of inputs which is fn. + auto input_size = inputs.size(); + auto param_size = params.size(); + if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size && + std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) { + auto inner = inputs[0]; + if (IsValueNode(inner) || + (IsValueNode(inner) && GetValueNode(inner)->parent() == nullptr)) { + return inner; + } + } + + return nullptr; + } +}; + +using CriterionFuncType = std::function; + +bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { + auto n_cnode = fg->nodes().size() - fg->parameters().size(); + // There is at least one CNode(return, other_node). + return n_cnode <= 2; +} + +bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { + auto &cnodes = fg->func_graph_cnodes_index(); + int n_use = + std::accumulate(cnodes.begin(), cnodes.end(), 0, + [](int sum, const std::pair &item) { return sum + item.second; }); + return n_use == 1; +} + +bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node->func_graph()); + return node->func_graph()->has_flag("inline_inside"); +} + +bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } + +bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } + +// {G, Xs} +class InlinerBase : public AnfVisitor { + public: + explicit InlinerBase(std::vector> criterions) : criterions_(criterions) {} + ~InlinerBase() override = default; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa()) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 1 || !IsValueNode(inputs[0])) { + return nullptr; + } + + // G + auto fg = GetValueNode(inputs[0]); + if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { + return nullptr; + } + // Do not inline GraphKernel to Cell. + if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + // If the GraphKernel only contains a return node, we make it inlined. + if (fg->nodes().size() - fg->parameters().size() > 1) { + return nullptr; + } + } + + Reset(); + bool is_match = false; + for (auto &criterion : criterions_) { + if (!criterion.first(fg, node)) { + continue; + } + + if (criterion.second && IsRecursive(fg)) { + continue; + } + + is_match = true; + break; + } + + if (!is_match) { + return nullptr; + } + + std::vector params; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params)); + + if (IsUniqueUse(fg, nullptr)) { + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + ReplaceParams(mng, params, fg); + auto out_node = fg->output(); + mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); + return out_node; + } + + return InlineClone(fg, node->func_graph(), params, inputs[0]->scope()); + } + + void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector &new_params, + const FuncGraphPtr &fg) { + auto params = fg->parameters(); + auto old_size = params.size(); + if (old_size != new_params.size()) { + MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size() + << fg->output()->DebugString(10); + } + for (size_t i = 0; i < old_size; i++) { + (void)mng->Replace(params[i], new_params[i]); + } + } + + bool IsRecursive(const FuncGraphPtr &fg) { + if (!is_checked_) { + is_checked_ = true; + is_recursive_ = fg->recursive(); + } + return is_recursive_; + } + + void Reset() { + is_checked_ = false; + is_recursive_ = false; + } + + private: + bool is_checked_{false}, is_recursive_{false}; + std::vector> criterions_; +}; + +class Inliner : public InlinerBase { + public: + Inliner() + : InlinerBase({ + {IsUniqueUse, true}, + {IsTrivial, false}, + {IsInside, false}, + {IsCore, false}, + {NoCriterion, true}, + }) {} + ~Inliner() override = default; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h new file mode 100644 index 0000000000..acd6844ee7 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h @@ -0,0 +1,301 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ + +#include +#include +#include + +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// (a, b, c, ...)[0] => a +// (a, b, c, ...)[1] => b +// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} +class GetitemEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); + + if (is_match_) { + return tuple_->input(id_); + } + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + tuple_ = cnode; + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (tuple_ != nullptr && IsValueNode(vnode)) { + id_ = IntToSize(GetValue(vnode->value()) + 1); + if (tuple_->size() > id_) { + is_match_ = true; + } + } + } + + void Reset() { + id_ = 0; + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + size_t id_{0}; + CNodePtr tuple_{nullptr}; +}; + +// (a, b, c, ...)[0] => a +// (a, b, c, ...)[1] => b +// {prim::kPrimTupleGetItem, C1, C} +class GetitemConstEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node); + + if (is_match_) { + return NewValueNode((*tuple_)[id_]); + } + return nullptr; + } + + void Visit(const ValueNodePtr &vnode) override { + if (IsValueNode(vnode)) { + tuple_ = GetValueNode(vnode); + } + if (tuple_ != nullptr && IsValueNode(vnode)) { + id_ = IntToSize(GetValue(vnode->value())); + if (tuple_->size() > id_) { + is_match_ = true; + } + } + } + + void Reset() { + id_ = 0; + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + size_t id_{0}; + ValueTuplePtr tuple_{nullptr}; +}; + +// setitem((a, b, c, ...), 0, z) => (z, b, c, ...) +// setitem((a, b, c, ...), 1, z) => (a, z, c, ...) +// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} +class SetitemEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); + + auto fg = node->func_graph(); + if (fg != nullptr && z_ != nullptr) { + args_[id_] = z_; + return fg->NewCNode(args_); + } + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (is_match_) { + z_ = node; + return; + } + + AnfVisitor::Visit(node); + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + auto &inputs = cnode->inputs(); + (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(args_)); + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (args_.size() > 0 && IsValueNode(vnode)) { + id_ = IntToSize(GetValue(vnode->value()) + 1); + if (id_ < args_.size()) { + is_match_ = true; + } + } + } + + void Reset() { + id_ = 0; + z_ = nullptr; + is_match_ = false; + args_.clear(); + } + + private: + bool is_match_{false}; + size_t id_{0}; + AnfNodePtr z_{nullptr}; + std::vector args_{}; +}; + +// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} +class GetSetitemEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); + + auto fg = node->func_graph(); + if (fg != nullptr && key1_ >= 0 && key2_ >= 0) { + if (key1_ == key2_) { + return last_; + } + return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_, c2_}); + } + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) { + if (cnode->size() < 4) { + return; + } + + tuple_ = cnode->input(1); + last_ = cnode->input(3); + + // key of setitem + is_in_set_ = true; + AnfVisitor::Visit(cnode->input(2)); + is_in_set_ = false; + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (IsValueNode(vnode)) { + auto key = GetValue(vnode->value()); + if (is_in_set_) { + key1_ = key; + } else { + c2_ = vnode; + key2_ = key; + } + } + } + + void Reset() { + key1_ = -1; + key2_ = -1; + c2_ = nullptr; + last_ = nullptr; + tuple_ = nullptr; + is_in_set_ = false; + } + + private: + bool is_in_set_{false}; + int key1_{-1}, key2_{-1}; + AnfNodePtr tuple_{nullptr}, last_{nullptr}, c2_{nullptr}; +}; + +// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> +// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} +class GetitemDependReorder : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); + if (x_ == nullptr) { + return nullptr; + } + + auto fg = node->func_graph(); + auto item_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), x_, c_}, fg); + return NewCNode({NewValueNode(prim::kPrimDepend), item_node, y_}, fg); + } + + void Visit(const CNodePtr &cnode) override { + // {prim::kPrimDepend, X, Y} + if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && cnode->size() == 3) { + x_ = cnode->input(1); + y_ = cnode->input(2); + } + } + + void Visit(const ValueNodePtr &vnode) override { c_ = vnode; } + + void Reset() { + x_ = nullptr; + y_ = nullptr; + c_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; +}; + +class ItemTupleEliminater : public OptimizerCaller { + public: + ItemTupleEliminater() + : get_item_eliminater_(std::make_shared()), + get_item_const_eliminater_(std::make_shared()), + set_item_eliminater_(std::make_shared()), + get_set_item_eliminater_(std::make_shared()), + get_item_depend_reorder_(std::make_shared()) { + eliminaters_.emplace_back(get_item_eliminater_); + eliminaters_.emplace_back(get_item_const_eliminater_); + eliminaters_.emplace_back(set_item_eliminater_); + eliminaters_.emplace_back(get_set_item_eliminater_); + eliminaters_.emplace_back(get_item_depend_reorder_); + } + ~ItemTupleEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, + get_item_depend_reorder_; + std::vector eliminaters_{}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h b/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h new file mode 100644 index 0000000000..8d3839bd9e --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "utils/graph_utils.h" +#include "frontend/operator/composite/composite.h" + +namespace mindspore { +namespace opt { +namespace irpass { + +static int count = 0; + +std::string GetFusionNumber() { + std::stringstream ss; + ss << std::setw(4) << std::setfill('0') << count; + std::string num = ss.str(); + ++count; + + return "_" + num; +} + +// Mark CNodes which can be merged in kernel build +class MarkInterfaceFusion : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) { + auto cnode = node->cast(); + auto condition = cnode->input(1); + std::string cmp; + std::unordered_map cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"}, + {"LessEqual", "LE"}, {"Less", "LT"}, + {"Equal", "EQ"}, {"NotEqual", "NE"}}; + if (IsPrimitiveCNode(condition)) { + auto prim_name = GetCNodeFuncName(condition->cast()); + if (cmp_list.count(prim_name) != 0) { + // Mark Select and compare node + cmp = cmp_list[prim_name]; + auto cnt = GetFusionNumber(); + AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition); + AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) { + AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i)); + } + } + } + } + } + return nullptr; + } + + void Visit(const AnfNodePtr &) override {} + + private: + AnfNodePtr y_{nullptr}; +}; + +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h new file mode 100644 index 0000000000..a3cf6e2231 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h @@ -0,0 +1,320 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} -> +// {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}} +// {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} -> +// {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}} +class MergeAddN : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + Reset(); + optimizer_ = optimizer; + is_outer_ = true; + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); + if (!is_match_ || node->func_graph() == nullptr) { + return nullptr; + } + + auto cnode = node->cast(); + auto addn = NewValueNode(GetValueNode(cnode->input(0))); + + // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs} + (void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple)); + auto fg = node->func_graph(); + auto make_node = fg->NewCNode(args_); + + return fg->NewCNode({addn, make_node}); + } + + void Visit(const CNodePtr &cnode) override { + if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + return; + } + + auto &inputs = cnode->inputs(); + + if (is_outer_) { + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_)); + + is_outer_ = false; + is_inner_ = true; + + // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys} + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs[1]); + if (is_match_) { + if (!is_unique(inputs[1])) { + is_match_ = false; + return; + } + (void)Ys_.erase(Ys_.begin()); + (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); + (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); + return; + } + + // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}} + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs.back()); + if (is_match_) { + if (!is_unique(inputs.back())) { + is_match_ = false; + return; + } + Ys_.pop_back(); + (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); + (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); + return; + } + + return; + } + + if (is_inner_) { + is_match_ = true; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); + } + } + + bool is_unique(const AnfNodePtr &node) { + auto mng = optimizer_->resource()->manager(); + auto &node_users = mng->node_users(); + if (node_users.find(node) == node_users.end()) { + return false; + } + + size_t n_use = node_users[node].size(); + return n_use == 1; + } + + void Reset() { + Xs_.clear(); + Ys_.clear(); + args_.clear(); + is_inner_ = false; + is_outer_ = false; + is_match_ = false; + } + + private: + OptimizerPtr optimizer_{nullptr}; + std::vector Xs_{}, Ys_{}, args_{}; + bool is_inner_{false}, is_outer_{false}, is_match_{false}; +}; + +// {PrimAddN, {kPrimMakeTuple, Xs}} +class AddNZeroFilter : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); + + if (filtered_Xs_.empty() || node->func_graph() == nullptr) { + return nullptr; + } + + // if only two node in filtered_nodes, {make_tuple, x}. return x. + if (filtered_Xs_.size() == 2) { + return filtered_Xs_[1]; + } + + // if only one node in filtered_nodes, all node is zerolike, return one of the input. + if (filtered_Xs_.size() == 1 && Xs_.size() > 0) { + return Xs_[0]; + } + + if (!has_zero_like_) { + return nullptr; + } + + auto cnode = node->cast(); + auto addn = NewValueNode(GetValueNode(cnode->input(0))); + auto fg = node->func_graph(); + auto make_tuple = fg->NewCNode(filtered_Xs_); + return fg->NewCNode({addn, make_tuple}); + } + + void Visit(const CNodePtr &cnode) override { + if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + return; + } + + auto &inputs = cnode->inputs(); + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); + + // {kPrimMakeTuple, X1, X2, ...} + filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (auto &x : Xs_) { + if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) { + filtered_Xs_.push_back(x); + } else { + has_zero_like_ = true; + } + } + } + + void Reset() { + Xs_.clear(); + filtered_Xs_.clear(); + has_zero_like_ = false; + } + + private: + std::vector filtered_Xs_{}, Xs_{}; + bool has_zero_like_{false}; +}; + +// {PrimAddN, {kPrimMakeTuple, Xs}} +// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd. +// case0: AddN(inputs)(inputs size < 2) -> error +// case1: AddN(inputs)(all inputs is ValueNode) -> error +// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor) +// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input) +// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) +class AddNEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + auto fg = GetValueNode(inputs[0]); + MS_EXCEPTION_IF_NULL(fg); + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + if (fg->recursive()) { + return nullptr; + } + + auto new_fg = TransformableClone(fg, std::make_shared("fg")); + mng->AddFuncGraph(new_fg); + need_update_ = false; + bool changed; + do { + changed = Process(new_fg); + } while (changed); + + if (!need_update_) { + return nullptr; + } else { + auto new_sx = inputs; + new_sx[0] = NewValueNode(new_fg); + return node->func_graph()->NewCNode(new_sx); + } + } + + bool Process(const FuncGraphPtr &func_graph) { + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto nodes = TopoSort(func_graph->output()); + bool changed = false; + + for (size_t i = 0; i < nodes.size(); ++i) { + auto node = nodes[i]; + if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto &tuple_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(tuple_input); + auto tuple_input_cnode = tuple_input->cast(); + MS_EXCEPTION_IF_NULL(tuple_input_cnode); + auto &tuple_inputs = tuple_input_cnode->inputs(); + if (tuple_inputs.size() < 3) { + // case0: inputs size < 2, error + MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2); + } + + int valuenode_num = + std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) { + if (IsValueNode(node)) { + return accumulator + 1; + } else { + return accumulator; + } + }); + if (IntToSize(valuenode_num) == tuple_inputs.size()) { + // case1: all inputs is ValueNode, error + MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2); + } + + if (tuple_inputs.size() == 3) { + // case2: inputs size = 2, -> TensorAdd(Tensor, Tensor) + MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2); + ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); + std::vector new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1], + tuple_inputs[2]}; + mng->Replace(node, func_graph->NewCNode(new_xs)); + changed = true; + continue; + } + + auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(), + [](const AnfNodePtr &node) { return IsValueNode(node); }); + if (first_valuenode == tuple_inputs.end()) { + // no ValueNode input found. + continue; + } else { + // case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) + std::vector make_tuple_new_xs{ + NewValueNode(prim::kPrimMakeTuple), + }; + std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(), + [&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) { + if (node != *first_valuenode) { + make_tuple_new_xs.push_back(node); + } + }); + ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations"); + auto new_addn = func_graph->NewCNode( + {func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)}); + ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); + auto new_add = + func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn}); + (void)mng->Replace(node, new_add); + changed = true; + continue; + } + } + + need_update_ = need_update_ || changed; + return changed; + } + + private: + bool need_update_{false}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/minmax_grad.h b/mindspore/ccsrc/frontend/optimizer/irpass/minmax_grad.h new file mode 100644 index 0000000000..658a287234 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/minmax_grad.h @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ + +#include +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +// check if node is MinimumGrad() or MaximumGrad() +bool IsOriginMaxMinGrad(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimMaximumGrad) && !IsPrimitiveCNode(node, prim::kPrimMinimumGrad)) { + return false; + } + + auto cnode = node->cast(); + auto prim = GetValueNode(cnode->input(0)); + auto x_v = prim->GetAttr("grad_x"); + auto y_v = prim->GetAttr("grad_y"); + if (x_v == nullptr || y_v == nullptr || !x_v->isa() || !y_v->isa()) { + return false; + } + + bool x = GetValue(x_v); + bool y = GetValue(y_v); + return x && y; +} +} // namespace internal + +// {prim::kPrimTupleGetItem, {target_grad, Xs}, C} +class MinMaximumGrad : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {internal::IsOriginMaxMinGrad, IsValueNode})(node); + if (grad_ == nullptr || idx_ < 0 || idx_ > 1 || node->func_graph() == nullptr) { + return nullptr; + } + + // check single use + auto mng = optimizer->resource()->manager(); + auto &users = mng->node_users(); + if (users.find(grad_) == users.end() || users[grad_].size() != 1) { + return nullptr; + } + + // {target_grad, Xs} + auto &inputs = grad_->inputs(); + auto prim = GetValueNode(inputs[0]); + + auto new_prim = std::make_shared(prim->name()); + new_prim->set_attr("grad_x", MakeValue(true)); + new_prim->set_attr("grad_y", MakeValue(true)); + + if (idx_ == 0) { + new_prim->set_attr("grad_y", MakeValue(false)); + } + if (idx_ == 1) { + new_prim->set_attr("grad_x", MakeValue(false)); + } + + std::vector args; + args.push_back(NewValueNode(new_prim)); + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + + auto fg = node->func_graph(); + auto tuple = fg->NewCNode(args); + + return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple, NewValueNode(MakeValue(idx_))}); + } + + void Visit(const CNodePtr &cnode) override { grad_ = cnode; } + + void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } + + void Reset() { + idx_ = -1; + grad_ = nullptr; + } + + private: + int idx_{-1}; + CNodePtr grad_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/param_replace.h b/mindspore/ccsrc/frontend/optimizer/irpass/param_replace.h new file mode 100644 index 0000000000..999376e528 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/param_replace.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ + +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/parse.h" + +namespace mindspore { +namespace opt { +namespace irpass { +class ReplaceOldParam : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + if (!IsParam(node)) { + return nullptr; + } + auto resource = std::dynamic_pointer_cast(optimizer->resource()); + MS_EXCEPTION_IF_NULL(resource); + + auto top_graph = resource->func_graph(); // parse::Parser::GetTopFuncGraph(); + MS_EXCEPTION_IF_NULL(top_graph); + + auto param_node = node->cast(); + if (!param_node->has_default() || node->func_graph() == top_graph) { + return nullptr; + } + auto para_name = param_node->name(); + for (const auto &tnode : top_graph->parameters()) { + auto para = tnode->cast(); + if (para != nullptr && para->name() == para_name) { + return para; + } + } + return nullptr; + } +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h new file mode 100644 index 0000000000..32fc5abc7d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} +class PartialEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + Xs_.clear(); + auto &inputs = node->cast()->inputs(); + Visit(inputs[0]); + + if (Xs_.size() == 0) { + return nullptr; + } + + // {X, Xs, Ys} + std::vector args{}; + (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args)); + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); + TraceManager::DebugTrace(std::make_shared(node->debug_info())); + auto new_node = node->func_graph()->NewCNode(args); + TraceManager::EndTrace(); + return new_node; + } + + void Visit(const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { + return; + } + + auto &inputs = node->cast()->inputs(); + // {prim::kPrimPartial, X, Xs} + if (inputs.size() < 2) { + return; + } + + // fill Xs + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); + } + + private: + std::vector Xs_{}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/prim_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/prim_eliminate.h new file mode 100644 index 0000000000..d8c96825c9 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/prim_eliminate.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim, X} +class PrimEliminater : public AnfVisitor { + public: + explicit PrimEliminater(const PrimitivePtr &prim) : prim_(prim) {} + ~PrimEliminater() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + x_ = nullptr; + AnfVisitor::Match(prim_, {IsNode})(node); + return x_; + } + + void Visit(const AnfNodePtr &node) override { x_ = node; } + + private: + AnfNodePtr x_{nullptr}; + PrimitivePtr prim_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h new file mode 100644 index 0000000000..78b7d3f4f1 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "abstract/dshape.h" + +namespace mindspore { +namespace opt { +namespace irpass { +using abstract::Shape; +using abstract::ShapePtr; + +// {ReduceLike, X, axis} +class ReduceOneEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + PrimitivePtr prim; + if (IsPrimitiveCNode(node, prim::kPrimReduceMean) || IsPrimitiveCNode(node, prim::kPrimReduceAll) || + IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimReduceMax) || + IsPrimitiveCNode(node, prim::kPrimReduceMin)) { + prim = GetValueNode(node->cast()->input(0)); + AnfVisitor::Match(prim, {IsNode, IsVNode})(node); + if (!is_axis_one_) { + return nullptr; + } + + // consider keep_dims + auto keep_dims = prim->GetAttr("keep_dims"); + auto is_keep_dims = GetValue(keep_dims); + // {_Reduce, X, axis} -> X + if (is_keep_dims) { + return x_; + } + + // {_Reduce, Tensor} + if (is_tensor_) { + return nullptr; + } + + // {_Reduce, X, axis} -> {Reshape, X, new_shape} + std::vector elements; + for (size_t i = 0; i < x_shape_.size(); i++) { + auto iter = find(axis_.begin(), axis_.end(), i); + if (iter == axis_.end()) { + ValuePtr s = MakeValue(x_shape_[i]); + elements.push_back(s); + } + } + auto new_shape = std::make_shared(elements); + auto reshape_op = prim::GetPythonOps("reshape", "mindspore.ops.functional")->cast(); + return node->func_graph()->NewCNode({NewValueNode(reshape_op), x_, NewValueNode(new_shape)}); + } + + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (!IsVNode(node) && x_ == nullptr) { + if (IsValueNode(node)) { + is_tensor_ = true; + } + // get X's shape + auto x_shape_abs = node->abstract(); + if (x_shape_abs != nullptr) { + auto x_track = x_shape_abs->GetShapeTrack()->cast(); + if (x_track == nullptr) { + return; + } + auto x_shape = x_track->shape(); + (void)std::copy(x_shape.begin(), x_shape.end(), std::back_inserter(x_shape_)); + x_ = node; + } + return; + } + + // check axis + AnfVisitor::Visit(node); + } + + void Visit(const ValueNodePtr &vnode) override { + if (x_shape_.empty()) { + return; + } + + // axis : int + if (IsValueNode(vnode)) { + auto idx = GetValue(vnode->value()); + // axis could be negative + if (idx < 0) { + idx += SizeToInt(x_shape_.size()); + } + if (SizeToInt(x_shape_.size()) > idx && x_shape_[IntToSize(idx)] == 1) { + is_axis_one_ = true; + axis_.push_back(idx); + } + return; + } + + // axis : tuple(int), default () + if (IsValueNode(vnode)) { + auto axis = GetValue>(vnode->value()); + if (axis.empty()) { + return; + } + + auto cmp = std::all_of(axis.cbegin(), axis.cend(), [this](int idx) { + // axis could be negative + if (idx < 0) { + idx += SizeToInt(x_shape_.size()); + } + return SizeToInt(this->x_shape_.size()) > idx && this->x_shape_[IntToSize(idx)] == 1; + }); + if (cmp) { + is_axis_one_ = true; + (void)std::copy(axis.begin(), axis.end(), std::back_inserter(axis_)); + } + } + } + + void Reset() { + axis_.clear(); + x_shape_.clear(); + x_ = nullptr; + is_axis_one_ = false; + is_tensor_ = false; + } + + private: + bool is_axis_one_{false}, is_tensor_{false}; + std::vector axis_{}, x_shape_{}; + AnfNodePtr x_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h new file mode 100644 index 0000000000..86eb4e761d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ + +#include + +#include "ir/pattern_matcher.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimMakeRef, X, Y, Z} -> Y +class MakeRefEliminater : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x, y, z; + MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y); + return nullptr; + } +}; + +// {prim::kPrimGetRefValue, Parameter} -> Parameter +// {prim::kPrimGetRefOrigin, Parameter} -> Parameter +class GetRefParamEliminater : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x; + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node)); + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node)); + return nullptr; + } +}; + +// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X +// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y +// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z +class GetMakeRefEliminater : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x, y, z; + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z); + + return nullptr; + } +}; + +// IsValueNode +class ReplaceRefkeyByParam : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr { + auto refkey = GetValueNode(node); + auto resource = std::dynamic_pointer_cast(optimizer->resource()); + MS_EXCEPTION_IF_NULL(resource); + + auto top_graph = resource->func_graph(); + MS_EXCEPTION_IF_NULL(top_graph); + + for (const auto &tnode : top_graph->parameters()) { + auto para = tnode->cast(); + if (para != nullptr && para->name() == refkey->tag()) { + return para; + } + } + return nullptr; + }; + PatternNode x; + MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode, node)); + return nullptr; + } +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/reshape_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/reshape_eliminate.h new file mode 100644 index 0000000000..27d4bdad3d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/reshape_eliminate.h @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ + +#include + +#include "ir/func_graph.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "abstract/dshape.h" + +namespace mindspore { +namespace opt { +namespace irpass { +using abstract::Shape; +using abstract::ShapePtr; + +// {reshape_op, X, Shape} +class ReshapeSameShapeEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimReshape, {IsNode, IsVNode})(node); + + // check pattern match + if (shape_ == nullptr) { + return nullptr; + } + + auto src_shape_abs = x_->abstract(); + if (src_shape_abs == nullptr) { + return nullptr; + } + + auto src_shape = src_shape_abs->GetShapeTrack(); + auto tgt_shape_abs = node->abstract(); + if (tgt_shape_abs == nullptr) { + return nullptr; + } + auto tgt_shape = tgt_shape_abs->GetShapeTrack(); + if (src_shape != nullptr && tgt_shape != nullptr && src_shape->isa() && tgt_shape->isa()) { + auto elements = tgt_shape->cast(); + auto shape = src_shape->cast(); + if (shape->shape() == elements->shape()) { + return x_; + } + } + + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } else { + shape_ = node; + } + } + + void Reset() { + x_ = nullptr; + shape_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}, shape_{nullptr}; +}; + +// {PrimReshape, {PrimReshape, X, Y}, Shape} +class TwoReshapeEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimReshape, {IsCNode, IsNode})(node); + + auto fg = node->func_graph(); + if (fg != nullptr && x_ != nullptr && shape_ != nullptr) { + auto new_node = fg->NewCNode({NewValueNode(prim_), x_, shape_}); + new_node->set_abstract(node->abstract()); + return new_node; + } + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (IsPrimitiveCNode(node, prim::kPrimReshape)) { + auto &inputs = node->cast()->inputs(); + // {PrimReshape, X, Y} + if (inputs.size() != 3) { + return; + } + prim_ = GetValueNode(inputs[0]); + x_ = inputs[1]; + } else { + shape_ = node; + } + } + + void Reset() { + prim_ = nullptr; + x_ = nullptr; + shape_ = nullptr; + } + + private: + PrimitivePtr prim_{nullptr}; + AnfNodePtr x_{nullptr}, shape_{nullptr}; +}; + +class ReshapeEliminater : public OptimizerCaller { + public: + ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} + ~ReshapeEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + auto new_node = reshape_same_shape_eliminater_(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + + new_node = two_reshape_eliminater_(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + + return nullptr; + } + + private: + ReshapeSameShapeEliminater reshape_same_shape_eliminater_; + TwoReshapeEliminater two_reshape_eliminater_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h new file mode 100644 index 0000000000..01efa85e8d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -0,0 +1,210 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ + +#include +#include +#include +#include + +#include "ir/optimizer_caller.h" +#include "ir/pattern_matcher.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/prim_eliminate.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +class SpecialOpEliminater : public OptimizerCaller { + public: + SpecialOpEliminater() + : insert_gradient_of_(std::make_shared(prim::kPrimInsertGradientOf)), + stop_gradient_(std::make_shared(prim::kPrimStopGradient)), + hook_backward_(std::make_shared(prim::kPrimHookBackward)), + print_shape_type_(std::make_shared(prim::kPrimPrintShapeType)), + get_ref_value_(std::make_shared(prim::kPrimGetRefValue)), + mirror_(std::make_shared(prim::kPrimMirror)), + virtual_div_(std::make_shared(prim::kPrimVirtualDiv)) { + eliminaters_.emplace_back(insert_gradient_of_); + eliminaters_.emplace_back(stop_gradient_); + eliminaters_.emplace_back(hook_backward_); + eliminaters_.emplace_back(print_shape_type_); + eliminaters_.emplace_back(get_ref_value_); + eliminaters_.emplace_back(mirror_); + eliminaters_.emplace_back(virtual_div_); + } + ~SpecialOpEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, + virtual_div_; + std::vector eliminaters_{}; +}; + +// {PrimVirtualDataset, X} -> X +// {PrimVirtualDataset, Xs} -> {prim::kPrimMakeTuple, Xs} +class VirtualDatasetEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimVirtualDataset) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 1) { + return nullptr; + } + + std::vector args; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); + if (args.size() == 1) { + return args.front(); + } + + (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple)); + + return node->func_graph()->NewCNode(args); + } + + void Visit(const AnfNodePtr &) override {} +}; + +// {prim::kPrimSameTypeShape, X, Y} -> X +class SameEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + x_ = nullptr; + AnfVisitor::Match(prim::kPrimSameTypeShape, {IsNode, IsNode})(node); + return x_; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } + } + + private: + AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimCheckBprop, X, Y} -> X +class CheckBpropEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + x_ = nullptr; + AnfVisitor::Match(prim::kPrimCheckBprop, {IsNode, IsNode})(node); + return x_; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } + } + + private: + AnfNodePtr x_{nullptr}; +}; + +// Reset defer_inline flag +class ResetDeferInline : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (IsValueNode(node)) { + auto fg = GetValueNode(node); + fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); + } + return nullptr; + } +}; + +// {PrimZerosLike, Y} -> +// {PrimFill, {PrimDType, Y}, {PrimShape, Y}, 0} +class ZeroLikeFillZero : public AnfVisitor { + public: + ZeroLikeFillZero() + : PrimFill_(prim::GetPythonOps("fill", "mindspore.ops.functional")->cast()), + PrimShape_(prim::GetPythonOps("shape", "mindspore.ops.functional")->cast()), + PrimDType_(prim::GetPythonOps("dtype", "mindspore.ops.functional")->cast()) {} + ~ZeroLikeFillZero() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + y_ = nullptr; + AnfVisitor::Match(prim::kPrimZerosLike, {IsNode})(node); + if (y_ == nullptr || node->func_graph() == nullptr) { + return nullptr; + } + if ((y_->abstract() == nullptr) || !y_->abstract()->isa()) { + auto fg = node->func_graph(); + auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); + auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); + return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))}); + } + + abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast(); + + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + tensor::TensorPtr new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + (void)memset_s(data, mem_size, 0, mem_size); + + auto new_cnode = NewValueNode(new_tensor_ptr); + new_cnode->set_abstract(new_tensor_ptr->ToAbstract()); + + return new_cnode; + } + + void Visit(const AnfNodePtr &node) override { y_ = node; } + + private: + AnfNodePtr y_{nullptr}; + PrimitivePtr PrimFill_, PrimShape_, PrimDType_; +}; + +// {prim::kPrimDepend, X, ValueCond}->X +class DependValueElim : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x, cond; + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node))); + return nullptr; + } +}; + +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h b/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h new file mode 100644 index 0000000000..d8a15f6d83 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h @@ -0,0 +1,305 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ + +#include +#include +#include +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/manager.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +class SpecializeTransform { + public: + SpecializeTransform() : cache_() {} + ~SpecializeTransform() = default; + + FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector graph_args, + std::vector prim_args, std::vector value_args) { + if (cache_.count(func_graph) == 0) { + cache_[func_graph] = {}; + } + + auto &cache = cache_[func_graph]; + auto key = std::make_pair(graph_args, prim_args); + if (cache.count(key) == 0) { + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + + FuncGraphPtr new_fg = TransformableClone(func_graph, std::make_shared("sp")); + mng->AddFuncGraph(new_fg); + + std::vector params = new_fg->parameters(); + std::vector new_params; + size_t n = graph_args.size(); + for (size_t i = 0; i < n; i++) { + if (graph_args[i] != nullptr) { + auto arg = NewValueNode(graph_args[i]); + (void)mng->Replace(params[i], arg); + continue; + } + if (prim_args[i] != nullptr) { + auto arg = NewValueNode(prim_args[i]); + (void)mng->Replace(params[i], arg); + continue; + } + if (value_args[i] != nullptr) { + auto &const_tensor = *value_args[i]; + auto const_tensor_ptr = std::make_shared(const_tensor); + AnfNodePtr arg = NewValueNode(const_tensor_ptr); + (void)mng->Replace(params[i], arg); + continue; + } + new_params.push_back(params[i]); + } + + mng->SetParameters(new_fg, new_params); + cache[key] = new_fg; + } + return cache[key]; + } + + private: + std::unordered_map, std::vector>, FuncGraphPtr>> + cache_; +}; +} // namespace internal + +// {G, Xs} +class SpecializeOnGraphArguments : public AnfVisitor { + public: + SpecializeOnGraphArguments() : specialize_transform_() {} + ~SpecializeOnGraphArguments() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (!IsValueNode(inputs[0])) { + return nullptr; + } + + auto inp0_fg = GetValueNode(inputs[0]); + if (inp0_fg->recursive()) { + return nullptr; + } + + std::vector graph_args; + std::vector prim_args; + std::vector value_node_args; + std::vector new_xs; + bool hasVNode = false; + for (size_t i = 1; i < inputs.size(); i++) { + if (IsValueNode(inputs[i])) { + auto fg_vnode = GetValueNode(inputs[i]); + graph_args.push_back(fg_vnode); + prim_args.emplace_back(nullptr); + value_node_args.emplace_back(nullptr); + hasVNode = true; + } else if (IsValueNode(inputs[i])) { + auto p_vnode = GetValueNode(inputs[i]); + graph_args.emplace_back(nullptr); + prim_args.push_back(p_vnode); + value_node_args.emplace_back(nullptr); + hasVNode = true; + } else if (IsValueNode(inputs[i])) { + tensor::TensorPtr t_vnode = GetValueNode(inputs[i]); + graph_args.emplace_back(nullptr); + prim_args.emplace_back(nullptr); + value_node_args.emplace_back(t_vnode); + hasVNode = true; + } else { + graph_args.emplace_back(nullptr); + prim_args.emplace_back(nullptr); + value_node_args.emplace_back(nullptr); + new_xs.push_back(inputs[i]); + } + } + + if (!hasVNode) { + return nullptr; + } + + auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args); + (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); + + return node->func_graph()->NewCNode(new_xs); + } + + private: + internal::SpecializeTransform specialize_transform_; +}; + +// Eliminate unused parameters. +// {G, Xs} +class UnusedParasEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto &inputs = cnode->inputs(); + auto fg = GetValueNode(inputs[0]); + MS_EXCEPTION_IF_NULL(fg); + + std::vector parameters = fg->parameters(); + size_t size = parameters.size(); + if (size != inputs.size() - 1) { + return nullptr; + } + + std::vector new_xs; + std::vector keep_parameters; + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto &node_users = mng->node_users(); + bool has_unused_para = false; + for (size_t i = 0; i < size; ++i) { + auto iter = node_users.find(parameters[i]); + if (iter != node_users.end() && !iter->second.empty()) { + keep_parameters.push_back(true); + new_xs.push_back(inputs[i + 1]); + continue; + } + keep_parameters.push_back(false); + has_unused_para = true; + } + + if (!has_unused_para) { + return nullptr; + } + FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("sp")); + mng->AddFuncGraph(new_fg); + + std::vector new_fg_parameters = new_fg->parameters(); + std::vector new_parameters; + for (size_t i = 0; i < size; i++) { + if (keep_parameters[i]) { + if (parameters[i]->abstract() != nullptr) { + new_fg_parameters[i]->set_abstract(parameters[i]->abstract()); + } + new_parameters.push_back(new_fg_parameters[i]); + } + } + mng->SetParameters(new_fg, new_parameters); + + (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); + return node->func_graph()->NewCNode(new_xs); + } +}; + +// Eliminate unused outputs. +// {G, Xs} +class UnusedOutputEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + auto fg = GetValueNode(inputs[0]); + MS_EXCEPTION_IF_NULL(fg); + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + if (fg->recursive()) { + return nullptr; + } + + auto new_fg = TransformableClone(fg, std::make_shared("fg")); + mng->AddFuncGraph(new_fg); + auto new_fg_output = new_fg->output(); + if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) { + return nullptr; + } + + auto output_cnode = new_fg_output->cast(); + auto &node_users = mng->node_users(); + if (node_users.count(node) == 0 || node_users[node].empty()) { + return nullptr; + } + std::unordered_set used_output_idx; + std::vector> all_users; + for (auto &node_user : node_users[node]) { + if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { + return nullptr; + } + auto user_cnode = node_user.first->cast(); + size_t used_idx = GetValue(user_cnode->input(2)->cast()->value()); + used_output_idx.insert(used_idx); + all_users.push_back(std::make_pair(node_user.first, used_idx)); + } + + if (used_output_idx.size() >= output_cnode->inputs().size() - 1) { + // all output has users. + return nullptr; + } + + if (used_output_idx.empty()) { + // we do not process this case. + return nullptr; + } else if (used_output_idx.size() == 1) { + // after eliminate, only one output left. + new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1)); + // update users. + for (auto &ret_user : all_users) { + (void)mng->Replace(ret_user.first, node); + } + } else { + // after eliminate, create new multi output. + std::vector new_output_inputs{output_cnode->input(0)}; + std::unordered_map new_idx_map; + for (auto idx : used_output_idx) { + new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1); + new_output_inputs.push_back(output_cnode->input(idx + 1)); + } + new_fg->set_output(new_fg->NewCNode(new_output_inputs)); + // update users. + for (auto &ret_user : all_users) { + auto ret_user_cnode = ret_user.first->cast(); + ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second])); + } + } + + auto new_sx = inputs; + new_sx[0] = NewValueNode(new_fg); + return node->func_graph()->NewCNode(new_sx); + } +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h new file mode 100644 index 0000000000..de9e533550 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ + +#include +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/python_adapter.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimResolve, Ns, Sym} +class ResolverResolve : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimResolve, {IsVNode, IsVNode})(node); + if (sym_ != nullptr) { + return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node); + } + return nullptr; + } + + void Visit(const ValueNodePtr &vnode) override { + if (IsValueNode(vnode)) { + ns_ = GetValueNode(vnode); + } else if (ns_ != nullptr && IsValueNode(vnode)) { + sym_ = GetValueNode(vnode); + } + } + + void Reset() { + ns_ = nullptr; + sym_ = nullptr; + } + + private: + parse::NameSpacePtr ns_{nullptr}; + parse::SymbolPtr sym_{nullptr}; +}; + +// {prim::kPrimGetAttr, Ns, Str} +class ResolverGetattr : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimGetAttr, {IsVNode, IsVNode})(node); + if (sym_ != nullptr) { + return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node); + } + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (IsValueNode(node)) { + ns_ = GetValueNode(node); + } else if (ns_ != nullptr && IsValueNode(node)) { + auto str = GetValue(GetValueNode(node)); + sym_ = std::make_shared(str); + } + } + + void Reset() { + ns_ = nullptr; + sym_ = nullptr; + } + + private: + parse::NameSpacePtr ns_{nullptr}; + parse::SymbolPtr sym_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/tile_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/tile_eliminate.h new file mode 100644 index 0000000000..f561e04c10 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/tile_eliminate.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ + +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// check if node is value tuple and all one. e.g. (1, 1, 1) +// {PrimTile, X, MultiOne} +class TileMultiplyByOne : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTile, {IsNode, IsVNode})(node); + + // check pattern match + if (tuple_ == nullptr) { + return nullptr; + } + + auto value = GetValueNode(tuple_); + auto elements = GetValue>(value); + if (elements.empty()) { + return nullptr; + } + + auto cmp = std::all_of(elements.cbegin(), elements.cend(), [](int i) { return i == 1; }); + if (cmp) { + return x_; + } + + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } else { + tuple_ = node; + } + } + + void Reset() { + x_ = nullptr; + tuple_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}, tuple_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/transpose_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/transpose_eliminate.h new file mode 100644 index 0000000000..70b8898462 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/transpose_eliminate.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ + +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// check if node is value tuple and ascends one by one from zero. e.g., (0, 1, 2, 3) +// {PrimTranspose, X, AscendingNums} +class TransposeSameIOEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTranspose, {IsNode, IsVNode})(node); + + // check pattern match + if (tuple_ == nullptr) { + return nullptr; + } + + auto value = GetValueNode(tuple_); + auto elements = GetValue>(value); + if (elements.empty()) { + return nullptr; + } + + int j = 0; + bool cmp = std::all_of(elements.cbegin(), elements.cend(), [&j](int i) { return i == j++; }); + // same IO settings, eliminate this transpose + if (cmp) { + return x_; + } + + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } else { + tuple_ = node; + } + } + + void Reset() { + x_ = nullptr; + tuple_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}, tuple_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/opt.cc b/mindspore/ccsrc/frontend/optimizer/opt.cc new file mode 100644 index 0000000000..44917106fa --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/opt.cc @@ -0,0 +1,241 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/opt.h" + +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/manager.h" +#include "frontend/optimizer/optimizer.h" +#include "utils/log_adapter.h" +#include "utils/ordered_set.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, + const RenormAction &renorm_action) { + auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; + return std::make_shared(transform, name, fn, renorm_action); +} + +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, + const std::vector &prims, const RenormAction &renorm_action) { + auto fn = [prims](const AnfNodePtr &node) -> bool { + if (!node->isa()) { + return false; + } + + auto cnode = node->cast(); + auto inp0 = cnode->input(0); + auto prim0 = GetValueNode(inp0); + if (prim0 == nullptr) { + return false; + } + + auto hash = prim0->Hash(); + auto const &name = prim0->name(); + for (auto &prim : prims) { + if (hash == prim->Hash() && name == prim->name()) { + return true; + } + } + return false; + }; + + return std::make_shared(transform, name, fn, renorm_action); +} + +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, + const PredicateFuncType &predicate, const RenormAction &renorm_action) { + return std::make_shared(transform, name, predicate, renorm_action); +} + +AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { +#ifdef ENABLE_PROFILE + double t = GetTime(); +#endif + AnfNodePtr result = (*transform_)(optimizer, node); +#ifdef ENABLE_PROFILE + if (optimizer != nullptr) { + auto time = GetTime(); + MsProfile::StatTime("substitution." + name_, time - t); + if (result != nullptr) { + MsProfile::StatTime("match." + name_, time - t); + } + } +#endif + if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) { + if ((renorm_action_ == FORCE_RENORM) || (result->abstract() == nullptr)) { + optimizer->set_is_untyped_generated(); + } + } + + return result; +} + +static bool isTraversable(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + if (node->isa() || node->isa()) { + return true; + } + if (IsValueNode(node) || IsValueNode(node)) { + return true; + } + return false; +} + +bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, + const SubstitutionPtr &transform) const { +#ifdef ENABLE_PROFILE + double start = GetTime(); +#endif + FuncGraphManagerPtr manager = optimizer->manager(); + auto seen = NewSeenGeneration(); + // 1024 is for the initial capacity of deque + std::deque todo(1024); + todo.clear(); + todo.push_back(root_node); + bool changes = false; + + auto &all_nodes = manager->all_nodes(); + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + + // check whether this node has been matched. + if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { + continue; + } + node->seen_ = seen; + + // select nodes that this transform can be applied. + bool is_match = transform->predicate_(node); + + // apply transform on this node + bool change = false; + if (is_match) { + auto ret = (*transform)(optimizer, node); + if (ret != nullptr && ret != node) { + change = true; + changes = true; +#ifdef ENABLE_PROFILE + double t = GetTime(); +#endif + (void)manager->Replace(node, ret); +#ifdef ENABLE_PROFILE + MsProfile::StatTime("replace." + transform->name_, GetTime() - t); +#endif + node = ret; + } + } + + // find success, and add them to todo list + if (IsValueNode(node)) { + todo.push_back(GetValueNode(node)->output()); + } + + if (node->isa()) { + auto &inputs = node->cast()->inputs(); + (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); + } + + auto &node_users = manager->node_users(); + if (change && node_users.find(node) != node_users.end()) { + for (auto &use : node_users[node]) { + auto use_node = use.first; + if (use_node == nullptr) { + continue; + } + todo.push_back(use_node); + if (use_node->seen_ == seen) { + use_node->seen_--; + } + } + } + } + +#ifdef ENABLE_PROFILE + MsProfile::StatTime("opt.transform." + optimizer->name(), GetTime() - start); +#endif + return changes; +} + +bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { + MS_EXCEPTION_IF_NULL(optimizer); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = optimizer->manager(); + manager->AddFuncGraph(func_graph); + + // for transform status counting + size_t space = 0; + std::unordered_map> status; + if (optimizer->is_on_debug_) { + for (size_t i = 0; i < list_.size(); i++) { + status[list_[i]->name_ + std::to_string(i)] = {}; + } + } + + bool loop = false; + bool changes = false; + + do { + loop = false; + for (size_t i = 0; i < list_.size(); i++) { + auto change = ApplyTransform(optimizer, func_graph->output(), list_[i]); + changes = changes || change; + loop = loop || change; + + // record the status of each transform + if (optimizer->is_on_debug_) { + status[list_[i]->name_ + std::to_string(i)].push_back(change); + space = std::max(list_[i]->name_.size(), space); + } + } + + if (is_once_) { + break; + } + } while (loop); + + // display the status of each transform + if (optimizer->is_on_debug_) { + std::stringstream ss; + ss << std::endl + << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name + << std::endl; + for (size_t i = 0; i < list_.size(); i++) { + auto name = list_[i]->name_; + ss << std::left << std::setw(space + 4) << name << "\t"; + for (auto change : status[name + std::to_string(i)]) { + ss << change << " "; + } + ss << std::endl; + } + MS_LOG(DEBUG) << ss.str(); + } + + return changes; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/opt.h b/mindspore/ccsrc/frontend/optimizer/opt.h new file mode 100644 index 0000000000..f440cc71dc --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/opt.h @@ -0,0 +1,78 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ + +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/optimizer_caller.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { + +// Define the interaction mode between an Optimize pass and Renormalize pass +// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed +// CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted +enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; + +class Substitution { + public: + OptimizerCallerPtr transform_; + std::string name_; + PredicateFuncType predicate_{nullptr}; + // an enum to mark this Substitution relation to renormalize pass + RenormAction renorm_action_; + Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, + const RenormAction &renorm_action) + : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} + ~Substitution() = default; + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); +}; + +using SubstitutionPtr = std::shared_ptr; + +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, + const RenormAction &action_renorm = CHECK_RENORM); +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, + const std::vector &prims, + const RenormAction &action_renorm = CHECK_RENORM); +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, + const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); + +class SubstitutionList { + public: + explicit SubstitutionList(const std::vector &patterns, bool is_once = false) + : list_(patterns), is_once_(is_once) {} + ~SubstitutionList() = default; + + bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; + + private: + bool ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &transform) const; + std::vector list_; + // a flag to mark this list of Substitution can only be executed only once + bool is_once_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/optimizer.h b/mindspore/ccsrc/frontend/optimizer/optimizer.h new file mode 100644 index 0000000000..a1f11e74d0 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/optimizer.h @@ -0,0 +1,242 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "debug/draw.h" +#include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" +#include "debug/trace.h" +#include "frontend/optimizer/opt.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/action.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +using OptimizeGraphFunc = std::function; + +class OptPassConfig { + public: + explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {} + explicit OptPassConfig(const std::vector &list, bool is_once = false) + : list_(list), is_once_(is_once) {} + OptPassConfig(const std::initializer_list &list, bool is_once = false) + : list_(list), is_once_(is_once) {} + ~OptPassConfig() = default; + + const std::vector &list() const { return list_; } + const OptimizeGraphFunc &func() const { return func_; } + + static OptPassConfig Renormalize() { return OptPassConfig(); } + const bool is_renormalize() const { return is_renormalize_; } + + const bool is_once() const { return is_once_; } + + private: + OptPassConfig() : is_renormalize_(true) {} + + OptimizeGraphFunc func_; + std::vector list_; + bool is_renormalize_{false}; + bool is_once_{false}; +}; + +class OptPass { + public: + explicit OptPass(const OptimizeGraphFunc &func) : pass_func_(func) {} + ~OptPass() = default; + + bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { + return pass_func_(func_graph, optimizer); + } + + static OptPass Renormalize() { return OptPass(); } + const bool is_renormalize() const { return is_renormalize_; } + + private: + OptPass() : is_renormalize_(true) {} + + OptimizeGraphFunc pass_func_; + bool is_renormalize_{false}; +}; +using OptPassGroupMap = std::vector>; + +class Optimizer : public std::enable_shared_from_this { + public: + Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) + : name_(name), + resource_(resource_ptr), + run_only_once_(false), + is_watch_renormalize_(false), + is_enable_(true), + is_untyped_generated_(false) {} + virtual ~Optimizer() = default; + + void Init(const OptPassGroupMap &passes, bool run_only_once) { + run_only_once_ = run_only_once; + is_watch_renormalize_ = false; + is_untyped_generated_ = false; + is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG); + + for (auto &iter : passes) { + const std::string &name = iter.first; + pass_names_.push_back(name); + + const OptPassConfig &config = iter.second; + if (config.is_renormalize()) { + passes_.push_back(OptPass::Renormalize()); + continue; + } + + if (config.list().size() > 0) { + OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once()); + passes_.push_back(OptPass(func)); + continue; + } + + passes_.push_back(OptPass(config.func())); + } + + if (passes_.size() == 1) { + run_only_once_ = true; + } + } + + static std::shared_ptr MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, + const OptPassGroupMap &passes, bool run_only_once = false, + bool watch_renormalize = false) { + OptimizerPtr optimizer = std::make_shared(name, resource_ptr); + optimizer->Init(passes, run_only_once); + if (watch_renormalize) { + optimizer->enable_watch_renormalize(); + } + return optimizer; + } + + FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { + if (!is_enable_) { + return func_graph; + } + // Optimizer step counter; + int counter = 1; + bool changes = true; + + while (changes) { + changes = false; + auto run_runc = [&counter, &func_graph, &changes, use_profile, this]() { + for (size_t i = 0; i < passes_.size(); ++i) { + const OptPass &opt = passes_[i]; + CurPass_ = {counter, pass_names_[i]}; + auto opt_func = [&func_graph, &changes, &opt, this]() { + if (opt.is_renormalize()) { + auto resource_ptr = std::dynamic_pointer_cast(resource_); + if (resource_ptr != nullptr) { + // StepParallel may replace the AbstractValue of the parameters of func_graph, + // So generate the args_spec from parameters. + abstract::AbstractBasePtrList maybe_new_args_spec; + if (is_watch_renormalize_) { + if (is_untyped_generated_) { + std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), + std::back_inserter(maybe_new_args_spec), + [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); + func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); + clear_is_untyped_generated(); + } else { + MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because is_untyped_generated_ is False."; + } + } else { + std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), + std::back_inserter(maybe_new_args_spec), + [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); + func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); + } + } + } else if (opt(func_graph, shared_from_this())) { + changes = true; + } + }; + use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); + if (is_on_debug_ && MsContext::GetInstance()->save_graphs_flag()) { + MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; + auto fg_name = + "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; + func_graph->DumpFuncGraph(fg_name); + DumpIR(fg_name + ".ir", func_graph); + ExportIR(fg_name + ".dat", "", func_graph); + MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; + } + } + }; + use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter)) run_runc) : run_runc(); + counter++; + + if (run_only_once_) { + break; + } + } + return func_graph; + } + + pipeline::ResourceBasePtr resource() const { return resource_; } + FuncGraphManagerPtr manager() const { + if (resource_ != nullptr) { + return resource_->manager(); + } + MS_LOG(EXCEPTION) << "No ResourceBase exists."; + } + + const std::string name() const { return name_; } + + void set_is_untyped_generated() { is_untyped_generated_ = true; } + void clear_is_untyped_generated() { is_untyped_generated_ = false; } + + void enable_watch_renormalize() { is_watch_renormalize_ = true; } + void disable_watch_renormalize() { is_watch_renormalize_ = false; } + bool is_watch_renormalize() { return is_watch_renormalize_; } + void set_enable(bool enable) { is_enable_ = enable; } + + struct { + int counter; + std::string name; + } CurPass_; + + bool is_on_debug_{false}; + + private: + const std::string name_; + pipeline::ResourceBasePtr resource_; + std::vector passes_; + std::vector pass_names_; + bool run_only_once_; + bool is_watch_renormalize_; + bool is_enable_; + bool is_untyped_generated_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/pass_group.cc b/mindspore/ccsrc/frontend/optimizer/pass_group.cc new file mode 100644 index 0000000000..3619396215 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/pass_group.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "frontend/optimizer/pass_group.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +void PassGroup::AddPass(const PythonPassPtr &pass) { + if (pass != nullptr) { + passes_.push_back(pass); + } +} + +bool PassGroup::DeletePass(const std::string &pass_name) { + for (auto iter = passes_.begin(); iter != passes_.end(); iter++) { + if ((*iter)->name() == pass_name) { + *iter = nullptr; + passes_.erase(iter); + return true; + } + } + return false; +} + +bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { + if (func_graph == nullptr) { + return false; + } + bool changed = false; + for (const auto &pass : passes) { + if (pass != nullptr) { + if (pass->Run(func_graph)) { + changed = true; + } + } + } + return changed; +} + +bool PassGroup::Run(const FuncGraphPtr &func_graph) const { + bool changed = false; + // run all passes + bool change = true; + while (change) { + change = Run(func_graph, passes_); + changed = change || changed; + if (run_only_once_) { + break; + } + } + return changed; +} + +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/pass_group.h b/mindspore/ccsrc/frontend/optimizer/pass_group.h new file mode 100644 index 0000000000..08fa8018d6 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/pass_group.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ + +#include +#include +#include +#include + +#include "frontend/optimizer/py_pass.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +class PassGroup { + public: + explicit PassGroup(const std::string &name = "pass_group", bool run_only_once = false) + : name_(name), passes_{}, run_only_once_(run_only_once) {} + virtual ~PassGroup() = default; + // Add graph pass, the pass object will be freed when pass manager freed. + void AddPass(const PythonPassPtr &pass); + // Delete graph pass before the pass manager is freed. + bool DeletePass(const std::string &pass_name); + // Run passes added in pass manager on the input graph + // @param [inout] graph The graph to be optimized + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph) const; + // Run the given graph passes on the input graph + // @param [inout] graph The graph to be optimized + // @param [in] passes The given graph passes + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; + std::string name() const { return name_; } + + private: + const std::string name_; + std::vector passes_; + bool run_only_once_; +}; +using PassGroupPtr = std::shared_ptr; +} // namespace python_pass +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.cc b/mindspore/ccsrc/frontend/optimizer/py_pass.cc new file mode 100644 index 0000000000..c1bf40fcbb --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.cc @@ -0,0 +1,237 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "frontend/optimizer/py_pass.h" +#include +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +namespace internal { +std::string GetNodeRepr(AnfNodePtr node) { + if (node != nullptr) { + if (node->isa()) { + std::string repr = "("; + auto const &inputs = node->cast()->inputs(); + for (auto &input : inputs) { + repr += " "; + repr += GetNodeRepr(input); + repr += " "; + } + repr += ")"; + return repr; + } + if (node->isa()) { + return GetValueNode(node)->ToString(); + } + return node->ToString(); + } + return ""; +} + +void ResolveFuncGraph_(const FuncGraphPtr &fg) { + auto manager = Manage(fg, false); + parse::python_adapter::set_use_signature_in_resolve(false); + parse::ResolveAll(manager); + parse::python_adapter::set_use_signature_in_resolve(true); +} + +bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { + if (node == nullptr) { + return false; + } + MS_EXCEPTION_IF_NULL(pattern); + if (pattern->isa()) { + if (!node->isa()) { + return false; + } + if (GetNodeRepr(pattern) == GetNodeRepr(node)) { + // add to equiv_ptr + equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node)); + return true; + } + return false; + } else if (pattern->isa()) { + MS_LOG(DEBUG) << pattern->ToString() + "\n"; + // add to equiv_ptr + equiv_ptr->insert(std::make_pair(pattern->ToString(), node)); + return true; + } else if (pattern->isa()) { + // match every single sub ANode + if (!node->isa()) { + return false; + } + auto pattern_inputs = pattern->cast()->inputs(); + auto node_inputs = node->cast()->inputs(); + if (pattern_inputs.size() != node_inputs.size()) { + return false; + } + for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end(); + p_item++, node_item++) { + auto res = Match(*p_item, *node_item, equiv_ptr); + if (!res) { + return false; + } + } + return true; + } + MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n"; +} + +AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_, + const NodeEquivPtr &equiv_ptr) { + if (cur_raw_dst_node_->isa()) { + auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString()); + if (sub_pair != equiv_ptr->end()) { + return sub_pair->second; + } + MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n"; + } else if (cur_raw_dst_node_->isa()) { + // check primitive ValueNode + auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast()->value()->ToString()); + if (sub_pair != equiv_ptr->end()) { + return sub_pair->second; + } + return cur_raw_dst_node_; + } else if (cur_raw_dst_node_->isa()) { + std::vector new_inputs; + auto inputs = cur_raw_dst_node_->cast()->inputs(); + for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) { + auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr); + new_inputs.push_back(subed); + } + return func_graph->NewCNode(new_inputs); + } + MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_); +} + +bool isTraversable(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + if (node->isa() || node->isa()) { + return true; + } + if (IsValueNode(node) || IsValueNode(node)) { + return true; + } + return false; +} +} // namespace internal + +void PythonPass::Build(const py::function &src, const py::function &dst) { + // 1. get FuncGraph from py::function + auto src_fg_ = parse::ParsePythonCode(src); + auto dst_fg_ = parse::ParsePythonCode(dst); + if (src_fg_ == nullptr || dst_fg_ == nullptr) { + MS_LOG(EXCEPTION) << "Failed to parse python code.\n"; + } + // 2. Resolve + internal::ResolveFuncGraph_(src_fg_); + internal::ResolveFuncGraph_(dst_fg_); + // 3. from FuncGraphPtr to ValueNode + src_node_ = src_fg_->output(); + dst_node_ = dst_fg_->output(); +} + +PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once, + bool multigraph) + : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) { + Build(src, dst); +} + +AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + auto equiv_ptr = std::make_shared(); + bool is_a_match = internal::Match(src_node_, node, equiv_ptr); + if (is_a_match) { + auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr); + MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; + return new_node; + } + return nullptr; +} + +bool PythonPass::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + auto seen = NewSeenGeneration(); + // 1024 is for the initial capacity of deque + std::deque todo(1024); + todo.push_back(func_graph->output()); + bool changes = false; + + auto &all_nodes = manager->all_nodes(); + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + + // check whether this node has been matched. + if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) { + continue; + } + node->seen_ = seen; + + // select nodes that this transform can be applied. + AnfNodePtr new_node = Run(func_graph, node); + bool change = (new_node != nullptr); + if (new_node != nullptr && new_node != node) { + (void)manager->Replace(node, new_node); + } else if (new_node == nullptr) { + new_node = node; + } + if (run_only_once_) { + return change; + } + + // find success, and add them to todo list + if (IsValueNode(node)) { + todo.push_back(GetValueNode(node)->output()); + } + + if (node->isa()) { + auto &inputs = node->cast()->inputs(); + (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); + } + + auto &node_users = manager->node_users(); + if (change && node_users.find(node) != node_users.end()) { + for (auto &use : node_users[node]) { + auto use_node = use.first; + if (use_node == nullptr) { + continue; + } + todo.push_back(use_node); + if (use_node->seen_ == seen) { + use_node->seen_--; + } + } + } + } + return changes; +} +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/py_pass.h b/mindspore/ccsrc/frontend/optimizer/py_pass.h similarity index 100% rename from mindspore/ccsrc/optimizer/py_pass.h rename to mindspore/ccsrc/frontend/optimizer/py_pass.h diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc new file mode 100644 index 0000000000..86d7067d1c --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "frontend/optimizer/py_pass_manager.h" + +#include +#include +#include +#include + +#include "ir/manager.h" +#include "frontend/optimizer/pass_group.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +PyPassManagerPtr PyPassManager::global_instance = nullptr; +std::unordered_map PyPassManager::phase_to_group_; + +PassGroupPtr PyPassManager::GetPassGroup(Phase phase) { + auto pm = phase_to_group_.find(phase); + if (pm == phase_to_group_.end()) { + return nullptr; + } + return pm->second; +} + +PyPassManagerPtr PyPassManager::GetInstance() { + if (global_instance == nullptr) { + global_instance = std::shared_ptr(new (std::nothrow) PyPassManager()); + } + return global_instance; +} + +PyPassManager::PyPassManager() { + phase_to_group_[Phase::RESOLVE] = std::make_shared(); + phase_to_group_[Phase::OPT] = std::make_shared(); +} + +void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, + Phase phase, bool run_only_once, bool multigraph) { + auto cur_pm = GetPassGroup(phase); + MS_EXCEPTION_IF_NULL(cur_pm); + PythonPassPtr new_pass = std::make_shared(pass_name, pattern, target, run_only_once, multigraph); + cur_pm->AddPass(new_pass); +} + +void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { + auto cur_pm = GetPassGroup(phase); + MS_EXCEPTION_IF_NULL(cur_pm); + if (!cur_pm->DeletePass(pass_name)) { + MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; + } +} + +void PyPassManager::ClearRes() { + MS_LOG(INFO) << "Clear PyPassManager resources!"; + global_instance = nullptr; + phase_to_group_.clear(); +} + +REGISTER_PYBIND_DEFINE( + PyPassManager_, ([](const py::module *m) { + (void)py::enum_(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT); + (void)py::class_>(*m, "PyPassManager_") + .def(py::init([]() { return PyPassManager::GetInstance(); })) + .def("registe", &PyPassManager::Registe, "Registe python pass") + .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass"); + })); +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h new file mode 100644 index 0000000000..84868862a7 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ + +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/primitive_py.h" +#include "utils/graph_utils.h" +#include "common/utils.h" + +#include "pipeline/jit/parse/resolve.h" +#include "frontend/optimizer/py_pass.h" +#include "frontend/optimizer/pass_group.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +class PyPassManager; +using PyPassManagerPtr = std::shared_ptr; + +enum Phase { RESOLVE, OPT }; + +class PyPassManager { + protected: + PyPassManager(); + static PyPassManagerPtr global_instance; + + public: + // Singletons should not be cloneable and assignable + PyPassManager(const PyPassManager &other) = delete; + void operator=(const PyPassManager &) = delete; + // Access the only global instance + static PyPassManagerPtr GetInstance(); + virtual ~PyPassManager() = default; + void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, + Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); + void Unregiste(const std::string &pass_name, Phase phase); + PassGroupPtr GetPassGroup(Phase phase); + void ClearRes(); + + private: + static std::unordered_map phase_to_group_; +}; +} // namespace python_pass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/CMakeLists.txt b/mindspore/ccsrc/frontend/parallel/CMakeLists.txt new file mode 100644 index 0000000000..d2a099cf41 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/CMakeLists.txt @@ -0,0 +1,8 @@ +file(GLOB_RECURSE _PARALLEL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/util.cc" "ps/scheduler.cc" "ps/optimizer_info.cc" "ps/optimizer_info_builder.cc") +if (ENABLE_DUMP_PROTO) + list(REMOVE_ITEM _PARALLEL_SRC_FILES "parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") +endif () + +set_property(SOURCE ${_PARALLEL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PARALLEL) +add_library(_mindspore_frontend_parallel_obj OBJECT ${_PARALLEL_SRC_FILES}) diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc new file mode 100644 index 0000000000..70ae5a7d20 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc @@ -0,0 +1,435 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/allreduce_fusion/allreduce_fusion.h" +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "frontend/parallel/costmodel_context.h" +#include "frontend/parallel/graph_util/node_info.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/step_parallel.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +std::unordered_set FindCNodesWithPara(const AnfNodePtr ¶, uint32_t recursive_times = 0) { + if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { + MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " + << MAX_RECURSIVE_CALL_TIMES; + } + MS_EXCEPTION_IF_NULL(para); + MS_EXCEPTION_IF_NULL(para->func_graph()); + FuncGraphManagerPtr manager = para->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto node_set = manager->node_users()[para]; + std::unordered_set cnode_set; + for (auto &node_pair : node_set) { + auto cnode = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + continue; + } + auto node_prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { + (void)cnode_set.emplace(cnode); + } else { + auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); + for (auto &cnode_sub : cnode_set_sub) { + (void)cnode_set.emplace(cnode_sub); + } + } + } + return cnode_set; +} + +Status AllreduceFusion::AddNodeToGraph() { + const auto ¶meters = root_graph_->parameters(); + for (auto ¶meter : parameters) { + if (!ParameterRequireGrad(parameter)) { + continue; + } + auto cnode_set = FindCNodesWithPara(parameter); + if (cnode_set.empty()) { + continue; + } + for (auto &cnode : cnode_set) { + MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); + if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { + MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); + return FAILED; + } + } + } + return SUCCESS; +} + +CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const { + if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { + MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " + << MAX_RECURSIVE_CALL_TIMES; + } + MS_EXCEPTION_IF_NULL(from); + std::unordered_map cnode_dist; + if (!from->isa()) { + return cnode_dist; + } + auto cnode = from->cast(); + if (!IsValueNode(cnode->input(0))) { + return cnode_dist; + } + + MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) + << " operator_info: " << (cnode->operator_info() != nullptr); + + if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode(); + MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost; + + if (allreduce_graph_.NodeInGraph(cnode)) { + cnode_dist[cnode] = cost; + return cnode_dist; + } else { + auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); + for (auto &ele : cnode_dist_next) { + cnode_dist[ele.first] = cost + ele.second; + } + } + } else { + auto cnode_dist_next = FindNextCNodes(cnode); + for (auto &ele : cnode_dist_next) { + cnode_dist[ele.first] = ele.second; + } + } + return cnode_dist; +} + +CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const { + if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { + MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " + << MAX_RECURSIVE_CALL_TIMES; + } + const auto &from_inputs = from->inputs(); + std::unordered_map dist_map; + MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; + for (auto &input_node : from_inputs) { + auto cnode_dist = FindCNode(input_node, recursive_times + 1); + for (auto &ele : cnode_dist) { + (void)dist_map.emplace(ele); + } + } + return dist_map; +} + +Status AllreduceFusion::AddEdgeToGraph() { + std::unordered_map cnode_state_map; + const auto &cnodes = allreduce_graph_.cnode_set(); + for (auto &cnode : cnodes) { + cnode_state_map[cnode] = 0; + } + const auto &head_cnode = allreduce_graph_.head_cnode(); + std::queue cnode_queue; + cnode_queue.emplace(head_cnode); + cnode_state_map[head_cnode] = 1; + + while (!cnode_queue.empty()) { + const auto cur_cnode = cnode_queue.front(); + cnode_queue.pop(); + cnode_state_map[cur_cnode] = 2; + auto next = FindNextCNodes(cur_cnode); + for (auto &ele : next) { + auto &cnode = ele.first; + auto &dist = ele.second; + if (cnode_state_map[cnode] == 0) { + cnode_queue.emplace(cnode); + cnode_state_map[cnode] = 1; + } + if (allreduce_graph_.AddEdge(cur_cnode, cnode, dist) != SUCCESS) { + MS_LOG(ERROR) << "AddEdge error"; + return FAILED; + } + MS_LOG(DEBUG) << "from " << cur_cnode->DebugString() << ", to " << cnode->DebugString() << " dist " << dist; + } + } + return SUCCESS; +} + +std::vector FindMirror(const AnfNodePtr ¶, uint32_t recursive_times = 0) { + if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { + MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " + << MAX_RECURSIVE_CALL_TIMES; + } + MS_EXCEPTION_IF_NULL(para); + MS_EXCEPTION_IF_NULL(para->func_graph()); + FuncGraphManagerPtr manager = para->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[para]; + std::vector cnode_list; + for (auto &node_pair : node_set) { + auto cnode = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + continue; + } + auto node_prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == CAST) { + auto mirror_cnodes = FindMirror(node_pair.first, recursive_times + 1); + if (mirror_cnodes.empty()) { + MS_LOG(WARNING) << "mirror node after cast not found"; + continue; + } + if (mirror_cnodes.size() > 1) { + MS_LOG(EXCEPTION) << "mirror node after cast number is not 1"; + } + cnode_list.emplace_back(mirror_cnodes[0]); + } + if (node_prim->name() == MIRROR_OPERATOR) { + cnode_list.emplace_back(cnode); + } + } + return cnode_list; +} + +void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string ¶meter_name) { + MS_EXCEPTION_IF_NULL(mirror_cnode); + MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; + auto node_prim = GetValueNode(mirror_cnode->input(0)); + auto old_value_ptr = node_prim->GetAttr(FUSION); + if (old_value_ptr != nullptr) { + if (old_value_ptr->isa()) { + int32_t old_value = old_value_ptr->cast()->value(); + if (old_value < fusion) { + return; + } + } + } + (void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared(fusion))); + (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared(parameter_name))); +} + +Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int32_t fusion) { + auto mirror_cnodes = FindMirror(para); + if (mirror_cnodes.empty()) { + MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; + return SUCCESS; + } + if (mirror_cnodes.size() > 2) { + for (auto &mirror_cnode : mirror_cnodes) { + MS_EXCEPTION_IF_NULL(mirror_cnode); + MS_LOG(INFO) << mirror_cnode->DebugString(); + } + MS_EXCEPTION_IF_NULL(para); + MS_LOG(ERROR) << para->ToString() << " FindMirror is more than 2. " << mirror_cnodes.size() + << "Mirror CNode found."; + return FAILED; + } + for (auto &mirror_cnode : mirror_cnodes) { + auto parameter_name = ParameterName(para); + SetMirrorFusion(mirror_cnode, fusion, parameter_name); + } + return SUCCESS; +} + +Status FindMirrorAndSetFusion(const std::vector ¶s, int32_t fusion) { + for (auto ¶m_node : paras) { + if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { + MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; + return FAILED; + } + } + return SUCCESS; +} + +Status AllreduceFusion::SetFusion(const std::vector &cost_map) { + if (cost_map.size() < 2) { + MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); + return FAILED; + } + int32_t fusion = 1; + for (auto cost_iter = cost_map.end() - 1; cost_iter != cost_map.begin(); --cost_iter) { + auto paras = allreduce_graph_.GetParaByCost(*(cost_iter - 1), *cost_iter); + if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) { + MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; + return FAILED; + } + fusion++; + } + return SUCCESS; +} + +std::vector AllreduceFusion::GenerateCostMap(int32_t fusion_times, double tail_percent) const { + double offset = allreduce_graph_.max() * (1 - tail_percent) / (fusion_times - 1); + MS_LOG(DEBUG) << "max = " << allreduce_graph_.max() << ", offset = " << offset; + std::vector cost_map; + double begin = 0; + for (auto i = 0; i < fusion_times - 1; i++) { + cost_map.push_back(begin); + begin += offset; + } + cost_map.push_back(allreduce_graph_.max() * (1 - tail_percent)); + cost_map.push_back(allreduce_graph_.max()); + MS_LOG(DEBUG) << "cost_map = " << cost_map; + return cost_map; +} + +Status AllreduceFusion::SetFusionByBackwardCompTime() { + auto fusion_times = CostModelContext::GetInstance()->costmodel_allreduce_fusion_times(); + if (fusion_times < 2) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_times' is " << fusion_times << ". Bypass ProcessAllreduceFusion"; + return SUCCESS; + } + auto tail_percent = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_percent(); + if (tail_percent < 0 || tail_percent >= 1) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_tail_percent' is " << tail_percent + << ". Bypass ProcessAllreduceFusion"; + return SUCCESS; + } + const auto cost_map = GenerateCostMap(fusion_times, tail_percent); + MS_LOG(DEBUG) << "AllreduceGraph GenerateCostMap succeed."; + if (SetFusion(cost_map) != SUCCESS) { + MS_LOG(ERROR) << "SetFusion failed."; + return FAILED; + } + MS_LOG(DEBUG) << "AllreduceGraph SetFusion succeed."; + return SUCCESS; +} + +Status AllreduceFusion::GetSetFusionByBackwardCompAndAllreduceTimeParams() { + tail_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_time(); + if (tail_time_ <= 0) { + MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ << ". Bypass ProcessAllreduceFusion"; + return FAILED; + } + allreduce_inherent_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_inherent_time(); + if (allreduce_inherent_time_ <= 0) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_ + << ". Bypass ProcessAllreduceFusion"; + return FAILED; + } + if (tail_time_ <= allreduce_inherent_time_) { + MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ + << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_ + << ".tail_time is not more than allreduce_inherent_time. Bypass ProcessAllreduceFusion"; + return FAILED; + } + allreduce_bandwidth_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_bandwidth(); + if (allreduce_bandwidth_ <= 0) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_bandwidth' is " << allreduce_bandwidth_ + << ". Bypass ProcessAllreduceFusion"; + return FAILED; + } + computation_time_parameter_ = + CostModelContext::GetInstance()->costmodel_allreduce_fusion_computation_time_parameter(); + if (computation_time_parameter_ <= 0) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_computation_time_parameter' is " << computation_time_parameter_ + << ". Bypass ProcessAllreduceFusion"; + return FAILED; + } + return SUCCESS; +} + +Status AllreduceFusion::SetFusionByBackwardCompAndAllreduceTime() { + if (GetSetFusionByBackwardCompAndAllreduceTimeParams() != SUCCESS) { + MS_LOG(ERROR) << "GetSetFusionByBackwardCompAndAllreduceTimeParams failed!"; + return FAILED; + } + allreduce_graph_.SortArnode(); + if (allreduce_graph_.RemoveExtraParas() != SUCCESS) { + MS_LOG(ERROR) << "RemoveExtraParas failed!"; + return FAILED; + } + double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_; + double to_cost = allreduce_graph_.max(); + int32_t fusion = 1; + while (to_cost != 0) { + MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size; + auto node_cost_pair = allreduce_graph_.GetParaByParaSize(to_cost, para_size); + MS_LOG(INFO) << "para size: " << node_cost_pair.first.size() << " from_cost: " << node_cost_pair.second; + auto paras = node_cost_pair.first; + if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) { + MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; + return FAILED; + } + fusion++; + para_size = ((to_cost - node_cost_pair.second) * computation_time_parameter_ - allreduce_inherent_time_) / + allreduce_bandwidth_; + to_cost = node_cost_pair.second; + } + MS_LOG(DEBUG) << "AllreduceGraph SetFusionByBackwardCompAndAllreduceTime succeed."; + return SUCCESS; +} + +Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) { + if (algorithm == 1) { + return SetFusionByBackwardCompTime(); + } + return SetFusionByBackwardCompAndAllreduceTime(); +} + +Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { + if (ret == nullptr) { + MS_LOG(ERROR) << "ret is nullptr."; + return FAILED; + } + auto algorithm = CostModelContext::GetInstance()->costmodel_allreduce_fusion_algorithm(); + if (algorithm < 1 || algorithm > 2) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_algorithm' is " << algorithm << ". Bypass ProcessAllreduceFusion"; + return SUCCESS; + } + ret_ = ret; + root_graph_ = ret_->func_graph(); + MS_EXCEPTION_IF_NULL(root_graph_); + auto graph_set = ForwardGraph(root_graph_); + if (graph_set.size() > 1) { + MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now."; + return SUCCESS; + } + auto forward_graph = *(graph_set.begin()); + MS_EXCEPTION_IF_NULL(forward_graph); + forward_ret_ = forward_graph->get_return(); + MS_EXCEPTION_IF_NULL(forward_ret_); + + if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) { + MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed."; + return FAILED; + } + MS_LOG(DEBUG) << "AllreduceGraph set_head_cnode succeed."; + if (AddNodeToGraph() != SUCCESS) { + MS_LOG(ERROR) << "AddNodeToGraph failed."; + return FAILED; + } + MS_LOG(DEBUG) << "AllreduceGraph AddNodeToGraph succeed."; + if (AddEdgeToGraph() != SUCCESS) { + MS_LOG(ERROR) << "AddNodeToGraph failed."; + return FAILED; + } + MS_LOG(DEBUG) << "AllreduceGraph AddEdgeToGraph succeed."; + if (SetFusionByAlgorithm(algorithm) != SUCCESS) { + MS_LOG(ERROR) << "SetFusionByAlgorithm failed."; + return FAILED; + } + MS_LOG(DEBUG) << "AllreduceGraph SetFusionByAlgorithm succeed."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h new file mode 100644 index 0000000000..7383c477a6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ +#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ + +#include +#include +#include "ir/anf.h" +#include "frontend/parallel/allreduce_fusion/allreduce_graph.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +using CNodeCostMap = std::unordered_map; + +constexpr int32_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM = 0; +constexpr int32_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES = 0; +constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT = 0.1; +constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME = 0.1; +constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0.1; +constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1; +constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1; + +constexpr char FUSION[] = "fusion"; +constexpr char PARAMETER[] = "parameter"; +const uint32_t MAX_RECURSIVE_CALL_TIMES = 100; +class AllreduceFusion { + public: + AllreduceFusion() + : allreduce_graph_(), + ret_(nullptr), + forward_ret_(nullptr), + root_graph_(nullptr), + tail_time_(0), + allreduce_inherent_time_(0), + allreduce_bandwidth_(0), + computation_time_parameter_(0) {} + virtual ~AllreduceFusion() = default; + Status ProcessAllreduceFusion(const CNodePtr &ret); + + private: + Status AddNodeToGraph(); + CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const; + CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const; + Status AddEdgeToGraph(); + std::vector GenerateCostMap(int32_t fusion_times, double tail_percent) const; + Status SetFusion(const std::vector &cost_map); + Status SetFusionByAlgorithm(int32_t algorithm); + Status SetFusionByBackwardCompTime(); + Status SetFusionByBackwardCompAndAllreduceTime(); + Status GetSetFusionByBackwardCompAndAllreduceTimeParams(); + + AllreduceGraph allreduce_graph_; + CNodePtr ret_; + CNodePtr forward_ret_; + FuncGraphPtr root_graph_; + double tail_time_; + double allreduce_inherent_time_; + double allreduce_bandwidth_; + double computation_time_parameter_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.cc new file mode 100644 index 0000000000..ca47b0fa97 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.cc @@ -0,0 +1,209 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/allreduce_fusion/allreduce_graph.h" +#include +#include +#include "ir/anf.h" +#include "frontend/parallel/allreduce_fusion/allreduce_node.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) { + AllreduceNodePtr arnode; + auto cnode_emplace_return = cnode_set_.emplace(node); + if (!cnode_emplace_return.second) { + MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!"; + auto cnode_arnode_pair = cnode_arnode_map_.find(node); + if (cnode_arnode_pair == cnode_arnode_map_.end()) { + MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!"; + } + arnode = cnode_arnode_pair->second; + } else { + arnode = std::make_shared(AllreduceNode()); + } + + if (arnode->Init(node) != SUCCESS) { + MS_LOG(ERROR) << "AllreduceNode Init failed"; + return FAILED; + } + if (arnode->AddPara(para) != SUCCESS) { + MS_LOG(ERROR) << "AllreduceNode AddPara failed"; + return FAILED; + } + cnode_arnode_map_[node] = arnode; + + auto arnode_emplace_return = arnode_set_.insert(arnode); + if (!arnode_emplace_return.second) { + MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!"; + } + cnode_emplace_return = para_cnodeset_map_[para].emplace(node); + if (!cnode_emplace_return.second) { + MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope() + << "'s cnodeset!"; + } + auto para_emplace_return = cnode_paraset_map_[node].emplace(para); + if (!para_emplace_return.second) { + MS_LOG(INFO) << "para: " << para->fullname_with_scope() << " already in node: " << node->DebugString() + << "'s paraset!"; + } + return SUCCESS; +} + +Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) { + auto from_arnode_iter = cnode_arnode_map_.find(from); + if (from_arnode_iter == cnode_arnode_map_.end()) { + MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; + PrintCNodeSet(); + return FAILED; + } + auto to_arnode_iter = cnode_arnode_map_.find(to); + if (to_arnode_iter == cnode_arnode_map_.end()) { + MS_LOG(ERROR) << "cnode to: " << to->DebugString() << "has not been added"; + PrintCNodeSet(); + return FAILED; + } + auto from_arnode = from_arnode_iter->second; + auto to_arnode = to_arnode_iter->second; + if (from_arnode->AddNext(to_arnode) != SUCCESS) { + MS_LOG(ERROR) << "from_arnode AddNext failed"; + return FAILED; + } + if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) { + MS_LOG(ERROR) << "to_arnode AddPrev failed"; + return FAILED; + } + max_ = std::max(max_, to_arnode->depend_feat_size()); + MS_LOG(DEBUG) << "from " << from->DebugString() << ", to " << to->DebugString(); + MS_LOG(DEBUG) << "from depend_feat_size: " << from_arnode->depend_feat_size() + << ", to depend_feat_size: " << to_arnode->depend_feat_size(); + return SUCCESS; +} + +bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const { + auto cnode_iter = cnode_set_.find(node); + return !(cnode_iter == cnode_set_.end()); +} + +std::vector AllreduceGraph::GetParaByCost(double from, double to) { + std::vector nodes; + for (auto &cnode_arnode : cnode_arnode_map_) { + MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() + << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() + << " curr_para_size: " << cnode_arnode.second->curr_para_size(); + if ((cnode_arnode.second->depend_feat_size() <= to) && (cnode_arnode.second->depend_feat_size() > from)) { + (void)nodes.insert(nodes.end(), cnode_paraset_map_[cnode_arnode.first].begin(), + cnode_paraset_map_[cnode_arnode.first].end()); + } + } + return nodes; +} + +std::pair, double> AllreduceGraph::GetParaByParaSize(double to, double para_size) { + std::vector nodes; + double cur_para_size = 0; + double from = to; + for (auto &arnode : arnode_vec_) { + if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { + continue; + } + if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) { + return std::make_pair(nodes, from); + } + (void)nodes.insert(nodes.end(), arnode.paras().begin(), arnode.paras().end()); + cur_para_size += arnode.curr_para_size(); + from = arnode.depend_feat_size(); + } + MS_LOG(INFO) << "GetParaByParaSize has reached head node! para_size: " << para_size + << " cur_para_size: " << cur_para_size << " from: " << from; + return std::make_pair(nodes, from); +} + +void AllreduceGraph::PrintCNodeSet() const { + MS_LOG(INFO) << "CNodeSet:"; + for (auto &cnode : cnode_set_) { + MS_LOG(INFO) << cnode->DebugString(); + } +} + +void AllreduceGraph::PrintAllredueGraphInfo() const { + MS_LOG(INFO) << "max: " << max_; + for (auto &cnode_arnode : cnode_arnode_map_) { + MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); + MS_LOG(INFO) << "arnode info: "; + cnode_arnode.second->ToString(); + } +} + +void AllreduceGraph::PrintArnodeVec() const { + MS_LOG(INFO) << "ArnodeVec:"; + for (auto &arnode : arnode_vec_) { + arnode.ToString(); + } +} + +void AllreduceGraph::PrintArnodeSet() const { + MS_LOG(INFO) << "ArnodeSet:"; + for (auto &arnode : arnode_set_) { + arnode->ToString(); + } +} + +void AllreduceGraph::SortArnode() { + arnode_vec_.clear(); + for (auto &node : arnode_set_) { + arnode_vec_.emplace_back(*node); + } + std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); +} + +Status AllreduceGraph::RemoveExtraParas() { + std::unordered_set para_map; + for (auto &node : arnode_vec_) { + for (auto ¶ : node.paras()) { + auto emplac_result = para_map.emplace(para); + if (!emplac_result.second) { + MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; + if (node.RemovePara(para) != SUCCESS) { + MS_LOG(ERROR) << "remove para failed"; + return FAILED; + } + } + } + } + return SUCCESS; +} + +Status AllreduceGraph::set_head_cnode(const CNodePtr &node) { + auto arnode = std::make_shared(AllreduceNode()); + if (arnode->Init(node) != SUCCESS) { + MS_LOG(ERROR) << "AllreduceNode Init failed"; + } + head_cnode_ = node; + cnode_arnode_map_[node] = arnode; + auto arnode_emplace_return = arnode_set_.insert(arnode); + if (!arnode_emplace_return.second) { + MS_LOG(WARNING) << "node: " << node->DebugString() << "'s arnode has already been added!"; + } + auto cnode_emplace_return = cnode_set_.emplace(node); + if (!cnode_emplace_return.second) { + MS_LOG(WARNING) << "node: " << node->DebugString() << " has already been added!"; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.h new file mode 100644 index 0000000000..a47039f070 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ +#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "frontend/parallel/allreduce_fusion/allreduce_node.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +class AllreduceGraph { + public: + AllreduceGraph() + : head_cnode_(nullptr), + arnode_set_(), + arnode_vec_(), + cnode_set_(), + para_cnode_map_(), + para_cnodeset_map_(), + cnode_paraset_map_(), + cnode_arnode_map_(), + max_(0) {} + virtual ~AllreduceGraph() = default; + Status AddNode(const CNodePtr &node, const AnfNodePtr ¶); + Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist); + bool NodeInGraph(const CNodePtr &node) const; + std::vector GetParaByCost(double from, double to); + // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is + // over para_size. + // Return the parameter AnfNodePtr vector corresponding to these AllreduceNodes and the smallest depend_feat_size. + // If the sum of left AllreduceNode's parameter size is less than para_size, the returned depend_feat_size must be 0. + std::pair, double> GetParaByParaSize(double to, double para_size); + // If one parameter is used by multiple AllreduceNode, parameter belong to the last node for backward computation + // is saved by the corresponding AllreduceNode, parameters belong to other AllreduceNode are removed. + // Called during precise optimization, not implemented temporarily. + void SortArnode(); + Status RemoveExtraParas(); + void PrintCNodeSet() const; + void PrintAllredueGraphInfo() const; + void PrintArnodeVec() const; + void PrintArnodeSet() const; + const std::unordered_set &cnode_set() const { return cnode_set_; } + CNodePtr head_cnode() const { return head_cnode_; } + Status set_head_cnode(const CNodePtr &node); + double max() const { return max_; } + + private: + CNodePtr head_cnode_; + std::set arnode_set_; + std::vector arnode_vec_; + std::unordered_set cnode_set_; + // If One ParameterPtr is used by multiple CNode, the last node for backward computation is saved. + std::unordered_map> para_cnode_map_; + // One ParameterPtr may be used by multiple CNode + std::unordered_map> para_cnodeset_map_; + // Multiple Parameter may be inputs to the same CNode + std::unordered_map> cnode_paraset_map_; + std::unordered_map cnode_arnode_map_; + double max_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc new file mode 100644 index 0000000000..1c478887df --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/allreduce_fusion/allreduce_node.h" +#include +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status AllreduceNode::AddNext(const AllreduceNodePtr &next_node) { + if (next_node == nullptr) { + MS_LOG(ERROR) << "next_node is nullptr!"; + return FAILED; + } + next_.emplace_back(next_node); + return SUCCESS; +} + +Status AllreduceNode::AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max) { + if (prev_node == nullptr) { + MS_LOG(ERROR) << "next_node is nullptr!"; + return FAILED; + } + if (dist <= 0) { + MS_LOG(ERROR) << "dist must be positive! dist: " << dist; + return FAILED; + } + prev_.emplace_back(prev_node); + double add_dist = prev_node->depend_feat_size() + dist; + depend_feat_size_ += add_dist; + if (depend_feat_size_ > *max) { + *max = depend_feat_size_; + } + std::queue next_queue; + for (auto &next : next_) { + next_queue.push(next); + } + while (!next_queue.empty()) { + auto ele = next_queue.front(); + ele->AddDependFeatSize(add_dist); + if (ele->depend_feat_size() > *max) { + *max = ele->depend_feat_size(); + } + for (auto &next : ele->next()) { + next_queue.push(next); + } + next_queue.pop(); + } + return SUCCESS; +} + +Status AllreduceNode::Init(const CNodePtr &cnode_ptr) { + if (cnode_ptr == nullptr) { + MS_LOG(ERROR) << "cnode_ptr is nullptr!"; + return FAILED; + } + cnode_ptr_ = cnode_ptr; + return SUCCESS; +} + +Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { + if (node_ptr == nullptr) { + MS_LOG(ERROR) << "node_ptr is nullptr!"; + return FAILED; + } + if (!node_ptr->isa()) { + MS_LOG(ERROR) << "node_ptr is not a ParameterPtr!"; + return FAILED; + } + auto para_ptr = node_ptr->cast(); + MS_EXCEPTION_IF_NULL(para_ptr); + auto layout_ptr = para_ptr->tensor_layout(); + if (layout_ptr == nullptr) { + MS_LOG(ERROR) << "layout_ptr is nullptr!"; + return FAILED; + } + auto emplace_return = paras_.emplace(node_ptr); + if (emplace_return.second) { + double para_size = static_cast(layout_ptr->slice_shape().size()); + curr_para_size_ += para_size; + para_size_map_[node_ptr] = para_size; + } else { + MS_LOG(INFO) << "node already exist!"; + } + return SUCCESS; +} + +Status AllreduceNode::RemovePara(const AnfNodePtr &node_ptr) { + if (node_ptr == nullptr) { + MS_LOG(ERROR) << "node_ptr is nullptr!"; + return FAILED; + } + auto erase_num = paras_.erase(node_ptr); + if (erase_num == 0) { + MS_LOG(ERROR) << "para not find!"; + return FAILED; + } + curr_para_size_ -= para_size_map_[node_ptr]; + return SUCCESS; +} + +void AllreduceNode::ToString() const { + MS_LOG(INFO) << "cnode: " << cnode_ptr_->DebugString() << "para size: " << paras_.size(); + for (auto ¶ : paras_) { + MS_LOG(INFO) << "para name: " << para->fullname_with_scope() << " size: " << para_size_map_.at(para); + } + MS_LOG(INFO) << "depend_feat_size: " << depend_feat_size_ << " curr_para_size: " << curr_para_size_; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.h new file mode 100644 index 0000000000..6538381f27 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ +#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +class AllreduceNode; +using AllreduceNodePtr = std::shared_ptr; + +class AllreduceNode { + public: + AllreduceNode() + : cnode_ptr_(nullptr), prev_(), next_(), paras_(), para_size_map_(), curr_para_size_(0), depend_feat_size_(0) {} + Status Init(const CNodePtr &cnode_ptr); + Status AddPara(const AnfNodePtr &node_ptr); + Status RemovePara(const AnfNodePtr &node_ptr); + const std::unordered_set ¶s() const { return paras_; } + double curr_para_size() const { return curr_para_size_; } + virtual ~AllreduceNode() = default; + // Add previous node + // prev_node is the previous to be added + // max is the current max depend_feat_size of the AllreduceGraph + Status AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max); + Status AddNext(const AllreduceNodePtr &next_node); + double depend_feat_size() const { return depend_feat_size_; } + void AddDependFeatSize(double add_dist) { depend_feat_size_ += add_dist; } + const std::vector &next() const { return next_; } + void ToString() const; + bool operator<(const AllreduceNode &node) const { return depend_feat_size_ < node.depend_feat_size(); } + bool operator>(const AllreduceNode &node) const { return depend_feat_size_ > node.depend_feat_size(); } + + private: + CNodePtr cnode_ptr_; + std::vector prev_; + std::vector next_; + std::unordered_set paras_; + std::unordered_map para_size_map_; + double curr_para_size_; + double depend_feat_size_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.cc new file mode 100644 index 0000000000..b669fa7782 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" +#include +#include +#include "frontend/optimizer/optimizer.h" +#include "frontend/parallel/allreduce_fusion/allreduce_fusion.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/graph_util/graph_info.h" +#include "frontend/parallel/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(optimizer); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); + bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion(); + // assume no change to graph + bool changes = false; + // control whether use model_parallel mode + if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || + (!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) { + return changes; + } +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); +#endif + MS_LOG(INFO) << "Now entering allreduce fusion"; + DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN)); + + pipeline::ResourceBasePtr res = optimizer->resource(); + MS_EXCEPTION_IF_NULL(res); + + FuncGraphManagerPtr manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + CNodePtr ret = root->get_return(); + MS_EXCEPTION_IF_NULL(ret); + + AllreduceFusion allreduce_fusion; + if (allreduce_fusion.ProcessAllreduceFusion(ret) != SUCCESS) { + MS_LOG(EXCEPTION) << "ProcessAllreduceFusion failed"; + } + + DumpGraph(root, std::string(ALLREDUCE_FUSION_END)); + + // allreduce fusion only run once + root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true); + res->results()[pipeline::kStepParallelGraph] = root; +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); + time += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << time << " us"; +#endif + return changes; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.h new file mode 100644 index 0000000000..2612e71984 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ +#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ + +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace parallel { +constexpr char ALLREDUCE_FUSION_RUN_ONCE_ONLY[] = "allreduce_fusion_run_once_only"; +constexpr char ALLREDUCE_FUSION_BEGIN[] = "allreduce_fusion_begin"; +constexpr char ALLREDUCE_FUSION_END[] = "allreduce_fusion_end"; + +bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.cc new file mode 100644 index 0000000000..531a5cd7f6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/costmodel.h" +#include +#include +#include +#include "frontend/parallel/auto_parallel/graph_costmodel.h" + +namespace mindspore { +namespace parallel { +void Simplify(CostPtrList *clist_ptrs) { + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); + } else { + // inference phase + SimplifyForDecreasingCommunicationForward(clist_ptrs); + } +} +void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { + // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method + // excludes the cost with greater computation_cost_ and greater communication_forward. + // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} + if (!COST_MODEL_SIMPLIFY_CALCULATION) { + return; + } + MS_EXCEPTION_IF_NULL(clist_ptrs); + std::vector id(clist_ptrs->size()); + std::iota(id.begin(), id.end(), size_t(0)); + std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { + return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; + }); + CostPtrList ret; + for (size_t i = 0; i < clist_ptrs->size(); ++i) { + if ((ret.size() == size_t(0)) || + (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) { + ret.emplace_back(std::move(clist_ptrs->at(id[i]))); + } + } + *clist_ptrs = std::move(ret); +} + +void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { + // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing + // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. + if (!COST_MODEL_SIMPLIFY_CALCULATION) { + return; + } + MS_EXCEPTION_IF_NULL(clist_ptrs); + std::vector id(clist_ptrs->size()); + std::iota(id.begin(), id.end(), size_t(0)); + std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { + return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; + }); + CostPtrList ret; + for (size_t i = 0; i < clist_ptrs->size(); ++i) { + if ((ret.size() == size_t(0)) || + (clist_ptrs->at(id[i])->communication_with_partial_para_ < ret.back()->communication_with_partial_para_)) { + ret.emplace_back(std::move(clist_ptrs->at(id[i]))); + } + } + *clist_ptrs = std::move(ret); +} + +void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { + MS_EXCEPTION_IF_NULL(origin_cost); + if (is_redistribution) { + // Redistribution cost + if ((origin_cost->communication_redis_forward_ > EPS) && + (origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { + origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST; + } else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) { + origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS; + } + if ((origin_cost->communication_redis_backward_ > EPS) && + (origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { + origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST; + } else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) { + origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS; + } + origin_cost->communication_cost_ = + origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_; + origin_cost->communication_without_parameter_ = origin_cost->communication_cost_; + origin_cost->communication_with_partial_para_ = origin_cost->communication_cost_; + } else { + // Operator cost + double backward = 0.0; + if (std::abs(origin_cost->communication_cost_ - origin_cost->communication_without_parameter_) > EPS) { + backward = origin_cost->communication_cost_ - origin_cost->communication_without_parameter_; + } + // forward cost + if ((origin_cost->communication_without_parameter_ > EPS) && + (origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) { + origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST; + } else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) { + origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS; + } + // total + if (origin_cost->communication_cost_ > EPS) { + origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward; + } + if (origin_cost->communication_with_partial_para_ > EPS) { + origin_cost->communication_with_partial_para_ = + origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward; + } + } +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h new file mode 100644 index 0000000000..cc4508681b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h @@ -0,0 +1,311 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ +#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ + +#include +#include +#include +#include +#include +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" + +namespace mindspore { +namespace parallel { +struct Decision; +using OperatorName = std::string; +using Attr = std::pair; +using Param = std::pair, int32_t>; +using OperatorParams = std::vector; +using OperatorAttrs = std::vector; +// OutPutInfo.fist: true if the operator's output is a tuple +// OutPutInfo.second: elements number of the tuple output. Only meaningful if OutPutInfo.fist is true. +using OutPutInfo = std::pair; +using OutPutInfoVector = std::vector; +using OperatorArgs = std::pair; +using Operator = std::pair; +using OperatorVector = std::vector; +using RedistributionOpListPtr = std::shared_ptr>; + +struct Cost { + Cost(); + Cost(double computation, double commuication, const std::shared_ptr &decision_ = nullptr) + : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { + memory_with_reuse_ = 0.0; + communication_without_parameter_ = 0.0; + communication_with_partial_para_ = 0.0; + communication_redis_forward_ = 0.0; + communication_redis_backward_ = 0.0; + communication_forward_ = 0.0; + } + // 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase + double memory_with_reuse_; + // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated + // by ONLY forward phase + double computation_cost_; + // 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution) + double communication_cost_; + // communication_without_parameter_ = communication_cost_ - (backward communication from operators) + double communication_without_parameter_; + // communication_with_partial_para_ = + // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ ) + double communication_with_partial_para_; + // communication_forward_ = communication cost from operators (only forward phase) and forward redistribution. + double communication_forward_; + double communication_redis_forward_; + double communication_redis_backward_; + std::shared_ptr decision_ptr_; +}; + +using CostPtr = std::shared_ptr; +using CostPtrList = std::vector>; + +class StrategyWithCost { + public: + StrategyWithCost(StrategyPtr strategy, std::vector inputs_, std::vector outputs_) + : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} + + StrategyWithCost(const StrategyWithCost &swc) = delete; + StrategyWithCost(StrategyWithCost &&swc) + : strategy_ptr(swc.strategy_ptr), + inputs_ptr(swc.inputs_ptr), + outputs_ptr(swc.outputs_ptr), + cost_list(swc.cost_list) {} + ~StrategyWithCost() = default; + + StrategyPtr strategy_ptr; + std::vector inputs_ptr; + std::vector outputs_ptr; + CostPtrList cost_list; +}; + +enum DecisionType { + OP_ELIMINATION, + EDGE_ELIMINATION, + MERGE_ELIMINATION, + CONTRACT_ELIMINATION, + TRIANGLE_ELIMINATION, + STAR_ELIMINATION, + FINAL_TYPE, + FINAL_SINGLE +}; + +struct Decision : public Base { + ~Decision() override = default; + DecisionType type_; +}; + +// 'OpEliminationDecision' is for the Operator Elimination in DP algorithm: u --> v --> w ==> u --> w. +// This data structure records the strategy 'op_strategy_' for v, the edge cost 'left_cost_' for 'u --> v', the +// operator cost 'middle_cost_' for v, and the edge cost 'right_cost_' for 'v --> w' +struct OpEliminationDecision : public Decision { + OpEliminationDecision(StrategyPtr op_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost) + : op_strategy_(std::move(op_stra)), + left_cost_(std::move(l_cost)), + middle_cost_(std::move(m_cost)), + right_cost_(std::move(r_cost)) { + type_ = DecisionType::OP_ELIMINATION; + } + + StrategyPtr op_strategy_; + CostPtr left_cost_; + CostPtr middle_cost_; + CostPtr right_cost_; + MS_DECLARE_PARENT(OpEliminationDecision, Decision); +}; + +/* 'EdgeEliminationDecision' is for the Edge Elimination in DP algorithm: + ____ + / \ + u v ==> u --> v, which replace the multi-edges by a single edge. + \____/ + This data structure records the cost list for all edges 'edges_cost_list_' + */ +struct EdgeEliminationDecision : public Decision { + explicit EdgeEliminationDecision(CostPtrList cost_list) : edges_cost_list_(std::move(cost_list)) { + type_ = DecisionType::EDGE_ELIMINATION; + } + + CostPtrList edges_cost_list_; + MS_DECLARE_PARENT(EdgeEliminationDecision, Decision); +}; + +// 'MergeEliminationDecision' is for the Merge Elimination in DP algorithm: +// w +// | +// | ==> u --> v +// u --> v In the original graph, v has two alive incoming edges, w has one alive outgoing edge, +// and w has zero alive incoming edges. After the Merge Elimination, the result graph contains only 'u -- >v'. +// This data structure records the strategy 'merged_op_strategy_' for operator 'w', +// the cost 'merged_op_cost_' for operator 'w', and the edge cost 'edge_cost_' for 'w --> v'. +struct MergeEliminationDecision : public Decision { + MergeEliminationDecision(StrategyPtr op_stra, CostPtr op_cost, CostPtr edge_c, StrategyPtr tar_op_stra, + CostPtr target_op_c) + : merged_op_strategy_(std::move(op_stra)), + merged_op_cost_(std::move(op_cost)), + edge_cost_(std::move(edge_c)), + target_op_strategy_(std::move(tar_op_stra)), + target_op_cost_(std::move(target_op_c)) { + type_ = DecisionType::MERGE_ELIMINATION; + } + + StrategyPtr merged_op_strategy_; + CostPtr merged_op_cost_; + CostPtr edge_cost_; + StrategyPtr target_op_strategy_; + CostPtr target_op_cost_; + MS_DECLARE_PARENT(MergeEliminationDecision, Decision); +}; + +// 'ContractEliminationDecision' is for the Contract Elimination in DP algorithm: +// u --> v +// | +// | ==> u --> w +// w In the original graph, u has two alive outgoing edges, v has one alive incoming edge, +// and v has zero outgoing edge. After the Contract Elimination, the result graph contains only 'u --> w'. +// This data structure records the strategy 'contracted_op_strategy_' for operator 'v', the cost for +// operator 'contracted_op_cost_', and the edge cost for 'edge_cost_'. +struct ContractEliminationDecision : public Decision { + ContractEliminationDecision(StrategyPtr contra_stra, CostPtr contra_op_cost, CostPtr edge_cost, + StrategyPtr target_stra, CostPtr tar_cost) + : contracted_op_strategy_(std::move(contra_stra)), + contracted_op_cost_(std::move(contra_op_cost)), + edge_cost_(std::move(edge_cost)), + target_op_strategy_(std::move(target_stra)), + target_cost_(std::move(tar_cost)) { + type_ = DecisionType::CONTRACT_ELIMINATION; + } + + StrategyPtr contracted_op_strategy_; + CostPtr contracted_op_cost_; + CostPtr edge_cost_; + StrategyPtr target_op_strategy_; + CostPtr target_cost_; + MS_DECLARE_PARENT(ContractEliminationDecision, Decision); +}; + +/* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm: + * + * u + * / \ + * / \ + * v --- w ==> v --- w In the original graph, u has 2 outgoing edges, v has 1 outgoing edge, + * and w has 2 incoming edges, u can be eliminated into v. + * 'eliminated_op_strategy_' is for u, 'eliminated_op_cost_' is for u, 'eliminated_left_edge_' is for edge u --> v, + * 'eliminated_right_edge_' is for edge u --> w. + */ +struct TriangleEliminationDecision : public Decision { + TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, + StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra) + : eliminated_op_strategy_(std::move(elimi_stra)), + eliminated_op_cost_(std::move(elimi_op_cost)), + left_edge_cost_(std::move(l_edge_cost)), + right_edge_cost_(std::move(r_edge_cost)), + left_node_strategy_(std::move(left_stra)), + left_node_cost_(std::move(l_node_cost)), + right_node_strategy_(std::move(right_stra)) { + type_ = DecisionType::TRIANGLE_ELIMINATION; + } + + StrategyPtr eliminated_op_strategy_; + CostPtr eliminated_op_cost_; + CostPtr left_edge_cost_; + CostPtr right_edge_cost_; + StrategyPtr left_node_strategy_; + CostPtr left_node_cost_; + StrategyPtr right_node_strategy_; + MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); +}; + +/* 'StarEliminationDecision' is for the Star Elimination in DP algorithm: + * + * v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges. + * In addition, v and w have other complicated connections, resulting in v and w can not be performed other + * eliminations. After the StarElimination, u is merged into v, and the resulting graph is splitted into multiple + * connected components. + * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. + */ +struct StarEliminationDecision : public Decision { + StarEliminationDecision(StrategyPtr elimi_op_stra, CostPtr elimi_op_cost, CostPtrList succ_edges_clist, + std::vector succ_ops_stra_list, CostPtrList succ_ops_clist) + : eliminated_op_strategy_(std::move(elimi_op_stra)), + eliminated_op_cost_(std::move(elimi_op_cost)), + succ_edges_cost_list_(std::move(succ_edges_clist)), + succ_ops_stra_list_(std::move(succ_ops_stra_list)), + succ_ops_cost_list_(std::move(succ_ops_clist)) { + type_ = DecisionType::STAR_ELIMINATION; + } + + StrategyPtr eliminated_op_strategy_; + CostPtr eliminated_op_cost_; + CostPtrList succ_edges_cost_list_; + std::vector succ_ops_stra_list_; + CostPtrList succ_ops_cost_list_; + MS_DECLARE_PARENT(StarEliminationDecision, Decision); +}; + +// This data structure records the decision for the graph which contains two nodes: u --> v. This includes +// the strategy 'u_strategy_' for 'u', the strategy 'v_strategy_' for 'v', the cost 'left_cost_' for 'u'. +struct FinalDecision : public Decision { + FinalDecision(StrategyPtr u_stra, StrategyPtr v_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost) + : u_strategy_(std::move(u_stra)), + v_strategy_(std::move(v_stra)), + left_cost_(std::move(l_cost)), + middle_cost_(std::move(m_cost)), + right_cost_(std::move(r_cost)) { + type_ = DecisionType::FINAL_TYPE; + } + + StrategyPtr u_strategy_; + StrategyPtr v_strategy_; + CostPtr left_cost_; + CostPtr middle_cost_; + CostPtr right_cost_; + MS_DECLARE_PARENT(FinalDecision, Decision); +}; + +// This data structure records the final decision for the graph containing a single node: u. This includes +// the strategy 'u_strategy_' for 'u', the cost 'u_cost_' for 'u'. +struct FinalSingleDecision : public Decision { + FinalSingleDecision(StrategyPtr u_stra, CostPtr u_cost) : u_strategy_(std::move(u_stra)), u_cost_(std::move(u_cost)) { + type_ = DecisionType::FINAL_SINGLE; + } + + StrategyPtr u_strategy_; + CostPtr u_cost_; + MS_DECLARE_PARENT(FinalSingleDecision, Decision); +}; + +using DecisionPtr = std::shared_ptr; +using OpEliminationDecisionPtr = std::shared_ptr; +using EdgeEliminationDecisionPtr = std::shared_ptr; +using MergeEliminationDecisionPtr = std::shared_ptr; +using ContractEliminationDecisionPtr = std::shared_ptr; +using TriangleEliminationDecisionPtr = std::shared_ptr; +using StarEliminationDecisionPtr = std::shared_ptr; +using FinalDecisionPtr = std::shared_ptr; +using FinalSingleDecisionPtr = std::shared_ptr; + +void Simplify(CostPtrList *clist); +void SimplifyForDecreasingCommunicationForward(CostPtrList *clist); +void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist); +void RefineForPracticalCost(const CostPtr &, bool is_redistribution); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc new file mode 100644 index 0000000000..9408596111 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc @@ -0,0 +1,226 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" + +#include +#include +#include + +namespace mindspore { +namespace parallel { +Status GetStrategy(const CostGraphPtr &graph) { + MS_LOG(INFO) << "Searching strategies begins."; + MS_EXCEPTION_IF_NULL(graph); + std::vector eliminations; + bool flag = true; + + // Phase 1: Shrink the CostGraph using 6 operations, and record them in the order. + // Note: the checking and applying of the 6 operations MUST in current order. + while (flag) { + flag = false; + auto node = graph->CheckOpElimination(); + if (node != nullptr) { + // Applying the Operator Elimination + flag = true; + auto l_edge = node->GetAlivePrevEdges()[0]; + auto r_edge = node->GetAliveSuccEdges()[0]; + auto n_edge = graph->EliminationOp(node); + auto elimi = std::make_shared(n_edge, l_edge, node, r_edge); + eliminations.emplace_back(std::move(elimi)); + } + auto edges = graph->CheckEdgeElimination(); + if ((!flag) && (!edges.empty())) { + // Applying the Edge Elimination + flag = true; + auto n_edge = graph->EliminationEdges(edges); + auto elimi = std::make_shared(n_edge, edges); + eliminations.emplace_back(std::move(elimi)); + } + auto merge_node = graph->CheckMergeElimination(); + if ((!flag) && (merge_node != nullptr)) { + // Applying the Merge Elimination + flag = true; + auto succ_edge = merge_node->GetAliveSuccEdges()[0]; + auto target_node = graph->EliminationMerge(merge_node); + auto elimi = std::make_shared(merge_node, succ_edge, target_node); + eliminations.emplace_back(std::move(elimi)); + } + auto contracted_node = graph->CheckContractElimination(); + if ((!flag) && (contracted_node != nullptr)) { + // Applying the Contract Elimination + flag = true; + auto prev_edge = contracted_node->GetAlivePrevEdges()[0]; + auto target_node = graph->EliminationContract(contracted_node); + auto elimi = std::make_shared(target_node, prev_edge, contracted_node); + eliminations.emplace_back(std::move(elimi)); + } + auto triangle_pair = graph->CheckTriangleElimination(); + if ((!flag) && (triangle_pair.first != nullptr)) { + // Applying the Triangle Elimination + flag = true; + auto eliminated_node = triangle_pair.first; + auto l_r_edge = triangle_pair.second; + + auto left_node = l_r_edge->prev_operator(); + auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; + auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; + MS_EXCEPTION_IF_NULL(left_edge); + if (left_edge->next_operator() != left_node) { + auto tmp = left_edge; + left_edge = right_edge; + right_edge = tmp; + } + auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); + auto right_node = l_r_edge->next_operator(); + auto elimi = + std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); + eliminations.emplace_back(std::move(elimi)); + } + auto star_center = graph->CheckStarElimination(); + if ((!flag) && (star_center != nullptr)) { + // Applying the Star Elimination + flag = true; + auto succ_edges = graph->EliminationStar(star_center); + std::vector succ_nodes; + for (size_t i = 0; i < succ_edges.size(); ++i) { + MS_EXCEPTION_IF_NULL(succ_edges[i]); + succ_nodes.push_back(succ_edges[i]->next_operator()); + } + auto elimi = std::make_shared(star_center, succ_edges, succ_nodes); + eliminations.emplace_back(std::move(elimi)); + } + } + + // Phase 2: Search the cost_list in the final graph, and determine the optimal one + if (graph->SearchStrategy() != SUCCESS) { + MS_LOG(ERROR) << "Searching strategy for the final failed."; + return FAILED; + } + + // Phase 3: Recover the original CostGraph, the determine strategy for each operator + if (RecoverStrategy(eliminations) == SUCCESS) { + MS_LOG(INFO) << "Searching strategies ends."; + return SUCCESS; + } else { + MS_LOG(EXCEPTION) << "Searching strategies failed."; + } +} + +Status RecoverStrategy(std::vector eliminations) { + std::vector::reverse_iterator rit; + + for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) { + if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto e = elimination->new_edge_; + auto w = elimination->op_; + MS_EXCEPTION_IF_NULL(e); + MS_EXCEPTION_IF_NULL(w); + auto left_edge = elimination->left_edge_; + auto right_edge = elimination->right_edge_; + MS_EXCEPTION_IF_NULL(left_edge); + MS_EXCEPTION_IF_NULL(right_edge); + auto decision = e->selected_cost()->decision_ptr_->cast(); + w->SetSelectedStrategyAndCost(decision->op_strategy_, decision->middle_cost_); + left_edge->set_selected_cost(decision->left_cost_); + right_edge->set_selected_cost(decision->right_cost_); + MS_LOG(INFO) << "Recover opElimination succeeded."; + } else if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto new_edge = elimination->new_edge_; + MS_EXCEPTION_IF_NULL(new_edge); + auto &edges = elimination->edges_; + auto decision = new_edge->selected_cost()->decision_ptr_->cast(); + for (size_t j = 0; j < edges.size(); ++j) { + MS_EXCEPTION_IF_NULL(edges[j]); + edges[j]->set_selected_cost(decision->edges_cost_list_[j]); + } + MS_LOG(INFO) << "Recover edgeElimination succeeded."; + } else if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto target_node = elimination->target_node_; + MS_EXCEPTION_IF_NULL(target_node); + auto merged_node = elimination->merged_node_; + MS_EXCEPTION_IF_NULL(merged_node); + auto merged_edge = elimination->dir_edge_; + MS_EXCEPTION_IF_NULL(merged_edge); + MS_EXCEPTION_IF_NULL(target_node->selected_cost()); + MS_EXCEPTION_IF_NULL(target_node->selected_cost()->decision_ptr_); + auto decision = target_node->selected_cost()->decision_ptr_->cast(); + merged_node->SetSelectedStrategyAndCost(decision->merged_op_strategy_, decision->merged_op_cost_); + merged_edge->set_selected_cost(decision->edge_cost_); + target_node->SetSelectedStrategyAndCost(decision->target_op_strategy_, decision->target_op_cost_); + + MS_LOG(INFO) << "Recover mergeElimination succeeded."; + } else if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto target_node = elimination->target_node_; + auto contracted_node = elimination->contracted_node_; + auto contracted_edge = elimination->dir_edge_; + auto decision = target_node->selected_cost()->decision_ptr_->cast(); + + contracted_node->SetSelectedStrategyAndCost(decision->contracted_op_strategy_, decision->contracted_op_cost_); + contracted_edge->set_selected_cost(decision->edge_cost_); + target_node->SetSelectedStrategyAndCost(decision->target_op_strategy_, decision->target_cost_); + MS_LOG(INFO) << "Recover contractElimination succeeded."; + } else if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto left_node = elimination->left_node_; + auto left_edge = elimination->left_edge_; + auto eliminated_node = elimination->eliminated_node_; + auto right_edge = elimination->right_edge_; + auto right_node = elimination->right_node_; + auto decision = left_node->selected_cost()->decision_ptr_->cast(); + + eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); + left_edge->set_selected_cost(decision->left_edge_cost_); + right_edge->set_selected_cost(decision->right_edge_cost_); + // Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy. + left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); + right_node->CheckSelectedStrategy(decision->right_node_strategy_); + MS_LOG(INFO) << "Recover triangleElimination succeeded."; + } else if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto merged_node = elimination->eliminated_node_; + auto succ_edges = elimination->succ_edges_; + auto succ_nodes = elimination->succ_ops_; + // decision is hided in succ_nodes[0] + auto decision = succ_nodes[0]->selected_cost()->decision_ptr_->cast(); + + merged_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); + for (size_t i = 0; i < succ_edges.size(); ++i) { + succ_edges[i]->set_selected_cost(decision->succ_edges_cost_list_[i]); + } + MS_EXCEPTION_IF_NULL(succ_nodes[0]); + MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]); + MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); + // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. + succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); + for (size_t k = 1; k < succ_nodes.size(); ++k) { + succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); + } + MS_LOG(INFO) << "Recover starElimination succeeded."; + } else { + MS_LOG(ERROR) << "Unknown Elimination type."; + return FAILED; + } + } + + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h new file mode 100644 index 0000000000..812f375f0b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h @@ -0,0 +1,152 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ +#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ + +#include +#include +#include +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" + +namespace mindspore { +namespace parallel { +// There are 3 meta phases of the Dynamic Programming (DP) algorithm. The input is a CostGraph, and the goal +// is to compute the strategy for each operator in the CostGraph. +// +// Phase 1: Shrink the CostGraph using 6 operations, and record them in the order +// Using for operations: Operator Elimination, Edge Elimination, Merge Elimination, and Contract Elimination, +// each connected component in the CostGraph can be shrunk in to the final graph: u --> v. See the +// interpretation of 6 operations in costmodel.h. +// Phase 2: Search the cost_list in the final graph, and determine the optimal one +// Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity +// COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost +// Phase 3: Recover the original CostGraph, the determine strategy for each operator +// After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying +// the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy, +// the operators' strategies can be all determined. + +struct Elimination : public Base { + enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, TRIANGLE, STAR }; + Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {} + + EdgePtr new_edge_; + EliminationType type_; +}; + +// Operator Elimination +struct OpElimination : public Elimination { + OpElimination(EdgePtr n_edge, EdgePtr l_edge, OperatorInfoPtr op_info, EdgePtr r_edge) + : Elimination(std::move(n_edge), Elimination::EliminationType::OPERA), + left_edge_(std::move(l_edge)), + op_(std::move(op_info)), + right_edge_(std::move(r_edge)) {} + + EdgePtr left_edge_; + OperatorInfoPtr op_; + EdgePtr right_edge_; + MS_DECLARE_PARENT(OpElimination, Elimination); +}; + +// Edge Elimination +struct EdgeElimination : public Elimination { + EdgeElimination(const EdgePtr &n_edge, std::vector eds) + : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {} + + std::vector edges_; + MS_DECLARE_PARENT(EdgeElimination, Elimination); +}; + +// Merge Elimination +struct MergeElimination : public Elimination { + MergeElimination(OperatorInfoPtr u_info, EdgePtr merged_target_edge, OperatorInfoPtr v_info) + : Elimination(nullptr, Elimination::EliminationType::MERGE), + merged_node_(std::move(u_info)), + dir_edge_(std::move(merged_target_edge)), + target_node_(std::move(v_info)) {} + + OperatorInfoPtr merged_node_; + EdgePtr dir_edge_; + OperatorInfoPtr target_node_; + MS_DECLARE_PARENT(MergeElimination, Elimination); +}; + +// Contract Elimination +struct ContractElimination : public Elimination { + ContractElimination(OperatorInfoPtr tar_info, EdgePtr tar_con_edge, OperatorInfoPtr con_info) + : Elimination(nullptr, Elimination::EliminationType::CONTRACT), + contracted_node_(std::move(con_info)), + dir_edge_(std::move(tar_con_edge)), + target_node_(std::move(tar_info)) {} + + OperatorInfoPtr contracted_node_; + EdgePtr dir_edge_; + OperatorInfoPtr target_node_; + MS_DECLARE_PARENT(ContractElimination, Elimination); +}; + +// Triangle Elimination +struct TriangleElimination : public Elimination { + TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, + OperatorInfoPtr r_node) + : Elimination(nullptr, Elimination::EliminationType::TRIANGLE), + eliminated_node_(std::move(elim_node)), + left_edge_(std::move(l_edge)), + left_node_(std::move(l_node)), + right_edge_(std::move(r_edge)), + right_node_(std::move(r_node)) {} + + OperatorInfoPtr eliminated_node_; + EdgePtr left_edge_; + OperatorInfoPtr left_node_; + EdgePtr right_edge_; + OperatorInfoPtr right_node_; + MS_DECLARE_PARENT(TriangleElimination, Elimination); +}; + +// Star Elimination +struct StarElimination : public Elimination { + StarElimination(OperatorInfoPtr elimi_node, std::vector s_edges, std::vector s_ops) + : Elimination(nullptr, Elimination::EliminationType::STAR), + eliminated_node_(std::move(elimi_node)), + succ_edges_(std::move(s_edges)), + succ_ops_(std::move(s_ops)) {} + + OperatorInfoPtr eliminated_node_; + std::vector succ_edges_; + std::vector succ_ops_; + MS_DECLARE_PARENT(StarElimination, Elimination); +}; + +using EliminationPtr = std::shared_ptr; +using OpEliminationPtr = std::shared_ptr; +using EdgeEliminationPtr = std::shared_ptr; +using MergeEliminationPtr = std::shared_ptr; +using ContractEliminationPtr = std::shared_ptr; +using TriangleEliminationPtr = std::shared_ptr; +using StarEliminationPtr = std::shared_ptr; + +// Phase 1 and Phase 2 +Status GetStrategy(const CostGraphPtr &graph); + +// Phase 3 +Status RecoverStrategy(std::vector eliminations); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc new file mode 100644 index 0000000000..e3f1de7207 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc @@ -0,0 +1,324 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/edge_costmodel.h" + +#include +#include +#include +#include +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Status Edge::InitEdgeCost() { + bool has_available_cost = false; + for (auto &swc : prev_op_->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(swc); + pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr)); + } + for (auto &swc : next_op_->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(swc); + next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr)); + } + if (is_identity_edge) { + for (auto &target_output : pre_op_output_) { + auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); + auto target_output_str = target_output.first; + for (auto &target_input : next_op_input_) { + auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); + auto target_input_str = target_input.first; + if (target_output_lyt == target_input_lyt) { + CostPtrKey ck = {target_output_str, target_input_str}; + CostPtr cost = std::make_shared(0.0, 0.0); + MS_EXCEPTION_IF_NULL(cost); + cost->communication_without_parameter_ = 0.0; + cost->communication_with_partial_para_ = 0.0; + CostPtrList cl; + cl.push_back(cost); + (void)cost_map_.emplace(std::make_pair(ck, cl)); + has_available_cost = true; + } + } + } + } else { + for (auto &target_output : pre_op_output_) { + auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); + auto target_output_str = target_output.first; + auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_]; + auto type = prev_op_->outputs_type()[prev_op_output_index_]; + for (auto &target_input : next_op_input_) { + auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); + auto target_input_str = target_input.first; + CostPtr cost; + if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) { + MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed"; + } + MS_EXCEPTION_IF_NULL(cost); + MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_ + << ", communication_cost: " << cost->communication_cost_ + << ", communication_without_parameter_: " << cost->communication_without_parameter_ + << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; + // refine communication cost calculation for practice + RefineForPracticalCost(cost, true); + cost->communication_forward_ = cost->communication_redis_forward_; + CostPtrKey ck = {target_output_str, target_input_str}; + CostPtrList cl; + cl.push_back(cost); + (void)cost_map_.emplace(std::make_pair(ck, cl)); + has_available_cost = true; + } + } + } + if (!has_available_cost) { + if (FULLY_USE_DEVICES) { + MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ + << " failed, it may be caused by setting 'fully_use_devices' true. Try to set " + "'fully_use_devices' false."; + } else if (ELEMENTWISE_OP_STRA_FOLLOW) { + MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ + << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " + "Try to set 'elementwise_op_strategy_follow' false."; + } + if (edge_name_.find(RESHAPE) != std::string::npos) { + MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ + << " failed, it may be caused by setting different strategies for operators following Reshape. " + "Try to fix that."; + } + MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed."; + } + return Status::SUCCESS; +} + +Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, + size_t type_length, TypePtr type, CostPtr *cost) { + MS_EXCEPTION_IF_NULL(prev_op_); + MS_EXCEPTION_IF_NULL(cost); + RankList dev_list = prev_op_->global_device_list(); + TensorRedistribution tensor_redistribution(false); + + // Init TensorRedistribution + if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; + } + + if (tensor_redistribution.ComputeCost() == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; + } + + double comm_cost = tensor_redistribution.comm_cost(); + double forward_comm_cost = tensor_redistribution.forward_comm_cost(); + double backward_comm_cost = tensor_redistribution.backward_comm_cost(); + double computation_cost = tensor_redistribution.computation_cost(); + double mem_cost = tensor_redistribution.memory_cost(); + + // Now AllGather, ReduceScatter, AlltoAll don't support bool type + MS_EXCEPTION_IF_NULL(type); + if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) { + computation_cost = INF; + comm_cost = INF; + MS_LOG(WARNING) << "Communication Operators don't support bool dtype!"; + } + *cost = std::make_shared(type_length * computation_cost, type_length * comm_cost); + (*cost)->communication_without_parameter_ = type_length * comm_cost; + (*cost)->communication_with_partial_para_ = + (*cost)->communication_without_parameter_ + + COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); + (*cost)->communication_redis_forward_ = type_length * forward_comm_cost; + (*cost)->communication_redis_backward_ = type_length * backward_comm_cost; + (*cost)->memory_with_reuse_ = mem_cost; + return Status::SUCCESS; +} + +CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) { + CostPtrKey ck = {output_str, input_str}; + CostPtrList result; + if (cost_map_.find(ck) != cost_map_.end()) { + return cost_map_.at(ck); + } + return result; +} + +CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector &edges, + const StrategyPtr &input_st_ptr) { + std::function LocalGetCostList = [&](const EdgePtr &edge) { + MS_EXCEPTION_IF_NULL(edge); + return edge->GetCostList(output_st_ptr, input_st_ptr); + }; + CostPtrList result; + std::vector all_cost_list; + all_cost_list.resize(edges.size()); + (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); + + CostPtrList selected_cost_list(all_cost_list.size(), nullptr); + std::function recursive = + [&](size_t k, double computation, double memory, double communication, double communication_without_para, + double communication_forward) { + if (k == edges.size()) { + auto decision = std::make_shared(selected_cost_list); + CostPtr new_cost = std::make_shared(computation, communication); + MS_EXCEPTION_IF_NULL(new_cost); + new_cost->communication_without_parameter_ = communication_without_para; + new_cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; + new_cost->decision_ptr_ = decision; + result.push_back(new_cost); + return; + } + for (auto &c : all_cost_list[k]) { + MS_EXCEPTION_IF_NULL(c); + selected_cost_list[k] = c; + recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, + communication + c->communication_cost_, + communication_without_para + c->communication_without_parameter_, + communication_forward + c->communication_forward_); + } + }; + recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0); + Simplify(&result); + return result; +} + +void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector &edges, OperatorInfoPtr) { + bool valid = false; + for (const auto &output_pair : pre_op_output_) { + StrategyPtr output_st_ptr = output_pair.first; + for (const auto &input_pair : next_op_input_) { + StrategyPtr input_st_ptr = input_pair.first; + CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr); + CostPtrKey key = {output_st_ptr, input_st_ptr}; + cost_map_[key] = clist; + if ((!valid) && (!clist.empty())) { + valid = true; + } + } + } + if (!valid) { + MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; + } +} + +void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, + const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, + CostPtrList *ret_cost_list) { + for (auto &left_cost : left_cost_list) { + MS_EXCEPTION_IF_NULL(left_cost); + for (auto &middle_cost : middle_cost_list) { + MS_EXCEPTION_IF_NULL(middle_cost); + for (auto &right_cost : right_cost_list) { + MS_EXCEPTION_IF_NULL(right_cost); + double computation = + left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; + double communication = + left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; + double communication_forward = + left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_; + double communication_without_para = left_cost->communication_without_parameter_ + + middle_cost->communication_without_parameter_ + + right_cost->communication_without_parameter_; + double memory_cost = + left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_; + + auto decision = std::make_shared(op_strategy, left_cost, middle_cost, right_cost); + auto cost = std::make_shared(computation, communication, decision); + MS_EXCEPTION_IF_NULL(cost); + cost->communication_without_parameter_ = communication_without_para; + cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + cost->memory_with_reuse_ = memory_cost; + cost->communication_forward_ = communication_forward; + ret_cost_list->emplace_back(std::move(cost)); + } + } + } +} + +CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr, + const OperatorInfoPtr &op, const EdgePtr &e2, + const StrategyPtr &input_st_ptr) { + MS_EXCEPTION_IF_NULL(op); + MS_EXCEPTION_IF_NULL(e1); + MS_EXCEPTION_IF_NULL(e2); + CostPtrList result; + for (const auto &op_strategy : op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(op_strategy); + auto middle_strategy = op_strategy->strategy_ptr; + CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), + op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result); + } + Simplify(&result); + return result; +} + +void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) { + bool valid = false; + for (const auto &output_pair : pre_op_output_) { + StrategyPtr output_st_ptr = output_pair.first; + for (const auto &input_pair : next_op_input_) { + StrategyPtr input_st_ptr = input_pair.first; + + CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr); + CostPtrKey key = {output_st_ptr, input_st_ptr}; + cost_map_[key] = clist; + if ((!valid) && (!clist.empty())) { + valid = true; + } + } + } + if (!valid) { + MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; + } +} + +Status Edge::CalculateMemoryCost() { + if (is_output_parameter_involve_ == -1) { + MS_LOG(ERROR) << "is_output_parameter_involve_ is unset."; + return FAILED; + } + if (is_output_parameter_involve_ == 0) { + // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is + // unnecessary to keep them in memory. + for (auto &cost_kv : cost_map_) { + auto &cost_v = cost_kv.second; + if (!cost_v.empty()) { + cost_v[0]->memory_with_reuse_ = 0; + } + } + } + + return SUCCESS; +} + +Status Edge::CalculateMemoryCostForInference() { + // Currently, memory cost is NOT calculated for redistribution + if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) { + MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_; + return FAILED; + } + for (auto &cost_kv : cost_map_) { + auto &cost_v = cost_kv.second; + if (!cost_v.empty()) { + cost_v[0]->memory_with_reuse_ = 0; + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h new file mode 100644 index 0000000000..3fffd1b86d --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h @@ -0,0 +1,171 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ +#define PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ + +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +using CostPtrKey = std::pair; +using OperatorInfoPtr = std::shared_ptr; +using EdgePtr = std::shared_ptr; + +class Edge { + // An 'Edge' connects two Operators in the CostGraph. + public: + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, + const bool &is_com) + : edge_name_(edge_name), + prev_op_(prev_op), + next_op_(next_op), + prev_op_output_index_(output_index_), + next_op_input_index_(input_index_), + is_combined_(is_com) { + is_identity_edge = false; + } + + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, + const bool &is_com, const bool &is_iden) + : edge_name_(edge_name), + prev_op_(prev_op), + next_op_(next_op), + prev_op_output_index_(output_index_), + next_op_input_index_(input_index_), + is_combined_(is_com), + is_identity_edge(is_iden) {} + + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const std::vector &output_indexs_, + const std::vector &input_indexs_, const bool &is_com) + : edge_name_(edge_name), + prev_op_(prev_op), + next_op_(next_op), + pre_op_output_indexs_(output_indexs_), + next_op_input_indexs_(input_indexs_), + is_combined_(is_com) { + prev_op_output_index_ = 0; + next_op_input_index_ = 0; + is_identity_edge = false; + } + + ~Edge() = default; + std::shared_ptr prev_operator() const { return prev_op_; } + std::shared_ptr next_operator() const { return next_op_; } + std::string edge_name() const { return edge_name_; } + // Init cost_map_: for each output layout and input layout, calculate the cost + Status InitEdgeCost(); + // For two operators u--->v, given the output tensor layout of u, + // and the input tensor layout of v, return the redistribution cost, + // and the op_list to carry out the redistribution. + Status GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, + size_t, TypePtr type, CostPtr *cost); + + void set_pre_op_output(const std::vector, std::vector>> &output_set) { + pre_op_output_ = output_set; + } + void set_next_op_input(const std::vector, std::vector>> &input_set) { + next_op_input_ = input_set; + } + + // Given a pair of output strategy and input strategy, return the corresponding costlist + CostPtrList GetCostList(StrategyPtr output_str, StrategyPtr input_str); + + std::vector, std::vector>> prev_op_output() const { + return pre_op_output_; + } + std::vector, std::vector>> next_op_input() const { + return next_op_input_; + } + + bool is_combined() const { return is_combined_; } + size_t prev_op_output_index() const { return prev_op_output_index_; } + size_t next_op_input_index() const { return next_op_input_index_; } + std::vector prev_op_output_indexs() const { return pre_op_output_indexs_; } + std::vector next_op_input_indexs() const { return next_op_input_indexs_; } + + CostPtrList CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, + const std::vector> &edges, + const StrategyPtr &input_st_ptr); + // In the Edge Elimination operation in DP algorithm, 'edges' is replaced by a new edge. This method is used to + // set cost for this new edge + void EdgeEliminationSetNewCost(std::shared_ptr u, const std::vector> &edges, + std::shared_ptr v); + void CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, + const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, + CostPtrList *ret_cost_list); + + CostPtrList CreateOpEliminationCostList(const std::shared_ptr &e1, const StrategyPtr &output_st_ptr, + const std::shared_ptr &op, const std::shared_ptr &e2, + const StrategyPtr &input_st_ptr); + // In the Operation Elimination operation in DP algorithm, 'op', 'e1' and 'e2' are replaced by a new edge. + // This method is used to set cost for this new edge + void OpEliminationSetNewCost(const std::shared_ptr &e1, const std::shared_ptr &op, + const std::shared_ptr &e2); + + void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } + const CostPtr &selected_cost() const { return selected_cost_; } + void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } + // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving + // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory + // at the end of forward phase. + Status CalculateMemoryCost(); + // In the inference phase, + Status CalculateMemoryCostForInference(); + void mark_output_critical() { is_output_critical_ = 1; } + + private: + std::string edge_name_; + std::shared_ptr prev_op_, next_op_; + std::map cost_map_; + // pre_op_output_ + std::vector, std::vector>> pre_op_output_; + std::vector, std::vector>> next_op_input_; + // the index of outputs of prev_op, and the index of inputs of next_op + size_t prev_op_output_index_, next_op_input_index_; + + // pre_op_output_indexs_ and next_op_input_indexs_ store the indexs of inputs and outputs if is_combined = true + std::vector pre_op_output_indexs_; + std::vector next_op_input_indexs_; + // is this edge constructed by combining multiple edges? If is is, then is_combined = true, else is_combined = false + bool is_combined_; + // When a Parameter in the ANF graph being used by multiple operators, we include the Parameter in the costgraph by + // replace the Parameter by a TmpIdentity operator, and connecting this TmpIdentity operator with subsequent + // operators. The resulting edges are different from those normal edges, thus this Bool variable distinguishes them. + // If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor. + bool is_identity_edge; + CostPtr selected_cost_; + // In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator + // is parameter-involved + int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved + // In the inference phase, this is used to mark whether the output of the previous operator is critical. + int is_output_critical_ = 0; +}; +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc new file mode 100644 index 0000000000..1c1fc3a700 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -0,0 +1,1677 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/ops_info/reshape_info.h" +#include "frontend/parallel/step_auto_parallel.h" + +namespace mindspore { +namespace parallel { +CostGraphPtr entire_costgraph = nullptr; +size_t TOTAL_OPS = 0; +double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA; +bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; +double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY; +double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; +double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST; +double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS; +bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; +size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; +bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; +bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; +bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; +int32_t RUN_PHASE = DEFAULT_RUN_PHASE; + +void CostGraph::SetDeviceMemoryAndCostParameter() { + MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); + + // DEVICE_MEMORY_CAPACITY + auto device_memory = CostModelContext::GetInstance()->device_memory_capacity(); + if (device_memory <= 0) { + MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; + } + dev_memory_ = device_memory; + DEVICE_MEMORY_CAPACITY = device_memory; + MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << "."; + + // COST_MODEL_ALPHA + auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); + if (alpha <= 0) { + MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; + } + costmodel_alpha_ = alpha; + MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; + + // COST_MODEL_BETA + auto beta = CostModelContext::GetInstance()->costmodel_beta(); + if (beta <= 0) { + MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; + } + costmodel_beta_ = beta; + MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; + + // COST_MODEL_GAMMA + auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); + if ((gamma < 0) || (gamma > 1)) { + MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; + } + COST_MODEL_GAMMA = gamma; + MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << "."; + + // COST_MODEL_SIMPLIFY_CALCULATION + auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal(); + COST_MODEL_SIMPLIFY_CALCULATION = simplify; + if (COST_MODEL_SIMPLIFY_CALCULATION) { + MS_LOG(INFO) << "costmodel_simplify_cal: true."; + } else { + MS_LOG(INFO) << "costmodel_simplify_cal: false."; + } + + // COST_MODEL_COMMUNI_THRESHOLD + auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold(); + if (communi_threshold < 0) { + MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; + } + COST_MODEL_COMMUNI_THRESHOLD = communi_threshold; + MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << "."; + + // COST_MODEL_COMMUNI_CONST + auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const(); + if (communi_const < 0) { + MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; + } + COST_MODEL_COMMUNI_CONST = communi_const; + MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << "."; + + // COST_MODEL_COMMUNI_BIAS + auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias(); + if (communi_bias < 0) { + MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; + } + COST_MODEL_COMMUNI_BIAS = communi_bias; + MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << "."; + + // TENSOR_SLICE_ALIGNMENT_ENABLE + auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable(); + TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable; + if (TENSOR_SLICE_ALIGNMENT_ENABLE) { + MS_LOG(INFO) << "tensor_slice_align_enable: true."; + } else { + MS_LOG(INFO) << "tensor_slice_align_enable: false."; + } + + // TENSOR_SLICE_ALIGNMENT_SIZE + auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size(); + if (align_size == 0) { + MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; + } + TENSOR_SLICE_ALIGNMENT_SIZE = align_size; + MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << "."; + + // FULLY_USE_DEVICES + auto fully_devices = CostModelContext::GetInstance()->fully_use_device(); + FULLY_USE_DEVICES = fully_devices; + if (FULLY_USE_DEVICES) { + MS_LOG(INFO) << "fully_use_devices: true."; + } else { + MS_LOG(INFO) << "fully_use_devices: false."; + } + + // ELEMENTWISE_OP_STRA_FOLLOW + auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); + ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow; + if (ELEMENTWISE_OP_STRA_FOLLOW) { + MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; + } else { + MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; + } + + // MULTI_SUBGRAPHS + auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs(); + MULTI_SUBGRAPHS = multi_subgraphs; + if (MULTI_SUBGRAPHS) { + MS_LOG(INFO) << "multi_subgraphs: true."; + } else { + MS_LOG(INFO) << "multi_subgraphs: false."; + } + + // RUN_PHASE + auto phase = CostModelContext::GetInstance()->run_phase(); + if (phase != 0 && phase != 1) { + MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; + } + RUN_PHASE = phase; + MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; +} + +void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { + for (auto it = ops_.begin(); it != ops_.end();) { + if ((*it) == op) { + it = ops_.erase(it); + } else { + ++it; + } + } +} + +bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) { + struct IsInGraph { + const OperatorInfoPtr test_; + explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {} + bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); } + }; + return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); +} + +void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { + std::vector curr_edges(edges_[{u_node, v_node}]); + curr_edges.push_back(edge); + edges_[{u_node, v_node}] = curr_edges; + + std::vector curr_out_edges(out_edges_[u_node]); + curr_out_edges.push_back(edge); + out_edges_[u_node] = curr_out_edges; + + std::vector curr_in_edges(in_edges_[v_node]); + curr_in_edges.push_back(edge); + in_edges_[v_node] = curr_in_edges; +} + +bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { + for (auto &edge_pair : edges_) { + auto edges = edge_pair.second; + for (auto &edge : edges) { + MS_EXCEPTION_IF_NULL(edge); + bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) && + (edge->next_op_input_index() == input_index); + if (bool_result) { + return true; + } + } + } + return false; +} + +std::vector> CostGraph::ConstructConnectedComponents( + std::vector alive_ops) { + std::map visited; + + for (auto &op : alive_ops) { + visited[op] = false; + } + + MS_LOG(INFO) << "visited: " << visited.size() << "."; + for (auto &op : alive_ops) { + if ((!visited[op]) && op->is_alive()) { + std::shared_ptr new_component = std::make_shared(); + MS_EXCEPTION_IF_NULL(new_component); + new_component->SetDeviceMemoryAndCostParameter(); + DFS(op, &visited, new_component); + connected_compoents_.push_back(new_component); + } + } + return connected_compoents_; +} + +void CostGraph::DFS(const OperatorInfoPtr ¤t_op, std::map *visited, + const std::shared_ptr &component) { + MS_EXCEPTION_IF_NULL(visited); + MS_EXCEPTION_IF_NULL(component); + visited->at(current_op) = true; + component->AddOperator(current_op); + + for (auto &edge : current_op->succ_edges()) { + bool bool_test = (visited->find(edge->next_operator()) != visited->end()) && + (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive(); + if (bool_test) { + component->AddEdge(current_op, edge->next_operator(), edge); + DFS(edge->next_operator(), visited, component); + } + } + + for (auto &edge : current_op->prev_edges()) { + bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) && + (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive(); + if (bool_test) { + component->AddEdge(edge->prev_operator(), current_op, edge); + DFS(edge->prev_operator(), visited, component); + } + } +} + +// Create final cost list for the graph: u --> v +CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr &e, + const OperatorInfoPtr &v) { + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(v); + MS_EXCEPTION_IF_NULL(e); + CostPtrList ret; + for (const auto &u_strategy : u->GetStrategyCost()) { + for (const auto &v_strategy : v->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(u_strategy); + MS_EXCEPTION_IF_NULL(v_strategy); + auto u_strategy_ptr = u_strategy->strategy_ptr; + auto v_strategy_ptr = v_strategy->strategy_ptr; + CostPtrList clist1 = u_strategy->cost_list; + CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr); + CostPtrList clist3 = v_strategy->cost_list; + for (const auto &cost1 : clist1) { + for (const auto &cost2 : clist2) { + for (const auto &cost3 : clist3) { + MS_EXCEPTION_IF_NULL(cost1); + MS_EXCEPTION_IF_NULL(cost2); + MS_EXCEPTION_IF_NULL(cost3); + double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; + double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; + double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; + double communication_forward = + cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_; + double communication_without_para = cost1->communication_without_parameter_ + + cost2->communication_without_parameter_ + + cost3->communication_without_parameter_; + auto decision = + std::make_shared(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); + auto cost = std::make_shared(computation, communication, decision); + MS_EXCEPTION_IF_NULL(cost); + cost->communication_without_parameter_ = communication_without_para; + cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + cost->memory_with_reuse_ = memory; + cost->communication_forward_ = communication_forward; + ret.push_back(cost); + } + } + } + } + } + + Simplify(&ret); + return ret; +} + +// Create final cost list for the graph containing a signle node: u +CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { + MS_EXCEPTION_IF_NULL(u); + CostPtrList ret; + for (const auto &u_strategy : u->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(u_strategy); + auto u_strategy_ptr = u_strategy->strategy_ptr; + CostPtrList clist1 = u_strategy->cost_list; + for (const auto &cost1 : clist1) { + MS_EXCEPTION_IF_NULL(cost1); + auto decision = std::make_shared(u_strategy_ptr, cost1); + auto new_cost = std::make_shared(cost1->computation_cost_, cost1->communication_cost_, decision); + MS_EXCEPTION_IF_NULL(new_cost); + new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; + new_cost->communication_with_partial_para_ = + cost1->communication_without_parameter_ + + COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); + new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; + new_cost->communication_forward_ = cost1->communication_forward_; + ret.push_back(new_cost); + } + } + + Simplify(&ret); + return ret; +} + +CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) { + // Select the cost with minimum inference time. Currently, the inference time is modeled as = + // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_ + if (cost_list.empty()) { + MS_LOG(ERROR) << "Final cost list is null."; + return nullptr; + } + CostPtrList after_mem_filter; + double minimum_memory = DBL_MAX; + // Filter out the valid costs. + for (auto &a_cost : cost_list) { + if (a_cost->memory_with_reuse_ <= memory) { + after_mem_filter.emplace_back(std::move(a_cost)); + } else if (a_cost->memory_with_reuse_ < minimum_memory) { + minimum_memory = a_cost->memory_with_reuse_; + } + } + if (after_mem_filter.empty()) { + MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory + << ", the memory capacity is: " << memory << "."; + return nullptr; + } + // Init the returned value with first cost. + CostPtr ret = after_mem_filter[0]; + + double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_; + MS_LOG(INFO) << "Cost 0: " + << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ + << ", communication_forward_: " << ret->communication_forward_ + << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ + << ", communication_cost_: " << ret->communication_cost_ + << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; + MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; + for (size_t i = 1; i < after_mem_filter.size(); ++i) { + MS_EXCEPTION_IF_NULL(after_mem_filter[i]); + MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ + << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ + << ", communication_forward_: " << after_mem_filter[i]->communication_forward_ + << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ + << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ + << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ + << "."; + auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + + costmodel_beta_ * after_mem_filter[i]->communication_forward_; + MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; + if (minimum > tmp) { + minimum = tmp; + ret = after_mem_filter[i]; + MS_LOG(INFO) << "Selected: " << i; + } + } + return ret; +} + +CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { + // Select the cost with minimum training time. Currently, the training time is modeled as = + // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_ + if (cost_list.empty()) { + MS_LOG(ERROR) << "Final cost list is null."; + return nullptr; + } + CostPtrList after_mem_filter; + double minimum_memory = DBL_MAX; + // Filter out the valid costs. + for (auto &a_cost : cost_list) { + if (a_cost->memory_with_reuse_ <= memory) { + after_mem_filter.emplace_back(std::move(a_cost)); + } else if (a_cost->memory_with_reuse_ < minimum_memory) { + minimum_memory = a_cost->memory_with_reuse_; + } + } + if (after_mem_filter.empty()) { + MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory + << ", the memory capacity is: " << memory << "."; + return nullptr; + } + // Init the returned value with first cost. + CostPtr ret = after_mem_filter[0]; + + double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; + MS_LOG(INFO) << "Cost 0: " + << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ + << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ + << ", communication_cost_: " << ret->communication_cost_ + << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; + MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; + for (size_t i = 1; i < after_mem_filter.size(); ++i) { + MS_EXCEPTION_IF_NULL(after_mem_filter[i]); + MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ + << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ + << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ + << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ + << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ + << "."; + auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + + costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_; + MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; + if (minimum > tmp) { + minimum = tmp; + ret = after_mem_filter[i]; + MS_LOG(INFO) << "Selected: " << i; + } + } + return ret; +} + +CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_cost_list, + double available_memory) { + CostPtrList selected_cost_list(all_cost_list.size(), nullptr); + double minimum = DBL_MAX, total_memory = 0.0; + CostPtrList ret(all_cost_list.size(), nullptr); + // Check whether valid costs exist. + for (size_t i = 0; i < all_cost_list.size(); ++i) { + if (all_cost_list[i][0] == nullptr) { + MS_LOG(ERROR) << "The cost list " << i << " is empty."; + return ret; + } else { + double memory_i_cost = DBL_MAX; + for (size_t j = 0; j < all_cost_list[i].size(); ++j) { + if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) { + memory_i_cost = all_cost_list[i][j]->memory_with_reuse_; + } + } + total_memory += memory_i_cost; + } + } + if (total_memory >= available_memory) { + MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory + << ", minimum strategy cost: " << total_memory << "."; + return selected_cost_list; + } + + std::function recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive, + &available_memory, this](size_t k) { + if (k == all_cost_list.size()) { + double tmp_memory = 0.0, tmp_minimum = 0.0; + for (size_t i = 0; i < selected_cost_list.size(); ++i) { + MS_EXCEPTION_IF_NULL(selected_cost_list[i]); + tmp_memory += selected_cost_list[i]->memory_with_reuse_; + tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + + costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; + } + MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum + << "."; + if (tmp_memory < available_memory && tmp_minimum < minimum) { + ret = selected_cost_list; + minimum = tmp_minimum; + MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << "."; + } + return; + } + + MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << "."; + for (auto &c : all_cost_list[k]) { + selected_cost_list[k] = c; + recursive(k + 1); + } + }; + recursive(0); + return ret; +} + +Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector &alive_ops) { + MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph."; + auto connected_components = ConstructConnectedComponents(alive_ops); + MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; + std::vector all_list; + for (size_t j = 0; j < connected_components.size(); ++j) { + auto one_component = connected_components[j]; + MS_EXCEPTION_IF_NULL(one_component); + if (one_component->GetOperators().size() == 1) { + MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; + auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); + all_list.push_back(cost_list); + } else if (one_component->GetOperators().size() == 2) { + MS_LOG(INFO) << "There are 2 operators in a component in the final graph."; + OperatorInfoPtr u, v; + auto first_op = one_component->GetOperators()[0]; + auto second_op = one_component->GetOperators()[1]; + MS_EXCEPTION_IF_NULL(first_op); + MS_EXCEPTION_IF_NULL(second_op); + if (!first_op->GetAliveSuccEdges().empty() && + first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) { + u = first_op; + v = second_op; + } else if (!second_op->GetAliveSuccEdges().empty() && + second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) { + u = second_op; + v = first_op; + } else { + MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size() + << ", " << second_op->GetAliveSuccEdges().size() << "."; + } + MS_EXCEPTION_IF_NULL(u); + auto e = u->GetAliveSuccEdges()[0]; + auto cost_list = one_component->CreateFinalCostList(u, e, v); + all_list.push_back(cost_list); + } else { + MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size() + << " operators in a component in the final graph."; + } + } + // + auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); + for (size_t k = 0; k < selected_cost_list.size(); ++k) { + auto selected_cost = selected_cost_list[k]; + if (selected_cost == nullptr) { + MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(connected_components[k]); + if (connected_components[k]->GetOperators().size() == 1) { + auto u = connected_components[k]->GetOperators()[0]; + auto decision = selected_cost->decision_ptr_->cast(); + u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); + MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; + } else if (connected_components[k]->GetOperators().size() == 2) { + OperatorInfoPtr u = nullptr, v = nullptr; + auto first_op = connected_components[k]->GetOperators()[0]; + auto second_op = connected_components[k]->GetOperators()[1]; + MS_EXCEPTION_IF_NULL(first_op); + MS_EXCEPTION_IF_NULL(second_op); + if (!first_op->GetAliveSuccEdges().empty() && + first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) { + u = first_op; + v = second_op; + } else if (!second_op->GetAliveSuccEdges().empty() && + second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) { + u = second_op; + v = first_op; + } + MS_EXCEPTION_IF_NULL(u); + auto e = u->GetAliveSuccEdges()[0]; + MS_EXCEPTION_IF_NULL(v); + MS_EXCEPTION_IF_NULL(e); + MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); + auto decision = selected_cost->decision_ptr_->cast(); + MS_EXCEPTION_IF_NULL(decision); + u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); + v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); + e->set_selected_cost(decision->middle_cost_); + MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; + } + } + return SUCCESS; +} + +// searching the strategy for the final eliminated graph +Status CostGraph::SearchStrategy() { + MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began."; + std::vector alive_ops; + (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) { + MS_EXCEPTION_IF_NULL(op); + if (op->is_alive()) { + alive_ops.push_back(op); + } + }); + + if (alive_ops.size() > 2) { + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + return SearchStrategyForMultiNodeFinalGraph(alive_ops); + } else { + // inference phase + MS_LOG(EXCEPTION) + << "Currently, searching strategy for the multi-node final graph in inference phase is not supported."; + } + } else if (alive_ops.size() == 1) { + MS_LOG(INFO) << "There are 1 single node in the final graph."; + OperatorInfoPtr u = alive_ops[0]; + auto cost_list = CreateFinalSingleCostList(u); + CostPtr cost = nullptr; + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); + } else { + // inference phase + cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_); + } + if (cost == nullptr) { + MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(cost->decision_ptr_); + auto decision = cost->decision_ptr_->cast(); + MS_EXCEPTION_IF_NULL(decision); + u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); + MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; + return SUCCESS; + } else { + // In this case, the final graph should contains exactly 2 nodes. + if (alive_ops.empty()) { + MS_LOG(INFO) << "0 Operator in the final graph."; + return SUCCESS; + } + OperatorInfoPtr u, v; + MS_EXCEPTION_IF_NULL(alive_ops[0]); + MS_EXCEPTION_IF_NULL(alive_ops[1]); + if (!alive_ops[0]->GetAliveSuccEdges().empty() && + alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) { + u = alive_ops[0]; + v = alive_ops[1]; + } else if (!alive_ops[1]->GetAliveSuccEdges().empty() && + alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) { + u = alive_ops[1]; + v = alive_ops[0]; + } else { + if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) { + MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size() + << ", " << alive_ops[1]->GetAliveSuccEdges().size() << "."; + } else { + // In this case, the final graph consists of two single nodes + MS_LOG(INFO) << "There are 2 single nodes in the final graph."; + std::vector all_list; + auto connected_components = ConstructConnectedComponents(alive_ops); + MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; + for (size_t i = 0; i < connected_components.size(); ++i) { + MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; + auto one_component = connected_components[i]; + MS_EXCEPTION_IF_NULL(one_component); + auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); + all_list.push_back(cost_list); + } + CostPtrList selected_cost_list; + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); + } else { + // inference phase + MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " + "phase is not supported."; + } + for (size_t k = 0; k < selected_cost_list.size(); ++k) { + auto selected_cost = selected_cost_list[k]; + if (selected_cost == nullptr) { + MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(connected_components[k]); + auto one_operator = connected_components[k]->GetOperators()[0]; + MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); + auto decision = selected_cost->decision_ptr_->cast(); + MS_EXCEPTION_IF_NULL(decision); + one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); + MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; + } + + return SUCCESS; + } + } + MS_LOG(INFO) << "There are 2 nodes in the final graph."; + // In this case, the finale graph is exactly of the form: u --> v + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(v); + auto e = u->GetAliveSuccEdges()[0]; + MS_EXCEPTION_IF_NULL(e); + auto cost_list = CreateFinalCostList(u, e, v); + CostPtr cost = nullptr; + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); + } else { + MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " + "phase is not supported."; + } + if (cost == nullptr) { + MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(cost->decision_ptr_); + auto decision = cost->decision_ptr_->cast(); + MS_EXCEPTION_IF_NULL(decision); + u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); + v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); + e->set_selected_cost(decision->middle_cost_); + MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; + return SUCCESS; + } +} + +// Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated +// return the v and the edge u --> v +OperatorInfoPtr CostGraph::CheckOpElimination() const { + for (auto &op : ops_) { + bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1; + if (bool_test) { + if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) { + return op; + } + } + } + return nullptr; +} + +// Check the graph whether an EdgeElimination can be performed +std::vector> CostGraph::CheckEdgeElimination() const { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + if (!op->is_alive()) continue; + std::map count; + for (auto &edge : op->GetAliveSuccEdges()) { + MS_EXCEPTION_IF_NULL(edge); + auto v = edge->next_operator(); + count[v.get()]++; + } + for (auto &pair : count) { + auto *op_ptr = pair.first; + int op_count = pair.second; + if (op_count > 1) { + std::vector> ret; + for (auto &edge : op->GetAliveSuccEdges()) { + MS_EXCEPTION_IF_NULL(edge); + if (edge->next_operator().get() == op_ptr) { + ret.push_back(edge); + } + } + return ret; + } + } + } + return {}; +} + +// Check the graph whether a MergeElimination can be performed +OperatorInfoPtr CostGraph::CheckMergeElimination() const { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1; + if (bool_test) { + auto next_op = op->GetAliveSuccEdges()[0]->next_operator(); + MS_EXCEPTION_IF_NULL(next_op); + if (!next_op->GetAlivePrevEdges().empty()) { + return op; + } + } + } + return nullptr; +} + +// Check the graph whether a ContractElimination can be performed +OperatorInfoPtr CostGraph::CheckContractElimination() const { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty(); + if (bool_test) { + auto edge = op->GetAlivePrevEdges()[0]; + MS_EXCEPTION_IF_NULL(edge); + auto prev_op = edge->prev_operator(); + MS_EXCEPTION_IF_NULL(prev_op); + if (!prev_op->GetAliveSuccEdges().empty()) { + return op; + } + } + } + return nullptr; +} + +// Check the graph whether a TriangleElimination can be performed +std::pair> CostGraph::CheckTriangleElimination() const { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2); + if (bool_test) { + auto edge1 = op->GetAliveSuccEdges()[0]; + auto edge2 = op->GetAliveSuccEdges()[1]; + MS_EXCEPTION_IF_NULL(edge1); + MS_EXCEPTION_IF_NULL(edge2); + auto first_op = edge1->next_operator(); + auto second_op = edge2->next_operator(); + MS_EXCEPTION_IF_NULL(first_op); + for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) { + if (first_op_succ_edge->next_operator() == second_op) { + return {op, first_op_succ_edge}; + } + } + MS_EXCEPTION_IF_NULL(second_op); + for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) { + if (second_op_succ_edge->next_operator() == first_op) { + return {op, second_op_succ_edge}; + } + } + } + } + return {nullptr, nullptr}; +} + +// Check the graph whether a StarElimination can be performed. +// NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. +OperatorInfoPtr CostGraph::CheckStarElimination() const { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1); + if (bool_test) { + return op; + } + } + return nullptr; +} + +// This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace +// 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge. +std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr &op) { + // in this case, the operators are organised in the form of u-->op-->v, and the goal + // is to eliminate 'op'. + MS_EXCEPTION_IF_NULL(op); + MS_LOG(INFO) << "Now eliminating node: " << op->name() << "."; + auto edge_u_op = op->GetAlivePrevEdges()[0]; + auto edge_op_v = op->GetAliveSuccEdges()[0]; + MS_EXCEPTION_IF_NULL(edge_u_op); + MS_EXCEPTION_IF_NULL(edge_op_v); + auto u = edge_u_op->prev_operator(); + auto v = edge_op_v->next_operator(); + std::vector output_indexs, input_indexs; + size_t output_index, input_index; + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(v); + std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); + std::shared_ptr new_edge; + if (edge_u_op->is_combined()) { + output_indexs = edge_u_op->prev_op_output_indexs(); + } else { + output_index = edge_u_op->prev_op_output_index(); + output_indexs.push_back(output_index); + } + if (edge_op_v->is_combined()) { + input_indexs = edge_op_v->next_op_input_indexs(); + } else { + input_index = edge_op_v->next_op_input_index(); + input_indexs.push_back(input_index); + } + + if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) { + new_edge = std::make_shared(new_edge_name, u, v, output_index, input_index, false); + } else { + new_edge = std::make_shared(new_edge_name, u, v, output_indexs, input_indexs, true); + } + MS_EXCEPTION_IF_NULL(new_edge); + new_edge->set_pre_op_output(edge_u_op->prev_op_output()); + new_edge->set_next_op_input(edge_op_v->next_op_input()); + new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v); + u->ReplaceSuccEdge(op, new_edge); + v->ReplacePreEdge(op, new_edge); + op->SetNotAlive(); + MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded."; + return new_edge; +} + +// This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges', +// and sets new costlist for the new edge. +std::shared_ptr CostGraph::EliminationEdges(const std::vector> &edges) { + MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges."; + MS_EXCEPTION_IF_NULL(edges[0]); + auto u = edges[0]->prev_operator(); + auto v = edges[0]->next_operator(); + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(v); + std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); + std::vector output_indexs, input_indexs; + + for (auto &edge : edges) { + MS_EXCEPTION_IF_NULL(edge); + if (edge->is_combined()) { + auto from_output_indexs = edge->prev_op_output_indexs(); + auto from_input_indexs = edge->next_op_input_indexs(); + (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs)); + (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs)); + } else { + output_indexs.push_back(edge->prev_op_output_index()); + input_indexs.push_back(edge->next_op_input_index()); + } + } + + std::shared_ptr new_edge = std::make_shared(new_edge_name, u, v, output_indexs, input_indexs, true); + MS_EXCEPTION_IF_NULL(new_edge); + new_edge->set_pre_op_output(edges[0]->prev_op_output()); + new_edge->set_next_op_input(edges[0]->next_op_input()); + + new_edge->EdgeEliminationSetNewCost(u, edges, v); + + u->ReplaceSuccEdges(v, new_edge); + v->ReplacePreEdges(u, new_edge); + MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded."; + return new_edge; +} + +// Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' +// for this contract under the strategy 'op_strategy' +void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, + const CostPtrList &tar_cost_list, + CostPtrList *const tar_cost_list_new) { + for (size_t i = 0; i < op_cost_list.size(); ++i) { + auto &op_cost = op_cost_list[i]; + MS_EXCEPTION_IF_NULL(op_cost); + for (size_t j = 0; j < edge_cost_list.size(); ++j) { + auto &edge_cost = edge_cost_list[j]; + MS_EXCEPTION_IF_NULL(edge_cost); + for (size_t k = 0; k < tar_cost_list.size(); ++k) { + auto &tar_cost = tar_cost_list[k]; + MS_EXCEPTION_IF_NULL(tar_cost); + double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; + double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; + double communication = + op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; + double communication_forward = + op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_; + double communication_without_para = op_cost->communication_without_parameter_ + + edge_cost->communication_without_parameter_ + + tar_cost->communication_without_parameter_; + + auto decision = + std::make_shared(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); + auto new_cost = std::make_shared(computation, communication, decision); + MS_EXCEPTION_IF_NULL(new_cost); + new_cost->communication_without_parameter_ = communication_without_para; + new_cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; + MS_EXCEPTION_IF_NULL(tar_cost_list_new); + tar_cost_list_new->emplace_back(std::move(new_cost)); + } + } + } +} + +// This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the +// target_op +OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) { + MS_EXCEPTION_IF_NULL(op); + auto target_op = op->GetAliveSuccEdges()[0]->next_operator(); + auto edge_ptr = op->GetAliveSuccEdges()[0]; + MS_EXCEPTION_IF_NULL(target_op); + MS_EXCEPTION_IF_NULL(edge_ptr); + MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << "."; + bool valid = false; + + for (auto &tar_stra_cost : target_op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(tar_stra_cost); + auto tar_stra = tar_stra_cost->strategy_ptr; + auto tar_clist_origin = tar_stra_cost->cost_list; + CostPtrList tar_clist_new; + + for (auto &op_stra_cost : op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(op_stra_cost); + auto op_stra = op_stra_cost->strategy_ptr; + auto op_clist = op_stra_cost->cost_list; + auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra); + + CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); + } + Simplify(&tar_clist_new); + // Set the new costlist w.r.t the strategy + tar_stra_cost->cost_list = tar_clist_new; + if ((!valid) && (!tar_clist_new.empty())) { + valid = true; + } + } + + if (!valid) { + MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed."; + } + op->SetNotAlive(); + MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded."; + return target_op; +} + +// Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' +// for this contract under the strategy 'contract_op_stra' +void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra, + const CostPtrList &contract_op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr target_op_stra, + const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) { + for (size_t i = 0; i < contract_op_cost_list.size(); ++i) { + auto &contract_op_cost = contract_op_cost_list[i]; + MS_EXCEPTION_IF_NULL(contract_op_cost); + for (size_t j = 0; j < edge_cost_list.size(); ++j) { + auto &edge_cost = edge_cost_list[j]; + MS_EXCEPTION_IF_NULL(edge_cost); + for (size_t k = 0; k < tar_cost_list.size(); ++k) { + auto &tar_cost = tar_cost_list[k]; + MS_EXCEPTION_IF_NULL(tar_cost); + double computation = + contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; + double memory = + contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; + double communication = + contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; + double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ + + tar_cost->communication_forward_; + double communication_without_para = contract_op_cost->communication_without_parameter_ + + edge_cost->communication_without_parameter_ + + tar_cost->communication_without_parameter_; + + auto decision = std::make_shared(contract_op_stra, contract_op_cost, edge_cost, + target_op_stra, tar_cost); + auto new_cost = std::make_shared(computation, communication, decision); + new_cost->communication_without_parameter_ = communication_without_para; + new_cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; + tar_cost_list_new->emplace_back(std::move(new_cost)); + } + } + } +} + +// This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the +// target_op +OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) { + MS_EXCEPTION_IF_NULL(op); + auto target_op = op->GetAlivePrevEdges()[0]->prev_operator(); + auto edge_ptr = op->GetAlivePrevEdges()[0]; + MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << "."; + bool valid = false; + + for (auto &tar_stra_cost : target_op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(tar_stra_cost); + auto tar_stra = tar_stra_cost->strategy_ptr; + auto tar_clist_origin = tar_stra_cost->cost_list; + CostPtrList tar_clist_new; + + for (auto &op_stra_cost : op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(op_stra_cost); + auto op_stra = op_stra_cost->strategy_ptr; + auto op_clist = op_stra_cost->cost_list; + auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra); + + CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); + } + Simplify(&tar_clist_new); + // Set the new costlist w.r.t the strategy + tar_stra_cost->cost_list = tar_clist_new; + if ((!valid) && (!tar_clist_new.empty())) { + valid = true; + } + } + if (!valid) { + MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed."; + } + op->SetNotAlive(); + MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded."; + return target_op; +} + +void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, + StrategyPtr right_op_stra, const CostPtr &right_op_cost, + const CostPtrList &elimi_op_clist, + const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost, + const CostPtrList &left_node_clist_origin, + CostPtrList *left_node_clist_new) { + MS_EXCEPTION_IF_NULL(right_edge_cost); + MS_EXCEPTION_IF_NULL(right_op_cost); + MS_EXCEPTION_IF_NULL(left_node_clist_new); + for (auto &elimi_op_cost : elimi_op_clist) { + MS_EXCEPTION_IF_NULL(elimi_op_cost); + for (auto &left_edge_cost : left_edge_clist) { + MS_EXCEPTION_IF_NULL(left_edge_cost); + for (auto &left_node_cost : left_node_clist_origin) { + MS_EXCEPTION_IF_NULL(left_node_cost); + double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + + left_node_cost->computation_cost_ + right_edge_cost->computation_cost_; + double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ + + left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; + double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + + left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; + double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ + + left_node_cost->communication_forward_ + right_edge_cost->communication_forward_; + double new_commu_without = + elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + + left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; + + auto decision = std::make_shared( + elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra); + auto new_cost = std::make_shared(new_computation, new_commu_cost, decision); + new_cost->communication_without_parameter_ = new_commu_without; + new_cost->communication_with_partial_para_ = + new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); + new_cost->memory_with_reuse_ = new_memory; + new_cost->communication_forward_ = new_commu_forward; + left_node_clist_new->emplace_back(std::move(new_cost)); + } + } + } +} + +void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist, + const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra, + const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra, + const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist, + const CostPtrList &left_node_clist_origin, + CostPtrList *left_node_clist_new) { + MS_EXCEPTION_IF_NULL(elimi_op); + for (auto &right_node_cost : right_node_clist) { + MS_EXCEPTION_IF_NULL(right_node_cost); + for (auto &right_edge_cost : right_edge_clist) { + MS_EXCEPTION_IF_NULL(right_edge_cost); + CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, + elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin, + left_node_clist_new); + } + } +} + +OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, + const std::shared_ptr &edge_left_right) { + MS_EXCEPTION_IF_NULL(edge_left_right); + MS_EXCEPTION_IF_NULL(elimi_op); + MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << "."; + auto left_node = edge_left_right->prev_operator(); + auto right_node = edge_left_right->next_operator(); + auto left_edge = elimi_op->GetAliveSuccEdges()[0]; + auto right_edge = elimi_op->GetAliveSuccEdges()[1]; + MS_EXCEPTION_IF_NULL(left_node); + MS_EXCEPTION_IF_NULL(right_node); + MS_EXCEPTION_IF_NULL(left_edge); + MS_EXCEPTION_IF_NULL(right_edge); + MS_LOG(INFO) << "The left operator is: " << left_node->name() << "."; + MS_LOG(INFO) << "The right operator is: " << right_node->name() << "."; + + if (left_edge->next_operator() != left_node) { + auto tmp = left_edge; + left_edge = right_edge; + right_edge = tmp; + } + bool valid = false; + + for (auto &left_node_stra_cost : left_node->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(left_node_stra_cost); + auto left_node_stra = left_node_stra_cost->strategy_ptr; + auto left_node_clist_origin = left_node_stra_cost->cost_list; + CostPtrList left_node_clist_new; + + for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(elimi_op_stra_cost); + auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr; + auto elimi_op_clist = elimi_op_stra_cost->cost_list; + auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra); + + for (auto &right_node_stra_cost : right_node->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(right_node_stra_cost); + auto right_node_stra = right_node_stra_cost->strategy_ptr; + auto right_node_clist = right_node_stra_cost->cost_list; + auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra); + + CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra, + right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin, + &left_node_clist_new); + } + } + Simplify(&left_node_clist_new); + // Set the new costlist w.r.t the strategy + left_node_stra_cost->cost_list = left_node_clist_new; + if ((!valid) && (!left_node_clist_new.empty())) { + valid = true; + } + } + + if (!valid) { + MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed."; + } + elimi_op->SetNotAlive(); + MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded."; + return left_node; +} + +void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra, + const CostPtrList &first_succ_node_clist, + const CostPtrList &first_succ_edge_clist, + const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, + std::vector succ_nodes_stras, + CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs, + CostPtrList *first_succ_node_clist_new) { + for (auto &first_succ_node_cost : first_succ_node_clist) { + for (auto &first_succ_edge_cost : first_succ_edge_clist) { + for (auto &merged_node_cost : merged_op_clist) { + MS_EXCEPTION_IF_NULL(merged_node_cost); + succ_nodes_stras[0] = first_succ_node_stra; + succ_edges_costs[0] = first_succ_edge_cost; + succ_nodes_costs[0] = first_succ_node_cost; + + double computation_cost = merged_node_cost->computation_cost_, + memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, + commu_without = merged_node_cost->communication_without_parameter_, + commu_forward = merged_node_cost->communication_forward_; + for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { + MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); + if (i == 0) { + computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; + memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; + commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; + commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_; + commu_without += succ_edges_costs[i]->communication_without_parameter_ + + succ_nodes_costs[i]->communication_without_parameter_; + } else { + computation_cost += succ_edges_costs[i]->computation_cost_; + memory_cost += succ_edges_costs[i]->memory_with_reuse_; + commu_cost += succ_edges_costs[i]->communication_cost_; + commu_forward += succ_edges_costs[i]->communication_forward_; + commu_without += succ_edges_costs[i]->communication_without_parameter_; + } + } + + auto decision = std::make_shared(merged_op_stra, merged_node_cost, succ_edges_costs, + succ_nodes_stras, succ_nodes_costs); + auto new_cost = std::make_shared(computation_cost, commu_cost, decision); + new_cost->communication_without_parameter_ = commu_without; + new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); + new_cost->memory_with_reuse_ = memory_cost; + new_cost->communication_forward_ = commu_forward; + first_succ_node_clist_new->emplace_back(std::move(new_cost)); + } + } + } +} + +void CostGraph::CreateStarEliminationCostList(std::vector> &succ_edges, + const StrategyPtr &first_succ_node_stra, + const CostPtrList &first_succ_node_clist, + const CostPtrList &first_succ_edge_clist, + const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, + CostPtrList *first_succ_node_clist_new) { + std::vector succ_nodes_stras(succ_edges.size(), nullptr); + CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr); + std::function recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist, + &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs, + &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive, + this](size_t k) { + if (k == succ_edges.size()) { + CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, + merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs, + succ_nodes_costs, first_succ_node_clist_new); + return; + } + MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size() + << ", first_succ_edge_clist: " << first_succ_edge_clist.size() + << ", merged_op_clist: " << merged_op_clist.size() + << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << "."; + auto succ_edge = succ_edges[k]; + MS_EXCEPTION_IF_NULL(succ_edge); + auto succ_node = succ_edge->next_operator(); + MS_EXCEPTION_IF_NULL(succ_node); + for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(succ_node_stra_cost); + auto succ_node_stra = succ_node_stra_cost->strategy_ptr; + auto succ_node_clist = succ_node_stra_cost->cost_list; + auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra); + + for (auto &succ_node_cost : succ_node_clist) { + MS_EXCEPTION_IF_NULL(succ_node_cost); + for (auto &succ_edge_cost : succ_edge_clist) { + MS_EXCEPTION_IF_NULL(succ_edge_cost); + succ_nodes_stras[k] = succ_node_stra; + succ_edges_costs[k] = succ_edge_cost; + succ_nodes_costs[k] = succ_node_cost; + recursive(k + 1); + } + } + } + }; + + recursive(1); +} + +std::vector> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) { + MS_EXCEPTION_IF_NULL(merged_op); + auto succ_edges = merged_op->GetAliveSuccEdges(); + MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << "."; + for (auto &succ_edge : succ_edges) { + MS_EXCEPTION_IF_NULL(succ_edge->next_operator()); + MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << "."; + } + + MS_EXCEPTION_IF_NULL(succ_edges[0]); + auto first_succ_node = succ_edges[0]->next_operator(); + auto first_succ_edge = succ_edges[0]; + bool valid = false; + + // 'merged_op' is merged into first_node + MS_EXCEPTION_IF_NULL(first_succ_node); + for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost); + auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr; + auto first_succ_node_clist = first_succ_node_stra_cost->cost_list; + CostPtrList first_succ_node_clist_new; + + for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(merged_op_stra_cost); + auto merged_op_stra = merged_op_stra_cost->strategy_ptr; + auto merged_op_clist = merged_op_stra_cost->cost_list; + auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra); + + CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, + merged_op_stra, merged_op_clist, &first_succ_node_clist_new); + } + Simplify(&first_succ_node_clist_new); + // Set the new costlist w.r.t the strategy + first_succ_node_stra_cost->cost_list = first_succ_node_clist_new; + if ((!valid) && (!first_succ_node_clist_new.empty())) { + valid = true; + } + } + + if (!valid) { + MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed."; + } + + merged_op->SetNotAlive(); + MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded."; + return succ_edges; +} + +size_t CostGraph::GetNumEdges() const { + size_t sum = 0; + for (const auto &kv : edges_) { + auto &edges = kv.second; + sum += edges.size(); + } + return sum; +} +Status CostGraph::InitSelectedStrategy() { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + if (op->name().find(RESHAPEINFO) != std::string::npos) { + continue; + } + auto result = op->InitSelectedStrategy(op->selected_strategy()); + if (result != SUCCESS) { + return result; + } + } + // reshape init should be apply after the init of it's previous node and next node. + for (size_t i = 0; i < ops_.size(); ++i) { + if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { + auto reshape_info = std::dynamic_pointer_cast(ops_[i]); + auto in_edges = GetOriginalPrevEdges(ops_[i]); + auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr edge) { + return edge->prev_operator()->name() == reshape_info->pre_operator_name(); + }); + auto out_edges = GetOriginalNextEdges(ops_[i]); + auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr edge) { + return edge->next_operator()->name() == reshape_info->next_operator_name(); + }); + if (pre_iter != in_edges.end()) { + MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name(); + int32_t pre_index = reshape_info->pre_operator_index(); + TensorInfo pre_info; + if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) { + pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index]; + } else { + pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index]; + } + reshape_info->SetInputLayout(pre_info.tensor_layout()); + Dimensions stra = pre_info.InferStrategy(); + if (stra.empty()) { + MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; + } + std::vector stra_inputs = {stra}; + StrategyPtr reshape_stra = + std::make_shared((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); + reshape_info->set_strategy(reshape_stra); + } + if (next_iter != out_edges.end()) { + MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name(); + int32_t next_index = reshape_info->next_operator_index(); + reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout()); + } + if (reshape_info->Init(nullptr) != SUCCESS) { + return FAILED; + } + } + } + return SUCCESS; +} + +Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved(); + if ((output_parameter != 0) && (output_parameter != 1)) { + MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed."; + return FAILED; + } + } + return SUCCESS; +} + +void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map *visited, + std::vector *topo_order) { + MS_EXCEPTION_IF_NULL(current_op); + MS_EXCEPTION_IF_NULL(visited); + MS_EXCEPTION_IF_NULL(topo_order); + + visited->at(current_op) = true; + for (const auto &s_edge : current_op->succ_edges()) { + if (!visited->at(s_edge->next_operator())) { + DFSForTopoOrder(s_edge->next_operator(), visited, topo_order); + } + } + topo_order->push_back(current_op); +} + +// Compute a topological order of the costgraph +void CostGraph::TopologyOrder(std::vector *topo_order) { + std::map visited; + for (auto &op : ops_) { + visited[op] = false; + } + + for (auto &op : ops_) { + if (!visited[op]) { + DFSForTopoOrder(op, &visited, topo_order); + } + } +} +void CostGraph::MarkCriticalOpsAndEdges(const std::map &candidate_ops) { + for (auto &op : ops_) { + auto search = candidate_ops.find(op); + if (search != candidate_ops.end()) { + // Mark the critical operators + op->mark_output_critical(); + // Mark the successive edges + for (auto &s_edge : op->succ_edges()) { + s_edge->mark_output_critical(); + } + } else { + op->mark_output_not_critical(); + } + } +} + +Status CostGraph::DetermineCriticalOps(const std::vector &topo_order) { + if (topo_order.size() == 0) { + MS_LOG(ERROR) << "0 operator in costgraph."; + return FAILED; + } + auto &first_op = topo_order[0]; + if (first_op->prev_edges().size() > 0) { + MS_LOG(ERROR) << "The first operator in the first of topological order of " + "costgraph should have 0 incoming edge, but has " + << first_op->prev_edges() << "edges."; + return FAILED; + } + // The 'curr_memory_state' records , where remaining_output_cnt is the number + // of the output of OperatorInfo that currently has not been used + std::map curr_memory_state; + (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size()))); + std::map max_memory_state = curr_memory_state; + // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has + // not been used + double curr_memory_size = first_op->GetOutputsTotalSize(); + double max_memory_size = curr_memory_size; + + for (size_t finished = 1; finished < topo_order.size(); ++finished) { + // Produce + (void)curr_memory_state.emplace( + std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size()))); + curr_memory_size += topo_order[finished]->GetOutputsTotalSize(); + // Consume + for (const auto &prev_edge : topo_order[finished]->prev_edges()) { + const auto &prev_op = prev_edge->prev_operator(); + curr_memory_state[prev_op]--; + } + for (const auto &prev_edge : topo_order[finished]->prev_edges()) { + const auto &prev_op = prev_edge->prev_operator(); + if (curr_memory_state[prev_op] < 0) { + MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op]; + return FAILED; + } else if (curr_memory_state[prev_op] == 0) { + curr_memory_state.erase(prev_op); + curr_memory_size -= prev_op->GetOutputsTotalSize(); + } + } + + if (curr_memory_size < 0) { + MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size; + } + // Modify the max + if (curr_memory_size > max_memory_size) { + max_memory_size = curr_memory_size; + max_memory_state = curr_memory_state; + } + } + // Mark those critical operators + MarkCriticalOpsAndEdges(max_memory_state); + return SUCCESS; +} + +Status CostGraph::ComputeOpsAndEdgesOutputCritical() { + // Two steps to do: + // 1. Compute a topological order of the costgraph + // 2. Determine and mark the operators (and necessary edges) that are critical + std::vector topo_order; + TopologyOrder(&topo_order); + std::reverse(std::begin(topo_order), std::end(topo_order)); + + if (DetermineCriticalOps(topo_order) != SUCCESS) { + MS_LOG(ERROR) << "Determining critical operators failed."; + return FAILED; + } + return SUCCESS; +} + +Status CostGraph::CalculateOpsMemoryCost() { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + if (op->CalculateMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; + return FAILED; + } + } + return SUCCESS; +} + +Status CostGraph::CalculateOpsMemoryCostForInference() { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + if (op->CalculateMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; + return FAILED; + } + } + return SUCCESS; +} + +Status CostGraph::CalculateEdgesMemoryCost() { + for (auto &edge_pair : edges_) { + const auto &edges = edge_pair.second; + for (auto &one_edge : edges) { + if (one_edge->CalculateMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; + return FAILED; + } + } + } + return SUCCESS; +} + +Status CostGraph::CalculateEdgesMemoryCostForInference() { + for (auto &edge_pair : edges_) { + const auto &edges = edge_pair.second; + for (auto &one_edge : edges) { + if (one_edge->CalculateMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; + return FAILED; + } + } + } + return SUCCESS; +} + +OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { + for (auto one_op : ops_) { + if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { + if (one_op->refkey_parameter_name() == p_name) { + return one_op; + } + } + } + return nullptr; +} +Status CostGraph::CorrectOpsMemoryCost() { + for (auto &one_op : ops_) { + if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) { + if (one_op->GetAliveSuccEdges().size() > 1) { + // Filter out the case when the TmpIdentity being used by multiple operators + std::map output_count; + for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { + auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); + output_count[output_index]++; + } + for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { + auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); + if (output_count[output_index] <= 1) { + continue; + } + auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator(); + MS_EXCEPTION_IF_NULL(next_op); + auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index(); + if (next_op->CorrectMemoryCost(input_index) != SUCCESS) { + MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name() + << ", the output_index: " << output_index << ", the input_index: " << input_index << "."; + return FAILED; + } + output_count[output_index]--; + } + } + } + } + return SUCCESS; +} + +Status CostGraph::CalculateMemoryCost() { + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { + // Calculate operators' memory usage + if (CalculateOpsMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed."; + return FAILED; + } + // Calculate edges' memory usage + if (CalculateEdgesMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed."; + return FAILED; + } + // Correct memory usage caused by TmpIdentity + if (CorrectOpsMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed."; + return FAILED; + } + } else { + MS_LOG(ERROR) << "Computing operators' parameter_involved failed."; + return FAILED; + } + } else { + // inference phase + if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) { + // Calculate operators' memory usage + if (CalculateOpsMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; + return FAILED; + } + // Calculate edges's memory usage + if (CalculateEdgesMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; + return FAILED; + } + } else { + MS_LOG(ERROR) << "Computing operators' critical flag failed."; + return FAILED; + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h new file mode 100644 index 0000000000..87f13e3383 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h @@ -0,0 +1,238 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ +#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ + +#include +#include +#include +#include +#include +#include "mindspore/ccsrc/common.h" +#include "common/utils.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/costmodel_context.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/ops_info/tmp_identity_info.h" + +namespace mindspore { +namespace parallel { +#define OPERATOR_TO_OPERATOR_CONNECTOR "-" +#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) +#define DEFAULT_COST_MODEL_ALPHA 1.0 +#define DEFAULT_COST_MODEL_BETA 400.0 +#define DEFAULT_COST_MODEL_GAMMA 0.001 +#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true +#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 +#define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0 +#define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0 +#define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false +#define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 +#define DEFAULT_FULLY_USE_DEVICES true +#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false +#define DEFAULT_IS_MULTI_SUBGRAPHS false +#define DEFAULT_RUN_PHASE 0 +#define TRAINING_PHASE 0 +#define INFERENCE_PHASE 1 + +class CostGraph; +using CostGraphPtr = std::shared_ptr; +extern CostGraphPtr entire_costgraph; +extern size_t TOTAL_OPS; +extern double COST_MODEL_GAMMA; +extern bool COST_MODEL_SIMPLIFY_CALCULATION; +extern double DEVICE_MEMORY_CAPACITY; +extern double COST_MODEL_COMMUNI_THRESHOLD; +extern double COST_MODEL_COMMUNI_CONST; +extern double COST_MODEL_COMMUNI_BIAS; +extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; +extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; +extern bool FULLY_USE_DEVICES; +extern bool ELEMENTWISE_OP_STRA_FOLLOW; +extern bool MULTI_SUBGRAPHS; +extern int32_t RUN_PHASE; + +class CostGraph { + // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have + // output-input dependency relationship. + public: + CostGraph() { + dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY; + costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; + costmodel_beta_ = DEFAULT_COST_MODEL_BETA; + } + ~CostGraph() = default; + void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } + OperatorInfoPtr FindOperatorByIndex(size_t index) { + if (index >= ops_.size()) { + MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << "."; + return nullptr; + } + return ops_[index]; + } + void RemoveOperator(const OperatorInfoPtr &op); + bool IsOperatorInCostGraph(const OperatorInfoPtr &op); + // the edge is in the form: u --> v + void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge); + std::vector> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; } + std::vector> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; } + // An edge is uniquely identified by its name, and its output index and input index. + bool IsEdgeInCostGraph(const std::string &, size_t, size_t); + + void SetDeviceMemoryAndCostParameter(); + + std::vector> ConstructConnectedComponents(std::vector); + void DFS(const OperatorInfoPtr ¤t_op, std::map *visited, + const std::shared_ptr &component); + + CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); + CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); + CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory); + CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); + CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_costlist, double memory); + Status SearchStrategyForMultiNodeFinalGraph(const std::vector &); + std::vector> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { + return edges_[{u_node, v_node}]; + } + double GetDeviceMemory() const { return dev_memory_; } + + // Search the cost_list in the final graph, and determine the optimal one + Status SearchStrategy(); + + // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated + OperatorInfoPtr CheckOpElimination() const; + // Given a graph which contains the following subgraph where there are multiple edges between u and v, these edges + // can be eliminated into one + std::vector CheckEdgeElimination() const; + // Given a graph which contains the following subgraph: + // u + // | + // w --- v --- x + // where u has 0 incoming edge, u has 1 outgoing edge, and v has > 1 incoming edges, u can be merged into v. + // u is returned. + OperatorInfoPtr CheckMergeElimination() const; + // Given a graph which contains the following subgraph: + // u + // | + // v --- x + // where v has 2 outgoing edges, and u has 1 incoming edges and no outgoing edges. In this case, u can be contracted + // into v. u is returned. + OperatorInfoPtr CheckContractElimination() const; + /* Given a graph which contains the following subgraph: + * u + * / \ + * / \ + * v --- w + * where u has 2 outgoing edges, v has 1 outgoing edge, and w has 2 incoming edges, u can be eliminated into v. + * The returned value includes u and the edge >. + */ + std::pair CheckTriangleElimination() const; + /* Given a graph which contains the following subgraph: + * v <--- u ---> w + * where u has 0 incoming edges, and multiple outgoing edges. In addition, v and w have other complicated connections, + * resulting in v and w can not be performed ContractElimination. u is returned. + * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. + */ + OperatorInfoPtr CheckStarElimination() const; + // Applying Operator Elimination in DP algorithm + EdgePtr EliminationOp(const OperatorInfoPtr &op); + // Applying Edge Elimination in DP algorithm + EdgePtr EliminationEdges(const std::vector &edges); + // Applying Merge Elimination in DP algorithm + OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op); + void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, + const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new); + // Applying Contract Elimination in DP algorithm + OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op); + void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr, + const CostPtrList &, CostPtrList *); + + // Applying Triangle Elimination in DP algorithm. return the left_node + OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right); + void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &, + const StrategyPtr &, const StrategyPtr &, const StrategyPtr &, + const CostPtrList &, const CostPtrList &, const CostPtrList &, CostPtrList *); + // Given the relevant costlist, create the TriangleElimination cost + void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr &, const CostPtrList &, + const CostPtrList &, const CostPtr &, const CostPtrList &, CostPtrList *); + + // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op + // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. + std::vector EliminationStar(const OperatorInfoPtr &op); + void CreateStarEliminationCostList(std::vector &, const StrategyPtr &, const CostPtrList &, + const CostPtrList &, const StrategyPtr &, const CostPtrList &, CostPtrList *); + void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, + const StrategyPtr &, const CostPtrList &, std::vector, + CostPtrList &, CostPtrList &, CostPtrList *); + // Calculate memory cost for training phase or inference phase. + Status CalculateMemoryCost(); + // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then + // the memory cost can be resused. This is used to calculate memory in the training phase. + Status CalculateOpsMemoryCost(); + // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then + // the memory cost can be reused. This is used to calculate memory in the training phase. + Status CalculateEdgesMemoryCost(); + // Calculate memory cost of operators in the inference phase. + Status CalculateOpsMemoryCostForInference(); + // Calculate memory cost of edges in the inference phase. + Status CalculateEdgesMemoryCostForInference(); + Status ComputeOpsAndEdgesParameterInvolved(); + // Compute for each operator whether the output is critical. + Status ComputeOpsAndEdgesOutputCritical(); + + std::vector GetOperators() const { return ops_; } + size_t GetNumEdges() const; + Status InitSelectedStrategy(); + OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; + // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only + // once (instead of multiple times), this method is used to correct this. + Status CorrectOpsMemoryCost(); + // Needed by rec_parser + void add_inputs_tensor_name(const std::vector &inputs_tensor_name) { + inputs_tensor_name_list_.push_back(inputs_tensor_name); + } + const std::vector> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; } + void add_tuple_getitem(const std::pair &tuple_getitem) { + auto ret = tuple_getitem_list_.insert(tuple_getitem); + if (ret.second == false) { + MS_LOG(EXCEPTION) << "The insert item is already exist."; + } + } + const std::map get_tuple_getitem_list() const { return tuple_getitem_list_; } + + private: + void TopologyOrder(std::vector *); + void DFSForTopoOrder(const OperatorInfoPtr &, std::map *, std::vector *); + Status DetermineCriticalOps(const std::vector &); + void MarkCriticalOpsAndEdges(const std::map &); + // Needed by rec_parser + std::vector> inputs_tensor_name_list_; + std::map tuple_getitem_list_; + double dev_memory_; + double costmodel_alpha_; + double costmodel_beta_; + std::vector ops_; + std::map, std::vector> edges_; + std::vector> connected_compoents_; + std::map> out_edges_; + std::map> in_edges_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc new file mode 100644 index 0000000000..aaf3fdff3c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc @@ -0,0 +1,892 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" + +#include +#include +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +void OperatorCost::set_is_parameter(const std::vector &is_parameter) { is_parameter_ = is_parameter; } + +void OperatorCost::set_is_parameter_involve(const std::vector &is_parameter_inv) { + is_parameter_involve_ = is_parameter_inv; +} + +void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; } + +void OperatorCost::SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths) { + inputs_type_lengths_ = input_lengths; + outputs_type_lengths_ = output_lengths; +} + +void OperatorCost::set_output_critical(int critical) { is_outputs_critical_ = critical; } + +double OperatorCost::GetMemoryCost(const std::vector &inputs, + const std::vector &outputs) const { + double result = 0.0; + if (output_parameter_involve_ == 1) { + // When this operator has multiple outputs, they all contributes to the memory. + for (size_t i = 0; i < outputs.size(); ++i) { + result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); + } + bool is_any_para_inv = + std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; }); + if (is_any_para_inv) { + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_parameter_[i]) { + result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); + } else if (inputs_related_ && (!is_parameter_involve_[i])) { + // When the inputs of this operator are related, and they are not parameter-involved, then they are included + // in the memory cost. + result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); + } + } + } + } + + return result; +} + +double OperatorCost::GetMemoryCostForInference(const std::vector &, + const std::vector &outputs) const { + double result = 0.0; + if (is_outputs_critical_ == -1) { + MS_LOG(EXCEPTION) << "The critical flag is not set."; + } + if (is_outputs_critical_ == 1) { + for (size_t i = 0; i < outputs.size(); ++i) { + result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); + } + } + return result; +} + +// return the per device communication cost in the forward phase. +double MatMulCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t) const { + TensorInfo input0 = inputs[0]; + TensorInfo output0 = outputs[0]; + Shape input0_shape = input0.shape(); + Shape input0_slice_shape = input0.slice_shape(); + if (input0_shape[input0_shape.size() - 1] == input0_slice_shape[input0_slice_shape.size() - 1]) { + // If the reduced dimension has not been partitioned, then there is no communication cost. + return 0.0; + } else { + // Else, the communication cost is the size (number of bytes) of a slice of output tensor. + return ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + } +} + +// return the per device communication cost in the forward phase. +double MatMulCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not + // fully utilize all devices + double result = 0.0; + if (is_parameter_[1]) { + TensorInfo input1 = inputs[1]; // tensor B + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double MatMulCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t) const { + // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) + double result = 0.0; + TensorInfo output0 = outputs[0]; + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + Shape input0_shape = inputs[0].shape(); + if (input0_shape[input0_shape.size() - 1] != input0_slice_shape[input0_slice_shape.size() - 1]) { + // If the reduced dimension has been partitioned, then there is no communication cost. + result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + } + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double MatMulCost::GetBackwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) + double result = 0.0; + if (is_parameter_[1]) { + TensorInfo input1 = inputs[1]; // tensor B + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + + return result; +} + +// Return the per device communication cost in the forward phase. +double ActivationCost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { + // ReLU is the element-wise operator, thus it does not need communication in the forward phase + return 0.0; +} + +// Return the per device communication cost in the backward phase. +double ActivationCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[0]) { + TensorInfo input1 = inputs[0]; + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + } + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double ActivationCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + TensorInfo input0_info = inputs[0]; + Shape input0_slice_shape = input0_info.slice_shape(); + return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double ActivationCost::GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const { + return 0.0; +} + +// Return the per device communication cost in the forward phase. +double SoftmaxCost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { + // In the forward phase, the communication cost = 0 + return 0.0; +} + +// Return the per device communication cost in the backward phase. +double SoftmaxCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[0]) { + TensorInfo input1 = inputs[0]; + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + } + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double SoftmaxCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + // In the forward phase, the computation cost = slice(A) + TensorInfo input0 = inputs[0]; + Shape input0_slice_shape = input0.slice_shape(); + return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double SoftmaxCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { + return 0.0; +} + +// return the per device communication cost in the forward phase. +double TmpIdentityCost::GetForwardCommCost(const std::vector &, + const std::vector &, int32_t) const { + // Identity is the element-wise operator, thus it does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double TmpIdentityCost::GetBackwardCommCost(const std::vector &, + const std::vector &, int32_t) const { + // Identity is the element-wise operator, thus it does not need communication in the backward phase + return 0.0; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double TmpIdentityCost::GetForwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { + return 0.0; +} + +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes +// this operator uses +double TmpIdentityCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, + int32_t) const { + return 0.0; +} + +// Return the per device PEAK memory cost contributed by this operator in a training iteration. +double TmpIdentityCost::GetMemoryCost(const std::vector &, const std::vector &) const { + return 0.0; +} + +double BatchParallelCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &, + int32_t) const { + double cost = 0.0; + for (size_t i = 0; i < inputs.size(); ++i) { + cost += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); + } + return cost; +} + +double BatchParallelCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, + int32_t) const { + return 0.0; +} + +double BatchParallelCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t j = 0; j < inputs.size(); ++j) { + if (!is_parameter_[j]) { + continue; + } + TensorInfo input_a_tensor_info = inputs[j]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + } + + return result; +} +// return the per device communication cost in the forward phase. +double PReLUCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { + // prelu does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double PReLUCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[1]) { + TensorInfo input1 = inputs[1]; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + } + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double PReLUCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + // In forward phase, the computation cost = slice(A) + slice(B) + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + return result; +} + +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes +// this operator uses +double PReLUCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &, + int32_t stage_id) const { + // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) + double result = 0.0; + if (is_parameter_[1]) { + TensorInfo input1 = inputs[1]; // tensor B + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + } + return result; +} + +// return the per device communication cost in the forward phase. +double OneHotCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { + // onehot does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double OneHotCost::GetBackwardCommCost(const std::vector &, const std::vector &, + int32_t) const { + // onehot does not need communication in the backward phase + return 0.0; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double OneHotCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + // In onehot's forward phase, the computation cost = slice(A) + Shape input0_slice_shape = inputs[0].slice_shape(); + return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); +} + +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes +// this operator uses +double OneHotCost::GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const { + return 0.0; +} + +// return the per device communication cost in the forward phase. +double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector &, + const std::vector &, int32_t) const { + // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector &, + const std::vector &, int32_t) const { + // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase + return 0.0; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t) const { + // In forward phase, the computation cost = slice(A) + slice(B) + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + return result; +} + +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes +// this operator uses +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { + return 0.0; +} + +// return the per device communication cost in the forward phase. +double ReshapeCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const { + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); + TensorRedistribution tensor_redistribution(false, true); + if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; + } + if (tensor_redistribution.ComputeCost() == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; + } + return (inputs_type_lengths_[0] * tensor_redistribution.comm_cost()); +} + +// return the per device communication cost in the backward phase. +double ReshapeCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[0]) { + TensorInfo input1 = inputs[0]; + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + } + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double ReshapeCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); + TensorRedistribution tensor_redistribution(false, true); + if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; + } + if (tensor_redistribution.ComputeCost() == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; + } + return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost()); +} + +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes +// this operator uses +double ReshapeCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { + return 0.0; +} + +double ArithmeticCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + double result; + result = ListProduct(inputs[0].slice_shape()) * static_cast(inputs_type_lengths_[0]) + + ListProduct(inputs[1].slice_shape()) * static_cast(inputs_type_lengths_[1]); + return result; +} + +double ArithmeticCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + if (is_parameter_[0]) { + TensorInfo input_a_tensor_info = inputs[0]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + + if (is_parameter_[1]) { + TensorInfo input_b_tensor_info = inputs[1]; + Shape input_b_shape = input_b_tensor_info.shape(); + Shape input_b_slice_shape = input_b_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_b_shape.size(); ++i) { + used_device_num *= input_b_shape[i] / input_b_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input_b_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + return result; +} + +double ArithmeticCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + if (is_parameter_[0]) { + TensorInfo input_a_tensor_info = inputs[0]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + + if (is_parameter_[1]) { + TensorInfo input_b_tensor_info = inputs[1]; + Shape input_b_shape = input_b_tensor_info.shape(); + Shape input_b_slice_shape = input_b_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_b_shape.size(); ++i) { + used_device_num *= input_b_shape[i] / input_b_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input_b_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + + return result; +} + +bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int32_t stage_id) { + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + auto strategy0 = shape[0] / slice_shape[0]; + + return (total_device_num == IntToSize(strategy0)); +} + +double ReduceMethodCost::GetForwardCommCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + double result = 0.0; + TensorInfo input0 = inputs[0]; + TensorInfo output0 = outputs[0]; + Shape input0_shape = input0.shape(); + Shape input0_slice_shape = input0.slice_shape(); + if (cross_batch_ && IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { + return result; + } + std::vector dim_list = input0.reduce_dim(); + std::vector::iterator pos; + pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { + return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; + }); + if (pos != dim_list.end()) { + result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + } + + return result; +} + +double ReduceMethodCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[0]) { + TensorInfo input_tensor_info = inputs[0]; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape input_shape = input_tensor_info.shape(); + Shape input_slice_shape = input_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_shape.size(); ++i) { + used_device_num *= input_shape[i] / input_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + + return result; +} + +double ReduceMethodCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + double result = 0.0; + TensorInfo input0 = inputs[0]; + TensorInfo output0 = outputs[0]; + std::vector dim_list = input0.reduce_dim(); + Shape input0_slice_shape = input0.slice_shape(); + Shape input0_shape = input0.shape(); + if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { + std::vector::iterator pos; + pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { + return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; + }); + if (pos != dim_list.end()) { + result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + } + } + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); + + return result; +} + +double ReduceMeanCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + double result = 0.0; + TensorInfo input0 = inputs[0]; + TensorInfo output0 = outputs[0]; + std::vector dim_list = input0.reduce_dim(); + Shape input0_slice_shape = input0.slice_shape(); + Shape input0_shape = input0.shape(); + if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { + std::vector::iterator pos; + pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { + return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; + }); + if (pos != dim_list.end()) { + result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]) * 2.0; + } + } + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); + + return result; +} + +double DropOutCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + if (inputs.empty()) { + return 0.0; + } + TensorInfo input0 = inputs[0]; + Shape input0_slice_shape = input0.slice_shape(); + return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * DROPOUT_COST_RATE; +} + +// return the per device communication cost in the forward phase. +double GatherV2Cost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { + // GatherV2Cost does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double GatherV2Cost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t j = 0; j < inputs.size(); ++j) { + if (!is_parameter_[j]) { + continue; + } + TensorInfo input_a_tensor_info = inputs[j]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + } + + return result; +} + +double GatherV2Cost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + // In forward phase, the computation cost = slice(A) + slice(B) + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + return result; +} + +double GatherV2Cost::GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const { + return 0.0; +} + +double LayerNormCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid parameter size " << is_parameter_.size() << " for layer norm cost"; + } + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost"; + } + + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t index = 0; index < inputs.size(); ++index) { + if (is_parameter_[index]) { + TensorInfo tensor_info = inputs[index]; + Shape shape = tensor_info.shape(); + Shape slice_shape = tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < shape.size(); ++i) { + if (slice_shape[i] == 0) { + MS_LOG(EXCEPTION) << "Invalid slice shape " << ShapeToString(slice_shape); + } + used_device_num *= shape[i] / slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(slice_shape) * static_cast(inputs_type_lengths_[index]); + } + } + } + return result; +} + +double LayerNormCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + double result = 0.0; + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost"; + } + + for (size_t index = 0; index < inputs.size(); ++index) { + TensorInfo tensor_info = inputs[index]; + Shape slice_shape = tensor_info.slice_shape(); + result += ListProduct(slice_shape) * static_cast(inputs_type_lengths_[index]); + } + return result; +} + +double GatherV2PCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const { + double result = 0.0; + if (outputs_type_lengths_.size() != outputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; + } + // don't split axis + if (strategy_.at(IntToSize(axis_)) == 1) { + return result; + } + + // split axis + auto param_shape = inputs[0].slice_shape(); + auto index_shape = inputs[1].slice_shape(); + Shape reducescatter_shape = index_shape; + if (param_shape.size() == 2) { + reducescatter_shape.push_back(param_shape.at(1 - axis_)); + } + result += ListProduct(reducescatter_shape) * static_cast(outputs_type_lengths_[0]); + return result; +} + +double GatherV2PCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t j = 0; j < inputs.size(); ++j) { + if (!is_parameter_[j]) { + continue; + } + TensorInfo input_a_tensor_info = inputs[j]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + } + return result; +} + +double GatherV2PCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + double result = 0.0; + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; + } + // don't split axis + if (strategy_.at(IntToSize(axis_)) == 1) { + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } else { + // split axis + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1; + } + + return result; +} + +double GatherV2PCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t) const { + double result = 0.0; + Shape input1_slice_shape = inputs[1].slice_shape(); + Shape output0_slice_shape = outputs[0].slice_shape(); + // don't split axis + if (strategy_.at(IntToSize(axis_)) == 1) { + result += ListProduct(output0_slice_shape) * static_cast(inputs_type_lengths_[0]); + } else { + // split axis + result += ListProduct(output0_slice_shape) * static_cast(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3; + } + + return result; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h new file mode 100644 index 0000000000..dda597bd1f --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -0,0 +1,656 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ +#define PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ + +#include +#include +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" + +namespace mindspore { +namespace parallel { +#define MAXIMUM_INPUT_NUMBER 100 +#define DEFAULT_DATA_TYPE_LENGTH 4 +#define DROPOUT_COST_RATE 1.125 // the DropoutGenMask need 12.5% memory +#define GATHERV2_COST_WEIGHT0 3 +#define GATHERV2_COST_WEIGHT1 7 +#define GATHERV2_COST_WEIGHT2 2 +#define GATHERV2_COST_WEIGHT3 6 + +class OperatorCost; +using OperatorCostPtr = std::shared_ptr; + +template +double ListProduct(std::vector vec) { + double result = 1; + for (size_t i = 0; i < vec.size(); ++i) { + result *= vec[i]; + } + return result; +} +// NOTE: Currently, the returned value in each method is bytes of memory size, which is calculated by the number of +// entries timing the length of each entry's data type +class OperatorCost { + public: + explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) { + // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked + for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { + is_parameter_.push_back(false); + is_parameter_involve_.push_back(false); + inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); + outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); + } + } + OperatorCost() : inputs_related_(false) { + // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked + for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { + is_parameter_.push_back(false); + is_parameter_involve_.push_back(false); + inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); + outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); + } + } + virtual ~OperatorCost() = default; + + void set_is_parameter(const std::vector &is_parameter); + void set_is_parameter_involve(const std::vector &); + void set_output_parameter_involve(int); + void set_output_critical(int); + void SetInputAndOutputTypeLength(const std::vector &input_lengths, const std::vector &output_lengths); + std::vector inputs_type_lengths() const { return inputs_type_lengths_; } + std::vector outputs_type_lengths() const { return outputs_type_lengths_; } + + // per device communication cost + virtual double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const = 0; + virtual double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const = 0; + virtual double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const = 0; + // per device computation cost + virtual double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const = 0; + virtual double GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const = 0; + virtual double GetBackwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const = 0; + // per device PEAK memory cost in a training iteration + // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), + // plus necessary inputs. + virtual double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const; + // per device memory cost in a inference phase + double GetMemoryCostForInference(const std::vector &, const std::vector &) const; + + protected: + // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of + // pre-operator that has parameters as input. + std::vector is_parameter_involve_; + int output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved + // Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while + // Mul's two inputs are dependent (related). + bool inputs_related_; + // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter + std::vector is_parameter_; + // for each input and output, the followings record the number of bytes of each element + std::vector inputs_type_lengths_; + std::vector outputs_type_lengths_; + // Whether the output is critical, which means that this output is included in calculating peak memory cost + // in the inference phase. + int is_outputs_critical_ = -1; +}; + +using OperatorCostPtr = std::shared_ptr; + +class MatMulCost : public OperatorCost { + public: + explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + MatMulCost() : OperatorCost(true) {} + ~MatMulCost() override = default; + + // per device communication cost + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + // per device computation cost + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using MatMulCostPtr = std::shared_ptr; + +class ActivationCost : public OperatorCost { + public: + explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ActivationCost() : OperatorCost(false) {} + ~ActivationCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using ActivationCostPtr = std::shared_ptr; +using TransposeCost = ActivationCost; +using TransposeCostPtr = std::shared_ptr; + +class SoftmaxCost : public OperatorCost { + public: + explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + SoftmaxCost() : OperatorCost(false) {} + ~SoftmaxCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t) const override; +}; +using SoftmaxCostPtr = std::shared_ptr; + +class TmpIdentityCost : public OperatorCost { + public: + explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + TmpIdentityCost() : OperatorCost(false) {} + ~TmpIdentityCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + // per device PEAK memory cost in a training iteration + double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override; +}; +using TmpIdentityCostPtr = std::shared_ptr; + +class BatchParallelCost : public OperatorCost { + public: + explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + BatchParallelCost() : OperatorCost(false) {} + ~BatchParallelCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using BatchParallelCostPtr = std::shared_ptr; + +class VirtualDatasetCost : public OperatorCost { + public: + explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + VirtualDatasetCost() : OperatorCost(false) {} + ~VirtualDatasetCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } + // per device PEAK memory cost in a training iteration + double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override { + return 0.0; + } +}; +using VirtualDatasetCostPtr = std::shared_ptr; + +class GeneratorBaseCost : public OperatorCost { + public: + explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + GeneratorBaseCost() : OperatorCost(false) {} + ~GeneratorBaseCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + // Inputs vector is empty for generator ops. + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } + // Generator ops don't have backward steps. + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } +}; +using GeneratorBaseCostPtr = std::shared_ptr; + +class PReLUCost : public OperatorCost { + public: + explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + PReLUCost() : OperatorCost(true) {} + ~PReLUCost() override = default; + + // per device communication cost + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + // per device computation cost + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using PReLUCostPtr = std::shared_ptr; + +class OneHotCost : public OperatorCost { + public: + explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + OneHotCost() : OperatorCost(true) {} + ~OneHotCost() override = default; + + // per device communication cost + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + // per device computation cost + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using OneHotCostPtr = std::shared_ptr; + +class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { + public: + explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {} + ~SoftmaxCrossEntropyWithLogitsCost() override = default; + + // per device communication cost + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + // per device computation cost + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; + +class ReshapeCost : public OperatorCost { + public: + explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ReshapeCost() : OperatorCost(true) {} + + ~ReshapeCost() override = default; + + // per device communication cost + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + // per device computation cost + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using ReshapeCostPtr = std::shared_ptr; + +class ArithmeticCost : public OperatorCost { + public: + explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ArithmeticCost() : OperatorCost(false) {} + ~ArithmeticCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using ArithmeticCostPtr = std::shared_ptr; +using BiasAddCost = ArithmeticCost; +using BiasAddCostPtr = std::shared_ptr; + +class ReduceMethodCost : public OperatorCost { + public: + explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ReduceMethodCost() : OperatorCost(true) {} + ~ReduceMethodCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } + void set_cross_batch(bool cb) { cross_batch_ = cb; } + + protected: + bool cross_batch_ = false; +}; +using ReduceMethodCostPtr = std::shared_ptr; + +class ReduceMeanCost : public ReduceMethodCost { + public: + explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {} + ReduceMeanCost() : ReduceMethodCost(true) {} + ~ReduceMeanCost() override = default; + + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using ReduceMeanCostPtr = std::shared_ptr; + +class GetNextCost : public OperatorCost { + public: + explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + GetNextCost() : OperatorCost(false) {} + ~GetNextCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + // Inputs vector is empty for generator ops. + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } + // Generator ops don't have backward steps. + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } +}; +using GetNextCostPtr = std::shared_ptr; + +class DropOutCost : public OperatorCost { + public: + explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + DropOutCost() : OperatorCost(true) {} + ~DropOutCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override; + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } +}; + +using DropOutCostPtr = std::shared_ptr; + +class LayerNormCost : public OperatorCost { + public: + explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + LayerNormCost() : OperatorCost(true) {} + ~LayerNormCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override; + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } +}; + +using DropOutCostPtr = std::shared_ptr; + +class GatherV2Cost : public OperatorCost { + public: + explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + GatherV2Cost() : OperatorCost(true) {} + ~GatherV2Cost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t) const override; +}; + +using GatherV2CostPtr = std::shared_ptr; + +class GatherV2PCost : public OperatorCost { + public: + explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {} + GatherV2PCost() : OperatorCost(true), axis_(0) {} + ~GatherV2PCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t) const override; + void set_axis(int32_t axis) { axis_ = axis; } + void set_strategy(const Shape &strategy) { strategy_ = strategy; } + + protected: + int32_t axis_; + Shape strategy_; +}; + +using GatherV2PCostPtr = std::shared_ptr; +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.cc new file mode 100644 index 0000000000..0a7e6c59d4 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.cc @@ -0,0 +1,750 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/rec_core/rec_cost.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" + +namespace mindspore { +namespace parallel { + +// Compute redistributed cost +double CostRedis(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const std::vector> &mode, const Graph &graph) { + // Store value of cost redist + double cost_redis = 0; + + // Number of current strategies. + size_t num_strategy = node_name_to_strategy.size(); + + // Number of node-in and node-out + size_t num_node_in = node.node_in.size(); + size_t num_node_out = node.node_out.size(); + + // Set tensor edge value with original tensor shape and cutting times. + double input_tensor = node.apply.arguments[0].tensor_shape.shape_n * node.apply.arguments[0].tensor_str.str_n * + node.apply.arguments[0].tensor_shape.shape_c * node.apply.arguments[0].tensor_str.str_c * + node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h * + node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w; + + double output_tensor = node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n * + node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c * + node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h * + node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w; + + // For each strategy candidate. + for (size_t i_strategy = 0; i_strategy < num_strategy; i_strategy++) { + // Find its forward nodes + for (size_t i_node = 0; i_node < num_node_in; i_node++) { + if (graph.nodes[node.node_in[i_node]].name == node_name_to_strategy[i_strategy].first) { + bool is_search_forward = true; + cost_redis += + CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, input_tensor, is_search_forward); + } + } + + // Find its backward nodes + for (size_t i_node = 0; i_node < num_node_out; i_node++) { + if (graph.nodes[node.node_out[i_node]].name == node_name_to_strategy[i_strategy].first) { + bool is_search_forward = false; + cost_redis += + CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, output_tensor, is_search_forward); + } + } + } + + return cost_redis; +} + +double CostRedisWithAdjacentNode(const std::vector> &node_name_to_strategy, + const std::vector> &mode, size_t i_strategy, size_t i_node, + double tensor_size, bool search_forward) { + double new_redis_cost = 0; + int counter = 0; + + if (search_forward) { + if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_n) != + static_cast(1 / mode[i_node][0])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_c) != + static_cast(1 / mode[i_node][1])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_h) != + static_cast(1 / mode[i_node][2])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_w) != + static_cast(1 / mode[i_node][3])) { + counter += 1; + } + } else { + if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_n) != + static_cast(1 / mode[2][0])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_c) != + static_cast(1 / mode[2][1])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_h) != + static_cast(1 / mode[2][2])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_w) != + static_cast(1 / mode[2][3])) { + counter += 1; + } + } + + if (counter >= 2) { + new_redis_cost = tensor_size / 4.0; + } else if (counter == 0 || counter == 1) { + new_redis_cost = 0; + } else { + MS_LOG(EXCEPTION) << "Failure: CostRedis failed."; + } + + return new_redis_cost; +} + +// Get optimal strategy for MatMul +StrategyRec CostMatMul::GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph) { + int edge_i = + static_cast(node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h); + int edge_j = + static_cast(node.apply.arguments[1].tensor_shape.shape_w * node.apply.arguments[1].tensor_str.str_w); + int edge_k = + static_cast(node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w); + + std::vector cost_op; + std::vector> mode; + + if (edge_i < 2 || edge_i % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrConcatDimI(edge_j, edge_k) + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}}, + graph)); + } + + if (edge_j < 2 || edge_j % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrConcatDimJ(edge_i, edge_k) + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 1, 1}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}}, + graph)); + } + + if (edge_k < 2 || edge_k % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrReduceDimK(edge_i, edge_j) + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 1, 0.5}, {1, 1, 0.5, 1}, {1, 1, 1, 1}}, + graph)); + } + + return ChoseStr(cost_op, node.apply.str); +} + +// Get weight for MatMul +double CostMatMul::GetMinCostIn(const OperatorRec &op) { + int edge_i = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); + int edge_j = static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w); + int edge_k = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); + + std::vector cost_in; + cost_in.push_back(StrConcatDimI(edge_j, edge_k)); + cost_in.push_back(StrConcatDimJ(edge_i, edge_k)); + cost_in.push_back(StrReduceDimK(edge_i, edge_j)); + + return *min_element(cost_in.begin(), cost_in.end()); +} + +// Chose strategy for MatMul +StrategyRec CostMatMul::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_i_; + break; + + case 1: + str.inputTensor[1].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_j_; + break; + + case 2: + str.inputTensor[0].str_w /= 2.0; + str.inputTensor[1].str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_k_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure:CostMatMul failed."; + } + + return str; +} + +// Get optimal strategy for Conv +StrategyRec CostConvolution::GetOptimalStr( + const Graph::NodeType &node, const std::vector> &node_name_to_strategy, + const Graph &graph, bool channel_partition) { + const OperatorRec &op = node.apply; + + int input_tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); + int input_tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); + int input_tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); + int input_tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); + + int tensor_in = input_tensor_h * input_tensor_w * input_tensor_n * input_tensor_c; + + int tensor_filter_h = static_cast(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h); + int tensor_filter_w = static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w); + int tensor_filter_n = static_cast(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n); + int tensor_filter_c = static_cast(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); + + int tensor_filter = tensor_filter_h * tensor_filter_w * tensor_filter_n * tensor_filter_c; + + int output_tensor_h = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h); + int output_tensor_w = static_cast(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w); + int output_tensor_n = static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n); + int output_tensor_c = static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); + + int tensor_out = output_tensor_h * output_tensor_w * output_tensor_n * output_tensor_c; + + std::vector cost_op; + cost_op.reserve(7); + std::vector> mode; + + if (input_tensor_n < 2 || input_tensor_n % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy, + mode = {{0.5, 1, 1, 1}, {1, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); + } + + cost_op.push_back(DOUBLE_MAX); + cost_op.push_back(DOUBLE_MAX); + + if (channel_partition == false || tensor_filter < 2 || tensor_filter % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrDimK(tensor_in) + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 1, 1}, {0.5, 1, 1, 1}, {1, 0.5, 1, 1}}, graph)); + } + + cost_op.push_back(DOUBLE_MAX); + cost_op.push_back(DOUBLE_MAX); + + if (channel_partition == false || tensor_filter_c < 2 || tensor_filter_c % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrDimQ(tensor_out) + CostRedis(node, node_name_to_strategy, + mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 1, 1, 1}}, graph)); + } + + return ChoseStr(cost_op, node.apply.str); +} + +// Get weight for Conv +double CostConvolution::GetMinCostIn(const Graph::NodeType &node) { + const OperatorRec &op = node.apply; + + int tensor_in = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) * + static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) * + static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) * + static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); + int tensor_filter = static_cast(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h) * + static_cast(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n) * + static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w) * + static_cast(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); + int tensor_out = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h) * + static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n) * + static_cast(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w) * + static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); + + std::vector cost_in; + cost_in.push_back(StrDimB(tensor_filter)); + cost_in.push_back(StrDimI(tensor_in, tensor_filter)); + cost_in.push_back(StrDimJ(tensor_in, tensor_filter)); + cost_in.push_back(StrDimK(tensor_in)); + cost_in.push_back(StrDimDI(tensor_in, tensor_out)); + cost_in.push_back(StrDimDJ(tensor_in, tensor_out)); + cost_in.push_back(StrDimQ(tensor_out)); + + return *min_element(cost_in.begin(), cost_in.end()); +} + +// Chose strategy for Conv +StrategyRec CostConvolution::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_b_; + break; + + case 1: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_i_; + break; + + case 2: + str.inputTensor[0].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_j_; + break; + + case 3: + str.inputTensor[1].str_n /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_k_; + break; + + case 4: + str.inputTensor[1].str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_di_; + break; + + case 5: + str.inputTensor[1].str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_dj_; + break; + + case 6: + str.inputTensor[0].str_c /= 2.0; + str.inputTensor[1].str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_q_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostConvolution failed."; + } + return str; +} + +// Get optimal strategy for Pooling +StrategyRec CostPooling::GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph) { + int tensor_n = static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n); + int tensor_c = static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); + + std::vector cost_op; + std::vector> mode; + + if (tensor_n < 2 || tensor_n % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); + } + + if (tensor_c < 2 || tensor_c % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph)); + } + + cost_op.push_back(DOUBLE_MAX); + cost_op.push_back(DOUBLE_MAX); + + return ChoseStr(cost_op, node.apply.str); +} + +// Chose strategy for Pooling +StrategyRec CostPooling::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostPooling failed."; + } + return str; +} + +// Chose strategy for Add +StrategyRec CostTensorAdd::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.inputTensor[1].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.inputTensor[1].str_c /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.inputTensor[1].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.inputTensor[1].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostAdd failed."; + } + return str; +} + +// Get optimal strategy for Reshape +StrategyRec CostReshape::GetOptimalStr(const Graph::NodeType &node) const { return ChoseStr(node.apply.str); } + +StrategyRec CostReshape::ChoseStr(StrategyRec str) const { return str; } + +// Chose strategy for BiasAdd +StrategyRec CostBiasAdd::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.inputTensor[1].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostBiasAdd failed."; + } + return str; +} + +// Get optimal strategy for Common OPs +StrategyRec CostCommon::GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph) { + const OperatorRec &op = node.apply; + int tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); + int tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); + int tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); + int tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); + + std::vector cost_op; + std::vector> mode; + + if (tensor_n < 2 || tensor_n % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); + } + + if (tensor_c < 2 || tensor_c % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph)); + } + + if (tensor_h < 2 || tensor_h % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 0.5, 1}, {1, 1, 0.5, 1}, {1, 1, 0.5, 1}}, graph)); + } + + if (tensor_w < 2 || tensor_w % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 1, 0.5}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}}, graph)); + } + + return ChoseStr(cost_op, node.apply.str); +} + +// Chose strategy for Common op +StrategyRec CostCommon::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: Common failed."; + } + return str; +} + +// Get optimal strategy for BatchParallel OPs +StrategyRec CostBatchParallel::GetOptimalStr(const Graph::NodeType &node) { + const OperatorRec &op = node.apply; + int tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); + int tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); + int tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); + int tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); + + std::vector cost_op; + + if (tensor_n < 2 || tensor_n % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_); + } + + if (tensor_c < 2 || tensor_c % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_); + } + + if (tensor_h < 2 || tensor_h % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_); + } + + if (tensor_w < 2 || tensor_w % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_); + } + + return ChoseStr(cost_op, node.apply.str); +} + +// Chose strategy for BatchParallel op +StrategyRec CostBatchParallel::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostBatchParallel failed."; + } + return str; +} + +// Chose strategy for CostSoftmaxCrossEntropyWithLogits +StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.inputTensor[1].str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.inputTensor[1].str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.inputTensor[1].str_h /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.inputTensor[1].str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed."; + } + return str; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.h new file mode 100644 index 0000000000..563bf4598a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.h @@ -0,0 +1,233 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_COST_H_ +#define PARALLEL_AUTO_PARALLEL_REC_COST_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h" + +namespace mindspore { +namespace parallel { +#define DOUBLE_MAX (std::numeric_limits::max)() + +double CostRedis(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const std::vector> &mode, const Graph &graph); + +double CostRedisWithAdjacentNode(const std::vector> &node_name_to_strategy, + const std::vector> &mode, size_t i_strategy, size_t i_node, + double tensor_size, bool is_search_forward); + +// class CostMatMul is used to compute the cost of MatMul operator. +class CostMatMul { + public: + StrategyRec GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph); + + double GetMinCostIn(const OperatorRec &op); + + private: + double StrConcatDimI(int32_t a, int32_t b) { + cost_in_i_ = (static_cast(a) * static_cast(b)) / 2.0; + + return cost_in_i_; + } + + double StrConcatDimJ(int32_t a, int32_t b) { + cost_in_j_ = (static_cast(a) * static_cast(b)) / 2.0; + + return cost_in_j_; + } + + double StrReduceDimK(int32_t a, int32_t b) { + cost_in_k_ = (static_cast(a) * static_cast(b)) / 2.0; + + return cost_in_k_; + } + + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); + + double cost_in_i_ = 0; + + double cost_in_j_ = 0; + + double cost_in_k_ = 0; +}; // class CostMatMul is used to compute the cost of MatMul operator. + +// class CostConvolution is used to compute the cost of Conv operator. +class CostConvolution { + public: + StrategyRec GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph, bool channel_partition); + + double GetMinCostIn(const Graph::NodeType &node); + + private: + double StrDimB(int32_t TensorFilter) { + cost_in_b_ = static_cast((TensorFilter) / 2.0); + + return cost_in_b_; + } + + double StrDimI(int32_t TensorIn, int32_t TensorFilter) { + cost_in_i_ = static_cast((TensorIn + TensorFilter) / 2.0); + + return cost_in_i_; + } + + double StrDimJ(int32_t TensorIn, int32_t TensorFilter) { + cost_in_j_ = static_cast((TensorIn + TensorFilter) / 2.0); + + return cost_in_j_; + } + + double StrDimK(int32_t TensorIn) { + cost_in_k_ = static_cast((TensorIn) / 2.0); + + return cost_in_k_; + } + + double StrDimDI(int32_t TensorIn, int32_t TensorOut) { + cost_in_di_ = static_cast((TensorIn + TensorOut) / 2.0); + + return cost_in_di_; + } + + double StrDimDJ(int32_t TensorIn, int32_t TensorOut) { + cost_in_dj_ = static_cast((TensorIn + TensorOut) / 2.0); + + return cost_in_dj_; + } + + double StrDimQ(int32_t TensorOut) { + cost_in_q_ = static_cast((TensorOut) / 2.0); + + return cost_in_q_; + } + + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); + + double cost_in_b_ = 0; + + double cost_in_i_ = 0; + + double cost_in_j_ = 0; + + double cost_in_k_ = 0; + + double cost_in_di_ = 0; + + double cost_in_dj_ = 0; + + double cost_in_q_ = 0; +}; // class CostConvolution is used to compute the cost of Conv operator. + +// class CostPooling is used to compute the cost of Pooling operator. +class CostPooling { + public: + StrategyRec GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph); + + double GetMinCostIn() const { return cost_in_; } + + private: + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); + + double cost_in_ = 0; +}; // class CostPooling is used to compute the cost of Pooling operator. + +// class CostReshape is used to compute the cost of Reshape operator. +class CostReshape { + public: + StrategyRec GetOptimalStr(const Graph::NodeType &node) const; + + double GetMinCostIn() const { return cost_in_; } + + private: + StrategyRec ChoseStr(StrategyRec str) const; + + double cost_in_ = 0; +}; // class CostReshape is used to compute the cost of Reshape operator. + +// class CostCommon is used to compute the cost of an element-wise operator +class CostCommon { + public: + virtual StrategyRec GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph); + + virtual double GetMinCostIn() const { return cost_in_; } + + protected: + virtual StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); + + double cost_in_ = 0; +}; // class CostCommon is used to compute the cost of an element-wise operator + +// class CostBiasAdd is used to compute the cost of the addition between a tensor and a bias +class CostBiasAdd : public CostCommon { + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); +}; +// class CostAdd is used to compute the cost of Add operator. +class CostTensorAdd : public CostCommon { + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); +}; + +// all the following operation are element-wise and have the same cost +class CostReLU : public CostCommon {}; +class CostLog : public CostCommon {}; +class CostExp : public CostCommon {}; +class CostAdd : public CostCommon {}; +class CostSub : public CostCommon {}; +class CostMul : public CostCommon {}; +class CostDiv : public CostCommon {}; +class CostSqueeze : public CostCommon {}; +class CostCast : public CostCommon {}; + +// class BatchParallel is used to compute the cost of BatchParallel operator. +class CostBatchParallel { + public: + virtual StrategyRec GetOptimalStr(const Graph::NodeType &node); + + virtual double GetMaxCostIn() const { return DOUBLE_MAX; } + + protected: + virtual StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); + + double cost_in_ = 0; +}; // class BatchParallel is used to compute the cost of BatchParallel operator. + +class CostBatchNorm : public CostBatchParallel {}; +class CostOneHot : public CostBatchParallel {}; +class CostPRelu : public CostBatchParallel {}; +class CostSoftmax : public CostBatchParallel {}; + +class CostSoftmaxCrossEntropyWithLogits : public CostBatchParallel { + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); +}; +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc new file mode 100644 index 0000000000..68b776155a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -0,0 +1,837 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h" + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, + const std::shared_ptr>> &eli_list, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(eli_list); + MS_EXCEPTION_IF_NULL(index_list); + GeneratePartitionedOperatorStrategy(graph, ops, index_list); + std::shared_ptr> no_stra_op_list(new std::vector); + for (size_t i = 0; i < eli_list->size(); i++) { + no_stra_op_list->push_back(eli_list->at(i)[0]); + } + GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); + GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); + GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); +} + +std::vector> PrepareMatMul(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + std::vector> strategies; + auto attrs = ops[iter_ops]->attrs(); + bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); + bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); + + // HCCL does not support multi-dimension partition, and the hardware does not support excessive + // number of EVENT, so we temporarily disable matmul's multi-dimension partition function. + const auto max_cut = 1.0 / g_device_manager->DeviceNum(); + if (graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h != max_cut && + graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w != max_cut) { + graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = 1.0; + graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_w = 1.0; + graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_h = 1.0; + graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; + + auto shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; + if (transpose_a) { + shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[1]; + } + auto shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[1]; + if (transpose_b) { + shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[0]; + } + + bool already_cut = false; + if (shape_1 >= shape_4) { + if (shape_1 % g_device_manager->DeviceNum() == 0) { + graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut; + already_cut = true; + } + if (!already_cut && shape_4 % g_device_manager->DeviceNum() == 0) { + graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut; + already_cut = true; + } + } else { + if (shape_4 % g_device_manager->DeviceNum() == 0) { + graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut; + already_cut = true; + } + if (!already_cut && shape_1 % g_device_manager->DeviceNum() == 0) { + graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut; + already_cut = true; + } + } + + if (!already_cut) { + MS_LOG(EXCEPTION) << "Failure: MatMul's shape is invalid."; + } + } + + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + std::vector s; + if (transpose_a && (iter_op_inputs == 0)) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + } else if (transpose_b && (iter_op_inputs == 1)) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + } else { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } + strategies.push_back(s); + } + return strategies; +} + +std::vector> PrepareBiasAdd(const std::shared_ptr> &s) { + std::vector> strategies; + strategies.push_back(*s); + std::vector s_biasadd; + s_biasadd.push_back(s->at(1)); + strategies.push_back(s_biasadd); + return strategies; +} + +std::vector> PrepareOneHot(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + std::vector> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); + + int32_t axis = -1; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter != ops[iter_ops]->attrs().end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis = iter->second->cast()->value(); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int."; + } + } + if (axis == -1) { + strategies[0][0] = strategies[0][1]; + strategies[0][1] = 1; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; + } + + std::vector s_empty = {}; + strategies.push_back(s_empty); + strategies.push_back(s_empty); + return strategies; +} + +std::vector> PrepareGatherV2(const std::vector> &ops, + const size_t iter_ops, std::vector s) { + std::vector> strategies; + + auto axis_input = GetValue(ops[iter_ops]->input_value().at(2)); + if (axis_input < 0) { + axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); + } + int32_t axis = axis_input; + if (axis >= SizeToInt(s.size())) { + MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; + } + s[axis] = 1; + strategies.push_back(s); + + auto pos = ops[iter_ops]->name().find("Info"); + auto name = ops[iter_ops]->name().substr(0, pos); + if (name == "GatherV2") { + return strategies; + } + + std::vector s_indices; + for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { + s_indices.push_back(1); + } + strategies.push_back(s_indices); + + return strategies; +} + +std::vector> PrepareL2Normalize(const std::vector> &ops, + const size_t iter_ops, std::vector s) { + int32_t axis = 0; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter != ops[iter_ops]->attrs().end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis = iter->second->cast()->value(); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " : The value of axis is not int."; + } + } + + int32_t axis_index = axis; + if (axis < 0) { + size_t input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + axis_index = static_cast(input_dim) + axis; + } + + s[IntToSize(axis_index)] = 1; + + std::vector> strategies; + strategies.push_back(s); + return strategies; +} + +std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + + StrategyPtr origin_strategy = ops[iter_ops]->strategy(); + std::vector> strategies; + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { + MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + } + + size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); + std::vector s; + if (output_size == 4) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_n)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } else if (output_size == 2) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } else if (output_size == 1) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } else if (output_size == 0) { + s = {}; + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's output size is unexcepted."; + } + strategies.push_back(s); + } + return strategies; +} + +std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + + StrategyPtr origin_strategy = ops[iter_ops]->strategy(); + std::vector> strategies; + size_t max_device_num = g_device_manager->DeviceNum(); + size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { + MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + } + + std::vector s; + size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); + for (size_t dim = 0; dim < input_size; dim++) { + if (input_size == 1 || input_size == 2 || input_size == 4) { + if (dim == 0) { + s.push_back(std::min(max_device_num, target_tensor_batch)); + } else { + s.push_back(1); + } + } else if (input_size == 0) { + s = {}; + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; + } + } + strategies.push_back(s); + } + + graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; + if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { + graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch); + } + + return strategies; +} + +std::vector> PrepareStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + MS_EXCEPTION_IF_NULL(ops[iter_ops]); + + auto type = ops[iter_ops]->type(); + auto idx = DictOpType.find(type); + if (idx == DictOpType.end()) { + return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); + } + + if (type == MATMUL) { + return PrepareMatMul(graph, ops, iter_graph, iter_ops); + } else if (type == ONEHOT) { + return PrepareOneHot(graph, ops, iter_graph, iter_ops); + } else { + return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); + } +} + +void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const std::shared_ptr> &index_list) { + for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { + std::vector> strategies; + size_t iter_graph = index_list->at(iter_ops); + if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) { + strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); + } + StrategyPtr sp = std::make_shared(0, strategies); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } +} + +size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, + const size_t iter_ops) { + size_t incoming_op_index = SIZE_MAX; + for (size_t i = 1; i < input_tensor_names[iter_ops].size(); i++) { + for (size_t j = 0; j < input_tensor_names.size(); j++) { + if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) { + incoming_op_index = j; + break; + } + } + if (incoming_op_index != SIZE_MAX) { + break; + } + } + return incoming_op_index; +} + +std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_ops, const size_t iter_graph) { + std::vector s; + for (auto input : ops[iter_ops]->inputs_tensor_info()) { + auto input_stra_dim = input.shape().size(); + if (input_stra_dim == 0) { + continue; + } + if (input_stra_dim == 1) { + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); + } else if (input_stra_dim == 2) { + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); + } else if (input_stra_dim == 4) { + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_n); + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c); + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; + } + break; + } + return s; +} + +std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t incoming_op_index) { + std::vector s; + if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || + ops[incoming_op_index]->type() == TRANSPOSE) { + return s; + } + auto strategy = ops[incoming_op_index]->selected_strategy(); + if (strategy->GetInputNumber() == 0) { + return s; + } + + for (size_t i = 0; i < (size_t)ops[incoming_op_index]->inputs_tensor_info().size(); i++) { + if (ops[incoming_op_index]->inputs_tensor_info()[i].shape().size() == 0) { + continue; + } + for (size_t j = 0; j < ops[incoming_op_index]->inputs_tensor_info()[i].shape().size(); ++j) { + s.push_back(strategy->GetInputDim()[i][j]); + } + break; + } + return s; +} + +std::vector GetAxisList(const std::vector> &ops, const int iter_ops) { + std::vector axis_list; + auto axis_param = ops[iter_ops]->attrs().find(AXIS)->second; + std::vector elements; + if (axis_param->isa()) { + elements = axis_param->cast()->value(); + } else if (axis_param->isa()) { + elements = axis_param->cast()->value(); + } else { + MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list." << std::endl; + } + + for (auto &element : elements) { + if (!element->isa()) { + MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32." << std::endl; + } + auto axis = element->cast()->value(); + axis_list.push_back(axis); + } + return axis_list; +} + +std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s) { + std::vector s_Squeeze; + std::vector stra_dim_list; + for (size_t i = 0; i < s.size(); i++) { + stra_dim_list.push_back(i); + } + + auto axis_list = GetAxisList(ops, incoming_op_index); + for (auto axis : axis_list) { + auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis); + if (it == stra_dim_list.end()) { + MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; + } + if (ops[incoming_op_index]->inputs_tensor_info()[0].shape()[axis] != 1) { + MS_LOG(EXCEPTION) << "Failure: Removed dimension's shape is not 1." << std::endl; + } + stra_dim_list.erase(it); + } + + for (size_t i = 0; i < (size_t)stra_dim_list.size(); i++) { + s_Squeeze.push_back(s[stra_dim_list[i]]); + } + return s_Squeeze; +} + +bool GetKeepDims(const std::vector> &ops, const size_t iter_ops) { + bool keepdims = false; + auto keep_dims_iter = ops[iter_ops]->attrs().find(KEEP_DIMS); + if (keep_dims_iter == ops[iter_ops]->attrs().end()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr keep_dims."; + } + MS_EXCEPTION_IF_NULL(keep_dims_iter->second); + if (!keep_dims_iter->second->isa()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Keep_dims is not a bool."; + } + keepdims = keep_dims_iter->second->cast()->value(); + return keepdims; +} + +std::vector GetDimList(const std::vector> &ops, const size_t iter_ops) { + std::vector dim_list; + bool keep_dims = GetKeepDims(ops, iter_ops); + if (keep_dims != false) { + return dim_list; + } + auto input_value = ops[iter_ops]->input_value(); + auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + if (input_value.back()->isa()) { + auto attr_axis = GetValue>(input_value.back()); + if (attr_axis.empty()) { + MS_LOG(EXCEPTION) << "Failure: This output is a 0-D tensor." << std::endl; + } + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } + } else if (input_value.back()->isa()) { + int axis = GetValue(input_value.back()); + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Failure: Axis type is invalid." << std::endl; + } + return dim_list; +} + +std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s) { + std::vector s_Reduce; + std::vector axis_list; + for (size_t i = 0; i < s.size(); i++) { + axis_list.push_back(i); + } + + auto dim_list = GetDimList(ops, incoming_op_index); + for (auto axis : dim_list) { + auto it = find(axis_list.begin(), axis_list.end(), axis); + if (it == axis_list.end()) { + MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; + } + axis_list.erase(it); + } + + for (size_t i = 0; i < (size_t)axis_list.size(); i++) { + s_Reduce.push_back(s[axis_list[i]]); + } + return s_Reduce; +} + +std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops) { + std::vector dim_list; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter == ops[iter_ops]->attrs().end()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis."; + } + auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + auto attr_axis = GetValue>(iter->second); + if (attr_axis.empty()) { + for (size_t i = 0; i < input_dim; ++i) { + dim_list.push_back(SizeToInt(i)); + } + } else { + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } + } + } else if (iter->second->isa()) { + int axis = GetValue(iter->second); + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Axis type is invalid."; + } + return dim_list; +} + +std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s) { + bool keepdims = GetKeepDims(ops, incoming_op_index); + if (keepdims) { + return s; + } + + std::vector s_Arg; + std::vector axis_list; + for (size_t i = 0; i < s.size(); i++) { + axis_list.push_back(i); + } + + auto dim_list = GetDimListFromAttrs(ops, incoming_op_index); + for (auto axis : dim_list) { + auto it = find(axis_list.begin(), axis_list.end(), axis); + if (it == axis_list.end()) { + MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; + } + axis_list.erase(it); + } + + for (size_t i = 0; i < (size_t)axis_list.size(); i++) { + s_Arg.push_back(s[axis_list[i]]); + } + return s_Arg; +} + +std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t iter_ops, const size_t incoming_op_index) { + std::vector s; + s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index); + if (s.size() != 0) { + if (ops[incoming_op_index]->type() == SQUEEZE) { + s = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, s); + } + if (ops[incoming_op_index]->type() == REDUCE_SUM || ops[incoming_op_index]->type() == REDUCE_MAX || + ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { + s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); + } + if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) { + s = ModifyStrategyIfArgIncoming(ops, incoming_op_index, s); + } + } + return s; +} + +std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, + const size_t iter_ops, + std::vector basic_stra) { + std::vector s_empty = {}; + std::vector> stra; + MS_EXCEPTION_IF_NULL(ops[iter_ops]); + + if (basic_stra.size() == 0) { + for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); + iter_op_inputs++) { + stra.push_back(basic_stra); + } + return stra; + } + + auto s_ptr = std::make_shared>(basic_stra); + if (ops[iter_ops]->type() == BIAS_ADD) { + return PrepareBiasAdd(s_ptr); + } + if (ops[iter_ops]->type() == GATHERV2) { + return PrepareGatherV2(ops, iter_ops, basic_stra); + } + if (ops[iter_ops]->type() == L2_NORMALIZE) { + return PrepareL2Normalize(ops, iter_ops, basic_stra); + } + + for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); + iter_op_inputs++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) { + stra.push_back(s_empty); + continue; + } + + std::vector tmp_stra = basic_stra; + bool modified = false; + for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) { + tmp_stra[j] = 1; + modified = true; + } + } + if (modified) { + stra.push_back(tmp_stra); + } else { + stra.push_back(basic_stra); + } + } + return stra; +} + +void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, + const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list) { + if (no_stra_op_list->size() == 0) { + return; + } + std::vector no_stra_op_list_bis; + + for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { + size_t iter_ops = no_stra_op_list->at(iter_list - 1); + std::vector> stra; + std::vector s; + size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); + if (incoming_op_index != SIZE_MAX) { + auto iter_graph = index_list->at(incoming_op_index); + if (iter_graph != SIZE_MAX) { + s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph); + } else { + s = CopyIncomingOperatorInputStrategy(ops, iter_ops, incoming_op_index); + } + } + + if (s.size() == 0) { + no_stra_op_list_bis.push_back(iter_ops); + } else { + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); + } + + StrategyPtr sp = std::make_shared(0, stra); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } + + no_stra_op_list->clear(); + for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) { + no_stra_op_list->push_back(no_stra_op_list_bis[i]); + } +} + +std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, + const size_t iter_ops, std::vector s) { + std::vector s_Squeeze; + auto axis_list = GetAxisList(ops, iter_ops); + size_t s_index = 0; + size_t axis_list_index = 0; + for (size_t i = 0; i < (size_t)(s.size() + axis_list.size()); i++) { + if (i == (size_t)axis_list[axis_list_index]) { + s_Squeeze.push_back(1); + axis_list_index++; + } else { + s_Squeeze.push_back(s[s_index]); + s_index++; + } + } + + size_t cut = 1; + for (size_t i = 0; i < s_Squeeze.size(); i++) { + cut *= s_Squeeze[i]; + } + if (cut != g_device_manager->DeviceNum()) { + s_Squeeze.clear(); + } + + return s_Squeeze; +} + +std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, + const std::vector> &input_tensor_names, + const size_t iter_ops) { + std::vector s; + if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || + ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || + ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE || + ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) { + return s; + } + + bool found = false; + size_t outgoing_op_index = SIZE_MAX; + size_t iter_op_inputs = SIZE_MAX; + for (size_t i = 0; i < input_tensor_names.size(); i++) { + for (size_t j = 1; j < input_tensor_names[i].size(); j++) { + if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0] && + ops[i]->selected_strategy()->GetInputNumber() != 0) { + outgoing_op_index = i; + iter_op_inputs = j - 1; + found = true; + break; + } + } + if (found) { + break; + } + } + + if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) { + for (size_t k = 0; k < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); ++k) { + s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]); + } + } + return s; +} + +void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &no_stra_op_list) { + if (no_stra_op_list->size() == 0) { + return; + } + std::vector no_stra_op_list_bis; + + for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { + auto iter_ops = no_stra_op_list->at(iter_list - 1); + std::vector> stra; + std::vector s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); + + if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) { + s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); + } + if (s.size() != 0) { + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); + } else { + no_stra_op_list_bis.push_back(iter_ops); + } + + StrategyPtr sp = std::make_shared(0, stra); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } + + no_stra_op_list->clear(); + for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) { + no_stra_op_list->push_back(no_stra_op_list_bis[i]); + } +} + +void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list) { + if (no_stra_op_list->size() == 0) { + return; + } + + size_t no_stra_op_list_size = no_stra_op_list->size(); + do { + no_stra_op_list_size = no_stra_op_list->size(); + GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); + GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); + } while (no_stra_op_list_size > no_stra_op_list->size()); + + for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) { + auto iter_ops = no_stra_op_list->at(iter_list); + std::vector> stra; + std::vector s; + + size_t max_dim_num = 0; + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() > max_dim_num) { + max_dim_num = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); + } + } + for (size_t i = 0; i < max_dim_num; i++) { + s.push_back(1); + } + + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); + StrategyPtr sp = std::make_shared(0, stra); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h new file mode 100644 index 0000000000..9acd05e0a9 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ +#define PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/ops_info/operator_info.h" + +namespace mindspore { +namespace parallel { +void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, + const std::shared_ptr>> &eli_list, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list); +std::vector> PrepareMatMul(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> PrepareBiasAdd(const std::shared_ptr> &s); +std::vector> PrepareOneHot(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> PrepareGatherV2(const std::vector> &ops, + const size_t iter_ops, std::vector s); +std::vector> PrepareL2Normalize(const std::vector> &ops, + const size_t iter_ops, std::vector s); +std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> PrepareStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const std::shared_ptr> &index_list); +size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, + const size_t iter_ops); +std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_ops, const size_t iter_graph); +std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t incoming_op_index); +std::vector GetAxisList(const std::vector> &ops, const int iter_ops); +std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s); +bool GetKeepDims(const std::vector> &ops, const size_t iter_ops); +std::vector GetDimList(const std::vector> &ops, const size_t iter_ops); +std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s); +std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops); +std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s); +std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t iter_ops, const size_t incoming_op_index); +std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, + const size_t iter_ops, + std::vector basic_stra); +void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, + const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list); +std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, + const size_t iter_ops, std::vector s); +std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, + const std::vector> &input_tensor_names, + const size_t iter_ops); +void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &no_stra_op_list); +void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list); +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_graph.h new file mode 100644 index 0000000000..15b8220016 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_graph.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ +#define PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ + +#include +#include +#include + +#include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" + +namespace mindspore { +namespace parallel { +enum OperatorType { + kRecUnkownType, + kRecMatMul, + kRecConvolution, + kRecPooling, + kRecElmWiseOp, + kRecReLU, + kRecBatchNorm, + kRecReshape, + kRecBiasAdd, + kRecSoftmax, + kRecSparseSoftmaxCrossEntropyWithLogits, + kRecSoftmaxCrossEntropyWithLogits, + kRecOneHot, + kRecLog, + kRecExp, + kRecAdd, + kRecSub, + kRecMul, + kRecDiv, + kRecSqueeze, + kRecCast, + kRecReduce, + kRecPReLU, + kRecGatherV2, + kRecArgWithValue +}; + +enum InfoType { kApplication, kConstant }; + +struct OperatorRec { + OperatorType op_type; + TensorParam arguments[MAX_INPUT_NUM]; + StrategyRec str; +}; + +// Define simplified dataflow Graph for partitioning +class Graph { + public: + struct NodeType { + std::string name; + // Nodes that point to this node + std::vector node_in; + // Nodes that point from this node + std::vector node_out; + std::vector node_in_aux; + // Node Type Info: Application or Constant. Defined in enum . + InfoType info; + // Operator info. Defined in struct . + OperatorRec apply; + // Tensor info. Defined in tensor.h struct . + TensorParam tensor_parm; + }; + + std::vector nodes; // Nodes of the graph. Pubic. +}; // Define simplified dataflow Graph for partitioning +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc new file mode 100644 index 0000000000..a393c825df --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -0,0 +1,264 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" +#include "frontend/parallel/ops_info/operator_info.h" + +namespace mindspore { +namespace parallel { +const TensorParam MakeTensor(int n, int c, int h, int w) { + TensorParam new_tensor; + new_tensor.tensor_type = kFloat32; + new_tensor.tensor_shape.shape_n = n; + new_tensor.tensor_shape.shape_c = c; + new_tensor.tensor_shape.shape_h = h; + new_tensor.tensor_shape.shape_w = w; + const TensorParam &tensor = new_tensor; + return tensor; +} + +Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops) { + Graph::NodeType NewOp; + NewOp.name = ops[iter_ops]->name(); + NewOp.info = InfoType::kApplication; + + auto op_type = ops[iter_ops]->type(); + auto idx = DictOpType.find(op_type); + if (idx == DictOpType.end()) { + NewOp.apply.op_type = OperatorType::kRecUnkownType; + MS_LOG(INFO) << "Unknown operator type."; + } else { + NewOp.apply.op_type = DictOpType.at(op_type); + } + + if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { + NewOp.tensor_parm = MakeTensor( + ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], + ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { + NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], + ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { + NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0]); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) { + NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); + } else { + MS_LOG(ERROR) << "Tensor's shape is unknown."; + } + + NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp); + return NewOp; +} + +OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, + Graph::NodeType NewTensor) { + for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); + iter_input_tensors++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { + NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); + } else { + MS_LOG(ERROR) << "Tensor's shape is unknown."; + } + } + return NewTensor.apply; +} + +TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, + const size_t iter_input_tensors, Graph::NodeType NewTensor) { + if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { + auto attrs = ops[iter_ops]->attrs(); + bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); + bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); + if (transpose_a && (iter_input_tensors == 0)) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); + } else if (transpose_b && (iter_input_tensors == 1)) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); + } else { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); + } + } else { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); + } + return NewTensor.apply.arguments[iter_input_tensors]; +} + +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names) { + std::shared_ptr graph(new Graph); + if (ops.size() > SIZE_MAX / 2) { + MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2; + } + + for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { + Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops); + graph->nodes.push_back(NewOp); + } + MakeEdge(input_tensor_names, graph); + + return graph; +} + +void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph) { + for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { + for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { + size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); + if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) { + graph->nodes[iter_i].node_in.push_back(head_node_index); + graph->nodes[head_node_index].node_out.push_back(iter_i); + } + } + } +} + +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_name, + const std::string &input_name) { + for (size_t index = 0; index < input_tensor_name.size(); index++) { + if (input_tensor_name[index][0] == input_name) { + return index; + } + } + MS_LOG(INFO) << "Get index failed, using SIZE_MAX insted"; + return SIZE_MAX; +} + +void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list) { + std::vector eli; + eli.push_back(node_index); + for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) { + eli.push_back(graph->nodes[node_index].node_out[i]); + } + eli_list->push_back(eli); + + for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) { + auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out; + auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index); + if (it != incoming_outputs->end()) { + it = incoming_outputs->erase(it); + incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), graph->nodes[node_index].node_out.end()); + } + } + + for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) { + auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out; + auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index); + if (it != aux_incoming_outputs->end()) { + it = aux_incoming_outputs->erase(it); + aux_incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), + graph->nodes[node_index].node_out.end()); + } + } + + for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) { + auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in; + auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index); + if (it != outgoing_inputs->end()) { + if (graph->nodes[node_index].node_in.size() > 0) { + outgoing_inputs->at(std::distance(outgoing_inputs->begin(), it)) = graph->nodes[node_index].node_in[0]; + for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) { + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]); + } + for (size_t j = 1; j < graph->nodes[node_index].node_in_aux.size(); j++) { + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back( + graph->nodes[node_index].node_in_aux[j]); + } + } else { + outgoing_inputs->erase(it); + } + } + } +} + +std::shared_ptr EliminateGraph(const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list, + const std::shared_ptr> &index_list) { + MS_EXCEPTION_IF_NULL(graph); + for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { + auto type = graph->nodes[node_index].apply.op_type; + if (ElementWiseOpType.find(type) != ElementWiseOpType.end()) { + Eliminate_Aux(node_index, graph, eli_list); + } + } + index_list->reserve(graph->nodes.size()); + for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { + index_list->push_back(i); + } + for (size_t i = 0; i < (size_t)eli_list->size(); i++) { + if (eli_list->at(i)[0] >= index_list->size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + index_list->at(eli_list->at(i)[0]) = SIZE_MAX; + for (size_t j = eli_list->at(i)[0] + 1; j < (size_t)index_list->size(); j++) { + index_list->at(j)--; + } + } + std::shared_ptr new_graph(new Graph); + for (size_t i = 0; i < graph->nodes.size(); i++) { + if (index_list->at(i) > SIZE_MAX / 2) { + continue; + } + new_graph->nodes.push_back(graph->nodes[i]); + auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; + for (size_t j = node_in->size(); j > 0; j--) { + bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX); + if (IsEliminated) { + node_in->erase(node_in->begin() + j - 1); + } else { + node_in->at(j - 1) = index_list->at(node_in->at(j - 1)); + } + } + auto *node_out = &new_graph->nodes[index_list->at(i)].node_out; + for (size_t j = node_out->size(); j > 0; j--) { + bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX); + if (IsEliminated) { + node_out->erase(node_out->begin() + j - 1); + } else { + node_out->at(j - 1) = index_list->at(node_out->at(j - 1)); + } + } + } + return new_graph; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h new file mode 100644 index 0000000000..4d0c02f5fe --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ +#define PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ + +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/ops_info/operator_info.h" + +namespace mindspore { +namespace parallel { +static const std::set ElementWiseOpType = { + OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, + OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, + OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, + OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue}; + +const std::map DictOpType{ + {MATMUL, OperatorType::kRecMatMul}, + {CONV2D, OperatorType::kRecConvolution}, + {MAXPOOL, OperatorType::kRecPooling}, + {MAXPOOLV2, OperatorType::kRecPooling}, + {SIMPLE_MEAN, OperatorType::kRecPooling}, + {RESHAPE, OperatorType::kRecReshape}, + {BIAS_ADD, OperatorType::kRecBiasAdd}, + {BATCH_NORM, OperatorType::kRecBatchNorm}, + {FUSE_BATCH_NORM, OperatorType::kRecBatchNorm}, + {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, + {ONEHOT, OperatorType::kRecOneHot}, + {SQUEEZE, OperatorType::kRecSqueeze}, + {CAST, OperatorType::kRecCast}, + {REDUCE_SUM, OperatorType::kRecReduce}, + {REDUCE_MAX, OperatorType::kRecReduce}, + {REDUCE_MIN, OperatorType::kRecReduce}, + {REDUCE_MEAN, OperatorType::kRecReduce}, + {GATHERV2, OperatorType::kRecGatherV2}, + {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue}, + {ARGMINWITHVALUE, OperatorType::kRecArgWithValue}, + + {RELU, OperatorType::kRecReLU}, + {"ReLU6", OperatorType::kRecReLU}, + {"ReLUV2", OperatorType::kRecReLU}, + {SIGMOID, OperatorType::kRecReLU}, + {SIGMOID_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecReLU}, + {"HSigmoid", OperatorType::kRecReLU}, + {GELU, OperatorType::kRecReLU}, + {TANH, OperatorType::kRecReLU}, + + {PRELU, OperatorType::kRecPReLU}, + + {TRANSPOSE, OperatorType::kRecElmWiseOp}, + {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, + {TENSOR_ADD, OperatorType::kRecElmWiseOp}, + {SUB, OperatorType::kRecElmWiseOp}, + {MUL, OperatorType::kRecElmWiseOp}, + {DIV, OperatorType::kRecElmWiseOp}, + {REAL_DIV, OperatorType::kRecElmWiseOp}, + {SOFTMAX, OperatorType::kRecSoftmax}, + {LOG_SOFTMAX, OperatorType::kRecSoftmax}, + {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits}, + {SQRT, OperatorType::kRecElmWiseOp}, + {NEG, OperatorType::kRecElmWiseOp}, + {POW, OperatorType::kRecElmWiseOp}, + {EXP, OperatorType::kRecElmWiseOp}, + {LOG, OperatorType::kRecElmWiseOp}, + {COS, OperatorType::kRecElmWiseOp}, + {ACOS, OperatorType::kRecElmWiseOp}, + {LOGICALNOT, OperatorType::kRecElmWiseOp}, + {"LogicalAnd", OperatorType::kRecElmWiseOp}, + {"LogicalOr", OperatorType::kRecElmWiseOp}, + {SQUARE, OperatorType::kRecElmWiseOp}, + {"Abs", OperatorType::kRecElmWiseOp}, + {"Acosh", OperatorType::kRecElmWiseOp}, + {"AddN", OperatorType::kRecElmWiseOp}, + {"AccumulateNV2", OperatorType::kRecElmWiseOp}, + {"Atan2", OperatorType::kRecElmWiseOp}, + {"Erf", OperatorType::kRecElmWiseOp}, + {"Floor", OperatorType::kRecElmWiseOp}, + {FLOORDIV, OperatorType::kRecElmWiseOp}, + {"FloorMod", OperatorType::kRecElmWiseOp}, + {GREATER, OperatorType::kRecElmWiseOp}, + {"GreaterEqual", OperatorType::kRecElmWiseOp}, + {"HSwish", OperatorType::kRecElmWiseOp}, + {"Less", OperatorType::kRecElmWiseOp}, + {"LessEqual", OperatorType::kRecElmWiseOp}, + {MAXIMUM, OperatorType::kRecElmWiseOp}, + {MINIMUM, OperatorType::kRecElmWiseOp}, + {EQUAL, OperatorType::kRecElmWiseOp}, + {NOT_EQUAL, OperatorType::kRecElmWiseOp}, + {"Reciprocal", OperatorType::kRecElmWiseOp}, + {"Round", OperatorType::kRecElmWiseOp}, + {"Rsqrt", OperatorType::kRecElmWiseOp}, + {"Sign", OperatorType::kRecElmWiseOp}, + {"Sin", OperatorType::kRecElmWiseOp}, + {ASSIGN, OperatorType::kRecElmWiseOp}, + {ASSIGN_SUB, OperatorType::kRecElmWiseOp}, + {"AssignAdd", OperatorType::kRecElmWiseOp}}; + +const TensorParam MakeTensor(int n, int c, int h, int w); + +Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops); + +OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, + Graph::NodeType NewTensor); + +TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, + const size_t iter_input_tensor, Graph::NodeType NewTensor); + +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names); + +void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph); + +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, + const std::string &input_name); + +void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list); + +std::shared_ptr EliminateGraph(const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list, + const std::shared_ptr> &index_list); +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc new file mode 100644 index 0000000000..97d230a49f --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc @@ -0,0 +1,310 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +// Get the target node's weight for sorting. +double GetWeights(const Graph::NodeType &node) { + const OperatorRec &op = node.apply; + + if (op.op_type == OperatorType::kRecMatMul) { + // For MatMul + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(op); + } else if (op.op_type == OperatorType::kRecConvolution) { + // For Convolution + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(node); + } else if (op.op_type == OperatorType::kRecPooling) { + // For Pooling + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecElmWiseOp) { + // For TensorAdd + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecReLU) { + // For Activation + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecReshape) { + // For Reshape + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecBiasAdd) { + // For BiasAdd + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecLog || op.op_type == OperatorType::kRecExp || + op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecSub || + op.op_type == OperatorType::kRecMul || op.op_type == OperatorType::kRecDiv || + op.op_type == OperatorType::kRecSqueeze || op.op_type == OperatorType::kRecCast) { + // For element-wise op + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot || + op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax || + op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || + op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { + // For BatchParallel op + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMaxCostIn(); + } else if (op.op_type == OperatorType::kRecUnkownType) { + // For Unkown type + return 0.0; + } else { + MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; + } +} + +// Sort all the nodes by their weights +std::vector SortByWeight(const std::shared_ptr &graph) { + MS_EXCEPTION_IF_NULL(graph); + + std::vector> weight_to_node_index; + std::vector node_index_by_weights; + + // Get node's weight. + for (size_t i = 0; i < graph->nodes.size(); i++) { + if (graph->nodes[i].info == kApplication) { + const Graph::NodeType &node_ptr = graph->nodes[i]; + double weight = GetWeights(node_ptr); + size_t index = i; + weight_to_node_index.push_back(std::make_pair(weight, index)); + } + } + + // Ordering ops aka nodes of the graph + std::sort(weight_to_node_index.begin(), weight_to_node_index.end()); + + // Store the result in node_index_by_weights. + uint64_t size = weight_to_node_index.size(); + for (uint64_t i = 1; i <= size; i++) { + node_index_by_weights.push_back(weight_to_node_index[size - i].second); + } + + return node_index_by_weights; +} + +// Get optimal strategy to partition the target node +StrategyRec PartitionNode(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const std::shared_ptr &graph) { + bool enable_conv_chw_partition = false; + MS_EXCEPTION_IF_NULL(graph); + + if (node.apply.op_type == OperatorType::kRecMatMul) { + // For MatMul + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecConvolution) { + // For Convolution + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph, enable_conv_chw_partition); + } else if (node.apply.op_type == OperatorType::kRecPooling) { + // For Pooling + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecElmWiseOp) { + // For TensorAdd + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecReLU) { + // For Activation + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecReshape) { + // For Reshape + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node); + } else if (node.apply.op_type == OperatorType::kRecBiasAdd) { + // For BiasAdd + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecLog || node.apply.op_type == OperatorType::kRecExp || + node.apply.op_type == OperatorType::kRecAdd || node.apply.op_type == OperatorType::kRecSub || + node.apply.op_type == OperatorType::kRecMul || node.apply.op_type == OperatorType::kRecDiv || + node.apply.op_type == OperatorType::kRecSqueeze || node.apply.op_type == OperatorType::kRecCast) { + // For element-wise op + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot || + node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax || + node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { + // For BatchParallel type + auto cost_ptr = std::make_shared(); + return cost_ptr->GetOptimalStr(node); + } else if (node.apply.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { + // For SoftmaxCrossEntropyWithLogits type + auto cost_ptr = std::make_shared(); + return cost_ptr->GetOptimalStr(node); + } else if (node.apply.op_type == OperatorType::kRecUnkownType) { + // For Unkown type + StrategyRec default_strategy; + return default_strategy; + } else { + MS_LOG(EXCEPTION) << "Failure: Partition Operator failed."; + } +} + +// Parttion graph into all devices. +Status PartitionForAllDevices(const size_t num_device, const double device_memory, + const std::shared_ptr &graph) { + if (num_device < 1) { + MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; + } + + if (num_device > 1024) { + MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be larger than 1024."; + } + + MS_EXCEPTION_IF_NULL(graph); + + // Comopute iter times + int iter_times = static_cast(log2(num_device)); + + // N-cuts loop + for (int loop = 0; loop < iter_times; loop++) { + // Sort by weights + std::vector reorder_node_list = SortByWeight(graph); + + // get total node number + size_t iter_nodes = reorder_node_list.size(); + + // temp vector to map nodename to its strategy. + std::vector> node_name_to_strategy; + + // Loop for all the nodes + for (size_t i_node = 0; i_node < iter_nodes; i_node++) { + // get current node's index + size_t index = reorder_node_list[i_node]; + + Graph::NodeType &node_ptr = graph->nodes[index]; + + // Serch optimal strategy to cut this operator. And store the result optimal strategy in graph. + graph->nodes[index].apply.str = PartitionNode(node_ptr, node_name_to_strategy, graph); + + // Apply OP Strategy to Tensor Strategy. + graph->nodes[index] = ApplyStrToTensor(node_ptr); + + // Note down the node name and its strategy in this loop. + auto node_name_to_str = + std::pair(graph->nodes[index].name, graph->nodes[index].apply.str); + node_name_to_strategy.push_back(node_name_to_str); + } + } + + if (DevicesMemoryControl(num_device, device_memory, graph) != SUCCESS) { + return FAILED; + } else { + return SUCCESS; + } +} + +// Apply OP Strategy to Tensor Strategy +Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { + // Set Node's tensor_parm + Node.tensor_parm.tensor_str.str_n = Node.apply.str.outputTensor.str_n; + Node.tensor_parm.tensor_str.str_c = Node.apply.str.outputTensor.str_c; + Node.tensor_parm.tensor_str.str_h = Node.apply.str.outputTensor.str_h; + Node.tensor_parm.tensor_str.str_w = Node.apply.str.outputTensor.str_w; + + // Set input tensors' tersor_parm + for (int i = 0; i < 2; i++) { + Node.apply.arguments[i].tensor_str.str_n = Node.apply.str.inputTensor[i].str_n; + Node.apply.arguments[i].tensor_str.str_c = Node.apply.str.inputTensor[i].str_c; + Node.apply.arguments[i].tensor_str.str_h = Node.apply.str.inputTensor[i].str_h; + Node.apply.arguments[i].tensor_str.str_w = Node.apply.str.inputTensor[i].str_w; + } + return Node; +} + +Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph) { + MS_EXCEPTION_IF_NULL(graph); + if (num_device == 0) { + MS_LOG(EXCEPTION) << "Failure: device number is 0."; + } + + uint64_t iter_nodes = graph->nodes.size(); + double used_memory = 0.0; + + for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) { + if (graph->nodes[i_node].info == 0) { + Graph::NodeType &Node = graph->nodes[i_node]; + for (int index = 0; index < 2; index++) { + used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n * + Node.apply.arguments[index].tensor_str.str_c * Node.apply.arguments[index].tensor_shape.shape_c * + Node.apply.arguments[index].tensor_str.str_h * Node.apply.arguments[index].tensor_shape.shape_h * + Node.apply.arguments[index].tensor_str.str_w * Node.apply.arguments[index].tensor_shape.shape_w * + GetDataTypeSize(Node.apply.arguments[index].tensor_type); + } + } + } + + if (device_memory < (used_memory / num_device)) { + MS_LOG(EXCEPTION) << "Failure: Out of memory!"; + return FAILED; + } else { + return SUCCESS; + } +} + +size_t GetDataTypeSize(const TensorType &type) { + switch (type) { + case kInt8: + return sizeof(int); + case kFloat16: + return sizeof(float) / 2; + case kFloat32: + return sizeof(float); + case kDouble64: + return sizeof(double); + default: + MS_LOG(EXCEPTION) << "GetDataTypeSize Failed. Unexpected type"; + } +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.h new file mode 100644 index 0000000000..528163e4d3 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ +#define PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/rec_core/rec_cost.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +std::vector SortByWeight(const std::shared_ptr &graph); + +double GetWeights(const Graph::NodeType &node); + +StrategyRec PartitionNode(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const std::shared_ptr &graph); + +Status PartitionForAllDevices(const size_t num_device, const double device_memory, const std::shared_ptr &graph); + +Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); + +Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph); + +size_t GetDataTypeSize(const TensorType &type); +} // namespace parallel +} // namespace mindspore + +#endif // PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_strategy.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_strategy.h similarity index 100% rename from mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_strategy.h rename to mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_strategy.h diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_tensor.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_tensor.h new file mode 100644 index 0000000000..315c52c867 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_tensor.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ +#define PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ + +#include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h" + +namespace mindspore { +namespace parallel { +enum TensorType { kInt8, kFloat16, kFloat32, kDouble64 }; + +struct Shape4D { + int32_t shape_n = 1; + int32_t shape_c = 1; + int32_t shape_h = 1; + int32_t shape_w = 1; +}; + +struct TensorParam { + TensorType tensor_type = kFloat32; // default as float. + Shape4D tensor_shape; + TensorStr4D tensor_str; +}; +} // namespace parallel +} // namespace mindspore + +#endif // PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc new file mode 100644 index 0000000000..7164660be0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/context.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "frontend/parallel/device_manager.h" + +namespace mindspore { +namespace parallel { +static std::map> param_shapes; + +std::vector PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, + AUTO_PARALLEL}; +std::vector STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING}; + +std::shared_ptr ParallelContext::inst_context_ = nullptr; + +std::shared_ptr ParallelContext::GetInstance() { + if (inst_context_ == nullptr) { + inst_context_.reset(new (std::nothrow) ParallelContext()); + } + return inst_context_; +} + +ParallelContext::ParallelContext() { Reset(); } + +void ParallelContext::Reset() { + mirror_mean_ = false; + full_batch_ = false; + cast_before_mirror_ = true; + loss_repeated_mean_ = true; + device_num_ = 1; + global_rank_ = 0; + communication_backend_ = HCCL_BACKEND; + device_num_is_set_ = false; + global_rank_is_set_ = false; + parallel_mode_ = STAND_ALONE; + parameter_broadcast_ = false; + parameter_broadcast_is_set_ = false; + enable_all_reduce_fusion_ = false; + strategy_ckpt_load_file_ = ""; + strategy_ckpt_save_file_ = ""; + enable_parallel_optimizer_ = false; +} + +void ParallelContext::set_device_num(int32_t device_num) { + device_num_ = device_num; + device_num_is_set_ = true; +} + +void ParallelContext::set_global_rank(int32_t global_rank) { + global_rank_ = global_rank; + global_rank_is_set_ = true; +} + +void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_mean; } + +void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } + +void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } + +void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } + +void ParallelContext::set_communication_backend(const std::string &communication_backend) { + communication_backend_ = communication_backend; +} + +bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { + auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); + if (iter == PARALLEL_MODE_LIST.end()) { + MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode; + return false; + } + parallel_mode_ = parallel_mode; + return true; +} + +bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) { + auto iter = std::find(STRATEGY_SEARCH_MODE_LIST.begin(), STRATEGY_SEARCH_MODE_LIST.end(), strategy_search_mode); + if (iter == STRATEGY_SEARCH_MODE_LIST.end()) { + MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode; + return false; + } + strategy_search_mode_ = strategy_search_mode; + return true; +} + +void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) { + parameter_broadcast_ = parameter_broadcast; + parameter_broadcast_is_set_ = true; +} + +void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) { + strategy_ckpt_load_file_ = strategy_ckpt_load_file; +} + +void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) { + strategy_ckpt_save_file_ = strategy_ckpt_save_file; +} + +void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group) { + all_reduce_fusion_split_indices_[group] = indices; +} + +const std::vector ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const { + auto iter = all_reduce_fusion_split_indices_.find(group); + if (iter != all_reduce_fusion_split_indices_.end()) { + return iter->second; + } + return {}; +} + +void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group) { + all_reduce_fusion_split_sizes_[group] = sizes; +} + +const std::vector ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const { + auto iter = all_reduce_fusion_split_sizes_.find(group); + if (iter != all_reduce_fusion_split_sizes_.end()) { + return iter->second; + } + return {}; +} + +// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode +void ParallelParameterContextInit(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { + return; + } + param_shapes.clear(); +} + +// Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode +void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + AbstractBasePtr ptr) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(param_node); + MS_EXCEPTION_IF_NULL(ptr); + if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) || + func_graph->has_flag(TRAINING)) { + return; + } + + auto iter = param_shapes.find(param_node->name()); + if (iter == param_shapes.end()) { + MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); + return; + } + std::vector shape = iter->second; + std::shared_ptr base_shape = std::make_shared(shape); + ptr->set_shape(base_shape); + MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; +} + +// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode +void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + const AbstractBasePtr &ptr) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(param_node); + MS_EXCEPTION_IF_NULL(ptr); + if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { + return; + } + + std::vector shape = dyn_cast(ptr->GetShapeTrack())->shape(); + auto ret = param_shapes.try_emplace(param_node->name(), shape); + if (!ret.second) { + MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed"; + return; + } + + MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h new file mode 100644 index 0000000000..1bb40d5c29 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -0,0 +1,142 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ +#define MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/status.h" +#include "utils/convert_utils.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "debug/info.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace parallel { +constexpr char STAND_ALONE[] = "stand_alone"; +constexpr char DATA_PARALLEL[] = "data_parallel"; +constexpr char HYBRID_PARALLEL[] = "hybrid_parallel"; +constexpr char AUTO_PARALLEL[] = "auto_parallel"; +constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel"; + +constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming"; +constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; + +constexpr char TRAINING[] = "training"; + +class ParallelContext { + public: + ~ParallelContext() = default; + ParallelContext(const ParallelContext &) = delete; + ParallelContext &operator=(const ParallelContext &) = delete; + + static std::shared_ptr GetInstance(); + + void set_mirror_mean(bool mirror_mean); + bool mirror_mean() const { return mirror_mean_; } + + void set_full_batch(bool full_batch); + bool full_batch() const { return full_batch_; } + + void set_cast_before_mirror(bool cast_before_mirror); + bool cast_before_mirror() const { return cast_before_mirror_; } + + void set_loss_repeated_mean(bool loss_repeated_mean); + bool loss_repeated_mean() const { return loss_repeated_mean_; } + + void set_device_num(int32_t device_num); + int32_t device_num() const { return device_num_; } + + void set_global_rank(int32_t global_rank); + int32_t global_rank() const { return global_rank_; } + + void set_communication_backend(const std::string &communication_backend); + std::string communication_backend() const { return communication_backend_; } + + bool set_parallel_mode(const std::string ¶llel_mode); + std::string parallel_mode() const { return parallel_mode_; } + + bool set_strategy_search_mode(const std::string &strategy_search_mode); + std::string strategy_search_mode() const { return strategy_search_mode_; } + + void set_parameter_broadcast(bool parameter_broadcast); + bool parameter_broadcast() const { return parameter_broadcast_; } + + bool device_num_is_set() const { return device_num_is_set_; } + bool global_rank_is_set() const { return global_rank_is_set_; } + bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } + + void SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group); + const std::vector GetAllReduceFusionSplitIndices(const std::string &group) const; + void SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group); + const std::vector GetAllReduceFusionSplitSizes(const std::string &group) const; + void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { + enable_all_reduce_fusion_ = enable_all_reduce_fusion; + } + bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; } + + void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file); + std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } + void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); + std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } + + void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { + enable_parallel_optimizer_ = enable_parallel_optimizer; + } + bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } + + void Reset(); + + private: + ParallelContext(); + static std::shared_ptr inst_context_; + bool mirror_mean_; + bool full_batch_; + bool cast_before_mirror_; + bool loss_repeated_mean_; + int32_t device_num_; + int32_t global_rank_; + std::string communication_backend_; + std::string parallel_mode_; + std::string strategy_search_mode_; + bool parameter_broadcast_; + bool device_num_is_set_; + bool global_rank_is_set_; + bool parameter_broadcast_is_set_; + bool enable_all_reduce_fusion_; + std::map> all_reduce_fusion_split_indices_; + std::map> all_reduce_fusion_split_sizes_; + std::string strategy_ckpt_load_file_; + std::string strategy_ckpt_save_file_; + bool enable_parallel_optimizer_; +}; + +void ParallelParameterContextInit(const FuncGraphPtr &func_graph); +void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + AbstractBasePtr ptr); +void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + const AbstractBasePtr &ptr); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc new file mode 100644 index 0000000000..67d087eabd --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/costmodel_context.h" + +#include + +#include "frontend/parallel/allreduce_fusion/allreduce_fusion.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" + +namespace mindspore { +namespace parallel { +std::shared_ptr CostModelContext::cm_context_inst_ = nullptr; + +std::shared_ptr CostModelContext::GetInstance() { + if (cm_context_inst_ == nullptr) { + MS_LOG(INFO) << "Create costmodel_context"; + cm_context_inst_.reset(new (std::nothrow) CostModelContext()); + } + return cm_context_inst_; +} + +CostModelContext::CostModelContext() { + ResetCostModel(); + ResetAlgoParameters(); +} + +void CostModelContext::ResetCostModel() { + device_memory_capacity_ = DEFAULT_DEVICE_MEMORY_CAPACITY; + costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; + costmodel_beta_ = DEFAULT_COST_MODEL_BETA; + costmodel_gamma_ = DEFAULT_COST_MODEL_GAMMA; + costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; + costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; + costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; + is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; + run_phase_ = DEFAULT_RUN_PHASE; + costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; + costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; + costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; + costmodel_allreduce_fusion_tail_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME; + costmodel_allreduce_fusion_allreduce_inherent_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME; + costmodel_allreduce_fusion_allreduce_bandwidth_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH; + costmodel_allreduce_fusion_computation_time_parameter_ = + DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER; +} + +void CostModelContext::ResetAlgoParameters() { + costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; + tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; + tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; + fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; + elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; +} + +void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; } + +void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; } + +void CostModelContext::set_costmodel_beta(double cm_beta) { costmodel_beta_ = cm_beta; } + +void CostModelContext::set_costmodel_gamma(double cm_gamma) { costmodel_gamma_ = cm_gamma; } + +void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { costmodel_simplify_cal_ = cm_simplify; } + +void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) { + costmodel_communi_threshold_ = cm_communi_th; +} + +void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { + costmodel_communi_const_ = cm_communi_const; +} + +void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } + +void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; } +void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int32_t algorithm) { + costmodel_allreduce_fusion_algorithm_ = algorithm; +} + +void CostModelContext::set_costmodel_allreduce_fusion_times(int32_t allreduce_fusion_times) { + costmodel_allreduce_fusion_times_ = allreduce_fusion_times; +} + +void CostModelContext::set_costmodel_allreduce_fusion_tail_percent(double tail_percent) { + costmodel_allreduce_fusion_tail_percent_ = tail_percent; +} + +void CostModelContext::set_costmodel_allreduce_fusion_tail_time(double tail_time) { + costmodel_allreduce_fusion_tail_time_ = tail_time; +} + +void CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time(double allreduce_inherent_time) { + costmodel_allreduce_fusion_allreduce_inherent_time_ = allreduce_inherent_time; +} + +void CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth(double allreduce_bandwidth) { + costmodel_allreduce_fusion_allreduce_bandwidth_ = allreduce_bandwidth; +} + +void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter(double computation_time_parameter) { + costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter; +} + +void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { tensor_slice_alignment_enable_ = ts_align; } + +void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { + tensor_slice_alignment_size_ = ts_align_size; +} + +void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; } + +void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { + elementwise_stra_follow_ = elementwise_follow; +} + +void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/costmodel_context.h b/mindspore/ccsrc/frontend/parallel/costmodel_context.h similarity index 100% rename from mindspore/ccsrc/parallel/costmodel_context.h rename to mindspore/ccsrc/frontend/parallel/costmodel_context.h diff --git a/mindspore/ccsrc/frontend/parallel/device.h b/mindspore/ccsrc/frontend/parallel/device.h new file mode 100644 index 0000000000..c9633623d2 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/device.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ +#define MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ + +#include +#include +#include + +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +class Device { + // This class abstract the 'device' information, used in Parallel module. + public: + Device() : rank_(0) { name_.clear(); } + explicit Device(int32_t rank) : rank_(rank) { name_.clear(); } + Device(std::string name, int32_t rank) : name_(std::move(name)), rank_(rank) {} + ~Device() = default; + std::string name() const { return name_; } + int32_t rank() const { return rank_; } + + private: + std::string name_; + int32_t rank_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.cc b/mindspore/ccsrc/frontend/parallel/device_manager.cc new file mode 100644 index 0000000000..d3657afdb8 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/device_manager.cc @@ -0,0 +1,374 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/device_manager.h" + +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/step_parallel.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +DeviceManagerPtr g_device_manager = nullptr; + +Stage::Stage(const std::vector &devices, int num, int rank) + : devices_(devices), number_(num), rank_(rank) { + gm_ = GroupManager(); +} + +// NOTE: '-1' indicates ERROR +int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); } + +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) { + if (device_num <= 0) { + MS_LOG(ERROR) << "'device_num' must be positive."; + return false; + } + if (global_rank < 0) { + MS_LOG(ERROR) << "'global_rank' must be nonnegative."; + return false; + } + if (device_num > MAX_DEVICE_NUM) { + MS_LOG(ERROR) << "'device_num' must be no more than " << MAX_DEVICE_NUM << "."; + return false; + } + // 'device_num_converted' must be the power of 2 + if ((IntToUint(device_num) & IntToUint(device_num - 1)) != 0) { + MS_LOG(ERROR) << "'device_num' must be the power of 2."; + return false; + } + if (global_rank >= device_num) { + MS_LOG(ERROR) << "'global_rank' must be less than 'device_num'."; + return false; + } + if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { + MS_LOG(ERROR) << "Invalid backend: " << backend; + return false; + } + + RankList devices, stage_map; + for (int i = 0; i < device_num; ++i) { + devices.push_back(i); + } + + stage_map.push_back(device_num); + g_device_manager = std::make_shared(); + if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) { + MS_LOG(INFO) << "Device initialization succeeds."; + return true; + } else { + MS_LOG(ERROR) << "Device initialization fails."; + return false; + } +} + +void CheckGlobalDeviceManager() { + if (g_device_manager == nullptr) { + MS_LOG(EXCEPTION) << "Device information has not been set!"; + } +} + +int32_t GetListMemberByIndex(size_t index, const RankList &devices) { + size_t i = 0; + int32_t result = 0; + if ((devices.empty()) || (index >= devices.size())) { + MS_LOG(EXCEPTION) << "Index is out of the list scope"; + } + auto it = devices.begin(); + for (; it != devices.end(); ++it) { + if (i == index) { + result = *it; + break; + } + ++i; + } + return result; +} + +std::shared_ptr GetListMemberByIndex(size_t index, const std::vector> &device_list) { + size_t i = 0; + std::shared_ptr result; + if ((device_list.empty()) || (index >= device_list.size())) { + MS_LOG(EXCEPTION) << "Index is out of the list scope"; + } + auto it = device_list.begin(); + for (; it != device_list.end(); ++it) { + if (i == index) { + result = *it; + break; + } + ++i; + } + return result; +} + +// E.g. devices = [4, 5, 2, 1, 7, 8, 10], stage_map = [4, 3], +// therefore the stage_devices_ = [[4, 5, 2, 1], [7, 8, 10]]. +Status DeviceManager::Init(const RankList &devices, int32_t global_device_rank, const RankList &stage_map, + const std::string &backend) { + auto dev_it = devices.begin(); + auto stage_it = stage_map.begin(); + int32_t sum = 0; + + if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { + MS_LOG(ERROR) << "Invalid backend: " << backend; + return Status::FAILED; + } + + for (; stage_it != stage_map.end(); ++stage_it) { + sum += (*stage_it); + } + if (IntToSize(sum) != devices.size()) { + MS_LOG(ERROR) << "The number of 'devices' in the list is not equal to the mentioned " + << "size of 'stage_map'"; + return Status::FAILED; + } + + for (; dev_it != devices.end(); ++dev_it) { + std::shared_ptr one = std::make_shared(*dev_it); + devices_.push_back(one); + } + + size_t global_index = 0; + for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) { + int num_device = *stage_it; + if (num_device > MAX_DEVICE_NUM) { + MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM; + return Status::FAILED; + } + if (num_device <= 0) { + MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; + return Status::FAILED; + } + RankList curr_dev_list; + for (int i = 0; i < num_device; ++i) { + curr_dev_list.push_back(GetListMemberByIndex(global_index, devices)); + global_index++; + } + stage_devices_.push_back(curr_dev_list); + } + + global_index = 0; + for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) { + int num_device = *stage_it; + if (num_device > MAX_DEVICE_NUM) { + MS_LOG(ERROR) << "The number of 'devices' in a stage must be less than " << MAX_DEVICE_NUM; + return Status::FAILED; + } + if (num_device <= 0) { + MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; + return Status::FAILED; + } + std::vector curr_dev_list; + for (int i = 0; i < num_device; ++i) { + curr_dev_list.push_back(*GetListMemberByIndex(global_index, devices_)); + global_index++; + } + std::shared_ptr new_stage = std::make_shared(curr_dev_list); + stages_.push_back(new_stage); + } + + std::shared_ptr dev = std::make_shared(global_device_rank); + device_ = dev; + set_global_rank(global_device_rank); + backend_ = backend; + + if (backend == HCCL_BACKEND) { + gm_.set_world_group(HCCL_WORLD_GROUP); + } else if (backend_ == NCCL_BACKEND) { + gm_.set_world_group(NCCL_WORLD_GROUP); + } else { + gm_.set_world_group(UNDEFINED_WORLD_GROUP); + } + MS_LOG(INFO) << "The device num: " << devices.size() << "rank id: " << global_device_rank + << "the backend: " << backend; + return Status::SUCCESS; +} + +std::shared_ptr DeviceManager::GetStageById(int32_t stage_id) { + std::shared_ptr res; + if (IntToSize(stage_id) >= stages_.size()) { + MS_LOG(ERROR) << "the 'stage_id': " << stage_id << ", is out of the scope of 'stage_devices_': " << stages_.size(); + return res; + } + int32_t index = 0; + for (auto &stage : stages_) { + if (index == stage_id) return stage; + index++; + } + return res; +} + +RankList DeviceManager::GetDeviceListByStageId(int32_t stage_id) const { + if (IntToSize(stage_id) >= stage_devices_.size()) + MS_LOG(ERROR) << "the 'stage_id': " << stage_id + << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); + RankList res; + int32_t index = 0; + for (auto &stage : stage_devices_) { + if (index == stage_id) { + return stage; + } + index++; + } + return res; +} + +RankList DeviceManager::global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const { + RankList res; + if (split_num <= 0) { + return res; + } + if (IntToSize(stage_id) >= stage_devices_.size()) { + MS_LOG(ERROR) << "the 'stage_id': " << stage_id + << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); + return res; + } + + RankList global_list = GetDeviceListByStageId(stage_id); + if (global_list.size() % IntToSize(split_num)) { + MS_LOG(ERROR) << "dev list size(" << global_list.size() << ") can not be divisible by split num: " << stage_id; + return res; + } + + std::vector dev_list; + (void)std::copy(global_list.begin(), global_list.end(), std::back_inserter(dev_list)); + + size_t index = 0; + size_t slice_size = dev_list.size() / IntToSize(split_num); + for (int32_t i = 0; i < split_num; ++i) { + bool found = false; + index = slice_size * IntToSize(i); + for (size_t j = 0; j < slice_size; ++j) { + if (dev_list[index + j] == rank) { + found = true; + break; + } + } + + if (found) { + break; + } + } + + for (size_t k = 0; k < slice_size; ++k) { + res.push_back(dev_list[index + k]); + } + return res; +} + +Device DeviceManager::CreateNewDeviceByRank(int32_t rank) const { return Device(rank); } + +std::vector DeviceManager::CreateDeviceListByRankList(RankList ranks) { + std::vector dev_list; + for (auto &rank : ranks) { + Device one = CreateNewDeviceByRank(rank); + dev_list.push_back(one); + } + return dev_list; +} + +DeviceManager &DeviceManager::GetInstance() { + static DeviceManager instance = DeviceManager(); + return instance; +} + +std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) { + std::string tmp = "WORLD_GROUP"; + if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) { + return tmp; + } + auto iter = group_to_rank_.find(hash_name); + if (iter == group_to_rank_.end()) { + MS_LOG(WARNING) << "Can not find the rank list name by hash name: " << hash_name; + return tmp; + } + return iter->second; +} + +std::string HashName(const std::string &origin_name) { return std::to_string(std::hash{}(origin_name)); } + +// Group name is generated using the increasing ranks of the devices. +// E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name +// is '0-1-3-5-7'. +std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) { + std::string rank_list_name; + std::vector::iterator it; + std::sort(ranks.begin(), ranks.end()); // sorted in increasing order + for (it = ranks.begin(); it != ranks.end(); ++it) { + if (it == ranks.begin()) { + rank_list_name = std::to_string(*it); + } else { + rank_list_name += "-" + std::to_string(*it); + } + } + + // hash rank-list-name and add ranks' size as prefix + std::string group_hash_name = HashName(rank_list_name); + std::string group_name = std::to_string(ranks.size()) + "-" + group_hash_name; + + if (rank_to_group_.find(rank_list_name) == rank_to_group_.end()) { + if (group_to_rank_.find(group_name) == group_to_rank_.end()) { + rank_to_group_[rank_list_name] = group_name; + group_to_rank_[group_name] = rank_list_name; + MS_LOG(INFO) << "The rank list name is " << rank_list_name << "nd group name is " << group_name; + } else { + MS_LOG(EXCEPTION) << "Hash collision, the current rank list: " << rank_list_name + << "the old rank list:" << group_to_rank_.find(group_name)->second + << "the group name: " << group_name; + } + } + return group_name; +} + +// Create the group with the given devices and the given name. The GroupManager +// gm_ will create a new group only if there does not exit a group with the same +// name. Otherwise, let the pointer g point to that group. +Group DeviceManager::CreateGroup(const std::string &group_name, + const std::vector &devices) { + if ((world_group() == NCCL_WORLD_GROUP) && (devices.size() != devices_.size())) { + MS_LOG(EXCEPTION) << "Do not support sub group for nccl"; + } + Group g; + (void)gm_.CreateGroup(group_name, devices, &g); + return g; +} + +// Create the group with only the given devices' ranks. +Group DeviceManager::CreateGroup(const RankList &dev_ranks) { + std::unordered_set rank_set(dev_ranks.begin(), dev_ranks.end()); + if (dev_ranks.size() != rank_set.size()) { + MS_LOG(EXCEPTION) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list"; + } + + std::string group_name = GenerateGroupNameByRanks(dev_ranks); + auto dev_list = CreateDeviceListByRankList(dev_ranks); + return CreateGroup(group_name, dev_list); +} + +void DeviceManager::Clear() { + devices_.clear(); + stage_devices_.clear(); + gm_.Clear(); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.h b/mindspore/ccsrc/frontend/parallel/device_manager.h new file mode 100644 index 0000000000..654acd9dff --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/device_manager.h @@ -0,0 +1,130 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ +#define MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "frontend/parallel/device.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/group_manager.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/strategy.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace parallel { +#define MAX_DEVICE_NUM 1024 + +constexpr char HCCL_BACKEND[] = "hccl"; +constexpr char NCCL_BACKEND[] = "nccl"; +constexpr char UNDEFINED_BACKEND[] = "undefined_backend"; + +class DeviceManager; +using DeviceManagerPtr = std::shared_ptr; +// 'g_device_manager' is the globally unique manager to manage the devices. +extern DeviceManagerPtr g_device_manager; + +class Stage { + // This class is used in pipeline-parallelization. Available devices are partitioned into multiple stages. + // Currently, the function of pipeline-parallelization and this class are NOT implemented. + public: + explicit Stage(std::vector devices) : devices_(std::move(devices)), number_(0), rank_(0) { + gm_ = GroupManager(); + } + Stage(const std::vector &devices, int num, int rank); + ~Stage() = default; + + int GetStageNum() const { return number_; } + size_t GetDevicesNum() const { return devices_.size(); } + std::vector GetDevicesList() { return devices_; } + int global_rank(Group *g) const; + + private: + std::vector devices_; + int number_; + int32_t rank_; + GroupManager gm_; +}; + +// This method is used for initializing the global DeviceManager 'g_device_manager', +// arguments including 'device_num' and 'global_rank' +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend); + +void CheckGlobalDeviceManager(); + +std::string HashName(const std::string &rank_list_name); + +class DeviceManager { + // This class is used to manage the abstract devices, including group-related and stage-related management. + public: + DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(0) { gm_ = GroupManager(); } + ~DeviceManager() = default; + + Status Init(const RankList &devices, int32_t local_device, const RankList &stage_map, const std::string &backend); + + static DeviceManager &GetInstance(); + RankList GetDeviceListByStageId(int32_t stage_id) const; + RankList global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const; + + Device CreateNewDeviceByRank(int32_t rank) const; + std::vector CreateDeviceListByRankList(RankList ranks); + + std::string GenerateGroupNameByRanks(RankList dev_ranks); + Group CreateGroup(const std::string &group_name, const std::vector &devices); + Group CreateGroup(const RankList &dev_ranks); + std::shared_ptr GetStageById(int32_t stage_id); + + size_t DeviceNum() const { return devices_.size(); } + + int32_t GetStageNum() const { return static_cast(stage_devices_.size()); } + + int32_t global_rank() const { return global_rank_; } + std::string backend() const { return backend_; } + void set_global_rank(int32_t global_rank) { global_rank_ = global_rank; } + void Clear(); + std::string world_group() const { return gm_.world_group(); } + std::string FindRankListNameByHashName(const std::string &hash_name); + + private: + std::vector> devices_; + // each stage has a list of devices + std::vector> stage_devices_; + std::shared_ptr device_; + std::vector> stages_; + GroupManager gm_; + std::string backend_; + + // bimap: + std::map rank_to_group_; // the key is rank list, value is hash name + std::map group_to_rank_; // the key is hash name, value is rank list + + int32_t local_rank_; + int32_t global_rank_; + int32_t stage_num_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/device_matrix.cc b/mindspore/ccsrc/frontend/parallel/device_matrix.cc new file mode 100644 index 0000000000..9cc85d9701 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/device_matrix.cc @@ -0,0 +1,170 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/device_matrix.h" + +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +DeviceMatrix::DeviceMatrix(int32_t rank, RankList dev_list, Shape dev_shape) + : rank_(rank), dev_list_(std::move(dev_list)), dev_shape_(std::move(dev_shape)) { + if (!std::any_of(dev_list_.begin(), dev_list_.end(), [rank](int32_t a) { return a == rank; })) { + MS_LOG(EXCEPTION) << "Rank " << rank << " is not in the current stage!"; + } + int32_t total = std::accumulate(dev_shape_.begin(), dev_shape_.end(), 1, std::multiplies()); + if (IntToSize(total) != dev_list_.size()) { + MS_LOG(EXCEPTION) << "Device shape does not match the size of the device list!"; + } +} + +Status DeviceMatrix::CreateGroupList() { + size_t size = dev_shape_.size(); + RankList group; + for (size_t i = 0; i < size; i++) { + Status status = GetDevicesAlongDim(SizeToUint(i), &group); + group_list_.push_back(group); + if (status == Status::FAILED) { + return Status::FAILED; + } + } + return Status::SUCCESS; +} + +Status DeviceMatrix::GetDevicesAlongDim(const uint32_t &dim, RankList *devices) { + if (dim >= dev_shape_.size()) { + MS_LOG(EXCEPTION) << "The dimension " << dim << " is out of the size of the device shape!"; + } + if (dev_shape_[dim] == 1) { + *devices = {rank_}; + return Status::SUCCESS; + } + + RankList group; + std::vector local_group_list; + + // lower than dim + int32_t step = 1; + for (uint32_t i = dim + 1; i < dev_shape_.size(); i++) { + step = step * dev_shape_[i]; + } + int32_t num = *dev_list_.begin(); + for (int32_t i = 0; i < dev_shape_[dim]; i++) { + group.push_back(num); + num += step; + } + + for (int32_t i = 0; i < step; i++) { + local_group_list.push_back(group); + (void)std::for_each(group.begin(), group.end(), [](int32_t &a) { a++; }); + } + + // higher than dim + step = step * dev_shape_[dim]; + int32_t len = SizeToInt(dev_list_.size()) / step; + + // search rank + int32_t target = rank_; + for (int32_t i = 0; i < len; i++) { + for (RankList &temp : local_group_list) { + if (std::any_of(temp.begin(), temp.end(), [target](int32_t a) { return a == target; })) { + *devices = temp; + return Status::SUCCESS; + } + (void)std::for_each(temp.begin(), temp.end(), [step](int32_t &a) { a = a + step; }); + } + } + MS_LOG(ERROR) << "Can't find groups for rank" << rank_ << " in device list!"; + return Status::FAILED; +} + +Shape ConvertRankToCoordinate(int32_t rank, const Shape &dev_shape) { + Shape dev_coordinate; + for (size_t i = 0; i < dev_shape.size(); ++i) { + int32_t size = dev_shape[dev_shape.size() - i - 1]; + if (size == 0) { + MS_LOG(EXCEPTION) << "Invalid dev shape: " << ShapeToString(dev_shape); + } else { + int32_t index = rank % size; + (void)dev_coordinate.insert(dev_coordinate.begin(), index); + rank = rank / size; + } + } + return dev_coordinate; +} + +Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list) { + for (auto &element : tensor_map) { + // -1 means the corresponding dimension is not split. + if (element == MAP_NONE) { + continue; + } else if ((element < 0) || (IntToSize(element) >= dev_shape_.size())) { + MS_LOG(ERROR) << "create group by tensor map: the tensor map is invalid"; + return FAILED; + } + } + + Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_); + for (auto &tmp_rank : dev_list_) { + Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_); + bool matched = true; + for (auto &map : tensor_map) { + if (map == MAP_NONE) { + continue; + } + size_t index = dev_shape_.size() - IntToSize(map) - 1; + if (current_rank_coordinate[index] != tmp_rank_coordinate[index]) { + matched = false; + break; + } + } + if (matched) { + rank_list->push_back(tmp_rank); + } + } + + return SUCCESS; +} + +std::string ShapeToString(const Shape &shape) { + std::string str = "["; + for (size_t i = 0; i < shape.size(); ++i) { + str += std::to_string(shape[i]); + if (i < shape.size() - 1) { + str += ", "; + } + } + return str + "]"; +} + +std::string ListToString(const std::vector &list) { + std::string str = "["; + for (auto &element : list) { + str += std::to_string(element) + ", "; + } + return str + "]"; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/device_matrix.h b/mindspore/ccsrc/frontend/parallel/device_matrix.h new file mode 100644 index 0000000000..f1e7acec39 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/device_matrix.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ +#define MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ + +#include +#include +#include + +#include "frontend/parallel/status.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace parallel { +using RankList = std::vector; +using Shape = std::vector; + +class DeviceMatrix { + public: + DeviceMatrix(int32_t rank, RankList devices, Shape dev_shape); + DeviceMatrix() = default; + ~DeviceMatrix() = default; + std::vector group_list() const { return group_list_; } + Status CreateGroupList(); + Status GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list); + Status GetDevicesAlongDim(const uint32_t &dim, RankList *devices); + + private: + int32_t rank_ = -1; + RankList dev_list_; + // From low dim to high dim. eg: [D0 D1 D2 D3] + Shape dev_shape_; + std::vector group_list_; +}; + +std::string ShapeToString(const Shape &shape); +std::string ListToString(const std::vector &list); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h new file mode 100644 index 0000000000..3ba40fade9 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -0,0 +1,139 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ +#define MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/ops_info/ops_info_head_files.h" +#include "frontend/parallel/step_parallel.h" + +namespace mindspore { +namespace parallel { +#define REGISTER(className) \ + OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \ + return std::make_shared(name, in, out, attrs); \ + } \ + RegisterAction className##Register(#className, (CreatFn)objectCreator##className); + +typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out, + const PrimitiveAttrs &attrs); + +class DynCreator { + public: + ~DynCreator() = default; + + // creat static singleton dyn_creator instance + static DynCreator &Instance() { + static DynCreator fac = DynCreator(); + return fac; + } + // register + void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } + // creator + OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, + const PrimitiveAttrs &attrs, size_t count) { + std::string op_name = name + std::to_string(count); + auto iter = Function_map_.find(name); + if (iter == Function_map_.end()) { + MS_LOG(INFO) << name << " is not register yet"; + return nullptr; + } + return iter->second(op_name, shape_in, shape_out, attrs); + } + + private: + DynCreator() = default; + std::map Function_map_; +}; + +class RegisterAction { + public: + RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { + DynCreator::Instance().Regist(name, creatfn); + } + ~RegisterAction() = default; + + private: + std::string name_; +}; + +// operator register +REGISTER(MatMulInfo); +REGISTER(GeluInfo); +REGISTER(VirtualDatasetInfo); +REGISTER(BatchParallelInfo); +REGISTER(TanhInfo); +REGISTER(SoftmaxInfo); +REGISTER(LogSoftmaxInfo); +REGISTER(ActivationInfo); +REGISTER(SoftmaxCrossEntropyWithLogitsInfo); +REGISTER(SubInfo); +REGISTER(TensorAddInfo); +REGISTER(BiasAddInfo); +REGISTER(MulInfo); +REGISTER(DivInfo); +REGISTER(RealDivInfo); +REGISTER(PowInfo); +REGISTER(ExpInfo); +REGISTER(OneHotInfo); +REGISTER(EqualInfo); +REGISTER(NotEqualInfo); +REGISTER(LogInfo); +REGISTER(CosInfo); +REGISTER(ACosInfo); +REGISTER(LogicalNotInfo); +REGISTER(L2NormalizeInfo); +REGISTER(LayerNormInfo); +REGISTER(ReduceMaxInfo); +REGISTER(ArgMaxWithValueInfo); +REGISTER(ArgMinWithValueInfo); +REGISTER(ReduceMeanInfo); +REGISTER(ReduceSumInfo); +REGISTER(ReduceMinInfo); +REGISTER(TransposeInfo); +REGISTER(PReLUInfo); +REGISTER(DropoutDoMaskInfo); +REGISTER(ReshapeInfo); +REGISTER(FloorDivInfo); +REGISTER(MaximumInfo); +REGISTER(MinimumInfo); +REGISTER(CastInfo); +REGISTER(GreaterInfo); +REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo); +REGISTER(AssignSubInfo); +REGISTER(ReLUInfo); +REGISTER(GatherV2Info); +REGISTER(SparseGatherV2Info); +REGISTER(SqrtInfo); +REGISTER(SigmoidInfo); +REGISTER(GetNextInfo); +REGISTER(NegInfo); +REGISTER(BatchMatMulInfo); +REGISTER(ExpandDimsInfo); +REGISTER(SqueezeInfo); +REGISTER(SigmoidCrossEntropyWithLogitsInfo); +REGISTER(SquareInfo); +REGISTER(GatherV2PInfo); +REGISTER(EmbeddingLookupInfo); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc new file mode 100644 index 0000000000..30c25e5f26 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc @@ -0,0 +1,175 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/graph_util/generate_graph.h" + +#include +#include +#include +#include + +using mindspore::tensor::Tensor; + +namespace mindspore { +namespace parallel { +std::string GetOpPythonPath(const OperatorName &op_name) { + // almost all ops are defined in two main paths + const std::string ops_module = OP_PATH; + const std::string inner_ops_module = INNER_OP_PATH; + py::module mod = py::module::import(common::SafeCStr(ops_module)); + py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); + if (!py::hasattr(mod, common::SafeCStr(op_name))) { + if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { + MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name; + } + return inner_ops_module; + } + return ops_module; +} + +ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) { + std::string op_path = GetOpPythonPath(op_name); + py::module mod = py::module::import(common::SafeCStr(op_path)); + if (!py::hasattr(mod, common::SafeCStr(op_name))) { + MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name; + return nullptr; + } + std::vector arg_list; + (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), + [](const Attr &attr) { return ValuePtrToPyData(attr.second); }); + py::object obj = + parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list); + ValuePtr op_instance = nullptr; + bool succ = parse::ConvertData(obj, &op_instance); + if (!succ) { + MS_LOG(ERROR) << "Failure:get Python op " << op_path << " from " << op_name << " fail"; + return nullptr; + } + return op_instance; +} + +AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) { + auto value_node = NewValueNode(value_ptr); + MS_EXCEPTION_IF_NULL(value_node); + return value_node->cast(); +} + +static std::unordered_map int_tensor_map = {}; +AnfNodePtr CreateInt32Tensor(int32_t value) { + auto it = int_tensor_map.find(value); + if (it != int_tensor_map.end()) { + return it->second; + } + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(py::int_(value), kInt32); + ValuePtr value_ptr = MakeValue(tensor_ptr); + auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr); + int_tensor_map[value] = anf_node_ptr; + return anf_node_ptr; +} + +AnfNodePtr CreatTypeInt(int32_t value) { + ValuePtr value_ptr = MakeValue(std::make_shared(value)); + return ValuePtrToAnfNodePtr(value_ptr); +} + +AnfNodePtr CreatInt32Imm(int32_t value) { + ValuePtr value_ptr = MakeValue(std::make_shared(value)); + return ValuePtrToAnfNodePtr(value_ptr); +} + +std::string GetInstanceNameByCNode(const CNodePtr &cnode) { + PrimitivePtr prim = GetValueNode(cnode->input(0)); + if (!prim) { + MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr."; + } + std::string instance_name = prim->instance_name(); + return HashInstanceName(instance_name); +} + +std::string HashInstanceName(const std::string &name) { + auto using_hash_name = common::GetEnv(USING_HASH_NAME); + std::string instance_name; + if ((using_hash_name.empty()) || (using_hash_name == "on")) { + instance_name = HashName(name); + } else { + instance_name = name; + } + return instance_name; +} + +Status GenerateGraph::Init(const CNodePtr &cnode) { + if (!cnode) { + MS_LOG(ERROR) << "Init:cnode is nullptr"; + return FAILED; + } + cnode_ = cnode; + func_graph_ = cnode->func_graph(); + if (!func_graph_) { + MS_LOG(ERROR) << "Init:func_graph_ is nullptr"; + return FAILED; + } + manager_ = func_graph_->manager(); + if (!manager_) { + MS_LOG(ERROR) << "Init:manager_ is nullptr"; + return FAILED; + } + scope_ = cnode_->scope(); + if (!scope_) { + MS_LOG(ERROR) << "Init:scope_ is nullptr"; + return FAILED; + } + virtual_input_node_ = std::make_shared(nullptr); + virtual_input_node_->set_scope(scope_); + instance_name_base_ = GetInstanceNameByCNode(cnode_); + name_idx_ = 0; + return SUCCESS; +} + +AnfNodePtr GenerateGraph::PushBack(const std::vector &inputs) { + CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode + MS_EXCEPTION_IF_NULL(cnode); + cnode->set_scope(scope_); + if (inputs.size() < 2) { + MS_LOG(EXCEPTION) << "inputs.size() must be more than 1"; + } + (void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[0] + auto new_anf_node_ptr = cnode->cast(); + MS_EXCEPTION_IF_NULL(new_anf_node_ptr); + return new_anf_node_ptr; +} + +AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) { + name_idx_++; + ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_)); + if (pyop_instance == nullptr) { + MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed"; + } + auto value_node = NewValueNode(pyop_instance); + return value_node->cast(); +} + +AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) { + name_idx_++; + OperatorAttrs attrs; + ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_)); + if (pyop_instance == nullptr) { + MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed"; + } + auto value_node = NewValueNode(pyop_instance); + return value_node->cast(); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h new file mode 100644 index 0000000000..b3ef54a22e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ +#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ + +#include +#include +#include +#include +#include +#include + +#include "./common.h" +#include "frontend/optimizer/opt.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +#define USING_HASH_NAME "USING_HASH_NAME" +// Get the operator's path where the operator has be defined +std::string GetOpPythonPath(const OperatorName &op_name); + +// Init python operator Instance +ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name); + +AnfNodePtr CreatTypeInt(int32_t value); +AnfNodePtr CreatInt32Imm(int32_t value); +AnfNodePtr CreateInt32Tensor(int32_t value); +AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr); +std::string HashInstanceName(const std::string &name); + +class GenerateGraph { + public: + GenerateGraph() : name_idx_(0) {} + Status Init(const CNodePtr &cnode); + ~GenerateGraph() = default; + AnfNodePtr virtual_input_node() { return virtual_input_node_; } + AnfNodePtr NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs); + AnfNodePtr NewOpInst(const OperatorName &op_name); + AnfNodePtr PushBack(const std::vector &inputs); + + private: + CNodePtr cnode_; + FuncGraphManagerPtr manager_; + ScopePtr scope_; + FuncGraphPtr func_graph_; + AnfNodePtr virtual_input_node_; + std::string instance_name_base_; + int64_t name_idx_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc new file mode 100644 index 0000000000..21298697f4 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/graph_util/get_parallel_info.h" + +#include +#include +#include +#include + +#include "common/utils.h" +#include "ir/func_graph.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/graph_util/graph_info.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +py::dict GetParameterLayout(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + py::dict dict; + std::vector graph_params = graph->parameters(); + + for (auto para : graph_params) { + std::string name = std::static_pointer_cast(para)->name(); + std::shared_ptr tensor_layout = std::static_pointer_cast(para)->tensor_layout(); + if (tensor_layout == nullptr) { + MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; + } else { + auto device_arrangement = tensor_layout->device_arrangement().array(); + auto tensor_map = tensor_layout->tensor_map().array(); + auto slice_shape = tensor_layout->slice_shape().array(); + std::vector> layout = {device_arrangement, tensor_map, slice_shape}; + dict[py::str(name)] = layout; + MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); + } + } + return dict; +} + +py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + py::dict dict; + auto ret = graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto nodes = DeepScopedGraphSearch(ret); + + for (auto node : nodes) { + if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto distributed_operation_info = cnode->operator_info(); + if (distributed_operation_info != nullptr) { + auto strategyPtr = distributed_operation_info->strategy(); + if (strategyPtr != nullptr) { + auto strategy = strategyPtr->GetInputDim(); + auto name = cnode->fullname_with_scope(); + dict[py::str(name)] = strategy; + } + } + } + } + return dict; +} + +py::dict GetAllreduceFusion(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + py::dict dict; + auto allreduce_prim_list = FindPrimtive(graph, ALL_REDUCE); + + for (auto prim : allreduce_prim_list) { + auto name_ptr = prim->GetAttr("parameter"); + auto fusion_ptr = prim->GetAttr("fusion"); + if (fusion_ptr == nullptr) { + MS_LOG(EXCEPTION) << "fusion_ptr is nullptr"; + } else if (name_ptr == nullptr) { + continue; + } + if (!name_ptr->isa()) { + MS_LOG(EXCEPTION) << "name is not StringImm"; + } + auto name = name_ptr->cast()->value(); + if (!fusion_ptr->isa()) { + MS_LOG(EXCEPTION) << "fusion is not Int32Imm"; + } + int32_t fusion = fusion_ptr->cast()->value(); + dict[py::str(name)] = fusion; + } + return dict; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h similarity index 100% rename from mindspore/ccsrc/parallel/graph_util/get_parallel_info.h rename to mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc new file mode 100644 index 0000000000..45a88c3a23 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/graph_util/graph_info.h" +#include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" +#include "debug/draw.h" +#include "ir/func_graph.h" +#include "utils/context/ms_context.h" +#include "utils/graph_utils.h" + +namespace mindspore { +namespace parallel { +std::vector FindPrimtive(const FuncGraphPtr &graph, const std::string &name) { + AnfNodePtr ret = graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + std::vector all_nodes = DeepScopedGraphSearch(ret); + std::vector prim_list; + for (auto &node : all_nodes) { + if (!IsValueNode(node)) { + continue; + } + ValueNodePtr prim_node_anf = node->cast(); + MS_EXCEPTION_IF_NULL(prim_node_anf); + PrimitivePtr node_prim = prim_node_anf->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == name) { + prim_list.emplace_back(node_prim); + } + } + return prim_list; +} + +void DumpGraph(const FuncGraphPtr &root, const std::string &name) { + if (MsContext::GetInstance()->save_graphs_flag()) { + draw::Draw(name + ".dot", root); + DumpIR(name + ".ir", root); + ExportIR(name + ".dat", "0", root); + } +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/graph_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.h similarity index 100% rename from mindspore/ccsrc/parallel/graph_util/graph_info.h rename to mindspore/ccsrc/frontend/parallel/graph_util/graph_info.h diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc new file mode 100644 index 0000000000..e50df2818b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/graph_util/node_info.h" + +#include + +#include "ir/anf.h" +#include "ir/param_value.h" +#include "pipeline/jit/parse/python_adapter.h" + +namespace mindspore { +namespace parallel { +std::string ParameterName(const AnfNodePtr &node_ptr) { + auto para_ptr = node_ptr->cast(); + MS_EXCEPTION_IF_NULL(para_ptr); + return para_ptr->name(); +} + +bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { + auto para_ptr = node_ptr->cast(); + if (para_ptr == nullptr) { + return false; + } + if (!para_ptr->has_default()) { + return false; + } + return para_ptr->default_param()->requires_grad(); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h new file mode 100644 index 0000000000..6037c466cd --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ + +#include +#include "base/base.h" + +namespace mindspore { +namespace parallel { +std::string ParameterName(const AnfNodePtr &node_ptr); + +bool ParameterRequireGrad(const AnfNodePtr &node_ptr); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.cc b/mindspore/ccsrc/frontend/parallel/group_manager.cc new file mode 100644 index 0000000000..8929af7b0b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/group_manager.cc @@ -0,0 +1,178 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/group_manager.h" + +#include +#include + +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "utils/comm_manager.h" + +namespace mindspore { +namespace parallel { +Group::Group() { + name_.clear(); + devices_.clear(); +} + +Status Group::Init(const std::string &name, const std::vector &devices) { + this->name_ = name; + this->devices_ = devices; + return Status::SUCCESS; +} + +std::vector Group::GetDevicesList() const { return devices_; } + +bool Group::IsInThisGroup(int32_t device_rank) { + for (auto &device : devices_) { + if (device.rank() == device_rank) { + return true; + } + } + return false; +} + +// Get the position of the device in the group +Status Group::GetIndex(size_t *index) { + size_t pos = 0; + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + for (auto &device : devices_) { + if (device.rank() == rank) { + *index = pos; + return Status::SUCCESS; + } else { + pos++; + } + } + MS_LOG(ERROR) << "Could not find device rank " << rank << "in this group!"; + return Status::FAILED; +} + +GroupManager::GroupManager() { groups_.clear(); } + +Status GroupManager::CreateGroup(const std::string &group_name, const std::vector &devices, + mindspore::parallel::Group *const group) { + // it is simple to use size to determine whether it is a world group + uint32_t world_size = 0; + if (world_group_ != NCCL_WORLD_GROUP) { + (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); + } + + if ((world_group_ == NCCL_WORLD_GROUP) || (devices.size() == world_size)) { + auto it = groups_.find(world_group_); + if (it == groups_.end()) { + (void)group->Init(world_group_, devices); + groups_[world_group_] = *group; + } else { + *group = it->second; + } + MS_LOG(INFO) << "It is world group " << world_group_ << ", no need to create it."; + return Status::SUCCESS; + } + + auto it = groups_.find(group_name); + // If there already exits a group with the desired 'name', + // let the pointer point to the group. + if (it != groups_.end()) { + *group = it->second; + return Status::SUCCESS; + } else { + (void)group->Init(group_name, devices); + groups_[group_name] = *group; + + vector ranks; + (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks), + [](const Device dev) { return (uint32_t)dev.rank(); }); + // Create group through the CommManager interface + bool ret = CommManager::GetInstance().CreateGroupSync(group_name, ranks); + if (!ret) { + MS_LOG(ERROR) << "Create group failed, group name is " << group_name; + return Status::FAILED; + } + + MS_LOG(INFO) << "Create group success, group name is " << group_name; + return Status::SUCCESS; + } +} + +Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { + std::string name = (*group).name(); + auto it = groups_.find(name); + if (it == groups_.end()) { + MS_LOG(ERROR) << "Could not find group name :" << name; + return Status::FAILED; + } + (void)groups_.erase(it); + bool ret = CommManager::GetInstance().DestroyGroup(name); + if (!ret) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +Status GroupManager::DestroyAllGroups() { + for (auto &it : groups_) { + std::string name = it.first; + bool ret = CommManager::GetInstance().DestroyGroup(name); + if (!ret) { + return Status::FAILED; + } + } + groups_.clear(); + return Status::SUCCESS; +} + +Status GroupManager::GetRankID(const std::string &name, unsigned int *const rank_id) { + auto it = groups_.find(name); + if (it == groups_.end()) { + MS_LOG(ERROR) << "Could not find group name :" << name; + return Status::FAILED; + } + bool ret = CommManager::GetInstance().GetRankID(name, rank_id); + if (!ret) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +Status GroupManager::GetRankSize(const std::string &name, unsigned int *const rank_size) { + auto it = groups_.find(name); + if (it == groups_.end()) { + MS_LOG(ERROR) << "Could not find group name :" << name; + return Status::FAILED; + } + bool ret = CommManager::GetInstance().GetRankSize(name, rank_size); + if (!ret) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Group **group) { + auto it = groups_.find(name); + if (it == groups_.end()) { + return Status::FAILED; + } + *group = &it->second; + return Status::SUCCESS; +} + +void GroupManager::Clear() { (void)DestroyAllGroups(); } +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.h b/mindspore/ccsrc/frontend/parallel/group_manager.h new file mode 100644 index 0000000000..b9cf9663b0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/group_manager.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ +#define MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/device.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +constexpr char HCCL_WORLD_GROUP[] = "hccl_world_group"; +constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; +constexpr char UNDEFINED_WORLD_GROUP[] = "undefined_world_group"; + +// Devices that need communication should in the same group. These classes are used to +// create and destroy group among devices. +class Group { + public: + Group(); + ~Group() = default; + Status Init(const std::string &name, const std::vector &devices); + std::vector GetDevicesList() const; + std::string name() const { return name_; } + bool IsInThisGroup(int32_t device_rank); + Status GetIndex(size_t *index); + size_t GetDevNum() const { return devices_.size(); } + + private: + std::string name_; + std::vector devices_; +}; + +class GroupManager { + public: + GroupManager(); + ~GroupManager() = default; + + Status CreateGroup(const std::string &name, const std::vector &devices, Group *group); + Status DestroyGroup(Group *group); + Status DestroyAllGroups(); + Status GetRankID(const std::string &name, unsigned int *rank_id); + Status GetRankSize(const std::string &name, unsigned int *rank_size); + Status FindGroup(const std::string &name, Group **group); + std::string world_group() const { return world_group_; } + void set_world_group(const std::string &name) { world_group_ = name; } + void Clear(); + + private: + // the key is group name (name_) + std::map groups_; + std::string world_group_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/node_check.cc b/mindspore/ccsrc/frontend/parallel/node_check.cc new file mode 100644 index 0000000000..de29417a4d --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/node_check.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/node_check.h" + +#include +#include + +#include "frontend/parallel/ops_info/ops_utils.h" + +namespace mindspore { +namespace parallel { +const std::set BLACK_LIST = {TUPLE_GETITEM, + MAKE_TUPLE, + J, + LIST_GETITEM, + ARRAY_GETITEM, + TUPLE_SETITEM, + DEPEND, + LIST_SETITEM, + ARRAY_SETITEM, + DICT_GETITEM, + LIST_APPEND, + LIST_MAP, + LIST_REDUCE, + TUPLE_REVERSED, + TILE_SHAPE, + TUPLE_DIV, + TUPLE_TO_ARRAY, + MAKE_LIST, + MAKE_DICT, + MAKE_SLICE, + MAKE_RECORD, + STRING_EQUAL, + VIRTUALLOSS, + RETURN, + ENV_GETITEM, + IDENTITY, + PARTIAL, + ENVSETITEM, + ENVGETITEM, + ENVADD, + MAKEREFKEY, + MAKEREF, + GETREFKEY, + GETREFVALUE, + GETREFORIGIN, + DOT, + IM2COL, + COL2IM, + IM2COLV1, + STATESETITEM, + SCALARSUMMARY, + IMAGESUMMARY, + TENSORSUMMARY, + DEBUG, + HISTOGRAMSUMMARY, + COL2IMV1, + RESOLVE, + BROADCASTGRADIENTARGS, + INVERTPERMUTATION, + CONTROLDEPEND, + DROPOUT_GEN_MASK, + EMBED, + CREATINSTANCE, + ZEROSLIKE, + ASSIGN, + REF_TO_EMBED, + STOP_GRADIENT}; + +bool IsInBlackList(const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(prim); + return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end()); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/node_check.h b/mindspore/ccsrc/frontend/parallel/node_check.h similarity index 100% rename from mindspore/ccsrc/parallel/node_check.h rename to mindspore/ccsrc/frontend/parallel/node_check.h diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc new file mode 100644 index 0000000000..35cac1480c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc @@ -0,0 +1,705 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/activation_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status Activation::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + return SUCCESS; +} + +Status ActivationInfo::GetAttrs() { + if (attrs_.size() < ACTIVATION_ATTR_SIZE) { + MS_LOG(ERROR) << name_ << " : The size of attrs small than 1."; + return FAILED; + } + + if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" + << outputs_shape_.size() << "is wrong."; + return FAILED; + } + + auto iter = attrs_.find(ACTIVATION_TYPE); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + std::string val = iter->second->cast()->value(); + if ((val != RELU_TYPE) && (val != RELU6_TYPE) && (val != SIGMOID_TYPE)) { + MS_LOG(ERROR) << name_ << " : Activation type is wrong."; + return FAILED; + } + } else { + MS_LOG(ERROR) << name_ << " : The value of activation_type is not string."; + return FAILED; + } + } + + return SUCCESS; +} + +Status ActivationOther::GetAttrs() { + if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" + << outputs_shape_.size() << "is wrong."; + return FAILED; + } + return SUCCESS; +} + +Status Activation::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" + << outputs_shape_.size() << "is wrong."; + return FAILED; + } + + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status Softmax::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + Dimensions input_strategy = stra.at(0); + + for (auto &element : axis_) { + int32_t axis_index = element; + if (element < 0) { + size_t input_dim = inputs_shape_.at(0).size(); + axis_index = static_cast(input_dim) + element; + } + + int32_t axis_strategy = input_strategy.at(IntToSize(axis_index)); + // Dimension corresponding to axis is un-splittable + if (axis_strategy != MIN_SLICE_NUM) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1"; + } else { + MS_LOG(ERROR) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1"; + } + return FAILED; + } + } + + return SUCCESS; +} + +Status Softmax::GetAttrs() { + if (attrs_.size() < SOFTMAX_ATTR_SIZE) { + MS_LOG(ERROR) << name_ << " : The size of attrs small than 1."; + return FAILED; + } + + auto iter = attrs_.find(AXIS); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { // the axis is a number + int32_t axis_element = iter->second->cast()->value(); + axis_.push_back(axis_element); + MS_LOG(INFO) << name_ << " : The axis is int, value is " << axis_element; + } else if (iter->second->isa()) { // the axis is a tuple + ValueTuplePtr value_tuple = iter->second->cast(); + if (value_tuple == nullptr) { + MS_LOG(ERROR) << name_ << " : The value_tuple is nullptr."; + return FAILED; + } + std::vector value_vector = value_tuple->value(); + (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_), + [](const ValuePtr &value) { return static_cast(GetValue(value)); }); + if (axis_.empty()) { + MS_LOG(ERROR) << name_ << " : The axis tuple is empty."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ShapeToString(axis_); + } else { + MS_LOG(ERROR) << name_ << " : The value of axis is not int or tuple int."; + return FAILED; + } + } + + if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; + return FAILED; + } + + // for example: tensor dimension is 4, then axis range [-4, 3] + int32_t dim = SizeToInt(inputs_shape_.at(0).size()); + auto it = + std::find_if(axis_.begin(), axis_.end(), [dim](int32_t element) { return ((element >= dim) || (element < -dim)); }); + if (it != axis_.end()) { + MS_LOG(ERROR) << name_ << " : The axis(" << *it << ") is out of range[" << -dim << ", " << dim - 1 << "]."; + return FAILED; + } + + return SUCCESS; +} + +Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status Softmax::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << " : GetAttrs failed."; + return FAILED; + } + if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; + return FAILED; + } + + is_auto_parallel_ = true; + Shape input0_split; + (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); + for (auto &element : axis_) { + int32_t axis_index = element; + if (element < 0) { + size_t input_dim = inputs_shape_.at(0).size(); + axis_index = static_cast(input_dim) + element; + } + input0_split[IntToSize(axis_index)] = 0; + } + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status ActivationBase::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + + dev_matrix_shape_ = input_strategy; + + return SUCCESS; +} + +Status ActivationBase::InferMirrorOps() { + mirror_ops_.clear(); + + Shape tensor_map = inputs_tensor_map_[0]; + std::vector group; + if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group failed."; + return FAILED; + } + + OperatorVector mirror_op; + if (group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror ops is empty."; + return SUCCESS; + } else { + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(mirror_op); + std::string group_name = group[0].name(); + MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; + } + + return SUCCESS; +} + +Status ActivationBase::InferForwardCommunication() { + // do nothing + return SUCCESS; +} + +Status ActivationBase::InferTensorMap() { + std::vector tensor_map_index; + size_t size = inputs_shape_.at(0).size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int32_t)(size - i - 1)); + } + + inputs_tensor_map_.push_back(tensor_map_index); + outputs_tensor_map_.push_back(tensor_map_index); + return SUCCESS; +} + +Status ActivationBase::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = {inputs_strategy.at(0)}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + + TensorLayout input_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(input_tensor_info); // the same as input + + return SUCCESS; +} + +Status ActivationBase::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +Status CastInfo::InferMirrorOps() { + mirror_ops_.clear(); + + Shape tensor_map = inputs_tensor_map_[0]; + std::vector group; + if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group failed."; + return FAILED; + } + + OperatorVector mirror_op; + OperatorVector op_for_value; + if (group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror ops is empty."; + return SUCCESS; + } else { + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(mirror_op); + mirror_ops_.push_back(op_for_value); + std::string group_name = group[0].name(); + MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; + } + + return SUCCESS; +} + +Status ExpandDimsInfo::GetAttrs() { + if (input_value_.size() != EXPANDDIMS_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size(); + return FAILED; + } + + if (!input_value_.back()->isa()) { + MS_LOG(ERROR) << name_ << ": The type of axis is not int"; + return FAILED; + } + + int32_t axis = GetValue(input_value_.back()); + + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + int32_t dim = SizeToInt(inputs_shape_[0].size()); + if ((axis > dim) || (axis < -dim - 1)) { + MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim - 1 << ", " << dim << "]"; + return FAILED; + } + + if (axis < 0) { + positive_axis_ = dim + axis + 1; + } else { + positive_axis_ = axis; + } + MS_LOG(INFO) << name_ << ": The axis is " << axis << ", and the positive axis is " << positive_axis_; + return SUCCESS; +} + +Status ExpandDimsInfo::InferTensorMap() { + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + // for example: if the dimension of input is 3, and the axis is 2, + // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0] + std::vector input_tensor_map, output_tensor_map; + size_t size = inputs_shape_[0].size(); + for (size_t i = 0; i < size; ++i) { + input_tensor_map.push_back(SizeToInt(size - i - 1)); + } + + inputs_tensor_map_.push_back(input_tensor_map); + + output_tensor_map = input_tensor_map; + if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(size))) { + MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; + return FAILED; + } + (void)output_tensor_map.insert(output_tensor_map.begin() + positive_axis_, NO_SPLIT_MAP); + outputs_tensor_map_.push_back(output_tensor_map); + + MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) + << ", and the tensor map of output is " << ShapeToString(output_tensor_map); + return SUCCESS; +} + +Status ExpandDimsInfo::InferTensorStrategy() { + if (strategy_ == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + + inputs_strategy_ = strategy_->GetInputDim(); + if (inputs_strategy_.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is empty"; + return FAILED; + } + + Shape output_strategy = inputs_strategy_[0]; + if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(output_strategy.size()))) { + MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; + return FAILED; + } + (void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY); + outputs_strategy_ = {output_strategy}; + return SUCCESS; +} + +Status ExpandDimsInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; + return FAILED; + } + + if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; + return FAILED; + } + + Shape input_shape = inputs_shape_[0]; + Shape output_shape = outputs_shape_[0]; + + // infer slice shape + if (InferTensorStrategy() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer tensor strategy failed"; + return FAILED; + } + Shapes inputs_slice_shape, outputs_slice_shape; + if (InferSliceShape(inputs_strategy_, outputs_strategy_, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; + return FAILED; + } + + if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { + MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; + return FAILED; + } + + Shape input_slice_shape = inputs_slice_shape[0]; + Shape output_slice_shape = outputs_slice_shape[0]; + + TensorLayout input_tensor_layout, output_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; + return FAILED; + } + + if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status ExpandDimsInfo::InferMirrorOps() { + mirror_ops_.clear(); + + if (inputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The tensor map of inputs is empty"; + return FAILED; + } + + std::vector group; + if (CreateGroupByTensorMap(inputs_tensor_map_[0], &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group failed"; + return FAILED; + } + + if (group.empty()) { + MS_LOG(INFO) << name_ << ": No need to create mirror ops"; + return SUCCESS; + } + + OperatorVector mirror_op, placeholder_op; + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(mirror_op); + mirror_ops_.push_back(placeholder_op); + MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name(); + return SUCCESS; +} + +Status SqueezeInfo::InferAxis(const ValueTuplePtr &value_tuple) { + std::vector axis; + auto axis_list = value_tuple->value(); + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + Shape input_shape = inputs_shape_.at(0); + size_t input_size = input_shape.size(); + // if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1. + if (axis_list.empty()) { + for (size_t i = 0; i < input_size; ++i) { + if (input_shape[i] == 1) { + axis.push_back(i); + } + } + axis_ = MakeValue(axis)->cast(); + return SUCCESS; + } + + // convert negative axis to positive. + for (auto &dim : axis_list) { + if (!dim->isa()) { + MS_LOG(ERROR) << name_ << ": The type of axis is not int"; + return FAILED; + } + int32_t dim_value = GetValue(dim); + int32_t positive_value = (dim_value < 0) ? (dim_value + SizeToInt(input_size)) : dim_value; + axis.push_back(positive_value); + } + axis_ = MakeValue(axis)->cast(); + return SUCCESS; +} + +Status SqueezeInfo::GetAttrs() { + auto iter = attrs_.find(AXIS); + if (iter == attrs_.end()) { + MS_LOG(ERROR) << name_ << ": Can't find axis attribute."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(iter->second); + auto value_tuple = iter->second->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + InferAxis(value_tuple); + attrs_[AXIS] = axis_; + return SUCCESS; +} + +Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) { + Attr attr = std::make_pair(AXIS, axis_); + OperatorAttrs attrs = {attr}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + replace_op_ = {std::make_pair(SQUEEZE, args)}; + return SUCCESS; +} + +Status SqueezeInfo::InferTensorMap() { + // for example: if the shape of input is [32, 32, 1], and the axis is (2, ), + // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1] + std::vector input_tensor_map, output_tensor_map; + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + size_t size = inputs_shape_[0].size(); + std::vector axis = GetValue>(axis_); + for (size_t i = 0; i < size; ++i) { + size_t index = size - i - 1; + auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); + if (iter == axis.end()) { + output_tensor_map.push_back(SizeToInt(index)); + } + input_tensor_map.push_back(SizeToInt(index)); + } + inputs_tensor_map_.push_back(input_tensor_map); + outputs_tensor_map_.push_back(output_tensor_map); + MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) + << ", and the tensor map of output is " << ShapeToString(output_tensor_map); + + return SUCCESS; +} + +Status SqueezeInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; + return FAILED; + } + + if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; + return FAILED; + } + + Shape input_shape = inputs_shape_[0]; + Shape output_shape = outputs_shape_[0]; + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Dimensions output_strategy; + std::vector axis = GetValue>(axis_); + for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { + auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); + if (iter == axis.end()) { + output_strategy.push_back(inputs_strategy[0].at(i)); + } + } + Strategys outputs_strategy = {output_strategy}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; + return FAILED; + } + + if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { + MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; + return FAILED; + } + + Shape input_slice_shape = inputs_slice_shape[0]; + Shape output_slice_shape = outputs_slice_shape[0]; + + // infer tensor layout + TensorLayout input_tensor_layout, output_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; + return FAILED; + } + + if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status SqueezeInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + } + + if (InferReplaceOps(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer replace ops failed"; + } + + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h new file mode 100644 index 0000000000..a74707efbe --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h @@ -0,0 +1,224 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class ActivationBase : public OperatorInfo { + public: + ActivationBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} + ~ActivationBase() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + protected: + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; +}; + +class Activation : public ActivationBase { + public: + Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~Activation() override = default; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; +}; + +class ActivationInfo : public Activation { + public: + ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : Activation(name, inputs_shape, outputs_shape, attrs) {} + ~ActivationInfo() override = default; + + protected: + Status GetAttrs() override; // activation_type: relu, relu6, sigmoid +}; + +class ActivationOther : public Activation { + public: + ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : Activation(name, inputs_shape, outputs_shape, attrs) {} + ~ActivationOther() override = default; + + protected: + Status GetAttrs() override; +}; + +class GeluInfo : public ActivationOther { + public: + GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~GeluInfo() override = default; +}; + +class TanhInfo : public ActivationOther { + public: + TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~TanhInfo() override = default; +}; + +class Softmax : public ActivationBase { + public: + explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~Softmax() override = default; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + + private: + std::vector axis_; +}; + +class SoftmaxInfo : public Softmax { + public: + SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : Softmax(name, inputs_shape, outputs_shape, attrs) {} + ~SoftmaxInfo() override = default; +}; + +class LogSoftmaxInfo : public Softmax { + public: + LogSoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : Softmax(name, inputs_shape, outputs_shape, attrs) {} + ~LogSoftmaxInfo() override = default; +}; + +class ReLUInfo : public ActivationOther { + public: + ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~ReLUInfo() override = default; +}; + +class CastInfo : public ActivationOther { + public: + CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~CastInfo() override = default; + + protected: + Status InferMirrorOps() override; +}; + +class SqrtInfo : public ActivationOther { + public: + SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SqrtInfo() override = default; +}; + +class NegInfo : public ActivationOther { + public: + NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~NegInfo() override = default; +}; + +class ExpandDimsInfo : public ActivationOther { + public: + ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~ExpandDimsInfo() override = default; + + protected: + Status GetAttrs() override; + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferMirrorOps() override; + Status InferTensorStrategy(); + + private: + int32_t positive_axis_ = -1; + Strategys inputs_strategy_; + Strategys outputs_strategy_; +}; + +class SqueezeInfo : public ActivationOther { + public: + SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SqueezeInfo() override = default; + + protected: + Status InferAxis(const ValueTuplePtr &value_tuple); + Status GetAttrs() override; + Status InferReplaceOps(const StrategyPtr &strategy); + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status Init(const StrategyPtr &strategy) override; + + private: + ValueTuplePtr axis_; +}; + +class SquareInfo : public ActivationOther { + public: + SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SquareInfo() override = default; +}; + +class SigmoidInfo : public ActivationOther { + public: + SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SigmoidInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc new file mode 100644 index 0000000000..1dd9c899ca --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc @@ -0,0 +1,363 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/arithmetic_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Shape ExpendShape(const Shape &bigger_size_shape, Shape smaller_size_shape) { + size_t insert_num = bigger_size_shape.size() - smaller_size_shape.size(); + for (size_t num = 0; num < insert_num; ++num) { + (void)smaller_size_shape.insert(smaller_size_shape.begin(), 1); + } + return smaller_size_shape; +} + +Shapes ArithmeticBase::InferExpendShape() { + Shape input_a_shape = inputs_shape_.at(0); + Shape input_b_shape = inputs_shape_.at(1); + Shapes input_shapes; + size_t input_a_size = input_a_shape.size(); + size_t input_b_size = input_b_shape.size(); + if (input_a_size > input_b_size) { + input_shapes.push_back(input_a_shape); + input_shapes.push_back(ExpendShape(input_a_shape, input_b_shape)); + } else if (input_a_size < input_b_size) { + input_shapes.push_back(ExpendShape(input_b_shape, input_a_shape)); + input_shapes.push_back(input_b_shape); + } else { + input_shapes.push_back(input_a_shape); + input_shapes.push_back(input_b_shape); + } + return input_shapes; +} + +std::vector ExpendStrategy(const StrategyPtr &strategy) { + std::vector expend_strategy; + std::vector stra = strategy->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + Dimensions sub_b_strategy = stra.at(1); + size_t input_a_size = sub_a_strategy.size(); + size_t input_b_size = sub_b_strategy.size(); + if (input_a_size > input_b_size) { + expend_strategy.push_back(sub_a_strategy); + expend_strategy.push_back(ExpendShape(sub_a_strategy, sub_b_strategy)); + } else if (input_a_size < input_b_size) { + expend_strategy.push_back(ExpendShape(sub_b_strategy, sub_a_strategy)); + expend_strategy.push_back(sub_b_strategy); + } else { + expend_strategy = stra; + } + return expend_strategy; +} + +Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + Shapes input_shapes = InferExpendShape(); + std::vector expend_strategy = ExpendStrategy(strategy); + Dimensions sub_a_strategy = expend_strategy.at(0); + Dimensions sub_b_strategy = expend_strategy.at(1); + Shape input_a_shape = input_shapes.at(0); + Shape input_b_shape = input_shapes.at(1); + + for (size_t i = 0; i < input_a_shape.size(); ++i) { + if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != 1) && (input_b_shape[i] != 1)) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + } + return SUCCESS; +} + +Status ArithmeticBase::InferDevMatrixShape() { + std::vector expend_strategy = ExpendStrategy(strategy_); + Dimensions sub_a_strategy = expend_strategy.at(0); + Dimensions sub_b_strategy = expend_strategy.at(1); + Shape dev_shape; + for (size_t i = 0; i < sub_a_strategy.size(); ++i) { + if (sub_a_strategy[i] != sub_b_strategy[i]) { + dev_shape.push_back(sub_a_strategy[i] * sub_b_strategy[i]); + } else { + dev_shape.push_back(sub_a_strategy[i]); + } + } + dev_matrix_shape_ = dev_shape; + + return SUCCESS; +} + +TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shape) { + TensorMap tensor_map_index; + for (size_t i = 0; i < strategy.size(); ++i) { + if (strategy[i] == dev_matrix_shape[i]) { + tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(strategy.size())) - i)); + } else { + tensor_map_index.push_back(-1); + } + } + return tensor_map_index; +} + +TensorMap SetTensorMap(const Shape &strategy_expend, const Shape &dev_matrix_shape, const Shape &strategy) { + TensorMap expend_map = SetExpendTensorMap(strategy_expend, dev_matrix_shape); + size_t dev_matrix_size = dev_matrix_shape.size(); + size_t strategy_size = strategy.size(); + if (dev_matrix_size != strategy_size) { + (void)expend_map.erase(expend_map.begin(), + expend_map.begin() + static_cast(dev_matrix_size - strategy_size)); + } + return expend_map; +} + +void ArithmeticBase::ReComputeBatchSplitFlagList() { + Shapes expend_shapes = InferExpendShape(); + Shape expend_a_shape = expend_shapes.at(0); + Shape expend_b_shape = expend_shapes.at(1); + if (expend_a_shape.size() != expend_b_shape.size()) { + MS_LOG(EXCEPTION) << name_ << " : Recompute batch split flag list is wrong."; + } + if (expend_a_shape.empty()) { + split_flag_list_[0] = false; + split_flag_list_[1] = false; + return; + } + (expend_a_shape.at(0) != 1) ? (split_flag_list_[0] = true) : (split_flag_list_[0] = false); + (expend_b_shape.at(0) != 1) ? (split_flag_list_[1] = true) : (split_flag_list_[1] = false); +} + +Status ArithmeticBase::InferTensorMap() { + std::vector tensor_map_index; + std::vector expend_strategy = ExpendStrategy(strategy_); + Dimensions sub_a_expend_strategy = expend_strategy.at(0); + Dimensions sub_b_expend_strategy = expend_strategy.at(1); + Strategys stra = strategy_->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + Dimensions sub_b_strategy = stra.at(1); + for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { + tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_expend_strategy.size())) - i)); + } + + Shape dev_shape; + for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { + if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) { + dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]); + } else { + dev_shape.push_back(sub_a_expend_strategy[i]); + } + } + inputs_tensor_map_.push_back(SetTensorMap(sub_a_expend_strategy, dev_shape, sub_a_strategy)); + inputs_tensor_map_.push_back(SetTensorMap(sub_b_expend_strategy, dev_shape, sub_b_strategy)); + outputs_tensor_map_.push_back(tensor_map_index); + + return SUCCESS; +} + +Status ArithmeticBase::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_a_tensor_map = inputs_tensor_map_.at(0); + Shape input_b_tensor_map = inputs_tensor_map_.at(1); + std::vector input_a_group, input_b_group; + if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input a failed."; + return FAILED; + } + if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input b failed."; + return FAILED; + } + + OperatorVector op_for_input_a, op_for_input_b; + if (input_a_group.empty() && input_b_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror group is empty."; + return SUCCESS; + } + if (!input_a_group.empty()) { + op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); + } + if (!input_b_group.empty()) { + op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name(); + } + mirror_ops_.push_back(op_for_input_a); + mirror_ops_.push_back(op_for_input_b); + + return SUCCESS; +} + +Status ArithmeticBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, + const Shape &dev_matrix_array) { + if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { + MS_LOG(ERROR) << name_ << " : The layout is null."; + return FAILED; + } + TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0); + TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1); + TensorMap out_tensor_map_array = outputs_tensor_map_.at(0); + Shape input_a_shape_array = inputs_shape_.at(0); + Shape input_b_shape_array = inputs_shape_.at(1); + Shape out_shape_array = outputs_shape_.at(0); + + TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout; + if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed."; + return FAILED; + } + if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed."; + return FAILED; + } + if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed."; + return FAILED; + } + inputs_layout->push_back(input_a_tensor_layout); + inputs_layout->push_back(input_b_tensor_layout); + outputs_layout->push_back(out_tensor_layout); + + return SUCCESS; +} + +Status ArithmeticBase::InferTensorInfo() { + // infer tensor shape + Shape input_a_shape = inputs_shape_.at(0); + Shape input_b_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + std::vector expend_strategy = ExpendStrategy(strategy_); + Dimensions sub_a_expend_strategy = expend_strategy.at(0); + Dimensions sub_b_expend_strategy = expend_strategy.at(1); + Strategys inputs_strategy = strategy_->GetInputDim(); + Shape dev_shape; + for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { + if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) { + dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]); + } else { + dev_shape.push_back(sub_a_expend_strategy[i]); + } + } + Strategys outputs_strategy = {dev_shape}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_a_slice_shape = inputs_slice_shape.at(0); + Shape input_b_slice_shape = inputs_slice_shape.at(1); + Shape output_slice_shape = outputs_slice_shape.at(0); + + // infer tensor layout + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer tensor layout failed."; + return FAILED; + } + + TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape); + TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape); + TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a + inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b + outputs_tensor_info_.push_back(out_tensor_info); // output + + return SUCCESS; +} + +Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status ArithmeticBase::GenerateStrategies(int32_t stage_id) { + Shape input0_split(inputs_shape_[0].size(), 1); + Shape input1_split(inputs_shape_[1].size(), 1); + Shapes splittable_inputs = {input0_split, input1_split}; + + std::vector sp_vector; + is_auto_parallel_ = true; + if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies with broadcast failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success."; + + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status ArithmeticBase::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status ArithmeticBase::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h new file mode 100644 index 0000000000..1d347e4ec1 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h @@ -0,0 +1,135 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class ArithmeticBase : public OperatorInfo { + public: + ArithmeticBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} + ~ArithmeticBase() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status GetAttrs() override { return SUCCESS; } + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); + Shapes InferExpendShape(); +}; + +class SubInfo : public ArithmeticBase { + public: + SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~SubInfo() override = default; +}; + +class TensorAddInfo : public ArithmeticBase { + public: + TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~TensorAddInfo() override = default; +}; + +class MulInfo : public ArithmeticBase { + public: + MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~MulInfo() override = default; +}; + +class DivInfo : public ArithmeticBase { + public: + DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~DivInfo() override = default; +}; + +class RealDivInfo : public ArithmeticBase { + public: + RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~RealDivInfo() override = default; +}; + +class FloorDivInfo : public ArithmeticBase { + public: + FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~FloorDivInfo() override = default; +}; + +class PowInfo : public ArithmeticBase { + public: + PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~PowInfo() override = default; +}; + +class GreaterInfo : public ArithmeticBase { + public: + GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~GreaterInfo() override = default; +}; + +class AssignSubInfo : public ArithmeticBase { + public: + AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~AssignSubInfo() override = default; +}; + +// All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label. +class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { + public: + SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~SigmoidCrossEntropyWithLogitsInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc new file mode 100644 index 0000000000..64aceb90f6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/batch_parallel_info.h" + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/step_parallel.h" + +namespace mindspore { +namespace parallel { +Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + int32_t stage = strategy->GetInputStage(); + CheckGlobalDeviceManager(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); + dev_num_ = dev_num; + + size_t strategy_size = strategy->GetInputNumber(); + std::vector stra = strategy->GetInputDim(); + for (size_t i = 0; i < strategy_size; ++i) { + Shape sub_strategy = stra.at(i); + size_t strategy_len = sub_strategy.size(); + bool flag = false; + for (size_t j = 0; j < strategy_len; ++j) { + int32_t strategy_value = sub_strategy.at(j); + if (strategy_value > 1) { + if (flag || strategy_value != dev_num_) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : It is not a valid data parallel strategy."; + } else { + MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; + } + return FAILED; + } + flag = true; + } + } + } + return SUCCESS; +} + +Status BatchParallelInfo::InferDevMatrixShape() { + dev_matrix_shape_.push_back(dev_num_); + return SUCCESS; +} + +Status BatchParallelInfo::InferMirrorOps() { + mirror_ops_.clear(); + if (g_device_manager->DeviceNum() == 1) { + MS_LOG(INFO) << name_ << " : The device num is 1, no need to create mirror ops."; + return SUCCESS; + } + + MS_LOG(INFO) << name_ << " : Batch parallel input number " << strategy_->GetInputNumber(); + for (size_t i = 0; i < input_value_.size(); i++) { + MS_EXCEPTION_IF_NULL(g_device_manager); + OperatorVector op_vec = CreateMirrorOps(g_device_manager->world_group(), g_device_manager->DeviceNum()); + mirror_ops_.push_back(op_vec); + } + return SUCCESS; +} + +Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } + +Status BatchParallelInfo::InferTensorMap() { + if (strategy_->GetInputDim()[0][0] != dev_num_) { + MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; + return FAILED; + } + for (size_t i = 0; i < inputs_shape_.size(); i++) { + std::vector tensor_map_index; + for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { + if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) { + tensor_map_index.push_back(0); + } else { + tensor_map_index.push_back(MAP_NONE); + } + } + inputs_tensor_map_.push_back(tensor_map_index); + } + for (size_t i = 0; i < outputs_shape_.size(); i++) { + std::vector tensor_map_index; + for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { + if (i == 0 && j == 0) { + tensor_map_index.push_back(0); + } else { + tensor_map_index.push_back(MAP_NONE); + } + } + outputs_tensor_map_.push_back(tensor_map_index); + } + return SUCCESS; +} + +Strategys BatchParallelInfo::GetOutputsStrategy() { + Strategys outputs_strategy; + + for (size_t i = 0; i < outputs_shape_.size(); ++i) { + std::vector strategy; + for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { + if (i == 0 && j == 0) { + strategy.push_back(dev_num_); + } else { + strategy.push_back(1); + } + } + outputs_strategy.push_back(strategy); + } + + return outputs_strategy; +} + +Status BatchParallelInfo::InferTensorInfo() { + for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { + MS_LOG(INFO) << name_ << " : The input size is " << strategy_->GetInputNumber(); + TensorLayout tensor_layout_in; + if (tensor_layout_in.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(i), inputs_shape_.at(i)) != SUCCESS) { + return FAILED; + } + TensorInfo tensor_info_in(tensor_layout_in); + inputs_tensor_info_.push_back(tensor_info_in); + } + for (size_t i = 0; i < outputs_shape_.size(); i++) { + TensorLayout tensor_layout_out; + if (tensor_layout_out.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(i), outputs_shape_.at(i)) != + SUCCESS) { + return FAILED; + } + TensorInfo tensor_info_out(tensor_layout_out); + outputs_tensor_info_.push_back(tensor_info_out); + } + return SUCCESS; +} + +Status BatchParallelInfo::GetAttrs() { return SUCCESS; } + +Status BatchParallelInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) { + CheckGlobalDeviceManager(); + is_auto_parallel_ = true; + size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + StrategyPtr sp; + std::vector strategy; + for (size_t i = 0; i < inputs_shape_.size(); i++) { + Shape temp(inputs_shape_[i].size(), 1); + if (split_flag_list_[i]) { + temp[0] = SizeToInt(total_dev_num); + } + strategy.push_back(temp); + } + sp = std::make_shared(stage_id, strategy); + + if (SetCostUnderStrategy(sp) == SUCCESS) { + MS_LOG(INFO) << name_ << " : Successfully generated batch-parallel-strategy."; + PrintStrategy(sp); + } else { + MS_LOG(ERROR) << name_ << " : Generating batch-parallel-strategy failed."; + return FAILED; + } + return SUCCESS; +} + +void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { + for (size_t i = 0; i < inputs_shape_.size(); i++) { + split_flag_list_[i] = true; + } +} + +Status BatchParallelInfo::InferAsLossDivisor() { + as_loss_divisor_ = 1; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h new file mode 100644 index 0000000000..0ba30c385a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h @@ -0,0 +1,72 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ + +#include +#include +#include +#include +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class BatchParallelInfo : public OperatorInfo { + public: + BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} + BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), + dev_num_(1) {} + + ~BatchParallelInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + Strategys GetOutputsStrategy(); + Status InferAsLossDivisor() override; + + private: + int32_t dev_num_; +}; + +class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { + public: + SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, + const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; + void ReComputeBatchSplitFlagList() override; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc new file mode 100644 index 0000000000..e8b3afba16 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc @@ -0,0 +1,261 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/bias_add_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + std::vector stra = strategy->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + Dimensions sub_b_strategy = stra.at(1); + int32_t channel_a_strategy = sub_a_strategy.at(1); + int32_t channel_b_strategy = sub_b_strategy.at(0); + if (channel_a_strategy != channel_b_strategy) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + return SUCCESS; +} + +Status BiasAddInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + dev_matrix_shape_ = sub_a_strategy; + return SUCCESS; +} + +void BiasAddInfo::ReComputeBatchSplitFlagList() { + split_flag_list_[0] = true; + split_flag_list_[1] = false; +} + +Status BiasAddInfo::InferTensorMap() { + TensorMap sub_a_tensor_map; + TensorMap sub_b_tensor_map; + std::vector stra = strategy_->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + size_t sub_a_strategy_size = sub_a_strategy.size(); + for (size_t i = 0; i < sub_a_strategy_size; ++i) { + sub_a_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - i)); + } + sub_b_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - 1)); + + inputs_tensor_map_.push_back(sub_a_tensor_map); + inputs_tensor_map_.push_back(sub_b_tensor_map); + outputs_tensor_map_.push_back(sub_a_tensor_map); + + return SUCCESS; +} + +Status BiasAddInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_a_tensor_map = inputs_tensor_map_.at(0); + Shape input_b_tensor_map = inputs_tensor_map_.at(1); + std::vector input_a_group, input_b_group; + if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input a failed."; + return FAILED; + } + if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input b failed."; + return FAILED; + } + + OperatorVector op_for_input_a, op_for_input_b; + if (input_a_group.empty() && input_b_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror group is empty."; + return SUCCESS; + } + if (!input_a_group.empty()) { + op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); + } + if (!input_b_group.empty()) { + op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name(); + } + mirror_ops_.push_back(op_for_input_a); + mirror_ops_.push_back(op_for_input_b); + + return SUCCESS; +} + +Status BiasAddInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, + const Shape &dev_matrix_array) { + if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { + MS_LOG(ERROR) << name_ << " : The layout is null."; + return FAILED; + } + TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0); + TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1); + TensorMap out_tensor_map_array = outputs_tensor_map_.at(0); + Shape input_a_shape_array = inputs_shape_.at(0); + Shape input_b_shape_array = inputs_shape_.at(1); + Shape out_shape_array = outputs_shape_.at(0); + + TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout; + if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed."; + return FAILED; + } + if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed."; + return FAILED; + } + if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed."; + return FAILED; + } + inputs_layout->push_back(input_a_tensor_layout); + inputs_layout->push_back(input_b_tensor_layout); + outputs_layout->push_back(out_tensor_layout); + + return SUCCESS; +} + +Status BiasAddInfo::InferTensorInfo() { + // infer tensor shape + Shape input_a_shape = inputs_shape_.at(0); + Shape input_b_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = {inputs_strategy.at(0)}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_a_slice_shape = inputs_slice_shape.at(0); + Shape input_b_slice_shape = inputs_slice_shape.at(1); + Shape output_slice_shape = outputs_slice_shape.at(0); + + // infer tensor layout + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer tensor layout failed."; + return FAILED; + } + + TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape); + TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape); + TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a + inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b + outputs_tensor_info_.push_back(out_tensor_info); // output + + return SUCCESS; +} + +Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status BiasAddInfo::GenerateStrategies(int32_t stage_id) { + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split, input0_split}; + + std::vector sp_vector; + is_auto_parallel_ = true; + Shapes tmp_inputs_shape = {inputs_shape_[0], inputs_shape_[0]}; + Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]}; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, &sp_vector) != + SUCCESS) { + return FAILED; + } + MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success."; + + for (auto &sp : sp_vector) { + std::vector tmp_strategy; + Dimensions input0_strategy = sp->GetInputDim()[0]; + tmp_strategy.push_back(input0_strategy); // input0 + + Dimensions input1_strategy = {input0_strategy.at(1)}; + + // reset the strategy + tmp_strategy.push_back(input1_strategy); // input1 + sp->ResetInputs(tmp_strategy); + } + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status BiasAddInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status BiasAddInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h new file mode 100644 index 0000000000..3ede65a3ba --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ + +#include + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class BiasAddInfo : public OperatorInfo { + public: + BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~BiasAddInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status GetAttrs() override { return SUCCESS; } + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h new file mode 100644 index 0000000000..2829889846 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ + +#include +#include +#include +#include +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/arithmetic_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class EqualInfo : public ArithmeticBase { + public: + EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~EqualInfo() override = default; +}; + +class NotEqualInfo : public ArithmeticBase { + public: + NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~NotEqualInfo() override = default; +}; + +class MaximumInfo : public ArithmeticBase { + public: + MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~MaximumInfo() override = default; +}; + +class MinimumInfo : public ArithmeticBase { + public: + MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~MinimumInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc new file mode 100644 index 0000000000..3b411ccb0e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc @@ -0,0 +1,323 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/dropout_do_mask_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "pipeline/jit/resource.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +static int32_t SEED_NUM = 1; + +Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { + if (strategy == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + if (stra.size() != 1) { + MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1"; + return FAILED; + } + + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + // only check the input[0] + Shapes input_shape = {inputs_shape_[0]}; + if (CheckStrategyValue(strategy, input_shape, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy"; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + } + return FAILED; + } + return SUCCESS; +} + +Status DropoutDoMaskInfo::InferDevMatrixShape() { + if (strategy_ == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + + std::vector strategy = strategy_->GetInputDim(); + if (strategy.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is empty"; + return FAILED; + } + + dev_matrix_shape_ = strategy[0]; + return SUCCESS; +} + +Status DropoutDoMaskInfo::InferTensorMap() { + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + std::vector tensor_map_index; + size_t size = inputs_shape_[0].size(); + // if the dimension of input is 4, and tensor_map_index is [3, 2, 1, 0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back(SizeToInt(size - i - 1)); + } + + // the input[1] do not need tensor map + inputs_tensor_map_.push_back(tensor_map_index); // input_0 + outputs_tensor_map_.push_back(tensor_map_index); // output + return SUCCESS; +} + +Status DropoutDoMaskInfo::InferTensorInfo() { + if (inputs_shape_.size() != 3) { + MS_LOG(ERROR) << name_ << ": Invalid inputs shape size " << inputs_shape_.size(); + return FAILED; + } + + if (strategy_ == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + + Shape input_0_shape = inputs_shape_[0]; + + if (inputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty"; + return FAILED; + } + + TensorLayout input_0_tensor_layout; + if (input_0_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_0_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout failed"; + return FAILED; + } + + TensorInfo input_0_tensor_info(input_0_tensor_layout); + + // input_1 do not need tensor info + inputs_tensor_info_.push_back(input_0_tensor_info); // input_0 + outputs_tensor_info_.push_back(input_0_tensor_info); // output + return SUCCESS; +} + +Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + Shapes used_inputs_shape = {inputs_shape_[0]}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, used_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate strategies failed"; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +std::shared_ptr>> DropoutDoMaskInfo::GenerateBatchStrategies() { + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + Dimensions strategy(inputs_shape_[0].size() - 1, 1); + (void)strategy.insert(strategy.begin(), SizeToInt(dev_num)); + std::vector strategy_v = {strategy}; + return std::make_shared>>(strategy_v); +} + +Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; + } + + AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX); + MS_EXCEPTION_IF_NULL(dropout_gen_mask); + if (!dropout_gen_mask->isa()) { + MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode"; + } + + auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); + if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; + } + if (!IsValueNode(dropout_gen_mask_cnode->input(0))) { + MS_LOG(EXCEPTION) << "The input[0] of dropout gen mask cnode is not primitive"; + } + + ValueNodePtr value_node = dropout_gen_mask_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(value_node); + PrimitivePtr prim = value_node->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() != DROPOUT_GEN_MASK) { + MS_LOG(EXCEPTION) << "The primitive name is not DropoutGenMask"; + } + return prim; +} + +void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; + } + + AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX); + MS_EXCEPTION_IF_NULL(dropout_gen_mask); + if (!dropout_gen_mask->isa()) { + MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode."; + } + + auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); + if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; + } + + if (!IsValueNode(dropout_gen_mask_cnode->input(1))) { + MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple."; + } + + FuncGraphPtr func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr."; + } + + ValuePtr new_shape = MakeValue(input_slice_shape); + AnfNodePtr val = NewValueNode(new_shape); + (void)manager->Replace(dropout_gen_mask_cnode->input(1), val); +} + +// DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is +// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape +// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation +// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. +std::vector DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { + std::vector replace_ops; + MS_EXCEPTION_IF_NULL(cnode); + PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); + MS_EXCEPTION_IF_NULL(prim); + + if (inputs_tensor_info_.empty()) { + MS_LOG(EXCEPTION) << "The tensor info of dropout do mask is empty"; + } + + if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; + } + + if (!cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX)->isa()) { + MS_LOG(EXCEPTION) << "The keep prob of dropout do mask is not value node"; + } + + ValuePtr keep_prob = GetValueNode(cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX)); + MS_EXCEPTION_IF_NULL(keep_prob); + auto attr = prim->attrs(); + if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) { + MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1"; + } + + Shape input_slice_shape = inputs_tensor_info_[0].slice_shape(); + int32_t seed_0 = GetValue(attr[SEED0]); + int32_t seed_1 = GetValue(attr[SEED1]); + if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) { + seed_0 = SEED_NUM; + seed_1 = SEED_NUM; + SEED_NUM++; + } else { + SetGenMaskShape(cnode, input_slice_shape); + MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape); + return replace_ops; + } + + ValuePtr new_shape = MakeValue(input_slice_shape); + Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0)); + Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1)); + OperatorAttrs attrs = {attr_0, attr_1}; + Attr param_0 = std::make_pair(SHAPE, new_shape); + Attr param_1 = std::make_pair(KEEP_PROB, keep_prob); + OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)}; + OperatorArgs args = std::make_pair(attrs, params); + Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)}; + replace_ops.push_back(replace_op); + return replace_ops; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h new file mode 100644 index 0000000000..ea7d590071 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class DropoutDoMaskInfo : public OperatorInfo { + public: + DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~DropoutDoMaskInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + std::shared_ptr>> GenerateBatchStrategies() override; + std::vector GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorMap() override; + Status GetAttrs() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; +}; + +using DropoutDoMaskInfoPtr = std::shared_ptr; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h new file mode 100644 index 0000000000..e25da9e743 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ + +#include +#include +#include +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class ExpInfo : public ActivationOther { + public: + ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~ExpInfo() override = default; +}; + +class LogInfo : public ActivationOther { + public: + LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~LogInfo() override = default; +}; + +class CosInfo : public ActivationOther { + public: + CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~CosInfo() override = default; +}; + +class ACosInfo : public ActivationOther { + public: + ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~ACosInfo() override = default; +}; + +class LogicalNotInfo : public ActivationOther { + public: + LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~LogicalNotInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc new file mode 100644 index 0000000000..4e6e947f68 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc @@ -0,0 +1,350 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/gather_v2_info.h" + +#include +#include +#include + +#include "ir/tensor.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/strategy.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status GatherV2Info::GetAttrs() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size(); + return FAILED; + } + if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) { + MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size(); + return FAILED; + } + // the second input is the index tensor + + // the third input is the axis, is a ValueNode + if (input_value_.at(2) == nullptr) { + MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; + return FAILED; + } + + if (inputs_shape_.at(0).size() == 0) { + MS_LOG(ERROR) << name_ << ": input can not be a scalar!"; + return FAILED; + } + int axis = GetValue(input_value_.at(2)); + if (axis >= SizeToInt(inputs_shape_.at(0).size()) || axis < 0 - SizeToInt(inputs_shape_.at(0).size())) { + MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", " + << inputs_shape_.at(0).size() << ")."; + } + if (axis < 0) { + axis += SizeToInt(inputs_shape_[0].size()); + } + axis_ = axis; + + index_size_ = inputs_shape_.at(1).size(); + + return SUCCESS; +} + +Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_shape_.size(); + return FAILED; + } + // Only strategy of the first input should be set. + if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + axis_strategy_ = strategy->GetInputDim().at(0).at(axis_); + if (index_size_ != 1 && axis_strategy_ != 1) { + MS_LOG(ERROR) << name_ + << ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy " + "corresponding to axis must be 1, but is " + << axis_strategy_; + return FAILED; + } + if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) { + MS_LOG(ERROR) << name_ + << ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to " + "axis. The first dimension of index is " + << inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_; + return FAILED; + } + return SUCCESS; +} + +Status GatherV2Info::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + dev_matrix_shape_ = stra.at(0); + return SUCCESS; +} + +// If index is a scalar, output dimension is input dimension minus 1; +// If index is a n dimension tensor, output dimension is input dimension plus (n - 1). +// Tensor map dimension is equal to the corresponding input and output dimension. +// If index's dimension is more than 1, we insert -1 for the output tensor map. +Status GatherV2Info::InferTensorMap() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_shape_.size(); + return FAILED; + } + std::vector tensor_map_in; + std::vector tensor_map_out; + size_t size = inputs_shape_.at(0).size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_in.push_back(SizeToInt(size - i - 1)); + tensor_map_out.push_back(SizeToInt(size - i - 1)); + } + + if (index_size_ == 0) { + (void)tensor_map_out.erase(tensor_map_out.begin() + axis_); + } else if (index_size_ > 1) { + (void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1); + } + if (tensor_map_out.size() != outputs_shape_.at(0).size()) { + MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size() + << " output size is " << outputs_shape_.at(0).size(); + return FAILED; + } + + std::vector tensor_map_in_index; + if (index_size_ >= 1) { + tensor_map_in_index.push_back(SizeToInt(size - axis_ - 1)); + } + for (size_t i = 1; i < index_size_; ++i) { + tensor_map_in_index.push_back(-1); + } + inputs_tensor_map_.emplace_back(std::move(tensor_map_in)); + inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index)); + outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); + return SUCCESS; +} + +Status GatherV2Info::InferTensorInfo() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_shape_.size(); + return FAILED; + } + if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_tensor_map_.size(); + return FAILED; + } + if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_tensor_map_.size(); + return FAILED; + } + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape input_index_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + + TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || + (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) { + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout); + TensorInfo input_index_info(input_index_layout); + TensorInfo output_tensor_info(output_tensor_layout); + + inputs_tensor_info_.push_back(input_tensor_info); + inputs_tensor_info_.push_back(input_index_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +OperatorVector CreateSubOp(int32_t sub_value) { + OperatorVector ops; + OperatorName operator_name = SUB; + OperatorAttrs operator_attrs; + + std::vector tensor_data = {sub_value}; + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(tensor_data, kInt32); + ValuePtr op_param_value = MakeValue(tensor_ptr); + + Attr op1_param = std::make_pair("", op_param_value); + OperatorParams operator_param = {std::make_pair(op1_param, 2)}; + + OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); + Operator op = std::make_pair(operator_name, operator_args); + ops.push_back(op); + return ops; +} + +Status GatherV2Info::InferTensorSubOps() { + sub_ops_.clear(); + if ((index_size_ == 0) || (axis_strategy_ == 1)) { + return SUCCESS; + } + int32_t mod_n = 1; + for (size_t i = IntToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) { + mod_n *= dev_matrix_shape_.at(i); + } + if ((axis_ >= SizeToInt(dev_matrix_shape_.size())) || axis_ < 0) { + MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; + } + int32_t mod_p = mod_n * dev_matrix_shape_.at(axis_); + int32_t rank = g_device_manager->global_rank(); + int32_t mod_rank = rank % mod_p; + mod_rank = static_cast(mod_rank / mod_n); + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if ((axis_ >= SizeToInt(inputs_shape_.at(0).size())) || axis_ < 0) { + MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ")."; + } + int32_t sub_value = static_cast(inputs_shape_.at(0).at(axis_) / dev_matrix_shape_.at(axis_)) * mod_rank; + + OperatorVector sub_op; + sub_ops_.emplace_back(std::move(sub_op)); + sub_op = CreateSubOp(sub_value); + sub_ops_.emplace_back(std::move(sub_op)); + return SUCCESS; +} + +Status GatherV2Info::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + Status status = InferTensorSubOps(); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed."; + return status; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status GatherV2Info::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" + << outputs_shape_.size() << "is wrong."; + return FAILED; + } + + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +std::shared_ptr>> GatherV2Info::GenerateBatchStrategies() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + } + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + if (GetAttrs() != SUCCESS) { + MS_LOG(EXCEPTION) << "GetAttrs failed!"; + } + + Dimensions strategy; + if (index_size_ != 1) { + strategy.push_back(1); + } else { + strategy.push_back(SizeToInt(dev_num)); + } + for (size_t i = 1; i < inputs_shape_[0].size(); i++) { + strategy.push_back(1); + } + std::vector strategy_v = {strategy}; + return std::make_shared>>(strategy_v); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.h new file mode 100644 index 0000000000..b3dc0fab87 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.h @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +constexpr size_t GATHER_V2_INPUTS_SIZE = 2; +constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1; +constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3; +// We now supported limited parallel strategies. +// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of +// the input. +// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1. +class GatherV2Info : public OperatorInfo { + public: + GatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), + axis_(-1), + index_size_(0), + axis_strategy_(1) {} + ~GatherV2Info() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + std::shared_ptr>> GenerateBatchStrategies() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + + private: + Status InferTensorSubOps(); + + int32_t axis_; + size_t index_size_; + int32_t axis_strategy_; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc new file mode 100644 index 0000000000..eb3c9900f8 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -0,0 +1,636 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/gather_v2_p_info.h" + +#include +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/graph_util/generate_graph.h" + +namespace mindspore { +namespace parallel { +Status GatherV2PInfo::GetAttrs() { + // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis. + if (target_ != CPU) { + if (input_value_.at(2) == nullptr) { + MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; + return FAILED; + } + auto axis = GetValue(input_value_.at(2)); + // if axis is negative then convert it to positive + auto params_shape = inputs_shape_.at(0); + if (params_shape.size() == 0) { + MS_LOG(ERROR) << name_ << ": params can not be a scalar!"; + return FAILED; + } + if (axis < 0) { + axis += SizeToInt(inputs_shape_[0].size()); + } + axis_ = axis; + } + + auto target_iter = attrs_.find(TARGET); + if (target_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(target_iter->second); + if (target_iter->second->isa()) { + target_ = target_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of target is not a string."; + } + } + auto manual_split_iter = attrs_.find("manual_split"); + if (manual_split_iter != attrs_.end()) { + param_split_shapes_.clear(); + manual_split_ = true; + auto var = manual_split_iter->second->cast(); + MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString(); + + if (var->size() > 0) { + std::vector elements = var->value(); + for (auto &ele : elements) { + if (ele->isa()) { + auto value_tuple = ele->cast(); + std::vector value_vector = value_tuple->value(); + if (value_vector.size() != 2) { + MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2."; + return FAILED; + } + param_split_shapes_.push_back(static_cast(GetValue(value_vector[0]))); + index_offsets_.push_back(static_cast(GetValue(value_vector[1]))); + } else { + MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue"; + return FAILED; + } + } + + if (param_split_shapes_.empty()) { + MS_LOG(ERROR) << "Failed to extract param split strategy."; + return FAILED; + } + } + } + + return SUCCESS; +} + +Status GatherV2PInfo::CheckManualSplit() { + auto param_shape = inputs_shape_.at(0); + int32_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, + [](int32_t s, int32_t shape) { return s + shape; }); + if (split_shape_sum < param_shape.at(0)) { + MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape."; + return FAILED; + } + + if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int32_t &offset) { return offset < 0; })) { + MS_LOG(ERROR) << "Failure: Index offset must not less than 0."; + return FAILED; + } + + return SUCCESS; +} + +Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + // param slice shape need 32Byte aligned + auto param_shape = inputs_shape_.at(0); + auto param_strategy = strategy->GetInputDim().at(0); + auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1); + if ((target_ != CPU) && (slice_shape % 8 != 0) && (slice_shape != 1)) { + MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned."; + return FAILED; + } + + // only support 1-dim and 2-dim param + if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) { + MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size(); + return FAILED; + } + + // don't support scalar index + if (inputs_shape_.at(1).size() == 0) { + MS_LOG(DEBUG) << name_ << ": Don't support scalar index."; + return FAILED; + } + + // axis=0, index_shape(0)%param_strategy(0) must be 0 + Shape index_shape = inputs_shape_.at(1); + if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) { + MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; + return FAILED; + } + + if (manual_split_) { + if (CheckManualSplit() != SUCCESS) { + return FAILED; + } + // when using manual_split, no need to check belowings. + return SUCCESS; + } + + // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 + if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { + MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; + return FAILED; + } + + // param_strategy(axis) != 1, index can't be splited + auto index_strategy = strategy->GetInputDim().at(1); + auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); + if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) { + MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; + return FAILED; + } + + // param_strategy(axis) != 1, Don't support repeated calc + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); + if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; + return FAILED; + } + + return SUCCESS; +} + +Status GatherV2PInfo::InferMirrorOps() { + // There is no mirror operators for manual split + if (manual_split_) { + return SUCCESS; + } + + mirror_ops_.clear(); + Shape input_a_tensor_map = inputs_tensor_map_.at(0); + std::vector input_a_group; + if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input a failed."; + return FAILED; + } + + OperatorVector op_for_input_a, op_for_input_b, op_for_axis; + if (input_a_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror group is empty."; + return SUCCESS; + } else { + op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); + } + + mirror_ops_.push_back(op_for_input_a); + mirror_ops_.push_back(op_for_input_b); + mirror_ops_.push_back(op_for_axis); + + return SUCCESS; +} + +Status GatherV2PInfo::InferDevMatrixShape() { + dev_matrix_shape_.clear(); + out_dev_matrix_shape_.clear(); + // infer input dev_matrix_shape + auto param_strategy = strategy_->GetInputDim().at(0); + auto index_strategy = strategy_->GetInputDim().at(1); + + if (manual_split_) { + dev_matrix_shape_ = param_strategy; + out_dev_matrix_shape_ = dev_matrix_shape_; + return SUCCESS; + } + + dev_matrix_shape_ = param_strategy; + + // param_strategy(axis)!=1, + if (param_strategy.at(IntToSize(axis_)) != 1) { + std::reverse(dev_matrix_shape_.begin(), dev_matrix_shape_.end()); + } else { + dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end()); + } + + // infer out dev_matrix_shape + // axis!=0, split axis + if (axis_ != 0 && param_strategy.at(IntToSize(axis_)) != 1) { + out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(IntToSize(axis_))); + for (size_t i = 1; i < param_strategy.size(); ++i) { + if (i == IntToSize(axis_)) { + out_dev_matrix_shape_.push_back(1); + } else { + out_dev_matrix_shape_.push_back(param_strategy.at(i)); + } + } + } else { + out_dev_matrix_shape_ = dev_matrix_shape_; + } + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); + auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); + if (param_product * index_product < SizeToInt(dev_num)) { + out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), SizeToInt(dev_num / (param_product * index_product))); + } + + return SUCCESS; +} + +Status GatherV2PInfo::InferTensorMap() { + if (manual_split_) { + inputs_tensor_map_.push_back({1, 0}); + inputs_tensor_map_.push_back({-1, 1}); + outputs_tensor_map_.push_back({-1, 1, 0}); + return SUCCESS; + } + // infer input tensor map + // param_strategy(axis) != 1 + size_t param_size = inputs_shape_.at(0).size(); + size_t index_size = inputs_shape_.at(1).size(); + size_t total_size = param_size + index_size; + std::vector tensor_map_index; + std::vector tensor_map_params; + auto param_strategy = strategy_->GetInputDim().at(0); + if (param_strategy.at(IntToSize(axis_)) != 1) { + tensor_map_index.insert(tensor_map_index.begin(), index_size, -1); + for (size_t i = 0; i < param_size; ++i) { + tensor_map_params.push_back(SizeToInt(i)); + } + } else { + // param_strategy(axis) == 1 + for (size_t i = 0; i < param_size; ++i) { + tensor_map_params.push_back(SizeToInt(total_size - i - 1)); + } + for (size_t i = 0; i < index_size; ++i) { + tensor_map_index.push_back(SizeToInt(index_size - i - 1)); + } + } + + // infer output tensor map + std::vector tensor_map_out; + if (param_strategy.at(IntToSize(axis_)) == 1) { + // param_strategy(axis) == 1 + for (size_t i = 0; i < param_size; ++i) { + if (i == IntToSize(axis_)) { + for (size_t j = 0; j < index_size; ++j) { + tensor_map_out.push_back(SizeToInt(index_size - j - 1)); + } + } else { + tensor_map_out.push_back(SizeToInt(total_size - i - 1)); + } + } + } else { + // param_strategy(axis) != 1 + if (axis_ == 0) { + tensor_map_out.insert(tensor_map_out.end(), 0); + tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); + for (size_t i = 1; i < param_size; ++i) { + tensor_map_out.push_back(i); + } + } else { + for (size_t i = 0; i < param_size; ++i) { + if (i == IntToSize(axis_)) { + tensor_map_out.insert(tensor_map_out.end(), index_size, -1); + } else { + tensor_map_out.push_back(SizeToInt(param_size - i - 1)); + } + } + } + } + + inputs_tensor_map_.emplace_back(std::move(tensor_map_params)); + inputs_tensor_map_.emplace_back(std::move(tensor_map_index)); + outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); + return SUCCESS; +} + +Status GatherV2PInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape input_index_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + int32_t rank = g_device_manager->global_rank(); + // infer tensor layout + TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; + if (manual_split_) { + input_shape[0] = param_split_shapes_[rank / dev_matrix_shape_[1]]; + input_shape[0] = input_shape[0] * dev_matrix_shape_[0]; + } + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || + (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != + SUCCESS)) { + return FAILED; + } + // infer tensor info + TensorInfo input_tensor_info(input_tensor_layout); + TensorInfo input_index_info(input_index_layout); + TensorInfo output_tensor_info(output_tensor_layout); + + Shape slice_shape = input_tensor_info.slice_shape(); + MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + inputs_tensor_info_.push_back(input_index_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status GatherV2PInfo::InferBias() { + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + auto input_shape = inputs_shape_.at(0); + auto params_strategy = strategy_->GetInputDim().at(0); + // axis don't split + if (params_strategy.at(axis_) == 1) { + bias_ = 0; + return SUCCESS; + } + // params_size=1, axis=0 + if ((input_shape.size() == 1) && (axis_ == 0)) { + slice_size_ = input_shape.at(0) / params_strategy.at(0); + bias_ = rank * slice_size_; + return SUCCESS; + } + // params_size=2, axis=0 + if ((input_shape.size() == 2) && (axis_ == 0)) { + slice_size_ = input_shape.at(0) / params_strategy.at(0); + bias_ = rank / params_strategy.at(1) * slice_size_; + return SUCCESS; + } + // params_size=2, axis=1 + if ((input_shape.size() == 2) && (axis_ == 1)) { + slice_size_ = input_shape.at(1) / params_strategy.at(1); + bias_ = rank % params_strategy.at(1) * slice_size_; + return SUCCESS; + } + MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_; + return FAILED; +} + +Status GatherV2PInfo::InferOffset() { + CheckGlobalDeviceManager(); + size_t rank = g_device_manager->global_rank(); + if (rank < index_offsets_.size()) { + index_offset_ = index_offsets_.at(rank); + MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_; + return SUCCESS; + } + + MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size(); + return FAILED; +} + +Status GatherV2PInfo::InferGroup() { + auto param_strategy = strategy_->GetInputDim().at(0); + size_t dim = IntToSize(axis_); + if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { + dim = (axis_ + 1) % 2; + } + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + int32_t rank = g_device_manager->global_rank(); + RankList dev_list = g_device_manager->GetDeviceListByStageId(0); + DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); + RankList group_devices; + if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group failed."; + return FAILED; + } + if (group_devices.size() == 1) { + MS_LOG(INFO) << "the group is empty"; + return SUCCESS; + } + + group_ = g_device_manager->CreateGroup(group_devices); + return SUCCESS; +} + +std::vector GetRankFromGroup(const Group &group) { + std::vector rank_list; + auto device_list = group.GetDevicesList(); + for (auto &device : device_list) { + rank_list.insert(rank_list.end(), device.rank() % 8); + } + return rank_list; +} + +Status GatherV2PInfo::InferForwardCommunication() { + forward_op_.clear(); + auto param_strategy = strategy_->GetInputDim().at(0); + // don't split axis or target is not CPU, no need forward communication + if (target_ != CPU || param_strategy.at(IntToSize(axis_)) == 1) { + return SUCCESS; + } + // split axis + OperatorName operator_name; + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Group failed."; + return FAILED; + } + Attr attr_group; + operator_name = REDUCE_SCATTER; + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Group failed."; + return FAILED; + } + attr_group = std::make_pair(GROUP, MakeValue(group_.name())); + Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); + OperatorAttrs attrs = {attr_op, attr_group}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + Operator op = std::make_pair(operator_name, args); + + forward_op_.push_back(op); + return SUCCESS; +} + +Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + GenerateGraph gen_g = GenerateGraph(); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + if (manual_split_) { + if (InferOffset() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Bias failed."; + return FAILED; + } + auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)}); + auto gather_v2 = + gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub, CreatInt32Imm(axis_)}); + std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, gather_v2)); + return SUCCESS; + } + if (InferBias() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Bias failed."; + return FAILED; + } + auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); + auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); + auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); + auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); + auto gather_v2 = + gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); + auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); + auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); + auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)}); + auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims}); + // don't need expandim,if param_size = 1, + if (inputs_shape_.at(0).size() == 1) { + mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast}); + } + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Group failed."; + return FAILED; + } + Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); + Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); + OperatorAttrs attrs = {attr_op, attr_group}; + auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); + std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, reduce_scatter)); + + return SUCCESS; +} + +ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { + if (manual_split_) { + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return nullptr; + } + return replace_graph_; + } + + auto param_strategy = strategy_->GetInputDim().at(0); + // target_ == CPU, no need to raplace graph + if (target_ == CPU) { + return nullptr; + } + if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return nullptr; + } + return replace_graph_; +} + +Status GatherV2PInfo::ComputeReplaceOp() { + if (InferBias() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer offset failed."; + return FAILED; + } + OperatorName op_name = EMBEDDING_LOOKUP; + OperatorAttrs attrs; + Attr param_offset = std::make_pair("offset", MakeValue(bias_)); + OperatorParams params = {std::make_pair(param_offset, 3)}; + OperatorArgs args = std::make_pair(attrs, params); + Operator op = std::make_pair(op_name, args); + replace_op_.push_back(op); + + return SUCCESS; +} + +Status GatherV2PInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + // only target_ == CPU, we need to replace op + if (target_ == CPU && ComputeReplaceOp() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed."; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + auto param_strategy = strategy_->GetInputDim().at(0); + // cost model set axis and strategy + auto gatherv2_2cost = std::dynamic_pointer_cast(operator_cost()); + gatherv2_2cost->set_axis(axis_); + gatherv2_2cost->set_strategy(param_strategy); + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shape input1_split(inputs_shape_[1].size(), 1); + Shapes splittable_inputs = {input0_split, input1_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +std::shared_ptr>> GatherV2PInfo::GenerateBatchStrategies() { + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + Dimensions param_strategy(inputs_shape_[0].size(), 1); + Dimensions index_strategy; + index_strategy.push_back(SizeToInt(dev_num)); + for (size_t i = 1; i < inputs_shape_[1].size(); i++) { + index_strategy.push_back(1); + } + std::vector strategy_v = {param_strategy, index_strategy}; + return std::make_shared>>(strategy_v); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h new file mode 100644 index 0000000000..eb26c616d0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class GatherV2PInfo : public OperatorInfo { + public: + GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), + axis_(0), + bias_(0), + index_offset_(0), + slice_size_(0) {} + ~GatherV2PInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + std::shared_ptr>> GenerateBatchStrategies() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + + private: + Status ComputeReplaceGraph(const CNodePtr &cnode); + Status CheckManualSplit(); + Status ComputeReplaceOp(); + Status InferBias(); + Status InferOffset(); + Status InferGroup(); + + int32_t axis_; + std::string target_ = DEVICE; + std::string replace_op_name_ = GATHERV2; + int32_t bias_; + int32_t index_offset_; + int32_t slice_size_; + Shape out_dev_matrix_shape_; + Group group_; + bool manual_split_ = false; + std::vector param_split_shapes_; + std::vector index_offsets_; +}; + +class SparseGatherV2Info : public GatherV2PInfo { + public: + SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} + ~SparseGatherV2Info() override = default; + + private: + std::string replace_op_name_ = SPARSE_GATHERV2; +}; + +class EmbeddingLookupInfo : public GatherV2PInfo { + public: + EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} + ~EmbeddingLookupInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc new file mode 100644 index 0000000000..3606732156 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc @@ -0,0 +1,269 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/get_next_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Status GetNextInfo::InferTensorMap() { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + + for (auto shp : shapes_) { + TensorMap out_tensor_map; + for (size_t i = 0; i < shp.size(); ++i) { + if (full_batch) { + out_tensor_map.push_back(MAP_NONE); + } else { + out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1)); + } + } + outputs_tensor_map_.push_back(out_tensor_map); + } + return SUCCESS; +} + +Status GetNextInfo::InferTensorLayout(TensorLayouts *outputs_layout) { + if (outputs_layout == nullptr) { + MS_LOG(ERROR) << name_ << " : The layout is null."; + return FAILED; + } + for (size_t i = 0; i < outputs_shape_.size(); ++i) { + TensorLayout output_layout; + if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) { + return FAILED; + } + outputs_layout->push_back(output_layout); + } + return SUCCESS; +} + +Strategys GetNextInfo::GetOutputStrategy() { + Strategys outputs_strategy; + for (auto shp : shapes_) { + Dimensions out_strategy; + out_strategy.push_back(dev_num_); + for (size_t i = 1; i < shp.size(); ++i) { + out_strategy.push_back(1); + } + outputs_strategy.push_back(out_strategy); + } + return outputs_strategy; +} + +Status GetNextInfo::InferTensorInfo() { + TensorLayouts outputs_layout; + if (InferTensorLayout(&outputs_layout) != SUCCESS) { + return FAILED; + } + for (size_t i = 0; i < outputs_shape_.size(); ++i) { + TensorInfo output_tensor_info(outputs_layout[i]); + outputs_tensor_info_.push_back(output_tensor_info); + } + return SUCCESS; +} + +Status GetNextInfo::InferDevMatrixShape() { + size_t max_shape_length = 0; + for (auto shp : shapes_) { + if (max_shape_length < shp.size()) { + max_shape_length = shp.size(); + } + } + if (max_shape_length == 0) { + MS_LOG(ERROR) << name_ << " : shape is 0"; + } + dev_matrix_shape_.push_back(dev_num_); + for (size_t i = 1; i < max_shape_length; ++i) { + dev_matrix_shape_.push_back(1); + } + return SUCCESS; +} + +Status GetNextInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed"; + return FAILED; + } + if (InferReplaceOps(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer replace Ops failed"; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init success"; + return SUCCESS; +} + +Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { + std::vector stras = strategy->GetInputDim(); + for (Dimensions stra : stras) { + if (stra.size() != 0) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + } + int32_t stage = strategy->GetInputStage(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); + dev_num_ = dev_num; + return SUCCESS; +} + +Status GetNextInfo::GetAttrTypes() { + auto iter = attrs_.find(TYPES); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + auto iter_cast = iter->second->cast(); + MS_EXCEPTION_IF_NULL(iter_cast); + auto types = iter_cast->value(); + for (auto &type : types) { + MS_EXCEPTION_IF_NULL(type); + types_.push_back(type->ToString()); + } + } else if (iter->second->isa()) { + auto iter_cast = iter->second->cast(); + MS_EXCEPTION_IF_NULL(iter_cast); + auto types = iter_cast->value(); + for (auto &type : types) { + MS_EXCEPTION_IF_NULL(type); + types_.push_back(type->ToString()); + } + } else { + MS_LOG(ERROR) << name_ << " : The value of types is not list."; + return FAILED; + } + } + return SUCCESS; +} + +Status GetNextInfo::GetAttrShapes() { + shapes_ = outputs_shape_; + if (shapes_.size() == 0) { + MS_LOG(ERROR) << name_ << " : Shape is None."; + return FAILED; + } + return SUCCESS; +} + +Status GetNextInfo::GetAttrOutPutNum() { + auto iter = attrs_.find(GETNEXT_NUM); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + output_num_ = iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of output_num is not int."; + return FAILED; + } + } + return SUCCESS; +} + +Status GetNextInfo::GetAttrs() { + if (GetAttrTypes() == FAILED || GetAttrShapes() == FAILED || GetAttrOutPutNum() == FAILED) { + return FAILED; + } + if (types_.size() != IntToSize(output_num_) || shapes_.size() != IntToSize(output_num_) || output_num_ == 0) { + MS_LOG(ERROR) << name_ << " : The output_num is not equal to shapes size."; + return FAILED; + } + return SUCCESS; +} + +Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + + Shapes out_shapes = outputs_shape_; + for (size_t i = 0; i < out_shapes.size(); ++i) { + if (dev_num_ <= 0) { + MS_LOG(ERROR) << name_ << " : The dev num is 0."; + return FAILED; + } + if (out_shapes[i][0] % dev_num_ != 0) { + MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; + return FAILED; + } + if (!full_batch) { + out_shapes[i][0] = out_shapes[i][0] / dev_num_; + } + } + ValuePtr new_shapes = MakeValue(out_shapes); + Attr attr_types = std::make_pair(TYPES, attrs_[TYPES]); + Attr attr_shapes = std::make_pair(SHAPES, new_shapes); + Attr attr_num = std::make_pair(GETNEXT_NUM, attrs_[GETNEXT_NUM]); + Attr attr_shared_name = std::make_pair(SHARED_NAME, attrs_[SHARED_NAME]); + OperatorAttrs attrs = {attr_types, attr_shapes, attr_num, attr_shared_name}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + replace_op_ = {std::make_pair(GET_NEXT, args)}; + return SUCCESS; +} + +Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +Status GetNextInfo::GenerateStrategies(int32_t stage_id) { + is_auto_parallel_ = true; + std::vector stra; + StrategyPtr sp = std::make_shared(stage_id, stra); + if (SetCostUnderStrategy(sp) == SUCCESS) { + MS_LOG(INFO) << name_ << " : Successfully generated strategy."; + PrintStrategy(sp); + } else { + MS_LOG(ERROR) << name_ << " : Generating strategy failed."; + return FAILED; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h new file mode 100644 index 0000000000..36e7a0fcb3 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class GetNextInfo : public OperatorInfo { + public: + GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~GetNextInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *outputs_layout); + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferReplaceOps(const StrategyPtr &strategy); + Status GetAttrTypes(); + Status GetAttrShapes(); + Status GetAttrOutPutNum(); + Strategys GetOutputStrategy(); + Status InferAsLossDivisor() override { return SUCCESS; } + + private: + int32_t dev_num_ = 1; + std::vector types_; + Shapes shapes_; + int32_t output_num_ = 0; + std::string shared_name_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc new file mode 100644 index 0000000000..126fdcf84e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/l2_normalize_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(INFO) << name_ << " : Init success."; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + Dimensions input_strategy = stra.at(0); + int32_t axis_index = axis_; + if (axis_ < 0) { + size_t input_dim = inputs_shape_.at(0).size(); + axis_index = static_cast(input_dim) + axis_; + } + + if (input_strategy[IntToSize(axis_index)] != 1) { + MS_LOG(ERROR) << name_ << " : The dim " << axis_index << " of input strategy must be 1."; + return FAILED; + } + + return SUCCESS; +} + +Status L2NormalizeInfo::GetAttrs() { + auto iter = attrs_.find(AXIS); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis_ = iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of axis is not int."; + return FAILED; + } + } + + return SUCCESS; +} + +Status L2NormalizeInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_tensor_map = inputs_tensor_map_.at(0); + std::vector input_group; + if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group failed."; + return FAILED; + } + + OperatorVector op_for_weight; + if (input_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror ops is empty."; + return SUCCESS; + } else { + op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); + mirror_ops_.push_back(op_for_weight); + MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group is " << input_group[0].name(); + } + + return SUCCESS; +} + +Status L2NormalizeInfo::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << " : GetAttrs failed."; + return FAILED; + } + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size() - 1, 1); + int32_t axis_index = axis_; + if (axis_ < 0) { + size_t input_dim = inputs_shape_.at(0).size(); + axis_index = static_cast(input_dim) + axis_; + } + (void)input0_split.insert(input0_split.begin() + axis_index, 0); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h new file mode 100644 index 0000000000..c74dde4b4b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class L2NormalizeInfo : public Activation { + public: + L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : Activation(name, inputs_shape, outputs_shape, attrs) {} + ~L2NormalizeInfo() override = default; + Status GenerateStrategies(int32_t stage_id) override; + + protected: + Status GetAttrs() override; + Status InferMirrorOps() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + + private: + int32_t axis_ = 0; // Default value = 0 +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc new file mode 100644 index 0000000000..62d7c6d61e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc @@ -0,0 +1,324 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/layer_norm_info.h" +#include +#include +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +Status LayerNormInfo::GetAttrs() { + auto iter = attrs_.find(BEGIN_NORM_AXIS); + if (iter == attrs_.end()) { + MS_LOG(ERROR) << name_ << ": Can not find the attr of begin norm axis"; + return FAILED; + } + if ((iter->second == nullptr) || !iter->second->isa()) { + MS_LOG(ERROR) << name_ << ": The axis type is not int"; + return FAILED; + } + + int32_t dim = SizeToInt(input_shape_.size()); + auto axis = GetValue(iter->second); + if ((axis >= dim) || (axis < -dim)) { + MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim << ", " << dim - 1 << "]"; + return FAILED; + } + + if (axis < 0) { + axis = axis + dim; + } + begin_norm_axis_ = IntToSize(axis); + return SUCCESS; +} + +Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + std::vector stra = strategy->GetInputDim(); + if (stra.size() != LAYER_NORM_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size(); + return FAILED; + } + + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy value"; + return FAILED; + } + + Dimensions input_strategy = stra[LAYER_NORM_INPUT_INDEX]; + Dimensions gamma_strategy = stra[LAYER_NORM_GAMMA_INDEX]; + Dimensions beta_strategy = stra[LAYER_NORM_BETA_INDEX]; + if (begin_norm_axis_ >= input_strategy.size()) { + MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_; + return FAILED; + } + // check input strategy + for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) { + if (input_strategy[i] != NO_SPLIT_STRATEGY) { + MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy); + return FAILED; + } + } + + // check gamma and beta strategy + if ((gamma_strategy.size() > input_strategy.size()) || (beta_strategy.size() > input_strategy.size())) { + MS_LOG(ERROR) << name_ << " : The strategy size of gamma or beta is lager than input strategy"; + return FAILED; + } + + size_t gamma_diff = input_strategy.size() - gamma_strategy.size(); + for (size_t j = 0; j < gamma_strategy.size(); ++j) { + if (gamma_strategy[j] != input_strategy[gamma_diff + j]) { + MS_LOG(ERROR) << name_ << ": Invalid gamma strategy " << ShapeToString(gamma_strategy); + return FAILED; + } + } + + size_t beta_diff = input_strategy.size() - beta_strategy.size(); + for (size_t k = 0; k < beta_strategy.size(); ++k) { + if (beta_strategy[k] != input_strategy[beta_diff + k]) { + MS_LOG(ERROR) << name_ << ": Invalid beta strategy " << ShapeToString(beta_strategy); + return FAILED; + } + } + return SUCCESS; +} + +Status LayerNormInfo::InferDevMatrixShape() { + if (strategy_ == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is empty"; + return FAILED; + } + dev_matrix_shape_ = stra[0]; + return SUCCESS; +} + +Status LayerNormInfo::CreateTensorMap(size_t input_index) { + if (inputs_shape_.size() <= input_index) { + MS_LOG(ERROR) << name_ << ": Invalid index" << input_index; + return FAILED; + } + Shape shape = inputs_shape_[input_index]; + Shape tensor_map; + for (size_t i = 0; i < shape.size(); ++i) { + tensor_map.push_back(SizeToInt(shape.size() - i - 1)); + } + inputs_tensor_map_.push_back(tensor_map); + outputs_tensor_map_.push_back(tensor_map); + return SUCCESS; +} + +Status LayerNormInfo::InferTensorMap() { + if ((CreateTensorMap(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorMap(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || + (CreateTensorMap(LAYER_NORM_BETA_INDEX) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Create tensor map failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::CreateMirrorOp(size_t input_index) { + if (inputs_tensor_map_.size() <= input_index) { + MS_LOG(ERROR) << name_ << ": Invalid index " << input_index; + return FAILED; + } + Shape tensor_map = inputs_tensor_map_[input_index]; + std::vector group; + if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input " << input_index << " failed"; + return FAILED; + } + OperatorVector mirror_op; + if (!group.empty()) { + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input " << input_index << " success, group is " + << group[0].name(); + } + mirror_ops_.push_back(mirror_op); + return SUCCESS; +} + +Status LayerNormInfo::InferMirrorOps() { + if ((CreateMirrorOp(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateMirrorOp(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || + (CreateMirrorOp(LAYER_NORM_BETA_INDEX) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Create mirror op failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::CreateTensorInfo(size_t input_index) { + if ((inputs_shape_.size() <= input_index) || (inputs_tensor_map_.size() <= input_index)) { + MS_LOG(ERROR) << name_ << ": Invalid input index" << input_index; + return FAILED; + } + Shape tensor_map = inputs_tensor_map_[input_index]; + Shape shape = inputs_shape_[input_index]; + TensorLayout tensor_layout; + if (tensor_layout.InitFromVector(dev_matrix_shape_, tensor_map, shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for input " << input_index << " failed"; + return FAILED; + } + + TensorInfo tensor_info(tensor_layout); + inputs_tensor_info_.push_back(tensor_info); + outputs_tensor_info_.push_back(tensor_info); + return SUCCESS; +} + +Status LayerNormInfo::InferTensorInfo() { + if ((CreateTensorInfo(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorInfo(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || + (CreateTensorInfo(LAYER_NORM_BETA_INDEX) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Create tensor info failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::InferAsLossDivisor() { + if (outputs_tensor_map_.size() != LAYER_NORM_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": The size of outputs tensor map " << outputs_tensor_map_.size() << " is error"; + return FAILED; + } + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); + MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_) + << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0]) + << ", as_loss_divisor_ is " << as_loss_divisor_; + return SUCCESS; +} + +Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Set cost failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector &sp_vector) { + if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) { + MS_LOG(ERROR) << name_ << ": The dimension of gamma or beta is lager than input"; + return FAILED; + } + + size_t gamma_diff = input_shape_.size() - gamma_shape_.size(); + size_t beta_diff = input_shape_.size() - beta_shape_.size(); + for (auto &sp : sp_vector) { + if ((sp == nullptr) || sp->GetInputDim().empty()) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + std::vector tmp_strategy; + Dimensions input_strategy = sp->GetInputDim()[0]; + Dimensions gamma_strategy = input_strategy; + (void)gamma_strategy.erase(gamma_strategy.begin(), + gamma_strategy.begin() + static_cast(gamma_diff)); + Dimensions beta_strategy = input_strategy; + (void)beta_strategy.erase(beta_strategy.begin(), beta_strategy.begin() + static_cast(beta_diff)); + + // reset the strategy + tmp_strategy.push_back(input_strategy); + tmp_strategy.push_back(gamma_strategy); + tmp_strategy.push_back(beta_strategy); + sp->ResetInputs(tmp_strategy); + } + return SUCCESS; +} + +Status LayerNormInfo::GenerateStrategies(int32_t stage_id) { + if (InitShapes() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init shapes failed"; + return FAILED; + } + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Get attrs failed"; + return FAILED; + } + Shape input_split(input_shape_.size(), SPLIT_FLAG); + if (begin_norm_axis_ >= input_split.size()) { + MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_; + return FAILED; + } + + // Can not split the dimensions from begin norm axis + for (size_t i = begin_norm_axis_; i < input_split.size(); ++i) { + input_split[i] = NO_SPLIT_FLAG; + } + + // Generate strategy for input + Shapes splittable_inputs = {input_split}; + Shapes tmp_inputs_shape = {input_shape_}; + std::vector sp_vector; + is_auto_parallel_ = true; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate input strategy failed"; + return FAILED; + } + + // Generate the strategies for gamma and beta + if (GenerateGammaAndBetaStrategies(sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate gamma and beta strategies failed"; + return FAILED; + } + + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(DEBUG) << name_ << ": Successfully generated " << success << " strategy"; + } + } + return SUCCESS; +} + +Status LayerNormInfo::InitShapes() { + if (inputs_shape_.size() != LAYER_NORM_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": Invalid inputs size"; + return FAILED; + } + input_shape_ = inputs_shape_[LAYER_NORM_INPUT_INDEX]; + gamma_shape_ = inputs_shape_[LAYER_NORM_GAMMA_INDEX]; + beta_shape_ = inputs_shape_[LAYER_NORM_BETA_INDEX]; + return SUCCESS; +} + +Status LayerNormInfo::Init(const StrategyPtr &strategy) { + if ((InitShapes() != SUCCESS) || (InitWithAutoRepeatCalc(strategy)) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed"; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success"; + return SUCCESS; +} + +Status LayerNormInfo::InitForCostModel(const StrategyPtr &strategy) { + if ((InitShapes() != SUCCESS) || (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed"; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success"; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h new file mode 100644 index 0000000000..9ee11bb215 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ + +#include +#include +#include +#include +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +constexpr size_t LAYER_NORM_INPUT_SIZE = 3; +constexpr size_t LAYER_NORM_INPUT_INDEX = 0; +constexpr size_t LAYER_NORM_GAMMA_INDEX = 1; +constexpr size_t LAYER_NORM_BETA_INDEX = 2; +constexpr char BEGIN_NORM_AXIS[] = "begin_norm_axis"; + +// The dimensions of input tensor starting from begin norm axis cannot be split. Other dimensions can be split +// arbitrarily. Gamma and beta should match input to meet the broadcast requirements of mul and add. +class LayerNormInfo : public OperatorInfo { + public: + LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(true)), + begin_norm_axis_(0) {} + ~LayerNormInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferAsLossDivisor() override; + Status CreateTensorMap(size_t input_index); + Status CreateTensorInfo(size_t input_index); + Status CreateMirrorOp(size_t input_index); + Status GenerateGammaAndBetaStrategies(const std::vector &sp_vector); + Status InitShapes(); + + private: + size_t begin_norm_axis_; + Shape input_shape_; + Shape gamma_shape_; + Shape beta_shape_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc new file mode 100644 index 0000000000..889f204fb0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc @@ -0,0 +1,232 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/loss_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + Dimensions input_strategy = stra.at(0); + Dimensions label_strategy = stra.at(1); + if (input_strategy != label_strategy) { + MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; + return FAILED; + } + + int32_t axis_index = axis_; + if (axis_ < 0) { + size_t input_dim = inputs_shape_.at(0).size(); + axis_index = static_cast(input_dim) + axis_; + } + + int32_t input_axis_strategy = input_strategy.at(IntToSize(axis_index)); + int32_t label_axis_strategy = label_strategy.at(IntToSize(axis_index)); + // Dimension corresponding to axis is un-splittable + if ((input_axis_strategy != MIN_SLICE_NUM) && (label_axis_strategy != MIN_SLICE_NUM)) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ + << " : The strategy corresponding to axis dimension is not 1, input: " << input_axis_strategy + << ", label: " << label_axis_strategy; + } else { + MS_LOG(ERROR) << name_ + << " : The strategy corresponding to axis dimension is not 1, input: " << input_axis_strategy + << ", label: " << label_axis_strategy; + } + return FAILED; + } + + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() { + if ((inputs_shape_.size() != SoftmaxCrossEntropyWithLogitsInputsSize) || + (outputs_shape_.size() != SoftmaxCrossEntropyWithLogitsOutputsSize)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; + return FAILED; + } + + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + dev_matrix_shape_ = input_strategy; + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorMap() { + std::vector tensor_map_index; + size_t size = inputs_shape_[0].size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int32_t)(size - i - 1)); + } + + std::vector first_output_tensor_map = {tensor_map_index[0]}; + inputs_tensor_map_.push_back(tensor_map_index); // input + inputs_tensor_map_.push_back(tensor_map_index); // label + outputs_tensor_map_.push_back(first_output_tensor_map); // output-0 + outputs_tensor_map_.push_back(tensor_map_index); // output-1 + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape first_output_shape = outputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = {{inputs_strategy[0][0]}, inputs_strategy.at(0)}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + Shape first_output_slice_shape = outputs_slice_shape.at(0); + + TensorMap input_tensor_map = inputs_tensor_map_.at(0); + TensorMap first_output_tensor_map = outputs_tensor_map_.at(0); + + TensorLayout input_tensor_layout, first_output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, input_tensor_map, input_shape) != SUCCESS) || + (first_output_tensor_layout.InitFromVector(dev_matrix_shape_, first_output_tensor_map, first_output_shape) != + SUCCESS)) { + return FAILED; + } + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo first_output_tensor_info(first_output_tensor_layout, first_output_shape, first_output_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); // input + inputs_tensor_info_.push_back(input_tensor_info); // label + outputs_tensor_info_.push_back(first_output_tensor_info); // output-0 + outputs_tensor_info_.push_back(input_tensor_info); // output-1 + + return SUCCESS; +} + +// There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload the function. +Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() { + if (outputs_tensor_map_.size() != 2) { + MS_LOG(ERROR) << name_ << " : The size of outputs tensor map " << outputs_tensor_map_.size() << " is error."; + return FAILED; + } + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[1]); + MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_) + << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[1]) << ", as_loss_divisor_ is " + << as_loss_divisor_; + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +void SoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + split_flag_list_[i] = true; + } +} + +Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << " : GetAttrs failed."; + return FAILED; + } + int32_t axis_index = axis_; + if (axis_ < 0) { + size_t input_dim = inputs_shape_[0].size(); + axis_index = static_cast(input_dim) + axis_; + } + is_auto_parallel_ = true; + + Shape input0_split; + (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); + input0_split[IntToSize(axis_index)] = 0; + Shapes splittable_inputs = {input0_split, input0_split}; + std::vector sp_vector; + if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies failed."; + return FAILED; + } + + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + PrintStrategy(strategy); + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h new file mode 100644 index 0000000000..7e5478bedf --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h @@ -0,0 +1,67 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +// infer shape: +// input_0 : [a, b], input_1 : [a, b] +// output_0 : [a], output_1: [a, b] +class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { + public: + SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, + std::make_shared(false)) {} + ~SoftmaxCrossEntropyWithLogitsInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + // There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload + // the InferAsLossDivisor. + Status InferAsLossDivisor() override; + + private: + int32_t axis_ = -1; // default -1 +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc new file mode 100644 index 0000000000..60a3d60b39 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc @@ -0,0 +1,647 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/matmul_info.h" + +#include +#include +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b, + Shape *dev_matrix_shape) { + MS_EXCEPTION_IF_NULL(dev_matrix_shape); + size_t mat_a_size = mat_a_strategy.size(); + size_t mat_b_size = mat_b_strategy.size(); + if (mat_a_size >= mat_b_size) { + // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32] + // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) + + // [2],[4] in the example above + for (size_t i = 0; i < SECOND_FROM_END(mat_a_size); ++i) { + dev_matrix_shape->push_back(mat_a_strategy.at(i)); + } + } else { + // for example: mat_a_strategy:[8,16], mat_b_strategy:[2,4,16,32] + // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) + + // [2],[4] in the example above + for (size_t i = 0; i < SECOND_FROM_END(mat_b_size); ++i) { + dev_matrix_shape->push_back(mat_b_strategy.at(i)); + } + } + + // [8],[16] in the example above + dev_matrix_shape->push_back(mat_a_strategy.at(SECOND_FROM_END(mat_a_size))); + dev_matrix_shape->push_back(mat_a_strategy.back()); + + // [32] in the example above + if (!transpose_b) { + dev_matrix_shape->push_back(mat_b_strategy.back()); + } else { + dev_matrix_shape->push_back(mat_b_strategy.at(SECOND_FROM_END(mat_b_size))); + } +} + +Status MatMulBase::GetAttrs() { + if (attrs_.size() < MATMUL_ATTRS_SIZE) { + MS_LOG(ERROR) << name_ << " : The size of attrs small than 2."; + return FAILED; + } + + auto transpose_a_iter = attrs_.find(TRANSPOSE_A); + if (transpose_a_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(transpose_a_iter->second); + if (transpose_a_iter->second->isa()) { + transpose_a_ = transpose_a_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool."; + return FAILED; + } + } + + auto transpose_b_iter = attrs_.find(TRANSPOSE_B); + if (transpose_b_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(transpose_b_iter->second); + if (transpose_b_iter->second->isa()) { + transpose_b_ = transpose_b_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool."; + return FAILED; + } + } + + auto forward_reduce_scatter_iter = attrs_.find(FORWARD_REDUCE_SCATTER); + if (forward_reduce_scatter_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(forward_reduce_scatter_iter->second); + if (forward_reduce_scatter_iter->second->isa()) { + forward_reduce_scatter_ = forward_reduce_scatter_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of forward reduce scatter is not bool."; + return FAILED; + } + } + + // infer inputs dimension size + if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; + return FAILED; + } + mat_a_dimension_ = inputs_shape_.at(0).size(); + mat_b_dimension_ = inputs_shape_.at(1).size(); + + return SUCCESS; +} + +Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) { + size_t long_size = long_strategy.size(); + size_t short_size = short_strategy.size(); + if (long_size < short_size) { + MS_LOG(ERROR) << "Size error, the size of long strategy is " << long_size << ", the size of short strategy is " + << short_size; + return FAILED; + } + + size_t len_diff = long_size - short_size; + for (size_t j = 0; j < SECOND_FROM_END(short_size); ++j) { + if (long_strategy.at(len_diff + j) != short_strategy.at(j)) { + MS_LOG(ERROR) << "Strategies of relevant dimensions are not equal, long strategy is " + << ShapeToString(long_strategy) << ", short strategy is " << ShapeToString(short_strategy); + return FAILED; + } + } + + return SUCCESS; +} + +Status MatMul::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + Dimensions mat_a_strategy = stra.at(0); + Dimensions mat_b_strategy = stra.at(1); + + size_t mat_a_size = mat_a_strategy.size(); + size_t mat_b_size = mat_b_strategy.size(); + if ((mat_a_size != mat_a_dimension_) || (mat_b_size != mat_b_dimension_)) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong."; + } else { + MS_LOG(ERROR) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong."; + } + return FAILED; + } + + // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32] + // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) + // [16] in the example above + if (!transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.at(SECOND_FROM_END(mat_b_size)))) { + MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; + return FAILED; + } else if (transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.back())) { + MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; + return FAILED; + } + + if (mat_a_size >= mat_b_size) { + if (CheckRelevantDimension(mat_a_strategy, mat_b_strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; + return FAILED; + } + } else { + if (CheckRelevantDimension(mat_b_strategy, mat_a_strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; + return FAILED; + } + } + + if ((mat_a_dimension_ != 2 || mat_b_dimension_ != 2) && forward_reduce_scatter_) { + MS_LOG(WARNING) << name_ + << ": The dimension of mat a and mat b must be 2 in forward reduce scatter mode, " + "setting the forward reduce scatter mode to false here"; + forward_reduce_scatter_ = false; + } + + return SUCCESS; +} + +Status MatMulBase::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions mat_a_strategy = stra.at(0); + Dimensions mat_b_strategy = stra.at(1); + + SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_); + return SUCCESS; +} + +// all-reduce weight's grad +Status MatMulBase::InferMirrorOps() { + mirror_ops_.clear(); + + Shape mat_b_tensor_map = inputs_tensor_map_[1]; + std::vector mat_b_group; + if (CreateGroupByTensorMap(mat_b_tensor_map, &mat_b_group) != SUCCESS) { + return FAILED; + } + + OperatorVector op_for_inputs; // op_for_inputs is empty + OperatorVector op_for_weight; + + if (mat_b_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror ops is empty."; + return SUCCESS; + } else { + op_for_weight = CreateMirrorOps(mat_b_group[0].name(), mat_b_group[0].GetDevNum()); + mirror_ops_.push_back(op_for_inputs); + mirror_ops_.push_back(op_for_weight); + MS_LOG(INFO) << name_ << " : Create the mirror ops for weight success, group is " << mat_b_group[0].name(); + } + + return SUCCESS; +} + +Status MatMulBase::InferForwardCommunication() { + forward_op_.clear(); + size_t dimension = dev_matrix_shape_.size(); + size_t relevant_dimension_index = SECOND_FROM_END(dimension); + // Relevant dimension is not split and all reduce is not required + if (dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) { + MS_LOG(INFO) << name_ << " : Forward all reduce is not required."; + return SUCCESS; + } + + std::vector group_list; + if (CreateGroupByDim(relevant_dimension_index, &group_list) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer forward communication, create group failed."; + return FAILED; + } else if (group_list.empty()) { + MS_LOG(INFO) << name_ << " : Forward all reduce is not required."; + return SUCCESS; + } + + Operator op; + if (forward_reduce_scatter_) { + op = CreateReduceScatterOp(REDUCE_OP_SUM, group_list[0].name()); + } else { + op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name()); + } + + forward_op_.push_back(op); + MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name(); + return SUCCESS; +} + +Status MatMulBase::InferTensorMap() { + size_t size = dev_matrix_shape_.size(); + if (repeated_calc_num_ > 1) { + // move the first dimension(repeated_calc_num_), just for the convenience of tensor-map's calculation + size = dev_matrix_shape_.size() - 1; + } + + std::vector tensor_map_index; + // such as 5: tensor_map_index [4,3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); + } + + // infer output tensor map: [4,3,2,0], delete the second-from-end element + TensorMap output_tensor_map = tensor_map_index; + (void)output_tensor_map.erase(output_tensor_map.begin() + static_cast(SECOND_FROM_END(size))); + + // infer mat_a tensor map + // for example: mat_a_dimension is 4, mat_a tensor map:[4,3,2,1] + TensorMap mat_a_tensor_map = tensor_map_index; + // delete last one element + mat_a_tensor_map.pop_back(); + // delete the first (dev_matrix_size - 1 - mat_a_dimension) elements + (void)mat_a_tensor_map.erase( + mat_a_tensor_map.begin(), + mat_a_tensor_map.begin() + static_cast(LAST_INDEX(size) - mat_a_dimension_)); + + // infer mat_b tensor map + TensorMap mat_b_tensor_map = tensor_map_index; + // delete the third-to-last element + (void)mat_b_tensor_map.erase(mat_b_tensor_map.begin() + static_cast(THIRD_FROM_END(size))); + // delete the first (dev_matrix_size - 1 - mat_b_dimension) elements + (void)mat_b_tensor_map.erase( + mat_b_tensor_map.begin(), + mat_b_tensor_map.begin() + static_cast(LAST_INDEX(size) - mat_b_dimension_)); + if (transpose_b_) { + // swap the last two elements + int32_t last_value = mat_b_tensor_map.back(); + mat_b_tensor_map.pop_back(); + (void)mat_b_tensor_map.insert( + mat_b_tensor_map.begin() + static_cast(LAST_INDEX(mat_b_tensor_map.size())), last_value); + } + + if (forward_reduce_scatter_) { + if (dev_matrix_shape_.size() != 3) { + MS_LOG(WARNING) << name_ + << ": The dimension of dev matrix shape must be 3 in forward reduce scatter mode, " + "setting the forward reduce scatter mode to false here"; + forward_reduce_scatter_ = false; + } else if (outputs_shape_[0][0] % (dev_matrix_shape_[0] * dev_matrix_shape_[1]) != 0) { + MS_LOG(WARNING) << name_ + << ": The first dimension of output should be split by dev_matrix[0]*dev_matrix[1] in " + "forward reduce scatter mode, setting the forward reduce scatter mode to false here"; + forward_reduce_scatter_ = false; + } else { + // the forward reduce scatter only support that the dimension of output is 2 + output_tensor_map = {1, 0}; + } + } + + inputs_tensor_map_.push_back(mat_a_tensor_map); + inputs_tensor_map_.push_back(mat_b_tensor_map); + outputs_tensor_map_.push_back(output_tensor_map); + return SUCCESS; +} + +Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { + Shape output_dev_matrix_shape; + if (forward_reduce_scatter_) { + if (dev_matrix_shape_.size() != 3) { + MS_LOG(ERROR) << "The size of origin dev matrix shape must be 3 in forward reduce scatter mode"; + return FAILED; + } + output_dev_matrix_shape = {dev_matrix_shape_[0] * dev_matrix_shape_[1], dev_matrix_shape_[2]}; + } else { + output_dev_matrix_shape = dev_matrix_shape_; + } + + TensorLayout mat_a_layout, mat_b_layout, output_layout; + if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || + (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || + (output_layout.InitFromVector(output_dev_matrix_shape, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) { + return FAILED; + } + + inputs_layout->push_back(mat_a_layout); + inputs_layout->push_back(mat_b_layout); + outputs_layout->push_back(output_layout); + return SUCCESS; +} + +Status MatMulBase::InferTensorInfo() { + // infer tensor layout + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { + return FAILED; + } + + TensorLayout mat_a_layout = inputs_layout.at(0); + TensorLayout mat_b_layout = inputs_layout.at(1); + TensorLayout output_layout = outputs_layout.at(0); + TensorInfo mat_a_tensor_info(mat_a_layout); + TensorInfo mat_b_tensor_info(mat_b_layout); + TensorInfo output_tensor_info(output_layout); + + inputs_tensor_info_.push_back(mat_a_tensor_info); + inputs_tensor_info_.push_back(mat_b_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status MatMulBase::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + + if (forward_reduce_scatter_) { + virtual_div_op_.clear(); + MS_LOG(INFO) << "The forward reduce scatter mode does not involve repeated calculation, clear the virtual div op"; + } + + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) { + if (input->size() < 2) { + MS_LOG(ERROR) << name_ << " : The size of inputs small than 2."; + return FAILED; + } + auto last_1st_value = input->at(input->size() - 1); + auto last_2nd_value = input->at(input->size() - 2); + input->pop_back(); + input->pop_back(); + input->push_back(last_1st_value); + input->push_back(last_2nd_value); + return SUCCESS; +} + +Status MatMulBase::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << " : GetAttrs failed."; + return FAILED; + } + CheckGlobalDeviceManager(); + std::vector dev_list = g_device_manager->GetDeviceListByStageId(stage_id); + size_t dev_num = dev_list.size(); + Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1]; + if (transpose_a_) { + if (SwapLastTwoElements(&input0_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + } + if (transpose_b_) { + if (SwapLastTwoElements(&input1_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + } + // The shape of input0 (input1) + // E.g., input0 = [100, 200, 300], input1 = [300, 400] + + // Combining the input0_shape and input1_shape + // E.g., combined_shape = [100, 200, 300, 400] + is_auto_parallel_ = true; + size_t input1_shape_size = input1_shape.size(), input0_shape_size = input0_shape.size(); + Dimensions combined_partitions; + Shape combined_shape; + // In SwapLastTwoElements(), it is guaranteed that input0_shape.size() and input1_shape.size() are both larger than 2 + if (input0_shape.size() >= input1_shape.size()) { + combined_shape = input0_shape; + combined_shape.push_back(input1_shape[input1_shape.size() - 1]); + } else { + combined_shape = input1_shape; + combined_shape.push_back(input0_shape[input0_shape.size() - 2]); + } + std::function recursive = [&stage_id, &dev_num, &combined_partitions, &combined_shape, + &input1_shape_size, &recursive, &input0_shape_size, + this](uint32_t current_index, size_t n) { + // Finishing the recursive steps, if the strategy is valid, then calculate the cost + // for this operator under the strategy. + if (current_index == combined_shape.size()) { + StrategyPtr sp; + if (this->PrepareStrategy(stage_id, dev_num, combined_partitions, input0_shape_size, input1_shape_size, &sp) == + FAILED) { + return; + } + if (this->SetCostUnderStrategy(sp) == FAILED) { + MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed."; + return; + } + } else { + MS_LOG(DEBUG) << name_ << " : The value input0_shape_size: " << input0_shape_size + << ", input1_shape_size: " << input1_shape_size; + for (uint32_t i = 1; i <= n; i *= 2) { + if (n % i == 0 && IntToSize(combined_shape[current_index]) % i == 0) { + combined_partitions.push_back(i); + recursive(current_index + 1, n / i); + combined_partitions.pop_back(); + } + } + } + }; + recursive(0, dev_num); + if (strategy_cost_.empty()) { + MS_LOG(EXCEPTION) << name_ << " : No available strategy."; + } + return Status::SUCCESS; +} + +Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, + mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, + size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { + int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies()); + if (!FULLY_USE_DEVICES) { + if (IntToSize(product) > dev_num) { + return FAILED; + } + } else { + if (IntToSize(product) != dev_num) { + return FAILED; + } + } + Dimensions input0_partitions, input1_partitions; + if (input0_shape_size >= input1_shape_size) { + for (size_t i = 0; i < input0_shape_size; ++i) { + input0_partitions.push_back(combined_partitions[i]); + } + if (input1_shape_size == 2) { + input1_partitions.push_back(combined_partitions[combined_partitions.size() - 2]); + input1_partitions.push_back(combined_partitions[combined_partitions.size() - 1]); + } else { + // input1_shape.size() > 2 + for (size_t j = combined_partitions.size() - input1_shape_size - 1; j < combined_partitions.size(); ++j) { + if (j == combined_partitions.size() - 3) { + continue; + } + input1_partitions.push_back(combined_partitions[j]); + } + } + } else { + for (size_t i = 0; i < input1_shape_size; ++i) { + input1_partitions.push_back(combined_partitions[i]); + } + for (size_t j = combined_partitions.size() - input0_shape_size - 1; j < combined_partitions.size() - 3; ++j) { + input0_partitions.push_back(combined_partitions[j]); + } + input0_partitions.push_back(combined_partitions[combined_partitions.size() - 1]); + input0_partitions.push_back(combined_partitions[combined_partitions.size() - 3]); + } + if (transpose_a_) { + if (SwapLastTwoElements(&input0_partitions) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + } + if (transpose_b_) { + if (SwapLastTwoElements(&input1_partitions) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + } + std::vector stras; + stras.push_back(input0_partitions); + stras.push_back(input1_partitions); + (*sp) = std::make_shared(stage_id, stras); + + return SUCCESS; +} + +void MatMulBase::InitTensorInfoForCost(std::vector *relica_inputs_tensor_vector) { + TensorLayout tly; + if (transpose_a_) { + Shape replica_input0_shape(inputs_tensor_info_[0].shape()); + Shape replica_input0_slice_shape(inputs_tensor_info_[0].slice_shape()); + if (SwapLastTwoElements(&replica_input0_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + if (SwapLastTwoElements(&replica_input0_slice_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + + TensorInfo replica_input0_info(tly, replica_input0_shape, replica_input0_slice_shape); + relica_inputs_tensor_vector->push_back(replica_input0_info); + } else { + relica_inputs_tensor_vector->push_back(inputs_tensor_info_[0]); + } + if (transpose_b_) { + Shape replica_input1_shape(inputs_tensor_info_[1].shape()); + Shape replica_input1_slice_shape(inputs_tensor_info_[1].slice_shape()); + if (SwapLastTwoElements(&replica_input1_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + if (SwapLastTwoElements(&replica_input1_slice_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + + TensorInfo replica_input1_info(tly, replica_input1_shape, replica_input1_slice_shape); + relica_inputs_tensor_vector->push_back(replica_input1_info); + } else { + relica_inputs_tensor_vector->push_back(inputs_tensor_info_[1]); + } +} + +Status MatMulBase::CheckForTensorSliceValid() const { + if (!TENSOR_SLICE_ALIGNMENT_ENABLE) { + return SUCCESS; + } + if (inputs_tensor_info_.empty()) { + return FAILED; + } + for (auto &one_input_tensor : inputs_tensor_info_) { + auto slice_shape = one_input_tensor.slice_shape(); + if ((IntToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) || + (IntToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) { + return FAILED; + } + } + return SUCCESS; +} + +Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { + if (InitForCostModel(strategy) == FAILED) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Initialization under the strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Initialization under the strategy failed."; + } + return FAILED; + } + PrintStrategy(strategy); + // Check whether the tensor slice of input_tensor_info is valid or not + if (CheckForTensorSliceValid() != SUCCESS) { + MS_LOG(INFO) << name_ << " : The tensor slice is not valid under this strategy."; + return FAILED; + } + // Here, a replicated inputs_ is constructed for the transposed TensorInfo. + std::vector relica_inputs_tensor_vector; + InitTensorInfoForCost(&relica_inputs_tensor_vector); + + int32_t stage_id = strategy->GetInputStage(); + // Here, we use the origin outputs_, because we only use the slice size of the output tensor. + // It does not matter whether the output tensor is transposed or not. + double computation_cost = + operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + std::shared_ptr result = std::make_shared(computation_cost, communication_cost); + result->communication_without_parameter_ = + operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + result->communication_with_partial_para_ = + result->communication_without_parameter_ + + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); + + // Breaking ties for preferring data parallelization + BreakingTiesForPerferringDataParallel(strategy, result); + MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_ + << ", communication_cost: " << result->communication_cost_ + << ", communication_without_parameter_: " << result->communication_without_parameter_ + << ", communication_with_partial_para_: " << result->communication_with_partial_para_; + // refine communication cost calculation for practice + RefineForPracticalCost(result, false); + result->communication_forward_ = result->communication_without_parameter_; + + std::shared_ptr swc = + std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); + swc->cost_list.push_back(result); + strategy_cost_.emplace_back(swc); + + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h new file mode 100644 index 0000000000..d4e144c2b6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h @@ -0,0 +1,96 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ + +#include +#include +#include +#include + +#include "common/utils.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class MatMulBase : public OperatorInfo { + public: + MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~MatMulBase() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + // Generate all strategies and the corresponding cost for this MatMul operator + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, + size_t input1_shape_size, StrategyPtr *sp); + + Status SwapLastTwoElements(Shape *shape); + + protected: + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + void InitTensorInfoForCost(std::vector *); + Status CheckForTensorSliceValid() const; + Status GetAttrs() override; + + bool transpose_a_ = false; + bool transpose_b_ = false; + bool forward_reduce_scatter_ = false; + size_t mat_a_dimension_ = 0; + size_t mat_b_dimension_ = 0; +}; + +class MatMul : public MatMulBase { + public: + MatMul(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : MatMulBase(name, inputs_shape, outputs_shape, attrs) {} + ~MatMul() override = default; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; +}; + +class MatMulInfo : public MatMul { + public: + MatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : MatMul(name, inputs_shape, outputs_shape, attrs) {} + ~MatMulInfo() override = default; +}; + +class BatchMatMulInfo : public MatMul { + public: + BatchMatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : MatMul(name, inputs_shape, outputs_shape, attrs) {} + ~BatchMatMulInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc new file mode 100644 index 0000000000..15acb085f5 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc @@ -0,0 +1,311 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/onehot_info.h" + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/strategy.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status OneHotInfo::GetAttrs() { + auto iter = attrs_.find(AXIS); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis_value_ptr_ = iter->second; + axis_ = iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << ": The value of axis is not int."; + return FAILED; + } + } + + if (inputs_shape_[0].size() != 1) { + MS_LOG(ERROR) << name_ << ": Input's shape only support 1-D now."; + return FAILED; + } + + if ((axis_ > 1) || (axis_ < -1)) { + MS_LOG(ERROR) << name_ << ": Axis " << axis_ << " is out of range[-1, 1]."; + return FAILED; + } + return SUCCESS; +} + +Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { + if (inputs_shape_.size() != 3) { + MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != 1) { + MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size(); + return FAILED; + } + if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}, + is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + return SUCCESS; +} + +Status OneHotInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + + // Now input only support 1-D tensor, so the output is a 2-D tensor + // If input is a vector of length features, the output shape will be: + // [features, depth] if axis == -1 (or axis == 1) + // [depth, features] if axis == 0 + if (axis_ == 0) { + dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable + dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable + } else { + dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable + dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable + } + + return SUCCESS; +} + +Status OneHotInfo::InferTensorMap() { + std::vector input_tensor_map_index, output_tensor_map_index; + size_t size = outputs_shape_[0].size(); + // such as 2: tensor_map_index [1,0] + if (axis_ == 0) { + for (size_t i = 0; i < size; ++i) { + output_tensor_map_index.push_back((int32_t)(i)); + } + } else { + for (size_t i = 0; i < size; ++i) { + output_tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); + } + } + outputs_tensor_map_.push_back(output_tensor_map_index); + + // Now input only support 1-D tensor + input_tensor_map_index.push_back(1); + + inputs_tensor_map_.push_back(input_tensor_map_index); + return SUCCESS; +} + +// axis = -1 +// (0,(1,16),(),())reid dev_matrix=(1,16) map_in=(1) map_out=(1,0) +// (0,(16,1),(),())data parallel dev_matrix=(16,1) map_in=(1) map_out=(1,0) +// (0,(2,8),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between +// machines dev_matrix=(2,8) map_in=(1) map_out=(1,0) (0, (2,4),(),())16 devices dev_matrix=(2,4,2) map_in=(1) +// map_out=(1,0) +// axis = 0 +// (0, (16,1),(),())reid dev_matrix=(1,16) map_in=(1) map_out=(0,1) +// (0, (1,16),(),())data parallel dev_matrix=(16,1) map_in=(1) map_out=(0,1) +// (0, (8,2),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between +// machines dev_matrix=(2,8) map_in=(1) map_out=(0,1) (0,(4,2),(),())16 devices dev_matrix=(2,4,2) map_in=(1) +// map_out=(0,1) +Status OneHotInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape output_shape = outputs_shape_.at(0); + + TensorLayout input_tensor_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout); + TensorInfo output_tensor_info(output_tensor_layout); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + + return SUCCESS; +} + +Status OneHotInfo::ExtractInputInfo() { + CheckGlobalDeviceManager(); + rank_ = g_device_manager->global_rank(); + mod_rank_ = rank_ % dev_matrix_shape_.back(); + if (!cnode_) { + MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; + return FAILED; + } + if (cnode_->inputs().size() != 5) { + MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, real input size is " + << cnode_->inputs().size(); + return FAILED; + } + if (input_value_.size() != 4) { + MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, and input value size " + "must be 4, real size is " + << input_value_.size(); + return FAILED; + } + auto value_ptr = input_value_.at(1); + if (value_ptr == nullptr) { + MS_LOG(WARNING) << "Input 2 of cnode is not a value node, its type is " << cnode_->input(2)->type_name(); + return FAILED; + } + + if (value_ptr->isa()) { + total_class_number_ = value_ptr->cast()->value(); + } else { + MS_LOG(ERROR) << "OneHot Primitive depth type must be int"; + return FAILED; + } + classes_each_device_ = total_class_number_ / dev_matrix_shape_.back(); + + return SUCCESS; +} + +Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + if (dev_matrix_shape_.back() == 1) { + replace_graph_ = nullptr; + return SUCCESS; + } + if (ExtractInputInfo() != SUCCESS) { + MS_LOG(ERROR) << "ExtractInputInfo failed"; + return FAILED; + } + GenerateGraph gen_g = GenerateGraph(); + Status status = gen_g.Init(cnode); + if (status != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + + auto floor_div = + gen_g.PushBack({gen_g.NewOpInst(FLOORDIV), gen_g.virtual_input_node(), CreateInt32Tensor(classes_each_device_)}); + auto mul1 = gen_g.PushBack({gen_g.NewOpInst(MUL), floor_div, CreateInt32Tensor(classes_each_device_)}); + auto sub1 = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), mul1}); + auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)}); + auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)}); + auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast}); + auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(TENSOR_ADD), mul2, CreateInt32Tensor(1)}); + auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add}); + auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)}); + Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_); + OperatorAttrs attrs_onehot = {attr_onehot_axis}; + auto onehot = gen_g.PushBack({gen_g.NewOpInst(ONEHOT, attrs_onehot), sub2, CreatInt32Imm(classes_each_device_), + cnode->input(3), cnode->input(4)}); + std::vector> input_nodes = {std::make_pair(floor_div, 1), std::make_pair(sub1, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, onehot)); + + return SUCCESS; +} + +ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) { + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return nullptr; + } + return replace_graph_; +} + +Status OneHotInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + Status status = ComputeReplaceGraph(cnode_); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return status; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status OneHotInfo::GenerateStrategies(int32_t stage_id) { + Shapes splittable_inputs = {{1, 1}, {}, {}}; + std::vector sp_vector; + if (inputs_shape_.size() != 3) { + MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != 1) { + MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size(); + return FAILED; + } + is_auto_parallel_ = true; + if (GenerateStrategiesForIndependentInputs(stage_id, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}, + splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategies failed."; + return FAILED; + } + + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + + return SUCCESS; +} + +Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +std::shared_ptr>> OneHotInfo::GenerateBatchStrategies() { + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + Dimensions strategy = {SizeToInt(dev_num), 1}; + Dimensions empty_strategy; + std::vector strategy_v = {strategy, empty_strategy, empty_strategy}; + return std::make_shared>>(strategy_v); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h new file mode 100644 index 0000000000..dfd7e6cbaf --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class OneHotInfo : public OperatorInfo { + public: + OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~OneHotInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + std::shared_ptr>> GenerateBatchStrategies() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status ExtractInputInfo(); + + private: + Status ComputeReplaceGraph(const CNodePtr &cnode); + + int axis_ = -1; + int32_t rank_ = 0; + int32_t total_class_number_ = 1; + int32_t classes_each_device_ = 1; + ValuePtr axis_value_ptr_; + int32_t mod_rank_ = 0; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc new file mode 100644 index 0000000000..3dd47b1de6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -0,0 +1,1334 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/operator_info.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/dtype.h" +#include "ir/tensor.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/context.h" +#include "utils/context/ms_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) { + if (strategy == nullptr) { + MS_LOG(ERROR) << "The strategy is null."; + return FAILED; + } + + size_t strategy_size = strategy->GetInputNumber(); + size_t inputs_shape_size = inputs_shape.size(); + if (strategy_size != inputs_shape_size) { + if (is_auto_parallel) { + MS_LOG(DEBUG) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; + } else { + MS_LOG(ERROR) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + for (size_t i = 0; i < strategy_size; ++i) { + Shape sub_strategy = stra.at(i); + Shape sub_input_shape = inputs_shape.at(i); + size_t strategy_len = sub_strategy.size(); + size_t inputs_len = sub_input_shape.size(); + if (strategy_len != inputs_len) { + if (is_auto_parallel) { + MS_LOG(DEBUG) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len + << ", index: " << i; + } else { + MS_LOG(ERROR) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len + << ", index: " << i; + } + return FAILED; + } + + for (size_t j = 0; j < strategy_len; ++j) { + int32_t strategy_value = sub_strategy.at(j); + if (strategy_value < MIN_SLICE_NUM) { + if (is_auto_parallel) { + MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value; + } else { + MS_LOG(ERROR) << "Invalid strategy value: " << strategy_value; + } + return FAILED; + } + + if ((IntToUint(strategy_value) & IntToUint(strategy_value - 1)) != 0) { + if (is_auto_parallel) { + MS_LOG(DEBUG) << "Invalid Strategy value it is not the power of 2, " << strategy_value; + } else { + MS_LOG(ERROR) << "Invalid Strategy value it is not the power of 2, " << strategy_value; + } + return FAILED; + } + + int32_t shape_value = sub_input_shape.at(j); + if ((shape_value % strategy_value) != 0) { + if (is_auto_parallel) { + MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; + } else { + MS_LOG(ERROR) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; + } + return FAILED; + } + } + } + + return SUCCESS; +} + +void OperatorInfo::ResetQueueMember() { + inputs_tensor_info_.clear(); + outputs_tensor_info_.clear(); + inputs_tensor_map_.clear(); + outputs_tensor_map_.clear(); + dev_matrix_shape_.clear(); + forward_op_.clear(); + mirror_ops_.clear(); + sub_ops_.clear(); + replace_op_.clear(); + replace_op_info_.clear(); + virtual_div_op_.clear(); + global_device_list_.clear(); +} + +Status OperatorInfo::InferAttrs() { + if (infer_attrs_completed_) { + return SUCCESS; + } + + if (GetAttrs() != SUCCESS) { + return FAILED; + } + infer_attrs_completed_ = true; + return SUCCESS; +} + +void OperatorInfo::SetDeviceListByStrategy() { + int32_t stage = strategy_->GetInputStage(); + CheckGlobalDeviceManager(); + global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); +} + +Status OperatorInfo::InferRepeatedCalcInfo() { + int32_t g_dev_list_size = SizeToInt(global_device_list_.size()); + int32_t dev_matrix_size = + std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); + if (dev_matrix_size == 0) { + MS_LOG(ERROR) << name_ << ": The dev matrix size is 0"; + return FAILED; + } + + if (g_dev_list_size == dev_matrix_size) { + repeated_calc_num_ = 1; + } else if (g_dev_list_size % dev_matrix_size == 0) { + repeated_calc_num_ = g_dev_list_size / dev_matrix_size; + } else { + MS_LOG(ERROR) << name_ << ": Dev list size " << g_dev_list_size << " can not be divisible by dev matrix size " + << dev_matrix_size; + return FAILED; + } + + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + int32_t stage = strategy_->GetInputStage(); + local_device_list_ = g_device_manager->global_device_list(stage, rank, repeated_calc_num_); + + return SUCCESS; +} + +// if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix, +// only use for infer tensor layout +void OperatorInfo::SetRepeatedCalcDevMatrix() { + if (repeated_calc_num_ <= 1) { + return; + } + + (void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_); +} + +// use for loss repeated calculation +Operator CreateVirtualDivOp(int32_t div_num) { + OperatorName operator_name = VIRTUAL_DIV; + ValuePtr attr0_value = MakeValue(div_num); + Attr attr0 = std::make_pair(DIVISOR, attr0_value); + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + + OperatorParams operator_param; + OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_arg); + return op; +} + +// use for forward all reduce +Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) { + OperatorName operator_name = ALL_REDUCE; + ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM + ValuePtr attr1_value = MakeValue(group); // group + Attr attr0 = std::make_pair(OP, attr0_value); + Attr attr1 = std::make_pair(GROUP, attr1_value); + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + operator_attrs.push_back(attr1); + + OperatorParams operator_param; + OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_arg); + MS_LOG(INFO) << "Create all reduce op success, the reduce_op is " << reduce_op << ", the group is " << group; + return op; +} + +Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) { + OperatorName operator_name = REDUCE_SCATTER; + ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM + ValuePtr attr1_value = MakeValue(group); // group + Attr attr0 = std::make_pair(OP, attr0_value); + Attr attr1 = std::make_pair(GROUP, attr1_value); + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + operator_attrs.push_back(attr1); + + OperatorParams operator_param; + OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_arg); + MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is " << reduce_op << ", the group is " << group; + return op; +} + +// use for get tensor slice +Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { + Shape tensor_map = tensor_layout.tensor_map().array(); + Shape dev_matrix_shape = tensor_layout.device_arrangement().array(); + OperatorName operator_name = GET_TENSOR_SLICE; + + OperatorAttrs attrs; + ValuePtr dev_mat_value = MakeValue(dev_matrix_shape); + Param dev_mat_param = std::make_pair(std::make_pair(DEV_MAT, dev_mat_value), 2); + ValuePtr tensor_map_value = MakeValue(tensor_map); + Param tensor_map_param = std::make_pair(std::make_pair(TENSOR_MAP, tensor_map_value), 3); + OperatorParams params = {dev_mat_param, tensor_map_param}; + OperatorArgs operator_arg = std::make_pair(attrs, params); + + Operator op = std::make_pair(operator_name, operator_arg); + MS_LOG(INFO) << "Create get tensor slice op success, the dev mat and tensor map is " + << ShapeToString(dev_matrix_shape) << ", " << ShapeToString(tensor_map); + return op; +} + +OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { + if ((dev_num == 0) || (dev_num == 1)) { + MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num; + } + OperatorVector op_for_weight; + bool mean_flag = ParallelContext::GetInstance()->mirror_mean(); + + OperatorName operator_name = MIRROR_OPERATOR; + ValuePtr attr0_value = MakeValue(group_name); + ValuePtr attr1_value = MakeValue(SizeToInt(dev_num)); + ValuePtr attr2_value = MakeValue(mean_flag); + + Attr attr0 = std::make_pair(GROUP, attr0_value); + Attr attr1 = std::make_pair(DEV_NUM, attr1_value); + Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value); + + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + operator_attrs.push_back(attr1); + operator_attrs.push_back(attr2); + + OperatorParams operator_param; + OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_args); + + op_for_weight.push_back(op); + MS_LOG(INFO) << "The group name is " << group_name << ", the dev num is " << dev_num << ", the mean flag is " + << mean_flag; + return op_for_weight; +} + +Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group) { + if (group == nullptr) { + MS_LOG(ERROR) << "The group is null."; + return FAILED; + } + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); + RankList group_devices; + if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { + return FAILED; + } + + if (group_devices.size() == 1) { + MS_LOG(INFO) << "The dev size is 1, no need to create group."; + return SUCCESS; + } + + Group g = g_device_manager->CreateGroup(group_devices); + group->push_back(g); + return SUCCESS; +} + +Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector *group) { + if (group == nullptr) { + MS_LOG(ERROR) << "The group is null."; + return FAILED; + } + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); + RankList group_devices; + if (dev_matrix.GetDevicesAlongDim(SizeToUint(axis), &group_devices) != SUCCESS) { + return FAILED; + } + + if (group_devices.size() == 1) { + MS_LOG(INFO) << "The dev size is 1, no need to create group."; + return SUCCESS; + } + + Group g = g_device_manager->CreateGroup(group_devices); + group->push_back(g); + return SUCCESS; +} + +Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) { + Shape slice_shape; + if (std::any_of(strategy.begin(), strategy.end(), [](int32_t value) { return value <= 0; })) { + MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0"; + return slice_shape; + } + for (size_t i = 0; i < strategy.size(); ++i) { + slice_shape.push_back(tensor_shape.at(i) / strategy.at(i)); + } + return slice_shape; +} + +Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) { + if (slice_shapes == nullptr) { + MS_LOG(ERROR) << "The slice_shapes is null."; + return FAILED; + } + if (strategys.size() != shapes.size()) { + MS_LOG(ERROR) << "Strategy size " << strategys.size() << " not equal to shape size " << shapes.size(); + return FAILED; + } + + for (size_t i = 0; i < strategys.size(); ++i) { + if (strategys.at(i).size() != shapes.at(i).size()) { + MS_LOG(ERROR) << "Strategy dimension " << strategys.at(i).size() << " not equal to shape dimension " + << shapes.at(i).size(); + slice_shapes->clear(); + return FAILED; + } + + for (size_t j = 0; j < shapes.at(i).size(); ++j) { + if (strategys.at(i).at(j) <= 0) { + MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategys[i]) + << " the element is less than or equal to 0."; + slice_shapes->clear(); + return FAILED; + } + if (shapes.at(i).at(j) % strategys.at(i).at(j) != 0) { + MS_LOG(ERROR) << "Shape cannot be divisible by strategy, " << shapes.at(i).at(j) << " : " + << strategys.at(i).at(j); + slice_shapes->clear(); + return FAILED; + } + } + Shape slice_shape = GetSliceShape(shapes.at(i), strategys.at(i)); + slice_shapes->push_back(slice_shape); + } + + return SUCCESS; +} + +Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, + Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) { + if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) { + MS_LOG(ERROR) << "The slice_shape is null."; + return FAILED; + } + + if (InferSliceShapeByStrategy(inputs_strategy, inputs_shape_, inputs_slice_shape) != SUCCESS) { + MS_LOG(ERROR) << "Infer inputs slice shape error."; + return FAILED; + } + + if (InferSliceShapeByStrategy(outputs_strategy, outputs_shape_, outputs_slice_shape) != SUCCESS) { + MS_LOG(ERROR) << "Infer outputs slice shape error."; + inputs_slice_shape->clear(); + return FAILED; + } + + return SUCCESS; +} + +// method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1 +Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) { + if (strategy == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null."; + return FAILED; + } + + if (InferAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferAttrs failed."; + return FAILED; + } + + // must be after InferAttrs() + if (CheckStrategy(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": CheckStrategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": CheckStrategy failed."; + } + return FAILED; + } + + // need to clear queues before Init(), + // because Init() may be called multiple times by cost model + ResetQueueMember(); + + strategy_ = strategy; + SetDeviceListByStrategy(); + + if (InferDevMatrixShape() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; + return FAILED; + } + + used_devices_ = std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); + + // must be after InferDevMatrixShape + if (InferRepeatedCalcInfo() != SUCCESS) { + MS_LOG(ERROR) << ": InferRepeatedCalcInfo failed."; + return FAILED; + } + + // if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix for layout + SetRepeatedCalcDevMatrix(); + + if (InferTensorMap() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorMap failed."; + return FAILED; + } + + if (InferTensorInfo() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; + return FAILED; + } + + return SUCCESS; +} + +// method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape +Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) { + if (strategy == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null."; + return FAILED; + } + + if (InferAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferAttrs failed."; + return FAILED; + } + + // must be after InferAttrs() + if (CheckStrategy(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": CheckStrategy failed."; + return FAILED; + } + + // need to clear queues before Init(), + // because Init() may be called multiple times by cost model + ResetQueueMember(); + + strategy_ = strategy; + SetDeviceListByStrategy(); + + if (InferDevMatrixShape() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; + return FAILED; + } + + // must be after InferDevMatrixShape + if (InferRepeatedCalcInfo() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferRepeatedCalcInfo failed."; + return FAILED; + } + + if (InferTensorMap() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorMap failed."; + return FAILED; + } + + if (InferTensorInfo() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; + return FAILED; + } + + return SUCCESS; +} + +Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) { + if (strategy == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null."; + return FAILED; + } + + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + return FAILED; + } + + if (InferForwardCommunication() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed."; + return FAILED; + } + + if (InferMirrorOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; + return FAILED; + } + + if (InferVirtualDivOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; + return FAILED; + } + + return SUCCESS; +} + +Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) { + if (strategy == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null."; + return FAILED; + } + + if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { + return FAILED; + } + + if (InferForwardCommunication() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed."; + return FAILED; + } + + if (InferMirrorOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; + return FAILED; + } + + if (InferVirtualDivOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; + return FAILED; + } + + return SUCCESS; +} + +std::vector> OperatorInfo::GetAliveSuccEdges() { + std::vector> ret; + for (auto &edge : succ_edges_) { + if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) { + ret.push_back(edge); + } else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) { + // CAST is ordered in front of L2NORMALIZE + ret.push_back(edge); + } + } + for (auto &edge : succ_edges_) { + if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) && + (edge->next_operator()->name().find(CAST) == std::string::npos)) { + ret.push_back(edge); + } + } + return ret; +} + +std::vector> OperatorInfo::GetAlivePrevEdges() { + std::vector> ret; + for (auto &edge : prev_edges_) { + if (edge->prev_operator()->is_alive()) { + ret.push_back(edge); + } + } + return ret; +} + +void OperatorInfo::ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { + if (op == nullptr) { + MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null."; + return; + } + for (auto &edge : prev_edges_) { + if (edge->prev_operator() == op) { + edge = new_edge; + return; + } + } + MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; +} + +void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { + if (op == nullptr) { + MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null."; + return; + } + for (auto &edge : succ_edges_) { + if (edge->next_operator() == op) { + edge = new_edge; + return; + } + } + MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; +} + +void OperatorInfo::ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { + if (op == nullptr) { + MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null."; + return; + } + std::vector> new_pre_edges; + for (auto &edge : prev_edges_) { + if (edge->prev_operator() != op) { + new_pre_edges.push_back(edge); + } + } + new_pre_edges.push_back(new_edge); + prev_edges_ = new_pre_edges; +} + +void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { + if (op == nullptr) { + MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null"; + return; + } + std::vector> new_succ_edges; + for (auto &edge : succ_edges_) { + if (edge->next_operator() != op) { + new_succ_edges.push_back(edge); + } + } + new_succ_edges.push_back(new_edge); + succ_edges_ = new_succ_edges; +} + +std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( + const Shapes &shapes, const std::vector &split_flag_list) { + if (shapes.size() != split_flag_list.size()) { + MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : " + << shapes.size(); + return nullptr; + } + CheckGlobalDeviceManager(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); + std::vector> strategy_v; + for (size_t i = 0; i != shapes.size(); i++) { + if (shapes[i].empty()) { + MS_LOG(INFO) << "Elements of shapes is empty."; + std::vector empty_element; + strategy_v.push_back(empty_element); + } else { + std::vector element(shapes[i].size(), 1); + if (split_flag_list[i]) { + element[0] = dev_num; + } + strategy_v.push_back(element); + } + } + return std::make_shared>>(strategy_v); +} + +void OperatorInfo::ReComputeBatchSplitFlagList() { + if (!inputs_shape_.empty()) { + split_flag_list_[0] = true; + } +} + +void OperatorInfo::ComputeBatchSplitFlagList() { + split_flag_list_.clear(); + for (auto iter = inputs_shape_.begin(); iter != inputs_shape_.end(); ++iter) { + split_flag_list_.push_back(false); + } + ReComputeBatchSplitFlagList(); +} + +// This is a common method for checking whether the generated stragegy has the correct number of devuces. +Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { + if (sp == nullptr) { + MS_LOG(ERROR) << "The strategy is null."; + return FAILED; + } + int32_t product = 1; + + for (auto &input_partition : inputs_partitions) { + product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies()); + } + if (!FULLY_USE_DEVICES) { + if (IntToSize(product) > dev_num) { + return FAILED; + } + } else { + if ((product != 1) && (IntToSize(product) != dev_num)) { + return FAILED; + } + } + std::vector stras(inputs_partitions); + (*sp) = std::make_shared(stage_id, stras); + return SUCCESS; +} + +std::shared_ptr>> OperatorInfo::GenerateBatchStrategies() { + ComputeBatchSplitFlagList(); + return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); +} + +void PrintStrategy(const StrategyPtr &strategy) { + if (strategy == nullptr) { + return; + } + std::string all_strategy = ""; + for (size_t i = 0; i < strategy->GetInputNumber(); ++i) { + all_strategy += "["; + for (size_t j = 0; j < strategy->GetInputDim()[i].size(); ++j) { + all_strategy += std::to_string(strategy->GetInputDim()[i][j]); + if (j != strategy->GetInputDim()[i].size() - 1) { + all_strategy += ", "; + } + } + all_strategy += "]"; + if (i != strategy->GetInputNumber() - 1) { + all_strategy += ", "; + } + } + MS_LOG(INFO) << "The strategy is: " << all_strategy; +} + +// generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d]) +Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + + if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) { + MS_LOG(ERROR) << "The inputs size is wrong."; + return FAILED; + } + + if ((inputs_shape[0].size() != inputs_shape[1].size()) || + (splittable_inputs[0].size() != splittable_inputs[1].size())) { + MS_LOG(ERROR) << "The size of two inputs are not equal."; + return FAILED; + } + + Shapes input0_shape = {inputs_shape[0]}; + Shapes input0_splittable = {splittable_inputs[0]}; + if (GenerateStrategiesForIndependentInputs(stage_id, input0_shape, input0_splittable, sp_vector) != SUCCESS) { + return FAILED; + } + + for (auto &sp : *sp_vector) { + sp->ExpandInputDimFromOneToTwo(); + } + + return SUCCESS; +} + +// generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast +// such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) +Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + + if (inputs_shape[0].size() >= inputs_shape[1].size()) { + MS_LOG(ERROR) << "Invalid inputs shape."; + return FAILED; + } + + // first, generate strategy for input0 the same as input1 + Shapes tmp_inputs_shape = {inputs_shape[1], inputs_shape[1]}; + Shapes tmp_splittable_inputs = {splittable_inputs[1], splittable_inputs[1]}; + if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; + return FAILED; + } + + // second, get the correct strategy for input0 + for (auto &sp : *sp_vector) { + std::vector tmp_strategy; + Dimensions input0_strategy = sp->GetInputDim()[0]; + size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size(); + + // erase the unnecessary part + (void)input0_strategy.erase(input0_strategy.begin(), + input0_strategy.begin() + static_cast(size_diff)); + + // handel the case likes ([1, c, d], [a, b, c, d]) + for (size_t i = 0; i < inputs_shape[0].size(); ++i) { + if (inputs_shape[0][i] == 1) { + input0_strategy[i] = 1; + } else { + break; + } + } + + // reset the strategy + tmp_strategy.push_back(input0_strategy); // input0 + tmp_strategy.push_back(sp->GetInputDim()[1]); // input1 + sp->ResetInputs(tmp_strategy); + } + return SUCCESS; +} + +// generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast +// such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) +Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + + if (inputs_shape[0].size() <= inputs_shape[1].size()) { + MS_LOG(ERROR) << "Invalid inputs shape."; + return FAILED; + } + + // first, generate strategy for input1 the same as input0 + Shapes tmp_inputs_shape = {inputs_shape[0], inputs_shape[0]}; + Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]}; + if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; + return FAILED; + } + + // second, get the correct strategy for input1 + for (auto &sp : *sp_vector) { + std::vector tmp_strategy; + tmp_strategy.push_back(sp->GetInputDim()[0]); // input0 + + Dimensions input1_strategy = sp->GetInputDim()[1]; + size_t size_diff = inputs_shape[0].size() - inputs_shape[1].size(); + + // erase the unnecessary part + (void)input1_strategy.erase(input1_strategy.begin(), + input1_strategy.begin() + static_cast(size_diff)); + + // handel the case likes ([a, b, c, d], [1, c, d]) + for (size_t i = 0; i < inputs_shape[1].size(); ++i) { + if (inputs_shape[1][i] == 1) { + input1_strategy[i] = 1; + } else { + break; + } + } + + // reset the strategy + tmp_strategy.push_back(input1_strategy); // input1 + sp->ResetInputs(tmp_strategy); + } + return SUCCESS; +} + +// generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast +// such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) +Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + + if (inputs_shape[0].size() != inputs_shape[1].size()) { + MS_LOG(ERROR) << "Invalid inputs shape."; + return FAILED; + } + + // step1: ([a, 1], [1, b]) -> [a, b] + Shape max_shape, splittable_vector; + for (size_t i = 0; i < inputs_shape[0].size(); ++i) { + if (inputs_shape[0][i] >= inputs_shape[1][i]) { + max_shape.push_back(inputs_shape[0][i]); + splittable_vector.push_back(splittable_inputs[0][i]); + } else { + max_shape.push_back(inputs_shape[1][i]); + splittable_vector.push_back(splittable_inputs[1][i]); + } + } + + // step2: ([a, 1], [1, b]) -> generate strategy for ([a, b], [a, b]) + Shapes tmp_inputs_shape = {max_shape, max_shape}; + Shapes tmp_splittable_inputs = {splittable_vector, splittable_vector}; + if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; + return FAILED; + } + + // step3: reset the strategy if the dimension is 1 + for (auto &sp : *sp_vector) { + Dimensions input0_strategy = sp->GetInputDim()[0]; + Dimensions input1_strategy = sp->GetInputDim()[1]; + for (size_t i = 0; i < inputs_shape[0].size(); ++i) { + if (inputs_shape[0][i] == 1) { + input0_strategy[i] = 1; + } + + if (inputs_shape[1][i] == 1) { + input1_strategy[i] = 1; + } + } + sp->ResetInputs({input0_strategy, input1_strategy}); + } + + return SUCCESS; +} + +// 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that +// the corresponding dimension is unsplittable, '1' in 'splittable_inputs' means that the corresponding +// dimension is splittable. 'inputs_partitions' is the result of partitions. +// NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring +// specific dimensions in inputs have the identical partition should have individual implementation. +Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, + std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + if (splittable_inputs.size() != inputs_shape.size()) { + MS_LOG(ERROR) << "Splittable_inputs do not have the same input number of inputs shape, " << splittable_inputs.size() + << " : " << inputs_shape.size(); + return FAILED; + } + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape combined_inputs_shape, combined_splittable_inputs, combined_partitions; + for (size_t j = 0; j < inputs_shape.size(); ++j) { + (void)combined_inputs_shape.insert(combined_inputs_shape.end(), inputs_shape[j].begin(), inputs_shape[j].end()); + (void)combined_splittable_inputs.insert(combined_splittable_inputs.end(), splittable_inputs[j].begin(), + splittable_inputs[j].end()); + } + std::function recursive = [&stage_id, &dev_num, &sp_vector, &combined_inputs_shape, + &combined_splittable_inputs, &combined_partitions, &recursive, + &inputs_shape](uint32_t current_index, size_t n) { + if (current_index == combined_inputs_shape.size()) { + MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size(); + Shapes inputs_partitions; + size_t global_index = 0; + for (auto &shape : inputs_shape) { + Shape tmp_partition; + for (size_t j = 0; j < shape.size(); ++j) { + tmp_partition.push_back(combined_partitions[global_index]); + global_index++; + } + inputs_partitions.push_back(tmp_partition); + } + StrategyPtr sp; + if (PrepareStrategyBase(stage_id, dev_num, inputs_partitions, &sp) == SUCCESS) { + sp_vector->push_back(sp); + } + return; + } else { + MS_LOG(DEBUG) << "The value of sp_vector size is " << sp_vector->size(); + if (combined_splittable_inputs[current_index] == 0) { + combined_partitions.push_back(MIN_SLICE_NUM); + recursive(current_index + 1, n / MIN_SLICE_NUM); + combined_partitions.pop_back(); + } else if (combined_splittable_inputs[current_index] == 1) { + for (uint32_t i = 1; i <= n; i *= 2) { + if (n % i == 0 && IntToSize(combined_inputs_shape[current_index]) % i == 0) { + combined_partitions.push_back(i); + recursive(current_index + 1, n / i); + combined_partitions.pop_back(); + } + } + } + } + }; + recursive(0, dev_num); + if (sp_vector->empty()) { + MS_LOG(EXCEPTION) << "No available strategy for current OperatorInfo."; + } + return SUCCESS; +} + +// generate strategies for that have two inputs, and input0 or input1 maybe broadcast, +// and the corresponding dimensions that are not broadcast are all relevant dimensions +// such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) +// or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) +// or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) +Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + + if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) { + MS_LOG(ERROR) << "The inputs' size is wrong."; + return FAILED; + } + + if (inputs_shape[0] == inputs_shape[1]) { + // element wise operation([a, b, c, d], [a, b, c, d]), so input0's strategy is equal to input1's strategy + if (GenerateStrategiesForTwoEqualInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; + return FAILED; + } + MS_LOG(INFO) << "GenerateStrategiesForTwoEqualInputs success."; + } else if (inputs_shape[0].empty() || inputs_shape[1].empty()) { + // ([a, b, c, d], []) or ([], [a, b, c, d]) + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "Generate strategies for scalar case failed."; + return FAILED; + } + MS_LOG(INFO) << "Generate strategies for scalar case success."; + } else if (inputs_shape[0].size() > inputs_shape[1].size()) { + // ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) + if (GenerateStrategiesForBroadcastRight(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForBroadcastRight failed."; + return FAILED; + } + MS_LOG(INFO) << "GenerateStrategiesForBroadcastRight success."; + } else if (inputs_shape[0].size() < inputs_shape[1].size()) { + // ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) + if (GenerateStrategiesForBroadcastLeft(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForBroadcastLeft failed."; + return FAILED; + } + MS_LOG(INFO) << "GenerateStrategiesForBroadcastLeft success."; + } else { // same size, but different value + // ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) + if (GenerateStrategiesForBroadcastBoth(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForBroadcastBoth failed."; + return FAILED; + } + MS_LOG(INFO) << "GenerateStrategiesForBroadcastBoth success."; + } + return SUCCESS; +} + +Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { + if (InitForCostModel(strategy) == FAILED) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Initialization under the strategy failed."; + } + return FAILED; + } + int32_t stage_id = strategy->GetInputStage(); + double computation_cost = + operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + std::shared_ptr result = std::make_shared(computation_cost, communication_cost); + result->communication_without_parameter_ = + operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + result->communication_with_partial_para_ = + result->communication_without_parameter_ + + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); + + // Breaking ties for preferring data parallelization + BreakingTiesForPerferringDataParallel(strategy, result); + // refine communication cost calculation for practice + RefineForPracticalCost(result, false); + result->communication_forward_ = result->communication_without_parameter_; + + std::shared_ptr swc = + std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); + swc->cost_list.push_back(result); + strategy_cost_.emplace_back(swc); + + return SUCCESS; +} + +int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { + if (is_output_parameter_involve_ != -1) { + return is_output_parameter_involve_; + } + is_parameter_involve_ = is_parameter_; + const auto &prev_edges = this->GetAlivePrevEdges(); + for (auto &p_edge : prev_edges) { + auto input_index = p_edge->next_op_input_index(); + auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved(); + if (input_index >= is_parameter_involve_.size()) { + MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size() + << ", but got wrong input_index: " << input_index; + } + if (prev_op_para == 0) { + is_parameter_involve_[input_index] = false; + } else if (prev_op_para == 1) { + is_parameter_involve_[input_index] = true; + } else { + MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index; + } + p_edge->set_parameter_involve(prev_op_para); + } + if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) { + // If anyone of the input is a parameter_involved, the output is parameter_involved. + is_output_parameter_involve_ = 1; + } else { + is_output_parameter_involve_ = 0; + } + + return is_output_parameter_involve_; +} + +Status OperatorInfo::set_is_parameter(const std::vector &is_parameter) { + if (is_parameter.size() != inputs_shape_.size()) { + MS_LOG(ERROR) << "Is_parameter: " << is_parameter.size() + << " do not have the same number of inputs_shape_: " << inputs_shape_.size(); + return FAILED; + } + is_parameter_ = is_parameter; + operator_cost()->set_is_parameter(is_parameter); + return SUCCESS; +} + +Status OperatorInfo::CalculateMemoryCost() { + // First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to + // calculate memory cost. + if (is_parameter_involve_.size() != is_parameter_.size()) { + MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."; + return FAILED; + } + operator_cost()->set_is_parameter_involve(is_parameter_involve_); + operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); + // Set the memory cost in the 'strategy_cost_' + for (auto &swc : strategy_cost_) { + auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); + swc->cost_list[0]->memory_with_reuse_ = mem_cost; + } + return SUCCESS; +} + +Status OperatorInfo::CalculateMemoryCostForInference() { + // First, set the 'is_outputs_critical_' flag into OperatorCost. + if (is_output_critical_ == -1) { + MS_LOG(EXCEPTION) << "The critical flag is not set."; + return FAILED; + } + operator_cost()->set_output_critical(is_output_critical_); + // Set the memory cost in the 'strategy_cost_' + for (auto &swc : strategy_cost_) { + auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr); + swc->cost_list[0]->memory_with_reuse_ = mem_cost; + } + return SUCCESS; +} + +Status OperatorInfo::CorrectMemoryCost(size_t input_index) { + for (auto &swc : strategy_cost_) { + double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * + static_cast(operator_cost()->inputs_type_lengths()[input_index]); + swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost; + if (swc->cost_list[0]->memory_with_reuse_ < 0) { + MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_ + << ", the parameter memory cost is: " << parameter_mem_cost; + return FAILED; + } + } + return SUCCESS; +} + +int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) { + int32_t ret = -1; + + // The number of repetitions is equal to the number of all devices divided by the number of devices use for + // tensor map. + int32_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies()); + for (auto &element : tensor_map) { + // -1 means the corresponding dimension is not split. + if (element == MAP_NONE) { + continue; + } else if ((element < 0) || (IntToSize(element) >= dev_matrix_shape.size())) { + MS_LOG(ERROR) << "Invalid tensor map: " << ShapeToString(tensor_map) << ", the dev matrix shape is " + << ShapeToString(dev_matrix_shape); + return ret; + } else { + size_t index = dev_matrix_shape.size() - IntToSize(element) - 1; + if (dev_matrix_shape[index] <= 0) { + MS_LOG(ERROR) << "Invalid dev matrix shape: " << ShapeToString(dev_matrix_shape); + return ret; + } + device_num /= dev_matrix_shape[index]; + } + } + + return device_num; +} + +Status OperatorInfo::InferAsLossDivisor() { + if (!ParallelContext::GetInstance()->loss_repeated_mean()) { + as_loss_divisor_ = 1; + return SUCCESS; + } + + if (outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty."; + return FAILED; + } + + if (outputs_tensor_map_.size() > 1) { + MS_LOG(ERROR) << name_ << ": The output size is " << outputs_tensor_map_.size() + << ", need to override this function "; + return FAILED; + } + + if (outputs_tensor_map_[0].empty()) { + as_loss_divisor_ = SizeToInt(global_device_list_.size()); + MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; + return SUCCESS; + } + + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); + MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(dev_matrix_shape_) + << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is " + << as_loss_divisor_; + return SUCCESS; +} + +// If the operator is used as a loss, a div node is inserted for the grad of all its inputs. +Status OperatorInfo::InferVirtualDivOps() { + if (InferAsLossDivisor() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferAsLossDivisor failed."; + return FAILED; + } + + if (as_loss_divisor_ <= 0) { + MS_LOG(ERROR) << name_ << ": Invalid loss divisor: " << as_loss_divisor_; + return FAILED; + } else if (as_loss_divisor_ == 1) { + MS_LOG(INFO) << name_ << ": The loss divisor is 1, no need to create virtual div op."; + return SUCCESS; + } + + virtual_div_op_.clear(); + // if loss is repeated calculation, insert div op + Operator op = CreateVirtualDivOp(as_loss_divisor_); + virtual_div_op_.push_back(op); + return SUCCESS; +} + +Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths) { + if (input_lengths.size() != inputs_shape_.size()) { + MS_LOG(ERROR) << "Input_lengths: " << input_lengths.size() + << " do not have the same number of inputs shape: " << inputs_shape_.size(); + return FAILED; + } + if (output_lengths.size() != outputs_shape_.size()) { + MS_LOG(ERROR) << "Output_lengths: " << output_lengths.size() + << " do not have the same number of outputs shape: " << outputs_shape_.size(); + return FAILED; + } + inputs_type_lengths_ = input_lengths; + outputs_type_lengths_ = output_lengths; + operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); + return SUCCESS; +} + +double OperatorInfo::GetOutputsTotalSize() { + if (is_calculated_outputs_size_) { + return outputs_total_size_; + } + if (outputs_type_lengths_.size() != outputs_shape_.size()) { + MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size() + << " do not have the same number of outputs shape: " << outputs_shape_.size(); + } + double sum = 0.0; + for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) { + auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast(1.0), + std::multiplies()); + sum += size * static_cast(outputs_type_lengths_[i]); + } + is_calculated_outputs_size_ = true; + outputs_total_size_ = sum; + return outputs_total_size_; +} + +Status OperatorInfo::set_outputs_type(const std::vector &outputs_type) { + if (outputs_type.size() != outputs_shape_.size()) { + MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() + << " do not have the same number of outputs shape: " << outputs_shape_.size(); + return FAILED; + } + outputs_type_ = outputs_type; + return SUCCESS; +} + +void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { + if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { + CheckGlobalDeviceManager(); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); + if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) { + if (cost->computation_cost_ > 1.0) { + cost->computation_cost_ -= 1.0; + } + if (cost->communication_cost_ > 1.0) { + cost->communication_cost_ -= 1.0; + } + if (cost->communication_with_partial_para_ > 1.0) { + cost->communication_with_partial_para_ -= 1.0; + } + if (cost->communication_without_parameter_ > 1.0) { + cost->communication_without_parameter_ -= 1.0; + } + } + } +} + +double OperatorInfo::GetForwardMemoryCostFromCNode() { + return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); +} + +void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) { + MS_EXCEPTION_IF_NULL(s_strategy); + if (!s_strategy->IsEqual(selected_strategy_)) { + MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:"; + PrintStrategy(selected_strategy_); + MS_LOG(INFO) << "The minimal strategy:"; + PrintStrategy(s_strategy); + } +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h new file mode 100644 index 0000000000..8641c47491 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -0,0 +1,289 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "base/base.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/group_manager.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +using ForwardOp = OperatorVector; +using MirrorOps = std::vector; +using Ops = std::vector; +using VirtualDivOp = OperatorVector; +using TensorMaps = std::vector>; +using TensorLayouts = std::vector; +using different_type = std::vector::difference_type; +using PrimitiveAttrs = std::unordered_map; +using Strategys = std::vector; +using ReplaceGraphPtr = std::shared_ptr>, AnfNodePtr>>; + +class Edge; + +class OperatorInfo { + public: + OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs, OperatorCostPtr cost) + : name_(std::move(name)), + inputs_shape_(std::move(inputs_shape)), + outputs_shape_(std::move(outputs_shape)), + attrs_(std::move(attrs)), + is_alive_(true), + operator_cost_(cost), + outputs_type_() { + std::vector not_parameteter(inputs_shape_.size(), false); + is_parameter_ = not_parameteter; + refkey_parameter_name_ = ""; + } + + virtual ~OperatorInfo() = default; + + Status set_is_parameter(const std::vector &is_parameter); + Status SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths); + double GetOutputsTotalSize(); + // Set outputs dtype. + // If only one output, outputs_type.size() is 1. + // If output is tuple, outputs_type.size() is greater than 1. + Status set_outputs_type(const std::vector &outputs_type); + const std::vector &outputs_type() const { return outputs_type_; } + virtual Status Init(const StrategyPtr &strategy) = 0; + virtual Status InitForCostModel(const StrategyPtr &strategy) = 0; // only init the necessary parts + + // Given the stage_id (which indicates the number of devices), + // generate all strategies for this operator + virtual Status GenerateStrategies(int32_t stage_id) = 0; + const OperatorCostPtr &operator_cost() const { return operator_cost_; } + void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; } + virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0; + + virtual std::shared_ptr>> GenerateBatchStrategies(); + virtual void ReComputeBatchSplitFlagList(); + void ComputeBatchSplitFlagList(); + + double GetForwardMemoryCostFromCNode(); + // This is a common method for setting operator cost for a given strategy, in which the validity of this strategy + // is checked + Status SetCostUnderStrategyBase(const StrategyPtr &strategy); + std::vector> GetStrategyCost() { return strategy_cost_; } + // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving + // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory + // at the end of forward phase. + Status CalculateMemoryCost(); + // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated + // by the output + Status CalculateMemoryCostForInference(); + int ComputeOpAndPrevEdgeParameterInvolved(); + + ForwardOp forward_op() const { return forward_op_; } + ForwardOp replace_op() const { return replace_op_; } + OutPutInfoVector replace_op_info() const { return replace_op_info_; } + virtual ReplaceGraphPtr replace_graph(const CNodePtr &) { return replace_graph_; } + MirrorOps mirror_ops() const { return mirror_ops_; } + Ops sub_ops() const { return sub_ops_; } + VirtualDivOp virtual_div_op() const { return virtual_div_op_; } + Shape dev_matrix_shape() const { return dev_matrix_shape_; } + std::vector inputs_tensor_info() const { return inputs_tensor_info_; } + std::vector outputs_tensor_info() const { return outputs_tensor_info_; } + std::vector> strategy_cost() const { return strategy_cost_; } + const std::string &name() const { return name_; } + void set_name(const std::string &name) { name_ = name; } + RankList global_device_list() const { return global_device_list_; } + + void AddSuccEdge(const std::shared_ptr &e) { succ_edges_.push_back(e); } + void AddPrevEdge(const std::shared_ptr &e) { prev_edges_.push_back(e); } + std::vector> succ_edges() const { return succ_edges_; } + std::vector> prev_edges() const { return prev_edges_; } + std::vector> GetAliveSuccEdges(); + std::vector> GetAlivePrevEdges(); + void ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); + std::vector GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); } + void SetSelectedStrategyAndCost(const StrategyPtr &s_strategy, const CostPtr &cost) { + selected_strategy_ = s_strategy; + selected_cost_ = cost; + } + StrategyPtr selected_strategy() const { return selected_strategy_; } + CostPtr selected_cost() const { return selected_cost_; } + void CheckSelectedStrategy(const StrategyPtr &); + Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } + void set_input_value(const std::vector &input_value) { input_value_ = input_value; } + const std::vector &input_value() const { return input_value_; } + void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; } + void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } + bool is_alive() const { return is_alive_; } + void SetNotAlive() { is_alive_ = false; } + StrategyPtr strategy() const { return strategy_; } + void set_strategy(const StrategyPtr &strategy) { strategy_ = strategy; } + void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } + const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } + // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated + // multiple times. This method is to correct this, and makes the cost is calulated only once. + Status CorrectMemoryCost(size_t input_index); + int is_output_parameter_involve() const { return is_output_parameter_involve_; } + int is_output_critical() const { return is_output_critical_; } + void mark_output_critical() { is_output_critical_ = 1; } + void mark_output_not_critical() { is_output_critical_ = 0; } + int used_devices() const { return used_devices_; } + // needed by rec_parser + void set_type(const std::string &type) { type_ = type; } + const std::string &type() const { return type_; } + const std::unordered_map &attrs() const { return attrs_; } + + protected: + // needed by rec_parser + std::string type_; + virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; + virtual Status InferTensorMap() = 0; + virtual Status InferForwardCommunication() = 0; + virtual Status InferMirrorOps() = 0; + virtual Status GetAttrs() = 0; + virtual Status InferTensorInfo() = 0; + virtual Status InferDevMatrixShape() = 0; + void SetDeviceListByStrategy(); + void SetRepeatedCalcDevMatrix(); + Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group); + Status CreateGroupByDim(size_t axis, std::vector *group); + Status InferAttrs(); + void ResetQueueMember(); + Status InitWithAutoRepeatCalc(const StrategyPtr &strategy); + Status InitWithManualRepeatCalc(const StrategyPtr &strategy); + Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy); + Status InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy); + Status InferRepeatedCalcInfo(); + Status InferVirtualDivOps(); + + // Calculate the number of repeated calculations for the output by the number of devices and the output tensor map. + // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output + // is used for grad and overload the function. If the output is a scalar, need to override the function too. + virtual Status InferAsLossDivisor(); + Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, + Shapes *inputs_slice_shape, Shapes *outputs_slice_shape); + void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &); + + std::string name_; + Shapes inputs_shape_; + Shapes outputs_shape_; + std::unordered_map attrs_; + std::vector input_value_; + TypePtr outputs_dtype_; + + StrategyPtr strategy_; + std::vector inputs_tensor_info_; + std::vector outputs_tensor_info_; + Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num as the first dimension + int32_t repeated_calc_num_ = 1; + int32_t as_loss_divisor_ = 1; + TensorMaps inputs_tensor_map_; + TensorMaps outputs_tensor_map_; + ForwardOp forward_op_; + Ops sub_ops_; + ForwardOp replace_op_; + OutPutInfoVector replace_op_info_; + ReplaceGraphPtr replace_graph_; + MirrorOps mirror_ops_; + VirtualDivOp virtual_div_op_; + RankList global_device_list_; // the size of global_device_list equal to the size of stageID + RankList local_device_list_; // the size equal to global_device_list_.size() / repeated_calc_num_ + bool infer_attrs_completed_ = false; + + bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel + // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected. + std::vector corrected_input_indices_; + // Given a parallization strategy, there is a cost. + std::vector> strategy_cost_; + // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter + std::vector is_parameter_; + // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of + // pre-operator that has parameters as input. + std::vector is_parameter_involve_; + // If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating + // peak memory cost in the training phase. + // -1: unset; 0: not parameter_involved; 1: parameter_involved + int is_output_parameter_involve_ = -1; + // Whether this output is critical, which means that this output is included in calculating peak memory cost + // in the inference phase. + // -1 : unset; 0: not critical; 1: critical + int is_output_critical_ = -1; + double outputs_total_size_ = 0.0; + bool is_calculated_outputs_size_ = false; + // for each input and output, the followings record the number of bytes of each element + std::vector inputs_type_lengths_; + std::vector outputs_type_lengths_; + std::vector> prev_edges_; + std::vector> succ_edges_; + StrategyPtr selected_strategy_; + // Used in DP algorithm + bool is_alive_; + CostPtr selected_cost_; + std::vector split_flag_list_; + std::string refkey_parameter_name_; + CNodePtr cnode_; + int32_t used_devices_ = -1; + + private: + OperatorCostPtr operator_cost_; + std::vector outputs_type_; +}; + +Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); +Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool); +Operator CreateVirtualDivOp(int32_t div_num); +Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); +Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); +Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); +OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); +int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); +std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( + const Shapes &shapes, const std::vector &split_flag_list); + +void PrintStrategy(const StrategyPtr &strategy); +// generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d]) +Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *sp_vector); +// generate strategies for that have two inputs, and input0 or input1 maybe broadcast, +// and the corresponding dimensions that are not broadcast are all relevant dimensions +// such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) +// or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) +// or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) +Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *sp_vector); + +Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h new file mode 100644 index 0000000000..bc732ed234 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ + +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/ops_info/arithmetic_info.h" +#include "frontend/parallel/ops_info/batch_parallel_info.h" +#include "frontend/parallel/ops_info/bias_add_info.h" +#include "frontend/parallel/ops_info/comparison_function_info.h" +#include "frontend/parallel/ops_info/dropout_do_mask_info.h" +#include "frontend/parallel/ops_info/elementary_function_info.h" +#include "frontend/parallel/ops_info/gather_v2_info.h" +#include "frontend/parallel/ops_info/get_next_info.h" +#include "frontend/parallel/ops_info/l2_normalize_info.h" +#include "frontend/parallel/ops_info/layer_norm_info.h" +#include "frontend/parallel/ops_info/loss_info.h" +#include "frontend/parallel/ops_info/matmul_info.h" +#include "frontend/parallel/ops_info/onehot_info.h" +#include "frontend/parallel/ops_info/prelu_info.h" +#include "frontend/parallel/ops_info/reduce_method_info.h" +#include "frontend/parallel/ops_info/reshape_info.h" +#include "frontend/parallel/ops_info/transpose_info.h" +#include "frontend/parallel/ops_info/virtual_dataset_info.h" +#include "frontend/parallel/ops_info/gather_v2_p_info.h" + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h new file mode 100644 index 0000000000..79dfb56693 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -0,0 +1,296 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_ + +namespace mindspore { +namespace parallel { +constexpr size_t PRELU_INPUTS_SIZE = 2; +constexpr size_t PRELU_OUTPUTS_SIZE = 1; +constexpr size_t PRELU_SECOND_INPUT_SIZE = 1; +constexpr int32_t PRELU_CHANNEL_INDEX = 1; +constexpr int32_t PRELU_CHANNEL_STRATEGY = 1; +constexpr int32_t NO_SPLIT_MAP = -1; +constexpr int32_t NO_SPLIT_STRATEGY = 1; +constexpr int32_t SPLIT_FLAG = 1; +constexpr int32_t NO_SPLIT_FLAG = 0; +constexpr size_t MATMUL_ATTRS_SIZE = 2; +constexpr size_t MATMUL_INPUTS_SIZE = 2; +constexpr size_t MATMUL_OUTPUTS_SIZE = 1; +constexpr size_t ACTIVATION_ATTR_SIZE = 1; +constexpr size_t SOFTMAX_ATTR_SIZE = 1; +constexpr size_t ACTIVATION_INPUTS_SIZE = 1; +constexpr size_t ACTIVATION_OUTPUTS_SIZE = 1; +constexpr size_t EXPANDDIMS_INPUT_SIZE = 2; +constexpr size_t DROPOUT_DO_MASK_CNODE_INPUT_SIZE = 4; +constexpr size_t DROPOUT_GEN_MASK_CNODE_INPUT_SIZE = 3; +constexpr size_t DROPOUT_GEN_MASK_INDEX = 2; +constexpr size_t DROPOUT_DO_MASK_KEEP_PROB_INDEX = 3; +constexpr size_t SoftmaxCrossEntropyWithLogitsAttrSize = 1; +constexpr size_t SoftmaxCrossEntropyWithLogitsInputsSize = 2; +constexpr size_t SoftmaxCrossEntropyWithLogitsOutputsSize = 2; +constexpr double EPS = 1e-6; +constexpr double INF = 1e20; + +constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only"; +constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only"; +constexpr char CHECK_SET_STRATEGY_VALID_ONCE_ONLY[] = "check_set_strategy_valid_once_only"; +constexpr char STRATEGY[] = "strategy"; +constexpr char GEN_STRATEGY[] = "gen_strategy"; +constexpr char REDUCE_OP_SUM[] = "sum"; +constexpr char REDUCE_OP_MAX[] = "max"; +constexpr char REDUCE_OP_MIN[] = "min"; +constexpr char OP_PATH[] = "mindspore.ops.operations"; +constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops"; +constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; +constexpr char GET_OP_FUNCTION[] = "_get_python_op"; +constexpr char KEEP_DIMS[] = "keep_dims"; +constexpr char CROSS_BATCH[] = "cross_batch"; +constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin"; +constexpr char STEP_PARALLEL_END[] = "step_parallel_end"; +constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; +constexpr char REQUIRES_GRAD[] = "requires_grad"; +constexpr char PARAM_NAME[] = "name"; +constexpr char RESHAPEINFO[] = "ReshapeInfo"; + +constexpr char RELU_TYPE[] = "relu"; +constexpr char RELU6_TYPE[] = "relu6"; +constexpr char SIGMOID_TYPE[] = "sigmoid"; +constexpr char OP[] = "op"; +constexpr char IDENTITY_INFO[] = "identity_info"; +constexpr char DIVISOR[] = "divisor"; +constexpr char NONE[] = "None"; +constexpr char DEPEND[] = "Depend"; +constexpr char BATCH_PARALLEL[] = "BatchParallel"; + +constexpr char ACTIVATION_TYPE[] = "activation_type"; +constexpr char TARGET[] = "primitive_target"; +constexpr char CPU[] = "CPU"; +constexpr char TRANSPOSE_A[] = "transpose_a"; +constexpr char TRANSPOSE_B[] = "transpose_b"; +constexpr char SHAPE[] = "shape"; +constexpr char BEGIN_MASK[] = "begin_mask"; +constexpr char END_MASK[] = "end_mask"; +constexpr char ELLIPSIS_MASK[] = "ellipsis_mask"; +constexpr char NEW_AXIS_MASK[] = "new_axis_mask"; +constexpr char SHRINK_AXIS_MASK[] = "shrink_axis_mask"; +constexpr char BEGIN[] = "begin"; +constexpr char END[] = "end"; +constexpr char STRIDES[] = "strides"; +constexpr char GROUP[] = "group"; +constexpr char AXIS[] = "axis"; +constexpr char OUTPUT_NUM[] = "output_num"; +constexpr char SPLIT_COUNT[] = "split_count"; +constexpr char SPLIT_DIM[] = "split_dim"; +constexpr char CONCAT_DIM[] = "concat_dim"; +constexpr char FORWARD[] = "forward"; +constexpr char BACKWARD[] = "backward"; +constexpr char REDISTRIBUTION[] = "redistribution"; +constexpr char REPLACE[] = "replace"; +constexpr char CONNSYMBOL[] = "/"; +constexpr char INSTANCE_NAME[] = "instance_name"; +constexpr char SPLIT_SENS[] = "split_sens"; +constexpr char SPLIT_TENSOR[] = "split_tensor"; +constexpr char DEV_MAT[] = "dev_mat"; +constexpr char TENSOR_MAP[] = "tensor_map"; +constexpr char SEED0[] = "Seed0"; +constexpr char SEED1[] = "Seed1"; +constexpr char KEEP_PROB[] = "keep_prob"; +constexpr char SRC[] = "src"; +constexpr char CLONE_INFO[] = "clone_info"; +constexpr char CLONED[] = "cloned"; +constexpr char BE_CLONED[] = "be_cloned"; +constexpr char CLONED_INDEX[] = "cloned_index"; +constexpr char BE_CLONED_INDEX[] = "be_cloned_index"; +constexpr char GROUP_RANKS[] = "group_ranks"; +constexpr char IS_IN_FORWARD[] = "is_in_forward"; +constexpr char DEFAULT_INPUT[] = "default_input"; +constexpr char DTYPE[] = "DType"; +constexpr char DEV_NUM[] = "dev_num"; +constexpr char MEAN_FLAG[] = "mean_flag"; +constexpr char TYPES[] = "types"; +constexpr char SHAPES[] = "shapes"; +constexpr char GETNEXT_NUM[] = "output_num"; +constexpr char SHARED_NAME[] = "shared_name"; +constexpr char MIRROR_OP[] = "mirror_op"; +constexpr char FORWARD_OP[] = "forward_op"; +constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; +constexpr char DARA_PARALLEL[] = "data_parallel"; +constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; +constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; +constexpr char DEVICE[] = "Device"; + +// Operator +constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; +constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice"; +constexpr char SPLIT[] = "Split"; +constexpr char ALL_TO_ALL[] = "_AlltoAll"; +constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis"; +constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis"; +constexpr char SPLIT_BY_AXIS[] = "SplitByAxis"; +constexpr char ALL_REDUCE[] = "AllReduce"; +constexpr char MIRROR_OPERATOR[] = "_MirrorOperator"; +constexpr char STRIDED_SLICE[] = "StridedSlice"; +constexpr char ALL_GATHER[] = "AllGather"; +constexpr char REDUCE_SCATTER[] = "ReduceScatter"; +constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter"; +constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; +constexpr char CONCAT[] = "Concat"; +constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits"; +constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits"; +constexpr char MATMUL[] = "MatMul"; +constexpr char GELU[] = "Gelu"; +constexpr char TANH[] = "Tanh"; +constexpr char SOFTMAX[] = "Softmax"; +constexpr char LOG_SOFTMAX[] = "LogSoftmax"; +constexpr char ACTIVATION[] = "Activation"; +constexpr char PRELU[] = "PReLU"; +constexpr char FLOORDIV[] = "FloorDiv"; +constexpr char MAXPOOL[] = "MaxPool"; +constexpr char MAXPOOLV2[] = "MaxPoolV2"; +constexpr char L2_NORMALIZE[] = "L2Normalize"; +constexpr char TRANSPOSE[] = "Transpose"; +constexpr char RESHAPE[] = "Reshape"; +constexpr char TENSOR_ADD[] = "TensorAdd"; +constexpr char BIAS_ADD[] = "BiasAdd"; +constexpr char SUB[] = "Sub"; +constexpr char MUL[] = "Mul"; +constexpr char DIV[] = "Div"; +constexpr char REAL_DIV[] = "RealDiv"; +constexpr char ASSIGN_SUB[] = "AssignSub"; +constexpr char GREATER[] = "Greater"; +constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset"; +constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo"; +constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits"; +constexpr char RELU[] = "ReLU"; +constexpr char ONEHOT[] = "OneHot"; +constexpr char DROPOUT_DO_MASK[] = "DropoutDoMask"; +constexpr char DROPOUT_GEN_MASK[] = "DropoutGenMask"; +constexpr char REDUCE_MAX[] = "ReduceMax"; +constexpr char REDUCE_MIN[] = "ReduceMin"; +constexpr char REDUCE_SUM[] = "ReduceSum"; +constexpr char REDUCE_MEAN[] = "ReduceMean"; +constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue"; +constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue"; +constexpr char CONV2D[] = "Conv2D"; +constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm"; +constexpr char BATCH_NORM[] = "BatchNorm"; +constexpr char LAYER_NORM[] = "LayerNorm"; +constexpr char POOLING[] = "Pooling"; +constexpr char CAST[] = "Cast"; +constexpr char MAX_POOL_WITH_ARGMAX[] = "MaxPoolWithArgmax"; +constexpr char SIMPLE_MEAN[] = "SimpleMean"; +constexpr char FLATTEN[] = "Flatten"; +constexpr char J[] = "J"; +constexpr char TMPIDENTITY_INFO_NAME[] = "identity_info"; +constexpr char COS[] = "Cos"; +constexpr char ACOS[] = "ACos"; +constexpr char EXP[] = "Exp"; +constexpr char LOG[] = "Log"; +constexpr char SIGMOID[] = "Sigmoid"; +constexpr char POW[] = "Pow"; +constexpr char MAXIMUM[] = "Maximum"; +constexpr char MINIMUM[] = "Minimum"; +constexpr char EQUAL[] = "Equal"; +constexpr char NOT_EQUAL[] = "NotEqual"; +constexpr char LOGICALNOT[] = "LogicalNot"; +constexpr char GATHERV2[] = "GatherV2"; +constexpr char SPARSE_GATHERV2[] = "SparseGatherV2"; +constexpr char STRIDEDSLICE[] = "StridedSlice"; +constexpr char BROADCAST[] = "Broadcast"; +constexpr char SQRT[] = "Sqrt"; +constexpr char ASSIGN[] = "Assign"; +constexpr char GET_NEXT[] = "GetNext"; +constexpr char SQUEEZE[] = "Squeeze"; +constexpr char NEG[] = "Neg"; +constexpr char BATCH_MATMUL[] = "BatchMatMul"; +constexpr char EXPAND_DIMS[] = "ExpandDims"; +constexpr char SQUARE[] = "Square"; +constexpr char BATCHMATMUL[] = "BatchMatMul"; +constexpr char TOPK[] = "TopK"; +constexpr char IN_TOPK[] = "InTopK"; +constexpr char PACK[] = "Pack"; +constexpr char GATHER_ND[] = "GatherNd"; +constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD"; +constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; +constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; +constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; +constexpr char ADD[] = "Add"; + +// Parallel don't care +constexpr char TUPLE_GETITEM[] = "tuple_getitem"; +constexpr char STRING_EQUAL[] = "string_equal"; +constexpr char MAKE_TUPLE[] = "make_tuple"; +constexpr char MAKE_LIST[] = "make_list"; +constexpr char MAKE_DICT[] = "make_dict"; +constexpr char MAKE_SLICE[] = "make_slice"; +constexpr char MAKE_RECORD[] = "make_record"; +constexpr char LIST_GETITEM[] = "list_getitem"; +constexpr char ARRAY_GETITEM[] = "array_getitem"; +constexpr char TUPLE_SETITEM[] = "tuple_setitem"; +constexpr char LIST_SETITEM[] = "list_setitem"; +constexpr char ARRAY_SETITEM[] = "array_setitem"; +constexpr char DICT_GETITEM[] = "dict_getitem"; +constexpr char LIST_APPEND[] = "list_append"; +constexpr char LIST_MAP[] = "list_map"; +constexpr char LIST_REDUCE[] = "list_reduce"; +constexpr char TUPLE_REVERSED[] = "tuple_reversed"; +constexpr char TILE_SHAPE[] = "tile_shape"; +constexpr char REDUCED_SHAPE[] = "reduced_shape"; +constexpr char TUPLE_DIV[] = "tuple_div"; +constexpr char TUPLE_TO_ARRAY[] = "tuple_to_array"; +constexpr char VIRTUALLOSS[] = "VirtualLoss"; +constexpr char RETURN[] = "return"; +constexpr char ENV_GETITEM[] = "env_getitem"; +constexpr char IDENTITY[] = "identity"; +constexpr char PARTIAL[] = "partial"; +constexpr char ENVSETITEM[] = "env_setitem"; +constexpr char ENVGETITEM[] = "env_getitem"; +constexpr char ENVADD[] = "env_add"; +constexpr char MAKEREFKEY[] = "MakeRefKey"; +constexpr char MAKEREF[] = "make_ref"; +constexpr char GETREFKEY[] = "get_ref_key"; +constexpr char GETREFVALUE[] = "get_ref_value"; +constexpr char GETREFORIGIN[] = "get_ref_origin"; +constexpr char STATESETITEM[] = "state_setitem"; +constexpr char SCALARSUMMARY[] = "ScalarSummary"; +constexpr char IMAGESUMMARY[] = "ImageSummary"; +constexpr char TENSORSUMMARY[] = "TensorSummary"; +constexpr char HISTOGRAMSUMMARY[] = "HistogramSummary"; +constexpr char DEBUG[] = "Debug"; +constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs"; +constexpr char INVERTPERMUTATION[] = "InvertPermutation"; +constexpr char CONTROLDEPEND[] = "ControlDepend"; +constexpr char DOT[] = "dot"; +constexpr char IM2COL[] = "im2col"; +constexpr char COL2IM[] = "col2im"; +constexpr char IM2COLV1[] = "im2col_v1"; +constexpr char COL2IMV1[] = "col2im_v1"; +constexpr char RESOLVE[] = "resolve"; +constexpr char EMBED[] = "embed"; +constexpr char CREATINSTANCE[] = "create_instance"; +constexpr char ZEROSLIKE[] = "ZerosLike"; +constexpr char REF_TO_EMBED[] = "RefToEmbed"; +constexpr char STOP_GRADIENT[] = "stop_gradient"; + +constexpr size_t LAST_INDEX(size_t s) { return s - 1; } +constexpr size_t SECOND_FROM_END(size_t s) { return s - 2; } +constexpr size_t THIRD_FROM_END(size_t s) { return s - 3; } +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc new file mode 100644 index 0000000000..57b35b69f7 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc @@ -0,0 +1,253 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/prelu_info.h" + +#include +#include +#include + +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/step_parallel.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +/* + * prelu has 2 input + * A: A float tensor of shape [NCHW] representing the output of the preview layer. + * w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input. + * the strategy of w should equal to the channel dimension of strategy of A, or equal to 1 + */ +Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + std::vector stra = strategy->GetInputDim(); + if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy size."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy size."; + } + return FAILED; + } + if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0] && inputs_shape_[1][0] != 1) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid channel strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid channel strategy."; + } + return FAILED; + } + return SUCCESS; +} + +/* + * device matrix is same with the strategy matrix + */ +Status PReLUInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + input_strategy_ = input_strategy; + dev_matrix_shape_ = input_strategy; + return SUCCESS; +} + +Status PReLUInfo::InferMirrorOps() { + Shape param_tensor_map = inputs_tensor_map_[1]; + std::vector param_group; + if (CreateGroupByTensorMap(param_tensor_map, ¶m_group) != SUCCESS) { + return FAILED; + } else if (param_group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror ops is empty."; + return SUCCESS; + } + OperatorVector op_for_param; + op_for_param = CreateMirrorOps(param_group[0].name(), param_group[0].GetDevNum()); + // op_for_inputs is empty + OperatorVector op_for_inputs; + mirror_ops_.push_back(op_for_inputs); + mirror_ops_.push_back(op_for_param); + std::string group_name = param_group[0].name(); + MS_LOG(INFO) << name_ << ": The mirror ops group is " << group_name; + return SUCCESS; +} + +Status PReLUInfo::InferForwardCommunication() { return SUCCESS; } + +/* + * the output tensor map is the same as the input tensor map + */ +Status PReLUInfo::InferTensorMap() { + TensorMap input_tensor_map; + // such as 4: input_tensor_map [3,2,1,0] + for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { + input_tensor_map.push_back((int32_t)(inputs_shape_[0].size() - i - 1)); + } + + TensorMap param_tensor_map; + if (inputs_shape_[1][0] == 1) { + param_tensor_map.push_back(-1); + } else { + param_tensor_map.push_back(input_tensor_map.at(1)); + } + inputs_tensor_map_.push_back(input_tensor_map); + inputs_tensor_map_.push_back(param_tensor_map); + outputs_tensor_map_.push_back(input_tensor_map); + return SUCCESS; +} + +Dimensions PReLUInfo::GetOutputStrategy() { + Dimensions output_strategy = input_strategy_; + return output_strategy; +} + +Status PReLUInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { + if (inputs_layout == nullptr || outputs_layout == nullptr) { + MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; + return FAILED; + } + TensorLayout input_layout, param_layout, output_layout; + if ((input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || + (param_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || + (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) { + return FAILED; + } + inputs_layout->push_back(input_layout); + inputs_layout->push_back(param_layout); + outputs_layout->push_back(output_layout); + return SUCCESS; +} + +Status PReLUInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape param_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Dimensions output_strategy = GetOutputStrategy(); + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = {output_strategy}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + Shape param_slice_shape = inputs_slice_shape.at(1); + Shape output_slice_shape = outputs_slice_shape.at(0); + + // infer tensor layout + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { + return FAILED; + } + + TensorLayout input_layout = inputs_layout.at(0); + TensorLayout param_layout = inputs_layout.at(1); + TensorLayout output_layout = outputs_layout.at(0); + TensorInfo input_tensor_info(input_layout, input_shape, input_slice_shape); + TensorInfo param_tensor_info(param_layout, param_shape, param_slice_shape); + TensorInfo output_tensor_info(output_layout, output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + inputs_tensor_info_.push_back(param_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status PReLUInfo::GetAttrs() { + if ((inputs_shape_.size() != PRELU_INPUTS_SIZE) || (outputs_shape_.size() != PRELU_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size " << inputs_shape_.size() << " or outputs shape size " + << outputs_shape_.size() << " is wrong."; + return FAILED; + } + return SUCCESS; +} + +Status PReLUInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status PReLUInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status PReLUInfo::GenerateStrategies(int32_t stage_id) { + if (inputs_shape_.size() != PRELU_INPUTS_SIZE) { + return FAILED; + } + if (inputs_shape_[1].size() != PRELU_SECOND_INPUT_SIZE) { + return FAILED; + } + is_auto_parallel_ = true; + Shape input0_split; + input0_split.emplace_back(1); + input0_split.emplace_back(0); + (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 2, 1); + Shape input1_split(inputs_shape_[1].size(), 0); + Shapes splittable_inputs = {input0_split, input1_split}; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed"; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h new file mode 100644 index 0000000000..e6e5e23bac --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +/* + * parallel class for PReLU Primitive + */ +class PReLUInfo : public OperatorInfo { + public: + PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~PReLUInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + Status GetAttrs() override; + Dimensions GetOutputStrategy(); + + private: + Dimensions input_strategy_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc new file mode 100644 index 0000000000..0488dceeca --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -0,0 +1,571 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/reduce_method_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + return SUCCESS; +} + +Status ReduceMethod::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + + dev_matrix_shape_ = input_strategy; + + return SUCCESS; +} + +std::vector ReduceMethod::reduce_dim() { + std::vector dim_list; + if (input_value_.size() < 2) { + MS_LOG(EXCEPTION) << name_ << ": Input value size is smaller than 2."; + } + if (input_value_.back() == nullptr) { + MS_LOG(EXCEPTION) << name_ << ": Input value is nullptr."; + } + MS_ASSERT(inputs_shape_.size() == 1); + auto input_dim = inputs_shape_.at(0).size(); + if (input_value_.back()->isa()) { + auto attr_axis = GetValue>(input_value_.back()); + // axis is (), reduce all dim + if (attr_axis.empty()) { + for (size_t i = 0; i < input_dim; ++i) { + dim_list.push_back(SizeToInt(i)); + } + } else { + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } + } + } else if (input_value_.back()->isa()) { + int axis = GetValue(input_value_.back()); + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Axis type is invalid."; + } + + return dim_list; +} + +Status ReduceMethod::GetAttrs() { + // get attr cross_batch and keep_dims + auto keep_dims_iter = attrs_.find(KEEP_DIMS); + if (keep_dims_iter == attrs_.end()) { + MS_LOG(ERROR) << name_ << ": Don't have attr keep_dims."; + return FAILED; + } + + if (keep_dims_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(keep_dims_iter->second); + if (!keep_dims_iter->second->isa()) { + MS_LOG(ERROR) << name_ << ": Keep_dims is not a bool."; + return FAILED; + } + keepdims_ = keep_dims_iter->second->cast()->value(); + } + + auto cross_batch_iter = attrs_.find(CROSS_BATCH); + if (cross_batch_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(cross_batch_iter->second); + if (!cross_batch_iter->second->isa()) { + MS_LOG(ERROR) << name_ << ": cross_batch is not a bool."; + return FAILED; + } + cross_batch_ = cross_batch_iter->second->cast()->value(); + } + auto reducemethodcost = std::dynamic_pointer_cast(operator_cost()); + if (reducemethodcost == nullptr) { + MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; + return FAILED; + } + reducemethodcost->set_cross_batch(cross_batch_); + return SUCCESS; +} + +Status ReduceMethod::InferTensorMap() { + std::vector tensor_map_index, dim_list, output_tensor_map; + size_t size = inputs_shape_.at(0).size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int32_t)(size - 1 - i)); + } + dim_list = reduce_dim(); + for (size_t i = 0; i < size; ++i) { + if (find(dim_list.begin(), dim_list.end(), SizeToInt(i)) != dim_list.end()) { + if (keepdims_) { + output_tensor_map.push_back(-1); + } else { + continue; + } + } else { + output_tensor_map.push_back(tensor_map_index[i]); + } + } + inputs_tensor_map_.push_back(tensor_map_index); + outputs_tensor_map_.push_back(output_tensor_map); + + return SUCCESS; +} + +bool IsDataParallelStrategy(const Dimensions &strategy) { + CheckGlobalDeviceManager(); + size_t total_dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + if (strategy.empty()) { + MS_LOG(EXCEPTION) << "IsDataParallelStrategy: strategy is empty"; + } + + return (IntToSize(strategy[0]) == total_dev_num); +} + +Status ReduceMethod::InferForwardCommunication() { + Dimensions stra = strategy_->GetInputDim().at(0); + if (cross_batch_ && IsDataParallelStrategy(stra)) { + MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; + return SUCCESS; + } + if (cross_batch_) { + MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; + return SUCCESS; + } + forward_op_.clear(); + std::vector dim_list = reduce_dim(); + size_t size = stra.size(); + // judge if the reduce dim is partitioned. + Shape group_creat_map; + if (dev_matrix_shape_.size() > size) { + group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); + } + for (size_t index = 0; index < size; ++index) { + auto pos = + std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; }); + if (pos != dim_list.end() && stra[index] != 1) { + continue; + } + group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); + } + std::vector forward_group; + if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; + return FAILED; + } + if (!forward_group.empty()) { + Operator op = CreateAllReduceOp(reduce_method_, forward_group[0].name()); + forward_op_.push_back(op); + std::string group_name = forward_group[0].name(); + MS_LOG(INFO) << name_ << ": Forward communication group is " << group_name; + } + + return SUCCESS; +} + +ForwardOp CreatReduceMeanForwardOp(const std::vector &forward_group, const TypePtr &dtype) { + // Creat AllReduceSum op + Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name()); + std::string group_name = forward_group[0].name(); + MS_LOG(INFO) << "The group of forward all reduce is " << group_name; + + // Creat RealDiv op + OperatorName operator1_name = REAL_DIV; + std::vector device_list = forward_group[0].GetDevicesList(); + auto divisor = static_cast(device_list.size()); + std::vector tensor_data = {divisor}; + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(tensor_data, dtype); + ValuePtr op1_param_value = MakeValue(tensor_ptr); + Attr op1_param = std::make_pair("divisor", op1_param_value); + OperatorParams operator1_params = {std::make_pair(op1_param, 2)}; + OperatorAttrs operator1_attrs; + OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params); + Operator op1 = std::make_pair(operator1_name, operator1_args); + ForwardOp forward_op = {op0, op1}; + + std::string dtype_name = dtype->ToString(); + MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name; + return forward_op; +} + +Status ReduceMeanInfo::InferForwardCommunication() { + Dimensions stra = strategy_->GetInputDim().at(0); + if (cross_batch_ && IsDataParallelStrategy(stra)) { + MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; + return SUCCESS; + } + forward_op_.clear(); + std::vector dim_list = reduce_dim(); + size_t size = stra.size(); + // judge if the reduce dim is partitioned. + Shape group_creat_map; + if (dev_matrix_shape_.size() > size) { + group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); + } + for (size_t index = 0; index < size; ++index) { + auto pos = + std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; }); + if (pos != dim_list.end() && stra[index] != 1) { + continue; + } + group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); + } + std::vector forward_group; + if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; + return FAILED; + } + if (!forward_group.empty()) { + if ((outputs_dtype_ == nullptr) || !outputs_dtype_->isa()) { + MS_LOG(ERROR) << name_ << ": The dtype of output is not Array"; + return FAILED; + } + + auto element_type = outputs_dtype_->cast()->element(); + forward_op_ = CreatReduceMeanForwardOp(forward_group, element_type); + } + + return SUCCESS; +} + +Status ReduceMethod::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_tensor_map = inputs_tensor_map_.at(0); + std::vector input_group; + if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " Infer MirrorOps failed."; + return FAILED; + } + + OperatorVector op_for_weight; + OperatorVector op_for_reduce_axis; // helper node + if (input_group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror ops is empty."; + return SUCCESS; + } else { + op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); + mirror_ops_.push_back(op_for_weight); + mirror_ops_.push_back(op_for_reduce_axis); + std::string group_name = input_group[0].name(); + MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success, the group is " << group_name; + } + + return SUCCESS; +} + +Status ArgMaxWithValueInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_tensor_map = inputs_tensor_map_.at(0); + std::vector input_group; + if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed."; + return FAILED; + } + + OperatorVector op_for_weight; + if (input_group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror ops is empty."; + return SUCCESS; + } else { + op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); + mirror_ops_.push_back(op_for_weight); + MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success."; + } + + return SUCCESS; +} + +Dimensions ReduceMethod::InferOutputStrategy() { + std::vector dim_list = reduce_dim(); + Dimensions output_strategy; + Dimensions stra = strategy_->GetInputDim().at(0); + // if keepdims_ is true,then output strategy is same with input. + for (size_t i = 0; i < stra.size(); ++i) { + if (find(dim_list.begin(), dim_list.end(), SizeToInt(i)) != dim_list.end()) { + if (keepdims_) { + output_strategy.push_back(1); + } + } else { + output_strategy.push_back(stra[i]); + } + } + return output_strategy; +} + +Status ReduceMethod::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape output_shape = outputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Dimensions output_strategy = InferOutputStrategy(); + + Strategys outputs_strategy = {output_strategy}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + Shape output_slice_shape = outputs_slice_shape.at(0); + + TensorLayout input_tensor_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { + return FAILED; + } + + std::vector dim_list = reduce_dim(); + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); + input_tensor_info.set_reduce_dim(dim_list); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + + return SUCCESS; +} + +Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status ReduceMethod::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " + << outputs_shape_.size(); + return FAILED; + } + + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + is_auto_parallel_ = true; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status ReduceMethod::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + + return SUCCESS; +} + +Status ReduceMethod::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed"; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed"; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success"; + return SUCCESS; +} + +std::vector ArgMaxWithValueInfo::reduce_dim() { + std::vector dim_list; + auto iter = attrs_.find(AXIS); + if (iter == attrs_.end()) { + MS_LOG(EXCEPTION) << name_ << ": Don't have attr axis."; + } + + MS_ASSERT(inputs_shape_.size() == 1); + auto input_dim = inputs_shape_.at(0).size(); + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + auto attr_axis = GetValue>(iter->second); + if (attr_axis.empty()) { + for (size_t i = 0; i < input_dim; ++i) { + dim_list.push_back(SizeToInt(i)); + } + } else { + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } + } + } else if (iter->second->isa()) { + int axis = GetValue(iter->second); + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Axis type is invalid."; + } + + return dim_list; +} + +Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) { + if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": CheckStrategy for parent class ReduceMethod failed"; + } else { + MS_LOG(ERROR) << name_ << ": CheckStrategy for parent class ReduceMethod failed"; + } + return FAILED; + } + std::vector dim_list = reduce_dim(); + MS_ASSERT(dim_list.size() == 1); + + std::vector stra = strategy->GetInputDim(); + MS_ASSERT(stra.size() == 1); + Shape input_strategy = stra.at(0); + MS_ASSERT(dim_list.at(0) < input_strategy.size()); + if (input_strategy.at(IntToSize(dim_list.at(0))) != 1) { + MS_LOG(WARNING) + << name_ + << " CheckStrategy for ArgMaxWithValueInfo, the strategy corresponding to axis is not one, real strategy " + "is " + << input_strategy.at(IntToSize(dim_list.at(0))) + << ", the output index may be not compatible with the stand alone Primitive"; + } + return SUCCESS; +} + +Status ArgMaxWithValueInfo::InferTensorMap() { + if (ReduceMethod::InferTensorMap() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorMap for parent class ReduceMethod failed"; + return FAILED; + } + MS_ASSERT(outputs_tensor_map_.size() == 1); + outputs_tensor_map_.push_back(outputs_tensor_map_[0]); + return SUCCESS; +} + +Status ArgMaxWithValueInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape output_shape = outputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Dimensions output_strategy = InferOutputStrategy(); + + Strategys outputs_strategy = {output_strategy, output_strategy}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + Shape output_slice_shape = outputs_slice_shape.at(0); + + TensorLayout input_tensor_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { + return FAILED; + } + + std::vector dim_list = reduce_dim(); + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); + input_tensor_info.set_reduce_dim(dim_list); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status ArgMaxWithValueInfo::InferAsLossDivisor() { + if (outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty."; + return FAILED; + } + + MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer"; + if (outputs_tensor_map_[0].empty()) { + as_loss_divisor_ = SizeToInt(global_device_list_.size()); + MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor."; + return SUCCESS; + } + + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); + + std::string dev_matrix_shape_str = ShapeToString(dev_matrix_shape_); + std::string output_tensor_map_str = ShapeToString(outputs_tensor_map_[0]); + MS_LOG(INFO) << name_ << ": the dev matrix shape, the output tensor map, and loss divisor is " << dev_matrix_shape_str + << ", " << output_tensor_map_str << ", " << as_loss_divisor_; + return SUCCESS; +} + +Status ArgMaxWithValueInfo::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 2)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " + << outputs_shape_.size(); + return FAILED; + } + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + is_auto_parallel_ = true; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated strategy " << success; + PrintStrategy(sp); + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h new file mode 100644 index 0000000000..ed9ab0721d --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h @@ -0,0 +1,141 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ + +#include +#include +#include +#include + +#include "ir/tensor.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class ReduceMethod : public OperatorInfo { + public: + ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~ReduceMethod() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + std::string reduce_method_; + bool keepdims_ = false; + bool cross_batch_ = false; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + Dimensions InferOutputStrategy(); + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferMirrorOps() override; + virtual std::vector reduce_dim(); + Status InferForwardCommunication() override; + Status InferDevMatrixShape() override; +}; + +class ReduceMaxInfo : public ReduceMethod { + public: + ReduceMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + reduce_method_ = REDUCE_OP_MAX; + } + + ~ReduceMaxInfo() override = default; +}; + +class ArgMaxWithValueInfo : public ReduceMethod { + public: + ArgMaxWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + reduce_method_ = REDUCE_OP_MAX; + } + + ~ArgMaxWithValueInfo() override = default; + + Status GenerateStrategies(int32_t stage_id) override; + + protected: + std::vector reduce_dim() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferAsLossDivisor() override; +}; + +class ArgMinWithValueInfo : public ArgMaxWithValueInfo { + public: + ArgMinWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArgMaxWithValueInfo(name, inputs_shape, outputs_shape, attrs) { + reduce_method_ = REDUCE_OP_MIN; + } + + ~ArgMinWithValueInfo() override = default; +}; + +class ReduceMeanInfo : public ReduceMethod { + public: + ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + set_cost(std::make_shared()); + } + + ~ReduceMeanInfo() override = default; + + protected: + Status InferForwardCommunication() override; +}; + +class ReduceSumInfo : public ReduceMethod { + public: + ReduceSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + reduce_method_ = REDUCE_OP_SUM; + } + + ~ReduceSumInfo() override = default; +}; + +class ReduceMinInfo : public ReduceMethod { + public: + ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + reduce_method_ = REDUCE_OP_MIN; + } + + ~ReduceMinInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc new file mode 100644 index 0000000000..fb62c1d02c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -0,0 +1,507 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/reshape_info.h" + +#include +#include + +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + size_t strategy_size = strategy->GetInputNumber(); + if (strategy_size != 1) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy size " << strategy_size; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy size " << strategy_size; + } + return FAILED; + } + return SUCCESS; +} + +/* + * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of + * device matrix + * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) + */ +Status ReshapeInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + input_strategy_ = stra.at(0); + dev_matrix_shape_.push_back(input_strategy_[0]); + return SUCCESS; +} + +/* + * there is no Parameter for Reshape Primitive, so no need to do allreduce + */ +Status ReshapeInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_tensor_map = input_layout_.tensor_map().array(); + std::vector input_group; + if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed."; + return FAILED; + } + + OperatorVector op_for_input; + if (input_group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror ops is empty."; + return SUCCESS; + } + if (!input_group.empty()) { + op_for_input = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); + std::string group_name = input_group[0].name(); + MS_LOG(INFO) << name_ << ": Create the mirror ops for input_a success, group is " << group_name; + } + mirror_ops_.push_back(op_for_input); + OperatorVector op_for_input_empty; + mirror_ops_.push_back(op_for_input_empty); + + return SUCCESS; +} + +/* + * there is no reduction dimension for forward computation of Reshape Primitive, so no need to do allreduce + */ +Status ReshapeInfo::InferForwardCommunication() { return SUCCESS; } + +/* + * get shape input of Reshape Primitive + * the result is saved in parameter_input_v_ + * not support -1 + */ +Status ReshapeInfo::GetParameterInput() { + if (input_value_[1] == nullptr) { + MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr."; + return FAILED; + } + std::vector elements; + ValueTuplePtr dim_tuple = input_value_[1]->cast(); + if (dim_tuple == nullptr) { + MS_LOG(ERROR) << name_ << ": Input_value_[1] must be ValueTuplePtr."; + return FAILED; + } + elements = dim_tuple->value(); + if (elements.size() != outputs_shape_[0].size()) { + MS_LOG(ERROR) << name_ << ": Elements size must equal to outputs shape[0] size."; + return FAILED; + } + + for (auto &element : elements) { + MS_EXCEPTION_IF_NULL(element); + if (element->isa()) { + int32_t axis = element->cast()->value(); + parameter_input_v_.push_back(axis); + } else { + MS_LOG(ERROR) << name_ << ": The value of axis must be int32."; + return FAILED; + } + } + return SUCCESS; +} + +Status ReshapeInfo::ComputeReplaceOp() { + RankList dev_list = global_device_list(); + TensorRedistribution tensor_redistribution(!is_generating_costs_, true); + if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) { + if (is_generating_costs_) { + MS_LOG(DEBUG) << name_ << ": tensor_redistribution init failed."; + } else { + MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed."; + } + return FAILED; + } + MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); + MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); + MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); + RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); + if (redistribution_oplist_ptr == nullptr) { + if (is_generating_costs_) { + MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; + } else { + MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; + } + return FAILED; + } + replace_op_ = redistribution_oplist_ptr->first; + replace_op_info_ = redistribution_oplist_ptr->second; + MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); + return SUCCESS; +} + +/* + * the first dimension of input tensor map and output tensor map is set to the last dimension of device arrangement, + * all other dimension is set to None + * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) + */ +Status ReshapeInfo::InferTensorMap() { + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": inputs shape and outputs shape size must be 1. inputs shape and outputs shape are " + << inputs_shape_.size() << " and " << outputs_shape_.size(); + return FAILED; + } + + std::vector tensor_map_index_input; + tensor_map_index_input.push_back(0); + + for (size_t j = 1; j < inputs_shape_[0].size(); ++j) { + tensor_map_index_input.push_back(MAP_NONE); + } + inputs_tensor_map_.push_back(tensor_map_index_input); + + std::vector tensor_map_index_output; + tensor_map_index_output.push_back(0); + + for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { + tensor_map_index_output.push_back(MAP_NONE); + } + outputs_tensor_map_.push_back(tensor_map_index_output); + return SUCCESS; +} + +/* + * the output tensor strategy is the same as input tensor strategy + * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) + */ +Strategys ReshapeInfo::GetOutputsStrategy() { + Strategys outputs_strategy; + std::vector strategy; + strategy.push_back(input_strategy_[0]); + for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { + strategy.push_back(1); + } + outputs_strategy.push_back(strategy); + return outputs_strategy; +} + +Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { + if (inputs_layout == nullptr || outputs_layout == nullptr) { + MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; + return FAILED; + } + Arrangement dev_matrix; + Status status = dev_matrix.Init(dev_matrix_shape_); + if (status != Status::SUCCESS) { + return status; + } + // infer input tensor info + Shape shape_array_in = inputs_shape_.at(0); + TensorMap tensor_map_array_in = inputs_tensor_map_.at(0); + TensorLayout tensor_layout_in; + Map tensor_map_in; + status = tensor_map_in.Init(tensor_map_array_in); + if (status != Status::SUCCESS) { + return status; + } + Arrangement shape_in; + status = shape_in.Init(shape_array_in); + if (status != Status::SUCCESS) { + return status; + } + (void)tensor_layout_in.Init(dev_matrix, tensor_map_in, shape_in); + inputs_layout->push_back(tensor_layout_in); + // infer output tensor info + Shape shape_array_out = outputs_shape_.at(0); + + TensorMap tensor_map_array_out = outputs_tensor_map_.at(0); + TensorLayout tensor_layout_out; + Map tensor_map_out; + status = tensor_map_out.Init(tensor_map_array_out); + if (status != Status::SUCCESS) { + return status; + } + Arrangement shape_out; + status = shape_out.Init(shape_array_out); + if (status != Status::SUCCESS) { + return status; + } + (void)tensor_layout_out.Init(dev_matrix, tensor_map_out, shape_out); + outputs_layout->push_back(tensor_layout_out); + + input_layout_ = tensor_layout_in; + output_layout_ = tensor_layout_out; + return SUCCESS; +} + +Status ReshapeInfo::InferTensorInfo() { + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = GetOutputsStrategy(); + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { + return FAILED; + } + TensorLayout tensor_layout_in = inputs_layout.at(0); + TensorLayout tensor_layout_out = outputs_layout.at(0); + Shape shape_array_in = inputs_shape_.at(0); + Shape slice_shape_in = inputs_slice_shape.at(0); + Shape shape_array_out = outputs_shape_.at(0); + Shape slice_shape_out = outputs_slice_shape.at(0); + TensorInfo tensor_info_in(tensor_layout_in, shape_array_in, slice_shape_in); + TensorInfo tensor_info_out(tensor_layout_out, shape_array_out, slice_shape_out); + inputs_tensor_info_.push_back(tensor_info_in); + outputs_tensor_info_.push_back(tensor_info_out); + return SUCCESS; +} + +void ReshapeInfo::InferTensorInfoByLayout() { + TensorInfo tensor_info_in(input_layout_); + TensorInfo tensor_info_out(output_layout_); + inputs_tensor_info_.push_back(tensor_info_in); + outputs_tensor_info_.push_back(tensor_info_out); +} + +/* + * compute parameter_input_v_ during this method + */ +Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } + +void ReshapeInfo::device_number(const StrategyPtr &strategy) { + int32_t stage = 0; + if (strategy != nullptr) { + stage = strategy->GetInputStage(); + } + CheckGlobalDeviceManager(); + global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); + dev_num_ = SizeToInt(global_device_list_.size()); + MS_ASSERT(dev_num_ > 0); +} + +Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) { + std::vector tensor_map_index; + for (size_t i = 0; i < shape.size(); i++) { + tensor_map_index.push_back(MAP_NONE); + } + Status status = layout->InitFromVector({dev_num_}, tensor_map_index, shape); + if (status != Status::SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferDefaultLayout failed."; + return status; + } + return Status::SUCCESS; +} + +Status ReshapeInfo::Init(const StrategyPtr &strategy) { + ResetQueueMember(); + device_number(strategy); + if (strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + } else { + if (!input_layout_set_flag_) { + MS_ASSERT(inputs_shape_.size() == 1); + Status status = InferDefaultLayout(inputs_shape_.at(0), &input_layout_); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": infer input default layout failed."; + return status; + } + } + if (!output_layout_set_flag_) { + MS_ASSERT(output_layout_.size() == 1); + Status status = InferDefaultLayout(outputs_shape_.at(0), &output_layout_); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": infer output default layout failed."; + return status; + } + } + inputs_tensor_map_.push_back(input_layout_.tensor_map().array()); + outputs_tensor_map_.push_back(output_layout_.tensor_map().array()); + InferTensorInfoByLayout(); + // change dev_matrix_shape_ to input_layout_ device_arrangement before InferMirrorOps + dev_matrix_shape_ = input_layout_.device_arrangement().array(); + if (InferMirrorOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; + return FAILED; + } + // change dev_matrix_shape_ to output_layout_ device_arrangement before InferVirtualDivOps + dev_matrix_shape_ = output_layout_.device_arrangement().array(); + if (InferVirtualDivOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; + return FAILED; + } + } + Status status = ComputeReplaceOp(); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed."; + return status; + } + return SUCCESS; +} + +Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +void ReshapeInfo::SetCostForReshapeWithParameter() { + size_t success = 0; + for (auto &sp : sp_vector_) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } +} + +void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + int32_t stage_id = strategy->GetInputStage(); + double computation_cost = + operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + std::shared_ptr result = std::make_shared(computation_cost, communication_cost); + result->communication_without_parameter_ = + operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + result->communication_with_partial_para_ = + result->communication_without_parameter_ + + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); + + // Breaking ties for preferring data parallelization + BreakingTiesForPerferringDataParallel(strategy, result); + // refine communication cost calculation for practice + RefineForPracticalCost(result, false); + + std::shared_ptr swc = + std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); + swc->cost_list.push_back(result); + strategy_cost_.emplace_back(swc); +} + +Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GetAttrs failed."; + return FAILED; + } + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " + << outputs_shape_.size(); + return FAILED; + } + is_auto_parallel_ = true; + Shape input0_split; + (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + // strategy used only in the input node is parameter, + // in other case, use the input node's output_layout as input_layout. + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; + return FAILED; + } + return SUCCESS; +} + +Status ReshapeInfo::GenetateStrategyCosts(const std::vector> &pre_stra_costs, + const std::vector> &next_stra_costs, + int32_t out_index, int32_t in_index, bool is_prev_param) { + is_generating_costs_ = true; + for (auto pre_stra_cost : pre_stra_costs) { + std::vector pre_out_tensor_infos; + if (is_prev_param) { + pre_out_tensor_infos = pre_stra_cost->inputs_ptr; + } else { + pre_out_tensor_infos = pre_stra_cost->outputs_ptr; + } + if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { + MS_LOG(ERROR) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; + return FAILED; + } + TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; + SetInputLayout(pre_out_tensor_info.tensor_layout()); + // infer pre_node output strategy from output_layout. + Dimensions stra = pre_out_tensor_info.InferStrategy(); + if (stra.empty()) { + MS_LOG(ERROR) << "Infer strategy by tensor_info failed"; + return FAILED; + } + std::vector stra_inputs = {stra}; + StrategyPtr reshape_stra = std::make_shared(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); + if (next_stra_costs.empty()) { + if (Init(nullptr) == FAILED) { + MS_LOG(ERROR) << "Failure:operator reshape init failed"; + return FAILED; + } + SetCostForReshape(reshape_stra); + continue; + } + for (auto next_stra_cost : next_stra_costs) { + std::vector next_in_tensor_infos = next_stra_cost->inputs_ptr; + if (next_in_tensor_infos.size() <= IntToSize(in_index)) { + MS_LOG(ERROR) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; + return FAILED; + } + TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; + SetOutputLayout(next_in_tensor_info.tensor_layout()); + if (Init(nullptr) == FAILED) { + MS_LOG(DEBUG) << "Failure:operator reshape init failed"; + continue; + } + SetCostForReshape(reshape_stra); + } + } + is_generating_costs_ = false; + if (strategy_cost_.empty()) { + return FAILED; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h new file mode 100644 index 0000000000..2463b440f8 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h @@ -0,0 +1,107 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ + +#include + +#include +#include +#include +#include + +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +/* + * parallel class for Reshape Primitive + */ +class ReshapeInfo : public OperatorInfo { + public: + ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), + dev_num_(0), + pre_operator_index_(0), + next_operator_index_(0), + input_layout_set_flag_(false), + output_layout_set_flag_(false) {} + ~ReshapeInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + void SetInputLayout(const TensorLayout &input_layout) { + input_layout_ = input_layout; + input_layout_set_flag_ = true; + } + void SetOutputLayout(const TensorLayout &output_layout) { + output_layout_ = output_layout; + output_layout_set_flag_ = true; + } + void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy); + void SetCostForReshapeWithParameter(); + void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; } + void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } + void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; } + void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; } + Status GenetateStrategyCosts(const std::vector> &pre_stra_costs, + const std::vector> &next_stra_costs, int32_t out_index, + int32_t in_index, bool is_prev_param); + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + std::string pre_operator_name() const { return pre_operator_name_; } + std::string next_operator_name() const { return next_operator_name_; } + int32_t pre_operator_index() const { return pre_operator_index_; } + int32_t next_operator_index() const { return next_operator_index_; } + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + Status GetAttrs() override; + Strategys GetOutputsStrategy(); + + private: + Status GetParameterInput(); + Status ComputeReplaceOp(); + void InferTensorInfoByLayout(); + void device_number(const StrategyPtr &strategy); + Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); + + int32_t dev_num_; + int32_t pre_operator_index_; + int32_t next_operator_index_; + std::vector parameter_input_v_; + std::vector sp_vector_; + Dimensions input_strategy_; + TensorLayout input_layout_; + TensorLayout output_layout_; + bool input_layout_set_flag_; + bool output_layout_set_flag_; + bool is_generating_costs_; + std::string pre_operator_name_; + std::string next_operator_name_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc new file mode 100644 index 0000000000..ed6eaa89f1 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc @@ -0,0 +1,147 @@ +/** +#include "utils/log_adapter.h" + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/tmp_identity_info.h" + +#include +#include + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": invalid strategy."; + } + return FAILED; + } + return SUCCESS; +} + +Status TmpIdentityInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + dev_matrix_shape_ = input_strategy; + return SUCCESS; +} + +Status TmpIdentityInfo::InferTensorMap() { + std::vector tensor_map_index; + size_t size = inputs_shape_[0].size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int32_t)(size - 1 - i)); + } + + inputs_tensor_map_.push_back(tensor_map_index); + outputs_tensor_map_.push_back(tensor_map_index); + return SUCCESS; +} + +Status TmpIdentityInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = {inputs_strategy.at(0)}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + + TensorLayout input_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(input_tensor_info); // the same as input + + return SUCCESS; +} + +Status TmpIdentityInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status TmpIdentityInfo::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " + << outputs_shape_.size(); + return FAILED; + } + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h new file mode 100644 index 0000000000..7f73f81180 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ + +#include +#include +#include + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class TmpIdentityInfo : public OperatorInfo { + // This operator is only used for the case of a parameter tensor being used by multiple operators, where we + // consider this parameter tensor as TmpIdentityInfo operator. TmpIdentityInfo operator tasks as input a tensor, + // and outputs the same tensor. After the transformation, subsequent operators can share the output tensor. + public: + TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, + const std::string &name = IDENTITY_INFO) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~TmpIdentityInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override { return SUCCESS; } + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc new file mode 100644 index 0000000000..b6bb875abc --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc @@ -0,0 +1,247 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/transpose_info.h" + +#include +#include + +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/step_parallel.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + return SUCCESS; +} + +Status TransposeInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + input_strategy_ = stra.at(0); + for (auto &iter : input_strategy_) { + dev_matrix_shape_.push_back(iter); + } + return SUCCESS; +} + +// there is no Parameter for Transpose Primitive, so no need to do all reduce +Status TransposeInfo::InferMirrorOps() { return SUCCESS; } + +// there is no reduction dimension for forward computation of Transpose Primitive, so no need to do all reduce +Status TransposeInfo::InferForwardCommunication() { return SUCCESS; } + +/* + * get perm input of Transpose Primitive + * perm is a permutation of the dimensions of input + * the result is saved in axis_v_ + */ +Status TransposeInfo::ComputeAxis() { + if (input_value_[1] == nullptr) { + MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr."; + return FAILED; + } + std::vector elements; + ValueTuplePtr dim_tuple = input_value_[1]->cast(); + if (dim_tuple == nullptr) { + MS_LOG(ERROR) << name_ << ": input_value_[1] must be ValueTuplePtr."; + return FAILED; + } + elements = dim_tuple->value(); + if (elements.size() != inputs_shape_[0].size()) { + MS_LOG(ERROR) << name_ << ": elements size must equal to inputs shape 0 size."; + return FAILED; + } + axis_v_.clear(); + for (auto &element : elements) { + MS_EXCEPTION_IF_NULL(element); + if (element->isa()) { + int32_t axis = element->cast()->value(); + axis_v_.push_back(axis); + } else { + MS_LOG(ERROR) << name_ << ": The value of axis must be int32."; + return FAILED; + } + } + + for (int32_t i = 0; i < SizeToInt(axis_v_.size()); i++) { + auto iter = std::find(axis_v_.begin(), axis_v_.end(), i); + if (iter == axis_v_.end()) { + MS_LOG(ERROR) << name_ << ": axis_v_ must be a permutation."; + } + } + return SUCCESS; +} + +// the output tensor map is the permutation of input tensor map, the permutation is axis_v +Status TransposeInfo::InferTensorMap() { + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": inputs_shape_ and outputs_shape_ size must be 1, inputs shape and outputs shape is " + << inputs_shape_.size() << ", " << outputs_shape_.size(); + return FAILED; + } + + std::vector tensor_map_index_input; + for (size_t j = 0; j < inputs_shape_[0].size(); ++j) { + tensor_map_index_input.push_back(SizeToInt(inputs_shape_[0].size() - j - 1)); + } + inputs_tensor_map_.push_back(tensor_map_index_input); + + std::vector tensor_map_index_output = tensor_map_index_input; + for (uint32_t i = 0; i < tensor_map_index_output.size(); i++) { + tensor_map_index_output[i] = tensor_map_index_input[IntToUint(axis_v_[i])]; + } + outputs_tensor_map_.push_back(tensor_map_index_output); + return SUCCESS; +} + +// the output tensor strategy is the permutation of input tensor strategy, the permutation is axis_v +Strategys TransposeInfo::GetOutputsStrategy() { + Strategys outputs_strategy; + std::vector strategy = input_strategy_; + for (uint32_t i = 0; i < strategy.size(); i++) { + strategy[i] = input_strategy_[IntToUint(axis_v_[i])]; + } + outputs_strategy.push_back(strategy); + return outputs_strategy; +} + +Status TransposeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { + if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { + MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; + return FAILED; + } + Shape shape_in = inputs_shape_.at(0); + TensorMap tensor_map_in = inputs_tensor_map_.at(0); + Shape shape_out = outputs_shape_.at(0); + TensorMap tensor_map_out = outputs_tensor_map_.at(0); + + TensorLayout tensor_layout_in, tensor_layout_out; + if ((tensor_layout_in.InitFromVector(dev_matrix_shape_, tensor_map_in, shape_in) != SUCCESS) || + (tensor_layout_out.InitFromVector(dev_matrix_shape_, tensor_map_out, shape_out) != SUCCESS)) { + return FAILED; + } + + inputs_layout->push_back(tensor_layout_in); + outputs_layout->push_back(tensor_layout_out); + return SUCCESS; +} + +Status TransposeInfo::InferTensorInfo() { + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = GetOutputsStrategy(); + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { + return FAILED; + } + TensorLayout tensor_layout_in = inputs_layout.at(0); + TensorLayout tensor_layout_out = outputs_layout.at(0); + Shape shape_array_in = inputs_shape_.at(0); + Shape slice_shape_in = inputs_slice_shape.at(0); + Shape shape_array_out = outputs_shape_.at(0); + Shape slice_shape_out = outputs_slice_shape.at(0); + TensorInfo tensor_info_in(tensor_layout_in, shape_array_in, slice_shape_in); + TensorInfo tensor_info_out(tensor_layout_out, shape_array_out, slice_shape_out); + inputs_tensor_info_.push_back(tensor_info_in); + outputs_tensor_info_.push_back(tensor_info_out); + return SUCCESS; +} + +// compute axis_v_ during this method +Status TransposeInfo::GetAttrs() { return ComputeAxis(); } + +Status TransposeInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status TransposeInfo::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GetAttrs failed."; + return FAILED; + } + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " + << outputs_shape_.size(); + return FAILED; + } + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed"; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << "strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h new file mode 100644 index 0000000000..d3b62dc234 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +/* + * parallel class for Transpose Primitive + */ +class TransposeInfo : public OperatorInfo { + public: + TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~TransposeInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + Status GetAttrs() override; + Strategys GetOutputsStrategy(); + + private: + Status ComputeAxis(); + std::vector axis_v_; + Dimensions input_strategy_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc new file mode 100644 index 0000000000..3b89d7c84c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc @@ -0,0 +1,229 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/virtual_dataset_info.h" + +#include +#include +#include + +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + if (stra.size() < 1) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Strategy size must be larger than 1."; + } else { + MS_LOG(ERROR) << name_ << ": Strategy size must be larger than 1."; + } + return FAILED; + } + if (stra.size() == 1) { + MS_LOG(WARNING) << name_ << ": Strategy size is 1."; + return SUCCESS; + } + Dimensions strategy_first = stra.at(1); + for (auto iter_strategy = stra.begin() + 1; iter_strategy != stra.end(); ++iter_strategy) { + if (iter_strategy->empty()) { + MS_LOG(ERROR) << name_ << ": iter_strategy size is zero."; + } + if (strategy_first.at(0) != *(iter_strategy->begin())) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": The first dimension of each strategy must be the same."; + } else { + MS_LOG(ERROR) << name_ << ": The first dimension of each strategy must be the same."; + } + return FAILED; + } + + for (auto iter_element = iter_strategy->begin() + 1; iter_element != iter_strategy->end(); ++iter_element) { + if (*iter_element != 1) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": All dimension except the first dimension of each strategy must be 1."; + } else { + MS_LOG(ERROR) << name_ << ": All dimension except the first dimension of each strategy must be 1."; + } + return FAILED; + } + } + } + return SUCCESS; +} + +Status VirtualDatasetInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions strategy_first = stra.at(0); + int32_t stage = strategy_->GetInputStage(); + CheckGlobalDeviceManager(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); + int32_t batch_split_num = strategy_first.at(0); + dev_matrix_shape_.push_back(batch_split_num); + if (dev_num > batch_split_num) { + dev_matrix_shape_.push_back(dev_num / batch_split_num); + } + + return SUCCESS; +} + +Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; } + +Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; } + +Status VirtualDatasetInfo::InferTensorMap() { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + + for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { + std::vector tensor_map_index; + if (full_batch) { + tensor_map_index.push_back(MAP_NONE); + } else { + tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size())))); + } + for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) { + tensor_map_index.push_back(MAP_NONE); + } + inputs_tensor_map_.push_back(tensor_map_index); + outputs_tensor_map_.push_back(tensor_map_index); + } + return SUCCESS; +} + +Status VirtualDatasetInfo::InferTensorInfo() { + for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { + MS_LOG(INFO) << name_ << ": InferTensorInfo " << i << ", size " << strategy_->GetInputNumber(); + TensorLayout tensor_layout_in; + if (tensor_layout_in.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(i), inputs_shape_.at(i)) != SUCCESS) { + return FAILED; + } + TensorInfo tensor_info_in(tensor_layout_in); + inputs_tensor_info_.push_back(tensor_info_in); + outputs_tensor_info_.push_back(tensor_info_in); + } + return SUCCESS; +} + +Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; } + +Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) { + if (InitWithManualRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + return SUCCESS; +} + +Status VirtualDatasetInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +void VirtualDatasetInfo::ReComputeBatchSplitFlagList() { + for (size_t i = 0; i < inputs_shape_.size(); i++) { + split_flag_list_[i] = true; + } +} + +Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + size_t total_dev_num; + + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GetAttrs failed"; + return FAILED; + } + + CheckGlobalDeviceManager(); + is_auto_parallel_ = true; + if (full_batch) { + total_dev_num = 1; + } else { + total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + } + StrategyPtr sp; + std::vector strategy; + for (auto &shape : inputs_shape_) { + Shape temp; + temp.emplace_back(SizeToInt(total_dev_num)); + (void)temp.insert(temp.end(), shape.size() - 1, 1); + strategy.push_back(temp); + } + sp = std::make_shared(stage_id, strategy); + + if (SetCostUnderStrategy(sp) == SUCCESS) { + if (full_batch) { + MS_LOG(INFO) << name_ << ": Successfully generated full-batch-parallel-strategy."; + } else { + MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy."; + } + PrintStrategy(sp); + } else { + if (full_batch) { + MS_LOG(ERROR) << name_ << ": Generating full-batch-parallel-strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +Status VirtualDatasetInfo::InferAsLossDivisor() { + // no need to insert div op + as_loss_divisor_ = 1; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.h new file mode 100644 index 0000000000..fe54954be0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_OPS_INFO_DATASET_INFO_H_ +#define PARALLEL_OPS_INFO_DATASET_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class VirtualDatasetInfo : public OperatorInfo { + public: + VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~VirtualDatasetInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + Status InferAsLossDivisor() override; +}; +} // namespace parallel +} // namespace mindspore + +#endif // PARALLEL_OPS_INFO_VIRTUAL_DATASET_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/common.h b/mindspore/ccsrc/frontend/parallel/ps/common.h new file mode 100644 index 0000000000..5e136c816f --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/common.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_COMMON_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_COMMON_H_ + +#include +#include +#include +#include "ps/ps.h" + +namespace mindspore { +namespace parallel { +namespace ps { +constexpr char kEnvCommType[] = "MS_COMM_TYPE"; +constexpr char kEnvInterface[] = "MS_INTERFACE"; +constexpr char kEnvPServerNum[] = "MS_SERVER_NUM"; +constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM"; +constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST"; +constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT"; + +constexpr char kEnvRole[] = "MS_ROLE"; +constexpr char kEnvRoleOfPServer[] = "MS_PSERVER"; +constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; +constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; + +constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE"; +constexpr char kDmlcInterface[] = "DMLC_INTERFACE"; +constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER"; +constexpr char kDmlcWorkerNum[] = "DMLC_NUM_WORKER"; +constexpr char kDmlcRole[] = "DMLC_ROLE"; +constexpr char kDmlcSchedulerHost[] = "DMLC_PS_ROOT_URI"; +constexpr char kDmlcSchedulerPort[] = "DMLC_PS_ROOT_PORT"; + +constexpr char kCommTypeOfIBVerbs[] = "ibverbs"; +constexpr char kCommTypeOfTCP[] = "zmq"; +constexpr char kRoleOfPServer[] = "server"; +constexpr char kRoleOfWorker[] = "worker"; +constexpr char kRoleOfScheduler[] = "scheduler"; + +constexpr char kLearningRate[] = "learning_rate"; +constexpr char kMomentum[] = "momentum"; + +constexpr char kApplyMomentum[] = "ApplyMomentum"; +constexpr char kSparseAdam[] = "Adam"; +constexpr char kSparseFtrl[] = "Ftrl"; + +constexpr int kInitWeightsCmd = 10; +constexpr int kInitWeightToOptimIdCmd = 11; +constexpr int kInitOptimInputsShapeCmd = 12; +constexpr int kInitEmbeddingsCmd = 20; +constexpr int kEmbeddingLookupCmd = 30; + +constexpr size_t kInvalidKey = UINT64_MAX; + +using Key = ::ps::Key; +using Keys = ::ps::SArray; +using Values = ::ps::SArray; +using ValuesPtr = std::shared_ptr; +using Weight = ::ps::SArray; +using Grad = ::ps::SArray; +using LookupIds = ::ps::SArray; +using Lengths = ::ps::SArray; +using WeightPtr = std::shared_ptr; +using GradPtr = std::shared_ptr; +// using EmbeddingTable = std::unordered_map; +// using EmbeddingTable = ::ps::SArray; +// using EmbeddingTablePtr = std::shared_ptr; +using InputsShape = std::vector>>; +using InputsShapePtr = std::shared_ptr>>>; +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_COMMON_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc new file mode 100644 index 0000000000..e16c713e3c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ps/optimizer_info.h" +#include + +namespace mindspore { +namespace parallel { +namespace ps { +void OptimizerInfo::AddWorkspace(const AddressPtr &workspace) { workspaces_.push_back(workspace); } + +const std::vector &OptimizerInfo::inputs() { return inputs_; } + +const std::vector &OptimizerInfo::workspaces() { return workspaces_; } + +const std::vector &OptimizerInfo::outputs() { return outputs_; } + +bool OptimizerInfo::IsSparse() const { return false; } + +size_t OptimizerInfo::grad_index() { return 0; } + +size_t OptimizerInfo::indices_index() { return 0; } + +void OptimizerInfo::UpdateWeight(const WeightPtr &weight) { + AddressPtr weight_addr = std::make_shared(); + weight_addr->addr = weight->data(); + weight_addr->size = weight->size(); + inputs_[0] = weight_addr; +} + +void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { + float *accum_grad_data = reinterpret_cast(gradient()->addr); + size_t size = gradient()->size / sizeof(float); + size_t grad_index = this->grad_index(); + size_t grad_offset = 0; + for (size_t i = 0; i < grad_index; i++) { + grad_offset += lengths[i]; + } + float *grad_data = values.data() + grad_offset; + CHECK_EQ(size, static_cast(lengths[grad_index])); + + for (size_t i = 0; i < size; i++) { + accum_grad_data[i] += grad_data[i]; + } +} + +void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { + // Append grad data to the end + float *accum_grad_data = reinterpret_cast(gradient()->addr); + + size_t grad_index = this->grad_index(); + size_t grad_offset = 0; + for (size_t i = 0; i < grad_index; i++) { + grad_offset += lengths[i]; + } + float *incr_grad_data = values.data() + grad_offset; + size_t incr_grad_size = lengths[grad_index] * sizeof(float); + + auto ret = memcpy_s(accum_grad_data + grads_offset_, incr_grad_size, incr_grad_data, incr_grad_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + grads_offset_ += incr_grad_size; + gradient()->size += incr_grad_size; + + // Append indice data to the end + int *accum_indices_data = reinterpret_cast(indices()->addr); + + size_t indices_index = this->indices_index(); + size_t indice_offset = 0; + for (size_t i = 0; i < indices_index; i++) { + indice_offset += lengths[i]; + } + int *incr_indice_data = reinterpret_cast(values.data() + indice_offset); + size_t incr_indice_size = lengths[indices_index] * sizeof(float); + + auto ret2 = memcpy_s(accum_indices_data + indices_offset_, incr_indice_size, incr_indice_data, incr_indice_size); + if (ret2 != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + } + indices_offset_ += incr_indice_size; + indices()->size += incr_indice_size; +} + +void SparseOptimInfo::Reset() { + auto &gradient = this->gradient(); + gradient->size = 0; + auto &indices = this->indices(); + indices->size = 0; + grads_offset_ = 0; + indices_offset_ = 0; +} + +MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr &accumulate, + const AddressPtr &learning_rate, const AddressPtr &gradient, + const AddressPtr &momentum) { + inputs_.push_back(weight); + inputs_.push_back(accumulate); + inputs_.push_back(learning_rate); + inputs_.push_back(gradient); + inputs_.push_back(momentum); +} + +const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; } + +const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; } + +SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, + const AddressPtr &beta1_power, const AddressPtr &beta2_power, + const AddressPtr &learning_rate, const AddressPtr &beta1, + const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, + const AddressPtr &indices, size_t grads_offset, size_t indices_offset) { + inputs_.push_back(weight); + inputs_.push_back(m); + inputs_.push_back(v); + inputs_.push_back(beta1_power); + inputs_.push_back(beta2_power); + inputs_.push_back(learning_rate); + inputs_.push_back(beta1); + inputs_.push_back(beta2); + inputs_.push_back(epsilon); + inputs_.push_back(grad); + inputs_.push_back(indices); + grads_offset_ = grads_offset; + indices_offset_ = indices_offset; +} + +void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { + void *data_ptr = values.data(); + AddressPtr beta1_power = inputs_[3]; + size_t size = values.size() * sizeof(float); + auto ret = memcpy_s(beta1_power->addr, size, data_ptr, size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } +} + +const AddressPtr &SparseAdamOptimInfo::gradient() { return inputs_[9]; } + +const AddressPtr &SparseAdamOptimInfo::indices() { return inputs_[10]; } + +bool SparseAdamOptimInfo::IsSparse() const { return true; } + +size_t SparseAdamOptimInfo::grad_index() { return 6; } + +size_t SparseAdamOptimInfo::indices_index() { return 7; } + +SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, + const AddressPtr &grad, const AddressPtr &indices, size_t grads_offset, + size_t indices_offset) { + inputs_.push_back(weight); + inputs_.push_back(accum); + inputs_.push_back(linear); + inputs_.push_back(grad); + inputs_.push_back(indices); + grads_offset_ = grads_offset; + indices_offset_ = indices_offset; +} + +const AddressPtr &SparseFtrlOptimInfo::gradient() { return inputs_[3]; } + +const AddressPtr &SparseFtrlOptimInfo::indices() { return inputs_[4]; } + +bool SparseFtrlOptimInfo::IsSparse() const { return true; } + +size_t SparseFtrlOptimInfo::grad_index() { return 0; } + +size_t SparseFtrlOptimInfo::indices_index() { return 1; } +} // namespace ps +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h new file mode 100644 index 0000000000..bb9a64acdb --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ + +#include +#include "backend/kernel_compiler/kernel.h" +#include "frontend/parallel/ps/common.h" + +namespace mindspore { +namespace parallel { +namespace ps { +using mindspore::kernel::AddressPtr; +class OptimizerInfo { + public: + OptimizerInfo() = default; + virtual ~OptimizerInfo() = default; + + virtual void Update(const Values &values, const Lengths &lengths) {} + virtual void UpdateWeight(const WeightPtr &weight); + virtual void Accumulate(const Values &values, const Lengths &lengths) = 0; + virtual void Reset() {} + void AddWorkspace(const AddressPtr &workspace); + + virtual const AddressPtr &gradient() = 0; + virtual const AddressPtr &indices() = 0; + const std::vector &inputs(); + const std::vector &workspaces(); + const std::vector &outputs(); + + virtual bool IsSparse() const; + virtual size_t grad_index(); + virtual size_t indices_index(); + + protected: + std::vector inputs_; + std::vector workspaces_; + std::vector outputs_; +}; + +class DenseOptimInfo : public OptimizerInfo { + public: + DenseOptimInfo() = default; + ~DenseOptimInfo() override = default; + + void Accumulate(const Values &values, const Lengths &lens) override; +}; + +class SparseOptimInfo : public OptimizerInfo { + public: + SparseOptimInfo() = default; + ~SparseOptimInfo() override = default; + + void Accumulate(const Values &values, const Lengths &lens) override; + void Reset() override; + + protected: + size_t grads_offset_{0}; + size_t indices_offset_{0}; +}; + +class MomentumOptimInfo : public DenseOptimInfo { + public: + MomentumOptimInfo(const AddressPtr &weight, const AddressPtr &accumulate, const AddressPtr &learning_rate, + const AddressPtr &gradient, const AddressPtr &momentum); + ~MomentumOptimInfo() override = default; + + const AddressPtr &gradient(); + const AddressPtr &indices(); +}; + +class SparseAdamOptimInfo : public SparseOptimInfo { + public: + SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, const AddressPtr &beta1_power, + const AddressPtr &beta2_power, const AddressPtr &learning_rate, const AddressPtr &beta1, + const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, + const AddressPtr &indices, size_t grads_offset, size_t indices_offset); + ~SparseAdamOptimInfo() override = default; + + void Update(const Values &values, const Lengths &lens) override; + const AddressPtr &gradient(); + const AddressPtr &indices(); + bool IsSparse() const override; + size_t grad_index() override; + size_t indices_index() override; +}; + +class SparseFtrlOptimInfo : public SparseOptimInfo { + public: + SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, + const AddressPtr &grad, const AddressPtr &indices, size_t grads_offset, size_t indices_offset); + ~SparseFtrlOptimInfo() override = default; + + const AddressPtr &gradient(); + const AddressPtr &indices(); + bool IsSparse() const override; + size_t grad_index() override; + size_t indices_index() override; +}; +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc new file mode 100644 index 0000000000..159a50793e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ps/optimizer_info_builder.h" +#include +#include +#include + +namespace mindspore { +namespace parallel { +namespace ps { +OptimizerInfo *OptimizerInfoBuilder::Build(const std::shared_ptr &pserver_kernel, + const WeightPtr &weight, const Keys &keys, const Values &values, + const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) { + OptimizerInfo *optim_info = BuildInputs(weight, keys, values, lens, inputs_shape, worker_num); + std::vector ws_sizes = pserver_kernel->workspace_sizes(); + BuildWorkspaces(optim_info, ws_sizes, worker_num); + BuildOutputs(optim_info, worker_num); + return optim_info; +} + +void OptimizerInfoBuilder::BuildWorkspaces(OptimizerInfo *info, const std::vector &ws_sizes, + size_t worker_num) { + for (size_t i = 0; i < ws_sizes.size(); i++) { + size_t size = ws_sizes[i]; + AddressPtr workspace = std::make_shared(); + workspace->addr = new float[size]; + workspace->size = size; + info->AddWorkspace(workspace); + } +} + +OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, + const Lengths &lens, const InputsShapePtr &inputs_shape, + size_t worker_num) { + AddressPtr weight_addr = std::make_shared(); + weight_addr->addr = weight->data(); + weight_addr->size = weight->size(); + void *data_ptr = values.data(); + AddressPtr accumulate = std::make_shared(); + accumulate->addr = new float[weight->size()]; + accumulate->size = weight->size(); + AddressPtr learning_rate = std::make_shared(); + learning_rate->addr = data_ptr; + learning_rate->size = lens[0]; + AddressPtr gradient = std::make_shared(); + gradient->addr = reinterpret_cast(learning_rate->addr) + lens[0]; + gradient->size = lens[1]; + AddressPtr momentum = std::make_shared(); + momentum->addr = reinterpret_cast(gradient->addr) + lens[1]; + momentum->size = lens[2]; + + return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum); +} + +OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, + const Lengths &lens, const InputsShapePtr &inputs_shape, + size_t worker_num) { + AddressPtr weight_addr = std::make_shared(); + weight_addr->addr = weight->data(); + weight_addr->size = weight->size(); + AddressPtr m = std::make_shared(); + m->addr = new float[weight->size()]; + m->size = weight->size() * sizeof(float); + AddressPtr v = std::make_shared(); + v->addr = new float[weight->size()]; + v->size = weight->size() * sizeof(float); + + void *data_ptr = values.data(); + void *copy_data_ptr = new float[values.size()]; + auto ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + AddressPtr beta1_power = std::make_shared(); + beta1_power->addr = copy_data_ptr; + beta1_power->size = lens[0] * sizeof(float); + AddressPtr beta2_power = std::make_shared(); + beta2_power->addr = reinterpret_cast(beta1_power->addr) + lens[0]; + beta2_power->size = lens[1] * sizeof(float); + + AddressPtr learning_rate = std::make_shared(); + learning_rate->addr = reinterpret_cast(beta2_power->addr) + lens[1]; + learning_rate->size = lens[2] * sizeof(float); + + AddressPtr beta1 = std::make_shared(); + beta1->addr = reinterpret_cast(learning_rate->addr) + lens[2]; + beta1->size = lens[3] * sizeof(float); + + AddressPtr beta2 = std::make_shared(); + beta2->addr = reinterpret_cast(beta1->addr) + lens[3]; + beta2->size = lens[4] * sizeof(float); + + AddressPtr epsilon = std::make_shared(); + epsilon->addr = reinterpret_cast(beta2->addr) + lens[4]; + epsilon->size = lens[5] * sizeof(float); + + const std::shared_ptr> &grad_shape = (*inputs_shape)[9]; + size_t total_grad_size = + std::accumulate((*grad_shape).begin(), (*grad_shape).end(), sizeof(float), std::multiplies()); + AddressPtr grad = std::make_shared(); + grad->addr = new float[total_grad_size * worker_num]; + auto ret2 = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast(epsilon->addr) + lens[5], + lens[6] * sizeof(float)); + if (ret2 != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + } + grad->size = lens[6] * sizeof(float); + + const std::shared_ptr> &indices_shape = (*inputs_shape)[10]; + size_t total_indice_size = + std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(float), std::multiplies()); + AddressPtr indices = std::make_shared(); + indices->addr = new float[total_indice_size * worker_num]; + auto ret3 = memcpy_s(indices->addr, lens[7] * sizeof(float), + reinterpret_cast(epsilon->addr) + lens[5] + lens[6], lens[7] * sizeof(float)); + if (ret3 != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret3 << ")"; + } + indices->size = lens[7] * sizeof(float); + + return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, + grad, indices, total_grad_size, total_indice_size); +} + +OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, + const Lengths &lens, const InputsShapePtr &inputs_shape, + size_t worker_num) { + AddressPtr weight_addr = std::make_shared(); + weight_addr->addr = weight->data(); + weight_addr->size = weight->size(); + AddressPtr accum = std::make_shared(); + accum->addr = new float[weight->size()]; + accum->size = weight->size() * sizeof(float); + for (size_t i = 0; i < weight->size(); i++) { + float *tmp = reinterpret_cast(accum->addr); + tmp[i] = 1.0; + } + AddressPtr linear = std::make_shared(); + linear->addr = new float[weight->size()]; + memcpy_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); + linear->size = weight->size() * sizeof(float); + + const std::shared_ptr> &grad_shape = (*inputs_shape)[3]; + size_t total_grad_size = std::accumulate((*grad_shape).begin(), (*grad_shape).end(), 1, std::multiplies()); + AddressPtr grad = std::make_shared(); + grad->addr = new float[total_grad_size * worker_num]; + auto ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + grad->size = lens[0] * sizeof(float); + + const std::shared_ptr> &indices_shape = (*inputs_shape)[4]; + size_t total_indice_size = + std::accumulate((*indices_shape).begin(), (*indices_shape).end(), 1, std::multiplies()); + AddressPtr indices = std::make_shared(); + indices->addr = new float[total_indice_size * worker_num]; + auto ret2 = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast(values.data()) + lens[0], + lens[1] * sizeof(float)); + if (ret2 != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + } + indices->size = lens[1] * sizeof(float); + + return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, total_grad_size, total_indice_size); +} +} // namespace ps +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h new file mode 100644 index 0000000000..c5aae32921 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_BUILDER_H_ + +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/ps/pserver_kernel.h" +#include "frontend/parallel/ps/optimizer_info.h" + +namespace mindspore { +namespace parallel { +namespace ps { +using mindspore::kernel::KernelMod; +using mindspore::kernel::ps::PServerKernel; +class OptimizerInfoBuilder { + public: + OptimizerInfoBuilder() = default; + virtual ~OptimizerInfoBuilder() = default; + + OptimizerInfo *Build(const std::shared_ptr &pserver_kernel, const WeightPtr &weight, const Keys &keys, + const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, + size_t worker_num); + + virtual OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, + const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) = 0; + + virtual void BuildWorkspaces(OptimizerInfo *info, const std::vector &ws_sizes, size_t worker_num); + virtual void BuildOutputs(OptimizerInfo *info, size_t worker_num) {} +}; + +class MomentumOptimInfoBuilder : public OptimizerInfoBuilder { + public: + OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, + const InputsShapePtr &inputs_shape, size_t worker_num) override; +}; + +class SparseAdamOptimInfoBuilder : public OptimizerInfoBuilder { + public: + OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, + const InputsShapePtr &inputs_shpae, size_t worker_num) override; +}; + +class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder { + public: + OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, + const InputsShapePtr &inputs_shpae, size_t worker_num) override; +}; +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_BUILDER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h new file mode 100755 index 0000000000..1afb4c9fa6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -0,0 +1,559 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/session_factory.h" +#include "frontend/parallel/ps/common.h" +#include "frontend/parallel/ps/optimizer_info.h" +#include "frontend/parallel/ps/optimizer_info_builder.h" +#include "frontend/parallel/ps/util.h" +#include "runtime/device/cpu/kernel_select_cpu.h" +#include "utils/context/ms_context.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "backend/kernel_compiler/ps/sparse_apply_adam_ps_kernel.h" +#include "backend/kernel_compiler/ps/sparse_apply_ftrl_ps_kernel.h" +#include "backend/kernel_compiler/ps/apply_momentum_ps_kernel.h" +#include "backend/kernel_compiler/ps/embedding_look_up_ps_kernel.h" + +namespace mindspore { +namespace parallel { +namespace ps { +using mindspore::kernel::ps::PServerKernel; +template +class ParameterServer { + public: + static ParameterServer &GetInstance() { + static ParameterServer instance; + return instance; + } + + void Run(const FuncGraphPtr &func_graph); + + private: + ParameterServer() + : pserver_num_(0), + worker_num_(0), + rank_id_(0), + grad_accum_count_(0), + ps_(new ::ps::KVServer(0)), + handler_(nullptr), + func_graph_(nullptr), + kernel_graph_(nullptr), + sess_(nullptr), + thread_(nullptr) {} + ~ParameterServer() = default; + ParameterServer(const ParameterServer &) = delete; + ParameterServer &operator=(const ParameterServer &) = delete; + + struct ServerHandler { + explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} + void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server); + void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data); + void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleInitWeights(const ::ps::KVPairs &req_data); + void HandleInitWeightToOptimId(const ::ps::KVPairs &req_data); + void HandleInitInputsShape(const ::ps::KVPairs &req_data); + void HandleInitEmbeddings(const ::ps::KVPairs &req_data); + void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + ParameterServer *ps_; + }; + + bool Init(const FuncGraphPtr &func_graph); + void InitOptimInfoBuilders(); + void InitWeightKeyToOptims(const Key &key, const int &optim_id); + void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); + void InitWeight(const Key &key, const WeightPtr &weight); + void InitGrad(const Key &key, const GradPtr &grad); + void InitEmbeddingTable(const Key &key, + const std::shared_ptr>>> &shapes); + void UpdateWeights(); + void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); + WeightPtr weight(const Key &key); + void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); + int SumOfShapes(const std::vector &shapes) const; + size_t PreComputeCapacity(const Keys &keys, const Lengths &lens); + bool ReadyForUpdateWeights(); + bool ReadyForAccumGrads(); + void ResetGradAccumCount(); + + size_t pserver_num_; + size_t worker_num_; + size_t rank_id_; + size_t grad_accum_count_; + std::unique_ptr<::ps::KVServer> ps_; + std::unique_ptr handler_; + FuncGraphPtr func_graph_; + std::shared_ptr kernel_graph_; + std::shared_ptr sess_; + + std::unordered_map> optimizers_; + std::unordered_map optim_inputs_shape_; + std::unordered_map> optim_infos_; + std::unordered_map> optim_info_builders_; + std::unordered_map weight_key_to_optims_; + std::unordered_map weights_; + std::unordered_map grads_; + std::unordered_map grads_accum_counter_; + // std::unordered_map embeddings_; + std::unordered_map> embedding_lookup_ops_; + std::unordered_map embedding_row_lens_; + + T learning_rate_; + T momentum_; + + std::mutex mutex_; + std::condition_variable apply_grads_cv_; + std::condition_variable accum_grads_cv_; + + std::unique_ptr thread_; + + friend struct ServerHandler; +}; + +class FuncGraph; +template +void ParameterServer::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVServer *server) { + ::ps::KVPairs res; + if (req_meta.cmd == kInitWeightsCmd) { + MS_LOG(ERROR) << "handle init weights cmd" << std::endl; + HandleInitWeights(req_data); + } else if (req_meta.cmd == kInitWeightToOptimIdCmd) { + MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl; + HandleInitWeightToOptimId(req_data); + } else if (req_meta.cmd == kInitOptimInputsShapeCmd) { + MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl; + HandleInitInputsShape(req_data); + } else if (req_meta.cmd == kInitEmbeddingsCmd) { + MS_LOG(ERROR) << "handle init embedding cmd" << std::endl; + HandleInitEmbeddings(req_data); + } else if (req_meta.cmd == kEmbeddingLookupCmd) { + MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl; + HandleEmbeddingLookup(req_meta, req_data, &res); + } else if (req_meta.push) { + MS_LOG(ERROR) << "handle push req cmd" << std::endl; + HandlePushReq(req_meta, req_data); + } else { + MS_LOG(ERROR) << "handle pull req cmd" << std::endl; + HandlePullReq(req_meta, req_data, &res); + } + server->Response(req_meta, res); +} + +template +void ParameterServer::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data) { + ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens); +} + +template +void ParameterServer::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + res->keys = req_data.keys; + ::ps::Key key = req_data.keys[0]; + res->vals = *(ps_->weight(key)); +} + +template +void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVPairs &req_data) { + size_t key_num = req_data.keys.size(); + T *data_ptr = req_data.vals.data(); + size_t pos = 0; + for (size_t i = 0; i < key_num; i++) { + Key key = req_data.keys[i]; + size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; + + WeightPtr weight_ptr = std::make_shared<::ps::SArray>(); + weight_ptr->CopyFrom(data_ptr + pos, data_len); + ps_->InitWeight(key, weight_ptr); + + GradPtr grad_ptr = std::make_shared<::ps::SArray>(data_len, 0); + ps_->InitGrad(key, grad_ptr); + pos += data_len; + } +} + +template +void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs &req_data) { + size_t key_num = req_data.keys.size(); + for (size_t i = 0; i < key_num; i++) { + Key key = req_data.keys[i]; + T val = req_data.vals[i]; + ps_->InitWeightKeyToOptims(key, val); + } +} + +template +void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs &req_data) { + ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens); +} + +template +void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs &req_data) { + std::shared_ptr>>> shapes = + std::make_shared>>>(); + std::shared_ptr> input_shape = std::make_shared>(); + std::shared_ptr> indices_shape = std::make_shared>(); + std::shared_ptr> output_shape = std::make_shared>(); + shapes->push_back(input_shape); + shapes->push_back(indices_shape); + shapes->push_back(output_shape); + + const Key &key = req_data.keys[0]; + const Lengths &lens = req_data.lens; + size_t index = 0; + for (int i = 0; i < lens[0]; i++) { + input_shape->push_back(static_cast(req_data.vals[index++])); + } + for (int j = 0; j < lens[1]; j++) { + indices_shape->push_back(static_cast(req_data.vals[index++])); + } + for (int k = 0; k < lens[2]; k++) { + output_shape->push_back(static_cast(req_data.vals[index++])); + } + ps_->InitEmbeddingTable(key, shapes); +} + +template +void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { + const Key &key = req_data.keys[0]; + ps_->DoEmbeddingLookup(key, req_data.vals, res); + for (size_t i = 0; i < req_data.vals.size(); i++) { + res->keys->push_back(req_data.vals[i]); + } +} + +template +bool ParameterServer::Init(const FuncGraphPtr &func_graph) { + const char *server_num = getenv(kEnvPServerNum); + const char *worker_num = getenv(kEnvWorkerNum); + if (server_num != nullptr) { + pserver_num_ = *server_num - '0'; + } + if (worker_num != nullptr) { + worker_num_ = *worker_num - '0'; + } + func_graph_ = func_graph; + rank_id_ = ::ps::MyRank(); + handler_.reset(new ServerHandler(this)); + + InitOptimInfoBuilders(); + + ps_->set_request_handle(*handler_); + thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); + return true; +} + +template +void ParameterServer::InitOptimInfoBuilders() { + std::shared_ptr momentum_info_builder = std::make_shared(); + std::shared_ptr sparse_adam_info_builder = std::make_shared(); + std::shared_ptr sparse_ftrl_info_builder = std::make_shared(); + optim_info_builders_[kApplyMomentum] = momentum_info_builder; + optim_info_builders_[kSparseAdam] = sparse_adam_info_builder; + optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder; +} + +template +void ParameterServer::InitWeightKeyToOptims(const Key &key, const int &optim_id) { + if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(key) == "") { + return; + } + weight_key_to_optims_[key] = Util::optimizer_name(optim_id); +} + +template +void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) { + InputsShapePtr inputs_shape = std::make_shared(); + int val_idx = 0; + const Key &key = keys[0]; + + if (optim_inputs_shape_.count(key) == 0) { + optim_inputs_shape_[key] = inputs_shape; + } + for (size_t i = 0; i < keys.size(); i++) { + auto shape = std::make_shared>(); + inputs_shape->push_back(shape); + + int len = lengths[i]; + for (int j = 0; j < len; j++) { + shape->push_back(values[val_idx++]); + } + } + if (weight_key_to_optims_.count(key) > 0) { + const std::string &optim_name = weight_key_to_optims_[key]; + if (optimizers_.count(optim_name) == 0 && optim_inputs_shape_.count(key) > 0) { + if (optim_name == kSparseAdam) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_); + optimizer->InitKernel(optim_inputs_shape_[key]); + optimizers_[optim_name] = optimizer; + } else if (optim_name == kApplyMomentum) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_); + optimizer->InitKernel(optim_inputs_shape_[key]); + optimizers_[optim_name] = optimizer; + } else if (optim_name == kSparseFtrl) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_); + optimizer->InitKernel(optim_inputs_shape_[key]); + optimizers_[optim_name] = optimizer; + } + } + } +} + +template +void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { + if (weights_.count(key) == 0) { + weights_[key] = weight; + } +} + +template +void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) { + if (grads_.count(key) == 0) { + grads_[key] = grad; + grads_accum_counter_[key] = 0; + } +} + +template +void ParameterServer::InitEmbeddingTable( + const Key &key, const std::shared_ptr>>> &shapes) { + // Init embedding lookup kernel + std::shared_ptr lookup = std::make_shared(rank_id_, pserver_num_); + lookup->InitKernel(shapes); + embedding_lookup_ops_[key] = lookup; + + // Init embedding weight + const std::vector &input_shapes = lookup->input_sizes(); + size_t total_dims = 1; + for (auto shape : input_shapes) { + total_dims *= shape; + } + WeightPtr embedding = std::make_shared(total_dims, 0.01); + weights_[key] = embedding; + + grads_accum_counter_[key] = 0; +} + +template +void ParameterServer::UpdateWeights() { + while (true) { + std::unique_lock lock(mutex_); + apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); }); + + for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { + Key key = iter->first; + WeightPtr weight_ptr = iter->second; + + std::shared_ptr optimizer = nullptr; + if (weight_key_to_optims_.count(key) > 0) { + const std::string &optim_name = weight_key_to_optims_[key]; + optimizer = optimizers_[optim_name]; + } + MS_EXCEPTION_IF_NULL(optimizer); + + std::shared_ptr optim_info = optim_infos_[key]; + if (optim_info == nullptr) { + continue; + } + const WeightPtr &weight = weights_[key]; + optim_info->UpdateWeight(weight); + const std::vector &inputs = optim_info->inputs(); + const std::vector &workspaces = optim_info->workspaces(); + const std::vector &outputs = optim_info->outputs(); + + optimizer->Execute(inputs, workspaces, outputs); + optim_info->Reset(); + } + ResetGradAccumCount(); + accum_grads_cv_.notify_all(); + } +} + +template +void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) { + std::unique_lock lock(mutex_); + accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); }); + + const Key &key = keys[0]; + std::shared_ptr optim_info = optim_infos_[key]; + + // Create or update the optimizer info + if (optim_info == nullptr) { + const std::shared_ptr &builder = optim_info_builders_[weight_key_to_optims_[key]]; + std::shared_ptr pserver_kernel = optimizers_[weight_key_to_optims_[key]]; + if (pserver_kernel == nullptr) { + MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; + } + MS_EXCEPTION_IF_NULL(pserver_kernel); + OptimizerInfo *optim = + builder->Build(pserver_kernel, weights_[key], keys, values, lengths, optim_inputs_shape_[key], worker_num_); + optim_info.reset(optim); + optim_infos_[key] = optim_info; + } else { + optim_info->Update(values, lengths); + } + MS_EXCEPTION_IF_NULL(optim_info); + + optim_info->Accumulate(values, lengths); + + grads_accum_counter_[key] += 1; + if (grads_accum_counter_[key] == worker_num_) { + grad_accum_count_++; + } + if (ReadyForUpdateWeights()) { + apply_grads_cv_.notify_one(); + } +} + +template +WeightPtr ParameterServer::weight(const Key &key) { + std::unique_lock lock(mutex_); + + if (weights_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid weight key " << key; + return nullptr; + } + WeightPtr weight_ptr = weights_[key]; + WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray>(weight_ptr->size(), 0); + copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size()); + return copy_weight_ptr; +} + +template +void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res) { + std::unique_lock lock(mutex_); + if (weights_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding table key " << key; + return; + } + if (embedding_lookup_ops_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; + return; + } + WeightPtr table_ptr = weights_[key]; + std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; + + // Update shapes of lookup operator + std::shared_ptr>>> shapes = + std::make_shared>>>(); + std::shared_ptr> indices_shape = std::make_shared>(); + indices_shape->emplace_back(lookup_ids.size()); + shapes->push_back(indices_shape); + table_lookup_op->ReInit(shapes); + + const std::vector output_shapes = table_lookup_op->output_sizes(); + std::vector inputs; + AddressPtr embedding_table = std::make_shared(); + AddressPtr indices = std::make_shared(); + inputs.push_back(embedding_table); + inputs.push_back(indices); + embedding_table->addr = table_ptr->data(); + embedding_table->size = table_ptr->size() * sizeof(T); + indices->addr = lookup_ids.data(); + indices->size = lookup_ids.size() * sizeof(T); + + std::vector workspaces; + std::vector outputs; + AddressPtr output = std::make_shared(); + std::shared_ptr addr = std::make_shared(output_shapes[0] / sizeof(T), 0); + + output->addr = addr->data(); + output->size = output_shapes[0]; + outputs.push_back(output); + + table_lookup_op->Execute(inputs, workspaces, outputs); + res->vals = *addr; + res->lens.push_back(res.vals.size()); +} + +template +int ParameterServer::SumOfShapes(const std::vector &shapes) const { + int sum = 1; + for (auto shape : shapes) { + sum *= shape; + } + return sum; +} + +template +size_t ParameterServer::PreComputeCapacity(const Keys &keys, const Lengths &lens) { + size_t capacity = 0; + for (size_t i = 0; i < keys.size(); i++) { + Key key = keys[i]; + if (embedding_row_lens_.count(key) > 0) { + capacity += embedding_row_lens_[key] * lens[i]; + } else { + MS_LOG(ERROR) << "Invalid embedding lookup id " << key; + } + } + return capacity; +} + +template +inline bool ParameterServer::ReadyForUpdateWeights() { + return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); +} + +template +inline bool ParameterServer::ReadyForAccumGrads() { + return grad_accum_count_ < weights_.size(); +} + +template +inline void ParameterServer::ResetGradAccumCount() { + grad_accum_count_ = 0; + for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) { + grads_accum_counter_[iter->first] = 0; + } +} + +template +void ParameterServer::Run(const FuncGraphPtr &func_graph) { + ::ps::Start(0); + if (!::ps::IsServer()) { + std::cout << "This is not ther Server" << std::endl; + return; + } + Init(func_graph); + thread_->join(); +} +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc b/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc new file mode 100755 index 0000000000..274b7259b0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ps/scheduler.h" +#include +#include "ps/ps.h" + +namespace mindspore { +namespace parallel { +namespace ps { +void Scheduler::Run() { + ::ps::Start(0); + while (true) { + sleep(1); + } +} +} // namespace ps +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ps/scheduler.h b/mindspore/ccsrc/frontend/parallel/ps/scheduler.h new file mode 100755 index 0000000000..e656bcfd22 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/scheduler.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_SCHEDULER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_SCHEDULER_H_ +namespace mindspore { +namespace parallel { +namespace ps { +class Scheduler { + public: + static Scheduler &GetInstance() { + static Scheduler instance; + return instance; + } + + void Run(); + + private: + Scheduler() = default; + ~Scheduler() = default; + Scheduler(const Scheduler &) = delete; + Scheduler &operator=(const Scheduler &) = delete; +}; +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_SCHEDULER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.cc b/mindspore/ccsrc/frontend/parallel/ps/util.cc new file mode 100644 index 0000000000..fc63e88901 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/util.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ps/util.h" +#include +#include "frontend/parallel/ps/common.h" +#include "common/utils.h" + +namespace mindspore { +namespace parallel { +namespace ps { +std::unordered_map Util::optimizer_to_ids{ + {kApplyMomentum, 0}, + {kSparseAdam, 1}, + {kSparseFtrl, 2}, +}; + +std::unordered_map Util::id_to_optimizers{ + {0, kApplyMomentum}, + {1, kSparseAdam}, + {2, kSparseFtrl}, +}; +bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); } + +bool Util::IsRoleOfWorker() { + auto role = common::GetEnv(kEnvRole); + if (strcmp(role.c_str(), kEnvRoleOfWorker) == 0) { + return true; + } else { + return false; + } +} + +bool Util::IsRoleOfPServer() { + auto role = common::GetEnv(kEnvRole); + if (strcmp(role.c_str(), kEnvRoleOfPServer) == 0) { + return true; + } else { + return false; + } +} + +bool Util::IsRoleOfScheduler() { + auto role = common::GetEnv(kEnvRole); + if (strcmp(role.c_str(), kEnvRoleOfScheduler) == 0) { + return true; + } else { + return false; + } +} + +void Util::SetInternalEnvVar() { + if (IsParamServerMode()) { + auto comm_type = common::GetEnv(kEnvCommType); + if (comm_type.size() > 0) { + (void)common::SetEnv(kDmlcCommType, comm_type.c_str()); + } + auto interface = common::GetEnv(kEnvInterface); + if (interface.size() > 0) { + (void)common::SetEnv(kDmlcInterface, interface.c_str()); + } + auto server_num = common::GetEnv(kEnvPServerNum); + if (server_num.size() > 0) { + (void)common::SetEnv(kDmlcPServerNum, server_num.c_str()); + } + auto worker_num = common::GetEnv(kEnvWorkerNum); + if (worker_num.size() > 0) { + (void)common::SetEnv(kDmlcWorkerNum, worker_num.c_str()); + } + if (IsRoleOfScheduler()) { + (void)common::SetEnv(kDmlcRole, kRoleOfScheduler); + } else if (IsRoleOfPServer()) { + (void)common::SetEnv(kDmlcRole, kRoleOfPServer); + } else if (IsRoleOfWorker()) { + (void)common::SetEnv(kDmlcRole, kRoleOfWorker); + } + auto scheduler_host = common::GetEnv(kEnvSchedulerHost); + if (scheduler_host.size() > 0) { + (void)common::SetEnv(kDmlcSchedulerHost, scheduler_host.c_str()); + } + auto scheduler_port = common::GetEnv(kEnvSchedulerPort); + if (scheduler_port.size() > 0) { + (void)common::SetEnv(kDmlcSchedulerPort, scheduler_port.c_str()); + } + } +} + +int Util::optimizer_id(std::string name) { + if (optimizer_to_ids.count(name) > 0) { + return optimizer_to_ids[name]; + } + return -1; +} + +std::string Util::optimizer_name(int id) { + if (id_to_optimizers.count(id) > 0) { + return id_to_optimizers[id]; + } + return ""; +} + +bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; } + +int Util::LocalShard(int first_dim, int rank_id, int server_num) { + int shard_size = std::round((static_cast(first_dim)) / server_num); + int remain_size = first_dim % server_num; + if (remain_size == 0 || rank_id < server_num - 1) { + return shard_size; + } else { + return first_dim - (shard_size * (server_num - 1)); + } +} +} // namespace ps +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.h b/mindspore/ccsrc/frontend/parallel/ps/util.h new file mode 100644 index 0000000000..8947ad36de --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/util.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ + +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace parallel { +namespace ps { +class Util { + public: + static bool IsParamServerMode(); + static bool IsRoleOfWorker(); + static bool IsRoleOfPServer(); + static bool IsRoleOfScheduler(); + static void SetInternalEnvVar(); + static int optimizer_id(std::string name); + static std::string optimizer_name(int id); + static bool is_optimizer(std::string name); + static int LocalShard(int first_dim, int rank_id, int server_num); + + private: + static std::unordered_map optimizer_to_ids; + static std::unordered_map id_to_optimizers; +}; +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h new file mode 100644 index 0000000000..9ecbc28fc5 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -0,0 +1,259 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ + +#include +#include +#include +#include +#include +#include "ps/ps.h" +#include "utils/log_adapter.h" +#include "frontend/parallel/ps/util.h" +#include "frontend/parallel/ps/common.h" +#include "frontend/parallel/ps/worker_proxy.h" + +namespace mindspore { +namespace parallel { +namespace ps { +template +class Worker { + public: + static Worker &GetInstance() { + static Worker instance; + return instance; + } + + void Run(); + void Push(const std::vector &keys, std::vector addrs, const std::vector &sizes); + void Pull(const size_t key, void *dev_addr, const size_t size); + size_t SetParamKey(const std::string ¶m_name); + void SetKeyOptimId(size_t key, const std::string &optimizer_name); + void SetOptimInputShapes(size_t key, const std::vector &shape); + void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); + void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const std::vector &sizes); + void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size); + void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd); + + private: + Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {} + ~Worker() { ::ps::Finalize(0, true); } + Worker(const Worker &) = delete; + Worker &operator=(const Worker &) = delete; + + bool IsKeyInit(const size_t key); + size_t GetParamKey(const std::string ¶m_name); + void InitPSOptimId(const size_t param_key); + void InitPSOptimInputShapes(const size_t key); + void InitPSParamData(const std::vector &keys, void *origin_addr, size_t size); + static void EmbeddingLookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &ranges, + std::vector>> *sliced) {} + + std::shared_ptr> kv_worker_; + bool running_; + size_t key_cnt_; + std::map param_to_key_; + std::map init_keys_; + std::map key_to_optimId_; + std::map>> key_to_optim_shapes_; +}; + +template +void Worker::Run() { + if (running_) { + MS_LOG(INFO) << "'Worker is already running."; + return; + } + + ::ps::Start(0); + if (!::ps::IsWorker()) { + MS_LOG(EXCEPTION) << "The role is not worker."; + } + kv_worker_ = std::make_shared>(0, 0, 1); + running_ = true; +} + +template +void Worker::Push(const std::vector &keys, std::vector addrs, const std::vector &sizes) { + size_t total_size = 0; + for (auto size : sizes) { + total_size += size; + } + ::ps::SArray total_buffer(total_size, 0); + size_t offset = 0; + for (size_t i = 0; i < sizes.size(); i++) { + memcpy(total_buffer.data() + offset / sizeof(T), addrs[i], sizes[i] * sizeof(T)); + offset += sizes[i] * sizeof(T); + } + kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes)); +} + +template +void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { + ::ps::SArray variables(size / sizeof(T), 0); + kv_worker_->Wait(kv_worker_->ZPull({key}, &variables)); + memcpy(dev_addr, variables.data(), size); +} + +template +void Worker::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd) { + kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, &lookup_result, cmd); +} + +template +void Worker::InitPSParamData(const std::vector &keys, void *origin_addr, size_t size) { + ::ps::SArray addr(reinterpret_cast(origin_addr), size / sizeof(T)); + ::ps::SArray<::ps::Key> key(keys); + ::ps::SArray lens; + lens.push_back(addr.size()); + kv_worker_->Wait(kv_worker_->ZPush(key, addr, lens, kInitWeightsCmd)); + init_keys_[key[0]] = true; +} + +template +void Worker::SetOptimInputShapes(size_t key, const std::vector &shape) { + if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) { + key_to_optim_shapes_[key] = {shape}; + } else { + key_to_optim_shapes_[key].push_back(shape); + } +} + +template +void Worker::InitPSOptimInputShapes(const size_t key) { + ::ps::SArray<::ps::Key> keys; + ::ps::SArray shape_len; + ::ps::SArray all_shape; + std::vector> shapes = key_to_optim_shapes_[key]; + for (auto shape : shapes) { + keys.push_back(key); + if (shape.size() == 0) { + shape_len.push_back(1); + all_shape.push_back(1); + } else { + shape_len.push_back(SizeToInt(shape.size())); + for (auto dim : shape) { + all_shape.push_back(static_cast(dim)); + } + } + } + MS_LOG(ERROR) << "keys:" << keys; + MS_LOG(ERROR) << "shape_len:" << shape_len; + MS_LOG(ERROR) << "all_shape:" << all_shape; + if (!init_keys_[key]) { + init_keys_[key] = true; + } + kv_worker_->PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd); +} + +template +bool Worker::IsKeyInit(const size_t key) { + if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) { + return false; + } + return true; +} + +template +size_t Worker::SetParamKey(const std::string ¶m_name) { + size_t key = UINT64_MAX; + if (param_to_key_.count(param_name)) { + key = param_to_key_[param_name]; + MS_LOG(INFO) << param_name << " key is already set: key value is " << key; + } else { + key = key_cnt_++; + param_to_key_[param_name] = key; + MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name; + } + return key; +} + +template +size_t Worker::GetParamKey(const std::string ¶m_name) { + size_t key = kInvalidKey; + if (param_to_key_.find(param_name) != param_to_key_.end()) { + key = param_to_key_[param_name]; + MS_LOG(ERROR) << "Get key of parameter " << param_name << " key is " << key; + } + return key; +} + +template +void Worker::SetKeyOptimId(size_t key, const std::string &optimizer_name) { + key_to_optimId_[key] = Util::optimizer_id(optimizer_name); +} + +template +void Worker::InitPSOptimId(const size_t param_key) { + if (key_to_optimId_.count(param_key) == 0) { + MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key; + } + int optim_id = key_to_optimId_[param_key]; + + ::ps::SArray<::ps::Key> keys = {param_key}; + ::ps::SArray optim_id_vals = {static_cast(optim_id)}; + ::ps::SArray optim_id_lens = {optim_id_vals.size()}; + kv_worker_->PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd); +} + +template +void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, + const std::vector &sizes) { + bool has_init = IsKeyInit(keys[0]); + if (has_init) { + MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized."; + return; + } + ::ps::SArray shapes_val; + for (auto dim : shapes) { + shapes_val.push_back(static_cast(dim)); + } + kv_worker_->Wait(kv_worker_->InitEmbeddingTable(::ps::SArray<::ps::Key>(keys), shapes_val, ::ps::SArray(sizes))); +} + +template +// Initialize parameters and optimizer kernels of Parameter Server. +void Worker::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size) { + size_t param_key = GetParamKey(param_name); + if (param_key == kInvalidKey) { + MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned."; + return; + } + bool init = IsKeyInit(param_key); + if (!init) { + MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name; + // No need to push embedding table data to Parameter Server. + if (param_name.find("embedding_table") == std::string::npos && param_name.find("wide_w") == std::string::npos) { + InitPSParamData({param_key}, param_data, param_size); + } + InitPSOptimId(param_key); + InitPSOptimInputShapes(param_key); + } +} + +template +void Worker::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { + kv_worker_->AddEmbeddingTable(key, row_count); +} + +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h new file mode 100644 index 0000000000..a0f58d39a4 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -0,0 +1,311 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ + +#include +#include +#include +#include +#include +#include "ps/ps.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace parallel { +namespace ps { +template +class WorkerProxy : public ::ps::KVWorker { + public: + using Worker = ::ps::KVWorker; + using Callback = std::function; + using SlicedKVs = std::vector>>; + using Slicer = + std::function &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>; + using ::ps::SimpleApp::obj_; + explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) { + using _1 = std::placeholders::_1; + using _2 = std::placeholders::_2; + using _3 = std::placeholders::_3; + lookup_customer_ = std::unique_ptr<::ps::Customer>( + new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy::ProcessLookupResult, this, _1))); + lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3); + init_embedding_slicer_ = std::bind(&WorkerProxy::EmbeddingTableInitSlicer, this, _1, _2, _3); + push_slicer_ = std::bind(&WorkerProxy::PushSlicer, this, _1, _2, _3); + broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3); + } + ~WorkerProxy() override = default; + + void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); + void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *outs, int cmd = 0, const Callback &cb = nullptr, + int priority = 0); + int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int priority = 0); + void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, + int cmd = 0, int priority = 0); + + private: + template + int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *vals, int cmd, + const Callback &cb); + void LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void ProcessLookupResult(const ::ps::Message &msg); + void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs &kvs, + const Slicer &slicer); + + std::unique_ptr<::ps::Customer> lookup_customer_; + std::unordered_map<::ps::Key, std::shared_ptr>> embedding_table_ranges_; + std::unordered_map>> lookup_results_; + std::mutex mutex_; + Slicer lookup_slicer_; + Slicer init_embedding_slicer_; + Slicer push_slicer_; + Slicer broadcast_slicer_; + std::unordered_map lookup_callbacks_; +}; + +template +void WorkerProxy::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { + uint64_t begin = 0; + uint64_t end = 0; + int server_num = ::ps::NumServers(); + for (int i = 0; i < server_num; i++) { + int local_row_cnt = Util::LocalShard(row_count, i, server_num); + if (i == 0) { + end = local_row_cnt - 1; + } else { + begin = end + 1; + end += local_row_cnt; + } + ::ps::Range range(begin, end); + if (embedding_table_ranges_.count(key) == 0) { + embedding_table_ranges_[key] = std::make_shared>(); + } + embedding_table_ranges_[key]->push_back(range); + } +} + +template +void WorkerProxy::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *outs, int cmd, const Callback &cb, + int priority) { + int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = lookup_ids; + kvs.lens = lens; + kvs.priority = priority; + Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_); + lookup_customer_->WaitRequest(ts); +} + +template +int WorkerProxy::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens, const Callback &cb, int priority) { + int ts = obj_->NewRequest(::ps::kServerGroup); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals; + kvs.lens = lens; + kvs.priority = priority; + Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_); + return ts; +} + +template +void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens, int cmd, int priority) { + int ts = obj_->NewRequest(::ps::kServerGroup); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals; + kvs.lens = lens; + kvs.priority = priority; + Send(obj_, ts, true, false, cmd, kvs, push_slicer_); + obj_->WaitRequest(ts); +} + +template +template +int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + C *lookup_result, int cmd, const Callback &cb) { + int ts = lookup_customer_->NewRequest(::ps::kServerGroup); + const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { + mutex_.lock(); + auto &kvs = lookup_results_[ts]; + mutex_.unlock(); + + size_t total_len = 0; + const auto &s = kvs[0]; + for (size_t i = 0; i < s.lens.size(); i++) { + total_len += s.lens[i]; + } + lookup_result->resize(total_len, 0); + T *result_addr = lookup_result->data(); + + for (const auto &s : kvs) { + size_t offset = 0; + for (size_t i = 0; i < s.vals.size(); i++) { + result_addr[offset++] += s.vals[i]; + } + } + + mutex_.lock(); + lookup_results_.erase(ts); + mutex_.unlock(); + if (cb) cb(); + }; + lookup_callbacks_[ts] = callback; + return ts; +} + +template +void WorkerProxy::LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + int *data = send.lens.data(); + size_t size = send.lens.size(); + std::vector lookup_ids(data, data + size); + std::sort(lookup_ids.begin(), lookup_ids.end()); + + const Key &key = send.keys[0]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); + sliced->resize(ranges.size()); + + size_t index = 0; + for (size_t i = 0; i < ranges.size(); i++) { + const ::ps::Range &range = ranges[i]; + const auto &begin = range.begin(); + const auto &end = range.end(); + auto &kvs = sliced->at(i).second; + + auto lookup_id = static_cast(lookup_ids[index]); + while (lookup_id >= begin && lookup_id <= end) { + kvs.vals.push_back(lookup_id); + if (++index >= lookup_ids.size()) { + break; + } + lookup_id = static_cast(lookup_ids[index]); + } + kvs.keys.push_back(key); + kvs.lens.push_back(kvs.vals.size()); + + if (kvs.vals.size() == 0) { + sliced->at(i).first = false; + } else { + sliced->at(i).first = true; + } + } +} + +template +void WorkerProxy::EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + const Key &key = send.keys[0]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); + sliced->resize(ranges.size()); + for (size_t i = 0; i < ranges.size(); i++) { + sliced->at(i).first = true; + sliced->at(i).second = send; + } +} + +template +void WorkerProxy::PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + auto server_num = ::ps::Postoffice::Get()->num_servers(); + sliced->resize(server_num); + for (int i = 0; i < server_num; i++) { + sliced->at(i).first = true; + sliced->at(i).second = send; + } +} + +template +void WorkerProxy::BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + auto server_num = ::ps::Postoffice::Get()->num_servers(); + sliced->resize(server_num); + for (int i = 0; i < server_num; i++) { + sliced->at(i).first = true; + sliced->at(i).second = send; + } +} + +template +void WorkerProxy::ProcessLookupResult(const ::ps::Message &msg) { + int ts = msg.meta.timestamp; + if (msg.meta.pull) { + CHECK_GE(msg.data.size(), (size_t)2); + ::ps::KVPairs kvs; + kvs.keys = msg.data[0]; + kvs.vals = msg.data[1]; + if (msg.data.size() > (size_t)2) { + kvs.lens = msg.data[2]; + } + mutex_.lock(); + lookup_results_[ts].push_back(kvs); + mutex_.unlock(); + } + if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) { + const auto &cb = lookup_callbacks_[ts]; + cb(); + lookup_callbacks_.erase(ts); + } +} + +template +void WorkerProxy::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, + const ::ps::KVPairs &kvs, const Slicer &slicer) { + SlicedKVs sliced; + slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced); + + for (size_t i = 0; i < sliced.size(); i++) { + const auto &s = sliced[i]; + if (!s.first) continue; + ::ps::Message msg; + msg.meta.app_id = customer->app_id(); + msg.meta.customer_id = customer->customer_id(); + msg.meta.request = true; + msg.meta.push = push; + msg.meta.pull = pull; + msg.meta.head = cmd; + msg.meta.timestamp = timestamp; + msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i); + msg.meta.priority = kvs.priority; + const auto &kvs = s.second; + if (kvs.keys.size()) { + msg.AddData(kvs.keys); + msg.AddData(kvs.vals); + if (kvs.lens.size()) { + msg.AddData(kvs.lens); + } + } + ::ps::Postoffice::Get()->van()->Send(msg); + } +} +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ diff --git a/mindspore/ccsrc/parallel/status.h b/mindspore/ccsrc/frontend/parallel/status.h similarity index 100% rename from mindspore/ccsrc/parallel/status.h rename to mindspore/ccsrc/frontend/parallel/status.h diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc new file mode 100644 index 0000000000..8d54eb454a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -0,0 +1,1187 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/step_auto_parallel.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/param_value.h" +#include "ir/tensor.h" +#include "frontend/optimizer/opt.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/ops_info/tmp_identity_info.h" +#include "frontend/parallel/ops_info/reshape_info.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/pipeline.h" + +namespace mindspore { +namespace parallel { +bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); + // assume no change to graph + bool changes = false; + // control whether use model_parallel mode + if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) || + root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) { + return changes; + } + // check whether strategy_search_mode is valid + std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); + if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { + // Setting searching mode: dynanic programming as default. + strategy_search_mode = DYNAMIC_PROGRAMMING; + MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default"; + } + + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); + + if (MsContext::GetInstance()->save_graphs_flag()) { + draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root); + } + MS_LOG(INFO) << "Now entering step auto parallel"; + TOTAL_OPS = 0; + AnfNodePtr ret = root->get_return(); + std::vector all_nodes = DeepScopedGraphSearch(ret); + + if (ParallelInit() != SUCCESS) { + MS_LOG(EXCEPTION) << "Parallel init failed"; + } + + // mark the forward cnodes, parallel only care these nodes + MarkForwardCNode(root); + + if (FindCommunicationOp(all_nodes)) { + MS_LOG(EXCEPTION) << "The graph contain communication op"; + } + + // search parallelization strategy + if (strategy_search_mode == DYNAMIC_PROGRAMMING) { + if (ParallelStrategySearch(all_nodes, root) != SUCCESS) { + MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode"; + } + } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) { + if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) { + MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode"; + } + } else { + MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected"; + } + + (void)gettimeofday(&end_time, nullptr); + uint64_t time = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + time += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us"; + + root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true); + return changes; +} + +// Given the node, return whether each input is a parameter or a output of a operator. +// The returned boolean vector should be the same order of the inputs, thus its implementation +// is closely consistent with ExtractShape() in step_parallel.cc +std::vector ExtractInputParameterByNode(const CNodePtr &node) { + std::vector is_parameter; + std::vector node_inputs{node->inputs()}; + for (size_t i = 1; i < node_inputs.size(); ++i) { + auto input = node_inputs[i]; + + if (input->isa()) { + auto input_parameter = input->cast(); + if (input_parameter->has_default()) { + bool requires_grad = input_parameter->default_param()->requires_grad(); + is_parameter.push_back(requires_grad); + } else { + is_parameter.push_back(false); + } + } else if (input->isa() || IsValueNode(input) || IsValueNode(input)) { + is_parameter.push_back(false); + } + } + return is_parameter; +} + +// Given the type, return the number of bytes to represent this type +size_t GetLengthOfDataType(const TypePtr &type) { + switch (type->type_id()) { + case kNumberTypeBool: + return sizeof(bool); + case kNumberTypeInt8: + return sizeof(int8_t); + case kNumberTypeInt16: + return sizeof(int16_t); + case kNumberTypeInt32: + return sizeof(int32_t); + case kNumberTypeInt64: + return sizeof(int64_t); + case kNumberTypeUInt8: + return sizeof(uint8_t); + case kNumberTypeUInt16: + return sizeof(uint16_t); + case kNumberTypeUInt32: + return sizeof(uint32_t); + case kNumberTypeUInt64: + return sizeof(uint64_t); + case kNumberTypeFloat16: + return sizeof(float) / 2; + case kNumberTypeFloat32: + return sizeof(float); + case kNumberTypeFloat64: + return sizeof(double); + case kNumberTypeInt: + return sizeof(int); + case kNumberTypeUInt: + return sizeof(unsigned int); + case kNumberTypeFloat: + return sizeof(float); + default: + MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); + } +} + +size_t GetInputsTypeLen(const AnfNodePtr &input) { + MS_EXCEPTION_IF_NULL(input); + if (!input->isa() && !input->isa() && !IsValueNode(input)) { + MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor"; + } + + size_t input_type_len = 0; + auto type = input->Type(); + MS_EXCEPTION_IF_NULL(type); + if (type->isa()) { + auto input_element_type = type->cast()->element(); + input_type_len = GetLengthOfDataType(input_element_type); + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); + } + return input_type_len; +} + +std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector inputs_type_len; + std::vector node_inputs{node->inputs()}; + + // extract input element length + for (auto &input : node_inputs) { + if (IsValueNode(input)) { + auto func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector parameters = FindParameterByRefKeyNode(input, func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); + } else if (input->isa() || input->isa() || IsValueNode(input)) { + // extract input shape from parameter and apply node + inputs_type_len.push_back(GetInputsTypeLen(input)); + } + } + return inputs_type_len; +} + +std::vector ExtractOutputTypeByNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector outputs_type; + // extract output element type + auto primary_output_type = node->Type(); + MS_EXCEPTION_IF_NULL(primary_output_type); + if (primary_output_type->isa()) { + // in this case, the output is a tuple + auto tuple_output_type = primary_output_type->cast(); + auto elements = tuple_output_type->elements(); + for (auto &ele : elements) { + if (ele->isa()) { + auto ele_element_type = ele->cast()->element(); + outputs_type.push_back(ele_element_type); + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); + } + } + } else { + // in this case, the output is a single tensor + if (primary_output_type->isa()) { + auto element_type = primary_output_type->cast()->element(); + outputs_type.push_back(element_type); + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); + } + } + return outputs_type; +} + +bool IsElementWiseOperator(const std::string &op_name) { + static const std::set elementwise_op = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, + SQRT, CAST, POW, EXP, LOG, COS, + ACOS, LOGICALNOT, NEG, SQUARE, SIGMOID}; + auto iter = elementwise_op.find(op_name); + return (iter != elementwise_op.end()); +} + +bool IsSplittableOperator(const std::string &op_name) { + // clang-format off + static const std::set splittable_op = + {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, + FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, + REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, + MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, + LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, + STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, + SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS}; + // clang-format on + + auto iter = splittable_op.find(op_name); + return (iter != splittable_op.end()); +} + +bool IsAutoParallelCareNode(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + ValueNodePtr prim_node = cnode->input(0)->cast(); + if (prim_node == nullptr) { + return false; + } + PrimitivePtr prim = GetValueNode(prim_node); + if (prim == nullptr) { + return false; + } + bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name()); + if (bool_result) { + MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name(); + } else if (prim->name() == CAST) { + if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) { + // Do not care CASTs from optimizer + return false; + } + return true; + } + return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name()); +} + +OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { + MS_EXCEPTION_IF_NULL(prim); + MS_EXCEPTION_IF_NULL(cnode); + auto attrs = prim->attrs(); + std::vector shape_list = ExtractShape(cnode); + if (shape_list.empty()) { + MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape"; + } + // Create an OperatorInfo instance + OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list); + MS_EXCEPTION_IF_NULL(operator_info); + // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not) + std::vector parameter_info = ExtractInputParameterByNode(cnode); + if (operator_info->set_is_parameter(parameter_info) != SUCCESS) { + MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name(); + return nullptr; + } + // Set the data type for inputs and outputs of this OperatorInfo + auto inputs_type_length = ExtractInputTypeLengthByNode(cnode); + auto outputs_type = ExtractOutputTypeByNode(cnode); + std::vector outputs_type_length; + outputs_type_length.reserve(outputs_type.size()); + std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length), + GetLengthOfDataType); + if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) { + MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name(); + return nullptr; + } + if (operator_info->set_outputs_type(outputs_type) != SUCCESS) { + MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name(); + return nullptr; + } + // When the 'inputs' contains numerical values for some operators, these values should be extracted from + // ANF graph + auto &inputs = cnode->inputs(); + std::vector input_value; + for (size_t index = 1; index < inputs.size(); ++index) { + if (inputs[index]->isa()) { + input_value.push_back(GetValueNode(inputs[index])); + } else { + input_value.emplace_back(nullptr); + } + } + operator_info->set_input_value(input_value); + operator_info->set_outputs_dtype(cnode->Type()); + operator_info->set_cnode(cnode); + // key of strategy map + std::string strategy_key_name = NodeParameterName(cnode); + bool load_strategy_from_ckpt = + StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); + // If no strategy has been configured for this operator, then candidate strategies are generated for + // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. + // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . + if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) { + // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for + // BatchParallelInfo operator + operator_info->ComputeBatchSplitFlagList(); + if (operator_info->GenerateStrategies(0) != SUCCESS) { + MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed."; + return nullptr; + } + } else { + // In this case, the configured strategy should be extracted to help setting cost + StrategyPtr strategyPtr; + if (load_strategy_from_ckpt) { + strategyPtr = (*stra_map)[strategy_key_name]; + } else { + strategyPtr = parallel::ExtractStrategy(attrs); + } + if (strategyPtr != nullptr) { + if (prim->name() == RESHAPE) { + MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; + } + // Set cost for this configured strategy + if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { + MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; + } else if (FULLY_USE_DEVICES) { + // If configured to fully use devices, then checking for the user-specified strategy + int32_t used_devices = operator_info->used_devices(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size(); + // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel + if (used_devices == 1) { + return operator_info; + } + // 'used_devices == -1' means that 'used_devices_' is not set + if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) { + MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, " + << "but the specified strategy uses device: " << used_devices + << ", total devices: " << total_device_num; + } + } + } + } + return operator_info; +} + +// Using CNode's UniqueIds to construct nodes +Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &) { + MS_LOG(INFO) << "Constructing nodes for cost graph begins."; + entire_costgraph = std::make_shared(); + entire_costgraph->SetDeviceMemoryAndCostParameter(); + // The map from CNode's UniqueId to its operatorInfo + std::map from_cnode_to_info; + // extract strategy from checkpoint for multi-train + StrategyMap stra_map; + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { + if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; + } + } + // Step 1 + for (auto &node : all_nodes) { + // NOTE: we only care about splittable Primitive operators + auto cnode = node->cast(); + bool bool_result = (cnode == nullptr) || (!IsValueNode(cnode->input(0))); + if (bool_result) { + continue; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + if (!IsAutoParallelCareNode(cnode)) { + // Needed by rec_parser + if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) { + auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node); + if (prev_cnode != nullptr) { + entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId())); + } + } + continue; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + + auto search_cnode = from_cnode_to_info.find(cnode->UniqueId()); + if (search_cnode == from_cnode_to_info.end()) { + auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); + if (operator_info == nullptr) { + return FAILED; + } + // Needed by rec_parser + operator_info->set_type(prim->name()); + std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); + + entire_costgraph->AddOperator(operator_info); + (void)cnode->set_operator_info(operator_info); + MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); + (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); + // Needed by rec_parser + entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); + } else { + // Two CNODEs' UniqueIds should not be equal + MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name(); + } + } + + MS_LOG(INFO) << "Constructing nodes for cost graph ends."; + return SUCCESS; +} + +// Using CNode's UniqueIdThroughCopys to construct nodes +Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &) { + MS_LOG(INFO) << "Constructing nodes for cost graph begins."; + entire_costgraph = std::make_shared(); + entire_costgraph->SetDeviceMemoryAndCostParameter(); + // The map from CNode's UniqueIdThroughCopy to its operatorInfo + std::map from_cnode_to_info; + // extract strategy from checkpoint for multi-train + StrategyMap stra_map; + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { + if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; + } + } + for (auto &node : all_nodes) { + // NOTE: we only care about splittable Primitive operators + auto cnode = node->cast(); + bool bool_result = (cnode == nullptr) || (!IsValueNode(cnode->input(0))); + if (bool_result) { + continue; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + if (!IsAutoParallelCareNode(cnode)) { + // Needed by rec_parser + if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) { + auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node); + if (prev_cnode != nullptr) { + entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId())); + } + } + continue; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + + // Find the operatorInfo if it exists + auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy()); + if (search_cnode == from_cnode_to_info.end()) { + // In this case, the corresponding OperatorInfo is not created, create the new one. + auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); + if (operator_info == nullptr) { + return FAILED; + } + // Needed by rec_parser + operator_info->set_type(prim->name()); + std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); + + entire_costgraph->AddOperator(operator_info); + (void)cnode->set_operator_info(operator_info); + MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); + (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); + // Needed by rec_parser + entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); + } else { + auto current_op_ptr = search_cnode->second; + if (current_op_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; + } else { + bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && + (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && + (current_op_ptr->name().find(prim->name()) == std::string::npos); + if (is_find_wrong) { + MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() + << " does not match the Prim: " << prim->name(); + } + (void)cnode->set_operator_info(current_op_ptr); + MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); + } + } + } + + MS_LOG(INFO) << "Constructing nodes for cost graph ends."; + return SUCCESS; +} + +void ConstructCostGraphEdges(const std::vector &all_nodes) { + // Step 2 + MS_LOG(INFO) << "Constructing edges for cost graph begins."; + for (auto &node : all_nodes) { + auto cnode = node->cast(); + bool bool_result_cnode = (cnode == nullptr) || !IsValueNode(cnode->input(0)); + if (bool_result_cnode) { + continue; + } + auto &inputs = cnode->inputs(); + ValueNodePtr prim_anf_node = inputs[0]->cast(); + if (!IsAutoParallelCareNode(cnode)) { + continue; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + size_t edge_count = 0; + + for (size_t i = 1; i < inputs.size(); ++i) { + auto prev_cnode = inputs[i]->cast(); + bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); + if (bool_result_prev_cnode) { + continue; + } + ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast(); + PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast(); + size_t output_index = 0; + + bool bool_result = + (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); + while (bool_result) { + if (IsAutoParallelCareNode(prev_cnode)) { + std::string edge_name = + prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name(); + // If the edge between these two operators already has been added, then the edge will not be added again. + if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { + break; + } + EdgePtr edge_ptr; + MS_LOG(INFO) << "Creating edge: " << edge_name; + + bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || + (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); + if (follow_strategy) { + // Redistribution in not allowed on the edge. + // Elementwise operators have the same strategy as their previous operators. + edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), + output_index, i - 1, false, true); + } else { + edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), + output_index, i - 1, false); + } + + // Init costs for this edge + if (edge_ptr->InitEdgeCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Edge cost initialization failed"; + } + cnode->operator_info()->AddPrevEdge(edge_ptr); + prev_cnode->operator_info()->AddSuccEdge(edge_ptr); + entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr); + MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and " + << cnode->operator_info()->name(); + edge_count++; + + break; + } else if (prev_prim->name() == TUPLE_GETITEM) { + // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before + // this 'tuple_getitem' + MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator."; + output_index = IntToSize(GetValue(GetValueNode(prev_cnode->input(2)))); + prev_cnode = prev_cnode->input(1)->cast(); + bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); + if (bool_result_tuple) { + break; + } + prev_prim_anf_node = prev_cnode->input(0)->cast(); + prev_prim = prev_prim_anf_node->value()->cast(); + if (!IsAutoParallelCareNode(prev_cnode)) { + MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name(); + } + MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, " + << "and creating an edge between the Operator before " + << "'tuple_getitem' and the Operator after 'tuple_getitem'."; + } else if (prev_prim->name() == DEPEND) { + // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before + // this 'depend' + MS_LOG(INFO) << "Jumping the 'depend' operator."; + prev_cnode = prev_cnode->input(1)->cast(); + bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); + if (bool_result_depend) { + break; + } + prev_prim_anf_node = prev_cnode->input(0)->cast(); + prev_prim = prev_prim_anf_node->value()->cast(); + MS_LOG(INFO) << "Jumped the 'depend' operator, " + << "and creating an edge between the Operator before " + << "'depend' and the Operator after 'depend'."; + } + bool_result = + (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); + } + } + MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); + } + + MS_LOG(INFO) << "Constructing edges for cost graph ends."; +} + +std::pair> CNodeWithRefKeys(const AnfNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector refkeys; + if (cnode->isa()) { + auto cnode_ptr = cnode->cast(); + auto inputs = cnode_ptr->inputs(); + for (auto &one_input : inputs) { + if (IsValueNode(one_input)) { + refkeys.push_back(one_input); + } + } + if (refkeys.size() >= 1) { + return std::make_pair(cnode, refkeys); + } + } + return {nullptr, refkeys}; +} + +void AugmentCostGraph(const std::vector &all_nodes) { + // Step 3 + for (auto &node : all_nodes) { + auto cnode_with_refkeys = CNodeWithRefKeys(node); + if ((!node->isa()) && (cnode_with_refkeys.first == nullptr)) { + continue; + } + std::string parameter_name; + AnfNodePtr target_parameter = nullptr; + AnfNodeIndexSet target_set; + + if (cnode_with_refkeys.first != nullptr) { + // Dealing with the RefKey case + auto refkeys = cnode_with_refkeys.second; + auto cnode = cnode_with_refkeys.first; + + auto cnode_ptr = cnode->cast(); + if (cnode_ptr == nullptr || !IsValueNode(cnode_ptr->input(0))) { + continue; + } + if (!IsAutoParallelCareNode(cnode_ptr)) { + continue; + } + + if (refkeys.size() > 1) { + MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys."; + } + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + auto cnode_func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager()); + + // Find the RefKey being used + auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]]; + for (auto &candidate : candidate_set_by_refkey) { + auto candidate_node = candidate.first; + auto c = candidate_node->cast(); + if (c == nullptr || !IsValueNode(c->input(0))) { + continue; + } + if (!IsAutoParallelCareNode(c)) { + continue; + } + target_set.add(candidate); + } + + // Find the corresponding Parameter being used + std::vector parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + parameter_name = parameters[0]->cast()->name(); + target_parameter = parameters[0]; + auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]]; + for (auto &candidate : candidate_set_by_para) { + auto candidate_node = candidate.first; + auto c = candidate_node->cast(); + if (c == nullptr || !IsValueNode(c->input(0))) { + continue; + } + if (!IsAutoParallelCareNode(c)) { + continue; + } + (void)target_set.insert(candidate); + } + } else if (node->isa()) { + // Dealing with the Parameter case + MS_EXCEPTION_IF_NULL(node->func_graph()); + MS_EXCEPTION_IF_NULL(node->func_graph()->manager()); + auto candidate_set = node->func_graph()->manager()->node_users()[node]; + for (auto &candidate : candidate_set) { + auto candidate_node = candidate.first; + auto c = candidate_node->cast(); + if (c == nullptr || !IsValueNode(c->input(0))) { + continue; + } + if (!IsAutoParallelCareNode(c)) { + continue; + } + (void)target_set.insert(candidate); + } + // In this case, node is a Parameter + parameter_name = node->cast()->name(); + target_parameter = node; + } + if (target_set.size() <= 1) { + continue; + } + + // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs + std::set target_without_duplicate; + for (auto &target : target_set) { + auto target_cnode = target.first->cast(); + auto input_index = target.second; + (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name()); + } + if (target_without_duplicate.size() <= 1) { + continue; + } + + // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators. + OperatorInfoPtr tmp_identity_ptr; + bool new_identity = false; + std::string tmp_identity_name; + auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name); + if (returned_identity != nullptr) { + // In this case, the TmpIdentityInfo instance has already been created + new_identity = false; + tmp_identity_ptr = returned_identity; + tmp_identity_name = tmp_identity_ptr->name(); + } else { + // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created. + new_identity = true; + // 1) extract input shape from this Parameter + MS_EXCEPTION_IF_NULL(target_parameter); + AbstractBasePtr abstract = target_parameter->abstract(); + if (abstract == nullptr) { + MS_LOG(EXCEPTION) << "Failure: abstract is nullptr"; + } + auto input_shape = dyn_cast(abstract->GetShapeTrack()); + if (input_shape == nullptr) { + MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr"; + } + std::vector shape_int = input_shape->shape(); + Shape shape; + (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape), + [](int sub_shape) { return static_cast(sub_shape); }); + Shapes inputs_shape = {shape}; + Shapes outputs_shape = {shape}; + // 2) init the attr + std::unordered_map attr = {}; + + // Create the TmpIdentity instance + tmp_identity_ptr = std::make_shared(inputs_shape, outputs_shape, attr); + tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS)); + TOTAL_OPS++; + tmp_identity_ptr->set_refkey_parameter_name(parameter_name); + // Set the parameter and type lengths for inputs and outputs + std::vector is_parameter; + auto casted_target_parameter = target_parameter->cast(); + MS_EXCEPTION_IF_NULL(casted_target_parameter); + if (casted_target_parameter->has_default()) { + bool requires_grad = casted_target_parameter->default_param()->requires_grad(); + is_parameter.push_back(requires_grad); + } else { + is_parameter.push_back(false); + } + if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) { + MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed"; + } + auto node_type = target_parameter->Type(); + if (node_type->isa()) { + auto input_element_type = node_type->cast()->element(); + std::vector type_length = {GetLengthOfDataType(input_element_type)}; + if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) { + MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed"; + } + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name(); + } + + // Generate strategies for this TmpIdentityInfo instance; + if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) { + MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name(); + } + } + // A flag recording whether new edges have been created or not + bool add_identity_edge = false; + + // Create edges between this TmpIdentityInfo instance and subsequent Operator instances + for (auto &target : target_set) { + auto target_cnode = target.first->cast(); + auto prim = GetValueNode(target_cnode->input(0)); + auto input_index = target.second; + + std::string edge_name = + std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name(); + // If the edge between these two operators already has been added, then the edge will not be added again. + if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) { + continue; + } + std::shared_ptr edge_ptr = std::make_shared( + edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); + + if (edge_ptr->InitEdgeCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Edge cost initialization failed"; + } + target_cnode->operator_info()->AddPrevEdge(edge_ptr); + tmp_identity_ptr->AddSuccEdge(edge_ptr); + entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr); + MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and " + << target_cnode->operator_info()->name(); + add_identity_edge = true; + } + if (new_identity && add_identity_edge) { + // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied + entire_costgraph->AddOperator(tmp_identity_ptr); + } + } +} + +bool FindReshape(const CNodePtr &cnode) { + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + return false; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { + return false; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + OperatorInfoPtr operator_info = cnode->operator_info(); + if (operator_info == nullptr) { + MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; + } + if (prim->name() != RESHAPE) { + return false; + } + return true; +} + +// find previous node, then obtain its strategy_cost_ vector to get its layout vector. +bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) { + // if previous node is a parameter, handle it in the outsize. + if (node->isa()) { + return false; + } + if (!node->isa()) { + return false; + } + CNodePtr cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + return false; + } + if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + *pre_operator_info = cnode->operator_info(); + *out_index = 0; + return true; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = prim_anf_node->value()->cast(); + if (prim->name() == TUPLE_GETITEM) { + *out_index = GetTupleGetItemIndex(cnode); + // find tuple_get_item's previous node + auto pre_node = cnode->input(1); + if (!pre_node->isa()) { + MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; + } + CNodePtr pre_cnode = pre_node->cast(); + if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { + *pre_operator_info = pre_cnode->operator_info(); + return true; + } + return false; + } + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + if (prim->name() == DEPEND && index != 1) { + continue; + } + if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { + continue; + } + return true; + } + MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; + return false; +} + +// find next node, then obtain its strategy_cost_ vector to get its layout vector. +// if reshape's output connect to several primitive, return the first layout found +bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + FuncGraphManagerPtr manager = cnode->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[cnode]; + for (auto &node_pair : node_set) { + CNodePtr use_apply = node_pair.first->cast(); + if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { + MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); + *next_operator_info = use_apply->operator_info(); + *in_index = node_pair.second - 1; + return true; + } + MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) + << " " << (use_apply->operator_info() != nullptr); + + if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { + return true; + } + } + return false; +} + +void ReshapeCostCompute(const std::vector &all_nodes) { + for (auto node : all_nodes) { + auto cnode = node->cast(); + if (!FindReshape(cnode)) { + continue; + } + MS_ASSERT(cnode->inputs().size() == 3); + // get previous node's strategy_cost_ + auto pre_node = cnode->input(1); + int32_t out_index = 0; + OperatorInfoPtr pre_operator_info; + std::vector> pre_stra_costs; + if (pre_node->isa()) { + OperatorInfoPtr operator_info = cnode->operator_info(); + auto reshape_info = std::dynamic_pointer_cast(operator_info); + reshape_info->SetCostForReshapeWithParameter(); + pre_operator_info = reshape_info; + pre_stra_costs = reshape_info->strategy_cost(); + } else { + if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { + MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed"; + } + pre_stra_costs = pre_operator_info->strategy_cost(); + } + // get next node's strategy_cost_ + int32_t in_index = 0; + OperatorInfoPtr next_operator_info; + std::vector> next_stra_costs; + bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index); + if (!find_next_node) { + MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed"; + } + // set input_layout and output_layout for reshape. + // init reshape and set cost for each input_layout and output_layout. + OperatorInfoPtr operator_info = cnode->operator_info(); + auto reshape_info = std::dynamic_pointer_cast(operator_info); + reshape_info->set_pre_operator_name(pre_operator_info->name()); + reshape_info->set_pre_operator_index(out_index); + if (find_next_node) { + next_stra_costs = next_operator_info->strategy_cost(); + reshape_info->set_next_operator_name(next_operator_info->name()); + reshape_info->set_next_operator_index(in_index); + } + bool is_prev_param = pre_node->isa(); + if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) != + SUCCESS) { + MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!"; + } + } +} + +Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root) { + // There are 4 meta-steps to determine the parallelization strategy for the ANF graph. + // Step 1: Traverse the ANF graph, and create NODEs for costgraph: + // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies + // for each OperatorInfo; + // Step 1.1: Deal with 'Reshape': + // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's + // layout as its output layout. + // Step 2: Traverse the ANF graph, and create EDGES for costgraph: + // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies + // for each edge, based on the strategies of two OperatorInfos; + // Step 3: Augment the costgraph: + // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity + // operator for this Parameter, and add an edge for the use of this Parameter by each + // subsequent operator; + // Step 3.1: Calculate memory usage: + // note the memory usage calculation is different in training phase and inference phase. + // Step 4: Run the Dynamic Programming algorithm: + // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge + // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input + // tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm + // runs on each of them. + // + // OUTPUT: the determined strategy for each operator. + + // Step 1 + if (CostModelContext::GetInstance()->is_multi_subgraphs()) { + if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { + MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " + << entire_costgraph->GetOperators().size() << " operators."; + } else { + MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + } + } else { + if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { + MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " + << entire_costgraph->GetOperators().size() << " operators."; + } else { + MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + } + } + // Step 1.1 + ReshapeCostCompute(all_nodes); + // Step 2 + ConstructCostGraphEdges(all_nodes); + MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() + << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; + + // Step 3: Augment the costgraph. + AugmentCostGraph(all_nodes); + MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() + << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; + + // Step 3.1: Calculate the memory usage + if (entire_costgraph->CalculateMemoryCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Calculating memory cost failed."; + } + + // Step 4: run DP algorithm on the costgraph. + if (GetStrategy(entire_costgraph) != SUCCESS) { + MS_LOG(ERROR) << "Strategy search for cost-graph fails"; + return FAILED; + } + MS_LOG(INFO) << "Searching strategy succeeded."; + + if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { + MS_LOG(INFO) << "Init selected strategy succeeded."; + } else { + MS_LOG(EXCEPTION) << "Init selected strategy failed."; + } + + // print the selected strategy + for (auto &op : entire_costgraph->GetOperators()) { + StrategyPtr s_strategy = op->selected_strategy(); + MS_LOG(INFO) << op->name() << " : The strategy is:"; + PrintStrategy(s_strategy); + } + + return SUCCESS; +} + +std::vector> RecInputTensorNames(const std::map::iterator &it, + std::vector> input_tensor_names) { + for (size_t j = 0; j < input_tensor_names.size(); j++) { + for (size_t k = 0; k < input_tensor_names[j].size(); k++) { + if (it->first == input_tensor_names[j][k]) { + input_tensor_names[j][k] = it->second; + break; + } + } + } + return input_tensor_names; +} + +CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) { + PrimitivePtr prim = GetValueNode(prim_anf_node); + if (prim->name() == TUPLE_GETITEM || prim->name() == DEPEND) { + auto prev_cnode = cnode->input(1)->cast(); + if (prev_cnode == nullptr || !IsValueNode(prev_cnode->input(0))) { + return nullptr; + } + auto prev_prim = prev_cnode->input(0)->cast()->value()->cast(); + while (prev_prim->name() == TUPLE_GETITEM || prev_prim->name() == DEPEND) { + prev_cnode = prev_cnode->input(1)->cast(); + if (prev_cnode == nullptr || !IsValueNode(prev_cnode->input(0))) { + return nullptr; + } + prev_prim = prev_cnode->input(0)->cast()->value()->cast(); + } + return prev_cnode; + } + return nullptr; +} + +Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root) { + if (CostModelContext::GetInstance()->is_multi_subgraphs()) { + if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { + MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " + << entire_costgraph->GetOperators().size() << " operators."; + } else { + MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + } + } else { + if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { + MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " + << entire_costgraph->GetOperators().size() << " operators."; + } else { + MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + } + } + ReshapeCostCompute(all_nodes); + + auto ops = entire_costgraph->GetOperators(); + std::vector> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list(); + auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list(); + for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) { + input_tensor_names = RecInputTensorNames(it++, input_tensor_names); + } + std::shared_ptr graph = ParseGraph(ops, input_tensor_names); + + std::shared_ptr>> eli_list(new std::vector>); + std::shared_ptr> index_list(new std::vector); + graph = EliminateGraph(graph, eli_list, index_list); + + size_t num_device = g_device_manager->DeviceNum(); + double device_memory = entire_costgraph->GetDeviceMemory(); + if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) { + MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; + } else { + MS_LOG(ERROR) << "PartitionForAllDevices failed."; + return FAILED; + } + + GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list); + + if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { + MS_LOG(INFO) << "Init selected strategy succeeded."; + } else { + MS_LOG(ERROR) << "Init selected strategy failed."; + return FAILED; + } + + // print the selected strategy + for (auto &op : entire_costgraph->GetOperators()) { + StrategyPtr s_strategy = op->selected_strategy(); + MS_LOG(INFO) << op->name() << " : The strategy is:"; + PrintStrategy(s_strategy); + } + + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h new file mode 100644 index 0000000000..f87d49b736 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_STEP_AUTO_PARALLEL_H_ +#define PARALLEL_STEP_AUTO_PARALLEL_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "frontend/optimizer/opt.h" +#include "frontend/parallel/status.h" +#include "pipeline/jit/pipeline.h" + +namespace mindspore { +namespace parallel { +bool IsSplittableOperator(const std::string &); + +bool IsAutoParallelCareNode(const CNodePtr &); + +// main step of Auto-parallel +bool StepAutoParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); + +size_t GetLengthOfDataType(const TypePtr &type); + +std::vector ExtractInputParameterByNode(const CNodePtr &node); + +std::vector ExtractInputTypeLengthByNode(const CNodePtr &node); + +std::vector ExtractOutputTypeByNode(const CNodePtr &node); + +Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &root); + +Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &root); + +void ConstructCostGraphEdges(const std::vector &all_nodes); + +void AugmentCostGraph(const std::vector &all_nodes); + +Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root); + +Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root); + +std::vector> RecInputTensorNames(const std::map::iterator &it, + std::vector> input_tensor_names); + +CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node); +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_STEP_AUTO_PARALLEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc new file mode 100644 index 0000000000..6b9cfd9d37 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -0,0 +1,2368 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/step_parallel.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "ir/tensor.h" +#include "ir/param_value.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/dynamic_creator.h" +#include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/graph_util/graph_info.h" +#include "frontend/parallel/graph_util/node_info.h" +#include "frontend/parallel/node_check.h" +#include "frontend/parallel/ops_info/matmul_info.h" +#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" +#include "utils/comm_manager.h" +#include "utils/symbolic.h" +#include "pipeline/jit/static_analysis/prim.h" + +using mindspore::tensor::Tensor; + +namespace mindspore { +namespace parallel { +static const std::set COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; +static const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; +// g_RefMap, for CNode B input i is a RefKey[Parameter C], +// it will be one item in map with key: C, and value: (B, i) +static std::map> g_RefMap; + +void SetCommunicationOpGroupLabel(std::vector new_node_input) { + if (new_node_input.empty()) { + return; + } + + ValueNodePtr prim_anf_node = new_node_input[0]->cast(); + PrimitivePtr prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + + auto attrs = prim->attrs(); + auto iter = attrs.find(GROUP); + if (iter != attrs.end()) { + auto value = iter->second; + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + std::string hash_name = value->cast()->value(); + MS_EXCEPTION_IF_NULL(g_device_manager); + std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name); + (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name)); + } + } +} + +std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { + MS_EXCEPTION_IF_NULL(node); + OperatorArgs arg_forward = op.second; + ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name); + MS_EXCEPTION_IF_NULL(pyop_instance); + OperatorParams params = arg_forward.second; + + std::vector new_node_input = {NewValueNode(pyop_instance), node}; + if (!params.empty()) { + for (auto ¶m : params) { + AnfNodePtr val = NewValueNode(param.first.second); + MS_EXCEPTION_IF_NULL(val); + int32_t position = param.second; + (void)new_node_input.insert(new_node_input.begin() + position, val); + } + } + + // if the op have 'group' attr, set the rank list name for the op + SetCommunicationOpGroupLabel(new_node_input); + return new_node_input; +} + +void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node, + const FuncGraphPtr &func_graph, const std::string &instance_name) { + // insert new node before the node + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + ScopePtr scope = node->scope(); + MS_EXCEPTION_IF_NULL(scope); + std::vector node_input = CreateInput(op, pre_node, instance_name); + CNodePtr new_node = func_graph->NewCNode(node_input); + MS_EXCEPTION_IF_NULL(new_node); + if (instance_name.find(SPLIT_SENS) == std::string::npos) { + new_node->set_in_forward_flag(true); // mark forward flag + } + auto new_node_value = node_input[0]->cast(); + MS_EXCEPTION_IF_NULL(new_node_value); + PrimitivePtr new_node_prim = new_node_value->value()->cast(); + new_node_prim->set_instance_name(instance_name); + new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); + new_node->set_scope(scope); + node_input[0]->set_scope(scope); + manager->SetEdge(node, SizeToInt(index), new_node); +} + +std::string CreateInstanceName(const CNodePtr &node, size_t index) { + MS_EXCEPTION_IF_NULL(node); + if (!IsValueNode(node->input(0))) { + MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; + } + std::string name_base = node->fullname_with_scope(); + std::string name = name_base + "_" + std::to_string(index); + std::string instance_name = HashInstanceName(name); + return instance_name; +} + +void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // step1:get graph manager distribute_operator + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto uses_set = manager->node_users()[node]; + CNodePtr node_to_insert = node; + for (auto &uses_pair : uses_set) { + auto uses_cnode = uses_pair.first->cast(); + MS_EXCEPTION_IF_NULL(uses_cnode); + if (!IsValueNode(uses_cnode->input(0))) { + break; + } + PrimitivePtr value_node_prim = GetValueNode(uses_cnode->input(0)); + MS_EXCEPTION_IF_NULL(value_node_prim); + if (value_node_prim->name() == TUPLE_GETITEM) { + if (uses_set.size() > 1) { + MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size(); + } + node_to_insert = uses_cnode; + } + } + MS_EXCEPTION_IF_NULL(node_to_insert); + std::reverse(forward_op.begin(), forward_op.end()); + + // step2:traverse op_list and insert node + for (size_t index = 0; index < forward_op.size(); ++index) { + std::string instance_name_base = FORWARD_OP; + std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index); + std::vector forward_input = CreateInput(forward_op[index], node_to_insert, instance_name); + CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to creat anfnode + MS_EXCEPTION_IF_NULL(forward_node); + ScopePtr scope = node->scope(); + MS_EXCEPTION_IF_NULL(scope); + forward_node->set_scope(scope); + forward_node->set_in_forward_flag(true); + forward_input[0]->set_scope(scope); + (void)manager->Replace(node_to_insert, forward_node); // using Replace function to insert node + } +} + +CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint32_t num, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(prev); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector make_tuple_inputs; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (uint32_t i = 0; i < num; i++) { + std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), prev, + CreatInt32Imm(UintToInt(i))}; + auto tuple_get_item = func_graph->NewCNode(tuple_get_item_inputs); + MS_EXCEPTION_IF_NULL(tuple_get_item); + make_tuple_inputs.push_back(tuple_get_item); + } + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(prev, make_tuple); + return make_tuple; +} + +void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, + const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(pre_node); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if ((redistribution_oplist_ptr->first).size() != (redistribution_oplist_ptr->second).size()) { + MS_LOG(EXCEPTION) << "size of OperatorVector and OutPutInfoVector must be the same!"; + } + for (size_t index = 0; index < (redistribution_oplist_ptr->first).size(); ++index) { + if (pos >= SizeToInt(node->inputs().size())) { + MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size"; + } + // Creat new node + AnfNodePtr target_node = node->input(IntToSize(pos)); + MS_EXCEPTION_IF_NULL(target_node); + // Creat instance_name + auto op = (redistribution_oplist_ptr->first)[index]; + std::string op_name = (redistribution_oplist_ptr->first)[index].first; + std::string instance_name_base = REDISTRIBUTION_OP; + std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name; + InsertNode(op, node, IntToSize(pos), target_node, func_graph, instance_name); + if ((redistribution_oplist_ptr->second)[index].first) { + target_node = node->input(IntToSize(pos)); + MS_EXCEPTION_IF_NULL(target_node); + (void)InsertMakeTuple(target_node, (redistribution_oplist_ptr->second)[index].second, func_graph); + } + } +} + +void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph, int pos, + const std::string &instance_name) { + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name; + } + + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (pos >= SizeToInt(node->inputs().size())) { + MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is " + << instance_name; + } + // Creat new node + AnfNodePtr pre_node = node->input(IntToSize(pos)); + MS_EXCEPTION_IF_NULL(pre_node); + InsertNode(op, node, IntToSize(pos), pre_node, func_graph, instance_name); +} + +TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim, + const OperatorInfoPtr &distribute_operator) { + TensorInfo tensorinfo_in; + if (middle_prim->name() == TUPLE_GETITEM) { + auto value_node = middle_node->input(2)->cast(); + MS_EXCEPTION_IF_NULL(value_node); + size_t index_s = IntToSize(GetValue(value_node->value())); + if (index_s >= distribute_operator->outputs_tensor_info().size()) { + MS_LOG(EXCEPTION) << "The index out of range, index: " << index_s + << ", vector size: " << distribute_operator->outputs_tensor_info().size(); + } + tensorinfo_in = distribute_operator->outputs_tensor_info()[index_s]; + } else { + if (distribute_operator->outputs_tensor_info().empty()) { + MS_LOG(EXCEPTION) << "The outputs tensor info is empty"; + } + tensorinfo_in = distribute_operator->outputs_tensor_info()[0]; + } + return tensorinfo_in.tensor_layout(); +} + +OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!IsParallelCareNode(node)) { + return nullptr; + } + OperatorInfoPtr distribute_operator = node->operator_info(); + if (distribute_operator == nullptr) { + MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; + } + return distribute_operator; +} + +void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, + const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, + const CNodePtr &pre_node) { + FuncGraphPtr func_graph = middle_node->func_graph(); + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Redistribution:get graph failed"; + } + CNodePtr next_node = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(next_node); + auto middle_value = middle_node->input(0)->cast(); + MS_EXCEPTION_IF_NULL(middle_value); + PrimitivePtr middle_prim = middle_value->value()->cast(); + MS_EXCEPTION_IF_NULL(middle_prim); + OperatorInfoPtr next_distribute_operator = GetDistributeOperator(next_node); + if (next_distribute_operator == nullptr) { + MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed"; + } + RankList dev_list = distribute_operator->global_device_list(); + std::string next_prim_name = GetValueNode(next_node->input(0))->name(); + MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name; + MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); + // extract tensor layout in and out + if (distribute_operator->outputs_tensor_info().empty()) { + MS_LOG(EXCEPTION) << "Failure:pre_node's tensorinfo_in is empty"; + } + + if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) { + MS_LOG(EXCEPTION) << "The index is out of range, the index is " << index - 1 << ", the vector size is " + << next_distribute_operator->inputs_tensor_info().size(); + } + TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; + TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); + TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); + if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { + MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; + MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " + << next_node->ToString(); + DumpGraph(func_graph, "redistribution_error"); + MS_LOG(EXCEPTION) << "Failure:tensor_redistribution init failed"; + } + RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); + if (redistribution_oplist_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Failure:InferTensorRedistribution failed"; + } + MS_LOG(DEBUG) << "Redistribution size " << redistribution_oplist_ptr->first.size(); + if (!redistribution_oplist_ptr->first.empty()) { + // insert node before next node + InsertRedistribution(redistribution_oplist_ptr, next_node, func_graph, node_pair.second, pre_node); + } +} + +bool StrategyFound(std::unordered_map attrs) { + auto iter = attrs.find(STRATEGY); + return !((iter == attrs.end()) || (iter->second->type_name() == NONE)); +} + +bool HasStrategy(const FuncGraphPtr &root) { + AnfNodePtr ret = root->get_return(); + MS_EXCEPTION_IF_NULL(ret); + std::vector all_nodes = DeepScopedGraphSearch(ret); + + for (auto &node : all_nodes) { + auto cnode = node->cast(); + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + continue; + } + + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = GetValueNode(prim_anf_node); + auto attrs = prim->attrs(); + if (StrategyFound(attrs)) { + return true; + } + } + + return false; +} + +bool IsCommunicationOp(const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(prim); + return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()); +} + +bool FindCommunicationOp(const std::vector &all_nodes) { + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + ValueNodePtr prim_value_node = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_value_node); + PrimitivePtr prim = GetValueNode(prim_value_node); + MS_EXCEPTION_IF_NULL(prim); + + if (IsCommunicationOp(prim) && cnode->in_forward_flag()) { + MS_EXCEPTION_IF_NULL(prim_value_node->scope()); + MS_LOG(INFO) << "The graph contain communication op: " << prim->name() << ", scope name is " + << prim_value_node->scope()->name(); + return true; + } + } + return false; +} + +bool IsParallelCareNode(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + ValueNodePtr prim_node = cnode->input(0)->cast(); + if (prim_node == nullptr) { + return false; + } + PrimitivePtr prim = prim_node->value()->cast(); + if (prim == nullptr) { + return false; + } + if (IsInBlackList(prim)) { + MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); + return false; + } + // get_next is not in the forward graph, we need mark the get_next as the forward node + if (prim->name() == GET_NEXT) { + return true; + } + if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) { + return false; + } + + return cnode->in_forward_flag(); +} + +void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, + const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node) { + MS_EXCEPTION_IF_NULL(node->func_graph()); + FuncGraphManagerPtr manager = node->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[node]; + CNodePtr insert_node_new; + if (IsValueNode(node->input(0))) { + auto current_value = node->input(0)->cast(); + MS_EXCEPTION_IF_NULL(current_value); + PrimitivePtr current_prim = current_value->value()->cast(); + MS_EXCEPTION_IF_NULL(current_prim); + insert_node_new = ((current_prim->name() == TUPLE_GETITEM) ? node : insert_node); + } else { + insert_node_new = insert_node; + } + MS_EXCEPTION_IF_NULL(insert_node_new); + for (auto &node_pair : node_set) { + CNodePtr use_cnode = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(use_cnode); + if (!IsValueNode(use_cnode->input(0))) { + StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node); + } else { + ValueNodePtr prim_anf_node = use_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) { + Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, + pre_node); + } else { + StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node); + } + } + } +} + +void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(next_node); + OperatorInfoPtr op_info = next_node->operator_info(); + MS_EXCEPTION_IF_NULL(op_info); + + // If the shape of tensor is [] or [1], no need to split it. + Shapes shapes = GetNodeShape(node); + if (shapes.size() != 1) { + MS_LOG(EXCEPTION) << "Split tensor for " << op_info->name() + << ": GetNodeShape for tensor_node, output size is not 1"; + } + Shape shape = shapes[0]; + std::string shape_str = ShapeToString(shape); + if (shape.empty() || ((shape.size() == 1) && (shape[0] == 1))) { + MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape is " << shape_str + << ", no need to split it."; + return; + } + + MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape of tensor is " << shape_str; + + // extract tensor layout + if (IntToSize(index - 1) >= op_info->inputs_tensor_info().size()) { + MS_LOG(EXCEPTION) << "The index is out of range, index is " << index - 1 << ", vector size is " + << op_info->inputs_tensor_info().size(); + } + TensorInfo tensor_info = op_info->inputs_tensor_info()[IntToSize(index - 1)]; + TensorLayout tensor_layout = tensor_info.tensor_layout(); + + // Use _GetTensorSlice operator to split the tensor + FuncGraphPtr func_graph = next_node->func_graph(); // only cnode can get the graph + MS_EXCEPTION_IF_NULL(func_graph); + Operator op = CreateGetTensorSliceOp(tensor_layout); + InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR); + if (!op_info->sub_ops().empty()) { + auto sub_ops = op_info->sub_ops(); + for (size_t i = 0; i < sub_ops.size(); i++) { + if (!sub_ops.at(i).empty()) { + InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB); + } + } + } +} + +void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[node]; + for (auto &node_pair : node_set) { + CNodePtr use_cnode = node_pair.first->cast(); + if (use_cnode == nullptr || !IsValueNode(use_cnode->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(use_cnode_prim); + if (use_cnode_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(use_cnode)) { + SplitTensor(node, use_cnode, node_pair.second); + } + } +} + +std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, + const CNodePtr &node) { + OperatorArgs arg_replace_op = replace_op.second; + ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name); + if (pyop_instance == nullptr) { + MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed"; + } + OperatorParams params = arg_replace_op.second; + if (node->inputs().size() < 2) { + // GetNext operator dose not has input + if (node->inputs().size() == 1) { + return {NewValueNode(pyop_instance)}; + } + MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; + } + std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; + auto prim = GetValueNode(node->input(0)); + if (prim->name() == EMBEDDING_LOOKUP) { + replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; + } + if (!params.empty()) { + Param param_first = *(params.begin()); + int32_t first_position = param_first.second; + if (first_position == 1) { + replace_input.pop_back(); + } + for (auto ¶m : params) { + AnfNodePtr val = NewValueNode(param.first.second); + if (val == nullptr) { + MS_LOG(EXCEPTION) << "Failure:val is nullptr"; + } + int32_t position = param.second; + (void)replace_input.insert(replace_input.begin() + position, val); + } + } + + return replace_input; +} + +void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; + } + std::string instance_name = CreateInstanceName(node, 0); + std::vector replace_input; + replace_input = ReplaceOpInput(replace_op, instance_name, node); + CNodePtr replace_node = func_graph->NewCNode(replace_input); + MS_EXCEPTION_IF_NULL(replace_node); + ScopePtr scope = node->scope(); + MS_EXCEPTION_IF_NULL(scope); + replace_node->set_scope(scope); + replace_node->set_in_forward_flag(true); + replace_input[0]->set_scope(scope); + (void)manager->Replace(node, replace_node); +} + +void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { + // step1:get graph manager distribute_operator + OperatorInfoPtr distribute_operator = node->operator_info(); + if (distribute_operator == nullptr) { + MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; + } + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; + } + // step2:traverse op_list and insert node + std::reverse(replace_op.begin(), replace_op.end()); + auto replace_op_info = distribute_operator->replace_op_info(); + std::reverse(replace_op_info.begin(), replace_op_info.end()); + if (!replace_op_info.empty() && replace_op_info.size() != replace_op.size()) { + MS_LOG(EXCEPTION) << "replace_op_info is not empty and size not equal to replace_op!"; + } + bool replace_op_info_flag = !replace_op_info.empty(); + for (size_t index = 0; index < replace_op.size(); ++index) { + std::string instance_name = CreateInstanceName(node, index); + std::vector replace_input; + if (index != replace_op.size() - 1) { + replace_input = CreateInput(replace_op[index], node, instance_name); + } else { + replace_input = ReplaceOpInput(replace_op[index], instance_name, node); + } + CNodePtr replace_node = func_graph->NewCNode(replace_input); + MS_EXCEPTION_IF_NULL(replace_node); + ScopePtr scope = node->scope(); + MS_EXCEPTION_IF_NULL(scope); + replace_node->set_scope(scope); + PrimitivePtr prim = GetValueNode(replace_node->input(0)); + if (prim->name() == EMBEDDING_LOOKUP) { + auto attrs = prim->attrs(); + attrs[TARGET] = MakeValue(CPU); + (void)prim->SetAttrs(attrs); + } + if (index == replace_op.size() - 1) { + (void)replace_node->set_operator_info(node->operator_info()); + } + replace_node->set_in_forward_flag(true); + replace_input[0]->set_scope(scope); + if (replace_op_info_flag && replace_op_info[index].first) { + auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph); + (void)manager->Replace(node, new_cnode); // using Replace function to insert node + } else { + (void)manager->Replace(node, replace_node); // using Replace function to insert node + } + } + MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name(); +} + +bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { + ValueNodePtr anf_node = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(anf_node); + PrimitivePtr prim = anf_node->value()->cast(); + return (prim->name() == name); +} + +void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(replace_graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(replace_graph->second); + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; + } + for (auto &replace_input : replace_graph->first) { + auto pre_node = node->input(IntToSize(replace_input.second)); + manager->SetEdge(replace_input.first, 1, pre_node); + } + // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called + auto replace_output = replace_graph->second; + MS_EXCEPTION_IF_NULL(replace_output); + (void)manager->Replace(node, replace_output); +} + +int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() != 3) { + MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3"; + } + + if (!cnode->input(2)->isa()) { + MS_LOG(EXCEPTION) << "The index of tuple getitem is not a value node"; + } + + ValuePtr tuple_index_value = GetValueNode(cnode->input(2)); + MS_EXCEPTION_IF_NULL(tuple_index_value); + if (!tuple_index_value->isa()) { + MS_LOG(EXCEPTION) << "The index of tuple getitem is not int32"; + } + return tuple_index_value->cast()->value(); +} + +// Judge whether the node is a loss, and if there are multiple outputs, +// get which output is a grad according to the tuple getitem. +// Currently, it is not supported that the sens is a tuple. +LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { + MS_EXCEPTION_IF_NULL(loss_node); + FuncGraphPtr sub_graph = loss_node->func_graph(); + MS_EXCEPTION_IF_NULL(sub_graph); + CNodePtr return_node = sub_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->inputs().size() < 2) { + MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; + } + AnfNodePtr pre_node = return_node->input(1); + MS_EXCEPTION_IF_NULL(pre_node); + + LossNodeInfo node_info; + + // return -> cast + auto pre_cnode = pre_node->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + auto pre_prim = GetValueNode(pre_cnode->input(0)); + if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + pre_node = pre_cnode->input(1); + } + + // return -> loss + if (pre_node == loss_node) { + node_info.has_tuple_getitem = false; + node_info.dout_index = 0; + return node_info; + } + + // return -> tuple_getitem -> loss + auto cnode = pre_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto current_value = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(current_value); + PrimitivePtr current_prim = current_value->value()->cast(); + MS_EXCEPTION_IF_NULL(current_prim); + // size of common cnode is larger than 1 + if (cnode->inputs().size() < 2) { + MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is smaller than 2"; + } + + if ((current_prim->name() == TUPLE_GETITEM) && (cnode->input(1) == loss_node)) { + // size of tuple_getitem cnode is 3 + auto tuple_index = GetTupleGetItemIndex(cnode); + node_info.has_tuple_getitem = true; + node_info.dout_index = tuple_index; + return node_info; + } + + MS_LOG(EXCEPTION) << "Invalid loss"; +} + +void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + size_t node_size = node->inputs().size(); + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + for (size_t index = 1; index < node_size; ++index) { + AnfNodePtr input = node->input(index); + MS_EXCEPTION_IF_NULL(input); + if (!input->isa() && !input->isa()) { // if it is not a tensor, continue + MS_LOG(INFO) << "insert div op: the index " << index << " is not tensor, skip"; + continue; + } + + for (size_t pos = 0; pos < virtual_div_op.size(); ++pos) { + std::string instance_name = CreateInstanceName(node, pos); + InsertNode(virtual_div_op[pos], node, index, node->input(index), func_graph, instance_name); + } + MS_LOG(INFO) << "insert div op for input index " << index << " of node"; + } +} + +std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { + if (!node->isa() && !node->isa() && !node->isa()) { + return std::make_pair(nullptr, false); + } else if (node->isa()) { + return std::make_pair(node, false); + } else if (node->isa()) { + if (IsValueNode(node)) { + std::vector param_v = FindParameterByRefKeyNode(node, func_graph); + if (param_v.size() != 1) { + MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " + << param_v.size(); + } + return std::make_pair(node, true); + } + return std::make_pair(nullptr, false); + } else { + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + if (!FindParameter(cnode->input(index), func_graph).first) { + continue; + } + return FindParameter(cnode->input(index), func_graph); + } + } else { + if (IsParallelCareNode(cnode)) { + return std::make_pair(nullptr, false); + } else { + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + PrimitivePtr prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == DEPEND && index != 1) { + continue; + } + if (!FindParameter(cnode->input(index), func_graph).first) { + continue; + } + return FindParameter(cnode->input(index), func_graph); + } + } + } + } + return std::make_pair(nullptr, false); +} + +std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(anode); + MS_EXCEPTION_IF_NULL(anode->func_graph()); + FuncGraphManagerPtr manager = anode->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[anode]; + bool result = false; + CNodePtr cnode_return = nullptr; + for (auto &node_pair : node_set) { + CNodePtr use_apply = node_pair.first->cast(); + if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == name && node_pair.second == 1) { + if (use_apply->func_graph() == func_graph) { + result = true; + cnode_return = use_apply; + MS_LOG(INFO) << "Find Primitive " << name << " in the same func_graph"; + continue; + } + MS_LOG(INFO) << "Find Primitive " << name << " in different func_graph"; + } + } + return std::make_pair(result, cnode_return); +} + +bool IsCastBeforMirror(const CNodePtr &node, size_t index) { + // only if cast_before_mirror is true, pre node is cast and type is not float32 return true + if (!ParallelContext::GetInstance()->cast_before_mirror()) { + return false; + } + auto pre_node = node->input(index); + MS_EXCEPTION_IF_NULL(pre_node); + auto cnode = pre_node->cast(); + if (cnode == nullptr || !IsValueNode(cnode->input(0))) { + return false; + } + auto pre_value_node = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(pre_value_node); + auto pre_prim = pre_value_node->value()->cast(); + MS_EXCEPTION_IF_NULL(pre_prim); + if (pre_prim->name() != CAST) { + return false; + } + auto node_type = pre_node->Type(); + MS_EXCEPTION_IF_NULL(node_type); + if (!node_type->isa()) { + MS_LOG(EXCEPTION) << "Unknown type."; + } + auto input_element_type = node_type->cast()->element(); + MS_EXCEPTION_IF_NULL(input_element_type); + auto type_id = input_element_type->type_id(); + + return (type_id != kNumberTypeFloat32); +} + +void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + size_t node_size = node->inputs().size(); + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (mirror_ops.size() != node_size - 1) { + MS_LOG(EXCEPTION) << "Failure:Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() + << ", node_size is " << node_size; + } + for (size_t index = 1; index < node_size; ++index) { + OperatorVector backward_op = mirror_ops[index - 1]; + if (backward_op.empty()) { + continue; + } + std::pair param_node_pair = FindParameter(node->input(index), func_graph); + if (!param_node_pair.first) { + continue; + } + // not a RefKey + if (!param_node_pair.second) { + auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph); + // if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead + if (next_cnode.first) { + MS_EXCEPTION_IF_NULL(next_cnode.second); + manager->SetEdge(node, SizeToInt(index), next_cnode.second); + continue; + } + } + // if the parameter found is a RefKey, or no MirrorOp is found in the same graph, insert a new MirrorOp + // only one MirrorOp in backward_op + if (backward_op.size() != 1) { + MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size(); + } + std::string instance_name = MIRROR_OP; + if (IsCastBeforMirror(node, index)) { + for (auto &op : backward_op) { + // insert new node before the node + CNodePtr cnode = node->input(index)->cast(); + MS_EXCEPTION_IF_NULL(cnode); + AnfNodePtr pre_node = cnode->input(1); + InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); + } + } else { + for (auto &op : backward_op) { + AnfNodePtr pre_node = node->input(index); + InsertNode(op, node, index, pre_node, func_graph, instance_name); + } + } + } +} + +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, + const std::vector> &sens_loss_pairs) { + MS_EXCEPTION_IF_NULL(distribute_operator); + MS_EXCEPTION_IF_NULL(node); + + bool is_loss_cnode = + std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(), + [node](const std::pair &element) { return element.second == node; }); + + MirrorOps mirror_ops = distribute_operator->mirror_ops(); + VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); + // insert mirror op + if (!mirror_ops.empty()) { + MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); + InsertMirrorOps(mirror_ops, node); + } + // insert virtual div op + if (!virtual_div_op.empty() && is_loss_cnode) { + MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name(); + InsertVirtualDivOp(virtual_div_op, node); + } +} + +std::string GetDisOpName(const std::string &prim_name) { + std::string op_name = prim_name; + if (!prim_name.empty() && (prim_name[0] == '_')) { + op_name = prim_name.substr(1); + } + return op_name + "Info"; +} + +OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs, + const std::vector &shape_list) { + if (shape_list.size() != 2) { + MS_LOG(ERROR) << "The size of shape list is not 2"; + return nullptr; + } + if (name.length() == 0) { + MS_LOG(EXCEPTION) << "Length of name is zero!"; + } + std::string distribute_opname = GetDisOpName(name); + if (name == GATHERV2) { + distribute_opname = name + "PInfo"; + auto data_parallel_iter = attrs.find(DATA_PARALLEL); + if (data_parallel_iter != attrs.end()) { + MS_EXCEPTION_IF_NULL(data_parallel_iter->second); + if (!data_parallel_iter->second->isa()) { + MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool."; + } + bool data_parallel = data_parallel_iter->second->cast()->value(); + if (data_parallel) { + distribute_opname = name + "Info"; + } + } + } + OperatorInfoPtr operator_ = + (OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); + if (operator_ == nullptr) { + MS_LOG(INFO) << "Creat " << name << " failed"; + return nullptr; + } + std::string origin_name = operator_->name(); + operator_->set_name(origin_name + std::to_string(TOTAL_OPS)); + MS_LOG(INFO) << "Successfully created operator " << origin_name; + ++TOTAL_OPS; + return operator_; +} + +OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + const std::vector &shape_list) { + MS_EXCEPTION_IF_NULL(prim); + OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list); + if (operator_ == nullptr) { + MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel"; + operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list); + MS_EXCEPTION_IF_NULL(operator_); + } + return operator_; +} + +OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + std::vector shape_list) { + OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); + for (size_t i = 0; i < shape_list[0].size(); ++i) { + MS_LOG(INFO) << "No: " << i << " input's shape: " << ShapeToString(shape_list[0][i]); + } + return operator_; +} + +StrategyPtr ExtractStrategy(std::unordered_map attrs) { + ValueTuplePtr var = attrs[STRATEGY]->cast(); + StrategyPtr strategyPtr; + MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString(); + if (var == nullptr) { + MS_LOG(EXCEPTION) << "Strategy value is nullptr"; + } + if (var->size() > 0) { + std::vector elements = var->value(); + std::vector strategy; + for (uint32_t index = 0; index < elements.size(); ++index) { + Dimensions dim; + if (elements[index]->isa()) { + ValueTuplePtr value_tuple = elements[index]->cast(); + std::vector value_vector = value_tuple->value(); + (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim), + [](const ValuePtr &value) { return static_cast(GetValue(value)); }); + strategy.push_back(dim); + } else { + MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; + } + } + if (strategy.empty()) { + MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy"; + } + strategyPtr = NewStrategy(0, strategy); + } + + return strategyPtr; +} + +Shapes GetNodeShape(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + Shapes shapes; + BaseShapePtr base_shape_ptr = node->Shape(); + if (node->isa()) { + auto cnode = node->cast(); + if (IsValueNode(cnode->input(0))) { + PrimitivePtr prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == MAKEREF) { + AnfNodePtr ref_node = cnode->input(1); + auto func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(ref_node); + MS_EXCEPTION_IF_NULL(func_graph); + return GetRefKeyNodeShape(ref_node, func_graph); + } + } + if (cnode->input(0)->isa()) { + if (cnode->inputs().size() < 2) { + MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is samller than 2"; + } + base_shape_ptr = cnode->input(1)->Shape(); + } + } + if (base_shape_ptr == nullptr) { + MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is " + << node->fullname_with_scope(); + } + auto tuple_shape_ptr = dyn_cast(base_shape_ptr); + if (tuple_shape_ptr != nullptr) { + auto tuple_shape = tuple_shape_ptr->shape(); + for (auto &shape : tuple_shape) { + auto each_shape = dyn_cast(shape); + MS_EXCEPTION_IF_NULL(each_shape); + shapes.push_back(each_shape->shape()); + } + } else { + auto shape_ptr = dyn_cast(base_shape_ptr); + MS_EXCEPTION_IF_NULL(shape_ptr); + shapes.push_back(shape_ptr->shape()); + } + return shapes; +} + +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector parameters; + if (!IsValueNode(node)) { + MS_LOG(ERROR) << "The node is not a ref key"; + return parameters; + } + + auto ref_key = GetValueNode(node); + MS_EXCEPTION_IF_NULL(ref_key); + auto name = ref_key->tag(); + + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto roots = manager->roots(); + if (roots.size() != 1) { + MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1"; + return parameters; + } + + FuncGraphPtr root_g = roots.back(); + MS_EXCEPTION_IF_NULL(root_g); + for (auto ¶m_node : root_g->parameters()) { + auto param = param_node->cast(); + if (param && (name == param->name())) { + parameters.push_back(param_node); + MS_LOG(INFO) << "The name of ref key is: " << name; + return parameters; + } + } + + MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter"; + return parameters; +} + +Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector parameters = FindParameterByRefKeyNode(node, func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + + Shapes input_shapes; + input_shapes = GetNodeShape(parameters[0]); + if (input_shapes.size() != 1) { + MS_LOG(EXCEPTION) << "Get input shape failed"; + } + + MS_LOG(INFO) << "The parameter shape is " << ShapeToString(input_shapes[0]); + return input_shapes; +} + +std::vector ExtractShape(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + Shapes shape_inputs, shape_outputs; + std::vector shape_all; + std::vector all_inputs = node->inputs(); + std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; + + size_t inputs_size = all_inputs.size(); + for (size_t i = 1; i < inputs_size; ++i) { + Shapes input_shapes; + AnfNodePtr input = all_inputs[i]; + if (IsValueNode(input)) { + auto func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector parameters = FindParameterByRefKeyNode(input, func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + std::pair node_pair = std::make_pair(node, SizeToInt(i)); + g_RefMap[parameters[0]] = node_pair; + input_shapes = GetRefKeyNodeShape(input, func_graph); + } else if (IsValueNode(input) || input->isa() || input->isa()) { + input_shapes = GetNodeShape(input); + } else { + continue; + } + if (input_shapes.size() != 1) { + MS_LOG(EXCEPTION) << "ExtractShape:Get input shape failed"; + } + shape_inputs.push_back(input_shapes[0]); + } + shape_all.push_back(shape_inputs); + // extract out shape + shape_outputs = GetNodeShape(node); + shape_all.push_back(shape_outputs); + return shape_all; +} + +std::pair FindParallelCareNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[node]; + for (auto &node_pair : node_set) { + CNodePtr cnode = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + continue; + } + ValueNodePtr prim_node_anf = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_node_anf); + PrimitivePtr node_prim = prim_node_anf->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { + return node_pair; + } else if (FindParallelCareNode(node_pair.first).first != nullptr) { + return FindParallelCareNode(node_pair.first); + } + } + return std::make_pair(nullptr, 0); +} + +std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(parameter); + FuncGraphManagerPtr manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::pair prim_anf_node_pair = FindParallelCareNode(parameter); + if (prim_anf_node_pair.first != nullptr) { + return prim_anf_node_pair; + } else { + AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; + for (auto ¶m_pair : param_sub_set) { + CNodePtr graph_cnode = param_pair.first->cast(); + if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa()) { + continue; + } + CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast(); + if (!IsValueNode(graph_cnode_inp0->input(1))) { + continue; + } + FuncGraphPtr graph_sub = GetValueNode(graph_cnode_inp0->input(1)); + auto parameters = graph_sub->parameters(); + if (IntToSize(param_pair.second - 1) >= parameters.size()) { + MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " + << parameters.size(); + } + std::pair res = FindSubGraph(graph_sub, parameters[IntToSize(param_pair.second - 1)]); + if (res.first != nullptr) { + return res; + } + } + } + return std::make_pair(nullptr, 0); +} + +void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res) { + MS_EXCEPTION_IF_NULL(parameter); + AbstractBasePtr abstract = parameter->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); + CNodePtr cnode = res.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + OperatorInfoPtr distribute_operator = cnode->operator_info(); + if (distribute_operator == nullptr) { + MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; + } + + if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) { + MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is " + << distribute_operator->inputs_tensor_info().size(); + } + TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)]; + Shape slice_shape = tensorinfo_in.slice_shape(); + MS_LOG(DEBUG) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " + << MakeValue(slice_shape)->ToString(); + std::shared_ptr parallel_shape = std::make_shared(slice_shape); + MS_EXCEPTION_IF_NULL(parallel_shape); + // Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis. + auto cloned_abstract = abstract->Clone(); + MS_EXCEPTION_IF_NULL(cloned_abstract); + cloned_abstract->set_shape(parallel_shape); + parameter->set_abstract(cloned_abstract); + TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); + ParameterPtr parameter_ptr = parameter->cast(); + MS_EXCEPTION_IF_NULL(parameter_ptr); + parameter_ptr->set_tensor_layout(std::make_shared(tensor_layout)); +} + +void CoverSliceShape(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + auto parameters = root->parameters(); + for (auto ¶meter : parameters) { + MS_EXCEPTION_IF_NULL(parameter->Shape()); + auto iter = g_RefMap.find(parameter); + if (iter != g_RefMap.end()) { + SetParallelShape(parameter, g_RefMap[parameter]); + continue; + } + std::pair res = FindSubGraph(root, parameter); + if (res.first == nullptr) { + MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape"; + } else { + SetParallelShape(parameter, res); + MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); + } + } + g_RefMap.clear(); +} + +bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_node) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(parameter_node); + FuncGraphManagerPtr manager = root->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto cloned_parameter = parameter_node->cast(); + MS_EXCEPTION_IF_NULL(cloned_parameter); + + // find the clone parameter + if (!cloned_parameter->has_default()) { + return false; + } + + bool cloned = cloned_parameter->default_param()->cloned(); + if (!cloned) { + return false; + } + + MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned"; + return true; +} + +void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + for (auto &cloned_parameter_node : root->parameters()) { + MS_EXCEPTION_IF_NULL(cloned_parameter_node); + auto cloned_parameter = cloned_parameter_node->cast(); + MS_EXCEPTION_IF_NULL(cloned_parameter); + + if (!ParameterIsCloned(root, cloned_parameter_node)) { + continue; + } + + // get the cloned index + int32_t cloned_index = cloned_parameter->default_param()->cloned_index(); + + // find the be cloned parameter + bool found_be_cloned_parameter = false; + ParameterPtr cloned_from_parameter = nullptr; + AnfNodePtr cloned_from_node = nullptr; + for (auto &be_cloned_parameter_node : root->parameters()) { + MS_EXCEPTION_IF_NULL(be_cloned_parameter_node); + auto be_cloned_parameter = be_cloned_parameter_node->cast(); + MS_EXCEPTION_IF_NULL(be_cloned_parameter); + if (!be_cloned_parameter->has_default()) { + continue; + } + + const auto ¶m_value_cloned = be_cloned_parameter->default_param(); + if (!param_value_cloned->be_cloned()) { + continue; + } + + // get the be cloned index + auto &be_cloned_index = param_value_cloned->be_cloned_index(); + if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) { + found_be_cloned_parameter = true; + cloned_from_parameter = be_cloned_parameter; + cloned_from_node = be_cloned_parameter_node; + } + } + + if (found_be_cloned_parameter) { + // set the shape and tensor layout for cloned parameter + cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); + MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); + MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); + auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); + MS_EXCEPTION_IF_NULL(cloned_abstract); + cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); + cloned_parameter_node->set_abstract(cloned_abstract); + MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() + << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() + << ", clone index is: " << cloned_index; + } else { + MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is " + << cloned_index << ", but not found the be cloned parameter"; + } + } + std::string env = common::GetEnv("SLICE_ENV"); + if (!env.empty()) { + MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env; + } +} + +void SetVirtualDatasetStrategy(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + + PrimitivePtr prim = GetValueNode(node->input(0)); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == VIRTUAL_DATA_SET) { + CheckGlobalDeviceManager(); + int32_t dev_num; + if (full_batch) { + dev_num = 1; + } else { + dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); + } + auto attrs_temp = prim->attrs(); + std::vector shape_list = ExtractShape(node); + if (shape_list.empty()) { + MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; + } + std::vector elements; + for (size_t i = 0; i < shape_list[0].size(); i++) { + if (shape_list[0][i].empty()) { + MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero"; + } + std::vector input_strategy = {dev_num}; + for (size_t j = 1; j < shape_list[0][i].size(); j++) { + input_strategy.push_back(1); + } + elements.push_back(MakeValue(input_strategy)); + } + ValueTuplePtr strategy = std::make_shared(elements); + attrs_temp[STRATEGY] = strategy; + (void)prim->SetAttrs(attrs_temp); + } +} + +void ExtractInformation(const std::vector &all_nodes) { + // load strategy map from checkpoint + StrategyMap stra_map; + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { + if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; + } + } + for (auto &node : all_nodes) { + auto cnode = node->cast(); + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + continue; + } + SetVirtualDatasetStrategy(cnode); + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = GetValueNode(prim_anf_node); + auto attrs = prim->attrs(); + MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name(); + if (IsParallelCareNode(cnode)) { + std::vector shape_list = ExtractShape(cnode); + if (shape_list.empty()) { + MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; + } + OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); + if (operator_ == nullptr) { + MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed"; + } + auto &inputs = cnode->inputs(); + std::vector input_value; + for (size_t index = 1; index < inputs.size(); ++index) { + if (inputs[index]->isa()) { + input_value.push_back(GetValueNode(inputs[index])); + } else { + input_value.emplace_back(nullptr); + } + } + StrategyPtr strategyPtr = nullptr; + (*operator_).set_input_value(input_value); + (*operator_).set_outputs_dtype(cnode->Type()); + (*operator_).set_cnode(cnode); + if (prim->name() == RESHAPE) { + (void)cnode->set_operator_info(operator_); + continue; + } + // load strategy checkpoint + // key of strategy map + std::string strategy_key_name = NodeParameterName(cnode); + bool load_strategy_from_ckpt = + StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); + if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { + MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() + << " is empty, using batch parallel"; + std::shared_ptr> strategy_v_ptr = operator_->GenerateBatchStrategies(); + if (strategy_v_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Failure:Generate batch parallel strategy failed"; + } + std::vector elements; + for (size_t i = 0; i < strategy_v_ptr->size(); i++) { + elements.push_back(MakeValue((*strategy_v_ptr)[i])); + } + ValueTuplePtr strategy = std::make_shared(elements); + // display the strategy generated by batch parallel + attrs[GEN_STRATEGY] = strategy; + (void)prim->SetAttrs(attrs); + MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is " + << attrs[GEN_STRATEGY]->ToString(); + strategyPtr = NewStrategy(0, *strategy_v_ptr); + } else if (load_strategy_from_ckpt) { + strategyPtr = stra_map[strategy_key_name]; + } else { + strategyPtr = ExtractStrategy(attrs); + } + if (strategyPtr != nullptr) { + if (operator_->Init(strategyPtr) == FAILED) { + MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; + } + (void)cnode->set_operator_info(operator_); + } else { + MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; + } + } + } +} + +TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair) { + CNodePtr cnode = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); + MS_EXCEPTION_IF_NULL(distribute_operator); + int index = node_pair.second; + if (index > SizeToInt(distribute_operator->inputs_tensor_info().size())) { + MS_LOG(EXCEPTION) << "The index is out of range, the node_pair.second is " << index - 1 << ", the vector size is " + << distribute_operator->inputs_tensor_info().size(); + } + TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; + TensorLayout tensorlayout_in = tensorinfo_in.tensor_layout(); + return tensorlayout_in; +} + +// if reshape's output connect to several primitive, return the first layout found +std::shared_ptr FindNextLayout(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + FuncGraphManagerPtr manager = cnode->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[cnode]; + for (auto &node_pair : node_set) { + CNodePtr use_apply = node_pair.first->cast(); + if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { + MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); + auto layout = GetInputLayoutFromCNode(node_pair); + return std::make_shared(layout); + } + MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) + << " " << (use_apply->operator_info() != nullptr); + + auto layout_ptr = FindNextLayout(use_apply); + if (layout_ptr) { + return layout_ptr; + } + } + MS_LOG(WARNING) << "FindNextLayout return nullptr, if reshape is not the last primitive, there must be some error"; + return nullptr; +} + +std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) { + MS_EXCEPTION_IF_NULL(cnode); + OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); + MS_EXCEPTION_IF_NULL(distribute_operator); + if (distribute_operator->outputs_tensor_info().size() < output_index) { + MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size() + << ", must be less than output_index " << output_index; + } + TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index]; + TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); + return std::make_shared(tensorlayout_out); +} + +std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) { + if (!node->isa()) { + return nullptr; + } + CNodePtr cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + return nullptr; + } + if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); + if (!layout_ptr) { + MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; + } + return layout_ptr; + } + return nullptr; +} + +std::shared_ptr CreateParameterLayout(const AnfNodePtr &node) { + // Create DataParallel tensor layout for parameter(support WideDeep). + CheckGlobalDeviceManager(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); + TensorLayout input_tensor_layout; + // create input_shape + Shapes inputs_shape = GetNodeShape(node); + Shape input_shape_array = inputs_shape[0]; + if (input_shape_array.empty()) { + MS_LOG(EXCEPTION) << "Don't support reshape a scalar parameter."; + } + // create tensor_map + size_t shape_size = input_shape_array.size(); + TensorMap input_tensor_map_array(SizeToInt(shape_size) - 1, -1); + input_tensor_map_array.insert(input_tensor_map_array.begin(), 0); + // create dev_matrix + Shape dev_matrix_array = {dev_num}; + if (input_tensor_layout.InitFromVector(dev_matrix_array, input_tensor_map_array, input_shape_array) != SUCCESS) { + MS_LOG(EXCEPTION) << "Create tensor layout for parameter failed."; + } + return std::make_shared(input_tensor_layout); +} + +std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { + if (node->isa()) { + return CreateParameterLayout(node); + } + if (!node->isa()) { + return nullptr; + } + CNodePtr cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + return nullptr; + } + if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); + if (!layout_ptr) { + MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; + } + return layout_ptr; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = prim_anf_node->value()->cast(); + if (prim->name() == TUPLE_GETITEM) { + auto tuple_index = GetTupleGetItemIndex(cnode); + auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), IntToSize(tuple_index)); + if (!layout_ptr) { + MS_LOG(EXCEPTION) + << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a parallel care node " + "before tuple_getitem!"; + } + return layout_ptr; + } + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + if (prim->name() == DEPEND && index != 1) { + continue; + } + auto layout_ptr = FindPrevLayout(cnode->inputs()[index]); + if (!layout_ptr) { + continue; + } + return layout_ptr; + } + MS_LOG(WARNING) << "FindPrevLayout return nullptr, if reshape is not the first primitive, there must be some error"; + return nullptr; +} + +void ReshapeInit(const std::vector &all_nodes) { + for (auto &node : all_nodes) { + auto cnode = node->cast(); + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { + continue; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + OperatorInfoPtr operator_info = cnode->operator_info(); + if (operator_info == nullptr) { + MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; + } + if (prim->name() != RESHAPE) { + continue; + } + auto attrs = prim->attrs(); + if (StrategyFound(attrs)) { + MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; + } + MS_ASSERT(cnode->inputs().size() == 3); + auto prev_layout_ptr = FindPrevLayout(cnode->input(1)); + if (prev_layout_ptr) { + auto reshape_info_ptr = std::dynamic_pointer_cast(operator_info); + reshape_info_ptr->SetInputLayout(*prev_layout_ptr); + } + auto next_layout_ptr = FindNextLayout(cnode); + if (next_layout_ptr) { + auto reshape_info_ptr = std::dynamic_pointer_cast(operator_info); + reshape_info_ptr->SetOutputLayout(*next_layout_ptr); + } + if (operator_info->Init(nullptr) == FAILED) { + MS_LOG(EXCEPTION) << "Failure:operator " << prim->ToString() << " init failed"; + } + } +} + +CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + CNodePtr return_node = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->size() < 2) { + MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; + } + AnfNodePtr pre_node = return_node->input(1); + MS_EXCEPTION_IF_NULL(pre_node); + + auto pre_cnode = pre_node->cast(); + if (pre_cnode == nullptr) { + return nullptr; + } + + auto current_prim = GetValueNode(pre_cnode->input(0)); + // return -> cast + if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + pre_cnode = pre_cnode->input(1)->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + current_prim = GetValueNode(pre_cnode->input(0)); + } + + // notice: the GetNext op has not input + if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { + MS_LOG(INFO) << "The loss is: " << current_prim->name(); + return pre_cnode; + } + + // size of common cnode is larger than 1 + if (pre_cnode->size() < 2) { + MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; + } + + // return -> tuple_getitem -> loss + if (current_prim->name() == TUPLE_GETITEM) { + AnfNodePtr pre_pre_node = pre_cnode->input(1); + MS_EXCEPTION_IF_NULL(pre_pre_node); + + auto pre_pre_cnode = pre_pre_node->cast(); + auto value = pre_pre_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(value); + PrimitivePtr prim = value->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + MS_LOG(DEBUG) << "The loss name is " << prim->name(); + return pre_pre_cnode; + } + + // return -> make_tuple + if (current_prim->name() == MAKE_TUPLE) { + MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; + } + + // return -> loss + MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); + return pre_cnode; +} + +TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { + TensorLayouts ret; + MS_EXCEPTION_IF_NULL(loss_cnode); + AnfNodePtr node = loss_cnode->cast(); + MS_EXCEPTION_IF_NULL(node); + + LossNodeInfo node_info = GetLossNodeInfo(node); + ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) { + MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now"; + return ret; + } + + OperatorInfoPtr operator_info = loss_cnode->operator_info(); + MS_EXCEPTION_IF_NULL(operator_info); + TensorInfo loss_grad_tensor_info; + size_t op_output_size = operator_info->outputs_tensor_info().size(); + MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is " + << node_info.has_tuple_getitem << ", the output size is " << op_output_size << ", the dout_index is " + << node_info.dout_index; + + if ((op_output_size == 0) || (op_output_size <= IntToSize(node_info.dout_index))) { + MS_LOG(EXCEPTION) << "The index is " << node_info.dout_index << ", but the size of outputs is " << op_output_size; + } + + if (!node_info.has_tuple_getitem && (op_output_size > 1)) { + MS_LOG(EXCEPTION) << "Currently, it is not supported that the sens is a tuple."; + } + + loss_grad_tensor_info = operator_info->outputs_tensor_info()[IntToSize(node_info.dout_index)]; + ret.push_back(loss_grad_tensor_info.tensor_layout()); + return ret; +} + +void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) { + MS_EXCEPTION_IF_NULL(grad_sens_node); + if (grad_sens_node->size() <= 1) { + MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2"; + } + AnfNodePtr sens_tensor_node = grad_sens_node->input(1); + MS_EXCEPTION_IF_NULL(sens_tensor_node); + Shapes sens_shapes = GetNodeShape(sens_tensor_node); + if (sens_shapes.size() != 1) { + MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1"; + } + // If the shape of sens tensor is [] or [1], no need to split it. + Shape sens_shape = sens_shapes[0]; + if (sens_shape.empty() || ((sens_shape.size() == 1) && (sens_shape[0] == 1))) { + if (sens_tensor_node->isa()) { + auto sens_tensor_param = sens_tensor_node->cast(); + MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); + sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); + } + MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; + return; + } + auto loss_shape = loss_grad_layout.tensor_shape().array(); + if (loss_shape != sens_shape) { + MS_LOG(EXCEPTION) << "The shape of sens is not equal to loss output, it is unsupported now. Sens shape is " + << ShapeToString(sens_shape) << ", loss shape is " << ShapeToString(loss_shape); + } + MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", split it."; + + if (!IsValueNode(sens_tensor_node)) { + if (sens_tensor_node->isa()) { + MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); + AbstractBasePtr abstract = sens_tensor_node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + auto slice_shape = loss_grad_layout.slice_shape().array(); + std::shared_ptr parallel_shape = std::make_shared(slice_shape); + MS_EXCEPTION_IF_NULL(parallel_shape); + auto cloned_abstract = abstract->Clone(); + MS_EXCEPTION_IF_NULL(cloned_abstract); + cloned_abstract->set_shape(parallel_shape); + sens_tensor_node->set_abstract(cloned_abstract); + auto sens_tensor_param = sens_tensor_node->cast(); + sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); + return; + } + MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; + } + + // Use _GetTensorSlice operator to split the sens tensor + FuncGraphPtr func_graph = grad_sens_node->func_graph(); // only cnode can get the graph + MS_EXCEPTION_IF_NULL(func_graph); + Operator op = CreateGetTensorSliceOp(loss_grad_layout); + InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS); +} + +void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(distribute_operator); + MS_EXCEPTION_IF_NULL(cnode); + OperatorVector forward_op = distribute_operator->forward_op(); + if (!forward_op.empty()) { + MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name(); + ForwardCommunication(forward_op, cnode); + } +} + +void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(distribute_operator); + MS_EXCEPTION_IF_NULL(cnode); + // StepReplaceOp + OperatorVector replace_op = distribute_operator->replace_op(); + if (!replace_op.empty()) { + MS_LOG(INFO) << "StepReplaceOp " << cnode->ToString(); + StepReplaceOp(replace_op, cnode); + } + + // StepReplaceGraph: after calling StepReplaceGraph, cnode can not be used anymore. + ReplaceGraphPtr replace_graph = distribute_operator->replace_graph(cnode); + if (!replace_op.empty() && replace_graph) { + MS_LOG(EXCEPTION) << "Only one of replace_op or replace_op can be used"; + } + if (replace_graph) { + MS_LOG(INFO) << "StepReplaceGraph " << cnode->ToString(); + StepReplaceGraph(replace_graph, cnode); + } +} + +void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(distribute_operator); + MS_EXCEPTION_IF_NULL(cnode); + + std::string op_name = distribute_operator->name(); + if (op_name.find(DROPOUT_DO_MASK) == std::string::npos) { + return; + } + + DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast(distribute_operator); + MS_EXCEPTION_IF_NULL(dropout_do_mask); + std::vector replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode); + if (replace_op.empty()) { + MS_LOG(DEBUG) << "No need to replace dropout_gen_mask"; + return; + } + if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; + } + ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); +} + +void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { + HandleDropoutNode(distribute_operator, cnode); +} + +std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { + // J->CNode->Graph + std::set graph_set; + for (auto &node : root_all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + + auto cnode = node->cast(); + if ((cnode->size() < 2) || !IsValueNode(cnode->input(0))) { + continue; + } + auto expect_j_prim = GetValueNode(cnode->input(0)); + if (expect_j_prim->name() != J) { + continue; + } + if (IsValueNode(cnode->input(1))) { + auto graph = GetValueNode(cnode->input(1)); + MS_LOG(DEBUG) << "Find the forward graph success"; + graph_set.insert(graph); + } + } + return graph_set; +} + +void StepSplitSens(const std::pair &sens_loss_pair) { + CNodePtr sens_node = sens_loss_pair.first; + CNodePtr loss_node = sens_loss_pair.second; + auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node); + if (!loss_grad_layout.empty()) { + SplitSens(sens_node, loss_grad_layout[0]); + } +} + +// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) +std::vector> GetSensLossPairs(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + std::vector> sens_loss_pairs; + for (auto &node : root->nodes()) { + if (!node->isa()) { + continue; + } + + // cnode(sens)-->cnode(tuple_getitem) + auto sens_cnode = node->cast(); + AnfNodePtr expect_tuple_getitem = sens_cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_tuple_getitem); + if (!expect_tuple_getitem->isa()) { + continue; + } + + auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); + if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) { + continue; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode + AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); + MS_EXCEPTION_IF_NULL(expect_anonymous); + if (!expect_anonymous->isa()) { + continue; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) + auto expect_anonymous_cnode = expect_anonymous->cast(); + AnfNodePtr expect_j = expect_anonymous_cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_j); + if (!expect_j->isa()) { + continue; + } + auto expect_j_cnode = expect_j->cast(); + if (!IsSomePrimitive(expect_j_cnode, J)) { + continue; + } + + if (!IsValueNode(expect_j_cnode->input(1))) { + MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; + } + auto func_graph = GetValueNode(expect_j_cnode->input(1)); + auto loss_cnode = FindLossCNode(func_graph); + if (loss_cnode == nullptr) { + MS_LOG(WARNING) << "Can not find the loss cnode"; + continue; + } + std::pair sens_loss_pair = std::make_pair(sens_cnode, loss_cnode); + sens_loss_pairs.push_back(sens_loss_pair); + } + return sens_loss_pairs; +} + +void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, + const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(manager); + TensorRedistribution tensor_redistribution; + + std::vector> sens_loss_pairs = GetSensLossPairs(root); + bool has_backward = !sens_loss_pairs.empty(); + // split sens must before inserting the operators. + for (auto &pair : sens_loss_pairs) { + // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. + // If the type of sens node is not Tensor, it is unsupported now, do nothing default. + StepSplitSens(pair); + } + + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); + if (distribute_operator == nullptr) { + continue; + } + + // insert forward ops + InsertForwardOps(distribute_operator, cnode); + + // insert redistribution ops + StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode); + + // insert backward ops + if (has_backward) { + BackwardCommunication(distribute_operator, cnode, sens_loss_pairs); + } + + HandleSpecialNode(distribute_operator, cnode); + } else if (IsValueNode(node)) { + StepSplitTensor(node, manager); + } + } + + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); + if (distribute_operator == nullptr) { + continue; + } + // StepReplace + StepReplace(distribute_operator, cnode); + } + } +} + +namespace { +void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(node); + auto symbolic_key = GetValueNode(node); + MS_EXCEPTION_IF_NULL(symbolic_key); + auto all_upstream_node = root->manager()->node_users()[node]; + for (auto &upstream_node : all_upstream_node) { + FuncGraphPtr fg = upstream_node.first->func_graph(); + if (symbolic_key->node()->isa()) { + for (auto ¶m : root->parameters()) { + if (*param == *symbolic_key->node()) { + AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param}); + MS_EXCEPTION_IF_NULL(reverted_node); + MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString(); + (void)fg->manager()->Replace(node, reverted_node); + MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString(); + } + } + } + } +} +} // namespace + +void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector &all_nodes) { + MS_EXCEPTION_IF_NULL(root); + for (auto &node : all_nodes) { + // revert back SymbolicKeyInstance to embed() primitive + if (IsValueNode(node)) { + RevertSymbolicKeyInstance(root, node); + continue; + } + } +} + +std::string NodeParameterName(const CNodePtr &node) { + std::vector node_inputs{node->inputs()}; + for (auto input : node_inputs) { + if (input->isa()) { + auto input_parameter = input->cast(); + if (input_parameter->has_default()) { + const auto ¶m_value = input_parameter->default_param(); + if (param_value->requires_grad()) { + return param_value->name(); + } + } + } + } + return ""; +} + +void CheckpointStrategy(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; + StrategyMap stra_map; + auto ret = func_graph->get_return(); + auto all_nodes = DeepScopedGraphSearch(ret); + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + continue; + } + std::string param_name = NodeParameterName(cnode); + if (param_name.empty()) { + continue; + } + PrimitivePtr prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim); + OperatorInfoPtr operator_info = cnode->operator_info(); + if (operator_info) { + if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { + continue; + } + StrategyPtr strategyPtr = operator_info->strategy(); + MS_EXCEPTION_IF_NULL(node->scope()); + stra_map[param_name] = strategyPtr; + } + } + if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; + } +} + +void SetForwardFlag(const std::vector &all_nodes) { + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + + // CNode is globally unique. + MS_LOG(DEBUG) << "Set forward flag " << cnode->DebugString() << "."; + cnode->set_in_forward_flag(true); + } +} + +void SetForwardFlag(const AnfNodeSet &all_nodes) { + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + + // CNode is globally unique. + cnode->set_in_forward_flag(true); + } +} + +std::set ForwardGraph(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + const auto &all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + return graph_set; +} + +std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { + MS_EXCEPTION_IF_NULL(graph); + std::vector root_forward_nodes; + auto loss_cnode = FindLossCNode(graph); + if (loss_cnode == nullptr) { + MS_LOG(WARNING) << "Can not find the loss cnode"; + return root_forward_nodes; + } + + auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy(); + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + auto root_node_id = node->UniqueIdThroughCopy(); + if (loss_cnode_id == root_node_id) { + root_forward_nodes = DeepLinkedGraphSearch(cnode); + break; + } + } + return root_forward_nodes; +} + +void MarkForwardCNode(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + auto all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + + if (graph_set.empty()) { + MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; + SetForwardFlag(all_nodes); + } else { + for (auto &func_graph : graph_set) { + MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); + auto return_node = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); + SetForwardFlag(all_dfs_nodes); + auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes); + if (root_forward_nodes.empty()) { + continue; + } + // Mark forward flag for the nodes in root graph. + SetForwardFlag(root_forward_nodes); + } + } +} + +Status ParallelInit() { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + int32_t device_num = ParallelContext::GetInstance()->device_num(); + int32_t global_rank = ParallelContext::GetInstance()->global_rank(); + std::string backend = ParallelContext::GetInstance()->communication_backend(); + std::string world_group; + + if (backend == HCCL_BACKEND) { + world_group = HCCL_WORLD_GROUP; + } else if (backend == NCCL_BACKEND) { + world_group = NCCL_WORLD_GROUP; + } else { + MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; + } + + uint32_t world_rank_size = 0; + if (!ParallelContext::GetInstance()->device_num_is_set()) { + if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) { + MS_LOG(EXCEPTION) << "Get rank size failed"; + } + device_num = UintToInt(world_rank_size); + MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num; + } + + uint32_t rank_id = 0; + if (!ParallelContext::GetInstance()->global_rank_is_set()) { + if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { + MS_LOG(EXCEPTION) << "Get rank id failed"; + } + global_rank = UintToInt(rank_id); + MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; + } + + if (!InitDevice(device_num, global_rank, backend)) { + MS_LOG(ERROR) << "Init device failed"; + return FAILED; + } + + MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank + << ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean() + << ", cast_before_mirror: " << ParallelContext::GetInstance()->cast_before_mirror(); + return SUCCESS; +} + +bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(optimizer); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); + // assume no change to graph + bool changes = false; + // control whether use model_parallel mode + if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || + (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { + if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) { + if (HasStrategy(root)) { + MS_LOG(INFO) << "Strategies ignored in " << parallel_mode + << ", set_strategy() only valid in [semi_]auto_parallel."; + } + root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); + } + + return changes; + } + + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); + + MS_LOG(INFO) << "Now entering step parallel"; + DumpGraph(root, std::string(STEP_PARALLEL_BEGIN)); + + pipeline::ResourceBasePtr res = optimizer->resource(); + MS_EXCEPTION_IF_NULL(res); + + FuncGraphManagerPtr manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodePtr ret = root->get_return(); + MS_EXCEPTION_IF_NULL(ret); + std::vector all_nodes = DeepScopedGraphSearch(ret); + std::reverse(all_nodes.begin(), all_nodes.end()); + if (parallel_mode != AUTO_PARALLEL) { + TOTAL_OPS = 0; + if (ParallelInit() != SUCCESS) { + MS_LOG(EXCEPTION) << "Parallel init failed"; + } + + // mark the forward cnodes, parallel only care these nodes + MarkForwardCNode(root); + + if (FindCommunicationOp(all_nodes)) { + MS_LOG(EXCEPTION) << "The graph contain communication op"; + } + + // extract shape and strategy, set operator_info + ExtractInformation(all_nodes); + ReshapeInit(all_nodes); + } + // save strategy as checkpoint for multi-train + if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { + CheckpointStrategy(root); + } + + HandleSymbolicKeyInstance(root, all_nodes); + + // cover Parallel shape + CoverSliceShape(root); + + // set the shape for optimizer's clone tensor + SetClonedTensorShapeForOptimizer(root); + + // ForwardCommunication BackwardCommunication TensorRedistribution + ParallelCommunication(root, all_nodes, manager); + + DumpGraph(root, std::string(STEP_PARALLEL_END)); + + // step parallel only run once + root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true); + res->results()[pipeline::kStepParallelGraph] = root; + + // in auto parallel mode, no need to check if stategies set + root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); + + (void)gettimeofday(&end_time, nullptr); + uint64_t time = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + time += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Now leaving step parallel, used time: " << time << " us"; + return changes; +} + +// Needed by rec_parser +std::vector ExtractInputsTensorName(const CNodePtr &node) { + std::vector name_inputs; + std::vector all_inputs = node->inputs(); + std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; + + std::string node_id = node->UniqueId(); + name_inputs.push_back(node_id); + for (auto &input : node_inputs) { + std::string name = input->UniqueId(); + name_inputs.push_back(name); + } + + return name_inputs; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h new file mode 100644 index 0000000000..f9fe67ea6b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -0,0 +1,155 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ +#define MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "./common.h" +#include "frontend/optimizer/opt.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +using OperatorInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace parallel { +const uint64_t kUSecondInSecond = 1000000; + +struct LossNodeInfo { + bool has_tuple_getitem = false; + int dout_index = 0; // now don't support the sens is a tuple +}; + +std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); +std::string CreateInstanceName(const CNodePtr &node, size_t index); +void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node); + +void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, + const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node); + +TensorLayout GetTensorInLayout(const CNodePtr &pre_node, const PrimitivePtr &pre_prim, + const OperatorInfoPtr &distribute_operator_pre); + +OperatorInfoPtr GetDistributeOperator(const CNodePtr &node); + +void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, + const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, + const CNodePtr &pre_node); + +bool StrategyFound(std::unordered_map attrs); + +bool IsParallelCareNode(const CNodePtr &cnode); + +void MarkForwardCNode(const FuncGraphPtr &root); + +bool FindCommunicationOp(const std::vector &all_nodes); + +void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, + const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node); + +std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, + const CNodePtr &node); + +void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node); + +void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node); + +std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph); + +std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); + +void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); + +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, + const std::vector> &sens_loss_pairs); + +// Generate and init parallel operator +OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + const std::vector &shape_list); + +// Generate without initing parallel operator +OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + std::vector shape_list); + +// Extract strategy from attr +StrategyPtr ExtractStrategy(std::unordered_map attrs); + +Shapes GetNodeShape(const AnfNodePtr &node); + +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); + +// Extract shape from anfnode +std::vector ExtractShape(const CNodePtr &node); + +std::pair FindParallelCareNode(const AnfNodePtr &node); + +// Find finally sub graph +std::pair FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter); + +// Set distribute shape for parameters abstract +void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res); + +// change parameters'shape in resource +void CoverSliceShape(const FuncGraphPtr &root); + +void SetVirtualDatasetStrategy(const CNodePtr &node); + +// Creat parallel operator for primitive node(has strategy) +void ExtractInformation(const std::vector &all_nodes); + +TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair); + +std::shared_ptr FindNextLayout(const CNodePtr &node); + +std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index); + +std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index); + +std::shared_ptr FindPrevLayout(const AnfNodePtr &node); + +void ReshapeInit(const std::vector &all_nodes); + +// Add node for whole graph +void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, + const FuncGraphManagerPtr &manager); + +std::string NodeParameterName(const CNodePtr &node); + +void CheckpointStrategy(const FuncGraphPtr &func_graph); + +// main step of Parallel +bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); + +int32_t GetTupleGetItemIndex(const CNodePtr &cnode); + +Status ParallelInit(); + +std::vector ExtractInputsTensorName(const CNodePtr &node); + +std::set ForwardGraph(const FuncGraphPtr &root); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/strategy.h b/mindspore/ccsrc/frontend/parallel/strategy.h new file mode 100644 index 0000000000..ca01164a6a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/strategy.h @@ -0,0 +1,74 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ +#define MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +#define MIN_SLICE_NUM 1 + +using Dimensions = std::vector; + +class Strategy; +using StrategyPtr = std::shared_ptr; + +class Strategy { + public: + Strategy(int32_t stage, std::vector inputs) : stage_(stage), inputs_(std::move(inputs)) {} + ~Strategy() = default; + size_t GetInputNumber() const { return inputs_.size(); } + std::vector GetInputDim() const { return inputs_; } + int32_t GetInputStage() const { return stage_; } + void ExpandInputDimFromOneToTwo() { + if (inputs_.size() == 1) { + inputs_.push_back(inputs_[0]); + } + } + void ResetInputs(const std::vector &input) { inputs_ = input; } + + bool IsEqual(const StrategyPtr &another_stra) { + if (another_stra == nullptr) { + return false; + } + if ((stage_ != another_stra->GetInputStage()) || (inputs_ != another_stra->GetInputDim())) { + return false; + } + return true; + } + + private: + const int32_t stage_; + + // The size of Dimensions must equal to inputs_ tensor dimension. + std::vector inputs_; +}; + +inline StrategyPtr NewStrategy(const int32_t stage, const std::vector &inputs) { + return std::make_shared(stage, inputs); +} +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc new file mode 100644 index 0000000000..bf7c4e29ab --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" + +#include +#include +#include + +#include "common/utils.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" +#include "proto/node_strategy.pb.h" + +namespace mindspore { +namespace parallel { +StrategyCheckpoint &StrategyCheckpoint::GetInstance() { + static StrategyCheckpoint instance = StrategyCheckpoint(); + if (ParallelContext::GetInstance() != nullptr) { + instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file(); + instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty(); + instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file(); + instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty(); + } + return instance; +} + +bool StrategyCheckpoint::CheckPointExit(const std::string path) const { + std::ifstream fin(path); + if (fin) { + return true; + } + return false; +} + +Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { + if (strategy_map == nullptr) { + MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; + } + if (!CheckPointExit(load_file_)) { + MS_LOG(EXCEPTION) << "CheckPoint file is not found"; + } + straspb::ParallelStrategyMap parallel_strategy_map; + std::fstream input(load_file_, std::ios::in | std::ios::binary); + if (!parallel_strategy_map.ParseFromIstream(&input)) { + MS_LOG(ERROR) << "Load strategy file failed"; + return FAILED; + } + size_t node_num = IntToSize(parallel_strategy_map.parallel_strategy_item_size()); + for (size_t i = 0; i < node_num; i++) { + straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i)); + std::string node_name = parallel_strategy_item.node_name(); + straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys(); + auto stage = (int32_t)parallel_strategys.stage(); + size_t strategys_num = IntToSize(parallel_strategys.parallel_strategy_size()); + std::vector> strategy_inputs; + for (size_t j = 0; j < strategys_num; j++) { + straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j)); + std::vector dimension; + size_t dim_num = IntToSize(parallel_strategy.dim_size()); + for (size_t k = 0; k < dim_num; k++) { + dimension.push_back(parallel_strategy.dim(SizeToInt(k))); + } + strategy_inputs.push_back(dimension); + } + + StrategyPtr strategy = NewStrategy(stage, strategy_inputs); + (*strategy_map)[node_name] = strategy; + current_stage_ = (int32_t)parallel_strategy_map.current_stage(); + } + return SUCCESS; +} + +Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { + straspb::ParallelStrategyMap parallel_strategy_map; + parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); + for (auto &node_stra : strategy_map) { + straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); + MS_EXCEPTION_IF_NULL(parallel_strategy_item); + parallel_strategy_item->set_node_name(node_stra.first); + straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); + MS_EXCEPTION_IF_NULL(parallel_strategys); + MS_EXCEPTION_IF_NULL(node_stra.second); + parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); + for (auto &dims : node_stra.second->GetInputDim()) { + straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); + MS_EXCEPTION_IF_NULL(parallel_strategy); + for (auto dim : dims) { + parallel_strategy->add_dim(IntToUint(dim)); + } + } + } + std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); + if (!parallel_strategy_map.SerializeToOstream(&output)) { + MS_LOG(ERROR) << "Save strategy file failed"; + return FAILED; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h new file mode 100644 index 0000000000..67cbb92ee2 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ +#define MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ + +#include +#include +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/context.h" + +namespace mindspore { +namespace parallel { +using StrategyMap = std::unordered_map; +class StrategyCheckpoint { + public: + StrategyCheckpoint() { + current_stage_ = 0; + load_file_ = ""; + load_checkpoint_on_ = false; + save_file_ = ""; + save_checkpoint_on_ = false; + } + ~StrategyCheckpoint() = default; + + Status Load(StrategyMap *strategy_map); + Status Save(const StrategyMap &strategy_map); + + static StrategyCheckpoint &GetInstance(); + bool LoadCheckPointOn() const { return load_checkpoint_on_; } + bool SaveCheckPointOn() const { return save_checkpoint_on_; } + + private: + std::string load_file_; + std::string save_file_; + bool load_checkpoint_on_; + bool save_checkpoint_on_; + bool CheckPointExit(const std::string path) const; + int32_t current_stage_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc new file mode 100644 index 0000000000..cff3d53a88 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc @@ -0,0 +1,248 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/arrangement.h" +#include +#include +#include +#include "common/utils.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/shape_util.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status Arrangement::Init(const std::vector &array) { + Status status = Array::Init(array); + if (status != Status::SUCCESS) { + return Status::FAILED; + } + if (!IsValidArrangement()) { + MS_LOG(ERROR) << "invalid arrangement " << this->ToString(); + return Status::FAILED; + } + ComputeSize(); + return Status::SUCCESS; +} + +bool Arrangement::IsValidArrangement() { + return !std::any_of(array_.begin(), array_.end(), [](int32_t value) { return value <= 0; }); +} + +void Arrangement::ComputeSize() { + size_ = 1; + for (auto &value : array_) { + size_ *= value; + } +} + +/* + * if GetDimSize() = 0, return [] + * if value <= array_[0], return [value] + * if array_[0] < value <= size_[i], return [shape[0], shape[1], ..., shape[i-1], value/size_[i-1]], + * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1], + * if value > size_, return [] + */ +std::vector Arrangement::GetFrontElementByValue(int32_t value) const { + std::vector out; + if (GetDimSize() == 0) { + return out; + } + if (value <= size_) { + int32_t size = 1; + uint32_t shape_list_idx = 0; + while (size < value) { + size *= array_[shape_list_idx]; + if (size <= value) { + out.push_back(array_[shape_list_idx]); + } else { + if (size == 0) { + MS_LOG(ERROR) << "The size is 0"; + out.clear(); + return out; + } + out.push_back(value * array_[shape_list_idx] / size); + } + shape_list_idx++; + } + } + return out; +} + +std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft( + const std::vector &expand_list) const { + if (expand_list.size() != GetDimSize()) { + return nullptr; + } + std::vector new_shape; + for (uint32_t i = 0; i < expand_list.size(); i++) { + std::vector expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i)); + if (expand_shape.empty()) { + new_shape.push_back(GetDimByIdx(i)); + } else { + (void)new_shape.insert(new_shape.end(), expand_shape.begin(), expand_shape.end()); + } + } + Arrangement arrangement_new; + (void)arrangement_new.Init(new_shape); + return std::make_shared(arrangement_new); +} + +/* + * example: + * expand_shape = [4, 2, 2, 2] + * array_ = [8, 4], + * arrangement_list = [[4, 2], [2, 2]] + */ +std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const { + int32_t size = 1; + uint32_t ind = 0; + std::vector arrangement_list; + std::vector shape; + for (uint32_t i = 0; i < expand_shape.GetDimSize(); i++) { + size *= expand_shape.GetDimByIdx(i); + if (size > GetDimByIdx(ind)) { + MS_LOG(ERROR) << "invalid expand_shape"; + return nullptr; + } else if (size < GetDimByIdx(ind)) { + shape.push_back(expand_shape.GetDimByIdx(i)); + continue; + } else { + shape.push_back(expand_shape.GetDimByIdx(i)); + Arrangement arrangement; + (void)arrangement.Init(shape); + arrangement_list.push_back(arrangement); + shape.clear(); + ind++; + size = 1; + } + } + if (ind != GetDimSize()) { + MS_LOG(ERROR) << "invalid expand_shape"; + return nullptr; + } + auto arrangement_new = std::make_shared>(arrangement_list); + return arrangement_new; +} + +std::shared_ptr, Arrangement>> Arrangement::GetExpandShapeListPair( + const Arrangement &expand_shape) const { + std::shared_ptr> expand_shape_list_ptr = GetExpandShapeList(expand_shape); + if (expand_shape_list_ptr == nullptr) { + return nullptr; + } + std::vector expand_num_list_shape; + (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(), + std::back_inserter(expand_num_list_shape), + [](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); }); + Arrangement expand_num_list; + Status status = expand_num_list.Init(expand_num_list_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + auto out_value = std::make_pair(*expand_shape_list_ptr, expand_num_list); + return std::make_shared, Arrangement>>(out_value); +} + +std::vector Arrangement::ComputeReverseAccumulateSumInReverseOrder() const { + std::vector shape_accum; + int32_t size = 0; + for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) { + shape_accum.push_back(size); + size += *iter; + } + return shape_accum; +} + +std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLeft( + const std::vector &expand_list) const { + if (expand_list.size() != GetDimSize()) { + return nullptr; + } + std::vector new_shape; + for (uint32_t i = 0; i < expand_list.size(); i++) { + if (expand_list[i].GetDimSize() >= 1) { + int32_t size = 1; + for (uint32_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) { + new_shape.push_back(expand_list[i].GetDimByIdx(k)); + size *= expand_list[i].GetDimByIdx(k); + } + new_shape.push_back(GetDimByIdx(i) / size); + } else { + new_shape.push_back(GetDimByIdx(i)); + } + } + Arrangement arrangement_new; + (void)arrangement_new.Init(new_shape); + return std::make_shared(arrangement_new); +} + +std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement &in2) const { + std::vector in1_accum; + Status status = ShapeToAccumulateProduct(array_, &in1_accum); + if (status != Status::SUCCESS) { + return nullptr; + } + std::vector in2_accum; + status = ShapeToAccumulateProduct(in2.array(), &in2_accum); + if (status != Status::SUCCESS) { + return nullptr; + } + std::vector out_accum; + status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum); + if (status != Status::SUCCESS) { + return nullptr; + } + std::vector out_shape; + status = AccumulateProductToShape(out_accum, &out_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + Arrangement out; + status = out.Init(out_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(out); +} + +std::vector Arrangement::GetSqueezeIdx() const { + std::vector out; + for (size_t i = 0; i < GetDimSize(); i++) { + if (GetDimByIdx(SizeToUint(i)) == 1) { + out.push_back(i); + } + } + return out; +} + +Arrangement Arrangement::GetSqueezeArrangement() const { + std::vector out_shape(array_.size()); + auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int32_t value) { return value != 1; }); + out_shape.resize(LongToSize(std::distance(out_shape.begin(), it))); + + // if all elements are 1, out_shape = {1} + if (out_shape.empty()) { + MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation"; + out_shape.push_back(1); + } + Arrangement out; + (void)out.Init(out_shape); + return out; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.h new file mode 100644 index 0000000000..ab807fb20a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ + +#include +#include +#include +#include +#include +#include +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/array.h" + +namespace mindspore { +namespace parallel { +class Arrangement : public Array { + public: + Arrangement() : size_(1) {} + ~Arrangement() override = default; + Status Init(const std::vector &array) override; + int32_t size() const { return size_; } + std::vector GetFrontElementByValue(int32_t value) const; + std::shared_ptr> GetExpandShapeList(const Arrangement &expand_shape) const; + std::vector ComputeReverseAccumulateSumInReverseOrder() const; + std::shared_ptr GetExpandedShapeByExpandListReserveLeft( + const std::vector &expand_list) const; + std::shared_ptr GetExpandedShapeByExpandListRemoveLeft( + const std::vector &expand_list) const; + std::shared_ptr, Arrangement>> GetExpandShapeListPair( + const Arrangement &expand_shape) const; + std::shared_ptr GetUnifiedShape(const Arrangement &in2) const; + std::vector GetSqueezeIdx() const; + Arrangement GetSqueezeArrangement() const; + + private: + bool IsValidArrangement(); + void ComputeSize(); + int32_t size_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc new file mode 100644 index 0000000000..4e1f467793 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/array.h" +#include +#include "frontend/parallel/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +std::string Array::ToString() const { + std::ostringstream buffer; + buffer << "[ "; + for (auto &element : array_) { + buffer << std::to_string(element) + " "; + } + buffer << "]"; + return buffer.str(); +} + +Status Array::Init(const std::vector &array) { + array_ = array; + return IsvalidArray() ? Status::SUCCESS : Status::FAILED; +} + +bool Array::IsvalidArray() const { return true; } + +int32_t Array::GetDimByIdx(uint32_t idx) const { + size_t mod_idx = idx; + if (idx >= GetDimSize()) { + MS_LOG(EXCEPTION) << "idx is " << idx << ", but array size is " << GetDimSize(); + } + return array_[mod_idx]; +} + +int32_t Array::GetDimByReverseIdx(uint32_t idx) const { + size_t mod_idx = idx; + if (idx >= GetDimSize()) { + MS_LOG(EXCEPTION) << "idx is " << idx << " but array size is " << GetDimSize(); + } + return array_[GetDimSize() - 1 - mod_idx]; +} + +bool Array::operator==(const Array &shape) const { + if (GetDimSize() != shape.GetDimSize()) { + return false; + } + for (uint32_t i = 0; i < GetDimSize(); i++) { + if (GetDimByIdx(i) != shape.GetDimByIdx(i)) { + return false; + } + } + return true; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/array.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.h new file mode 100644 index 0000000000..13b3982a18 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ + +#include +#include +#include +#include +#include +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +class Array { + public: + Array() = default; + virtual ~Array() = default; + std::string ToString() const; + virtual Status Init(const std::vector &array); + bool IsvalidArray() const; + std::vector array() const { return array_; } + size_t GetDimSize() const { return array_.size(); } + int32_t GetDimByIdx(uint32_t idx) const; + int32_t GetDimByReverseIdx(uint32_t idx) const; + bool operator==(const Array &a1) const; + + protected: + std::vector array_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc new file mode 100644 index 0000000000..9395d3df89 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc @@ -0,0 +1,254 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/construct_operator.h" + +#include +#include + +namespace mindspore { +namespace parallel { +Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape) { + dev_size_ = dev_matrix_shape.size(); + dev_matrix_shape_ = dev_matrix_shape; + dev_list_ = dev_list; + return Status::SUCCESS; +} + +Status ConstructOperator::ReshapeOP(Shape shape) { + int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies()); + if (prod != prod_expect) { + ValuePtr ptr = MakeValue(shape); + MS_EXCEPTION_IF_NULL(ptr); + MS_LOG(ERROR) << "Invalid tensor shape " << ptr->ToString() << "when construct Reshape operator!"; + return Status::INVALID_ARGUMENT; + } + OperatorAttrs attrs; + ValuePtr param_value = MakeValue(shape); + Attr param = std::make_pair(SHAPE, param_value); + OperatorParams params = {std::make_pair(param, 2)}; + OperatorArgs args = std::make_pair(attrs, params); + op_ = std::make_pair(RESHAPE, args); + return Status::SUCCESS; +} + +Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &end, const Shape &strides) { + ValuePtr attr_value = MakeValue(value); + Attr attr_begin_mask = std::make_pair(BEGIN_MASK, attr_value); + Attr attr_end_mask = std::make_pair(END_MASK, attr_value); + Attr attr_ellipsis_mask = std::make_pair(ELLIPSIS_MASK, attr_value); + Attr attr_new_axis_mask = std::make_pair(NEW_AXIS_MASK, attr_value); + Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value); + OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask}; + + ValuePtr param_begin_value = MakeValue(begin); + Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), 2); + ValuePtr param_end_value = MakeValue(end); + Param param_end = std::make_pair(std::make_pair(END, param_end_value), 3); + + ValuePtr param_strides_value = MakeValue(strides); + Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), 4); + OperatorParams params = {param_begin, param_end, param_strides}; + OperatorArgs op_args = std::make_pair(attrs, params); + + return std::make_pair(STRIDED_SLICE, op_args); +} + +Status ConstructOperator::StridedSliceOP(Args args) { + if (args.size() < 3) { + MS_LOG(ERROR) << "args size should not be less than 3!"; + return Status::FAILED; + } + int32_t split_count = args[0]; + if (split_count <= 0) { + MS_LOG(ERROR) << "split_count should not be less than 0!"; + return Status::FAILED; + } + int32_t split_dim = args[1]; + int32_t dev_dim = args[2]; + std::vector group_list; + + if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { + MS_LOG(ERROR) << "stride slice op: create group failed"; + return FAILED; + } else if (group_list.empty()) { // this group only has one device, don't need do StridedSlice + MS_LOG(INFO) << "no need stride slice op"; + return SUCCESS; + } + + Group group = group_list[0]; + size_t rank; + if (group.GetIndex(&rank) == Status::FAILED) { + return Status::FAILED; + } + size_t size = tensor_shape_.size(); + Shape begin(size); + Shape end(size); + Shape strides(size, 1); + size_t index = 0; + for (auto num : tensor_shape_) { + if (index != IntToSize(split_dim)) { + begin[index] = 0; + end[index] = num; + } else { + if (num % split_count != 0) { + MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim + << "! when construct StridedSlice operator"; + return Status::INVALID_ARGUMENT; + } + int32_t count = num / split_count; + begin[index] = SizeToInt(rank) * count; + end[index] = (SizeToInt(rank) + 1) * count; + } + index++; + } + + op_ = CreateStridedSliceOp(DEFAULT, begin, end, strides); + + return Status::SUCCESS; +} + +Status ConstructOperator::AllGatherOP(int32_t dev_dim) { + if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { + MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!"; + return Status::INVALID_ARGUMENT; + } + + std::vector group_list; + if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { + MS_LOG(ERROR) << "AllGather op: create group failed"; + return FAILED; + } else if (group_list.empty()) { // this group only has one device, don't need do allgather + MS_LOG(INFO) << "no need all gather op"; + return SUCCESS; + } + + std::string group_name = group_list[0].name(); + ValuePtr attr_value = MakeValue(group_name); + Attr attr = std::make_pair(GROUP, attr_value); + OperatorAttrs attrs = {attr}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + op_ = std::make_pair(ALL_GATHER, args); + return Status::SUCCESS; +} + +Status ConstructOperator::ConcatOP(int32_t concat_dim) { + if (IntToSize(concat_dim) >= tensor_shape_.size()) { + MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!"; + return Status::INVALID_ARGUMENT; + } + ValuePtr attr_value = MakeValue(concat_dim); + Attr attr = std::make_pair(AXIS, attr_value); + OperatorAttrs attrs = {attr}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + op_ = std::make_pair(CONCAT, args); + return Status::SUCCESS; +} + +Status ConstructOperator::SplitOP(int32_t split_count) { + if (split_count <= 0) { + MS_LOG(ERROR) << "Invalid split count when construct Split operator!"; + return Status::FAILED; + } + OperatorAttrs attrs; + ValuePtr attr_value_axis = MakeValue(DEFAULT); + Attr attr_axis = std::make_pair(AXIS, attr_value_axis); + ValuePtr attr_value_split = MakeValue(split_count); + Attr attr_split = std::make_pair(OUTPUT_NUM, attr_value_split); + attrs = {attr_axis, attr_split}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + op_ = std::make_pair(SPLIT, args); + return Status::SUCCESS; +} + +Status ConstructOperator::AlltoAllOP(Args args) { + if (args.size() < 4) { + MS_LOG(ERROR) << "args size should not be less than 4!"; + return Status::FAILED; + } + int32_t split_count = args[0]; + int32_t split_dim = args[1]; + int32_t concat_dim = args[2]; + int32_t dev_dim = args[3]; + if (split_count <= 0) { + MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!"; + return Status::FAILED; + } + if (tensor_shape_[IntToSize(split_dim)] % split_count != 0) { + MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim + << "when construct AlltoAll operator!"; + return Status::INVALID_ARGUMENT; + } + if (IntToSize(concat_dim) >= tensor_shape_.size()) { + MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!"; + return Status::INVALID_ARGUMENT; + } + if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { + MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!"; + return Status::INVALID_ARGUMENT; + } + + std::vector group_list; + if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { + MS_LOG(ERROR) << "AlltoAll op: create group failed"; + return FAILED; + } else if (group_list.empty()) { // this group only has one device, don't need do alltoall + MS_LOG(INFO) << "no need all to all op"; + return SUCCESS; + } + + std::string group_name = group_list[0].name(); + ValuePtr attr_value_group = MakeValue(group_name); + Attr attr_group = std::make_pair(GROUP, attr_value_group); + ValuePtr attr_value_split_count = MakeValue(split_count); + Attr attr_split_count = std::make_pair(SPLIT_COUNT, attr_value_split_count); + ValuePtr attr_value_split_dim = MakeValue(split_dim); + Attr attr_split_dim = std::make_pair(SPLIT_DIM, attr_value_split_dim); + ValuePtr attr_value_concat_dim = MakeValue(concat_dim); + Attr attr_concat_dim = std::make_pair(CONCAT_DIM, attr_value_concat_dim); + OperatorAttrs attrs = {attr_split_count, attr_split_dim, attr_concat_dim, attr_group}; + OperatorParams params; + OperatorArgs op_args = std::make_pair(attrs, params); + op_ = std::make_pair(ALL_TO_ALL, op_args); + return Status::SUCCESS; +} + +Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector *group) { + MS_EXCEPTION_IF_NULL(group); + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + int32_t rank = g_device_manager->global_rank(); + DeviceMatrix dev_matrix(rank, dev_list_, dev_matrix_shape_); + RankList group_devices; + if (dev_matrix.GetDevicesAlongDim(SizeToUint(axis), &group_devices) != SUCCESS) { + return FAILED; + } + // this group only has one device, don't need create the group + if (group_devices.size() == 1) { + MS_LOG(INFO) << "the group is empty"; + return SUCCESS; + } + + Group g = g_device_manager->CreateGroup(group_devices); + group->push_back(g); + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h new file mode 100644 index 0000000000..b06d70af36 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +using Args = std::vector; + +class ConstructOperator { + public: + const int32_t DEFAULT = 0; + ConstructOperator() : dev_size_(0) {} + ~ConstructOperator() = default; + Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); + Status ReshapeOP(Shape shape); + Status StridedSliceOP(Args args); + Status AllGatherOP(int32_t dev_dim); + Status SplitOP(int32_t split_count); + Status ConcatOP(int32_t concat_dim); + Status AlltoAllOP(Args args); + Operator GetOperator() const { return op_; } + void UpdateTensorShape(const Shape &tensor_shape) { tensor_shape_ = tensor_shape; } + + private: + Operator op_; + size_t dev_size_; + Shape tensor_shape_; + RankList dev_list_; + Shape dev_matrix_shape_; + Status CreateGroupByDim(size_t axis, std::vector *group); +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.cc new file mode 100644 index 0000000000..d5d34a484f --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/layout_transfer.h" +#include "common/utils.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +std::string LayoutTransfer::ToString() const { + std::ostringstream buffer; + buffer << std::endl << std::string("from_in_ tensor layout:" + from_in_.ToString()); + buffer << std::endl << std::string("to_in_ tensor layout:" + to_in_.ToString()); + return buffer.str(); +} + +LayoutTransfer::~LayoutTransfer() = default; + +Status LayoutTransfer::Init(const TensorLayout &from_in, const TensorLayout &to_in) { + from_in_ = from_in; + to_in_ = to_in; + MS_LOG(DEBUG) << "LayoutTransfer " << this->ToString(); + Status status = CheckValidTransfer(); + return status; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.h new file mode 100644 index 0000000000..01c56fc7cf --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ + +#include +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +class LayoutTransfer { + public: + LayoutTransfer() = default; + virtual ~LayoutTransfer() = 0; + std::string ToString() const; + Status Init(const TensorLayout &from_in, const TensorLayout &to_in); + TensorLayout from_in() const { return from_in_; } + TensorLayout to_in() const { return to_in_; } + + protected: + bool IsSameTensorShape() const { return from_in_.IsSameTensorShape(to_in_); } + bool IsSameDeviceArrangement() const { return from_in_.IsSameDeviceArrangement(to_in_); } + + TensorLayout from_in_; + TensorLayout to_in_; + + private: + virtual Status CheckValidTransfer() = 0; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc new file mode 100644 index 0000000000..184f0c7530 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc @@ -0,0 +1,171 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/map.h" +#include +#include +#include +#include "common/utils.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/shape_util.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status Map::Init(const std::vector &array) { + Status status = Array::Init(array); + if (status != Status::SUCCESS) { + return Status::FAILED; + } + if (!IsValidMap()) { + MS_LOG(ERROR) << "invalid map " << this->ToString(); + return Status::FAILED; + } + return Status::SUCCESS; +} + +bool Map::IsValidMap() { + if (std::any_of(array_.begin(), array_.end(), [](int32_t value) { return ((value < 0) && (value != MAP_NONE)); })) { + return false; + } + // check that all none -1 value in array_ is different + std::vector sorted_array = array_; + std::sort(sorted_array.begin(), sorted_array.end()); + int32_t value = MAP_NONE; + for (auto &element : sorted_array) { + if (element == MAP_NONE) { + continue; + } + if (element == value) { + return false; + } + value = element; + } + return true; +} + +int32_t Map::GetMaxItem() const { + if (!array_.empty()) { + return *std::max_element(array_.begin(), array_.end()); + } else { + return MAP_NONE; + } +} + +int32_t Map::GetIndexByValue(int32_t value) const { + auto iter = find(array_.begin(), array_.end(), value); + if (iter != array_.end()) { + return static_cast(std::distance(array_.begin(), iter)); + } else { + return MAP_NONE; + } +} + +/* + * expand.size() should be equal to array_.size() + */ +std::shared_ptr Map::ExpandMapByNone(const Arrangement &expand_num_list) const { + if (expand_num_list.GetDimSize() != GetDimSize()) { + return nullptr; + } + std::vector new_shape; + for (uint32_t i = 0; i != GetDimSize(); i++) { + if (GetDimByIdx(i) == MAP_NONE) { + for (int32_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) { + new_shape.push_back(MAP_NONE); + } + } else { + new_shape.push_back(GetDimByIdx(i)); + int32_t j = 1; + while (j < expand_num_list.GetDimByIdx(i)) { + new_shape.push_back(MAP_NONE); + j++; + } + } + } + auto map_new = std::make_shared(); + (void)map_new->Init(new_shape); + return map_new; +} + +/* + * expand.size() should be equal to array_.size() + */ +std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const { + if (GetMaxItem() >= static_cast(expand_num_list.GetDimSize())) { + return nullptr; + } + std::vector new_shape; + for (uint32_t i = 0; i < GetDimSize(); i++) { + if (GetDimByIdx(i) == MAP_NONE) { + new_shape.push_back(MAP_NONE); + } else { + int32_t start_map = + expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast(GetDimByIdx(i))]; + for (int32_t k = expand_num_list.GetDimByReverseIdx(static_cast(GetDimByIdx(i))) - 1; k >= 0; k--) { + new_shape.push_back(k + start_map); + } + } + } + auto map_new = std::make_shared(); + (void)map_new->Init(new_shape); + return map_new; +} + +std::shared_ptr> Map::ReMapVector(const std::vector &input_vector) const { + if (GetMaxItem() >= static_cast(input_vector.size())) { + return nullptr; + } + std::vector out; + Arrangement empty_arrangement; + for (uint32_t i = 0; i < GetDimSize(); i++) { + if (GetDimByIdx(i) == MAP_NONE) { + out.push_back(empty_arrangement); + } else { + out.push_back(input_vector[IntToUint(SizeToInt(input_vector.size()) - 1 - GetDimByIdx(i))]); + } + } + return std::make_shared>(out); +} + +bool Map::CheckNoneByIdxList(std::vector idx_list) const { + for (auto &value : idx_list) { + if (GetDimByIdx(SizeToUint(value)) != MAP_NONE) { + return false; + } + } + return true; +} + +Map Map::SqueezeMapByIdxList(std::vector idx_list) const { + std::vector out_shape; + for (size_t i = 0; i < GetDimSize(); i++) { + auto it = std::find(idx_list.begin(), idx_list.end(), i); + if (it == idx_list.end()) { + out_shape.push_back(GetDimByIdx(SizeToUint(i))); + } + } + if (out_shape.empty()) { + MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation"; + out_shape.push_back(MAP_NONE); + } + Map out; + (void)out.Init(out_shape); + return out; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/map.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.h new file mode 100644 index 0000000000..3d299d4b90 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ + +#include +#include +#include +#include +#include +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/arrangement.h" +#include "frontend/parallel/tensor_layout/array.h" + +namespace mindspore { +namespace parallel { +constexpr int32_t MAP_NONE = -1; + +class Map : public Array { + public: + Map() = default; + ~Map() override = default; + Status Init(const std::vector &array) override; + int32_t GetMaxItem() const; + int32_t GetIndexByValue(int32_t value) const; + std::shared_ptr ExpandMapByNone(const Arrangement &expand_num_list) const; + std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const; + std::shared_ptr> ReMapVector(const std::vector &input_vector) const; + bool CheckNoneByIdxList(std::vector idx_list) const; + Map SqueezeMapByIdxList(std::vector idx_list) const; + + private: + bool IsValidMap(); +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc new file mode 100644 index 0000000000..a5a488d807 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/reshape_layout_transfer.h" +#include "frontend/parallel/tensor_layout/shape_util.h" + +namespace mindspore { +namespace parallel { +Status RedistributionLayoutTransfer::CheckValidTransfer() { return Status::SUCCESS; } + +/* + * unify device arrangement between in_layout and out_layout + * after this function is called, + * in_step1_layout.device_arrangement and out_step1_layout.device_arrangement will be the same + */ +std::shared_ptr RedistributionLayoutTransfer::UnifyDeviceArrangement() const { + Arrangement in_arrangement; + Arrangement out_arrangement; + in_arrangement = from_in_.device_arrangement(); + out_arrangement = to_in_.device_arrangement(); + std::shared_ptr unify_arrangement_ptr = in_arrangement.GetUnifiedShape(out_arrangement); + if (unify_arrangement_ptr == nullptr) { + return nullptr; + } + std::shared_ptr from_out_ptr = from_in_.ExpandDeviceArrangement(*unify_arrangement_ptr); + if (from_out_ptr == nullptr) { + return nullptr; + } + std::shared_ptr to_out_ptr = to_in_.ExpandDeviceArrangement(*unify_arrangement_ptr); + if (to_out_ptr == nullptr) { + return nullptr; + } + ReshapeLayoutTransfer out; + Status status = out.Init(*from_out_ptr, *to_out_ptr); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(out); +} + +/* + * unify tensor shape between in_step1_layout.tensor_shape and out_step1_layout.tensor_shape + * after this function is called, + * in_step2_layout.tensor_shape and out_step2_layout.tensor_shape will be the same + */ +std::shared_ptr RedistributionLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { + std::shared_ptr unified_device_arrangement_ptr = UnifyDeviceArrangement(); + if (unified_device_arrangement_ptr == nullptr) { + return nullptr; + } + return unified_device_arrangement_ptr->UnifyDeviceArrangementAndTensorShape(); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h new file mode 100644 index 0000000000..0347b6423a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ + +#include +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/layout_transfer.h" +#include "frontend/parallel/tensor_layout/reshape_layout_transfer.h" + +namespace mindspore { +namespace parallel { +class RedistributionLayoutTransfer : public LayoutTransfer { + public: + RedistributionLayoutTransfer() = default; + ~RedistributionLayoutTransfer() override = default; + std::shared_ptr UnifyDeviceArrangementAndTensorShape() const; + + private: + Status CheckValidTransfer() override; + std::shared_ptr UnifyDeviceArrangement() const; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc new file mode 100644 index 0000000000..6ac24418b7 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc @@ -0,0 +1,289 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/redistribution_operator_infer.h" + +#include + +#include "frontend/parallel/device_manager.h" + +namespace mindspore { +namespace parallel { +Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, + RankList dev_list, bool is_cost_model) { + in_tensor_map_ = tensor_layout.tensor_map(); + dev_mat_ = tensor_layout.device_arrangement(); + + if (in_tensor_map_.GetDimSize() == 0 || out_tensor_map.GetDimSize() != in_tensor_map_.GetDimSize()) { + MS_LOG(ERROR) << "Invalid input when initialize RedistributionOperatorInfer!"; + return Status::FAILED; + } + + cur_tensor_layout_ = tensor_layout; + out_tensor_map_ = out_tensor_map; + dev_list_ = std::move(dev_list); + + operator_list_.clear(); + operator_vector_.clear(); + output_info_vector_.clear(); + + if (constructor_.Init(dev_list_, dev_mat_.array()) != Status::SUCCESS) { + MS_LOG(ERROR) << "Init constructor failed"; + return Status::FAILED; + } + constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array()); + + size_t key = 0; + std::vector map = in_tensor_map_.array(); + for (int32_t item : map) { + map_[key++] = item; + } + + is_cost_model_ = is_cost_model; + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::InferRedistributionOperator() { + while (!map_.empty()) { + size_t len_global = operator_list_.size(); + + while (!map_.empty()) { + size_t len_split_by_axis = operator_list_.size(); + // split_by_axis operation + if (InferSplitByAxis() == Status::FAILED) { + return Status::FAILED; + } + // permute_by_axis operation + while (!map_.empty()) { + size_t len_permute_by_axis = operator_list_.size(); + if (InferPermuteByAxis() == Status::FAILED) { + return Status::FAILED; + } + if (len_permute_by_axis == operator_list_.size()) break; + } + if (len_split_by_axis == operator_list_.size()) break; + } + // concat_by_axis operation + if (InferConcatByAxis() == Status::FAILED) { + return Status::FAILED; + } + // break loop structure with concat_by_axis + if (len_global == operator_list_.size() && !map_.empty()) { + size_t index = map_.begin()->first; + int32_t in_dim = map_[index]; + map_[index] = NONE; + Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; + if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { + return Status::FAILED; + } + } + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::InferSplitByAxis() { + for (auto iter = map_.begin(); iter != map_.end();) { + uint32_t index = iter->first; + int32_t in_dim = iter->second; + int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + if (in_dim == out_dim) { + (void)map_.erase(iter++); + continue; + } + if (in_dim == NONE && + !std::any_of(map_.begin(), map_.end(), + [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { + Args args = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; + if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) { + MS_LOG(ERROR) << "Insert SplitByAxis Error!"; + return Status::FAILED; + } + (void)map_.erase(iter++); + } else { + (void)++iter; + } + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::InferPermuteByAxis() { + for (auto iter = map_.begin(); iter != map_.end();) { + uint32_t index = iter->first; + int32_t in_dim = map_[index]; + int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + if (in_dim == out_dim) { + (void)map_.erase(iter++); + continue; + } + if (in_dim == NONE && + std::any_of(map_.begin(), map_.end(), + [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { + int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); + int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); + if (is_cost_model_) { + int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); + Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim, + dev_num}; + if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) { + MS_LOG(ERROR) << "Insert PermuteByAxis Error!"; + return Status::FAILED; + } + } else { + Args args_allconcat = {cat_dim, out_dim, dev_num}; + Args args_allsplit = {dev_num, UintToInt(index), out_dim}; + if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { + MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; + return Status::FAILED; + } + if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { + MS_LOG(ERROR) << "Insert SplitByAxis Error!"; + return Status::FAILED; + } + } + (void)map_.erase(iter++); + map_[IntToSize(cat_dim)] = NONE; + } else { + (void)++iter; + } + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::InferConcatByAxis() { + for (auto iter = map_.begin(); iter != map_.end();) { + uint32_t index = iter->first; + int32_t in_dim = map_[index]; + int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) { + Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; + if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { + MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; + return Status::FAILED; + } + if (out_dim == NONE) { + (void)map_.erase(iter++); + } else { + map_[index] = NONE; + (void)++iter; + } + } else { + (void)++iter; + } + } + return Status::SUCCESS; +} + +// Transfer communicative operators into primitives and insert them into vector +Status RedistributionOperatorInfer::InsertOperator(OperatorName name, Args args) { + OperatorR op = std::make_pair(name, args); + OperatorC op_cost = std::make_pair(op, cur_tensor_layout_.slice_shape().array()); + operator_list_.push_back(op_cost); + if (construct_op_flag_) { + if (name == SPLIT_BY_AXIS) { + if (TransferSplitByAxis(args) == Status::FAILED) { + return Status::FAILED; + } + } else if (name == PERMUTE_BY_AXIS) { + if (TransferPermuteByAxis(args) == Status::FAILED) { + return Status::FAILED; + } + } else { + if (TransferConcatByAxis(args) == Status::FAILED) { + return Status::FAILED; + } + } + constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array()); + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::TransferSplitByAxis(Args args) { + if (args.size() < 3) { + MS_LOG(ERROR) << "args size should not be less than 3!"; + return Status::FAILED; + } + uint32_t index = IntToUint(args[1]); + if (constructor_.StridedSliceOP(args) != Status::SUCCESS) { + return Status::FAILED; + } else { + operator_vector_.push_back(constructor_.GetOperator()); + output_info_vector_.push_back(std::make_pair(false, 0)); + } + if (cur_tensor_layout_.UpdateTensorMap(index, args[2]) == Status::FAILED) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::TransferPermuteByAxis(Args args) { + if (args.size() < 3) { + MS_LOG(ERROR) << "args size should not be less than 3!"; + return Status::FAILED; + } + if (constructor_.AlltoAllOP(args) != Status::SUCCESS) { + return Status::FAILED; + } else { + operator_vector_.push_back(constructor_.GetOperator()); + output_info_vector_.push_back(std::make_pair(false, 0)); + } + uint32_t index = IntToUint(args[1]); + int32_t val = args[2]; + int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + + if (cur_tensor_layout_.UpdateTensorMap(IntToUint(val), NONE) == Status::FAILED) { + return Status::FAILED; + } + if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) { + if (args.size() < 3) { + MS_LOG(ERROR) << "args size should not be less than 3!"; + return Status::FAILED; + } + int32_t tensor_dim = args[0]; + int32_t dev_dim = args[1]; + int32_t split_count = args[2]; + if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) { + return Status::FAILED; + } else { + operator_vector_.push_back(constructor_.GetOperator()); + output_info_vector_.push_back(std::make_pair(false, 0)); + } + if (tensor_dim != 0) { + if (constructor_.SplitOP(split_count) != Status::SUCCESS) { + return Status::FAILED; + } else { + operator_vector_.push_back(constructor_.GetOperator()); + output_info_vector_.push_back(std::make_pair(true, split_count)); + } + if (constructor_.ConcatOP(tensor_dim) != Status::SUCCESS) { + return Status::FAILED; + } else { + operator_vector_.push_back(constructor_.GetOperator()); + output_info_vector_.push_back(std::make_pair(false, 0)); + } + } + if (cur_tensor_layout_.UpdateTensorMap(IntToUint(tensor_dim), NONE) == Status::FAILED) { + return Status::FAILED; + } + return Status::SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.h new file mode 100644 index 0000000000..66cdb3f925 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/tensor_layout/construct_operator.h" +#include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h" +#include "utils/convert_utils.h" +namespace mindspore { +namespace parallel { +using DeviceArrangement = std::vector; +using TensorMap = std::vector; +using TensorShape = std::vector; +using RedistributionOperatorMap = std::unordered_map; +using OperatorR = std::pair; +using OperatorC = std::pair; +using OperatorList = std::vector; + +class RedistributionOperatorInfer { + public: + const int NONE = -1; + explicit RedistributionOperatorInfer(bool construct_op_flag = true) + : construct_op_flag_(construct_op_flag), is_cost_model_(false) {} + Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, + bool is_cost_model = false); + ~RedistributionOperatorInfer() = default; + OperatorList operator_list() const { return operator_list_; } + OperatorVector operator_vector() const { return operator_vector_; } + OutPutInfoVector output_info_vector() const { return output_info_vector_; } + Status InferRedistributionOperator(); + + private: + Status InferSplitByAxis(); + Status InferPermuteByAxis(); + Status InferConcatByAxis(); + Status TransferSplitByAxis(Args args); + Status TransferPermuteByAxis(Args args); + Status TransferConcatByAxis(Args args); + Status InsertOperator(OperatorName name, Args args); + + OperatorList operator_list_; + OperatorVector operator_vector_; + OutPutInfoVector output_info_vector_; + Arrangement dev_mat_; + RedistributionOperatorMap map_; + Map in_tensor_map_; + Map out_tensor_map_; + TensorLayout cur_tensor_layout_; + ConstructOperator constructor_; + RankList dev_list_; + bool construct_op_flag_; + bool is_cost_model_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.cc new file mode 100644 index 0000000000..98f7cf78fa --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/reshape_layout_transfer.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/shape_util.h" + +namespace mindspore { +namespace parallel { +Status ReshapeLayoutTransfer::CheckValidTransfer() { + if (!IsSameDeviceArrangement()) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +std::shared_ptr ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { + bool is_unified = IsSameTensorShape(); + std::shared_ptr out_layout_ptr = std::make_shared(*this); + if (out_layout_ptr == nullptr) { + return nullptr; + } + while (!is_unified) { + std::shared_ptr temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); + if (temp_layout_ptr == nullptr) { + return nullptr; + } + out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); + if (out_layout_ptr == nullptr) { + return nullptr; + } + is_unified = out_layout_ptr->IsSameTensorShape(); + } + return out_layout_ptr; +} + +std::shared_ptr ReshapeLayoutTransfer::ExtendFromTensorShapeByTo() const { + std::shared_ptr out_ptr = std::make_shared(*this); + bool is_expanded = FromTensorShapeCanBeExpandByTo(); + while (!is_expanded) { + out_ptr = out_ptr->ExtendFromTensorShapeByExpandedTensorShape(); + if (out_ptr == nullptr) { + return nullptr; + } + is_expanded = out_ptr->FromTensorShapeCanBeExpandByTo(); + } + return out_ptr; +} + +std::shared_ptr ReshapeLayoutTransfer::ExtendToTensorShapeByFrom() const { + std::shared_ptr out_ptr = std::make_shared(*this); + bool is_expanded = ToTensorShapeCanBeExpandByFrom(); + while (!is_expanded) { + out_ptr = out_ptr->ExtendToTensorShapeByExpandedTensorShape(); + if (out_ptr == nullptr) { + return nullptr; + } + is_expanded = out_ptr->ToTensorShapeCanBeExpandByFrom(); + } + return out_ptr; +} + +bool ReshapeLayoutTransfer::FromTensorShapeCanBeExpandByTo() const { + return from_in_.TensorShapeCanBeExpanded(to_in_.tensor_shape()); +} + +bool ReshapeLayoutTransfer::ToTensorShapeCanBeExpandByFrom() const { + return to_in_.TensorShapeCanBeExpanded(from_in_.tensor_shape()); +} + +std::shared_ptr ReshapeLayoutTransfer::ExtendFromTensorShapeByExpandedTensorShape() const { + std::shared_ptr expanded_shape_ptr = ComputeExpandedFromTensorShapeByTo(); + if (expanded_shape_ptr == nullptr) { + return nullptr; + } + return ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); +} + +std::shared_ptr ReshapeLayoutTransfer::ExtendToTensorShapeByExpandedTensorShape() const { + std::shared_ptr exchanged_from_and_to_ptr = ExchangeFromAndTo(); + if (exchanged_from_and_to_ptr == nullptr) { + return nullptr; + } + std::shared_ptr expanded_shape_ptr = exchanged_from_and_to_ptr->ComputeExpandedFromTensorShapeByTo(); + if (expanded_shape_ptr == nullptr) { + return nullptr; + } + std::shared_ptr exchanged_out = + exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); + if (exchanged_out == nullptr) { + return nullptr; + } + return exchanged_out->ExchangeFromAndTo(); +} + +std::shared_ptr ReshapeLayoutTransfer::ExchangeFromAndTo() const { + ReshapeLayoutTransfer out; + Status status = out.Init(to_in_, from_in_); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(out); +} + +std::shared_ptr ReshapeLayoutTransfer::ExpandFromTensorShapeAndExpandToDeviceArrangement( + const Arrangement &expand_shape) const { + std::shared_ptr extend_tensor_shape_from_ptr = from_in_.ExpandTensorShape(expand_shape); + if (extend_tensor_shape_from_ptr == nullptr) { + return nullptr; + } + Arrangement unified_device_arrangement = extend_tensor_shape_from_ptr->device_arrangement(); + std::shared_ptr extend_device_arrangement_to_ptr = + to_in_.ExpandDeviceArrangement(unified_device_arrangement); + if (extend_device_arrangement_to_ptr == nullptr) { + return nullptr; + } + ReshapeLayoutTransfer out; + Status status = out.Init(*extend_tensor_shape_from_ptr, *extend_device_arrangement_to_ptr); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(out); +} + +std::shared_ptr ReshapeLayoutTransfer::ComputeExpandedFromTensorShapeByTo() const { + return from_in_.ComputeExpandedTensorShape(to_in_.tensor_shape()); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h new file mode 100644 index 0000000000..f9ebe9e32b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ + +#include +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/layout_transfer.h" + +namespace mindspore { +namespace parallel { +class ReshapeLayoutTransfer : public LayoutTransfer { + public: + ReshapeLayoutTransfer() = default; + ~ReshapeLayoutTransfer() override = default; + std::shared_ptr UnifyDeviceArrangementAndTensorShape() const; + std::shared_ptr ExtendFromTensorShapeByTo() const; + std::shared_ptr ExtendToTensorShapeByFrom() const; + std::shared_ptr ExtendFromTensorShapeByExpandedTensorShape() const; + std::shared_ptr ExtendToTensorShapeByExpandedTensorShape() const; + std::shared_ptr ExpandFromTensorShapeAndExpandToDeviceArrangement( + const Arrangement &expand_shape) const; + std::shared_ptr ExchangeFromAndTo() const; + + private: + Status CheckValidTransfer() override; + std::shared_ptr ComputeExpandedFromTensorShapeByTo() const; + bool FromTensorShapeCanBeExpandByTo() const; + bool ToTensorShapeCanBeExpandByFrom() const; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc new file mode 100644 index 0000000000..83282d16b3 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc @@ -0,0 +1,263 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/shape_util.h" +#include +#include "frontend/parallel/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +/* + * example: + * shape = [2, 8, 32] + * shape_accum = [2, 2 * 8, 2 * 8 * 32] + */ +Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum) { + MS_EXCEPTION_IF_NULL(shape_accum); + shape_accum->clear(); + int64_t size = 1; + for (auto iter = shape.begin(); iter < shape.end(); ++iter) { + size *= *iter; + if (size <= 0) { + MS_LOG(ERROR) << "element of shape should not be zero"; + return Status::FAILED; + } + shape_accum->push_back(size); + } + return Status::SUCCESS; +} + +/* + * example: + * shape = [2, 8, 32] + * shape_accum = [2 * 8 * 32, 8 * 32, 32] + * + */ +Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum) { + MS_EXCEPTION_IF_NULL(shape_accum); + shape_accum->clear(); + int64_t size = 1; + for (auto iter = shape.end() - 1; iter >= shape.begin(); --iter) { + size *= *iter; + if (size <= 0) { + MS_LOG(ERROR) << "element of shape should not be zero"; + return Status::FAILED; + } + (void)shape_accum->insert(shape_accum->begin(), size); + } + return Status::SUCCESS; +} + +/* + * example: + * shape_accum = [2, 2 * 8, 2 * 8 * 32] + * shape = [2, 8, 32] + * + */ +Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape) { + MS_EXCEPTION_IF_NULL(shape); + shape->clear(); + int64_t value = 1; + for (auto iter = shape_accum.begin(); iter < shape_accum.end(); ++iter) { + if ((*iter) == 0) { + MS_LOG(ERROR) << "element of shape_accum should not be zero"; + return Status::FAILED; + } + if ((*iter) % value != 0) { + MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; + return Status::FAILED; + } + shape->push_back(static_cast((*iter) / value)); + value = (*iter); + } + return Status::SUCCESS; +} + +/* + * example: + * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * shape = [2, 8, 32] + */ +Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape) { + MS_EXCEPTION_IF_NULL(shape); + shape->clear(); + int64_t value = 1; + for (auto iter = shape_accum_reverse.end() - 1; iter >= shape_accum_reverse.begin(); --iter) { + if (*iter == 0) { + MS_LOG(ERROR) << "element of shape_accum should not be zero"; + return Status::FAILED; + } + if ((*iter) % value != 0) { + MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; + return Status::FAILED; + } + (void)shape->insert(shape->begin(), static_cast((*iter) / value)); + value = *iter; + } + return Status::SUCCESS; +} + +/* + * example1: + * in1 = [2, 8] + * in2 = [4, 8] + * *out = [2, 4, 8] + * + * example2: + * in1 = [2, 4, 16] + * in2 = [8, 16] + * *out = [2, 4, 8, 16] + */ +Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, + std::vector *out_accum) { + MS_EXCEPTION_IF_NULL(out_accum); + out_accum->clear(); + auto in1_iter = in1_accum.begin(); + auto in2_iter = in2_accum.begin(); + while ((in1_iter < in1_accum.end()) || (in2_iter < in2_accum.end())) { + if ((*in1_iter <= 0) || (*in2_iter <= 0)) { + MS_LOG(ERROR) << "element of in1 and in2 must be larger than zero"; + return Status::FAILED; + } + if (*in1_iter < *in2_iter) { + out_accum->push_back(*in1_iter); + ++in1_iter; + continue; + } else if (*in1_iter == *in2_iter) { + out_accum->push_back(*in1_iter); + ++in1_iter; + ++in2_iter; + } else { + out_accum->push_back(*in2_iter); + ++in2_iter; + } + } + if ((in1_iter != in1_accum.end()) || (in2_iter != in2_accum.end())) { + MS_LOG(ERROR) << "last element of in1 and in2 must be equal"; + return Status::FAILED; + } + return Status::SUCCESS; +} + +/* + * example: + * in1 = [8, 4] + * in2 = [2, 16] + * out = [2, 4, 4] + */ +Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out) { + MS_EXCEPTION_IF_NULL(out); + std::vector in1_accum; + Status status = ShapeToAccumulateProduct(in1, &in1_accum); + if (status != Status::SUCCESS) { + return status; + } + std::vector in2_accum; + status = ShapeToAccumulateProduct(in2, &in2_accum); + if (status != Status::SUCCESS) { + return status; + } + std::vector out_accum; + status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum); + if (status != Status::SUCCESS) { + return status; + } + status = AccumulateProductToShape(out_accum, out); + if (status != Status::SUCCESS) { + return status; + } + return status; +} + +/* + * example1: + * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * expand_accum_reverse = [2 * 8 * 32, 32, 8] + * out_accum_reverse = [2 * 8 * 4 * 8, 8 * 4 * 8, 4 * 8, 8] + * + * example2: + * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] + * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] + */ +Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, + const std::vector &expand_accum_reverse, + std::vector *out_accum_reverse) { + MS_EXCEPTION_IF_NULL(out_accum_reverse); + out_accum_reverse->clear(); + auto in_riter = in_accum_reverse.rbegin(); + auto expand_riter = expand_accum_reverse.rbegin(); + while (expand_riter != expand_accum_reverse.rend()) { + if (in_riter == in_accum_reverse.rend()) { + MS_LOG(ERROR) << "invalid ExpandAccumProd inputs"; + return Status::FAILED; + } + if (*in_riter > *expand_riter) { + (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter); + ++expand_riter; + } else if (*in_riter == *expand_riter) { + (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter); + ++in_riter; + ++expand_riter; + } else { + (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter); + ++in_riter; + } + } + while (in_riter != in_accum_reverse.rend()) { + (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter); + ++in_riter; + } + return Status::SUCCESS; +} + +/* + * example1: + * in = [2, 8, 32] + * expand = [16, 4, 8] + * out = [2, 8, 4, 8] + * + * example2: + * in = [2, 8, 32] + * expand = [2, 4, 8] + * out = [2, 4, 2, 4, 8] + */ +Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out) { + MS_EXCEPTION_IF_NULL(out); + std::vector in_accum_reverse; + Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse); + if (status != Status::SUCCESS) { + return status; + } + std::vector expand_accum_reverse; + status = ShapeToAccumulateProductReverse(expand, &expand_accum_reverse); + if (status != Status::SUCCESS) { + return status; + } + std::vector out_accum_reverse; + status = ExpandAccumulateProduct(in_accum_reverse, expand_accum_reverse, &out_accum_reverse); + if (status != Status::SUCCESS) { + return status; + } + status = AccumulateProductReverseToShape(out_accum_reverse, out); + if (status != Status::SUCCESS) { + return status; + } + return status; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.h new file mode 100644 index 0000000000..49dd39ffd6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.h @@ -0,0 +1,172 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +/* + * compute the accumulating product of all the values in shape from left to right, + * the accumulating results are saved in shape_accum from left to right + * + * given a shape = [d_n-1, d_n-2, ..., d_0](d_i > 0, i=0,1,...,n-1, elements of shape must be larger than zero), + * then *shape_accum = [d_n-1, d_n-1 * d_n-2, d_n-1 * d_n-2 * d_n-3, ..., d_n-1 * d_n-2 * ... *d_0] + * + * example: + * shape = [2, 8, 32] + * shape_accum = [2, 2 * 8, 2 * 8 * 32] + * + */ +Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum); + +/* + * compute the accumulating product of all the values in shape from right to left, + * the accumulating results are saved in shape_accum from right to left + * + * given a shape = [d_n-1, d_n-2, ..., d_0](d_i > 0, i=0,1,...,n-1, elements of shape must be larger than zero), + * then *shape_accum = [d_n-1 * d_n-2 * ... *d_0, d_n-2 * d_n-3 * ... *d_0, ..., d_0] + * + * example: + * shape = [2, 8, 32] + * shape_accum = [2 * 8 * 32, 8 * 32, 32] + * + */ +Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum); + +/* + * compute the original shape from the accumulating product shape_accum, + * elements of shape_accum is saved from left to right, + * given shape_accum = [accum_n-1, accum_n-2, accum_n-3, ..., accum_0] + * (accum_i > 0, i=0,1,...,n-1, elements of shape_accum must be larger than zero), + * (accum_i-1 % accum_i == 0, i=1,...,n-1) + * then *shape = [accum_n-2/accum_n-1, accum_n-3/accum_n-2, ..., accum_0/accum_1] + * + * example: + * shape_accum = [2, 2 * 8, 2 * 8 * 32] + * shape = [2, 8, 32] + * + */ +Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape); + +/* + * compute the original shape from the accumulating product shape_accum, + * elements of shape_accum is saved from right to left, + * given shape_accum_reverse = [accum_n-1, accum_n-2, accum_n-3, ..., accum_0] + * (accum_i > 0, i=0,1,...,n-1, elements of shape_accum must be larger than zero), + * (accum_i % accum_i-1 == 0, i=1,...,n-1) + * then *shape = [accum_n-1/accum_n-2, accum_n-2/accum_n-1, ..., accum_1/accum_0] + * + * example: + * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * shape = [2, 8, 32] + * + */ +Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape); + +/* + * given two accumulate product in1_accum and in2_accum, compute the union of in1_accum and in2_accum, + * results are saved in out. + * i.e. *out_accum = in1_accum U in2_accum + * elements of out are saved in increasing order + * + * example1: + * in1_accum = [2, 8] + * in2_accum = [4, 8] + * out_accum = [2, 4, 8] + * + * example2: + * in1_accum = [2, 4, 16] + * in2_accum = [8, 16] + * out_accum = [2, 4, 8, 16] + */ +Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, + std::vector *out_accum); + +/* + * given two shape in1 = [din1_n-1, din1_n-2, ..., din1_0] and in2 = [din2_m-1, din2_m-2, ..., din2_m] + * size = din1_n-1 * din1n-2 * ... * din1_0 = din2_m-1 * din2_m-2 * ... * din2_0 + * find *out = [dout_k-1, dout_k-2, ..., dout_0], s.t. dout_k-1 * dout_k-2 * ... * dout_0 = size and + * suppose in1_accum, in2_accum, and *out_accum is the ShapeToAccumulateProduct result of in1, in2, and *out + * then for each din1_i in in1_accum, din1_i is in *out_accumulate, + * for each din2_i in in2_accum, din2_i is in *out_accumulate + * + * example: + * in1 = [8, 4] + * in2 = [2, 16] + * out = [2, 4, 4] + */ +Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out); + +/* + * given two accumulate product in reverse order of in and expand, + * in_accum_reverse = [din_n-1, din_n-2, ..., din_0] and expand_pos_reverse = [dexp_n-1, dexp_n-2, ..., dexp_0], + * i.e. in_accum_reverse is the ShapeToAccumulateProductReverse result of a shape in, + * expand_accum_reverse is the ShapeToAccumulateProductReverse result of a shape expand, + * compute the accumulate product in reverse order out_accum_reverse = [dout_k-1, dout_k-2, ..., dout_0], + * s.t. elements in out_accum_reverse are union of elements in in_accum_reverse and expand_accum_reverse + * (out_accum_reverse = in_accum_reverse U expand_accum_reverse), and + * out_accum_reverse is the ShapeToAccumulateProductReverse result of shape expand, + * i.e. dout_i > 0, i=0,1,...,k-1, elements of out_accum_reverse must be larger than zero, + * dout_i-1 % dout_i == 0, i=1,...,k-1 + * + * example1: + * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * expand_accum_reverse = [2 * 8 * 32, 32, 8] + * out_accum_reverse = [2 * 8 * 4 * 8, 8 * 4 * 8, 4 * 8, 8] + * + * example2: + * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] + * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] + */ +Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, + const std::vector &expand_accum_reverse, + std::vector *out_accum_reverse); + +/* + * given a shape in = [din_n-1, din_n-2, ..., d_0], and the expand shape expand= [dexp_m-1, dexp_m-2, ..., dexp_0], + * compute the expended shape out = [dout_k-1, dout_k-2, ..., dout_0], + * s.t. dout_k-1 * dout_k-2 * ...* dout_0 = din_n-1 * din_n-2 * ... * d_0 + * suppose in_accum_reverse is the ShapeToAccumulateProductReverse result of in, + * expand_accum_reverse is the ShapeToAccumulateProductReverse result of expand, + * out_accum_reverse is the ShapeToAccumulateProductReverse result of out, + * then out_accum_reverse is the union of in_accum_reverse and expand_accum_reverse + * (out_accum_reverse = in_accum_reverse U expand_accum_reverse) + * + * example1: + * in = [2, 8, 32] + * expand = [16, 4, 8] + * out = [2, 8, 4, 8] + * + * example2: + * in = [2, 8, 32] + * expand = [2, 4, 8] + * out = [2, 4, 2, 4, 8] + */ +Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_info.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_info.h new file mode 100644 index 0000000000..fc78b1f59c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_info.h @@ -0,0 +1,71 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +using Shapes = std::vector; + +class TensorInfo { + public: + TensorInfo(const TensorLayout &tensor_layout, Shape shape, Shape slice_shape) + : tensor_layout_(tensor_layout), shape_(std::move(shape)), slice_shape_(std::move(slice_shape)) {} + explicit TensorInfo(const TensorLayout &tensor_layout) : tensor_layout_(tensor_layout) { + shape_ = tensor_layout.tensor_shape().array(); + slice_shape_ = tensor_layout.slice_shape().array(); + } + // trivial default constructor will not initialize c language types. + TensorInfo() = default; + ~TensorInfo() = default; + TensorLayout tensor_layout() const { return tensor_layout_; } + Shape slice_shape() const { return slice_shape_; } + Shape shape() const { return shape_; } + void set_reduce_dim(const std::vector &dim) { reduce_dim_ = dim; } + std::vector reduce_dim() const { return reduce_dim_; } + Dimensions InferStrategy() const { + Dimensions stra; + for (size_t i = 0; i < shape_.size(); ++i) { + if ((slice_shape_[i] == 0) || (shape_[i] % slice_shape_[i] != 0)) { + return stra; + } + int32_t dim = (int32_t)(shape_[i] / slice_shape_[i]); + stra.push_back(dim); + } + return stra; + } + + private: + TensorLayout tensor_layout_; + Shape shape_; + Shape slice_shape_; + // reduce method's reduce dim + std::vector reduce_dim_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc new file mode 100644 index 0000000000..b9c6cc78de --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc @@ -0,0 +1,394 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include +#include +#include "common/utils.h" +#include "ir/value.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/array.h" +#include "frontend/parallel/tensor_layout/shape_util.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +std::string TensorLayout::ToString() const { return StandardToString() + OriginToString(); } + +std::string TensorLayout::StandardToString() const { + std::ostringstream buffer; + buffer << std::endl << std::string("device arrangement = " + device_arrangement_.ToString()); + buffer << std::endl << std::string("tensor map = " + tensor_map_.ToString()); + buffer << std::endl << std::string("tensor shape = " + tensor_shape_.ToString()); + return buffer.str(); +} + +std::string TensorLayout::OriginToString() const { + std::ostringstream buffer; + buffer << std::endl << std::string("device arrangement origin = " + device_arrangement_origin_.ToString()); + buffer << std::endl << std::string("tensor map origin = " + tensor_map_origin_.ToString()); + buffer << std::endl << std::string("tensor shape origin = " + tensor_shape_origin_.ToString()); + return buffer.str(); +} + +Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tensor_map, + const Arrangement &tensor_shape) { + device_arrangement_origin_ = device_arrangement; + tensor_map_origin_ = tensor_map; + tensor_shape_origin_ = tensor_shape; + device_arrangement_ = device_arrangement; + tensor_map_ = tensor_map; + tensor_shape_ = tensor_shape; + if (IsValidTensorLayout()) { + MS_LOG(DEBUG) << "valid origin tensor layout " << this->OriginToString(); + RemoveElementEqualToOneInDeviceArrangement(); + MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString(); + return Status::SUCCESS; + } else { + MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); + return Status::FAILED; + } +} + +Status TensorLayout::InitFromVector(const std::vector &device_arrangement, + const std::vector &tensor_map, const std::vector &tensor_shape) { + if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) { + return FAILED; + } + if (tensor_map_origin_.Init(tensor_map) != SUCCESS) { + return FAILED; + } + if (tensor_shape_origin_.Init(tensor_shape) != SUCCESS) { + return FAILED; + } + if (Init(device_arrangement_origin_, tensor_map_origin_, tensor_shape_origin_) != SUCCESS) { + return FAILED; + } + return SUCCESS; +} + +bool TensorLayout::IsValidTensorLayout() const { + if (tensor_map_origin_.GetMaxItem() >= static_cast(device_arrangement_origin_.GetDimSize())) { + MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size!"; + return false; + } + if (tensor_map_origin_.GetDimSize() != tensor_shape_origin_.GetDimSize()) { + MS_LOG(ERROR) << "tensor_map_origin_ size must be equal to tensor_shape_origin_ size!"; + return false; + } + if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) { + MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; + return false; + } + return true; +} + +bool TensorLayout::TensorShapeDimensionIsDividedBySplitDeviceDimension() const { + for (uint32_t i = 0; i < tensor_map_.GetDimSize(); i++) { + if (tensor_map_.GetDimByIdx(i) != -1) { + int32_t divisor = GetSliceNumByTensorDimensionIndex(i); + if (divisor == 0) { + MS_LOG(ERROR) << "GetSliceNumByTensorDimensionIndex is 0"; + return false; + } + if (tensor_shape_.GetDimByIdx(i) % divisor != 0) { + return false; + } + } + } + return true; +} + +void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { + std::vector device_arrangement_shape; + std::vector tensor_map_shape = tensor_map_origin_.array(); + uint32_t dev_num = SizeToUint(device_arrangement_origin_.GetDimSize()); + int32_t dev_num_left = SizeToInt(device_arrangement_origin_.GetDimSize()); + for (uint32_t i = 0; i < dev_num; i++) { + if (device_arrangement_origin_.GetDimByIdx(i) == 1) { + int32_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast(dev_num - 1 - i)); + if (idx != -1) { + tensor_map_shape[static_cast(idx)] = -1; + } + for (auto &value : tensor_map_shape) { + if (value >= dev_num_left - 1 - static_cast(i)) { + value--; + } + } + continue; + } + device_arrangement_shape.push_back(device_arrangement_origin_.GetDimByIdx(i)); + } + (void)device_arrangement_.Init(device_arrangement_shape); + (void)tensor_map_.Init(tensor_map_shape); + tensor_shape_ = tensor_shape_origin_; +} + +// if idx is not in tensor_map, return -1 +int32_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const { + return tensor_map_.GetIndexByValue(idx); +} + +// tensor_map_.GetDimByIdx(idx) should not be -1 +int32_t TensorLayout::GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const { + return static_cast(device_arrangement_.GetDimSize()) - 1 - tensor_map_.GetDimByIdx(idx); +} + +// tensor_map_.GetDimByIdx(idx) should not be -1 +int32_t TensorLayout::GetSliceNumByTensorDimensionIndex(uint32_t idx) const { + return device_arrangement_.GetDimByIdx(static_cast(GetSliceDeviceDimensionByTensorDimensionIndex(idx))); +} + +std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement &expanded_shape) const { + std::shared_ptr expanded_arrangement_ptr = ComputeArrangementByExpandedShape(expanded_shape); + if (expanded_arrangement_ptr == nullptr) { + return nullptr; + } + std::shared_ptr temp_tensor_layout_ptr = ExpandDeviceArrangement(*expanded_arrangement_ptr); + if (temp_tensor_layout_ptr == nullptr) { + return nullptr; + } + return temp_tensor_layout_ptr->ExpandTensorShapeWithoutExtendDeviceArrangement(expanded_shape); +} + +/* + * example1: + * in_device_arrangement = [8, 4], + * in_tensor_map = [1, 0], + * in_tensor_shape = [512, 1024], + * out_tensor_shape = [128, 4, 2, 512], + * => + * out_device_arrangement = [8, 2, 2] + */ +std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const { + std::shared_ptr> expand_list_ptr = tensor_shape_.GetExpandShapeList(tensor_shape); + if (expand_list_ptr == nullptr) { + return nullptr; + } + std::vector re_map_expand_list; + Arrangement empty_arrangement; + for (int32_t i = static_cast(device_arrangement_.GetDimSize()) - 1; i >= 0; i--) { + if (tensor_map_.GetIndexByValue(i) < 0) { + re_map_expand_list.push_back(empty_arrangement); + } else { + re_map_expand_list.push_back((*expand_list_ptr)[IntToUint(tensor_map_.GetIndexByValue(i))]); + } + } + std::shared_ptr new_arrangement_ptr = + device_arrangement_.GetExpandedShapeByExpandListRemoveLeft(re_map_expand_list); + return new_arrangement_ptr; +} + +/* + * example1: + * in_device_arrangement = [8, 4], + * in_tensor_map = [1, 0], + * in_tensor_shape = [512, 1024], + * out_tensor_shape = [8, 64, 4, 256] + * => + * out_device_arrangement = [8, 4], + * out_tensor_map = [1, -1, 0, -1], + */ +std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDeviceArrangement( + const Arrangement &expanded_shape) const { + std::shared_ptr, Arrangement>> expand_list_pair_ptr = + tensor_shape_.GetExpandShapeListPair(expanded_shape); + if (expand_list_pair_ptr == nullptr) { + return nullptr; + } + std::shared_ptr tensor_map_new_ptr = tensor_map_.ExpandMapByNone(expand_list_pair_ptr->second); + if (tensor_map_new_ptr == nullptr) { + return nullptr; + } + TensorLayout tensor_layout_new; + Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(tensor_layout_new); +} + +/* + * example1: + * in_device_arrangement = [8, 4], + * in_tensor_map = [1, 0], + * in_tensor_shape = [512, 1024], + * out_device_arrangement = [4, 2, 2, 2] + * => + * out_tensor_map = [3, 2, 1, 0], + * out_tensor_shape = [4, 128, 2, 512] + * + * example2: + * in_device_arrangement = [8, 4], + * in_tensor_map = [0, 1], + * in_tensor_shape = [512, 1024], + * out_device_arrangement = [4, 2, 2, 2] + * => + * out_tensor_map = [1, 0, 3, 2], + * out_tensor_shape = [2, 256, 4, 256] + * + * example3: + * in_device_arrangement = [8, 4], + * in_tensor_map = [1, -1], + * in_tensor_shape = [512, 1024], + * out_device_arrangement = [4, 2, 2, 2] + * => + * out_tensor_map = [3, 2, -1], + * out_tensor_shape = [4, 128, 1024] + * + * example4: + * in_device_arrangement = [8, 4], + * in_tensor_map = [0, 1], + * in_tensor_shape = [512, 1024], + * out_device_arrangement = [4, 2, 4] + * => + * out_tensor_map = [0, 2, 1], + * out_tensor_shape = [512, 4, 256] + */ +std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const { + std::shared_ptr, Arrangement>> expand_list_pair_ptr = + device_arrangement_.GetExpandShapeListPair(expanded_arrangement); + if (expand_list_pair_ptr == nullptr) { + return nullptr; + } + std::shared_ptr tensor_map_new_ptr = tensor_map_.ExpandMapByDecreaseNumber(expand_list_pair_ptr->second); + if (tensor_map_new_ptr == nullptr) { + return nullptr; + } + std::shared_ptr> re_map_shape_list_ptr = + tensor_map_.ReMapVector(expand_list_pair_ptr->first); + if (re_map_shape_list_ptr == nullptr) { + return nullptr; + } + std::shared_ptr tensor_shape_new_ptr = + tensor_shape_.GetExpandedShapeByExpandListReserveLeft(*re_map_shape_list_ptr); + if (tensor_shape_new_ptr == nullptr) { + return nullptr; + } + TensorLayout tensor_layout_new; + Status status = tensor_layout_new.Init(expanded_arrangement, *tensor_map_new_ptr, *tensor_shape_new_ptr); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(tensor_layout_new); +} + +bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const { + std::vector in_expand_shape_shape; + Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); + if (status != Status::SUCCESS) { + return false; + } + return (in_expand_shape_shape == tensor_shape_.array()); +} + +std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const { + std::vector in_expand_shape_shape; + Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + Arrangement expanded_shape; + status = expanded_shape.Init(in_expand_shape_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(expanded_shape); +} + +Arrangement TensorLayout::slice_shape() const { + std::vector shape; + for (uint32_t index = 0; index < tensor_map_.GetDimSize(); index++) { + int32_t dim = tensor_map_.GetDimByIdx(index); + int32_t num = tensor_shape_.GetDimByIdx(index); + if (dim == -1) { + shape.push_back(num); + } else { + int32_t divisor = device_arrangement_.GetDimByReverseIdx(IntToUint(dim)); + shape.push_back(num / divisor); + } + } + Arrangement new_tensor_shape; + if (new_tensor_shape.Init(shape) == Status::FAILED) { + ValuePtr ptr = MakeValue(shape); + MS_LOG(EXCEPTION) << "Can't get slice shape when initialize a new shape " << ptr->ToString(); + } else { + return new_tensor_shape; + } +} + +Status TensorLayout::UpdateTensorMap(uint32_t index, int32_t value) { + if (index >= tensor_map_.GetDimSize()) { + MS_LOG(ERROR) << "Index is out of the size of the tensor map!"; + return Status::FAILED; + } + auto shape = tensor_map_.array(); + shape[index] = value; + if (tensor_map_.Init(shape) == Status::FAILED) { + MS_LOG(ERROR) << "Update tensor map failed!"; + return Status::FAILED; + } + return Status::SUCCESS; +} + +bool TensorLayout::operator==(const TensorLayout &t1) const { + return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1)); +} + +/* + * remove elements equal to 1 in tensor_shape, if all elements are 1, squeeze the tensor_shape to [ 1 ] + * example 1: + * original tensor layout: + * device arrangement = [ 8 ] + * tensor map = [ 0 -1 -1 -1 ] + * tensor shape = [ 128 64 1 1 ] + * return tensor layout: + * device arrangement = [ 8 ] + * tensor map = [ 0 -1 ] + * tensor shape = [ 128 64 ] + * + * example 2: + * device arrangement = [ 8 ] + * tensor map = [ -1 -1 -1 -1 ] + * tensor shape = [ 1 1 1 1 ] + * return tensor layout: + * device arrangement = [ 8 ] + * tensor map = [ -1 ] + * tensor shape = [ 1 ] + */ +TensorLayout TensorLayout::SqueezeShape() const { + TensorLayout out; + Map out_map; + Arrangement out_shape; + if (tensor_shape_.size() == 1) { + (void)out_map.Init({MAP_NONE}); + (void)out_shape.Init({1}); + (void)out.Init(device_arrangement_, out_map, out_shape); + return out; + } + std::vector squeeze_list = tensor_shape_.GetSqueezeIdx(); + if (!tensor_map_.CheckNoneByIdxList(squeeze_list)) { + MS_LOG(ERROR) << "CheckNoneByIdxList failed, this may not happen under current situation"; + return *this; + } + out_shape = tensor_shape_.GetSqueezeArrangement(); + out_map = tensor_map_.SqueezeMapByIdxList(squeeze_list); + (void)out.Init(device_arrangement_, out_map, out_shape); + return out; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h new file mode 100644 index 0000000000..a9fdc9610c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h @@ -0,0 +1,99 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ + +#include +#include +#include +#include +#include +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/arrangement.h" +#include "frontend/parallel/tensor_layout/map.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace parallel { +class TensorLayout { + public: + TensorLayout() = default; + ~TensorLayout() = default; + std::string ToString() const; + std::string StandardToString() const; + std::string OriginToString() const; + Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); + Status InitFromVector(const std::vector &device_arrangement, const std::vector &tensor_map, + const std::vector &tensor_shape); + + Arrangement device_arrangement() const { return device_arrangement_; } + + Map tensor_map() const { return tensor_map_; } + + Arrangement tensor_shape() const { return tensor_shape_; } + + Map origin_tensor_map() const { return tensor_map_origin_; } + + std::shared_ptr ExpandTensorShape(const Arrangement &expanded_shape) const; + + std::shared_ptr ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const; + + bool IsSameTensorShape(const TensorLayout &tensor_layout) const { + return (tensor_shape_ == tensor_layout.tensor_shape()); + } + + bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const { + return (device_arrangement_ == tensor_layout.device_arrangement()); + } + + bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } + + bool operator==(const TensorLayout &t1) const; + + bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const; + + std::shared_ptr ComputeExpandedTensorShape(const Arrangement &expand_shape) const; + + Arrangement slice_shape() const; + + Status UpdateTensorMap(uint32_t index, int32_t value); + + TensorLayout SqueezeShape() const; + + private: + std::shared_ptr ExpandTensorShapeWithoutExtendDeviceArrangement( + const Arrangement &expanded_shape) const; + std::shared_ptr ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const; + bool IsValidTensorLayout() const; + void RemoveElementEqualToOneInDeviceArrangement(); + int32_t GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const; + int32_t GetSliceNumByTensorDimensionIndex(uint32_t idx) const; + bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const; + int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const; + + Arrangement device_arrangement_origin_; + Map tensor_map_origin_; + Arrangement tensor_shape_origin_; + Arrangement device_arrangement_; + Map tensor_map_; + Arrangement tensor_shape_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc new file mode 100644 index 0000000000..43bb330787 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc @@ -0,0 +1,209 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include +#include +#include +#include "common/utils.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/shape_util.h" + +namespace mindspore { +namespace parallel { +Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) { + from_origin_ = from; + to_origin_ = to; + if (from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) { + MS_LOG(ERROR) << "from shape size must be equal to to shape size!"; + MS_LOG(ERROR) << "reshape from_origin_ " << from_origin_.ToString(); + MS_LOG(ERROR) << "reshape to_origin_ " << to_origin_.ToString(); + return Status::FAILED; + } + + dev_list_ = dev_list; + from_ = from_origin_.SqueezeShape(); + to_ = to_origin_.SqueezeShape(); + return Status::SUCCESS; +} + +RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) { + // Step 1: Match device arrangement between from_ and to_ + RedistributionLayoutTransfer layout_transfer; + Status status = layout_transfer.Init(from_, to_); + if (status != Status::SUCCESS) { + return nullptr; + } + std::shared_ptr ptr = layout_transfer.UnifyDeviceArrangementAndTensorShape(); + if (ptr == nullptr) { + MS_LOG(ERROR) << "Infer tensor layout return nullptr!"; + return nullptr; + } + TensorLayout from_layout = ptr->from_in(); + TensorLayout to_layout = ptr->to_in(); + MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString(); + MS_LOG(DEBUG) << "reshape to_layout " << to_layout.ToString(); + MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); + MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); + MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); + MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); + // Step 2: Infer redistribution and insert operators + RedistributionOperatorInfer operator_infer(construct_op_flag_); + if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { + MS_LOG(ERROR) << "Init operatorInfer failed!"; + return nullptr; + } + OperatorVector operator_vector; + OutPutInfoVector output_info_vector; + if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) { + MS_LOG(ERROR) << "Infer redistribution failed!"; + return nullptr; + } else { + operator_vector = operator_infer.operator_vector(); + output_info_vector = operator_infer.output_info_vector(); + operator_list_ = operator_infer.operator_list(); + } + + // Step 3: Infer reshape and insert operators + if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) { + MS_LOG(ERROR) << "Construct Reshape operator failed!"; + return nullptr; + } + + return std::make_shared>( + std::make_pair(operator_vector, output_info_vector)); +} + +Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, + OutPutInfoVector *const output_info_vector) { + MS_EXCEPTION_IF_NULL(operator_vector); + MS_EXCEPTION_IF_NULL(output_info_vector); + ConstructOperator constructor; + if (operator_list_.empty()) { + if (from_origin_.slice_shape().array() != to_origin_.slice_shape().array() || keep_reshape_) { + reshape_flag_ = true; + constructor.UpdateTensorShape(from_origin_.slice_shape().array()); + Arrangement shape = to_origin_.slice_shape(); + MS_LOG(DEBUG) << "reshape " << shape.ToString(); + if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { + return Status::FAILED; + } else { + (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator()); + (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0)); + } + } + return Status::SUCCESS; + } + + if (from_origin_.slice_shape().array() != from_layout.slice_shape().array()) { + reshape_flag_ = true; + constructor.UpdateTensorShape(from_origin_.slice_shape().array()); + Arrangement shape = from_layout.slice_shape(); + MS_LOG(DEBUG) << "reshape " << shape.ToString(); + if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { + return Status::FAILED; + } else { + (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator()); + (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0)); + } + } + + if (to_origin_.slice_shape().array() != to_layout.slice_shape().array()) { + reshape_flag_ = true; + constructor.UpdateTensorShape(to_layout.slice_shape().array()); + Arrangement shape = to_origin_.slice_shape(); + MS_LOG(DEBUG) << "step_parallel to reshape " << shape.ToString(); + if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { + return Status::FAILED; + } else { + (void)operator_vector->insert(operator_vector->end(), constructor.GetOperator()); + (void)output_info_vector->insert(output_info_vector->end(), std::make_pair(false, 0)); + } + } + return Status::SUCCESS; +} + +Status TensorRedistribution::ComputeCost() { + RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true); + if (redistribution_oplist_ptr == nullptr) { + MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; + return Status::FAILED; + } + // Compute redistribution communication cost and computation cost + for (auto &op_cost : operator_list_) { + OperatorR op = op_cost.first; + Shape slice_shape = op_cost.second; + double prod = + std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast(1.0), std::multiplies()); + std::string str = op.first; + if (str == PERMUTE_BY_AXIS) { + // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost. + // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape + forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; + backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; + comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR; + int32_t concat_dim = op.second[2]; + if (concat_dim == 0) { + // memory cost = all_gather + computation_cost_ += prod; + memory_cost_ += prod; + } else { + // memory cost = all_gather + split + concat + int32_t dev_num = op.second[4]; + computation_cost_ += (prod + prod * dev_num + prod * dev_num); + memory_cost_ += (prod * dev_num + prod * dev_num + prod); + } + } else if (str == CONCAT_BY_AXIS) { + // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape + // computation cost = before_slice_shape + if (op.second.size() < 3) { + MS_LOG(ERROR) << "op.second size should not be less than 3!"; + return Status::FAILED; + } + double dev_num = op.second[2]; + // here, communication cost = all_gather + reduce_scatter + forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; + backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; + comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; + int32_t concat_dim = op.second[0]; + if (concat_dim == 0) { + // computation cost = all_gather + computation_cost_ += prod; + memory_cost_ += prod * dev_num; + } else { + // computation cost = all_gather + split + concat + computation_cost_ += (prod + prod * dev_num + prod * dev_num); + memory_cost_ += (prod * dev_num + prod * dev_num + prod); + } + } else { + // There is only computation cost in SplitByAxis. + // computation cost = before_slice_shape + computation_cost_ += prod; + // This addtion may be erroneous + memory_cost_ += prod; + } + } + if (reshape_flag()) { + Shape prev_slice_shape = from_.slice_shape().array(); + double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies()); + computation_cost_ += 2.0 * prev_prod; + memory_cost_ += 2.0 * prev_prod; + } + return Status::SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h new file mode 100644 index 0000000000..df4bd1570f --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h @@ -0,0 +1,90 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ + +#include +#include +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/construct_operator.h" +#include "frontend/parallel/tensor_layout/redistribution_operator_infer.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +constexpr double ALLTOALL_SCALE_FACTOR = 2.0; +constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5; +class TensorRedistribution { + public: + explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) + : reshape_flag_(false), + comm_cost_(0.0), + forward_comm_cost_(0.0), + backward_comm_cost_(0.0), + computation_cost_(0.0), + memory_cost_(0.0), + construct_op_flag_(construct_op_flag), + keep_reshape_(keep_reshape) {} + Status Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list); + ~TensorRedistribution() = default; + RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); + OperatorList operator_list() const { return operator_list_; } + bool reshape_flag() const { return reshape_flag_; } + Status ComputeCost(); + double comm_cost() const { return comm_cost_; } + double computation_cost() const { return computation_cost_; } + double forward_comm_cost() const { return forward_comm_cost_; } + double backward_comm_cost() const { return backward_comm_cost_; } + double memory_cost() const { return memory_cost_; } + + private: + Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector); + + TensorLayout from_origin_; + TensorLayout to_origin_; + TensorLayout from_; + TensorLayout to_; + RankList dev_list_; + OperatorList operator_list_; + bool reshape_flag_; + // communication cost, which is the sum of forward communication cost and backward communication cost + double comm_cost_; + // forward communication cost + double forward_comm_cost_; + // backward communication cost + double backward_comm_cost_; + // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the + // inputs. This is calculated ONLY for forward phase. + double computation_cost_; + // memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is + // calculated by the outputs. + double memory_cost_; + bool construct_op_flag_; + bool keep_reshape_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ diff --git a/mindspore/ccsrc/gvar/typeid_manager.cc b/mindspore/ccsrc/gvar/typeid_manager.cc index f40052411a..bc74f3a0df 100644 --- a/mindspore/ccsrc/gvar/typeid_manager.cc +++ b/mindspore/ccsrc/gvar/typeid_manager.cc @@ -20,7 +20,7 @@ #include #include -#include "ir/base.h" +#include "base/base.h" namespace mindspore { diff --git a/mindspore/ccsrc/ir/anf.cc b/mindspore/ccsrc/ir/anf.cc deleted file mode 100644 index 4c1d2bf50d..0000000000 --- a/mindspore/ccsrc/ir/anf.cc +++ /dev/null @@ -1,221 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/anf.h" - -#include -#include -#include -#include - -#include "ir/func_graph.h" -#include "ir/primitive_base.h" -#include "utils/context/ms_context.h" -#include "operator/ops.h" - -namespace mindspore { -// namespace to support intermediate representation definition -CNode::CNode(const std::vector &inputs, const FuncGraphPtr &func_graph) - : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} - -// Check if CNode is an apply with the specific Primitive. -bool CNode::IsApply(const PrimitivePtr &value) const { - if (value == nullptr) { - return false; - } - - if (inputs_.size() != 0 && IsValueNode(inputs_[0])) { - PrimitivePtr fn_value = GetValueNode(inputs_[0]); - if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) { - return true; - } - } - - return false; -} - -void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; } - -std::string CNode::DebugString(int recursive_level) const { - std::ostringstream buffer; - if (recursive_level > 0) { - if (func_graph() != nullptr) { - buffer << func_graph()->ToString() << ":"; - } - buffer << ToString() << "{"; - bool is_first_node = true; - int idx = 0; - for (auto &node : inputs_) { - MS_EXCEPTION_IF_NULL(node); - if (is_first_node) { - is_first_node = false; - } else { - buffer << ", "; - } - buffer << "[" << idx << "]: " << node->DebugString(recursive_level - 1); - idx++; - } - buffer << "}"; - } else { - buffer << ToString(); - } - return buffer.str(); -} - -std::string ValueNode::ToString() const { - MS_EXCEPTION_IF_NULL(value_); - if (value_->isa()) { - return value_->cast()->ToString(); - } - std::ostringstream buffer; - buffer << AnfNode::ToString(); - buffer << "(" << value_->ToString() << ")"; - return buffer.str(); -} - -std::string ValueNode::DebugString(int) const { - MS_EXCEPTION_IF_NULL(value_); - std::ostringstream buffer; - buffer << "ValueNode<" << value_->type_name() << "> " << value_->ToString(); - return buffer.str(); -} - -std::string ValueNode::fullname_with_scope() { - if (!fullname_with_scope_.empty()) { - return fullname_with_scope_; - } - - MS_EXCEPTION_IF_NULL(scope()); - fullname_with_scope_ = scope()->name() + "/" + "data-" + id_generator::get_id(shared_from_base()); - return fullname_with_scope_; -} - -bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode == nullptr) { - return false; - } - if (value != nullptr) { - return cnode->IsApply(value); - } - const auto &prim = GetValueNode(cnode->input(0)); - return prim != nullptr; -} - -PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { - if (node == nullptr) { - return nullptr; - } - auto cnode = node->cast(); - if (cnode != nullptr) { - if (cnode->size() > 0) { - auto prim = GetValueNode(cnode->input(0)); - return prim; - } - } - return nullptr; -} - -std::string GetCNodeFuncName(const CNodePtr cnode) { - if (cnode->inputs().empty()) { - return ""; - } - - AnfNodePtr valuenode = cnode->input(0); - if (valuenode->isa()) { - auto value = GetValueNode(valuenode); - // check whether the valuenode is primitive - if (value->isa()) { - return value->cast()->name(); - } - return value->ToString(); - } - return ""; -} - -bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { - if (IsValueNode(node)) { - PrimitivePtr fn_value = GetValueNode(node); - MS_EXCEPTION_IF_NULL(value); - if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) { - return true; - } - } - return false; -} - -size_t NewSeenGeneration() { - static size_t seen_generation = 0; - return ++seen_generation; -} - -namespace id_generator { -static std::unordered_map node_ids; -std::string get_id(const AnfNodePtr &node) { - auto type_name = node->type_name(); - if (node_ids.find(type_name) == node_ids.end()) { - node_ids[type_name] = 0; - } else { - node_ids[type_name]++; - } - return std::to_string(node_ids[type_name]); -} - -void reset_id() { node_ids.clear(); } -} // namespace id_generator - -std::string GetCNodeTarget(const AnfNodePtr &node) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->device_target(); - if (!node->isa()) { - return default_target; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto attr_input = cnode->input(0); - if (attr_input == nullptr) { - return default_target; - } - auto value_node = attr_input->cast(); - if (value_node == nullptr) { - return default_target; - } - auto value = value_node->value(); - if (value == nullptr) { - return default_target; - } - if (!value->isa()) { - return default_target; - } - auto primitive = value->cast(); - auto att_target = primitive->GetAttr("primitive_target"); - if (att_target != nullptr) { - if (!att_target->isa()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; - } - auto target = GetValue(att_target); - if (kTargetSet.find(target) == kTargetSet.end()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; - } - return target; - } - return default_target; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/anf.h b/mindspore/ccsrc/ir/anf.h deleted file mode 100644 index 8a44627885..0000000000 --- a/mindspore/ccsrc/ir/anf.h +++ /dev/null @@ -1,454 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_ANF_H_ -#define MINDSPORE_CCSRC_IR_ANF_H_ - -#include -#include -#include -#include -#include -#include - -#include "ir/base.h" -#include "debug/info.h" -#include "ir/scope.h" - -// A MindSpore ANF IR defined here. -// with BNF followed: -// ::= Scalar | Named | Tensor | Var | -// Prim | MetaFuncGraph | FuncGraph | Type| -// Shape | Param -// ::= ( ...) -// ::= | -// ANode: Atomic Node -// CNode: Complex Node -namespace mindspore { -namespace parallel { -class TensorLayout; -class OperatorInfo; -} // namespace parallel -using OperatorInfoPtr = std::shared_ptr; - -namespace abstract { -class BaseShape; -class AbstractBase; -} // namespace abstract -using BaseShapePtr = std::shared_ptr; -using AbstractBasePtr = std::shared_ptr; -using AbstractBasePtrList = std::vector; - -class ValueNode; -using ValueNodePtr = std::shared_ptr; -class CNode; -using CNodePtr = std::shared_ptr; - -class FuncGraph; -using FuncGraphSet = OrderedSet; -using FuncGraphPtrList = std::vector; - -class Primitive; -using PrimitivePtr = std::shared_ptr; - -class BaseRef; - -class Var; -using VarPtr = std::shared_ptr; - -namespace device { -class KernelInfo; -} // namespace device -using KernelInfoDevice = device::KernelInfo; -using KernelInfoDevicePtr = std::shared_ptr; - -class AnfVisitor; - -class ParamValue { - public: - ParamValue() = default; - virtual ~ParamValue() = default; -}; -using ParamValuePtr = std::shared_ptr; - -// AnfNode is the basic class of the IR definition derived from Base. -// Only two types of nodes are derived: CNode and ANode. -// Methods: -// func_graph: return FuncGraph that this AnfNode belongs to. -// scope: return the scope namespace of this AnfNode. Set it using set_scope. -// abstract: return the cached inferred abstract value. It contains type, shape -// value. Set New cache using set_abstract. -// intermediate_abstract: return the cached inferring abstract value. -// Type/Shape: return the related info of this AnfNode. When this AnfNode is an -// input of other CNodes, you can get the related info by this method. -// debug_info: return the information retrived from parser. Set it using set_debug_info. -// fullname_with_scope: return the detailed debug info. -class AnfNode : public Base { - public: - explicit AnfNode(const FuncGraphPtr &func_graph) - : func_graph_(FuncGraphWeakPtr(func_graph)), - abstract_(nullptr), - intermediate_abstract_(nullptr), - debug_info_(std::make_shared()), - fullname_with_scope_(""), - hash_(std::hash()), - kernel_info_(nullptr) { - scope_ = ScopeManager::GetInstance().GetCurrentScope(); - } - - ~AnfNode() override = default; - MS_DECLARE_PARENT(AnfNode, Base); - - virtual void accept(AnfVisitor *) {} - FuncGraphPtr func_graph() const { return func_graph_.lock(); } - - void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } - - ScopePtr scope() { return scope_; } - void set_scope(const ScopePtr &scope) { scope_ = scope; } - - const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); } - KernelInfoDevice *kernel_info() { return kernel_info_.get(); } - const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; } - void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; } - - AbstractBasePtr abstract() const { return abstract_; } - void set_abstract(const AbstractBasePtr &abs) { abstract_ = abs; } - - AbstractBasePtr intermediate_abstract() { return intermediate_abstract_; } - void set_intermediate_abstract(const AbstractBasePtr &abs) { intermediate_abstract_ = abs; } - - NodeDebugInfoPtr debug_info() { - MS_EXCEPTION_IF_NULL(debug_info_); - if (debug_info_->get_node() == nullptr) { - debug_info_->set_node(shared_from_base()); - } - return debug_info_; - } - void set_debug_info(const NodeDebugInfoPtr &debug_info) { - debug_info_ = debug_info; - if (debug_info_->get_node() == nullptr) { - debug_info_->set_node(shared_from_base()); - } - } - - TypePtr Type() const; - BaseShapePtr Shape() const; - - std::size_t hash() const override { return this->hash_(this); } - virtual std::string fullname_with_scope() { return ""; } - - virtual std::string DebugString(int recursive_level = 1) const { return ToString(); } - virtual std::string DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); } - std::string ToString() const override; - void dump() const override { std::cout << DebugString() << std::endl; } - std::string UniqueId() { return std::to_string(debug_info()->unique_id()); } - std::string UniqueIdThroughCopy() { return std::to_string(debug_info()->unique_id_through_copy()); } - virtual bool operator==(const AnfNode &other) const { return &other == this; } - friend std::ostream &operator<<(std::ostream &os, const AnfNode &node) { - os << node.ToString(); - return os; - } - size_t seen_{0}; - - protected: - // Hold a weak ref to Graph as Graph also hold ref to AnfNode. - // Otherwise, func_graph_ and AnfNode will make a reference cycle. - FuncGraphWeakPtr func_graph_; - AbstractBasePtr abstract_; - AbstractBasePtr intermediate_abstract_; - NodeDebugInfoPtr debug_info_; - std::string fullname_with_scope_; - - private: - std::hash hash_; - ScopePtr scope_; - KernelInfoDevicePtr kernel_info_; -}; - -// CNode represents the complex node with a set of arguments. -// Fields: -// inputs_: represents all of the inputs for this CNode. -// Using input(i) to get the index i input. -// Using inputs() to get all the inputs as a vector. -// Using add_input(input) to append a new input for a CNode. -// Using set_input(i, input) to change some input of these inputs. -// Using set_inputs(inputs) to refresh all of the inputs of a CNode. -// func_graph_as_var_: used in opt pattern matching to match a real FuncGraph. -// stop_gradient_: a flag used to stop gradient. -// Using stop_gradient() to get this flag, mainly used in ad. -// Using set_stop_gradient() to set this flag. -class CNode : public AnfNode { - public: - CNode(const std::vector &inputs, const FuncGraphPtr &func_graph); - CNode(const std::vector &inputs, const VarPtr &func_graph_as_var) - : AnfNode(nullptr), inputs_(inputs), func_graph_as_var_(func_graph_as_var), stop_gradient_(false) {} - - ~CNode() override = default; - MS_DECLARE_PARENT(CNode, AnfNode); - - void accept(AnfVisitor *v) override; - // check whether this cnode has some primitive value as the first input. - bool IsApply(const PrimitivePtr &) const; - - const size_t size() const { return inputs_.size(); } - const AnfNodePtr input(size_t i) const { return inputs_[i]; } - const std::vector &inputs() const { return inputs_; } - void add_input(const AnfNodePtr &input) { inputs_.push_back(input); } - void set_input(size_t i, const AnfNodePtr &input); - void set_inputs(const std::vector &inputs) { inputs_ = inputs; } - - bool stop_gradient() const { return stop_gradient_; } - void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } - - std::string fullname_with_scope() override; - void set_fullname_with_scope(const std::string full_name) { fullname_with_scope_ = full_name; } - std::string DebugString(int recursive_level = 1) const override; - std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } - - OperatorInfoPtr set_operator_info(const OperatorInfoPtr &operator_info); - OperatorInfoPtr operator_info() { return operator_info_; } - - void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; } - bool in_forward_flag() const { return in_forward_flag_; } - - VarPtr func_graph_as_var() const { return func_graph_as_var_; } - - private: - std::vector inputs_; - VarPtr func_graph_as_var_; - bool stop_gradient_; - OperatorInfoPtr operator_info_ = nullptr; - bool in_forward_flag_ = false; -}; - -// ANode represents the atomic node. It's derived Parameter and ValueNode. -class ANode : public AnfNode { - public: - ANode() : AnfNode(nullptr) {} - explicit ANode(const FuncGraphPtr &func_graph) : AnfNode(func_graph) {} - virtual ~ANode() = default; - - MS_DECLARE_PARENT(ANode, AnfNode); -}; - -// Parameter represents the parameter inputs of a function. They have no value. -// Attributes: -// default_param_value_: used to hold the inputting tensor of the model. -class Parameter : public ANode { - public: - explicit Parameter(const FuncGraphPtr &func_graph) - : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {} - ~Parameter() override = default; - MS_DECLARE_PARENT(Parameter, ANode); - - void accept(AnfVisitor *v) override; - - std::string name() const { return name_; } - void set_name(const std::string &name) { name_ = name; } - std::string fullname_with_scope() override { return name(); }; - - bool has_default() const { return has_default_; } - void set_default_param(ParamValuePtr param) { - default_param_ = param; - has_default_ = true; - } - ParamValuePtr default_param() const { return default_param_; } - - std::shared_ptr tensor_layout() const { return tensor_layout_; } - void set_tensor_layout(const std::shared_ptr &tensor_layout) { - tensor_layout_ = tensor_layout; - } - - bool operator==(const AnfNode &other) const override { - if (!other.isa()) { - return false; - } - auto p = static_cast(other); - if (name_.length() > 0 && p.name_.length() > 0) { - return p.name_ == name_; - } - return shared_from_this() == other.shared_from_this(); - } - - private: - std::string name_; - bool has_default_; - ParamValuePtr default_param_; - std::shared_ptr tensor_layout_; -}; -using ParameterPtr = std::shared_ptr; - -// Value is used to represent the atomic expression mentioned in BNF. -// It mainly be stored in ValueNode. Value and ValueNode is related definition. -class Value : public Base { - public: - Value() = default; - explicit Value(const TypePtr t) : type_(t) {} - Value(const Value &other) : Base(other) { this->type_ = other.type_; } - ~Value() override = default; - MS_DECLARE_PARENT(Value, Base) - - TypePtr type() const { return type_; } - virtual abstract::AbstractBasePtr ToAbstract() { MS_LOG(EXCEPTION) << "ToAbstract error"; } - - virtual bool operator==(const Value &rhs) const = 0; - virtual Value &operator=(const Value &other) { - if (&other == this) { - return *this; - } - this->type_ = other.type_; - return *this; - } - - protected: - TypePtr type_{nullptr}; -}; -using ValuePtr = std::shared_ptr; -using ValuePtrList = std::vector; - -// ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode -// does not belong to any particular function graph. -class ValueNode : public ANode { - public: - explicit ValueNode(const ValuePtr &value) : value_(value) {} - ~ValueNode() override = default; - MS_DECLARE_PARENT(ValueNode, ANode); - - void accept(AnfVisitor *v) override; - const ValuePtr &value() const { return value_; } - std::string fullname_with_scope() override; - - std::string ToString() const override; - std::string DebugString(int recursive_level = 1) const override; - std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } - - bool operator==(const AnfNode &other) const override { - if (!other.isa()) { - return false; - } - auto v = static_cast(other); - return *v.value() == *value(); - } - friend std::ostream &operator<<(std::ostream &os, const ValueNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - os << node->ToString(); - return os; - } - - private: - ValuePtr value_; -}; - -template -struct ImmTraits {}; - -#define IMM_TRAITS(typeimm, prototype) \ - template <> \ - struct ImmTraits { \ - using type = typeimm; \ - }; - -inline ValuePtr MakeValue(const ValuePtr &value) { return value; } - -template ::type::element_type> -inline ValuePtr MakeValue(S v) { - return std::make_shared(v); -} - -template ::type> -static S GetValue(const ValuePtr &value) { - MS_EXCEPTION_IF_NULL(value); - - U imm = value->cast(); - if (imm == nullptr) { - MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name(); - } - return imm->value(); -} - -template ::value && std::is_base_of::value, - S>::type * = nullptr> -static S GetValue(const ValuePtr &value) { - MS_EXCEPTION_IF_NULL(value); - S v = value->cast(); - if (v == nullptr) { - MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name(); - } - return v; -} - -std::string GetCNodeFuncName(CNodePtr cnode); - -// used to check whether an AnfNode is a cnode with a kind of Primitive as first input -bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr); - -// used to get PrimitivePtr from a cnode first input -PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); - -// used to check whether an AnfNode is a valuenode having some Primitive value -bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value); - -// used to check whether a ValueNode has some kind of value -template -static bool IsValueNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto anode = node->cast(); - if (anode != nullptr) { - auto value = anode->value(); - if (value == nullptr) { - MS_LOG(EXCEPTION) << "Const value is nullptr."; - } - return value->isa(); - } - return false; -} - -inline ValuePtr GetValueNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return nullptr; - } - return node->cast()->value(); -} - -template ::value && std::is_base_of::value, - S>::type * = nullptr> -inline S GetValueNode(const AnfNodePtr &node) { - auto value = GetValueNode(node); - if (value == nullptr) { - return nullptr; - } - auto s = value->cast(); - return s; -} - -size_t NewSeenGeneration(); - -namespace id_generator { -std::string get_id(const AnfNodePtr &node); -void reset_id(); -} // namespace id_generator -using TaggedNodeMap = std::unordered_map; -using TaggedGraph = std::pair; -std::string GetCNodeTarget(const AnfNodePtr &node); -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_ANF_H_ diff --git a/mindspore/ccsrc/ir/anf_extends.cc b/mindspore/ccsrc/ir/anf_extends.cc deleted file mode 100644 index 432ffdb606..0000000000 --- a/mindspore/ccsrc/ir/anf_extends.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/anf.h" - -#include -#include -#include -#include - -#include "ir/visitor.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "operator/ops.h" -#include "parallel/ops_info/ops_utils.h" -#include "debug/label.h" - -namespace mindspore { -// namespace to support intermediate representation definition -// Methods of AnfNode -TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); } -BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } - -std::string AnfNode::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); -} - -OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { - if (operator_info_ != nullptr) { - MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() - << ", using the new one: " << operator_info->name(); - auto old_ptr = operator_info_; - operator_info_ = operator_info; - return old_ptr; - } - operator_info_ = operator_info; - return nullptr; -} - -std::string CNode::fullname_with_scope() { - // if full name is set, return its name immediately - if (!fullname_with_scope_.empty()) { - return fullname_with_scope_; - } - - if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || - IsApply(prim::kPrimHistogramSummary)) { - std::string tag = GetValue(GetValueNode(input(1))); - std::string name; - if (IsApply(prim::kPrimScalarSummary)) { - name = tag + "[:Scalar]"; - } else if (IsApply(prim::kPrimImageSummary)) { - name = tag + "[:Image]"; - } else if (IsApply(prim::kPrimHistogramSummary)) { - name = tag + "[:Histogram]"; - } else { - name = tag + "[:Tensor]"; - } - fullname_with_scope_ = name; - } else { - // cnode input 0 should be primitive ptr or funcgraph ptr - auto value_ptr = input(0)->cast(); - if (value_ptr == nullptr) { - MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; - fullname_with_scope_ = id_generator::get_id(shared_from_base()); - return fullname_with_scope_; - } - auto input_value = value_ptr->value(); - if (input_value == nullptr) { - MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr."; - fullname_with_scope_ = id_generator::get_id(shared_from_base()); - return fullname_with_scope_; - } - - auto prim = input_value->cast(); - MS_EXCEPTION_IF_NULL(scope()); - fullname_with_scope_ = scope()->name() + "/"; - if (prim != nullptr) { - fullname_with_scope_ += prim->name(); - } else { - auto func_graph = input_value->cast(); - MS_EXCEPTION_IF_NULL(func_graph); - auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - if (fg_flag != nullptr) { - auto fg_name = GetValue(fg_flag); - fullname_with_scope_ += "GraphKernel_" + fg_name; - } else { - fullname_with_scope_ += func_graph->ToString(); - } - } - fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base()); - } - - return fullname_with_scope_; -} - -void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/base.cc b/mindspore/ccsrc/ir/base.cc deleted file mode 100644 index 7a03269ad8..0000000000 --- a/mindspore/ccsrc/ir/base.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/base.h" -#include -#include -#include - -namespace mindspore { -const bool Base::IsFromTypeId(uint32_t tid) const { - static const uint32_t node_id = GetTypeId(typeid(Base).name()); - return tid == node_id; -} - -uint32_t Base::GetTypeId(const char *const type_name) { - TypeIdManager *t = TypeIdManager::Get(); - std::lock_guard(t->mutex); - auto it = t->map.find(type_name); - if (it != t->map.end()) { - return it->second; - } - uint32_t tid = ++(t->type_counter); - t->map[type_name] = tid; - return tid; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/base.h b/mindspore/ccsrc/ir/base.h deleted file mode 100644 index 7dc4145837..0000000000 --- a/mindspore/ccsrc/ir/base.h +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_BASE_H_ -#define MINDSPORE_CCSRC_IR_BASE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "utils/visible.h" -#include "utils/log_adapter.h" -#include "utils/ordered_set.h" -#include "utils/ordered_map.h" - -namespace mindspore { -template -struct is_shared_ptr : public std::false_type {}; -template -struct is_shared_ptr> : public std::true_type {}; - -class Base : public std::enable_shared_from_this { - public: - constexpr Base() = default; - Base(const Base &other) : std::enable_shared_from_this(other) {} - virtual bool operator==(const Base &rhs) { - if (this == &rhs) { - return true; - } - return false; - } - - virtual Base &operator=(const Base &) { return *this; } - virtual ~Base() = default; - virtual std::size_t hash() const { return tid(); } - virtual std::string ToString() const { return type_name(); } - virtual void dump() const { std::cout << ToString() << std::endl; } - - virtual std::string DumpText() const { return ToString(); } - - virtual const bool IsFromTypeId(uint32_t tid) const; - virtual std::string type_name() const { return "Base"; } - static uint32_t GetTypeId(const char *const type_key); - virtual uint32_t tid() const { - static const uint32_t tid = GetTypeId(typeid(Base).name()); - return tid; - } - - template ::value && std::is_base_of::value, T>::type * = nullptr> - inline bool isa() const { - static const uint32_t tid = GetTypeId(typeid(T).name()); - return this->IsFromTypeId(tid); - } - - template ::value, typename T::element_type>::type> - inline T cast() { - if (isa()) { - return std::static_pointer_cast(shared_from_this()); - } else { - return nullptr; - } - } - - protected: - template - std::shared_ptr shared_from_base() { - return std::static_pointer_cast(shared_from_this()); - } -}; - -using BasePtr = std::shared_ptr; -using BaseWeakPtr = std::weak_ptr; - -template -inline T *cast(U *source) { - if (source != nullptr && source->template isa()) { - return static_cast(source); - } else { - return nullptr; - } -} - -template < - typename T, typename U, - typename std::enable_if::value && std::is_base_of::value, T>::type * = nullptr> -inline std::shared_ptr dyn_cast(const std::shared_ptr r) { - if (r != nullptr && r->template isa()) { - return std::static_pointer_cast(r); - } else { - return std::shared_ptr(); - } -} - -#define MS_DECLARE_PARENT(current_t, parent_t) \ - uint32_t tid() const override { \ - static const uint32_t tid = GetTypeId(typeid(current_t).name()); \ - return tid; \ - } \ - const bool IsFromTypeId(uint32_t from_tid) const override { \ - static const uint32_t tid = Base::GetTypeId(typeid(current_t).name()); \ - if (tid == from_tid) { \ - return true; \ - } \ - return parent_t::IsFromTypeId(from_tid); \ - } \ - std::string type_name() const override { return #current_t; } - -class Type; -using TypePtr = std::shared_ptr; - -class AnfNode; -using AnfNodePtr = std::shared_ptr; -using AnfNodePtrList = std::vector; -using AnfNodeSet = OrderedSet; - -namespace abstract { -class AbstractBase; -using AbstractBasePtr = std::shared_ptr; -using AbstractAttribute = std::pair; -class AnalysisContext; -using AnalysisContextPtr = std::shared_ptr; -} // namespace abstract - -struct MS_EXPORT TypeIdManager { - std::mutex mutex; - std::atomic type_counter{0}; - std::unordered_map map; - static TypeIdManager *Get(); - TypeIdManager() : mutex(), type_counter(0), map() {} -}; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_BASE_H_ diff --git a/mindspore/ccsrc/ir/dtype.h b/mindspore/ccsrc/ir/dtype.h deleted file mode 100644 index f10c56e659..0000000000 --- a/mindspore/ccsrc/ir/dtype.h +++ /dev/null @@ -1,335 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_DTYPE_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/base.h" -#include "ir/named.h" - -#include "ir/dtype/type.h" -#include "ir/dtype/ref.h" -#include "ir/dtype/number.h" -#include "ir/dtype/container.h" -#include "ir/dtype/empty.h" - -/* namespace to support intermediate representation definition */ -namespace mindspore { -// Only few type supported now. -TypePtr TypeIdToType(TypeId id); - -class String : public Object { - public: - String() : Object(kObjectTypeString, false) {} - ~String() override = default; - MS_DECLARE_PARENT(String, Object) - - TypeId generic_type_id() const override { return kObjectTypeString; } - - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string ToString() const override { return std::string("String"); } - std::string ToReprString() const override { return "string"; } - std::string DumpText() const override { return "String"; } -}; -using StringPtr = std::shared_ptr; - -class Keyword : public Object { - public: - Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {} - Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} - - ~Keyword() override = default; - MS_DECLARE_PARENT(Keyword, Object) - - TypeId generic_type_id() const override { return kObjectTypeKeyword; } - TypePtr DeepCopy() const override; - - std::string ToString() const override; - std::string DumpText() const override; - bool operator==(const Type &other) const override; - - std::string GetKey() const { return key_; } - TypePtr GetValue() const { return value_; } - - private: - std::string key_; - TypePtr value_; -}; -using KeywordPtr = std::shared_ptr; - -class Slice : public Object { - public: - Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {} - Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step) - : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {} - - ~Slice() override = default; - MS_DECLARE_PARENT(Slice, Object) - - TypeId generic_type_id() const override { return kObjectTypeSlice; } - TypePtr DeepCopy() const override; - - std::string ToString() const override; - std::string DumpText() const override; - bool operator==(const Type &other) const override; - - TypePtr get_start() const { return start_; } - TypePtr get_stop() const { return stop_; } - TypePtr get_step() const { return step_; } - - private: - TypePtr start_; - TypePtr stop_; - TypePtr step_; -}; -using SlicePtr = std::shared_ptr; - -class UndeterminedType : public Object { - public: - UndeterminedType() : Object(kObjectTypeUndeterminedType) {} - explicit UndeterminedType(const TypePtr &ele) - : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} - ~UndeterminedType() override = default; - MS_DECLARE_PARENT(UndeterminedType, Object) - - TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } - const TypePtr element() const { return element_type_; } - void set_element(const TypePtr &element_type) { element_type_ = element_type; } - - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string ToReprString() const override; - std::string DumpText() const override; - bool operator==(const Type &other) const override; - - protected: - TypePtr element_type_; -}; -using MetaTensorTypePtr = std::shared_ptr; - -class TensorType : public Object { - public: - TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} - explicit TensorType(const TypePtr &ele) - : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} - ~TensorType() override = default; - MS_DECLARE_PARENT(TensorType, Object) - - TypeId generic_type_id() const override { return kObjectTypeTensorType; } - const TypePtr element() const { return element_type_; } - void set_element(const TypePtr &element_type) { element_type_ = element_type; } - - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string ToReprString() const override; - std::string DumpText() const override; - bool operator==(const Type &other) const override; - - private: - TypePtr element_type_; -}; -using TensorTypePtr = std::shared_ptr; - -class IndexedSlicesType : public Object { - public: - IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {} - explicit IndexedSlicesType(const TypePtr &ele) - : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {} - ~IndexedSlicesType() override = default; - MS_DECLARE_PARENT(IndexedSlicesType, Object) - - TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; } - const TypePtr element() const { return element_type_; } - void set_element(const TypePtr &element_type) { element_type_ = element_type; } - - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string ToReprString() const override; - std::string DumpText() const override; - bool operator==(const Type &other) const override; - - private: - TypePtr element_type_; -}; -using IndexedSlicesTypePtr = std::shared_ptr; - -class Function : public Object { - public: - Function(); - Function(const std::vector &args, const TypePtr retval); - ~Function() override = default; - MS_DECLARE_PARENT(Function, Object) - - TypeId generic_type_id() const override { return kObjectTypeFunction; } - - // Add temporarily for return abstraction to avoid type checking. - bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); } - const std::vector &args() const { return args_; } - const TypePtr &retval() const { return retval_; } - - TypePtr DeepCopy() const override; - bool operator==(const Type &other) const override; - std::string ToString() const override; - std::string ToReprString() const override { return "function"; } - - private: - std::vector args_; - TypePtr retval_; -}; -using FunctionPtr = std::shared_ptr; - -class JTagged : public Object { - public: - JTagged() : Object(kObjectTypeJTagged) {} - explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} - ~JTagged() override = default; - MS_DECLARE_PARENT(JTagged, Object) - - TypeId generic_type_id() const override { return kObjectTypeJTagged; } - - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string DumpText() const override; - - private: - TypePtr subtype_; -}; -using JTaggedPtr = std::shared_ptr; - -class SymbolicKeyType : public Object { - public: - SymbolicKeyType() : Object(kObjectTypeSymbolicKeyType) {} - ~SymbolicKeyType() override = default; - MS_DECLARE_PARENT(SymbolicKeyType, Object) - - TypeId generic_type_id() const override { return kObjectTypeSymbolicKeyType; } - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string ToReprString() const override { return "symbolic_key"; } - std::string DumpText() const override { return "SymType"; } -}; - -class EnvType : public Object { - public: - EnvType() : Object(kObjectTypeEnvType) {} - ~EnvType() override = default; - MS_DECLARE_PARENT(EnvType, Object) - - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string ToReprString() const override { return "env_type"; } - std::string DumpText() const override { return "EnvType"; } -}; -using EnvTypePtr = std::shared_ptr; - -class TypeType : public Type { - public: - TypeType() : Type(kMetaTypeTypeType) {} - ~TypeType() override = default; - MS_DECLARE_PARENT(TypeType, Type) - - TypeId generic_type_id() const override { return kMetaTypeTypeType; } - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string ToReprString() const override { return "type_type"; } - std::string DumpText() const override { return "TypeType"; } -}; -using TypeTypePtr = std::shared_ptr; - -class Problem : public Type { - public: - Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {} - explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {} - ~Problem() override = default; - MS_DECLARE_PARENT(Problem, Type) - - TypeId generic_type_id() const override { return kMetaTypeProblem; } - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string ToString() const override { return kind_.name(); } - std::string DumpText() const override { return "ProblemType"; } - - friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr problem); - - private: - Named kind_; -}; -using ProblemPtr = std::shared_ptr; - -class External : public Type { - public: - External() : Type(kMetaTypeExternal) {} - ~External() override = default; - MS_DECLARE_PARENT(External, Type) - - TypeId generic_type_id() const override { return kMetaTypeExternal; } - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string DumpText() const override { return "ExternalType"; } - - private: - TypePtr kind; -}; -using ExternalPtr = std::shared_ptr; - -// helper template -template -TypePtr Clone(const T &t) { - return t.Clone(); -} - -TypePtr StringToType(const std::string &type_name); - -// Judge whether x is predicate or is a subclass of predicate. -bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type); - -bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type); - -// Whether t1 is identity or a subclass of t2. -bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr); - -struct TypeHasher { - std::size_t operator()(TypePtr const &type) const; -}; -struct TypeListHasher { - std::size_t operator()(const TypePtrList &type_list) const; -}; -struct TypeEqual { - bool operator()(TypePtr const &t1, TypePtr const &t2) const; -}; -struct TypeListEqual { - bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const; -}; - -extern const TypePtr kTypeExternal; -extern const TypePtr kTypeEnv; -extern const TypePtr kTypeType; -extern const TypePtr kString; -extern const TypePtr kList; -extern const TypePtr kTuple; -extern const TypePtr kDict; -extern const TypePtr kSlice; -extern const TypePtr kKeyword; -extern const TypePtr kTensorType; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_DTYPE_H_ diff --git a/mindspore/ccsrc/ir/dtype/container.h b/mindspore/ccsrc/ir/dtype/container.h deleted file mode 100644 index 0612d24c4d..0000000000 --- a/mindspore/ccsrc/ir/dtype/container.h +++ /dev/null @@ -1,150 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_DTYPE_CONTAINER_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_CONTAINER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/base.h" -#include "ir/named.h" -#include "ir/dtype/type.h" - -namespace mindspore { -// TypeRefKey type - -// List -class List : public Object { - public: - List() : Object(kObjectTypeList) {} - List(const std::initializer_list &objs) - : Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {} - // Shadow copy; - explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {} - ~List() override {} - MS_DECLARE_PARENT(List, Object) - - const TypePtr operator[](size_t dim) const; - TypeId generic_type_id() const override { return kObjectTypeList; } - TypePtr DeepCopy() const override; - - bool operator==(const Type &other) const override; - std::size_t size() const { return elements_.size(); } - TypePtrList elements() const { return elements_; } - std::string ToString() const override; - std::string ToReprString() const override { return "list_"; } - std::string DumpText() const override; - - private: - TypePtrList elements_; -}; -using ListPtr = std::shared_ptr; - -using ClassAttrVector = std::vector>; - -class Class : public Object { - public: - Class() : Object(kObjectTypeClass), tag_(Named("Class")) {} - Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map &methods); - ~Class() override {} - MS_DECLARE_PARENT(Class, Object) - - TypeId generic_type_id() const override { return kObjectTypeClass; } - - bool operator==(const Type &other) const override; - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string DumpText() const override; - void set_value(const std::unordered_map &v) { attributes_value_ = v; } - - Named tag() { return tag_; } - std::unordered_map GetValue() { return attributes_value_; } - std::unordered_map methods() { return methods_; } - ClassAttrVector &GetAttributes() { return attributes_; } - - ClassAttrVector attributes_; - - private: - Named tag_; - std::unordered_map methods_; - // For AbstractClass build value - std::unordered_map attributes_value_; -}; -using ClassPtr = std::shared_ptr; - -class Tuple : public Object { - public: - Tuple() : Object(kObjectTypeTuple) {} - // usage : Tuple t = {std::make_shared(), std::make_shared(32)}; - Tuple(const std::initializer_list &objs) - : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} - - // Shadow copy - explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} - - ~Tuple() override {} - MS_DECLARE_PARENT(Tuple, Object) - - TypeId generic_type_id() const override { return kObjectTypeTuple; } - TypePtr DeepCopy() const override; - - std::string ToString() const override; - std::string ToReprString() const override { return "tuple_"; } - std::string DumpText() const override; - const TypePtr operator[](size_t dim) const; - bool operator==(const Type &other) const override; - - TypePtrList elements() const { return elements_; } - std::size_t size() const { return elements_.size(); } - - private: - TypePtrList elements_; -}; -using TuplePtr = std::shared_ptr; - -class Dictionary : public Object { - public: - Dictionary() : Object(kObjectTypeDictionary) {} - explicit Dictionary(const std::vector> &key_values) - : Object(kObjectTypeDictionary, false), key_values_(key_values) {} - - ~Dictionary() override {} - MS_DECLARE_PARENT(Dictionary, Object) - - TypeId generic_type_id() const override { return kObjectTypeDictionary; } - - bool operator==(const Type &other) const override; - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string DumpText() const override; - - private: - std::vector> key_values_; -}; -using DictionaryPtr = std::shared_ptr; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_DTYPE_CONTAINER_H_ diff --git a/mindspore/ccsrc/ir/dtype/empty.h b/mindspore/ccsrc/ir/dtype/empty.h deleted file mode 100644 index e3b46ec7d9..0000000000 --- a/mindspore/ccsrc/ir/dtype/empty.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_DTYPE_EMPTY_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_EMPTY_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/base.h" -#include "ir/named.h" -#include "ir/dtype/type.h" - -namespace mindspore { -class TypeAnything : public Type { - public: - TypeAnything() : Type(kMetaTypeAnything) {} - ~TypeAnything() override {} - MS_DECLARE_PARENT(TypeAnything, Type) - - TypeId generic_type_id() const override { return kMetaTypeAnything; } - TypePtr DeepCopy() const override; - std::string DumpText() const override { return "AnythingType"; } -}; -using TypeAnythingPtr = std::shared_ptr; - -class TypeNone : public Type { - public: - TypeNone() : Type(kMetaTypeNone) {} - ~TypeNone() override {} - MS_DECLARE_PARENT(TypeNone, Type) - - TypeId generic_type_id() const override { return kMetaTypeNone; } - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string ToReprString() const override { return "type_none"; } - std::string DumpText() const override { return "NoneType"; } -}; -using TypeNonePtr = std::shared_ptr; - -class TypeNull : public Type { - public: - TypeNull() : Type(kMetaTypeNull) {} - ~TypeNull() override {} - MS_DECLARE_PARENT(TypeNull, Type) - - TypeId generic_type_id() const override { return kMetaTypeNull; } - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string DumpText() const override { return "NullType"; } -}; -using TypeNullPtr = std::shared_ptr; - -class TypeEllipsis : public Type { - public: - TypeEllipsis() : Type(kMetaTypeEllipsis) {} - ~TypeEllipsis() override {} - MS_DECLARE_PARENT(TypeEllipsis, Type) - - TypeId generic_type_id() const override { return kMetaTypeEllipsis; } - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string ToReprString() const override { return "Ellipsis"; } - std::string DumpText() const override { return "Ellipsis"; } -}; -using TypeEllipsisPtr = std::shared_ptr; - -extern const TypePtr kTypeNone; -extern const TypePtr kTypeNull; -extern const TypePtr kTypeEllipsis; -extern const TypePtr kAnyType; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_DTYPE_EMPTY_H_ diff --git a/mindspore/ccsrc/ir/dtype/number.h b/mindspore/ccsrc/ir/dtype/number.h deleted file mode 100644 index f8a746f8d6..0000000000 --- a/mindspore/ccsrc/ir/dtype/number.h +++ /dev/null @@ -1,154 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_DTYPE_NUMBER_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_NUMBER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/base.h" -#include "ir/named.h" -#include "ir/dtype/type.h" - -namespace mindspore { -// Number, abstract class. -class Number : public Object { - public: - Number() : Object(kObjectTypeNumber), number_type_(kObjectTypeNumber), nbits_(0) {} - Number(const TypeId number_type, const int nbits, bool is_generic = true) - : Object(kObjectTypeNumber, is_generic), number_type_(number_type), nbits_(nbits) {} - ~Number() override = default; - MS_DECLARE_PARENT(Number, Object) - - int nbits() const { return nbits_; } - - TypeId number_type() const override { return number_type_; } - TypeId type_id() const override { return number_type_; } - TypeId generic_type_id() const override { return kObjectTypeNumber; } - - bool operator==(const Type &other) const override; - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string ToString() const override { return "Number"; } - std::string ToReprString() const override { return "number"; } - std::string DumpText() const override { return "Number"; } - std::string GetTypeName(const std::string &type_name) const { - std::ostringstream oss; - oss << type_name; - if (nbits() != 0) { - oss << nbits(); - } - return oss.str(); - } - - private: - const TypeId number_type_; - const int nbits_; -}; - -// Bool -class Bool : public Number { - public: - Bool() : Number(kNumberTypeBool, 8) {} - ~Bool() override = default; - MS_DECLARE_PARENT(Bool, Number) - - TypeId generic_type_id() const override { return kNumberTypeBool; } - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string ToString() const override { return "Bool"; } - std::string ToReprString() const override { return "bool"; } - std::string DumpText() const override { return "Bool"; } -}; - -// Int -class Int : public Number { - public: - Int() : Number(kNumberTypeInt, 0) {} - explicit Int(const int nbits); - ~Int() override = default; - MS_DECLARE_PARENT(Int, Number) - TypeId generic_type_id() const override { return kNumberTypeInt; } - TypePtr DeepCopy() const override { return std::make_shared(nbits()); } - std::string ToString() const override { return GetTypeName("Int"); } - std::string ToReprString() const override { return nbits() == 0 ? "int_" : GetTypeName("int"); } - std::string DumpText() const override { - return nbits() == 0 ? std::string("Int") : std::string("I") + std::to_string(nbits()); - } -}; - -// UInt -class UInt : public Number { - public: - UInt() : Number(kNumberTypeUInt, 0) {} - explicit UInt(const int nbits); - TypeId generic_type_id() const override { return kNumberTypeUInt; } - - ~UInt() override {} - MS_DECLARE_PARENT(UInt, Number) - - TypePtr DeepCopy() const override { return std::make_shared(nbits()); } - std::string ToString() const override { return GetTypeName("UInt"); } - std::string ToReprString() const override { return GetTypeName("uint"); } - std::string DumpText() const override { - return nbits() == 0 ? std::string("UInt") : std::string("U") + std::to_string(nbits()); - } -}; - -// Float -class Float : public Number { - public: - Float() : Number(kNumberTypeFloat, 0) {} - explicit Float(const int nbits); - ~Float() override {} - MS_DECLARE_PARENT(Float, Number) - - TypeId generic_type_id() const override { return kNumberTypeFloat; } - TypePtr DeepCopy() const override { return std::make_shared(nbits()); } - std::string ToString() const override { return GetTypeName("Float"); } - std::string ToReprString() const override { return nbits() == 0 ? "float_" : GetTypeName("float"); } - std::string DumpText() const override { - return nbits() == 0 ? std::string("Float") : std::string("F") + std::to_string(nbits()); - } -}; - -extern const TypePtr kBool; -extern const TypePtr kInt8; -extern const TypePtr kInt16; -extern const TypePtr kInt32; -extern const TypePtr kInt64; -extern const TypePtr kUInt8; -extern const TypePtr kUInt16; -extern const TypePtr kUInt32; -extern const TypePtr kUInt64; -extern const TypePtr kFloat16; -extern const TypePtr kFloat32; -extern const TypePtr kFloat64; -extern const TypePtr kInt; -extern const TypePtr kUInt; -extern const TypePtr kFloat; -extern const TypePtr kNumber; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_DTYPE_NUMBER_H_ diff --git a/mindspore/ccsrc/ir/dtype/ref.h b/mindspore/ccsrc/ir/dtype/ref.h deleted file mode 100644 index 7d8159289f..0000000000 --- a/mindspore/ccsrc/ir/dtype/ref.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_DTYPE_REF_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_REF_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/base.h" -#include "ir/named.h" -#include "ir/dtype/type.h" - -namespace mindspore { -// TypeRefKey type -class RefKeyType : public Object { - public: - RefKeyType() : Object(kObjectTypeRefKey) {} - ~RefKeyType() override {} - MS_DECLARE_PARENT(RefKeyType, Object) - - TypeId generic_type_id() const override { return kObjectTypeRefKey; } - TypePtr DeepCopy() const override { return std::make_shared(); } - std::string ToReprString() const override { return "type_refkey"; } - std::string DumpText() const override { return "RefKeyType"; } -}; - -// TypeRef type -class RefType : public Object { - public: - RefType() : Object(kObjectTypeRef) {} - RefType(const TypePtr &subtype, const TypePtr &subtype_origin) - : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} - ~RefType() override {} - MS_DECLARE_PARENT(RefType, Object) - - TypePtr subtype() const { return subtype_; } - TypeId generic_type_id() const override { return kObjectTypeRef; } - TypePtr DeepCopy() const override; - std::string ToString() const override; - std::string DumpText() const override; - - private: - TypePtr subtype_; - TypePtr subtype_origin_; -}; -using RefTypePtr = std::shared_ptr; - -extern const TypePtr kRefKeyType; -extern const TypePtr kRefType; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_DTYPE_REF_H_ diff --git a/mindspore/ccsrc/ir/dtype/type.h b/mindspore/ccsrc/ir/dtype/type.h deleted file mode 100644 index cba0d17fce..0000000000 --- a/mindspore/ccsrc/ir/dtype/type.h +++ /dev/null @@ -1,127 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/base.h" -#include "ir/named.h" -#include "ir/dtype/type_id.h" - -namespace mindspore { - -TypeId IntBitsToTypeId(const int nbits); -TypeId UIntBitsToTypeId(const int nbits); -TypeId FloatBitsToTypeId(const int nbits); -const char *TypeIdLabel(const TypeId &v); -TypeId NormalizeTypeId(const TypeId type_id); -bool IsSameObjectType(const Type &lhs, const Type &rhs); -size_t GetTypeByte(const TypePtr &type_ptr); - -// Base class for all types -// forward declaration. - -class Type : public Value { - public: - Type() : meta_type_(kMetaTypeType), is_generic_(true) {} - explicit Type(TypeId t, bool is_generic = true) : meta_type_(t), is_generic_(is_generic) {} - ~Type() override = default; - MS_DECLARE_PARENT(Type, Value) - - bool operator==(const Value &other) const override; - TypeId meta_type() const { return meta_type_; } - - virtual TypeId type_id() const { return meta_type_; } - virtual TypeId generic_type_id() const { return kMetaTypeType; } - - virtual bool operator!=(const Type &other) const { return !(*this == other); } - virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); } - virtual bool equal(const TypePtr other) const { return *this == *other; } - - virtual TypeId object_type() const { return kTypeUnknown; } - virtual TypeId parent_type() const { return kTypeUnknown; } - virtual TypeId number_type() const { return kTypeUnknown; } - virtual TypePtr DeepCopy() const = 0; - virtual TypePtr Clone() const { return DeepCopy(); } - - std::size_t hash() const override { return std::hash{}(static_cast(type_id())); } - - std::string ToString() const override { return TypeIdLabel(meta_type_); } - virtual std::string ToReprString() const { return ToString(); } - std::string ReprString() const { return "mindspore." + ToReprString(); } - void dump() const override { std::cout << ToString() << std::endl; } - bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } - bool IsGeneric() const { return is_generic_; } - abstract::AbstractBasePtr ToAbstract() override; - friend std::ostream &operator<<(std::ostream &os, const Type &type); - friend std::ostream &operator<<(std::ostream &os, const TypePtr type); - - const bool parse_info_ = true; - - private: - TypeId meta_type_; - bool is_generic_; -}; - -using TypePtrList = std::vector; - -// -// Base class for normal objects -// -class Object : public Type { - public: - Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject), parent_type_(kMetaTypeObject) {} - explicit Object(const TypeId object_type, bool is_generic = true) - : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(kMetaTypeObject) {} - explicit Object(const TypeId object_type, const TypeId parent_type, bool is_generic = true) - : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(parent_type) {} - ~Object() override = default; - MS_DECLARE_PARENT(Object, Type) - - TypeId object_type() const override { return object_type_; } - TypeId parent_type() const override { return parent_type_; } - TypeId type_id() const override { return object_type_; } - TypeId generic_type_id() const override { return kMetaTypeObject; } - bool equal(const TypePtr other) const override; - std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } - - friend std::ostream &operator<<(std::ostream &os, const Object &obj); - friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr obj); - - private: - const TypeId object_type_; - const TypeId parent_type_; -}; - -std::ostream &operator<<(std::ostream &os, const TypePtrList &types); -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ diff --git a/mindspore/ccsrc/ir/dtype/type_extends.cc b/mindspore/ccsrc/ir/dtype/type_extends.cc deleted file mode 100644 index a77a6a9cba..0000000000 --- a/mindspore/ccsrc/ir/dtype/type_extends.cc +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/dtype/type.h" -#include "pipeline/static_analysis/abstract_value.h" - -namespace mindspore { -abstract::AbstractBasePtr Type::ToAbstract() { - auto ptr = std::make_shared(shared_from_base()); - return ptr; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype/type_id.h b/mindspore/ccsrc/ir/dtype/type_id.h deleted file mode 100644 index a711779e91..0000000000 --- a/mindspore/ccsrc/ir/dtype/type_id.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_ -#define MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_ - -#include -#include - -namespace mindspore { -// -// Supported meta type -// -enum TypeId : int { - kTypeUnknown = 0, - kMetaTypeBegin = kTypeUnknown, - kMetaTypeType, // Type - kMetaTypeAnything, - kMetaTypeObject, - kMetaTypeTypeType, // TypeType - kMetaTypeProblem, - kMetaTypeExternal, - kMetaTypeNone, - kMetaTypeNull, - kMetaTypeEllipsis, - kMetaTypeEnd, - // - // Object types - // - kObjectTypeBegin = kMetaTypeEnd, - kObjectTypeNumber, - kObjectTypeString, - kObjectTypeList, - kObjectTypeTuple, - kObjectTypeSlice, - kObjectTypeKeyword, - kObjectTypeTensorType, - kObjectTypeIndexedSlicesType, - kObjectTypeUndeterminedType, - kObjectTypeClass, - kObjectTypeDictionary, - kObjectTypeFunction, - kObjectTypeJTagged, - kObjectTypeSymbolicKeyType, - kObjectTypeEnvType, - kObjectTypeRefKey, - kObjectTypeRef, - kObjectTypeEnd, - // - // Number Types - // - kNumberTypeBegin = kObjectTypeEnd, - kNumberTypeBool, - kNumberTypeInt, - kNumberTypeInt8, - kNumberTypeInt16, - kNumberTypeInt32, - kNumberTypeInt64, - kNumberTypeUInt, - kNumberTypeUInt8, - kNumberTypeUInt16, - kNumberTypeUInt32, - kNumberTypeUInt64, - kNumberTypeFloat, - kNumberTypeFloat16, - kNumberTypeFloat32, - kNumberTypeFloat64, - kNumberTypeEnd -}; -// -// TypeId name map -// -const std::unordered_map type_name_map = { - {kNumberTypeBool, "Bool"}, {kNumberTypeInt8, "Int8"}, {kNumberTypeUInt8, "UInt8"}, - {kNumberTypeInt16, "Int16"}, {kNumberTypeInt32, "Int32"}, {kNumberTypeInt64, "Int64"}, - {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat32, "Float32"}, {kNumberTypeFloat64, "Float64"}}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_ diff --git a/mindspore/ccsrc/ir/dtype_extends.cc b/mindspore/ccsrc/ir/dtype_extends.cc deleted file mode 100644 index 732872cb4f..0000000000 --- a/mindspore/ccsrc/ir/dtype_extends.cc +++ /dev/null @@ -1,568 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/dtype.h" -#include -#include -#include -#include "utils/log_adapter.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" - -namespace mindspore { -TypePtr TypeAnything::DeepCopy() const { return kAnyType; } - -std::size_t TypeHasher::operator()(TypePtr const &type) const { - MS_EXCEPTION_IF_NULL(type); - std::size_t hash = std::hash()(type->type_id()); - return hash; -} - -std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { - std::size_t hash_sum = 0; - for (auto &type : type_list) { - auto type_id = static_cast(type->type_id()); - hash_sum = hash_combine(hash_sum, type_id); - } - return hash_sum; -} - -bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { - MS_EXCEPTION_IF_NULL(t1); - MS_EXCEPTION_IF_NULL(t2); - return t1->type_id() == t2->type_id(); -} - -bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { - if (lhs.size() != rhs.size()) { - return false; - } - std::size_t size = lhs.size(); - for (std::size_t i = 0; i < size; ++i) { - MS_EXCEPTION_IF_NULL(lhs[i]); - MS_EXCEPTION_IF_NULL(rhs[i]); - if (*lhs[i] != *rhs[i]) { - return false; - } - } - return true; -} - -TypePtr TypeIdToType(TypeId id) { - switch (id) { - case kNumberTypeFloat16: - return kFloat16; - case kNumberTypeFloat: - case kNumberTypeFloat32: - return kFloat32; - case kNumberTypeFloat64: - return kFloat64; - case kNumberTypeInt8: - return kInt8; - case kNumberTypeInt16: - return kInt16; - case kNumberTypeInt32: - return kInt32; - case kNumberTypeInt64: - return kInt64; - case kNumberTypeUInt8: - return kUInt8; - case kNumberTypeUInt16: - return kUInt16; - case kNumberTypeUInt32: - return kUInt32; - case kNumberTypeUInt64: - return kUInt64; - case kNumberTypeBool: - return kBool; - case kMetaTypeExternal: - return kTypeExternal; - case kMetaTypeAnything: - return kAnyType; - case kMetaTypeNone: - return kTypeNone; - case kMetaTypeNull: - return kTypeNull; - case kMetaTypeEllipsis: - return kTypeEllipsis; - case kObjectTypeEnvType: - return kTypeEnv; - case kObjectTypeRefKey: - return kRefKeyType; - case kObjectTypeRef: - return kRefType; - case kMetaTypeTypeType: - return kTypeType; - case kObjectTypeString: - return kString; - case kObjectTypeList: - return kList; - case kObjectTypeTuple: - return kTuple; - case kObjectTypeDictionary: - return kDict; - case kObjectTypeSlice: - return kSlice; - case kObjectTypeKeyword: - return kKeyword; - case kTypeUnknown: - return kTypeNone; - default: - MS_LOG(EXCEPTION) << "Not support the type: " << id; - } -} - -namespace { -template -TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { - TypePtr type = nullptr; - if (type_name == num_type_name) { - type = std::make_shared(); - } else { - try { - if (num_type_name.size() >= type_name.size()) { - MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name - << ")"; - } - auto bits = std::stoi(type_name.substr(num_type_name.size())); - type = std::make_shared(bits); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << num_type_name << " convert from string error " << e.what(); - } - } - return type; -} - -std::vector StringToVectorOfType(const std::string &type_names) { - std::vector types; - if (type_names.length() == 0) { - return types; - } - std::string::size_type start = 0; - std::string::size_type end = type_names.find_first_of(','); - while (end != std::string::npos) { - types.push_back(StringToType(type_names.substr(start, end))); - // Skip ',' to find the next element. - start = end + 1; - end = type_names.find_first_of(',', start); - } - if (start >= type_names.size()) { - MS_LOG(EXCEPTION) << "Type name is empty string."; - } - types.push_back(StringToType(type_names.substr(start))); - return types; -} - -TypePtr TensorStrToType(const std::string &type_name) { - TypePtr type = nullptr; - if (type_name == "Tensor") { - type = std::make_shared(); - } else { - try { - auto start = type_name.find_first_of('[') + 1; - auto end = type_name.find_last_of(']'); - if (start >= type_name.size()) { - return nullptr; - } - auto element_str = type_name.substr(start, end - start); - auto element_type = StringToType(element_str); - if (element_type == nullptr) { - return nullptr; - } - type = std::make_shared(element_type); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); - } - } - - return type; -} - -TypePtr IndexedSlicesStrToType(const std::string &type_name) { - if (type_name == "IndexedSlices") { - return std::make_shared(); - } - auto start = type_name.find_first_of('[') + 1; - auto end = type_name.find_last_of(']'); - if (start >= type_name.size()) { - return nullptr; - } - auto element_str = type_name.substr(start, end - start); - auto element_type = StringToType(element_str); - if (element_type == nullptr) { - return nullptr; - } - return std::make_shared(element_type); -} - -TypePtr UndeterminedStrToType(const std::string &type_name) { - if (type_name == "Undetermined") { - return std::make_shared(); - } - auto start = type_name.find_first_of('[') + 1; - auto end = type_name.find_last_of(']'); - if (start >= type_name.size()) { - return nullptr; - } - auto element_str = type_name.substr(start, end - start); - auto element_type = StringToType(element_str); - if (element_type == nullptr) { - return nullptr; - } - return std::make_shared(element_type); -} - -TypePtr ListStrToType(const std::string &type_name) { - TypePtr type = nullptr; - if (type_name == "List") { - type = std::make_shared(); - } else { - try { - auto start = type_name.find_first_of('[') + 1; - auto end = type_name.find_last_of(']'); - if (start >= type_name.size()) { - return nullptr; - } - std::string element_strs = type_name.substr(start, end - start); - std::vector element_types = StringToVectorOfType(element_strs); - bool wrong = - std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); - if (wrong) { - return nullptr; - } - type = std::make_shared(element_types); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); - } - } - - return type; -} - -TypePtr TupleStrToType(const std::string &type_name) { - TypePtr type = nullptr; - if (type_name == "Tuple") { - type = std::make_shared(); - } else { - try { - size_t start = type_name.find_first_of('[') + 1; - size_t end = type_name.find_last_of(']'); - if (start >= type_name.size()) { - return nullptr; - } - std::string element_strs = type_name.substr(start, end - start); - std::vector element_types = StringToVectorOfType(element_strs); - bool wrong = - std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); - if (wrong) { - return nullptr; - } - type = std::make_shared(element_types); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); - } - } - return type; -} - -TypePtr FunctionStrToType(const std::string &type_name) { - TypePtr type = nullptr; - - if (type_name == "Function") { - type = std::make_shared(); - } else { - try { - // format: [(para1, para2, para3, ...) retval] - size_t start = type_name.find_first_of('[') + 1; - size_t end = type_name.find_last_of(']'); - if (start >= type_name.size()) { - return nullptr; - } - std::string str_all = type_name.substr(start, end - start); - size_t start_a = str_all.find_first_of('(') + 1; - size_t end_a = str_all.find_last_of(')'); - if (start_a >= str_all.size()) { - return nullptr; - } - std::string str_args = str_all.substr(start_a, end_a - start_a); - // bypass " " between ")" and retval - start = end_a + 2; - if (start >= str_all.size()) { - return nullptr; - } - std::string str_retval = str_all.substr(start); - - std::vector args_type = StringToVectorOfType(str_args); - TypePtr retval = StringToType(str_retval); - bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); - if (retval == nullptr || wrong) { - return nullptr; - } - type = std::make_shared(args_type, retval); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); - } - } - return type; -} -} // namespace - -TypePtr StringToType(const std::string &type_name) { - TypePtr type = nullptr; - if (type_name.compare("None") == 0) { - type = std::make_shared(); - } else if (type_name.compare("Ellipsis") == 0) { - type = std::make_shared(); - } else if (type_name.compare("TypeType") == 0) { - type = std::make_shared(); - } else if (type_name.compare("SymbolicKeyType") == 0) { - type = std::make_shared(); - } else if (type_name.compare("RefKeyType") == 0) { - type = std::make_shared(); - } else if (type_name.compare("EnvType") == 0) { - type = std::make_shared(); - } else if (type_name.compare("Number") == 0) { - type = std::make_shared(); - } else if (type_name.compare("Bool") == 0) { - type = std::make_shared(); - } else if (type_name.compare(0, strlen("Int"), "Int") == 0) { - type = StringToNumberType(type_name, "Int"); - } else if (type_name.compare(0, strlen("UInt"), "UInt") == 0) { - type = StringToNumberType(type_name, "UInt"); - } else if (type_name.compare(0, strlen("Float"), "Float") == 0) { - type = StringToNumberType(type_name, "Float"); - } else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) { - type = TensorStrToType(type_name); - } else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) { - type = UndeterminedStrToType(type_name); - } else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) { - type = IndexedSlicesStrToType(type_name); - } else if (type_name.compare(0, strlen("List"), "List") == 0) { - type = ListStrToType(type_name); - } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { - type = TupleStrToType(type_name); - } else if (type_name.compare("Slice") == 0) { - type = std::make_shared(); - } else if (type_name.compare("Dictionary") == 0) { - type = std::make_shared(); - } else if (type_name.compare("String") == 0) { - type = std::make_shared(); - } else if (type_name.compare("Problem") == 0) { - type = std::make_shared(); - } else if (type_name.compare(0, strlen("Function"), "Function") == 0) { - type = FunctionStrToType(type_name); - } else { - // - unsupported to convert - // Class - // SymbolicType - // JTagged - // Anything - // External - // Problem - MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!"; - } - return type; -} - -bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type) { - if (x == nullptr || base_type == nullptr) { - MS_LOG(ERROR) << "Type is nullptr."; - return false; - } - if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { - return false; - } - if (base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type()) { - return true; - } - return false; -} - -bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { - if (x == nullptr || base_type == nullptr) { - MS_LOG(ERROR) << "Type is nullptr."; - return false; - } - if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { - return false; - } else if (!(base_type->IsGeneric())) { - return *(base_type) == *(x); - } else if (base_type->type_id() == x->type_id()) { - return true; - } else if (base_type->type_id() == x->generic_type_id()) { - return true; - } else if (base_type->type_id() == x->object_type()) { - return true; - } else if (base_type->type_id() == x->meta_type()) { - return true; - } else { - return false; - } -} - -bool IsSubType(TypePtr const &t1, TypePtr const &t2) { - MS_EXCEPTION_IF_NULL(t1); - if (t1->type_id() == kTypeUnknown) { - return false; - } else if (t2 != nullptr) { - return IsIdentidityOrSubclass(t1, t2); - } else { - return true; - } -} - -REGISTER_PYBIND_DEFINE( - typing, ([](py::module *const m) { - auto m_sub = m->def_submodule("typing", "submodule for dtype"); - py::enum_(m_sub, "TypeId"); - (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); - (void)m_sub.def("load_type", &TypeIdToType, "load type"); - (void)m_sub.def( - "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); - (void)m_sub.def("str_to_type", &StringToType, "string to typeptr"); - (void)py::class_>(m_sub, "Type") - .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) - .def("__eq__", - [](const TypePtr &t1, const TypePtr &t2) { - if (t1 != nullptr && t2 != nullptr) { - return *t1 == *t2; - } - return false; - }) - .def("__hash__", &Type::hash) - .def("__str__", &Type::ToString) - .def("__repr__", &Type::ReprString) - .def("__deepcopy__", [](const TypePtr &t, py::dict) { - if (t == nullptr) { - return static_cast(nullptr); - } - return t->DeepCopy(); - }); - (void)py::class_>(m_sub, "Number").def(py::init()); - (void)py::class_>(m_sub, "Bool") - .def(py::init()) - .def(py::pickle( - [](const Bool &) { // __getstate__ - return py::make_tuple(); - }, - [](const py::tuple &) { // __setstate__ - return std::make_shared(); - })); - (void)py::class_>(m_sub, "Int") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const Int &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - Int data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "UInt") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const UInt &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - UInt data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "Float") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const Float &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - Float data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "List") - .def(py::init()) - .def(py::init>(), py::arg("elements")); - (void)py::class_>(m_sub, "Tuple") - .def(py::init()) - .def(py::init>(), py::arg("elements")); - (void)py::class_>(m_sub, "TensorType") - .def(py::init()) - .def(py::init(), py::arg("element")) - .def("element_type", &TensorType::element) - .def(py::pickle( - [](const TensorType &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - TensorType data(TypeIdToType(TypeId(static_cast(t[0].cast())))); - return data; - })); - (void)py::class_>(m_sub, "IndexedSlicesType") - .def(py::init()); - (void)py::class_>(m_sub, "UndeterminedType") - .def(py::init()); - (void)py::class_>(m_sub, "Function") - .def(py::init()) - .def(py::init, TypePtr>(), py::arg("args"), py::arg("retval")); - (void)py::class_>(m_sub, "Class").def(py::init()); - (void)py::class_>(m_sub, "SymbolicKeyType").def(py::init()); - (void)py::class_>(m_sub, "EnvType").def(py::init()); - (void)py::class_>(m_sub, "TypeNone").def(py::init()); - (void)py::class_>(m_sub, "TypeType").def(py::init()); - (void)py::class_>(m_sub, "String").def(py::init()); - (void)py::class_>(m_sub, "RefKeyType").def(py::init()); - (void)py::class_>(m_sub, "RefType").def(py::init()); - (void)py::class_>(m_sub, "TypeAnything").def(py::init()); - (void)py::class_>(m_sub, "Slice").def(py::init()); - (void)py::class_>(m_sub, "TypeEllipsis").def(py::init()); - })); - -const TypePtr kTypeExternal = std::make_shared(); -const TypePtr kTypeEnv = std::make_shared(); -const TypePtr kTypeType = std::make_shared(); -const TypePtr kTensorType = std::make_shared(); -const TypePtr kIndexedSlicesType = std::make_shared(); -const TypePtr kUndeterminedType = std::make_shared(); -const TypePtr kString = std::make_shared(); -const TypePtr kList = std::make_shared(); -const TypePtr kTuple = std::make_shared(); -const TypePtr kDict = std::make_shared(); -const TypePtr kSlice = std::make_shared(); -const TypePtr kKeyword = std::make_shared(); -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc deleted file mode 100644 index 4e01e9003f..0000000000 --- a/mindspore/ccsrc/ir/func_graph.cc +++ /dev/null @@ -1,628 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/func_graph.h" - -#include -#include -#include - -#include "debug/trace.h" -#include "ir/manager.h" -#include "operator/ops.h" -#include "pybind_api/export_flags.h" -#include "utils/ordered_set.h" -#include "utils/convert_utils_base.h" - -namespace mindspore { -/* - * Methods of Graph - */ -FuncGraph::FuncGraph() - : attrs_(), - transforms_(), - parameter_default_value_(), - seen_(0), - parameters_(), - has_vararg_(false), - has_kwarg_(false), - kwonlyargs_count_(0), - hyper_param_count_(0), - is_generated_(false), - return_(nullptr), - manager_(std::weak_ptr()) { - debug_info_ = std::make_shared(); -} - -AnfNodePtr FuncGraph::output() const { - // If return value is set, return should have two inputs. - if (return_ != nullptr && return_->inputs().size() == 2) { - return return_->input(1); - } else { - // If not set yet, return nullptr. - return nullptr; - } -} - -ParameterPtr FuncGraph::add_parameter() { - FuncGraphPtr this_func_graph = shared_from_base(); - ParameterPtr p = std::make_shared(this_func_graph); - add_parameter(p); - return p; -} - -void FuncGraph::add_parameter(const ParameterPtr &p) { - if (manager_.lock()) { - manager_.lock()->AddParameter(shared_from_base(), p); - } else { - parameters_.push_back(p); - } -} - -ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { - FuncGraphPtr this_graph = shared_from_base(); - ParameterPtr p = std::make_shared(this_graph); - p->set_name(name); - p->debug_info()->set_name(name); - - if (manager_.lock()) { - manager_.lock()->AddParameter(shared_from_base(), p); - } else { - parameters_.push_back(p); - } - hyper_param_count_++; - return p; -} - -bool FuncGraph::has_flag(const std::string &key) { - auto iter = attrs_.find(key); - if (iter != attrs_.cend()) { - if (iter->second->isa()) { - return GetValue(iter->second); - } - MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function."; - } - return false; -} - -bool FuncGraph::has_attr(const std::string &key) { - auto iter = attrs_.find(key); - return !(iter == attrs_.cend()); -} - -ValuePtr FuncGraph::get_attr(const std::string &key) { - auto iter = attrs_.find(key); - return iter == attrs_.cend() ? nullptr : iter->second; -} - -CNodePtr FuncGraph::NewCNode(const std::vector &inputs) { - CNodePtr cnode = std::make_shared(inputs, shared_from_base()); - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - order_.push_back(cnode); - MS_LOG(INFO) << "Graph: " << ToString() << ", push back " << cnode->DebugString() << " in order."; - } - return cnode; -} - -CNodePtr FuncGraph::NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope) { - CNodePtr app = NewCNode(inputs); - app->set_scope(scope); - return app; -} - -void FuncGraph::DumpCNodeList() { - MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; - for (const auto &cnode : order_) { - MS_LOG(INFO) << cnode->DebugString(); - } -} - -std::string FuncGraph::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); -} - -GraphDebugInfoPtr FuncGraph::debug_info() { - MS_EXCEPTION_IF_NULL(this->debug_info_); - if (this->debug_info_->get_graph() == nullptr) { - this->debug_info_->set_graph(shared_from_base()); - } - return this->debug_info_; -} - -const AnfNodeSet &FuncGraph::nodes() { return nodes_; } - -void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); } - -void FuncGraph::ClearNodes() { nodes_.clear(); } - -void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); } - -void FuncGraph::DropNode(AnfNodePtr node) { - nodes_.erase(node); - auto graph = node->func_graph(); - // Remove the node from order list. - if (graph) { - graph->EraseUnusedNodeInOrder(node); - } -} - -const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } - -void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) { - auto &others = source->value_nodes(); - for (auto it = others.begin(); it != others.end(); it++) { - AddValueNode(it->first, it->second); - } -} - -void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } - -void FuncGraph::AddValueNode(AnfNodePtr node, int count) { - if (value_nodes_.count(node) == 0) { - value_nodes_[node] = count; - } else { - value_nodes_[node] += count; - } -} - -void FuncGraph::DropValueNode(AnfNodePtr node) { - if (value_nodes_.count(node) != 0) { - if (value_nodes_[node] == 1) { - (void)value_nodes_.erase(node); - } else { - value_nodes_[node]--; - if (value_nodes_[node] < 0) { - MS_LOG(EXCEPTION) << "Count of ValueNode '" << node - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - } - } -} - -const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } - -void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) { - auto &others = source->free_variables(); - for (auto it = others.begin(); it != others.end(); it++) { - if (it->first->func_graph().get() != this) { - (void)AddFreeVariable(it->first, it->second); - } - } -} - -void FuncGraph::ClearFreeVariables() { free_variables_.clear(); } - -bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) { - if (free_variables_.count(node) == 0) { - free_variables_[node] = count; - return true; - } else { - free_variables_[node] += count; - return false; - } -} - -bool FuncGraph::DropFreeVariable(AnfNodePtr node) { - if (free_variables_.count(node) != 0) { - if (free_variables_[node] == 1) { - (void)free_variables_.erase(node); - return true; - } else { - free_variables_[node]--; - if (free_variables_[node] < 0) { - MS_LOG(EXCEPTION) << "Count of free variable '" << node - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - } - } - return false; -} - -const BaseRefCounterMap &FuncGraph::free_variables_total() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - auto &fv_total = mng->free_variables_total(); - return fv_total[shared_from_base()]; -} - -std::vector FuncGraph::free_variables_nodes() { - std::vector nodes; - const auto &fv_total = this->free_variables_total(); - for (auto &p : fv_total) { - auto key = p.first; - if (utils::isa(key)) { - nodes.push_back(utils::cast(key)); - } - } - - return nodes; -} - -std::vector FuncGraph::free_variables_func_graphs() { - std::vector func_graphs; - const auto &fv_total = this->free_variables_total(); - for (auto &p : fv_total) { - auto key = p.first; - if (utils::isa(key)) { - func_graphs.push_back(utils::cast(key)); - } - } - - return func_graphs; -} - -const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; } - -void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) { - auto &others = source->func_graphs_used(); - for (auto it = others.begin(); it != others.end(); it++) { - (void)AddFuncGraphUsed(it->first, it->second); - } - func_graphs_used_.erase(source); -} - -void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); } - -bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) { - if (func_graphs_used_.count(fg) == 0) { - func_graphs_used_[fg] = count; - return true; - } else { - func_graphs_used_[fg] += count; - return false; - } -} - -bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) { - if (func_graphs_used_.count(fg) != 0) { - if (func_graphs_used_[fg] == 1) { - (void)func_graphs_used_.erase(fg); - return true; - } else { - func_graphs_used_[fg]--; - if (func_graphs_used_[fg] < 0) { - MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - } - } - return false; -} - -const FuncGraphSet &FuncGraph::func_graphs_used_total() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - auto &used = mng->func_graphs_used_total(shared_from_base()); - return used; -} - -const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } - -void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) { - auto &others = source->func_graph_cnodes_index(); - for (auto it = others.begin(); it != others.end(); it++) { - // Ignore the user graph who may own itself. - auto fg = it->first->first->func_graph(); - MS_EXCEPTION_IF_NULL(fg); - if (fg.get() != this) { - AddFuncGraphCNodeIndex(it->first, it->second); - } - } -} - -void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); } - -void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) { - if (func_graph_cnodes_index_.count(pair) == 0) { - func_graph_cnodes_index_[pair] = count; - } else { - func_graph_cnodes_index_[pair] += count; - } -} - -void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { - if (func_graph_cnodes_index_.count(pair) != 0) { - if (func_graph_cnodes_index_[pair] == 1) { - (void)func_graph_cnodes_index_.erase(pair); - } else { - func_graph_cnodes_index_[pair]--; - if (func_graph_cnodes_index_[pair] < 0) { - MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - } - } -} - -const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; } - -void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) { - auto &others = source->j_func_graphs(); - for (auto it = others.begin(); it != others.end(); it++) { - AddJFuncGraph(it->first, it->second); - } -} - -void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); } - -void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) { - if (j_func_graphs_.count(fg) == 0) { - j_func_graphs_[fg] = count; - } else { - j_func_graphs_[fg] += count; - } -} - -void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) { - if (j_func_graphs_.count(fg) != 0) { - if (j_func_graphs_[fg] == 1) { - (void)j_func_graphs_.erase(fg); - } else { - j_func_graphs_[fg]--; - if (j_func_graphs_[fg] < 0) { - MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - } - } -} - -FuncGraphPtr FuncGraph::parent() { - // report the bug early. - if (manager_.lock() == nullptr) { - MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() - << " NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - return mng->parent(shared_from_base()); -} - -const FuncGraphSet &FuncGraph::children() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - return mng->children(shared_from_base()); -} - -const FuncGraphSet &FuncGraph::scope() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - return mng->scopes(shared_from_base()); -} - -bool FuncGraph::recursive() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - return mng->recursive(shared_from_base()); -} - -std::shared_ptr> FuncGraph::recursive_graphs() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - return mng->recursive_graphs(shared_from_base()); -} - -AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { - auto itr = this->parameter_default_value_.find(name); - if (itr == parameter_default_value_.end()) { - return nullptr; - } - auto default_value = itr->second; - if (default_value == nullptr) { - MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist"; - } - if (IsValueNode(default_value)) { - return nullptr; - } - return default_value; -} - -// set the default values -void FuncGraph::SetDefaultValues(const std::vector &name_list, const std::vector &value_list) { - auto all_is_null = - std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode(node); }); - if (value_list.empty()) { - all_is_null = true; - } - for (size_t i = 0; i < name_list.size(); ++i) { - if (!all_is_null) { - this->parameter_default_value_[name_list[i]] = value_list[i]; - } - } -} - -void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } - -size_t FuncGraph::GetDefaultValueCount() { - int null_count = - std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), - [](const std::pair &pair) { return IsValueNode(pair.second); }); - return parameter_default_value_.size() - IntToSize(null_count); -} - -AnfNodePtr FuncGraph::GetVariableArgParameter() { - if (!has_vararg_) { - return nullptr; - } - - if (has_kwarg_) { - if (parameters_.size() < hyper_param_count_ + 2) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 2]; - } - - if (parameters_.size() < hyper_param_count_ + 1) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 1]; -} - -std::string FuncGraph::GetVariableArgName() { - if (!has_vararg_) { - return ""; - } - - if (has_kwarg_) { - if (parameters_.size() < hyper_param_count_ + 2) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 2]->cast()->name(); - } - - if (parameters_.size() < hyper_param_count_ + 1) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast()->name(); -} - -AnfNodePtr FuncGraph::GetVariableKwargParameter() { - if (has_kwarg_) { - if (parameters_.size() < hyper_param_count_ + 1) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 1]; - } - return nullptr; -} - -std::string FuncGraph::GetVariableKwargName() { - if (has_kwarg_) { - if (parameters_.size() < hyper_param_count_ + 1) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast()->name(); - } - return ""; -} - -int FuncGraph::GetPositionalArgsCount() const { - int count = SizeToInt(parameters_.size()); - if (has_kwarg_) { - count--; - } - if (has_vararg_) { - count--; - } - return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); -} - -AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { - for (size_t i = 0; i < parameters_.size(); ++i) { - MS_EXCEPTION_IF_NULL(parameters_[i]); - auto param_cast = parameters_[i]->cast(); - MS_EXCEPTION_IF_NULL(param_cast); - if (param_cast->name() == name) { - return parameters_[i]; - } - } - return nullptr; -} - -void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } - -std::list FuncGraph::GetOrderedCnodes() { - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - MS_LOG(DEBUG) << "Return ordered cnodes."; - return order_; - } else { - auto this_ptr = shared_from_base(); - auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1); - auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1); - - std::list cnodes; - auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); - for (const auto &node : nodes) { - auto cnode = dyn_cast(node); - if (cnode) { - cnodes.push_back(cnode); - } - } - return cnodes; - } -} - -void FuncGraph::EraseUnusedNodeInOrder() { - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - auto mng = manager_.lock(); - if (mng) { - auto &all_nodes = nodes(); - // Erase unused cnode. - for (auto it = order_.begin(); it != order_.end();) { - if (all_nodes.count(*it)) { - (void)it++; - } else { - MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; - it = order_.erase(it); - } - } - } - } -} - -void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { - if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa()) { - order_.remove(n->cast()); - MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; - } -} - -void FuncGraph::CheckOrder() { - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - MS_LOG(DEBUG) << "Check graph " << ToString(); - for (auto it = order_.begin(); it != order_.end(); (void)it++) { - for (const auto &input_node : (*it)->inputs()) { - if (input_node && input_node->isa() && input_node->func_graph() == shared_from_base()) { - // Need to reorder the wrong order node. - auto found = std::find(order_.begin(), it, input_node); - if (found == it) { - DumpCNodeList(); - MS_LOG(EXCEPTION) << "The cnode " << (*it)->DebugString() << " order in " << ToString() - << " doesn't obey the input dependency, " - << "as input " << input_node->DebugString() << " is not ahead of itself."; - } - } - } - } - auto mng = manager_.lock(); - if (mng != nullptr) { - const auto &all_nodes = nodes(); - if (all_nodes.size() != (order_.size() + parameters_.size())) { - DumpCNodeList(); - MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " - << all_nodes.size() - parameters_.size() << "."; - } - } - MS_LOG(DEBUG) << "Check order okay."; - } -} - -size_t NewFgSeenGeneration() { - static size_t fg_seen_generation = 0; - return ++fg_seen_generation; -} - -const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared("FuncGraph"); -const char kFuncGraphFlagUndetermined[] = "Undeterminate"; -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h deleted file mode 100644 index b1be892a53..0000000000 --- a/mindspore/ccsrc/ir/func_graph.h +++ /dev/null @@ -1,420 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_ -#define MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/manager.h" -#include "utils/ordered_set.h" -#include "utils/ordered_map.h" -#include "utils/base_ref.h" - -namespace mindspore { -using BaseRefCounterMap = OrderedMap; -using FuncGraphCounterMap = OrderedMap; - -struct CNodeIndexHasher { - std::size_t operator()(const CNodeIndexPairPtr pair) const { - MS_EXCEPTION_IF_NULL(pair); - MS_EXCEPTION_IF_NULL(pair->first); - return hash_combine(pair->first->hash(), std::hash()(pair->second)); - } -}; - -struct CNodeIndexEqual { - bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const { - if (lhs == nullptr || rhs == nullptr) { - return false; - } - if (lhs == rhs) { - return true; - } - if (lhs->first != rhs->first) { - return false; - } - if (lhs->second != rhs->second) { - return false; - } - return true; - } -}; - -template , class CounterEqual = std::equal_to> -using CounterOrderedMap = OrderedMap; -using AnfNodeCounterMap = CounterOrderedMap; -using CNodeIndexCounterMap = CounterOrderedMap; - -using FuncGraphMap = OrderedMap; - -const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; -const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; -const char FUNC_GRAPH_FLAG_CORE[] = "core"; -const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel"; -const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; - -namespace abstract { -class AbstractKeywordArg; -using AbstractKeywordArgPtr = std::shared_ptr; -class AbstractFunction; -using AbstractFunctionPtr = std::shared_ptr; -} // namespace abstract - -// ANF transform class -// either a primitive or a func_graph -class FuncGraphTransform { - public: - enum Type { kGtPrimitive, kGtFuncGraph }; - - explicit FuncGraphTransform(const PrimitivePtr prim, const FuncGraphPtr func_graph = nullptr) - : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {} - - explicit FuncGraphTransform(const FuncGraphPtr &func_graph, const PrimitivePtr &prim = func_graph_prim_) - : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {} - - FuncGraphTransform(const FuncGraphTransform &t) : prim_(t.prim_), func_graph_(t.func_graph_) {} - - ~FuncGraphTransform() = default; - - Type type() const { - if (IsFuncGraph()) { - return kGtFuncGraph; - } else { - return kGtPrimitive; - } - } - - bool IsPrimitive() const { return (func_graph_.lock() == nullptr); } - bool IsFuncGraph() const { return (func_graph_.lock() != nullptr); } - FuncGraphPtr func_graph() const { return func_graph_.lock(); } - PrimitivePtr primitive() const { return prim_; } - - FuncGraphTransform &operator=(const FuncGraphTransform &t) { - if (this != &t) { - prim_ = t.prim_; - func_graph_ = t.func_graph_; - } - return *this; - } - - private: - PrimitivePtr prim_; - // FuncGraph will be hold by FuncGraphManager, so weak_ptr is enough here. - // And use weak_ptr can break the reference cycle between "primal" and "grad" graph in - // FPropRemapper::FinalizeGraph(). - FuncGraphWeakPtr func_graph_; - static const PrimitivePtr func_graph_prim_; -}; - -class FuncGraphBase : public Value { - public: - FuncGraphBase() = default; - - ~FuncGraphBase() override = default; - MS_DECLARE_PARENT(FuncGraphBase, Value); -}; - -extern const char kFuncGraphFlagUndetermined[]; - -class FuncGraph : public FuncGraphBase { - public: - FuncGraph(); - - ~FuncGraph() override = default; - MS_DECLARE_PARENT(FuncGraph, FuncGraphBase); - - // get the graph's abstract - abstract::AbstractFunctionPtr abstract(); - abstract::AbstractBasePtr MakeAbstractClosure(const abstract::AnalysisContextPtr &context); - - // return the graph's output, or nullptr if not yet deduced - AnfNodePtr output() const; - void set_output(const AnfNodePtr &value, bool force_new_ret = false); - - const std::vector ¶meters() const { return parameters_; } - virtual ParameterPtr add_parameter(); - void add_parameter(const ParameterPtr &p); - void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); } - void set_parameters(const std::vector ¶ms) { parameters_ = params; } - // add a weight parameter with specific name - ParameterPtr AddWeightParameter(const std::string &name); - - // create a cnode with given inputs, bound to this graph - virtual CNodePtr NewCNode(const std::vector &inputs = std::vector()); - - // create a cnode with given inputs, bound to this graph, and set to specific scope - CNodePtr NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope); - - // Functions for handling variable argument, keyword-only arguments and variable keyword argument - AnfNodePtr GetDefaultValueByName(const std::string &name); - void set_param_default_value(const std::string &name, const AnfNodePtr &node) { - parameter_default_value_[name] = node; - } - void SetDefaultValues(const std::vector &name_list, const std::vector &value_list); - void ClearDefaultValues(); - size_t GetDefaultValueCount(); - std::map ¶meter_default_value() { return parameter_default_value_; } - void set_has_vararg(bool has_) { has_vararg_ = has_; } - bool has_vararg() const { return has_vararg_; } - AnfNodePtr GetVariableArgParameter(); - std::string GetVariableArgName(); - void set_has_kwarg(bool has_) { has_kwarg_ = has_; } - bool has_kwarg() const { return has_kwarg_; } - void set_kwonlyargs_count(int count) { kwonlyargs_count_ = count; } - int kwonlyargs_count() const { return kwonlyargs_count_; } - AnfNodePtr GetVariableKwargParameter(); - std::string GetVariableKwargName(); - void set_hyper_param_count(size_t count) { hyper_param_count_ = count; } - size_t hyper_param_count() const { return hyper_param_count_; } - int GetPositionalArgsCount() const; - AnfNodePtr GetParameterByName(const std::string &name); - bool NeedGenerate(const std::vector &kwarg_list); - FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list); - void set_is_generate(bool generated) { is_generated_ = generated; } - bool is_generated() const { return is_generated_; } - - std::unordered_map &attrs() { return attrs_; } - void set_attrs(const std::unordered_map &attrs) { - for (auto &attr : attrs) { - attrs_[attr.first] = attr.second; - } - } - bool has_flag(const std::string &key); - void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); } - void erase_flag(const std::string &key) { (void)attrs_.erase(key); } - - bool has_attr(const std::string &key); - ValuePtr get_attr(const std::string &key); - void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } - - std::unordered_map &transforms() { return transforms_; } - void set_transforms(const std::unordered_map &transforms) { - transforms_ = transforms; - } - - CNodePtr get_return() const { return return_; } - void set_return(const CNodePtr &cnode) { return_ = cnode; } - - FuncGraphManagerPtr manager() const { return manager_.lock(); } - void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr(m); } - - std::string ToString() const override; - GraphDebugInfoPtr debug_info(); - void set_debug_info(const GraphDebugInfoPtr &info) { - if (info == nullptr) { - MS_LOG(EXCEPTION) << "Graph set null debug info"; - } - this->debug_info_ = info; - } - - // get all nodes belonging to this func graph - const AnfNodeSet &nodes(); - void CopyNodes(const FuncGraphPtr &source); - void ClearNodes(); - void AddNode(AnfNodePtr node); - void DropNode(AnfNodePtr node); - - // get all value_nodes belonging to this func graph - const AnfNodeCounterMap &value_nodes(); - void CopyValueNodes(const FuncGraphPtr &source); - void ClearValueNodes(); - void AddValueNode(AnfNodePtr node, int count = 1); - void DropValueNode(AnfNodePtr node); - - // get all free vars directly used in this func graph - const AnfNodeCounterMap &free_variables(); - void CopyFreeVariables(const FuncGraphPtr &source); - void ClearFreeVariables(); - bool AddFreeVariable(AnfNodePtr node, int count = 1); - bool DropFreeVariable(AnfNodePtr node); - - // get all vars required by this func graph - const BaseRefCounterMap &free_variables_total(); - - // Return the set of graphs free_variables_total belong to. - std::vector free_variables_nodes(); - - // get all vars that are func graphs - std::vector free_variables_func_graphs(); - - // get all value nodes of func graph directly used by this func graph - const FuncGraphCounterMap &func_graphs_used(); - void CopyFuncGraphsUsed(const FuncGraphPtr &source); - void ClearFuncGraphsUsed(); - bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1); - bool DropFuncGraphUsed(FuncGraphPtr fg); - - // get all value nodes of J func graph directly used by this func graph - const FuncGraphCounterMap &j_func_graphs(); - void CopyJFuncGraphs(const FuncGraphPtr &source); - void ClearJFuncGraphs(); - void AddJFuncGraph(FuncGraphPtr fg, int count = 1); - void DropJFuncGraph(FuncGraphPtr fg); - - // get all func graphs nested used by this func graph - const FuncGraphSet &func_graphs_used_total(); - - // get all user value nodes of this func graph, by CNode and its input's index - const CNodeIndexCounterMap &func_graph_cnodes_index(); - void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source); - void ClearFuncGraphCNodesIndex(); - void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1); - void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node); - - // Return the parent of this graph. - FuncGraphPtr parent(); - - // Return the children of this graph. - const FuncGraphSet &children(); - - // Return the scope of this graph, scope have graph self but children not have. - const FuncGraphSet &scope(); - - // Return whether this graph is recursive - bool recursive(); - - // Return graphs which forms a recursive loop - std::shared_ptr> recursive_graphs(); - - std::size_t hash() const override { return std::hash{}(this); } - - void DumpFuncGraph(const std::string &path = "./func_graph.dot"); - - bool operator==(const Value &other) const override { - if (other.isa()) { - return &other == this; - } else { - return false; - } - } - void GenerateVarParams(const FuncGraphPtr &specialized_graph, std::vector *specialized_parameter_list, - std::unordered_map *repl_nodes, int variable_args_count, - int pos_args_input_count); - - void GenerateKwParams(const FuncGraphPtr &specialized_graph, std::vector *specialized_parameter_list, - const std::vector &kwarg_list, - std::unordered_map *repl_nodes); - - void GenerateDefaultValue(const FuncGraphPtr &specialized_graph, - const std::vector &specialized_parameter_list, - std::unordered_map *repl_nodes); - - const std::vector ¶mter_obj_nodes() const { return paramter_obj_nodes_; } - void add_parameter_obj_node(const AnfNodePtr &p); - - std::unordered_map &make_ref_params() { return make_ref_params_; } - - std::unordered_map attrs_; - std::unordered_map transforms_; - // parameter default value - std::map parameter_default_value_; - std::unordered_map make_ref_params_; - size_t seen_; - - std::list GetOrderedCnodes(); - void EraseUnusedNodeInOrder(const AnfNodePtr &n); - void EraseUnusedNodeInOrder(); - void CheckOrder(); - void DumpCNodeList(); - void ReleaseFullOrderToEffectOrder(); - void SetEffectDepends(const std::vector &depend_inputs); - bool HasEffect(const CNodePtr &cnode); - - private: - // graph is manipulated by manager and others - friend FuncGraphManager; - - // all nodes of the function - AnfNodeSet nodes_; - - // all value nodes of the function - AnfNodeCounterMap value_nodes_; - - // all func graph value nodes of the function - FuncGraphCounterMap func_graphs_used_; - - // all free variables of the function - AnfNodeCounterMap free_variables_; - - // all value nodes calling J in the function - FuncGraphCounterMap j_func_graphs_; - - // all user value nodes of this func graph, recording by CNode and its input's index - CNodeIndexCounterMap func_graph_cnodes_index_; - - // parameters of this function - std::vector parameters_; - std::vector paramter_obj_nodes_; - - // whether there is a *args and **kwargs, and count kwonlyargs'number - bool has_vararg_; - bool has_kwarg_; - int kwonlyargs_count_; - // the hyper param is placed on the top graph, - // and positioned in the end of the param list, so we record the number to trace the position - size_t hyper_param_count_; - // the argument input list for the graph used to generate this graph - bool is_generated_; - - // the cnode that calls 'return' primitive - // we use shared pointer to manage it. - CNodePtr return_; - - // back-ref to its manager - // hold a weak ref to FuncGraphManager as FuncGraphManager also hold many ref to FuncGraph. - // Otherwise, FuncGraph and FuncGraphManager will make a reference cycles. - // Notes: Normally, there will be a global FuncGraphManager, it will hold all FuncGraphs. - // In some ut test cases, they may use local FuncGraphManager in function which - // generating the func graph, when go outside of that function, func graph will have no - // FuncGraphManager. In that special case, Manage() should be called to make the func graph - // managed. - std::weak_ptr manager_; - - GraphDebugInfoPtr debug_info_; - void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, - std::unordered_map *repl_nodes, - const std::vector &kwarg_keys_tuple_nodes, - const std::vector &kwarg_values_tuple_nodes); - - // CNode order which relates to origin code order - std::list order_; -}; - -inline CNodePtr NewCNode(const std::vector &inputs, const FuncGraphPtr &fg) { - MS_EXCEPTION_IF_NULL(fg); - return fg->NewCNode(inputs); -} - -size_t NewFgSeenGeneration(); - -// Find the root cnodes of a segment of cnodes. -std::shared_ptr> FindRoots(const std::vector &segment); -// Find the leaf cnodes of a segment of cnodes. -std::shared_ptr> FindLeaves(const std::vector &segment); -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_ diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc deleted file mode 100644 index 4a0c69d99a..0000000000 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ /dev/null @@ -1,650 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/func_graph_cloner.h" - -#include - -#include "ir/manager.h" -#include "ir/param_value_py.h" -#include "operator/ops.h" -#include "utils/convert_utils_base.h" -#include "utils/log_adapter.h" -#include "utils/profile.h" -#include "utils/context/ms_context.h" - -// namespace to support intermediate representation definition -namespace mindspore { -Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, - bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation) - : clone_all_valuenodes_(clone_all_valuenodes), - clone_all_child_graphs_(clone_all_child_graphs), - clone_all_used_graphs_(clone_all_used_graphs), - relation_(relation), - target_relation_(target_relation == nullptr ? relation : target_relation) { - for (auto &func_graph : func_graphs) { - AddClone(func_graph); - } - scope_ = kDefaultScope; - type_ = kBasic; -} - -void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, - const AnfNodePtrList ¶ms, CloneType type) { - if (func_graph != nullptr) { - todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); - type_ = type; - } -} - -void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) { - MS_EXCEPTION_IF_NULL(node); - if (repl_node_.find(node) != repl_node_.end() || node->isa()) { - return; - } - if (node->isa()) { - CloneParameter(node, target); - } else if (node->isa()) { - CloneCNode(node, target); - } -} - -void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(target); - TraceManager::DebugTrace(node->debug_info(), relation_); - auto new_param = (is_add) ? target->add_parameter() : std::make_shared(target); - auto old_param = node->cast(); - new_param->set_abstract(old_param->abstract()); - new_param->set_name(old_param->name()); - if (old_param->has_default()) { - auto param_value = std::dynamic_pointer_cast(old_param->default_param()); - auto param_value_new = std::make_shared(param_value->value()); - new_param->set_default_param(param_value_new); - } - ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); - new_param->set_scope(scope); - repl_node_[node] = new_param; - TraceManager::EndTrace(); -} - -void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(target); - TraceManager::DebugTrace(node->debug_info(), relation_); - CNodePtr new_node = std::make_shared(AnfNodePtrList{}, target); - auto old_node = node->cast(); - new_node->set_abstract(old_node->abstract()); - ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); - new_node->set_scope(scope); - new_node->set_kernel_info(old_node->kernel_info_ptr()); - repl_node_[old_node] = new_node; - nodes_.emplace_back(old_node, new_node); - TraceManager::EndTrace(); -} - -void Cloner::CloneValueNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - TraceManager::DebugTrace(node->debug_info(), relation_); - ValueNodePtr new_const = NewValueNode(GetValueNode(node)); - ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); - new_const->set_scope(scope); - new_const->set_abstract(node->abstract()); - repl_node_[node] = new_const; - TraceManager::EndTrace(); -} - -void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(target); - TraceManager::DebugTrace(node->debug_info(), relation_); - ValueNodePtr new_const = NewValueNode(target); - ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); - new_const->set_scope(scope); - new_const->set_abstract(node->abstract()); - repl_node_[node] = new_const; - TraceManager::EndTrace(); -} - -void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(manager_); - if (!clone_all_valuenodes_) { - return; - } - auto &value_nodes = func_graph->value_nodes(); - for (auto &value_node : value_nodes) { - auto old_node = value_node.first; - MS_EXCEPTION_IF_NULL(old_node); - if (repl_node_.count(old_node) == 0) { - CloneValueNode(old_node); - } - } -} - -void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(manager_); - if (!clone_all_child_graphs_) { - return; - } - auto &scopes = manager_->scopes(func_graph); - for (auto &graph : scopes) { - if (graph != func_graph) { - todo_.push_back({graph, nullptr, {}}); - } - } -} - -void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(manager_); - if (!clone_all_used_graphs_) { - return; - } - auto &used = func_graph->func_graphs_used(); - for (auto &fg : used) { - todo_.push_back({fg.first, nullptr, {}}); - } -} - -void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - for (auto &item : func_graph->parameter_default_value()) { - auto nodes = DeepLinkedGraphSearch(item.second); - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - CloneNode(node, target_func_graph); - } else if (node->isa()) { - CloneValueNode(node); - } - } - } -} - -void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - MS_EXCEPTION_IF_NULL(manager_); - auto return_node = repl_node_[func_graph->get_return()]->cast(); - if (return_node == nullptr) { - MS_LOG(EXCEPTION) << "Can't find replicate node for return."; - } - target_func_graph->set_return(return_node); - - auto &cnodes = func_graph->func_graph_cnodes_index(); - for (auto &cnode : cnodes) { - auto parent = cnode.first->first->cast(); - auto valuenode = parent->input(cnode.first->second); - CloneValueNode(valuenode, target_func_graph); - } -} - -void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) { - MS_EXCEPTION_IF_NULL(func_graph); - auto &old_params = func_graph->parameters(); - if (old_params.size() != params.size()) { - MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; - return; - } - for (size_t i = 0; i < old_params.size(); ++i) { - repl_node_[old_params[i]] = params[i]; - } -} - -void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); - *target_func_graph = std::make_shared(); - (*target_func_graph)->set_attrs(func_graph->attrs()); - (*target_func_graph)->set_transforms(func_graph->transforms()); - (*target_func_graph)->set_has_vararg(func_graph->has_vararg()); - (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); - (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count()); - (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); - (*target_func_graph)->set_is_generate(func_graph->is_generated()); - TraceManager::EndTrace(); -} - -void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - auto ¶ms = func_graph->parameters(); - for (auto ¶m : params) { - CloneParameter(param, target_func_graph, true); - } - repl_func_graph_[func_graph] = target_func_graph; -} - -void Cloner::GenParameters(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto &free_vars = manager_->free_variables_total(); - auto iter = free_vars.find(func_graph); - if (iter == free_vars.end()) { - return; - } - - for (auto &fv_map : iter->second) { - auto &free_var = fv_map.first; - if (utils::isa(free_var)) { - repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast(free_var))); - } - } -} - -void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { - param->set_abstract(node->abstract()); - if (node->isa()) { - ParameterPtr old_param = dyn_cast(node); - if (old_param->has_default()) { - auto param_value = std::dynamic_pointer_cast(old_param->default_param()); - auto param_value_new = std::make_shared(param_value->value()); - param->set_default_param(param_value_new); - } - param->set_name(old_param->name()); - } -} - -ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) { - TraceManager::DebugTrace(std::make_shared(node->debug_info())); - ParameterPtr param = std::make_shared(func_graph); - TraceManager::EndTrace(); - CloneParameter(param, node); - if (is_add) { - func_graph->add_parameter(param); - } - repl_node_[param] = node; - repl_map_node_[func_graph][node] = param; - return param; -} - -void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, - AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) { - AnfNodePtrList parameters; - std::unordered_set old_params; - for (auto ¶m : func_graph->parameters()) { - auto iter = repl_node_.find(param); - if (iter != repl_node_.end()) { - (void)old_params.insert(iter->second); - parameters.push_back(param); - } else { - parameters.push_back(AddParameter(func_graph, param, false)); - (void)old_params.insert(param); - } - } - AnfNodePtr new_param = nullptr; - for (auto ¶m : params) { - auto old_param = repl_node_[param]; - if (old_param->isa() && old_param->func_graph() == func_graph) { - repl_node_[old_param] = old_param; - repl_map_node_[func_graph][old_param] = old_param; - input_params->push_back(old_param); - continue; - } - if (old_params.find(old_param) != old_params.end()) { - new_param = repl_map_node_[func_graph][old_param]; - input_params->push_back(new_param); - continue; - } - new_param = AddParameter(func_graph, old_param, false); - parameters.push_back(new_param); - lift_params->push_back(new_param); - input_params->push_back(new_param); - } - func_graph->set_parameters(parameters); -} - -void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, - const AnfNodePtrList ¶ms) { - AnfNodePtr node = nullptr; - auto &repl_func_graph = repl_map_func_graph_[func_graph_user]; - auto iter = repl_func_graph.find(func_graph); - if (iter == repl_func_graph.end()) { - node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); - repl_func_graph[func_graph] = node; - } else { - node = iter->second; - } - if (node == nullptr || !node->isa()) { - return; - } - auto cnode = node->cast(); - auto inputs = cnode->inputs(); - (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs)); - cnode->set_inputs(inputs); - OrderParameters(func_graph, inputs); -} - -void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) { - std::unordered_set old_params; - for (auto ¶m : func_graph->parameters()) { - (void)old_params.insert(repl_node_[param]); - } - std::unordered_set new_params; - AnfNodePtrList parameters; - // Ignore the 1st and 2nd param of inputs(such as. partial graph) - for (size_t i = 2; i < inputs.size(); ++i) { - auto input = inputs[i]; - auto param = repl_node_[input]; - if (old_params.find(param) != old_params.end()) { - auto new_param = repl_map_node_[func_graph][param]; - parameters.push_back(new_param); - (void)new_params.insert(new_param); - } - } - for (auto ¶m : func_graph->parameters()) { - if (new_params.find(param) == new_params.end()) { - parameters.push_back(param); - } - } - func_graph->set_parameters(parameters); -} - -void Cloner::SetEdges(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - for (auto &node : func_graph->nodes()) { - if (node == nullptr) { - continue; - } - // Only cnode needed to be handled - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - auto &inputs = cnode->inputs(); - for (size_t i = 0; i < inputs.size(); i++) { - auto &input = inputs[i]; - if (IsValueNode(input)) { - auto graph = GetValueNode(input); - auto &repl_func_graph = repl_map_func_graph_[func_graph]; - if (repl_func_graph.find(graph) != repl_func_graph.end()) { - transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); - } - } else { - auto &repl_node = repl_map_node_[func_graph]; - if (repl_node.find(input) != repl_node.end()) { - transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); - } - } - } - } -} - -void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, - const AnfNodePtrList ¶ms) { - AnfNodePtrList lift_params; - AnfNodePtrList input_params; - AddParameters(func_graph_user, params, &lift_params, &input_params); - AddInputs(func_graph_user, func_graph, input_params); - if (lift_params.empty()) { - return; - } - for (auto &cnode : func_graph_user->func_graph_cnodes_index()) { - LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params); - } -} - -void Cloner::Lift() { - for (auto &func_graph_params : repl_func_graph_params_) { - auto &func_graph = func_graph_params.first; - auto ¶ms = func_graph_params.second; - for (auto &cnode : func_graph->func_graph_cnodes_index()) { - LiftParameters(cnode.first->first->func_graph(), func_graph, params); - } - } -} - -void Cloner::LiftParameters() { - MS_EXCEPTION_IF_NULL(manager_); - transaction_ = manager_->Transact(); - const FuncGraphSet &func_graphs = manager_->func_graphs(); - for (auto &func_graph : func_graphs) { - GenParameters(func_graph); - } - Lift(); - for (auto &func_graph : func_graphs) { - SetEdges(func_graph); - } - transaction_.Commit(); -} - -bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) { - MS_EXCEPTION_IF_NULL(func_graph); - // Make sure only inline once - if (status_.count(func_graph) != 0) { - if (is_inline == status_[func_graph]) { - return false; - } - if (clone_all_used_graphs_) { - MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False."; - return false; - } - } - return true; -} - -void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - MS_EXCEPTION_IF_NULL(manager_); - const AnfNodeSet &nodes = func_graph->nodes(); - for (auto &node : nodes) { - CloneNode(node, target_func_graph); - } -} - -void Cloner::Run() { - if (todo_.empty()) { - return; - } - - if (type_ < kLifting) { - // Basic and Inline Clone - FuncGraphPtrList func_graphs; - (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), - [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; }); - manager_ = Manage(func_graphs, false); - CloneNodes(); - LinkEdges(); - SetDefaults(); - } else { - // Lifting Clone - CloneInfo item = todo_.back(); - manager_ = Manage(item.origin); - LiftParameters(); - } -} - -void Cloner::CloneNodes() { - while (!todo_.empty()) { - CloneInfo item = todo_.back(); - todo_.pop_back(); - - bool is_inline = (item.target != nullptr); - FuncGraphPtr func_graph = item.origin; - FuncGraphPtr target_func_graph = item.target; - (void)graph_set_.insert(func_graph); - - if (!CheckStatus(func_graph, is_inline)) { - continue; - } - - if (is_inline) { - InlineCloneParameters(func_graph, item.params); - CloneAllNodes(func_graph, target_func_graph); - } else { - SetFuncGraphInfo(func_graph, &target_func_graph); - CloneParameters(func_graph, target_func_graph); - CloneAllNodes(func_graph, target_func_graph); - CloneFuncGraphValueNodes(func_graph, target_func_graph); - CloneFuncGraphDefaultValues(func_graph, target_func_graph); - } - - CloneValueNodes(func_graph); - AddChildGraphs(func_graph); - AddTotalGraphs(func_graph); - status_[func_graph] = is_inline; - } -} - -void Cloner::LinkEdges() { - for (auto &node_pair : nodes_) { - CNodePtr old_node = node_pair.first; - CNodePtr new_node = node_pair.second; - MS_EXCEPTION_IF_NULL(old_node); - MS_EXCEPTION_IF_NULL(new_node); - for (auto &input : old_node->inputs()) { - auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; - new_node->add_input(new_input); - } - } -} - -// For the graphs cloned, update its default value map to the cloned nodes -void Cloner::SetDefaults() { - for (auto &item : graph_set_) { - MS_EXCEPTION_IF_NULL(item); - if (repl_func_graph_.count(item) != 0) { - for (auto ¶m_def : item->parameter_default_value()) { - MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); - if (repl_node_.count(param_def.second) != 0) { - repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); - } else { - repl_func_graph_[item]->set_param_default_value(param_def.first, param_def.second); - } - } - } - } -} - -AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) { - MS_EXCEPTION_IF_NULL(root); - if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { - MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; - } - CloneNode(root, repl_func_graph_[root->func_graph()]); - auto iter = repl_node_.find(root); - if (iter != repl_node_.end()) { - return iter->second; - } - MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; -} - -AnfNodePtr Cloner::operator[](const AnfNodePtr &node) { -#ifdef ENABLE_PROFILE - double time = GetTime(); -#endif - Run(); -#ifdef ENABLE_PROFILE - MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerNode", GetTime() - time); -#endif - return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); -} - -FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) { -#ifdef ENABLE_PROFILE - double time = GetTime(); -#endif - Run(); -#ifdef ENABLE_PROFILE - MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerGraph", GetTime() - time); -#endif - return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); -} - -FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - Cloner cloner({func_graph}, false, true, true, std::make_shared(), nullptr); - return cloner[func_graph]; -} - -AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, - const AnfNodePtrList &func_graph_args, const ScopePtr &scope) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - Cloner cloner({}, false); - if (scope != nullptr) { - cloner.set_scope(scope); - } - cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline); - return cloner[func_graph->output()]; -} - -FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - Cloner cloner({}, false); - cloner.AddClone(func_graph, nullptr, {}, kLifting); - return cloner[func_graph]; -} - -ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphPtrList func_graphs = {func_graph}; - ClonerPtr cloner = - std::make_shared(func_graphs, false, false, false, std::make_shared(), relation); -#ifdef ENABLE_PROFILE - double time = GetTime(); -#endif - cloner->Run(); -#ifdef ENABLE_PROFILE - MsProfile::StatTime("func_graph_cloner_run.FuncGraphSpecializer", GetTime() - time); -#endif - return cloner; -} - -FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { - MS_EXCEPTION_IF_NULL(func_graph); - TraceManager::DebugTrace(func_graph->debug_info(), relation); - auto new_func_graph = std::make_shared(); - TraceManager::EndTrace(); - - auto ¶meters = func_graph->parameters(); - (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void { - MS_EXCEPTION_IF_NULL(param); - TraceManager::DebugTrace(std::make_shared(param->debug_info())); - (void)new_func_graph->add_parameter(); - TraceManager::EndTrace(); - }); - - Cloner cloner = Cloner(); - cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters()); - AnfNodePtr output = cloner[func_graph->output()]; - new_func_graph->set_output(output); - new_func_graph->set_has_vararg(func_graph->has_vararg()); - new_func_graph->set_has_kwarg(func_graph->has_kwarg()); - new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); - new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); - new_func_graph->set_is_generate(func_graph->is_generated()); - for (auto &item : func_graph->parameter_default_value()) { - new_func_graph->set_param_default_value(item.first, cloner[item.second]); - } - - if (MsContext::GetInstance()->is_multi_graph_sink()) { - if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { - new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); - } - } - - if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - } - - return new_func_graph; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph_extends.cc b/mindspore/ccsrc/ir/func_graph_extends.cc deleted file mode 100644 index ad7aa6ee0c..0000000000 --- a/mindspore/ccsrc/ir/func_graph_extends.cc +++ /dev/null @@ -1,422 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/func_graph.h" - -#include -#include -#include - -#include "ir/manager.h" -#include "ir/func_graph_cloner.h" -#include "operator/ops.h" -#include "utils/ordered_set.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "pipeline/static_analysis/abstract_function.h" - -#include "debug/anf_ir_dump.h" -#include "debug/trace.h" -#include "debug/draw.h" -#include "debug/label.h" - -namespace mindspore { -using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractFunctionPtr; -using mindspore::abstract::AnalysisContextPtr; -using mindspore::abstract::PrimitiveAbstractClosure; -using mindspore::abstract::VirtualAbstractClosure; - -AbstractFunctionPtr FuncGraph::abstract() { - AbstractBasePtrList args_spec_list; - - for (auto &p : parameters_) { - MS_EXCEPTION_IF_NULL(p); - if (p->abstract() == nullptr) { - MS_LOG(ERROR) << "Error!!"; - return nullptr; - } - args_spec_list.push_back(p->abstract()); - } - - if (nullptr == output()) { - MS_LOG(ERROR) << "Error func graph no output"; - return nullptr; - } - - return std::make_shared(args_spec_list, output()->abstract()); -} - -abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) { - AnalysisContextPtr temp_context = context; - if (temp_context == nullptr) { - temp_context = abstract::AnalysisContext::DummyContext(); - } - return std::make_shared(shared_from_base(), temp_context); -} - -void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { - if (force_new_ret || return_ == nullptr) { - std::vector params({NewValueNode(prim::kPrimReturn), value}); - FuncGraphPtr this_graph = shared_from_base(); - return_ = this_graph->NewCNode(params); - } else { - if (manager_.lock()) { - manager_.lock()->SetEdge(return_, 1, value); - } else { - return_->set_input(1, value); - } - } - - return_->set_abstract(value->abstract()); - - AnfNodePtr input0 = return_->input(0); - - PrimitivePtr return_prim = prim::kPrimReturn; - auto f = std::make_shared(return_prim, input0); - input0->set_abstract(f); -} - -void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base()); } - -void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, - std::vector *specialized_parameter_list, - std::unordered_map *repl_nodes, int variable_args_count, - int pos_args_input_count) { - // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple - if (specialized_graph->has_vararg()) { - TraceManager::DebugTrace( - std::make_shared(specialized_graph->GetVariableArgParameter()->debug_info())); - std::vector var_param_tuple_nodes; - var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); - - if (variable_args_count < 0) { - MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count - << " were given."; - } - // for python variable argument input , there is no upper limit - for (int i = 0; i < variable_args_count; ++i) { - ParameterPtr p = std::make_shared(specialized_graph); - std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i); - p->set_name(param_name); - MS_EXCEPTION_IF_NULL(p->debug_info()); - p->debug_info()->set_name(param_name); - var_param_tuple_nodes.push_back(p); - MS_EXCEPTION_IF_NULL(specialized_parameter_list); - specialized_parameter_list->push_back(p); - } - auto var_tuple_param = specialized_graph->NewCNode(var_param_tuple_nodes); - (void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param); - TraceManager::EndTrace(); - } else if (variable_args_count > 0) { - MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << this->GetPositionalArgsCount() - << " positional arguments, but " << pos_args_input_count << " were given."; - } -} - -void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, - std::vector *specialized_parameter_list, - const std::vector &kwarg_list, - std::unordered_map *repl_nodes) { - std::vector kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; - std::vector kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; - - for (const auto &kwarg : kwarg_list) { - MS_EXCEPTION_IF_NULL(kwarg); - std::string kw_param_name = kwarg->get_key(); - MS_EXCEPTION_IF_NULL(specialized_graph); - AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name); - // if not find correspoding parameter node - if (param_node == nullptr) { - if (!has_kwarg()) { - MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name; - } else { - ParameterPtr p = std::make_shared(specialized_graph); - std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; - MS_EXCEPTION_IF_NULL(specialized_parameter_list); - auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), - [param_name](const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto param = node->cast(); - return param != nullptr && param->name() == param_name; - }); - if (find_kw_arg_in_list) { - MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name; - } - p->set_name(param_name); - p->debug_info()->set_name(param_name); - kwarg_keys_tuple_nodes.push_back(NewValueNode(kw_param_name)); - auto extract_node = - specialized_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), p}); - kwarg_values_tuple_nodes.push_back(extract_node); - specialized_parameter_list->push_back(p); - } - } else { - auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); - // multiply values found given for parameter - if (node_itr != specialized_parameter_list->end()) { - MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name; - } else { - specialized_parameter_list->push_back(param_node); - auto extract_node = specialized_graph->NewCNode( - {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); - (void)repl_nodes->emplace(param_node, extract_node); - } - } - } - - GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); -} - -void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, - std::unordered_map *repl_nodes, - const std::vector &kwarg_keys_tuple_nodes, - const std::vector &kwarg_values_tuple_nodes) { - if (has_kwarg()) { - MS_EXCEPTION_IF_NULL(specialized_graph); - TraceManager::DebugTrace( - std::make_shared(specialized_graph->GetVariableKwargParameter()->debug_info())); - auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes); - auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes); - auto make_dict_node = - specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values}); - MS_EXCEPTION_IF_NULL(repl_nodes); - (void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node); - TraceManager::EndTrace(); - } -} - -bool FuncGraph::NeedGenerate(const std::vector &kwarg_list) { - // if the function does not have any vararg/kwarg/kwonly/default value/kw args input - // return the original graph - if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { - return false; - } - - // if the graph is generated for specific input, do not need to generate again - if (is_generated()) { - return false; - } - return true; -} - -void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, - const std::vector &specialized_parameter_list, - std::unordered_map *repl_nodes) { - MS_EXCEPTION_IF_NULL(specialized_graph); - for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { - auto param_node = specialized_graph->parameters()[i]; - MS_EXCEPTION_IF_NULL(param_node); - auto param_name = param_node->cast()->name(); - auto node_itr = std::find(specialized_parameter_list.begin(), specialized_parameter_list.end(), param_node); - if (node_itr != specialized_parameter_list.end()) { - continue; - } - if (param_name == specialized_graph->GetVariableArgName() || - param_name == specialized_graph->GetVariableKwargName()) { - continue; - } - auto default_value = specialized_graph->GetDefaultValueByName(param_name); - if (default_value == nullptr) { - MS_LOG(EXCEPTION) << "Miss argument input for parameter:" << param_name; - } - MS_EXCEPTION_IF_NULL(repl_nodes); - (void)repl_nodes->emplace(param_node, default_value); - } -} - -FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { - std::vector kwarg_list; - size_t arguments_count = args_spec_list.size(); - for (const auto &arg : args_spec_list) { - // if it is a keyword argument - MS_EXCEPTION_IF_NULL(arg); - if (arg->isa()) { - kwarg_list.push_back(dyn_cast(arg)); - } - } - if (!NeedGenerate(kwarg_list)) { - return shared_from_base(); - } - FuncGraphPtr specialized_graph = BasicClone(shared_from_base()); - size_t kwarg_count = kwarg_list.size(); - int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count()); - int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); - int variable_args_count = pos_args_input_count - pos_args_count; - std::vector specialized_parameter_list; - std::unordered_map repl_nodes; - // the parameters that has arg input, copy from original parameters - for (size_t i = 0; i < IntToSize(pos_args_count); ++i) { - specialized_parameter_list.push_back(specialized_graph->parameters()[i]); - } - - GenerateVarParams(specialized_graph, &specialized_parameter_list, &repl_nodes, variable_args_count, - pos_args_input_count); - - GenerateKwParams(specialized_graph, &specialized_parameter_list, kwarg_list, &repl_nodes); - - GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes); - - // append hyper parameter to specialized_parameter_list - MS_EXCEPTION_IF_NULL(specialized_graph); - auto params = specialized_graph->parameters(); - (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), - std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); - - std::shared_ptr manager = mindspore::Manage(specialized_graph, false); - auto tr = manager->Transact(); - for (auto &node_pair : repl_nodes) { - MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" - << node_pair.second->DebugString(); - (void)tr.Replace(node_pair.first, node_pair.second); - } - tr.SetParameters(specialized_graph, specialized_parameter_list); - tr.Commit(); - specialized_graph->set_has_kwarg(false); - specialized_graph->set_has_vararg(false); - specialized_graph->set_kwonlyargs_count(0); - specialized_graph->ClearDefaultValues(); - specialized_graph->set_is_generate(true); - return specialized_graph; -} - -const char kPrimHasEffect[] = "_side_effect_flag"; - -bool FuncGraph::HasEffect(const CNodePtr &cnode) { - auto prim = GetCNodePrimitive(cnode); - if (prim != nullptr && prim->isa()) { - auto do_sig = prim->cast(); - auto prim_val = do_sig->function(); - if (prim_val != nullptr && prim_val->isa()) { - prim = prim_val->cast(); - } else { - prim = nullptr; - } - } - if (prim != nullptr) { - auto effect_val = prim->GetAttr(kPrimHasEffect); - if (effect_val && effect_val->isa()) { - auto effect_bool = GetValue(effect_val); - return effect_bool; - } - } - return false; -} - -std::shared_ptr> FindRoots(const std::vector &segment) { - std::shared_ptr> roots = std::make_shared>(segment); - for (const auto &node : segment) { - if (roots->size() == 1) { - return roots; - } - auto input_size = node->size(); - for (size_t i = 0; i < input_size; i++) { - auto in_node = node->input(i); - auto in_cnode = in_node->cast(); - if (in_cnode != nullptr) { - (void)roots->erase(in_cnode); - } - } - } - return roots; -} - -std::shared_ptr> FindLeaves(const std::vector &segment) { - std::shared_ptr> nodes = std::make_shared>(segment); - for (const auto &node : segment) { - if (nodes->size() == 1) { - return nodes; - } - if (IsPrimitiveCNode(node, prim::kPrimSwitch)) { - (void)nodes->erase(node); - continue; - } - auto input_size = node->size(); - for (size_t i = 0; i < input_size; i++) { - auto in_node = node->input(i); - if (!in_node->isa()) { - continue; - } - auto in_cnode = in_node->cast(); - if (in_cnode != nullptr) { - if (std::find(segment.begin(), segment.end(), in_cnode) != segment.end()) { - (void)nodes->erase(node); - break; - } - } - } - } - return nodes; -} - -void FuncGraph::ReleaseFullOrderToEffectOrder() { - MS_LOG(DEBUG) << "Flag has_effect " << has_flag(GRAPH_FLAG_HAS_EFFECT) << "."; - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - std::list depends_order; - std::vector segment; - for (const auto &cnode : order_) { - if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { - continue; - } - if (HasEffect(cnode)) { - MS_LOG(DEBUG) << "Meet a effect node " << cnode->DebugString() << "."; - if (segment.size() > 0) { - auto roots = FindRoots(segment); - for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { - depends_order.push_back(*iter); - } - } - segment.clear(); - depends_order.push_back(cnode); - } else { - MS_LOG(DEBUG) << "Meet a general node " << cnode->DebugString() << "."; - segment.push_back(cnode); - } - } - if (segment.size() > 1) { - auto roots = FindRoots(segment); - for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { - depends_order.push_back(*iter); - } - } - std::vector depend_inputs; - auto old_ret = output(); - for (auto iter = depends_order.rbegin(); iter != depends_order.rend(); (void)iter++) { - if (*iter != old_ret) { - depend_inputs.push_back(*iter); - } - } - set_flag(GRAPH_FLAG_HAS_EFFECT, false); - set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); - if (!depend_inputs.empty()) { - SetEffectDepends(depend_inputs); - } - } -} - -void FuncGraph::SetEffectDepends(const std::vector &depend_inputs) { - auto old_ret = output(); - std::vector inputs{NewValueNode(prim::kPrimDepend), old_ret}; - (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); - auto new_ret = NewCNode(inputs); - auto mng = manager(); - if (mng) { - (void)mng->Replace(old_ret, new_ret); - } else { - return_->set_input(1, new_ret); - } -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/lite/param_value_lite.h b/mindspore/ccsrc/ir/lite/param_value_lite.h deleted file mode 100644 index 2b249cfa4f..0000000000 --- a/mindspore/ccsrc/ir/lite/param_value_lite.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_ -#define MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_ - -#include - -#include "ir/anf.h" - -namespace mindspore { -class ParamValueLite : public ParamValue { - public: - ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} - virtual ~ParamValueLite() = default; - - size_t tensor_size() const { return tensor_size_; } - void set_tensor_size(size_t size) { tensor_size_ = size; } - - void *tensor_addr() const { return tensor_addr_; } - void set_tensor_addr(void *addr) { tensor_addr_ = addr; } - - private: - void *tensor_addr_; - size_t tensor_size_; -}; - -using ParamValueLitePtr = std::shared_ptr; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_ diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc deleted file mode 100644 index cf56500aea..0000000000 --- a/mindspore/ccsrc/ir/manager.cc +++ /dev/null @@ -1,914 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/manager.h" - -#include -#include -#include - -#include "debug/trace_base.h" -#include "ir/func_graph.h" -#include "utils/profile.h" -#include "utils/convert_utils_base.h" -#include "operator/ops.h" - -namespace mindspore { - -FuncGraphManagerPtr MakeManager(const std::vector &func_graphs, bool manage) { - auto m = std::make_shared(func_graphs, manage); - m->Init(); - return m; -} - -FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool manage) { - FuncGraphManagerPtr m = nullptr; - bool root = false; - - for (auto &fg : func_graphs) { - if (fg == nullptr) { - continue; - } - if (fg->manager() != nullptr) { - m = fg->manager(); - break; - } - } - - if (m == nullptr) { - std::vector tmp; - m = MakeManager(tmp, manage); - root = true; - } - - for (auto &fg : func_graphs) { - if (fg == nullptr) { - continue; - } - m->AddFuncGraph(fg, root); - } - return m; -} - -FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) { - std::vector func_graphs = {func_graph}; - return Manage(func_graphs, manage); -} - -FuncGraphManager::FuncGraphManager(const std::vector &roots, bool manage) - : roots_(roots), is_manage_(manage) { - Reset(); -} - -void FuncGraphManager::Reset() { - func_graphs_ = FuncGraphSet(); - all_nodes_ = AnfNodeSet(); - node_users_ = NodeUsersMap(); - - signals_ = std::make_shared(); - - func_graph_parents_total_ = std::make_shared(this); - func_graph_parent_ = std::make_shared(this); - children_ = std::make_shared(this); - scopes_ = std::make_shared(this); - free_variables_total_ = std::make_shared(this); - func_graphs_used_total_ = std::make_shared(this); - recursive_ = std::make_shared(this); - j_total_ = std::make_shared(this); - - limit_ = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); -} - -void FuncGraphManager::Init() { - auto roots = roots_; - roots_ = FuncGraphSet(); - - for (auto &fg : roots) { - AddFuncGraph(fg, true); - } -} - -FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); - func_graph_parents_total_->Recompute(fg); - MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString(); - return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; -} - -FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(func_graph_parent_); - MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); - func_graph_parent_->Recompute(fg); - if (func_graph_parent_->parent_analysis().count(fg) == 0) { - MS_LOG(WARNING) << "This func graph is not in manager:" << fg->ToString(); - return nullptr; - } - MS_LOG(DEBUG) << "End parents func graph " << fg->ToString(); - return func_graph_parent_->parent_analysis()[fg]; -} - -FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(children_); - MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); - children_->Recompute(fg); - return children_->children_analysis()[fg]; -} - -FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(scopes_); - MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); - scopes_->Recompute(fg); - MS_LOG(DEBUG) << "End scopes func graph:" << fg->ToString(); - return scopes_->scope_analysis()[fg]; -} - -FVTotalMap &FuncGraphManager::free_variables_total() const { - MS_EXCEPTION_IF_NULL(free_variables_total_); - free_variables_total_->Recompute(); - return free_variables_total_->fv_total_analysis(); -} - -FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(func_graphs_used_total_); - func_graphs_used_total_->Recompute(fg); - return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; -} - -bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - recursive_->Recompute(fg); - if (recursive_->recursive_analysis().count(fg) == 0) { - MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); - return false; - } - return recursive_->recursive_analysis()[fg]; -} - -std::shared_ptr> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - if (recursive(fg)) { - if (!recursive_->recursive_map().count(fg)) { - auto trace = std::list(); - recursive_->CheckRecursiveGraphs(fg, &trace); - } - if (recursive_->recursive_map().count(fg) == 0) { - MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); - return nullptr; - } - return recursive_->recursive_map()[fg]; - } else { - return nullptr; - } -} - -bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(j_total_); - MS_EXCEPTION_IF_NULL(fg); - j_total_->Recompute(fg); - if (j_total_->j_total_analysis().count(fg) == 0) { - MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); - return false; - } - return j_total_->j_total_analysis()[fg]; -} - -// add a func graph to this manager, optionally as a root func graph. -void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { - MS_EXCEPTION_IF_NULL(func_graph); - if (is_root) { - roots_.add(func_graph); - } - if (func_graphs_.contains(func_graph)) { - return; - } - AddIntoManaged(func_graph); - std::vector para = func_graph->parameters(); - AcquireNodes(para); - std::vector return_vec({func_graph->get_return()}); - AcquireNodes(return_vec); -} - -// clear the all information in manager -void FuncGraphManager::Clear() { - func_graphs_.clear(); - all_nodes_.clear(); - node_users_.clear(); - roots_.clear(); - - signals_->InvalidateComputer(); -} - -void FuncGraphManager::KeepRoots(const std::vector &func_graphs) { - MS_LOG(DEBUG) << "Start keep roots"; - bool root_exist = false; - for (auto &item : func_graphs) { - if (roots_.contains(item)) { - root_exist = true; - break; - } - } - - // if the new_root in roots_, we add new_root first, then calculate the func_graphs - // relation to new_root, remove the func_graphs not relation to new_root - // if the new_root not in roots_, we clear the all func_graphs in manager - // then add the new_root - if (root_exist || func_graphs.empty()) { - FuncGraphSet roots(func_graphs); - if (roots.empty()) { - roots = roots_; - } else { - roots_.clear(); - for (auto &item : roots) { - AddFuncGraph(item, true); - } - } - - FuncGraphSet keep; - for (auto &item : roots) { - MS_LOG(DEBUG) << "roots: " << item->ToString(); - keep.update(func_graphs_used_total(item)); -#ifdef DEBUG - for (auto &k : keep) { - MS_LOG(DEBUG) << "keep: " << k->ToString(); - } -#endif - } - MaybeDropFuncGraphs(func_graphs_ - keep, true); - } else { - Clear(); - FuncGraphSet roots(func_graphs); - for (auto &item : roots) { - AddFuncGraph(item, true); - } - } -} - -void FuncGraphManager::RemoveRoots() { - MS_LOG(DEBUG) << "Start remove roots"; - roots_.clear(); - MaybeDropFuncGraphs(func_graphs_, true); -} - -void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { - MS_EXCEPTION_IF_NULL(fg); - if (is_manage_) { - if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { - MS_LOG(WARNING) << "A func graph can only have one manager."; - } - FuncGraphManagerPtr this_manager = shared_from_this(); - fg->set_manager(this_manager); - } - func_graphs_.add(fg); -} - -void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) { - FuncGraphSet todo(func_graphs); - std::set dropped; - // int count = 0; - while (!todo.empty()) { - FuncGraphPtr func_graph = todo.pop(); - MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(DEBUG) << "Maybe drop func graph " << func_graph->ToString(); - if (roots_.contains(func_graph)) { - MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); - continue; - } - auto &users_cnode_index = func_graph->func_graph_cnodes_index(); - if (!users_cnode_index.empty() && !ignore_users) { - MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); - continue; - } - if (dropped.find(func_graph) != dropped.end()) { - MS_LOG(DEBUG) << "Func graph had been dropped " << func_graph->ToString(); - continue; - } - (void)dropped.insert(func_graph); - std::vector return_vec = {func_graph->get_return()}; - todo.update(MaybeDropNodes(return_vec)); - } - for (auto &fg : dropped) { - MS_EXCEPTION_IF_NULL(fg); - all_nodes_.difference_update(fg->parameters()); - (void)func_graphs_.erase(fg); - if (fg->manager().get() == this) { - fg->set_manager(nullptr); - } - MS_LOG(DEBUG) << "Func graph dropped " << fg->ToString(); - } -} - -void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(inp); - if (direction == kDecEdge) { - MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); - auto &users_node = node_users_[inp]; - if (!users_node.contains(make_pair(node, index))) { - return; - } - (void)users_node.erase(make_pair(node, index)); - DropEdge(node, index, inp); - } else { - MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); - if (IsValueNode(inp)) { - MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); - AddFuncGraph(GetValueNode(inp)); - } - auto &users_node = node_users_[inp]; - users_node.add(make_pair(node, index)); - AddEdge(node, index, inp); - } -} - -void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cnode = node->cast(); - int index = 0; - for (auto &inp : cnode->inputs()) { - ProcessEdge(cnode, index, inp, direction); - ++index; - } - } -} - -IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) { - if (all_nodes_.contains(node)) { - return EXCLUDE; - } else { - return FOLLOW; - } -} - -void FuncGraphManager::AcquireNodes(const std::vector &nodes) { - AnfNodeSet acq; - for (auto &node : nodes) { - AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit_)); - - all_nodes_.update(new_nodes); - acq.update(new_nodes); - } - - for (auto &node : acq) { - MS_EXCEPTION_IF_NULL(node); - auto fg = node->func_graph(); - if (fg != nullptr) { - fg->AddNode(node); - } - ProcessInputs(node, kIncEdge); - } -} - -FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector &nodes) { - AnfNodeSet nodes_ordered(nodes); - FuncGraphSetPtr func_graphs_to_check = std::make_shared(); - while (!nodes_ordered.empty()) { - AnfNodePtr node = nodes_ordered.pop(); - MS_EXCEPTION_IF_NULL(node); - if (!all_nodes_.contains(node)) { - continue; - } - AnfNodeIndexSet &users = node_users_[node]; - - std::vector parameters; - if (!users.empty() || - (node->isa() && parameters.end() != std::find(parameters.begin(), parameters.end(), node))) { - continue; - } - if (IsValueNode(node)) { - auto fg = GetValueNode(node); - func_graphs_to_check->add(fg); - MS_LOG(DEBUG) << "Set value of node " << node->DebugString() << " from func graph " << fg->ToString() - << " to null"; - } - ProcessInputs(node, kDecEdge); - (void)all_nodes_.erase(node); - if (node->func_graph() != nullptr) { - node->func_graph()->DropNode(node); - } - - if (node->isa()) { - auto cnode = node->cast(); - nodes_ordered.update(cnode->inputs()); - } - (void)node_users_.erase(node); - } - return func_graphs_to_check; -} - -void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters) { - auto tr = Transact(); - tr.SetParameters(fg, parameters); - tr.Commit(); -} - -void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) { - auto tr = Transact(); - tr.AddParameter(fg, parameter); - tr.Commit(); -} - -bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { - auto tr = Transact(); - bool success = tr.Replace(old_node, new_node); - if (success) { - tr.Commit(); - } - return success; -} - -void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) { - auto tr = Transact(); - tr.SetEdge(node, index, value); - tr.Commit(); -} - -void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) { - AnfNodePtr source_return = source->get_return(); - AnfNodePtr source_output = source->output(); - AnfNodePtr source_prim = source_return->cast()->input(0); - - int index = 0; - (void)node_users_[source_prim].erase(make_pair(source_return, index)); - DropEdge(source_return, index, source_prim); - index = 1; - (void)node_users_[source_output].erase(make_pair(source_return, index)); - DropEdge(source_return, index, source_output); - (void)all_nodes_.erase(source_return); - (void)node_users_.erase(source_return); - source->DropNode(source_return); - for (auto &node : source->nodes()) { - node->set_func_graph(target); - if (node->scope() == kDefaultScope) { - node->set_scope(scope); - } - } - - MoveAllNodes(source, target); - all_nodes_.difference_update(source->parameters()); - (void)func_graphs_.erase(source); - if (source->manager().get() == this) { - source->set_manager(nullptr); - } -} - -void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { - auto fg = node->func_graph(); - if (input->isa()) { - fg->AddValueNode(input); - if (IsValueNode(input)) { - auto used = GetValueNode(input); - used->AddFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); - if (fg->AddFuncGraphUsed(used)) { - signals_->InvalidateComputer(); - } - if (IsPrimitiveCNode(node, prim::kPrimJ)) { - fg->AddJFuncGraph(used); - } - } - } else if (fg != nullptr && fg != input->func_graph()) { - if (fg->AddFreeVariable(input)) { - signals_->InvalidateComputer(); - } - } -} - -void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { - auto fg = node->func_graph(); - if (input->isa()) { - fg->DropValueNode(input); - if (IsValueNode(input)) { - auto used = GetValueNode(input); - used->DropFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); - if (fg->DropFuncGraphUsed(used)) { - signals_->InvalidateComputer(); - } - if (IsPrimitiveCNode(node, prim::kPrimJ)) { - fg->DropJFuncGraph(used); - } - } - } else if (fg != nullptr && fg != input->func_graph()) { - if (fg->DropFreeVariable(input)) { - signals_->InvalidateComputer(); - } - } -} - -void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { - target->CopyNodes(source); - target->CopyValueNodes(source); - target->CopyFuncGraphCNodesIndex(source); - target->CopyFreeVariables(source); - target->CopyFuncGraphsUsed(source); - target->CopyJFuncGraphs(source); - signals_->InvalidateComputer(); - source->ClearNodes(); - source->ClearValueNodes(); - source->ClearFuncGraphCNodesIndex(); - source->ClearFreeVariables(); - source->ClearFuncGraphsUsed(); - source->ClearJFuncGraphs(); -} - -FuncGraphTransaction FuncGraphManager::Transact() { - auto tr = FuncGraphTransaction(this); - return tr; -} - -void FuncGraphManager::ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, - EdgeTupleCounter *rm_edges, Counter *adds, Counter *rms) { - for (auto &iter : changes) { - auto operation = iter.op; - auto args = iter.args; - switch (operation) { - case Change::kTxSetEdge: { - auto edge = args.cast(); - auto old_node = edge.root_node->input(edge.index); - (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; - (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; - (*rms)[old_node] += 1; - (*adds)[edge.new_node] += 1; - edge.root_node->set_input(edge.index, edge.new_node); - } break; - case Change::kTxSetParams: { - auto param = args.cast(); - MS_EXCEPTION_IF_NULL(param.func_graph); - auto old_parameters = param.func_graph->parameters(); - for (auto &p : param.params) { - (*adds)[p] += 1; - } - for (auto &p : old_parameters) { - (*rms)[p] += 1; - } - param.func_graph->set_parameters(param.params); - } break; - case Change::kTxAddParam: { - auto param = args.cast(); - MS_EXCEPTION_IF_NULL(param.func_graph); - (*adds)[param.param] += 1; - auto param_node = param.param->cast(); - param.func_graph->append_parameter(param_node); - } break; - default: - break; - } - } -} - -void FuncGraphManager::CommitChanges(const std::vector &changes) { - EdgeTupleCounter add_edges; - EdgeTupleCounter rm_edges; - Counter adds; - Counter rms; - ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); - - auto sub_edges = add_edges - rm_edges; - for (auto &iter : sub_edges) { - auto root_node = iter.first.first; - int index = iter.first.second.first; - auto new_node = iter.first.second.second; - ProcessEdge(root_node, index, new_node, kIncEdge); - } - - auto sub_nodes = adds - rms; - std::vector nodes; - (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), - [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); - - AcquireNodes(nodes); - - auto sub_edges_reverse = rm_edges - add_edges; - for (auto &iter : sub_edges_reverse) { - auto root_node = iter.first.first; - int index = iter.first.second.first; - auto old_node = iter.first.second.second; - ProcessEdge(root_node, index, old_node, kDecEdge); - } - - auto sub_nodes_reverse = rms - adds; - std::vector nodes_reverse; - - (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), - [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); - - auto drop_func_graphs = MaybeDropNodes(nodes_reverse); - MaybeDropFuncGraphs(*drop_func_graphs); -} - -void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector ¶ms) { - changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); -} - -void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) { - changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param}); -} - -bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { - MS_EXCEPTION_IF_NULL(old_node); - MS_EXCEPTION_IF_NULL(new_node); - FuncGraphPtr old_func_graph = old_node->func_graph(); - if (old_func_graph != nullptr && old_func_graph->get_return() == old_node) { - MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString(); - return false; - } - auto users = manager_->node_users()[old_node]; - for (auto &node : users) { - SetEdge(node.first, node.second, new_node); - } - - return true; -} - -void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) { - if (k < 0) { - MS_LOG(EXCEPTION) << "Invalid value k = " << k; - } - MS_EXCEPTION_IF_NULL(src_node); - auto cnode = src_node->cast(); - if (cnode == nullptr) { - MS_LOG(EXCEPTION) << "src_node should be a cnode, but cast failed."; - } - changes_.emplace_back(Change::kTxSetEdge, ArgsOfSetEdge{cnode, v, IntToSize(k)}); -} - -void FuncGraphTransaction::Commit() { - std::vector changes; - changes_.swap(changes); - manager_->CommitChanges(changes); -} - -DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) { - MS_EXCEPTION_IF_NULL(manager_); - manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); - validate_ = false; -} - -void DepComputer::Recompute() { - if (!validate_) { - RealRecompute(); - validate_ = true; - } -} - -void DepComputer::Recompute(const FuncGraphPtr &fg) { - if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { - RealRecompute(fg); - func_graphs_validate_[fg] = true; - } -} - -FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) { - if (fg->seen_ == seen_num) { - return std::make_shared(); - } - FuncGraphSetPtr parents = std::make_shared(); - - // Append all the fvs in fg. - auto &fvs = fg->free_variables(); - for (auto fv : fvs) { - parents->add(fv.first->func_graph()); - } - - // Search the fv in fg's child func graph. - auto &fgs = fg->func_graphs_used(); - for (auto &item : fgs) { - fg->seen_ = seen_num; - auto gt = item.first; - parents->update(SeekParents(gt, seen_num)); - } - (void)parents->erase(fg); - return parents; -} - -void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { - MS_EXCEPTION_IF_NULL(fg); - func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration())); -} - -bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { - auto l1 = lhs.second.size(); - auto l2 = rhs.second.size(); - return l1 < l2; -} - -void ParentComputer::RealRecompute(FuncGraphPtr fg) { - this->parent_analysis_[fg] = nullptr; - // Note: must be a copy other than reference as it is modified thereafter. - auto deps = this->manager_->func_graph_parents_total(fg); - - if (deps.empty()) { - this->parent_analysis_[fg] = nullptr; - return; - } else if (deps.size() == 1) { - this->parent_analysis_[fg] = deps.pop(); - return; - } else { - // return nearest parent as parent - FuncGraphSet deps_copy(deps); - for (auto &dep : deps) { - auto parent_deps = this->manager_->func_graph_parents_total(dep); - for (auto &p_d : parent_deps) { - if (deps_copy.count(p_d)) { - (void)deps_copy.erase(p_d); - } - } - if (deps_copy.size() == 1) { - this->parent_analysis_[fg] = deps_copy.pop(); - return; - } - } - } -} - -void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { - MS_EXCEPTION_IF_NULL(manager_); - auto used_fg_total = manager_->func_graphs_used_total(fg); - for (auto &used_fg : used_fg_total) { - if (manager_->parent(used_fg) == fg) { - children_analysis_[fg].add(used_fg); - } - } -} - -void ScopeComputer::RealRecompute(FuncGraphPtr fg) { - MS_EXCEPTION_IF_NULL(manager_); - auto &children = manager_->children(fg); - - scope_analysis_[fg] = FuncGraphSet(); - scope_analysis_[fg].add(fg); - for (auto &child : children) { - scope_analysis_[fg].add(child); - } -} - -void FVTotalComputer::RealRecompute() { - auto manager = DepComputer::manager_; - MS_EXCEPTION_IF_NULL(manager); - - for (auto &fg : manager->func_graphs()) { - fv_total_analysis_[fg] = OrderedMap(); - } - - for (auto &fg : manager->func_graphs()) { - // add all free variable nodes - AnfNodeCounterMap items = fg->free_variables(); - for (auto &iter : items) { - auto curr = fg; - while (curr != nullptr) { - fv_total_analysis_[curr][iter.first] = iter.second; - curr = manager->parent(curr); - if (curr != nullptr) { - const AnfNodeSet &all_nodes = curr->nodes(); - if (all_nodes.contains(iter.first)) { - break; - } - } - } - } - - // add all FGs of free variables - auto &used = fg->func_graphs_used(); - for (auto &iter : used) { - auto p = manager->parent(iter.first); - if (p == nullptr) { - continue; - } - auto curr = fg; - while (curr != p) { - fv_total_analysis_[curr][iter.first] = iter.second; - curr = manager->parent(curr); - } - } - } -} - -void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { - MS_EXCEPTION_IF_NULL(manager_); - std::vector todo; - std::vector todo_new; - - todo.push_back(fg); - while (!todo.empty()) { - todo_new.clear(); - for (auto > : todo) { - for (auto &item : gt->func_graphs_used()) { - auto used_fg = item.first; - if (used_fg == fg) { - func_graph_used_total_analysis_[fg].add(used_fg); - continue; - } - if (func_graph_used_total_analysis_[fg].count(used_fg) == 0) { - todo_new.push_back(used_fg); - } - MS_LOG(DEBUG) << fg->ToString() << " add func graph " << used_fg->ToString(); - func_graph_used_total_analysis_[fg].add(used_fg); - } - } - todo = todo_new; - } -} - -bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { - MS_EXCEPTION_IF_NULL(manager); - std::vector todo; - std::vector todo_new; - todo.push_back(fg); - FuncGraphSet used_total; - while (!todo.empty()) { - todo_new.clear(); - for (auto > : todo) { - for (auto &item : gt->func_graphs_used()) { - auto used_g = item.first; - if (used_g == fg) { - return true; - } - if (used_total.count(used_g) == 0) { - todo_new.push_back(used_g); - } - used_total.add(used_g); - } - } - todo = todo_new; - } - return false; -} - -void RecursiveComputer::RealRecompute(FuncGraphPtr fg) { - this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); -} - -void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list *trace) { - MS_EXCEPTION_IF_NULL(trace); - auto res = std::find(trace->begin(), trace->end(), fg); - // find recursive - if (res != trace->end()) { - auto recur_ptr = std::make_shared>(res, trace->end()); - for (auto iter = res; iter != trace->end(); (void)iter++) { - MS_LOG(DEBUG) << "Recursive graph " << (*iter)->ToString(); - recursive_map_[*iter] = recur_ptr; - } - } else { - trace->push_back(fg); - auto &items = fg->func_graphs_used(); - for (auto iter = items.begin(); iter != items.end(); (void)iter++) { - CheckRecursiveGraphs(iter->first, trace); - } - trace->pop_back(); - if (!recursive_map_.count(fg)) { - recursive_map_[fg] = nullptr; - } - } -} - -bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { - if (fg->seen_ == seen_num) { - MS_LOG(DEBUG) << fg->ToString() << " had been checked"; - return false; - } - auto &j_fgs = fg->j_func_graphs(); - if (!j_fgs.empty()) { - // check g1->J(fg)->g2->g cycle; - auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair iter) { - return iter.first->seen_ != seen_num; - }); - if (contains_j != j_fgs.end()) { - MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; - return true; - } - } - fg->seen_ = seen_num; - - // check if func graphs used contains J(func_graph); - for (auto &item : fg->func_graphs_used()) { - auto used_g = item.first; - if (SeekJ(used_g, seen_num)) { - MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; - return true; - } - } - MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph)"; - return false; -} - -void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) { - this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration()); -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/meta_func_graph.cc b/mindspore/ccsrc/ir/meta_func_graph.cc deleted file mode 100644 index 3b2704613a..0000000000 --- a/mindspore/ccsrc/ir/meta_func_graph.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/meta_func_graph.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "pipeline/static_analysis/abstract_function.h" - -// namespace to support intermediate representation definition -namespace mindspore { -abstract::AbstractBasePtr MetaFuncGraph::MakeAbstractClosure(const AnfNodePtr &anf_node) { - abstract::MetaFuncGraphAbstractClosurePtr meta_func_graph_fn; - if (anf_node == nullptr) { - meta_func_graph_fn = std::make_shared(shared_from_base()); - } else { - meta_func_graph_fn = - std::make_shared(shared_from_base(), anf_node->scope()); - } - return meta_func_graph_fn; -} - -FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { - TypePtrList types; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), - [](const AbstractBasePtr &arg) -> TypePtr { - MS_EXCEPTION_IF_NULL(arg); - return arg->BuildType(); - }); - // filter unsafe characters in log print since name_ is from outside - auto iter = cache_.find(types); - if (iter == cache_.end()) { - FuncGraphPtr fg = GenerateFromTypes(types); - MS_EXCEPTION_IF_NULL(fg); - MS_LOG(INFO) << "MetaFuncgraph: cache miss for types: " << mindspore::ToString(args_spec_list) - << ", g: " << fg->ToString(); - cache_[types] = fg; - return fg; - } else { - MS_LOG(DEBUG) << "MetaFuncgraph: cache hit for types: " << mindspore::ToString(args_spec_list) - << ", g: " << iter->second->ToString(); - return iter->second; - } -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/meta_func_graph.h b/mindspore/ccsrc/ir/meta_func_graph.h deleted file mode 100644 index f63f812f9e..0000000000 --- a/mindspore/ccsrc/ir/meta_func_graph.h +++ /dev/null @@ -1,95 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_ -#define MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_ - -#include -#include -#include -#include -#include -#include - -#include "pybind11/pybind11.h" - -#include "ir/dtype.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/signature.h" -#include "pipeline/static_analysis/abstract_value.h" - -namespace py = pybind11; - -namespace mindspore { -// namespace to support intermediate representation definition -// Graph generator. -// Can be called with a pipeline's resources and a list of argument types to -// generate a graph corresponding to these types. -class MetaFuncGraph : public FuncGraphBase { - public: - explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); } - - ~MetaFuncGraph() override = default; - - MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); - abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node); - // Return normalized versions of the arguments. - // By default, this returns args unchanged. - virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const { - return args_spec_list; - } - - const std::vector &signatures() const { return signatures_; } - void set_signatures(const std::vector &signatures) { signatures_ = signatures; } - // Generate a Graph for the given abstract arguments. - virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list); - - // Generate a Graph for this type signature. - virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) { - MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; - } - - std::string name() { return name_; } - std::string ToString() const override { return name_; } - std::size_t hash() const override { return tid(); } - - virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; } - bool operator==(const Value &other) const override { - if (other.isa()) { - return &other == this; - } else { - return false; - } - } - const bool parse_info_ = true; - - protected: - template - std::shared_ptr shared_from_base() { - return std::static_pointer_cast(shared_from_this()); - } - std::string name_; - std::vector signatures_; - std::unordered_map cache_; -}; - -using MetaFuncGraphPtr = std::shared_ptr; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_ diff --git a/mindspore/ccsrc/ir/meta_tensor.h b/mindspore/ccsrc/ir/meta_tensor.h deleted file mode 100644 index a8c07d6992..0000000000 --- a/mindspore/ccsrc/ir/meta_tensor.h +++ /dev/null @@ -1,195 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_META_TENSOR_H_ -#define MINDSPORE_CCSRC_IR_META_TENSOR_H_ - -#include -#include -#include -#include - -#include "ir/base.h" -#include "ir/dtype.h" -#include "utils/convert_utils.h" -#include "utils/hashing.h" - -// brief mindspore namespace. -// -// mindspore namespace is the top level namespace of MindSpore project. -// Other namespace should be a sub namespace of mindspore namespace in the ME project. -namespace mindspore { - -// brief mindspore::tensor namespace -// -// A sub namespace in ME to support tensor related definition. -namespace tensor { - -// brief Device info of Tensor -// -// Includes the format and data type of a tensor. -struct DeviceInfo { - explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr) - : format_(std::move(format)), data_type_(std::move(data_type)) {} - std::string format_ = "DefaultFormat"; - TypePtr data_type_ = nullptr; -}; - -// brief Metadata of Tensor -// -// Includes the metadata information of a tensor, such as data type, shape -// and so on. But it does not contain values of a tensor. -class MetaTensor : public Value { - public: - // Construction - MetaTensor(); - - // brief Constructs a meta tensor of a tensor having data_type data and shape. - // - // The constructed MetaTensor is not a Tensor, but it has the data type and shape - // information of a Tensor. The following codes will create a 2x3 float - // param data_type The data type of the tensor. - // param shape The shape of the tensor. - MetaTensor(const TypeId data_type, const std::vector &shape); - - MetaTensor(const TypePtr &type_ptr, const std::vector &shape); - // brief Constructs a MetaTensor object from an existing MetaTensor instance. - // - // The constructed MetaTensor object will have the same data type and shape as the - // meta_tensor. - // - // param meta_tensor An existing MetaTensor object. - MetaTensor(const MetaTensor &meta_tensor); - ~MetaTensor() override = default; - MS_DECLARE_PARENT(MetaTensor, Value) - - // brief Overloads operator = for MetaTensor. - // - // The constructed MetaTensor object has the same type and shape with meta_tensor. - // - // param meta_tensor An existing MetaTensor object. - virtual MetaTensor &operator=(const MetaTensor &meta_tensor); - - // brief Compares two MetaTensor objects. - // - // The constructed MetaTensor object has the same type and shape with meta_tensor. - // - // param meta_tensor The MetaTensor object to be compared. - // return true: If having same type and shape, return true, or return false. - virtual bool operator==(const MetaTensor &meta_tensor) const; - - // brief Returns the data type of the tensor in its MetaTensor. - // - // All the types are defined in "ir/dtype.h". - TypePtr Dtype() const; - abstract::AbstractBasePtr ToAbstract() override; - TypeId data_type() const { return data_type_; } - std::string ToString() const override; - std::string DumpText() const override; - // brief Sets the data type of a tensor in its MetaTensor. - // - // param data_type The data type of the tensor to be set. - virtual TypeId set_data_type(const TypeId data_type) { - data_type_ = data_type; - return data_type_; - } - virtual TypePtr SetDtype(const TypePtr type_ptr); - // brief Get tensor's shape. - // - // The shape of a tensor is stored in a vector. Each - // element of the vector represents the size of a dimension of the tensor. - // The order of each element in the vector is as same as the the dimension's - // order it represents. - // - // return A const vector which represents the shape of the tensor. - const std::vector &shape() const { return shape_; } - - // brief Sets the shape of a tensor. - // - // The shape of a tensor is stored in a vector. Each - // element of the vector represents the size of a dimension of the tensor. - // The order of each element in the vector is as same as the the dimension's - // order it represents. - // - // param shape The shape of the tensor. - // return The shape's size. - size_t set_shape(const std::vector &shape) { - this->shape_ = shape; - return shape_.size(); - } - - // Get tensor's device info. - DeviceInfo device_info() const { return device_info_; } - - // Set tensor's device info. - void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; } - - void SetDeviceInfo(const std::string &format, const TypePtr &data_type); - - // Get the size of a given dimension by its index number. - int DimensionSize(size_t index) const; - - // Get total number of elements in a tensor. - int ElementsNum() const; - - std::size_t hash() const override { - std::size_t hash_value = std::hash{}(SizeToInt(data_type_)); - hash_value = hash_combine(hash_value, std::hash{}(shape_.size())); - // hash all elements may costly, so only take at most 4 elements into account based on - // some experiments. - for (size_t i = 0; (i < shape_.size()) && (i < 4); ++i) { - hash_value = hash_combine(hash_value, (std::hash{}(shape_[i]))); - } - return hash_value; - } - bool operator==(const Value &other) const override { - if (other.isa()) { - auto other_ = static_cast(other); - return *this == other_; - } else { - return false; - } - } - const bool parse_info_ = true; - - protected: - // brief Data type of the tensor. - // - // All support data type is in Number Types of [TypeId], - // including [kNumberTypeBool], [kNumberTypeInt], - // [kNumberTypeUInt32], [kNumberTypeFloat32] and [kNumberTypeFloat64]. - TypeId data_type_; - - // brief Shape of the tensor. - // - // A std::vector container is used to store the shape of a tensor. - // Each element of the vector represents the size of a dimension of the tensor. - // The order of each element in the vector is as same as the the dimension's - // order it represents. If the dimension size is not set, its value will be -1. - std::vector shape_; - - // brief Device info of Tensor - // - // Includes the format and data type of a tensor on device. - DeviceInfo device_info_; -}; - -using MetaTensorPtr = std::shared_ptr; - -} // namespace tensor -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_META_TENSOR_H_ diff --git a/mindspore/ccsrc/ir/meta_tensor_extends.cc b/mindspore/ccsrc/ir/meta_tensor_extends.cc deleted file mode 100644 index 87f1db95e5..0000000000 --- a/mindspore/ccsrc/ir/meta_tensor_extends.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/meta_tensor.h" - -#include -#include -#include -#include -#include - -#include "pipeline/static_analysis/abstract_value.h" - -namespace mindspore { -namespace tensor { -abstract::AbstractBasePtr MetaTensor::ToAbstract() { - auto tens = shared_from_base(); - auto dtype = tens->Dtype(); - if (!IsSubType(dtype, kNumber)) { - MS_LOG(EXCEPTION) << "Expect MetaTensor type kNumber but got: " << dtype->ToString() << "."; - } - auto tensor_shape = tens->shape(); - auto abs_tensor = std::make_shared(dtype, tensor_shape); - abs_tensor->set_value(shared_from_base()); - return abs_tensor; -} -} // namespace tensor -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/named.cc b/mindspore/ccsrc/ir/named.cc deleted file mode 100644 index 9e1a7968b8..0000000000 --- a/mindspore/ccsrc/ir/named.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/named.h" -#include "pipeline/static_analysis/abstract_value.h" - -namespace mindspore { -bool Named::operator==(const Value &other) const { - if (other.isa()) { - auto other_named = static_cast(other); - return *this == other_named; - } else { - return false; - } -} - -abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared(); } -const NamedPtr kNone = std::make_shared(); - -abstract::AbstractBasePtr Null::ToAbstract() { return std::make_shared(); } -const NamedPtr kNull = std::make_shared(); - -abstract::AbstractBasePtr Ellipsis::ToAbstract() { return std::make_shared(); } -const NamedPtr kEllipsis = std::make_shared(); -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/param_value_py.h b/mindspore/ccsrc/ir/param_value_py.h deleted file mode 100644 index a03e34ac6e..0000000000 --- a/mindspore/ccsrc/ir/param_value_py.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_ -#define MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_ - -#include - -#include "ir/anf.h" -#include "pybind11/pybind11.h" - -namespace mindspore { -namespace py = pybind11; - -class ParamValuePy : public ParamValue { - public: - ParamValuePy() : value_(py::none()) {} - explicit ParamValuePy(const py::object &value) : value_(value) {} - ~ParamValuePy() override = default; - - py::object value() { return value_; } - void set_value(const py::object &obj) { value_ = obj; } - - private: - py::object value_; -}; - -using ParamValuePyPtr = std::shared_ptr; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_ diff --git a/mindspore/ccsrc/ir/pattern_matcher.h b/mindspore/ccsrc/ir/pattern_matcher.h deleted file mode 100644 index 6605b9ce4c..0000000000 --- a/mindspore/ccsrc/ir/pattern_matcher.h +++ /dev/null @@ -1,310 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ -#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ - -#include -#include - -#include "ir/anf.h" -#include "operator/ops.h" - -namespace mindspore { - -/// -/// Base class for all recognizable patterns. -/// We implement an Expression Template approach using static polymorphism based on -/// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect -/// to the use of virtual functions without the costs..." as described in: -/// https://en.wikipedia.org/wiki/Expression_templates and -/// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern -/// The TryCapture function tries to capture the pattern with the given node. -/// The GetNode function builds a new node using the captured values. -/// - -template -class PBase { - public: - bool CheckFunc(const opt::PredicateFuncType &func, const AnfNodePtr &node) { - return func(get_object().GetNode(node)); - } - - const T &get_object() const { return *static_cast(this); } - - template - bool TryCapture(const TN &value) const { - get_object().Reset(); - return get_object().TryCapture_(value); - } - - using Internal = T; -}; - -template -class PIsEqual { - public: - bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } -}; - -template -class PatternNode : public PBase > { - public: - T GetNode(const AnfNodePtr &node) const { - if (!captured_) { - MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode."; - } - return captured_node_; - } - - bool TryCapture_(const T &node) const { - if (!captured_) { - captured_node_ = node; - captured_ = true; - return true; - } - return PIsEqual()(captured_node_, node); - } - - void Reset() const { captured_ = false; } - using Internal = const PatternNode &; - - protected: - mutable T captured_node_; - mutable bool captured_{false}; -}; - -template -class PBinOperation : public PBase > { - public: - PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {} - - AnfNodePtr GetNode(const AnfNodePtr &node) const { - AnfNodePtr lhs = x_.GetNode(node->func_graph()); - AnfNodePtr rhs = y_.GetNode(node->func_graph()); - AnfNodePtrList list = {prim_->cast(), lhs, rhs}; - return NewCNode(list, node->func_graph()); - } - - bool TryCapture_(const AnfNodePtr &node) const { - if (IsPrimitiveCNode(node, prim_)) { - auto cnode = node->cast(); - auto inputs = cnode->inputs(); - if (inputs.size() == 3) { - // Binary Prim assumes only two inputs - if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { - return false; - } - return true; - } - } - return false; - } - - void Reset() const { - x_.Reset(); - y_.Reset(); - } - - private: - const PrimitivePtr prim_; - typename T::Internal x_; - typename T2::Internal y_; -}; - -/// -/// Helper functions to apply a pattern function on all elements of a tuple -/// -namespace tuple_utils { -template -struct apply_func_tuple_item { - template - static void apply(Func *func, const TTuple &tuple) { - (*func)(Index, std::get(tuple)); - apply_func_tuple_item<(Index + 1) == std::tuple_size::value, (Index + 1), Func>::apply(func, tuple); - } -}; - -template -struct apply_func_tuple_item { - template - static void apply(Func *func, const TTuple &tuple) {} -}; - -template -inline void apply_func_tuple(Func *func, const TTuple &tuple) { - apply_func_tuple_item::value == 0, 0, Func>::apply(func, tuple); -} - -struct PTupleResetCapture { - template - void operator()(size_t i, const T &pattern) const { - pattern.Reset(); - } -}; - -struct PTupleCapture { - explicit PTupleCapture(const AnfNodePtrList tuple) : tuple_(tuple) {} - - template - void operator()(size_t i, const TPattern &pattern) { - // Check if the first node is a Primitive - if (i == 0 && tuple_[i]->isa()) { - auto prim = tuple_[i]->cast(); - if (tuple_[i] != pattern.GetNode(tuple_[i])) { - captured_ = false; - } - } else { - captured_ = captured_ && pattern.TryCapture_(tuple_[i]); - } - } - - const AnfNodePtrList tuple_; - bool captured_{true}; -}; - -struct PTupleGetNode { - explicit PTupleGetNode(const AnfNodePtr &node) : node_(node) {} - - template - void operator()(size_t, const TPattern &pattern) { - args_.push_back(pattern.GetNode(node_)); - } - - const AnfNodePtr &node_; - std::vector args_; -}; -} // namespace tuple_utils - -template -class PCNode : public PBase > { - public: - explicit PCNode(const TArgs &... args) : args_(args...) {} - - AnfNodePtr GetNode(const AnfNodePtr &node) const { - tuple_utils::PTupleGetNode get_node(node); - tuple_utils::apply_func_tuple(&get_node, args_); - return NewCNode(get_node.args_, node->func_graph()); - } - - bool TryCapture_(const AnfNodePtr &node) const { - if (node->isa()) { - auto cnode = node->cast(); - auto inputs = cnode->inputs(); - if (inputs.size() != sizeof...(TArgs)) { - return false; - } - tuple_utils::PTupleCapture capture_func(inputs); - tuple_utils::apply_func_tuple(&capture_func, args_); - return capture_func.captured_; - } - - return false; - } - - void Reset() const { - tuple_utils::PTupleResetCapture reset; - tuple_utils::apply_func_tuple(&reset, args_); - } - - private: - std::tuple args_; -}; - -template -class PPrimitive : public PBase > { - public: - explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {} - - AnfNodePtr GetNode(const AnfNodePtr &node) const { - tuple_utils::PTupleGetNode get_node(node); - tuple_utils::apply_func_tuple(&get_node, args_); - auto prim_cnode = get_node.args_; - prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_)); - return NewCNode(prim_cnode, node->func_graph()); - } - - bool TryCapture_(const AnfNodePtr &node) const { - if (IsPrimitiveCNode(node, prim_)) { - auto cnode = node->cast(); - auto inputs = cnode->inputs(); - if ((inputs.size() - 1) != sizeof...(TArgs)) { - return false; - } - - AnfNodePtrList rest(inputs.begin() + 1, inputs.end()); - tuple_utils::PTupleCapture capture_func(rest); - tuple_utils::apply_func_tuple(&capture_func, args_); - - return capture_func.captured_; - } - - return false; - } - - void Reset() const { - tuple_utils::PTupleResetCapture reset; - tuple_utils::apply_func_tuple(&reset, args_); - } - - private: - const PrimitivePtr prim_; - std::tuple args_; -}; - -// Macro for binary operation functions -#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \ - template \ - inline PBinOperation Operator(const PBase &x, const PBase &y) { \ - return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \ - } - -// Arithmetic operations -BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd); -BIN_OPERATION_PATTERN(operator*, prim::kPrimMul); - -// Macros for match and replace -#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ - if ((CaptureNode).TryCapture(OrigNode)) { \ - return (ReplaceWith).GetNode(OrigNode); \ - } - -#define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ - if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ - return (ReplaceWith).GetNode(OrigNode); \ - } - -#define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ - if ((CaptureNode).TryCapture(OrigNode)) { \ - if ((Condition)) { \ - return (ReplaceWith).GetNode(OrigNode); \ - } \ - return (ElseNode).GetNode(OrigNode); \ - } - -#define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ - if ((CaptureNode).TryCapture(OrigNode)) { \ - return (Lambda)(); \ - } - -#define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ - if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ - return (Lambda)(); \ - } - -} // namespace mindspore - -#endif // #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/ccsrc/ir/primitive.cc deleted file mode 100644 index 6ec27c2567..0000000000 --- a/mindspore/ccsrc/ir/primitive.cc +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/primitive.h" -#include -#include -#include "ir/signature.h" -#include "operator/ops.h" -#include "./common.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/data_converter.h" -#include "pybind11/pytypes.h" -#include "utils/convert_utils_base.h" -#include "utils/primitive_utils.h" - -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" - -namespace mindspore { -void PrimitivePy::set_signatures( - std::vector> signatures) { - signatures_.clear(); - for (auto &signature : signatures) { - std::string name; - SignatureEnumRW rw; - SignatureEnumKind kind; - py::object default_value; - SignatureEnumDType dtype; - std::tie(name, rw, kind, default_value, dtype) = signature; - signatures_.emplace_back(Signature(name, rw, kind, default_value, dtype)); - } - set_has_signature(true); -} - -py::function PrimitivePy::GetBpropFunction() { - static const char *const get_bprop_func_name = "get_bprop"; - if (py::hasattr(python_obj_, get_bprop_func_name)) { - py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); - return fn; - } else { - auto fn = GetBpropFunctionByObj(python_obj_); - return fn; - } -} - -py::function PrimitivePy::GetComputeFunction() { - static const char *const compute_func_name = "vm_impl"; - - if (py::hasattr(python_obj_, compute_func_name)) { - MS_LOG(INFO) << name() << " compute_func_name"; - py::function fn = python_obj_.attr(compute_func_name).cast(); - return fn; - } - - static const std::string vm_module = "mindspore.ops.vm_impl_registry"; - static const std::string get_vm_impl_fn = "get_vm_impl_fn"; - MS_LOG(INFO) << name() << ": get_vm_impl_fn"; - py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn); - py::function vm_fn = get_fn(python_obj_); - - if (py::isinstance(vm_fn)) { - MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast(); - vm_fn = mindspore::GetComputeFunction(Primitive::name()); - } - return vm_fn; -} - -void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { - std::string attr_name = name; - ValuePtr converted_ret = nullptr; - if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; - } - bool converted = parse::ConvertData(obj, &converted_ret); - if (!converted) { - MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); - } - (void)this->AddAttr(attr_name, converted_ret); -} - -py::dict PrimitivePy::GetAttrDict() { - py::dict attr_dict; - for (auto &attr : attrs_) { - attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); - } - return attr_dict; -} - -REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { - (void)py::enum_(*m, "prim_type", py::arithmetic()) - .value("unknown", PrimType::kPrimTypeUnknown) - .value("builtin", PrimType::kPrimTypeBuiltIn) - .value("py_infer_shape", PrimType::kPrimTypePyInferShape) - .value("user_custom", PrimType::kPrimTypeUserCustom); - (void)py::class_>(*m, "Primitive_") - .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) - .def(py::init()) - .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") - .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") - .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") - .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") - .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") - .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); - })); -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h deleted file mode 100644 index 257302c0c4..0000000000 --- a/mindspore/ccsrc/ir/primitive.h +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_H_ -#define MINDSPORE_CCSRC_IR_PRIMITIVE_H_ - -#include -#include -#include -#include -#include - -#include "pipeline/static_analysis/abstract_value.h" -#include "utils/misc.h" -#include "utils/log_adapter.h" -#include "ir/primitive_base.h" -#include "ir/signature.h" -#include "parallel/ops_info/operator_info.h" - -namespace mindspore { -class PrimitivePy : public Primitive { - public: - PrimitivePy(const py::str &name, const py::object &python_obj) - : Primitive(name, false), python_obj_(python_obj), signatures_() {} - ~PrimitivePy() override = default; - MS_DECLARE_PARENT(PrimitivePy, Primitive); - py::function GetBpropFunction(); - py::function GetComputeFunction(); - - void set_signatures( - std::vector> - signatures); - - const std::vector &signatures() const { return signatures_; } - - void AddPyAttr(const py::str &name, const py::object &obj); - - py::dict GetAttrDict(); - void set_hook(const py::function &hook) { hook_ = hook; } - py::function hook() const { return hook_; } - - const bool parse_info_ = true; - const py::object &GetPyObj() const { return python_obj_; } - bool is_tuple_input_ = false; - - private: - py::object python_obj_; - py::function hook_; - std::vector signatures_; -}; - -using PrimitivePyPtr = std::shared_ptr; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ diff --git a/mindspore/ccsrc/ir/primitive_base.cc b/mindspore/ccsrc/ir/primitive_base.cc deleted file mode 100644 index 864427fe13..0000000000 --- a/mindspore/ccsrc/ir/primitive_base.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/primitive_base.h" - -#include - -namespace mindspore { -bool Primitive::operator==(const Value &other) const { - if (other.isa()) { - auto other_prim = static_cast(other); - return *this == other_prim; - } else { - return false; - } -} - -bool Primitive::operator==(const Primitive &other) const { - if (name() != other.name()) { - return false; - } - if (attrs_.size() != other.attrs_.size()) { - return false; - } - auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair &item) -> bool { - if (item.second == nullptr) { - return false; - } - auto iter = other.attrs_.find(item.first); - if (iter == other.attrs_.end()) { - return false; - } - return *item.second == *iter->second; - }); - return all; -} - -std::string Primitive::GetAttrsText() const { - if (attrs_.empty()) { - return ""; - } - - std::ostringstream oss; - oss << "["; - bool is_first = true; - for (auto &attr : attrs_) { - if (is_first) { - is_first = false; - } else { - oss << ", "; - } - oss << attr.first << "=" << attr.second->DumpText(); - } - oss << "]"; - - return oss.str(); -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/primitive_base.h b/mindspore/ccsrc/ir/primitive_base.h deleted file mode 100644 index b34c43d00e..0000000000 --- a/mindspore/ccsrc/ir/primitive_base.h +++ /dev/null @@ -1,150 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ -#define MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ - -#include -#include -#include -#include -#include - -#include "ir/dtype/type.h" -#include "pybind11/pybind11.h" - -namespace py = pybind11; - -namespace mindspore { -// Supported meta type -enum PrimType { - kPrimTypeUnknown = 0, - kPrimTypeBegin = kTypeUnknown, - kPrimTypeBuiltIn, // Built-in primitive operator - kPrimTypePyInferShape, // Primitive operator defined by custom - kPrimTypePyInferTensor, // Primitive operator defined by custom - kPrimTypeUserCustom -}; - -class Primitive : public Named { - public: - explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) - : Named(name), - is_base_(is_base), - has_signature_(false), - prim_type_(prim_type), - record_evaluate_add_attr_(false) {} - - Primitive(const Primitive &prim) - : Named(prim), - attrs_(prim.attrs_), - instance_name_(prim.instance_name_), - is_base_(prim.is_base_), - has_signature_(prim.has_signature_), - prim_type_(prim.prim_type_), - record_evaluate_add_attr_(false) {} - - MS_DECLARE_PARENT(Primitive, Named); - - abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); - std::string ToString() const override { return name(); } - void BeginRecordAddAttr() { - evaluate_added_attrs_.clear(); - record_evaluate_add_attr_ = true; - } - void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } - Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { - attrs_[name] = attr; - if (record_evaluate_add_attr_) { - evaluate_added_attrs_[name] = attr; - } - return *this; - } - - Primitive &SetAttrs(const std::unordered_map &attrs) { - for (auto &attr : attrs) { - attrs_[attr.first] = attr.second; - } - return *this; - } - - void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } - void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } - - ValuePtr GetAttr(const std::string &attrName) const { - auto iter = attrs_.find(attrName); - return iter == attrs_.cend() ? nullptr : iter->second; - } - - const std::unordered_map &attrs() const { return attrs_; } - const std::unordered_map &evaluate_added_attrs() const { return evaluate_added_attrs_; } - - // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. - bool HasAttr() const { return !attrs_.empty(); } - bool HasAttr(const std::string &attrName) const { - auto iter = attrs_.find(attrName); - return !(iter == attrs_.cend()); - } - void set_prim_type(const PrimType t) { prim_type_ = t; } - void set_instance_name(const std::string s) { instance_name_ = s; } - bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } - bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } - bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } - - PrimType prim_type() const { return prim_type_; } - std::string instance_name() const { return instance_name_; } - std::string GetAttrsText() const; - bool operator==(const Value &other) const override; - bool operator==(const Primitive &other) const; - ~Primitive() override = default; - - void set_has_signature(bool has_signature) { has_signature_ = has_signature; } - bool has_signature() const { return has_signature_; } - bool is_base() const { return is_base_; } - - protected: - std::unordered_map attrs_; - std::unordered_map evaluate_added_attrs_; - - private: - std::string instance_name_; - bool is_base_; - bool has_signature_; - PrimType prim_type_; - bool record_evaluate_add_attr_; -}; - -inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { - os << *p; - return os; -} - -struct PrimitiveEqual { - bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { - MS_EXCEPTION_IF_NULL(t1); - MS_EXCEPTION_IF_NULL(t2); - return t1->name() == t2->name(); - } -}; - -struct PrimitiveHasher { - std::size_t operator()(PrimitivePtr const &prim) const { - MS_EXCEPTION_IF_NULL(prim); - return prim->Hash(); - } -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ diff --git a/mindspore/ccsrc/ir/primitive_base_extends.cc b/mindspore/ccsrc/ir/primitive_base_extends.cc deleted file mode 100644 index 64bdafa4d1..0000000000 --- a/mindspore/ccsrc/ir/primitive_base_extends.cc +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/primitive_base.h" -#include "pipeline/static_analysis/abstract_function.h" - -namespace mindspore { -abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { - auto prim_func = std::make_shared(shared_from_base(), anf_node); - return prim_func; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/scalar.h b/mindspore/ccsrc/ir/scalar.h deleted file mode 100644 index e8e29fb2f9..0000000000 --- a/mindspore/ccsrc/ir/scalar.h +++ /dev/null @@ -1,362 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_SCALAR_H_ -#define MINDSPORE_CCSRC_IR_SCALAR_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/base.h" -#include "ir/dtype.h" -#include "ir/dtype/number.h" - -using std::fabs; - -namespace mindspore { -class Scalar : public Value { - public: - Scalar() = default; - explicit Scalar(const TypePtr t) : Value(t) {} - ~Scalar() override = default; - MS_DECLARE_PARENT(Scalar, Value) - virtual bool IsZero() = 0; - virtual bool IsOne() = 0; - abstract::AbstractBasePtr ToAbstract() override; - - protected: - std::size_t hash_ = 0; -}; -using ScalarPtr = std::shared_ptr; - -class BoolImm : public Scalar { - public: - explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = std::hash{}(v_); } - ~BoolImm() override = default; - MS_DECLARE_PARENT(BoolImm, Scalar) - std::size_t hash() const override { return hash_; } - bool value() const { return v_; } - bool IsZero() override { return v_ == false; } - bool IsOne() override { return v_ == true; } - bool operator==(const Value &other) const override; - bool operator==(const BoolImm &other) const; - std::string ToString() const override { - if (v_) { - return "true"; - } else { - return "false"; - } - } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "Bool(" << v_ << ")"; - return oss.str(); - } - - private: - bool v_; -}; -using BoolImmPtr = std::shared_ptr; -IMM_TRAITS(BoolImmPtr, bool) - -class IntergerImm : public Scalar { - public: - IntergerImm() = default; - explicit IntergerImm(const TypePtr &t) : Scalar(t) {} - ~IntergerImm() override = default; - MS_DECLARE_PARENT(IntergerImm, Scalar) -}; - -class Int8Imm : public IntergerImm { - public: - Int8Imm() : IntergerImm(kInt8), v_(0) {} - explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = std::hash{}(v_); } - ~Int8Imm() override = default; - MS_DECLARE_PARENT(Int8Imm, IntergerImm) - std::size_t hash() const override { return hash_; } - bool IsZero() override { return v_ == 0; } - bool IsOne() override { return v_ == 1; } - int8_t value() const { return v_; } - bool operator==(const Value &other) const override; - bool operator==(const Int8Imm &other) const; - std::string ToString() const override { return std::to_string(v_); } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "I8(" << v_ << ")"; - return oss.str(); - } - - private: - int8_t v_; -}; -using Int8ImmPtr = std::shared_ptr; -IMM_TRAITS(Int8ImmPtr, int8_t) - -class Int16Imm : public IntergerImm { - public: - Int16Imm() : IntergerImm(kInt16), v_(0) {} - explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = std::hash{}(v_); } - ~Int16Imm() override = default; - MS_DECLARE_PARENT(Int16Imm, IntergerImm) - std::size_t hash() const override { return hash_; } - bool IsZero() override { return v_ == 0; } - bool IsOne() override { return v_ == 1; } - int16_t value() const { return v_; } - bool operator==(const Value &other) const override; - bool operator==(const Int16Imm &other) const; - std::string ToString() const override { return std::to_string(v_); } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "I16(" << v_ << ")"; - return oss.str(); - } - - private: - int16_t v_; -}; -using Int16ImmPtr = std::shared_ptr; -IMM_TRAITS(Int16ImmPtr, int16_t) - -class Int32Imm : public IntergerImm { - public: - Int32Imm() : IntergerImm(kInt32), v_(0) {} - explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = std::hash{}(v_); } - ~Int32Imm() override = default; - MS_DECLARE_PARENT(Int32Imm, IntergerImm) - std::size_t hash() const override { return hash_; } - bool IsZero() override { return v_ == 0; } - bool IsOne() override { return v_ == 1; } - int32_t value() const { return v_; } - bool operator==(const Value &other) const override; - bool operator==(const Int32Imm &other) const; - std::string ToString() const override { return std::to_string(v_); } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "I32(" << v_ << ")"; - return oss.str(); - } - - private: - int32_t v_; -}; -using Int32ImmPtr = std::shared_ptr; -IMM_TRAITS(Int32ImmPtr, int32_t) - -class Int64Imm : public IntergerImm { - public: - Int64Imm() : IntergerImm(kInt64), v_(0) {} - explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = std::hash{}(v_); } - ~Int64Imm() override = default; - MS_DECLARE_PARENT(Int64Imm, IntergerImm) - std::size_t hash() const override { return hash_; } - bool IsZero() override { return v_ == 0; } - bool IsOne() override { return v_ == 1; } - int64_t value() const { return v_; } - bool operator==(const Value &other) const override; - bool operator==(const Int64Imm &other) const; - std::string ToString() const override { return std::to_string(v_); } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "I64(" << v_ << ")"; - return oss.str(); - } - - private: - int64_t v_; -}; -using Int64ImmPtr = std::shared_ptr; -IMM_TRAITS(Int64ImmPtr, int64_t) - -class UInt8Imm : public IntergerImm { - public: - UInt8Imm() : IntergerImm(kUInt8), v_(0) {} - explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { hash_ = std::hash{}(v_); } - ~UInt8Imm() override = default; - MS_DECLARE_PARENT(UInt8Imm, IntergerImm) - std::size_t hash() const override { return hash_; } - bool IsZero() override { return v_ == 0; } - bool IsOne() override { return v_ == 1; } - uint8_t value() const { return v_; } - bool operator==(const Value &other) const override; - bool operator==(const UInt8Imm &other) const; - std::string ToString() const override { return std::to_string(v_); } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "U8(" << v_ << ")"; - return oss.str(); - } - - private: - uint8_t v_; -}; -using UInt8ImmPtr = std::shared_ptr; -IMM_TRAITS(UInt8ImmPtr, uint8_t); - -class UInt16Imm : public IntergerImm { - public: - UInt16Imm() : IntergerImm(kUInt16), v_(0) {} - explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { hash_ = std::hash{}(v_); } - ~UInt16Imm() override = default; - MS_DECLARE_PARENT(UInt16Imm, IntergerImm) - std::size_t hash() const override { return hash_; } - bool IsZero() override { return v_ == 0; } - bool IsOne() override { return v_ == 1; } - uint16_t value() const { return v_; } - bool operator==(const Value &other) const override; - bool operator==(const UInt16Imm &other) const; - std::string ToString() const override { return std::to_string(v_); } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "U16(" << v_ << ")"; - return oss.str(); - } - - private: - uint16_t v_; -}; -using UInt16ImmPtr = std::shared_ptr; -IMM_TRAITS(UInt16ImmPtr, uint16_t); - -class UInt32Imm : public IntergerImm { - public: - UInt32Imm() : IntergerImm(kUInt32), v_(0) {} - explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { hash_ = std::hash{}(v_); } - ~UInt32Imm() override = default; - MS_DECLARE_PARENT(UInt32Imm, IntergerImm) - std::size_t hash() const override { return hash_; } - bool IsZero() override { return v_ == 0; } - bool IsOne() override { return v_ == 1; } - uint32_t value() const { return v_; } - bool operator==(const Value &other) const override; - bool operator==(const UInt32Imm &other) const; - std::string ToString() const override { return std::to_string(v_); } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "U32(" << v_ << ")"; - return oss.str(); - } - - private: - uint32_t v_; -}; -using UInt32ImmPtr = std::shared_ptr; -IMM_TRAITS(UInt32ImmPtr, uint32_t); - -class UInt64Imm : public IntergerImm { - public: - UInt64Imm() : IntergerImm(kUInt64), v_(0) {} - explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { hash_ = std::hash{}(v); } - ~UInt64Imm() override = default; - MS_DECLARE_PARENT(UInt64Imm, IntergerImm) - std::size_t hash() const override { return hash_; } - bool IsZero() override { return v_ == 0; } - bool IsOne() override { return v_ == 1; } - uint64_t value() const { return v_; } - bool operator==(const Value &other) const override; - bool operator==(const UInt64Imm &other) const; - std::string ToString() const override { return std::to_string(v_); } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "U64(" << v_ << ")"; - return oss.str(); - } - - private: - uint64_t v_; -}; -using UInt64ImmPtr = std::shared_ptr; -IMM_TRAITS(UInt64ImmPtr, uint64_t); - -class FloatImm : public Scalar { - public: - FloatImm() = default; - explicit FloatImm(const TypePtr &t) : Scalar(t) {} - ~FloatImm() override = default; - MS_DECLARE_PARENT(FloatImm, Scalar) -}; -using FloatImmPtr = std::shared_ptr; - -class FP32Imm : public FloatImm { - public: - FP32Imm() : FloatImm(kFloat32), v_(0.0) {} - explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = std::hash{}(v_); } - ~FP32Imm() override = default; - MS_DECLARE_PARENT(FP32Imm, FloatImm) - std::size_t hash() const override { return hash_; } - bool IsZero() override { return fabs(v_) <= FLT_EPSILON; } - bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; } - float value() const { return v_; } - bool operator==(const Value &other) const override; - bool operator==(const FP32Imm &other) const; - std::string ToString() const override { return std::to_string(v_); } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "F32(" << v_ << ")"; - return oss.str(); - } - - private: - float v_; -}; -using FP32ImmPtr = std::shared_ptr; -IMM_TRAITS(FP32ImmPtr, float) - -class FP64Imm : public FloatImm { - public: - FP64Imm() : FloatImm(kFloat64), v_(0.0) {} - explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = std::hash{}(v_); } - ~FP64Imm() override = default; - MS_DECLARE_PARENT(FP64Imm, FloatImm) - std::size_t hash() const override { return hash_; } - bool IsZero() override { return fabs(v_) <= DBL_EPSILON; } - bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; } - double value() const { return v_; } - bool operator==(const Value &other) const override; - bool operator==(const FP64Imm &other) const; - std::string ToString() const override { return std::to_string(v_); } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "F64(" << v_ << ")"; - return oss.str(); - } - - private: - double v_; -}; -using FP64ImmPtr = std::shared_ptr; -IMM_TRAITS(FP64ImmPtr, double) - -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_SCALAR_H_ diff --git a/mindspore/ccsrc/ir/signature.cc b/mindspore/ccsrc/ir/signature.cc deleted file mode 100644 index 8f312d5b98..0000000000 --- a/mindspore/ccsrc/ir/signature.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/signature.h" - -#include "pybind11/operators.h" -#include "pybind_api/api_register.h" -#include "pipeline/parse/data_converter.h" - -namespace mindspore { -Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, - const py::object &arg_default, const SignatureEnumDType &arg_dtype) - : name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) { - if (py::isinstance(arg_default) && - py::cast(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) { - default_value = nullptr; - } else { - default_value = parse::data_converter::PyDataToValue(arg_default); - } -} - -Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind) - : name(arg_name), - rw(rw_tag), - kind(arg_kind), - default_value(nullptr), - dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {} - -REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { - (void)py::enum_(*m, "signature_rw", py::arithmetic()) - .value("RW_READ", SignatureEnumRW::kRWRead) - .value("RW_WRITE", SignatureEnumRW::kRWWrite) - .value("RW_REF", SignatureEnumRW::kRWRef) - .value("RW_EMPTY_DEFAULT_VALUE", SignatureEnumRW::kRWEmptyDefaultValue); - (void)py::enum_(*m, "signature_kind", py::arithmetic()) - .value("KIND_POSITIONAL_KEYWORD", SignatureEnumKind::kKindPositionalKeyword) - .value("KIND_VAR_POSITIONAL", SignatureEnumKind::kKindVarPositional) - .value("KIND_KEYWORD_ONLY", SignatureEnumKind::kKindKeywordOnly) - .value("KIND_VAR_KEYWARD", SignatureEnumKind::kKindVarKeyword) - .value("KIND_EMPTY_DEFAULT_VALUE", SignatureEnumKind::kKindEmptyDefaultValue); - (void)py::enum_(*m, "signature_dtype", py::arithmetic()) - .value("T", SignatureEnumDType::kDType) - .value("T1", SignatureEnumDType::kDType1) - .value("T2", SignatureEnumDType::kDType2) - .value("T3", SignatureEnumDType::kDType3) - .value("T4", SignatureEnumDType::kDType4) - .value("T5", SignatureEnumDType::kDType5) - .value("T6", SignatureEnumDType::kDType6) - .value("T7", SignatureEnumDType::kDType7) - .value("T8", SignatureEnumDType::kDType8) - .value("T9", SignatureEnumDType::kDType9) - .value("T_EMPTY_DEFAULT_VALUE", SignatureEnumDType::kDTypeEmptyDefaultValue); - })); -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/signature.h b/mindspore/ccsrc/ir/signature.h deleted file mode 100644 index 48be7e0f31..0000000000 --- a/mindspore/ccsrc/ir/signature.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_SIGNATURE_H_ -#define MINDSPORE_CCSRC_IR_SIGNATURE_H_ -#include -#include - -#include "pybind11/operators.h" -#include "ir/value.h" - -namespace py = pybind11; - -namespace mindspore { -// Input signature, support type -enum SignatureEnumRW { - // describe the arguments action on read and write - kRWRead = 0, // use the value of the input - kRWWrite, // use the key of the input - kRWRef, // use the ref of the input - kRWEmptyDefaultValue, - kRWDefault = kRWRead -}; -enum SignatureEnumKind { - kKindPositionalKeyword = 0, // use value of the input start from this arg - kKindVarPositional, // use key of the input start from this arg - kKindKeywordOnly, - kKindVarKeyword, // use ref of the input start from this arg - kKindEmptyDefaultValue, - kKindDefault = kKindPositionalKeyword -}; -enum SignatureEnumDType { - kDType = 0, - kDType1, - kDType2, - kDType3, - kDType4, - kDType5, - kDType6, - kDType7, - kDType8, - kDType9, - kDTypeEmptyDefaultValue -}; -struct Signature { - std::string name; - SignatureEnumRW rw; - SignatureEnumKind kind; - ValuePtr default_value; // nullptr for no default value - SignatureEnumDType dtype; - Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, - const py::object &arg_default, const SignatureEnumDType &arg_dtype); - Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind); -}; -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_SIGNATURE_H_ diff --git a/mindspore/ccsrc/ir/tensor.cc b/mindspore/ccsrc/ir/tensor.cc deleted file mode 100644 index c06ba2a820..0000000000 --- a/mindspore/ccsrc/ir/tensor.cc +++ /dev/null @@ -1,393 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/tensor.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "device/device_address.h" -#include "pipeline/static_analysis/abstract_value.h" - -namespace mindspore { -namespace tensor { - -static std::string MakeId() { - // Use atomic to make id generator thread safe. - static std::atomic last_id{1}; - return "T" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed)); -} - -static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) { - return data_type ? data_type->type_id() : defaultTypeId; -} - -static size_t SizeOf(const std::vector &shape) { - return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); -} - -template -std::vector CopyData(const std::vector &shape, void *data, TypeId data_type) { - const size_t count = SizeOf(shape); - switch (data_type) { - case kNumberTypeBool: - case kNumberTypeUInt8: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeInt8: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeInt16: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeInt32: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeInt64: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeUInt16: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeUInt32: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeUInt64: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeFloat16: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeFloat32: { - const float *buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeFloat64: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - default: - break; - } - MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; -} - -template -std::vector CopyData(const std::vector &shape, void *data, size_t data_len) { - size_t size = SizeOf(shape); - if (size * sizeof(T) != data_len) { - MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T) - << " item size " << sizeof(T); - } - auto buf = static_cast(data); - return {buf, buf + size}; -} - -// Tensor data implementation. -template -class TensorDataImpl : public TensorData { - public: - explicit TensorDataImpl(const std::vector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {} - - TensorDataImpl(const std::vector &shape, void *data, size_t data_len) - : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_len)) {} - - TensorDataImpl(const std::vector &shape, void *data, TypeId data_type) - : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_type)) {} - - template - TensorDataImpl(const std::vector &shape, InputIt first, InputIt last) - : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(first, last) {} - - template - TensorDataImpl(const std::vector &shape, Scalar scalar) - : ndim_(shape.size()), data_size_(SizeOf(shape)), data_({static_cast(scalar)}) {} - - ssize_t size() const override { return static_cast(data_size_); } - - ssize_t itemsize() const override { return static_cast(sizeof(T)); } - - ssize_t nbytes() const override { return size() * itemsize(); } - - ssize_t ndim() const override { return static_cast(ndim_); } - - void *data() override { - static std::vector empty_data(1); - if (data_size_ == 0) { - // Prevent null pointer for empty shape. - return empty_data.data(); - } - if (data_.empty()) { - // Lazy allocation. - data_.resize(data_size_); - } - return data_.data(); - } - - bool equals(const TensorData &other) const override { - auto ptr = dynamic_cast *>(&other); - if (ptr) { - return (ptr == this) || ((ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) && (data_ == ptr->data_)); - } - return false; - } - - std::string ToString() const override { - std::ostringstream ss; - ss << '['; - for (auto value : data_) { - ss << value << ','; - } - ss << ']'; - return ss.str(); - } - - private: - size_t ndim_{0}; - size_t data_size_{0}; - std::vector data_; -}; - -template -TensorDataPtr MakeTensorData(TypeId data_type, const std::vector &shape, Args... args) { - switch (data_type) { - case kNumberTypeBool: - case kNumberTypeUInt8: - return std::make_shared>(shape, args...); - case kNumberTypeInt8: - return std::make_shared>(shape, args...); - case kNumberTypeInt16: - return std::make_shared>(shape, args...); - case kNumberTypeInt32: - return std::make_shared>(shape, args...); - case kNumberTypeInt64: - return std::make_shared>(shape, args...); - case kNumberTypeUInt16: - return std::make_shared>(shape, args...); - case kNumberTypeUInt32: - return std::make_shared>(shape, args...); - case kNumberTypeUInt64: - return std::make_shared>(shape, args...); - case kNumberTypeFloat16: - return std::make_shared>(shape, args...); - case kNumberTypeFloat32: - return std::make_shared>(shape, args...); - case kNumberTypeFloat64: - return std::make_shared>(shape, args...); - default: - break; - } - MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; -} - -Tensor::Tensor(const Tensor &tensor) - : MetaTensor(tensor), - init_flag_(tensor.init_flag_), - data_(tensor.data_), - dirty_(tensor.dirty_), - id_(tensor.id_), - device_address_(tensor.device_address_) {} - -Tensor::Tensor(const Tensor &tensor, TypeId data_type) - : MetaTensor(data_type, tensor.shape_), - init_flag_(tensor.init_flag_), - data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)), - dirty_(tensor.dirty_), - id_(tensor.id_), - device_address_(tensor.device_address_) {} - -Tensor::Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data) - : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} - -Tensor::Tensor(TypeId data_type, const std::vector &shape) - : Tensor(data_type, shape, MakeTensorData(data_type, shape)) {} - -Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, size_t data_len) - : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {} - -Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, TypeId src_data_type) - : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {} - -Tensor::Tensor(const std::vector &input, const TypePtr &data_type) - : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast(input.size())}), - data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), - id_(MakeId()) {} - -Tensor::Tensor(const std::vector &input, const TypePtr &data_type) - : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast(input.size())}), - data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), - id_(MakeId()) {} - -Tensor::Tensor(int64_t input, const TypePtr &data_type) - : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}), - data_(MakeTensorData(data_type_, {}, input)), - id_(MakeId()) {} - -Tensor::Tensor(double input, const TypePtr &data_type) - : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {}), - data_(MakeTensorData(data_type_, {}, input)), - id_(MakeId()) {} - -bool Tensor::operator==(const Tensor &tensor) const { - return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_)); -} - -bool Tensor::ValueEqual(const Tensor &tensor) const { - return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); -} -// assgin value to this tensor -Tensor &Tensor::AssignValue(const Tensor &tensor) { - if (this != &tensor) { - MetaTensor::operator=(tensor); - dirty_ = tensor.is_dirty(); - device_address_ = tensor.device_address(); - data_ = tensor.data_; - id_ = tensor.id(); - } - return *this; -} -abstract::AbstractBasePtr Tensor::ToAbstract() { - auto tens = shared_from_base(); - auto dtype = tens->Dtype(); - if (!IsSubType(dtype, kNumber)) { - MS_LOG(EXCEPTION) << "Expect tensor type kNumber but got: " << dtype->ToString() << "."; - } - auto tensor_shape = tens->shape(); - auto abs_tensor = std::make_shared(dtype, tensor_shape); - abs_tensor->set_value(shared_from_base()); - return abs_tensor; -} - -std::string Tensor::GetShapeAndDataTypeInfo() const { - std::ostringstream buf; - buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); - return buf.str(); -} - -std::string Tensor::ToString() const { - const int small_tensor_size = 30; - std::ostringstream buf; - buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); - // only print small tensor - if (DataSize() < small_tensor_size) { - buf << "val:" << data().ToString(); - } - return buf.str(); -} - -std::string Tensor::ToStringRepr() const { - std::ostringstream buf; - auto type_ptr = this->Dtype(); - MS_EXCEPTION_IF_NULL(type_ptr); - buf << "Tensor shape:[" << shape() << "]" << type_ptr->ToString(); - buf << "\nval:" << data().ToString(); - return buf.str(); -} - -void Tensor::data_sync() const { - if (device_address_ != nullptr) { - if (!device_address_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { - MS_LOG(EXCEPTION) << "SyncDeviceToHost when asnumpy."; - } - } -} - -TypeId Tensor::set_data_type(const TypeId data_type) { - if (data_type != data_type_) { - data_ = MakeTensorData(data_type, shape_, data_->data(), data_type_); - return MetaTensor::set_data_type(data_type); - } - return data_type; -} -} // namespace tensor - -namespace inference { -MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector &shape) { - return new Tensor(data_type, shape); -} - -Tensor::Tensor(TypeId data_type, const std::vector &shape) { - this->tensor_impl_ = std::make_shared(data_type, shape); -} - -Tensor::Tensor(std::shared_ptr tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); } - -TypeId Tensor::data_type() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data_type(); -} - -TypeId Tensor::set_data_type(TypeId data_type) { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->set_data_type(data_type); -} - -std::vector Tensor::shape() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->shape(); -} - -size_t Tensor::set_shape(const std::vector &shape) { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->set_shape(shape); -} - -int Tensor::DimensionSize(size_t index) const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->DimensionSize(index); -} - -int Tensor::ElementsNum() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->ElementsNum(); -} - -std::size_t Tensor::hash() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->hash(); -} - -std::shared_ptr Tensor::tensor() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_; -} - -size_t Tensor::Size() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data().nbytes(); -} - -void *Tensor::MutableData() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data_c(); -} - -} // namespace inference -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/tensor.h b/mindspore/ccsrc/ir/tensor.h deleted file mode 100644 index 5be8a063c1..0000000000 --- a/mindspore/ccsrc/ir/tensor.h +++ /dev/null @@ -1,279 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_TENSOR_H_ -#define MINDSPORE_CCSRC_IR_TENSOR_H_ - -#include -#include -#include -#include - -#include "Eigen/Core" -#include "device/device_address.h" -#include "ir/meta_tensor.h" -#include "include/ms_tensor.h" -#include "utils/log_adapter.h" - -using float16 = Eigen::half; - -using mindspore::device::DeviceAddress; -using DeviceAddressPtr = std::shared_ptr; -// brief mindspore namespace. -// -// mindspore namespace is the top level namespace of MindSpore project. -// Other namespace should be a sub namespace of mindspore namespace in the ME project. -namespace mindspore { -// brief mindspore::tensor namespace -// -// A sub namespace in ME to support tensor related definition. -namespace tensor { -// Tensor data interface. -class TensorData { - public: - /// Total number of elements. - virtual ssize_t size() const = 0; - /// Byte size of a single element. - virtual ssize_t itemsize() const = 0; - /// Total number of bytes. - virtual ssize_t nbytes() const = 0; - /// Number of dimensions. - virtual ssize_t ndim() const = 0; - /// Data pointer. - virtual void *data() = 0; - /// Is data equals. - virtual bool equals(const TensorData &other) const = 0; - /// To string. - virtual std::string ToString() const = 0; -}; - -using TensorDataPtr = std::shared_ptr; - -// Tensor entity class -class Tensor : public MetaTensor { - public: - abstract::AbstractBasePtr ToAbstract() override; - - // brief Create tensor from another tensor, data is shared. - // - // param tensor [Tensor] The input tensor. - explicit Tensor(const Tensor &tensor); - - // brief Create tensor with given data type from another tensor. - // - // param tensor [Tensor] The input tensor. - // param data_type [TypeId] The new tensor data type. - Tensor(const Tensor &tensor, TypeId data_type); - - // brief Create tensor with the given shared tensor data. - // - // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. - // param data The shared tensor data. - Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data); - - // brief Create an all zero tensor. - // - // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. - Tensor(TypeId data_type, const std::vector &shape); - - // brief Create a tensor with input data buffer. - // - // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. - // param data The input data to be copied into tensor. - // param data_len The length of data in bytes. - Tensor(TypeId data_type, const std::vector &shape, void *data, size_t data_len); - - // brief Create a tensor with input data buffer and given source data type. - // - // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. - // param data The input data to be copied into tensor. - // param src_data_type The source data type. - Tensor(TypeId data_type, const std::vector &shape, void *data, TypeId src_data_type); - - // brief Create 1 dimension tensor from an int vector. - // - // param input [std::vector] the data for tensor - // param data_type [TypeId] data type - explicit Tensor(const std::vector &input, const TypePtr &data_type = nullptr); - - // brief Create 1 dimension tensor from a float vector. - // - // param input [std::vector] the data for tensor - // param data_type [TypeId] data type - explicit Tensor(const std::vector &input, const TypePtr &data_type = nullptr); - - // brief Create 0 dimension tensor from an int scalar. - // - // param input [int64] the data for tensor - // param data_type [TypeId] data type - explicit Tensor(int64_t input, const TypePtr &data_type = nullptr); - - // brief Create 0 dimension tensor from a float scalar. - // - // param input [double] the data for tensor - // param data_type [TypeId] data type - explicit Tensor(double input, const TypePtr &data_type = nullptr); - - ~Tensor() override = default; - - MS_DECLARE_PARENT(Tensor, MetaTensor); - - // brief Compares two Tensor objects. - // - // Compare two tensor objects to see if they have same data type, shape and data address. - // - // param tensor The Tensor object to be compared. - // return true: If having same type, shape and data address, return true, or return false. - bool operator==(const Tensor &tensor) const; - - // It is different from 'operator==' which just compare shape/type/address, - // it do real value comparison. - bool ValueEqual(const Tensor &tensor) const; - - // assgin value to this tensor - Tensor &AssignValue(const Tensor &tensor); - - bool operator==(const Value &other) const override { - if (other.isa()) { - auto &other_ = static_cast(other); - return *this == other_; - } - return false; - } - - // brief Gets tensor's dimension - // - // return The number of dimensions of the tensor data. - int DataDim() const { return static_cast(data().ndim()); } - - // brief Getting tensor data size - // - // return The total number of elements of the tensor data. - int DataSize() const { return static_cast(data().size()); } - - // brief Get the data type fo the tensor for C++ - // - // return [int] The tensor's data type will be cast to int to return. - int data_type_c() const { return static_cast(data_type_); } - - // brief Get the tensor's shape for C++ - // - // return [std::vector] - std::vector shape_c(void) const { return shape(); } - - // brief Get Tensor data pointer for c++ type - // - // param writable true if writable, false if read only - // return The pointer to the object - void *data_c() { return data().data(); } - - // brief Get Tensor data byte-size for c++ type - // - // return byte size of Tensor data - size_t Size() const { return data().nbytes(); } - - void *data_c() const { return data_->data(); } - - // brief Sync data with device. - void data_sync() const; - - // brief Get the internal data object. - // - // return The reference to internal data object. - TensorData &data() { return *data_; } - - // brief Get the internal data shared pointer. - // - // return The reference to internal data object. - const TensorDataPtr &data_ptr() const { return data_; } - - // brief Get the internal data object. - // - // return The reference to internal data object. - const TensorData &data() const { return *data_; } - - TypeId set_data_type(const TypeId data_type) override; - - std::string GetShapeAndDataTypeInfo() const; - - std::string ToString() const override; - - std::string ToStringRepr() const; - - bool is_init() { return init_flag_; } - void set_init_flag(bool flag) { init_flag_ = flag; } - - bool is_dirty() const { return dirty_; } - void set_dirty(const bool dirty) { dirty_ = dirty; } - - DeviceAddressPtr device_address() const { return device_address_; } - void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } - - std::string id() const { return id_; } - - const bool parse_info_ = true; - - private: - bool init_flag_{false}; - TensorDataPtr data_{nullptr}; - bool dirty_{true}; - std::string id_{""}; - DeviceAddressPtr device_address_{nullptr}; -}; -using TensorPtr = std::shared_ptr; -using TensorPtrList = std::vector>; -} // namespace tensor - -namespace inference { -class Tensor : public MSTensor { - public: - Tensor(TypeId data_type, const std::vector &shape); - - explicit Tensor(std::shared_ptr tensor_ptr); - - ~Tensor() = default; - - TypeId data_type() const override; - - TypeId set_data_type(const TypeId data_type) override; - - std::vector shape() const override; - - size_t set_shape(const std::vector &shape) override; - - int DimensionSize(size_t index) const override; - - int ElementsNum() const override; - - std::size_t hash() const override; - - std::shared_ptr tensor() const; - - size_t Size() const override; - - void *MutableData() const override; - - protected: - std::shared_ptr tensor_impl_; -}; -} // namespace inference -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_TENSOR_H_ diff --git a/mindspore/ccsrc/ir/tensor_py.cc b/mindspore/ccsrc/ir/tensor_py.cc deleted file mode 100644 index 11a000cef7..0000000000 --- a/mindspore/ccsrc/ir/tensor_py.cc +++ /dev/null @@ -1,390 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/tensor_py.h" - -#include -#include -#include -#include -#include - -#include "device/device_address.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" -#include "pipeline/static_analysis/abstract_value.h" - -namespace mindspore { -namespace tensor { - -static TypeId GetDataType(const py::buffer_info &buf) { - if (buf.format.size() == 1) { - switch (buf.format.front()) { - case 'e': - case 'f': - case 'd': - switch (buf.itemsize) { - case 2: - return TypeId::kNumberTypeFloat16; - case 4: - return TypeId::kNumberTypeFloat32; - case 8: - return TypeId::kNumberTypeFloat64; - } - break; - case 'b': - case 'h': - case 'i': - case 'l': - case 'q': - switch (buf.itemsize) { - case 1: - return TypeId::kNumberTypeInt8; - case 2: - return TypeId::kNumberTypeInt16; - case 4: - return TypeId::kNumberTypeInt32; - case 8: - return TypeId::kNumberTypeInt64; - } - break; - case 'B': - case 'H': - case 'I': - case 'L': - case 'Q': - switch (buf.itemsize) { - case 1: - return TypeId::kNumberTypeUInt8; - case 2: - return TypeId::kNumberTypeUInt16; - case 4: - return TypeId::kNumberTypeUInt32; - case 8: - return TypeId::kNumberTypeUInt64; - } - break; - case '?': - return TypeId::kNumberTypeBool; - } - } - MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize; - return TypeId::kTypeUnknown; -} - -static std::string GetPyTypeFormat(TypeId data_type) { - switch (data_type) { - case TypeId::kNumberTypeFloat16: - return "e"; - case TypeId::kNumberTypeFloat32: - return py::format_descriptor::format(); - case TypeId::kNumberTypeFloat64: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt8: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt16: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt32: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt64: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt8: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt16: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt32: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt64: - return py::format_descriptor::format(); - case TypeId::kNumberTypeBool: - return py::format_descriptor::format(); - default: - MS_LOG(WARNING) << "Unsupported DataType " << data_type << "."; - return ""; - } -} - -static bool IsCContiguous(const py::array &input) { - auto flags = static_cast(input.flags()); - return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0; -} - -TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) { - // Get input buffer info. - py::buffer_info buf = input.request(); - // Check data types. - auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown; - auto buf_type = GetDataType(buf); - if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) { - MS_LOG(EXCEPTION) << "Unsupported tensor type!"; - } - // Use buf type as data type if type_ptr not set. - if (data_type == TypeId::kTypeUnknown) { - data_type = buf_type; - } - // Convert input array to C contiguous if need. - std::unique_ptr tmp_buf; - if (!IsCContiguous(input)) { - Py_buffer pybuf; - if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS)) { - MS_LOG(EXCEPTION) << "Failed to get buffer from the input!"; - } - tmp_buf = std::make_unique(pybuf.len); - if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C')) { - MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer."; - } - PyBuffer_Release(&pybuf); - buf.ptr = tmp_buf.get(); - } - // Get tensor shape. - std::vector shape(buf.shape.begin(), buf.shape.end()); - if (data_type == buf_type) { - // Use memory copy if input data type is same as the required type. - return std::make_shared(data_type, shape, buf.ptr, buf.size * buf.itemsize); - } - // Create tensor with data type converted. - return std::make_shared(data_type, shape, buf.ptr, buf_type); -} - -static std::vector GetStrides(const std::vector &shape, ssize_t item_size) { - std::vector strides; - strides.reserve(shape.size()); - const auto ndim = shape.size(); - for (size_t i = 0; i < ndim; ++i) { - auto stride = item_size; - for (size_t j = i + 1; j < ndim; ++j) { - stride *= shape[j]; - } - strides.push_back(stride); - } - return strides; -} - -static py::buffer_info GetPyBufferInfo(const Tensor &tensor) { - std::vector shape(tensor.shape().begin(), tensor.shape().end()); - std::vector strides = GetStrides(shape, tensor.data().itemsize()); - return py::buffer_info{ - tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides}; -} - -py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) { - auto &shape = tensor.shape(); - py::tuple dims(shape.size()); - for (size_t i = 0; i < dims.size(); ++i) { - dims[i] = py::int_(shape[i]); - } - return dims; -} - -py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { - tensor.data_sync(); - auto info = GetPyBufferInfo(tensor); - py::object self = py::cast(&tensor); - return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); -} - -py::array TensorPy::AsNumpy(const Tensor &tensor) { - auto info = GetPyBufferInfo(tensor); - py::object self = py::cast(&tensor); - return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); -} - -static std::vector GetShapeFromTuple(const py::tuple &tuple) { - std::vector shape; - const size_t size = tuple.size(); - shape.reserve(tuple.size()); - for (size_t i = 0; i < size; ++i) { - shape.push_back(py::int_(tuple[i])); - } - return shape; -} - -REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { - // Define python Tensor class. - // dtype should define before Tensor, because Tensor init depend dtype - (void)py::class_>(*m, "Tensor") - .def(py::init([](const Tensor &tensor) { return std::make_shared(tensor); }), - py::arg("input")) - .def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) { - TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown; - if (data_type == kTypeUnknown || tensor.data_type() == data_type) { - return std::make_shared(tensor); - } - return std::make_shared(tensor, data_type); - }), - py::arg("input"), py::arg("dtype")) - .def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) { - auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64; - return std::make_shared(data_type, GetShapeFromTuple(shape)); - }), - py::arg("dtype"), py::arg("shape")) - .def(py::init([](const py::array &input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(input, type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::float_ input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::int_ input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::list input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::tuple input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_) - .def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter( - Get the tensor's data type. - - Returns: - type, the data type of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) - >>> data.dtype - Int32 - )mydelimiter") - .def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter( - Get the tensor's shape. - - Returns: - tuple[int], the shape of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((3, 3))) - >>> data.shape() - (3, 3) - )mydelimiter") - .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter( - Convert tensor to numpy.ndarray. - - Returns: - numpy.ndarray. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> array = data.asnumpy() - >>> array - array([[1., 1., 1.], - [1., 1., 1.]]) - )mydelimiter") - .def("size", &Tensor::DataSize, R"mydelimiter( - Get tensor's data size. - - Returns: - int, the size of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.size() - 6 - )mydelimiter") - .def("is_init", &Tensor::is_init, R"mydelimiter( - Get tensor init_flag. - - Returns: - bool, whether the tensor init. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.is_init() - False - )mydelimiter") - .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter( - Set tensor init_flag. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.set_init_flag(True) - )mydelimiter") - .def("dim", &Tensor::DataDim, R"mydelimiter( - Get tensor's data dimension. - - Returns: - int, the dimension of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.dim() - 2 - )mydelimiter") - .def("assign_value", &Tensor::AssignValue, R"mydelimiter( - Assign another tensor value to this. - - Arg: - value (:class:`mindspore.tensor`): The value tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) - >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32)) - >>> data.assign_value(data2) - >>> data.shape - (2, 2) - )mydelimiter") - .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( - Set the tensor's data type. - - Arg: - dtype (:class:`mindspore.dtype`): The type of output tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) - >>> data.set_dtype(mindspore.int32) - mindspore.int32 - )mydelimiter") - .def("__str__", &Tensor::ToString) - .def("__repr__", &Tensor::ToStringRepr) - .def(py::pickle( - [](const Tensor &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(TensorPy::AsNumpy(t)); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - return TensorPy::MakeTensor(t[0].cast()); - })); - // Define python MetaTensor class. - (void)py::class_>(*m, "MetaTensor") - .def(py::init>(), py::arg("dtype"), py::arg("shape")) - .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) - .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") - .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") - .def(py::pickle( - [](const MetaTensor &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(static_cast(t.data_type()), t.shape()); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 2) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - MetaTensor tensor(TypeId(t[0].cast()), t[1].cast>()); - return tensor; - })); - })); - -} // namespace tensor -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/tensor_py.h b/mindspore/ccsrc/ir/tensor_py.h deleted file mode 100644 index 18ee547071..0000000000 --- a/mindspore/ccsrc/ir/tensor_py.h +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_TENSOR_PY_H_ -#define MINDSPORE_CCSRC_IR_TENSOR_PY_H_ - -#include -#include -#include - -#include "pybind11/pybind11.h" -#include "pybind11/numpy.h" - -#include "ir/tensor.h" - -namespace py = pybind11; - -namespace pybind11 { -namespace detail { -// Similar to enums in `pybind11/numpy.h`. Determined by doing: -// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' -constexpr int NPY_FLOAT16 = 23; - -template -struct npy_scalar_caster { - PYBIND11_TYPE_CASTER(T, _("PleaseOverride")); - using Array = array_t; - - bool load(handle src, bool convert) { - // Taken from Eigen casters. Permits either scalar dtype or scalar array. - handle type = dtype::of().attr("type"); - if (!convert && !isinstance(src) && !isinstance(src, type)) return false; - - Array tmp = Array::ensure(src); - if (tmp && tmp.size() == 1 && tmp.ndim() == 0) { - this->value = *tmp.data(); - return true; - } - - return false; - } - - static handle cast(T src, return_value_policy, handle) { - Array tmp({1}); - tmp.mutable_at(0) = src; - tmp.resize({}); - - // You could also just return the array if you want a scalar array. - object scalar = tmp[tuple()]; - return scalar.release(); - } -}; - -template <> -struct npy_format_descriptor { - static constexpr auto name = "float16"; - static pybind11::dtype dtype() { - handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); - return reinterpret_borrow(ptr); - } - virtual ~npy_format_descriptor() {} -}; - -template <> -struct type_caster : public npy_scalar_caster { - static constexpr auto name = "float16"; -}; -} // namespace detail -} // namespace pybind11 - -using mindspore::device::DeviceAddress; -using DeviceAddressPtr = std::shared_ptr; -// brief mindspore namespace. -// -// mindspore namespace is the top level namespace of Mindsporeession project. -// Other namespace should be a sub namespace of mindspore namespace in the ME project. -namespace mindspore { -// brief mindspore::tensor namespace -// -// A sub namespace in ME to support tensor related definition. -namespace tensor { -// Tensor python wrapper and adapter class. -class TensorPy { - public: - // brief Create Tensor from a numpy array object. - // - // param input [py::array] Data value of the tensor. - // param data_type [TypeId] Data type of the tensor. - static TensorPtr MakeTensor(const py::array &input, const TypePtr &data_type = nullptr); - - static py::array SyncAsNumpy(const Tensor &tensor); - - static py::array AsNumpy(const Tensor &tensor); - - static py::tuple GetPyTupleShape(const Tensor &tensor); -}; - -} // namespace tensor -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_TENSOR_PY_H_ diff --git a/mindspore/ccsrc/ir/value.h b/mindspore/ccsrc/ir/value.h deleted file mode 100644 index ea9bb47ffe..0000000000 --- a/mindspore/ccsrc/ir/value.h +++ /dev/null @@ -1,306 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_VALUE_H_ -#define MINDSPORE_CCSRC_IR_VALUE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/base.h" -#include "ir/anf.h" -#include "ir/dtype.h" -#include "ir/scalar.h" -#include "ir/dtype/ref.h" -#include "utils/hashing.h" -#include "common/utils.h" - -namespace mindspore { -class ValueSequeue : public Value { - public: - explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) { - TypePtrList t_list; - (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) { - MS_EXCEPTION_IF_NULL(ele); - return ele->type(); - }); - TypePtr t = std::make_shared(t_list); - type_ = t; - } - ValueSequeue(const std::initializer_list &elements) : elements_(elements.begin(), elements.end()) { - TypePtrList t_list; - (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list), - [](const ValuePtr &ele) { return ele->type(); }); - TypePtr t = std::make_shared(t_list); - type_ = t; - } - ~ValueSequeue() override = default; - MS_DECLARE_PARENT(ValueSequeue, Value) - std::size_t hash() const override { return hash_combine(tid(), std::hash{}(elements_.size())); } - std::size_t size() const { return elements_.size(); } - bool erase(size_t idx); - const ValuePtr operator[](const std::size_t &dim) const; - const ValuePtrList &value() const { return elements_; } - bool operator==(const Value &other) const override; - bool operator==(const ValueSequeue &other) const; - std::string ToString() const override; - std::string DumpText() const override; - - protected: - ValuePtrList elements_; -}; -using ValueSequeuePtr = std::shared_ptr; - -class ValueTuple : public ValueSequeue { - public: - explicit ValueTuple(const std::vector &elements) : ValueSequeue(elements) {} - ValueTuple(const std::initializer_list &elements) : ValueSequeue(elements) {} - ~ValueTuple() override = default; - MS_DECLARE_PARENT(ValueTuple, ValueSequeue) - abstract::AbstractBasePtr ToAbstract() override; - - std::string DumpText() const override { return "(" + ValueSequeue::DumpText() + ")"; } - std::string ToString() const override { return "(" + ValueSequeue::ToString() + ")"; } -}; -using ValueTuplePtr = std::shared_ptr; - -class ValueList : public ValueSequeue { - public: - explicit ValueList(const std::vector &elements) : ValueSequeue(elements) {} - ValueList(const std::initializer_list &elements) : ValueSequeue(elements) {} - ~ValueList() override = default; - MS_DECLARE_PARENT(ValueList, ValueSequeue) - abstract::AbstractBasePtr ToAbstract() override; - - std::string DumpText() const override { return "[" + ValueSequeue::DumpText() + "]"; } - std::string ToString() const override { return "[" + ValueSequeue::ToString() + "]"; } -}; -using ValueListPtr = std::shared_ptr; - -inline ValuePtr MakeValue(const std::vector &v) { return std::make_shared(v); } -inline ValuePtr MakeValue(std::initializer_list v) { return std::make_shared(v); } - -template -struct is_vector : public std::false_type {}; -template -struct is_vector> : public std::true_type {}; - -template ::value, typename T::value_type>::type> -ValuePtr MakeValue(const T &vec) { - std::vector list; - (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); }); - return std::make_shared(list); -} - -class ValueSlice : public Value { - public: - ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step) - : start_(start), stop_(stop), step_(step) {} - ~ValueSlice() override = default; - MS_DECLARE_PARENT(ValueSlice, Value) - std::size_t hash() const override; - bool operator==(const Value &other) const override; - bool operator==(const ValueSlice &other) const; - - std::string ToString() const override; - - abstract::AbstractBasePtr ToAbstract() override; - std::string DumpText() const override { return ToString(); } - ValuePtr start() const { return start_; } - ValuePtr stop() const { return stop_; } - ValuePtr step() const { return step_; } - - private: - ValuePtr start_; - ValuePtr stop_; - ValuePtr step_; -}; -using ValueSlicePtr = std::shared_ptr; - -class KeywordArg : public Value { - public: - KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {} - ~KeywordArg() override = default; - MS_DECLARE_PARENT(KeywordArg, Value) - std::size_t hash() const override; - ValuePtr get_value() const { return value_; } - bool operator==(const Value &other) const override; - bool operator==(const KeywordArg &other) const; - - std::string ToString() const override; - - abstract::AbstractBasePtr ToAbstract() override; - std::string DumpText() const override { return ToString(); } - - private: - std::string key_; - ValuePtr value_; -}; -using KeywordArgPtr = std::shared_ptr; - -class ValueDictionary : public Value { - public: - explicit ValueDictionary(const std::vector> &key_values) : key_values_(key_values) {} - ~ValueDictionary() override = default; - MS_DECLARE_PARENT(ValueDictionary, Value) - std::size_t hash() const override { return hash_combine(tid(), std::hash{}(key_values_.size())); } - std::size_t size() const { return key_values_.size(); } - const ValuePtr operator[](const std::string &key) const; - const std::vector> &value() const { return key_values_; } - bool operator==(const Value &other) const override; - bool operator==(const ValueDictionary &other) const; - - std::string ToString() const override { - std::ostringstream buffer; - std::vector keys; - std::vector values; - for (const auto &kv : key_values_) { - keys.push_back(kv.first); - values.push_back(kv.second); - } - buffer << "(Dict: " - << " keys:("; - for (const auto &key : keys) { - buffer << key << ", "; - } - buffer << ") values:("; - for (const auto &value : values) { - MS_EXCEPTION_IF_NULL(value); - buffer << value->DumpText() << ", "; - } - buffer << ")"; - return buffer.str(); - } - abstract::AbstractBasePtr ToAbstract() override; - std::string DumpText() const override { return ToString(); } - - private: - std::vector> key_values_; -}; -using ValueDictionaryPtr = std::shared_ptr; - -class StringImm : public Value { - public: - explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash{}(str_)) {} - - ~StringImm() override = default; - MS_DECLARE_PARENT(StringImm, Value) - std::size_t hash() const override { return hash_; } - const std::string &value() const { return str_; } - bool operator==(const Value &other) const override; - bool operator==(const StringImm &other) const; - abstract::AbstractBasePtr ToAbstract() override; - std::string ToString() const override { return str_; } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "\"" << str_ << "\""; - return oss.str(); - } - - private: - std::string str_; - std::size_t hash_ = 0; -}; -using StringImmPtr = std::shared_ptr; -IMM_TRAITS(StringImmPtr, std::string) -IMM_TRAITS(StringImmPtr, const char *) - -class RefKey : public Value { - public: - explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash{}(tag)) {} - - ~RefKey() override = default; - MS_DECLARE_PARENT(RefKey, Value) - std::size_t hash() const override { return hash_; } - const std::string &tag() const { return tag_; } - bool operator==(const Value &other) const override; - bool operator==(const RefKey &other) const; - abstract::AbstractBasePtr ToAbstract() override; - std::string ToString() const override { return "RefKey[" + tag_ + "]"; } - - std::string DumpText() const override { - std::ostringstream oss; - oss << "RefKey[\"" << tag_ << "\"]"; - return oss.str(); - } - - private: - std::string tag_; - std::size_t hash_ = 0; -}; -using RefKeyPtr = std::shared_ptr; - -class AnyValue : public Value { - public: - AnyValue() = default; - ~AnyValue() override = default; - MS_DECLARE_PARENT(AnyValue, Value) - std::size_t hash() const override { return tid(); } - bool operator==(const Value &other) const override; - abstract::AbstractBasePtr ToAbstract() override; -}; -extern const ValuePtr kAnyValue; - -template <> -inline const char *GetValue(const ValuePtr &value) { - if (value == nullptr) { - MS_LOG(EXCEPTION) << "Value is nullptr"; - } - auto imm = value->cast(); - if (imm == nullptr) { - MS_LOG(EXCEPTION) << "GetValue:" << value->ToString() << ", Type:" << value->type_name(); - } - return common::SafeCStr(imm->value()); -} - -template ::type, - typename U = typename std::enable_if::value, typename S::value_type>::type> -std::vector GetValue(const ValuePtr &value) { - if (value == nullptr) { - MS_LOG(EXCEPTION) << "Value is nullptr"; - } - - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Error GetValue for value: " << value->ToString() << ", type: vector<" << typeid(U).name() - << ">"; - } - std::vector rets; - const std::vector &vals = value->cast()->value(); - (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets), - [](const ValuePtr &v) { return GetValue(v); }); - return rets; -} - -inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared(t); } - -template ::value>::type> -inline ValueNodePtr NewValueNode(const std::shared_ptr &x) { - return NewValueNode(MakeValue(x)); -} - -template ::value>::type> -inline ValueNodePtr NewValueNode(const T &x) { - return NewValueNode(MakeValue(x)); -} -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_VALUE_H_ diff --git a/mindspore/ccsrc/ir/value_extends.cc b/mindspore/ccsrc/ir/value_extends.cc deleted file mode 100644 index 8eb34d0eeb..0000000000 --- a/mindspore/ccsrc/ir/value_extends.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/value.h" -#include -#include -#include -#include - -#include "pybind_api/api_register.h" -#include "pipeline/static_analysis/abstract_value.h" - -namespace mindspore { -using ContextPtr = abstract::AnalysisContextPtr; - -abstract::AbstractBasePtr Scalar::ToAbstract() { - return std::make_shared(shared_from_base()); -} - -abstract::AbstractBasePtr StringImm::ToAbstract() { - return std::make_shared(shared_from_base(), std::make_shared()); -} - -abstract::AbstractBasePtr RefKey::ToAbstract() { - auto refkey = std::make_shared(); - refkey->set_value(shared_from_base()); - return refkey; -} - -abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared(); } - -abstract::AbstractBasePtr ValueTuple::ToAbstract() { - abstract::AbstractBasePtrList a_list; - (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { - MS_EXCEPTION_IF_NULL(ele); - return ele->ToAbstract(); - }); - return std::make_shared(a_list); -} - -abstract::AbstractBasePtr ValueList::ToAbstract() { - abstract::AbstractBasePtrList a_list; - (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { - MS_EXCEPTION_IF_NULL(ele); - return ele->ToAbstract(); - }); - return std::make_shared(a_list); -} - -abstract::AbstractBasePtr ValueSlice::ToAbstract() { - MS_EXCEPTION_IF_NULL(start_); - MS_EXCEPTION_IF_NULL(stop_); - MS_EXCEPTION_IF_NULL(step_); - abstract::AbstractBasePtr start = start_->ToAbstract(); - abstract::AbstractBasePtr end = stop_->ToAbstract(); - abstract::AbstractBasePtr step = step_->ToAbstract(); - return std::make_shared(start, end, step); -} - -abstract::AbstractBasePtr KeywordArg::ToAbstract() { - MS_EXCEPTION_IF_NULL(value_); - abstract::AbstractBasePtr argument = value_->ToAbstract(); - return std::make_shared(key_, argument); -} - -abstract::AbstractBasePtr ValueDictionary::ToAbstract() { - std::vector> kv; - (void)std::transform( - key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const std::pair &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); - return std::make_shared(kv); -} - -REGISTER_PYBIND_DEFINE( - RefKey, ([](const py::module *m) { - (void)py::class_>(*m, "RefKey").def(py::init(), py::arg("tag")); - })); -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/CMakeLists.txt b/mindspore/ccsrc/kernel/CMakeLists.txt deleted file mode 100644 index ceea6b1a99..0000000000 --- a/mindspore/ccsrc/kernel/CMakeLists.txt +++ /dev/null @@ -1,58 +0,0 @@ -file(GLOB_RECURSE KERNEL_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "kernel_build_info.cc" - "kash/*.cc" - "common_utils.cc" - "oplib/*.cc" -) - -if (ENABLE_D) - file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "kernel_query.cc" - "kernel_fusion.cc" - "akg/ascend/*.cc" - "akg/akg_kernel_build.cc" - "akg/akg_kernel_attrs_process.cc" - "akg/akg_kernel_metadata.cc" - "tbe/*.cc" - "aicpu/*.cc" - "rts/*.cc" - "hccl/*.cc" - ) - add_compile_definitions(ENABLE_D) -endif () - -if (ENABLE_CPU) - file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "cpu/*.cc" - ) - - if (NOT ENABLE_MPI) - list(REMOVE_ITEM CPU_SRC_LIST "cpu/allgather_cpu_kernel.cc") - list(REMOVE_ITEM CPU_SRC_LIST "cpu/reduce_scatter_cpu_kernel.cc") - list(REMOVE_ITEM CPU_SRC_LIST "cpu/embedding_look_up_comm_grad_cpu_kernel.cc") - endif () -endif () - -if (ENABLE_GPU) - file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "gpu/*.cu" - "akg/gpu/*.cc" - "akg/akg_kernel_build.cc" - "akg/akg_kernel_attrs_process.cc" - ) - - file(GLOB_RECURSE GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc") - list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_gpu_kernel.cc") - - if (ENABLE_MPI) - include(ExternalProject) - file(GLOB_RECURSE GPU_NCCL_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/nccl/*.cc") - list(APPEND GPU_SRC_LIST ${GPU_NCCL_LIST}) - endif () - - # add_library(_mindspore_kernel_cuda_obj OBJECT ${CUDA_SRC_LIST}) -endif() - -set_property(SOURCE ${KERNEL_SRC_LIST} ${CPU_SRC_LIST} ${GPU_SRC_LIST} ${D_SRC_LIST} - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_KERNEL) -add_library(_mindspore_kernel_obj OBJECT ${KERNEL_SRC_LIST} ${CPU_SRC_LIST} ${GPU_SRC_LIST} ${D_SRC_LIST}) diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.h b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.h deleted file mode 100644 index a3c24ae49e..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.h +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ -#include -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.cc deleted file mode 100644 index 3670a2d76f..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/aicpu/aicpu_kernel_metadata.h" -#include -#include -#include "kernel/oplib/oplib.h" -#include "kernel/common_utils.h" -#include "kernel/aicpu/aicpu_util.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { - MS_LOG(INFO) << "AicpuMetadataInfo."; - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - if (op_name == kInitDataSetQueue) { - op_name = kInitData; - } - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); - if (op_info_ptr == nullptr) { - MS_LOG(DEBUG) << "Aicpu does not have op [" << op_name << "]"; - return; - } - // For compatibility with the current framework - if (op_name == kPrint || op_name == kGetNext || op_name == kPack) { - std::vector inputs_format{}; - std::vector inputs_type{}; - if (op_name == kPrint || op_name == kPack) { - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - inputs_format.emplace_back(kOpFormat_DEFAULT); - inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); - } - } - std::vector outputs_format; - std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - outputs_format.emplace_back(kOpFormat_DEFAULT); - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); - } - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - builder.SetInputsFormat(inputs_format); - builder.SetInputsDeviceType(inputs_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_type); - builder.SetProcessor(AICPU); - builder.SetKernelType(AICPU_KERNEL); - builder.SetFusionType(OPAQUE); - kernel_info_list->push_back(builder.Build()); - return; - } - if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) { - MS_LOG(WARNING) << "Aicpu parsed metadata op [" << op_name << "] failed"; - return; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.h b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.h deleted file mode 100644 index 74e667856e..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ - -#include -#include -#include -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc deleted file mode 100644 index 2213f176cc..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc +++ /dev/null @@ -1,154 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/aicpu/aicpu_kernel_mod.h" - -#include -#include -#include -#include - -#include "runtime/mem.h" -#include "runtime/rt.h" -#include "kernel/aicpu/aicpu_kernel_build.h" -#include "utils/convert_utils.h" -#include "kernel/aicpu/aicpu_util.h" - -using AicpuTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -constexpr auto AICPU_OPS_SO_NAME = "libaicpu_kernels.so"; - -AicpuOpKernelMod::AicpuOpKernelMod() : anf_node_(nullptr) {} - -AicpuOpKernelMod::~AicpuOpKernelMod() { - args_.clear(); - inputList_.clear(); - outputList_.clear(); - anf_node_ = nullptr; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); -} - -void AicpuOpKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } -const std::vector &AicpuOpKernelMod::GetInputSizeList() const { return input_size_list_; } -void AicpuOpKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } -const std::vector &AicpuOpKernelMod::GetOutputSizeList() const { return output_size_list_; } -void AicpuOpKernelMod::SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } -const std::vector &AicpuOpKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } -void AicpuOpKernelMod::SetInputList(const std::vector &inputList) { inputList_ = inputList; } -void AicpuOpKernelMod::SetOutputList(const std::vector &outputList) { outputList_ = outputList; } -void AicpuOpKernelMod::SetNodeDef(const std::string &nodeDef) { (void)node_def_str_.assign(nodeDef); } -void AicpuOpKernelMod::SetNodeName(const std::string &node_name) { node_name_ = node_name; } -void AicpuOpKernelMod::SetAnfNode(const mindspore::AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - anf_node_ = anf_node; -} - -void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector &inputs, - const std::vector &outputs) { - MS_LOG(INFO) << "CreateCpuKernelInfoOffline start"; - - node_so_ = AICPU_OPS_SO_NAME; - - // InputOutputAddr - vector io_addrs; - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(io_addrs), - [](const AddressPtr &input) -> void * { return input->addr; }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(io_addrs), - [](const AddressPtr &output) -> void * { return output->addr; }); - - auto io_addrs_num = io_addrs.size(); - // calculate paramLen: AicpuParamHead.len + ioAddrsSize + notifyId.len + customizedAttr.len - auto param_len = sizeof(AicpuParamHead); - - // get input and output addrs size, no need to check overflow - auto io_addrs_size = io_addrs_num * sizeof(uint64_t); - // refresh paramLen, no need to check overflow - param_len += io_addrs_size; - - auto node_def_len = node_def_str_.length(); - param_len += node_def_len; - - // Create taskArgs: AicpuParamHead + ioAddrs + notifyId + customizedAttr - AicpuParamHead paramHead = {static_cast(param_len), static_cast(io_addrs_num)}; - args_.clear(); - (void)args_.append(reinterpret_cast(¶mHead), sizeof(AicpuParamHead)); - // TaskArgs append ioAddrs - if (io_addrs_size != 0) { - (void)args_.append(reinterpret_cast(io_addrs.data()), io_addrs_size); - } - - // When it's aicpu customized ops, taskArgs should append customized attr - if (node_def_len != 0) { - (void)args_.append(reinterpret_cast(node_def_str_.data()), node_def_len); - } - - MS_LOG(INFO) << "CreateCpuKernelInfoOffline end"; -} - -bool AicpuOpKernelMod::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { - if (stream_ptr == nullptr) { - MS_LOG(ERROR) << "stream_ptr should not be nullptr."; - return false; - } - - CreateCpuKernelInfo(inputs, outputs); - if (node_name_ == kTopK) { - node_name_ = kTopKV2; - } - MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_ - << ", args_size:" << args_.length(); - if (rtCpuKernelLaunch(reinterpret_cast(node_so_.c_str()), - reinterpret_cast(node_name_.c_str()), 1, - reinterpret_cast(args_.data()), static_cast(args_.length()), nullptr, - stream_ptr) != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Aicpu op launch failed!"; - - return false; - } - return true; -} - -std::vector AicpuOpKernelMod::GenTask(const std::vector &inputs, - const std::vector &, - const std::vector &outputs, uint32_t stream_id) { - MS_LOG(INFO) << "AicpuOpKernelMod GenTask start"; - - stream_id_ = stream_id; - node_so_ = AICPU_OPS_SO_NAME; - std::vector input_data_addrs; - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), - [](const AddressPtr &input) -> void * { return input->addr; }); - - std::vector output_data_addrs; - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), - [](const AddressPtr &output) -> void * { return output->addr; }); - - if (node_name_ == kTopK) { - node_name_ = kTopKV2; - } - AicpuTaskInfoPtr task_info_ptr = make_shared( - stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs); - - MS_LOG(INFO) << "AicpuOpKernelMod GenTask end"; - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.h b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.h deleted file mode 100644 index 3ee9bd2a15..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ -#include -#include -#include -#include "kernel/ascend_kernel_mod.h" -#include "kernel/aicpu/aicpu_util.h" -namespace mindspore { -namespace kernel { -class AicpuOpKernelMod : public AscendKernelMod { - public: - AicpuOpKernelMod(); - ~AicpuOpKernelMod() override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - void SetInputList(const std::vector &inputList); - void SetOutputList(const std::vector &outputList); - void SetAnfNode(const AnfNodePtr &anf_node); - void SetNodeDef(const std::string &nodeDef); - void SetNodeName(const std::string &node_name); - - /** - * @brief Build AICPU Engine kernel structure, and allocate device memory for offline task generate - * @return SUCCESS - * @return FAIL - * - */ - void CreateCpuKernelInfo(const std::vector &inputs, const std::vector &outputs); - - void SetInputSizeList(const std::vector &size_list); - void SetOutputSizeList(const std::vector &size_list); - void SetWorkspaceSizeList(const std::vector &size_list); - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - - private: - std::string args_; - std::string node_def_str_; - std::string node_name_; - std::string node_so_; - std::vector inputList_; - std::vector outputList_; - AnfNodePtr anf_node_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -using AicpuOpKernelModPtr = std::shared_ptr; -using AicputOpKernelModPtrList = std::vector; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc deleted file mode 100644 index a617f56f8f..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/aicpu/aicpu_util.h" -#include -#include -#include "proto/types.pb.h" -#include "runtime/mem.h" -#include "runtime/rt.h" -#include "utils/convert_utils.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -static std::map MS_PROTO_DATA_TYPE_MAP = { - {mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN}, - {mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL}, - {mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32}, - {mindspore::TypeId::kNumberTypeInt8, mindspore::DataType::MS_INT8}, - {mindspore::TypeId::kNumberTypeInt16, mindspore::DataType::MS_INT16}, - {mindspore::TypeId::kNumberTypeInt32, mindspore::DataType::MS_INT32}, - {mindspore::TypeId::kNumberTypeInt64, mindspore::DataType::MS_INT64}, - {mindspore::TypeId::kNumberTypeUInt, mindspore::DataType::MS_UINT32}, - {mindspore::TypeId::kNumberTypeUInt8, mindspore::DataType::MS_UINT8}, - {mindspore::TypeId::kNumberTypeUInt16, mindspore::DataType::MS_UINT16}, - {mindspore::TypeId::kNumberTypeUInt32, mindspore::DataType::MS_UINT32}, - {mindspore::TypeId::kNumberTypeUInt64, mindspore::DataType::MS_UINT64}, - {mindspore::TypeId::kNumberTypeFloat16, mindspore::DataType::MS_FLOAT16}, - {mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32}, - {mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32}, - {mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64}, -}; - -int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) { - auto iter = MS_PROTO_DATA_TYPE_MAP.find(ms_type); - if (iter != MS_PROTO_DATA_TYPE_MAP.end()) { - return MS_PROTO_DATA_TYPE_MAP[ms_type]; - } else { - MS_LOG(ERROR) << "UnSupported ms_type value" << static_cast(ms_type); - return -1; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.cc b/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.cc deleted file mode 100644 index 018fbe4f2a..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.cc +++ /dev/null @@ -1,180 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/akg/akg_kernel_attrs_process.h" - -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace kernel { -void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - // The x and output are akg op input and output param. - std::vector input_names = {"x"}; - std::vector output_names = {"output"}; - AnfAlgo::SetNodeAttr("input_names", MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr("output_names", MakeValue(output_names), anf_node); - - TypeId dst_type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); - std::string dst_type; - if (dst_type_id == kFloat32->type_id()) { - dst_type = "float32"; - } else if (dst_type_id == kFloat16->type_id()) { - dst_type = "float16"; - } - AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); -} - -void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector input_names = {"x"}; - std::vector output_names = {"output"}; - AnfAlgo::SetNodeAttr("input_names", MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr("output_names", MakeValue(output_names), anf_node); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(anf_node, 0); - if (origin_shape.size() != kShape4dDims) { - MS_LOG(EXCEPTION) << "The dim of origin_shape is not equal to 4, but it's dim is " << origin_shape.size() << "."; - } - std::vector shape_transform; - (void)std::transform(origin_shape.begin(), origin_shape.end(), std::back_inserter(shape_transform), - [](const int &origin_shape) { return static_cast(origin_shape); }); - AnfAlgo::SetNodeAttr("shape4d", MakeValue(shape_transform), anf_node); - AnfAlgo::SetNodeAttr("output_format", MakeValue(kOpFormat_NCHW), anf_node); - - TypeId dst_type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); - std::string dst_type; - if (dst_type_id == kFloat32->type_id()) { - dst_type = "float32"; - } else if (dst_type_id == kFloat16->type_id()) { - dst_type = "float16"; - } - AnfAlgo::SetNodeAttr("dstType", MakeValue(dst_type), anf_node); -} - -void SetAkgAttrsForCast(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - // The x and output are akg op input and output param. - std::vector input_names = {"x", "dst_type"}; - std::vector output_names = {"output"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); - - std::string dst_type; - TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); - if (output_type == kFloat32->type_id()) { - dst_type = "float32"; - } else if (output_type == kFloat16->type_id()) { - dst_type = "float16"; - } else if (output_type == kInt32->type_id()) { - dst_type = "int32"; - } else { - MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString(); - } - AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); -} - -void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector input_names{"dy", "data", "mean"}; - std::vector output_names{"dgamma_red_hw", "dbeta_red_hw", "data_minus_mean"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); -} - -void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node) { - const size_t kBNGrad2InputSize = 5; - MS_EXCEPTION_IF_NULL(anf_node); - std::vector input_names{"dgamma_red_hw", "dbeta_red_hw", "variance", "gamma"}; - std::vector output_names{"bn_scale", "bn_bias", "rs", "dgamma_dx", "dbeta_dx"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() < kBNGrad2InputSize) { - MS_LOG(EXCEPTION) << "The inputs size of BNGrad2 is less then " << kBNGrad2InputSize; - } - auto input1 = cnode->input(1); - MS_EXCEPTION_IF_NULL(input1); - auto tuple_getitem = input1->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - if (tuple_getitem->inputs().size() < kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "The inputs size of tuple_getitem is less then " << kTupleGetItemInputSize; - } - auto bn_grad1 = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); - std::vector data_shape = AnfAlgo::GetInputDeviceShape(bn_grad1, 0); - AnfAlgo::SetNodeAttr(kAttrDataShape, MakeValue(opt::Convert2Int(data_shape)), anf_node); -} - -void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector input_names{"dy", "rs", "dgamma_dx", "dbeta_dx", "data_minus_mean"}; - std::vector output_names{"dx"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); -} - -void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - // Set attr for fused_bn1 - std::vector fused_bn1_input_names{"data"}; - std::vector fused_bn1_output_names{"mean", "var_part"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn1_input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn1_output_names), anf_node); -} - -void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - // Set attr for fused_bn2 - std::vector fused_bn2_input_names{"mean", "var_part", "running_mean", "running_var"}; - std::vector fused_bn2_output_names{"variance", "running_mean", "running_variance"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn2_input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn2_output_names), anf_node); -} - -void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - // Set attr for fused_bn3 - std::vector fused_bn3_input_names{"data", "mean", "variance", "gamma", "beta"}; - std::vector fused_bn3_output_names{"y"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn3_input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn3_output_names), anf_node); -} - -void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector conv_bn1_output_names{"data", "var_part", "mean"}; - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(conv_bn1_output_names), anf_node); -} - -void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector bn2_add_relu_input_names{"data", "var_part", "mean", "other_branch_data", - "gamma", "beta", "running_mean", "running_var"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_add_relu_input_names), anf_node); - std::vector bn2_add_relu_output_names{"output", "running_mean", "running_variance", "save_inv_variance"}; - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_add_relu_output_names), anf_node); -} - -void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector bn2_input_names{"data", "var_part", "mean", "gamma", "beta", "running_mean", "running_var"}; - std::vector bn2_output_names{"y", "running_mean", "running_variance", "save_inv_variance"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.h b/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.h deleted file mode 100644 index 9d15d4f9e9..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H -#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H - -#include -#include -#include -#include -#include "ir/anf.h" -#include "utils/utils.h" -#include "operator/ops.h" - -namespace mindspore { -namespace kernel { -void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node); -void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node); -void SetAkgAttrsForCast(const AnfNodePtr &anf_node); -void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node); -void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node); -void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node); -void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node); -void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node); -void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node); -void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node); -void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node); -void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node); - -const std::unordered_map> kAkgKernelAttrsProcessMap = { - {kFour2FiveOpName, SetAkgAttrsForFour2Five}, - {kFive2FourOpName, SetAkgAttrsForFive2Four}, - {"Cast", SetAkgAttrsForCast}, - {kBNGrad1OpName, SetAkgAttrsForBNGrad1}, - {kBNGrad2OpName, SetAkgAttrsForBNGrad2}, - {kBNGrad3OpName, SetAkgAttrsForBNGrad3}, - {kFusedBN1OpName, SetAkgAttrsForFusedBN1}, - {kFusedBN2OpName, SetAkgAttrsForFusedBN2}, - {kFusedBN3OpName, SetAkgAttrsForFusedBN3}, - {kConvBN1OpName, SetAkgAttrsForConvBN1}, - {kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu}, - {kBN2ReLUOpName, SetAkgAttrsForBN2Relu}, -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_build.cc b/mindspore/ccsrc/kernel/akg/akg_kernel_build.cc deleted file mode 100644 index 0e8d93d47f..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_build.cc +++ /dev/null @@ -1,623 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/akg_kernel_build.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "utils/convert_utils.h" -#include "utils/any.h" -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/akg/akg_kernel_attrs_process.h" - -namespace mindspore { -namespace kernel { -constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; -constexpr int32_t ARGS_SIZE = 1; -constexpr auto kCompileWithJsonFunc = "compilewithjson"; - -// json key -constexpr auto kOpDesc = "op_desc"; -constexpr auto kInputDesc = "input_desc"; -constexpr auto kShape = "shape"; -constexpr auto kDataType = "data_type"; -constexpr auto kOutputDesc = "output_desc"; -constexpr auto kName = "name"; -constexpr auto kTensorName = "tensor_name"; -constexpr auto kValue = "value"; -constexpr auto KDynInputSizes = "dyn_input_sizes"; -constexpr auto KInputNames = "input_names"; -constexpr auto KInput = "input"; -constexpr auto KDtype = "dtype"; -namespace { -template -std::string Vector2Str(const std::vector &inputs) { - if (!inputs.empty()) { - std::ostringstream oss; - (void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator(oss, ", ")); - oss << inputs.back(); - return oss.str(); - } - return ""; -} -} // namespace - -std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) { - char *pChar = nullptr; - std::string str_res; - if (PyObj == nullptr) { - MS_LOG(ERROR) << "Input parameter is nullptr."; - return str_res; - } - PyObject *strArgs = PyObject_Str(PyObj); - if (strArgs != nullptr) { - (void)PyArg_Parse(strArgs, "s", &pChar); - } - if (pChar == nullptr) { - MS_LOG(ERROR) << "pChar is nullptr."; - return str_res; - } - str_res = pChar; - return str_res; -} - -std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, - const std::pair &position) { - if (node_json.count(tag) == 0) { - MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "]."; - return ""; - } - - auto const &tag_desc = node_json[tag]; - nlohmann::json first_index; - if (tag == kOutputDesc) { - first_index = tag_desc; - } else if (!tag_desc.is_array() || tag_desc.size() <= position.first) { - MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "]."; - return ""; - } else { - first_index = tag_desc[position.first]; - } - - if (!first_index.is_array() || first_index.size() <= position.second) { - MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "]."; - return ""; - } - auto const &second_index = first_index[position.second]; - if (second_index.count(kTensorName) == 0) { - MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "]."; - return ""; - } - - return second_index[kTensorName]; -} - -void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair &position, - nlohmann::json *const node_json) { - MS_EXCEPTION_IF_NULL(node_json); - if (node_json->count(tag) == 0) { - MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "]."; - return; - } - - nlohmann::json *tag_desc = &((*node_json)[tag]); - nlohmann::json *first_index; - if (tag == kOutputDesc) { - first_index = tag_desc; - } else if (!tag_desc->is_array() || tag_desc->size() <= position.first) { - MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "]."; - return; - } else { - first_index = &((*tag_desc)[position.first]); - } - - if (!first_index->is_array() || first_index->size() <= position.second) { - MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "]."; - return; - } - nlohmann::json *second_index = &((*first_index)[position.second]); - if (second_index->count(kTensorName) == 0) { - MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "]."; - return; - } - (*second_index)[kTensorName] = new_name; - return; -} - -int AkgKernelBuild::op_cnt_ = 0; -std::mutex AkgKernelBuild::op_cnt_mtx_; - -std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string device; - switch (AnfAlgo::GetProcessor(anf_node)) { - case Processor::AICORE: - device = kProcessorAiCore; - break; - - case Processor::AICPU: - device = kProcessorAiCpu; - break; - - case Processor::CUDA: - device = kProcessorCuda; - break; - - default: - MS_LOG(ERROR) << "Unknown processor type."; - break; - } - - return device; -} - -bool GetIOSize(const nlohmann::json &node_json, std::vector *const input_size, - std::vector *const output_size) { - if (input_size == nullptr || output_size == nullptr) { - MS_LOG(ERROR) << "input size or output size is nullptr"; - return false; - } - input_size->clear(); - output_size->clear(); - - for (size_t i = 0; i < node_json[kInputDesc].size(); i++) { - for (size_t m = 0; m < node_json[kInputDesc][i].size(); m++) { - std::string dtype = node_json[kInputDesc][i][m][kDataType]; - size_t nbyte = GetDtypeNbyte(dtype); - size_t size_i = std::accumulate(node_json[kInputDesc][i][m][kShape].begin(), - node_json[kInputDesc][i][m][kShape].end(), nbyte, std::multiplies()); - input_size->push_back(size_i); - } - } - - for (size_t i = 0; i < node_json[kOutputDesc].size(); i++) { - std::string dtype = node_json[kOutputDesc][i][kDataType]; - size_t nbyte = GetDtypeNbyte(dtype); - size_t size_i = std::accumulate(node_json[kOutputDesc][i][kShape].begin(), node_json[kOutputDesc][i][kShape].end(), - nbyte, std::multiplies()); - output_size->push_back(size_i); - } - - return true; -} - -int AkgKernelBuild::GetOpCntInc() { - op_cnt_mtx_.lock(); - int cnt = op_cnt_++; - op_cnt_mtx_.unlock(); - return cnt; -} - -bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(inputs_json); - - // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); - if (op_info == nullptr) { - MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op_info is nullptr"; - return false; - } - - std::vector> inputs_ptr = op_info->inputs_ptr(); - if (inputs_ptr.empty()) { - MS_LOG(INFO) << "Apply kernel [" << op_name << "] regist info has no input info"; - return true; - } - auto op_info_input_num = inputs_ptr.size(); - - // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. - std::vector dyn_input_sizes; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - - if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - - size_t real_input_index = 0; - std::vector input_list; - for (size_t i = 0; i < op_info_input_num; i++) { - size_t input_tensor_num; - std::shared_ptr input_ptr = inputs_ptr[i]; - std::string op_input_name; - if (input_ptr == nullptr) { - MS_LOG(ERROR) << "Apply kernel [" << op_name << "] regist input[" << i << "] is nullptr"; - return false; - } - - op_input_name = input_ptr->name(); - if (dyn_input_sizes.empty()) { - input_tensor_num = 1; - } else { - input_tensor_num = IntToSize(dyn_input_sizes[i]); - } - - input_list.clear(); - for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { - // dtype : float16 - auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index); - std::string dtype = TypeId2String(type_id); - if (dtype.empty()) { - MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. "; - return false; - } - nlohmann::json input_desc_json; - input_desc_json[kDataType] = dtype; - input_desc_json[kName] = op_input_name; - input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); - auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); - if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && - GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { - MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2) - << "] as const tensor, shape: [" << Vector2Str(input_shape) - << "], value: " << input_desc_json[kValue]; - - input_shape.clear(); - } - if (input_shape.empty()) { - input_shape.push_back(1); - } - input_desc_json[kShape] = input_shape; - input_list.emplace_back(input_desc_json); - real_input_index++; - } - inputs_json->emplace_back(input_list); - } - return true; -} - -bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(outputs_json); - size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); - auto outputs = op_info_ptr->outputs_ptr(); - for (size_t i = 0; i < output_tensor_num; i++) { - nlohmann::json output_json; - auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i); - std::string dtype = TypeId2String(type_id); - if (dtype.empty()) { - MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. "; - return false; - } - - std::string output_name = outputs[i]->name(); - output_json[kDataType] = dtype; - output_json[kName] = output_name; - output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc()); - output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i); - outputs_json->push_back(output_json); - } - return true; -} - -void GetJson(const AnfNodePtr &anf_node, const std::vector &dyn_input_sizes, - const std::shared_ptr &op_attr, nlohmann::json *const attr_json, const ValuePtr &attr_value) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_attr); - MS_EXCEPTION_IF_NULL(attr_json); - std::string type = op_attr->type(); - if (type == "int") { - (*attr_json)[kValue] = GetValue(attr_value); - } else if (type == "str") { - (*attr_json)[kValue] = GetValue(attr_value); - } else if (type == "bool") { - (*attr_json)[kValue] = GetValue(attr_value); - } else if (type == "float") { - (*attr_json)[kValue] = GetValue(attr_value); - } else if (type == "listInt") { - (*attr_json)[kValue] = GetValue>(attr_value); - } else if (type == "listStr") { - std::vector data_format; - if (op_attr->name() == kArgDataformat) { - size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); - for (size_t format_i = 0; format_i < tensor_args_num; format_i++) { - auto input_format = AnfAlgo::GetInputFormat(anf_node, format_i); - data_format.push_back(input_format); - } - } else { - data_format = GetValue>(attr_value); - } - (*attr_json)[kValue] = data_format; - } else { - MS_LOG(WARNING) << "attr type:" << type; - } -} - -bool AkgKernelBuild::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, - const std::shared_ptr &op_info, nlohmann::json *const attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(attrs_json); - MS_EXCEPTION_IF_NULL(op_info); - std::vector> attrs = op_info->attrs_ptr(); - if (attrs.empty()) { - MS_LOG(INFO) << "Apply kernel [" << op_name << "] op info attrs is empty"; - return true; - } - std::vector> inputs = op_info->inputs_ptr(); - - std::vector dyn_input_sizes; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - - if (inputs.empty()) { - MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op info inputs is empty"; - return false; - } - - // create input name list for atch "x_shape" in att with "x" in primitive. - std::map op_info_shape_name; - for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) { - std::string input_name = inputs[op_info_input_i]->name(); - std::string x_shape_name = input_name + "_shape"; - (void)op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name)); - } - - for (const auto &op_attr : attrs) { - nlohmann::json attr_json; - ValuePtr attr_value = primitive->GetAttr(op_attr->name()); - if (attr_value == nullptr && op_attr->name() != kArgDataformat) { - if (op_attr->param_type() == "required") { - // match "x_shape" in att with "x" in primitive. - std::string attr_name = op_attr->name(); - auto find_item = std::find_if( - op_info_shape_name.begin(), op_info_shape_name.end(), - [attr_name](const std::map::value_type item) { return item.second == attr_name; }); - if (find_item != op_info_shape_name.end()) { - if (!dyn_input_sizes.empty()) { - if (find_item->first >= dyn_input_sizes.size() - 1) { - MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first - << " is out of range:" << dyn_input_sizes.size() - 1 << "."; - return false; - } - size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0)); - for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) { - attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx); - attr_json[kName] = op_attr->name(); - attrs_json->push_back(attr_json); - tensor_idx++; - } - } else { - attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first); - attr_json[kName] = op_attr->name(); - attrs_json->push_back(attr_json); - } - } else { - MS_LOG(ERROR) << "op [" << op_name << "] should have attr :" << op_attr->name(); - return false; - } - } - continue; - } - - GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value); - - attr_json[kName] = op_attr->name(); - attrs_json->push_back(attr_json); - } - return true; -} - -bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, - nlohmann::json *const node_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(node_json); - int op_cnt = GetOpCntInc(); - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); - MS_EXCEPTION_IF_NULL(op_info_ptr); - - // get basic params from currentNodeOpDesc - (*node_json)[kName] = op_name; - (*node_json)["impl_path"] = op_info_ptr->impl_path(); - (*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node); - (*node_json)["composite"] = false; - - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - ValuePtr input_names_v = primitive->GetAttr(KInputNames); - if (input_names_v == nullptr) { - MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "]."; - return false; - } - std::vector prim_input_names = GetValue>(input_names_v); - std::string inputs_name; - for (const auto &prim_input_name : prim_input_names) { - (void)inputs_name.append("_input_").append(prim_input_name).append("_"); - } - - // input desc - nlohmann::json inputs_json; - if (!CreateInputDescJson(anf_node, &inputs_json)) { - MS_LOG(ERROR) << "Create input desc json failed, op[" << op_name << "]."; - return false; - } - (*node_json)[kInputDesc] = inputs_json; - MS_LOG(INFO) << "Akg create input desc json success."; - std::string inputs_shape = "inputs_shape_"; - for (auto &i : inputs_json) { - for (auto &m : i) { - std::string data_type = m[kDataType]; - (void)inputs_shape.append("_").append(data_type).append("_"); - for (auto &j : m[kShape]) { - size_t n = j; - (void)inputs_shape.append(std::to_string(n)).append("_"); - } - } - } - - // output desc - nlohmann::json outputs_json; - if (!CreateOutputDescJson(anf_node, &outputs_json)) { - MS_LOG(ERROR) << "Create output desc json failed, op[" << op_name << "]."; - return false; - } - - (*node_json)[kOutputDesc] = outputs_json; - MS_LOG(INFO) << "Akg create output desc json success."; - std::string outputs_shape = "outputs_shape_"; - for (auto &i : outputs_json) { - std::string data_type = i[kDataType]; - (void)outputs_shape.append("_").append(data_type).append("_"); - for (auto &j : i[kShape]) { - size_t m = j; - (void)outputs_shape.append(std::to_string(m)).append("_"); - } - } - - // attribute desc - nlohmann::json attrs_json; - if (!CreateAttrDescJson(anf_node, op_name, op_info_ptr, &attrs_json)) { - MS_LOG(ERROR) << "Create attr desc json failed, op[" << op_name << "]."; - return false; - } - (*node_json)["attr"] = attrs_json; - std::string json_str = node_json->dump(); - size_t hash_id = std::hash()(json_str); - json_name_ = op_name + "_"; - (void)json_name_.append(std::to_string(hash_id)); - MS_LOG(INFO) << "full scope name is : " << anf_node->fullname_with_scope() << ", json info name is : " << json_name_; - json_info_ = json_str; - (*node_json)["id"] = op_cnt; - (*node_json)["op"] = json_name_; - MS_LOG(INFO) << "Akg create node desc json success."; - return true; -} - -KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - auto processor = AkgKernelBuild::GetProcessor(anf_node); - auto cached_kernel_pack = SearchCache(json_name_, processor); - if (cached_kernel_pack != nullptr) { - MS_LOG(INFO) << "Use cached kernel, json_name_[" << json_name_ << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - return cached_kernel_pack; - } - - PyObject *pModule = nullptr; - PyObject *pFunc = nullptr; - PyObject *pArg = nullptr; - PyObject *pRes = nullptr; - - pModule = PyImport_ImportModule(kAkgModule); - if (pModule == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kAkgModule << "]."; - return nullptr; - } - - pFunc = PyObject_GetAttrString(pModule, kCompileWithJsonFunc); - pArg = PyTuple_New(ARGS_SIZE); - (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", node_json.c_str())); - - (void)alarm(AUTODIFF_COMPILE_OVERTIME); - pRes = PyEval_CallObject(pFunc, pArg); - (void)alarm(0); - if (pRes == nullptr) { - MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(pArg) << ")."; - return nullptr; - } - if (PyObject_IsTrue(pRes) != 1) { - MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(pArg) << ")."; - return nullptr; - } - - auto new_kernel_pack = InsertCache(json_name_, processor); - kernel::SaveJsonInfo(json_name_, json_info_); - if (new_kernel_pack == nullptr) { - MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name_ << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - return nullptr; - } - return new_kernel_pack; -} - -KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vector *const input_size, - std::vector *const output_size) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - auto it = kAkgKernelAttrsProcessMap.find(op_name); - if (it != kAkgKernelAttrsProcessMap.end()) { - it->second(anf_node); - } - MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; - nlohmann::json node_json; - if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { - MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed."; - } - - std::string json_str = node_json.dump(); - auto kernel_pack = OpBuild(json_str, anf_node); - if (kernel_pack == nullptr) { - MS_LOG(ERROR) << "Akg build failed op[" << op_name << "], json:" << json_str; - return nullptr; - } - - if (!GetIOSize(node_json, input_size, output_size)) { - MS_LOG(ERROR) << "Cal mem size failed."; - return nullptr; - } - MS_LOG(INFO) << "Akg compile success, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) - << "]"; - return kernel_pack; -} - -size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) { - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->inputs().size()) { - MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" - << cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]"; - } - - auto input_node = cnode->input(input_idx + 1); - if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) { - size_t index = input_tensor_idx_.size(); - input_tensor_idx_[input_node] = index; - } - - return input_tensor_idx_[input_node]; -} - -size_t AkgKernelBuild::GetOutputTensorIdxInc() { - size_t idx = output_tensor_idx_++; - return idx; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_build.h b/mindspore/ccsrc/kernel/akg/akg_kernel_build.h deleted file mode 100644 index 15fa03f45b..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_build.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ -#include -#include -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "ir/dtype.h" -#include -#include "kernel/common_utils.h" -#include "kernel/oplib/oplib.h" - -namespace mindspore { -namespace kernel { -class AkgKernelBuild { - public: - AkgKernelBuild() { - input_tensor_idx_ = {}; - output_tensor_idx_ = 0; - } - ~AkgKernelBuild() = default; - - KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector *const input_size, - std::vector *const output_size); - static std::string GetProcessor(const AnfNodePtr &anf_node); - static std::string PyObjectToStr(PyObject *const PyObj); - - protected: - bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json); - bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json); - bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, - const std::shared_ptr &op_info, nlohmann::json *const attrs_json); - KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node); - int GetOpCntInc(); - size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx); - size_t GetOutputTensorIdxInc(); - bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, - nlohmann::json *const node_json); - - static int op_cnt_; - // lock for variable fusionOpCnt in singleton mode - static std::mutex op_cnt_mtx_; - std::string json_name_; - std::string json_info_; - std::unordered_map input_tensor_idx_; - size_t output_tensor_idx_; -}; - -bool GetIOSize(const nlohmann::json &node_json, std::vector *const input_size, - std::vector *const output_size); -void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair &position, - nlohmann::json *const node_json); -std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, - const std::pair &position); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.cc b/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.cc deleted file mode 100644 index 3515add1e0..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/akg_kernel_metadata.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -void AkgMetadataInfo(const CNodePtr &kernel_node, - std::vector> *const kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - for (size_t i = 0; i < support_devices.size(); i++) { - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); - if (op_info_ptr == nullptr) { - continue; - } - - if (!ParseMetadata(kernel_node, op_info_ptr, Processor(i), kernel_info_list)) { - MS_LOG(WARNING) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "] failed."; - } else { - MS_LOG(DEBUG) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "]."; - break; - } - } - - if (kernel_info_list->empty()) { - MS_LOG(WARNING) << "Akg dose not has metadata of op[" << op_name << "]."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.h b/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.h deleted file mode 100644 index 5e329f0080..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ - -#include -#include -#include -#include -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -void AkgMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ diff --git a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.cc b/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.cc deleted file mode 100644 index 7200a91ac0..0000000000 --- a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.cc +++ /dev/null @@ -1,422 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/ascend/akg_ascend_kernel_build.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/dtype.h" -#include "ir/func_graph.h" -#include "kernel/kernel.h" -#include "kernel/common_utils.h" -#include "kernel/tbe/tbe_utils.h" -#include "kernel/akg/ascend/akg_ascend_kernel_mod.h" -#include "kernel/akg/akg_kernel_attrs_process.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -constexpr int32_t PARALLEL_ARGS_SIZE = 3; -constexpr int32_t PROCESS_NUM = 16; -constexpr int32_t TIME_OUT = 300; - -constexpr auto kOpDesc = "op_desc"; -constexpr auto kShape = "shape"; -constexpr auto kDataType = "data_type"; -constexpr auto kInputDesc = "input_desc"; -constexpr auto kOutputDesc = "output_desc"; -constexpr auto kTensorName = "tensor_name"; -constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel"; -constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler"; -namespace { -void UpdateTensorNameInJson(const std::vector &anf_nodes, - std::map *node_json_map) { - for (auto const &anf_node : anf_nodes) { - std::vector dyn_input_sizes; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - - if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - - bool is_dynamic_input = !dyn_input_sizes.empty(); - size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); - size_t real_input_index = 0; - for (size_t i = 0; i < input_num; ++i) { - size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1; - for (size_t j = 0; j < input_tensor_num; ++j) { - auto tmp_input = GetKernelInput(anf_node, real_input_index); - std::string tensor_name = GetTensorName((*node_json_map)[anf_node], kInputDesc, std::make_pair(i, j)); - if (node_json_map->find(tmp_input.first) != node_json_map->end()) { - std::string new_tensor_name = - GetTensorName((*node_json_map)[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second)); - SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &((*node_json_map)[anf_node])); - MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of [" - << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output [" - << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "]."; - } else { - MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of [" - << anf_node->fullname_with_scope() << "] is out input."; - } - real_input_index++; - } - } - } -} - -nlohmann::json GetInputsJson(const std::vector &anf_nodes, const std::vector &input_list, - std::map *node_json_map) { - nlohmann::json inputs_json; - auto input_index = GetInputIndex(anf_nodes, input_list); - for (size_t i = 0; i < input_index.size(); ++i) { - auto tmp_input = input_index[i]; - auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first); - std::string dtype = TypeId2String(type_id); - nlohmann::json input_desc_json; - input_desc_json[kTensorName] = GetTensorName((*node_json_map)[tmp_input.first], kInputDesc, tmp_input.second); - input_desc_json[kDataType] = dtype; - input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first); - inputs_json.emplace_back(std::vector{input_desc_json}); - } - - return inputs_json; -} - -nlohmann::json GetOutputsJson(const std::vector &anf_nodes, const std::vector &input_list, - const std::vector &output_list, const nlohmann::json &inputs_json, - std::map *node_json_map) { - nlohmann::json outputs_json; - auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); - for (size_t i = 0; i < output_index.size(); ++i) { - auto tmp_output = output_index[i]; - bool found = false; - nlohmann::json output_desc_json; - for (size_t input_i = 0; input_i < input_list.size(); ++input_i) { - if (tmp_output.first == input_list[input_i]) { - output_desc_json = inputs_json[input_i][0]; - found = true; - break; - } - } - if (!found) { - auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second); - std::string dtype = TypeId2String(type_id); - output_desc_json[kTensorName] = - GetTensorName((*node_json_map)[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second)); - output_desc_json[kDataType] = dtype; - auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second); - if (output_shape.empty()) { - output_shape.push_back(1); - } - output_desc_json[kShape] = output_shape; - } - outputs_json.emplace_back(output_desc_json); - } - - return outputs_json; -} - -std::pair, std::vector>> PreProcessJsonForBuild( - const std::vector> &build_args) { - // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. - std::vector jsons; - std::vector> repeat_nodes; - std::unordered_set json_name_set; - for (const auto &[builder, anf_node] : build_args) { - MS_EXCEPTION_IF_NULL(anf_node); - auto json_name = builder.json_name(); - MS_LOG(DEBUG) << "Akg start compile op: " << json_name; - auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); - if (cached_kernel_pack != nullptr) { - MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - auto kernel_mod_ptr = std::make_shared(cached_kernel_pack); - kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); - kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); - continue; - } - - if (json_name_set.count(json_name) != 0) { - repeat_nodes.push_back({builder, anf_node}); - continue; - } - json_name_set.insert(json_name); - auto node_json = builder.kernel_json(); - kernel::SaveJsonInfo(json_name, node_json); - jsons.push_back(node_json); - } - - return std::make_pair(jsons, repeat_nodes); -} - -bool PostProcessAfterCompile(const std::vector> &build_args, - const std::vector> &repeat_nodes) { - for (const auto &[builder, anf_node] : build_args) { - auto json_name = builder.json_name(); - auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); - if (new_kernel_pack == nullptr) { - MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - return false; - } - auto kernel_mod_ptr = std::make_shared(new_kernel_pack); - kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); - kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); - MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!"; - } - - for (const auto &[builder, anf_node] : repeat_nodes) { - auto node_json = builder.kernel_json(); - auto json_name = builder.json_name(); - auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); - if (cached_kernel_pack == nullptr) { - return false; - } - MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - auto kernel_mod_ptr = std::make_shared(cached_kernel_pack); - kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); - kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); - } - - return true; -} -} // namespace - -bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; - auto it = kAkgKernelAttrsProcessMap.find(op_name); - if (it != kAkgKernelAttrsProcessMap.end()) { - it->second(anf_node); - } - MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; - nlohmann::json node_json; - if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { - MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed."; - } - - kernel_json_ = node_json.dump(); - - if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) { - MS_LOG(ERROR) << "Cal mem size failed."; - return false; - } - - return true; -} - -bool AkgAscendKernelBuilder::GenJsonAndPreprocess4Fused(const std::vector &anf_nodes, - std::map *node_json_map) { - for (auto const &anf_node : anf_nodes) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (!AnfAlgo::IsRealKernel(anf_node)) { - MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "]."; - return false; - } - auto it = kAkgKernelAttrsProcessMap.find(op_name); - if (it != kAkgKernelAttrsProcessMap.end()) { - it->second(anf_node); - } - - nlohmann::json node_json; - if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { - MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed."; - return false; - } - // No need for composite op. - node_json.erase("id"); - node_json.erase("op"); - node_json.erase("composite"); - - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - - if (primitive->GetAttr("fusion") != nullptr) { - node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); - } - - (*node_json_map)[anf_node] = node_json; - } - return true; -} - -bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector &anf_nodes, - const std::vector &input_list, - const std::vector &output_list) { - if (anf_nodes.empty() || input_list.empty()) { - MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size() - << "]."; - return false; - } - MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list [" - << input_list.size() << "]."; - - std::map node_json_map; - if (!GenJsonAndPreprocess4Fused(anf_nodes, &node_json_map)) { - return false; - } - - UpdateTensorNameInJson(anf_nodes, &node_json_map); - - nlohmann::json fused_node_json; - std::vector node_json_desc; - std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc), - [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; }); - fused_node_json[kOpDesc] = node_json_desc; - fused_node_json[kInputDesc] = GetInputsJson(anf_nodes, input_list, &node_json_map); - fused_node_json[kOutputDesc] = - GetOutputsJson(anf_nodes, input_list, output_list, fused_node_json[kInputDesc], &node_json_map); - - size_t hash_id = std::hash()(fused_node_json.dump()); - json_name_ = "Fused_"; - auto fg = anf_nodes[0]->func_graph(); - MS_EXCEPTION_IF_NULL(fg); - auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - if (attr_val != nullptr) { - auto fg_attr = GetValue(attr_val); - (void)json_name_.append(fg_attr).append("_"); - } - (void)json_name_.append(std::to_string(hash_id)); - fused_node_json["composite_graph"] = fg->ToString(); - fused_node_json["op"] = json_name_; - fused_node_json["platform"] = "AKG"; - fused_node_json["process"] = "aicore"; - fused_node_json["composite"] = true; - - kernel_json_ = fused_node_json.dump(); - - if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) { - MS_LOG(ERROR) << "Cal mem size failed."; - return false; - } - - return true; -} - -void GenParallelCompileFuncArgs(const std::vector &kernel_jsons, PyObject **p_args) { - MS_EXCEPTION_IF_NULL(p_args); - *p_args = PyTuple_New(PARALLEL_ARGS_SIZE); - - PyObject *arg1 = PyList_New(kernel_jsons.size()); - for (int i = 0; i < PyList_Size(arg1); ++i) { - PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str())); - } - PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM); - PyObject *arg3 = Py_BuildValue("i", TIME_OUT); - - (void)PyTuple_SetItem(*p_args, 0, arg1); - (void)PyTuple_SetItem(*p_args, 1, arg2); - (void)PyTuple_SetItem(*p_args, 2, arg3); -} - -bool AkgOpParallelBuild(const std::vector> &build_args) { - auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args); - if (jsons.empty()) { - return true; - } - - // Try to call python method to compile nodes parallely. - PyObject *p_module = nullptr; - PyObject *p_func = nullptr; - PyObject *p_arg = nullptr; - PyObject *p_res = nullptr; - - p_module = PyImport_ImportModule(kMultiProcModule); - if (p_module == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "]."; - return false; - } - - p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc); - GenParallelCompileFuncArgs(jsons, &p_arg); - MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size() - << " Akg kernels parallelly."; - p_res = PyEval_CallObject(p_func, p_arg); - if (p_res == nullptr) { - PyErr_Print(); - MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; - return false; - } - if (PyObject_IsTrue(p_res) != 1) { - PyErr_Print(); - MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; - return false; - } - - if (!PostProcessAfterCompile(build_args, repeat_nodes)) { - return false; - } - - return true; -} - -bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes) { - std::vector> json_and_node; - for (const auto &anf_node : anf_nodes) { - MS_EXCEPTION_IF_NULL(anf_node); - AkgAscendKernelBuilder akg_cce_kernel_builder; - KernelPackPtr kernel_pack = nullptr; - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::IsGraphKernel(cnode)) { - auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); - auto mng = func_graph->manager(); - if (mng == nullptr) { - mng = Manage(func_graph, true); - func_graph->set_manager(mng); - } - MS_EXCEPTION_IF_NULL(func_graph); - std::vector node_list; - std::vector input_list; - std::vector output_list; - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]"; - GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); - if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) { - MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "]."; - } - } else { - if (!akg_cce_kernel_builder.CollectJson(anf_node)) { - MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "]."; - } - } - json_and_node.push_back({akg_cce_kernel_builder, anf_node}); - } - - if (json_and_node.empty()) { - MS_LOG(DEBUG) << "There is no kernel needed to be compiled."; - return true; - } - - return AkgOpParallelBuild(json_and_node); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.h b/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.h deleted file mode 100644 index 01752911ed..0000000000 --- a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "kernel/kernel.h" -#include "kernel/akg/akg_kernel_build.h" - -namespace mindspore { -namespace kernel { -class AkgAscendKernelBuilder : public AkgKernelBuild { - public: - AkgAscendKernelBuilder() = default; - ~AkgAscendKernelBuilder() = default; - - bool CollectJson(const AnfNodePtr &anf_node); - bool CollectFusedJson(const std::vector &anf_nodes, const std::vector &input_list, - const std::vector &output_list); - std::string json_name() const { return json_name_; } - std::string kernel_json() const { return kernel_json_; } - const std::vector &input_size_list() const { return input_size_list_; } - const std::vector &output_size_list() const { return output_size_list_; } - - private: - bool GenJsonAndPreprocess4Fused(const std::vector &anf_nodes, - std::map *node_json_map); - - std::string kernel_json_; - std::vector input_size_list_; - std::vector output_size_list_; -}; - -bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.cc b/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.cc deleted file mode 100644 index 69fc82aad3..0000000000 --- a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/ascend/akg_ascend_kernel_mod.h" -#include -#include -#include -#include -#include -#include -#include -#include "nlohmann/json.hpp" -#include "runtime/rt.h" -#include "utils/log_adapter.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace kernel { -using std::fstream; -using std::map; -using std::mutex; -using std::string; -using TbeTaskInfoPtr = std::shared_ptr; -using tbe::KernelManager; -constexpr uint32_t DEFAULT_BLOCK_DIM = 1; -/** - * @brief infotable contain func_stub\blockdim\kernel file buffer - */ -AkgKernelMod::AkgKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {} - -void AkgKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } - -void AkgKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } - -void AkgKernelMod::SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } - -const std::vector &AkgKernelMod::GetInputSizeList() const { return input_size_list_; } - -const std::vector &AkgKernelMod::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool AkgKernelMod::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { - if (stream_ptr == nullptr) { - MS_LOG(ERROR) << "stream_ptr should not be nullptr."; - return false; - } - - if (kernel_pack_ == nullptr) { - MS_LOG(ERROR) << "kernel pack should not be nullptr."; - return false; - } - - uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. - auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); - if (func_stub == 0) { - MS_LOG(ERROR) << "GenFuncStub failed."; - return false; - } - - // pack all addresses into a vector. - std::vector runtime_args; - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtime_args), - [](const AddressPtr &input) -> void * { return input->addr; }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args), - [](const AddressPtr &output) -> void * { return output->addr; }); - - rtL2Ctrl_t *l2ctrl = nullptr; - auto stream = reinterpret_cast(stream_ptr); - if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast(func_stub), block_dim, runtime_args.data(), - SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) { - MS_LOG(ERROR) << "Call runtime rtKernelLaunch error."; - return false; - } - - return true; -} - -std::vector AkgKernelMod::GenTask(const std::vector &inputs, const std::vector &, - const std::vector &outputs, uint32_t stream_id) { - if (kernel_pack_ == nullptr) { - MS_LOG(EXCEPTION) << "kernel pack should not be nullptr."; - } - - std::vector args; - const uint32_t args_size = 0; - std::vector sm_desc; - void *binary = nullptr; - const uint32_t binary_size = 0; - std::vector meta_data; - std::vector input_data_addrs; - std::vector output_data_addrs; - std::vector workspace_addrs; - - // pack all addresses into a vector. - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), - [](const AddressPtr &input) -> void * { return input->addr; }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), - [](const AddressPtr &output) -> void * { return output->addr; }); - - uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. - auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); - if (func_stub == 0) { - MS_LOG(EXCEPTION) << "GenFuncStub failed."; - } - - std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_); - - MS_LOG(DEBUG) << "The block_dim is:" << block_dim; - - TbeTaskInfoPtr task_info_ptr = make_shared( - stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, input_data_addrs, - output_data_addrs, workspace_addrs); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.h b/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.h deleted file mode 100644 index 18d342f629..0000000000 --- a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ -#include -#include -#include -#include "kernel/ascend_kernel_mod.h" -#include "kernel/tbe/tbe_utils.h" - -namespace mindspore { -namespace kernel { -class AkgKernelMod : public AscendKernelMod { - public: - explicit AkgKernelMod(const KernelPackPtr &kernel_pack); - ~AkgKernelMod() final {} - - void SetInputSizeList(const std::vector &size_list); - void SetOutputSizeList(const std::vector &size_list); - void SetWorkspaceSizeList(const std::vector &size_list); - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - KernelPackPtr kernel_pack_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -using AkgKernelModPtr = std::shared_ptr; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.cc b/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.cc deleted file mode 100644 index 534e355802..0000000000 --- a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.cc +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/gpu/akg_gpu_kernel_build.h" -#include -#include -#include "kernel/kernel.h" -#include "kernel/akg/akg_kernel_build.h" -#include "kernel/akg/gpu/akg_gpu_kernel_mod.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - AkgKernelBuild akg_kernel_build; - - std::vector input_size_list; - std::vector output_size_list; - KernelPackPtr kernel_pack = akg_kernel_build.BuildByJson(anf_node, &input_size_list, &output_size_list); - MS_EXCEPTION_IF_NULL(kernel_pack); - - auto kernel_mod_ptr = std::make_shared(kernel_pack); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - kernel_mod_ptr->SetInputSizeList(input_size_list); - kernel_mod_ptr->SetOutputSizeList(output_size_list); - return kernel_mod_ptr; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.h b/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.h deleted file mode 100644 index 3a1145140f..0000000000 --- a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.h +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ -#include "kernel/kernel.h" -#include "ir/base.h" - -namespace mindspore { -namespace kernel { -KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.cc b/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.cc deleted file mode 100644 index 64590cd9b8..0000000000 --- a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.cc +++ /dev/null @@ -1,116 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/gpu/akg_gpu_kernel_mod.h" -#include -#include -#include "nlohmann/json.hpp" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -using std::fstream; -using std::string; -using std::vector; - -GpuKernelManagerPtr GpuKernelMod::kernelmanager_ = std::make_shared(); -GpuKernelManager::GpuKernelManager() {} - -CUresult GpuKernelManager::GetFunction(const KernelPackPtr &kernel_pack, bool force_reload, - vector *thread_info, CUfunction *func) { - if (kernel_pack->GetJson() == nullptr || kernel_pack->GetJson()->contents == nullptr || - kernel_pack->GetKernel() == nullptr || kernel_pack->GetKernel()->contents == nullptr) { - MS_LOG(ERROR) << "GPU:Invalid kernel pack, json or kernel is nullptr."; - return CUDA_ERROR_INVALID_IMAGE; - } - auto js = nlohmann::json::parse(kernel_pack->GetJson()->contents, - kernel_pack->GetJson()->contents + kernel_pack->GetJson()->len); - string fn = js["kernelName"]; - if (!force_reload) { - auto iter = infotable_.find(fn); - if (iter != infotable_.end()) { - auto kernelmeta = iter->second; - *thread_info = kernelmeta->thread_info_; - *func = kernelmeta->func_addr_; - return CUDA_SUCCESS; - } - } - thread_info->emplace_back(js["blockIdx.x"]); - thread_info->emplace_back(js["blockIdx.y"]); - thread_info->emplace_back(js["blockIdx.z"]); - thread_info->emplace_back(js["threadIdx.x"]); - thread_info->emplace_back(js["threadIdx.y"]); - thread_info->emplace_back(js["threadIdx.z"]); - CUmodule module; - CUresult result = cuModuleLoadData(&module, kernel_pack->GetKernel()->contents); - if (result != CUDA_SUCCESS) { - MS_LOG(ERROR) << "cuModuleLoadData failed."; - return result; - } - result = cuModuleGetFunction(func, module, fn.c_str()); - if (result != CUDA_SUCCESS) { - MS_LOG(ERROR) << "cuModuleGetFunction failed."; - return result; - } - infotable_[fn] = std::make_shared(*func, module, *thread_info); - return result; -} - -GpuKernelMod::GpuKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {} - -void GpuKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } - -void GpuKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } - -const std::vector &GpuKernelMod::GetInputSizeList() const { return input_size_list_; } - -const std::vector &GpuKernelMod::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &GpuKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool GpuKernelMod::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { - if (stream_ptr == 0) { - MS_LOG(ERROR) << "stream_ptr should not be nullptr."; - return false; - } - if (kernel_pack_ == nullptr) { - MS_LOG(ERROR) << "kernel pack should not be nullptr."; - return false; - } - vector thread_info; - CUfunction kernel_addr; - CUresult result = kernelmanager_->GetFunction(kernel_pack_, false, &thread_info, &kernel_addr); - if (result != CUDA_SUCCESS) { - MS_LOG(ERROR) << "GetFunction failed."; - return false; - } - std::vector runtimeargs; - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs), - [](const AddressPtr &input) -> void * { return reinterpret_cast(&(input->addr)); }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs), - [](const AddressPtr &output) -> void * { return reinterpret_cast(&(output->addr)); }); - result = cuLaunchKernel(kernel_addr, thread_info[0], thread_info[1], thread_info[2], thread_info[3], thread_info[4], - thread_info[5], 0, reinterpret_cast(stream_ptr), - reinterpret_cast(&runtimeargs[0]), 0); - if (result != CUDA_SUCCESS) { - MS_LOG(ERROR) << "Launch Kernel failed."; - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.h b/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.h deleted file mode 100644 index df9cb069f7..0000000000 --- a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ -#include -#include -#include -#include -#include -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -struct GpuKernelMeta { - CUfunction func_addr_; - CUmodule module_; - std::vector thread_info_; - GpuKernelMeta(CUfunction funcAddr, CUmodule module, const std::vector &thread_info) - : func_addr_(funcAddr), module_(module), thread_info_(thread_info) {} -}; -using GpuKernelMetaPtr = std::shared_ptr; - -class GpuKernelManager { - public: - GpuKernelManager(); - virtual ~GpuKernelManager() { - for (auto iter = infotable_.begin(); iter != infotable_.end(); ++iter) { - CUresult ret = cuModuleUnload(iter->second->module_); - if (ret != CUDA_SUCCESS && ret != CUDA_ERROR_DEINITIALIZED) { - MS_LOG(ERROR) << "Unload GPU Module failed."; - } - } - } - CUresult GetFunction(const KernelPackPtr &kernel_pack, bool force_reload, std::vector *thread_info, - CUfunction *func); - - private: - std::unordered_map infotable_; -}; -using GpuKernelManagerPtr = std::shared_ptr; - -class GpuKernelMod : public KernelMod { - public: - explicit GpuKernelMod(const KernelPackPtr &kernel_pack); - virtual ~GpuKernelMod() {} - - void SetInputSizeList(const std::vector &size_list); - void SetOutputSizeList(const std::vector &size_list); - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - static GpuKernelManagerPtr kernelmanager_; - - private: - KernelPackPtr kernel_pack_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -using GpuKernelModPtr = std::shared_ptr; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/ascend_kernel_mod.h b/mindspore/ccsrc/kernel/ascend_kernel_mod.h deleted file mode 100644 index 0aee881f7d..0000000000 --- a/mindspore/ccsrc/kernel/ascend_kernel_mod.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ - -#include -#include -#include "framework/ge_runtime/task_info.h" -#include "kernel/kernel.h" - -using TaskInfoPtr = std::shared_ptr; -namespace mindspore { -namespace kernel { -class AscendKernelMod : public KernelMod { - public: - virtual std::vector GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t) = 0; - uint32_t block_dim() { return block_dim_; } - uint32_t stream_id() { return stream_id_; } - - protected: - uint32_t block_dim_{1}; - uint32_t stream_id_{0}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/common_utils.cc b/mindspore/ccsrc/kernel/common_utils.cc deleted file mode 100644 index ab4f59e549..0000000000 --- a/mindspore/ccsrc/kernel/common_utils.cc +++ /dev/null @@ -1,896 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/common_utils.h" -#include -#include -#include -#include -#include -#include -#include "nlohmann/json.hpp" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "ir/manager.h" -#include "ir/meta_tensor.h" -#include "ir/func_graph.h" -#include "operator/ops.h" -#include "utils/graph_utils.h" - -namespace mindspore { -namespace kernel { -const std::unordered_map type_id_maps = { - {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16}, - {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64}, - {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8}, - {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32}, - {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, - {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, - {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, - {"bool", TypeId::kNumberTypeBool}, -}; - -const std::map type_id_str_map = { - {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"}, - {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"}, - {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"}, - {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"}, - {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, - {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, - {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, - {TypeId::kNumberTypeBool, "bool"}, -}; - -const std::unordered_map dtype_shortdtype_map_ = { - {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"}, - {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"}, -}; - -const std::unordered_map dtype_nbyte_map = { - {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2}, - {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)}, - {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2}, - {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)}, -}; - -const std::unordered_map fusion_type_maps = { - {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, - {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE}, -}; - -void KernelMeta::Initialize() { - kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/"; - // remove old kernel cache - RemoveKernelCache(); - -#if defined(_WIN32) || defined(_WIN64) - auto ret = mkdir(kernel_meta_path_.c_str()); -#else - auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU); -#endif - if (ret != 0) { - MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later"; - } - initialized_ = true; -} - -void KernelMeta::RemoveKernelCache() { - DIR *dir = opendir(kernel_meta_path_.c_str()); - if (dir == nullptr) { - return; - } - struct dirent *entry; - while ((entry = readdir(dir)) != nullptr) { - std::string kernel_file = entry->d_name; - std::string kernel_file_realpath = kernel_meta_path_ + kernel_file; - (void)remove(kernel_file_realpath.c_str()); - } - (void)closedir(dir); - (void)rmdir(kernel_meta_path_.c_str()); -} - -std::string KernelMeta::Search(const std::string &kernel_name) const { - if (!initialized_) { - return ""; - } - - auto iter = kernel_meta_map_.find(kernel_name); - if (iter == kernel_meta_map_.end()) { - return ""; - } else { - return iter->second; - } -} - -bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) { - if (!initialized_) { - return false; - } - kernel_meta_map_[kernel_name] = kernel_json; - return true; -} - -bool CheckCache(const std::string &kernel_name) { - // check cache. - KernelMeta *bin_map = KernelMeta::GetInstance(); - if (bin_map == nullptr) { - MS_LOG(DEBUG) << "kernel cache is invalid."; - return false; - } - std::string kernel_json = bin_map->Search(kernel_name); - bool ret = (!kernel_json.empty()); - if (ret) { - MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed."; - } else { - MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed."; - } - return ret; -} - -KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) { - // search cache. - KernelMeta *bin_map = KernelMeta::GetInstance(); - if (bin_map == nullptr) { - MS_LOG(DEBUG) << "kernel cache is invalid."; - return nullptr; - } - - std::string kernel_json = bin_map->Search(kernel_name); - if (!kernel_json.empty()) { - KernelPackPtr kernel_pack = std::make_shared(); - // just a tmp solution. - if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) { - MS_LOG(DEBUG) << "Read cache json and bin file failed[" << kernel_json << "]."; - return nullptr; - } else { - return kernel_pack; - } - } else { - MS_LOG(INFO) << "cache kernel not found[" << kernel_name << "]."; - return nullptr; - } -} - -KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) { - MS_LOG(INFO) << "kernel name:" << kernel_name << ", processr:" << processor; - KernelMeta *bin_map = KernelMeta::GetInstance(); - std::string kernel_json; - if (processor == kProcessorAiCore || processor == kProcessorAiCpu) { - kernel_json = kCceKernelMeta; - } else { - kernel_json = bin_map->GetKernelMetaPath(); - } - (void)kernel_json.append(kernel_name).append(kJsonSuffix); - KernelPackPtr kernel_pack = std::make_shared(); - if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) { - MS_LOG(DEBUG) << "Read json and bin file failed[" << kernel_json << "]."; - return nullptr; - } - - if (bin_map == nullptr) { - MS_LOG(DEBUG) << "kernel cache is invalid."; - return nullptr; - } - if (bin_map->Insert(kernel_name, kernel_json)) { - MS_LOG(INFO) << "Insert to cache success[" << kernel_json << "], kernelname[" << kernel_name << "]."; - } - return kernel_pack; -} - -TypeId DtypeToTypeId(const std::string &dtypes) { - auto iter = type_id_maps.find(dtypes); - if (iter != type_id_maps.end()) { - return iter->second; - } else { - MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes; - } -} - -std::string TypeId2String(TypeId type_id) { - auto iter = type_id_str_map.find(type_id); - if (iter == type_id_str_map.end()) { - return std::string(TypeIdLabel(type_id)); - } - return iter->second; -} - -std::string Dtype2ShortType(const std::string &dtypes) { - auto iter = dtype_shortdtype_map_.find(dtypes); - if (iter != dtype_shortdtype_map_.end()) { - return iter->second; - } else { - MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes; - } -} - -size_t GetDtypeNbyte(const std::string &dtypes) { - auto iter = dtype_nbyte_map.find(dtypes); - if (iter != dtype_nbyte_map.end()) { - return iter->second; - } else { - MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes; - } -} - -bool SetInputKernelBuilderInfo(const std::vector> &inputs, size_t real_input_num, - size_t builder_idex, const std::vector &dyn_input_sizes, - const std::shared_ptr &builder) { - MS_EXCEPTION_IF_NULL(builder); - - std::vector inputs_device_type; - std::vector inputs_format; - size_t dyn_input_idx = 0; - size_t kernel_info_index = 0; - MS_EXCEPTION_IF_NULL(inputs[0]); - size_t kernel_info_cnt = inputs[0]->dtypes().size(); - - for (const auto &input : inputs) { - MS_EXCEPTION_IF_NULL(input); - std::string param_type = input->param_type(); - std::vector dtypes = input->dtypes(); - std::vector formats = input->formats(); - if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { - MS_LOG(DEBUG) << "Set input kernel builder info, dtyps size != formats size."; - return false; - } - - if (param_type == "dynamic") { - if (dyn_input_sizes.empty()) { - MS_LOG(DEBUG) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; - return false; - } - - for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) { - kernel_info_index++; - auto type_id = DtypeToTypeId(dtypes[builder_idex]); - inputs_device_type.push_back(type_id); - inputs_format.push_back(formats[builder_idex]); - } - dyn_input_idx++; - } else if (param_type == "required") { - kernel_info_index++; - auto type_id = DtypeToTypeId(dtypes[builder_idex]); - inputs_device_type.push_back(type_id); - inputs_format.push_back(formats[builder_idex]); - } else { - if (kernel_info_index < real_input_num) { - MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index; - kernel_info_index++; - auto type_id = DtypeToTypeId(dtypes[builder_idex]); - inputs_device_type.push_back(type_id); - inputs_format.push_back(formats[builder_idex]); - } - } - } - - builder->SetInputsDeviceType(inputs_device_type); - builder->SetInputsFormat(inputs_format); - return true; -} - -bool SetOutputKernelBuilderInfo(const std::vector> &outputs, size_t builder_idex, - const size_t &real_output_num, - const std::shared_ptr &builder) { - // not now but in the next we need to support dynamic output case - MS_EXCEPTION_IF_NULL(builder); - - size_t output_idx = 0; - std::vector outputs_device_type; - std::vector outputs_format; - MS_EXCEPTION_IF_NULL(outputs[0]); - size_t kernel_info_cnt = outputs[0]->dtypes().size(); - - for (const auto &output : outputs) { - MS_EXCEPTION_IF_NULL(output); - if (output_idx >= real_output_num) { - MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!"; - continue; - } - size_t output_num = 0; - if (output->param_type() == "dynamic") { - if (outputs.size() > 1) { - MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!"; - } - output_num = real_output_num; - } else if (output->param_type() == "required") { - output_num = 1; - } else { - if (output_idx < real_output_num) { - MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; - output_num = 1; - } - } - - for (size_t i = 0; i < output_num; i++) { - std::vector dtypes = output->dtypes(); - std::vector formats = output->formats(); - if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { - MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size."; - return false; - } - auto type_id = DtypeToTypeId(dtypes[builder_idex]); - outputs_device_type.push_back(type_id); - outputs_format.push_back(formats[builder_idex]); - output_idx++; - } - } - - builder->SetOutputsFormat(outputs_format); - builder->SetOutputsDeviceType(outputs_device_type); - return true; -} - -void SetKernelBuildInfo(const std::shared_ptr &builder, Processor processor, - const std::shared_ptr &op_info_ptr) { - MS_EXCEPTION_IF_NULL(builder); - MS_EXCEPTION_IF_NULL(op_info_ptr); - - auto imply_type = op_info_ptr->imply_type(); - builder->SetProcessor(processor); - std::string fusion_type = op_info_ptr->fusion_type(); - auto iter = fusion_type_maps.find(fusion_type); - if (iter != fusion_type_maps.end()) { - builder->SetFusionType(iter->second); - } else { - if (imply_type == kAKG) { - MS_EXCEPTION(NotExistsError) << "Illegal fusion type from dsl register:" << fusion_type; - } - } - - if (imply_type == kAKG) { - builder->SetKernelType(AKG_KERNEL); - } else if (imply_type == kAICPU) { - builder->SetKernelType(AICPU_KERNEL); - } else { - builder->SetKernelType(TBE_KERNEL); - } -} - -bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &op_info_ptr, Processor processor, - std::vector> *const kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node); - size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - std::vector> inputs = op_info_ptr->inputs_ptr(); - std::vector> outputs = op_info_ptr->outputs_ptr(); - std::vector dyn_input_sizes; - auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("dyn_input_sizes") != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr("dyn_input_sizes")); - } - if (inputs.size() > 0) { - MS_EXCEPTION_IF_NULL(inputs[0]); - size_t kernel_info_cnt = inputs[0]->dtypes().size(); - for (size_t j = 0; j < kernel_info_cnt; j++) { - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - SetKernelBuildInfo(builder, processor, op_info_ptr); - - if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) { - MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed."; - return false; - } - - if (outputs.size() > 0) { - if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) { - MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed."; - return false; - } - } - - kernel_info_list->push_back(builder->Build()); - } - } else if (outputs.size() > 0) { - MS_EXCEPTION_IF_NULL(outputs[0]); - size_t kernel_info_cnt = outputs[0]->dtypes().size(); - for (size_t j = 0; j < kernel_info_cnt; j++) { - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - SetKernelBuildInfo(builder, processor, op_info_ptr); - - if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) { - MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed."; - return false; - } - - kernel_info_list->push_back(builder->Build()); - } - } else { - if (processor == AICPU) { - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - SetKernelBuildInfo(builder, processor, op_info_ptr); - kernel_info_list->push_back(builder->Build()); - } - } - return true; -} - -void SaveJsonInfo(const std::string &json_name, const std::string &info) { - char real_path[PATH_MAX] = {0}; - std::string path = kCceKernelMeta + json_name + kInfoSuffix; - if (path.size() > PATH_MAX) { - MS_LOG(DEBUG) << "file path " << path << " is too long."; - return; - } - std::ofstream filewrite; - filewrite.open(path); - if (!filewrite.is_open()) { - return; - } - filewrite << info << std::endl; - filewrite.close(); -#if defined(_WIN32) || defined(_WIN64) - if (nullptr == _fullpath(real_path, path.c_str(), PATH_MAX)) { - MS_LOG(DEBUG) << "dir " << path << " does not exit."; - return; - } -#else - if (nullptr == realpath(path.c_str(), real_path)) { - MS_LOG(DEBUG) << "dir " << path << " does not exit."; - return; - } -#endif - MS_LOG(INFO) << "real path is :" << real_path; - if (chmod(real_path, S_IRUSR) == -1) { - MS_LOG(DEBUG) << "modify file:" << real_path << " to read only fail."; - } -} - -std::string GetProcessor(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string device; - switch (AnfAlgo::GetProcessor(anf_node)) { - case Processor::AICORE: - device = kProcessorAiCore; - break; - - case Processor::AICPU: - device = kProcessorAiCpu; - break; - - case Processor::CUDA: - device = kProcessorCuda; - break; - - default: - MS_LOG(DEBUG) << "Unknown processor type."; - break; - } - return device; -} - -bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b) { - if (shape_a.size() != shape_b.size()) { - return false; - } - for (size_t i = 0; i < shape_a.size(); ++i) { - if (shape_a[i] != shape_b[i]) { - return false; - } - } - return true; -} - -int Sign(float x) { - if (x > 0) { - return 1; - } - if (x < 0) { - return -1; - } - return 0; -} - -void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim) { - MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); - MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); - MS_EXCEPTION_IF_NULL(unique_grad); - MS_EXCEPTION_IF_NULL(unique_grad->value_); - MS_EXCEPTION_IF_NULL(unique_grad->indices_); - std::unordered_map index_map; - size_t unique_indices_size = 0; - for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) { - int index = origin_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= first_dim) { - continue; - } - auto iter = index_map.find(index); - if (iter == index_map.end()) { - index_map[index] = unique_indices_size; - unique_grad->indices_[unique_indices_size] = index; - size_t start_index = unique_indices_size * outer_dim; - size_t end_index = start_index + outer_dim; - for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) { - unique_grad->value_[j] = origin_sparse_grad.value_[k]; - } - unique_indices_size++; - } else { - size_t first_index = iter->second; - size_t start_index = first_index * outer_dim; - size_t end_index = start_index + outer_dim; - for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) { - unique_grad->value_[j] += origin_sparse_grad.value_[k]; - } - } - } - unique_grad->indices_size_ = unique_indices_size; -} - -struct WorkerParamsForReduceSparseGradient { - size_t slice_start_{0}; - size_t slice_end_{0}; - size_t max_length_{0}; - size_t outer_dim_{0}; - std::vector> *sorted_indices_{nullptr}; - std::vector *slice_positions_{nullptr}; - float *src_value_{nullptr}; - SparseGradient *unique_grad_{nullptr}; -}; - -void WorkerForReduceSparseGradient(WorkerParamsForReduceSparseGradient param) { - MS_EXCEPTION_IF_NULL(param.sorted_indices_); - MS_EXCEPTION_IF_NULL(param.slice_positions_); - MS_EXCEPTION_IF_NULL(param.src_value_); - MS_EXCEPTION_IF_NULL(param.unique_grad_); - auto outer_dim = param.outer_dim_; - auto &sorted_indices = *(param.sorted_indices_); - auto &slice_positions = *(param.slice_positions_); - auto unique_grad = param.unique_grad_; - for (size_t slice_id = param.slice_start_; slice_id < param.slice_end_; ++slice_id) { - size_t cur_pos = slice_positions[slice_id]; - int index = sorted_indices[cur_pos].first; - unique_grad->indices_[slice_id] = index; - size_t start_index = slice_id * outer_dim; - auto ret_code = memcpy_s(unique_grad->value_ + start_index, (param.max_length_ - start_index) * sizeof(float), - param.src_value_ + sorted_indices[cur_pos].second, outer_dim * sizeof(float)); - if (ret_code != EOK) { - MS_LOG(EXCEPTION) << "Failed to copy data!"; - } - cur_pos++; - size_t end_pos; - if (slice_id + 1 < slice_positions.size()) { - end_pos = slice_positions[slice_id + 1]; - } else { - end_pos = sorted_indices.size(); - } - while (cur_pos < end_pos) { - for (size_t i = 0; i < outer_dim; ++i) { - unique_grad->value_[start_index + i] += param.src_value_[sorted_indices[cur_pos].second + i]; - } - cur_pos++; - } - } -} - -void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim) { - MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); - MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); - MS_EXCEPTION_IF_NULL(unique_grad); - MS_EXCEPTION_IF_NULL(unique_grad->value_); - MS_EXCEPTION_IF_NULL(unique_grad->indices_); - std::vector> sorted_indices; - sorted_indices.reserve(origin_sparse_grad.indices_size_); - for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) { - int index = origin_sparse_grad.indices_[i]; - if (index >= 0 && IntToSize(index) < first_dim) { - sorted_indices.emplace_back(std::pair(index, i * outer_dim)); - } - } - std::sort( - sorted_indices.begin(), sorted_indices.end(), - [](const std::pair &left, const std::pair &right) { return left.first < right.first; }); - int last_index = 0; - std::vector slice_positions; - for (size_t i = 0; i < sorted_indices.size(); ++i) { - if (i == 0 || last_index != sorted_indices[i].first) { - slice_positions.emplace_back(i); - } - last_index = sorted_indices[i].first; - } - size_t thread_num = 8; - if (slice_positions.size() < thread_num) { - thread_num = slice_positions.size(); - } - size_t stride = (slice_positions.size() + thread_num - 1) / thread_num; - thread_num = (slice_positions.size() + stride - 1) / stride; - std::vector threads; - size_t max_length = sorted_indices.size() * outer_dim; - for (size_t i = 0; i < thread_num; ++i) { - size_t slice_start = i * stride; - size_t slice_end = 0; - if (i == thread_num - 1) { - slice_end = slice_positions.size(); - } else { - slice_end = slice_start + stride; - } - WorkerParamsForReduceSparseGradient params{ - slice_start, slice_end, max_length, outer_dim, &sorted_indices, &slice_positions, origin_sparse_grad.value_, - unique_grad}; - threads.emplace_back(std::thread(WorkerForReduceSparseGradient, params)); - } - for (size_t i = 0; i < thread_num; ++i) { - threads[i].join(); - } - unique_grad->indices_size_ = slice_positions.size(); -} - -std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index) { - MS_EXCEPTION_IF_NULL(anf_node); - - if (index >= AnfAlgo::GetInputTensorNum(anf_node)) { - MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs."; - } - - auto cnode = anf_node->cast(); - if (cnode == nullptr) { - return AnfAlgo::VisitKernel(anf_node, 0); - } else { - return AnfAlgo::VisitKernel(anf_node->cast()->input(index + 1), 0); - } -} - -std::vector>> GetInputIndex(const std::vector &node_list, - const std::vector &input_list) { - std::vector>> input_index; - for (size_t i = 0; i < input_list.size(); ++i) { - auto const &input = input_list[i]; - MS_EXCEPTION_IF_NULL(input); - bool found = false; - // using NodeUsersMap = std::unordered_map>>; - auto mng = input->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(mng); - const NodeUsersMap &users = mng->node_users(); - auto input_users = users.find(input); - if (input_users == users.end() || input_users->second.empty()) { - MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" - << input->func_graph()->ToString() << "] has no users."; - } - - for (auto const &input_user : input_users->second) { - for (auto const &anf_node : node_list) { - if (anf_node != input_user.first) { - continue; - } - - std::vector dyn_input_sizes; - auto prim = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(prim); - if (prim->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(prim->GetAttr(kAttrDynInputSizes)); - } - - if (dyn_input_sizes.empty()) { - input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0))); - found = true; - break; - } else { - int used_as_idx = input_user.second - 1; - int accum_idx = 0; - size_t dyn_i = 0; - for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) { - accum_idx += dyn_input_sizes[dyn_i]; - if (used_as_idx < accum_idx) { - input_index.push_back(std::make_pair( - anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i]))))); - break; - } - } - if (dyn_i != dyn_input_sizes.size()) { - found = true; - break; - } - } - } - if (found) { - break; - } - } - - if (!found) { - MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" - << input->func_graph()->ToString() << "] found no related kernel info."; - } - } - return input_index; -} - -std::vector> GetOutputIndex(const std::vector &node_list, - const std::vector &input_list, - const std::vector &output_list) { - std::vector> output_index; - for (size_t i = 0; i < output_list.size(); ++i) { - auto const &output = output_list[i]; - MS_EXCEPTION_IF_NULL(output); - bool found = false; - auto pree_node = AnfAlgo::VisitKernel(output, 0); - auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first); - if (pos != std::end(node_list)) { - output_index.push_back(pree_node); - continue; - } - auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first); - if (ret != std::end(input_list)) { - output_index.push_back(std::make_pair(pree_node.first, 0)); - found = true; - } - if (!found) { - MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of [" - << output->func_graph()->ToString() << "] found no related kernel info."; - } - } - return output_index; -} - -void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list) { - MS_EXCEPTION_IF_NULL(node_list); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector node_lists = TopoSort(func_graph->get_return()); - for (auto const &node : node_lists) { - if (!AnfAlgo::IsRealKernel(node) || !node->isa()) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (IsValueNode(cnode->input(kAnfPrimitiveIndex))) { - node_list->push_back(node); - } - } -} - -void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list, - std::vector *input_list, std::vector *output_list) { - MS_EXCEPTION_IF_NULL(node_list); - MS_EXCEPTION_IF_NULL(input_list); - MS_EXCEPTION_IF_NULL(output_list); - MS_EXCEPTION_IF_NULL(func_graph); - - GetValidKernelNodes(func_graph, node_list); - - auto parameters = func_graph->parameters(); - input_list->insert(input_list->begin(), parameters.begin(), parameters.end()); - - auto func_output = func_graph->output(); - MS_EXCEPTION_IF_NULL(func_output); - if (func_output->isa()) { - // multi output. - auto cnode = func_output->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input0 = cnode->input(kAnfPrimitiveIndex); - MS_EXCEPTION_IF_NULL(input0); - if (IsPrimitive(input0, prim::kPrimMakeTuple)) { - for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) { - auto input_node = cnode->input(input_idx); - MS_EXCEPTION_IF_NULL(input_node); - output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first); - } - } else { - // single output. - output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); - } - } else { - // single output. - output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); - } -} - -bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(node_json); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->size()) { - MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" - << cnode->inputs().size() << "][" << cnode->DebugString() << "]"; - } - - auto input_node = cnode->input(input_idx + 1); - if (!IsValueNode(input_node)) { - return false; - } - - auto tensor = GetValueNode(input_node); - if (tensor == nullptr) { - return false; - } - - auto type_id = tensor->data_type(); - auto *data = tensor->data_c(); - MS_EXCEPTION_IF_NULL(data); - if (tensor->DataDim() > 1 || tensor->DataSize() != 1) { - // not const tensor. - MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]"; - } - - if (type_id == kFloat32->type_id()) { - float *val = static_cast(data); - MS_EXCEPTION_IF_NULL(val); - (*node_json)["value"] = val[0]; - MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "]."; - return true; - } else if (type_id == kFloat16->type_id()) { - float16 *val = static_cast(data); - MS_EXCEPTION_IF_NULL(val); - (*node_json)["value"] = static_cast(val[0]); - MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "]."; - return true; - } else if (type_id == kInt32->type_id()) { - int *val = static_cast(data); - MS_EXCEPTION_IF_NULL(val); - (*node_json)["value"] = val[0]; - MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "]."; - return true; - } - MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]"; - return false; -} - -void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector> *node_list) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node_list); - auto output = func_graph->output(); - MS_EXCEPTION_IF_NULL(output); - if (AnfAlgo::IsRealKernel(output)) { - // single output. - node_list->push_back(std::make_pair(output, 0)); - return; - } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { - auto output_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - // multi output. - auto &inputs = output_cnode->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0); - node_list->push_back(in_with_idx); - } - return; - } - MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2) - << " of graph: " << func_graph->ToString(); -} - -bool IsWeightBoundary(const AnfNodePtr &node) { - if (node->isa()) { - return true; - } - if (node->isa() && AnfAlgo::IsParameterWeight(node->cast())) { - return true; - } - return false; -} - -void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params, - size_t total_compute_size) { - const size_t kThreadNum = 24; - std::vector threads; - threads.reserve(kThreadNum); - size_t start = 0; - size_t once_compute_size = (total_compute_size + kThreadNum - 1) / kThreadNum; - while (start < total_compute_size) { - size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size); - threads.emplace_back(std::thread(func, params, start, end)); - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/common_utils.h b/mindspore/ccsrc/kernel/common_utils.h deleted file mode 100644 index e9d72848f6..0000000000 --- a/mindspore/ccsrc/kernel/common_utils.h +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "kernel/oplib/opinfo.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -constexpr auto kCceKernelMeta = "./kernel_meta/"; -constexpr auto kGpuKernelMeta = "./cuda_meta"; -constexpr auto kProcessorAiCore = "aicore"; -constexpr auto kProcessorAiCpu = "aicpu"; -constexpr auto kProcessorCuda = "cuda"; -constexpr auto kJsonSuffix = ".json"; -constexpr auto kInfoSuffix = ".info"; -constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600; -constexpr auto kAkgModule = "_akg"; -constexpr auto kArgDataformat = "data_format"; - -const std::vector support_devices = {"aicore", "aicpu", "cuda"}; - -struct KernelMetaInfo { - uintptr_t func_stub_; - uint32_t block_dim_; -}; -using KernelMetaPtr = std::shared_ptr; - -class KernelMeta { - public: - KernelMeta() = default; - void Initialize(); - void RemoveKernelCache(); - std::string Search(const std::string &kernel_name) const; - bool Insert(const std::string &kernel_name, const std::string &kernel_json); - std::string GetKernelMetaPath() { return kernel_meta_path_; } - - static KernelMeta *GetInstance() { - static KernelMeta kernel_meta; - return &kernel_meta; - } - ~KernelMeta() = default; - - private: - bool initialized_ = false; - std::string kernel_meta_path_; - std::unordered_map kernel_meta_map_; -}; - -struct SparseGradient { - float *value_; - int *indices_; - size_t indices_size_; -}; - -struct MultiThreadComputeParams { - float *var_; - float *accum_; - float *linear_; - float *m_; - float *m_t_; - float *v_; - float lr_; - float l1_; - float l2_; - float lr_power_; - float beta1_; - float beta2_; - float epsilon_; - SparseGradient sparse_grad_; - size_t var_first_dim_size_; - size_t var_outer_dim_size_; - bool use_nesterov_; -}; -using MultiThreadComputeFunc = std::function; - -bool CheckCache(const std::string &kernel_name); -KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); -KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); -TypeId DtypeToTypeId(const std::string &dtypes); -std::string Dtype2ShortType(const std::string &dtypes); -std::string TypeId2String(TypeId type_id); -size_t GetDtypeNbyte(const std::string &dtypes); -bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &op_info_ptr, Processor processor, - std::vector> *const kernel_info_list); -void SaveJsonInfo(const std::string &json_name, const std::string &info); -std::string GetProcessor(const AnfNodePtr &anf_node); -bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b); -int Sign(float x); -void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim); -void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim); -std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index); -std::vector>> GetInputIndex(const std::vector &node_list, - const std::vector &input_list); -std::vector> GetOutputIndex(const std::vector &node_list, - const std::vector &input_list, - const std::vector &output_list); -void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list, - std::vector *input_list, std::vector *output_list); -void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list); -bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json); -void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector> *node_list); -bool IsWeightBoundary(const AnfNodePtr &node); -void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params, - size_t total_compute_size); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc deleted file mode 100644 index 5b3194608e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/addn_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); -} - -bool AddNCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); - - size_t offset = 0; - for (size_t i = 0; i < output_shape_[0]; ++i) { - for (size_t j = 0; j < output_shape_[1]; ++j) { - for (size_t k = 0; k < output_shape_[2]; ++k) { - for (size_t m = 0; m < output_shape_[3]; ++m) { - float sum = 0; - for (size_t index = 0; index < input_num_; ++index) { - auto input_addr = reinterpret_cast(inputs[index]->addr); - sum += input_addr[offset]; - } - output_addr[offset++] = sum; - } - } - } - } - - return true; -} - -void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but AddNCPUKernel olny support 4d or lower."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AddNCPUKernel needs 1 output."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.h deleted file mode 100644 index 1a1a9157d9..0000000000 --- a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class AddNCPUKernel : public CPUKernel { - public: - AddNCPUKernel() : input_num_(0) {} - ~AddNCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void CheckParam(const CNodePtr &kernel_node); - size_t input_num_; - std::vector output_shape_; -}; - -MS_REG_CPU_KERNEL(AddN, - KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AddNCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc deleted file mode 100644 index 9cc5126c08..0000000000 --- a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/allgather_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "device/cpu/mpi/mpi_adapter.h" -#include "ir/primitive.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr auto kRanksGroup = "group"; -constexpr auto kAllGatherInputNum = 1; -} // namespace - -void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != kAllGatherInputNum) { - MS_LOG(EXCEPTION) << "allgather input num:" << input_num; - } - - auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); - if (ranks_group != nullptr) { - ranks_group_ = GetValue>(ranks_group); - } else { - MS_LOG(EXCEPTION) << "Miss attribute " << kRanksGroup; - } -} - -bool AllGatherCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto input_data_num = inputs[0]->size / sizeof(float); - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h deleted file mode 100644 index 1dddf810ef..0000000000 --- a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class AllGatherCPUKernel : public CPUKernel { - public: - AllGatherCPUKernel() = default; - ~AllGatherCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - std::vector ranks_group_; -}; - -MS_REG_CPU_KERNEL(_HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AllGatherCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.cc deleted file mode 100644 index 3cd6c57413..0000000000 --- a/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/apply_momentum_cpu_kernel.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void ApplyMomentumCPUKernel::InitKernel(const CNodePtr & /*kernel_node*/) {} - -bool ApplyMomentumCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector & /*outputs*/) { - if (inputs.size() < 5) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[3]->size) { - MS_LOG(EXCEPTION) << "error input data size!"; - } - auto weight = reinterpret_cast(inputs[0]->addr); - auto accumulate = reinterpret_cast(inputs[1]->addr); - float learning_rate = reinterpret_cast(inputs[2]->addr)[0]; - auto gradient = reinterpret_cast(inputs[3]->addr); - float moment = reinterpret_cast(inputs[4]->addr)[0]; - size_t elem_num = inputs[0]->size / sizeof(float); - for (size_t i = 0; i < elem_num; ++i) { - accumulate[i] = accumulate[i] * moment + gradient[i]; - weight[i] -= accumulate[i] * learning_rate; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.h deleted file mode 100644 index c0ca581974..0000000000 --- a/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class ApplyMomentumCPUKernel : public MKLCPUKernel { - public: - ApplyMomentumCPUKernel() = default; - ~ApplyMomentumCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - ApplyMomentumCPUKernel); -MS_REG_CPU_KERNEL(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - ApplyMomentumCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.cc deleted file mode 100644 index ee328df721..0000000000 --- a/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/argmax_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void ArgmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (shape.size() != 2) { - MS_LOG(EXCEPTION) << "argmax kernel dims invalid " << shape.size(); - } - batch_size_ = shape[0]; - class_num_ = shape[1]; - - int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - if (axis != -1 && axis != 1) { - MS_LOG(EXCEPTION) << "argmax kernel not support axis " << axis; - } -} - -bool ArgmaxCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspaces*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output empty!"; - } - - size_t batch_float_size = batch_size_ * sizeof(float); - size_t batch_class_float_size = class_num_ * batch_float_size; - if (inputs[0]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { - MS_LOG(EXCEPTION) << "invalid input or output data size!"; - } - auto input = reinterpret_cast(inputs[0]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - size_t row_start = 0; - for (size_t i = 0; i < batch_size_; ++i) { - size_t max_index = 0; - float max_value = input[row_start]; - for (size_t j = 1; j < class_num_; ++j) { - size_t index = row_start + j; - if (input[index] > max_value) { - max_value = input[index]; - max_index = j; - } - } - output[i] = SizeToInt(max_index); - row_start += class_num_; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.h deleted file mode 100644 index aae7435c5c..0000000000 --- a/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class ArgmaxCPUKernel : public CPUKernel { - public: - ArgmaxCPUKernel() = default; - ~ArgmaxCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t class_num_{0}; - size_t batch_size_{0}; -}; - -MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - ArgmaxCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.cc deleted file mode 100644 index 00f3017231..0000000000 --- a/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.cc +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/bias_add_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -void BiasAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - bias_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - if (input_shape_.size() == 4) { - data_shape_ = 4; - } else if (input_shape_.size() == 2) { - data_shape_ = 2; - } else { - MS_LOG(EXCEPTION) << "bias add input data format should be NCHW or NC"; - } - if (input_shape_.size() != 2 && input_shape_.size() != 4) { - MS_LOG(EXCEPTION) << "bias add input shape nchw or nc"; - } - if (bias_shape_.size() != 1) { - MS_LOG(EXCEPTION) << "bias shape invalid"; - } - if (input_shape_[1] != bias_shape_[0]) { - MS_LOG(EXCEPTION) << "bias shape not match"; - } -} - -bool BiasAddCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() != 2 || outputs.size() != 1) { - MS_LOG(EXCEPTION) << "inputs outputs size not supoort"; - } - - auto src_addr = reinterpret_cast(inputs[0]->addr); - auto bias_addr = reinterpret_cast(inputs[1]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - - if (data_shape_ == 4) { - size_t h_size = input_shape_[3]; - size_t c_size = input_shape_[2] * h_size; - size_t n_size = input_shape_[1] * c_size; - size_t hw_size = input_shape_[2] * input_shape_[3]; - size_t n_offset = 0; - for (size_t n = 0; n < input_shape_[0]; ++n) { - size_t c_offset = 0; - for (size_t c = 0; c < input_shape_[1]; ++c) { - for (size_t hw = 0; hw < hw_size; ++hw) { - size_t offset = n_offset + c_offset + hw; - output_addr[offset] = src_addr[offset] + bias_addr[c]; - } - c_offset += c_size; - } - n_offset += n_size; - } - } else { - size_t n_offset = 0; - for (size_t n = 0; n < input_shape_[0]; ++n) { - for (size_t c = 0; c < input_shape_[1]; ++c) { - output_addr[n_offset + c] = src_addr[n_offset + c] + bias_addr[c]; - } - n_offset += input_shape_[1]; - } - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.h deleted file mode 100644 index 516a21147b..0000000000 --- a/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class BiasAddCPUKernel : public CPUKernel { - public: - BiasAddCPUKernel() = default; - ~BiasAddCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - uint8_t data_shape_{0}; - std::vector input_shape_; - std::vector bias_shape_; -}; -MS_REG_CPU_KERNEL( - BiasAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BiasAddCPUKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.cc deleted file mode 100644 index 1d9c7d076e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/bias_add_grad_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -void BiasAddGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (input_shape_.size() != 4 && input_shape_.size() != 2) { - MS_LOG(EXCEPTION) << "input data format not support"; - } -} - -bool BiasAddGradCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() != 1 || outputs.size() != 1) { - MS_LOG(EXCEPTION) << "input output size not support"; - } - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto input_addr = reinterpret_cast(inputs[0]->addr); - - if (input_shape_.size() == 4) { - size_t h_size = input_shape_[3]; - size_t c_size = h_size * input_shape_[2]; - size_t n_size = c_size * input_shape_[1]; - size_t hw_size = input_shape_[2] * input_shape_[3]; - size_t c_offset = 0; - for (size_t c = 0; c < input_shape_[1]; ++c) { - output_addr[c] = 0; - size_t n_offset = 0; - for (size_t n = 0; n < input_shape_[0]; ++n) { - for (size_t hw = 0; hw < hw_size; ++hw) { - size_t offset = c_offset + n_offset + hw; - output_addr[c] += input_addr[offset]; - } - n_offset += n_size; - } - c_offset += c_size; - } - } else if (input_shape_.size() == 2) { - for (size_t c = 0; c < input_shape_[1]; ++c) { - output_addr[c] = 0; - size_t n_offset = 0; - for (size_t n = 0; n < input_shape_[0]; ++n) { - output_addr[c] += input_addr[c + n_offset]; - n_offset += input_shape_[1]; - } - } - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.h deleted file mode 100644 index e3ac896096..0000000000 --- a/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class BiasAddGradCPUKernel : public CPUKernel { - public: - BiasAddGradCPUKernel() = default; - ~BiasAddGradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - std::vector input_shape_; -}; -MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BiasAddGradCPUKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc deleted file mode 100644 index d8f2ef421b..0000000000 --- a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/concat_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - - axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - auto input_1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (axis_ < 0) { - axis_ = axis_ + SizeToInt(input_1_shape.size()); - } - axis_ += 4 - input_1_shape.size(); - - auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; i++) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - CPUKernelUtils::ExpandDimsTo4(&input_shape); - input_shape_list_.push_back(input_shape); - } - - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); -} - -bool ConcatCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto buff_size = outputs[0]->size; - size_t dim0 = output_shape_[0]; - size_t dim1 = output_shape_[1]; - size_t dim2 = output_shape_[2]; - - if (axis_ == 3) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - for (size_t k = 0; k < dim2; ++k) { - CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size); - } - } - } - } else if (axis_ == 2) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size); - } - } - } else if (axis_ == 1) { - for (size_t i = 0; i < dim0; ++i) { - CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size); - } - } else if (axis_ == 0) { - CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); - } - return true; -} - -void ConcatCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, - size_t dim2, float **output_addr, size_t *buff_size) { - for (size_t i = 0; i < input_shape_list_.size(); ++i) { - auto input_i_shape = input_shape_list_[i]; - auto input_i_addr = reinterpret_cast(inputs[i]->addr); - - size_t num = CPUKernelUtils::GetElementNumOnAxis(input_i_shape, axis_); - num *= input_i_shape[axis_]; - auto pos = CPUKernelUtils::CalcOffset(input_i_shape, dim0, dim1, dim2, 0); - auto ret = memcpy_s(*output_addr, *buff_size, input_i_addr + pos, num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy failed."; - } - *output_addr += num; - *buff_size -= num * sizeof(float); - } -} - -void ConcatCPUKernel::CheckParam(const CNodePtr &kernel_node) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but ConcatCPUKernel olny support 4d or lower."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ConcatCPUKernel needs 1 output."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h deleted file mode 100644 index 46f9078178..0000000000 --- a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class ConcatCPUKernel : public CPUKernel { - public: - ConcatCPUKernel() : axis_(0) {} - ~ConcatCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void CheckParam(const CNodePtr &kernel_node); - void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, - float **output_addr, size_t *buff_size); - int axis_; - std::vector> input_shape_list_; - std::vector output_shape_; -}; - -MS_REG_CPU_KERNEL(Concat, - KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ConcatCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/cpu_kernel.cc deleted file mode 100644 index 2be05038d6..0000000000 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel.cc +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/cpu_kernel.h" - -namespace mindspore { -namespace kernel { -void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - size_t type_size = sizeof(float); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); - size_t tensor_size = - shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - input_size_list_.emplace_back(tensor_size); - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - for (size_t output_index = 0; output_index < output_num; ++output_index) { - std::vector shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); - size_t tensor_size = - shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - output_size_list_.emplace_back(tensor_size); - } -} - -void CPUKernel::Init(const CNodePtr &kernel_node) { - InitKernel(kernel_node); - InitInputOutputSize(kernel_node); -} - -void CPUKernelUtils::ExpandDimsTo4(std::vector *shape) { - auto len = shape->size(); - if (len < 4) { - for (size_t i = 0; i < 4 - len; ++i) { - shape->insert(shape->begin(), 1); - } - } -} - -size_t CPUKernelUtils::CalcOffset(const std::vector &shape, size_t dim0, size_t dim1, size_t dim2, - size_t dim3) { - size_t offset = dim0 * shape[1] * shape[2] * shape[3] + dim1 * shape[2] * shape[3] + dim2 * shape[3] + dim3; - return offset; -} - -size_t CPUKernelUtils::GetElementNumOnAxis(const std::vector &shape, int axis) { - if (axis < 0) { - axis = axis + SizeToInt(shape.size()); - } - size_t result = 1; - for (int j = 3; j > axis; --j) { - result *= shape[j]; - } - return result; -} - -void CPUKernelUtils::GetElementNumEveryDim(const std::vector &shape, std::vector *element_num) { - size_t accumulation = 1; - element_num->emplace_back(1); - for (size_t i = shape.size() - 1; i > 0; --i) { - accumulation *= shape[i]; - element_num->emplace_back(accumulation); - } - std::reverse(element_num->begin(), element_num->end()); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/cpu_kernel.h deleted file mode 100644 index 0836529840..0000000000 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel.h +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "ir/anf.h" -#include "session/anf_runtime_algorithm.h" - -using mindspore::kernel::Address; -using mindspore::kernel::AddressPtr; -namespace mindspore { -namespace kernel { -const char KSIZE[] = "ksize"; -const char STRIDE[] = "stride"; -const char STRIDES[] = "strides"; -const char DILATION[] = "dilation"; -const char PAD[] = "pad"; -const char PAD_MODE[] = "pad_mode"; -const char PADDING[] = "padding"; -const char PAD_MODE_LOWER_SAME[] = "same"; -const char PAD_MODE_LOWER_VALID[] = "valid"; -const char PAD_MODE_UPPER_SAME[] = "SAME"; -const char PAD_MODE_UPPER_VALID[] = "VALID"; -const char TRANSPOSE_A[] = "transpose_a"; -const char TRANSPOSE_B[] = "transpose_b"; -const char IS_GRAD[] = "is_grad"; -const char TRANSPOSE_NO = 'N'; -const char TRANSPOSE_YES = 'T'; -const char AXIS[] = "axis"; -const char BEGIN[] = "begin"; -const char END[] = "end"; -const char SIZE[] = "size"; -const char USE_NESTEROV[] = "use_nesterov"; - -class CPUKernel : public kernel::KernelMod { - public: - CPUKernel() = default; - ~CPUKernel() override = default; - void Init(const CNodePtr &kernel_node); - virtual void InitKernel(const CNodePtr &kernel_node) = 0; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void * /*stream_ptr*/) override { - return Launch(inputs, workspace, outputs); - }; - virtual bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) = 0; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - protected: - virtual void InitInputOutputSize(const CNodePtr &kernel_node); - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -class CPUKernelUtils { - public: - static void ExpandDimsTo4(std::vector *shape); - static size_t CalcOffset(const std::vector &shape, size_t dim0, size_t dim1, size_t dim2, size_t dim3); - static size_t GetElementNumOnAxis(const std::vector &shape, int axis); - static void GetElementNumEveryDim(const std::vector &shape, std::vector *element_num); -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc deleted file mode 100644 index bcda7af9fd..0000000000 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/cpu_kernel_factory.h" - -#include -#include -#include - -#include "device/kernel_info.h" - -namespace mindspore { -namespace kernel { -CPUKernelFactory &CPUKernelFactory::GetInstance() { - static CPUKernelFactory instance; - return instance; -} - -void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr, - CPUKernelCreator &&kernel_creator) { - (void)name_to_attr_creator_[kernel_name].emplace_back(kernel_attr, kernel_creator); -#if !defined(_WIN32) && !defined(_WIN64) - MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name; -#endif -} - -std::shared_ptr CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { - auto kernel_info = apply_kernel->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(kernel_build_Info); - std::pair ret_pair = CPUKernelAttrCheck(kernel_name, *kernel_build_Info); - if (ret_pair.first) { - return (name_to_attr_creator_.find(kernel_name)->second)[ret_pair.second].second(); - } - return nullptr; -} - -std::pair CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name, - const KernelBuildInfo &kernel_info) { - auto iter = name_to_attr_creator_.find(kernel_name); - if (iter == name_to_attr_creator_.end()) { - MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!"; - return std::make_pair(false, 0); - } - auto creators = iter->second; - for (size_t index = 0; index < creators.size(); ++index) { - auto attr_creator = creators[index]; - if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) { - return std::make_pair(true, index); - } - } - return std::make_pair(false, 0); -} - -bool CPUKernelFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) { - for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) { - auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first; - if (kernel_info.GetInputDeviceType(i) != dtype) { - MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i) - << ", register type:" << dtype; - return false; - } - } - for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) { - auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first; - if (kernel_info.GetOutputDeviceType(i) != dtype) { - MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i) - << ", register type:" << dtype; - return false; - } - } - return true; -} - -std::vector CPUKernelFactory::GetSupportedKernelAttrList(const std::string &kernel_name) { - std::vector result; - auto iter = name_to_attr_creator_.find(kernel_name); - if (iter == name_to_attr_creator_.end()) { - MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!"; - return result; - } - auto creators = iter->second; - for (size_t index = 0; index < creators.size(); ++index) { - auto attr_creator = creators[index]; - result.push_back(attr_creator.first); - } - return result; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h deleted file mode 100644 index 52eda12ba7..0000000000 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ - -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "kernel/cpu/cpu_kernel.h" -#include "device/cpu/kernel_select_cpu.h" - -namespace mindspore { -namespace kernel { -using mindspore::device::cpu::KernelAttr; -using CPUKernelCreator = std::function()>; -class CPUKernelFactory { - public: - static CPUKernelFactory &GetInstance(); - void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator); - std::shared_ptr Create(const std::string &kernel_name, const CNodePtr &apply_kernel); - std::vector GetSupportedKernelAttrList(const std::string &kernel_name); - - private: - CPUKernelFactory() = default; - ~CPUKernelFactory() = default; - DISABLE_COPY_AND_ASSIGN(CPUKernelFactory) - std::pair CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info); - bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info); - std::map>> name_to_attr_creator_; -}; - -class CPUKernelRegistrar { - public: - CPUKernelRegistrar(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator) { - CPUKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(kernel_creator)); - } - ~CPUKernelRegistrar() = default; -}; - -#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) MS_REG_CPU_KERNEL_(__COUNTER__, OPNAME, ATTR, OPCLASS) -#define MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) -#define _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) \ - static_assert(std::is_base_of::value, " must be base of CPUKernel"); \ - static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \ - []() { return std::make_shared(); }); - -#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \ - static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ - static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_reg(#OPNAME, ATTR, \ - []() { return std::make_shared>(); }); - -#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \ - static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ - static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_##S##_reg( \ - #OPNAME, ATTR, []() { return std::make_shared>(); }); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ diff --git a/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.cc deleted file mode 100644 index a1dcaca3f3..0000000000 --- a/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/debug_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" -#ifdef ENABLE_DEBUGGER -#include "debug/debugger/debugger.h" -#endif - -namespace mindspore { -namespace kernel { -void DebugCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } - -bool DebugCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 1 || outputs.empty()) { - MS_LOG(EXCEPTION) << " input or output empty!"; - } - auto val = reinterpret_cast(inputs[0]->addr); - MS_LOG(DEBUG) << " launch DebugCountCPUKernel val " << *val; - - auto output = reinterpret_cast(outputs[0]->addr); - size_t elem_num = inputs[0]->size / sizeof(int); - for (size_t i = 0; i < elem_num; i++) { - output[i] = val[i]; - } - -#ifdef ENABLE_DEBUGGER - // debugger will suspend execution is neccessary - Debugger::GetInstance()->PostDebugOp(); -#endif - - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.h deleted file mode 100644 index da9f3286b9..0000000000 --- a/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class DebugCPUKernel : public CPUKernel { - public: - DebugCPUKernel() = default; - ~DebugCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(Debug, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), DebugCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc deleted file mode 100644 index 07da3dcc25..0000000000 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "device/cpu/mpi/mpi_adapter.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -void EmbeddingLookUpCommGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - split_num_ = AnfAlgo::GetNodeAttr(kernel_node, "split_num"); - MS_LOG(INFO) << "split_num: " << split_num_; - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape[0] % split_num_ != 0) { - MS_LOG(EXCEPTION) << "Input shape[0] is " << input_shape[0] << ", but it must be multiple of split_num."; - } -} - -bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - size_t input_size = inputs[0]->size; - size_t output_size = outputs[0]->size; - MS_LOG(DEBUG) << "input addr: " << input_addr << "input size: " << input_size; - MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << output_size; - memset_s(output_addr, output_size, 0, output_size); - const std::vector &rank_group = {0, 1, 2, 3, 4, 5, 6, 7}; - size_t input_split_lens = input_size / split_num_ / sizeof(float_t); - size_t output_split_lens = output_size / split_num_ / sizeof(float_t); - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - for (int i = 0; i < split_num_; i++) { - mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group, - input_split_lens); - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "EmbeddingLookUpCommGradCPUKernel, used time: " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "EmbeddingLookUpCommGradCPUKernel, used time: " << time << " us"; -#endif - return true; -} - -void EmbeddingLookUpCommGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but EmbeddingLookUpCommGradCPUKernel needs 1."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h deleted file mode 100644 index 7222bd9be1..0000000000 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class EmbeddingLookUpCommGradCPUKernel : public CPUKernel { - public: - EmbeddingLookUpCommGradCPUKernel() : split_num_(1) {} - ~EmbeddingLookUpCommGradCPUKernel() override{}; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void CheckParam(const CNodePtr &kernel_node); - int split_num_; -}; - -MS_REG_CPU_KERNEL(EmbeddingLookupCommGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EmbeddingLookUpCommGradCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc deleted file mode 100644 index c8c2c667ad..0000000000 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include "kernel/cpu/embedding_look_up_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "device/cpu/mpi/mpi_adapter.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_lens_ = 1; - for (auto shape : input_shape_) { - input_lens_ = input_lens_ * shape; - } - indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - indices_lens_ = 1; - for (auto shape : indices_shape_) { - indices_lens_ = indices_lens_ * shape; - } - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - axis_ = 4 - input_shape_.size(); - reduce_scatter_flag_ = AnfAlgo::GetNodeAttr(kernel_node, "reduce_scatter_flag"); -#ifdef ENABLE_MPI - if (reduce_scatter_flag_) { - size_t gatherv2_out_lens = 1; - for (int i = 0; i < SizeToInt(input_shape_.size()); i++) { - if (i == 0) { - for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) { - gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j]; - } - } else { - gatherv2_out_lens = gatherv2_out_lens * input_shape_[i]; - } - } - gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float); - gather_v2_out_ = malloc(gatherv2_out_lens_); - if (gather_v2_out_ == nullptr) { - MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_; - } - auto ret = memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_); - if (ret != 0) { - MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed"; - } - split_num_ = AnfAlgo::GetNodeAttr(kernel_node, "split_num"); - } -#else - if (reduce_scatter_flag_) { - MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true"; - } -#endif - offset_ = AnfAlgo::GetNodeAttr(kernel_node, "offset"); - CPUKernelUtils::ExpandDimsTo4(&input_shape_); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); -} - -bool EmbeddingLookUpCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); - float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast(gather_v2_out_) : output_addr; - size_t dim0 = input_shape_[0]; - size_t dim1 = input_shape_[1]; - size_t dim2 = input_shape_[2]; - if (axis_ == 3) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - for (size_t k = 0; k < dim2; ++k) { - LookUpTable(inputs, i, j, k, &gather_out_addr); - } - } - } - } else if (axis_ == 2) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - LookUpTable(inputs, i, j, 0, &gather_out_addr); - } - } - } else if (axis_ == 1) { - for (size_t i = 0; i < dim0; ++i) { - LookUpTable(inputs, i, 0, 0, &gather_out_addr); - } - } else if (axis_ == 0) { - LookUpTable(inputs, 0, 0, 0, &gather_out_addr); - } -#ifdef ENABLE_MPI - if (reduce_scatter_flag_) { - size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float); - size_t reduce_scatter_out_lens = one_split_lens / 8; - const std::vector &group = {0, 1, 2, 3, 4, 5, 6, 7}; - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - for (int i = 0; i < split_num_; i++) { - mpi_instance->ReduceScatter(reinterpret_cast(gather_v2_out_) + i * one_split_lens, - output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum"); - } - } -#endif - return true; -} - -void LookUpTable_task(const float *input_addr, float *output_addr, const int *indices_addr, size_t indices_lens, - size_t num, size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, - std::vector input_shape, size_t input_lens) { - size_t lens = num * sizeof(float); - for (size_t i = 0; i < indices_lens; ++i) { - int indices = indices_addr[i] - offset; - if (indices >= 0) { - size_t index = IntToSize(indices); - if (index < input_shape[axis]) { - size_t pos = 0; - if (axis == 3) { - pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, dim2, index); - } else if (axis == 2) { - pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, index, 0); - } else if (axis == 1) { - pos = CPUKernelUtils::CalcOffset(input_shape, dim0, index, 0, 0); - } else if (axis == 0) { - pos = CPUKernelUtils::CalcOffset(input_shape, index, 0, 0, 0); - } - if (pos + num <= input_lens) { - auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; - } - } else { - auto ret = memset_s(output_addr, lens, 0, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; - } - } - } else { - auto ret = memset_s(output_addr, lens, 0, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; - } - } - } else { - auto ret = memset_s(output_addr, lens, 0, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; - } - } - output_addr += num; - } -} - -void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector &inputs, size_t dim0, size_t dim1, - size_t dim2, float **output_addr) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto indices_addr = reinterpret_cast(inputs[1]->addr); - size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); - float *task_out_addr = *output_addr; - const size_t thread_num = 8; - std::thread threads[8]; - size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; - size_t i; - size_t task_offset = 0; - MS_LOG(DEBUG) << "indices_lens_: " << indices_lens_ << " one task proc lens:" << task_proc_lens; - for (i = 0; i < thread_num; i++) { - if (task_offset >= indices_lens_) { - break; - } - MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; - threads[i] = - std::thread(LookUpTable_task, input_addr, task_out_addr + task_offset * num, indices_addr + task_offset, - task_proc_lens, num, dim0, dim1, dim2, offset_, axis_, input_shape_, input_lens_); - task_offset += task_proc_lens; - if (task_offset + task_proc_lens > indices_lens_) { - task_proc_lens = indices_lens_ - task_offset; - } - } - for (size_t j = 0; j < i; j++) { - threads[j].join(); - } - *output_addr += num * indices_lens_; -} - -void EmbeddingLookUpCPUKernel::CheckParam(const CNodePtr &kernel_node) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() - << ", but EmbeddingLookUpCPUKernel olny support 4d or lower."; - } - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but EmbeddingLookUpCPUKernel needs 2."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.h deleted file mode 100644 index d839571caa..0000000000 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class EmbeddingLookUpCPUKernel : public CPUKernel { - public: - EmbeddingLookUpCPUKernel() { - axis_ = 0; - offset_ = 0; - split_num_ = 0; - input_lens_ = 0; - indices_lens_ = 0; - gatherv2_out_lens_ = 0; - reduce_scatter_flag_ = false; - gather_v2_out_ = nullptr; - } - ~EmbeddingLookUpCPUKernel() override { - if (gather_v2_out_ != nullptr) { - free(gather_v2_out_); - gather_v2_out_ = nullptr; - } - } - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void LookUpTable(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, - float **output_addr); - void CheckParam(const CNodePtr &kernel_node); - std::vector input_shape_; - std::vector indices_shape_; - std::vector output_shape_; - int axis_; - int offset_; - int split_num_; - size_t input_lens_; - size_t indices_lens_; - size_t gatherv2_out_lens_; - bool reduce_scatter_flag_; - - void *gather_v2_out_; -}; - -MS_REG_CPU_KERNEL( - EmbeddingLookup, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - EmbeddingLookUpCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.cc deleted file mode 100644 index 60e7eafa78..0000000000 --- a/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/equal_count_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void EqualCountCPUKernel::InitKernel(const CNodePtr & /*kernel_node*/) {} - -bool EqualCountCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output empty!"; - } - if (inputs[0]->size != inputs[1]->size) { - MS_LOG(EXCEPTION) << "input or output size!"; - } - int count = 0; - auto left = reinterpret_cast(inputs[0]->addr); - auto right = reinterpret_cast(inputs[1]->addr); - size_t elem_num = inputs[0]->size / sizeof(int); - for (size_t i = 0; i < elem_num; i++) { - if (left[i] == right[i]) { - count++; - } - } - auto output = reinterpret_cast(outputs[0]->addr); - output[0] = count; - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.h deleted file mode 100644 index 13083889d0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class EqualCountCPUKernel : public CPUKernel { - public: - EqualCountCPUKernel() = default; - ~EqualCountCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - EqualCount, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - EqualCountCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc deleted file mode 100644 index 28090817cb..0000000000 --- a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc +++ /dev/null @@ -1,116 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/gather_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - if (axis_ < 0) { - axis_ = axis_ + SizeToInt(input_shape_.size()); - } - axis_ += 4 - input_shape_.size(); - CPUKernelUtils::ExpandDimsTo4(&input_shape_); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); -} - -bool GatherV2CPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto buff_size = outputs[0]->size; - size_t dim0 = input_shape_[0]; - size_t dim1 = input_shape_[1]; - size_t dim2 = input_shape_[2]; - if (axis_ == 3) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - for (size_t k = 0; k < dim2; ++k) { - CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size); - } - } - } - } else if (axis_ == 2) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size); - } - } - } else if (axis_ == 1) { - for (size_t i = 0; i < dim0; ++i) { - CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size); - } - } else if (axis_ == 0) { - CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); - } - return true; -} - -void GatherV2CPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, - size_t dim2, float **output_addr, size_t *buff_size) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto indices_addr = reinterpret_cast(inputs[1]->addr); - size_t elem_num = inputs[1]->size / 4; - size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); - for (size_t i = 0; i < elem_num; ++i) { - if (indices_addr[i] < 0) { - MS_LOG(EXCEPTION) << "The indices value is less than 0."; - } - size_t index = IntToSize(indices_addr[i]); - if (index >= input_shape_[IntToSize(axis_)]) { - auto ret = memset_s(*output_addr, *buff_size, 0., num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memset failed."; - } - } else { - size_t pos = 0; - if (axis_ == 3) { - pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index); - } else if (axis_ == 2) { - pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0); - } else if (axis_ == 1) { - pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0); - } else if (axis_ == 0) { - pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0); - } - auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy failed."; - } - } - *output_addr += num; - *buff_size -= num * sizeof(float); - } -} // namespace kernel - -void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but GatherV2CPUKernel olny support 4d or lower."; - } - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.h deleted file mode 100644 index 2ffd7df4d4..0000000000 --- a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class GatherV2CPUKernel : public CPUKernel { - public: - GatherV2CPUKernel() : axis_(0) {} - ~GatherV2CPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, - float **output_addr, size_t *buff_size); - void CheckParam(const CNodePtr &kernel_node); - std::vector input_shape_; - std::vector indices_shape_; - std::vector output_shape_; - int axis_; -}; - -MS_REG_CPU_KERNEL( - GatherV2, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - GatherV2CPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.cc deleted file mode 100644 index 657c85dc48..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/conv2d_cpu_kernel.h" -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - if (src_shape.size() != 4 || weight_shape.size() != 4) { - MS_LOG(EXCEPTION) << "conv2d only support nchw input!"; - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); - dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - - int kernel_size = SizeToInt(weight_shape[3]); - auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); - auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); - if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) { - MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!"; - } - if (stride_ori[0] != 1 || stride_ori[1] != 1) { - MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!"; - } - if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { - MS_LOG(EXCEPTION) << "conv2d dilation only support 1, and dilation must be 4d!"; - } - if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { - MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!"; - } - int stride = stride_ori[2]; - int dilation = dilation_ori[2]; - - dnnl::memory::dims strides{stride, stride}; - dnnl::memory::dims dilates{dilation - 1, dilation - 1}; - std::vector int_padding_l; - std::vector int_padding_r; - - const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); - GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); - if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { - MS_LOG(EXCEPTION) << "get padding failed"; - } - dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; - dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; - dnnl::convolution_forward::desc desc = - dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, - weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - - auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_WEIGHTS, weights_desc); - AddArgument(DNNL_ARG_DST, dst_desc); -} - -bool Conv2dCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_WEIGHTS, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.h deleted file mode 100644 index 1cb100299e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class Conv2dCPUKernel : public MKLCPUKernel { - public: - Conv2dCPUKernel() = default; - ~Conv2dCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - Conv2D, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - Conv2dCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc deleted file mode 100644 index fbfebaf56e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h" -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector weight_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - std::vector dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (src_shape.size() != 4 || weight_shape.size() != 4) { - MS_LOG(EXCEPTION) << ("conv2d grad filter only support nchw input!"); - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); - dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - - int kernel_size = SizeToInt(weight_shape[3]); - auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); - auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); - if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { - MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel only support equal stride, and stride must be 2d!"; - } - if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { - MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1, and dilation must be 4d!"; - } - if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { - MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1 in N axis and C axis!"; - } - int stride = stride_ori[0]; - int dilation = dilation_ori[2]; - - dnnl::memory::dims strides{stride, stride}; - dnnl::memory::dims dilates{dilation - 1, dilation - 1}; - const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); - std::vector int_padding_l; - std::vector int_padding_r; - GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); - if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { - MS_LOG(EXCEPTION) << "get padding failed"; - } - dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; - dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; - dnnl::convolution_forward::desc forward_desc = - dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, - weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - - auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); - - dnnl::convolution_backward_weights::desc backward_desc = dnnl::convolution_backward_weights::desc( - dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - - auto backward_prim_desc = dnnl::convolution_backward_weights::primitive_desc( - backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); - primitive_ = std::make_shared(backward_prim_desc); - - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_DIFF_DST, dst_desc); - AddArgument(DNNL_ARG_DIFF_WEIGHTS, weights_desc); -} - -bool Conv2dGradFilterCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h deleted file mode 100644 index 49559f452b..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class Conv2dGradFilterCPUKernel : public MKLCPUKernel { - public: - Conv2dGradFilterCPUKernel() = default; - ~Conv2dGradFilterCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - Conv2DBackpropFilter, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - Conv2dGradFilterCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc deleted file mode 100644 index ff0b8633d4..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h" -#include -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (src_shape.size() != 4 || weight_shape.size() != 4) { - MS_LOG(EXCEPTION) << "conv2d grad filter only support nchw input!"; - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); - dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - - int kernel_size = SizeToInt(weight_shape[3]); - auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); - auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); - if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { - MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel only support equal stride, and stride must be 2d!"; - } - if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { - MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1, and dilation must be 4d!"; - } - if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { - MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1 in N axis and C axis!"; - } - int stride = stride_ori[0]; - int dilation = dilation_ori[2]; - dnnl::memory::dims strides{stride, stride}; - dnnl::memory::dims dilates{dilation - 1, dilation - 1}; - std::vector int_padding_l; - std::vector int_padding_r; - const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); - GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); - if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { - MS_LOG(EXCEPTION) << "conv2d grad get padding failed"; - } - dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; - dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; - dnnl::convolution_forward::desc forward_desc = - dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, - weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - - auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); - - dnnl::convolution_backward_data::desc backward_desc = dnnl::convolution_backward_data::desc( - dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - - auto backward_prim_desc = - dnnl::convolution_backward_data::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); - primitive_ = std::make_shared(backward_prim_desc); - - AddArgument(DNNL_ARG_DIFF_SRC, src_desc); - AddArgument(DNNL_ARG_DIFF_DST, dst_desc); - AddArgument(DNNL_ARG_WEIGHTS, weights_desc); -} - -bool Conv2dGradInputCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_WEIGHTS, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h deleted file mode 100644 index 9fb024a279..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class Conv2dGradInputCPUKernel : public MKLCPUKernel { - public: - Conv2dGradInputCPUKernel() = default; - ~Conv2dGradInputCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - Conv2DBackpropInput, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - Conv2dGradInputCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc deleted file mode 100644 index 0a343785f7..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/lstm_cpu_kernel.h" -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { -#ifdef PLATFORM_86 - _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); - _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); -#endif - MS_EXCEPTION_IF_NULL(kernel_node); - using tag = dnnl::memory::format_tag; - using dim = dnnl::memory::dims; - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); - bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); - input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); - hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); - num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); - has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); - batch_size_ = SizeToInt(src_shape[1]); - seq_len_ = SizeToInt(src_shape[0]); - num_directions_ = 1; - if (bidirectional_) { - num_directions_ = 2; - } - if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { - MS_LOG(EXCEPTION) << "error iteration shape!"; - } - if (num_layers_ <= 0) { - MS_LOG(EXCEPTION) << "layers must be greater than zero!"; - } - if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { - MS_LOG(EXCEPTION) << "conv2d only support 3-D input!"; - } - const int gate_size = 4 * hidden_size_; - for (int i = 0; i < num_layers_; ++i) { - weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); - weight_h_size_ += gate_size * hidden_size_; - } - weight_size_ = weight_size_ * num_directions_; - weight_h_size_ = weight_h_size_ * num_directions_; - auto eng = MKLKernelEngine::Get().engine(); - dnnl::stream s(eng); - dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; - if (bidirectional_) { - direction = dnnl::rnn_direction::bidirectional_concat; - } - dim src_dims = {seq_len_, batch_size_, input_size_}; - dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; - weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; - bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; - dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; - dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); - dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); - dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); - dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); - dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); - dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); - dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); - auto desc = std::make_shared(dnnl::prop_kind::forward_training, direction, src_desc, - src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), - formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, - dst_h_desc, dst_c_desc); - prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng); - primitive_ = std::make_shared(prim_desc_); - AddArgument(DNNL_ARG_SRC_LAYER, src_desc); - AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); - AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); - AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc()); - AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc()); - AddArgument(DNNL_ARG_BIAS, bias_desc); - AddArgument(DNNL_ARG_DST_LAYER, dst_desc); - AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); - AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); - AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc()); -} - -bool LstmCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - using dt = dnnl::memory::data_type; - using tag = dnnl::memory::format_tag; - auto eng = MKLKernelEngine::Get().engine(); - auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); - auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); - auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng); - auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng); - user_weights_memory.set_data_handle(inputs[3]->addr); - user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); - Reorder(&user_weights_memory, &weights_memory); - Reorder(&user_weights_h_memory, &weights_h_memory); - auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng); - if (has_bias_) { - bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); - } else { - auto ret = - memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, prim_desc_.bias_desc().get_size()); - if (ret != 0) { - MS_LOG(EXCEPTION) << "bias memset error"; - } - } - // set handle - SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); - SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr); - SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h deleted file mode 100644 index d42ff803f0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ -#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64) -#define PLATFORM_86 -#endif -#ifdef PLATFORM_86 -#include -#endif -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" -namespace mindspore { -namespace kernel { -class LstmCPUKernel : public MKLCPUKernel { - public: - LstmCPUKernel() = default; - ~LstmCPUKernel() override = default; - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - int weight_size_ = 0; - int weight_h_size_ = 0; - int input_size_; - int hidden_size_; - int num_layers_; - int batch_size_; - int seq_len_; - int num_directions_; - bool bidirectional_; - bool has_bias_; - dnnl::memory::dims weights_dims_; - dnnl::memory::dims weights_h_dims_; - dnnl::memory::dims bias_dims_; - dnnl::lstm_forward::primitive_desc prim_desc_; -}; - -MS_REG_CPU_KERNEL(LSTM, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LstmCPUKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc deleted file mode 100644 index d7e7701d85..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc +++ /dev/null @@ -1,196 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h" -#include -#include -#include -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - using tag = dnnl::memory::format_tag; - using dim = dnnl::memory::dims; - auto eng = MKLKernelEngine::Get().engine(); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); - bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); - input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); - hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); - num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); - has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); - batch_size_ = SizeToInt(src_shape[1]); - seq_len_ = SizeToInt(src_shape[0]); - num_directions_ = 1; - if (bidirectional_) { - num_directions_ = 2; - } - if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { - MS_LOG(EXCEPTION) << "error iteration shape!"; - } - if (num_layers_ <= 0) { - MS_LOG(EXCEPTION) << "layers must be greater than zero!"; - } - if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { - MS_LOG(EXCEPTION) << "conv2d only support 3-D input!"; - } - const int gate_size = 4 * hidden_size_; - for (int i = 0; i < num_layers_; ++i) { - weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); - weight_h_size_ += gate_size * hidden_size_; - } - weight_size_ = weight_size_ * num_directions_; - weight_h_size_ = weight_h_size_ * num_directions_; - dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; - if (bidirectional_) { - direction = dnnl::rnn_direction::bidirectional_concat; - } - dim src_dims = {seq_len_, batch_size_, input_size_}; - dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; - weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; - bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; - dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; - dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); - dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); - dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); - dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); - dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); - dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); - dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); - auto forward_desc = std::make_shared( - dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, - formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, - dst_c_desc); - auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng); - auto backward_desc = std::make_shared( - dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), - formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, - src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, - dst_h_desc, dst_c_desc); - prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc); - primitive_ = std::make_shared(prim_backward_desc_); - - AddArgument(DNNL_ARG_SRC_LAYER, src_desc); - AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); - AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); - AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc()); - AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc()); - AddArgument(DNNL_ARG_BIAS, bias_desc); - AddArgument(DNNL_ARG_DST_LAYER, dst_desc); - AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); - AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); - AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc()); - AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc); - AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc); - AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); - AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc()); - AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc()); - AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc); - AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc); - AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc); - AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc); -} - -bool LSTMGradCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace /*workspace*/, - const std::vector &outputs) { - using dt = dnnl::memory::data_type; - using tag = dnnl::memory::format_tag; - auto eng = MKLKernelEngine::Get().engine(); - // construct fw memory - auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); - auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); - auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng); - auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng); - auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng); - user_weights_memory.set_data_handle(inputs[3]->addr); - user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); - Reorder(&user_weights_memory, &weights_memory); - Reorder(&user_weights_h_memory, &weights_h_memory); - if (has_bias_) { - bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); - } else { - if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0, - prim_backward_desc_.bias_desc().get_size())) { - MS_LOG(EXCEPTION) << "bias memset error"; - } - } - // construct bw memory - auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng); - auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng); - auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng); - auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); - auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); - user_diff_weights_memory.set_data_handle(outputs[3]->addr); - user_diff_weights_h_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_); - if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0, - user_diff_weights_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "user weights grad memset error"; - } - if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0, - user_diff_weights_h_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "user weights iter grad memset error"; - } - if (has_bias_) { - diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); - } - if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0, - prim_backward_desc_.diff_bias_desc().get_size())) { - MS_LOG(EXCEPTION) << "bias grad memset error"; - } - if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0, - diff_weights_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "weights grad memset error"; - } - if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0, - diff_weights_h_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "weights iter grad memset error"; - } - SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); - SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr); - SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr); - SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr); - SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr); - ExecutePrimitive(); - Reorder(&diff_weights_memory, &user_diff_weights_memory); - Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h deleted file mode 100644 index 1f3fb824c0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class LSTMGradCPUKernel : public MKLCPUKernel { - public: - LSTMGradCPUKernel() = default; - ~LSTMGradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - int weight_size_ = 0; - int weight_h_size_ = 0; - int input_size_; - int hidden_size_; - int num_layers_; - int batch_size_; - int seq_len_; - int num_directions_; - bool bidirectional_; - bool has_bias_; - dnnl::memory::dims weights_dims_; - dnnl::memory::dims weights_h_dims_; - dnnl::memory::dims bias_dims_; - dnnl::lstm_backward::primitive_desc prim_backward_desc_; -}; - -MS_REG_CPU_KERNEL(LSTMGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LSTMGradCPUKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.cc deleted file mode 100644 index 28266f2aa0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/matmul_cpu_kernel.h" -#include -#include -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "common/utils.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - - if (src_shape.size() != 2 || weight_shape.size() != 2 || dst_shape.size() != 2) { - MS_LOG(EXCEPTION) << "matmul invalid input size"; - } - bool trans_a = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_A); - bool trans_b = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_B); - if (trans_a) { - trans_a_ = TRANSPOSE_YES; - dim_m_ = static_cast(src_shape[1]); - dim_k_ = static_cast(src_shape[0]); - } else { - dim_m_ = static_cast(src_shape[0]); - dim_k_ = static_cast(src_shape[1]); - } - if (trans_b) { - trans_b_ = TRANSPOSE_YES; - } - dim_n_ = static_cast(dst_shape[1]); -} - -bool MatMulCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "matmul error input output size!"; - } - dnnl_dim_t lda = dim_m_; - if (trans_a_ == TRANSPOSE_NO) { - lda = dim_k_; - } - dnnl_dim_t ldb = dim_k_; - if (trans_b_ == TRANSPOSE_NO) { - ldb = dim_n_; - } - auto input_a = reinterpret_cast(inputs[0]->addr); - auto input_b = reinterpret_cast(inputs[1]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - (void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, 1.f, input_a, lda, input_b, ldb, 0.f, output, dim_n_); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.h deleted file mode 100644 index 10276d01fa..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class MatMulCPUKernel : public MKLCPUKernel { - public: - MatMulCPUKernel() = default; - ~MatMulCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - char trans_a_{TRANSPOSE_NO}; - char trans_b_{TRANSPOSE_NO}; - dnnl_dim_t dim_m_{0}; - dnnl_dim_t dim_n_{0}; - dnnl_dim_t dim_k_{0}; -}; - -MS_REG_CPU_KERNEL( - MatMul, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - MatMulCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc deleted file mode 100644 index a38470e3a3..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" -#include -#include -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" - -namespace mindspore { -namespace kernel { -void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, - const std::vector &src_shape, int kernel_size, int stride, - std::vector *padding_l, std::vector *padding_r) { - MS_EXCEPTION_IF_NULL(kernel_node); - if (src_shape.size() < 2) { - MS_LOG(EXCEPTION) << "set pad only support src dim >= 2!"; - } - std::vector weight_height; - weight_height.emplace_back(src_shape[src_shape.size() - 2]); - weight_height.emplace_back(src_shape[src_shape.size() - 1]); - int rad = kernel_size / 2; - int need_pad = kernel_size - 1; - MS_LOG(INFO) << "pad mode " << pad_mode; - if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) { - for (auto wh : weight_height) { - int re = (wh - 1) % stride; - int pad = std::max(rad - (re / 2), 0); - padding_r->emplace_back(pad); - pad = std::max(need_pad - pad - re, 0); - padding_l->emplace_back(pad); - } - } else if (pad_mode == PAD_MODE_LOWER_VALID || pad_mode == PAD_MODE_UPPER_VALID) { - MS_LOG(INFO) << "pad valid"; - padding_l->emplace_back(0); - padding_l->emplace_back(0); - padding_r->emplace_back(0); - padding_r->emplace_back(0); - } else { - std::vector pad = AnfAlgo::GetNodeAttr>(kernel_node, PAD); - if (pad.size() != 4) { - MS_LOG(EXCEPTION) << "wrong pad size in max pooling " << pad.size(); - } - padding_l->emplace_back(pad[0]); - padding_l->emplace_back(pad[1]); - padding_r->emplace_back(pad[2]); - padding_r->emplace_back(pad[3]); - } -} - -dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const { - dnnl::memory::format_tag mem_tag; - auto dim_size = dims.size(); - if (dim_size == 4) { - mem_tag = dnnl::memory::format_tag::abcd; - } else if (dim_size == 3) { - mem_tag = dnnl::memory::format_tag::abc; - } else if (dim_size == 2) { - mem_tag = dnnl::memory::format_tag::ab; - } else if (dim_size == 1) { - mem_tag = dnnl::memory::format_tag::a; - } else { - MS_LOG(EXCEPTION) << "kernel dims invalid " << dim_size; - } - return mem_tag; -} - -dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector &shape) { - dnnl::memory::dims dims; - dims.insert(dims.end(), shape.begin(), shape.end()); - dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims); - dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag); - return mem_desc; -} - -void MKLCPUKernel::AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc) { - arguments_[arg_key] = MKLKernelEngine::Get().CreateMemory(mem_desc, alloc); -} - -void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) { - auto arg_iter = arguments_.find(arg_key); - if (arg_iter != arguments_.end()) { - arg_iter->second.set_data_handle(ptr); - } -} - -void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); } - -void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { - MKLKernelEngine::Get().Reorder(src_mem, dst_mem); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h deleted file mode 100644 index 10a860afff..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include "dnnl.hpp" -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class MKLCPUKernel : public CPUKernel { - public: - MKLCPUKernel() = default; - ~MKLCPUKernel() override = default; - - protected: - void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector &src_shape, - int kernel_size, int stride, std::vector *padding_l, std::vector *padding_r); - void AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc = false); - void SetArgumentHandle(int arg_key, void *ptr); - dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const; - dnnl::memory::desc GetDefaultMemDesc(const std::vector &shape); - void ExecutePrimitive(); - std::unordered_map arguments_; - std::shared_ptr primitive_{nullptr}; - inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) { - return dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout}; - } - void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc deleted file mode 100644 index 5ae9791b12..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "utils/log_adapter.h" -#include "dnnl.hpp" - -namespace mindspore { -namespace kernel { -void MKLKernelEngine::Execute(const std::shared_ptr &primitive, - const std::unordered_map &arguments) { - MS_EXCEPTION_IF_NULL(primitive); - primitive->execute(stream_, arguments); - (void)stream_.wait(); -} - -dnnl::memory MKLKernelEngine::CreateMemory(const dnnl::memory::desc &mem_desc, bool alloc) { - if (alloc) { - return dnnl::memory(mem_desc, engine_); - } else { - return dnnl::memory(mem_desc, engine_, nullptr); - } -} -void MKLKernelEngine::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { - dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.cc deleted file mode 100644 index 4f77508004..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.cc +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/mul_cpu_kernel.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void MulCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) { - MS_LOG(EXCEPTION) << "mul only support same dim input or tensor * scalar " << src0_shape.size() << " vs " - << src1_shape.size(); - } - if (src1_shape.size() < src0_shape.size()) { - for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { - src1_shape.emplace_back(1); - } - } - dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape); - dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape); - dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape); - dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_mul, src0_mem_desc, src1_mem_desc, dst_mem_desc); - auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - AddArgument(DNNL_ARG_SRC_0, src0_mem_desc); - AddArgument(DNNL_ARG_SRC_1, src1_mem_desc); - AddArgument(DNNL_ARG_DST, dst_mem_desc); -} - -bool MulCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "mul error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.h deleted file mode 100644 index 1131fd594c..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class MulCPUKernel : public MKLCPUKernel { - public: - MulCPUKernel() = default; - ~MulCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - MulCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.cc deleted file mode 100644 index 5225050dc1..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/pooling_cpu_kernel.h" -#include -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - std::vector kernel_sizes = AnfAlgo::GetNodeAttr>(kernel_node, KSIZE); - std::vector strides = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); - if (kernel_sizes.size() != 4 || strides.size() != 4) { - MS_LOG(EXCEPTION) << "invalid kernel size " << kernel_sizes.size() << " or stride size " << strides.size(); - } - dnnl::memory::dims strides_dims{strides[2], strides[3]}; - dnnl::memory::dims kernels_dims{kernel_sizes[2], kernel_sizes[3]}; - const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PADDING); - std::vector int_padding_l; - std::vector int_padding_r; - GetPadding(kernel_node, pad_mode, src_shape, kernel_sizes[3], strides[3], &int_padding_l, &int_padding_r); - if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { - MS_LOG(EXCEPTION) << "pooling get padding failed"; - } - dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; - dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; - dnnl::pooling_forward::desc desc = - dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_max, src_desc, dst_desc, - strides_dims, kernels_dims, padding_l, padding_r); - auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_DST, dst_desc); - AddArgument(DNNL_ARG_WORKSPACE, prim_desc.workspace_desc()); -} - -bool PoolingCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.h deleted file mode 100644 index 4993d0834d..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class PoolingCPUKernel : public MKLCPUKernel { - public: - PoolingCPUKernel() = default; - ~PoolingCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - PoolingCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.cc deleted file mode 100644 index c0459de790..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.cc +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/pooling_grad_cpu_kernel.h" -#include -#include -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void PoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - src_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - dst_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector kernel_sizes = AnfAlgo::GetNodeAttr>(kernel_node, KSIZE); - std::vector strides = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); - if (kernel_sizes.size() != 4 || strides.size() != 4 || src_shape_.size() != 4 || dst_shape_.size() != 4) { - MS_LOG(EXCEPTION) << "pooling grad invalid input size"; - } - std::vector padding_r; - const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PADDING); - kernel_size_ = kernel_sizes[3]; - stride_ = strides[3]; - GetPadding(kernel_node, pad_mode, src_shape_, kernel_size_, stride_, &padding_l_, &padding_r); -} - -void PoolingGradCPUKernel::RowPoolingGrad(const float *input, float *output, float diff, - const std::vector> &box, - std::vector> *row_max_pair) { - float max_value = 0; - size_t max_index = box[1].second; - size_t src_width = src_shape_[3]; - size_t index_start; - size_t index; - for (size_t i = box[1].first; i < box[1].second; ++i) { - if ((*row_max_pair)[i].first == 0) { - index_start = box[0].first * src_width; - for (size_t j = box[0].first; j < box[0].second; ++j) { - index = index_start + i; - if (input[index] > (*row_max_pair)[i].second || j == box[0].first) { - (*row_max_pair)[i].second = input[index]; - (*row_max_pair)[i].first = index; - } - index_start += src_width; - } - } - if ((*row_max_pair)[i].second > max_value || max_index == box[1].second) { - max_value = (*row_max_pair)[i].second; - max_index = i; - } - } - - output[(*row_max_pair)[max_index].first] += diff; -} - -void PoolingGradCPUKernel::ChannelPoolingGrad(const float *input, const float *diff, float *output) { - int src_width = SizeToInt(src_shape_[3]); - int src_height = SizeToInt(src_shape_[2]); - std::vector> row_max_pair(src_shape_[3]); - std::vector> box(2); - int h_start = -padding_l_[0]; - size_t diff_index = 0; - for (size_t h = 0; h < dst_shape_[2]; ++h) { - box[0].first = IntToSize(std::max(h_start, 0)); - box[0].second = IntToSize(std::min(h_start + kernel_size_, src_height)); - for (size_t w = 0; w < src_shape_[3]; ++w) { - row_max_pair[w].first = 0; - row_max_pair[w].second = 0; - } - int w_start = -padding_l_[1]; - for (size_t w = 0; w < dst_shape_[3]; ++w) { - box[1].first = IntToSize(std::max(w_start, 0)); - box[1].second = IntToSize(std::min(w_start + kernel_size_, src_width)); - RowPoolingGrad(input, output, diff[diff_index], box, &row_max_pair); - diff_index += 1; - w_start += stride_; - } - h_start += stride_; - } -} - -bool PoolingGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 3 || outputs.empty()) { - MS_LOG(EXCEPTION) << "pooling grad error input output size!"; - } - - auto input = reinterpret_cast(inputs[0]->addr); - auto diff = reinterpret_cast(inputs[2]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - auto ret = memset_s(output, outputs[0]->size, 0, outputs[0]->size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "pooling grad memset error"; - } - size_t src_wh = src_shape_[2] * src_shape_[3]; - size_t dst_wh = dst_shape_[2] * dst_shape_[3]; - for (size_t n = 0; n < src_shape_[0]; ++n) { - for (size_t c = 0; c < src_shape_[1]; ++c) { - ChannelPoolingGrad(input, diff, output); - input = input + src_wh; - output = output + src_wh; - diff = diff + dst_wh; - } - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.h deleted file mode 100644 index cdb2c69ef0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class PoolingGradCPUKernel : public MKLCPUKernel { - public: - PoolingGradCPUKernel() = default; - ~PoolingGradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void RowPoolingGrad(const float *input, float *output, float diff, const std::vector> &box, - std::vector> *row_max_pair); - void ChannelPoolingGrad(const float *input, const float *diff, float *output); - int stride_{0}, kernel_size_{0}; - std::vector padding_l_; - std::vector src_shape_; - std::vector dst_shape_; -}; - -MS_REG_CPU_KERNEL(MaxPoolGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - PoolingGradCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.cc deleted file mode 100644 index d5ef20a25e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.cc +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/relu_cpu_kernel.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void ReluCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (src_shape.size() != 4 && src_shape.size() != 2) { - MS_LOG(EXCEPTION) << "relu kernel dims invalid " << src_shape.size(); - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - - dnnl::eltwise_forward::desc desc = - dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::eltwise_relu, src_desc, 0.0); - auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_DST, src_desc); -} - -bool ReluCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.h deleted file mode 100644 index 26905e267d..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class ReluCPUKernel : public MKLCPUKernel { - public: - ReluCPUKernel() = default; - ~ReluCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReluCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.cc deleted file mode 100644 index 4a6213ddf2..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/relu_grad_cpu_kernel.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void ReluGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (src_shape.size() != 4 && src_shape.size() != 2) { - MS_LOG(EXCEPTION) << "relu grad kernel dims invalid " << src_shape.size(); - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - - dnnl::eltwise_forward::desc forward_desc = - dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::eltwise_relu, src_desc, 0.0); - auto forward_prim_desc = dnnl::eltwise_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); - - dnnl::eltwise_backward::desc backward_desc = - dnnl::eltwise_backward::desc(dnnl::algorithm::eltwise_relu, src_desc, src_desc, 0.0, 0.0); - auto backward_prim_desc = - dnnl::eltwise_backward::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); - primitive_ = std::make_shared(backward_prim_desc); - - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_DIFF_SRC, src_desc); - AddArgument(DNNL_ARG_DIFF_DST, src_desc); -} - -bool ReluGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "relu grad error input output size!"; - } - if (inputs[0]->size != outputs[0]->size) { - MS_LOG(EXCEPTION) << "relu grad error input output data size!"; - } - - SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); - ExecutePrimitive(); - size_t mem_bits = outputs[0]->size; - auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret; - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.h deleted file mode 100644 index f0a77ee282..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class ReluGradCPUKernel : public MKLCPUKernel { - public: - ReluGradCPUKernel() = default; - ~ReluGradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - ReluGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReluGradCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.cc deleted file mode 100644 index 7fa740cfc0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/softmax_cpu_kernel.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void SoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector axis_list = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); - if (axis_list.size() != 1) { - MS_LOG(EXCEPTION) << "cpu softmax only support input axis size 1"; - } - int axis = axis_list[0]; - if (axis == -1 || axis >= SizeToInt(src_shape.size())) { - axis = SizeToInt(src_shape.size()) - 1; - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis); - auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_DST, src_desc); -} - -bool SoftmaxCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "softmax error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.h deleted file mode 100644 index 6acb9e5b9b..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class SoftmaxCPUKernel : public MKLCPUKernel { - public: - SoftmaxCPUKernel() = default; - ~SoftmaxCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SoftmaxCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc deleted file mode 100644 index 05b1a79924..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h" -#include -#include -#include -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void SoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - size_t type_size = sizeof(float); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - workspace_size_list_.emplace_back(tensor_size); -} - -void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - dnnl::memory::dims mem_dims; - mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); - if (mem_dims.size() != 2) { - MS_LOG(EXCEPTION) << "SoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size(); - } - batch_size_ = shape[0]; - class_num_ = shape[1]; - if (batch_size_ == 0 || class_num_ == 0) { - MS_LOG(EXCEPTION) << "invalid batch size or class num input!"; - } - dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); - - dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); - auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - - AddArgument(DNNL_ARG_SRC, mem_desc); - AddArgument(DNNL_ARG_DST, mem_desc); -} - -void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *logits, const float *labels, - float *output1, float *output2) const { - float epsilon = 1e-6; - for (size_t i = 0; i < batch_size_; ++i) { - output1[i] = 0; - float loss = 0.0; - for (size_t j = 0; j < class_num_; ++j) { - float logit = logf(logits[i * class_num_ + j] <= 0.0 ? epsilon : logits[i * class_num_ + j]); - output2[i * class_num_ + j] = logits[i * class_num_ + j] - labels[i * class_num_ + j]; - loss += labels[i * class_num_ + j] * logit; - } - output1[i] = -loss; - } -} - -bool SoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - if (inputs.empty() || workspace.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - size_t batch_float_size = batch_size_ * sizeof(float); - size_t batch_class_float_size = class_num_ * batch_float_size; - if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size || - inputs[1]->size != batch_class_float_size) { - MS_LOG(EXCEPTION) << "error input data size!"; - } - if (outputs[1]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { - MS_LOG(EXCEPTION) << "error output data size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr); - ExecutePrimitive(); - auto labels = reinterpret_cast(inputs[1]->addr); - auto logits = reinterpret_cast(workspace[0]->addr); - auto output1 = reinterpret_cast(outputs[0]->addr); - auto output2 = reinterpret_cast(outputs[1]->addr); - ForwardPostExecute(logits, labels, output1, output2); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h deleted file mode 100644 index f663508059..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class SoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { - public: - SoftmaxCrossEntropyWithLogitsCPUKernel() = default; - ~SoftmaxCrossEntropyWithLogitsCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - void InitInputOutputSize(const CNodePtr &kernel_node) override; - - private: - void ForwardPostExecute(const float *logits, const float *labels, float *output1, float *output2) const; - size_t class_num_{0}; - size_t batch_size_{0}; -}; -MS_REG_CPU_KERNEL(SoftmaxCrossEntropyWithLogits, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SoftmaxCrossEntropyWithLogitsCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc deleted file mode 100644 index c33fcd246f..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h" -#include -#include -#include -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - size_t type_size = sizeof(float); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - workspace_size_list_.emplace_back(tensor_size); -} - -void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - dnnl::memory::dims mem_dims; - mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); - if (mem_dims.size() != 2) { - MS_LOG(EXCEPTION) << "SparseSoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size(); - } - batch_size_ = shape[0]; - class_num_ = shape[1]; - if (batch_size_ == 0 || class_num_ == 0) { - MS_LOG(EXCEPTION) << "invalid batch size or class num input!"; - } - is_grad_ = AnfAlgo::GetNodeAttr(kernel_node, IS_GRAD); - dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); - - dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); - auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - - AddArgument(DNNL_ARG_SRC, mem_desc); - AddArgument(DNNL_ARG_DST, mem_desc); -} - -void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses, - float *output) const { - float total_loss = 0; - for (size_t i = 0; i < batch_size_; ++i) { - if (labels[i] < 0) { - MS_LOG(EXCEPTION) << "label value must >= 0"; - } - size_t label = IntToSize(labels[i]); - if (label > class_num_) { - MS_LOG(EXCEPTION) << "error label input!"; - } - total_loss -= logf(losses[i * class_num_ + label]); - } - output[0] = total_loss / batch_size_; -} - -void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, - float *output) const { - size_t row_start = 0; - for (size_t i = 0; i < batch_size_; ++i) { - if (labels[i] < 0) { - MS_LOG(EXCEPTION) << "label value must >= 0"; - } - size_t label = IntToSize(labels[i]); - if (label > class_num_) { - MS_LOG(EXCEPTION) << "error label input!"; - } - for (size_t j = 0; j < class_num_; ++j) { - size_t index = row_start + j; - if (j == label) { - output[index] = (losses[index] - 1) / batch_size_; - } else { - output[index] = losses[index] / batch_size_; - } - } - row_start += class_num_; - } -} - -bool SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - if (inputs.empty() || workspace.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - size_t batch_float_size = batch_size_ * sizeof(float); - size_t batch_class_float_size = class_num_ * batch_float_size; - if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size || - inputs[1]->size != batch_float_size) { - MS_LOG(EXCEPTION) << "error input data size!"; - } - if (is_grad_ && outputs[0]->size != batch_class_float_size) { - MS_LOG(EXCEPTION) << "error output data size!"; - } else if (!is_grad_ && outputs[0]->size != sizeof(float)) { - MS_LOG(EXCEPTION) << "error output data size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr); - ExecutePrimitive(); - auto labels = reinterpret_cast(inputs[1]->addr); - auto losses = reinterpret_cast(workspace[0]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - if (is_grad_) { - GradPostExecute(labels, losses, output); - } else { - ForwardPostExecute(labels, losses, output); - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h deleted file mode 100644 index 6391b27de6..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { - public: - SparseSoftmaxCrossEntropyWithLogitsCPUKernel() = default; - ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - void InitInputOutputSize(const CNodePtr &kernel_node) override; - - private: - void ForwardPostExecute(const int *labels, const float *losses, float *output) const; - void GradPostExecute(const int *labels, const float *losses, float *output) const; - bool is_grad_{false}; - size_t class_num_{0}; - size_t batch_size_{0}; -}; - -MS_REG_CPU_KERNEL( - SparseSoftmaxCrossEntropyWithLogits, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - SparseSoftmaxCrossEntropyWithLogitsCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.cc deleted file mode 100644 index 00dfe73f28..0000000000 --- a/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/one_hot_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void OneHotCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - if (output_shape.size() < 2) { - MS_LOG(EXCEPTION) << "invalid output shape size: " << output_shape.size(); - } - int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - if (axis != -1 && IntToSize(axis) >= output_shape.size()) { - MS_LOG(EXCEPTION) << "invalid axis: " << axis; - } - if (axis == -1) { - axis_ = output_shape.size() - 1; - } else { - axis_ = IntToSize(axis); - } - depth_ = output_shape[axis_]; - stride_ = 1; - for (size_t i = axis_ + 1; i < output_shape.size(); ++i) { - stride_ *= output_shape[i]; - } -} - -bool OneHotCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 3 || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output invalid!"; - } - auto indices = reinterpret_cast(inputs[0]->addr); - auto on_value = reinterpret_cast(inputs[1]->addr)[0]; - auto off_value = reinterpret_cast(inputs[2]->addr)[0]; - auto output = reinterpret_cast(outputs[0]->addr); - size_t elem_num = inputs[0]->size / sizeof(int); - - for (size_t i = 0; i < elem_num; i++) { - size_t stride_num = i / stride_; - size_t output_index = stride_num * depth_ * stride_ + i % stride_; - size_t index = IntToSize(indices[i]); - for (size_t j = 0; j < depth_; j++) { - if (index == j) { - output[output_index] = on_value; - } else { - output[output_index] = off_value; - } - output_index += stride_; - } - } - - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.h deleted file mode 100644 index ef13047343..0000000000 --- a/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class OneHotCPUKernel : public CPUKernel { - public: - OneHotCPUKernel() = default; - ~OneHotCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t depth_; - size_t stride_; - size_t axis_; -}; - -MS_REG_CPU_KERNEL(OneHot, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - OneHotCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.cc deleted file mode 100644 index e56f2af8c7..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.cc +++ /dev/null @@ -1,160 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include "kernel/cpu/reduce_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -const size_t kReduceTypeMax = 0; -const size_t kReduceTypeMean = 1; -const size_t kReduceTypeSum = 2; -const size_t kMaxDim = 100; -void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == "ReduceMax") { - reduce_type_ = kReduceTypeMax; - } else if (kernel_name == "ReduceMean") { - reduce_type_ = kReduceTypeMean; - } else if (kernel_name == "ReduceSum") { - reduce_type_ = kReduceTypeSum; - } else { - MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; - } - shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); - if (axis_addr->isa()) { - auto attr_axis = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); - if (attr_axis.size() > shape_.size()) { - MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size(); - } else if (attr_axis.empty()) { - axis_.push_back(shape_.size() - 1); - } else { - for (auto axis : attr_axis) { - if (IntToSize(axis) >= (shape_.size())) { - MS_LOG(EXCEPTION) << "axis value is oversize."; - } - axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); - } - } - } else if (axis_addr->isa()) { - int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - if (axis >= 0 && IntToSize(axis) >= shape_.size()) { - MS_LOG(EXCEPTION) << "axis value is oversize."; - } - axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; - } - for (size_t i = 0; i < shape_.size(); ++i) { - if (shape_[i] <= 0) { - MS_LOG(EXCEPTION) << "shape value is invalid."; - } - left_dims_ *= shape_[i]; - } - for (size_t i = 0; i < axis_.size(); ++i) { - stride_ *= shape_[axis_[i]]; - } - if (stride_ <= 0) { - MS_LOG(EXCEPTION) << "stride_ must greater than zero."; - } - left_dims_ = left_dims_ / stride_; -} -bool ReduceCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspaces*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output empty!"; - } - size_t out_float_size = left_dims_ * sizeof(float); - size_t in_float_size = stride_ * out_float_size; - if (inputs[0]->size != in_float_size || outputs[0]->size != out_float_size) { - MS_LOG(EXCEPTION) << "invalid input or output data size!"; - } - auto input = reinterpret_cast(inputs[0]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - int size = inputs[0]->size / sizeof(float); - std::vector new_input(IntToSize(size), 0.0); - std::vector transpose_axis; - for (size_t i = 0; i < shape_.size(); ++i) { - bool insert = true; - for (size_t j = 0; j < axis_.size(); ++j) { - if (axis_[j] == i) { - insert = false; - break; - } - } - if (insert) { - transpose_axis.push_back(i); - } - } - (void)transpose_axis.insert(transpose_axis.end(), axis_.begin(), axis_.end()); - Transpose(size, input, shape_, transpose_axis, SizeToInt(shape_.size()), &new_input[0]); - if (reduce_type_ == kReduceTypeMax) { - for (size_t i = 0; i < left_dims_; ++i) { - float value = new_input[i * stride_]; - for (size_t k = 0; k < stride_; ++k) { - if (value < new_input[i * stride_ + k]) { - value = new_input[i * stride_ + k]; - } - } - output[i] = value; - } - } else { - for (size_t i = 0; i < left_dims_; ++i) { - float value = 0.0; - for (size_t k = 0; k < stride_; ++k) { - value += new_input[i * stride_ + k]; - } - if (reduce_type_ == kReduceTypeMean) { - output[i] = value / stride_; - } else { - output[i] = value; - } - } - } - return true; -} -void ReduceCPUKernel::Transpose(const int size, const float *input, const std::vector &input_shape, - const std::vector &input_axis, const int shape_size, float *output) { - int pos_array[kMaxDim]; - int size_offset[kMaxDim]; - size_offset[0] = size / SizeToInt(input_shape[0]); - for (int i = 1; i < shape_size; i++) { - size_offset[i] = size_offset[i - 1] / SizeToInt(input_shape[i]); - } - for (int position = 0; position < size; position += 1) { - int temp_position = position; - pos_array[0] = temp_position / size_offset[0]; - for (int i = 1; i < shape_size; i++) { - temp_position -= pos_array[i - 1] * size_offset[i - 1]; - pos_array[i] = temp_position / size_offset[i]; - } - int new_position = pos_array[SizeToInt(input_axis[shape_size - 1])]; - int new_position_size = 1; - for (int j = shape_size - 2; j >= 0; j--) { - new_position_size *= SizeToInt(input_shape[SizeToInt(input_axis[j + 1])]); - new_position += pos_array[SizeToInt(input_axis[j])] * new_position_size; - } - output[new_position] = input[position]; - } - return; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.h deleted file mode 100644 index 3317ec72ed..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ -#include -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class ReduceCPUKernel : public CPUKernel { - public: - ReduceCPUKernel() = default; - ~ReduceCPUKernel() override = default; - void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void Transpose(const int size, const float *input, const std::vector &input_shape, - const std::vector &input_axis, const int shape_size, float *output); - size_t reduce_type_; - std::vector axis_; - std::vector shape_; - size_t left_dims_ = 1; - size_t stride_ = 1; -}; -MS_REG_CPU_KERNEL(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReduceCPUKernel); -MS_REG_CPU_KERNEL(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReduceCPUKernel); -MS_REG_CPU_KERNEL(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReduceCPUKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc deleted file mode 100644 index 19a4e907a0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/reduce_scatter_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "device/cpu/mpi/mpi_adapter.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr auto kRanksGroup = "group"; -} // namespace - -ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {} - -void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) { - auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); - if (op != nullptr) { - op_type_ = GetValue(op); - } - - auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); - if (ranks_group != nullptr) { - ranks_group_ = GetValue>(ranks_group); - } else { - MS_LOG(EXCEPTION) << "Miss attribute " << kRanksGroup; - } -} - -bool ReduceScatterCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto output_data_num = outputs[0]->size / sizeof(float); - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h deleted file mode 100644 index 5c6907602a..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class ReduceScatterCPUKernel : public CPUKernel { - public: - ReduceScatterCPUKernel(); - ~ReduceScatterCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - std::string op_type_; - std::vector ranks_group_; -}; - -MS_REG_CPU_KERNEL(_HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReduceScatterCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.cc deleted file mode 100644 index 7342a19e99..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/reshape_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } - -bool ReshapeCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output empty!"; - } - if (inputs[0]->size != outputs[0]->size) { - return false; - } - - if (inputs[0]->addr == outputs[0]->addr) { - return true; - } - - size_t mem_bits = outputs[0]->size; - auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h deleted file mode 100644 index 6ca746f4ac..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class ReshapeCPUKernel : public CPUKernel { - public: - ReshapeCPUKernel() = default; - ~ReshapeCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReshapeCPUKernel); -MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ReshapeCPUKernel); - -MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReshapeCPUKernel); -MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ReshapeCPUKernel); - -MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReshapeCPUKernel); -MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ReshapeCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc deleted file mode 100644 index d2530430e9..0000000000 --- a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc +++ /dev/null @@ -1,180 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/slice_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - - begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); - for (size_t i = 0; i < begin_.size(); i++) { - if (begin_[i] < 0) { - begin_[i] = begin_[i] + input_shape_[i]; - } - } - auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); - MS_EXCEPTION_IF_NULL(prim); - auto strides = prim->GetAttr(STRIDES); - if (strides != nullptr) { - strides_ = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); - end_ = AnfAlgo::GetNodeAttr>(kernel_node, END); - if (strides_.size() != end_.size() || strides_.size() != input_shape_.size()) { - MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; - } - for (size_t i = 0; i < strides_.size(); ++i) { - if (strides_[i] < 0) { - strides_[i] = (strides_[i] + input_shape_[i]) > 0 ? (strides_[i] + input_shape_[i]) : 0; - } - if (end_[i] < 0) { - end_[i] = (end_[i] + input_shape_[i]) > 0 ? (end_[i] + input_shape_[i]) : 0; - } - } - } else { - auto sizes = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); - if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) { - MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; - } - for (size_t i = 0; i < sizes.size(); ++i) { - if (sizes[i] < 0) { - sizes[i] = (sizes[i] + input_shape_[i]) > 0 ? (sizes[i] + input_shape_[i]) : 0; - } - strides_.emplace_back(1); - end_.emplace_back(begin_[i] + sizes[i]); - } - } - - ExpandAllMemberDims(); - CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); - CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); -} - -void SliceCPUKernel::ExpandAllMemberDims() { - CPUKernelUtils::ExpandDimsTo4(&output_shape_); - - auto input_len = input_shape_.size(); - if (input_len < 4) { - for (size_t i = 0; i < 4 - input_len; ++i) { - input_shape_.insert(input_shape_.begin(), 1); - begin_.insert(begin_.begin(), 0); - strides_.insert(strides_.begin(), 1); - end_.insert(end_.begin(), 1); - } - } -} - -bool SliceCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - - bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)}; - size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1], - begin_[2] * input_element_num_[2]}; - size_t in_step_size[3] = {strides_[0] * input_element_num_[0], strides_[1] * input_element_num_[1], - strides_[2] * input_element_num_[2]}; - - auto in_n_offset = in_start_offset[0]; - auto out_n_offset = 0; - for (int i = begin_[0]; i < end_[0]; - i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) { - if (can_copy_memory[0]) { - CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]); - continue; - } - auto in_c_offset = in_start_offset[1]; - auto out_c_offset = 0; - for (int j = begin_[1]; j < end_[1]; - j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) { - if (can_copy_memory[1]) { - CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, - input_element_num_[1]); - continue; - } - auto in_h_offset = in_start_offset[2]; - auto out_h_offset = 0; - for (int k = begin_[2]; k < end_[2]; - k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) { - if (can_copy_memory[2]) { - CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs, - out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]); - continue; - } - for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { - *output_addr++ = input_addr[in_n_offset + in_c_offset + in_h_offset + m]; - } - } - } - } - - return true; -} - -bool SliceCPUKernel::CanCopyMemoryOnAxis(size_t dim) const { - for (size_t i = dim + 1; i < 4; ++i) { - if (begin_[i] != 0 || end_[i] != SizeToInt(input_shape_[i]) || strides_[i] != 1) { - return false; - } - } - return true; -} - -void SliceCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t in_offset, - const std::vector &outputs, size_t out_offset, - size_t copy_num) const { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto in_buff_size = inputs[0]->size; - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto out_buff_size = outputs[0]->size; - - if ((in_offset + copy_num) * sizeof(float) > in_buff_size) { - MS_LOG(EXCEPTION) << "input memory out of bounds."; - } - if ((out_offset + copy_num) * sizeof(float) > out_buff_size) { - MS_LOG(EXCEPTION) << "output memory out of bounds."; - } - - auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset, - copy_num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret; - } -} - -void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SliceCPUKernel needs 1 inputs."; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output."; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower."; - } - if (input_shape.size() == 0) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h deleted file mode 100644 index 913c993d7a..0000000000 --- a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SliceCPUKernel : public CPUKernel { - public: - SliceCPUKernel() = default; - ~SliceCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void ExpandAllMemberDims(); - bool CanCopyMemoryOnAxis(size_t dim) const; - void CopyDataToOutput(const std::vector &inputs, size_t in_offset, - const std::vector &outputs, size_t out_offset, size_t copy_num) const; - void CheckParam(const CNodePtr &kernel_node) const; - std::vector begin_; - std::vector end_; - std::vector strides_; - std::vector input_shape_; - std::vector input_element_num_; - std::vector output_shape_; - std::vector output_element_num_; -}; - -MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceCPUKernel); -MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc deleted file mode 100644 index 92eaffe8c6..0000000000 --- a/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc +++ /dev/null @@ -1,182 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/slice_grad_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - - begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); - for (size_t i = 0; i < begin_.size(); i++) { - if (begin_[i] < 0) { - begin_[i] = begin_[i] + output_shape_[i]; - } - } - - auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); - MS_EXCEPTION_IF_NULL(prim); - auto strides = prim->GetAttr(STRIDES); - if (strides != nullptr) { - strides_ = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); - end_ = AnfAlgo::GetNodeAttr>(kernel_node, END); - if (strides_.size() != end_.size() || strides_.size() != output_shape_.size()) { - MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; - } - for (size_t i = 0; i < strides_.size(); ++i) { - if (strides_[i] < 0) { - strides_[i] = (strides_[i] + output_shape_[i]) > 0 ? (strides_[i] + output_shape_[i]) : 0; - } - if (end_[i] < 0) { - end_[i] = (end_[i] + output_shape_[i]) > 0 ? (end_[i] + output_shape_[i]) : 0; - } - } - } else { - auto sizes = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); - if (sizes.size() != output_shape_.size() || begin_.size() != output_shape_.size()) { - MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; - } - for (size_t i = 0; i < sizes.size(); ++i) { - if (sizes[i] < 0) { - sizes[i] = (sizes[i] + output_shape_[i]) > 0 ? (sizes[i] + output_shape_[i]) : 0; - } - strides_.emplace_back(1); - end_.emplace_back(begin_[i] + sizes[i]); - } - } - - ExpandAllMemberDims(); - CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); - CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); -} - -void SliceGradCPUKernel::ExpandAllMemberDims() { - CPUKernelUtils::ExpandDimsTo4(&input_shape_); - - auto output_len = output_shape_.size(); - if (output_len < 4) { - for (size_t i = 0; i < 4 - output_len; ++i) { - output_shape_.insert(output_shape_.begin(), 1); - begin_.insert(begin_.begin(), 0); - strides_.insert(strides_.begin(), 1); - end_.insert(end_.begin(), 1); - } - } -} - -bool SliceGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - - auto ret = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size); - if (ret != EOK) { - MS_LOG(ERROR) << "output buff memset fail. ret:" << ret; - return false; - } - - bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)}; - size_t out_start_offset[3] = {begin_[0] * output_element_num_[0], begin_[1] * output_element_num_[1], - begin_[2] * output_element_num_[2]}; - size_t out_step_size[3] = {strides_[0] * output_element_num_[0], strides_[1] * output_element_num_[1], - strides_[2] * output_element_num_[2]}; - - auto in_n_offset = 0; - auto out_n_offset = out_start_offset[0]; - for (int i = begin_[0]; i < end_[0]; - i += strides_[0], in_n_offset += input_element_num_[0], out_n_offset += out_step_size[0]) { - if (can_copy_memory[0]) { - CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]); - continue; - } - auto in_c_offset = 0; - auto out_c_offset = out_start_offset[1]; - for (int j = begin_[1]; j < end_[1]; - j += strides_[1], in_c_offset += input_element_num_[1], out_c_offset += out_step_size[1]) { - if (can_copy_memory[1]) { - CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, - input_element_num_[1]); - continue; - } - auto in_h_offset = 0; - auto out_h_offset = out_start_offset[2]; - for (int k = begin_[2]; k < end_[2]; - k += strides_[2], in_h_offset += input_element_num_[2], out_h_offset += out_step_size[2]) { - if (can_copy_memory[2]) { - CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs, - out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]); - continue; - } - for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { - output_addr[out_n_offset + out_c_offset + out_h_offset + m] = *input_addr++; - } - } - } - } - return true; -} - -bool SliceGradCPUKernel::CanCopyMemoryOnAxis(size_t dim) const { - for (size_t i = dim + 1; i < 4; ++i) { - if (begin_[i] != 0 || end_[i] != SizeToInt(output_shape_[i]) || strides_[i] != 1) { - return false; - } - } - return true; -} - -void SliceGradCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t in_offset, - const std::vector &outputs, size_t out_offset, - size_t copy_num) const { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto in_buff_size = inputs[0]->size; - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto out_buff_size = outputs[0]->size; - - if ((in_offset + copy_num) * sizeof(float) > in_buff_size) { - MS_LOG(EXCEPTION) << "input memory out of bounds."; - } - if ((out_offset + copy_num) * sizeof(float) > out_buff_size) { - MS_LOG(EXCEPTION) << "output memory out of bounds."; - } - - auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset, - copy_num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret; - } -} - -void SliceGradCPUKernel::CheckParam(const CNodePtr &kernel_node) const { - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output."; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceGradGpuKernel only support 4d or lower."; - } - if (input_shape.size() == 0) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.h deleted file mode 100644 index 1e42c8ac68..0000000000 --- a/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SliceGradCPUKernel : public CPUKernel { - public: - SliceGradCPUKernel() = default; - ~SliceGradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void ExpandAllMemberDims(); - bool CanCopyMemoryOnAxis(size_t dim) const; - void CopyDataToOutput(const std::vector &inputs, size_t in_offset, - const std::vector &outputs, size_t out_offset, size_t copy_num) const; - void CheckParam(const CNodePtr &kernel_node) const; - std::vector begin_; - std::vector end_; - std::vector strides_; - std::vector input_shape_; - std::vector input_element_num_; - std::vector output_shape_; - std::vector output_element_num_; -}; - -MS_REG_CPU_KERNEL( - SliceGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGradCPUKernel); -MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGradCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc deleted file mode 100644 index ef3db78275..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc +++ /dev/null @@ -1,177 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/sparse_apply_adam_cpu_kernel.h" -#include "kernel/common_utils.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseApplyAdamInputSize = 11; - -void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto m = input_params->m_; - auto m_t = input_params->m_t_; - auto v = input_params->v_; - auto beta1 = input_params->beta1_; - auto beta2 = input_params->beta2_; - auto use_nesterov = input_params->use_nesterov_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; - for (size_t i = start; i < end; ++i) { - int index = unique_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= var_first_dim_size) { - MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; - } - size_t start_index = var_outer_dim_size * index; - size_t end_index = start_index + var_outer_dim_size; - for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { - auto summed_grad = unique_sparse_grad.value_[k]; - m[j] += (1 - beta1) * summed_grad; - v[j] += (1 - beta2) * summed_grad * summed_grad; - if (use_nesterov) { - m_t[j] = m[j] * beta1 + (1 - beta1) * summed_grad; - } - } - } -} - -void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto m = input_params->m_; - auto v = input_params->v_; - auto beta1 = input_params->beta1_; - auto beta2 = input_params->beta2_; - for (size_t i = start; i < end; ++i) { - m[i] *= beta1; - v[i] *= beta2; - } -} - -void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto var = input_params->var_; - auto m = input_params->m_; - auto v = input_params->v_; - auto lr = input_params->lr_; - auto epsilon = input_params->epsilon_; - for (size_t i = start; i < end; ++i) { - var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon); - } -} -} // namespace - -void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); - workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); -} - -void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - std::vector m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - std::vector v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); - std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 10); - if (!IsSameShape(var_shape, m_shape)) { - MS_LOG(EXCEPTION) << "var and m should have the same shape"; - } - if (!IsSameShape(var_shape, v_shape)) { - MS_LOG(EXCEPTION) << "var and v should have the same shape"; - } - if (var_shape.empty()) { - MS_LOG(EXCEPTION) << "var must be at least 1D"; - } - var_first_dim_size_ = var_shape[0]; - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; - } - var_outer_dim_size_ *= var_shape[i]; - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "indices must be 1D"; - } - indices_size_ = indices_shape[0]; - if (grad_shape[0] != indices_size_) { - MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; - } - if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { - use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); - } -} - -bool SparseApplyAdamCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector & /*outputs*/) { - if (inputs.size() < kSparseApplyAdamInputSize) { - MS_LOG(EXCEPTION) << "Error input size!"; - } - - auto var = reinterpret_cast(inputs[0]->addr); - auto m = reinterpret_cast(inputs[1]->addr); - auto v = reinterpret_cast(inputs[2]->addr); - auto beta1_power = reinterpret_cast(inputs[3]->addr)[0]; - if (beta1_power == 1) { - MS_LOG(EXCEPTION) << "The beta1_power should not be 1"; - } - auto beta2_power = reinterpret_cast(inputs[4]->addr)[0]; - auto lr = reinterpret_cast(inputs[5]->addr)[0]; - auto beta1 = reinterpret_cast(inputs[6]->addr)[0]; - auto beta2 = reinterpret_cast(inputs[7]->addr)[0]; - auto epsilon = reinterpret_cast(inputs[8]->addr)[0]; - auto grad = reinterpret_cast(inputs[9]->addr); - auto indices = reinterpret_cast(inputs[10]->addr); - auto new_grad = reinterpret_cast(workspace[0]->addr); - auto new_indices = reinterpret_cast(workspace[1]->addr); - auto m_t = reinterpret_cast(workspace[2]->addr); - - SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, - var_outer_dim_size_); - size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_; - lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); - - MultiThreadComputeParams input_params; - input_params.m_ = m; - input_params.v_ = v; - input_params.beta1_ = beta1; - input_params.beta2_ = beta2; - MultiThreadCompute(ComputeMomentum, &input_params, total_dim_size); - - input_params.m_t_ = m_t; - input_params.use_nesterov_ = use_nesterov_; - input_params.sparse_grad_ = unique_sparse_grad; - input_params.var_first_dim_size_ = var_first_dim_size_; - input_params.var_outer_dim_size_ = var_outer_dim_size_; - MultiThreadCompute(ComputeAdam, &input_params, unique_sparse_grad.indices_size_); - - if (use_nesterov_) { - input_params.m_ = input_params.m_t_; - } - input_params.var_ = var; - input_params.lr_ = lr; - input_params.epsilon_ = epsilon; - MultiThreadCompute(ComputeWeight, &input_params, total_dim_size); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.h deleted file mode 100644 index c2770d0ebd..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SparseApplyAdamCPUKernel : public CPUKernel { - public: - SparseApplyAdamCPUKernel() = default; - ~SparseApplyAdamCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - void InitInputOutputSize(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t indices_size_{0}; - size_t var_first_dim_size_{0}; - size_t var_outer_dim_size_{1}; - bool use_nesterov_{false}; -}; - -MS_REG_CPU_KERNEL(SparseApplyAdam, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyAdamCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc deleted file mode 100644 index 0537e746f3..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/sparse_apply_ftrl_cpu_kernel.h" -#include "kernel/common_utils.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseApplyFtrlInputSize = 5; - -void ComputeFtrl(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto var = input_params->var_; - auto accum = input_params->accum_; - auto linear = input_params->linear_; - auto lr = input_params->lr_; - auto l1 = input_params->l1_; - auto l2_plus = 2 * input_params->l2_; - auto lr_power = input_params->lr_power_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; - for (size_t i = start; i < end; ++i) { - int index = unique_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= var_first_dim_size) { - MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; - } - size_t start_index = var_outer_dim_size * index; - size_t end_index = start_index + var_outer_dim_size; - for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { - auto summed_grad = unique_sparse_grad.value_[k]; - auto accum_new = accum[j] + summed_grad * summed_grad; - float y; - if (lr_power == -0.5) { - y = std::sqrt(accum_new); - linear[j] += summed_grad - (y - std::sqrt(accum[j])) / lr * var[j]; - } else { - y = std::pow(accum_new, -lr_power); - linear[j] += summed_grad - (y - std::pow(accum[j], -lr_power)) / lr * var[j]; - } - accum[j] = accum_new; - auto x = Sign(linear[j]) * l1 - linear[j]; - y = y / lr + l2_plus; - var[j] = std::fabs(linear[j]) > l1 ? x / y : 0; - } - } -} -} // namespace - -void SparseApplyFtrlCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); -} - -void SparseApplyFtrlCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - std::vector accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - std::vector linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); - if (!IsSameShape(var_shape, accum_shape)) { - MS_LOG(EXCEPTION) << "var and accum should have the same shape"; - } - if (!IsSameShape(var_shape, linear_shape)) { - MS_LOG(EXCEPTION) << "var and linear should have the same shape"; - } - if (var_shape.empty()) { - MS_LOG(EXCEPTION) << "var must be at least 1D"; - } - var_first_dim_size_ = var_shape[0]; - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; - } - var_outer_dim_size_ *= var_shape[i]; - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "indices must be a 1D vector"; - } - indices_size_ = indices_shape[0]; - if (grad_shape[0] != indices_size_) { - MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; - } - lr_ = AnfAlgo::GetNodeAttr(kernel_node, "lr"); - if (lr_ <= 0) { - MS_LOG(EXCEPTION) << "lr should be a positive scalar"; - } - l1_ = AnfAlgo::GetNodeAttr(kernel_node, "l1"); - if (l1_ < 0) { - MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar"; - } - l2_ = AnfAlgo::GetNodeAttr(kernel_node, "l2"); - if (l2_ < 0) { - MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar"; - } - lr_power_ = AnfAlgo::GetNodeAttr(kernel_node, "lr_power"); - if (lr_power_ > 0) { - MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; - } -} - -bool SparseApplyFtrlCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector & /*outputs*/) { - if (inputs.size() < kSparseApplyFtrlInputSize) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - - auto var = reinterpret_cast(inputs[0]->addr); - auto accum = reinterpret_cast(inputs[1]->addr); - auto linear = reinterpret_cast(inputs[2]->addr); - auto grad = reinterpret_cast(inputs[3]->addr); - auto indices = reinterpret_cast(inputs[4]->addr); - auto new_grad = reinterpret_cast(workspace[0]->addr); - auto new_indices = reinterpret_cast(workspace[1]->addr); - SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, - var_outer_dim_size_); - - MultiThreadComputeParams input_params; - input_params.var_ = var; - input_params.accum_ = accum; - input_params.linear_ = linear; - input_params.lr_ = lr_; - input_params.l1_ = l1_; - input_params.l2_ = l2_; - input_params.lr_power_ = lr_power_; - input_params.sparse_grad_ = unique_sparse_grad; - input_params.var_first_dim_size_ = var_first_dim_size_; - input_params.var_outer_dim_size_ = var_outer_dim_size_; - MultiThreadCompute(ComputeFtrl, &input_params, unique_sparse_grad.indices_size_); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.h deleted file mode 100644 index 9e79dc83c7..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ - -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SparseApplyFtrlCPUKernel : public CPUKernel { - public: - SparseApplyFtrlCPUKernel() = default; - ~SparseApplyFtrlCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - void InitInputOutputSize(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t indices_size_{0}; - size_t var_first_dim_size_{0}; - size_t var_outer_dim_size_{1}; - float lr_{0}; - float l1_{0}; - float l2_{0}; - float lr_power_{0}; -}; - -MS_REG_CPU_KERNEL(SparseApplyFtrl, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyFtrlCPUKernel); - -MS_REG_CPU_KERNEL(SparseApplyFtrlNoReturn, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyFtrlCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc deleted file mode 100644 index 16cb901b04..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc +++ /dev/null @@ -1,146 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h" -#include "kernel/common_utils.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseApplyLazyAdamInputSize = 11; - -void ComputeLazyAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto var = input_params->var_; - auto m = input_params->m_; - auto v = input_params->v_; - auto lr = input_params->lr_; - auto beta1 = input_params->beta1_; - auto beta2 = input_params->beta2_; - auto epsilon = input_params->epsilon_; - auto use_nesterov = input_params->use_nesterov_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; - for (size_t i = start; i < end; ++i) { - int index = unique_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= var_first_dim_size) { - MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range"; - } - size_t start_index = var_outer_dim_size * index; - size_t end_index = start_index + var_outer_dim_size; - for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { - auto summed_grad = unique_sparse_grad.value_[k]; - m[j] = beta1 * m[j] + (1 - beta1) * summed_grad; - v[j] = beta2 * v[j] + (1 - beta2) * summed_grad * summed_grad; - if (use_nesterov) { - var[j] -= lr * (m[j] * beta1 + (1 - beta1) * summed_grad) / (std::sqrt(v[j]) + epsilon); - } else { - var[j] -= lr * m[j] / (std::sqrt(v[j]) + epsilon); - } - } - } -} -} // namespace - -void SparseApplyLazyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); -} - -void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - std::vector m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - std::vector v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); - std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 10); - if (!IsSameShape(var_shape, m_shape)) { - MS_LOG(EXCEPTION) << "var and m should have the same shape"; - } - if (!IsSameShape(var_shape, v_shape)) { - MS_LOG(EXCEPTION) << "var and v should have the same shape"; - } - if (var_shape.empty()) { - MS_LOG(EXCEPTION) << "var must be at least 1D"; - } - var_first_dim_size_ = var_shape[0]; - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; - } - var_outer_dim_size_ *= var_shape[i]; - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "indices must be 1D"; - } - indices_size_ = indices_shape[0]; - if (grad_shape[0] != indices_size_) { - MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; - } - if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { - use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); - } -} - -bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector & /*outputs*/) { - if (inputs.size() < kSparseApplyLazyAdamInputSize) { - MS_LOG(EXCEPTION) << "Error input size!"; - } - - auto var = reinterpret_cast(inputs[0]->addr); - auto m = reinterpret_cast(inputs[1]->addr); - auto v = reinterpret_cast(inputs[2]->addr); - auto beta1_power = reinterpret_cast(inputs[3]->addr)[0]; - if (beta1_power == 1) { - MS_LOG(EXCEPTION) << "The beta1_power should not be 1"; - } - auto beta2_power = reinterpret_cast(inputs[4]->addr)[0]; - auto lr = reinterpret_cast(inputs[5]->addr)[0]; - auto beta1 = reinterpret_cast(inputs[6]->addr)[0]; - auto beta2 = reinterpret_cast(inputs[7]->addr)[0]; - auto epsilon = reinterpret_cast(inputs[8]->addr)[0]; - auto grad = reinterpret_cast(inputs[9]->addr); - auto indices = reinterpret_cast(inputs[10]->addr); - auto new_grad = reinterpret_cast(workspace[0]->addr); - auto new_indices = reinterpret_cast(workspace[1]->addr); - - SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, - var_outer_dim_size_); - - lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); - MultiThreadComputeParams input_params; - input_params.var_ = var; - input_params.m_ = m; - input_params.v_ = v; - input_params.lr_ = lr; - input_params.beta1_ = beta1; - input_params.beta2_ = beta2; - input_params.epsilon_ = epsilon; - input_params.use_nesterov_ = use_nesterov_; - input_params.sparse_grad_ = unique_sparse_grad; - input_params.var_first_dim_size_ = var_first_dim_size_; - input_params.var_outer_dim_size_ = var_outer_dim_size_; - MultiThreadCompute(ComputeLazyAdam, &input_params, unique_sparse_grad.indices_size_); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h deleted file mode 100644 index 795568a64d..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SparseApplyLazyAdamCPUKernel : public CPUKernel { - public: - SparseApplyLazyAdamCPUKernel() = default; - ~SparseApplyLazyAdamCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - void InitInputOutputSize(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t indices_size_{0}; - size_t var_first_dim_size_{0}; - size_t var_outer_dim_size_{1}; - bool use_nesterov_{false}; -}; - -MS_REG_CPU_KERNEL(SparseApplyLazyAdam, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyLazyAdamCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc deleted file mode 100644 index 6069fb708e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" -#include "kernel/common_utils.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseApplyProximalAdagradInputSize = 7; - -void ComputeProximalAdagrad(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto var = input_params->var_; - auto accum = input_params->accum_; - auto lr = input_params->lr_; - auto l1 = input_params->l1_; - auto l2 = input_params->l2_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; - for (size_t i = start; i < end; ++i) { - int index = unique_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= var_first_dim_size) { - MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; - } - size_t start_index = var_outer_dim_size * index; - size_t end_index = start_index + var_outer_dim_size; - for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { - auto summed_grad = unique_sparse_grad.value_[k]; - accum[j] += summed_grad * summed_grad; - auto learning_rate = lr * (1 / std::sqrt(accum[j])); - auto prox_v = var[j]; - prox_v -= summed_grad * learning_rate; - if (l1 > 0) { - var[j] = Sign(prox_v) * std::fmax(std::fabs(prox_v) - learning_rate * l1, static_cast(0.0)) / - (1 + l2 * learning_rate); - } else { - var[j] = prox_v / (1 + l2 * learning_rate); - } - } - } -} -} // namespace - -void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); -} - -void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - std::vector accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - std::vector lr_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - std::vector l1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - std::vector l2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); - std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 5); - std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 6); - if (!IsSameShape(var_shape, accum_shape)) { - MS_LOG(EXCEPTION) << "var and accum should have the same shape"; - } - if (var_shape.empty()) { - MS_LOG(EXCEPTION) << "var must be at least 1D"; - } - var_first_dim_size_ = var_shape[0]; - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; - } - var_outer_dim_size_ *= var_shape[i]; - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "indices must be a 1D vector"; - } - indices_size_ = indices_shape[0]; - if (grad_shape[0] != indices_size_) { - MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; - } - if (!lr_shape.empty()) { - MS_LOG(EXCEPTION) << "lr is not a scalar"; - } - if (!l1_shape.empty()) { - MS_LOG(EXCEPTION) << "l1 is not a scalar"; - } - if (!l2_shape.empty()) { - MS_LOG(EXCEPTION) << "l2 is not a scalar"; - } -} - -bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector & /*outputs*/) { - if (inputs.size() < kSparseApplyProximalAdagradInputSize) { - MS_LOG(EXCEPTION) << "Wrong input size!"; - } - - auto var = reinterpret_cast(inputs[0]->addr); - auto accum = reinterpret_cast(inputs[1]->addr); - auto lr = reinterpret_cast(inputs[2]->addr)[0]; - auto l1 = reinterpret_cast(inputs[3]->addr)[0]; - auto l2 = reinterpret_cast(inputs[4]->addr)[0]; - auto grad = reinterpret_cast(inputs[5]->addr); - auto indices = reinterpret_cast(inputs[6]->addr); - auto new_grad = reinterpret_cast(workspace[0]->addr); - auto new_indices = reinterpret_cast(workspace[1]->addr); - SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, - var_outer_dim_size_); - - MultiThreadComputeParams input_params; - input_params.var_ = var; - input_params.accum_ = accum; - input_params.lr_ = lr; - input_params.l1_ = l1; - input_params.l2_ = l2; - input_params.sparse_grad_ = unique_sparse_grad; - input_params.var_first_dim_size_ = var_first_dim_size_; - input_params.var_outer_dim_size_ = var_outer_dim_size_; - MultiThreadCompute(ComputeProximalAdagrad, &input_params, unique_sparse_grad.indices_size_); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h deleted file mode 100644 index ff7da7966c..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SparseApplyProximalAdagradCPUKernel : public CPUKernel { - public: - SparseApplyProximalAdagradCPUKernel() = default; - ~SparseApplyProximalAdagradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - void InitInputOutputSize(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t indices_size_{0}; - size_t var_first_dim_size_{0}; - size_t var_outer_dim_size_{1}; -}; - -MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyProximalAdagradCPUKernel); - -MS_REG_CPU_KERNEL(SparseApplyProximalAdagradNoReturn, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyProximalAdagradCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc deleted file mode 100644 index 543f0e5cdd..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "kernel/cpu/sub_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void SubCPUKernel::InitKernel(const CNodePtr &kernel_node) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - if (shape.size() == 1) { - if (shape[0] != 1) { - MS_LOG(EXCEPTION) << "input 1 only support scalar"; - } - } else { - MS_LOG(EXCEPTION) << "input 1 only support scalar"; - } -} - -void sub_task(const int *in_addr, int *out_addr, size_t lens, int offset) { - for (size_t i = 0; i < lens; i++) { - out_addr[i] = in_addr[i] - offset; - } -} - -bool SubCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - offset_ = *reinterpret_cast(inputs[1]->addr); - MS_LOG(INFO) << "offset: " << offset_; - auto lens = inputs[0]->size / sizeof(int); - if (lens < 10000) { - for (size_t i = 0; i < lens; i++) { - output_addr[i] = input_addr[i] - offset_; - } - } else { - const size_t thread_num = 4; - std::thread threads[4]; - size_t process_lens = (lens + thread_num - 1) / thread_num; - size_t process_offset = 0; - for (size_t i = 0; i < thread_num; i++) { - threads[i] = - std::thread(sub_task, input_addr + process_offset, output_addr + process_offset, process_lens, offset_); - if (process_offset + process_lens > lens) { - process_lens = lens - process_offset; - process_offset = lens; - } else { - process_offset += process_lens; - } - } - for (size_t i = 0; i < thread_num; i++) { - threads[i].join(); - } - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "SubscaleCPUKernel, used time: " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "SubCPUKernel, used time: " << time << " us"; -#endif - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.h deleted file mode 100644 index 54b2c8951a..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SubCPUKernel : public CPUKernel { - public: - SubCPUKernel() : offset_(0) {} - ~SubCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - int offset_; -}; - -MS_REG_CPU_KERNEL( - Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SubCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.cc deleted file mode 100644 index f2ac9350cb..0000000000 --- a/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/transpose_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -namespace mindspore { -namespace kernel { -const size_t kMaxDim = 100; -void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - axis_ = AnfAlgo::GetNodeAttr>(kernel_node, "perm"); - if (shape_.size() != axis_.size()) { - MS_LOG(EXCEPTION) << "The size of input shape and transpose axis shape must be equal."; - } -} -bool TransposeCPUFwdKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input = reinterpret_cast(inputs[0]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - size_t size = IntToSize(inputs[0]->size / sizeof(float)); - size_t shape_size = IntToSize(shape_.size()); - if (shape_size > kMaxDim) { - MS_LOG(EXCEPTION) << "Input is " << shape_size << "-D, but transpose supports max " << kMaxDim << "-D inputs."; - } - size_t pos_array[kMaxDim]; - size_t size_offset[kMaxDim]; - size_offset[0] = size / shape_[0]; - for (size_t i = 1; i < shape_size; i++) { - size_offset[i] = size_offset[SizeToInt(i) - 1] / shape_[i]; - } - for (size_t position = 0; position < size; position += 1) { - size_t temp_position = position; - pos_array[0] = temp_position / size_offset[0]; - for (size_t i = 1; i < shape_size; i++) { - temp_position -= pos_array[SizeToInt(i) - 1] * size_offset[i - 1]; - pos_array[i] = temp_position / size_offset[i]; - } - size_t new_position = pos_array[axis_[SizeToInt(shape_size) - 1]]; - size_t new_position_size = 1; - for (int j = shape_size - 2; j >= 0; j--) { - new_position_size *= shape_[axis_[j + 1]]; - new_position += pos_array[axis_[j]] * new_position_size; - } - output[new_position] = input[position]; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.h deleted file mode 100644 index d882f4fa51..0000000000 --- a/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ -#include -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" -namespace mindspore { -namespace kernel { -class TransposeCPUFwdKernel : public CPUKernel { - public: - TransposeCPUFwdKernel() = default; - ~TransposeCPUFwdKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - std::vector shape_; - std::vector axis_; -}; - -MS_REG_CPU_KERNEL(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TransposeCPUFwdKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.cc deleted file mode 100644 index 71f612d07c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/argmax_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - ArgmaxGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), - ArgmaxGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.h deleted file mode 100644 index 3df70d0960..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/argmax_impl.cuh" -namespace mindspore { -namespace kernel { -#define ARGMAX_MAX_DIMENSION 2 -template -class ArgmaxGpuKernel : public GpuKernel { - public: - ArgmaxGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), batch_size_(0), channel_size_(0), axis_(0) {} - ~ArgmaxGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input = GetDeviceAddress(inputs, 0); - int *output = GetDeviceAddress(outputs, 0); - CalArgmax(input, SizeToInt(batch_size_), SizeToInt(channel_size_), axis_, output, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but argmax needs 1 input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but argmax needs 1 output."; - return false; - } - auto output_type = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("output_type")); - if (output_type->type_id() != TypeId::kNumberTypeInt32) { - MS_LOG(EXCEPTION) << "Argmax only supports int32 output type."; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > ARGMAX_MAX_DIMENSION) { - MS_LOG(EXCEPTION) << "Input is " << input_shape.size() << "-D, but argmax supports max " << ARGMAX_MAX_DIMENSION - << "-D inputs."; - } - - axis_ = GetAttr(kernel_node, "axis"); - if (axis_ < 0) { - axis_ += SizeToInt(input_shape.size()); - } - if (input_shape.size() == 1) { - batch_size_ = 0; - channel_size_ = input_shape[0]; - input_size_ = sizeof(T) * channel_size_; - output_size_ = sizeof(int); - } else { - batch_size_ = input_shape[0]; - channel_size_ = input_shape[1]; - input_size_ = sizeof(T) * batch_size_ * channel_size_; - output_size_ = (axis_ == 1) ? sizeof(int) * batch_size_ : sizeof(int) * channel_size_; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - private: - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t batch_size_; - size_t channel_size_; - int axis_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.cc deleted file mode 100644 index 24c8a9a730..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - ArgMaxWithValue, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - ArgmaxWithValueGpuKernel, float, int) -MS_REG_GPU_KERNEL_TWO( - ArgMaxWithValue, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - ArgmaxWithValueGpuKernel, half, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h deleted file mode 100644 index 304f0ab161..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh" -namespace mindspore { -namespace kernel { -template -class ArgmaxWithValueGpuKernel : public GpuKernel { - public: - ArgmaxWithValueGpuKernel() : input_size_(0), output_size_(0), bound_(0), outerSize_(0), innerSize_(0) {} - ~ArgmaxWithValueGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 1); - S *index = GetDeviceAddress(outputs, 0); - CalArgmaxWithValue(input, bound_, outerSize_, innerSize_, index, output, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - std::vector shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); - int dims = shape.size(); - int axis = GetAttr(kernel_node, "axis"); - if (axis < 0) { - axis += dims; - } - input_size_ = sizeof(T); - for (auto x : shape) { - input_size_ *= x; - } - output_size_ = sizeof(S); - for (auto x : output_shape) { - output_size_ *= x; - } - bound_ = shape[axis]; - outerSize_ = 1; - for (int i = axis - 1; i >= 0; i--) { - outerSize_ *= shape[i]; - } - - innerSize_ = 1; - for (int i = axis + 1; i < dims; i++) { - innerSize_ *= shape[i]; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - output_size_list_.push_back(output_size_ / sizeof(S) * sizeof(T)); - } - - private: - size_t input_size_; - size_t output_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - int bound_; - int outerSize_; - int innerSize_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.cc deleted file mode 100644 index f378604624..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/array_reduce_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ArrayReduceGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ArrayReduceGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ArrayReduceGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ArrayReduceGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ArrayReduceGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ArrayReduceGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h deleted file mode 100644 index 4a52439305..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h +++ /dev/null @@ -1,237 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -namespace mindspore { -namespace kernel { -const std::map kReduceTypeMap = { - {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, - {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, - {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, -}; -template -class ArrayReduceGpuKernel : public GpuKernel { - public: - ArrayReduceGpuKernel() - : cudnn_handle_(nullptr), - reduce_tensor_op_(CUDNN_REDUCE_TENSOR_ADD), - data_type_(CUDNN_DATA_FLOAT), - nan_prop_(CUDNN_NOT_PROPAGATE_NAN), - reduce_indices_(CUDNN_REDUCE_TENSOR_NO_INDICES), - reduce_tensor_descriptor_(nullptr), - inputA_descriptor_(nullptr), - outputC_descriptor_(nullptr), - keep_dims_(false), - all_match_(false), - is_null_input_(false), - input_size_(0), - output_size_(0), - workspace_size_(0) {} - ~ArrayReduceGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - T *workspace_addr = GetDeviceAddress(workspace, 0); - - const float alpha = 1; - const float beta = 0; - if (all_match_) { - MS_LOG(WARNING) - << "The corresponding dimensions of the input and output tensors all match. No need to call cuDNN kernel."; - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync failed in ArrayReduceGpuKernel::Launch."); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha, - inputA_descriptor_, input_addr, &beta, outputC_descriptor_, output_addr), - "cudnnReduceTensor failed."); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but reduce op needs 1 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but reduce op needs 1 output."; - return false; - } - int input_dim_length = SizeToInt(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size()); - - if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa() || - AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa()) { - auto attr_axis = GetAttr>(kernel_node, "axis"); - if (attr_axis.empty()) { - axis_.push_back(-1); - } else { - for (auto axis : attr_axis) { - axis < 0 ? axis_.push_back(axis + input_dim_length) : axis_.push_back(axis); - } - } - } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa()) { - int axis = GetAttr(kernel_node, "axis"); - axis < 0 ? axis_.push_back(axis + input_dim_length) : axis_.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; - } - keep_dims_ = GetAttr(kernel_node, "keep_dims"); - - auto inputA_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto outputC_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(inputA_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "ArrayReduceGpuKernel input is null"; - InitSizeLists(); - return true; - } - InferInAndOutDesc(inputA_shape, outputC_shape); - InferArrayReduceType(kernel_node); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_), - "cudnnCreateReduceTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&inputA_descriptor_), - "cudnnCreateTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&outputC_descriptor_), - "cudnnCreateTensorDescriptor failed."); - } - void InitSizeLists() override { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(inputA_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed."); - input_size_list_.push_back(input_size_); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(outputC_descriptor_, &output_size_), - "cudnnGetTensorSizeInBytes failed."); - output_size_list_.push_back(output_size_); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_, inputA_descriptor_, outputC_descriptor_, - &workspace_size_), - "cudnnGetReductionWorkspaceSize failed."); - workspace_size_list_.push_back(workspace_size_); - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_), - "cudnnDestroyReduceTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputA_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(outputC_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - } - void InferArrayReduceType(const CNodePtr &kernel_node) { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kReduceTypeMap.find(kernel_name); - if (iter == kReduceTypeMap.end()) { - MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; - } else { - reduce_tensor_op_ = iter->second; - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, reduce_tensor_op_, CUDNN_DATA_FLOAT, nan_prop_, - reduce_indices_, CUDNN_32BIT_INDICES), - "cudnnSetReduceTensorDescriptor failed"); - return; - } - void InferInAndOutDesc(const std::vector &input_shape, const std::vector &output_shape) { - std::vector inputA; - std::vector outputC_shape = output_shape; - ShapeNdTo4d(input_shape, &inputA); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0], - inputA[1], inputA[2], inputA[3]), - "cudnnSetTensor4dDescriptor failed"); - - if (axis_[0] == -1) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1), - "cudnnSetTensor4dDescriptor failed"); - if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) { - all_match_ = true; - } - return; - } - if (!keep_dims_) { - for (auto i : axis_) { - (void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); - } - } - std::vector outputC; - ShapeNdTo4d(outputC_shape, &outputC); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, - outputC[0], outputC[1], outputC[2], outputC[3]), - "cudnnSetTensor4dDescriptor failed"); - if (inputA == outputC) { - all_match_ = true; - } - return; - } - - cudnnHandle_t cudnn_handle_; - cudnnReduceTensorOp_t reduce_tensor_op_; - cudnnDataType_t data_type_; - cudnnNanPropagation_t nan_prop_; - cudnnReduceTensorIndices_t reduce_indices_; - cudnnReduceTensorDescriptor_t reduce_tensor_descriptor_; - cudnnTensorDescriptor_t inputA_descriptor_; - cudnnTensorDescriptor_t outputC_descriptor_; - - std::vector axis_; - bool keep_dims_; - bool all_match_; - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc deleted file mode 100644 index 3bca6a69d3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/concatv2_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ConcatV2GpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Concat, - KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ConcatV2GpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE( - Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ConcatV2GpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h deleted file mode 100644 index a91c50ce69..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h +++ /dev/null @@ -1,128 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/concatv2_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class ConcatV2GpuFwdKernel : public GpuKernel { - public: - ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {} - ~ConcatV2GpuFwdKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (inputs.size() == 2) { - T *input_0 = GetDeviceAddress(inputs, 0); - T *input_1 = GetDeviceAddress(inputs, 1); - T *output = GetDeviceAddress(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, - reinterpret_cast(stream_ptr)); - } - - if (inputs.size() == 3) { - T *input_0 = GetDeviceAddress(inputs, 0); - T *input_1 = GetDeviceAddress(inputs, 1); - T *input_2 = GetDeviceAddress(inputs, 2); - T *output = GetDeviceAddress(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output, - reinterpret_cast(stream_ptr)); - } - - if (inputs.size() == 4) { - T *input_0 = GetDeviceAddress(inputs, 0); - T *input_1 = GetDeviceAddress(inputs, 1); - T *input_2 = GetDeviceAddress(inputs, 2); - T *input_3 = GetDeviceAddress(inputs, 3); - T *output = GetDeviceAddress(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output, - reinterpret_cast(stream_ptr)); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - - axis_ = GetAttr(kernel_node, "axis"); - if (axis_ < 0) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - axis_ += SizeToInt(input_shape.size()); - } - - auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; i++) { - auto input_size = sizeof(T); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - for (size_t j = 0; j < input_shape.size(); j++) { - input_size *= SizeToInt(input_shape[j]); - if (j >= IntToSize(axis_)) { - w_[i] *= SizeToInt(input_shape[j]); - } - input_size_list_.push_back(input_size); - } - } - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - output_size_ = sizeof(T); - for (size_t i = 0; i < output_shape.size(); i++) { - output_size_ *= output_shape[i]; - } - output_size_list_.push_back(output_size_); - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override {} - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num < 2 || input_num > 4) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but ConcatV2GpuFwdKernel needs 1 output."; - return false; - } - return true; - } - int w_[4] = {1, 1, 1, 1}; - int axis_; - size_t output_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.cc deleted file mode 100644 index dc595e4793..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/gather_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - GatherV2, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - GatherGpuFwdKernel, float, int) -MS_REG_GPU_KERNEL_TWO( - GatherV2, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - GatherGpuFwdKernel, half, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.h deleted file mode 100644 index 72a05b0915..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_GATHER_GPU_KERNEL_H -#define MINDSPORE_GATHER_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/gather.cuh" - -namespace mindspore { -namespace kernel { -template -class GatherGpuFwdKernel : public GpuKernel { - public: - GatherGpuFwdKernel() : axis_(0), handle_(nullptr) {} - ~GatherGpuFwdKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - T *input_addr = GetDeviceAddress(inputs, 0); - S *indices_addr = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - - auto input_dim1 = input_shapes_[IntToSize(axis_)]; - Gather(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuFwdKernel needs 2."; - } - input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - - axis_ = GetAttr(kernel_node, "axis"); - if (axis_ < 0) { - axis_ = axis_ + SizeToInt(input_shapes_.size()); - } - - Reshape(); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - void InitSizeLists() override { - size_t size = GetSize(input_shapes_); - input_size_list_.push_back(size); - - size = GetSize(indices_shapes_); - input_size_list_.push_back(size); - - size = GetSize(output_shapes_); - output_size_list_.push_back(size); - } - - private: - void Reshape() { - size_t dim_before_axis = 1; - for (size_t i = 0; i < IntToSize(axis_); i++) { - dim_before_axis *= output_shapes_[i]; - } - - size_t dim_of_indices = 1; - for (size_t i = 0; i < indices_shapes_.size(); i++) { - dim_of_indices *= indices_shapes_[i]; - } - - size_t dim_after_indices = 1; - for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) { - dim_after_indices *= output_shapes_[i]; - } - - dims_[0] = dim_before_axis; - dims_[1] = dim_of_indices; - dims_[2] = dim_after_indices; - return; - } - size_t GetSize(const std::vector &shape) const { - if (shape.size() == 0) { - return 0; - } - size_t result = sizeof(T); - for (size_t i = 0; i < shape.size(); i++) { - result *= shape[i]; - } - return result; - } - - std::vector input_shapes_; - std::vector indices_shapes_; - std::vector output_shapes_; - - size_t dims_[3] = {}; - int axis_; - cudnnHandle_t handle_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_GATHER_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.cc deleted file mode 100644 index 7c160f8f58..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/one_hot_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO(OneHot, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - OneHotGpuFwdKernel, float, int) -MS_REG_GPU_KERNEL_TWO(OneHot, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - OneHotGpuFwdKernel, half, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.h deleted file mode 100644 index c8b64e7243..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.h +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/one_hot_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class OneHotGpuFwdKernel : public GpuKernel { - public: - OneHotGpuFwdKernel() : input_size_(1), output_size_(1), depth_(0), left_dim_size_(1), right_dim_size_(1) {} - ~OneHotGpuFwdKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - const S *indices = GetDeviceAddress(inputs, 0); - const T *on_value = GetDeviceAddress(inputs, 1); - const T *off_value = GetDeviceAddress(inputs, 2); - T *output = GetDeviceAddress(outputs, 0); - OneHot(indices, depth_, on_value, off_value, left_dim_size_, right_dim_size_, output, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - int axis = GetAttr(kernel_node, "axis"); - auto input = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto output = AnfAlgo::GetOutputInferShape(kernel_node, 0); - int input_size = SizeToInt(input.size()); - const int default_axis = -1; - - // Compress arbitrary tensor dimensions into three dimensions (left_dims, depth, right_dims). - for (int i = 0; i < input_size; i++) { - auto dim_size = input[IntToSize(i)]; - if (axis == default_axis || i < axis) { - left_dim_size_ *= dim_size; - } - if (axis != default_axis && i >= axis) { - right_dim_size_ *= dim_size; - } - } - for (auto size : input) { - input_size_ *= size; - } - for (auto size : output) { - output_size_ *= size; - } - if (axis >= input_size) { - MS_LOG(ERROR) << "invalid one hot axis value: " << axis << " for input dims size: " << input.size(); - return false; - } - if (axis == default_axis) { - depth_ = output[output.size() - 1]; - } else { - depth_ = output[IntToSize(axis)]; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - // inputs: indices, depth - input_size_list_.push_back((input_size_ + 1) * sizeof(S)); - output_size_list_.push_back(output_size_ * sizeof(T)); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; - size_t output_size_; - - size_t depth_; - size_t left_dim_size_; - size_t right_dim_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.cc deleted file mode 100644 index 41c9c2243f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.cc +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/select_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Select, - KernelAttr() - .AddInputAttr(kNumberTypeBool) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SelectGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Select, - KernelAttr() - .AddInputAttr(kNumberTypeBool) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SelectGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Select, - KernelAttr() - .AddInputAttr(kNumberTypeBool) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - SelectGpuKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.h deleted file mode 100644 index f1b6c5853a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.h +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/select_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SelectGpuKernel : public GpuKernel { - public: - SelectGpuKernel() : input_size_(0), output_size_(0) {} - ~SelectGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - bool *input_cond = GetDeviceAddress(inputs, 0); - T *input_x = GetDeviceAddress(inputs, 1); - T *input_y = GetDeviceAddress(inputs, 2); - T *output = GetDeviceAddress(outputs, 0); - CalSelect(output_size_ / sizeof(T), input_cond, input_x, input_y, output, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_size_ = sizeof(bool); - output_size_ = sizeof(T); - for (size_t x : shape) { - input_size_ = input_size_ * x; - output_size_ = output_size_ * x; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(output_size_); - input_size_list_.push_back(output_size_); - output_size_list_.push_back(output_size_); - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but SelectGpuKernel needs 3 output."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but SelectGpuKernel needs 1 output."; - return false; - } - return true; - } - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; - size_t output_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.cc deleted file mode 100644 index 53161c29c2..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/slice_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SliceGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SliceGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGpuFwdKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h deleted file mode 100644 index 7f71e548ad..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h +++ /dev/null @@ -1,162 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/slice_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SliceGpuFwdKernel : public GpuKernel { - public: - SliceGpuFwdKernel() - : is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {} - ~SliceGpuFwdKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - if (is_strided_slice_) { - CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output, - reinterpret_cast(stream_ptr)); - } else { - Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0], - input_shape_[1], input_shape_[2], input_shape_[3], input, output, - reinterpret_cast(stream_ptr)); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - ShapeNdTo4d(input_shape, &input_shape_); - auto strides = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"); - if (strides) { - strides_ = GetAttr>(kernel_node, "strides"); - for (auto i = strides_.size(); i < 4; i++) { - (void)strides_.insert(strides_.begin(), 1); - } - size_ = GetAttr>(kernel_node, "end"); - is_strided_slice_ = true; - } else { - size_ = GetAttr>(kernel_node, "size"); - } - for (auto i = begin_.size(); i < 4; i++) { - (void)begin_.insert(begin_.begin(), 0); - } - for (size_t i = size_.size(); i < 4; i++) { - (void)size_.insert(size_.begin(), 1); - } - for (size_t i = 0; i < begin_.size(); i++) { - if (begin_[i] < 0) { - begin_[i] = begin_[i] + input_shape_[i]; - } - } - for (size_t i = 0; i < size_.size(); i++) { - if (size_[i] < 0) { - size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; - } - if (begin_[i] == size_[i] && is_strided_slice_) { - MS_LOG(WARNING) << "Output is null."; - is_null_input_ = true; - } - if (size_[i] == 0 && strides_[i] > 0) { - size_[i] = begin_[i] + 1; - } - } - - input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); - auto out_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - - output_size_ = sizeof(T); - for (size_t x : out_shape) { - output_size_ = output_size_ * x; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but SliceGpuFwdKernel needs 1 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but SliceGpuFwdKernel needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGpuFwdKernel olny support 4d or lower."; - return false; - } - if (input_shape.size() == 0) { - MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", scalar is not supported."; - return false; - } - begin_ = GetAttr>(kernel_node, "begin"); - for (size_t i = 0; i < input_shape.size(); i++) { - if ((begin_[i] > 0 && (begin_[i] > SizeToInt(input_shape[i]))) || - (begin_[i] < 0 && (std::abs(begin_[i]) > SizeToInt(input_shape[i])))) { - MS_LOG(INFO) << "Input out of bounds " << input_shape[i] << " in axis " << i << "."; - begin_[i] = 0; - } - } - return true; - } - std::vector begin_; - std::vector size_; - std::vector strides_; - std::vector input_shape_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - bool is_strided_slice_; - bool is_null_input_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.cc deleted file mode 100644 index b91aafb734..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/slice_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - SliceGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGradGpuKernel, int) -MS_REG_GPU_KERNEL_ONE( - SliceGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SliceGradGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGradGpuKernel, int) -MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SliceGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.h deleted file mode 100644 index bf24272d93..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.h +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/slice_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SliceGradGpuKernel : public GpuKernel { - public: - SliceGradGpuKernel() : is_strided_slice_(false), input_size_(0), output_size_(0), workspace_size_(0) {} - ~SliceGradGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *dy = GetDeviceAddress(inputs, 0); - T *dx = GetDeviceAddress(outputs, 0); - FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast(stream_ptr)); - if (is_strided_slice_) { - CalStridedSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, strides_, dx, - reinterpret_cast(stream_ptr)); - } else { - CalSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, dx, - reinterpret_cast(stream_ptr)); - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == "StridedSliceGrad") { - is_strided_slice_ = true; - input_shape_ = GetAttr>(kernel_node, "shapex"); - for (auto i = input_shape_.size(); i < 4; i++) { - (void)input_shape_.insert(input_shape_.begin(), 1); - } - strides_ = GetAttr>(kernel_node, "strides"); - for (auto i = strides_.size(); i < 4; i++) { - (void)strides_.insert(strides_.begin(), 1); - } - size_ = GetAttr>(kernel_node, "end"); - } else { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - ShapeNdTo4d(input_shape, &input_shape_); - size_ = GetAttr>(kernel_node, "size"); - } - - auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - ShapeNdTo4d(dy_shape, &dy_shape_); - begin_ = GetAttr>(kernel_node, "begin"); - DealParam(); - input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); - - output_size_ = sizeof(T); - for (auto x : dy_shape_) { - output_size_ = output_size_ * IntToSize(x); - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(output_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGradGpuKernel only support 4d or lower."; - return false; - } - if (input_shape.size() == 0) { - MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", scalar is not supported."; - return false; - } - return true; - } - void DealParam() { - for (auto i = begin_.size(); i < 4; i++) { - (void)begin_.insert(begin_.begin(), 0); - } - for (auto i = size_.size(); i < 4; i++) { - (void)size_.insert(size_.begin(), 1); - } - for (size_t i = 0; i < begin_.size(); i++) { - if (begin_[i] < 0) { - begin_[i] = begin_[i] + input_shape_[i]; - } - } - for (size_t i = 0; i < size_.size(); i++) { - if (size_[i] < 0) { - size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; - } - } - } - std::vector begin_; - std::vector size_; - std::vector strides_; - std::vector input_shape_; - std::vector dy_shape_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - bool is_strided_slice_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.cc deleted file mode 100644 index 338e7a4093..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.cc +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/transpose_gpu_kernel.h" -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TransposeGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - TransposeGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h deleted file mode 100644 index 61be9b68fe..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h +++ /dev/null @@ -1,111 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/transpose_impl.cuh" -namespace mindspore { -namespace kernel { -template -class TransposeGpuFwdKernel : public GpuKernel { - public: - TransposeGpuFwdKernel() : shape_size_(0), input_size_(0), output_size_(0), workspace_size_(0) {} - ~TransposeGpuFwdKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - int *input_shape = GetDeviceAddress(workspace, 0); - int *input_axis = GetDeviceAddress(workspace, 1); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_axis failed"); - int size = SizeToInt(input_size_ / sizeof(T)); - CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 1 input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - shape_size_ = input_shape.size(); - if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION - << "-D inputs."; - } - - input_size_ = 1; - for (size_t i = 0; i < shape_size_; i++) { - input_size_ *= input_shape[i]; - input_shape_.push_back(input_shape[i]); - } - input_size_ *= sizeof(T); - output_size_ = input_size_; - auto perm = GetAttr>(kernel_node, "perm"); - for (size_t j = 0; j < perm.size(); j++) { - input_axis_.push_back(perm[j]); - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - workspace_size_ = shape_size_ * sizeof(int); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - return; - } - - private: - std::vector input_shape_; - std::vector input_axis_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t shape_size_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ diff --git a/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc deleted file mode 100644 index 9962d55988..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - UnsortedSegmentSum, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - UnsortedSegmentSumGpuKernel, float, int) - -MS_REG_GPU_KERNEL_TWO( - UnsortedSegmentSum, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - UnsortedSegmentSumGpuKernel, float, int64_t) - -MS_REG_GPU_KERNEL_TWO( - UnsortedSegmentSum, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - UnsortedSegmentSumGpuKernel, int, int) - -MS_REG_GPU_KERNEL_TWO( - UnsortedSegmentSum, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - UnsortedSegmentSumGpuKernel, int, int64_t) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h deleted file mode 100644 index a20375ee29..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/unsorted_segment_sum.cuh" - -namespace mindspore { -namespace kernel { -template -class UnsortedSegmentSumGpuKernel : public GpuKernel { - public: - UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {} - ~UnsortedSegmentSumGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input_addr = GetDeviceAddress(inputs, 0); - S *indices_addr = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemsetAsync(output_addr, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), - "cudaMemSet Failed"); - UnsortedSegmentSum(input_dim0_, input_dim1_, output_dim0_, output_dim1_, input_addr, indices_addr, output_addr, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); - - auto axis = ids_shapes.size(); - for (size_t i = 0; i < input_shapes.size(); i++) { - if (i < axis) { - input_dim0_ *= input_shapes[i]; - } else { - input_dim1_ *= input_shapes[i]; - } - } - - output_dim0_ = output_shapes[0]; - for (size_t j = 1; j < output_shapes.size(); j++) { - output_dim1_ *= output_shapes[j]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_dim0_ * input_dim1_ * sizeof(T)); - input_size_list_.push_back(input_dim0_ * sizeof(S)); - output_size_list_.push_back(output_dim0_ * output_dim1_ * sizeof(T)); - } - - private: - size_t input_dim0_; - size_t input_dim1_; - size_t output_dim0_; - size_t output_dim1_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ diff --git a/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.cc deleted file mode 100644 index 5468aa6500..0000000000 --- a/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.cc +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/control/recv_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_REGULAR(Recv, KernelAttr(), RecvGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.h deleted file mode 100644 index 12b4eed132..0000000000 --- a/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class RecvGpuKernel : public GpuKernel { - public: - RecvGpuKernel() {} - ~RecvGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &, const std::vector &, const std::vector &, - void *) override { - CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamWaitEvent(wait_stream_, wait_event_, 0), "Waiting cuda event failed."); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - wait_stream_ = reinterpret_cast(GetAttr(kernel_node, "wait_event_stream")); - wait_event_ = reinterpret_cast(GetAttr(kernel_node, "wait_event")); - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - return; - } - - private: - cudaStream_t wait_stream_{nullptr}; - cudaEvent_t wait_event_{nullptr}; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.cc deleted file mode 100644 index c417c30bb3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.cc +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/control/send_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_REGULAR(Send, KernelAttr(), SendGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.h deleted file mode 100644 index a26e41aa1e..0000000000 --- a/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SendGpuKernel : public GpuKernel { - public: - SendGpuKernel() {} - ~SendGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &, const std::vector &, const std::vector &, - void *) override { - CHECK_CUDA_RET_WITH_EXCEPT(cudaEventRecord(record_event_, record_stream_), "Recording cuda event failed."); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - record_stream_ = reinterpret_cast(GetAttr(kernel_node, "record_event_stream")); - record_event_ = reinterpret_cast(GetAttr(kernel_node, "record_event")); - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - return; - } - - private: - cudaStream_t record_stream_{nullptr}; - cudaEvent_t record_event_{nullptr}; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu deleted file mode 100644 index 3ec63ee03a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/adam_impl.cuh" - -template -__device__ __forceinline__ T SqrtFunc(T input) { - return sqrt(input); -} - -template <> -__device__ __forceinline__ half SqrtFunc(half input) { - return hsqrt(input); -} - -template -__global__ void ApplyAdamKernel(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, - const T *learning_rate, const T *beta1, const T *beta2, const T *epsilon, T *variable, - T *m, T *v) { - const T one = static_cast(1.0); - const T new_learning_rate = learning_rate[0] * SqrtFunc(one - beta2_power[0]) / (one - beta1_power[0]); - - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - m[i] += (gradient[i] - m[i]) * (one - beta1[0]); - v[i] += (gradient[i] * gradient[i] - v[i]) * (one - beta2[0]); - variable[i] -= new_learning_rate * m[i] / (SqrtFunc(v[i]) + epsilon[0]); - } -} - -template -void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, - const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream) { - ApplyAdamKernel<<>>( - size, gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, variable, m, v); -} - -template void ApplyAdam(const size_t size, const float *gradient, const float *beta1_power, - const float *beta2_power, const float *learning_rate, const float *beta1, - const float *beta2, const float *epsilon, float *variable, float *m, float *v, - cudaStream_t cuda_stream); -template void ApplyAdam(const size_t size, const half *gradient, const half *beta1_power, const half *beta2_power, - const half *learning_rate, const half *beta1, const half *beta2, const half *epsilon, - half *variable, half *m, half *v, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh deleted file mode 100644 index f48a113c26..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, - const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cu deleted file mode 100644 index dfadaa09d6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cu +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "adam_weight_decay_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void AdamWeightDecayKernel(const int element_num_, const bool need_decay, const float *beta1, - const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, - const float *epsilon, const float *lr, const float *weight_decay, T *m, T *v, - T *param, T *gradient) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num_; i += blockDim.x * gridDim.x) { - float next_m = beta1[0] * m[i] + one_sub_beta1[0] * gradient[i]; - float next_v = beta2[0] * v[i] + one_sub_beta2[0] * gradient[i] * gradient[i]; - float update = next_m / (sqrt(next_v) + epsilon[0]); - if (need_decay && weight_decay != nullptr) { - update += weight_decay[0] * param[i]; - } - param[i] -= lr[0] * update; - m[i] = next_m; - v[i] = next_v; - } -} - -template -void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1, - const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr, - const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream) { - AdamWeightDecayKernel<<>>( - element_num_, need_decay, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, param, - gradient); -} - -template void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, - const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, - const float *epsilon, const float *lr, const float *weight_decay, float *m, float *v, - float *param, float *gradient, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmax_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmax_impl.cu deleted file mode 100755 index e8fab27dda..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmax_impl.cu +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "argmax_impl.cuh" -#include "device/gpu/cuda_common.h" -#include "include/cuda_fp16.h" -template -__global__ void Argmax1D(const T* input, const int channel_size, int* output) { - int max_index = 0; - T max = input[0]; - for (int pos = 1; pos < channel_size; pos++) { - if (max < input[pos]) { - max = input[pos]; - max_index = pos; - } - } - output[0] = max_index; - return; -} -template -__global__ void ArgmaxDefault2D(const T* input, const int batch_size, const int channel_size, int* output) { - int pos; - int max_index; - T max; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) { - max = input[i * channel_size]; - max_index = 0; - for (int j = 1; j < channel_size; j++) { - pos = i * channel_size + j; - if (max < input[pos]) { - max = input[pos]; - max_index = j; - } - } - - output[i] = max_index; - } - return; -} -template -__global__ void ArgmaxAxis2D(const T* input, const int batch_size, const int channel_size, int* output) { - int pos; - int max_index; - T max; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { - max = input[i]; - max_index = 0; - for (int j = 1; j < batch_size; j++) { - pos = j * channel_size + i; - if (max < input[pos]) { - max = input[pos]; - max_index = j; - } - } - output[i] = max_index; - } - return; -} -template -void CalArgmax(const T* input, const int batch_size, const int channel_size, const int axis, int* output, - cudaStream_t cuda_stream) { - if (batch_size == 0) { - Argmax1D<<<1, 1, 0, cuda_stream>>>(input, channel_size, output); - } else if (axis == 1) { - ArgmaxDefault2D<<>>(input, batch_size, channel_size, output); - } else { - ArgmaxAxis2D<<>>(input, batch_size, channel_size, output); - } - return; -} - -template void CalArgmax(const float* input, const int batch_size, const int channel_size, const int axis, - int* output, cudaStream_t cuda_stream); -template void CalArgmax(const half* input, const int batch_size, const int channel_size, const int axis, - int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu deleted file mode 100644 index 3313fc6853..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "argmaxwithvalue_impl.cuh" -#include "device/gpu/cuda_common.h" -#include "include/cuda_fp16.h" -template -__global__ void ArgmaxWithValue(const T* input, const int bound, int outerSize, int innerSize, S* index, - T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outerSize); pos += blockDim.x * gridDim.x) { - int inputOutterOffset = pos * innerSize * bound; - int outputOutterOffset = pos * innerSize; - for (int j = 0; j < innerSize; j++) { - auto outputInnerOffset = outputOutterOffset + j; - S idx = 0; - T maxData = input[j + inputOutterOffset]; - for (S c = 0; c < bound; c++) { - int offset = j + c * innerSize; - auto inputData = input[inputOutterOffset + offset]; - idx = inputData > maxData ? c : idx; - maxData = inputData > maxData ? inputData : maxData; - } - output[outputInnerOffset] = maxData; - index[outputInnerOffset] = idx; - } - } - return; -} - -template -void CalArgmaxWithValue(const T* input, const int bound_, const int outerSize_, const int innerSize_, - S* index, T* output, cudaStream_t cuda_stream) { - ArgmaxWithValue<<>>(input, bound_, outerSize_, innerSize_, - index, output); - return; -} - -template void CalArgmaxWithValue(const float* input, const int bound_, const int outerSize_, - const int innerSize_, int* index, float* output, - cudaStream_t cuda_stream); -template void CalArgmaxWithValue(const half* input, const int bound_, const int outerSize_, - const int innerSize_, int* index, half* output, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/assign_add_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/assign_add_impl.cu deleted file mode 100644 index d44ad99202..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/assign_add_impl.cu +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "assign_add_impl.cuh" -#include "device/gpu/cuda_common.h" -#include "include/cuda_fp16.h" -template -__global__ void AssignAdd(const size_t size, T* ref, const T* value, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - output[pos] = ref[pos] + value[pos]; - ref[pos] = output[pos]; - } - return; -} - -template -void CalAssignAdd(const size_t size, T* ref, const T* value, T* output, cudaStream_t cuda_stream) { - AssignAdd<<>>(size, ref, value, output); - - return; -} - -template void CalAssignAdd(const size_t size, float* ref, const float* value, float* output, - cudaStream_t cuda_stream); -template void CalAssignAdd(const size_t size, half* ref, const half* value, half* output, - cudaStream_t cuda_stream); -template void CalAssignAdd(const size_t size, int* ref, const int* value, int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh deleted file mode 100644 index c3ce08dfd0..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ - -#include "device/gpu/cuda_common.h" -template -void BatchNormFold2Forward(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean, - const T *running_std, const T *running_mean, const int *global_step, T *y, int freeze_bn, - size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream); -template -void CalBatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, - const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, - T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream); -template -void CalBatchNormFold2GradFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, - const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, - T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream); -template -void BatchNormFold2GradReduce(const T *dout, const T *x, T *d_beta, T *tmp, T *reduce_x, T *tmp2, T *tmp_x, size_t N, - size_t C, size_t H, size_t W, cudaStream_t cuda_stream); - -template -void CalBatchNormFold2GradNotFreezeDxMul(const T *batch_std, const T *running_std, T *d_x, size_t N, size_t C, size_t H, - size_t W, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold_impl.cu deleted file mode 100755 index ddc2803f56..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold_impl.cu +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "batchnorm_fold_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void UpdateRunningStd(int channel_size, const double epsilon, T* running_std) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { - running_std[i] = sqrtf(running_std[i] + epsilon); - } - return; -} - -template -__global__ void UpdateBatchStd(int channel_size, T* batch_std) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { - batch_std[i] = 1 / batch_std[i]; - } - return; -} - -template -__global__ void CalDx(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, const T* batch_std, - int batch_size, int channel_size, int height, int width, T* dx) { - int n = batch_size * channel_size * height * width; - int normal_size = batch_size * height * width; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { - int channel_index = i / (height * width) % channel_size; - dx[i] = d_batch_mean[channel_index] / normal_size + - d_batch_std[channel_index] * (x[i] - batch_mean[channel_index]) / batch_std[channel_index] / normal_size; - } - return; -} - -template -void CalUpdateRunningStd(int channel_size, double epsilon, T* running_std, cudaStream_t cuda_stream) { - UpdateRunningStd<<>>(channel_size, epsilon, running_std); - return; -} - -template void CalUpdateRunningStd(int channel_size, double epsilon, float* running_std, - cudaStream_t cuda_stream); - -template -void CalUpdateBatchStd(int channel_size, T* batch_std, cudaStream_t cuda_stream) { - UpdateBatchStd<<>>(channel_size, batch_std); - return; -} - -template void CalUpdateBatchStd(int channel_size, float* batch_std, cudaStream_t cuda_stream); - -template -void CalBatchNormFoldGrad(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, - const T* batch_std, int batch_size, int channel_size, int height, int width, T* dx, - cudaStream_t cuda_stream) { - CalDx<<>>( - d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_size, channel_size, height, width, dx); -} - -template void CalBatchNormFoldGrad(const float* d_batch_mean, const float* d_batch_std, const float* x, - const float* batch_mean, const float* batch_std, int batch_size, - int channel_size, int height, int width, float* dx, cudaStream_t cuda_stream); - -template -void ThrustFillWith(T* array, int size, T tofill, cudaStream_t cuda_stream) { - thrust::device_ptr dev_ptr(array); - thrust::fill(thrust::cuda::par.on(cuda_stream), dev_ptr, dev_ptr + size, tofill); -} - -template void ThrustFillWith(float* array, int size, float tofill, cudaStream_t cuda_stream); - diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cu deleted file mode 100644 index 5aa087e7f5..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cu +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/broadcast_grad_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -struct MinimumGradFunc { - __device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) { - if (x1 < x2) { - atomicAdd(dx1, dy); - } else { - atomicAdd(dx2, dy); - } - } -}; - -template -struct MaximumGradFunc { - __device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) { - if (x1 > x2) { - atomicAdd(dx1, dy); - } else { - atomicAdd(dx2, dy); - } - } -}; - -__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } - -template -__device__ __forceinline__ void BroadcastGradOperator(const int &l0, const int &l1, const int &l2, const int &l3, - const int &r0, const int &r1, const int &r2, const int &r3, - const int &d0, const int &d1, const int &d2, const int &d3, - const T *x1, const T *x2, const T *dy, T *dx1, T *dx2) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { - int i = pos / (d1 * d2 * d3) % d0; - int j = pos / (d2 * d3) % d1; - int k = pos / d3 % d2; - int l = pos % d3; - - int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); - int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); - Func()(x1[l_index], x2[r_index], dy[pos], dx1 + l_index, dx2 + r_index); - } -} - -template -__global__ void BroadcastGradKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, - const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, - enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, - T *dx2) { - switch (op) { - case BROADCAST_GRAD_TYPE_MINIMUM: - return BroadcastGradOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy, - dx1, dx2); - case BROADCAST_GRAD_TYPE_MAXIMUM: - return BroadcastGradOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy, - dx1, dx2); - } -} - -template -void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, - cudaStream_t stream) { - int size = d0 * d1 * d2 * d3; - BroadcastGradKernel<<>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, - x1, x2, dy, dx1, dx2); -} - -template -__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *x1, const T *x2, const T *dy, T *dx1, - T *dx2) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { - Func()(x1[pos], x2[pos], dy[pos], dx1 + pos, dx2 + pos); - } -} - -template -__global__ void NoBroadcastGradKernel(const int nums, enum BroadcastGradOpType op, const T *x1, const T *x2, - const T *dy, T *dx1, T *dx2) { - switch (op) { - case BROADCAST_GRAD_TYPE_MINIMUM: - return NoBroadcastOperator>(nums, x1, x2, dy, dx1, dx2); - case BROADCAST_GRAD_TYPE_MAXIMUM: - return NoBroadcastOperator>(nums, x1, x2, dy, dx1, dx2); - } -} - -template -void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, - T *dx2, cudaStream_t stream) { - NoBroadcastGradKernel<<>>(nums, op, x1, x2, dy, dx1, dx2); -} - -template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const float *x1, const float *x2, - const float *dy, float *dx1, float *dx2, cudaStream_t stream); -template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const int *x1, const int *x2, - const int *dy, int *dx1, int *dx2, cudaStream_t stream); -template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastGradOpType op, const float *x1, const float *x2, const float *dy, float *dx1, - float *dx2, cudaStream_t stream); -template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastGradOpType op, const int *x1, const int *x2, const int *dy, int *dx1, - int *dx2, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cuh deleted file mode 100644 index d154eddd4c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cuh +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ - -#include "device/gpu/cuda_common.h" - -enum BroadcastGradOpType { - BROADCAST_GRAD_TYPE_MAXIMUM = 0, - BROADCAST_GRAD_TYPE_MINIMUM = 1, - BROADCAST_GRAD_TYPE_INVALID = 0xffffffff, -}; - -template -void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, - cudaStream_t stream); - -template -void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, - T *dx2, cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu deleted file mode 100644 index afa94fc56c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/broadcast_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -struct GreaterFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; } -}; - -template -struct LessFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; } -}; - -template -struct MinimumFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } -}; - -template -struct MaximumFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; } -}; - -template -struct PowerFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); } -}; - -template <> -struct PowerFunc { - __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { - return __float2half(pow(__half2float(lhs), __half2float(rhs))); - } -}; - -template -struct RealDivFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); } -}; - -template -struct MulFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs * rhs); } -}; - -template -struct SubFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs - rhs); } -}; - -template -struct AddFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } -}; - -template <> -struct PowerFunc { - // invalid branch - __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } -}; - -__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } - -template -__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, - const int &r0, const int &r1, const int &r2, const int &r3, - const int &d0, const int &d1, const int &d2, const int &d3, - const T *input0, const T *input1, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { - int i = pos / (d1 * d2 * d3) % d0; - int j = pos / (d2 * d3) % d1; - int k = pos / d3 % d2; - int l = pos % d3; - - int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); - int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); - output[pos] = Func()(input0[l_index], input1[r_index]); - } -} - -template -__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, - const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, - enum BroadcastOpType op, const T *input0, const T *input1, S *output) { - switch (op) { - case BROADCAST_TYPE_GREATER: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_LESS: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_MINIMUM: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_MAXIMUM: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_POWER: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_REALDIV: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_MUL: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_SUB: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_ADD: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - } -} - -template -void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, - const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, - const T *input0, const T *input1, S *output, cudaStream_t stream) { - int size = d0 * d1 * d2 * d3; - BroadcastKernel<<>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, - input0, input1, output); -} - -template -__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *input0, const T *input1, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { - output[pos] = Func()(input0[pos], input1[pos]); - } -} - -template -__global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const T *input0, const T *input1, - S *output) { - switch (op) { - case BROADCAST_TYPE_GREATER: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_LESS: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_MINIMUM: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_MAXIMUM: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_POWER: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_REALDIV: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_MUL: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_SUB: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_ADD: - return NoBroadcastOperator>(nums, input0, input1, output); - } -} - -template -void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, const T *input1, S *output, - cudaStream_t stream) { - NoBroadcastKernel<<>>(nums, op, input0, input1, output); -} - -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const float *input0, const float *input1, bool *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const float *input0, const float *input1, float *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const half *input0, const half *input1, bool *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const half *input0, const half *input1, half *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const int *input0, const int *input1, int *output, - cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, - bool *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, - float *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, - bool *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, - half *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, - int *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh deleted file mode 100644 index 5f6992511d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ - -#include "device/gpu/cuda_common.h" - -enum BroadcastOpType { - BROADCAST_TYPE_GREATER = 0, - BROADCAST_TYPE_LESS = 1, - BROADCAST_TYPE_MAXIMUM = 2, - BROADCAST_TYPE_MINIMUM = 3, - BROADCAST_TYPE_POWER = 4, - BROADCAST_TYPE_REALDIV = 5, - BROADCAST_TYPE_MUL = 6, - BROADCAST_TYPE_SUB = 7, - BROADCAST_TYPE_ADD = 8, - BROADCAST_TYPE_INVALID = 0xffffffff, -}; - -template -void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, - const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, - const T *input0, const T *input1, S *output, cudaStream_t stream); - -template -void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, - cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu deleted file mode 100755 index 5cccf183ea..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "kernel/gpu/cuda_impl/concatv2_impl.cuh" -template -__global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2); - int m = pos % (w1 + w2); - output[pos] = m >= w1 ? input_2[n * w2 + m - w1] : input_1[n * w1 + m]; - } - return; -} - -template -__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2 + w3); - int m = pos % (w1 + w2 + w3); - output[pos] = m < w1 ? input_1[n * w1 + m] : - m < w1 + w2 ? input_2[n * w2 + m - w1] : - input_3[n * w3 + m - w1 - w2]; - } - return; -} - -template -__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2 + w3 + w4); - int m = pos % (w1 + w2 + w3 + w4); - output[pos] = m < w1 ? input_1[n * w1 + m] : - m < w1 + w2 ? input_2[n * w2 + m - w1]: - m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]: - input_4[n * w4 + m - w1 - w2 - w3]; - } - return; -} - -template -void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, - cudaStream_t cuda_stream) { - Concat<<>>(size, w1, w2, input_1, input_2, output); - return; -} - -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output, - cudaStream_t cuda_stream) { - Concat<<>>(size, w1, w2, w3, input_1, input_2, input_3, output); - return; -} - -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, - cudaStream_t cuda_stream) { - Concat<<>>(size, w1, w2, w3, w4, input_1, - input_2, input_3, input_4, output); - return; -} - -template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, - half* output, cudaStream_t cuda_stream); - -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const float* input_1, const float* input_2, const float* input_3, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const int* input_1, const int* input_2, const int* input_3, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const half* input_1, const half* input_2, const half* input_3, - half* output, cudaStream_t cuda_stream); - -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const float* input_1, const float* input_2, const float* input_3, const float* input_4, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const int* input_1, const int* input_2, const int* input_3, const int* input_4, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const half* input_1, const half* input_2, const half* input_3, const half* input_4, - half* output, cudaStream_t cuda_stream); - diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh deleted file mode 100755 index b6932aa4a1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, - cudaStream_t cuda_stream); -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream); -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/correction_mul_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/correction_mul_impl.cu deleted file mode 100755 index ac2f99ed9a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/correction_mul_impl.cu +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "correction_mul_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void CorrectionMul(const T* weight, const T* gamma, const T* running_std, const int batchsize, const int chw, - T* output) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batchsize * chw; i += blockDim.x * gridDim.x) { - int n = i / chw; - output[i] = weight[i] * gamma[n] / running_std[n]; - } - return; -} - -template -__global__ void Mul(int N, const T* a, const T* b, T* c) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - c[i] = a[i] * b[i]; - } - return; -} - -template -__global__ void Reduce(int N, int CHW, const T* tmp, const T* running_std, T* d_gamma) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - d_gamma[i] = thrust::reduce(thrust::seq, tmp + i * CHW, tmp + (i + 1) * CHW, 0.f, thrust::plus()); - d_gamma[i] = d_gamma[i] / running_std[i]; - } - return; -} - -template -void CalCorrectionMul(const T* weight, const T* gamma, const T* running_std, int N, int C, int H, int W, T* output, - cudaStream_t cuda_stream) { - CorrectionMul<<>>(weight, gamma, running_std, N, C * H * W, - output); -} - -template void CalCorrectionMul(const float* weight, const float* gamma, const float* running_std, int N, int C, - int H, int W, float* output, cudaStream_t cuda_stream); - -template -void CalCorrectionMulGrad(const T* d_out, const T* weight, const T* running_std, int N, int C, int H, int W, T* d_gamma, - T* tmp, cudaStream_t cuda_stream) { - Mul<<>>(N * C * H * W, d_out, weight, tmp); - Reduce<<>>(N, C * H * W, tmp, running_std, d_gamma); -} - -template void CalCorrectionMulGrad(const float* d_out, const float* weight, const float* running_std, int N, - int C, int H, int W, float* d_gamma, float* tmp, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh deleted file mode 100644 index 54ae072892..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ - -#include "device/gpu/cuda_common.h" - -template -void CrossEntropyWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *loss, - cudaStream_t cuda_stream); - -template -void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *grad, cudaStream_t cuda_stream); - -template -void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, - T *dlogits, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh deleted file mode 100644 index f89d42ce49..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ - -#include "device/gpu/cuda_common.h" -template -void DropoutForward(const T *input, T *mask, T *output, float *mask_f, size_t num_count, float keep_prob, - cudaStream_t cuda_stream); -template -void DropoutBackward(const T *dy, const T *mask, T *dx, size_t num_count, float keep_prob, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/equalcount_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/equalcount_impl.cu deleted file mode 100755 index 38dd79c441..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/equalcount_impl.cu +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "equalcount_impl.cuh" -#include "device/gpu/cuda_common.h" -template -__global__ void EqualCount(const int size, const T* input1, const T* input2, T* output) { - T equal_count = 0; - - for (int i = 0; i < size; i++) { - if (input1[i] == input2[i]) { - equal_count++; - } - } - - output[0] = equal_count; - return; -} -template -void CalEqualCount(const int size, const T* input1, const T* input2, T* output, cudaStream_t cuda_stream) { - EqualCount<<<1, 1, 0, cuda_stream>>>(size, input1, input2, output); - return; -} - -template void CalEqualCount(const int size, const int* input1, const int* input2, int* output, - cudaStream_t cuda_stream); -template void CalEqualCount(const int size, const float* input1, const float* input2, float* output, - cudaStream_t cuda_stream); -template void CalEqualCount(const int size, const half* input1, const half* input2, half* output, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh deleted file mode 100644 index ad2e387b08..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ - -#include "device/gpu/cuda_common.h" - -void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max, - float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric, - cudaStream_t cuda_stream); - -void CalFakeQuantPerChannel(const float *input, float *output, const int total_num, const int channel_num, - const float *nudge_min, const float *nudge_max, const float *scale, - cudaStream_t cuda_stream); - -void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, - const int channel_num, const float *nudge_min, const float *nudge_max, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh deleted file mode 100644 index dda95ed781..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ - -#include "device/gpu/cuda_common.h" - -void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, - float *nudge_min, float *nudge_max, float *scale, const bool symmetric, cudaStream_t cuda_stream); - -void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, - const float *nudge_max, const float *scale, cudaStream_t cuda_stream); - -void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, - const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu deleted file mode 100644 index c2fd5ecd70..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "include/cuda_runtime.h" -#include "kernel/gpu/cuda_impl/float_status_impl.cuh" - -template -__global__ void IsNan(const size_t size, const T* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (isnan(input[pos])) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} -template <> -__global__ void IsNan(const size_t size, const half* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (__hisnan(input[pos])) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} - -template -__global__ void IsInf(const size_t size, const T* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (isinf(input[pos]) != 0) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} -template <> -__global__ void IsInf(const size_t size, const half* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (__hisinf(input[pos]) != 0) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} - -template -__global__ void IsFinite(const size_t size, const T* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (isinf(input[pos]) == 0 && !isnan(input[pos])) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} -template <> -__global__ void IsFinite(const size_t size, const half* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (__hisinf(input[pos]) == 0 && !__hisnan(input[pos])) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} - -template -__global__ void FloatStatus(const size_t size, const T* input, T* out) { - out[0] = 0; - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (isinf(input[pos]) != 0 || isnan(input[pos])) { - out[0] = 1; - } - } - return; -} -template <> -__global__ void FloatStatus(const size_t size, const half* input, half* out) { - out[0] = 0; - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) { - out[0] = 1; - } - } - return; -} - -template -void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) { - FloatStatus<<>>(size, input, output); - return; -} -template -void CalIsNan(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { - IsNan<<>>(size, input, output); - return; -} -template -void CalIsInf(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { - IsInf<<>>(size, input, output); - return; -} -template -void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { - IsFinite<<>>(size, input, output); - return; -} - -template void CalFloatStatus(const size_t size, const float* input, float* output, cudaStream_t cuda_stream); -template void CalFloatStatus(const size_t size, const half* input, half* output, cudaStream_t cuda_stream); -template void CalIsInf(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); -template void CalIsInf(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); -template void CalIsNan(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); -template void CalIsNan(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); -template void CalIsFinite(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); -template void CalIsFinite(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh deleted file mode 100644 index da488ff937..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ -#include "device/gpu/cuda_common.h" -template -void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream); -template -void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream); -template -void CalIsInf(const size_t size, const T *input, bool *output, cudaStream_t stream); -template -void CalIsFinite(const size_t size, const T *input, bool *output, cudaStream_t stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu deleted file mode 100644 index ea6ffdbbdc..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/ftrl_impl.cuh" - -template -__device__ __forceinline__ T PowFunc(T x, T y) { - return pow(x, y); -} - -template <> -__device__ __forceinline__ half PowFunc(half x, half y) { - return __float2half(pow(__half2float(x), __half2float(y))); -} - -template -__device__ __forceinline__ bool CompareFunc(T x, T y) { - return abs(x) > y; -} - -template <> -__device__ __forceinline__ bool CompareFunc(half x, half y) { - return abs(__half2float(x)) > __half2float(y); -} - -template -__device__ __forceinline__ T Sgn(T x) { - return static_cast(x != 0 ? (x > 0 ? 1 : -1) : 0); -} - -template <> -__device__ __forceinline__ half Sgn(half x) { - return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0); -} - -template -__global__ void ApplyFtrlKernel(const size_t size, const T *gradient, const T *learning_rate, - const T *l1_regularization, const T *l2_regularization, const T *learning_rate_power, - T *variable, T *accumulation, T *linear) { - const T two = static_cast(2.0); - const T learning_rate_power_val = -learning_rate_power[0]; - - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - const T cur_accumulation = accumulation[i] + gradient[i] * gradient[i]; - const T accumulation_power = PowFunc(accumulation[i], learning_rate_power_val); - const T cur_accumulation_power = PowFunc(cur_accumulation, learning_rate_power_val); - const T sigma = (cur_accumulation_power - accumulation_power) / learning_rate[0]; - - linear[i] += gradient[i] - sigma * variable[i]; - variable[i] = CompareFunc(linear[i], l1_regularization[0]) - ? ((l1_regularization[0] * Sgn(linear[i]) - linear[i]) / - (cur_accumulation_power / learning_rate[0] + two * l2_regularization[0])) - : static_cast(0); - accumulation[i] = cur_accumulation; - } -} - -template -void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, - const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, - cudaStream_t cuda_stream) { - ApplyFtrlKernel<<>>(size, gradient, learning_rate, l1_regularization, - l2_regularization, learning_rate_power, variable, - accumulation, linear); -} - -template void ApplyFtrl(const size_t size, const float *gradient, const float *learning_rate, - const float *l1_regularization, const float *l2_regularization, - const float *learning_rate_power, float *variable, float *accumulation, float *linear, - cudaStream_t cuda_stream); -template void ApplyFtrl(const size_t size, const half *gradient, const half *learning_rate, - const half *l1_regularization, const half *l2_regularization, - const half *learning_rate_power, half *variable, half *accumulation, half *linear, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh deleted file mode 100644 index ba4a8fa816..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, - const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gather.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/gather.cu deleted file mode 100755 index 6bde359d9b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/gather.cu +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "kernel/gpu/cuda_impl/gather.cuh" -#include "device/gpu/cuda_common.h" -template -__global__ void GatherKernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1) { - int num = output_dim0 * output_dim1 * output_dim2; - int i, j, k; - for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; - write_index += blockDim.x * gridDim.x) { - i = write_index / (output_dim1 * output_dim2) % output_dim0; - j = write_index / output_dim2 % output_dim1; - k = write_index % output_dim2; - - if ((indices[j] >= 0) && (indices[j] < input_dim1)) { - int read_index = i * input_dim1 * output_dim2 + indices[j] * output_dim2 + k; - output[write_index] = input[read_index]; - } else { - output[write_index] = 0; - } - } - - return; -} -template -void Gather(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream) { - int size = output_dim0 * output_dim1 * output_dim2; - GatherKernel<<>>(input, indices, output, output_dim0, output_dim1, - output_dim2, input_dim1); - return; -} - -template void Gather(float *input, int *indices, float *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1, cudaStream_t stream); - -template void Gather(half *input, int *indices, half *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu deleted file mode 100644 index e460caec9e..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/gelu_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void GeluKernel(size_t size, T *input_addr, T *output_addr) { - // formula: - // gelu(x) = 0.5 * x * (1.0 + tanh(y)) - // tanh(y) = 2 / (1 + exp(-2y)) - 1) - // y = sqrt(2/pi) * (x + 0.044715 * x^3) - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - float x = input_addr[pos]; - float tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); - output_addr[pos] = 0.5 * x * (1.0 + tanh_res); - } -} - -template <> -__global__ void GeluKernel(size_t size, half *input_addr, half *output_addr) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - half x = input_addr[pos]; - float tanh_res = tanh(__half2float(half(0.7978845608) * (x + half(0.044715) * x * x * x))); - output_addr[pos] = half(0.5) * x * (half(1.0) + __float2half(tanh_res)); - } -} - -template <> -__global__ void GeluKernel(size_t size, half2 *input_addr, half2 *output_addr) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - half2 x = input_addr[pos]; - float2 tanh_param = __half22float2(half2(0.7978845608, 0.7978845608) * (x + half2(0.044715, 0.044715) * x * x * x)); - float2 tanh_res; - tanh_res.x = tanh(tanh_param.x); - tanh_res.y = tanh(tanh_param.y); - output_addr[pos] = half2(0.5, 0.5) * x * (half2(1.0, 1.0) + __float22half2_rn(tanh_res)); - } -} - -template -void Gelu(size_t size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) { - GeluKernel<<>>(size, input_addr, output_addr); - return; -} - -template <> -void Gelu(size_t size, half *input_addr, half *output_addr, cudaStream_t cuda_stream) { - if (size % 2 == 0) { - GeluKernel<<>>( - size / 2, reinterpret_cast(input_addr), reinterpret_cast(output_addr)); - } else { - GeluKernel<<>>(size, input_addr, output_addr); - } - return; -} - -template -__global__ void GeluGradKernel(size_t size, T *dy_addr, T *x_addr, T *dx_addr) { - // formula: - // dx = dy * y' - // y' = 0.5 * (1 + tanh(tanh_para)) + - // 0.5 * x * (1 - tanh(tanh_para) * tanh(tanh_para)) * mul_right - // tanh_para = sqrt(2/pi) * (x + 0.044715 * x^3) - // mul_right = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)) - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - T x = x_addr[pos]; - T tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); - T mul_right = 0.7978845608 + 0.1070322244 * x * x; - T y_res = 0.5 * (1.0 + tanh_res) + 0.5 * x * (1.0 - tanh_res * tanh_res) * mul_right; - dx_addr[pos] = dy_addr[pos] * y_res; - } -} - -template -__global__ void GeluGradKernel(size_t size, half2 *dy_addr, half2 *x_addr, half2 *dx_addr) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - half2 x = x_addr[pos]; - float2 tanh_param = __half22float2(half2(0.7978845608, 0.7978845608) * (x + half2(0.044715, 0.044715) * x * x * x)); - float2 tanh_res; - tanh_res.x = tanh(tanh_param.x); - tanh_res.y = tanh(tanh_param.y); - half2 tanh_res_half = __float22half2_rn(tanh_res); - half2 mul_right = half2(0.7978845608, 0.7978845608) + half2(0.1070322244, 0.1070322244) * x * x; - half2 y_res = half2(0.5, 0.5) * (half2(1.0, 1.0) + tanh_res_half) + - half2(0.5, 0.5) * x * (half2(1.0, 1.0) - tanh_res_half * tanh_res_half) * mul_right; - dx_addr[pos] = dy_addr[pos] * y_res; - } -} - -template -__global__ void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - half x = x_addr[pos]; - half tanh_param = half(0.7978845608) * (x + half(0.044715) * x * x * x); - half tanh_res = __float2half_rn(tanh(__half2float(tanh_param))); - half mul_right = half(0.7978845608) + half(0.1070322244) * x * x; - half y_res = half(0.5) * (half(1.0) + tanh_res) + half(0.5) * x * (half(1.0) - tanh_res * tanh_res) * mul_right; - dx_addr[pos] = dy_addr[pos] * y_res; - } -} - -template -void GeluGradKernel(size_t size, T *dy_addr, T *x_addr, T *dx_addr, cudaStream_t cuda_stream) { - GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); -} - -template <> -void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr, cudaStream_t cuda_stream) { - if (size % 2 == 0) { - GeluGradKernel<<>>( - size / 2, reinterpret_cast(dy_addr), reinterpret_cast(x_addr), - reinterpret_cast(dx_addr)); - } else { - GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); - } - return; -} - -template void Gelu(size_t size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); -template void Gelu(size_t size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); -template void GeluGradKernel(size_t size, float *dy_addr, float *x_addr, float *dx_addr, cudaStream_t cuda_stream); -template void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh deleted file mode 100644 index 7a8e1fae8a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ - -#include "device/gpu/cuda_common.h" -template -void Gelu(size_t input_size, T* input_addr, T* output_addr, cudaStream_t cuda_stream); - -template -void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu deleted file mode 100644 index e887b98eca..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu +++ /dev/null @@ -1,259 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" -#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" - -constexpr int NUM_PER_THREAD_REDUCE = 4; -constexpr int WARP_SIZE = 32; - -template -inline __device__ T my_pow(T a, double b) { - return pow(a, static_cast(b)); -} - -template <> -inline __device__ half my_pow(half a, double b) { - return __float2half(pow(__half2float(a), static_cast(b))); -} - -template -inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, - const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, - T* dg, T* db) { - int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; - for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { - for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { - int row = NUM_PER_THREAD_REDUCE * i + j; - if (row >= row_dim) { - return; - } - - int pos = row * col_dim + col; - dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); - db[0] += dy[pos]; - } - } -} - -template -inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { - for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { - dg[0] += __shfl_down_sync(0xffffffff, dg[0], delta); - db[0] += __shfl_down_sync(0xffffffff, db[0], delta); - } -} - -template -inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_dim, T* dg, T* db, T* dg_addr, - T* db_addr) { - if (threadIdx.x >= row_dim) { - return; - } - - // load data to share memory - // thread(0, 32, 64, 96, ...) keep the data - DynamicSharedMem share_mem; - if (threadIdx.x % WARP_SIZE == 0) { - int offset = threadIdx.x / WARP_SIZE * 2; - share_mem.addr()[offset] = dg[0]; - share_mem.addr()[offset + 1] = db[0]; - } - __syncthreads(); - - for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { - if (threadIdx.x < stride) { - int offset = (threadIdx.x + stride) * 2; - share_mem.addr()[threadIdx.x * 2] += share_mem.addr()[offset]; - share_mem.addr()[threadIdx.x * 2 + 1] += share_mem.addr()[offset + 1]; - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - dg_addr[col] = share_mem.addr()[0]; - db_addr[col] = share_mem.addr()[1]; - } -} - -template -__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T* dy, const T* x, - const T* mean_addr, const T* var_addr, T* dg_addr, T* db_addr) { - // row: [0:param_axis] - // col: [param_axis:] - // dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) - // dg[j] = \Sigma_{j}dg[i][j] - for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { - T dg = 0; - T db = 0; - GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); - GammaAndBetaWarpReduce(&dg, &db); - GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr); - } -} - -template -inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, - T* sum1, T* sum2, T* sum3, const T* dy, const T* x, const T* mean, - const T* var, const T* gamma) { - int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; - for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { - for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { - int col = NUM_PER_THREAD_REDUCE * i + j; - if (col >= col_dim) { - return; - } - - int pos = row * col_dim + col; - int gamma_offset = pos % param_dim; - T v1 = dy[pos] * gamma[gamma_offset]; - T v2 = x[pos] - mean[row]; - - sum1[0] += -0.5 * v1 * v2 * my_pow(var[row] + epsilon, -1.5); - sum2[0] += v1; - sum3[0] += -2.0 * v2; - } - } -} - -template <> -inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, - half* sum1, half* sum2, half* sum3, const half* dy, const half* x, - const half* mean, const half* var, const half* gamma) { - int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; - for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { - for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { - int col = NUM_PER_THREAD_REDUCE * i + j; - if (col >= col_dim) { - return; - } - - int pos = row * col_dim + col; - int gamma_offset = pos % param_dim; - half v1 = dy[pos] * gamma[gamma_offset]; - half v2 = x[pos] - mean[row]; - - sum1[0] += __float2half(-0.5) * v1 * v2 * my_pow(var[row] + epsilon, -1.5); - sum2[0] += v1; - sum3[0] += __float2half(-2.0) * v2; - } - } -} - -template -inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { - for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { - sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); - sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); - sum3[0] += __shfl_down_sync(0xffffffff, sum3[0], delta); - } -} - -template -inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* sum3, T* share_mem) { - if (threadIdx.x >= col_dim) { - return; - } - - // load data to share memory - // thread(0, 32, 64, 96, ...) keep the data - if (threadIdx.x % WARP_SIZE == 0) { - int offset = threadIdx.x / WARP_SIZE * 3; - share_mem[offset] = sum1[0]; - share_mem[offset + 1] = sum2[0]; - share_mem[offset + 2] = sum3[0]; - } - __syncthreads(); - - for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { - if (threadIdx.x < stride) { - int offset = (threadIdx.x + stride) * 3; - share_mem[threadIdx.x * 3] += share_mem[offset]; - share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1]; - share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2]; - } - } - __syncthreads(); -} - -template -inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, - const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx, - const T* share_mem) { - for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { - int pos = (row * col_dim + col); - int gamma_offset = pos % param_dim; - T v1 = dy[pos] * gamma[gamma_offset]; - T v2 = x[pos] - mean[row]; - T v3 = my_pow(var[row] + epsilon, -0.5); - dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 + - (-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim); - } -} - -template <> -inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, - const half* dy, const half* x, const half* mean, const half* var, const half* gamma, - half* dx, const half* share_mem) { - for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { - int pos = (row * col_dim + col); - int gamma_offset = pos % param_dim; - half v1 = dy[pos] * gamma[gamma_offset]; - half v2 = x[pos] - mean[row]; - half v3 = my_pow(var[row] + epsilon, -0.5); - dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 + - (__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\ - * __float2half(1.0 / col_dim); - } -} - -template -__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, - const T* x, const T* mean, const T* var, const T* gamma, T* dx) { - for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { - T sum1 = 0; - T sum2 = 0; - T sum3 = 0; - DynamicSharedMem share_mem; - InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma); - InputWarpReduce(&sum1, &sum2, &sum3); - InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem.addr()); - InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem.addr()); - } -} - -template -void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, - const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { - int share_mem_size = - ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); - InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, - gamma, dx); - - share_mem_size = - ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); - GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); -} - -template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, - const float* dy, const float* x, const float* mean, const float* var, const float* gamma, - float* dx, float* dg, float* db, cudaStream_t stream); -template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const half& epsilon, - const half* dy, const half* x, const half* mean, const half* var, const half* gamma, - half* dx, half* dg, half* db, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh deleted file mode 100644 index 9f7d57cdb9..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ - -#include "device/gpu/cuda_common.h" - -template -void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, - const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu deleted file mode 100644 index cfb60f0ba6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" - -constexpr int NUM_PER_THREAD_REDUCE = 4; -constexpr int WARP_SIZE = 32; - -template -inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T &val) { - // Welford Algorithm: - // \mu_k = \mu_{k-1} + (x_k - \mu_{k-1})/k - // \sigma_k^2 = \sigma_{k-1}^2 + (x_k - \mu_{k-1}) * (x_k - \mu_k) - num[0]++; - T mean_new = mean[0] + (val - mean[0]) / num[0]; - var[0] = var[0] + (val - mean[0]) * (val - mean_new); - mean[0] = mean_new; -} - -template -inline __device__ void MeanAndVarMerge(T *m1, T *v1, T *n1, const T &m2, const T &v2, const T &n2) { - T zero = 0; - if (n2 == zero) { - return; - } - - T count = n1[0] + n2; - v1[0] = v1[0] + v2 + (m1[0] - m2) * (m1[0] - m2) * n1[0] * n2 / count; - m1[0] = (n1[0] * m1[0] + n2 * m2) / count; - n1[0] = count; -} - -template -inline __device__ void ThreadReduce(const int &col_dim, const T *block_addr, T *mean, T *var, T *num) { - int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; - for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { - for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { - int pos = NUM_PER_THREAD_REDUCE * i + j; - if (pos >= col_dim) { - return; - } - MeanAndVarAccumulation(mean, var, num, block_addr[pos]); - } - } -} - -template -inline __device__ void WarpReduce(T *mean, T *var, T *num) { - for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { - T mean_other = __shfl_down_sync(0xffffffff, mean[0], delta); - T var_other = __shfl_down_sync(0xffffffff, var[0], delta); - T num_other = __shfl_down_sync(0xffffffff, num[0], delta); - MeanAndVarMerge(mean, var, num, mean_other, var_other, num_other); - } -} - -template -inline __device__ void BlockReduce(const int &col_dim, T *mean, T *var, T *num, T *mean_addr, T *var_addr, - T *share_mem) { - if (threadIdx.x >= col_dim) { - return; - } - - // load data to share memory - // thread(0, 32, 64, 96, ...) keep the data - if (threadIdx.x % WARP_SIZE == 0) { - int offset = threadIdx.x / WARP_SIZE * 3; - share_mem[offset] = mean[0]; - share_mem[offset + 1] = var[0]; - share_mem[offset + 2] = num[0]; - } - __syncthreads(); - - for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { - if (threadIdx.x < stride) { - int offset = (threadIdx.x + stride) * 3; - MeanAndVarMerge(&share_mem[threadIdx.x * 3], &share_mem[threadIdx.x * 3 + 1], &share_mem[threadIdx.x * 3 + 2], - share_mem[offset], share_mem[offset + 1], share_mem[offset + 2]); - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - mean_addr[blockIdx.x] = share_mem[0]; - share_mem[1] /= col_dim; - var_addr[blockIdx.x] = share_mem[1]; - } -} - -template -inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const T *x, - const T *share_mem, const T *gamma, const T *beta, const T epsilon, T *y) { - for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { - int pos = row * col_dim + col; - int i = pos % param_dim; - y[pos] = (x[pos] - share_mem[0]) / sqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; - } -} - -template <> -inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const half *x, - const half *share_mem, const half *gamma, const half *beta, const half epsilon, - half *y) { - for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { - int pos = row * col_dim + col; - int i = pos % param_dim; - y[pos] = (x[pos] - share_mem[0]) / hsqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; - } -} - -template -__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x, - const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) { - for (auto row = blockIdx.x; row < row_dim; row += gridDim.x) { - T mean = 0; - T var = 0; - T num = 0; - const T *block_addr = x + row * col_dim; - DynamicSharedMem share_mem; - - ThreadReduce(col_dim, block_addr, &mean, &var, &num); - WarpReduce(&mean, &var, &num); - BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem.addr()); - - __syncthreads(); - LayerNorm(row, col_dim, param_dim, x, share_mem.addr(), gamma, beta, epsilon, y); - } -} - -template -void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *x, - const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) { - const dim3 block(row_dim); - const dim3 thread(256); - // keep the mean/var/num after warp reduce - int share_mem_size = - ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); - LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, - mean, var); -} - -template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, - const float *x, const float *gamma, const float *beta, float *y, float *mean, float *var, - cudaStream_t stream); -template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const half &epsilon, - const half *x, const half *gamma, const half *beta, half *y, half *mean, half *var, - cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh deleted file mode 100644 index c06a698384..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ - -#include "device/gpu/cuda_common.h" - -template -struct DynamicSharedMem; -template<> -struct DynamicSharedMem { - __device__ float *addr() { - extern __shared__ float addr_float[]; - return addr_float; - } -}; -template<> -struct DynamicSharedMem { - __device__ half *addr() { - extern __shared__ half addr_half[]; - return addr_half; - } -}; - -template -void LayerNorm(const int& outer, const int& inner, const int& param_dim, const T& epsilon, const T* x, const T* gamma, - const T* beta, T* y, T* mean, T* var, cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu deleted file mode 100644 index 27b2cb0232..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include "minmax_update_impl.cuh" -#include "device/gpu/cuda_common.h" - -__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min, - float *output_max, const float min, const float max, - const float decay) { - output_min[0] = decay * (min) + (1 - decay) * (input_min[0]); - output_min[0] = input_min[0] > 0 ? 0 : input_min[0]; - output_max[0] = decay * (max) + (1 - decay) * (input_max[0]); - output_max[0] = input_max[0] < 0 ? 0 : input_max[0]; - return; -} - -__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max) { - output_min[0] = min > 0 ? 0 : min; - output_max[0] = max < 0 ? 0 : max; - return; -} - -__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, - float *output_max, int channels, int per_channel_nums, bool ema, - float ema_decay) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { - thrust::pair sum = - thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); - if (ema) { - output_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; - output_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i]; - } else { - output_min[i] = sum.first[0]; - output_max[i] = sum.second[0]; - } - output_min[i] = input_min[i] > 0 ? 0 : input_min[i]; - output_max[i] = input_max[i] < 0 ? 0 : input_max[i]; - } - return; -} - -void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, - const int total_num, const int channel_num, const float ema_decay, const bool ema, - cudaStream_t cuda_stream) { - int per_channel_num = total_num / channel_num; - UpdateInputMinMaxPerChannel<<>>( - input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay); - return; -} - -void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, - const int total_num, const float ema_decay, const bool ema, cudaStream_t cuda_stream) { - float minel = 0.f; - float maxel = 0.f; - auto policy = thrust::cuda::par.on(cuda_stream); - thrust::pair, thrust::device_ptr> tuple; - tuple = - thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + total_num); - minel = tuple.first[0]; - maxel = tuple.second[0]; - - if (ema) { - UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel, - maxel, ema_decay); - } else { - UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel); - } - return; -} diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh deleted file mode 100644 index 5e9becab38..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ - -#include "device/gpu/cuda_common.h" - -void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, - const int total_num, const int channel_num, const float ema_decay, const bool ema, - cudaStream_t cuda_stream); - -void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, - const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh deleted file mode 100755 index 5405f5ef1d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, - const S *momentum, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/one_hot_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/one_hot_impl.cu deleted file mode 100644 index cf5dc7ecd0..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/one_hot_impl.cu +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "one_hot_impl.cuh" -#include "device/gpu/cuda_common.h" -template -__global__ void OneHotKernel(size_t size, const S *indices, size_t depth, const T *on_value, const T *off_value, - size_t left_dim_size, size_t right_dim_size, T *output) { - T on_v = *on_value; - T off_v = *off_value; - for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; - thread_idx += blockDim.x * gridDim.x) { - if (thread_idx < size) { - int left_idx = (thread_idx / (depth * right_dim_size)) % left_dim_size; - int d_idx = thread_idx / right_dim_size % depth; - int right_idx = thread_idx % right_dim_size; - int input_idx = left_idx * right_dim_size + right_idx; - int output_idx = left_idx * depth * right_dim_size + d_idx * right_dim_size + right_idx; - if (indices[input_idx] == d_idx) { - output[output_idx] = on_v; - } else { - output[output_idx] = off_v; - } - } - } -} -template -void OneHot(const S *indices, size_t depth, const T *on_value, const T *off_value, size_t left_dim_size, - size_t right_dim_size, T *output, cudaStream_t cuda_stream) { - size_t size = left_dim_size * depth * right_dim_size; - OneHotKernel<<>>(size, indices, depth, on_value, off_value, - left_dim_size, right_dim_size, output); - return; -} -template void OneHot(const int *indices, size_t depth, const float *on_value, const float *off_value, - size_t left_dim_size, size_t right_dim_size, float *output, cudaStream_t cuda_stream); -template void OneHot(const int *indices, size_t depth, const half *on_value, const half *off_value, - size_t left_dim_size, size_t right_dim_size, half *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cu deleted file mode 100755 index ddc615d94b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cu +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "kernel/gpu/cuda_impl/pad_impl.cuh" - -template -__global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, - const int pad_left, float pad_value, T* output) { - T pad_value_ = static_cast(pad_value); - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int block_num = pos / padded_width / padded_height; - const int padded_w = pos % padded_width; - const int padded_h = pos / padded_width % padded_height; - if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height || - padded_w - pad_left >= old_width) { - output[pos] = pad_value_; - } else { - output[pos] = input[(block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left]; - } - } - return; -} - -template -__global__ void PadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, - const int pad_left, T* dx) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int block_num = pos / old_width / old_height; - const int padded_w = pos % old_width + pad_left; - const int padded_h = pos / old_width % old_height + pad_top; - dx[pos] = dy[(block_num * padded_height + padded_h) * padded_width + padded_w]; - } - return; -} - -template -void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, - const float pad_value, T* output, cudaStream_t cuda_stream) { - Pad<<>>(size, input, num, channels, old_height, old_width, - padded_height, padded_width, pad_top, pad_left, pad_value, - output); - return; -} - -template -void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, - const int pad_left, T* dx, cudaStream_t cuda_stream) { - PadGrad<<>>(size, dy, num, channels, old_height, old_width, - padded_height, padded_width, pad_top, pad_left, dx); - return; -} - -template void CalPad(const size_t size, const float* input, const int num, const int channels, - const int old_height, const int old_width, const int padded_height, const int padded_width, - const int pad_top, const int pad_left, float pad_value, float* output, - cudaStream_t cuda_stream); -template void CalPadGrad(const size_t size, const float* dy, const int num, const int channels, - const int old_height, const int old_width, const int padded_height, - const int padded_width, const int pad_top, const int pad_left, float* dx, - cudaStream_t cuda_stream); -template void CalPad(const size_t size, const half* input, const int num, const int channels, - const int old_height, const int old_width, const int padded_height, const int padded_width, - const int pad_top, const int pad_left, float pad_value, half* output, - cudaStream_t cuda_stream); -template void CalPadGrad(const size_t size, const half* dy, const int num, const int channels, - const int old_height, const int old_width, const int padded_height, - const int padded_width, const int pad_top, const int pad_left, half* dx, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cuh deleted file mode 100755 index dc3036b8b6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ -#include -#include "device/gpu/cuda_common.h" - -template -void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, - float pad_value, T* output, cudaStream_t cuda_stream); -template -void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, - const int pad_left, T* dx, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu deleted file mode 100644 index 913aaa3b8d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "kernel/gpu/cuda_impl/rmsprop_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void RmsPropKernel(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, - T* mean_square, T*moment, T* gradients, const size_t size) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { - mean_square[i] = decay * mean_square[i] + (1.0 - decay) * gradients[i] * gradients[i]; - moment[i] = momentum * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon) * gradients[i]; - variable[i] -= moment[i]; - } -} - -template -void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, - T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) { - RmsPropKernel<<>>(learning_rate, decay, momentum, epsilon, - variable, mean_square, moment, gradients, size); -} - -template -__global__ void RmsPropCenterKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, - T* variable, T* mean_gradients, T* mean_square, T*moment, T* gradients, - const size_t size) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { - mean_gradients[i] = decay[0] * mean_gradients[i] + (1.0 - decay[0]) * gradients[i]; - mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; - moment[i] = momentum[0] * moment[i] + learning_rate[0] * - rsqrt(mean_square[i] - mean_gradients[i] * mean_gradients[i] + epsilon[0]) * gradients[i]; - variable[i] -= moment[i]; - } -} - -template -void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, - T* mean_gradients, T* mean_square, T*moment, T* gradients, const size_t size, - cudaStream_t cuda_stream) { - RmsPropCenterKernel<<>>(learning_rate, decay, momentum, epsilon, - variable, mean_gradients, mean_square, - moment, gradients, size); -} - -template -void RmsProp(const float* learning_rate, const float decay, const float momentum, const float epsilon, - float* variable, float* mean_square, float* moment, float* gradients, const size_t size, - cudaStream_t cuda_stream); - -template -void RmsPropCenter(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, - float* variable, float* mean_gradients, float* mean_square, float*moment, float* gradients, - const size_t size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh deleted file mode 100644 index b5802dbb67..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ -#include "device/gpu/cuda_common.h" - -template -void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, T* mean_square, - T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream); - -template -void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, - T* mean_gradients, T* mean_square, T* moment, T* gradients, const size_t size, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cu deleted file mode 100644 index f07a820e75..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cu +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "kernel/gpu/cuda_impl/select_impl.cuh" - -template -__global__ void Select(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - output[pos] = cond[pos] ? input_x[pos] : input_y[pos]; - } - return; -} - -template -void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, - cudaStream_t cuda_stream) { - Select<<>>(size, cond, input_x, input_y, output); - return; -} - -template void CalSelect(const size_t size, const bool* cond, const float* input_X, const float* input_y, - float* output, cudaStream_t cuda_stream); -template void CalSelect(const size_t size, const bool* cond, const int* input_X, const int* input_y, int* output, - cudaStream_t cuda_stream); -template void CalSelect(const size_t size, const bool* cond, const half* input_X, const half* input_y, - half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cuh deleted file mode 100644 index da2d7d9a7f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ - -#include "device/gpu/cuda_common.h" - -template -void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu deleted file mode 100644 index a0082b84c8..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh" - -template -__global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const T *logits, const S *labels, - T *outputs) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - if (logits[i] >= 0) { - outputs[i] = 1. / (1. + exp(-logits[i])) - labels[i]; - } else { - const T exp_val = exp(logits[i]); - outputs[i] = exp_val / (1. + exp_val) - labels[i]; - } - } -} - -template -void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, - cudaStream_t cuda_stream) { - SigmoidCrossEntropyWithLogitsGradKernel<<>>(size, logits, labels, - outputs); -} - -template void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const float *logits, - const float *labels, float *outputs, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh deleted file mode 100644 index 2cd4922d25..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu deleted file mode 100644 index 3766f367db..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh" - -template -__global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - const T reverse_factor = static_cast(logits[i] >= 0); - outputs[i] = log1p(exp(logits[i] - 2 * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); - } -} - -template -void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *labels, T *outputs, - cudaStream_t cuda_stream) { - SigmoidCrossEntropyWithLogitsKernel<<>>(size, logits, labels, outputs); -} - -template void SigmoidCrossEntropyWithLogits(const size_t size, const float *logits, const float *labels, - float *outputs, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh deleted file mode 100644 index 575605bde0..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *labels, T *outputs, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu deleted file mode 100755 index e49a22bb46..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "kernel/gpu/cuda_impl/slice_impl.cuh" - -template -__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const T *input, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) { - int i = pos / (l2 * l3 * l4) % l1; - int j = pos / (l3 * l4) % l2; - int k = pos / l4 % l3; - int o = pos % l4; - - int offset = (i + s1) * (d2 * d3 * d4) + - (j + s2) * (d3 * d4) + - (k + s3) * d4 + - (o + s4); - output[pos] = input[offset]; - } -} -template -__global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) { - output[start + pos] = dy[p + pos]; - } - return; -} -template -__global__ void StridedSlice(const T* input, int p, int start, int begin, int stride, int ended, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); - pos += blockDim.x * gridDim.x) { - output[p + pos] = input[start + pos * stride]; - } - return; -} -template -__global__ void StridedSliceGrad(const T* dy, int p, int start, int begin, int stride, int ended, T* dx) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); - pos += blockDim.x * gridDim.x) { - dx[start + pos * stride] = dy[p + pos]; - } - return; -} -template -__global__ void FillArray(T* addr, const size_t len, const float value) { - T value_ = static_cast(value); - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < len; pos += blockDim.x * gridDim.x) { - addr[pos] = value_; - } - return; -} -template -void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream) { - FillArray<<>>(addr, input_size, value); - return; -} -template -void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const T *input, T *output, cudaStream_t stream) { - Slice4D<<>>(s1, s2, s3, s4, l1, l2, l3, l4, - d1, d2, d3, d4, input, output); -} -template -void CalSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, const std::vector begin, - const std::vector size, T* output, cudaStream_t cuda_stream) { - int block = in_shape[1] * in_shape[2] * in_shape[3]; - int map = in_shape[2] * in_shape[3]; - int w = in_shape[3]; - int length = size[3]; - int p = 0; - for (int i = begin[0]; i < size[0] + begin[0]; i++) { - for (int j = begin[1]; j < size[1] + begin[1]; j++) { - for (int k = begin[2]; k < size[2] + begin[2]; k++) { - SliceGrad<<>>( - dy, p, i * block + j * map + k * w + begin[3], length, output); - p = p + size[3]; - } - } - } -} -template -void CalStridedSlice(const size_t input_size, const T* input, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* output, cudaStream_t cuda_stream) { - int block = in_shape[1] * in_shape[2] * in_shape[3]; - int map = in_shape[2] * in_shape[3]; - int w = in_shape[3]; - int ended = end[3]; - int p = 0; - int start = 0; - for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0])); i += std::abs(strides[0])) { - for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1])); j += std::abs(strides[1])) { - for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2])); k += std::abs(strides[2])) { - start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + - (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; - StridedSlice<<>>(input, p, start, begin[3], strides[3], - ended, output); - p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); - } - } - } -} -template -void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* dx, cudaStream_t cuda_stream) { - int block = in_shape[1] * in_shape[2] * in_shape[3]; - int map = in_shape[2] * in_shape[3]; - int w = in_shape[3]; - int ended = end[3]; - int p = 0; - int start = 0; - for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0] + 1)); i += std::abs(strides[0])) { - for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1] + 1)); - j += std::abs(strides[1])) { - for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2] + 1)); - k += std::abs(strides[2])) { - start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + - (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; - StridedSliceGrad<<>>(dy, p, start, begin[3], strides[3], - ended, dx); - p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); - } - } - } -} - -template void FillDeviceArray(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const float *input, float *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, - const std::vector begin, const std::vector size, float* output, - cudaStream_t cuda_stream); -template void CalStridedSlice(const size_t input_size, const float* input, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, float* output, cudaStream_t cuda_stream); -template void CalStridedSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, float* dx, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const half *input, half *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, - const std::vector begin, const std::vector size, half* output, - cudaStream_t cuda_stream); -template void CalStridedSlice(const size_t input_size, const half* input, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, half* output, cudaStream_t cuda_stream); -template void CalStridedSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, half* dx, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const int *input, int *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, - const std::vector begin, const std::vector size, int* output, - cudaStream_t cuda_stream); -template void CalStridedSlice(const size_t input_size, const int* input, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, int* output, cudaStream_t cuda_stream); -template void CalStridedSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, int* dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh deleted file mode 100755 index 9513d6ed24..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ - -#include -#include -#include "device/gpu/cuda_common.h" - - -template -void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const T *input, T *output, cudaStream_t stream); -template -void CalSliceGrad(const size_t input_size, const T* input, const std::vector in_shape, - const std::vector begin, const std::vector size, T* output, cudaStream_t cuda_stream); -template -void CalStridedSlice(const size_t input_size, const T* input, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* output, cudaStream_t cuda_stream); -template -void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* dx, cudaStream_t cuda_stream); -template -void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh deleted file mode 100755 index d16131470c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ - -#include "device/gpu/cuda_common.h" - -template -void CalCrossEntropy(const float *logits, T *labels, const int batch_size, const int class_num, float *loss, - cudaStream_t cuda_stream); - -template -void CalCrossEntropyGrad(const float *logits, T *labels, const int batch_size, const int class_num, float *grad, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/transpose_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/transpose_impl.cu deleted file mode 100755 index a0fea90136..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/transpose_impl.cu +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "transpose_impl.cuh" -#include "device/gpu/cuda_common.h" -template -__global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis, - const int shape_size, T* output) { - int pos_size; - int temp_pos; - int newpos; - int newpos_size; - int pos_array[TRANSPOSE_MAX_DIMENSION]; - - // for example 4-D: pos = posArray[0] * input_shape[1] * input_shape[2] * input_shape[3] + - // posArray[1] * input_shape[2] * input_shape[3] + - // posArray[2] * input_shape[3] + - // posArray[3] - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - temp_pos = pos; - pos_size = size / input_shape[0]; - pos_array[0] = temp_pos / pos_size; - for (int i = 1; i < shape_size; i++) { - temp_pos -= pos_array[i - 1] * pos_size; - pos_size = pos_size / input_shape[i]; - pos_array[i] = temp_pos / pos_size; - } - - newpos = pos_array[input_axis[shape_size - 1]]; - newpos_size = 1; - for (int j = shape_size - 2; j >= 0; j--) { - newpos_size *= input_shape[input_axis[j + 1]]; - newpos += pos_array[input_axis[j]] * newpos_size; - } - - output[newpos] = input[pos]; - } - return; -} -template -void CalTranspose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size, - T* output, cudaStream_t cuda_stream) { - Transpose<<>>(size, input, input_shape, input_axis, shape_size, - output); - return; -} - -template void CalTranspose(const int size, const float* input, const int* input_shape, const int* input_axis, - const int shape_size, float* output, cudaStream_t cuda_stream); -template void CalTranspose(const int size, const half* input, const int* input_shape, const int* input_axis, - const int shape_size, half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh deleted file mode 100755 index 623b1a8c03..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Logarithm(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Negative(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cu deleted file mode 100644 index a7affd4705..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cu +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/unsorted_segment_sum.cuh" - -template -__global__ void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - T* input_addr, S* ids_addr, T* output_addr) { - for (int input_index = blockIdx.x * blockDim.x + threadIdx.x; input_index < input_dim0 * input_dim1; - input_index += blockDim.x * gridDim.x) { - size_t j = input_index / input_dim1; - size_t k = input_index % input_dim1; - - S i = ids_addr[j]; - if (i < 0 || i >= output_dim0) { - continue; - } - size_t output_index = i * output_dim1 + k; - atomicAdd(output_addr + output_index, input_addr[input_index]); - } -} - -template -void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - T* input_addr, S* ids_addr, T* output_addr, cudaStream_t stream) { - int size = input_dim0 * input_dim1; - UnsortedSegmentSum<<>>(input_dim0, input_dim1, - output_dim0, output_dim1, input_addr, ids_addr, output_addr); - return; -} - -template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - float* input_addr, int* ids_addr, float* output_addr, cudaStream_t stream); -template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - float* input_addr, int64_t* ids_addr, float* output_addr, cudaStream_t stream); - -template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - int* input_addr, int* ids_addr, int* output_addr, cudaStream_t stream); -template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - int* input_addr, int64_t* ids_addr, int* output_addr, cudaStream_t stream); - - - diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cuh deleted file mode 100644 index ef95032996..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cuh +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ - -#include -#include "device/gpu/cuda_common.h" - -template -void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - T* input_addr, S* ids, T* output_addr, cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.cc b/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.cc deleted file mode 100644 index 777310cebc..0000000000 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/data/dataset_init_kernel.h" -#include "kernel/gpu/data/dataset_utils.h" -#include "device/gpu/gpu_buffer_mgr.h" -#include "device/gpu/gpu_memory_allocator.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace kernel { -using mindspore::device::GpuBufferMgr; - -DatasetInitKernel::DatasetInitKernel() : total_bytes_(0) {} - -const std::vector &DatasetInitKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &DatasetInitKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &DatasetInitKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool DatasetInitKernel::Init(const CNodePtr &kernel_node) { - queue_name_ = GetAttr(kernel_node, "queue_name"); - auto shapes = GetAttr>>(kernel_node, "shapes"); - auto types = GetAttr>(kernel_node, "types"); - if (shapes.size() != types.size()) { - MS_LOG(EXCEPTION) << "Invalid shapes: " << shapes << ", types: " << types; - } - - for (size_t i = 0; i < shapes.size(); i++) { - int unit = UnitSizeInBytes(types[i]->type_id()); - int nums = ElementNums(shapes[i]); - int bytes = unit * nums; - shapes_.push_back(bytes); - total_bytes_ += bytes; - } - return true; -} - -void DatasetInitKernel::InitSizeLists() { return; } - -bool DatasetInitKernel::Launch(const std::vector &, const std::vector &, - const std::vector &, void *) { - void *addr = nullptr; - size_t len = total_bytes_ * buffer_q_capacity_; - - if (!device::gpu::GPUMemoryAllocator::GetInstance().AllocBufferQueueMem(len, &addr)) { - MS_LOG(EXCEPTION) << "Memory not enough: failed to allocate GPU buffer queue memory[" << len << "]."; - } - - auto status = GpuBufferMgr::GetInstance().Create(0, queue_name_, addr, shapes_, buffer_q_capacity_); - if (status) { - MS_LOG(EXCEPTION) << "Init Dataset Failed. len: " << len << ", status:" << status; - } - - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.h b/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.h deleted file mode 100644 index 318049f4ad..0000000000 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_DATASET_INIT_KERNEL_H -#define MINDSPORE_DATASET_INIT_KERNEL_H - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class DatasetInitKernel : public GpuKernel { - public: - DatasetInitKernel(); - ~DatasetInitKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel_node) override; - - protected: - void InitSizeLists() override; - - private: - std::string queue_name_; - std::vector shapes_; - size_t total_bytes_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - // The capacity of buffer Q. - size_t buffer_q_capacity_{2}; -}; - -MS_REG_GPU_KERNEL(InitDataSetQueue, DatasetInitKernel) -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_QUEUE_CPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc b/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc deleted file mode 100644 index 13ca191b0b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/data/dataset_iterator_kernel.h" -#include -#include -#include -#include "device/gpu/gpu_buffer_mgr.h" -#include "device/gpu/gpu_common.h" -#include "kernel/gpu/data/dataset_utils.h" - -namespace mindspore { -namespace kernel { -using mindspore::device::GpuBufferMgr; -using mindspore::device::HandleMgr; - -DatasetIteratorKernel::DatasetIteratorKernel() : handle_(HandleMgr::INVALID_HANDLE), total_bytes_(0) {} - -DatasetIteratorKernel::~DatasetIteratorKernel() { GpuBufferMgr::GetInstance().Close(handle_); } - -const std::vector &DatasetIteratorKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &DatasetIteratorKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &DatasetIteratorKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) { - queue_name_ = GetAttr(kernel_node, "shared_name"); - auto shapes = GetAttr>>(kernel_node, "shapes"); - auto types = GetAttr>(kernel_node, "types"); - if (shapes.size() != types.size()) { - MS_LOG(EXCEPTION) << "Invalid shapes: " << shapes << ", types: " << types; - } - - for (size_t i = 0; i < shapes.size(); i++) { - int unit = UnitSizeInBytes(types[i]->type_id()); - int nums = ElementNums(shapes[i]); - int bytes = unit * nums; - output_size_list_.push_back(bytes); - total_bytes_ += bytes; - } - - handle_ = GpuBufferMgr::GetInstance().Open(0, queue_name_, output_size_list_); - if (handle_ == HandleMgr::INVALID_HANDLE) { - MS_LOG(EXCEPTION) << "Gpu Queue(" << queue_name_ << ") Open Failed"; - } - - return true; -} - -void DatasetIteratorKernel::InitSizeLists() { return; } - -bool DatasetIteratorKernel::Launch(const std::vector &, const std::vector &, - const std::vector &outputs, void *stream) { - void *addr = nullptr; - size_t len = 0; - - int repeat = 0; - while (true) { - auto ret = GpuBufferMgr::GetInstance().Front(handle_, &addr, &len); - if (ret == device::SUCCESS) { - break; - } - - if (ret == device::TIMEOUT) { - repeat++; - if (repeat < 10) { - MS_LOG(INFO) << "Waiting for data...(" << repeat << " / 10)"; - continue; - } else { - MS_LOG(ERROR) << "Get data timeout"; - return false; - } - } - - MS_LOG(ERROR) << "Get data failed, errcode " << ret; - return false; - } - - if (total_bytes_ != len) { - MS_LOG(ERROR) << "Dataset front error. read: " << len << ", expect: " << total_bytes_ << ", "; - return false; - } - - for (size_t i = 0; i < output_size_list_.size(); i++) { - void *output_addr = GetDeviceAddress(outputs, i); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, addr, output_size_list_[i], cudaMemcpyDeviceToDevice, - reinterpret_cast(stream)), - "Cuda Memcpy Failed"); - addr = reinterpret_cast(addr) + output_size_list_[i]; - } - - CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream)), - "cudaStreamSynchronize failed"); - (void)GpuBufferMgr::GetInstance().Pop(handle_); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.h b/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.h deleted file mode 100644 index cdd7a47e7b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_GET_NEXT_KERNEL_H -#define MINDSPORE_GET_NEXT_KERNEL_H - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class DatasetIteratorKernel : public GpuKernel { - public: - DatasetIteratorKernel(); - ~DatasetIteratorKernel(); - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel_node) override; - - protected: - void InitSizeLists() override; - - private: - std::string queue_name_; - unsigned int handle_; - size_t total_bytes_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -MS_REG_GPU_KERNEL(GetNext, DatasetIteratorKernel) -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_QUEUE_CPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_utils.cc b/mindspore/ccsrc/kernel/gpu/data/dataset_utils.cc deleted file mode 100644 index 846a63f84f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_utils.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/data/dataset_utils.h" - -namespace mindspore { -namespace kernel { -size_t UnitSizeInBytes(const mindspore::TypeId &t) { - size_t bytes = 0; - switch (t) { - case kNumberTypeBool: - case kNumberTypeInt8: - case kNumberTypeUInt8: - bytes = 1; - break; - case kNumberTypeInt16: - case kNumberTypeUInt16: - case kNumberTypeFloat16: - bytes = 2; - break; - case kNumberTypeInt: - case kNumberTypeUInt: - case kNumberTypeInt32: - case kNumberTypeUInt32: - case kNumberTypeFloat: - case kNumberTypeFloat32: - bytes = 4; - break; - case kNumberTypeUInt64: - case kNumberTypeInt64: - case kNumberTypeFloat64: - bytes = 8; - break; - default: - MS_LOG(EXCEPTION) << "Invalid types " << t; - break; - } - - return bytes; -} - -int ElementNums(const std::vector &shape) { - if (shape.size() == 0) { - return 0; - } - - int nums = 1; - for (size_t i = 0; i < shape.size(); i++) { - nums *= shape[i]; - } - - return nums; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/gpu_kernel.h deleted file mode 100644 index c935798f06..0000000000 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ - -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "kernel/gpu/kernel_constants.h" -#include "device/gpu/gpu_device_manager.h" -#include "device/gpu/gpu_common.h" -#include "session/anf_runtime_algorithm.h" -using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; - -namespace mindspore { -namespace kernel { -class GpuKernel : public KernelMod { - public: - virtual ~GpuKernel() = default; - virtual bool Init(const CNodePtr &kernel_node) = 0; - - protected: - virtual void InitResource() {} - virtual void InitSizeLists() = 0; - - template - inline T *GetDeviceAddress(const std::vector &addr_list, size_t index) { - if (index >= addr_list.size()) { - MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; - } - // Kernels may run normally without workspace, the addr_list[index] maybe nullptr. - if ((addr_list[index] == nullptr) || (addr_list[index]->size == 0)) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(addr_list[index]->addr); - return reinterpret_cast(addr_list[index]->addr); - } - - template - inline T GetAttr(const CNodePtr &kernel_node, const std::string &key) const { - const PrimitivePtr &prim = AnfAlgo::GetCNodePrimitive(kernel_node); - const ValuePtr &attr = prim->GetAttr(key); - if (attr == nullptr) { - const std::string &prim_name = AnfAlgo::GetCNodeName(kernel_node); - MS_LOG(EXCEPTION) << "The attr(" << key << ") of kernel(" << prim_name << ") not exist"; - } - return GetValue(attr); - } - // expand Nd Shape to 4d (N in [0,4]) - void ShapeNdTo4d(const std::vector &src, std::vector *dst) { - if (src.size() > 4) { - MS_EXCEPTION(ValueError) << src.size() << "-D data is not supported!"; - } - dst->push_back(src.size() < 4 ? 1 : SizeToInt(src[src.size() - 4])); - dst->push_back(src.size() < 3 ? 1 : SizeToInt(src[src.size() - 3])); - dst->push_back(src.size() < 2 ? 1 : SizeToInt(src[src.size() - 2])); - dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1])); - } - - inline void CheckBroadcast4TensorOp(const std::vector &A, const std::vector &B, - const std::vector &Out) { - if (A != Out && B != Out) { - MS_EXCEPTION(ValueError) - << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n" - "InputA must match the corresponding dimension of the destination tensor outC, and each " - "dimension of the inputB " - "must match the corresponding dimension of outC or must be equal to 1."; - } - } - - // choose the suitable datatype for cudnn/cublas - inline cudnnDataType_t GetCudnnDataType(const std::string &Type) { - auto type = kCudnnDtypeMap.find(Type); - if (type == kCudnnDtypeMap.end()) { - MS_EXCEPTION(TypeError) << Type << " is not supported."; - } - return type->second; - } - inline cudaDataType_t GetCudaDataType(const std::string &Type) { - auto type = kCudaDtypeMap.find(Type); - if (type == kCudaDtypeMap.end()) { - MS_EXCEPTION(TypeError) << Type << " is not supported."; - } - return type->second; - } -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc deleted file mode 100644 index b00b5c263d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/gpu_kernel_factory.h" - -#include -#include - -#include "common/utils.h" -#include "device/kernel_info.h" -#include "device/gpu/cuda_common.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -GpuKernelFactory &GpuKernelFactory::GetInstance() { - static GpuKernelFactory instance; - return instance; -} - -void GpuKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr, - GpuKernelCreater &&creater) { - map_kernel_name_to_creater_[kernel_name].emplace_back(kernel_attr, creater); -} - -void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, - std::vector> *iter_second, - size_t attr_index) { - if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { - if (iter_second->at(attr_index).first.GetAllSame()) { - auto dtype = iter_second->at(attr_index).first.GetInputAttr(0).first; - for (size_t attr = 1; attr < kernel_info->GetInputNum(); ++attr) { - (void)iter_second->at(attr_index).first.AddInputAttr(dtype); - } - } else { - MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; - } - } - if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { - if (iter_second->at(attr_index).first.GetAllSame()) { - auto dtype = iter_second->at(attr_index).first.GetOutputAttr(0).first; - for (size_t attr = 1; attr < kernel_info->GetOutputNum(); ++attr) { - (void)iter_second->at(attr_index).first.AddOutputAttr(dtype); - } - } else { - MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; - } - } -} - -std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) { - std::string type_lists = ""; - auto iter = map_kernel_name_to_creater_.find(kernel_name); - if (map_kernel_name_to_creater_.end() == iter) { - return type_lists; - } - for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { - std::string type_list = "in["; - auto attr = (iter->second)[attr_index].first; - for (size_t input_index = 0; input_index < attr.GetInputSize(); ++input_index) { - type_list = type_list + TypeId2String(attr.GetInputAttr(input_index).first) + - ((input_index == (attr.GetInputSize() - 1)) ? "" : " "); - } - type_list = type_list + "], out["; - for (size_t input_index = 0; input_index < attr.GetOutputSize(); ++input_index) { - type_list = type_list + TypeId2String(attr.GetOutputAttr(input_index).first) + - ((input_index == (attr.GetOutputSize() - 1)) ? "" : " "); - } - type_lists = type_lists + type_list + "]; "; - } - return type_lists; -} - -std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string &kernel_name, - const KernelBuildInfo *kernel_info) { - auto iter = map_kernel_name_to_creater_.find(kernel_name); - const int marjor_sm = GET_MAJOR_SM; - if (map_kernel_name_to_creater_.end() == iter) { - MS_LOG(INFO) << "Not registered GPU kernel: op[" << kernel_name << "]!"; - return std::make_pair(false, 0); - } - if ((iter->second).size() == 1 && (iter->second)[0].first.GetInputSize() == 0) { - return std::make_pair(true, 0); - } - - for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { - CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index); - bool flag = true; - // data type matching check of all input parameters of kernel - for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { - if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { - if (marjor_sm < MINIUM_SM) { - MS_LOG(EXCEPTION) << "Half precision ops can be used on Devices which computing capacity is >= " << MINIUM_SM - << ", but the current device's computing capacity is " << marjor_sm; - } - MS_LOG(WARNING) << "It is recommended to use devices with a computing capacity >= " << RECOMMEND_SM - << ", but the current device's computing capacity is " << marjor_sm; - } - if (kernel_info->GetInputDeviceType(input_index) != - (iter->second)[attr_index].first.GetInputAttr(input_index).first) { - flag = false; - break; - } - } - if (!flag) { - continue; - } - // data type matching check of all output parameters of kernel - for (size_t output_index = 0; output_index < kernel_info->GetOutputNum(); output_index++) { - if (kernel_info->GetOutputDeviceType(output_index) != - (iter->second)[attr_index].first.GetOutputAttr(output_index).first) { - flag = false; - break; - } - } - // finish data type matching check and return a pair maintain the whether matching is success, - // if first is true, second is index of matching KernelAttr and creater pair in vector; - if (flag) { - size_t match_index = attr_index; - return std::make_pair(true, match_index); - } - } - return std::make_pair(false, 0); -} - -GpuKernel *GpuKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { - auto kernel_info = apply_kernel->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(kernel_build_Info); - std::pair ret_pair = GpuKernelAttrCheck(kernel_name, kernel_build_Info); - if (ret_pair.first) { - return (map_kernel_name_to_creater_.find(kernel_name)->second)[ret_pair.second].second(); - } - return nullptr; -} - -bool GpuKernelFactory::SearchRegistered(const std::string &kernel_name, const KernelBuildInfoPtr &kernel_build_info) { - std::pair ret_pair = GpuKernelAttrCheck(kernel_name, kernel_build_info.get()); - return ret_pair.first; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.h b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.h deleted file mode 100644 index dc5f61a315..0000000000 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ - -#include -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "device/gpu/kernel_info_setter.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -using mindspore::device::gpu::KernelAttr; -using GpuKernelCreater = std::function; -class GpuKernelFactory { - public: - ~GpuKernelFactory() = default; - - static GpuKernelFactory &GetInstance(); - - void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, GpuKernelCreater &&creater); - - GpuKernel *Create(const std::string &kernel_name, const CNodePtr &apply_kernel); - - bool SearchRegistered(const std::string &kernel_name, const KernelBuildInfoPtr &kernel_info); - - std::string SupportedTypeList(const std::string &kernel_name); - - private: - GpuKernelFactory() = default; - - GpuKernelFactory(GpuKernelFactory const &); - - GpuKernelFactory &operator=(const GpuKernelFactory &); - - std::pair GpuKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info); - void CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, - std::vector> *iter_second, size_t attr_index); - // map to maintain kernel and creater, KernelAttr object and creater must be registered as a pair. - std::map>> map_kernel_name_to_creater_; -}; - -class GpuKernelRegister { - public: - GpuKernelRegister(const std::string &kernel_name, const KernelAttr &kernel_attr, GpuKernelCreater &&creater) { - GpuKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(creater)); - } -}; - -#define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \ - static_assert(std::is_base_of::value, " must be base of GpuKernel"); \ - static const GpuKernelRegister g_##OPNAME##_gpu_kernel_reg(#OPNAME, KernelAttr(), []() { return new OPCLASS(); }); - -// regular register of fixed accuracy kernels -#define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \ - static_assert(std::is_base_of::value, " must be base of GpuKernel"); \ - static const GpuKernelRegister g_##OPNAME##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); - -// register of mixed accuracy kernels which use template and maintain one typename, ignore input num -#define MS_REG_GPU_KERNEL_SAME(OPNAME, ATTR, OPCLASS, T) \ - static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ - static const GpuKernelRegister g_##OPNAME##_##T##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); - -// register of mixed accuracy kernels which use template and maintain one typename -#define MS_REG_GPU_KERNEL_ONE(OPNAME, ATTR, OPCLASS, T) \ - static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ - static const GpuKernelRegister g_##OPNAME##_##T##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); - -// register of mixed accuracy kernels which use template and maintain two typename -#define MS_REG_GPU_KERNEL_TWO(OPNAME, ATTR, OPCLASS, T, S) \ - static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ - static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \ - []() { return new OPCLASS(); }); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.cc deleted file mode 100644 index 4683f015ae..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.cc +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/addn_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AddNGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE( - AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - AddNGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(AddN, - KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - AddNGpuFwdKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.h deleted file mode 100644 index 1498da777f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.h +++ /dev/null @@ -1,134 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class AddNGpuFwdKernel : public GpuKernel { - public: - AddNGpuFwdKernel() - : cudnn_handle_(nullptr), - input_descriptor_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - input_size_(0), - output_size_(0), - workspace_size_(0), - is_null_input_(false), - num_input_(0) {} - ~AddNGpuFwdKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *) override { - if (is_null_input_) { - return true; - } - T *output_addr = GetDeviceAddress(outputs, 0); - const float alpha = 1; - const float beta = 0; - for (size_t i = 0; i < IntToSize(num_input_); i++) { - T *input_addr = GetDeviceAddress(inputs, i); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr, - &(i > 0 ? alpha : beta), input_descriptor_, output_addr), - "cudnnAddTensor failed"); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - num_input_ = GetAttr(kernel_node, "n"); - if (IntToSize(num_input_) != input_num) { - MS_LOG(ERROR) << "Input number is " << num_input_ << " in attr, but got " << input_num << "input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but cudnnAddTensor needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "AddNGpuFwdKernel input is null"; - InitSizeLists(); - return true; - } - for (size_t i = input_shape.size(); i < 4; i++) { - (void)input_shape.insert(input_shape.begin(), 1); - } - int dimA[4]; - for (size_t i = 0; i < input_shape.size(); i++) { - dimA[i] = SizeToInt(input_shape[i]); - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - SizeToInt(input_shape.size()), dimA), - "cudnnSetTensorNdDescriptor failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(input_descriptor_, reinterpret_cast(&input_size_)), - "cudnnGetTensorSizeInBytes failed"); - } - for (int i = 0; i < num_input_; i++) { - input_size_list_.push_back(input_size_); - } - output_size_list_.push_back(input_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); - } - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t input_descriptor_; - cudnnDataType_t cudnn_data_type_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - bool is_null_input_; - int num_input_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.cc deleted file mode 100644 index 2ae1728ca3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/assign_add_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - AssignAddGpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE( - AssignAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AssignAddGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE( - AssignAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - AssignAddGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.h deleted file mode 100644 index db69fd7be6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.h +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/assign_add_impl.cuh" -namespace mindspore { -namespace kernel { -template -class AssignAddGpuFwdKernel : public GpuKernel { - public: - AssignAddGpuFwdKernel() : is_null_input_(false), input_size_(0) {} - ~AssignAddGpuFwdKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input_addr = GetDeviceAddress(inputs, 0); - T *input_addr2 = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - - CalAssignAdd(input_size_ / sizeof(T), input_addr, input_addr2, output_addr, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but cudnnAddTensor needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "AssignAddGpuFwdKernel input is null"; - InitSizeLists(); - return true; - } - input_size_ = sizeof(T); - for (size_t i : input_shape) { - input_size_ = i * input_size_; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.cc deleted file mode 100644 index 5684f0c424..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/bias_add_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - BiasAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BiasAddGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - BiasAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BiasAddGpuKernel, float16) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.h deleted file mode 100644 index 5a664db2e1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.h +++ /dev/null @@ -1,149 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_BIAS_ADD_GPU_KERNEL_H -#define MINDSPORE_BIAS_ADD_GPU_KERNEL_H -#include -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class BiasAddGpuKernel : public GpuKernel { - public: - BiasAddGpuKernel() - : cudnn_handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - x_desc_(nullptr), - b_desc_(nullptr), - op_desc_(nullptr), - is_null_input_(false) {} - ~BiasAddGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - VARIABLE_NOT_USED(stream_ptr); - if (is_null_input_) { - return true; - } - - T *x_addr = GetDeviceAddress(inputs, 0); - T *b_addr = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - - try { - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnOpTensor(cudnn_handle_, op_desc_, &alpha, x_desc_, x_addr, &alpha, b_desc_, - b_addr, &beta, x_desc_, output_addr), - "cudnnOpTensor failed"); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cudnnOpTensor"; - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto num_dims = x_shape.size(); - is_null_input_ = CHECK_NULL_INPUT(x_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "input is null"; - InitSizeLists(); - return true; - } - - if (num_dims < 2) { - MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims; - } - - std::string format = GetAttr(kernel_node, "data_format"); - string::size_type pos = format.find("C"); - if (pos == std::string::npos || pos >= num_dims) { - MS_LOG(EXCEPTION) << "format '" << format << "' invalid"; - } - - // Expand to 4 dims for cudnnSetTensorNdDescriptorEx. - auto cudnn_dims = std::max(num_dims, 4UL); - std::unique_ptr x_dims = std::make_unique(cudnn_dims); - std::unique_ptr b_dims = std::make_unique(cudnn_dims); - for (size_t i = 0; i < cudnn_dims; i++) { - x_dims[i] = (i < num_dims) ? SizeToInt(x_shape[i]) : 1; - b_dims[i] = (i == pos) ? SizeToInt(x_shape[i]) : 1; - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()), - "cudnnSetTensorNdDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(b_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()), - "cudnnSetTensorNdDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetOpTensorDescriptor(op_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN), - "cudnnSetOpTensorDescriptor failed"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&b_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateOpTensorDescriptor(&op_desc_), "cudnnCreateOpTensorDescriptor failed"); - } - void InitSizeLists() override { - size_t x_size, b_size; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &x_size), "cudnnGetTensorSizeInBytes failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(b_desc_, &b_size), "cudnnGetTensorSizeInBytes failed."); - input_size_list_.push_back(x_size); - input_size_list_.push_back(b_size); - output_size_list_.push_back(x_size); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyOpTensorDescriptor(op_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(b_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyOpTensorDescriptor failed"); - } - - cudnnHandle_t cudnn_handle_; - cudnnDataType_t cudnn_data_type_; - cudnnTensorDescriptor_t x_desc_; - cudnnTensorDescriptor_t b_desc_; - cudnnOpTensorDescriptor_t op_desc_; - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_BIAS_ADD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc deleted file mode 100644 index 96d51b704c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/broadcast_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -// fp32 -MS_REG_GPU_KERNEL_TWO( - Greater, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, float, bool) -MS_REG_GPU_KERNEL_TWO( - Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, float, bool) -MS_REG_GPU_KERNEL_TWO( - Maximum, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - Minimum, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - RealDiv, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - TensorAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) - -// fp16 -MS_REG_GPU_KERNEL_TWO( - Greater, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, half, bool) -MS_REG_GPU_KERNEL_TWO( - Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, half, bool) -MS_REG_GPU_KERNEL_TWO( - Maximum, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - Minimum, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - RealDiv, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - TensorAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) - -// int32 -MS_REG_GPU_KERNEL_TWO( - TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( - Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( - Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( - Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h deleted file mode 100644 index be7d3a19d4..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/broadcast_impl.cuh" -#include "kernel/gpu/kernel_constants.h" -namespace mindspore { -namespace kernel { -template -class BroadcastOpGpuKernel : public GpuKernel { - public: - BroadcastOpGpuKernel() - : op_type_(BROADCAST_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} - ~BroadcastOpGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *lhs = GetDeviceAddress(inputs, 0); - T *rhs = GetDeviceAddress(inputs, 1); - S *output = GetDeviceAddress(outputs, 0); - - if (need_broadcast_) { - Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2], - rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs, - rhs, output, reinterpret_cast(stream_ptr)); - } else { - NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast(stream_ptr)); - } - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - GetOpType(kernel_node); - auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); - need_broadcast_ = IsBroadcast(shape1, shape2); - if (need_broadcast_ && shape1.size() > 4) { - MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; - } - - for (size_t i = 0; i < shape3.size(); i++) { - output_shape_[i] = shape3[i]; - output_num_ *= shape3[i]; - } - int lhs_offset = shape3.size() - shape1.size(); - for (size_t j = 0; j < shape1.size(); j++) { - lhs_shape_[j + lhs_offset] = shape1[j]; - input1_num_ *= shape1[j]; - } - int rhs_offset = shape3.size() - shape2.size(); - for (size_t k = 0; k < shape2.size(); k++) { - rhs_shape_[k + rhs_offset] = shape2[k]; - input2_num_ *= shape2[k]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { return; } - void InitSizeLists() override { - input_size_list_.push_back(input1_num_ * sizeof(T)); - input_size_list_.push_back(input2_num_ * sizeof(T)); - output_size_list_.push_back(output_num_ * sizeof(S)); - } - - private: - void GetOpType(const CNodePtr &kernel_node) { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - - static std::map kBroadcastTypeMap = { - {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, - {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, - {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, - }; - - auto iter = kBroadcastTypeMap.find(kernel_name); - if (iter == kBroadcastTypeMap.end()) { - MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; - } else { - op_type_ = iter->second; - } - } - - bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { - if (lhs.size() != rhs.size()) { - return true; - } - for (size_t i = 0; i < lhs.size(); i++) { - if (lhs[i] != rhs[i]) { - return true; - } - } - return false; - } - - BroadcastOpType op_type_; - bool need_broadcast_; - int input1_num_; - int input2_num_; - int output_num_; - int lhs_shape_[4] = {1, 1, 1, 1}; - int rhs_shape_[4] = {1, 1, 1, 1}; - int output_shape_[4] = {1, 1, 1, 1}; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.cc deleted file mode 100644 index 85598cf940..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/broadcast_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(MinimumGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(MaximumGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(MinimumGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - BroadcastOpGradGpuKernel, int) -MS_REG_GPU_KERNEL_ONE(MaximumGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - BroadcastOpGradGpuKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h deleted file mode 100644 index f1eb5fecf9..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/broadcast_grad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" -namespace mindspore { -namespace kernel { -template -class BroadcastOpGradGpuKernel : public GpuKernel { - public: - BroadcastOpGradGpuKernel() - : op_type_(BROADCAST_GRAD_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} - ~BroadcastOpGradGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *x1 = GetDeviceAddress(inputs, 0); - T *x2 = GetDeviceAddress(inputs, 1); - T *dy = GetDeviceAddress(inputs, 2); - T *dx1 = GetDeviceAddress(outputs, 0); - T *dx2 = GetDeviceAddress(outputs, 1); - - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(dx1, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), - "cudaMemSet Failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(dx2, 0, outputs[1]->size, reinterpret_cast(stream_ptr)), - "cudaMemSet Failed"); - if (need_broadcast_) { - BroadcastGrad(x1_shape_[0], x1_shape_[1], x1_shape_[2], x1_shape_[3], x2_shape_[0], x2_shape_[1], x2_shape_[2], - x2_shape_[3], dy_shape_[0], dy_shape_[1], dy_shape_[2], dy_shape_[3], op_type_, x1, x2, dy, dx1, - dx2, reinterpret_cast(stream_ptr)); - } else { - NoBroadcastGrad(output_num_, op_type_, x1, x2, dy, dx1, dx2, reinterpret_cast(stream_ptr)); - } - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - GetOpType(kernel_node); - auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto shape3 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - need_broadcast_ = IsBroadcast(shape1, shape2); - if (need_broadcast_ && shape1.size() > 4) { - MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; - } - - for (size_t i = 0; i < shape3.size(); i++) { - dy_shape_[i] = shape3[i]; - output_num_ *= shape3[i]; - } - int x1_offset = shape3.size() - shape1.size(); - for (size_t i = 0; i < shape1.size(); i++) { - x1_shape_[i + x1_offset] = shape1[i]; - input1_num_ *= shape1[i]; - } - int x2_offset = shape3.size() - shape2.size(); - for (size_t i = 0; i < shape2.size(); i++) { - x2_shape_[i + x2_offset] = shape2[i]; - input2_num_ *= shape2[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { return; } - void InitSizeLists() override { - input_size_list_.push_back(input1_num_ * sizeof(T)); - input_size_list_.push_back(input2_num_ * sizeof(T)); - input_size_list_.push_back(output_num_ * sizeof(T)); - output_size_list_.push_back(input1_num_ * sizeof(T)); - output_size_list_.push_back(input2_num_ * sizeof(T)); - } - - private: - void GetOpType(const CNodePtr &kernel_node) { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - - static std::map kBroadcastTypeMap = { - {"MaximumGrad", BROADCAST_GRAD_TYPE_MAXIMUM}, - {"MinimumGrad", BROADCAST_GRAD_TYPE_MINIMUM}, - }; - - auto iter = kBroadcastTypeMap.find(kernel_name); - if (iter == kBroadcastTypeMap.end()) { - MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; - } else { - op_type_ = iter->second; - } - } - - bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { - if (lhs.size() != rhs.size()) { - return true; - } - for (size_t i = 0; i < lhs.size(); i++) { - if (lhs[i] != rhs[i]) { - return true; - } - } - return false; - } - - BroadcastGradOpType op_type_; - bool need_broadcast_; - int input1_num_; - int input2_num_; - int output_num_; - int x1_shape_[4] = {1, 1, 1, 1}; - int x2_shape_[4] = {1, 1, 1, 1}; - int dy_shape_[4] = {1, 1, 1, 1}; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.cc deleted file mode 100644 index f3c3b6164d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/equalcount_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - EqualCount, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - EqualCountGpuKernel, int) -MS_REG_GPU_KERNEL_ONE( - EqualCount, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EqualCountGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - EqualCount, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - EqualCountGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.h deleted file mode 100644 index 7d3f74970f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.h +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_EQUALCOUNT_GPU_KERNEL_H -#define MINDSPORE_EQUALCOUNT_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/equalcount_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class EqualCountGpuKernel : public GpuKernel { - public: - EqualCountGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} - ~EqualCountGpuKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - T *input1 = GetDeviceAddress(inputs, 0); - T *input2 = GetDeviceAddress(inputs, 1); - T *output = GetDeviceAddress(outputs, 0); - int size = SizeToInt(input_size_ / sizeof(T)); - CalEqualCount(size, input1, input2, output, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but equalcount needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but equalcount needs 1 output."; - return false; - } - - output_size_ = sizeof(T); - input_size_ = sizeof(T); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - return; - } - - private: - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc deleted file mode 100644 index 374644eaf5..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/float_status_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - FloatStatusGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - FloatStatusGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h deleted file mode 100644 index 1aa9b18684..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H - -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/float_status_impl.cuh" - -namespace mindspore { -namespace kernel { -enum Optype { OP_STATUS = 0, OP_INF, OP_NAN, OP_FINITE, OP_INVALID = 255 }; -static const std::map kOpTypeMap = { - {"FloatStatus", OP_STATUS}, {"IsInf", OP_INF}, {"IsNan", OP_NAN}, {"IsFinite", OP_FINITE}}; -template -class FloatStatusGpuKernel : public GpuKernel { - public: - FloatStatusGpuKernel() : kernel_name_(OP_INVALID), input_size_(0), output_size_(0) {} - ~FloatStatusGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input = GetDeviceAddress(inputs, 0); - - switch (kernel_name_) { - case OP_STATUS: { - T *output = GetDeviceAddress(outputs, 0); - CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); - break; - } - case OP_INF: { - bool *output = GetDeviceAddress(outputs, 0); - CalIsInf(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); - break; - } - case OP_NAN: { - bool *output = GetDeviceAddress(outputs, 0); - CalIsNan(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); - break; - } - case OP_FINITE: { - bool *output = GetDeviceAddress(outputs, 0); - CalIsFinite(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); - break; - } - default: { - MS_LOG(EXCEPTION) << "FloatStatus type " << kernel_name_ << " is not supported."; - } - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_size_ = sizeof(T); - for (size_t x : shape) { - input_size_ = input_size_ * x; - } - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kOpTypeMap.find(kernel_name); - if (iter == kOpTypeMap.end()) { - MS_LOG(EXCEPTION) << "FloatStatus kernel " << kernel_name << " is not supported."; - } else { - kernel_name_ = iter->second; - } - if (kernel_name_ == OP_STATUS) { - output_size_ = sizeof(T); - } else { - output_size_ = input_size_ / sizeof(T) * sizeof(bool); - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but FloatStatusGpuKernel needs 1 output."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but FloatStatusGpuKernel needs 1 output."; - return false; - } - return true; - } - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - Optype kernel_name_; - size_t input_size_; - size_t output_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.cc deleted file mode 100644 index 808d599853..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/matmul_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - MatMul, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - MatMulGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - MatMul, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - MatMulGpuKernel, half) -MS_REG_GPU_KERNEL_ONE( - BatchMatMul, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - MatMulGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - BatchMatMul, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - MatMulGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.h deleted file mode 100644 index 3ee3493ed6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.h +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MATMUL_GPU_KERNEL_H -#define MINDSPORE_MATMUL_GPU_KERNEL_H - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace kernel { -template -class MatMulGpuKernel : public GpuKernel { - public: - MatMulGpuKernel() - : batch_(0), - m_(0), - n_(0), - k_(0), - is_null_input_(false), - transpose_x1_(CUBLAS_OP_N), - transpose_x2_(CUBLAS_OP_N), - handle_(nullptr), - dtype_a_(CUDA_R_32F), - dtype_b_(CUDA_R_32F), - dtype_c_(CUDA_R_32F), - algo_(CUBLAS_GEMM_DEFAULT_TENSOR_OP) {} - ~MatMulGpuKernel() = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - VARIABLE_NOT_USED(stream_ptr); - if (is_null_input_) { - return true; - } - auto input1_addr = GetDeviceAddress(inputs, 0); - auto input2_addr = GetDeviceAddress(inputs, 1); - auto output_addr = GetDeviceAddress(outputs, 0); - - const float alpha = 1; - const float beta = 0; - const int lda = (transpose_x1_ == CUBLAS_OP_T) ? SizeToInt(m_) : SizeToInt(k_); - const int ldb = (transpose_x2_ == CUBLAS_OP_T) ? SizeToInt(k_) : SizeToInt(n_); - const int ldc = n_; - - auto stride_a = SizeToInt(m_ * k_); - auto stride_b = SizeToInt(k_ * n_); - auto stride_c = SizeToInt(m_ * n_); - - try { - CHECK_CUBLAS_RET_WITH_EXCEPT( - cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), - &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, - &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), - "cublasSgemm Call Fail"); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); - dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); - dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(output_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "input is null"; - InitSizeLists(); - return true; - } - auto dims = output_shape.size(); - if (dims < 2) { - MS_LOG(EXCEPTION) << "Output dims " << dims << " not support."; - } - - m_ = output_shape[dims - 2]; - n_ = output_shape[dims - 1]; - batch_ = 1; - for (size_t i = 0; i < dims - 2; i++) { - batch_ *= output_shape[i]; - } - - bool transpose = GetAttr(kernel_node, "transpose_x1"); - transpose_x1_ = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - auto input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - k_ = transpose ? input1_shape[dims - 2] : input1_shape[dims - 1]; - - transpose = GetAttr(kernel_node, "transpose_x2"); - transpose_x2_ = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - size_t unit_size = sizeof(T); - - size_t input_size = batch_ * m_ * k_ * unit_size; - input_size_list_.push_back(input_size); - - input_size = batch_ * n_ * k_ * unit_size; - input_size_list_.push_back(input_size); - - size_t output_size = batch_ * m_ * n_ * unit_size; - output_size_list_.push_back(output_size); - } - - private: - size_t batch_; - size_t m_; - size_t n_; - size_t k_; - bool is_null_input_; - - cublasOperation_t transpose_x1_; - cublasOperation_t transpose_x2_; - cublasHandle_t handle_; - cudaDataType_t dtype_a_; - cudaDataType_t dtype_b_; - cudaDataType_t dtype_c_; - cublasGemmAlgo_t algo_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc deleted file mode 100644 index 77f53fc417..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/unary_op_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h deleted file mode 100644 index 4503b805f6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/unary_op_impl.cuh" - -namespace mindspore { -namespace kernel { -enum UnaryOptype { - UNARY_OP_EXP = 0, - UNARY_OP_LOG, - UNARY_OP_NEG, - UNARY_OP_RECIPROCAL, - UNARY_OP_ZEROSLIKE, - UNARY_OP_SQUARE, - UNARY_OP_SQRT, - UNARY_OP_RSQRT, - UNARY_OP_INVALID_TYPE = 255 -}; -static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, - {"Log", UNARY_OP_LOG}, - {"Neg", UNARY_OP_NEG}, - {"Reciprocal", UNARY_OP_RECIPROCAL}, - {"ZerosLike", UNARY_OP_ZEROSLIKE}, - {"Square", UNARY_OP_SQUARE}, - {"Sqrt", UNARY_OP_SQRT}, - {"Rsqrt", UNARY_OP_RSQRT}}; -template -class UnaryOpGpuKernel : public GpuKernel { - public: - UnaryOpGpuKernel() - : unary_op_type_(UNARY_OP_INVALID_TYPE), - input_size_(sizeof(T)), - output_size_(sizeof(T)), - workspace_size_(0), - is_null_input_(false) {} - ~UnaryOpGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - - switch (unary_op_type_) { - case UNARY_OP_EXP: { - Exponential(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_LOG: { - Logarithm(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_NEG: { - Negative(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_RECIPROCAL: { - Reciprocal(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_SQUARE: { - Square(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_SQRT: { - Sqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_RSQRT: { - Rsqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_ZEROSLIKE: { - Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); - return true; - } - default: { - MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported."; - } - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kUnaryOpTypeMap.find(kernel_name); - if (iter == kUnaryOpTypeMap.end()) { - MS_LOG(EXCEPTION) << "Unary operation " << kernel_name << " is not supported."; - } else { - unary_op_type_ = iter->second; - } - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but unary op needs 1 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but unary op needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "UnaryOpGpuKernel input is null"; - InitSizeLists(); - return true; - } - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - output_size_ = input_size_; - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - private: - UnaryOptype unary_op_type_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.cc deleted file mode 100644 index 6993085a75..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nccl/nccl_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - NcclGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - NcclGpuKernel, half) -MS_REG_GPU_KERNEL_ONE( - AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - NcclGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - NcclGpuKernel, half) -MS_REG_GPU_KERNEL_ONE( - ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - NcclGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - NcclGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h deleted file mode 100644 index b5ab46a67d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "device/gpu/distribution/collective_init.h" - -namespace mindspore { -namespace kernel { -enum NcclKernelType { NCCL_ALL_REDUCE = 0, NCCL_ALL_GATHER, NCCL_REDUCE_SCATTER, NCCL_INVALID_TYPE = 255 }; -const std::map kNcclTypeMap = { - {"AllReduce", NCCL_ALL_REDUCE}, - {"AllGather", NCCL_ALL_GATHER}, - {"ReduceScatter", NCCL_REDUCE_SCATTER}, -}; - -static std::map kNcclDtypeMap = { - {"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}}; - -typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t); -typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t); -typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t); - -template -class NcclGpuKernel : public GpuKernel { - public: - NcclGpuKernel() - : nccl_kernel_type_(NCCL_INVALID_TYPE), - nccl_reduce_type_(ncclSum), - input_size_(0), - output_size_(0), - collective_handle_(nullptr), - comm_stream_(nullptr) {} - ~NcclGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - - cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); - switch (nccl_kernel_type_) { - case NCCL_ALL_REDUCE: { - auto all_reduce_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); - MS_EXCEPTION_IF_NULL(all_reduce_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), - nccl_data_type_, nccl_reduce_type_, stream), - "ncclAllReduce failed"); - break; - } - case NCCL_ALL_GATHER: { - auto all_gather_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); - MS_EXCEPTION_IF_NULL(all_gather_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT( - (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream), - "ncclAllGather failed"); - break; - } - case NCCL_REDUCE_SCATTER: { - auto reduce_scatter_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "ReduceScatter")); - MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), - nccl_data_type_, nccl_reduce_type_, stream), - "ncclReduceScatter failed"); - break; - } - default: { - MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported."; - } - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - nccl_data_type_ = kNcclDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; ++i) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - size_t size = sizeof(T); - for (size_t j = 0; j < shape.size(); j++) { - size *= IntToSize(shape[j]); - } - input_size_list_.push_back(size); - input_size_ += size; - } - for (size_t i = 0; i < output_num; ++i) { - auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); - size_t size = sizeof(T); - for (size_t j = 0; j < shape.size(); j++) { - size *= IntToSize(shape[j]); - } - output_size_list_.push_back(size); - output_size_ += size; - } - InferCommType(kernel_node); - collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); - MS_EXCEPTION_IF_NULL(collective_handle_); - - auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id"); - if (comm_stream_attr) { - comm_stream_ = reinterpret_cast(GetValue(comm_stream_attr)); - MS_EXCEPTION_IF_NULL(comm_stream_); - } - return true; - } - - protected: - void InitSizeLists() override { return; } - - private: - void InferCommType(const CNodePtr &kernel_node) { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kNcclTypeMap.find(kernel_name); - if (iter == kNcclTypeMap.end()) { - MS_LOG(EXCEPTION) << "Kernel " << kernel_name << " is not supported."; - } else { - nccl_kernel_type_ = iter->second; - } - - auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); - if (reduce_op) { - std::string type = GetValue(reduce_op); - if (type == "sum") { - nccl_reduce_type_ = ncclSum; - } else if (type == "max") { - nccl_reduce_type_ = ncclMax; - } else if (type == "min") { - nccl_reduce_type_ = ncclMin; - } else if (type == "prod") { - nccl_reduce_type_ = ncclProd; - } else { - MS_LOG(EXCEPTION) << "Nccl reduce type " << type << " is not supported."; - } - } - return; - } - - NcclKernelType nccl_kernel_type_; - ncclRedOp_t nccl_reduce_type_; - ncclDataType_t nccl_data_type_; - size_t input_size_; - size_t output_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - const void *collective_handle_; - cudaStream_t comm_stream_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc deleted file mode 100644 index 5e80cccd75..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/activation_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGpuFwdKernel, half) - -MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGpuFwdKernel, half) - -MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h deleted file mode 100644 index bf6cfa7b23..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class ActivationGpuFwdKernel : public GpuKernel { - public: - ActivationGpuFwdKernel() - : cudnn_handle_(nullptr), - activation_desc_(nullptr), - mode_(CUDNN_ACTIVATION_RELU), - data_descriptor_(nullptr), - is_null_input_(false), - cudnn_data_type_(CUDNN_DATA_FLOAT), - input_size_(0), - output_size_(0), - workspace_size_(0) {} - ~ActivationGpuFwdKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *) override { - if (is_null_input_) { - return true; - } - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, - &beta, data_descriptor_, output), - "cudnnActivationForward failed"); - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kernel_map.find(node_name); - if (iter == kernel_map.end()) { - MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; - } - mode_ = iter->second; - - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGpuFwdKernel needs 1."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "ActivationGpuFwdKernel input is null."; - InitSizeLists(); - return true; - } - std::vector shape; - ShapeNdTo4d(input_shape, &shape); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), - "cudnnSetActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shape[0], shape[1], shape[2], shape[3]), - "cudnnSetTensor4dDescriptor failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), - "cudnnCreateActivationDescriptor failed"); - } - - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - output_size_ = input_size_; - } - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), - "cudnnDestroyActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); - } - - std::map kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, - {"Tanh", CUDNN_ACTIVATION_TANH}, - {"ELU", CUDNN_ACTIVATION_ELU}, - {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; - - cudnnHandle_t cudnn_handle_; - cudnnActivationDescriptor_t activation_desc_; - cudnnActivationMode_t mode_; - cudnnTensorDescriptor_t data_descriptor_; - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - cudnnDataType_t cudnn_data_type_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc deleted file mode 100644 index 35d11f8b47..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/activation_grad_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - ReluGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - ReluGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGradGpuKernel, half) - -MS_REG_GPU_KERNEL_ONE( - TanhGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - TanhGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGradGpuKernel, half) - -MS_REG_GPU_KERNEL_ONE( - SigmoidGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - SigmoidGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h deleted file mode 100644 index 38e34eb752..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h +++ /dev/null @@ -1,146 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class ActivationGradGpuKernel : public GpuKernel { - public: - ActivationGradGpuKernel() - : cudnn_handle_(nullptr), - activation_desc_(nullptr), - mode_(CUDNN_ACTIVATION_RELU), - data_descriptor_(nullptr), - is_null_input_(false), - cudnn_data_type_(CUDNN_DATA_FLOAT), - input_size_(0) {} - ~ActivationGradGpuKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *) override { - if (is_null_input_) { - return true; - } - T *dy = nullptr; - T *y = nullptr; - if (mode_ == CUDNN_ACTIVATION_RELU || mode_ == CUDNN_ACTIVATION_ELU) { - dy = GetDeviceAddress(inputs, 0); - y = GetDeviceAddress(inputs, 1); - } else { - y = GetDeviceAddress(inputs, 0); - dy = GetDeviceAddress(inputs, 1); - } - T *dx = GetDeviceAddress(outputs, 0); - - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, - data_descriptor_, y, &beta, data_descriptor_, dx), - "cudnnActivationBackward failed"); - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kernel_map.find(node_name); - if (iter == kernel_map.end()) { - MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; - } - mode_ = iter->second; - - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGradGpuKernel needs 2."; - return false; - } - auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "ActivationGradGpuKernel input is null."; - InitSizeLists(); - return true; - } - std::vector shape; - ShapeNdTo4d(input_shape, &shape); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0), - "SetActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shape[0], shape[1], shape[2], shape[3]), - "SetTensor4dDescriptor failed"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), - "cudnnCreateActivationDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), - "cudnnDestroyActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); - } - - std::map kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU}, - {"TanhGrad", CUDNN_ACTIVATION_TANH}, - {"ELUGrad", CUDNN_ACTIVATION_ELU}, - {"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}}; - cudnnHandle_t cudnn_handle_; - cudnnActivationDescriptor_t activation_desc_; - cudnnActivationMode_t mode_; - cudnnTensorDescriptor_t data_descriptor_; - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - cudnnDataType_t cudnn_data_type_; - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc deleted file mode 100644 index 049a5cc280..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/adam_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Adam, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - AdamGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Adam, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - AdamGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h deleted file mode 100644 index 93c6381ab3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/adam_impl.cuh" -namespace mindspore { -namespace kernel { -template -class AdamGpuKernel : public GpuKernel { - public: - AdamGpuKernel() - : variable_size_(0), - m_size_(0), - v_size_(0), - beta1_power_size_(0), - beta2_power_size_(0), - learning_rate_size_(0), - beta1_size_(0), - beta2_size_(0), - epsilon_size_(0), - gradient_size_(0) {} - - ~AdamGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, - void *stream_ptr) override { - T *variable = GetDeviceAddress(inputs, 0); - T *m = GetDeviceAddress(inputs, 1); - T *v = GetDeviceAddress(inputs, 2); - T *beta1_power = GetDeviceAddress(inputs, 3); - T *beta2_power = GetDeviceAddress(inputs, 4); - T *learning_rate = GetDeviceAddress(inputs, 5); - T *beta1 = GetDeviceAddress(inputs, 6); - T *beta2 = GetDeviceAddress(inputs, 7); - T *epsilon = GetDeviceAddress(inputs, 8); - T *gradient = GetDeviceAddress(inputs, 9); - ApplyAdam(inputs[0]->size / sizeof(T), gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, - variable, m, v, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 10) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 10 inputs."; - return false; - } - - variable_size_ = sizeof(T); - m_size_ = sizeof(T); - v_size_ = sizeof(T); - beta1_power_size_ = sizeof(T); - beta2_power_size_ = sizeof(T); - learning_rate_size_ = sizeof(T); - beta1_size_ = sizeof(T); - beta2_size_ = sizeof(T); - epsilon_size_ = sizeof(T); - gradient_size_ = sizeof(T); - - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < variable_shape.size(); i++) { - variable_size_ *= variable_shape[i]; - } - - auto m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < m_shape.size(); i++) { - m_size_ *= m_shape[i]; - } - - auto v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - for (size_t i = 0; i < v_shape.size(); i++) { - v_size_ *= v_shape[i]; - } - - auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); - for (size_t i = 0; i < gradient_shape.size(); i++) { - gradient_size_ *= gradient_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(variable_size_); - input_size_list_.push_back(m_size_); - input_size_list_.push_back(v_size_); - input_size_list_.push_back(beta1_power_size_); - input_size_list_.push_back(beta2_power_size_); - input_size_list_.push_back(learning_rate_size_); - input_size_list_.push_back(beta1_size_); - input_size_list_.push_back(beta2_size_); - input_size_list_.push_back(epsilon_size_); - input_size_list_.push_back(gradient_size_); - output_size_list_.push_back(0); - output_size_list_.push_back(0); - output_size_list_.push_back(0); - } - - private: - size_t variable_size_; - size_t m_size_; - size_t v_size_; - size_t beta1_power_size_; - size_t beta2_power_size_; - size_t learning_rate_size_; - size_t beta1_size_; - size_t beta2_size_; - size_t epsilon_size_; - size_t gradient_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.cc b/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.cc deleted file mode 100644 index ce6c9beeb7..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/bias_add_grad_gpu_kenel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BiasAddGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BiasAddGradGpuKernel, float16) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h b/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h deleted file mode 100644 index 9b4f18d24c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ - -#include -#include -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class BiasAddGradGpuKernel : public GpuKernel { - public: - BiasAddGradGpuKernel() - : same_dims_(true), - cudnn_handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - dy_desc_(nullptr), - db_desc_(nullptr), - op_desc_(nullptr) {} - ~BiasAddGradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - T *dy_addr = GetDeviceAddress(inputs, 0); - T *db_addr = GetDeviceAddress(outputs, 0); - T *indices_addr = GetDeviceAddress(workspace, 0); - T *workspace_addr = GetDeviceAddress(workspace, 1); - - const float alpha = 1; - const float beta = 0; - if (same_dims_) { - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(db_addr, dy_addr, output_size_list_[0], cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync failed."); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnReduceTensor(cudnn_handle_, op_desc_, indices_addr, workspace_size_list_[0], workspace_addr, - workspace_size_list_[1], &alpha, dy_desc_, dy_addr, &beta, db_desc_, db_addr), - "cudnnReduceTensor failed"); - } - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto num_dims = dy_shape.size(); - if (num_dims < 2) { - MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims; - } - - std::string format = GetAttr(kernel_node, "data_format"); - string::size_type pos = format.find("C"); - if (pos == std::string::npos || pos >= num_dims) { - MS_LOG(EXCEPTION) << "format '" << format << "' invalid"; - } - - // Expand to 4 dims for cudnnSetTensorNdDescriptorEx. - auto cudnn_dims = std::max(num_dims, 4UL); - std::unique_ptr dy_dims = std::make_unique(cudnn_dims); - std::unique_ptr db_dims = std::make_unique(cudnn_dims); - for (size_t i = 0; i < cudnn_dims; i++) { - dy_dims[i] = (i < num_dims) ? SizeToInt(dy_shape[i]) : 1; - db_dims[i] = (i == pos) ? SizeToInt(dy_shape[i]) : 1; - - if (dy_dims[i] != db_dims[i]) { - same_dims_ = false; - } - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()), - "cudnnSetTensorNdDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), - "cudnnSetTensorNdDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, - CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES), - "cudnnSetReduceTensorDescriptor failed"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&db_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateReduceTensorDescriptor(&op_desc_), "cudnnCreateOpTensorDescriptor failed"); - } - void InitSizeLists() override { - size_t dy_size, db_size; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, &dy_size), "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(db_desc_, &db_size), "cudnnGetTensorSizeInBytes failed"); - input_size_list_.push_back(dy_size); - output_size_list_.push_back(db_size); - - size_t indices_size, workspace_size; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetReductionIndicesSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &indices_size), - "cudnnGetReductionIndicesSize failed") - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetReductionWorkspaceSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &workspace_size), - "cudnnGetReductionWorkspaceSize failed") - workspace_size_list_.push_back(indices_size); - workspace_size_list_.push_back(workspace_size); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDestroyReduceTensorDescriptor(op_desc_), - "cudnnDestroyReduceTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(db_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyOpTensorDescriptor failed"); - } - - bool same_dims_; - cudnnHandle_t cudnn_handle_; - cudnnDataType_t cudnn_data_type_; - cudnnTensorDescriptor_t dy_desc_; - cudnnTensorDescriptor_t db_desc_; - cudnnReduceTensorDescriptor_t op_desc_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.cc deleted file mode 100644 index df6825e079..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/conv2d_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Conv2D, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - Conv2dGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE( - Conv2D, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - Conv2dGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.h deleted file mode 100644 index f51cbfef33..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.h +++ /dev/null @@ -1,320 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/pad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class Conv2dGpuFwdKernel : public GpuKernel { - public: - Conv2dGpuFwdKernel() - : cudnn_handle_(nullptr), - input_desc_(nullptr), - output_desc_(nullptr), - filter_desc_(nullptr), - conv_desc_(nullptr), - padded_desc_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - group_(1), - is_null_input_(false), - input_size_(0), - filter_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} - ~Conv2dGpuFwdKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input_addr = GetDeviceAddress(inputs, 0); - T *filter_addr = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - T *workspace_addr = nullptr; - if (workspace_size_ != 0) { - workspace_addr = GetDeviceAddress(workspace, 0); - } - - const float alpha = 1; - const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded_addr = GetDeviceAddress(workspace, 1); - CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, - reinterpret_cast(stream_ptr)); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionForward(cudnn_handle_, &alpha, padded_desc_, padded_addr, filter_desc_, filter_addr, conv_desc_, - conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), - "cudnnConvolutionForward failed"); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionForward(cudnn_handle_, &alpha, input_desc_, input_addr, filter_desc_, filter_addr, conv_desc_, - conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), - "cudnnConvolutionForward failed"); - } - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } - auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(in_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "Conv2dGpuFwdKernel input is null."; - InitSizeLists(); - return true; - } - Set4DDesc(in_shape, filter_shape, output_shape); - group_ = GetAttr(kernel_node, "group"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); - pad_height_ = GetAttr(kernel_node, "pad"); - pad_width_ = pad_height_; - pad_mode_ = GetAttr(kernel_node, "pad_mode"); - SetStrideAndDilation(kernel_node); - cudnnTensorDescriptor_t input_descriptor_real = nullptr; - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { - SetPad(in_shape, kernel_node); - input_descriptor_real = use_pad_ ? padded_desc_ : input_desc_; - } else { - if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_height_ = 0; - pad_width_ = 0; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2], - dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - input_descriptor_real = input_desc_; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), - "cudnnSetConvolutionMathType failed.") - } - SelectAlgorithm(input_descriptor_real); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&filter_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), - "cudnnCreateConvolutionDescriptor failed"); - } - - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_desc_, reinterpret_cast(&input_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(filter_desc_, reinterpret_cast(&filter_size_)), - "cudnnGetFilterSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(output_desc_, reinterpret_cast(&output_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_desc_, reinterpret_cast(&padded_size_)), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - input_size_list_.push_back(filter_size_); - output_size_list_.push_back(output_size_); - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle_, padded_desc_, filter_desc_, conv_desc_, output_desc_, - conv_algorithm_, &workspace_size_), - "cudnnGetConvolutionForwardWorkspaceSize failed"); - workspace_size_list_.push_back(padded_size_); - } else { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle_, input_desc_, filter_desc_, conv_desc_, output_desc_, - conv_algorithm_, &workspace_size_), - "cudnnGetConvolutionForwardWorkspaceSize failed"); - } - } - (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); - - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), - "cudnnDestroyConvolutionDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(filter_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); - } - bool CheckParam(const CNodePtr &kernel_node) { - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but conv2d needs 2 inputs."; - return false; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but conv2d needs 1 output."; - return false; - } - return true; - } - void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { - auto pad_list = GetAttr>(kernel_node, "pad_list"); - - n_ = SizeToInt(in_shape[0]); - c_ = SizeToInt(in_shape[1]); - old_height_ = SizeToInt(in_shape[2]); - old_width_ = SizeToInt(in_shape[3]); - pad_height_ = pad_list[0] + pad_list[1]; - pad_width_ = pad_list[2] + pad_list[3]; - pad_top_ = pad_list[0]; - pad_left_ = pad_list[2]; - - // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_, - old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( - conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3], - dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - } - - void Set4DDesc(const std::vector &in_shape, const std::vector &filter_shape, - const std::vector &output_shape) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), - SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), - "cudnnSetTensor4dDescriptor failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(filter_shape[0]), - SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), - "cudnnSetFilter4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), - SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), - "cudnnSetTensor4dDescriptor failed"); - } - void SelectAlgorithm(cudnnTensorDescriptor_t input_descriptor_real) { - if (group_ > 1 || CUDNN_MAJOR < 7) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetConvolutionForwardAlgorithm( - cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, output_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, 0, &conv_algorithm_), - "cudnnGetConvolutionForwardAlgorithm failed"); - } else { - constexpr int requested_algo_count = 1; - int returned_algo_count; - cudnnConvolutionFwdAlgoPerf_t perf_results; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionForwardAlgorithm_v7(cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, - output_desc_, requested_algo_count, &returned_algo_count, &perf_results), - "cudnnGetConvolutionForwardAlgorithm_v7 failed"); - conv_algorithm_ = perf_results.algo; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - } - } - void SetStrideAndDilation(const CNodePtr &kernel_node) { - stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); - dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); - if (stride_.size() != 4) { - MS_LOG(EXCEPTION) << "Conv2d's' stride must be 4d!"; - } - if (stride_[0] != 1 || stride_[1] != 1) { - MS_LOG(EXCEPTION) << "Conv2d stride only support 1 in N axis and C axis!"; - } - if (dilation_.size() != 4) { - MS_LOG(EXCEPTION) << "Conv2d's dilation must be 4d!"; - } - if (dilation_[0] != 1 || dilation_[1] != 1) { - MS_LOG(EXCEPTION) << "Conv2d dilation only support 1 in N axis and C axis!"; - } - } - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t input_desc_; - cudnnTensorDescriptor_t output_desc_; - cudnnFilterDescriptor_t filter_desc_; - cudnnConvolutionFwdAlgo_t conv_algorithm_; - cudnnConvolutionDescriptor_t conv_desc_; - cudnnTensorDescriptor_t padded_desc_; - std::string pad_mode_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - const float pad_value_ = 0.0; - cudnnDataType_t cudnn_data_type_; - int old_height_; - int old_width_; - int pad_height_; - int pad_width_; - int pad_top_; - int pad_left_; - int n_; - int c_; - std::vector stride_; - std::vector dilation_; - int group_; - bool is_null_input_; - size_t input_size_; - size_t filter_size_; - size_t output_size_; - size_t padded_size_; - size_t workspace_size_; - bool use_pad_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.cc deleted file mode 100644 index 28e9a10ccc..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Conv2DBackpropFilter, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ConvGradFilterGpuBkwKernel, float) -MS_REG_GPU_KERNEL_ONE( - Conv2DBackpropFilter, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ConvGradFilterGpuBkwKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.h deleted file mode 100644 index 0d7be25772..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.h +++ /dev/null @@ -1,320 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/pad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class ConvGradFilterGpuBkwKernel : public GpuKernel { - public: - ConvGradFilterGpuBkwKernel() - : cudnn_handle_(nullptr), - dw_desc_(nullptr), - conv_desc_(nullptr), - dy_desc_(nullptr), - x_desc_(nullptr), - padded_descriptor_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - group_(1), - is_null_input_(false), - input_size_(0), - dy_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} - ~ConvGradFilterGpuBkwKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *dy = GetDeviceAddress(inputs, 0); - T *x = GetDeviceAddress(inputs, 1); - T *dw = GetDeviceAddress(outputs, 0); - T *work_space = nullptr; - if (workspace_size_ != 0) { - work_space = GetDeviceAddress(workspace, 0); - } - - const float alpha = 1; - const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded = GetDeviceAddress(workspace, 1); - CalPad(padded_size_ / sizeof(T), x, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, - reinterpret_cast(stream_ptr)); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, padded_descriptor_, padded, dy_desc_, dy, conv_desc_, - algo_, work_space, workspace_size_, &beta, dw_desc_, dw), - "ConvolutionBackwardFilter failed"); - return true; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, x_desc_, x, dy_desc_, dy, conv_desc_, algo_, work_space, - workspace_size_, &beta, dw_desc_, dw), - "ConvolutionBackwardFilter failed"); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "ConvGradFilterGpuBkwKernel input is null."; - InitSizeLists(); - return true; - } - std::vector filter_shape; - GetFilterShape(kernel_node, &filter_shape); - Set4DDesc(dy_shape, filter_shape, in_shape); - group_ = GetAttr(kernel_node, "group"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); - - pad_height_ = GetAttr(kernel_node, "pad"); - pad_width_ = pad_height_; - pad_mode_ = GetAttr(kernel_node, "pad_mode"); - SetStrideAndDilation(kernel_node); - cudnnTensorDescriptor_t x_desc_real = nullptr; - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { - SetPad(in_shape, kernel_node); - x_desc_real = use_pad_ ? padded_descriptor_ : x_desc_; - } else { - if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_height_ = 0; - pad_width_ = 0; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], - dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "GetConvolution2dDescriptor failed"); - x_desc_real = x_desc_; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), - "cudnnSetConvolutionMathType failed.") - } - SelectAlgorithm(x_desc_real); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&dw_desc_), "cudnnCreateFilterDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), - "cudnnCreateConvolutionDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, reinterpret_cast(&dy_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, reinterpret_cast(&input_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(dw_desc_, reinterpret_cast(&output_size_)), - "cudnnGetFilterSizeInBytes failed"); - } - input_size_list_.push_back(dy_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast(&padded_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, padded_descriptor_, dy_desc_, conv_desc_, - dw_desc_, algo_, reinterpret_cast(&workspace_size_)), - "cudnnGetConvolutionBackwardFilterWorkspaceSize failed"); - workspace_size_list_.push_back(padded_size_); - } else { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, x_desc_, dy_desc_, conv_desc_, dw_desc_, algo_, - reinterpret_cast(&workspace_size_)), - "cudnnGetConvolutionBackwardFilterWorkspaceSize failed"); - } - } - (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), - "cudnnDestroyConvolutionDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(dw_desc_), "cudnnDestroyFilterDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyTensorDescriptor failed"); - } - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ConvGradFilter needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but ConvGradFilter needs 1 output."; - return false; - } - return true; - } - void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { - auto pad_list = GetAttr>(kernel_node, "pad_list"); - n_ = SizeToInt(in_shape[0]); - c_ = SizeToInt(in_shape[1]); - old_height_ = SizeToInt(in_shape[2]); - old_width_ = SizeToInt(in_shape[3]); - pad_height_ = pad_list[0] + pad_list[1]; - pad_width_ = pad_list[2] + pad_list[3]; - pad_top_ = pad_list[0]; - pad_left_ = pad_list[2]; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( - conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], - dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - } - void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { - if (group_ > 1 || CUDNN_MAJOR < 7) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, 0, &algo_), - "GetConvolutionBackwardFilterAlgorithm failed"); - } else { - constexpr int requested_algo_count = 1; - int returned_algo_count; - cudnnConvolutionBwdFilterAlgoPerf_t perf_results; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, - requested_algo_count, &returned_algo_count, &perf_results), - "GetConvolutionBackwardFilterAlgorithm failed"); - algo_ = perf_results.algo; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - } - } - void GetFilterShape(const CNodePtr &kernel_node, std::vector *filter_shape) { - auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast()->value(); - (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*filter_shape), - [](const ValuePtr &e) -> int { return e->cast()->value(); }); - } - void Set4DDesc(const std::vector &dy_shape, const std::vector &filter_shape, - const std::vector &in_shape) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), - SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), - "SetTensor4dDescriptor failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetFilter4dDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), filter_shape[1], - filter_shape[2], filter_shape[3]), - "SetFilter4dDescriptor failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), - SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), - "SetTensor4dDescriptor failed"); - } - void SetStrideAndDilation(const CNodePtr &kernel_node) { - stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); - dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); - if (stride_.size() != 2) { - MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel's stride must be 2d!"; - } - if (dilation_.size() != 4) { - MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel's dilation must be 4d!"; - } - if (dilation_[0] != 1 || dilation_[1] != 1) { - MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel dilation only support 1 in N axis and C axis!"; - } - } - cudnnHandle_t cudnn_handle_; - cudnnFilterDescriptor_t dw_desc_; - cudnnConvolutionDescriptor_t conv_desc_; - cudnnTensorDescriptor_t dy_desc_; - cudnnTensorDescriptor_t x_desc_; - cudnnTensorDescriptor_t padded_descriptor_; - cudnnConvolutionBwdFilterAlgo_t algo_; - std::string pad_mode_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - const float pad_value_ = 0.0; - cudnnDataType_t cudnn_data_type_; - int old_height_; - int old_width_; - int pad_height_; - int pad_width_; - int pad_top_; - int pad_left_; - int n_; - int c_; - std::vector stride_; - std::vector dilation_; - int group_; - bool is_null_input_; - size_t input_size_; - size_t dy_size_; - size_t output_size_; - size_t padded_size_; - size_t workspace_size_; - bool use_pad_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.cc deleted file mode 100644 index 12b6f91537..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/conv2d_grad_input_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Conv2DBackpropInput, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ConvGradInputGpuBkwKernel, float) -MS_REG_GPU_KERNEL_ONE( - Conv2DBackpropInput, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ConvGradInputGpuBkwKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.h deleted file mode 100644 index a33ea5b4da..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.h +++ /dev/null @@ -1,315 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/pad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class ConvGradInputGpuBkwKernel : public GpuKernel { - public: - ConvGradInputGpuBkwKernel() - : cudnn_handle_(nullptr), - w_desc_(nullptr), - conv_desc_(nullptr), - dy_desc_(nullptr), - dx_desc_(nullptr), - padded_descriptor_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - group_(1), - is_null_input_(false), - dy_size_(0), - w_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} - ~ConvGradInputGpuBkwKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *dy = GetDeviceAddress(inputs, 0); - T *w = GetDeviceAddress(inputs, 1); - T *dx = GetDeviceAddress(outputs, 0); - T *work_space = nullptr; - if (workspace_size_ != 0) { - work_space = GetDeviceAddress(workspace, 0); - } - - const float alpha = 1; - const float beta = 0; - - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded = GetDeviceAddress(workspace, 1); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, - workspace_size_, &beta, padded_descriptor_, padded), - "ConvolutionBackwardData failed"); - CalPadGrad(output_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, - workspace_size_, &beta, dx_desc_, dx), - "ConvolutionBackwardData failed"); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(dy_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "ConvGradInputGpuBkwKernel input is null."; - InitSizeLists(); - return true; - } - std::vector input_shape; - GetInputShape(kernel_node, &input_shape); - Set4DDesc(dy_shape, input_shape, filter_shape); - - group_ = GetAttr(kernel_node, "group"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); - - pad_height_ = GetAttr(kernel_node, "pad"); - pad_width_ = pad_height_; - pad_mode_ = GetAttr(kernel_node, "pad_mode"); - SetStrideAndDilation(kernel_node); - cudnnTensorDescriptor_t dx_desc_real = nullptr; - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { - SetPad(input_shape, kernel_node); - dx_desc_real = use_pad_ ? padded_descriptor_ : dx_desc_; - } else { - if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_height_ = 0; - pad_width_ = 0; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], - dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - dx_desc_real = dx_desc_; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), - "cudnnSetConvolutionMathType failed.") - } - SelectAlgorithm(dx_desc_real); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "cudnnCreateFilterDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), - "cudnnCreateConvolutionDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, &dy_size_), "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(w_desc_, &w_size_), "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dx_desc_, &output_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(dy_size_); - input_size_list_.push_back(w_size_); - output_size_list_.push_back(output_size_); - - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), - "cudnnGetTensorSizeInBytes failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, padded_descriptor_, - algo_, &workspace_size_), - "cudnnGetConvolutionBackwardDataWorkspaceSize failed"); - workspace_size_list_.push_back(padded_size_); - } else { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetConvolutionBackwardDataWorkspaceSize( - cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_, algo_, &workspace_size_), - "cudnnGetConvolutionBackwardDataWorkspaceSize failed"); - } - } - (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), - "cudnnDestroyConvolutionDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "cudnnDestroyFilterDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "cudnnDestroyTensorDescriptor failed"); - } - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ConvGradInput needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but ConvGradInput needs 1 output."; - return false; - } - return true; - } - void SetPad(const std::vector &input_shape, const CNodePtr &kernel_node) { - auto pad_list = GetAttr>(kernel_node, "pad_list"); - n_ = input_shape[0]; - c_ = input_shape[1]; - old_height_ = input_shape[2]; - old_width_ = input_shape[3]; - pad_height_ = pad_list[0] + pad_list[1]; - pad_width_ = pad_list[2] + pad_list[3]; - pad_top_ = pad_list[0]; - pad_left_ = pad_list[2]; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( - conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], - dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - } - void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { - if (group_ > 1 || CUDNN_MAJOR < 7) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, 0, &algo_), - "cudnnGetConvolutionBackwardDataAlgorithm failed"); - } else { - constexpr int requested_algo_count = 1; - int returned_algo_count; - cudnnConvolutionBwdDataAlgoPerf_t perf_results; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, - requested_algo_count, &returned_algo_count, &perf_results), - "cudnnGetConvolutionBackwardDataAlgorithm_v7 failed"); - algo_ = perf_results.algo; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - } - } - void GetInputShape(const CNodePtr &kernel_node, std::vector *input_shape) { - auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast()->value(); - (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*input_shape), - [](const ValuePtr &e) -> int { return e->cast()->value(); }); - } - void Set4DDesc(const std::vector &dy_shape, const std::vector &input_shape, - const std::vector &filter_shape) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetFilter4dDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), - SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), - "SetFilter4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), - SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), - "SetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, input_shape[0], input_shape[1], - input_shape[2], input_shape[3]), - "SetTensor4dDescriptor failed"); - } - void SetStrideAndDilation(const CNodePtr &kernel_node) { - stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); - dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); - if (stride_.size() != 2) { - MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel's stride must be 2d!"; - } - if (dilation_.size() != 4) { - MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel's dilation must be 4d!"; - } - if (dilation_[0] != 1 || dilation_[1] != 1) { - MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel dilation only support 1 in N axis and C axis!"; - } - } - cudnnHandle_t cudnn_handle_; - cudnnFilterDescriptor_t w_desc_; - cudnnConvolutionDescriptor_t conv_desc_; - cudnnTensorDescriptor_t dy_desc_; - cudnnTensorDescriptor_t dx_desc_; - cudnnTensorDescriptor_t padded_descriptor_; - cudnnConvolutionBwdDataAlgo_t algo_; - std::string pad_mode_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - cudnnDataType_t cudnn_data_type_; - int old_height_; - int old_width_; - int pad_height_; - int pad_width_; - int pad_top_; - int pad_left_; - int n_; - int c_; - std::vector stride_; - std::vector dilation_; - int group_; - bool is_null_input_; - size_t dy_size_; - size_t w_size_; - size_t output_size_; - size_t padded_size_; - size_t workspace_size_; - bool use_pad_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc deleted file mode 100644 index 459010e9e9..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/dropout_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Dropout, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - DropoutGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE( - Dropout, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - DropoutGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h deleted file mode 100644 index 4dfacb7ca1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/dropout_impl.cuh" -#include "include/curand.h" - -namespace mindspore { -namespace kernel { -template -class DropoutGpuFwdKernel : public GpuKernel { - public: - DropoutGpuFwdKernel() - : cudnn_handle_(nullptr), - is_null_input_(false), - num_count_(0), - keep_prob_(0.0), - states_init_(false), - mask_generator_(nullptr) {} - - ~DropoutGpuFwdKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - T *mask = GetDeviceAddress(outputs, 1); - float *mask_f = GetDeviceAddress(workspace, 0); - - if (!states_init_) { - curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT); - curandSetPseudoRandomGeneratorSeed(mask_generator_, time(NULL)); - states_init_ = true; - } - // curandGen only support float or double for mask. - curandGenerateUniform(mask_generator_, mask_f, num_count_); - DropoutForward(input, mask, output, mask_f, num_count_, keep_prob_, reinterpret_cast(stream_ptr)); - - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but DropoutGpuFwdKernel needs 1."; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - InitSizeLists(); - return true; - } - - num_count_ = 1; - for (size_t x : input_shape) { - num_count_ *= x; - } - keep_prob_ = GetAttr(kernel_node, "keep_prob"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - - void InitSizeLists() override { - size_t input_size = num_count_ * sizeof(T); - input_size_list_.push_back(input_size); - output_size_list_.push_back(input_size); // output size: the same with input size - output_size_list_.push_back(input_size); // mask size: the same with input size - workspace_size_list_.push_back(num_count_ * sizeof(float)); // temp mask_f for curandGen - } - - private: - cudnnHandle_t cudnn_handle_; - bool is_null_input_; - size_t num_count_; - float keep_prob_; - bool states_init_; - curandGenerator_t mask_generator_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc deleted file mode 100644 index 2fd21c96ee..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/dropout_grad_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - DropoutGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - DropoutGradGpuBwdKernel, float) -MS_REG_GPU_KERNEL_ONE( - DropoutGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - DropoutGradGpuBwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h deleted file mode 100644 index e6683e15dd..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/dropout_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class DropoutGradGpuBwdKernel : public GpuKernel { - public: - DropoutGradGpuBwdKernel() : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), keep_prob_(0.0) {} - ~DropoutGradGpuBwdKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - - T *dy = GetDeviceAddress(inputs, 0); - T *mask = GetDeviceAddress(inputs, 1); - T *dx = GetDeviceAddress(outputs, 0); - - DropoutBackward(dy, mask, dx, num_count_, keep_prob_, reinterpret_cast(stream_ptr)); - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but DropoutGradGpuBwdKernel needs 2."; - return false; - } - - auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - InitSizeLists(); - return true; - } - - num_count_ = 1; - for (size_t x : input_shape) { - num_count_ *= x; - } - keep_prob_ = GetAttr(kernel_node, "keep_prob"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - void InitSizeLists() override { - size_t dy_size = num_count_ * sizeof(T); - size_t mask_size = dy_size; - size_t dx_size = dy_size; - - input_size_list_.push_back(dy_size); - input_size_list_.push_back(mask_size); - output_size_list_.push_back(dx_size); - } - - private: - cudnnHandle_t cudnn_handle_; - bool is_null_input_; - size_t num_count_; - float keep_prob_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.cc deleted file mode 100644 index f9c993d31d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/flatten_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - FlattenGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - FlattenGpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - FlattenGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - FlattenGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - FlattenGpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - FlattenGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - FlattenGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - FlattenGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - FlattenGpuFwdKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h deleted file mode 100644 index 3b0ad8c946..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class FlattenGpuFwdKernel : public GpuKernel { - public: - FlattenGpuFwdKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} - ~FlattenGpuFwdKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - cudaError_t ret = - cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)); - if (ret) { - MS_LOG(ERROR) << "cudaMemcpyAsync error in FlattenGpuFwdKernel::Launch, error code is " << ret; - return false; - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_size_ = sizeof(T); - for (size_t i = 0; i < shape.size(); ++i) { - input_size_ *= shape[i]; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_ = input_size_; - output_size_list_.push_back(output_size_); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.cc deleted file mode 100644 index 0e079d137b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.cc +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/flatten_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - FlattenGardGpuBkwKernel, float) -MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - FlattenGardGpuBkwKernel, half) -MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - FlattenGardGpuBkwKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.h deleted file mode 100644 index 0748dc77db..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.h +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class FlattenGardGpuBkwKernel : public GpuKernel { - public: - FlattenGardGpuBkwKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} - ~FlattenGardGpuBkwKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - cudaError_t ret = - cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)); - if (ret) { - MS_LOG(ERROR) << "cudaMemcpyAsync error in FlattenGardGpuFwdKernel::Launch, error code is " << ret; - return false; - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but FlattenGardGpuFwdKernel needs 1."; - return false; - } - - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < shape.size(); ++i) { - if (input_size_ == 0) { - input_size_ = 1; - } - input_size_ *= shape[i]; - } - input_size_ = input_size_ * sizeof(T); - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_ = input_size_; - output_size_list_.push_back(output_size_); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc deleted file mode 100644 index 4d30130931..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/ftrl_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(ApplyFtrl, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FtrlGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ApplyFtrl, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - FtrlGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h deleted file mode 100644 index 9e2153965b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/ftrl_impl.cuh" -namespace mindspore { -namespace kernel { -template -class FtrlGpuKernel : public GpuKernel { - public: - FtrlGpuKernel() - : variable_size_(0), - accumulation_size_(0), - linear_size_(0), - gradient_size_(0), - learning_rate_size_(0), - l1_regularization_size_(0), - l2_regularization_size_(0), - learning_rate_power_size_(0) {} - - ~FtrlGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, - void *stream_ptr) override { - T *variable = GetDeviceAddress(inputs, 0); - T *accumulation = GetDeviceAddress(inputs, 1); - T *linear = GetDeviceAddress(inputs, 2); - T *gradient = GetDeviceAddress(inputs, 3); - T *learning_rate = GetDeviceAddress(inputs, 4); - T *l1_regularization = GetDeviceAddress(inputs, 5); - T *l2_regularization = GetDeviceAddress(inputs, 6); - T *learning_rate_power = GetDeviceAddress(inputs, 7); - ApplyFtrl(inputs[0]->size / sizeof(T), gradient, learning_rate, l1_regularization, l2_regularization, - learning_rate_power, variable, accumulation, linear, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 8) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 8 inputs."; - return false; - } - - variable_size_ = sizeof(T); - accumulation_size_ = sizeof(T); - linear_size_ = sizeof(T); - gradient_size_ = sizeof(T); - learning_rate_size_ = sizeof(T); - l1_regularization_size_ = sizeof(T); - l2_regularization_size_ = sizeof(T); - learning_rate_power_size_ = sizeof(T); - - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < variable_shape.size(); i++) { - variable_size_ *= variable_shape[i]; - } - - auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < accumulation_shape.size(); i++) { - accumulation_size_ *= accumulation_shape[i]; - } - - auto linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - for (size_t i = 0; i < linear_shape.size(); i++) { - linear_size_ *= linear_shape[i]; - } - - auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - for (size_t i = 0; i < gradient_shape.size(); i++) { - gradient_size_ *= gradient_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(variable_size_); - input_size_list_.push_back(accumulation_size_); - input_size_list_.push_back(linear_size_); - input_size_list_.push_back(gradient_size_); - input_size_list_.push_back(learning_rate_size_); - input_size_list_.push_back(l1_regularization_size_); - input_size_list_.push_back(l2_regularization_size_); - input_size_list_.push_back(learning_rate_power_size_); - output_size_list_.push_back(0); - } - - private: - size_t variable_size_; - size_t accumulation_size_; - size_t linear_size_; - size_t gradient_size_; - size_t learning_rate_size_; - size_t l1_regularization_size_; - size_t l2_regularization_size_; - size_t learning_rate_power_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc deleted file mode 100644 index 77cb7f8608..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/fused_adam_weight_decay.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(FusedAdamWeightDecay, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FusedAdamWeightDecayGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(FusedAdam, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FusedAdamWeightDecayGpuKernel, float) - -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.h b/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.h deleted file mode 100644 index f13f6ed59f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.h +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class FusedAdamWeightDecayGpuKernel : public GpuKernel { - public: - FusedAdamWeightDecayGpuKernel() : element_nums_(0), weight_decay_(false) {} - ~FusedAdamWeightDecayGpuKernel() override = default; - - bool Init(const CNodePtr &kernel_node) override { - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - if (node_name == "AdamWeighDecay") { - weight_decay_ = true; - } - - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 7); - element_nums_ = 1; - for (auto i : shape) { - element_nums_ *= i; - } - - InitSizeLists(); - return true; - } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - float *beta1 = GetDeviceAddress(inputs, 0); - float *one_sub_beta1 = GetDeviceAddress(inputs, 1); - float *beta2 = GetDeviceAddress(inputs, 2); - float *one_sub_beta2 = GetDeviceAddress(inputs, 3); - float *epsilon = GetDeviceAddress(inputs, 4); - float *lr = GetDeviceAddress(inputs, 5); - T *param = GetDeviceAddress(inputs, 6); - T *m = GetDeviceAddress(inputs, 7); - T *v = GetDeviceAddress(inputs, 8); - T *gradient = GetDeviceAddress(inputs, 9); - float *weight_decay = nullptr; - if (weight_decay_) { - weight_decay = GetDeviceAddress(inputs, 10); - } - AdamWeightDecay(element_nums_, true, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, - param, gradient, reinterpret_cast(stream_ptr)); - return true; - } - - protected: - void InitResource() override{}; - void InitSizeLists() override { - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(element_nums_ * sizeof(T)); - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(element_nums_ * sizeof(T)); - if (weight_decay_) { - input_size_list_.push_back(sizeof(float)); - } - output_size_list_.push_back(element_nums_ * sizeof(T)); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int element_nums_; - bool weight_decay_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc deleted file mode 100644 index 91747d24d8..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/fused_batch_norm_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FusedBatchNormGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - FusedBatchNormGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(BatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FusedBatchNormGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(BatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - FusedBatchNormGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h deleted file mode 100644 index b0a898209b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h +++ /dev/null @@ -1,190 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class FusedBatchNormGpuKernel : public GpuKernel { - public: - FusedBatchNormGpuKernel() - : batch_(0), - channel_(0), - height_(0), - width_(0), - mode_(CUDNN_BATCHNORM_SPATIAL), - epsilon_(10e-5), - exp_avg_factor_(0.1), - is_train_(false), - is_null_input_(false), - x_desc_(nullptr), - y_desc_(nullptr), - scale_bias_mean_var_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} - ~FusedBatchNormGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - VARIABLE_NOT_USED(stream_ptr); - if (is_null_input_) { - return true; - } - auto x = GetDeviceAddress(inputs, 0); - auto scale = GetDeviceAddress(inputs, 1); - auto bias = GetDeviceAddress(inputs, 2); - auto runing_mean = GetDeviceAddress(inputs, 3); - auto runnig_variance = GetDeviceAddress(inputs, 4); - auto y = GetDeviceAddress(outputs, 0); - - const float alpha = 1; - const float beta = 0; - if (is_train_) { - auto save_mean = GetDeviceAddress(outputs, 3); - auto save_variance = GetDeviceAddress(outputs, 4); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, - scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, - runnig_variance, epsilon_, save_mean, save_variance), - "Kernel launch failed"); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationForwardInference(handle_, mode_, &alpha, &beta, x_desc_, x, - y_desc_, y, scale_bias_mean_var_desc_, scale, - bias, runing_mean, runnig_variance, epsilon_), - "Kernel launch failed"); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 5) { - MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5"; - } - - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (shape.size() != 4) { - MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGpuKernel should be >= 4"; - } - is_null_input_ = CHECK_NULL_INPUT(shape); - if (is_null_input_) { - MS_LOG(WARNING) << "FusedBatchNormGpuKernel input is null"; - InitSizeLists(); - return true; - } - batch_ = SizeToInt(shape[0]); - channel_ = SizeToInt(shape[1]); - height_ = SizeToInt(shape[2]); - width_ = SizeToInt(shape[3]); - - mode_ = CUDNN_BATCHNORM_SPATIAL; - epsilon_ = GetAttr(kernel_node, "epsilon"); - // P.FusedBatchNorm is used for training; P.BatchNorm is used for inference - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - if (node_name == "FusedBatchNorm") { - is_train_ = true; - exp_avg_factor_ = GetAttr(kernel_node, "momentum"); - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), - "Set x desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), - "Set y desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), - "Set para desc failed"); - - InitSizeLists(); - - return true; - } - - protected: - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); - } - void InitSizeLists() override { - size_t input_size = 0; - size_t para_size = 0; - size_t output_size = 0; - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &input_size), "Get input size failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, ¶_size), - "Get para size failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(y_desc_, &output_size), "Get para size failed"); - } - input_size_list_.push_back(input_size); - input_size_list_.push_back(para_size); // scale - input_size_list_.push_back(para_size); // bias - input_size_list_.push_back(para_size); // mean - input_size_list_.push_back(para_size); // variance - - output_size_list_.push_back(output_size); - output_size_list_.push_back(para_size); // running mean - output_size_list_.push_back(para_size); // running variance - output_size_list_.push_back(para_size); // save mean - output_size_list_.push_back(para_size); // save variance - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed"); - } - - int batch_; - int channel_; - int height_; - int width_; - cudnnBatchNormMode_t mode_; - double epsilon_; - double exp_avg_factor_; - bool is_train_; - bool is_null_input_; - cudnnTensorDescriptor_t x_desc_; - cudnnTensorDescriptor_t y_desc_; - cudnnTensorDescriptor_t scale_bias_mean_var_desc_; - cudnnHandle_t handle_; - cudnnDataType_t cudnn_data_type_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc deleted file mode 100644 index 3947aaea9a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FusedBatchNormGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - FusedBatchNormGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h deleted file mode 100644 index 712354b17c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h +++ /dev/null @@ -1,178 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class FusedBatchNormGradGpuKernel : public GpuKernel { - public: - FusedBatchNormGradGpuKernel() - : batch_(0), - channel_(0), - height_(0), - width_(0), - mode_(CUDNN_BATCHNORM_SPATIAL), - epsilon_(10e-5), - is_null_input_(false), - x_desc_(nullptr), - dy_desc_(nullptr), - dx_desc_(nullptr), - scale_bias_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} - ~FusedBatchNormGradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - VARIABLE_NOT_USED(stream_ptr); - if (is_null_input_) { - return true; - } - auto dy = GetDeviceAddress(inputs, 0); - auto x = GetDeviceAddress(inputs, 1); - auto scale = GetDeviceAddress(inputs, 2); - auto save_mean = GetDeviceAddress(inputs, 3); - auto save_variance = GetDeviceAddress(inputs, 4); - auto dx = GetDeviceAddress(outputs, 0); - auto bn_scale = GetDeviceAddress(outputs, 1); - auto bn_bias = GetDeviceAddress(outputs, 2); - - const float alpha_data_diff = 1; - const float beta_data_diff = 0; - const float alpha_param_diff = 1; - const float beta_param_diff = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, - &beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale, - bn_scale, bn_bias, epsilon_, save_mean, save_variance), - "Kernel Launch Failed."); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 5) { - MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGradGpuKernel should be 5"; - } - - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (shape.size() != 4) { - MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradGpuKernel should be 4"; - return false; - } - is_null_input_ = CHECK_NULL_INPUT(shape); - if (is_null_input_) { - MS_LOG(WARNING) << "FusedBatchNormGradGpuKernel input is null"; - InitSizeLists(); - return true; - } - batch_ = SizeToInt(shape[0]); - channel_ = SizeToInt(shape[1]); - height_ = SizeToInt(shape[2]); - width_ = SizeToInt(shape[3]); - - mode_ = CUDNN_BATCHNORM_SPATIAL; - epsilon_ = GetAttr(kernel_node, "epsilon"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), - "Set x desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), - "Set dy desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), - "Set dx desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), - "Set para desc failed"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_desc_), "Create para desc failed"); - } - - void InitSizeLists() override { - size_t input_size = 0; - size_t para_size = 0; - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &input_size), "Get input size failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(scale_bias_desc_, ¶_size), "Get input size failed"); - } - - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(para_size); - input_size_list_.push_back(para_size); - input_size_list_.push_back(para_size); - - output_size_list_.push_back(input_size); - output_size_list_.push_back(para_size); - output_size_list_.push_back(para_size); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_desc_), "Destroy para desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); - } - - int batch_; - int channel_; - int height_; - int width_; - - cudnnBatchNormMode_t mode_; - double epsilon_; - bool is_null_input_; - cudnnTensorDescriptor_t x_desc_; - cudnnTensorDescriptor_t dy_desc_; - cudnnTensorDescriptor_t dx_desc_; - cudnnTensorDescriptor_t scale_bias_desc_; - - cudnnHandle_t handle_; - cudnnDataType_t cudnn_data_type_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc deleted file mode 100644 index 32d91be80a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/gelu_grad_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(GeluGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - GeLUGpuGradKernel, float) -MS_REG_GPU_KERNEL_ONE(GeluGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - GeLUGpuGradKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h deleted file mode 100644 index 6415349012..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/gelu_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class GeLUGpuGradKernel : public GpuKernel { - public: - GeLUGpuGradKernel() : input_size_(0) {} - ~GeLUGpuGradKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *dy_addr = GetDeviceAddress(inputs, 0); - T *x_addr = GetDeviceAddress(inputs, 1); - T *dx_addr = GetDeviceAddress(outputs, 0); - - GeluGradKernel(input_size_ / sizeof(T), dy_addr, x_addr, dx_addr, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - input_size_ = sizeof(T); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (auto dim : input_shape) { - input_size_ *= dim; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc deleted file mode 100644 index ca54ff68ad..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/gelu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - GeluGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - GeluGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h deleted file mode 100644 index 60968d109b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/gelu_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class GeluGpuKernel : public GpuKernel { - public: - GeluGpuKernel() : input_size_(0) {} - ~GeluGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - - Gelu(input_size_ / sizeof(T), input_addr, output_addr, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - input_size_ = sizeof(T); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (auto dim : input_shape) { - input_size_ *= dim; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc deleted file mode 100644 index 19e4dc17a6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/layer_norm_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LayerNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LayerNormGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LayerNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - LayerNormGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.h deleted file mode 100644 index d5ec3ff8f2..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.h +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class LayerNormGpuKernel : public GpuKernel { - public: - LayerNormGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} - ~LayerNormGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - auto x = GetDeviceAddress(inputs, 0); - auto gamma = GetDeviceAddress(inputs, 1); - auto beta = GetDeviceAddress(inputs, 2); - auto y = GetDeviceAddress(outputs, 0); - auto mean = GetDeviceAddress(outputs, 1); - auto variance = GetDeviceAddress(outputs, 2); - - const T epsilon = 10e-12; - LayerNorm(input_row_, input_col_, param_dim_, epsilon, x, gamma, beta, y, mean, variance, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - int begin_norm_axis = GetAttr(kernel_node, "begin_norm_axis"); - int begin_params_axis = GetAttr(kernel_node, "begin_params_axis"); - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (begin_norm_axis < 0) { - begin_norm_axis += input_shape.size(); - } - - if (begin_params_axis < 0) { - begin_params_axis += input_shape.size(); - } - - for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { - input_row_ *= input_shape[i]; - } - - for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { - input_col_ *= input_shape[i]; - } - - for (size_t i = begin_params_axis; i < input_shape.size(); i++) { - param_dim_ *= input_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); - input_size_list_.push_back(param_dim_ * sizeof(T)); - input_size_list_.push_back(param_dim_ * sizeof(T)); - - output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); - output_size_list_.push_back(input_row_ * sizeof(T)); - output_size_list_.push_back(input_row_ * sizeof(T)); - return; - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int input_row_; - int input_col_; - int param_dim_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc deleted file mode 100644 index 7991d42499..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/layer_norm_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LayerNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LayerNormGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LayerNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - LayerNormGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.h deleted file mode 100644 index 83bdedb9b3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.h +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class LayerNormGradGpuKernel : public GpuKernel { - public: - LayerNormGradGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} - ~LayerNormGradGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - auto x = GetDeviceAddress(inputs, 0); - auto dy = GetDeviceAddress(inputs, 1); - auto var = GetDeviceAddress(inputs, 2); - auto mean = GetDeviceAddress(inputs, 3); - auto gamma = GetDeviceAddress(inputs, 4); - auto dx = GetDeviceAddress(outputs, 0); - auto dg = GetDeviceAddress(outputs, 1); - auto db = GetDeviceAddress(outputs, 2); - - const T epsilon = 10e-12; - LayerNormGrad(input_row_, input_col_, param_dim_, epsilon, dy, x, mean, var, gamma, dx, dg, db, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - int begin_norm_axis = GetAttr(kernel_node, "begin_norm_axis"); - int begin_params_axis = GetAttr(kernel_node, "begin_params_axis"); - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (begin_norm_axis < 0) { - begin_norm_axis += input_shape.size(); - } - - if (begin_params_axis < 0) { - begin_params_axis += input_shape.size(); - } - - for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { - input_row_ *= input_shape[i]; - } - - for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { - input_col_ *= input_shape[i]; - } - - for (size_t i = begin_params_axis; i < input_shape.size(); i++) { - param_dim_ *= input_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); - input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); - input_size_list_.push_back(input_row_ * sizeof(T)); - input_size_list_.push_back(input_row_ * sizeof(T)); - input_size_list_.push_back(param_dim_ * sizeof(T)); - - output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); - output_size_list_.push_back(param_dim_ * sizeof(T)); - output_size_list_.push_back(param_dim_ * sizeof(T)); - return; - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int input_row_; - int input_col_; - int param_dim_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.cc deleted file mode 100644 index c745c216f7..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/lstm_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LSTM, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LstmGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LSTM, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - LstmGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.h deleted file mode 100644 index 42eda96b02..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.h +++ /dev/null @@ -1,247 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class LstmGpuKernel : public GpuKernel { - public: - LstmGpuKernel() - : batch_size_(0), - seq_len_(0), - input_size_(0), - hidden_size_(0), - num_layers_(0), - has_bias_(false), - bidirectional_(false), - states_init_(false), - dropout_(0), - weight_size_(0), - reserved_size_(0), - x_desc_(nullptr), - hx_desc_(nullptr), - cx_desc_(nullptr), - w_desc_(nullptr), - dropout_desc_(nullptr), - y_desc_(nullptr), - hy_desc_(nullptr), - cy_desc_(nullptr), - rnn_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} - ~LstmGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(stream_ptr); - auto x_addr = GetDeviceAddress(inputs, 0); - auto hx_addr = GetDeviceAddress(inputs, 1); - auto cx_addr = GetDeviceAddress(inputs, 2); - auto w_addr = GetDeviceAddress(inputs, 3); - auto y_addr = GetDeviceAddress(outputs, 0); - auto hy_addr = GetDeviceAddress(outputs, 1); - auto cy_addr = GetDeviceAddress(outputs, 2); - auto reserved_addr = GetDeviceAddress(outputs, 3); - auto states_addr = GetDeviceAddress(outputs, 4); - void *workspace_addr = GetDeviceAddress(workspace, 0); - - if (!states_init_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, output_size_list_[4], 0), - "set dropout_desc failed"); - states_init_ = true; - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnRNNForwardTraining(handle_, rnn_desc_, seq_len_, x_desc_.get(), x_addr, hx_desc_, hx_addr, cx_desc_, cx_addr, - w_desc_, w_addr, y_desc_.get(), y_addr, hy_desc_, hy_addr, cy_desc_, cy_addr, - workspace_addr, workspace_size_list_[0], reserved_addr, reserved_size_), - "launch lstm kernel failed"); - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - seq_len_ = SizeToInt(input_shape[0]); - batch_size_ = SizeToInt(input_shape[1]); - input_size_ = SizeToInt(input_shape[2]); - - input_size_ = GetAttr(kernel_node, "input_size"); - hidden_size_ = GetAttr(kernel_node, "hidden_size"); - num_layers_ = GetAttr(kernel_node, "num_layers"); - has_bias_ = GetAttr(kernel_node, "has_bias"); - bidirectional_ = GetAttr(kernel_node, "bidirectional"); - dropout_ = GetAttr(kernel_node, "dropout"); - - cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; - cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - cudnnRNNMode_t rnn_mode = CUDNN_LSTM; - cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; - CreateTensorDescGrp(); - int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set hx_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set cx_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set hy_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set cy_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), - "set dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - input_mode, direction, rnn_mode, algo, cudnn_data_type_), - "set rnn_desc failed"); - cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); - auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, x_desc_[0], &weight_size_, cudnn_data_type_), - "get weight_size_ failed"); - if (weight_size != weight_size_) { - MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; - } - int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), - "set w_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &reserved_size_), - "get reserve size failed"); - InitSizeLists(); - return true; - } - void CreateTensorDescGrp() { - int x_dims[3]{batch_size_, input_size_, 1}; - int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; - - x_desc_ = std::make_unique(seq_len_); - y_desc_ = std::make_unique(seq_len_); - - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_[i]), "create x_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(x_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), "set x_desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); - } - } - - protected: - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cx_desc_), "create cx_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "create w_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hy_desc_), "create hy_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cy_desc_), "create cy_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); - } - void InitSizeLists() override { - size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); - - size_t h_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); - - input_size_list_.push_back(x_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(weight_size_); - - size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); - output_size_list_.push_back(y_size); - output_size_list_.push_back(h_size); - output_size_list_.push_back(h_size); - output_size_list_.push_back(reserved_size_); - size_t state_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); - output_size_list_.push_back(state_size); - - size_t workspace_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &workspace_size), - "get workspace size failed"); - workspace_size_list_.push_back(workspace_size); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cy_desc_), "destroy cy_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hy_desc_), "destroy hy_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "destroy w_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cx_desc_), "destroy cx_desc failed"); - - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_[i]), "destroy x_desc failed"); - } - } - - int batch_size_; - int seq_len_; - int input_size_; - int hidden_size_; - int num_layers_; - - bool has_bias_; - bool bidirectional_; - bool states_init_; - float dropout_; - - size_t weight_size_; - size_t reserved_size_; - - // input desc - std::unique_ptr x_desc_; - cudnnTensorDescriptor_t hx_desc_; - cudnnTensorDescriptor_t cx_desc_; - cudnnFilterDescriptor_t w_desc_; - cudnnDropoutDescriptor_t dropout_desc_; - std::unique_ptr y_desc_; - cudnnTensorDescriptor_t hy_desc_; - cudnnTensorDescriptor_t cy_desc_; - cudnnRNNDescriptor_t rnn_desc_; - - cudnnHandle_t handle_; - cudnnDataType_t cudnn_data_type_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.cc deleted file mode 100644 index ab88308d4e..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.cc +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/lstm_grad_data_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LSTMGradData, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LstmGradDataGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LSTMGradData, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - LstmGradDataGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.h deleted file mode 100644 index 6eeefa262c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.h +++ /dev/null @@ -1,284 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class LstmGradDataGpuKernel : public GpuKernel { - public: - LstmGradDataGpuKernel() - : batch_size_(0), - seq_len_(0), - input_size_(0), - hidden_size_(0), - num_layers_(0), - has_bias_(false), - bidirectional_(false), - states_init_(false), - dropout_(0), - weight_size_(0), - reserved_size_(0), - rnn_desc_(nullptr), - y_desc_(nullptr), - dy_desc_(nullptr), - dhy_desc_(nullptr), - dcy_desc_(nullptr), - w_desc_(nullptr), - hx_desc_(nullptr), - cx_desc_(nullptr), - dropout_desc_(nullptr), - dx_desc_(nullptr), - dhx_desc_(nullptr), - dcx_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} - ~LstmGradDataGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(stream_ptr); - auto y_addr = GetDeviceAddress(inputs, 0); - auto dy_addr = GetDeviceAddress(inputs, 1); - auto dhy_addr = GetDeviceAddress(inputs, 2); - auto dcy_addr = GetDeviceAddress(inputs, 3); - auto w_addr = GetDeviceAddress(inputs, 4); - auto hx_addr = GetDeviceAddress(inputs, 5); - auto cx_addr = GetDeviceAddress(inputs, 6); - auto reserved_addr = GetDeviceAddress(inputs, 7); - auto states_addr = GetDeviceAddress(inputs, 8); - auto dx_addr = GetDeviceAddress(outputs, 0); - auto dhx_addr = GetDeviceAddress(outputs, 1); - auto dcx_addr = GetDeviceAddress(outputs, 2); - void *workspace_addr = GetDeviceAddress(workspace, 0); - - if (!states_init_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnRestoreDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, input_size_list_[8], 0), - "restore dropout state failed"); - states_init_ = true; - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnRNNBackwardData(handle_, rnn_desc_, seq_len_, y_desc_.get(), y_addr, dy_desc_.get(), dy_addr, dhy_desc_, - dhy_addr, dcy_desc_, dcy_addr, w_desc_, w_addr, hx_desc_, hx_addr, cx_desc_, cx_addr, - dx_desc_.get(), dx_addr, dhx_desc_, dhx_addr, dcx_desc_, dcx_addr, workspace_addr, - workspace_size_list_[0], reserved_addr, reserved_size_), - "launch lstm back data kernel failed"); - - CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), - "stream synchronize failed."); - return true; - } - void GetAttrs(const CNodePtr &kernel_node) { - input_size_ = GetAttr(kernel_node, "input_size"); - hidden_size_ = GetAttr(kernel_node, "hidden_size"); - num_layers_ = GetAttr(kernel_node, "num_layers"); - has_bias_ = GetAttr(kernel_node, "has_bias"); - bidirectional_ = GetAttr(kernel_node, "bidirectional"); - dropout_ = GetAttr(kernel_node, "dropout"); - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - seq_len_ = SizeToInt(input_shape[0]); - batch_size_ = SizeToInt(input_shape[1]); - GetAttrs(kernel_node); - cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; - cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - cudnnRNNMode_t rnn_mode = CUDNN_LSTM; - cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; - CreateTensorDescGrp(); - int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dhy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dhy_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dcy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dcy_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set hx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set cx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dhx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dhx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dcx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dcx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), - "set dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - input_mode, direction, rnn_mode, algo, cudnn_data_type_), - "set rnn_desc failed"); - cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); - auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); - size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, dx_desc_[0], &weight_size_, cudnn_data_type_), - "get weight_size_ failed"); - if (weight_size != weight_size_) { - MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; - } - int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), - "set w_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, dx_desc_.get(), &reserved_size_), "get size failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dhy_desc_), "create dhy_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dcy_desc_), "create dcy_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cx_desc_), "create cx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "create w_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dhx_desc_), "create dhx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dcx_desc_), "create dcx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); - } - - void InitSizeLists() override { - size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); - - size_t h_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); - - input_size_list_.push_back(y_size); - input_size_list_.push_back(y_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(weight_size_); - input_size_list_.push_back(h_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(reserved_size_); - size_t state_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); - input_size_list_.push_back(state_size); - - size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); - output_size_list_.push_back(x_size); - output_size_list_.push_back(h_size); - output_size_list_.push_back(h_size); - - size_t workspace_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, dx_desc_.get(), &workspace_size), - "get workspace size failed"); - workspace_size_list_.push_back(workspace_size); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dcx_desc_), "destroy dcx_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dhx_desc_), "destroy dhx_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "destroy w_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cx_desc_), "destroy cx_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dcy_desc_), "destroy dcy_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dhy_desc_), "destroy dhy_desc_ failed"); - DestroyTensorDescGrp(); - } - void CreateTensorDescGrp() { - int x_dims[3]{batch_size_, input_size_, 1}; - int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; - - dx_desc_ = std::make_unique(seq_len_); - y_desc_ = std::make_unique(seq_len_); - dy_desc_ = std::make_unique(seq_len_); - - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_[i]), "create x_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dx_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), - "set dx_desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_[i]), "create dy_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dy_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), - "set dy_desc_ failed"); - } - } - - void DestroyTensorDescGrp() { - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_[i]), "destroy dy_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_[i]), "destroy x_desc failed"); - } - } - - int batch_size_; - int seq_len_; - int input_size_; - int hidden_size_; - int num_layers_; - - bool has_bias_; - bool bidirectional_; - bool states_init_; - float dropout_; - - size_t weight_size_; - size_t reserved_size_; - - cudnnRNNDescriptor_t rnn_desc_; - - // input desc - std::unique_ptr y_desc_; - std::unique_ptr dy_desc_; - cudnnTensorDescriptor_t dhy_desc_; - cudnnTensorDescriptor_t dcy_desc_; - cudnnFilterDescriptor_t w_desc_; - cudnnTensorDescriptor_t hx_desc_; - cudnnTensorDescriptor_t cx_desc_; - - cudnnDropoutDescriptor_t dropout_desc_; - - // output desc - std::unique_ptr dx_desc_; - cudnnTensorDescriptor_t dhx_desc_; - cudnnTensorDescriptor_t dcx_desc_; - - cudnnHandle_t handle_; - cudnnDataType_t cudnn_data_type_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.cc deleted file mode 100644 index 856a986e07..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/lstm_grad_weight_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LSTMGradWeight, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LstmGradWeightGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LSTMGradWeight, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - LstmGradWeightGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.h deleted file mode 100644 index a1a4852c84..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.h +++ /dev/null @@ -1,231 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -namespace mindspore { -namespace kernel { -template -class LstmGradWeightGpuKernel : public GpuKernel { - public: - LstmGradWeightGpuKernel() - : batch_size_(0), - seq_len_(0), - input_size_(0), - hidden_size_(0), - num_layers_(0), - has_bias_(false), - bidirectional_(false), - states_init_(false), - dropout_(0), - weight_size_(0), - reserved_size_(0), - rnn_desc_(nullptr), - dropout_desc_(nullptr), - x_desc_(nullptr), - hx_desc_(nullptr), - y_desc_(nullptr), - dw_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} - ~LstmGradWeightGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(stream_ptr); - auto x_addr = GetDeviceAddress(inputs, 0); - auto hx_addr = GetDeviceAddress(inputs, 1); - auto y_addr = GetDeviceAddress(inputs, 2); - auto reserved_addr = GetDeviceAddress(inputs, 3); - auto states_addr = GetDeviceAddress(inputs, 4); - auto dw_addr = GetDeviceAddress(outputs, 0); - void *workspace_addr = GetDeviceAddress(workspace, 0); - - if (!states_init_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnRestoreDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, input_size_list_[4], 0), - "restore dropout state failed"); - states_init_ = true; - } - - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemsetAsync(dw_addr, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), "cudaMemSet Failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnRNNBackwardWeights(handle_, rnn_desc_, seq_len_, x_desc_.get(), x_addr, hx_desc_, hx_addr, y_desc_.get(), - y_addr, workspace_addr, workspace_size_list_[0], dw_desc_, dw_addr, reserved_addr, - reserved_size_), - "launch lstm back weight kernel failed"); - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - seq_len_ = SizeToInt(input_shape[0]); - batch_size_ = SizeToInt(input_shape[1]); - - input_size_ = GetAttr(kernel_node, "input_size"); - hidden_size_ = GetAttr(kernel_node, "hidden_size"); - num_layers_ = GetAttr(kernel_node, "num_layers"); - has_bias_ = GetAttr(kernel_node, "has_bias"); - bidirectional_ = GetAttr(kernel_node, "bidirectional"); - dropout_ = GetAttr(kernel_node, "dropout"); - - cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; - cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - cudnnRNNMode_t rnn_mode = CUDNN_LSTM; - cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; - - CreateTensorDescGrp(); - int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set hx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), - "set dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - input_mode, direction, rnn_mode, algo, cudnn_data_type_), - "set rnn_desc failed"); - cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); - - auto weight_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, x_desc_[0], &weight_size_, cudnn_data_type_), - "get weight_size_ failed"); - if (weight_size != weight_size_) { - MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; - } - int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), - "set dw_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &reserved_size_), - "get reserve size failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&dw_desc_), "create dw_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); - } - void InitSizeLists() override { - size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); - - size_t h_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); - - size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); - input_size_list_.push_back(x_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(y_size); - input_size_list_.push_back(reserved_size_); - size_t state_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); - input_size_list_.push_back(state_size); - - output_size_list_.push_back(weight_size_); - - size_t workspace_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &workspace_size), - "get workspace size failed"); - workspace_size_list_.push_back(workspace_size); - } - - private: - void CreateTensorDescGrp() { - int x_dims[3]{batch_size_, input_size_, 1}; - int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; - - x_desc_ = std::make_unique(seq_len_); - y_desc_ = std::make_unique(seq_len_); - - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_[i]), "create x_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(x_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), "set x_desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); - } - } - void DestroyTensorDescGrp() { - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_[i]), "destroy x_desc failed"); - } - } - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(dw_desc_), "destroy dw_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc_ failed"); - DestroyTensorDescGrp(); - } - - int batch_size_; - int seq_len_; - int input_size_; - int hidden_size_; - int num_layers_; - - bool has_bias_; - bool bidirectional_; - bool states_init_; - float dropout_; - - size_t weight_size_; - size_t reserved_size_; - - cudnnRNNDescriptor_t rnn_desc_; - cudnnDropoutDescriptor_t dropout_desc_; - - // input desc - std::unique_ptr x_desc_; - cudnnTensorDescriptor_t hx_desc_; - std::unique_ptr y_desc_; - - // output desc - cudnnFilterDescriptor_t dw_desc_; - - cudnnHandle_t handle_; - cudnnDataType_t cudnn_data_type_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc deleted file mode 100644 index e8b2b17706..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/momentum_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - MomentumGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - MomentumGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16), - MomentumGpuKernel, half, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h deleted file mode 100644 index 5abfb9e97b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/momentum_impl.cuh" -namespace mindspore { -namespace kernel { -template -class MomentumGpuKernel : public GpuKernel { - public: - MomentumGpuKernel() - : variable_size_(0), accumulation_size_(0), learning_rate_size_(0), gradient_size_(0), momentum_size_(0) {} - ~MomentumGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, - void *stream_ptr) override { - T *variable = GetDeviceAddress(inputs, 0); - T *accumulation = GetDeviceAddress(inputs, 1); - S *learning_rate = GetDeviceAddress(inputs, 2); - T *gradient = GetDeviceAddress(inputs, 3); - S *momentum = GetDeviceAddress(inputs, 4); - MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 5) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but momentum needs 5 inputs."; - return false; - } - - variable_size_ = sizeof(T); - accumulation_size_ = sizeof(T); - learning_rate_size_ = sizeof(S); - gradient_size_ = sizeof(T); - momentum_size_ = sizeof(S); - - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < variable_shape.size(); i++) { - variable_size_ *= variable_shape[i]; - } - auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < accumulation_shape.size(); i++) { - accumulation_size_ *= accumulation_shape[i]; - } - auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - for (size_t i = 0; i < gradient_shape.size(); i++) { - gradient_size_ *= gradient_shape[i]; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(variable_size_); - input_size_list_.push_back(accumulation_size_); - input_size_list_.push_back(learning_rate_size_); - input_size_list_.push_back(gradient_size_); - input_size_list_.push_back(momentum_size_); - output_size_list_.push_back(0); - } - - private: - size_t variable_size_; - size_t accumulation_size_; - size_t learning_rate_size_; - size_t gradient_size_; - size_t momentum_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.cc deleted file mode 100644 index e871af360a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/pooling_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - PoolingGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - PoolingGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - PoolingGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - PoolingGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h deleted file mode 100644 index 0dda1e8998..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h +++ /dev/null @@ -1,252 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/pad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class PoolingGpuFwdKernel : public GpuKernel { - public: - PoolingGpuFwdKernel() - : cudnn_handle_(nullptr), - input_descriptor_(nullptr), - output_descriptor_(nullptr), - pooling_descriptor_(nullptr), - padded_descriptor_(nullptr), - pooling_mode_(CUDNN_POOLING_MAX), - cudnn_data_type_(CUDNN_DATA_FLOAT), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - pad_value_(0), - is_null_input_(false), - input_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} - ~PoolingGpuFwdKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (is_null_input_) { - return true; - } - T *input_addr = reinterpret_cast(inputs[0]->addr); - T *output_addr = reinterpret_cast(outputs[0]->addr); - const float alpha = 1; - const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded_addr = reinterpret_cast(workspace[0]->addr); - CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, - reinterpret_cast(stream_ptr)); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, padded_descriptor_, - padded_addr, &beta, output_descriptor_, output_addr), - "cudnnPoolingForward failed"); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, - input_addr, &beta, output_descriptor_, output_addr), - "cudnnPoolingForward failed"); - } - return true; - } - bool Init(const CNodePtr &kernel_node) { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "PoolingGpuFwdKernel input is null."; - InitSizeLists(); - return true; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), - SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), - "cudnnSetTensor4dDescriptor failed"); - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), - SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), - "cudnnSetTensor4dDescriptor failed"); - auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); - int window_height = window[2]; - int window_width = window[3]; - stride_ = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); - SetPoolingMode(kernel_node); - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { - SetPad(input_shape, window_height, window_width); - } else { - pad_height_ = 0; - pad_width_ = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, - window_width, pad_height_, pad_width_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); - } - - InitSizeLists(); - return true; - } - - protected: - void InitResource() { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), - "cudnnCreatePoolingDescriptor failed"); - } - void InitSizeLists() { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(input_descriptor_, reinterpret_cast(&input_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(output_descriptor_, reinterpret_cast(&output_size_)), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast(&padded_size_)), - "cudnnGetTensorSizeInBytes failed"); - workspace_size_list_.push_back(padded_size_); - if (padded_size_ == 0) { - MS_LOG(EXCEPTION) << "Padded size is 0."; - } - } - return; - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but pooling needs 1 inputs."; - return false; - } - return true; - } - void SetPad(const std::vector &input_shape, const int &window_height, const int &window_width) { - n_ = SizeToInt(input_shape[0]); - c_ = SizeToInt(input_shape[1]); - old_height_ = SizeToInt(input_shape[2]); - old_width_ = SizeToInt(input_shape[3]); - pad_height_ = - std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) - : (old_height_ / stride_[2]) + 1) - - 1) * - stride_[2] + - window_height - old_height_); - pad_width_ = - std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) - : (old_width_ / stride_[3]) + 1) - - 1) * - stride_[3] + - window_width - old_width_); - pad_top_ = pad_height_ / 2; - pad_left_ = pad_width_ / 2; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, - window_height, window_width, use_pad_ ? 0 : pad_top_, - use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); - } - void SetPoolingMode(const CNodePtr &kernel_node) { - pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); - mode_ = AnfAlgo::GetCNodeName(kernel_node); - if (mode_ == "AvgPool") { - pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - pad_value_ = 0.0; - } else { - pooling_mode_ = CUDNN_POOLING_MAX; - pad_value_ = kSignedMinFloat; - } - } - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), - "cudnnDestroyPoolingDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); - } - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t input_descriptor_; - cudnnTensorDescriptor_t output_descriptor_; - cudnnPoolingDescriptor_t pooling_descriptor_; - cudnnTensorDescriptor_t padded_descriptor_; - cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; - std::vector stride_; - std::string mode_; - std::string pad_mode_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - cudnnDataType_t cudnn_data_type_; - - int old_height_; - int old_width_; - int pad_height_; - int pad_width_; - int pad_top_; - int pad_left_; - int n_; - int c_; - float pad_value_; - bool is_null_input_; - size_t input_size_; - size_t output_size_; - size_t padded_size_; - size_t workspace_size_; - bool use_pad_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.cc deleted file mode 100644 index c3d4a44943..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/pooling_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - PoolingGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - PoolingGradGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - PoolingGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - PoolingGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.h deleted file mode 100644 index e8f1ebc1af..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.h +++ /dev/null @@ -1,296 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/pad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class PoolingGradGpuKernel : public GpuKernel { - public: - PoolingGradGpuKernel() - : cudnn_handle_(nullptr), - pooling_descriptor_(nullptr), - y_descriptor_(nullptr), - dy_descriptor_(nullptr), - x_descriptor_(nullptr), - dx_descriptor_(nullptr), - padded_descriptor_(nullptr), - pooling_mode_(CUDNN_POOLING_MAX), - cudnn_data_type_(CUDNN_DATA_FLOAT), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - pad_value_(0), - is_null_input_(false), - input_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} - ~PoolingGradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *x_data = GetDeviceAddress(inputs, 0); - T *y = GetDeviceAddress(inputs, 1); - T *dy = GetDeviceAddress(inputs, 2); - T *dx = GetDeviceAddress(outputs, 0); - - const float alpha = 1; - const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded = GetDeviceAddress(workspace, 0); - T *padded_dx = GetDeviceAddress(workspace, 1); - - CalPad(padded_size_ / sizeof(T), x_data, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, - reinterpret_cast(stream_ptr)); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, - padded_descriptor_, padded, &beta, padded_descriptor_, padded_dx), - "cudnnPoolingBackward failed"); - - CalPadGrad(output_size_ / sizeof(T), padded_dx, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, - x_descriptor_, x_data, &beta, dx_descriptor_, dx), - "cudnnPoolingBackward failed"); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } - auto window = GetAttr>(kernel_node, "ksize"); - int window_height = window[2]; - int window_width = window[3]; - SetPoolingMode(kernel_node); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); - if (is_null_input_) { - MS_LOG(WARNING) << "PoolingGradGpuKernel input is null."; - InitSizeLists(); - return true; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(y_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_mask[0]), - SizeToInt(input_mask[1]), SizeToInt(input_mask[2]), SizeToInt(input_mask[3])), - "cudnnSetTensor4dDescriptor"); - - auto dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dout_shape[0]), - SizeToInt(dout_shape[1]), SizeToInt(dout_shape[2]), SizeToInt(dout_shape[3])), - "cudnnSetTensor4dDescriptor"); - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dx_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), - SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), - "cudnnSetTensor4dDescriptor failed"); - if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { - SetPad(input_shape, window_height, window_width); - } else { - if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_height_ = 0; - pad_width_ = 0; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, - window_width, pad_height_, pad_width_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), - SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), - "cudnnSetTensor4dDescriptor"); - } - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), - "cudnnCreatePoolingDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(y_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dx_descriptor_, &output_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), - "cudnnGetTensorSizeInBytes failed"); - if (padded_size_ == 0) { - MS_LOG(EXCEPTION) << "Padded size is 0."; - } - workspace_size_list_.push_back(padded_size_); - workspace_size_list_.push_back(padded_size_); - } - return; - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but PoolingGradGpuKernel needs 3 inputs."; - return false; - } - return true; - } - void SetPad(const std::vector &input_shape, const int &window_height, const int &window_width) { - n_ = SizeToInt(input_shape[0]); - c_ = SizeToInt(input_shape[1]); - old_height_ = SizeToInt(input_shape[2]); - old_width_ = SizeToInt(input_shape[3]); - pad_height_ = - std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) - : (old_height_ / stride_[2]) + 1) - - 1) * - stride_[2] + - window_height - old_height_); - pad_width_ = - std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) - : (old_width_ / stride_[3]) + 1) - - 1) * - stride_[3] + - window_width - old_width_); - pad_top_ = pad_height_ / 2; - pad_left_ = pad_width_ / 2; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), - SizeToInt(input_shape[1]), SizeToInt(input_shape[2]) + (use_pad_ ? pad_height_ : 0), - SizeToInt(input_shape[3]) + (use_pad_ ? pad_width_ : 0)), - "cudnnSetTensor4dDescriptor"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, - window_height, window_width, use_pad_ ? 0 : pad_top_, - use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); - } - void SetPoolingMode(const CNodePtr &kernel_node) { - pad_mode_ = GetAttr(kernel_node, "padding"); - stride_ = GetAttr>(kernel_node, "strides"); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - mode_ = AnfAlgo::GetCNodeName(kernel_node); - if (mode_ == "AvgPoolGradGpu") { - pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - pad_value_ = 0.0; - } else { - pooling_mode_ = CUDNN_POOLING_MAX; - pad_value_ = kSignedMinFloat; - } - } - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), - "cudnnDestroyPoolingDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_descriptor_), "cudnnDestroyTensorDescriptor failed"); - } - - cudnnHandle_t cudnn_handle_; - cudnnPoolingDescriptor_t pooling_descriptor_; - cudnnTensorDescriptor_t y_descriptor_; - cudnnTensorDescriptor_t dy_descriptor_; - cudnnTensorDescriptor_t x_descriptor_; - cudnnTensorDescriptor_t dx_descriptor_; - cudnnTensorDescriptor_t padded_descriptor_; - cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; - std::vector stride_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - std::string mode_; - std::string pad_mode_; - cudnnDataType_t cudnn_data_type_; - int old_height_; - int old_width_; - int pad_height_; - int pad_width_; - int pad_top_; - int pad_left_; - int n_; - int c_; - float pad_value_; - bool is_null_input_; - size_t input_size_; - size_t output_size_; - size_t padded_size_; - size_t workspace_size_; - bool use_pad_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc deleted file mode 100644 index 032e8eeec4..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/rmsprop_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(ApplyRMSProp, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - RMSPropGpuKernel, float) - -MS_REG_GPU_KERNEL_ONE(ApplyCenteredRMSProp, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - RMSPropGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h deleted file mode 100644 index 9e148b690d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h +++ /dev/null @@ -1,121 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/rmsprop_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class RMSPropGpuKernel : public GpuKernel { - public: - RMSPropGpuKernel() : size_(1), use_center_(false), decay_(0.0), momentum_(0.9), epsilon_(1e-12) {} - ~RMSPropGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream) override { - if (!use_center_) { - T *variable = GetDeviceAddress(inputs, 0); - T *mean_square = GetDeviceAddress(inputs, 1); - T *moment = GetDeviceAddress(inputs, 2); - T *learning_rate = GetDeviceAddress(inputs, 3); - T *gradients = GetDeviceAddress(inputs, 4); - - RmsProp(learning_rate, decay_, momentum_, epsilon_, variable, mean_square, moment, gradients, size_, - reinterpret_cast(stream)); - } else { - T *variable = GetDeviceAddress(inputs, 0); - T *mean_gradients = GetDeviceAddress(inputs, 1); - T *mean_square = GetDeviceAddress(inputs, 2); - T *moment = GetDeviceAddress(inputs, 3); - T *gradients = GetDeviceAddress(inputs, 4); - T *learning_rate = GetDeviceAddress(inputs, 5); - T *decay = GetDeviceAddress(inputs, 6); - T *momentum = GetDeviceAddress(inputs, 7); - T *epsilon = GetDeviceAddress(inputs, 8); - - RmsPropCenter(learning_rate, decay, momentum, epsilon, variable, mean_gradients, mean_square, moment, gradients, - size_, reinterpret_cast(stream)); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - if (node_name == "ApplyCenteredRMSProp") { - use_center_ = true; - } - - if (node_name == "ApplyRMSProp") { - decay_ = GetAttr(kernel_node, "rho"); - momentum_ = GetAttr(kernel_node, "momentum"); - epsilon_ = GetAttr(kernel_node, "epsilon"); - } - auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - for (auto &dim : input_shape) { - size_ *= dim; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - size_t input_size = size_ * sizeof(T); - if (!use_center_) { - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(input_size); - output_size_list_.push_back(input_size); - } else { - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(sizeof(T)); - output_size_list_.push_back(input_size); - } - } - - private: - size_t size_; - bool use_center_; - float decay_; - float momentum_; - float epsilon_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc deleted file mode 100644 index 1e650811fd..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - SigmoidCrossEntropyWithLogits, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SigmoidCrossEntropyWithLogitsGpuKernel, float, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h deleted file mode 100644 index 8d0efe90b4..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SigmoidCrossEntropyWithLogitsGpuKernel : public GpuKernel { - public: - SigmoidCrossEntropyWithLogitsGpuKernel() : logits_size_(0), labels_size_(0), outputs_size_(0) {} - - ~SigmoidCrossEntropyWithLogitsGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *logits_addr = GetDeviceAddress(inputs, 0); - S *labels_addr = GetDeviceAddress(inputs, 1); - T *outputs_addr = GetDeviceAddress(outputs, 0); - - SigmoidCrossEntropyWithLogits(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogits needs 2 inputs."; - return false; - } - logits_size_ = sizeof(T); - labels_size_ = sizeof(S); - outputs_size_ = sizeof(T); - - auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < logits_shape.size(); i++) { - logits_size_ *= logits_shape[i]; - } - - auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < labels_shape.size(); i++) { - labels_size_ *= labels_shape[i]; - } - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < output_shape.size(); i++) { - outputs_size_ *= output_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(logits_size_); - input_size_list_.push_back(labels_size_); - output_size_list_.push_back(outputs_size_); - } - - private: - size_t logits_size_; - size_t labels_size_; - size_t outputs_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc deleted file mode 100644 index dabc4df850..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h deleted file mode 100644 index 01f416f6b7..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel { - public: - SigmoidCrossEntropyWithLogitsGradGpuKernel() : logits_size_(0), labels_size_(0), outputs_size_(0) {} - ~SigmoidCrossEntropyWithLogitsGradGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *logits_addr = GetDeviceAddress(inputs, 0); - S *labels_addr = GetDeviceAddress(inputs, 1); - T *outputs_addr = GetDeviceAddress(outputs, 0); - - SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogitsGrad needs 3 inputs."; - return false; - } - logits_size_ = sizeof(T); - labels_size_ = sizeof(S); - outputs_size_ = sizeof(T); - - auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < logits_shape.size(); i++) { - logits_size_ *= logits_shape[i]; - } - - auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < labels_shape.size(); i++) { - labels_size_ *= labels_shape[i]; - } - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < output_shape.size(); i++) { - outputs_size_ *= output_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(logits_size_); - input_size_list_.push_back(labels_size_); - output_size_list_.push_back(outputs_size_); - } - - private: - size_t logits_size_; - size_t labels_size_; - size_t outputs_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc deleted file mode 100644 index 160a26d200..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO(SoftmaxCrossEntropyWithLogits, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SoftmaxCrossEntropyWithLogitsGpuKernel, float, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h deleted file mode 100644 index 8256174bcb..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h +++ /dev/null @@ -1,205 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/cross_entropy_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { - public: - SoftmaxCrossEntropyWithLogitsGpuKernel() - : cudnn_handle_(nullptr), - logits_descriptor_(nullptr), - softmax_output_descriptor_(nullptr), - algo_(CUDNN_SOFTMAX_ACCURATE), - mode_(CUDNN_SOFTMAX_MODE_INSTANCE), - cudnn_data_type_(CUDNN_DATA_FLOAT), - is_null_input_(false), - logits_size_(0), - labels_size_(0), - output1_size_(0), - output2_size_(0), - softmax_output_logits_size_(0), - batch_size_(0), - channel_size_(0), - height_(0), - width_(0) {} - ~SoftmaxCrossEntropyWithLogitsGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *logits_addr = GetDeviceAddress(inputs, 0); - S *labels_addr = GetDeviceAddress(inputs, 1); - T *loss_addr = GetDeviceAddress(outputs, 0); - T *dlogits_addr = GetDeviceAddress(outputs, 1); - T *softmax_output_logits = GetDeviceAddress(workspace, 0); - - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, logits_descriptor_, logits_addr, &beta, - softmax_output_descriptor_, softmax_output_logits), - "cudnnSoftmaxForward failed."); - - CrossEntropy(softmax_output_logits, labels_addr, batch_size_, channel_size_, loss_addr, dlogits_addr, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num - << ", but SoftmaxCrossEntropyWithLogitsGpuKernel needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 2) { - MS_LOG(ERROR) << "Output number is " << output_num - << ", but SoftmaxCrossEntropyWithLogitsGpuKernel needs 2 output."; - return false; - } - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - - InferInputOutputSize(kernel_node); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - batch_size_, channel_size_, height_, width_), - "cudnnSetTensor4dDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(softmax_output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_size_, - channel_size_, height_, width_), - "cudnnSetTensor4dDescriptor failed."); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&logits_descriptor_), - "cudnnCreateTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&softmax_output_descriptor_), - "cudnnCreateTensorDescriptor failed."); - } - void InitSizeLists() override { - input_size_list_.push_back(logits_size_); - input_size_list_.push_back(labels_size_); - output_size_list_.push_back(output1_size_); - output_size_list_.push_back(output2_size_); - workspace_size_list_.push_back(softmax_output_logits_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(softmax_output_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(logits_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - } - void InferInputOutputSize(const CNodePtr &kernel_node) { - auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(logits_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input1 is null"; - InitSizeLists(); - return; - } - auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(logits_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input2 is null"; - InitSizeLists(); - return; - } - CheckShapeValidation(logits_shape, labels_shape); - - size_t logits_dims = logits_shape.size(); - batch_size_ = 1; - for (size_t i = 0; i < logits_dims - 1; i++) { - batch_size_ *= logits_shape[i]; - } - channel_size_ = logits_shape[logits_dims - 1]; - height_ = 1; - width_ = 1; - logits_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; - - labels_size_ = 1; - size_t labels_dims = labels_shape.size(); - for (size_t i = 0; i < labels_dims; i++) { - labels_size_ *= labels_shape[i]; - } - labels_size_ *= sizeof(S); - - output1_size_ = logits_size_ / logits_shape[logits_dims - 1]; - output2_size_ = logits_size_; - softmax_output_logits_size_ = logits_size_; - return; - } - void CheckShapeValidation(const std::vector &logits_shape, const std::vector &labels_shape) { - size_t logits_dim_length = logits_shape.size(); - size_t labels_dim_length = labels_shape.size(); - if (labels_dim_length != logits_dim_length) { - MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length for " - "SoftmaxCrossEntropyWithLogits, but got Labels " - "shape length:" - << labels_dim_length << ", Logits shape length:" << logits_dim_length; - } - if (!std::equal(labels_shape.begin(), labels_shape.end(), logits_shape.begin())) { - MS_LOG(EXCEPTION) << "The shape of labels should be the same as the shape of logits except its last demension."; - } - return; - } - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t logits_descriptor_; - cudnnTensorDescriptor_t softmax_output_descriptor_; - cudnnSoftmaxAlgorithm_t algo_; - cudnnSoftmaxMode_t mode_; - cudnnDataType_t cudnn_data_type_; - bool is_null_input_; - - size_t logits_size_; - size_t labels_size_; - size_t output1_size_; - size_t output2_size_; - size_t softmax_output_logits_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t batch_size_; - size_t channel_size_; - size_t height_; - size_t width_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.cc deleted file mode 100644 index b9667ed85b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/softmax_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SoftmaxGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SoftmaxGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SoftmaxGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SoftmaxGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.h deleted file mode 100644 index 9d5a2a24e1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.h +++ /dev/null @@ -1,252 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/transpose_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SoftmaxGpuKernel : public GpuKernel { - public: - SoftmaxGpuKernel() - : cudnn_handle_(nullptr), - input_descriptor_(nullptr), - output_descriptor_(nullptr), - algo_(CUDNN_SOFTMAX_ACCURATE), - mode_(CUDNN_SOFTMAX_MODE_INSTANCE), - cudnn_data_type_(CUDNN_DATA_FLOAT), - is_null_input_(false), - input_size_(0), - output_size_(0), - workspace_size_(0), - axis_(0), - shape_size_(0), - batch_size_(0), - channel_size_(0), - height_(0), - width_(0) {} - ~SoftmaxGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - const float alpha = 1; - const float beta = 0; - - if (axis_ == 1) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, input_descriptor_, - input_addr, &beta, output_descriptor_, output_addr), - "cudnnSoftmaxForward failed"); - } else { - T *transpose_input_addr = GetDeviceAddress(workspace, 0); - T *transpose_output_addr = GetDeviceAddress(workspace, 1); - int *input_shape = GetDeviceAddress(workspace, 2); - int *transpose_shape = GetDeviceAddress(workspace, 3); - int *transpose_axis = GetDeviceAddress(workspace, 4); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_axis failed"); - int size = SizeToInt(input_size_ / sizeof(T)); - CalTranspose(size, input_addr, input_shape, transpose_axis, shape_size_, transpose_input_addr, - reinterpret_cast(stream_ptr)); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, input_descriptor_, transpose_input_addr, &beta, - output_descriptor_, transpose_output_addr), - "cudnnSoftmaxForward failed"); - CalTranspose(size, transpose_output_addr, transpose_shape, transpose_axis, shape_size_, output_addr, - reinterpret_cast(stream_ptr)); - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax needs 1 input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but softmax needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxGpuKernel input is null"; - InitSizeLists(); - return true; - } - shape_size_ = SizeToInt(input_shape.size()); - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == "LogSoftmax") { - algo_ = CUDNN_SOFTMAX_LOG; - auto axis = GetAttr(kernel_node, "axis"); - InitSizeByAxis(input_shape, axis); - } else { - algo_ = CUDNN_SOFTMAX_ACCURATE; - auto axis = GetAttr>(kernel_node, "axis"); - InitSizeByAxis(input_shape, axis[0]); - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), - SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), - "set input_descriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), - SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), - "set output_descriptor failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "create input_descriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "create output_descriptor failed"); - } - - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(input_size_); - workspace_size_list_.push_back(output_size_); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "destroy output_descriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "destroy input_descriptor failed"); - } - - void InitSizeByAxis(const std::vector &input_shape, const int &axis) { - if (input_shape.size() == 2) { - InitSizeByAxis2D(input_shape, axis); - } else { - InitSizeByAxisLastDim(input_shape, axis); - } - } - - void InitSizeByAxis2D(const std::vector &input_shape, const int &axis) { - axis_ = axis; - if (axis_ < 0) { - axis_ += shape_size_; - } - if (axis_ == 1) { - batch_size_ = input_shape[0]; - channel_size_ = input_shape[1]; - } else if (axis_ == 0) { - batch_size_ = input_shape[1]; - channel_size_ = input_shape[0]; - input_shape_.push_back(input_shape[0]); - input_shape_.push_back(input_shape[1]); - transpose_shape_.push_back(input_shape[1]); - transpose_shape_.push_back(input_shape[0]); - transpose_axis_.push_back(1); - transpose_axis_.push_back(0); - } else { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; - } - - height_ = 1; - width_ = 1; - input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; - output_size_ = input_size_; - workspace_size_ = IntToSize(shape_size_) * sizeof(int); - } - - void InitSizeByAxisLastDim(const std::vector &input_shape, const int &axis) { - int axis_pos = axis; - if (axis_pos < 0) { - axis_pos += input_shape.size(); - } - // axis should be -1 with ND - if (axis_pos != SizeToInt(input_shape.size() - 1)) { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; - } - // squeeze to 2d, then invoke cudnn - size_t n = 1; - for (size_t i = 0; i < input_shape.size() - 1; i++) { - n *= input_shape[i]; - } - axis_ = 1; - batch_size_ = n; - channel_size_ = input_shape[axis_pos]; - height_ = 1; - width_ = 1; - input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; - output_size_ = input_size_; - input_shape_.push_back(batch_size_); - input_shape_.push_back(channel_size_); - } - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t input_descriptor_; - cudnnTensorDescriptor_t output_descriptor_; - cudnnSoftmaxAlgorithm_t algo_; - cudnnSoftmaxMode_t mode_; - cudnnDataType_t cudnn_data_type_; - bool is_null_input_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - std::vector input_shape_; - std::vector transpose_shape_; - std::vector transpose_axis_; - int axis_; - int shape_size_; - - size_t batch_size_; - size_t channel_size_; - size_t height_; - size_t width_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.cc deleted file mode 100644 index 5b07136522..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/softmax_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - LogSoftmaxGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SoftmaxGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - LogSoftmaxGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SoftmaxGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.h deleted file mode 100644 index d73503d5a5..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.h +++ /dev/null @@ -1,219 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/transpose_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SoftmaxGradGpuKernel : public GpuKernel { - public: - SoftmaxGradGpuKernel() - : cudnn_handle_(nullptr), - y_desc_(nullptr), - algo_(CUDNN_SOFTMAX_ACCURATE), - mode_(CUDNN_SOFTMAX_MODE_INSTANCE), - cudnn_data_type_(CUDNN_DATA_FLOAT), - is_null_input_(false), - input_size_(0), - output_size_(0), - workspace_size_(0), - axis_(0), - shape_size_(0), - batch_size_(0), - channel_size_(0), - height_(0), - width_(0) {} - ~SoftmaxGradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *y_addr = GetDeviceAddress(inputs, 0); - T *dy_addr = GetDeviceAddress(inputs, 1); - T *dx_addr = GetDeviceAddress(outputs, 0); - - T *transpose_y_addr = GetDeviceAddress(workspace, 0); - T *transpose_dy_addr = GetDeviceAddress(workspace, 1); - T *transpose_dx_addr = GetDeviceAddress(workspace, 2); - int *input_shape = GetDeviceAddress(workspace, 3); - int *transpose_shape = GetDeviceAddress(workspace, 4); - int *transpose_axis = GetDeviceAddress(workspace, 5); - const float alpha = 1; - const float beta = 0; - - if (axis_ == 1) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, y_addr, y_desc_, - dy_addr, &beta, y_desc_, dx_addr), - "cudnnSoftmaxBackward failed"); - } else { - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_axis failed"); - int size = SizeToInt(input_size_ / sizeof(T)); - CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr, - reinterpret_cast(stream_ptr)); - CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr, - reinterpret_cast(stream_ptr)); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, transpose_y_addr, - y_desc_, transpose_dy_addr, &beta, y_desc_, transpose_dx_addr), - "cudnnSoftmaxBackward failed"); - CalTranspose(size, transpose_dx_addr, transpose_shape, transpose_axis, shape_size_, dx_addr, - reinterpret_cast(stream_ptr)); - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax grad needs 2 input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but softmax grad needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxGradGpuKernel input is null"; - InitSizeLists(); - return true; - } - shape_size_ = SizeToInt(input_shape.size()); - if (shape_size_ != 2) { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax grad only supports 2-D inputs."; - } - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == "LogSoftmaxGrad") { - algo_ = CUDNN_SOFTMAX_LOG; - auto axis = GetAttr(kernel_node, "axis"); - InitSizeByAxis(input_shape, axis); - } else { - algo_ = CUDNN_SOFTMAX_ACCURATE; - auto axis = GetAttr>(kernel_node, "axis"); - InitSizeByAxis(input_shape, axis[0]); - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), - SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), - "set input_descriptor failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "create input_descriptor failed"); - } - - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(input_size_); - workspace_size_list_.push_back(input_size_); - workspace_size_list_.push_back(output_size_); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "destroy output_descriptor failed"); - } - - void InitSizeByAxis(const std::vector input_shape, const int axis) { - axis_ = axis; - if (axis_ < 0) { - axis_ += shape_size_; - } - if (axis_ == 1) { - batch_size_ = input_shape[0]; - channel_size_ = input_shape[1]; - } else if (axis_ == 0) { - batch_size_ = input_shape[1]; - channel_size_ = input_shape[0]; - input_shape_.push_back(input_shape[0]); - input_shape_.push_back(input_shape[1]); - transpose_shape_.push_back(input_shape[1]); - transpose_shape_.push_back(input_shape[0]); - transpose_axis_.push_back(1); - transpose_axis_.push_back(0); - } else { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; - } - - height_ = 1; - width_ = 1; - input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; - output_size_ = input_size_; - workspace_size_ = IntToSize(shape_size_) * sizeof(int); - } - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t y_desc_; - cudnnSoftmaxAlgorithm_t algo_; - cudnnSoftmaxMode_t mode_; - cudnnDataType_t cudnn_data_type_; - bool is_null_input_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - std::vector input_shape_; - std::vector transpose_shape_; - std::vector transpose_axis_; - int axis_; - int shape_size_; - - size_t batch_size_; - size_t channel_size_; - size_t height_; - size_t width_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc deleted file mode 100644 index 537eeb5726..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - SparseSoftmaxCrossEntropyWithLogits, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - SparseSoftmaxCrossEntropyWithLogitsGpuKernel, float, int) -MS_REG_GPU_KERNEL_TWO( - SparseSoftmaxCrossEntropyWithLogits, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - SparseSoftmaxCrossEntropyWithLogitsGpuKernel, float, int64_t) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h deleted file mode 100644 index 6950f0e308..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/cross_entropy_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class SparseSoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { - public: - SparseSoftmaxCrossEntropyWithLogitsGpuKernel() - : cudnn_handle_(nullptr), - logits_descriptor_(nullptr), - softmax_output_descriptor_(nullptr), - algo_(CUDNN_SOFTMAX_ACCURATE), - mode_(CUDNN_SOFTMAX_MODE_INSTANCE), - cudnn_data_type_(CUDNN_DATA_FLOAT), - is_grad_(false), - is_null_input_(false), - logits_size_(0), - labels_size_(0), - output_size_(0), - softmax_output_logits_size_(0), - batch_size_(0), - channel_size_(0), - height_(0), - width_(0) {} - ~SparseSoftmaxCrossEntropyWithLogitsGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *logits_addr = GetDeviceAddress(inputs, 0); - S *labels_addr = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - T *softmax_output_logits = GetDeviceAddress(workspace, 0); - - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, logits_descriptor_, logits_addr, &beta, - softmax_output_descriptor_, softmax_output_logits), - "cudnnSoftmaxForward failed."); - - is_grad_ ? CrossEntropyGradWithSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output_addr, - reinterpret_cast(stream_ptr)) - : CrossEntropyWithSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output_addr, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num - << ", but SparseSoftmaxCrossEntropyWithLogitsGpuKernel needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num - << ", but SparseSoftmaxCrossEntropyWithLogitsGpuKernel needs 1 output."; - return false; - } - is_grad_ = GetAttr(kernel_node, "is_grad"); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - - InferInputOutputSize(kernel_node); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - batch_size_, channel_size_, height_, width_), - "cudnnSetTensor4dDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(softmax_output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_size_, - channel_size_, height_, width_), - "cudnnSetTensor4dDescriptor failed."); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&logits_descriptor_), - "cudnnCreateTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&softmax_output_descriptor_), - "cudnnCreateTensorDescriptor failed."); - } - void InitSizeLists() override { - input_size_list_.push_back(logits_size_); - input_size_list_.push_back(labels_size_); - output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(softmax_output_logits_size_); - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(softmax_output_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(logits_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - } - void InferInputOutputSize(const CNodePtr &kernel_node) { - auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(logits_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input1 is null"; - InitSizeLists(); - return; - } - auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(logits_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input2 is null"; - InitSizeLists(); - return; - } - CheckShapeValidation(logits_shape, labels_shape); - - size_t logits_dims = logits_shape.size(); - batch_size_ = 1; - for (size_t i = 0; i < logits_dims - 1; i++) { - batch_size_ *= logits_shape[i]; - } - channel_size_ = logits_shape[logits_dims - 1]; - height_ = 1; - width_ = 1; - logits_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; - - labels_size_ = 1; - size_t labels_dims = labels_shape.size(); - for (size_t i = 0; i < labels_dims; i++) { - labels_size_ *= labels_shape[i]; - } - labels_size_ *= sizeof(S); - - output_size_ = is_grad_ ? logits_size_ : sizeof(T); - softmax_output_logits_size_ = logits_size_; - return; - } - void CheckShapeValidation(const std::vector &logits_shape, const std::vector &labels_shape) { - size_t logits_dim_length = logits_shape.size(); - size_t labels_dim_length = labels_shape.size(); - if (labels_dim_length != logits_dim_length - 1) { - MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length minus 1 for " - "SparseSoftmaxCrossEntropyWithLogits, " - "but got Labels shape length:" - << labels_dim_length << ", Logits shape length:" << logits_dim_length; - } - if (!std::equal(labels_shape.begin(), labels_shape.end(), logits_shape.begin())) { - MS_LOG(EXCEPTION) << "The shape of labels should be the same as the shape of logits except its last demension."; - } - return; - } - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t logits_descriptor_; - cudnnTensorDescriptor_t softmax_output_descriptor_; - cudnnSoftmaxAlgorithm_t algo_; - cudnnSoftmaxMode_t mode_; - cudnnDataType_t cudnn_data_type_; - bool is_grad_; - bool is_null_input_; - - size_t logits_size_; - size_t labels_size_; - size_t output_size_; - size_t softmax_output_logits_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t batch_size_; - size_t channel_size_; - size_t height_; - size_t width_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.cc deleted file mode 100644 index 0f3e0c95f4..0000000000 --- a/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/other/assign_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Assign, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AssignGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - Assign, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - AssignGpuKernel, half) -MS_REG_GPU_KERNEL_ONE( - Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - AssignGpuKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.h deleted file mode 100644 index b41d583a43..0000000000 --- a/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class AssignGpuKernel : public GpuKernel { - public: - AssignGpuKernel() : input_size_(0) {} - ~AssignGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *var = GetDeviceAddress(inputs, 0); - T *value = GetDeviceAddress(inputs, 1); - T *output = GetDeviceAddress(outputs, 0); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(var, value, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), - "cudaMemxcpyAsync failed."); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(output, value, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), - "cudaMemxcpyAsync failed."); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_size_ = sizeof(T); - for (size_t x : shape) { - input_size_ = input_size_ * x; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but AssignGpuKernel needs 2 output."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but AssignGpuKernel needs 1 output."; - return false; - } - return true; - } - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc deleted file mode 100644 index af95767407..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(BatchNormFold2, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - BatchNormFold2GpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h deleted file mode 100644 index b898f34689..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class BatchNormFold2GpuKernel : public GpuKernel { - public: - BatchNormFold2GpuKernel() - : cudnn_handle_(nullptr), - is_null_input_(false), - batch_size_(0), - channel_(0), - height_(0), - width_(0), - freeze_bn_(0) {} - - ~BatchNormFold2GpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - - auto *input = GetDeviceAddress(inputs, 0); - auto *beta = GetDeviceAddress(inputs, 1); - auto *gamma = GetDeviceAddress(inputs, 2); - auto *batch_std = GetDeviceAddress(inputs, 3); - auto *batch_mean = GetDeviceAddress(inputs, 4); - auto *running_std = GetDeviceAddress(inputs, 5); - auto *running_mean = GetDeviceAddress(inputs, 6); - auto *global_step = GetDeviceAddress(inputs, 7); - auto *output = GetDeviceAddress(outputs, 0); - - BatchNormFold2Forward(input, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, output, - freeze_bn_, batch_size_, channel_, height_, width_, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 8) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GpuKernel needs 8."; - return false; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "BatchNormFold2GpuKernel input is null"; - InitSizeLists(); - return true; - } - - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "BatchNormFold2GpuKernel input shape needs (N,C,H,W)."; - return false; - } - batch_size_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - - void InitSizeLists() override { - size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - size_t weight_size = channel_ * sizeof(T); - input_size_list_.push_back(input_size); - input_size_list_.push_back(weight_size); // beta - input_size_list_.push_back(weight_size); // gamma - input_size_list_.push_back(weight_size); // batch_std - input_size_list_.push_back(weight_size); // batch_mean - input_size_list_.push_back(weight_size); // running_std - input_size_list_.push_back(weight_size); // running_mean - input_size_list_.push_back(sizeof(int32_t)); // global_step - output_size_list_.push_back(input_size); - } - - private: - void DestroyResource() noexcept {} - - cudnnHandle_t cudnn_handle_; - bool is_null_input_; - size_t batch_size_; - size_t channel_; - size_t height_; - size_t width_; - size_t freeze_bn_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc deleted file mode 100644 index 93862aeedd..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(BatchNormFold2Grad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - BatchNormFold2GradGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h deleted file mode 100644 index e0bafdb96a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h +++ /dev/null @@ -1,168 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class BatchNormFold2GradGpuKernel : public GpuKernel { - public: - BatchNormFold2GradGpuKernel() - : cudnn_handle_(nullptr), - is_null_input_(false), - batch_size_(0), - channel_(0), - height_(0), - width_(0), - freeze_bn_(0) {} - - ~BatchNormFold2GradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - - auto *dout = GetDeviceAddress(inputs, 0); - auto *x = GetDeviceAddress(inputs, 1); - auto *gamma = GetDeviceAddress(inputs, 2); - auto *batch_std = GetDeviceAddress(inputs, 3); - auto *batch_mean = GetDeviceAddress(inputs, 4); - auto *running_std = GetDeviceAddress(inputs, 5); - auto *running_mean = GetDeviceAddress(inputs, 6); - auto *global_step = GetDeviceAddress(inputs, 7); - auto *d_batch_std = GetDeviceAddress(outputs, 0); - auto *d_batch_mean = GetDeviceAddress(outputs, 1); - auto *d_beta = GetDeviceAddress(outputs, 2); - auto *d_gamma = GetDeviceAddress(outputs, 3); - auto *d_x = GetDeviceAddress(outputs, 4); - auto *tmp = GetDeviceAddress(workspace, 0); - auto *tmp2 = GetDeviceAddress(workspace, 1); - auto *reduce_x = GetDeviceAddress(workspace, 2); - auto *tmp_x = GetDeviceAddress(workspace, 3); - - int32_t current_step_host[1]; - size_t x_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost, - reinterpret_cast(stream_ptr)), - "Failed to copy gpu memory."); - CHECK_CUDA_RET_WITH_ERROR( - cudaMemcpyAsync(d_x, dout, x_size, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), - "Failed to copy gpu memory."); - - BatchNormFold2GradReduce(dout, x, d_beta, tmp, reduce_x, tmp2, tmp_x, batch_size_, channel_, height_, width_, - reinterpret_cast(stream_ptr)); - if (current_step_host[0] < freeze_bn_) { - CalBatchNormFold2GradNotFreezeDxMul(batch_std, running_std, d_x, batch_size_, channel_, height_, width_, - reinterpret_cast(stream_ptr)); - CalBatchNormFold2GradNotFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, - d_batch_mean, d_batch_std, channel_, reinterpret_cast(stream_ptr)); - } else { - CalBatchNormFold2GradFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, - d_batch_mean, d_batch_std, channel_, reinterpret_cast(stream_ptr)); - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 8) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GradGpuKernel needs 8."; - return false; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "BatchNormFold2GradGpuKernel input is null"; - InitSizeLists(); - return true; - } - - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "BatchNormFold2GradGpuKernel input shape needs (N,C,H,W)."; - return false; - } - batch_size_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - - void InitSizeLists() override { - size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - size_t weight_size = channel_ * sizeof(T); - size_t workspace_size = batch_size_ * channel_ * sizeof(T); - input_size_list_.push_back(input_size); // dout - input_size_list_.push_back(input_size); // x - input_size_list_.push_back(weight_size); // gamma - input_size_list_.push_back(weight_size); // batch_std - input_size_list_.push_back(weight_size); // batch_mean - input_size_list_.push_back(weight_size); // running_std - input_size_list_.push_back(weight_size); // running_mean - input_size_list_.push_back(sizeof(int32_t)); // global_step - - output_size_list_.push_back(weight_size); // d_batch_std - output_size_list_.push_back(weight_size); // d_batch_mean - output_size_list_.push_back(weight_size); // d_beta - output_size_list_.push_back(weight_size); // d_gamma - output_size_list_.push_back(input_size); // d_x - - workspace_size_list_.push_back(workspace_size); // tmp - workspace_size_list_.push_back(workspace_size); // tmp2 - workspace_size_list_.push_back(weight_size); // reduce_x - workspace_size_list_.push_back(input_size); // tmp_x - } - - private: - void DestroyResource() noexcept {} - - cudnnHandle_t cudnn_handle_; - bool is_null_input_; - size_t batch_size_; - size_t channel_; - size_t height_; - size_t width_; - int32_t freeze_bn_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc deleted file mode 100644 index 4f968a0fa3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/batchnorm_fold_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(BatchNormFold, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - BatchNormFoldGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h deleted file mode 100644 index 6cd001fd2e..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h +++ /dev/null @@ -1,209 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class BatchNormFoldGpuKernel : public GpuKernel { - public: - BatchNormFoldGpuKernel() - : input_size_(0), - output_size_(0), - exp_avg_factor_(0.9), - epsilon_(1e-12), - is_training_(true), - freeze_bn_(0), - batch_(0), - channel_(0), - height_(0), - width_(0), - mode_(CUDNN_BATCHNORM_SPATIAL), - x_desc_(nullptr), - scale_bias_mean_var_desc_(nullptr), - handle_(nullptr) {} - - ~BatchNormFoldGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - (void)workspace; - auto x = GetDeviceAddress(inputs, 0); - auto mean = GetDeviceAddress(inputs, 1); - auto variance = GetDeviceAddress(inputs, 2); - int *current_step = GetDeviceAddress(inputs, 3); - int current_step_host[1]; - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, - reinterpret_cast(stream_ptr)), - "Copy gpu memoy failed."); - if (x == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null."; - return false; - } - if (mean == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGpuKernel mean is null."; - return false; - } - if (variance == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGpuKernel variance is null."; - return false; - } - if (current_step == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGpuKernel current_step is null."; - return false; - } - auto batch_mean = GetDeviceAddress(outputs, 0); - auto batch_std = GetDeviceAddress(outputs, 1); - auto running_mean = GetDeviceAddress(outputs, 2); - auto running_std = GetDeviceAddress(outputs, 3); - auto y = GetDeviceAddress(workspace, 0); - - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Failed to copy gpu memory."); - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_std, variance, output_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Failed to copy gpu memory."); - CalUpdateRunningStd(channel_, epsilon_, running_std, reinterpret_cast(stream_ptr)); - if (!is_training_ || current_step_host[0] >= freeze_bn_) { - CHECK_CUDA_RET_WITH_ERROR(cudaMemset(batch_mean, 0, output_size_), "Failed to set gpu memory."); - ThrustFillWith(batch_std, channel_, 1.f, reinterpret_cast(stream_ptr)); - return true; - } - const T alpha = 1; - const T beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationForwardTraining( - handle_, mode_, &alpha, &beta, x_desc_, x, x_desc_, y, scale_bias_mean_var_desc_, - mean, mean, exp_avg_factor_, mean, variance, epsilon_, batch_mean, batch_std), - "Failed to launch kernel.") - CalUpdateBatchStd(channel_, batch_std, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 4) { - MS_LOG(ERROR) << "Input number is " << input_num << " but BatchNormFold GpuKernel OP needs 4 input."; - return false; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 4) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFold GpuKernel OP needs 4 output."; - return false; - } - - T momentum = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("momentum")); - exp_avg_factor_ = 1.0 - momentum; - epsilon_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon")); - is_training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training")); - freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "Input shape is " << input_shape.size() - << ", but BatchNormFold GpuKernel OP needs 4DTensor input."; - return false; - } - batch_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - - input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; - output_size_ = sizeof(T) * channel_; - - cudnnDataType_t cudnnDataType = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, batch_, channel_, height_, width_), - "Set x desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, 1, channel_, 1, 1), - "Set para desc failed"); - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - // x, mean, variance, current_step - input_size_list_.push_back(input_size_); - input_size_list_.push_back(output_size_); - input_size_list_.push_back(output_size_); - input_size_list_.push_back(sizeof(int)); - - // batch_mean, batch_std, running_mean, running_std - output_size_list_.push_back(output_size_); - output_size_list_.push_back(output_size_); - output_size_list_.push_back(output_size_); - output_size_list_.push_back(output_size_); - - // store y - workspace_size_list_.push_back(input_size_); - } - - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed"); - } - - size_t input_size_; - size_t output_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - double exp_avg_factor_; - double epsilon_; - bool is_training_; - int freeze_bn_; - int batch_; - int channel_; - int height_; - int width_; - - cudnnBatchNormMode_t mode_; - cudnnTensorDescriptor_t x_desc_; - cudnnTensorDescriptor_t scale_bias_mean_var_desc_; - - cudnnHandle_t handle_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc deleted file mode 100644 index 93ea66258d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(BatchNormFoldGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - BatchNormFoldGradGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h deleted file mode 100644 index 7a3ed7ef91..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h +++ /dev/null @@ -1,166 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class BatchNormFoldGradGpuKernel : public GpuKernel { - public: - BatchNormFoldGradGpuKernel() - : input_size_(0), - channel_size_(0), - workspace_size_(0), - momentum_(0.1), - epsilon_(1e-12), - is_training_(true), - freeze_bn_(0), - current_step_(0), - batch_(0), - channel_(0), - height_(0), - width_(0) {} - ~BatchNormFoldGradGpuKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' - T *d_batch_mean = GetDeviceAddress(inputs, 0); - T *d_batch_std = GetDeviceAddress(inputs, 1); - T *x = GetDeviceAddress(inputs, 2); - T *batch_mean = GetDeviceAddress(inputs, 3); - T *batch_std = GetDeviceAddress(inputs, 4); - int *current_step = GetDeviceAddress(inputs, 5); - int current_step_host[1]; - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, - reinterpret_cast(stream_ptr)), - "Copy gpu memoy failed."); - if (d_batch_mean == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null."; - return false; - } - if (d_batch_std == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_std is null."; - return false; - } - if (x == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel x is null."; - return false; - } - if (batch_mean == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_mean is null."; - return false; - } - if (batch_std == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_std is null."; - return false; - } - if (current_step == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel current_step is null."; - return false; - } - T *dx = GetDeviceAddress(outputs, 0); - - if (!is_training_ || current_step_host[0] >= freeze_bn_) { - ThrustFillWith(dx, batch_ * channel_ * height_ * width_, 0.f, reinterpret_cast(stream_ptr)); - return true; - } - CalBatchNormFoldGrad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_, channel_, height_, width_, dx, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 6) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs 6 input."; - return false; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFoldGrad GpuKernel OP needs 4 output."; - return false; - } - - epsilon_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon")); - is_training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training")); - freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "Input shape is " << input_shape.size() - << ", but BatchNormFoldGrad GpuKernel OP needs 4DTensor input."; - return false; - } - batch_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - - input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; - channel_size_ = sizeof(T) * channel_; - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' - input_size_list_.push_back(channel_size_); - input_size_list_.push_back(channel_size_); - input_size_list_.push_back(input_size_); - input_size_list_.push_back(channel_size_); - input_size_list_.push_back(channel_size_); - input_size_list_.push_back(sizeof(int)); - // 'dx' - output_size_list_.push_back(input_size_); - } - - private: - size_t input_size_; - size_t channel_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - T momentum_; - T epsilon_; - bool is_training_; - int freeze_bn_; - int current_step_; - int batch_; - int channel_; - int height_; - int width_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.cc deleted file mode 100644 index a914b6ec14..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020、 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/correction_mul_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(CorrectionMul, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - CorrectionMulGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h deleted file mode 100644 index 29aeabb03a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/correction_mul_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class CorrectionMulGpuKernel : public GpuKernel { - public: - CorrectionMulGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} - ~CorrectionMulGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - auto *weight = GetDeviceAddress(inputs, 0); - auto *gamma = GetDeviceAddress(inputs, 1); - auto *running_std = GetDeviceAddress(inputs, 2); - auto *output = GetDeviceAddress(outputs, 0); - - CalCorrectionMul(weight, gamma, running_std, batch_size_, channel_, height_, width_, output, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGpuKernel needs 3."; - return false; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "CorrectionMulGpuKernel input shape needs (N,C,H,W)."; - return false; - } - batch_size_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - size_t weight_size = batch_size_ * sizeof(T); - input_size_list_.push_back(input_size); // weight - input_size_list_.push_back(weight_size); // gamma - input_size_list_.push_back(weight_size); // running_std - output_size_list_.push_back(input_size); - } - - void InitResource() override {} - - private: - void DestroyResource() noexcept {} - - size_t batch_size_; - size_t channel_; - size_t height_; - size_t width_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc deleted file mode 100644 index 28b5d56e68..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/correction_mul_grad_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/correction_mul_impl.cuh" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(CorrectionMulGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - CorrectionMulGradGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h deleted file mode 100644 index 3feffa586b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/correction_mul_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class CorrectionMulGradGpuKernel : public GpuKernel { - public: - CorrectionMulGradGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} - ~CorrectionMulGradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - auto *d_out = GetDeviceAddress(inputs, 0); - auto *weight = GetDeviceAddress(inputs, 1); - auto *gamma = GetDeviceAddress(inputs, 2); - auto *running_std = GetDeviceAddress(inputs, 3); - auto *d_weight = GetDeviceAddress(outputs, 0); - auto *d_gamma = GetDeviceAddress(outputs, 1); - auto *tmp = GetDeviceAddress(workspace, 0); - - CalCorrectionMul(d_out, gamma, running_std, batch_size_, channel_, height_, width_, d_weight, - reinterpret_cast(stream_ptr)); - CalCorrectionMulGrad(d_out, weight, running_std, batch_size_, channel_, height_, width_, d_gamma, tmp, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 4) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGradGpuKernel needs 4."; - return false; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "CorrectionMulGradGpuKernel input shape needs (N,C,H,W)."; - return false; - } - batch_size_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - size_t weight_size = batch_size_ * sizeof(T); - input_size_list_.push_back(input_size); // d_out - input_size_list_.push_back(input_size); // weight - input_size_list_.push_back(weight_size); // gamma - input_size_list_.push_back(weight_size); // running_std - output_size_list_.push_back(input_size); // d_weight - output_size_list_.push_back(weight_size); // d_gamma - workspace_size_list_.push_back(input_size); // tmp d_out * weight - } - void InitResource() override {} - - private: - void DestroyResource() noexcept {} - - size_t batch_size_; - size_t channel_; - size_t height_; - size_t width_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc deleted file mode 100644 index 8db6ddd848..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" -#include -#include -#include -#include - -namespace mindspore { -namespace kernel { -FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel() - : input_size_(0), - num_channels_(0), - num_bits_(0), - training_(false), - symmetric_(false), - narrow_range_(false), - quant_delay_(0), - quant_min_(0), - quant_max_(0), - global_step_(0) {} - -const std::vector &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &FakeQuantPerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &FakeQuantPerChannelGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 input."; - return false; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << " but FakeQuant GpuKernel OP needs 1 output."; - return false; - } - - // get attribute - num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); - training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); - quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); - - if (num_bits_ <= 2 || num_bits_ >= 16) { - MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16."; - return false; - } - - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0."; - return false; - } - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - - // shape info for gpu - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - num_channels_ = SizeToInt(input_shape[0]); - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void FakeQuantPerChannelGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // input in tensor - input_size_list_.push_back(sizeof(float) * num_channels_); // min one scalar - input_size_list_.push_back(sizeof(float) * num_channels_); // max on scalar - output_size_list_.push_back(input_size_); // output in tensor - workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel -} - -void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, - float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) { - CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, - symmetric_, reinterpret_cast(stream_ptr)); - CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); -} - -bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - (void)workspace; - float *output = GetDeviceAddress(outputs, 0); - float *input = GetDeviceAddress(inputs, 0); - float *input_min = GetDeviceAddress(inputs, 1); - float *input_max = GetDeviceAddress(inputs, 2); - float *scale = GetDeviceAddress(workspace, 0); - float *nudge_min = GetDeviceAddress(workspace, 1); - float *nudge_max = GetDeviceAddress(workspace, 2); - - if (input == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; - } - if (input_min == nullptr || input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null."; - } - - if (training_) { - if (global_step_ >= quant_delay_) { - CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); - } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed."); - } - global_step_++; - } else { - CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); - } - - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h deleted file mode 100755 index 122fe96af3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class FakeQuantPerChannelGpuKernel : public GpuKernel { - public: - FakeQuantPerChannelGpuKernel(); - ~FakeQuantPerChannelGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel) override; - - protected: - void InitSizeLists() override; - - private: - void CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min, - float *nudge_max, float *scale, void *stream_ptr); - - size_t input_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int num_channels_; - int num_bits_; - bool training_; - bool symmetric_; - bool narrow_range_; - int quant_delay_; - float quant_min_; - float quant_max_; - int global_step_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc deleted file mode 100644 index 5c774c05ed..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" - -namespace mindspore { -namespace kernel { -FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel() - : input_size_(0), - num_bits_(0), - quant_min_(0), - quant_max_(0), - num_channels_(0), - quant_delay_(0), - global_step_(0), - narrow_range_(false), - symmetric_(false) {} - -const std::vector &FakeQuantPerChannelGradGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &FakeQuantPerChannelGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &FakeQuantPerChannelGradGpuKernel::GetWorkspaceSizeList() const { - return workspace_size_list_; -} - -bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 4) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output."; - } - - num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); - if (num_bits_ <= 2 || num_bits_ >= 16) { - MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; - } - - quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; - } - - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - num_channels_ = SizeToInt(input_shape[0]); - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void FakeQuantPerChannelGradGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // gradient - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(sizeof(float) * num_channels_); // min - input_size_list_.push_back(sizeof(float) * num_channels_); // max - output_size_list_.push_back(input_size_); // output - workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel -} - -bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - (void)workspace; - float *output = GetDeviceAddress(outputs, 0); - float *gradient = GetDeviceAddress(inputs, 0); - float *input = GetDeviceAddress(inputs, 1); - float *input_min = GetDeviceAddress(inputs, 2); - float *input_max = GetDeviceAddress(inputs, 3); - float *scale = GetDeviceAddress(workspace, 0); - float *nudge_min = GetDeviceAddress(workspace, 1); - float *nudge_max = GetDeviceAddress(workspace, 2); - - if (gradient == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; - } - if (input == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input is null"; - } - if (input_min == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input min is null"; - } - if (input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input max is null"; - } - - int total_size = input_size_ / sizeof(float); - if (global_step_ >= quant_delay_) { - CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, - symmetric_, reinterpret_cast(stream_ptr)); - CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, - reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed."); - } - global_step_++; - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h deleted file mode 100644 index d863a2c99f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class FakeQuantPerChannelGradGpuKernel : public GpuKernel { - public: - FakeQuantPerChannelGradGpuKernel(); - ~FakeQuantPerChannelGradGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel_node) override; - - protected: - void InitSizeLists() override; - - private: - size_t input_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int num_bits_; - float quant_min_; - float quant_max_; - int num_channels_; - int quant_delay_; - int global_step_; - bool narrow_range_; - bool symmetric_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc deleted file mode 100644 index 44869983eb..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc +++ /dev/null @@ -1,143 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" -#include -#include -#include -#include - -namespace mindspore { -namespace kernel { -FakeQuantPerLayerGpuKernel::FakeQuantPerLayerGpuKernel() - : input_size_(0), - quant_min_(0), - quant_max_(0), - quant_num_(1), - global_step_(0), - num_bits_(0), - quant_delay_(0), - training_(false), - narrow_range_(false), - symmetric_(false) {} - -const std::vector &FakeQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; - } - - num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); - quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); - training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); - - if (num_bits_ <= 2 || num_bits_ >= 16) { - MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; - } - - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0."; - } - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - - // init size - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= SizeToInt(input_shape[i]); - } - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void FakeQuantPerLayerGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // x - input_size_list_.push_back(sizeof(float)); // min - input_size_list_.push_back(sizeof(float)); // max - output_size_list_.push_back(input_size_); // y - workspace_size_list_.push_back(sizeof(float)); // scale - workspace_size_list_.push_back(sizeof(float)); // nudge_min - workspace_size_list_.push_back(sizeof(float)); // nudge_max -} - -bool FakeQuantPerLayerGpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - float *output = GetDeviceAddress(outputs, 0); - float *input = GetDeviceAddress(inputs, 0); - float *input_min = GetDeviceAddress(inputs, 1); - float *input_max = GetDeviceAddress(inputs, 2); - float *scale = GetDeviceAddress(workspace, 0); - float *nudge_min = GetDeviceAddress(workspace, 1); - float *nudge_max = GetDeviceAddress(workspace, 2); - - if (input == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input x is null."; - } - if (input_min == nullptr || input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input min or input max is null."; - } - - if (training_) { - // control flow for quant_delay - if (global_step_ >= quant_delay_) { - // real launch - CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, - reinterpret_cast(stream_ptr)); - CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - } - global_step_++; - } else { - // real launch - CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, - reinterpret_cast(stream_ptr)); - CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); - } - - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h deleted file mode 100755 index 38810e06df..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class FakeQuantPerLayerGpuKernel : public GpuKernel { - public: - FakeQuantPerLayerGpuKernel(); - ~FakeQuantPerLayerGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel) override; - - protected: - void InitSizeLists() override; - - private: - size_t input_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - float quant_min_; - float quant_max_; - int quant_num_; - int global_step_; - int num_bits_; - int quant_delay_; - bool training_; - bool narrow_range_; - bool symmetric_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc deleted file mode 100644 index c8d57b2bb1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc +++ /dev/null @@ -1,133 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" - -namespace mindspore { -namespace kernel { -FakeQuantPerLayerGradGpuKernel::FakeQuantPerLayerGradGpuKernel() - : input_size_(0), - workspace_size_(0), - num_bits_(0), - quant_min_(0), - quant_max_(0), - quant_num_(1), - quant_delay_(0), - global_step_(0), - narrow_range_(false), - symmetric_(false) {} - -const std::vector &FakeQuantPerLayerGradGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 4) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output."; - } - - num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); - if (num_bits_ <= 2 || num_bits_ >= 16) { - MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; - } - - quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; - } - - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - - // init size - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= SizeToInt(input_shape[i]); - } - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void FakeQuantPerLayerGradGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // gradient - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(sizeof(float)); // min - input_size_list_.push_back(sizeof(float)); // max - output_size_list_.push_back(input_size_); // output - workspace_size_list_.push_back(sizeof(float)); // scale - workspace_size_list_.push_back(sizeof(float)); // nudge_min - workspace_size_list_.push_back(sizeof(float)); // nudge_max -} - -bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - float *output = GetDeviceAddress(outputs, 0); - float *gradient = GetDeviceAddress(inputs, 0); - float *input = GetDeviceAddress(inputs, 1); - float *input_min = GetDeviceAddress(inputs, 2); - float *input_max = GetDeviceAddress(inputs, 3); - float *scale = GetDeviceAddress(workspace, 0); - float *nudge_min = GetDeviceAddress(workspace, 1); - float *nudge_max = GetDeviceAddress(workspace, 2); - - if (gradient == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel gradient is null"; - } - if (input == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input is null."; - } - if (input_min == nullptr || input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input min or max is null."; - } - - if (global_step_ >= quant_delay_) { - CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, - reinterpret_cast(stream_ptr)); - CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, - reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - } - global_step_++; - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h deleted file mode 100644 index ae2ea5bfac..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class FakeQuantPerLayerGradGpuKernel : public GpuKernel { - public: - FakeQuantPerLayerGradGpuKernel(); - ~FakeQuantPerLayerGradGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel_node) override; - - protected: - void InitSizeLists() override; - - private: - size_t input_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int num_bits_; - float quant_min_; - float quant_max_; - int quant_num_; - int quant_delay_; - int global_step_; - bool narrow_range_; - bool symmetric_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc deleted file mode 100644 index a8ce72148b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh" -#include -#include -#include -#include - -namespace mindspore { -namespace kernel { -MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel() - : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0), num_channels_(0) {} - -const std::vector &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &MinMaxUpdatePerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList() const { - return workspace_size_list_; -} - -bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 2) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; - } - - ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); - ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); - - // init size - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - num_channels_ = SizeToInt(input_shape[0]); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= SizeToInt(input_shape[i]); - } - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void MinMaxUpdatePerChannelGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(sizeof(float) * num_channels_); // min - input_size_list_.push_back(sizeof(float) * num_channels_); // max - output_size_list_.push_back(sizeof(float) * num_channels_); // output min - output_size_list_.push_back(sizeof(float) * num_channels_); // output max -} - -bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { - float *output_min = GetDeviceAddress(outputs, 0); - float *output_max = GetDeviceAddress(outputs, 1); - float *input = GetDeviceAddress(inputs, 0); - float *input_min = GetDeviceAddress(inputs, 1); - float *input_max = GetDeviceAddress(inputs, 2); - - if (input == nullptr) { - MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input x is null."; - } - if (input_min == nullptr || input_max == nullptr) { - MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input min or input max is null."; - } - - // calculate the input min and max according by the parameter ema and ema_decay. - CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_, - ema_decay_, ema_, reinterpret_cast(stream_ptr)); - return true; -} - -MS_REG_GPU_KERNEL(MinMaxUpdatePerChannel, MinMaxUpdatePerChannelGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h deleted file mode 100644 index 563a583ca1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class MinMaxUpdatePerChannelGpuKernel : public GpuKernel { - public: - MinMaxUpdatePerChannelGpuKernel(); - ~MinMaxUpdatePerChannelGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel) override; - - protected: - void InitSizeLists() override; - - private: - size_t input_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int quant_num_; - bool ema_; - float ema_decay_; - int num_channels_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc deleted file mode 100644 index 3659665b23..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh" -#include -#include -#include -#include - -namespace mindspore { -namespace kernel { -MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel() - : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0) {} - -const std::vector &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 2) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; - } - - ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); - ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); - - // init size - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= SizeToInt(input_shape[i]); - } - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void MinMaxUpdatePerLayerGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(sizeof(float)); // input min - input_size_list_.push_back(sizeof(float)); // input max - output_size_list_.push_back(sizeof(float)); // output min - output_size_list_.push_back(sizeof(float)); // output max -} - -bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { - float *output_min = GetDeviceAddress(outputs, 0); - float *output_max = GetDeviceAddress(outputs, 1); - float *input = GetDeviceAddress(inputs, 0); - float *input_min = GetDeviceAddress(inputs, 1); - float *input_max = GetDeviceAddress(inputs, 2); - - if (input == nullptr) { - MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input x is null."; - } - if (input_min == nullptr || input_max == nullptr) { - MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null."; - } - - CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, - reinterpret_cast(stream_ptr)); - - return true; -} - -MS_REG_GPU_KERNEL(MinMaxUpdatePerLayer, MinMaxUpdatePerLayerGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h deleted file mode 100644 index a237b6dc26..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class MinMaxUpdatePerLayerGpuKernel : public GpuKernel { - public: - MinMaxUpdatePerLayerGpuKernel(); - ~MinMaxUpdatePerLayerGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel) override; - - protected: - void InitSizeLists() override; - - private: - size_t input_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int quant_num_; - bool ema_; - float ema_decay_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel.cc deleted file mode 100644 index 87fb8d743d..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel.cc +++ /dev/null @@ -1,157 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hccl_kernel.h" -#include "device/ascend/tasksink/runtime_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" - -using HcclTaskInfoPtr = std::shared_ptr; -using ge::model_runner::HcclTaskInfo; -using mindspore::device::ascend::tasksink::RuntimeUtils; - -namespace mindspore { -namespace kernel { -void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) { - hcclKernelMap_.emplace(name, std::move(fun)); -} - -std::shared_ptr HcclKernelFactory::Get(const std::string &name) { - const auto &map = Get().hcclKernelMap_; - auto it = map.find(name); - if (it != map.end() && it->second) { - return (it->second)(); - } - return nullptr; -} - -HcclKernelFactory &HcclKernelFactory::Get() { - static HcclKernelFactory _this; - return _this; -} - -HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REP_OP_SUM), root_id_(0), anf_node_(nullptr) {} - -HcclKernel::~HcclKernel() { - hccl_kernel_input_shape_list_.clear(); - hccl_kernel_output_shape_list_.clear(); - hccl_data_type_list_.clear(); - hccl_count_ = 0; - op_type_ = HCCL_REP_OP_SUM; - root_id_ = 0; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - anf_node_ = nullptr; -} - -bool HcclKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - op_name_ = AnfAlgo::GetCNodeName(anf_node); - - if (!HcomUtil::GetKernelInputShape(anf_node, &hccl_kernel_input_shape_list_)) { - MS_LOG(ERROR) << "GetKernelInputShape fail!"; - return false; - } - if (!HcomUtil::GetKernelOutputShape(anf_node, &hccl_kernel_output_shape_list_)) { - MS_LOG(ERROR) << "GetKernelOutputShape fail!"; - return false; - } - if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) { - MS_LOG(ERROR) << "GetHcomDataType fail!"; - return false; - } - if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) { - MS_LOG(ERROR) << "GetHcomCount fail!"; - return false; - } - if (op_name_ == kAllReduce || op_name_ == kReduceScatter) { - if (!HcomUtil::GetHcomOperationType(anf_node, &op_type_)) { - MS_LOG(ERROR) << "GetHcomOperationType fail!"; - return false; - } - } - if (op_name_ == kBroadcast) { - if (!HcomUtil::GetHcomRootId(anf_node, &root_id_)) { - MS_LOG(ERROR) << "GetHcomRootId fail!"; - return false; - } - } - HcomUtil::GetHcomGroup(NOT_NULL(anf_node), NOT_NULL(&group_)); - anf_node_ = anf_node; - return true; -} - -const std::vector &HcclKernel::GetInputSizeList() const { - size_t size = 0; - if (!input_size_list_.empty()) { - return input_size_list_; - } - for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { - if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_input_shape_list_[i], &size)) { - MS_LOG(ERROR) << "GetHcclOpInputSize failed"; - } - input_size_list_.push_back(size); - } - return input_size_list_; -} - -const std::vector &HcclKernel::GetOutputSizeList() const { - size_t size = 0; - if (!output_size_list_.empty()) { - return output_size_list_; - } - for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { - if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_output_shape_list_[i], &size)) { - MS_LOG(ERROR) << "GetHcclOpOutputSize failed"; - } - output_size_list_.push_back(size); - } - return output_size_list_; -} - -const std::vector &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -std::vector HcclKernel::GenTask(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "Inputs or outputs is empty"; - } - stream_id_ = stream_id; - std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); - MS_EXCEPTION_IF_NULL(inputs.at(0)); - auto input_data_addr = inputs.at(0)->addr; - MS_EXCEPTION_IF_NULL(outputs.at(0)); - auto output_data_addr = outputs.at(0)->addr; - void *workspace_address = nullptr; - const int64_t workspace_num = 0; - std::vector private_def; - hcclDataType_t data_type = hccl_data_type_list_[0]; - - MS_LOG(INFO) << "HCCL Task : stream_id=" << stream_id << ", ws_num=" << workspace_num << ", count=" << hccl_count_ - << ", root_id=" << root_id_ << ", op_type=" << static_cast(op_type_) - << ", data_type=" << static_cast(data_type); - - HcclTaskInfoPtr task_info_ptr = std::make_shared( - stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, private_def, nullptr, - hccl_count_, root_id_, op_type_, data_type, group_, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel, - RuntimeUtils::HcomDistribute); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel.h b/mindspore/ccsrc/kernel/hccl/hccl_kernel.h deleted file mode 100644 index 72e202591f..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel.h +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include "kernel/ascend_kernel_mod.h" -#include "kernel/hccl/hcom_util.h" -#include "hccl/hcom.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -class HcclKernel : public AscendKernelMod { - public: - HcclKernel(); - ~HcclKernel() override; - virtual bool Init(const AnfNodePtr &anf_node); - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - protected: - std::vector> hccl_kernel_input_shape_list_; - std::vector> hccl_kernel_output_shape_list_; - std::vector hccl_data_type_list_; - std::vector hccl_format_list_; - uint64_t hccl_count_; - hcclRedOp_t op_type_; - uint32_t root_id_; - mutable std::vector input_size_list_; - mutable std::vector output_size_list_; - mutable std::vector workspace_size_list_; - AnfNodePtr anf_node_; - std::string op_name_; - std::string group_; -}; - -using HcclKernelCreater = std::function()>; - -class HcclKernelFactory { - HcclKernelFactory() = default; - ~HcclKernelFactory() = default; - - public: - static HcclKernelFactory &Get(); - void Registe(const string &name, HcclKernelCreater &&fun); - static std::shared_ptr Get(const string &name); - - private: - std::map hcclKernelMap_; -}; - -class _HcclKernelRegister { - public: - _HcclKernelRegister(const string &name, HcclKernelCreater &&fun) { - HcclKernelFactory::Get().Registe(name, std::move(fun)); - } - ~_HcclKernelRegister() = default; -}; - -#define _MS_HCCL_REG_KERNEL_REG(KNAME, clazz) \ - static_assert(std::is_base_of::value, " must be base of HcclKernel"); \ - static const _HcclKernelRegister g_##KNAME##_##_kernel_reg(#KNAME, []() { \ - std::shared_ptr ptr = nullptr; \ - ptr = std::make_shared(); \ - MS_EXCEPTION_IF_NULL(ptr); \ - return ptr; \ - }); - -#define MS_HCCL_REG_KERNEL(KNAME, clazz) _MS_HCCL_REG_KERNEL_REG(KNAME, clazz) -} // namespace kernel -} // namespace mindspore -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.cc deleted file mode 100644 index d6e4aa09b9..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.cc +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hccl_kernel_build.h" - -#include -#include -#include - -#include "kernel/hccl/hccl_kernel.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -KernelModPtr HcclOpBuild(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string opname = AnfAlgo::GetCNodeName(anf_node); - MS_LOG(INFO) << "Hccl op [" << opname << "]"; - auto kerPtr = HcclKernelFactory::Get(opname); - if (kerPtr == nullptr) { - MS_LOG(ERROR) << "Hccl can't find Kernel[" << opname << "]"; - return nullptr; - } - if (!kerPtr->Init(anf_node)) { - MS_LOG(ERROR) << "Kernel initialize failed!"; - return nullptr; - } - return kerPtr; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.h b/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.h deleted file mode 100644 index f20760a3eb..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_BUILD_H_ - -#include -#include -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -KernelModPtr HcclOpBuild(const AnfNodePtr &anf_node); -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc deleted file mode 100755 index 601d5cf1ea..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hccl_kernel_metadata.h" -#include -#include "utils/utils.h" -#include "kernel/hccl/hcom_util.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { - const std::vector kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, - kNumberTypeFloat32, kNumberTypeInt16}; - MS_EXCEPTION_IF_NULL(kernel_info_list); - MS_EXCEPTION_IF_NULL(kernel_node); - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - if (op_name != kAllGather && op_name != kAllReduce && op_name != kBroadcast && op_name != kReduceScatter) { - MS_LOG(DEBUG) << "Hccl does not have op [" << op_name << "]"; - return; - } - for (const auto &type : kHcclSupportTypes) { - std::vector inputs_format{}; - std::vector inputs_type{}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); - inputs_type.push_back(type); - } - std::vector outputs_format; - std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); - outputs_type.push_back(type); - } - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - builder.SetInputsFormat(inputs_format); - builder.SetInputsDeviceType(inputs_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_type); - builder.SetKernelType(HCCL_KERNEL); - kernel_info_list->push_back(builder.Build()); - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.h b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.h deleted file mode 100755 index b13393d3bd..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ -#include -#include -#include -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.cc b/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.cc deleted file mode 100644 index 9dbe708ef9..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hcom_all_broadcast.h" - -#include -#include -#include - -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -bool HcomAllBroadCastKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void *stream_ptr) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { - return true; - } - if (inputs.empty() || hccl_data_type_list_.empty()) { - MS_LOG(ERROR) << "BroadCast param is empty"; - return false; - } - const char *tag = "Hccl-BroadCast"; - MS_EXCEPTION_IF_NULL(inputs[0]); - hcclResult_t ret = - hcom_broadcast(tag, inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, nullptr, stream_ptr); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << static_cast(ret); - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.h b/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.h deleted file mode 100644 index ca8eba91af..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_BROADCAST_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_BROADCAST_H_ - -#include -#include -#include "hccl/hcom.h" -#include "kernel/hccl/hccl_kernel.h" - -namespace mindspore { -namespace kernel { -class HcomAllBroadCastKernel : public HcclKernel { - public: - HcomAllBroadCastKernel() = default; - ~HcomAllBroadCastKernel() override = default; - - /* Inherit from kernelmod */ - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - private: -}; -MS_HCCL_REG_KERNEL(Broadcast, HcomAllBroadCastKernel); -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_gather.cc b/mindspore/ccsrc/kernel/hccl/hcom_all_gather.cc deleted file mode 100644 index 6494f7fd12..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_gather.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hcom_all_gather.h" - -#include -#include -#include - -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -bool HcomAllGatherKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs, void *stream_ptr) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { - return true; - } - if (inputs.empty() || hccl_data_type_list_.empty()) { - MS_LOG(ERROR) << "AllGather param is empty"; - return false; - } - const char *tag = "Hccl-AllGather"; - hcclResult_t ret = - hcom_all_gather(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], nullptr, stream_ptr); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "HcomAllGatherKernelOp : hcom_all_gather fail, return: " << static_cast(ret); - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_gather.h b/mindspore/ccsrc/kernel/hccl/hcom_all_gather.h deleted file mode 100644 index 5de2c513cf..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_gather.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_GATHER_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_GATHER_H_ - -#include -#include -#include "hccl/hcom.h" -#include "kernel/hccl/hccl_kernel.h" - -namespace mindspore { -namespace kernel { -class HcomAllGatherKernel : public HcclKernel { - public: - HcomAllGatherKernel() = default; - ~HcomAllGatherKernel() override = default; - - /* Inherit from kernelmod */ - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - private: -}; -MS_HCCL_REG_KERNEL(AllGather, HcomAllGatherKernel); -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.cc b/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.cc deleted file mode 100644 index 35a058e766..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hcom_all_reduce.h" - -#include -#include -#include - -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -bool HcomAllReduceKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs, void *stream_ptr) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { - return true; - } - if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { - MS_LOG(ERROR) << "AllReduce param is empty"; - return false; - } - const char *tag = "Hccl-AllReduce"; - hcclResult_t ret = hcom_all_reduce(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], - op_type_, nullptr, stream_ptr); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "HcomAllReduceKernelOp : hcom_all_reduce fail, return: " << static_cast(ret); - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.h b/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.h deleted file mode 100644 index 939abd9de7..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_H_ - -#include -#include -#include "kernel/hccl/hccl_kernel.h" - -namespace mindspore { -namespace kernel { -class HcomAllReduceKernel : public HcclKernel { - public: - HcomAllReduceKernel() = default; - ~HcomAllReduceKernel() override = default; - - /* Inherit from kernelmod */ - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - private: -}; - -MS_HCCL_REG_KERNEL(AllReduce, HcomAllReduceKernel); -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.cc b/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.cc deleted file mode 100644 index dea516885d..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hcom_all_reduce_scatter.h" - -#include -#include -#include - -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -bool HcomAllReduceScatterKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs, void *stream_ptr) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { - return true; - } - if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { - MS_LOG(ERROR) << "ReduceScatter param is empty"; - return false; - } - const char *tag = "Hccl-ReduceScatter"; - hcclResult_t ret = hcom_reduce_scatter(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], - op_type_, nullptr, stream_ptr); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "HcomReduceScatterOp : hcom_reduce_scatter fail, return: " << static_cast(ret); - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.h b/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.h deleted file mode 100644 index c734b517c6..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ - -#include -#include -#include "hccl/hcom.h" -#include "kernel/hccl/hccl_kernel.h" - -namespace mindspore { -namespace kernel { -class HcomAllReduceScatterKernel : public HcclKernel { - public: - HcomAllReduceScatterKernel() = default; - ~HcomAllReduceScatterKernel() override = default; - - /* Inherit from kernelmod */ - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - private: -}; - -MS_HCCL_REG_KERNEL(ReduceScatter, HcomAllReduceScatterKernel); -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hcom_util.cc b/mindspore/ccsrc/kernel/hccl/hcom_util.cc deleted file mode 100644 index 088dbe59d5..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_util.cc +++ /dev/null @@ -1,198 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hcom_util.h" - -#include - -#include "kernel/common_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" - -namespace mindspore { -bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_intput_shape_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { - std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); - hccl_kernel_intput_shape_list->emplace_back(shape_i); - } - - return true; -} - -bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_output_shape_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list); - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node); ++i) { - std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); - hccl_kernel_output_shape_list->emplace_back(shape_i); - } - - return true; -} - -bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector *data_type_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(data_type_list); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { - auto type_ptr = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, i); - auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type_ptr); - if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { - MS_LOG(EXCEPTION) << "HcomDataType cann't support Current Ascend Data Type : " << type_ptr; - } - data_type_list->emplace_back(iter->second); - } - auto type_base = *(std::begin(*data_type_list)); - if (std::any_of(data_type_list->begin(), data_type_list->end(), - [&type_base](hcclDataType_t type) { return type != type_base; })) { - MS_LOG(ERROR) << "hccl have different data type"; - return false; - } - return true; -} - -bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector &shape, size_t *size) { - MS_EXCEPTION_IF_NULL(size); - size_t tmp_size = 1; - uint32_t type_size = 4; - for (size_t i = 0; i < shape.size(); i++) { - tmp_size = SizetMulWithOverflowCheck(tmp_size, shape[i]); - } - - if (!GetHcomTypeSize(data_type, &type_size)) { - return false; - } - - *size = SizetMulWithOverflowCheck(tmp_size, type_size); - - MS_LOG(INFO) << "size[" << *size << "]"; - return true; -} - -bool HcomUtil::GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size) { - MS_EXCEPTION_IF_NULL(size); - auto iter = CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.find(data_type); - if (iter == CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.end()) { - MS_LOG(ERROR) << "HcomUtil::HcomDataTypeSize, No DataTypeSize!"; - return false; - } - *size = iter->second; - return true; -} - -bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector &data_type_list, - const vector> &shape_list, uint64_t *total_count) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(total_count); - const uint32_t align_size = 512; - const uint32_t filled_size = 32; - uint64_t total_size = 0; - uint64_t block_size; - size_t input_size; - uint32_t type_size = 4; - - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { - if (!GetHcomTypeSize(data_type_list[i], &type_size)) { - return false; - } - - if (!GetHcclOpSize(data_type_list[i], shape_list[i], &input_size)) { - MS_LOG(ERROR) << "Get GetHcclOpSize failed"; - return false; - } - - if (AnfAlgo::GetCNodeName(anf_node) == kReduceScatterOpName) { - int32_t rank_size; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("rank_size") != nullptr) { - rank_size = GetValue(primitive->GetAttr("rank_size")); - } else { - MS_LOG(ERROR) << "Get rank size failed"; - return false; - } - block_size = input_size / IntToSize(rank_size); - total_size = total_size + block_size; - } else { - if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) { - block_size = input_size; - } else { - block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; - } - total_size = total_size + block_size; - } - } - - if (type_size == 0 || total_size % type_size != 0) { - MS_LOG(ERROR) << "Total_size[" << total_size << "],Type_size[" << type_size << "] != 0, fail!"; - return false; - } - *total_count = total_size / type_size; - return true; -} - -bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_type); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("op") == nullptr) { - MS_LOG(ERROR) << "Get HCOM_ATTR_REDUCE_TYPE fail, not support!"; - return false; - } - auto hcom_op_type_get = GetValue(primitive->GetAttr("op")); - string hcom_op_type(hcom_op_type_get); - if (hcom_op_type == "min") { - *op_type = HCCL_REP_OP_MIN; - } else if (hcom_op_type == "max") { - *op_type = HCCL_REP_OP_MAX; - } else if (hcom_op_type == "prod") { - *op_type = HCCL_REP_OP_PROD; - } else if (hcom_op_type == "sum") { - *op_type = HCCL_REP_OP_SUM; - } else { - MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!"; - return false; - } - return true; -} - -bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(root_id); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("root_rank") != nullptr) { - *root_id = (uint32_t)GetValue(primitive->GetAttr("root_rank")); - } else { - MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"; - return false; - } - return true; -} - -void HcomUtil::GetHcomGroup(NotNull anf_node, NotNull group) { - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - auto attr = primitive->GetAttr("group"); - if (attr != nullptr) { - *group = GetValue(attr); - } else { - MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed"; - } -} -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kash/kernel_pack.cc b/mindspore/ccsrc/kernel/kash/kernel_pack.cc deleted file mode 100644 index a87441031b..0000000000 --- a/mindspore/ccsrc/kernel/kash/kernel_pack.cc +++ /dev/null @@ -1,249 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "mindspore/ccsrc/kernel/kernel.h" -#include "kernel/kernel.h" -#include "kernel/akg/akg_kernel_build.h" -#include "nlohmann/json.hpp" -#include "securec/include/securec.h" -#include "pipeline/parse/python_adapter.h" -#include "utils/log_adapter.h" -#include "utils/convert_utils.h" -namespace mindspore { -namespace kernel { -constexpr auto kUtilsModule = "mindspore._extends.utils"; -constexpr auto kCalSha256Func = "cal_sha256"; - -namespace { -bool CheckHash(const std::string &json_file, const std::string &bin_file, const nlohmann::json &js) { - if (js.find("sha256") == js.end()) { - MS_LOG(ERROR) << "No sha256 found in " << json_file; - return false; - } - std::string sha256_str = js["sha256"]; - py::object ret = parse::python_adapter::CallPyFn(kUtilsModule, kCalSha256Func, bin_file); - std::string sha256_cal = py::cast(ret); - if (sha256_cal.empty()) { - MS_LOG(ERROR) << "Cal sha256 of " << bin_file << " failed."; - return false; - } - if (sha256_cal != sha256_str) { - MS_LOG(ERROR) << "Cal sha256 of " << bin_file << " failed."; - return false; - } - return true; -} -} // namespace - -const std::string KernelPack::Serialize() const { - MS_EXCEPTION_IF_NULL(json_); - MS_EXCEPTION_IF_NULL(kernel_); - std::string buffer; - (void)buffer.append((const char *)json_, json_->len + sizeof(json_->len)); - (void)buffer.append((const char *)kernel_, kernel_->len + sizeof(kernel_->len)); - return buffer; -} - -bool KernelPack::ReadFromJsonFileHelper(std::ifstream &kernelbin) { - size_t binsize = LongToSize(kernelbin.seekg(0, std::ios::end).tellg()); - // free old data - if (kernel_ != nullptr) { - delete[] kernel_; - kernel_ = nullptr; - } - - void *ptr = static_cast(new (std::nothrow) uint8_t[sizeof(KernelPack) + binsize]); - if (ptr != nullptr) { - kernel_ = static_cast(ptr); - } - if (kernel_ == nullptr) { - MS_LOG(ERROR) << "memory malloc failed."; - kernelbin.close(); - return false; - } - if (memset_s(kernel_, sizeof(KernelPack) + binsize, 0, sizeof(KernelPack) + binsize) != EOK) { - MS_LOG(ERROR) << "memset kernel_ failed."; - delete[] kernel_; - kernel_ = nullptr; - kernelbin.close(); - return false; - } - kernel_->len = binsize; - MS_LOG(INFO) << "kernel len:" << kernel_->len; - (void)kernelbin.seekg(0, std::ios::beg); - (void)kernelbin.read(kernel_->contents, SizeToLong(kernel_->len)); - return true; -} - -bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &processor) { - if (json_f.length() <= strlen(kJsonSuffix)) { - MS_LOG(ERROR) << "please check json path."; - return false; - } - - std::ifstream kerneljson(json_f); - if (!kerneljson.is_open()) { - MS_LOG(DEBUG) << "read json file error, please check kernelmeta."; - return false; - } - nlohmann::json js; - kerneljson >> js; - - size_t binsize = LongToSize(kerneljson.seekg(0, std::ios::end).tellg()); - void *ptr = static_cast(new (std::nothrow) uint8_t[sizeof(KernelPack) + binsize]); - if (ptr != nullptr) { - json_ = static_cast(ptr); - } - if (json_ == nullptr) { - MS_LOG(ERROR) << "memory malloc failed."; - kerneljson.close(); - return false; - } - json_->len = binsize; - (void)kerneljson.seekg(0, std::ios::beg); - (void)kerneljson.read(json_->contents, SizeToLong(json_->len)); - - if (processor == kProcessorCuda) { - std::string bin_f = json_f.substr(0, json_f.length() - 5) + ".ptx"; - std::ifstream kernelbin(bin_f); - if (!kernelbin.is_open()) { - MS_LOG(ERROR) << "read kernel ptx file error, please check kernelmeta."; - kerneljson.close(); - return false; - } - - if (ReadFromJsonFileHelper(kernelbin) == false) { - delete[] json_; - json_ = nullptr; - kerneljson.close(); - return false; - } - kerneljson.close(); - if (!CheckHash(json_f, bin_f, js)) { - return false; - } - return true; - } - - std::string binfilesuffix = js["binFileSuffix"]; - std::string bin_f = json_f.substr(0, json_f.length() - 5) + binfilesuffix; - if (binfilesuffix.compare(".so") == 0) { - // change "xx/xx.so" -> "xx/libxx.so" - auto sp = bin_f.rfind('/'); - if (sp == std::string::npos) { - MS_LOG(ERROR) << "illegal bin file path " << bin_f; - kerneljson.close(); - return false; - } - bin_f = bin_f.substr(0, sp + 1) + "lib" + bin_f.substr(sp + 1, bin_f.length() - sp - 1); - } - - std::ifstream kernelbin(bin_f, std::ios::binary); - if (!kernelbin.is_open()) { - MS_LOG(ERROR) << "read kernel binary file error, please check kernelmeta."; - kerneljson.close(); - delete[] json_; - json_ = nullptr; - return false; - } - - MS_LOG(INFO) << "kernelbin_name:" << bin_f; - if (ReadFromJsonFileHelper(kernelbin) == false) { - delete[] json_; - json_ = nullptr; - kerneljson.close(); - return false; - } - kerneljson.close(); - - if (!CheckHash(json_f, bin_f, js)) { - return false; - } - - return true; -} - -void KernelPack::ParseKernelJson(const nlohmann::json &js) { - kernel_json_info_.bin_file_name = js["binFileName"]; - kernel_json_info_.bin_file_suffix = js["binFileSuffix"]; - kernel_json_info_.block_dim = js["blockDim"]; - kernel_json_info_.kernel_name = js["kernelName"]; - kernel_json_info_.magic = js["magic"]; - if (js.find("parameters") != js.end()) { - if (!js.at("parameters").is_array()) { - MS_LOG(DEBUG) << "Format error!,parameters should be array."; - } - std::vector sizes = js.at("parameters"); - for (auto size : sizes) { - MS_LOG(INFO) << "parameter " << size; - kernel_json_info_.parameters.push_back(size); - } - } - if (js.find("workspace") != js.end()) { - auto workspace = js.at("workspace"); - std::vector sizes = workspace.at("size"); - for (auto size : sizes) { - MS_LOG(INFO) << "workspace_size_list " << size; - kernel_json_info_.workspaces.push_back(size); - } - } - kernel_json_info_.sha256 = js["sha256"]; -} - -bool KernelPack::LoadKernelMeta(const std::string &json_f, const std::string &processor) { - if (json_f.length() <= strlen(kJsonSuffix)) { - MS_LOG(ERROR) << "please check json path."; - return false; - } - std::ifstream kernel_json(json_f); - if (!kernel_json.is_open()) { - MS_LOG(DEBUG) << "read json file error, please check kernelmeta."; - return false; - } - nlohmann::json js; - kernel_json >> js; - ParseKernelJson(js); - kernel_json.close(); - - std::string bin_f = json_f.substr(0, json_f.length() - 5) + kernel_json_info_.bin_file_suffix; - if (kernel_json_info_.bin_file_suffix == ".so") { - // change "xx/xx.so" -> "xx/libxx.so" - auto sp = bin_f.rfind('/'); - if (sp == std::string::npos) { - MS_LOG(ERROR) << "illegal bin file path " << bin_f; - return false; - } - bin_f = bin_f.substr(0, sp + 1) + "lib" + bin_f.substr(sp + 1, bin_f.length() - sp - 1); - } - - std::ifstream kernelbin(bin_f, std::ios::binary); - if (!kernelbin.is_open()) { - MS_LOG(ERROR) << "read kernel binary file error, please check kernelmeta."; - return false; - } - - MS_LOG(INFO) << "kernelbin_name:" << bin_f; - if (!ReadFromJsonFileHelper(kernelbin)) { - return false; - } - - return CheckHash(json_f, bin_f, js); -} - -KernelJsonInfo KernelPack::kernel_json_info() const { return kernel_json_info_; } -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel.h b/mindspore/ccsrc/kernel/kernel.h deleted file mode 100644 index 7bccce49c3..0000000000 --- a/mindspore/ccsrc/kernel/kernel.h +++ /dev/null @@ -1,137 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNEL_H_ -#include -#include -#include -#include "nlohmann/json.hpp" -#include "ir/anf.h" -#include "ir/dtype.h" -#include "utils/utils.h" -#include "ir/tensor.h" -#include "pipeline/static_analysis/dshape.h" -#include "utils/log_adapter.h" - -namespace mindspore { -enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; - -namespace kernel { - -enum Axis : int { - N = 0, - C, - H, - W, -}; - -// Supported fusion type -enum FusionType { - CONVLUTION = 0, - ELEMWISE, - COMMREDUCE, - SEGMENT, - OPAQUE, - DYNAMIC, - UNKNOWN_FUSION_TYPE = -1, -}; -enum OpPattern { - kCommonPattern = 0, - kFormatAgnosticPattern = 1, - kBroadcastPattern = 2, - kReducePattern = 3, - kDynamicFormatPattern = 4, -}; - -// Backend processor -enum Processor { - AICORE = 0, - AICPU, - CUDA, -}; - -struct FlexArray { - size_t len; - char contents[]; -}; - -struct KernelJsonInfo { - std::string bin_file_name; - std::string bin_file_suffix; - uint32_t block_dim; - std::string kernel_name; - std::string magic; - std::vector parameters; - std::string sha256; - std::vector workspaces; - KernelJsonInfo() : block_dim(0) {} -}; - -class KernelPack { - public: - KernelPack() : json_(nullptr), kernel_(nullptr) {} - KernelPack(const KernelPack &) = default; - KernelJsonInfo kernel_json_info() const; - bool LoadKernelMeta(const std::string &json_f, const std::string &processor); - bool ReadFromJsonFile(const std::string &json_f, const std::string &processor); - const std::string Serialize() const; - const FlexArray *const GetJson() const { return json_; } - const FlexArray *const GetKernel() const { return kernel_; } - ~KernelPack() { - if (json_) { - delete[] json_; - json_ = nullptr; - } - if (kernel_) { - delete[] kernel_; - kernel_ = nullptr; - } - } - - private: - bool ReadFromJsonFileHelper(std::ifstream &kernelbin); - void ParseKernelJson(const nlohmann::json &js); - KernelJsonInfo kernel_json_info_; - FlexArray *json_; - FlexArray *kernel_; -}; -using KernelPackPtr = std::shared_ptr; - -/** - * @brief base class for autotensor kernel and cce kernel. - */ -struct Address { - void *addr; - size_t size; -}; -using AddressPtr = std::shared_ptr
; - -class KernelMod { - public: - virtual const std::vector &GetInputSizeList() const = 0; - virtual const std::vector &GetOutputSizeList() const = 0; - virtual const std::vector &GetWorkspaceSizeList() const = 0; - virtual bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) = 0; - virtual std::vector GenParameters() { return {}; } - - virtual ~KernelMod() = default; -}; -using KernelModPtr = std::shared_ptr; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc deleted file mode 100644 index c912a0c199..0000000000 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/kernel_build_info.h" -#include -#include "utils/log_adapter.h" -#include "debug/anf_ir_dump.h" -namespace mindspore { -namespace kernel { -std::string KernelBuildInfo::GetInputFormat(size_t input_index) const { - if (input_index >= inputs_format_.size()) { - MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node"; - return kInvalidFormat; - } - return inputs_format_[input_index]; -} - -std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const { - if (output_index >= outputs_format_.size()) { - MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node"; - return kInvalidFormat; - } - return outputs_format_[output_index]; -} - -TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const { - if (input_index >= inputs_device_type_.size()) { - MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input"; - return TypeId::kNumberTypeEnd; - } - return inputs_device_type_[input_index]; -} - -TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const { - if (output_index >= outputs_device_type_.size()) { - MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output"; - return TypeId::kNumberTypeEnd; - } - return outputs_device_type_[output_index]; -} - -std::vector KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } - -std::vector KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } - -std::vector KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; } - -std::vector KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; } - -size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } - -size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } - -std::vector KernelBuildInfo::GetInputReshapeType(size_t input_index) const { - if (input_index >= input_reshape_type_.size()) { - MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size " - << input_reshape_type_.size(); - } - return input_reshape_type_[input_index]; -} - -std::vector KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { - if (output_index >= output_reshape_type_.size()) { - MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size " - << output_reshape_type_.size(); - } - return output_reshape_type_[output_index]; -} - -std::string KernelBuildInfo::ToString() const { - std::ostringstream output_buffer; - output_buffer << "("; - for (size_t index = 0; index < GetInputNum(); ++index) { - if (index != 0) { - output_buffer << ", "; - } - output_buffer << "<" << ToShortString(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">"; - } - output_buffer << ") -> ("; - for (size_t index = 0; index < GetOutputNum(); ++index) { - if (index != 0) { - output_buffer << ", "; - } - output_buffer << "<" << ToShortString(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">"; - } - output_buffer << ")"; - return output_buffer.str(); -} - -bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { - if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) { - return false; - } - if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) { - if (op_pattern_ != kFormatAgnosticPattern) { - return false; - } else { - MS_LOG(INFO) << "this kernel build info:" << this->ToString() - << ", other kernel build info: " << other.ToString(); - } - } - return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); -} - -bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); } - -bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); } - -void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->kernel_type_ = kernel_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector &inputs_format) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->inputs_format_ = inputs_format; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsFormat(const std::vector &outputs_format) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->outputs_format_ = outputs_format; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsDeviceType(const std::vector &inputs_device_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->inputs_device_type_ = inputs_device_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::vector &outputs_device_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->outputs_device_type_ = outputs_device_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->fusion_type_ = fusion_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->processor_ = processor; -} - -std::shared_ptr KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; } - -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType( - const std::vector> &input_reshape_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->input_reshape_type_ = input_reshape_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( - const std::vector> &output_reshape_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->output_reshape_type_ = output_reshape_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->op_pattern_ = pattern; -} -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - if (index >= kernel_build_info_->inputs_format_.size()) { - MS_LOG(EXCEPTION) << "index outof range!"; - } - kernel_build_info_->inputs_format_[index] = format; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - if (index >= kernel_build_info_->outputs_format_.size()) { - MS_LOG(EXCEPTION) << "index outof range!"; - } - kernel_build_info_->outputs_format_[index] = format; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_build_info.h b/mindspore/ccsrc/kernel/kernel_build_info.h deleted file mode 100644 index ca1083fd68..0000000000 --- a/mindspore/ccsrc/kernel/kernel_build_info.h +++ /dev/null @@ -1,145 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ -#include -#include -#include -#include -#include -#include "ir/dtype.h" -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -class KernelBuildInfo { - public: - class KernelBuildInfoBuilder; - - KernelBuildInfo() { - kernel_type_ = TBE_KERNEL; - fusion_type_ = OPAQUE; - processor_ = AICORE; - op_pattern_ = kCommonPattern; - input_reshape_type_ = {}; - output_reshape_type_ = {}; - inputs_format_ = {}; - outputs_format_ = {}; - inputs_device_type_ = {}; - outputs_device_type_ = {}; - } - - ~KernelBuildInfo() = default; - - KernelType kernel_type() const { return kernel_type_; } - - std::string GetInputFormat(size_t input_index) const; - - std::string GetOutputFormat(size_t output_index) const; - - TypeId GetInputDeviceType(size_t input_index) const; - - TypeId GetOutputDeviceType(size_t output_index) const; - - std::vector GetInputReshapeType(size_t input_index) const; - - bool IsInputDefaultPadding() const; - - bool IsOutputDefaultPadding() const; - - std::vector GetOutputReshapeType(size_t input_index) const; - - std::vector GetAllInputFormats() const; - - std::vector GetAllOutputFormats() const; - - std::vector GetAllInputDeviceTypes() const; - - std::vector GetAllOutputDeviceTypes() const; - - OpPattern op_pattern() const { return op_pattern_; } - - FusionType fusion_type() const { return fusion_type_; } - - Processor processor() const { return processor_; } - - size_t GetInputNum() const; - - size_t GetOutputNum() const; - - std::string ToString() const; - - bool operator==(const KernelBuildInfo &other) const; - - public: - static auto constexpr kInvalidFormat = "InvalidFormat"; - - private: - KernelType kernel_type_; - std::vector inputs_format_; - OpPattern op_pattern_; - std::vector outputs_format_; - std::vector> input_reshape_type_; - std::vector> output_reshape_type_; - std::vector inputs_device_type_; - std::vector outputs_device_type_; - FusionType fusion_type_; - Processor processor_; -}; -using KernelBuildInfoPtr = std::shared_ptr; - -class KernelBuildInfo::KernelBuildInfoBuilder { - public: - KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared(); } - - explicit KernelBuildInfoBuilder(std::shared_ptr kernel_build_info) - : kernel_build_info_(std::move(kernel_build_info)) {} - - ~KernelBuildInfoBuilder() = default; - - void SetKernelType(const KernelType &kernel_type); - - void SetInputsFormat(const std::vector &inputs_format); - - void SetOutputsFormat(const std::vector &outputs_format); - - void SetInputsDeviceType(const std::vector &inputs_device_type); - - void SetOutputsDeviceType(const std::vector &outputs_device_type); - - void SetInputReshapeType(const std::vector> &input_reshape_type); - - void SetOutputReshapeType(const std::vector> &output_reshape_type); - - void SetFusionType(FusionType fusion_type); - - void SetProcessor(Processor processor); - - void SetOpPattern(OpPattern pattern); - - void SetInputFormat(const std::string &format, size_t index); - - void SetOutputFormat(const std::string &format, size_t index); - - std::shared_ptr Build(); - - private: - std::shared_ptr kernel_build_info_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ diff --git a/mindspore/ccsrc/kernel/kernel_fusion.cc b/mindspore/ccsrc/kernel/kernel_fusion.cc deleted file mode 100644 index be79eca15a..0000000000 --- a/mindspore/ccsrc/kernel/kernel_fusion.cc +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/kernel_fusion.h" - -#include -#include -#include -#include - -#include "common/utils.h" -#include "kernel/tbe/tbe_kernel_build.h" -#include "kernel/tbe/tbe_kernel_parallel_build.h" -#include "kernel/tbe/tbe_utils.h" -#include "kernel/tbe/tbe_convert_utils.h" - -namespace mindspore { -namespace kernel { -using mindspore::kernel::tbe::TbeUtils; -static bool GenPreBuildKernelJson(const std::vector &compute_nodes, - std::vector *prebuild_op_list) { - MS_EXCEPTION_IF_NULL(prebuild_op_list); - TbeKernelJsonCreator creator(PREBUILD); - for (const auto &anf_node : compute_nodes) { - nlohmann::json prebuild; - if (!creator.GenTbeSingleKernelJson(anf_node, &prebuild)) { - MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; - return false; - } - (*prebuild_op_list).push_back(prebuild); - } - return true; -} - -std::map KernelFusion(const std::vector &fusion_scopes) { - MS_LOG(INFO) << "kernel fusion build start, scope size:" << fusion_scopes.size(); - std::map kernel_mod_ret; - auto build_manger = std::make_shared(); - MS_EXCEPTION_IF_NULL(build_manger); - for (const auto &fusion_scope_iter : fusion_scopes) { - auto scope_id = fusion_scope_iter.scope_id; - nlohmann::json fusion_op; - string fusion_kernel = "te_fusion"; - if (!TbeKernelBuild::GenFusionScopeJson(fusion_scope_iter.input_nodes, fusion_scope_iter.compute_nodes, &fusion_op, - &fusion_kernel)) { - continue; - } - // gen kernel_name & check cache - std::string json_str = fusion_op.dump(); - size_t hash_id = std::hash()(json_str); - auto json_name = fusion_kernel.append("_").append(std::to_string(hash_id)); - fusion_op["fusion_op_name"] = json_name; - // gen json for prebuild - std::vector prebuild_op_list; - if (!GenPreBuildKernelJson(fusion_scope_iter.compute_nodes, &prebuild_op_list)) { - continue; - } - // get io size - std::vector input_size_list; - std::vector output_size_list; - if (!TbeKernelBuild::GetIOSize(fusion_op["op_list"], fusion_scope_iter.output_nodes, &input_size_list, - &output_size_list)) { - continue; - } - // search cache - auto kernel_pack = TbeUtils::SearchCache(json_name, tbe::kProcessorAiCore); - if (kernel_pack != nullptr) { - MS_LOG(INFO) << "Use cached kernel, kernel json name: " << json_name; - auto kernel_mod = - build_manger->GenKernelMod(json_name, tbe::kProcessorAiCore, input_size_list, output_size_list, kernel_pack); - if (kernel_mod != nullptr) { - kernel_mod_ret[scope_id] = kernel_mod; - continue; - } - } - // fusion build - nlohmann::json fusion_json; - fusion_json["fusion_op"] = fusion_op; - fusion_json["prebuild_ops"] = prebuild_op_list; - auto task_id = build_manger->StartCompileOp(fusion_json); - TbeUtils::SaveJsonInfo(json_name, fusion_json.dump()); - if (task_id < 0) { - MS_EXCEPTION(ArgumentError) << "start compile failed."; - } - build_manger->SaveTaskInfo(task_id, nullptr, json_name, input_size_list, output_size_list, scope_id); - } - - int build_failed_num = 0; - while (!build_manger->IsAllTaskFinish()) { - int task_id = -1; - char *task_result = nullptr; - char *pre_build_result = nullptr; - auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); - if (!ret) { - MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; - } - - if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { - MS_LOG(INFO) << "Fusion warning: Fuison op build failed, err log: " << task_result - << " change to single op build."; - build_failed_num++; - } - auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, false); - if (kernel_mod_item.second != nullptr) { - (void)kernel_mod_ret.emplace(kernel_mod_item); - } - } - MS_LOG(INFO) << "Build Fusion Kernel Failed Num: " << build_failed_num; - return kernel_mod_ret; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_fusion.h b/mindspore/ccsrc/kernel/kernel_fusion.h deleted file mode 100644 index 8ded21787c..0000000000 --- a/mindspore/ccsrc/kernel/kernel_fusion.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ -#include -#include -#include "kernel/kernel.h" -namespace mindspore { -namespace kernel { -/* - * @brief fuse op and return a callable mod - */ -struct FusionScopeInfo { - int32_t scope_id; - std::vector input_nodes; - std::vector compute_nodes; - std::vector output_nodes; -}; - -std::map KernelFusion(const std::vector &fusion_scopes); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc deleted file mode 100755 index 4a8ae81afa..0000000000 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/kernel_query.h" -#include -#include -#include "kernel/aicpu/aicpu_kernel_metadata.h" -#include "kernel/rts/rt_kernel_info.h" -#include "kernel/hccl/hccl_kernel_metadata.h" -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" -#include "kernel/akg/akg_kernel_metadata.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -namespace { -void FilterInvalidKernelInfo(const CNodePtr &kernel_node, - std::vector> *kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_info_list); - std::vector> filtered_list; - (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), - [&kernel_node](const std::shared_ptr &kernel_build_info) { - return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && - AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); - }); - if (!filtered_list.empty()) { - kernel_info_list->clear(); - (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); - } else { - MS_LOG(INFO) << "All kernel Info list does not match any kernel info "; - for (size_t index = 0; index < kernel_info_list->size(); ++index) { - std::ostringstream buffer; - auto kernel_info = kernel_info_list->at(index); - MS_EXCEPTION_IF_NULL(kernel_info); - if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info->GetOutputNum()) { - buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" - << " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]"; - } else { - buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]" - << " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]"; - } - MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str(); - } - kernel_info_list->clear(); - MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : [" - << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" - << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; - } -} -} // namespace - -void KernelQueryAll(const CNodePtr &kernel_node, - std::vector> *kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - - TbeMetadataInfo(kernel_node, kernel_info_list); - - if (kernel_info_list->empty()) { - AicpuMetadataInfo(kernel_node, kernel_info_list); - if (!kernel_info_list->empty()) { - MS_LOG(INFO) << "The node [" << kernel_node->DebugString() - << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; - AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); - } - } - - if (kernel_info_list->empty()) { - GetRtKelInfo(kernel_node, kernel_info_list); - } - - if (kernel_info_list->empty()) { - HcclMetadataInfo(kernel_node, kernel_info_list); - } - if (kernel_info_list->empty()) { - MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; - } -} - -void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list, - KernelType kernel_type) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { - kernel_type = KernelType::AKG_KERNEL; - } - - switch (kernel_type) { - case KernelType::AKG_KERNEL: - AkgMetadataInfo(kernel_node, kernel_info_list); - break; - default: - KernelQueryAll(kernel_node, kernel_info_list); - break; - } - - if (kernel_info_list->empty()) { - MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!"; - } - // check output - FilterInvalidKernelInfo(kernel_node, kernel_info_list); -} - -void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - kernel_info_list->clear(); - AicpuMetadataInfo(kernel_node, kernel_info_list); - FilterInvalidKernelInfo(kernel_node, kernel_info_list); -} -bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(select_kernel_build_info); - std::vector> kernel_info_list; - auto cnode = kernel_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - AICPUQuery(cnode, &kernel_info_list); - return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), - [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { - MS_EXCEPTION_IF_NULL(item); - return *item == *select_kernel_build_info; - }); -} - -bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(select_kernel_build_info); - std::vector> kernel_info_list; - auto cnode = kernel_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - TbeMetadataInfo(cnode, &kernel_info_list); - return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), - [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { - MS_EXCEPTION_IF_NULL(item); - return *item == *select_kernel_build_info; - }); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_query.h b/mindspore/ccsrc/kernel/kernel_query.h deleted file mode 100644 index 257b0cf073..0000000000 --- a/mindspore/ccsrc/kernel/kernel_query.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ - -#include -#include -#include -#include "kernel/kernel.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list, - KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); -void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); -bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ diff --git a/mindspore/ccsrc/kernel/oplib/opinfo.h b/mindspore/ccsrc/kernel/oplib/opinfo.h deleted file mode 100644 index f224a97efc..0000000000 --- a/mindspore/ccsrc/kernel/oplib/opinfo.h +++ /dev/null @@ -1,167 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ -#define MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ -#include -#include -#include -#include -#include "ir/dtype.h" -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU }; -enum OpIOType { kInput = 0, kOutput }; - -class OpAttr { - public: - OpAttr() = default; - ~OpAttr() = default; - - std::string name() const { return name_; } - std::string param_type() const { return param_type_; } - std::string type() const { return type_; } - std::string value() const { return value_; } - std::string default_value() const { return default_value_; } - - void set_name(const std::string &name) { name_ = name; } - void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } - void set_type(const std::string &type) { type_ = type; } - void set_value(const std::string &value) { value_ = value; } - void set_default_value(const std::string &default_value) { default_value_ = default_value; } - - private: - std::string name_; - std::string param_type_; - std::string type_; - std::string value_; - std::string default_value_; -}; - -class OpIOInfo { - public: - OpIOInfo() = default; - ~OpIOInfo() = default; - - int index() const { return index_; } - std::string name() const { return name_; } - bool need_compile() const { return need_compile_; } - std::string param_type() const { return param_type_; } - std::string reshape_type() const { return reshape_type_; } - std::string shape() const { return shape_; } - std::vector dtypes() const { return dtypes_; } - std::vector formats() const { return formats_; } - - void set_index(const int index) { index_ = index; } - void set_name(const std::string &name) { name_ = name; } - void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } - void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } - void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } - void set_shape(const std::string &shape) { shape_ = shape; } - void set_dtypes(const std::vector &dtype) { dtypes_ = dtype; } - void set_formats(const std::vector &formats) { formats_ = formats; } - - private: - int index_ = 0; - std::string name_; - bool need_compile_ = false; - std::string param_type_; - std::string reshape_type_; - std::string shape_; - std::vector dtypes_; - std::vector formats_; -}; - -class OpInfo { - public: - OpInfo() = default; - OpInfo(const OpInfo &opinfo) { - op_name_ = opinfo.op_name(); - imply_type_ = opinfo.imply_type(); - - impl_path_ = opinfo.impl_path(); - fusion_type_ = opinfo.fusion_type(); - async_flag_ = opinfo.async_flag_; - binfile_name_ = opinfo.binfile_name_; - compute_cost_ = opinfo.compute_cost_; - kernel_name_ = opinfo.kernel_name(); - partial_flag_ = opinfo.partial_flag_; - dynamic_format_ = opinfo.dynamic_format_; - op_pattern_ = opinfo.op_pattern(); - for (auto attr : opinfo.attrs_ptr()) { - attrs_ptr_.push_back(std::make_shared(*attr)); - } - for (auto input : opinfo.inputs_ptr()) { - inputs_ptr_.push_back(std::make_shared(*input)); - } - for (auto output : opinfo.outputs_ptr()) { - outputs_ptr_.push_back(std::make_shared(*output)); - } - ref_infos_ = opinfo.ref_infos(); - } - ~OpInfo() = default; - std::string op_name() const { return op_name_; } - OpImplyType imply_type() const { return imply_type_; } - std::string impl_path() const { return impl_path_; } - std::string fusion_type() const { return fusion_type_; } - std::string kernel_name() const { return kernel_name_; } - OpPattern op_pattern() const { return op_pattern_; } - std::vector> attrs_ptr() const { return attrs_ptr_; } - std::vector> inputs_ptr() const { return inputs_ptr_; } - std::vector> outputs_ptr() const { return outputs_ptr_; } - const std::unordered_map &ref_infos() const { return ref_infos_; } - - void set_op_name(const std::string &op_name) { op_name_ = op_name; } - void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } - void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } - void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } - void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } - void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } - void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } - void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } - void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } - void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } - void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } - void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } - void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } - bool is_ref() const { return !ref_infos_.empty(); } - bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } - void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } - void ClearInputs() { (void)inputs_ptr_.clear(); } - void ClearOutputs() { (void)outputs_ptr_.clear(); } - - private: - std::string op_name_; - OpImplyType imply_type_ = kTBE; - std::string impl_path_; - std::string fusion_type_; - bool async_flag_ = false; - std::string binfile_name_; - int compute_cost_ = 0; - std::string kernel_name_; - bool partial_flag_ = false; - bool dynamic_format_ = false; - OpPattern op_pattern_ = kCommonPattern; - std::vector> attrs_ptr_; - std::vector> inputs_ptr_; - std::vector> outputs_ptr_; - std::unordered_map ref_infos_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc deleted file mode 100644 index e01bbe9162..0000000000 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ /dev/null @@ -1,329 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/oplib/oplib.h" -#include -#include -#include -#include -#include "utils/log_adapter.h" -#include "utils/overload.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -constexpr auto kImplyType = "imply_type"; -constexpr auto kOpName = "op_name"; -constexpr auto kFusionType = "fusion_type"; -constexpr auto kAsyncFlag = "async_flag"; -constexpr auto kBinfileName = "binfile_name"; -constexpr auto kComputeCost = "compute_cost"; -constexpr auto kKernelName = "kernel_name"; -constexpr auto kPartialFlag = "partial_flag"; -constexpr auto kReshapeType = "reshape_type"; -constexpr auto kOpPattern = "op_pattern"; -constexpr auto kDynamicFormat = "dynamicFormat"; -constexpr auto kFormatAgnostic = "formatAgnostic"; -constexpr auto kBroadcast = "broadcast"; -constexpr auto kReduce = "reduce"; -constexpr auto kDtypeFormat = "dtype_format"; -constexpr auto kAttr = "attr"; -constexpr auto kIputs = "inputs"; -constexpr auto kOutputs = "outputs"; -constexpr auto kAiCPU = "AiCPU"; -constexpr auto kTbe = "TBE"; -constexpr auto kAkg = "akg"; -constexpr auto kAutodiff = "AutoDiff"; -constexpr auto kName = "name"; -constexpr auto kParamType = "param_type"; -constexpr auto kDtype = "dtype"; -constexpr auto kType = "type"; -constexpr auto kValue = "value"; -constexpr auto kDefaultValue = "default_value"; -constexpr auto kIndex = "index"; -constexpr auto kFormat = "format"; -constexpr auto kNeedCompile = "need_compile"; -constexpr auto kShape = "shape"; -std::vector> OpLib::op_info_; - -std::string ImplTypeToStr(OpImplyType impl_type) { - switch (impl_type) { - case kTBE: - return kTbe; - case kAKG: - return kAkg; - case kAICPU: - return kAiCPU; - default: - return "unknow"; - } -} -bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) { - bool ret = false; - try { - auto op_json = nlohmann::json::parse(json_string); - std::string imply_type_string = op_json.at(kImplyType); - std::string op_name = op_json.at(kOpName); - if (imply_type_string == kTbe) { - OpImplyType imply_type = kTBE; - ret = DecodeOpInfo(op_json, imply_type, impl_path); - } else if (imply_type_string == kAutodiff) { - OpImplyType imply_type = kAKG; - ret = DecodeOpInfo(op_json, imply_type, impl_path); - } else if (imply_type_string == kAiCPU) { - OpImplyType imply_type = kAICPU; - ret = DecodeOpInfo(op_json, imply_type, impl_path); - } else { - MS_LOG(ERROR) << "Not support imply_type"; - } - if (!ret) { - MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string; - } - } catch (const std::exception &e) { - MS_LOG(ERROR) << "get op json elements failed: " << e.what(); - } - return ret; -} - -void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { - const std::map kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, - {kBroadcast, kBroadcastPattern}, - {kReduce, kReducePattern}, - {kDynamicFormat, kDynamicFormatPattern}}; - MS_EXCEPTION_IF_NULL(op_info); - op_info->set_async_flag(obj.at(kAsyncFlag)); - op_info->set_binfile_name(obj.at(kBinfileName)); - op_info->set_compute_cost(obj.at(kComputeCost)); - op_info->set_kernel_name(obj.at(kKernelName)); - op_info->set_partial_flag(obj.at(kPartialFlag)); - - if (obj.find(kOpPattern) != obj.end()) { - std::string op_pattern = obj.at(kOpPattern); - auto find_iter = kOpPatternMap.find(op_pattern); - if (find_iter == kOpPatternMap.end()) { - if (!op_pattern.empty()) { - MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern; - } - op_info->set_op_pattern(kCommonPattern); - } else { - op_info->set_op_pattern(find_iter->second); - } - } -} - -bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, - const std::string &impl_path) { - std::shared_ptr op_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(op_info); - op_info->set_op_name(obj.at(kOpName)); - op_info->set_impl_path(impl_path); - op_info->set_imply_type(imply_type); - op_info->set_fusion_type(obj.at(kFusionType)); - if (imply_type == kTBE) { - DecodeTBESpecificInfo(obj, op_info); - } - auto attrs = obj.at(kAttr); - for (const auto &attr : attrs) { - if (!DecodeAttr(attr, imply_type, op_info)) { - MS_LOG(ERROR) << "DecodeAttr Failed"; - return false; - } - } - nlohmann::json dtype_format; - if (obj.find(kDtypeFormat) != obj.end()) { - dtype_format = obj.at(kDtypeFormat); - } - auto inputs = obj.at(kIputs); - for (const auto &input : inputs) { - if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { - MS_LOG(ERROR) << "DecodeInputOutput Failed"; - return false; - } - } - auto outputs = obj.at(kOutputs); - for (const auto &output : outputs) { - if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { - MS_LOG(ERROR) << "DecodeInputOutput Failed"; - return false; - } - } - if (!GetRefInfo(op_info)) { - MS_LOG(ERROR) << "GetRefInfo Failed"; - return false; - } - if (!CheckRepetition(op_info)) { - MS_LOG(ERROR) << "CheckRepetition Failed"; - return false; - } - op_info_.push_back(op_info); - return true; -} - -bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, - const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(op_info); - bool ret = true; - try { - std::shared_ptr op_attr = std::make_shared(); - MS_EXCEPTION_IF_NULL(op_attr); - op_attr->set_name(obj.at(kName)); - if (imply_type != kAICPU) { - op_attr->set_param_type(obj.at(kParamType)); - } - op_attr->set_type(obj.at(kType)); - if (imply_type == kTBE) { - op_attr->set_value(obj.at(kValue)); - } - if (obj.find(kDefaultValue) != obj.end()) { - op_attr->set_default_value(obj.at(kDefaultValue)); - } - op_info->add_attrs_ptr(op_attr); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "DecodeAttr failed:" << e.what(); - ret = false; - } - return ret; -} - -bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, - size_t index) { - MS_EXCEPTION_IF_NULL(op_io); - bool ret = true; - try { - std::vector dtype; - std::vector format; - for (const auto &it : dtype_format) { - dtype.emplace_back(it[index][0]); - format.emplace_back(it[index][1]); - } - op_io->set_dtypes(dtype); - op_io->set_formats(format); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); - ret = false; - } - return ret; -} - -bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr &op_info, const nlohmann::json &dtype_format) { - MS_EXCEPTION_IF_NULL(op_info); - bool ret = true; - try { - std::shared_ptr op_io = std::make_shared(); - MS_EXCEPTION_IF_NULL(op_io); - op_io->set_index(obj.at(kIndex)); - op_io->set_name(obj.at(kName)); - if (!dtype_format.empty()) { - if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) { - MS_LOG(ERROR) << "Decode dtype format failed"; - return false; - } - } else { - op_io->set_dtypes(obj.at(kDtype)); - op_io->set_formats(obj.at(kFormat)); - } - if (op_io->dtypes().size() != op_io->formats().size()) { - MS_LOG(ERROR) << "op " << op_io->name() << " dtype size: " << op_io->dtypes() - << " is not equal to format size: " << op_io->formats(); - return false; - } - if (obj.find(kParamType) != obj.end()) { - op_io->set_param_type(obj.at(kParamType)); - } - if (imply_type == kTBE) { - if (obj.find(kNeedCompile) != obj.end()) { - op_io->set_need_compile(obj.at(kNeedCompile)); - } - if (obj.find(kShape) != obj.end()) { - op_io->set_shape(obj.at(kShape)); - } - if (obj.find(kReshapeType) != obj.end()) { - op_io->set_reshape_type(obj.at(kReshapeType)); - } - } - - if (io_type == kInput) { - op_info->add_inputs_ptr(op_io); - } else if (io_type == kOutput) { - op_info->add_outputs_ptr(op_io); - } - } catch (const std::exception &e) { - MS_LOG(ERROR) << "DecodeInputOutput failed" << e.what(); - ret = false; - } - return ret; -} - -std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool is_gpu = (context->device_target() == kGPUDevice); - if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) { - MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) - << ", current op num: " << op_info_.size(); - return nullptr; - } - for (const auto &op_info : op_info_) { - MS_EXCEPTION_IF_NULL(op_info); - if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { - return op_info; - } - } - MS_LOG(DEBUG) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) - << ", current op num: " << op_info_.size(); - return nullptr; -} - -bool OpLib::GetRefInfo(const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(op_info); - const auto &output_infos = op_info->outputs_ptr(); - const auto &input_infos = op_info->inputs_ptr(); - for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { - MS_EXCEPTION_IF_NULL(output_infos[out_index]); - const auto &out_name = output_infos[out_index]->name(); - for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { - MS_EXCEPTION_IF_NULL(input_infos[in_index]); - const auto &in_name = input_infos[in_index]->name(); - if (out_name == in_name) { - if (op_info->has_ref_index(out_index)) { - MS_LOG(ERROR) << "The out_index " << out_index << " is already in ref_info"; - return false; - } - op_info->add_ref_pair(out_index, in_index); - MS_LOG(INFO) << "add ref info, op name is " << op_info->op_name() << ", outindex is " << out_index - << ", in_index is " << in_index; - } - } - } - return true; -} - -bool OpLib::CheckRepetition(const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(op_info); - for (const auto &exist_op_info : op_info_) { - MS_EXCEPTION_IF_NULL(exist_op_info); - if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && - exist_op_info->impl_path() != op_info->impl_path()) { - MS_LOG(ERROR) << "Op has already exist, please use other name, op name: " << op_info->op_name() - << " op type: " << ImplTypeToStr(op_info->imply_type()); - return false; - } - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/oplib/oplib.h b/mindspore/ccsrc/kernel/oplib/oplib.h deleted file mode 100644 index 47183455a2..0000000000 --- a/mindspore/ccsrc/kernel/oplib/oplib.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ -#define MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ -#include -#include -#include -#include -#include "kernel/oplib/opinfo.h" - -namespace mindspore { -namespace kernel { -class OpLib { - public: - OpLib() = default; - virtual ~OpLib() = default; - bool RegOp(const std::string &json_string, const std::string &impl_path); - static void RegOpInfo(std::shared_ptr opinfo) { - op_info_.emplace_back(opinfo); - return; - } - static std::shared_ptr FindOp(const std::string &op_name, OpImplyType imply_type); - static const std::vector> &GetAllOpsInfo() { return op_info_; } - - protected: - static std::vector> op_info_; - - private: - static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path); - static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, - const std::shared_ptr &op_info); - static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, - size_t index); - static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info); - static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr &op_info, const nlohmann::json &dtype_format); - static bool GetRefInfo(const std::shared_ptr &op_info); - static bool CheckRepetition(const std::shared_ptr &op_info); -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ diff --git a/mindspore/ccsrc/kernel/oplib/oploader.h b/mindspore/ccsrc/kernel/oplib/oploader.h deleted file mode 100644 index dd4c37e80b..0000000000 --- a/mindspore/ccsrc/kernel/oplib/oploader.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_OPLOADER_H -#define MINDSPORE_OPLOADER_H - -#include -#include "kernel/oplib/oplib.h" - -namespace mindspore { -namespace kernel { -class OpInfoLoaderPy { - public: - OpInfoLoaderPy() = default; - - ~OpInfoLoaderPy() = default; - - size_t GetAllOpsInfo() { - auto ops = OpLib::GetAllOpsInfo(); - auto op_infos = new std::vector(); - for (auto op_info : ops) { - auto new_op_info = new OpInfo(*op_info); - op_infos->emplace_back(new_op_info); - } - return (size_t)op_infos; - } -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_OPLOADER_H diff --git a/mindspore/ccsrc/kernel/rts/assign.cc b/mindspore/ccsrc/kernel/rts/assign.cc deleted file mode 100644 index 7f214b6e6f..0000000000 --- a/mindspore/ccsrc/kernel/rts/assign.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/assign.h" - -#include - -#include "runtime/mem.h" -#include "common/utils.h" - -using ge::model_runner::MemcpyAsyncTaskInfo; -using MemcpyAsyncTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -AssignKernel::AssignKernel() {} - -AssignKernel::~AssignKernel() {} - -bool AssignKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void *stream_ptr) { - if (inputs.size() != 2) { - MS_LOG(ERROR) << "inputs size is not two"; - return false; - } - - if (inputs[0]->addr == inputs[1]->addr) { - MS_LOG(INFO) << "first addr is same with second addr , no need assign"; - return true; - } - rtError_t status = rtMemcpyAsync(inputs[0]->addr, inputs[0]->size, inputs[1]->addr, inputs[1]->size, - RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Assign op rtMemcpyAsync failed!"; - return false; - } - return true; -} - -std::vector AssignKernel::GenTask(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) { - if (inputs.size() != 2) { - MS_LOG(EXCEPTION) << "inputs size is not two"; - } - stream_id_ = stream_id; - - std::shared_ptr task_info_ptr = std::make_shared( - stream_id, inputs[0]->addr, inputs[0]->size, inputs[1]->addr, inputs[1]->size, RT_MEMCPY_DEVICE_TO_DEVICE); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/assign.h b/mindspore/ccsrc/kernel/rts/assign.h deleted file mode 100644 index 0e7e52d48f..0000000000 --- a/mindspore/ccsrc/kernel/rts/assign.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H -#define MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H - -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class AssignKernel : public RtKernel { - public: - AssignKernel(); - ~AssignKernel() override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; -}; - -MS_REG_RTKERNEL(assign, AssignKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H diff --git a/mindspore/ccsrc/kernel/rts/label_goto.cc b/mindspore/ccsrc/kernel/rts/label_goto.cc deleted file mode 100644 index 7bcf42a210..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_goto.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/label_goto.h" -#include -#include -#include "runtime/stream.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::LabelGotoTaskInfo; -using LabelGotoTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -LabelGotoKernel::LabelGotoKernel() { label_ = 0; } - -LabelGotoKernel::~LabelGotoKernel() {} - -bool LabelGotoKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "LabelGotoKernel init"; - auto cnode = anf_node->cast(); - if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { - MS_LOG(EXCEPTION) << "LabelGotoKernel has no attr label_index"; - } - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - label_ = GetValue(primitive->GetAttr(kAttrLabelIndex)); - MS_LOG(INFO) << "LabelGotoKernel get attr label:" << label_; - return true; -} - -bool LabelGotoKernel::Launch(const std::vector & /*inputs*/, const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void * /*stream_ptr*/) { - MS_LOG(INFO) << "LabelGotoKernel launch"; - return true; -} - -std::vector LabelGotoKernel::GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t stream_id) { - MS_LOG(INFO) << "LabelGotoKernel GenTask label:" << label_ << ", stream id:" << stream_id; - std::vector task_info_list; - std::shared_ptr task_info_ptr = std::make_shared(stream_id, label_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - task_info_list.emplace_back(task_info_ptr); - return task_info_list; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/label_goto.h b/mindspore/ccsrc/kernel/rts/label_goto.h deleted file mode 100644 index efccc12d6f..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_goto.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H -#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class LabelGotoKernel : public RtKernel { - public: - LabelGotoKernel(); - ~LabelGotoKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - uint32_t label_; -}; - -MS_REG_RTKERNEL(labelgoto, LabelGotoKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H diff --git a/mindspore/ccsrc/kernel/rts/label_set.cc b/mindspore/ccsrc/kernel/rts/label_set.cc deleted file mode 100644 index 5aedd012dc..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_set.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/label_set.h" -#include -#include -#include "runtime/stream.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::LabelSetTaskInfo; -using LabelSetTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -LabelSetKernel::LabelSetKernel() { label_ = 0; } - -LabelSetKernel::~LabelSetKernel() {} - -bool LabelSetKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "LabelSetKernel init"; - auto cnode = anf_node->cast(); - if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { - MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; - } - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - label_ = GetValue(primitive->GetAttr(kAttrLabelIndex)); - MS_LOG(INFO) << "LabelSetKernel get attr label:" << label_; - return true; -} - -bool LabelSetKernel::Launch(const std::vector & /*inputs*/, const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void * /*stream_ptr*/) { - MS_LOG(INFO) << "LabelSetKernel launch"; - return true; -} - -std::vector LabelSetKernel::GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t stream_id) { - MS_LOG(INFO) << "LabelSetKernel GenTask label:" << label_ << ", stream id:" << stream_id; - std::vector task_info_list; - std::shared_ptr task_info_ptr = std::make_shared(stream_id, label_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - task_info_list.emplace_back(task_info_ptr); - return task_info_list; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/label_set.h b/mindspore/ccsrc/kernel/rts/label_set.h deleted file mode 100644 index d05d81f898..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_set.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H -#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class LabelSetKernel : public RtKernel { - public: - LabelSetKernel(); - ~LabelSetKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - uint32_t label_; -}; - -MS_REG_RTKERNEL(labelset, LabelSetKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H diff --git a/mindspore/ccsrc/kernel/rts/label_switch.cc b/mindspore/ccsrc/kernel/rts/label_switch.cc deleted file mode 100644 index fb1ad1601a..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_switch.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/label_switch.h" -#include -#include -#include -#include "runtime/stream.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::LabelSwitchTaskInfo; -using LabelSwitchTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -LabelSwitchKernel::LabelSwitchKernel() { - label_list_ = {}; - cond_ = nullptr; - label_size_ = 0; -} - -LabelSwitchKernel::~LabelSwitchKernel() {} - -bool LabelSwitchKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "LabelSwitchKernel init"; - auto cnode = anf_node->cast(); - if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) { - MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; - } - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - label_list_ = GetValue>(primitive->GetAttr(kAttrLabelSwitchList)); - label_size_ = label_list_.size(); - MS_LOG(INFO) << "LabelSwitchKernel get attr label size:" << label_size_; - for (auto label : label_list_) { - MS_LOG(INFO) << "label: " << label; - } - return true; -} - -bool LabelSwitchKernel::Launch(const std::vector & /*inputs*/, - const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void * /*stream_ptr*/) { - MS_LOG(INFO) << "LabelSwitchKernel launch"; - return true; -} - -std::vector LabelSwitchKernel::GenTask(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) { - MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; - std::vector task_info_list; - cond_ = inputs[0]->addr; - auto task_info_ptr = std::make_shared(stream_id, label_size_, label_list_, cond_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - task_info_list.emplace_back(task_info_ptr); - return task_info_list; -} - -std::vector> LabelSwitchDesc::GetKernelInfo() { - std::vector> label_switch_build_info{}; - vector input_format{kOpFormat_DEFAULT}; - vector input_type{kNumberTypeInt32}; - if (input_format.size() != input_type.size()) { - MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " - << input_type.size(); - } - for (size_t i = 0; i < input_format.size(); ++i) { - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - builder.SetInputsFormat({input_format[i]}); - builder.SetInputsDeviceType({input_type[i]}); - builder.SetProcessor(AICORE); - builder.SetKernelType(RT_KERNEL); - builder.SetFusionType(OPAQUE); - label_switch_build_info.emplace_back(builder.Build()); - } - return label_switch_build_info; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/label_switch.h b/mindspore/ccsrc/kernel/rts/label_switch.h deleted file mode 100644 index 858f851b2a..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_switch.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H -#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class LabelSwitchKernel : public RtKernel { - public: - LabelSwitchKernel(); - ~LabelSwitchKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - std::vector label_list_; - uint32_t label_size_; - void *cond_; -}; - -class LabelSwitchDesc : public RtKerDesc { - public: - LabelSwitchDesc() = default; - ~LabelSwitchDesc() override = default; - std::vector> GetKernelInfo() override; -}; - -MS_REG_RTKERNEL_DESC(labelswitch, LabelSwitchDesc); -MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H diff --git a/mindspore/ccsrc/kernel/rts/memcpy_async.cc b/mindspore/ccsrc/kernel/rts/memcpy_async.cc deleted file mode 100644 index f5fbec6e56..0000000000 --- a/mindspore/ccsrc/kernel/rts/memcpy_async.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/memcpy_async.h" - -#include -#include - -#include "runtime/mem.h" -#include "common/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "common/trans.h" - -using ge::model_runner::MemcpyAsyncTaskInfo; -using MemcpyAsyncTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -MemCpyAsyncKernel::MemCpyAsyncKernel() {} - -MemCpyAsyncKernel::~MemCpyAsyncKernel() {} - -bool MemCpyAsyncKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs, void *stream_ptr) { - if (inputs.size() != 1) { - MS_LOG(ERROR) << "inputs size is not one"; - return false; - } - if (outputs.size() != 1) { - MS_LOG(ERROR) << "outputs size is not one"; - return false; - } - - if (inputs[0]->addr == outputs[0]->addr) { - MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async"; - return true; - } - if (outputs[0]->size < inputs[0]->size) { - MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; - } - // input x -> memcpy_async -> AllReduce - if (outputs[0]->size > inputs[0]->size) { - MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; - } - rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, - RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "MemCpyAsync op rtMemcpyAsync failed!"; - return false; - } - return true; -} - -bool MemCpyAsyncKernel::Init(const mindspore::AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - GetInputOutputDataType(anf_node); - GetInputOutputTotalCount(anf_node); - return true; -} - -void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - size_t input_size = AnfAlgo::GetInputTensorNum(anf_node); - if (input_size != 1) { - MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; - } - input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0); -} - -void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - size_t input_size = AnfAlgo::GetInputTensorNum(anf_node); - if (input_size != 1) { - MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; - } - size_t type_size = trans::TypeIdSize(input_type_id_); - std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0); - size_t total_size = 1; - for (size_t i = 0; i < shape_i.size(); i++) { - total_size = total_size * shape_i[i]; - } - total_size *= type_size; - MS_LOG(INFO) << "MemCpyAsync size[" << total_size << "]"; - input_size_list_.emplace_back(total_size); - output_size_list_.emplace_back(total_size); -} - -std::vector MemCpyAsyncKernel::GenTask(const std::vector &inputs, - const std::vector &, - const std::vector &outputs, uint32_t stream_id) { - if (inputs.size() != 1) { - MS_LOG(EXCEPTION) << "MemCpyAsync op inputs is not one"; - } - - if (outputs.size() != 1) { - MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one"; - } - - if (outputs[0]->size < inputs[0]->size) { - MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; - } - // input x -> memcpy_async -> AllReduce - if (outputs[0]->size > inputs[0]->size) { - MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; - } - - stream_id_ = stream_id; - std::shared_ptr task_info_ptr = std::make_shared( - stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} - -const std::vector data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, - kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, - kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, kNumberTypeFloat16, - kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool}; -const std::vector format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, - kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, - kOpFormat_C1HWNCoC0}; - -MemCpyAsyncDesc::MemCpyAsyncDesc() {} - -MemCpyAsyncDesc::~MemCpyAsyncDesc() {} - -std::vector> MemCpyAsyncDesc::GetKernelInfo() { - std::vector> memcpy_build_info{}; - for (const auto &format : format_list) { - for (const auto &type : data_type_list) { - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - vector input_format{format}; - vector input_type{type}; - vector output_format{format}; - vector output_type{type}; - builder.SetInputsFormat(input_format); - builder.SetInputsDeviceType(input_type); - builder.SetOutputsFormat(output_format); - builder.SetOutputsDeviceType(output_type); - builder.SetProcessor(AICORE); - builder.SetKernelType(RT_KERNEL); - builder.SetFusionType(OPAQUE); - memcpy_build_info.emplace_back(builder.Build()); - } - } - return memcpy_build_info; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/memcpy_async.h b/mindspore/ccsrc/kernel/rts/memcpy_async.h deleted file mode 100644 index 94bbf1ca1c..0000000000 --- a/mindspore/ccsrc/kernel/rts/memcpy_async.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H -#define MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class MemCpyAsyncKernel : public RtKernel { - public: - MemCpyAsyncKernel(); - ~MemCpyAsyncKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - void GetInputOutputDataType(const AnfNodePtr &anf_node); - void GetInputOutputTotalCount(const AnfNodePtr &anf_node); - TypeId input_type_id_{}; -}; - -class MemCpyAsyncDesc : public RtKerDesc { - public: - MemCpyAsyncDesc(); - ~MemCpyAsyncDesc() override; - std::vector> GetKernelInfo() override; -}; - -MS_REG_RTKERNEL_DESC(memcpy_async, MemCpyAsyncDesc); -MS_REG_RTKERNEL(memcpy_async, MemCpyAsyncKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H diff --git a/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.cc b/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.cc deleted file mode 100644 index ff005f399b..0000000000 --- a/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/profiling_kernel_mod.h" - -#include -#include -#include - -#include "framework/ge_runtime/task_info.h" -#include "device/ascend/profiling/profiling_utils.h" -#include "session/anf_runtime_algorithm.h" - -using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo; -using mindspore::device::ascend::ProfilingUtils; - -namespace mindspore { -namespace kernel { -bool ProfilingKernelMod::Init(const AnfNodePtr &anf_node) { - MS_LOG(INFO) << "[profiling] init profiling kernel mod"; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - - ValuePtr notify_ptr = primitive->GetAttr(ProfilingUtils::kNotify); - MS_EXCEPTION_IF_NULL(notify_ptr); - - ValuePtr log_id_ptr = primitive->GetAttr(ProfilingUtils::kProfilerTraceId); - MS_EXCEPTION_IF_NULL(log_id_ptr); - - ValuePtr flags_ptr = primitive->GetAttr(ProfilingUtils::kFlags); - MS_EXCEPTION_IF_NULL(flags_ptr); - - notify_ = GetValue(notify_ptr); - log_id_ = GetValue(log_id_ptr); - flags_ = GetValue(flags_ptr); - MS_LOG(INFO) << "[profiling] profiling kernel notify_:" << notify_ << ", log_id_:" << log_id_ - << ", flags_:" << flags_; - return true; -} - -bool ProfilingKernelMod::Launch(const std::vector & /*inputs*/, - const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void * /*stream_ptr*/) { - return true; -} - -std::vector ProfilingKernelMod::GenTask(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) { - MS_LOG(INFO) << "gen task inputs size:" << inputs.size() << ", workspace size:" << workspace.size() - << ", outputs size:" << outputs.size(); - stream_id_ = stream_id; - std::shared_ptr task_info_ptr = - std::make_shared(stream_id, log_id_, notify_, flags_); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.h b/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.h deleted file mode 100644 index f77f3b5c67..0000000000 --- a/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ -#include -#include "kernel/rts/rt_kernel.h" -namespace mindspore { -namespace kernel { -class ProfilingKernelMod : public RtKernel { - public: - ProfilingKernelMod() = default; - ~ProfilingKernelMod() override = default; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - bool Init(const AnfNodePtr &anf_node) override; - - private: - uint64_t log_id_{0}; - bool notify_{true}; - uint32_t flags_{0}; -}; -MS_REG_RTKERNEL(profiling, ProfilingKernelMod); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/rts/recv.cc b/mindspore/ccsrc/kernel/rts/recv.cc deleted file mode 100644 index c195fd1c92..0000000000 --- a/mindspore/ccsrc/kernel/rts/recv.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/recv.h" -#include -#include "runtime/stream.h" -#include "utils/context/ms_context.h" -#include "device/ascend/ascend_stream_assign.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -using ge::model_runner::EventWaitTaskInfo; -using mindspore::device::ascend::AscendStreamAssign; -using EventWaitTaskInfoPtr = std::shared_ptr; - -RecvKernel::RecvKernel() { event_id_ = 0; } - -RecvKernel::~RecvKernel() {} - -bool RecvKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (!AnfAlgo::HasNodeAttr(kAttrEventId, anf_node->cast())) { - MS_LOG(EXCEPTION) << "RecvKernel has no attr kAttrEventId"; - } - event_id_ = GetValue(primitive->GetAttr(kAttrEventId)); - MS_LOG(INFO) << "recv op event_id_:" << event_id_; - return true; -} - -bool RecvKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - rtEvent_t stream_event{}; - auto status = rtStreamWaitEvent(stream_ptr, stream_event); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Recv rtStreamWaitEvent failed!"; - return false; - } - return true; -} - -std::vector RecvKernel::GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t stream_id) { - MS_LOG(INFO) << "RecvKernel GenTask event_id_:" << event_id_ << ", stream_id_:" << stream_id; - stream_id_ = stream_id; - EventWaitTaskInfoPtr task_info_ptr = std::make_shared(stream_id, event_id_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/recv.h b/mindspore/ccsrc/kernel/rts/recv.h deleted file mode 100644 index 68f0b69cc5..0000000000 --- a/mindspore/ccsrc/kernel/rts/recv.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RECV_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RECV_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class RecvKernel : public RtKernel { - public: - RecvKernel(); - ~RecvKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - uint32_t event_id_; -}; - -MS_REG_RTKERNEL(recv, RecvKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RECV_H diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel.cc b/mindspore/ccsrc/kernel/rts/rt_kernel.cc deleted file mode 100644 index 9e81372383..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/rt_kernel.h" - -namespace mindspore { -namespace kernel { -void RtKernelFactory::Registe(const std::string &name, RtKernelCreater &&fun) { - (void)fmap_.emplace(name, std::move(fun)); -} - -std::shared_ptr RtKernelFactory::Create(const std::string &name) { - const auto &map = Get().fmap_; - auto it = map.find(name); - if (it != map.end() && it->second) { - return (it->second)(); - } - return nullptr; -} - -RtKernelFactory &RtKernelFactory::Get() { - static RtKernelFactory _this; - return _this; -} - -RtKernel::RtKernel() {} - -RtKernel::~RtKernel() {} - -bool RtKernel::Init(const mindspore::AnfNodePtr & /*anf_node*/) { return true; } - -const std::vector &RtKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &RtKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &RtKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel.h b/mindspore/ccsrc/kernel/rts/rt_kernel.h deleted file mode 100644 index 44d55dca31..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H - -#include -#include -#include -#include -#include -#include "kernel/ascend_kernel_mod.h" -#include "kernel/task_stream.h" - -namespace mindspore { -namespace kernel { -class RtKernel : public AscendKernelMod { - public: - RtKernel(); - ~RtKernel() override; - virtual bool Init(const AnfNodePtr &anf_node); - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - - protected: - mutable std::vector input_size_list_; - mutable std::vector output_size_list_; - mutable std::vector workspace_size_list_; -}; - -using RTKernelPtr = std::shared_ptr; - -using RtKernelCreater = std::function()>; -class RtKernelFactory { - RtKernelFactory() = default; - ~RtKernelFactory() = default; - - public: - static RtKernelFactory &Get(); - void Registe(const std::string &name, RtKernelCreater &&fun); - static std::shared_ptr Create(const std::string &name); - - private: - std::map fmap_; -}; - -class _RtKernelRegister { - public: - _RtKernelRegister(const std::string &name, RtKernelCreater &&fun) { - RtKernelFactory::Get().Registe(name, std::move(fun)); - } - ~_RtKernelRegister() = default; -}; - -#define _MS_REG_RTKERNEL_REG(KNAME, clazz) \ - static_assert(std::is_base_of::value, " must be base of RtKernel"); \ - static const _RtKernelRegister g_##KNAME##_##_RtKernel_reg(#KNAME, []() { return std::make_shared(); }); - -#define MS_REG_RTKERNEL(KNAME, clazz) _MS_REG_RTKERNEL_REG(KNAME, clazz) -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel_build.cc b/mindspore/ccsrc/kernel/rts/rt_kernel_build.cc deleted file mode 100644 index 164605fe9b..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel_build.cc +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/rt_kernel_build.h" - -#include -#include -#include -#include - -#include "kernel/rts/rt_kernel.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -KernelModPtr RtOpBuild(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - (void)std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower); - MS_LOG(INFO) << "Op Name(tolower)[" << op_name << "]"; - auto ker_ptr = RtKernelFactory::Create(op_name); - MS_EXCEPTION_IF_NULL(ker_ptr); - if (!ker_ptr->Init(anf_node)) { - MS_LOG(ERROR) << "Rt Op initialize failed!"; - return nullptr; - } - - return ker_ptr; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel_build.h b/mindspore/ccsrc/kernel/rts/rt_kernel_build.h deleted file mode 100644 index cbd674b751..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel_build.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H - -#include -#include -#include "kernel/kernel.h" -namespace mindspore { -namespace kernel { -KernelModPtr RtOpBuild(const AnfNodePtr &anf_node); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel_info.cc b/mindspore/ccsrc/kernel/rts/rt_kernel_info.cc deleted file mode 100755 index 14f5a60a07..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel_info.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/rt_kernel_info.h" -#include -#include -#include "utils/convert_utils.h" -#include "utils/utils.h" -#include "common/utils.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -void RtKerDescFactory::Register(const std::string &name, RtKerDescCreater &&fun) { - if (fmap_.find(name) == fmap_.end()) { - (void)fmap_.emplace(name, std::move(fun)); - } -} - -std::shared_ptr RtKerDescFactory::Create(const std::string &name) { - const auto &map = Get().fmap_; - auto it = map.find(name); - if (it != map.end() && it->second) { - return (it->second)(); - } - return nullptr; -} - -RtKerDescFactory &RtKerDescFactory::Get() { - static RtKerDescFactory _this; - return _this; -} - -static bool IsDefaultKernelInfo(const std::string &name) { - static const std::set white_list = {kStreamSwitchOpName, kStreamActiveOpName, kLabelSetOpName, - kLabelGotoOpName}; - return white_list.find(name) != white_list.end(); -} - -void GetRtKelInfo(const CNodePtr &kernel_node, - std::vector> *kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_info_list); - MS_EXCEPTION_IF_NULL(kernel_node); - std::string opNameLower = AnfAlgo::GetCNodeName(kernel_node); - (void)std::transform(opNameLower.begin(), opNameLower.end(), opNameLower.begin(), ::tolower); - - auto ker_desc_ptr = RtKerDescFactory::Create(opNameLower); - if (ker_desc_ptr != nullptr && !ker_desc_ptr->GetKernelInfo().empty()) { - *kernel_info_list = ker_desc_ptr->GetKernelInfo(); - return; - } - // if can't find kernel info in kernel info database, use the default kernel info - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - if (IsDefaultKernelInfo(node_name)) { - auto kernel_build_info_builder = std::make_shared(); - // set input infos - auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); - kernel_build_info_builder->SetInputsFormat(std::vector(input_num, kOpFormat_DEFAULT)); - std::vector input_types = {}; - for (size_t i = 0; i < input_num; i++) { - input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i)); - } - kernel_build_info_builder->SetInputsDeviceType(input_types); - // set output info - auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - kernel_build_info_builder->SetOutputsFormat(std::vector(output_num, kOpFormat_DEFAULT)); - kernel_build_info_builder->SetOutputsDeviceType(std::vector(output_num, TypeId::kTypeUnknown)); - // set ohter info - kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); - kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); - kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); - kernel_info_list->push_back(kernel_build_info_builder->Build()); - return; - } - MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "]."; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel_info.h b/mindspore/ccsrc/kernel/rts/rt_kernel_info.h deleted file mode 100644 index ae3753b4c8..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel_info.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/dtype.h" -#include "kernel/kernel_build_info.h" -#include "kernel/kernel.h" -#include "utils/utils.h" - -namespace mindspore { -namespace kernel { -class RtKerDesc { - public: - virtual ~RtKerDesc() {} - virtual std::vector> GetKernelInfo() { - return std::vector>{}; - } -}; - -using RtKerDescCreater = std::function()>; -class RtKerDescFactory { - RtKerDescFactory() = default; - ~RtKerDescFactory() = default; - - public: - static RtKerDescFactory &Get(); - void Register(const std::string &name, RtKerDescCreater &&fun); - static std::shared_ptr Create(const std::string &name); - - private: - std::map fmap_; -}; - -class _RtKerDescRegister { - public: - _RtKerDescRegister(const std::string &name, RtKerDescCreater &&fun) { - RtKerDescFactory::Get().Register(name, std::move(fun)); - } - ~_RtKerDescRegister() = default; -}; - -#define _MS_REG_RTKERNEL_DESC_REG(KNAME, clazz) \ - static_assert(std::is_base_of::value, " must be base of RtKerDesc"); \ - static const _RtKerDescRegister g_##KNAME##_##_rtkernel_desc_reg(#KNAME, []() { return std::make_shared(); }); - -#define MS_REG_RTKERNEL_DESC(KNAME, clazz) _MS_REG_RTKERNEL_DESC_REG(KNAME, clazz) - -void GetRtKelInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H diff --git a/mindspore/ccsrc/kernel/rts/send.cc b/mindspore/ccsrc/kernel/rts/send.cc deleted file mode 100644 index ccdd43ebb6..0000000000 --- a/mindspore/ccsrc/kernel/rts/send.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/send.h" -#include -#include "runtime/event.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::EventRecordTaskInfo; -using EventRecordTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -SendKernel::SendKernel() { event_id_ = 0; } - -SendKernel::~SendKernel() {} - -bool SendKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (!AnfAlgo::HasNodeAttr(kAttrEventId, anf_node->cast())) { - MS_LOG(EXCEPTION) << "SendKernel has no attr kAttrEventId"; - } - event_id_ = GetValue(primitive->GetAttr(kAttrEventId)); - MS_LOG(INFO) << "send op event id:" << event_id_; - return true; -} - -bool SendKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - rtEvent_t event{}; - rtError_t status = rtEventRecord(event, stream_ptr); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Send op rtEventRecord failed!"; - return false; - } - return true; -} - -std::vector SendKernel::GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t stream_id) { - MS_LOG(INFO) << "SendKernel GenTask event id:" << event_id_ << ", stream id:" << stream_id; - stream_id_ = stream_id; - EventRecordTaskInfoPtr task_info_ptr = std::make_shared(stream_id, event_id_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/send.h b/mindspore/ccsrc/kernel/rts/send.h deleted file mode 100644 index 5c5b7cf09e..0000000000 --- a/mindspore/ccsrc/kernel/rts/send.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_SEND_H -#define MINDSPORE_CCSRC_KERNEL_RTS_SEND_H -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class SendKernel : public RtKernel { - public: - SendKernel(); - ~SendKernel() override; - bool Init(const AnfNodePtr &anf_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - uint32_t event_id_; -}; - -MS_REG_RTKERNEL(send, SendKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_SEND_H diff --git a/mindspore/ccsrc/kernel/rts/stream_active.cc b/mindspore/ccsrc/kernel/rts/stream_active.cc deleted file mode 100644 index 4f0895a0be..0000000000 --- a/mindspore/ccsrc/kernel/rts/stream_active.cc +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/stream_active.h" -#include -#include -#include "runtime/stream.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::StreamActiveTaskInfo; -using StreamActiveTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -StreamActiveKernel::StreamActiveKernel() { active_streams_index_ = {}; } - -StreamActiveKernel::~StreamActiveKernel() {} - -bool StreamActiveKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "stream active op init start"; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (!AnfAlgo::HasNodeAttr(kAttrActiveStreamList, anf_node->cast())) { - MS_LOG(EXCEPTION) << "StreamActiveKernel has no attr kAttrActiveStreamList"; - } - active_streams_index_ = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); - return true; -} - -bool StreamActiveKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - MS_LOG(INFO) << "Stream active op launch start"; - - if (active_streams_index_.empty()) { - MS_LOG(ERROR) << "activeStreamList_ is empty!"; - return false; - } - - rtStream_t act_stream; - rtError_t status; - for (auto index : active_streams_index_) { - act_stream = kernel::TaskStream::GetInstance()->gen_stream_list()[index]; - status = rtStreamActive(act_stream, stream_ptr); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Stream active failed!"; - return false; - } - } - return true; -} - -std::vector StreamActiveKernel::GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t stream_id) { - MS_LOG(INFO) << "StreamActiveKernel GenTask active stream size:" << active_streams_index_.size() - << ", stream id:" << stream_id; - stream_id_ = stream_id; - std::vector task_info_list; - for (auto &index : active_streams_index_) { - std::shared_ptr task_info_ptr = std::make_shared(stream_id, index); - MS_EXCEPTION_IF_NULL(task_info_ptr); - task_info_list.emplace_back(task_info_ptr); - MS_LOG(INFO) << "StreamActiveKernel GenTask: streamId:" << stream_id << ", Active streamId:" << index; - } - return task_info_list; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/stream_active.h b/mindspore/ccsrc/kernel/rts/stream_active.h deleted file mode 100644 index 68c422e7c2..0000000000 --- a/mindspore/ccsrc/kernel/rts/stream_active.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H -#define MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class StreamActiveKernel : public RtKernel { - public: - StreamActiveKernel(); - ~StreamActiveKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - std::vector active_streams_index_; -}; - -MS_REG_RTKERNEL(streamactive, StreamActiveKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H diff --git a/mindspore/ccsrc/kernel/rts/stream_switch.cc b/mindspore/ccsrc/kernel/rts/stream_switch.cc deleted file mode 100644 index bab6b04366..0000000000 --- a/mindspore/ccsrc/kernel/rts/stream_switch.cc +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/stream_switch.h" - -#include -#include - -#include "runtime/stream.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::StreamSwitchTaskInfo; -using StreamSwitchTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -StreamSwitchKernel::StreamSwitchKernel() { - cond_ = RT_EQUAL; - true_stream_index_ = 0; - data_type_ = RT_SWITCH_INT32; -} - -StreamSwitchKernel::~StreamSwitchKernel() {} - -bool StreamSwitchKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "stream switch op init start"; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (!AnfAlgo::HasNodeAttr(kAttrSwitchCondition, anf_node->cast())) { - MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrSwitchCondition"; - } - cond_ = tagRtCondition(GetValue(primitive->GetAttr(kAttrSwitchCondition))); - if (!AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, anf_node->cast())) { - MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrTrueBranchStream"; - } - true_stream_index_ = GetValue(primitive->GetAttr(kAttrTrueBranchStream)); - if (!AnfAlgo::HasNodeAttr(kAttrDataType, anf_node->cast())) { - MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrDataType"; - } - data_type_ = tagRtSwitchDataType(GetValue(primitive->GetAttr(kAttrDataType))); - MS_LOG(INFO) << "cond_:" << static_cast(cond_) << ", true_stream_index_:" << true_stream_index_ - << ", data_type_:" << static_cast(data_type_); - return true; -} - -bool StreamSwitchKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - MS_LOG(INFO) << "stream switch op launch start"; - if (inputs.size() != 2) { - MS_LOG(EXCEPTION) << "Stream switch inputs size is " << inputs.size() << ", only support 2"; - } - - void *loop_cnt = inputs[0]->addr; - void *ites_per_loop = inputs[1]->addr; - rtStream_t true_stream_ = kernel::TaskStream::GetInstance()->gen_stream_list()[true_stream_index_]; - rtError_t status = rtStreamSwitchEx(loop_cnt, cond_, ites_per_loop, true_stream_, stream_ptr, data_type_); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Stream switch failed!"; - return false; - } - return true; -} - -std::vector StreamSwitchKernel::GenTask(const std::vector &inputs, - const std::vector &, const std::vector &, - uint32_t stream_id) { - MS_LOG(INFO) << "StreamSwitchKernel GenTask start"; - if (inputs.size() != 2) { - MS_LOG(EXCEPTION) << "stream switch inputs size is " << inputs.size() << ", is not two"; - } - stream_id_ = stream_id; - MS_EXCEPTION_IF_NULL(inputs[0]); - MS_EXCEPTION_IF_NULL(inputs[1]); - auto loop_cnt = inputs[0]->addr; - auto ites_per_loop = inputs[1]->addr; - MS_LOG(INFO) << "cond_:" << static_cast(cond_) << ", true_stream_index_:" << true_stream_index_ - << ", stream_id:" << stream_id; - std::shared_ptr task_info_ptr = - std::make_shared(stream_id, true_stream_index_, loop_cnt, ites_per_loop, cond_, data_type_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/stream_switch.h b/mindspore/ccsrc/kernel/rts/stream_switch.h deleted file mode 100644 index 4e927f3059..0000000000 --- a/mindspore/ccsrc/kernel/rts/stream_switch.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H -#define MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class StreamSwitchKernel : public RtKernel { - public: - StreamSwitchKernel(); - ~StreamSwitchKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - rtCondition_t cond_; - uint32_t true_stream_index_; - rtSwitchDataType_t data_type_; -}; - -MS_REG_RTKERNEL(streamswitch, StreamSwitchKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc deleted file mode 100644 index c38f48763e..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ /dev/null @@ -1,423 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_adapter.h" - -#include -#include -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/opinfo.h" - -namespace mindspore { -namespace kernel { -namespace tbe { -static std::map tbe_func_adapter_map = { - {"softmax", "softmax_v2"}, - {"log_softmax", "log_softmax_v2"}, - {"apply_momentum", "apply_momentum_d"}, - {"apply_ftrl", "apply_ftrl_d"}, - {"re_lu6", "relu6"}, - {"re_lu6_grad", "relu6_grad"}, - {"re_lu", "relu"}, - {"re_luv2", "relu_v2"}, - {"p_re_lu", "prelu"}, - {"p_re_lu_grad", "prelu_grad"}, - {"tensor_add", "add"}, - {"reduce_mean", "reduce_mean_d"}, - {"reduce_max", "reduce_max_d"}, - {"reduce_min", "reduce_min_d"}, - {"avg_pool_grad", "avg_pool_grad_d"}, - {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, - {"conv2d_backprop_input", "conv2d_backprop_input_d"}, - {"depthwise_conv2d_native", "depthwise_conv2d"}, - {"depthwise_conv2d_native_backprop_filter", "depthwise_conv2d_backprop_filter_d"}, - {"depthwise_conv2d_native_backprop_input", "depthwise_conv2d_backprop_input_d"}, - {"scatter_nd", "scatter_nd_d"}, - {"tile", "tile_d"}, - {"gather_v2", "gather_v2_d"}, - {"sparse_gather_v2", "gather_v2_d"}, - {"batch_mat_mul", "batch_matmul"}, - {"b_n_training_reduce", "bn_training_reduce"}, - {"b_n_training_update", "bn_training_update"}, - {"b_n_training_update_v2", "bn_training_update_v2"}, - {"b_n_training_update_v3", "bn_training_update_v3"}, - {"b_n_training_reduce_grad", "bn_training_reduce_grad"}, - {"b_n_training_update_grad", "bn_training_update_grad"}, - {"b_n_infer", "bn_infer"}, - {"b_n_infer_grad", "bn_infer_grad"}, - {"n_pu_clear_float_status", "n_p_u_clear_float_status"}, - {"n_pu_get_float_status", "n_p_u_get_float_status"}, - {"n_pu_alloc_float_status", "n_p_u_alloc_float_status"}, - {"dropout_do_mask", "drop_out_do_mask"}, - {"strided_slice", "strided_slice_d"}, - {"strided_slice_grad", "strided_slice_grad_d"}, - {"sparse_apply_ftrl", "sparse_apply_ftrl_d"}, - {"sparse_apply_ftrl_v2", "sparse_apply_ftrl_v2_d"}, - {"apply_ada_max", "apply_ada_max_d"}, - {"apply_adadelta", "apply_adadelta_d"}, - {"apply_adagrad", "apply_adagrad_d"}, - {"apply_adagrad_v2", "apply_adagradv2_d"}, - {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, - {"sparse_apply_adagrad_v2", "sparse_apply_adagrad_v2_d"}, - {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, - {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, - {"apply_add_sign", "apply_add_sign_d"}, - {"apply_power_sign", "apply_power_sign_d"}, - {"transpose", "transpose_d"}, - {"fill", "fill_d"}, - {"unsorted_segment_sum", "unsorted_segment_sum_d"}, - {"concat", "concat_d"}, - {"slice", "slice_d"}, - {"reduce_sum", "reduce_sum_d"}, - {"inplace_add", "inplace_add_d"}, - {"inplace_sub", "inplace_sub_d"}, - {"one_hot", "one_hot_d"}, - {"sum", "reduce_sum_d"}, - {"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"}, - {"lamb_next_mv", "lamb_next_m_v"}, - {"split", "split_d"}, - {"split_v", "split_v_d"}, - {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, - {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, - {"pad", "pad_d"}, - {"argmax", "arg_max_d"}, - {"argmin", "arg_min_d"}, - {"space_to_batch", "space_to_batch_d"}, - {"batch_to_space", "batch_to_space_d"}, - {"space_to_batch_nd", "space_to_batch_nd_d"}, - {"batch_to_space_nd", "batch_to_space_nd_d"}, - {"resize_bilinear", "resize_bilinear_v2_d"}, - {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, - {"adam", "apply_adam_d"}, - {"r_oi_align", "roi_align"}, - {"r_oi_align_grad", "roi_align_grad"}, - {"i_ou", "iou"}, - {"s_gd", "sgd"}, - {"l_rn", "lrn"}, - {"l_rn_grad", "lrn_grad"}, - {"l_ars_update", "lars_v2_update"}, - {"n_ms_with_mask", "nms_with_mask"}, - {"square_sum_all", "square_sum_all"}, - {"cum_sum", "cumsum_d"}, - {"range", "range_d"}, - {"lin_space", "lin_space_d"}, - {"inv_grad", "inv_grad"}, - {"apply_rms_prop", "apply_rms_prop_d"}, - {"cum_prod", "cumprod_d"}, - {"reduce_all", "reduce_all_d"}, - {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, - {"unsorted_segment_min", "unsorted_segment_min_d"}, - {"reduce_prod", "reduce_prod_d"}, - {"a_cos", "acos"}, - {"a_cos_grad", "acos_grad"}, - {"histogram_fixed_width", "histogram_fixed_width_d"}, - {"broadcast_to", "broadcast_to_d"}, - {"inplace_update", "inplace_update_d"}, - {"matrix_diag", "matrix_diag_d"}, - {"matrix_diag_part", "matrix_diag_part_d"}, - {"matrix_set_diag", "matrix_set_diag_d"}}; - -void TbeAdapter::NormalizeFuncName(std::string *func_name) { - if (func_name == nullptr) { - MS_LOG(EXCEPTION) << "func_name is null"; - } - std::string name_tmp; - bool sub_head = false; - for (string::iterator iter = func_name->begin(); iter != func_name->end(); ++iter) { - if (islower(*iter)) { - sub_head = false; - } - if (isdigit(*iter)) { - sub_head = true; - } - if (isupper(*iter) && iter != func_name->begin()) { - if (!sub_head) { - (void)name_tmp.insert(name_tmp.end(), '_'); - sub_head = true; - } else { - string::iterator iter_next = iter + 1; - if (iter_next != func_name->end()) { - if (islower(*iter_next)) { - (void)name_tmp.insert(name_tmp.end(), '_'); - } - } - } - } - (void)name_tmp.insert(name_tmp.end(), *iter); - } - (void)transform(name_tmp.begin(), name_tmp.end(), name_tmp.begin(), ::tolower); - *func_name = name_tmp; - auto iter = tbe_func_adapter_map.find(*func_name); - if (iter != tbe_func_adapter_map.end()) { - MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second; - *func_name = iter->second; - } -} - -void TbeAdapter::SetTbeAttrsForTransDataOp(const mindspore::AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - if (AnfAlgo::GetCNodeName(anf_node) == kTransDataOpName) { - std::string input_format = AnfAlgo::GetInputFormat(anf_node, 0); - std::string output_format = AnfAlgo::GetOutputFormat(anf_node, 0); - if (input_format == kOpFormat_DEFAULT) { - input_format = kOpFormat_NCHW; - } - if (output_format == kOpFormat_DEFAULT) { - output_format = kOpFormat_NCHW; - } - AnfAlgo::SetNodeAttr("src_format", MakeValue(input_format), anf_node); - AnfAlgo::SetNodeAttr("dst_format", MakeValue(output_format), anf_node); - } -} - -std::unordered_set input_order_adjusted_ops = { - "Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop", - "LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"}; - -void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector> const &inputs_list, - nlohmann::json *inputs_json) { - MS_EXCEPTION_IF_NULL(inputs_json); - if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { - (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json))); - } else { - if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { - inputs_json->push_back(inputs_list[2]); - inputs_json->push_back(inputs_list[0]); - inputs_json->push_back(inputs_list[1]); - for (size_t i = 3; i < inputs_list.size(); ++i) { - inputs_json->push_back(inputs_list[i]); - } - } else if (op_name == "ApplyCenteredRMSProp") { - // Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map - // TBE parameter to correspond python API parameter by latter's index using hardcode - inputs_json->push_back(inputs_list[0]); - inputs_json->push_back(inputs_list[1]); - inputs_json->push_back(inputs_list[2]); - inputs_json->push_back(inputs_list[3]); - inputs_json->push_back(inputs_list[5]); - inputs_json->push_back(inputs_list[6]); - inputs_json->push_back(inputs_list[7]); - inputs_json->push_back(inputs_list[8]); - inputs_json->push_back(inputs_list[4]); - } else { - inputs_json->push_back(inputs_list[1]); - inputs_json->push_back(inputs_list[0]); - for (size_t i = 2; i < inputs_list.size(); ++i) { - inputs_json->push_back(inputs_list[i]); - } - } - } -} - -void TbeAdapter::FusionInputOrderPass(const std::string &op_name, const std::vector &inputs_list, - std::vector *inputs_json) { - MS_EXCEPTION_IF_NULL(inputs_json); - if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { - (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json))); - } else { - if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { - inputs_json->emplace_back(inputs_list[2]); - inputs_json->emplace_back(inputs_list[0]); - inputs_json->emplace_back(inputs_list[1]); - for (size_t i = 3; i < inputs_list.size(); ++i) { - inputs_json->emplace_back(inputs_list[i]); - } - } else { - inputs_json->emplace_back(inputs_list[1]); - inputs_json->emplace_back(inputs_list[0]); - for (size_t i = 2; i < inputs_list.size(); ++i) { - inputs_json->emplace_back(inputs_list[i]); - } - } - } -} - -void TbeAdapter::FusionDataOrderPass(const std::string &op_name, const std::vector &data_layer, - std::vector *reorder_data_layer) { - MS_EXCEPTION_IF_NULL(reorder_data_layer); - if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { - (void)std::copy(data_layer.begin(), data_layer.end(), std::back_inserter((*reorder_data_layer))); - } else { - if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { - reorder_data_layer->emplace_back(data_layer[2]); - reorder_data_layer->emplace_back(data_layer[0]); - reorder_data_layer->emplace_back(data_layer[1]); - for (size_t i = 3; i < data_layer.size(); ++i) { - reorder_data_layer->emplace_back(data_layer[i]); - } - } else { - reorder_data_layer->emplace_back(data_layer[1]); - reorder_data_layer->emplace_back(data_layer[0]); - for (size_t i = 2; i < data_layer.size(); ++i) { - reorder_data_layer->emplace_back(data_layer[i]); - } - } - } -} - -std::map TbeAdapter::build_json_attr_pass_map_ = { - {"MaximumGrad", TbeAdapter::MaximumGradAttrJsonPass}, - {"MinimumGrad", TbeAdapter::MinimumGradAttrJsonPass}, - {"Cast", TbeAdapter::CastAttrJsonPass}}; - -bool TbeAdapter::RunAttrPass(const mindspore::AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(attrs_json); - auto cnode_name = AnfAlgo::GetCNodeName(anf_node); - auto FPass = build_json_attr_pass_map_.find(cnode_name); - if (FPass != build_json_attr_pass_map_.end()) { - FPass->second(anf_node, op_info_attrs, attrs_json); - return true; - } - return false; -} - -void TbeAdapter::MaximumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(attrs_json); - auto attr_num = op_info_attrs.size(); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - for (size_t i = 0; i < attr_num; i++) { - nlohmann::json attr_obj; - MS_EXCEPTION_IF_NULL(op_info_attrs[i]); - std::string attr_name = op_info_attrs[i]->name(); - auto value = primitive->GetAttr(attr_name); - if (value != nullptr) { - bool attr_value = GetValue(value); - attr_obj["value"] = attr_value; - attr_obj["valid"] = true; - } else { - attr_obj["valid"] = false; - } - attr_obj["name"] = attr_name; - attrs_json->push_back(attr_obj); - } - MS_LOG(INFO) << "MaximumGradAttrJsonPass done."; -} - -void TbeAdapter::MinimumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(attrs_json); - auto attr_num = op_info_attrs.size(); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - for (size_t i = 0; i < attr_num; i++) { - nlohmann::json attr_obj; - MS_EXCEPTION_IF_NULL(op_info_attrs[i]); - std::string attr_name = op_info_attrs[i]->name(); - auto value = primitive->GetAttr(attr_name); - if (value != nullptr) { - bool attr_value = GetValue(value); - attr_obj["value"] = attr_value; - attr_obj["valid"] = true; - } else { - attr_obj["valid"] = false; - } - attr_obj["name"] = attr_name; - attrs_json->push_back(attr_obj); - } - MS_LOG(INFO) << "MinimumGradAttrJsonPass done."; -} - -static int TypeStrToDstType(const std::string &type_str) { - int ret = -1; - if (type_str == "Float" || type_str == "Float32") { - ret = 0; - } else if (type_str == "Float16") { - ret = 1; - } else if (type_str == "Int8") { - ret = 2; - } else if (type_str == "Int32") { - ret = 3; - } else if (type_str == "UInt8") { - ret = 4; - } else if (type_str == "UInt64") { - ret = 10; - } else if (type_str == "Bool") { - ret = 12; - } else { - MS_LOG(INFO) << "Error type str is invailed: " << type_str; - } - return ret; -} - -void TbeAdapter::CastAttrJsonPass(const mindspore::AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(attrs_json); - if (op_info_attrs.size() != 1) { - MS_LOG(INFO) << "cast node should has dst_type attr"; - return; - } - auto attr_name = op_info_attrs[0]->name(); - auto type_ptr = std::make_shared(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, 0))); - MS_EXCEPTION_IF_NULL(type_ptr); - auto type_element = type_ptr->element(); - MS_EXCEPTION_IF_NULL(type_element); - auto dtype = type_element->ToString(); - auto dst_type_value = TypeStrToDstType(dtype); - nlohmann::json attr_obj; - attr_obj["value"] = dst_type_value; - attr_obj["valid"] = true; - attr_obj["name"] = attr_name; - attrs_json->push_back(attr_obj); - MS_LOG(INFO) << "CastAttrJsonPass done."; -} - -void TbeAdapter::GenTopKV2IndicesTensorInfo(const std::shared_ptr &anf_node, - size_t real_input_index, std::vector *input_list, - mindspore::kernel::kCreaterType creater_type) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(input_list); - auto input_x_shape = AnfAlgo::GetOutputInferShape(anf_node, 0); - size_t last_dim = input_x_shape[input_x_shape.size() - 1]; - std::vector tensor_shape = {last_dim}; - std::vector tensor_origin_shape = {last_dim}; - std::string tensor_format = AnfAlgo::GetInputFormat(anf_node, static_cast(real_input_index)); - if (tensor_format == kOpFormat_DEFAULT) { - tensor_format = kOpFormat_NCHW; - } - std::string tensor_origin_format = kOpFormat_NCHW; - std::string tensor_dtype = "float16"; - nlohmann::json input_desc_json; - input_desc_json["dtype"] = tensor_dtype; - input_desc_json["name"] = AnfAlgo::GetCNodeName(anf_node); - input_desc_json["ori_shape"] = tensor_origin_shape; - input_desc_json["ori_format"] = tensor_origin_format; - input_desc_json["shape"] = tensor_shape; - if (creater_type == OP_SELECT_FORMAT) { - input_desc_json["format"] = tensor_origin_format; - } else { - input_desc_json["format"] = tensor_format; - } - input_desc_json["valid"] = true; - input_list->emplace_back(input_desc_json); -} -} // namespace tbe -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.h b/mindspore/ccsrc/kernel/tbe/tbe_adapter.h deleted file mode 100644 index 51c4cfd777..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H - -#include -#include -#include -#include -#include "nlohmann/json.hpp" -#include "ir/base.h" -#include "kernel/oplib/opinfo.h" -// Note: This file is mainly used to adapt the ME front-end operator description and -// the TBE back-end operator implementation difference -namespace mindspore { -namespace kernel { -enum kCreaterType : int { SINGLE_BUILD = 0, PREBUILD, OP_SELECT_FORMAT, CHECK_SUPPORTED, OP_PRE_COMPILE }; -namespace tbe { -using FAttrsPass = void (*)(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); -class TbeAdapter { - public: - TbeAdapter() = default; - ~TbeAdapter() = default; - static void NormalizeFuncName(std::string *func_name); - static void SetTbeAttrsForTransDataOp(const AnfNodePtr &anf_node); - static void InputOrderPass(const std::string &op_name, std::vector> const &inputs_list, - nlohmann::json *inputs_json); - static bool RunAttrPass(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); - static void GenTopKV2IndicesTensorInfo(const std::shared_ptr &anf_node, size_t real_input_index, - std::vector *input_list, kCreaterType creater_type); - - static void FusionInputOrderPass(const std::string &op_name, const std::vector &inputs_list, - std::vector *inputs_json); - static void FusionDataOrderPass(const std::string &op_name, const std::vector &data_layer, - std::vector *reorder_data_layer); - - private: - static void MaximumGradAttrJsonPass(const AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); - static void MinimumGradAttrJsonPass(const AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); - - static void CastAttrJsonPass(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); - - static std::map build_json_attr_pass_map_; -}; -} // namespace tbe -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc b/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc deleted file mode 100644 index 90c5557253..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_convert_utils.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -namespace tbe { -const std::unordered_map type_str_id_maps = { - {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16}, - {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64}, - {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8}, - {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32}, - {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, - {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, - {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, - {"bool", TypeId::kNumberTypeBool}, -}; - -const std::map type_id_str_maps = { - {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"}, - {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"}, - {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"}, - {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"}, - {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, - {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, - {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, - {TypeId::kNumberTypeBool, "int8"}, -}; - -const std::map type_str_maps = { - {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, - {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, - {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool", "int8"}, {"Float64", "float64"}, -}; - -const std::unordered_map type_nbyte_maps = { - {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2}, - {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)}, - {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2}, - {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)}, -}; - -const std::unordered_map fusion_type_maps = { - {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, - {"SEGMENT", FusionType::SEGMENT}, {"DYNAMIC", FusionType::DYNAMIC}, {"OPAQUE", FusionType::OPAQUE}, -}; - -TypeId DtypeToTypeId(const std::string &dtypes) { - auto iter = type_str_id_maps.find(dtypes); - if (iter == type_str_id_maps.end()) { - MS_LOG(EXCEPTION) << "Illegal input device dtype: " << dtypes; - } - return iter->second; -} - -std::string TypeIdToString(TypeId type_id) { - auto iter = type_id_str_maps.find(type_id); - if (iter == type_id_str_maps.end()) { - MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id); - } - return iter->second; -} - -size_t GetDtypeNbyte(const std::string &dtypes) { - auto iter = type_nbyte_maps.find(dtypes); - if (iter == type_nbyte_maps.end()) { - MS_LOG(EXCEPTION) << "Illegal input dtype: " << dtypes; - } - return iter->second; -} - -FusionType GetFusionType(const std::string &pattern) { - auto iter = fusion_type_maps.find(pattern); - if (iter == fusion_type_maps.end()) { - MS_LOG(INFO) << "Illegal fusion pattern: " << pattern; - return UNKNOWN_FUSION_TYPE; - } - return iter->second; -} - -std::string GetProcessor(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string device; - switch (AnfAlgo::GetProcessor(anf_node)) { - case Processor::AICORE: - device = kProcessorAiCore; - break; - default: - MS_LOG(INFO) << "Unknown processor type." << anf_node->fullname_with_scope(); - break; - } - return device; -} -} // namespace tbe -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.h b/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.h deleted file mode 100644 index 2c8d3008b9..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ - -#include -#include "kernel/kernel.h" -#include "ir/base.h" -#include "ir/dtype/type.h" - -namespace mindspore { -namespace kernel { -namespace tbe { -constexpr auto kProcessorAiCore = "aicore"; -TypeId DtypeToTypeId(const std::string &dtypes); - -std::string TypeIdToString(TypeId type_id); - -size_t GetDtypeNbyte(const std::string &dtypes); - -FusionType GetFusionType(const std::string &pattern); - -std::string GetProcessor(const AnfNodePtr &anf_node); -} // namespace tbe -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc deleted file mode 100644 index 645a195f5e..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc +++ /dev/null @@ -1,1019 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_build.h" -#include -#include -#include -#include "operator/ops.h" -#include "parallel/ops_info/ops_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/tbe/tbe_adapter.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "kernel/tbe/tbe_convert_utils.h" -#include "kernel/tbe/tbe_utils.h" - -namespace mindspore { -namespace kernel { -using mindspore::kernel::tbe::TbeAdapter; -using mindspore::kernel::tbe::TbeUtils; -constexpr auto kFusionOpList = "op_list"; -constexpr auto kFusionKernelNamePrfix = "te_fusion"; -constexpr auto kOptional = "optional_"; -constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z"; -constexpr auto kPlatform = "platform"; -constexpr auto kPlatTBE = "TBE"; -constexpr auto kGenModel = "gen_model"; -constexpr auto kSingle = "single"; -constexpr auto kImplPath = "impl_path"; -constexpr auto kJInputs = "inputs"; -constexpr auto kJOutputs = "outputs"; -constexpr auto kJAttrs = "attrs"; -constexpr auto kJKernelName = "kernel_name"; -constexpr auto kJOpInfo = "op_info"; -constexpr auto kJDtype = "dtype"; -constexpr auto kJtype = "type"; -constexpr auto kJName = "name"; -constexpr auto kJOriShape = "ori_shape"; -constexpr auto kJOriFormat = "ori_format"; -constexpr auto kJShape = "shape"; -constexpr auto kJFormat = "format"; -constexpr auto kJValid = "valid"; -constexpr auto kJParamType = "param_type"; -constexpr auto kParamDynamic = "dynamic"; -constexpr auto kParamRequred = "required"; -constexpr auto kJDataType = "data_type"; -constexpr auto kJOutputIndex = "output_index"; -constexpr auto kJOutputDesc = "output_desc"; -constexpr auto kJInputDesc = "input_desc"; -constexpr auto kVTypeInt = "int"; -constexpr auto kVTypeStr = "str"; -constexpr auto kVTypeBool = "bool"; -constexpr auto kVTypeFloat = "float"; -constexpr auto kVTypeListInt = "listInt"; -constexpr auto kVTypeInt32 = "Int32"; -constexpr auto kVTypeListUInt64 = "listUInt64"; -constexpr auto kVTypeListFloat = "listFloat"; -constexpr auto kVTypeListListInt = "listListInt"; -constexpr auto kJValue = "value"; -constexpr auto kJDynIndex = "dyn_index"; -constexpr auto kJFuncName = "func_name"; - -std::string NormalizeFullScopeName(const string &full_scope_name) { - // exp:Default/ReLU-op0 -->Default_ReLU_op0 - string normal_ret = full_scope_name; - std::replace(normal_ret.begin(), normal_ret.end(), '/', '_'); - std::replace(normal_ret.begin(), normal_ret.end(), '-', '_'); - return normal_ret; -} - -bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr &anf_node, - nlohmann::json *kernel_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(kernel_json); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); - MS_EXCEPTION_IF_NULL(op_info_ptr); - (*kernel_json)[kPlatform] = kPlatTBE; - (*kernel_json)[kGenModel] = kSingle; - (*kernel_json)[kImplPath] = op_info_ptr->impl_path(); - nlohmann::json op_info_json; - if (op_info_ptr->impl_path().empty()) { - tbe::TbeAdapter::NormalizeFuncName(&op_name); - } else { - op_name = op_info_ptr->kernel_name(); - } - op_info_json[kJName] = op_name; - // generate inputs json - nlohmann::json inputs_json; - if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) { - MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate inputs json failed"; - return false; - } - op_info_json[kJInputs] = inputs_json; - // generate outputs json - nlohmann::json outputs_json; - if (!GenTbeOutputsJson(anf_node, op_info_ptr, &outputs_json)) { - MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate outputs json failed"; - return false; - } - op_info_json[kJOutputs] = outputs_json; - // generate attrs json - nlohmann::json attrs_json; - (void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json); - op_info_json[kJAttrs] = attrs_json; - std::string json_str = op_info_json.dump(); - size_t hash_id = std::hash()(json_str); - json_name_ = op_name + "_" + std::to_string(hash_id); - json_info_ = json_str; - if (creater_type_ == PREBUILD) { - op_info_json[kJKernelName] = NormalizeFullScopeName(anf_node->fullname_with_scope()); - } else { - op_info_json[kJKernelName] = json_name_; - } - (*kernel_json)[kJOpInfo] = op_info_json; - if (creater_type_ == SINGLE_BUILD) { - TbeUtils::SaveJsonInfo(json_name_, json_info_); - } - - MS_LOG(INFO) << "Operate type:" << creater_type_ << ", full scope name is :" << anf_node->fullname_with_scope() - << ", json info name is : " << json_name_ << ", kernel json:" << kernel_json->dump(); - - return true; -} - -bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr &anf_node, size_t real_input_index, - bool value, const std::shared_ptr &input_ptr, - const string &op_input_name, size_t input_i, - std::vector *input_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(input_ptr); - MS_EXCEPTION_IF_NULL(input_list); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { - TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_); - } else { - auto dtype = GetDeviceInputType(anf_node, real_input_index); - auto format = GetDeviceInputFormat(anf_node, real_input_index); - auto shape = GetDeviceInputShape(anf_node, real_input_index); - auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); - if (ori_shape.empty()) { - ori_shape.emplace_back(1); - } - nlohmann::json input_desc_json; - input_desc_json[kJDtype] = dtype; - input_desc_json[kJName] = op_input_name + std::to_string(input_i); - input_desc_json[kJOriShape] = ori_shape; - input_desc_json[kJOriFormat] = kOpFormat_NCHW; - input_desc_json[kJShape] = shape; - input_desc_json[kJFormat] = format; - input_desc_json[kJValid] = value; - input_desc_json[kJParamType] = input_ptr->param_type(); - input_list->emplace_back(input_desc_json); - } - return true; -} - -bool TbeKernelJsonCreator::GenInputList(const std::shared_ptr &anf_node, size_t input_tensor_num, - const std::shared_ptr &input_ptr, size_t *real_input_index, - string *op_input_name, std::vector *input_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(input_ptr); - MS_EXCEPTION_IF_NULL(real_input_index); - MS_EXCEPTION_IF_NULL(op_input_name); - MS_EXCEPTION_IF_NULL(input_list); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node); - bool value = true; - for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { - if (*real_input_index >= real_input_num) { - if (input_ptr->param_type() == "optional") { - *op_input_name = input_ptr->name() + "_optional_"; - nlohmann::json input_desc_json; - input_desc_json[kJValid] = false; - input_desc_json[kJName] = *op_input_name + std::to_string(*real_input_index); - input_list->emplace_back(input_desc_json); - continue; - } - MS_LOG(ERROR) << "Input num: " << *real_input_index << " is not match op inputs"; - return false; - } - if (op_name == "BatchNorm") { - if (input_ptr->name() == "mean" || input_ptr->name() == "variance") { - auto attr = primitive->GetAttr("is_training"); - MS_EXCEPTION_IF_NULL(attr); - bool is_training = GetValue(attr); - MS_LOG(INFO) << "Op_name" << op_name << ", tensor_name " << input_ptr->name() << ", is_training " - << is_training; - if (is_training) { - (*real_input_index)++; - break; - } - } - } - bool ret = GenInputDescJson(anf_node, *real_input_index, value, input_ptr, *op_input_name, input_i, input_list); - (*real_input_index)++; - if (!ret) { - return false; - } - } - return true; -} - -bool GetInputNameAndRealNum(const std::shared_ptr &anf_node, const std::shared_ptr &input_ptr, - size_t *dyn_input_index, size_t *input_num, std::string *op_input_name) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(input_ptr); - MS_EXCEPTION_IF_NULL(dyn_input_index); - MS_EXCEPTION_IF_NULL(input_num); - MS_EXCEPTION_IF_NULL(op_input_name); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. - std::vector dyn_input_sizes; - if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - - if (input_ptr->param_type() == kParamDynamic) { - if (*dyn_input_index >= dyn_input_sizes.size()) { - MS_LOG(ERROR) << "Dyn input index" << *dyn_input_index << "is over dyn input num" << dyn_input_sizes.size(); - return false; - } - *input_num = IntToSize(dyn_input_sizes[*dyn_input_index]); - *op_input_name = input_ptr->name() + "_dynamic_"; - (*dyn_input_index)++; - // if optional input is exist - } else { - *input_num = 1; - *op_input_name = input_ptr->name() + "_"; - } - return true; -} - -bool TbeKernelJsonCreator::GenTbeInputsJson(const std::shared_ptr &anf_node, - const std::shared_ptr &op_info, nlohmann::json *inputs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_info); - MS_EXCEPTION_IF_NULL(inputs_json); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == kAtomicAddrCleanOpName) { - return true; - } - std::vector> inputs_ptr = op_info->inputs_ptr(); - if (inputs_ptr.empty()) { - MS_LOG(INFO) << "Apply kernel " << op_name << "registration info has no input info"; - return true; - } - auto op_info_input_num = inputs_ptr.size(); - size_t dyn_input_index = 0; - size_t real_input_index = 0; - std::vector> inputs_list; - for (size_t i = 0; i < op_info_input_num; i++) { - size_t input_tensor_num; - std::shared_ptr input_ptr = inputs_ptr[i]; - std::string op_input_name; - MS_EXCEPTION_IF_NULL(input_ptr); - if (!GetInputNameAndRealNum(anf_node, input_ptr, &dyn_input_index, &input_tensor_num, &op_input_name)) { - return false; - } - std::vector input_list; - if (!GenInputList(anf_node, input_tensor_num, input_ptr, &real_input_index, &op_input_name, &input_list)) { - return false; - } - inputs_list.emplace_back(input_list); - } - - TbeAdapter::InputOrderPass(op_name, inputs_list, inputs_json); - return true; -} - -bool TbeKernelJsonCreator::GenTbeOutputsJson(const std::shared_ptr &anf_node, - const std::shared_ptr &op_info, nlohmann::json *outputs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_info); - MS_EXCEPTION_IF_NULL(outputs_json); - auto op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == kAtomicAddrCleanOpName) { - return true; - } - auto outputs_ptr = op_info->outputs_ptr(); - return GenOutputDescJson(anf_node, outputs_ptr, outputs_json); -} - -bool TbeKernelJsonCreator::GenOutputDescJson( - const std::shared_ptr &anf_node, - const std::vector> &outputs_ptr, nlohmann::json *outputs_json) { - MS_EXCEPTION_IF_NULL(outputs_json); - size_t output_idx = 0; - auto op_name = AnfAlgo::GetCNodeName(anf_node); - size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node); - - for (const auto &output_ptr : outputs_ptr) { - size_t output_obj_num = 0; - if (output_ptr->param_type() == kParamRequred) { - output_obj_num = 1; - } else if (output_ptr->param_type() == kParamDynamic) { - if (outputs_ptr.size() > 1) { - MS_LOG(ERROR) << "Dynamic output is unsupported multi output!"; - return false; - } - output_obj_num = real_output_num; - } else { - if (output_idx >= real_output_num) { - MS_LOG(INFO) << "Op:" << op_name << ", output" << output_ptr->name() << " is optional, output is none."; - std::vector output_list; - nlohmann::json output_obj; - output_obj[kJName] = output_ptr->name(); - output_obj[kJValid] = false; - output_list.emplace_back(output_obj); - (*outputs_json).push_back(output_list); - continue; - } else { - output_obj_num = 1; - } - } - std::vector output_list; - GenOutputList(anf_node, output_obj_num, output_ptr, &output_idx, &output_list); - (*outputs_json).push_back(output_list); - } - return true; -} - -void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr &anf_node, const size_t &output_obj_num, - const std::shared_ptr &output_ptr, size_t *output_idx, - std::vector *output_list) { - MS_EXCEPTION_IF_NULL(output_idx); - MS_EXCEPTION_IF_NULL(output_list); - for (size_t i = 0; i < output_obj_num; i++) { - auto dtype = GetDeviceOutputType(anf_node, *output_idx); - auto format = GetDeviceOutputFormat(anf_node, *output_idx); - auto shape = GetDeviceOutputShape(anf_node, *output_idx); - std::vector ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); - if (ori_shape.empty()) { - ori_shape.emplace_back(1); - } - nlohmann::json output_obj; - output_obj[kJDtype] = dtype; - output_obj[kJShape] = shape; - output_obj[kJFormat] = format; - output_obj[kJOriShape] = ori_shape; - output_obj[kJOriFormat] = kOpFormat_NCHW; - output_obj[kJName] = output_ptr->name(); - output_obj[kJValid] = true; - output_obj[kJParamType] = output_ptr->param_type(); - output_list->emplace_back(output_obj); - (*output_idx)++; - } -} - -bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr &anf_node, - const std::shared_ptr &op_info, nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_info); - MS_EXCEPTION_IF_NULL(attrs_json); - auto attrs_ptr = op_info->attrs_ptr(); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (TbeAdapter::RunAttrPass(anf_node, attrs_ptr, attrs_json)) { - return true; - } - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - for (const auto &attr_ptr : attrs_ptr) { - std::string attr_name = attr_ptr->name(); - nlohmann::json attr_obj; - attr_obj[kJName] = attr_name; - if (op_name == parallel::LAYER_NORM && attr_obj[kJName] == "epsilon" && creater_type_ == OP_SELECT_FORMAT) { - continue; - } - if (primitive->GetAttr(attr_name) != nullptr) { - auto value = primitive->GetAttr(attr_name); - std::string type = attr_ptr->type(); - ParseAttrValue(type, value, &attr_obj); - attr_obj[kJValid] = true; - } else { - if (op_info->impl_path().empty()) { - attr_obj[kJValid] = false; - } else { - if (attr_ptr->param_type() == kParamRequred && creater_type_ == SINGLE_BUILD) { - MS_LOG(EXCEPTION) << "Op name: " << op_info->op_name() << " attr: " << attr_name - << " is required, but not set."; - } else { - attr_obj[kJValid] = false; - } - } - } - (*attrs_json).push_back(attr_obj); - } - return true; -} - -void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, - nlohmann::json *attr_obj) { - MS_EXCEPTION_IF_NULL(value); - MS_EXCEPTION_IF_NULL(attr_obj); - if (type == kVTypeInt) { - auto attr_value = GetValue(value); - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeStr) { - auto attr_value = GetValue(value); - if (attr_value == kOpFormat_FRAC_Z) { - attr_value = kOpFormat_FRACTAL_Z; - } - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeBool) { - auto attr_value = GetValue(value); - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeFloat) { - auto attr_value = GetValue(value); - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeListInt) { - std::vector attr_value; - auto value_type = value->type(); - MS_EXCEPTION_IF_NULL(value_type); - auto value_type_str = value_type->ToString(); - if (value_type_str == kVTypeInt32) { - int data = GetValue(value); - attr_value.push_back(data); - } else { - attr_value = GetValue>(value); - } - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeListFloat) { - std::vector attr_value; - auto value_type = value->type(); - MS_EXCEPTION_IF_NULL(value_type); - auto value_type_str = value_type->ToString(); - if (value_type_str == kVTypeFloat) { - auto data = GetValue(value); - attr_value.push_back(data); - } else { - attr_value = GetValue>(value); - } - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeListUInt64) { - auto attr_value = GetValue>(value); - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeListListInt) { - auto attr_value = GetValue>>(value); - (*attr_obj)[kJValue] = attr_value; - } else { - MS_LOG(EXCEPTION) << "Type: " << type << "not support"; - } -} - -std::vector TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector shape; - if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { - shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index); - } else { - shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index); - } - if (shape.empty()) { - shape.emplace_back(1); - } - return shape; -} - -std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - TypeId type_id; - if (creater_type_ == OP_SELECT_FORMAT) { - type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index); - } else { - type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index); - } - return tbe::TypeIdToString(type_id); -} - -std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - std::string format = kOpFormat_NCHW; - if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { - format = AnfAlgo::GetInputFormat(anf_node, real_index); - if (format == kOpFormat_FRAC_Z) { - format = kOpFormat_FRACTAL_Z; - } else if (format == kOpFormat_DEFAULT) { - format = kOpFormat_NCHW; - } - } - return format; -} - -std::vector TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector shape; - if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { - shape = AnfAlgo::GetOutputInferShape(anf_node, real_index); - } else { - shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index); - } - if (shape.empty()) { - shape.emplace_back(1); - } - return shape; -} - -std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - TypeId type_id; - if (creater_type_ == OP_SELECT_FORMAT) { - type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index); - } else { - type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index); - } - return tbe::TypeIdToString(type_id); -} - -std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - std::string format = kOpFormat_NCHW; - if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { - format = AnfAlgo::GetOutputFormat(anf_node, real_index); - if (format == kOpFormat_FRAC_Z) { - format = kOpFormat_FRACTAL_Z; - } else if (format == kOpFormat_DEFAULT) { - format = kOpFormat_NCHW; - } - } - return format; -} - -bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, - std::vector *output_size_list) { - if (input_size_list == nullptr || output_size_list == nullptr) { - MS_LOG(ERROR) << "Input size or output size is nullptr"; - return false; - } - input_size_list->clear(); - output_size_list->clear(); - for (size_t i = 0; i < kernel_json[kJOpInfo][kJInputs].size(); i++) { - for (size_t m = 0; m < kernel_json[kJOpInfo][kJInputs][i].size(); m++) { - size_t size_i = 1; - if (kernel_json[kJOpInfo][kJInputs][i][m][kJValid] == false) { - std::string input_name = kernel_json[kJOpInfo][kJInputs][i][m][kJName]; - MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false."; - continue; - } - for (const auto &j : kernel_json[kJOpInfo][kJInputs][i][m][kJShape]) { - size_i *= static_cast(j); - } - std::string dtype = kernel_json[kJOpInfo][kJInputs][i][m][kJDtype]; - size_t nbyte = tbe::GetDtypeNbyte(dtype); - size_i *= nbyte; - input_size_list->push_back(size_i); - } - } - for (size_t i = 0; i < kernel_json[kJOpInfo][kJOutputs].size(); i++) { - for (size_t m = 0; m < kernel_json[kJOpInfo][kJOutputs][i].size(); m++) { - size_t size_i = 1; - if (kernel_json[kJOpInfo][kJOutputs][i][m][kJValid] == false) { - std::string output_name = kernel_json[kJOpInfo][kJOutputs][i][m][kJName]; - MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false."; - continue; - } - for (const auto &j : kernel_json[kJOpInfo][kJOutputs][i][m][kJShape]) { - size_i *= static_cast(j); - } - std::string dtype = kernel_json[kJOpInfo][kJOutputs][i][m][kJDtype]; - size_t nbyte = tbe::GetDtypeNbyte(dtype); - size_i *= nbyte; - output_size_list->push_back(size_i); - } - } - return true; -} - -bool TbeKernelBuild::GenFusionScopeJson(const std::vector &input_nodes, - const std::vector &compute_nodes, - nlohmann::json *fusion_str, std::string *fusion_kernel) { - MS_EXCEPTION_IF_NULL(fusion_str); - MS_EXCEPTION_IF_NULL(fusion_kernel); - // get input layer info - std::vector> input_layers; - std::map spec_data_input; - if (!GetInputLayers(input_nodes, compute_nodes, &input_layers, &spec_data_input)) { - return false; - } - // gen fusion scopre_op jsom - std::vector compute_list; - (*fusion_kernel) = kFusionKernelNamePrfix; - // index: fusion build option input record, next one from 0 - static size_t index = 0; - auto layer_iter = input_layers.begin(); - auto compute_op_iter = compute_nodes.begin(); - for (; compute_op_iter != compute_nodes.end(); ++compute_op_iter, ++layer_iter) { - nlohmann::json compute_op_str; - (void)GenFusionComputeJson(*compute_op_iter, &layer_iter, &compute_op_str, fusion_kernel, &index); - compute_list.push_back(compute_op_str); - } - index = 0; - // gen data input json - std::vector data_list; - for (const auto &layer : input_layers) { - for (const auto &data_input : layer) { - nlohmann::json data_str; - if (!GenFusionDataInputJson(data_input, spec_data_input, &data_str, &index)) { - MS_LOG(INFO) << "Fusion error: gen fusion datainput json faild."; - return false; - } - data_list.push_back(data_str); - } - } - index = 0; - data_list.insert(data_list.end(), compute_list.begin(), compute_list.end()); - (*fusion_str)[kFusionOpList] = data_list; - return true; -} - -void TbeKernelBuild::GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, - size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) { - std::string output_desc_name = anf_node->fullname_with_scope(); - if (node_out_idx > 0) { - output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx); - } - (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name); - auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); - (*output_desc)[kJDataType] = tbe::TypeIdToString(type_id); - auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); - if (ori_shape.empty()) { - ori_shape.emplace_back(1); - } - (*output_desc)[kJOriShape] = ori_shape; - auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, node_out_idx); - if (shape.empty()) { - shape.emplace_back(1); - } - (*output_desc)[kJShape] = shape; - auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx); - if (format == kOpFormat_DEFAULT) { - format = ori_shape.size() == 4 ? kOpFormat_NCHW : kOpFormat_ND; - } - (*output_desc)[kJFormat] = format; - (*output_desc)[kJOriFormat] = kOpFormat_NCHW; - (*output_desc)[kJOutputIndex] = desc_output_idx; - if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) { - std::vector spec_shape = {}; - spec_shape.emplace_back(shape[0]); - spec_shape.emplace_back(shape[1]); - spec_shape.emplace_back(shape[2] * shape[3]); - spec_shape.emplace_back(shape[4]); - (*output_desc)[kJShape] = spec_shape; - } else if (fusion_data_type == kFusionReLUGradV2) { - std::vector spec_shape = {}; - spec_shape.emplace_back(shape[0]); - spec_shape.emplace_back(shape[1]); - spec_shape.emplace_back(shape[2] * shape[3]); - spec_shape.emplace_back(16); - (*output_desc)[kJShape] = spec_shape; - (*output_desc)[kJDataType] = kVTypeBool; - } -} - -void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, - size_t output_index, nlohmann::json *output_desc) { - std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index); - (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name); - (*output_desc)[kJOutputIndex] = output_index; - std::vector shape; - (*output_desc)[kJShape] = shape; -} - -bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name, - const std::vector &reorder_layer, - std::map *spec_data_input) { - if ((op_name == kReluGradV2OpName || op_name == kAddNOpName) && reorder_layer.empty()) { - MS_LOG(INFO) << "Fusion error: node(" << op_name << " )'s input is null. "; - return false; - } - MS_LOG(INFO) << "Fusion info: op_name: " << op_name << "input layer size: " << reorder_layer.size(); - if (op_name == kReluGradV2OpName) { - (*spec_data_input)[reorder_layer[0]] = kFusionReLUGradV2; - } else if (op_name == kAddNOpName) { - for (const auto &it : reorder_layer) { - (*spec_data_input)[it] = kFusionAddN; - } - } - return true; -} - -bool TbeKernelBuild::GetInputLayers(const std::vector &input_nodes, - const std::vector &compute_nodes, - std::vector> *input_layers, - std::map *spec_data_input) { - MS_EXCEPTION_IF_NULL(input_layers); - MS_EXCEPTION_IF_NULL(spec_data_input); - auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) { - auto op_name = AnfAlgo::GetCNodeName(it); - return op_name == kConv2DBackpropInputOpName; - }); - bool need_spec = (result != compute_nodes.end()); - size_t input_size = 0; - for (const auto &compute_node : compute_nodes) { - std::vector layer = {}; - std::vector reorder_layer = {}; - MS_EXCEPTION_IF_NULL(compute_node); - auto op_name = AnfAlgo::GetCNodeName(compute_node); - auto ccompute_node = compute_node->cast(); - if (ccompute_node == nullptr) { - MS_LOG(INFO) << "Fusion error: fusion compute node must be cnode"; - return false; - } - MS_LOG(INFO) << "Fusion info: compute name: " << compute_node->fullname_with_scope(); - for (size_t i = 1; i < ccompute_node->inputs().size(); ++i) { - auto input = ccompute_node->input(i); - auto find_iter = std::find(input_nodes.begin(), input_nodes.end(), input); - if (find_iter != input_nodes.end()) { - MS_LOG(INFO) << "Fusion info: add compute node's [" << i << "] input: " << input->fullname_with_scope(); - layer.emplace_back((*find_iter)); - } else { - MS_LOG(INFO) << "Fusion warnig: this input [" << i << "] may be pre compute(" << input->fullname_with_scope() - << ") node's output."; - } - } - TbeAdapter::FusionDataOrderPass(op_name, layer, &reorder_layer); - if (need_spec) { - MS_LOG(INFO) << "Fusion info: match conv2d backprop input + ... patten."; - if (!GetSpecInputLayers(op_name, reorder_layer, spec_data_input)) { - return false; - } - } - input_size += reorder_layer.size(); - input_layers->emplace_back(reorder_layer); - } - if (input_nodes.size() != input_size) { - MS_LOG(INFO) << "Fusion error: fusion scope error, layer input:" << input_size - << ", input_node:" << input_nodes.size(); - return false; - } - return true; -} - -bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr &data_input, - const std::map &spec_data_input, - nlohmann::json *data_str, size_t *index) { - MS_EXCEPTION_IF_NULL(data_str); - MS_EXCEPTION_IF_NULL(index); - std::vector output_desc_list; - if (!data_input) { - MS_LOG(INFO) << "Data input is optional node"; - auto name = std::string(kOptional) + std::to_string(*index); - (*data_str)[kJName] = name; - nlohmann::json output_desc; - output_desc[kJName] = name; - output_desc[kJShape] = "NULL"; - output_desc_list.push_back(output_desc); - (*index)++; - } else { - FusionDataType fusion_data_type = kFusionNormal; - if (spec_data_input.find(data_input) != spec_data_input.end()) { - fusion_data_type = spec_data_input.at(data_input); - } - auto kernel_idx = AnfAlgo::VisitKernel(data_input, 0); - auto real_node = kernel_idx.first; - size_t real_idx = kernel_idx.second; - MS_LOG(INFO) << "Real name " << real_node->fullname_with_scope() << " index:" << real_idx; - // kJOutputDesc - nlohmann::json output_desc; - GenDescJson(real_node, real_idx, real_idx, &output_desc, fusion_data_type); - output_desc_list.push_back(output_desc); - (*data_str)[kJName] = NormalizeFullScopeName(real_node->fullname_with_scope()); - } - (*data_str)[kJOutputDesc] = output_desc_list; - (*data_str)[kJtype] = "Data"; - return true; -} - -bool TbeKernelBuild::IsDynamicInput(const mindspore::CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. - bool ret = false; - std::vector dyn_input_sizes; - auto dynamic_input_attr = primitive->GetAttr(kAttrDynInputSizes); - if (dynamic_input_attr != nullptr) { - dyn_input_sizes = GetValue>(dynamic_input_attr); - auto real_input_size = cnode->inputs().size() - 1; - auto dyn_input_size = dyn_input_sizes.size(); - if (dyn_input_size != 1) { - MS_LOG(INFO) << "Fusion error: fusion build not support dyn_input_sizes > 1"; - return ret; - } - if (IntToSize(dyn_input_sizes[0]) != real_input_size) { - MS_LOG(INFO) << "Fusion error: dyn_input_size" << dyn_input_sizes[0] << "not equal real_input_size" - << real_input_size; - return ret; - } - ret = true; - } - return ret; -} - -size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool is_dynamic_input) { - MS_EXCEPTION_IF_NULL(cnode); - if (is_dynamic_input) { - return 0; - } - MS_EXCEPTION_IF_NULL(cnode); - auto node_name = AnfAlgo::GetCNodeName(cnode); - auto op_info = OpLib::FindOp(node_name, kTBE); - MS_EXCEPTION_IF_NULL(cnode); - if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) { - MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope(); - } - return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size()); -} - -std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { - static std::map buffer_fussion_op_map = { - {parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}, {parallel::TENSOR_ADD, parallel::ADD}}; - string result = origin_type; - auto iter = buffer_fussion_op_map.find(origin_type); - if (iter != buffer_fussion_op_map.end()) { - result = iter->second; - } - return result; -} - -bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, - std::vector>::iterator *layer_iter, - std::vector *input_desc_list, size_t *index) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(input_desc_list); - std::vector input_desc_list_tmp = {}; - bool is_dynamic_input = IsDynamicInput(cnode); - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - auto input = cnode->input(i); - auto kernel_idx = AnfAlgo::VisitKernel(input, 0); - auto real_node = kernel_idx.first; - size_t real_idx = kernel_idx.second; - MS_LOG(INFO) << "Real name" << real_node->fullname_with_scope() << "index:" << real_idx; - nlohmann::json input_desc; - GenDescJson(real_node, real_idx, real_idx, &input_desc); - if (is_dynamic_input) { - MS_LOG(INFO) << "Node has dynamic input."; - input_desc[kJDynIndex] = (i - 1); - } - input_desc_list_tmp.emplace_back(input_desc); - } - size_t optional_num = GetOptionalInput(cnode, is_dynamic_input); - if (optional_num > 0) { - MS_LOG(INFO) << "Node has optional input."; - for (size_t i = 0; i < optional_num; ++i) { - nlohmann::json optional_input_desc; - optional_input_desc[kJName] = std::string(kOptional) + std::to_string(*index); - (*index)++; - (*layer_iter)->emplace_back(nullptr); - input_desc_list_tmp.emplace_back(optional_input_desc); - } - } - auto op_name = AnfAlgo::GetCNodeName(cnode); - TbeAdapter::FusionInputOrderPass(op_name, input_desc_list_tmp, input_desc_list); - return true; -} - -std::vector TbeKernelBuild::GetDescOutputIndex(const std::vector &output_used_nums) { - std::vector desc_output_index = {}; - for (size_t idx = 0; idx < output_used_nums.size(); ++idx) { - auto output_use_num_item = output_used_nums[idx]; - MS_LOG(INFO) << "Output used num[" << idx << "] = " << output_use_num_item; - desc_output_index.emplace_back(idx); - if (output_use_num_item > 1) { - desc_output_index.emplace_back(idx); - } - } - return desc_output_index; -} - -bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, - std::vector *output_desc_list) { - MS_EXCEPTION_IF_NULL(output_desc_list); - auto output_size = AnfAlgo::GetOutputTensorNum(cnode); - if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { - auto output_used_nums = AnfAlgo::GetNodeAttr>(cnode, kAttrOutputUsedNum); - MS_LOG(INFO) << "This node's output has been reused, node name: " << cnode->fullname_with_scope(); - if (output_used_nums.size() != output_size) { - MS_LOG(INFO) << "Fusion error: output tenor num(" << output_size << ")" - << " is not match output used num(" << output_used_nums.size() << ")"; - return false; - } - auto desc_output_index = GetDescOutputIndex(output_used_nums); - for (size_t i = 0; i < output_size; ++i) { - MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i]; - nlohmann::json output_desc; - GenDescJson(cnode, i, desc_output_index[i], &output_desc); - output_desc_list->emplace_back(output_desc); - } - for (size_t j = output_size; j < desc_output_index.size(); ++j) { - MS_LOG(INFO) << "Fusion index: " << j << ", desc_output_index: " << desc_output_index[j]; - nlohmann::json output_desc; - GenReusedOutputDesc(cnode, j, desc_output_index[j], &output_desc); - output_desc_list->emplace_back(output_desc); - } - } else { - for (size_t i = 0; i < output_size; ++i) { - nlohmann::json output_desc; - GenDescJson(cnode, i, i, &output_desc); - output_desc_list->push_back(output_desc); - } - } - return true; -} - -bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, - std::vector>::iterator *layer_iter, - nlohmann::json *compute_op_str, std::string *fusion_kernel_name, - size_t *index) { - MS_EXCEPTION_IF_NULL(compute_node); - auto cnode = compute_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // gen input desc - std::vector input_desc_list; - (void)GenFusionComputeInputJson(cnode, layer_iter, &input_desc_list, index); - (*compute_op_str)[kJInputDesc] = input_desc_list; - // gen output desc - std::vector output_desc_list; - if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) { - MS_LOG(INFO) << "Fusion Error: gen fusion output desc faild, node full name: " << cnode->fullname_with_scope(); - return false; - } - (*compute_op_str)[kJOutputDesc] = output_desc_list; - // gen others - auto origin_type = AnfAlgo::GetCNodeName(cnode); - // replace special op type for buffer fusion op - auto type = GetRealOpType(origin_type); - (*compute_op_str)[kJtype] = type; - tbe::TbeAdapter::NormalizeFuncName(&type); - (*compute_op_str)[kJFuncName] = type; - (*compute_op_str)[kJName] = NormalizeFullScopeName(cnode->fullname_with_scope()); - (void)(*fusion_kernel_name).append("_"); - (void)(*fusion_kernel_name).append(type); - return true; -} - -size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) { - size_t ret = 1; - for (const auto &shape_item : desc[kJShape]) { - ret *= static_cast(shape_item); - } - std::string data_type = desc[kJDataType]; - size_t nbyte = tbe::GetDtypeNbyte(data_type); - ret *= nbyte; - return ret; -} - -bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, - const std::vector &output_nodes, - std::vector *input_size_list, std::vector *output_size_list) { - MS_EXCEPTION_IF_NULL(input_size_list); - MS_EXCEPTION_IF_NULL(output_size_list); - input_size_list->clear(); - output_size_list->clear(); - - for (const auto &op : fusion_op_list) { - if (op[kJtype] == "Data") { - const auto &data_output_desc = op[kJOutputDesc]; - for (const auto &data_output : data_output_desc) { - if (data_output[kJShape] == "NULL") { - break; - } - auto ret = GetIOSizeImpl(data_output); - input_size_list->push_back(ret); - MS_LOG(INFO) << "Fusion info: scope input name: " << op[kJName] << ", size: " << ret; - } - } - } - - for (const auto &output_node : output_nodes) { - auto kernel_idx = AnfAlgo::VisitKernel(output_node, 0); - auto real_node = kernel_idx.first; - size_t real_idx = kernel_idx.second; - auto normal_name = NormalizeFullScopeName(real_node->fullname_with_scope()); - MS_LOG(INFO) << "Fusion info: real node name: " << normal_name << ", real output index: " << real_idx; - for (const auto &op : fusion_op_list) { - if (op[kJName] == normal_name) { - auto op_output_desces = op[kJOutputDesc]; - if (output_node != real_node) { - // tuple_get item - MS_LOG(INFO) << "Output is a tuple getitem node"; - auto output_desc = op_output_desces[real_idx]; - if (output_desc[kJShape].empty()) { - MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx; - return false; - } - auto ret = GetIOSizeImpl(output_desc); - output_size_list->push_back(ret); - MS_LOG(INFO) << "Fusion info: scope output index: " << real_idx << ", size: " << ret; - } else { - for (const auto &output_desc : op_output_desces) { - if (output_desc[kJShape].empty()) { - MS_LOG(INFO) << "Fusion info: output_desc's shape is empty, may be this node output"; - continue; - } - auto ret = GetIOSizeImpl(output_desc); - output_size_list->push_back(ret); - MS_LOG(INFO) << "Fusion info: scope output size: " << ret; - } - } - } - } - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h deleted file mode 100644 index eef02efa87..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "ir/dtype.h" -#include "kernel/kernel.h" -#include "pybind11/stl.h" -#include "kernel/oplib/oplib.h" -#include "kernel/tbe/tbe_adapter.h" - -namespace mindspore { -namespace kernel { -// kernel operate type used for generate json - -class TbeKernelBuild { - enum FusionDataType { kFusionNormal = 0, kFusionAddN, kFusionReLUGradV2 }; - - public: - static bool GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, - std::vector *output_size_list); - // Ub Fuison - static bool GenFusionScopeJson(const std::vector &input_nodes, - const std::vector &compute_nodes, nlohmann::json *fusion_str, - std::string *fusion_kernel); - static bool GetIOSize(const nlohmann::json &fusion_op_list, const std::vector &output_nodes, - std::vector *input_size_list, std::vector *output_size_list); - - private: - TbeKernelBuild() = default; - ~TbeKernelBuild() = default; - static bool GenFusionDataInputJson(const std::shared_ptr &data_input, - const std::map &spec_data_input, - nlohmann::json *data_str, size_t *index); - static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, - std::vector>::iterator *layer_iter, - nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index); - static bool GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, - std::vector>::iterator *layer_iter, - std::vector *input_desc_list, size_t *index); - static std::vector GetDescOutputIndex(const std::vector &output_used_nums); - static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, - std::vector *output_desc_list); - static void GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, - size_t desc_output_idx, nlohmann::json *output_desc, - FusionDataType fusion_data_type = kFusionNormal); - static void GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, - size_t output_index, nlohmann::json *output_desc); - static size_t GetIOSizeImpl(const nlohmann::json &desc); - static bool GetSpecInputLayers(const std::string &op_name, const std::vector &reorder_layer, - std::map *spec_data_input); - static bool GetInputLayers(const std::vector &input_nodes, - const std::vector &compute_nodes, - std::vector> *input_layers, - std::map *spec_data_input); - static bool IsDynamicInput(const CNodePtr &cnode); - static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input); - static std::string GetRealOpType(const std::string &origin_type); -}; - -class TbeKernelJsonCreator { - public: - explicit TbeKernelJsonCreator(kCreaterType creater_type = SINGLE_BUILD) : creater_type_(creater_type) {} - ~TbeKernelJsonCreator() = default; - bool GenTbeSingleKernelJson(const std::shared_ptr &anf_node, nlohmann::json *kernel_json); - std::string json_name() { return json_name_; } - - private: - bool GenTbeInputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, - nlohmann::json *inputs_json); - bool GenTbeOutputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, - nlohmann::json *outputs_json); - bool GenTbeAttrJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, - nlohmann::json *attrs_json); - static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); - bool GenInputDescJson(const std::shared_ptr &anf_node, size_t real_input_index, bool value, - const std::shared_ptr &input_ptr, const string &op_input_name, size_t input_i, - std::vector *input_list); - bool GenOutputDescJson(const std::shared_ptr &anf_node, - const std::vector> &outputs_ptr, nlohmann::json *outputs_json); - bool GenInputList(const std::shared_ptr &anf_node, size_t input_tensor_num, - const std::shared_ptr &input_ptr, size_t *real_input_index, string *op_input_name, - std::vector *input_list); - void GenOutputList(const std::shared_ptr &anf_node, const size_t &output_obj_num, - const std::shared_ptr &output_ptr, size_t *output_idx, - std::vector *output_list); - std::vector GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const; - std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const; - std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const; - std::vector GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const; - std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const; - std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const; - - kCreaterType creater_type_; - std::string json_name_; - std::string json_info_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.cc deleted file mode 100644 index 0f377940da..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.cc +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_mod.h" -#include -#include "runtime/rt.h" -#include "nlohmann/json.hpp" -#include "graphengine/inc/framework/ge_runtime/task_info.h" - -namespace mindspore { -namespace kernel { -using TbeTaskInfoPtr = std::shared_ptr; -using tbe::KernelManager; -bool TbeKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (stream_ptr == nullptr) { - MS_LOG(ERROR) << "stream_ptr should not be nullptr."; - return false; - } - - if (kernel_pack_ == nullptr) { - MS_LOG(ERROR) << "kernel pack should not be nullptr."; - return false; - } - - uint32_t blockdim = 1; // default blockdim equal to 1. - auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &blockdim); - if (func_stub == 0) { - MS_LOG(ERROR) << "GenFuncStub failed."; - return false; - } - - // pack all addresses into a vector. - std::vector runtimeargs; - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs), - [](const AddressPtr &input) -> void * { return input->addr; }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs), - [](const AddressPtr &output) -> void * { return output->addr; }); - if (!workspace.empty()) { - (void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(runtimeargs), - [](const AddressPtr &addr) -> void * { return addr->addr; }); - } - rtL2Ctrl_t *l2ctrl = nullptr; - const void *stubFunc = reinterpret_cast(func_stub); - auto argsSize = static_cast(UlongToUint(sizeof(void *)) * runtimeargs.size()); - if (RT_ERROR_NONE != rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream_ptr)) { - MS_LOG(ERROR) << "Call runtime rtKernelLaunch error."; - return false; - } - - return true; -} - -std::vector TbeKernelMod::GenTask(const std::vector &inputs, - const std::vector &workspaces, - const std::vector &outputs, uint32_t stream_id) { - if (kernel_pack_ == nullptr) { - MS_EXCEPTION(ArgumentError) << "kernel pack should not be nullptr."; - } - - std::vector args; - std::vector sm_desc; - std::vector meta_data; - std::vector input_data_addrs; - std::vector output_data_addrs; - std::vector workspace_addrs; - - // pack all addresses into a vector. - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), - [](const AddressPtr &input) -> void * { return input->addr; }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), - [](const AddressPtr &output) -> void * { return output->addr; }); - if (!workspaces.empty()) { - (void)std::transform(std::begin(workspaces), std::end(workspaces), std::back_inserter(workspace_addrs), - [](const AddressPtr &workspace) -> void * { return workspace->addr; }); - } - - stream_id_ = stream_id; - auto funcstub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim_); - if (funcstub == 0) { - MS_EXCEPTION(ArgumentError) << "GenFuncStub failed."; - } - - std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_); - - MS_LOG(INFO) << "block_dim is:" << block_dim_; - - TbeTaskInfoPtr task_info_ptr = - make_shared(stream_id, stub_func, block_dim_, args, 0, sm_desc, nullptr, 0, - meta_data, input_data_addrs, output_data_addrs, workspace_addrs); - return {task_info_ptr}; -} - -vector TbeKernelMod::GenParameters() { - auto kernel_json_info = kernel_pack_->kernel_json_info(); - return kernel_json_info.parameters; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.h deleted file mode 100644 index e0e7ab4646..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ - -#include -#include -#include -#include -#include "kernel/ascend_kernel_mod.h" -#include "kernel/tbe/tbe_utils.h" - -namespace mindspore { -namespace kernel { -class TbeKernelMod : public AscendKernelMod { - public: - explicit TbeKernelMod(KernelPackPtr kernel_pack) : kernel_pack_(std::move(kernel_pack)) {} - ~TbeKernelMod() override = default; - - void SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } - void SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } - void SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspaces, - const std::vector &outputs, uint32_t stream_id) override; - std::vector GenParameters() override; - - private: - KernelPackPtr kernel_pack_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -using TbeKernelModPtr = std::shared_ptr; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc deleted file mode 100644 index 43d492f397..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc +++ /dev/null @@ -1,326 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_parallel_build.h" - -#include -#include -#include -#include -#include -#include - -#include "utils/context/ms_context.h" -#include "kernel/tbe/tbe_adapter.h" -#include "kernel/tbe/tbe_kernel_build.h" -#include "kernel/tbe/tbe_kernel_mod.h" -#include "session/anf_runtime_algorithm.h" -#include "./common.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "kernel/tbe/tbe_convert_utils.h" -#include "kernel/tbe/tbe_utils.h" - -namespace mindspore { -namespace kernel { -using mindspore::kernel::tbe::TbeUtils; -constexpr auto kParallelCompileModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process"; -constexpr auto kCreateParallelCompiler = "create_tbe_parallel_compiler"; -constexpr auto kStartCompileOp = "start_compile_op"; -constexpr auto kWaitOne = "wait_one"; -constexpr auto kResetTaskInfo = "reset_task_info"; - -bool TbeOpParallelPreBuild(const std::vector &anf_nodes) { - auto build_manger = std::make_shared(); - MS_EXCEPTION_IF_NULL(build_manger); - for (const auto &anf_node : anf_nodes) { - // gen kernel json - MS_EXCEPTION_IF_NULL(anf_node); - nlohmann::json kernel_json; - TbeKernelJsonCreator creator(OP_PRE_COMPILE); - if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) { - MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; - return false; - } - kernel_json["compile_type"] = "pre_build"; - // op build - auto task_id = build_manger->StartCompileOp(kernel_json); - build_manger->SavePreTaskInfo(task_id, anf_node); - } - while (!build_manger->IsAllPreTaskFinish()) { - int task_id = -1; - char *task_result = nullptr; - char *pre_build_result = nullptr; - auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); - if (!ret) { - MS_EXCEPTION(ArgumentError) << "Pre Build Failed. wait one ret:" << ret << ", task id:" << task_id; - } - - if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { - MS_EXCEPTION(ArgumentError) << "task pre compile Failed, task id:" << task_id << ", cause:" << task_result; - } - - build_manger->PreTaskFinishProcess(task_id, pre_build_result); - } - return true; -} - -bool TbeOpParallelBuild(const std::vector &anf_nodes) { - auto build_manger = std::make_shared(); - MS_EXCEPTION_IF_NULL(build_manger); - set processed_kernel; - for (const auto &anf_node : anf_nodes) { - // gen kernel json - tbe::TbeAdapter::SetTbeAttrsForTransDataOp(anf_node); - if (AnfAlgo::GetKernelMod(anf_node) != nullptr) { - continue; - } - const std::string &processor = tbe::GetProcessor(anf_node); - nlohmann::json kernel_json; - TbeKernelJsonCreator creator(SINGLE_BUILD); - if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) { - MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; - return false; - } - // get size - std::vector input_size_list; - std::vector output_size_list; - (void)TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); - // search cache - const std::string &json_name = creator.json_name(); - if (build_manger->SearchInCache(json_name, processor, input_size_list, output_size_list, anf_node.get())) { - MS_LOG(INFO) << "Use cached kernel, kernel json name:." << json_name; - continue; - } - // same op not need build, but need wait build finish to set kernel mode - if (processed_kernel.find(json_name) != processed_kernel.end()) { - build_manger->SaveSameOpInfo(anf_node, json_name, input_size_list, output_size_list); - continue; - } - (void)processed_kernel.insert(json_name); - // op build - auto task_id = build_manger->StartCompileOp(kernel_json); - build_manger->SaveTaskInfo(task_id, anf_node, json_name, input_size_list, output_size_list); - } - while (!build_manger->IsAllTaskFinish()) { - int task_id = -1; - char *task_result = nullptr; - char *pre_build_result = nullptr; - auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); - if (!ret) { - MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; - } - - if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { - MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; - } - (void)build_manger->TaskFinishProcess(task_id); - } - return build_manger->GenSameOpKernelMod(); -} - -ParallelBuildManager::ParallelBuildManager() { tbe_parallel_compiler_ = TbePythonFuncs::TbeParallelCompiler(); } - -ParallelBuildManager::~ParallelBuildManager() { ResetTaskInfo(); } - -int32_t ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) const { - PyObject *pRes = nullptr; - PyObject *pArgs = PyTuple_New(1); - std::string json_str = kernel_json.dump(); - PyObject *arg1 = Py_BuildValue("s", json_str.c_str()); - (void)PyTuple_SetItem(pArgs, 0, arg1); - pRes = PyObject_CallMethod(tbe_parallel_compiler_, kStartCompileOp, "O", pArgs); - if (pRes == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function start_compile_op"; - } - int task_id; - (void)PyArg_Parse(pRes, "i", &task_id); - MS_LOG(INFO) << "start compile , task id:" << task_id; - return task_id; -} - -bool ParallelBuildManager::WaitOne(int *task_id, char **task_result, char **pre_build_result) const { - MS_LOG(INFO) << "wait task start."; - MS_EXCEPTION_IF_NULL(task_id); - MS_EXCEPTION_IF_NULL(task_result); - PyObject *pRes = nullptr; - PyObject *pArg = Py_BuildValue("()"); - pRes = PyObject_CallMethod(tbe_parallel_compiler_, kWaitOne, "O", pArg); - if (pRes == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function wait_one"; - return false; - } - (void)PyArg_ParseTuple(pRes, "iss", task_id, task_result, pre_build_result); - return true; -} - -void ParallelBuildManager::SavePreTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node) { - MS_LOG(INFO) << "SavePreTaskInfo, task id: " << task_id; - pre_task_map_[task_id] = anf_node; -} - -void ParallelBuildManager::SaveTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node, - const std::string &json_name, const std::vector &input_size_list, - const std::vector &output_size_list, int32_t scope_id) { - MS_LOG(INFO) << "SaveTaskInfo, task id: " << task_id; - struct KernelBuildTaskInfo task_info; - task_info.node = anf_node.get(); - task_info.json_name = json_name; - if (anf_node == nullptr) { - task_info.processor = tbe::kProcessorAiCore; - } else { - task_info.processor = tbe::GetProcessor(anf_node); - } - task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end()); - task_info.output_size_list.assign(output_size_list.begin(), output_size_list.end()); - task_info.scope_id = scope_id; - task_map_[task_id] = task_info; -} - -bool ParallelBuildManager::IsAllPreTaskFinish() const { - MS_LOG(INFO) << "wait pre build process task_num: " << pre_task_map_.size(); - return pre_task_map_.empty(); -} - -bool ParallelBuildManager::IsAllTaskFinish() const { - MS_LOG(INFO) << "wait process task_num: " << task_map_.size(); - return task_map_.empty(); -} - -void ParallelBuildManager::PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result) { - auto task_iter = pre_task_map_.find(task_id); - if (task_iter == pre_task_map_.end()) { - MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id; - } - auto node = task_iter->second; - auto builder = - std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); - std::string start_flag = "fusion_pattern_start"; - std::string end_flag = "fusion_pattern_end"; - int start = pre_build_result.find(start_flag); - int end = pre_build_result.find(end_flag); - if (start != -1 && end != -1 && end >= start) { - std::string result = pre_build_result.substr(start + start_flag.size(), end - start - start_flag.size()); - if (result == "") { - (void)pre_task_map_.erase(task_iter); - return; - } - transform(result.begin(), result.end(), result.begin(), ::toupper); - FusionType fusion_type = tbe::GetFusionType(result); - builder->SetFusionType(fusion_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); - } - (void)pre_task_map_.erase(task_iter); -} - -std::pair ParallelBuildManager::TaskFinishProcess(int32_t task_id, bool set_kernel_mod) { - auto task_iter = task_map_.find(task_id); - if (task_iter == task_map_.end()) { - MS_EXCEPTION(ArgumentError) << "can find task_id:" << task_id; - } - auto json_name = task_iter->second.json_name; - auto processor = task_iter->second.processor; - auto kernel_pack = TbeUtils::InsertCache(json_name, processor); - if (kernel_pack == nullptr) { - if (set_kernel_mod) { - MS_EXCEPTION(ArgumentError) << "build kernel name:" << task_iter->second.json_name << " failed."; - } else { - MS_LOG(INFO) << "fusion build kernel name:" << task_iter->second.json_name << "failed."; - auto ret = std::make_pair(task_iter->second.scope_id, nullptr); - (void)task_map_.erase(task_iter); - return ret; - } - } - auto kernel_mod = GenKernelMod(json_name, processor, task_iter->second.input_size_list, - task_iter->second.output_size_list, kernel_pack); - MS_EXCEPTION_IF_NULL(kernel_mod); - if (set_kernel_mod) { - AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node); - } - auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod); - (void)task_map_.erase(task_iter); - MS_LOG(INFO) << "wait process remain task_num:" << task_map_.size(); - return ret; -} - -void ParallelBuildManager::SaveSameOpInfo(const mindspore::AnfNodePtr &anf_node, const std::string &json_name, - const std::vector &input_size_list, - const std::vector &output_size_list) { - struct KernelBuildTaskInfo task_info; - task_info.node = anf_node.get(); - task_info.json_name = json_name; - task_info.processor = tbe::GetProcessor(anf_node); - task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end()); - task_info.output_size_list.assign(output_size_list.begin(), output_size_list.end()); - same_op_list_.push_back(task_info); -} - -bool ParallelBuildManager::GenSameOpKernelMod() const { - for (const auto &task_info : same_op_list_) { - bool ret = SearchInCache(task_info.json_name, task_info.processor, task_info.input_size_list, - task_info.output_size_list, task_info.node); - if (!ret) { - MS_LOG(INFO) << "can't find " << task_info.json_name << " in cache."; - return false; - } - } - return true; -} - -bool ParallelBuildManager::SearchInCache(const std::string &json_name, const std::string &processor, - const std::vector &input_size_list, - const std::vector &output_size_list, mindspore::AnfNode *node) const { - auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor); - if (cached_kernel_pack != nullptr) { - MS_LOG(INFO) << "Find cached kernel, kernel json name" << json_name; - auto kernel_mod_ptr = GenKernelMod(json_name, processor, input_size_list, output_size_list, cached_kernel_pack); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - AnfAlgo::SetKernelMod(kernel_mod_ptr, node); - return true; - } else { - return false; - } -} - -KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const string &processor, - const vector &input_size_list, - const vector &output_size_list, - const mindspore::kernel::KernelPackPtr &kernel_pack) const { - MS_EXCEPTION_IF_NULL(kernel_pack); - auto kernel_json_info = kernel_pack->kernel_json_info(); - auto kernel_mod_ptr = std::make_shared(kernel_pack); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - kernel_mod_ptr->SetInputSizeList(input_size_list); - kernel_mod_ptr->SetOutputSizeList(output_size_list); - kernel_mod_ptr->SetWorkspaceSizeList(kernel_json_info.workspaces); - return kernel_mod_ptr; -} - -void ParallelBuildManager::ResetTaskInfo() { - if (task_map_.empty()) { - MS_LOG(INFO) << "All tasks are compiled success."; - return; - } - task_map_.clear(); - same_op_list_.clear(); - if (tbe_parallel_compiler_ != nullptr) { - PyObject *pArg = Py_BuildValue("()"); - (void)PyObject_CallMethod(tbe_parallel_compiler_, kResetTaskInfo, "O", pArg); - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h deleted file mode 100644 index 637c03bce3..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ - -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "pybind11/stl.h" -#include -namespace mindspore { -namespace kernel { -bool TbeOpParallelPreBuild(const std::vector &anf_nodes); -bool TbeOpParallelBuild(const std::vector &anf_nodes); - -struct KernelBuildTaskInfo { - AnfNode *node; - std::string processor; - std::string json_name; - std::vector input_size_list; - std::vector output_size_list; - int32_t scope_id; -}; - -class ParallelBuildManager { - public: - ParallelBuildManager(); - ~ParallelBuildManager(); - int32_t StartCompileOp(const nlohmann::json &kernel_json) const; - void SavePreTaskInfo(int32_t task_id, const AnfNodePtr &anf_node); - void SaveTaskInfo(int32_t task_id, const AnfNodePtr &anf_node, const std::string &json_name, - const std::vector &input_size_list, const std::vector &output_size_list, - int32_t scope_id = 0); - void SaveSameOpInfo(const AnfNodePtr &anf_node, const std::string &json_name, - const std::vector &input_size_list, const std::vector &output_size_list); - bool GenSameOpKernelMod() const; - bool SearchInCache(const std::string &json_name, const std::string &processor, - const std::vector &input_size_list, const std::vector &output_size_list, - AnfNode *node) const; - - bool WaitOne(int *task_id, char **task_result, char **pre_build_result) const; - bool IsAllPreTaskFinish() const; - bool IsAllTaskFinish() const; - void PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result); - std::pair TaskFinishProcess(int32_t task_id, bool set_kernel_mod = true); - KernelModPtr GenKernelMod(const string &json_name, const string &processor, - const std::vector &input_size_list, const std::vector &output_size_list, - const KernelPackPtr &kernel_pack) const; - void ResetTaskInfo(); - - private: - PyObject *tbe_parallel_compiler_; - std::map pre_task_map_; - std::map task_map_; - std::vector same_op_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc deleted file mode 100644 index 8050f02f95..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc +++ /dev/null @@ -1,318 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" - -namespace mindspore { -namespace kernel { -constexpr size_t kInputIndex_0 = 0; -constexpr size_t kChannelN = 0; -constexpr size_t kChannelC = 1; -constexpr size_t kAlignmented16 = 16; -// 1. all shape no scalar and same -// 2. part scalar : no_scalar (shape size > xxx && alig xxx) -// 3. all no_scalar and not same (broad cast xxx dim) -bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) { - MS_EXCEPTION_IF_NULL(support_format); - input_num_ = 0; - output_num_ = 0; - input_shapes_.clear(); - output_shapes_.clear(); - if (AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode_ptr_)) { - MS_LOG(INFO) << "This broadcast node has dynamic input."; - auto dynamic_size_vec = AnfAlgo::GetNodeAttr>(cnode_ptr_, kAttrDynInputSizes); - if (dynamic_size_vec.empty() || dynamic_size_vec[0] < 2) { - MS_LOG(EXCEPTION) << "dynamic attr set error, please check."; - } - auto dynamic_input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); - PadScalarShape(&dynamic_input_shape0_); - input_shapes_.emplace_back(dynamic_input_shape0_); - input_num_ = 1; - } else { - input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_); - for (size_t i = 0; i < input_num_; ++i) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); - PadScalarShape(&input_shape); - input_shapes_.emplace_back(input_shape); - } - } - - output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); - for (size_t i = 0; i < output_num_; ++i) { - auto output = AnfAlgo::GetOutputInferShape(cnode_ptr_, i); - PadScalarShape(&output); - output_shapes_.emplace_back(output); - } - AssignSupportFormat(kOpFormat_DEFAULT, support_format); - return true; -} - -bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (IsSameShape()) { - if (!HasScalarInput()) { - AssignSupportFormat(kOpFormat_NC1HWC0, support_format); - return true; - } else { - return false; - } - } - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - if (HasScalarInput()) { - for (const auto &shape : input_shapes_) { - if (IsScalarShape(shape)) { - input_support_format.emplace_back(kOpFormat_DEFAULT); - } else { - if (!Is4DShape(shape)) { - return false; - } - if (shape[kChannelC] % kAlignmented16 != 0) { - return false; - } - input_support_format.emplace_back(kOpFormat_NC1HWC0); - } - } - } else { - for (const auto &shape : input_shapes_) { - if (!Is4DShape(shape)) { - return false; - } - } - auto shape_tmp = input_shapes_[0]; - auto broadcast_c_axis = std::any_of( - input_shapes_.begin(), input_shapes_.end(), - [&shape_tmp](const std::vector &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); }); - if (broadcast_c_axis) { - MS_LOG(INFO) << "This node broadcast c channel."; - return false; - } - input_support_format.assign(input_num_, kOpFormat_NC1HWC0); - } - GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); - return true; -} - -bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracZ(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (IsSameShape()) { - if (!HasScalarInput()) { - AssignSupportFormat(kOpFormat_FRAC_Z, support_format); - return true; - } else { - return false; - } - } - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - if (HasScalarInput()) { - for (const auto &shape : input_shapes_) { - if (IsScalarShape(shape)) { - input_support_format.emplace_back(kOpFormat_DEFAULT); - } else { - if (!Is4DShape(shape)) { - return false; - } - if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) { - return false; - } - input_support_format.emplace_back(kOpFormat_FRAC_Z); - } - } - } else { - return false; - } - GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); - return true; -} -bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (IsSameShape()) { - if (!HasScalarInput()) { - AssignSupportFormat(kOpFormat_C1HWNCoC0, support_format); - return true; - } else { - return false; - } - } - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - if (HasScalarInput()) { - for (const auto &shape : input_shapes_) { - if (IsScalarShape(shape)) { - input_support_format.emplace_back(kOpFormat_DEFAULT); - } else { - if (!Is4DShape(shape)) { - return false; - } - if (shape[kChannelN] % kAlignmented16 != 0) { - return false; - } - input_support_format.emplace_back(kOpFormat_C1HWNCoC0); - } - } - } else { - for (const auto &shape : input_shapes_) { - if (!Is4DShape(shape)) { - return false; - } - } - auto shape_tmp = input_shapes_[0]; - auto broadcast_nc_axis = - std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector &elem) { - return (shape_tmp.at(kChannelC) != elem.at(kChannelC) || shape_tmp.at(kChannelN) != elem.at(kChannelN)); - }); - if (broadcast_nc_axis) { - MS_LOG(INFO) << "This node broadcast n || c channel."; - return false; - } - input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0); - } - GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); - return true; -} - -bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (IsSameShape()) { - if (!HasScalarInput()) { - AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); - return true; - } else { - return false; - } - } - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - if (HasScalarInput()) { - for (const auto &shape : input_shapes_) { - if (IsScalarShape(shape)) { - input_support_format.emplace_back(kOpFormat_DEFAULT); - } else { - if (shape.size() < kShape2dDims) { - return false; - } - if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - 2] % kAlignmented16 != 0) { - return false; - } - input_support_format.emplace_back(kOpFormat_FRAC_NZ); - } - } - } else { - auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(), - [](const std::vector &elem) { return elem.size() < kShape2dDims; }); - if (less_2dims) { - MS_LOG(INFO) << "This node dim less 2."; - return false; - } - - auto shape_tmp = input_shapes_[0]; - auto broadcast_last_dim = - std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector &elem) { - return (shape_tmp.at(shape_tmp.size() - 1) != elem.at(elem.size() - 1)) || - (shape_tmp.at(shape_tmp.size() - 2) != elem.at(elem.size() - 2)); - }); - if (broadcast_last_dim) { - MS_LOG(INFO) << "This node broadcast last channel."; - return false; - } - - input_support_format.assign(input_num_, kOpFormat_FRAC_NZ); - } - GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); - return true; -} - -bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - return false; -} - -bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector &shape) const { - return shape.size() == kShape4dDims; -} - -bool TbeKernelBroadCastSelecter::IsSameShape() const { - auto shape = input_shapes_.begin(); - for (const auto &item : input_shapes_) { - if (shape->size() != item.size()) { - return false; - } - for (size_t i = 0; i < shape->size(); ++i) { - if (shape->at(i) != item.at(i)) { - return false; - } - } - } - return true; -} - -void TbeKernelBroadCastSelecter::PadScalarShape(std::vector *shape) const { - MS_EXCEPTION_IF_NULL(shape); - if (shape->empty()) { - shape->emplace_back(1); - } -} - -bool TbeKernelBroadCastSelecter::IsScalarShape(const std::vector &shape) const { - return (shape.size() == 1 && shape[0] == 1); -} - -bool TbeKernelBroadCastSelecter::HasScalarInput() const { - bool ret = false; - for (const auto &shape : input_shapes_) { - if (IsScalarShape(shape)) { - ret = true; - break; - } - } - return ret; -} - -void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &support_format, - SupportFormatItem *output_support_item) const { - MS_EXCEPTION_IF_NULL(output_support_item); - for (const auto &shape : output_shapes_) { - if (IsScalarShape(shape)) { - output_support_item->emplace_back(kOpFormat_DEFAULT); - } else { - output_support_item->emplace_back(support_format); - } - } -} - -void TbeKernelBroadCastSelecter::AssignSupportFormat(const std::string &support_format_str, - mindspore::kernel::SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - input_support_format.assign(input_num_, support_format_str); - output_support_format.assign(output_num_, support_format_str); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h deleted file mode 100644 index af711ddf29..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ - -#include -#include -#include -#include "ir/anf.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" - -namespace mindspore { -namespace kernel { -class TbeKernelBroadCastSelecter { - public: - explicit TbeKernelBroadCastSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} - ~TbeKernelBroadCastSelecter() = default; - bool GetShapeInfo(SupportFormat *support_format); - bool IsBroadCastSupport5HD(SupportFormat *support_format) const; - bool IsBroadCastSupportFracZ(SupportFormat *support_format) const; - bool IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const; - bool IsBroadCastSupportFracNZ(SupportFormat *support_format) const; - bool IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const; - - private: - bool IsSameShape() const; - void PadScalarShape(std::vector *shape) const; - bool Is4DShape(const std::vector &shape) const; - bool IsScalarShape(const std::vector &shape) const; - bool HasScalarInput() const; - void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; - void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; - // broadcast - CNodePtr cnode_ptr_; - size_t input_num_{}; - size_t output_num_{}; - std::vector> input_shapes_; - std::vector> output_shapes_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc deleted file mode 100644 index 3f8e5b85c3..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc +++ /dev/null @@ -1,178 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" -#include -#include -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" - -namespace mindspore { -namespace kernel { -constexpr char kAxis[] = "axis"; -constexpr char kTypeInt32[] = "Int32"; -constexpr size_t kInputIndex_0 = 0; -constexpr size_t kOutputIndex_0 = 0; -constexpr size_t kChannelN = 0; -constexpr size_t kChannelC = 1; -constexpr size_t kReduceNZMinDim = 3; - -bool TbeKernelReduceSelecter::GetShapeInfo(SupportFormat *support_format) { - MS_EXCEPTION_IF_NULL(support_format); - input_shape_.clear(); - output_shape_.clear(); - axis_.clear(); - auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); - auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); - if (input_num != 1 || output_num != 1) { - MS_LOG(EXCEPTION) << "Reduce operator only support one input/output, input num: " << input_num - << ", output num: " << output_num; - } - // get input/output shape - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); - PadScalarShape(&input_shape_); - output_shape_ = AnfAlgo::GetOutputInferShape(cnode_ptr_, kOutputIndex_0); - PadScalarShape(&output_shape_); - // get keep dim attr - GetReduceAttrKeepDim(); - // get axis attr - GetReduceAttrAxis(); - AssignSupportFormat(kOpFormat_DEFAULT, support_format); - return true; -} - -bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (!Is4DShape(input_shape_)) { - return false; - } - if (!keep_dims_ || axis_.empty()) { - return false; - } - auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); }); - if (reduce_c_axis) { - return false; - } - AssignSupportFormat(kOpFormat_NC1HWC0, support_format); - return true; -} - -bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - // like to 5HD - return false; -} - -bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { - return IsFracZAndC1HWNCoC0Common(kOpFormat_FRAC_Z, support_format); -} - -bool TbeKernelReduceSelecter::IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const { - return IsFracZAndC1HWNCoC0Common(kOpFormat_C1HWNCoC0, support_format); -} - -bool TbeKernelReduceSelecter::IsReduceSupportFracNZ(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (input_shape_.size() < kReduceNZMinDim) { - return false; - } - if (axis_.empty()) { - return false; - } - auto reduce_last_axis = std::any_of(axis_.begin(), axis_.end(), [this](const size_t &elem) { - return (elem == (this->input_shape_.size() - 1) || elem == (this->input_shape_.size() - 2)); - }); - if (reduce_last_axis) { - return false; - } - AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); - return true; -} - -bool TbeKernelReduceSelecter::IsFracZAndC1HWNCoC0Common(const std::string &format, - mindspore::kernel::SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (!Is4DShape(input_shape_)) { - return false; - } - if (!keep_dims_ || axis_.empty()) { - return false; - } - auto reduce_n_c_axis = std::any_of(axis_.begin(), axis_.end(), - [](const size_t &elem) { return (elem == kChannelC || elem == kChannelN); }); - if (reduce_n_c_axis) { - return false; - } - AssignSupportFormat(format, support_format); - return true; -} - -void TbeKernelReduceSelecter::GetReduceAttrAxis() { - auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); - MS_EXCEPTION_IF_NULL(primitive); - auto axis = primitive->GetAttr(kAxis); - if (axis == nullptr) { - MS_LOG(INFO) << "This node does't have axie attr."; - return; - } - auto type = axis->type(); - MS_EXCEPTION_IF_NULL(type); - std::vector axis_list; - if (type->ToString() == kTypeInt32) { - axis_list.emplace_back(GetValue(axis)); - } else { - axis_list = GetValue>(axis); - } - for (const auto &elem : axis_list) { - if (elem < 0) { - axis_.emplace_back(input_shape_.size() + elem); - } else { - axis_.emplace_back(IntToSize(elem)); - } - } -} - -void TbeKernelReduceSelecter::GetReduceAttrKeepDim() { - if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode_ptr_)) { - MS_LOG(INFO) << "This node does't have keep_attr."; - keep_dims_ = false; - return; - } - keep_dims_ = AnfAlgo::GetNodeAttr(cnode_ptr_, kAttrKeepDims); -} - -void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_format_str, - mindspore::kernel::SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - input_support_format.emplace_back(support_format_str); - output_support_format.emplace_back(support_format_str); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); -} - -bool TbeKernelReduceSelecter::Is4DShape(const std::vector &shape) const { return shape.size() == kShape4dDims; } - -void TbeKernelReduceSelecter::PadScalarShape(std::vector *shape) const { - MS_EXCEPTION_IF_NULL(shape); - if (shape->empty()) { - shape->emplace_back(1); - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h deleted file mode 100644 index e66525fd64..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ -#include -#include -#include -#include "ir/anf.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" -namespace mindspore { -namespace kernel { -class TbeKernelReduceSelecter { - public: - explicit TbeKernelReduceSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} - ~TbeKernelReduceSelecter() = default; - bool GetShapeInfo(SupportFormat *support_format); - bool IsReduceSupport5HD(SupportFormat *support_format) const; - bool IsReduceSupportNDC1HWC0(SupportFormat *support_format) const; - bool IsReduceSupportFracZ(SupportFormat *support_format) const; - bool IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const; - bool IsReduceSupportFracNZ(SupportFormat *support_format) const; - - private: - bool IsFracZAndC1HWNCoC0Common(const std::string &format, SupportFormat *support_format) const; - void GetReduceAttrAxis(); - void GetReduceAttrKeepDim(); - void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; - bool Is4DShape(const std::vector &shape) const; - void PadScalarShape(std::vector *shape) const; - CNodePtr cnode_ptr_; - std::vector input_shape_{}; - std::vector output_shape_{}; - std::vector axis_{}; - bool keep_dims_ = false; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h deleted file mode 100644 index c400bdbb6f..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_TBE_KERNEL_SELECT_H -#define MINDSPORE_TBE_KERNEL_SELECT_H - -#include -#include -#include -#include "kernel/oplib/opinfo.h" -#include "kernel/kernel_build_info.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" - -namespace mindspore { -namespace kernel { -void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); - -class TbeKernelSelect { - using OpInfoPtr = std::shared_ptr; - using KernelBuildInfoIter = std::vector>::iterator; - - public: - TbeKernelSelect(CNodePtr kernel_node, std::vector> *kernel_info_list); - ~TbeKernelSelect() = default; - void TbeMetadataInfoEx(); - - private: - void GetCommonPatternKernelInfo(const OpInfo &op_info); - void GetDynamicFormatPatternKernelInfo(const OpInfo &op_info); - void GetAgnosticPatternKernelInfo(const OpInfo &op_info); - void GetBroadcastPatternKernelInfo(const OpInfo &op_info); - void GetReducePatternKernelInfo(const OpInfo &op_info); - void FilterInVaildKernelInfo(); - bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); - static bool IsShapeMatchFormat(const std::vector &shape, const std::string &format); - bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); - static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); - bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, - const std::vector> &ios_info, const std::vector &dyn_input_sizes, - std::vector *formats, std::vector *device_types, - std::vector> *reshape_types); - static void StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec); - static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new); - static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, - const std::vector> &support_format_item, size_t index, - OpIOInfo *op_io_info_new); - // op select(dynamic) - void CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, mindspore::kernel::OpInfo *op_info_new); - static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector &support_dtype, - const std::vector &support_format, OpIOInfo *op_io_info_new); - static std::vector SplitStrToVec(const std::string &op_select_json_item); - std::string OpSelectFormat(); - - static void PrintSupportedFormat(const SupportFormat &support_format); - - private: - CNodePtr cnode_ptr_; - std::vector> *kernel_info_list_; - std::string node_name_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_TBE_KERNEL_SELECT_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.cc b/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.cc deleted file mode 100644 index 7204fb7f96..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.cc +++ /dev/null @@ -1,198 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_python_funcs.h" -#include "kernel/tbe/tbe_utils.h" -#include "common/utils.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -using mindspore::kernel::tbe::TbeUtils; -constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process"; -constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler"; -constexpr auto kOpSelectFormatFunc = "op_select_format"; -constexpr auto kCheckSupportedFunc = "check_supported"; -constexpr auto kTBEException = "TBEException"; - -PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr; -PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr; -PyObject *TbePythonFuncs::pOpSelectFormatFunc_ = nullptr; -PyObject *TbePythonFuncs::pCheckSupportedFunc_ = nullptr; -bool TbePythonFuncs::Init() { - static bool initialized = false; - if (initialized) { - return true; - } - // Initialize cache - TbeUtils::LoadCache(); - - // tbe_process - PyObject *pTbeProcessModule = nullptr; - pTbeProcessModule = PyImport_ImportModule(kTbeProcessModule); - if (pTbeProcessModule == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kTbeProcessModule << "] module."; - return false; - } - - pCreateTbeParallelCompilerFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCreateTbeParallelCompilerFunc); - if (pCreateTbeParallelCompilerFunc_ == nullptr) { - MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule - << "], FuncName:[" << kCreateTbeParallelCompilerFunc << "]."; - return false; - } - - pTbeCompiler_ = PyEval_CallObject(pCreateTbeParallelCompilerFunc_, nullptr); - if (pTbeCompiler_ == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function : create_parallel_compiler."; - return false; - } - - pOpSelectFormatFunc_ = PyObject_GetAttrString(pTbeProcessModule, kOpSelectFormatFunc); - if (pOpSelectFormatFunc_ == nullptr) { - MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule - << "], FuncName:[" << kOpSelectFormatFunc << "]."; - return false; - } - - pCheckSupportedFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCheckSupportedFunc); - if (pCheckSupportedFunc_ == nullptr) { - MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule - << "], FuncName:[" << kCheckSupportedFunc << "]."; - return false; - } - initialized = true; - MS_LOG(INFO) << "TbePythonFuncs initialized Success."; - return true; -} - -std::string TbePythonFuncs::PyObjectToStr(PyObject *PyObj) { - char *pChar = nullptr; - std::string str_res; - if (PyObj == nullptr) { - MS_LOG(ERROR) << "Input parameter is nullptr."; - return str_res; - } - PyObject *strArgs = PyObject_Str(PyObj); - if (strArgs != nullptr) { - (void)PyArg_Parse(strArgs, "s", &pChar); - } - if (pChar == nullptr) { - MS_LOG(ERROR) << "pChar is nullptr."; - return str_res; - } - str_res = pChar; - return str_res; -} - -std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) { - PyObject *pArg = nullptr; - PyObject *pRet = nullptr; - std::string res_json_str; - - if (!Init()) { - MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; - return res_json_str; - } - - // assembly Args - pArg = PyTuple_New(1); - std::string json_str = kernel_json.dump(); - (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", json_str.c_str())); - if (pArg == nullptr) { - MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject."; - return res_json_str; - } - - // call functions - if (pOpSelectFormatFunc_ == nullptr) { - MS_LOG(ERROR) << "function is nullptr."; - return res_json_str; - } - - pRet = PyEval_CallObject(pOpSelectFormatFunc_, pArg); - if (pRet == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc - << "], function args:" << PyObjectToStr(pArg); - } - - char *pstr = nullptr; - (void)PyArg_Parse(pRet, "s", &pstr); - res_json_str = pstr; - if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) { - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str - << " ,function args:" << PyObjectToStr(pArg); - } - return res_json_str; -} - -bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) { - PyObject *pArg = nullptr; - PyObject *pRes = nullptr; - bool ret = false; - - if (!Init()) { - MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; - return ret; - } - // assembly Args - pArg = PyTuple_New(1); - std::string json_str = kernel_json.dump(); - PyObject *arg1 = Py_BuildValue("s", json_str.c_str()); - (void)PyTuple_SetItem(pArg, 0, arg1); - if (pArg == nullptr) { - MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject."; - return ret; - } - - // call functions - if (pCheckSupportedFunc_ == nullptr) { - MS_LOG(ERROR) << "function is nullptr."; - return ret; - } - - pRes = PyEval_CallObject(pCheckSupportedFunc_, pArg); - if (pRes == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc - << "], function args: " << PyObjectToStr(pArg); - } - if (PyBool_Check(pRes)) { - ret = PyObject_IsTrue(pRes) != 0; - } else { - char *pstr = nullptr; - (void)PyArg_Parse(pRes, "s", &pstr); - std::string res_str = pstr; - if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) { - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str - << ", function args: " << PyObjectToStr(pArg); - } - } - - return ret; -} - -PyObject *TbePythonFuncs::TbeParallelCompiler() { - if (!Init()) { - MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; - return nullptr; - } - return pTbeCompiler_; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_utils.cc b/mindspore/ccsrc/kernel/tbe/tbe_utils.cc deleted file mode 100644 index ae7e5cb6d5..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_utils.cc +++ /dev/null @@ -1,254 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "runtime/kernel.h" -#include "kernel/oplib/oplib.h" -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "device/kernel_info.h" -#include "ir/dtype/type.h" -#include "kernel/tbe/tbe_convert_utils.h" -#include "securec/include/securec.h" -#include "operator/ops.h" - -namespace mindspore { -namespace kernel { -namespace tbe { -constexpr auto kCceKernelMeta = "./kernel_meta/"; -constexpr auto kJsonSuffix = ".json"; -constexpr auto kInfoSuffix = ".info"; - -uintptr_t KernelManager::kernel_stub_gen_ = 0; -std::unordered_map KernelManager::info_table_ = {}; - -void TbeUtils::SaveJsonInfo(const std::string &json_name, const std::string &info) { - char real_path[PATH_MAX] = {0}; - std::string path = kCceKernelMeta + json_name + kInfoSuffix; - if (path.size() > PATH_MAX) { - MS_LOG(ERROR) << "file path: " << path << "is too long."; - return; - } - std::ifstream fin(path); - if (fin) { - MS_LOG(INFO) << "json file exist, no need to create."; - return; - } - std::ofstream file_write; - file_write.open(path); - if (!file_write.is_open()) { - return; - } - file_write << info << std::endl; - file_write.close(); - if (realpath(path.c_str(), real_path) == nullptr) { - MS_LOG(INFO) << "dir: " << path << "does not exit."; - return; - } - MS_LOG(INFO) << "real path is: " << real_path; - if (chmod(real_path, S_IRUSR) == -1) { - MS_LOG(INFO) << "modify file: " << real_path << "to read only fail."; - } -} - -void TbeUtils::LoadCache() { - static bool has_load = false; - if (!has_load) { - KernelMeta *bin_map = KernelMeta::GetInstance(); - if (bin_map != nullptr && !bin_map->ReadIndex(kCceKernelMeta)) { - MS_LOG(INFO) << "Cache initialize failed[" << kCceKernelMeta << "]"; - } else { - MS_LOG(INFO) << "Cache initialize to " << kCceKernelMeta; - } - has_load = true; - } -} - -KernelPackPtr TbeUtils::SearchCache(const std::string &kernel_name, const std::string &processor) { - // search cache. - KernelMeta *bin_map = KernelMeta::GetInstance(); - if (bin_map == nullptr) { - MS_LOG(INFO) << "kernel cache is invalid."; - return nullptr; - } - return bin_map->GetKernelPack(kernel_name, processor); -} - -KernelPackPtr TbeUtils::InsertCache(const std::string &kernel_name, const std::string &processor) { - MS_LOG(INFO) << "kernel name: " << kernel_name << ", processr:" << processor; - if (processor != kProcessorAiCore) { - MS_LOG(EXCEPTION) << "process type should be aicore, actually is: " << processor; - } - return SearchCache(kernel_name, processor); -} - -int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buffer, void **module, - const string &magic) { - static std::map magic_maps = {{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF}, - {"RT_DEV_BINARY_MAGIC_PLAIN", RT_DEV_BINARY_MAGIC_PLAIN}, - {"RT_DEV_BINARY_MAGIC_PLAIN_AICPU", RT_DEV_BINARY_MAGIC_PLAIN_AICPU}, - {"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU}}; - // object for device register. - rtDevBinary_t dev_bin; - dev_bin.data = kernel_buffer.contents; - auto iter = magic_maps.find(magic); - if (iter == magic_maps.end()) { - MS_LOG(INFO) << "Invalid magic number: " << magic; - return -1; - } - dev_bin.magic = iter->second; - dev_bin.length = kernel_buffer.len; - dev_bin.version = 2; - if (RT_ERROR_NONE != rtDevBinaryRegister(&dev_bin, module)) { - MS_LOG(INFO) << "Call runtime rtDevBinaryRegister error."; - return -1; - } - return 0; -} - -uintptr_t KernelManager::GenFuncStub(const mindspore::kernel::KernelPack &kernel_pack, bool force_reload, - uint32_t *block_dim) { - auto kernel = kernel_pack.GetKernel(); - if (kernel == nullptr) { - MS_LOG(EXCEPTION) << "Invalid kernel pack, json or kernel is nullptr."; - } - auto kernel_contents = kernel->contents; - if (kernel_contents == nullptr) { - MS_LOG(EXCEPTION) << "Invalid kernel context, json or kernel is nullptr."; - } - auto kernel_json_info = kernel_pack.kernel_json_info(); - - *block_dim = kernel_json_info.block_dim; - string func_name = kernel_json_info.kernel_name; - string magic = kernel_json_info.magic; - - if (!force_reload) { - // use the cached object. - auto iter = info_table_.find(func_name); - if (iter != info_table_.end()) { - auto kernelmeta = iter->second; - *block_dim = kernelmeta->block_dim_; - return kernelmeta->func_stub_; - } - } - void *module = nullptr; - if (BinaryRegister((*kernel_pack.GetKernel()), &module, magic) != 0) { - MS_LOG(INFO) << "Call runtime BinaryRegister error."; - return 0; - } - // to diff different funcs. - uintptr_t func_stub = ++kernel_stub_gen_; - if (RT_ERROR_NONE != - rtFunctionRegister(module, reinterpret_cast(func_stub), func_name.c_str(), func_name.c_str(), 0)) { - MS_LOG(INFO) << "Call runtime rtFunctionRegister error."; - return 0; - } - // cache the registered kernelmeta. - info_table_[func_name] = std::make_shared(KernelMetaInfo{func_stub, *block_dim}); - return func_stub; -} - -std::string KernelManager::GetStubFuncName(const KernelPackPtr &kernel_pack) { - MS_EXCEPTION_IF_NULL(kernel_pack); - auto kernel_json_info = kernel_pack->kernel_json_info(); - return kernel_json_info.kernel_name; -} - -KernelMeta *KernelMeta::GetInstance() { - static KernelMeta inst; - return &inst; -} - -bool KernelMeta::ReadIndex(const std::string &bin_dir) { - DIR *dir = opendir(bin_dir.c_str()); - if (dir == nullptr) { - auto ret = mkdir(bin_dir.c_str(), S_IRWXG | S_IRWXU); - if (ret != 0) { - MS_LOG(INFO) << "kernel dir: " << bin_dir << "not exist"; - return false; - } - dir = opendir(bin_dir.c_str()); - } - struct dirent *entry; - while ((entry = readdir(dir)) != nullptr) { - string bin_dir_tmp = bin_dir; - std::string cce_json = entry->d_name; - if (cce_json.length() <= 5) { - continue; - } - std::string suffix = cce_json.substr(cce_json.length() - 5); - if (suffix != kJsonSuffix) { - continue; - } - auto sp = cce_json.rfind('/'); - if (sp != std::string::npos) { - continue; - } - sp = cce_json.rfind('.'); - if (sp == std::string::npos) { - continue; - } - auto kernel_name = cce_json.substr(0, sp); - (void)bin_dir_tmp.append("/"); - (void)bin_dir_tmp.append(cce_json); - kernel_index_map_[kernel_name] = bin_dir_tmp; - } - (void)closedir(dir); - - MS_LOG(INFO) << "Cache kernel initialized, kernel size: " << kernel_index_map_.size(); - return true; -} - -KernelPackPtr KernelMeta::GetKernelPack(const std::string &kernel_name, const std::string &processor) { - KernelPackPtr ret = nullptr; - // 1. pack has been created - auto kernel_pack_iter = kernel_pack_map_.find(kernel_name); - if (kernel_pack_iter != kernel_pack_map_.end()) { - MS_LOG(INFO) << "kernel pack [" << kernel_name << "]has been created."; - ret = kernel_pack_iter->second; - } else { - // 2. kernel file has been create, but pack does not been created. - std::string cce_json = kCceKernelMeta; - (void)cce_json.append(kernel_name).append(kJsonSuffix); - ret = std::make_shared(); - if (!ret->LoadKernelMeta(cce_json, processor)) { - MS_LOG(INFO) << "Read cache json and bin file failed[" << cce_json << "]"; - return nullptr; - } - kernel_pack_map_[kernel_name] = ret; - auto iter = kernel_index_map_.find(kernel_name); - if (iter == kernel_index_map_.end()) { - MS_LOG(INFO) << "kernel name [" << kernel_name << "] has been ceated first."; - kernel_index_map_[kernel_name] = cce_json; - } - } - return ret; -} -} // namespace tbe -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_utils.h b/mindspore/ccsrc/kernel/tbe/tbe_utils.h deleted file mode 100644 index 56fbe7967a..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_utils.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ -#include -#include -#include -#include -#include -#include - -#include "session/kernel_graph.h" -#include "ir/anf.h" -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -namespace tbe { -using std::string; -using std::vector; - -class TbeUtils { - public: - TbeUtils() = default; - - ~TbeUtils() = default; - - static void SaveJsonInfo(const std::string &json_name, const std::string &info); - - static void LoadCache(); - - static KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); - - static KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); -}; - -struct KernelMetaInfo { - uintptr_t func_stub_; - uint32_t block_dim_; -}; -using KernelMetaPtr = std::shared_ptr; - -class KernelManager { - public: - static uintptr_t GenFuncStub(const KernelPack &kernel_pack, bool force_reload, uint32_t *block_dim); - static std::string GetStubFuncName(const KernelPackPtr &kernel_pack); - - private: - KernelManager() = default; - ~KernelManager() = default; - static int BinaryRegister(const FlexArray &kernel_buffer, void **module, const string &magic); - static std::unordered_map info_table_; - static uintptr_t kernel_stub_gen_; -}; - -class KernelMeta { - public: - static KernelMeta *GetInstance(); - bool ReadIndex(const std::string &bin_dir); - KernelPackPtr GetKernelPack(const std::string &kernel_name, const std::string &processor); - - private: - KernelMeta() = default; - ~KernelMeta() = default; - std::unordered_map kernel_index_map_{}; - std::unordered_map kernel_pack_map_{}; -}; -} // namespace tbe -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt new file mode 100644 index 0000000000..df9729c4ee --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -0,0 +1,159 @@ +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-reorder") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-switch") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sequence-point") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-variable") + +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-uninitialized") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-maybe-uninitialized") +endif() +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") + +############################# Options ################################ +if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") + add_definitions(-D _CRT_RAND_S) +endif () +if (ENABLE_GPUQUE) + add_definitions(-D ENABLE_GPUQUE) + message(STATUS "GPU queue is enabled") +endif () +if (ENABLE_TDTQUE) + add_definitions(-D ENABLE_TDTQUE) + message(STATUS "TDT queue is enabled") +endif () + +# conde coverage +# option(ENABLE_COVERAGE "Enable code coverage report" OFF) +# if (ENABLE_COVERAGE) +# include(${CMAKE_SOURCE_DIR}/cmake/CodeCoverage.cmake) +# append_coverage_compiler_flags() +# endif () + +########### Set up the include directories ########################### +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/runtime/device/ascend/platform) + +include_directories(${CMAKE_BINARY_DIR}) # for protobuf generated .h + +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/mindrecord/include) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include) +###################################################################### + +####################### Flags ######################################## +# compile flags +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") + +ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${CMAKE_BINARY_DIR}) + +################## Include sub-modules ############################### +add_subdirectory(util) +add_subdirectory(core) +add_subdirectory(kernels) +add_subdirectory(engine) +add_subdirectory(api) +add_subdirectory(text) +###################################################################### +add_dependencies(utils core) +add_dependencies(kernels-image core) +add_dependencies(kernels-data core) +add_dependencies(kernels core) +add_dependencies(engine-datasetops-source core) +add_dependencies(engine-datasetops-source-sampler core) +add_dependencies(engine-datasetops core) +add_dependencies(engine-opt core) +add_dependencies(engine-perf core) +add_dependencies(engine-gnn core) +add_dependencies(engine core) +add_dependencies(text core) +add_dependencies(text-kernels core) +add_dependencies(cpp-API core) +if (ENABLE_PYTHON) + add_dependencies(APItoPython core) +endif() +if (ENABLE_TDTQUE) + add_dependencies(engine-tdt core) +endif () +################### Create _c_dataengine Library ###################### +set(submodules + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + ) + +if (ENABLE_PYTHON) + set(submodules + ${submodules} + $) +endif() + +if (ENABLE_TDTQUE) + add_library(_c_dataengine SHARED ${submodules} $) +else () + add_library(_c_dataengine SHARED ${submodules}) +endif () + +add_dependencies(_c_dataengine generated_engine_files) + +set_target_properties(_c_dataengine PROPERTIES + PREFIX "${PYTHON_MODULE_PREFIX}" + SUFFIX "${PYTHON_MODULE_EXTENSION}" + ) + +###################################################################### + +################# Link with external libraries ######################## +target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar) +if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") + target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY}) +else() + set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n) + target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY}) +endif() +target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::opencv_core mindspore::opencv_imgcodecs + mindspore::opencv_imgproc mindspore::tinyxml2 ${ICU_LIB}) +if (ENABLE_GPUQUE) + target_link_libraries(_c_dataengine PRIVATE gpu_queue + ${CUDNN_PATH}/lib64/libcudnn.so + ${CUDA_PATH}/lib64/libcudart.so + ${CUDA_PATH}/lib64/stubs/libcuda.so) +endif () + +if (ENABLE_TDTQUE) + target_link_libraries(_c_dataengine PRIVATE ${TSDCLIENT}) +endif () + +add_dependencies(_c_dataengine _c_mindrecord) +if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") + set(MINDRECORD_LINK_OBJECT ${CMAKE_BINARY_DIR}/mindspore/ccsrc/minddata/mindrecord/CMakeFiles/_c_mindrecord.dir/objects.a) + target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite) +else() + target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) +endif() + +if (USE_GLOG) + target_link_libraries(_c_dataengine PRIVATE mindspore::glog) +else() + if (CMAKE_SYSTEM_NAME MATCHES "Linux") + target_link_options(_c_dataengine PRIVATE -Wl,-init,mindspore_log_init) + elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") + set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON) + endif () +endif() diff --git a/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt new file mode 100644 index 0000000000..ae0b9cc28e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt @@ -0,0 +1,16 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +if (ENABLE_PYTHON) + add_library(APItoPython OBJECT + de_pipeline.cc + python_bindings.cc + ) + target_include_directories(APItoPython PRIVATE ${pybind11_INCLUDE_DIRS}) +endif() + +add_library(cpp-API OBJECT + datasets.cc + iterator.cc + transforms.cc + samplers.cc + ) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc new file mode 100644 index 0000000000..3072a62dc9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -0,0 +1,446 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/include/transforms.h" +#include "minddata/dataset/include/samplers.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/batch_op.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/datasetops/project_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +namespace api { + +#define RETURN_NULL_IF_ERROR(_s) \ + do { \ + Status __rc = (_s); \ + if (__rc.IsError()) { \ + return nullptr; \ + } \ + } while (false) + +// Function to create the iterator, which will build and launch the execution tree. +std::shared_ptr Dataset::CreateIterator() { + std::shared_ptr iter; + try { + iter = std::make_shared(); + Status rc = iter->BuildAndLaunchTree(shared_from_this()); + if (rc.IsError()) { + MS_LOG(ERROR) << "CreateIterator failed."; + return nullptr; + } + + return iter; + } catch (const std::exception &err) { + MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << err.what(); + return nullptr; + } + + return iter; +} + +// Constructor +Dataset::Dataset() { + // Fetch some default value from config manager + std::shared_ptr cfg = GlobalContext::config_manager(); + num_workers_ = cfg->num_parallel_workers(); + rows_per_buffer_ = cfg->rows_per_buffer(); + connector_que_size_ = cfg->op_connector_size(); +} + +// Function to create a ImageFolderDataset. +std::shared_ptr ImageFolder(std::string dataset_dir, bool decode, + std::shared_ptr sampler, std::set extensions, + std::map class_indexing) { + // This arg is exist in ImageFolderOp, but not externalized (in Python API). The default value is false. + bool recursive = false; + + // Create logical representation of ImageFolderDataset. + auto ds = std::make_shared(dataset_dir, decode, sampler, recursive, extensions, class_indexing); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a MnistDataset. +std::shared_ptr Mnist(std::string dataset_dir, std::shared_ptr sampler) { + auto ds = std::make_shared(dataset_dir, sampler); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a Cifar10Dataset. +std::shared_ptr Cifar10(const std::string &dataset_dir, int32_t num_samples, + std::shared_ptr sampler) { + auto ds = std::make_shared(dataset_dir, num_samples, sampler); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a Batch dataset +std::shared_ptr Dataset::Batch(int32_t batch_size, bool drop_remainder) { + // Default values + std::vector cols_to_map = {}; + std::map>> pad_map; + bool pad = false; + auto ds = std::make_shared(batch_size, drop_remainder, pad, cols_to_map, pad_map); + + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Function to create Repeat dataset. +std::shared_ptr Dataset::Repeat(int32_t count) { + // Workaround for repeat == 1, do not inject repeat. + if (count == 1) { + return shared_from_this(); + } + + auto ds = std::make_shared(count); + + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Function to create a Map dataset. +std::shared_ptr Dataset::Map(std::vector> operations, + std::vector input_columns, + std::vector output_columns, + const std::vector &project_columns) { + auto ds = std::make_shared(operations, input_columns, output_columns, project_columns); + + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Function to create a ShuffleOp +std::shared_ptr Dataset::Shuffle(int32_t shuffle_size) { + // Pass in reshuffle_each_epoch with true + auto ds = std::make_shared(shuffle_size, true); + + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Function to create a ProjectDataset. +std::shared_ptr Dataset::Project(const std::vector &columns) { + auto ds = std::make_shared(columns); + // Call derived class validation method. + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Helper function to create default RandomSampler. +std::shared_ptr CreateDefaultSampler() { + int32_t num_samples = 0; // 0 means to sample all ids. + bool replacement = false; + return std::make_shared(replacement, num_samples); +} + +/* ####################################### Derived Dataset classes ################################# */ + +ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, + bool recursive, std::set extensions, + std::map class_indexing) + : dataset_dir_(dataset_dir), + decode_(decode), + sampler_(sampler), + recursive_(recursive), + class_indexing_(class_indexing), + exts_(extensions) {} + +bool ImageFolderDataset::ValidateParams() { + if (dataset_dir_.empty()) { + MS_LOG(ERROR) << "No dataset path is specified."; + return false; + } + + return true; +} + +std::shared_ptr>> ImageFolderDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler, i.e., RandomSampler. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + // Do internal Schema generation. + // This arg is exist in ImageFolderOp, but not externalized (in Python API). + std::unique_ptr schema = std::make_unique(); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_NULL_IF_ERROR( + schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_NULL_IF_ERROR( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); + node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, + recursive_, decode_, exts_, class_indexing_, std::move(schema), + std::move(sampler_->Build()))); + return std::make_shared>>(node_ops); +} + +MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr sampler) + : dataset_dir_(dataset_dir), sampler_(sampler) {} + +bool MnistDataset::ValidateParams() { + if (dataset_dir_.empty()) { + MS_LOG(ERROR) << "No dataset path is specified."; + return false; + } + + return true; +} + +std::shared_ptr>> MnistDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler, i.e., RandomSampler. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_NULL_IF_ERROR( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + + node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, + std::move(schema), std::move(sampler_->Build()))); + return std::make_shared>>(node_ops); +} + +BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector cols_to_map, + std::map>> pad_map) + : batch_size_(batch_size), + drop_remainder_(drop_remainder), + pad_(pad), + cols_to_map_(cols_to_map), + pad_map_(pad_map) {} + +std::shared_ptr>> BatchDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + +#ifdef ENABLE_PYTHON + py::function noop; + node_ops.push_back(std::make_shared(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, + cols_to_map_, noop, noop, pad_map_)); +#else + node_ops.push_back(std::make_shared(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, + cols_to_map_, pad_map_)); +#endif + return std::make_shared>>(node_ops); +} + +bool BatchDataset::ValidateParams() { + if (batch_size_ <= 0) { + return false; + } + + return true; +} + +RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {} + +std::shared_ptr>> RepeatDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(repeat_count_)); + return std::make_shared>>(node_ops); +} + +bool RepeatDataset::ValidateParams() { + if (repeat_count_ <= 0) { + return false; + } + + return true; +} +MapDataset::MapDataset(std::vector> operations, std::vector input_columns, + std::vector output_columns, const std::vector &project_columns) + : operations_(operations), + input_columns_(input_columns), + output_columns_(output_columns), + project_columns_(project_columns) {} + +std::shared_ptr>> MapDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // Currently default is true, and this is not exposed to user. + bool perf_mode = true; + + std::vector> tensor_ops; + + // Build tensorOp from tensorOperation vector + // This is to ensure each iterator hold its own copy of the tensorOp objects. + (void)std::transform( + operations_.begin(), operations_.end(), std::back_inserter(tensor_ops), + [](std::shared_ptr operation) -> std::shared_ptr { return operation->Build(); }); + + // This parameter will be removed with next rebase + std::vector col_orders; + auto map_op = + std::make_shared(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_, perf_mode); + if (!project_columns_.empty()) { + auto project_op = std::make_shared(project_columns_); + node_ops.push_back(project_op); + } + + node_ops.push_back(map_op); + return std::make_shared>>(node_ops); +} + +bool MapDataset::ValidateParams() { + if (operations_.empty()) { + return false; + } + + return true; +} + +// Constructor for ShuffleDataset +ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch) + : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {} + +// Function to build the ShuffleOp +std::shared_ptr>> ShuffleDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, + rows_per_buffer_)); + return std::make_shared>>(node_ops); +} + +// Function to validate the parameters for ShuffleDataset +bool ShuffleDataset::ValidateParams() { + if (shuffle_size_ <= 1) { + MS_LOG(ERROR) << "ShuffleDataset: Invalid input, shuffle_size: " << shuffle_size_; + return false; + } + + return true; +} + +// Constructor for Cifar10Dataset +Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler) + : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {} + +bool Cifar10Dataset::ValidateParams() { + if (dataset_dir_.empty()) { + MS_LOG(ERROR) << "No dataset path is specified."; + return false; + } + if (num_samples_ < 0) { + MS_LOG(ERROR) << "Number of samples cannot be negative"; + return false; + } + return true; +} + +// Function to build CifarOp +std::shared_ptr>> Cifar10Dataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler based on the shuffle variable. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_NULL_IF_ERROR( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + + node_ops.push_back(std::make_shared(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_, + dataset_dir_, connector_que_size_, std::move(schema), + std::move(sampler_->Build()))); + return std::make_shared>>(node_ops); +} + +// Function to build ProjectOp +ProjectDataset::ProjectDataset(const std::vector &columns) : columns_(columns) {} + +bool ProjectDataset::ValidateParams() { + if (columns_.empty()) { + MS_LOG(ERROR) << "No columns are specified."; + return false; + } + return true; +} + +std::shared_ptr>> ProjectDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(columns_)); + return std::make_shared>>(node_ops); +} + +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc new file mode 100644 index 0000000000..2a6166f868 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc @@ -0,0 +1,1605 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/api/de_pipeline.h" + +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/filter_op.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/kernels/py_func_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" +#include "minddata/mindrecord/include/shard_category.h" +#include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" +#include "pybind11/stl.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *, std::shared_ptr *); + +static std::unordered_map g_parse_op_func_ = { + {kShuffle, &DEPipeline::ParseShuffleOp}, + {kMindrecord, &DEPipeline::ParseMindRecordOp}, + {kMap, &DEPipeline::ParseMapOp}, + {kFilter, &DEPipeline::ParseFilterOp}, + {kBatch, &DEPipeline::ParseBatchOp}, + {kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp}, + {kBarrier, &DEPipeline::ParseBarrierOp}, + {kRepeat, &DEPipeline::ParseRepeatOp}, + {kSkip, &DEPipeline::ParseSkipOp}, + {kZip, &DEPipeline::ParseZipOp}, + {kConcat, &DEPipeline::ParseConcatOp}, + {kRename, &DEPipeline::ParseRenameOp}, + {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, + {kGenerator, &DEPipeline::ParseGeneratorOp}, + {kTfReader, &DEPipeline::ParseTFReaderOp}, + {kProject, &DEPipeline::ParseProjectOp}, + {kTake, &DEPipeline::ParseTakeOp}, + {kImageFolder, &DEPipeline::ParseImageFolderOp}, + {kMnist, &DEPipeline::ParseMnistOp}, + {kManifest, &DEPipeline::ParseManifestOp}, + {kVoc, &DEPipeline::ParseVOCOp}, + {kCoco, &DEPipeline::ParseCocoOp}, + {kCifar10, &DEPipeline::ParseCifar10Op}, + {kCifar100, &DEPipeline::ParseCifar100Op}, + {kCelebA, &DEPipeline::ParseCelebAOp}, + {kRandomData, &DEPipeline::ParseRandomDataOp}, + {kTextFile, &DEPipeline::ParseTextFileOp}, + {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, + {kClue, &DEPipeline::ParseClueOp}}; + +DEPipeline::DEPipeline() : iterator_(nullptr) { + try { + // One time init + (void)GlobalInit(); + + // Instantiate the execution tree + tree_ = std::make_shared(); + repeat_num_ = 1; + batch_size_ = 1; + num_rows_ = 0; + num_classes_ = 0; + temp_batch_size_ = 1; + temp_drop_remainder_ = false; + } catch (const std::exception &err) { + MS_LOG(ERROR) << "Dataset pipeline exception caught on init: " << err.what() << "."; + return; + } +} + +DEPipeline::~DEPipeline() { + { + // Release GIL before joining all threads + py::gil_scoped_release gil_release; + // Release tree + tree_.reset(); + } +} + +// Function to add a Node to the Execution Tree. +Status DEPipeline::AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output) { + // For each operator, Parse through the list of arguments, then call the respective builder/constructor. + // Note that each call to the parse function may result in building more than one dataset operator. + // For example, one call to ParseNNNOp may result in multiple internal C nodes: + // nodeA + // | + // nodeB + // | + // nodeC + // However, the python side dataset is more abstract, and it does not know about the potential subtree that + // is being built here. Since the python api is hooking tree nodes together (parent/child hookups), the + // python side needs to know about nodeA and NodeC to be able to appropriately hook up parents and child + // to this subtee. + // Thus, it is required that both the top-most parent and bottom-most child are returned from the parse + // function. + DsOpPtr top = nullptr; + DsOpPtr bottom = nullptr; + auto iter = g_parse_op_func_.find(op_name); + if (iter != g_parse_op_func_.end()) { + pFunction func = iter->second; + RETURN_IF_NOT_OK((this->*func)(args, &top, &bottom)); + + if (top == nullptr) { + RETURN_STATUS_UNEXPECTED("An operator was parsed but it did not produce a C node."); + } + + // It is not required that the parse function always produces the bottom pointer. If it's still null, + // then set top and bottom to be the same operator + if (bottom == nullptr) bottom = top; + + // Pack these pointers into a py dict so that we can return both back to python. + (*output)["top"] = top; + (*output)["bottom"] = bottom; + } else { + RETURN_STATUS_UNEXPECTED("No such Op"); + } + // Associate current dataset op node with the tree. + RETURN_IF_NOT_OK(tree_->AssociateNode(top)); + return Status::OK(); +} +// Function to add a child and parent relationship. +Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op) { + // Link this relationship. + // Note parent node takes ownership of the child + return (parent_op->AddChild(child_op)); +} + +// Function to assign the node as root. +Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } + +// Function to launch the tree execution. +Status DEPipeline::LaunchTreeExec() { + RETURN_IF_NOT_OK(tree_->Prepare()); + RETURN_IF_NOT_OK(tree_->Launch()); + iterator_ = std::make_unique(tree_); + if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); + return Status::OK(); +} + +void DEPipeline::PrintTree() { + for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) { + std::stringstream ss; + ss << *itr; + MS_LOG(DEBUG) << "Operator ID is " << itr->id() << ". Details: " << ss.str().c_str() << "."; + } +} + +Status DEPipeline::GetNextAsMap(py::dict *output) { + TensorMap row; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->GetNextAsMap(&row); + } + RETURN_IF_NOT_OK(s); + // Generate Python dict as return + for (auto el : row) { + (*output)[common::SafeCStr(el.first)] = el.second; + } + return Status::OK(); +} + +Status DEPipeline::GetNextAsList(py::list *output) { + TensorRow row; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->FetchNextTensorRow(&row); + } + RETURN_IF_NOT_OK(s); + // Generate Python list as return + for (auto el : row) { + output->append(el); + } + return Status::OK(); +} + +Status DEPipeline::GetOutputShapes(py::list *output) { + std::vector shapes; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->GetOutputShapes(&shapes); + } + RETURN_IF_NOT_OK(s); + for (auto el : shapes) { + py::list shape; + for (auto dim : el.AsVector()) { + shape.append(dim); + } + output->append(shape); + } + return Status::OK(); +} + +Status DEPipeline::GetOutputTypes(py::list *output) { + std::vector types; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->GetOutputTypes(&types); + } + RETURN_IF_NOT_OK(s); + for (auto el : types) { + output->append(el.AsNumpyType()); + } + return Status::OK(); +} + +int DEPipeline::GetDatasetSize() const { return num_rows_ / batch_size_; } + +int DEPipeline::GetBatchSize() const { return batch_size_; } + +int DEPipeline::GetRepeatCount() const { return repeat_num_; } + +float ToFloat(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +int ToInt(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +bool ToBool(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +std::string ToString(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +std::vector ToStringVector(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::vector vector; + for (auto l : list) { + if (!l.is_none()) + vector.push_back(py::str(l)); + else + vector.emplace_back(""); + } + return vector; +} + +std::set ToStringSet(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::set set; + for (auto l : list) { + if (!l.is_none()) { + (void)set.insert(py::str(l)); + } + } + return set; +} + +std::map ToStringMap(const py::handle handle) { + py::dict dict = py::reinterpret_borrow(handle); + std::map map; + for (auto p : dict) { + (void)map.insert(std::make_pair(ToString(p.first), ToInt(p.second))); + } + return map; +} + +std::vector ToIntVector(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::vector vector; + for (auto l : list) { + if (!l.is_none()) { + vector.push_back(ToInt(l)); + } + } + return vector; +} + +std::vector ToTypeVector(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::vector vector; + for (auto l : list) { + if (l.is_none()) { + vector.emplace_back(DataType()); + } else { + vector.push_back(l.cast()); + } + } + return vector; +} + +Status DEPipeline::SetBatchParameters(const py::dict &args) { + if (args["batch_size"].is_none()) { + std::string err_msg = "Error: batchSize is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + temp_batch_size_ = ToInt(args["batch_size"]); + CHECK_FAIL_RETURN_UNEXPECTED(temp_batch_size_ > 0, "Error: batchSize is invalid."); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "drop_remainder") { + temp_drop_remainder_ = ToBool(value); + } + } + } + + return Status::OK(); +} + +Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + if (!args["buffer_size"].is_none()) { + (void)builder->SetShuffleSize(ToInt(args["buffer_size"])); + } else { + std::string err_msg = "Error: Shuffle buffer size is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "reshuffle_each_epoch") { + (void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"])); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, + std::vector> *operators, + int num_padded) { + auto sampler = py::reinterpret_borrow(handle); + auto create = sampler.attr("create_for_minddataset"); + auto op = create().cast>(); + std::stack> stack_ops; + while (op != nullptr) { + auto sampler_op = std::dynamic_pointer_cast(op); + if (sampler_op && num_padded > 0) { + sampler_op->SetNumPaddedSamples(num_padded); + stack_ops.push(sampler_op); + } else { + stack_ops.push(op); + } + op = op->GetChildOp(); + } + while (!stack_ops.empty()) { + operators->push_back(stack_ops.top()); + stack_ops.pop(); + } + return Status::OK(); +} + +Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["dataset_file"].is_none()) { + std::string err_msg = "Error: at least one of dataset_files is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + bool load_dataset = ToBool(args["load_dataset"]); + if (load_dataset == true) { + (void)builder->SetDatasetFile({ToString(args["dataset_file"])}); + } else { + (void)builder->SetDatasetFile(ToStringVector(args["dataset_file"])); + } + (void)builder->SetLoadDataset(load_dataset); + std::vector in_col_names; + if (!args["columns_list"].is_none()) { + in_col_names = ToStringVector(args["columns_list"]); + if (in_col_names.empty() || in_col_names[0].empty()) { + std::string err_msg = "Error: columns_list is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)builder->SetColumnsToLoad(in_col_names); + } + + if (!args["padded_sample"].is_none()) { + (void)builder->SetPaddedSample(args["padded_sample"]); + (void)builder->SetNumToPadSamples(ToInt(args["num_padded"])); + } + std::vector> operators; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumMindRecordWorkers(ToInt(value)); + } else if (key == "block_reader" && ToBool(value) == true) { + (void)builder->SetBlockReader(); + } else if (key == "sampler") { + int num_padded = 0; + if (!args["num_padded"].is_none()) { + num_padded = ToInt(args["num_padded"]); + } + RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded)); + } + } + } + + if (!operators.empty()) { + (void)builder->SetOperators(operators); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + num_rows_ = op->num_rows(); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + MapOp::Builder map_builder; + std::vector> tensor_op_list; + std::vector project_columns; + std::shared_ptr cache_client = nullptr; + int num_workers = 0; + + if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n"); + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "input_columns") { + std::vector in_col_names = ToStringVector(args["input_columns"]); + (void)map_builder.SetInColNames(in_col_names); + } else if (key == "output_columns") { + (void)map_builder.SetOutColNames(ToStringVector(value)); + } else if (key == "columns_order") { + project_columns = ToStringVector(value); + } else if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)map_builder.SetNumWorkers(num_workers); + } else if (key == "prefetch_size") { + (void)map_builder.SetOpConnectorSize(ToInt(value)); + } else if (key == "operations") { + py::handle tensor_ops = args["operations"]; + // operation can be a list of TensorOps or a single TensorOp. + if (py::isinstance(tensor_ops)) { + for (auto op : tensor_ops) { + std::shared_ptr tensor_op; + if (py::isinstance(op)) { + tensor_op = op.cast>(); + } else if (py::isinstance(op)) { + tensor_op = std::make_shared(op.cast()); + } else { + RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); + } + tensor_op_list.push_back(tensor_op); + } + } + if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set."); + (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else { + RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); + } + } + } + + std::shared_ptr map_op; + RETURN_IF_NOT_OK(map_builder.Build(&map_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(map_op)); + *top = map_op; + + // Add a project op over top of the map if the user wanted to reposition the columns + if (!project_columns.empty()) { + ProjectOp::Builder proj_builder(project_columns); + std::shared_ptr proj_op; + RETURN_IF_NOT_OK(proj_builder.Build(&proj_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(proj_op)); + RETURN_IF_NOT_OK(proj_op->AddChild(map_op)); + *top = proj_op; + *bottom = map_op; + } + + // Additionally, add a cache if required. This will go over top of the project op if one + // was created, otherwise it goes over top of the map op + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, *top, &cache_op)); + *top = cache_op; + *bottom = map_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + + if (args["predicate"].is_none()) { + RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); + } + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "predicate") { + py::handle op = args["predicate"]; + if (!py::isinstance(op)) { + RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); + } + py::function predicate_func = op.cast(); + (void)builder->SetPredicateFunc(std::move(predicate_func)); + } else if (key == "input_columns") { + std::vector in_col_names = ToStringVector(args["input_columns"]); + (void)builder->SetInColNames(in_col_names); + } else { + RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + repeat_num_ = ToInt(args["count"]); + std::shared_ptr op; + RETURN_IF_NOT_OK(RepeatOp::Builder(ToInt(args["count"])).Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(SkipOp::Builder(ToInt(args["count"])).Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "source") { + py::object obj = py::cast(&value); + if (!py::isinstance(obj)) { + std::string err_msg = "Error: generator is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)builder->SetGeneratorFunction(obj.cast()); + } else if (key == "column_names") { + (void)builder->SetColumnNames(ToStringVector(value)); + } else if (key == "column_types") { + (void)builder->SetColumnTypes(ToTypeVector(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder; + if (py::isinstance(args["batch_size"])) { + batch_size_ = ToInt(args["batch_size"]); + CHECK_FAIL_RETURN_UNEXPECTED(batch_size_ > 0, "Error: batch_size is invalid."); + builder = std::make_shared(ToInt(args["batch_size"])); + } else if (py::isinstance(args["batch_size"])) { + builder = std::make_shared(1); + (void)builder->SetBatchSizeFunc(args["batch_size"].cast()); + } else { + std::string err_msg = "Error: batch_size is neither an Integer nor a python function"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "drop_remainder") { + (void)builder->SetDrop(ToBool(value)); + } + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } + if (key == "per_batch_map") { + (void)builder->SetBatchMapFunc(value.cast()); + } + if (key == "input_columns") { + (void)builder->SetColumnsToMap(ToStringVector(value)); + } + if (key == "pad_info") { + PadInfo pad_info; + RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); + (void)builder->SetPaddingMap(pad_info, true); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", + "bucket_batch_sizes"}; + for (auto name : mandatory_arguments) { + if (args[name.c_str()].is_none()) { + std::string err_msg = "Error: " + name + " is not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + + std::shared_ptr builder = std::make_shared( + ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]), + ToIntVector(args[mandatory_arguments[2].c_str()])); + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "length_dependent_columns") { + (void)builder->SetLengthDependentColumns(ToStringVector(value)); + } + if (key == "bucket_boundaries") { + (void)builder->SetBucketBoundaries(ToIntVector(value)); + } + if (key == "bucket_batch_sizes") { + (void)builder->SetBucketBatchSizes(ToIntVector(value)); + } + if (key == "element_length_function") { + (void)builder->SetElementLengthFunction(value.cast()); + } + if (key == "pad_info") { + PadInfo pad_info; + RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); + (void)builder->SetPadInfo(pad_info); + } + if (key == "pad_to_bucket_boundary") { + (void)builder->SetPadToBucketBoundary(ToBool(value)); + } + if (key == "drop_remainder") { + (void)builder->SetDropRemainder(ToBool(value)); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + // Right now barrier should only take num_rows_per_buffer = 1 + // The reason for this is because having it otherwise can lead to blocking issues + // See barrier_op.h for more details + (void)builder->SetRowsPerBuffer(1); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "condition_name") { + (void)builder->SetConditionName(ToString(value)); + } else if (key == "condition_func") { + (void)builder->SetConditionFunc(value.cast()); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + int32_t prefetch_size = 0; + if (args.contains("prefetch_size")) { + if (args["prefetch_size"].is_none()) { + prefetch_size = 16; + } else { + prefetch_size = ToInt(args["prefetch_size"]); + } + } + std::shared_ptr builder = std::make_shared(prefetch_size); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "queue_name") { + (void)builder->SetChannelName(ToString(value)); + } else if (key == "device_type") { + (void)builder->SetDeviceType(ToString(value)); + } else if (key == "device_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "num_batch") { + (void)builder->SetNumBatch(ToInt(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector in_col_names; + std::vector out_col_names; + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "input_columns") { + in_col_names = ToStringVector(value); + } else if (key == "output_columns") { + out_col_names = ToStringVector(value); + } + } + } + if (in_col_names.empty() || in_col_names[0].empty()) { + std::string err_msg = "Error: input_column_names is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (out_col_names.empty() || out_col_names[0].empty()) { + std::string err_msg = "Error: output_column_names is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)builder->SetInColNames(in_col_names); + (void)builder->SetOutColNames(out_col_names); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + std::vector files_list; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetDatasetFilesList(files_list); + } else { + std::string err_msg = "Error: at least one of dataset_files or schema_file is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::vector columns_to_load; + bool schema_exists = false; + bool shuffle_required = false; + int64_t num_devices = 0; + int64_t total_rows = 0; + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); + } else if (key == "columns_list") { + columns_to_load = ToStringVector(value); + (void)builder->SetColumnsToLoad(columns_to_load); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "schema_file_path" || key == "schema_json_string") { + schema_exists = true; + } else if (key == "num_samples") { + total_rows = ToInt(value); + (void)builder->setTotalRows(total_rows); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "shard_equal_rows") { + (void)builder->SetShardEqualRows(ToBool(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); + } + } + } + if (schema_exists) { + std::unique_ptr schema = std::make_unique(); + if (args.contains("schema_file_path")) { + RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); + } else { + RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); + } + (void)builder->SetDataSchema(std::move(schema)); + } + + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because TFReaderOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder->SetSampler(std::move(sampler)); + } else if (cache_client) { + int64_t num_samples = 0; + int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder->SetSampler(std::move(sampler)); + } + + std::shared_ptr tf_op; + RETURN_IF_NOT_OK(builder->Build(&tf_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op)); + *top = tf_op; + + if (!cache_client && shuffle_required) { + const boolean estimate = true; + const int64_t workers = 8; + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset via estimate and then compute the shuffle size + RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, files_list, workers, estimate)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, total_rows, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, tf_op, &shuffle_op)); + *top = shuffle_op; + *bottom = tf_op; + } + + // Add a cache op over this op if required and update the output subtree (top/bottom) + if (cache_client) { + // Note, it is not allowed to have both shuffle and cache + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, tf_op, &cache_op)); + *top = cache_op; + *bottom = tf_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParseProjectOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["columns"].is_none()) { + std::string err_msg = "Error: columns is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::vector columns_to_project = ToStringVector(args["columns"]); + std::shared_ptr builder = std::make_shared(columns_to_project); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + int num_workers = 0; + std::shared_ptr cache_client = nullptr; + std::shared_ptr builder = std::make_shared(); + (void)builder->SetImageFolderDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "extensions") { + (void)builder->SetExtensions(ToStringSet(value)); + } else if (key == "class_indexing") { + (void)builder->SetClassIndex(ToStringMap(value)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } + } + } + std::shared_ptr if_op; + RETURN_IF_NOT_OK(builder->Build(&if_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(if_op)); + *top = if_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, if_op, &cache_op)); + *top = cache_op; + *bottom = if_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_file"].is_none()) { + std::string err_msg = "Error: No dataset files specified for manifest"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr builder = std::make_shared(); + (void)builder->SetManifestFile(ToString(args["dataset_file"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "class_indexing") { + (void)builder->SetClassIndex(ToStringMap(value)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "usage") { + (void)builder->SetUsage(ToString(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["task"].is_none()) { + std::string err_msg = "Error: No task specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["mode"].is_none()) { + std::string err_msg = "Error: No mode specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetDir(ToString(args["dataset_dir"])); + (void)builder->SetTask(ToString(args["task"])); + (void)builder->SetMode(ToString(args["mode"])); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "class_indexing") { + (void)builder->SetClassIndex(ToStringMap(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + + return Status::OK(); +} + +Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["annotation_file"].is_none()) { + std::string err_msg = "Error: No annotation_file specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["task"].is_none()) { + std::string err_msg = "Error: No task specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetDir(ToString(args["dataset_dir"])); + (void)builder->SetFile(ToString(args["annotation_file"])); + (void)builder->SetTask(ToString(args["task"])); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetCifarDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } + } + } + + (void)builder->SetCifarType(true); + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetCifarDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } + } + } + + (void)builder->SetCifarType(false); + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + RandomDataOp::Builder builder; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; + + if (args["total_rows"].is_none()) { + std::string err_msg = "Error: total_rows is a required argument"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::vector columns_to_load; + bool schema_exists = false; + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)builder.SetNumWorkers(num_workers); + } else if (key == "schema_file_path" || key == "schema_json_string") { + schema_exists = true; + } else if (key == "columns_list") { + columns_to_load = ToStringVector(value); + } else if (key == "total_rows") { + // This is not sampling here. The random data op needs to know how much data to generate. + (void)builder.SetTotalRows(ToInt(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); + } + } + } + if (schema_exists) { + std::unique_ptr schema = std::make_unique(); + if (args.contains("schema_file_path")) { + RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); + } else { + RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); + } + (void)builder.SetDataSchema(std::move(schema)); + } + + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because RandomDataOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder.SetSampler(std::move(sampler)); + } else if (cache_client) { + int64_t num_samples = 0; + int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder.SetSampler(std::move(sampler)); + } + + std::shared_ptr random_op = nullptr; + RETURN_IF_NOT_OK(builder.Build(&random_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(random_op)); + *top = random_op; + + // Add a cache op over this op if required and update the output subtree (top/bottom) + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, random_op, &cache_op)); + *top = cache_op; + *bottom = random_op; + } + + return Status::OK(); +} + +int32_t DEPipeline::GetNumClasses() const { return num_classes_; } + +Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + } + + std::shared_ptr builder = std::make_shared(); + if (builder == nullptr) { + std::string err_msg = "Create celebaop builder failed"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + } + (void)builder->SetCelebADir(ToString(args["dataset_dir"])); + for (const auto &arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "extensions") { + (void)builder->SetExtensions(ToStringSet(value)); + } else if (key == "dataset_type") { + (void)builder->SetDatasetType(ToString(value)); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + std::vector files_list; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetTextFilesList(files_list); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + // Optional arguments + bool shuffle_required = false; + int64_t num_devices = 0; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "num_samples") { + (void)builder->SetTotalRows(ToInt(value)); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } + } + } + + std::shared_ptr txt_op; + RETURN_IF_NOT_OK(builder->Build(&txt_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); + *top = txt_op; + + if (shuffle_required) { + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset and then compute the shuffle size + RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(files_list, &num_rows)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, txt_op, &shuffle_op)); + *top = shuffle_op; + *bottom = txt_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { + for (auto p : py::reinterpret_borrow(value)) { + if (!p.second.is_none()) { + auto tp = py::reinterpret_borrow(p.second); + CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)"); + TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]); + std::shared_ptr pad_val = nullptr; + if (py::isinstance(tp[1])) { + std::string pad_val_string = tp[1].is_none() ? "" : ToString(tp[1]); + CHECK_FAIL_RETURN_UNEXPECTED( + Tensor::CreateTensor(&pad_val, std::vector{pad_val_string}, TensorShape::CreateScalar()), + "Cannot create pad_value Tensor"); + } else { + float pad_val_float = tp[1].is_none() ? 0 : ToFloat(tp[1]); + CHECK_FAIL_RETURN_UNEXPECTED(Tensor::CreateTensor(&pad_val, TensorImpl::kFlexible, TensorShape::CreateScalar(), + DataType(DataType::DE_FLOAT32)), + "Cannot create pad_value Tensor"); + pad_val->SetItemAt({}, pad_val_float); + } + (void)pad_info->insert({ToString(p.first), {shape, pad_val}}); + } else { // tuple is None + (void)pad_info->insert({ToString(p.first), {TensorShape({}), nullptr}}); + } + } + return Status::OK(); +} + +Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "freq_range") { + py::tuple tp = py::reinterpret_borrow(value); + if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow(tp[0])); + if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow(tp[1])); + } else if (key == "top_k") { + builder->SetTopK(py::reinterpret_borrow(value)); + } else if (key == "columns") { + (void)builder->SetColumnNames(ToStringVector(value)); + } else if (key == "vocab") { + (void)builder->SetVocab(value.cast>()); + } else if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "special_first") { + (void)builder->SetSpecialFirst(ToBool(value)); + } else if (key == "special_tokens") { + (void)builder->SetSpecialTokens(ToStringVector(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector files_list; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetClueFilesList(files_list); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + // Optional arguments + bool shuffle_required = false; + int64_t num_devices = 0; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "num_samples") { + (void)builder->SetNumSamples(ToInt(value)); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "cols_to_keyword") { + std::map map_dict; + for (auto p : py::reinterpret_borrow(value)) { + if (!p.second.is_none()) { + map_dict.insert({ToString(p.first), ToString(p.second)}); + } else { + map_dict.insert({ToString(p.first), ToString(p.first)}); + } + } + (void)builder->SetColsKeyMap(map_dict); + } + } + } + + std::shared_ptr clue_op; + RETURN_IF_NOT_OK(builder->Build(&clue_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); + *top = clue_op; + + if (shuffle_required) { + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset and then compute the shuffle size + RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(files_list, &num_rows)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, clue_op, &shuffle_op)); + *top = shuffle_op; + *bottom = clue_op; + } + + return Status::OK(); +} + +// Helper function to inject the cache operator over top of the current operation being built. +Status DEPipeline::AddCacheOp(std::shared_ptr cache_client, int num_workers, + std::shared_ptr input_op, std::shared_ptr *cache_op) { + std::shared_ptr new_cache_op = nullptr; + CacheOp::Builder cache_builder; + // use the same number of workers as the leaf. We need some optimization here, the user does not + // give the cache op number of workers directly. + if (num_workers != 0) { + (void)cache_builder.SetNumWorkers(num_workers); + } + (void)cache_builder.SetClient(cache_client); + RETURN_IF_NOT_OK(cache_builder.Build(&new_cache_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(new_cache_op)); + RETURN_IF_NOT_OK(new_cache_op->AddChild(input_op)); + // We have now created: + // + // CacheOp + // | + // input_op + // + *cache_op = new_cache_op; + + return Status::OK(); +} + +// Helper function to inject a shuffle operator over top of the current operation being built. +Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, + std::shared_ptr *shuffle_op) { + std::shared_ptr new_shuffle_op = nullptr; + ShuffleOp::Builder shuffle_builder; + + (void)shuffle_builder.SetShuffleSize(shuffle_size); + RETURN_IF_NOT_OK(shuffle_builder.Build(&new_shuffle_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(new_shuffle_op)); + RETURN_IF_NOT_OK(new_shuffle_op->AddChild(input_op)); + // We have now created: + // + // ShuffleOp + // | + // input_op + // + *shuffle_op = new_shuffle_op; + + return Status::OK(); +} + +// Common code for computing a default shuffle size +Status DEPipeline::ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, + int64_t *shuffle_size) { + const int64_t average_files_multiplier = 4; + const int64_t shuffle_max = 10000; + int64_t avg_rows_per_file = 0; + + // Adjust the num rows per shard if sharding was given + if (num_devices > 0) { + if (num_rows % num_devices == 0) { + num_rows = num_rows / num_devices; + } else { + num_rows = (num_rows / num_devices) + 1; + } + } + + // Cap based on total rows directive. Some ops do not have this and give value of 0. + if (total_rows > 0) { + num_rows = std::min(num_rows, total_rows); + } + + // get the average per file + avg_rows_per_file = num_rows / num_files; + + *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h new file mode 100644 index 0000000000..755e827ef2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h @@ -0,0 +1,225 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_API_DE_PIPELINE_H_ +#define DATASET_API_DE_PIPELINE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/client.h" // DE client +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/util/status.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; +namespace mindspore { +namespace dataset { +using DsOpPtr = std::shared_ptr; + +class CacheClient; + +// enum for the dataset operator names +enum OpName { + kShuffle, + kMindrecord, + kBatch, + kBucketBatch, + kBarrier, + kCache, + kRepeat, + kSkip, + kTake, + kZip, + kConcat, + kMap, + kFilter, + kDeviceQueue, + kGenerator, + kRename, + kTfReader, + kProject, + kImageFolder, + kMnist, + kManifest, + kVoc, + kCoco, + kCifar10, + kCifar100, + kCelebA, + kRandomData, + kTextFile, + kBuildVocab, + kClue +}; + +// The C++ binder class that we expose to the python script. +class DEPipeline { + public: + DEPipeline(); + + ~DEPipeline(); + + // Function to add a Node to the Execution Tree. + Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output); + + // Function to add a child and parent relationship. + static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op); + + // Function to assign the node as root. + Status AssignRootNode(const DsOpPtr &dataset_op); + + // Function to launch the tree execution. + Status LaunchTreeExec(); + + // Get a row of data as dictionary of column name to the value. + Status GetNextAsMap(py::dict *output); + + // Get a row of data as list. + Status GetNextAsList(py::list *output); + + Status GetOutputShapes(py::list *output); + + Status GetOutputTypes(py::list *output); + + int GetDatasetSize() const; + + int GetBatchSize() const; + + int GetRepeatCount() const; + + Status ParseShuffleOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status BuildMindrecordSamplerChain(const py::handle &handle, + std::vector> *operators, + int num_padded); + + Status ParseMapOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseFilterOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseRepeatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseSkipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseBatchOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom); + + Status ParseBarrierOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseRenameOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseTakeOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseZipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseConcatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseProjectOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseManifestOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseVOCOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseCocoOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseCifar10Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseCifar100Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + void PrintTree(); + + int32_t GetNumClasses() const; + + Status ParseMnistOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status SetBatchParameters(const py::dict &args); + + Status ParseCelebAOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseTextFileOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseClueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + private: + // Execution tree that links the dataset operators. + std::shared_ptr tree_; + + std::unique_ptr iterator_; + + static Status ParsePadInfo(py::handle value, PadInfo *pad_info); + + /// \brief Helper function to inject a cache operator over top of the current operation being built. + /// \param[in] cache_client The client to use for caching + /// \param[in] num_workers The number of workers to use in the cache op + /// \param[in] input_op The operator to build the cache on top of + /// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be + /// the cache operator + /// \return Status return code + Status AddCacheOp(std::shared_ptr cache_client, int num_workers, std::shared_ptr input_op, + std::shared_ptr *cache_op); + + /// \brief Helper function to inject a shuffle operator over top of the current operation being built. + /// \param[in] shuffle_size The size to use in the shuffle buffer + /// \param[in] input_op The operator to build shuffle on top of + /// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be + /// the shuffle operator + /// \return Status return code + Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, + std::shared_ptr *shuffle_op); + + /// \brief Helper function to compute the shuffle size + /// \param[in] num_files The number of files in the dataset + /// \param[in] num_devices The number of devices in the dataset + /// \param[in] num_rows The number of rows in the dataset + /// \param[in] total_rows An upper bound on the total rows in the dataset + /// \param[out] shuffle_size The resultant computed shuffle size + /// \return Status return code + Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, + int64_t *shuffle_size); + + int batch_size_; + int repeat_num_; + int num_rows_; + int num_classes_; + + int temp_batch_size_; + bool temp_drop_remainder_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_API_DE_PIPELINE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/api/iterator.cc b/mindspore/ccsrc/minddata/dataset/api/iterator.cc new file mode 100644 index 0000000000..068bcfaa04 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/iterator.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/include/iterator.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/include/datasets.h" + +namespace mindspore { +namespace dataset { +namespace api { + +// Get the next row from the data pipeline. +void Iterator::GetNextRow(TensorMap *row) { + Status rc = iterator_->GetNextAsMap(row); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetNextRow: Failed to get next row."; + row->clear(); + } +} + +// Shut down the data pipeline. +void Iterator::Stop() { + // Releasing the iterator_ unique_ptre. This should trigger the destructor of iterator_. + iterator_.reset(); + + // Release ownership of tree_ shared pointer. This will decrement the ref count. + tree_.reset(); +} + +// Function to build and launch the execution tree. +Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { + // One time init + Status rc; + rc = GlobalInit(); + RETURN_IF_NOT_OK(rc); + + // Instantiate the execution tree + tree_ = std::make_shared(); + + // Iterative BFS converting Dataset tree into runtime Execution tree. + std::queue, std::shared_ptr>> q; + + if (ds != nullptr) { + // Convert the current root node. + auto root_op = ds->Build()->front(); + RETURN_UNEXPECTED_IF_NULL(root_op); + + RETURN_IF_NOT_OK(tree_->AssociateNode(root_op)); + + q.push(std::make_pair(ds, root_op)); + + // Traverse down to the children and convert them to the corresponding DatasetOps (i.e. execution tree nodes) + while (!q.empty()) { + auto node_pair = q.front(); + q.pop(); + // Iterate through all the direct children of the first element in our BFS queue + for (auto child : node_pair.first->children) { + auto child_ops = child->Build(); + RETURN_UNEXPECTED_IF_NULL(child_ops); + auto node_op = node_pair.second; + // Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them + // with the execution tree and add the child and parent relationship between the nodes + // Note that some Dataset objects might return more than one DatasetOps + // e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset + for (auto child_op : *child_ops) { + RETURN_IF_NOT_OK(tree_->AssociateNode(child_op)); + RETURN_IF_NOT_OK(node_op->AddChild(child_op)); + node_op = child_op; + } + // Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current + // execution tree) to the BFS queue + q.push(std::make_pair(child, child_ops->back())); + } + } + RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); + } + + // Launch the execution tree. + RETURN_IF_NOT_OK(tree_->Prepare()); + RETURN_IF_NOT_OK(tree_->Launch()); + iterator_ = std::make_unique(tree_); + RETURN_UNEXPECTED_IF_NULL(iterator_); + + return rc; +} + +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc new file mode 100644 index 0000000000..145291ec3b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc @@ -0,0 +1,954 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "minddata/dataset/api/de_pipeline.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/gnn/graph.h" +#include "minddata/dataset/engine/jagged_connector.h" +#include "minddata/dataset/kernels/data/concatenate_op.h" +#include "minddata/dataset/kernels/data/duplicate_op.h" +#include "minddata/dataset/kernels/data/fill_op.h" +#include "minddata/dataset/kernels/data/mask_op.h" +#include "minddata/dataset/kernels/data/one_hot_op.h" +#include "minddata/dataset/kernels/data/pad_end_op.h" +#include "minddata/dataset/kernels/data/slice_op.h" +#include "minddata/dataset/kernels/data/to_float16_op.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/kernels/image/bounding_box_augment_op.h" +#include "minddata/dataset/kernels/image/center_crop_op.h" +#include "minddata/dataset/kernels/image/cut_out_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/hwc_to_chw_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/normalize_op.h" +#include "minddata/dataset/kernels/image/pad_op.h" +#include "minddata/dataset/kernels/image/random_color_adjust_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_op.h" +#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" +#include "minddata/dataset/kernels/image/random_resize_op.h" +#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h" +#include "minddata/dataset/kernels/image/rescale_op.h" +#include "minddata/dataset/kernels/image/resize_bilinear_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/uniform_aug_op.h" +#include "minddata/dataset/kernels/no_op.h" +#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" +#include "minddata/dataset/text/kernels/lookup_op.h" +#include "minddata/dataset/text/kernels/ngram_op.h" +#include "minddata/dataset/text/kernels/to_number_op.h" +#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" +#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" +#include "minddata/dataset/text/vocab.h" +#include "minddata/dataset/util/random.h" +#include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_pk_sample.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_sequential_sample.h" +#include "mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +#ifdef ENABLE_ICU4C +#include "minddata/dataset/text/kernels/basic_tokenizer_op.h" +#include "minddata/dataset/text/kernels/bert_tokenizer_op.h" +#include "minddata/dataset/text/kernels/case_fold_op.h" +#include "minddata/dataset/text/kernels/normalize_utf8_op.h" +#include "minddata/dataset/text/kernels/regex_replace_op.h" +#include "minddata/dataset/text/kernels/regex_tokenizer_op.h" +#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h" +#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h" +#endif + +namespace py = pybind11; + +namespace mindspore { +namespace dataset { +#define THROW_IF_ERROR(s) \ + do { \ + Status rc = std::move(s); \ + if (rc.IsError()) throw std::runtime_error(rc.ToString()); \ + } while (false) + +void bindDEPipeline(py::module *m) { + (void)py::class_(*m, "DEPipeline") + .def(py::init<>()) + .def( + "AddNodeToTree", + [](DEPipeline &de, const OpName &op_name, const py::dict &args) { + py::dict out; + THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out)); + return out; + }, + py::return_value_policy::reference) + .def_static("AddChildToParentNode", + [](const DsOpPtr &child_op, const DsOpPtr &parent_op) { + THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op)); + }) + .def("AssignRootNode", + [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) + .def("SetBatchParameters", + [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) + .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) + .def("GetNextAsMap", + [](DEPipeline &de) { + py::dict out; + THROW_IF_ERROR(de.GetNextAsMap(&out)); + return out; + }) + .def("GetNextAsList", + [](DEPipeline &de) { + py::list out; + THROW_IF_ERROR(de.GetNextAsList(&out)); + return out; + }) + .def("GetOutputShapes", + [](DEPipeline &de) { + py::list out; + THROW_IF_ERROR(de.GetOutputShapes(&out)); + return out; + }) + .def("GetOutputTypes", + [](DEPipeline &de) { + py::list out; + THROW_IF_ERROR(de.GetOutputTypes(&out)); + return out; + }) + .def("GetDatasetSize", &DEPipeline::GetDatasetSize) + .def("GetBatchSize", &DEPipeline::GetBatchSize) + .def("GetNumClasses", &DEPipeline::GetNumClasses) + .def("GetRepeatCount", &DEPipeline::GetRepeatCount); +} +void bindDatasetOps(py::module *m) { + (void)py::class_>(*m, "TFReaderOp") + .def_static("get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) { + int64_t count = 0; + std::vector filenames; + for (auto l : files) { + !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back(""); + } + THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate)); + return count; + }); + + (void)py::class_>(*m, "CifarOp") + .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) { + int64_t count = 0; + THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count)); + return count; + }); + + (void)py::class_>(*m, "ImageFolderOp") + .def_static("get_num_rows_and_classes", [](const std::string &path) { + int64_t count = 0, num_classes = 0; + THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set{}, &count, &num_classes)); + return py::make_tuple(count, num_classes); + }); + + (void)py::class_>(*m, "MindRecordOp") + .def_static("get_num_rows", [](const std::vector &paths, bool load_dataset, const py::object &sampler, + const int64_t num_padded) { + int64_t count = 0; + std::shared_ptr op; + if (py::hasattr(sampler, "create_for_minddataset")) { + auto create = sampler.attr("create_for_minddataset"); + op = create().cast>(); + } + THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); + return count; + }); + + (void)py::class_>(*m, "ManifestOp") + .def_static("get_num_rows_and_classes", + [](const std::string &file, const py::dict &dict, const std::string &usage) { + int64_t count = 0, num_classes = 0; + THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); + return py::make_tuple(count, num_classes); + }) + .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) { + std::map output_class_indexing; + THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); + return output_class_indexing; + }); + + (void)py::class_>(*m, "MnistOp") + .def_static("get_num_rows", [](const std::string &dir) { + int64_t count = 0; + THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count)); + return count; + }); + + (void)py::class_>(*m, "TextFileOp") + .def_static("get_num_rows", [](const py::list &files) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); + } + THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); + return count; + }); + + (void)py::class_>(*m, "ClueOp") + .def_static("get_num_rows", [](const py::list &files) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); + } + THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count)); + return count; + }); + + (void)py::class_>(*m, "VOCOp") + .def_static("get_num_rows", + [](const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, int64_t numSamples) { + int64_t count = 0; + THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); + return count; + }) + .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, + const std::string &task_mode, const py::dict &dict) { + std::map output_class_indexing; + THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing)); + return output_class_indexing; + }); + (void)py::class_>(*m, "CocoOp") + .def_static("get_class_indexing", + [](const std::string &dir, const std::string &file, const std::string &task) { + std::vector>> output_class_indexing; + THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing)); + return output_class_indexing; + }) + .def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) { + int64_t count = 0; + THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count)); + return count; + }); +} +void bindTensor(py::module *m) { + (void)py::class_(*m, "GlobalContext") + .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference); + + (void)py::class_>(*m, "ConfigManager") + .def("__str__", &ConfigManager::ToString) + .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer) + .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers) + .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) + .def("set_op_connector_size", &ConfigManager::set_op_connector_size) + .def("set_seed", &ConfigManager::set_seed) + .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval) + .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer) + .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers) + .def("get_worker_connector_size", &ConfigManager::worker_connector_size) + .def("get_op_connector_size", &ConfigManager::op_connector_size) + .def("get_seed", &ConfigManager::seed) + .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) + .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); + + (void)py::class_>(*m, "Tensor", py::buffer_protocol()) + .def(py::init([](py::array arr) { + std::shared_ptr out; + THROW_IF_ERROR(Tensor::CreateTensor(&out, arr)); + return out; + })) + .def_buffer([](Tensor &tensor) { + py::buffer_info info; + THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info)); + return info; + }) + .def("__str__", &Tensor::ToString) + .def("shape", &Tensor::shape) + .def("type", &Tensor::type) + .def("as_array", [](py::object &t) { + auto &tensor = py::cast(t); + if (tensor.type() == DataType::DE_STRING) { + py::array res; + tensor.GetDataAsNumpyStrings(&res); + return res; + } + py::buffer_info info; + THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info)); + return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t); + }); + + (void)py::class_(*m, "TensorShape") + .def(py::init()) + .def("__str__", &TensorShape::ToString) + .def("as_list", &TensorShape::AsPyList) + .def("is_known", &TensorShape::known); + + (void)py::class_(*m, "DataType") + .def(py::init()) + .def(py::self == py::self) + .def("__str__", &DataType::ToString) + .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); +} + +void bindTensorOps1(py::module *m) { + (void)py::class_>(*m, "TensorOp") + .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); + + (void)py::class_>( + *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") + .def(py::init(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), + py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); + + (void)py::class_>( + *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.") + .def(py::init(), py::arg("rescale"), py::arg("shift")); + + (void)py::class_>( + *m, "CenterCropOp", "Tensor operation to crop and image in the middle. Takes height and width (optional)") + .def(py::init(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth); + + (void)py::class_>( + *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation); + + (void)py::class_>( + *m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth, + py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation); + + (void)py::class_>( + *m, "RandomResizeWithBBoxOp", + "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth); + + (void)py::class_>( + *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") + .def(py::init>, int32_t>(), py::arg("operations"), + py::arg("NumOps") = UniformAugOp::kDefNumOps); + + (void)py::class_>( + *m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.") + .def(py::init, float>(), py::arg("transform"), + py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio); + + (void)py::class_>( + *m, "ResizeBilinearOp", + "Tensor operation to resize an image using " + "Bilinear mode. Takes height and width.") + .def(py::init(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeBilinearOp::kDefWidth); + + (void)py::class_>(*m, "DecodeOp", + "Tensor operation to decode a jpg image") + .def(py::init<>()) + .def(py::init(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat); + + (void)py::class_>( + *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.") + .def(py::init(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability); + + (void)py::class_>( + *m, "RandomHorizontalFlipWithBBoxOp", + "Tensor operation to randomly flip an image horizontally, while flipping bounding boxes.") + .def(py::init(), py::arg("probability") = RandomHorizontalFlipWithBBoxOp::kDefProbability); +} + +void bindTensorOps2(py::module *m) { + (void)py::class_>( + *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.") + .def(py::init(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability); + + (void)py::class_>( + *m, "RandomVerticalFlipWithBBoxOp", + "Tensor operation to randomly flip an image vertically" + " and adjust bounding boxes.") + .def(py::init(), py::arg("probability") = RandomVerticalFlipWithBBoxOp::kDefProbability); + + (void)py::class_>(*m, "RandomCropOp", + "Gives random crop of specified size " + "Takes crop size") + .def(py::init(), + py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop, + py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft, + py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType, + py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR, + py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB); + (void)py::class_>(*m, "ChannelSwapOp").def(py::init<>()); + + (void)py::class_>(*m, "RandomCropWithBBoxOp", + "Gives random crop of given " + "size + adjusts bboxes " + "Takes crop size") + .def(py::init(), + py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropWithBBoxOp::kDefPadTop, + py::arg("padBottom") = RandomCropWithBBoxOp::kDefPadBottom, + py::arg("padLeft") = RandomCropWithBBoxOp::kDefPadLeft, + py::arg("padRight") = RandomCropWithBBoxOp::kDefPadRight, + py::arg("borderType") = RandomCropWithBBoxOp::kDefBorderType, + py::arg("padIfNeeded") = RandomCropWithBBoxOp::kDefPadIfNeeded, + py::arg("fillR") = RandomCropWithBBoxOp::kDefFillR, py::arg("fillG") = RandomCropWithBBoxOp::kDefFillG, + py::arg("fillB") = RandomCropWithBBoxOp::kDefFillB); + + (void)py::class_>( + *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.") + .def(py::init()); + + (void)py::class_>( + *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") + .def(py::init>()); + + (void)py::class_>(*m, "SliceOp", "Tensor slice operation.") + .def(py::init()) + .def(py::init([](const py::list &py_list) { + std::vector c_list; + for (auto l : py_list) { + if (!l.is_none()) { + c_list.push_back(py::reinterpret_borrow(l)); + } + } + return std::make_shared(c_list); + })) + .def(py::init([](const py::tuple &py_slice) { + if (py_slice.size() != 3) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); + } + Slice c_slice; + if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1]), + py::reinterpret_borrow(py_slice[2])); + } else if (py_slice[0].is_none() && py_slice[2].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[1])); + } else if (!py_slice[0].is_none() && !py_slice[1].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1])); + } + + if (!c_slice.valid()) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); + } + return std::make_shared(c_slice); + })); + + (void)py::enum_(*m, "RelationalOp", py::arithmetic()) + .value("EQ", RelationalOp::kEqual) + .value("NE", RelationalOp::kNotEqual) + .value("LT", RelationalOp::kLess) + .value("LE", RelationalOp::kLessEqual) + .value("GT", RelationalOp::kGreater) + .value("GE", RelationalOp::kGreaterEqual) + .export_values(); + + (void)py::class_>(*m, "MaskOp", + "Tensor mask operation using relational comparator") + .def(py::init, DataType>()); + + (void)py::class_>(*m, "DuplicateOp", "Duplicate tensor.") + .def(py::init<>()); + + (void)py::class_>( + *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") + .def(py::init()); + + (void)py::class_>(*m, "ConcatenateOp", + "Tensor operation concatenate tensors.") + .def(py::init, std::shared_ptr>(), py::arg("axis"), + py::arg("prepend").none(true), py::arg("append").none(true)); + + (void)py::class_>( + *m, "RandomRotationOp", + "Tensor operation to apply RandomRotation." + "Takes a range for degrees and " + "optional parameters for rotation center and image expand") + .def(py::init(), + py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX, + py::arg("centerY") = RandomRotationOp::kDefCenterY, + py::arg("interpolation") = RandomRotationOp::kDefInterpolation, + py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR, + py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB); + + (void)py::class_>( + *m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.") + .def(py::init>()); +} + +void bindTensorOps3(py::module *m) { + (void)py::class_>( + *m, "RandomCropAndResizeOp", + "Tensor operation to randomly crop an image and resize to a given size." + "Takes output height and width and" + "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," + "interpolation mode, and max attempts to crop") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb, + py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb, + py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb, + py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb, + py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation, + py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter); + + (void)py::class_>( + *m, "RandomCropAndResizeWithBBoxOp", + "Tensor operation to randomly crop an image (with BBoxes) and resize to a given size." + "Takes output height and width and" + "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," + "interpolation mode, and max attempts to crop") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeWithBBoxOp::kDefScaleLb, + py::arg("scaleUb") = RandomCropAndResizeWithBBoxOp::kDefScaleUb, + py::arg("aspectLb") = RandomCropAndResizeWithBBoxOp::kDefAspectLb, + py::arg("aspectUb") = RandomCropAndResizeWithBBoxOp::kDefAspectUb, + py::arg("interpolation") = RandomCropAndResizeWithBBoxOp::kDefInterpolation, + py::arg("maxIter") = RandomCropAndResizeWithBBoxOp::kDefMaxIter); + + (void)py::class_>( + *m, "RandomColorAdjustOp", + "Tensor operation to adjust an image's color randomly." + "Takes range for brightness, contrast, saturation, hue and") + .def(py::init(), py::arg("bright_factor_start"), + py::arg("bright_factor_end"), py::arg("contrast_factor_start"), py::arg("contrast_factor_end"), + py::arg("saturation_factor_start"), py::arg("saturation_factor_end"), py::arg("hue_factor_start"), + py::arg("hue_factor_end")); + + (void)py::class_>( + *m, "RandomResizeOp", + "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth); + + (void)py::class_>( + *m, "CutOutOp", "Tensor operation to randomly erase a portion of the image. Takes height and width.") + .def(py::init(), py::arg("boxHeight"), + py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor, + py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG, + py::arg("fillB") = CutOutOp::kDefFillB); +} + +void bindTensorOps4(py::module *m) { + (void)py::class_>( + *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.") + .def(py::init(), py::arg("data_type")) + .def(py::init(), py::arg("data_type")); + + (void)py::class_>(*m, "NoOp", + "TensorOp that does nothing, for testing purposes only.") + .def(py::init<>()); + + (void)py::class_>( + *m, "ToFloat16Op", py::dynamic_attr(), "Tensor operator to type cast float32 data to a float16 type.") + .def(py::init<>()); + + (void)py::class_>( + *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb, + py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb, + py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb, + py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb, + py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation, + py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter); + + (void)py::class_>( + *m, "PadOp", + "Pads image with specified color, default black, " + "Takes amount to pad for top, bottom, left, right of image, boarder type and color") + .def(py::init(), py::arg("padTop"), + py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType, + py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); + (void)py::class_>(*m, "ToNumberOp", + "TensorOp to convert strings to numbers.") + .def(py::init(), py::arg("data_type")) + .def(py::init(), py::arg("data_type")); +} + +void bindTokenizerOps(py::module *m) { + (void)py::class_>(*m, "JiebaTokenizerOp", "") + .def(py::init(), py::arg("hmm_path"), + py::arg("mp_path"), py::arg("mode") = JiebaMode::kMix, + py::arg("with_offsets") = JiebaTokenizerOp::kDefWithOffsets) + .def("add_word", + [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); + (void)py::class_>( + *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") + .def(py::init(), py::arg("with_offsets") = UnicodeCharTokenizerOp::kDefWithOffsets); + (void)py::class_>(*m, "LookupOp", + "Tensor operation to LookUp each word.") + .def(py::init([](std::shared_ptr vocab, const py::object &py_word) { + if (vocab == nullptr) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null.")); + } + if (py_word.is_none()) { + return std::make_shared(vocab, Vocab::kNoTokenExists); + } + std::string word = py::reinterpret_borrow(py_word); + WordIdType default_id = vocab->Lookup(word); + if (default_id == Vocab::kNoTokenExists) { + THROW_IF_ERROR( + Status(StatusCode::kUnexpectedError, "default unknown token:" + word + " doesn't exist in vocab.")); + } + return std::make_shared(vocab, default_id); + })); + (void)py::class_>(*m, "NgramOp", "TensorOp performs ngram mapping.") + .def(py::init &, int32_t, int32_t, const std::string &, const std::string &, + const std::string &>(), + py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"), + py::arg("separator")); + (void)py::class_>( + *m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.") + .def( + py::init &, const std::string &, const int &, const std::string &, const bool &>(), + py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), + py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, + py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), + py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); +} + +void bindDependIcuTokenizerOps(py::module *m) { +#ifdef ENABLE_ICU4C + (void)py::class_>( + *m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.") + .def(py::init(), py::arg("with_offsets") = WhitespaceTokenizerOp::kDefWithOffsets); + (void)py::class_>( + *m, "UnicodeScriptTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.") + .def(py::init<>()) + .def(py::init(), + py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace, + py::arg("with_offsets") = UnicodeScriptTokenizerOp::kDefWithOffsets); + (void)py::class_>( + *m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor") + .def(py::init<>()); + (void)py::class_>( + *m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.") + .def(py::init<>()) + .def(py::init(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm); + (void)py::class_>( + *m, "RegexReplaceOp", "Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.") + .def(py::init(), py::arg("pattern"), py::arg("replace"), + py::arg("replace_all")); + (void)py::class_>( + *m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.") + .def(py::init(), py::arg("delim_pattern"), + py::arg("keep_delim_pattern"), py::arg("with_offsets") = RegexTokenizerOp::kDefWithOffsets); + (void)py::class_>( + *m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.") + .def(py::init(), + py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, + py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, + py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, + py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken, + py::arg("with_offsets") = BasicTokenizerOp::kDefWithOffsets); + (void)py::class_>(*m, "BertTokenizerOp", + "Tokenizer used for Bert text process.") + .def(py::init &, const std::string &, const int &, const std::string &, const bool &, + const bool &, const NormalizeForm &, const bool &, const bool &>(), + py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), + py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, + py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), + py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, + py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, + py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, + py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken, + py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); +#endif +} + +void bindSamplerOps(py::module *m) { + (void)py::class_>(*m, "Sampler") + .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) + .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) + .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) + .def("get_indices", + [](Sampler &self) { + py::array ret; + THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); + return ret; + }) + .def("add_child", + [](std::shared_ptr self, std::shared_ptr child) { THROW_IF_ERROR(self->AddChild(child)); }); + + (void)py::class_>(*m, "ShardOperator") + .def("add_child", [](std::shared_ptr self, + std::shared_ptr child) { self->SetChildOp(child); }); + + (void)py::class_>(*m, "DistributedSampler") + .def(py::init()); + + (void)py::class_>(*m, "PKSampler") + .def(py::init()); + + (void)py::class_>(*m, "RandomSampler") + .def(py::init()); + + (void)py::class_>(*m, "SequentialSampler") + .def(py::init()); + + (void)py::class_>(*m, "SubsetRandomSampler") + .def(py::init>()); + + (void)py::class_>( + *m, "MindrecordSubsetRandomSampler") + .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); + + (void)py::class_>( + *m, "MindrecordPkSampler") + .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { + if (shuffle == true) { + return std::make_shared(kColumn, kVal, std::numeric_limits::max(), + GetSeed()); + } else { + return std::make_shared(kColumn, kVal); + } + })); + + (void)py::class_>(*m, "MindrecordDistributedSampler") + .def(py::init()); + + (void)py::class_>( + *m, "MindrecordRandomSampler") + .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) { + return std::make_shared(GetSeed(), num_samples, replacement, reshuffle_each_epoch); + })); + + (void)py::class_>(*m, "MindrecordSequentialSampler") + .def(py::init([](int num_samples, int start_index) { + return std::make_shared(num_samples, start_index); + })); + + (void)py::class_>(*m, "WeightedRandomSampler") + .def(py::init, bool>()); + + (void)py::class_>(*m, "PythonSampler") + .def(py::init()); +} + +void bindInfoObjects(py::module *m) { + (void)py::class_(*m, "CBatchInfo") + .def(py::init()) + .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num) + .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); +} + +void bindCacheClient(py::module *m) { + (void)py::class_>(*m, "CacheClient") + .def(py::init()); +} + +void bindVocabObjects(py::module *m) { + (void)py::class_>(*m, "Vocab") + .def(py::init<>()) + .def_static("from_list", + [](const py::list &words, const py::list &special_tokens, bool special_first) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v)); + return v; + }) + .def_static("from_file", + [](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens, + bool special_first) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v)); + return v; + }) + .def_static("from_dict", [](const py::dict &words) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v)); + return v; + }); +} + +void bindGraphData(py::module *m) { + (void)py::class_>(*m, "Graph") + .def(py::init([](std::string dataset_file, int32_t num_workers) { + std::shared_ptr g_out = std::make_shared(dataset_file, num_workers); + THROW_IF_ERROR(g_out->Init()); + return g_out; + })) + .def("get_all_nodes", + [](gnn::Graph &g, gnn::NodeType node_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); + return out; + }) + .def("get_all_edges", + [](gnn::Graph &g, gnn::EdgeType edge_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); + return out; + }) + .def("get_nodes_from_edges", + [](gnn::Graph &g, std::vector edge_list) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); + return out; + }) + .def("get_all_neighbors", + [](gnn::Graph &g, std::vector node_list, gnn::NodeType neighbor_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); + return out; + }) + .def("get_sampled_neighbors", + [](gnn::Graph &g, std::vector node_list, std::vector neighbor_nums, + std::vector neighbor_types) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); + return out; + }) + .def("get_neg_sampled_neighbors", + [](gnn::Graph &g, std::vector node_list, gnn::NodeIdType neighbor_num, + gnn::NodeType neg_neighbor_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); + return out; + }) + .def("get_node_feature", + [](gnn::Graph &g, std::shared_ptr node_list, std::vector feature_types) { + TensorRow out; + THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); + return out.getRow(); + }) + .def("get_edge_feature", + [](gnn::Graph &g, std::shared_ptr edge_list, std::vector feature_types) { + TensorRow out; + THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); + return out.getRow(); + }) + .def("graph_info", + [](gnn::Graph &g) { + py::dict out; + THROW_IF_ERROR(g.GraphInfo(&out)); + return out; + }) + .def("random_walk", [](gnn::Graph &g, std::vector node_list, std::vector meta_path, + float step_home_param, float step_away_param, gnn::NodeIdType default_node) { + std::shared_ptr out; + THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); + return out; + }); +} + +// This is where we externalize the C logic as python modules +PYBIND11_MODULE(_c_dataengine, m) { + m.doc() = "pybind11 for _c_dataengine"; + (void)py::class_>(m, "DatasetOp"); + + (void)py::enum_(m, "OpName", py::arithmetic()) + .value("SHUFFLE", OpName::kShuffle) + .value("BATCH", OpName::kBatch) + .value("BUCKETBATCH", OpName::kBucketBatch) + .value("BARRIER", OpName::kBarrier) + .value("MINDRECORD", OpName::kMindrecord) + .value("CACHE", OpName::kCache) + .value("REPEAT", OpName::kRepeat) + .value("SKIP", OpName::kSkip) + .value("TAKE", OpName::kTake) + .value("ZIP", OpName::kZip) + .value("CONCAT", OpName::kConcat) + .value("MAP", OpName::kMap) + .value("FILTER", OpName::kFilter) + .value("DEVICEQUEUE", OpName::kDeviceQueue) + .value("GENERATOR", OpName::kGenerator) + .export_values() + .value("RENAME", OpName::kRename) + .value("TFREADER", OpName::kTfReader) + .value("PROJECT", OpName::kProject) + .value("IMAGEFOLDER", OpName::kImageFolder) + .value("MNIST", OpName::kMnist) + .value("MANIFEST", OpName::kManifest) + .value("VOC", OpName::kVoc) + .value("COCO", OpName::kCoco) + .value("CIFAR10", OpName::kCifar10) + .value("CIFAR100", OpName::kCifar100) + .value("RANDOMDATA", OpName::kRandomData) + .value("BUILDVOCAB", OpName::kBuildVocab) + .value("CELEBA", OpName::kCelebA) + .value("TEXTFILE", OpName::kTextFile) + .value("CLUE", OpName::kClue); + + (void)py::enum_(m, "JiebaMode", py::arithmetic()) + .value("DE_JIEBA_MIX", JiebaMode::kMix) + .value("DE_JIEBA_MP", JiebaMode::kMp) + .value("DE_JIEBA_HMM", JiebaMode::kHmm) + .export_values(); + +#ifdef ENABLE_ICU4C + (void)py::enum_(m, "NormalizeForm", py::arithmetic()) + .value("DE_NORMALIZE_NONE", NormalizeForm::kNone) + .value("DE_NORMALIZE_NFC", NormalizeForm::kNfc) + .value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc) + .value("DE_NORMALIZE_NFD", NormalizeForm::kNfd) + .value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd) + .export_values(); +#endif + + (void)py::enum_(m, "InterpolationMode", py::arithmetic()) + .value("DE_INTER_LINEAR", InterpolationMode::kLinear) + .value("DE_INTER_CUBIC", InterpolationMode::kCubic) + .value("DE_INTER_AREA", InterpolationMode::kArea) + .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour) + .export_values(); + + (void)py::enum_(m, "BorderType", py::arithmetic()) + .value("DE_BORDER_CONSTANT", BorderType::kConstant) + .value("DE_BORDER_EDGE", BorderType::kEdge) + .value("DE_BORDER_REFLECT", BorderType::kReflect) + .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric) + .export_values(); + bindDEPipeline(&m); + bindTensor(&m); + bindTensorOps1(&m); + bindTensorOps2(&m); + bindTensorOps3(&m); + bindTensorOps4(&m); + bindTokenizerOps(&m); + bindSamplerOps(&m); + bindDatasetOps(&m); + bindInfoObjects(&m); + bindCacheClient(&m); + bindVocabObjects(&m); + bindGraphData(&m); + bindDependIcuTokenizerOps(&m); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc new file mode 100644 index 0000000000..91421f0ff8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -0,0 +1,224 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/include/samplers.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" + +namespace mindspore { +namespace dataset { +namespace api { + +SamplerObj::SamplerObj() {} + +/// Function to create a Distributed Sampler. +std::shared_ptr DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, + int64_t num_samples, uint32_t seed) { + auto sampler = std::make_shared(num_shards, shard_id, shuffle, num_samples, seed); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/// Function to create a PK Sampler. +std::shared_ptr PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) { + auto sampler = std::make_shared(num_val, shuffle, num_samples); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/// Function to create a Random Sampler. +std::shared_ptr RandomSampler(bool replacement, int64_t num_samples) { + auto sampler = std::make_shared(replacement, num_samples); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/// Function to create a Sequential Sampler. +std::shared_ptr SequentialSampler(int64_t start_index, int64_t num_samples) { + auto sampler = std::make_shared(start_index, num_samples); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/// Function to create a Subset Random Sampler. +std::shared_ptr SubsetRandomSampler(const std::vector &indices, int64_t num_samples) { + auto sampler = std::make_shared(indices, num_samples); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/// Function to create a Weighted Random Sampler. +std::shared_ptr WeightedRandomSampler(const std::vector &weights, int64_t num_samples, + bool replacement) { + auto sampler = std::make_shared(weights, num_samples, replacement); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/* ####################################### Derived Sampler classes ################################# */ + +// DistributedSampler +DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, + uint32_t seed) + : num_shards_(num_shards), shard_id_(shard_id), shuffle_(shuffle), num_samples_(num_samples), seed_(seed) {} + +bool DistributedSamplerObj::ValidateParams() { + if (num_shards_ <= 0) { + MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_; + return false; + } + + if (shard_id_ < 0 || shard_id_ >= num_shards_) { + MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; + return false; + } + + if (num_samples_ < 0) { + MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_; + return false; + } + + return true; +} + +std::shared_ptr DistributedSamplerObj::Build() { + return std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_); +} + +// PKSampler +PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) + : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} + +bool PKSamplerObj::ValidateParams() { + if (num_val_ <= 0) { + MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_; + return false; + } + + if (num_samples_ < 0) { + MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_; + return false; + } + return true; +} + +std::shared_ptr PKSamplerObj::Build() { + return std::make_shared(num_samples_, num_val_, shuffle_); +} + +// RandomSampler +RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) + : replacement_(replacement), num_samples_(num_samples) {} + +bool RandomSamplerObj::ValidateParams() { + if (num_samples_ < 0) { + MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_; + return false; + } + return true; +} + +std::shared_ptr RandomSamplerObj::Build() { + bool reshuffle_each_epoch = true; + auto sampler = std::make_shared(num_samples_, replacement_, reshuffle_each_epoch); + return sampler; +} + +// SequentialSampler +SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples) + : start_index_(start_index), num_samples_(num_samples) {} + +bool SequentialSamplerObj::ValidateParams() { + if (num_samples_ < 0) { + MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_; + return false; + } + + if (start_index_ < 0) { + MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_; + return false; + } + + return true; +} + +std::shared_ptr SequentialSamplerObj::Build() { + auto sampler = std::make_shared(num_samples_, start_index_); + return sampler; +} + +// SubsetRandomSampler +SubsetRandomSamplerObj::SubsetRandomSamplerObj(const std::vector &indices, int64_t num_samples) + : indices_(indices), num_samples_(num_samples) {} + +bool SubsetRandomSamplerObj::ValidateParams() { + if (num_samples_ < 0) { + MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_; + return false; + } + + return true; +} + +std::shared_ptr SubsetRandomSamplerObj::Build() { + auto sampler = std::make_shared(num_samples_, indices_); + return sampler; +} + +// WeightedRandomSampler +WeightedRandomSamplerObj::WeightedRandomSamplerObj(const std::vector &weights, int64_t num_samples, + bool replacement) + : weights_(weights), num_samples_(num_samples), replacement_(replacement) {} + +bool WeightedRandomSamplerObj::ValidateParams() { + if (num_samples_ < 0) { + MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_; + return false; + } + return true; +} + +std::shared_ptr WeightedRandomSamplerObj::Build() { + auto sampler = std::make_shared(num_samples_, weights_, replacement_); + return sampler; +} + +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc new file mode 100644 index 0000000000..59a25ef9f5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -0,0 +1,491 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/include/transforms.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/normalize_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_op.h" +#include "minddata/dataset/kernels/image/center_crop_op.h" +#include "minddata/dataset/kernels/image/uniform_aug_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" +#include "minddata/dataset/kernels/image/cut_out_op.h" +#include "minddata/dataset/kernels/image/random_color_adjust_op.h" +#include "minddata/dataset/kernels/image/pad_op.h" + +namespace mindspore { +namespace dataset { +namespace api { + +TensorOperation::TensorOperation() {} + +// Transform operations for computer vision. +namespace vision { + +// Function to create NormalizeOperation. +std::shared_ptr Normalize(std::vector mean, std::vector std) { + auto op = std::make_shared(mean, std); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create DecodeOperation. +std::shared_ptr Decode(bool rgb) { + auto op = std::make_shared(rgb); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create ResizeOperation. +std::shared_ptr Resize(std::vector size, InterpolationMode interpolation) { + auto op = std::make_shared(size, interpolation); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomCropOperation. +std::shared_ptr RandomCrop(std::vector size, std::vector padding, + bool pad_if_needed, std::vector fill_value) { + auto op = std::make_shared(size, padding, pad_if_needed, fill_value); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create CenterCropOperation. +std::shared_ptr CenterCrop(std::vector size) { + auto op = std::make_shared(size); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create UniformAugOperation. +std::shared_ptr UniformAugment(std::vector> operations, + int32_t num_ops) { + auto op = std::make_shared(operations, num_ops); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomHorizontalFlipOperation. +std::shared_ptr RandomHorizontalFlip(float prob) { + auto op = std::make_shared(prob); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomVerticalFlipOperation. +std::shared_ptr RandomVerticalFlip(float prob) { + auto op = std::make_shared(prob); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomRotationOperation. +std::shared_ptr RandomRotation(std::vector degrees, InterpolationMode resample, + bool expand, std::vector center, + std::vector fill_value) { + auto op = std::make_shared(degrees, resample, expand, center, fill_value); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create PadOperation. +std::shared_ptr Pad(std::vector padding, std::vector fill_value, + BorderType padding_mode) { + auto op = std::make_shared(padding, fill_value, padding_mode); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create CutOutOp. +std::shared_ptr CutOut(int32_t length, int32_t num_patches) { + auto op = std::make_shared(length, num_patches); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomColorAdjustOperation. +std::shared_ptr RandomColorAdjust(std::vector brightness, + std::vector contrast, + std::vector saturation, std::vector hue) { + auto op = std::make_shared(brightness, contrast, saturation, hue); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +/* ####################################### Derived TensorOperation classes ################################# */ + +// NormalizeOperation +NormalizeOperation::NormalizeOperation(std::vector mean, std::vector std) : mean_(mean), std_(std) {} + +bool NormalizeOperation::ValidateParams() { + if (mean_.size() != 3) { + MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size(); + return false; + } + + if (std_.size() != 3) { + MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size(); + return false; + } + + return true; +} + +std::shared_ptr NormalizeOperation::Build() { + return std::make_shared(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); +} + +// DecodeOperation +DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {} + +bool DecodeOperation::ValidateParams() { return true; } + +std::shared_ptr DecodeOperation::Build() { return std::make_shared(rgb_); } + +// ResizeOperation +ResizeOperation::ResizeOperation(std::vector size, InterpolationMode interpolation) + : size_(size), interpolation_(interpolation) {} + +bool ResizeOperation::ValidateParams() { + if (size_.empty() || size_.size() > 2) { + MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size(); + return false; + } + return true; +} + +std::shared_ptr ResizeOperation::Build() { + int32_t height = size_[0]; + int32_t width = 0; + + // User specified the width value. + if (size_.size() == 2) { + width = size_[1]; + } + + return std::make_shared(height, width, interpolation_); +} + +// RandomCropOperation +RandomCropOperation::RandomCropOperation(std::vector size, std::vector padding, bool pad_if_needed, + std::vector fill_value) + : size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value) {} + +bool RandomCropOperation::ValidateParams() { + if (size_.empty() || size_.size() > 2) { + MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size(); + return false; + } + + if (padding_.empty() || padding_.size() != 4) { + MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()"; + return false; + } + + if (fill_value_.empty() || fill_value_.size() != 3) { + MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()"; + return false; + } + return true; +} + +std::shared_ptr RandomCropOperation::Build() { + int32_t crop_height = size_[0]; + int32_t crop_width = 0; + + int32_t pad_top = padding_[0]; + int32_t pad_bottom = padding_[1]; + int32_t pad_left = padding_[2]; + int32_t pad_right = padding_[3]; + + uint8_t fill_r = fill_value_[0]; + uint8_t fill_g = fill_value_[1]; + uint8_t fill_b = fill_value_[2]; + + // User has specified the crop_width value. + if (size_.size() == 2) { + crop_width = size_[1]; + } + + auto tensor_op = std::make_shared(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right, + BorderType::kConstant, pad_if_needed_, fill_r, fill_g, fill_b); + return tensor_op; +} + +// CenterCropOperation +CenterCropOperation::CenterCropOperation(std::vector size) : size_(size) {} + +bool CenterCropOperation::ValidateParams() { + if (size_.empty() || size_.size() > 2) { + MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size."; + return false; + } + return true; +} + +std::shared_ptr CenterCropOperation::Build() { + int32_t crop_height = size_[0]; + int32_t crop_width = 0; + + // User has specified crop_width. + if (size_.size() == 2) { + crop_width = size_[1]; + } + + std::shared_ptr tensor_op = std::make_shared(crop_height, crop_width); + return tensor_op; +} + +// UniformAugOperation +UniformAugOperation::UniformAugOperation(std::vector> operations, int32_t num_ops) + : operations_(operations), num_ops_(num_ops) {} + +bool UniformAugOperation::ValidateParams() { return true; } + +std::shared_ptr UniformAugOperation::Build() { + std::vector> tensor_ops; + (void)std::transform(operations_.begin(), operations_.end(), std::back_inserter(tensor_ops), + [](std::shared_ptr op) -> std::shared_ptr { return op->Build(); }); + std::shared_ptr tensor_op = std::make_shared(tensor_ops, num_ops_); + return tensor_op; +} + +// RandomHorizontalFlipOperation +RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {} + +bool RandomHorizontalFlipOperation::ValidateParams() { return true; } + +std::shared_ptr RandomHorizontalFlipOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(probability_); + return tensor_op; +} + +// RandomVerticalFlipOperation +RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {} + +bool RandomVerticalFlipOperation::ValidateParams() { return true; } + +std::shared_ptr RandomVerticalFlipOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(probability_); + return tensor_op; +} + +// Function to create RandomRotationOperation. +RandomRotationOperation::RandomRotationOperation(std::vector degrees, InterpolationMode interpolation_mode, + bool expand, std::vector center, + std::vector fill_value) + : degrees_(degrees), + interpolation_mode_(interpolation_mode), + expand_(expand), + center_(center), + fill_value_(fill_value) {} + +bool RandomRotationOperation::ValidateParams() { + if (degrees_.empty() || degrees_.size() != 2) { + MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()"; + return false; + } + if (center_.empty() || center_.size() != 2) { + MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()"; + return false; + } + if (fill_value_.empty() || fill_value_.size() != 3) { + MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()"; + return false; + } + return true; +} + +std::shared_ptr RandomRotationOperation::Build() { + std::shared_ptr tensor_op = + std::make_shared(degrees_[0], degrees_[1], center_[0], center_[1], interpolation_mode_, expand_, + fill_value_[0], fill_value_[1], fill_value_[2]); + return tensor_op; +} + +// PadOperation +PadOperation::PadOperation(std::vector padding, std::vector fill_value, BorderType padding_mode) + : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {} + +bool PadOperation::ValidateParams() { + if (padding_.empty() || padding_.size() == 3 || padding_.size() > 4) { + MS_LOG(ERROR) << "Pad: padding vector has incorrect size: padding.size()"; + return false; + } + + if (fill_value_.empty() || (fill_value_.size() != 1 && fill_value_.size() != 3)) { + MS_LOG(ERROR) << "Pad: fill_value vector has incorrect size: fill_value.size()"; + return false; + } + return true; +} + +std::shared_ptr PadOperation::Build() { + int32_t pad_top, pad_bottom, pad_left, pad_right; + switch (padding_.size()) { + case 1: + pad_left = padding_[0]; + pad_top = padding_[0]; + pad_right = padding_[0]; + pad_bottom = padding_[0]; + break; + case 2: + pad_left = padding_[0]; + pad_top = padding_[1]; + pad_right = padding_[0]; + pad_bottom = padding_[1]; + break; + default: + pad_left = padding_[0]; + pad_top = padding_[1]; + pad_right = padding_[2]; + pad_bottom = padding_[3]; + } + uint8_t fill_r, fill_g, fill_b; + + fill_r = fill_value_[0]; + fill_g = fill_value_[0]; + fill_b = fill_value_[0]; + + if (fill_value_.size() == 3) { + fill_r = fill_value_[0]; + fill_g = fill_value_[1]; + fill_b = fill_value_[2]; + } + + std::shared_ptr tensor_op = + std::make_shared(pad_top, pad_bottom, pad_left, pad_right, padding_mode_, fill_r, fill_g, fill_b); + return tensor_op; +} + +// CutOutOperation +CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {} + +bool CutOutOperation::ValidateParams() { + if (length_ < 0) { + MS_LOG(ERROR) << "CutOut: length cannot be negative"; + return false; + } + if (num_patches_ < 0) { + MS_LOG(ERROR) << "CutOut: number of patches cannot be negative"; + return false; + } + return true; +} + +std::shared_ptr CutOutOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(length_, length_, num_patches_, false, 0, 0, 0); + return tensor_op; +} + +// RandomColorAdjustOperation. +RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector brightness, std::vector contrast, + std::vector saturation, std::vector hue) + : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {} + +bool RandomColorAdjustOperation::ValidateParams() { + // Do some input validation. + if (brightness_.empty() || brightness_.size() > 2) { + MS_LOG(ERROR) << "RandomColorAdjust: brightness must be a vector of one or two values"; + return false; + } + if (contrast_.empty() || contrast_.size() > 2) { + MS_LOG(ERROR) << "RandomColorAdjust: contrast must be a vector of one or two values"; + return false; + } + if (saturation_.empty() || saturation_.size() > 2) { + MS_LOG(ERROR) << "RandomColorAdjust: saturation must be a vector of one or two values"; + return false; + } + if (hue_.empty() || hue_.size() > 2) { + MS_LOG(ERROR) << "RandomColorAdjust: hue must be a vector of one or two values"; + return false; + } + return true; +} + +std::shared_ptr RandomColorAdjustOperation::Build() { + float brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub; + + brightness_lb = brightness_[0]; + brightness_ub = brightness_[0]; + + if (brightness_.size() == 2) brightness_ub = brightness_[1]; + + contrast_lb = contrast_[0]; + contrast_ub = contrast_[0]; + + if (contrast_.size() == 2) contrast_ub = contrast_[1]; + + saturation_lb = saturation_[0]; + saturation_ub = saturation_[0]; + + if (saturation_.size() == 2) saturation_ub = saturation_[1]; + + hue_lb = hue_[0]; + hue_ub = hue_[0]; + + if (hue_.size() == 2) hue_ub = hue_[1]; + + std::shared_ptr tensor_op = std::make_shared( + brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub); + return tensor_op; +} + +} // namespace vision +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/core/CMakeLists.txt new file mode 100644 index 0000000000..bfe6e67563 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/CMakeLists.txt @@ -0,0 +1,21 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +set(DATASET_CORE_SRC_FILES + client.cc + config_manager.cc + cv_tensor.cc + data_type.cc + global_context.cc + tensor.cc + tensor_row.cc + tensor_shape.cc +) + +ms_protobuf_generate(EXAMPLE_SRCS EXAMPLE_HDRS example.proto) +ms_protobuf_generate(FEATURE_SRCS FEATURE_HDRS feature.proto) +add_library(core OBJECT ${DATASET_CORE_SRC_FILES} ${EXAMPLE_SRCS} ${FEATURE_SRCS}) +add_dependencies(core mindspore::protobuf) + +if (ENABLE_PYTHON) + target_include_directories(core PRIVATE ${pybind11_INCLUDE_DIRS}) +endif() diff --git a/mindspore/ccsrc/minddata/dataset/core/client.cc b/mindspore/ccsrc/minddata/dataset/core/client.cc new file mode 100644 index 0000000000..e3fd844e66 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/client.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/sig_handler.h" + +namespace mindspore { +namespace dataset { +// This is a one-time global initializer which includes the call to instantiate singletons. +// It is external api call and not a member of the GlobalContext directly. +Status GlobalInit() { + // Bring up all the services (logger, task, bufferpool) + return (Services::CreateInstance()); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/client.h b/mindspore/ccsrc/minddata/dataset/core/client.h new file mode 100644 index 0000000000..78b298e616 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/client.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_CLIENT_H_ +#define DATASET_CORE_CLIENT_H_ + +// client.h +// Include file for DE client functions + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" + +#ifdef ENABLE_PYTHON +#include "minddata/dataset/engine/datasetops/barrier_op.h" +#include "minddata/dataset/engine/datasetops/filter_op.h" +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#include "minddata/dataset/engine/datasetops/build_vocab_op.h" +#endif + +#include "minddata/dataset/engine/datasetops/batch_op.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/project_op.h" +#include "minddata/dataset/engine/datasetops/rename_op.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/skip_op.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/datasetops/take_op.h" +#include "minddata/dataset/engine/datasetops/zip_op.h" +#include "minddata/dataset/engine/datasetops/concat_op.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// This is a one-time global initializer that needs to be called at the +// start of any minddata applications. +extern Status GlobalInit(); +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CORE_CLIENT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc new file mode 100644 index 0000000000..e1fc7f29ba --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/config_manager.h" + +#include +#include +#include + +#include "minddata/dataset/util/system_pool.h" + +namespace mindspore { +namespace dataset { +// A print method typically used for debugging +void ConfigManager::Print(std::ostream &out) const { + // Don't show the test/internal ones. Only display the main ones here. + // fyi, boolalpha tells the output stream to write "true" and "false" for bools + out << "\nClient config settings :" + << "\nDataCache Rows per buffer : " << rows_per_buffer_ + << "\nParallelOp workers : " << num_parallel_workers_ + << "\nParallelOp worker connector size : " << worker_connector_size_ + << "\nSize of each Connector : " << op_connector_size_ << std::endl; +} + +// Private helper function that taks a nlohmann json format and populates the settings +Status ConfigManager::FromJson(const nlohmann::json &j) { + set_rows_per_buffer(j.value("rowsPerBuffer", rows_per_buffer_)); + set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_)); + set_worker_connector_size(j.value("workerConnectorSize", worker_connector_size_)); + set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); + set_seed(j.value("seed", seed_)); + set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); + return Status::OK(); +} + +// Loads a json file with the default settings and populates all the settings +Status ConfigManager::LoadFile(const std::string &settingsFile) { + Status rc; + if (!Path(settingsFile).Exists()) { + RETURN_STATUS_UNEXPECTED("File is not found."); + } + // Some settings are mandatory, others are not (with default). If a setting + // is optional it will set a default value if the config is missing from the file. + try { + std::ifstream in(settingsFile); + nlohmann::json js; + in >> js; + rc = FromJson(js); + } catch (const nlohmann::json::type_error &e) { + std::ostringstream ss; + ss << "Client file failed to load:\n" << e.what(); + std::string err_msg = ss.str(); + RETURN_STATUS_UNEXPECTED(err_msg); + } catch (const std::exception &err) { + RETURN_STATUS_UNEXPECTED("Client file failed to load."); + } + return rc; +} + +// Setter function +void ConfigManager::set_rows_per_buffer(int32_t rows_per_buffer) { rows_per_buffer_ = rows_per_buffer; } + +// Setter function +void ConfigManager::set_num_parallel_workers(int32_t num_parallel_workers) { + num_parallel_workers_ = num_parallel_workers; +} + +// Setter function +void ConfigManager::set_worker_connector_size(int32_t connector_size) { worker_connector_size_ = connector_size; } + +// Setter function +void ConfigManager::set_op_connector_size(int32_t connector_size) { op_connector_size_ = connector_size; } + +uint32_t ConfigManager::seed() const { return seed_; } + +void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; } + +void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.h b/mindspore/ccsrc/minddata/dataset/core/config_manager.h new file mode 100644 index 0000000000..a8e1907c41 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.h @@ -0,0 +1,137 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_CONFIG_MANAGER_H_ +#define DATASET_CORE_CONFIG_MANAGER_H_ + +#include +#include +#include + +#include + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" + +// Config settings for the client-side +// example config file: +// { +// "rowsPerBuffer": 3 +// } +// + +namespace mindspore { +namespace dataset { +// The ConfigManager is a class for managing default values. When a user is constructing any objects +// in the framework, often they may choose to omit some settings instead of overriding them. +// This class manages some of the default values, for cases when the user does not manually specify +// those values. +class ConfigManager { + public: + ConfigManager() = default; + + // destructor + ~ConfigManager() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + void Print(std::ostream &out) const; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param cS - reference to the ConfigManager to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ConfigManager &cS) { + cS.Print(out); + return out; + } + + // Another debug print helper. Converts the print info to a string for you. + // @return The string version of the debug print + std::string ToString() { + std::stringstream ss; + ss << *this; + return ss.str(); + } + + // Loads a json file with the default settings and populates all the settings + // @param settingsFile - A json file with a set of default settings + // @return Status error code + Status LoadFile(const std::string &settingsFile); + + // getter function + // @return The rows per buffer setting + int32_t rows_per_buffer() const { return rows_per_buffer_; } + + // getter function + // @return The number of workers setting + int32_t num_parallel_workers() const { return num_parallel_workers_; } + + // getter function + // @return The queue size of the operator's output connector + int32_t op_connector_size() const { return op_connector_size_; } + + // getter function + // @return The internal worker-to-master connector queue size + int32_t worker_connector_size() const { return worker_connector_size_; } + + // setter function + // @param rows_per_buffer - The setting to apply to the config + void set_rows_per_buffer(int32_t rows_per_buffer); + + // setter function + // @param num_parallel_workers - The setting to apply to the config + void set_num_parallel_workers(int32_t num_parallel_workers); + + // setter function + // @param connector_size - The setting to apply to the config + void set_worker_connector_size(int32_t connector_size); + + // setter function + // @param connector_size - The setting to apply to the config + void set_op_connector_size(int32_t connector_size); + + uint32_t seed() const; + + // setter function + // @param seed - The default seed to use + void set_seed(uint32_t seed); + + // setter function + // @param interval - The setting to apply to the config + void set_monitor_sampling_interval(uint32_t interval); + + // getter function + // @return The iterval of monitor sampling + int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; } + + private: + int32_t rows_per_buffer_{kCfgRowsPerBuffer}; + int32_t num_parallel_workers_{kCfgParallelWorkers}; + int32_t worker_connector_size_{kCfgWorkerConnectorSize}; + int32_t op_connector_size_{kCfgOpConnectorSize}; + uint32_t seed_{kCfgDefaultSeed}; + uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval}; + + // Private helper function that taks a nlohmann json format and populates the settings + // @param j - The json nlohmann json info + Status FromJson(const nlohmann::json &j); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CORE_CONFIG_MANAGER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/core/constants.h new file mode 100644 index 0000000000..c85ef52bf5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/constants.h @@ -0,0 +1,66 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_CONSTANTS_H_ +#define DATASET_CORE_CONSTANTS_H_ + +#include +#include +#include + +namespace mindspore { +namespace dataset { +// Various type defines for convenience +using uchar = unsigned char; +using dsize_t = int64_t; + +// Possible dataset types for holding the data and client type +enum class DatasetType { kUnknown, kArrow, kTf }; + +// Possible flavours of Tensor implementations +enum class TensorImpl { kNone, kFlexible, kCv, kNP }; + +// Possible values for Border types +enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 }; + +// Possible interpolation modes +enum class InterpolationMode { kLinear = 0, kNearestNeighbour = 1, kCubic = 2, kArea = 3 }; + +// convenience functions for 32bit int bitmask +inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } + +inline void BitSet(uint32_t *bits, uint32_t bitMask) { *bits |= bitMask; } + +inline void BitClear(uint32_t *bits, uint32_t bitMask) { *bits &= (~bitMask); } + +constexpr int32_t kDeMaxDim = std::numeric_limits::max(); // 2147483647 or 2^32 -1 +constexpr int32_t kDeMaxRank = std::numeric_limits::max(); + +constexpr uint32_t kCfgRowsPerBuffer = 1; +constexpr uint32_t kCfgParallelWorkers = 4; +constexpr uint32_t kCfgWorkerConnectorSize = 16; +constexpr uint32_t kCfgOpConnectorSize = 16; +constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed; +constexpr uint32_t kCfgMonitorSamplingInterval = 10; + +// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) +constexpr uint8_t kCVInvalidType = 255; + +using connection_id_type = int64_t; +using row_id_type = int64_t; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CORE_CONSTANTS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc new file mode 100644 index 0000000000..5af748b5de --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/cv_tensor.h" + +#include +#include + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/tensor.h" + +namespace mindspore { +namespace dataset { +CVTensor::CVTensor(const TensorShape &shape, const DataType &type) : Tensor(shape, type) { + (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); +} + +CVTensor::CVTensor(const TensorShape &shape, const DataType &type, const uchar *data) : Tensor(shape, type, data) { + (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); +} + +CVTensor::CVTensor(std::shared_ptr tensor) : Tensor(std::move(*tensor)) { + (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); +} + +std::pair, int> CVTensor::IsValidImage(const TensorShape &shape, const DataType &type) { + std::array size = {1, 1}; + if (shape.Rank() <= 2 || (shape.Rank() == 3 && shape[2] <= CV_CN_MAX)) { + uint8_t ch = 1; + if (shape.Rank() == 3) { + ch = static_cast(shape[2]); + } + if (shape.Rank() > 0) size[0] = static_cast(shape[0]); + if (shape.Rank() > 1) size[1] = static_cast(shape[1]); + if (type.AsCVType() == kCVInvalidType) return std::make_pair(size, -1); + + int cv_type = CV_MAKETYPE(type.AsCVType(), ch); + return std::make_pair(size, cv_type); + } + return std::make_pair(size, -1); +} + +std::shared_ptr CVTensor::AsCVTensor(std::shared_ptr t) { + std::shared_ptr cv_t = std::dynamic_pointer_cast(t); + if (cv_t != nullptr) { + return cv_t; + } else { + return std::make_shared(t); + } +} + +Status CVTensor::MatInit(uchar *data, const TensorShape &shape, const DataType &type, cv::Mat *mat) { + std::pair, int> cv_shape_type = IsValidImage(shape, type); + if (cv_shape_type.second == -1) { + std::vector sizes = shape.AsVector(); + std::vector sizes32(sizes.begin(), sizes.end()); // convert long to int for usage with OpenCV + if (static_cast(shape.Rank()) != shape.Rank()) { + RETURN_STATUS_UNEXPECTED("Error in creating CV mat. Wrong shape."); + } + + uint8_t cv_type = type.AsCVType(); + if (cv_type == kCVInvalidType) { + RETURN_STATUS_UNEXPECTED("Error in creating CV mat. Invalid type."); + } + *mat = cv::Mat(static_cast(shape.Rank()), &sizes32[0], cv_type, data); + } else { + *mat = cv::Mat(2, &(cv_shape_type.first[0]), cv_shape_type.second, data); + } + return Status::OK(); +} + +Status CVTensor::Reshape(const TensorShape &shape) { + RETURN_IF_NOT_OK(Tensor::Reshape(shape)); + RETURN_IF_NOT_OK(this->MatInit(GetMutableBuffer(), shape_, type_, &mat_)); + return Status::OK(); +} + +Status CVTensor::ExpandDim(const dsize_t &axis) { + RETURN_IF_NOT_OK(Tensor::ExpandDim(axis)); + RETURN_IF_NOT_OK(this->MatInit(GetMutableBuffer(), shape_, type_, &mat_)); + return Status::OK(); +} + +void CVTensor::Squeeze() { + Tensor::Squeeze(); + (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h new file mode 100644 index 0000000000..a614418be6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h @@ -0,0 +1,106 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_CV_TENSOR_H_ +#define DATASET_CORE_CV_TENSOR_H_ + +#include +#include +#include + +#include + +#include "./securec.h" + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" + +namespace mindspore { +namespace dataset { +class CVTensor : public Tensor { + public: + // Create an empty CVTensor of shape `shape` and type `type`. + // @note The shape and type information should be known and valid. + // @param shape TensorShape + // @param type DataType + CVTensor(const TensorShape &shape, const DataType &type); + + // Create a CVTensor from a given buffer, shape and type. + // @note This constructor allocates a new space in the memory and copies the buffer into it. + // @note The buffer should be valid and the shape and type information should be known and valid. + // @param shape TensorShape + // @param type DataType + // @param data unsigned char*, pointer to the data. + CVTensor(const TensorShape &shape, const DataType &type, const uchar *data); + + // Create a CVTensor from a given CV::Mat. + // @note This constructor allocates a new space in the memory and copies the CV::Mat buffer into it. + // @param mat CV::Mat + explicit CVTensor(const cv::Mat &mat) + : CVTensor(TensorShape(mat.size, mat.type()), DataType::FromCVType(mat.type()), mat.data) {} + + ~CVTensor() = default; + + // Static function to cast a given Tensor as CVTensor. If the input tensor is already of type CVTensor, + // this function would be treated as a no-op. Fot other tensor types, a new CVTensor is created based on the data + // provided. The Passed Tensor will be invalidated. + // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. + // @param tensor + // @return CVTensor + static std::shared_ptr AsCVTensor(std::shared_ptr tensor); + + // Create a CVTensor from a given tensor. The input tensor will be invalidated (i.e., the shape and type will be + // set to unknown and the data buffer will point to null. + // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. + // @param tensor + explicit CVTensor(std::shared_ptr tensor); + + // Getter function for the CV::Mat + // @return + cv::Mat mat() const { return mat_; } + + // Static function to check if the passed information (shape and type) can be treated as a valid description + // of an image in OpenCV. Moreover, it returns OpenCV shape and type + // For example, if the shape is <512,512,3> and type is DE_UINT8, the output would be [512,512] and CV_8UC3. + // In case of invalid shape or type, the function will return pair + // @param shape TensorShape + // @param type DataType + // @return std::pair of OpenCV shape and type + std::pair, int> IsValidImage(const TensorShape &shape, const DataType &type); + + Status Reshape(const TensorShape &shape) override; + + Status ExpandDim(const dsize_t &axis) override; + + void Squeeze() override; + + Status Mat(const std::vector &index, cv::Mat *mat) { + uchar *start = nullptr; + TensorShape remaining({-1}); + RETURN_IF_NOT_OK(this->StartAddrOfIndex(index, &start, &remaining)); + RETURN_IF_NOT_OK(this->MatInit(start, remaining, type_, mat)); + return Status::OK(); + } + + private: + cv::Mat mat_; + + // Initialize CV::Mat with the data_, shape_ and type_ + Status MatInit(uchar *data, const TensorShape &shape, const DataType &type, cv::Mat *mat); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_CV_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/data_type.cc b/mindspore/ccsrc/minddata/dataset/core/data_type.cc new file mode 100644 index 0000000000..b5641e3105 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/data_type.cc @@ -0,0 +1,166 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/data_type.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/core/pybind_support.h" +#endif + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { + +uint8_t DataType::SizeInBytes() const { + if (type_ < DataType::NUM_OF_TYPES) + return kTypeInfo[type_].sizeInBytes_; + else + return 0; +} + +#ifdef ENABLE_PYTHON +py::dtype DataType::AsNumpyType() const { + if (type_ < DataType::NUM_OF_TYPES) + return py::dtype(kTypeInfo[type_].pybindType_); + else + return py::dtype("unknown"); +} +#endif + +uint8_t DataType::AsCVType() const { + uint8_t res = kCVInvalidType; + if (type_ < DataType::NUM_OF_TYPES) { + res = kTypeInfo[type_].cvType_; + } + + if (res == kCVInvalidType) { + MS_LOG(ERROR) << "Cannot convert to OpenCV type. Return invalid type!"; + } + + return res; +} // namespace dataset + +DataType DataType::FromCVType(int cv_type) { + auto depth = static_cast(cv_type) & static_cast(CV_MAT_DEPTH_MASK); + switch (depth) { + case CV_8S: + return DataType(DataType::DE_INT8); + case CV_8U: + return DataType(DataType::DE_UINT8); + case CV_16S: + return DataType(DataType::DE_INT16); + case CV_16U: + return DataType(DataType::DE_UINT16); + case CV_32S: + return DataType(DataType::DE_INT32); + case CV_16F: + return DataType(DataType::DE_FLOAT16); + case CV_32F: + return DataType(DataType::DE_FLOAT32); + case CV_64F: + return DataType(DataType::DE_FLOAT64); + default: + MS_LOG(ERROR) << "Cannot convert from OpenCV type, unknown CV type. Unknown data type is returned!"; + return DataType(DataType::DE_UNKNOWN); + } +} + +DataType::DataType(const std::string &type_str) { + if (type_str == "bool") + type_ = DE_BOOL; + else if (type_str == "int8") + type_ = DE_INT8; + else if (type_str == "uint8") + type_ = DE_UINT8; + else if (type_str == "int16") + type_ = DE_INT16; + else if (type_str == "uint16") + type_ = DE_UINT16; + else if (type_str == "int32") + type_ = DE_INT32; + else if (type_str == "uint32") + type_ = DE_UINT32; + else if (type_str == "int64") + type_ = DE_INT64; + else if (type_str == "uint64") + type_ = DE_UINT64; + else if (type_str == "float16") + type_ = DE_FLOAT16; + else if (type_str == "float32") + type_ = DE_FLOAT32; + else if (type_str == "float64") + type_ = DE_FLOAT64; + else if (type_str == "string") + type_ = DE_STRING; + else + type_ = DE_UNKNOWN; +} + +std::string DataType::ToString() const { + if (type_ < DataType::NUM_OF_TYPES) + return kTypeInfo[type_].name_; + else + return "unknown"; +} + +#ifdef ENABLE_PYTHON +DataType DataType::FromNpArray(const py::array &arr) { + if (py::isinstance>(arr)) { + return DataType(DataType::DE_BOOL); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_INT8); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_UINT8); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_INT16); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_UINT16); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_INT32); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_UINT32); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_INT64); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_UINT64); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_FLOAT16); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_FLOAT32); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_FLOAT64); + } else if (arr.dtype().kind() == 'S' || arr.dtype().kind() == 'U') { + return DataType(DataType::DE_STRING); + } else { + MS_LOG(ERROR) << "Cannot convert from numpy type. Unknown data type is returned!"; + return DataType(DataType::DE_UNKNOWN); + } +} + +std::string DataType::GetPybindFormat() const { + std::string res; + if (type_ < DataType::NUM_OF_TYPES) { + res = kTypeInfo[type_].pybindFormatDescriptor_; + } + + if (res.empty()) { + MS_LOG(ERROR) << "Cannot convert from data type to pybind format descriptor!"; + } + return res; +} +#endif + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/data_type.h b/mindspore/ccsrc/minddata/dataset/core/data_type.h new file mode 100644 index 0000000000..db4834cae2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/data_type.h @@ -0,0 +1,350 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_DATA_TYPE_H_ +#define DATASET_CORE_DATA_TYPE_H_ + +#include + +#include +#ifdef ENABLE_PYTHON +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "minddata/dataset/core/pybind_support.h" +namespace py = pybind11; +#else +#include "Eigen/Core" +using float16 = Eigen::half; +#endif +#include "minddata/dataset/core/constants.h" +namespace mindspore { +namespace dataset { + +// Class that represents basic data types in DataEngine. +class DataType { + public: + enum Type : uint8_t { + DE_UNKNOWN = 0, + DE_BOOL, + DE_INT8, + DE_UINT8, + DE_INT16, + DE_UINT16, + DE_INT32, + DE_UINT32, + DE_INT64, + DE_UINT64, + DE_FLOAT16, + DE_FLOAT32, + DE_FLOAT64, + DE_STRING, + NUM_OF_TYPES + }; + + struct TypeInfo { + const char *name_; // name to be represent the type while printing + const uint8_t sizeInBytes_; // number of bytes needed for this type + const char *pybindType_; // Python matching type, used in get_output_types + const std::string pybindFormatDescriptor_; // pybind format used for numpy types + const uint8_t cvType_; // OpenCv matching type + }; + +#ifdef ENABLE_PYTHON + static inline const TypeInfo kTypeInfo[] = { + // name, sizeInBytes, pybindTypem formatDescriptor, openCV + {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN + {"bool", 1, "bool", py::format_descriptor::format(), CV_8U}, // DE_BOOL + {"int8", 1, "int8", py::format_descriptor::format(), CV_8S}, // DE_INT8 + {"uint8", 1, "uint8", py::format_descriptor::format(), CV_8U}, // DE_UINT8 + {"int16", 2, "int16", py::format_descriptor::format(), CV_16S}, // DE_INT16 + {"uint16", 2, "uint16", py::format_descriptor::format(), CV_16U}, // DE_UINT16 + {"int32", 4, "int32", py::format_descriptor::format(), CV_32S}, // DE_INT32 + {"uint32", 4, "uint32", py::format_descriptor::format(), kCVInvalidType}, // DE_UINT32 + {"int64", 8, "int64", py::format_descriptor::format(), kCVInvalidType}, // DE_INT64 + {"uint64", 8, "uint64", py::format_descriptor::format(), kCVInvalidType}, // DE_UINT64 + {"float16", 2, "float16", "e", CV_16F}, // DE_FLOAT16 + {"float32", 4, "float32", py::format_descriptor::format(), CV_32F}, // DE_FLOAT32 + {"float64", 8, "double", py::format_descriptor::format(), CV_64F}, // DE_FLOAT64 + {"string", 0, "bytes", "S", kCVInvalidType} // DE_STRING + }; +#else + static inline const TypeInfo kTypeInfo[] = { + // name, sizeInBytes, pybindTypem formatDescriptor, openCV + {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN + {"bool", 1, "bool", "", CV_8U}, // DE_BOOL + {"int8", 1, "int8", "", CV_8S}, // DE_INT8 + {"uint8", 1, "uint8", "", CV_8U}, // DE_UINT8 + {"int16", 2, "int16", "", CV_16S}, // DE_INT16 + {"uint16", 2, "uint16", "", CV_16U}, // DE_UINT16 + {"int32", 4, "int32", "", CV_32S}, // DE_INT32 + {"uint32", 4, "uint32", "", kCVInvalidType}, // DE_UINT32 + {"int64", 8, "int64", "", kCVInvalidType}, // DE_INT64 + {"uint64", 8, "uint64", "", kCVInvalidType}, // DE_UINT64 + {"float16", 2, "float16", "", CV_16F}, // DE_FLOAT16 + {"float32", 4, "float32", "", CV_32F}, // DE_FLOAT32 + {"float64", 8, "double", "", CV_64F}, // DE_FLOAT64 + {"string", 0, "bytes", "", kCVInvalidType} // DE_STRING + }; +#endif + + // No arg constructor to create an unknown shape + DataType() : type_(DE_UNKNOWN) {} + + // Create a type from a given string + /// \param type_str + explicit DataType(const std::string &type_str); + + // Default destructor + ~DataType() = default; + + // Create a type from a given enum + /// \param d + constexpr explicit DataType(Type d) : type_(d) {} + + constexpr bool operator==(const DataType a) const { return type_ == a.type_; } + + constexpr bool operator==(const Type a) const { return type_ == a; } + + constexpr bool operator!=(const DataType a) const { return type_ != a.type_; } + + constexpr bool operator!=(const Type a) const { return type_ != a; } + + // Disable this usage `if(d)` where d is of type DataType + /// \return + operator bool() = delete; + + // To be used in Switch/case + /// \return + operator Type() const { return type_; } + + // The number of bytes needed to store one value of this type + /// \return + uint8_t SizeInBytes() const; + + // Convert from DataType to OpenCV type + /// \return + uint8_t AsCVType() const; + + // Convert from OpenCV type to DataType + /// \param cv_type + /// \return + static DataType FromCVType(int cv_type); + + // Returns a string representation of the type + /// \return + std::string ToString() const; + + // returns true if the template type is the same as the Tensor type_ + /// \tparam T + /// \return true or false + template + bool IsCompatible() const { + return type_ == FromCType(); + } + + // returns true if the template type is the same as the Tensor type_ + /// \tparam T + /// \return true or false + template + bool IsLooselyCompatible() const; + + // << Stream output operator overload + /// \notes This allows you to print the info using stream operators + /// \param out - reference to the output stream being overloaded + /// \param rO - reference to the DataType to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const DataType &so) { + out << so.ToString(); + return out; + } + + template + static DataType FromCType(); + +#ifdef ENABLE_PYTHON + // Convert from DataType to Pybind type + /// \return + py::dtype AsNumpyType() const; + + // Convert from NP type to DataType + /// \param type + /// \return + static DataType FromNpType(const py::dtype &type); + + // Convert from NP array to DataType + /// \param py array + /// \return + static DataType FromNpArray(const py::array &arr); +#endif + + // Get the buffer string format of the current type. Used in pybind buffer protocol. + /// \return + std::string GetPybindFormat() const; + + bool IsSignedInt() const { + return type_ == DataType::DE_INT8 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT32 || + type_ == DataType::DE_INT64; + } + + bool IsUnsignedInt() const { + return type_ == DataType::DE_UINT8 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT32 || + type_ == DataType::DE_UINT64; + } + + bool IsInt() const { return IsSignedInt() || IsUnsignedInt(); } + + bool IsFloat() const { + return type_ == DataType::DE_FLOAT16 || type_ == DataType::DE_FLOAT32 || type_ == DataType::DE_FLOAT64; + } + + bool IsBool() const { return type_ == DataType::DE_BOOL; } + + bool IsNumeric() const { return type_ != DataType::DE_STRING; } + + Type value() const { return type_; } + + private: + Type type_; +}; + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_BOOL); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_FLOAT64); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_FLOAT32); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_FLOAT16); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_INT64); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_UINT64); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_INT32); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_UINT32); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_INT16); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_UINT16); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_INT8); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_UINT8); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_STRING); +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_BOOL; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_FLOAT64 || type_ == DataType::DE_FLOAT32; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_FLOAT32; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_FLOAT16; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_INT64 || type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 || + type_ == DataType::DE_INT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_UINT64 || type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 || + type_ == DataType::DE_UINT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_INT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_UINT8; +} +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_DATA_TYPE_H_ diff --git a/mindspore/ccsrc/dataset/core/example.proto b/mindspore/ccsrc/minddata/dataset/core/example.proto similarity index 100% rename from mindspore/ccsrc/dataset/core/example.proto rename to mindspore/ccsrc/minddata/dataset/core/example.proto diff --git a/mindspore/ccsrc/dataset/core/feature.proto b/mindspore/ccsrc/minddata/dataset/core/feature.proto similarity index 100% rename from mindspore/ccsrc/dataset/core/feature.proto rename to mindspore/ccsrc/minddata/dataset/core/feature.proto diff --git a/mindspore/ccsrc/minddata/dataset/core/global_context.cc b/mindspore/ccsrc/minddata/dataset/core/global_context.cc new file mode 100644 index 0000000000..eb76382ab2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/global_context.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/global_context.h" + +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/util/system_pool.h" + +namespace mindspore { +namespace dataset { +// Global static pointer for the singleton GlobalContext +std::unique_ptr GlobalContext::global_context_ = nullptr; +std::once_flag GlobalContext::init_instance_flag_; + +constexpr int GlobalContext::kArenaSize; +constexpr int GlobalContext::kMaxSize; +constexpr bool GlobalContext::kInitArena; + +// Singleton initializer +GlobalContext *GlobalContext::Instance() { + // If the single global context is not created yet, then create it. Otherwise the + // existing one is returned. + std::call_once(init_instance_flag_, []() { + global_context_.reset(new GlobalContext()); + Status rc = global_context_->Init(); + if (rc.IsError()) { + std::terminate(); + } + }); + return global_context_.get(); +} + +Status GlobalContext::Init() { + config_manager_ = std::make_shared(); + mem_pool_ = std::make_shared(); + // For testing we can use Dummy pool instead + + // Create some tensor allocators for the different types and hook them into the pool. + tensor_allocator_ = std::make_unique>(mem_pool_); + cv_tensor_allocator_ = std::make_unique>(mem_pool_); + int_allocator_ = std::make_unique(mem_pool_); + return Status::OK(); +} + +// A print method typically used for debugging +void GlobalContext::Print(std::ostream &out) const { + out << "GlobalContext contains the following default config: " << *config_manager_ << "\n"; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/global_context.h b/mindspore/ccsrc/minddata/dataset/core/global_context.h new file mode 100644 index 0000000000..fe0847f639 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/global_context.h @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_GLOBAL_CONTEXT_H_ +#define DATASET_CORE_GLOBAL_CONTEXT_H_ + +#include +#include + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// forward declare +class MemoryPool; +class ConfigManager; +class Tensor; +class CVTensor; + +using TensorAlloc = Allocator; // An allocator for Tensors +using CVTensorAlloc = Allocator; // An allocator CVTensors +using IntAlloc = Allocator; + +class GlobalContext { + // some consts for pool config + static constexpr int kArenaSize = 128; + static constexpr int kMaxSize = -1; + static constexpr bool kInitArena = true; + + public: + // Singleton pattern. This method either: + // - creates the single version of the GlobalContext for the first time and returns it + // OR + // - returns the already existing single instance of the GlobalContext + // @return the single global context + static GlobalContext *Instance(); + + // Destructor + ~GlobalContext() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + void Print(std::ostream &out) const; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param g_c - reference to the GlobalContext to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const GlobalContext &g_c) { + g_c.Print(out); + return out; + } + + // Getter method + // @return the client config as raw const pointer + static std::shared_ptr config_manager() { return Instance()->config_manager_; } + + // Getter method + // @return the mem pool + std::shared_ptr mem_pool() const { return mem_pool_; } + + // Getter method + // @return the tensor allocator as raw pointer + const TensorAlloc *tensor_allocator() const { return tensor_allocator_.get(); } + + // Getter method + // @return the CVTensor allocator as raw pointer + const CVTensorAlloc *cv_tensor_allocator() const { return cv_tensor_allocator_.get(); } + + // Getter method + // @return the integer allocator as raw pointer + const IntAlloc *int_allocator() const { return int_allocator_.get(); } + + private: + // Constructor. + // @note Singleton. Instantiation flows through instance() + // @return This is a constructor. + GlobalContext() = default; + + Status Init(); + + static std::once_flag init_instance_flag_; + static std::unique_ptr global_context_; // The instance of the singleton (global) + std::shared_ptr mem_pool_; // A global memory pool + std::shared_ptr config_manager_; // The configs + std::unique_ptr tensor_allocator_; // An allocator for Tensors + std::unique_ptr cv_tensor_allocator_; // An allocator for CV Tensors + std::unique_ptr int_allocator_; // An allocator for ints +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CORE_GLOBAL_CONTEXT_H_ diff --git a/mindspore/ccsrc/dataset/core/pybind_support.h b/mindspore/ccsrc/minddata/dataset/core/pybind_support.h similarity index 100% rename from mindspore/ccsrc/dataset/core/pybind_support.h rename to mindspore/ccsrc/minddata/dataset/core/pybind_support.h diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.cc b/mindspore/ccsrc/minddata/dataset/core/tensor.cc new file mode 100644 index 0000000000..842615f9e1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.cc @@ -0,0 +1,1034 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/tensor.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/global_context.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/core/pybind_support.h" +namespace py = pybind11; +#endif +#include "minddata/dataset/core/tensor_shape.h" + +namespace mindspore { +namespace dataset { +// Helper macros for printing tensor elements +#define CASE_PRINT(de_type, native_type) \ + case de_type: { \ + native_type o; \ + rc = GetItemAt(&o, index); \ + out << o; \ + break; \ + } + +#define CASE_PRINT_HEX(de_type, native_type) \ + case de_type: { \ + native_type o; \ + rc = GetItemAt(&o, index); \ + out << std::hex << std::setw(2) << std::setfill('0') << o << std::dec << std::setfill(' '); \ + break; \ + } + +Tensor::Tensor(const TensorShape &shape, const DataType &type) : shape_(shape), type_(type), data_(nullptr) { + // grab the mem pool from global context and create the allocator for char data area + std::shared_ptr global_pool = GlobalContext::Instance()->mem_pool(); + data_allocator_ = std::make_unique>(global_pool); +} + +Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data) : Tensor(shape, type) { + if (type.IsNumeric()) { + // If the data pointer was given, then we can also populate the tensor with data + if (data != nullptr) { + // Given the shape/type of this tensor, compute the data size and copy in the input bytes. + int64_t byte_size = this->SizeInBytes(); + Status s = this->AllocateBuffer(byte_size); // Allocates data_ inside itself + if (s.IsOk() && data_ != nullptr) { + int ret_code = memcpy_s(data_, byte_size, data, byte_size); + if (ret_code != 0) { + MS_LOG(ERROR) << "Failed to copy data into Tensor!"; + } + } else { + MS_LOG(ERROR) << "Failed to create memory for Tensor!"; + } + } + } else { + MS_LOG(ERROR) << "Type should be numeric to use this constructor."; + } +} + +Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length) + : Tensor(shape, type) { + // If the data pointer was given, then we can also populate the tensor with data + if (data != nullptr) { + // Allocates data_ inside itself + Status s = AllocateBuffer(length); + if (s.IsError()) { + MS_LOG(ERROR) << "Failed to create memory for Tensor!"; + } + if (data_ != nullptr) { + int ret_code = memcpy_s(data_, length, data, length); + if (ret_code != 0) { + MS_LOG(ERROR) << "Failed to copy data into Tensor!"; + } + } + } +} + +Tensor::Tensor(Tensor &&other) noexcept + : shape_(other.shape()), + type_(other.type()), + data_(other.GetMutableBuffer()), + data_allocator_(std::move(other.data_allocator_)) { + other.Invalidate(); +} + +Tensor &Tensor::operator=(Tensor &&other) noexcept { + if (&other != this) { + shape_ = other.shape(); + type_ = other.type(); + data_ = other.GetMutableBuffer(); + data_end_ = other.data_end_; + data_allocator_ = std::move(other.data_allocator_); + other.Invalidate(); + } + return *this; +} + +Tensor::Tensor(const std::vector &strings, const TensorShape &shape) + : Tensor(TensorShape({static_cast(strings.size())}), DataType(DataType::DE_STRING)) { + auto length_sum = [](dsize_t sum, const std::string &s) { return s.length() + sum; }; + dsize_t total_length = std::accumulate(strings.begin(), strings.end(), 0, length_sum); + + // total bytes needed = offset array + strings + // offset array needs to store one offset var per element + 1 extra to get the length of the last string. + // strings will be null-terminated --> need 1 extra byte per element + dsize_t num_bytes = (kOffsetSize + 1) * shape_.NumOfElements() + kOffsetSize + total_length; + + data_ = data_allocator_->allocate(num_bytes); + + auto offset_arr = reinterpret_cast(data_); + uchar *buf = GetStringsBuffer(); + + offset_t offset = buf - data_; // the first string will start here + uint32_t i = 0; + for (const auto &str : strings) { + // insert the start index of the string. + offset_arr[i++] = offset; + // total bytes are reduced by kOffsetSize + num_bytes -= kOffsetSize; + // insert actual string + int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); + if (ret_code != 0) MS_LOG(ERROR) << "Cannot copy string into Tensor"; + // next string will be stored right after the current one. + offset = offset + str.length() + 1; + // total bytes are reduced by the length of the string + num_bytes -= str.length() + 1; + } + // store one more offset value so we can get the length of the last string + // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] + offset_arr[i] = offset; + + this->data_end_ = data_ + offset_arr[i]; + + MS_ASSERT(num_bytes == 0); + if (shape.known()) Tensor::Reshape(shape); +} + +Tensor::Tensor(const dataengine::BytesList &bytes_list, const TensorShape &shape) + : Tensor(TensorShape({static_cast(bytes_list.value_size())}), DataType(DataType::DE_STRING)) { + // total bytes needed = offset array + strings + // offset array needs to store one offset var per element + 1 extra to get the length of the last string. + // strings will be null-terminated --> need 1 extra byte per element + dsize_t num_bytes = (kOffsetSize)*shape_.NumOfElements() + kOffsetSize + bytes_list.ByteSizeLong(); + + data_ = data_allocator_->allocate(num_bytes); + + auto offset_arr = reinterpret_cast(data_); + uchar *buf = GetStringsBuffer(); + + offset_t offset = buf - data_; // the first string will start here + uint32_t i = 0; + for (; i < bytes_list.value_size(); i++) { + const std::string &str = bytes_list.value(i); + // insert the start index of the string. + offset_arr[i] = offset; + // total bytes are reduced by kOffsetSize + num_bytes -= kOffsetSize; + // insert actual string + int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); + if (ret_code != 0) { + MS_LOG(ERROR) << "Cannot copy string into Tensor"; + } + // next string will be stored right after the current one. + offset = offset + str.length() + 1; + // total bytes are reduced by the length of the string + num_bytes -= str.length() + 1; + } + // store one more offset value so we can get the length of the last string + // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] + offset_arr[i] = offset; + + data_end_ = data_ + offset_arr[i]; + + MS_ASSERT(num_bytes == 0); + if (shape.known()) Tensor::Reshape(shape); +} + +Status Tensor::CreateTensor(std::shared_ptr *ptr, TensorImpl tensor_impl, const TensorShape &shape, + DataType type, const unsigned char *data) { + if (!shape.known()) { + RETURN_STATUS_UNEXPECTED("Invalid shape."); + } + if (type == DataType::DE_UNKNOWN) { + RETURN_STATUS_UNEXPECTED("Invalid data type."); + } + + switch (tensor_impl) { + case TensorImpl::kFlexible: { + // The flex tensor is really just the base class tensor implementation + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *ptr = std::allocate_shared(*alloc, shape, type, data); + break; + } + case TensorImpl::kCv: { + const CVTensorAlloc *alloc = GlobalContext::Instance()->cv_tensor_allocator(); + *ptr = std::allocate_shared(*alloc, shape, type, data); + break; + } + default: { + std::string err_msg("Invalid tensor implementation type."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); // returns base-class shared_ptr +} + +#ifdef ENABLE_PYTHON +Status Tensor::CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr) { + std::vector shape; + for (dsize_t i = 0; i < arr.ndim(); i++) { + shape.push_back(static_cast(arr.shape()[i])); + } + arr.resize({arr.size()}); // flatten the py::array so we can iterate once + std::vector strings; + + if (arr.dtype().kind() == 'U') { + std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast(s)); }); + } else { + std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast(s)); }); + } + + arr.resize(shape); // resize arr back to the original shape + + return CreateTensor(ptr, strings, TensorShape{shape}); +} + +Status Tensor::CreateTensor(std::shared_ptr *ptr, py::array arr) { + if (DataType::FromNpArray(arr) == DataType::DE_STRING) { + return CreateTensorFromNumpyString(ptr, arr); + } + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *ptr = std::allocate_shared(*alloc, TensorShape({}), DataType(DataType::DE_UNKNOWN)); + + std::vector shape; + for (dsize_t i = 0; i < arr.ndim(); i++) { + shape.push_back(static_cast(arr.shape()[i])); + } + + (*ptr)->shape_ = TensorShape(shape); + (*ptr)->type_ = DataType::FromNpArray(arr); + if (!(*ptr)->shape_.known()) RETURN_STATUS_UNEXPECTED("Invalid shape."); + + if ((*ptr)->type_ == DataType::DE_UNKNOWN) RETURN_STATUS_UNEXPECTED("Invalid data type."); + + std::shared_ptr global_pool = GlobalContext::Instance()->mem_pool(); + (*ptr)->data_allocator_ = std::make_unique>(global_pool); + int64_t byte_size = (*ptr)->SizeInBytes(); + RETURN_IF_NOT_OK((*ptr)->AllocateBuffer(byte_size)); + + unsigned char *data = static_cast(arr.request().ptr); + if ((*ptr)->data_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Failed to create memory for Tensor."); + } + + std::vector strides; + for (dsize_t i = 0; i < arr.ndim(); i++) { + strides.push_back(static_cast(arr.strides()[i])); + } + + // check if strides are contiguous + bool is_strided = false; + dsize_t count = (*ptr)->shape_.NumOfElements(); + for (size_t i = 0; i < shape.size(); i++) { + count /= shape[i]; + if (strides[i] != (*ptr)->type_.SizeInBytes() * count) { + is_strided = true; + break; + } + } + + if (is_strided) { + RETURN_IF_NOT_OK(CopyStridedArray((*ptr)->data_, data, shape, strides, (*ptr)->type_.SizeInBytes())); + } else { + int ret_code = memcpy_s((*ptr)->data_, byte_size, data, byte_size); + if (ret_code != 0) { + RETURN_STATUS_UNEXPECTED("Failed to copy data into Tensor."); + } + } + + return Status::OK(); // returns base-class shared_ptr +} +#endif + +Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::vector &strings, + const TensorShape &shape) { + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *ptr = std::allocate_shared(*alloc, strings, shape); + return Status::OK(); +} + +Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, + const TensorShape &shape) { + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *ptr = std::allocate_shared(*alloc, bytes_list, shape); + return Status::OK(); +} + +Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::string &file_path) { + std::ifstream fs; + fs.open(file_path, std::ios::binary | std::ios::in); + CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + file_path); + int64_t num_bytes = fs.seekg(0, std::ios::end).tellg(); + CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Fail to find size of file"); + RETURN_IF_NOT_OK( + Tensor::CreateTensor(ptr, TensorImpl::kFlexible, TensorShape{num_bytes}, DataType(DataType::DE_UINT8))); + int64_t written_bytes = fs.read(reinterpret_cast((*ptr)->GetMutableBuffer()), num_bytes).gcount(); + CHECK_FAIL_RETURN_UNEXPECTED(written_bytes == num_bytes && fs.good(), "Error in writing to tensor"); + fs.close(); + return Status::OK(); +} + +Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, + const TensorShape &shape, const DataType &type, dsize_t pad_size) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(ptr, TensorImpl::kFlexible, shape, type)); + + unsigned char *current_tensor_addr = (*ptr)->GetMutableBuffer(); + int64_t tensor_bytes_remaining = bytes_list.value_size() * pad_size; + + for (int i = 0; i < bytes_list.value_size(); i++) { + // read string data into tensor + const std::string ¤t_element = bytes_list.value(i); + int return_code = + memcpy_s(current_tensor_addr, tensor_bytes_remaining, common::SafeCStr(current_element), current_element.size()); + + CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when reading bytesList element into Tensor"); + + current_tensor_addr += current_element.size(); + tensor_bytes_remaining -= current_element.size(); + + // pad + int64_t chars_to_pad = pad_size - current_element.size(); + return_code = memset_s(current_tensor_addr, tensor_bytes_remaining, static_cast(' '), chars_to_pad); + CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when padding Tensor"); + + current_tensor_addr += chars_to_pad; + tensor_bytes_remaining -= chars_to_pad; + } + + return Status::OK(); +} + +// Memcpy the given strided array's used part to consecutive memory +// Consider a 3-d array +// A[(i * shape[1] + j) * shape[2] + k] = B[i][j][k] = C[i * strides[0] + j * strides[1] + k * strides[2]] +// Here we convert array C to array A, by memcpy index by index (Note that not all elements in C is copied) +Status Tensor::CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, + std::vector strides, uint8_t type_size) { + dsize_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + for (dsize_t i = 0; i < size; ++i) { + dsize_t offset = 0; + dsize_t count = i; + for (size_t j = 0; j < shape.size(); ++j) { + // convert 1d array's index to 3d array's index (A -> B) + dsize_t idx = count % shape[shape.size() - 1 - j]; + count /= shape[shape.size() - 1 - j]; + // calculate the raw data offset based on strides (B -> C) + offset += idx * strides[shape.size() - 1 - j]; + // once count = 0, the following idxes are all zero, skip them + if (count == 0) break; + } + // strides already consider byte size of the data type, but dst doesn't. + // dst[i] = dst + i * type_size = src + offset + int ret_code = memcpy_s(dst + i * type_size, type_size, src + offset, type_size); + if (ret_code != 0) { + RETURN_STATUS_UNEXPECTED("Failed to copy data into Tensor."); + } + } + return Status::OK(); +} + +// Name: Destructor +// Description: Destructor +Tensor::~Tensor() { + if (data_ != nullptr) { + if (data_allocator_ != nullptr) { + data_allocator_->deallocate(data_); + data_ = nullptr; + data_end_ = nullptr; + } else { + // If we didn't have an allocator, but data_ is not null then it must + // be a stand-alone tensor that used malloc directly. + free(data_); + data_ = nullptr; + data_end_ = nullptr; + } + } +} + +bool Tensor::operator==(const Tensor &rhs) const { + // 1. different shape 2. different type 3. one data_ is nullptr and the other is not + if (shape_ != rhs.shape() || type_ != rhs.type_ || (data_ == nullptr && rhs.data_ != nullptr) || + (data_ != nullptr && rhs.data_ == nullptr)) { + return false; + } + if (data_ == nullptr && rhs.data_ == nullptr) { + return true; + } + // use mem compare to compare the two data, size are already verified + return memcmp(data_, rhs.data_, SizeInBytes()) == 0; +} + +// Name: PrintItemAt() +// Description: A function that print the value as specified by its index +void Tensor::PrintItemAt(const std::vector &index, std::ostream &out) const { + Status rc; + MS_ASSERT(data_); + + switch (type_.value()) { + CASE_PRINT_HEX(DataType::DE_BOOL, bool); + + CASE_PRINT_HEX(DataType::DE_INT8, int8_t); + + CASE_PRINT_HEX(DataType::DE_UINT8, uint8_t); + + CASE_PRINT(DataType::DE_INT16, int16_t); + + CASE_PRINT(DataType::DE_UINT16, uint16_t); + + CASE_PRINT(DataType::DE_INT32, int32_t); + + CASE_PRINT(DataType::DE_UINT32, uint32_t); + + CASE_PRINT(DataType::DE_INT64, int64_t); + + CASE_PRINT(DataType::DE_UINT64, uint64_t); + + CASE_PRINT(DataType::DE_FLOAT16, float16); + + CASE_PRINT(DataType::DE_FLOAT32, float); + + CASE_PRINT(DataType::DE_FLOAT64, double); + + case DataType::DE_STRING: { + std::string_view o{""}; + GetItemAt(&o, index); + out << "\"" << o << "\""; + break; + } + default: { + out << "?"; + break; + } + } + if (rc.IsError()) { + out << rc.ToString(); + } +} + +// Name: PrintRecursive() +// Description: A function that prints Tensor recursively, first called by print +void Tensor::PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const { + if (cur_index.size() == shape_.Rank()) { + PrintItemAt(cur_index, out); + } else { + out << "["; + for (dsize_t i = 0; i < shape_[cur_dim]; i++) { + std::vector new_index = cur_index; + new_index.push_back(i); + PrintRecursive(out, cur_dim + 1, new_index); + if (i < shape_[cur_dim] - 1) { + out << ","; + } + } + out << "]"; + } +} + +// Name: Print() +// Description: A function that prints info about the tensor +void Tensor::Print(std::ostream &out) const { + out << "Tensor (shape: "; + out << shape_; + out << ", Type: " << type_ << ")\n"; + if (data_) { + PrintRecursive(out, 0, std::vector{}); + } else { + out << "[Data area is null]"; + } +} +Status Tensor::AllocateBuffer(const dsize_t &length) { + if (data_ == nullptr) { + if (data_allocator_ != nullptr) { + data_ = data_allocator_->allocate(length); + RETURN_UNEXPECTED_IF_NULL(data_); + data_end_ = data_ + length; + } else { + data_ = static_cast(malloc(length)); + data_end_ = data_ + length; + RETURN_UNEXPECTED_IF_NULL(data_); + } + } + return Status::OK(); +} +const unsigned char *Tensor::GetBuffer() const { + // This version cannot modify anything. data_ could possibly be null. + return data_; +} + +// check for empty +bool Tensor::HasData() const { + if (data_ == nullptr) { + return true; + } else { + return false; + } +} + +unsigned char *Tensor::GetMutableBuffer() { + if (!shape_.known() || type_ == DataType::DE_UNKNOWN) { + return nullptr; + } + // If the data area is already created, return the pointer to it + if (data_ != nullptr) { + return data_; + } else { + // If the data area is not created, then identify the memory size based + // on the shape and type and allocate it. + if (this->AllocateBuffer(this->SizeInBytes()).IsOk()) { + return data_; + } else { + return nullptr; + } + } +} + +Status Tensor::Reshape(const TensorShape &shape) { + if (shape.NumOfElements() == shape_.NumOfElements()) { + shape_ = shape; + return Status::OK(); + } else { + std::string err = "Cannot reshape, Number of elements do not match"; + RETURN_STATUS_UNEXPECTED(err); + } +} + +void Tensor::Invalidate() { + shape_ = TensorShape::CreateUnknownRankShape(); + type_ = DataType(DataType::DE_UNKNOWN); + data_ = nullptr; + data_end_ = nullptr; + data_allocator_ = nullptr; +} + +template +Status Tensor::GetItemPtr(T **ptr, const std::vector &index) const { + if (type_.IsCompatible()) { + if (data_ == nullptr) { + std::string err = "Data is not allocated yet"; + RETURN_STATUS_UNEXPECTED(err); + } + dsize_t flat_idx; + RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &flat_idx)); + *ptr = reinterpret_cast(data_ + flat_idx * type_.SizeInBytes()); + + return Status::OK(); + } else { + std::string err = "data type not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } +} + +Status Tensor::GetItemPtr(uchar **ptr, const std::vector &index, offset_t *length) const { + if (type_ == DataType::DE_STRING) { + if (data_ == nullptr) { + std::string err = "Data is not allocated yet"; + RETURN_STATUS_UNEXPECTED(err); + } + dsize_t flat_idx; + RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &flat_idx)); + offset_t length_temp = 0; + RETURN_IF_NOT_OK(GetStringAt(flat_idx, ptr, &length_temp)); + if (length != nullptr) *length = length_temp; + return Status::OK(); + } else { + std::string err = "data type not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } +} + +Status Tensor::StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining) { + if (type() == DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("StartAddrOfIndex does not support string tensors yet."); + } + + dsize_t flat_ind; + std::vector t_shape = shape().AsVector(); + std::vector r(t_shape.begin() + ind.size(), t_shape.end()); + *remaining = TensorShape(r); + ind.resize(this->Rank(), 0); // same as -> while (ind.size() < this->Rank()) ind.push_back(0); + + RETURN_IF_NOT_OK(shape_.ToFlatIndex(ind, &flat_ind)); + // check if GetBuffer() returns null, we should flag this as an error, this sanity check will only + // be true is the tensor failed to allocate memory. + if (GetMutableBuffer() == nullptr) { + RETURN_STATUS_UNEXPECTED("Invalid GetBuffer in Tensor, got nullptr"); + } + *start_addr_of_index = GetMutableBuffer() + flat_ind * this->type().SizeInBytes(); + return Status::OK(); +} + +Status Tensor::InsertTensor(const std::vector &ind, const std::shared_ptr &tensor) { + std::string err_msg; + err_msg += (this->type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; + err_msg += (!this->shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; + err_msg += (ind.size() + tensor->Rank() != this->Rank()) ? "[Tensor] incorrect index\n" : ""; + err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; + uchar *start_addr_of_ind = nullptr; + TensorShape remaining_shape({-1}); + err_msg += (!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : ""; + err_msg += !(remaining_shape == tensor->shape()) ? "[Tensor] memory error\n" : ""; + if (!err_msg.empty()) { + MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + if (start_addr_of_ind != nullptr) { + int ret_code = + memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); + if (ret_code == 0) { + return Status::OK(); + } else { + err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; + MS_LOG(DEBUG) << "Tensor message: " << err_msg; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } else { + RETURN_STATUS_UNEXPECTED("Failed to create memory for Tensor."); + } + } +} + +Status Tensor::Concatenate(const std::vector &index, const std::shared_ptr &tensor) { + std::string err_msg; + err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : ""; + err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; + err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; + + err_msg += + (index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : ""; + err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; + uchar *start_addr_of_ind = nullptr; + + TensorShape remaining_shape = tensor->shape(); + StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape); + err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : ""; + + if (!err_msg.empty()) { + MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; + + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + int ret_code = + memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); + + if (ret_code == 0) { + return Status::OK(); + } else { + err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; + MS_LOG(DEBUG) << "Tensor message: " << err_msg; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } +} + +Status Tensor::ExpandDim(const dsize_t &axis) { + if (axis > Rank()) { + std::string err = "Axis is out of bound"; + RETURN_STATUS_UNEXPECTED(err); + } + if (axis == Rank()) { + shape_ = shape_.AppendDim(1); + } else { + shape_ = shape_.InsertDim(axis, 1); + } + return Status::OK(); +} + +std::vector Tensor::Strides() { + std::vector strides = shape_.Strides(); + uint8_t size = type_.SizeInBytes(); + std::transform(strides.begin(), strides.end(), strides.begin(), [&size](const auto &c) { return c * size; }); + return strides; +} + +#ifdef ENABLE_PYTHON +Status Tensor::GetBufferInfo(Tensor *t, py::buffer_info *out) { + RETURN_UNEXPECTED_IF_NULL(t); + CHECK_FAIL_RETURN_UNEXPECTED(t->type().IsNumeric(), "Cannot use GetBufferInfo on tensor of strings."); + + std::string format_desc = t->type().GetPybindFormat(); + if (format_desc.empty()) { + RETURN_STATUS_UNEXPECTED("Cannot convert DE type tp pybind format"); + } + *out = py::buffer_info(t->GetMutableBuffer(), /* Pointer to buffer */ + t->type().SizeInBytes(), /* Size of one scalar */ + format_desc, /* Python struct-style format descriptor */ + t->Rank(), /* Number of dimensions */ + t->shape().AsVector(), /* Buffer dimensions */ + t->Strides()); + return Status::OK(); +} +#endif + +template +Status Tensor::GetItemAt(T *o, const std::vector &index) const { + if (data_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); + } + if (!type_.IsLooselyCompatible()) { + std::string err = "Template type and Tensor type are not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } + if (type_.IsUnsignedInt()) { + RETURN_IF_NOT_OK(GetUnsignedIntAt(o, index)); + } else if (type_.IsSignedInt()) { + RETURN_IF_NOT_OK(GetSignedIntAt(o, index)); + } else if (type_.IsFloat()) { + RETURN_IF_NOT_OK(GetFloatAt(o, index)); + } else if (type_.IsBool()) { + bool *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + } else { + std::string err = "Tensor Type is unknown"; + RETURN_STATUS_UNEXPECTED(err); + } + return Status::OK(); +} + +Status Tensor::GetItemAt(std::string_view *o, const std::vector &index) const { + RETURN_UNEXPECTED_IF_NULL(data_); + RETURN_UNEXPECTED_IF_NULL(o); + CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Tensor type is not a string"); + + uchar *start = nullptr; + offset_t length = 0; + RETURN_IF_NOT_OK(GetItemPtr(&start, index, &length)); + std::string_view sv{reinterpret_cast(start)}; + o->swap(sv); + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +// return data as numpy, should return status +Status Tensor::GetDataAsNumpy(py::array *data) { + RETURN_UNEXPECTED_IF_NULL(data_); + RETURN_UNEXPECTED_IF_NULL(data); + if (type_ == DataType::DE_BOOL) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_INT8) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_INT16) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_INT32) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_INT64) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_UINT8) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_UINT16) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_UINT32) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_UINT64) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_FLOAT16) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_FLOAT32) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_FLOAT64) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_STRING) { + GetDataAsNumpyStrings(data); + } else { + RETURN_STATUS_UNEXPECTED("Got unexpected type when returning numpy"); + } + return Status::OK(); +} +Status Tensor::GetDataAsNumpyStrings(py::array *data) { + auto itr = begin(); + uint64_t max = 0; + for (; itr != end(); itr++) { + max = std::max((*itr).length(), max); + } + // if all strings are empty, numpy stores a byte for each string |S1 + max = (max == 0 ? 1 : max); + uint64_t total_size = shape_.NumOfElements() * max; + char *tmp_data = reinterpret_cast(data_allocator_->allocate(total_size)); + if (tmp_data == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create temp array."); + int ret_code = memset_s(tmp_data, total_size, 0, total_size); + CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to initialize temp memory"); + + itr = begin(); + uint64_t i = 0; + for (; itr != end(); itr++, i++) { + if (!(*itr).empty()) { + ret_code = memcpy_s(tmp_data + i * max, total_size, (*itr).data(), (*itr).length()); + CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy string data."); + } + } + auto strides = shape_.Strides(); + std::transform(strides.begin(), strides.end(), strides.begin(), [&max](const auto &s) { return s * max; }); + *data = py::array(py::dtype("S" + std::to_string(max)), shape_.AsVector(), strides, tmp_data); + data_allocator_->deallocate(reinterpret_cast(tmp_data)); + return Status::OK(); +} +#endif + +void Tensor::Squeeze() { shape_ = shape_.Squeeze(); } + +template +Status Tensor::GetUnsignedIntAt(T *o, const std::vector &index) const { + if (data_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); + } + if (!type_.IsLooselyCompatible()) { + std::string err = "Template type and Tensor type are not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } + switch (type_.value()) { + case DataType::DE_UINT8: { + uint8_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_UINT16: { + uint16_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_UINT32: { + uint32_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_UINT64: { + uint64_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + default: + std::string err = "Tensor Type is not an unsigned Integer"; + RETURN_STATUS_UNEXPECTED(err); + } + return Status::OK(); +} + +template +Status Tensor::GetSignedIntAt(T *o, const std::vector &index) const { + if (data_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); + } + if (!type_.IsLooselyCompatible()) { + std::string err = "Template type and Tensor type are not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } + switch (type_.value()) { + case DataType::DE_INT8: { + int8_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_INT16: { + int16_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_INT32: { + int32_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_INT64: { + int64_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + default: + std::string err = "Tensor Type is not a signed Integer"; + RETURN_STATUS_UNEXPECTED(err); + } + return Status::OK(); +} + +template +Status Tensor::GetFloatAt(T *o, const std::vector &index) const { + if (data_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); + } + if (!type_.IsLooselyCompatible()) { + std::string err = "Template type and Tensor type are not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } + switch (type_.value()) { + case DataType::DE_FLOAT16: { + float16 *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_FLOAT32: { + float *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_FLOAT64: { + double *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + default: + std::string err = "Tensor Type is not a float/double"; + RETURN_STATUS_UNEXPECTED(err); + } + return Status::OK(); +} +Status Tensor::GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const { + CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Type is not string"); + RETURN_UNEXPECTED_IF_NULL(data_); + RETURN_UNEXPECTED_IF_NULL(string_start); + RETURN_UNEXPECTED_IF_NULL(length); + auto *offset_ptr = reinterpret_cast(data_); // offsets starts here + offset_t start = offset_ptr[index]; + *string_start = data_ + start; + *length = offset_ptr[index + 1] - start - 1; // -1 to skip the \0 from the string length + return Status::OK(); +} +Status Tensor::CopyLastDimAt(const std::shared_ptr &src, const std::vector &index) { + CHECK_FAIL_RETURN_UNEXPECTED(src->type() == type_, "Source Tensor has a different type"); + CHECK_FAIL_RETURN_UNEXPECTED(index.back() == 0, "Last dim in index should be 0"); + + uint8_t type_size = type_.SizeInBytes(); + size_t len = std::min(src->shape()[-1], shape_[-1]) * type_size; + dsize_t src_flat_ind = 0, dst_flat_ind = 0; + RETURN_IF_NOT_OK(src->shape().ToFlatIndex(index, &src_flat_ind)); + RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &dst_flat_ind)); + + const unsigned char *src_addr = src->GetBuffer() + src_flat_ind * type_size; + unsigned char *dst_addr = GetMutableBuffer() + dst_flat_ind * type_size; + CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error"); + return Status::OK(); +} +Status Tensor::Slice(std::shared_ptr *out, const std::vector &indices) { + CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only."); + CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty."); + if (type_.IsNumeric()) { + return SliceNumeric(out, indices); + } else { + return SliceString(out, indices); + } +} +Status Tensor::SliceNumeric(std::shared_ptr *out, const std::vector &indices) { + RETURN_IF_NOT_OK( + CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast(indices.size())}), type_)); + (*out)->GetMutableBuffer(); + dsize_t out_index = 0; + dsize_t dim_length = shape_[0]; + dsize_t type_size = type_.SizeInBytes(); + dsize_t src_start = HandleNeg(indices[0], dim_length); + uchar *dst_addr = (*out)->data_; + dsize_t count = 1; + + for (dsize_t i = 0; i < indices.size(); i++) { + dsize_t cur_index = HandleNeg(indices[i], dim_length); + CHECK_FAIL_RETURN_UNEXPECTED( + cur_index >= 0 && cur_index < dim_length, + "Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")"); + if (i < indices.size() - 1) { + dsize_t next_index = HandleNeg(indices[i + 1], dim_length); + if (next_index == cur_index + 1) { + count++; + continue; + } + } + int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, + count * type_size); + CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric"); + out_index += count; + if (i < indices.size() - 1) { + src_start = HandleNeg(indices[i + 1], dim_length); // next index + } + count = 1; + } + return Status::OK(); +} +Status Tensor::SliceString(std::shared_ptr *out, const std::vector &indices) { + dsize_t dim_length = shape_[0]; + std::vector strings; + for (dsize_t index : indices) { + dsize_t cur_index = HandleNeg(index, dim_length); + CHECK_FAIL_RETURN_UNEXPECTED( + cur_index >= 0 && cur_index < dim_length, + "Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")"); + std::string_view sv; + GetItemAt(&sv, {cur_index}); + strings.emplace_back(sv); + } + return CreateTensor(out, strings); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.h b/mindspore/ccsrc/minddata/dataset/core/tensor.h new file mode 100644 index 0000000000..b0b173e9c3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.h @@ -0,0 +1,668 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_TENSOR_H_ +#define DATASET_CORE_TENSOR_H_ + +#include +#include +#include +#include +#include "./securec.h" +#include "utils/log_adapter.h" +#if defined(_WIN32) || defined(_WIN64) +#undef HAVE_STDDEF_H +#undef HAVE_STDLIB_H +#endif + +#ifdef ENABLE_PYTHON +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#endif + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/util/status.h" +#include "proto/example.pb.h" + +#ifdef ENABLE_PYTHON +namespace py = pybind11; +#endif +namespace mindspore { +namespace dataset { +class Tensor; +template +class Allocator; + +using CharAllocPtr = std::unique_ptr>; +using TensorAllocPtr = std::shared_ptr>; // An allocator shared_ptr for Tensors + +class Tensor { + public: + Tensor() = delete; + + // Create a new tensor, does not internally allocate storage. This constructor is protected, use CreateTensor. + // @note The shape and type information should be known and valid. + // @param shape TensorShape + // @param type DataType + Tensor(const TensorShape &shape, const DataType &type); + + // Create a new tensor, allocates storage and copies in data. This constructor is protected, use CreateTensor. + // @note The buffer should be valid and the shape and type information should be known and valid. + // @param shape TensorShape + // @param type DataType + // @param data unsigned char*, pointer to the data. + Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data); + + Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length); + + Tensor(const Tensor &other) = delete; + + Tensor &operator=(const Tensor &other) = delete; + + Tensor(Tensor &&other) noexcept; + + Tensor &operator=(Tensor &&other) noexcept; + + Status AllocateBuffer(const dsize_t &length); + + // type of offest values to store strings information + using offset_t = uint32_t; + // const of the size of the offset variable + static constexpr uint8_t kOffsetSize = sizeof(offset_t); + // Tensor base class which holds the data in an unsigned char* buffer. + + // Construct a scalar string Tensor + explicit Tensor(const std::string &str) : Tensor(std::vector{str}, TensorShape::CreateScalar()) {} + + // Construct a tensor from a list of strings. Reshape the tensor with `shape` if given, otherwise assume the shape is + // the size of the vector `strings`. + // The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. + // Thr offset array will store one extra value to find the length of the last string. + // OFFSET1, OFFSET2, ..., OFFSETn+1, STRING1, STRING2, ..., STRINGn + // The value of each offset is the start index of the corresponding string + // Offsets is of type offest_t + // strings will ne null-terminated + // example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) + // |----------------------------------------------------------------| + // | OFFSET ARRAY | STRINGS | + // | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | + // | 11 | 15 | 18 | abc\0 | de\0 | + // |----------------------------------------------------------------| + explicit Tensor(const std::vector &strings, + const TensorShape &shape = TensorShape::CreateUnknownRankShape()); + + // Same as Tensor(vector) but the input is protobuf bytelist + explicit Tensor(const dataengine::BytesList &bytes_list, + const TensorShape &shape = TensorShape::CreateUnknownRankShape()); + + // A static factory method to create the given flavour of derived Tensor + // Returns the base class reference for the Tensor. + // @param ptr output argument to hold the created Tensor of given tensor_impl + // @param tensor_impl - which implementation of Tensor + // @param shape - shape of the tensor + // @param type - datatype of the tensor + // @param data - data to be copied to Tensor new allocation + // @return Status Code + static Status CreateTensor(std::shared_ptr *, TensorImpl tensor_impl, const TensorShape &shape, DataType type, + const unsigned char *data = nullptr); + + // Create a copy of the input tensor + // @param out [out] output tensor to be generated + // @param in [in] orginal tensor to be copied + // @return Status + static Status CreateTensor(std::shared_ptr *out, const std::shared_ptr &in) { + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *out = std::allocate_shared(*alloc, in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes()); + return Status::OK(); + } + +#ifdef ENABLE_PYTHON + // A static factory method to create a Tensor from a given py::array. + // @param ptr output argument to hold the created Tensor + // @param arr py::array + // @return Status Code + static Status CreateTensor(std::shared_ptr *ptr, py::array arr); + + // Helper function to create a tensor from Numpy of strings + static Status CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr); +#endif + + // A static factory method to create a Tensor from a given list of strings. + // @param ptr output argument to hold the created Tensor + // @param strings elements of the tensor + // @param shape shape of the tensor + // @return Status Code + static Status CreateTensor(std::shared_ptr *ptr, const std::vector &strings, + const TensorShape &shape = TensorShape::CreateUnknownRankShape()); + + // create tensor from protobuf bytelist with strings + static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, + const TensorShape &shape); + + // A static factory method to create a Tensor from a given list of numbers. + // @param ptr output argument to hold the created Tensor + // @param items elements of the tensor + // @param shape shape of the tensor + // @return Status Code + template + static Status CreateTensor(std::shared_ptr *ptr, const std::vector &items, + const TensorShape &shape_req = TensorShape::CreateUnknownRankShape()) { + DataType type = DataType::FromCType(); + auto items_ptr = reinterpret_cast(&items[0]); + TensorShape shape = shape_req; + if (!shape.known()) { + shape = TensorShape({static_cast(items.size())}); + } + return CreateTensor(ptr, TensorImpl::kFlexible, shape, type, items_ptr); + } + + // A static factory method to create a Tensor from a given number. + // @param ptr output argument to hold the created Tensor + // @param item value + // @return Status Code + template + static Status CreateTensor(std::shared_ptr *ptr, const T &item) { + return CreateTensor(ptr, {item}, TensorShape::CreateScalar()); + } + + // Create tensor from protobuf bytelist with uint8 or int8 types + static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, + const TensorShape &shape, const DataType &type, dsize_t pad_size); + + static Status CreateTensor(std::shared_ptr *ptr, const std::string &path); + + // Copy raw data of a array based on shape and strides to the destination pointer + // @param dst Pointer to the destination array where the content is to be copied + // @param src Pointer to the source of strided array to be copied + // @param shape - shape of the source array + // @param strides - strides of the source array + // @param type_size - number of bytes needed to store one array element's type + // @return Status Code + static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, + std::vector strides, uint8_t type_size); + + // Release the memory using the allocator + virtual ~Tensor(); + + // compare the tensor shape and data + bool operator==(const Tensor &rhs) const; + + bool operator!=(const Tensor &rhs) const { return !((*this) == rhs); } + + // Get item located at `index`, caller needs to provide the type. + // @tparam T + // @param index vector + // @return return the item specified at index + template + Status GetItemAt(T *o, const std::vector &index) const; + + // Get string located at `index`. + // @param index vector + // @return return std::string_view specified at index + Status GetItemAt(std::string_view *o, const std::vector &index) const; + + template + Status GetUnsignedIntAt(T *o, const std::vector &index) const; + + template + Status GetSignedIntAt(T *o, const std::vector &index) const; + + template + Status GetFloatAt(T *o, const std::vector &index) const; + + // set item at location specified by index + // @tparam `T` + // @param index + // @param value of type `T` + template + Status SetItemAt(const std::vector &index, const T &value) { + RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); + T *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *ptr = value; + return Status::OK(); + } + + // set string item at location specified by index + // @param index + // @param value of type std::string + Status SetItemAt(const std::vector &index, const std::string &value) { + RETURN_UNEXPECTED_IF_NULL(data_); + uchar *ptr = nullptr; + offset_t length = 0; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index, &length)); + if (value.length() != length) { + RETURN_STATUS_UNEXPECTED("Length of the new string does not match the item."); + } + memcpy_s(reinterpret_cast(ptr), length, value.c_str(), length); + + return Status::OK(); + } + // fill tensor with Zeros. Does not support strings. + Status Zero() { + CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use Zero on tensor of strings.."); + dsize_t size = SizeInBytes(); + CHECK_FAIL_RETURN_UNEXPECTED(memset_sp(GetMutableBuffer(), size, 0, size) == 0, + "Failed to fill tensor with zeroes."); + return Status::OK(); + } + + // Fill all elements in the Tensor with the given value of type `T`. Does not support strings. + // @tparam T + // @param value + template + Status Fill(const T &value) { + CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use fill on tensor of strings."); + RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); + int64_t cellSize = type_.SizeInBytes(); + if ((data_ != nullptr) && type_.IsCompatible()) { + for (dsize_t i = 0; i < Size(); i++) { + CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s((data_ + i * cellSize), cellSize, &value, cellSize) == 0, "memcpy err"); + } + return Status::OK(); + } else { + std::string err; + err += (data_ == nullptr) ? "data_ is nullptr \t" : ""; + err += type_.IsCompatible() ? "data type not compatible\t" : ""; + return Status(StatusCode::kUnexpectedError, err); + } + } + + // Getter function for shape + // @return + const TensorShape &shape() const { return shape_; } + + /// Check if tensor has data + /// \return bool - true if tensor is empty + bool HasData() const; + + // Reshape the tensor. The given shape should have the same number of elements in the Tensor + // @param shape + virtual Status Reshape(const TensorShape &shape); + + // @return number of elements in this tensor + dsize_t Size() const { return shape().NumOfElements(); } + + // @return the number of bytes this tensor is needs + dsize_t SizeInBytes() const { + if (data_end_ == nullptr) return type_.SizeInBytes() * shape_.NumOfElements(); + return data_end_ - data_; + } + + // @return the rank of the tensor + dsize_t Rank() const { return shape().Rank(); } + + // Get the starting memory address as a constant for the data of the tensor. This potentially + // drives an allocation if the data area. + // @return const unsigned char* + const unsigned char *GetBuffer() const; + + // Getter of the type + // @return + DataType type() const { return type_; } + + // Provide stream operator for displaying it + // @param output stream + // @param so the Tensor object to be printed + // @return output stream + friend std::ostream &operator<<(std::ostream &out, const Tensor &so) { + so.Print(out); + return out; + } + + // Invalidate this Tensor by setting the type and shape to unknown and MData to null. + // Calling this method will make the Tensor and its data inaccessible, use it with caution. + void Invalidate(); + + // Copy input tensor into self at the location index. + // Index is a vector of axises which can be incomplete: + // Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. + // @param index + // @param input + // @return Status code + Status InsertTensor(const std::vector &index, const std::shared_ptr &input); + + // Find the address of the given index. Used in InsertTensor. + // Example: + // Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 + // @param index incomplete index + // @param output: startAddrofIndex + // @param output: remaining + // @return Status code + Status StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining); + + // Expand the shape of the Tensor with one extra dimension. + // For example, if the shape is <512,512,3>: + // *- ExpandDim(0) gives: <1,512,512,3> + // *- ExpandDim(1) gives: <512,1,512,3> + // *- ExpandDim(3) gives: <512,512,3,1> + // @param axis location of the dim + virtual Status ExpandDim(const dsize_t &axis); + + virtual void Squeeze(); + + // Calculates the strides of the Tensor + // Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) + // The strides will be {6,2,1}. + // Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) + // The strides will be {24,8,4}. + // @return vector of integers + std::vector Strides(); + + std::string ToString() { + std::stringstream ss; + this->Print(ss); + return ss.str(); + } + + // Handle negative indices. + static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } + + // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. + // Based on the type of tensor, SliceNumeric or SliceString will be called + // @param out Tensor + // @param indices vector of indices + // @return Status error code + Status Slice(std::shared_ptr *out, const std::vector &indices); + + // Slice numeric tensors. + Status SliceNumeric(std::shared_ptr *out, const std::vector &indices); + + // Slice string tensors + Status SliceString(std::shared_ptr *out, const std::vector &indices); + +#ifdef ENABLE_PYTHON + // Constructs numpy array from input tensor + // @param data this data is the location of python data + // @return Status code + Status GetDataAsNumpy(py::array *data); + + Status GetDataAsNumpyStrings(py::array *data); + + static Status GetBufferInfo(Tensor *t, py::buffer_info *out); +#endif + + // Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor + Status Concatenate(const std::vector &index, const std::shared_ptr &input); + + // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor + // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 + // @tparam T type of values in the Tensor Iterator + template + class TensorIterator { + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = ptrdiff_t; + using pointer = T *; + using reference = T &; + + explicit TensorIterator(uchar *ptr = nullptr) { ptr_ = reinterpret_cast(ptr); } + + TensorIterator(const TensorIterator &raw_iterator) { ptr_ = raw_iterator.ptr_; } + + ~TensorIterator() = default; + + TensorIterator &operator=(const TensorIterator &rhs) { + ptr_ = rhs.ptr_; + return *this; + } + + TensorIterator &operator=(T *rhs) { + ptr_ = rhs; + return *this; + } + + bool operator==(const TensorIterator &rhs) { return ptr_ == rhs.ptr_; } + + bool operator!=(const TensorIterator &rhs) { return !(*this == rhs); } + + operator bool() const { return ptr_ != nullptr; } + + T &operator*() { return *ptr_; } + + const T &operator*() const { return *ptr_; } + + T *operator->() { return ptr_; } + + TensorIterator &operator+=(const ptrdiff_t &inc) { + ptr_ += inc; + return *this; + } + + TensorIterator &operator-=(const ptrdiff_t &inc) { + ptr_ -= inc; + return *this; + } + + TensorIterator &operator++() { + ++ptr_; + return *this; + } + + TensorIterator &operator--() { + --ptr_; + return *this; + } + + TensorIterator operator++(int) { + auto temp(*this); + ++ptr_; + return temp; + } + + TensorIterator operator--(int) { + auto temp(*this); + --ptr_; + return temp; + } + + TensorIterator operator+(const ptrdiff_t &inc) { + auto oldPtr = ptr_; + ptr_ += inc; + auto temp(*this); + ptr_ = oldPtr; + return temp; + } + + TensorIterator operator-(const ptrdiff_t &inc) { + auto oldPtr = ptr_; + ptr_ -= inc; + auto temp(*this); + ptr_ = oldPtr; + return temp; + } + + protected: + T *ptr_; + }; + + // Specialization of TensorIterator for strings. It returns std::string_view for every item. + // @tparam DUMMY, used to mbe able to specialize the inner class + template + class TensorIterator { + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = std::string_view; + using difference_type = ptrdiff_t; + using pointer = std::string_view *; + using reference = std::string_view &; + + explicit TensorIterator(uchar *data = nullptr, dsize_t index = 0) { + data_ = reinterpret_cast(data); + index_ = index; + } + + TensorIterator(const TensorIterator &raw_iterator) { + data_ = raw_iterator.data_; + index_ = raw_iterator.index_; + } + + ~TensorIterator() = default; + + bool operator==(const TensorIterator &rhs) { return data_ == rhs.data_ && index_ == rhs.index_; } + + bool operator!=(const TensorIterator &rhs) { return !(*this == rhs); } + + operator bool() const { return data_ != nullptr; } + + std::string_view operator*() const { + auto offset_ = reinterpret_cast(data_); + offset_t start = offset_[index_]; + return std::string_view{data_ + start}; + } + + TensorIterator &operator+=(const dsize_t &inc) { + index_ += inc; + return *this; + } + + TensorIterator &operator-=(const dsize_t &inc) { + index_ -= inc; + return *this; + } + + TensorIterator &operator++() { + ++index_; + return *this; + } + + TensorIterator &operator--() { + --index_; + return *this; + } + + TensorIterator operator++(int) { + auto temp(*this); + ++index_; + return temp; + } + + TensorIterator operator--(int) { + auto temp(*this); + --index_; + return temp; + } + + TensorIterator operator+(const dsize_t &inc) { + auto oldPtr = index_; + index_ += inc; + auto temp(*this); + index_ = oldPtr; + return temp; + } + + TensorIterator operator-(const dsize_t &inc) { + auto oldPtr = index_; + index_ -= inc; + auto temp(*this); + index_ = oldPtr; + return temp; + } + + protected: + dsize_t index_; + const char *data_; + }; + + // Return a TensorIterator that points to the start of the Tensor. + // It's the user responsibility to use the correct type that matches the Tensor type + // @param T The type of values in the Tensor + // @return TensorIterator + template + TensorIterator begin() { + AllocateBuffer(SizeInBytes()); + return TensorIterator(data_); + } + + // Return a linear iterator that points to the place after the last element of the Tensor. + // @tparam T The type of values in the Tensor + // @return TensorIterator + template + TensorIterator end() { + return TensorIterator(data_end_); + } + + // Copies the last dimension at `index` from Tensor `src` to this Tensor. + // @param src Tensor + // @param index vector to the start of the dimension. The last dim should be 0 + // @return Status + Status CopyLastDimAt(const std::shared_ptr &src, const std::vector &index); + + protected: + // Get the starting memory address for the data of the tensor. This potentially + // drives an allocation if the data is null. + // @return unsigned char* + unsigned char *GetMutableBuffer(); + + // A function that prints Tensor recursively, first called by print + // @param out + // @param cur_dim + // @param cur_index + void PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const; + + // A function that prints info about the tensor + // @param out output stream + void Print(std::ostream &out) const; + + // A function that print the value as specified by its index + // @param index vector representing the index + // @param out + void PrintItemAt(const std::vector &index, std::ostream &out) const; + + // Get pointer to item located at `index`, caller needs to provide the type. + // @tparam T + // @param index vector + // @return return a pointer to the item specified at index of type `T` + template + Status GetItemPtr(T **, const std::vector &index) const; + + // Get pointer to string located at `index` and the length of string + // @param index vector + // @return return a pointer to the string specified at index and the length of the string + Status GetItemPtr(uchar **, const std::vector &index, offset_t *length = nullptr) const; + + // Given a flat index of an item string, return the start and length of the item + // @param index flat index of the item + // @return start address of the ths string + // @return length of the string + Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const; + + // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the + // tensor's type is a string, otherwise undefined address would be returned. + // @return address of the first string of the tensor. + uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } + + // all access to shape_ should be via shape + TensorShape shape_; + // data type of tensor + DataType type_; + // pointer to the start of the physical data + unsigned char *data_; + // An allocator for data_ + CharAllocPtr data_allocator_; + // pointer to the end of the physical data + unsigned char *data_end_ = nullptr; +}; +template <> +inline Tensor::TensorIterator Tensor::end() { + return TensorIterator(data_, shape_.NumOfElements()); +} +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_row.cc b/mindspore/ccsrc/minddata/dataset/core/tensor_row.cc new file mode 100644 index 0000000000..5d75730a4c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_row.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "minddata/dataset/core/tensor_row.h" + +namespace mindspore { +namespace dataset { + +TensorRow::TensorRow() noexcept : id_(kDefaultRowId) {} + +TensorRow::TensorRow(size_type n, TensorRow::value_type t) noexcept : id_(kDefaultRowId), row_(n, t) {} + +TensorRow::TensorRow(const TensorRow::vector_type &v) : id_(kDefaultRowId), row_(v) {} + +TensorRow::TensorRow(row_id_type id, const std::initializer_list &lst) : id_(id), row_(lst) {} + +TensorRow::TensorRow(const TensorRow &tr) : id_(tr.id_), row_(tr.row_) {} + +TensorRow &TensorRow::operator=(const TensorRow &tr) { + if (this == &tr) { + return *this; + } + row_ = tr.row_; + id_ = tr.id_; + return *this; +} + +TensorRow &TensorRow::operator=(const std::initializer_list &lst) { + row_ = lst; + return *this; +} + +TensorRow::TensorRow(TensorRow::vector_type &&v) noexcept : id_(kDefaultRowId), row_(std::move(v)) {} + +TensorRow::TensorRow(row_id_type id, std::initializer_list &&lst) noexcept + : id_(id), row_(std::move(lst)) {} + +TensorRow::TensorRow(TensorRow &&tr) noexcept { + id_ = tr.id_; + row_ = std::move(tr.row_); +} + +TensorRow &TensorRow::operator=(TensorRow &&tr) noexcept { + if (this == &tr) { + return *this; + } + row_ = std::move(tr.row_); + id_ = tr.id_; + tr.id_ = kDefaultRowId; + return *this; +} + +TensorRow &TensorRow::operator=(std::initializer_list &&lst) noexcept { + row_ = std::move(lst); + return *this; +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_row.h b/mindspore/ccsrc/minddata/dataset/core/tensor_row.h new file mode 100644 index 0000000000..e8f066c87b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_row.h @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_CORE_TENSOR_ROW_H_ +#define DATASET_CORE_TENSOR_ROW_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" + +namespace mindspore { +namespace dataset { + +class TensorRow; // A set of Tensor pointers with an id +using TensorTable = std::vector; // The table of tensors is a vector of rows +using TensorQTable = std::deque; // A different flavour of tensor table, this one has queue functionality + +class TensorRow { + public: + static constexpr row_id_type kDefaultRowId = -1; // Default row id + + // Type definitions + using size_type = dsize_t; + using value_type = std::shared_ptr; + using reference = std::shared_ptr &; + using const_reference = const std::shared_ptr &; + using vector_type = std::vector>; + using iterator = std::vector>::iterator; + using const_iterator = std::vector>::const_iterator; + + TensorRow() noexcept; + + TensorRow(size_type n, value_type t) noexcept; + + // Copy Constructors + explicit TensorRow(const vector_type &v); + + TensorRow(row_id_type id, const std::initializer_list &lst); + + TensorRow(const TensorRow &tr); + + TensorRow &operator=(const TensorRow &tr); + + TensorRow &operator=(const std::initializer_list &lst); + + // Move Constructors + explicit TensorRow(vector_type &&v) noexcept; + + TensorRow(row_id_type id, std::initializer_list &&lst) noexcept; + + TensorRow(TensorRow &&tr) noexcept; + + TensorRow &operator=(TensorRow &&tr) noexcept; + + TensorRow &operator=(std::initializer_list &&lst) noexcept; + + // Destructor + ~TensorRow() = default; + + // Functions to fetch/set id/vector + row_id_type getId() const { return id_; } + + void setId(row_id_type id) { id_ = id; } + + const vector_type &getRow() const { return row_; } + + // Wrapper functions to support vector operations + void emplace_back(value_type t) { row_.emplace_back(t); } + + void push_back(value_type t) { row_.push_back(t); } + + void clear() noexcept { row_.clear(); } + + size_type size() const noexcept { return row_.size(); } + + void reserve(size_type size) { row_.reserve(size); } + + void resize(size_type size) { row_.resize(size); } + + bool empty() { return row_.empty(); } + + void insert(iterator position, iterator first, iterator last) { row_.insert(position, first, last); } + + // Wrapper functions to support vector element access + reference at(size_type index) { return row_.at(index); } + + const_reference at(size_type index) const { return row_.at(index); } + + reference front() { return row_.front(); } + + const_reference front() const { return row_.front(); } + + reference back() { return row_.back(); } + + const_reference back() const { return row_.back(); } + + reference operator[](size_type index) { return row_[index]; } + + const_reference operator[](size_type index) const { return row_[index]; } + + // Wrapper functions to support vector iteration + iterator begin() { return row_.begin(); } + + const_iterator begin() const { return row_.begin(); } + + iterator end() { return row_.end(); } + + const_iterator end() const { return row_.end(); } + + protected: + row_id_type id_; + std::vector> row_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_TENSOR_ROW_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc new file mode 100644 index 0000000000..ff40062d37 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#define MAX_INTEGER_DTYPE 9223372036854775807 + +#include "minddata/dataset/core/tensor_shape.h" + +#include + +#include "common/utils.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { +constexpr dsize_t TensorShape::kDimUnknown; + +bool multi_ok(dsize_t x, dsize_t y) { + dsize_t p = x * y; + if (x == 0) { + return true; + } + return p / x == y; +} + +dsize_t TensorShape::NumOfElements() const { + if (!known()) { + return 0; + } + return strides_[0]; +} + +void TensorShape::Print(std::ostream &out) const { + if (!known() && raw_shape_.empty()) { + out << ""; + } else { + out << "<"; + for (auto i = 0; i < this->Rank(); i++) { + if (raw_shape_[i] == kDimUnknown) { + out << "*"; + } else { + out << raw_shape_[i]; + } + if (i != this->Rank() - 1) { + out << ","; + } + } + out << ">"; + } +} + +TensorShape::TensorShape(const std::initializer_list &list) + : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { + AddListToShape(list); +} + +TensorShape::TensorShape(const std::vector &list) + : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { + AddListToShape(list); +} + +TensorShape::TensorShape(const TensorShape &shape) + : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { + AddListToShape(shape.AsVector()); + known_ = shape.known_; // override with the input shape in case of unknown-rank tensor shape. +} + +#ifdef ENABLE_PYTHON +TensorShape::TensorShape(py::list l) + : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { + std::vector list_c; + for (auto &i : l) { + if (!i.is_none()) { + list_c.push_back(i.cast()); + } else { + list_c.push_back(TensorShape::kDimUnknown); + } + } + AddListToShape(list_c); +} +#endif + +TensorShape::TensorShape(cv::MatSize cv_size, uint32_t type) + : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { + for (int i = 0; i < cv_size.dims(); i++) { + raw_shape_.push_back(cv_size[i]); + } + auto channels = static_cast(1 + (type >> static_cast(CV_CN_SHIFT))); + if (channels != 1) { + raw_shape_.push_back(channels); + } + known_ = true; +} + +TensorShape TensorShape::CreateUnknownRankShape() { + TensorShape s({}); + s.known_ = false; + return s; +} + +TensorShape TensorShape::InsertDim(dsize_t axis, dsize_t dim) const { + std::vector tmp = AsVector(); + (void)tmp.insert(tmp.begin() + axis, dim); + return TensorShape(tmp); +} + +std::vector TensorShape::AsVector() const { + return std::vector(raw_shape_.begin(), raw_shape_.end()); +} + +bool TensorShape::IsValidIndex(const std::vector &index) const { + dsize_t s_rank = Rank(); + if (index.size() != s_rank) { + return false; + } + for (dsize_t i = 0; i < s_rank; i++) { + if (index[i] < 0 || raw_shape_[i] <= index[i]) { + return false; + } + } + return true; +} + +template +void TensorShape::AddListToShape(const T &list) { + raw_shape_.resize(list.size()); + strides_.resize(list.size() + 1); + strides_[list.size()] = 1; + known_ = true; + dsize_t size = 0; + auto itr = std::rbegin(list); // iterate over the list in reverse order + auto s = list.size() - 1; // to compute strides while adding dims + for (; itr != std::rend(list); itr++, s--) { + dsize_t dim = *itr; + if (dim > 0) { + if (strides_[s + 1] > std::numeric_limits::max() / dim) { + MS_LOG(ERROR) << "Invalid shape data, overflow occurred!"; + known_ = false; + raw_shape_.clear(); + return; + } + strides_[s] = dim * strides_[s + 1]; + } + if (dim < 0) { + known_ = false; + } + if (dim > kDeMaxDim) { + std::stringstream ss; + ss << "Invalid shape data, dim (" << size << ") is larger than the maximum dim size(" << kDeMaxDim << ")!"; + MS_LOG(ERROR) << ss.str().c_str(); + known_ = false; + raw_shape_.clear(); + return; + } + raw_shape_[s] = dim; + size++; + } + if (size > kDeMaxRank) { + std::stringstream ss; + ss << "Invalid shape data, rank (" << size << ") is larger than the maximum rank size(" << kDeMaxRank << ")."; + MS_LOG(ERROR) << ss.str().c_str(); + known_ = false; + raw_shape_.clear(); + return; + } +} + +TensorShape TensorShape::CreateUnknownShapeWithRank(dsize_t rank) { + TensorShape s({}); + for (dsize_t i = 0; i < rank; i++) { + s.raw_shape_.push_back(kDimUnknown); + } + s.known_ = false; + return s; +} + +TensorShape TensorShape::PrependDim(dsize_t dim) const { + if (Size() == 0) { + return TensorShape({dim}); + } + return InsertDim(0, dim); +} + +TensorShape TensorShape::AppendDim(dsize_t dim) const { + auto vec = AsVector(); + vec.push_back(dim); + return TensorShape(vec); +} + +#ifdef ENABLE_PYTHON +py::list TensorShape::AsPyList() { + py::list list; + for (auto i : raw_shape_) { + list.append(i); + } + return list; +} +#endif + +TensorShape TensorShape::Squeeze() const { + std::vector new_shape; + for (auto s : AsVector()) { + if (s != 1) { + new_shape.push_back(s); + } + } + return TensorShape(new_shape); +} + +std::vector TensorShape::Strides() const { return std::vector{strides_.begin() + 1, strides_.end()}; } + +// Name: ToFlatIndex() +// Description: convert a vector style index to number, used to access memory internal use only +Status TensorShape::ToFlatIndex(const std::vector &index, dsize_t *flat_index) const { + *flat_index = 0; + for (size_t k = 0; k < index.size(); k++) { + *flat_index += index[k] * strides_[k + 1]; // skip the first element of strides_ which is numOfElements + } + CHECK_FAIL_RETURN_UNEXPECTED(*flat_index < NumOfElements(), "Not a valid index"); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_shape.h b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.h new file mode 100644 index 0000000000..4944f9e32c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.h @@ -0,0 +1,196 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_TENSOR_SHAPE_H_ +#define DATASET_CORE_TENSOR_SHAPE_H_ + +#include +#include +#include +#include +#include + +#include + +#ifdef ENABLE_PYTHON +#include "pybind11/pybind11.h" +namespace py = pybind11; +#endif + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/util/allocator.h" + +namespace mindspore { +namespace dataset { +// Class that represents a shape of a Tensor. A shape can be: +// -# Known shape (mKnown = true) +// -# Scalar --> empty vector --> <> +// -# n-Dim --> not empty vector --> where di is >= 0\n +// Example: <1,2>, <1>, <1,13,10,11,1> +// -# Unknown shape (mKnown = false) +// -# Rank is unknown --> empty vector --> <> +// -# one or more dim is unknown --> not empty vector --> where di is unknown\n +// Example: <3,?> (the 1st dim is unknown)\n +// <2,?,?,?> (all dims but the 0th dim are unknown) + +/// \brief TensorShape supports any dim > 0 and < 2^31-1 +class TensorShape { + public: + static constexpr dsize_t kDimUnknown = -1; // constant for an unknown dimension + + // Force the compiler to not create a no-arg constructor + TensorShape() = delete; + + /// \brief Create a Shape from an initialization list (e.g., TensorShape s = {2,2}). + /// If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown + /// \param[in] list + explicit TensorShape(const std::initializer_list &list); + + /// \brief Create a Shape from a vector (e.g., TensorShape s = std::vector({2,2}) ). + /// If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown + /// \param[in] list + explicit TensorShape(const std::vector &list); + + /// \brief Copy constructor + /// \param[in] shape + TensorShape(const TensorShape &shape); + +#ifdef ENABLE_PYTHON + /// \brief construct a TensorShape via a python list + /// \param[in] py::list l - a list object from python + explicit TensorShape(py::list l); +#endif + + ~TensorShape() = default; + + /// \brief Create a scalar Shape (i.e., empty shape with mKnown = true) + /// \return TensorShape + static TensorShape CreateScalar() { return TensorShape({}); } + + /// \brief Create a shape with an unknown rank. + /// \return TensorShape + static TensorShape CreateUnknownRankShape(); + + /// \brief Create a shape with a known rank . + /// \return TensorShape + static TensorShape CreateUnknownShapeWithRank(dsize_t rank); + + /// \brief Insert a new dim into a copy of the current shape. + /// \param[in] dim to be added + /// \param[in] axis the index where dim should be added + /// \return New modified shape + TensorShape InsertDim(dsize_t axis, dsize_t dim) const; + + /// \brief Insert new dim at index 0. For example, <2,4> --> PrependDim(4) --> <4,2,4> + /// \param[in] dim + /// \return + TensorShape PrependDim(dsize_t dim) const; + + /// \brief Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4> + /// \param[in] dim + /// \return + TensorShape AppendDim(dsize_t dim) const; + + /// \brief Create a shape based on OpenCV shape and type + /// \param[in] cv_size + /// \param[in] type int that represent the type in OpenCV, example CV_8U, CV_64S + TensorShape(cv::MatSize cv_size, uint32_t type); + + dsize_t Size() const { return raw_shape_.size(); } + + dsize_t Rank() const { return raw_shape_.size(); } + + bool known() const { return known_; } + + bool empty() const { return raw_shape_.empty(); } + + dsize_t NumOfElements() const; + + bool operator==(const TensorShape &rhs) const { return known_ == rhs.known_ && raw_shape_ == rhs.raw_shape_; } + + bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); } + + dsize_t operator[](const dsize_t index) const { + if (index < 0) return raw_shape_[raw_shape_.size() + index]; + return raw_shape_[index]; + } + + /// \brief Return the Shape as a vector + /// \return + std::vector AsVector() const; + + /// \brief Returns the class info as a string + /// \return + std::string ToString() const { + std::stringstream ss; + ss << *this; + return ss.str(); + } + + /// \brief Actual print function used by operator<< + /// \param out output string stream + void Print(std::ostream &out) const; + + /// \brief << Stream output operator overload + /// This allows you to print the info using stream operators + /// \param[in] out - reference to the output stream being overloaded + /// \param[in] rO - reference to the TensorShape to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const TensorShape &so) { + so.Print(out); + return out; + } + +#ifdef ENABLE_PYTHON + py::list AsPyList(); +#endif + + /// \brief Checks if the given index is a valid index for this tensor. + /// For example: Tensor<3,4> Index<1,1> is valid. But Index<4,1> or <1> are not. + /// \param[in] index + /// \return bool + bool IsValidIndex(const std::vector &index) const; + + TensorShape Squeeze() const; + + std::vector Strides() const; + + /// \brief Returns the location of the item assuming row major memory layout. + /// \param[in] index + /// \param[out] flat_index + /// \return + Status ToFlatIndex(const std::vector &index, dsize_t *flat_index) const; + + private: + // True if known and valid shape, false otherwise + bool known_; + // Vector to keep the dims of the shape. + std::vector raw_shape_; + // Vector to keep the strides of the shape. The size is rank+1 + std::vector strides_; + + /// \brief Internal utility function to iterate over a list, + /// check if the dim is valid and then insert it into the shape. + /// \param[in] list Iterable list + /// \return true if the shape is valid and no overflow would be generated when counting the number of elements. + /// False otherwise. + template + void AddListToShape(const T &list); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_TENSOR_SHAPE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt new file mode 100644 index 0000000000..e3ead16d05 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt @@ -0,0 +1,26 @@ +add_subdirectory(datasetops) +add_subdirectory(opt) +add_subdirectory(gnn) +add_subdirectory(perf) +add_subdirectory(cache) +if (ENABLE_TDTQUE) + add_subdirectory(tdt) +endif () + +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +add_library(engine OBJECT + execution_tree.cc + data_buffer.cc + data_schema.cc + dataset_iterator.cc + ) +target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) + +if (ENABLE_TDTQUE) + add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf + engine-cache-client engine-cache-server) +else () + add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf + engine-cache-client engine-cache-server) +endif () diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt new file mode 100644 index 0000000000..5e7ebea176 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt @@ -0,0 +1,8 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +add_library(engine-cache-client OBJECT + cache_client.cc + cache_request.cc) +add_library(engine-cache-server OBJECT + cache_service.cc + cache_server.cc) diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc new file mode 100644 index 0000000000..04746131bb --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/util/bit.h" + +namespace mindspore { +namespace dataset { + +// Constructor +CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill) + : server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {} + +// print method for display cache details +void CacheClient::Print(std::ostream &out) const { + out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_ + << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_ + << "\n Spilling: " << std::boolalpha << spill_; +} + +Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { + CacheRowRequest rq(server_connection_id_, cookie()); + RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row)); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + if (row_id_from_server != nullptr) { + *row_id_from_server = rq.GetRowIdAfterCache(); + } + return Status::OK(); +} + +Status CacheClient::WriteBuffer(std::unique_ptr &&in) const { + std::unique_ptr db_ptr = std::move(in); + auto num_rows = db_ptr->NumRows(); + std::vector all_rows; + if (num_rows > 0) { + all_rows.reserve(num_rows); + // Break down the DataBuffer into TensorRow. We will send the requests async + // and then do a final wait. + MemGuard rq_arr; + RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie())); + CacheServer &cs = CacheServer::GetInstance(); + for (auto i = 0; i < num_rows; ++i) { + TensorRow row; + auto rq = rq_arr[i]; + RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); + RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row)); + RETURN_IF_NOT_OK(cs.PushRequest(rq)); + // We can't let row go out of scope. Otherwise it will free all the tensor memory. + // So park it in the vector. When this function go out of scope, its memory + // will be freed. + all_rows.push_back(std::move(row)); + } + // Now we wait for the requests to be done. + for (auto i = 0; i < num_rows; ++i) { + auto rq = rq_arr[i]; + RETURN_IF_NOT_OK(rq->Wait()); + } + } + return Status::OK(); +} + +Status CacheClient::GetRows(const std::vector &row_id, TensorTable *out) const { + RETURN_UNEXPECTED_IF_NULL(out); + BatchFetchRequest rq(server_connection_id_, row_id); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + RETURN_IF_NOT_OK(rq.RestoreRows(out)); + return Status::OK(); +} + +Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { + UniqueLock lck(&mux_); + // To create a cache, we identify ourself at the client by: + // - the shared session id + // - a crc for the tree nodes from the cache downward + // Pack these 2 into a single 64 bit request id + // + // Consider this example: + // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch + // tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch + // These are different trees in a single session, but the user wants to share the cache. + // This is not allowed because the data of these caches are different. + // + // Consider this example: + // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch + // tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch + // These are different trees in the same session, but the cached data is the same, so it is okay + // to allow the sharing of this cache between these pipelines. + + // The CRC is computed by the tree prepare phase and passed to this function when creating the cache. + // If we already have a server_connection_id_, then it means this same cache client has already been used + // to create a cache and some other tree is trying to use the same cache. + // That is allowed, however the crc better match! + if (server_connection_id_) { + if (cache_crc_ != tree_crc) { + RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!"); + } + // Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should + // skip the build phase. + lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. + CacheClient::ServiceStat stat{}; + RETURN_IF_NOT_OK(GetStat(&stat)); + if (stat.cache_service_state == static_cast(CacheService::State::kFetchPhase)) { + return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); + } + } else { + cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client + // Combine the session and crc. This will form our client cache identifier. + connection_id_type connection_identification = (static_cast(session_id_) << 32) | cache_crc_; + // Now execute the cache create request using this identifier and other configs + BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone; + if (spill_) { + createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk; + } + if (generate_id) { + createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId; + } + CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + Status rc = rq.Wait(); + if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { + server_connection_id_ = rq.GetServerConnectionId(); + if (rc.IsOk()) { + // The 1st guy creating the cache will get a cookie back. + // But this object may be shared among pipelines and we don't want + // overwrite it. + cookie_ = rq.cookie(); + } + } + // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the + // CacheOp to bypass the build phase. + return rc; + } + return Status::OK(); +} + +Status CacheClient::PurgeCache() { + UniqueLock lck(&mux_); + PurgeCacheRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + return rq.Wait(); +} + +Status CacheClient::DestroyCache() { + UniqueLock lck(&mux_); + DestroyCacheRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + return rq.Wait(); +} + +Status CacheClient::GetStat(ServiceStat *stat) { + SharedLock lck(&mux_); + RETURN_UNEXPECTED_IF_NULL(stat); + GetStatRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + stat->num_disk_cached = rq.GetNumDiskCached(); + stat->num_mem_cached = rq.GetNumMemCached(); + stat->min_row_id = rq.GetMinRowId(); + stat->max_row_id = rq.GetMaxRowId(); + stat->cache_service_state = rq.GetState(); + return Status::OK(); +} + +Status CacheClient::CacheSchema(const std::unordered_map &map) { + SharedLock lck(&mux_); + CacheSchemaRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map)); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + return Status::OK(); +} + +Status CacheClient::FetchSchema(std::unordered_map *map) { + SharedLock lck(&mux_); + RETURN_UNEXPECTED_IF_NULL(map); + FetchSchemaRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + *map = rq.GetColumnMap(); + return Status::OK(); +} + +Status CacheClient::BuildPhaseDone() const { + SharedLock lck(&mux_); + BuildPhaseDoneRequest rq(server_connection_id_, cookie()); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h new file mode 100644 index 0000000000..f25db87578 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h @@ -0,0 +1,141 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_CACHE_CLIENT_H_ +#define DATASET_ENGINE_CACHE_CLIENT_H_ + +#include +#include +#include +#include +#include +#include + +#include "./de_tensor_generated.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/cache/cache_server.h" +#include "minddata/dataset/util/lock.h" + +namespace mindspore { +namespace dataset { +/// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through +/// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously +/// rows, etc. +class CacheClient { + public: + /// \brief Constructor + /// \param session_id A user assigned session id for the current pipeline + /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited + /// \param spill Spill to disk if out of memory + CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill); + + /// \brief Destructor + ~CacheClient() = default; + + /// \brief Getter function for returning the current session id + /// \return session id + uint64_t session_id() const { return session_id_; } + + /// \brief Send a TensorRow to the cache server + /// \param[in] row + /// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset + /// \return return code + Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const; + + /// \brief Send a DataBuffer to the cache server + /// \param in Unique pointer of the DataBuffer to be cached + /// \return return code + Status WriteBuffer(std::unique_ptr &&in) const; + + /// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is + /// any cache miss + /// \param row_id A vector of row id's + /// \param out A TensorTable of TensorRows. + /// \return return code + Status GetRows(const std::vector &row_id, TensorTable *out) const; + + /// \brief Create a cache. + /// \param tree_crc A crc that was generated during tree prepare phase + /// \param generate_id Let the cache service generate row id + /// \return Status object + Status CreateCache(uint32_t tree_crc, bool generate_id); + + /// \brief Purge a cache. Cache can be reused after reset. + /// \return Status object + Status PurgeCache(); + + /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused. + /// \return Status object + Status DestroyCache(); + + /// \brief Get the statistics from a cache. + /// \param[in/out] Pointer to a pre-allocated ServiceStat object + /// \return Status object + struct ServiceStat { + int64_t num_mem_cached; + int64_t num_disk_cached; + row_id_type min_row_id; + row_id_type max_row_id; + int8_t cache_service_state; + }; + Status GetStat(ServiceStat *); + + /// \brief Cache the schema at the cache server + /// \param map The unordered map of the schema + /// \return Status object + Status CacheSchema(const std::unordered_map &map); + + /// \brief Fetch the schema from the cache server + /// \param map Pointer to pre-allocated map object + /// \return Status object. + Status FetchSchema(std::unordered_map *map); + + /// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache + /// client that holds cookie can be allowed to make this request + /// \return Status object + Status BuildPhaseDone() const; + + /// \brief A print method typically used for debugging + /// \param out The output stream to write output to + void Print(std::ostream &out) const; + + /// \brief Stream output operator overload + /// \return the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) { + cc.Print(out); + return out; + } + + /// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it. + /// \return Cookie + std::string cookie() const { return cookie_; } + + private: + mutable RWLock mux_; + uint64_t cache_mem_sz_; + bool spill_; + // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow + // sharing of the cache. + uint32_t session_id_; + uint32_t cache_crc_; + // The server_connection_id_ is the actual id we use for operations after the cache is built + connection_id_type server_connection_id_; + // Some magic cookie returned from the cache server. + std::string cookie_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_CACHE_CLIENT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc new file mode 100644 index 0000000000..3b7fc057a2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc @@ -0,0 +1,223 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_request.h" + +namespace mindspore { +namespace dataset { + +Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) { + buffers_.reserve(row.size() + 1); + RETURN_IF_NOT_OK(SerializeTensorRowHeader(row)); + buffers_.push_back(fbb_->GetBufferPointer()); + for (const auto &ts : row) { + buffers_.push_back(ts->GetBuffer()); + } + return Status::OK(); +} + +Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) { + try { + fbb_ = std::make_shared(); + std::vector> v; + std::vector tensor_sz; + v.reserve(row.size()); + tensor_sz.reserve(row.size()); + // We will go through each column in the row. + for (const std::shared_ptr &ts_ptr : row) { + flatbuffers::Offset ts_off; + RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off)); + v.push_back(ts_off); + tensor_sz.push_back(ts_ptr->SizeInBytes()); + } + auto column_off = fbb_->CreateVector(v); + auto data_sz_off = fbb_->CreateVector(tensor_sz); + TensorRowHeaderMsgBuilder row_builder(*fbb_); + row_builder.add_column(column_off); + row_builder.add_data_sz(data_sz_off); + // Pass the row_id even if it may not be known. + row_builder.add_row_id(row.getId()); + row_builder.add_size_of_this(-1); // fill in later after we call Finish. + auto out = row_builder.Finish(); + fbb_->Finish(out); + // Now go back to fill in size_of_this in the flat buffer. + auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer()); + auto success = msg->mutate_size_of_this(fbb_->GetSize()); + if (!success) { + RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); + } + return Status::OK(); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } +} + +Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr &ts_ptr, + flatbuffers::Offset *out_off) { + RETURN_UNEXPECTED_IF_NULL(out_off); + const Tensor *ts = ts_ptr.get(); + auto shape_off = fbb_->CreateVector(ts->shape().AsVector()); + const auto ptr = ts->GetBuffer(); + if (ptr == nullptr) { + RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); + } + auto src = ts->type().value(); + TensorType dest; +#define CASE(t) \ + case DataType::t: \ + dest = TensorType::TensorType_##t; \ + break + // Map the type to fill in the flat buffer. + switch (src) { + CASE(DE_BOOL); + CASE(DE_INT8); + CASE(DE_UINT8); + CASE(DE_INT16); + CASE(DE_UINT16); + CASE(DE_INT32); + CASE(DE_UINT32); + CASE(DE_INT64); + CASE(DE_UINT64); + CASE(DE_FLOAT16); + CASE(DE_FLOAT32); + CASE(DE_FLOAT64); + CASE(DE_STRING); + default: + MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; + RETURN_STATUS_UNEXPECTED("Unknown type"); + } +#undef CASE + + TensorMetaMsgBuilder ts_builder(*fbb_); + ts_builder.add_dims(shape_off); + ts_builder.add_type(dest); + auto ts_off = ts_builder.Finish(); + *out_off = ts_off; + return Status::OK(); +} + +Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, + std::shared_ptr *out) { + RETURN_UNEXPECTED_IF_NULL(col_ts); + auto shape_in = col_ts->dims(); + auto type_in = col_ts->type(); + std::vector v; + v.reserve(shape_in->size()); + v.assign(shape_in->begin(), shape_in->end()); + TensorShape shape(v); + DataType::Type dest = DataType::DE_UNKNOWN; +#define CASE(t) \ + case TensorType_##t: \ + dest = DataType::Type::t; \ + break + + switch (type_in) { + CASE(DE_BOOL); + CASE(DE_INT8); + CASE(DE_UINT8); + CASE(DE_INT16); + CASE(DE_UINT16); + CASE(DE_INT32); + CASE(DE_UINT32); + CASE(DE_INT64); + CASE(DE_UINT64); + CASE(DE_FLOAT16); + CASE(DE_FLOAT32); + CASE(DE_FLOAT64); + CASE(DE_STRING); + } +#undef CASE + + DataType type(dest); + std::shared_ptr ts = + std::make_shared(shape, type, static_cast(data.GetPointer()), data.GetSize()); + // Next we restore the real data which can be embedded or stored separately. + if (ts->SizeInBytes() != data.GetSize()) { + MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" + << "Dumping tensor\n" + << *ts << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + *out = std::move(ts); + return Status::OK(); +} + +Status BatchFetchRequest::RestoreRows(TensorTable *out) { + RETURN_UNEXPECTED_IF_NULL(out); + auto num_elements = row_id_.size(); + auto *offset_array = reinterpret_cast(mem_.GetPointer()); + TensorTable tbl; + tbl.reserve(num_elements); + ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes()); + for (auto i = 0; i < num_elements; ++i) { + auto len = offset_array[i + 1] - offset_array[i]; + TensorRow row; + row.setId(row_id_.at(i)); + if (len > 0) { + ReadableSlice row_data(all, offset_array[i], len); + // Next we de-serialize flat buffer to get back each column + auto msg = GetTensorRowHeaderMsg(row_data.GetPointer()); + auto msg_sz = msg->size_of_this(); + // Start of the tensor data + auto ts_offset = msg_sz; + row.reserve(msg->column()->size()); + for (auto k = 0; k < msg->column()->size(); ++k) { + auto col_ts = msg->column()->Get(k); + std::shared_ptr ts; + ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k)); + RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts)); + row.push_back(ts); + ts_offset += data.GetSize(); + } + } + tbl.push_back(std::move(row)); + } + *out = std::move(tbl); + return Status::OK(); +} + +Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map &map) { + try { + fbb_ = std::make_shared(); + std::vector> v; + v.reserve(map.size()); + for (auto &column : map) { + auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second); + v.push_back(c); + } + auto v_off = fbb_->CreateVector(v); + auto final_off = CreateSchemaMsg(*fbb_, v_off); + fbb_->Finish(final_off); + buf_ = fbb_->GetBufferPointer(); + len_of_buf_ = fbb_->GetSize(); + return Status::OK(); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } +} + +std::unordered_map FetchSchemaRequest::GetColumnMap() { + if (column_name_id_map_.empty()) { + auto *map_msg = flatbuffers::GetRoot(mem_.GetPointer()); + auto v = map_msg->column(); + for (auto i = 0; i < v->size(); ++i) { + auto col = map_msg->column()->Get(i); + column_name_id_map_.emplace(col->name()->str(), col->id()); + } + } + return column_name_id_map_; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h new file mode 100644 index 0000000000..3d0edc6dd8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h @@ -0,0 +1,225 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef DATASET_ENGINE_CACHE_REQ_H_ +#define DATASET_ENGINE_CACHE_REQ_H_ + +#include +#include +#include +#include +#include +#include + +#include "./de_tensor_generated.h" +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/util/slice.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +/// \brief CacheClient communicates with CacheServer using Requests. +class BaseRequest { + public: + // Request types + enum class RequestType : int16_t { + kCacheRow = 0, + kBatchFetchRows = 1, + kCreateCache = 2, + kPurgeCache = 3, + kDestroyCache = 4, + kGetStat = 5, + kCacheSchema = 6, + kFetchSchema = 7, + kBuildPhaseDone = 8, + // Add new request before it. + kRequestUnknown = 32767 + }; + // For kCreateCache + enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; + friend class CacheServer; + /// \brief Base class of a cache server request + /// \param connection_id A combination of session id and crc that uniquely identifies a connection. + /// \param type Type of the request + explicit BaseRequest(connection_id_type connection_id, RequestType type) + : type_(type), connection_id_(connection_id) {} + virtual ~BaseRequest() = default; + /// \brief Wait for the completion of a request + /// \return Status returned from the cache server + Status Wait() { + RETURN_IF_NOT_OK(wp_.Wait()); + return rc_; + } + + /// \brief Getter function of the current connection id + /// \return Connection id + connection_id_type GetServerConnectionId() const { return connection_id_; } + + private: + RequestType type_; + connection_id_type connection_id_; + Status rc_; + WaitPost wp_; +}; +/// \brief Request to cache a single TensorRow +class CacheRowRequest : public BaseRequest { + public: + friend class CacheServer; + explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie) + : BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {} + ~CacheRowRequest() = default; + + /// \brief Serialize a TensorRow for streaming to the cache server + /// \param row TensorRow + /// \return Status object + Status SerializeCacheRowRequest(const TensorRow &row); + /// \brief Return the row id assigned to this row for non-mappable dataset + /// \return row id of the cached row + row_id_type GetRowIdAfterCache() { return row_id_from_server_; } + + private: + std::shared_ptr fbb_; + row_id_type row_id_from_server_; + std::vector buffers_; + std::string cookie_; + + /// \brief Private function to serialize one TensorRow + /// \param row TensorRow + /// \return Status object + Status SerializeTensorRowHeader(const TensorRow &row); + /// \brief Private function to serialize one Tensor + /// \param ts_ptr Tensor + /// \return Status object + Status SerializeOneTensorMeta(const std::shared_ptr &ts_ptr, flatbuffers::Offset *out_off); +}; +/// \brief Request to fetch rows in batch +class BatchFetchRequest : public BaseRequest { + public: + friend class CacheServer; + friend class CacheService; + BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id) + : BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {} + Status RestoreRows(TensorTable *out); + + private: + std::vector row_id_; + MemGuard mem_; + Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr *out); +}; +/// \brief Request to create a cache for the current connection +class CreationCacheRequest : public BaseRequest { + public: + friend class CacheServer; + /// \brief Constructor + /// \param connection_id + /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited + /// \param flag Attributes of the cache. + explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz, + CreateCacheFlag flag = CreateCacheFlag::kNone) + : BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {} + + std::string cookie() const { return cookie_; } + + private: + uint64_t cache_mem_sz; + CreateCacheFlag flag_; + std::string cookie_; +}; +/// \brief Request to purge a cache. +class PurgeCacheRequest : public BaseRequest { + public: + friend class CacheServer; + explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {} +}; +/// \brief Request to destroy a cache +class DestroyCacheRequest : public BaseRequest { + public: + friend class CacheServer; + explicit DestroyCacheRequest(connection_id_type connection_id) + : BaseRequest(connection_id, RequestType::kDestroyCache) {} +}; +/// \brief Obtain the statistics of the current connection +class GetStatRequest : public BaseRequest { + public: + friend class CacheServer; + friend class CacheService; + explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {} + row_id_type GetMinRowId() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->min_row_id(); + } + row_id_type GetMaxRowId() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->max_row_id(); + } + int64_t GetNumMemCached() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->num_mem_cached(); + } + int64_t GetNumDiskCached() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->num_disk_cached(); + } + uint8_t GetState() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->state(); + } + + private: + MemGuard mem_; +}; +/// \brief Request to cache a schema +class CacheSchemaRequest : public BaseRequest { + public: + friend class CacheServer; + explicit CacheSchemaRequest(connection_id_type connection_id) + : BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {} + ~CacheSchemaRequest() = default; + + Status SerializeCacheSchemaRequest(const std::unordered_map &map); + const void *GetBuffer() const { return buf_; } + + private: + std::shared_ptr fbb_; + const void *buf_; + int64_t len_of_buf_; +}; +/// \brief Request to fetch a schema +class FetchSchemaRequest : public BaseRequest { + public: + friend class CacheServer; + explicit FetchSchemaRequest(connection_id_type connection_id) + : BaseRequest(connection_id, RequestType::kFetchSchema) {} + ~FetchSchemaRequest() = default; + + std::unordered_map GetColumnMap(); + + private: + MemGuard mem_; + std::unordered_map column_name_id_map_; +}; +/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only. +class BuildPhaseDoneRequest : public BaseRequest { + public: + friend class CacheServer; + BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) + : BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {} + + private: + std::string cookie_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc new file mode 100644 index 0000000000..c9fb6ecab1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc @@ -0,0 +1,252 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_server.h" +#include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/util/bit.h" + +namespace mindspore { +namespace dataset { +Status CacheServer::DoServiceStart() { + if (!top_.empty()) { + Path spill(top_); + RETURN_IF_NOT_OK(spill.CreateDirectories()); + MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; + } + RETURN_IF_NOT_OK(vg_.ServiceStart()); + cache_q_ = std::make_shared>(1024); + RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); + auto f = std::bind(&CacheServer::ServerRequest, this); + // Spawn a a few threads to serve the request. + for (auto i = 0; i < num_workers_; ++i) { + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f)); + } + return Status::OK(); +} + +Status CacheServer::DoServiceStop() { + Status rc; + Status rc2; + // First stop all the threads. + RETURN_IF_NOT_OK(vg_.ServiceStop()); + // Clean up all the caches if any. + UniqueLock lck(&rwLock_); + auto it = all_caches_.begin(); + while (it != all_caches_.end()) { + auto cs = std::move(it->second); + rc2 = cs->ServiceStop(); + if (rc2.IsError()) { + rc = rc2; + } + ++it; + } + return rc; +} + +CacheService *CacheServer::GetService(connection_id_type id) const { + SharedLock lck(&rwLock_); + auto it = all_caches_.find(id); + if (it != all_caches_.end()) { + return it->second.get(); + } + return nullptr; +} + +Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, + BaseRequest::CreateCacheFlag flag, std::string *out_cookie) { + // We can't do spilling unless this server is setup with a spill path in the first place + bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk; + bool generate_id = + (flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId; + if (spill && top_.empty()) { + RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); + } + RETURN_UNEXPECTED_IF_NULL(out_cookie); + *out_cookie = ""; + // Before creating the cache, first check if this is a request for a shared usage of an existing cache + // If two CreateService come in with identical connection_id, we need to serialize the create. + // The first create will be successful and be given a special cookie. + UniqueLock lck(&rwLock_); + auto end = all_caches_.end(); + auto it = all_caches_.find(connection_id); + if (it == end) { + std::unique_ptr cs; + try { + cs = std::make_unique(cache_mem_sz, spill ? top_ : "", generate_id); + RETURN_IF_NOT_OK(cs->ServiceStart()); + *out_cookie = cs->cookie(); + all_caches_.emplace(connection_id, std::move(cs)); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory); + } + } else { + MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; + // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it + // treat it as OK. + return Status(StatusCode::kDuplicateKey); + } + return Status::OK(); +} + +/// This is the main loop the cache server thread(s) are running. +/// Each thread will pop a request and save the result in the same request. +/// The sender will wait on the wait post in the request. Once the request +/// is fulfilled, the server thread will do a post signalling the request is +/// is processed. +/// \return +Status CacheServer::ServerRequest() { + TaskManager::FindMe()->Post(); + // Loop forever until we are interrupted. + while (true) { + BaseRequest *base_rq = nullptr; + RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq)); + auto cs = GetService(base_rq->connection_id_); + // Except for creating a new session, we expect cs is not null. + switch (base_rq->type_) { + case BaseRequest::RequestType::kCacheRow: { + if (cs == nullptr) { + std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + // Only if the cookie matches, we can accept insert into this cache that has a build phase + if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) { + rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); + } + } + break; + } + case BaseRequest::RequestType::kBatchFetchRows: { + if (cs == nullptr) { + std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_); + } + break; + } + case BaseRequest::RequestType::kCreateCache: { + // If the cache is already created we still need to run the creation so that we do sanity checks on the + // client id and return the cache id back to the user. + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_); + break; + } + case BaseRequest::RequestType::kPurgeCache: { + if (cs != nullptr) { + base_rq->rc_ = cs->Purge(); + } else { + // it is already purged. Ignore it. + base_rq->rc_ = Status::OK(); + } + break; + } + case BaseRequest::RequestType::kDestroyCache: { + if (cs != nullptr) { + // We need a strong lock to protect the map. + connection_id_type id = base_rq->connection_id_; + UniqueLock lck(&rwLock_); + // std::map will invoke the constructor of CacheService. So we don't need to do anything here. + auto n = all_caches_.erase(id); + if (n == 0) { + // It has been destroyed by another duplicate request. + MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; + } + base_rq->rc_ = Status::OK(); + } else { + // it is already destroyed. Ignore it. + base_rq->rc_ = Status::OK(); + } + break; + } + case BaseRequest::RequestType::kGetStat: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + CacheService::ServiceStat svc_stat; + rq->rc_ = cs->GetStat(&svc_stat); + if (rq->rc_.IsOk()) { + flatbuffers::FlatBufferBuilder fbb; + ServiceStatMsgBuilder bld(fbb); + bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); + bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); + bld.add_max_row_id(svc_stat.max_); + bld.add_min_row_id(svc_stat.min_); + bld.add_state(svc_stat.state_); + auto offset = bld.Finish(); + fbb.Finish(offset); + rq->rc_ = rq->mem_.allocate(fbb.GetSize()); + if (rq->rc_.IsOk()) { + WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize()); + ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize()); + RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); + } + } + } + break; + } + case BaseRequest::RequestType::kCacheSchema: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_); + } + break; + } + case BaseRequest::RequestType::kFetchSchema: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = cs->FetchSchema(&rq->mem_); + } + break; + } + case BaseRequest::RequestType::kBuildPhaseDone: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + // We can only allow to switch phase is the cookie match. + if (rq->cookie_ == cs->cookie()) { + rq->rc_ = cs->BuildPhaseDone(); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); + } + } + break; + } + default: + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type"); + } + // Notify it is done, and move on to the next request. + base_rq->wp_.Set(); + } + return Status::OK(); +} +CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers) + : top_(spill_path), num_workers_(num_workers) {} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h new file mode 100644 index 0000000000..13b68c4389 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef DATASET_ENGINE_CACHE_SERVER_H_ +#define DATASET_ENGINE_CACHE_SERVER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/util/arena.h" +#include "minddata/dataset/util/cache_pool.h" +#include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/system_pool.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +class BaseRequest; +/// \brief A server which provides CacheService services. +class CacheServer : public Service { + public: + friend class Services; + using cache_index = std::map>; + + CacheServer(const CacheServer &) = delete; + CacheServer &operator=(const CacheServer &) = delete; + CacheServer(CacheServer &&) = delete; + CacheServer &operator=(CacheServer &) = delete; + static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); } + Status DoServiceStart() override; + Status DoServiceStop() override; + ~CacheServer() { (void)ServiceStop(); } + + /// \brief For the current demonstration, a cache client contacts cache server using a Queue. + /// \param rq + /// \return Status object + Status PushRequest(BaseRequest *rq) { + RETURN_UNEXPECTED_IF_NULL(rq); + RETURN_IF_NOT_OK(cache_q_->Add(rq)); + return Status::OK(); + } + + private: + mutable RWLock rwLock_; + std::string top_; + cache_index all_caches_; + std::shared_ptr> cache_q_; + TaskGroup vg_; + int32_t num_workers_; + + /// \brief Constructor + /// \param spill_path Top directory for spilling buffers to. + /// \param num_workers Number of threads for handling requests. + explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3); + + /// \brief Locate a cache service from connection id. + /// \return Pointer to cache service. Null if not found + CacheService *GetService(connection_id_type id) const; + + /// \brief Create a cache service. We allow multiple clients to create the same cache service. + /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given + /// a special unique cookie. + /// \param[in] connection_id This is from a Cache client. + /// \param[in] cache_mem_sz + /// \param[in] flag + /// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator + /// \return Status object + Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag, + std::string *out_cookie); + + /// \brief Entry point for all server threads. + Status ServerRequest(); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_CACHE_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc new file mode 100644 index 0000000000..4e1208d173 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc @@ -0,0 +1,265 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/util/slice.h" + +namespace mindspore { +namespace dataset { +CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id) + : root_(root), + cache_mem_sz_(mem_sz), + cp_(nullptr), + map_(nullptr), + next_id_(0), + generate_id_(generate_id), + schema_key_(-1), + st_(generate_id ? State::kBuildPhase : State::kNone) {} +CacheService::~CacheService() { (void)ServiceStop(); } +bool CacheService::UseArena() { + // If fixed size, use Arena instead of the pool from global context. + return (cache_mem_sz_ > 0); +} +Status CacheService::DoServiceStart() { + std::shared_ptr mp_; + if (UseArena()) { + // Create a fixed size arena based on the parameter. + std::shared_ptr arena; + RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); + mp_ = std::move(arena); + } else { + // Unlimited size. Simply use a system pool. Another choice is CircularPool. + mp_ = std::make_shared(); + } + // Put together a CachePool for backing up the Tensor + cp_ = std::make_shared(CachePool::value_allocator(mp_), root_); + RETURN_IF_NOT_OK(cp_->ServiceStart()); + // Set up the B+ tree as well. But use the system pool instead. + map_ = std::make_shared(); + // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. + cookie_ = cp_->MyName(); + return Status::OK(); +} +Status CacheService::DoServiceStop() { + if (cp_ != nullptr) { + RETURN_IF_NOT_OK(cp_->ServiceStop()); + } + return Status::OK(); +} +Status CacheService::CacheRow(const std::vector &buf, row_id_type *row_id_generated) { + SharedLock rw(&rw_lock_); + RETURN_UNEXPECTED_IF_NULL(row_id_generated); + if (st_ == State::kFetchPhase) { + // For this kind of cache service, once we are done with the build phase into fetch phase, we can't + // allow other to cache more rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + try { + // The first buffer is a flatbuffer which describes the rest of the buffers follow + auto fb = buf.front(); + RETURN_UNEXPECTED_IF_NULL(fb); + auto msg = GetTensorRowHeaderMsg(fb); + // If the server side is designed to ignore incoming row id, we generate row id. + if (generate_id_) { + *row_id_generated = GetNextRowId(); + // Some debug information on how many rows we have generated so far. + if ((*row_id_generated) % 1000 == 0) { + MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated; + } + } else { + if (msg->row_id() < 0) { + std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + *row_id_generated = msg->row_id(); + } + auto size_of_this = msg->size_of_this(); + auto column_hdr = msg->column(); + // Number of tensor buffer should match the number of columns plus one. + if (buf.size() != column_hdr->size() + 1) { + std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) + + " but get " + std::to_string(buf.size()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + // Next we store in either memory or on disk. Low level code will consolidate everything in one piece. + std::vector all_data; + all_data.reserve(column_hdr->size() + 1); + all_data.emplace_back(fb, size_of_this); + for (auto i = 0; i < column_hdr->size(); ++i) { + all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); + } + // Now we cache the flat buffer. + CachePool::key_type key; + RETURN_IF_NOT_OK(cp_->Insert(all_data, &key)); + Status rc = map_->DoInsert(*row_id_generated, key); + if (rc == Status(StatusCode::kDuplicateKey)) { + MS_LOG(DEBUG) << "Ignoring duplicate key."; + } else { + RETURN_IF_NOT_OK(rc); + } + return Status::OK(); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } +} +std::ostream &operator<<(std::ostream &out, const CacheService &cs) { + // Then show any custom derived-internal stuff + out << "\nCache memory size: " << cs.cache_mem_sz_; + out << "\nSpill path: "; + if (cs.root_.empty()) { + out << "None"; + } else { + out << cs.GetSpillPath(); + } + return out; +} +Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); } +Status CacheService::Purge() { + // First we must lock exclusively. No one else can cache/restore anything. + UniqueLock rw(&rw_lock_); + RETURN_IF_NOT_OK(cp_->ServiceStop()); + auto new_map = std::make_shared(); + map_.reset(); + map_ = std::move(new_map); + next_id_ = 0; + RETURN_IF_NOT_OK(cp_->ServiceStart()); + return Status::OK(); +} +Status CacheService::GetStat(CacheService::ServiceStat *out) { + SharedLock rw(&rw_lock_); + RETURN_UNEXPECTED_IF_NULL(out); + if (st_ == State::kNone || st_ == State::kFetchPhase) { + out->stat_ = cp_->GetStat(); + out->state_ = static_cast(st_); + auto it = map_->begin(); + if (it != map_->end()) { + out->min_ = it.key(); + auto end_it = map_->end(); + --end_it; + out->max_ = end_it.key(); + } + } else { + out->state_ = static_cast(st_); + } + return Status::OK(); +} +Status CacheService::BatchFetch(const std::vector &v, MemGuard *out) const { + RETURN_UNEXPECTED_IF_NULL(out); + SharedLock rw(&rw_lock_); + if (st_ == State::kBuildPhase) { + // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + const auto num_elements = v.size(); + int64_t mem_sz = (num_elements + 1) * sizeof(int64_t); + int64_t data_offset = mem_sz; + std::vector sz_v; + std::vector keys; + sz_v.reserve(num_elements); + keys.reserve(num_elements); + for (auto row_id : v) { + auto r = map_->Search(row_id); + if (r.second) { + auto &it = r.first; + CachePool::key_type key = it.value(); + auto sz = cp_->GetSize(key); + if (sz == 0) { + std::string errMsg = "Key not found: "; + errMsg += std::to_string(key); + RETURN_STATUS_UNEXPECTED(errMsg); + } + keys.push_back(key); + sz_v.push_back(sz); + mem_sz += sz; + } else { + keys.push_back(-1); + sz_v.push_back(0); + } + } + MemGuard mem; + RETURN_IF_NOT_OK(mem.allocate(mem_sz)); + auto *offset_array = reinterpret_cast(mem.GetMutablePointer()); + offset_array[0] = data_offset; + WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes()); + for (auto i = 0; i < num_elements; ++i) { + auto sz = sz_v.at(i); + offset_array[i + 1] = offset_array[i] + sz; + if (sz > 0) { + WritableSlice row_data(all, offset_array[i], sz); + auto key = keys.at(i); + size_t bytesRead = 0; + RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); + if (bytesRead != sz) { + MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "." + << " Internal key: " << key << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + } + } + *out = std::move(mem); + return Status::OK(); +} +Status CacheService::CacheSchema(const void *buf, int64_t len) { + SharedLock rw(&rw_lock_); + if (st_ == State::kFetchPhase) { + // For this kind of cache service, once we are done with the build phase into fetch phase, we can't + // allow other to cache more rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + // This is a special request and we need to remember where we store it. + // In case we are calling the same function from multiple threads, only + // the first one is considered. Rest is ignored. + CachePool::key_type cur_key = schema_key_; + CachePool::key_type key; + if (cur_key < 0) { + RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key)); + auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key); + MS_LOG(DEBUG) << "Caching Schema. Result = " << result; + } else { + MS_LOG(DEBUG) << "Caching Schema already done"; + } + return Status::OK(); +} +Status CacheService::FetchSchema(MemGuard *out) const { + SharedLock rw(&rw_lock_); + if (st_ == State::kBuildPhase) { + // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + RETURN_UNEXPECTED_IF_NULL(out); + MemGuard mem; + if (schema_key_ >= 0) { + auto len = cp_->GetSize(schema_key_); + RETURN_IF_NOT_OK(mem.allocate(len)); + auto slice = WritableSlice(mem.GetMutablePointer(), len); + RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); + *out = std::move(mem); + } else { + return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached"); + } + return Status::OK(); +} +Status CacheService::BuildPhaseDone() { + if (HasBuildPhase()) { + // Exclusive lock to switch phase + UniqueLock rw(&rw_lock_); + st_ = State::kFetchPhase; + return Status::OK(); + } else { + RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h new file mode 100644 index 0000000000..bf324e82e3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef DATASET_ENGINE_CACHE_SERVICE_H_ +#define DATASET_ENGINE_CACHE_SERVICE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "./de_tensor_generated.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/util/arena.h" +#include "minddata/dataset/util/btree.h" +#include "minddata/dataset/util/cache_pool.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/system_pool.h" + +namespace mindspore { +namespace dataset { +struct CacheStat; +/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is +/// created to support spilling +class CacheService : public Service { + public: + friend class CacheServer; + using row_map = BPlusTree; + + enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase }; + + /// \brief Constructor + /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited + /// \param root Spill path. Empty string means no spilling + /// \param generate_id If the cache service should generate row id for buffer that is cached. + /// For non-mappable dataset, this should be set to true. + CacheService(uint64_t mem_sz, const std::string &root, bool generate_id); + ~CacheService(); + + /// \brief For fixed size memory, we will create an Arena. + /// \return false if unlimited memory. + bool UseArena(); + + Status DoServiceStart() override; + Status DoServiceStop() override; + + /// \brief Main function to cache a row which is in form a series of buffers. + /// The first buffer is a Google flatbuffer which describes the rest of the buffers followed. + /// \param[in] buf Vector of buffer + /// \param[out] row_id_generated The row id assigned to this row if any + /// \return Status object + Status CacheRow(const std::vector &buf, row_id_type *row_id_generated); + /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded + /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. + /// \param[in] v A vector of row id. + /// \param[out] out A contiguous memory buffer that holds the requested rows. + /// \return Status object + Status BatchFetch(const std::vector &v, MemGuard *out) const; + + /// \brief Getter function + /// \return Spilling path + Path GetSpillPath() const; + /// \brief A structure returned from the cache server for statistics request. + class ServiceStat { + public: + using state_type = std::underlying_type::type; + ServiceStat() : min_(0), max_(0), state_(0) {} + CachePool::CacheStat stat_{}; + row_id_type min_; + row_id_type max_; + state_type state_; + }; + /// \brief Statistics for the current service + /// \param[in/out] A pointer to a pre-allocated ServiceStat structure + /// \return Status Object + Status GetStat(ServiceStat *); + /// \brief Cache schema + /// \param buf A Google Flatbuffer that contains the schema + /// \param len size of the buffer + /// \return Status object + Status CacheSchema(const void *buf, int64_t len); + /// \brief Fetch schema + /// \param out A contiguous memory that contains the serialized form of schema. + /// \return Status object + Status FetchSchema(MemGuard *out) const; + /// \brief Purge the content of a cache + /// \return Status object + Status Purge(); + /// \brief Overload the << operator to print a cache service + /// \param out std::ostream + /// \param cs A cache service + /// \return std::ostream + friend std::ostream &operator<<(std::ostream &out, const CacheService &cs); + /// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient + /// is the creator + /// \return Cookie + std::string cookie() const { return cookie_; } + /// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and + /// a read phase. + /// \return True if has two phases. + bool HasBuildPhase() const { return generate_id_; } + /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. + /// \return Status object + Status BuildPhaseDone(); + + private: + mutable RWLock rw_lock_; + std::string root_; + uint64_t cache_mem_sz_; + std::shared_ptr cp_; + std::shared_ptr map_; + std::atomic next_id_; + bool generate_id_; + std::atomic schema_key_; + std::string cookie_; + State st_; + + /// \brief Private function to generate a row id + /// \return Row id assigned. + row_id_type GetNextRowId() { return next_id_.fetch_add(1); } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs b/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs new file mode 100644 index 0000000000..de26069f23 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +namespace mindspore.dataset; + +/// Type of a Tensor +enum TensorType : byte { + DE_UNKNOWN = 0, + DE_BOOL = 1, + DE_INT8 = 2, + DE_UINT8 = 3, + DE_INT16 = 4, + DE_UINT16 = 5, + DE_INT32 = 6, + DE_UINT32 = 7, + DE_INT64 = 8, + DE_UINT64 = 9, + DE_FLOAT16 = 10, + DE_FLOAT32 = 11, + DE_FLOAT64 = 12, + DE_STRING = 13 +} + +/// The meta information of a Tensor +/// \note Only the type and shape are considered meta information. Tensor data is excluded. +table TensorMetaMsg { + dims:[int64] (required); + type:TensorType; +} + +/// This is the first buffer that is sent to a Cache server when a TensorRow is serialized. +/// \param row_id is the row id of the TensorRow. +/// \param column The meta information of each Tensor in the row +/// \param size of this serialized buffer +/// \param size of each tensor data buffer that follows +table TensorRowHeaderMsg { + row_id:int64; + column:[TensorMetaMsg] (required); + size_of_this:int64; + data_sz:[int64] (required); +} + +root_type TensorRowHeaderMsg; + +/// A row of row id's +table TensorRowIds { + row_id:[int64] (required); +} + +/// Statistics returned from each cache service +/// \note It must match CacheService::ServiceStat +table ServiceStatMsg { + num_mem_cached:int64; + num_disk_cached:int64; + min_row_id:int64; + max_row_id:int64; + state:int8; +} + +/// Column description of each column in a schema +table ColumnNameMsg { + name:string; + id:int32; +} + +/// Serialized form of a schema +table SchemaMsg { + column:[ColumnNameMsg]; +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/connector.h b/mindspore/ccsrc/minddata/dataset/engine/connector.h new file mode 100644 index 0000000000..a91d8e68e9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/connector.h @@ -0,0 +1,211 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_CONNECTOR_H_ +#define DATASET_ENGINE_CONNECTOR_H_ + +#include +#include +#include +#include +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/cond_var.h" + +namespace mindspore { +namespace dataset { +// Connector is a communication data structure between two group of threads that +// preserve the order. +// +// Example use case: +// An initial tasks-list of [1,2,3,4,5,6,7,8,9] with 5 threads getting/processing elements from that list, +// and pushing the processed elements to a Connector in any order whoever finishes processing first. +// If the consumer of the Connector is single threaded, when the consumer pop() the +// element from the Connector one by one, it will get [1,2,3,4,5,6,7,8,9]. +// +// Requirements: +// 1. Each thread in the group of consumer or producer threads must be assigned ids starting from 0. +// 2. If your multi-threads program is not reading from a Connector class but +// want to push to a Connector class, you must follow roundrobin element distribution, +// i.e., the thread-id0 must have the first element, thread-id1 has the second element, +// and so on; then each of this worker can push to the Connector class async in parallel. +// +// Blocking conditions: +// 1. Connector.push(int, T) can block when the internal queue it's trying to push is full. +// 2. Connector.pop(int) can block when +// - The internal queue it's trying to pop is empty. +// - The caller thread of pop() is not equal to the _expectConsumer. This is to enforce +// the ordering. +// +// Future improvement: +// 1. Fault tolerant: Right now, if one of the worker dies, the Connector will not work +// properly. +template +class Connector { + public: + // Name: Constructor + // Description: Initializing private members with the given input arguments. + // expect_consumer_ and pop_from_ is initialized to 0 as part of + // our requirements. We instantiate nProducers number of internal + // queues so that each producer thread can push to its queue without + // any sync overhead. + // Constructor of Connector + // Initializing private members with the given input arguments. + // _expectConsumer and _popFrom is initialized to 0 as part of + // our requirements. We instantiate nProducers number of internal + // queues so that each producer thread can push to its queue without + // any sync overhead. + // @param n_producers The number of threads producing data into this DbConnector. + // @param n_consumers The number of thread consuming data from this DbConnector. + // @param queue_capacity The number of element (DataBuffer) for each queue. + Connector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity) + : num_producers_(n_producers), num_consumers_(n_consumers) { + MS_LOG(DEBUG) << "A connector is created with " << n_producers << " producers and " << n_consumers << " consumers."; + my_name_ = Services::GetUniqueID(); + // We require the consumers to have ids sequentially from 0 to the num_consumers_-1, + // Otherwise a ordered list of consumer ids have to be passed here. (not implemented yet) + expect_consumer_ = 0; + + // Roundrobin pop starts from index 0 of the queues_. + pop_from_ = 0; + + // Initialize the queues_ to have num_producers_ number of queues. + // Each queue is a blocking queue and has the same queue_capacity. + queues_.Init(num_producers_, queue_capacity); + } + + // Destructor of Connector + virtual ~Connector() = default; + + // Get an element from the Connector. + // @not Call to pop() can block the caller thread, see the blocking condition at the top of this file. + // @param worker_id The id of a worker thread calling this method. + // @param result The address of an object where the popped element will be placed. + virtual Status Pop(int32_t worker_id, // The worker-id of the caller. See the requirement at the top of this file. + T *result) noexcept { + { + MS_ASSERT(worker_id < num_consumers_); + std::unique_lock lk(m_); + RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return expect_consumer_ == worker_id; })); + RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); + pop_from_ = (pop_from_ + 1) % num_producers_; + out_buffers_count_++; + expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; + } + + cv_.NotifyAll(); + return Status::OK(); + } + + // Add an element into the DbConnector without the overhead of synchronization. + // It may block when the internal queue is full. + // The element passed to this function will be copied into the internal queue. + // @param worker_id The id of a worker thread calling this method. + // @param el A const lvalue element to be passed/added/pushed. + Status Push(int32_t worker_id, const T &el) noexcept { + MS_ASSERT(worker_id < static_cast(queues_.size())); + MS_ASSERT(queues_[worker_id] != nullptr); + return (queues_[worker_id]->Add(el)); + } + + auto out_buffers_count() const { return out_buffers_count_.load(); } + + // Add an element into the DbConnector without the overhead of synchronization. + // It may block when the internal queue is full. + // The element passed to this function will be forwarded into the internal queue. + // @param worker_id The id of a worker thread calling this method. + // @param el An element to be passed/added/pushed. + virtual Status Push(int32_t worker_id, T &&el) noexcept { + MS_ASSERT(worker_id < static_cast(queues_.size())); + MS_ASSERT(queues_[worker_id] != nullptr); + return (queues_[worker_id]->Add(std::forward(el))); + } + + // Resets the internal index tracking of the queue so that it can be used again with new inputs, + // starting from the beginning. + void Reset() { + for (int i = 0; i < queues_.size(); ++i) { + queues_[i]->ResetQue(); + } + expect_consumer_ = 0; + pop_from_ = 0; + out_buffers_count_ = 0; + MS_LOG(DEBUG) << "Connector counters reset."; + } + + void Print(std::ostream &out, bool showAll) const { + out << "\n--------- Connector ------------" + << "\nConnector Name : " << my_name_ << "\nNumber of consumers : " << num_consumers_ + << "\nNumber of producers : " << num_producers_ << "\n"; + } + + friend std::ostream &operator<<(std::ostream &out, const Connector &con) { + con.print(out, false); + return out; + } + + // Get current size of connector. + int32_t size() const { + int32_t size = 0; + for (int32_t i = 0; i < queues_.size(); ++i) { + size += queues_[i]->size(); + } + return size; + } + + int32_t capacity() const { + int32_t capacity = 0; + for (int32_t i = 0; i < queues_.size(); ++i) { + capacity += queues_[i]->capacity(); + } + return capacity; + } + + // Register the internal resources with Task group for interruption service. + // @param vg + // @return + Status Register(TaskGroup *vg) { + Status rc = queues_.Register(vg); + if (rc.IsOk()) { + rc = cv_.Register(vg->GetIntrpService()); + } + return rc; + } + + protected: + std::string my_name_; + + // A list of Queues that are thread safe. + QueueList queues_; + + // The consumer that we allow to get the next data from pop() + int32_t expect_consumer_; + + // The index to the queues_ where the next data should be popped. + int32_t pop_from_; + + int32_t num_producers_; + int32_t num_consumers_; + + // Used in the Pop(), when a thread call pop() but it is not the expect_consumer_. + std::mutex m_; + CondVar cv_; + std::atomic out_buffers_count_ = 0; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_CONNECTOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_buffer.cc b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.cc new file mode 100644 index 0000000000..b36aae6837 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/core/tensor.h" + +namespace mindspore { +namespace dataset { +// Name: Constructor #1 +// Description: This is the main constructor that is used for making a buffer +DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {} + +// A method for debug printing of the buffer +void DataBuffer::Print(std::ostream &out, bool show_all) const { + out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n"; + + // If the column counts are set then it means that data has been set into + // the tensor table. Display the tensor table here. + if (this->NumCols() > 0) { + out << "Tensor table:\n"; + for (int32_t row = 0; row < DataBuffer::NumRows(); ++row) { + out << "Row # : " << row << "\n"; + TensorRow currRow = (*tensor_table_)[row]; + for (int32_t col = 0; col < this->NumCols(); ++col) { + out << "Column #: " << col << "\n"; // Should add the column name here as well? + // Call the tensor display + out << *(currRow[col]) << "\n"; + } + } + } +} + +// Remove me!! Callers should fetch rows via pop +Status DataBuffer::GetTensor(std::shared_ptr *ptr, int32_t row_id, int32_t col_id) const { + if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) { + *ptr = (tensor_table_->at(row_id)).at(col_id); + } else { + std::string err_msg = + "indices for mTensorTable out of range: (" + std::to_string(row_id) + "," + std::to_string(col_id) + ")."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +// Remove me!! Callers should fetch rows via pop +Status DataBuffer::GetRow(int32_t row_id, TensorRow *ptr) const { + if (tensor_table_ && !tensor_table_->empty() && row_id < tensor_table_->size()) { + *ptr = tensor_table_->at(row_id); + } else { + std::string err_msg = "rowId for mTensorTable out of range: " + std::to_string(row_id); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + return Status::OK(); +} + +Status DataBuffer::PopRow(TensorRow *ptr) { + if (tensor_table_ && !tensor_table_->empty()) { + *ptr = std::move(tensor_table_->front()); + tensor_table_->pop_front(); + } + + return Status::OK(); +} + +Status DataBuffer::SliceOff(int64_t number_of_rows) { + while (number_of_rows > 0) { + tensor_table_->pop_back(); + number_of_rows--; + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h new file mode 100644 index 0000000000..5fcb4c21a5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATA_BUFFER_H_ +#define DATASET_ENGINE_DATA_BUFFER_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" + +namespace mindspore { +namespace dataset { +/// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between +/// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format +/// where n TensorRows may consist of m tensors (columns). +class DataBuffer { + public: + // Buffer flags + enum BufferFlags : uint32_t { + kDeBFlagNone = 0, + kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg + kDeBFlagEOE = 1u << 1 // The buffer is an eoe end-of-epoch msg + }; + + // Name: Constructor #1 + // Description: This is the main constructor that is used for making a buffer + DataBuffer(int32_t id, BufferFlags flags); + + /// \brief default destructor + ~DataBuffer() = default; + + /// \brief A method for debug printing of the buffer + /// \param[inout] out The stream to write to + /// \param[in] show_all A boolean to toggle between details and summary printing + void Print(std::ostream &out, bool show_all) const; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) { + cb.Print(out, false); + return out; + } + + // Convenience getter functions for flag checking + bool eof() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagEOF)); } + + bool eoe() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagEOE)); } + + // Simple getter funcs + int32_t id() const { return buffer_id_; } + + void set_id(int32_t id) { buffer_id_ = id; } + + int32_t NumRows() const { return ((tensor_table_) ? tensor_table_->size() : 0); } + + int32_t NumCols() const { + return (tensor_table_ == nullptr || tensor_table_->empty()) ? 0 : tensor_table_->at(0).size(); + } + + BufferFlags buffer_flags() const { return buffer_flags_; } + + // Remove me!! Callers should fetch rows via pop + Status GetTensor(std::shared_ptr *, int32_t row_id, int32_t col_id) const; + + // Remove me!! Callers should drain rows via pop. + Status GetRow(int32_t row_id, TensorRow *) const; + + // Get a row from the TensorTable + Status PopRow(TensorRow *); + + Status SliceOff(int64_t number_of_rows); + + // Replacing mTensorTable, the unique_ptr assignment will release the old TensorTable. + void set_tensor_table(std::unique_ptr new_table) { tensor_table_ = std::move(new_table); } + + void set_flag(BufferFlags in_flag) { + buffer_flags_ = static_cast(static_cast(buffer_flags_) | static_cast(in_flag)); + } + + void Shuffle() {} // does nothing right now. possibly remove later + + protected: + int32_t buffer_id_; // An id for the buffer. + std::unique_ptr tensor_table_; // A table (row major) of Tensors + BufferFlags buffer_flags_; // bit mask for various buffer properties +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATA_BUFFER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_schema.cc b/mindspore/ccsrc/minddata/dataset/engine/data_schema.cc new file mode 100644 index 0000000000..50d910251d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/data_schema.cc @@ -0,0 +1,451 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/data_schema.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// A macro for converting an input string representing the column type to it's actual +// numeric column type. +#define STR_TO_TENSORIMPL(in_col_str, out_type) \ + do { \ + if (in_col_str == "cvmat") { \ + out_type = TensorImpl::kCv; \ + } else if (in_col_str == "flex") { \ + out_type = TensorImpl::kFlexible; \ + } else if (in_col_str == "np") { \ + out_type = TensorImpl::kNP; \ + } else { \ + out_type = TensorImpl::kNone; \ + } \ + } while (false) + +// Constructor 1: Simple constructor that leaves things uninitialized. +ColDescriptor::ColDescriptor() + : type_(DataType::DE_UNKNOWN), rank_(0), tensor_impl_(TensorImpl::kNone), tensor_shape_(nullptr) {} + +// Constructor 2: Main constructor +ColDescriptor::ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank, + const TensorShape *in_shape) + : type_(col_type), rank_(rank), tensor_impl_(tensor_impl), col_name_(col_name) { + // If a shape was provided, create unique pointer for it and copy construct it into + // our shape. Otherwise, set our shape to be empty. + if (in_shape != nullptr) { + // Create a shape and copy construct it into our column's shape. + tensor_shape_ = std::make_unique(*in_shape); + } else { + tensor_shape_ = nullptr; + } + // If the user input a shape, then the rank of the input shape needs to match + // the input rank + if (in_shape != nullptr && in_shape->known() && in_shape->Size() != rank_) { + rank_ = in_shape->Size(); + MS_LOG(WARNING) << "Rank does not match the number of dimensions in the provided shape." + << " Overriding rank with the number of dimensions in the provided shape."; + } +} + +// Explicit copy constructor is required +ColDescriptor::ColDescriptor(const ColDescriptor &in_cd) + : type_(in_cd.type_), rank_(in_cd.rank_), tensor_impl_(in_cd.tensor_impl_), col_name_(in_cd.col_name_) { + // If it has a tensor shape, make a copy of it with our own unique_ptr. + tensor_shape_ = in_cd.hasShape() ? std::make_unique(in_cd.shape()) : nullptr; +} + +// Assignment overload +ColDescriptor &ColDescriptor::operator=(const ColDescriptor &in_cd) { + if (&in_cd != this) { + type_ = in_cd.type_; + rank_ = in_cd.rank_; + tensor_impl_ = in_cd.tensor_impl_; + col_name_ = in_cd.col_name_; + // If it has a tensor shape, make a copy of it with our own unique_ptr. + tensor_shape_ = in_cd.hasShape() ? std::make_unique(in_cd.shape()) : nullptr; + } + return *this; +} + +// Destructor +ColDescriptor::~ColDescriptor() = default; + +// A print method typically used for debugging +void ColDescriptor::Print(std::ostream &out) const { + out << " Name : " << col_name_ << "\n Type : " << type_ << "\n Rank : " << rank_ + << "\n Shape : ("; + if (tensor_shape_) { + out << *tensor_shape_ << ")\n"; + } else { + out << "no shape provided)\n"; + } +} + +// Given a number of elements, this function will compute what the actual Tensor shape would be. +// If there is no starting TensorShape in this column, or if there is a shape but it contains +// an unknown dimension, then the output shape returned shall resolve dimensions as needed. +Status ColDescriptor::MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const { + if (out_shape == nullptr) { + RETURN_STATUS_UNEXPECTED("Unexpected null output shape argument."); + } + + // If the shape is not given in this column, then we assume the shape will be: {numElements} + if (tensor_shape_ == nullptr) { + if (this->rank() == 0 && num_elements == 1) { + *out_shape = TensorShape::CreateScalar(); + return Status::OK(); + } + *out_shape = TensorShape({num_elements}); + return Status::OK(); + } + + // Build the real TensorShape based on the requested shape and the number of elements in the data. + // If there are unknown dimensions, then the unknown dimension needs to be filled in. + // Example: requestedShape: {?,4,3}. + // If numElements is 24, then the output shape can be computed to: {2,4,3} + std::vector requested_shape = tensor_shape_->AsVector(); + int64_t num_elements_of_shape = 1; // init to 1 as a starting multiplier. + + // unknownDimPosition variable is overloaded to provide 2 meanings: + // 1) If it's set to DIM_UNKNOWN, then it provides a boolean knowledge to tell us if there are + // any unknown dimensions. i.e. if it's set to unknown, then there are no unknown dimensions. + // 2) If it's set to a numeric value, then this is the vector index position within the shape + // where the single unknown dimension can be found. + int64_t unknown_dim_position = TensorShape::kDimUnknown; // Assume there are no unknown dims to start + + for (int i = 0; i < requested_shape.size(); ++i) { + // If we already had an unknown dimension, then we cannot have a second unknown dimension. + // We only support the compute of a single unknown dim. + if (requested_shape[i] == TensorShape::kDimUnknown && unknown_dim_position != TensorShape::kDimUnknown) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Requested shape has more than one unknown dimension!"); + } + + // If the current dimension in the requested shape is a known value, then compute the number of + // elements so far. + if (requested_shape[i] != TensorShape::kDimUnknown) { + num_elements_of_shape *= requested_shape[i]; + } else { + // This dimension is unknown so track which dimension position has it. + unknown_dim_position = i; + } + } + + // Sanity check the the computed element counts divide evenly into the input element count + if (num_elements < num_elements_of_shape || num_elements_of_shape == 0 || num_elements % num_elements_of_shape != 0) { + RETURN_STATUS_UNEXPECTED("Requested shape has an invalid element count!"); + } + + // If there was any unknown dimensions, then update the requested shape to fill in the unknown + // dimension with the correct value. If there were no unknown dim's then the output shape will + // remain to be the same as the requested shape. + if (unknown_dim_position != TensorShape::kDimUnknown) { + requested_shape[unknown_dim_position] = (num_elements / num_elements_of_shape); + } + + // Any unknown dimension is filled in now. Set the output shape + *out_shape = TensorShape(requested_shape); + return Status::OK(); +} + +// getter function for the shape +TensorShape ColDescriptor::shape() const { + if (tensor_shape_ != nullptr) { + return *tensor_shape_; // copy construct a shape to return + } else { + return TensorShape::CreateUnknownRankShape(); // empty shape to return + } +} + +const char DataSchema::DEFAULT_DATA_SCHEMA_FILENAME[] = "datasetSchema.json"; + +// Constructor 1: Simple constructor that leaves things uninitialized. +DataSchema::DataSchema() : num_rows_(0) {} + +// Internal helper function. Parses the json schema file in any order and produces a schema that +// does not follow any particular order (json standard does not enforce any ordering protocol). +// This one produces a schema that contains all of the columns from the schema file. +Status DataSchema::AnyOrderLoad(nlohmann::json column_tree) { + // Iterate over the json file. Each parent json node is the column name, + // followed by the column properties in the child tree under the column. + // Outer loop here iterates over the parents (i.e. the column name) + if (!column_tree.is_array()) { + for (nlohmann::json::iterator it = column_tree.begin(); it != column_tree.end(); ++it) { + std::string col_name = it.key(); + nlohmann::json column_child_tree = it.value(); + RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, col_name)); + } + } else { + // Case where the schema is a list of columns not a dict + for (nlohmann::json::iterator it = column_tree.begin(); it != column_tree.end(); ++it) { + nlohmann::json column_child_tree = it.value(); + RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, "")); + } + } + return Status::OK(); +} + +// Internal helper function. For each input column name, perform a lookup to the json document to +// find the matching column. When the match is found, process that column to build the column +// descriptor and add to the schema in the order in which the input column names are given.id +Status DataSchema::ColumnOrderLoad(nlohmann::json column_tree, const std::vector &columns_to_load) { + if (!column_tree.is_array()) { + // the json file is dict (e.g., {image: ...}) + // Loop over the column name list + for (const auto &curr_col_name : columns_to_load) { + // Find the column in the json document + auto column_info = column_tree.find(common::SafeCStr(curr_col_name)); + if (column_info == column_tree.end()) { + RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); + } + // At this point, columnInfo.value() is the subtree in the json document that contains + // all of the data for a given column. This data will formulate our schema column. + const std::string &col_name = column_info.key(); + nlohmann::json column_child_tree = column_info.value(); + RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, col_name)); + } + } else { + // the json file is array (e.g., [name: image...]) + // Loop over the column name list + for (const auto &curr_col_name : columns_to_load) { + // Find the column in the json document + int32_t index = -1; + int32_t i = 0; + for (const auto &it_child : column_tree.items()) { + auto name = it_child.value().find("name"); + if (name == it_child.value().end()) { + RETURN_STATUS_UNEXPECTED("Name field is missing for this column."); + } + if (name.value() == curr_col_name) { + index = i; + break; + } + i++; + } + if (index == -1) { + RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); + } + nlohmann::json column_child_tree = column_tree[index]; + RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, curr_col_name)); + } + } + return Status::OK(); +} + +// Internal helper function for parsing shape info and building a vector for the shape construction. +static Status buildShape(const nlohmann::json &shapeVal, std::vector *outShape) { + if (outShape == nullptr) { + RETURN_STATUS_UNEXPECTED("null output shape"); + } + if (shapeVal.empty()) return Status::OK(); + + // Iterate over the integer list and add those values to the output shape tensor + auto items = shapeVal.items(); + using it_type = decltype(items.begin()); + (void)std::transform(items.begin(), items.end(), std::back_inserter(*outShape), [](it_type j) { return j.value(); }); + return Status::OK(); +} + +// Internal helper function. Given the json tree for a given column, load it into our schema. +Status DataSchema::ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name) { + int32_t rank_value = -1; + TensorImpl t_impl_value = TensorImpl::kFlexible; + std::string name, type_str; + std::vector tmp_shape = {}; + bool shape_field_exists = false; + // Iterate over this column's attributes. + // Manually iterating each of the child nodes/trees here so that we can provide our own error handling. + for (const auto &it_child : column_child_tree.items()) { + // Save the data for each of the attributes into variables. We'll use these to construct later. + if (it_child.key() == "name") { + name = it_child.value(); + } else if (it_child.key() == "type") { + type_str = it_child.value(); + } else if (it_child.key() == "rank") { + rank_value = it_child.value(); + } else if (it_child.key() == "t_impl") { + STR_TO_TENSORIMPL(it_child.value(), t_impl_value); + } else if (it_child.key() == "shape") { + shape_field_exists = true; + RETURN_IF_NOT_OK(buildShape(it_child.value(), &tmp_shape)); + } else { + std::string err_msg = "Unexpected column attribute " + it_child.key() + " for column " + col_name; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + if (!name.empty()) { + if (!col_name.empty() && col_name != name) { + std::string err_msg = + "json schema file for column " + col_name + " has column name that does not match columnsToLoad"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } else { + if (col_name.empty()) { + std::string err_msg = "json schema file for column " + col_name + " has invalid or missing column name."; + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + name = col_name; + } + } + // data type is mandatory field + if (type_str.empty()) + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "json schema file for column " + col_name + " has invalid or missing column type."); + + // rank number is mandatory field + if (rank_value <= -1) + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "json schema file for column " + col_name + " must define a positive rank value."); + + // Create the column descriptor for this column from the data we pulled from the json file + TensorShape col_shape = TensorShape(tmp_shape); + if (shape_field_exists) + (void)this->AddColumn(ColDescriptor(name, DataType(type_str), t_impl_value, rank_value, &col_shape)); + else + // Create a column descriptor that doesn't have a shape + (void)this->AddColumn(ColDescriptor(name, DataType(type_str), t_impl_value, rank_value)); + return Status::OK(); +} + +// Parses a schema json file and populates the columns and meta info. +Status DataSchema::LoadSchemaFile(const std::string &schema_file_path, + const std::vector &columns_to_load) { + try { + std::ifstream in(schema_file_path); + + nlohmann::json js; + in >> js; + RETURN_IF_NOT_OK(PreLoadExceptionCheck(js)); + try { + num_rows_ = js.at("numRows").get(); + } catch (nlohmann::json::out_of_range &e) { + num_rows_ = 0; + } catch (nlohmann::json::exception &e) { + RETURN_STATUS_UNEXPECTED("Unable to parse \"numRows\" from schema"); + } + nlohmann::json column_tree = js.at("columns"); + if (column_tree.empty()) { + RETURN_STATUS_UNEXPECTED("columns is null"); + } + if (columns_to_load.empty()) { + // Parse the json tree and load the schema's columns in whatever order that the json + // layout decides + RETURN_IF_NOT_OK(this->AnyOrderLoad(column_tree)); + } else { + RETURN_IF_NOT_OK(this->ColumnOrderLoad(column_tree, columns_to_load)); + } + } catch (const std::exception &err) { + // Catch any exception and convert to Status return code + RETURN_STATUS_UNEXPECTED("Schema file failed to load"); + } + return Status::OK(); +} + +// Parses a schema json string and populates the columns and meta info. +Status DataSchema::LoadSchemaString(const std::string &schema_json_string, + const std::vector &columns_to_load) { + try { + nlohmann::json js = nlohmann::json::parse(schema_json_string); + RETURN_IF_NOT_OK(PreLoadExceptionCheck(js)); + num_rows_ = js.value("numRows", 0); + nlohmann::json column_tree = js.at("columns"); + if (column_tree.empty()) { + RETURN_STATUS_UNEXPECTED("columns is null"); + } + if (columns_to_load.empty()) { + // Parse the json tree and load the schema's columns in whatever order that the json + // layout decides + RETURN_IF_NOT_OK(this->AnyOrderLoad(column_tree)); + } else { + RETURN_IF_NOT_OK(this->ColumnOrderLoad(column_tree, columns_to_load)); + } + } catch (const std::exception &err) { + // Catch any exception and convert to Status return code + RETURN_STATUS_UNEXPECTED("Schema file failed to load"); + } + return Status::OK(); +} + +// Destructor +DataSchema::~DataSchema() = default; + +// Getter for the ColDescriptor by index +const ColDescriptor &DataSchema::column(int32_t idx) const { + MS_ASSERT(idx < static_cast(col_descs_.size())); + return col_descs_[idx]; +} + +// A print method typically used for debugging +void DataSchema::Print(std::ostream &out) const { + out << "Dataset schema: ("; + for (const auto &col_desc : col_descs_) { + out << col_desc << "\n"; + } +} + +// Adds a column descriptor to the schema +Status DataSchema::AddColumn(const ColDescriptor &cd) { + // Sanity check there's not a duplicate name before adding the column + for (int32_t i = 0; i < col_descs_.size(); ++i) { + if (col_descs_[i].name() == cd.name()) { + std::ostringstream ss; + ss << "column name '" << cd.name() << "' already exists in schema."; + std::string err_msg = ss.str(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + col_descs_.push_back(cd); + return Status::OK(); +} + +// Internal helper function. Performs sanity checks on the json file setup. +Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) { + // Check if columns node exists. It is required for building schema from file. + if (js.find("columns") == js.end()) + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "\"columns\" node is required in the schema json file."); + return Status::OK(); +} + +// Loops through all columns in the schema and returns a map with the column +// name to column index number. +Status DataSchema::GetColumnNameMap(std::unordered_map *out_column_name_map) { + if (out_column_name_map == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "unexpected null output column name map."); + } + + for (int32_t i = 0; i < col_descs_.size(); ++i) { + if (col_descs_[i].name().empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Constructing column name map from schema, but found empty column name."); + } + (*out_column_name_map)[col_descs_[i].name()] = i; + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_schema.h b/mindspore/ccsrc/minddata/dataset/engine/data_schema.h new file mode 100644 index 0000000000..96f6f2b118 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/data_schema.h @@ -0,0 +1,208 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATA_SCHEMA_H_ +#define DATASET_ENGINE_DATA_SCHEMA_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +/// \class ColDescriptor data_schema.h +/// \brief A simple class to provide meta info about a column. +class ColDescriptor { + public: + /// \brief Constructor 1: Simple constructor that leaves things uninitialized. + ColDescriptor(); + + /// \brief Constructor 2: Main constructor + /// \param[in] col_name - The name of the column + /// \param[in] col_type - The DE Datatype of the column + /// \param[in] tensor_impl - The (initial) type of tensor implementation for the column + /// \param[in] rank - The number of dimension of the data + /// \param[in] in_shape - option argument for input shape + ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank, + const TensorShape *in_shape = nullptr); + + /// \brief Explicit copy constructor is required + /// \param[in] in_cd - the source ColDescriptor + ColDescriptor(const ColDescriptor &in_cd); + + /// \brief Assignment overload + /// \param in_cd - the source ColDescriptor + ColDescriptor &operator=(const ColDescriptor &in_cd); + + /// \brief Destructor + ~ColDescriptor(); + + /// \brief A print method typically used for debugging + /// \param out - The output stream to write output to + void Print(std::ostream &out) const; + + /// \brief Given a number of elements, this function will compute what the actual Tensor shape would be. + /// If there is no starting TensorShape in this column, or if there is a shape but it contains + /// an unknown dimension, then the output shape returned shall resolve dimensions as needed. + /// \param[in] num_elements - The number of elements in the data for a Tensor + /// \param[inout] out_shape - The materialized output Tensor shape + /// \return Status - The error code return + Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const; + + /// \brief << Stream output operator overload + /// This allows you to write the debug print info using stream operators + /// \param[in] out - reference to the output stream being overloaded + /// \param[in] cd - reference to the ColDescriptor to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ColDescriptor &cd) { + cd.Print(out); + return out; + } + + /// \brief getter function + /// \return The column's DataType + DataType type() const { return type_; } + + /// \brief getter function + /// \return The column's rank + int32_t rank() const { return rank_; } + + /// \brief getter function + /// \return The column's name + std::string name() const { return col_name_; } + + /// \brief getter function + /// \return The column's shape + TensorShape shape() const; + + /// \brief getter function + /// \return TF if the column has an assigned fixed shape. + bool hasShape() const { return tensor_shape_ != nullptr; } + + /// \brief getter function + /// \return The column's tensor implementation type + TensorImpl tensorImpl() const { return tensor_impl_; } + + private: + DataType type_; // The columns type + int32_t rank_; // The rank for this column (number of dimensions) + TensorImpl tensor_impl_; // The initial flavour of the tensor for this column + std::unique_ptr tensor_shape_; // The fixed shape (if given by user) + std::string col_name_; // The name of the column +}; + +/// \class DataSchema data_schema.h +/// \brief A list of the columns. +class DataSchema { + public: + /// \brief Constructor + DataSchema(); + + /// \brief Destructor + ~DataSchema(); + + /// \brief Parses a schema json file and populates the columns and meta info. + /// \param[in] schema_file_path - the schema file that has the column's info to load + /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. + /// \return Status - The error code return + Status LoadSchemaFile(const std::string &schema_file_path, const std::vector &columns_to_load); + + /// \brief Parses a schema JSON string and populates the columns and meta info. + /// \param[in] schema_json_string - the schema file that has the column's info to load + /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. + /// \return Status - The error code return + Status LoadSchemaString(const std::string &schema_json_string, const std::vector &columns_to_load); + + /// \brief A print method typically used for debugging + /// \param[in] out - The output stream to write output to + void Print(std::ostream &out) const; + + /// \brief << Stream output operator overload. This allows you to write the debug print info using stream operators + /// \param[in] out - reference to the output stream being overloaded + /// \param[in] ds - reference to the DataSchema to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const DataSchema &ds) { + ds.Print(out); + return out; + } + + /// \brief Adds a column descriptor to the schema + /// \param[in] cd - The ColDescriptor to add + /// \return Status - The error code return + Status AddColumn(const ColDescriptor &cd); + + /// \brief getter + /// \return The reference to a ColDescriptor to get (const version) + const ColDescriptor &column(int32_t idx) const; + + /// \brief getter + /// \return The number of columns in the schema + int32_t NumColumns() const { return col_descs_.size(); } + + bool Empty() const { return NumColumns() == 0; } + + /// \brief getter + /// \return The number of rows read from schema + int64_t num_rows() const { return num_rows_; } + + static const char DEFAULT_DATA_SCHEMA_FILENAME[]; + + /// \brief Loops through all columns in the schema and returns a map with the column name to column index number. + /// \param[inout] out_column_name_map - The output map of columns names to column index + /// \return Status - The error code return + Status GetColumnNameMap(std::unordered_map *out_column_name_map); + + private: + /// \brief Internal helper function. Parses the json schema file in any order and produces a schema that + /// does not follow any particular order (json standard does not enforce any ordering protocol). + /// This one produces a schema that contains all of the columns from the schema file. + /// \param[in] column_tree - The nlohmann tree from the json file to parse + /// \return Status - The error code return + Status AnyOrderLoad(nlohmann::json column_tree); + + /// \brief Internal helper function. For each input column name, perform a lookup to the json document to + /// find the matching column. When the match is found, process that column to build the column + /// descriptor and add to the schema in the order in which the input column names are given. + /// \param[in] column_tree - The nlohmann tree from the json file to parse + /// \param[in] columns_to_load - list of strings for the columns to add to the schema + /// \return Status - The error code return + Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector &columns_to_load); + + /// \brief Internal helper function. Given the json tree for a given column, load it into our schema. + /// \param[in] columnTree - The nlohmann child tree for a given column to load. + /// \param[in] col_name - The string name of the column for that subtree. + /// \return Status - The error code return + Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name); + + /// \brief Internal helper function. Performs sanity checks on the json file setup. + /// \param[in] js - The nlohmann tree for the schema file + /// \return Status - The error code return + Status PreLoadExceptionCheck(const nlohmann::json &js); + + std::vector col_descs_; // Vector of column descriptors + int64_t num_rows_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATA_SCHEMA_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc new file mode 100644 index 0000000000..f75ca5d097 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc @@ -0,0 +1,268 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/dataset_iterator.h" +#include +#include +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +namespace mindspore { +namespace dataset { +// Constructor of the IteratorBase +IteratorBase::IteratorBase() : curr_buffer_(nullptr), eof_handled_(false) {} + +IteratorBase::~IteratorBase() = default; + +// Fetches one row of data from the iterator as a column map. +Status IteratorBase::GetNextAsMap(TensorMap *out_map) { + if (out_map == nullptr) { + RETURN_STATUS_UNEXPECTED("Null output map in iterator!"); + } + + out_map->clear(); + + TensorRow curr_row; + RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row)); + + // Return empty map if there's no data + if (curr_row.empty()) { + return Status::OK(); + } + + // The column name mapping is needed to be able to produce the tensor map output. + // The column name mapping comes from the source operator that is producing the data into the iterator. + // To avoid having to fetch this for every time, we'll take a local copy of the column name id mapping + // and save in the iterator. We only have to do this once. All subsequent iterations use the same mapping. + if (col_name_id_map_.empty()) { + // Determine the column name map by calling the derived class method to retrieve the column + // name map + col_name_id_map_ = this->GetColumnNameMap(); + } + + // Populate the out map from the row and return it + for (auto colMap : col_name_id_map_) { + (*out_map)[colMap.first] = std::move(curr_row[colMap.second]); + } + + return Status::OK(); +} + +// Fetches one row of data from the iterator. +// The base class version simply performs error handling and returns empty row. Actual +// functionality exists in the derived versions of this function. +Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) { + if (out_row == nullptr) { + RETURN_STATUS_UNEXPECTED("Null output row in iterator!"); + } + + // clear the old tensor row + out_row->clear(); + + return Status::OK(); +} + +// Constructor of the DatasetIterator +DatasetIterator::DatasetIterator(std::shared_ptr exe_tree) + : IteratorBase(), + root_(exe_tree->root()), + tracing_(nullptr), + cur_batch_num_(0), + cur_connector_size_(0), + cur_connector_capacity_(0) { + std::shared_ptr node; + Status s = exe_tree->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node); + if (s.IsOk()) { + tracing_ = std::dynamic_pointer_cast(node); + } +} + +DatasetIterator::~DatasetIterator() = default; + +// Fetches one row of data from the iterator. Overrides the base class. This one fetches +// from the tree root node directly. +Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) { + // Common code init and error checking in the base class. + RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row)); + + // Once eof is handled, always return empty row. Class must be destroyed and recreated if you + // want to iterate again. + if (eof_handled_) { + return Status::OK(); + } + + // Check if we need to get a new DataBuffer to iterate. + if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { + if (tracing_ != nullptr) { + cur_connector_size_ = root_->ConnectorSize(); + cur_connector_capacity_ = root_->ConnectorCapacity(); + } + RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); + + // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually + // handle eoe and eof messages here. + // + // An eoe buffer means we have iterated fully to the end of the tree. + // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of + // all operators. + if (curr_buffer_->eoe()) { + MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row."; + + // Before returning the last empty vector, fetch the eof buffer which should be the last + // buffer, and then free it. + RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); + + if (!curr_buffer_->eof()) { + RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!"); + } + eof_handled_ = true; + curr_buffer_.reset(); // explicitly free the eof buffer + // Set tree to Finished state + root_->Tree()->SetFinished(); + + return Status::OK(); + } + + if (curr_buffer_->eof()) { + // An eof by itself, without being preceded by an eoe, is possible if a repeat operator + // exists below us in the stack. Repeat operator eats eoe's but eventually allows the + // flow of an eof up the pipeline by itself. + eof_handled_ = true; + curr_buffer_.reset(); // explicitly free the eof buffer + // Set tree to Finished state + root_->Tree()->SetFinished(); + return Status::OK(); + } + } + + // If we got this far, now it's time to pop that next row for return to caller + RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row)); + if (tracing_ != nullptr) { + cur_batch_num_++; + tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_); + } + return Status::OK(); +} + +Status DatasetIterator::GetOutputShapes(std::vector *out_shapes) { + if (out_shapes == nullptr) { + RETURN_STATUS_UNEXPECTED("Null output shape argument"); + } + if (device_queue_row_.empty()) { + RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); + } + for (auto ts : device_queue_row_) { + out_shapes->push_back(ts->shape()); + } + + return Status::OK(); +} + +Status DatasetIterator::GetOutputTypes(std::vector *out_types) { + if (out_types == nullptr) { + RETURN_STATUS_UNEXPECTED("Null output type argument"); + } + if (device_queue_row_.empty()) { + RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); + } + for (auto ts : device_queue_row_) { + out_types->push_back(ts->type()); + } + return Status::OK(); +} + +// Getter +std::unordered_map DatasetIterator::GetColumnNameMap() const { + return root_->column_name_id_map(); +} + +// Constructor of the ChildIterator +ChildIterator::ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx) + : IteratorBase(), current_op_(current_op), child_idx_(child_idx), worker_id_(worker_id), end_epoch_(false) {} + +ChildIterator::~ChildIterator() { current_op_ = nullptr; } + +// Fetches one row of data from the iterator. Overrides the base class. This one fetches +// only from the child/worker id as given from the constructor. +Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) { + // Common code init and error checking in the base class. + RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row)); + + // Once eof is handled, always return empty row. Class must be destroyed and recreated if you + // want to iterate again. + if (eof_handled_) { + return Status::OK(); + } + + // Check if we need to get a new DataBuffer to iterate. + if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { + RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); + + // Unlike the DatasetIterator, this child iterator does not quit after eoe. + // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the + // caller to decide what it wants to do next. + if (curr_buffer_->eoe()) { + MS_LOG(DEBUG) << "Child iterator picked up EOE."; + end_epoch_ = true; + return Status::OK(); + } + + if (curr_buffer_->eof()) { + MS_LOG(DEBUG) << "Child iterator picked up EOF."; + eof_handled_ = true; + return Status::OK(); + } + } + + // If we got this far, now it's time to pop that next row for return to caller + RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row)); + + return Status::OK(); +} + +// drain till the next eoe +Status ChildIterator::Drain() { + if (end_epoch_ == true) { + // Calling drain against a child that is already at it's eoe state will not result in any action. + // This allows you to do: + // - fetch until empty row + // - drain (will not actually drain because you are already at the end of the iteration) + // However, the next time after that, it will perform it's normal draining activities. + end_epoch_ = false; + MS_LOG(DEBUG) << "No operation drain, already at end of epoch."; + return Status::OK(); + } + MS_LOG(DEBUG) << "Child draining buffers until eoe."; + // else we drain until eoe or eof, eof here is for sanity check + while (!curr_buffer_->eoe() && !curr_buffer_->eof()) { + RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); + } + if (curr_buffer_->eof()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain."); + } + return Status::OK(); +} + +// Getter +std::unordered_map ChildIterator::GetColumnNameMap() const { + return current_op_->child(child_idx_)->column_name_id_map(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h new file mode 100644 index 0000000000..253d1604e2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h @@ -0,0 +1,156 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASET_ITERATOR_H_ +#define DATASET_ENGINE_DATASET_ITERATOR_H_ + +#include +#include +#include +#include +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h" + +namespace mindspore { +namespace dataset { +using TensorMap = std::unordered_map>; + +// forward declare +class ExecutionTree; + +class DataBuffer; + +// IteratorBase class is used to iterate data from an executionTree one row at a time. +// The base class provides the general interface, whereas derived classes provide slightly +// different implementations. +class IteratorBase { + public: + // Constructor of IteratorBase + IteratorBase(); + + // Destructor + virtual ~IteratorBase(); + + // Fetches one row of data from the iterator. + // the base class version simply performs error handling and returns empty row. Actual + // functionality exists in the derived versions of this function. + // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data + // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. + // @return Status - The error code return + // @note The position of a Tensor/column might be different from the initial column order + // in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change + // the column ordering. + virtual Status FetchNextTensorRow(TensorRow *out_row); + + // Fetches one row of data from the iterator as a column map. + // @return A unordered map from column name to shared pointer to Tensor. + Status GetNextAsMap(TensorMap *out_map); + + // Getter + // @return T/F if this iterator is completely done after getting an eof + bool eof_handled() const { return eof_handled_; } + + // Getter + // @return The string to column id mapping. + virtual std::unordered_map GetColumnNameMap() const = 0; + + protected: + std::unique_ptr curr_buffer_; // holds the current buffer + bool eof_handled_; // T/F if this op got an eof + std::unordered_map col_name_id_map_; +}; + +// The DatasetIterator derived class is for fetching rows off the end/root of the execution tree. +class DatasetIterator : public IteratorBase { + public: + // Constructor of the DatasetIterator + // @param exe_tree The execution tree we want to pull/iterate the data from using it's root node. + explicit DatasetIterator(std::shared_ptr exe_tree); + + // Destructor + ~DatasetIterator(); + + // Fetches one row of data from the iterator. Overrides the base class. This one fetches + // from the tree root node directly. + // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data + // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. + // @return Status - The error code return + Status FetchNextTensorRow(TensorRow *out_row) override; + + // Fetches the next tensor row into device row, and returns it's shape. + // @param out_shapes - A vector of tensor shapes (one shape per column) + // @return Status - The error code return + Status GetOutputShapes(std::vector *out_shapes); + + // Fetches the next tensor row into device row, and returns it's shape. + // @param outShapes - A vector of tensor shapes (one shape per column) + // @return Status - The error code return + Status GetOutputTypes(std::vector *out_types); + + // Getter + // @return The string to column id mapping. + std::unordered_map GetColumnNameMap() const override; + + private: + std::shared_ptr root_; // saves the root of the executionTree + TensorRow device_queue_row_; + std::shared_ptr tracing_; // trace profiling data + int32_t cur_batch_num_; // current batch number,used for profiling + int32_t cur_connector_size_; // current connector size of root op,used for profiling + int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling +}; + +// The ChildIterator derived class is for fetching rows from intermediate nodes of execution tree. +// This one should only be used by internal Dataset operators, rather than an end-user. +class ChildIterator : public IteratorBase { + public: + // Constructor of the DatasetIterator + // @param current_op - The parent op from which we'll fetch from it's children. + // @param worker_id - The worker id to use when fetching from the children. + // @param child_idx - The index to the child to fetch from. + ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx); + + // Destructor + ~ChildIterator(); + + // Fetches one row of data from the iterator. Overrides the base class. This one fetches + // only from the child/worker id as given from the constructor. + // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data + // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. + // @return Status - The error code return + Status FetchNextTensorRow(TensorRow *out_row) override; + + // This function drains buffer until next eoe has been received. + // It will be a no-op if the previous row returned is empty. + // @return Status - The error code return + Status Drain(); + + // Getter + // @return The string to column id mapping. + std::unordered_map GetColumnNameMap() const override; + + private: + DatasetOp *current_op_; // The parent operator. We consume from it's children. + int32_t child_idx_; // The specific child this iterator will fetch from. + int32_t worker_id_; // The worker id uses for fetching the child data. + bool end_epoch_; // the flag used when an empty row has been returned. +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASET_ITERATOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt new file mode 100644 index 0000000000..a2cd6dc07a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt @@ -0,0 +1,38 @@ +add_subdirectory(source) + +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) + +set(DATASET_ENGINE_DATASETOPS_SRC_FILES + dataset_op.cc + parallel_op.cc + pipeline_op.cc + batch_op.cc + device_queue_op.cc + map_op.cc + project_op.cc + rename_op.cc + repeat_op.cc + skip_op.cc + take_op.cc + shuffle_op.cc + zip_op.cc + concat_op.cc + cache_base_op.cc + cache_lookup_op.cc + cache_op.cc + cache_merge_op.cc + ) + +if (ENABLE_PYTHON) + set(DATASET_ENGINE_DATASETOPS_SRC_FILES + ${DATASET_ENGINE_DATASETOPS_SRC_FILES} + bucket_batch_by_length_op.cc + barrier_op.cc + filter_op.cc + build_vocab_op.cc + ) +endif() + +add_library(engine-datasetops OBJECT ${DATASET_ENGINE_DATASETOPS_SRC_FILES}) + diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.cc new file mode 100644 index 0000000000..51ea232e68 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.cc @@ -0,0 +1,242 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/barrier_op.h" +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +BarrierOp::Builder::Builder() { + // Some arguments to the BarrierOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the BarrierOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } + +Status BarrierOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, + builder_condition_func_); + return Status::OK(); +} + +// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions +BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func) + : PipelineOp(op_connector_size), + rows_per_buffer_(rows_per_buffer), + buffer_id_(0), + clean_up_(false), + eof_(false), + condition_name_(condition_name), + condition_function_(condition_func) {} + +// destructor +BarrierOp::~BarrierOp() {} + +// Entry point for Barrier, called by launch() +Status BarrierOp::operator()() { + // The children_num_ parameter needs to be put here + // Synchronize with TaskManager once the thread is created. + TaskManager::FindMe()->Post(); + + // create child iterator, right now this barrier is a pipeline operator + const int32_t worker_id = 0; + const int32_t child_idx = 0; + child_iterator_ = std::make_unique(this, worker_id, child_idx); + + // Loop until eof is true + while (!eof_) { + // Create new table to put the new tensor rows + std::unique_ptr curr_table = std::make_unique(); + RETURN_IF_NOT_OK(prepare(curr_table.get())); + + // If an eof got picked up during the above prepare, then we're done + if (eof_) { + break; + } + + // we have to output new buffer with possibly different buffer size, possibly one row + while (!clean_up_) { + // 1. If a previous loop iteration sent the current table out, then create a new one. + + if (curr_table == nullptr) { + curr_table = std::make_unique(); + } + + // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished + RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); + + // 3 create and update buffer and send it to the out connector + if (!curr_table->empty()) { + std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); + curr_buffer->set_tensor_table(std::move(curr_table)); + MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " + << curr_buffer->NumCols() << ", map " << column_name_id_map_.size() << "."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + buffer_id_++; + } + } + + // 4 handle drain state. + if (clean_up_) { + MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; + // Send the eoe up. + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + } + } + // 5 handle eof + // propagate eof here. + MS_LOG(INFO) << "Barrier operator got EOF, propagating."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Handles preprocessing of the main loop, used when starting new epoch +Status BarrierOp::prepare(TensorQTable *const table) { + MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; + clean_up_ = false; + buffer_id_ = 0; + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); + } + // fill initial row + TensorRow new_row = {}; + // use iterator to get next row and invoke pyfunc wait + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + + // If the first row fetching resulted in eof, then we are done. + if (eof_) { + return Status::OK(); + } + if (new_row.empty()) { + // This epoch is empty + return Status::OK(); + } + // Pack this first row into our tensor table + // first row we also have to check if we should block + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + + // the update code below shouldn't do anything bad if the column name already exists. + return Status::OK(); +} + +// fillBuffer always expects a new table to fill +Status BarrierOp::fillBuffer(TensorQTable *const table) { + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); + } + TensorRow new_row = {}; + while (table->size() < static_cast(rows_per_buffer_)) { + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + // Early exit the loop if we got empty row from any of our child iterations + if (new_row.empty()) { + return Status::OK(); + } + // else we got a row so pack it into the tensor table. + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + } + return Status::OK(); +} + +// function executes a py_func and blocks until condition becomes true. +Status BarrierOp::blockCond() { + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + // we have condition name, however the flexibility is in python today + try { + // Invoke python function + py::object ret_py_obj = condition_function_(); + // Process the return value + if (!py::isinstance(ret_py_obj)) { + return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +// fetches next Barrier buffer row +Status BarrierOp::getNextTensorRow(TensorRow *new_row) { + // iterate over all iterators and generate a row + RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); + // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row + if (new_row->empty()) { + // If we did not get a row from any of the children, then it's the end of an epoch and we can move + // to drain state. + MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; + clean_up_ = true; + // If we picked up an eof here, then we are completely done. + if ((child_iterator_)->eof_handled()) { + MS_LOG(INFO) << "Barrier operator iterator got EOF."; + eof_ = true; + } + return Status::OK(); + } + return Status::OK(); +} + +// A function that prints info about the Operator +void BarrierOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nCondition: " << condition_name_ << "\n\n"; + } +} + +// overwrite function and handle eof +Status BarrierOp::EofReceived(int32_t) { + MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; + return Status::OK(); +} + +// overwrite function and handle eoe +Status BarrierOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h new file mode 100644 index 0000000000..a3ac843272 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h @@ -0,0 +1,169 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class DataBuffer; +class ExecutionTree; + +// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has +// been received. This signal is given from python layer. The current barrier design respects the +// rows per buffer design and will only output a buffer with rows once it has received rows per buffer +// signals from python. + +class BarrierOp : public PipelineOp { + public: + // The nested builder class inside of the BarrierOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @param const std::string & condition_name + // @return Builder setter method returns reference to the builder. + Builder &SetConditionName(const std::string &condition_name) { + builder_condition_name_ = condition_name; + return *this; + } + + // Setter method. + // @param py::function condition_func - blocking condition function + // @return Builder setter method returns reference to the builder. + Builder &SetConditionFunc(py::function condition_func) { + builder_condition_func_ = condition_func; + return *this; + } + + // The builder "build" method creates the BarrierOp dataset Operator. + // @return shared_ptr to the new BarrierOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::string builder_condition_name_; + py::function builder_condition_func_; + + Status SanityCheck() const; + }; + + // Constructor for BarrierOp + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size + // @param condition_name - the condition name associated with this operator + // @param condition_func - the blocking condition check per row + // @note - currently rows_per_buffer should = 1 for barrier. + // The reason for this is having other values would complicate how the pipeline behaves with other operators + // One example of such case is having batch after barrier. Batch would be waiting for data and having + // rows per buffer in this case can result in hanging + BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func); + + // Destructor + ~BarrierOp(); + + Status EofReceived(int32_t) override; + + Status EoeReceived(int32_t) override; + + // Print function for Barrier + // @param out - output stream to print to + // @param show_all - if it should print everything + void Print(std::ostream &out, bool show_all) const override; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { + bo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Handles preprocessing of the main loop, used when starting new epoch + // @param table - a table of tensors to be moved into a buffer + Status prepare(TensorQTable *const table); + + // This function calls takes a table repeatedly adds rows to it. + // @param table - a table of tensors to be moved into a buffer + Status fillBuffer(TensorQTable *const table); + + // Gets next tensor row and sets control signals + Status getNextTensorRow(TensorRow *new_row); + + // This function runs the wait function on condition + Status blockCond(); + + private: + // clean up variable to return imcomplete buffer + bool clean_up_; + // end of file state, we stop reading data and shut down + bool eof_; + // rows per buffer + int32_t rows_per_buffer_; + // buffer_id + int32_t buffer_id_; + // iterator to pull new rows, we only have one child + std::unique_ptr child_iterator_; + // condition name, to support multiple barriers + std::string condition_name_; + // Function pointer of blocking function + py::function condition_function_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc new file mode 100644 index 0000000000..844d054307 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -0,0 +1,446 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/batch_op.h" + +#include +#include + +#include "common/utils.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/core/pybind_support.h" +#endif +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/kernels/data/data_utils.h" + +using float16 = Eigen::half; + +namespace mindspore { +namespace dataset { +BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false), builder_pad_(false), builder_pad_map_({}) { + builder_batch_size_ = batch_size; + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status BatchOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); +#ifdef ENABLE_PYTHON + *ptr = std::make_shared(builder_batch_size_, builder_drop_, builder_pad_, builder_op_connector_size_, + builder_num_workers_, builder_cols_to_map_, builder_batch_size_func_, + builder_batch_map_func_, builder_pad_map_); +#else + *ptr = std::make_shared(builder_batch_size_, builder_drop_, builder_pad_, builder_op_connector_size_, + builder_num_workers_, builder_cols_to_map_, builder_pad_map_); +#endif + return Status::OK(); +} + +Status BatchOp::Builder::SanityCheck() { + std::string err; + err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; + err += builder_batch_size_ <= 0 ? "batch size <= 0\n" : ""; + err += builder_num_workers_ <= 0 ? "batch num_parallel_workers <= 0\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); +} + +#ifdef ENABLE_PYTHON +BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, + const std::vector &cols_to_map, py::function batch_size_func, py::function batch_map_func, + PadInfo pad_map) + : ParallelOp(num_workers, op_queue_size), + start_batch_size_(batch_size), + drop_(drop), + pad_(pad), + pyfunc_column_names_(cols_to_map), + batch_size_func_(batch_size_func), + batch_map_func_(batch_map_func), + pad_info_(pad_map) { + worker_queues_.Init(num_workers, op_queue_size); +} +#else +BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, + const std::vector &cols_to_map, PadInfo pad_map) + : ParallelOp(num_workers, op_queue_size), + start_batch_size_(batch_size), + drop_(drop), + pad_(pad), + pyfunc_column_names_(cols_to_map), + pad_info_(pad_map) { + worker_queues_.Init(num_workers, op_queue_size); +} +#endif + +Status BatchOp::operator()() { + Status rc = LaunchThreadsAndInitOp(); + // Synchronize with TaskManager + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(rc); + int64_t epoch_num = 0, batch_num = 0, cnt = 0; + TensorRow new_row; + std::unique_ptr table = std::make_unique(); + child_iterator_ = std::make_unique(this, 0, 0); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + int32_t cur_batch_size = 0; + RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(0, 0, 0))); + while (child_iterator_->eof_handled() == false) { + while (new_row.empty() == false) { + table->emplace_back(new_row); + // if # of rows is enough to make 1 batch (1 batch is buffer), send it to worker_queue + if (table->size() == static_cast(cur_batch_size)) { + RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack( + std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt - epoch_num)))); + table = std::make_unique(); + RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num))); + } + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + } + // Reminder logic, execute only when there is a remainder (table is non empty) and don't drop + if (drop_ == false && table->empty() == false) { + RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack( + std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt - epoch_num)))); + } + table = std::make_unique(); // this drops when drop == true + // end of the current epoch, batch_num should start from 0 again + batch_num = 0; + epoch_num++; + RETURN_IF_NOT_OK( + worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOE)))); + RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num))); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + } // end of eof_handled() == false + RETURN_IF_NOT_OK( + worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOF)))); + // EOF received, send quit signal (an empty buffer) to all workers + for (int32_t ind = 0; ind < num_workers_; ind++) { + RETURN_IF_NOT_OK( + worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kQuit)))); + } + return Status::OK(); +} + +void BatchOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [batch size: " << start_batch_size_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nStart batch size: " << start_batch_size_ << "\nDrop remainder: " << (drop_ ? "yes" : "no") << "\n\n"; + } +} + +Status BatchOp::BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, + dsize_t batch_size) { + if ((*src)->size() != batch_size) { + RETURN_STATUS_UNEXPECTED("[Internal Batch ERROR] Source table size does not match the batch_size"); + } + + if (batch_size == 1) { + TensorRow row = std::move((*src)->front()); + (*src)->pop_front(); + (*dest)->push_back(row); + for (const auto &tensor : (*dest)->front()) { + RETURN_IF_NOT_OK(tensor->ExpandDim(0)); + } + return Status::OK(); + } + + TensorRow batched_row; + auto num_columns = (*src)->front().size(); + for (size_t i = 0; i < num_columns; i++) { + std::shared_ptr first_tensor = (*src)->at(0).at(i); // first row, column i + TensorShape first_shape = first_tensor->shape(); + DataType first_type = first_tensor->type(); + TensorShape new_shape = first_shape.PrependDim(static_cast(batch_size)); + + std::shared_ptr new_tensor; + if (first_type.IsNumeric()) { // numeric tensor + RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, TensorImpl::kFlexible, new_shape, first_type)); + dsize_t j = 0; + for (auto row : **src) { + std::shared_ptr old_tensor = row.at(i); // row j, column i + if (old_tensor->shape() == first_shape) { // check the newly popped rows have the same dim as the first + RETURN_IF_NOT_OK(new_tensor->InsertTensor({j++}, old_tensor)); + } else { + RETURN_STATUS_UNEXPECTED("[Batch ERROR] Inconsistent TensorShapes of Column " + std::to_string(i)); + } + } + } else { // handle string column differently + std::vector strings; + for (dsize_t j = 0; j < batch_size; j++) { + std::shared_ptr old_tensor = (*src)->at(j).at(i); + for (auto itr = old_tensor->begin(); itr != old_tensor->end(); itr++) { + strings.emplace_back(*itr); + } + } + RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, strings, new_shape)); + } + batched_row.emplace_back(new_tensor); + } + + (*dest)->emplace_back(batched_row); + + return Status::OK(); +} + +Status BatchOp::WorkerEntry(int32_t workerId) { + TaskManager::FindMe()->Post(); + std::pair, CBatchInfo> table_pair; + RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair)); + while (table_pair.second.ctrl_ != batchCtrl::kQuit) { + if (table_pair.second.ctrl_ == batchCtrl::kEOE) { + RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + } else if (table_pair.second.ctrl_ == batchCtrl::kEOF) { + RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else if (table_pair.second.ctrl_ == batchCtrl::kNoCtrl) { + std::unique_ptr db = nullptr; + RETURN_IF_NOT_OK(MakeBatchedBuffer(std::move(table_pair), &db)); + RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::move(db))); + } + RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair)); + } + return Status::OK(); +} + +Status BatchOp::MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, + std::unique_ptr *db) { + RETURN_UNEXPECTED_IF_NULL(table_pair.first); +#ifdef ENABLE_PYTHON + if (!pyfunc_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc +#endif + if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_)); // do padding if needed + (*db) = std::make_unique(table_pair.second.batch_num_, DataBuffer::kDeBFlagNone); + std::unique_ptr dest_table = std::make_unique(); + RETURN_IF_NOT_OK(BatchRows(&table_pair.first, &dest_table, table_pair.first->size())); + (*db)->set_tensor_table(std::move(dest_table)); + return Status::OK(); +} + +Status BatchOp::LaunchThreadsAndInitOp() { + RETURN_UNEXPECTED_IF_NULL(tree_); + RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&BatchOp::WorkerEntry, this, std::placeholders::_1))); + return Status::OK(); +} + +Status BatchOp::EofReceived(int32_t) { return Status::OK(); } + +Status BatchOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +Status BatchOp::MapColumns(std::pair, CBatchInfo> *table_pair) { + TensorBatchTable input_table; + input_table.reserve(pyfunc_column_names_.size()); + for (std::string col_name : pyfunc_column_names_) { + if (column_name_id_map_.find(col_name) == column_name_id_map_.end()) { + RETURN_STATUS_UNEXPECTED("column : '" + col_name + "' does not exist\n"); + } + TensorBatch tensor_batch; + tensor_batch.reserve(table_pair->first->size()); + size_t col_idx = static_cast(column_name_id_map_[col_name]); + for (size_t row_idx = 0; row_idx < table_pair->first->size(); row_idx++) { + tensor_batch.push_back(std::move(table_pair->first->at(row_idx)[col_idx])); + } + input_table.push_back(std::move(tensor_batch)); + } + + // Perform batch map + TensorBatchTable output_table; + RETURN_IF_NOT_OK(InvokeBatchMapFunc(&input_table, &output_table, table_pair->second)); + + // Write back to TensorQTable + for (size_t input_idx = 0; input_idx < pyfunc_column_names_.size(); input_idx++) { + size_t col_idx = static_cast(column_name_id_map_[pyfunc_column_names_[input_idx]]); + size_t row_id = 0; + for (TensorRow &row : *(table_pair->first)) { + row[col_idx] = std::move(output_table[input_idx][row_id++]); + } + } + return Status::OK(); +} +#endif + +Status BatchOp::GetBatchSize(int32_t *batch_size, CBatchInfo info) { +#ifdef ENABLE_PYTHON + if (batch_size_func_ != nullptr) { + RETURN_IF_NOT_OK(InvokeBatchSizeFunc(batch_size, info)); + } else { + (*batch_size) = start_batch_size_; + } +#else + (*batch_size) = start_batch_size_; +#endif + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) { + { + // Acquire Python GIL + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py::object size = batch_size_func_(info); + *batch_size = size.cast(); + if (*batch_size <= 0) { + return Status(StatusCode::kPyFuncException, "Batch size function should return an integer > 0"); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Batch size function should return an integer > 0"); + } + } + return Status(StatusCode::kOK, "Batch size func call succeed"); +} + +Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *output, CBatchInfo info) { + { + // Acquire Python GIL + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + // Prepare batch map call back parameters + py::tuple input_args(input->size() + 1); + for (size_t i = 0; i < input->size(); i++) { + std::vector np_batch; + for (std::shared_ptr t : input->at(i)) { + py::array np_array; + RETURN_IF_NOT_OK(t->GetDataAsNumpy(&np_array)); + np_batch.push_back(std::move(np_array)); + } + input_args[i] = np_batch; + } + input_args[input->size()] = info; + // Invoke batch map func + py::object ret_py_obj = batch_map_func_(*input_args); + // Parse batch map return value + py::tuple ret_tuple = py::cast(ret_py_obj); + if (ret_tuple.size() != pyfunc_column_names_.size() || !py::isinstance(ret_tuple)) { + return Status(StatusCode::kPyFuncException, "Batch map function should return a tuple"); + } + for (size_t i = 0; i < ret_tuple.size(); i++) { + TensorBatch output_batch; + py::list output_list = py::cast(ret_tuple[i]); + for (size_t j = 0; j < output_list.size(); j++) { + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, py::cast(output_list[j]))); + output_batch.push_back(std::move(out)); + } + output->push_back(std::move(output_batch)); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Batch map function should return an tuple of list of numpy array"); + } + } + return Status(StatusCode::kOK); +} +#endif + +Status BatchOp::PadColumns(std::unique_ptr *table, const PadInfo &pad_info, + const std::unordered_map &column_name_id_map) { + RETURN_UNEXPECTED_IF_NULL(table); // placeholder for now, might need this in the future + CHECK_FAIL_RETURN_UNEXPECTED((*table)->front().size() == column_name_id_map.size(), "col_name_map mismatch"); + std::vector> pad_vals(column_name_id_map.size(), + 0); // value to pad each column's tensor with, default 0 + std::set pad_cols; + // padded_shape provided by user, maximum shapes of current batch of tensors + std::vector> pad_shapes(column_name_id_map.size()), max_shapes(column_name_id_map.size()); + RETURN_IF_NOT_OK(UnpackPadInfo(pad_info, column_name_id_map, &pad_cols, &pad_vals, &pad_shapes)); + + // init each shape in max_shape to {-1,-1...} init each unspecified shape in pad_shape to -1 as well + for (size_t col_id : pad_cols) { + max_shapes[col_id] = std::vector((*table)->front()[col_id]->Rank(), -1); + if (pad_shapes[col_id].empty()) pad_shapes[col_id] = max_shapes[col_id]; // fill pad shape with -1 + CHECK_FAIL_RETURN_UNEXPECTED(pad_shapes[col_id].size() == max_shapes[col_id].size(), "wrong rank in pad_shape"); + } + + // calculate maximum shape for each column that needs to be padded + for (const TensorRow &row : **table) { // iterator each row in a batch + for (size_t col_id : pad_cols) { // iterator each tensor in a row + CHECK_FAIL_RETURN_UNEXPECTED(row[col_id]->Rank() == max_shapes[col_id].size(), + "Tensor to be padded together need to have the same rank"); + for (size_t dim = 0; dim < row[col_id]->Rank(); dim++) { // pick the largest number in each dimension + max_shapes[col_id][dim] = std::max(max_shapes[col_id][dim], row[col_id]->shape()[dim]); + } + } + } + + // if user sets a dimension to -1 (None in python), use the max value for current dimension + for (size_t col_id : pad_cols) { + for (size_t dim = 0; dim < pad_shapes[col_id].size(); dim++) { + if (pad_shapes[col_id][dim] < 0) pad_shapes[col_id][dim] = max_shapes[col_id][dim]; + } + } + + // call pad on each tensor that needs to be padded + for (TensorRow &row : **table) { + for (size_t col_id : pad_cols) { + std::shared_ptr pad_tensor; + RETURN_IF_NOT_OK(PadEnd(row[col_id], &pad_tensor, pad_shapes[col_id], pad_vals[col_id])); + row[col_id] = pad_tensor; + } + } + return Status::OK(); +} + +Status BatchOp::UnpackPadInfo(const PadInfo &pad_info, + const std::unordered_map &column_name_id_map, + std::set *pad_cols, std::vector> *pad_vals, + std::vector> *pad_shapes) { + if (pad_info.empty()) { // if pad_info empty, pad every columns automatically + for (dsize_t col_id = 0; col_id < column_name_id_map.size(); col_id++) { + pad_cols->insert(col_id); + } + } else { + for (const auto &p : pad_info) { + auto location = column_name_id_map.find(p.first); + CHECK_FAIL_RETURN_UNEXPECTED(location != column_name_id_map.end(), "no column exists with name:" + p.first); + auto col_id = static_cast(location->second); + CHECK_FAIL_RETURN_UNEXPECTED(col_id < pad_vals->size() && col_id < pad_shapes->size(), "col_id out of bound"); + pad_cols->insert(col_id); + (*pad_vals)[col_id] = p.second.second; // set pad values + (*pad_shapes)[col_id] = p.second.first.AsVector(); // empty vector if shape is unknown + } + } + return Status::OK(); +} + +// Visitor accept method for NodePass +Status BatchOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h new file mode 100644 index 0000000000..0c042433f7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h @@ -0,0 +1,287 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DataBuffer; + +using TensorBatch = TensorRow; +using TensorBatchTable = std::vector; +using PadInfo = std::map>>; + +class BatchOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor for Batch, batch size needs to be specified + // @param int32_t batch_size + explicit Builder(int32_t batch_size); + + // Default destructor + ~Builder() = default; + + // set number of parallel Workers on batch + // @param int32_t num_workers + // @return Builder & reference to builder class object + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // set drop for batch op,default false + // @param bool drop + // @return Builder & reference to builder class object + Builder &SetDrop(bool drop) { + builder_drop_ = drop; + return *this; + } + + Builder &SetPaddingMap(const PadInfo &pad_map, bool pad = true) { + builder_pad_ = pad; + builder_pad_map_ = pad_map; + return *this; + } + + // set connector size for batch + // @param int32_t op_conn_size + // @return Builder & reference to builder class object + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = (op_connector_size == 0 ? builder_op_connector_size_ : op_connector_size); + return *this; + } + + // set columns to perform map on + // @param const std::vector & cols_to_map - name of columns to perform map on + // @return Builder & reference to builder class object + Builder &SetColumnsToMap(const std::vector &cols_to_map) { + builder_cols_to_map_ = cols_to_map; + return *this; + } + +#ifdef ENABLE_PYTHON + // set columns to perform map on + // @param const std::vector & cols_to_map - name of columns to perform map on + // @return Builder & reference to builder class object + Builder &SetBatchMapFunc(py::function batch_map_func) { + builder_batch_map_func_ = batch_map_func; + return *this; + } + + // SetBatchSizeFunc, a function that calls to python after every batch is made + // @param py::function batch_size_func - python function to call, GIL required before calling + // @return Builder & reference to builder class object + Builder &SetBatchSizeFunc(py::function batch_size_func) { + builder_batch_size_func_ = batch_size_func; + return *this; + } +#endif + + // @param std::shared_ptr *ptr pointer to shared_ptr, actual return arg + // @return Status - The error code return + Status Build(std::shared_ptr *); + + private: + // Sanity check for builder class args + // @return Status - The error code return + Status SanityCheck(); + + bool builder_drop_; + bool builder_pad_; + int32_t builder_batch_size_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + std::vector builder_cols_to_map_; + PadInfo builder_pad_map_; +#ifdef ENABLE_PYTHON + py::function builder_batch_size_func_; + py::function builder_batch_map_func_; +#endif + }; + + enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 }; + + // Parameters associate with one batch. + // This struct is used for both internal control and python callback. + // This struct is bound to python with read-only access. + struct CBatchInfo { + CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl) + : epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {} + CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {} + CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {} + explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {} + int64_t epoch_num_; // i-th epoch. i starts from 0 + int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0 + int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0 + batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3 + const int64_t get_batch_num() const { return batch_num_; } + const int64_t get_epoch_num() const { return epoch_num_; } + }; + +#ifdef ENABLE_PYTHON + // BatchOp constructor + // @param int32_t batch_size + // @param bool drop + // @param int32_t op_queue_size + // @param int32_t rows_per_buf + // @param int32_t num_workers + BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, + const std::vector &, py::function batch_size_func, py::function batch_map_func, PadInfo pad_map); +#else + BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, + const std::vector &, PadInfo pad_map); +#endif + + // BatchOp destructor + ~BatchOp() {} + + // @param int32_t workerId + // @return Status - The error code return + Status EofReceived(int32_t) override; + + // @param int32_t workerId + // @return Status - The error code return + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param sO - reference to the BatchOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const BatchOp &bo) { + bo.Print(out, false); + return out; + } + + // Main loop of batch + // @return Status - The error code return + Status operator()() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "BatchOp"; } + + // batch the rows in src table then put it to dest table + // @param const std::unique_ptr *src - table that has the rows for batching + // @param const std::unique_ptr *dest - dest_table to hold batched rows + // @param int32_t size - batch_size + // @param const std::unordered_map& column_name_id_map - column names to index mapping + // @return Status - The error code return + static Status BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, + dsize_t batch_size); + + // @param table + // @param const PadInfo &pad_info pad info + // @param const std::unordered_map& column_name_id_map - column names to index mapping + // @return Status - The error code return + static Status PadColumns(std::unique_ptr *table, const PadInfo &pad_info, + const std::unordered_map &column_name_id_map); + + private: + // Worker thread for doing the memcpy of batch + // @param int32_t param workerId + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Generate buffer with batched tensors + // @return Status - The error code return + Status MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, + std::unique_ptr *db); + +#ifdef ENABLE_PYTHON + // Function that calls pyfunc to perform map on batch + // @param (std::pair, batch_stats> *table_pair - contains un-batched tensor + // @return Status - The error code return + Status MapColumns(std::pair, CBatchInfo> *table_pair); +#endif + + // @param const PadInfo &pad_info pad info to unpack + // @param const std::unordered_map& column_name_id_map - column names to index mapping + // @param std::set *cols, col ids to perform pad on + // @param std::vector *vals, default padding value for each column + // @param std::vector> *shapes, padding shape specified by user + // @return Status - The error code return + static Status UnpackPadInfo(const PadInfo &pad_info, + const std::unordered_map &column_name_id_map, + std::set *pad_cols, std::vector> *pad_vals, + std::vector> *pad_shapes); + + // the number of thread pulling from the mOutConnector of the Op below + // @return int32_t, 1 + int32_t num_consumers() const override { return 1; } + + // get the batch size for next batch + // @return Status - The error code return + Status GetBatchSize(int32_t *batch_size, CBatchInfo info); + + // Do the initialization of all queues then start all worker threads + // @return Status - The error code return + Status LaunchThreadsAndInitOp(); + +#ifdef ENABLE_PYTHON + // Invoke batch size function with current BatchInfo to generate batch size. + // @return Status - The error code return + Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info); + + // Invoke batch map function with current BatchInfo to generate tensors to batch. + // @return Status - The error code return + Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info); +#endif + + int32_t start_batch_size_; + bool drop_; // bool for whether to drop remainder or not + bool pad_; // bool for whether to perform padding on tensor + std::vector pyfunc_column_names_; // Name of the columns to perform map op on + PadInfo pad_info_; // column names to perform padding on + std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 + QueueList, CBatchInfo>> worker_queues_; // internal queue for syncing worker +#ifdef ENABLE_PYTHON + py::function batch_size_func_; // Function pointer of batch size function + py::function batch_map_func_; // Function pointer of per batch map function +#endif +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc new file mode 100644 index 0000000000..138bb7980b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc @@ -0,0 +1,240 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h" + +#include +#include +#include +#include +#include + +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "minddata/dataset/core/pybind_support.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/util/status.h" + +namespace py = pybind11; +namespace mindspore { +namespace dataset { +BucketBatchByLengthOp::Builder::Builder(std::vector length_dependent_columns, + std::vector bucket_boundaries, std::vector bucket_batch_sizes) + : builder_length_dependent_columns_(length_dependent_columns), + builder_bucket_boundaries_(bucket_boundaries), + builder_bucket_batch_sizes_(bucket_batch_sizes), + builder_pad_info_({}), + builder_pad_to_bucket_boundary_(false), + builder_drop_remainder_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_op_connector_size_ = config_manager->op_connector_size(); +} + +Status BucketBatchByLengthOp::Builder::SanityCheck() { + std::string error_message; + + if (builder_length_dependent_columns_.empty()) { + error_message += "At least 1 column must be specified for element length calculation.\n"; + } + + if (builder_bucket_boundaries_.empty()) { + error_message += "At least 1 bucket boundary must be specified.\n"; + } + + if (builder_bucket_batch_sizes_.size() != builder_bucket_boundaries_.size() + 1) { + error_message += "There must be exactly one bucket batch size specified for each bucket boundary.\n"; + } + + CHECK_FAIL_RETURN_UNEXPECTED(error_message.empty(), error_message); + + return Status::OK(); +} + +Status BucketBatchByLengthOp::Builder::Build(std::shared_ptr *new_bucket_batch_by_length_op) { + RETURN_IF_NOT_OK(SanityCheck()); + + // insert 0 for the first bucket + builder_bucket_boundaries_.insert(builder_bucket_boundaries_.begin(), 0); + + *new_bucket_batch_by_length_op = std::make_shared( + builder_length_dependent_columns_, builder_bucket_boundaries_, builder_bucket_batch_sizes_, + builder_element_length_function_, builder_pad_info_, builder_pad_to_bucket_boundary_, builder_drop_remainder_, + builder_op_connector_size_); + + return Status::OK(); +} + +BucketBatchByLengthOp::BucketBatchByLengthOp(std::vector length_dependent_columns, + std::vector bucket_boundaries, + std::vector bucket_batch_sizes, + py::function element_length_function, PadInfo pad_info, + bool pad_to_bucket_boundary, bool drop_remainder, + int32_t op_connector_size) + : PipelineOp(op_connector_size), + length_dependent_columns_(length_dependent_columns), + bucket_boundaries_(bucket_boundaries), + bucket_batch_sizes_(bucket_batch_sizes), + element_length_function_(element_length_function), + pad_info_(pad_info), + pad_to_bucket_boundary_(pad_to_bucket_boundary), + drop_remainder_(drop_remainder), + batch_count_(0) { + for (int i = 0; i < bucket_batch_sizes_.size(); i++) { + buckets_.push_back(std::make_unique()); + } +} + +Status BucketBatchByLengthOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; } + +Status BucketBatchByLengthOp::operator()() { + TaskManager::FindMe()->Post(); + + TensorRow current_row; + child_iterator_ = std::make_unique(this, 0, 0); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + while (!child_iterator_->eof_handled()) { + while (!current_row.empty()) { + int32_t element_length; + RETURN_IF_NOT_OK(ObtainElementLength(&element_length, current_row)); + + int bucket_index = bucket_boundaries_.size() - 1; + while (element_length < bucket_boundaries_[bucket_index]) { + bucket_index--; + } + + buckets_[bucket_index]->push_back(current_row); + + if (buckets_[bucket_index]->size() == bucket_batch_sizes_[bucket_index]) { + RETURN_IF_NOT_OK(PadAndBatchBucket(bucket_index, bucket_batch_sizes_[bucket_index])); + } + + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + } + + // got EOE, do what we need to do with remainders in each bucket + if (!drop_remainder_) { + for (int i = 0; i < bucket_boundaries_.size(); i++) { + if (!buckets_[i]->empty()) { + RETURN_IF_NOT_OK(PadAndBatchBucket(i, buckets_[i]->size())); + } + } + } + + // need to send EOE manually since we set state to idle in EoeRecieved() + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + } + + return Status::OK(); +} + +Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, TensorRow element) { + // call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of + // the single column specified in length_dependent_columns_ + if (element_length_function_) { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + size_t number_of_arguments = length_dependent_columns_.size(); + py::tuple input_arguments(number_of_arguments); + for (size_t i = 0; i < number_of_arguments; i++) { + py::array argument_value; + int32_t column_index = column_name_id_map_[length_dependent_columns_[i]]; + RETURN_IF_NOT_OK(element[column_index]->GetDataAsNumpy(&argument_value)); + input_arguments[i] = argument_value; + } + + py::object length = element_length_function_(*input_arguments); + *out_element_length = length.cast(); + if (*out_element_length < 0) { + return Status(StatusCode::kPyFuncException, "Element length function should return a non negative integer."); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Count not cast output of element length function to int32_t."); + } + } else { + *out_element_length = element[0]->shape()[0]; + } + + return Status::OK(); +} + +Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t batch_size) { + std::unique_ptr *bucket = &buckets_[bucket_index]; + + PadInfo pad_info_copy = pad_info_; + if (pad_to_bucket_boundary_) { + for (auto &pair : pad_info_copy) { + std::vector pad_shape = pair.second.first.AsVector(); + + for (size_t i = 0; i < pad_shape.size(); i++) { + if (pad_shape[i] == TensorShape::kDimUnknown) { + if (bucket_index + 1 >= bucket_boundaries_.size()) { + std::string error_message = "Requested to pad to bucket boundary, element falls in last bucket"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message); + } + + pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1; + } + } + + pair.second.first = TensorShape(pad_shape); + } + } + + // PadColumns will change the data in bucket + RETURN_IF_NOT_OK(BatchOp::PadColumns(bucket, pad_info_copy, column_name_id_map_)); + + std::unique_ptr batched_bucket = std::make_unique(); + RETURN_IF_NOT_OK(BatchOp::BatchRows(bucket, &batched_bucket, batch_size)); + (*bucket)->clear(); + + std::unique_ptr batched_buffer = std::make_unique(batch_count_, DataBuffer::kDeBFlagNone); + batched_buffer->set_tensor_table(std::move(batched_bucket)); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(batched_buffer))); + + batch_count_++; + + return Status::OK(); +} + +Status BucketBatchByLengthOp::Reset() { + batch_count_ = 0; + + for (int i = 0; i < buckets_.size(); i++) { + buckets_[i] = std::make_unique(); + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h new file mode 100644 index 0000000000..332ff4bb22 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/batch_op.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DataBuffer; + +class BucketBatchByLengthOp : public PipelineOp { + public: + class Builder { + public: + Builder(std::vector length_dependent_columns, std::vector bucket_boundaries, + std::vector bucket_batch_sizes); + + ~Builder() = default; + + Builder &SetLengthDependentColumns(std::vector length_dependent_columns) { + builder_length_dependent_columns_ = length_dependent_columns; + return *this; + } + + Builder &SetBucketBoundaries(std::vector bucket_boundaries) { + builder_bucket_boundaries_ = bucket_boundaries; + return *this; + } + + Builder &SetBucketBatchSizes(std::vector bucket_batch_sizes) { + builder_bucket_batch_sizes_ = bucket_batch_sizes; + return *this; + } + + Builder &SetElementLengthFunction(py::function element_length_function) { + builder_element_length_function_ = element_length_function; + return *this; + } + + Builder &SetPadInfo(PadInfo pad_info) { + builder_pad_info_ = pad_info; + return *this; + } + + Builder &SetPadToBucketBoundary(bool pad_to_bucket_boundary) { + builder_pad_to_bucket_boundary_ = pad_to_bucket_boundary; + return *this; + } + + Builder &SetDropRemainder(bool drop_remainder) { + builder_drop_remainder_ = drop_remainder; + return *this; + } + + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + Status Build(std::shared_ptr *new_bucket_batch_by_length_op); + + private: + Status SanityCheck(); + + std::vector builder_length_dependent_columns_; + std::vector builder_bucket_boundaries_; + std::vector builder_bucket_batch_sizes_; + py::function builder_element_length_function_; + PadInfo builder_pad_info_; + bool builder_pad_to_bucket_boundary_; + bool builder_drop_remainder_; + int32_t builder_op_connector_size_; + }; + + BucketBatchByLengthOp(std::vector length_dependent_columns, std::vector bucket_boundaries, + std::vector bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, + bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); + + // Destructor + ~BucketBatchByLengthOp() = default; + + // Might need to batch remaining buckets after receiving eoe, so override this method. + // @param int32_t workerId + // @return Status - The error code returned + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param sO - reference to the BucketBatchByLengthOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const BucketBatchByLengthOp &bo) { + bo.Print(out, false); + return out; + } + + // Main loop of batch + // @return Status - The error code returned + Status operator()() override; + + // Function that is called by ResetOp at the end of every epoch + // @return Status - The error code returned + Status Reset() override; + + private: + Status ObtainElementLength(int32_t *out_element_length, TensorRow element); + + Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size); + + std::vector length_dependent_columns_; + std::vector bucket_boundaries_; + std::vector bucket_batch_sizes_; + py::function element_length_function_; + PadInfo pad_info_; + bool pad_to_bucket_boundary_; + bool drop_remainder_; + + int32_t batch_count_; + std::unique_ptr child_iterator_; + std::vector> buckets_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc new file mode 100644 index 0000000000..8ed51ebbb6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc @@ -0,0 +1,206 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/build_vocab_op.h" + +#include +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" + +namespace mindspore { +namespace dataset { + +BuildVocabOp::BuildVocabOp(std::shared_ptr vocab, std::vector col_names, + std::pair freq_r, int64_t top_k, const std::vector &tokens, + bool prepend, int32_t num_workers, int32_t op_conn_size) + : ParallelOp(num_workers, op_conn_size), + interval_(op_conn_size * num_workers), + vocab_(vocab), + col_names_(col_names), + freq_range_(freq_r), + top_k_(top_k), + special_tokens_(tokens), + special_first_(prepend) { + // init two queues for thread sync + distributor_queue_ = std::make_unique>(num_workers * op_conn_size); + collector_queue_ = + std::make_unique>>>(num_workers * op_conn_size); +} + +Status BuildVocabOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + TensorRow new_row; + RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); + std::unique_ptr> wrkr_map = + std::make_unique>(); + int32_t row_cnt = 0; + while (!new_row.empty()) { + for (int32_t col : col_ids_) { + CHECK_FAIL_RETURN_UNEXPECTED(!new_row[col]->type().IsNumeric(), "from_dataset only works on string columns"); + for (auto itr = new_row[col]->begin(); itr != new_row[col]->end(); itr++) { + (*wrkr_map)[std::string(*itr)] += 1; + } + } + row_cnt++; // row is processed by this point + if ((row_cnt % interval_ == 0) && ((row_cnt / interval_) % num_workers_ == worker_id) && (!wrkr_map->empty())) { + RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); + wrkr_map = std::make_unique>(); + } + RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); + } + // clean up + if (!wrkr_map->empty()) { + RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); + } + // empty map as quit signal + RETURN_IF_NOT_OK(collector_queue_->Add(std::make_unique>())); + return Status::OK(); +} + +Status BuildVocabOp::operator()() { + // launch the collector thread + RETURN_UNEXPECTED_IF_NULL(tree_); + RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks())); + // launch worker threads and collector thread + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&BuildVocabOp::WorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("collector", std::bind(&BuildVocabOp::CollectorThread, this))); + TaskManager::FindMe()->Post(); + child_iterator_ = std::make_unique(this, 0, 0); + TensorRow new_row; + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + if (!col_names_.empty()) { + col_ids_.reserve(col_names_.size()); + for (std::string col : col_names_) { + auto itr = column_name_id_map_.find(col); + CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(), col + " column doesn't exist"); + col_ids_.push_back(itr->second); + } + } else { + col_ids_.reserve(column_name_id_map_.size()); + for (const auto &p : column_name_id_map_) { + col_ids_.push_back(p.second); + } + } + bool eoe_warning = false; // give out warning if receive more than 1 eoe + while (child_iterator_->eof_handled() == false) { + while (new_row.empty() == false) { + RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(new_row)); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + } + CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no op should be after from_dataset (repeat detected)"); + eoe_warning = true; + } + + // tell all workers to quit + for (int32_t wrkr_id = 0; wrkr_id < num_workers_; wrkr_id++) { + RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(TensorRow())); + } + return Status::OK(); +} + +Status BuildVocabOp::CollectorThread() { + TaskManager::FindMe()->Post(); + int32_t num_quited_worker = 0; + std::unique_ptr> wrkr_map; + while (num_quited_worker != num_workers_) { + RETURN_IF_NOT_OK(collector_queue_->PopFront(&wrkr_map)); + RETURN_UNEXPECTED_IF_NULL(wrkr_map); + if (!wrkr_map->empty()) { + for (const auto &wd : *wrkr_map) word_cnt_[wd.first] += wd.second; + } else { + ++num_quited_worker; + } + } // all frequencies are obtained + CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty"); + std::vector words; + // make sure enough is reserved, this will become a partially sorted list eventually + words.reserve(wrkr_map->size()); + + for (auto it = word_cnt_.begin(); it != word_cnt_.end();) { + if (it->second >= freq_range_.first && it->second <= freq_range_.second) { + words.push_back(it->first); + it++; + } else { + it = word_cnt_.erase(it); + } + } + std::string err_msg; + + for (const std::string &sp_tk : special_tokens_) { + // if a special word exists in dataset, warn user about this + err_msg += (word_cnt_.find(sp_tk) != word_cnt_.end() ? sp_tk + "\t" : ""); + } + + CHECK_FAIL_RETURN_UNEXPECTED(err_msg.empty(), "These specials words are already in the dataset: " + err_msg + "."); + + int64_t num_words = std::min(static_cast(words.size()), top_k_); + if (num_words == 0) { + MS_LOG(WARNING) << "No word falls in the frequency range: (" << freq_range_.first << "," << freq_range_.second + << ") vocab would be empty (except for special tokens)."; + } + + // this would take the top-k most frequent words + std::partial_sort(words.begin(), words.begin() + num_words, words.end(), + [this](const std::string &w1, const std::string &w2) { + int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2]; + return f1 == f2 ? w1 < w2 : f1 > f2; + }); + + if (special_first_) { + for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk); + } + + for (int64_t i = 0; i < num_words; i++) { + vocab_->append_word(words[i]); + } + + if (!special_first_) { + for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk); + } + + RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + // then use std::nth_element to partial sort + return Status::OK(); +} + +Status BuildVocabOp::Builder::Build(std::shared_ptr *op) { + CHECK_FAIL_RETURN_UNEXPECTED(builder_num_workers_ > 0, "builder num_workers need to be greater than 0"); + CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number"); + CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0, + "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)"); + (*op) = std::make_shared( + builder_vocab_, builder_col_names_, std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_, + builder_speical_tokens_, builder_special_first_, builder_num_workers_, builder_connector_size_); + return Status::OK(); +} + +BuildVocabOp::Builder::Builder() + : builder_top_k_(std::numeric_limits::max()), + builder_min_freq_(0), + builder_max_freq_(std::numeric_limits::max()), + builder_special_first_(true) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_connector_size_ = cfg->op_connector_size(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h new file mode 100644 index 0000000000..42ea0deb5c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h @@ -0,0 +1,174 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/text/vocab.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class BuildVocabOp : public ParallelOp { + public: + class Builder { + public: + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param int32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + builder_connector_size_ = size; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param int64_t top_k + // @return Builder setter method returns reference to the builder. + Builder &SetTopK(int64_t top_k) { + builder_top_k_ = top_k; + return *this; + } + + // Setter method + // @param int64_t min_freq + // @return Builder setter method returns reference to the builder. + Builder &SetMinFreq(int64_t min_freq) { + builder_min_freq_ = min_freq; + return *this; + } + + // Setter method + // @param int64_t max_freq + // @return Builder setter method returns reference to the builder. + Builder &SetMaxFreq(int64_t max_freq) { + builder_max_freq_ = max_freq; + return *this; + } + + // set columns names + // @param const std::vector & col_names - name of columns to get words + // @return Builder & reference to builder class object + Builder &SetColumnNames(const std::vector &col_names) { + builder_col_names_ = col_names; + return *this; + } + + // set special tokens + // @param const std::vector & col_names - name of columns to get words + // @return Builder & reference to builder class object + Builder &SetSpecialTokens(const std::vector &tokens) { + builder_speical_tokens_ = tokens; + return *this; + } + + // set vocab object + Builder &SetVocab(std::shared_ptr vocab) { + builder_vocab_ = vocab; + return *this; + } + + // set special tokens first (or last) + Builder &SetSpecialFirst(bool prepend) { + builder_special_first_ = prepend; + return *this; + } + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + int32_t builder_num_workers_; + int32_t builder_connector_size_; + int64_t builder_min_freq_; + int64_t builder_max_freq_; + bool builder_special_first_; + std::vector builder_col_names_; + std::vector builder_speical_tokens_; + std::shared_ptr builder_vocab_; + int64_t builder_top_k_; + }; + + BuildVocabOp(std::shared_ptr vocab, std::vector col_names, std::pair freq_range, + int64_t top_k, const std::vector &tokens, bool prepend, int32_t num_workers, + int32_t op_connector_size); + + ~BuildVocabOp() = default; + + Status WorkerEntry(int32_t worker_id) override; + + // collect the work product from each worker + Status CollectorThread(); + + Status EofReceived(int32_t) override { return Status::OK(); } + + Status EoeReceived(int32_t) override { return Status::OK(); } + + Status operator()() override; + + // Getter + // @return the number of workers + int32_t num_producers() const override { return 1; } + + // Getter + // @return the number of threads consuming from the previous Connector + int32_t num_consumers() const override { return 1; } + + Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); } + + private: + const int32_t interval_; + bool special_first_; + std::shared_ptr vocab_; + std::vector col_names_; + std::vector col_ids_; + std::vector special_tokens_; + // pair = {min_f, max_f} + // make sure that 0<= min_f < max_f <= int32_max in the builder + std::pair freq_range_; + + int64_t top_k_; // every thing means top_k_ == int32_max + std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 + std::unique_ptr> distributor_queue_; // master thread assigns each worker TensorRow via this + std::unique_ptr>>> collector_queue_; + std::unordered_map word_cnt_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc new file mode 100644 index 0000000000..1b0890686f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/cache_base_op.h" +#include +#include +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +// A print method typically used for debugging +void CacheBase::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nCache client:\n" << *cache_client_ << "\n\n"; + } +} +// Overrides base class reset method. When an operator does a reset, it cleans up any state +// info from it's previous execution and then initializes itself so that it can be executed +// again. +Status CacheBase::Reset() { + if (sampler_ != nullptr) { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + } + // Wake up the workers to get them going again in a new epoch + MS_LOG(DEBUG) << Name() << " resetting."; + epoch_sync_.Set(); + return Status::OK(); +} +CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, sampler), + cache_client_(cache_client), + rows_per_buffer_(rows_per_buf), + // We can cause deadlock if this internal Connector size is too small. + keys_miss_(num_workers_, 1, connector_capacity_) { + io_block_queues_.Init(num_workers, op_connector_size); +} +// Common function to fetch samples from the sampler and send them using the io_block_queues to +// the parallel workers +Status CacheBase::FetchSamplesToWorkers() { + int64_t buf_cnt = 0; + int64_t wait_cnt = 0; + do { + epoch_sync_.Clear(); + std::vector keys; + int64_t row_cnt = 0; + keys.reserve(rows_per_buffer_); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (!sampler_buffer->eoe()) { + TensorRow sample_row; + RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { + keys.push_back(*itr); + ++row_cnt; + if (row_cnt % rows_per_buffer_ == 0) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (!keys.empty()) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); + } + // send the eoe + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + // If repeat but the not last repeat, wait for reset. + if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; + RETURN_IF_NOT_OK(epoch_sync_.Wait()); + } else { + // We can break out from the loop. + break; + } + } while (true); + // Flow the eof before exit + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + // Ask all the workers to quit. + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); +} +Status CacheBase::FetchFromCache(int32_t worker_id) { + int64_t buffer_id = worker_id; + std::unique_ptr blk; + do { + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); + if (blk->eof()) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else if (blk->eoe()) { + if (AllowCacheMiss()) { + // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from + // a sampler, send a eoe to physical leaf op as well. + std::vector eoe; + eoe.push_back(eoe_row_id); + RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe)); + } + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(blk->GetKeys(&keys)); + if (keys.empty()) { + // empty key is a quit signal for workers + break; + } + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + std::unique_ptr que = std::make_unique(); + TensorTable ttbl; + RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl)); + auto row_it = ttbl.begin(); + std::vector cache_miss; + cache_miss.reserve(keys.size()); + for (auto row_id : keys) { + auto &row = *row_it; + if (row.empty()) { + if (AllowCacheMiss()) { + cache_miss.push_back(row_id); + } else { + std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; + RETURN_STATUS_UNEXPECTED(errMsg); + } + } + que->push_back(std::move(row)); + ++row_it; + } + db->set_tensor_table(std::move(que)); + if (AllowCacheMiss()) { + // Because of the way connector works, we push unconditionally even cache_miss can be empty. + RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss)); + } + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + } while (true); + return Status::OK(); +} +Status CacheBase::RegisterResources() { + RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + return Status::OK(); +} +CacheBase::~CacheBase() {} +Status CacheBase::UpdateColumnMapFromCache() { + Status rc; + // Get the schema from the server. It may not be there yet. So tolerate the error. + if (column_name_id_map_.empty()) { + rc = cache_client_->FetchSchema(&column_name_id_map_); + if (rc == Status(StatusCode::kFileNotExist)) { + MS_LOG(DEBUG) << "Schema not in the server yet."; + rc = Status::OK(); + } + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h new file mode 100644 index 0000000000..fb3e999b76 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/engine/datasetops/cache_base_op.h" +namespace mindspore { +namespace dataset { +/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities. +/// \see CacheOp +/// \see CacheLookupOp +class CacheBase : public ParallelOp { + public: + /// \brief Base class constructor + /// \param num_workers Number of parallel workers + /// \param op_connector_size Connector size + /// \param rows_per_buf Number of rows per buffer + /// \param cache_client CacheClient for communication to the CacheServer + /// \param sampler Sampler which is mandatory + CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler); + /// \brief Destructor + ~CacheBase(); + + /// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state + /// info from it's previous execution and then initializes itself so that it can be executed + /// again. + /// \return Status - The error code return + Status Reset() override; + + /// \brief A print method typically used for debugging + /// \param out The output stream to write output to + /// \param show_all A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + /// \brief << Stream output operator overload + /// \notes This allows you to write the debug print info using stream operators + /// \param out reference to the output stream being overloaded + /// \param mo reference to the CacheOp to display + /// \return the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) { + mo.Print(out, false); + return out; + } + + /// \brief Getter for the cache client + /// \return shared ptr to the cache client + std::shared_ptr cache_client() { return cache_client_; } + /// \brief Setter for the cache client + void SetCacheClient(std::shared_ptr cache_client) { cache_client_ = std::move(cache_client); } + /// \brief Derived class must implement this method if a cache miss is treated as error + virtual bool AllowCacheMiss() = 0; + + protected: + constexpr static int32_t eoe_row_id = -1; + std::shared_ptr cache_client_; + WaitPost epoch_sync_; + int32_t rows_per_buffer_; + Connector> keys_miss_; + + /// \brief Common function to register resources for interrupt + /// \note Derived should override this function for extra resources to be registered + virtual Status RegisterResources(); + /// \brief This function is called by main thread to send samples to the worker thread. + /// \note It is a non-virtual function + /// \return Status object + Status FetchSamplesToWorkers(); + /// \brief This function is called by each worker to fetch rows from the cache server for a given set of + /// sample row id's + /// \return Status object + Status FetchFromCache(int32_t worker_id); + /// \brief Get the column map from cache server + Status UpdateColumnMapFromCache(); + + private: + constexpr static int32_t connector_capacity_ = 1024; + QueueList> io_block_queues_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc new file mode 100644 index 0000000000..0a9b7544ba --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "utils/log_adapter.h" +#include "utils/system/crc32c.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + rows_per_buffer_ = cfg->rows_per_buffer(); + build_op_connector_size_ = cfg->op_connector_size(); +} + +// Check if the required parameters are set by the builder. +Status CacheLookupOp::Builder::SanityCheck() const { + if (build_cache_client_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheLookupOp requires a CacheClient"); + } + // Make sure the cache client has a valid session + if (!build_cache_client_->session_id()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Cache client for CacheLookupOp is missing session id"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object and does some init on it +Status CacheLookupOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, rows_per_buffer_, + build_cache_client_, build_sampler_); + return Status::OK(); +} +Status CacheLookupOp::operator()() { + if (!sampler_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "CacheLookupOp requires a sampler before it can be executed!"); + } + RETURN_IF_NOT_OK(RegisterResources()); + // Kick off the workers + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&CacheLookupOp::WorkerEntry, this, std::placeholders::_1))); + // required task group sync after launching workers + TaskManager::FindMe()->Post(); + // We have to wait until the leaf op has handshake with us. + RETURN_IF_NOT_OK(leaf_op_wp_.Wait()); + RETURN_IF_NOT_OK(FetchSamplesToWorkers()); + return Status::OK(); +} +Status CacheLookupOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(FetchFromCache(worker_id)); + return Status::OK(); +} +Status CacheLookupOp::ResetSampler() { return Status::OK(); } +Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) { + // We act like a sampler and as a dataset op. During handshake with leaf op, + // We must wait until the leaf op has indexed everything. + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op)); + // Now we notify the main thread handshake has finished. + leaf_op_wp_.Set(); + return Status::OK(); +} +Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } +void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } +Status CacheLookupOp::GetNextSample(std::unique_ptr *out_buffer) { + std::vector cache_miss; + RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); + // Ignore the case we have no cache miss, we can't return empty samples. + while (cache_miss.empty()) { + RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); + } + // Special code for eoe + if (cache_miss.at(0) == eoe_row_id) { + *out_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + std::shared_ptr sample_ts; + RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size())); + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); + auto idPtr = sample_ts->begin(); + for (auto i = 0; i < cache_miss.size(); ++i) { + *idPtr = cache_miss.at(i); + ++idPtr; + } + TensorRow row; + row.push_back(sample_ts); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} +Status CacheLookupOp::RegisterResources() { + RETURN_IF_NOT_OK(CacheBase::RegisterResources()); + RETURN_IF_NOT_OK(leaf_op_wp_.Register(tree_->AllTasks())); + return Status::OK(); +} +Status CacheLookupOp::ComputeColMap() { + // We don't know the column map at this point unless we contact the cache server + // to fetch the schema but the cache server may not have it at this point either. + // So we will just return OK and let MergeOp (our parent) to handle it. + return Status::OK(); +} + +// Visitor accept method for NodePass +Status CacheLookupOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h new file mode 100644 index 0000000000..46a58c5d02 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/cache_base_op.h" + +namespace mindspore { +namespace dataset { +/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset. +/// \note For non-mappable dataset, please see CacheOp +/// \see CacheOp +class CacheLookupOp : public CacheBase, public Sampler { + public: + class Builder { + public: + /// \brief Builder constructor. Creates the builder object. + /// \note No default args + Builder(); + + /// Default destructor + ~Builder() = default; + + /// Setter method. + /// \treturn Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetClient(std::shared_ptr cache_client) { + build_cache_client_ = cache_client; + return *this; + } + + /// \brief Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + build_sampler_ = std::move(sampler); + return *this; + } + + /// \brief The builder "build" method creates the final object and does some init on it. + /// \param ptr The shared_ptr to the new CacheLookupOp object + /// \return Status + Status Build(std::shared_ptr *ptr); + + private: + int32_t build_num_workers_; + int32_t rows_per_buffer_; + int32_t build_op_connector_size_; + std::shared_ptr build_cache_client_; + std::shared_ptr build_sampler_; + + // Check if the required parameters are set by the builder. + // \return Status The error code return + Status SanityCheck() const; + }; + /// \brief Constructor + /// \note It takes the same argument as the base class. + /// \see CacheBase + CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler) + : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {} + ~CacheLookupOp() = default; + // As a parallel op, we override these two functions + Status operator()() override; + Status WorkerEntry(int32_t worker_id) override; + // As a sampler, we override the following functions + Status ResetSampler() override; + Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; + Status InitSampler() override; + Status GetNextSample(std::unique_ptr *out_buffer) override; + void Print(std::ostream &out, bool show_all) const override; + bool AllowCacheMiss() override { return true; } + std::string Name() const override { return "CacheLookupOp"; } + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + protected: + Status ComputeColMap() override; + + private: + WaitPost leaf_op_wp_; + + Status RegisterResources() override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc new file mode 100644 index 0000000000..75579dc3a6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -0,0 +1,302 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" + +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +CacheMergeOp::~CacheMergeOp() = default; +void CacheMergeOp::Print(std::ostream &out, bool show_all) + const { // Always show the id and name as first line regardless if this is summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\n\n"; + } +} +CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, + std::shared_ptr cache_client, const std::shared_ptr &sampler) + : ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {} +Status CacheMergeOp::operator()() { + // A queue of row id to let cleaner send cache miss rows to the cache server + // We don't want a small queue as this will block the parallel op workers. + // A row id is 8 byte integer. So bigger size doesn't consume a lot of memory. + static const int32_t queue_sz = 512; + io_que_ = std::make_unique>(queue_sz); + RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1))); + // One dedicated thread to move TensorRow from the pool to the cache server + for (auto i = 0; i < num_cleaners_; ++i) { + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this))); + } + TaskManager::FindMe()->Post(); + return Status::OK(); +} +// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait +// until it shows up in the pool. +Status CacheMergeOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::shared_ptr cache_hit_stream = child_[kCacheHitChildIdx]; + std::unique_ptr db_ptr; + RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); + while (!db_ptr->eof()) { + if (db_ptr->eoe()) { + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + db_ptr.reset(); + RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); + } else { + // See if there is any missing row + auto tbl = std::make_unique(); + while (db_ptr->NumRows() > 0) { + TensorRow row; + RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); + if (row.empty()) { + auto row_id = row.getId(); + TensorRowRequest *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(row_id, &rq)); + // Block until the row shows up in the pool. + RETURN_IF_NOT_OK(rq->Wait(&row)); + } + tbl->push_back(std::move(row)); + } + db_ptr->set_tensor_table(std::move(tbl)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); + RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); + } + } + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); + return Status::OK(); +} +Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { + TaskManager::FindMe()->Post(); + // We will simply pop TensorRow from the stream and insert them into the pool and + // wake up any worker that is awaiting on the missing TensorRow. + // If we see an eoe, ignore it. For eof, we exit. + std::shared_ptr cache_missing_stream = child_[kCacheMissChildIdx]; + // Before we start, cache the schema at the server. Pick one of the workers + // do it. The schema should have been done at prepare time. + if (workerId == 0) { + RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map())); + } + std::unique_ptr db_ptr; + RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); + while (!db_ptr->eof()) { + if (db_ptr->eoe()) { + // Ignore it. + MS_LOG(DEBUG) << "Ignore eoe"; + } else { + while (db_ptr->NumRows() > 0) { + TensorRow row; + RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); + row_id_type row_id = row.getId(); + if (row_id < 0) { + std::string errMsg = "Expect positive row id: " + std::to_string(row_id); + RETURN_STATUS_UNEXPECTED(errMsg); + } + TensorRowRequest *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(row_id, &rq)); + rq->WakeUpAny(std::move(row)); + // Let the cleaner to flush out this row (async) to the cache server. + RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); + } + } + RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); + } + return Status::OK(); +} +Status CacheMergeOp::Cleaner() { + TaskManager::FindMe()->Post(); + while (true) { + row_id_type row_id; + RETURN_IF_NOT_OK(io_que_->PopFront(&row_id)); + if (row_id < 0) { + break; + } + TensorRowRequest *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(row_id, &rq)); + if (rq->GetState() == TensorRowRequest::State::kClean) { + // If already flushed, move on to the next one. + continue; + } + TensorRow row; + RETURN_IF_NOT_OK(rq->Release(&row)); + CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error."); + Status rc = cache_client_->WriteRow(row); + // Bad rc should not bring down the pipeline + if (rc.IsError()) { + MS_LOG(WARNING) << "Cache not successful." << rc.ToString(); + } + rq->SetState(TensorRowRequest::State::kClean); + } + return Status::OK(); +} + +Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) { + RETURN_UNEXPECTED_IF_NULL(out); + std::unique_lock lck(mux_); + auto it = cache_miss_map_.find(row_id); + if (it != cache_miss_map_.end()) { + *out = it->second.GetMutablePointer(); + } else { + // We will create a new one. + auto alloc = Services::GetAllocator(); + auto r = cache_miss_map_.emplace(row_id, MemGuard>(alloc)); + if (r.second) { + auto &mem = r.first->second; + RETURN_IF_NOT_OK(mem.allocate(1, row_id)); + *out = mem.GetMutablePointer(); + } else { + RETURN_STATUS_UNEXPECTED("Map insert fail."); + } + } + return Status::OK(); +} +Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own + // specific logic + CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children"); + RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); + // Get the computed check sum from all ops in the cache miss class + uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]); + // This is a mappable cache op so the id's need to be generated. + // Construct the cache + const bool generate_ids = false; + Status rc = cache_client_->CreateCache(cache_crc, generate_ids); + if (rc.get_code() == StatusCode::kDuplicateKey) { + // We are told the cache has been created already. + MS_LOG(INFO) << "Cache created already"; + rc = Status::OK(); + } + RETURN_IF_NOT_OK(rc); + return Status::OK(); +} +Status CacheMergeOp::ComputeColMap() { + CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty"); + if (column_name_id_map().empty()) { + column_name_id_map_ = child_[kCacheMissChildIdx]->column_name_id_map(); + } + CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected"); + return Status::OK(); +} +Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) { + RETURN_UNEXPECTED_IF_NULL(out); + // Block until the missing row is in the pool. + RETURN_IF_NOT_OK(use_count_.P()); + std::unique_lock lck(dq_mux_); + CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); + *out = std::move(row_.front()); + row_.pop_front(); + return Status::OK(); +} +void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) { + std::unique_lock lck(dq_mux_); + // Technically number of this row shows up in the cache miss stream is equal to the number + // of P() call. However the cleaner wants it too. So we need an extra copy. + if (GetState() == State::kEmpty) { + // We will do a deep copy + for (auto &ts : row) { + auto out_ts = std::make_shared(ts->shape(), ts->type(), ts->GetBuffer(), ts->SizeInBytes()); + cleaner_copy_.push_back(out_ts); + } + cleaner_copy_.setId(row.getId()); + // Change the state to dirty + SetState(State::kDirty); + } + row_.push_back(std::move(row)); + // Bump up the use count by 1. This wake up any parallel worker which is waiting + // for this row. + use_count_.V(); +} +Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) { + RETURN_UNEXPECTED_IF_NULL(out); + // We are not holding any mutex here because the cleaner isn't really touching the deque row_. + // In case we have multiple cleaners and they all see the copy, only one of them will + // get it. + auto expected = State::kDirty; + if (st_.compare_exchange_strong(expected, State::kClean)) { + *out = std::move(cleaner_copy_); + } + return Status::OK(); +} +// Builder constructor. Creates the builder object. +CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + build_op_connector_size_ = cfg->op_connector_size(); + build_num_cleaners_ = 1; +} + +// Check if the required parameters are set by the builder. +Status CacheMergeOp::Builder::SanityCheck() const { + if (build_cache_client_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheMergeOp requires a CacheClient"); + } + // Make sure the cache client has a valid session + if (!build_cache_client_->session_id()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Cache client for CacheMergeOp is missing session id"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object and does some init on it +Status CacheMergeOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, build_num_cleaners_, + build_cache_client_, build_sampler_); + return Status::OK(); +} + +// Pre-Visitor accept method for NodePass +Status CacheMergeOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + +// Visitor accept method for NodePass +Status CacheMergeOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status CacheMergeOp::EoeReceived(int32_t worker_id) { + // If we are in a repeat path, send the eoe up. + // Otherwise ignore it. + if (BitTest(op_ctrl_flags_, kDeOpRepeated)) { + return DatasetOp::EoeReceived(worker_id); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h new file mode 100644 index 0000000000..df37465fc4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h @@ -0,0 +1,196 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/semaphore.h" + +namespace mindspore { +namespace dataset { +/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single +/// stream +class CacheMergeOp : public ParallelOp { + public: + // Some handshake structures among the main thread, cleaner threads and parallel op threads. + class TensorRowRequest { + public: + enum class State : uint8_t { + kEmpty = 0, // No row in the deque + kDirty = 1, // Cleaner hasn't flushed it to the cache server yet. + kClean = 2 // The row has been flushed already. + }; + explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {} + ~TensorRowRequest() = default; + State GetState() const { return st_; } + void SetState(State newState) { st_ = newState; } + Status Wait(TensorRow *out); + void WakeUpAny(TensorRow &&row); + Status Release(TensorRow *out); + + private: + std::mutex dq_mux_; + std::atomic st_; + Semaphore use_count_; + std::deque row_; + TensorRow cleaner_copy_; + }; + + constexpr static int kCacheHitChildIdx = 0; // Cache hit stream + constexpr static int kCacheMissChildIdx = 1; // Cache miss stream + + /// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of + /// the arguments for constructing it. Use the builder by setting each argument + /// with the provided set methods, and then finally call the build method to execute + /// the actual construction. + class Builder { + public: + /// Builder constructor. Creates the builder object. + /// \note No default args + Builder(); + + /// Default destructor + ~Builder() = default; + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetClient(std::shared_ptr cache_client) { + build_cache_client_ = cache_client; + return *this; + } + + /// \brief Setter method + /// \param sampler + /// \return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + build_sampler_ = std::move(sampler); + return *this; + } + + /// \brief Setter method + /// \param num_cleaners + /// \return Builder setter method returns reference to the builder. + Builder &SetNumCleaner(int32_t num_cleaners) { + build_num_cleaners_ = num_cleaners; + return *this; + } + + /// The builder "build" method creates the final object and does some init on it. + /// \param ptr The shared_ptr to the new CacheMergeOp object + /// \return Status + Status Build(std::shared_ptr *ptr); + + private: + int32_t build_num_workers_; + int32_t build_op_connector_size_; + int32_t build_num_cleaners_; + std::shared_ptr build_cache_client_; + std::shared_ptr build_sampler_; + + /// Check if the required parameters are set by the builder. + /// \return Status The error code return + Status SanityCheck() const; + }; + + /// \brief Constructor + /// \param numWorkers Number of parallel workers as a derived class of ParallelOp + /// \param opConnector Size Connector size as a derived class of ParallelOp + /// \param numCleaners Number of cleaners to move cache miss rows into the cache server + /// \param cache_client CacheClient to commmunicate with the Cache server + /// \param sampler as a derived class of ParallelOp + CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, + std::shared_ptr cache_client, const std::shared_ptr &sampler); + ~CacheMergeOp(); + void Print(std::ostream &out, bool show_all) const override; + friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) { + mo.Print(out, false); + return out; + } + /// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and + /// the threads for the cleaners. + /// \return + Status operator()() override; + /// \brief Entry function for worker thread that fetch rows from CacheLookupOp + /// \param workerId + /// \return Status object + Status WorkerEntry(int32_t workerId) override; + Status PrepareNodePostAction() override; + /// \brief Entry function for worker thread that fetch rows from the cache miss stream + /// \param workerId + /// \return Status object + Status CacheMissWorkerEntry(int32_t workerId); + Status GetRq(row_id_type row_id, TensorRowRequest **); + + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for eoe handling + /// \param worker_id + /// \return Status object + Status EoeReceived(int32_t worker_id) override; + + protected: + Status ComputeColMap() override; + + private: + std::mutex mux_; + std::map>> cache_miss_map_; + std::unique_ptr> io_que_; + std::shared_ptr cache_client_; + int32_t num_cleaners_; + + /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for + /// moving cache miss TensorRow into the CacheServer. + /// \return Status object + Status Cleaner(); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc new file mode 100644 index 0000000000..143c45b2dc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -0,0 +1,219 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/cache_op.h" + +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/util/task_manager.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + rows_per_buffer_ = cfg->rows_per_buffer(); + build_op_connector_size_ = cfg->op_connector_size(); +} + +// Check if the required parameters are set by the builder. +Status CacheOp::Builder::SanityCheck() const { + if (build_cache_client_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheOp requires a CacheClient"); + } + // Make sure the cache client has a valid session + if (!build_cache_client_->session_id()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cache client for CacheOp is missing session id"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object and does some init on it +Status CacheOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_, + build_sampler_); + RETURN_IF_NOT_OK((*ptr)->InitCache()); + + return Status::OK(); +} + +// Constructor of CacheOp +CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler) + : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), + num_guys_in_(0), + phase_(Phase::kBuildPhase) {} + +// Destructor +CacheOp::~CacheOp() = default; + +// Private function for cache setup/init work just after construction +Status CacheOp::InitCache() { return Status::OK(); } + +// This class functor will provide the master loop that drives the logic for performing the work +Status CacheOp::operator()() { + if (!sampler_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "CacheOp requires a sampler before it can be executed!"); + } + RETURN_IF_NOT_OK(RegisterResources()); + // Kick off the workers + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1))); + // required task group sync after launching workers + TaskManager::FindMe()->Post(); + // Wait for the workers to finish caching the rows. + RETURN_IF_NOT_OK(WaitForCachingAllRows()); + RETURN_IF_NOT_OK(FetchSamplesToWorkers()); + return Status::OK(); +} +Status CacheOp::CacheAllRows(int32_t worker_id) { + // If the current phase is to fill the cache, do it then. + if (phase_ == Phase::kBuildPhase) { + // We will take the chance to cache the schema at the server. + // Just do it once and pick one worker to do it. + if (worker_id == 0) { + RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map())); + } + MS_LOG(INFO) << "CacheOp first epoch SAVE mode started. Worker: " << worker_id; + // SAVE mode loop + std::unique_ptr db_ptr; + RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); + while (!db_ptr->eof()) { + if (!db_ptr->eoe()) { + RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr))); + } else { + // In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up + // as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the + // the eoe to indicate the end of the epoch, we should next expect to get the eof. + // Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch + // from again. + RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); + if (!db_ptr->eof()) { + RETURN_STATUS_UNEXPECTED("Cache op expects to get an eof after eoe from child."); + } + } + RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); + } + } + // Let the main guy know we are done. + auto last_guy_in = num_guys_in_.fetch_add(1); + if ((last_guy_in + 1) == num_workers_) { + rows_cache_done_.Set(); + } else { + // Let's do a sync up here. + RETURN_IF_NOT_OK(rows_cache_done_.Wait()); + } + return Status::OK(); +} +Status CacheOp::WaitForCachingAllRows() { + // Wait for the workers to finish caching the rows. + RETURN_IF_NOT_OK(rows_cache_done_.Wait()); + // Move from build phase to fetch phase if we are the one to fill the cache + if (phase_ == Phase::kBuildPhase) { + RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone()); + // Move to the next phase + phase_ = Phase::kFetchPhase; + } + // Get statistics from the server, and if we are not the one to create the cache, + // wait until the state changed from build phase to fetch base. + CacheClient::ServiceStat stat{}; + bool BuildPhaseDone = true; + do { + RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); + BuildPhaseDone = stat.cache_service_state == static_cast(CacheService::State::kFetchPhase); + if (!BuildPhaseDone) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } while (!BuildPhaseDone); + const row_id_type min_key = stat.min_row_id; + const row_id_type max_key = stat.max_row_id; + num_rows_ = max_key - min_key + 1; + MS_LOG(INFO) << "Number of rows cached: " << num_rows_; + MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached; + MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached; + // Now all rows are cached and we have done a sync point check up. Next phase is + // is pick up fetch input from sampler and pass up to the caller. + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} +Status CacheOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(CacheAllRows(worker_id)); + RETURN_IF_NOT_OK(FetchFromCache(worker_id)); + return Status::OK(); +} +Status CacheOp::RegisterResources() { + RETURN_IF_NOT_OK(CacheBase::RegisterResources()); + RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks())); + return Status::OK(); +} + +// Base-class override for setting specific CacheOp configurations. This code will be called +// during the execution tree prepare phase BEFORE traversing down to child operators. +uint32_t CacheOp::PrepareFlags() const { return ExecutionTree::kDePrepCache; } +// Base-class override for special eoe handler. +// CacheOp must override this because it shall not perform default handling of eoe. Instead +// the CacheOp manages actions related to the end of the epoch. +Status CacheOp::EoeReceived(int32_t worker_id) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} +// Base-class override for handling cases when an eof is received. +Status CacheOp::EofReceived(int32_t worker_id) { + // eofReceived is overloaded because we want to manually handle this eof. + // Specifically, the default behaviour is to pack it and flow it up to the next connection. + // In this case, we want a no-op behaviour so that we can perform correct action. + return Status::OK(); +} + +// Pre-Visitor accept method for NodePass +Status CacheOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + +// Visitor accept method for NodePass +Status CacheOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +// A public wrapper for creating the cache through the client +Status CacheOp::CreateCache(uint32_t cache_crc) { + // This is a non-mappable cache op so the id's need to be generated. + // Construct the cache + const bool generate_ids = true; + Status rc = cache_client_->CreateCache(cache_crc, generate_ids); + if (rc.get_code() == StatusCode::kDuplicateKey) { + // We are told the cache has been created already. So we skip the build phase. + phase_ = Phase::kFetchPhase; + rc = Status::OK(); + } + RETURN_IF_NOT_OK(rc); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h new file mode 100644 index 0000000000..dd34d54973 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h @@ -0,0 +1,168 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/cache_base_op.h" + +namespace mindspore { +namespace dataset { +/// \brief CacheOp provides a memory/disk cache that acts as a save-point within a non-mappable dataset. +/// \note For mappable dataset, please see CacheLookupOp. +/// \see CacheLookupOp +class CacheOp : public CacheBase, public RandomAccessOp { + public: + // This CacheOp is for non-mappable case where it is divided into two phases. + // The first phase is we cache all the rows from the child (and let the cache server + // assigns row id). No read access in the first phase. Once the cache is fully built, + // we switch to second phase and fetch requests from the sampler. + enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 }; + + /// \brief The nested builder class inside of the CacheOp is used to help manage all of + /// the arguments for constructing it. Use the builder by setting each argument + /// with the provided set methods, and then finally call the build method to execute + /// the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + /// \brief Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + /// \brief Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetClient(std::shared_ptr cache_client) { + build_cache_client_ = cache_client; + return *this; + } + + /// \brief Setter method + /// \param rows_per_buffer + /// \return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + rows_per_buffer_ = rows_per_buffer; + return *this; + } + + /// \brief Setter method + /// \param sampler + /// \return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + build_sampler_ = std::move(sampler); + return *this; + } + + /// \brief The builder "build" method creates the final object and does some init on it. + /// \param ptr The shared_ptr to the new CacheOp object + /// \return Status + Status Build(std::shared_ptr *ptr); + + private: + int32_t build_num_workers_; + int32_t rows_per_buffer_; + int32_t build_op_connector_size_; + std::shared_ptr build_cache_client_; + std::shared_ptr build_sampler_; + + /// \brief Check if the required parameters are set by the builder. + /// \return Status The error code return + Status SanityCheck() const; + }; + + /// \brief Constructor of CacheOp + /// \note The builder class should be used to call it. + /// \param num_workers The number of worker threads. + /// \param op_connector_size The size of each queue in the connector. + CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler); + + // Destructor + ~CacheOp(); + + /// \brief Base-class override for setting specific CacheOp configurations. This code will be called + /// during the execution tree prepare phase BEFORE traversing down to child operators. + uint32_t PrepareFlags() const override; + /// \brief Base-class override for special eoe handler. + /// CacheOp must override this because it shall not perform default handling of eoe. Instead + /// the CacheOp manages actions related to the end of the epoch. + /// \return Status - The error code return + Status EoeReceived(int32_t worker_id) override; + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + /// \brief Base-class override for handling cases when an eof is received. + /// \param worker_id - The worker id + /// \return Status - The error code return + Status EofReceived(int32_t worker_id) override; + Status operator()() override; + Status WorkerEntry(int32_t worker_id) override; + /// \brief Base-class override for handling cases if we allow cache miss + bool AllowCacheMiss() override { return false; } + /// \brief Base-class override for the name of this operator + std::string Name() const override { return "CacheOp"; } + /// \brief A public wrapper for creating the cache through the client + /// \param[in] cache_crc The crc that identifies the cache + /// \see cache_pass.cc + /// \return Status return code + Status CreateCache(uint32_t cache_crc); + + private: + WaitPost rows_cache_done_; + std::atomic num_guys_in_; + Phase phase_; + /// \brief The main thread will wait until all the rows are cached and will start the handshake with the sampler. + /// \return Status object + Status WaitForCachingAllRows(); + /// \brief For non-mappable dataset, there is a build phase where we cache all the rows. + /// \return Status object + Status CacheAllRows(int32_t worker_id); + Status RegisterResources() override; + /// \brief Private function for cache setup/init work just after construction + /// \return Status The error code return + Status InitCache(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc new file mode 100644 index 0000000000..7acb68350b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/concat_op.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +ConcatOp::Builder::Builder() { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +// The builder "build" method creates the final object. +Status ConcatOp::Builder::Build(std::shared_ptr *ptr) { + *ptr = std::make_shared(builder_op_connector_size_); + return Status::OK(); +} + +// Constructor of the ConcatOp. +ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {} + +// A function that prints info about the Operator +void ConcatOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this is summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nDatasets: " << children_num_ << "\n\n"; + } +} + +// Main entry point for Concat +Status ConcatOp::operator()() { + // The children_num_ parameter needs to be put here + children_num_ = static_cast(child_.size()); + TaskManager::FindMe()->Post(); + std::unique_ptr buf; + int eof_count = 0; + while (eof_count == 0) { + for (int i = 0; i < children_num_; i++) { + // 1. Read the first buffer + RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); + if (buf->eof()) { + eof_count++; + continue; + } + // 2. Do verification as for column name, column data type and rank of column data + if (!buf->eoe()) { + RETURN_IF_NOT_OK(Verify(i, buf)); + } + // 3. Put the data into output_connector + while (!buf->eoe() && !buf->eof()) { + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); + RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); + } + } + // 4. Add eoe buffer after get buffer from all child + if (eof_count == 0) { + auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + } + } + CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_, + "Something went wrong, eof count does not match the number of children."); + // 5. Add eof buffer in the end manually + MS_LOG(DEBUG) << "Add the eof buffer manualy in the end."; + auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + return Status::OK(); +} + +Status ConcatOp::Verify(int32_t id, const std::unique_ptr &buf) { + TensorRow new_row; + buf->GetRow(0, &new_row); + + if (id == 0) { + // Obtain the data type and data rank in child[0] + for (auto item : new_row) { + data_type_.push_back(item->type()); + data_rank_.push_back(item->Rank()); + } + } else { + // Compare the data type and data rank with these in child[0] + int32_t index = 0; + for (auto item : new_row) { + if ((item->type() != data_type_[index]) || item->Rank() != data_rank_[index++]) { + RETURN_STATUS_UNEXPECTED("The data type or data rank is not the same with previous dataset."); + } + } + } + return Status::OK(); +} + +// We need to overwrite the super class ComputeColMap here because the number of children is more than 1. +Status ConcatOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + // Obtain columns_name_id_map from child_[0] + column_name_id_map_ = child_[0]->column_name_id_map(); + if (column_name_id_map_.empty()) { + RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); + } + // Verify all children have the same column name map + for (int32_t i = 0; i < child_.size(); ++i) { + if (child_[i]->column_name_id_map() != column_name_id_map_) { + RETURN_STATUS_UNEXPECTED("The column name or column order is not the same with previous dataset."); + } + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h new file mode 100644 index 0000000000..3d3d9df71c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class ConcatOp : public PipelineOp { + public: + // The nested builder class inside of the ConcatOp is used to help manage all of the arguments + // for constructing it. This Concat op is very simple though, so this builder is really just + // provided for a consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new ConcatOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_op_connector_size_; + }; + + // Constructor of the ConcatOp. + // @note The builder class should be used to call it + // @param op_connector_size - connector size + explicit ConcatOp(int32_t op_connector_size); + + // Destructor + ~ConcatOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param ro - reference to the ConcatOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ConcatOp &ro) { + ro.Print(out, false); + return out; + } + + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ConcatOp"; } + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + private: + Status Verify(int32_t id, const std::unique_ptr &buf); + + int32_t children_num_; // The num of child of parent node. + std::unordered_map column_name_id_; // Mapping between col index and col name + std::vector data_type_; + std::vector data_rank_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc new file mode 100644 index 0000000000..9254141308 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -0,0 +1,391 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "utils/system/crc32c.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// Constructor +DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr sampler) + : oc_queue_size_(op_connector_size), + sampler_(sampler), + operator_id_(kInvalidOperatorId), + tree_(nullptr), + state_(OpState::kDeOpIdle), + op_ctrl_flags_(kDeOpNone), + out_connector_(nullptr) { + // The operator starts out with an invalid operator id. The only way to + // get it out of invalid state is to assign the operator to an execution tree. +} + +// Adds a operator to become our child. +Status DatasetOp::AddChild(std::shared_ptr child) { + if (std::dynamic_pointer_cast(child) != nullptr) { + std::string err_msg("DeviceQueueOp cannot be added as a child, DeviceQueueOp must be a root node"); + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (operator_id_ == kInvalidOperatorId) { + std::string err_msg( + "Cannot add child node. Tree node connections can only" + "be made if the node belongs to a tree."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // disallow relationships with other trees + if (tree_ != child->tree_) { + std::string err_msg( + "Cannot add child node. Tree node connections can only be made if both nodes belong to the same tree."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + child_.push_back(child); + child->AddParent(this); + return Status::OK(); +} + +Status DatasetOp::RemoveChild(std::shared_ptr child) { + if (operator_id_ == kInvalidOperatorId) { + std::string err_msg( + "Cannot remove child node. Tree node connections can only" + "be made if the node belongs to a tree."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // disallow relationships with other trees + if (tree_ != child->tree_) { + std::string err_msg( + "Cannot remove child node. Tree node connections can only be made if both nodes belong to the same tree."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + child_.erase(std::remove(child_.begin(), child_.end(), child), child_.end()); + child->RemoveParent(this); + return Status::OK(); +} + +Status DatasetOp::InsertAsParent(std::shared_ptr to_add) { + for (auto &prev_parent : this->parent_) { + RETURN_IF_NOT_OK(prev_parent->RemoveChild(shared_from_this())); + RETURN_IF_NOT_OK(prev_parent->AddChild(to_add)); + } + RETURN_IF_NOT_OK(to_add->AddChild(shared_from_this())); + if (tree_->root()->id() == this->id()) { + tree_->AssignRoot(to_add); + } + return Status::OK(); +} + +// Adds a parent operator to this operator +void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); } + +// Removes a parent operator from this operator +void DatasetOp::RemoveParent(const DatasetOp *parent) { + parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end()); +} + +// Removes this node from the tree and connects it's parent/child together +Status DatasetOp::Remove() { + if (parent_.size() > 1) { + std::string err_msg("No support for op removal if the operator has more than one parent"); + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (child_.size() > 1) { + std::string err_msg("No support for op removal if the operator has more than one child"); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Scenario's when removing node B: + // A -> B -> C + // A -> B + // B -> C + // + // If we remove B, then first take our child A and update it's parent to be C + // It's possible the parent is null if we are the root node being removed. + if (!child_.empty()) { + // If we have a parent, then assign chlid's parent to point to our parent. + if (!parent_.empty()) { + child_[0]->parent_[0] = parent_[0]; + } else { + // We don't have a parent, so we are the root node being removed. + // clear the parent list of our child so that it becomes the new root. + child_[0]->parent_.clear(); + tree_->AssignRoot(child_[0]); + } + } + + // Next, if we had a parent, then set it's child to be our child. + if (!parent_.empty()) { + // if we have a child, then set our parent to point to it + if (!child_.empty()) { + parent_[0]->child_[0] = child_[0]; + } else { + // We don't have a child, so clear the child list of the current + // parent because it will be empty once we are removed. + parent_[0]->child_.clear(); + } + } + + // Finally, clear "this" op's parent and child pointers since we have just + // disconnected it from the tree and invalidate it's fields. + child_.clear(); + parent_.clear(); + operator_id_ = kInvalidOperatorId; + tree_ = nullptr; + + return Status::OK(); +} + +// Getter function to get a shared pointer to our child +std::shared_ptr DatasetOp::child(int32_t child_index) const { + std::shared_ptr return_op = nullptr; + if (child_.empty()) { + return return_op; + } + MS_ASSERT(child_index < static_cast(child_.size())); + // Return a shared pointer + return child_[child_index]; +} + +// Getter function to get the parent pointer +void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const { + if (parent_.empty()) { + // common case if this is a root node + *parent = nullptr; + } else { + MS_ASSERT(parent_index < static_cast(parent_.size())); + *parent = parent_[parent_index]; + } +} + +// Creates the connector within this operator +void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) { + MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers + << ". Consumer: " << num_consumers << "."; + if (oc_queue_size_ > 0) { + out_connector_ = std::make_unique(num_producers, // The number of producers + num_consumers, // Only one consumer (the training App) + oc_queue_size_); + } else { + // Some op's may choose not to have an output connector + MS_LOG(DEBUG) << "Bypassed connector creation for tree operator: " << operator_id_ << "."; + out_connector_ = nullptr; + } +} + +// A print method typically used for debugging. showAll of true will recursively descend to child prints +void DatasetOp::Print(std::ostream &out, bool show_all) const { + // When show_all is false, we display a 1 liner piece of text for the op. + // When show_all is true, we display more detailed output for the op. + // Derived printers should show their own header info, then call base class printer, followed by + // derived-specific items. + // For now, the base class doesn't have any summary info to show so it's a no-op in that case. + if (show_all) { + // The detailed display will show common base class info of the op. Allow the derived class to print + // it's own id and name though as the first line. + out << "\nNumber of children : " << child_.size(); + for (size_t i = 0; i < child_.size(); i++) { + out << "\n Child[" << i << "] id: " << child_[i]->id(); + } + out << "\nNumber of parents : " << parent_.size(); + for (size_t i = 0; i < parent_.size(); i++) { + out << "\n Parent[" << i << "] id: " << parent_[i]->id(); + } + out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex + << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' '); + if (sampler_) { + sampler_->Print(out, show_all); + } + } +} + +// Gets the next buffer from the given child +Status DatasetOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { +#if defined(_WIN32) || defined(_WIN64) + RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), p_buffer, retry_if_eoe)); +#else + std::unique_ptr next_buff; + // pop is a blocked call and will throw an interruption if the whole group shuts down. + RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), &next_buff, retry_if_eoe)); + + *p_buffer = std::move(next_buff); +#endif + return Status::OK(); +} + +// Gets the next buffer from the given child . This function also has built-in eoe and eof +// message handling so that child classes don't have to manually code pass-through logic when +// those messages are received. +Status DatasetOp::GetNextInput(std::unique_ptr *p_buffer, int32_t worker_id, int32_t child_index) { + if (child_.size() == 0) { + return this->GetNextBuffer(p_buffer, worker_id); + } + CHECK_FAIL_RETURN_UNEXPECTED(child_index < child_.size(), "Child index too big : " + std::to_string(child_index)); + std::shared_ptr child = child_[child_index]; + std::unique_ptr buf; + RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); + // Loop until non EOE is received + while (buf->eoe()) { + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + if (state_ == OpState::kDeOpIdle) { + *p_buffer = std::move(buf); + return Status::OK(); + } + RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); + } + // Check if the last buf is next eof + if (buf->eof()) { + RETURN_IF_NOT_OK(EofReceived(worker_id)); + } + *p_buffer = std::move(buf); + return Status::OK(); +} + +// Performs handling for when an eoe message is received. +// The base class implementation simply flows the eoe message to output. Derived classes +// may override if they need to perform special eoe handling. +Status DatasetOp::EoeReceived(int32_t worker_id) { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + return (out_connector_->Add(static_cast(worker_id), std::move(eoe_buffer))); +} + +// Performs handling for when an eof message is received. +// The base class implementation simply flows the eof message to output. Derived classes +// may override if they need to perform special eof handling. +Status DatasetOp::EofReceived(int32_t worker_id) { + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + return (out_connector_->Add(static_cast(worker_id), std::move(eof_buffer))); +} + +// During tree prepare phase, operators may have specific pre-operations to perform depending on +// their role. +Status DatasetOp::PrepareNodePreAction() { return Status::OK(); } + +// During tree prepare phase, operators may have specific post-operations to perform depending on +// their role. +Status DatasetOp::PrepareNodePostAction() { + // Creating Connector object for each op. + // The consumer of the root node is assumed to be one thread. + // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion. + if (parent_.empty()) { + this->CreateConnector(num_producers(), 1); + } else { + this->CreateConnector(num_producers(), parent_[0]->num_consumers()); + } + if (out_connector_) { + RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks())); + } + RETURN_IF_NOT_OK(this->RegisterWorkerConnectors()); + + // Generate the column name map for the current op. + RETURN_IF_NOT_OK(this->ComputeColMap()); + + return Status::OK(); +} + +// Getter function. Base class does not have any special flags setting. +uint32_t DatasetOp::PrepareFlags() const { return ExecutionTree::kDePrepNone; } + +// Derived classes may implement the reset function if the operator is stateful and needs +// specific reset handling that is not contained in this common code version of the reset. +Status DatasetOp::Reset() { + state_ = OpState::kDeOpRunning; + return Status::OK(); +} + +// gives a string output for the column map for handy debug printing +std::string DatasetOp::ColumnNameMapAsString() const { + std::string outStr = "Column name id map: "; + for (auto &it : column_name_id_map_) { + outStr += (" " + it.first + ":" + std::to_string(it.second)); + } + return outStr; +} + +// Computing the assignment of the column name map. +// This just inherits the column map from its first child, can only be used if the number of children is 1. +// Operations changing the column map must overwrite this function. +Status DatasetOp::ComputeColMap() { + if (child_.size() > 1) { + RETURN_STATUS_UNEXPECTED("Assigning column name map from child only works for single-child operators."); + } + if (column_name_id_map_.empty()) { + column_name_id_map_ = child_[0]->column_name_id_map(); + if (column_name_id_map_.empty()) { + RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); + } + MS_LOG(DEBUG) << "Setting column map:\n" << DatasetOp::ColumnNameMapAsString(); + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +Status DatasetOp::PreAccept(NodePass *p, bool *modified) { + // DatasetOp is the base class of visitor target pre-visit. + // This method will only be called if its derived class does not implement one. + return p->PreRunOnNode(shared_from_this(), modified); +} + +Status DatasetOp::Accept(NodePass *p, bool *modified) { + // DatasetOp is the base class of visitor target. + // This method will only be called if its derived class does not implement one. + return p->RunOnNode(shared_from_this(), modified); +} + +// Getter for the sampler, and it also removes the sampler from the op +Status DatasetOp::FetchRemoveSampler(std::shared_ptr *sampler) { + *sampler = sampler_; // It's okay if it sampler_ points to nullptr + sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler + return Status::OK(); +} + +uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { + std::stringstream ss; + op->tree_->Print(ss, op); + std::string ss_str = ss.str(); + + // Filter out the Operator control flags field when generating the check sum + ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), ""); + + // Filter out the Device id field to allow cache sharing for a distributed run of the same pipeline + ss_str = std::regex_replace(ss_str, std::regex("Device id.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), ""); + + // The Cache crc and Server cache id field is different when creating new cache_client and re-using the same + // cache_client later. So we filter out these two fields to allow cache sharing. + ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), ""); + + uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); + return cache_crc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h new file mode 100644 index 0000000000..b4630c1652 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -0,0 +1,363 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ +#define DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +// Forward declare +class ExecutionTree; + +class DataBuffer; + +class NodePass; + +class Sampler; + +/// \brief The base class DatasetOp is the main tree node. It is an abstract class, so +/// the actual implementation of the operators will be derived from here. +class DatasetOp : public std::enable_shared_from_this { + // Allow execution tree to access internal members + friend class ExecutionTree; + + public: + static constexpr int32_t kInvalidOperatorId = -1; + + // Operator control flags + enum OpControlFlags { + kDeOpNone = 0, + kDeOpRepeated = 1, // Operator is a node in a repeat path + kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop + }; + + // Flags that control operator runtime behaviours + enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated }; + + /// Constructor + /// \param op_connector_size - The size for the output connector of this operator. + /// \param sampler - The sampler for the op + explicit DatasetOp(int32_t op_connector_size, std::shared_ptr sampler); + + /// Destructor + virtual ~DatasetOp() { tree_ = nullptr; } + + /// Adds a operator to become our child. + /// \param child - shared pointer to the child to add. + Status AddChild(std::shared_ptr child); + + /// Remove a operator from our children. + /// \param child - shared pointer to the child to remove. + Status RemoveChild(std::shared_ptr child); + + /// \brief Removes this node from the tree and connects it's parent/child together + /// \return Status eerror code returned + Status Remove(); + + /// \brief Getter function to get a shared pointer to our child + /// \param[in] child_index An operator can have n children. Indicates which child to return. + /// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index + std::shared_ptr child(int32_t child_index) const; + + /// \brief Getter function to get the pointer to our parent + /// If there are no parents, it returns null regardless of the given index + /// \param[in] parent_index An operator can have n parents. Indicates which parent to return. + void Parent(DatasetOp **parent, int32_t parent_index) const; + + // Inserts a operator as the parent current op. + // Inserted op will become the sole parent of the current op. + // The existing parent of the current op will be transferred to the inserted op. + Status InsertAsParent(std::shared_ptr to_add); + + /// \brief Creates the connector within this operator + /// \param num_producers - number of threads that write into this connector + /// \param num_consumers - number of threads that read from this connector + void CreateConnector(int32_t num_producers, int32_t num_consumers); + + /// \brief A print method typically used for debugging + /// \param out - The output stream to write output to + /// \param show_all - A bool to control if you want to show all info or just a summary + virtual void Print(std::ostream &out, bool show_all) const; + + /// \brief << Stream output operator overload + /// \notes This allows you to write the debug print info using stream operators + /// \param out - reference to the output stream being overloaded + /// \param dO - reference to the DatasetOp to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const DatasetOp &dO) { + dO.Print(out, false); + return out; + } + + /// \brief Class functor operator (). + /// DatasetOps operate by launching a thread (see ExecutionTree). + /// This pure virtual version makes the requirement that derived classes must provide a functor + /// that will execute their main runtime loop code. + /// \return Status - The error code return + virtual Status operator()() = 0; + + /// \brief Gets the next buffer from the given child + /// \notes See GetNextInput for similar function that has built-in message handling + /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) + /// \param worker_id - The worker id + /// \return Status - The error code return + virtual Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id) { + return GetNextBuffer(p_buffer, worker_id, false); + } + + /// \brief Gets the next buffer from the given child + /// \notes See GetNextInput for similar function that has built-in message handling + /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) + /// \return Status - The error code return + virtual Status GetNextBuffer(std::unique_ptr *p_buffer) { return GetNextBuffer(p_buffer, 0, false); } + + /// \brief Gets the next buffer from the given child + /// \notes See GetNextInput for similar function that has built-in message handling + /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) + /// \param worker_id - The worker id + /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. + /// \return Status - The error code return + virtual Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe); + + /// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof + /// message handling so that child classes don't have to manually code pass-through logic when + /// those messages are received. + /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) + /// \param worker_id - The worker id + /// \return Status - The error code return + Status GetNextInput(std::unique_ptr *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); + + /// \brief Performs handling for when an eoe message is received. + /// The base class implementation simply flows the eoe message to output. Derived classes + /// may override if they need to perform special eoe handling. + /// \param worker_id - The worker id + /// \return Status - The error code return + virtual Status EoeReceived(int32_t worker_id); + + /// \brief Performs handling for when an eof message is received. + /// The base class implementation simply flows the eof message to output. Derived classes + /// may override if they need to perform special eof handling. + /// \param worker_id - The worker id + /// \return Status - The error code return + virtual Status EofReceived(int32_t worker_id); + + /// \brief Derived classes may implement the reset function if the operator is stateful and needs + /// specific reset handling that is not contained in this common code version of the reset + /// \return Status - The error code return + virtual Status Reset(); + + /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on + /// their role. + /// \notes Derived versions of this function should always call it's superclass version first + /// before providing their own implementations. + virtual Status PrepareNodePreAction(); + + /// \brief During tree prepare phase, operators may have specific post-operations to perform depending on + /// their role. + /// \notes Derived versions of this function should always call it's superclass version first + /// before providing their own implementations. + virtual Status PrepareNodePostAction(); + + /// \brief Getter function + /// \return The operator id + int32_t id() const { return operator_id_; } + + /// \brief Getter function + /// \return The prepare flags + virtual uint32_t PrepareFlags() const; + + /// \brief Getter function + /// \return The number of workers in this op + virtual int32_t num_workers() const = 0; + + /// \brief Getter function + /// \return The number of threads consuming from previous op. + virtual int32_t num_consumers() const = 0; + + /// \brief Getter function + /// \return The number of threads producing to the output connector. + virtual int32_t num_producers() const = 0; + + /// \brief Getter function + /// \return T/F if this is an inlined operator + bool inlined() const { return (oc_queue_size_ == 0); } + + /// \brief Setter function + /// \return Sets the control flags + void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); } + + /// \brief Setter function + /// \return Sets the control flags + void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); } + + /// \brief Register the internal worker connectors. No op unless it is a parallel op + /// \return Status + virtual Status RegisterWorkerConnectors() { return Status::OK(); } + + /// \brief Getter for the column name mapping + /// \return The returned map + std::unordered_map column_name_id_map() const { return column_name_id_map_; } + + /// \brief Checks if the column name map has been set up yet for this op + /// \return - T/F if the operator has the map set up + bool HasColumnNameMap() const { return (column_name_id_map_.empty()); } + + /// \brief gives a string output for the column map for handy debug printing + /// \return - the column name map as a string + std::string ColumnNameMapAsString() const; + + /// \brief Getter function + /// \return connector size of current op + int32_t ConnectorSize() const { + if (!inlined()) { + return out_connector_->size(); + } + // Return child connector size for inlined op + return ChildOpConnectorSize(); + } + + /// \brief Counting number of buffer sent out by a connector + int64_t ConnectorOutBufferCount() const { + return out_connector_ == nullptr ? int64_t(-1) : static_cast(out_connector_->out_buffers_count()); + } + + /// \brief Getter function + /// \return connector size of current op + int32_t ConnectorCapacity() const { + if (!inlined()) { + return out_connector_->capacity(); + } + // Return child connector capacity for inlined op + return ChildOpConnectorCapacity(); + } + + /// \brief Getter function + /// \return connector size of child op + int32_t ChildOpConnectorSize(int32_t child_index = 0) const { return child_[child_index]->ConnectorSize(); } + + /// \brief Getter function + /// \return connector capacity of child op + int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); } + + /// \brief Children Getter + /// \return Vector of Children + std::vector> Children() const { return child_; } + + /// \brief Base method for NodePass pre-visit. A tree walk consists of walking down the tree and also walking back up + /// in a depth-first order. PreAccept is the node visit on the way down, whereas the regular Accept is the main + /// visit on the way back up the tree during a post-order traversal. Subclass needs to override this if it + /// requires special node visit access. Check "dataset/engine/opt/pass.h" for more details. + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + virtual Status PreAccept(NodePass *p, bool *modified); + + /// \brief Base method for NodePass visit. Subclass needs to override this if it requires special node visit access. + /// Check "dataset/engine/opt/pass.h" for more details. + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + virtual Status Accept(NodePass *p, bool *modified); + + /// Op name getter + /// \return Name of the current Op + virtual std::string Name() const { return "DatasetOp"; } + + /// Execution Tree getter + /// \return Pointer to the ExecutionTree the current op belongs to, no ownership + ExecutionTree *Tree() { return tree_; } + + /// Getter for the sampler + /// \return Shared pointer to the sampler (may return nullptr) + std::shared_ptr sampler() { return sampler_; } + + /// \brief Getter for the sampler, and it also removes the sampler from the op + /// \param[out] sampler A pointer to the output sampler that was removed + /// \return Status error code + Status FetchRemoveSampler(std::shared_ptr *sampler); + + // Computes a CRC value for the operator + static uint32_t GenerateCRC(const std::shared_ptr &op); + + /// \brief A helper templated function for casting "this" pointer to shared_ptr + /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr + /// \return A shared_ptr casted to the derived class + template + std::shared_ptr shared_from_base() { + return std::static_pointer_cast(shared_from_this()); + } + + /// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. + void SetSampler(std::shared_ptr sampler) { sampler_ = sampler; } + + /// \brief Checks if this is a leaf node (0 children) + /// \return boolean returns true if it's a leaf + bool IsLeaf() { return (child_.empty()); } + + protected: + /// \brief Removes a parent operator from this operator + /// \notes External callers do not have access to this function + /// \param[in] parent The parent node to remove + void RemoveParent(const DatasetOp *parent); + + /// \brief Adds a parent operator to this operator + /// \notes External callers do not have access to this function + /// \param[in] parent The parent node to add + void AddParent(DatasetOp *parent); + + /// Compute the current op's column map using its child's column map. + /// Get called during the tree post-prepare phase in PrepareNodePostAction. + /// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1. + /// Operations changing the column map it inherits from the child must overwrite this function. + /// \return - Status + virtual Status ComputeColMap(); + + std::vector> child_; // Child nodes + std::vector parent_; // Parent nodes. No ownership + std::shared_ptr sampler_; // Some leaf ops might have a sampler + int32_t oc_queue_size_; // Capacity for each out_connector_ + int32_t operator_id_; // Generated id for the node + ExecutionTree *tree_; // Back pointer to our tree. + OpState state_; // The state of the operator, Running, Idle, Terminated + uint32_t op_ctrl_flags_; // Flags for the operator + std::unique_ptr out_connector_; // Output Connector + std::unordered_map column_name_id_map_; // Mapping between col index and col name + std::mutex column_name_map_mutex_; // For protecting shared access to the column map + + private: + /// Sets the operator id. + /// \notes No public interface. Only the class itself, or it's friend the execution tree can set + /// this + /// \param op_id - the Id value to set into the operator + void set_id(int32_t op_id) { operator_id_ = op_id; } + + /// Sets the tree into the op so that the operator has a back pointer to the tree. + /// \param tree - the tree to assign to the op. + void set_tree(ExecutionTree *tree) { tree_ = tree; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc new file mode 100644 index 0000000000..4fe779246b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -0,0 +1,320 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/perf/profiling.h" +#include "minddata/dataset/engine/perf/device_queue_tracing.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, + int32_t op_connector_size, int64_t num_batch) + : PipelineOp(op_connector_size), + channel_name_(channel_name), + device_type_(device_type), + device_id_(device_id), + prefetch_size_(prefetch_size), + num_batch_(num_batch) {} + +DeviceQueueOp::~DeviceQueueOp() {} + +#ifdef ENABLE_GPUQUE +void ReleaseData(void *addr) { + if (addr != nullptr) { + free(addr); + } +} +#endif + +DeviceQueueOp::Builder::Builder(int32_t prefetch_size) + : builder_prefetch_size_(prefetch_size), + builder_device_id_(0), + builder_device_type_(DeviceType::CPU), + builder_channel_name_(""), + builder_num_batch_(0) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status DeviceQueueOp::EoeReceived(int32_t worker_id) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +Status DeviceQueueOp::operator()() { + TaskManager::FindMe()->Post(); + + if (device_type_ == DeviceType::Ascend) { +#ifdef ENABLE_TDTQUE + RETURN_IF_NOT_OK(SendDataToAscend()); +#endif + } else if (device_type_ == DeviceType::GPU) { +#ifdef ENABLE_GPUQUE + RETURN_IF_NOT_OK(SendDataToGPU()); +#endif + } else if (device_type_ == DeviceType::CPU) { + RETURN_IF_NOT_OK(SendDataToCPU()); + } + + return Status::OK(); +} + +Status DeviceQueueOp::CheckExceptions(const std::unique_ptr &buffer) const { + // this method checks if the buffer meets the conditions to be sent to TDT + if (buffer->NumRows() != 0) { + TensorRow row; + buffer->GetRow(0, &row); + for (const auto &item : row) { + CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device."); + } + } + return Status::OK(); +} + +#ifdef ENABLE_TDTQUE +Status DeviceQueueOp::SendDataToAscend() { + MS_LOG(INFO) << "Device queue, sending data to Ascend."; + int64_t total_batch = 0; + bool is_break_loop = false; + double batch_start_time, end_time; + int32_t batch_cost, tdt_cost; + int32_t connector_size = 0; + int32_t connector_capacity; + std::shared_ptr profiling_node; + bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable(); + if (isProfilingEnable) { + std::shared_ptr node; + RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node)); + profiling_node = std::dynamic_pointer_cast(node); + batch_start_time = ProfilingTime::GetCurMilliSecond(); + connector_capacity = ChildOpConnectorCapacity(); + } + std::unique_ptr current_buffer; + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + + while (!current_buffer->eof() && !is_break_loop) { + while (!current_buffer->eoe() && !is_break_loop) { + RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); + TensorRow currRow; + for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) { + RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); + auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); + if (status == TdtStatus::FAILED) { + return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); + } + + if (isProfilingEnable) { + end_time = ProfilingTime::GetCurMilliSecond(); + // record push tdt time + profiling_node->Record(TIME, TDT_PUSH_TIME, total_batch + 1, tdt_cost); + batch_cost = (int32_t)(end_time - batch_start_time); + // record batch time + profiling_node->Record(TIME, BATCH_TIME, total_batch + 1, batch_cost); + // record pipeline time + profiling_node->Record(TIME, PIPELINE_TIME, total_batch + 1, batch_cost - tdt_cost); + batch_start_time = end_time; + // record connector depth + profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size); + } + total_batch++; + if (num_batch_ > 0 && total_batch == num_batch_) { + is_break_loop = true; + } + } + if (isProfilingEnable) { + connector_size = ChildOpConnectorSize(); + connector_capacity = ChildOpConnectorCapacity(); + } + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + } + if (isProfilingEnable) { + connector_size = ChildOpConnectorSize(); + connector_capacity = ChildOpConnectorCapacity(); + } + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + } + + tree_->SetFinished(); + MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; + + return Status::OK(); +} +#endif + +#ifdef ENABLE_GPUQUE +Status DeviceQueueOp::SendDataToGPU() { + MS_LOG(INFO) << "Device queue, sending data to GPU."; + int64_t total_batch = 0; + bool is_break_loop = false; + bool is_open = false; + uint32_t handle = INVALID_HANDLE; + + std::unique_ptr current_buffer; + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + + while (!current_buffer->eof() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed()) { + while (!current_buffer->eoe() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed()) { + RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); + TensorRow curr_row; // batch data + for (int row_id = 0; + row_id < current_buffer->NumRows() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed(); row_id++) { + RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &curr_row)); + + std::vector data_size; + for (int i = 0; i < curr_row.size(); i++) { + data_size.push_back(static_cast(curr_row[i]->SizeInBytes())); + } + if (!is_open) { + handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, data_size, ReleaseData); + if (handle == INVALID_HANDLE) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "open failed"); + } + is_open = true; + } + RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle)); + total_batch++; + if (num_batch_ > 0 && total_batch == num_batch_) { + is_break_loop = true; + } + } + if (!TaskManager::FindMe()->Interrupted()) + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + else + is_break_loop = true; + } + if (!TaskManager::FindMe()->Interrupted()) + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + else + is_break_loop = true; + } + + MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; + + GpuBufferMgr::GetInstance().Close(handle); + + GpuBufferMgr::GetInstance().CloseConfirm(); + + return Status::OK(); +} + +Status DeviceQueueOp::RetryPushGPUData(const std::vector &data_size, const TensorRow &curr_row, + uint32_t handle) { + std::vector items; + for (int i = 0; i < data_size.size(); i++) { + device::DataItemGpu data_item; + data_item.data_len_ = data_size[i]; + data_item.data_ptr_ = nullptr; + items.push_back(data_item); + } + + while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { + RETURN_IF_NOT_OK(MallocForGPUData(&items, curr_row)); + BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); + if (ret) { + for (int i = 0; i < items.size(); i++) { + free(items[i].data_ptr_); + } + if (ret == BlockQueueStatus_T::ERROR_INPUT) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it."); + } else { + MS_LOG(WARNING) << "Retry pushing data..."; + continue; + } + } else { + break; + } + } + return Status::OK(); +} + +Status DeviceQueueOp::MallocForGPUData(std::vector *items, const TensorRow &curr_row) { + int i = 0; + for (auto &sub_item : *items) { + sub_item.data_ptr_ = (unsigned char *)malloc(sub_item.data_len_); + if (sub_item.data_ptr_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memory malloc failed."); + } + (void)memset_s(sub_item.data_ptr_, sub_item.data_len_, 0, sub_item.data_len_); + const unsigned char *column_data = curr_row[i]->GetBuffer(); + if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data, + static_cast(curr_row[i++]->SizeInBytes())) != 0) { + MS_LOG(ERROR) << "memcpy_s failed!"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memcpy_s failed."); + } + } + + return Status::OK(); +} +#endif + +Status DeviceQueueOp::SendDataToCPU() { + MS_LOG(INFO) << "Device queue, sending data to CPU."; + int64_t total_batch = 0; + + std::unique_ptr child_iterator = std::make_unique(this, 0, 0); + while (!(child_iterator->eof_handled())) { + TensorRow curr_row; + RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&curr_row)); + + if (!curr_row.empty()) { + MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << "."; + MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << "."; + total_batch++; + if (num_batch_ > 0 && total_batch == num_batch_) { + break; + } + } + } + + MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; + + return Status::OK(); +} + +void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n"; + } +} + +// Visitor accept method for NodePass +Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h new file mode 100644 index 0000000000..0fb4fb093d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h @@ -0,0 +1,175 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" + +#ifdef ENABLE_TDTQUE +#include "minddata/dataset/engine/tdt/tdt_plugin.h" +#endif + +#ifdef ENABLE_GPUQUE +#include "runtime/device/gpu/gpu_buffer_mgr.h" +using mindspore::device::BlockQueueStatus_T; +using mindspore::device::GpuBufferMgr; +#endif + +namespace mindspore { +namespace dataset { +class DeviceQueueOp : public PipelineOp { + public: + static const uint32_t INVALID_HANDLE = 0xffffffffUL; + static const uint32_t WAIT_TIME = 5; + + enum class DeviceType { Ascend = 0, GPU = 1, CPU = 2 }; + + // The nested builder class inside of the DeviceQueueOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + explicit Builder(int32_t prefetch_size); + + // Default destructor + ~Builder() = default; + + Builder &SetPrefetchSize(int32_t prefetch_size) { + builder_prefetch_size_ = prefetch_size; + return *this; + } + + Builder &SetChannelName(const std::string &channel_name) { + builder_channel_name_ = channel_name; + return *this; + } + + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + Builder &SetDeviceType(const std::string &device_type) { + if (device_type == "Ascend") { + builder_device_type_ = DeviceType::Ascend; + } else if (device_type == "GPU") { + builder_device_type_ = DeviceType::GPU; + } else if (device_type == "CPU") { + builder_device_type_ = DeviceType::CPU; + } + return *this; + } + + Builder &SetDeviceId(int32_t device_id) { + builder_device_id_ = device_id; + return *this; + } + + Builder &SetNumBatch(int64_t num_batch) { + builder_num_batch_ = num_batch; + return *this; + } + + // Name: Build() + // Description: The final step for building a DeviceQueueOp via the Builder is + // to call this Build() method. It will instantiate the DeviceQueueOp + // and return it to caller as a shared pointer. + Status Build(std::shared_ptr *ptr) { + *ptr = std::make_shared(builder_channel_name_, builder_device_type_, builder_device_id_, + builder_prefetch_size_, builder_op_connector_size_, builder_num_batch_); + return Status::OK(); + } + + private: + int32_t builder_prefetch_size_; + int32_t builder_device_id_; + DeviceType builder_device_type_; + std::string builder_channel_name_; + int64_t builder_num_batch_; + int32_t builder_op_connector_size_; + }; + + // Name: constructor + // Description + DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, + int32_t op_connector_size, int64_t num_batch); + + // Name: destructor + // Description + ~DeviceQueueOp(); + + Status EoeReceived(int32_t worker_id) override; + + const int32_t get_prefetch_size() { return prefetch_size_; } + + // Name: Print() + // Description: A function that prints info about the node + void Print(std::ostream &out, // In: The output stream to print to + bool show_all) const override; // In: T/F if it should print everything + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const DeviceQueueOp &to) { + to.Print(out, false); + return out; + } + + Status operator()() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "DeviceQueueOp"; } + + private: + // Name: checkExceptions(DataBuffer); + // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp + Status CheckExceptions(const std::unique_ptr &buffer) const; + +#ifdef ENABLE_TDTQUE + Status SendDataToAscend(); +#endif + +#ifdef ENABLE_GPUQUE + Status SendDataToGPU(); + Status RetryPushGPUData(const std::vector &data_size, const TensorRow &curr_row, uint32_t handle); + Status MallocForGPUData(std::vector *items, const TensorRow &curr_row); +#endif + + Status SendDataToCPU(); + std::string channel_name_; + DeviceType device_type_; + const int32_t device_id_; + const int32_t prefetch_size_; + const int64_t num_batch_; + +#ifdef ENABLE_TDTQUE + std::shared_ptr tdtInstancePtr; +#endif +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc new file mode 100644 index 0000000000..f32648a3df --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc @@ -0,0 +1,267 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/filter_op.h" +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { + +Status FilterOp::Builder::SanityCheck() { + std::string err; + err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; + err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); +} + +FilterOp::Builder::Builder() { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status FilterOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_, + builder_predicate_func_); + return Status::OK(); +} + +FilterOp::FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, + py::function predicate_func) + : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {} + +Status FilterOp::operator()() { + // The operator class just starts off threads by calling the tree_ function. + RETURN_UNEXPECTED_IF_NULL(tree_); + filter_queues_.Init(num_workers_, oc_queue_size_); + RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); + Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1)); + // Synchronize with TaskManager. + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(rc); + RETURN_IF_NOT_OK(Collector()); + return Status::OK(); +} + +Status FilterOp::EofReceived(int32_t) { return Status::OK(); } + +Status FilterOp::EoeReceived(int32_t) { return Status::OK(); } + +// Validating if each of the input_columns exists in the DataBuffer. +Status FilterOp::ValidateInColumns(const std::vector *input_columns) { + for (const auto &inCol : *input_columns) { + bool found = column_name_id_map_.find(inCol) != column_name_id_map_.end() ? true : false; + if (!found) { + std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); +} + +// A print method typically used for debugging. +void FilterOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nInput column names:"; + for (size_t i = 0; i < in_columns_.size(); i++) { + out << " " << in_columns_[i]; + } + out << "\n\n"; + } +} + +Status FilterOp::WorkerEntry(int32_t worker_id) { + // Handshake with TaskManager that thread creation is successful. + TaskManager::FindMe()->Post(); + std::unique_ptr in_buffer; + bool worker_stop = false; + while (worker_stop == false) { + // Getting a databuffer to work on. + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); + if (in_buffer->eoe()) { + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); + continue; + } else if (in_buffer->eof()) { + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); + worker_stop = true; + continue; + } + + RETURN_IF_NOT_OK(CheckColumns(in_buffer.get(), &in_columns_)); + + // if the databuffer was all filtered, it is marked as kFilterEmpty. + // if the databuffer was partially filtered, it is marked as kFilterPartial. + // if the databuffer was not filtered, it is marked as kFilterFull. + int32_t num_rows = in_buffer->NumRows(); + std::unique_ptr new_tensor_table; + RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), &new_tensor_table)); + + if (new_tensor_table->empty()) { + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty))); + } else if (new_tensor_table->size() == num_rows) { + in_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull))); + } else { // kFilterPartial + in_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial))); + } + } + return Status::OK(); +} + +Status FilterOp::WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out) { + *out = std::make_unique(); + int32_t num_rows = in_buffer->NumRows(); + for (int32_t i = 0; i < num_rows; i++) { + TensorRow to_process; + TensorRow cur_row; + RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); + if (in_columns_.empty() == true) { + MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; + to_process = cur_row; + } else { + (void)std::transform( + in_columns_.begin(), in_columns_.end(), std::back_inserter(to_process), + [&cur_row, this](const auto &it) -> std::shared_ptr { return cur_row[column_name_id_map_[it]]; }); + } + bool predicate = true; + RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate)); + if (predicate) { + (*out)->push_back(std::move(cur_row)); + } + } + return Status::OK(); +} + +// if the filtered DataBuffer is written directly to out_connector_, +// the thread fetching data will block in a queue. +// Collector function will reorder the DataBuffer in order. +// for example in two work queues: +// int filter_queues_: +// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof) +// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe) +// after reorder in out_connector_: +// queue1: DB(data2) DB(data4) DB(eof) +// queue2: DB(eoe) DB(eoe) +Status FilterOp::Collector() { + bool collector_stop = false; + uint64_t task_id_cnt = 0; + uint64_t out_id_cnt = 0; + std::pair, filterCtrl> in_pair; + while (collector_stop == false) { + uint32_t w_id = task_id_cnt % num_workers_; + RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); + if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || + in_pair.second == filterCtrl::kFilterEoe) { + uint32_t out_task_id = out_id_cnt % num_workers_; + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); + out_id_cnt++; + task_id_cnt++; + } else if (in_pair.second == filterCtrl::kFilterEof) { + uint32_t out_task_id = out_id_cnt % num_workers_; + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); + collector_stop = true; + } else { // kFilterEmpty + task_id_cnt++; + } + } + return Status::OK(); +} + +// Private function for checking the column legality. +Status FilterOp::CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns) { + int32_t num_rows = in_buf->NumRows(); + int32_t num_cols = in_buf->NumCols(); + if (num_rows == 0 || num_cols == 0) { + RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer."); + } + // Check if there is invalid column name in the inColumns. + RETURN_IF_NOT_OK(ValidateInColumns(input_columns)); + return Status::OK(); +} + +Status FilterOp::CheckInput(const TensorRow &input) const { + for (auto &item : input) { + if (item == nullptr) { + RETURN_STATUS_UNEXPECTED("input is null."); + } + } + return Status::OK(); +} + +Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) { + RETURN_IF_NOT_OK(CheckInput(input)); + // Acquire Python GIL. + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + // Transform input tensor vector into numpy array vector. + py::tuple input_args(input.size()); + for (size_t i = 0; i < input.size(); i++) { + py::array new_data; + RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); + input_args[i] = new_data; + } + // Invoke python function. + py::object ret_py_obj = predicate_func_(*input_args); + *out_predicate = ret_py_obj.cast(); + } catch (const py::error_already_set &e) { + std::stringstream ss; + ss << e.what() << std::endl; + ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool."; + return Status(StatusCode::kPyFuncException, ss.str()); + } + return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); +} + +// Visitor accept method for NodePass +Status FilterOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h new file mode 100644 index 0000000000..fcc6e577df --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h @@ -0,0 +1,188 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/queue.h" + +namespace mindspore { +namespace dataset { + +class FilterOp : public ParallelOp { + public: + // The nested builder class inside of the FilterOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args. + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetPredicateFunc(py::function func) { + builder_predicate_func_ = std::move(func); + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetInColNames(const std::vector &in_col_names) { + build_in_col_names_ = in_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + builder_op_connector_size_ = connector_size; + return *this; + } + + // The builder "build" method creates the final object. + // @param ptr The shared_ptr to the new FilterOp object. + // @return Status. + Status Build(std::shared_ptr *ptr); + + private: + // Sanity check for builder class args. + // @return Status - The error code return. + Status SanityCheck(); + std::vector build_in_col_names_; + py::function builder_predicate_func_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + }; + + enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 }; + + // Constructor of FilterOp + // @note The builder class should be used to call it. + // @param in_col_names A list of input column names,when it is empty the predicate will be + // applied all columns in the dataset. + // @param num_workers The number of worker threads. + // @param op_connector_size The size of each queue in the connector. + // @param predicate_func python callable which returns a boolean value. + FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, + py::function predicate_func); + + // Destructor + ~FilterOp() = default; + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will + // provide the master loop that drives the logic for performing the work. + // @return Status The error code return + Status operator()() override; + + // @param int32_t workerId. + // @return Status - The error code return. + Status EofReceived(int32_t) override; + + // @param int32_t workerId. + // @return Status - The error code return. + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging. + // @param out The output stream to write output to. + // @param show_all A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "FilterOp"; } + + private: + // predicate_func python callable which returns a boolean value. + py::function predicate_func_; + + // Variable to store the column name that will feed to predicate function. + std::vector in_columns_; + + // Internal queue for filter. + QueueList, filterCtrl>> filter_queues_; + + // Private function for worker/thread to loop continuously. It comprises the main + // logic of FilterOp, getting the data from previous Op, validating user specified column names, + // applying predicate to each of the data, filter the data when predicate result is false. + // @param worker_id The id assigned to this thread/worker upon creation. + // @return Status The error code return. + Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ + + // Filter the data by predicate function . + // @param in_buffer input data buffer. + // @param to_proess_indices Indices of columns to be processed. + // @param out data buffer that are filtered by predicate. + // @return Status The error code return. + Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out); + + // Collector databuffer. + // @return Status The error code return. + Status Collector(); + + // @param input tensor vector. + // @return Status - The error code return. + Status CheckInput(const TensorRow &input) const; + + // Invoke python func. + // @param input tensor vector. + // @param the result of predicate. + // @return Status - The error code return. + Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); + + // Private function for validating if each of the user specified input column names + // exist in the DataBuffer. + // @param input_columns The vector of input column names used in the current thread. + // @return Status The error code return. + Status ValidateInColumns(const std::vector *input_columns); + + // Private function for checking the column legality + // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory + // and is not shared with other threads. + // @param[out] to_process_indices Indices of columns that will feed to predicate. + // @param input_columns The vector of input column names used in the current thread. + Status CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns); +}; + +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.cc new file mode 100644 index 0000000000..e5e70dbbdf --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.cc @@ -0,0 +1,373 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/map_op.h" +#include +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +MapOp::Builder::Builder() : build_perf_mode_(true) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + build_op_connector_size_ = cfg->op_connector_size(); +} + +// Check if the required parameters are set by the builder. +Status MapOp::Builder::sanityCheck() const { + if (build_tensor_funcs_.empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Building a MapOp that has not provided any function/operation to apply"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status MapOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(sanityCheck()); + *ptr = std::make_shared(std::move(build_in_col_names_), std::move(build_out_col_names_), + std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_, + build_perf_mode_); + return Status::OK(); +} + +// Constructor of MapOp +MapOp::MapOp(const std::vector &in_col_names, const std::vector &out_col_names, + std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, + bool perf_mode) + : ParallelOp(num_workers, op_connector_size), + tfuncs_(std::move(tensor_funcs)), + in_columns_(in_col_names), + out_columns_(out_col_names), + perf_mode_(perf_mode) { + // If caller didn't specify the out_col_names, assume they are same as the in_columns. + if (out_columns_.empty() || out_columns_[0].empty()) { + out_columns_ = in_columns_; + } + MS_LOG(DEBUG) << "Performance Mode in map operator is " << perf_mode_ << "."; +} + +// The number of threads consuming data from previous op's output Connector. +int32_t MapOp::num_consumers() const { + // When Performance Mode is on, there is only one thread consuming from the previous Connector. + return perf_mode_ == true ? 1 : num_workers_; +} + +// A print method typically used for debugging +void MapOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nInput column names:"; + for (size_t i = 0; i < in_columns_.size(); i++) { + out << " " << in_columns_[i]; + } + out << "\n TensorOps:"; + for (size_t i = 0; i < tfuncs_.size(); i++) { + out << " " << *(tfuncs_[i].get()); + } + out << "\n\n"; + } +} + +// This class functor will provide the master loop that drives the logic for performing the work +Status MapOp::operator()() { + if (perf_mode_) { + // Create and register the local queues. + local_queues_.Init(num_workers_, oc_queue_size_); + Status rc = local_queues_.Register(tree_->AllTasks()); + if (rc.IsError()) { + TaskManager::FindMe()->Post(); + return rc; + } + } + + // The operator class just starts off threads by calling the tree_ function + Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1)); + // Synchronize with TaskManager + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(rc); + + if (perf_mode_) { + int64_t que_id = 0; + std::unique_ptr buff; + bool is_eof = false; + // Draining output connector of the previous op and distribute it to local queues. + // Stop when all worker threads are finished (received EOF). + while (!is_eof) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); + is_eof = buff->eof(); + RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(buff))); + que_id = (que_id + 1) % num_workers_; + } + } + + return Status::OK(); +} + +// Private function for worker/thread to loop continuously. It comprises the main +// logic of MapOp: getting the data from previous Op, validating user specified column names, +// applying a list of TensorOps to each of the data, process the results and then +// pushing them back to MapOp's output Connector to be fetched by the next Op. +Status MapOp::WorkerEntry(int32_t worker_id) { + // Handshake with TaskManager that thread creation is successful. + TaskManager::FindMe()->Post(); + std::unique_ptr in_buffer; + + // Getting a databuffer to work on. + // Perform the first fetch here outside of the loop. This allows us to execute one-time only + // initializations that happen after the first fetch. + RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); + + // Sanity check the databuffer. + // Special case: if there's more threads than buffers, some threads simply get the final control + // messages (eoe/eof), and so they will not perform the check. + if (!in_buffer->eoe() && !in_buffer->eof()) { + int32_t num_rows = in_buffer->NumRows(); + int32_t num_cols = in_buffer->NumCols(); + if (num_rows == 0 || num_cols == 0) { + RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer."); + } + } + + // Now that init work is done, drop into the main fetching loop. + // Map op does not use child iterator, and it needs to manually handle eoe and eof's itself + // rather than use the base-class defaults. + while (true) { + // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work + // with Performance Mode design. + if (in_buffer->eoe()) { + // Calling base class EoeReceived to forward eoe buffer. + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); + continue; + } else if (in_buffer->eof()) { + // Calling base class EofReceived to forward eof buffer. + RETURN_IF_NOT_OK(EofReceived(worker_id)); + break; + } + + std::unique_ptr new_tensor_table(std::make_unique()); + // Perform the compute function of TensorOp(s) and store the result in new_tensor_table. + RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get())); + + // Replace the TensorTable in DataBuffer with the new one. + in_buffer->set_tensor_table(std::move(new_tensor_table)); + + // Push the buffer onto the connector for next operator to consume. + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(worker_id), std::move(in_buffer))); + + // Fetch the next buffer and loop back to the top. + RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); + } + + return Status::OK(); +} + +Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table) { + // Getting number of rows and cols in this buffer. + int32_t num_rows = in_buffer->NumRows(); + int32_t num_cols = in_buffer->NumCols(); + + for (int32_t r = 0; r < num_rows; r++) { + // to_process : A vector of Tensors only holding cols in input_columns. + // result_row; : A vector of Tensors to hold the result after Compute(). + // cur_row : A vector of Tensors holding all the columns from DataBuffer. + TensorRow to_process, result_row, cur_row; + RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); + + // Populate the Tensor from the current row to be processed by TensorOp + for (const auto &idx : to_process_indices_) { + to_process.push_back(std::move(cur_row[idx])); + } + + // Looping over multiple TensorOps supplied in to MapOp. + // The assumption is that the result of one TensorOp matches the required input to the next TensorOp. + for (size_t i = 0; i < tfuncs_.size(); i++) { + // TensorOp can operate on single col or multiple cols. MapOp always call compute for multiple cols. + // TensorOp base class will call the single column Compute() depending on the ops. + // Note: The columns of the result_row is not preallocated, the compute function of each tensor op are + // required to resize/push back the result_row + RETURN_IF_NOT_OK(tfuncs_[i]->Compute(to_process, &result_row)); + + // Assign result_row to to_process for the next TensorOp processing, except for the last TensorOp in the list. + if (i + 1 < tfuncs_.size()) { + to_process = std::move(result_row); + } + } + + if (out_columns_.size() != result_row.size()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Result of a tensorOp doesn't match output column names"); + } + + if (in_columns_.size() == out_columns_.size()) { + for (size_t i = 0; i < result_row.size(); i++) { + cur_row[to_process_indices_[i]] = std::move(result_row[i]); + } + new_tensor_table->push_back(std::move(cur_row)); + } else { + // Add the columns we did not touch to the result_row. + for (int32_t i = 0; i < num_cols; i++) { + if (keep_input_columns_[i]) { + result_row.push_back(std::move(cur_row[i])); + } + } + + // Add this final result_row to our new TensorTable. + new_tensor_table->push_back(std::move(result_row)); + } + } + + return Status::OK(); +} + +Status MapOp::ComputeColMap() { + // If the map has not been set up yet in the base class, then set it up + if (column_name_id_map_.empty()) { + std::unordered_map current_name_id_map = child_[0]->column_name_id_map(); + // Initialize private variables + RETURN_IF_NOT_OK(InitPrivateVariable(¤t_name_id_map)); + // Create the final column name to index mapping in the base class field + CreateFinalColMap(¤t_name_id_map); + MS_LOG(DEBUG) << "Column name map for map op set: " << this->ColumnNameMapAsString(); + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +// Validating if each of the input_columns exists in the DataBuffer. +Status MapOp::ValidateInColumns(const std::unordered_map &col_name_id_map) { + for (const auto &inCol : in_columns_) { + bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; + if (!found) { + std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); +} + +Status MapOp::InitPrivateVariable(std::unordered_map *col_name_id_map) { + // If input_columns is empty(), The col at index-0 will be picked. + if (in_columns_.empty()) { + for (const auto &pair : *col_name_id_map) { + if (pair.second == 0) { + MS_LOG(INFO) << "Input columns empty for map op, will apply to the first column in the current table."; + in_columns_.push_back(pair.first); + break; + } + } + + // If caller didn't specify the out_col_names, assume they are same as the input_columns. + // This was done in the constructor, but if input columns was empty to start we have to redo it here. + if (out_columns_.empty() || out_columns_[0].empty()) { + out_columns_ = in_columns_; + } + } + + // Before we continue, issue a sanity check to make sure the input columns from user and the incoming + // columns from child are correct + RETURN_IF_NOT_OK(this->ValidateInColumns(*col_name_id_map)); + + // initialize keep_input_columns, true means to keep the column. + keep_input_columns_.resize(col_name_id_map->size(), true); + for (const auto &col_name : in_columns_) { + int32_t missed = (*col_name_id_map)[col_name]; + keep_input_columns_[missed] = false; + } + + // initialize to_process_indices. + for (const auto &col_name : in_columns_) { + to_process_indices_.push_back((*col_name_id_map)[col_name]); + } + return Status::OK(); +} + +// Create the final column name to index mapping and get indices of the columns this mapop does not use. +void MapOp::CreateFinalColMap(std::unordered_map *col_name_id_map) { + std::unordered_map final_col_name_id_map; + size_t num_cols = col_name_id_map->size(); + std::vector new_ids(num_cols); + if (in_columns_.size() == out_columns_.size()) { + for (size_t i = 0; i < in_columns_.size(); i++) { + int32_t loc = (*col_name_id_map)[in_columns_[i]]; + (void)col_name_id_map->erase(in_columns_[i]); + (*col_name_id_map)[out_columns_[i]] = loc; + } + + // Set the base class final column id map result + column_name_id_map_ = *col_name_id_map; + } else { + int32_t fill_idx = 0; + // First columns of the tables are occupied by the output columns from tensorOp. + for (const auto &col_name : out_columns_) { + final_col_name_id_map[col_name] = fill_idx++; + } + + // Creating new_ids mapping for the columns we keep. + for (size_t i = 0; i < num_cols; i++) { + if (keep_input_columns_[i]) { + new_ids[i] = fill_idx++; + } + } + + // Iterating through the old mapping to update the final mapping for the columns we kept. + std::string name; + for (const auto &pair : *col_name_id_map) { + name = pair.first; + int32_t old_id = pair.second; + if (keep_input_columns_[old_id]) { + final_col_name_id_map[name] = new_ids[old_id]; + } + } + + // Set the base class final column id map result + column_name_id_map_ = final_col_name_id_map; + } +} + +// Visitor accept method for NodePass +Status MapOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.h new file mode 100644 index 0000000000..b1cd58010f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.h @@ -0,0 +1,268 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_MAP_OP_H_ +#define DATASET_ENGINE_DATASETOPS_MAP_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/queue.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class DataBuffer; +class ExecutionTree; + +// MapOp class implements the Map operator. It will apply a list of operations to each record specified by column names. +// The column order behavior after MapOp is as follows. +// [Case 1] If the number of Input Columns == the number of Output Column, column ordering after MapOp +// is the same as the original column order where the Remainder Columns stay in the same position, +// and the Output Columns are placed the same position of the Input Columns. +// For example, initially if the dataset has column order |A, B, C, D, E|, +// and we apply MapOp() with Input Columns {B, C} and Output Columns {X, Y}. +// The column order after applying MapOp will be |A, X, Y, D, E|. +// Note that in this case, |X, Y| is the Output Columns and |A, D, E| which is the Remainder Columns stay in +// their original position, and column B is replaced by column X and column C is replace by column Y. +// [Case 2] If the number of Input Columns != the number of Output Column, column ordering after MapOp +// is Output Columns followed by Remainder Columns. +// For example, initially if the dataset has column order |A, B, C, D, E|, +// and we apply MapOp() with Input Columns {B, C, A} and Output Columns {X, Y}. +// The column order after applying MapOp will be |X, Y, D, E|. +// Note that in this case, |X, Y| is the Output Columns and |D, E| is the Remainder Columns, +// and the Input Columns are gone and replaced by the Output Columns. + +// Keywords: +// Input Columns : a vector of column names (string) passed to MapOp specifying the column names from which +// Tensors are taken and passed to the TensorOp Compute(). +// Output Columns : a vector of column names (string) passed to MapOp specifying what are the column names +// for the Tensors produced by TensorOp Compute(). +// Remainder Columns : columns that exist in the dataset but are not mentioned in Input Columns. +// These columns will not be passed to TensorOp Compute(), but will be appended to the end of the Output Columns. +class MapOp : public ParallelOp { + public: + // The nested builder class inside of the MapOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetInColNames(const std::vector &in_col_names) { + build_in_col_names_ = in_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOutColNames(const std::vector &out_col_names) { + build_out_col_names_ = out_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetTensorFuncs(std::vector> funcs) { + build_tensor_funcs_ = std::move(funcs); + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetPerformanceMode(bool perf_mode) { + build_perf_mode_ = perf_mode; + return *this; + } + + // The builder "build" method creates the final object. + // @param ptr The shared_ptr to the new MapOp object + // @return Status + Status Build(std::shared_ptr *ptr); + + private: + std::vector build_in_col_names_; + std::vector build_out_col_names_; + std::vector> build_tensor_funcs_; + int32_t build_num_workers_; + int32_t build_op_connector_size_; + bool build_perf_mode_; // Default true. + + // Check if the required parameters are set by the builder. + // @return Status The error code return + Status sanityCheck() const; + }; + + // Constructor of MapOp + // @note The builder class should be used to call it. + // @param in_col_names A list of input column names (should match the input/output \p tensorFuncs). + // @param out_col_names A list of output column names (should match the input/output \p tensorFuncs). + // @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data. + // @param num_workers The number of worker threads. + // @param op_connector_size The size of each queue in the connector. + MapOp(const std::vector &in_col_names, const std::vector &out_col_names, + std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, + bool perf_mode); + + // Destructor + ~MapOp() = default; + + // A print method typically used for debugging + // @param out The output stream to write output to + // @param show_all A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out reference to the output stream being overloaded + // @param mo reference to the MapOp to display + // @return the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const MapOp &mo) { + mo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status The error code return + Status operator()() override; + + // Getter + // @return the number of threads consuming data from previous op's output Connector. + int32_t num_consumers() const override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "MapOp"; } + + // List of tensor ops getter/setter + // @Return the vector of tensor ops by non-const reference + + auto &TFuncs() { return tfuncs_; } + + const auto &TFuncs() const { return tfuncs_; } + + private: + // Local queues where worker threads can pop from. + // Popping directly from the Connector can block if the previous designated threads haven't pop. + // Setting the size of these queues to 0 is essentially the same as pulling directly from Connector. + QueueList> local_queues_; + + // Static variables to be ready by worker threads, no modification and readonly + std::vector> tfuncs_; + + // Variable to store the column name that the tensorOps are consuming + std::vector in_columns_; + + // Variable to store the column name that the tensorOps are producing + std::vector out_columns_; + + // Boolean mapping, true means to keep the column. + std::vector keep_input_columns_; + + // Indices of the columns to process. + std::vector to_process_indices_; + + // Performance mode is when the main thread creates local queues, pulls databuffers from the previous + // op's Connector and distributes them to the local queues. Workers pull from the local queues. + // If this flag is false, each worker pulls directly from the Connector. This use less resources + // (thread and memory), but when the computation cost is heavy (e.g. DecodeOp) and fluctuating, it can + // cause additional blocking because pop calls to Connector from the threads are synchronized to enforce the order. + bool perf_mode_; + + // Private function for worker/thread to loop continuously. It comprises the main + // logic of MapOp: getting the data from previous Op, validating user specified column names, + // applying a list of TensorOps to each of the data, process the results and then + // pushing them back to MapOp's output Connector to be fetched by the next Op. + // @param worker_id The id assigned to this thread/worker upon creation. + // @return Status The error code return + Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ + + // Private helper function for getting the next buffer + // When PerformanceMode is enabled, workers pop from the local queue. + // Otherwise, workers pop from the first child output Connector. + // @param p_buffer - the buffer to return + // @return Status return code + Status FetchNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id) { + if (perf_mode_) { + RETURN_IF_NOT_OK(local_queues_[worker_id]->PopFront(p_buffer)); + } else { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id)); + } + return Status::OK(); + } + + // Private function for worker thread to perform TensorOp's compute function and get the result. + // @param in_buffer A raw pointer to the DataBuffer. A raw pointer is fine because this function doesn't manage memory + // and is not shared with other threads. + // @param[out] new_tensor_table A new Tensor Table to be populated in this function. + Status WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table); + + // Private function that create the final column name to index mapping and + // get indices of the columns this mapop does not use. + // @param col_name_id_map The column name to index mapping obtained from child operator + void CreateFinalColMap(std::unordered_map *col_name_id_map); + + // Validating if each of the input_columns exists in the DataBuffer. + // @param - the column map to check + // @return - status return code + Status ValidateInColumns(const std::unordered_map &col_name_id_map); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + // Private function for initializing private variables such as in_columns_, out_columns_. + // @return - Status + Status InitPrivateVariable(std::unordered_map *col_name_id_map); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_MAP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc new file mode 100644 index 0000000000..abb827aea8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/parallel_op.h" + +#include +#include +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +// Constructor +ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler) + : DatasetOp(op_connector_size, sampler), + num_workers_(num_workers), + num_producers_(num_workers), + worker_connector_size_(1), + worker_connector_(nullptr) {} + +// Creates the internal worker connector for the parallel op if the derived class wants to use it +Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { + if (worker_connector_size == 0) { + RETURN_STATUS_UNEXPECTED("Worker connector size 0 is invalid."); + } + num_producers_ = 1; + worker_connector_size_ = worker_connector_size; + // Instantiate the worker connector. This is the internal connector, not the operators + // output connector. It has single master consuming from it (num producers is 1), and the number + // of workers is the defined count from the op. + worker_connector_ = std::make_unique(num_workers_, num_producers_, worker_connector_size); + + return Status::OK(); +} + +// A print method typically used for debugging +void ParallelOp::Print(std::ostream &out, bool show_all) const { + // Summary 1-liner print + if (!show_all) { + out << " [workers: " << num_workers_ << "]"; + // Call super class printer + DatasetOp::Print(out, show_all); + } else { + // Detailed print + DatasetOp::Print(out, show_all); + out << "\nNum workers: " << num_workers_; + } +} + +// Override base class reset to provide reset actions specific to the ParallelOp class. +Status ParallelOp::Reset() { + RETURN_IF_NOT_OK(DatasetOp::Reset()); // Perform any super class reset work + + // ParallelOp is abstract, but we do own the connector between workers and master + // (if the parallel op is configured for this). Reset that connector here. + if (worker_connector_) { + worker_connector_->Reset(); + } + + return Status::OK(); +} + +// Register the internal worker connectors +Status ParallelOp::RegisterWorkerConnectors() { + if (worker_connector_) { + return (worker_connector_->Register(tree_->AllTasks())); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h new file mode 100644 index 0000000000..da54ce1331 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h @@ -0,0 +1,126 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ +#define DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ + +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// global const in our namespace +constexpr int32_t kEndOfActions = -1; + +// Forward declares +class DataBuffer; + +class DbConnector; + +// A ParallelOp provides a multi-threaded DatasetOp +class ParallelOp : public DatasetOp { + public: + // Constructor + // @param num_workers + // @param op_connector_size - size of the output connector for this operator + // @param sampler - The sampler for the op + ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler = nullptr); + + // Destructor + ~ParallelOp() = default; + + // Creates the internal worker connector for the parallel op if the derived class wants to use it. + // @notes This changes the number of producers of this op to 1, since it establishes a master/worker + // relationship within the op, making all production flow through a single master. + // @return Status - The error return code + Status CreateWorkerConnector(int32_t worker_connector_size); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param pO - reference to the ParallelOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ParallelOp &po) { + po.Print(out, false); + return out; + } + + // During tree prepare phase, operators may have specific pre-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + // @return Status - The error return code + Status PrepareNodePreAction() override { + // Run common code from super class before adding ParallelOp specific logic + return (DatasetOp::PrepareNodePreAction()); + } + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + // @return Status - The error return code + Status PrepareNodePostAction() override { + // Run common code from super class before adding ParallelOp specific logic + return (DatasetOp::PrepareNodePostAction()); + } + + // Override base class reset to provide reset actions specific to the ParallelOp class. + // @return Status - The error code return + Status Reset() override; + + // Getter + // @return the number of workers + int32_t num_workers() const override { return num_workers_; } + + // Getter + // @return the number of threads consuming from the previous Connector + int32_t num_consumers() const override { return num_workers_; } + + // Getter + // @return the number of producers pushing to the output Connector + // @notes The number of producers is commonly the same as number of workers, except in the case + // when a worker connector is set up. In that case, there are n workers, and a single master + // such that only 1 thread is a producer rather than the n workers. + // @return the number of producers + int32_t num_producers() const override { return num_producers_; } + + // Register the internal worker connectors. + // @return Status + Status RegisterWorkerConnectors() override; + + protected: + // Interface for derived classes to implement. All derived classes must provide the entry + // function with the main execution loop for worker threads. + // @return Status - The error code return + virtual Status WorkerEntry(int32_t workerId) = 0; + + int32_t num_workers_; // The number of worker threads + int32_t num_producers_; // The number of threads pushing to the out_connector_ + int32_t worker_connector_size_; + std::unique_ptr worker_connector_; // The internal connector for worker threads +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc new file mode 100644 index 0000000000..fff5ba19e7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include +#include + +namespace mindspore { +namespace dataset { +// Constructor +PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr sampler) + : DatasetOp(op_connector_size, sampler) {} + +// A print method typically used for debugging +void PipelineOp::Print(std::ostream &out, bool show_all) const { + // Summary 1-liner print + if (!show_all) { + out << " [workers: "; + if (this->inlined()) { + out << "0 (inlined)]"; + } else { + out << "1]"; // Pipeline ops only have 1 worker + } + // Call super class printer + DatasetOp::Print(out, show_all); + } else { + // Detailed print + DatasetOp::Print(out, show_all); + out << "\nNum workers: "; + if (this->inlined()) { + out << "0 (inlined)"; + } else { + out << "1"; // Pipeline ops only have 1 worker + } + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h new file mode 100644 index 0000000000..0538349f48 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h @@ -0,0 +1,98 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ + +#include +#include +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +namespace mindspore { +namespace dataset { +// forward declare +class ExecutionTree; + +class DataBuffer; + +class PipelineOp : public DatasetOp { + public: + // Constructor + // @param op_connector_size - size of the output connector + // @return Builder setter method returns reference to the builder. + // @param sampler - The sampler for the op + explicit PipelineOp(int32_t op_connector_size, std::shared_ptr sampler = nullptr); + + // Destructor + ~PipelineOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param po - reference to the PipelineOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const PipelineOp &po) { + po.Print(out, false); + return out; + } + + // Getter + // @return The number of workers inside this op. Pipeline ops only have a single worker. + int32_t num_workers() const override { return 1; } + + // Getter + // @return the number of threads consuming from the previous Connector + int32_t num_consumers() const override { return 1; } + + // Getter + // @return The number of threads that push data to the output connector + int32_t num_producers() const override { return 1; } + + // During tree prepare phase, operators may have specific pre-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePreAction() override { + // Run common code from super class before adding PipelineOp specific logic + return (DatasetOp::PrepareNodePreAction()); + } + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePostAction() override { + // Run common code from super class before adding PipelineOp specific logic + return (DatasetOp::PrepareNodePostAction()); + } + + protected: + // ******************************************************************************* + // I'm predicting there will be common arguments or functionality for pipeline ops, + // just not sure yet what those are. perhaps this intermediate class between + // DatasetOp and the actual ops is not needed at all? + // For example, if there's no common code for all of the non-parallel ops, then + // they can just inherit from DatasetOp directly and we can put this class into the + // trash. +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc new file mode 100644 index 0000000000..e232a64164 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc @@ -0,0 +1,159 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/project_op.h" +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +ProjectOp::Builder::Builder(const std::vector &columns_to_project) + : builder_columns_to_project_(columns_to_project) {} + +Status ProjectOp::Builder::SanityCheck() const { + if (builder_columns_to_project_.empty()) { + std::string err_msg("Columns to project is empty."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +Status ProjectOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_columns_to_project_); + return Status::OK(); +} + +ProjectOp::ProjectOp(const std::vector &columns_to_project) + : PipelineOp(0), columns_to_project_(columns_to_project) {} + +void ProjectOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nColumns that are projected:"; + for (size_t i = 0; i < columns_to_project_.size(); i++) { + out << "\n" << columns_to_project_[i]; + } + out << "\n\n"; + } +} + +// Gets a buffer from the child operator and projects the buffer. +Status ProjectOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id, retry_if_eoe)); + if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) { + RETURN_IF_NOT_OK(Project(p_buffer)); + } + return Status::OK(); +} + +Status ProjectOp::Project(std::unique_ptr *data_buffer) { + std::unique_ptr new_tensor_table = std::make_unique(); + while ((*data_buffer)->NumRows() > 0) { + TensorRow current_row; + RETURN_IF_NOT_OK((*data_buffer)->PopRow(¤t_row)); + TensorRow new_row; + (void)std::transform(projected_column_indices_.begin(), projected_column_indices_.end(), + std::back_inserter(new_row), [¤t_row](uint32_t x) { return current_row[x]; }); + new_tensor_table->push_back(new_row); + } + (*data_buffer)->set_tensor_table(std::move(new_tensor_table)); + return Status::OK(); +} + +// Class functor operator () override. +// Most dataset ops operate by launching a thread (see ExecutionTree). +// However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the +// functor since this op runs inlined inside another operator. The function is overloaded to +// ensure that it is not called by mistake (it will generate an error). +Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); } + +int32_t ProjectOp::num_consumers() const { + if (parent_.empty()) { + MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1."; + return 1; + } else if (parent_[0] == nullptr) { + MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0."; + return 0; + } else { + return parent_[0]->num_consumers(); + } +} + +int32_t ProjectOp::num_producers() const { + if (child_.empty() || child_[0] == nullptr) { + MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0."; + return 0; + } else { + return child_[0]->num_producers(); + } +} + +Status ProjectOp::EoeReceived(int32_t worker_id) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } + +// Visitor accept method for NodePass +Status ProjectOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +// Compute the column map and save it into our own column name map +// We cannot use the super class ComputeColMap here because we're making a modification of the +// map from the child map. +Status ProjectOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + std::unordered_map child_column_name_mapping = child_[0]->column_name_id_map(); + for (size_t i = 0; i < columns_to_project_.size(); i++) { + std::string ¤t_column = columns_to_project_[i]; + if (child_column_name_mapping.find(current_column) == child_column_name_mapping.end()) { + std::string err_msg = "ProjectOp: column " + current_column + " does not exist in child operator."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + // Setup the new column name mapping for ourself (base class field) + column_name_id_map_[current_column] = i; + projected_column_indices_.push_back(child_column_name_mapping[current_column]); + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h new file mode 100644 index 0000000000..c2f14d34b7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h @@ -0,0 +1,127 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ +#define DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class ProjectOp : public PipelineOp { + public: + // The nested builder class inside of the ProjectOp is used to help manage all of the arguments + // for constructing it. This repeat op is very simple though, so this builder is really just + // provided for a consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @param columns_to_project - + // @return This is a constructor. + explicit Builder(const std::vector &columns_to_project); + + // Builder destructor. + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new ProjectOp object. + Status Build(std::shared_ptr *); + + private: + std::vector builder_columns_to_project_; + Status SanityCheck() const; + }; + + // Constructor of the ProjectOp. + // @param columnsToProject - + explicit ProjectOp(const std::vector &columns_to_project); + + // Destructor. + ~ProjectOp() = default; + + // A print method typically used for debugging. + // @param out - The output stream to write output to. + // @param show_all - A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload. + // @notes This allows you to write the debug print info using stream operators. + // @param out - reference to the output stream being overloaded. + // @param project_op - reference to the ProjectOp to display. + // @return - the output stream must be returned. + friend std::ostream &operator<<(std::ostream &out, const ProjectOp &project_op) { + project_op.Print(out, false); + return out; + } + + // Class functor operator () override. + // Most dataset ops operate by launching a thread (see ExecutionTree). + // However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the + // functor since this op runs inlined inside another operator. The function is overloaded to + // ensure that it is not called by mistake (it will generate an error). + // @return Status - The error code returned. + Status operator()() override; + + // Gets a buffer from the child node and projects that buffer. The caller is typically our parent node. + // @param p_buffer - output pointer to the projected buffer. + // @param worker_id - The worker id + Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; + + // Base-class override. Return the number of workers in the first parent. + // @param workerId - The worker id + int32_t num_consumers() const override; + + // Base-class override. Return the number of producers in the first child. + // @param workerId - The worker id + int32_t num_producers() const override; + + // Base-class override for special eoe handler. + // Inline operators must override this because there is no connector to push eoe onto. + // @return Status - The error code returned. + Status EoeReceived(int32_t worker_id) override; + + // Base-class override for special eof handler. + // Inline operators must override this because there is no connector to push eof onto. + // @return Status - The error code returned. + Status EofReceived(int32_t worker_id) override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ProjectOp"; } + + private: + std::vector columns_to_project_; + std::vector projected_column_indices_; + + Status Project(std::unique_ptr *data_buffer); + + // Computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc new file mode 100644 index 0000000000..d12660e6f9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc @@ -0,0 +1,182 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/rename_op.h" +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// builds +RenameOp::Builder::Builder() { + // Some arguments to the RenameOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the RenameOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status RenameOp::Builder::SanityCheck() const { return Status::OK(); } + +// build method for RenameOp +Status RenameOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_in_columns_, builder_out_columns_, builder_op_connector_size_); + return Status::OK(); +} + +// constructor +RenameOp::RenameOp(const std::vector &in_col_names, const std::vector &out_col_names, + int32_t op_connector_size) + : PipelineOp(op_connector_size), in_columns_(in_col_names), out_columns_(out_col_names) {} + +// destructor +RenameOp::~RenameOp() {} + +// main entry point for rename +Status RenameOp::operator()() { + TaskManager::FindMe()->Post(); + std::unique_ptr curr_buffer; + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + if (curr_buffer->buffer_flags() != DataBuffer::kDeBFlagNone) { + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + std::string err_msg = "Rename first buffer got was control signal"; + // if 1st eoe or eof, pass it on then return + RETURN_STATUS_UNEXPECTED(err_msg); + } + + while (curr_buffer->eof() == false) { + while (curr_buffer->eoe() == false) { + // push the renamed input buffer + MS_LOG(DEBUG) << "Rename operator pushing next buffer."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } // end of while eoe loop + + // we got eoe, now try again until we get eof + MS_LOG(DEBUG) << "Rename operator EOE Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + MS_LOG(DEBUG) << "Rename operator fetching buffer after EOE."; + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } // end of while eof loop + + MS_LOG(DEBUG) << "Rename opeerator EOF Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Rename core functionality to compute the new column name id map. +// We need to overwrite the super class ComputeColMap here because we're making a modification of the +// map from the child map. +Status RenameOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + column_name_id_map_ = child_[0]->column_name_id_map(); + // iterate over my index in input vector, find the corresponding position + std::unordered_map new_col_name_id_map = {}; + // parameter for input check + size_t found = 0; + + // iterate over all the pairs and if there is a name match with rename, rename the column and add it to new map + // by doing it this way we recreate a new ColNameIdMap and allow for switching + for (const auto &pair : column_name_id_map_) { + std::string name = pair.first; + int32_t id = pair.second; + // find name + std::vector::iterator it; + it = std::find(in_columns_.begin(), in_columns_.end(), name); + // for c input checks here we have to count the number of times we find the stuff in in_columns_ + // because we iterate over the mInputList n times + if (it != in_columns_.end()) { + // found + found += 1; + int index = std::distance(in_columns_.begin(), it); + MS_LOG(DEBUG) << "Rename operator index found " << index << " value " << id << "."; + + new_col_name_id_map[out_columns_[index]] = id; + } else { + // not found + MS_LOG(DEBUG) << "Rename operator index not found: " << id << " is the column id."; + new_col_name_id_map[name] = id; + } + } + // only checks number of renamed columns have been found, this input check doesn't check everything + if (found != in_columns_.size()) { + MS_LOG(DEBUG) << "Rename operator column names found: " << found << " out of " << in_columns_.size() << "."; + std::string err_msg = "Renamed column doesn't exist in dataset"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Now, overwrite our column map with the new renamed columns/id's + column_name_id_map_ = new_col_name_id_map; + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +// prints rename +void RenameOp::Print(std::ostream &out, // In: The output stream to print to + bool show_all) const { // In: T/F if it should print everything + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nIn columns:"; + for (size_t i = 0; i < in_columns_.size(); ++i) { + out << "\n " << in_columns_[i]; + } + for (size_t i = 0; i < out_columns_.size(); ++i) { + out << "\n " << out_columns_[i]; + } + out << "\n\n"; + } +} + +Status RenameOp::EofReceived(int32_t) { + MS_LOG(DEBUG) << "Rename operator EOF received, do nothing now."; + return Status::OK(); +} + +Status RenameOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +// Visitor accept method for NodePass +Status RenameOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h new file mode 100644 index 0000000000..d846bb1b40 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h @@ -0,0 +1,138 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ +#define DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// forward declare +class DataBuffer; + +class RenameOp : public PipelineOp { + public: + // The nested builder class inside of the RenameOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetInColNames(const std::vector &in_col_names) { + builder_in_columns_ = in_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOutColNames(const std::vector &out_col_names) { + builder_out_columns_ = out_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // The builder "build" method creates the ZipOp dataset Operator. + // @return shared_ptr to the new RenameOp object + Status Build(std::shared_ptr *); + + private: + std::vector builder_in_columns_; + std::vector builder_out_columns_; + int32_t builder_op_connector_size_; + + Status SanityCheck() const; + }; + + // Constructor for RenameOp + // @param in_col_names names of columns to rename + // @param out_col_names names of columns after rename + // @param op_connector_size connector size + RenameOp(const std::vector &in_col_names, // In: Col names to consume + const std::vector &out_col_names, // In: Col names to produce + int32_t op_connector_size); + + // Destructor + ~RenameOp(); + + Status EofReceived(int32_t) override; + + Status EoeReceived(int32_t) override; + + // Print function for Rename + // @param out output stream to print to + // @param show_all if it should print everything + void Print(std::ostream &out, bool show_all) const override; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const RenameOp &ro) { + ro.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "RenameOp"; } + + protected: + // Rename core functionality + // Computing the assignment of the new column name map. + // @return - Status + Status ComputeColMap() override; + + // Variable to store the input column names + std::vector in_columns_; + + // Variable to store the output column names + std::vector out_columns_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc new file mode 100644 index 0000000000..6d3dc91ed3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -0,0 +1,199 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {} + +Status RepeatOp::Builder::SanityCheck() const { + if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) { + std::string err_msg("Repeat count must be > 0 or -1."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status RepeatOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_max_repeats_); + return Status::OK(); +} + +// Constructor of the RepeatOp. +RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {} + +// Destructor +RepeatOp::~RepeatOp() {} + +// A print method typically used for debugging +void RepeatOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [repeats: " << max_repeats_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ + << "\nLeaf Nodes in execution path:"; + if (!eoe_ops_.empty()) { + for (size_t i = 0; i < eoe_ops_.size(); i++) { + out << "\n Operator: " << eoe_ops_[i]->id(); + } + } else { + out << " None."; + } + out << "\n\n"; + } +} + +// This function returns the buffer that is at the top of our output connector. The caller is +// typically our parent node, when the parent is asking us to provide the next buffer of data. +// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get +// a buffer from our child. +// This function sets the `retryIfEoe` flag when popping from the child connector. This way, +// this function will retry to pop the connector again and will get the non-EOE buffer if any. +Status RepeatOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { + if (child_.empty()) { + RETURN_STATUS_UNEXPECTED("RepeatOp can't be the leaf node."); + } + + std::unique_ptr buf; + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); + // Loop until non EOE is received + while (buf->eoe()) { + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + if (state_ == OpState::kDeOpIdle) { + *p_buffer = std::move(buf); + return Status::OK(); + } + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); + } + // Check if the last buf is next eof + if (buf->eof()) { + RETURN_IF_NOT_OK(EofReceived(worker_id)); + } + *p_buffer = std::move(buf); + return Status::OK(); +} + +// Base-class override for handling cases when an eoe is received. +Status RepeatOp::EoeReceived(int32_t worker_id) { + repeat_count_++; + MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ + << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; + bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); + bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); + // If we've reached the requested repeat count, then flag the eoe nodes + // to tell them they've got one more epoch to perform. When they reach the end + // of the last epoch, they quit rather than loop again. This happens in two cases: + // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or, + // 2- We are not repeated + if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) { + for (auto &eoe_op : eoe_ops_) { + eoe_op->set_control_flag(kDeOpLastRepeat); + } + } + if (repeat_count_ == max_repeats_) { + repeat_count_ = 0; + state_ = OpState::kDeOpIdle; + return Status::OK(); + } + + // Invoke a reset against the eoe nodes only. + for (auto &eoe_op : eoe_ops_) { + RETURN_IF_NOT_OK(eoe_op->Reset()); + } + + return Status::OK(); +} + +// Class functor operator () override. +// Most dataset ops operate by launching a thread (see ExecutionTree). +// However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the +// functor since this op runs inlined inside another operator. The function is overloaded to +// ensure that it is not called by mistake (it will generate an error). +Status RepeatOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. RepeatOp is an inlined operator."); } + +// Base-class override for handling cases when an eof is received. +Status RepeatOp::EofReceived(int32_t worker_id) { + MS_LOG(DEBUG) << "Repeat operator EOF received, do nothing now."; + return Status::OK(); +} + +int32_t RepeatOp::num_consumers() const { + if (parent_.empty()) { + MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1."; + return 1; + } else if (parent_[0] == nullptr) { + MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0."; + return 0; + } else { + return parent_[0]->num_consumers(); + } +} + +// Drive reset actions if needed +Status RepeatOp::Reset() { + // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. + // In that case, we now have to bounce the reset down to our own eoe ops. + MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset."; + for (auto &eoe_op : eoe_ops_) { + RETURN_IF_NOT_OK(eoe_op->Reset()); + } + state_ = OpState::kDeOpRunning; + return Status::OK(); +} + +int32_t RepeatOp::num_producers() const { + if (child_.empty() || child_[0] == nullptr) { + MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; + return 0; + } else { + return child_[0]->num_producers(); + } +} + +// Pre-Visitor accept method for NodePass +Status RepeatOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + +// Visitor accept method for NodePass +Status RepeatOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h new file mode 100644 index 0000000000..f5259de30e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h @@ -0,0 +1,146 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ +#define DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class RepeatOp : public PipelineOp { + public: + static constexpr int32_t kInfiniteRepeat = -1; + + // The nested builder class inside of the RepeatOp is used to help manage all of the arguments + // for constructing it. This repeat op is very simple though, so this builder is really just + // provided for a consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @param count - The number of repeats to do + // @return This is a constructor. + explicit Builder(int32_t count); + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new RepeatOp object + Status Build(std::shared_ptr *); + + private: + int32_t build_max_repeats_; + + Status SanityCheck() const; + }; + + // Constructor of the RepeatOp. + // @note The builder class should be used to call it + // @param count - The number of repeats to do + explicit RepeatOp(int32_t count); + + // Destructor + ~RepeatOp(); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param ro - reference to the RepeatOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const RepeatOp &ro) { + ro.Print(out, false); + return out; + } + + // Class functor operator () override. + // Most dataset ops operate by launching a thread (see ExecutionTree). + // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the + // functor since this op runs inlined inside another operator. The function is overloaded to + // ensure that it is not called by mistake (it will generate an error). + // @return Status - The error code return + Status operator()() override; + + // This function returns the buffer that is at the top of our output connector. The caller is + // typically our parent node, when the parent is asking us to provide the next buffer of data. + // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get + // a buffer from our child. + // @note This function sets the `retryIfEoe` flag when popping from the child connector. This way, + // this function will retry to pop the connector again and will get the non-EOE buffer if any. + // @param p_buffer - output pointer to the buffer that it will fetch. + // @param worker_id - The worker id + // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. + // @return Status - The error code return + Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; + + // Base-class override for handling cases when an eoe is received. + // @param worker_id - The worker id + Status EoeReceived(int32_t worker_id) override; + + // Base-class override for handling cases when an eof is received. + // @param worker_id - The worker id + Status EofReceived(int32_t worker_id) override; + + /// \brief reset Op + /// \@return Status - The error code return + Status Reset() override; + + // Base-class override. Return the number of workers in the first parent. + // @param workerId - The worker id + int32_t num_consumers() const override; + + // Base-class override. Return the number of producers in the first child. + // @param workerId - The worker id + int32_t num_producers() const override; + + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "RepeatOp"; } + + /// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes + /// \param[in] eoe_op The input leaf/eoe operator to add to the list + void AddToEoeList(std::shared_ptr eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } + + private: + int32_t max_repeats_; // The number of repeats that the user requested + int32_t repeat_count_; // A counter for the current number of executed repeats + std::vector> eoe_ops_; // List of operators that can generate EOE underneath this repeat. +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc new file mode 100644 index 0000000000..0eb5f29eaf --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc @@ -0,0 +1,304 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(_WIN32) || defined(_WIN64) +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +constexpr int32_t ShuffleOp::kShuffleStateInit; +constexpr int32_t ShuffleOp::kShuffleStateActive; +constexpr int32_t ShuffleOp::kShuffleStateDrain; + +// Builder constructor. Creates the builder object. +ShuffleOp::Builder::Builder() : build_shuffle_size_(0), build_reshuffle_each_epoch_(true) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_op_connector_size_ = cfg->op_connector_size(); + build_rows_per_buffer_ = cfg->rows_per_buffer(); + build_shuffle_seed_ = GetSeed(); +} + +Status ShuffleOp::Builder::SanityCheck() const { + if (build_shuffle_size_ < 2) { + RETURN_STATUS_UNEXPECTED("Shuffle buffer size must be greater than 1."); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status ShuffleOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_shuffle_size_, build_shuffle_seed_, build_op_connector_size_, + build_reshuffle_each_epoch_, build_rows_per_buffer_); + return Status::OK(); +} + +// Constructor of the ShuffleOp +ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch, + int32_t rows_per_buffer) + : PipelineOp(op_connector_size), + shuffle_size_(shuffle_size), + shuffle_seed_(shuffle_seed), + reshuffle_each_epoch_(reset_every_epoch), + rng_(shuffle_seed), + buffer_counter_(0), + rows_per_buffer_(rows_per_buffer), + shuffle_buffer_(std::make_unique()), + shuffle_last_row_idx_(0), + shuffle_buffer_state_(kShuffleStateInit) {} + +// Private function to re-init the shuffle op for another epoch. Shuffle op calls this by +// itself rather than waiting for the reset driven from operators above it in the pipeline. +Status ShuffleOp::SelfReset() { + MS_LOG(DEBUG) << "Shuffle operator performing a self-reset."; + // If reshuffle_each_epoch is false, then we always use the same seed for every + // epoch. + // If reshuffle_each_epoch is true, then the first epoch uses the given seed, + // and all subsequent epochs will then keep on using the rng_ without resetting it + if (!reshuffle_each_epoch_) { + rng_ = std::mt19937_64(shuffle_seed_); + } + + shuffle_buffer_ = std::make_unique(); + buffer_counter_ = 0; + shuffle_last_row_idx_ = 0; + shuffle_buffer_state_ = kShuffleStateInit; + return Status::OK(); +} + +// A print method typically used for debugging +void ShuffleOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [shuffle size: " << shuffle_size_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nShuffle size: " << shuffle_size_ << "\nRows per buffer: " << rows_per_buffer_ + << "\nShuffle buffer state: " << shuffle_buffer_state_ << "\nShuffle seed: " << shuffle_seed_ << "\n\n"; + } +} + +// Private function to add a new row to the shuffle buffer. +Status ShuffleOp::AddRowToShuffleBuffer(TensorRow new_shuffle_row) { + // If the last slot of our shuffle buffer was not the full size of the shuffle buffer then we are + // filling it during the initial fill codepath and thus growing it's size. In that case, we push + // back the new row to grow our shuffle buffer size by 1. + // If we are already at the full size, then we overwrite the last slot with our row (and the last + // slot better be empty because it should already have been swapped out during the random row + // selection that was done previously!) + if (shuffle_last_row_idx_ < (shuffle_size_ - 1)) { + shuffle_buffer_->push_back(std::move(new_shuffle_row)); + shuffle_last_row_idx_ = (shuffle_buffer_->size()) - 1; + } else { + if (!(*shuffle_buffer_)[shuffle_last_row_idx_].empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Last row of shuffle buffer should not be occupied!"); + } + (*shuffle_buffer_)[shuffle_last_row_idx_] = std::move(new_shuffle_row); + } + return Status::OK(); +} + +// Class functor operator () override. +// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will +// provide the master loop that drives the logic for performing the work +Status ShuffleOp::operator()() { + std::unique_ptr new_buffer_table; // A tensor table to be used for output. + + // Synchronize with TaskManager once the thread is launched. + TaskManager::FindMe()->Post(); + + // Shuffle op does not have workers, and only consumes from child 0. + // Create the child iterator to fetch our data from. + int32_t worker_id = 0; + int32_t child_idx = 0; + child_iterator_ = std::make_unique(this, worker_id, child_idx); + + // Main operator loop + while (true) { + // Do an initial populate of the shuffle buffer + RETURN_IF_NOT_OK(InitShuffleBuffer()); + + // This is our main loop exit condition, when the iterator has no more data completely. + if (child_iterator_->eof_handled()) { + break; + } + + // Next, enter into the main execution loop of the shuffle op. + // When the tail index position of our shuffle buffer goes negative it means that we've + // fully drained the data from the shuffle buffer and we're done. + while (shuffle_last_row_idx_ >= 0) { + // Step 1) + // Create an output tensor table if one is not created yet. + if (!new_buffer_table) { + new_buffer_table = std::make_unique(); + } + + // Step 2) + // Randomly select a slot from our shuffle buffer and copy that row into the output + // tensor table. We remove the data from the shuffle buffer, leaving that slot + // in the table as an empty vector + int64_t random_slot = rng_() % (shuffle_last_row_idx_ + 1); + new_buffer_table->push_back(std::move((*shuffle_buffer_)[random_slot])); + + // Step 3) + // If the output tensor table is at the requested size, then create a buffer for it + // and send this buffer on it's way up the pipeline. Special case is if this is the + // last row then we also send it. + if (new_buffer_table->size() == rows_per_buffer_ || shuffle_last_row_idx_ == 0) { + auto new_buffer = std::make_unique(buffer_counter_, DataBuffer::kDeBFlagNone); + new_buffer->set_tensor_table(std::move(new_buffer_table)); + buffer_counter_++; + MS_LOG(DEBUG) << "Shuffle operator sending a buffer to output."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(new_buffer))); + } + + // Step 4) + // Take the last row from shuffle buffer, and swap it into the row position that was + // just vacated. This makes the shuffle buffer contiguous, with an empty slot at the + // tail of the shuffle buffer. + if (random_slot != shuffle_last_row_idx_) { + (*shuffle_buffer_)[random_slot] = std::move((*shuffle_buffer_)[shuffle_last_row_idx_]); + } + + // Step 5) + // Refill the last slot of the shuffle buffer with the next row from input if we are in the + // active state. + // If we are in the draining state, we do not need to fetch another row to replace the one we + // just drained. + if (shuffle_buffer_state_ == kShuffleStateActive) { + TensorRow new_row; + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + + if (!new_row.empty()) { + RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); + } else { + shuffle_buffer_state_ = kShuffleStateDrain; + } + } + + // If we are draining, reposition (decrement) our tail index in the shuffle buffer since we + // just drained a row from it. + if (shuffle_buffer_state_ == kShuffleStateDrain) { + shuffle_last_row_idx_--; + } + } + + // Since we overloaded eoeReceived function, we are responsible to flow the EOE up the + // pipepline manually now that we are done draining the shuffle buffer + MS_LOG(DEBUG) << "Shuffle operator sending EOE."; + auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + // Do not wait for any reset to be flown down from operators above us. + // Instead, manually update ourselves and then go reloop to start fetching from child operator + // right away. Any Reset() from the parent will still perform common reset actions. + RETURN_IF_NOT_OK(this->SelfReset()); + } + return Status::OK(); +} + +// Private function populate the shuffle buffer initially by fetching from the child output +// connector until the shuffle buffer is full (or there is no more data coming). +Status ShuffleOp::InitShuffleBuffer() { + MS_LOG(DEBUG) << "Shuffle operator initializing the shuffle buffer."; + + // The first phase of this operator is to read incoming buffers and then drain those + // rows from the buffers, putting them into our own local table of tensors (the shuffle + // buffer). + // This shuffle buffer initialization phase stops when we've either filled up the + // shuffle buffer to it's max size, or the dataset below us is not providing any more + // rows. + if (shuffle_buffer_state_ != kShuffleStateInit) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Invalid shuffle buffer state (SHUFFLE_STATE_INIT expected)"); + } + + // Before we drop into the fetching loop, call the fetch once for the first time + // to fill the first row and grab the first buffer. + TensorRow new_row; + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + + if (child_iterator_->eof_handled()) { + MS_LOG(DEBUG) << "Shuffle operator init picked up EOF. No more epochs."; + return Status::OK(); + } + + if (new_row.empty()) { + RETURN_STATUS_UNEXPECTED("Unable to fetch a single row for shuffle buffer."); + } + + // Now fill the rest of the shuffle buffer until we are unable to get the next row or we reached + // the desired shuffle buffer size. + while (!new_row.empty() && shuffle_buffer_->size() < static_cast(shuffle_size_ - 1)) { + // Add the previously fetched row + RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); + + // Fetch the next row + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + } + + // If we quit the loop due to being at the shuffle size, still need to add the last row here. + if (!new_row.empty()) { + RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); + shuffle_buffer_state_ = kShuffleStateActive; // Transition to the active state + } else { + // If init phase doesn't have more rows, then skip the active state and jump straight to the + // shuffle buffer draining state + shuffle_buffer_state_ = kShuffleStateDrain; + } + + MS_LOG(DEBUG) << "Shuffle operator finished intializing the shuffle buffer."; + return Status::OK(); +} + +Status ShuffleOp::EoeReceived(int32_t worker_id) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +// Visitor accept method for NodePass +Status ShuffleOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h new file mode 100644 index 0000000000..86bea7cc77 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h @@ -0,0 +1,204 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class ExecutionTree; + +class DbConnector; + +class DataBuffer; + +class ShuffleOp : public PipelineOp { + // Shuffle buffer state flags + // + // Shuffle buffer is in a state of being initialized + static constexpr int32_t kShuffleStateInit = 0; + + // Shuffle buffer is in a state of being actively drained from, but refilling as well + static constexpr int32_t kShuffleStateActive = 1; + + // Shuffle buffer is in a state of being drained + static constexpr int32_t kShuffleStateDrain = 2; + + public: + // The nested builder class inside of the ShuffleOp is used to help manage all of the arguments + // for constructing it. The shuffle op is fairly simple though, but the builder provides a + // consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetShuffleSize(int32_t shuffle_size) { + build_shuffle_size_ = shuffle_size; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetShuffleSeed(uint32_t shuffle_seed) { + build_shuffle_seed_ = shuffle_seed; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + build_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetReshuffleEachEpoch(bool reshuffle_each_epoch) { + build_reshuffle_each_epoch_ = reshuffle_each_epoch; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + build_op_connector_size_ = op_connector_size; + return *this; + } + + // The builder "build" method creates the final object. + // @return shared_ptr to the new ShuffleOp object + Status Build(std::shared_ptr *); + + private: + // The builder saves all ShuffleOp construction arguments internally. + // The following are the arguments. + int32_t build_shuffle_size_; + uint32_t build_shuffle_seed_; + int32_t build_rows_per_buffer_; + bool build_reshuffle_each_epoch_; + int32_t build_op_connector_size_; + + Status SanityCheck() const; + }; + + // Constructor of the ShuffleOp + // @note The builder class should be used to call it + // @param shuffle_size - The size for the shuffle buffer + // @param shuffle_seed - The seed to use for random number generation + // @param op_connector_size - The output connector queue size + // @param rows_per_buffer - The requested number of rows per buffer + ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch, + int32_t rows_per_buffer); + + // Destructor + ~ShuffleOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param so - reference to the ShuffleOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ShuffleOp &so) { + so.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Base-class override for special eoe handler. + // ShuffleOp must override this because it shall not perform default handling of eoe. Instead + // the ShuffleOp needs to manage actions related to the end of the epoch itself. + // @return Status - The error code return + Status EoeReceived(int32_t worker_id) override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ShuffleOp"; } + + private: + // Private function to add a new row to the shuffle buffer. + // @return Status - The error code return + Status AddRowToShuffleBuffer(TensorRow new_shuffle_row); + + // Private function to populate the shuffle buffer initially by fetching from the child output + // connector until the shuffle buffer is full (or there is no more data coming). + // @return Status - The error code return + Status InitShuffleBuffer(); + + // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by + // itself rather than waiting for the reset driven from operators above it in the pipeline. + // @return Status - The error code return + Status SelfReset(); + + int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows) + uint32_t shuffle_seed_; + bool reshuffle_each_epoch_; + // rng_ is seeded initially with shuffle_seed_. mt19937 is used for its large period. + // specifically mt19937_64 is used to generate larger random numbers to reduce bias when + // modding to fit within our desired range. we dont use a distribution + // (ie uniform_int_distribution) because we will need to create up to |dataset| instances + // of the distribution object in the common case of a perfect shuffle + std::mt19937_64 rng_; + int32_t buffer_counter_; // For creating new buffer id's + int32_t rows_per_buffer_; // Number of rows to pack into output buffer + // A single (potentially large) buffer of tensor rows for performing shuffling. + std::unique_ptr shuffle_buffer_; + int32_t shuffle_last_row_idx_; // Internal tracking of the last slot of our shuffle buffer + int32_t shuffle_buffer_state_; // State tracking for the shuffle buffer phases of work + + std::unique_ptr child_iterator_; // An iterator for fetching. +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc new file mode 100644 index 0000000000..2fe8cbeaa6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/skip_op.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status SkipOp::Builder::SanityCheck() const { + if (build_max_skips_ < 0) { + std::string err_msg("Skip count must be positive integer or 0."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status SkipOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_max_skips_, builder_op_connector_size_); + return Status::OK(); +} + +// Constructor of the SkipOp. +SkipOp::SkipOp(int32_t count, int32_t op_connector_size) + : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} + +// Destructor +SkipOp::~SkipOp() {} + +// A print method typically used for debugging +void SkipOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [skips: " << max_skips_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nSkip count: " << skip_count_ << "\nMax skips: " << max_skips_ << "\n\n"; + } +} + +// Base-class override for handling cases when an eoe is received. +Status SkipOp::EoeReceived(int32_t worker_id) { + skip_count_ = 0; + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +// main entry point for skip +Status SkipOp::operator()() { + TaskManager::FindMe()->Post(); + std::unique_ptr curr_buffer; + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + + while (curr_buffer->eof() == false) { + // Reset count + skip_count_ = 0; + while (curr_buffer->eoe() == false) { + // Drop first count rows + while (skip_count_ < max_skips_) { + if (curr_buffer->eoe() || curr_buffer->eof()) { + break; + } + // Consider the rows of buffer more than one + TensorRow drop_row; + int row_num = curr_buffer->NumRows(); + int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; + skip_count_ += drop_num; + for (int i = 0; i < drop_num; i++) { + RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row)); + } + if (curr_buffer->NumRows() == 0) { + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + } + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + // we got eoe, now try again until we got eof + MS_LOG(DEBUG) << "Skip operator EOE Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + + MS_LOG(DEBUG) << "Skip operator EOF Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Base-class override for handling cases when an eof is received. +Status SkipOp::EofReceived(int32_t worker_id) { + MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; + return Status::OK(); +} + +// Visitor accept method for NodePass +Status SkipOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h new file mode 100644 index 0000000000..a717d0efa4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ + +#include +#include +#include +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class SkipOp : public PipelineOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @param count - The number of skip to do + // @return This is a constructor. + explicit Builder(int32_t count); + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new SkipOp object + Status Build(std::shared_ptr *); + + private: + int32_t build_max_skips_; + int32_t builder_op_connector_size_; + + Status SanityCheck() const; + }; + + // Constructor of the SkipOp. + // @note The builder class should be used to call it + // @param count - The number of skips to do + explicit SkipOp(int32_t count, int32_t op_connector_size); + + // Destructor + ~SkipOp(); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Base-class override for handling cases when an eoe is received. + // @param worker_id - The worker id + Status EoeReceived(int32_t worker_id) override; + + // Base-class override for handling cases when an eof is received. + // @param worker_id - The worker id + Status EofReceived(int32_t worker_id) override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "SkipOp"; } + + private: + int32_t max_skips_; // The number of skips that the user requested + int32_t skip_count_; // A counter for the current number of executed skips +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt new file mode 100644 index 0000000000..389e3f5af6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt @@ -0,0 +1,32 @@ +add_subdirectory(sampler) +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) + +set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES + io_block.cc + image_folder_op.cc + mnist_op.cc + coco_op.cc + cifar_op.cc + random_data_op.cc + celeba_op.cc + text_file_op.cc + clue_op.cc + ) + +set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES + ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES} + mindrecord_op.cc + tf_reader_op.cc + ) + +if (ENABLE_PYTHON) + set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES + ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES} + generator_op.cc + voc_op.cc + manifest_op.cc + ) +endif() + +add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}) \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc new file mode 100644 index 0000000000..9d7d5622a6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -0,0 +1,430 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" + +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { +CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status CelebAOp::Builder::Build(std::shared_ptr *op) { + MS_LOG(DEBUG) << "Celeba dataset directory is " << builder_dir_.c_str() << "."; + MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << "."; + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + // label is like this:0 1 0 0 1...... + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + *op = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, + builder_op_connector_size_, builder_decode_, builder_dataset_type_, + builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_)); + if (*op == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null"); + } + + return Status::OK(); +} + +Status CelebAOp::Builder::SanityCheck() { + Path dir(builder_dir_); + std::string err_msg; + err_msg += dir.IsDirectory() ? "" : "CelebA path is invalid or not set\n"; + err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is smaller than 1\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, + bool decode, const std::string &dataset_type, const std::set &exts, + std::unique_ptr schema, std::shared_ptr sampler) + : ParallelOp(num_workers, queue_size, std::move(sampler)), + rows_per_buffer_(rows_per_buffer), + folder_path_(dir), + decode_(decode), + extensions_(exts), + data_schema_(std::move(schema)), + num_rows_in_attr_file_(0), + dataset_type_(dataset_type) { + attr_info_queue_ = std::make_unique>>(queue_size); + io_block_queues_.Init(num_workers_, queue_size); +} + +Status CelebAOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "tree_ not set"); + } + + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(attr_info_queue_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Walking attr file", std::bind(&CelebAOp::ParseAttrFile, this))); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(ParseImageAttrInfo()); + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + + return Status::OK(); +} + +Status CelebAOp::ParseAttrFile() { + TaskManager::FindMe()->Post(); + Path folder_path(folder_path_); + std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString()); + if (!attr_file.is_open()) { + return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "Celeba attr file does not exist"); + } + + const auto PushBackToQueue = [this](std::vector &vec, std::ifstream &attr_file, + std::ifstream &partition_file) { + Status s = attr_info_queue_->EmplaceBack(vec); + if (s.IsError()) { + CLOSE_FILE(attr_file, partition_file); + return s; + } + return Status::OK(); + }; + + std::string rows_num; + std::string attr_name; + (void)getline(attr_file, rows_num); + try { + num_rows_in_attr_file_ = static_cast(std::stoul(rows_num)); // First line is rows number in attr file + } catch (std::invalid_argument &e) { + RETURN_STATUS_UNEXPECTED("Conversion to ulong failed, invalid argument."); + } catch (std::out_of_range &e) { + RETURN_STATUS_UNEXPECTED("Conversion to ulong failed, out of range."); + } + + (void)getline(attr_file, attr_name); // Second line is attribute name,ignore it + std::string image_info; + std::vector image_infos; + image_infos.reserve(oc_queue_size_); + while (getline(attr_file, image_info)) { + if ((image_info.empty()) || (dataset_type_ != "all" && !CheckDatasetTypeValid())) { + continue; + } + image_infos.push_back(image_info); + if (image_info.size() % oc_queue_size_ == 0) { + RETURN_IF_NOT_OK(PushBackToQueue(image_infos, attr_file, partition_file_)); + image_infos.clear(); + } + } + if (!image_infos.empty()) { + RETURN_IF_NOT_OK(PushBackToQueue(image_infos, attr_file, partition_file_)); + } + std::vector end_indicator = std::vector(0); + RETURN_IF_NOT_OK(PushBackToQueue(end_indicator, attr_file, partition_file_)); // end indicator + CLOSE_FILE(attr_file, partition_file_); + return Status::OK(); +} + +bool CelebAOp::CheckDatasetTypeValid() { + if (!partition_file_.is_open()) { + Path folder_path(folder_path_); + partition_file_.open((folder_path / "list_eval_partition.txt").toString()); + if (!partition_file_.is_open()) { + MS_LOG(ERROR) << "Celeba partition file does not exist!"; + return false; + } + } + std::string line; + (void)getline(partition_file_, line); + std::vector vec = Split(line); + if (vec.size() != 2) { + return false; + } + int32_t type; + try { + type = std::stoi(vec[1]); + } catch (std::invalid_argument &e) { + MS_LOG(WARNING) << "Conversion to unsigned long failed, invalid argument, " << vec[0] << "."; + return false; + } catch (std::out_of_range &e) { + MS_LOG(WARNING) << "Conversion to unsigned long failed, out of range, " << vec[0] << "."; + return false; + } + // train:0, valid=1, test=2 + if (dataset_type_ == "train" && (type == 0)) { + return true; + } else if (dataset_type_ == "valid" && (type == 1)) { + return true; + } else if (dataset_type_ == "test" && (type == 2)) { + return true; + } + + return false; +} + +Status CelebAOp::ParseImageAttrInfo() { + std::vector image_infos; + bool needMoreData = true; + RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); + while (!image_infos.empty() && needMoreData) { + for (uint32_t index = 0; index < image_infos.size(); index++) { + std::string image_info = image_infos[index]; + std::vector split = Split(image_info); + std::pair> image_labels; + + Path path(folder_path_); + Path file_path = path / split[0]; + if (!extensions_.empty() && extensions_.find(file_path.Extension()) == extensions_.end()) { + MS_LOG(WARNING) << "Unsupported file found at " << file_path.toString().c_str() << ", its extension is " + << file_path.Extension().c_str() << "."; + continue; + } + image_labels.first = split[0]; + for (uint32_t label_index = 1; label_index < split.size(); label_index++) { + int32_t value; + try { + value = std::stoi(split[label_index]); + } catch (std::invalid_argument &e) { + RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument."); + } catch (std::out_of_range &e) { + RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); + } + image_labels.second.push_back(value); + } + + image_labels_vec_.push_back(image_labels); + } + + RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); + } + + num_rows_ = image_labels_vec_.size(); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API " + "validation first."); + } + MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_ << "."; + return Status::OK(); +} + +std::vector CelebAOp::Split(const std::string &line) { + std::string str = line; + std::string::size_type pos; + std::vector split; + str += " "; + int size = str.size(); + for (uint32_t index = 0; index < size;) { + pos = str.find(" ", index); + if (pos != index) { // skip space + std::string s = str.substr(index, pos - index); + split.push_back(s); + } + index = pos + 1; + } + + return split; +} + +// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work +Status CelebAOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr data_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&data_buffer)); + RETURN_IF_NOT_OK(AddIOBlock(&data_buffer)); + return Status::OK(); +} + +Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { + int64_t buff_count = 0; + while (true) { + std::vector keys; + keys.reserve(rows_per_buffer_); + int64_t row_count = 0; + while (!(*data_buffer)->eoe()) { + TensorRow sample_row; + RETURN_IF_NOT_OK((*data_buffer)->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) >= num_rows_) { + MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_ << "."; + continue; + } + keys.push_back(*itr); + row_count++; + if (row_count % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buff_count++ % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); + } + + if (!keys.empty()) { + RETURN_IF_NOT_OK(io_block_queues_[(buff_count++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + RETURN_IF_NOT_OK( + io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK( + io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset + RETURN_IF_NOT_OK( + io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); + } + } +} + +Status CelebAOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty()) { + return Status::OK(); // empty key is a quit signal for workers + } + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unexpected nullptr received in worker"); +} + +Status CelebAOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + for (const auto &key : keys) { + TensorRow row; + RETURN_IF_NOT_OK(LoadTensorRow(key, image_labels_vec_[key], &row)); + deq->push_back(std::move(row)); + } + + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +Status CelebAOp::LoadTensorRow(row_id_type row_id, const std::pair> &image_label, + TensorRow *row) { + std::shared_ptr image; + std::shared_ptr label; + + Path path(folder_path_); + Path image_path = path / image_label.first; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, image_path.toString())); + if (decode_ == true) { + Status rc = Decode(image, &image); + if (rc.IsError()) { + image = nullptr; + std::string err_msg = "Fail to decode image: " + image_path.toString(); + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + } + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), + TensorShape({1, (uint32_t)image_label.second.size()}), + data_schema_->column(1).type())); + RETURN_IF_NOT_OK(label->Zero()); + for (uint32_t index = 0; index < image_label.second.size(); index++) { + if (image_label.second[index] == 1) { + label->SetItemAt({0, static_cast(index)}, 1); + } else { + label->SetItemAt({0, static_cast(index)}, 0); + } + } + label->Squeeze(); + + (*row) = TensorRow(row_id, {std::move(image), std::move(label)}); + return Status::OK(); +} + +void CelebAOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows:" << num_rows_ << "\nceleba dir: " << folder_path_ << "\n\n"; + } +} + +// Reset Sampler and wakeup Master thread (functor) +Status CelebAOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + wp_.Set(); // wake up master thread after reset is done + return Status::OK(); +} + +// Visitor accept method for NodePass +Status CelebAOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status CelebAOp::ComputeColMap() { + // Set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t index = 0; index < data_schema_->NumColumns(); index++) { + column_name_id_map_[data_schema_->column(index).name()] = index; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h new file mode 100644 index 0000000000..ef183f8e65 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h @@ -0,0 +1,240 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H +#define DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" + +#define CLOSE_FILE(attr_file, pairition_file) \ + do { \ + attr_file.close(); \ + if (pairition_file.is_open()) { \ + pairition_file.close(); \ + } \ + } while (false) + +namespace mindspore { +namespace dataset { +class CelebAOp : public ParallelOp, RandomAccessOp { + public: + class Builder { + public: + // Constructor for Builder class of CelebAOp + // @return Builder setter method returns reference to the builder. + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method + // @param int32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + builder_op_connector_size_ = size; + return *this; + } + + // Setter method + // @param std::set & exts, file extensions to be read + // @return Builder setter method returns reference to the builder. + Builder &SetExtensions(const std::set &exts) { + builder_extensions_ = exts; + return *this; + } + + // Setter method + // @param bool decode + // @return Builder setter method returns reference to the builder. + Builder &SetDecode(bool decode) { + builder_decode_ = decode; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method + // @param const std::string &dir + // @return Builder setter method returns reference to the builder. + Builder &SetCelebADir(const std::string &dir) { + builder_dir_ = dir; + return *this; + } + + // Setter method + // @param const std::string dataset_type: type to be read + // @return Builder setter method returns reference to the builder. + Builder &SetDatasetType(const std::string &dataset_type) { + builder_dataset_type_ = dataset_type; + return *this; + } + // Check validity of input args + // @return - The error code return + Status SanityCheck(); + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + bool builder_decode_; + std::string builder_dir_; + int32_t builder_num_workers_; + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::set builder_extensions_; + std::shared_ptr builder_sampler_; + std::unique_ptr builder_schema_; + std::string builder_dataset_type_; + }; + + // Constructor + // @param int32_t - num_workers - Num of workers reading images in parallel + // @param int32_t - rows_per_buffer Number of images (rows) in each buffer + // @param std::string - dir directory of celeba dataset + // @param int32_t queueSize - connector queue size + // @param std::unique_ptr sampler - sampler tells CelebAOp what to read + CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, + const std::string &dataset_type, const std::set &exts, std::unique_ptr schema, + std::shared_ptr sampler); + + ~CelebAOp() override = default; + + // Main Loop of CelebaOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t worker_id - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // Method in operator(), to fill IOBlockQueue + // @param std::unique_ptr sampler_buffer - to fill IOBlockQueue + // @return Status - The error code return + Status AddIOBlock(std::unique_ptr *data_buffer); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const { return "CelebAOp"; } + + private: + // Called first when function is called + // @return + Status LaunchThreadsAndInitOp(); + + // Parse attribute file + // @return + Status ParseAttrFile(); + + // Parse each image line in attribute file + // @return + Status ParseImageAttrInfo(); + + // Split attribute info with space + // @param std::string - line - Line from att or partition file + // @return std::vector - string after split + std::vector Split(const std::string &line); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Load a tensor row according to a pair + // @param row_id_type row_id - id for this tensor row + // @param std::pair - > + // @param TensorRow row - image & label read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, const std::pair> &image_label, + TensorRow *row); + + // Check if need read according to dataset type + // @return bool - if need read + bool CheckDatasetTypeValid(); + + // reset Op + // @return Status - The error code return + Status Reset() override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t rows_per_buffer_; + std::string folder_path_; // directory of celeba folder + bool decode_; + std::set extensions_; // extensions allowed + std::unique_ptr data_schema_; + std::unique_ptr>> attr_info_queue_; + int64_t num_rows_in_attr_file_; // rows number specified in attr file + QueueList> io_block_queues_; + WaitPost wp_; + std::vector>> image_labels_vec_; + std::string dataset_type_; + std::ifstream partition_file_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc new file mode 100644 index 0000000000..06be682bfd --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -0,0 +1,472 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" + +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +constexpr uint32_t kCifarImageHeight = 32; +constexpr uint32_t kCifarImageWidth = 32; +constexpr uint32_t kCifarImageChannel = 3; +constexpr uint32_t kCifarBlockImageNum = 5; +constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel; + +CifarOp::Builder::Builder() : sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + num_workers_ = cfg->num_parallel_workers(); + rows_per_buffer_ = cfg->rows_per_buffer(); + op_connect_size_ = cfg->op_connector_size(); + cifar_type_ = kCifar10; +} + +Status CifarOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + sampler_ = std::make_shared(start_index, num_samples); + } + schema_ = std::make_unique(); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + if (cifar_type_ == kCifar10) { + RETURN_IF_NOT_OK( + schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + } else { + RETURN_IF_NOT_OK(schema_->AddColumn( + ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + TensorShape another_scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema_->AddColumn( + ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar))); + } + + *ptr = std::make_shared(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, + std::move(schema_), std::move(sampler_)); + return Status::OK(); +} + +Status CifarOp::Builder::SanityCheck() { + Path dir(dir_); + std::string err_msg; + err_msg += dir.IsDirectory() == false ? "Cifar path is invalid or not set\n" : ""; + err_msg += num_workers_ <= 0 ? "Num of parallel workers is negative or 0\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, + int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_works, queue_size, std::move(sampler)), + cifar_type_(type), + rows_per_buffer_(rows_per_buf), + folder_path_(file_dir), + data_schema_(std::move(data_schema)), + row_cnt_(0), + buf_cnt_(0) { + constexpr uint64_t kUtilQueueSize = 512; + cifar_raw_data_block_ = std::make_unique>>(kUtilQueueSize); + io_block_queues_.Init(num_workers_, queue_size); +} + +// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work +Status CifarOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (true) { // each iterator is 1 epoch + std::vector keys; + keys.reserve(rows_per_buffer_); + while (sampler_buffer->eoe() == false) { + TensorRow sample_row; + RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { + keys.push_back(*itr); + row_cnt_++; + if ((*itr) >= num_rows_) continue; // index out of bound, skipping + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + } +} + +Status CifarOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK( + tree_->AllTasks()->CreateAsyncTask("Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this))); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CifarOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + // The order of the following 2 functions must not be changed! + RETURN_IF_NOT_OK(ParseCifarData()); // Parse cifar data and get num rows, blocking + RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler + return Status::OK(); +} + +// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ +// IMPORTANT: 1 IOBlock produces 1 DataBuffer +Status CifarOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty() == true) { + return Status::OK(); // empty key is a quit signal for workers + } + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +// Load 1 TensorRow (image,label). 1 function call produces 1 TensorTow in a DataBuffer +Status CifarOp::LoadTensorRow(uint64_t index, TensorRow *trow) { + std::shared_ptr label; + std::shared_ptr fine_label; + std::shared_ptr ori_image = cifar_image_label_pairs_[index].first; + std::shared_ptr copy_image = + std::make_shared(ori_image->shape(), ori_image->type(), ori_image->GetBuffer()); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), + data_schema_->column(1).type(), + reinterpret_cast(&cifar_image_label_pairs_[index].second[0]))); + if (cifar_image_label_pairs_[index].second.size() > 1) { + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &fine_label, data_schema_->column(2).tensorImpl(), data_schema_->column(2).shape(), + data_schema_->column(2).type(), reinterpret_cast(&cifar_image_label_pairs_[index].second[1]))); + (*trow) = TensorRow(index, {copy_image, std::move(label), std::move(fine_label)}); + } else { + (*trow) = TensorRow(index, {copy_image, std::move(label)}); + } + + return Status::OK(); +} + +// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer +Status CifarOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + for (const int64_t &key : keys) { + TensorRow trow; + RETURN_IF_NOT_OK(LoadTensorRow(key, &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +void CifarOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows:" << num_rows_ << "\nCifar directory: " << folder_path_ << "\n\n"; + } +} + +// Reset Sampler and wakeup Master thread (functor) +Status CifarOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); // wake up master thread after reset is done + return Status::OK(); +} + +// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows +Status CifarOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +Status CifarOp::ReadCifarBlockDataAsync() { + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(GetCifarFiles()); + if (cifar_type_ == kCifar10) { + RETURN_IF_NOT_OK(ReadCifar10BlockData()); + } else { + RETURN_IF_NOT_OK(ReadCifar100BlockData()); + } + + return Status::OK(); +} + +Status CifarOp::ReadCifar10BlockData() { + constexpr uint32_t num_cifar10_records = 10000; + uint32_t block_size = (kCifarImageSize + 1) * kCifarBlockImageNum; // about 2M + std::vector image_data(block_size * sizeof(unsigned char), 0); + for (auto &file : cifar_files_) { + std::ifstream in(file, std::ios::binary); + if (!in.is_open()) { + std::string err_msg = file + " can not be opened."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + for (uint32_t index = 0; index < num_cifar10_records / kCifarBlockImageNum; ++index) { + (void)in.read(reinterpret_cast(&(image_data[0])), block_size * sizeof(unsigned char)); + if (in.fail()) { + RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file); + } + (void)cifar_raw_data_block_->EmplaceBack(image_data); + } + in.close(); + } + (void)cifar_raw_data_block_->EmplaceBack(std::vector()); // end block + + return Status::OK(); +} + +Status CifarOp::ReadCifar100BlockData() { + uint32_t num_cifar100_records = 0; // test:10000, train:50000 + uint32_t block_size = (kCifarImageSize + 2) * kCifarBlockImageNum; // about 2M + std::vector image_data(block_size * sizeof(unsigned char), 0); + for (auto &file : cifar_files_) { + int pos = file.find_last_of('/'); + if (pos == std::string::npos) { + RETURN_STATUS_UNEXPECTED("Invalid cifar100 file path"); + } + std::string file_name(file.substr(pos + 1)); + if (file_name.find("test") != std::string::npos) { + num_cifar100_records = 10000; + } else if (file_name.find("train") != std::string::npos) { + num_cifar100_records = 50000; + } else { + RETURN_STATUS_UNEXPECTED("Cifar 100 file not found!"); + } + + std::ifstream in(file, std::ios::binary); + if (!in.is_open()) { + RETURN_STATUS_UNEXPECTED(file + " can not be opened."); + } + + for (uint32_t index = 0; index < num_cifar100_records / kCifarBlockImageNum; index++) { + (void)in.read(reinterpret_cast(&(image_data[0])), block_size * sizeof(unsigned char)); + if (in.fail()) { + RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file); + } + (void)cifar_raw_data_block_->EmplaceBack(image_data); + } + in.close(); + } + (void)cifar_raw_data_block_->EmplaceBack(std::vector()); // block end + return Status::OK(); +} + +Status CifarOp::GetCifarFiles() { + // Initialize queue to hold the file names + const std::string kExtension = ".bin"; + Path dataset_directory(folder_path_); + auto dirIt = Path::DirIterator::OpenDirectory(&dataset_directory); + if (dirIt) { + while (dirIt->hasNext()) { + Path file = dirIt->next(); + std::string filename = file.toString(); + if (filename.find(kExtension) != std::string::npos) { + cifar_files_.push_back(filename); + MS_LOG(INFO) << "Cifar operator found file at " << filename << "."; + } + } + } else { + std::string err_msg = "Unable to open directory " + dataset_directory.toString(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::sort(cifar_files_.begin(), cifar_files_.end()); + return Status::OK(); +} + +Status CifarOp::ParseCifarData() { + std::vector block; + RETURN_IF_NOT_OK(cifar_raw_data_block_->PopFront(&block)); + uint32_t cur_block_index = 0; + while (!block.empty()) { + for (uint32_t index = 0; index < kCifarBlockImageNum; ++index) { + std::vector labels; + uint32_t label = block[cur_block_index++]; + labels.push_back(label); + if (cifar_type_ == kCifar100) { + uint32_t fine_label = block[cur_block_index++]; + labels.push_back(fine_label); + } + + std::shared_ptr image_tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image_tensor, data_schema_->column(0).tensorImpl(), + TensorShape({kCifarImageHeight, kCifarImageWidth, kCifarImageChannel}), + data_schema_->column(0).type())); + auto itr = image_tensor->begin(); + uint32_t total_pix = kCifarImageHeight * kCifarImageWidth; + for (int pix = 0; pix < total_pix; ++pix) { + for (int ch = 0; ch < kCifarImageChannel; ++ch) { + *itr = block[cur_block_index + ch * total_pix + pix]; + itr++; + } + } + cur_block_index += total_pix * kCifarImageChannel; + cifar_image_label_pairs_.emplace_back(std::make_pair(image_tensor, labels)); + } + RETURN_IF_NOT_OK(cifar_raw_data_block_->PopFront(&block)); + cur_block_index = 0; + } + cifar_image_label_pairs_.shrink_to_fit(); + num_rows_ = cifar_image_label_pairs_.size(); + if (num_rows_ == 0) { + std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; + std::string err_msg = "There is no valid data matching the dataset API " + api + + ".Please check file path or dataset API validation first."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + cifar_raw_data_block_->Reset(); + return Status::OK(); +} + +// Derived from RandomAccessOp +Status CifarOp::GetClassIds(std::map> *cls_ids) const { + if (cls_ids == nullptr || !cls_ids->empty()) { + RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); + } + + for (uint64_t index = 0; index < cifar_image_label_pairs_.size(); ++index) { + uint32_t label = (cifar_image_label_pairs_[index].second)[0]; + (*cls_ids)[label].push_back(index); + } + + for (auto &pair : (*cls_ids)) { + pair.second.shrink_to_fit(); + } + return Status::OK(); +} + +Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) { + // the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block() + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op)); + RETURN_IF_NOT_OK(op->GetCifarFiles()); + if (op->cifar_type_ == kCifar10) { + constexpr int64_t num_cifar10_records = 10000; + for (auto &file : op->cifar_files_) { + std::ifstream in(file, std::ios::binary); + if (!in.is_open()) { + std::string err_msg = file + " can not be opened."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + *count = *count + num_cifar10_records; + } + return Status::OK(); + } else { + int64_t num_cifar100_records = 0; + for (auto &file : op->cifar_files_) { + size_t pos = file.find_last_of('/'); + if (pos == std::string::npos) { + std::string err_msg = "Invalid cifar100 file path"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::string file_name; + if (file.size() > 0) + file_name = file.substr(pos + 1); + else + RETURN_STATUS_UNEXPECTED("Invalid string length!"); + if (file_name.find("test") != std::string::npos) { + num_cifar100_records = 10000; + } else if (file_name.find("train") != std::string::npos) { + num_cifar100_records = 50000; + } + std::ifstream in(file, std::ios::binary); + if (!in.is_open()) { + std::string err_msg = file + " can not be opened."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + *count = num_cifar100_records; + return Status::OK(); + } +} + +// Visitor accept method for NodePass +Status CifarOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status CifarOp::ComputeColMap() { + // set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (uint32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h new file mode 100644 index 0000000000..60169f32bf --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h @@ -0,0 +1,236 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +class CifarOp : public ParallelOp, public RandomAccessOp { + public: + enum CifarType { kCifar10, kCifar100 }; + + class Builder { + public: + // Constructor for Builder class of CifarOp + // @return Builder setter method returns reference to the builder. + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param uint32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method + // @param uint32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + op_connect_size_ = size; + return *this; + } + + // Setter method + // @param uint32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + sampler_ = std::move(sampler); + return *this; + } + + // Setter method + // @param const std::string & dir + // @return + Builder &SetCifarDir(const std::string &dir) { + dir_ = dir; + return *this; + } + + // Setter method + // @param const std::string & dir + // @return + Builder &SetCifarType(const bool cifar10) { + if (cifar10) { + cifar_type_ = kCifar10; + } else { + cifar_type_ = kCifar100; + } + return *this; + } + + // Check validity of input args + // @return - The error code return + Status SanityCheck(); + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + std::string dir_; + int32_t num_workers_; + int32_t rows_per_buffer_; + int32_t op_connect_size_; + std::shared_ptr sampler_; + std::unique_ptr schema_; + CifarType cifar_type_; + }; + + // Constructor + // @param CifarType type - Cifar10 or Cifar100 + // @param uint32_t numWorks - Num of workers reading images in parallel + // @param uint32_t - rowsPerBuffer Number of images (rows) in each buffer + // @param std::string - dir directory of cifar dataset + // @param uint32_t - queueSize - connector queue size + // @param std::unique_ptr sampler - sampler tells ImageFolderOp what to read + CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size, + std::unique_ptr data_schema, std::shared_ptr sampler); + // Destructor. + ~CifarOp() = default; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param uint32_t workerId - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Main Loop of CifarOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // Function to count the number of samples in the CIFAR dataset + // @param dir path to the CIFAR directory + // @param isCIFAR10 true if CIFAR10 and false if CIFAR100 + // @param count output arg that will hold the actual dataset size + // @return + static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "CifarOp"; } + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Load a tensor row according to a pair + // @param uint64_t index - index need to load + // @param TensorRow row - image & label read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(uint64_t index, TensorRow *row); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Read block data from cifar file + // @return + Status ReadCifarBlockDataAsync(); + + // Called first when function is called + // @return + Status LaunchThreadsAndInitOp(); + + // reset Op + // @return Status - The error code return + Status Reset() override; + + // Get cifar files in dir + // @return + Status GetCifarFiles(); + + // Read cifar10 data as block + // @return + Status ReadCifar10BlockData(); + + // Read cifar100 data as block + // @return + Status ReadCifar100BlockData(); + + // Parse cifar data + // @return + Status ParseCifarData(); + + // Method derived from RandomAccess Op, enable Sampler to get all ids for each calss + // @param (std::map> * map - key label, val all ids for this class + // @return Status - The error code return + Status GetClassIds(std::map> *cls_ids) const override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + CifarType cifar_type_; + int32_t rows_per_buffer_; + std::string folder_path_; + std::unique_ptr data_schema_; + int64_t row_cnt_; + int64_t buf_cnt_; + + WaitPost wp_; + QueueList> io_block_queues_; + std::unique_ptr>> cifar_raw_data_block_; + std::vector cifar_files_; + std::vector, std::vector>> cifar_image_label_pairs_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc new file mode 100644 index 0000000000..958514583a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -0,0 +1,555 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/clue_op.h" + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/engine/jagged_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +ClueOp::Builder::Builder() + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); +} + +Status ClueOp::Builder::ValidateInputs() const { + std::string err; + err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; + err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err); +} + +Status ClueOp::Builder::Build(std::shared_ptr *op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_clue_files_list_.size()) { + builder_num_workers_ = builder_clue_files_list_.size(); + MS_LOG(WARNING) << "ClueOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + ColKeyMap ck_map; + for (auto &p : builder_cols_to_keyword_) { + ck_map.insert({p.first, split(p.second, '/')}); + } + + std::shared_ptr clue_op = std::make_shared( + builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, + builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, + builder_device_id_); + RETURN_IF_NOT_OK(clue_op->Init()); + *op = std::move(clue_op); + + return Status::OK(); +} + +std::vector ClueOp::Builder::split(const std::string &s, char delim) { + std::vector res; + std::stringstream ss(s); + std::string item; + + while (getline(ss, item, delim)) { + res.push_back(item); + } + return res; +} + +ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), + rows_per_buffer_(rows_per_buffer), + num_rows_per_shard_(0), + all_num_rows_(0), + num_samples_(num_samples), + filename_index_(std::make_unique()), + clue_files_list_(std::move(clue_files_list)), + load_jagged_connector_(true), + cols_to_keyword_(cols_to_keyword), + shuffle_files_(shuffle_files), + finished_reading_dataset_(false), + num_devices_(num_device), + device_id_(device_id), + load_io_block_queue_(true) { + worker_connector_size_ = worker_connector_size; +} + +Status ClueOp::Init() { + RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_)); + + int32_t safe_queue_size = static_cast(std::ceil(clue_files_list_.size() / num_workers_) + 1); + io_block_queues_.Init(num_workers_, safe_queue_size); + + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); + + return Status::OK(); +} + +Status ClueOp::Reset() { + load_jagged_connector_ = true; + load_io_block_queue_ = true; + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + return Status::OK(); +} + +Status ClueOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { + TensorRow tRow(1, nullptr); + (*tensor_table)->push_back(std::move(tRow)); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); + (**tensor_table)[row][0] = std::move(tensor); + return Status::OK(); +} + +Status ClueOp::GetValue(const nlohmann::json &js, std::vector key_chain, std::shared_ptr *t) { + nlohmann::json cursor = js; + for (int i = 0; i < key_chain.size(); i++) { + if (cursor.find(key_chain[i]) != cursor.end()) { + cursor = cursor[key_chain[i]]; + } else { + RETURN_STATUS_UNEXPECTED("Failed to find key: " + key_chain[i]); + } + } + std::string final_str = key_chain.back(); + switch (cursor.type()) { + case nlohmann::detail::value_t::string: + RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get()}, TensorShape::CreateScalar())); + break; + + case nlohmann::detail::value_t::number_integer: + RETURN_IF_NOT_OK( + Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); + (*t)->SetItemAt({0}, cursor.get()); + break; + case nlohmann::detail::value_t::number_unsigned: + RETURN_IF_NOT_OK( + Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); + (*t)->SetItemAt({0}, cursor.get()); + break; + case nlohmann::detail::value_t::number_float: + RETURN_IF_NOT_OK( + Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32))); + (*t)->SetItemAt({0}, cursor.get()); + break; + case nlohmann::detail::value_t::array: + RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get>()}, TensorShape::CreateScalar())); + break; + default: + break; + } + return Status::OK(); +} + +Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id) { + std::ifstream handle(file); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Failed to open file " + file); + } + + int64_t rows_each_buffer = 0; + int64_t rows_total = 0; + std::string line; + std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + std::unique_ptr tensor_table = std::make_unique(); + + while (getline(handle, line)) { + if (line.empty()) { + continue; + } + // If read to the end offset of this file, break. + if (rows_total >= end_offset) { + break; + } + // Skip line before start offset. + if (rows_total < start_offset) { + rows_total++; + continue; + } + + try { + nlohmann::json js = nlohmann::json::parse(line); + int cols_count = cols_to_keyword_.size(); + TensorRow tRow(cols_count, nullptr); + tensor_table->push_back(std::move(tRow)); + + int cout = 0; + for (auto &p : cols_to_keyword_) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor)); + (*tensor_table)[rows_each_buffer][cout] = std::move(tensor); + cout++; + } + } catch (const std::exception &err) { + // Catch any exception and convert to Status return code + RETURN_STATUS_UNEXPECTED("Failed to load json file"); + } + + // RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); + rows_each_buffer++; + rows_total++; + if (rows_each_buffer == rows_per_buffer_) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + + cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + tensor_table = std::make_unique(); + rows_each_buffer = 0; + } + } + + if (rows_each_buffer > 0) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + } + return Status::OK(); +} + +Status ClueOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling IoBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this))); + + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); + NotifyToFillIOBlockQueue(); + + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + load_io_block_queue_ = true; + + while (workers_done < num_workers_) { + std::unique_ptr buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); + if (buffer->eoe()) { + workers_done++; + } else if (num_samples_ == 0 || rows_read < num_samples_) { + if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { + int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); + RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); + } + rows_read += buffer->NumRows(); + buffer->set_id(buffer_id++); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); + } else { + // end of epoch + load_jagged_connector_ = false; + load_io_block_queue_ = false; + } + } + + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + return Status::OK(); +} + +Status ClueOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + return Status::OK(); +} + +// A print method typically used for debugging +void ClueOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ + << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ + << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n"; + for (int i = 0; i < clue_files_list_.size(); ++i) { + out << " " << clue_files_list_[i]; + } + out << "\n\n"; + } +} + +// Pops an element from a queue in io_block_queues +Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +Status ClueOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spanwed by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); + } + return Status::OK(); +} + +Status ClueOp::FillIOBlockQueue(const std::vector &i_keys) { + int32_t queue_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + while (!finish) { + std::vector> file_index; + if (!i_keys.empty()) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair((*filename_index_)[*it], *it)); + } + } else { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair(it.value(), it.key())); + } + } + for (auto file_info : file_index) { + if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { + auto ioBlock = + std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_info.first]; + } + + if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status ClueOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +Status ClueOp::CalculateNumRowsPerShard() { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + int64_t count = CountTotalRows(it.value()); + filename_numrows_[it.value()] = count; + all_num_rows_ += count; + } + if (all_num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API CLUEDataset. Please check file path or dataset API " + "validation first."); + } + + num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); + MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; + return Status::OK(); +} + +int64_t ClueOp::CountTotalRows(const std::string &file) { + std::ifstream handle(file); + if (!handle.is_open()) { + MS_LOG(ERROR) << "Failed to open file: " << file; + return 0; + } + + std::string line; + int64_t count = 0; + while (getline(handle, line)) { + if (!line.empty()) { + count++; + } + } + + return count; +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status ClueOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +Status ClueOp::CountAllFileRows(const std::vector &files, int64_t *count) { + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetClueFilesList(files).Build(&op)); + for (auto file : files) { + *count += op->CountTotalRows(file); + } + return Status::OK(); +} + +Status ClueOp::ComputeColMap() { + // Set the column name mapping (base class field) + if (column_name_id_map_.empty()) { + int count = 0; + for (auto &p : cols_to_keyword_) { + column_name_id_map_[p.first] = count; + count++; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h new file mode 100644 index 0000000000..ab429561ec --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h @@ -0,0 +1,277 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" + +namespace mindspore { +namespace dataset { +using StringIndex = AutoIndexObj; +using ColKeyMap = std::map>; + +class JaggedConnector; + +class ClueOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + // Create the final object. + // @param op - dataset op. + // @return - the error code return. + Status Build(std::shared_ptr *op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetClueFilesList(const std::vector &files_list) { + builder_clue_files_list_ = files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumSamples(int64_t num_samples) { + builder_num_samples_ = num_samples; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetColsKeyMap(const std::map &cols_to_key) { + builder_cols_to_keyword_ = cols_to_key; + return *this; + } + + // Split string based on a character delimiter + // @return - the a string vector + std::vector split(const std::string &s, char delim); + + private: + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_num_samples_; + int32_t builder_worker_connector_size_; + std::vector builder_clue_files_list_; + bool builder_shuffle_files_; + std::map builder_cols_to_keyword_; + }; + + // Constructor of ClueOp + ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id); + + // Default destructor + ~ClueOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Get total rows in files. + // @param files - all clue files. + // @param count - number of rows. + // @return Status - the error coed returned. + static Status CountAllFileRows(const std::vector &files, int64_t *count); + + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return clue_files_list_; } + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); + + // Reads a clue file and loads the data into multiple buffers. + // @param file - the file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Fill the IOBlockQueue. + // @para i_keys - keys of file to fill to the IOBlockQueue + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys); + + // Notifies the thread which called FillIoBlockQueue to resume execution + void NotifyToFillIOBlockQueue(); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Count number of rows in each file. + // @param filename - clue file name. + // @return int64_t - the total number of rows in file. + int64_t CountTotalRows(const std::string &file); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // @return Status - the error code returned. + Status GetValue(const nlohmann::json &js, std::vector key_chain, std::shared_ptr *t); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t device_id_; + bool shuffle_files_; + bool finished_reading_dataset_; + int32_t num_devices_; + int64_t rows_per_buffer_; + bool load_io_block_queue_; + int64_t num_rows_per_shard_; + int64_t all_num_rows_; + int64_t num_samples_; + std::map filename_numrows_; + std::unique_ptr filename_index_; + std::vector clue_files_list_; + WaitPost io_block_queue_wait_post_; + std::unique_ptr jagged_buffer_connector_; + QueueList> io_block_queues_; + bool load_jagged_connector_; + ColKeyMap cols_to_keyword_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc new file mode 100644 index 0000000000..daef2f284b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -0,0 +1,646 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/coco_op.h" + +#include +#include +#include +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +const char kColumnImage[] = "image"; +const char kJsonImages[] = "images"; +const char kJsonImagesFileName[] = "file_name"; +const char kJsonId[] = "id"; +const char kJsonAnnotations[] = "annotations"; +const char kJsonAnnoSegmentation[] = "segmentation"; +const char kJsonAnnoCounts[] = "counts"; +const char kJsonAnnoSegmentsInfo[] = "segments_info"; +const char kJsonAnnoIscrowd[] = "iscrowd"; +const char kJsonAnnoBbox[] = "bbox"; +const char kJsonAnnoArea[] = "area"; +const char kJsonAnnoImageId[] = "image_id"; +const char kJsonAnnoNumKeypoints[] = "num_keypoints"; +const char kJsonAnnoKeypoints[] = "keypoints"; +const char kJsonAnnoCategoryId[] = "category_id"; +const char kJsonCategories[] = "categories"; +const char kJsonCategoriesIsthing[] = "isthing"; +const char kJsonCategoriesName[] = "name"; +const float kDefaultPadValue = -1.0; +const unsigned int kPadValueZero = 0; + +CocoOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); + builder_task_type_ = TaskType::Detection; +} + +Status CocoOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + switch (builder_task_type_) { + case TaskType::Detection: + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoCategoryId), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + case TaskType::Stuff: + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoSegmentation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + case TaskType::Keypoint: + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoKeypoints), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoNumKeypoints), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + case TaskType::Panoptic: + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoCategoryId), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoArea), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + default: + RETURN_STATUS_UNEXPECTED("Invalid task type"); + } + *ptr = std::make_shared(builder_task_type_, builder_dir_, builder_file_, builder_num_workers_, + builder_rows_per_buffer_, builder_op_connector_size_, builder_decode_, + std::move(builder_schema_), std::move(builder_sampler_)); + return Status::OK(); +} + +Status CocoOp::Builder::SanityCheck() { + Path dir(builder_dir_); + Path file(builder_file_); + std::string err_msg; + err_msg += dir.IsDirectory() == false ? "Coco image folder path is invalid or not set\n" : ""; + err_msg += file.Exists() == false ? "Coco annotation json path is invalid or not set\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, + int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, + std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_workers, queue_size), + decode_(decode), + row_cnt_(0), + buf_cnt_(0), + task_type_(task_type), + image_folder_path_(image_folder_path), + annotation_path_(annotation_path), + rows_per_buffer_(rows_per_buffer), + sampler_(std::move(sampler)), + data_schema_(std::move(data_schema)) { + io_block_queues_.Init(num_workers_, queue_size); +} + +Status CocoOp::TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) > num_rows_) continue; + keys->push_back(*itr); + row_cnt_++; + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); + keys->clear(); + } + } + return Status::OK(); +} + +Status CocoOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (true) { + std::vector keys; + keys.reserve(rows_per_buffer_); + while (sampler_buffer->eoe() == false) { + std::shared_ptr sample_ids; + RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); + if (sample_ids->type() != DataType(DataType::DE_INT64)) { + RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); + } + RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + } +} + +void CocoOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows: " << num_rows_ << "\nCOCO Directory: " << image_folder_path_ << "\n\n"; + } +} + +Status CocoOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); + return Status::OK(); +} + +Status CocoOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *trow) { + std::shared_ptr image, coordinate; + auto itr = coordinate_map_.find(image_id); + if (itr == coordinate_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); + + std::string kImageFile = image_folder_path_ + image_id; + RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); + + auto bboxRow = itr->second; + std::vector bbox_row; + dsize_t bbox_row_num = static_cast(bboxRow.size()); + dsize_t bbox_column_num = 0; + for (auto bbox : bboxRow) { + if (static_cast(bbox.size()) > bbox_column_num) { + bbox_column_num = static_cast(bbox.size()); + } + } + + for (auto bbox : bboxRow) { + bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); + dsize_t pad_len = bbox_column_num - static_cast(bbox.size()); + if (pad_len > 0) { + for (dsize_t i = 0; i < pad_len; i++) { + bbox_row.push_back(kDefaultPadValue); + } + } + } + + std::vector bbox_dim = {bbox_row_num, bbox_column_num}; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&coordinate, data_schema_->column(1).tensorImpl(), TensorShape(bbox_dim), + data_schema_->column(1).type(), + reinterpret_cast(&bbox_row[0]))); + if (task_type_ == TaskType::Detection) { + RETURN_IF_NOT_OK(LoadDetectionTensorRow(row_id, image_id, image, coordinate, trow)); + } else if (task_type_ == TaskType::Stuff || task_type_ == TaskType::Keypoint) { + RETURN_IF_NOT_OK(LoadSimpleTensorRow(row_id, image_id, image, coordinate, trow)); + } else if (task_type_ == TaskType::Panoptic) { + RETURN_IF_NOT_OK(LoadMixTensorRow(row_id, image_id, image, coordinate, trow)); + } else { + RETURN_STATUS_UNEXPECTED("Invalid task type."); + } + + return Status::OK(); +} + +// When task is Detection, user can get data with four columns: +// column ["image"] with datatype=uint8 +// column ["bbox"] with datatype=float32 +// column ["category_id"] with datatype=uint32 +// column ["iscrowd"] with datatype=uint32 +// By the way, column ["iscrowd"] is used for some testcases, like fasterRcnn. +// If "iscrowd" is not existed, user will get default value 0. +Status CocoOp::LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow) { + std::shared_ptr category_id, iscrowd; + std::vector category_id_row; + std::vector iscrowd_row; + auto itr_item = simple_item_map_.find(image_id); + if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); + + std::vector annotation = itr_item->second; + for (int64_t i = 0; i < annotation.size(); i++) { + if (i % 2 == 0) { + category_id_row.push_back(annotation[i]); + } else if (i % 2 == 1) { + iscrowd_row.push_back(annotation[i]); + } + } + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), + data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); + + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), + data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); + (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd)}); + return Status::OK(); +} + +// When task is "Stuff"/"Keypoint", user can get data with three columns: +// column ["image"] with datatype=uint8 +// column ["segmentation"]/["keypoints"] with datatype=float32 +// column ["iscrowd"]/["num_keypoints"] with datatype=uint32 +Status CocoOp::LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow) { + std::shared_ptr item; + std::vector item_queue; + auto itr_item = simple_item_map_.find(image_id); + if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); + + item_queue = itr_item->second; + std::vector bbox_dim = {static_cast(item_queue.size()), 1}; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&item, data_schema_->column(2).tensorImpl(), TensorShape(bbox_dim), + data_schema_->column(2).type(), + reinterpret_cast(&item_queue[0]))); + (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(item)}); + return Status::OK(); +} + +// When task is "Panoptic", user can get data with five columns: +// column ["image"] with datatype=uint8 +// column ["bbox"] with datatype=float32 +// column ["category_id"] with datatype=uint32 +// column ["iscrowd"] with datatype=uint32 +// column ["area"] with datattype=uint32 +Status CocoOp::LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow) { + std::shared_ptr category_id, iscrowd, area; + std::vector category_id_row; + std::vector iscrowd_row; + std::vector area_row; + auto itr_item = simple_item_map_.find(image_id); + if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); + + std::vector annotation = itr_item->second; + for (int64_t i = 0; i < annotation.size(); i++) { + if (i % 3 == 0) { + category_id_row.push_back(annotation[i]); + } else if (i % 3 == 1) { + iscrowd_row.push_back(annotation[i]); + } else if (i % 3 == 2) { + area_row.push_back(annotation[i]); + } + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), + data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); + + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), + data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); + + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &area, data_schema_->column(4).tensorImpl(), TensorShape({static_cast(area_row.size()), 1}), + data_schema_->column(4).type(), reinterpret_cast(&area_row[0]))); + (*trow) = TensorRow( + row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd), std::move(area)}); + return Status::OK(); +} + +Status CocoOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + TensorRow trow; + for (const int64_t &key : keys) { + RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_ids_[key], &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +Status CocoOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, (std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty() == true) return Status::OK(); + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +template +Status CocoOp::SearchNodeInJson(nlohmann::json input_tree, std::string node_name, T *output_node) { + auto node = input_tree.find(node_name); + if (node == input_tree.end()) RETURN_STATUS_UNEXPECTED("Invalid node found in json : " + node_name); + (*output_node) = *node; + return Status::OK(); +} + +Status CocoOp::ParseAnnotationIds() { + std::ifstream in(annotation_path_); + nlohmann::json js; + in >> js; + + std::vector image_que; + nlohmann::json image_list; + RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonImages), &image_list)); + RETURN_IF_NOT_OK(ImageColumnLoad(image_list, &image_que)); + if (task_type_ == TaskType::Detection || task_type_ == TaskType::Panoptic) { + nlohmann::json node_categories; + RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonCategories), &node_categories)); + RETURN_IF_NOT_OK(CategoriesColumnLoad(node_categories)); + } + nlohmann::json annotations_list; + RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonAnnotations), &annotations_list)); + for (auto annotation : annotations_list) { + int32_t image_id = 0, id = 0; + std::string file_name; + RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonAnnoImageId), &image_id)); + auto itr_file = image_index_.find(image_id); + if (itr_file == image_index_.end()) + RETURN_STATUS_UNEXPECTED("Invalid image id of annotations : " + std::to_string(image_id)); + file_name = itr_file->second; + switch (task_type_) { + case TaskType::Detection: + RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); + RETURN_IF_NOT_OK(DetectionColumnLoad(annotation, file_name, id)); + break; + case TaskType::Stuff: + RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); + RETURN_IF_NOT_OK(StuffColumnLoad(annotation, file_name, id)); + break; + case TaskType::Keypoint: + RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); + RETURN_IF_NOT_OK(KeypointColumnLoad(annotation, file_name, id)); + break; + case TaskType::Panoptic: + RETURN_IF_NOT_OK(PanopticColumnLoad(annotation, file_name, image_id)); + break; + default: + RETURN_STATUS_UNEXPECTED("Invalid task type"); + } + } + for (auto img : image_que) { + if (coordinate_map_.find(img) != coordinate_map_.end()) image_ids_.push_back(img); + } + num_rows_ = image_ids_.size(); + return Status::OK(); +} + +Status CocoOp::ImageColumnLoad(nlohmann::json image_tree, std::vector *image_vec) { + if (image_tree.size() == 0) { + RETURN_STATUS_UNEXPECTED("No images found in " + annotation_path_); + } + for (auto img : image_tree) { + std::string file_name; + int32_t id = 0; + RETURN_IF_NOT_OK(SearchNodeInJson(img, std::string(kJsonImagesFileName), &file_name)); + RETURN_IF_NOT_OK(SearchNodeInJson(img, std::string(kJsonId), &id)); + + image_index_[id] = file_name; + image_vec->push_back(file_name); + } + return Status::OK(); +} + +Status CocoOp::DetectionColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, + const int32_t &unique_id) { + std::vector bbox; + nlohmann::json node_bbox; + uint32_t category_id = 0, iscrowd = 0; + RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoBbox), &node_bbox)); + RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoCategoryId), &category_id)); + auto search_category = category_set_.find(category_id); + if (search_category == category_set_.end()) + RETURN_STATUS_UNEXPECTED("category_id can't find in categories where category_id: " + std::to_string(category_id)); + auto node_iscrowd = annotation_tree.find(kJsonAnnoIscrowd); + if (node_iscrowd != annotation_tree.end()) iscrowd = *node_iscrowd; + bbox.insert(bbox.end(), node_bbox.begin(), node_bbox.end()); + coordinate_map_[image_file].push_back(bbox); + simple_item_map_[image_file].push_back(category_id); + simple_item_map_[image_file].push_back(iscrowd); + return Status::OK(); +} + +Status CocoOp::StuffColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, + const int32_t &unique_id) { + uint32_t iscrowd = 0; + std::vector bbox; + RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoIscrowd), &iscrowd)); + simple_item_map_[image_file].push_back(iscrowd); + nlohmann::json segmentation; + RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoSegmentation), &segmentation)); + if (iscrowd == 0) { + for (auto item : segmentation) { + if (bbox.size() > 0) bbox.clear(); + bbox.insert(bbox.end(), item.begin(), item.end()); + coordinate_map_[image_file].push_back(bbox); + } + } else if (iscrowd == 1) { + nlohmann::json segmentation_count; + RETURN_IF_NOT_OK(SearchNodeInJson(segmentation, std::string(kJsonAnnoCounts), &segmentation_count)); + bbox.insert(bbox.end(), segmentation_count.begin(), segmentation_count.end()); + coordinate_map_[image_file].push_back(bbox); + } + return Status::OK(); +} + +Status CocoOp::KeypointColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, + const int32_t &unique_id) { + auto itr_num_keypoint = annotation_tree.find(kJsonAnnoNumKeypoints); + if (itr_num_keypoint == annotation_tree.end()) + RETURN_STATUS_UNEXPECTED("No num_keypoint found in annotations where id: " + std::to_string(unique_id)); + simple_item_map_[image_file].push_back(*itr_num_keypoint); + auto itr_keypoint = annotation_tree.find(kJsonAnnoKeypoints); + if (itr_keypoint == annotation_tree.end()) + RETURN_STATUS_UNEXPECTED("No keypoint found in annotations where id: " + std::to_string(unique_id)); + coordinate_map_[image_file].push_back(*itr_keypoint); + return Status::OK(); +} + +Status CocoOp::PanopticColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, + const int32_t &image_id) { + auto itr_segments = annotation_tree.find(kJsonAnnoSegmentsInfo); + if (itr_segments == annotation_tree.end()) + RETURN_STATUS_UNEXPECTED("No segments_info found in annotations where image_id: " + std::to_string(image_id)); + for (auto info : *itr_segments) { + std::vector bbox; + uint32_t category_id = 0; + auto itr_bbox = info.find(kJsonAnnoBbox); + if (itr_bbox == info.end()) + RETURN_STATUS_UNEXPECTED("No bbox found in segments_info where image_id: " + std::to_string(image_id)); + bbox.insert(bbox.end(), itr_bbox->begin(), itr_bbox->end()); + coordinate_map_[image_file].push_back(bbox); + + RETURN_IF_NOT_OK(SearchNodeInJson(info, std::string(kJsonAnnoCategoryId), &category_id)); + auto search_category = category_set_.find(category_id); + if (search_category == category_set_.end()) + RETURN_STATUS_UNEXPECTED("category_id can't find in categories where category_id: " + + std::to_string(category_id)); + auto itr_iscrowd = info.find(kJsonAnnoIscrowd); + if (itr_iscrowd == info.end()) + RETURN_STATUS_UNEXPECTED("No iscrowd found in segments_info where image_id: " + std::to_string(image_id)); + auto itr_area = info.find(kJsonAnnoArea); + if (itr_area == info.end()) + RETURN_STATUS_UNEXPECTED("No area found in segments_info where image_id: " + std::to_string(image_id)); + simple_item_map_[image_file].push_back(category_id); + simple_item_map_[image_file].push_back(*itr_iscrowd); + simple_item_map_[image_file].push_back(*itr_area); + } + return Status::OK(); +} + +Status CocoOp::CategoriesColumnLoad(nlohmann::json categories_tree) { + if (categories_tree.size() == 0) RETURN_STATUS_UNEXPECTED("No categories found in " + annotation_path_); + for (auto category : categories_tree) { + int32_t id = 0; + std::string name; + std::vector label_info; + auto itr_id = category.find(kJsonId); + if (itr_id == category.end()) RETURN_STATUS_UNEXPECTED("No id found in categories of " + annotation_path_); + id = *itr_id; + label_info.push_back(id); + category_set_.insert(id); + + auto itr_name = category.find(kJsonCategoriesName); + if (itr_name == category.end()) + RETURN_STATUS_UNEXPECTED("No name found in categories where id: " + std::to_string(id)); + name = *itr_name; + + if (task_type_ == TaskType::Panoptic) { + auto itr_isthing = category.find(kJsonCategoriesIsthing); + if (itr_isthing == category.end()) + RETURN_STATUS_UNEXPECTED("No isthing found in categories of " + annotation_path_); + label_info.push_back(*itr_isthing); + } + label_index_.emplace_back(std::make_pair(name, label_info)); + } + return Status::OK(); +} + +Status CocoOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +Status CocoOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CocoOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(this->ParseAnnotationIds()); + RETURN_IF_NOT_OK(this->InitSampler()); + return Status::OK(); +} + +Status CocoOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); + + if (decode_ == true) { + Status rc = Decode(*tensor, tensor); + if (rc.IsError()) { + RETURN_STATUS_UNEXPECTED("fail to decode file: " + path); + } + } + return Status::OK(); +} + +Status CocoOp::CountTotalRows(const std::string &dir, const std::string &file, const std::string &task, + int64_t *count) { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetDir(dir).SetFile(file).SetTask(task).Build(&op)); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + *count = static_cast(op->image_ids_.size()); + return Status::OK(); +} + +Status CocoOp::GetClassIndexing(const std::string &dir, const std::string &file, const std::string &task, + std::vector>> *output_class_indexing) { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetDir(dir).SetFile(file).SetTask(task).Build(&op)); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + *output_class_indexing = op->label_index_; + return Status::OK(); +} + +// Visitor accept method for NodePass +Status CocoOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status CocoOp::ComputeColMap() { + // Set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h new file mode 100644 index 0000000000..31070c26f5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h @@ -0,0 +1,340 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_COC0_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// Forward declares +template +class Queue; + +using CoordinateRow = std::vector>; + +class CocoOp : public ParallelOp, public RandomAccessOp { + public: + enum class TaskType { Detection = 0, Stuff = 1, Panoptic = 2, Keypoint = 3 }; + + class Builder { + public: + // Constructor for Builder class of ImageFolderOp + // @param uint32_t numWrks - number of parallel workers + // @param dir - directory folder got ImageNetFolder + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method. + // @param const std::string & build_dir + // @return Builder setter method returns reference to the builder. + Builder &SetDir(const std::string &build_dir) { + builder_dir_ = build_dir; + return *this; + } + + // Setter method. + // @param const std::string & build_file + // @return Builder setter method returns reference to the builder. + Builder &SetFile(const std::string &build_file) { + builder_file_ = build_file; + return *this; + } + + // Setter method. + // @param const std::string & task_type + // @return Builder setter method returns reference to the builder. + Builder &SetTask(const std::string &task_type) { + if (task_type == "Detection") { + builder_task_type_ = TaskType::Detection; + } else if (task_type == "Stuff") { + builder_task_type_ = TaskType::Stuff; + } else if (task_type == "Panoptic") { + builder_task_type_ = TaskType::Panoptic; + } else if (task_type == "Keypoint") { + builder_task_type_ = TaskType::Keypoint; + } + return *this; + } + + // Setter method. + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method. + // @param bool do_decode + // @return Builder setter method returns reference to the builder. + Builder &SetDecode(bool do_decode) { + builder_decode_ = do_decode; + return *this; + } + + // Check validity of input args + // @return = The error code return + Status SanityCheck(); + + // The builder "Build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + bool builder_decode_; + std::string builder_dir_; + std::string builder_file_; + TaskType builder_task_type_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int32_t builder_rows_per_buffer_; + std::shared_ptr builder_sampler_; + std::unique_ptr builder_schema_; + }; + + // Constructor + // @param TaskType task_type - task type of Coco + // @param std::string image_folder_path - image folder path of Coco + // @param std::string annotation_path - annotation json path of Coco + // @param int32_t num_workers - number of workers reading images in parallel + // @param int32_t rows_per_buffer - number of images (rows) in each buffer + // @param int32_t queue_size - connector queue size + // @param int64_t num_samples - number of samples to read + // @param bool decode - whether to decode images + // @param std::unique_ptr data_schema - the schema of the Coco dataset + // @param std::shared_ptr sampler - sampler tells CocoOp what to read + CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, + int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, + std::unique_ptr data_schema, std::shared_ptr sampler); + + // Destructor + ~CocoOp() = default; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t workerId - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Main Loop of CocoOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // @param const std::string &dir - Coco image dir path + // @param const std::string &file - Coco json file path + // @param const std::string &task - task mode of Coco task + // @param int64_t numSamples - samples number of CocoDataset + // @param int64_t *count - output rows number of CocoDataset + static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, + int64_t *count); + + // @param const std::string &dir - Coco image dir path + // @param const std::string &file - Coco json file path + // @param const std::string &task - task mode of Coco task + // @param int64_t numSamples - samples number of CocoDataset + // @param std::map *output_class_indexing - output class index of CocoDataset + static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, + std::vector>> *output_class_indexing); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Load a tensor row according to image id + // @param row_id_type row_id - id for this tensor row + // @param std::string image_id - image id + // @param TensorRow row - image & target read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); + + // Load a tensor row with vector which a vector to a tensor + // @param row_id_type row_id - id for this tensor row + // @param const std::string &image_id - image is + // @param std::shared_ptr image - image tensor + // @param std::shared_ptr coordinate - coordinate tensor + // @param TensorRow row - image & target read into this tensor row + // @return Status - The error code return + Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow); + + // Load a tensor row with vector which a vector to a tensor + // @param row_id_type row_id - id for this tensor row + // @param const std::string &image_id - image is + // @param std::shared_ptr image - image tensor + // @param std::shared_ptr coordinate - coordinate tensor + // @param TensorRow row - image & target read into this tensor row + // @return Status - The error code return + Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow); + + // Load a tensor row with vector which a vector to multi-tensor + // @param row_id_type row_id - id for this tensor row + // @param const std::string &image_id - image is + // @param std::shared_ptr image - image tensor + // @param std::shared_ptr coordinate - coordinate tensor + // @param TensorRow row - image & target read into this tensor row + // @return Status - The error code return + Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow); + + // @param const std::string &path - path to the image file + // @param const ColDescriptor &col - contains tensor implementation and datatype + // @param std::shared_ptr tensor - return + // @return Status - The error code return + Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Read annotation from Annotation folder + // @return Status - The error code return + Status ParseAnnotationIds(); + + // @param const std::shared_ptr &sample_ids - sample ids of tensor + // @param std::vector *keys - image id + // @return Status - The error code return + Status TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); + + // Called first when function is called + // @return Status - The error code return + Status LaunchThreadsAndInitOp(); + + // Reset dataset state + // @return Status - The error code return + Status Reset() override; + + // @param nlohmann::json image_tree - image tree of json + // @param std::vector *image_vec - image id list of json + // @return Status - The error code return + Status ImageColumnLoad(nlohmann::json image_tree, std::vector *image_vec); + + // @param nlohmann::json categories_tree - categories tree of json + // return Status - The error code return + Status CategoriesColumnLoad(nlohmann::json categories_tree); + + // @param nlohmann::json categories_tree - categories tree of json + // @param const std::string &image_file - current image name in annotation + // @param const int32_t &id - current unique id of annotation + // @return Status - The error code return + Status DetectionColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); + + // @param nlohmann::json categories_tree - categories tree of json + // @param const std::string &image_file - current image name in annotation + // @param const int32_t &id - current unique id of annotation + // @return Status - The error code return + Status StuffColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); + + // @param nlohmann::json categories_tree - categories tree of json + // @param const std::string &image_file - current image name in annotation + // @param const int32_t &id - current unique id of annotation + // @return Status - The error code return + Status KeypointColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); + + // @param nlohmann::json categories_tree - categories tree of json + // @param const std::string &image_file - current image name in annotation + // @param const int32_t &image_id - current unique id of annotation + // @return Status - The error code return + Status PanopticColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &image_id); + + template + Status SearchNodeInJson(nlohmann::json input_tree, std::string node_name, T *output_node); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + bool decode_; + int64_t row_cnt_; + int64_t buf_cnt_; + std::string image_folder_path_; + std::string annotation_path_; + TaskType task_type_; + int32_t rows_per_buffer_; + std::shared_ptr sampler_; + std::unique_ptr data_schema_; + + WaitPost wp_; + std::vector image_ids_; + std::map image_index_; + QueueList> io_block_queues_; + std::vector>> label_index_; + std::map coordinate_map_; + std::map> simple_item_map_; + std::set category_set_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_Coco_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc new file mode 100644 index 0000000000..773dfc78b6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -0,0 +1,267 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#include +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +GeneratorOp::Builder::Builder() { + // Some arguments to the GeneratorOp constructor have a default argument that is taken + // from the client config. + build_buffer_size_ = kCfgRowsPerBuffer; + build_op_connector_size_ = kCfgOpConnectorSize; +} + +Status GeneratorOp::Builder::SanityCheck() { + // Update queue size to fit the prefetch requirement + MS_LOG(DEBUG) << "Generator operator sanity check, prefetch size is " << build_prefetch_size_ << "."; + if (build_prefetch_size_ > 0) { + build_op_connector_size_ = (build_prefetch_size_ + build_buffer_size_ - 1) / build_buffer_size_; + } + return Status::OK(); +} + +Status GeneratorOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_generator_function_, build_column_names_, build_column_types_, + build_prefetch_size_, build_buffer_size_, build_op_connector_size_); + return (*ptr)->Init(); +} + +GeneratorOp::GeneratorOp(py::function generator_function, std::vector column_names, + std::vector column_types, int32_t prefetch_size, int32_t buffer_size, + int32_t connector_size) + : PipelineOp(connector_size), + generator_function_(generator_function), + column_names_(column_names), + column_types_(column_types), + prefetch_size_(prefetch_size), + buffer_size_(buffer_size), + buffer_id_(0) {} + +GeneratorOp::~GeneratorOp() { this->Dealloc(); } + +void GeneratorOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nColumn names:\n"; + for (int i = 0; i < column_names_.size(); ++i) { + out << "\n " << column_names_[i]; + } + out << "\n\n"; + } +} + +void GeneratorOp::Dealloc() noexcept { + // Setup GIL state + PyGILState_STATE gstate; + gstate = PyGILState_Ensure(); + // GC the generator object within GIL + (void)generator_.dec_ref(); + // Release GIL + PyGILState_Release(gstate); +} + +// Reentrant init method. +Status GeneratorOp::Init() { + // Reset BufferID + buffer_id_ = 0; + Status ret; + { + // Acquire Python GIL + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + // Invoke the generatorFunction to get generator object + try { + generator_ = generator_function_(); + } catch (const py::error_already_set &e) { + ret = Status(StatusCode::kPyFuncException, e.what()); + } + } + return ret; +} + +Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) { + if (!py::isinstance(py_data)) { + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, "Generator should return a tuple of numpy arrays."); + } + py::tuple py_row = py_data.cast(); + // Check if returned number of columns matches with column names + if (py_row.size() != column_names_.size()) { + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, + "Generator should return same number of numpy arrays as specified in column names."); + } + // Iterate over two containers simultaneously for memory copy + for (int i = 0; i < py_row.size(); ++i) { + py::object ret_py_ele = py_row[i]; + if (!py::isinstance(ret_py_ele)) { + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, + "Generator should return a tuple of numpy arrays."); + } + std::shared_ptr tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, ret_py_ele.cast())); + if ((!column_types_.empty()) && (column_types_[i] != DataType::DE_UNKNOWN) && + (column_types_[i] != tensor->type())) { + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, "Generator type check failed."); + } + tensor_row->push_back(tensor); + } + return Status(StatusCode::kOK, ""); +} + +Status GeneratorOp::FillBuffer(TensorQTable *tt) { + for (int i = 0; i < buffer_size_; i++) { + TensorRow row; + RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &row)); + tt->push_back(std::move(row)); + } + return Status::OK(); +} + +// Entry point for Generator, called by launch() +// Note that this function is very easy to break because of the Python GIL mechanism +// The master thread has the following workflow +// +// while !eof: +// Try: +// Prepare one data buffer GIL, Can throw +// Catch: +// Fetch Python Exception GIL +// Check if Exception is StopIteration (EOE) GIL +// Restore Python Exception GIL +// If not StopIteration: +// Return Status PyFuncException +// +// Push data buffer to connector Block +// +// if EOE +// Push EOE Block +// if more epoch: +// Block until next epoch Block +// else: +// Push EOF Block +// eof = true +// Return Status OK +// +// Note that any modification of this function need to guarantee: +// 1. All "Require GIL" operations are protected by GIL +// SegFault / Deadlock will occur if this condition is not fulfilled. +// 2. All "Block" operations are free from GIL, all block target are registered with tree. +// Deadlock will occur if this condition is not fulfilled +// 3. No Python GC should be triggered outside of GIL. +// SegFault will occur is this condition is not fulfilled +// +Status GeneratorOp::operator()() { + // Handshake with TaskManager to synchronize thread creation + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + std::unique_ptr fetched_buffer; + bool eof = false; + while (!eof) { + // Create new buffer each iteration + fetched_buffer = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); + std::unique_ptr fetched_table = std::make_unique(); + bool eoe = false; + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + RETURN_IF_NOT_OK(FillBuffer(fetched_table.get())); + } catch (py::error_already_set &e) { + eoe = e.matches(PyExc_StopIteration); + // Restore exception to python + e.restore(); + // Pop up non StopIteration Python Exception + if (!eoe) { + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what()); + } + } + } + if (fetched_table->size() > 0) { + fetched_buffer->set_tensor_table(std::move(fetched_table)); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); + } + if (eoe) { + // Push out EOE upon StopIteration exception from generator + MS_LOG(DEBUG) << "Generator operator sends out EOE."; + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + // If last repeat or not repeated, push out EOF and exit master loop + MS_LOG(DEBUG) << "Generator operator sends out EOF."; + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + MS_LOG(DEBUG) << "Generator operator main execution loop complete."; + eof = true; + } else { + // Waiting for repeatOp to start new epoch + // If Reset() is called first by repeat op, this wait() will return right away. + // If Reset() is not called yet, this wait() will block until reset. + RETURN_IF_NOT_OK(wp_.Wait()); + // Clear the status of the wait post + wp_.Clear(); + } + } + } + return Status::OK(); +} + +Status GeneratorOp::Reset() { + // Reset Op state + RETURN_IF_NOT_OK(this->Init()); + // Wake up master thread + wp_.Set(); + return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); +} + +// Visitor accept method for NodePass +Status GeneratorOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status GeneratorOp::ComputeColMap() { + // Setup column names map (base class field) + if (column_name_id_map_.empty()) { + for (int i = 0; i < column_names_.size(); ++i) { + column_name_id_map_[column_names_[i]] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h new file mode 100644 index 0000000000..d09bfc3d71 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h @@ -0,0 +1,163 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +#pragma GCC visibility push(hidden) + +class GeneratorOp : public PipelineOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetGeneratorFunction(py::function generator_function) { + build_generator_function_ = generator_function; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetColumnNames(const std::vector &column_names) { + build_column_names_ = column_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetColumnTypes(const std::vector &column_types) { + build_column_types_ = column_types; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetPrefetchSize(int32_t prefetch_size) { + build_prefetch_size_ = prefetch_size; + return *this; + } + + // The builder "build" method creates the final object. + // @return shared_ptr to the new GeneratorOp object + Status Build(std::shared_ptr *); + + private: + // The builder saves all GeneratorOp construction arguments internally. + // The following are the arguments. + py::function build_generator_function_; + std::vector build_column_names_; + std::vector build_column_types_; + + int32_t build_prefetch_size_ = 0; + int32_t build_buffer_size_; + int32_t build_op_connector_size_; + + Status SanityCheck(); + }; + + GeneratorOp(py::function generator_function, std::vector column_names, + std::vector column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size); + + ~GeneratorOp(); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param generator_op - reference to the GeneratorOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const GeneratorOp &generator_op) { + generator_op.Print(out, false); + return out; + } + + // Class functor operator () override. + // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work. + // @return Status - The error code return + Status operator()() override; + + // Overrides base class reset method. When an operator does a reset, it cleans up any state + // info from it's previous execution and then initializes itself so that it can be executed + // again. + // @return Status - The error code return + Status Reset() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "GeneratorOp"; } + + private: + py::function generator_function_; + std::vector column_names_; + std::vector column_types_; + int32_t prefetch_size_; + int32_t buffer_size_; + + py::object generator_; + int32_t buffer_id_; + + WaitPost wp_; + + Status Init(); + + void Dealloc() noexcept; + + Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row); + + Status FillBuffer(TensorQTable *tt); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; +}; + +#pragma GCC visibility pop +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc new file mode 100644 index 0000000000..85839303db --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -0,0 +1,429 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include +#include +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +ImageFolderOp::Builder::Builder() : builder_decode_(false), builder_recursive_(false), builder_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status ImageFolderOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + builder_schema_ = std::make_unique(); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); + *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, + builder_op_connector_size_, builder_recursive_, builder_decode_, + builder_extensions_, builder_labels_to_read_, std::move(builder_schema_), + std::move(builder_sampler_)); + return Status::OK(); +} + +Status ImageFolderOp::Builder::SanityCheck() { + Path dir(builder_dir_); + std::string err_msg; + err_msg += dir.IsDirectory() == false ? "ImageFolder path is invalid or not set\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, + bool recursive, bool do_decode, const std::set &exts, + const std::map &map, std::unique_ptr data_schema, + std::shared_ptr sampler) + : ParallelOp(num_wkrs, queue_size, std::move(sampler)), + rows_per_buffer_(rows_per_buffer), + folder_path_(file_dir), + recursive_(recursive), + decode_(do_decode), + extensions_(exts), + class_index_(map), + data_schema_(std::move(data_schema)), + row_cnt_(0), + buf_cnt_(0), + sampler_ind_(0), + dirname_offset_(0) { + folder_name_queue_ = std::make_unique>(num_wkrs * queue_size); + image_name_queue_ = std::make_unique>(num_wkrs * queue_size); + io_block_queues_.Init(num_workers_, queue_size); +} + +// Master thread that pulls the prescan worker's results. +// Keep collecting results until all prescan workers quit +// Then consolidate 2 level shuffles together into 1 giant vector +// calculate numRows then return +Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) { + std::vector v; + int64_t cnt = 0; + while (cnt != num_workers_) { // count number of end signals + FolderImagesPair p; + RETURN_IF_NOT_OK(image_name_queue_->PopFront(&p)); + if (p == nullptr) { + cnt++; + } else { + v.push_back(p); + } + } + std::sort(v.begin(), v.end(), + [](const FolderImagesPair &lhs, const FolderImagesPair &rhs) { return lhs->first < rhs->first; }); + // following loop puts the 2 level of shuffles together into 1 vector + for (size_t ind = 0; ind < v.size(); ++ind) { + while (v[ind]->second.empty() == false) { + MS_ASSERT(!(v[ind]->first.empty())); // make sure that v[ind]->first.substr(1) is not out of bound + v[ind]->second.front()->second = class_index_.empty() ? ind : class_index_[v[ind]->first.substr(1)]; + image_label_pairs_.push_back(v[ind]->second.front()); + v[ind]->second.pop(); + } + } + image_label_pairs_.shrink_to_fit(); + num_rows_ = image_label_pairs_.size(); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset " + "API validation first."); + } + // free memory of two queues used for pre-scan + folder_name_queue_->Reset(); + image_name_queue_->Reset(); + return Status::OK(); +} + +// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work +Status ImageFolderOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (true) { // each iterator is 1 epoch + std::vector keys; + keys.reserve(rows_per_buffer_); + while (sampler_buffer->eoe() == false) { + TensorRow sample_row; + RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + if (sample_ids->type() != DataType(DataType::DE_INT64)) RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) >= num_rows_) continue; // index out of bound, skipping + keys.push_back(*itr); + row_cnt_++; + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK( + io_block_queues_[buf_cnt_++ % num_workers_]->Add(std::make_unique(keys, IOBlock::kDeIoBlockNone))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(keys, IOBlock::kDeIoBlockNone))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); + for (int32_t i = 0; i < num_workers_; ++i) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { // not the last repeat. Sleep master thread, wait for the wake-up from reset + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + } +} + +// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ +// IMPORTANT: 1 IOBlock produces 1 DataBuffer +Status ImageFolderOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty() == true) return Status::OK(); // empty key is a quit signal for workers + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer +Status ImageFolderOp::LoadTensorRow(row_id_type row_id, ImageLabelPair pairPtr, TensorRow *trow) { + std::shared_ptr image, label; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), + data_schema_->column(1).type(), + reinterpret_cast(&pairPtr->second))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, folder_path_ + (pairPtr->first))); + + if (decode_ == true) { + Status rc = Decode(image, &image); + if (rc.IsError()) { + std::string err = "Fail to decode image:" + folder_path_ + (pairPtr->first); + RETURN_STATUS_UNEXPECTED(err); + } + } + (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); + return Status::OK(); +} + +// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer +Status ImageFolderOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + TensorRow trow; + for (const int64_t &key : keys) { + RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_label_pairs_[key], &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +void ImageFolderOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows:" << num_rows_ << "\nImageFolder directory: " << folder_path_ << "\n\n"; + } +} + +// Reset Sampler and wakeup Master thread (functor) +Status ImageFolderOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); // wake up master thread after reset is done + return Status::OK(); +} + +// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows +Status ImageFolderOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +// Derived from RandomAccessOp +Status ImageFolderOp::GetClassIds(std::map> *cls_ids) const { + if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { + RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); + } + for (size_t i = 0; i < image_label_pairs_.size(); ++i) { + (*cls_ids)[image_label_pairs_[i]->second].push_back(i); + } + for (auto &pair : (*cls_ids)) { + pair.second.shrink_to_fit(); + } + return Status::OK(); +} + +// Worker Entry for pre-scanning all the folders and do the 1st level shuffle +// Worker pull a file name from mFoldernameQueue (which is a Queue), walks all the images under that foldername +// After walking is complete, sort all the file names (relative path to all jpeg files under the same directory ) +// (Sort is automatically conducted using a set which is implemented using a Red-Black Tree) +// Add the sorted filenames in to a queue. The make a pair (foldername, queue*), +// foldername is used for 2nd level sorting. +// FYI: 1st level sorting: sort all images under the same directory. +// FYI: 2nd level sorting: sort all folder names +// push this pair to mImagenameQueue (which is again a Queue) +Status ImageFolderOp::PrescanWorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::string folder_name; + RETURN_IF_NOT_OK(folder_name_queue_->PopFront(&folder_name)); + while (folder_name.empty() == false) { + Path folder(folder_path_ + folder_name); + std::shared_ptr dirItr = Path::DirIterator::OpenDirectory(&folder); + if (folder.Exists() == false || dirItr == nullptr) { + RETURN_STATUS_UNEXPECTED("Error unable to open: " + folder_name); + } + std::set imgs; // use this for ordering + while (dirItr->hasNext()) { + Path file = dirItr->next(); + if (extensions_.empty() || extensions_.find(file.Extension()) != extensions_.end()) { + (void)imgs.insert(file.toString().substr(dirname_offset_)); + } else { + MS_LOG(WARNING) << "Image folder operator unsupported file found: " << file.toString() + << ", extension: " << file.Extension() << "."; + } + } + FolderImagesPair p = std::make_shared>>(); + p->first = folder_name; + for (const std::string &img : imgs) { + p->second.push(std::make_shared>(img, 0)); + } + RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(p)); + RETURN_IF_NOT_OK(folder_name_queue_->PopFront(&folder_name)); + } + RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(nullptr)); // end signal + return Status::OK(); +} + +// This helper function recursively walks all foldernames, and send each foldername to mFoldernameQueue +// if mRecursive == false, don't go into folder of folders +Status ImageFolderOp::RecursiveWalkFolder(Path *dir) { + std::shared_ptr dir_itr = Path::DirIterator::OpenDirectory(dir); + RETURN_UNEXPECTED_IF_NULL(dir_itr); + while (dir_itr->hasNext()) { + Path subdir = dir_itr->next(); + if (subdir.IsDirectory()) { + if (class_index_.empty() || + class_index_.find(subdir.toString().substr(dirname_offset_ + 1)) != class_index_.end()) { + RETURN_IF_NOT_OK(folder_name_queue_->EmplaceBack(subdir.toString().substr(dirname_offset_))); + } + if (recursive_ == true) { + RETURN_IF_NOT_OK(RecursiveWalkFolder(&subdir)); + } + } + } + return Status::OK(); +} + +// A thread that calls RecursiveWalkFolder +Status ImageFolderOp::startAsyncWalk() { + TaskManager::FindMe()->Post(); + Path dir(folder_path_); + if (dir.Exists() == false || dir.IsDirectory() == false) { + RETURN_STATUS_UNEXPECTED("Error unable to open: " + folder_path_); + } + dirname_offset_ = folder_path_.length(); + RETURN_IF_NOT_OK(RecursiveWalkFolder(&dir)); + // send out num_workers_ end signal to mFoldernameQueue, 1 for each worker. + // Upon receiving end Signal, worker quits and set another end Signal to mImagenameQueue. + for (int32_t ind = 0; ind < num_workers_; ++ind) { + RETURN_IF_NOT_OK(folder_name_queue_->EmplaceBack("")); // end signal + } + return Status::OK(); +} + +Status ImageFolderOp::LaunchThreadsAndInitOp() { + RETURN_UNEXPECTED_IF_NULL(tree_); + // Registers QueueList and individual Queues for interrupt services + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(image_name_queue_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + // The following code launch 3 threads group + // 1) A thread that walks all folders and push the folder names to a util:Queue mFoldernameQueue. + // 2) Workers that pull foldername from mFoldernameQueue, walk it and return the sorted images to mImagenameQueue + // 3) Launch main workers that load DataBuffers by reading all images + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("walk dir", std::bind(&ImageFolderOp::startAsyncWalk, this))); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::PrescanWorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + // The order of the following 2 functions must not be changed! + RETURN_IF_NOT_OK(this->PrescanMasterEntry(folder_path_)); // Master thread of pre-scan workers, blocking + RETURN_IF_NOT_OK(this->InitSampler()); // pass numRows to Sampler + return Status::OK(); +} + +Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::set &exts, int64_t *num_rows, + int64_t *num_classes, int64_t dev_id, int64_t num_dev) { + Path dir(path); + std::string err_msg = ""; + int64_t row_cnt = 0; + err_msg += (dir.Exists() == false || dir.IsDirectory() == false) ? "unable to open dir " + path : ""; + err_msg += (num_classes == nullptr || num_rows == nullptr) ? "num_class/num_rows is null\n" : ""; + err_msg += (dev_id >= num_dev || num_dev <= 0) ? "invalid sharding config\n" : ""; + if (err_msg.empty() == false) { + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::queue foldernames; + std::shared_ptr dir_itr = Path::DirIterator::OpenDirectory(&dir); + while (dir_itr->hasNext()) { + Path subdir = dir_itr->next(); + if (subdir.IsDirectory()) { + foldernames.push(subdir.toString()); + } + } + (*num_classes) = foldernames.size(); + while (foldernames.empty() == false) { + Path subdir(foldernames.front()); + dir_itr = Path::DirIterator::OpenDirectory(&subdir); + while (dir_itr->hasNext()) { + if (exts.empty() || exts.find(subdir.Extension()) != exts.end()) { + ++row_cnt; + } + } + foldernames.pop(); + } + (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); + return Status::OK(); +} + +// Visitor accept method for NodePass +Status ImageFolderOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status ImageFolderOp::ComputeColMap() { + // Set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h new file mode 100644 index 0000000000..153751d3c5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h @@ -0,0 +1,274 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// Forward declares +template +class Queue; + +using ImageLabelPair = std::shared_ptr>; +using FolderImagesPair = std::shared_ptr>>; + +class ImageFolderOp : public ParallelOp, public RandomAccessOp { + public: + class Builder { + public: + // Constructor for Builder class of ImageFolderOp + // @param int32_t numWrks - number of parallel workers + // @param dir - directory folder got ImageNetFolder + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method + // @param int32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + builder_op_connector_size_ = size; + return *this; + } + + // Setter method + // @param std::set & exts, file extensions to be read + // @return Builder setter method returns reference to the builder. + Builder &SetExtensions(const std::set &exts) { + builder_extensions_ = exts; + return *this; + } + + // Setter method + // @paramconst std::map& map - a class name to label map + // @return + Builder &SetClassIndex(const std::map &map) { + builder_labels_to_read_ = map; + return *this; + } + + // Setter method + // @param bool do_decode + // @return Builder setter method returns reference to the builder. + Builder &SetDecode(bool do_decode) { + builder_decode_ = do_decode; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method + // @param const std::string & dir + // @return + Builder &SetImageFolderDir(const std::string &dir) { + builder_dir_ = dir; + return *this; + } + + // Whether dir are walked recursively + // @param bool recursive - if set to false, only get dirs in top level dir + // @return + Builder &SetRecursive(bool recursive) { + builder_recursive_ = recursive; + return *this; + } + + // Check validity of input args + // @return - The error code return + Status SanityCheck(); + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + bool builder_decode_; + bool builder_recursive_; + std::string builder_dir_; + int32_t builder_num_workers_; + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::set builder_extensions_; + std::shared_ptr builder_sampler_; + std::unique_ptr builder_schema_; + std::map builder_labels_to_read_; + }; + + // Constructor + // @param int32_t num_wkrs - Num of workers reading images in parallel + // @param int32_t - rows_per_buffer Number of images (rows) in each buffer + // @param std::string - dir directory of ImageNetFolder + // @param int32_t queue_size - connector queue size + // @param std::set exts - set of file extensions to read, if empty, read everything under the dir + // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read + ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, + bool do_decode, const std::set &exts, const std::map &map, + std::unique_ptr, std::shared_ptr sampler); + + // Destructor. + ~ImageFolderOp() = default; + + // Initialize ImageFOlderOp related var, calls the function to walk all files + // @param - std::string dir file directory to ImageNetFolder + // @return - The error code return + Status PrescanMasterEntry(const std::string &dir); + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t workerId - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t workerId - id of each worker + // @return Status - The error code return + Status PrescanWorkerEntry(int32_t worker_id); + + // Main Loop of ImageFolderOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // Method derived from RandomAccess Op, enable Sampler to get all ids for each class + // @param (std::map> * map - key label, val all ids for this class + // @return Status - The error code return + Status GetClassIds(std::map> *cls_ids) const override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // This function is a hack! It is to return the num_class and num_rows. The result + // returned by this function may not be consistent with what image_folder_op is going to return + // user this at your own risk! + static Status CountRowsAndClasses(const std::string &path, const std::set &exts, int64_t *num_rows, + int64_t *num_classes, int64_t dev_id = 0, int64_t num_dev = 1); + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ImageFolderOp"; } + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Load a tensor row according to a pair + // @param row_id_type row_id - id for this tensor row + // @param ImageLabelPair pair - + // @param TensorRow row - image & label read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // @param std::string & dir - dir to walk all images + // @param int64_t * cnt - number of non folder files under the current dir + // @return + Status RecursiveWalkFolder(Path *dir); + + // start walking of all dirs + // @return + Status startAsyncWalk(); + + // Called first when function is called + // @return + Status LaunchThreadsAndInitOp(); + + // reset Op + // @return Status - The error code return + Status Reset() override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t rows_per_buffer_; + std::string folder_path_; // directory of image folder + bool recursive_; + bool decode_; + std::set extensions_; // extensions allowed + std::map class_index_; + std::unique_ptr data_schema_; + int64_t row_cnt_; + int64_t buf_cnt_; + int64_t sampler_ind_; + int64_t dirname_offset_; + WaitPost wp_; + std::vector image_label_pairs_; + QueueList> io_block_queues_; // queues of IOBlocks + std::unique_ptr> folder_name_queue_; + std::unique_ptr> image_name_queue_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.cc new file mode 100644 index 0000000000..2b2542430b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/io_block.h" + +#include +#include + +namespace mindspore { +namespace dataset { +// IOBlock Class // + +// Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. +IOBlock::IOBlock(int64_t inKey, IOBlockFlags io_block_flags) : index_keys_(1, inKey), io_block_flags_(io_block_flags) {} + +// Constructor of the IOBlock (2) +IOBlock::IOBlock(const std::vector &in_keys, IOBlockFlags io_block_flags) : io_block_flags_(io_block_flags) { + index_keys_.insert(index_keys_.end(), in_keys.begin(), in_keys.end()); +} + +// Constructor of the IOBlock (3). A special IOBlock that is used for control messaging. +IOBlock::IOBlock(IOBlockFlags io_block_flags) : io_block_flags_(io_block_flags) {} + +// Fetches the first key from this block +Status IOBlock::GetKey(int64_t *out_key) const { + if (out_key == nullptr || index_keys_.empty()) { + RETURN_STATUS_UNEXPECTED("Failed to get the key from IOBlock"); + } + *out_key = index_keys_[0]; + return Status::OK(); +} + +// Fetches the list of keys from this block. +Status IOBlock::GetKeys(std::vector *out_keys) const { + if (out_keys == nullptr) { + RETURN_STATUS_UNEXPECTED("Output arg for GetKeys is null"); + } + *out_keys = index_keys_; // vector copy assign + return Status::OK(); +} + +// FilenameBlock derived class // + +// Constructor of the FilenameBlock (1) +FilenameBlock::FilenameBlock(int64_t key, int64_t start_offset, int64_t end_offset, IOBlockFlags io_block_flags) + : IOBlock(key, io_block_flags), start_offset_(start_offset), end_offset_(end_offset) {} + +// Constructor of the FilenameBlock (2). A special IOBlock that is used for control messaging. +FilenameBlock::FilenameBlock(IOBlockFlags io_block_flags) + : IOBlock(io_block_flags), start_offset_(kInvalidOffset), end_offset_(kInvalidOffset) {} + +// Gets the filename from the block using the provided index container +Status FilenameBlock::GetFilename(std::string *out_filename, const AutoIndexObj &index) const { + if (out_filename == nullptr) { + RETURN_STATUS_UNEXPECTED("Failed to get filename from FilenameBlock"); + } + + // a FilenameBlock only has one key. Call base class method to fetch that key + int64_t fetched_key; + RETURN_IF_NOT_OK(IOBlock::GetKey(&fetched_key)); + + // Do an index lookup using that key to get the filename. + auto r = index.Search(fetched_key); + if (r.second) { + auto &it = r.first; + *out_filename = it.value(); + } else { + RETURN_STATUS_UNEXPECTED("Could not find filename from index"); + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h new file mode 100644 index 0000000000..df26aa1fc1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h @@ -0,0 +1,125 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ + +#include +#include + +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// The IOBlock class is used to describe a "unit of work" that a storage leaf operator worker thread +// is responsible for acting on. +// The IOBlocks and it's derived classes abstracts a key-store and key-lookup interface where each +// block contains 1 to n keys, and the keys are used in conjunction with an index to provide the meta +// information for satisfying an IO request. +class IOBlock { + public: + enum IOBlockFlags : uint32_t { + kDeIoBlockNone = 0, + kDeIoBlockFlagEoe = 1u, // end of IOBlocks for one epoch + kDeIoBlockFlagEof = 1u << 1 // end of IOBlocks for entire program + }; + + // Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. + // @param inKey - A single key to add into the block + // @param io_block_flags - The flag setting for the block + IOBlock(int64_t inKey, IOBlockFlags io_block_flags); + + // Constructor of the IOBlock (2). + // @param in_keys - A vector of keys to add into the block + // @param io_block_flags - The flag setting for the block + IOBlock(const std::vector &in_keys, IOBlockFlags io_block_flags); + + // Constructor of the IOBlock (3). A special IOBlock that is used for control messaging. + // @param io_block_flags - The flag setting for the block + explicit IOBlock(IOBlockFlags io_block_flags); + + // Destructor + virtual ~IOBlock() = default; + + // Fetches the first key from the block. + // @note Only useful if you know the block only has 1 key. + // @return A copy of the first key from the block + // @return Status - The error code return + Status GetKey(int64_t *out_key) const; + + // Fetches the list of keys from this block. + // @param out_keys - A copy of the vector of keys from the block. + // @return Status - The error code return + Status GetKeys(std::vector *out_keys) const; + + // Does this block have the eoe flag turned on? + // @return T/F if the IOBlock is eoe + bool eoe() const { return static_cast(io_block_flags_) & static_cast(kDeIoBlockFlagEoe); } + + // Does this block have the eof flag turned on? + // @return T/F if the IOBlock is eof + bool eof() const { return static_cast(io_block_flags_) & static_cast(kDeIoBlockFlagEof); } + + // Adds a key to this block + // @param key - The key to add to this block + void AddKey(int64_t key) { index_keys_.push_back(key); } + + protected: + std::vector index_keys_; // keys used for lookups to the meta info for the data + IOBlockFlags io_block_flags_; +}; // class IOBlock + +const int64_t kInvalidOffset = -1; + +// The Filename block derived class implements a style of IO block where each block contains only a +// single key that maps to a filename. +class FilenameBlock : public IOBlock { + public: + // Constructor of the FilenameBlock (1) + // @param key - The key identifier that can be used to find the data for this block + // @param start_offset - Start offset + // @param end_offset - End offset + // @param io_block_flags - The flag setting for the block + FilenameBlock(int64_t key, int64_t start_offset, int64_t end_offset, IOBlockFlags io_block_flags); + + // Constructor of the FilenameBlock (2). A special IOBlock that is used for control messaging. + // @param io_block_flags - The flag setting for the block + explicit FilenameBlock(IOBlockFlags io_block_flags); + + // Destructor + ~FilenameBlock() = default; + + // Gets the filename from the block using the provided index container + // @param out_filename - The filename to add to the block + // @param index - The index to perform lookup against + // @return Status - The error code return + Status GetFilename(std::string *out_filename, const AutoIndexObj &index) const; + + // Get the start offset of file + // @return int64_t - Start offset + int64_t GetStartOffset() const { return start_offset_; } + + // Get the end offset of the file + // @return int64_t - Start offset + int64_t GetEndOffset() const { return end_offset_; } + + private: + int64_t start_offset_; + int64_t end_offset_; +}; // class TFBlock +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc new file mode 100644 index 0000000000..0476baf56f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -0,0 +1,438 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" + +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_decode_(false) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status ManifestOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_file_, + builder_op_connector_size_, builder_decode_, builder_labels_to_read_, + std::move(builder_schema_), std::move(builder_sampler_), builder_usage_); + return Status::OK(); +} + +Status ManifestOp::Builder::SanityCheck() { + std::string err_msg; + err_msg += builder_file_.empty() ? "Manifest file is not set\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers smaller than 1\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, + const std::map &class_index, std::unique_ptr data_schema, + std::shared_ptr sampler, std::string usage) + : ParallelOp(num_works, queue_size, std::move(sampler)), + rows_per_buffer_(rows_per_buffer), + io_block_pushed_(0), + row_cnt_(0), + sampler_ind_(0), + data_schema_(std::move(data_schema)), + file_(file), + class_index_(class_index), + decode_(decode), + usage_(usage), + buf_cnt_(0) { + io_block_queues_.Init(num_workers_, queue_size); + (void)std::transform(usage_.begin(), usage_.end(), usage_.begin(), ::tolower); +} + +// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work +Status ManifestOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + return AddIoBlock(&sampler_buffer); +} + +Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { + while (true) { // each iterator is 1 epoch + std::vector keys; + keys.reserve(rows_per_buffer_); + while (!(*sampler_buffer)->eoe()) { + TensorRow sample_row; + RETURN_IF_NOT_OK((*sampler_buffer)->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) >= num_rows_) continue; // index out of bound, skipping + keys.push_back(*itr); + row_cnt_++; + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); + } + } +} + +Status ManifestOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&ManifestOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(ParseManifestFile()); + RETURN_IF_NOT_OK(CountDatasetInfo()); + RETURN_IF_NOT_OK(InitSampler()); + return Status::OK(); +} + +// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ +// IMPORTANT: 1 IOBlock produces 1 DataBuffer +Status ManifestOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty()) { + return Status::OK(); // empty key is a quit signal for workers + } + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer +Status ManifestOp::LoadTensorRow(row_id_type row_id, const std::pair> &data, + TensorRow *trow) { + std::shared_ptr image; + std::shared_ptr label; + std::vector label_index(data.second.size()); + (void)std::transform(data.second.begin(), data.second.end(), label_index.begin(), + [this](const std::string &label_name) { return label_index_[label_name]; }); + if (label_index.size() == 1) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), TensorShape({}), + data_schema_->column(1).type(), + reinterpret_cast(&label_index[0]))); + } else { + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &label, data_schema_->column(1).tensorImpl(), TensorShape(std::vector(1, label_index.size())), + data_schema_->column(1).type(), reinterpret_cast(&label_index[0]))); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data.first)); + if (decode_ == true) { + Status rc = Decode(image, &image); + if (rc.IsError()) { + std::string err = "Fail to decode image:" + data.first; + RETURN_STATUS_UNEXPECTED(err); + } + } + (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); + return Status::OK(); +} + +// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer +Status ManifestOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + for (const auto &key : keys) { + TensorRow trow; + RETURN_IF_NOT_OK(LoadTensorRow(key, image_labelname_[static_cast(key)], &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +void ManifestOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows:" << num_rows_ << "\nManifest file: " << file_ << "\n\n"; + } +} + +// Reset Sampler and wakeup Master thread (functor) +Status ManifestOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); // wake up master thread after reset is done + return Status::OK(); +} + +// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows +Status ManifestOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +// Derived from RandomAccessOp +Status ManifestOp::GetClassIds(std::map> *cls_ids) const { + if (cls_ids == nullptr || !cls_ids->empty() || image_labelname_.empty()) { + RETURN_STATUS_UNEXPECTED("Class indexing is invalid."); + } + + for (size_t i = 0; i < image_labelname_.size(); i++) { + size_t image_index = i; + for (size_t j = 0; j < image_labelname_[image_index].second.size(); j++) { + std::string label_name = (image_labelname_[image_index].second)[j]; + int32_t label_index = label_index_.at(label_name); + (*cls_ids)[label_index].emplace_back(image_index); + } + } + + for (auto &pair : (*cls_ids)) { + pair.second.shrink_to_fit(); + } + return Status::OK(); +} + +// Manifest file content +// {"source": "/path/to/image1.jpg", "usage":"train", annotation": ...} +// {"source": "/path/to/image2.jpg", "usage":"eval", "annotation": ...} +Status ManifestOp::ParseManifestFile() { + std::ifstream file_handle(file_); + if (!file_handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Manifest file " + file_ + " can not open."); + } + std::string line; + while (getline(file_handle, line)) { + try { + nlohmann::json js = nlohmann::json::parse(line); + std::string image_file_path = js.value("source", ""); + // If image is not JPEG/PNG/GIF/BMP, drop it + bool valid = false; + RETURN_IF_NOT_OK(CheckImageType(image_file_path, &valid)); + if (!valid) { + continue; + } + std::string usage = js.value("usage", ""); + (void)std::transform(usage.begin(), usage.end(), usage.begin(), ::tolower); + if (usage != usage_) { + continue; + } + std::vector labels; + nlohmann::json annotations = js.at("annotation"); + for (nlohmann::json::iterator it = annotations.begin(); it != annotations.end(); ++it) { + nlohmann::json annotation = it.value(); + std::string label_name = annotation.value("name", ""); + if (label_name == "") { + file_handle.close(); + RETURN_STATUS_UNEXPECTED("Label name is not found in manifest file for " + image_file_path); + } + if (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) { + if (label_index_.find(label_name) == label_index_.end()) { + label_index_[label_name] = 0; + } + labels.emplace_back(label_name); + } + } + if (!labels.empty()) { + image_labelname_.emplace_back(std::make_pair(image_file_path, labels)); + } + } catch (const std::exception &err) { + file_handle.close(); + RETURN_STATUS_UNEXPECTED("Parse manifest file failed"); + } + } + file_handle.close(); + + return Status::OK(); +} + +// Only support JPEG/PNG/GIF/BMP +Status ManifestOp::CheckImageType(const std::string &file_name, bool *valid) { + std::ifstream file_handle; + constexpr int read_num = 3; + *valid = false; + file_handle.open(file_name, std::ios::binary | std::ios::in); + if (!file_handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Can not open image file " + file_name); + } + unsigned char file_type[read_num]; + (void)file_handle.read(reinterpret_cast(file_type), read_num); + + if (file_handle.fail()) { + file_handle.close(); + RETURN_STATUS_UNEXPECTED("Read image file failed " + file_name); + } + file_handle.close(); + if (file_type[0] == 0xff && file_type[1] == 0xd8 && file_type[2] == 0xff) { + // Normal JPEGs start with \xff\xd8\xff\xe0 + // JPEG with EXIF stats with \xff\xd8\xff\xe1 + // Use \xff\xd8\xff to cover both. + *valid = true; + } else if (file_type[0] == 0x89 && file_type[1] == 0x50 && file_type[2] == 0x4e) { + // It's a PNG + *valid = true; + } else if (file_type[0] == 0x47 && file_type[1] == 0x49 && file_type[2] == 0x46) { + // It's a GIF + *valid = true; + } else if (file_type[0] == 0x42 && file_type[1] == 0x4d) { + // It's a BMP + *valid = true; + } + return Status::OK(); +} + +Status ManifestOp::CountDatasetInfo() { + int32_t index = 0; + for (auto &label : label_index_) { + label.second = class_index_.empty() ? index : class_index_[label.first]; + index++; + } + + num_rows_ = static_cast(image_labelname_.size()); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " + "validation first."); + } + return Status::OK(); +} + +Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, + int64_t *count, int64_t *numClasses) { + // the logic of counting the number of samples is copied from ParseManifestFile() + std::map map; + for (auto p : dict) { + (void)map.insert(std::pair(py::reinterpret_borrow(p.first), + py::reinterpret_borrow(p.second))); + } + + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op)); + RETURN_IF_NOT_OK(op->ParseManifestFile()); + *numClasses = static_cast(op->label_index_.size()); + *count = static_cast(op->image_labelname_.size()); + return Status::OK(); +} + +Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, + std::map *output_class_indexing) { + std::map input_class_indexing; + for (auto p : dict) { + (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), + py::reinterpret_borrow(p.second))); + } + + if (!input_class_indexing.empty()) { + *output_class_indexing = input_class_indexing; + } else { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(input_class_indexing).SetUsage(usage).Build(&op)); + RETURN_IF_NOT_OK(op->ParseManifestFile()); + RETURN_IF_NOT_OK(op->CountDatasetInfo()); + uint32_t count = 0; + for (const auto label : op->label_index_) { + (*output_class_indexing).insert(std::make_pair(label.first, count)); + count++; + } + } + + return Status::OK(); +} + +// Visitor accept method for NodePass +Status ManifestOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status ManifestOp::ComputeColMap() { + // Set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h new file mode 100644 index 0000000000..bac8f04c94 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h @@ -0,0 +1,250 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +class ManifestOp : public ParallelOp, public RandomAccessOp { + public: + class Builder { + public: + // Constructor for Builder class of ManifestOp + Builder(); + + // Destructor + ~Builder() = default; + + // Setter method + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method + // @param int32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + builder_op_connector_size_ = size; + return *this; + } + + // Setter method + // @param const std::map& map - a class name to label map + // @return + Builder &SetClassIndex(const std::map &map) { + builder_labels_to_read_ = map; + return *this; + } + + // Setter method + // @param bool do_decode + // @return Builder setter method returns reference to the builder. + Builder &SetDecode(bool do_decode) { + builder_decode_ = do_decode; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method + // @param const std::string & dir + // @return Builder setter method returns reference to the builder. + Builder &SetManifestFile(const std::string &file) { + builder_file_ = file; + return *this; + } + + // Setter method + // @param const std::string & dir + // @return Builder setter method returns reference to the builder. + Builder &SetUsage(const std::string &usage) { + builder_usage_ = usage; + return *this; + } + + // Check validity of input args + // @return Status - The error code return + Status SanityCheck(); + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + std::shared_ptr builder_sampler_; + bool builder_decode_; + + std::string builder_file_; + int32_t builder_num_workers_; + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::unique_ptr builder_schema_; + std::string builder_usage_; + std::map builder_labels_to_read_; + }; + + // Constructor + // @param int32_t num_works - Num of workers reading images in parallel + // @param int32_t - rows_per_buffer Number of images (rows) in each buffer + // @param std::string - file list of Manifest + // @param int32_t queue_size - connector queue size + // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read + ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, + const std::map &class_index, std::unique_ptr data_schema, + std::shared_ptr sampler, std::string usage); + // Destructor. + ~ManifestOp() = default; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t worker_id - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Main Loop of ManifestOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // Method derived from RandomAccess Op, enable Sampler to get all ids for each class + // @param (std::map> * map - key label, val all ids for this class + // @return Status - The error code return + Status GetClassIds(std::map> *cls_ids) const override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count, + int64_t *numClasses); + + // Get str-to-int mapping from label name to index + static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, + std::map *output_class_indexing); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ManifestOp"; } + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Method in operator(), to fill IOBlockQueue + // @param std::unique_ptr sampler_buffer - to fill IOBlockQueue + // @return Status - The error code return + Status AddIoBlock(std::unique_ptr *sampler_buffer); + + // Load a tensor row according to a pair + // @param row_id_type row_id - id for this tensor row + // @param std::pair> - > + // @param TensorRow row - image & label read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, const std::pair> &data, + TensorRow *row); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Parse manifest file to get image path and label and so on. + // @return Status - The error code return + Status ParseManifestFile(); + + // Called first when function is called + // @return Status - The error code return + Status LaunchThreadsAndInitOp(); + + // reset Op + // @return Status - The error code return + Status Reset() override; + + // Check if image ia valid.Only support JPEG/PNG/GIF/BMP + // @return + Status CheckImageType(const std::string &file_name, bool *valid); + + // Count label index,num rows and num samples + // @return Status - The error code return + Status CountDatasetInfo(); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t rows_per_buffer_; + int64_t io_block_pushed_; + int64_t row_cnt_; + int64_t sampler_ind_; + std::unique_ptr data_schema_; + std::string file_; // file that store the information of images + std::map class_index_; + bool decode_; + std::string usage_; + int64_t buf_cnt_; + + WaitPost wp_; + QueueList> io_block_queues_; + std::map label_index_; + std::vector>> image_labelname_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc new file mode 100644 index 0000000000..cf1493eb78 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -0,0 +1,513 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" + +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +using mindrecord::kInt64Len; +using mindrecord::MSRStatus; +using mindrecord::Schema; +using mindrecord::ShardOperator; +using mindrecord::ShardReader; + +// Builder constructor. Creates the builder object. +MindRecordOp::Builder::Builder() : build_dataset_file_({}) { + // Some arguments to the MindRecordOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the MindRecordOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_mind_record_workers_ = kDefaultMindRecordWorkers; + build_rows_per_buffer_ = cfg->rows_per_buffer(); + build_op_connector_queue_size_ = cfg->op_connector_size(); + build_block_reader_ = false; + builder_num_workers_ = 0; + build_num_padded_ = 0; + build_sample_ = nullptr; +} + +// The builder "build" method creates the final object. +Status MindRecordOp::Builder::Build(std::shared_ptr *ptr) { + std::shared_ptr new_mind_record_op; + + if (build_dataset_file_.empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Building a MindRecordOp that has not provided a file."); + } + mindrecord::json sample_json; + if (build_num_padded_ > 0) { + sample_json = ToJson(build_sample_); + } + new_mind_record_op = std::make_shared( + build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_, + build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_, build_num_padded_, + sample_json, build_sample_bytes_); + + RETURN_IF_NOT_OK(new_mind_record_op->Init()); + *ptr = std::move(new_mind_record_op); + return Status::OK(); +} + +Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); } + +mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) { + if (obj.is_none()) { + return nullptr; + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { // also catch py::bytes + return obj.cast(); + } + if (py::isinstance(obj)) { + auto out = mindrecord::json::object(); + for (const py::handle &key : obj) { + if (py::isinstance(obj[key])) { + build_sample_bytes_[py::str(key).cast()] = obj[key].cast(); + } else { + out[py::str(key).cast()] = ToJson(obj[key]); + } + } + return out; + } + MS_LOG(ERROR) << "Python object convert to json failed, object is: " << py::cast(obj); + return mindrecord::json(); +} + +// Constructor of the MindRecordOp. +MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, + std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, + const std::vector &columns_to_load, + const std::vector> &operators, const bool &block_reader, + int64_t num_padded, const mindrecord::json &sample_json, + const std::map &sample_bytes) + : ParallelOp(num_mind_record_workers, op_connector_queue_size), + rows_per_buffer_(rows_per_buffer), + dataset_file_(dataset_file), + load_dataset_(load_dataset), + columns_to_load_(columns_to_load), + operators_(operators), + num_mind_record_workers_(num_mind_record_workers), + block_reader_(block_reader), + num_rows_(0), + buffers_needed_(0), + buf_cnt_(0), + ended_worker_(0), + buffer_water_mark_(0), + num_padded_(num_padded), + sample_json_(sample_json), + sample_bytes_(sample_bytes) { + io_blk_queues_.Init(num_workers_, op_connector_queue_size); + if (!block_reader_) return; + for (int32_t i = 0; i < num_workers_; ++i) { + block_buffer_.emplace_back(std::make_unique>(std::vector{})); + } +} + +// Private helper method to encapsulate some common construction/reset tasks +Status MindRecordOp::Init() { + shard_reader_ = std::make_unique(); + auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, + block_reader_, num_padded_); + + CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, + "MindRecordOp init failed. Error message: " + ErrnoToMessage(rc)); + + data_schema_ = std::make_unique(); + + std::vector col_names = shard_reader_->GetShardColumn()->GetColumnName(); + CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found"); + std::vector col_data_types = shard_reader_->GetShardColumn()->GeColumnDataType(); + std::vector> col_shapes = shard_reader_->GetShardColumn()->GetColumnShape(); + + bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything + std::map colname_to_ind; + for (uint32_t i = 0; i < col_names.size(); i++) { + std::string colname = col_names[i]; + ColDescriptor col_desc; + + TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown + std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]]; + DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"} + + if (col_data_types[i] == mindrecord::ColumnBytes) { // rank = 1 + col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1); + } else if (col_data_types[i] == mindrecord::ColumnString) { // rank = 0 + col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 0); + } else if (col_shapes[i].size() > 0) { + std::vector vec(col_shapes[i].size()); // temporary vector to hold shape + (void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin()); + t_shape = TensorShape(vec); + col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); + } else { // unknown shape + // create colDesc and add it to schema + col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); + } + + colname_to_ind[colname] = data_schema_->NumColumns(); + RETURN_IF_NOT_OK(data_schema_->AddColumn(col_desc)); + + if (load_all_cols) { + columns_to_load_.emplace_back(colname); + } + } + + if (!load_all_cols) { + std::unique_ptr tmp_schema = std::make_unique(); + for (std::string colname : columns_to_load_) { + CHECK_FAIL_RETURN_UNEXPECTED(colname_to_ind.find(colname) != colname_to_ind.end(), colname + ": doesn't exist"); + RETURN_IF_NOT_OK(tmp_schema->AddColumn(data_schema_->column(colname_to_ind[colname]))); + } + data_schema_ = std::move(tmp_schema); + } + + return Status::OK(); +} + +// Destructor +MindRecordOp::~MindRecordOp() {} + +// A print method typically used for debugging +void MindRecordOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\n Dataset file : "; + for (auto &file : dataset_file_) { + out << file << " "; + } + out << "\nNumber of rows : " << num_rows_ << "\nRows per buffer : " << rows_per_buffer_ + << "\nNumber of buffers : " << buffers_needed_ + << "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n"; + } +} + +Status MindRecordOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe()) { + RETURN_IF_NOT_OK( + out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + continue; + } + if (io_block->eof()) { + RETURN_IF_NOT_OK( + out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + continue; + } + + // load data buffer + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty() == true) { + { + std::unique_lock lock(ended_worker_mutex_); + ended_worker_++; + if (ended_worker_ == num_workers_) shard_reader_->Close(); + } + return Status::OK(); // empty key is a quit signal for workers + } + + const uint64_t buffer_id = keys[0]; + std::unique_ptr fetched_buffer; + + // Get the next buffer. Push it up to the output connector. + if (buffer_id % LOG_INTERVAL == 0) { + MS_LOG(DEBUG) << "MindRecord operator consumed buffer " << buffer_id << " by worker " << worker_id << "."; + } + RETURN_IF_NOT_OK(GetBufferFromReader(&fetched_buffer, buffer_id, worker_id)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(fetched_buffer))); + if (!block_reader_) { + RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + continue; + } + + // update block-reader buffer + block_buffer_[buffer_id % num_workers_]->clear(); + { + std::unique_lock lck(mtx_block_reader_); + if (buffer_id == buffer_water_mark_) { + buffer_water_mark_++; + while (block_set_.count(buffer_water_mark_) > 0) (void)block_set_.erase(buffer_water_mark_++); + } else { + (void)block_set_.insert(buffer_id); + } + } + cv_reader_.notify_one(); + RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +Status MindRecordOp::GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, + int32_t worker_id) { + *fetched_buffer = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + std::unique_ptr tensor_table = std::make_unique(); + for (int32_t i = 0; i < rows_per_buffer_; ++i) { + ShardTuple tupled_buffer; + mindrecord::TaskType task_type = mindrecord::TaskType::kCommonTask; + if (block_reader_) { + if (i >= block_buffer_[buffer_id % num_workers_]->size()) break; + tupled_buffer = block_buffer_[buffer_id % num_workers_]->at(i); + } else { + int32_t row_id = buffer_id * rows_per_buffer_ + i; + auto rc = shard_reader_->GetNextById(row_id, worker_id); + task_type = rc.first; + tupled_buffer = rc.second; + if (task_type == mindrecord::TaskType::kPaddedTask) { + TensorRow tensor_row; + RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, {}, mindrecord::json(), task_type)); + tensor_table->push_back(std::move(tensor_row)); + } + if (tupled_buffer.empty()) break; + } + if (task_type == mindrecord::TaskType::kCommonTask) { + for (const auto &tupled_row : tupled_buffer) { + std::vector columns_blob = std::get<0>(tupled_row); + mindrecord::json columns_json = std::get<1>(tupled_row); + TensorRow tensor_row; + RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json, task_type)); + tensor_table->push_back(std::move(tensor_row)); + } + } + } + + // Replace the TensorTable in DataBuffer with the new one. + (*fetched_buffer)->set_tensor_table(std::move(tensor_table)); + return Status::OK(); +} + +Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, + const mindrecord::json &columns_json, const mindrecord::TaskType task_type) { + for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) { + auto column_name = columns_to_load_[i_col]; + + // Initialize column parameters + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0; + mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType; + uint64_t column_data_type_size = 1; + std::vector column_shape; + + // Get column data + auto shard_column = shard_reader_->GetShardColumn(); + if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) { + auto rc = + shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, &column_shape); + if (rc.first != MSRStatus::SUCCESS) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve data type."); + } + if (rc.second == mindrecord::ColumnInRaw) { + auto has_column = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes); + if (has_column == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve raw data from padding sample."); + } + } else if (rc.second == mindrecord::ColumnInBlob) { + if (sample_bytes_.find(column_name) == sample_bytes_.end()) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve blob data from padding sample."); + } + std::string ss(sample_bytes_[column_name]); + n_bytes = ss.size(); + data_ptr = std::make_unique(n_bytes); + std::copy(ss.begin(), ss.end(), data_ptr.get()); + } else { + RETURN_STATUS_UNEXPECTED("Retrieved data type is unknown."); + } + if (data == nullptr) { + data = reinterpret_cast(data_ptr.get()); + } + } else { + auto has_column = + shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, + &column_data_type, &column_data_type_size, &column_shape); + if (has_column == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader."); + } + } + + std::shared_ptr tensor; + const ColDescriptor &column = data_schema_->column(i_col); + DataType type = column.type(); + + // Set shape + auto num_elements = n_bytes / column_data_type_size; + if (type == DataType::DE_STRING) { + std::string s{data, data + n_bytes}; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {s}, TensorShape::CreateScalar())); + } else if (column.hasShape()) { + auto new_shape = TensorShape(column.shape()); + RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast(num_elements), &new_shape)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); + } else { + std::vector shapeDetails = {static_cast(num_elements)}; + auto new_shape = TensorShape(shapeDetails); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); + } + tensor_row->push_back(std::move(tensor)); + } + return Status::OK(); +} + +Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { + { + std::unique_lock lck(mtx_block_reader_); + cv_reader_.wait(lck, [buffer_id, this] { return buffer_id < buffer_water_mark_ + num_workers_; }); + } + for (int32_t i = 0; i < rows_per_buffer_; i++) { + // Block reader does NOT care about argument + auto rc = shard_reader_->GetNextById(i, i); + ShardTuple tuple_buffer = rc.second; + if (tuple_buffer.empty()) break; + block_buffer_[buffer_id % num_workers_]->push_back(std::move(tuple_buffer)); + } + return Status::OK(); +} + +// Class functor operator () override. +// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will +// provide the master loop that drives the logic for performing the work +// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work +Status MindRecordOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadAndInitOp()); + num_rows_ = shard_reader_->GetNumRows(); + // Compute how many buffers we would need to accomplish rowsPerBuffer + buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; + + while (true) { // each iterator is 1 epoch + for (int32_t i = 0; i < buffers_needed_; ++i) { + if (block_reader_) RETURN_IF_NOT_OK(FetchBlockBuffer(i)); + std::vector keys(1, i); + RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + RETURN_IF_NOT_OK( + io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK( + io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK(io_blk_queues_[i]->Add( + std::move(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone)))); + } + return Status::OK(); + } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset + RETURN_IF_NOT_OK( + io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + + // reset our buffer count and go to loop again. + RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); + shard_reader_wait_post_.Clear(); + } + } +} + +// Overrides base class reset method. When an operator does a reset, it cleans up any state +// info from it's previous execution and then initializes itself so that it can be executed +// again. +Status MindRecordOp::Reset() { + RETURN_IF_NOT_OK(ParallelOp::Reset()); // Call our super class reset first. + + if (block_reader_) { + shard_reader_->Reset(); + buffer_water_mark_ = 0; + } else { + shard_reader_->ShuffleTask(); + } + shard_reader_wait_post_.Set(); + + return Status::OK(); +} + +Status MindRecordOp::LaunchThreadAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + + RETURN_IF_NOT_OK(io_blk_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(shard_reader_wait_post_.Register(tree_->AllTasks())); + if (shard_reader_->Launch(!block_reader_) == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); + } + // Launch main workers that load DataBuffers by reading all images + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + return Status::OK(); +} + +Status MindRecordOp::CountTotalRows(const std::vector dataset_path, bool load_dataset, + const std::shared_ptr &op, int64_t *count, int64_t num_padded) { + std::unique_ptr shard_reader = std::make_unique(); + MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded); + if (rc == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); + } + return Status::OK(); +} + +// Visitor accept method for NodePass +Status MindRecordOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status MindRecordOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + for (int i = 0; i < static_cast(columns_to_load_.size()); i++) { + column_name_id_map_[columns_to_load_[i]] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h new file mode 100644 index 0000000000..367505b172 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h @@ -0,0 +1,276 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/mindrecord/include/shard_column.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// Forward declares +template +class Queue; +class DataBuffer; + +using mindrecord::ShardOperator; +using mindrecord::ShardReader; +using ShardTuple = std::vector, mindrecord::json>>; // Row of data from ShardReader + +const int32_t LOG_INTERVAL = 19; + +class MindRecordOp : public ParallelOp { + public: + // The nested builder class inside of the MindRecordOp is used to help manage all of the arguments + // for constructing it. Use the builder by setting each argument with the provided set methods, + // and then finally call the build method to execute the actual construction. + class Builder { + public: + Builder(); + + ~Builder() = default; + + Status Build(std::shared_ptr *); + + Builder &SetRowsPerBuffer(int rows_per_buffer) { + build_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + Builder &SetNumMindRecordWorkers(int32_t num_mind_record_workers) { + build_num_mind_record_workers_ = num_mind_record_workers; + return *this; + } + + Builder &SetOpConnectorQueueSize(int32_t queue_size) { + build_op_connector_queue_size_ = queue_size; + return *this; + } + + Builder &SetDatasetFile(const std::vector &files) { + build_dataset_file_ = files; + return *this; + } + + Builder &SetColumnsToLoad(const std::vector &columns) { + build_columns_to_load_ = columns; + return *this; + } + + Builder &SetOperators(const std::vector> &operators) { + build_operators_ = operators; + return *this; + } + + Builder &SetBlockReader() { + build_block_reader_ = true; + return *this; + } + + Builder &SetLoadDataset(bool load_dataset) { + build_load_dataset_ = load_dataset; + return *this; + } + + Builder &SetNumToPadSamples(int64_t num_padded) { + build_num_padded_ = num_padded; + return *this; + } + + Builder &SetPaddedSample(const py::handle &sample) { + build_sample_ = sample; + return *this; + } + + Status SanityCheck() const; + + static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; } + + mindrecord::json ToJson(const py::handle &obj); + + private: + static constexpr int32_t kDefaultMindRecordWorkers = 4; + // The builder saves all MindRecordOp construction arguments internally. + // The following are the arguments. + int32_t build_num_mind_record_workers_; + int32_t builder_num_workers_; + int32_t build_rows_per_buffer_; + int32_t build_op_connector_queue_size_; + std::vector build_dataset_file_; + bool build_load_dataset_; + std::vector build_columns_to_load_; + std::vector> build_operators_; + bool build_block_reader_; + int64_t build_num_padded_; + py::handle build_sample_; + std::map build_sample_bytes_; + }; + + // Constructor of the MindRecordOp. + // @note The builder class should be used to call it + // @param num_mind_record_workers - The number of workers for the op (run by ShardReader) + // @param rows_per_buffer - The requested number of rows per buffer + // @param dataset_file - dataset files + // @param op_connector_queue_size - The output connector queue size + // @param columns_to_load - The list of columns to use (column name) + // @param operators - ShardOperators for Shuffle, Category, Sample + MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector dataset_file, + bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, + const std::vector> &operators, const bool &block_reader, + int64_t num_padded_, const mindrecord::json &sample_json, + const std::map &sample_bytes_); + + // Destructor + ~MindRecordOp() override; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param op - reference to the MindRecordOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const MindRecordOp &op) { + op.Print(out, false); + return out; + } + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t workerId - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Class functor operator () override. + // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work. + // @return Status - The error code return + Status operator()() override; + + // Called first when function is called + // @return + Status LaunchThreadAndInitOp(); + + // Overrides base class reset method. When an operator does a reset, it cleans up any state + // info from it's previous execution and then initializes itself so that it can be executed + // again. + // @return Status - The error code return + Status Reset() override; + + // Getter method + int32_t num_rows() const { return num_rows_; } + + static Status CountTotalRows(const std::vector dataset_path, bool load_dataset, + const std::shared_ptr &op, int64_t *count, int64_t num_padded); + + // Getter method + int32_t rows_per_buffer() const { return rows_per_buffer_; } + + // Getter method + std::vector dataset_file() const { return dataset_file_; } + + // Getter method + std::vector columns_to_load() const { return columns_to_load_; } + + bool block_reader() const { return block_reader_; } + + bool load_dataset() const { return load_dataset_; } + + Status Init(); + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "MindRecordOp"; } + + private: + Status GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, int32_t worker_id); + + // Parses a single cell and puts the data into a tensor + // @param tensor_row - the tensor row to put the parsed data in + // @param columns_blob - the blob data received from the reader + // @param columns_json - the data for fields received from the reader + Status LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, + const mindrecord::json &columns_json, const mindrecord::TaskType task_type); + + Status FetchBlockBuffer(const int32_t &buffer_id); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t rows_per_buffer_; // The number of requested rows per buffer. + std::vector dataset_file_; // dataset files + bool load_dataset_; // load dataset from single file or not + std::vector columns_to_load_; // Columns to load from dataset + std::vector> operators_; // ShardOperators to use + int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader + bool block_reader_; // block reader switch + int32_t buffers_needed_; // Counter for the buffers that were fetched + int64_t buf_cnt_; // Buffer counter + int32_t num_rows_; // One more than the last row id in the range for this cache + std::atomic ended_worker_; + std::atomic buffer_water_mark_; + + int64_t num_padded_; + mindrecord::json sample_json_; + std::map sample_bytes_; + + std::unique_ptr data_schema_; // Data schema for column typing + std::vector columns_blob_; // Blob Columns to load from dataset + std::vector columns_blob_index_; // Blob Columns to load from dataset + + std::unique_ptr shard_reader_; + WaitPost shard_reader_wait_post_; + QueueList> io_blk_queues_; + + // For block reader + std::mutex mtx_block_reader_; + std::condition_variable cv_reader_; + std::vector>> block_buffer_; + std::unordered_set block_set_; + + std::mutex ended_worker_mutex_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc new file mode 100644 index 0000000000..11ad18865e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -0,0 +1,450 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" + +#include +#include +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +const int32_t kMnistImageFileMagicNumber = 2051; +const int32_t kMnistLabelFileMagicNumber = 2049; +const int32_t kMnistImageRows = 28; +const int32_t kMnistImageCols = 28; + +MnistOp::Builder::Builder() : builder_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status MnistOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, + builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_)); + return Status::OK(); +} + +Status MnistOp::Builder::SanityCheck() { + Path dir(builder_dir_); + std::string err_msg; + err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, + std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_workers, queue_size, std::move(sampler)), + buf_cnt_(0), + row_cnt_(0), + folder_path_(folder_path), + rows_per_buffer_(rows_per_buffer), + data_schema_(std::move(data_schema)) { + io_block_queues_.Init(num_workers, queue_size); +} + +Status MnistOp::TraversalSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) >= num_rows_) continue; // index out of bound, skipping + keys->push_back(*itr); + row_cnt_++; + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); + keys->clear(); + } + } + return Status::OK(); +} + +// functor that contains the main logic of MNIST op +Status MnistOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (true) { // each iterator is 1 epoch + std::vector keys; + keys.reserve(rows_per_buffer_); + while (sampler_buffer->eoe() == false) { + std::shared_ptr sample_ids; + RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); + if (sample_ids->type() != DataType(DataType::DE_INT64)) { + RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't UINT64"); + } + RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + for (int32_t i = 0; i < num_workers_; ++i) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + } +} + +// contains the logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ +Status MnistOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr iOBlock; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); + while (iOBlock != nullptr) { + if (iOBlock->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (iOBlock->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(iOBlock->GetKeys(&keys)); + if (keys.empty() == true) return Status::OK(); // empty key is a quit signal for workers + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +// Load 1 TensorRow (image,label) using 1 MnistLabelPair. +Status MnistOp::LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *trow) { + std::shared_ptr image, label; + int32_t l = mnist_pair.second; + // make a copy of cached tensor + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), mnist_pair.first->shape(), + mnist_pair.first->type(), mnist_pair.first->GetBuffer())); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), + data_schema_->column(1).type(), reinterpret_cast(&l))); + (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); + return Status::OK(); +} + +// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer +Status MnistOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + TensorRow trow; + for (const int64_t &key : keys) { + RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_label_pairs_[key], &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +void MnistOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows:" << num_rows_ << "\nMNIST Directory: " << folder_path_ << "\n\n"; + } +} + +// Reset Sampler and wakeup Master thread (functor) +Status MnistOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); // wake up master thread after reset is done + return Status::OK(); +} + +// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows +Status MnistOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +// Derived from RandomAccessOp +Status MnistOp::GetClassIds(std::map> *cls_ids) const { + if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { + RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); + } + for (size_t i = 0; i < image_label_pairs_.size(); ++i) { + (*cls_ids)[image_label_pairs_[i].second].push_back(i); + } + for (auto &pair : (*cls_ids)) { + pair.second.shrink_to_fit(); + } + return Status::OK(); +} + +Status MnistOp::ReadFromReader(std::ifstream *reader, uint32_t *result) { + uint32_t res = 0; + reader->read(reinterpret_cast(&res), 4); + if (reader->fail()) { + RETURN_STATUS_UNEXPECTED("Failed to read 4 bytes from file"); + } + *result = SwapEndian(res); + return Status::OK(); +} + +uint32_t MnistOp::SwapEndian(uint32_t val) const { + val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF); + return (val << 16) | (val >> 16); +} + +Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images) { + if (image_reader->is_open() == false) { + RETURN_STATUS_UNEXPECTED("Cannot open mnist image file: " + file_name); + } + int64_t image_len = image_reader->seekg(0, std::ios::end).tellg(); + (void)image_reader->seekg(0, std::ios::beg); + // The first 16 bytes of the image file are type, number, row and column + if (image_len < 16) { + RETURN_STATUS_UNEXPECTED("Mnist file is corrupted."); + } + uint32_t magic_number; + RETURN_IF_NOT_OK(ReadFromReader(image_reader, &magic_number)); + CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistImageFileMagicNumber, + "This is not the mnist image file: " + file_name); + + uint32_t num_items; + RETURN_IF_NOT_OK(ReadFromReader(image_reader, &num_items)); + uint32_t rows; + RETURN_IF_NOT_OK(ReadFromReader(image_reader, &rows)); + uint32_t cols; + RETURN_IF_NOT_OK(ReadFromReader(image_reader, &cols)); + // The image size of the Mnist dataset is fixed at [28,28] + if ((rows != kMnistImageRows) || (cols != kMnistImageCols)) { + RETURN_STATUS_UNEXPECTED("Wrong shape of image."); + } + if ((image_len - 16) != num_items * rows * cols) { + RETURN_STATUS_UNEXPECTED("Wrong number of image."); + } + *num_images = num_items; + return Status::OK(); +} + +Status MnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) { + if (label_reader->is_open() == false) { + RETURN_STATUS_UNEXPECTED("Cannot open mnist label file: " + file_name); + } + int64_t label_len = label_reader->seekg(0, std::ios::end).tellg(); + (void)label_reader->seekg(0, std::ios::beg); + // The first 8 bytes of the image file are type and number + if (label_len < 8) { + RETURN_STATUS_UNEXPECTED("Mnist file is corrupted."); + } + uint32_t magic_number; + RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number)); + CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistLabelFileMagicNumber, + "This is not the mnist label file: " + file_name); + uint32_t num_items; + RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items)); + if ((label_len - 8) != num_items) { + RETURN_STATUS_UNEXPECTED("Wrong number of labels!"); + } + *num_labels = num_items; + return Status::OK(); +} + +Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) { + uint32_t num_images, num_labels; + RETURN_IF_NOT_OK(CheckImage(image_names_[index], image_reader, &num_images)); + RETURN_IF_NOT_OK(CheckLabel(label_names_[index], label_reader, &num_labels)); + CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels), "num_images != num_labels"); + // The image size of the Mnist dataset is fixed at [28,28] + int64_t size = kMnistImageRows * kMnistImageCols; + auto images_buf = std::make_unique(size * num_images); + auto labels_buf = std::make_unique(num_images); + if (images_buf == nullptr || labels_buf == nullptr) { + std::string err_msg = "Fail to allocate memory for MNIST Buffer."; + MS_LOG(ERROR) << err_msg.c_str(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)image_reader->read(images_buf.get(), size * num_images); + if (image_reader->fail()) { + RETURN_STATUS_UNEXPECTED("Fail to read:" + image_names_[index] + " size:" + std::to_string(size * num_images)); + } + (void)label_reader->read(labels_buf.get(), num_images); + if (label_reader->fail()) { + RETURN_STATUS_UNEXPECTED("Fail to read:" + label_names_[index] + " size: " + std::to_string(num_images)); + } + TensorShape img_tensor_shape = TensorShape({kMnistImageRows, kMnistImageCols, 1}); + for (int64_t j = 0; j != num_images; ++j) { + auto pixels = &images_buf[j * size]; + for (int64_t m = 0; m < size; ++m) { + pixels[m] = (pixels[m] == 0) ? 0 : 255; + } + std::shared_ptr image; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), img_tensor_shape, + data_schema_->column(0).type(), reinterpret_cast(pixels))); + image_label_pairs_.emplace_back(std::make_pair(image, labels_buf[j])); + } + return Status::OK(); +} + +Status MnistOp::ParseMnistData() { + for (size_t i = 0; i < image_names_.size(); ++i) { + std::ifstream image_reader, label_reader; + image_reader.open(image_names_[i], std::ios::binary); + label_reader.open(label_names_[i], std::ios::binary); + + Status s = ReadImageAndLabel(&image_reader, &label_reader, i); + // Close the readers + image_reader.close(); + label_reader.close(); + RETURN_IF_NOT_OK(s); + } + image_label_pairs_.shrink_to_fit(); + num_rows_ = image_label_pairs_.size(); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API " + "validation first."); + } + return Status::OK(); +} + +Status MnistOp::WalkAllFiles() { + const std::string kImageExtension = "idx3-ubyte"; + const std::string kLabelExtension = "idx1-ubyte"; + + Path dir(folder_path_); + auto dir_it = Path::DirIterator::OpenDirectory(&dir); + if (dir_it != nullptr) { + while (dir_it->hasNext()) { + Path file = dir_it->next(); + std::string filename = file.toString(); + if (filename.find(kImageExtension) != std::string::npos) { + image_names_.push_back(filename); + MS_LOG(INFO) << "Mnist operator found image file at " << filename << "."; + } else if (filename.find(kLabelExtension) != std::string::npos) { + label_names_.push_back(filename); + MS_LOG(INFO) << "Mnist Operator found label file at " << filename << "."; + } + } + } else { + MS_LOG(WARNING) << "Mnist operator unable to open directory " << dir.toString() << "."; + } + + std::sort(image_names_.begin(), image_names_.end()); + std::sort(label_names_.begin(), label_names_.end()); + + if (image_names_.size() != label_names_.size()) { + RETURN_STATUS_UNEXPECTED("num of images does not equal to num of labels"); + } + + return Status::OK(); +} + +Status MnistOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&MnistOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(this->WalkAllFiles()); + RETURN_IF_NOT_OK(this->ParseMnistData()); + RETURN_IF_NOT_OK(this->InitSampler()); // handle shake with sampler + return Status::OK(); +} + +Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) { + // the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader() + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op)); + + RETURN_IF_NOT_OK(op->WalkAllFiles()); + + for (size_t i = 0; i < op->image_names_.size(); ++i) { + std::ifstream image_reader; + image_reader.open(op->image_names_[i], std::ios::binary); + std::ifstream label_reader; + label_reader.open(op->label_names_[i], std::ios::binary); + + uint32_t num_images; + RETURN_IF_NOT_OK(op->CheckImage(op->image_names_[i], &image_reader, &num_images)); + uint32_t num_labels; + RETURN_IF_NOT_OK(op->CheckLabel(op->label_names_[i], &label_reader, &num_labels)); + CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels), "num of images does not equal to num of labels"); + *count = *count + num_images; + + // Close the readers + image_reader.close(); + label_reader.close(); + } + + return Status::OK(); +} + +// Visitor accept method for NodePass +Status MnistOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status MnistOp::ComputeColMap() { + // set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h new file mode 100644 index 0000000000..039f6b112f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h @@ -0,0 +1,252 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// Forward declares +template +class Queue; + +using MnistLabelPair = std::pair, int32_t>; + +class MnistOp : public ParallelOp, public RandomAccessOp { + public: + class Builder { + public: + // Constructor for Builder class of MnistOp + // @param uint32_t numWrks - number of parallel workers + // @param dir - directory folder got ImageNetFolder + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method + // @param const std::string & dir + // @return + Builder &SetDir(const std::string &dir) { + builder_dir_ = dir; + return *this; + } + + // Check validity of input args + // @return - The error code return + Status SanityCheck(); + + // The builder "Build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + std::string builder_dir_; + int32_t builder_num_workers_; + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::shared_ptr builder_sampler_; + std::unique_ptr builder_schema_; + }; + + // Constructor + // @param int32_t num_workers - number of workers reading images in parallel + // @param int32_t rows_per_buffer - number of images (rows) in each buffer + // @param std::string folder_path - dir directory of mnist + // @param int32_t queue_size - connector queue size + // @param std::unique_ptr data_schema - the schema of the mnist dataset + // @param td::unique_ptr sampler - sampler tells MnistOp what to read + MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, + std::unique_ptr data_schema, std::shared_ptr sampler); + + // Destructor. + ~MnistOp() = default; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t worker_id - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Main Loop of MnistOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // Method derived from RandomAccess Op, enable Sampler to get all ids for each class + // @param (std::map> * map - key label, val all ids for this class + // @return Status - The error code return + Status GetClassIds(std::map> *cls_ids) const override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // Function to count the number of samples in the MNIST dataset + // @param dir path to the MNIST directory + // @param count output arg that will hold the minimum of the actual dataset size and numSamples + // @return + static Status CountTotalRows(const std::string &dir, int64_t *count); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "MnistOp"; } + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Load a tensor row according to a pair + // @param row_id_type row_id - id for this tensor row + // @param ImageLabelPair pair - + // @param TensorRow row - image & label read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Iterate through all members in sampleIds and fill them into IOBlock. + // @param std::shared_ptr sample_ids - + // @param std::vector *keys - keys in ioblock + // @return Status - The error code return + Status TraversalSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); + + // Check image file stream. + // @param const std::string *file_name - image file name + // @param std::ifstream *image_reader - image file stream + // @param uint32_t num_images - returns the number of images + // @return Status - The error code return + Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images); + + // Check label stream. + // @param const std::string &file_name - label file name + // @param std::ifstream *label_reader - label file stream + // @param uint32_t num_labels - returns the number of labels + // @return Status - The error code return + Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels); + + // Read 4 bytes of data from a file stream. + // @param std::ifstream *reader - file stream to read + // @return uint32_t - read out data + Status ReadFromReader(std::ifstream *reader, uint32_t *result); + + // Swap endian + // @param uint32_t val - + // @return uint32_t - swap endian data + uint32_t SwapEndian(uint32_t val) const; + + // Read the specified number of images and labels from the file stream + // @param std::ifstream *image_reader - image file stream + // @param std::ifstream *label_reader - label file stream + // @param int64_t read_num - number of image to read + // @return Status - The error code return + Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index); + + // Parse all mnist dataset files + // @return Status - The error code return + Status ParseMnistData(); + + // Read all files in the directory + // @return Status - The error code return + Status WalkAllFiles(); + + // Called first when function is called + // @return Status - The error code return + Status LaunchThreadsAndInitOp(); + + // reset Op + // @return Status - The error code return + Status Reset() override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int64_t buf_cnt_; + int64_t row_cnt_; + WaitPost wp_; + std::string folder_path_; // directory of image folder + int32_t rows_per_buffer_; + std::unique_ptr data_schema_; + std::vector image_label_pairs_; + std::vector image_names_; + std::vector label_names_; + QueueList> io_block_queues_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc new file mode 100644 index 0000000000..46f3adfa62 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -0,0 +1,426 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include +#include +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +RandomDataOp::Builder::Builder() + : builder_data_schema_(nullptr), + builder_num_workers_(0), + builder_op_connector_size_(0), + builder_rows_per_buffer_(0), + builder_total_rows_(0), + builder_sampler_(nullptr) { + // Some arguments to the RandomDataOp have a default argument that is taken from the config. + // The user may override these defaults by using the builder set methods. + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +// The build method that produces the instantiated RandomDataOp as a shared pointer +Status RandomDataOp::Builder::Build(std::shared_ptr *out_op) { + RETURN_IF_NOT_OK(SanityCheck()); + + *out_op = + std::make_shared(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, + builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_)); + + // If the user did not provide a schema, then we will ask the op to generate a pseudo-random + // schema. + // See details of generateSchema function to learn what type of schema it will create. + if ((*out_op)->data_schema_ == nullptr) { + RETURN_IF_NOT_OK((*out_op)->GenerateSchema()); + } + + return Status::OK(); +} + +// Check if the required parameters are set by the builder. +Status RandomDataOp::Builder::SanityCheck() const { + // There actually is no required arguments for the random data op at all. + // Some arguments are preset with global values from config, and if they are not given by the user + // then we create them randomly. Leaving this function here for consistency with other operators. + return Status::OK(); +} + +// Constructor for RandomDataOp +RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, + std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + buffer_id_(0), + rows_per_buffer_(rows_per_buffer), + total_rows_(total_rows), + epoch_buffers_sent_(0), + guys_in_(0), + guys_out_(num_workers_), + eoe_worker_id_(0), + data_schema_(std::move(data_schema)) { + rand_gen_.seed(GetSeed()); // seed the random generator + // If total rows was not given, then randomly pick a number + if (total_rows_ == 0) { + total_rows_ = GenRandomInt(1, kMaxTotalRows); + } + // Everyone is already out from the sync area. + all_out_.Set(); +} + +// A print method typically used for debugging +void RandomDataOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [total rows: " << total_rows_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nTotal_rows: " << total_rows_ << "\nRows per buffer: " << rows_per_buffer_ << "\nSchema:\n" + << *data_schema_ << "\n\n"; + } +} + +// Helper function to produce a default/random schema if one didn't exist +Status RandomDataOp::GenerateSchema() { + if (data_schema_ != nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Generating a schema but one already exists!"); + } + + // To randomly create a schema, we need to choose: + // a) how many columns + // b) the type of each column + // c) the shape of each column (number of dimensions i.e. rank) + // d) the shape of each column (dimension values) + data_schema_ = std::make_unique(); + std::unique_ptr newShape; + std::unique_ptr newCol; + + // Loop over the number of chosen columns + int32_t numColumns = GenRandomInt(1, kMaxNumColumns); + for (int32_t i = 0; i < numColumns; i++) { + // For each column: + // - choose a datatype + // - generate a shape that randomly chooses the number of dimensions and the dimension values. + DataType::Type newType = static_cast(GenRandomInt(1, DataType::NUM_OF_TYPES - 2)); + int32_t rank = GenRandomInt(1, kMaxRank); + std::vector dims; + for (int32_t d = 0; d < rank; d++) { + // 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random + // 0 value to the unknown attribute if 0 is chosen + dsize_t dim_value = static_cast(GenRandomInt(0, kMaxDimValue)); + if (dim_value == 0) dim_value = TensorShape::kDimUnknown; + dims.push_back(dim_value); + } + newShape = std::make_unique(dims); + + // Create the column descriptor + std::string colName = "c" + std::to_string(i); + newCol = std::make_unique(colName, DataType(newType), TensorImpl::kFlexible, rank, newShape.get()); + + data_schema_->AddColumn(*newCol); + } + + return Status::OK(); +} + +// Class functor operator () override. +// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will +// provide the master loop that drives the logic for performing the work. +Status RandomDataOp::operator()() { + // First, compute how many buffers we'll need to satisfy the total row count. + // The only reason we do this is for the purpose of throttling worker count if needed. + int64_t buffers_needed = total_rows_ / rows_per_buffer_; + if (total_rows_ % rows_per_buffer_ != 0) { + buffers_needed++; + } + + // If the amount of workers we have exceeds the number of buffers to produce, then we'll have + // idle workers doing nothing. In that case, let's throttle the worker count. + if (num_workers_ > buffers_needed) { + MS_LOG(INFO) << "RandomDataOp throttling worker count from " << num_workers_ << "to " << buffers_needed; + num_workers_ = buffers_needed; + num_producers_ = num_workers_; + guys_out_ = num_workers_; + // The output connector was already created with a different worker count. We have to drop and recreate + // that connector. + DatasetOp::CreateConnector(num_producers_, num_workers_); + } + + // Assign the number of rows to each worker in a round robin fashion. + worker_max_rows_.reserve(num_workers_); + worker_rows_packed_.reserve(num_workers_); + // init the counts to zero to start. + for (int32_t w = 0; w < num_workers_; w++) { + worker_max_rows_.push_back(0); + worker_rows_packed_.push_back(0); + } + // then assign round robin row counts + int32_t currentWorker = 0; + for (int64_t r = 0; r < total_rows_; r++) { + worker_max_rows_[currentWorker]++; + currentWorker = (currentWorker + 1) % num_workers_; + } + + // Next, compute the total buffer count. This stat is needed during reset logic + for (int32_t w = 0; w < num_workers_; w++) { + int64_t worker_buffers = 0; + worker_buffers = worker_max_rows_[w] / rows_per_buffer_; + if (worker_max_rows_[w] % rows_per_buffer_ != 0) worker_buffers++; + epoch_buffers_sent_ += worker_buffers; + } + + // For the connector to work, we need to target the correct worker channel for the eoe. + // This will initialize it for the first one. reset() handles for the rest of the epochs. + eoe_worker_id_ = epoch_buffers_sent_ % num_workers_; + epoch_buffers_sent_++; // Add the eoe buffer to the count for subsequent epochs + + // RandomDataOp doesn't need the master thread to stay around. Kick off the workers and then master exits. + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&RandomDataOp::WorkerEntry, this, std::placeholders::_1))); + + // required task group setup after launching workers + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(epoch_sync_wait_post_.Register(tree_->AllTasks())); + + return Status::OK(); +} + +// Performs a synchronization between workers at the end of an epoch +Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " syncing at end of epoch"; + + // Sync on the guys_in counter + // We have to wait the last guy is out. + all_out_.Wait(); + // If we are not in a repeat loop, or that was the last repeat already, then setup our exit + // condition from the master loop. + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + *quitting = true; + } + + auto prev = guys_in_.fetch_add(1); + bool last_guy_in = (prev + 1) == num_workers_; + // If we are the last worker to hit this sync point, we have some extra tasks + if (last_guy_in) { + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker " + << eoe_worker_id_; + // Prepare for sync + all_out_.Clear(); + // Always flow eoe at the end + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(eoe_worker_id_, std::move(eoe_buffer))); + // If we're done then also flow the eof + if (*quitting) { + // The eof needs to be sent from the next sender in the round robin, so +1 + int32_t eof_worker_id = (eoe_worker_id_ + 1) % num_workers_; + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " has no more epochs. sending eof as worker " + << eof_worker_id; + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(eof_worker_id, std::move(eof_buffer))); + } + } + + // Wait for the reset to wake us up if we're not quitting + if (!(*quitting)) { + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entering sync wait."; + RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait()); + prev = guys_out_.fetch_add(1); + bool last_guy_out = (prev + 1) == num_workers_; + // Last guy out will clear the wait post and set the row counts + if (last_guy_out) { + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " last guy out clearing wait post."; + epoch_sync_wait_post_.Clear(); + guys_in_ = 0; + all_out_.Set(); + } + } + + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " epoch sync complete."; + return Status::OK(); +} + +// The entry point code for when workers are launched +Status RandomDataOp::WorkerEntry(int32_t worker_id) { + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entry"; + + // handshake with the master first to tell it we're alive + TaskManager::FindMe()->Post(); + + bool quitting = false; + std::unique_ptr new_tensor_table = nullptr; + + // Loop until the quitting variable gets set to true + do { + // If we have not yet reached the row count for this worker then produce another record + if (worker_rows_packed_[worker_id] < worker_max_rows_[worker_id]) { + TensorRow new_row; + + // Start a new tensor table if needed + if (new_tensor_table == nullptr) { + new_tensor_table = std::make_unique(); + } + + // Create the data for the row + RETURN_IF_NOT_OK(CreateRandomRow(worker_id, &new_row)); + + // Add the row to our table + new_tensor_table->push_back(std::move(new_row)); + worker_rows_packed_[worker_id]++; + + // If the tensor table is at capacity then it's time to send it to output + if (new_tensor_table->size() == rows_per_buffer_) { + RETURN_IF_NOT_OK(PackAndSend(worker_id, std::move(new_tensor_table))); + } + } else { + // We've reached the total row count for this worker, so it's time for epoch sync. + // There is likely some records built but not sent yet, so take care of those first + // (this buffer will be smaller than rows_per_buffer) + if (new_tensor_table != nullptr && new_tensor_table->size() > 0) { + RETURN_IF_NOT_OK(PackAndSend(worker_id, std::move(new_tensor_table))); + } + + // Now, let's enter the epoch sync + RETURN_IF_NOT_OK(EpochSync(worker_id, &quitting)); + } + } while (!quitting); + + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is now quitting."; + + return Status::OK(); +} + +// A helper function to stuff the tensor table into a buffer and send it to output connector +Status RandomDataOp::PackAndSend(int32_t worker_id, std::unique_ptr in_table) { + auto new_buffer = std::make_unique(GetNextBufferId(), DataBuffer::kDeBFlagNone); + new_buffer->set_tensor_table(std::move(in_table)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(new_buffer))); + return Status::OK(); +} + +// A helper function to create random data for the row +Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) { + if (new_row == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Missing tensor row output"); + } + + // Create a tensor for each column, then add the tensor to the row + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + const ColDescriptor current_col = data_schema_->column(i); + std::vector current_shape = current_col.shape().AsVector(); + std::unique_ptr new_shape = nullptr; + std::unique_ptr buf = nullptr; + std::shared_ptr new_tensor = nullptr; + + // We need to resolve the shape to fill in any unknown dimensions with random + // values, then use that as our shape for this tensor. + for (int j = 0; j < current_shape.size(); ++j) { + if (current_shape[j] == TensorShape::kDimUnknown) { + current_shape[j] = static_cast(GenRandomInt(1, kMaxDimValue)); + } + } + + new_shape = std::make_unique(current_shape); + int64_t size_in_bytes = new_shape->NumOfElements() * current_col.type().SizeInBytes(); + + // Generate a random byte of data. This may cause some funny data for things like doubles,floats, bools + // however the random data op is not too concerned about the physical data itself. + std::uniform_int_distribution uniDist(0, 255); + uint8_t random_byte = uniDist(rand_gen_); + + // Now, create a chunk of memory for the entire tensor and copy this byte in repeatedly. + buf = std::make_unique(size_in_bytes); + int ret_code = memset_s(buf.get(), size_in_bytes, random_byte, size_in_bytes); + if (ret_code != 0) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Failed to set random bytes for a tensor."); + } + + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&new_tensor, current_col.tensorImpl(), *new_shape, current_col.type(), buf.get())); + + // Add this tensor to the tensor row for output + (*new_row).push_back(std::move(new_tensor)); + } + return Status::OK(); +} + +// Overrides base class reset method. When an operator does a reset, it cleans up any state +// info from it's previous execution and then initializes itself so that it can be executed +// again. +Status RandomDataOp::Reset() { + MS_LOG(INFO) << "RandomDataOp resetting."; + + // Ensure all guys are in the waitpost + if (guys_in_ != num_workers_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Issuing a reset, but some workers are missing from epochSync!"); + } + + // reset the row counters for all workers + for (int32_t w = 0; w < num_workers_; w++) { + worker_rows_packed_[w] = 0; + worker_max_rows_[w] = 0; + } + buffer_id_ = 0; + + // Re-assign round robin row counts, starting from the worker after the one that gave + // the eoe last time + int32_t currentWorker = (eoe_worker_id_ + 1) % num_workers_; + for (int64_t r = 0; r < total_rows_; r++) { + worker_max_rows_[currentWorker]++; + currentWorker = (currentWorker + 1) % num_workers_; + } + + // Compute which worker should get the eoe for the next epoch + eoe_worker_id_ = ((epoch_buffers_sent_ % num_workers_) + eoe_worker_id_) % num_workers_; + + // Wake up the workers to get them going again in a new epoch + guys_out_ = 0; + epoch_sync_wait_post_.Set(); + + return Status::OK(); +} + +// Visitor accept method for NodePass +Status RandomDataOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status RandomDataOp::ComputeColMap() { + // Extract the column name mapping from the schema and save it in the class. + if (column_name_id_map_.empty()) { + RETURN_IF_NOT_OK(data_schema_->GetColumnNameMap(&(column_name_id_map_))); + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h new file mode 100644 index 0000000000..c77695439d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h @@ -0,0 +1,291 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ + +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// The RandomDataOp is a leaf node storage operator that generates random data based +// on the schema specifications. Typically, it's used for testing and demonstrating +// various dataset operator pipelines. It is not "real" data to train with. +// The data that is random created is just random and repeated bytes, there is no +// "meaning" behind what these bytes are. +class RandomDataOp : public ParallelOp { + public: + // Some constants to provide limits to random generation. + static constexpr int32_t kMaxNumColumns = 4; + static constexpr int32_t kMaxRank = 4; + static constexpr int32_t kMaxDimValue = 32; + static constexpr int32_t kMaxTotalRows = 1024; + + // A nested builder class to aid in the construction of a RandomDataOp + class Builder { + public: + /** + * Builder constructor. Creates the builder object. + * @note No default args. + * @return This is a constructor. + */ + Builder(); + + /** + * Default destructor + */ + ~Builder() = default; + + /** + * The build method that produces the instantiated RandomDataOp as a shared pointer + * @param out_op - The output RandomDataOperator that was constructed + * @return Status - The error code return + */ + Status Build(std::shared_ptr *out_op); + + /** + * Builder set method + * @param data_schema - A user-provided schema + * @return Builder - The modified builder by reference + */ + Builder &SetDataSchema(std::unique_ptr data_schema) { + builder_data_schema_ = std::move(data_schema); + return *this; + } + + /** + * Builder set method + * @param num_workers - The number of workers + * @return Builder - The modified builder by reference + */ + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + /** + * Builder set method + * @param op_connector_size - The size of the output connector + * @return Builder - The modified builder by reference + */ + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + /** + * Builder set method + * @param rows_per_buffer - The number of rows in each DataBuffer + * @return Builder - The modified builder by reference + */ + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + /** + * Builder set method + * @param total_rows - The total number of rows in the dataset + * @return Builder - The modified builder by reference + */ + Builder &SetTotalRows(int64_t total_rows) { + builder_total_rows_ = total_rows; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + private: + /** + * Check if the required parameters are set by the builder. + * @return Status - The error code return + */ + Status SanityCheck() const; + + std::unique_ptr builder_data_schema_; + std::shared_ptr builder_sampler_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_total_rows_; + }; // class Builder + + /** + * Constructor for RandomDataOp + * @note Private constructor. Must use builder to construct. + * @param num_workers - The number of workers + * @param op_connector_size - The size of the output connector + * @param rows_per_buffer - The number of rows in each DataBuffer + * @param data_schema - A user-provided schema + * @param total_rows - The total number of rows in the dataset + * @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes + * @return Builder - The modified builder by reference + */ + RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, + std::unique_ptr data_schema, std::shared_ptr sampler); + + /** + * Destructor + */ + ~RandomDataOp() = default; + + /** + * A print method typically used for debugging + * @param out - The output stream to write output to + * @param show_all - A bool to control if you want to show all info or just a summary + */ + void Print(std::ostream &out, bool show_all) const override; + + /** + * << Stream output operator overload + * @notes This allows you to write the debug print info using stream operators + * @param out - reference to the output stream being overloaded + * @param so - reference to the ShuffleOp to display + * @return - the output stream must be returned + */ + friend std::ostream &operator<<(std::ostream &out, const RandomDataOp &op) { + op.Print(out, false); + return out; + } + + /** + * Class functor operator () override. + * All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will + * provide the master loop that drives the logic for performing the work. + * @return Status - The error code return + */ + Status operator()() override; + + /** + * Overrides base class reset method. When an operator does a reset, it cleans up any state + * info from it's previous execution and then initializes itself so that it can be executed + * again. + * @return Status - The error code return + */ + Status Reset() override; + + /** + * Quick getter for total rows. + */ + int64_t GetTotalRows() const { return total_rows_; } + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "RandomDataOp"; } + + private: + /** + * The entry point code for when workers are launched + * @param worker_id - The worker id + * @return Status - The error code return + */ + Status WorkerEntry(int32_t worker_id) override; + + /** + * Helper function to produce a default/random schema if one didn't exist + @return Status - The error code return + */ + Status GenerateSchema(); + + /** + * Performs a synchronization between workers at the end of an epoch + * @param worker_id - The worker id + * @return Status - The error code return + */ + Status EpochSync(int32_t worker_id, bool *quitting); + + /** + * A helper function to stuff the tensor table into a buffer and send it to output connector + * @param worker_id - The worker id + * @param in_table - The tensor table to pack and send + * @return Status - The error code return + */ + Status PackAndSend(int32_t worker_id, std::unique_ptr in_table); + + /** + * A helper function to create random data for the row + * @param worker_id - The worker id + * @param new_row - The output row to produce + * @return Status - The error code return + */ + Status CreateRandomRow(int32_t worker_id, TensorRow *new_row); + + /** + * A quick inline for producing a random number between (and including) min/max + * @param min - minimum number that can be generated + * @param max - maximum number that can be generated + * @return - The generated random number + */ + inline int32_t GenRandomInt(int32_t min, int32_t max) { + std::uniform_int_distribution uniDist(min, max); + return uniDist(rand_gen_); + } + + /** + * A quick inline for producing the next buffer id in sequence, threadsafe + * @return - The next buffer id. + */ + inline int32_t GetNextBufferId() { + std::unique_lock lock(buffer_id_mutex_); + return ++buffer_id_; + } + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t buffer_id_; + int64_t rows_per_buffer_; + int64_t total_rows_; + int64_t epoch_buffers_sent_; + std::atomic guys_in_; + std::atomic guys_out_; + int32_t eoe_worker_id_; + std::unique_ptr data_schema_; + std::vector worker_max_rows_; + std::vector worker_rows_packed_; + std::mt19937 rand_gen_; + WaitPost epoch_sync_wait_post_; + WaitPost all_out_; + std::mutex buffer_id_mutex_; +}; // class RandomDataOp +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt new file mode 100644 index 0000000000..1335d987e8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt @@ -0,0 +1,21 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) + +set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES + distributed_sampler.cc + pk_sampler.cc + random_sampler.cc + sampler.cc + sequential_sampler.cc + subset_random_sampler.cc + weighted_random_sampler.cc + ) + +if (ENABLE_PYTHON) + set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES + ${DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES} + python_sampler.cc + ) +endif() + +add_library(engine-datasetops-source-sampler OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES}) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc new file mode 100644 index 0000000000..2b5e7c67c8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" + +#include +#include + +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, + uint32_t seed) + : Sampler(num_samples, std::numeric_limits::max()), + cnt_(0), + seed_(seed == std::numeric_limits::max() ? GetSeed() : seed), + device_id_(dev_id), + num_devices_(num_dev), + shuffle_(shuffle) {} + +Status DistributedSampler::InitSampler() { + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, + "fail to init DistributedSampler"); + rnd_.seed(seed_++); + samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) + samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_; + if (shuffle_ == true) { + shuffle_vec_.reserve(num_rows_); + for (int64_t i = 0; i < num_rows_; i++) { + shuffle_vec_.push_back(i); + } + std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); + } + return Status::OK(); +} + +Status DistributedSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (cnt_ > samples_per_buffer_) { + RETURN_STATUS_UNEXPECTED("Distributed Sampler Error"); + } else if (cnt_ == samples_per_buffer_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + (*out_buffer) = std::make_unique(cnt_, DataBuffer::kDeBFlagNone); + std::shared_ptr sample_ids; + RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_)); + auto id_ptr = sample_ids->begin(); + while (cnt_ < samples_per_buffer_ && id_ptr != sample_ids->end()) { + int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_; + if (shuffle_) { + sampled_id = shuffle_vec_[static_cast(sampled_id)]; + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *id_ptr = sampled_id; + id_ptr++; + cnt_++; + } + TensorRow row(1, sample_ids); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} + +Status DistributedSampler::ResetSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); + cnt_ = 0; + + if (shuffle_ == true) { + rnd_.seed(seed_); + seed_++; + std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +void DistributedSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: DistributedSampler"; + if (show_all) { + Sampler::Print(out, show_all); + out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ + << "\nshuffle: " << shuffle_; + } +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h new file mode 100644 index 0000000000..76bcf052f9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -0,0 +1,66 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class DistributedSampler : public Sampler { + public: + // @param num_samples + // @param int64_t num_dev + // @param int64_t dev_id + // @param bool shuffle + DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, + uint32_t seed = std::numeric_limits::max()); + + // default destructor + ~DistributedSampler() = default; + + // @param std::unique_ptr * pBuffer + // @param int32_t workerId + // @return - The error code return + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // Init sampler, called by base class or python + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status ResetSampler() override; + + void Print(std::ostream &out, bool show_all) const override; + + private: + int64_t cnt_; // number of samples that have already been filled in to buffer + uint32_t seed_; + int64_t device_id_; + int64_t num_devices_; + bool shuffle_; + std::mt19937 rnd_; + std::vector shuffle_vec_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc new file mode 100644 index 0000000000..770c24c8c5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include +#include +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), + shuffle_(shuffle), + seed_(GetSeed()), + next_id_(0), + samples_per_class_(val) {} + +Status PKSampler::InitSampler() { + labels_.reserve(label_to_ids_.size()); + for (const auto &pair : label_to_ids_) { + if (pair.second.empty() == false) { + labels_.push_back(pair.first); + } + } + rnd_.seed(seed_++); + + // The special handshake gives the list of classes and id's, but it did not set the num_rows_ to + // capture the total number of possible sample ids. + // Compute that here for this case to find the total number of samples that are available to return. + // (in this case, samples per class * total classes). + num_rows_ = samples_per_class_ * static_cast(labels_.size()); + + // The user may have chosen to sample less than the total amount. + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + + samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; + if (shuffle_ == true) { + std::shuffle(labels_.begin(), labels_.end(), rnd_); + } else { + std::sort(labels_.begin(), labels_.end()); + } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_class or K (num samples per class) is not positive"); + return Status::OK(); +} + +Status PKSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (next_id_ > num_samples_ || num_samples_ == 0) { + RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler"); + } else if (next_id_ == num_samples_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); + std::shared_ptr sample_ids; + int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; + RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_)); + auto id_ptr = sample_ids->begin(); + while (next_id_ < last_id && id_ptr != sample_ids->end()) { + int64_t cls_id = next_id_++ / samples_per_class_; + const std::vector &samples = label_to_ids_[labels_[cls_id]]; + int64_t rnd_ind = std::uniform_int_distribution(0, samples.size() - 1)(rnd_); + int64_t sampled_id = samples[rnd_ind]; + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *id_ptr = sampled_id; + id_ptr++; + } + + TensorRow row(1, sample_ids); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} + +Status PKSampler::ResetSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); + next_id_ = 0; + rnd_.seed(seed_++); + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { + RETURN_UNEXPECTED_IF_NULL(op); + RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); + RETURN_IF_NOT_OK(InitSampler()); + return Status::OK(); +} + +void PKSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: PKSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h new file mode 100644 index 0000000000..aed61fa273 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class PKSampler : public Sampler { // NOT YET FINISHED + public: + // @param num_samples - the number of samples to draw. value of 0 means to take the full amount + // @param int64_t val + // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 + // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // default destructor + ~PKSampler() = default; + + // @param std::unique_ptr *out_buffer) override; + + // first handshake between leaf source op and Sampler. This func will determine the amount of data + // in the dataset that we can sample from. + // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is + // @return + Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; + + // init sampler, to be called by python or Handshake + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status ResetSampler() override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + + private: + bool shuffle_; + uint32_t seed_; + int64_t next_id_; + int64_t samples_per_class_; + std::mt19937 rnd_; + std::vector labels_; + std::map> label_to_ids_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc new file mode 100644 index 0000000000..50c67bca6c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h" + +#include + +namespace mindspore { +namespace dataset { + +PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} + +Status PythonSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (need_to_reset_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + std::shared_ptr sample_ids; + { + py::gil_scoped_acquire gil_acquire; + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py::object py_ret = py_sampler_instance.attr("_get_indices")(); + py::array np_sample_ids = py_ret.cast(); + Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor + + if (HasChildSampler()) { + for (auto it = sample_ids->begin(); it != sample_ids->end(); ++it) { + int64_t associated_child_id = 0; + RETURN_IF_NOT_OK(GetAssociatedChildId(&associated_child_id, associated_child_id)); + *it = associated_child_id; + } + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Python Sampler iterator should return integer index"); + } + } + TensorRow row(1, sample_ids); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + need_to_reset_ = true; + } + return Status::OK(); +} + +Status PythonSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py_sampler_instance.attr("_handshake")(num_rows_, num_samples_); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +Status PythonSampler::ResetSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); + need_to_reset_ = false; + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py_sampler_instance.attr("reset")(); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +void PythonSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: PythonSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h new file mode 100644 index 0000000000..61716feb94 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -0,0 +1,66 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ + +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class PythonSampler : public Sampler { + public: + // Constructor + // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the + // data from the dataset. + // @param py_sampler_instance - the python instance of the sampler + // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~PythonSampler() = default; + + // Initialize the sampler. + // @return Status + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status ResetSampler() override; + + // Op calls this to get next Buffer that contains all the sampleIds + // @param std::unique_ptr pBuffer - Buffer to be returned to corresponding Dataset Op + // @param int32_t workerId - not meant to be used + // @return - The error code return + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + + private: + bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() + + py::object py_sampler_instance; // The handle to the py_sampler python object +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc new file mode 100644 index 0000000000..998dee2a07 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" + +#include +#include +#include +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, + int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), + seed_(GetSeed()), + replacement_(replacement), + next_id_(0), + reshuffle_each_epoch_(reshuffle_each_epoch), + dist(nullptr) {} + +Status RandomSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (next_id_ > num_samples_) { + RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); + } else if (next_id_ == num_samples_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); + + std::shared_ptr sampleIds; + int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_); + RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_)); + auto id_ptr = sampleIds->begin(); + + for (int64_t i = 0; i < (last_id - next_id_); i++) { + int64_t sampled_id = 0; + if (replacement_) { + sampled_id = (*dist)(rnd_); + } else { + sampled_id = shuffled_ids_[static_cast(i + next_id_)]; + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *(id_ptr + i) = sampled_id; + } + next_id_ = last_id; + TensorRow row(1, sampleIds); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} + +Status RandomSampler::InitSampler() { + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); + samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; + rnd_.seed(seed_); + + if (replacement_ == false) { + shuffled_ids_.reserve(num_rows_); + for (int64_t i = 0; i < num_rows_; i++) { + shuffled_ids_.push_back(i); + } + std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); + } else { + dist = std::make_unique>(0, num_rows_ - 1); + } + + return Status::OK(); +} + +Status RandomSampler::ResetSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); + next_id_ = 0; + + if (reshuffle_each_epoch_) { + seed_++; + } + + rnd_.seed(seed_); + + if (replacement_ == false && reshuffle_each_epoch_) { + std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +void RandomSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: RandomSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h new file mode 100644 index 0000000000..6e21b088b9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h @@ -0,0 +1,66 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class RandomSampler : public Sampler { + public: + // Constructor + // @param int64_t num_samples - number samples to draw + // @param bool replacement - put he id back / or not after a sample + // @param reshuffle_each_epoch - T/F to reshuffle after epoch + // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~RandomSampler() = default; + + // Op calls this to get next Buffer that contains all the sampleIds + // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp + // @param int32_t workerId - not meant to be used + // @return - The error code return + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // meant to be called by base class or python + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status ResetSampler() override; + + virtual void Print(std::ostream &out, bool show_all) const; + + private: + uint32_t seed_; + bool replacement_; + std::vector shuffled_ids_; // only used for NO REPLACEMENT + int64_t next_id_; + std::mt19937 rnd_; + std::unique_ptr> dist; + bool reshuffle_each_epoch_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc new file mode 100644 index 0000000000..60d75d2eec --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -0,0 +1,178 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +#include + +namespace mindspore { +namespace dataset { +Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { + // The sampler base class itself does not compute it's own num_rows_ value. + // Instead, this value is computed by the derived leaf op during it's own initialization + // after it has interacted with it's storage layers. + // Here, it is just a getter method to return the value. However, it is invalid if there is + // not a value set for this count, so generate a failure if that is the case. + if (num == nullptr || num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet."); + } + (*num) = num_rows_; + return Status::OK(); +} + +Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) + : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} + +Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { + std::shared_ptr child_sampler; + if (HasChildSampler()) { + child_sampler = std::dynamic_pointer_cast(child_[0]); + if (!child_sampler) { + std::string err_msg("Cannot handshake, child is not a sampler object."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Handshake and init child first. + RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); + } + + CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); + + // If there's a child sampler, set the row count to be it's sample count + if (HasChildSampler()) { + num_rows_ = child_sampler->num_samples_; + } else { + RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); + } + + // It's up to the derived class to check the validity of the two args + // Because some sampler only needs one of the arg (weighted_random_sampler) + RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback + + return Status::OK(); +} + +Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements) { + if (num_elements == 0) { + RETURN_STATUS_UNEXPECTED("num of Elements is 0"); + } + if (col_desc_ == nullptr) { + // a ColDescriptor for Tensor that holds SampleIds + col_desc_ = std::make_unique("sampleIds", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1); + } + TensorShape shape(std::vector(1, num_elements)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, col_desc_->tensorImpl(), shape, col_desc_->type())); + RETURN_IF_NOT_OK( + (*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes())); // allocate memory in case user forgets! + return Status::OK(); +} + +void Sampler::Print(std::ostream &out, bool show_all) const { + // Sampler printing is usually only called in the show_all mode. + // Derived classes will display the name, then call back to this base + // for common info. + // No-op in the summary mode. + if (show_all) { + out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_; + } +} + +#ifdef ENABLE_PYTHON +Status Sampler::GetAllIdsThenReset(py::array *data) { + std::unique_ptr db; + std::shared_ptr sample_ids; + TensorRow sample_row; + + // A call to derived class to get sample ids wrapped inside a buffer + RETURN_IF_NOT_OK(GetNextSample(&db)); + // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch + RETURN_IF_NOT_OK(db->GetRow(0, &sample_row)); + sample_ids = sample_row[0]; + + // check this buffer is not a ctrl buffer + CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data)); + } catch (const std::runtime_error &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch + RETURN_IF_NOT_OK(GetNextSample(&db)); + CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); + // Reset Sampler since this is the end of the epoch + RETURN_IF_NOT_OK(ResetSampler()); + return Status::OK(); +} +#endif + +Status Sampler::SetNumSamples(int64_t num_samples) { + CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "num_samples is negative"); + num_samples_ = num_samples; + return Status::OK(); +} + +Status Sampler::SetNumRowsInDataset(int64_t num_rows) { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "num_rows is negative or 0"); + num_rows_ = num_rows; + return Status::OK(); +} + +Status Sampler::AddChild(std::shared_ptr child) { + if (child == nullptr) { + return Status::OK(); + } + + // Only samplers can be added, not any other DatasetOp. + std::shared_ptr sampler = std::dynamic_pointer_cast(child); + if (!sampler) { + std::string err_msg("Cannot add child, child is not a sampler object."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Samplers can have at most 1 child. + if (!child_.empty()) { + std::string err_msg("Cannot add child sampler, this sampler already has a child."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + child_.push_back(child); + + // doesn't work, protected? + // child->AddParent(this); + return Status::OK(); +} + +bool Sampler::HasChildSampler() { return !child_.empty(); } + +Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { + if (child_ids_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); + } + + TensorRow sample_row; + RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + RETURN_IF_NOT_OK(sample_ids->GetItemAt(out_associated_id, {id})); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h new file mode 100644 index 0000000000..4cae935a42 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h @@ -0,0 +1,161 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +namespace mindspore { +namespace dataset { +// RandomAccessOp is a base class that all data-producing leaf operators +// must inherit from if those leaf operator wish to support sampling. +class RandomAccessOp { + public: + // Sampler get number of rows in the dataset + // @param int64_t num - return number of rows for this dataset + // @return - The error code return + Status GetNumRowsInDataset(int64_t *num_rows) const; + + // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK + // @param std::map> * map + // @return - The error code return + virtual Status GetClassIds(std::map> *map) const { + RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK"); + } + + // default destructor + virtual ~RandomAccessOp() = default; + + protected: + // The amount of rows in the dataset itself. This is the before-sampling value, the + // total count of rows. A sampler may choose to sample less than this amount. + int64_t num_rows_; +}; + +class Sampler { + public: + // Constructor + // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 + // indicates that the sampler should produce the complete set of ids. + // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); + + Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {} + + // default destructor + ~Sampler() = default; + + // Get a list of sample ids. + // @note It is Sampler responsibility to make sure that the id is not out of bound. + // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp + // @param int32_t workerId - not meant to be used + // @return - The error code return + virtual Status GetNextSample(std::unique_ptr *out_buffer) = 0; + +// This function only called by python layer. Not needed by Android. +#ifdef ENABLE_PYTHON + // return all ids in one epoch as a numpy array, then call reset + Status GetAllIdsThenReset(py::array *data); +#endif + + // for next epoch of sampleIds + // @return - The error code return + virtual Status ResetSampler() = 0; + + // first handshake between leaf source op and Sampler. This func will determine the amount of data + // in the dataset that we can sample from. + // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is + // @return + virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); + + // initialize sampler and perform checks on certain vars + virtual Status InitSampler() { return Status::OK(); } + + // setter for num samples + // @param num_samples - the number of samples to assign. + // @return status error code + Status SetNumSamples(int64_t num_samples); + + // setter for num or records in the dataset + // @param num_rows - the number of records + // @return status error code + Status SetNumRowsInDataset(int64_t num_rows); + + // Adds a sampler to become our child. + // @param std::shared_ptr - The sampler to add as a child. + // @return - The error code returned. + Status AddChild(std::shared_ptr child); + + // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler + // @param std::shared_ptr* sampleIds + // @param int64_t numElements - must be a non 0 number + // @return - The error code returned. + Status CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + virtual void Print(std::ostream &out, bool show_all) const; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param sampler - reference to teh sampler to print + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { + sampler.Print(out, false); + return out; + } + + // Checks if this sampler has a child sampler. + // @return - tre if there is a child sampler, false otherwise. + bool HasChildSampler(); + + // Uses id as an index for the list of ids generated by the child sampler, and gets the + // associated id. + // @param int64_t* out_associated_id - Out parameter, contains the associated id. + // @param int64_t id - The id used as an index to get the associated child id. + // @return - The error code returned. + Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); + + protected: + // Number of rows of data from the place this sampler is sampling from. If this sampler + // has a child sampler, num_rows_ is the number of ids the child sampler will + // output. Otherwise, num_rows_ is the number of rows in the dataset. + int64_t num_rows_; + + // The user may want to sample less than the full amount of data. num_samples_ reduces the number + // of id's returned as request by the user. Derived classes will choose how to sample the smaller + // amount. + int64_t num_samples_; + + int64_t samples_per_buffer_; + std::unique_ptr col_desc_; + std::vector> child_; // Child nodes + std::unique_ptr child_ids_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc new file mode 100644 index 0000000000..1cc4ac831a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" + +#include +#include + +namespace mindspore { +namespace dataset { +SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} + +Status SequentialSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (id_count_ > num_samples_) { + RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); + } else if (id_count_ == num_samples_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + (*out_buffer) = std::make_unique(current_id_, DataBuffer::kDeBFlagNone); + std::shared_ptr sampleIds; + + // Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for + // samples per buffer though. + int64_t remaining_ids = num_samples_ - id_count_; + int64_t num_elements = std::min(remaining_ids, samples_per_buffer_); + + RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements)); + auto idPtr = sampleIds->begin(); + for (int64_t i = 0; i < num_elements; i++) { + int64_t sampled_id = current_id_; + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *idPtr = sampled_id; + current_id_++; // Move the current id to the next one in the sequence + idPtr++; + } + + id_count_ += num_elements; // Count the packed ids towards our overall sample count + + TensorRow row(1, sampleIds); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} + +Status SequentialSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0, "num_samples < 0\n"); + // Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample + // the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data. + int64_t available_row_count = num_rows_ - start_index_; + if (num_samples_ == 0 || num_samples_ > available_row_count) { + num_samples_ = available_row_count; + } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); + samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; + return Status::OK(); +} + +Status SequentialSampler::ResetSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); + current_id_ = start_index_; + id_count_ = 0; + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +void SequentialSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: SequentialSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info + out << "\nStart index: " << start_index_; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h new file mode 100644 index 0000000000..c6ccd0d1eb --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ + +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class SequentialSampler : public Sampler { + public: + // Constructor + // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the + // full amount of ids from the dataset + // @param start_index - The starting index value + // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit SequentialSampler(int64_t num_samples, int64_t start_index, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~SequentialSampler() = default; + + // init sampler, called by python + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status ResetSampler() override; + + // Op calls this to get next Buffer that contains all the sampleIds + // @param std::unique_ptr pBuffer - Buffer to be returned to corresponding Dataset Op + // @param int32_t workerId - not meant to be used + // @return - The error code return + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + + private: + int64_t current_id_; // The id sequencer. Each new id increments from this + int64_t start_index_; // The starting id. current_id_ begins from here. + int64_t id_count_; // An internal counter that tracks how many ids have been produced +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc new file mode 100644 index 0000000000..db2078795e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" + +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +// Constructor. +SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector &indices, + int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} + +// Initialized this Sampler. +Status SubsetRandomSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); + + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // In this case, the id's are provided by the user. Cap the num_samples on the number of id's given. + if (num_samples_ == 0 || num_samples_ > static_cast(indices_.size())) { + num_samples_ = static_cast(indices_.size()); + } + // Initialize random generator with seed from config manager + rand_gen_.seed(GetSeed()); + + if (samples_per_buffer_ > num_samples_) { + samples_per_buffer_ = num_samples_; + } + + // num_samples_ could be smaller than the total number of input id's. + // We will shuffle the full set of id's, but only select the first num_samples_ of them later. + std::shuffle(indices_.begin(), indices_.end(), rand_gen_); + + return Status::OK(); +} + +// Reset the internal variable to the initial state. +Status SubsetRandomSampler::ResetSampler() { + // Reset the internal counters. + sample_id_ = 0; + buffer_id_ = 0; + + // Randomized the indices again. + rand_gen_.seed(GetSeed()); + std::shuffle(indices_.begin(), indices_.end(), rand_gen_); + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +// Get the sample ids. +Status SubsetRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { + // All samples have been drawn + if (sample_id_ == num_samples_) { + (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); + std::shared_ptr outputIds; + + int64_t last_id = sample_id_ + samples_per_buffer_; + // Handling the return all samples at once, and when last draw is not a full batch. + if (last_id > num_samples_) { + last_id = num_samples_; + } + + // Allocate tensor + RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_)); + + // Initialize tensor + auto id_ptr = outputIds->begin(); + while (sample_id_ < last_id) { + if (indices_[sample_id_] >= num_rows_) { + std::string err_msg = + "Generated id is bigger than numRows (out of bound). indices_: " + std::to_string(indices_[sample_id_]) + + " num_rows_: " + std::to_string(num_rows_); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + int64_t sampled_id = indices_[sample_id_]; + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *id_ptr = sampled_id; + id_ptr++; + sample_id_++; + } + + // Create a TensorTable from that single tensor and push into DataBuffer + (*out_buffer)->set_tensor_table(std::make_unique(1, TensorRow(1, outputIds))); + } + + return Status::OK(); +} + +void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: SubsetRandomSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h new file mode 100644 index 0000000000..fccc15e57b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +// Randomly samples elements from a given list of indices, without replacement. +class SubsetRandomSampler : public Sampler { + public: + // Constructor. + // @param num_samples The number of samples to draw. 0 for the full amount. + // @param indices List of indices from where we will randomly draw samples. + // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). + // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. + explicit SubsetRandomSampler(int64_t num_samples, const std::vector &indices, + std::int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~SubsetRandomSampler() = default; + + // Initialize the sampler. + // @return Status + Status InitSampler() override; + + // Reset the internal variable to the initial state and reshuffle the indices. + // @return Status + Status ResetSampler() override; + + // Get the sample ids. + // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. + // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + + private: + // A list of indices (already randomized in constructor). + std::vector indices_; + + // Current sample id. + int64_t sample_id_; + + // Current buffer id. + int64_t buffer_id_; + + // A random number generator. + std::mt19937 rand_gen_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc new file mode 100644 index 0000000000..13863143c0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +// Constructor. +WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, + int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), + weights_(weights), + replacement_(replacement), + sample_id_(0), + buffer_id_(0) {} + +// Initialized this Sampler. +Status WeightedRandomSampler::InitSampler() { + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, "num_samples & num_rows need to be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); + + // Initialize random generator with seed from config manager + rand_gen_.seed(GetSeed()); + + samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; + + if (!replacement_) { + exp_dist_ = std::make_unique>(1); + InitOnePassSampling(); + } else { + discrete_dist_ = std::make_unique>(weights_.begin(), weights_.end()); + } + + return Status::OK(); +} + +// Initialized the computation for generating weighted random numbers without replacement using onepass method. +void WeightedRandomSampler::InitOnePassSampling() { + exp_dist_->reset(); + onepass_ids_.clear(); + std::vector> val_idx; + for (size_t i = 0; i < weights_.size(); i++) { + val_idx.emplace_back(std::make_pair((*exp_dist_)(rand_gen_) / weights_[i], i)); + } + + // Partial sort the first `numSamples` elements. + std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end()); + for (int64_t i = 0; i < num_samples_; i++) { + onepass_ids_.push_back(val_idx[i].second); + } +} + +// Reset the internal variable to the initial state and reshuffle the indices. +Status WeightedRandomSampler::ResetSampler() { + sample_id_ = 0; + buffer_id_ = 0; + rand_gen_.seed(GetSeed()); + if (!replacement_) { + InitOnePassSampling(); + } else { + discrete_dist_->reset(); + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +// Get the sample ids. +Status WeightedRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (weights_.size() > static_cast(num_rows_)) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); + } + + if (!replacement_ && (weights_.size() < static_cast(num_samples_))) { + RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); + } + + if (sample_id_ == num_samples_) { + (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); + std::shared_ptr outputIds; + + int64_t last_id = sample_id_ + samples_per_buffer_; + // Handling the return all samples at once, and when last draw is not a full batch. + if (last_id > num_samples_) { + last_id = num_samples_; + } + + // Allocate tensor. + RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_)); + + // Initialize tensor. + auto id_ptr = outputIds->begin(); + // Assign the data to tensor element. + while (sample_id_ < last_id) { + int64_t genId; + if (replacement_) { + genId = (*discrete_dist_)(rand_gen_); + } else { + // Draw sample without replacement. + genId = onepass_ids_.front(); + onepass_ids_.pop_front(); + } + + if (genId >= num_rows_) { + RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound)."); + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId)); + } + + *id_ptr = genId; + id_ptr++; + sample_id_++; + } + + // Create a TensorTable from that single tensor and push into DataBuffer + (*out_buffer)->set_tensor_table(std::make_unique(1, TensorRow(1, outputIds))); + } + + return Status::OK(); +} + +void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: WeightedRandomSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h new file mode 100644 index 0000000000..b1a531abe9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h @@ -0,0 +1,94 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +// Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights). +class WeightedRandomSampler : public Sampler { + public: + // Constructor. + // @param num_samples Number of samples to be drawn. + // @param weights A lift of sample weights. + // @param replacement Determine if samples are drawn with/without replacement. + // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). + // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. + WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~WeightedRandomSampler() = default; + + // Initialize the sampler. + // @param op (Not used in this sampler) + // @return Status + Status InitSampler() override; + + // Reset the internal variable to the initial state and reshuffle the indices. + Status ResetSampler() override; + + // Get the sample ids. + // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. + // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + + private: + // A list of weights for each sample. + std::vector weights_; + + // A flag indicating if samples are drawn with/without replacement. + bool replacement_; + + // Current sample id. + int64_t sample_id_; + + // Current buffer id. + int64_t buffer_id_; + + // Random engine and device + std::mt19937 rand_gen_; + + // Discrete distribution for generating weighted random numbers with replacement. + std::unique_ptr> discrete_dist_; + + // Exponential distribution for generating weighted random numbers without replacement. + // based on "Accelerating weighted random sampling without replacement" by Kirill Muller. + std::unique_ptr> exp_dist_; + + // Initialized the computation for generating weighted random numbers without replacement + // using onepass method. + void InitOnePassSampling(); + + // Store the random weighted ids generated by onepass method in `InitOnePassSampling` + std::deque onepass_ids_; +}; +} // namespace dataset +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc new file mode 100644 index 0000000000..c1f5b13a94 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -0,0 +1,498 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +TextFileOp::Builder::Builder() + : builder_device_id_(0), + builder_num_devices_(1), + builder_total_rows_(0), + builder_shuffle_files_(false), + builder_sampler_(nullptr) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); +} + +Status TextFileOp::Builder::ValidateInputs() const { + std::string err_msg; + err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; + err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +Status TextFileOp::Builder::Build(std::shared_ptr *op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_text_files_list_.size()) { + builder_num_workers_ = builder_text_files_list_.size(); + MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + + std::shared_ptr text_file_op = std::make_shared( + builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, + std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, + builder_num_devices_, builder_device_id_, std::move(builder_sampler_)); + RETURN_IF_NOT_OK(text_file_op->Init()); + *op = std::move(text_file_op); + + return Status::OK(); +} + +TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, + std::unique_ptr schema, std::vector text_files_list, + int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, + std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + device_id_(device_id), + num_devices_(num_device), + rows_per_buffer_(rows_per_buffer), + total_rows_(total_rows), + text_files_list_(std::move(text_files_list)), + shuffle_files_(shuffle_files), + data_schema_(std::move(schema)), + all_num_rows_(0), + num_rows_per_shard_(0), + filename_index_(std::make_unique()), + finished_reading_dataset_(false), + load_io_block_queue_(true), + load_jagged_connector_(true) { + worker_connector_size_ = worker_connector_size; +} + +// A print method typically used for debugging +void TextFileOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nRows per buffer: " << rows_per_buffer_ << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ + << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") + << "\nText files list:\n"; + for (int i = 0; i < text_files_list_.size(); ++i) { + out << " " << text_files_list_[i]; + } + out << "\nData Schema:\n"; + out << *data_schema_ << "\n\n"; + } +} + +Status TextFileOp::Init() { + RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_)); + + int32_t safe_queue_size = static_cast(std::ceil(text_files_list_.size() / num_workers_) + 1); + io_block_queues_.Init(num_workers_, safe_queue_size); + + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + + jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); + return Status::OK(); +} + +Status TextFileOp::Reset() { + load_jagged_connector_ = true; + load_io_block_queue_ = true; + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + return Status::OK(); +} + +Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { + TensorRow tRow(1, nullptr); + (*tensor_table)->push_back(std::move(tRow)); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); + (**tensor_table)[row][0] = std::move(tensor); + return Status::OK(); +} + +Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id) { + std::ifstream handle(file); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Failed to open file " + file); + } + + int64_t rows_each_buffer = 0; + int64_t rows_total = 0; + std::string line; + std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + std::unique_ptr tensor_table = std::make_unique(); + + while (getline(handle, line)) { + if (line.empty()) { + continue; + } + // If read to the end offset of this file, break. + if (rows_total >= end_offset) { + break; + } + // Skip line before start offset. + if (rows_total < start_offset) { + rows_total++; + continue; + } + + RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); + rows_each_buffer++; + rows_total++; + if (rows_each_buffer == rows_per_buffer_) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + + cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + tensor_table = std::make_unique(); + rows_each_buffer = 0; + } + } + + if (rows_each_buffer > 0) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + } + + return Status::OK(); +} + +Status TextFileOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + return Status::OK(); +} + +// Pops an element from a queue in io_block_queues +Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status TextFileOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status TextFileOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +Status TextFileOp::FillIOBlockQueue(const std::vector &i_keys) { + int32_t queue_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + while (!finish) { + std::vector> file_index; + if (!i_keys.empty()) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair((*filename_index_)[*it], *it)); + } + } else { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair(it.value(), it.key())); + } + } + for (auto file_info : file_index) { + if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { + auto ioBlock = + std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_info.first]; + } + + if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +Status TextFileOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spanwed by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); + } + return Status::OK(); +} + +void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +Status TextFileOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling IoBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this))); + + // Read data from disk into buffers + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. + TaskManager::FindMe()->Post(); + + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); + NotifyToFillIOBlockQueue(); + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + load_io_block_queue_ = true; + + while (workers_done < num_workers_) { + std::unique_ptr buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); + if (buffer->eoe()) { + workers_done++; + } else if (total_rows_ == 0 || rows_read < total_rows_) { + if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) { + int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read); + RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); + } + rows_read += buffer->NumRows(); + buffer->set_id(buffer_id++); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); + } else { + // end of epoch + load_jagged_connector_ = false; + load_io_block_queue_ = false; + } + } + + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + + return Status::OK(); +} + +int64_t TextFileOp::CountTotalRows(const std::string &file) { + std::ifstream handle(file); + if (!handle.is_open()) { + MS_LOG(ERROR) << "Failed to open file: " << file; + return 0; + } + + std::string line; + int64_t count = 0; + while (getline(handle, line)) { + if (!line.empty()) { + count++; + } + } + + return count; +} + +Status TextFileOp::CalculateNumRowsPerShard() { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + int64_t count = CountTotalRows(it.value()); + filename_numrows_[it.value()] = count; + all_num_rows_ += count; + } + if (all_num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API TextFileDataset.Please check file path or dataset API " + "validation first."); + } + + num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); + MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; + return Status::OK(); +} + +Status TextFileOp::CountAllFileRows(const std::vector &files, int64_t *count) { + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op)); + for (auto file : files) { + *count += op->CountTotalRows(file); + } + return Status::OK(); +} + +Status TextFileOp::ComputeColMap() { + // Set the column name mapping (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h new file mode 100644 index 0000000000..68c226ab80 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h @@ -0,0 +1,289 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/engine/jagged_connector.h" + +namespace mindspore { +namespace dataset { +using StringIndex = AutoIndexObj; + +class TextFileOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + // Create the final object. + // @param op - dataset op. + // @return - the error code return. + Status Build(std::shared_ptr *op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetTextFilesList(const std::vector &files_list) { + builder_text_files_list_ = files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetTotalRows(int64_t total_rows) { + builder_total_rows_ = total_rows; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + private: + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_total_rows_; + int32_t builder_worker_connector_size_; + std::vector builder_text_files_list_; + bool builder_shuffle_files_; + std::unique_ptr builder_schema_; + std::shared_ptr builder_sampler_; + }; + + // Constructor of TextFileOp + // @note The builder class should be used to call this constructor. + // @param num_workers - number of worker threads reading data from tf_file files. + // @param rows_per_buffer - number of rows that a full buffer will contain. + // @param total_num_rows - number of rows to read + // @param dataset_files_list - list of filepaths for the dataset files. + // @param data_schema - the data schema object. + // @param op_connector_size - size of each queue in the connector that the child operator pulls from. + // @param columns_to_load - the names of the columns to load data from. + // @param shuffle_files - whether or not to shuffle the files before reading data. + // @param equal_rows_per_shard - whether or not to get equal rows for each process. + // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes + TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, + std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); + + // Default destructor + ~TextFileOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Get total rows in files. + // @param files - all text files. + // @param count - number of rows. + // @return Status - the error coed returned. + static Status CountAllFileRows(const std::vector &files, int64_t *count); + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "TextFileOp"; } + + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return text_files_list_; } + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); + + // Reads a text file and loads the data into multiple buffers. + // @param file - the file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Count number of rows in each file. + // @param filename - text file name. + // @return int64_t - the total number of rows in file. + int64_t CountTotalRows(const std::string &file); + + // Notifies the thread which called FillIoBlockQueue to resume execution + void NotifyToFillIOBlockQueue(); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Fill the IOBlockQueue. + // @para i_keys - keys of file to fill to the IOBlockQueue + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t device_id_; + int32_t num_devices_; + int64_t rows_per_buffer_; + int64_t total_rows_; + std::vector text_files_list_; + bool shuffle_files_; + std::unique_ptr data_schema_; + int64_t all_num_rows_; + int64_t num_rows_per_shard_; + std::map filename_numrows_; + std::unique_ptr filename_index_; + QueueList> io_block_queues_; + WaitPost io_block_queue_wait_post_; + bool finished_reading_dataset_; + bool load_io_block_queue_; + bool load_jagged_connector_; + std::unique_ptr jagged_buffer_connector_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc new file mode 100644 index 0000000000..ae7907b5ce --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -0,0 +1,1054 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "proto/example.pb.h" +#include "./securec.h" +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/connector.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/jagged_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/wait_post.h" +#include "utils/system/crc32c.h" + +namespace mindspore { +namespace dataset { +TFReaderOp::Builder::Builder() + : builder_device_id_(0), + builder_num_devices_(1), + builder_total_rows_(0), + builder_equal_rows_per_shard_(false), + builder_sampler_(nullptr) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_shuffle_files_ = false; + builder_data_schema_ = std::make_unique(); +} + +bool ValidateFirstRowCrc(const std::string &filename) { + std::ifstream reader; + reader.open(filename); + if (!reader) { + return false; + } + + // read data + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // read crc from file + uint32_t masked_crc = 0; + (void)reader.read(reinterpret_cast(&masked_crc), static_cast(sizeof(uint32_t))); + + // generate crc from data + uint32_t generated_crc = + system::Crc32c::GetMaskCrc32cValue(reinterpret_cast(&record_length), sizeof(int64_t)); + + return masked_crc == generated_crc; +} + +Status TFReaderOp::Builder::ValidateInputs() const { + std::string err_msg; + + if (builder_num_workers_ <= 0) { + err_msg += "Number of parallel workers is smaller or equal to 0\n"; + } + + if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { + err_msg += "Wrong sharding configs\n"; + } + + std::vector invalid_files(builder_dataset_files_list_.size()); + auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(), + [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); + invalid_files.resize(std::distance(invalid_files.begin(), it)); + + if (!invalid_files.empty()) { + err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n"; + + std::string accumulated_filenames = std::accumulate( + invalid_files.begin(), invalid_files.end(), std::string(""), + [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); + err_msg += accumulated_filenames; + } + + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +Status TFReaderOp::Builder::Build(std::shared_ptr *out_tf_reader_op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_dataset_files_list_.size()) { + builder_num_workers_ = builder_dataset_files_list_.size(); + MS_LOG(WARNING) << "TFReader operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + std::shared_ptr new_tf_reader_op = std::make_shared( + builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, + builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, + builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_, + std::move(builder_sampler_)); + + RETURN_IF_NOT_OK(new_tf_reader_op->Init()); + *out_tf_reader_op = std::move(new_tf_reader_op); + return Status::OK(); +} + +TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, + int64_t total_num_rows, std::vector dataset_files_list, + std::unique_ptr data_schema, int32_t op_connector_size, + std::vector columns_to_load, bool shuffle_files, int32_t num_device, + int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + device_id_(device_id), + num_devices_(num_device), + rows_per_buffer_(rows_per_buffer), + total_rows_(total_num_rows), + dataset_files_list_(std::move(dataset_files_list)), + columns_to_load_(std::move(columns_to_load)), + finished_reading_dataset_(false), + shuffle_files_(shuffle_files), + data_schema_(std::move(data_schema)), + filename_index_(std::make_unique()), + load_io_block_queue_(true), + load_jagged_connector_(true), + num_rows_(0), + num_rows_per_shard_(0), + equal_rows_per_shard_(equal_rows_per_shard) { + worker_connector_size_ = worker_connector_size; +} + +// A print method typically used for debugging +void TFReaderOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nRows per buffer: " << rows_per_buffer_ << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ + << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") + << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n"; + for (int i = 0; i < dataset_files_list_.size(); ++i) { + out << " " << dataset_files_list_[i]; + } + if (!columns_to_load_.empty()) { + out << "\nColumns to load:\n"; + for (int i = 0; i < columns_to_load_.size(); ++i) { + out << " " << columns_to_load_[i]; + } + } + out << "\nData Schema:\n"; + out << *data_schema_ << "\n\n"; + } +} + +Status TFReaderOp::Init() { + if (data_schema_->Empty()) { + RETURN_IF_NOT_OK(CreateSchema(dataset_files_list_[0], columns_to_load_)); + } + + if (total_rows_ == 0) { + total_rows_ = data_schema_->num_rows(); + } + if (total_rows_ < 0) { + RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0"); + } + + // Build the index with our files such that each file corresponds to a key id. + RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); + + // The creation of the internal connector has been delayed until now, since we may have adjusted the + // number of workers. Now that the worker count is established, create the connector now in the + // parallel op base. + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + + jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); + + // temporary: make size large enough to hold all files + EOE to avoid hangs + int32_t safe_queue_size = static_cast(std::ceil(dataset_files_list_.size() / num_workers_)) + 1; + io_block_queues_.Init(num_workers_, safe_queue_size); + + return Status::OK(); +} + +Status TFReaderOp::CalculateNumRowsPerShard() { + if (!equal_rows_per_shard_) { + return Status::OK(); + } + + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + std::vector file(1, it.value()); + int64_t num = CountTotalRowsSectioned(file, 0, 1); + filename_numrows_[it.value()] = num; + num_rows_ += num; + } + num_rows_per_shard_ = static_cast(std::ceil(num_rows_ * 1.0 / num_devices_)); + if (num_rows_per_shard_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API TFRecordDataset.Please check file path or dataset API " + "validation first."); + } + return Status::OK(); +} +// Class functor operator () override. +// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will +// provide the master loop that drives the logic for performing the work +Status TFReaderOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling mIOBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TFReaderOp::WaitToFillIOBlockQueue, this))); + + // launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading + // data from disk into buffers + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&TFReaderOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. workers can't be spawned after this post, + // so workers have to be kept alive until the end of the program + TaskManager::FindMe()->Post(); + + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); + + NotifyToFillIOBlockQueue(); + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + { + std::unique_lock lock(load_io_block_queue_mutex_); + load_io_block_queue_ = true; + } + + while (workers_done < num_workers_) { + std::unique_ptr fetched_buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer)); + if (fetched_buffer->eoe()) { + workers_done++; + } else if (total_rows_ == 0 || rows_read < total_rows_) { + // we need to push a buffer + if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) { + // this is last buffer we need, and we only need a part of it + int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read); + RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove)); + } + + rows_read += fetched_buffer->NumRows(); + fetched_buffer->set_id(buffer_id); + buffer_id++; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); + } else { + // user specified number of rows they want, and we read enough rows + // + // IOBlockQueue thread needs to: + // -stop pushing stuff to IOBlockQueue + // -call PostEndOfEpoch (will send EOE) + // -wait for reset + // + // Worker threads need to: + // -stop reading the file they are currently reading and throw it away + // -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE) + // + // Master thread needs to: + // -tell IOBlockQueue thread to stop pushing + // -tell worker threads to stop reading the file tey are currently reading + // -keep pulling until EOE + + // don't think we need a lock for now + load_jagged_connector_ = false; + + std::unique_lock lock(load_io_block_queue_mutex_); + load_io_block_queue_ = false; + } + } + + // all workers finished reading for this epoch, and we have read all the data from all workers + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + + return Status::OK(); +} + +// static local-only helper function +static void shuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +// The entry point for when workers are launched. +Status TFReaderOp::WorkerEntry(int32_t worker_id) { + // must be called first if called by worker spawned by taskgroup + TaskManager::FindMe()->Post(); + + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + MS_LOG(DEBUG) << "TFReader operator worker " << worker_id << " loaded file " << filename << "."; + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(1, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status TFReaderOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +bool TFReaderOp::NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +Status TFReaderOp::FillIOBlockShuffle(const std::vector &i_keys) { + int32_t queue_index = 0; + int32_t key_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + bool end_of_epoch = false; + while (!finish) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + std::unique_lock lock(load_io_block_queue_mutex_); + if (load_io_block_queue_ == false) { + end_of_epoch = true; + break; + } + } + if (!equal_rows_per_shard_) { + if (key_index++ % num_devices_ == device_id_) { + auto ioBlock = std::make_unique(*it, kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + } else { + // Do an index lookup using that key to get the filename. + std::string file_name = (*filename_index_)[*it]; + if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { + auto ioBlock = std::make_unique(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset; + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_name]; + } + } + if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_ && + !end_of_epoch) { + finish = false; + } else { + finish = true; + } + } + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +Status TFReaderOp::FillIOBlockNoShuffle() { + int32_t queue_index = 0; + int32_t key_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + bool end_of_epoch = false; + while (!finish) { + // Iterate over all the keys and add one key to each block. + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + std::unique_lock lock(load_io_block_queue_mutex_); + if (load_io_block_queue_ == false) { + end_of_epoch = true; + break; + } + } + if (!equal_rows_per_shard_) { + if (key_index++ % num_devices_ == device_id_) { + auto ioBlock = + std::make_unique(it.key(), kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + } else { + std::string file_name = it.value(); + if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { + auto ioBlock = std::make_unique(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_name]; + } + } + if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_ && + !end_of_epoch) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. +Status TFReaderOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spawned by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + // Generate a vector of keys that we can shuffle + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + shuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + RETURN_IF_NOT_OK(FillIOBlockShuffle(i_keys)); + } else { // shuffle_files_ == false + RETURN_IF_NOT_OK(FillIOBlockNoShuffle()); + } + } + + return Status::OK(); +} + +// Notifies the thread which called WaitToFillIOBlockQueue to resume execution. +void TFReaderOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +// Pops an element from a queue in io_block_queues +Status TFReaderOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status TFReaderOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +// Reads a tf_file file and loads the data into multiple buffers. +Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, + const int32_t &worker_id) { + std::ifstream reader; + reader.open(filename); + if (!reader) { + RETURN_STATUS_UNEXPECTED("failed to open file: " + filename); + } + + int64_t rows_read = 0; + int64_t rows_total = 0; + std::unique_ptr current_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + std::unique_ptr new_tensor_table = std::make_unique(); + + while (reader.peek() != EOF) { + if (!load_jagged_connector_) { + break; + } + + // read length + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // ignore crc header + (void)reader.ignore(static_cast(sizeof(int32_t))); + + // read serialized Example + std::string serialized_example; + serialized_example.resize(record_length); + (void)reader.read(&serialized_example[0], static_cast(record_length)); + if (start_offset == kInvalidOffset || (rows_total >= start_offset && rows_total < end_offset)) { + dataengine::Example tf_file; + if (!tf_file.ParseFromString(serialized_example)) { + std::string errMsg = "parse tfrecord failed"; + RETURN_STATUS_UNEXPECTED(errMsg); + } + RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); + rows_read++; + } + + // ignore crc footer + (void)reader.ignore(static_cast(sizeof(int32_t))); + rows_total++; + + if (rows_read == rows_per_buffer_) { + current_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(current_buffer))); + + current_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + new_tensor_table = std::make_unique(); + rows_read = 0; + } + } + + if (rows_read > 0) { + current_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(current_buffer))); + } + + return Status::OK(); +} + +// Parses a single row and puts the data into a tensor table. +Status TFReaderOp::LoadExample(const dataengine::Example *tf_file, std::unique_ptr *tensor_table, + int64_t row) { + int32_t num_columns = data_schema_->NumColumns(); + TensorRow newRow(num_columns, nullptr); + (*tensor_table)->push_back(std::move(newRow)); + + for (int32_t col = 0; col < num_columns; ++col) { + const ColDescriptor current_col = data_schema_->column(col); + const dataengine::Features &example_features = tf_file->features(); + const google::protobuf::Map &feature_map = example_features.feature(); + const dataengine::Feature &column_values_list = feature_map.at(current_col.name()); + RETURN_IF_NOT_OK(LoadFeature(tensor_table, column_values_list, current_col, row, col)); + } + + return Status::OK(); +} + +// Parses a single cell and puts the data into a tensor table. +Status TFReaderOp::LoadFeature(const std::unique_ptr *tensor_table, + const dataengine::Feature &column_values_list, const ColDescriptor ¤t_col, + int64_t row, int32_t col) { + const dataengine::Feature::KindCase column_list_type = column_values_list.kind_case(); + std::unique_ptr float_array; // For staging data from protobuf deserialization + const unsigned char *data_ptr = nullptr; // Generic pointer used for populating the Tensor + + // This variable will point into the above staging variables. + // Also used for creating shape attributes. + int32_t num_elements = 0; + + // we build a tensor first a read directly into it if we need to cast + std::shared_ptr ts; + + // Depending on the type of data from the tf_file, we want to extract 2 things: + // 1) A pointer to the data as a const unsigned char * + // 2) The number of elements of the data + // After those are determined, we can then build the tensor to represent this data. + switch (column_list_type) { + case dataengine::Feature::KindCase::kBytesList: { + RETURN_IF_NOT_OK(LoadBytesList(current_col, column_values_list, &num_elements, &ts)); + + break; + } + case dataengine::Feature::KindCase::kFloatList: { + RETURN_IF_NOT_OK(LoadFloatList(current_col, column_values_list, &num_elements, &float_array)); + + data_ptr = reinterpret_cast(float_array.get()); + + // only floatList needs to create the tensor here, other two lists read directly + // into the tensor + TensorShape current_shape = TensorShape::CreateUnknownRankShape(); + RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(num_elements, ¤t_shape)); + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&ts, current_col.tensorImpl(), current_shape, current_col.type(), data_ptr)); + break; + } + case dataengine::Feature::KindCase::kInt64List: { + RETURN_IF_NOT_OK(LoadIntListSwitch(current_col, column_values_list, &num_elements, &ts)); + break; + } + case dataengine::Feature::KindCase::KIND_NOT_SET: { + std::string err_msg = "tf_file column list type enum is KIND_NOT_SET"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + default: { + std::string err_msg = "tf_file column list type enum does not match any known DE type"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + + (**tensor_table)[row][col] = std::move(ts); + + return Status::OK(); +} + +// Overrides base class reset method. Cleans up any state info from it's previous execution and +// reinitializes itself so that it can be executed again, as if it was just created. +Status TFReaderOp::Reset() { + // start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true + load_jagged_connector_ = true; + + { + std::unique_lock lock(load_io_block_queue_mutex_); + load_io_block_queue_ = true; + } + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + + return Status::OK(); +} + +Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor) { + // kBytesList can map to the following DE types ONLY! + // DE_UINT8, DE_INT8 + // Must be single byte type for each element! + if (current_col.type() != DataType::DE_UINT8 && current_col.type() != DataType::DE_INT8 && + current_col.type() != DataType::DE_STRING) { + std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + const dataengine::BytesList &bytes_list = column_values_list.bytes_list(); + + *num_elements = bytes_list.value_size(); + + if (current_col.type() == DataType::DE_STRING) { + TensorShape shape = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, shape)); + return Status::OK(); + } + + uint64_t max_size = 0; + for (uint32_t i = 0; i < bytes_list.value_size(); ++i) max_size = std::max(max_size, bytes_list.value(i).size()); + + int64_t pad_size = max_size; + + // if user provides a shape in the form of [-1, d1, 2d, ... , dn], we need to pad to d1 * d2 * ... * dn + if (current_col.hasShape()) { + TensorShape cur_shape = current_col.shape(); + if (cur_shape.Size() >= 2 && cur_shape[0] == TensorShape::kDimUnknown) { + int64_t new_pad_size = 1; + for (int i = 1; i < cur_shape.Size(); ++i) { + if (cur_shape[i] == TensorShape::kDimUnknown) { + std::string err_msg = "More than one unknown dimension in the shape of column: " + current_col.name(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + new_pad_size *= cur_shape[i]; + } + pad_size = new_pad_size; + } + } + + // know how many elements there are and the total bytes, create tensor here: + TensorShape current_shape = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(current_col.MaterializeTensorShape((*num_elements) * pad_size, ¤t_shape)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, current_shape, current_col.type(), pad_size)); + + return Status::OK(); +} + +Status TFReaderOp::LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::unique_ptr *float_array) { + // KFloatList can only map to DE types: + // DE_FLOAT32 + if (current_col.type() != DataType::DE_FLOAT32) { + std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + const dataengine::FloatList &float_list = column_values_list.float_list(); + + // Identify how many values we have and then create a local array of these + // to deserialize into + *num_elements = float_list.value_size(); + *float_array = std::make_unique(*num_elements); + for (int i = 0; i < float_list.value_size(); ++i) { + (*float_array)[i] = float_list.value(i); + } + + return Status::OK(); +} + +// Determines which template type to use and calls LoadIntList +Status TFReaderOp::LoadIntListSwitch(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor) { + if (current_col.type() == DataType::DE_UINT64) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_INT64) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_UINT32) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_INT32) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_UINT16) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_INT16) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_UINT8) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_INT8) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else { + std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + return Status::OK(); +} + +// Reads values from a bytes list and casts the value to type T, must be an integral type +// compatible with int64_t +template +Status TFReaderOp::LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor) { + if (!(current_col.type().IsInt())) { + std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + const dataengine::Int64List &int64_list = column_values_list.int64_list(); + + // Identify how many values we have and then create a local array of these + // to deserialize into + *num_elements = int64_list.value_size(); + + // know how many elements there are, create tensor here: + TensorShape current_shape = TensorShape::CreateUnknownRankShape(); + RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, ¤t_shape)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, current_col.tensorImpl(), current_shape, current_col.type())); + + // Tensors are lazily allocated, this eagerly allocates memory for the tensor. + RETURN_IF_NOT_OK((*tensor)->AllocateBuffer((*tensor)->SizeInBytes())); + + int64_t i = 0; + auto it = (*tensor)->begin(); + for (; it != (*tensor)->end(); i++, ++it) { + T element = static_cast(int64_list.value(i)); + *it = element; + } + + return Status::OK(); +} + +Status TFReaderOp::CreateSchema(const std::string tf_file, std::vector columns_to_load) { + std::ifstream reader; + reader.open(tf_file); + + // read length + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // ignore crc header + (void)reader.ignore(static_cast(sizeof(int32_t))); + + // read serialized Example + std::string serialized_example; + serialized_example.resize(record_length); + (void)reader.read(&serialized_example[0], static_cast(record_length)); + + dataengine::Example example; + if (!example.ParseFromString(serialized_example)) RETURN_STATUS_UNEXPECTED("parse tf_file failed"); + + const dataengine::Features &example_features = example.features(); + const google::protobuf::Map &feature_map = example_features.feature(); + + if (columns_to_load.empty()) { + (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns_to_load), + [](const auto &it) -> std::string { return it.first; }); + std::sort(columns_to_load.begin(), columns_to_load.end()); + } + + for (const auto &curr_col_name : columns_to_load) { + auto it = feature_map.find(curr_col_name); + if (it == feature_map.end()) { + RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); + } + std::string column_name = it->first; + + std::string column_type; + + const dataengine::Feature &feature = it->second; + const dataengine::Feature::KindCase kind_case = feature.kind_case(); + switch (kind_case) { + case dataengine::Feature::KindCase::kBytesList: + column_type = "uint8"; + break; + + case dataengine::Feature::KindCase::kFloatList: + column_type = "float32"; + break; + + case dataengine::Feature::KindCase::kInt64List: + column_type = "int64"; + break; + + case dataengine::Feature::KindCase::KIND_NOT_SET: + RETURN_STATUS_UNEXPECTED("trying to make schema, tf_file column list type enum is KIND_NOT_SET"); + + default: + RETURN_STATUS_UNEXPECTED( + "trying to make schema, tf_file column list type enum does not match any known DE type"); + } + + RETURN_IF_NOT_OK( + data_schema_->AddColumn(ColDescriptor(column_name, DataType(column_type), TensorImpl::kFlexible, 1))); + } + + return Status::OK(); +} + +Status TFReaderOp::CountTotalRows(int64_t *out_total_rows, const std::vector &filenames, int64_t threads, + bool estimate) { + try { + if (threads > filenames.size()) { + threads = filenames.size(); + } + + std::vector> async_results; + + int64_t chunk_size = filenames.size() / threads; + int64_t remainder = filenames.size() % threads; + + int64_t begin = 0; + int64_t end = begin; + for (int i = 0; i < threads; i++) { + end += chunk_size; + if (remainder > 0) { + end++; + remainder--; + } + + if (estimate) { + // Parse a single file for each chunk with estimate mode on + async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, begin + 1)); + } else { + // Parse the whole chunk with estimate mode off + async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, end)); + } + + begin = end; + } + + int64_t total_rows = 0; + for (int i = 0; i < async_results.size(); i++) { + total_rows += async_results[i].get(); + } + + if (estimate) { + // Each thread only scans 1 file + // Estimated total rows = Average rows * total number of files + total_rows = total_rows / threads * filenames.size(); + } + + *out_total_rows = total_rows; + } catch (const std::exception &e) { + std::string err_msg = "Unexpected error occurred: "; + err_msg += e.what(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + return Status::OK(); +} + +int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector &filenames, int64_t begin, int64_t end) { + int64_t rows_read = 0; + for (int i = begin; i < end; i++) { + std::ifstream reader; + reader.open(filenames[i]); + if (!reader) { + MS_LOG(DEBUG) << "TFReader operator failed to open file " << filenames[i] << "."; + } + + while (reader.peek() != EOF) { + // read length + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // ignore crc header + (void)reader.ignore(static_cast(sizeof(int32_t))); + + // ignore tf_file contents + (void)reader.ignore(static_cast(record_length)); + + // ignore crc footer + (void)reader.ignore(static_cast(sizeof(int32_t))); + + rows_read++; + } + } + + return rows_read; +} + +// Visitor accept method for NodePass +Status TFReaderOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status TFReaderOp::ComputeColMap() { + // Construct the column name map for this operator (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing +// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so +// that this tf reader will produce the full set of data into the cache. +void TFReaderOp::MakeSimpleProducer() { + device_id_ = 0; + num_devices_ = 1; + total_rows_ = 0; + shuffle_files_ = false; + equal_rows_per_shard_ = false; +} + +// During tree prepare phase, operators may have specific post-operations to perform depending on +// their role. +Status TFReaderOp::PrepareNodePostAction() { + // Run common code from super class before adding TFReaderOp specific handling + RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); + + // Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into + // a simpler producer of all data (no shuffling or sharding or anything) + if (!BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { + // This sanity check had been delayed until now in the prepare loop. + // If we are not in a cache path, then we can validate the file-based sharding config. + // If we are in a cache path, there is no file-based sharding so the check is not correct in that + // situation. + if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast(num_devices_)) { + RETURN_STATUS_UNEXPECTED("Not enough tfrecord files provided\n"); + } + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h new file mode 100644 index 0000000000..c03f3957e9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h @@ -0,0 +1,420 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" + +namespace dataengine { +class Example; +class Feature; +class BytesList; +} // namespace dataengine + +namespace mindspore { +namespace dataset { +template +class Queue; + +template +class Connector; + +class JaggedConnector; +class FilenameBlock; + +using StringIndex = AutoIndexObj; + +class TFReaderOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + Status Build(std::shared_ptr *out_tf_reader_op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDataSchema(std::unique_ptr data_schema) { + builder_data_schema_ = std::move(data_schema); + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetWorkerConnectorSize(int32_t size) { + builder_worker_connector_size_ = size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &setTotalRows(int64_t total_rows) { + builder_total_rows_ = total_rows; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDatasetFilesList(const std::vector &dataset_files_list) { + builder_dataset_files_list_ = dataset_files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetColumnsToLoad(const std::vector &columns_to_load) { + builder_columns_to_load_ = columns_to_load; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShardEqualRows(bool shard_equal_rows) { + builder_equal_rows_per_shard_ = shard_equal_rows; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + private: + std::unique_ptr builder_data_schema_; + std::shared_ptr builder_sampler_; + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_worker_connector_size_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_total_rows_; + std::vector builder_dataset_files_list_; + std::vector builder_columns_to_load_; + bool builder_shuffle_files_; + bool builder_equal_rows_per_shard_; + }; + + // Constructor of TFReaderOp (2) + // @note The builder class should be used to call this constructor. + // @param num_workers - number of worker threads reading data from tf_file files. + // @param worker_connector_size - size of each internal queue. + // @param rows_per_buffer - number of rows that a full buffer will contain. + // @param total_num_rows - Number of rows to read + // @param dataset_files_list - list of filepaths for the dataset files. + // @param data_schema - the data schema object. + // @param op_connector_size - size of each queue in the connector that the child operator pulls from. + // @param columns_to_load - the names of the columns to load data from. + // @param shuffle_files - whether or not to shuffle the files before reading data. + // @param equal_rows_per_shard - whether or not to get equal rows for each process. + // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes + TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, + std::vector dataset_files_list, std::unique_ptr data_schema, + int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, + int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler); + + // Default destructor + ~TFReaderOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Instantiates the internal queues and connectors. + // @return Status - the error code returned. + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution and + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Getter method + int64_t rows_per_buffer() const { return rows_per_buffer_; } + + // Reads all the provided tf_file files and counts the total number of rows. filenames will + // first be sectioned into equal parts, then sections are read in parallel. If threads is + // greater than the number of files, threads will be clamped to the number of files. + // @param out_total_tows - output parameter which contains the total number of rows + // @param filenames - a list of tf_file filenames. + // @param threads - number of threads to use to read the tf_file files. + // @param estimate - estimate mode, under this mode each threads will sample a single file from each chunk + // @return Status - the error code returned. + static Status CountTotalRows(int64_t *out_total_rows, const std::vector &filenames, int64_t threads = 1, + bool estimate = false); + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "TFReaderOp"; } + + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return dataset_files_list_; } + + /// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so + /// that this tf reader will produce the full set of data into the cache. + void MakeSimpleProducer(); + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePostAction() override; + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Notifies the thread which called WaitToFillIOBlockQueue to resume execution. + void NotifyToFillIOBlockQueue(); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Reads a tf_file file and loads the data into multiple buffers. + // @param filename - the tf_file file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, + const int32_t &worker_id); + + // Parses a single row and puts the data into a tensor table. + // @param tf_file - the row to be parsed. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadExample(const dataengine::Example *tf_file, std::unique_ptr *tensor_table, int64_t row); + + // Parses a single cell and puts the data into a tensor table. + // @param tensor_table - the tensor table to put the parsed data in. + // @param column_values_list - the cell to parse. + // @param current_col - the column descriptor containing the expected shape and type of the data. + // @return Status - the error code returned. + Status LoadFeature(const std::unique_ptr *tensor_table, const dataengine::Feature &column_values_list, + const ColDescriptor ¤t_col, int64_t row, int32_t col); + + // Reads values from a bytes list + // @param current_col - the column descriptor containing the expected shape and type of the data. + // @param column_values_list - the cell that contains the bytes list to read from. + // @param elementStr - the string we read the value into. + // @return Status - the error code returned. + static Status LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor); + + // Reads values from a float list + // @param current_col - the column descriptor containing the expected shape and type of the data. + // @param column_values_list - the cell that contains the float list to read from. + // @Param numElements - number of values in the float list. + // @param float_array - the array we read the values into. + // @return Status - the error code returned. + Status LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::unique_ptr *float_array); + + // Reads values from a bytes list and casts the value to type T, must be an integral + // type compatible with int64_t + // @param current_col - the column descriptor containing the expected shape and type of the data. + // @param column_values_list - the cell that contains the int list to read from. + // @Param num_elements - number of values in the int list. + // @param tensor - the tensor we read the values into. + // @return Status - the error code returned. + template + Status LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor); + + // Determines which template type to use and calls LoadIntList + // @param current_col - the column descriptor containing the expected shape and type of the data. + // @param column_values_list - the cell that contains the int list to read from. + // @Param numElements - number of values in the int list. + // @param tensor - the tensor we read the values into. + // @return Status - the error code returned. + Status LoadIntListSwitch(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor); + + // Reads one row of data from a tf file and creates a schema based on that row + // @return Status - the error code returned. + Status CreateSchema(const std::string tf_file, std::vector columns_to_load); + + // Meant to be called async. Will read files in the range [begin, end) and return the total rows + // @param filenames - a list of tf data filenames. + // @param begin - index of first file to read. + // @param end - one greater than the index of the last file to read. + // @return int63_t - the total number of rows of files read. + static int64_t CountTotalRowsSectioned(const std::vector &filenames, const int64_t begin, + const int64_t end); + // Fill IO block queue if shuffle is true + // @param i_keys - shuffle keys. + // @return Status - the error code returned. + Status FillIOBlockShuffle(const std::vector &i_keys); + + /** + * Fill IO block queue if shuffle is false + * @param i_keys - shuffle keys. + * @return Status - the error code returned. + */ + Status FillIOBlockNoShuffle(); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Caculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t device_id_; + int32_t num_devices_; + int64_t rows_per_buffer_; + int64_t total_rows_; + std::vector dataset_files_list_; + std::vector columns_to_load_; + bool finished_reading_dataset_; + bool shuffle_files_; + std::unique_ptr data_schema_; + std::unique_ptr filename_index_; + bool load_io_block_queue_; + bool load_jagged_connector_; + + std::unique_ptr jagged_buffer_connector_; + QueueList> io_block_queues_; + WaitPost io_block_queue_wait_post_; + std::mutex load_io_block_queue_mutex_; + std::map filename_numrows_; + int64_t num_rows_; + int64_t num_rows_per_shard_; + bool equal_rows_per_shard_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc new file mode 100644 index 0000000000..e90d423ef4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -0,0 +1,471 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/voc_op.h" + +#include +#include +#include +#include "./tinyxml2.h" +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +using tinyxml2::XMLDocument; +using tinyxml2::XMLElement; +using tinyxml2::XMLError; +namespace mindspore { +namespace dataset { +const char kColumnImage[] = "image"; +const char kColumnTarget[] = "target"; +const char kColumnAnnotation[] = "annotation"; +const char kJPEGImagesFolder[] = "/JPEGImages/"; +const char kSegmentationClassFolder[] = "/SegmentationClass/"; +const char kAnnotationsFolder[] = "/Annotations/"; +const char kImageSetsSegmentation[] = "/ImageSets/Segmentation/"; +const char kImageSetsMain[] = "/ImageSets/Main/"; +const char kImageExtension[] = ".jpg"; +const char kSegmentationExtension[] = ".png"; +const char kAnnotationExtension[] = ".xml"; +const char kImageSetsExtension[] = ".txt"; + +VOCOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); + builder_task_type_ = TaskType::Segmentation; +} + +Status VOCOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + builder_schema_ = std::make_unique(); + if (builder_task_type_ == TaskType::Segmentation) { + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnTarget), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + } else if (builder_task_type_ == TaskType::Detection) { + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnAnnotation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + } + *ptr = std::make_shared(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_, + builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, + builder_decode_, std::move(builder_schema_), std::move(builder_sampler_)); + return Status::OK(); +} + +Status VOCOp::Builder::SanityCheck() { + Path dir(builder_dir_); + std::string err_msg; + err_msg += dir.IsDirectory() == false ? "VOC path is invalid or not set\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, + const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, + int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_workers, queue_size, std::move(sampler)), + decode_(decode), + row_cnt_(0), + buf_cnt_(0), + task_type_(task_type), + task_mode_(task_mode), + folder_path_(folder_path), + class_index_(class_index), + rows_per_buffer_(rows_per_buffer), + data_schema_(std::move(data_schema)) { + io_block_queues_.Init(num_workers_, queue_size); +} + +Status VOCOp::TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) > num_rows_) continue; + keys->push_back(*itr); + row_cnt_++; + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); + keys->clear(); + } + } + return Status::OK(); +} + +Status VOCOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (true) { + std::vector keys; + keys.reserve(rows_per_buffer_); + while (sampler_buffer->eoe() == false) { + std::shared_ptr sample_ids; + RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); + if (sample_ids->type() != DataType(DataType::DE_INT64)) { + RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); + } + RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + } +} + +void VOCOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows: " << num_rows_ << "\nVOC Directory: " << folder_path_ << "\n\n"; + } +} + +Status VOCOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); + return Status::OK(); +} + +Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *trow) { + if (task_type_ == TaskType::Segmentation) { + std::shared_ptr image, target; + const std::string kImageFile = + folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension); + const std::string kTargetFile = + folder_path_ + std::string(kSegmentationClassFolder) + image_id + std::string(kSegmentationExtension); + RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); + RETURN_IF_NOT_OK(ReadImageToTensor(kTargetFile, data_schema_->column(1), &target)); + (*trow) = TensorRow(row_id, {std::move(image), std::move(target)}); + } else if (task_type_ == TaskType::Detection) { + std::shared_ptr image, annotation; + const std::string kImageFile = + folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension); + const std::string kAnnotationFile = + folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); + RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); + RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, data_schema_->column(1), &annotation)); + (*trow) = TensorRow(row_id, {std::move(image), std::move(annotation)}); + } + return Status::OK(); +} + +Status VOCOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + TensorRow trow; + for (const uint64_t &key : keys) { + RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_ids_[key], &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +Status VOCOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, (std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty() == true) return Status::OK(); + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +Status VOCOp::ParseImageIds() { + std::string image_sets_file; + if (task_type_ == TaskType::Segmentation) { + image_sets_file = + folder_path_ + std::string(kImageSetsSegmentation) + task_mode_ + std::string(kImageSetsExtension); + } else if (task_type_ == TaskType::Detection) { + image_sets_file = folder_path_ + std::string(kImageSetsMain) + task_mode_ + std::string(kImageSetsExtension); + } + std::ifstream in_file; + in_file.open(image_sets_file); + if (in_file.fail()) { + RETURN_STATUS_UNEXPECTED("Fail to open file: " + image_sets_file); + } + std::string id; + while (getline(in_file, id)) { + if (id.size() > 0 && id[id.size() - 1] == '\r') { + image_ids_.push_back(id.substr(0, id.size() - 1)); + } else { + image_ids_.push_back(id); + } + } + in_file.close(); + image_ids_.shrink_to_fit(); + num_rows_ = image_ids_.size(); + return Status::OK(); +} + +Status VOCOp::ParseAnnotationIds() { + std::vector new_image_ids; + for (auto id : image_ids_) { + const std::string kAnnotationName = + folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension); + RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName)); + if (label_map_.find(kAnnotationName) != label_map_.end()) { + new_image_ids.push_back(id); + } + } + + if (image_ids_.size() != new_image_ids.size()) { + image_ids_.clear(); + image_ids_.insert(image_ids_.end(), new_image_ids.begin(), new_image_ids.end()); + } + uint32_t count = 0; + for (auto &label : label_index_) { + label.second = count++; + } + + num_rows_ = image_ids_.size(); + return Status::OK(); +} + +Status VOCOp::ParseAnnotationBbox(const std::string &path) { + if (!Path(path).Exists()) { + RETURN_STATUS_UNEXPECTED("File is not found : " + path); + } + Bbox bbox; + XMLDocument doc; + XMLError e = doc.LoadFile(common::SafeCStr(path)); + if (e != XMLError::XML_SUCCESS) { + RETURN_STATUS_UNEXPECTED("Xml load failed"); + } + XMLElement *root = doc.RootElement(); + if (root == nullptr) { + RETURN_STATUS_UNEXPECTED("Xml load root element error"); + } + XMLElement *object = root->FirstChildElement("object"); + if (object == nullptr) { + RETURN_STATUS_UNEXPECTED("No object find in " + path); + } + while (object != nullptr) { + std::string label_name; + float xmin = 0.0, ymin = 0.0, xmax = 0.0, ymax = 0.0, truncated = 0.0, difficult = 0.0; + XMLElement *name_node = object->FirstChildElement("name"); + if (name_node != nullptr && name_node->GetText() != 0) label_name = name_node->GetText(); + XMLElement *truncated_node = object->FirstChildElement("truncated"); + if (truncated_node != nullptr) truncated = truncated_node->FloatText(); + XMLElement *difficult_node = object->FirstChildElement("difficult"); + if (difficult_node != nullptr) difficult = difficult_node->FloatText(); + + XMLElement *bbox_node = object->FirstChildElement("bndbox"); + if (bbox_node != nullptr) { + XMLElement *xmin_node = bbox_node->FirstChildElement("xmin"); + if (xmin_node != nullptr) xmin = xmin_node->FloatText(); + XMLElement *ymin_node = bbox_node->FirstChildElement("ymin"); + if (ymin_node != nullptr) ymin = ymin_node->FloatText(); + XMLElement *xmax_node = bbox_node->FirstChildElement("xmax"); + if (xmax_node != nullptr) xmax = xmax_node->FloatText(); + XMLElement *ymax_node = bbox_node->FirstChildElement("ymax"); + if (ymax_node != nullptr) ymax = ymax_node->FloatText(); + } else { + RETURN_STATUS_UNEXPECTED("bndbox dismatch in " + path); + } + if (label_name != "" && (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) && xmin > 0 && + ymin > 0 && xmax > xmin && ymax > ymin) { + std::vector bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, truncated, difficult}; + bbox.emplace_back(std::make_pair(label_name, bbox_list)); + label_index_[label_name] = 0; + } + object = object->NextSiblingElement("object"); + } + if (bbox.size() > 0) label_map_[path] = bbox; + return Status::OK(); +} + +Status VOCOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +Status VOCOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&VOCOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(this->ParseImageIds()); + if (task_type_ == TaskType::Detection) { + RETURN_IF_NOT_OK(this->ParseAnnotationIds()); + } + RETURN_IF_NOT_OK(this->InitSampler()); + return Status::OK(); +} + +Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); + if (decode_ == true) { + Status rc = Decode(*tensor, tensor); + if (rc.IsError()) { + RETURN_STATUS_UNEXPECTED("fail to decode file: " + path); + } + } + return Status::OK(); +} + +Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, + std::shared_ptr *tensor) { + Bbox bbox_info = label_map_[path]; + std::vector bbox_row; + dsize_t bbox_column_num = 0, bbox_num = 0; + for (auto box : bbox_info) { + if (label_index_.find(box.first) != label_index_.end()) { + std::vector bbox; + bbox.insert(bbox.end(), box.second.begin(), box.second.end()); + if (class_index_.find(box.first) != class_index_.end()) { + bbox.push_back(static_cast(class_index_[box.first])); + } else { + bbox.push_back(static_cast(label_index_[box.first])); + } + bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); + if (bbox_column_num == 0) { + bbox_column_num = static_cast(bbox.size()); + } + bbox_num++; + } + } + + std::vector bbox_dim = {bbox_num, bbox_column_num}; + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, col.tensorImpl(), TensorShape(bbox_dim), col.type(), + reinterpret_cast(&bbox_row[0]))); + return Status::OK(); +} + +Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, int64_t *count) { + if (task_type == "Detection") { + std::map input_class_indexing; + for (auto p : dict) { + (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), + py::reinterpret_borrow(p.second))); + } + + std::shared_ptr op; + RETURN_IF_NOT_OK( + Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + *count = static_cast(op->image_ids_.size()); + } else if (task_type == "Segmentation") { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + *count = static_cast(op->image_ids_.size()); + } + + return Status::OK(); +} + +Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, std::map *output_class_indexing) { + std::map input_class_indexing; + for (auto p : dict) { + (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), + py::reinterpret_borrow(p.second))); + } + + if (!input_class_indexing.empty()) { + *output_class_indexing = input_class_indexing; + } else { + std::shared_ptr op; + RETURN_IF_NOT_OK( + Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + for (const auto label : op->label_index_) { + (*output_class_indexing).insert(std::make_pair(label.first, label.second)); + } + } + + return Status::OK(); +} +// Visitor accept method for NodePass +Status VOCOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status VOCOp::ComputeColMap() { + // Set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h new file mode 100644 index 0000000000..e0c46c7a94 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h @@ -0,0 +1,294 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// Forward declares +template +class Queue; + +using Bbox = std::vector>>; + +class VOCOp : public ParallelOp, public RandomAccessOp { + public: + enum class TaskType { Segmentation = 0, Detection = 1 }; + + class Builder { + public: + // Constructor for Builder class of ImageFolderOp + // @param uint32_t numWrks - number of parallel workers + // @param dir - directory folder got ImageNetFolder + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method. + // @param const std::string & build_dir + // @return Builder setter method returns reference to the builder. + Builder &SetDir(const std::string &build_dir) { + builder_dir_ = build_dir; + return *this; + } + + // Setter method. + // @param const std::map &map - a class name to label map + // @return Builder setter method returns reference to the builder. + Builder &SetClassIndex(const std::map &map) { + builder_labels_to_read_ = map; + return *this; + } + + // Setter method. + // @param const std::string & task_type + // @return Builder setter method returns reference to the builder. + Builder &SetTask(const std::string &task_type) { + if (task_type == "Segmentation") { + builder_task_type_ = TaskType::Segmentation; + } else if (task_type == "Detection") { + builder_task_type_ = TaskType::Detection; + } + return *this; + } + + // Setter method. + // @param const std::string & task_mode + // @return Builder setter method returns reference to the builder. + Builder &SetMode(const std::string &task_mode) { + builder_task_mode_ = task_mode; + return *this; + } + + // Setter method. + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method. + // @param bool do_decode + // @return Builder setter method returns reference to the builder. + Builder &SetDecode(bool do_decode) { + builder_decode_ = do_decode; + return *this; + } + + // Check validity of input args + // @return = The error code return + Status SanityCheck(); + + // The builder "Build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + bool builder_decode_; + std::string builder_dir_; + TaskType builder_task_type_; + std::string builder_task_mode_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int32_t builder_rows_per_buffer_; + std::shared_ptr builder_sampler_; + std::unique_ptr builder_schema_; + std::map builder_labels_to_read_; + }; + + // Constructor + // @param TaskType task_type - task type of VOC + // @param std::string task_mode - task mode of VOC + // @param std::string folder_path - dir directory of VOC + // @param std::map class_index - input class-to-index of annotation + // @param int32_t num_workers - number of workers reading images in parallel + // @param int32_t rows_per_buffer - number of images (rows) in each buffer + // @param int32_t queue_size - connector queue size + // @param bool decode - whether to decode images + // @param std::unique_ptr data_schema - the schema of the VOC dataset + // @param std::shared_ptr sampler - sampler tells VOCOp what to read + VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, + const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, + int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler); + + // Destructor + ~VOCOp() = default; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t workerId - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Main Loop of VOCOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // @param const std::string &dir - VOC dir path + // @param const std::string &task_type - task type of reading voc job + // @param const std::string &task_mode - task mode of reading voc job + // @param const py::dict &dict - input dict of class index + // @param int64_t *count - output rows number of VOCDataset + static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, int64_t *count); + + // @param const std::string &dir - VOC dir path + // @param const std::string &task_type - task type of reading voc job + // @param const std::string &task_mode - task mode of reading voc job + // @param const py::dict &dict - input dict of class index + // @param int64_t numSamples - samples number of VOCDataset + // @param std::map *output_class_indexing - output class index of VOCDataset + static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, std::map *output_class_indexing); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "VOCOp"; } + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Load a tensor row according to image id + // @param row_id_type row_id - id for this tensor row + // @param std::string image_id - image id + // @param TensorRow row - image & target read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); + + // @param const std::string &path - path to the image file + // @param const ColDescriptor &col - contains tensor implementation and datatype + // @param std::shared_ptr tensor - return + // @return Status - The error code return + Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); + + // @param const std::string &path - path to the image file + // @param const ColDescriptor &col - contains tensor implementation and datatype + // @param std::shared_ptr tensor - return + // @return Status - The error code return + Status ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Read image list from ImageSets + // @return Status - The error code return + Status ParseImageIds(); + + // Read annotation from Annotation folder + // @return Status - The error code return + Status ParseAnnotationIds(); + + // @param const std::string &path - path to annotation xml + // @return Status - The error code return + Status ParseAnnotationBbox(const std::string &path); + + // @param const std::shared_ptr &sample_ids - sample ids of tensor + // @param std::vector *keys - image id + // @return Status - The error code return + Status TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); + + // Called first when function is called + // @return Status - The error code return + Status LaunchThreadsAndInitOp(); + + // Reset dataset state + // @return Status - The error code return + Status Reset() override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + bool decode_; + int64_t row_cnt_; + int64_t buf_cnt_; + std::string folder_path_; + TaskType task_type_; + std::string task_mode_; + int32_t rows_per_buffer_; + std::unique_ptr data_schema_; + + WaitPost wp_; + std::vector image_ids_; + QueueList> io_block_queues_; + std::map class_index_; + std::map label_index_; + std::map label_map_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc new file mode 100644 index 0000000000..d1f07983f7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/take_op.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status TakeOp::Builder::SanityCheck() const { + if (build_max_takes_ <= 0) { + std::string err_msg("Take count must be greater than 0."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status TakeOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_max_takes_, builder_op_connector_size_); + return Status::OK(); +} + +// Constructor of the TakeOp. +TakeOp::TakeOp(int32_t count, int32_t op_connector_size) + : PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {} + +// A print method typically used for debugging +void TakeOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [takes: " << max_takes_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nTake count: " << take_count_ << "\nMax takes: " << max_takes_ << "\n\n"; + } +} + +// Main entry point for Take +Status TakeOp::operator()() { + TaskManager::FindMe()->Post(); + std::unique_ptr buf; + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); + + while (buf->eof() == false) { + if (take_count_ == max_takes_) { + // Do drain Operation + while (!buf->eoe() && !buf->eof()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); + } + } + + // Loop until non EOE is received + if (buf->eoe()) { + take_count_ = 0; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); + continue; + } + + // Get buffer and push back when take_count is still small + if (take_count_ < max_takes_) { + std::unique_ptr p_buffer; + RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer)); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer))); + } + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); + } + + take_count_ = 0; + MS_LOG(DEBUG) << "Meet the end and push-back eof buffer."; + auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + return Status::OK(); +} + +// Function FillBuffer mainly prepare the buffer for returning +Status TakeOp::FillBuffer(std::unique_ptr *buffer, std::unique_ptr *data_buffer) { + int32_t buffer_size = (*buffer)->NumRows(); + if (take_count_ + buffer_size < max_takes_) { + *data_buffer = std::move(*buffer); + take_count_ = take_count_ + buffer_size; + } else { + MS_LOG(DEBUG) << "In last buffer: Push one buffer."; + std::unique_ptr new_tensor_table = std::make_unique(); + while (take_count_ < max_takes_) { + TensorRow new_row; + RETURN_IF_NOT_OK((*buffer)->PopRow(&new_row)); + take_count_++; + new_tensor_table->push_back(new_row); + } + (*buffer)->set_tensor_table(std::move(new_tensor_table)); + *data_buffer = std::move(*buffer); + } + return Status::OK(); +} + +// Visitor accept method for NodePass +Status TakeOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h new file mode 100644 index 0000000000..7f3f821bd8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ + +#include +#include +#include +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class TakeOp : public PipelineOp { + public: + // The nested builder class inside of the TakeOp is used to help manage all of the arguments + // for constructing it. This take op is very simple though, so this builder is really just + // provided for a consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @param count - The number of takes to do + // @return This is a constructor. + explicit Builder(int32_t count); + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new TakeOp object + Status Build(std::shared_ptr *); + + private: + int32_t build_max_takes_; + int32_t builder_op_connector_size_; + + Status SanityCheck() const; + }; + + // Constructor of the TakeOp. + // @note The builder class should be used to call it + // @param count - The number of takes to do + explicit TakeOp(int32_t count, int32_t op_connector_size); + + // Destructor + ~TakeOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param ro - reference to the TakeOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const TakeOp &ro) { + ro.Print(out, false); + return out; + } + + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "TakeOp"; } + + private: + int32_t max_takes_; // The number of takes that the user requested + int32_t take_count_; // A counter for the current number of executed takes + + Status FillBuffer(std::unique_ptr *buffer, std::unique_ptr *data_buffer); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc new file mode 100644 index 0000000000..88019c30fc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc @@ -0,0 +1,268 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/zip_op.h" +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +ZipOp::Builder::Builder() { + // Some arguments to the ZipOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the ZipOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status ZipOp::Builder::SanityCheck() const { return Status::OK(); } + +Status ZipOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_); + return Status::OK(); +} + +// Construct ZipOp here, local variables initialized in operator due to tree construction restrictions +ZipOp::ZipOp(int32_t rows_per_buffer, int32_t op_connector_size) + : PipelineOp(op_connector_size), + children_num_(0), + rows_per_buffer_(rows_per_buffer), + buffer_id_(0), + draining_(false), + eof_(false) {} + +// destructor +ZipOp::~ZipOp() {} + +// Entry point for Zip, called by launch() +Status ZipOp::operator()() { + // The children_num_ parameter needs to be put here + children_num_ = child_.size(); + // Synchronize with TaskManager once the thread is created. + TaskManager::FindMe()->Post(); + + // initialize the iterators + for (int32_t i = 0; i < children_num_; ++i) { + // magic number 0 since Zip is not a parallel Op + child_iterators_.push_back(std::make_unique(this, 0, i)); + } + + // Loop until eof is true + while (!eof_) { + // Create tensor table and prepare it by fetching and packing the first zipped row into it. + std::unique_ptr curr_table = std::make_unique(); + RETURN_IF_NOT_OK(prepare(curr_table.get())); + + // If an eof got picked up during the above prepare, then we're done + if (eof_) { + break; + } + while (!draining_) { + // 1. If a previous loop iteration sent the current table out, then create a new one. + if (curr_table == nullptr) { + curr_table = std::make_unique(); + } + + // 2 fill the table. Note: draining mode might get turned on if any of the child inputs were done + RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); + + // 3 create and update buffer and send it to the out connector + if (!curr_table->empty()) { + std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); + curr_buffer->set_tensor_table(std::move(curr_table)); + MS_LOG(DEBUG) << "Zip operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " + << curr_buffer->NumCols() << ", map " << column_name_id_map_.size() << "."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + buffer_id_++; + } + } + + // 4 handle drain state. + if (draining_) { + MS_LOG(DEBUG) << "Zip operator is now draining child inputs."; + RETURN_IF_NOT_OK(drainPipeline()); + // Now that we have drained child inputs, send the eoe up. + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + } + } + + // 5 handle eof + // propagate eof here. + MS_LOG(DEBUG) << "Zip operator got EOF, propagating."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Handles preprocessing of the main loop, used when starting new epoch +Status ZipOp::prepare(TensorQTable *const table) { + MS_LOG(DEBUG) << "Zip operator prepares for new epoch."; + draining_ = false; + buffer_id_ = 0; + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase requires a tensor table."); + } + // fill initial row + TensorRow new_row; + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + + // If the first row fetching resulted in eof, then we are done. + if (eof_) { + return Status::OK(); + } + if (new_row.empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!"); + } + + // Pack this first row into our tensor table + table->push_back(std::move(new_row)); + + return Status::OK(); +} + +// fillBuffer always expects a new table to fill +Status ZipOp::fillBuffer(TensorQTable *const table) { + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp fillBuffer null table pointer."); + } + TensorRow new_row; + while (table->size() < static_cast(rows_per_buffer_)) { + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + // Early exit the loop if we got empty row from any of our child iterations + if (new_row.empty()) { + return Status::OK(); + } + // else we got a row so pack it into the tensor table. + table->push_back(std::move(new_row)); + } + return Status::OK(); +} + +// fetches next zip buffer row (merged row) +Status ZipOp::getNextTensorRow(TensorRow *const new_zip_row) { + // iterate over all iterators and generate a row + for (int32_t i = 0; i < children_num_; ++i) { + TensorRow new_row = {}; + RETURN_IF_NOT_OK((child_iterators_[i])->FetchNextTensorRow(&new_row)); + // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row + if (new_row.empty()) { + // If we did not get a row from any of the children, then it's the end of an epoch and we can move + // to drain state. + MS_LOG(DEBUG) << "Zip operator child iterator produced empty row."; + draining_ = true; + new_zip_row->clear(); + // If we picked up an eof here, then we are completely done. + if ((child_iterators_[i])->eof_handled()) { + MS_LOG(DEBUG) << "Zip operator iterator got EOF."; + eof_ = true; + } + return Status::OK(); + } else { + MS_LOG(DEBUG) << "Zip operator got row from child " << i << ". Num cols: " << new_row.size() << "."; + // if row isn't empty then we can append the fetched row with new_zip_row + new_zip_row->insert(new_zip_row->end(), new_row.begin(), new_row.end()); + } + } + MS_LOG(DEBUG) << "Zip operator builds a zipped row. Number of columns in row: " << new_zip_row->size() << "."; + return Status::OK(); +} + +// drain end of epoch messages from iterator for this epoch +Status ZipOp::drainPipeline() { + // we don't need to drain if we reached eof + if (eof_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "ZipOp draining should not be done if already at eof!"); + } + for (int32_t con = 0; con < children_num_; ++con) { + MS_LOG(DEBUG) << "Zip operator draining child at " << con << "."; + RETURN_IF_NOT_OK(child_iterators_[con]->Drain()); + } + // at this point all connectors don't contain end of epoch messages. next iteration should be clean + return Status::OK(); +} + +// A function that prints info about the Operator +void ZipOp::Print(std::ostream &out, // In: The output stream to print to + bool show_all) const { // In: T/F if it should print everything + // Always show the id and name as first line regardless if this is summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nDatasets: " << children_num_ << "\n\n"; + } +} + +// overwrite function and handle eof +Status ZipOp::EofReceived(int32_t) { + MS_LOG(DEBUG) << "Zip operator EOF received, do nothing now."; + return Status::OK(); +} + +// overwrite function and handle eoe +Status ZipOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +// Visitor accept method for NodePass +Status ZipOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status ZipOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + column_name_id_map_ = {}; + for (int32_t i = 0; i < child_.size(); ++i) { + // Initializing col_name_id_map from the child. + const std::unordered_map col_name_id_map = child_[i]->column_name_id_map(); + int32_t colsCurrent = column_name_id_map_.size(); + // the update code below shouldn't do anything bad if the column name already exists. + for (const auto &pair : col_name_id_map) { + std::string name = pair.first; + int32_t old_id = pair.second; + // check if name already exists in column name descriptor + if (column_name_id_map_.count(name) == 1) { + RETURN_STATUS_UNEXPECTED("key already exists when zipping datasets"); + } + column_name_id_map_[name] = old_id + colsCurrent; + } + } + MS_LOG(DEBUG) << "Setting column map:\n" << this->ColumnNameMapAsString(); + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h new file mode 100644 index 0000000000..c9466e26e2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h @@ -0,0 +1,158 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ +#define DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// forward declare +class DataBuffer; + +class ZipOp : public PipelineOp { + public: + // The nested builder class inside of the ZipOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + // NOTE: the rows per buffer with initial value 0 means to default to the number of rows from the first child + + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // The builder "build" method creates the ZipOp dataset Operator. + // @return shared_ptr to the new ZipOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + + Status SanityCheck() const; + }; + + // Constructor for ZipOp + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size + ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); + + // Destructor + ~ZipOp(); + + Status EofReceived(int32_t) override; + + Status EoeReceived(int32_t) override; + + // Print function for Zip + // @param out - output stream to print to + // @param show_all - if it should print everything + void Print(std::ostream &out, bool show_all) const override; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const ZipOp &zo) { + zo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ZipOp"; } + + private: + // Handles preprocessing of the main loop, used when starting new epoch + Status prepare(TensorQTable *const table); + + // This function calls takes a table repeatedly adds rows to it. + // @param table a table of tensors to be moved into a buffer + Status fillBuffer(TensorQTable *const table); + + // Special handle case where an empty row has been received from child iterator + // @note - we need to drain eoe signals from all children connectors. + // @details - when this function is called, then we encountered eoe at child iterator + // we have to drain rows from other child iterators until we hit eoe from all other child iterators + Status drainPipeline(); + + // Merges 1 row from each childIterator together + // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty + // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true + // @details merge rows from iterator together. This is the main functionality for ZipOp + // this function takes one row and fills it with tensors from rows fetched + // from childIterators. + // @example: + // Zips multiple rows at a time, the output is store in newZipRow + // 1 a T + // \ | / + // 1, a, T + Status getNextTensorRow(TensorRow *const new_zip_row); + + // Computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t children_num_; + int32_t rows_per_buffer_; + int32_t buffer_id_; + bool draining_; + bool eof_; + std::vector> child_iterators_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/db_connector.h b/mindspore/ccsrc/minddata/dataset/engine/db_connector.h new file mode 100644 index 0000000000..4a5c20bc12 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/db_connector.h @@ -0,0 +1,98 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DB_CONNECTOR_H_ +#define DATASET_ENGINE_DB_CONNECTOR_H_ + +#include +#include +#include "minddata/dataset/engine/connector.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { +// DbConnector is a derived class from Connector with added logic to handle EOE and EOF. +// The Connector class itself is responsible to ensure deterministic order on every run. +class DbConnector : public Connector> { + public: + // Constructor of DbConnector + // @note DbConnector will create internal N number of blocking queues, where N = nProducers. + // See Connector.h for more details. + // @param n_producers The number of threads producing data into this DbConnector. + // @param n_consumers The number of thread consuming data from this DbConnector. + // @param queue_capacity The number of element (DataBuffer) for each internal queue. + DbConnector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity) + : Connector>(n_producers, n_consumers, queue_capacity), end_of_file_(false) {} + + // Destructor of DbConnector + ~DbConnector() = default; + + // Add a unique_ptr into the DbConnector. + // @note The caller of this add method should use std::move to pass the ownership to DbConnector. + // @param worker_id The id of a worker thread calling this method. + // @param el A rvalue reference to an element to be passed/added/pushed. + Status Add(int32_t worker_id, std::unique_ptr &&el) noexcept { + return (Connector>::Push(worker_id, std::move(el))); + } + + // Get a unique_ptr from the DbConnector. + // @note After the first EOF Buffer is encountered, subsequent pop()s will return EOF Buffer. + // This will provide/propagate the EOF to all consumer threads of this Connector. + // Thus, When the num_consumers < num_producers, there will be extra EOF messages in some of the internal queues + // and reset() must be called before reusing DbConnector. + // @param worker_id The id of a worker thread calling this method. + // @param result The address of a unique_ptr where the popped element will be placed. + // @param retry_if_eoe A flag to allow the same thread invoke pop() again if the current pop returns eoe buffer. + Status PopWithRetry(int32_t worker_id, std::unique_ptr *result, bool retry_if_eoe = false) noexcept { + if (result == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "[ERROR] nullptr detected when getting data from db connector"); + } else { + std::unique_lock lk(m_); + RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return (expect_consumer_ == worker_id) || end_of_file_; })); + // Once an EOF message is encountered this flag will be set and we can return early. + if (end_of_file_) { + *result = std::make_unique(0, DataBuffer::kDeBFlagEOF); + } else { + RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); + if (*result == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "[ERROR] nullptr detected when getting data from db connector"); + } + // Setting the internal flag once the first EOF is encountered. + if ((*result)->eof()) { + end_of_file_ = true; + } + pop_from_ = (pop_from_ + 1) % num_producers_; + } + // Do not increment expect_consumer_ when result is eoe and retry_if_eoe is set. + if (!((*result)->eoe() && retry_if_eoe)) { + expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; + } + } + out_buffers_count_++; + cv_.NotifyAll(); + return Status::OK(); + } + + private: + // A flag to indicate the end of stream has been encountered. + bool end_of_file_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DB_CONNECTOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc new file mode 100644 index 0000000000..55dec24e79 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc @@ -0,0 +1,312 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/execution_tree.h" +#include +#include +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/opt/pre/removal_pass.h" +#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" +#include "minddata/dataset/engine/opt/post/repeat_pass.h" +#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" +#include "minddata/dataset/engine/perf/profiling.h" +#include "minddata/dataset/engine/perf/monitor.h" + +namespace mindspore { +namespace dataset { +// Constructor +ExecutionTree::ExecutionTree() : id_count_(0) { + tg_ = std::make_unique(); + tree_state_ = kDeTStateInit; + prepare_flags_ = kDePrepNone; + perf_monitor_ = std::make_unique(this); + profiling_manager_ = std::make_unique(this); + optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false; +} + +// Destructor +ExecutionTree::~ExecutionTree() { (void)tg_->ServiceStop(); } + +// Associates a DatasetOp with this tree. This assigns a valid node id to the operator and +// provides it with a link to the tree. A node cannot form any relationships (parent/child) with +// other nodes unless they are associated with the same tree. +Status ExecutionTree::AssociateNode(const std::shared_ptr &op) { + // If we are already a part of the tree, no-op + if (op->tree_ == this) { + return Status::OK(); + } + if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { + std::string err_msg = + "Invalid tree state for adding a node. Current state: " + std::to_string(static_cast(tree_state_)) + + " Expected states: " + std::to_string(static_cast(kDeTStateInit)) + " or " + + std::to_string(static_cast(kDeTStateBuilding)); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Enter the building state if we were not already there + tree_state_ = kDeTStateBuilding; + + // Assign an id to the operator + op->set_id(id_count_); + id_count_++; + + // Assign our tree into the op so that each op has a link back to the tree + op->set_tree(this); + return Status::OK(); +} + +// Sets the root node of the tree +Status ExecutionTree::AssignRoot(const std::shared_ptr &op) { + // Tree must be in building state before we can assign root to it + if (tree_state_ != kDeTStateBuilding) { + std::string err_msg = + "Invalid tree state for assigning a root node. Current state: " + std::to_string(static_cast(tree_state_)) + + " Expected state: " + std::to_string(static_cast(kDeTStateBuilding)); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // If they didn't already call AssociateNode for this node before calling AssignRoot, + // then do so now. + if (op->operator_id_ == DatasetOp::kInvalidOperatorId) { + RETURN_IF_NOT_OK(this->AssociateNode(op)); + } + + // Then add it as the root. + root_ = op; + + return Status::OK(); +} + +// A print method typically used for debugging +void ExecutionTree::Print(std::ostream &out, const std::shared_ptr &op) const { + out << "Execution tree summary:\n" + << "-----------------------\n"; + this->PrintNode(out, op == nullptr ? root_ : op, "", true, false); + out << "\nExecution tree operator details:\n" + << "--------------------------------\n"; + this->PrintNode(out, op == nullptr ? root_ : op, "", true, true); +} + +// A helper functions for doing the recursive printing +void ExecutionTree::PrintNode(std::ostream &out, const std::shared_ptr &dataset_op, std::string indent, + bool last, bool detailed) const { + // Decide which printer to use based on detailed arg. + if (!detailed) { + out << indent << "+- " << *dataset_op; + indent += (last ? " " : "| "); + } else { + dataset_op->Print(out, detailed); + } + + // Descend to children + for (int32_t i = 0; i < dataset_op->child_.size(); ++i) { + this->PrintNode(out, dataset_op->child_[i], indent, (i == (dataset_op->child_.size() - 1)), detailed); + } +} + +// Start the execution of the tree +Status ExecutionTree::Launch() { + // Tree must be built and prepared before it can be launched! + if (tree_state_ != kDeTStateReady) { + std::string err_msg = + "Invalid tree state for launching tree. Current state: " + std::to_string(static_cast(tree_state_)) + + " Expected state: " + std::to_string(static_cast(kDeTStateReady)); + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::ostringstream ss; + ss << *this; + + // Profiling infrastructures need to be initialized before Op launching + if (profiling_manager_->IsProfilingEnable()) { + // Setup profiling manager + RETURN_IF_NOT_OK(profiling_manager_->Initialize()); + // Launch Monitor Thread + RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Monitor Thread launched", std::ref(*perf_monitor_))); + } + + MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str(); + for (auto itr = this->begin(); itr != this->end(); ++itr) { + // An inlined operator is one that has an output connector size of 0, and it does not + // require a thread to execute. Instead, the work of this operator is executed inlined + // from the tree node directly above it (or in the case of a root node, it runs from within + // the launching tree/user thread. Do not exec any thread for an inlined op. + itr->state_ = DatasetOp::OpState::kDeOpRunning; + if (!itr->inlined()) { + RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Op launched, OperatorId:" + std::to_string(itr->id()), std::ref(*itr))); + // Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp + } + } + + tree_state_ = kDeTStateExecuting; + + return Status::OK(); +} + +// A function that traverse the tree in postorder then save the results in nodes +void ExecutionTree::Iterator::PostOrderTraverse(const std::shared_ptr &node) { + if (node == nullptr) { + return; + } + for (int32_t i = 0; i < node->child_.size(); ++i) { + PostOrderTraverse(node->child_[i]); + } + nodes_.push_back(node); +} + +ExecutionTree::Iterator::Iterator(const std::shared_ptr &root) : ind_(0) { + // post-order traverse the tree, if root is null, it return + PostOrderTraverse(root); + nodes_.emplace_back(nullptr); +} + +// Given the number of workers, launches the worker entry function for each. Essentially a +// wrapper for the TaskGroup handling that is stored inside the execution tree. +Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function func) { + // Launch the workers + for (int32_t i = 0; i < num_workers; ++i) { + RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Parallel Op Worker", std::bind(func, i))); + } + return Status::OK(); +} + +// The driver of the prepare phase of the execution tree. +// Prepare phase consists of three sub phases +// +// 1. PrepareTreePreAction() +// Compulsory transformation/action pre optimization. +// For example, CacheOp Insertion +// +// 2. Optimize() +// Optimization transformation/action, optional +// For example, MapOp Fusion +// +// 3. PrepareTreePostAction() +// Compulsory transformation/action post optimization. +// For example, repeatOp inlining +// +// @return Status - The error code return +Status ExecutionTree::Prepare() { + // Pre optimization compulsory transformation + RETURN_IF_NOT_OK(this->PrepareTreePreAction()); + + // If optional optimizations are enabled + if (optimize_) { + RETURN_IF_NOT_OK(this->Optimize()); + } + + // Post optimization compulsory transformation + RETURN_IF_NOT_OK(this->PrepareTreePostAction()); + + // Existing transformation implementation, will be removed later + RETURN_IF_NOT_OK(this->PrepareDeprecated()); + return Status::OK(); +} + +Status ExecutionTree::PrepareTreePreAction() { + bool modified = false; + std::vector> pre_actions; + // Construct pre actions + MS_LOG(INFO) << "Running pre pass loops."; + pre_actions.push_back(std::make_unique()); + pre_actions.push_back(std::make_unique()); + // Apply pre action passes + for (auto &pass : pre_actions) { + RETURN_IF_NOT_OK(pass->Run(this, &modified)); + } + MS_LOG(INFO) << "Pre passes complete."; + return Status::OK(); +} + +Status ExecutionTree::PrepareTreePostAction() { + // The tree is ready to be prepared. + tree_state_ = kDeTStatePrepare; + + bool modified = false; + std::vector> post_actions; + // Construct pre actions + MS_LOG(INFO) << "Running post pass loops."; + post_actions.push_back(std::make_unique()); + + // Apply post action passes + for (auto &pass : post_actions) { + RETURN_IF_NOT_OK(pass->Run(this, &modified)); + } + MS_LOG(INFO) << "Post passes complete."; + + return Status::OK(); +} + +Status ExecutionTree::Optimize() { + // Vector of optimizations, currently only 1, add more as necessary + std::vector> optimizations; + optimizations.push_back(std::make_unique()); + // vector of flags for each optimization + std::vector modified(optimizations.size(), false); + for (auto i = 0; i < optimizations.size(); i++) { + auto m = false; + optimizations[i]->Run(this, &m); + modified[i] = m; + } + return Status::OK(); +} + +// The driver of the prepare phase of the execution tree. The prepare phase will recursively +// walk the tree to perform modifications to the tree or specific nodes within the tree to get +// it ready for execution. +// +// This driver is deprecated. +Status ExecutionTree::PrepareDeprecated() { + // Tree must be in pending prepare state before we can assign root to it + if (tree_state_ != kDeTStatePrepare) { + std::string err_msg = + "Invalid tree state for preparing the tree. Current state: " + std::to_string(static_cast(tree_state_)) + + " Expected state: " + std::to_string(static_cast(kDeTStatePrepare)); + RETURN_STATUS_UNEXPECTED(err_msg); + } + // Start the recursive prepare + RETURN_IF_NOT_OK(this->PrepareNode(root_)); + tree_state_ = kDeTStateReady; + return Status::OK(); +} + +// Recursive function used during prepare phase to visit a node and drive any pre- and post- +// node actions during a tree walk. +Status ExecutionTree::PrepareNode(const std::shared_ptr &dataset_op) { + // execute PreAction + RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction()); + + // Before going down into children, make any prepare flags updates based on this operator. + uint32_t op_prep_flags = dataset_op->PrepareFlags(); + BitSet(&prepare_flags_, op_prep_flags); + + // Now, descend to children + for (const auto &i : dataset_op->child_) { + RETURN_IF_NOT_OK(this->PrepareNode(i)); + } + + // No more children, now we execute any prepare actions before going back up the + // the tree on recursive function + RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction()); + + // Then clear the flags from this op now that we have prepared it. + BitClear(&prepare_flags_, op_prep_flags); + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h new file mode 100644 index 0000000000..b62bf8e85d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h @@ -0,0 +1,257 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_EXECUTION_TREE_H_ +#define DATASET_ENGINE_EXECUTION_TREE_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/util/status.h" +#include "mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h" + +namespace mindspore { +namespace dataset { +// Forward declares +class TaskGroup; +class DatasetOp; +class Monitor; + +class ExecutionTree { + public: + // Prepare flags used during tree prepare phase + enum PrepareFlags { + kDePrepNone = 0, + kDePrepRepeat = 1, // Processing a repeat operation + kDePrepCache = 2 // Processing a cache operation + }; + + // State flags for the lifecycle of the tree + enum TreeState { + kDeTStateInit = 0, // The freshly initialized state after construction + kDeTStateBuilding, // The tree is being built, nodes are being added + kDeTStatePrepare, // The tree has been assigned a root node and is pending prepare + kDeTStateReady, // The tree has been prepared and is ready to be launched + kDeTStateExecuting, // The tree has been launched and is executing + kDeTStateFinished // The tree has been drained, dataset iterator received EOF + }; + + class Iterator { + public: + // Constructor + // @param root The root node to start iterating from + explicit Iterator(const std::shared_ptr &root = nullptr); + + // Destructor + ~Iterator() {} + + Iterator &operator++() { + ++ind_; + return *this; + } // prefix ++ overload + Iterator operator++(int) { + Iterator it = *this; + it.ind_ = ind_; + ind_++; + return it; + } // post-fix ++ overload + Iterator &operator--() { + --ind_; + return *this; + } // prefix -- overload + Iterator operator--(int) { + Iterator it = *this; + it.ind_ = ind_; + ind_--; + return it; + } // post-fix -- overload + DatasetOp &operator*() { return *nodes_[ind_]; } // dereference operator + std::shared_ptr operator->() { return nodes_[ind_]; } + + // getter function + // @return Shared pointer to the current operator + std::shared_ptr get() { return nodes_[ind_]; } + + bool operator==(const Iterator &rhs) { return nodes_[ind_] == rhs.nodes_[rhs.ind_]; } + + bool operator!=(const Iterator &rhs) { return nodes_[ind_] != rhs.nodes_[rhs.ind_]; } + + int32_t NumNodes() { return nodes_.size(); } + + private: + int32_t ind_; // the cur node our Iterator points to + std::vector> nodes_; // store the nodes in post order + void PostOrderTraverse(const std::shared_ptr &); + }; + + // Constructor + ExecutionTree(); + + // Destructor + ~ExecutionTree(); + + // Associates a DatasetOp with this tree. This assigns a valid node id to the operator and + // provides it with a link to the tree. A node cannot form any relationships (parent/child) with + // other nodes unless they are associated with the same tree. + // @param op - The operator to associate + // @return Status - The error code return + Status AssociateNode(const std::shared_ptr &op); + + // Sets the root node of the tree + // @param op - The operator to assign as root + // @return Status - The error code return + Status AssignRoot(const std::shared_ptr &op); + + // Start the execution of the tree + // @return Status - The error code return + Status Launch(); + + /// A print method typically used for debugging + /// \param out - The output stream to write output to + void Print(std::ostream &out, const std::shared_ptr &op = nullptr) const; + + // Returns an iterator positioned at the start + // @return Iterator - The iterator + ExecutionTree::Iterator begin(const std::shared_ptr &root = nullptr) const { + return Iterator(root == nullptr ? root_ : root); + } + + // Returns an iterator positioned at the end + // @return Iterator - The iterator + ExecutionTree::Iterator end() const { return Iterator(nullptr); } + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param exe_tree - reference to the execution tree to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, ExecutionTree &exe_tree) { + exe_tree.Print(out); + return out; + } + + // Given the number of workers, launches the worker entry function for each. Essentially a + // wrapper for the TaskGroup handling that is stored inside the execution tree. + // @param num_workers - The number of workers to launch + // @param func - The function entry point that workers will execute + // @return Status - The error code return + Status LaunchWorkers(int32_t num_workers, std::function func); + + // Getter method + // @return shared_ptr to the root operator + std::shared_ptr root() const { return root_; } + + // Getter method + // @return the prepare flags + uint32_t PrepareFlags() const { return prepare_flags_; } + + // The driver of the prepare phase of the execution tree. + // Prepare phase consists of three sub phases + // + // 1. PrepareTreePreAction() + // Compulsory transformation/action pre optimization. + // For example, CacheOp Insertion + // + // 2. Optimize() + // Optimization transformation/action, optional + // For example, MapOp Fusion + // + // 3. PrepareTreePostAction() + // Compulsory transformation/action post optimization. + // For example, repeatOp inlining + // + // @return Status - The error code return + Status Prepare(); + + // Compulsory transformation/action pre optimization. + // @return Status - The error code return + Status PrepareTreePreAction(); + + // Compulsory transformation/action post optimization. + // @return Status - The error code return + Status PrepareTreePostAction(); + + // Optimization transformation/action, optional. + // @return Status - The error code return + Status Optimize(); + + // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively + // walk the tree to perform modifications to the tree or specific nodes within the tree to get + // it ready for execution. + // @return Status - The error code return + Status PrepareDeprecated(); + + // Recursive function used during prepare phase to visit a node and drive any pre- and post- + // node actions during a tree walk. + // @param op - The dataset op to work on + // @return Status - The error code return + Status PrepareNode(const std::shared_ptr &dataset_op); + + // Return the pointer to the TaskGroup + // @return raw pointer to the TaskGroup + TaskGroup *AllTasks() const { return tg_.get(); } + + // Return if the ExecutionTree is finished (iterator receives EOF). + // @return Bool - true is ExecutionTree is finished + bool isFinished() const { return tree_state_ == TreeState::kDeTStateFinished; } + + // Set the ExecutionTree to Finished state. + void SetFinished() { tree_state_ = TreeState::kDeTStateFinished; } + + // Getter for profiling manager, no ownership + ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); } + + // Set optional optimization if tree has not been prepared yet + Status SetOptimize(bool value) { + if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { + std::string optimize = (optimize_ == true) ? "true" : "false"; + std::string msg = "Tree has already been prepared with OPTIMIZE set to " + optimize; + RETURN_STATUS_UNEXPECTED(msg); + } else { + optimize_ = value; + return Status::OK(); + } + } + + // Optional optimizations status + bool OptimizationEnabled() const { return optimize_; } + + private: + // A helper functions for doing the recursive printing + // @param dataset_op - The dataset op to print + // @param indent - an indent string for aligning child levels in output + // @param last - an indicator if it's the last child or not + // @param detailed - should it display the detailed node output or the summary line + void PrintNode(std::ostream &out, const std::shared_ptr &dataset_op, std::string indent, bool last, + bool detailed) const; + + std::unique_ptr tg_; // Class for worker management + std::shared_ptr root_; // The root node of the tree + int32_t id_count_; // Counter for generating operator id's + uint32_t prepare_flags_; // Flags used during tree prepare + TreeState tree_state_; // Tracking the current tree state + std::unique_ptr perf_monitor_; // Performance Monitor + std::unique_ptr profiling_manager_; // Profiling manager + bool optimize_; // Flag to enable optional optimizations +}; + +inline bool operator==(const ExecutionTree::Iterator &lhs, const ExecutionTree::Iterator &rhs) { return lhs == rhs; } +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_EXECUTION_TREE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/gnn/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h new file mode 100644 index 0000000000..c62c088bab --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_EDGE_H_ +#define DATASET_ENGINE_GNN_EDGE_H_ + +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/engine/gnn/node.h" + +namespace mindspore { +namespace dataset { +namespace gnn { +using EdgeType = int8_t; +using EdgeIdType = int32_t; + +class Edge { + public: + // Constructor + // @param EdgeIdType id - edge id + // @param EdgeType type - edge type + // @param std::shared_ptr src_node - source node + // @param std::shared_ptr dst_node - destination node + Edge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) + : id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {} + + virtual ~Edge() = default; + + // @return NodeIdType - Returned edge id + EdgeIdType id() const { return id_; } + + // @return NodeIdType - Returned edge type + EdgeType type() const { return type_; } + + // Get the feature of a edge + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; + + // Get nodes on the edge + // @param std::pair, std::shared_ptr> *out_node - Source and destination nodes returned + Status GetNode(std::pair, std::shared_ptr> *out_node) { + *out_node = std::make_pair(src_node_, dst_node_); + return Status::OK(); + } + + // Set node to edge + // @param const std::pair, std::shared_ptr> &in_node - + Status SetNode(const std::pair, std::shared_ptr> &in_node) { + src_node_ = in_node.first; + dst_node_ = in_node.second; + return Status::OK(); + } + + // Update feature of edge + // @param std::shared_ptr feature - + // @return Status - The error code return + virtual Status UpdateFeature(const std::shared_ptr &feature) = 0; + + protected: + EdgeIdType id_; + EdgeType type_; + std::shared_ptr src_node_; + std::shared_ptr dst_node_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_EDGE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc new file mode 100644 index 0000000000..dba4a6fa60 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/feature.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +Feature::Feature(FeatureType type_name, std::shared_ptr value) : type_name_(type_name), value_(value) {} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h new file mode 100644 index 0000000000..0d7eba1009 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_FEATURE_H_ +#define DATASET_ENGINE_GNN_FEATURE_H_ + +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace gnn { +using FeatureType = int16_t; + +class Feature { + public: + // Constructor + // @param FeatureType type_name - feature type + // @param std::shared_ptr value - feature value + Feature(FeatureType type_name, std::shared_ptr value); + + ~Feature() = default; + + // Get feature value + // @return std::shared_ptr *out_value - feature value + const std::shared_ptr Value() const { return value_; } + + // @return NodeIdType - Returned feature type + FeatureType type() const { return type_name_; } + + private: + FeatureType type_name_; + std::shared_ptr value_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_FEATURE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc new file mode 100644 index 0000000000..9083eb4c4b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc @@ -0,0 +1,681 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/graph.h" + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +Graph::Graph(std::string dataset_file, int32_t num_workers) + : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { + rnd_.seed(GetSeed()); + MS_LOG(INFO) << "num_workers:" << num_workers; +} + +Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr *out) { + auto itr = node_type_map_.find(node_type); + if (itr == node_type_map_.end()) { + std::string err_msg = "Invalid node type:" + std::to_string(node_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); + } + return Status::OK(); +} + +template +Status Graph::CreateTensorByVector(const std::vector> &data, DataType type, + std::shared_ptr *out) { + if (!type.IsCompatible()) { + RETURN_STATUS_UNEXPECTED("Data type not compatible"); + } + if (data.empty()) { + RETURN_STATUS_UNEXPECTED("Input data is empty"); + } + std::shared_ptr tensor; + size_t m = data.size(); + size_t n = data[0].size(); + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &tensor, TensorImpl::kFlexible, TensorShape({static_cast(m), static_cast(n)}), type, nullptr)); + auto ptr = tensor->begin(); + for (const auto &id_m : data) { + CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); + for (const auto &id_n : id_m) { + *ptr = id_n; + ptr++; + } + } + tensor->Squeeze(); + *out = std::move(tensor); + return Status::OK(); +} + +template +Status Graph::ComplementVector(std::vector> *data, size_t max_size, T default_value) { + if (!data || data->empty()) { + RETURN_STATUS_UNEXPECTED("Input data is empty"); + } + for (std::vector &vec : *data) { + size_t size = vec.size(); + if (size > max_size) { + RETURN_STATUS_UNEXPECTED("The max_size parameter is abnormal"); + } else { + for (size_t i = 0; i < (max_size - size); ++i) { + vec.push_back(default_value); + } + } + } + return Status::OK(); +} + +Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr *out) { + auto itr = edge_type_map_.find(edge_type); + if (itr == edge_type_map_.end()) { + std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); + } + return Status::OK(); +} + +Status Graph::GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) { + if (edge_list.empty()) { + RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); + } + + std::vector> node_list; + node_list.reserve(edge_list.size()); + for (const auto &edge_id : edge_list) { + auto itr = edge_id_map_.find(edge_id); + if (itr == edge_id_map_.end()) { + std::string err_msg = "Invalid edge id:" + std::to_string(edge_id); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + std::pair, std::shared_ptr> nodes; + RETURN_IF_NOT_OK(itr->second->GetNode(&nodes)); + node_list.push_back({nodes.first->id(), nodes.second->id()}); + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(node_list, DataType(DataType::DE_INT32), out)); + return Status::OK(); +} + +Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + std::shared_ptr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); + + std::vector> neighbors; + size_t max_neighbor_num = 0; + neighbors.resize(node_list.size()); + for (size_t i = 0; i < node_list.size(); ++i) { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node)); + RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i])); + max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); + } + + RETURN_IF_NOT_OK(ComplementVector(&neighbors, max_neighbor_num, kDefaultNodeId)); + RETURN_IF_NOT_OK(CreateTensorByVector(neighbors, DataType(DataType::DE_INT32), out)); + + return Status::OK(); +} + +Status Graph::CheckSamplesNum(NodeIdType samples_num) { + NodeIdType all_nodes_number = + std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, + [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); + if ((samples_num < 1) || (samples_num > all_nodes_number)) { + std::string err_msg = "Wrong samples number, should be between 1 and " + std::to_string(all_nodes_number) + + ", got " + std::to_string(samples_num); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +Status Graph::CheckNeighborType(NodeType neighbor_type) { + if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { + std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +Status Graph::GetSampledNeighbors(const std::vector &node_list, + const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), + "The sizes of neighbor_nums and neighbor_types are inconsistent."); + for (const auto &num : neighbor_nums) { + RETURN_IF_NOT_OK(CheckSamplesNum(num)); + } + for (const auto &type : neighbor_types) { + RETURN_IF_NOT_OK(CheckNeighborType(type)); + } + std::vector> neighbors_vec(node_list.size()); + for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { + std::shared_ptr input_node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &input_node)); + neighbors_vec[node_idx].emplace_back(node_list[node_idx]); + std::vector input_list = {node_list[node_idx]}; + for (size_t i = 0; i < neighbor_nums.size(); ++i) { + std::vector neighbors; + neighbors.reserve(input_list.size() * neighbor_nums[i]); + for (const auto &node_id : input_list) { + if (node_id == kDefaultNodeId) { + for (int32_t j = 0; j < neighbor_nums[i]; ++j) { + neighbors.emplace_back(kDefaultNodeId); + } + } else { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); + std::vector out; + RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); + neighbors.insert(neighbors.end(), out.begin(), out.end()); + } + } + neighbors_vec[node_idx].insert(neighbors_vec[node_idx].end(), neighbors.begin(), neighbors.end()); + input_list = std::move(neighbors); + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(neighbors_vec, DataType(DataType::DE_INT32), out)); + return Status::OK(); +} + +Status Graph::NegativeSample(const std::vector &data, const std::unordered_set &exclude_data, + int32_t samples_num, std::vector *out_samples) { + CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); + std::vector shuffled_id(data.size()); + std::iota(shuffled_id.begin(), shuffled_id.end(), 0); + std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); + for (const auto &index : shuffled_id) { + if (exclude_data.find(data[index]) != exclude_data.end()) { + continue; + } + out_samples->emplace_back(data[index]); + if (out_samples->size() >= samples_num) { + break; + } + } + return Status::OK(); +} + +Status Graph::GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); + RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); + + std::vector> neg_neighbors_vec; + neg_neighbors_vec.resize(node_list.size()); + for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node)); + std::vector neighbors; + RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors)); + std::unordered_set exclude_nodes; + std::transform(neighbors.begin(), neighbors.end(), + std::insert_iterator>(exclude_nodes, exclude_nodes.begin()), + [](const NodeIdType node) { return node; }); + const std::vector &all_nodes = node_type_map_[neg_neighbor_type]; + neg_neighbors_vec[node_idx].emplace_back(node->id()); + if (all_nodes.size() > exclude_nodes.size()) { + while (neg_neighbors_vec[node_idx].size() < samples_num + 1) { + RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(), + &neg_neighbors_vec[node_idx])); + } + } else { + MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() + << " neg_neighbor_type:" << neg_neighbor_type; + // If there are no negative neighbors, they are filled with kDefaultNodeId + for (int32_t i = 0; i < samples_num; ++i) { + neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId); + } + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(neg_neighbors_vec, DataType(DataType::DE_INT32), out)); + return Status::OK(); +} + +Status Graph::RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out) { + RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); + std::vector> walks; + RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); + RETURN_IF_NOT_OK(CreateTensorByVector({walks}, DataType(DataType::DE_INT32), out)); + return Status::OK(); +} + +Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = default_node_feature_map_.find(feature_type); + if (itr == default_node_feature_map_.end()) { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *out_feature = itr->second; + } + return Status::OK(); +} + +Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = default_edge_feature_map_.find(feature_type); + if (itr == default_edge_feature_map_.end()) { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *out_feature = itr->second; + } + return Status::OK(); +} + +Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, + TensorRow *out) { + if (!nodes || nodes->Size() == 0) { + RETURN_STATUS_UNEXPECTED("Input nodes is empty"); + } + CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); + TensorRow tensors; + for (const auto &f_type : feature_types) { + std::shared_ptr default_feature; + // If no feature can be obtained, fill in the default value + RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature)); + + TensorShape shape(default_feature->Value()->shape()); + auto shape_vec = nodes->shape().AsVector(); + dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); + shape = shape.PrependDim(size); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); + + dsize_t index = 0; + for (auto node_itr = nodes->begin(); node_itr != nodes->end(); ++node_itr) { + std::shared_ptr feature; + if (*node_itr == kDefaultNodeId) { + feature = default_feature; + } else { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node)); + if (!node->GetFeatures(f_type, &feature).IsOk()) { + feature = default_feature; + } + } + RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); + index++; + } + + TensorShape reshape(nodes->shape()); + for (auto s : default_feature->Value()->shape().AsVector()) { + reshape = reshape.AppendDim(s); + } + RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); + fea_tensor->Squeeze(); + tensors.push_back(fea_tensor); + } + *out = std::move(tensors); + return Status::OK(); +} + +Status Graph::GetEdgeFeature(const std::shared_ptr &edges, const std::vector &feature_types, + TensorRow *out) { + if (!edges || edges->Size() == 0) { + RETURN_STATUS_UNEXPECTED("Input edges is empty"); + } + CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); + TensorRow tensors; + for (const auto &f_type : feature_types) { + std::shared_ptr default_feature; + // If no feature can be obtained, fill in the default value + RETURN_IF_NOT_OK(GetEdgeDefaultFeature(f_type, &default_feature)); + + TensorShape shape(default_feature->Value()->shape()); + auto shape_vec = edges->shape().AsVector(); + dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); + shape = shape.PrependDim(size); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); + + dsize_t index = 0; + for (auto edge_itr = edges->begin(); edge_itr != edges->end(); ++edge_itr) { + std::shared_ptr edge; + RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge)); + std::shared_ptr feature; + if (!edge->GetFeatures(f_type, &feature).IsOk()) { + feature = default_feature; + } + RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); + index++; + } + + TensorShape reshape(edges->shape()); + for (auto s : default_feature->Value()->shape().AsVector()) { + reshape = reshape.AppendDim(s); + } + RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); + fea_tensor->Squeeze(); + tensors.push_back(fea_tensor); + } + *out = std::move(tensors); + return Status::OK(); +} + +Status Graph::Init() { + RETURN_IF_NOT_OK(LoadNodeAndEdge()); + return Status::OK(); +} + +Status Graph::GetMetaInfo(MetaInfo *meta_info) { + meta_info->node_type.resize(node_type_map_.size()); + std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), + [](auto itr) { return itr.first; }); + std::sort(meta_info->node_type.begin(), meta_info->node_type.end()); + + meta_info->edge_type.resize(edge_type_map_.size()); + std::transform(edge_type_map_.begin(), edge_type_map_.end(), meta_info->edge_type.begin(), + [](auto itr) { return itr.first; }); + std::sort(meta_info->edge_type.begin(), meta_info->edge_type.end()); + + for (const auto &node : node_type_map_) { + meta_info->node_num[node.first] = node.second.size(); + } + + for (const auto &edge : edge_type_map_) { + meta_info->edge_num[edge.first] = edge.second.size(); + } + + for (const auto &node_feature : node_feature_map_) { + for (auto type : node_feature.second) { + meta_info->node_feature_type.emplace_back(type); + } + } + std::sort(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); + auto unique_node = std::unique(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); + meta_info->node_feature_type.erase(unique_node, meta_info->node_feature_type.end()); + + for (const auto &edge_feature : edge_feature_map_) { + for (const auto &type : edge_feature.second) { + meta_info->edge_feature_type.emplace_back(type); + } + } + std::sort(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); + auto unique_edge = std::unique(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); + meta_info->edge_feature_type.erase(unique_edge, meta_info->edge_feature_type.end()); + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +Status Graph::GraphInfo(py::dict *out) { + MetaInfo meta_info; + RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); + (*out)["node_type"] = py::cast(meta_info.node_type); + (*out)["edge_type"] = py::cast(meta_info.edge_type); + (*out)["node_num"] = py::cast(meta_info.node_num); + (*out)["edge_num"] = py::cast(meta_info.edge_num); + (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type); + (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type); + return Status::OK(); +} +#endif + +Status Graph::LoadNodeAndEdge() { + GraphLoader gl(dataset_file_, num_workers_); + // ask graph_loader to load everything into memory + RETURN_IF_NOT_OK(gl.InitAndLoad()); + // get all maps + RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, + &node_feature_map_, &edge_feature_map_, &default_node_feature_map_, + &default_edge_feature_map_)); + return Status::OK(); +} + +Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { + auto itr = node_id_map_.find(id); + if (itr == node_id_map_.end()) { + std::string err_msg = "Invalid node id:" + std::to_string(id); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *node = itr->second; + } + return Status::OK(); +} + +Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge) { + auto itr = edge_id_map_.find(id); + if (itr == edge_id_map_.end()) { + std::string err_msg = "Invalid edge id:" + std::to_string(id); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *edge = itr->second; + } + return Status::OK(); +} + +Graph::RandomWalkBase::RandomWalkBase(Graph *graph) + : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} + +Status Graph::RandomWalkBase::Build(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, const NodeIdType default_node, + int32_t num_walks, int32_t num_workers) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + node_list_ = node_list; + if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { + std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) + + ". The size of input path is " + std::to_string(meta_path.size()); + RETURN_STATUS_UNEXPECTED(err_msg); + } + for (const auto &type : meta_path) { + RETURN_IF_NOT_OK(graph_->CheckNeighborType(type)); + } + meta_path_ = meta_path; + if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) { + std::string err_msg = "Failed, step_home_param and step_away_param required greater than " + + std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) + + ", step_away_param: " + std::to_string(step_away_param); + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (default_node < -1) { + std::string err_msg = "Failed, default_node required to be greater or equal to -1."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (num_walks <= 0) { + std::string err_msg = "Failed, num_walks parameter required to be greater than 0"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (num_workers <= 0) { + std::string err_msg = "Failed, num_workers parameter required to be greater than 0"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + step_home_param_ = step_home_param; + step_away_param_ = step_away_param; + default_node_ = default_node; + num_walks_ = num_walks; + num_workers_ = num_workers; + return Status::OK(); +} + +Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path) { + // Simulate a random walk starting from start node. + auto walk = std::vector(1, start_node); // walk is an vector + // walk simulate + while (walk.size() - 1 < meta_path_.size()) { + // current nodE + auto cur_node_id = walk.back(); + std::shared_ptr cur_node; + RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node)); + + // current neighbors + std::vector cur_neighbors; + RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true)); + std::sort(cur_neighbors.begin(), cur_neighbors.end()); + + // break if no neighbors + if (cur_neighbors.empty()) { + break; + } + + // walk by the fist node, then by the previous 2 nodes + std::shared_ptr stochastic_index; + if (walk.size() == 1) { + RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index)); + } else { + NodeIdType prev_node_id = walk[walk.size() - 2]; + RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index)); + } + NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)]; + walk.push_back(next_node_id); + } + + while (walk.size() - 1 < meta_path_.size()) { + walk.push_back(default_node_); + } + + *walk_path = std::move(walk); + return Status::OK(); +} + +Status Graph::RandomWalkBase::SimulateWalk(std::vector> *walks) { + for (int32_t i = 0; i < num_walks_; i++) { + for (const auto &node : node_list_) { + std::vector walk; + RETURN_IF_NOT_OK(Node2vecWalk(node, &walk)); + walks->push_back(walk); + } + } + return Status::OK(); +} + +Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, + std::shared_ptr *node_probability) { + // Generate alias nodes + std::shared_ptr node; + graph_->GetNodeByNodeId(node_id, &node); + std::vector neighbors; + RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true)); + std::sort(neighbors.begin(), neighbors.end()); + auto non_normalized_probability = std::vector(neighbors.size(), 1.0); + *node_probability = + std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); + return Status::OK(); +} + +Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, + std::shared_ptr *edge_probability) { + // Get the alias edge setup lists for a given edge. + std::shared_ptr src_node; + graph_->GetNodeByNodeId(src, &src_node); + std::vector src_neighbors; + RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true)); + + std::shared_ptr dst_node; + graph_->GetNodeByNodeId(dst, &dst_node); + std::vector dst_neighbors; + RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true)); + + std::sort(dst_neighbors.begin(), dst_neighbors.end()); + std::vector non_normalized_probability; + for (const auto &dst_nbr : dst_neighbors) { + if (dst_nbr == src) { + non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] + continue; + } + auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr); + if (it != src_neighbors.end()) { + // stay close, this node connect both src and dst + non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight'] + } else { + // step far away + non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] + } + } + + *edge_probability = + std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); + return Status::OK(); +} + +StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector &probability) { + uint32_t K = probability.size(); + std::vector switch_to_large_index(K, 0); + std::vector weight(K, .0); + std::vector smaller; + std::vector larger; + auto random_device = GetRandomDevice(); + std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon); + float accumulate_threshold = 0.0; + for (uint32_t i = 0; i < K; i++) { + float threshold_one = distribution(random_device); + accumulate_threshold += threshold_one; + weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold; + weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i); + } + + while ((!smaller.empty()) && (!larger.empty())) { + uint32_t small = smaller.back(); + smaller.pop_back(); + uint32_t large = larger.back(); + larger.pop_back(); + switch_to_large_index[small] = large; + weight[large] = weight[large] + weight[small] - 1.0; + weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large); + } + return StochasticIndex(switch_to_large_index, weight); +} + +uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { + auto switch_to_large_index = stochastic_index.first; + auto weight = stochastic_index.second; + const uint32_t size_of_index = switch_to_large_index.size(); + + auto random_device = GetRandomDevice(); + std::uniform_real_distribution<> distribution(0.0, 1.0); + + // Generate random integer between [0, K) + uint32_t random_idx = std::floor(distribution(random_device) * size_of_index); + + if (distribution(random_device) < weight[random_idx]) { + return random_idx; + } + return switch_to_large_index[random_idx]; +} + +template +std::vector Graph::RandomWalkBase::Normalize(const std::vector &non_normalized_probability) { + float sum_probability = + 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); + if (sum_probability < kGnnEpsilon) { + sum_probability = 1.0; + } + std::vector normalized_probability; + std::transform(non_normalized_probability.begin(), non_normalized_probability.end(), + std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; }); + return normalized_probability; +} +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h new file mode 100644 index 0000000000..76930d91f2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h @@ -0,0 +1,267 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_GRAPH_H_ +#define DATASET_ENGINE_GNN_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/engine/gnn/graph_loader.h" +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/engine/gnn/node.h" +#include "minddata/dataset/engine/gnn/edge.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +const float kGnnEpsilon = 0.0001; +const uint32_t kMaxNumWalks = 80; +using StochasticIndex = std::pair, std::vector>; + +struct MetaInfo { + std::vector node_type; + std::vector edge_type; + std::map node_num; + std::map edge_num; + std::vector node_feature_type; + std::vector edge_feature_type; +}; + +class Graph { + public: + // Constructor + // @param std::string dataset_file - + // @param int32_t num_workers - number of parallel threads + Graph(std::string dataset_file, int32_t num_workers); + + ~Graph() = default; + + // Get all nodes from the graph. + // @param NodeType node_type - type of node + // @param std::shared_ptr *out - Returned nodes id + // @return Status - The error code return + Status GetAllNodes(NodeType node_type, std::shared_ptr *out); + + // Get all edges from the graph. + // @param NodeType edge_type - type of edge + // @param std::shared_ptr *out - Returned edge ids + // @return Status - The error code return + Status GetAllEdges(EdgeType edge_type, std::shared_ptr *out); + + // Get the node id from the edge. + // @param std::vector edge_list - List of edges + // @param std::shared_ptr *out - Returned node ids + // @return Status - The error code return + Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out); + + // All neighbors of the acquisition node. + // @param std::vector node_list - List of nodes + // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported + // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is + // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors + // is not enough, fill in tensor as -1. + // @return Status - The error code return + Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + std::shared_ptr *out); + + // Get sampled neighbors. + // @param std::vector node_list - List of nodes + // @param std::vector neighbor_nums - Number of neighbors sampled per hop + // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::shared_ptr *out - Returned neighbor's id. + // @return Status - The error code return + Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out); + + // Get negative sampled neighbors. + // @param std::vector node_list - List of nodes + // @param NodeIdType samples_num - Number of neighbors sampled + // @param NodeType neg_neighbor_type - The type of negative neighbor. + // @param std::shared_ptr *out - Returned negative neighbor's id. + // @return Status - The error code return + Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out); + + // Node2vec random walk. + // @param std::vector node_list - List of nodes + // @param std::vector meta_path - node type of each step + // @param float step_home_param - return hyper parameter in node2vec algorithm + // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param NodeIdType default_node - default node id + // @param std::shared_ptr *out - Returned nodes id in walk path + // @return Status - The error code return + Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out); + + // Get the feature of a node + // @param std::shared_ptr nodes - List of nodes + // @param std::vector feature_types - Types of features, An error will be reported if the feature type + // does not exist. + // @param TensorRow *out - Returned features + // @return Status - The error code return + Status GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, + TensorRow *out); + + // Get the feature of a edge + // @param std::shared_ptr edget - List of edges + // @param std::vector feature_types - Types of features, An error will be reported if the feature type + // does not exist. + // @param Tensor *out - Returned features + // @return Status - The error code return + Status GetEdgeFeature(const std::shared_ptr &edget, const std::vector &feature_types, + TensorRow *out); + + // Get meta information of graph + // @param MetaInfo *meta_info - Returned meta information + // @return Status - The error code return + Status GetMetaInfo(MetaInfo *meta_info); + +#ifdef ENABLE_PYTHON + // Return meta information to python layer + Status GraphInfo(py::dict *out); +#endif + + Status Init(); + + private: + class RandomWalkBase { + public: + explicit RandomWalkBase(Graph *graph); + + Status Build(const std::vector &node_list, const std::vector &meta_path, + float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, + int32_t num_walks = 1, int32_t num_workers = 1); + + ~RandomWalkBase() = default; + + Status SimulateWalk(std::vector> *walks); + + private: + Status Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path); + + Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, + std::shared_ptr *node_probability); + + Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, + std::shared_ptr *edge_probability); + + static StochasticIndex GenerateProbability(const std::vector &probability); + + static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index); + + template + std::vector Normalize(const std::vector &non_normalized_probability); + + Graph *graph_; + std::vector node_list_; + std::vector meta_path_; + float step_home_param_; // Return hyper parameter. Default is 1.0 + float step_away_param_; // Inout hyper parameter. Default is 1.0 + NodeIdType default_node_; + + int32_t num_walks_; // Number of walks per source. Default is 1 + int32_t num_workers_; // The number of worker threads. Default is 1 + }; + + // Load graph data from mindrecord file + // @return Status - The error code return + Status LoadNodeAndEdge(); + + // Create Tensor By Vector + // @param std::vector> &data - + // @param DataType type - + // @param std::shared_ptr *out - + // @return Status - The error code return + template + Status CreateTensorByVector(const std::vector> &data, DataType type, std::shared_ptr *out); + + // Complete vector + // @param std::vector> *data - To be completed vector + // @param size_t max_size - The size of the completed vector + // @param T default_value - Filled default + // @return Status - The error code return + template + Status ComplementVector(std::vector> *data, size_t max_size, T default_value); + + // Get the default feature of a node + // @param FeatureType feature_type - + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); + + // Get the default feature of a edge + // @param FeatureType feature_type - + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); + + // Find node object using node id + // @param NodeIdType id - + // @param std::shared_ptr *node - Returned node object + // @return Status - The error code return + Status GetNodeByNodeId(NodeIdType id, std::shared_ptr *node); + + // Find edge object using edge id + // @param EdgeIdType id - + // @param std::shared_ptr *edge - Returned edge object + // @return Status - The error code return + Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge); + + // Negative sampling + // @param std::vector &input_data - The data set to be sampled + // @param std::unordered_set &exclude_data - Data to be excluded + // @param int32_t samples_num - + // @param std::vector *out_samples - Sampling results returned + // @return Status - The error code return + Status NegativeSample(const std::vector &input_data, const std::unordered_set &exclude_data, + int32_t samples_num, std::vector *out_samples); + + Status CheckSamplesNum(NodeIdType samples_num); + + Status CheckNeighborType(NodeType neighbor_type); + + std::string dataset_file_; + int32_t num_workers_; // The number of worker threads + std::mt19937 rnd_; + RandomWalkBase random_walk_; + + std::unordered_map> node_type_map_; + std::unordered_map> node_id_map_; + + std::unordered_map> edge_type_map_; + std::unordered_map> edge_id_map_; + + std::unordered_map> node_feature_map_; + std::unordered_map> edge_feature_map_; + + std::unordered_map> default_node_feature_map_; + std::unordered_map> default_edge_feature_map_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_GRAPH_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc new file mode 100644 index 0000000000..9d2c6211f4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc @@ -0,0 +1,260 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "minddata/dataset/engine/gnn/graph_loader.h" +#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h" +#include "minddata/dataset/engine/gnn/local_edge.h" +#include "minddata/dataset/engine/gnn/local_node.h" +#include "minddata/dataset/util/task_manager.h" + +using ShardTuple = std::vector, mindspore::mindrecord::json>>; + +namespace mindspore { +namespace dataset { +namespace gnn { + +using mindrecord::MSRStatus; + +GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) + : mr_path_(mr_filepath), + num_workers_(num_workers), + row_id_(0), + shard_reader_(nullptr), + keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} + +Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, + EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, + EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map, + DefaultEdgeFeatureMap *default_edge_feature_map) { + for (std::deque> &dq : n_deques_) { + while (dq.empty() == false) { + std::shared_ptr node_ptr = dq.front(); + n_id_map->insert({node_ptr->id(), node_ptr}); + (*n_type_map)[node_ptr->type()].push_back(node_ptr->id()); + dq.pop_front(); + } + } + + for (std::deque> &dq : e_deques_) { + while (dq.empty() == false) { + std::shared_ptr edge_ptr = dq.front(); + std::pair, std::shared_ptr> p; + RETURN_IF_NOT_OK(edge_ptr->GetNode(&p)); + auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id()); + CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); + CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); + RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); + RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); + e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ + (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id()); + dq.pop_front(); + } + } + + for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); + for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); + + MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map); + return Status::OK(); +} + +Status GraphLoader::InitAndLoad() { + CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n"); + CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n"); + n_deques_.resize(num_workers_); + e_deques_.resize(num_workers_); + n_feature_maps_.resize(num_workers_); + e_feature_maps_.resize(num_workers_); + default_node_feature_maps_.resize(num_workers_); + default_edge_feature_maps_.resize(num_workers_); + TaskGroup vg; + + shard_reader_ = std::make_unique(); + CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, + "Fail to open" + mr_path_); + CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); + CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); + + mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"]; + for (const std::string &key : keys_) { + if (schema.find(key) == schema.end()) { + RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); + } + } + + // launching worker threads + for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { + RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); + } + // wait for threads to finish and check its return code + vg.join_all(Task::WaitFlag::kBlocking); + RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny()); + return Status::OK(); +} + +Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrecord::json &col_jsn, + std::shared_ptr *node, NodeFeatureMap *feature_map, + DefaultNodeFeatureMap *default_feature) { + NodeIdType node_id = col_jsn["first_id"]; + NodeType node_type = static_cast(col_jsn["type"]); + (*node) = std::make_shared(node_id, node_type); + std::vector indices; + RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices)); + + for (int32_t ind : indices) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); + RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared(ind, tensor))); + (*feature_map)[node_type].insert(ind); + if ((*default_feature)[ind] == nullptr) { + std::shared_ptr zero_tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); + RETURN_IF_NOT_OK(zero_tensor->Zero()); + (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + } + } + return Status::OK(); +} + +Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrecord::json &col_jsn, + std::shared_ptr *edge, EdgeFeatureMap *feature_map, + DefaultEdgeFeatureMap *default_feature) { + EdgeIdType edge_id = col_jsn["first_id"]; + EdgeType edge_type = static_cast(col_jsn["type"]); + NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; + std::shared_ptr src = std::make_shared(src_id, -1); + std::shared_ptr dst = std::make_shared(dst_id, -1); + (*edge) = std::make_shared(edge_id, edge_type, src, dst); + std::vector indices; + RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices)); + for (int32_t ind : indices) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); + RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared(ind, tensor))); + (*feature_map)[edge_type].insert(ind); + if ((*default_feature)[ind] == nullptr) { + std::shared_ptr zero_tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); + RETURN_IF_NOT_OK(zero_tensor->Zero()); + (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + } + } + return Status::OK(); +} + +Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector &col_blob, + const mindrecord::json &col_jsn, std::shared_ptr *tensor) { + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0, col_type_size = 1; + mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; + std::vector column_shape; + MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( + key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); + CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); + if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible, + std::move(TensorShape({static_cast(n_bytes / col_type_size)})), + std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data)); + return Status::OK(); +} + +Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector &col_blob, + const mindrecord::json &col_jsn, std::vector *indices) { + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0, col_type_size = 1; + mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; + std::vector column_shape; + MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( + key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); + CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); + + if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); + + for (int i = 0; i < n_bytes; i += col_type_size) { + int32_t feature_ind = -1; + if (col_type == mindrecord::ColumnInt32) { + feature_ind = *(reinterpret_cast(data + i)); + } else if (col_type == mindrecord::ColumnInt64) { + feature_ind = *(reinterpret_cast(data + i)); + } else { + RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); + } + if (feature_ind >= 0) indices->push_back(feature_ind); + } + return Status::OK(); +} + +Status GraphLoader::WorkerEntry(int32_t worker_id) { + // Handshake + TaskManager::FindMe()->Post(); + auto ret = shard_reader_->GetNextById(row_id_++, worker_id); + ShardTuple rows = ret.second; + while (rows.empty() == false) { + RETURN_IF_INTERRUPTED(); + for (const auto &tupled_row : rows) { + std::vector col_blob = std::get<0>(tupled_row); + mindrecord::json col_jsn = std::get<1>(tupled_row); + std::string attr = col_jsn["attribute"]; + if (attr == "n") { + std::shared_ptr node_ptr; + RETURN_IF_NOT_OK(LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), + &default_node_feature_maps_[worker_id])); + n_deques_[worker_id].emplace_back(node_ptr); + } else if (attr == "e") { + std::shared_ptr edge_ptr; + RETURN_IF_NOT_OK(LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), + &default_edge_feature_maps_[worker_id])); + e_deques_[worker_id].emplace_back(edge_ptr); + } else { + MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node."; + } + } + auto rc = shard_reader_->GetNextById(row_id_++, worker_id); + rows = rc.second; + } + return Status::OK(); +} + +void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, + DefaultNodeFeatureMap *default_node_feature_map, + DefaultEdgeFeatureMap *default_edge_feature_map) { + for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { + for (auto &m : n_feature_maps_[wkr_id]) { + for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); + } + for (auto &m : e_feature_maps_[wkr_id]) { + for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); + } + for (auto &m : default_node_feature_maps_[wkr_id]) { + (*default_node_feature_map)[m.first] = m.second; + } + for (auto &m : default_edge_feature_maps_[wkr_id]) { + (*default_edge_feature_map)[m.first] = m.second; + } + } + n_feature_maps_.clear(); + e_feature_maps_.clear(); +} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h new file mode 100644 index 0000000000..f7f9245b8a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_GRAPH_LOADER_H_ +#define DATASET_ENGINE_GNN_GRAPH_LOADER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/engine/gnn/graph.h" +#include "minddata/dataset/engine/gnn/node.h" +#include "minddata/dataset/engine/gnn/edge.h" +#include "minddata/dataset/util/status.h" +#include "minddata/mindrecord/include/shard_reader.h" +namespace mindspore { +namespace dataset { +namespace gnn { + +using mindrecord::ShardReader; +using NodeIdMap = std::unordered_map>; +using EdgeIdMap = std::unordered_map>; +using NodeTypeMap = std::unordered_map>; +using EdgeTypeMap = std::unordered_map>; +using NodeFeatureMap = std::unordered_map>; +using EdgeFeatureMap = std::unordered_map>; +using DefaultNodeFeatureMap = std::unordered_map>; +using DefaultEdgeFeatureMap = std::unordered_map>; + +// this class interfaces with the underlying storage format (mindrecord) +// it returns raw nodes and edges via GetNodesAndEdges +// it is then the responsibility of graph to construct itself based on the nodes and edges +// if needed, this class could become a base where each derived class handles a specific storage format +class GraphLoader { + public: + explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4); + + ~GraphLoader() = default; + // Init mindrecord and load everything into memory multi-threaded + // @return Status - the status code + Status InitAndLoad(); + + // this function will query mindrecord and construct all nodes and edges + // nodes and edges are added to map without any connection. That's because there nodes and edges are read in + // random order. src_node and dst_node in Edge are node_id only with -1 as type. + // features attached to each node and edge are expected to be filled correctly + Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, + DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); + + private: + // + // worker thread that reads mindrecord file + // @param int32_t worker_id - id of each worker + // @return Status - the status code + Status WorkerEntry(int32_t worker_id); + + // Load a node based on 1 row of mindrecord, returns a shared_ptr + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::shared_ptr *node - return value + // @param NodeFeatureMap *feature_map - + // @param DefaultNodeFeatureMap *default_feature - + // @return Status - the status code + Status LoadNode(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *node, + NodeFeatureMap *feature_map, DefaultNodeFeatureMap *default_feature); + + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::shared_ptr *edge - return value, the edge ptr, edge is not yet connected + // @param FeatureMap *feature_map + // @param DefaultEdgeFeatureMap *default_feature - + // @return Status - the status code + Status LoadEdge(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *edge, + EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); + + // @param std::string key - column name + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::vector *ind - return value, list of feature index in int32_t + // @return Status - the status code + Status LoadFeatureIndex(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, + std::vector *ind); + + // @param std::string &key - column name + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::shared_ptr *tensor - return value feature tensor + // @return Status - the status code + Status LoadFeatureTensor(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, + std::shared_ptr *tensor); + + // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 + void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); + + const int32_t num_workers_; + std::atomic_int row_id_; + std::string mr_path_; + std::unique_ptr shard_reader_; + std::vector>> n_deques_; + std::vector>> e_deques_; + std::vector n_feature_maps_; + std::vector e_feature_maps_; + std::vector default_node_feature_maps_; + std::vector default_edge_feature_maps_; + const std::vector keys_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_GRAPH_LOADER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc new file mode 100644 index 0000000000..642c73eed3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/local_edge.h" + +#include + +namespace mindspore { +namespace dataset { +namespace gnn { + +LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) + : Edge(id, type, src_node, dst_node) {} + +Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = features_.find(feature_type); + if (itr != features_.end()) { + *out_feature = itr->second; + return Status::OK(); + } else { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } +} + +Status LocalEdge::UpdateFeature(const std::shared_ptr &feature) { + auto itr = features_.find(feature->type()); + if (itr != features_.end()) { + RETURN_STATUS_UNEXPECTED("Feature already exists"); + } else { + features_[feature->type()] = feature; + return Status::OK(); + } +} +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h new file mode 100644 index 0000000000..d112972f8f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_LOCAL_EDGE_H_ +#define DATASET_ENGINE_GNN_LOCAL_EDGE_H_ + +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/edge.h" +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/engine/gnn/node.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +class LocalEdge : public Edge { + public: + // Constructor + // @param EdgeIdType id - edge id + // @param EdgeType type - edge type + // @param std::shared_ptr src_node - source node + // @param std::shared_ptr dst_node - destination node + LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node); + + ~LocalEdge() = default; + + // Get the feature of a edge + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; + + // Update feature of edge + // @param std::shared_ptr feature - + // @return Status - The error code return + Status UpdateFeature(const std::shared_ptr &feature) override; + + private: + std::unordered_map> features_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_LOCAL_EDGE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc new file mode 100644 index 0000000000..8eaf9bb716 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/local_node.h" + +#include +#include +#include + +#include "minddata/dataset/engine/gnn/edge.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } + +Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = features_.find(feature_type); + if (itr != features_.end()) { + *out_feature = itr->second; + return Status::OK(); + } else { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } +} + +Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, bool exclude_itself) { + std::vector neighbors; + auto itr = neighbor_nodes_.find(neighbor_type); + if (itr != neighbor_nodes_.end()) { + if (exclude_itself) { + neighbors.resize(itr->second.size()); + std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), + [](const std::shared_ptr node) { return node->id(); }); + } else { + neighbors.resize(itr->second.size() + 1); + neighbors[0] = id_; + std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, + [](const std::shared_ptr node) { return node->id(); }); + } + } else { + MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; + if (!exclude_itself) { + neighbors.emplace_back(id_); + } + } + *out_neighbors = std::move(neighbors); + return Status::OK(); +} + +Status LocalNode::GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, + std::vector *out) { + std::vector shuffled_id(neighbors.size()); + std::iota(shuffled_id.begin(), shuffled_id.end(), 0); + std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); + int32_t num = std::min(samples_num, static_cast(neighbors.size())); + for (int32_t i = 0; i < num; ++i) { + out->emplace_back(neighbors[shuffled_id[i]]->id()); + } + return Status::OK(); +} + +Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + std::vector *out_neighbors) { + std::vector neighbors; + neighbors.reserve(samples_num); + auto itr = neighbor_nodes_.find(neighbor_type); + if (itr != neighbor_nodes_.end()) { + while (neighbors.size() < samples_num) { + RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); + } + } else { + MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; + // If there are no neighbors, they are filled with kDefaultNodeId + for (int32_t i = 0; i < samples_num; ++i) { + neighbors.emplace_back(kDefaultNodeId); + } + } + *out_neighbors = std::move(neighbors); + return Status::OK(); +} + +Status LocalNode::AddNeighbor(const std::shared_ptr &node) { + auto itr = neighbor_nodes_.find(node->type()); + if (itr != neighbor_nodes_.end()) { + itr->second.push_back(node); + } else { + neighbor_nodes_[node->type()] = {node}; + } + return Status::OK(); +} + +Status LocalNode::UpdateFeature(const std::shared_ptr &feature) { + auto itr = features_.find(feature->type()); + if (itr != features_.end()) { + RETURN_STATUS_UNEXPECTED("Feature already exists"); + } else { + features_[feature->type()] = feature; + return Status::OK(); + } +} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h new file mode 100644 index 0000000000..9c122931e7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_LOCAL_NODE_H_ +#define DATASET_ENGINE_GNN_LOCAL_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/node.h" +#include "minddata/dataset/engine/gnn/feature.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +class LocalNode : public Node { + public: + // Constructor + // @param NodeIdType id - node id + // @param NodeType type - node type + LocalNode(NodeIdType id, NodeType type); + + ~LocalNode() = default; + + // Get the feature of a node + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; + + // Get the all neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, + bool exclude_itself = false) override; + + // Get the sampled neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param int32_t samples_num - Number of neighbors to be acquired + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + std::vector *out_neighbors) override; + + // Add neighbor of node + // @param std::shared_ptr node - + // @return Status - The error code return + Status AddNeighbor(const std::shared_ptr &node) override; + + // Update feature of node + // @param std::shared_ptr feature - + // @return Status - The error code return + Status UpdateFeature(const std::shared_ptr &feature) override; + + private: + Status GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, + std::vector *out); + + std::mt19937 rnd_; + std::unordered_map> features_; + std::unordered_map>> neighbor_nodes_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_LOCAL_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h new file mode 100644 index 0000000000..a7c803fee2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_NODE_H_ +#define DATASET_ENGINE_GNN_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/feature.h" + +namespace mindspore { +namespace dataset { +namespace gnn { +using NodeType = int8_t; +using NodeIdType = int32_t; + +constexpr NodeIdType kDefaultNodeId = -1; + +class Node { + public: + // Constructor + // @param NodeIdType id - node id + // @param NodeType type - node type + Node(NodeIdType id, NodeType type) : id_(id), type_(type) {} + + virtual ~Node() = default; + + // @return NodeIdType - Returned node id + NodeIdType id() const { return id_; } + + // @return NodeIdType - Returned node type + NodeType type() const { return type_; } + + // Get the feature of a node + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; + + // Get the all neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, + bool exclude_itself = false) = 0; + + // Get the sampled neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param int32_t samples_num - Number of neighbors to be acquired + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + std::vector *out_neighbors) = 0; + + // Add neighbor of node + // @param std::shared_ptr node - + // @return Status - The error code return + virtual Status AddNeighbor(const std::shared_ptr &node) = 0; + + // Update feature of node + // @param std::shared_ptr feature - + // @return Status - The error code return + virtual Status UpdateFeature(const std::shared_ptr &feature) = 0; + + protected: + NodeIdType id_; + NodeType type_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/jagged_connector.h b/mindspore/ccsrc/minddata/dataset/engine/jagged_connector.h new file mode 100644 index 0000000000..cee0b7abf3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/jagged_connector.h @@ -0,0 +1,88 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_JAGGED_CONNECTOR_H_ +#define DATASET_ENGINE_JAGGED_CONNECTOR_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/connector.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { +class JaggedConnector : public Connector> { + public: + JaggedConnector(int32_t num_producers, int32_t num_consumers, int32_t queue_capacity) + : Connector>(num_producers, num_consumers, queue_capacity) { + for (int i = 0; i < num_producers; i++) { + is_queue_finished_.push_back(false); + } + } + + ~JaggedConnector() = default; + + Status Add(int32_t worker_d, std::unique_ptr &&element) noexcept { + return Connector>::Push(worker_d, std::move(element)); + } + + Status Pop(int32_t worker_id, std::unique_ptr *result) noexcept override { + { + MS_ASSERT(worker_id < num_consumers_); + std::unique_lock lock(m_); + RETURN_IF_NOT_OK(cv_.Wait(&lock, [this, worker_id]() { return expect_consumer_ == worker_id; })); + if (is_queue_finished_[pop_from_]) { + std::string errMsg = "ERROR: popping from a finished queue in JaggedConnector"; + RETURN_STATUS_UNEXPECTED(errMsg); + } + + RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); + if ((*result)->eoe()) { + is_queue_finished_[pop_from_] = true; + } + + for (int offset = 1; offset <= num_producers_; offset++) { + int32_t nextQueueIndex = (pop_from_ + offset) % num_producers_; + if (is_queue_finished_[nextQueueIndex] == false) { + pop_from_ = nextQueueIndex; + break; + } + } + + expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; + } + + cv_.NotifyAll(); + return Status::OK(); + } + + void DoReset() { + for (int i = 0; i < is_queue_finished_.size(); i++) { + is_queue_finished_[i] = false; + } + + Connector>::Reset(); + } + + private: + std::vector is_queue_finished_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_JAGGED_CONNECTOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt new file mode 100644 index 0000000000..0ab1fb7925 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt @@ -0,0 +1,12 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +add_library(engine-opt OBJECT + pass.cc + post/repeat_pass.cc + pre/cache_pass.cc + pre/cache_transform_pass.cc + pre/removal_nodes.cc + pre/removal_pass.cc + optional/tensor_op_fusion_pass.cc + util/printer_pass.cc + ) diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.cc new file mode 100644 index 0000000000..d8ce2dd863 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" + +namespace mindspore { +namespace dataset { + +Status TensorOpFusionPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Most primitive pattern: DecodeOp immediately followed by RandomCropAndResizeOp + // Abstract into a more general member function that can find any pattern, expressed + // by regular expressions, for instance. + // Add a list of optimisation policies. For now, just this lambda + auto FindPattern = [](auto &tfuncs) { + auto it = + std::find_if(tfuncs.begin(), tfuncs.end(), [](const auto &tf) -> bool { return tf->Name() == kDecodeOp; }); + auto next = it + 1; + if (it != tfuncs.end() && next != tfuncs.end() && (*next)->Name() == kRandomCropAndResizeOp) { + return it; + } else { + return tfuncs.end(); + } + }; + + auto &tfuncs = node->TFuncs(); + auto it = FindPattern(tfuncs); + if (it != tfuncs.end()) { + auto next = it + 1; + auto op = static_cast(next->get()); + *it = std::static_pointer_cast(std::make_shared(*op)); + tfuncs.erase(next); + } + if (modified != nullptr) { + *modified = true; + } else { + RETURN_STATUS_UNEXPECTED("modified is nullptr"); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h new file mode 100644 index 0000000000..a109af396c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TENSOR_OP_FUSION_PASS_H_ +#define DATASET_TENSOR_OP_FUSION_PASS_H_ + +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +/// \class TensorOpFusionPass tensor_op_fusion_pass.h +/// \brief And optional optimization pass identifying and fusing +/// tensor ops within MapOp +class TensorOpFusionPass : public NodePass { + /// \brief Identifies and fuses tensor ops within MapOp + /// \param[in] node The node being visited + /// \param[inout] *modified indicates whether the node has been visited + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TENSOR_OP_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc new file mode 100644 index 0000000000..4a8bbaf38f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -0,0 +1,248 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/datasetops/batch_op.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/project_op.h" +#include "minddata/dataset/engine/datasetops/rename_op.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/skip_op.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/engine/datasetops/filter_op.h" +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#endif +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/take_op.h" +#include "minddata/dataset/engine/datasetops/zip_op.h" + +namespace mindspore { +namespace dataset { + +// Driver method for TreePass +Status TreePass::Run(ExecutionTree *tree, bool *modified) { + if (tree == nullptr || modified == nullptr) { + return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); + } + return this->RunOnTree(tree, modified); +} + +// Driver method for NodePass +Status NodePass::Run(ExecutionTree *tree, bool *modified) { + if (tree == nullptr || modified == nullptr) { + return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); + } + std::shared_ptr root = tree->root(); + if (traversalOrder_ == Order::DFS) { + // DFS + return DFSNodeVisit(root, modified); + } else if (traversalOrder_ == Order::BFS) { + // BFS + return BFSNodeVisit(root, modified); + } + return Status::OK(); +} + +// Helper function to perform DFS visit +Status NodePass::DFSNodeVisit(std::shared_ptr node, bool *modified) { + RETURN_IF_NOT_OK(node->PreAccept(this, modified)); + for (const auto &c : node->Children()) { + RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); + } + return node->Accept(this, modified); +} + +// Helper function to perform BFS visit +Status NodePass::BFSNodeVisit(std::shared_ptr root, bool *modified) { + // Initialize bfs queue with root + std::queue> bfsQueue; + bfsQueue.push(root); + + // BFS loop + while (!bfsQueue.empty()) { + // Pop the front of the bfs queue + auto curNode = bfsQueue.front(); + bfsQueue.pop(); + + // Run node pass + RETURN_IF_NOT_OK(curNode->Accept(this, modified)); + + // Push children into bfs queue + for (const auto &c : curNode->Children()) { + bfsQueue.push(c); + } + } + return Status::OK(); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +#ifdef ENABLE_PYTHON +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} +#endif + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h new file mode 100644 index 0000000000..845ab34d66 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -0,0 +1,213 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_H_ + +#include +#include + +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class BatchOp; + +class MapOp; + +class ProjectOp; + +class RenameOp; + +class SkipOp; + +class ShuffleOp; + +class MindRecordOp; + +class TFReaderOp; + +#ifdef ENABLE_PYTHON +class FilterOp; + +class GeneratorOp; +#endif + +class RandomDataOp; + +class RepeatOp; + +class TakeOp; + +class ZipOp; + +class DeviceQueueOp; + +class ImageFolderOp; + +class CacheOp; + +class MnistOp; + +class ManifestOp; + +class CifarOp; + +class VOCOp; + +class CocoOp; + +class CelebAOp; + +class CacheMergeOp; + +class CacheLookupOp; + +// The base class Pass is the basic unit of tree transformation. +// The actual implementation of the passes will be derived from here. +class Pass : public std::enable_shared_from_this { + public: + // Run the transformation pass against the execution tree. + // @param tree - Pointer to the execution tree to be transformed. + // @param modified - Pointer to the modified flag, + virtual Status Run(ExecutionTree *tree, bool *modified) = 0; +}; + +// TreePass is a basic Pass class which performs transformation on ExecutionTree directly. +class TreePass : public Pass { + public: + /// \brief Run the transformation pass against the execution tree. + /// \param[inout] tree Pointer to the execution tree to be transformed. + /// \param[inout] modified Indicate if the tree was modified + Status Run(ExecutionTree *tree, bool *modified) final; + + /// \brief Derived classes may implement the runOnTree function to implement tree transformation. + /// "modified" flag needs to be set to true if tree is modified during the pass execution. + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } +}; + +// NodePass is a basic Pass class which performs transformation on Node visiting. +// NodePass implements Visitor design pattern. +class NodePass : public Pass { + public: + // Tree traversal order + enum Order { DFS, BFS }; + + // Constructor + // Default DFS traversal + explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } + + ~NodePass() = default; + + /// \brief Run the transformation pass against the execution tree + /// \param[inout] tree Pointer to the execution tree to be transformed + /// \param[inout] modified Indicator if the tree was changed + Status Run(ExecutionTree *tree, bool *modified) final; + + /// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down + /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution + /// \param[in] node The node being visited + /// \param[out] modified Indicator if the node was changed at all + /// \return Status The error code return + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } + + /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation + /// "modified" flag needs to be set to true if tree is modified during the pass execution + /// \param[in] node The node being visited + /// \param[out] modified Indicator if the node was changed at all. + /// \return Status The error code return + virtual Status RunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } + + // Visit methods to be overridden. + // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode + // of its own type and override "Accept" from DatasetOp. + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + +#ifdef ENABLE_PYTHON + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); +#endif + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + private: + // Helper function to perform DFS visit + Status DFSNodeVisit(std::shared_ptr node, bool *modified); + + // Helper function to perform BFS visit + Status BFSNodeVisit(std::shared_ptr root, bool *modified); + + // Tree traversal order of the NodePass + Order traversalOrder_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc new file mode 100644 index 0000000000..59a3f71c53 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/post/repeat_pass.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" + +namespace mindspore { +namespace dataset { + +RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {} + +// Identifies the subtree below this node as being in a repeated path of the tree. +Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // If we are already repeated, then this is a nested repeat. + if (is_repeated_) { + nested_repeats_++; + } + is_repeated_ = true; + return Status::OK(); +} + +// Identifies the subtree below this node as being in a cache merge path +Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that we're under a merge op + is_merge_ = true; + return Status::OK(); +} + +// Hooks up any identified eoe nodes under this repeat. +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking + std::shared_ptr leaf_op = PopFromEOEOpStack(); + while (leaf_op != nullptr) { + node->AddToEoeList(leaf_op); + leaf_op = PopFromEOEOpStack(); + } + + // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up + // and add it to the list of eoe/leaf ops for the repeat, removing it from the save area. + if (is_merge_ && cache_lookup_) { + cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated); + node->AddToEoeList(std::move(cache_lookup_)); + } + + // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. + // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. + if (nested_repeats_ > 0) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + AddToEOEOpStack(node); + nested_repeats_--; + } + + // If we are not nested, or we were the top-most repeat, now we clear the flag + if (nested_repeats_ == 0) { + is_repeated_ = false; + } + + return Status::OK(); +} + +// CacheOp removes previous leaf ops and replaces them with itself +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_repeated_) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + // if we are a cache within a repeat path of the tree, then there will be + // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the + // repeat or epoch ctrl operators can work with them for repeat activity during runtime. + // However, since a cache is present: + // - unflag those ops as being repeated ops + // - remove them from the eoe op stack so that repeat op above in the tree won't know about them + // - add ourself (the cache op), as an eoe op + // We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead + // the repeating behaviours shall be invoked against the cache op. + std::shared_ptr leaf_op = PopFromEOEOpStack(); + while (leaf_op != nullptr) { + leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat); + leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated); + leaf_op = PopFromEOEOpStack(); + } + AddToEOEOpStack(std::static_pointer_cast(node)); + } + + return Status::OK(); +} + +// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up +// for use with a controlling repeat above it. +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // If we are in a repeat path, then set our repeated flag + if (is_repeated_) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + + // if we are a leaf node then save ourself in a stack for the repeat operator above us + if (node->IsLeaf()) { + AddToEOEOpStack(node); + } + } + return Status::OK(); +} + +// Turns off the tracking for operations under merge op +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Setting the flag is needed since we didn't call the base class DatasetOp version + if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated); + is_merge_ = false; + cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed + return Status::OK(); +} + +// Saves the lookup up in case it needs to be referenced by a repeat +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + if (!node->IsLeaf()) { + // By definition, the CacheLookup must be a leaf op. Make that clear here. + RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); + } + + // If we are in a repeat path already, then there must be a repeat above the merge op + // In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here. + if (is_repeated_) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + AddToEOEOpStack(node); + } else { + // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we + // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself + // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. + cache_lookup_ = std::static_pointer_cast(node); + } + return Status::OK(); +} + +// Adds an operator to the eoe operator stack save area +void RepeatPass::AddToEOEOpStack(std::shared_ptr dataset_op) { eoe_stack_.push(dataset_op); } + +// Pops an operator from the eoe operator stack save area +std::shared_ptr RepeatPass::PopFromEOEOpStack() { + std::shared_ptr top_op = nullptr; + if (!eoe_stack_.empty()) { + top_op = eoe_stack_.top(); + eoe_stack_.pop(); + } + return top_op; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h new file mode 100644 index 0000000000..9b733e2329 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ +#define DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ + +#include +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +/// \class RepeatPass repeat_pass.h +/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references +/// to the eoe-producing (typically leaf) nodes underneath it. +class RepeatPass : public NodePass { + public: + /// \brief Constructor + RepeatPass(); + + /// \brief Identifies the subtree below this node as being in a repeated path of the tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the subtree below this node as being in a cache merge path + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Hooks up any identified eoe nodes under this repeat. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief CacheOp removes previous leaf ops and replaces them with itself + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Turns of the tracking for operations under merge op + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Saves the lookup up in case it needs to be referenced by a repeat + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up + /// for use with a controlling repeat above it. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + private: + /// \brief Adds an operator to the eoe operator stack save area + /// \param op - The dataset op to work add to eoe stack + /// \return Status - The error code return + void AddToEOEOpStack(std::shared_ptr dataset_op); + + /// \brief Pops an operator from the eoe operator stack save area + /// \return shared_ptr to the popped operator + std::shared_ptr PopFromEOEOpStack(); + + bool is_repeated_; // T/F if we are processing under a repeat + bool is_merge_; // T/F if we are processing under a cache merge op + int32_t nested_repeats_; // A counter for nested repeats + std::stack> eoe_stack_; // A save area for leaf/eoe ops + std::shared_ptr cache_lookup_; // A save area for a cache lookup op +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc new file mode 100644 index 0000000000..09b5f14a17 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc @@ -0,0 +1,181 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/pre/cache_pass.h" +#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" + +namespace mindspore { +namespace dataset { + +// Constructor +CachePass::CachePass(CacheTransformPass *transform_pass) + : transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {} + +// Identifies the subtree below this node as a cached descendant tree. +Status CachePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); + } + is_caching_ = true; + return Status::OK(); +} + +// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache +// transformation +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + is_caching_ = false; // We a no longer in a cache subtree. clear the flag. + if (leaf_op_) { + MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; + // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, + // using base class pointers. + transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node); + } else { + // If there was no leaf_op set, then this is a non-mappable scenario. + + if (sampler_) { + // Grab the sampler that was saved from the leaf and plug it into the cache op + node->SetSampler(std::move(sampler_)); + MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; + } else { + // We're a cache op but no sampler was saved from leaf, so create a default sampler + int64_t num_samples = 0; + int64_t start_index = 0; + sampler_ = std::make_shared(num_samples, start_index); + node->SetSampler(std::move(sampler_)); + MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; + } + + // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache + uint32_t cache_crc = DatasetOp::GenerateCRC(node); + RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); + } + + return Status::OK(); +} + +// Common code for mappable leaf setup. +Status CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (is_caching_ && leaf_op_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + } + + // If we are a leaf in the caching path, then save this leaf. + if (is_caching_) { + MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; + leaf_op_ = std::move(leaf_op); + } + return Status::OK(); +} + +// Common code for non mappable leaf setup. +Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr leaf_op) { + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (is_caching_ && leaf_op_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + } + + // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf + // as save it for use by cache op in ascendant tree. + if (is_caching_) { + RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); + MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; + } else { + // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can + // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) + std::shared_ptr sampler_from_leaf; + RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); + } + return Status::OK(); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_caching_) { + // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic + // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. + node->MakeSimpleProducer(); + } + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h new file mode 100644 index 0000000000..cbc805cd3e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ + +#include +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class CacheTransformPass; + +/// \class CachePass cache_pass.h +/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache +/// transformation. It works in conjunction with the CacheTransformPass +class CachePass : public NodePass { + public: + /// \brief Constructor + /// \param[in] transform_pass Raw pointer back to controlling tree pass + explicit CachePass(CacheTransformPass *transform_pass); + + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache + /// transformation + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + private: + /// \brief Common code for mappable leaf setup. + /// \param[in] node The leaf node performing setup work. + /// \return Status The error code return + Status MappableCacheLeafSetup(std::shared_ptr leaf_op); + + /// \brief Common code for non-mappable leaf setup. + /// \param[in] node The leaf node performing setup work. + /// \return Status The error code return + Status NonMappableCacheLeafSetup(std::shared_ptr leaf_op); + + bool is_caching_; + std::shared_ptr leaf_op_; + std::shared_ptr sampler_; + CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc new file mode 100644 index 0000000000..033150e8f4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/pre/cache_pass.h" +#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" + +namespace mindspore { +namespace dataset { + +// constructor +CacheTransformPass::CacheTransformPass() {} + +// Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations +Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) { + MS_LOG(INFO) << "Pre pass: Cache transform pass started."; + // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will + // use to execute a transform. + std::unique_ptr cache_pass = std::make_unique(this); + RETURN_IF_NOT_OK(cache_pass->Run(tree, modified)); + + // Then, execute the transform for each pair + for (auto cache_pair : cache_pairs_) { + MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; + ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); + } + MS_LOG(INFO) << "Pre pass: Cache transform pass complete."; + return Status::OK(); +} + +// Helper function to execute the cache transformation. +Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, + std::shared_ptr cache_op, + std::shared_ptr cache_client) { + // Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was + // the root node. It is also possible that cache_child == leaf_op + std::shared_ptr cache_child = cache_op->child(0); + DatasetOp *cache_parent = nullptr; + cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent + + // Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later. + std::shared_ptr leaf_sampler = leaf_op->sampler(); + + // Construct the merge op with defaults + std::shared_ptr merge_op; + CacheMergeOp::Builder merge_builder; + RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op)); + RETURN_IF_NOT_OK(tree->AssociateNode(merge_op)); + + // Construct the cache lookup op with defaults + std::shared_ptr cache_lookup_op; + CacheLookupOp::Builder lookup_builder; + RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op)); + RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op)); + + // Overwrite the old sampler in this leaf op to become the lookup op + leaf_op->SetSampler(cache_lookup_op); + + // If the cache had a parent, then go into that parent to remove the cache from it's child list and then + // replace it with the merge op. + if (cache_parent != nullptr) { + RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op)); + RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op)); + } else { + // If we didn't have a parent, then the merge op is the root node + RETURN_IF_NOT_OK(tree->AssignRoot(merge_op)); + } + + // Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op. + // We maintain a local pointer to the old child though. + RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child)); + + // Connect the merge op + RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op))); + RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child))); + + // At this point, the cache op has already had it's children and parents taken away. Calling remove + // on it at this point will not do any node hookups, and instead set internal fields to invalid. + RETURN_IF_NOT_OK(cache_op->Remove()); + + return Status::OK(); +} + +// Assigns the leaf and cache operators that are involved in a cache transformation +void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr leaf_op, + std::shared_ptr cache_op) { + cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h new file mode 100644 index 0000000000..02c22c4472 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ + +#include +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class DatasetOp; + +class CacheClient; + +/// \class CacheTransformPass cache_transform_pass.h +/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching +/// operations +class CacheTransformPass : public TreePass { + public: + /// \brief Constructor + CacheTransformPass(); + + /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + Status RunOnTree(ExecutionTree *tree, bool *modified) override; + + /// \brief Assigns the leaf and cache operators that are involved in a cache transformation + /// \param[in] leaf_op The leaf operator involved in the cache transform + /// \param[in] cache_op The cache operator involved in the cache transform + void AddMappableCacheOperators(std::shared_ptr leaf_op, std::shared_ptr cache_op); + + private: + /// \brief Helper function to execute the cache transformation. + /// + /// Input: + /// Sampler + /// | + /// LeafOp --> OtherOps --> CacheOp + /// + /// Transformed: + /// Sampler --> CacheLookupOp ----------------> + /// | | + /// | MergeOp + /// | | + /// LeafOp --> OtherOps --> + /// + /// \param[in] leaf_op The leaf node in the transform + /// \param[in] cache_op The cache op in the transform (will get removed) + /// \param[in] cache_client The cache client + /// \return Status The error code return + Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, + std::shared_ptr cache_op, std::shared_ptr cache_client); + + // The two operators that work together to establish the cache transform + std::vector, std::shared_ptr>> cache_pairs_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc new file mode 100644 index 0000000000..f04d7bc07d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/pre/removal_nodes.h" +#include "minddata/dataset/engine/opt/pre/removal_pass.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" + +namespace mindspore { +namespace dataset { + +RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} + +// Identifies the subtree below this node as a cached descendant tree. +Status RemovalNodes::PreRunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; + is_caching_ = true; + return Status::OK(); +} + +// Resets the tracking of the cache within the tree +Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; + is_caching_ = false; + return Status::OK(); +} + +// Perform ShuffleOp removal check. +Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + // If we are in a cache descendant tree, then this shuffle op needs to be removed + if (is_caching_) { + MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; + if (removal_pass_) { + removal_pass_->AddToRemovalList(std::static_pointer_cast(node)); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); + } + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h new file mode 100644 index 0000000000..32025cd597 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ + +#include +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/opt/pre/removal_pass.h" + +namespace mindspore { +namespace dataset { +/// \class RemovalNodes removal_nodes.h +/// \brief This is a NodePass who's job is to identify which nodes should be removed. +/// It works in conjunction with the removal_pass. +class RemovalNodes : public NodePass { + public: + /// \brief Constructor + /// \param[in] removal_pass Raw pointer back to controlling tree pass + explicit RemovalNodes(RemovalPass *removal_pass); + + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Destructor + ~RemovalNodes() = default; + + /// \brief Perform ShuffleOp removal check + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + private: + bool is_caching_; + RemovalPass *removal_pass_; // Back pointer to the owning removal pass +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc new file mode 100644 index 0000000000..0db422a7c2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/engine/opt/pre/removal_nodes.h" +#include "minddata/dataset/engine/opt/pre/removal_pass.h" +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { + +// constructor +RemovalPass::RemovalPass() {} + +// Runs a removal_nodes pass first to find out which nodes to remove, then removes them. +Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { + MS_LOG(INFO) << "Pre pass: removal pass started."; + // Create the removal node pass which can identify which nodes need to be removed. + std::unique_ptr removal_nodes = std::make_unique(this); + RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); + + // Then, execute the removal of any nodes that were set up for removal + for (auto node : removal_nodes_) { + node->Remove(); + } + MS_LOG(INFO) << "Pre pass: removal pass complete."; + return Status::OK(); +} + +// Adds an operator to the list of operators to be removed +void RemovalPass::AddToRemovalList(std::shared_ptr dataset_op) { removal_nodes_.push_back(dataset_op); } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h new file mode 100644 index 0000000000..bcab7cf08c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ + +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class DatasetOp; + +/// \class RemovalPass removal_pass.h +/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which +/// nodes should be removed, and then removes them. +class RemovalPass : public TreePass { + public: + /// \brief Constructor + RemovalPass(); + + /// \brief Destructor + ~RemovalPass() = default; + + /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + Status RunOnTree(ExecutionTree *tree, bool *modified) override; + + /// \brief Adds an operator to the list of operators to be removed + /// \param[in] dataset_op The operator to add to the removal list + void AddToRemovalList(std::shared_ptr dataset_op); + + private: + std::vector> removal_nodes_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc new file mode 100644 index 0000000000..eb74d8fcc3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/util/printer_pass.h" + +namespace mindspore { +namespace dataset { + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting DatasetOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting BatchOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting MapOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ProjectOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting RenameOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting SkipOp" << '\n'; + return Status::OK(); +} +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ShuffleOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting MindRecordOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting TFReaderOp" << '\n'; + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting FilterOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting GeneratorOp" << '\n'; + return Status::OK(); +} +#endif + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting TakeOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ZipOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting DeviceQueueOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ImageFolderOp" << '\n'; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h new file mode 100644 index 0000000000..527df3ccc9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H +#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H + +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class PrinterPass : public NodePass { + public: + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + +#ifdef ENABLE_PYTHON + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; +#endif + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H diff --git a/mindspore/ccsrc/dataset/engine/perf/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/perf/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/perf/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/perf/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.cc new file mode 100644 index 0000000000..20b4908030 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/perf/connector_size.h" +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/path.h" + +using json = nlohmann::json; +namespace mindspore { +namespace dataset { +using Qrow = std::vector; + +// Sample action +Status ConnectorSize::Sample() { + Qrow cur_row; + std::transform(tree_->begin(), tree_->end(), std::back_inserter(cur_row), + [](DatasetOp &op) { return op.ConnectorSize(); }); + // Push new row of sample + sample_table_.push_back(cur_row); + return Status::OK(); +} + +// JSON serializer helper function +json ConnectorSize::ParseOpInfo(const DatasetOp &node, const std::vector &size) { + auto children = node.Children(); + std::vector children_id; + std::transform(children.begin(), children.end(), std::back_inserter(children_id), + [](std::shared_ptr op) -> int32_t { return op->id(); }); + json json_node; + json_node["op_id"] = node.id(); + json_node["op_type"] = node.Name(); + json_node["num_workers"] = node.num_workers(); + json metrics; + // DeviceQueueOp is a special op,it is not inlined but its output queue is invalid. + // So we should not output its queue size. + if (!node.inlined() && node.Name() != "DeviceQueueOp") { + metrics["output_queue"] = {{"size", size}, {"length", node.ConnectorCapacity()}}; + } + json_node["metrics"] = metrics; + if (!children_id.empty()) { + json_node["children"] = children_id; + } + + return json_node; +} + +// Save profiling data to file +Status ConnectorSize::SaveToFile() { + std::ofstream os(file_path_, std::ios::trunc); + uint32_t idx = 0; + json output; + std::shared_ptr cfg = GlobalContext::config_manager(); + output["sampling_interval"] = cfg->monitor_sampling_interval(); + // Traverse the ExecutionTree for JSON node generation + for (auto &node : *tree_) { + std::vector cur_queue_size; + std::transform(sample_table_.begin(), sample_table_.end(), std::back_inserter(cur_queue_size), + [&](const ConnectorSizeSample &sample) { return sample[idx]; }); + json json_node = ParseOpInfo(node, cur_queue_size); + output["op_info"].push_back(json_node); + idx++; + } + os << output; + return Status::OK(); +} +Status ConnectorSize::Init(const std::string &dir_path, const std::string &device_id) { + file_path_ = (Path(dir_path) / Path("pipeline_profiling_" + device_id + ".json")).toString(); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h new file mode 100644 index 0000000000..61ba06a76f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CONNECTOR_SIZE_H +#define DATASET_CONNECTOR_SIZE_H + +#include +#include +#include +#include "minddata/dataset/engine/perf/profiling.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +using json = nlohmann::json; + +namespace mindspore { +namespace dataset { +class ExecutionTree; + +// Connector size sampling samples the output connector size of each op in the pipeline. +// It support JSON serialization for external usage. +class ConnectorSize : public Sampling { + // Connecto size sampling data is stored as a 2D vector + // op_0 ... op_m + // sample_0 size_0_0 ... size_m_0 + // ... ... ... ... + // sample_n size_0_m ... size_m_n + // + // A circular buffer will be implemented in the future to make this table more flexible. + using ConnectorSizeSample = std::vector; + using ConnectorSizeSampleTable = std::vector; + + public: + explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {} + + ~ConnectorSize() override = default; + + // Driver function for connector size sampling. + // This function samples the connector size of every nodes within the ExecutionTree + Status Sample() override; + + std::string Name() const override { return kConnectorSizeSamplingName; } + + // Save sampling data to file + // @return Status - The error code return + Status SaveToFile() override; + + Status Init(const std::string &dir_path, const std::string &device_id) override; + + // Parse op infomation and transform to json format + json ParseOpInfo(const DatasetOp &node, const std::vector &size); + + private: + ExecutionTree *tree_ = nullptr; // ExecutionTree pointer + ConnectorSizeSampleTable sample_table_; // Dataset structure to store all samples of connector size sampling +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CONNECTOR_SIZE_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.cc new file mode 100644 index 0000000000..b5e2efaf73 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.cc @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/perf/connector_throughput.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { + +// temporary helper +int ConnectorThroughput::InitNodes() { + auto it = (*tree_).begin(); + return it.NumNodes(); +} +// Sample action +Status ConnectorThroughput::Sample() { + std::vector out_buffer_count_row(n_nodes_); + std::vector throughput_row(n_nodes_); + TimePoint cur_time; // initialised inside the loop, used outside the loop to update prev sample time. + auto col = 0; + for (const auto &node : *tree_) { + auto cur_out_buffer_count = node.ConnectorOutBufferCount(); + out_buffer_count_row[col] = cur_out_buffer_count; + auto sz = timestamps_.size(); + cur_time = std::chrono::steady_clock::now(); + auto _dt = std::chrono::duration_cast(timestamps_[0][sz - 1] - timestamps_[0][sz - 2]); + auto dt = std::chrono::duration(_dt).count(); + auto prev_out_buffer_count = out_buffer_count_table_[col][out_buffer_count_table_.size() - 1]; + if (dt != 0) { + auto thr = (cur_out_buffer_count - prev_out_buffer_count) / (1000 * dt); + throughput_row[col] = thr; + } else { + throughput_row[col] = -1; + } + col++; + } + std::vector v = {cur_time}; // temporary fix + timestamps_.AddSample(v); + // Push new row of sample + out_buffer_count_table_.AddSample(out_buffer_count_row); + throughput_.AddSample(throughput_row); + return Status::OK(); +} + +json ConnectorThroughput::ParseOpInfo(const DatasetOp &node, const std::vector &thr) { + auto children = node.Children(); + std::vector children_id; + std::transform(children.begin(), children.end(), std::back_inserter(children_id), + [](std::shared_ptr op) -> int32_t { return op->id(); }); + json json_node; + json_node["op_id"] = node.id(); + json_node["op_type"] = node.Name(); + json_node["num_workers"] = node.num_workers(); + json metrics; + metrics["output_queue"] = {{"throughput", thr}}; + + json_node["metrics"] = metrics; + if (!children_id.empty()) { + json_node["children"] = children_id; + } + + return json_node; +} + +// Save profiling data to file +Status ConnectorThroughput::SaveToFile() { + std::ofstream os(file_path_); + json output; + output["sampling_interval"] = 10; + // Traverse the ExecutionTree for JSON node generation + int col = 0; + for (auto &node : *tree_) { + std::vector throughput; + for (auto i = 0; i < throughput_.size(); i++) { + throughput.push_back(throughput_[col][i]); + } + json json_node = ParseOpInfo(node, throughput); + output["op_info"].push_back(json_node); + col++; + } + os << output; + return Status::OK(); +} +Status ConnectorThroughput::Init(const std::string &dir_path, const std::string &device_id) { + file_path_ = (Path(dir_path) / Path("pipeline_profiling_" + Name() + "_" + device_id + ".json")).toString(); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h new file mode 100644 index 0000000000..9cf387230a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_CONNECTOR_THROUGHPUT_H +#define DATASET_CONNECTOR_THROUGHPUT_H + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/perf/profiling.h" +#include "minddata/dataset/engine/perf/perf_data.h" +#include "minddata/dataset/engine/perf/cyclic_array.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/execution_tree.h" + +using json = nlohmann::json; +namespace mindspore { +namespace dataset { +// Connector throughput samples the output connector size of each op in the pipeline. +// For the description of the data structure see perf_buffer.h +// It support JSON serialization for external usage. +class ConnectorThroughput : public Sampling { + using OutBufferCount = PerfData>; + using Throughput = PerfData>; + using TimePoint = std::chrono::time_point; + using TimeStamps = PerfData>; + + public: + explicit ConnectorThroughput(ExecutionTree *tree, int64_t max_rows = 1000000) + : tree_(tree), + max_rows_(max_rows), + n_nodes_(InitNodes()), + out_buffer_count_table_(OutBufferCount(max_rows_, n_nodes_)), + throughput_(Throughput(max_rows_, n_nodes_)), + timestamps_(TimeStamps(max_rows_, 1)) { + timestamps_.AddSample(std::vector(1)); + out_buffer_count_table_.AddSample(std::vector(n_nodes_)); + } + + /// \brief Destructor + ~ConnectorThroughput() = default; + + // Driver function for connector size sampling. + // This function samples the connector size of every nodes within the ExecutionTree + Status Sample() override; + + /* Status TestPrint() override { + std::ofstream os("performance_monitor.txt"); + if (throughput_.size() == 0) { + os << "data is empty" << std::endl; + return Status::OK(); + } + for (int i = 0; i < throughput_.size(); i++) { + for (int j = 0; j < n_nodes_; j++) { + os << throughput_[j][i] << " "; + } + os << std::endl; + } + return Status::OK(); + };*/ + + // Traverse the tree nodes and count them + int InitNodes(); + + std::string Name() const override { return name_; }; + + // Save sampling data to file + // @return Status - The error code return + Status SaveToFile() override; + + Status Init(const std::string &dir_path, const std::string &device_id); + + json ParseOpInfo(const DatasetOp &node, const std::vector &thr); + + private: + ExecutionTree *tree_ = nullptr; // ExecutionTree pointer + int64_t max_rows_; + int32_t n_nodes_; + OutBufferCount out_buffer_count_table_; + Throughput throughput_; + TimeStamps timestamps_; + std::string name_ = kConnectorThroughputSamplingName; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CONNECTOR_THROUGHPUT_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/cyclic_array.h b/mindspore/ccsrc/minddata/dataset/engine/perf/cyclic_array.h new file mode 100644 index 0000000000..2dfc3fd99d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/cyclic_array.h @@ -0,0 +1,197 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_CYCLIC_ARRAY_H +#define DATASET_CYCLIC_ARRAY_H + +#include +#include +#include +#include +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { + +/// \class CyclicArray "include/cyclic_array.h +/// \brief This is a container with a contiguous memory layout that pnly keeps N last entries, +/// when the number of entries exceeds the capacity +/// Must be preallocated +template +class CyclicArray { + public: + using value_type = T; + class Iterator { + // Add operator[] and make fully compliant with random access iterator + // and add a const iterator + // add resize(), empty() + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = CyclicArray::value_type; + using difference_type = std::ptrdiff_t; + using pointer = CyclicArray::value_type *; + using reference = CyclicArray::value_type &; + + Iterator() = default; + + Iterator(dsize_t idx, pointer ptr, dsize_t capacity, dsize_t head) + : cur_idx_(idx), ptr_(ptr), capacity_(capacity), head_(head) {} + + Iterator(const Iterator &rhs) = default; + + ~Iterator() = default; + + Iterator &operator++() { + cur_idx_ = (cur_idx_ + 1) % (capacity_ + 1); + return *this; + } + + Iterator operator++(int) { + Iterator tmp(*this); + cur_idx_ = (cur_idx_ + 1) % (capacity_ + 1); + return tmp; + } + + Iterator &operator--() { + cur_idx_ = (cur_idx_ + capacity_) % (capacity_ + 1); + return *this; + } + + Iterator operator--(int) { + Iterator tmp(*this); + cur_idx_ = (cur_idx_ + capacity_) % (capacity_ + 1); + return tmp; + } + + Iterator operator+(dsize_t x) { return Iterator((cur_idx_ + x) % (capacity_ + 1), ptr_, capacity_, head_); } + + Iterator operator-(dsize_t x) { + return Iterator((cur_idx_ + (capacity_ + 1 - x)) % (capacity_ + 1), ptr_, capacity_, head_); + } + + bool operator<(const Iterator &rhs) { + return (head_ + cur_idx_) % (capacity_ + 1) < (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); + } + + bool operator>(const Iterator &rhs) { + return (head_ + cur_idx_) % (capacity_ + 1) > (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); + } + + bool operator>=(const Iterator &rhs) { + return (head_ + cur_idx_) % (capacity_ + 1) >= (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); + } + + bool operator<=(const Iterator &rhs) { + return (head_ + cur_idx_) % (capacity_ + 1) <= (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); + } + + difference_type operator-(const Iterator &rhs) { + return (cur_idx_ - rhs.cur_idx_ + capacity_ + 1) % (capacity_ + 1); + } + + reference operator*() { return ptr_[cur_idx_]; } + + pointer operator->() { return &(ptr_[cur_idx_]); } + + bool operator==(const Iterator &rhs) { return cur_idx_ == rhs.cur_idx_; } + + bool operator!=(const Iterator &rhs) { return cur_idx_ != rhs.cur_idx_; } + + private: + dsize_t cur_idx_; + pointer ptr_; + dsize_t capacity_; + dsize_t head_; + }; + + /// \brief Default constructor + CyclicArray() : buf_(nullptr), head_(0), tail_(0), size_(0), capacity_(0) {} + + /// \brief Constructor + /// \param[in] capacity + explicit CyclicArray(dsize_t capacity) + : buf_(std::make_unique(capacity + 1)), head_(0), tail_(0), size_(0), capacity_(capacity) {} + + CyclicArray(const CyclicArray &rhs) + : buf_(std::make_unique(rhs.capacity_ + 1)), + head_(rhs.head_), + tail_(rhs.tail_), + size_(rhs.size_), + capacity_(rhs.capacity_) { + std::copy(rhs.begin(), rhs.end(), begin()); + } + + CyclicArray(CyclicArray &&rhs) = default; + + ~CyclicArray() = default; + + /// \brief Iterator begin() + Iterator begin() { return Iterator(head_, buf_.get(), capacity_, head_); } + + /// \brief Iterator end() + Iterator end() { return Iterator(tail_, buf_.get(), capacity_, head_); } + + // not really const. + Iterator begin() const { return Iterator(head_, buf_.get(), capacity_, head_); } + + Iterator end() const { return Iterator(tail_, buf_.get(), capacity_, head_); } + + /// \brief clear the array. Does not deallocate memory, capacity remains the same + void clear() { + head_ = 0; + tail_ = 0; + size_ = 0; + } + + /// \brief returns current size + dsize_t size() { return size_; } + + /// \brief returns capacity + dsize_t capacity() { return capacity_; } + + /// \brief pushes a value + /// \param[in] val value + void push_back(T val) { + buf_[tail_] = val; + if (size_ >= capacity_) { + (tail_ != capacity_) ? tail_++ : tail_ = 0; + (head_ != capacity_) ? head_++ : head_ = 0; + } else { + tail_++; + size_++; + } + } + + /// \brief returns const reference to an element of the array + /// \param[in] idx index of the element + /// \param[out] const T& reference to an element of the array + const T &operator[](dsize_t idx) const { return buf_[(head_ + idx) % (capacity_ + 1)]; } + + /// \brief returns non-const reference to an element of the array + /// \param[in] idx index of the element + /// \param[out] T& reference to an element of the array + T &operator[](dsize_t idx) { return buf_[(head_ + idx) % (capacity_ + 1)]; } + + private: + std::unique_ptr buf_; + dsize_t head_; + dsize_t tail_; + dsize_t size_; + dsize_t capacity_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CYCLIC_ARRAY_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.cc new file mode 100644 index 0000000000..4491db144e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { + +Status DatasetIteratorTracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, + const int32_t value) { + // Format: "type extra-info batch-num value" + // type: 0: time, 1: connector size + // extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time + // if type is 1 - connector capacity + // batch-num: batch number + // value: if type is 0 - value is time(ms) + // if type is 1 - value is connector size + // Examples: + // 0 0 20 10 - The 20th batch took 10ms to get data from pipeline. + // 1 64 20 5 - Connector size is 5 when get the 20th batch.Connector capacity is 64. + std::string data = std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " + + std::to_string(value); + value_.emplace_back(data); + return Status::OK(); +} + +Status DatasetIteratorTracing::SaveToFile() { + if (value_.empty()) { + return Status::OK(); + } + + std::ofstream handle(file_path_, std::ios::trunc); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Profiling file can not be opened."); + } + for (auto value : value_) { + handle << value << "\n"; + } + handle.close(); + + return Status::OK(); +} + +Status DatasetIteratorTracing::Init(const std::string &dir_path, const std::string &device_id) { + file_path_ = (Path(dir_path) / Path("dataset_iterator_profiling_" + device_id + ".txt")).toString(); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.h b/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.h new file mode 100644 index 0000000000..e7ba237a0a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DATASET_ITERATOR_TRACING_H +#define MINDSPORE_DATASET_ITERATOR_TRACING_H + +#include +#include +#include "minddata/dataset/engine/perf/profiling.h" + +namespace mindspore { +namespace dataset { +class DatasetIteratorTracing : public Tracing { + public: + // Constructor + DatasetIteratorTracing() = default; + + // Destructor + ~DatasetIteratorTracing() override = default; + + // Record tracing data + // @return Status - The error code return + Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); + + std::string Name() const override { return kDatasetIteratorTracingName; }; + + // Save tracing data to file + // @return Status - The error code return + Status SaveToFile() override; + + Status Init(const std::string &dir_path, const std::string &device_id) override; + + private: + std::vector value_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_DATASET_ITERATOR_TRACING_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.cc new file mode 100644 index 0000000000..776b483b79 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/engine/perf/device_queue_tracing.h" +#include "minddata/dataset/util/path.h" +namespace mindspore { +namespace dataset { + +Status DeviceQueueTracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, + const int32_t value) { + // Format: "type extra-info batch-num value" + // type: 0: time, 1: connector size + // extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time + // if type is 1 - connector capacity + // batch-num: batch number + // value: if type is 0 - value is time(ms) + // if type is 1 - value is connector size + // Examples: + // 0 0 20 10 - The 20th batch took 10ms to get data from pipeline. + // 1 64 20 5 - Connector size is 5 when get the 20th batch.Connector capacity is 64. + std::string data = std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " + + std::to_string(value); + value_.emplace_back(data); + return Status::OK(); +} + +Status DeviceQueueTracing::SaveToFile() { + if (value_.empty()) { + return Status::OK(); + } + + std::ofstream handle(file_path_, std::ios::trunc); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Profiling file can not be opened."); + } + for (auto value : value_) { + handle << value << "\n"; + } + handle.close(); + + return Status::OK(); +} + +Status DeviceQueueTracing::Init(const std::string &dir_path, const std::string &device_id) { + file_path_ = (Path(dir_path) / Path("device_queue_profiling_" + device_id + ".txt")).toString(); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.h b/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.h new file mode 100644 index 0000000000..32f9d2d8c2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DEVICE_QUEUE_TRACING_H +#define MINDSPORE_DEVICE_QUEUE_TRACING_H + +#include +#include +#include "minddata/dataset/engine/perf/profiling.h" + +namespace mindspore { +namespace dataset { +class DeviceQueueTracing : public Tracing { + public: + // Constructor + DeviceQueueTracing() = default; + + // Destructor + ~DeviceQueueTracing() override = default; + + // Record tracing data + // @return Status - The error code return + Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); + + std::string Name() const override { return kDeviceQueueTracingName; }; + + // Save tracing data to file + // @return Status - The error code return + Status SaveToFile() override; + + Status Init(const std::string &dir_path, const std::string &device_id) override; + + private: + std::vector value_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_DEVICE_QUEUE_TRACING_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.cc new file mode 100644 index 0000000000..7fa7e6fc78 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/perf/monitor.h" +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { + +Monitor::Monitor(ExecutionTree *tree) : tree_(tree) { + std::shared_ptr cfg = GlobalContext::config_manager(); + sampling_interval_ = cfg->monitor_sampling_interval(); + max_samples_ = 0; + cur_row_ = 0; +} +Status Monitor::operator()() { + // Register this thread with TaskManager to receive proper interrupt signal. + TaskManager::FindMe()->Post(); + + // Keep sampling if + // 1) Monitor Task is not interrupted by TaskManager AND + // 2) Iterator has not received EOF + while (!this_thread::is_interrupted() && !(tree_->isFinished())) { + for (auto &node : tree_->GetProfilingManager()->GetSamplingNodes()) { + RETURN_IF_NOT_OK(node.second->Sample()); + std::this_thread::sleep_for(std::chrono::milliseconds(sampling_interval_)); + } + } + + // Output all profiling data upon request. + tree_->GetProfilingManager()->SaveProfilingData(); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.h b/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.h new file mode 100644 index 0000000000..1e669dad71 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MONITOR_H +#define MINDSPORE_MONITOR_H + +#include +#include +#include +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/perf/profiling.h" + +namespace mindspore { +namespace dataset { +class ExecutionTree; +class Monitor { + public: + // Monitor object constructor + + explicit Monitor(ExecutionTree *tree); + + Monitor() = default; + + ~Monitor() = default; + + // Functor for Perf Monitor main loop. + // This function will be the entry point of mindspore::Dataset::Task + Status operator()(); + + int64_t GetSamplingInterval() { return sampling_interval_; } + + private: + int64_t cur_row_; + int64_t max_samples_; + int64_t sampling_interval_; + ExecutionTree *tree_; + std::vector> sampling_list_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_MONITOR_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/perf_data.h b/mindspore/ccsrc/minddata/dataset/engine/perf/perf_data.h new file mode 100644 index 0000000000..8f215fd8df --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/perf_data.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_PERF_DATA_H +#define DATASET_PERF_DATA_H + +#include +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { + +// PerfData is a convenience class to record and store the data produced by Monitor +// and represents a 2D column major table with every column storing samples +// for an operator. The number of rows equals to the number of samples, +// the number of columns equals to the number of operators. +// The capacity is determined on construction and cannot be changed. +// ColumnType can be std::vector or CyclicArray. In case of the latter data can be added +// indefinitely without the risk of overflowing otherwise the capacity must not be exceeded. +// Given PerfData pd(n_rows, n_cols) an element in the column i and row j can be accessed as +// pd[i][j] + +template +class PerfData { + public: + PerfData() = default; + ~PerfData() = default; + PerfData(dsize_t max_rows, dsize_t n_cols) : counter_(0), max_rows_(max_rows), n_cols_(n_cols) { + for (auto i = 0; i < n_cols_; i++) { + data_.push_back(ColumnType(max_rows_)); + } + } + PerfData(const PerfData &rhs) = default; + PerfData(PerfData &&rhs) = default; + + // Adds a row of data + // T must be any container working with range based loops + template + void AddSample(const T &row) { + auto i = 0; + for (const auto &e : row) { + data_[i++].push_back(e); + } + counter_++; + } + + // Fetches a row of data by copy + template + auto Row(dsize_t idx) { + std::vector row(n_cols_); + for (auto i = 0; i < n_cols_; i++) { + row[i] = data_[i][idx]; + } + return row; + } + + // returns a column of data + ColumnType &operator[](size_t idx) { return data_[idx]; } + + const ColumnType &operator[](size_t idx) const { return data_[idx]; } + + dsize_t size() { return counter_ < max_rows_ ? counter_ : max_rows_; } + + dsize_t capacity() { return max_rows_; } + + private: + std::vector data_; + dsize_t counter_; + dsize_t max_rows_; + int n_cols_; +}; + +} // namespace dataset +} // namespace mindspore +#endif // DATASET_PERF_DATA_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc new file mode 100644 index 0000000000..f5c018c03b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/perf/profiling.h" +#include +#include +#include +#include "common/utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/engine/perf/monitor.h" +#include "minddata/dataset/engine/perf/device_queue_tracing.h" +#include "minddata/dataset/engine/perf/connector_size.h" +#include "minddata/dataset/engine/perf/connector_throughput.h" +#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { + +bool ProfilingManager::IsProfilingEnable() const { + auto profiling = common::GetEnv("PROFILING_MODE"); + if (profiling.empty() || profiling != "true") { + return false; + } + return true; +} + +Status ProfilingManager::Initialize() { + // Register nodes based on config + std::string dir = common::GetEnv("MINDDATA_PROFILING_DIR"); + if (dir.empty()) { + RETURN_STATUS_UNEXPECTED("Profiling dir is not set."); + } + char real_path[PATH_MAX] = {0}; + if (dir.size() >= PATH_MAX) { + RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); + } +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) { + RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); + } +#else + if (realpath(common::SafeCStr(dir), real_path) == nullptr) { + RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); + } +#endif + dir_path_ = real_path; + + // If DEVICE_ID is not set,defult value is 0 + device_id_ = common::GetEnv("DEVICE_ID"); + if (device_id_.empty()) { + device_id_ = "0"; + } + + // Register all profiling node. + // device_queue node is used for graph mode + std::shared_ptr device_queue_tracing = std::make_shared(); + RETURN_IF_NOT_OK(RegisterTracingNode(device_queue_tracing)); + // dataset_iterator node is used for graph mode + std::shared_ptr dataset_iterator_tracing = std::make_shared(); + RETURN_IF_NOT_OK(RegisterTracingNode(dataset_iterator_tracing)); + + std::shared_ptr connector_size_sampling = std::make_shared(tree_); + RETURN_IF_NOT_OK(RegisterSamplingNode(connector_size_sampling)); + + std::shared_ptr connector_thr_sampling = std::make_shared(tree_); + RETURN_IF_NOT_OK(RegisterSamplingNode(connector_thr_sampling)); + return Status::OK(); +} + +// Profiling node registration +Status ProfilingManager::RegisterTracingNode(std::shared_ptr node) { + // Check if node with the same name has already been registered. + auto exist = tracing_nodes_.find(node->Name()); + if (exist != tracing_nodes_.end()) { + return Status(StatusCode::kProfilingError, "Profiling node already exist: " + node->Name()); + } + // Register the node with its name as key. + RETURN_IF_NOT_OK(node->Init(dir_path_, device_id_)); + tracing_nodes_[node->Name()] = node; + return Status::OK(); +} + +// Profiling node getter +Status ProfilingManager::GetTracingNode(const std::string &name, std::shared_ptr *node) { + // Check if node with the same name has already been registered. + auto exist = tracing_nodes_.find(name); + if (exist == tracing_nodes_.end()) { + return Status(StatusCode::kProfilingError, "Profiling node does not exist: " + name); + } + // Fetch node. + *node = tracing_nodes_[name]; + return Status::OK(); +} + +// Profiling node registration +Status ProfilingManager::RegisterSamplingNode(std::shared_ptr node) { + // Check if node with the same name has already been registered. + auto exist = sampling_nodes_.find(node->Name()); + if (exist != sampling_nodes_.end()) { + return Status(StatusCode::kProfilingError, "Profiling node already exist: " + node->Name()); + } + // Register the node with its name as key. + RETURN_IF_NOT_OK(node->Init(dir_path_, device_id_)); + sampling_nodes_[node->Name()] = node; + return Status::OK(); +} + +// Profiling node getter +Status ProfilingManager::GetSamplingNode(const std::string &name, std::shared_ptr *node) { + // Check if node with the same name has already been registered. + auto exist = sampling_nodes_.find(name); + if (exist == sampling_nodes_.end()) { + return Status(StatusCode::kProfilingError, "Profiling node does not exist: " + name); + } + // Fetch node. + *node = sampling_nodes_[name]; + return Status::OK(); +} + +Status ProfilingManager::SaveProfilingData() { + if (!IsProfilingEnable()) { + return Status::OK(); + } + MS_LOG(INFO) << "Start to save profiling data."; + for (auto node : tracing_nodes_) { + RETURN_IF_NOT_OK(node.second->SaveToFile()); + } + for (auto node : sampling_nodes_) { + RETURN_IF_NOT_OK(node.second->SaveToFile()); + } + MS_LOG(INFO) << "Save profiling data end."; + return Status::OK(); +} + +int64_t ProfilingTime::GetCurMilliSecond() { + // because cpplint does not allow using namespace + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using std::chrono::steady_clock; + return duration_cast(steady_clock::now().time_since_epoch()).count(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h new file mode 100644 index 0000000000..24f7f2efe8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_PROFILE_H_ +#define DATASET_UTIL_PROFILE_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class Monitor; +class ExecutionTree; + +const char kDeviceQueueTracingName[] = "Device_Queue_Tracing"; +const char kDatasetIteratorTracingName[] = "Dataset_Iterator_Tracing"; +const char kConnectorSizeSamplingName[] = "Connector_Size_Sampling"; +const char kConnectorThroughputSamplingName[] = "Connector_Throughput_Sampling"; + +// Profiling is a class of basic unit of profiling action +// This base class encapsulate the serialization output logic +class Profiling : std::enable_shared_from_this { + public: + // Constructor + Profiling() = default; + + // Destructor + virtual ~Profiling() = default; + + virtual Status Init(const std::string &dir_path, const std::string &device_id) = 0; + + // Default serialization file generator + virtual Status SaveToFile() = 0; + + // Profiling name + virtual std::string Name() const = 0; + + protected: + std::string file_path_; +}; + +// Sampling is a class of profiling which generate samples periodically. +class Sampling : public Profiling { + public: + // Sampling action function. This function will be invoked by performance monitor thread. + virtual Status Sample() = 0; + // virtual Status TestPrint() = 0; + virtual ~Sampling() = default; +}; + +// Tracing is class of profiling which record samples upon request. +class Tracing : public Profiling { + // Tracing does not define a fixed interface to provide flexible on data recording. +}; + +// ProfilingManager is a class manages all profiling infrastructure +// It serves the following purposes: +// 1) Fetch profiling configs from global contexts +// 2) Setup all profiling node based on config +// 3) Provide access of profiling nodes for profiling actions +// 4) Manage profiling data serialization process +class ProfilingManager { + public: + explicit ProfilingManager(ExecutionTree *tree) : tree_(tree) {} + + ~ProfilingManager() = default; + + Status Initialize(); + + // Save profile data to file + // @return Status - The error code return + Status SaveProfilingData(); + + // Sampling node getter + // @param name - The name of the requested node + // @param node - Pointer to the shared pointer for the Sampling node + // @return Status - The error code return + Status GetSamplingNode(const std::string &name, std::shared_ptr *node); + + // Tracing node getter + // @param name - The name of the requested node + // @param node - Pointer to the shared pointer for the Tracing node + // @return Status - The error code return + Status GetTracingNode(const std::string &name, std::shared_ptr *node); + + // If profiling is enabled. + bool IsProfilingEnable() const; + + const std::unordered_map> &GetSamplingNodes() { return sampling_nodes_; } + + private: + std::unordered_map> tracing_nodes_; + + std::unordered_map> sampling_nodes_; + + // Register profile node to tree + // @param node - Profiling node + // @return Status - The error code return + Status RegisterTracingNode(std::shared_ptr node); + + // Register profile node to tree + // @param node - Profiling node + // @return Status - The error code return + Status RegisterSamplingNode(std::shared_ptr node); + + ExecutionTree *tree_ = nullptr; // ExecutionTree pointer + std::string dir_path_; // where to create profiling file + std::string device_id_; // used when create profiling file,filename_deviceid.suffix +}; + +enum ProfilingType { TIME, CONNECTOR_DEPTH }; + +enum ProfilingTimeSubType { + PIPELINE_TIME, + TDT_PUSH_TIME, + BATCH_TIME, + INVALID_TIME, +}; + +class ProfilingTime { + public: + static int64_t GetCurMilliSecond(); +}; + +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/dataset/engine/tdt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/tdt/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/tdt/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/tdt/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc new file mode 100644 index 0000000000..126291179a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/tdt/tdt_plugin.h" +#include "common/utils.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/engine/perf/profiling.h" + +namespace mindspore { +namespace dataset { +static std::shared_ptr instance_ptr_ = nullptr; + +std::shared_ptr TdtPlugin::GetInstance() { + if (instance_ptr_ == nullptr) { + instance_ptr_ = std::shared_ptr(new TdtPlugin); + } + return instance_ptr_; +} + +TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time) { + MS_LOG(DEBUG) << "TDT channel name is " << channel_name << "."; + std::vector items; + double start_time; + auto ret = translate(ts_row, items); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "TDT converting tensor failed!"; + return FAILED; + } + if (profiling) { + start_time = ProfilingTime::GetCurMilliSecond(); + } + if (tdt::TdtHostPushData(channel_name, items) != 0) { + MS_LOG(ERROR) << "TDT pushing data failed!"; + return FAILED; + } + if (profiling) { + double end_time = ProfilingTime::GetCurMilliSecond(); + time = (int32_t)(end_time - start_time); + } + return SUCCESS; +} + +TdtStatus TdtPlugin::getTdtType(DataType d_type, std::string &datatype) { + switch (d_type.value()) { + case DataType::DE_BOOL: + datatype = "bool"; + break; + case DataType::DE_INT8: + datatype = "int8"; + break; + case DataType::DE_UINT8: + datatype = "uint8"; + break; + case DataType::DE_INT16: + datatype = "int16"; + break; + case DataType::DE_UINT16: + datatype = "uint16"; + break; + case DataType::DE_INT32: + datatype = "int32"; + break; + case DataType::DE_UINT32: + datatype = "uint32"; + break; + case DataType::DE_FLOAT16: + datatype = "float16"; + break; + case DataType::DE_FLOAT32: + datatype = "float32"; + break; + case DataType::DE_FLOAT64: + datatype = "float64"; + break; + case DataType::DE_INT64: + datatype = "int64"; + break; + case DataType::DE_UINT64: + datatype = "uint64"; + break; + default: + return FAILED; + } + return SUCCESS; +} + +TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector &items) { + if (ts_row.size() == 0) { + MS_LOG(ERROR) << "TDT the size of row is zero."; + return SUCCESS; + } + for (auto ts : ts_row) { + std::string datatype; + TdtStatus status = getTdtType(ts->type(), datatype); + if (status != SUCCESS) { + return status; + } + TensorShape tsShape = ts->shape(); + std::string dataShapes = "["; + for (auto dim : tsShape.AsVector()) { + (void)dataShapes.append(std::to_string(dim)).append(","); + } + dataShapes.pop_back(); + (void)dataShapes.append("]"); + DataItem data_item; + data_item.dataType_ = tdt::TDT_TENSOR; + data_item.tensorShape_ = dataShapes; + data_item.tensorType_ = datatype; + data_item.dataLen_ = ts->SizeInBytes(); + data_item.dataPtr_ = + std::shared_ptr(reinterpret_cast(&(*ts->begin())), [](const void *elem) {}); + items.emplace_back(data_item); + MS_LOG(DEBUG) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is " + << ts->Size() << "."; + } + return SUCCESS; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h new file mode 100644 index 0000000000..a7db08b7f5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_TDT_TDT_PLUGIN_H_ +#define DATASET_ENGINE_TDT_TDT_PLUGIN_H_ + +#include +#include +#include +#include +#include +#include +#include "tdt/tdt_host_interface.h" + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" + +namespace mindspore { +namespace dataset { +enum TdtStatus { SUCCESS, FAILED }; + +using tdt::DataItem; + +class TdtPlugin { + public: + static std::shared_ptr GetInstance(); + + TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time); + + private: + TdtPlugin() {} + + TdtStatus getTdtType(DataType d_type, std::string &datatype); + + TdtStatus translate(const TensorRow &ts_row, std::vector &items); + + void *tdt_handle_ = nullptr; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_TDT_TDT_PLUGIN_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/include/dataset/core/constants.h new file mode 120000 index 0000000000..22fe6d07e1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/core/constants.h @@ -0,0 +1 @@ +../../../core/constants.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/core/data_type.h b/mindspore/ccsrc/minddata/dataset/include/dataset/core/data_type.h new file mode 120000 index 0000000000..37a0e1b686 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/core/data_type.h @@ -0,0 +1 @@ +../../../core/data_type.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/core/tensor_shape.h b/mindspore/ccsrc/minddata/dataset/include/dataset/core/tensor_shape.h new file mode 120000 index 0000000000..1fb7a24d91 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/core/tensor_shape.h @@ -0,0 +1 @@ +../../../core/tensor_shape.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/include/dataset/util/status.h new file mode 120000 index 0000000000..b06279c05b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/util/status.h @@ -0,0 +1 @@ +../../../util/status.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h new file mode 100644 index 0000000000..6f38f5ea16 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -0,0 +1,357 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_INCLUDE_DATASETS_H_ +#define DATASET_INCLUDE_DATASETS_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/include/tensor.h" +#include "minddata/dataset/include/iterator.h" +#include "minddata/dataset/include/samplers.h" + +namespace mindspore { +namespace dataset { + +// Forward declare +class DatasetOp; +class DataSchema; +class Tensor; +class TensorShape; + +namespace api { + +class TensorOperation; +class SamplerObj; +class ImageFolderDataset; +class MnistDataset; +class BatchDataset; +class RepeatDataset; +class MapDataset; +class ShuffleDataset; +class Cifar10Dataset; +class ProjectDataset; + +/// \brief Function to create an ImageFolderDataset +/// \notes A source dataset that reads images from a tree of directories +/// All images within one folder have the same label +/// The generated dataset has two columns ['image', 'label'] +/// \param[in] dataset_dir Path to the root directory that contains the dataset +/// \param[in] decode A flag to decode in ImageFolder +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, +/// A `RandomSampler` will be used to randomly iterate the entire dataset +/// \param[in] extensions File extensions to be read +/// \param[in] class_indexing a class name to label map +/// \return Shared pointer to the current ImageFolderDataset +std::shared_ptr ImageFolder(std::string dataset_dir, bool decode = false, + std::shared_ptr sampler = nullptr, + std::set extensions = {}, + std::map class_indexing = {}); + +/// \brief Function to create a MnistDataset +/// \notes The generated dataset has two columns ['image', 'label'] +/// \param[in] dataset_dir Path to the root directory that contains the dataset +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, +/// A `RandomSampler` will be used to randomly iterate the entire dataset +/// \return Shared pointer to the current MnistDataset +std::shared_ptr Mnist(std::string dataset_dir, std::shared_ptr sampler = nullptr); + +/// \brief Function to create a Cifar10 Dataset +/// \notes The generated dataset has two columns ['image', 'label'] +/// \param[in] dataset_dir Path to the root directory that contains the dataset +/// \param[in] num_samples The number of images to be included in the dataset +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` +/// will be used to randomly iterate the entire dataset +/// \return Shared pointer to the current Dataset +std::shared_ptr Cifar10(const std::string &dataset_dir, int32_t num_samples, + std::shared_ptr sampler); + +/// \class Dataset datasets.h +/// \brief A base class to represent a dataset in the data pipeline. +class Dataset : public std::enable_shared_from_this { + public: + friend class Iterator; + + /// \brief Constructor + Dataset(); + + /// \brief Destructor + ~Dataset() = default; + + /// \brief Pure virtual function to convert a Dataset class into a runtime dataset object + /// \return shared pointer to the list of newly created DatasetOps + virtual std::shared_ptr>> Build() = 0; + + /// \brief Pure virtual function for derived class to implement parameters validation + /// \return bool True if all the params are valid + virtual bool ValidateParams() = 0; + + /// \brief Setter function for runtime number of workers + /// \param[in] num_workers The number of threads in this operator + /// \return Shared pointer to the original object + std::shared_ptr SetNumWorkers(int32_t num_workers) { + num_workers_ = num_workers; + return shared_from_this(); + } + + /// \brief Function to create an Iterator over the Dataset pipeline + /// \return Shared pointer to the Iterator + std::shared_ptr CreateIterator(); + + /// \brief Function to create a BatchDataset + /// \notes Combines batch_size number of consecutive rows into batches + /// \param[in] batch_size Path to the root directory that contains the dataset + /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete + /// batch. If true, and if there are less than batch_size rows + /// available to make the last batch, then those rows will + /// be dropped and not propagated to the next node + /// \return Shared pointer to the current BatchDataset + std::shared_ptr Batch(int32_t batch_size, bool drop_remainder = false); + + /// \brief Function to create a RepeatDataset + /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1 + /// \param[in] count Number of times the dataset should be repeated + /// \return Shared pointer to the current Dataset + /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset` + /// due to a limitation in the current implementation + std::shared_ptr Repeat(int32_t count = -1); + + /// \brief Function to create a MapDataset + /// \notes Applies each operation in operations to this dataset + /// \param[in] operations Vector of operations to be applied on the dataset. Operations are + /// applied in the order they appear in this list + /// \param[in] input_columns Vector of the names of the columns that will be passed to the first + /// operation as input. The size of this list must match the number of + /// input columns expected by the first operator. The default input_columns + /// is the first column + /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation + /// This parameter is mandatory if len(input_columns) != len(output_columns) + /// The size of this list must match the number of output columns of the + /// last operation. The default output_columns will have the same + /// name as the input columns, i.e., the columns will be replaced + /// \param[in] project_columns A list of column names to project + /// \return Shared pointer to the current MapDataset + std::shared_ptr Map(std::vector> operations, + std::vector input_columns = {}, + std::vector output_columns = {}, + const std::vector &project_columns = {}); + + /// \brief Function to create a Shuffle Dataset + /// \notes Randomly shuffles the rows of this dataset + /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling + /// \return Shared pointer to the current ShuffleDataset + std::shared_ptr Shuffle(int32_t shuffle_size); + + /// \brief Function to create a Project Dataset + /// \notes Applies project to the dataset + /// \param[in] columns The name of columns to project + /// \return Shared pointer to the current Dataset + std::shared_ptr Project(const std::vector &columns); + + protected: + std::vector> children; + std::shared_ptr parent; + + int32_t num_workers_; + int32_t rows_per_buffer_; + int32_t connector_que_size_; +}; + +/* ####################################### Derived Dataset classes ################################# */ + +/// \class ImageFolderDataset +/// \brief A Dataset derived class to represent ImageFolder dataset +class ImageFolderDataset : public Dataset { + public: + /// \brief Constructor + ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, + std::set extensions, std::map class_indexing); + + /// \brief Destructor + ~ImageFolderDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + bool decode_; + bool recursive_; + std::shared_ptr sampler_; + std::map class_indexing_; + std::set exts_; +}; + +class MnistDataset : public Dataset { + public: + /// \brief Constructor + MnistDataset(std::string dataset_dir, std::shared_ptr sampler); + + /// \brief Destructor + ~MnistDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + std::shared_ptr sampler_; +}; + +class BatchDataset : public Dataset { + public: + /// \brief Constructor + BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector cols_to_map, + std::map>> pad_map); + + /// \brief Destructor + ~BatchDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + int32_t batch_size_; + bool drop_remainder_; + bool pad_; + std::vector cols_to_map_; + std::map>> pad_map_; +}; + +class RepeatDataset : public Dataset { + public: + /// \brief Constructor + explicit RepeatDataset(uint32_t count); + + /// \brief Destructor + ~RepeatDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + uint32_t repeat_count_; +}; + +class ShuffleDataset : public Dataset { + public: + ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch); + + ~ShuffleDataset() = default; + + std::shared_ptr>> Build() override; + + bool ValidateParams() override; + + private: + int32_t shuffle_size_; + uint32_t shuffle_seed_; + bool reset_every_epoch_; +}; + +class MapDataset : public Dataset { + public: + /// \brief Constructor + MapDataset(std::vector> operations, std::vector input_columns = {}, + std::vector output_columns = {}, const std::vector &columns = {}); + + /// \brief Destructor + ~MapDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::vector> operations_; + std::vector input_columns_; + std::vector output_columns_; + std::vector project_columns_; +}; + +class Cifar10Dataset : public Dataset { + public: + /// \brief Constructor + Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler); + + /// \brief Destructor + ~Cifar10Dataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + int32_t num_samples_; + std::shared_ptr sampler_; +}; + +class ProjectDataset : public Dataset { + public: + /// \brief Constructor + explicit ProjectDataset(const std::vector &columns); + + /// \brief Destructor + ~ProjectDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::vector columns_; +}; +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // DATASET_INCLUDE_DATASETS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/iterator.h b/mindspore/ccsrc/minddata/dataset/include/iterator.h new file mode 100644 index 0000000000..c3784821a6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/iterator.h @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_INCLUDE_ITERATOR_H_ +#define DATASET_INCLUDE_ITERATOR_H_ + +#include +#include +#include +#include +#include "minddata/dataset/include/status.h" + +namespace mindspore { +namespace dataset { + +// Forward declare +class ExecutionTree; +class DatasetIterator; +class DatasetOp; +class Tensor; + +namespace api { + +class Dataset; + +using TensorMap = std::unordered_map>; + +// Abstract class for iterating over the dataset. +class Iterator { + public: + /// \brief Constructor + Iterator() = default; + + /// \brief Destructor + ~Iterator() = default; + + /// \brief Method for building and launching the pipeline. + /// \param[in] ops - a vector of DatasetOp in the data pipeline. + /// \return - a Status error code, returns OK if no error encountered. + Status BuildAndLaunchTree(std::shared_ptr ds); + + /// \brief Function to get the next row from the data pipeline. + /// \param[out] row - the output tensor row. + void GetNextRow(TensorMap *row); + + /// \brief Function to shut down the data pipeline. + void Stop(); + + class _Iterator { + public: + explicit _Iterator(Iterator *lt) : lt_{lt}, cur_row_{nullptr} { + if (lt_) { + cur_row_ = new TensorMap(); + lt_->GetNextRow(cur_row_); + } + } + + // Destructor + ~_Iterator() { + if (cur_row_) { + delete cur_row_; + } + } + + _Iterator &operator++() { + if (lt_) { + ++ind_; + lt_->GetNextRow(cur_row_); + } + if (cur_row_ && cur_row_->size() == 0) { + delete cur_row_; + cur_row_ = nullptr; + } + return *this; + } // prefix ++ overload + TensorMap &operator*() { return *cur_row_; } // dereference operator + TensorMap *operator->() { return cur_row_; } + + bool operator!=(const _Iterator &rhs) { return cur_row_ != rhs.cur_row_; } + + private: + int ind_; // the cur node our Iterator points to + Iterator *lt_; + TensorMap *cur_row_; + }; + + _Iterator begin() { return _Iterator(this); } + + _Iterator end() { return _Iterator(nullptr); } + + private: + // Runtime tree. + // Use shared_ptr instead of unique_ptr because the DatasetIterator constructor takes in a shared_ptr type. + std::shared_ptr tree_; + + // Runtime iterator + std::unique_ptr iterator_; +}; +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // DATASET_INCLUDE_ITERATOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h new file mode 100644 index 0000000000..3d57e67059 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -0,0 +1,199 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_API_SAMPLERS_H_ +#define DATASET_API_SAMPLERS_H_ + +#include +#include + +namespace mindspore { +namespace dataset { + +// Internal Sampler class forward declaration +class Sampler; + +namespace api { + +class SamplerObj : public std::enable_shared_from_this { + public: + SamplerObj(); + + ~SamplerObj() = default; + + virtual std::shared_ptr Build() = 0; + virtual bool ValidateParams() = 0; +}; + +class DistributedSamplerObj; +class PKSamplerObj; +class RandomSamplerObj; +class SequentialSamplerObj; +class SubsetRandomSamplerObj; +class WeightedRandomSamplerObj; + +/// Function to create a Distributed Sampler. +/// \notes A Sampler that access a shard of the dataset. +/// \param[in] num_shards - Number of shards to divide the dataset into. +/// \param[in] shard_id - Shard ID of the current shard within num_shards. +/// \param[in] shuffle - If true, the indices are shuffled. +/// \param[in] num_samples - The number of samples to draw (default to all elements). +/// \param[in] seed - The seed in use when shuffle is true. +/// \return Shared pointer to the current Sampler. +std::shared_ptr DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, + int64_t num_samples = 0, uint32_t seed = 1); + +/// Function to create a PK Sampler. +/// \notes Samples K elements for each P class in the dataset. +/// This will sample all classes. +/// \param[in] num_val - Number of elements to sample for each class. +/// \param[in] shuffle - If true, the class IDs are shuffled. +/// \param[in] num_samples - The number of samples to draw (default to all elements). +/// \return Shared pointer to the current Sampler. +std::shared_ptr PKSampler(int64_t num_val, bool shuffle = false, int64_t num_samples = 0); + +/// Function to create a Random Sampler. +/// \notes Samples the elements randomly. +/// \param[in] replacement - If True, put the sample ID back for the next draw. +/// \param[in] num_samples - The number of samples to draw (default to all elements). +/// \return Shared pointer to the current Sampler. +std::shared_ptr RandomSampler(bool replacement = false, int64_t num_samples = 0); + +/// Function to create a Sequential Sampler. +/// \notes Samples the dataset elements sequentially, same as not having a sampler. +/// \param[in] start_index - Index to start sampling at (dafault to start at first id). +/// \param[in] num_samples - The number of samples to draw (default to all elements). +/// \return Shared pointer to the current Sampler. +std::shared_ptr SequentialSampler(int64_t start_index = 0, int64_t num_samples = 0); + +/// Function to create a Subset Random Sampler. +/// \notes Samples the elements randomly from a sequence of indices. +/// \param[in] indices - A vector sequence of indices. +/// \param[in] num_samples - The number of samples to draw (default to all elements). +/// \return Shared pointer to the current Sampler. +std::shared_ptr SubsetRandomSampler(const std::vector &indices, + int64_t num_samples = 0); + +/// Function to create a Weighted Random Sampler. +/// \notes Samples the elements from [0, len(weights) - 1] randomly with the given +/// weights (probabilities). +/// \param[in] weights - A vector sequence of weights, not necessarily summing up to 1. +/// \param[in] num_samples - The number of samples to draw (default to all elements). +/// \param[in] replacement - If True, put the sample ID back for the next draw. +/// \return Shared pointer to the current Sampler. +std::shared_ptr WeightedRandomSampler(const std::vector &weights, + int64_t num_samples = 0, bool replacement = true); + +/* ####################################### Derived Sampler classes ################################# */ +class DistributedSamplerObj : public SamplerObj { + public: + DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, uint32_t seed); + + ~DistributedSamplerObj() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + int64_t num_shards_; + int64_t shard_id_; + bool shuffle_; + int64_t num_samples_; + uint32_t seed_; +}; + +class PKSamplerObj : public SamplerObj { + public: + PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples); + + ~PKSamplerObj() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + int64_t num_val_; + bool shuffle_; + int64_t num_samples_; +}; + +class RandomSamplerObj : public SamplerObj { + public: + RandomSamplerObj(bool replacement, int64_t num_samples); + + ~RandomSamplerObj() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + bool replacement_; + int64_t num_samples_; +}; + +class SequentialSamplerObj : public SamplerObj { + public: + SequentialSamplerObj(int64_t start_index, int64_t num_samples); + + ~SequentialSamplerObj() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + int64_t start_index_; + int64_t num_samples_; +}; + +class SubsetRandomSamplerObj : public SamplerObj { + public: + SubsetRandomSamplerObj(const std::vector &indices, int64_t num_samples); + + ~SubsetRandomSamplerObj() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + const std::vector &indices_; + int64_t num_samples_; +}; + +class WeightedRandomSamplerObj : public SamplerObj { + public: + explicit WeightedRandomSamplerObj(const std::vector &weights, int64_t num_samples = 0, + bool replacement = true); + + ~WeightedRandomSamplerObj() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + const std::vector &weights_; + int64_t num_samples_; + bool replacement_; +}; +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // DATASET_API_SAMPLERS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/status.h b/mindspore/ccsrc/minddata/dataset/include/status.h new file mode 120000 index 0000000000..bba92b63ad --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/status.h @@ -0,0 +1 @@ +../util/status.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/tensor.h b/mindspore/ccsrc/minddata/dataset/include/tensor.h new file mode 120000 index 0000000000..34b5e020a9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/tensor.h @@ -0,0 +1 @@ +../core/tensor.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h new file mode 100644 index 0000000000..31531a20af --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -0,0 +1,380 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_API_TRANSFORMS_H_ +#define DATASET_API_TRANSFORMS_H_ + +#include +#include +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { + +class TensorOp; + +namespace api { +// Abstract class to represent a dataset in the data pipeline. +class TensorOperation : public std::enable_shared_from_this { + public: + /// \brief Constructor + TensorOperation(); + + /// \brief Destructor + ~TensorOperation() = default; + + /// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object. + /// \return shared pointer to the newly created TensorOp. + virtual std::shared_ptr Build() = 0; + + virtual bool ValidateParams() = 0; +}; + +// Transform operations for performing computer vision. +namespace vision { + +class NormalizeOperation; +class DecodeOperation; +class ResizeOperation; +class RandomCropOperation; +class CenterCropOperation; +class UniformAugOperation; +class RandomHorizontalFlipOperation; +class RandomVerticalFlipOperation; +class RandomRotationOperation; +class PadOperation; +class CutOutOperation; +class RandomColorAdjustOperation; + +/// \brief Function to create a Normalize TensorOperation. +/// \notes Normalize the input image with respect to mean and standard deviation. +/// \param[in] mean - a vector of mean values for each channel, w.r.t channel order. +/// \param[in] std - a vector of standard deviations for each channel, w.r.t. channel order. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr Normalize(std::vector mean, std::vector std); + +/// \brief Function to create a Decode TensorOperation. +/// \notes Decode the input image in RGB mode. +/// \param[in] rgb - a boolean of whether to decode in RGB mode or not. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr Decode(bool rgb = true); + +/// \brief Function to create a Resize TensorOperation. +/// \notes Resize the input image to the given size.. +/// \param[in] size - a vector representing the output size of the resized image. +/// If size is a single value, the image will be resized to this value with +/// the same image aspect ratio. If size has 2 values, it should be (height, width). +/// \param[in] interpolation An enum for the mode of interpolation +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr Resize(std::vector size, + InterpolationMode interpolation = InterpolationMode::kLinear); + +/// \brief Function to create a RandomCrop TensorOperation. +/// \notes Crop the input image at a random location. +/// \param[in] size - a vector representing the output size of the cropped image. +/// If size is a single value, a square crop of size (size, size) is returned. +/// If size has 2 values, it should be (height, width). +/// \param[in] padding - a vector with the value of pixels to pad the image. If 4 values are provided, +/// it pads the left, top, right and bottom respectively. +/// \param[in] pad_if_needed - a boolean whether to pad the image if either side is smaller than +/// the given output size. +/// \param[in] fill_value - a vector representing the pixel intensity of the borders, it is used to +/// fill R, G, B channels respectively. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr RandomCrop(std::vector size, std::vector padding = {0, 0, 0, 0}, + bool pad_if_needed = false, + std::vector fill_value = {0, 0, 0}); + +/// \brief Function to create a CenterCrop TensorOperation. +/// \notes Crops the input image at the center to the given size. +/// \param[in] size - a vector representing the output size of the cropped image. +/// If size is a single value, a square crop of size (size, size) is returned. +/// If size has 2 values, it should be (height, width). +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr CenterCrop(std::vector size); + +/// \brief Function to create a UniformAugment TensorOperation. +/// \notes Tensor operation to perform randomly selected augmentation. +/// \param[in] operations - a vector of TensorOperation operations. +/// \param[in] num_ops - integer representing the number of OPs to be selected and applied. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr UniformAugment(std::vector> operations, + int32_t num_ops = 2); + +/// \brief Function to create a RandomHorizontalFlip TensorOperation. +/// \notes Tensor operation to perform random horizontal flip. +/// \param[in] prob - float representing the probability of flip. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr RandomHorizontalFlip(float prob = 0.5); + +/// \brief Function to create a RandomVerticalFlip TensorOperation. +/// \notes Tensor operation to perform random vertical flip. +/// \param[in] prob - float representing the probability of flip. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr RandomVerticalFlip(float prob = 0.5); + +/// \brief Function to create a RandomRotation TensorOp +/// \notes Rotates the image according to parameters +/// \param[in] degrees A float vector size 2, representing the starting and ending degree +/// \param[in] resample An enum for the mode of interpolation +/// \param[in] expand A boolean representing whether the image is expanded after rotation +/// \param[in] center A float vector size 2, representing the x and y center of rotation. +/// \param[in] fill_value A uint8_t vector size 3, representing the rgb value of the fill color +/// \return Shared pointer to the current TensorOp +std::shared_ptr RandomRotation( + std::vector degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false, + std::vector center = {-1, -1}, std::vector fill_value = {0, 0, 0}); + +/// \brief Function to create a Pad TensorOp +/// \notes Pads the image according to padding parameters +/// \param[in] padding A vector representing the number of pixels to pad the image +/// If vector has one value, it pads all sides of the image with that value +/// If vector has two values, it pads left and right with the first and +/// top and bottom with the second value +/// If vector has four values, it pads left, top, right, and bottom with +/// those values respectively +/// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is +/// BorderType.kConstant. If 3 values are provided, +/// it is used to fill R, G, B channels respectively +/// \param[in] padding_mode The method of padding (default=BorderType.kConstant) +/// Can be any of +/// [BorderType.kConstant, BorderType.kEdge, BorderType.kReflect, BorderType.kSymmetric] +/// - BorderType.kConstant, means it fills the border with constant values +/// - BorderType.kEdge, means it pads with the last value on the edge +/// - BorderType.kReflect, means it reflects the values on the edge omitting the last value of edge +/// - BorderType.kSymmetric, means it reflects the values on the edge repeating the last value of edge +/// \return Shared pointer to the current TensorOp +std::shared_ptr Pad(std::vector padding, std::vector fill_value = {0}, + BorderType padding_mode = BorderType::kConstant); + +/// \brief Function to create a CutOut TensorOp +/// \notes Randomly cut (mask) out a given number of square patches from the input image +/// \param[in] length Integer representing the side length of each square patch +/// \param[in] num_patches Integer representing the number of patches to be cut out of an image +/// \return Shared pointer to the current TensorOp +std::shared_ptr CutOut(int32_t length, int32_t num_patches = 1); + +/// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image +/// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values +/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} +/// \param[in] contrast Contrast adjustment factor. Must be a vector of one or two values +/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} +/// \param[in] saturation Saturation adjustment factor. Must be a vector of one or two values +/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} +/// \param[in] hue Brightness adjustment factor. Must be a vector of one or two values +/// if it's a vector of two values it must be in the form of [min, max] where -0.5 <= min <= max <= 0.5 +/// Default value is {0, 0} +/// \return Shared pointer to the current TensorOp +std::shared_ptr RandomColorAdjust(std::vector brightness = {1.0, 1.0}, + std::vector contrast = {1.0, 1.0}, + std::vector saturation = {1.0, 1.0}, + std::vector hue = {0.0, 0.0}); + +/* ####################################### Derived TensorOperation classes ################################# */ + +class NormalizeOperation : public TensorOperation { + public: + NormalizeOperation(std::vector mean, std::vector std); + + ~NormalizeOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector mean_; + std::vector std_; +}; + +class DecodeOperation : public TensorOperation { + public: + explicit DecodeOperation(bool rgb = true); + + ~DecodeOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + bool rgb_; +}; + +class ResizeOperation : public TensorOperation { + public: + explicit ResizeOperation(std::vector size, + InterpolationMode interpolation_mode = InterpolationMode::kLinear); + + ~ResizeOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector size_; + InterpolationMode interpolation_; +}; + +class RandomCropOperation : public TensorOperation { + public: + RandomCropOperation(std::vector size, std::vector padding = {0, 0, 0, 0}, + bool pad_if_needed = false, std::vector fill_value = {0, 0, 0}); + + ~RandomCropOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector size_; + std::vector padding_; + bool pad_if_needed_; + std::vector fill_value_; +}; + +class CenterCropOperation : public TensorOperation { + public: + explicit CenterCropOperation(std::vector size); + + ~CenterCropOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector size_; +}; + +class UniformAugOperation : public TensorOperation { + public: + explicit UniformAugOperation(std::vector> operations, int32_t num_ops = 2); + + ~UniformAugOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector> operations_; + int32_t num_ops_; +}; + +class RandomHorizontalFlipOperation : public TensorOperation { + public: + explicit RandomHorizontalFlipOperation(float probability = 0.5); + + ~RandomHorizontalFlipOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + float probability_; +}; + +class RandomVerticalFlipOperation : public TensorOperation { + public: + explicit RandomVerticalFlipOperation(float probability = 0.5); + + ~RandomVerticalFlipOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + float probability_; +}; + +class RandomRotationOperation : public TensorOperation { + public: + RandomRotationOperation(std::vector degrees, InterpolationMode interpolation_mode, bool expand, + std::vector center, std::vector fill_value); + + ~RandomRotationOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector degrees_; + InterpolationMode interpolation_mode_; + std::vector center_; + bool expand_; + std::vector fill_value_; +}; + +class PadOperation : public TensorOperation { + public: + PadOperation(std::vector padding, std::vector fill_value = {0}, + BorderType padding_mode = BorderType::kConstant); + + ~PadOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector padding_; + std::vector fill_value_; + BorderType padding_mode_; +}; + +class CutOutOperation : public TensorOperation { + public: + explicit CutOutOperation(int32_t length, int32_t num_patches = 1); + + ~CutOutOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + int32_t length_; + int32_t num_patches_; +}; + +class RandomColorAdjustOperation : public TensorOperation { + public: + RandomColorAdjustOperation(std::vector brightness = {1.0, 1.0}, std::vector contrast = {1.0, 1.0}, + std::vector saturation = {1.0, 1.0}, std::vector hue = {0.0, 0.0}); + + ~RandomColorAdjustOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector brightness_; + std::vector contrast_; + std::vector saturation_; + std::vector hue_; +}; +} // namespace vision +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // DATASET_API_TRANSFORMS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/utils/log_adapter.h b/mindspore/ccsrc/minddata/dataset/include/utils/log_adapter.h new file mode 120000 index 0000000000..f2c939bc0b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/utils/log_adapter.h @@ -0,0 +1 @@ +../../../../utils/log_adapter.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/utils/overload.h b/mindspore/ccsrc/minddata/dataset/include/utils/overload.h new file mode 120000 index 0000000000..7dc313d512 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/utils/overload.h @@ -0,0 +1 @@ +../../../../utils/overload.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/CMakeLists.txt new file mode 100644 index 0000000000..8a9096ff23 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/CMakeLists.txt @@ -0,0 +1,14 @@ +add_subdirectory(image) +add_subdirectory(data) +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +if (ENABLE_PYTHON) + add_library(kernels OBJECT + py_func_op.cc + tensor_op.cc) + target_include_directories(kernels PRIVATE ${pybind11_INCLUDE_DIRS}) +else() + add_library(kernels OBJECT + tensor_op.cc) +endif() + diff --git a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/data/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/kernels/data/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.cc new file mode 100644 index 0000000000..0c91b38b2d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/concatenate_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +Status ConcatenateOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + RETURN_IF_NOT_OK(Concatenate(input, output, axis_, prepend_, append_)); + return Status::OK(); +} + +Status ConcatenateOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + + std::vector inputs_copy; + inputs_copy.push_back(inputs[0].Squeeze()); + + CHECK_FAIL_RETURN_UNEXPECTED(inputs.at(0).Rank() == 1, "Only 1D input tensors supported"); + + outputs.clear(); + dsize_t output_shape = 0; + output_shape = output_shape + inputs.at(0).NumOfElements(); + if (prepend_ != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(prepend_->shape().Rank() == 1, "Only 1D prepend tensors supported"); + output_shape = output_shape + prepend_->shape().NumOfElements(); + } + if (append_ != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(append_->shape().Rank() == 1, "Only 1D append tensors supported"); + output_shape = output_shape + append_->shape().NumOfElements(); + } + + outputs.emplace_back(std::vector{output_shape}); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.h new file mode 100644 index 0000000000..46cc613049 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_DATA_CONCATENATE_OP_H_ +#define DATASET_KERNELS_DATA_CONCATENATE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +class ConcatenateOp : public TensorOp { + public: + /// Constructor to ConcatenateOp. + /// @param int8_t axis - axis to concatenate tensors along. + /// @param std::shared_ptr prepend - prepend tensor. + /// @param std::shared_ptr append -append tensor. + explicit ConcatenateOp(int8_t axis, std::shared_ptr prepend, std::shared_ptr append) + : axis_(axis), prepend_(prepend), append_(append) {} + + ~ConcatenateOp() override = default; + + /// Print method to see which tensor Op this is. + /// @param std::ostream &out - output stream object. + void Print(std::ostream &out) const override { out << "ConcatenateOp"; } + + /// Compute method allowing multiple tensors as inputs + /// @param TensorRow &input - input tensor rows + /// @param TensorRow *output - output tensor rows + Status Compute(const TensorRow &input, TensorRow *output) override; + + /// Compute tensor output shape + /// @param std::vector &inputs - vector of input tensor shapes + /// @param std::vector &inputs, std::vector &outputs) override; + + /// Number of inputs the tensor operation accepts + uint32_t NumInput() override { return 0; } + + std::string Name() const override { return kConcatenateOp; } + + private: + int8_t axis_; + std::shared_ptr prepend_; + std::shared_ptr append_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CONCATENATE_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc new file mode 100644 index 0000000000..b1d51a6c08 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc @@ -0,0 +1,656 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/data/data_utils.h" + +#include +#include +#include +#include + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/core/pybind_support.h" +#endif +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status OneHotEncodingUnsigned(const std::shared_ptr &input, std::shared_ptr *output, + dsize_t num_classes, int64_t index) { + uint64_t class_idx; + if (input->Rank() == 0) { + RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {})); + } else { + RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {index})); + } + if (class_idx >= static_cast(num_classes)) { + RETURN_STATUS_UNEXPECTED("One_hot index values are not in range"); + } + if (input->type() == DataType::DE_UINT64) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_UINT32) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_UINT16) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_UINT8) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else { + RETURN_STATUS_UNEXPECTED("One hot unsigned only supports unsigned int as input."); + } + return Status::OK(); +} + +Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, + int64_t index) { + int64_t class_idx; + if (input->Rank() == 0) { + RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {})); + } else { + RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {index})); + } + if (class_idx >= static_cast(num_classes)) { + RETURN_STATUS_UNEXPECTED("One_hot index values are not in range"); + } + if (input->type() == DataType::DE_INT64) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_INT32) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_INT16) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_INT8) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else { + RETURN_STATUS_UNEXPECTED("One hot signed only supports signed int as input."); + } + return Status::OK(); +} + +Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *output, dsize_t num_classes) { + input->Squeeze(); + + if (input->Rank() > 1) { // We expect the input to be int he first dimension + RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors."); + } + if (!input->type().IsInt()) { + RETURN_STATUS_UNEXPECTED("One hot does not support input of this type."); + } + try { + dsize_t num_elements = 1; + if (input->Rank() == 1) num_elements = input->shape()[0]; + TensorShape out_shape({num_elements, num_classes}); + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type())); + RETURN_IF_NOT_OK(out->Zero()); + for (dsize_t i = 0; i < num_elements; ++i) { + if (input->type().IsUnsignedInt()) { + RETURN_IF_NOT_OK(OneHotEncodingUnsigned(input, &out, num_classes, i)); + } else { + RETURN_IF_NOT_OK(OneHotEncodingSigned(input, &out, num_classes, i)); + } + } + out->Squeeze(); + *output = out; + return Status::OK(); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp"); + } +} + +Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value) { + const DataType &fill_type = fill_value->type(); + const DataType &input_type = input->type(); + const TensorShape &input_shape = input->shape(); + + CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)), + "Types do not match"); + + CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar"); + + std::shared_ptr out, fill_output; + + if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) { + auto op = std::make_unique(input_type); + RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output)); + } else { + fill_output = fill_value; + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type)); + + switch (input_type.value()) { + case DataType::DE_BOOL: { + bool value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT8: { + int8_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT8: { + uint8_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT16: { + uint16_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT16: { + int16_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT32: { + uint32_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT32: { + int32_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT64: { + uint64_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT64: { + int64_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_FLOAT16: { + int64_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_FLOAT32: { + float value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_FLOAT64: { + double value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_STRING: { + std::vector strings; + std::string_view fill_string_view; + RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {})); + std::string fill_string = std::string(fill_string_view); + for (int i = 0; i < input_shape.NumOfElements(); i++) { + strings.emplace_back(fill_string); + } + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape)); + break; + } + case DataType::DE_UNKNOWN: { + RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type."); + break; + } + } + + *output = out; + return Status::OK(); +} +template +void Cast(const std::shared_ptr &input, std::shared_ptr *output) { + auto in_itr = input->begin(); + auto out_itr = (*output)->begin(); + auto out_end = (*output)->end(); + + for (; out_itr != out_end; static_cast(in_itr++), static_cast(out_itr++)) + *out_itr = static_cast(*in_itr); +} + +template +void CastFrom(const std::shared_ptr &input, std::shared_ptr *output) { + switch ((*output)->type().value()) { + case DataType::DE_BOOL: + Cast(input, output); + break; + case DataType::DE_INT8: + Cast(input, output); + break; + case DataType::DE_UINT8: + Cast(input, output); + break; + case DataType::DE_INT16: + Cast(input, output); + break; + case DataType::DE_UINT16: + Cast(input, output); + break; + case DataType::DE_INT32: + Cast(input, output); + break; + case DataType::DE_UINT32: + Cast(input, output); + break; + case DataType::DE_INT64: + Cast(input, output); + break; + case DataType::DE_UINT64: + Cast(input, output); + break; + case DataType::DE_FLOAT16: + Cast(input, output); + break; + case DataType::DE_FLOAT32: + Cast(input, output); + break; + case DataType::DE_FLOAT64: + Cast(input, output); + break; + case DataType::DE_UNKNOWN: + MS_LOG(ERROR) << "Unknown data type."; + break; + } +} + +// Type cast operator +Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type)); + + RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); + switch (input->type().value()) { + case DataType::DE_BOOL: + CastFrom(input, output); + break; + case DataType::DE_INT8: + CastFrom(input, output); + break; + case DataType::DE_UINT8: + CastFrom(input, output); + break; + case DataType::DE_INT16: + CastFrom(input, output); + break; + case DataType::DE_UINT16: + CastFrom(input, output); + break; + case DataType::DE_INT32: + CastFrom(input, output); + break; + case DataType::DE_UINT32: + CastFrom(input, output); + break; + case DataType::DE_INT64: + CastFrom(input, output); + break; + case DataType::DE_UINT64: + CastFrom(input, output); + break; + case DataType::DE_FLOAT16: + CastFrom(input, output); + break; + case DataType::DE_FLOAT32: + CastFrom(input, output); + break; + case DataType::DE_FLOAT64: + CastFrom(input, output); + break; + case DataType::DE_UNKNOWN: + // sanity check, unreachable code. + RETURN_STATUS_UNEXPECTED("TypeCast does not support input of this type."); + } + return Status::OK(); +} + +Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { + // initiate new tensor for type cast + DataType new_type = DataType("float16"); + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type)); + RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); + + auto in_itr = input->begin(); + auto out_itr = (*output)->begin(); + auto out_end = (*output)->end(); + + for (; out_itr != out_end; in_itr++, out_itr++) { + float element = *in_itr; + float float16_max = static_cast(std::numeric_limits::max()); + float float16_min = static_cast(std::numeric_limits::lowest()); + if (element > float16_max || element < float16_min) { + RETURN_STATUS_UNEXPECTED("Value " + std::to_string(element) + " is outside of valid float16 range [" + + std::to_string(float16_max) + ", " + std::to_string(float16_min) + "]."); + } + + *out_itr = Eigen::half(*in_itr); + } + + return Status::OK(); +} + +Status PadEnd(const std::shared_ptr &src, std::shared_ptr *dst, const std::vector &pad_shape, + const std::shared_ptr &pad_val) { + if (pad_val == nullptr) { + if (src->type().IsNumeric()) { + return PadEndNumeric(src, dst, pad_shape, 0); + } else { + return PadEndString(src, dst, pad_shape, ""); + } + } + CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(), + "Source and pad_value tensors are not of the same type."); + if (pad_val->type().IsNumeric()) { + std::shared_ptr float_pad_value; + RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32))); + float val = 0; + RETURN_IF_NOT_OK(float_pad_value->GetItemAt(&val, {})); + return PadEndNumeric(src, dst, pad_shape, val); + } + std::string_view val; + RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {})); + return PadEndString(src, dst, pad_shape, std::string(val)); +} + +Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr *dst, + const std::vector &pad_shape, float pad_val) { + CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr"); + if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) { + (*dst) = src; // if no padding, copy the pointer + } else { + CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); + RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type())); + auto tensor_type = src->type().value(); + if (pad_val == 0) { // if pad with zero, don't care what type it is + RETURN_IF_NOT_OK((*dst)->Zero()); + } else if (tensor_type == DataType::DE_INT8) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_BOOL) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_UINT8) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_INT16) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_FLOAT16) { + RETURN_IF_NOT_OK((*dst)->Fill(static_cast(pad_val))); + } else if (tensor_type == DataType::DE_UINT16) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_INT32) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_UINT32) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_INT64) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_UINT64) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_FLOAT32) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_FLOAT64) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else { + RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type"); + } + std::vector cur_ind(src->Rank(), 0); + RETURN_IF_NOT_OK(PadEndNumericHelper(src, *dst, cur_ind, 0)); + } + return Status::OK(); +} +Status PadEndNumericHelper(const std::shared_ptr &src, std::shared_ptr dst, + std::vector cur_ind, size_t cur_dim) { + if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data + dst->CopyLastDimAt(src, cur_ind); + } else { // not the last dimension, keep doing recursion + dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]); + for (dsize_t i = 0; i < min_ind; i++) { + cur_ind[cur_dim] = i; + RETURN_IF_NOT_OK(PadEndNumericHelper(src, dst, cur_ind, cur_dim + 1)); + } + } + return Status::OK(); +} + +Status PadEndString(const std::shared_ptr &src, std::shared_ptr *dst, + const std::vector &pad_shape, const std::string &pad_val) { + CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr"); + if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) { + (*dst) = src; // if no padding, copy the pointer + } else { + CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); + std::vector cur_ind(src->Rank(), 0); + std::vector strings; + RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, strings, TensorShape(pad_shape))); + } + return Status::OK(); +} + +Status PadEndStringHelper(const std::shared_ptr &src, std::vector *dst, + const TensorShape &dst_shape, std::vector cur_ind, size_t cur_dim, + const std::string &pad_value) { + if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data + dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]); + for (dsize_t i = 0; i < min_ind; i++) { + cur_ind[cur_dim] = i; + std::string_view item; + RETURN_IF_NOT_OK(src->GetItemAt(&item, cur_ind)); + dst->emplace_back(item); + } + for (dsize_t i = min_ind; i < dst_shape[cur_dim]; i++) { + dst->emplace_back(pad_value); + } + + } else { // not the last dimension, keep doing recursion + dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]); + for (dsize_t i = 0; i < min_ind; i++) { + cur_ind[cur_dim] = i; + RETURN_IF_NOT_OK(PadEndStringHelper(src, dst, dst_shape, cur_ind, cur_dim + 1, pad_value)); + } + dsize_t count = (dst_shape[cur_dim] - min_ind) * dst_shape.Strides()[cur_dim]; + for (dsize_t i = 0; i < count; i++) { + dst->emplace_back(pad_value); + } + } + return Status::OK(); +} + +template +Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &output, + const std::shared_ptr &value_tensor, RelationalOp op) { + T value; + RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {})); + auto in_itr = input->begin(); + auto out_itr = output->begin(); + for (; in_itr != input->end(); in_itr++, out_itr++) { + switch (op) { + case RelationalOp::kEqual: + *out_itr = (*in_itr == value); + break; + case RelationalOp::kNotEqual: + *out_itr = (*in_itr != value); + break; + case RelationalOp::kGreater: + *out_itr = (*in_itr > value); + break; + case RelationalOp::kGreaterEqual: + *out_itr = (*in_itr >= value); + break; + case RelationalOp::kLess: + *out_itr = (*in_itr < value); + break; + case RelationalOp::kLessEqual: + *out_itr = (*in_itr <= value); + break; + default: + RETURN_STATUS_UNEXPECTED("Unknown relational operator."); + } + } + return Status::OK(); +} + +Status Mask(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, + RelationalOp op) { + CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(), + "Cannot convert constant value to the type of the input tensor."); + CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar"); + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL))); + + std::unique_ptr value_cast_op(new TypeCastOp(input->type())); + std::shared_ptr casted_value; + if (input->type().IsNumeric()) { + RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value)); + } else { + casted_value = value; + } + + switch (input->type().value()) { + case DataType::DE_BOOL: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT8: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT8: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT16: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT16: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT32: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT32: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT64: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT64: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_FLOAT16: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_FLOAT32: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_FLOAT64: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_STRING: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UNKNOWN: + RETURN_STATUS_UNEXPECTED("Unsupported input type."); + break; + } + return Status::OK(); +} + +Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, + std::shared_ptr append) { + CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported"); + CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported"); + + axis = Tensor::HandleNeg(axis, input[0]->shape().Rank()); + CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported"); + + std::shared_ptr out; + if (prepend != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported"); + RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0])); + } else { + out = input[0]; + } + for (dsize_t i = 1; i < input.size(); i++) { + std::shared_ptr out_t; + CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported"); + RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i])); + out = out_t; + } + std::shared_ptr out_t; + if (append != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported"); + RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append)); + } else { + out_t = out; + } + output->push_back(out_t); + + return Status::OK(); +} + +Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, + std::shared_ptr append) { + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match"); + + TensorShape t({}); + + for (dsize_t i = 0; i < input->shape().Rank(); i++) { + if (i != axis) { + t = t.AppendDim(input->shape()[i]); + } else { + dsize_t new_shape = input->shape()[i] + append->shape()[i]; + + t = t.AppendDim(new_shape); + } + } + std::shared_ptr out; + + if (input->type().IsNumeric()) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type())); + + RETURN_IF_NOT_OK(out->Concatenate({0}, input)); + RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append)); + *output = out; + } else { + std::vector strings; + + auto itr = input->begin(); + for (; itr != input->end(); itr++) { + strings.emplace_back(*itr); + } + itr = append->begin(); + for (; itr != append->end(); itr++) { + strings.emplace_back(*itr); + } + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t)); + + *output = out; + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h new file mode 100644 index 0000000000..141545a583 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h @@ -0,0 +1,163 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_DATA_UTILS_H_ +#define DATASET_KERNELS_DATA_DATA_UTILS_H_ + +#include +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" + +namespace mindspore { +namespace dataset { +// Returns Onehot encoding of the input tensor. +// Example: if input=2 and numClasses=3, the output is [0 0 1]. +// @param input: Tensor has type DE_UINT64, the non-one hot values are stored +// along the first dimensions or rows.. +// If the rank of input is not 1 or the type is not DE_UINT64, +// then it will fail. +// @param output: Tensor. The shape of the output tensor is +// and the type is same as input. +// @param num_classes: Number of classes to. +Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *output, dsize_t num_classes); + +Status OneHotEncodingUnsigned(const std::shared_ptr &input, std::shared_ptr *output, + dsize_t num_classes, int64_t index); + +Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, + int64_t index); + +// Returns a tensor of shape input filled with the passed fill_value +// @param input Tensor +// @param output Tensor. The shape and type of the output tensor is same as input +// @param fill_value Tensor. A scalar tensor used to fill the output tensor + +Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value); + +// Returns a type changed input tensor. +// Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp +// @param input Tensor +// @param output Tensor. The shape of the output tensor is same as input with the type changed. +// @param data_type: type of data to cast data to +// @note: this operation will do a memcpy and if the value is truncated then precision will be lost + +template +void CastFrom(const std::shared_ptr &input, std::shared_ptr *output); + +template +void Cast(const std::shared_ptr &input, std::shared_ptr *output); + +Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output); + +Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type); + +// Pad input tensor according pad_shape, need to have same rank. +// Based on the type of the input tensor, PadEndNumeric/String will be called. +// @param std::shared_ptr src - tensor to pad from +// @param std::shared_ptr *dst - return tensor padded +// @param std::vector pad_shape - shape to pad to +// @param std::shared_ptr pad_val - value to pad with in Tensor format, +// @return - The error code return +Status PadEnd(const std::shared_ptr &src, std::shared_ptr *dst, const std::vector &pad_shape, + const std::shared_ptr &pad_val); + +// Pad input numeric tensor according pad_shape, need to have same rank. +// @param std::shared_ptr src - tensor to pad from +// @param std::shared_ptr *dst - return tensor padded +// @param std::vector pad_shape - shape to pad to +// @param float pad_val - value to pad with +// @return - The error code return +Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr *dst, + const std::vector &pad_shape, float pad_val); + +// recursive helper function for padding numric tensors. This function could be very expensive if called on a +// multi-dimensional tensor it is only meant to be called by PadEndNumeric. +// @tparam T - type of tensor and fill value +// @param std::shared_ptr src - Tensor to pad from +// @param std::shared_ptr* dst - Tensor to pad to, return value +// @param std::vector cur_ind - recursion helper +// @param T pad_val - value to pad tensor with +// @param size_t cur_dim - recursion helper +// @return Status - The error code return +Status PadEndNumericHelper(const std::shared_ptr &src, std::shared_ptr dst, + std::vector cur_ind, size_t cur_dim = 0); + +// Pad input string tensor according pad_shape, need to have same rank. +// @param std::shared_ptr src - tensor to pad from +// @param std::shared_ptr *dst - return tensor padded +// @param std::vector pad_shape - shape to pad to +// @param std::string pad_val - value to pad with +// @return - The error code return +Status PadEndString(const std::shared_ptr &src, std::shared_ptr *dst, + const std::vector &pad_shape, const std::string &pad_val); + +// recursive helper function for padding string tensors. This function could be very expensive if called on a +// multi-dimensional tensor it is only meant to be called by PadEndString. +// @tparam T - type of tensor and fill value +// @param std::shared_ptr src - Tensor to pad from +// @param std::shared_ptr* dst - Tensor to pad to, return value +// @param std::vector cur_ind - recursion helperas text +// @param std::string pad_val - value to pad tensor with +// @param size_t cur_dim - recursion helper +// @return Status - The error code return +Status PadEndStringHelper(const std::shared_ptr &src, std::vector *dst, + const TensorShape &dst_shape, std::vector cur_ind, size_t cur_dim, + const std::string &pad_value); + +enum class RelationalOp { + kEqual = 0, // == + kNotEqual, // != + kLess, // < + kLessEqual, // <= + kGreater, // > + kGreaterEqual, // >= +}; + +/// Helper method that masks the input tensor +/// @tparam T type of the tensor +/// @param input[in] input tensor +/// @param output[out] output tensor +/// @param value_tensor[in] scalar tensor value to compared with +/// @param op[in] RelationalOp enum +/// @return Status ok/error +template +Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &output, + const std::shared_ptr &value_tensor, RelationalOp op); + +/// Mask the input tensor +/// @param input[in] input tensor +/// @param output[out] output tensor +/// @param value[in] scalar tensor value to compared with +/// @param op[in] RelationalOp enum +/// @return Status ok/error +Status Mask(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, + RelationalOp op); + +Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, + std::shared_ptr append); + +// helper for concat, always append to the input, and pass that to the output +Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, + std::shared_ptr append); + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_DATA_DATA_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc new file mode 100644 index 0000000000..57a424704f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/data/duplicate_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +Status DuplicateOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, input[0])); + output->push_back(input[0]); + output->push_back(out); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.h new file mode 100644 index 0000000000..60b2d8c33b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_DUPLICATE_OP_H_ +#define DATASET_KERNELS_DATA_DUPLICATE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +class DuplicateOp : public TensorOp { + public: + DuplicateOp() = default; + + ~DuplicateOp() override = default; + + void Print(std::ostream &out) const override { out << "DuplicateOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + uint32_t NumOutput() override { return 2; } + + std::string Name() const override { return kDuplicateOp; } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DUPLICATE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.cc new file mode 100644 index 0000000000..f8dc746dff --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/fill_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status FillOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + Status s = Fill(input, output, fill_value_); + return s; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.h new file mode 100644 index 0000000000..af0d9e7941 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_DATA_FILL_OP_H_ +#define DATASET_KERNELS_DATA_FILL_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class FillOp : public TensorOp { + public: + explicit FillOp(std::shared_ptr value) : fill_value_(value) {} + + ~FillOp() override = default; + void Print(std::ostream &out) const override { out << "FillOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kFillOp; } + + private: + std::shared_ptr fill_value_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_FILL_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.cc new file mode 100644 index 0000000000..2dbe501a47 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/data/mask_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +Status MaskOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + std::shared_ptr temp_output; + CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(), "Cannot generate a string mask. Type should be numeric."); + + RETURN_IF_NOT_OK(Mask(input, &temp_output, value_, op_)); + + // cast the output to the the required type. Skip casting if type_ is bool. + if (type_ != DataType::DE_BOOL) { + RETURN_IF_NOT_OK(cast_->Compute(temp_output, output)); + } else { + *output = std::move(temp_output); + } + + return Status::OK(); +} + +Status MaskOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = type_; + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.h new file mode 100644 index 0000000000..e6ac8c3964 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_MASK_OP_H_ +#define DATASET_KERNELS_DATA_MASK_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/kernels/data/data_utils.h" + +namespace mindspore { +namespace dataset { + +class MaskOp : public TensorOp { + public: + MaskOp(RelationalOp op, std::shared_ptr value, DataType type = DataType(DataType::DE_BOOL)) + : op_(op), value_(std::move(value)), type_(type), cast_(new TypeCastOp(type)) {} + + ~MaskOp() override = default; + + void Print(std::ostream &out) const override { out << "MaskOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kMaskOp; } + + private: + RelationalOp op_; + std::shared_ptr value_; + DataType type_; + std::unique_ptr cast_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_MASK_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.cc new file mode 100644 index 0000000000..e2b7b74a96 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/one_hot_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status OneHotOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + Status s = OneHotEncoding(input, output, num_classes_); + return s; +} + +Status OneHotOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + std::vector inputs_copy; + inputs_copy.push_back(inputs[0].Squeeze()); + if (inputs_copy[0].Rank() == 0) outputs.emplace_back(std::vector{num_classes_}); + if (inputs_copy[0].Rank() == 1) outputs.emplace_back(std::vector{inputs_copy[0][0], num_classes_}); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.h new file mode 100644 index 0000000000..06a4823573 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_ONE_HOT_OP_H_ +#define DATASET_KERNELS_DATA_ONE_HOT_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class OneHotOp : public TensorOp { + public: + explicit OneHotOp(int num_classes) : num_classes_(num_classes) {} + + ~OneHotOp() override = default; + + void Print(std::ostream &out) const override { out << "OneHotOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kOneHotOp; } + + private: + int num_classes_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.cc new file mode 100644 index 0000000000..7b83137d88 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/pad_end_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status PadEndOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + Status s = PadEnd(input, output, output_shape_.AsVector(), pad_val_); + return s; +} + +Status PadEndOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + for (auto s : inputs) { + outputs.emplace_back(TensorShape(output_shape_.AsVector())); + } + CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "Input has a wrong shape"); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.h new file mode 100644 index 0000000000..c28f7250e0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_PAD_END_OP_H_ +#define DATASET_KERNELS_DATA_PAD_END_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class PadEndOp : public TensorOp { + public: + explicit PadEndOp(const TensorShape &pad_shape, const std::shared_ptr &pad_value) + : output_shape_(pad_shape), pad_val_(pad_value) {} + + ~PadEndOp() override = default; + + void Print(std::ostream &out) const override { out << "PadEndOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kPadEndOp; } + + private: + TensorShape output_shape_; + std::shared_ptr pad_val_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_PAD_END_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.cc new file mode 100644 index 0000000000..66f48d5c2b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/slice_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status SliceOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SliceOp supports 1D Tensors only for now."); + + // if `all` flag is true, output is just the input. + if (all_) { + *output = input; + return Status::OK(); + } + + // if slice object was provided, indices should be empty. Generate indices from the slice object. + if (slice_.valid() && indices_.empty()) { + dsize_t len = input->shape()[0]; + std::vector indices = slice_.Indices(len); + return input->Slice(output, indices); + } + + // if indices are not empty, slices should be invalid, use indices_ to slice + if (!indices_.empty() && !slice_.valid()) { + return input->Slice(output, indices_); + } + RETURN_STATUS_UNEXPECTED("The indexing parameters are invalid"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.h new file mode 100644 index 0000000000..1cf99830c9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_SLICE_OP_H_ +#define DATASET_KERNELS_DATA_SLICE_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class Slice { + public: + Slice() : start_(0), stop_(0), step_(0) {} + Slice(dsize_t start, dsize_t stop, dsize_t step) : start_(start), stop_(stop), step_(step) {} + Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} + explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} + + ~Slice() = default; + + std::vector Indices(dsize_t length) { + std::vector indices; + dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); + dsize_t end_index = std::min(Tensor::HandleNeg(stop_, length), length); + if (step_ > 0) { + for (; index < end_index; index += step_) { + indices.push_back(index); + } + } else { + for (; index > end_index; index += step_) { + indices.push_back(index); + } + } + return indices; + } + + bool valid() { return !(start_ == 0 && stop_ == 0 && step_ == 0); } + + dsize_t start_; + dsize_t stop_; + dsize_t step_; +}; + +class SliceOp : public TensorOp { + public: + explicit SliceOp(std::vector indices) : indices_(std::move(indices)) {} + explicit SliceOp(Slice slice) : slice_(slice) {} + explicit SliceOp(bool all) : all_(all) {} + + ~SliceOp() override = default; + + void Print(std::ostream &out) const override { out << "SliceOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kSliceOp; } + + private: + // only on of the following will be valid + // given indices to slice the Tensor. Empty vector if invalid. + std::vector indices_; + // Slice object. All start, stop and step are 0 if invalid. + Slice slice_; + // Flag to read all indcies in the dim. + bool all_ = false; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_SLICE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.cc new file mode 100644 index 0000000000..c52162b1aa --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/to_float16_op.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status ToFloat16Op::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + return ToFloat16(input, output); +} +Status ToFloat16Op::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = DataType(DataType::DE_FLOAT16); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h new file mode 100644 index 0000000000..91f660ca9c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDDATA_TOFLOAT16OP_H +#define MINDDATA_TOFLOAT16OP_H + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class ToFloat16Op : public TensorOp { + public: + ToFloat16Op() = default; + + ~ToFloat16Op() override = default; + + // Overrides the base class compute function + // Calls the ToFloat16 function in ImageUtils, this function takes an input tensor + // and transforms its data to float16, the output memory is manipulated to contain the result + // @return Status - The error code return + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + void Print(std::ostream &out) const override { out << "ToFloat16Op"; } + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kToFloat16Op; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDDATA_TOFLOAT16OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.cc new file mode 100644 index 0000000000..5a58745293 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/type_cast_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +TypeCastOp::TypeCastOp(const DataType &new_type) : type_(new_type) {} + +TypeCastOp::TypeCastOp(const std::string &data_type) { type_ = DataType(data_type); } + +Status TypeCastOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + return TypeCast(input, output, type_); +} +Status TypeCastOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = type_; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.h new file mode 100644 index 0000000000..b82bc32342 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ +#define DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class TypeCastOp : public TensorOp { + public: + // Constructor for TypecastOp + // @param data_type datatype to cast to + explicit TypeCastOp(const DataType &data_type); + + // Constructor for TypecastOp + // @param data_type datatype to cast to + explicit TypeCastOp(const std::string &data_type); + + ~TypeCastOp() override = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + void Print(std::ostream &out) const override { out << "TypeCastOp"; } + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kTypeCastOp; } + + private: + DataType type_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt new file mode 100644 index 0000000000..c0c575de9a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -0,0 +1,30 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +add_library(kernels-image OBJECT + center_crop_op.cc + cut_out_op.cc + decode_op.cc + hwc_to_chw_op.cc + image_utils.cc + normalize_op.cc + pad_op.cc + random_color_adjust_op.cc + random_crop_decode_resize_op.cc + random_crop_and_resize_with_bbox_op.cc + random_crop_and_resize_op.cc + random_crop_op.cc + random_crop_with_bbox_op.cc + random_horizontal_flip_op.cc + random_horizontal_flip_with_bbox_op.cc + bounding_box_augment_op.cc + random_resize_op.cc + random_rotation_op.cc + random_vertical_flip_op.cc + random_vertical_flip_with_bbox_op.cc + rescale_op.cc + resize_bilinear_op.cc + resize_op.cc + uniform_aug_op.cc + resize_with_bbox_op.cc + random_resize_with_bbox_op.cc + ) diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.cc new file mode 100644 index 0000000000..618ed4d356 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/kernels/image/bounding_box_augment_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/core/cv_tensor.h" + +namespace mindspore { +namespace dataset { +const float BoundingBoxAugmentOp::kDefRatio = 0.3; + +BoundingBoxAugmentOp::BoundingBoxAugmentOp(std::shared_ptr transform, float ratio) + : ratio_(ratio), uniform_(0, 1), transform_(std::move(transform)) { + rnd_.seed(GetSeed()); +} + +Status BoundingBoxAugmentOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); // check if bounding boxes are valid + uint32_t num_of_boxes = input[1]->shape()[0]; + std::shared_ptr crop_out; + std::shared_ptr res_out; + std::shared_ptr input_restore = CVTensor::AsCVTensor(input[0]); + for (uint32_t i = 0; i < num_of_boxes; i++) { + // using a uniform distribution to ensure op happens with probability ratio_ + if (uniform_(rnd_) < ratio_) { + float min_x = 0; + float min_y = 0; + float b_w = 0; + float b_h = 0; + // get the required items + RETURN_IF_NOT_OK(input[1]->GetItemAt(&min_x, {i, 0})); + RETURN_IF_NOT_OK(input[1]->GetItemAt(&min_y, {i, 1})); + RETURN_IF_NOT_OK(input[1]->GetItemAt(&b_w, {i, 2})); + RETURN_IF_NOT_OK(input[1]->GetItemAt(&b_h, {i, 3})); + RETURN_IF_NOT_OK(Crop(input_restore, &crop_out, static_cast(min_x), static_cast(min_y), + static_cast(b_w), static_cast(b_h))); + // transform the cropped bbox region + RETURN_IF_NOT_OK(transform_->Compute(crop_out, &res_out)); + // place the transformed region back in the restored input + std::shared_ptr res_img = CVTensor::AsCVTensor(res_out); + // check if transformed crop is out of bounds of the box + if (res_img->mat().cols > b_w || res_img->mat().rows > b_h || res_img->mat().cols < b_w || + res_img->mat().rows < b_h) { + // if so, resize to fit in the box + std::shared_ptr resize_op = + std::make_shared(static_cast(b_h), static_cast(b_w)); + RETURN_IF_NOT_OK(resize_op->Compute(std::static_pointer_cast(res_img), &res_out)); + res_img = CVTensor::AsCVTensor(res_out); + } + res_img->mat().copyTo(input_restore->mat()(cv::Rect(min_x, min_y, res_img->mat().cols, res_img->mat().rows))); + } + } + (*output).push_back(std::move(std::static_pointer_cast(input_restore))); + (*output).push_back(input[1]); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.h new file mode 100644 index 0000000000..8e30c5738d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ +#define DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +class BoundingBoxAugmentOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefRatio; + + // Constructor for BoundingBoxAugmentOp + // @param std::shared_ptr transform transform: C++ opration to apply on select bounding boxes + // @param float ratio: ratio of bounding boxes to have the transform applied on + BoundingBoxAugmentOp(std::shared_ptr transform, float ratio); + + ~BoundingBoxAugmentOp() override = default; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const BoundingBoxAugmentOp &so) { + so.Print(out); + return out; + } + + void Print(std::ostream &out) const override { out << "BoundingBoxAugmentOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kBoundingBoxAugmentOp; } + + private: + float ratio_; + std::mt19937 rnd_; + std::uniform_real_distribution uniform_; + std::shared_ptr transform_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc new file mode 100644 index 0000000000..35079b05cd --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/center_crop_op.h" +#include +#include "common/utils.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t CenterCropOp::kDefWidth = 0; + +Status CenterCropOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + std::string err_msg; + dsize_t rank = input->shape().Rank(); + err_msg += (rank < 2 || rank > 3) ? "Rank received::" + std::to_string(rank) + " Expected: 2 or 3 \t" : ""; + err_msg += (crop_het_ <= 0 || crop_wid_ <= 0) ? "crop size needs to be positive integers\t" : ""; + + if (err_msg.length() != 0) RETURN_STATUS_UNEXPECTED(common::SafeCStr(err_msg)); + + int32_t top = crop_het_ - input->shape()[0]; // number of pixels to pad (top and bottom) + int32_t left = crop_wid_ - input->shape()[1]; + std::shared_ptr pad_image; + if (top > 0 && left > 0) { // padding only + return Pad(input, output, top / 2 + top % 2, top / 2, left / 2 + left % 2, left / 2, BorderType::kConstant); + } else if (top > 0) { + RETURN_IF_NOT_OK(Pad(input, &pad_image, top / 2 + top % 2, top / 2, 0, 0, BorderType::kConstant)); + return Crop(pad_image, output, (static_cast(pad_image->shape()[1]) - crop_wid_) / 2, + (static_cast(pad_image->shape()[0]) - crop_het_) / 2, crop_wid_, crop_het_); + } else if (left > 0) { + RETURN_IF_NOT_OK(Pad(input, &pad_image, 0, 0, left / 2 + left % 2, left / 2, BorderType::kConstant)); + return Crop(pad_image, output, (static_cast(pad_image->shape()[1]) - crop_wid_) / 2, + (static_cast(pad_image->shape()[0]) - crop_het_) / 2, crop_wid_, crop_het_); + } + return Crop(input, output, (input->shape()[1] - crop_wid_) / 2, (input->shape()[0] - crop_het_) / 2, crop_wid_, + crop_het_); +} + +void CenterCropOp::Print(std::ostream &out) const { + out << "CenterCropOp: " + << "cropWidth: " << crop_wid_ << "cropHeight: " << crop_het_ << "\n"; +} +Status CenterCropOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out = TensorShape{crop_het_, crop_wid_}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.h new file mode 100644 index 0000000000..1f8cbcf230 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ +#define DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class CenterCropOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefWidth; + + explicit CenterCropOp(int32_t het, int32_t wid = kDefWidth) : crop_het_(het), crop_wid_(wid == 0 ? het : wid) {} + + ~CenterCropOp() override = default; + + void Print(std::ostream &out) const override; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kCenterCropOp; } + + private: + int32_t crop_het_; + int32_t crop_wid_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.cc new file mode 100644 index 0000000000..578138d427 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/kernels/image/cut_out_op.h" + +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const bool CutOutOp::kDefRandomColor = false; +const uint8_t CutOutOp::kDefFillR = 0; +const uint8_t CutOutOp::kDefFillG = 0; +const uint8_t CutOutOp::kDefFillB = 0; + +// constructor +CutOutOp::CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color, uint8_t fill_r, + uint8_t fill_g, uint8_t fill_b) + : rnd_(GetSeed()), + box_height_(box_height), + box_width_(box_width), + num_patches_(num_patches), + random_color_(random_color), + fill_r_(fill_r), + fill_g_(fill_g), + fill_b_(fill_b) {} + +// main function call for cut out +Status CutOutOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + std::shared_ptr inputCV = CVTensor::AsCVTensor(input); + // cut out will clip the erasing area if the box is near the edge of the image and the boxes are black + RETURN_IF_NOT_OK(Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, &rnd_, fill_r_, + fill_g_, fill_b_)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h new file mode 100644 index 0000000000..263cbdb27c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ +#define DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class CutOutOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const bool kDefRandomColor; + static const uint8_t kDefFillR; + static const uint8_t kDefFillG; + static const uint8_t kDefFillB; + + // Constructor for CutOutOp + // @param box_height box height + // @param box_width box_width + // @param num_patches how many patches to erase from image + // @param random_color boolean value to indicate fill patch with random color + // @param fill_r R value for the color to fill patch with + // @param fill_g G value for the color to fill patch with + // @param fill_b B value for the color to fill patch with + // @note maybe using unsigned long int isn't the best here according to our coding rules + CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color = kDefRandomColor, + uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); + + ~CutOutOp() override = default; + + void Print(std::ostream &out) const override { + out << "CutOut:: box_height: " << box_height_ << " box_width: " << box_width_ << " num_patches: " << num_patches_; + } + + // Overrides the base class compute function + // Calls the erase function in ImageUtils, this function takes an input tensor + // and overwrites some of its data using openCV, the output memory is manipulated to contain the result + // @return Status - The error code return + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kCutOutOp; } + + private: + std::mt19937 rnd_; + int32_t box_height_; + int32_t box_width_; + int32_t num_patches_; + bool random_color_; + uint8_t fill_r_; + uint8_t fill_g_; + uint8_t fill_b_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.cc new file mode 100644 index 0000000000..5bc5377de9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/decode_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const bool DecodeOp::kDefRgbFormat = true; + +DecodeOp::DecodeOp(bool is_rgb_format) : is_rgb_format_(is_rgb_format) { + if (is_rgb_format_) { // RGB colour mode + MS_LOG(DEBUG) << "Decode colour mode is RGB."; + } else { + MS_LOG(DEBUG) << "Decode colour mode is BGR."; + } +} + +Status DecodeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (is_rgb_format_) { // RGB colour mode + return Decode(input, output); + } else { // BGR colour mode + RETURN_STATUS_UNEXPECTED("Decode BGR is deprecated"); + } +} +Status DecodeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out({-1, -1, 3}); // we don't know what is output image size, but we know it should be 3 channels + if (inputs[0].Rank() == 1) outputs.emplace_back(out); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} + +Status DecodeOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = DataType(DataType::DE_UINT8); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.h new file mode 100644 index 0000000000..29bf1d0146 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_DECODE_OP_H_ +#define DATASET_KERNELS_IMAGE_DECODE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DecodeOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const bool kDefRgbFormat; + + explicit DecodeOp(bool is_rgb_format = true); + + ~DecodeOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + void Print(std::ostream &out) const override { out << "DecodeOp"; } + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kDecodeOp; } + + private: + bool is_rgb_format_ = true; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_DECODE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.cc new file mode 100644 index 0000000000..5013958562 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/hwc_to_chw_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status HwcToChwOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // input.shape == HWC + // output.shape == CHW + return HwcToChw(input, output); +} +Status HwcToChwOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape in = inputs[0]; + TensorShape out = TensorShape{in[2], in[0], in[1]}; + if (inputs[0].Rank() == 3) outputs.emplace_back(out); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.h new file mode 100644 index 0000000000..0d5f70f895 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ +#define DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class HwcToChwOp : public TensorOp { + public: + void Print(std::ostream &out) const override { out << "HwcToChw"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kHwcToChwOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc new file mode 100644 index 0000000000..ddbce3e23a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -0,0 +1,836 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/image_utils.h" +#include +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/util/random.h" + +#define MAX_INT_PRECISION 16777216 // float int precision is 16777216 +namespace mindspore { +namespace dataset { +int GetCVInterpolationMode(InterpolationMode mode) { + switch (mode) { + case InterpolationMode::kLinear: + return static_cast(cv::InterpolationFlags::INTER_LINEAR); + case InterpolationMode::kCubic: + return static_cast(cv::InterpolationFlags::INTER_CUBIC); + case InterpolationMode::kArea: + return static_cast(cv::InterpolationFlags::INTER_AREA); + case InterpolationMode::kNearestNeighbour: + return static_cast(cv::InterpolationFlags::INTER_NEAREST); + default: + return static_cast(cv::InterpolationFlags::INTER_LINEAR); + } +} + +int GetCVBorderType(BorderType type) { + switch (type) { + case BorderType::kConstant: + return static_cast(cv::BorderTypes::BORDER_CONSTANT); + case BorderType::kEdge: + return static_cast(cv::BorderTypes::BORDER_REPLICATE); + case BorderType::kReflect: + return static_cast(cv::BorderTypes::BORDER_REFLECT101); + case BorderType::kSymmetric: + return static_cast(cv::BorderTypes::BORDER_REFLECT); + default: + return static_cast(cv::BorderTypes::BORDER_CONSTANT); + } +} + +Status Flip(std::shared_ptr input, std::shared_ptr *output, int flip_code) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); + + std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + RETURN_IF_NOT_OK(output_cv->AllocateBuffer(output_cv->SizeInBytes())); + + if (input_cv->mat().data) { + try { + cv::flip(input_cv->mat(), output_cv->mat(), flip_code); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in flip op."); + } + } else { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor, the input data is null"); + } +} + +Status HorizontalFlip(std::shared_ptr input, std::shared_ptr *output) { + return Flip(std::move(input), output, 1); +} + +Status VerticalFlip(std::shared_ptr input, std::shared_ptr *output) { + return Flip(std::move(input), output, 0); +} + +Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, + int32_t output_width, double fx, double fy, InterpolationMode mode) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of or "); + } + cv::Mat in_image = input_cv->mat(); + // resize image too large or too small + if (output_height == 0 || output_height > in_image.rows * 1000 || output_width == 0 || + output_width > in_image.cols * 1000) { + std::string err_msg = + "The resizing width or height 1) is too big, it's up to " + "1000 times the original image; 2) can not be 0."; + return Status(StatusCode::kShapeMisMatch, err_msg); + } + try { + TensorShape shape{output_height, output_width}; + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); + std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + auto cv_mode = GetCVInterpolationMode(mode); + cv::resize(in_image, output_cv->mat(), cv::Size(output_width, output_height), fx, fy, cv_mode); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in image resize."); + } +} + +bool IsNonEmptyJPEG(const std::shared_ptr &input) { + const unsigned char *kJpegMagic = (unsigned char *)"\xFF\xD8\xFF"; + constexpr size_t kJpegMagicLen = 3; + return input->SizeInBytes() > kJpegMagicLen && memcmp(input->GetBuffer(), kJpegMagic, kJpegMagicLen) == 0; +} + +Status Decode(const std::shared_ptr &input, std::shared_ptr *output) { + if (IsNonEmptyJPEG(input)) { + return JpegCropAndDecode(input, output); + } else { + return DecodeCv(input, output); + } +} + +Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *output) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + try { + cv::Mat img_mat = cv::imdecode(input_cv->mat(), cv::IMREAD_COLOR | cv::IMREAD_IGNORE_ORIENTATION); + if (img_mat.data == nullptr) { + std::string err = "Error in decoding\t"; + RETURN_STATUS_UNEXPECTED(err); + } + cv::cvtColor(img_mat, img_mat, static_cast(cv::COLOR_BGR2RGB)); + std::shared_ptr output_cv = std::make_shared(img_mat); + RETURN_UNEXPECTED_IF_NULL(output_cv); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in image Decode"); + } +} + +static void JpegInitSource(j_decompress_ptr cinfo) {} + +static boolean JpegFillInputBuffer(j_decompress_ptr cinfo) { + if (cinfo->src->bytes_in_buffer == 0) { + ERREXIT(cinfo, JERR_INPUT_EMPTY); + return FALSE; + } + return TRUE; +} + +static void JpegTermSource(j_decompress_ptr cinfo) {} + +static void JpegSkipInputData(j_decompress_ptr cinfo, int64_t jump) { + if (jump < 0) { + return; + } + if (static_cast(jump) > cinfo->src->bytes_in_buffer) { + cinfo->src->bytes_in_buffer = 0; + return; + } else { + cinfo->src->bytes_in_buffer -= jump; + cinfo->src->next_input_byte += jump; + } +} + +void JpegSetSource(j_decompress_ptr cinfo, const void *data, int64_t datasize) { + cinfo->src = static_cast( + (*cinfo->mem->alloc_small)(reinterpret_cast(cinfo), JPOOL_PERMANENT, sizeof(struct jpeg_source_mgr))); + cinfo->src->init_source = JpegInitSource; + cinfo->src->fill_input_buffer = JpegFillInputBuffer; +#if defined(_WIN32) || defined(_WIN64) + cinfo->src->skip_input_data = reinterpret_cast(JpegSkipInputData); +#else + cinfo->src->skip_input_data = JpegSkipInputData; +#endif + cinfo->src->resync_to_restart = jpeg_resync_to_restart; + cinfo->src->term_source = JpegTermSource; + cinfo->src->bytes_in_buffer = datasize; + cinfo->src->next_input_byte = static_cast(data); +} + +static Status JpegReadScanlines(jpeg_decompress_struct *const cinfo, int max_scanlines_to_read, JSAMPLE *buffer, + int buffer_size, int crop_w, int crop_w_aligned, int offset, int stride) { + // scanlines will be read to this buffer first, must have the number + // of components equal to the number of components in the image + int64_t scanline_size = crop_w_aligned * cinfo->output_components; + std::vector scanline(scanline_size); + JSAMPLE *scanline_ptr = &scanline[0]; + while (cinfo->output_scanline < static_cast(max_scanlines_to_read)) { + int num_lines_read = jpeg_read_scanlines(cinfo, &scanline_ptr, 1); + if (cinfo->out_color_space == JCS_CMYK && num_lines_read > 0) { + for (int i = 0; i < crop_w; ++i) { + int cmyk_pixel = 4 * i + offset; + const int c = scanline_ptr[cmyk_pixel]; + const int m = scanline_ptr[cmyk_pixel + 1]; + const int y = scanline_ptr[cmyk_pixel + 2]; + const int k = scanline_ptr[cmyk_pixel + 3]; + int r, g, b; + if (cinfo->saw_Adobe_marker) { + r = (k * c) / 255; + g = (k * m) / 255; + b = (k * y) / 255; + } else { + r = (255 - c) * (255 - k) / 255; + g = (255 - m) * (255 - k) / 255; + b = (255 - y) * (255 - k) / 255; + } + buffer[3 * i + 0] = r; + buffer[3 * i + 1] = g; + buffer[3 * i + 2] = b; + } + } else if (num_lines_read > 0) { + int copy_status = memcpy_s(buffer, buffer_size, scanline_ptr + offset, stride); + if (copy_status != 0) { + jpeg_destroy_decompress(cinfo); + RETURN_STATUS_UNEXPECTED("memcpy failed"); + } + } else { + jpeg_destroy_decompress(cinfo); + std::string err_msg = "failed to read scanline"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + buffer += stride; + buffer_size = buffer_size - stride; + } + return Status::OK(); +} + +static Status JpegSetColorSpace(jpeg_decompress_struct *cinfo) { + switch (cinfo->num_components) { + case 1: + // we want to output 3 components if it's grayscale + cinfo->out_color_space = JCS_RGB; + return Status::OK(); + case 3: + cinfo->out_color_space = JCS_RGB; + return Status::OK(); + case 4: + // Need to manually convert to RGB + cinfo->out_color_space = JCS_CMYK; + return Status::OK(); + default: + jpeg_destroy_decompress(cinfo); + std::string err_msg = "wrong number of components"; + RETURN_STATUS_UNEXPECTED(err_msg); + } +} + +void JpegErrorExitCustom(j_common_ptr cinfo) { + char jpeg_last_error_msg[JMSG_LENGTH_MAX]; + (*(cinfo->err->format_message))(cinfo, jpeg_last_error_msg); + throw std::runtime_error(jpeg_last_error_msg); +} + +Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr *output, int crop_x, int crop_y, + int crop_w, int crop_h) { + struct jpeg_decompress_struct cinfo; + auto DestroyDecompressAndReturnError = [&cinfo](const std::string &err) { + jpeg_destroy_decompress(&cinfo); + RETURN_STATUS_UNEXPECTED(err); + }; + struct JpegErrorManagerCustom jerr; + cinfo.err = jpeg_std_error(&jerr.pub); + jerr.pub.error_exit = JpegErrorExitCustom; + try { + jpeg_create_decompress(&cinfo); + JpegSetSource(&cinfo, input->GetBuffer(), input->SizeInBytes()); + (void)jpeg_read_header(&cinfo, TRUE); + RETURN_IF_NOT_OK(JpegSetColorSpace(&cinfo)); + jpeg_calc_output_dimensions(&cinfo); + } catch (std::runtime_error &e) { + return DestroyDecompressAndReturnError(e.what()); + } + if (crop_x == 0 && crop_y == 0 && crop_w == 0 && crop_h == 0) { + crop_w = cinfo.output_width; + crop_h = cinfo.output_height; + } else if (crop_w == 0 || static_cast(crop_w + crop_x) > cinfo.output_width || crop_h == 0 || + static_cast(crop_h + crop_y) > cinfo.output_height) { + return DestroyDecompressAndReturnError("Crop window is not valid"); + } + const int mcu_size = cinfo.min_DCT_scaled_size; + unsigned int crop_x_aligned = (crop_x / mcu_size) * mcu_size; + unsigned int crop_w_aligned = crop_w + crop_x - crop_x_aligned; + try { + (void)jpeg_start_decompress(&cinfo); + jpeg_crop_scanline(&cinfo, &crop_x_aligned, &crop_w_aligned); + } catch (std::runtime_error &e) { + return DestroyDecompressAndReturnError(e.what()); + } + JDIMENSION skipped_scanlines = jpeg_skip_scanlines(&cinfo, crop_y); + // three number of output components, always convert to RGB and output + constexpr int kOutNumComponents = 3; + TensorShape ts = TensorShape({crop_h, crop_w, kOutNumComponents}); + auto output_tensor = std::make_shared(ts, DataType(DataType::DE_UINT8)); + const int buffer_size = output_tensor->SizeInBytes(); + JSAMPLE *buffer = reinterpret_cast(&(*output_tensor->begin())); + const int max_scanlines_to_read = skipped_scanlines + crop_h; + // stride refers to output tensor, which has 3 components at most + const int stride = crop_w * kOutNumComponents; + // offset is calculated for scanlines read from the image, therefore + // has the same number of components as the image + const int offset = (crop_x - crop_x_aligned) * cinfo.output_components; + RETURN_IF_NOT_OK( + JpegReadScanlines(&cinfo, max_scanlines_to_read, buffer, buffer_size, crop_w, crop_w_aligned, offset, stride)); + *output = output_tensor; + jpeg_destroy_decompress(&cinfo); + return Status::OK(); +} + +Status Rescale(const std::shared_ptr &input, std::shared_ptr *output, float rescale, float shift) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + cv::Mat input_image = input_cv->mat(); + std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); + RETURN_UNEXPECTED_IF_NULL(output_cv); + try { + input_image.convertTo(output_cv->mat(), CV_32F, rescale, shift); + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in image rescale"); + } + return Status::OK(); +} + +Status Crop(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, int w, int h) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Shape not or "); + } + try { + TensorShape shape{h, w}; + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); + std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + cv::Rect roi(x, y, w, h); + (input_cv->mat())(roi).copyTo(output_cv->mat()); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in crop."); + } +} + +Status HwcToChw(std::shared_ptr input, std::shared_ptr *output) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() == 2) { + // If input tensor is 2D, we assume we have hw dimensions + *output = input; + return Status::OK(); + } + int num_channels = input_cv->shape()[2]; + if (input_cv->shape().Size() < 2 || input_cv->shape().Size() > 3 || + (input_cv->shape().Size() == 3 && num_channels != 3 && num_channels != 1)) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3 nor 1"); + } + cv::Mat output_img; + + int height = input_cv->shape()[0]; + int width = input_cv->shape()[1]; + + auto output_cv = std::make_unique(TensorShape{num_channels, height, width}, input_cv->type()); + for (int i = 0; i < num_channels; ++i) { + cv::Mat mat; + RETURN_IF_NOT_OK(output_cv->Mat({i}, &mat)); + cv::extractChannel(input_cv->mat(), mat, i); + } + *output = std::move(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in ChannelSwap."); + } +} + +Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); + int num_channels = input_cv->shape()[2]; + if (input_cv->shape().Size() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); + } + auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast(cv::COLOR_BGR2RGB)); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in ChangeMode."); + } +} + +Status CropAndResize(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, + int crop_height, int crop_width, int target_height, int target_width, InterpolationMode mode) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Shape not or "); + } + // image too large or too small + if (crop_height == 0 || crop_width == 0 || target_height == 0 || target_height > crop_height * 1000 || + target_width == 0 || target_height > crop_width * 1000) { + std::string err_msg = + "The resizing width or height 1) is too big, it's up to " + "1000 times the original image; 2) can not be 0."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + cv::Rect roi(x, y, crop_width, crop_height); + auto cv_mode = GetCVInterpolationMode(mode); + cv::Mat cv_in = input_cv->mat(); + TensorShape shape{target_height, target_width}; + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); + std::shared_ptr cvt_out = std::make_shared(shape, input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(cvt_out); + cv::resize(cv_in(roi), cvt_out->mat(), cv::Size(target_width, target_height), 0, 0, cv_mode); + *output = std::static_pointer_cast(cvt_out); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in CropAndResize."); + } +} + +Status Rotate(const std::shared_ptr &input, std::shared_ptr *output, float fx, float fy, float degree, + InterpolationMode interpolation, bool expand, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + cv::Mat input_img = input_cv->mat(); + if (input_img.cols > (MAX_INT_PRECISION * 2) || input_img.rows > (MAX_INT_PRECISION * 2)) { + RETURN_STATUS_UNEXPECTED("Image too large center not precise"); + } + // default to center of image + if (fx == -1 && fy == -1) { + fx = (input_img.cols - 1) / 2.0; + fy = (input_img.rows - 1) / 2.0; + } + cv::Mat output_img; + cv::Scalar fill_color = cv::Scalar(fill_b, fill_g, fill_r); + // maybe don't use uint32 for image dimension here + cv::Point2f pc(fx, fy); + cv::Mat rot = cv::getRotationMatrix2D(pc, degree, 1.0); + std::shared_ptr output_cv; + if (!expand) { + // this case means that the shape doesn't change, size stays the same + // We may not need this memcpy if it is in place. + output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + // using inter_nearest to comply with python default + cv::warpAffine(input_img, output_cv->mat(), rot, input_img.size(), GetCVInterpolationMode(interpolation), + cv::BORDER_CONSTANT, fill_color); + } else { + // we resize here since the shape changes + // create a new bounding box with the rotate + cv::Rect2f bbox = cv::RotatedRect(cv::Point2f(), input_img.size(), degree).boundingRect2f(); + rot.at(0, 2) += bbox.width / 2.0 - input_img.cols / 2.0; + rot.at(1, 2) += bbox.height / 2.0 - input_img.rows / 2.0; + // use memcpy and don't compute the new shape since openCV has a rounding problem + cv::warpAffine(input_img, output_img, rot, bbox.size(), GetCVInterpolationMode(interpolation), + cv::BORDER_CONSTANT, fill_color); + output_cv = std::make_shared(output_img); + RETURN_UNEXPECTED_IF_NULL(output_cv); + } + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in image rotation"); + } + return Status::OK(); +} + +Status Normalize(const std::shared_ptr &input, std::shared_ptr *output, + const std::shared_ptr &mean, const std::shared_ptr &std) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!(input_cv->mat().data && input_cv->Rank() == 3)) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + cv::Mat in_image = input_cv->mat(); + std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); + RETURN_UNEXPECTED_IF_NULL(output_cv); + mean->Squeeze(); + if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) { + std::string err_msg = "Mean tensor should be of size 3 and type float."; + return Status(StatusCode::kShapeMisMatch, err_msg); + } + std->Squeeze(); + if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) { + std::string err_msg = "Std tensor should be of size 3 and type float."; + return Status(StatusCode::kShapeMisMatch, err_msg); + } + try { + // NOTE: We are assuming the input image is in RGB and the mean + // and std are in RGB + cv::Mat rgb[3]; + cv::split(in_image, rgb); + for (uint8_t i = 0; i < 3; i++) { + float mean_c, std_c; + RETURN_IF_NOT_OK(mean->GetItemAt(&mean_c, {i})); + RETURN_IF_NOT_OK(std->GetItemAt(&std_c, {i})); + rgb[i].convertTo(rgb[i], CV_32F, 1.0 / std_c, (-mean_c / std_c)); + } + cv::merge(rgb, 3, output_cv->mat()); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in Normalize"); + } +} + +Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + cv::Mat input_img = input_cv->mat(); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); + } + auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + output_cv->mat() = input_img * alpha; + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in adjust brightness"); + } + return Status::OK(); +} + +Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + cv::Mat input_img = input_cv->mat(); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); + } + cv::Mat gray, output_img; + cv::cvtColor(input_img, gray, CV_RGB2GRAY); + int mean_img = static_cast(cv::mean(gray).val[0] + 0.5); + std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + output_img = cv::Mat::zeros(input_img.rows, input_img.cols, CV_8UC1); + output_img = output_img + mean_img; + cv::cvtColor(output_img, output_img, CV_GRAY2RGB); + output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha; + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in adjust contrast"); + } + return Status::OK(); +} + +Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + cv::Mat input_img = input_cv->mat(); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); + } + auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + cv::Mat output_img = output_cv->mat(); + cv::Mat gray; + cv::cvtColor(input_img, gray, CV_RGB2GRAY); + cv::cvtColor(gray, output_img, CV_GRAY2RGB); + output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha; + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in adjust saturation"); + } + return Status::OK(); +} + +Status AdjustHue(const std::shared_ptr &input, std::shared_ptr *output, const float &hue) { + if (hue > 0.5 || hue < -0.5) { + MS_LOG(ERROR) << "Hue factor is not in [-0.5, 0.5]."; + RETURN_STATUS_UNEXPECTED("hue_factor is not in [-0.5, 0.5]."); + } + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + cv::Mat input_img = input_cv->mat(); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); + } + auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + cv::Mat output_img; + cv::cvtColor(input_img, output_img, CV_RGB2HSV_FULL); + for (int y = 0; y < output_img.cols; y++) { + for (int x = 0; x < output_img.rows; x++) { + uint8_t cur1 = output_img.at(cv::Point(y, x))[0]; + uint8_t h_hue = 0; + h_hue = static_cast(hue * 255); + cur1 += h_hue; + output_img.at(cv::Point(y, x))[0] = cur1; + } + } + cv::cvtColor(output_img, output_cv->mat(), CV_HSV2RGB_FULL); + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in adjust hue"); + } + return Status::OK(); +} + +Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, + int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, + uint8_t fill_g, uint8_t fill_b) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + int num_channels = input_cv->shape()[2]; + if (input_cv->mat().data == nullptr || input_cv->Rank() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase"); + } + cv::Mat input_img = input_cv->mat(); + int32_t image_h = input_cv->shape()[0]; + int32_t image_w = input_cv->shape()[1]; + // check if erase size is bigger than image itself + if (box_height > image_h || box_width > image_w) { + RETURN_STATUS_UNEXPECTED("input box size too large for image erase"); + } + + // for random color + std::normal_distribution normal_distribution(0, 1); + std::uniform_int_distribution height_distribution_bound(0, image_h - box_height); + std::uniform_int_distribution width_distribution_bound(0, image_w - box_width); + std::uniform_int_distribution height_distribution_unbound(0, image_h + box_height); + std::uniform_int_distribution width_distribution_unbound(0, image_w + box_width); + // core logic + // update values based on random erasing or cutout + + for (int32_t i = 0; i < num_patches; i++) { + // rows in cv mat refers to the height of the cropped box + // we determine h_start and w_start using two different distributions as erasing is used by two different + // image augmentations. The bounds are also different in each case. + int32_t h_start = (bounded) ? height_distribution_bound(*rnd) : (height_distribution_unbound(*rnd) - box_height); + int32_t w_start = (bounded) ? width_distribution_bound(*rnd) : (width_distribution_unbound(*rnd) - box_width); + + int32_t max_width = (w_start + box_width > image_w) ? image_w : w_start + box_width; + int32_t max_height = (h_start + box_height > image_h) ? image_h : h_start + box_height; + // check for starting range >= 0, here the start range is checked after for cut out, for random erasing + // w_start and h_start will never be less than 0. + h_start = (h_start < 0) ? 0 : h_start; + w_start = (w_start < 0) ? 0 : w_start; + for (int y = w_start; y < max_width; y++) { + for (int x = h_start; x < max_height; x++) { + if (random_color) { + // fill each box with a random value + input_img.at(cv::Point(y, x))[0] = static_cast(normal_distribution(*rnd)); + input_img.at(cv::Point(y, x))[1] = static_cast(normal_distribution(*rnd)); + input_img.at(cv::Point(y, x))[2] = static_cast(normal_distribution(*rnd)); + } else { + input_img.at(cv::Point(y, x))[0] = fill_r; + input_img.at(cv::Point(y, x))[1] = fill_g; + input_img.at(cv::Point(y, x))[2] = fill_b; + } + } + } + } + *output = std::static_pointer_cast(input); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in erasing"); + } +} + +Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, + const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, + uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { + try { + // input image + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + // get the border type in openCV + auto b_type = GetCVBorderType(border_types); + // output image + cv::Mat out_image; + if (b_type == cv::BORDER_CONSTANT) { + cv::Scalar fill_color = cv::Scalar(fill_b, fill_g, fill_r); + cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type, fill_color); + } else { + cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type); + } + std::shared_ptr output_cv = std::make_shared(out_image); + RETURN_UNEXPECTED_IF_NULL(output_cv); + // pad the dimension if shape information is only 2 dimensional, this is grayscale + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() == 3 && num_channels == 1 && output_cv->Rank() == 2) output_cv->ExpandDim(2); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in pad"); + } +} +// -------- BBOX OPERATIONS -------- // +Status UpdateBBoxesForCrop(std::shared_ptr *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax, + int CB_Ymax) { + // PASS LIST, COUNT OF BOUNDING BOXES + // Also PAss X/Y Min/Max of image cropped region - normally obtained from 'GetCropBox' functions + float bb_Xmin = 0.0, bb_Ymin = 0.0, bb_Xmax = 0.0, bb_Ymax = 0.0; + std::vector correct_ind; + std::vector copyVals; + dsize_t bboxDim = (*bboxList)->shape()[1]; + bool retFlag = false; // true unless overlap found + for (int i = 0; i < *bboxCount; i++) { + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Xmin, {i, 0})); + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Ymin, {i, 1})); + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Xmax, {i, 2})); + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Ymax, {i, 3})); + bb_Xmax = bb_Xmin + bb_Xmax; + bb_Ymax = bb_Ymin + bb_Ymax; + // check for image / BB overlap + if (((bb_Xmin > CB_Xmax) || (bb_Ymin > CB_Ymax)) || ((bb_Xmax < CB_Xmin) || (bb_Ymax < CB_Ymin))) { + continue; // no overlap found + } + // Update this bbox and select it to move to the final output tensor + correct_ind.push_back(i); + // adjust BBox corners by bringing into new CropBox if beyond + // Also reseting/adjusting for boxes to lie within CropBox instead of Image - subtract CropBox Xmin/YMin + + bb_Xmin = bb_Xmin - std::min(static_cast(0.0), (bb_Xmin - CB_Xmin)) - CB_Xmin; + bb_Xmax = bb_Xmax - std::max(static_cast(0.0), (bb_Xmax - CB_Xmax)) - CB_Xmin; + bb_Ymin = bb_Ymin - std::min(static_cast(0.0), (bb_Ymin - CB_Ymin)) - CB_Ymin; + bb_Ymax = bb_Ymax - std::max(static_cast(0.0), (bb_Ymax - CB_Ymax)) - CB_Ymin; + + // bound check for float values + bb_Xmin = std::max(bb_Xmin, static_cast(0)); + bb_Ymin = std::max(bb_Ymin, static_cast(0)); + bb_Xmax = std::min(bb_Xmax, static_cast(CB_Xmax - CB_Xmin)); // find max value relative to new image + bb_Ymax = std::min(bb_Ymax, static_cast(CB_Ymax - CB_Ymin)); + + // reset min values and calculate width/height from Box corners + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 0}, bb_Xmin)); + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 1}, bb_Ymin)); + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 2}, bb_Xmax - bb_Xmin)); + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 3}, bb_Ymax - bb_Ymin)); + } + // create new tensor and copy over bboxes still valid to the image + // bboxes outside of new cropped region are ignored - empty tensor returned in case of none + *bboxCount = correct_ind.size(); + float temp = 0.0; + for (auto slice : correct_ind) { // for every index in the loop + for (int ix = 0; ix < bboxDim; ix++) { + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&temp, {slice, ix})); + copyVals.push_back(temp); + } + } + std::shared_ptr retV; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&retV, copyVals, TensorShape({static_cast(*bboxCount), bboxDim}))); + (*bboxList) = retV; // reset pointer + return Status::OK(); +} + +Status PadBBoxes(const std::shared_ptr *bboxList, const size_t &bboxCount, int32_t pad_top, int32_t pad_left) { + for (int i = 0; i < bboxCount; i++) { + float xMin = 0.0, yMin = 0.0; + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&xMin, {i, 0})); + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&yMin, {i, 1})); + xMin += pad_left; + yMin += pad_top; + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 0}, xMin)); + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 1}, yMin)); + } + return Status::OK(); +} + +Status UpdateBBoxesForResize(const std::shared_ptr &bboxList, const size_t &bboxCount, int32_t target_width_, + int32_t target_height_, int orig_width, int orig_height) { + float bb_Xmin = 0, bb_Ymin = 0, bb_Xwidth = 0, bb_Ywidth = 0; + // cast to float to preserve fractional + float W_aspRatio = (target_width_ * 1.0) / (orig_width * 1.0); + float H_aspRatio = (target_height_ * 1.0) / (orig_height * 1.0); + for (int i = 0; i < bboxCount; i++) { + // for each bounding box + RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Xmin, {i, 0})); + RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Ymin, {i, 1})); + RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Xwidth, {i, 2})); + RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Ywidth, {i, 3})); + // update positions and widths + bb_Xmin = bb_Xmin * W_aspRatio; + bb_Ymin = bb_Ymin * H_aspRatio; + bb_Xwidth = bb_Xwidth * W_aspRatio; + bb_Ywidth = bb_Ywidth * H_aspRatio; + // reset bounding box values + RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 0}, bb_Xmin)); + RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 1}, bb_Ymin)); + RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 2}, bb_Xwidth)); + RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 3}, bb_Ywidth)); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h new file mode 100644 index 0000000000..f489c7367b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -0,0 +1,259 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ +#define DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ + +#include + +#include +#include +#include +#include +#if defined(_WIN32) || defined(_WIN64) +#undef HAVE_STDDEF_H +#undef HAVE_STDLIB_H +#endif +#include "./jpeglib.h" +#include "./jerror.h" +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +void JpegErrorExitCustom(j_common_ptr cinfo); + +struct JpegErrorManagerCustom { + // "public" fields + struct jpeg_error_mgr pub; + // for return to caller + jmp_buf setjmp_buffer; +}; + +// Returns the interpolation mode in openCV format +// @param mode: interpolation mode in DE format +int GetCVInterpolationMode(InterpolationMode mode); + +// Returns the openCV equivalent of the border type used for padding. +// @param type +// @return +int GetCVBorderType(BorderType type); + +// Returns flipped image +// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param flip_code: 1 for Horizontal (around y-axis), 0 for Vertical (around x-axis), -1 for both +// The flipping happens in place. +Status Flip(std::shared_ptr input, std::shared_ptr *output, int flip_code); + +// Returns Horizontally flipped image +// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// The flipping happens in place. +Status HorizontalFlip(std::shared_ptr input, std::shared_ptr *output); + +// Returns Vertically flipped image +// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// The flipping happens in place. +Status VerticalFlip(std::shared_ptr input, std::shared_ptr *output); + +// Returns Resized image. +// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param output_height: height of output +// @param output_width: width of output +// @param fx: horizontal scale +// @param fy: vertical scale +// @param InterpolationMode: the interpolation mode +// @param output: Resized image of shape or +// and same type as input +Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, + int32_t output_width, double fx = 0.0, double fy = 0.0, + InterpolationMode mode = InterpolationMode::kLinear); + +// Returns Decoded image +// Supported images: +// BMP JPEG JPG PNG TIFF +// supported by opencv, if user need more image analysis capabilities, please compile opencv particularlly. +// @param input: CVTensor containing the not decoded image 1D bytes +// @param output: Decoded image Tensor of shape and type DE_UINT8. Pixel order is RGB +Status Decode(const std::shared_ptr &input, std::shared_ptr *output); + +Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *output); + +bool IsNonEmptyJPEG(const std::shared_ptr &input); + +void JpegSetSource(j_decompress_ptr c_info, const void *data, int64_t data_size); + +Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr *output, int x = 0, int y = 0, + int w = 0, int h = 0); +// Returns Rescaled image +// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param rescale: rescale parameter +// @param shift: shift parameter +// @param output: Rescaled image Tensor of same input shape and type DE_FLOAT32 +Status Rescale(const std::shared_ptr &input, std::shared_ptr *output, float rescale, float shift); + +// Returns cropped ROI of an image +// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param x: starting horizontal position of ROI +// @param y: starting vertical position of ROI +// @param w: width of the ROI +// @param h: height of the ROI +// @param output: Cropped image Tensor of shape or and same input type. +Status Crop(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, int w, int h); + +// Swaps the channels in the image, i.e. converts HWC to CHW +// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param output: Tensor of shape or and same input type. +Status HwcToChw(std::shared_ptr input, std::shared_ptr *output); + +// Swap the red and blue pixels (RGB <-> BGR) +// @param input: Tensor of shape and any OpenCv compatible type, see CVTensor. +// @param output: Swapped image of same shape and type +Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output); + +// Crops and resizes the image +// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param x: horizontal start point +// @param y: vertical start point +// @param crop_height: height of the cropped ROI +// @param crop_width: width of the cropped ROI +// @param target_width: width of the final resized image +// @param target_height: height of the final resized image +// @param InterpolationMode: the interpolation used in resize operation +// @param output: Tensor of shape or +// and same type as input +Status CropAndResize(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, + int crop_height, int crop_width, int target_height, int target_width, InterpolationMode mode); + +// Returns rotated image +// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param fx: rotation center x coordinate +// @param fy: rotation center y coordinate +// @param degree: degree to rotate +// @param expand: if reshape is necessary +// @param output: rotated image of same input type. +Status Rotate(const std::shared_ptr &input, std::shared_ptr *output, float fx, float fy, float degree, + InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, bool expand = false, + uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); + +// Returns Normalized image +// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +// @param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order +// @param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order +// @param output: Normalized image Tensor of same input shape and type DE_FLOAT32 +Status Normalize(const std::shared_ptr &input, std::shared_ptr *output, + const std::shared_ptr &mean, const std::shared_ptr &std); + +// Returns image with adjusted brightness. +// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +// @param alpha: Alpha value to adjust brightness by. Should be a positive number. +// If user input one value in python, the range is [1 - value, 1 + value]. +// This will output original image multiplied by alpha. 0 gives a black image, 1 gives the +// original image while 2 increases the brightness by a factor of 2. +// @param output: Adjusted image of same shape and type. +Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); + +// Returns image with adjusted contrast. +// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +// @param alpha: Alpha value to adjust contrast by. Should be a positive number. +// If user input one value in python, the range is [1 - value, 1 + value]. +// 0 gives a solid gray image, 1 gives the original image while 2 increases +// the contrast by a factor of 2. +// @param output: Adjusted image of same shape and type. +Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); + +// Returns image with adjusted saturation. +// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +// @param alpha: Alpha value to adjust saturation by. Should be a positive number. +// If user input one value in python, the range is [1 - value, 1 + value]. +// 0 will give a black and white image, 1 will give the original image while +// 2 will enhance the saturation by a factor of 2. +// @param output: Adjusted image of same shape and type. +Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); + +// Returns image with adjusted hue. +// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +// @param hue: Hue value to adjust by, should be within range [-0.5, 0.5]. 0.5 and - 0.5 will reverse the hue channel +// completely. +// If user input one value in python, the range is [-value, value]. +// @param output: Adjusted image of same shape and type. +Status AdjustHue(const std::shared_ptr &input, std::shared_ptr *output, const float &hue); + +// Masks out a random section from the image with set dimension +// @param input: input Tensor +// @param output: cutOut Tensor +// @param box_height: height of the cropped box +// @param box_width: width of the cropped box +// @param num_patches: number of boxes to cut out from the image +// @param bounded: boolean flag to toggle between random erasing and cutout +// @param random_color: whether or not random fill value should be used +// @param fill_r: red fill value for erase +// @param fill_g: green fill value for erase +// @param fill_b: blue fill value for erase. +Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, + int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, + uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); + +// Pads the input image and puts the padded image in the output +// @param input: input Tensor +// @param output: padded Tensor +// @param pad_top: amount of padding done in top +// @param pad_bottom: amount of padding done in bottom +// @param pad_left: amount of padding done in left +// @param pad_right: amount of padding done in right +// @param border_types: the interpolation to be done in the border +// @param fill_r: red fill value for pad +// @param fill_g: green fill value for pad +// @param fill_b: blue fill value for pad. +Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, + const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, + uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); + +// -------- BBOX OPERATIONS -------- // +// Updates and checks bounding boxes for new cropped region of image +// @param bboxList: A tensor contaning bounding box tensors +// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop +// @param CB_Xmin: Image's CropBox Xmin coordinate +// @param CB_Xmin: Image's CropBox Ymin coordinate +// @param CB_Xmax: Image's CropBox Xmax coordinate - (Xmin + width) +// @param CB_Xmax: Image's CropBox Ymax coordinate - (Ymin + height) +Status UpdateBBoxesForCrop(std::shared_ptr *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax, + int CB_Ymax); + +// Updates bounding boxes with required Top and Left padding +// Top and Left padding amounts required to adjust bboxs min X,Y values according to padding 'push' +// Top/Left since images 0,0 coordinate is taken from top left +// @param bboxList: A tensor contaning bounding box tensors +// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop +// @param pad_top: Total amount of padding applied to image top +// @param pad_left: Total amount of padding applied to image left side +Status PadBBoxes(const std::shared_ptr *bboxList, const size_t &bboxCount, int32_t pad_top, int32_t pad_left); + +// Updates bounding boxes for an Image Resize Operation - Takes in set of valid BBoxes +// For e.g those that remain after a crop +// @param bboxList: A tensor contaning bounding box tensors +// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop +// @param bboxList: A tensor contaning bounding box tensors +// @param target_width_: required width of image post resize +// @param target_width_: required height of image post resize +// @param orig_width: current width of image pre resize +// @param orig_height: current height of image pre resize +Status UpdateBBoxesForResize(const std::shared_ptr &bboxList, const size_t &bboxCount, int32_t target_width_, + int32_t target_height_, int orig_width, int orig_height); + +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc new file mode 100644 index 0000000000..de5deb31ef --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/normalize_op.h" + +#include + +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +NormalizeOp::NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b) { + int size[] = {3}; + cv::Mat mean_cv(1, size, CV_32F); + mean_cv.at(0) = mean_r; + mean_cv.at(1) = mean_g; + mean_cv.at(2) = mean_b; + mean_ = std::make_shared(mean_cv); + mean_->Squeeze(); + + cv::Mat std_cv(1, size, CV_32F); + std_cv.at(0) = std_r; + std_cv.at(1) = std_g; + std_cv.at(2) = std_b; + std_ = std::make_shared(std_cv); + std_->Squeeze(); +} + +Status NormalizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // Doing the normalization + return Normalize(input, output, mean_, std_); +} + +void NormalizeOp::Print(std::ostream &out) const { + out << "NormalizeOp, mean: " << mean_->mat().at(0) << ", " << mean_->mat().at(1) << ", " + << mean_->mat().at(2) << "std: " << std_->mat().at(0) << ", " << std_->mat().at(1) << ", " + << std_->mat().at(2) << std::endl; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h new file mode 100644 index 0000000000..7821869c8f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ +#define DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ + +#include +#include + +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class NormalizeOp : public TensorOp { + public: + NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b); + + ~NormalizeOp() override = default; + + void Print(std::ostream &out) const override; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kNormalizeOp; } + + private: + std::shared_ptr mean_; + std::shared_ptr std_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.cc new file mode 100644 index 0000000000..52f32e2b1b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/pad_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const BorderType PadOp::kDefBorderType = BorderType::kConstant; +const uint8_t PadOp::kDefFillR = 0; +const uint8_t PadOp::kDefFillG = 0; +const uint8_t PadOp::kDefFillB = 0; + +PadOp::PadOp(int32_t pad_top, int32_t pad_bottom, int32_t pad_left, int32_t pad_right, BorderType border_types, + uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) + : pad_top_(pad_top), + pad_bottom_(pad_bottom), + pad_left_(pad_left), + pad_right_(pad_right), + boarder_type_(border_types), + fill_r_(fill_r), + fill_g_(fill_g), + fill_b_(fill_b) {} + +Status PadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return Pad(input, output, pad_top_, pad_bottom_, pad_left_, pad_right_, boarder_type_, fill_r_, fill_g_, fill_b_); +} + +Status PadOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out({-1, -1, 3}); // we don't know what is output image size, but we know it should be 3 channels + if (inputs[0].Rank() == 1) outputs.emplace_back(out); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.h new file mode 100644 index 0000000000..9437058406 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.h @@ -0,0 +1,72 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_PAD_OP_H_ +#define DATASET_KERNELS_IMAGE_PAD_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class PadOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const BorderType kDefBorderType; + static const uint8_t kDefFillR; + static const uint8_t kDefFillG; + static const uint8_t kDefFillB; + + // Constructor for PadOp. + // @param pad_top number of pixels to pad the top of image with. + // @param pad_bottom number of pixels to pad the bottom of the image with. + // @param pad_left number of pixels to pad the left of the image with. + // @param pad_right number of pixels to pad the right of the image with. + // @param border_types BorderType enum, the type of boarders that we are using. + // @param fill_r R value for the color to pad with. + // @param fill_g G value for the color to pad with. + // @param fill_b B value for the color to pad with. + PadOp(int32_t pad_top, int32_t pad_bottom, int32_t pad_left, int32_t pad_right, BorderType border_types, + uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); + + ~PadOp() override = default; + + void Print(std::ostream &out) const override { out << "PadOp: "; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kPadOp; } + + private: + int32_t pad_top_; + int32_t pad_bottom_; + int32_t pad_left_; + int32_t pad_right_; + BorderType boarder_type_; + uint8_t fill_r_; + uint8_t fill_g_; + uint8_t fill_b_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_PAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.cc new file mode 100644 index 0000000000..6dbf30c33e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_color_adjust_op.h" + +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +RandomColorAdjustOp::RandomColorAdjustOp(float s_bright_factor, float e_bright_factor, float s_contrast_factor, + float e_contrast_factor, float s_saturation_factor, float e_saturation_factor, + float s_hue_factor, float e_hue_factor) + : bright_factor_start_(s_bright_factor), + bright_factor_end_(e_bright_factor), + contrast_factor_start_(s_contrast_factor), + contrast_factor_end_(e_contrast_factor), + saturation_factor_start_(s_saturation_factor), + saturation_factor_end_(e_saturation_factor), + hue_factor_start_(s_hue_factor), + hue_factor_end_(e_hue_factor) { + rnd_.seed(GetSeed()); +} + +Status RandomColorAdjustOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + + // randomly select an augmentation to apply to the input image until all the transformations run + std::vector params_vector = {"brightness", "contrast", "saturation", "hue"}; + + std::shuffle(params_vector.begin(), params_vector.end(), rnd_); + + *output = std::static_pointer_cast(input); + // determine if certain augmentation needs to be executed: + for (const auto ¶m : params_vector) { + // case switch + if (param == "brightness") { + if (CmpFloat(bright_factor_start_, bright_factor_end_) && CmpFloat(bright_factor_start_, 1.0f)) { + MS_LOG(DEBUG) << "Not running brightness."; + } else { + // adjust the brightness of an image + float random_factor = std::uniform_real_distribution(bright_factor_start_, bright_factor_end_)(rnd_); + RETURN_IF_NOT_OK(AdjustBrightness(*output, output, random_factor)); + } + } else if (param == "contrast") { + if (CmpFloat(contrast_factor_start_, contrast_factor_end_) && CmpFloat(contrast_factor_start_, 1.0f)) { + MS_LOG(DEBUG) << "Not running contrast."; + } else { + float random_factor = std::uniform_real_distribution(contrast_factor_start_, contrast_factor_end_)(rnd_); + RETURN_IF_NOT_OK(AdjustContrast(*output, output, random_factor)); + } + } else if (param == "saturation") { + // adjust the Saturation of an image + if (CmpFloat(saturation_factor_start_, saturation_factor_end_) && CmpFloat(saturation_factor_start_, 1.0f)) { + MS_LOG(DEBUG) << "Not running saturation."; + } else { + float random_factor = + std::uniform_real_distribution(saturation_factor_start_, saturation_factor_end_)(rnd_); + RETURN_IF_NOT_OK(AdjustSaturation(*output, output, random_factor)); + } + } else if (param == "hue") { + if (CmpFloat(hue_factor_start_, hue_factor_end_) && CmpFloat(hue_factor_start_, 0.0f)) { + MS_LOG(DEBUG) << "Not running hue."; + } else { + // adjust the Hue of an image + float random_factor = std::uniform_real_distribution(hue_factor_start_, hue_factor_end_)(rnd_); + RETURN_IF_NOT_OK(AdjustHue(*output, output, random_factor)); + } + } + } + // now after we do all the transformations, the last one is fine + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h new file mode 100644 index 0000000000..fb29b57062 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h @@ -0,0 +1,80 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomColorAdjustOp : public TensorOp { + public: + static const uint32_t kDefSeed; + + // Constructor for RandomColorAdjustOp. + // @param s_bright_factor brightness change range start value. + // @param e_bright_factor brightness change range end value. + // @param s_contrast_factor contrast change range start value. + // @param e_contrast_factor contrast change range start value. + // @param s_saturation_factor saturation change range end value. + // @param e_saturation_factor saturation change range end value. + // @param s_hue_factor hue change factor start value, this should be greater than -0.5. + // @param e_hue_factor hue change factor start value, this should be less than 0.5. + // @param seed optional seed to pass in to the constructor. + // @details the randomly chosen degree is uniformly distributed. + RandomColorAdjustOp(float s_bright_factor, float e_bright_factor, float s_contrast_factor, float e_contrast_factor, + float s_saturation_factor, float e_saturation_factor, float s_hue_factor, float e_hue_factor); + + ~RandomColorAdjustOp() override = default; + + // Print function for RandomJitter. + // @param out output stream to print to. + void Print(std::ostream &out) const override { out << "RandomColorAdjustOp: "; } + + // Overrides the base class compute function. + // Calls multiple transform functions in ImageUtils, this function takes an input tensor. + // and transforms its data using openCV, the output memory is manipulated to contain the result. + // @return Status - The error code return. + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomColorAdjustOp; } + + private: + std::mt19937 rnd_; + float bright_factor_start_; + float bright_factor_end_; + float contrast_factor_start_; + float contrast_factor_end_; + float saturation_factor_start_; + float saturation_factor_end_; + float hue_factor_start_; + float hue_factor_end_; + // Compare two floating point variables. Return true if they are same / very close. + inline bool CmpFloat(const float &a, const float &b, float epsilon = 0.0000000001f) const { + return (std::fabs(a - b) < epsilon); + } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.cc new file mode 100644 index 0000000000..8a7364d666 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const float RandomCropAndResizeOp::kDefScaleLb = 0.08; +const float RandomCropAndResizeOp::kDefScaleUb = 1.0; +const float RandomCropAndResizeOp::kDefAspectLb = 0.75; +const float RandomCropAndResizeOp::kDefAspectUb = 1.333333; +const InterpolationMode RandomCropAndResizeOp::kDefInterpolation = InterpolationMode::kLinear; +const int32_t RandomCropAndResizeOp::kDefMaxIter = 10; + +RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t target_width, float scale_lb, + float scale_ub, float aspect_lb, float aspect_ub, + InterpolationMode interpolation, int32_t max_iter) + : target_height_(target_height), + target_width_(target_width), + rnd_scale_(scale_lb, scale_ub), + rnd_aspect_(log(aspect_lb), log(aspect_ub)), + interpolation_(interpolation), + aspect_lb_(aspect_lb), + aspect_ub_(aspect_ub), + max_iter_(max_iter) { + rnd_.seed(GetSeed()); +} + +Status RandomCropAndResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); + + int h_in = input->shape()[0]; + int w_in = input->shape()[1]; + int x = 0; + int y = 0; + int crop_height = 0; + int crop_width = 0; + (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); + return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); +} +Status RandomCropAndResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out = TensorShape{target_height_, target_width_}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) { + *crop_width = w_in; + *crop_height = h_in; + CHECK_FAIL_RETURN_UNEXPECTED(w_in != 0, "Width is 0"); + CHECK_FAIL_RETURN_UNEXPECTED(h_in != 0, "Height is 0"); + CHECK_FAIL_RETURN_UNEXPECTED(aspect_lb_ > 0, "Aspect lower bound must be greater than zero"); + for (int32_t i = 0; i < max_iter_; i++) { + double const sample_scale = rnd_scale_(rnd_); + // In case of non-symmetrical aspect ratios, use uniform distribution on a logarithmic sample_scale. + // Note rnd_aspect_ is already a random distribution of the input aspect ratio in logarithmic sample_scale. + double const sample_aspect = exp(rnd_aspect_(rnd_)); + + *crop_width = static_cast(std::round(std::sqrt(h_in * w_in * sample_scale * sample_aspect))); + *crop_height = static_cast(std::round(*crop_width / sample_aspect)); + if (*crop_width <= w_in && *crop_height <= h_in) { + std::uniform_int_distribution<> rd_x(0, w_in - *crop_width); + std::uniform_int_distribution<> rd_y(0, h_in - *crop_height); + *x = rd_x(rnd_); + *y = rd_y(rnd_); + return Status::OK(); + } + } + double const img_aspect = static_cast(w_in) / h_in; + if (img_aspect < aspect_lb_) { + *crop_width = w_in; + *crop_height = static_cast(std::round(*crop_width / static_cast(aspect_lb_))); + } else { + if (img_aspect > aspect_ub_) { + *crop_height = h_in; + *crop_width = static_cast(std::round(*crop_height * static_cast(aspect_ub_))); + } else { + *crop_width = w_in; + *crop_height = h_in; + } + } + *x = static_cast(std::round((w_in - *crop_width) / 2.0)); + *y = static_cast(std::round((h_in - *crop_height) / 2.0)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.h new file mode 100644 index 0000000000..41d775fdf7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.h @@ -0,0 +1,78 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomCropAndResizeOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefScaleLb; + static const float kDefScaleUb; + static const float kDefAspectLb; + static const float kDefAspectUb; + static const InterpolationMode kDefInterpolation; + static const int32_t kDefMaxIter; + + RandomCropAndResizeOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, + float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, float aspect_ub = kDefAspectUb, + InterpolationMode interpolation = kDefInterpolation, int32_t max_iter = kDefMaxIter); + + RandomCropAndResizeOp() = default; + + RandomCropAndResizeOp(const RandomCropAndResizeOp &rhs) = default; + + RandomCropAndResizeOp(RandomCropAndResizeOp &&rhs) = default; + + ~RandomCropAndResizeOp() override = default; + + void Print(std::ostream &out) const override { + out << "RandomCropAndResize: " << target_height_ << " " << target_width_; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + Status GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width); + + std::string Name() const override { return kRandomCropAndResizeOp; } + + protected: + int32_t target_height_; + int32_t target_width_; + std::uniform_real_distribution rnd_scale_; + std::uniform_real_distribution rnd_aspect_; + std::mt19937 rnd_; + InterpolationMode interpolation_; + int32_t max_iter_; + double aspect_lb_; + double aspect_ub_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc new file mode 100644 index 0000000000..98bfe41241 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" + +namespace mindspore { +namespace dataset { + +Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); + + output->resize(2); + (*output)[1] = std::move(input[1]); // move boxes over to output + + size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor + int h_in = input[0]->shape()[0]; + int w_in = input[0]->shape()[1]; + int x = 0; + int y = 0; + int crop_height = 0; + int crop_width = 0; + + RETURN_IF_NOT_OK(RandomCropAndResizeOp::GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width)); + + int maxX = x + crop_width; // max dims of selected CropBox on image + int maxY = y + crop_height; + + RETURN_IF_NOT_OK(UpdateBBoxesForCrop(&(*output)[1], &bboxCount, x, y, maxX, maxY)); // IMAGE_UTIL + RETURN_IF_NOT_OK(CropAndResize(input[0], &(*output)[0], x, y, crop_height, crop_width, target_height_, target_width_, + interpolation_)); + + RETURN_IF_NOT_OK( + UpdateBBoxesForResize((*output)[1], bboxCount, target_width_, target_height_, crop_width, crop_height)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h new file mode 100644 index 0000000000..ddaac10fac --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ + +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include + +namespace mindspore { +namespace dataset { + +class RandomCropAndResizeWithBBoxOp : public RandomCropAndResizeOp { + public: + // Constructor for RandomCropAndResizeWithBBoxOp, with default value and passing to base class constructor + RandomCropAndResizeWithBBoxOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, + float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, + float aspect_ub = kDefAspectUb, InterpolationMode interpolation = kDefInterpolation, + int32_t max_iter = kDefMaxIter) + : RandomCropAndResizeOp(target_height, target_width, scale_lb, scale_ub, aspect_lb, aspect_ub, interpolation, + max_iter) {} + + ~RandomCropAndResizeWithBBoxOp() override = default; + + void Print(std::ostream &out) const override { + out << "RandomCropAndResizeWithBBox: " << RandomCropAndResizeOp::target_height_ << " " + << RandomCropAndResizeOp::target_width_; + } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomCropAndResizeWithBBoxOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.cc new file mode 100644 index 0000000000..d62aebd37f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" +#include +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/kernels/image/decode_op.h" + +namespace mindspore { +namespace dataset { +RandomCropDecodeResizeOp::RandomCropDecodeResizeOp(int32_t target_height, int32_t target_width, float scale_lb, + float scale_ub, float aspect_lb, float aspect_ub, + InterpolationMode interpolation, int32_t max_iter) + : RandomCropAndResizeOp(target_height, target_width, scale_lb, scale_ub, aspect_lb, aspect_ub, interpolation, + max_iter) {} + +Status RandomCropDecodeResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + if (input == nullptr) { + RETURN_STATUS_UNEXPECTED("input tensor is null"); + } + if (!IsNonEmptyJPEG(input)) { + DecodeOp op(true); + std::shared_ptr decoded; + RETURN_IF_NOT_OK(op.Compute(input, &decoded)); + return RandomCropAndResizeOp::Compute(decoded, output); + } else { + struct jpeg_decompress_struct cinfo {}; + struct JpegErrorManagerCustom jerr {}; + cinfo.err = jpeg_std_error(&jerr.pub); + jerr.pub.error_exit = JpegErrorExitCustom; + try { + jpeg_create_decompress(&cinfo); + JpegSetSource(&cinfo, input->GetBuffer(), input->SizeInBytes()); + (void)jpeg_read_header(&cinfo, TRUE); + jpeg_calc_output_dimensions(&cinfo); + } catch (std::runtime_error &e) { + jpeg_destroy_decompress(&cinfo); + RETURN_STATUS_UNEXPECTED(e.what()); + } + int h_in = cinfo.output_height; + int w_in = cinfo.output_width; + jpeg_destroy_decompress(&cinfo); + + int x = 0; + int y = 0; + int crop_height = 0; + int crop_width = 0; + (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); + + std::shared_ptr decoded; + RETURN_IF_NOT_OK(JpegCropAndDecode(input, &decoded, x, y, crop_width, crop_height)); + return Resize(decoded, output, target_height_, target_width_, 0.0, 0.0, interpolation_); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.h new file mode 100644 index 0000000000..863fd48c14 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomCropDecodeResizeOp : public RandomCropAndResizeOp { + public: + RandomCropDecodeResizeOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, + float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, float aspect_ub = kDefAspectUb, + InterpolationMode interpolation = kDefInterpolation, int32_t max_iter = kDefMaxIter); + + explicit RandomCropDecodeResizeOp(const RandomCropAndResizeOp &rhs) : RandomCropAndResizeOp(rhs) {} + + ~RandomCropDecodeResizeOp() override = default; + + void Print(std::ostream &out) const override { + out << "RandomCropDecodeResize: " << RandomCropAndResizeOp::target_height_ << " " + << RandomCropAndResizeOp::target_width_; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomCropDecodeResizeOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc new file mode 100644 index 0000000000..51772e9ec3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_crop_op.h" +#include +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t RandomCropOp::kDefPadTop = 0; +const int32_t RandomCropOp::kDefPadBottom = 0; +const int32_t RandomCropOp::kDefPadLeft = 0; +const int32_t RandomCropOp::kDefPadRight = 0; +const BorderType RandomCropOp::kDefBorderType = BorderType::kConstant; +const bool RandomCropOp::kDefPadIfNeeded = false; +const uint8_t RandomCropOp::kDefFillR = 0; +const uint8_t RandomCropOp::kDefFillG = 0; +const uint8_t RandomCropOp::kDefFillB = 0; + +RandomCropOp::RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_top, int32_t pad_bottom, + int32_t pad_left, int32_t pad_right, BorderType border_types, bool pad_if_needed, + uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) + : crop_height_(crop_height), + crop_width_(crop_width), + pad_top_(pad_top), + pad_bottom_(pad_bottom), + pad_left_(pad_left), + pad_right_(pad_right), + pad_if_needed_(pad_if_needed), + border_type_(border_types), + fill_r_(fill_r), + fill_g_(fill_g), + fill_b_(fill_b) { + rnd_.seed(GetSeed()); +} + +Status RandomCropOp::ImagePadding(const std::shared_ptr &input, std::shared_ptr *pad_image, + int32_t *t_pad_top, int32_t *t_pad_bottom, int32_t *t_pad_left, int32_t *t_pad_right, + int32_t *padded_image_w, int32_t *padded_image_h, bool *crop_further) { + *t_pad_top = pad_top_; + *t_pad_bottom = pad_bottom_; + *t_pad_left = pad_left_; + *t_pad_right = pad_right_; + + RETURN_IF_NOT_OK( + Pad(input, pad_image, pad_top_, pad_bottom_, pad_left_, pad_right_, border_type_, fill_r_, fill_g_, fill_b_)); + CHECK_FAIL_RETURN_UNEXPECTED((*pad_image)->shape().Size() >= 2, "Abnormal shape"); + + *padded_image_h = (*pad_image)->shape()[0]; + *padded_image_w = (*pad_image)->shape()[1]; + + if (*padded_image_h == crop_height_ && *padded_image_w == crop_width_) { + *crop_further = false; // no need for further crop + return Status::OK(); + } else if (pad_if_needed_) { + // check the dimensions of the image for padding, if we do need padding, then we change the pad values + if (*padded_image_h < crop_height_) { + RETURN_IF_NOT_OK(Pad(*pad_image, pad_image, crop_height_ - *padded_image_h, crop_height_ - *padded_image_h, 0, 0, + border_type_, fill_r_, fill_g_, fill_b_)); + + // update pad total above/below + t_pad_top += (crop_height_ - *padded_image_h); + t_pad_bottom += (crop_height_ - *padded_image_h); + } + if (*padded_image_w < crop_width_) { + RETURN_IF_NOT_OK(Pad(*pad_image, pad_image, 0, 0, crop_width_ - *padded_image_w, crop_width_ - *padded_image_w, + border_type_, fill_r_, fill_g_, fill_b_)); + // update pad total left/right + t_pad_left += (crop_width_ - *padded_image_w); + t_pad_right += (crop_width_ - *padded_image_w); + } + *padded_image_h = (*pad_image)->shape()[0]; + *padded_image_w = (*pad_image)->shape()[1]; + } + + if (*padded_image_h < crop_height_ || *padded_image_w < crop_width_ || crop_height_ == 0 || crop_width_ == 0) { + return Status(StatusCode::kShapeMisMatch, __LINE__, __FILE__, + "Crop size is greater than the image dimensions or is zero."); + } + return Status::OK(); +} + +void RandomCropOp::GenRandomXY(int *x, int *y, const int32_t &padded_image_w, const int32_t &padded_image_h) { + // GenCropPoints for cropping + *x = std::uniform_int_distribution(0, padded_image_w - crop_width_)(rnd_); + *y = std::uniform_int_distribution(0, padded_image_h - crop_height_)(rnd_); +} + +Status RandomCropOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + + // Apply padding first then crop + std::shared_ptr pad_image; + int32_t t_pad_top, t_pad_bottom, t_pad_left, t_pad_right; + int32_t padded_image_w; + int32_t padded_image_h; + bool crop_further = true; // whether image needs further cropping based on new size & requirements + + RETURN_IF_NOT_OK( // error code sent back directly + ImagePadding(input, &pad_image, &t_pad_top, &t_pad_bottom, &t_pad_left, &t_pad_right, &padded_image_w, + &padded_image_h, &crop_further)); + if (!crop_further) { + *output = pad_image; + return Status::OK(); + } + + int x, y; + GenRandomXY(&x, &y, padded_image_w, padded_image_h); + return Crop(pad_image, output, x, y, crop_width_, crop_height_); +} + +Status RandomCropOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out = TensorShape{crop_height_, crop_width_}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h new file mode 100644 index 0000000000..44f1789f9d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h @@ -0,0 +1,101 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomCropOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefPadTop; + static const int32_t kDefPadBottom; + static const int32_t kDefPadLeft; + static const int32_t kDefPadRight; + static const BorderType kDefBorderType; + static const bool kDefPadIfNeeded; + static const uint8_t kDefFillR; + static const uint8_t kDefFillG; + static const uint8_t kDefFillB; + + RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_top = kDefPadTop, + int32_t pad_bottom = kDefPadBottom, int32_t pad_left = kDefPadLeft, int32_t pad_right = kDefPadRight, + BorderType border_types = kDefBorderType, bool pad_if_needed = kDefPadIfNeeded, + uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); + + RandomCropOp(const RandomCropOp &rhs) = default; + + RandomCropOp(RandomCropOp &&rhs) = default; + + ~RandomCropOp() override = default; + + void Print(std::ostream &out) const override { out << "RandomCropOp: " << crop_height_ << " " << crop_width_; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + // Function breaks out the compute function's image padding functionality and makes available to other Ops + // Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op + // @param input: Input is the original Image + // @param pad_image: Pointer to new Padded image + // @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required + // @param t_pad_bottom: Total bottom Padding - Based on input and value calculated in function if required + // @param t_pad_left: Total left Padding - Based on input and value calculated in function if required + // @param t_pad_right: Total right Padding - Based on input and value calculated in function if required + // @param padded_image_w: Final Width of the 'pad_image' + // @param padded_image_h: Final Height of the 'pad_image' + // @param crop_further: Whether image required cropping after padding - False if new padded image matches required + // dimensions + Status ImagePadding(const std::shared_ptr &input, std::shared_ptr *pad_image, int32_t *t_pad_top, + int32_t *t_pad_bottom, int32_t *t_pad_left, int32_t *t_pad_right, int32_t *padded_image_w, + int32_t *padded_image_h, bool *crop_further); + + // Function breaks X,Y generation functionality out of original compute function and makes available to other Ops + void GenRandomXY(int *x, int *y, const int32_t &padded_image_w, const int32_t &padded_image_h); + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kRandomCropOp; } + + protected: + int32_t crop_height_ = 0; + int32_t crop_width_ = 0; + + private: + int32_t pad_top_ = 0; + int32_t pad_bottom_ = 0; + int32_t pad_left_ = 0; + int32_t pad_right_ = 0; + bool pad_if_needed_ = false; + BorderType border_type_; + uint8_t fill_r_ = 0; + uint8_t fill_g_ = 0; + uint8_t fill_b_ = 0; + std::mt19937 rnd_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.cc new file mode 100644 index 0000000000..08b12b8b70 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + + std::shared_ptr pad_image; + int32_t t_pad_top, t_pad_bottom, t_pad_left, t_pad_right; + size_t boxCount = input[1]->shape()[0]; // number of rows + + int32_t padded_image_h; + int32_t padded_image_w; + + output->resize(2); + (*output)[1] = std::move(input[1]); // since some boxes may be removed + + bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches + RETURN_IF_NOT_OK( // Error passed back to caller + RandomCropOp::ImagePadding(input[0], &pad_image, &t_pad_top, &t_pad_bottom, &t_pad_left, &t_pad_right, + &padded_image_w, &padded_image_h, &crop_further)); + + // update bounding boxes with new values based on relevant image padding + if (t_pad_left || t_pad_bottom) { + RETURN_IF_NOT_OK(PadBBoxes(&(*output)[1], boxCount, t_pad_left, t_pad_top)); + } + if (!crop_further) { + // no further cropping required + (*output)[0] = pad_image; + (*output)[1] = std::move(input[1]); + return Status::OK(); + } + + int x, y; + RandomCropOp::GenRandomXY(&x, &y, padded_image_w, padded_image_h); + int maxX = x + RandomCropOp::crop_width_; // max dims of selected CropBox on image + int maxY = y + RandomCropOp::crop_height_; + RETURN_IF_NOT_OK(UpdateBBoxesForCrop(&(*output)[1], &boxCount, x, y, maxX, maxY)); + return Crop(pad_image, &(*output)[0], x, y, RandomCropOp::crop_width_, RandomCropOp::crop_height_); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.h new file mode 100644 index 0000000000..bfcd1610d3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/kernels/image/random_crop_op.h" + +namespace mindspore { +namespace dataset { +class RandomCropWithBBoxOp : public RandomCropOp { + public: + // Constructor for RandomCropWithBBoxOp, with default value and passing to base class constructor + RandomCropWithBBoxOp(int32_t crop_height, int32_t crop_width, int32_t pad_top = kDefPadTop, + int32_t pad_bottom = kDefPadBottom, int32_t pad_left = kDefPadLeft, + int32_t pad_right = kDefPadRight, BorderType border_types = kDefBorderType, + bool pad_if_needed = kDefPadIfNeeded, uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, + uint8_t fill_b = kDefFillB) + : RandomCropOp(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right, border_types, pad_if_needed, + fill_r, fill_g, fill_b) {} + + ~RandomCropWithBBoxOp() override = default; + + void Print(std::ostream &out) const override { + out << "RandomCropWithBBoxOp: " << RandomCropOp::crop_height_ << " " << RandomCropOp::crop_width_; + } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomCropWithBBoxOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.cc new file mode 100644 index 0000000000..5e8ab8a634 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const float RandomHorizontalFlipOp::kDefProbability = 0.5; + +Status RandomHorizontalFlipOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (distribution_(rnd_)) { + return HorizontalFlip(input, output); + } + *output = input; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h new file mode 100644 index 0000000000..9e08929180 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomHorizontalFlipOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefProbability; + + explicit RandomHorizontalFlipOp(float probability = kDefProbability) : distribution_(probability) { + rnd_.seed(GetSeed()); + } + + ~RandomHorizontalFlipOp() override = default; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const RandomHorizontalFlipOp &so) { + so.Print(out); + return out; + } + + void Print(std::ostream &out) const override { out << "RandomHorizontalFlipOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomHorizontalFlipOp; } + + private: + std::mt19937 rnd_; + std::bernoulli_distribution distribution_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.cc new file mode 100644 index 0000000000..809f564b18 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/cv_tensor.h" + +namespace mindspore { +namespace dataset { +const float RandomHorizontalFlipWithBBoxOp::kDefProbability = 0.5; + +Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + if (distribution_(rnd_)) { + // To test bounding boxes algorithm, create random bboxes from image dims + size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes + float img_center = (input[0]->shape()[1] / 2.); // get the center of the image + for (int i = 0; i < num_of_boxes; i++) { + float b_w = 0; // bounding box width + float min_x = 0; + // get the required items + RETURN_IF_NOT_OK(input[1]->GetItemAt(&min_x, {i, 0})); + RETURN_IF_NOT_OK(input[1]->GetItemAt(&b_w, {i, 2})); + // do the flip + float diff = img_center - min_x; // get distance from min_x to center + float refl_min_x = diff + img_center; // get reflection of min_x + float new_min_x = refl_min_x - b_w; // subtract from the reflected min_x to get the new one + RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 0}, new_min_x)); + } + (*output).resize(2); + // move input to output pointer of bounding boxes + (*output)[1] = std::move(input[1]); + // perform HorizontalFlip on the image + std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input[0])); + return HorizontalFlip(std::static_pointer_cast(input_cv), &(*output)[0]); + } + *output = input; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h new file mode 100644 index 0000000000..d98669ea13 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomHorizontalFlipWithBBoxOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefProbability; + + explicit RandomHorizontalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { + rnd_.seed(GetSeed()); + } + + ~RandomHorizontalFlipWithBBoxOp() override = default; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const RandomHorizontalFlipWithBBoxOp &so) { + so.Print(out); + return out; + } + + void Print(std::ostream &out) const override { out << "RandomHorizontalFlipWithBBoxOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomHorizontalFlipWithBBoxOp; } + + private: + std::mt19937 rnd_; + std::bernoulli_distribution distribution_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.cc new file mode 100644 index 0000000000..8736f0a6a5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_resize_op.h" + +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t RandomResizeOp::kDefTargetWidth = 0; + +Status RandomResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + // Randomly selects from the following four interpolation methods + // 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area + interpolation_ = static_cast(distribution_(random_generator_)); + return ResizeOp::Compute(input, output); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h new file mode 100644 index 0000000000..8b2b067751 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomResizeOp : public ResizeOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefTargetWidth; + + explicit RandomResizeOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeOp(size_1, size_2) { + random_generator_.seed(GetSeed()); + } + + ~RandomResizeOp() = default; + + // Description: A function that prints info about the node + void Print(std::ostream &out) const override { + out << "RandomResizeOp: " << ResizeOp::size1_ << " " << ResizeOp::size2_; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomResizeOp; } + + private: + std::mt19937 random_generator_; + std::uniform_int_distribution distribution_{0, 3}; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.cc new file mode 100644 index 0000000000..e099b78a0f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t RandomResizeWithBBoxOp::kDefTargetWidth = 0; + +Status RandomResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + // Randomly selects from the following four interpolation methods + // 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area + interpolation_ = static_cast(distribution_(random_generator_)); + RETURN_IF_NOT_OK(ResizeWithBBoxOp::Compute(input, output)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h new file mode 100644 index 0000000000..6bad0d30fa --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H +#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefTargetWidth; + explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) { + random_generator_.seed(GetSeed()); + } + + ~RandomResizeWithBBoxOp() = default; + + // Description: A function that prints info about the node + void Print(std::ostream &out) const override { + out << "RandomResizeWithBBoxOp: " << ResizeWithBBoxOp::size1_ << " " << ResizeWithBBoxOp::size2_; + } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomResizeWithBBoxOp; } + + private: + std::mt19937 random_generator_; + std::uniform_int_distribution distribution_{0, 3}; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.cc new file mode 100644 index 0000000000..b2cb4facae --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_rotation_op.h" + +#include + +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const float RandomRotationOp::kDefCenterX = -1; +const float RandomRotationOp::kDefCenterY = -1; +const InterpolationMode RandomRotationOp::kDefInterpolation = InterpolationMode::kNearestNeighbour; +const bool RandomRotationOp::kDefExpand = false; +const uint8_t RandomRotationOp::kDefFillR = 0; +const uint8_t RandomRotationOp::kDefFillG = 0; +const uint8_t RandomRotationOp::kDefFillB = 0; + +// constructor +RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, float center_x, float center_y, + InterpolationMode interpolation, bool expand, uint8_t fill_r, uint8_t fill_g, + uint8_t fill_b) + : degree_start_(start_degree), + degree_end_(end_degree), + center_x_(center_x), + center_y_(center_y), + interpolation_(interpolation), + expand_(expand), + fill_r_(fill_r), + fill_g_(fill_g), + fill_b_(fill_b) { + rnd_.seed(GetSeed()); +} + +// main function call for random rotation : Generate the random degrees +Status RandomRotationOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + float random_double = distribution_(rnd_); + // get the degree rotation range, mod by 360 because full rotation doesn't affect + // the way this op works (uniform distribution) + // assumption here is that mDegreesEnd > mDegreeStart so we always get positive number + // Note: the range technically is greater than 360 degrees, but will be halved + float degree_range = (degree_end_ - degree_start_) / 2; + float mid = (degree_end_ + degree_start_) / 2; + float degree = mid + random_double * degree_range; + + return Rotate(input, output, center_x_, center_y_, degree, interpolation_, expand_, fill_r_, fill_g_, fill_b_); +} +Status RandomRotationOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + int32_t outputH = -1, outputW = -1; + // if expand_, then we cannot know the shape. We need the input image to find the output shape --> set it to + // <-1,-1[,3]> + if (!expand_) { + outputH = inputs[0][0]; + outputW = inputs[0][1]; + } + TensorShape out = TensorShape{outputH, outputW}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h new file mode 100644 index 0000000000..ea679ccb56 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h @@ -0,0 +1,90 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { +class RandomRotationOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefCenterX; + static const float kDefCenterY; + static const InterpolationMode kDefInterpolation; + static const bool kDefExpand; + static const uint8_t kDefFillR; + static const uint8_t kDefFillG; + static const uint8_t kDefFillB; + + // Constructor for RandomRotationOp + // @param startDegree starting range for random degree + // @param endDegree ending range for random degree + // @param centerX x coordinate for center of image rotation + // @param centerY y coordinate for center of image rotation + // @param interpolation DE interpolation mode for rotation + // @param expand option for the output image shape to change + // @param fill_r R value for the color to pad with + // @param fill_g G value for the color to pad with + // @param fill_b B value for the color to pad with + // @details the randomly chosen degree is uniformly distributed + // @details the output shape, if changed, will contain the entire rotated image + // @note maybe using unsigned long int isn't the best here according to our coding rules + RandomRotationOp(float start_degree, float end_degree, float center_x = kDefCenterX, float center_y = kDefCenterY, + InterpolationMode interpolation = kDefInterpolation, bool expand = kDefExpand, + uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); + + ~RandomRotationOp() override = default; + + // Print function for RandomRotation + // @param out output stream to print to + void Print(std::ostream &out) const override { out << "RandomRotationOp: "; } + + // Overrides the base class compute function + // Calls the rotate function in ImageUtils, this function takes an input tensor + // and transforms its data using openCV, the output memory is manipulated to contain the result + // @return Status - The error code return + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kRandomRotationOp; } + + private: + float degree_start_; + float degree_end_; + float center_x_; + float center_y_; + InterpolationMode interpolation_; + bool expand_; + uint8_t fill_r_; + uint8_t fill_g_; + uint8_t fill_b_; + std::uniform_real_distribution distribution_{-1.0, 1.0}; + std::mt19937 rnd_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.cc new file mode 100644 index 0000000000..24d816ef1a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const float RandomVerticalFlipOp::kDefProbability = 0.5; + +Status RandomVerticalFlipOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (distribution_(rnd_)) { + return VerticalFlip(input, output); + } + *output = input; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h new file mode 100644 index 0000000000..cee5869c71 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +class RandomVerticalFlipOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefProbability; + + explicit RandomVerticalFlipOp(float probability = kDefProbability) : distribution_(probability) { + rnd_.seed(GetSeed()); + } + + ~RandomVerticalFlipOp() override = default; + + void Print(std::ostream &out) const override { out << "RandomVerticalFlipOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomVerticalFlipOp; } + + private: + std::mt19937 rnd_; + std::bernoulli_distribution distribution_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc new file mode 100644 index 0000000000..7d2fa7bab5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h" + +namespace mindspore { +namespace dataset { +const float RandomVerticalFlipWithBBoxOp::kDefProbability = 0.5; +Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + + if (distribution_(rnd_)) { + dsize_t imHeight = input[0]->shape()[0]; + size_t boxCount = input[1]->shape()[0]; // number of rows in tensor + + // one time allocation -> updated in the loop + // type defined based on VOC test dataset + for (int i = 0; i < boxCount; i++) { + float boxCorner_y = 0.0, boxHeight = 0.0; + float newBoxCorner_y = 0.0; + RETURN_IF_NOT_OK(input[1]->GetItemAt(&boxCorner_y, {i, 1})); // get min y of bbox + RETURN_IF_NOT_OK(input[1]->GetItemAt(&boxHeight, {i, 3})); // get height of bbox + + // subtract (curCorner + height) from (max) for new Corner position + newBoxCorner_y = (imHeight - 1.0) - ((boxCorner_y + boxHeight) - 1.0); + RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); + } + + output->resize(2); + (*output)[1] = std::move(input[1]); + + return VerticalFlip(input[0], &(*output)[0]); + } + *output = input; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h new file mode 100644 index 0000000000..c9f19f5217 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +class RandomVerticalFlipWithBBoxOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefProbability; + // Constructor for RandomVerticalFlipWithBBoxOp + // @param probability: Probablity of Image flipping, 0.5 by default + explicit RandomVerticalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { + rnd_.seed(GetSeed()); + } + + ~RandomVerticalFlipWithBBoxOp() override = default; + + void Print(std::ostream &out) const override { out << "RandomVerticalFlipWithBBoxOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomVerticalFlipWithBBoxOp; } + + private: + std::mt19937 rnd_; + std::bernoulli_distribution distribution_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.cc new file mode 100644 index 0000000000..2a500d6c34 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/rescale_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status RescaleOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return Rescale(input, output, rescale_, shift_); +} +Status RescaleOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = DataType(DataType::DE_FLOAT32); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.h new file mode 100644 index 0000000000..c70b7bf6cf --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RESCALE_OP_H_ +#define DATASET_KERNELS_IMAGE_RESCALE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RescaleOp : public TensorOp { + public: + RescaleOp(float rescale_ratio, float shift_ratio) : rescale_(rescale_ratio), shift_(shift_ratio) {} + + ~RescaleOp() override = default; + + void Print(std::ostream &out) const override { + out << "RescaleOp: shift: " << shift_ << ", Rescale: " << rescale_ << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kRescaleOp; } + + private: + float rescale_; + float shift_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_IMAGE_RESCALE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.cc new file mode 100644 index 0000000000..48a8fbbc53 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/resize_bilinear_op.h" +#include + +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t ResizeBilinearOp::kDefWidth = 0; + +void ResizeBilinearOp::Print(std::ostream &out) const { out << "ResizeBilinearOp: "; } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.h new file mode 100644 index 0000000000..fd8f940946 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ +#define DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class ResizeBilinearOp : public ResizeOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefWidth; + + // Name: constructor + // Resizes the image to the output specified size using Bilinear interpolation. + // If only one value is provided, the it will resize the smaller size and maintains + // the aspect ratio. + // @param size1: the first size of output. If only this parameter is provided + // the smaller dimension will be resized to this and then the other dimension changes + // such that the aspect ratio is maintained. + // @param size2: the second size of output. If this is also provided, the output size + // will be (size1, size2) + explicit ResizeBilinearOp(int32_t size1, int32_t size2 = kDefWidth) + : ResizeOp(size1, size2, ResizeOp::kDefInterpolation) {} + + // Name: Destructor + // Description: Destructor + ~ResizeBilinearOp() = default; + + // Name: Print() + // Description: A function that prints info about the node + void Print(std::ostream &out) const override; + + std::string Name() const override { return kResizeBilinearOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.cc new file mode 100644 index 0000000000..7456f50f32 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/resize_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t ResizeOp::kDefWidth = 0; +const InterpolationMode ResizeOp::kDefInterpolation = InterpolationMode::kLinear; + +Status ResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape size " + std::to_string(input->shape().Size()) + + " of input tensor is invalid"); + int32_t output_h, output_w = 0; + int32_t input_h = static_cast(input->shape()[0]); + int32_t input_w = static_cast(input->shape()[1]); + if (size2_ == 0) { + if (input_h < input_w) { + CHECK_FAIL_RETURN_UNEXPECTED(input_h != 0, "The input height is 0"); + output_h = size1_; + output_w = static_cast(std::lround(static_cast(input_w) / input_h * output_h)); + } else { + CHECK_FAIL_RETURN_UNEXPECTED(input_w != 0, "The input width is 0"); + output_w = size1_; + output_h = static_cast(std::lround(static_cast(input_h) / input_w * output_w)); + } + } else { + output_h = size1_; + output_w = size2_; + } + return Resize(input, output, output_h, output_w, 0, 0, interpolation_); +} + +Status ResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + int32_t outputH = -1, outputW = -1; + // if size2_ == 0, then we cannot know the shape. We need the input image to find the output shape --> set it to + // <-1,-1[,3]> + if (size2_ != 0) { + outputH = size1_; + outputW = size2_; + } + TensorShape out = TensorShape{outputH, outputW}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h new file mode 100644 index 0000000000..3f847243ff --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RESIZE_OP_H_ +#define DATASET_KERNELS_IMAGE_RESIZE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class ResizeOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefWidth; + static const InterpolationMode kDefInterpolation; + + // Resizes the image to the output specified size. If only one value is provided, + // the it will resize the smaller size and maintains the aspect ratio. + // @param size1: the first size of output. If only this parameter is provided + // the smaller dimension will be resized to this and then the other dimension changes + // such that the aspect ratio is maintained. + // @param size2: the second size of output. If this is also provided, the output size + // will be (size1, size2) + // @param InterpolationMode: the interpolation mode being used. + explicit ResizeOp(int32_t size1, int32_t size2 = kDefWidth, InterpolationMode mInterpolation = kDefInterpolation) + : size1_(size1), size2_(size2), interpolation_(mInterpolation) {} + + ResizeOp(const ResizeOp &rhs) = default; + + ResizeOp(ResizeOp &&rhs) = default; + + ~ResizeOp() override = default; + + void Print(std::ostream &out) const override { out << "ResizeOp: " << size1_ << " " << size2_; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kResizeOp; } + + protected: + int32_t size1_; + int32_t size2_; + InterpolationMode interpolation_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.cc new file mode 100644 index 0000000000..9df2d8a25e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" +#include +#include +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/pybind_support.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +Status ResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + + int32_t input_h = input[0]->shape()[0]; + int32_t input_w = input[0]->shape()[1]; + + output->resize(2); + (*output)[1] = std::move(input[1]); // move boxes over to output + + std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input[0])); + + RETURN_IF_NOT_OK(ResizeOp::Compute(std::static_pointer_cast(input_cv), &(*output)[0])); + + int32_t output_h = (*output)[0]->shape()[0]; // output height if ResizeWithBBox + int32_t output_w = (*output)[0]->shape()[1]; // output width if ResizeWithBBox + + size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor + RETURN_IF_NOT_OK(UpdateBBoxesForResize((*output)[1], bboxCount, output_w, output_h, input_w, input_h)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.h new file mode 100644 index 0000000000..d2b5c96bf3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H +#define DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H + +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/kernels/image/resize_op.h" + +namespace mindspore { +namespace dataset { +class ResizeWithBBoxOp : public ResizeOp { + public: + // Constructor for ResizeWithBBoxOp, with default value and passing to base class constructor + explicit ResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefWidth, + InterpolationMode mInterpolation = kDefInterpolation) + : ResizeOp(size_1, size_2, mInterpolation) {} + + ~ResizeWithBBoxOp() override = default; + + void Print(std::ostream &out) const override { out << "ResizeWithBBoxOp: " << size1_ << " " << size2_; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kResizeWithBBoxOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.cc new file mode 100644 index 0000000000..95d75af0f2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include +#include "minddata/dataset/kernels/image/uniform_aug_op.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +const int UniformAugOp::kDefNumOps = 2; + +UniformAugOp::UniformAugOp(std::vector> op_list, int32_t num_ops) + : tensor_op_list_(op_list), num_ops_(num_ops) { + rnd_.seed(GetSeed()); +} + +// compute method to apply uniformly random selected augmentations from a list +Status UniformAugOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + + // randomly select ops to be applied + std::vector> selected_tensor_ops; + std::sample(tensor_op_list_.begin(), tensor_op_list_.end(), std::back_inserter(selected_tensor_ops), num_ops_, rnd_); + + bool first = true; + for (const auto &tensor_op : selected_tensor_ops) { + // Do NOT apply the op, if second random generator returned zero + if (std::uniform_int_distribution(0, 1)(rnd_)) { + continue; + } + // apply C++ ops (note: python OPs are not accepted) + if (first) { + RETURN_IF_NOT_OK(tensor_op->Compute(input, output)); + first = false; + } else { + RETURN_IF_NOT_OK(tensor_op->Compute(std::move(*output), output)); + } + } + + // The case where no tensor op is applied. + if (output->empty()) { + *output = input; + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h new file mode 100644 index 0000000000..0ae0fda92b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ +#define DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class UniformAugOp : public TensorOp { + public: + // Default number of Operations to be applied + static const int kDefNumOps; + + // Constructor for UniformAugOp + // @param std::vector> op_list: list of candidate C++ operations + // @param int32_t num_ops: number of augemtation operations to applied + UniformAugOp(std::vector> op_list, int32_t num_ops); + + // Destructor + ~UniformAugOp() override = default; + + void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; } + + // Overrides the base class compute function + // @return Status - The error code return + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kUniformAugOp; } + + private: + int32_t num_ops_; + std::vector> tensor_op_list_; + std::mt19937 rnd_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/no_op.h b/mindspore/ccsrc/minddata/dataset/kernels/no_op.h new file mode 100644 index 0000000000..f5a6a58f2b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/no_op.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_NO_OP_H_ +#define DATASET_KERNELS_NO_OP_H_ + +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class NoOp : public TensorOp { + public: + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override { + *output = input; + return Status::OK(); + } + + void Print(std::ostream &out) const override { out << "NoOp"; }; + + std::string Name() const override { return kNoOp; } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_NO_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc new file mode 100644 index 0000000000..f501dd4b4f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/py_func_op.h" + +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + Status ret = Status(StatusCode::kOK, "PyFunc Call Succeed"); + { + // Acquire Python GIL + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + ret = Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + goto ComputeReturn; + } + try { + // Transform input tensor vector into numpy array vector + py::tuple input_args(input.size()); + for (size_t i = 0; i < input.size(); i++) { + py::array new_data; + RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); + // possible memcpy here + input_args[i] = new_data; + } + // Invoke python function + py::object ret_py_obj = this->py_func_ptr_(*input_args); + // Process the return value + if (py::isinstance(ret_py_obj)) { + // In case of a n-1 mapping, the return value will be a numpy array + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_obj.cast())); + output->push_back(out); + } else if (py::isinstance(ret_py_obj)) { + // In case of a n-m mapping, the return value will be a tuple of numpy arrays + py::tuple ret_py_tuple = ret_py_obj.cast(); + // Iterate over two containers simultaneously for memory copy + for (size_t i = 0; i < ret_py_tuple.size(); i++) { + py::object ret_py_ele = ret_py_tuple[i]; + if (!py::isinstance(ret_py_ele)) { + goto ShapeMisMatch; + } + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_ele.cast())); + output->push_back(out); + } + } else { + goto ShapeMisMatch; + } + } catch (const py::error_already_set &e) { + ret = Status(StatusCode::kPyFuncException, e.what()); + } + } + +ComputeReturn: + return ret; + +ShapeMisMatch: + ret = Status(StatusCode::kShapeMisMatch, "PyFunc should return a numpy array or a numpy array tuple"); + goto ComputeReturn; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.h b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.h new file mode 100644 index 0000000000..75d222b433 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_PY_FUNC_OP_H_ +#define DATASET_KERNELS_PY_FUNC_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class __attribute__((visibility("hidden"))) PyFuncOp : public TensorOp { + public: + explicit PyFuncOp(py::function func) : py_func_ptr_(std::move(func)) {} + + ~PyFuncOp() override = default; + + uint32_t NumInput() override { return 0; } + uint32_t NumOutput() override { return 0; } + + // Compute function for n-n mapping. + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kPyFuncOp; } + + private: + py::function py_func_ptr_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_PY_FUNC_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.cc new file mode 100644 index 0000000000..b625e3b532 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/tensor_op.h" +#include +#include +#include +#include + +namespace mindspore { +namespace dataset { +// Name: Compute() +// Description: This Compute() take 1 Tensor and produce 1 Tensor. +// The derived class should override this function otherwise error. +Status TensorOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (!OneToOne()) { + return Status(StatusCode::kUnexpectedError, "Wrong Compute() function is called. This is not 1-1 TensorOp."); + } else { + return Status(StatusCode::kUnexpectedError, + "Is this TensorOp 1-1? If yes, please implement this Compute() in the derived class."); + } +} + +// Name: Compute() +// Description: This Compute() take multiple Tensors from different columns and produce multiple Tensors too. +// The derived class should override this function otherwise error. +Status TensorOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + if (OneToOne()) { + output->resize(1); + return Compute(input[0], &(*output)[0]); + } + + return Status(StatusCode::kUnexpectedError, + "Is this TensorOp oneToOne? If no, please implement this Compute() in the derived class."); +} + +void TensorOp::Print(std::ostream &out) const { out << "TensorOp" << std::endl; } + +Status TensorOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + if (inputs.size() != NumInput()) + return Status(StatusCode::kUnexpectedError, + "The size of the input argument vector does not match the number of inputs"); + outputs = inputs; + return Status::OK(); +} + +Status TensorOp::OutputType(const std::vector &inputs, std::vector &outputs) { + if (inputs.size() != NumInput()) + return Status(StatusCode::kUnexpectedError, + "The size of the input argument vector does not match the number of inputs"); + outputs = inputs; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h new file mode 100644 index 0000000000..3bcba4b463 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -0,0 +1,212 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_TENSOR_OP_H_ +#define DATASET_KERNELS_TENSOR_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/util/status.h" + +#define IO_CHECK(input, output) \ + do { \ + if (input == nullptr || output == nullptr) { \ + RETURN_STATUS_UNEXPECTED("input or output is null."); \ + } \ + } while (false) + +#define IO_CHECK_VECTOR(input, output) \ + do { \ + if (output == nullptr) { \ + RETURN_STATUS_UNEXPECTED("output is null."); \ + } \ + for (auto &_i : input) { \ + if (_i == nullptr) { \ + RETURN_STATUS_UNEXPECTED("input is null."); \ + } \ + } \ + } while (false) + +#define BOUNDING_BOX_CHECK(input) \ + do { \ + if (input.size() != 2) { \ + return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ + "Requires Image and Bounding Boxes, likely missed bounding boxes."); \ + } \ + if (input[1]->shape().Size() < 2) { \ + return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ + "Bounding boxes shape should have at least two dimensions."); \ + } \ + uint32_t num_of_features = input[1]->shape()[1]; \ + if (num_of_features < 4) { \ + return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ + "Bounding boxes should be have at least 4 features."); \ + } \ + uint32_t num_of_boxes = input[1]->shape()[0]; \ + uint32_t img_h = input[0]->shape()[0]; \ + uint32_t img_w = input[0]->shape()[1]; \ + for (uint32_t i = 0; i < num_of_boxes; i++) { \ + float min_x = 0.0, min_y = 0.0, b_w = 0.0, b_h = 0.0; \ + bool passing_data_fetch = true; \ + passing_data_fetch &= input[1]->GetItemAt(&min_x, {i, 0}).IsOk(); \ + passing_data_fetch &= input[1]->GetItemAt(&min_y, {i, 1}).IsOk(); \ + passing_data_fetch &= input[1]->GetItemAt(&b_w, {i, 2}).IsOk(); \ + passing_data_fetch &= input[1]->GetItemAt(&b_h, {i, 3}).IsOk(); \ + if (!passing_data_fetch) { \ + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, \ + "Fetching BBox values failed in BOUNDING_BOX_CHECK."); \ + } \ + if ((min_x + b_w > img_w) || (min_y + b_h > img_h)) { \ + return Status(StatusCode::kBoundingBoxOutOfBounds, __LINE__, __FILE__, \ + "At least one of the bounding boxes is out of bounds of the image."); \ + } \ + if (static_cast(min_x) < 0 || static_cast(min_y) < 0) { \ + return Status(StatusCode::kBoundingBoxOutOfBounds, __LINE__, __FILE__, \ + "At least one of the bounding boxes has negative min_x or min_y."); \ + } \ + } \ + } while (false) + +namespace mindspore { +namespace dataset { + +// image +constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; +constexpr char kDecodeOp[] = "DecodeOp"; +constexpr char kCenterCropOp[] = "CenterCropOp"; +constexpr char kCutOutOp[] = "CutOutOp"; +constexpr char kHwcToChwOp[] = "HwcToChwOp"; +constexpr char kNormalizeOp[] = "NormalizeOp"; +constexpr char kPadOp[] = "PadOp"; +constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; +constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp"; +constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp"; +constexpr char kRandomCropDecodeResizeOp[] = "RandomCropDecodeResizeOp"; +constexpr char kRandomCropOp[] = "RandomCropOp"; +constexpr char kRandomCropWithBBoxOp[] = "RandomCropWithBBoxOp"; +constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp"; +constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp"; +constexpr char kRandomResizeOp[] = "RandomResizeOp"; +constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp"; +constexpr char kRandomRotationOp[] = "RandomRotationOp"; +constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp"; +constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp"; +constexpr char kRescaleOp[] = "RescaleOp"; +constexpr char kResizeBilinearOp[] = "ResizeBilinearOp"; +constexpr char kResizeOp[] = "ResizeOp"; +constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp"; +constexpr char kUniformAugOp[] = "UniformAugOp"; + +// text +constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp"; +constexpr char kBertTokenizerOp[] = "BertTokenizerOp"; +constexpr char kCaseFoldOp[] = "CaseFoldOp"; +constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp"; +constexpr char kLookupOp[] = "LookupOp"; +constexpr char kNgramOp[] = "NgramOp"; +constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op"; +constexpr char kRegexReplaceOp[] = "RegexReplaceOp"; +constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp"; +constexpr char kToNumberOp[] = "ToNumberOp"; +constexpr char kTruncateSequencePairOp[] = "TruncateSequencePairOp"; +constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp"; +constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp"; +constexpr char kWhitespaceTokenizerOp[] = "WhitespaceTokenizerOp"; +constexpr char kWordpieceTokenizerOp[] = "WordpieceTokenizerOp"; + +// data +constexpr char kConcatenateOp[] = "kConcatenateOp"; +constexpr char kDuplicateOp[] = "DuplicateOp"; +constexpr char kFillOp[] = "FillOp"; +constexpr char kMaskOp[] = "MaskOp"; +constexpr char kOneHotOp[] = "OneHotOp"; +constexpr char kPadEndOp[] = "PadEndOp"; +constexpr char kSliceOp[] = "SliceOp"; +constexpr char kToFloat16Op[] = "ToFloat16Op"; +constexpr char kTypeCastOp[] = "TypeCastOp"; + +// other +constexpr char kPyFuncOp[] = "PyFuncOp"; +constexpr char kNoOp[] = "NoOp"; + +// A class that does a computation on a Tensor +class TensorOp { + public: + TensorOp() = default; + + virtual ~TensorOp() = default; + + // A function that prints info about the tensor operation + // @param out + virtual void Print(std::ostream &out) const; + + // Provide stream operator for displaying it + // @param output stream + // @param so the TensorOp object to be printed + // @return output stream + friend std::ostream &operator<<(std::ostream &out, const TensorOp &so) { + so.Print(out); + return out; + } + + // Perform an operation on one Tensor and produce one Tensor. This is for 1-to-1 column MapOp + // @param input shares the ownership of the Tensor (increase the ref count). + // @param output the address to a shared_ptr where the result will be placed. + // @return Status + virtual Status Compute(const std::shared_ptr &input, std::shared_ptr *output); + + // Perform an operation on Tensors from multiple columns, and produce multiple Tensors. + // This is for m-to-n column MapOp. + // @param input is a vector of shared_ptr to Tensor (pass by const reference). + // @param output is the address to an empty vector of shared_ptr to Tensor. + // @return Status + virtual Status Compute(const TensorRow &input, TensorRow *output); + + // Returns true oif the TensorOp takes one input and returns one output. + // @return true/false + bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; } + + // Function to determine the number of inputs the TensorOp can take. 0: means undefined. + // @return uint32_t + virtual uint32_t NumInput() { return 1; } + + // Function to determine the number of output the TensorOp generates. 0: means undefined. + // @return uint32_t + virtual uint32_t NumOutput() { return 1; } + + // Function to determine the shapes of the output tensor given the input tensors' shapes. + // If a subclass did not override this function, it means that the shape does not change. + // @param inputs in: vector of the shapes of the input tensors. + // @param outputs out: vector of the shapes of the output tensors to be filled. + // @return Status + virtual Status OutputShape(const std::vector &inputs, std::vector &outputs); + + // Function to determine the types of the output tensor given the input tensor's types. + // If a subclass did not override this function, it means that the type does not change. + // @param inputs in: vector of the types of the input tensors. + // @param outputs out: vector of the types of the output tensors to be filled. + // @return Status + virtual Status OutputType(const std::vector &inputs, std::vector &outputs); + + virtual std::string Name() const = 0; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_TENSOR_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/text/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/text/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/text/CMakeLists.txt diff --git a/mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc new file mode 100644 index 0000000000..6195572944 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc @@ -0,0 +1,173 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/basic_tokenizer_op.h" +#include +#include +#include +#include +#include +#include + +#include "unicode/errorcode.h" +#include "unicode/normalizer2.h" +#include "unicode/utypes.h" + +namespace mindspore { +namespace dataset { + +const bool BasicTokenizerOp::kDefLowerCase = false; +const bool BasicTokenizerOp::kDefKeepWhitespace = false; +const NormalizeForm BasicTokenizerOp::kDefNormalizationForm = NormalizeForm::kNone; +const bool BasicTokenizerOp::kDefPreserveUnusedToken = true; +const bool BasicTokenizerOp::kDefWithOffsets = false; +const char BasicTokenizerOp::kCommonPattern[] = + "[!-/]" + "|[:-@]" + "|[\\[-`]" + "|[{-~]" + "|[\\p{P}]" + "|[\\x{4E00}-\\x{9FFF}]" + "|[\\x{3400}-\\x{4DBF}]" + "|[\\x{20000}-\\x{2A6DF}]" + "|[\\x{2A700}-\\x{2B73F}]" + "|[\\x{2B740}-\\x{2B81F}]" + "|[\\x{2B820}-\\x{2CEAF}]" + "|[\\x{F900}-\\x{FAFF}]" + "|[\\x{2F800}-\\x{2FA1F}]"; +const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|\\[unused\\d+\\]|"; +const std::unordered_set BasicTokenizerOp::kUnusedWords{"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"}; + +BasicTokenizerOp::BasicTokenizerOp(const bool &lower_case, const bool &keep_whitespace, + const NormalizeForm &normalization_form, const bool &preserve_unused_token, + const bool &with_offsets) + : lower_case_(lower_case), + keep_whitespace_(keep_whitespace), + preserve_unused_token_(preserve_unused_token), + with_offsets_(with_offsets), + case_fold_(std::make_unique()), + nfd_normalize_(std::make_unique(NormalizeForm::kNfd)), + normalization_form_(normalization_form), + common_normalize_(std::make_unique(normalization_form)), + replace_accent_chars_(std::make_unique("\\p{Mn}", "")), + replace_control_chars_(std::make_unique("\\p{Cc}|\\p{Cf}", " ")) { + std::string delim_pattern = std::string("\\s+|") + kCommonPattern; + std::string keep_delim_pattern; + if (keep_whitespace_) { + keep_delim_pattern = delim_pattern; + } else { + keep_delim_pattern = kCommonPattern; + } + if (preserve_unused_token_) { + keep_delim_pattern = kUnusedPattern + keep_delim_pattern; + delim_pattern = kUnusedPattern + delim_pattern; + } + regex_tokenizer_ = std::make_unique(delim_pattern, keep_delim_pattern, with_offsets_); +} + +Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text, + const std::unordered_set &unused_words, + std::string *outupt) { + icu::ErrorCode error; + const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); + outupt->clear(); + + // 1. get start and end offsets of not case fold strs + std::queue> offsets; // offsets of not used words + int start = -1; + int len = 0; + for (int i = 0; i < text.length(); i++) { + if (text[i] == '[') { + start = i; + ++len; + } else if (text[i] == ']' && start >= 0) { + ++len; + std::string word(text.substr(start, len)); + if (unused_words.find(word) != unused_words.end()) { + offsets.push(std::make_pair(start, start + len - 1)); + } + start = -1; + len = 0; + } else if (start >= 0) { + ++len; + } + } + + // 2. Do not apply case fold on `unused_words` + start = 0; + for (int i = 0; i < text.length();) { + std::string_view process_text; + std::string preserve_token; + if (offsets.empty()) { + i = text.length(); + process_text = text.substr(start, i - start); + } else { + preserve_token = text.substr(offsets.front().first, offsets.front().second - offsets.front().first + 1); + process_text = text.substr(start, offsets.front().first - start); + i = offsets.front().second + 1; + offsets.pop(); + } + std::string temp; + icu::StringByteSink sink(&temp); + nfkc_case_fold->normalizeUTF8(0, icu::StringPiece(process_text.data(), process_text.size()), sink, nullptr, error); + *outupt += temp + preserve_token; + } + return Status::OK(); +} + +Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr &input, + std::shared_ptr *output) { + IO_CHECK(input, output); + std::vector strs(input->Size()); + int i = 0; + for (auto iter = input->begin(); iter != input->end(); iter++) { + RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(*iter, kUnusedWords, &strs[i++])); + } + *output = std::make_shared(std::move(strs), input->shape()); + return Status::OK(); +} + +Status BasicTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::shared_ptr cur_input; + std::shared_ptr processed_tensor; + if (lower_case_) { + if (!preserve_unused_token_) { + // to lower case + RETURN_IF_NOT_OK(case_fold_->Compute(input[0], &processed_tensor)); + } else { + // to lower case except words in kUnusedWords + RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(input[0], &processed_tensor)); + } + cur_input = processed_tensor; + // strip accent characters + RETURN_IF_NOT_OK(nfd_normalize_->Compute(cur_input, &processed_tensor)); + cur_input = processed_tensor; + RETURN_IF_NOT_OK(replace_accent_chars_->Compute(cur_input, &processed_tensor)); + } else { + RETURN_IF_NOT_OK(common_normalize_->Compute(input[0], &processed_tensor)); + } + // strip control characters + cur_input = processed_tensor; + RETURN_IF_NOT_OK(replace_control_chars_->Compute(cur_input, &processed_tensor)); + return regex_tokenizer_->Compute(TensorRow(0, {std::move(processed_tensor)}), output); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.h new file mode 100644 index 0000000000..cbc21273c2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/text/kernels/case_fold_op.h" +#include "minddata/dataset/text/kernels/normalize_utf8_op.h" +#include "minddata/dataset/text/kernels/regex_replace_op.h" +#include "minddata/dataset/text/kernels/regex_tokenizer_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class BasicTokenizerOp : public TensorOp { + public: + static const bool kDefLowerCase; + static const bool kDefKeepWhitespace; + static const NormalizeForm kDefNormalizationForm; + static const bool kDefPreserveUnusedToken; + static const bool kDefWithOffsets; + + explicit BasicTokenizerOp(const bool &lower_case = kDefLowerCase, const bool &keep_whitespace = kDefKeepWhitespace, + const NormalizeForm &normalization_form = kDefNormalizationForm, + const bool &preserve_unused_token = kDefPreserveUnusedToken, + const bool &with_offsets = kDefWithOffsets); + + ~BasicTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "BasicTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + protected: + Status CaseFoldWithoutUnusedWords(const std::string_view &text, const std::unordered_set &unused_words, + std::string *outupt); + Status CaseFoldWithoutUnusedWords(const std::shared_ptr &input, std::shared_ptr *output); + + std::string Name() const override { return kBasicTokenizerOp; } + + private: + static const char kCommonPattern[]; + static const char kUnusedPattern[]; + static const std::unordered_set kUnusedWords; + bool with_offsets_; + bool lower_case_; + bool keep_whitespace_; + NormalizeForm normalization_form_; + bool preserve_unused_token_; + std::unique_ptr case_fold_; + std::unique_ptr nfd_normalize_; + std::unique_ptr common_normalize_; + std::unique_ptr replace_accent_chars_; + std::unique_ptr replace_control_chars_; + std::unique_ptr regex_tokenizer_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.cc new file mode 100644 index 0000000000..631597ba24 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/bert_tokenizer_op.h" +namespace mindspore { +namespace dataset { +Status BertTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + TensorRow basic_tensor; + RETURN_IF_NOT_OK(basic_tokenizer_.Compute(input, &basic_tensor)); + RETURN_IF_NOT_OK(wordpiece_tokenizer_.Compute(basic_tensor, output)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.h new file mode 100644 index 0000000000..b281903349 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/text/kernels/basic_tokenizer_op.h" +#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class BertTokenizerOp : public TensorOp { + public: + explicit BertTokenizerOp(const std::shared_ptr &vocab, + const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, + const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, + const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, + const bool &lower_case = BasicTokenizerOp::kDefLowerCase, + const bool &keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, + const NormalizeForm &normalization_form = BasicTokenizerOp::kDefNormalizationForm, + const bool &preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken, + const bool &with_offsets = WordpieceTokenizerOp::kDefWithOffsets) + : wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets), + basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token, with_offsets) {} + + ~BertTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "BertTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kBertTokenizerOp; } + + private: + WordpieceTokenizerOp wordpiece_tokenizer_; + BasicTokenizerOp basic_tokenizer_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc new file mode 100644 index 0000000000..0ea5cadedb --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/case_fold_op.h" +#include +#include +#include +#include +#include + +#include "unicode/errorcode.h" +#include "unicode/normalizer2.h" +#include "unicode/utypes.h" + +namespace mindspore { +namespace dataset { + +Status CaseFoldOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + icu::ErrorCode error; + const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); + std::vector strs(input->Size()); + int i = 0; + for (auto iter = input->begin(); iter != input->end(); iter++) { + icu::StringByteSink sink(&strs[i++]); + nfkc_case_fold->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); + } + *output = std::make_shared(std::move(strs), input->shape()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.h new file mode 100644 index 0000000000..f7a2105269 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ +#define DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class CaseFoldOp : public TensorOp { + public: + CaseFoldOp() {} + + ~CaseFoldOp() override = default; + + void Print(std::ostream &out) const override { out << "CaseFoldOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kCaseFoldOp; } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc new file mode 100644 index 0000000000..0a1ae92d14 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" + +#include +#include +#include +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { + +const bool JiebaTokenizerOp::kDefWithOffsets = false; + +JiebaTokenizerOp::JiebaTokenizerOp(const std::string &hmm_path, const std::string &dict_path, const JiebaMode &mode, + const bool &with_offsets) + : jieba_mode_(mode), hmm_model_path_(hmm_path), mp_dict_path_(dict_path), with_offsets_(with_offsets) { + jieba_parser_ = std::make_unique(mp_dict_path_, hmm_model_path_, ""); +} + +Status JiebaTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + RETURN_UNEXPECTED_IF_NULL(jieba_parser_); + + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); + } + + std::string_view sentence_v; + RETURN_IF_NOT_OK(input[0]->GetItemAt(&sentence_v, {})); + std::string sentence{sentence_v}; + std::vector words; + std::vector offsets_start, offsets_limit; + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + if (sentence == "") { + words.push_back(""); + } else { + std::vector tmp; + if (jieba_mode_ == JiebaMode::kMp) { + std::unique_ptr mp_seg = std::make_unique(jieba_parser_->GetDictTrie()); + mp_seg->Cut(sentence, tmp, MAX_WORD_LENGTH); + } else if (jieba_mode_ == JiebaMode::kHmm) { + std::unique_ptr hmm_seg = + std::make_unique(jieba_parser_->GetHMMModel()); + hmm_seg->Cut(sentence, tmp); + } else { // Mix + std::unique_ptr mix_seg = + std::make_unique(jieba_parser_->GetDictTrie(), jieba_parser_->GetHMMModel()); + mix_seg->Cut(sentence, tmp, true); + } + GetStringsFromWords(tmp, words); + for (auto item : tmp) { + offsets_start.push_back(static_cast(item.offset)); + offsets_limit.push_back(static_cast(item.offset + item.word.length())); + } + } + token_tensor = std::make_shared(words, TensorShape({(dsize_t)words.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} + +Status JiebaTokenizerOp::AddWord(const std::string &word, int freq) { + RETURN_UNEXPECTED_IF_NULL(jieba_parser_); + if (jieba_parser_->InsertUserWord(word, freq, "") == false) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "add word error"); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.h new file mode 100644 index 0000000000..4e49891c00 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_TEXT_JIEBA_OP_H_ +#define DATASET_ENGINE_TEXT_JIEBA_OP_H_ + +#include +#include + +#include "cppjieba/Jieba.hpp" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +enum class JiebaMode { kMix = 0, kMp = 1, kHmm = 2 }; + +class JiebaTokenizerOp : public TensorOp { + public: + // default constant for Jieba MPSegment algorithm. + static constexpr size_t MAX_WORD_LENGTH = 512; + // default const for set whether Jieba output offsets tensor. + static const bool kDefWithOffsets; + // Constructor for JiebaTokenizerOp. + // @param hmm_path HMM model file. + // @param mp_path MP model file. + // @mode tokenization mode [Default "MIX"], "MP" model will tokenize with MPSegment algorithm, "HMM" mode will + // tokenize with Hiddel Markov Model Segment algorithm, "MIx" model will tokenize with a mix of MPSegment and + // HMMSegment algorithm. + // @with_offsets user set this value to choose whether output offset tensor. + JiebaTokenizerOp(const std::string &hmm_path, const std::string &mp_path, const JiebaMode &mode = JiebaMode::kMix, + const bool &with_offsets = kDefWithOffsets); + ~JiebaTokenizerOp() override = default; + + void Print(std::ostream &out) const override { + out << "JiebaTokenizerOp: " << jieba_mode_ << "hmm_model_path_ " << hmm_model_path_ << "mp_dict_path_" + << mp_dict_path_; + } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + // @word the word to be added to the JiebaTokenizer. + // @freq [Default 0] the frequency fo the word to be added. + // @tag [Default ""] the tag of the word to be added. + Status AddWord(const std::string &word, int freq = 0); + + std::string Name() const override { return kJiebaTokenizerOp; } + + protected: + std::string hmm_model_path_; + std::string mp_dict_path_; + std::unique_ptr jieba_parser_; + JiebaMode jieba_mode_; + bool with_offsets_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_TEXT_JIEBA_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc new file mode 100644 index 0000000000..02b75bc4f9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/lookup_op.h" + +#include + +namespace mindspore { +namespace dataset { + +LookupOp::LookupOp(std::shared_ptr vocab, WordIdType default_id) + : vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} + +Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + RETURN_UNEXPECTED_IF_NULL(vocab_); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor."); + std::vector word_ids; + word_ids.reserve(input->Size()); + for (auto itr = input->begin(); itr != input->end(); itr++) { + WordIdType word_id = vocab_->Lookup(std::string(*itr)); + word_ids.emplace_back(word_id == Vocab::kNoTokenExists ? default_id_ : word_id); + CHECK_FAIL_RETURN_UNEXPECTED( + word_ids.back() != Vocab::kNoTokenExists, + "Lookup Error: token" + std::string(*itr) + "doesn't exist in vocab and no unknown token is specified."); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_, + reinterpret_cast(word_ids.data()))); + return Status::OK(); +} +Status LookupOp::OutputType(const std::vector &inputs, std::vector &outputs) { + CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match"); + CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type"); + outputs[0] = type_; + return Status::OK(); +} + +void LookupOp::Print(std::ostream &out) const { + out << "LookupOp: " + << "type: " << type_ << "\n default lookup id: " << default_id_ << "\n"; +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h new file mode 100644 index 0000000000..4efc64321b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_TEXT_KERNELS_LOOKUP_OP_H_ +#define DATASET_TEXT_KERNELS_LOOKUP_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/text/vocab.h" + +namespace mindspore { +namespace dataset { +class LookupOp : public TensorOp { + public: + // constructor for lookup, takes in a vocab object + // @param std::shared_ptr vocab - + // @param WordIdType default_id, id to lookup if a word is not in vocab + explicit LookupOp(std::shared_ptr vocab, WordIdType default_id = 1); + + ~LookupOp() = default; + + // perform actual lookup on each tensor + // @param const std::shared_ptr &input + // @param std::shared_ptr *output + // @return error code + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + // print method + // @param std::ostream out + void Print(std::ostream &out) const override; + + // @param std::vector &inputs - + // @param std::vector &outputs - + // @return error code + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kLookupOp; } + + private: + std::shared_ptr vocab_; + WordIdType default_id_; + DataType type_; // type of tensor after lookup +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TEXT_KERNELS_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc new file mode 100644 index 0000000000..36781b9b4d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/ngram_op.h" + +#include +#include +#include +#include + +namespace mindspore { +namespace dataset { + +NgramOp::NgramOp(const std::vector &ngrams, int32_t l_len, int32_t r_len, const std::string &l_pad, + const std::string &r_pad, const std::string &separator) + : ngrams_(ngrams), + l_len_(l_len), + r_len_(r_len), + l_pad_with_sp_(l_pad + separator), + r_pad_with_sp_(r_pad + separator), + separator_(separator) {} + +Status NgramOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING && input->Rank() == 1, "Not a 1-D str Tensor"); + std::vector offsets; // offsets for each str + std::vector res; // holds the result of ngrams + std::string str_buffer; // concat all pad tokens with string interleaved with separators + res.reserve(input->shape().NumOfElements()); // this should be more than enough + offsets.reserve(1 + l_len_ + r_len_ + input->shape().NumOfElements()); + str_buffer.reserve(l_pad_with_sp_.size() * l_len_ + r_pad_with_sp_.size() * r_len_ + input->SizeInBytes()); + offsets.push_back(str_buffer.size()); // insert 0 as the starting pos + for (int i = 0; i < l_len_; i++) offsets.push_back((str_buffer += l_pad_with_sp_).size()); + + for (auto itr = input->begin(); itr != input->end(); itr++) { + str_buffer += (*itr); + str_buffer += separator_; + offsets.push_back(str_buffer.size()); + } + + for (int i = 0; i < r_len_; i++) offsets.push_back((str_buffer += r_pad_with_sp_).size()); + + for (auto n : ngrams_) { + CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "n gram needs to be a positive number.\n"); + int32_t start_ind = l_len_ - std::min(l_len_, n - 1); + int32_t end_ind = offsets.size() - r_len_ + std::min(r_len_, n - 1); + if (end_ind - start_ind <= n) { + res.emplace_back(std::string()); // push back empty string + } else { + CHECK_FAIL_RETURN_UNEXPECTED(end_ind - n >= 0, "Incorrect loop condition"); + + for (int i = start_ind; i < end_ind - n; i++) { + res.emplace_back(str_buffer.substr(offsets[i], offsets[i + n] - offsets[i] - separator_.size())); + } + } + } + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, res, TensorShape({static_cast(res.size())}))); + return Status::OK(); +} + +void NgramOp::Print(std::ostream &out) const { + out << "NgramOp: " + << "left pad width: " << l_len_ << " left pad token with separator: " << l_pad_with_sp_ << "\n" + << "right pad width: " << r_len_ << " right pad token with separator: " << r_pad_with_sp_ << "\n" + << "separator: " << separator_ << "\n"; +} + +Status NgramOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n"); + CHECK_FAIL_RETURN_UNEXPECTED(inputs[0].Rank() == 1, "ngram only works with 1-dim data\n"); + dsize_t num_elements = ngrams_.size(); + for (int32_t n : ngrams_) { + // here since rank == 1, NumOfElements == shape[0]. add padding length to string + int32_t len_with_padding = inputs[0].NumOfElements() + std::min(n - 1, l_len_) + std::min(n - 1, r_len_); + // if len_with_padding - n < 0, this would return an empty string + num_elements += std::max(len_with_padding - n, 0); + } + outputs.emplace_back(TensorShape({num_elements})); + CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n"); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.h new file mode 100644 index 0000000000..6ce3881638 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_TEXT_KERNELS_NGRAM_OP_H_ +#define DATASET_TEXT_KERNELS_NGRAM_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class NgramOp : public TensorOp { + public: + // Constructor of Ngram model + // @param const std::vector &ngrams + // @param int32_tl_len - padding length on the left + // @param int32_t r_len - padding length on the right + // @param const std::string &l_pad - padding token on the left + // @param const std::string &r_pad - padding token on the right + // @param const std::string &separator - use to join strings + NgramOp(const std::vector &ngrams, int32_t l_len, int32_t r_len, const std::string &l_pad, + const std::string &r_pad, const std::string &separator); + + // perform ngram model on each tensor + // @param const std::shared_ptr &input + // @param std::shared_ptr *output + // @return error code + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + // destructor + ~NgramOp() override = default; + + // @param std::vector &inputs - shape of input tensors + // @param std::vector &outputs - shape of output tensors + // @return error code + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + // print arg for debugging + // @param std::ostream &out + void Print(std::ostream &out) const override; + + std::string Name() const override { return kNgramOp; } + + private: + std::vector ngrams_; // list of n grams + int32_t l_len_; // left padding length + int32_t r_len_; // right padding length + std::string l_pad_with_sp_; // left padding appended with separator + std::string r_pad_with_sp_; // right padding appended with separator + std::string separator_; // separator +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TEXT_KERNELS_NGRAM_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc new file mode 100644 index 0000000000..0c0aa5fa2d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/normalize_utf8_op.h" +#include +#include +#include +#include +#include + +#include "unicode/errorcode.h" +#include "unicode/normalizer2.h" +#include "unicode/utypes.h" + +namespace mindspore { +namespace dataset { +const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc; +Status NormalizeUTF8Op::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + icu::ErrorCode error; + const icu::Normalizer2 *normalize = nullptr; + switch (normalize_form_) { + case NormalizeForm::kNone: { + *output = input; + return Status::OK(); + } + case NormalizeForm::kNfc: { + normalize = icu::Normalizer2::getNFCInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFCInstance failed"); + break; + } + case NormalizeForm::kNfkc: { + normalize = icu::Normalizer2::getNFKCInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCInstance failed"); + break; + } + case NormalizeForm::kNfd: { + normalize = icu::Normalizer2::getNFDInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFDInstance failed"); + break; + } + case NormalizeForm::kNfkd: { + normalize = icu::Normalizer2::getNFKDInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKDInstance failed"); + break; + } + default: { + RETURN_STATUS_UNEXPECTED("unexpected normalize form"); + break; + } + } + std::vector strs(input->Size()); + int i = 0; + for (auto iter = input->begin(); iter != input->end(); iter++) { + icu::StringByteSink sink(&strs[i++]); + normalize->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); + } + *output = std::make_shared(std::move(strs), input->shape()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.h new file mode 100644 index 0000000000..f914be1c58 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ +#define DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +enum class NormalizeForm { + kNone = 0, + kNfc, + kNfkc, + kNfd, + kNfkd, +}; + +class NormalizeUTF8Op : public TensorOp { + public: + static const NormalizeForm kDefNormalizeForm; + explicit NormalizeUTF8Op(NormalizeForm normalize_form = kDefNormalizeForm) : normalize_form_(normalize_form) {} + + ~NormalizeUTF8Op() override = default; + + void Print(std::ostream &out) const override { out << "NormalizeUTF8Op"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kNormalizeUTF8Op; } + + private: + NormalizeForm normalize_form_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc new file mode 100644 index 0000000000..c370393e76 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/regex_replace_op.h" +#include +#include +#include +#include +#include + +namespace mindspore { +namespace dataset { + +Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, + std::string *out) const { + CHECK_FAIL_RETURN_UNEXPECTED((matcher != nullptr && out != nullptr), "Input is null"); + UErrorCode icu_error = U_ZERO_ERROR; + icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(text); + matcher->reset(unicode_text); + icu::UnicodeString unicode_out; + if (replace_all_) { + unicode_out = matcher->replaceAll(replace_, icu_error); + } else { + unicode_out = matcher->replaceFirst(replace_, icu_error); + } + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "RegexReplace failed"); + unicode_out.toUTF8String(*out); + return Status::OK(); +} + +Status RegexReplaceOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + UErrorCode icu_error = U_ZERO_ERROR; + icu::RegexMatcher matcher(pattern_, 0, icu_error); + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern"); + std::vector strs(input->Size()); + int i = 0; + for (auto iter = input->begin(); iter != input->end(); iter++) { + RETURN_IF_NOT_OK(RegexReplace(&matcher, *iter, &strs[i])); + } + *output = std::make_shared(std::move(strs), input->shape()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.h new file mode 100644 index 0000000000..ac3d3f7ff0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ +#define DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ +#include +#include + +#include "unicode/regex.h" +#include "unicode/errorcode.h" +#include "unicode/utypes.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class RegexReplaceOp : public TensorOp { + public: + RegexReplaceOp(const std::string &pattern, const std::string &replace, bool replace_all = true) + : pattern_(icu::UnicodeString::fromUTF8(pattern)), + replace_(icu::UnicodeString::fromUTF8(replace)), + replace_all_(replace_all) {} + + ~RegexReplaceOp() override = default; + + void Print(std::ostream &out) const override { out << "RegexReplaceOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRegexReplaceOp; } + + protected: + Status RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, std::string *out) const; + + private: + const icu::UnicodeString pattern_; + const icu::UnicodeString replace_; + const bool replace_all_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc new file mode 100644 index 0000000000..7ff1d994be --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/regex_tokenizer_op.h" +#include +#include +#include +#include +#include + +namespace mindspore { +namespace dataset { + +const bool RegexTokenizerOp::kDefWithOffsets = false; + +Status RegexTokenizerOp::GetUnicodeSubstr(const icu::UnicodeString &input, const int &start, const int &len, + std::string *out_utf8, icu::UnicodeString *out_unicode) const { + CHECK_FAIL_RETURN_UNEXPECTED((out_utf8 != nullptr || out_unicode != nullptr), "Wrong input"); + int total_len = input.length(); + int end = start + len; + CHECK_FAIL_RETURN_UNEXPECTED((start >= 0 && len > 0 && end <= total_len), "Out of range"); + icu::UnicodeString temp; + input.extract(start, len, temp); + if (out_utf8 != nullptr) { + temp.toUTF8String(*out_utf8); + } + if (out_unicode != nullptr) { + *out_unicode = temp; + } + return Status::OK(); +} + +Status RegexTokenizerOp::GetRegexTokens(const std::string &text, std::vector *out_tokens, + std::vector *offsets_start, + std::vector *offsets_limit) const { + UErrorCode status = U_ZERO_ERROR; + out_tokens->clear(); + icu::RegexMatcher token_matcher(delim_pattern_, 0, status); + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Create icu RegexMatcher failed, you may input one error pattern"); + icu::RegexMatcher delim_matcher(keep_delim_pattern_, 0, status); + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Create icu RegexMatcher failed, you may input one error pattern"); + + icu::UnicodeString utext(icu::UnicodeString::fromUTF8(text)); + token_matcher.reset(utext); + + int text_start_index = 0; + int token_start_index = 0; + status = U_ZERO_ERROR; + while (token_matcher.find(status) && U_SUCCESS(status)) { + int deli_start_index = token_matcher.start(status); + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Get RegexMatcher matched start index failed"); + int deli_end_index = token_matcher.end(status); + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Get RegexMatcher matched start index failed"); + + // Add non-empty token + int token_len = deli_start_index - token_start_index; + if (token_len > 0) { + std::string token; + uint32_t token_offset = 0; + RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, token_start_index, token_len, &token)); + token_offset = token.length(); + out_tokens->emplace_back(std::move(token)); + offsets_start->push_back(static_cast(text_start_index)); + offsets_limit->push_back(static_cast(text_start_index + token_offset)); + text_start_index += token_offset; + } + + int delim_len = deli_end_index - deli_start_index; + if (delim_len > 0) { + icu::UnicodeString delim_str; + std::string delim_utf8_str; + uint32_t delim_str_offset = 0; + RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, deli_start_index, delim_len, &delim_utf8_str, &delim_str)); + delim_matcher.reset(delim_str); + delim_str_offset = delim_utf8_str.length(); + if (keep_delim_ && delim_matcher.matches(status) && U_SUCCESS(status)) { + out_tokens->emplace_back(std::move(delim_utf8_str)); + offsets_start->push_back(static_cast(text_start_index)); + offsets_limit->push_back(static_cast(text_start_index + delim_str_offset)); + } + text_start_index += delim_str_offset; + } + token_start_index = deli_end_index; + } + + if (token_start_index < utext.length()) { + std::string temp; + uint32_t temp_offset = 0; + RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, token_start_index, utext.length() - token_start_index, &temp)); + temp_offset = temp.length(); + out_tokens->emplace_back(std::move(temp)); + offsets_start->push_back(static_cast(text_start_index)); + offsets_limit->push_back(static_cast(text_start_index + temp_offset)); + } + return Status::OK(); +} + +Status RegexTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::string_view text; + std::vector tokens; + std::vector offsets_start; + std::vector offsets_limit; + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + RETURN_IF_NOT_OK(input[0]->GetItemAt(&text, {})); + RETURN_IF_NOT_OK(GetRegexTokens(std::string(text.data(), text.size()), &tokens, &offsets_start, &offsets_limit)); + token_tensor = std::make_shared(std::move(tokens), TensorShape({(dsize_t)tokens.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.h new file mode 100644 index 0000000000..56271f9551 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_REGEX_TOKENIZER_OP_H_ +#define DATASET_TEXT_REGEX_TOKENIZER_OP_H_ +#include +#include +#include + +#include "unicode/regex.h" +#include "unicode/errorcode.h" +#include "unicode/utypes.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class RegexTokenizerOp : public TensorOp { + public: + static const bool kDefWithOffsets; + + RegexTokenizerOp(const std::string &delim_pattern, const std::string &keep_delim_pattern, + const bool &with_offsets = kDefWithOffsets) + : delim_pattern_(icu::UnicodeString::fromUTF8(delim_pattern)), + keep_delim_pattern_(icu::UnicodeString::fromUTF8(keep_delim_pattern)), + with_offsets_(with_offsets), + keep_delim_(!keep_delim_pattern.empty()) {} + + ~RegexTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "RegexTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + protected: + Status GetUnicodeSubstr(const icu::UnicodeString &input, const int &start, const int &len, std::string *out_utf8, + icu::UnicodeString *out_unicode = nullptr) const; + Status GetRegexTokens(const std::string &text, std::vector *out_tokens, + std::vector *offsets_start, std::vector *offsets_limit) const; + + std::string Name() const override { return kRegexTokenizerOp; } + + private: + const icu::UnicodeString delim_pattern_; + const icu::UnicodeString keep_delim_pattern_; + bool with_offsets_; + const bool keep_delim_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_REGEX_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc new file mode 100644 index 0000000000..a6685a2d64 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc @@ -0,0 +1,241 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/to_number_op.h" + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +ToNumberOp::ToNumberOp(const DataType &cast_to_type) : cast_to_type_(cast_to_type) {} + +ToNumberOp::ToNumberOp(const std::string &cast_to_type) : cast_to_type_(DataType(cast_to_type)) {} + +Status ToNumberOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tenosrs should have type string."); + + switch (cast_to_type_.value()) { + case DataType::DE_INT8: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_INT16: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_INT32: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_INT64: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_UINT8: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_UINT16: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_UINT32: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_UINT64: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_FLOAT16: + RETURN_IF_NOT_OK(this->ToFloat16(input, output)); + break; + case DataType::DE_FLOAT32: + RETURN_IF_NOT_OK(ToFloat(input, output)); + break; + case DataType::DE_FLOAT64: + RETURN_IF_NOT_OK(ToDouble(input, output)); + break; + } + + return Status::OK(); +} + +void ToNumberOp::Print(std::ostream &out) const { out << "ToNumberOp: casting to " << '\n'; } + +Status ToNumberOp::OutputShape(const std::vector &input_shapes, std::vector &output_shapes) { + (void)std::copy(input_shapes.begin(), input_shapes.end(), std::back_inserter(output_shapes)); + return Status::OK(); +} + +template +Status ToNumberOp::ToSignedIntegral(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + int64_t result = 0; + + try { + result = std::stoll(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to a number."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::min() || is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::min()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + T casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +template +Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + uint64_t result = 0; + + // If there is a - at the start of the string, it is considered by us to + // be out of bounds. If the - is somewhere else in the string, it is + // deemed invalid by std::stoull and will throw std::invalid_argument + for (int i = 0; i < (*it).size(); i++) { + if ((*it)[i] == '-') { + is_cast_out_of_range = true; + break; + } + } + + try { + result = std::stoull(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::min() || is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::min()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + T casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +Status ToNumberOp::ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { + // special case, float16 does not exist in c++, no native support for + // casting, so cast to float first then use this method, which use Eigen. + std::shared_ptr temp; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32"))); + RETURN_IF_NOT_OK(ToFloat(input, &temp)); + RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output)); + return Status::OK(); +} + +Status ToNumberOp::ToFloat(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + float result = 0; + + try { + result = std::stof(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest() || + is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::lowest()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + float casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +Status ToNumberOp::ToDouble(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + double result = 0; + + try { + result = std::stod(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest() || + is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::lowest()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + double casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.h new file mode 100644 index 0000000000..8582fcf073 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.h @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ +#define DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class ToNumberOp : public TensorOp { + public: + // Constructor of ToNumberOp + // @param const DataType &cast_to_type - the type to convert string inputs to. + explicit ToNumberOp(const DataType &cast_to_type); + + // Constructor of ToNumberOp + // @param const std::string &cast_to_type - the type in string form to convert string inputs to. + explicit ToNumberOp(const std::string &cast_to_type); + + ~ToNumberOp() override = default; + + // Perform numeric conversion on each string in each tensor. + // @param const std::shared_ptr &input + // @param std::shared_ptr *output + // @return error code + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + // For each input shape, find the output shape + // @param std::vector &inputs - shape of input tensors + // @param std::vector &outputs - shape of output tensors + // @return error code + Status OutputShape(const std::vector &input_shapes, std::vector &output_shapes) override; + + // print arg for debugging + // @param std::ostream &out + void Print(std::ostream &out) const override; + + std::string Name() const override { return kToNumberOp; } + + private: + template + Status ToSignedIntegral(const std::shared_ptr &input, std::shared_ptr *output); + + template + Status ToUnsignedIntegral(const std::shared_ptr &input, std::shared_ptr *output); + + Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output); + + Status ToFloat(const std::shared_ptr &input, std::shared_ptr *output); + + Status ToDouble(const std::shared_ptr &input, std::shared_ptr *output); + + DataType cast_to_type_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.cc new file mode 100644 index 0000000000..53a803c542 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/data/slice_op.h" + +namespace mindspore { +namespace dataset { + +Status TruncateSequencePairOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 2, "Number of inputs should be two."); + std::shared_ptr seq1 = input[0]; + std::shared_ptr seq2 = input[1]; + CHECK_FAIL_RETURN_UNEXPECTED(seq1->shape().Rank() == 1 && seq2->shape().Rank() == 1, + "Both sequences should be of rank 1"); + dsize_t length1 = seq1->shape()[0]; + dsize_t length2 = seq2->shape()[0]; + dsize_t outLength1 = length1; + dsize_t outLength2 = length2; + + dsize_t total = length1 + length2; + while (total > max_length_) { + if (outLength1 > outLength2) + outLength1--; + else + outLength2--; + total--; + } + std::shared_ptr outSeq1; + if (length1 != outLength1) { + std::unique_ptr slice1(new SliceOp(Slice(outLength1 - length1))); + RETURN_IF_NOT_OK(slice1->Compute(seq1, &outSeq1)); + } else { + outSeq1 = std::move(seq1); + } + + std::shared_ptr outSeq2; + if (length2 != outLength2) { + std::unique_ptr slice2(new SliceOp(Slice(outLength2 - length2))); + RETURN_IF_NOT_OK(slice2->Compute(seq2, &outSeq2)); + } else { + outSeq2 = std::move(seq2); + } + output->push_back(outSeq1); + output->push_back(outSeq2); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h new file mode 100644 index 0000000000..ce82735645 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ +#define DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/kernels/data/data_utils.h" + +namespace mindspore { +namespace dataset { + +class TruncateSequencePairOp : public TensorOp { + public: + explicit TruncateSequencePairOp(dsize_t length) : max_length_(length) {} + + ~TruncateSequencePairOp() override = default; + + void Print(std::ostream &out) const override { out << "TruncateSequencePairOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kTruncateSequencePairOp; } + + private: + dsize_t max_length_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc new file mode 100644 index 0000000000..e08f61100b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" +#include +#include +#include +#include + +#include "cppjieba/Unicode.hpp" + +using cppjieba::DecodeRunesInString; +using cppjieba::RuneStrArray; + +namespace mindspore { +namespace dataset { + +const bool UnicodeCharTokenizerOp::kDefWithOffsets = false; + +Status UnicodeCharTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::string_view str; + RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); + + RuneStrArray runes; + if (!DecodeRunesInString(str.data(), str.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + std::vector splits(runes.size()); + std::vector offsets_start, offsets_limit; + for (size_t i = 0; i < runes.size(); i++) { + offsets_start.push_back(runes[i].offset); + offsets_limit.push_back(runes[i].offset + runes[i].len); + splits[i] = str.substr(runes[i].offset, runes[i].len); + } + if (splits.empty()) { + splits.emplace_back(""); + offsets_start.push_back(0); + offsets_limit.push_back(0); + } + token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.h new file mode 100644 index 0000000000..415d99b451 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class UnicodeCharTokenizerOp : public TensorOp { + public: + static const bool kDefWithOffsets; + + explicit UnicodeCharTokenizerOp(const bool &with_offsets = kDefWithOffsets) : with_offsets_(with_offsets) {} + + ~UnicodeCharTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "UnicodeCharTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kUnicodeCharTokenizerOp; } + + private: + bool with_offsets_; +}; + +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc new file mode 100644 index 0000000000..60fe8dd0e4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h" +#include +#include +#include +#include +#include + +#include "cppjieba/Unicode.hpp" +#include "unicode/errorcode.h" +#include "unicode/uchar.h" +#include "unicode/uscript.h" + +using cppjieba::DecodeRunesInString; +using cppjieba::RuneStrArray; + +namespace mindspore { +namespace dataset { + +const bool UnicodeScriptTokenizerOp::kDefKeepWhitespace = false; +const bool UnicodeScriptTokenizerOp::kDefWithOffsets = false; + +Status UnicodeScriptTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::string_view str; + RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); + RuneStrArray runes; + if (!DecodeRunesInString(str.data(), str.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + UScriptCode last_script = USCRIPT_INVALID_CODE; + icu::ErrorCode status; + int start = 0; + int len = 0; + std::vector splits; + std::vector offsets_start, offsets_limit; + + bool was_space = false; + for (size_t i = 0; i < runes.size(); i++) { + bool is_space = u_isUWhiteSpace(runes[i].rune); + UScriptCode script = uscript_getScript(runes[i].rune, status); + if (status.isFailure()) { + status.reset(); + script = USCRIPT_INVALID_CODE; + } + // 1) Seperate UTF-8 strings of different UScriptCode values + // (such as: "Chinese中国" should be splited to ["Chinese", "中国"]) + // 2) Seperate whitespace and non-whitespace UTF-8 strings + // (such as: " ." should be split to [" ", "."]) + if (len > 0 && (script != last_script || is_space != was_space)) { + // 3) If keep_whitespace_ is false, all the whitespace characters will be discard + if (keep_whitespace_ || !was_space) { + offsets_start.push_back(static_cast(start)); + offsets_limit.push_back(static_cast(start + len)); + std::string temp(str.substr(start, len)); + splits.emplace_back(std::move(temp)); + } + start = runes[i].offset; + len = runes[i].len; + } else { + len += runes[i].len; + } + last_script = script; + was_space = is_space; + } + + if (len > 0 && (keep_whitespace_ || !was_space)) { + offsets_start.push_back(static_cast(start)); + offsets_limit.push_back(static_cast(start + len)); + std::string temp(str.substr(start, len)); + splits.emplace_back(std::move(temp)); + } + // 4) If the input is empty scalar string, the output will be 1-D empty string. + if (splits.empty()) { + splits.emplace_back(""); + offsets_start.push_back(0); + offsets_limit.push_back(0); + } + token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.h new file mode 100644 index 0000000000..fc3b9e620a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class UnicodeScriptTokenizerOp : public TensorOp { + public: + static const bool kDefKeepWhitespace; + static const bool kDefWithOffsets; + + explicit UnicodeScriptTokenizerOp(const bool &keep_whitespace = kDefKeepWhitespace, + const bool &with_offsets = kDefWithOffsets) + : keep_whitespace_(keep_whitespace), with_offsets_(with_offsets) {} + + ~UnicodeScriptTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "UnicodeScriptTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kUnicodeScriptTokenizerOp; } + + private: + bool keep_whitespace_; // If or not keep whitespace tokens + bool with_offsets_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc new file mode 100644 index 0000000000..d3bb32081e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h" +#include +#include +#include +#include +#include + +#include "cppjieba/Unicode.hpp" +#include "unicode/errorcode.h" +#include "unicode/uchar.h" +#include "unicode/uscript.h" + +using cppjieba::DecodeRunesInString; +using cppjieba::RuneStrArray; + +namespace mindspore { +namespace dataset { + +const bool WhitespaceTokenizerOp::kDefWithOffsets = false; + +Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::string_view str; + RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); + + RuneStrArray runes; + if (!DecodeRunesInString(str.data(), str.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + std::vector offsets_start, offsets_limit; + std::vector splits; + int start = 0; + int len = 0; + for (size_t i = 0; i < runes.size(); i++) { + if (u_isUWhiteSpace(runes[i].rune)) { + if (len > 0) { + offsets_start.push_back(static_cast(start)); + offsets_limit.push_back(static_cast(start + len)); + std::string temp(str.substr(start, len)); + splits.emplace_back(std::move(temp)); + len = 0; + } + } else { + if (len == 0) { + start = runes[i].offset; + } + len += runes[i].len; + } + } + if (len > 0) { + offsets_start.push_back(static_cast(start)); + offsets_limit.push_back(static_cast(start + len)); + std::string temp(str.substr(start, len)); + splits.emplace_back(std::move(temp)); + } + if (splits.empty()) { + splits.emplace_back(""); + offsets_start.push_back(0); + offsets_limit.push_back(0); + } + token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.h new file mode 100644 index 0000000000..7cc37fd705 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class WhitespaceTokenizerOp : public TensorOp { + public: + static const bool kDefWithOffsets; + + explicit WhitespaceTokenizerOp(const bool &with_offsets = kDefWithOffsets) : with_offsets_(with_offsets) {} + + ~WhitespaceTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "WhitespaceTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kWhitespaceTokenizerOp; } + + private: + bool with_offsets_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc new file mode 100644 index 0000000000..f0bd448e39 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" +#include +#include + +namespace mindspore { +namespace dataset { + +const char WordpieceTokenizerOp::kDefSuffixIndicator[] = "##"; +const int WordpieceTokenizerOp::kDefMaxBytesPerToken = 100; +const char WordpieceTokenizerOp::kDefUnknownToken[] = "[UNK]"; +const bool WordpieceTokenizerOp::kDefWithOffsets = false; + +WordpieceTokenizerOp::WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator, + const int &max_bytes_per_token, const std::string &unknown_token, + const bool &with_offsets) + : vocab_(vocab), + suffix_indicator_(suffix_indicator), + max_bytes_per_token_(max_bytes_per_token), + unknown_token_(unknown_token), + with_offsets_(with_offsets) {} + +Status WordpieceTokenizerOp::LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, + bool *out_found, int *out_end) const { + CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && start < input_token.size(), "Out of range"); + *out_found = false; + for (int i = runes.size() - 1; i >= 0; i--) { + *out_end = runes[i].offset + runes[i].len; + int len = *out_end - start; + std::string word = input_token.substr(start, len); + if (start > 0) { + word = suffix_indicator_ + word; + } + if (vocab_->Lookup(word) != Vocab::kNoTokenExists) { + *out_found = true; + break; + } + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::FoundNoToken(const std::string &input_token, const uint32_t &basic_start, + std::vector *out_tokens, std::vector *offsets_start, + std::vector *offsets_limit) const { + out_tokens->clear(); + offsets_start->push_back(basic_start); + if (unknown_token_.empty()) { + out_tokens->emplace_back(input_token); + offsets_limit->push_back(basic_start + input_token.length()); + } else { + out_tokens->emplace_back(unknown_token_); + offsets_limit->push_back(basic_start + input_token.length()); + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::AddSubword(const std::string &input_token, const int &start, const int &end, + std::vector *out_tokens) const { + CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && end > start && end <= input_token.size(), "Out of range"); + std::string subword = input_token.substr(start, end - start); + if (start > 0) { + subword = suffix_indicator_ + subword; + } + out_tokens->emplace_back(subword); + return Status::OK(); +} + +Status WordpieceTokenizerOp::GetTokens(const std::string &input_token, const uint32_t &basic_start, + std::vector *out_tokens, std::vector *offsets_start, + std::vector *offsets_limit) const { + if (input_token.size() > max_bytes_per_token_) { + offsets_start->push_back(basic_start); + if (!unknown_token_.empty()) { + offsets_limit->push_back(basic_start + unknown_token_.size()); + out_tokens->emplace_back(unknown_token_); + } else { + out_tokens->emplace_back(input_token); + offsets_limit->push_back(basic_start + input_token.size()); + } + return Status::OK(); + } + RuneStrArray runes; + if (!DecodeRunesInString(input_token.data(), input_token.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + int end = 0; + for (int start = 0; start < input_token.size();) { + bool found = false; + RETURN_IF_NOT_OK(LookupWord(input_token, runes, start, &found, &end)); + if (found) { + RETURN_IF_NOT_OK(AddSubword(input_token, start, end, out_tokens)); + offsets_start->push_back(static_cast(basic_start + start)); + offsets_limit->push_back(static_cast(basic_start + end)); + start = end; + } else { + return FoundNoToken(input_token, basic_start, out_tokens, offsets_start, offsets_limit); + } + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + if (input[0]->Rank() > 1 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar or 1-D string tensor"); + } + dsize_t count = 0; + std::vector out_tokens; + std::vector offsets_start, offsets_limit; + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + for (auto iter = input[0]->begin(); iter != input[0]->end(); iter++) { + uint32_t basic_start = 0; + std::vector temp_tokens; + if (with_offsets_ && input.size() == 3) { + RETURN_IF_NOT_OK(input[1]->GetItemAt(&basic_start, {count, 0})); + } + RETURN_IF_NOT_OK(GetTokens(std::string(*iter), basic_start, &temp_tokens, &offsets_start, &offsets_limit)); + out_tokens.insert(out_tokens.end(), temp_tokens.begin(), temp_tokens.end()); + count++; + } + if (out_tokens.empty()) { + out_tokens.emplace_back(""); + offsets_start.push_back(0); + offsets_limit.push_back(0); + } + token_tensor = std::make_shared(out_tokens, TensorShape({(dsize_t)out_tokens.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h new file mode 100644 index 0000000000..4f9c76f57e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ +#include +#include +#include +#include + +#include "cppjieba/Unicode.hpp" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/text/vocab.h" +#include "minddata/dataset/util/status.h" + +using cppjieba::DecodeRunesInString; +using cppjieba::RuneStrArray; +namespace mindspore { +namespace dataset { + +class WordpieceTokenizerOp : public TensorOp { + public: + static const char kDefSuffixIndicator[]; + static const int kDefMaxBytesPerToken; + static const char kDefUnknownToken[]; + static const bool kDefWithOffsets; + WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator = kDefSuffixIndicator, + const int &max_bytes_per_token = kDefMaxBytesPerToken, + const std::string &unknown_token = kDefUnknownToken, const bool &with_offsets = kDefWithOffsets); + + ~WordpieceTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "WordpieceTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + protected: + Status AddSubword(const std::string &input_token, const int &start, const int &end, + std::vector *out_token) const; + Status FoundNoToken(const std::string &input_token, const uint32_t &basic_start, std::vector *out_tokens, + std::vector *offsets_start, std::vector *offsets_limit) const; + Status LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, bool *out_found, + int *out_end) const; + Status GetTokens(const std::string &input_token, const uint32_t &basic_start, std::vector *out_tokens, + std::vector *offsets_start, std::vector *offsets_limit) const; + + std::string Name() const override { return kWordpieceTokenizerOp; } + + private: + const std::shared_ptr vocab_; + const std::string suffix_indicator_; + const bool with_offsets_; + const int max_bytes_per_token_; + const std::string unknown_token_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/vocab.cc b/mindspore/ccsrc/minddata/dataset/text/vocab.cc new file mode 100644 index 0000000000..c1b7e6265c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/vocab.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +#include "minddata/dataset/text/vocab.h" + +namespace mindspore { +namespace dataset { +Vocab::Vocab(std::unordered_map word2id) { word2id_ = std::move(word2id); } + +WordIdType Vocab::Lookup(const WordType &word) const { + auto itr = word2id_.find(word); + return itr == word2id_.end() ? kNoTokenExists : itr->second; +} + +Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special, + std::shared_ptr *vocab) { + // check of duplication on both words and special_tokens will be performed in python + // special_tokens and words both need to be unique, and shouldn't overlap + std::unordered_map word2id; + // if special is added in front, normal words id will start from number of special tokens + WordIdType word_id = prepend_special ? static_cast(special_tokens.size()) : 0; + + for (auto word : words) { + word2id[py::str(word)] = word_id++; + } + + word_id = prepend_special ? 0 : word2id.size(); + + for (auto special_token : special_tokens) { + word2id[py::str(special_token)] = word_id++; + } + + *vocab = std::make_shared(std::move(word2id)); + return Status::OK(); +} + +Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, + const py::list &special_tokens, bool prepend_special, std::shared_ptr *vocab) { + // python validator checks special_tokens doesn't contain any duplicate words + std::unordered_set specials; + // used to check that words in file don't contain any special token that already exists + for (auto word : special_tokens) { + specials.insert(py::str(word)); + } + WordIdType word_id = prepend_special ? static_cast(special_tokens.size()) : 0; + std::unordered_map word2id; + std::fstream handle(path, std::ios::in); + CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path); + std::string word; + while (std::getline(handle, word)) { + if (!delimiter.empty()) { + // if delimiter is not found, find_first_of would return std::string::npos which is -1 + word = word.substr(0, word.find_first_of(delimiter)); + } + CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word + "."); + CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(), word + " is already in special_tokens."); + word2id[word] = word_id++; + // break if enough row is read, if vocab_size is smaller than 0 + if (word2id.size() == vocab_size) break; + } + + word_id = prepend_special ? 0 : word2id.size(); + + for (auto special_token : special_tokens) { + word2id[py::str(special_token)] = word_id++; + } + + *vocab = std::make_shared(std::move(word2id)); + return Status::OK(); +} + +Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr *vocab) { + std::unordered_map word2id; + for (auto p : words) { + word2id[py::str(p.first)] = py::reinterpret_borrow(p.second); + } + *vocab = std::make_shared(std::move(word2id)); + return Status::OK(); +} + +void Vocab::append_word(const std::string &word) { + if (word2id_.find(word) == word2id_.end()) { + word2id_[word] = word2id_.size(); + } +} + +const WordIdType Vocab::kNoTokenExists = -1; + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/vocab.h b/mindspore/ccsrc/minddata/dataset/text/vocab.h new file mode 100644 index 0000000000..6bf6c488c5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/vocab.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_TEXT_VOCAB_H_ +#define DATASET_TEXT_VOCAB_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace mindspore { +namespace dataset { +namespace py = pybind11; + +using WordIdType = int32_t; +using WordType = std::string; + +class Vocab { + public: + // Build a vocab from a python dictionary key is each word ,id needs to start from 2, no duplicate and continuous + // @param const py::dict &words - a dictionary containing word, word id pair. + // @param std::shared_ptr *vocab - return value, vocab object + // @return error code + static Status BuildFromPyDict(const py::dict &words, std::shared_ptr *vocab); + + // Build a vocab from a python list, id will be assigned automatically, start from 2 + // @param const py::list &words - a list of string, used to build vocab, id starts from 2 + // @param std::shared_ptr *vocab - return value, vocab object + // @return error code + static Status BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special, + std::shared_ptr *vocab); + + // Build a vocab from reading a vocab file, id are automatically assigned, start from 2 + // @param std::string &path - path to vocab file , each line is assumed to contain 1 word + // @param std::string &delimiter - delimiter to break each line with + // @param int32_t vocab_size - number of words to read from file + // @param std::shared_ptr *vocab - return value, vocab object + // @return error code + static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, + const py::list &special_tokens, bool prepend_special, std::shared_ptr *vocab); + + // Lookup the id of a word, if word doesn't exist in vocab, return default_id + // @param const WordType word - word to look up + // @param WordIdType default_id - word id to return to user when its not in the vocab + // @return WordIdType, word_id + WordIdType Lookup(const WordType &word) const; + + // constructor, shouldn't be called directly, can't be private due to std::make_unique() + // @param std::unordered_map map - sanitized word2id map + explicit Vocab(std::unordered_map map); + + Vocab() = default; + + // add one word to vocab, increment it's index automatically + // @param std::string & word - word to be added will skip if word already exists + void append_word(const std::string &word); + + // destructor + ~Vocab() = default; + + static const WordIdType kNoTokenExists; + + private: + std::unordered_map word2id_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TEXT_VOCAB_H_ diff --git a/mindspore/ccsrc/dataset/util/.gitignore b/mindspore/ccsrc/minddata/dataset/util/.gitignore similarity index 100% rename from mindspore/ccsrc/dataset/util/.gitignore rename to mindspore/ccsrc/minddata/dataset/util/.gitignore diff --git a/mindspore/ccsrc/dataset/util/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/util/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt diff --git a/mindspore/ccsrc/dataset/util/README.md b/mindspore/ccsrc/minddata/dataset/util/README.md similarity index 100% rename from mindspore/ccsrc/dataset/util/README.md rename to mindspore/ccsrc/minddata/dataset/util/README.md diff --git a/mindspore/ccsrc/minddata/dataset/util/allocator.h b/mindspore/ccsrc/minddata/dataset/util/allocator.h new file mode 100644 index 0000000000..b5eaed97a6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/allocator.h @@ -0,0 +1,178 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_ALLOCATOR_H_ +#define DATASET_UTIL_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/memory_pool.h" + +namespace mindspore { +namespace dataset { +// The following conforms to the requirements of +// std::allocator. Do not rename/change any needed +// requirements, e.g. function names, typedef etc. +template +class Allocator { + public: + template + friend class Allocator; + + using value_type = T; + using pointer = T *; + using const_pointer = const T *; + using reference = T &; + using const_reference = const T &; + using size_type = uint64_t; + + template + struct rebind { + using other = Allocator; + }; + + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + + explicit Allocator(const std::shared_ptr &b) : pool_(b) {} + + ~Allocator() = default; + + template + explicit Allocator(Allocator const &rhs) : pool_(rhs.pool_) {} + + template + bool operator==(Allocator const &rhs) const { + return pool_ == rhs.pool_; + } + + template + bool operator!=(Allocator const &rhs) const { + return pool_ != rhs.pool_; + } + + pointer allocate(std::size_t n) { + void *p; + Status rc = pool_->Allocate(n * sizeof(T), &p); + if (rc.IsOk()) { + return reinterpret_cast(p); + } else if (rc.IsOutofMemory()) { + throw std::bad_alloc(); + } else { + throw std::exception(); + } + } + + void deallocate(pointer p, std::size_t n = 0) noexcept { pool_->Deallocate(p); } + + size_type max_size() { return pool_->get_max_size(); } + + private: + std::shared_ptr pool_; +}; +/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will +/// be released when the object goes out of scope +/// \tparam T The type of object to be allocated +/// \tparam C Allocator. Default to std::allocator +template > +class MemGuard { + public: + using allocator = C; + MemGuard() : n_(0) {} + explicit MemGuard(allocator a) : n_(0), alloc_(a) {} + // There is no copy constructor nor assignment operator because the memory is solely owned by this object. + MemGuard(const MemGuard &) = delete; + MemGuard &operator=(const MemGuard &) = delete; + // On the other hand, We can support move constructor + MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} + MemGuard &operator=(MemGuard &&lhs) noexcept { + if (this != &lhs) { + this->deallocate(); + n_ = lhs.n_; + alloc_ = std::move(lhs.alloc_); + ptr_ = std::move(lhs.ptr_); + } + return *this; + } + /// \brief Explicitly deallocate the memory if allocated + void deallocate() { + if (ptr_) { + auto *p = ptr_.release(); + if (!std::is_arithmetic::value && std::is_destructible::value) { + for (auto i = 0; i < n_; ++i) { + p[i].~T(); + } + } + alloc_.deallocate(p, n_); + n_ = 0; + } + } + /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is + /// allocated. + /// \param n Number of objects of type T to be allocated + /// \tparam Args Extra arguments pass to the constructor of T + template + Status allocate(size_t n, Args &&... args) noexcept { + try { + deallocate(); + if (n > 0) { + T *data = alloc_.allocate(n); + if (!std::is_arithmetic::value) { + for (auto i = 0; i < n; i++) { + std::allocator_traits::construct(alloc_, &(data[i]), std::forward(args)...); + } + } + ptr_ = std::unique_ptr(data); + n_ = n; + } + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory); + } catch (std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return Status::OK(); + } + ~MemGuard() noexcept { deallocate(); } + /// \brief Getter function + /// \return The pointer to the memory allocated + T *GetPointer() const { return ptr_.get(); } + /// \brief Getter function + /// \return The pointer to the memory allocated + T *GetMutablePointer() { return ptr_.get(); } + /// \brief Overload [] operator to access a particular element + /// \param x index to the element. Must be less than number of element allocated. + /// \return pointer to the x-th element + T *operator[](size_t x) { return GetMutablePointer() + x; } + /// \brief Overload [] operator to access a particular element + /// \param x index to the element. Must be less than number of element allocated. + /// \return pointer to the x-th element + T *operator[](size_t x) const { return GetPointer() + x; } + /// \brief Return how many bytes are allocated in total + /// \return Number of bytes allocated in total + size_t GetSizeInBytes() const { return n_ * sizeof(T); } + + private: + allocator alloc_; + std::unique_ptr ptr_; + size_t n_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/arena.cc b/mindspore/ccsrc/minddata/dataset/util/arena.cc new file mode 100644 index 0000000000..87a9c614a8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/arena.cc @@ -0,0 +1,256 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/arena.h" +#include +#include +#include "minddata/dataset/util/system_pool.h" +#include "./securec.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +struct MemHdr { + uint32_t sig; + uint64_t addr; + uint64_t blk_size; + MemHdr(uint64_t a, uint64_t sz) : sig(0xDEADBEEF), addr(a), blk_size(sz) {} + static void setHdr(void *p, uint64_t addr, uint64_t sz) { new (p) MemHdr(addr, sz); } + static void getHdr(void *p, MemHdr *hdr) { + auto *tmp = reinterpret_cast(p); + *hdr = *tmp; + } +}; +Status Arena::Init() { + RETURN_IF_NOT_OK(DeMalloc(size_in_MB_ * 1048576L, &ptr_, false)); + // Divide the memory into blocks. Ignore the last partial block. + uint64_t num_blks = size_in_bytes_ / ARENA_BLK_SZ; + MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << "."; + tr_.Insert(0, num_blks); + return Status::OK(); +} + +Status Arena::Allocate(size_t n, void **p) { + if (n == 0) { + *p = nullptr; + return Status::OK(); + } + std::unique_lock lck(mux_); + // Round up n to 1K block + uint64_t req_size = static_cast(n) + ARENA_WALL_OVERHEAD_SZ; + if (req_size > this->get_max_size()) { + return Status(StatusCode::kOutOfMemory); + } + uint64_t reqBlk = SizeToBlk(req_size); + // Do a first fit search + auto blk = tr_.Top(); + if (blk.second && reqBlk <= blk.first.priority) { + uint64_t addr = blk.first.key; + uint64_t size = blk.first.priority; + // Trim to the required size and return the rest to the tree. + tr_.Pop(); + if (size > reqBlk) { + tr_.Insert(addr + reqBlk, size - reqBlk); + } + lck.unlock(); + char *q = static_cast(ptr_) + addr * ARENA_BLK_SZ; + MemHdr::setHdr(q, addr, reqBlk); + *p = get_user_addr(q); + } else { + return Status(StatusCode::kOutOfMemory); + } + return Status::OK(); +} + +void Arena::Deallocate(void *p) { + auto *q = get_base_addr(p); + MemHdr hdr(0, 0); + MemHdr::getHdr(q, &hdr); + MS_ASSERT(hdr.sig == 0xDEADBEEF); + // We are going to insert a free block back to the treap. But first, check if we can combine + // with the free blocks before and after to form a bigger block. + std::unique_lock lck(mux_); + // Query if we have a free block after us. + auto nextBlk = tr_.Search(hdr.addr + hdr.blk_size); + if (nextBlk.second) { + // Form a bigger block + hdr.blk_size += nextBlk.first.priority; + tr_.DeleteKey(nextBlk.first.key); + } + // Next find a block in front of us. + auto result = FindPrevBlk(hdr.addr); + if (result.second) { + // We can combine with this block + hdr.addr = result.first.first; + hdr.blk_size += result.first.second; + tr_.DeleteKey(result.first.first); + } + // Now we can insert the free node + tr_.Insert(hdr.addr, hdr.blk_size); +} + +Status Arena::Reallocate(void **pp, size_t old_sz, size_t new_sz) { + MS_ASSERT(pp); + MS_ASSERT(*pp); + uint64_t actual_size = static_cast(new_sz) + ARENA_WALL_OVERHEAD_SZ; + if (actual_size > this->get_max_size()) { + RETURN_STATUS_UNEXPECTED("Request size too big : " + std::to_string(new_sz)); + } + uint64_t req_blk = SizeToBlk(actual_size); + char *oldAddr = reinterpret_cast(*pp); + auto *oldHdr = get_base_addr(oldAddr); + MemHdr hdr(0, 0); + MemHdr::getHdr(oldHdr, &hdr); + MS_ASSERT(hdr.sig == 0xDEADBEEF); + std::unique_lock lck(mux_); + if (hdr.blk_size > req_blk) { + // Refresh the header with the new smaller size. + MemHdr::setHdr(oldHdr, hdr.addr, req_blk); + // Return the unused memory back to the tree. Unlike allocate, we we need to merge with the block after us. + auto next_blk = tr_.Search(hdr.addr + hdr.blk_size); + if (next_blk.second) { + hdr.blk_size += next_blk.first.priority; + tr_.DeleteKey(next_blk.first.key); + } + tr_.Insert(hdr.addr + req_blk, hdr.blk_size - req_blk); + } else if (hdr.blk_size < req_blk) { + uint64_t addr = hdr.addr; + // Attempt a block enlarge. No guarantee it is always successful. + bool success = BlockEnlarge(&addr, hdr.blk_size, req_blk); + if (success) { + auto *newHdr = static_cast(ptr_) + addr * ARENA_BLK_SZ; + MemHdr::setHdr(newHdr, addr, req_blk); + if (addr != hdr.addr) { + errno_t err = + memmove_s(get_user_addr(newHdr), (req_blk * ARENA_BLK_SZ) - ARENA_WALL_OVERHEAD_SZ, oldAddr, old_sz); + if (err) { + RETURN_STATUS_UNEXPECTED("Error from memmove: " + std::to_string(err)); + } + } + *pp = get_user_addr(newHdr); + return Status::OK(); + } + // If we reach here, allocate a new block and simply move the content from the old to the new place. + // Unlock since allocate will grab the lock again. + lck.unlock(); + return FreeAndAlloc(pp, old_sz, new_sz); + } + return Status::OK(); +} + +std::ostream &operator<<(std::ostream &os, const Arena &s) { + for (auto &it : s.tr_) { + os << "Address : " << it.key << ". Size : " << it.priority << "\n"; + } + return os; +} + +Arena::Arena(size_t val_in_MB) : ptr_(nullptr), size_in_MB_(val_in_MB), size_in_bytes_(val_in_MB * 1048576L) {} + +Status Arena::CreateArena(std::shared_ptr *p_ba, size_t val_in_MB) { + if (p_ba == nullptr) { + RETURN_STATUS_UNEXPECTED("p_ba is null"); + } + Status rc; + auto ba = new (std::nothrow) Arena(val_in_MB); + if (ba == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + rc = ba->Init(); + if (rc.IsOk()) { + (*p_ba).reset(ba); + } else { + delete ba; + } + return rc; +} + +int Arena::PercentFree() const { + uint64_t sz = 0; + for (auto &it : tr_) { + sz += it.priority; + } + double ratio = static_cast(sz * ARENA_BLK_SZ) / static_cast(size_in_bytes_); + return static_cast(ratio * 100.0); +} + +uint64_t Arena::get_max_size() const { return (size_in_bytes_ - ARENA_WALL_OVERHEAD_SZ); } + +std::pair, bool> Arena::FindPrevBlk(uint64_t addr) { + for (auto &it : tr_) { + if (it.key + it.priority == addr) { + return std::make_pair(std::make_pair(it.key, it.priority), true); + } else if (it.key > addr) { + break; + } + } + return std::make_pair(std::make_pair(0, 0), false); +} + +bool Arena::BlockEnlarge(uint64_t *addr, uint64_t old_sz, uint64_t new_sz) { + uint64_t size = old_sz; + // The logic is very much identical to Deallocate. We will see if we can combine with the blocks before and after. + auto next_blk = tr_.Search(*addr + old_sz); + if (next_blk.second) { + size += next_blk.first.priority; + if (size >= new_sz) { + // In this case, we can just enlarge the block without doing any moving. + tr_.DeleteKey(next_blk.first.key); + // Return unused back to the tree. + if (size > new_sz) { + tr_.Insert(*addr + new_sz, size - new_sz); + } + } + return true; + } + // If we still get here, we have to look at the block before us. + auto result = FindPrevBlk(*addr); + if (result.second) { + // We can combine with this block together with the next block (if any) + size += result.first.second; + *addr = result.first.first; + if (size >= new_sz) { + // We can combine with this block together with the next block (if any) + tr_.DeleteKey(*addr); + if (next_blk.second) { + tr_.DeleteKey(next_blk.first.key); + } + // Return unused back to the tree. + if (size > new_sz) { + tr_.Insert(*addr + new_sz, size - new_sz); + } + return true; + } + } + return false; +} + +Status Arena::FreeAndAlloc(void **pp, size_t old_sz, size_t new_sz) { + MS_ASSERT(pp); + MS_ASSERT(*pp); + void *p = nullptr; + void *q = *pp; + RETURN_IF_NOT_OK(Allocate(new_sz, &p)); + errno_t err = memmove_s(p, new_sz, q, old_sz); + if (err) { + RETURN_STATUS_UNEXPECTED("Error from memmove: " + std::to_string(err)); + } + *pp = p; + // Free the old one. + Deallocate(q); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/arena.h b/mindspore/ccsrc/minddata/dataset/util/arena.h new file mode 100644 index 0000000000..8887757af1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/arena.h @@ -0,0 +1,105 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_ARENA_H_ +#define DATASET_UTIL_ARENA_H_ + +#include +#include +#include +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/treap.h" + +#define ARENA_LOG_BLK_SZ (6u) +#define ARENA_BLK_SZ (static_cast(1u << ARENA_LOG_BLK_SZ)) +#define ARENA_WALL_OVERHEAD_SZ 32 +namespace mindspore { +namespace dataset { +// This is a memory arena based on a treap data structure. +// The constructor of the Arena takes the size of the initial memory size (in MB). +// Internally we divide the memory into multiple blocks. Each block is 64 bytes. +// The treap contains all the free blocks with the relative memory address as key +// and the size of the block as priority. +// +// Initially the treap has only one root which is the whole memory piece. +// +// For memory suballocation, we pop the root node of the treap which contains the largest free block. +// We allocate what we need and return the rest back to the treap. We search for the first fit instead +// of the best fit so to give us a constant time in memory allocation. +// +// When a block of memory is freed. It is joined with the blocks before and after (if they are available) to +// form a bigger block. +class Arena : public MemoryPool { + public: + Arena(const Arena &) = delete; + + Arena &operator=(const Arena &) = delete; + + ~Arena() override { + if (ptr_ != nullptr) { + free(ptr_); + ptr_ = nullptr; + } + } + + Status Allocate(size_t n, void **p) override; + + Status Reallocate(void **, size_t old_sz, size_t new_sz) override; + + void Deallocate(void *) override; + + uint64_t get_max_size() const override; + + static uint64_t SizeToBlk(uint64_t sz) { + uint64_t req_blk = sz / ARENA_BLK_SZ; + if (sz % ARENA_BLK_SZ) { + ++req_blk; + } + return req_blk; + } + + int PercentFree() const override; + + const void *get_base_addr() const { return ptr_; } + + friend std::ostream &operator<<(std::ostream &os, const Arena &s); + + static Status CreateArena(std::shared_ptr *p_ba, size_t val_in_MB = 4096); + + private: + std::mutex mux_; + Treap tr_; + void *ptr_; + size_t size_in_MB_; + size_t size_in_bytes_; + + explicit Arena(size_t val_in_MB = 4096); + + std::pair, bool> FindPrevBlk(uint64_t addr); + + Status Init(); + + bool BlockEnlarge(uint64_t *addr, uint64_t old_sz, uint64_t new_sz); + + Status FreeAndAlloc(void **pp, size_t old_sz, size_t new_sz); + + void *get_user_addr(void *base_addr) const { return reinterpret_cast(base_addr) + ARENA_WALL_OVERHEAD_SZ; } + + void *get_base_addr(void *user_addr) const { return reinterpret_cast(user_addr) - ARENA_WALL_OVERHEAD_SZ; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_ARENA_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/auto_index.h b/mindspore/ccsrc/minddata/dataset/util/auto_index.h new file mode 100644 index 0000000000..0fe55159e6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/auto_index.h @@ -0,0 +1,99 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_AUTO_INDEX_H_ +#define DATASET_UTIL_AUTO_INDEX_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/util/btree.h" +#include "minddata/dataset/util/system_pool.h" + +namespace mindspore { +namespace dataset { +/// This is a B+ tree with generated int64_t value as key. +/// Use minKey() function to query the min key. +/// Use maxKey() function to query the max key. +/// @tparam T +template > +class AutoIndexObj : public BPlusTree { + public: + using my_tree = BPlusTree; + using key_type = typename my_tree::key_type; + using value_type = typename my_tree::value_type; + + AutoIndexObj() : my_tree::BPlusTree(), inx_(kMinKey) {} + + explicit AutoIndexObj(const Allocator &alloc) : my_tree::BPlusTree(alloc), inx_(kMinKey) {} + + ~AutoIndexObj() = default; + + // Insert an object into the tree. + // @param val + // @return + Status insert(const value_type &val, key_type *key = nullptr) { + key_type my_inx = inx_.fetch_add(1); + if (key != nullptr) { + *key = my_inx; + } + return my_tree::DoInsert(my_inx, val); + } + + Status insert(std::unique_ptr &&val, key_type *key = nullptr) { + key_type my_inx = inx_.fetch_add(1); + if (key) { + *key = my_inx; + } + return my_tree::DoInsert(my_inx, std::move(val)); + } + + // Insert a vector of objects into the tree. + // @param v + // @return + Status insert(std::vector v) { + uint64_t num_ele = v.size(); + if (num_ele > 0) { + // reserve a range of keys rather than getting it one by one. + key_type my_inx = inx_.fetch_add(num_ele); + for (uint64_t i = 0; i < num_ele; i++) { + RETURN_IF_NOT_OK(my_tree::DoInsert(my_inx + i, v.at(i))); + } + } + return Status::OK(); + } + + // @return the minimum key + key_type min_key() const { + auto it = this->cbegin(); + return it.key(); + } + + // @return the maximum key + key_type max_key() const { + auto it = this->cend(); + --it; + return it.key(); + } + + private: + static constexpr key_type kMinKey = 0; + std::atomic inx_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_AUTO_INDEX_H_ diff --git a/mindspore/ccsrc/dataset/util/bit.h b/mindspore/ccsrc/minddata/dataset/util/bit.h similarity index 100% rename from mindspore/ccsrc/dataset/util/bit.h rename to mindspore/ccsrc/minddata/dataset/util/bit.h diff --git a/mindspore/ccsrc/minddata/dataset/util/btree.h b/mindspore/ccsrc/minddata/dataset/util/btree.h new file mode 100644 index 0000000000..828976a0a1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/btree.h @@ -0,0 +1,459 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_INDEX_H_ +#define DATASET_UTIL_INDEX_H_ + +#include +#include +#include +#include +#include +#include +#include "./securec.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/list.h" +#include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// Default traits for a B+ tree +struct BPlusTreeTraits { + // This determines the limit of number of keys in a node. + using slot_type = uint16_t; + // Number of slots in each leaf of the tree. + static constexpr slot_type kLeafSlots = 256; + // Number of slots in each inner node of the tree + static constexpr slot_type kInnerSlots = 128; +}; + +/// Implementation of B+ tree +/// @tparam K -- the type of key +/// @tparam V -- the type of value +/// @tparam A -- allocator +/// @tparam C -- comparison class +/// @tparam T -- trait +template , typename C = std::less, + typename T = BPlusTreeTraits> +class BPlusTree { + public: + enum class IndexRc : char { + kOk = 0, + kDuplicateKey = 1, + kSlotFull = 2, + kKeyNotFound = 3, + kNullPointer = 4, + kOutOfMemory = 5, + kRetry = 6, + kUnexpectedError = 127 + }; +#define RETURN_IF_BAD_RC(_s) \ + do { \ + IndexRc __rc = (_s); \ + if (__rc != IndexRc::kOk) { \ + return __rc; \ + } \ + } while (false) + + Status IndexRc2Status(IndexRc rc) { + if (rc == IndexRc::kOk) { + return Status(StatusCode::kOK); + } else if (rc == IndexRc::kOutOfMemory) { + return Status(StatusCode::kOutOfMemory); + } else if (rc == IndexRc::kDuplicateKey) { + return Status(StatusCode::kDuplicateKey); + } else { + RETURN_STATUS_UNEXPECTED(std::to_string(static_cast(rc))); + } + } + + using key_type = K; + using value_type = V; + using key_compare = C; + using slot_type = typename T::slot_type; + using traits = T; + using value_allocator = A; + using key_allocator = typename value_allocator::template rebind::other; + using slot_allocator = typename value_allocator::template rebind::other; + + BPlusTree(); + + explicit BPlusTree(const Allocator &alloc); + + ~BPlusTree() noexcept; + + BPlusTree(const BPlusTree &) = delete; + + BPlusTree(BPlusTree &&) = delete; + + BPlusTree &operator=(const BPlusTree &) = delete; + + BPlusTree &operator=(BPlusTree &&) = delete; + + key_compare key_comp() const { return key_less_; } + + size_t size() const { return stats_.size_; } + + bool empty() const { return (size() == 0); } + + /// @param key + /// @param value + /// @return + Status DoInsert(const key_type &key, const value_type &value); + Status DoInsert(const key_type &key, std::unique_ptr &&value); + + // Update a new value for a given key. + std::unique_ptr DoUpdate(const key_type &key, const value_type &new_value); + std::unique_ptr DoUpdate(const key_type &key, std::unique_ptr &&new_value); + + // Statistics + struct tree_stats { + std::atomic size_; + uint32_t leaves_; + uint32_t inner_nodes_; + uint32_t level_; + + tree_stats() : size_(0), leaves_(0), inner_nodes_(0), level_(0) {} + }; + + private: + // Abstract class of a node (leaf or inner) + class BaseNode { + public: + friend class BPlusTree; + + virtual bool is_leafnode() const = 0; + + virtual bool is_full() const = 0; + + explicit BaseNode(const value_allocator &alloc) : alloc_(alloc) {} + + virtual ~BaseNode() = default; + + protected: + mutable RWLock rw_lock_; + value_allocator alloc_; + + private: + Node lru_; + }; + + // This control block keeps track of all the nodes we traverse on insert. + // To maximize concurrency, internal nodes are latched S. If a node split + // is required, we must releases all the latches and redo it again and change + // the latch mode from S to X. + struct LockPathCB { + enum class LockMode : char { kShared = 0, kExclusive = 1, kNone = 2 }; + + struct path { + BaseNode *node_; + bool locked_; + + path() : node_(nullptr), locked_(false) {} + + path(BaseNode *p, LockMode lockmode) : node_(p), locked_(false) { + if (lockmode == LockMode::kExclusive) { + p->rw_lock_.LockExclusive(); + locked_ = true; + } else if (lockmode == LockMode::kShared) { + p->rw_lock_.LockShared(); + locked_ = true; + } + } + }; + + LockPathCB(BPlusTree *tree, bool retryWithXlock) : self_(tree), latch_shared_(true) { + if (retryWithXlock) { + latch_shared_ = false; + } + if (latch_shared_) { + tree->rw_lock_.LockShared(); + } else { + tree->rw_lock_.LockExclusive(); + } + } + + ~LockPathCB() noexcept { + // Make sure all locks are released. + while (!paths_.empty()) { + path p = paths_.back(); + paths_.pop_back(); + if (p.locked_) { + p.node_->rw_lock_.Unlock(); + } + } + self_->rw_lock_.Unlock(); + self_ = nullptr; + } + + void LockNode(BaseNode *p, LockMode locktype) { paths_.emplace_back(p, locktype); } + + void UnlockMyParents(BaseNode *me) { + path p = paths_.front(); + while (p.node_ != me) { + if (p.locked_) { + p.node_->rw_lock_.Unlock(); + } + paths_.pop_front(); + p = paths_.front(); + } + } + + BPlusTree *self_; + std::deque paths_; + bool latch_shared_; + }; + + // Definition of inner node which fans to either inner node or leaf node. + class InnerNode : public BaseNode { + public: + friend class BPlusTree; + + using alloc_type = typename value_allocator::template rebind::other; + + bool is_leafnode() const override { return false; } + + bool is_full() const override { return (slotuse_ == traits::kInnerSlots); } + + IndexRc Sort(); + + // 50/50 split + IndexRc Split(InnerNode *to, key_type *split_key); + + IndexRc InsertIntoSlot(slot_type slot, const key_type &key, BaseNode *ptr); + + explicit InnerNode(const value_allocator &alloc) : BaseNode::BaseNode(alloc), slotuse_(0) {} + + ~InnerNode() = default; + + slot_type slot_dir_[traits::kInnerSlots] = {0}; + key_type keys_[traits::kInnerSlots] = {0}; + BaseNode *data_[traits::kInnerSlots + 1] = {nullptr}; + slot_type slotuse_; + }; + + // Definition of a leaf node which contains the key/value pair + class LeafNode : public BaseNode { + public: + friend class BPlusTree; + + using alloc_type = typename value_allocator::template rebind::other; + Node link_; + + bool is_leafnode() const override { return true; } + + bool is_full() const override { return (slotuse_ == traits::kLeafSlots); } + + IndexRc Sort(); + + // 50/50 split + IndexRc Split(LeafNode *to); + + IndexRc InsertIntoSlot(LockPathCB *insCB, slot_type slot, const key_type &key, std::unique_ptr &&value); + + explicit LeafNode(const value_allocator &alloc) : BaseNode::BaseNode(alloc), slotuse_(0) {} + + ~LeafNode() = default; + + slot_type slot_dir_[traits::kLeafSlots] = {0}; + key_type keys_[traits::kLeafSlots] = {0}; + std::unique_ptr data_[traits::kLeafSlots]; + slot_type slotuse_; + }; + + mutable RWLock rw_lock_; + value_allocator alloc_; + // All the leaf nodes. Used by the iterator to traverse all the key/values. + List leaf_nodes_; + // All the nodes (inner + leaf). Used by the destructor to free the memory of all the nodes. + List all_; + // Pointer to the root of the tree. + BaseNode *root_; + // Key comparison object + key_compare key_less_; + // Stat + tree_stats stats_; + + bool LessThan(const key_type &a, const key_type &b) const { return key_less_(a, b); } + + bool EqualOrLessThan(const key_type &a, const key_type &b) const { return !key_less_(b, a); } + + bool Equal(const key_type &a, const key_type &b) const { return !key_less_(a, b) && !key_less_(b, a); } + + IndexRc AllocateInner(InnerNode **p); + + IndexRc AllocateLeaf(LeafNode **p); + + template + slot_type FindSlot(const node_type *node, const key_type &key, bool *duplicate = nullptr) const { + slot_type lo = 0; + while (lo < node->slotuse_ && key_comp()(node->keys_[node->slot_dir_[lo]], key)) { + ++lo; + } + bool keymatch = (lo < node->slotuse_ && Equal(key, node->keys_[node->slot_dir_[lo]])); + if (keymatch && !node->is_leafnode()) { + // For an inner node and we match a key during search, we should look into the next slot. + ++lo; + } + if (duplicate != nullptr) { + *duplicate = keymatch; + } + return lo; + } + + IndexRc LeafInsertKeyValue(LockPathCB *ins_cb, LeafNode *node, const key_type &key, + std::unique_ptr &&value, key_type *split_key, LeafNode **split_node); + + IndexRc InnerInsertKeyChild(InnerNode *node, const key_type &key, BaseNode *ptr, key_type *split_key, + InnerNode **split_node); + + inline BaseNode *FindBranch(InnerNode *inner, slot_type slot) const { + BaseNode *child = nullptr; + if (slot == 0) { + child = inner->data_[0]; + } else { + child = inner->data_[inner->slot_dir_[slot - 1] + 1]; + } + return child; + } + + IndexRc InsertKeyValue(LockPathCB *ins_cb, BaseNode *n, const key_type &key, std::unique_ptr &&value, + key_type *split_key, BaseNode **split_node); + + IndexRc Locate(RWLock *parent_lock, bool forUpdate, BaseNode *top, const key_type &key, LeafNode **ln, + slot_type *s) const; + + public: + class Iterator : public std::iterator { + public: + using reference = BPlusTree::value_type &; + using pointer = BPlusTree::value_type *; + + explicit Iterator(BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0), locked_(false) {} + + Iterator(LeafNode *leaf, slot_type slot, bool locked = false) : cur_(leaf), slot_(slot), locked_(locked) {} + + ~Iterator(); + + explicit Iterator(const Iterator &); + + Iterator &operator=(const Iterator &lhs); + + Iterator(Iterator &&); + + Iterator &operator=(Iterator &&lhs); + + pointer operator->() const { return cur_->data_[cur_->slot_dir_[slot_]].get(); } + + reference operator*() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } + + const key_type &key() const { return cur_->keys_[cur_->slot_dir_[slot_]]; } + + value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } + + // Prefix++ + Iterator &operator++(); + + // Postfix++ + Iterator operator++(int); + + // Prefix-- + Iterator &operator--(); + + // Postfix-- + Iterator operator--(int); + + bool operator==(const Iterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } + bool operator!=(const Iterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } + + private: + typename BPlusTree::LeafNode *cur_; + slot_type slot_; + bool locked_; + }; + + class ConstIterator : public std::iterator { + public: + using reference = BPlusTree::value_type &; + using pointer = BPlusTree::value_type *; + + explicit ConstIterator(const BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0), locked_(false) {} + + ~ConstIterator(); + + ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false) + : cur_(leaf), slot_(slot), locked_(locked) {} + + explicit ConstIterator(const ConstIterator &); + + ConstIterator &operator=(const ConstIterator &lhs); + + ConstIterator(ConstIterator &&); + + ConstIterator &operator=(ConstIterator &&lhs); + + pointer operator->() const { return cur_->data_[cur_->slot_dir_[slot_]].get(); } + + reference operator*() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } + + const key_type &key() const { return cur_->keys_[cur_->slot_dir_[slot_]]; } + + value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } + + // Prefix++ + ConstIterator &operator++(); + + // Postfix++ + ConstIterator operator++(int); + + // Prefix-- + ConstIterator &operator--(); + + // Postfix-- + ConstIterator operator--(int); + + bool operator==(const ConstIterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } + bool operator!=(const ConstIterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } + + private: + const typename BPlusTree::LeafNode *cur_; + slot_type slot_; + bool locked_; + }; + + Iterator begin(); + Iterator end(); + + ConstIterator begin() const; + ConstIterator end() const; + + ConstIterator cbegin() const; + ConstIterator cend() const; + + // Locate the entry with key + std::pair Search(const key_type &key) const; + std::pair Search(const key_type &key); + + value_type operator[](key_type key); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_INDEX_H_ + +#include "btree_impl.tpp" +#include "btree_iterator.tpp" diff --git a/mindspore/ccsrc/dataset/util/btree_impl.tpp b/mindspore/ccsrc/minddata/dataset/util/btree_impl.tpp similarity index 100% rename from mindspore/ccsrc/dataset/util/btree_impl.tpp rename to mindspore/ccsrc/minddata/dataset/util/btree_impl.tpp diff --git a/mindspore/ccsrc/dataset/util/btree_iterator.tpp b/mindspore/ccsrc/minddata/dataset/util/btree_iterator.tpp similarity index 100% rename from mindspore/ccsrc/dataset/util/btree_iterator.tpp rename to mindspore/ccsrc/minddata/dataset/util/btree_iterator.tpp diff --git a/mindspore/ccsrc/minddata/dataset/util/buddy.cc b/mindspore/ccsrc/minddata/dataset/util/buddy.cc new file mode 100644 index 0000000000..d4f5434f81 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/buddy.cc @@ -0,0 +1,388 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/buddy.h" +#include +#include +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/system_pool.h" +#include "utils/log_adapter.h" +#include "./securec.h" + +inline uint64_t BitLeftShift(uint64_t v, uint64_t n) { return (v << n); } + +inline uint64_t BitRightShift(uint64_t v, uint64_t n) { return (v >> n); } + +inline uint64_t BitOr(uint64_t rhs, uint64_t lhs) { return rhs | lhs; } + +inline uint64_t BitEx(uint64_t rhs, uint64_t lhs) { return rhs ^ lhs; } + +inline uint64_t BitAnd(uint64_t rhs, uint64_t lhs) { return rhs & lhs; } + +namespace mindspore { +namespace dataset { +Status BuddySpace::Init() { + if (log_min_ < 0) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "log_min must be positive : " + std::to_string(log_min_)); + } + if (num_lvl_ < 3 || num_lvl_ > 18) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "num_lvl must be between 3 and 18 : " + std::to_string(num_lvl_)); + } + min_ = BitLeftShift(1, log_min_); + max_ = BitLeftShift(1, log_min_ + num_lvl_ - 1); + size_t offset_1 = sizeof(rel_addr_t) * num_lvl_; + size_t offset_2 = sizeof(int) * num_lvl_ + offset_1; + size_t offset_3 = sizeof(char) * BitLeftShift(1, num_lvl_ - 3) + offset_2; + RETURN_IF_NOT_OK(DeMalloc(offset_3, &ptr_, true)); + hint_ = reinterpret_cast(ptr_); + count_ = reinterpret_cast((reinterpret_cast(ptr_) + offset_1)); + map_ = reinterpret_cast(ptr_) + offset_2; + count_[num_lvl_ - 1] = 1; + map_[0] = BitOr(MORE_BIT, num_lvl_ - 3); + return Status::OK(); +} + +Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) noexcept { + std::lock_guard lock(mutex_); + addr_t addr = AllocNoLock(sz, desc); + if (addr != NOSPACE) { + *p = addr; + return Status::OK(); + } else { + return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore."); + } +} + +addr_t BuddySpace::AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept { + MS_ASSERT(sz <= max_); + uint32_t reqSize = SizeToBlock(sz); + rel_addr_t rel_addr = AllocBuddySeg(reqSize); + if (rel_addr != static_cast(NOSPACE)) { + (void)memset_s(desc, sizeof(BSpaceDescriptor), 0, sizeof(BSpaceDescriptor)); + desc->sig = static_cast(0xDEADBEEF); + desc->addr = rel_addr; + desc->req_size = reqSize; + desc->blk_size = NextPowerOf2(reqSize); + return static_cast(rel_addr * min_); + } else { + return NOSPACE; + } +} + +void BuddySpace::FreeNoLock(const BSpaceDescriptor *desc) { + MS_ASSERT(desc->sig == 0XDEADBEEF); + rel_addr_t rel_addr = desc->addr; + size_t blk_size = desc->blk_size; + size_t req_size = desc->req_size; + FreeBuddySeg(rel_addr, blk_size, req_size); +} + +void BuddySpace::Free(const BSpaceDescriptor *desc) { + std::lock_guard lock(mutex_); + return FreeNoLock(desc); +} + +std::ostream &operator<<(std::ostream &os, const BuddySpace &s) { + os << "1 unit = " << s.GetMinSize() << "\n" + << "Size of buddy space = " << s.GetMaxSize() << "\n" + << "Number of levels = " << s.num_lvl_ << "\n\n" + << "Percent free = " << s.PercentFree() << "\n" + << "Dumping count array : " + << "\n"; + for (int i = 0; i < s.num_lvl_; i++) { + os << "[" << i << "] = " << s.count_[i] << " "; + if (((i + 1) % 4) == 0) { + os << "\n"; + } + } + os << "\n"; + os << "Dumping allocation info:" + << "\n"; + auto max_addr = static_cast(BitLeftShift(1, s.num_lvl_ - 1)); + rel_addr_t addr = 0; + while (addr < max_addr) { + size_t sz = 0; + BuddySpace::STATE st; + s.GetBuddySegState(addr, &sz, &st); + os << "Address : " << std::left << std::setw(8) << addr << " Size : " << std::setw(8) << sz << " State : " + << ((st == BuddySpace::STATE::kAlloc) ? "ALLOC" : ((st == BuddySpace::STATE::kFree) ? "FREE" : "Unkonwn")) + << "\n"; + addr += sz; + } + return os; +} + +void BuddySpace::GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const { + char byte; + int pos; + int offset; + uint64_t val = 0; + int shift; + pos = BitRightShift(rel_addr, 2); + offset = rel_addr % 4; + shift = offset * 2; + byte = map_[pos]; + switch (offset) { + case 0: + val = byte; + break; + case 1: + case 3: + if (offset == 1) { + val = BitLeftShift(BitAnd(byte, 0x30), shift); + } else { + val = BitLeftShift(BitAnd(byte, 0x03), shift); + } + break; + case 2: + val = BitLeftShift(BitAnd(byte, 0x0F), shift); + break; + } + if (BitAnd(val, ONE_BIT)) { + *rel_sz = 1; + } else if (BitAnd(val, TWO_BIT)) { + *rel_sz = 2; + } else if (BitAnd(val, MORE_BIT)) { + log_t lg = BitAnd(val, 0x0F); + *rel_sz = BitLeftShift(1, lg + 2); + } else { + *st = STATE::kEmpty; + return; + } + *st = BitAnd(val, ALLOC_BIT) ? STATE::kAlloc : STATE::kFree; +} + +void BuddySpace::SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st) { + int clr; + int mask; + int pos; + int offset; + int val = 0; + int shift; + auto log_sz = static_cast(Log2(rel_sz)); + pos = BitRightShift(rel_addr, 2); + offset = rel_addr % 4; + shift = offset * 2; + if (rel_sz == 1) { + val = ONE_BIT; + mask = 0xC0; + } else if (rel_sz == 2) { + val = TWO_BIT; + mask = 0xF0; + } else { + val = BitOr(log_sz - 2, MORE_BIT); + mask = 0xFF; + } + if (st == STATE::kAlloc) { + val = BitOr(val, ALLOC_BIT); + } else if (st == STATE::kFree) { + val = BitAnd(val, ~(static_cast(ALLOC_BIT))); + } else if (st == STATE::kEmpty) { + val = 0; + } + clr = static_cast(~(BitRightShift(mask, shift))); + map_[pos] = static_cast(BitAnd(map_[pos], clr)); + map_[pos] = static_cast(BitOr(map_[pos], BitRightShift(val, shift))); + if (st == STATE::kAlloc) { + count_[log_sz]--; + } else if (st == STATE::kFree) { + count_[log_sz]++; + if (rel_addr < hint_[log_sz]) { + hint_[log_sz] = rel_addr; + } + } +} + +void BuddySpace::JoinBuddySeg(rel_addr_t addr, size_t blk_sz) { + while (blk_sz < BitLeftShift(1, num_lvl_)) { + rel_addr_t buddy = BitEx(addr, blk_sz); + size_t sz = 0; + STATE st; + GetBuddySegState(buddy, &sz, &st); + if (st == STATE::kFree && sz == blk_sz) { + auto log_sz = static_cast(Log2(blk_sz)); + rel_addr_t left = (buddy < addr) ? buddy : addr; + rel_addr_t right = left + blk_sz; + MS_ASSERT(count_[log_sz] >= 2); + count_[log_sz] -= 2; + SetBuddySegState(right, blk_sz, STATE::kEmpty); + SetBuddySegState(left, BitLeftShift(blk_sz, 1), STATE::kFree); + for (int i = 0; i < log_sz; i++) { + if (hint_[i] == right) { + hint_[i] = left; + } + } + addr = left; + blk_sz <<= 1u; + } else { + break; + } + } +} + +void BuddySpace::TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { + MS_ASSERT(ask_sz < blk_sz); + uint32_t inx = Log2(blk_sz); + size_t remaining_sz = ask_sz; + for (int i = inx; i > 0; i--) { + size_t b_size = BitLeftShift(1, i); + size_t half_sz = BitRightShift(b_size, 1); + count_[i]--; + SetBuddySegState(addr, half_sz, STATE::kFree); + SetBuddySegState(addr + half_sz, half_sz, STATE::kFree); + if (remaining_sz >= half_sz) { + SetBuddySegState(addr, half_sz, STATE::kAlloc); + remaining_sz -= half_sz; + if (remaining_sz == 0) { + break; + } + addr += half_sz; + } + } +} + +void BuddySpace::UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { + MS_ASSERT(ask_sz < blk_sz); + uint32_t inx = Log2(blk_sz); + size_t remaining_sz = ask_sz; + for (int i = inx; i > 0; i--) { + size_t b_size = BitLeftShift(1, i); + size_t half_sz = BitRightShift(b_size, 1); + if (remaining_sz >= half_sz) { +#ifdef DEBUG + { + size_t sz = 0; + STATE st; + GetBuddySegState(addr, &sz, &st); + MS_ASSERT(sz == half_sz && st == STATE::kAlloc); + } +#endif + SetBuddySegState(addr, half_sz, STATE::kFree); + remaining_sz -= half_sz; + if (remaining_sz == 0) { + JoinBuddySeg(addr, half_sz); + break; + } + addr += half_sz; + } + } +} + +rel_addr_t BuddySpace::AllocBuddySeg(uint32_t req_size) noexcept { + uint32_t blk_size = NextPowerOf2(req_size); + int start_inx = static_cast(Log2(blk_size)); + bool found = false; + rel_addr_t ask_addr = 0; + auto max_addr = static_cast(BitLeftShift(1, num_lvl_ - 1)); + STATE st; + size_t sz = 0; + for (int i = start_inx; !found && i < num_lvl_; i++) { + MS_ASSERT(count_[i] >= 0); + if (count_[i] == 0) { + continue; + } + auto blk_sz = static_cast(BitLeftShift(1, i)); + ask_addr = hint_[i]; + while (ask_addr < max_addr && !found) { + GetBuddySegState(ask_addr, &sz, &st); + if (st == STATE::kFree && sz == blk_sz) { + found = true; + } else { + MS_ASSERT(st != STATE::kEmpty); + ask_addr += ((sz > blk_sz) ? sz : blk_sz); + } + } + } + if (found) { + if (sz > req_size) { + TrimBuddySeg(ask_addr, sz, req_size); + } else { + SetBuddySegState(ask_addr, sz, STATE::kAlloc); + hint_[start_inx] = ask_addr; + } + return ask_addr; + } else { + return static_cast(NOSPACE); + } +} + +void BuddySpace::FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size) { + if (req_size == blk_size) { +#ifdef DEBUG + { + size_t sz = 0; + STATE st; + GetBuddySegState(addr, &sz, &st); + } +#endif + SetBuddySegState(addr, blk_size, STATE::kFree); + JoinBuddySeg(addr, blk_size); + } else { + UnTrimBuddySeg(addr, blk_size, req_size); + } +} + +int BuddySpace::PercentFree() const { + uint64_t total_free_sz = 0; + uint64_t max_sz_in_unit = BitLeftShift(1, num_lvl_ - 1); + // Go through the count array without lock + for (int i = 0; i < num_lvl_; i++) { + int cnt = count_[i]; + if (cnt == 0) { + continue; + } + uint64_t blk_sz = BitLeftShift(1, i); + total_free_sz += (blk_sz * cnt); + } + return static_cast(static_cast(total_free_sz) / static_cast(max_sz_in_unit) * 100); +} + +BuddySpace::BuddySpace(int log_min, int num_lvl) + : hint_(nullptr), + count_(nullptr), + map_(nullptr), + log_min_(log_min), + num_lvl_(num_lvl), + min_(0), + max_(0), + ptr_(nullptr) {} + +BuddySpace::~BuddySpace() { + if (ptr_ != nullptr) { + free(ptr_); + } + hint_ = nullptr; + count_ = nullptr; + map_ = nullptr; +} + +Status BuddySpace::CreateBuddySpace(std::unique_ptr *out_bs, int log_min, int num_lvl) { + Status rc; + auto bs = new (std::nothrow) BuddySpace(log_min, num_lvl); + if (bs == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + rc = bs->Init(); + if (rc.IsOk()) { + (*out_bs).reset(bs); + } else { + delete bs; + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/buddy.h b/mindspore/ccsrc/minddata/dataset/util/buddy.h new file mode 100644 index 0000000000..b1bcd3ce41 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/buddy.h @@ -0,0 +1,133 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_BUDDY_H_ +#define DATASET_UTIL_BUDDY_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/util/status.h" + +using addr_t = int64_t; +using rel_addr_t = int32_t; +using log_t = int; +#define ALLOC_BIT 0x80 +#define ONE_BIT 0x40 +#define TWO_BIT 0x20 +#define MORE_BIT 0x10 +#define NOSPACE ((addr_t)(-1)) +namespace mindspore { +namespace dataset { +struct BSpaceDescriptor { + int32_t sig; + rel_addr_t addr; + size_t req_size; + size_t blk_size; +}; + +class BuddySpace { + public: + // C++11 feature. Change STATE into a type safe class with + // the keyword. Don't take out the keyword 'class' + enum class STATE { kFree, kAlloc, kEmpty }; + + BuddySpace(const BuddySpace &) = delete; + + BuddySpace &operator=(const BuddySpace &) = delete; + + virtual ~BuddySpace(); + + Status Alloc(uint64_t sz, BSpaceDescriptor *desc, addr_t *) noexcept; + + void Free(const BSpaceDescriptor *desc); + + uint64_t GetMinSize() const { return min_; } + + uint64_t GetMaxSize() const { return max_; } + + int PercentFree() const; + + friend std::ostream &operator<<(std::ostream &os, const BuddySpace &s); + + static uint64_t NextPowerOf2(uint64_t n) { + if (n <= 1) { + return 1; + } + n = n - 1; + while (n & (n - 1)) { + n = n & (n - 1); + } + return n << 1; + } + + static uint32_t Log2(uint64_t n) { + uint32_t cnt = 0; + while (n >>= 1) { + cnt++; + } + return cnt; + } + + static Status CreateBuddySpace(std::unique_ptr *out_bs, int log_min = 15, int num_lvl = 18); + + private: + rel_addr_t *hint_; + int *count_; + char *map_; + int log_min_; + int num_lvl_; + uint64_t min_; + uint64_t max_; + void *ptr_; + std::mutex mutex_; + + explicit BuddySpace(int log_min = 15, int num_lvl = 18); + + Status Init(); + + addr_t AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept; + + void FreeNoLock(const BSpaceDescriptor *desc); + + uint32_t SizeToBlock(const uint64_t sz) const { + uint32_t reqSize = (sz / min_); + if (sz % min_) { + reqSize++; + } + return reqSize; + } + + void GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const; + + void SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st); + + void JoinBuddySeg(rel_addr_t addr, size_t blk_sz); + + void TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); + + void UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); + + rel_addr_t AllocBuddySeg(uint32_t req_size) noexcept; + + void FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_BUDDY_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc b/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc new file mode 100644 index 0000000000..22fb72eb8a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc @@ -0,0 +1,197 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "common/utils.h" +#include "minddata/dataset/util/cache_pool.h" +#include "minddata/dataset/util/services.h" + +namespace mindspore { +namespace dataset { +CachePool::CachePool(const value_allocator &alloc, const std::string &root) + : alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} + +Status CachePool::DoServiceStart() { + tree_ = std::make_shared(); + // If we are given a disk path, set up the StorageManager + if (!root_.toString().empty()) { + Path spill = GetSpillPath(); + RETURN_IF_NOT_OK(spill.CreateDirectories()); + sm_ = std::make_shared(spill); + RETURN_IF_NOT_OK(sm_->ServiceStart()); + MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString()); + } + return Status::OK(); +} +Status CachePool::DoServiceStop() { + Status rc; + Status rc2; + if (sm_ != nullptr) { + rc = sm_->ServiceStop(); + if (rc.IsError()) { + rc2 = rc; + } + } + sm_.reset(); + for (auto &bl : *tree_) { + if (bl.ptr != nullptr) { + alloc_.deallocate(bl.ptr, bl.sz); + } + } + tree_.reset(); + if (!root_.toString().empty()) { + Path spill = GetSpillPath(); + auto it = Path::DirIterator::OpenDirectory(&spill); + while (it->hasNext()) { + rc = it->next().Remove(); + if (rc.IsError() && rc2.IsOk()) { + rc2 = rc; + } + } + rc = spill.Remove(); + if (rc.IsError() && rc2.IsOk()) { + rc2 = rc; + } + } + return rc2; +} +CachePool::~CachePool() noexcept { (void)ServiceStop(); } +Status CachePool::Insert(const std::vector &buf, CachePool::key_type *key) { + DataLocator bl; + Status rc; + size_t sz = 0; + // We will consolidate all the slices into one piece. + for (auto &v : buf) { + sz += v.GetSize(); + } + bl.sz = sz; + try { + bl.ptr = alloc_.allocate(sz); + // We will do a piecewise copy. + WritableSlice dest(bl.ptr, bl.sz); + size_t pos = 0; + for (auto &v : buf) { + WritableSlice out(dest, pos); + rc = WritableSlice::Copy(&out, v); + if (rc.IsError()) { + break; + } + pos += v.GetSize(); + } + if (rc.IsError()) { + alloc_.deallocate(bl.ptr, sz); + bl.ptr = nullptr; + return rc; + } + } catch (std::bad_alloc &e) { + if (sm_ != nullptr) { + RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); + } else { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + } + rc = tree_->insert(bl, key); + if (rc.IsError() && bl.ptr != nullptr) { + alloc_.deallocate(bl.ptr, sz); + } + return rc; +} +Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { + RETURN_UNEXPECTED_IF_NULL(dest); + auto r = tree_->Search(key); + if (r.second) { + auto &it = r.first; + if (it->ptr != nullptr) { + ReadableSlice src(it->ptr, it->sz); + RETURN_IF_NOT_OK(WritableSlice::Copy(dest, src)); + } else if (sm_ != nullptr) { + size_t expectedLength = 0; + RETURN_IF_NOT_OK(sm_->Read(it->storage_key, dest, &expectedLength)); + if (expectedLength != it->sz) { + MS_LOG(ERROR) << "Unexpected length. Read " << expectedLength << ". Expected " << it->sz << "." + << " Internal key: " << key << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + } + if (bytesRead != nullptr) { + *bytesRead = it->sz; + } + } else { + RETURN_STATUS_UNEXPECTED("Key not found"); + } + return Status::OK(); +} +const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; } +Path CachePool::GetSpillPath() const { + auto spill = Path(root_) / subfolder_; + return spill; +} +CachePool::CacheStat CachePool::GetStat() const { + CacheStat cs{0}; + for (auto &it : *tree_) { + if (it.ptr != nullptr) { + ++cs.num_mem_cached; + } else { + ++cs.num_disk_cached; + } + } + return cs; +} +Status CachePool::Spill(CachePool::DataLocator *dl) { + if (sm_ == nullptr) { + RETURN_STATUS_UNEXPECTED("No disk storage to spill"); + } + RETURN_UNEXPECTED_IF_NULL(dl); + RETURN_UNEXPECTED_IF_NULL(dl->ptr); + if (dl->storage_key == 0) { + ReadableSlice data(dl->ptr, dl->sz); + RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data})); + } + alloc_.deallocate(dl->ptr, dl->sz); + dl->ptr = nullptr; + return Status::OK(); +} +Status CachePool::Locate(CachePool::DataLocator *dl) { + RETURN_UNEXPECTED_IF_NULL(dl); + if (dl->ptr == nullptr) { + if (sm_ == nullptr) { + RETURN_STATUS_UNEXPECTED("No disk storage to locate the data"); + } + try { + dl->ptr = alloc_.allocate(dl->sz); + WritableSlice dest(dl->ptr, dl->sz); + Status rc = Read(dl->storage_key, &dest); + if (rc.IsError()) { + alloc_.deallocate(dl->ptr, dl->sz); + dl->ptr = nullptr; + return rc; + } + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + } + return Status::OK(); +} +size_t CachePool::GetSize(CachePool::key_type key) const { + auto r = tree_->Search(key); + if (r.second) { + auto &it = r.first; + return it->sz; + } else { + return 0; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.h b/mindspore/ccsrc/minddata/dataset/util/cache_pool.h new file mode 100644 index 0000000000..cdb6da16b6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/cache_pool.h @@ -0,0 +1,139 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_CACHE_POOL_H_ +#define DATASET_UTIL_CACHE_POOL_H_ + +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/slice.h" +#include "minddata/dataset/util/storage_manager.h" +#include "minddata/dataset/util/auto_index.h" + +namespace mindspore { +namespace dataset { +/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of +/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to +/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to +/// restore the buffer. +/// \see ReadableSlice +class CachePool : public Service { + public: + using base_type = uint8_t; + using pointer = base_type *; + using const_pointer = const base_type *; + using reference = base_type &; + using const_reference = const base_type &; + using value_allocator = Allocator; + + // An internal class to locate the whereabouts of a backed up buffer which can be either in + class DataLocator { + public: + DataLocator() : ptr(nullptr), sz(0), storage_key(0) {} + ~DataLocator() = default; + DataLocator(const DataLocator &other) = default; + DataLocator &operator=(const DataLocator &other) = default; + DataLocator(DataLocator &&other) noexcept { + ptr = other.ptr; + sz = other.sz; + storage_key = other.storage_key; + other.ptr = nullptr; + other.sz = 0; + other.storage_key = 0; + } + DataLocator &operator=(DataLocator &&other) noexcept { + if (&other != this) { + ptr = other.ptr; + sz = other.sz; + storage_key = other.storage_key; + other.ptr = nullptr; + other.sz = 0; + other.storage_key = 0; + } + return *this; + } + pointer ptr; + size_t sz; + StorageManager::key_type storage_key; + }; + + using data_index = AutoIndexObj; + using key_type = data_index::key_type; + using bl_alloc_type = typename value_allocator::template rebind::other; + + /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and + /// how many elements are spilled to disk. + struct CacheStat { + int64_t num_mem_cached; + int64_t num_disk_cached; + }; + + /// \brief Constructor + /// \param alloc Allocator to allocate memory from + /// \param root Optional disk folder to spill + explicit CachePool(const value_allocator &alloc, const std::string &root = ""); + + CachePool(const CachePool &) = delete; + CachePool(CachePool &&) = delete; + CachePool &operator=(const CachePool &) = delete; + CachePool &operator=(CachePool &&) = delete; + ~CachePool() noexcept; + + Status DoServiceStart() override; + Status DoServiceStop() override; + + Path GetSpillPath() const; + + /// \brief Insert a sequence of ReadableSlice objects into the pool. + /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. + /// \param[in] buf A sequence of ReadableSlice objects. + /// \param[out] key Generated key + /// \return Error code + Status Insert(const std::vector &buf, key_type *key); + /// \brief Restore a cached buffer (from memory or disk) + /// \param[in] key A previous key returned from Insert + /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice + /// \param[out] bytesRead Optional. Number of bytes read. + /// \return Error code + Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; + + Status Spill(DataLocator *dl); + + Status Locate(DataLocator *dl); + + size_t GetSize(key_type key) const; + + /// \brief Get statistics. + /// \return CacheStat object + CacheStat GetStat() const; + + const value_allocator &get_allocator() const; + + std::string MyName() const { return subfolder_; } + + private: + value_allocator alloc_; + Path root_; + const std::string subfolder_; + std::shared_ptr sm_; + std::shared_ptr tree_; +}; +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/minddata/dataset/util/circular_pool.cc b/mindspore/ccsrc/minddata/dataset/util/circular_pool.cc new file mode 100644 index 0000000000..f99e6de2f1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/circular_pool.cc @@ -0,0 +1,225 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/circular_pool.h" + +#include +#include +#include +#include "./securec.h" +#include "minddata/dataset/util/system_pool.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +Status CircularPool::AddOneArena() { + Status rc; + std::shared_ptr b; + RETURN_IF_NOT_OK(Arena::CreateArena(&b, arena_size_)); + tail_ = b.get(); + cur_size_in_mb_ += arena_size_; + mem_segments_.push_back(std::move(b)); + return Status::OK(); +} + +ListOfArenas::iterator CircularPool::CircularIterator::Next() { + ListOfArenas::iterator it = dp_->mem_segments_.begin(); + uint32_t size = dp_->mem_segments_.size(); + // This is what we return + it += cur_; + // Prepare for the next round + cur_++; + if (cur_ == size) { + if (start_ == 0) { + has_next_ = false; + } else { + wrap_ = true; + cur_ = 0; + } + } else if (cur_ == start_) { + has_next_ = false; + } + return it; +} + +bool CircularPool::CircularIterator::has_next() const { return has_next_; } + +void CircularPool::CircularIterator::Reset() { + wrap_ = false; + has_next_ = false; + if (!dp_->mem_segments_.empty()) { + // Find the buddy arena that corresponds to the tail. + cur_tail_ = dp_->tail_; + auto list_end = dp_->mem_segments_.end(); + auto it = std::find_if(dp_->mem_segments_.begin(), list_end, + [this](const std::shared_ptr &b) { return b.get() == cur_tail_; }); + MS_ASSERT(it != list_end); + start_ = std::distance(dp_->mem_segments_.begin(), it); + cur_ = start_; + has_next_ = true; + } +} + +CircularPool::CircularIterator::CircularIterator(CircularPool *dp) : dp_(dp) { Reset(); } + +Status CircularPool::Allocate(size_t n, void **p) { + if (p == nullptr) { + RETURN_STATUS_UNEXPECTED("p is null"); + } + Status rc; + void *ptr = nullptr; + do { + SharedLock lock_s(&rw_lock_); + int prevSzInMB = cur_size_in_mb_; + bool move_tail = false; + CircularIterator cirIt(this); + while (cirIt.has_next()) { + auto it = cirIt.Next(); + Arena *ba = it->get(); + if (ba->get_max_size() < n) { + return Status(StatusCode::kOutOfMemory); + } + // If we are asked to move forward the tail + if (move_tail) { + Arena *expected = cirIt.cur_tail_; + (void)atomic_compare_exchange_weak(&tail_, &expected, ba); + move_tail = false; + } + rc = ba->Allocate(n, &ptr); + if (rc.IsOk()) { + *p = ptr; + break; + } else if (rc.IsOutofMemory()) { + // Make the next arena a new tail and continue. + move_tail = true; + } else { + return rc; + } + } + + // Handle the case we have done one round robin search. + if (ptr == nullptr) { + // If we have room to expand. + if (unlimited_ || cur_size_in_mb_ < max_size_in_mb_) { + // lock in exclusively mode. + lock_s.Upgrade(); + // Check again if someone has already expanded. + if (cur_size_in_mb_ == prevSzInMB) { + RETURN_IF_NOT_OK(AddOneArena()); + } + // Re-acquire the shared lock and try again + lock_s.Downgrade(); + } else { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + } + } while (ptr == nullptr); + return rc; +} + +void CircularPool::Deallocate(void *p) { + // Lock in the chain in shared mode and find out which + // segment it comes from + SharedLock lock(&rw_lock_); + auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr &b) -> bool { + char *q = reinterpret_cast(p); + char *base = const_cast(reinterpret_cast(b->get_base_addr())); + return (q > base && q < base + b->get_max_size()); + }); + lock.Unlock(); + it->get()->Deallocate(p); +} + +Status CircularPool::Reallocate(void **pp, size_t old_sz, size_t new_sz) { + // Lock in the chain in shared mode and find out which + // segment it comes from + if (pp == nullptr) { + RETURN_STATUS_UNEXPECTED("pp is null"); + } + void *p = *pp; + SharedLock lock(&rw_lock_); + auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr &b) -> bool { + char *q = reinterpret_cast(p); + char *base = const_cast(reinterpret_cast(b->get_base_addr())); + return (q > base && q < base + b->get_max_size()); + }); + lock.Unlock(); + MS_ASSERT(it != mem_segments_.end()); + Arena *ba = it->get(); + Status rc = ba->Reallocate(pp, old_sz, new_sz); + if (rc.IsOutofMemory()) { + // The current arena has no room for the bigger size. + // Allocate free space from another arena and copy + // the content over. + void *q = nullptr; + rc = this->Allocate(new_sz, &q); + RETURN_IF_NOT_OK(rc); + errno_t err = memcpy_s(q, new_sz, p, old_sz); + if (err) { + this->Deallocate(q); + RETURN_STATUS_UNEXPECTED(std::to_string(err)); + } + *pp = q; + ba->Deallocate(p); + } + return Status::OK(); +} + +uint64_t CircularPool::get_max_size() const { return mem_segments_.front()->get_max_size(); } + +int CircularPool::PercentFree() const { + int percent_free = 0; + int num_arena = 0; + for (auto const &p : mem_segments_) { + percent_free += p->PercentFree(); + num_arena++; + } + if (num_arena) { + return percent_free / num_arena; + } else { + return 100; + } +} + +CircularPool::CircularPool(int max_size_in_gb, int arena_size) + : unlimited_(max_size_in_gb <= 0), + max_size_in_mb_(unlimited_ ? std::numeric_limits::max() : max_size_in_gb * 1024), + arena_size_(arena_size), + cur_size_in_mb_(0) {} + +Status CircularPool::CreateCircularPool(std::shared_ptr *out_pool, int max_size_in_gb, int arena_size, + bool createOneArena) { + Status rc; + if (out_pool == nullptr) { + RETURN_STATUS_UNEXPECTED("pPool is null"); + } + auto pool = new (std::nothrow) CircularPool(max_size_in_gb, arena_size); + if (pool == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + if (createOneArena) { + rc = pool->AddOneArena(); + } + if (rc.IsOk()) { + (*out_pool).reset(pool); + } else { + delete pool; + } + return rc; +} + +CircularPool::~CircularPool() = default; +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/circular_pool.h b/mindspore/ccsrc/minddata/dataset/util/circular_pool.h new file mode 100644 index 0000000000..a63afbd691 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/circular_pool.h @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_CIRCULAR_POOL_H_ +#define DATASET_UTIL_CIRCULAR_POOL_H_ + +#include +#include +#include +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/arena.h" +#include "minddata/dataset/util/lock.h" + +namespace mindspore { +namespace dataset { +using ListOfArenas = std::vector>; + +// This is a dynamic memory pool built on top of memory +// segment each of which is 4G in size. Initially we start +// with one segment, and gradually add segments (not +// guaranteed contiguous) until we reach 32G in size. There +// is an assumption about this kind of memory pool. Allocated +// memory is not held for the whole duration of the pool and +// will be released soon. Based on this assumption, memory is +// obtained from the tail while allocated memory is returned +// to the head of the pool. +class CircularPool : public MemoryPool { + public: + class CircularIterator { + friend class CircularPool; + + public: + explicit CircularIterator(CircularPool *dp); + + ~CircularIterator() = default; + + bool has_next() const; + + ListOfArenas::iterator Next(); + + void Reset(); + + private: + CircularPool *dp_; + Arena *cur_tail_{}; + uint32_t start_{}; + uint32_t cur_{}; + bool wrap_{}; + bool has_next_{}; + }; + + CircularPool(const CircularPool &) = delete; + + CircularPool &operator=(const CircularPool &) = delete; + + ~CircularPool() override; + + Status Allocate(size_t n, void **) override; + + Status Reallocate(void **, size_t old_size, size_t new_size) override; + + void Deallocate(void *) override; + + uint64_t get_max_size() const override; + + int PercentFree() const override; + + friend std::ostream &operator<<(std::ostream &os, const CircularPool &s) { + int i = 0; + for (auto it = s.mem_segments_.begin(); it != s.mem_segments_.end(); ++it, ++i) { + os << "Dumping segment " << i << "\n" << *(it->get()); + } + return os; + } + + static Status CreateCircularPool(std::shared_ptr *out_pool, int max_size_in_gb = -1, + int arena_size = 4096, bool create_one_arena = false); + + private: + ListOfArenas mem_segments_; + std::atomic tail_{}; + bool unlimited_; + int max_size_in_mb_; + int arena_size_; + int cur_size_in_mb_; + RWLock rw_lock_; + + // We can take negative or 0 as input which means unlimited. + CircularPool(int max_size_in_gb, int arena_size); + + Status AddOneArena(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_CIRCULAR_POOL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/cond_var.cc b/mindspore/ccsrc/minddata/dataset/util/cond_var.cc new file mode 100644 index 0000000000..b7c7b76cae --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/cond_var.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/cond_var.h" +#include +#include +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {} + +Status CondVar::Wait(std::unique_lock *lck, const std::function &pred) { + try { + if (svc_ != nullptr) { + // If this cv registers with a global resource tracking, then wait unconditionally. + auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); }; + cv_.wait(*lck, f); + // If we are interrupted, override the return value if this is the master thread. + // Master thread is being interrupted mostly because of some thread is reporting error. + RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus())); + } else { + // Otherwise we wake up once a while to check for interrupt (for this thread). + auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); }; + while (!f()) { + (void)cv_.wait_for(*lck, std::chrono::milliseconds(1)); + } + RETURN_IF_INTERRUPTED(); + } + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return Status::OK(); +} + +CondVar::~CondVar() noexcept { + if (svc_ != nullptr) { + (void)svc_->Deregister(my_name_); + svc_ = nullptr; + } +} + +void CondVar::NotifyOne() noexcept { cv_.notify_one(); } + +void CondVar::NotifyAll() noexcept { cv_.notify_all(); } + +Status CondVar::Register(std::shared_ptr svc) { + Status rc = svc->Register(my_name_, this); + if (rc.IsOk()) { + svc_ = svc; + } + return rc; +} + +void CondVar::Interrupt() { + IntrpResource::Interrupt(); + cv_.notify_all(); +} + +std::string CondVar::my_name() const { return my_name_; } + +Status CondVar::Deregister() { + if (svc_) { + Status rc = svc_->Deregister(my_name_); + svc_ = nullptr; + return rc; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/cond_var.h b/mindspore/ccsrc/minddata/dataset/util/cond_var.h new file mode 100644 index 0000000000..88fcad24a2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/cond_var.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_COND_VAR_H_ +#define DATASET_UTIL_COND_VAR_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/intrp_resource.h" +#include "minddata/dataset/util/intrp_service.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class CondVar : public IntrpResource { + public: + CondVar(); + + ~CondVar() noexcept; + + Status Wait(std::unique_lock *lck, const std::function &pred); + + void Interrupt() override; + + void NotifyOne() noexcept; + + void NotifyAll() noexcept; + + Status Register(std::shared_ptr svc); + + std::string my_name() const; + + Status Deregister(); + + protected: + std::condition_variable cv_; + std::shared_ptr svc_; + + private: + std::string my_name_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_COND_VAR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/intrp_resource.h b/mindspore/ccsrc/minddata/dataset/util/intrp_resource.h new file mode 100644 index 0000000000..9d78e2cd32 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/intrp_resource.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_INTRP_RESOURCE_H_ +#define DATASET_UTIL_INTRP_RESOURCE_H_ + +#include +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class IntrpResource { + public: + enum class State : int { kRunning, kInterrupted }; + + IntrpResource() : st_(State::kRunning) {} + + virtual ~IntrpResource() = default; + + virtual void Interrupt() { st_ = State::kInterrupted; } + + virtual void ResetIntrpState() { st_ = State::kRunning; } + + State CurState() const { return st_; } + + bool Interrupted() const { return CurState() == State::kInterrupted; } + + virtual Status GetInterruptStatus() const { + if (Interrupted()) { + return Status(StatusCode::kInterrupted); + } + return Status::OK(); + } + + protected: + std::atomic st_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_INTRP_RESOURCE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/intrp_service.cc b/mindspore/ccsrc/minddata/dataset/util/intrp_service.cc new file mode 100644 index 0000000000..a82c82cdc9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/intrp_service.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/intrp_service.h" +#include +#include "common/utils.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +IntrpService::IntrpService() : high_water_mark_(0) { (void)ServiceStart(); } + +IntrpService::~IntrpService() noexcept { + MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << "."; + if (!all_intrp_resources_.empty()) { + try { + InterruptAll(); + } catch (const std::exception &e) { + // Ignore all error as we can't throw in the destructor. + } + } + (void)ServiceStop(); +} + +Status IntrpService::Register(const std::string &name, IntrpResource *res) { + SharedLock stateLck(&state_lock_); + // Now double check the state + if (ServiceState() != STATE::kRunning) { + return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "Interrupt service is shutting down"); + } else { + std::lock_guard lck(mutex_); + try { + std::ostringstream ss; + ss << this_thread::get_id(); + MS_LOG(DEBUG) << "Register resource with name " << name << ". Thread ID " << ss.str() << "."; + auto it = all_intrp_resources_.emplace(name, res); + if (it.second == false) { + return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, name); + } + high_water_mark_++; + } catch (std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + } + return Status::OK(); +} + +Status IntrpService::Deregister(const std::string &name) noexcept { + std::lock_guard lck(mutex_); + try { + std::ostringstream ss; + ss << this_thread::get_id(); + MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << "."; + auto n = all_intrp_resources_.erase(name); + if (n == 0) { + MS_LOG(INFO) << "Key " << name << " not found."; + } + } catch (std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return Status::OK(); +} + +void IntrpService::InterruptAll() noexcept { + std::lock_guard lck(mutex_); + for (auto const &it : all_intrp_resources_) { + std::string kName = it.first; + try { + it.second->Interrupt(); + } catch (const std::exception &e) { + // continue the clean up. + } + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/intrp_service.h b/mindspore/ccsrc/minddata/dataset/util/intrp_service.h new file mode 100644 index 0000000000..cb6bf30c73 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/intrp_service.h @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_INTRP_SERVICE_H_ +#define DATASET_UTIL_INTRP_SERVICE_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/intrp_resource.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +using SvcAllocator = Allocator>; + +class IntrpService : public Service { + public: + IntrpService(); + + ~IntrpService() noexcept override; + + IntrpService(const IntrpService &) = delete; + + IntrpService &operator=(const IntrpService &) = delete; + + Status Register(const std::string &name, IntrpResource *res); + + Status Deregister(const std::string &name) noexcept; + + void InterruptAll() noexcept; + + Status DoServiceStart() override { return Status::OK(); } + + Status DoServiceStop() override { return Status::OK(); } + + private: + int high_water_mark_; + std::mutex mutex_; + std::map all_intrp_resources_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_INTRP_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/util/list.h b/mindspore/ccsrc/minddata/dataset/util/list.h similarity index 100% rename from mindspore/ccsrc/dataset/util/list.h rename to mindspore/ccsrc/minddata/dataset/util/list.h diff --git a/mindspore/ccsrc/minddata/dataset/util/lock.cc b/mindspore/ccsrc/minddata/dataset/util/lock.cc new file mode 100644 index 0000000000..5302196a46 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/lock.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/lock.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +void SpinLock::Lock() { + while (true) { + int expected = kUnlocked; + if (val_.compare_exchange_weak(expected, kLocked)) { + break; + } + } +} + +bool SpinLock::TryLock() { + int expected = kUnlocked; + return val_.compare_exchange_strong(expected, kLocked); +} + +void SpinLock::Unlock() noexcept { val_.store(kUnlocked); } + +void RWLock::LockShared() { + std::unique_lock lck(mtx_); + waiting_readers_ += 1; + read_cv_.wait(lck, [this]() { return (waiting_writers_ == 0 && status_ >= 0); }); + waiting_readers_ -= 1; + status_ += 1; +} + +void RWLock::Unlock() noexcept { + std::unique_lock lck(mtx_); + if (status_ == -1) { + // I am the writer. By definition, no other writer nor reader. + status_ = 0; + } else if (status_ > 0) { + // One less reader + status_ -= 1; + } + // Wake up writer only if there is no reader. + if (waiting_writers_ > 0) { + if (status_ == 0) { + write_cv_.notify_one(); + } + } else { + read_cv_.notify_all(); + } +} + +void RWLock::Upgrade() { + std::unique_lock lck(mtx_); + MS_ASSERT(status_); + if (status_ == -1) { + // I am a writer already. + return; + } else if (status_ == 1) { + // If I am the only reader. Just change the status. + status_ = -1; + return; + } else { + // In all other cases, let of the shared lock and relock in exclusive. + lck.unlock(); + this->Unlock(); + this->LockExclusive(); + } +} + +void RWLock::Downgrade() { + std::unique_lock lck(mtx_); + MS_ASSERT(status_); + if (status_ == -1) { + // If there are no other writers waiting, just change the status + if (waiting_writers_ == 0) { + status_ = 1; + } else { + // Otherwise just unlock and relock in shared + lck.unlock(); + this->Unlock(); + this->LockShared(); + } + } else if (status_ > 0) { + return; + } +} + +SharedLock::SharedLock(RWLock *rw) : rw_(rw), ownlock_(false) { + rw_->LockShared(); + ownlock_ = true; +} + +SharedLock::~SharedLock() { + if (ownlock_) { + rw_->Unlock(); + ownlock_ = false; + } + rw_ = nullptr; +} + +void SharedLock::Unlock() { + MS_ASSERT(ownlock_ == true); + rw_->Unlock(); + ownlock_ = false; +} + +void SharedLock::Lock() { + MS_ASSERT(ownlock_ == false); + rw_->LockShared(); + ownlock_ = true; +} + +void SharedLock::Upgrade() { + MS_ASSERT(ownlock_ == true); + rw_->Upgrade(); +} + +void SharedLock::Downgrade() { + MS_ASSERT(ownlock_ == true); + rw_->Downgrade(); +} + +UniqueLock::UniqueLock(RWLock *rw) : rw_(rw), ownlock_(false) { + rw_->LockExclusive(); + ownlock_ = true; +} + +UniqueLock::~UniqueLock() { + if (ownlock_) { + rw_->Unlock(); + ownlock_ = false; + } + rw_ = nullptr; +} + +void UniqueLock::Unlock() { + MS_ASSERT(ownlock_ == true); + rw_->Unlock(); + ownlock_ = false; +} + +void UniqueLock::Lock() { + MS_ASSERT(ownlock_ == false); + rw_->LockExclusive(); + ownlock_ = true; +} + +LockGuard::LockGuard(SpinLock *lock) : lck_(lock), own_lock_(false) { + lck_->Lock(); + own_lock_ = true; +} + +LockGuard::~LockGuard() { + if (own_lock_) { + lck_->Unlock(); + own_lock_ = false; + } + lck_ = nullptr; +} + +void LockGuard::Unlock() { + MS_ASSERT(own_lock_); + lck_->Unlock(); + own_lock_ = false; +} + +void LockGuard::Lock() { + MS_ASSERT(own_lock_ == false); + lck_->Lock(); + own_lock_ = true; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/lock.h b/mindspore/ccsrc/minddata/dataset/util/lock.h similarity index 100% rename from mindspore/ccsrc/dataset/util/lock.h rename to mindspore/ccsrc/minddata/dataset/util/lock.h diff --git a/mindspore/ccsrc/minddata/dataset/util/memory_pool.cc b/mindspore/ccsrc/minddata/dataset/util/memory_pool.cc new file mode 100644 index 0000000000..0e1be9d798 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/memory_pool.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/memory_pool.h" +#include "./securec.h" + +namespace mindspore { +namespace dataset { +Status DeMalloc(std::size_t s, void **p, bool init_to_zero = false) { + if (p == nullptr) { + RETURN_STATUS_UNEXPECTED("p is null"); + } + void *q = ::malloc(s); + if (q == nullptr) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } else { + *p = q; + if (init_to_zero) { + (void)memset_s(q, s, 0, s); + } + return Status::OK(); + } +} +} // namespace dataset +} // namespace mindspore + +void *operator new(std::size_t s, mindspore::dataset::Status *rc, std::shared_ptr b) { + void *ptr = nullptr; + *rc = b->Allocate(s, &ptr); + return ptr; +} + +void *operator new[](std::size_t s, mindspore::dataset::Status *rc, std::shared_ptr b) { + void *ptr = nullptr; + *rc = b->Allocate(s, &ptr); + return ptr; +} + +void operator delete(void *p, std::shared_ptr b) { + if (p != nullptr) b->Deallocate(p); +} + +void operator delete[](void *p, std::shared_ptr b) { + if (p != nullptr) b->Deallocate(p); +} diff --git a/mindspore/ccsrc/minddata/dataset/util/memory_pool.h b/mindspore/ccsrc/minddata/dataset/util/memory_pool.h new file mode 100644 index 0000000000..c7cc473109 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/memory_pool.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_MEMORY_POOL_H_ +#define DATASET_UTIL_MEMORY_POOL_H_ + +#include +#include +#include +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// Abstract class of a memory pool +class MemoryPool { + public: + // Allocate a block of size n + virtual Status Allocate(size_t, void **) = 0; + + // Enlarge or shrink a block from oldSz to newSz + virtual Status Reallocate(void **, size_t old_sz, size_t new_sz) = 0; + + // Free a pointer + virtual void Deallocate(void *) = 0; + + // What is the maximum size I can allocate ? + virtual uint64_t get_max_size() const = 0; + + virtual int PercentFree() const = 0; + + // Destructor + virtual ~MemoryPool() {} +}; + +Status DeMalloc(std::size_t s, void **p, bool); +} // namespace dataset +} // namespace mindspore + +void *operator new(std::size_t, mindspore::dataset::Status *, std::shared_ptr); + +void *operator new[](std::size_t, mindspore::dataset::Status *, std::shared_ptr); + +void operator delete(void *, std::shared_ptr); + +void operator delete[](void *, std::shared_ptr); + +#endif // DATASET_UTIL_MEMORY_POOL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/path.cc b/mindspore/ccsrc/minddata/dataset/util/path.cc new file mode 100644 index 0000000000..8740ecb8e0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/path.cc @@ -0,0 +1,340 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/path.h" + +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +#if defined(_WIN32) || defined(_WIN64) +char Path::separator_ = '\\'; +#else +char Path::separator_ = '/'; +#endif + +Path::Path(const std::string &s) : path_(s) {} + +Path::Path(const char *p) : path_(p) {} + +Path::Path(const Path &p) : path_(p.path_) {} + +Path &Path::operator=(const Path &p) { + if (&p != this) { + this->path_ = p.path_; + } + return *this; +} + +Path &Path::operator=(Path &&p) noexcept { + if (&p != this) { + this->path_ = std::move(p.path_); + } + return *this; +} + +Path::Path(Path &&p) noexcept { this->path_ = std::move(p.path_); } + +Path Path::operator+(const Path &p) { + std::string q = path_ + p.toString(); + return Path(q); +} + +Path Path::operator+(const std::string &p) { + std::string q = path_ + p; + return Path(q); +} + +Path Path::operator+(const char *p) { + std::string q = path_ + p; + return Path(q); +} + +Path &Path::operator+=(const Path &rhs) { + path_ += rhs.toString(); + return *this; +} + +Path &Path::operator+=(const std::string &p) { + path_ += p; + return *this; +} + +Path &Path::operator+=(const char *p) { + path_ += p; + return *this; +} + +Path Path::operator/(const Path &p) { + std::string q = path_ + separator_ + p.toString(); + return Path(q); +} + +Path Path::operator/(const std::string &p) { + std::string q = path_ + separator_ + p; + return Path(q); +} + +Path Path::operator/(const char *p) { + std::string q = path_ + separator_ + p; + return Path(q); +} + +std::string Path::Extension() const { + std::size_t found = path_.find_last_of('.'); + if (found != std::string::npos) { + return path_.substr(found); + } else { + return std::string(""); + } +} + +bool Path::Exists() { + struct stat sb; + int rc = stat(common::SafeCStr(path_), &sb); + if (rc == -1 && errno != ENOENT) { + MS_LOG(WARNING) << "Unable to query the status of " << path_ << ". Errno = " << errno << "."; + } + return (rc == 0); +} + +bool Path::IsDirectory() { + struct stat sb; + int rc = stat(common::SafeCStr(path_), &sb); + if (rc == 0) { + return S_ISDIR(sb.st_mode); + } else { + return false; + } +} + +Status Path::CreateDirectory() { + if (!Exists()) { +#if defined(_WIN32) || defined(_WIN64) + int rc = mkdir(common::SafeCStr(path_)); +#else + int rc = mkdir(common::SafeCStr(path_), S_IRUSR | S_IWUSR | S_IXUSR); +#endif + if (rc) { + std::ostringstream oss; + oss << "Unable to create directory " << path_ << ". Errno = " << errno; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + return Status::OK(); + } else { + if (IsDirectory()) { + return Status::OK(); + } else { + std::ostringstream oss; + oss << "Unable to create directory " << path_ << ". It exists but is not a directory"; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + } +} + +std::string Path::ParentPath() { + std::string r(""); + std::size_t found = path_.find_last_of(separator_); + if (found != std::string::npos) { + if (found == 0) { + r += separator_; + } else { + r = std::string(path_.substr(0, found)); + } + } + return r; +} + +Status Path::CreateDirectories() { + if (IsDirectory()) { + MS_LOG(DEBUG) << "Directory " << toString() << " already exists."; + return Status::OK(); + } else { + MS_LOG(DEBUG) << "Creating directory " << toString() << "."; + std::string parent = ParentPath(); + if (!parent.empty()) { + if (Path(parent).CreateDirectories()) { + return CreateDirectory(); + } + } else { + return CreateDirectory(); + } + } + return Status::OK(); +} + +Status Path::Remove() { + if (Exists()) { + if (IsDirectory()) { + errno_t err = rmdir(common::SafeCStr(path_)); + if (err == -1) { + std::ostringstream oss; + oss << "Unable to delete directory " << path_ << ". Errno = " << errno; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + } else { + errno_t err = unlink(common::SafeCStr(path_)); + if (err == -1) { + std::ostringstream oss; + oss << "Unable to delete file " << path_ << ". Errno = " << errno; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + } + } + return Status::OK(); +} + +Status Path::CreateFile(int *file_descriptor) { return OpenFile(file_descriptor, true); } + +Status Path::OpenFile(int *file_descriptor, bool create) { + int fd; + if (file_descriptor == nullptr) { + RETURN_STATUS_UNEXPECTED("null pointer"); + } + if (IsDirectory()) { + std::ostringstream oss; + oss << "Unable to create file " << path_ << " which is a directory."; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + // Convert to canonical form. + if (strlen(common::SafeCStr(path_)) > PATH_MAX) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + char canonical_path[PATH_MAX + 1] = {0x00}; +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(canonical_path, common::SafeCStr(path_), PATH_MAX) == nullptr) { +#else + if (realpath(common::SafeCStr(path_), canonical_path) == nullptr) { +#endif + if (errno == ENOENT && create) { + // File doesn't exist and we are to create it. Let's break it down. + auto file_part = Basename(); + auto parent_part = ParentPath(); +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(canonical_path, common::SafeCStr(parent_part), PATH_MAX) == nullptr) { +#else + if (realpath(common::SafeCStr(parent_part), canonical_path) == nullptr) { +#endif + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto cur_inx = strlen(canonical_path); + if ((cur_inx + file_part.length() + 1) > PATH_MAX) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + canonical_path[cur_inx++] = separator_; + if (strncpy_s(canonical_path + cur_inx, PATH_MAX - cur_inx, common::SafeCStr(file_part), file_part.length()) != + EOK) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + } else { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + } + if (create) { + fd = open(canonical_path, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP); + } else { + fd = open(canonical_path, O_RDWR); + } + if (fd == -1) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + *file_descriptor = fd; + return Status::OK(); +} + +Status Path::CloseFile(int fd) const { + if (close(fd) < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + return Status::OK(); +} + +Status Path::TruncateFile(int fd) const { + int rc; + rc = ftruncate(fd, 0); + if (rc == 0) { + return Status::OK(); + } else { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } +} + +std::string Path::Basename() { + std::size_t found = path_.find_last_of(separator_); + if (found != std::string::npos) { + return path_.substr(found + 1); + } else { + return path_; + } +} + +std::shared_ptr Path::DirIterator::OpenDirectory(Path *f) { + auto it = new (std::nothrow) DirIterator(f); + + if (it == nullptr) { + return nullptr; + } + + if (it->dp_) { + return std::shared_ptr(it); + } else { + delete it; + return nullptr; + } +} + +Path::DirIterator::~DirIterator() { + if (dp_) { + (void)closedir(dp_); + } + dp_ = nullptr; + dir_ = nullptr; + entry_ = nullptr; +} + +Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) { + MS_LOG(DEBUG) << "Open directory " << f->toString() << "."; + dp_ = opendir(f->toString().c_str()); +} + +bool Path::DirIterator::hasNext() { + do { + entry_ = readdir(dp_); + if (entry_) { + if (strcmp(entry_->d_name, ".") == 0 || strcmp(entry_->d_name, "..") == 0) { + continue; + } + } + break; + } while (true); + return (entry_ != nullptr); +} + +Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); } + +std::ostream &operator<<(std::ostream &os, const Path &s) { + os << s.path_; + return os; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/path.h b/mindspore/ccsrc/minddata/dataset/util/path.h new file mode 100644 index 0000000000..8bc07ca8f3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/path.h @@ -0,0 +1,114 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_PATH_H_ +#define DATASET_UTIL_PATH_H_ + +#include +#include +#include + +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class Path { + public: + class DirIterator { + public: + static std::shared_ptr OpenDirectory(Path *f); + + ~DirIterator(); + + bool hasNext(); + + Path next(); + + private: + explicit DirIterator(Path *f); + + Path *dir_; + DIR *dp_; + struct dirent *entry_; + }; + + explicit Path(const std::string &); + + explicit Path(const char *); + + ~Path() = default; + + Path(const Path &); + + Path &operator=(const Path &); + + Path(Path &&) noexcept; + + Path &operator=(Path &&) noexcept; + + std::string toString() const { return path_; } + + Path operator+(const Path &); + + Path operator+(const std::string &); + + Path operator+(const char *); + + Path &operator+=(const Path &rhs); + + Path &operator+=(const std::string &); + + Path &operator+=(const char *); + + Path operator/(const Path &); + + Path operator/(const std::string &); + + Path operator/(const char *); + + bool Exists(); + + bool IsDirectory(); + + Status CreateDirectory(); + + Status CreateDirectories(); + + std::string Extension() const; + + std::string ParentPath(); + + Status Remove(); + + Status CreateFile(int *fd); + + Status OpenFile(int *fd, bool create = false); + + Status CloseFile(int fd) const; + + Status TruncateFile(int fd) const; + + std::string Basename(); + + friend std::ostream &operator<<(std::ostream &os, const Path &s); + + private: + static char separator_; + std::string path_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_PATH_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/queue.h b/mindspore/ccsrc/minddata/dataset/util/queue.h new file mode 100644 index 0000000000..7a0a987499 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/queue.h @@ -0,0 +1,256 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_QUEUE_H_ +#define DATASET_UTIL_QUEUE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/cond_var.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +template +struct is_shared_ptr : public std::false_type {}; + +template +struct is_shared_ptr> : public std::true_type {}; + +template +struct is_unique_ptr : public std::false_type {}; + +template +struct is_unique_ptr> : public std::true_type {}; + +// A simple thread safe queue using a fixed size array +template +class Queue { + public: + using value_type = T; + using pointer = T *; + using const_pointer = const T *; + using reference = T &; + using const_reference = const T &; + + void Init() { + if (sz_ > 0) { + // We allocate a block of memory and then call the default constructor for each slot. Maybe simpler to call + // new[] but we want to control where the memory is allocated from. + arr_ = alloc_.allocate(sz_); + for (uint64_t i = 0; i < sz_; i++) { + std::allocator_traits>::construct(alloc_, &(arr_[i])); + } + } + } + + explicit Queue(int sz) + : sz_(sz), + arr_(nullptr), + head_(0), + tail_(0), + my_name_(Services::GetUniqueID()), + alloc_(Services::GetInstance().GetServiceMemPool()) { + Init(); + MS_LOG(DEBUG) << "Create Q with uuid " << my_name_ << " of size " << sz_ << "."; + } + + virtual ~Queue() { + ResetQue(); + if (arr_) { + // Simply free the pointer. Since there is nothing in the queue. We don't want to invoke the destructor + // of T in each slot. + alloc_.deallocate(arr_); + arr_ = nullptr; + } + } + + int size() const { + int v = tail_ - head_; + return (v >= 0) ? v : 0; + } + + int capacity() const { return sz_; } + + bool empty() const { return head_ == tail_; } + + void Reset() { ResetQue(); } + + // Producer + Status Add(const_reference ele) noexcept { + std::unique_lock _lock(mux_); + // Block when full + Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); + if (rc.IsOk()) { + uint32_t k = tail_++ % sz_; + arr_[k] = ele; + empty_cv_.NotifyAll(); + _lock.unlock(); + } else { + empty_cv_.Interrupt(); + } + return rc; + } + + Status Add(T &&ele) noexcept { + std::unique_lock _lock(mux_); + // Block when full + Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); + if (rc.IsOk()) { + uint32_t k = tail_++ % sz_; + arr_[k] = std::forward(ele); + empty_cv_.NotifyAll(); + _lock.unlock(); + } else { + empty_cv_.Interrupt(); + } + return rc; + } + + template + Status EmplaceBack(Ts &&... args) noexcept { + std::unique_lock _lock(mux_); + // Block when full + Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); + if (rc.IsOk()) { + uint32_t k = tail_++ % sz_; + new (&(arr_[k])) T(std::forward(args)...); + empty_cv_.NotifyAll(); + _lock.unlock(); + } else { + empty_cv_.Interrupt(); + } + return rc; + } + + // Consumer + Status PopFront(pointer p) { + std::unique_lock _lock(mux_); + // Block when empty + Status rc = empty_cv_.Wait(&_lock, [this]() -> bool { return !empty(); }); + if (rc.IsOk()) { + uint32_t k = head_++ % sz_; + *p = std::move(arr_[k]); + if (std::is_destructible::value) { + // std::move above only changes arr_[k] from rvalue to lvalue. + // The real implementation of move constructor depends on T. + // It may be compiler generated or user defined. But either case + // the result of arr_[k] is still a valid object of type T, and + // we will not keep any extra copy in the queue. + arr_[k].~T(); + // For gcc 9, an extra fix is needed here to clear the memory content + // of arr_[k] because this slot can be reused by another Add which can + // do another std::move. We have seen SEGV here in this case. + std::allocator_traits>::construct(alloc_, &(arr_[k])); + } + full_cv_.NotifyAll(); + _lock.unlock(); + } else { + full_cv_.Interrupt(); + } + return rc; + } + + void ResetQue() noexcept { + std::unique_lock _lock(mux_); + // If there are elements in the queue, invoke its destructor one by one. + if (!empty() && std::is_destructible::value) { + for (uint64_t i = head_; i < tail_; i++) { + uint32_t k = i % sz_; + arr_[k].~T(); + } + } + for (uint64_t i = 0; i < sz_; i++) { + std::allocator_traits>::construct(alloc_, &(arr_[i])); + } + empty_cv_.ResetIntrpState(); + full_cv_.ResetIntrpState(); + head_ = 0; + tail_ = 0; + } + + Status Register(TaskGroup *vg) { + Status rc1 = empty_cv_.Register(vg->GetIntrpService()); + Status rc2 = full_cv_.Register(vg->GetIntrpService()); + if (rc1.IsOk()) { + return rc2; + } else { + return rc1; + } + } + + private: + uint64_t sz_; + pointer arr_; + uint64_t head_; + uint64_t tail_; + std::string my_name_; + std::mutex mux_; + CondVar empty_cv_; + CondVar full_cv_; + Allocator alloc_; +}; + +// A container of queues with [] operator accessors. Basically this is a wrapper over of a vector of queues +// to help abstract/simplify code that is maintaining multiple queues. +template +class QueueList { + public: + QueueList() {} + + void Init(int num_queues, int capacity) { + queue_list_.reserve(num_queues); + for (int i = 0; i < num_queues; i++) { + queue_list_.emplace_back(std::make_unique>(capacity)); + } + } + + Status Register(TaskGroup *vg) { + if (vg == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Null task group during QueueList registration."); + } + for (int i = 0; i < queue_list_.size(); ++i) { + RETURN_IF_NOT_OK(queue_list_[i]->Register(vg)); + } + return Status::OK(); + } + + int size() const { return queue_list_.size(); } + + std::unique_ptr> &operator[](const int index) { return queue_list_[index]; } + + const std::unique_ptr> &operator[](const int index) const { return queue_list_[index]; } + + ~QueueList() = default; + + private: + // Queue contains non-copyable objects, so it cannot be added to a vector due to the vector + // requirement that objects must have copy semantics. To resolve this, we use a vector of unique + // pointers. This allows us to provide dynamic creation of queues in a container. + std::vector>> queue_list_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_QUEUE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/random.h b/mindspore/ccsrc/minddata/dataset/util/random.h new file mode 100644 index 0000000000..d2658f67ec --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/random.h @@ -0,0 +1,74 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_RANDOM_H_ +#define DATASET_UTIL_RANDOM_H_ + +#if defined(_WIN32) || defined(_WIN64) +#include +#endif +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +inline std::mt19937 GetRandomDevice() { +#if defined(_WIN32) || defined(_WIN64) + unsigned int number; + rand_s(&number); + std::mt19937 random_device{static_cast(number)}; +#else + int i = 0; + while (i < 5) { + try { + std::mt19937 random_device{std::random_device("/dev/urandom")()}; + return random_device; + } catch (const std::exception &e) { + MS_LOG(WARNING) << "Get std::random_device failed, retry: " << i << ", error: " << e.what(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + i++; + } + } + std::mt19937 random_device{std::random_device("/dev/urandom")()}; +#endif + return random_device; +} + +inline uint32_t GetNewSeed() { + std::mt19937 random_device = GetRandomDevice(); + std::uniform_int_distribution distribution(0, std::numeric_limits::max()); + return distribution(random_device); +} + +inline uint32_t GetSeed() { + uint32_t seed = GlobalContext::config_manager()->seed(); + if (seed == std::mt19937::default_seed) { + seed = GetNewSeed(); + } + return seed; +} + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_RANDOM_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/semaphore.cc b/mindspore/ccsrc/minddata/dataset/util/semaphore.cc new file mode 100644 index 0000000000..5dadd98f3c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/semaphore.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/semaphore.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +Status Semaphore::P() { + std::unique_lock lck(mutex_); + RETURN_IF_NOT_OK(wait_cond_.Wait(&lck, [this]() { return value_ > 0; })); + --value_; + return Status::OK(); +} +void Semaphore::V() { + std::unique_lock lck(mutex_); + ++value_; + wait_cond_.NotifyOne(); +} +int Semaphore::Peek() { + std::unique_lock lck(mutex_); + return value_; +} +Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } +Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } +void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/semaphore.h b/mindspore/ccsrc/minddata/dataset/util/semaphore.h new file mode 100644 index 0000000000..d07398acb1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/semaphore.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SEMAPHORE_H_ +#define DATASET_UTIL_SEMAPHORE_H_ + +#include "minddata/dataset/util/cond_var.h" + +namespace mindspore { +namespace dataset { +class TaskGroup; + +/// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be +/// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters. +class Semaphore { + public: + /// \brief Constructor + /// \param init Initial value of the internal counter. + explicit Semaphore(int init) : value_(init) {} + + virtual ~Semaphore() {} + /// \brief Decrement the internal counter. Will be blocked if the value is 0. + /// \return Error code. Can get interrupt. + Status P(); + /// \brief Increment the internal counter. Wakeup on of the watiers if any. + void V(); + /// \brief Peek the internal value + /// \return The internal value + int Peek(); + Status Register(TaskGroup *vg); + Status Deregister(); + void ResetIntrpState(); + + private: + int value_; + + std::mutex mutex_; + CondVar wait_cond_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_SEMAPHORE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/service.cc b/mindspore/ccsrc/minddata/dataset/util/service.cc new file mode 100644 index 0000000000..19d60ab47a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/service.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/service.h" +#include + +namespace mindspore { +namespace dataset { +Status Service::ServiceStart() { + do { + UniqueLock lck(&state_lock_); + // No-op if it is already up or some other thread is + // in the process of bring it up. + if (state_ == STATE::kRunning || state_ == STATE::kStartInProg) { + return Status::OK(); + } + // If a stop is in progress, we line up after it + // is done. + if (state_ == STATE::kStopInProg) { + std::this_thread::yield(); + } else { + state_ = STATE::kStartInProg; + // At this point, we will let go of the lock. This allow others to proceed. + lck.Unlock(); + RETURN_IF_NOT_OK(DoServiceStart()); + // Lock again to change state. + lck.Lock(); + state_ = STATE::kRunning; + return Status::OK(); + } + } while (true); +} + +Status Service::ServiceStop() noexcept { + do { + UniqueLock lck(&state_lock_); + // No-op if it is already stopped or some other thread is + // in the process of shutting it down + if (state_ == STATE::kStopped || state_ == STATE::kStopInProg) { + return Status::OK(); + } + // If a start is in progress, we line up after it + // is done. + if (state_ == STATE::kStartInProg) { + std::this_thread::yield(); + } else { + state_ = STATE::kStopInProg; + // At this point, we will let go of the lock. This allows others to proceed. + lck.Unlock(); + RETURN_IF_NOT_OK(DoServiceStop()); + // Lock again to change state. + lck.Lock(); + state_ = STATE::kStopped; + return Status::OK(); + } + } while (true); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/service.h b/mindspore/ccsrc/minddata/dataset/util/service.h new file mode 100644 index 0000000000..2b9c7197fe --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/service.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SERVICE_H_ +#define DATASET_UTIL_SERVICE_H_ + +#include +#include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class Service { + public: + enum class STATE : int { kStartInProg = 1, kRunning, kStopInProg, kStopped }; + + Service() : state_(STATE::kStopped) {} + + Service(const Service &) = delete; + + Service &operator=(const Service &) = delete; + + virtual ~Service() {} + + STATE ServiceState() const { return state_; } + + virtual Status DoServiceStart() = 0; + + virtual Status DoServiceStop() = 0; + + Status ServiceStart(); + + Status ServiceStop() noexcept; + + protected: + STATE state_; + RWLock state_lock_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/services.cc b/mindspore/ccsrc/minddata/dataset/util/services.cc new file mode 100644 index 0000000000..547773e0f1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/services.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/services.h" + +#include +#if !defined(_WIN32) && !defined(_WIN64) +#include +#else +#include +#endif +#include +#include "minddata/dataset/engine/cache/cache_server.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +std::unique_ptr Services::instance_ = nullptr; +std::once_flag Services::init_instance_flag_; + +#if !defined(_WIN32) && !defined(_WIN64) +std::string Services::GetUserName() { + char user[LOGIN_NAME_MAX]; + (void)getlogin_r(user, sizeof(user)); + return std::string(user); +} + +std::string Services::GetHostName() { + char host[LOGIN_NAME_MAX]; + (void)gethostname(host, sizeof(host)); + return std::string(host); +} + +int Services::GetLWP() { return syscall(SYS_gettid); } +#endif + +std::string Services::GetUniqueID() { + const std::string kStr = "abcdefghijklmnopqrstuvwxyz0123456789"; + std::mt19937 gen = GetRandomDevice(); + std::uniform_int_distribution dist(0, kStr.size() - 1); + char buffer[UNIQUEID_LEN]; + for (int i = 0; i < UNIQUEID_LEN; i++) { + buffer[i] = kStr[dist(gen)]; + } + return std::string(buffer, UNIQUEID_LEN); +} + +TaskManager &Services::getTaskMgrInstance() { + Services &sm = GetInstance(); + return *(static_cast(sm.sa_[kSlotTaskMgr_])); +} + +CacheServer &Services::getCacheServer() { + Services &sm = GetInstance(); + return *(static_cast(sm.sa_[kSlotCacheMgr_])); +} + +Status Services::CreateAllInstances() { + // In order, TaskMgr, BufferMgr + Status rc; + sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager(); + RETURN_IF_NOT_OK(rc); + rc = sa_[kSlotTaskMgr_]->ServiceStart(); + RETURN_IF_NOT_OK(rc); + // TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers + sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3); + RETURN_IF_NOT_OK(rc); + rc = sa_[kSlotCacheMgr_]->ServiceStart(); + return rc; +} + +Services::Services() : pool_(nullptr), sa_{nullptr} { + Status rc = CircularPool::CreateCircularPool(&pool_, -1, 16, true); // each arena 16M + if (rc.IsError()) { + std::terminate(); + } +} + +Services::~Services() noexcept { + try { + // In reverse order + CacheServer *cs = static_cast(sa_[kSlotCacheMgr_]); + if (cs != nullptr) { + (void)cs->ServiceStop(); + cs->~CacheServer(); + pool_->Deallocate(cs); + } + TaskManager *tm = static_cast(sa_[kSlotTaskMgr_]); + if (tm != nullptr) { + (void)tm->ServiceStop(); + tm->~TaskManager(); + pool_->Deallocate(tm); + } + } catch (const std::exception &e) { + // Do nothing. + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/services.h b/mindspore/ccsrc/minddata/dataset/util/services.h new file mode 100644 index 0000000000..c7adea0b6e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/services.h @@ -0,0 +1,104 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SERVICES_H_ +#define DATASET_UTIL_SERVICES_H_ + +#include +#include +#include +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/service.h" + +#define UNIQUEID_LEN 36 +namespace mindspore { +namespace dataset { +class TaskManager; +class CacheServer; +class Services { + public: + static Status CreateInstance() { + std::call_once(init_instance_flag_, [&]() -> Status { + instance_.reset(new Services()); + return (instance_->CreateAllInstances()); + }); + + if (instance_ == nullptr) { + instance_.reset(new Services()); + return (instance_->CreateAllInstances()); + } + + return Status::OK(); + } + + static Services &GetInstance() { + if (instance_ == nullptr) { + if (!CreateInstance()) { + std::terminate(); + } + } + return *instance_; + } + + Services(const Services &) = delete; + + Services &operator=(const Services &) = delete; + + ~Services() noexcept; + + static TaskManager &getTaskMgrInstance(); + + static CacheServer &getCacheServer(); + + std::shared_ptr GetServiceMemPool() { return pool_; } + +#if !defined(_WIN32) && !defined(_WIN64) + static std::string GetUserName(); + + static std::string GetHostName(); + + static int GetLWP(); +#endif + + static std::string GetUniqueID(); + + template + static Allocator GetAllocator() { + return Allocator(Services::GetInstance().GetServiceMemPool()); + } + + private: + static std::once_flag init_instance_flag_; + static std::unique_ptr instance_; + // A small pool used for small objects that last until the + // Services Manager shuts down. Used by all sub-services. + std::shared_ptr pool_; + // We use pointers here instead of unique_ptr because we + // want to have ultimate control on the order of + // construction and destruction. + static constexpr int kSlotTaskMgr_ = 0; + static constexpr int kSlotCacheMgr_ = 1; + static constexpr int kNumServices_ = 2; + Service *sa_[kNumServices_]; + + Services(); + + Status CreateAllInstances(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_SERVICES_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/sig_handler.cc b/mindspore/ccsrc/minddata/dataset/util/sig_handler.cc new file mode 100644 index 0000000000..eed3b4ee4d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/sig_handler.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/sig_handler.h" +#include +#include +#if !defined(_WIN32) && !defined(_WIN64) +#include +#endif +#include +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +// Register the custom signal handlers +#if !defined(_WIN32) && !defined(_WIN64) +void RegisterHandlers() { + struct sigaction new_int_action; + + // For the interrupt handler, we do not use SA_RESETHAND so this handler remains in play + // permanently, do not use the OS default handler for it. + new_int_action.sa_sigaction = &IntHandler; + (void)sigemptyset(&new_int_action.sa_mask); + new_int_action.sa_flags = SA_RESTART | SA_SIGINFO; + (void)sigaction(SIGINT, &new_int_action, nullptr); +} + +extern void IntHandler(int sig_num, // The signal that was raised + siginfo_t *sig_info, // The siginfo structure. + void *context) { // context info + // Wake up the watchdog which is designed as async-signal-safe. + TaskManager::WakeUpWatchDog(); +} +#endif +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/sig_handler.h b/mindspore/ccsrc/minddata/dataset/util/sig_handler.h similarity index 100% rename from mindspore/ccsrc/dataset/util/sig_handler.h rename to mindspore/ccsrc/minddata/dataset/util/sig_handler.h diff --git a/mindspore/ccsrc/minddata/dataset/util/slice.cc b/mindspore/ccsrc/minddata/dataset/util/slice.cc new file mode 100644 index 0000000000..beff2b3dd2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/slice.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/util/slice.h" + +namespace mindspore { +namespace dataset { +WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset, size_t len) : ReadableSlice(src, offset, len) { + mutable_data_ = static_cast(src.mutable_data_) + offset; +} +WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset) + : WritableSlice(src, offset, src.GetSize() - offset) {} +Status WritableSlice::Copy(WritableSlice *dest, const ReadableSlice &src) { + RETURN_UNEXPECTED_IF_NULL(dest); + RETURN_UNEXPECTED_IF_NULL(dest->GetMutablePointer()); + if (dest->GetSize() <= 0) { + RETURN_STATUS_UNEXPECTED("Destination length is non-positive"); + } + auto err = memcpy_s(dest->GetMutablePointer(), dest->GetSize(), src.GetPointer(), src.GetSize()); + if (err) { + RETURN_STATUS_UNEXPECTED(std::to_string(err)); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/slice.h b/mindspore/ccsrc/minddata/dataset/util/slice.h new file mode 100644 index 0000000000..1caee0f816 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/slice.h @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SLICE_H_ +#define DATASET_UTIL_SLICE_H_ + +#include +#include +#include +#include "./securec.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/status.h" +namespace mindspore { +namespace dataset { +/// \brief A ReadableSlice wraps a const pointer in memory and its size. +/// \see WritableSlice for a non-const version +/// +class ReadableSlice { + public: + ReadableSlice() : ptr_(nullptr), sz_(0) {} + ReadableSlice(const void *ptr, size_t sz) : ptr_(ptr), sz_(sz) {} + + /// \brief Destructor + ~ReadableSlice() = default; + + ReadableSlice(const ReadableSlice &src, off64_t offset, size_t len) { + ptr_ = static_cast(src.GetPointer()) + offset; + sz_ = len; + } + ReadableSlice(const ReadableSlice &src, off64_t offset) : ReadableSlice(src, offset, src.sz_ - offset) {} + ReadableSlice(const ReadableSlice &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + } + ReadableSlice &operator=(const ReadableSlice &lhs) { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + } + return *this; + } + ReadableSlice(ReadableSlice &&lhs) noexcept { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + lhs.ptr_ = nullptr; + lhs.sz_ = 0; + } + } + ReadableSlice &operator=(ReadableSlice &&lhs) noexcept { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + lhs.ptr_ = nullptr; + lhs.sz_ = 0; + } + return *this; + } + /// \brief Getter function + /// \return Const version of the pointer + const void *GetPointer() const { return ptr_; } + /// \brief Getter function + /// \return Size of the slice + size_t GetSize() const { return sz_; } + bool empty() const { return ptr_ == nullptr; } + + private: + const void *ptr_; + size_t sz_; +}; +/// \brief A WritableSlice inherits from ReadableSlice to allow +/// one to write to the address pointed to by the pointer. +/// +class WritableSlice : public ReadableSlice { + public: + friend class StorageContainer; + /// \brief Default constructor + WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} + /// \brief This form of a constructor takes a pointer and its size. + WritableSlice(void *ptr, size_t sz) : ReadableSlice(ptr, sz), mutable_data_(ptr) {} + WritableSlice(const WritableSlice &src, off64_t offset, size_t len); + WritableSlice(const WritableSlice &src, off64_t offset); + WritableSlice(const WritableSlice &lhs) : ReadableSlice(lhs) { mutable_data_ = lhs.mutable_data_; } + /// \brief Destructor + ~WritableSlice() = default; + WritableSlice &operator=(const WritableSlice &lhs) { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + ReadableSlice::operator=(lhs); + } + return *this; + } + WritableSlice(WritableSlice &&lhs) noexcept : ReadableSlice(std::move(lhs)) { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + lhs.mutable_data_ = nullptr; + } + } + WritableSlice &operator=(WritableSlice &&lhs) noexcept { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + lhs.mutable_data_ = nullptr; + ReadableSlice::operator=(std::move(lhs)); + } + return *this; + } + /// \brief Copy the content from one slice onto another. + static Status Copy(WritableSlice *dest, const ReadableSlice &src); + + private: + void *mutable_data_; + void *GetMutablePointer() { return mutable_data_; } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_SLICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/status.cc b/mindspore/ccsrc/minddata/dataset/util/status.cc new file mode 100644 index 0000000000..3fc498b701 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/status.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/status.h" +#include +#include "common/utils.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +std::string CodeAsString(const StatusCode c) { + const char *s = nullptr; + if (c == StatusCode::kOK) { + // Optimize the most frequent case + return std::string("OK"); + } else { + switch (c) { + case StatusCode::kOutOfMemory: + s = "Out of memory"; + break; + case StatusCode::kInterrupted: + s = "Interrupted system call"; + break; + case StatusCode::kShapeMisMatch: + s = "Shape is incorrect."; + break; + case StatusCode::kNoSpace: + s = "No space left on device"; + break; + case StatusCode::kPyFuncException: + s = "Exception thrown from PyFunc"; + break; + case StatusCode::kDuplicateKey: + s = "Duplicate key"; + break; + case StatusCode::kProfilingError: + s = "Error encountered while profiling"; + break; + case StatusCode::kUnexpectedError: + default: + s = "Unexpected error"; + break; + } + } + return std::string(s); +} + +Status::Status(StatusCode c) noexcept : code_(c), err_msg_(std::move(CodeAsString(c))) {} + +Status::Status() noexcept : code_(StatusCode::kOK), err_msg_("") {} + +Status::~Status() noexcept {} + +Status::Status(const Status &s) : code_(s.code_), err_msg_(s.err_msg_) {} + +Status &Status::operator=(const Status &s) { + if (this == &s) { + return *this; + } + code_ = s.code_; + err_msg_ = s.err_msg_; + return *this; +} + +Status::Status(Status &&s) noexcept { + code_ = s.code_; + s.code_ = StatusCode::kOK; + err_msg_ = std::move(s.err_msg_); +} + +Status &Status::operator=(Status &&s) noexcept { + if (this == &s) { + return *this; + } + code_ = s.code_; + s.code_ = StatusCode::kOK; + err_msg_ = std::move(s.err_msg_); + return *this; +} + +Status::Status(const StatusCode code, const std::string &msg) : code_(code), err_msg_(msg) {} + +Status::Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra) { + code_ = code; + std::ostringstream ss; + ss << "Thread ID " << this_thread::get_id() << " " << CodeAsString(code) << ". "; + if (!extra.empty()) { + ss << extra; + } + ss << "\n"; + ss << "Line of code : " << line_of_code << "\n"; + if (file_name != nullptr) { + ss << "File : " << file_name << "\n"; + } + err_msg_ = ss.str(); + MS_LOG(INFO) << err_msg_; +} + +std::ostream &operator<<(std::ostream &os, const Status &s) { + os << s.ToString(); + return os; +} + +std::string Status::ToString() const { return err_msg_; } + +StatusCode Status::get_code() const { return code_; } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/util/status.h similarity index 100% rename from mindspore/ccsrc/dataset/util/status.h rename to mindspore/ccsrc/minddata/dataset/util/status.h diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_container.cc b/mindspore/ccsrc/minddata/dataset/util/storage_container.cc new file mode 100644 index 0000000000..506495227d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/storage_container.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/storage_container.h" + +#include +#include +#include +#include +#include "common/utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +Status StorageContainer::Create() { + RETURN_IF_NOT_OK(BuddySpace::CreateBuddySpace(&bs_)); + RETURN_IF_NOT_OK(cont_.CreateFile(&fd_)); + is_open_ = true; + MS_LOG(INFO) << "Container " << cont_ << " created"; + return Status::OK(); +} + +Status StorageContainer::Open() noexcept { + std::lock_guard lck(mutex_); + // Check again + if (!is_open_) { + RETURN_IF_NOT_OK(cont_.OpenFile(&fd_)); + is_open_ = true; + } + return Status::OK(); +} + +Status StorageContainer::Close() noexcept { + if (is_open_) { + std::lock_guard lck(mutex_); + // Check again + if (is_open_) { + RETURN_IF_NOT_OK(cont_.CloseFile(fd_)); + is_open_ = false; + fd_ = -1; + } + } + return Status::OK(); +} + +Status StorageContainer::Read(WritableSlice *dest, off64_t offset) const noexcept { + MS_ASSERT(is_open_); + RETURN_UNEXPECTED_IF_NULL(dest); + auto sz = dest->GetSize(); +#if defined(_WIN32) || defined(_WIN64) + // Doesn't seem there is any pread64 on mingw. + // So we will do a seek and then a read under + // a protection of mutex. + std::lock_guard lck(mutex_); + auto seek_err = lseek(fd_, offset, SEEK_SET); + if (seek_err < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto r_sz = read(fd_, dest->GetMutablePointer(), sz); +#else + auto r_sz = pread64(fd_, dest->GetMutablePointer(), sz, offset); +#endif + if (r_sz != sz) { + errno_t err = (r_sz == 0) ? EOF : errno; + RETURN_STATUS_UNEXPECTED(strerror(err)); + } + return Status::OK(); +} + +Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const noexcept { + MS_ASSERT(is_open_); + auto sz = dest.GetSize(); +#if defined(_WIN32) || defined(_WIN64) + // Doesn't seem there is any pwrite64 on mingw. + // So we will do a seek and then a read under + // a protection of mutex. + std::lock_guard lck(mutex_); + auto seek_err = lseek(fd_, offset, SEEK_SET); + if (seek_err < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto r_sz = write(fd_, dest.GetPointer(), sz); +#else + auto r_sz = pwrite64(fd_, dest.GetPointer(), sz, offset); +#endif + if (r_sz != sz) { + errno_t err = (r_sz == 0) ? EOF : errno; + RETURN_STATUS_UNEXPECTED(strerror(err)); + } + return Status::OK(); +} + +Status StorageContainer::Insert(const std::vector &buf, off64_t *offset) noexcept { + size_t sz = 0; + for (auto &v : buf) { + sz += v.GetSize(); + } + if (sz == 0) { + RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); + } + if (sz > bs_->GetMaxSize()) { + RETURN_STATUS_UNEXPECTED("Request size too big"); + } + BSpaceDescriptor bspd{0}; + addr_t addr = 0; + RETURN_IF_NOT_OK(bs_->Alloc(sz, &bspd, &addr)); + *offset = static_cast(addr); + // We will do piecewise copy of the data to disk. + for (auto &v : buf) { + RETURN_IF_NOT_OK(Write(v, addr)); + addr += v.GetSize(); + } + return Status::OK(); +} + +Status StorageContainer::Truncate() const noexcept { + if (is_open_) { + RETURN_IF_NOT_OK(cont_.TruncateFile(fd_)); + MS_LOG(INFO) << "Container " << cont_ << " truncated"; + } + return Status::OK(); +} + +StorageContainer::~StorageContainer() noexcept { + (void)Truncate(); + (void)Close(); +} + +std::ostream &operator<<(std::ostream &os, const StorageContainer &s) { + os << "File path : " << s.cont_ << "\n" << *(s.bs_.get()); + return os; +} + +Status StorageContainer::CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path) { + Status rc; + auto sc = new (std::nothrow) StorageContainer(path); + if (sc == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + rc = sc->Create(); + if (rc.IsOk()) { + (*out_sc).reset(sc); + } else { + delete sc; + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_container.h b/mindspore/ccsrc/minddata/dataset/util/storage_container.h new file mode 100644 index 0000000000..a304012b60 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/storage_container.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_ +#define DATASET_UTIL_STORAGE_CONTAINER_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/util/system_pool.h" +#include "minddata/dataset/util/buddy.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/slice.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class StorageManager; + +class StorageContainer { + public: + friend class StorageManager; + + ~StorageContainer() noexcept; + + StorageContainer(const StorageContainer &) = delete; + + StorageContainer &operator=(const StorageContainer &) = delete; + + friend std::ostream &operator<<(std::ostream &os, const StorageContainer &s); + + Status Open() noexcept; + + Status Close() noexcept; + + Status Insert(const std::vector &buf, off64_t *offset) noexcept; + + Status Write(const ReadableSlice &dest, off64_t offset) const noexcept; + + Status Read(WritableSlice *dest, off64_t offset) const noexcept; + + Status Truncate() const noexcept; + + bool IsOpen() const { return is_open_; } + + static Status CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path); + + private: + mutable std::mutex mutex_; + Path cont_; + int fd_; + bool is_open_; + std::unique_ptr bs_; + + // Use the default value of BuddySpace + // which can map upto 4G of space. + explicit StorageContainer(const std::string &path) : cont_(path), fd_(-1), is_open_(false), bs_(nullptr) {} + + Status Create(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_STORAGE_CONTAINER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc b/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc new file mode 100644 index 0000000000..2f85d00a45 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc @@ -0,0 +1,166 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/storage_manager.h" + +#include +#include +#include +#include +#include "common/utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/services.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +std::string StorageManager::GetBaseName(const std::string &prefix, int32_t file_id) { + std::ostringstream oss; + oss << prefix << std::setfill('0') << std::setw(5) << file_id; + return oss.str(); +} + +std::string StorageManager::ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix) { + std::string base_name = GetBaseName(prefix, file_id); + return (base_name + "." + suffix); +} + +Status StorageManager::AddOneContainer() { + const std::string kPrefix = "IMG"; + const std::string kSuffix = "LB"; + Path container_name = root_ / ConstructFileName(kPrefix, file_id_, kSuffix); + std::shared_ptr sc; + RETURN_IF_NOT_OK(StorageContainer::CreateStorageContainer(&sc, container_name.toString())); + containers_.push_back(sc); + file_id_++; + return Status::OK(); +} + +Status StorageManager::DoServiceStart() { + containers_.reserve(1000); + if (root_.IsDirectory()) { + RETURN_IF_NOT_OK(AddOneContainer()); + } else { + RETURN_STATUS_UNEXPECTED("Not a directory"); + } + return Status::OK(); +} + +Status StorageManager::Write(key_type *key, const std::vector &buf) { + RETURN_UNEXPECTED_IF_NULL(key); + size_t sz = 0; + for (auto &v : buf) { + sz += v.GetSize(); + } + if (sz == 0) { + RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); + } + std::shared_ptr cont; + key_type out_key; + value_type out_value; + bool create_new_container = false; + do { + SharedLock lock_s(&rw_lock_); + size_t num_containers = containers_.size(); + if (create_new_container) { + // Upgrade to exclusvie lock. + lock_s.Upgrade(); + create_new_container = false; + // Check again if someone has already added a + // new container after we got the x lock + if (containers_.size() == num_containers) { + RETURN_IF_NOT_OK(AddOneContainer()); + } + // Refresh how many containers there are. + num_containers = containers_.size(); + // Downgrade back to shared lock + lock_s.Downgrade(); + } + if (num_containers == 0) { + RETURN_STATUS_UNEXPECTED("num_containers is zero"); + } + // Go to the last container to insert. + cont = containers_.at(num_containers - 1); + off64_t offset; + Status rc = cont->Insert(buf, &offset); + if (rc.IsNoSpace()) { + create_new_container = true; + } else if (rc.IsOk()) { + out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); + RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); + *key = out_key; + break; + } else { + return rc; + } + } while (true); + return Status::OK(); +} + +Status StorageManager::Read(StorageManager::key_type key, WritableSlice *dest, size_t *bytesRead) const { + RETURN_UNEXPECTED_IF_NULL(dest); + auto r = index_.Search(key); + if (r.second) { + auto &it = r.first; + value_type v = *it; + int container_inx = v.first; + off_t offset = v.second.first; + size_t sz = v.second.second; + if (dest->GetSize() < sz) { + std::string errMsg = "Destination buffer too small. Expect at least " + std::to_string(sz) + + " but length = " + std::to_string(dest->GetSize()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + if (bytesRead != nullptr) { + *bytesRead = sz; + } + auto cont = containers_.at(container_inx); + RETURN_IF_NOT_OK(cont->Read(dest, offset)); + } else { + RETURN_STATUS_UNEXPECTED("Key not found"); + } + return Status::OK(); +} + +Status StorageManager::DoServiceStop() noexcept { + Status rc; + Status rc1; + for (auto const &p : containers_) { + // The destructor of StorageContainer is not called automatically until the use + // count drops to 0. But it is not always the case. We will do it ourselves. + rc = p.get()->Truncate(); + if (rc.IsError()) { + rc1 = rc; + } + } + containers_.clear(); + file_id_ = 0; + return rc1; +} + +StorageManager::StorageManager(const Path &root) : root_(root), file_id_(0), index_() {} + +StorageManager::~StorageManager() { (void)StorageManager::DoServiceStop(); } + +std::ostream &operator<<(std::ostream &os, const StorageManager &s) { + os << "Dumping all containers ..." + << "\n"; + for (auto const &p : s.containers_) { + os << *(p.get()); + } + return os; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_manager.h b/mindspore/ccsrc/minddata/dataset/util/storage_manager.h new file mode 100644 index 0000000000..e79e7c6e63 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/storage_manager.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_STORAGE_MANAGER_H_ +#define DATASET_UTIL_STORAGE_MANAGER_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/slice.h" +#include "minddata/dataset/util/storage_container.h" + +using ListOfContainers = std::vector>; +namespace mindspore { +namespace dataset { +class StorageManager : public Service { + public: + using storage_index = AutoIndexObj>>; + using key_type = storage_index::key_type; + using value_type = storage_index::value_type; + + explicit StorageManager(const Path &); + + ~StorageManager() override; + + StorageManager(const StorageManager &) = delete; + + StorageManager &operator=(const StorageManager &) = delete; + + Status Write(key_type *out_key, const std::vector &buf); + + Status Read(key_type key, WritableSlice *dest, size_t *bytesRead) const; + + Status DoServiceStart() override; + + Status DoServiceStop() noexcept override; + + friend std::ostream &operator<<(std::ostream &os, const StorageManager &s); + + private: + Path root_; + ListOfContainers containers_; + int file_id_; + RWLock rw_lock_; + storage_index index_; + + std::string GetBaseName(const std::string &prefix, int32_t file_id); + + std::string ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix); + + Status AddOneContainer(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_STORAGE_MANAGER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/system_pool.h b/mindspore/ccsrc/minddata/dataset/util/system_pool.h new file mode 100644 index 0000000000..3a7e61d16b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/system_pool.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SYSTEM_POOL_H_ +#define DATASET_UTIL_SYSTEM_POOL_H_ + +#include +#include +#include +#include +#include +#include "./securec.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/memory_pool.h" + +namespace mindspore { +namespace dataset { +// This class demonstrate how to implement a simple MemoryPool +// for minddata/dataset using malloc/free/realloc. We need to +// implement 4 virtual functions. Other MemoryPool +// implementation, e.g., are BuddyArena and CircularPool. All +// these MemoryPool can be used together with Allocator.h for +// C++ STL containers. +class SystemPool : public MemoryPool { + public: + ~SystemPool() override {} + + Status Allocate(size_t n, void **pp) override { return DeMalloc(n, pp, false); } + + void Deallocate(void *p) override { free(p); } + + Status Reallocate(void **p, size_t old_sz, size_t new_sz) override { + if (old_sz >= new_sz) { + // Do nothing if we shrink. + return Status::OK(); + } else { + void *ptr = *p; + void *q = nullptr; + RETURN_IF_NOT_OK(DeMalloc(new_sz, &q, false)); + errno_t err = memcpy_s(q, new_sz, ptr, old_sz); + if (err) { + free(q); + RETURN_STATUS_UNEXPECTED(std::to_string(err)); + } + free(ptr); + *p = q; + return Status::OK(); + } + } + + uint64_t get_max_size() const override { return std::numeric_limits::max(); } + + int PercentFree() const override { return 100; } + + template + static Allocator GetAllocator() { + return Allocator(std::make_shared()); + } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_SYSTEM_POOL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/task.cc b/mindspore/ccsrc/minddata/dataset/util/task.cc new file mode 100644 index 0000000000..39d754e806 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/task.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/task.h" +#include "common/utils.h" +#include "minddata/dataset/util/task_manager.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +thread_local Task *gMyTask = nullptr; + +void Task::operator()() { +#if !defined(_WIN32) && !defined(_WIN64) + gMyTask = this; +#endif + id_ = this_thread::get_id(); + std::stringstream ss; + ss << id_; + MS_LOG(DEBUG) << my_name_ << " Thread ID " << ss.str() << " Started."; + try { + // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set + // the TaskGroup pointer and register. We move the registration logic to here (after we spawn) so we can + // get the thread id. + TaskGroup *vg = MyTaskGroup(); + rc_ = vg->GetIntrpService()->Register(ss.str(), this); + if (rc_.IsOk()) { + // Now we can run the given task. + rc_ = fnc_obj_(); + } + // Some error codes are ignored, e.g. interrupt. Others we just shutdown the group. + if (rc_.IsError() && !rc_.IsInterrupted()) { + ShutdownGroup(); + } + } catch (const std::bad_alloc &e) { + rc_ = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, e.what()); + ShutdownGroup(); + } catch (const std::exception &e) { + rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); + ShutdownGroup(); + } +} + +void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine. + { + std::lock_guard lk(mux_); + caught_severe_exception_ = true; + } + TaskGroup *vg = MyTaskGroup(); + // If multiple threads hit severe errors in the same group. Keep the first one and + // discard the rest. + if (vg->rc_.IsOk()) { + std::unique_lock rcLock(vg->rc_mux_); + // Check again after we get the lock + if (vg->rc_.IsOk()) { + vg->rc_ = rc_; + rcLock.unlock(); + TaskManager::InterruptMaster(rc_); + TaskManager::InterruptGroup(*this); + } + } +} + +Status Task::GetTaskErrorIfAny() const { + std::lock_guard lk(mux_); + if (caught_severe_exception_) { + return rc_; + } else { + return Status::OK(); + } +} + +Task::Task(const std::string &myName, const std::function &f) + : my_name_(myName), + rc_(), + fnc_obj_(f), + task_group_(nullptr), + is_master_(false), + running_(false), + caught_severe_exception_(false) { + IntrpResource::ResetIntrpState(); + wp_.ResetIntrpState(); + wp_.Clear(); +} + +Status Task::Run() { + Status rc; + if (running_ == false) { + try { + thrd_ = std::async(std::launch::async, std::ref(*this)); + running_ = true; + caught_severe_exception_ = false; + } catch (const std::exception &e) { + rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); + } + } + return rc; +} + +Status Task::Join(WaitFlag blocking) { + if (running_) { + RETURN_UNEXPECTED_IF_NULL(MyTaskGroup()); + auto interrupt_svc = MyTaskGroup()->GetIntrpService(); + try { + if (blocking == WaitFlag::kBlocking) { + // If we are asked to wait, then wait + thrd_.get(); + } else if (blocking == WaitFlag::kNonBlocking) { + // There is a race condition in the global resource tracking such that a thread can miss the + // interrupt and becomes blocked on a conditional variable forever. As a result, calling + // join() will not come back. We need some timeout version of join such that if the thread + // doesn't come back in a reasonable of time, we will send the interrupt again. + while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { + // We can't tell which conditional_variable this thread is waiting on. So we may need + // to interrupt everything one more time. + MS_LOG(INFO) << "Some threads not responding. Interrupt again"; + interrupt_svc->InterruptAll(); + } + } else { + RETURN_STATUS_UNEXPECTED("Unknown WaitFlag"); + } + std::stringstream ss; + ss << get_id(); + MS_LOG(DEBUG) << MyName() << " Thread ID " << ss.str() << " Stopped."; + running_ = false; + RETURN_IF_NOT_OK(wp_.Deregister()); + RETURN_IF_NOT_OK(interrupt_svc->Deregister(ss.str())); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + } + return Status::OK(); +} + +TaskGroup *Task::MyTaskGroup() { return task_group_; } + +void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; } + +Task::~Task() { task_group_ = nullptr; } +Status Task::OverrideInterruptRc(const Status &rc) { + if (rc.IsInterrupted() && this_thread::is_master_thread()) { + // If we are interrupted, override the return value if this is the master thread. + // Master thread is being interrupted mostly because of some thread is reporting error. + return TaskManager::GetMasterThreadRc(); + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/task.h b/mindspore/ccsrc/minddata/dataset/util/task.h new file mode 100644 index 0000000000..9309a3de7b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/task.h @@ -0,0 +1,125 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_TASK_H_ +#define DATASET_UTIL_TASK_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/util/intrp_resource.h" +#include "minddata/dataset/util/list.h" +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/wait_post.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +class TaskManager; + +class Task : public IntrpResource { + public: + friend class TaskManager; + friend class TaskGroup; + + enum class WaitFlag : int { kBlocking, kNonBlocking }; + + Task(const std::string &myName, const std::function &f); + + // Future objects are not copyable. + Task(const Task &) = delete; + + ~Task() override; + + Task &operator=(const Task &) = delete; + + // Move constructor and Assignment are not supported. + // Too many things in this class. + Task(Task &&) = delete; + + Task &operator=(Task &&) = delete; + + Status GetTaskErrorIfAny() const; + + void ChangeName(const std::string &newName) { my_name_ = newName; } + + // To execute the _fncObj + void operator()(); + + Node node; + Node group; + Node free; + + // Run the task + Status Run(); + + Status Join(WaitFlag wf = WaitFlag::kBlocking); + + bool Running() const { return running_; } + + bool CaughtSevereException() const { return caught_severe_exception_; } + + bool IsMasterThread() const { return is_master_; } + + std::thread::id get_id() { return id_; } + + std::string MyName() { return my_name_; } + + // An operator used by std::find + bool operator==(const Task &other) const { return (this == &other); } + + bool operator!=(const Task &other) const { return !(*this == other); } + + void Post() { wp_.Set(); } + + Status Wait() { return (wp_.Wait()); } + + static Status OverrideInterruptRc(const Status &rc); + + private: + mutable std::mutex mux_; + std::string my_name_; + Status rc_; + WaitPost wp_; + // Task need to provide definition for this function. It + // will be called by thread function. + std::function fnc_obj_; + // Misc fields used by TaskManager. + TaskGroup *task_group_; + std::future thrd_; + std::thread::id id_; + bool is_master_; + volatile bool running_; + volatile bool caught_severe_exception_; + + void ShutdownGroup(); + TaskGroup *MyTaskGroup(); + void set_task_group(TaskGroup *vg); +}; + +extern thread_local Task *gMyTask; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_TASK_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/task_manager.cc b/mindspore/ccsrc/minddata/dataset/util/task_manager.cc new file mode 100644 index 0000000000..fefea0b97c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/task_manager.cc @@ -0,0 +1,353 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "./securec.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +// This takes the same parameter as Task constructor. +Status TaskManager::CreateAsyncTask(const std::string &my_name, const std::function &f, TaskGroup *vg, + Task **task) { + // We need to block destructor coming otherwise we will deadlock. We will grab the + // stateLock in shared allowing CreateAsyncTask to run concurrently. + SharedLock stateLck(&state_lock_); + // Now double check the state + if (ServiceState() == STATE::kStopInProg || ServiceState() == STATE::kStopped) { + return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "TaskManager is shutting down"); + } + RETURN_IF_NOT_OK(GetFreeTask(my_name, f, task)); + if (vg == nullptr) { + RETURN_STATUS_UNEXPECTED("TaskGroup is null"); + } + // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set + // the TaskGroup pointer. We will do the set here before we call run(). The run() will do the registration. + (*task)->set_task_group(vg); + // Link to the master lru list. + { + UniqueLock lck(&lru_lock_); + lru_.Append(*task); + } + // Link to the group list as well before we spawn. + { + UniqueLock lck(&vg->rw_lock_); + vg->grp_list_.Append(*task); + } + // Track all the TaskGroup. Used for control-c + { + LockGuard lck(&tg_lock_); + this->grp_list_.insert(vg); + } + RETURN_IF_NOT_OK((*task)->wp_.Register(vg)); + RETURN_IF_NOT_OK((*task)->Run()); + // Wait for the thread to initialize successfully. + RETURN_IF_NOT_OK((*task)->Wait()); + return Status::OK(); +} + +Status TaskManager::join_all() { + Status rc; + Status rc2; + SharedLock lck(&lru_lock_); + for (Task &tk : lru_) { + rc = tk.Join(); + if (rc.IsError()) { + rc2 = rc; + } + } + return rc2; +} + +void TaskManager::interrupt_all() noexcept { + global_interrupt_ = 1; + LockGuard lck(&tg_lock_); + for (TaskGroup *vg : grp_list_) { + auto svc = vg->GetIntrpService(); + if (svc) { + // Stop the interrupt service. No new request is accepted. + svc->ServiceStop(); + svc->InterruptAll(); + } + } + master_->Interrupt(); +} + +Task *TaskManager::FindMe() { +#if !defined(_WIN32) && !defined(_WIN64) + return gMyTask; +#else + TaskManager &tm = TaskManager::GetInstance(); + SharedLock lock(&tm.lru_lock_); + auto id = this_thread::get_id(); + auto tk = std::find_if(tm.lru_.begin(), tm.lru_.end(), [id](const Task &tk) { return tk.id_ == id; }); + if (tk != tm.lru_.end()) { + return &(*tk); + } + // If we get here, either I am the watchdog or the master thread. + if (tm.master_->id_ == id) { + return tm.master_.get(); + } else if (tm.watchdog_ != nullptr && tm.watchdog_->id_ == id) { + return tm.watchdog_; + } + MS_LOG(ERROR) << "Task not found."; + return nullptr; +#endif +} + +TaskManager::TaskManager() try : global_interrupt_(0), + lru_(&Task::node), + free_lst_(&Task::free), + watchdog_grp_(nullptr), + watchdog_(nullptr) { + auto alloc = Services::GetAllocator(); + // Create a dummy Task for the master thread (this thread) + master_ = std::allocate_shared(alloc, "master", []() -> Status { return Status::OK(); }); + master_->id_ = this_thread::get_id(); + master_->running_ = true; + master_->is_master_ = true; +#if !defined(_WIN32) && !defined(_WIN64) + gMyTask = master_.get(); + // Initialize the semaphore for the watchdog + errno_t rc = sem_init(&sem_, 0, 0); + if (rc == -1) { + MS_LOG(ERROR) << "Unable to initialize a semaphore. Errno = " << rc << "."; + std::terminate(); + } +#endif +} catch (const std::exception &e) { + MS_LOG(ERROR) << "MindData initialization failed: " << e.what() << "."; + std::terminate(); +} + +TaskManager::~TaskManager() { + if (watchdog_) { + WakeUpWatchDog(); + watchdog_->Join(); + // watchdog_grp_ and watchdog_ pointers come from Services::GetInstance().GetServiceMemPool() which we will free it + // on shutdown. So no need to free these pointers one by one. + watchdog_grp_ = nullptr; + watchdog_ = nullptr; + } +#if !defined(_WIN32) && !defined(_WIN64) + (void)sem_destroy(&sem_); +#endif +} + +Status TaskManager::DoServiceStart() { + MS_LOG(INFO) << "Starting Task Manager."; +#if !defined(_WIN32) && !defined(_WIN64) + // Create a watchdog for control-c + std::shared_ptr mp = Services::GetInstance().GetServiceMemPool(); + // A dummy group just for the watchdog. We aren't really using it. But most code assumes a thread must + // belong to a group. + auto f = std::bind(&TaskManager::WatchDog, this); + Status rc; + watchdog_grp_ = new (&rc, mp) TaskGroup(); + RETURN_IF_NOT_OK(rc); + rc = watchdog_grp_->CreateAsyncTask("Watchdog", f, &watchdog_); + if (rc.IsError()) { + ::operator delete(watchdog_grp_, mp); + watchdog_grp_ = nullptr; + return rc; + } + grp_list_.erase(watchdog_grp_); + lru_.Remove(watchdog_); +#endif + return Status::OK(); +} + +Status TaskManager::DoServiceStop() { + WakeUpWatchDog(); + interrupt_all(); + return Status::OK(); +} + +Status TaskManager::WatchDog() { + TaskManager::FindMe()->Post(); +#if !defined(_WIN32) && !defined(_WIN64) + errno_t err = sem_wait(&sem_); + if (err == -1) { + RETURN_STATUS_UNEXPECTED("Errno = " + std::to_string(errno)); + } + // We are woken up by control-c and we are going to stop all threads that are running. + // In addition, we also want to prevent new thread from creating. This can be done + // easily by calling the parent function. + RETURN_IF_NOT_OK(ServiceStop()); +#endif + return Status::OK(); +} + +// Follow the group link and interrupt other +// Task in the same group. It is used by +// Watchdog only. +void TaskManager::InterruptGroup(Task &curTk) { + TaskGroup *vg = curTk.MyTaskGroup(); + vg->interrupt_all(); +} + +void TaskManager::InterruptMaster(const Status &rc) { + TaskManager &tm = TaskManager::GetInstance(); + std::shared_ptr master = tm.master_; + std::lock_guard lck(master->mux_); + master->Interrupt(); + if (rc.IsError() && master->rc_.IsOk()) { + master->rc_ = rc; + master->caught_severe_exception_ = true; + } +} + +Status TaskManager::GetMasterThreadRc() { + TaskManager &tm = TaskManager::GetInstance(); + std::shared_ptr master = tm.master_; + Status rc = tm.master_->GetTaskErrorIfAny(); + if (rc.IsError()) { + // Reset the state once we retrieve the value. + std::lock_guard lck(master->mux_); + master->rc_ = Status::OK(); + master->caught_severe_exception_ = false; + master->ResetIntrpState(); + } + return rc; +} + +void TaskManager::ReturnFreeTask(Task *p) noexcept { + // Take it out from lru_ if any + { + UniqueLock lck(&lru_lock_); + auto it = std::find(lru_.begin(), lru_.end(), *p); + if (it != lru_.end()) { + lru_.Remove(p); + } + } + // We need to deallocate the string resources associated with the Task class + // before we cache its memory for future use. + p->~Task(); + // Put it back into free list + { + LockGuard lck(&free_lock_); + free_lst_.Append(p); + } +} + +Status TaskManager::GetFreeTask(const std::string &my_name, const std::function &f, Task **p) { + if (p == nullptr) { + RETURN_STATUS_UNEXPECTED("p is null"); + } + Task *q = nullptr; + // First try the free list + { + LockGuard lck(&free_lock_); + if (free_lst_.count > 0) { + q = free_lst_.head; + free_lst_.Remove(q); + } + } + if (q) { + new (q) Task(my_name, f); + } else { + std::shared_ptr mp = Services::GetInstance().GetServiceMemPool(); + Status rc; + q = new (&rc, mp) Task(my_name, f); + RETURN_IF_NOT_OK(rc); + } + *p = q; + return Status::OK(); +} + +Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::function &f, Task **ppTask) { + auto pMytask = TaskManager::FindMe(); + // We need to block ~TaskGroup coming otherwise we will deadlock. We will grab the + // stateLock in shared allowing CreateAsyncTask to run concurrently. + SharedLock state_lck(&state_lock_); + // Now double check the state + if (ServiceState() != STATE::kRunning) { + return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "Taskgroup is shutting down"); + } + TaskManager &dm = TaskManager::GetInstance(); + Task *pTask = nullptr; + // If the group is already in error, early exit too. + // We can't hold the rc_mux_ throughout because the thread spawned by CreateAsyncTask may hit error which + // will try to shutdown the group and grab the rc_mux_ and we will deadlock. + { + std::unique_lock rcLock(rc_mux_); + if (rc_.IsError()) { + return pMytask->IsMasterThread() ? rc_ : Status(StatusCode::kInterrupted); + } + } + RETURN_IF_NOT_OK(dm.CreateAsyncTask(my_name, f, this, &pTask)); + if (ppTask) { + *ppTask = pTask; + } + return Status::OK(); +} + +void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); } + +Status TaskGroup::join_all(Task::WaitFlag wf) { + Status rc; + Status rc2; + SharedLock lck(&rw_lock_); + for (Task &tk : grp_list_) { + rc = tk.Join(wf); + if (rc.IsError()) { + rc2 = rc; + } + } + return rc2; +} + +Status TaskGroup::DoServiceStop() { + intrp_svc_->ServiceStop(); + interrupt_all(); + return (join_all(Task::WaitFlag::kNonBlocking)); +} + +TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) { + auto alloc = Services::GetAllocator(); + intrp_svc_ = std::allocate_shared(alloc); + (void)Service::ServiceStart(); +} + +TaskGroup::~TaskGroup() { + (void)Service::ServiceStop(); + // The TaskGroup is going out of scope, and we can return the Task list to the free list. + Task *cur = grp_list_.head; + TaskManager &tm = TaskManager::GetInstance(); + while (cur) { + Task *next = cur->group.next; + grp_list_.Remove(cur); + tm.ReturnFreeTask(cur); + cur = next; + } + { + LockGuard lck(&tm.tg_lock_); + (void)tm.grp_list_.erase(this); + } +} + +Status TaskGroup::GetTaskErrorIfAny() { + SharedLock lck(&rw_lock_); + for (Task &tk : grp_list_) { + RETURN_IF_NOT_OK(tk.GetTaskErrorIfAny()); + } + return Status::OK(); +} + +std::shared_ptr TaskGroup::GetIntrpService() { return intrp_svc_; } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/task_manager.h b/mindspore/ccsrc/minddata/dataset/util/task_manager.h new file mode 100644 index 0000000000..3030390bab --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/task_manager.h @@ -0,0 +1,181 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_TASK_MANAGER_H_ +#define DATASET_UTIL_TASK_MANAGER_H_ + +#if !defined(_WIN32) && !defined(_WIN64) +#include +#include // for sig_atomic_t +#endif +#include +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/intrp_service.h" +#include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/task.h" + +namespace mindspore { +namespace dataset { +namespace thread { +using id = std::thread::id; +} // namespace thread + +namespace this_thread { +inline thread::id get_id() { return std::this_thread::get_id(); } +} // namespace this_thread + +class TaskManager : public Service { + public: + friend class Services; + + friend class TaskGroup; + + ~TaskManager() override; + + TaskManager(const TaskManager &) = delete; + + TaskManager &operator=(const TaskManager &) = delete; + + static TaskManager &GetInstance() noexcept { return Services::getTaskMgrInstance(); } + + Status DoServiceStart() override; + + Status DoServiceStop() override; + + // A public global interrupt flag for signal handlers + volatile sig_atomic_t global_interrupt_; + + // API + // This takes the same parameter as Task constructor. Take a look + // of the test-thread.cc for usage. + Status CreateAsyncTask(const std::string &my_name, const std::function &f, TaskGroup *vg, Task **); + + // Same usage as boot thread group + Status join_all(); + + void interrupt_all() noexcept; + + // Locate a particular Task. + static Task *FindMe(); + + static void InterruptGroup(Task &); + + static Status GetMasterThreadRc(); + + static void InterruptMaster(const Status &rc = Status::OK()); + + static void WakeUpWatchDog() { +#if !defined(_WIN32) && !defined(_WIN64) + TaskManager &tm = TaskManager::GetInstance(); + (void)sem_post(&tm.sem_); +#endif + } + + void ReturnFreeTask(Task *p) noexcept; + + Status GetFreeTask(const std::string &my_name, const std::function &f, Task **p); + + Status WatchDog(); + + private: + RWLock lru_lock_; + SpinLock free_lock_; + SpinLock tg_lock_; + std::shared_ptr master_; + List lru_; + List free_lst_; +#if !defined(_WIN32) && !defined(_WIN64) + sem_t sem_; +#endif + TaskGroup *watchdog_grp_; + std::set grp_list_; + Task *watchdog_; + + TaskManager(); +}; + +// A group of related tasks. +class TaskGroup : public Service { + public: + friend class Task; + friend class TaskManager; + + Status CreateAsyncTask(const std::string &my_name, const std::function &f, Task **pTask = nullptr); + + void interrupt_all() noexcept; + + Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking); + + int size() const noexcept { return grp_list_.count; } + + Status DoServiceStart() override { return Status::OK(); } + + Status DoServiceStop() override; + + TaskGroup(); + + ~TaskGroup() override; + + Status GetTaskErrorIfAny(); + + std::shared_ptr GetIntrpService(); + + private: + Status rc_; + // Can't use rw_lock_ as we will lead to deadlatch. Create another mutex to serialize access to rc_. + std::mutex rc_mux_; + RWLock rw_lock_; + List grp_list_; + std::shared_ptr intrp_svc_; +}; + +namespace this_thread { +inline bool is_interrupted() { + TaskManager &tm = TaskManager::GetInstance(); + if (tm.global_interrupt_ == 1) { + return true; + } + Task *my_task = TaskManager::FindMe(); + return my_task->Interrupted(); +} + +inline bool is_master_thread() { + Task *my_task = TaskManager::FindMe(); + return my_task->IsMasterThread(); +} + +inline Status GetInterruptStatus() { + Task *my_task = TaskManager::FindMe(); + return my_task->GetInterruptStatus(); +} +} // namespace this_thread + +#define RETURN_IF_INTERRUPTED() \ + do { \ + if (mindspore::dataset::this_thread::is_interrupted()) { \ + return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \ + } \ + } while (false) + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_TASK_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/util/treap.h b/mindspore/ccsrc/minddata/dataset/util/treap.h similarity index 100% rename from mindspore/ccsrc/dataset/util/treap.h rename to mindspore/ccsrc/minddata/dataset/util/treap.h diff --git a/mindspore/ccsrc/minddata/dataset/util/wait_post.cc b/mindspore/ccsrc/minddata/dataset/util/wait_post.cc new file mode 100644 index 0000000000..944d9ca245 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/wait_post.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +WaitPost::WaitPost() : value_(0) {} + +Status WaitPost::Wait() { + std::unique_lock lck(mutex_); + return (wait_cond_.Wait(&lck, [this]() { return value_ != 0; })); +} + +void WaitPost::Set() { + std::unique_lock lck(mutex_); + value_ = 1; + wait_cond_.NotifyAll(); +} + +void WaitPost::Clear() { + std::unique_lock lck(mutex_); + value_ = 0; +} + +Status WaitPost::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } + +void WaitPost::ResetIntrpState() { wait_cond_.ResetIntrpState(); } + +Status WaitPost::Deregister() { return wait_cond_.Deregister(); } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/wait_post.h b/mindspore/ccsrc/minddata/dataset/util/wait_post.h new file mode 100644 index 0000000000..afd3bea38b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/wait_post.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_WAIT_POST_H_ +#define DATASET_UTIL_WAIT_POST_H_ + +#include +#include "minddata/dataset/util/cond_var.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class TaskGroup; + +class WaitPost { + public: + WaitPost(); + + ~WaitPost() = default; + + Status Wait(); + + void Set(); + + void Clear(); + + Status Register(TaskGroup *vg); + + Status Deregister(); + + void ResetIntrpState(); + + private: + std::mutex mutex_; + CondVar wait_cond_; + int value_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_WAIT_POST_H_ diff --git a/mindspore/ccsrc/mindrecord/CMakeLists.txt b/mindspore/ccsrc/minddata/mindrecord/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/mindrecord/CMakeLists.txt rename to mindspore/ccsrc/minddata/mindrecord/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc new file mode 100644 index 0000000000..e4d35b8305 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc @@ -0,0 +1,181 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_error.h" + +namespace mindspore { +namespace mindrecord { +std::string ErrnoToMessage(MSRStatus status) { + switch (status) { + case FAILED: + return "operator failed"; + break; + case SUCCESS: + return "operator success"; + break; + case OPEN_FILE_FAILED: + return "open file failed"; + break; + case CLOSE_FILE_FAILED: + return "close file failed"; + break; + case WRITE_METADATA_FAILED: + return "write metadata failed"; + break; + case WRITE_RAWDATA_FAILED: + return "write rawdata failed"; + break; + case GET_SCHEMA_FAILED: + return "get schema failed"; + break; + case ILLEGAL_RAWDATA: + return "illegal raw data"; + break; + case PYTHON_TO_JSON_FAILED: + return "pybind: python object to json failed"; + break; + case DIR_CREATE_FAILED: + return "directory create failed"; + break; + case OPEN_DIR_FAILED: + return "open directory failed"; + break; + case INVALID_STATISTICS: + return "invalid statistics object"; + break; + case OPEN_DATABASE_FAILED: + return "open database failed"; + break; + case CLOSE_DATABASE_FAILED: + return "close database failed"; + break; + case DATABASE_OPERATE_FAILED: + return "database operate failed"; + break; + case BUILD_SCHEMA_FAILED: + return "build schema failed"; + break; + case DIVISOR_IS_ILLEGAL: + return "divisor is illegal"; + break; + case INVALID_FILE_PATH: + return "file path is invalid"; + break; + case SECURE_FUNC_FAILED: + return "secure function failed"; + break; + case ALLOCATE_MEM_FAILED: + return "allocate memory failed"; + break; + case ILLEGAL_FIELD_NAME: + return "illegal field name"; + break; + case ILLEGAL_FIELD_TYPE: + return "illegal field type"; + break; + case SET_METADATA_FAILED: + return "set metadata failed"; + break; + case ILLEGAL_SCHEMA_DEFINITION: + return "illegal schema definition"; + break; + case ILLEGAL_COLUMN_LIST: + return "illegal column list"; + break; + case SQL_ERROR: + return "sql error"; + break; + case ILLEGAL_SHARD_COUNT: + return "illegal shard count"; + break; + case ILLEGAL_SCHEMA_COUNT: + return "illegal schema count"; + break; + case VERSION_ERROR: + return "data version is not matched"; + break; + case ADD_SCHEMA_FAILED: + return "add schema failed"; + break; + case ILLEGAL_Header_SIZE: + return "illegal header size"; + break; + case ILLEGAL_Page_SIZE: + return "illegal page size"; + break; + case ILLEGAL_SIZE_VALUE: + return "illegal size value"; + break; + case INDEX_FIELD_ERROR: + return "add index fields failed"; + break; + case GET_CANDIDATE_CATEGORYFIELDS_FAILED: + return "get candidate category fields failed"; + break; + case GET_CATEGORY_INFO_FAILED: + return "get category information failed"; + break; + case ILLEGAL_CATEGORY_ID: + return "illegal category id"; + break; + case ILLEGAL_ROWNUMBER_OF_PAGE: + return "illegal row number of page"; + break; + case ILLEGAL_SCHEMA_ID: + return "illegal schema id"; + break; + case DESERIALIZE_SCHEMA_FAILED: + return "deserialize schema failed"; + break; + case DESERIALIZE_STATISTICS_FAILED: + return "deserialize statistics failed"; + break; + case ILLEGAL_DB_FILE: + return "illegal db file"; + break; + case OVERWRITE_DB_FILE: + return "overwrite db file"; + break; + case OVERWRITE_MINDRECORD_FILE: + return "overwrite mindrecord file"; + break; + case ILLEGAL_MINDRECORD_FILE: + return "illegal mindrecord file"; + break; + case PARSE_JSON_FAILED: + return "parse json failed"; + break; + case ILLEGAL_PARAMETERS: + return "illegal parameters"; + break; + case GET_PAGE_BY_GROUP_ID_FAILED: + return "get page by group id failed"; + break; + case GET_SYSTEM_STATE_FAILED: + return "get system state failed"; + break; + case IO_FAILED: + return "io operate failed"; + break; + case MATCH_HEADER_FAILED: + return "match header failed"; + break; + default: + return "invalid error no"; + } +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc new file mode 100644 index 0000000000..d9e51efc4e --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc @@ -0,0 +1,230 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/utils.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_segment.h" +#include "minddata/mindrecord/include/shard_writer.h" +#include "nlohmann/json.hpp" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "utils/log_adapter.h" + +namespace py = pybind11; + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +void BindSchema(py::module *m) { + (void)py::class_>(*m, "Schema", py::module_local()) + .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Schema::Build) + .def("get_desc", &Schema::GetDesc) + .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) + .def("get_blob_fields", &Schema::GetBlobFields) + .def("get_schema_id", &Schema::GetSchemaID); +} + +void BindStatistics(const py::module *m) { + (void)py::class_>(*m, "Statistics", py::module_local()) + .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Statistics::Build) + .def("get_desc", &Statistics::GetDesc) + .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) + .def("get_statistics_id", &Statistics::GetStatisticsID); +} + +void BindShardHeader(const py::module *m) { + (void)py::class_>(*m, "ShardHeader", py::module_local()) + .def(py::init<>()) + .def("add_schema", &ShardHeader::AddSchema) + .def("add_statistics", &ShardHeader::AddStatistic) + .def("add_index_fields", + (MSRStatus(ShardHeader::*)(const std::vector &)) & ShardHeader::AddIndexFields) + .def("get_meta", &ShardHeader::GetSchemas) + .def("get_statistics", &ShardHeader::GetStatistics) + .def("get_fields", &ShardHeader::GetFields) + .def("get_schema_by_id", &ShardHeader::GetSchemaByID) + .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); +} + +void BindShardWriter(py::module *m) { + (void)py::class_(*m, "ShardWriter", py::module_local()) + .def(py::init<>()) + .def("open", &ShardWriter::Open) + .def("open_for_append", &ShardWriter::OpenForAppend) + .def("set_header_size", &ShardWriter::SetHeaderSize) + .def("set_page_size", &ShardWriter::SetPageSize) + .def("set_shard_header", &ShardWriter::SetShardHeader) + .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, + vector> &, bool, bool)) & + ShardWriter::WriteRawData) + .def("commit", &ShardWriter::Commit); +} + +void BindShardReader(const py::module *m) { + (void)py::class_>(*m, "ShardReader", py::module_local()) + .def(py::init<>()) + .def("open", (MSRStatus(ShardReader::*)(const std::vector &, bool, const int &, + const std::vector &, + const std::vector> &)) & + ShardReader::OpenPy) + .def("launch", &ShardReader::Launch) + .def("get_header", &ShardReader::GetShardHeader) + .def("get_blob_fields", &ShardReader::GetBlobFields) + .def("get_next", (std::vector>, pybind11::object>>(ShardReader::*)()) & + ShardReader::GetNextPy) + .def("finish", &ShardReader::Finish) + .def("close", &ShardReader::Close); +} + +void BindShardIndexGenerator(const py::module *m) { + (void)py::class_(*m, "ShardIndexGenerator", py::module_local()) + .def(py::init()) + .def("build", &ShardIndexGenerator::Build) + .def("write_to_db", &ShardIndexGenerator::WriteToDatabase); +} + +void BindShardSegment(py::module *m) { + (void)py::class_(*m, "ShardSegment", py::module_local()) + .def(py::init<>()) + .def("open", (MSRStatus(ShardSegment::*)(const std::vector &, bool, const int &, + const std::vector &, + const std::vector> &)) & + ShardSegment::OpenPy) + .def("get_category_fields", + (std::pair>(ShardSegment::*)()) & ShardSegment::GetCategoryFields) + .def("set_category_field", (MSRStatus(ShardSegment::*)(std::string)) & ShardSegment::SetCategoryField) + .def("read_category_info", (std::pair(ShardSegment::*)()) & ShardSegment::ReadCategoryInfo) + .def("read_at_page_by_id", (std::pair, pybind11::object>>>( + ShardSegment::*)(int64_t, int64_t, int64_t)) & + ShardSegment::ReadAtPageByIdPy) + .def("read_at_page_by_name", (std::pair, pybind11::object>>>( + ShardSegment::*)(std::string, int64_t, int64_t)) & + ShardSegment::ReadAtPageByNamePy) + .def("get_header", &ShardSegment::GetShardHeader) + .def("get_blob_fields", + (std::pair>(ShardSegment::*)()) & ShardSegment::GetBlobFields); +} + +void BindGlobalParams(py::module *m) { + (*m).attr("MIN_HEADER_SIZE") = kMinHeaderSize; + (*m).attr("MAX_HEADER_SIZE") = kMaxHeaderSize; + (*m).attr("MIN_PAGE_SIZE") = kMinPageSize; + (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize; + (*m).attr("MIN_SHARD_COUNT") = kMinShardCount; + (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount; + (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount; + (void)(*m).def("get_max_thread_num", &GetMaxThreadNum); +} + +PYBIND11_MODULE(_c_mindrecord, m) { + m.doc() = "pybind11 mindrecord plugin"; // optional module docstring + (void)py::enum_(m, "MSRStatus", py::module_local()) + .value("SUCCESS", SUCCESS) + .value("FAILED", FAILED) + .export_values(); + (void)py::enum_(m, "ShardType", py::module_local()).value("NLP", kNLP).value("CV", kCV).export_values(); + BindGlobalParams(&m); + BindSchema(&m); + BindStatistics(&m); + BindShardHeader(&m); + BindShardWriter(&m); + BindShardReader(&m); + BindShardIndexGenerator(&m); + BindShardSegment(&m); +} +} // namespace mindrecord +} // namespace mindspore + +namespace nlohmann { +namespace detail { +py::object FromJsonImpl(const json &j) { + if (j.is_null()) { + return py::none(); + } else if (j.is_boolean()) { + return py::bool_(j.get()); + } else if (j.is_number()) { + double number = j.get(); + if (fabs(number - std::floor(number)) < mindspore::mindrecord::kEpsilon) { + return py::int_(j.get()); + } else { + return py::float_(number); + } + } else if (j.is_string()) { + return py::str(j.get()); + } else if (j.is_array()) { + py::list obj; + for (const auto &el : j) { + (void)obj.attr("append")(FromJsonImpl(el)); + } + return std::move(obj); + } else { + py::dict obj; + for (json::const_iterator it = j.cbegin(); it != j.cend(); ++it) { + obj[py::str(it.key())] = FromJsonImpl(it.value()); + } + return std::move(obj); + } +} + +json ToJsonImpl(const py::handle &obj) { + if (obj.is_none()) { + return nullptr; + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj) || py::isinstance(obj)) { + auto out = json::array(); + for (const py::handle &value : obj) { + out.push_back(ToJsonImpl(value)); + } + return out; + } + if (py::isinstance(obj)) { + auto out = json::object(); + for (const py::handle &key : obj) { + out[py::str(key).cast()] = ToJsonImpl(obj[key]); + } + return out; + } + MS_LOG(ERROR) << "Python to json failed, obj is: " << py::cast(obj); + return json(); +} +} // namespace detail + +py::object adl_serializer::FromJson(const json &j) { return detail::FromJsonImpl(j); } + +void adl_serializer::ToJson(json *j, const py::object &obj) { + *j = detail::ToJsonImpl(obj); +} // namespace detail +} // namespace nlohmann diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc new file mode 100644 index 0000000000..b5021802a0 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc @@ -0,0 +1,204 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "common/utils.h" +#include "./securec.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::DEBUG; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +// split a string using a character +std::vector StringSplit(const std::string &field, char separator) { + std::vector res; + uint64_t s_pos = 0; + while (s_pos < field.length()) { + size_t e_pos = field.find_first_of(separator, s_pos); + if (e_pos != std::string::npos) { + res.push_back(field.substr(s_pos, e_pos - s_pos)); + } else { + res.push_back(field.substr(s_pos, field.length() - s_pos)); + break; + } + s_pos = e_pos + 1; + } + return res; +} + +bool ValidateFieldName(const std::string &str) { + std::string::const_iterator it = str.begin(); + if (it == str.end()) { + return false; + } + for (; it != str.end(); ++it) { + if (*it == '_' || ((*it >= '0') && (*it <= '9')) || ((*it >= 'A') && (*it <= 'Z')) || + ((*it >= 'a') && (*it <= 'z'))) { + continue; + } + return false; + } + return true; +} + +std::pair GetFileName(const std::string &path) { + char real_path[PATH_MAX] = {0}; + char buf[PATH_MAX] = {0}; + if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { + MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; + return {FAILED, ""}; + } + char tmp[PATH_MAX] = {0}; +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { + MS_LOG(ERROR) << "Invalid file path, path: " << buf; + return {FAILED, ""}; + } + if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { + MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; + } +#else + if (realpath(dirname(&(buf[0])), tmp) == nullptr) { + MS_LOG(ERROR) << "Invalid file path, path: " << buf; + return {FAILED, ""}; + } + if (realpath(common::SafeCStr(path), real_path) == nullptr) { + MS_LOG(DEBUG) << "Path: " << path << "check successfully"; + } +#endif + std::string s = real_path; + char sep = '/'; + size_t i = s.rfind(sep, s.length()); + if (i != std::string::npos) { + if (i + 1 < s.size()) { + return {SUCCESS, s.substr(i + 1)}; + } + } + return {SUCCESS, s}; +} + +std::pair GetParentDir(const std::string &path) { + char real_path[PATH_MAX] = {0}; + char buf[PATH_MAX] = {0}; + if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { + MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; + return {FAILED, ""}; + } + char tmp[PATH_MAX] = {0}; +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { + MS_LOG(ERROR) << "Invalid file path, path: " << buf; + return {FAILED, ""}; + } + if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { + MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; + } +#else + if (realpath(dirname(&(buf[0])), tmp) == nullptr) { + MS_LOG(ERROR) << "Invalid file path, path: " << buf; + return {FAILED, ""}; + } + if (realpath(common::SafeCStr(path), real_path) == nullptr) { + MS_LOG(DEBUG) << "Path: " << path << "check successfully"; + } +#endif + std::string s = real_path; + if (s.rfind('/') + 1 <= s.size()) { + return {SUCCESS, s.substr(0, s.rfind('/') + 1)}; + } + return {SUCCESS, "/"}; +} + +bool CheckIsValidUtf8(const std::string &str) { + int n = 0; + int ix = str.length(); + for (int i = 0; i < ix; ++i) { + uint8_t c = static_cast(str[i]); + if (c <= 0x7f) { + n = 0; + } else if ((c & 0xE0) == 0xC0) { + n = 1; + } else if (c == 0xed && i < (ix - 1) && (static_cast(str[i + 1]) & 0xa0) == 0xa0) { + return false; + } else if ((c & 0xF0) == 0xE0) { + n = 2; + } else if ((c & 0xF8) == 0xF0) { + n = 3; + } else { + return false; + } + for (int j = 0; j < n && i < ix; ++j) { + if ((++i == ix) || ((static_cast(str[i]) & 0xC0) != 0x80)) { + return false; + } + } + } + return true; +} + +bool IsLegalFile(const std::string &path) { + struct stat s; + if (stat(common::SafeCStr(path), &s) == 0) { + if (s.st_mode & S_IFDIR) { + return false; + } + return true; + } + return false; +} + +std::pair GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type) { +#if defined(_WIN32) || defined(_WIN64) + return {SUCCESS, 100}; +#else + uint64_t ll_count = 0; + struct statfs disk_info; + if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) { + MS_LOG(ERROR) << "Get disk size error"; + return {FAILED, 0}; + } + + switch (disk_type) { + case kTotalSize: + ll_count = disk_info.f_bsize * disk_info.f_blocks; + ll_count = ll_count >> 20; + break; + case kFreeSize: + ll_count = disk_info.f_bsize * disk_info.f_bavail; + ll_count = ll_count >> 20; + break; + default: + ll_count = 0; + break; + } + + return {SUCCESS, ll_count}; +#endif +} + +uint32_t GetMaxThreadNum() { + // define the number of thread + uint32_t thread_num = std::thread::hardware_concurrency(); + if (thread_num == 0) { + thread_num = kMaxConsumerCount; + } + return thread_num; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_pybind.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_pybind.h new file mode 100644 index 0000000000..3b3698ca68 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_pybind.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ +#define MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ + +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "pybind11/pybind11.h" + +namespace py = pybind11; +namespace nlohmann { +template <> +struct adl_serializer { + py::object FromJson(const json &j); + + void ToJson(json *j, const py::object &obj); +}; + +namespace detail { +py::object FromJsonImpl(const json &j); + +json ToJsonImpl(const py::handle &obj); +} // namespace detail +} // namespace nlohmann +#endif // MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h new file mode 100644 index 0000000000..bd1cda8a99 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h @@ -0,0 +1,182 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ +#define MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ + +#include +#include +#include +#include +#if !defined(_WIN32) && !defined(_WIN64) +#include +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_error.h" +#include "nlohmann/json.hpp" +#include "./sqlite3.h" +#include "utils/log_adapter.h" + +/* To be used when dlog is ok #include "./slog.h" */ +#ifdef DEBUG +#define MS_ASSERT(f) assert(f) +#else +#define MS_ASSERT(f) ((void)0) +#endif + +namespace mindspore { +namespace mindrecord { +using json = nlohmann::json; + +const int kInt0 = 0; +const int kInt1 = 1; +const int kInt2 = 2; +const int kInt3 = 3; +const int kUnsignedInt4 = 4; + +enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel }; + +const char kVersion[] = "3.0"; +const std::vector kSupportedVersion = {"2.0", kVersion}; + +enum ShardType { + kNLP = 0, + kCV = 1, +}; + +enum TaskType { + kCommonTask = 0, + kPaddedTask = 1, +}; +enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; + +enum ShuffleType { kShuffleCategory, kShuffleSample }; + +const double kEpsilon = 1e-7; + +const int kThreadNumber = 14; + +// Shard default parameters +const uint64_t kDefaultHeaderSize = 1 << 24; // 16MB +const uint64_t kDefaultPageSize = 1 << 25; // 32MB + +// HeaderSize [16KB, 128MB] +const int kMinHeaderSize = 1 << 14; // 16KB +const int kMaxHeaderSize = 1 << 27; // 128MB + +// PageSize [32KB, 256MB] +const int kMinPageSize = 1 << 15; // 32KB +const int kMaxPageSize = 1 << 28; // 256MB + +// used by value length / schema id length / statistic id length ... +const uint64_t kInt64Len = 8; + +// Minimum file size +const uint64_t kMinFileSize = kInt64Len; + +const int kMinShardCount = 1; +const int kMaxShardCount = 1000; + +const int kMinConsumerCount = 1; +const int kMaxConsumerCount = 128; + +const int kMaxSchemaCount = 1; +const int kMaxThreadCount = 32; +const int kMaxFieldCount = 100; + +// Minimum free disk size +const int kMinFreeDiskSize = 10; // 10M + +// dummy json +const json kDummyId = R"({"id": 0})"_json; + +// translate type in schema to type in sqlite3(NULL, INTEGER, REAL, TEXT, BLOB) +const std::unordered_map kDbJsonMap = { + {"string", "TEXT"}, {"date", "DATE"}, {"date-time", "DATETIME"}, {"null", "NULL"}, + {"integer", "INTEGER"}, {"boolean", "BOOLEAN"}, {"array", "BLOB"}, {"number", "NUMERIC"}, + {"int32", "INTEGER"}, {"int64", "INTEGER"}, {"float32", "NUMERIC"}, {"float64", "NUMERIC"}, + {"bytes", "BLOB"}}; + +const char kPoint = '.'; + +// field type used by check schema validation +const std::set kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"}; + +// can be searched field list +const std::set kScalarFieldTypeSet = {"string", "int32", "int64", "float32", "float64"}; + +// number field list +const std::set kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"}; + +/// \brief split a string using a character +/// \param[in] field target string +/// \param[in] separator a character for spliting +/// \return vector type result +std::vector StringSplit(const std::string &field, char separator); + +/// \brief validate field name is composed of '0-9' or 'a-z' or 'A-Z' or '_' or '-' +/// \param[in] str target string +/// \return +bool ValidateFieldName(const std::string &str); + +/// \brief get the filename by the path +/// \param s file path +/// \return +std::pair GetFileName(const std::string &s); + +/// \brief get parent dir +/// \param path file path +/// \return parent path +std::pair GetParentDir(const std::string &path); + +bool CheckIsValidUtf8(const std::string &str); + +/// \brief judge if a path is legal file +/// \param path file path +/// \return parent path +bool IsLegalFile(const std::string &path); + +enum DiskSizeType { kTotalSize = 0, kFreeSize }; + +/// \brief get the free space about the disk +/// \param str_dir file path +/// \param disk_type: kTotalSize / kFreeSize +/// \return size in Megabytes +std::pair GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type); + +/// \brief get the max hardware concurrency +/// \return max concurrency +uint32_t GetMaxThreadNum(); +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h new file mode 100644 index 0000000000..ed1e748afe --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ +#define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ + +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_operator.h" + +namespace mindspore { +namespace mindrecord { +class ShardCategory : public ShardOperator { + public: + explicit ShardCategory(const std::vector> &categories, + int64_t num_elements = std::numeric_limits::max(), bool replacement = false); + + ShardCategory(const std::string &category_field, int64_t num_elements, + int64_t num_categories = std::numeric_limits::max(), bool replacement = false); + + ~ShardCategory() override{}; + + const std::vector> &GetCategories() const { return categories_; } + + const std::string GetCategoryField() const { return category_field_; } + + int64_t GetNumElements() const { return num_elements_; } + + int64_t GetNumCategories() const { return num_categories_; } + + bool GetReplacement() const { return replacement_; } + + MSRStatus Execute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + private: + std::vector> categories_; + std::string category_field_; + int64_t num_elements_; + int64_t num_categories_; + bool replacement_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h new file mode 100644 index 0000000000..f6353ed3ce --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h @@ -0,0 +1,167 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_ +#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_ + +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_header.h" + +namespace mindspore { +namespace mindrecord { +const uint64_t kUnsignedOne = 1; +const uint64_t kBitsOfByte = 8; +const uint64_t kDataTypeBits = 2; +const uint64_t kNumDataOfByte = 4; +const uint64_t kBytesOfColumnLen = 4; +const uint64_t kDataTypeBitMask = 3; +const uint64_t kDataTypes = 6; + +enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type }; + +enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound }; + +enum ColumnDataType { + ColumnBytes = 0, + ColumnString = 1, + ColumnInt32 = 2, + ColumnInt64 = 3, + ColumnFloat32 = 4, + ColumnFloat64 = 5, + ColumnNoDataType = 6 +}; + +// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"}; +const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8}; + +const std::vector ColumnDataTypeNameNormalized = {"uint8", "string", "int32", + "int64", "float32", "float64"}; + +const std::unordered_map ColumnDataTypeMap = { + {"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32}, + {"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}}; + +class ShardColumn { + public: + explicit ShardColumn(const std::shared_ptr &shard_header, bool compress_integer = true); + + ~ShardColumn() = default; + + /// \brief get column value by column name + MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, + const json &columns_json, const unsigned char **data, + std::unique_ptr *data_ptr, uint64_t *const n_bytes, + ColumnDataType *column_data_type, uint64_t *column_data_type_size, + std::vector *column_shape); + + /// \brief compress blob + std::vector CompressBlob(const std::vector &blob); + + /// \brief check if blob compressed + bool CheckCompressBlob() const { return has_compress_blob_; } + + uint64_t GetNumBlobColumn() const { return num_blob_column_; } + + std::vector GetColumnName() { return column_name_; } + + std::vector GeColumnDataType() { return column_data_type_; } + + std::vector> GetColumnShape() { return column_shape_; } + + /// \brief get column value from blob + MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, + const unsigned char **data, std::unique_ptr *data_ptr, + uint64_t *const n_bytes); + std::pair GetColumnTypeByName(const std::string &column_name, + ColumnDataType *column_data_type, + uint64_t *column_data_type_size, + std::vector *column_shape); + + /// \brief get column value from json + MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, + std::unique_ptr *data_ptr, uint64_t *n_bytes); + + private: + /// \brief get float value from json + template + MSRStatus GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, bool use_double); + + /// \brief get integer value from json + template + MSRStatus GetInt(std::unique_ptr *data_ptr, const json &json_column_value); + + /// \brief get column offset address and size from blob + MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, + uint64_t *num_bytes, uint64_t *shift_idx); + + /// \brief check if column name is available + ColumnCategory CheckColumnName(const std::string &column_name); + + /// \brief compress integer column + static vector CompressInt(const vector &src_bytes, const IntegerType &int_type); + + /// \brief uncompress integer array column + template + static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, + const std::vector &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); + + /// \brief convert big-endian bytes to unsigned int + /// \param bytes_array bytes array + /// \param pos shift address in bytes array + /// \param i_type integer type + /// \return unsigned int + static uint64_t BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &i_type); + + /// \brief convert unsigned int to big-endian bytes + /// \param value integer value + /// \param i_type integer type + /// \return bytes + static std::vector UIntToBytesBig(uint64_t value, const IntegerType &i_type); + + /// \brief convert unsigned int to little-endian bytes + /// \param value integer value + /// \param i_type integer type + /// \return bytes + static std::vector UIntToBytesLittle(uint64_t value, const IntegerType &i_type); + + /// \brief convert unsigned int to little-endian bytes + /// \param bytes_array bytes array + /// \param pos shift address in bytes array + /// \param src_i_type source integer typ0e + /// \param dst_i_type (output), destination integer type + /// \return integer + static int64_t BytesLittleToMinIntType(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr); + + private: + std::vector column_name_; // column name list + std::vector column_data_type_; // column data type list + std::vector> column_shape_; // column shape list + std::unordered_map column_name_id_; // column name id map + std::vector blob_column_; // blob column list + std::unordered_map blob_column_id_; // blob column name id map + bool has_compress_blob_; // if has compress blob + uint64_t num_blob_column_; // number of blob columns +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h new file mode 100644 index 0000000000..f166ec1e6c --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ + +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_shuffle.h" +#include "minddata/mindrecord/include/shard_sample.h" + +namespace mindspore { +namespace mindrecord { +class ShardDistributedSample : public ShardSample { + public: + ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); + + ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed); + + void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; } + + ~ShardDistributedSample() override{}; + + MSRStatus PreExecute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + private: + bool shuffle_; + int no_of_padded_samples_; + bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch + ShardTask task_; // maintain the input tasks in first epoch +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_error.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_error.h similarity index 100% rename from mindspore/ccsrc/mindrecord/include/shard_error.h rename to mindspore/ccsrc/minddata/mindrecord/include/shard_error.h diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h new file mode 100644 index 0000000000..67169e8696 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h @@ -0,0 +1,186 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_HEADER_H_ +#define MINDRECORD_INCLUDE_SHARD_HEADER_H_ + +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "minddata/mindrecord/include/shard_page.h" +#include "minddata/mindrecord/include/shard_schema.h" +#include "minddata/mindrecord/include/shard_statistics.h" + +namespace mindspore { +namespace mindrecord { +class ShardHeader { + public: + ShardHeader(); + + ~ShardHeader() = default; + + MSRStatus BuildDataset(const std::vector &file_paths, bool load_dataset = true); + + static std::pair BuildSingleHeader(const std::string &file_path); + /// \brief add the schema and save it + /// \param[in] schema the schema needs to be added + /// \return the last schema's id + int AddSchema(std::shared_ptr schema); + + /// \brief add the statistic and save it + /// \param[in] statistic the statistic needs to be added + /// \return the last statistic's id + void AddStatistic(std::shared_ptr statistic); + + /// \brief create index and add fields which from schema for each schema + /// \param[in] fields the index fields needs to be added + /// \return SUCCESS if add successfully, FAILED if not + MSRStatus AddIndexFields(std::vector> fields); + + MSRStatus AddIndexFields(const std::vector &fields); + + /// \brief get the schema + /// \return the schema + std::vector> GetSchemas(); + + /// \brief get Statistics + /// \return the Statistic + std::vector> GetStatistics(); + + /// \brief get the fields of the index + /// \return the fields of the index + std::vector> GetFields(); + + /// \brief get the index + /// \return the index + std::shared_ptr GetIndex(); + + /// \brief get the schema by schemaid + /// \param[in] schemaId the id of schema needs to be got + /// \return the schema obtained by schemaId + std::pair, MSRStatus> GetSchemaByID(int64_t schema_id); + + /// \brief get the filepath to shard by shardID + /// \param[in] shardID the id of shard which filepath needs to be obtained + /// \return the filepath obtained by shardID + std::string GetShardAddressByID(int64_t shard_id); + + /// \brief get the statistic by statistic id + /// \param[in] statisticId the id of statistic needs to be get + /// \return the statistics obtained by statistic id + std::pair, MSRStatus> GetStatisticByID(int64_t statistic_id); + + MSRStatus InitByFiles(const std::vector &file_paths); + + void SetIndex(Index index) { index_ = std::make_shared(index); } + + std::pair, MSRStatus> GetPage(const int &shard_id, const int &page_id); + + MSRStatus SetPage(const std::shared_ptr &new_page); + + MSRStatus AddPage(const std::shared_ptr &new_page); + + int64_t GetLastPageId(const int &shard_id); + + int GetLastPageIdByType(const int &shard_id, const std::string &page_type); + + const std::pair> GetPageByGroupId(const int &group_id, const int &shard_id); + + std::vector GetShardAddresses() const { return shard_addresses_; } + + int GetShardCount() const { return shard_count_; } + + int GetSchemaCount() const { return schema_.size(); } + + uint64_t GetHeaderSize() const { return header_size_; } + + uint64_t GetPageSize() const { return page_size_; } + + void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } + + void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } + + std::vector SerializeHeader(); + + MSRStatus PagesToFile(const std::string dump_file_name); + + MSRStatus FileToPages(const std::string dump_file_name); + + private: + MSRStatus InitializeHeader(const std::vector &headers, bool load_dataset); + + /// \brief get the headers from all the shard data + /// \param[in] the shard data real path + /// \param[in] the headers which readed from the shard data + /// \return SUCCESS/FAILED + MSRStatus GetHeaders(const vector &real_addresses, std::vector &headers); + + MSRStatus ValidateField(const std::vector &field_name, json schema, const uint64_t &schema_id); + + /// \brief check the binary file status + static MSRStatus CheckFileStatus(const std::string &path); + + static std::pair ValidateHeader(const std::string &path); + + void ParseHeader(const json &header); + + void GetHeadersOneTask(int start, int end, std::vector &headers, const vector &realAddresses); + + MSRStatus ParseIndexFields(const json &index_fields); + + MSRStatus CheckIndexField(const std::string &field, const json &schema); + + void ParsePage(const json &page, int shard_index, bool load_dataset); + + MSRStatus ParseStatistics(const json &statistics); + + MSRStatus ParseSchema(const json &schema); + + void ParseShardAddress(const json &address); + + std::string SerializeIndexFields(); + + std::vector SerializePage(); + + std::string SerializeStatistics(); + + std::string SerializeSchema(); + + std::string SerializeShardAddress(); + + std::shared_ptr InitIndexPtr(); + + MSRStatus GetAllSchemaID(std::set &bucket_count); + + uint32_t shard_count_; + uint64_t header_size_; + uint64_t page_size_; + + std::shared_ptr index_; + std::vector shard_addresses_; + std::vector> schema_; + std::vector> statistics_; + std::vector>> pages_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_HEADER_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_index.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_index.h new file mode 100644 index 0000000000..79b10893fb --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_index.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INDEX_H +#define MINDRECORD_INDEX_H +#pragma once + +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_schema.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +using std::cin; +using std::endl; +using std::pair; +using std::string; +using std::vector; + +class Index { + public: + Index(); + + ~Index() {} + + /// \brief Add field which from schema according to schemaId + /// \param[in] schemaId the id of schema to be added + /// \param[in] field the field need to be added + /// + /// add the field to the fields_ vector + void AddIndexField(const int64_t &schemaId, const std::string &field); + + /// \brief get stored fields + /// \return fields stored + std::vector > GetFields(); + + private: + std::vector > fields_; + string database_name_; + string table_name_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INDEX_H diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h new file mode 100644 index 0000000000..fb85d9adbc --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h @@ -0,0 +1,120 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ +#define MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_header.h" +#include "./sqlite3.h" + +namespace mindspore { +namespace mindrecord { +using INDEX_FIELDS = std::pair>>; +using ROW_DATA = std::pair>>>; +class ShardIndexGenerator { + public: + explicit ShardIndexGenerator(const std::string &file_path, bool append = false); + + MSRStatus Build(); + + static std::pair GenerateFieldName(const std::pair &field); + + ~ShardIndexGenerator() {} + + /// \brief fetch value in json by field name + /// \param[in] field + /// \param[in] input + /// \return pair + std::pair GetValueByField(const string &field, json input); + + /// \brief fetch field type in schema n by field path + /// \param[in] field_path + /// \param[in] schema + /// \return the type of field + static std::string TakeFieldType(const std::string &field_path, json schema); + + /// \brief create databases for indexes + MSRStatus WriteToDatabase(); + + private: + static int Callback(void *not_used, int argc, char **argv, char **az_col_name); + + static MSRStatus ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = ""); + + static std::string ConvertJsonToSQL(const std::string &json); + + std::pair CreateDatabase(int shard_no); + + std::pair> GetSchemaDetails(const std::vector &schema_lens, std::fstream &in); + + static std::pair GenerateRawSQL(const std::vector> &fields); + + std::pair CheckDatabase(const std::string &shard_address); + + /// + /// \param shard_no + /// \param blob_id_to_page_id + /// \param raw_page_id + /// \param in + /// \return field name, db type, field value + ROW_DATA GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, int raw_page_id, + std::fstream &in); + /// + /// \param db + /// \param sql + /// \param data + /// \return + MSRStatus BindParameterExecuteSQL( + sqlite3 *db, const std::string &sql, + const std::vector>> &data); + + INDEX_FIELDS GenerateIndexFields(const std::vector &schema_detail); + + MSRStatus ExecuteTransaction(const int &shard_no, std::pair &db, + const std::vector &raw_page_ids, const std::map &blob_id_to_page_id); + + MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name); + + MSRStatus AddBlobPageInfo(std::vector> &row_data, + const std::shared_ptr cur_blob_page, uint64_t &cur_blob_page_offset, + std::fstream &in); + + void AddIndexFieldByRawData(const std::vector &schema_detail, + std::vector> &row_data); + + void DatabaseWriter(); // worker thread + + std::string file_path_; + bool append_; + ShardHeader shard_header_; + uint64_t page_size_; + uint64_t header_size_; + int schema_count_; + std::atomic_int task_; + std::atomic_bool write_success_; + std::vector> fields_; +}; +} // namespace mindrecord +} // namespace mindspore +#endif // MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h new file mode 100644 index 0000000000..b5ea53b759 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ +#define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ + +#include +#include "minddata/mindrecord/include/shard_task.h" + +namespace mindspore { +namespace mindrecord { +class ShardOperator { + public: + virtual ~ShardOperator() = default; + + MSRStatus operator()(ShardTask &tasks) { + if (SUCCESS != this->PreExecute(tasks)) { + return FAILED; + } + if (SUCCESS != this->Execute(tasks)) { + return FAILED; + } + if (SUCCESS != this->SufExecute(tasks)) { + return FAILED; + } + return SUCCESS; + } + virtual bool HasChildOp() { return child_op_ != nullptr; } + + virtual MSRStatus SetChildOp(std::shared_ptr child_op) { + if (child_op != nullptr) child_op_ = child_op; + return SUCCESS; + } + + virtual std::shared_ptr GetChildOp() { return child_op_; } + + virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } + + virtual MSRStatus Execute(ShardTask &tasks) = 0; + + virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } + + virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } + + private: + std::shared_ptr child_op_ = nullptr; +}; +} // namespace mindrecord +} // namespace mindspore +#endif // MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h new file mode 100644 index 0000000000..01c70acf29 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h @@ -0,0 +1,106 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_PAGE_H_ +#define MINDRECORD_INCLUDE_SHARD_PAGE_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +const std::string kPageTypeRaw = "RAW_DATA"; +const std::string kPageTypeBlob = "BLOB_DATA"; +const std::string kPageTypeNewColumn = "NEW_COLUMN_DATA"; + +class Page { + public: + Page(const int &page_id, const int &shard_id, const std::string &page_type, const int &page_type_id, + const uint64_t &start_row_id, const uint64_t end_row_id, + const std::vector> &row_group_ids, const uint64_t page_size) + : page_id_(page_id), + shard_id_(shard_id), + page_type_(page_type), + page_type_id_(page_type_id), + start_row_id_(start_row_id), + end_row_id_(end_row_id), + row_group_ids_(row_group_ids), + page_size_(page_size) {} + + ~Page() = default; + + /// \brief get the page and its description + /// \return the json format of the page and its description + json GetPage() const; + + int GetPageID() const { return page_id_; } + + int GetShardID() const { return shard_id_; } + + int GetPageTypeID() const { return page_type_id_; } + + std::string GetPageType() const { return page_type_; } + + uint64_t GetPageSize() const { return page_size_; } + + uint64_t GetStartRowID() const { return start_row_id_; } + + uint64_t GetEndRowID() const { return end_row_id_; } + + void SetEndRowID(const uint64_t &end_row_id) { end_row_id_ = end_row_id; } + + void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } + + std::pair GetLastRowGroupID() const { return row_group_ids_.back(); } + + std::vector> GetRowGroupIds() const { return row_group_ids_; } + + void SetRowGroupIds(const std::vector> &last_row_group_ids) { + row_group_ids_ = last_row_group_ids; + } + + void DeleteLastGroupId(); + + private: + int page_id_; + int shard_id_; + std::string page_type_; + int page_type_id_; + uint64_t start_row_id_; + uint64_t end_row_id_; + std::vector> row_group_ids_; + uint64_t page_size_; + // JSON page: { + // "page_id":X, + // "shard_id":X, + // "page_type":"XXX", (enum "raw_data", "blob_data", "new_column") + // "page_type_id":X, + // "start_row_id":X, + // "end_row_id":X, + // "row_group_ids":[{"id":X, "offset":X}], + // "page_size":X, +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_PAGE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h new file mode 100644 index 0000000000..2d420b563d --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ + +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_shuffle.h" +#include "minddata/mindrecord/include/shard_category.h" + +namespace mindspore { +namespace mindrecord { +class ShardPkSample : public ShardCategory { + public: + ShardPkSample(const std::string &category_field, int64_t num_elements); + + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); + + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); + + ~ShardPkSample() override{}; + + MSRStatus SufExecute(ShardTask &tasks) override; + + private: + bool shuffle_; + std::shared_ptr shuffle_op_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h new file mode 100644 index 0000000000..b1b0c1397a --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h @@ -0,0 +1,366 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_READER_H_ +#define MINDRECORD_INCLUDE_SHARD_READER_H_ + +#include +#include +#if !defined(_WIN32) && !defined(_WIN64) +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_category.h" +#include "minddata/mindrecord/include/shard_column.h" +#include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +using ROW_GROUPS = + std::tuple>>, std::vector>>; +using ROW_GROUP_BRIEF = + std::tuple>, std::vector>; +using TASK_RETURN_CONTENT = + std::pair, json>>>>; +const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode +const int kNumPageInBuffer = 16; // page buffer size in block-reader mode + +class ShardReader { + public: + ShardReader(); + + virtual ~ShardReader(); + + /// \brief open files and initialize reader, c++ API + /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list + /// \param[in] load_dataset load dataset from single file or not + /// \param[in] n_consumer number of threads when reading + /// \param[in] selected_columns column list to be populated + /// \param[in] operators operators applied to data, operator type is shuffle, sample or category + /// \param[in] block_reader block-reader mode if true, otherwise row-reader mode + /// \return MSRStatus the status of MSRStatus + MSRStatus Open(const std::vector &file_paths, bool load_dataset, int n_consumer = 4, + const std::vector &selected_columns = {}, + const std::vector> &operators = {}, const bool &block_reader = false, + const int num_padded = 0); + + /// \brief open files and initialize reader, python API + /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list + /// \param[in] load_dataset load dataset from single file or not + /// \param[in] n_consumer number of threads when reading + /// \param[in] selected_columns column list to be populated + /// \param[in] operators operators applied to data, operator type is shuffle, sample or category + /// \return MSRStatus the status of MSRStatus + MSRStatus OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer = 4, + const std::vector &selected_columns = {}, + const std::vector> &operators = {}); + + /// \brief close reader + /// \return null + void Close(); + + /// \brief read the file, get schema meta,statistics and index, single-thread mode + /// \return MSRStatus the status of MSRStatus + MSRStatus Open(); + + /// \brief read the file, get schema meta,statistics and index, multiple-thread mode + /// \return MSRStatus the status of MSRStatus + MSRStatus Open(int n_consumer); + + /// \brief launch threads to get batches + /// \param[in] is_simple_reader trigger threads if false; do nothing if true + /// \return MSRStatus the status of MSRStatus + MSRStatus Launch(bool is_simple_reader = false); + + /// \brief aim to get the meta data + /// \return the metadata + std::shared_ptr GetShardHeader() const; + + /// \brief aim to get columns context + /// \return the columns + std::shared_ptr GetShardColumn() const; + + /// \brief get the number of shards + /// \return # of shards + int GetShardCount() const; + + /// \brief get the number of rows in database + /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list + /// \param[in] load_dataset load dataset from single file or not + /// \param[in] op smart pointer refer to ShardCategory or ShardSample object + /// \param[out] count # of rows + /// \return MSRStatus the status of MSRStatus + MSRStatus CountTotalRows(const std::vector &file_paths, bool load_dataset, + const std::shared_ptr &op, int64_t *count, const int num_padded); + + /// \brief shuffle task with incremental seed + /// \return void + void ShuffleTask(); + + /// \brief get the number of rows in database + /// \return # of rows + int GetNumRows() const; + + /// \brief Read the summary of row groups + /// \return the tuple of 4 elements + /// 1. Sharding ID + /// 2. Row group ID + /// 3. The row ID started in row group + /// 4. # of rows in row group + std::vector> ReadRowGroupSummary(); + + /// \brief Read 1 row group data, excluding images + /// \param[in] groupID row group ID + /// \param[in] shard_id sharding ID + /// \param[in] columns multi-columns retrieved + /// \return the tuple of 5 elements + /// 1. file name where row group is located + /// 2. Actual row group size + /// 3. Offset address of row group in file + /// 4. The list of image offset in page [startOffset, endOffset) + /// 5. The list of columns data + ROW_GROUP_BRIEF ReadRowGroupBrief(int group_id, int shard_id, + const std::vector &columns = std::vector()); + + /// \brief Read 1 row group data, excluding images, following an index field criteria + /// \param[in] groupID row group ID + /// \param[in] shard_id sharding ID + /// \param[in] column-value pair of criteria to fulfill + /// \param[in] columns multi-columns retrieved + /// \return the tuple of 5 elements + /// 1. file name where row group is located + /// 2. Actual row group size + /// 3. Offset address of row group in file + /// 4. The list of image offset in page [startOffset, endOffset) + /// 5. The list of columns data + ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair &criteria, + const std::vector &columns = std::vector()); + + /// \brief join all created threads + /// \return MSRStatus the status of MSRStatus + MSRStatus Finish(); + + /// \brief return a batch, given that one is ready + /// \return a batch of images and image data + std::vector, json>> GetNext(); + + /// \brief return a row by id + /// \return a batch of images and image data + std::pair, json>>> GetNextById(const int64_t &task_id, + const int32_t &consumer_id); + + /// \brief return a batch in block-reader mode, given that one is ready + /// \return a batch of images and image data + std::vector, json>> GetBlockNext(); + + /// \brief return a batch, given that one is ready, python API + /// \return a batch of images and image data + std::vector>, pybind11::object>> GetNextPy(); + + /// \brief get blob filed list + /// \return blob field list + std::pair> GetBlobFields(); + + /// \brief reset reader + /// \return null + void Reset(); + + /// \brief set flag of all-in-index + /// \return null + void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } + + /// \brief get NLP flag + bool GetNlpFlag(); + + /// \brief get all classes + MSRStatus GetAllClasses(const std::string &category_field, std::set &categories); + + protected: + /// \brief sqlite call back function + static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); + + private: + /// \brief wrap up labels to json format + MSRStatus ConvertLabelToJson(const std::vector> &labels, std::shared_ptr fs, + std::vector>> &offsets, int shard_id, + const std::vector &columns, std::vector> &column_values); + + /// \brief read all rows for specified columns + ROW_GROUPS ReadAllRowGroup(std::vector &columns); + + /// \brief read all rows in one shard + MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, + std::vector>> &offsets, + std::vector> &column_values); + + /// \brief initialize reader + MSRStatus Init(const std::vector &file_paths, bool load_dataset); + + /// \brief validate column list + MSRStatus CheckColumnList(const std::vector &selected_columns); + + /// \brief populate one row by task list in row-reader mode + MSRStatus ConsumerByRow(int consumer_id); + + /// \brief populate one row by task list in block-reader mode + MSRStatus ConsumerByBlock(int consumer_id); + + /// \brief get offset address of images within page + std::vector> GetImageOffset(int group_id, int shard_id, + const std::pair &criteria = {"", ""}); + + /// \brief execute sqlite query with prepare statement + MSRStatus QueryWithCriteria(sqlite3 *db, string &sql, string criteria, std::vector> &labels); + + /// \brief get column values + std::pair> GetLabels(int group_id, int shard_id, const std::vector &columns, + const std::pair &criteria = {"", ""}); + + /// \brief get column values from raw data page + std::pair> GetLabelsFromPage(int group_id, int shard_id, + const std::vector &columns, + const std::pair &criteria = {"", + ""}); + + /// \brief create task list in block-reader mode + MSRStatus CreateTasksByBlock(const std::vector> &row_group_summary, + const std::vector> &operators); + + /// \brief create category-applied task list + MSRStatus CreateTasksByCategory(const std::vector> &row_group_summary, + const std::shared_ptr &op); + + /// \brief create task list in row-reader mode + MSRStatus CreateTasksByRow(const std::vector> &row_group_summary, + const std::vector> &operators); + + /// \brief crate task list + MSRStatus CreateTasks(const std::vector> &row_group_summary, + const std::vector> &operators); + + /// \brief set NLP flag + void CheckNlp(); + + /// \brief check if all specified columns are in index table + void CheckIfColumnInIndex(const std::vector &columns); + + /// \brief open multiple file handle + void FileStreamsOperator(); + + /// \brief read one row by one task + TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id); + + /// \brief get one row from buffer in block-reader mode + std::shared_ptr, json>>> GetRowFromBuffer(int bufId, int rowId); + + /// \brief get labels from binary file + std::pair> GetLabelsFromBinaryFile( + int shard_id, const std::vector &columns, const std::vector> &label_offsets); + + MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id); + + /// \brief get classes in one shard + void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set &categories); + + /// \brief get number of classes + int64_t GetNumClasses(const std::string &category_field); + + /// \brief get meta of header + std::pair> GetMeta(const std::string &file_path, json &meta_data); + + /// \brief extract uncompressed data based on column list + std::pair>> UnCompressBlob(const std::vector &raw_blob_data); + + protected: + uint64_t header_size_; // header size + uint64_t page_size_; // page size + int shard_count_; // number of shards + std::shared_ptr shard_header_; // shard header + std::shared_ptr shard_column_; // shard column + + std::vector database_paths_; // sqlite handle list + std::vector file_paths_; // file paths + std::vector> file_streams_; // single-file handle list + std::vector>> file_streams_random_; // multiple-file handle list + + private: + int n_consumer_; // number of workers (threads) + std::vector selected_columns_; // columns which will be read + std::map column_schema_id_; // column-schema map + std::vector> operators_; // data operators, including shuffle, sample and category + ShardTask tasks_; // shard task + std::mutex shard_locker_; // locker of shard + + // flags + bool all_in_index_ = true; // if all columns are stored in index-table + bool interrupt_ = false; // reader interrupted + + int num_padded_; // number of padding samples + + // Delivery/Iterator mode begin + const std::string kThreadName = "THRD_ITER_"; // prefix of thread name + std::vector thread_set_; // thread list + int num_rows_; // number of rows + std::mutex mtx_delivery_; // locker for delivery + std::condition_variable cv_delivery_; // conditional variable for delivery + std::condition_variable cv_iterator_; // conditional variable for iterator + std::atomic task_id_; // task ID which is working + std::atomic deliver_id_; // delivery ID which is picked up by iterator + // map of delivery + std::unordered_map, json>>>> delivery_map_; + // Delivery/Iterator mode end + + // Block reader mode begin + bool block_reader_; // block-reader mode + int row_id_; // row id in one page + int num_blocks_; // number of pages + // raw data page + std::vector>, std::vector>>> delivery_block_; + std::unordered_set delivery_block_set_; // set of delivered pages + std::vector> buf_; // page buffer + // Block reader mode end +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_READER_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h new file mode 100644 index 0000000000..ce813bc4bf --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ + +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_shuffle.h" + +namespace mindspore { +namespace mindrecord { +class ShardSample : public ShardOperator { + public: + explicit ShardSample(int n); + + ShardSample(int num, int den); + + ShardSample(int num, int den, int par); + + ShardSample(const std::vector &indices, uint32_t seed); + + ~ShardSample() override{}; + + MSRStatus Execute(ShardTask &tasks) override; + + MSRStatus SufExecute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + protected: + int numerator_; + int denominator_; + int partition_id_; + int no_of_samples_; + std::shared_ptr shuffle_op_; + + private: + std::vector indices_; + SamplerType sampler_type_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h new file mode 100644 index 0000000000..56eae85e5a --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h @@ -0,0 +1,90 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ +#define MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ + +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_pybind.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +class Schema { + public: + ~Schema() = default; + + /// \brief obtain the json schema ,its description, its block fields + /// \param[in] desc the description of the schema + /// \param[in] schema the schema's json + static std::shared_ptr Build(std::string desc, const json &schema); + + /// \brief obtain the json schema and its description for python + /// \param[in] desc the description of the schema + /// \param[in] schema the schema's json + static std::shared_ptr Build(std::string desc, pybind11::handle schema); + + /// \brief compare two schema to judge if they are equal + /// \param b another schema to be judged + /// \return true if they are equal,false if not + bool operator==(const Schema &b) const; + + /// \brief get the schema and its description + /// \return the json format of the schema and its description + std::string GetDesc() const; + + /// \brief get the schema and its description + /// \return the json format of the schema and its description + json GetSchema() const; + + /// \brief get the schema and its description for python method + /// \return the python object of the schema and its description + pybind11::object GetSchemaForPython() const; + + /// set the schema id + /// \param[in] id the id need to be set + void SetSchemaID(int64_t id); + + /// get the schema id + /// \return the int64 schema id + int64_t GetSchemaID() const; + + /// get the blob fields + /// \return the vector blob fields + std::vector GetBlobFields() const; + + private: + Schema() = default; + static bool ValidateNumberShape(const json &it_value); + static bool Validate(json schema); + static std::vector PopulateBlobFields(json schema); + + std::string desc_; + json schema_; + std::vector blob_fields_; + int64_t schema_id_ = -1; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h new file mode 100644 index 0000000000..45d9bda338 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h @@ -0,0 +1,102 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ +#define MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ + +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_reader.h" + +namespace mindspore { +namespace mindrecord { +class ShardSegment : public ShardReader { + public: + ShardSegment(); + + ~ShardSegment() override = default; + + /// \brief Get candidate category fields + /// \return a list of fields names which are the candidates of category + std::pair> GetCategoryFields(); + + /// \brief Set category field + /// \param[in] category_field category name + /// \return true if category name is existed + MSRStatus SetCategoryField(std::string category_field); + + /// \brief Thread-safe implementation of ReadCategoryInfo + /// \return statistics data in json format with 2 field: "key" and "categories". + /// The value of "categories" is a list. Each Element in list is {count, id, name} + /// count: count of images in category + /// id: internal unique identification, persistent + /// name: category name + /// example: + /// { "key": "label", + /// "categories": [ { "count": 3, "id": 0, "name": "sport", }, + /// { "count": 3, "id": 1, "name": "finance", } ] } + std::pair ReadCategoryInfo(); + + /// \brief Thread-safe implementation of ReadAtPageById + /// \param[in] category_id category ID + /// \param[in] page_no page number + /// \param[in] n_rows_of_page rows number in one page + /// \return images array, image is a vector of uint8_t + std::pair>> ReadAtPageById(int64_t category_id, int64_t page_no, + int64_t n_rows_of_page); + + /// \brief Thread-safe implementation of ReadAtPageByName + /// \param[in] category_name category Name + /// \param[in] page_no page number + /// \param[in] n_rows_of_page rows number in one page + /// \return images array, image is a vector of uint8_t + std::pair>> ReadAtPageByName(std::string category_name, int64_t page_no, + int64_t n_rows_of_page); + + std::pair, json>>> ReadAllAtPageById(int64_t category_id, + int64_t page_no, + int64_t n_rows_of_page); + + std::pair, json>>> ReadAllAtPageByName( + std::string category_name, int64_t page_no, int64_t n_rows_of_page); + + std::pair, pybind11::object>>> ReadAtPageByIdPy( + int64_t category_id, int64_t page_no, int64_t n_rows_of_page); + + std::pair, pybind11::object>>> ReadAtPageByNamePy( + std::string category_name, int64_t page_no, int64_t n_rows_of_page); + + std::pair> GetBlobFields(); + + private: + std::pair>> WrapCategoryInfo(); + + std::string ToJsonForCategory(const std::vector> &tri_vec); + + std::string CleanUp(std::string fieldName); + + std::pair> PackImages(int group_id, int shard_id, std::vector offset); + + std::vector candidate_category_fields_; + std::string current_category_field_; + const uint32_t kStartFieldId = 9; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h new file mode 100644 index 0000000000..724be9acaf --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ + +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_sample.h" + +namespace mindspore { +namespace mindrecord { +class ShardSequentialSample : public ShardSample { + public: + ShardSequentialSample(int n, int offset); + + ShardSequentialSample(float per, float per_offset); + + ~ShardSequentialSample() override{}; + + MSRStatus Execute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + private: + int offset_; + float per_; + float per_offset_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h new file mode 100644 index 0000000000..d7f736b55b --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ +#define MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ + +#include +#include "minddata/mindrecord/include/shard_operator.h" + +namespace mindspore { +namespace mindrecord { +class ShardShuffle : public ShardOperator { + public: + explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); + + ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, + ShuffleType shuffle_type = kShuffleSample); + + ~ShardShuffle() override{}; + + MSRStatus Execute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + private: + uint32_t shuffle_seed_; + int64_t no_of_samples_; + bool replacement_; + bool reshuffle_each_epoch_; + ShuffleType shuffle_type_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h new file mode 100644 index 0000000000..f100bb9833 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#ifndef MINDRECORD_STATISTICS_H +#define MINDRECORD_STATISTICS_H + +#include +#include +#include +#include +#include + +#include "minddata/mindrecord/include/common/shard_pybind.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +class Statistics { + public: + /// \brief save the statistic and its description + /// \param[in] desc the statistic's description + /// \param[in] statistics the statistic needs to be saved + static std::shared_ptr Build(std::string desc, const json &statistics); + + /// \brief save the statistic from python and its description + /// \param[in] desc the statistic's description + /// \param[in] statistics the statistic needs to be saved + static std::shared_ptr Build(std::string desc, pybind11::handle statistics); + + ~Statistics() = default; + + /// \brief compare two statistics to judge if they are equal + /// \param b another statistics to be judged + /// \return true if they are equal,false if not + bool operator==(const Statistics &b) const; + + /// \brief get the description + /// \return the description + std::string GetDesc() const; + + /// \brief get the statistic + /// \return json format of the statistic + json GetStatistics() const; + + /// \brief get the statistic for python + /// \return the python object of statistics + pybind11::object GetStatisticsForPython() const; + + /// \brief decode the bson statistics to json + /// \param[in] encodedStatistics the bson type of statistics + /// \return json type of statistic + void SetStatisticsID(int64_t id); + + /// \brief get the statistics id + /// \return the int64 statistics id + int64_t GetStatisticsID() const; + + private: + /// \brief validate the statistic + /// \return true / false + static bool Validate(const json &statistics); + + static bool LevelRecursive(json level); + + Statistics() = default; + + std::string desc_; + json statistics_; + int64_t statistics_id_ = -1; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_STATISTICS_H diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h new file mode 100644 index 0000000000..f07da656f2 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h @@ -0,0 +1,67 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_TASK_H_ +#define MINDRECORD_INCLUDE_SHARD_TASK_H_ + +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" + +namespace mindspore { +namespace mindrecord { +class ShardTask { + public: + ShardTask(); + + ShardTask(const ShardTask &task); // copy construction + + ShardTask &operator=(const ShardTask &task); // assignment operator + + ~ShardTask() = default; + + void MakePerm(); + + void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, + const json &label); + + void InsertTask(std::tuple, std::vector, json> task); + + void PopBack(); + + uint32_t Size() const; + + uint32_t SizeOfRows() const; + + std::tuple, std::vector, json> &GetTaskByID(size_t id); + + std::tuple, std::vector, json> &GetRandomTask(); + + static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements); + + uint32_t categories; + + std::vector permutation_; + + std::vector, std::vector, json>> task_list_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_TASK_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h new file mode 100644 index 0000000000..833928773e --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h @@ -0,0 +1,257 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_WRITER_H_ +#define MINDRECORD_INCLUDE_SHARD_WRITER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_column.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +class ShardWriter { + public: + ShardWriter(); + + ~ShardWriter(); + + /// \brief Open file at the beginning + /// \param[in] paths the file names list + /// \param[in] append new data at the end of file if true, otherwise overwrite file + /// \return MSRStatus the status of MSRStatus + MSRStatus Open(const std::vector &paths, bool append = false); + + /// \brief Open file at the ending + /// \param[in] paths the file names list + /// \return MSRStatus the status of MSRStatus + MSRStatus OpenForAppend(const std::string &path); + + /// \brief Write header to disk + /// \return MSRStatus the status of MSRStatus + MSRStatus Commit(); + + /// \brief Set file size + /// \param[in] header_size the size of header, only (1< header_data); + + /// \brief write raw data by group size + /// \param[in] raw_data the vector of raw json data, vector format + /// \param[in] blob_data the vector of image data + /// \param[in] sign validate data or not + /// \return MSRStatus the status of MSRStatus to judge if write successfully + MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, + bool sign = true, bool parallel_writer = false); + + /// \brief write raw data by group size for call from python + /// \param[in] raw_data the vector of raw json data, python-handle format + /// \param[in] blob_data the vector of image data + /// \param[in] sign validate data or not + /// \return MSRStatus the status of MSRStatus to judge if write successfully + MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, + bool sign = true, bool parallel_writer = false); + + /// \brief write raw data by group size for call from python + /// \param[in] raw_data the vector of raw json data, python-handle format + /// \param[in] blob_data the vector of blob json data, python-handle format + /// \param[in] sign validate data or not + /// \return MSRStatus the status of MSRStatus to judge if write successfully + MSRStatus WriteRawData(std::map> &raw_data, + std::map> &blob_data, bool sign = true, + bool parallel_writer = false); + + private: + /// \brief write shard header data to disk + MSRStatus WriteShardHeader(); + + /// \brief erase error data + void DeleteErrorData(std::map> &raw_data, std::vector> &blob_data); + + /// \brief populate error data + void PopulateMutexErrorData(const int &row, const std::string &message, std::map &err_raw_data); + + /// \brief check data + void CheckSliceData(int start_row, int end_row, json schema, const std::vector &sub_raw_data, + std::map &err_raw_data); + + /// \brief write shard header data to disk + std::tuple ValidateRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign); + + /// \brief fill data array in multiple thread run + void FillArray(int start, int end, std::map> &raw_data, + std::vector> &bin_data); + + /// \brief serialized raw data + MSRStatus SerializeRawData(std::map> &raw_data, + std::vector> &bin_data, uint32_t row_count); + + /// \brief write all data parallel + MSRStatus ParallelWriteData(const std::vector> &blob_data, + const std::vector> &bin_raw_data); + + /// \brief write data shard by shard + MSRStatus WriteByShard(int shard_id, int start_row, int end_row, const std::vector> &blob_data, + const std::vector> &bin_raw_data); + + /// \brief break image data up into multiple row groups + MSRStatus CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, + std::vector> &rows_in_group, const std::shared_ptr &last_raw_page, + const std::shared_ptr &last_blob_page); + + /// \brief append partial blob data to previous page + MSRStatus AppendBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page); + + /// \brief write new blob data page to disk + MSRStatus NewBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page); + + /// \brief shift last row group to next raw page for new appending + MSRStatus ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page); + + /// \brief write raw data page to disk + MSRStatus WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page, const std::vector> &bin_raw_data); + + /// \brief generate empty raw data page + void EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page); + + /// \brief append a row group at the end of raw page + MSRStatus AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, + const int &chunk_id, int &last_row_groupId, std::shared_ptr last_raw_page, + const std::vector> &bin_raw_data); + + /// \brief write blob chunk to disk + MSRStatus FlushBlobChunk(const std::shared_ptr &out, const std::vector> &blob_data, + const std::pair &blob_row); + + /// \brief write raw chunk to disk + MSRStatus FlushRawChunk(const std::shared_ptr &out, + const std::vector> &rows_in_group, const int &chunk_id, + const std::vector> &bin_raw_data); + + /// \brief break up into tasks by shard + std::vector> BreakIntoShards(); + + /// \brief calculate raw data size row by row + MSRStatus SetRawDataSize(const std::vector> &bin_raw_data); + + /// \brief calculate blob data size row by row + MSRStatus SetBlobDataSize(const std::vector> &blob_data); + + /// \brief populate last raw page pointer + void SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page); + + /// \brief populate last blob page pointer + void SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page); + + /// \brief check the data by schema + MSRStatus CheckData(const std::map> &raw_data); + + /// \brief check the data and type + MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, + std::map &err_raw_data); + + /// \brief Lock writer and save pages info + int LockWriter(bool parallel_writer = false); + + /// \brief Unlock writer and save pages info + MSRStatus UnlockWriter(int fd, bool parallel_writer = false); + + /// \brief Check raw data before writing + MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, + bool sign, int *schema_count, int *row_count); + + /// \brief Get full path from file name + MSRStatus GetFullPathFromFileName(const std::vector &paths); + + /// \brief Open files + MSRStatus OpenDataFiles(bool append); + + /// \brief Remove lock file + MSRStatus RemoveLockFile(); + + /// \brief Remove lock file + MSRStatus InitLockFile(); + + private: + const std::string kLockFileSuffix = "_Locker"; + const std::string kPageFileSuffix = "_Pages"; + std::string lock_file_; // lock file for parallel run + std::string pages_file_; // temporary file of pages info for parallel run + + int shard_count_; // number of files + uint64_t header_size_; // header size + uint64_t page_size_; // page size + uint32_t row_count_; // count of rows + uint32_t schema_count_; // count of schemas + + std::vector raw_data_size_; // Raw data size + std::vector blob_data_size_; // Blob data size + + std::vector file_paths_; // file paths + std::vector> file_streams_; // file handles + std::shared_ptr shard_header_; // shard header + std::shared_ptr shard_column_; // shard columns + + std::map> err_mg_; // used for storing error raw_data info + + std::mutex check_mutex_; // mutex for data check + std::atomic flag_{false}; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_WRITER_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc new file mode 100644 index 0000000000..f9b18a3bf0 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc @@ -0,0 +1,626 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "common/utils.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::DEBUG; +using mindspore::MsLogLevel::ERROR; +using mindspore::MsLogLevel::INFO; + +namespace mindspore { +namespace mindrecord { +ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append) + : file_path_(file_path), + append_(append), + page_size_(0), + header_size_(0), + schema_count_(0), + task_(0), + write_success_(true) {} + +MSRStatus ShardIndexGenerator::Build() { + auto ret = ShardHeader::BuildSingleHeader(file_path_); + if (ret.first != SUCCESS) { + return FAILED; + } + auto json_header = ret.second; + + auto ret2 = GetParentDir(file_path_); + if (SUCCESS != ret2.first) { + return FAILED; + } + std::vector real_addresses; + for (const auto &path : json_header["shard_addresses"]) { + std::string abs_path = ret2.second + string(path); + real_addresses.emplace_back(abs_path); + } + ShardHeader header = ShardHeader(); + if (header.BuildDataset(real_addresses) == FAILED) { + return FAILED; + } + shard_header_ = header; + MS_LOG(INFO) << "Init header from mindrecord file for index successfully."; + return SUCCESS; +} + +std::pair ShardIndexGenerator::GetValueByField(const string &field, json input) { + if (field.empty()) { + MS_LOG(ERROR) << "The input field is None."; + return {FAILED, ""}; + } + + if (input.empty()) { + MS_LOG(ERROR) << "The input json is None."; + return {FAILED, ""}; + } + + // parameter input does not contain the field + if (input.find(field) == input.end()) { + MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input; + return {FAILED, ""}; + } + + // schema does not contain the field + auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; + if (schema.find(field) == schema.end()) { + MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; + return {FAILED, ""}; + } + + // field should be scalar type + if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) { + MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable"; + return {FAILED, ""}; + } + + if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { + auto schema_field_options = schema[field]; + if (schema_field_options.find("shape") == schema_field_options.end()) { + return {SUCCESS, input[field].dump()}; + } else { + // field with shape option + MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable"; + return {FAILED, ""}; + } + } + + // the field type is string in here + return {SUCCESS, input[field].get()}; +} + +std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { + std::vector field_name = StringSplit(field_path, kPoint); + for (uint64_t i = 0; i < field_name.size(); i++) { + if (i != field_name.size() - 1) { + // Get type information from json schema + schema = schema.at(field_name[i]); + schema = schema.at("properties"); + } else { + // standard root layer exist "properties" if type is "object" + if (schema.find("properties") != schema.end()) { + schema = schema.at("properties"); + } + schema = schema.at(field_name[i]); + std::string field_type = schema.at("type").dump(); + if (field_type.length() <= 2) { + return ""; + } else { + return field_type.substr(1, field_type.length() - 2); + } + } + } + return ""; +} + +std::string ShardIndexGenerator::ConvertJsonToSQL(const std::string &json) { + if (kDbJsonMap.find(json) != kDbJsonMap.end()) { + return kDbJsonMap.at(json); + } else { + return "TEXT"; + } +} + +int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char **az_col_name) { + for (auto i = 0; i < argc; i++) { + if (argv[i] != nullptr) { + MS_LOG(INFO) << az_col_name[i] << " = " << (argv[i] ? argv[i] : "nullptr"); + } + } + MS_LOG(INFO) << "\n"; + return 0; +} + +MSRStatus ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) { + char *z_err_msg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Sql error: " << z_err_msg; + sqlite3_free(z_err_msg); + return FAILED; + } else { + if (!success_msg.empty()) { + MS_LOG(DEBUG) << "Sqlite3_exec exec success, msg is: " << success_msg; + } + sqlite3_free(z_err_msg); + return SUCCESS; + } +} + +std::pair ShardIndexGenerator::GenerateFieldName( + const std::pair &field) { + // Replaces dots and dashes with underscores for SQL use + std::string field_name = field.second; + // white list to avoid sql injection + std::replace_if( + field_name.begin(), field_name.end(), [](char x) { return (x == '-' || x == '.'); }, '_'); + auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) { + return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9'); + }); + if (pos != field_name.end()) { + MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " << field_name; + return {FAILED, ""}; + } + return {SUCCESS, field_name + "_" + std::to_string(field.first)}; +} + +std::pair ShardIndexGenerator::CheckDatabase(const std::string &shard_address) { + sqlite3 *db = nullptr; + std::ifstream fin(common::SafeCStr(shard_address)); + if (!append_ && fin.good()) { + MS_LOG(ERROR) << "DB file already exist"; + fin.close(); + return {FAILED, nullptr}; + } + fin.close(); + int rc = sqlite3_open_v2(common::SafeCStr(shard_address), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); + if (rc) { + MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); + return {FAILED, nullptr}; + } else { + MS_LOG(DEBUG) << "Opened database successfully"; + return {SUCCESS, db}; + } +} + +MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) { + // create shard_name table + std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;"; + if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { + return FAILED; + } + sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);"; + if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { + return FAILED; + } + sql = "INSERT INTO SHARD_NAME (NAME) VALUES ('" + shard_name + "');"; + if (ExecuteSQL(sql, db, "insert name successfully.") != SUCCESS) { + return FAILED; + } + return SUCCESS; +} + +std::pair ShardIndexGenerator::CreateDatabase(int shard_no) { + std::string shard_address = shard_header_.GetShardAddressByID(shard_no); + if (shard_address.empty()) { + MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; + return {FAILED, nullptr}; + } + + string shard_name = GetFileName(shard_address).second; + shard_address += ".db"; + auto ret1 = CheckDatabase(shard_address); + if (ret1.first != SUCCESS) { + return {FAILED, nullptr}; + } + sqlite3 *db = ret1.second; + std::string sql = "DROP TABLE IF EXISTS INDEXES;"; + if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { + return {FAILED, nullptr}; + } + sql = + "CREATE TABLE INDEXES(" + " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL" + ", PAGE_OFFSET_RAW INT NOT NULL, PAGE_OFFSET_RAW_END INT NOT NULL" + ", ROW_GROUP_ID INT NOT NULL, PAGE_ID_BLOB INT NOT NULL" + ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL"; + + int field_no = 0; + for (const auto &field : fields_) { + uint64_t schema_id = field.first; + auto result = shard_header_.GetSchemaByID(schema_id); + if (result.second != SUCCESS) { + return {FAILED, nullptr}; + } + json json_schema = (result.first->GetSchema())["schema"]; + std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema)); + auto ret = GenerateFieldName(field); + if (ret.first != SUCCESS) { + return {FAILED, nullptr}; + } + sql += ",INC_" + std::to_string(field_no++) + " INT, " + ret.second + " " + type; + } + sql += ", PRIMARY KEY(ROW_ID"; + for (uint64_t i = 0; i < fields_.size(); ++i) sql += ",INC_" + std::to_string(i); + sql += "));"; + if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { + return {FAILED, nullptr}; + } + + if (CreateShardNameTable(db, shard_name) != SUCCESS) { + return {FAILED, nullptr}; + } + return {SUCCESS, db}; +} + +std::pair> ShardIndexGenerator::GetSchemaDetails(const std::vector &schema_lens, + std::fstream &in) { + std::vector schema_details; + if (schema_count_ <= kMaxSchemaCount) { + for (int sc = 0; sc < schema_count_; ++sc) { + std::vector schema_detail(schema_lens[sc]); + + auto &io_read = in.read(&schema_detail[0], schema_lens[sc]); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + in.close(); + return {FAILED, {}}; + } + + schema_details.emplace_back(json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()))); + } + } + + return {SUCCESS, schema_details}; +} + +std::pair ShardIndexGenerator::GenerateRawSQL( + const std::vector> &fields) { + std::string sql = + "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END," + "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END"; + + int field_no = 0; + for (const auto &field : fields) { + auto ret = GenerateFieldName(field); + if (ret.first != SUCCESS) { + return {FAILED, ""}; + } + sql += ",INC_" + std::to_string(field_no++) + "," + ret.second; + } + sql += + ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB," + ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END"; + field_no = 0; + for (const auto &field : fields) { + auto ret = GenerateFieldName(field); + if (ret.first != SUCCESS) { + return {FAILED, ""}; + } + sql += ",:INC_" + std::to_string(field_no++) + ",:" + ret.second; + } + sql += " )"; + return {SUCCESS, sql}; +} + +MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( + sqlite3 *db, const std::string &sql, + const std::vector>> &data) { + sqlite3_stmt *stmt = nullptr; + if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; + return FAILED; + } + for (auto &row : data) { + for (auto &field : row) { + const auto &place_holder = std::get<0>(field); + const auto &field_type = std::get<1>(field); + const auto &field_value = std::get<2>(field); + + int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); + if (field_type == "INTEGER") { + if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index + << ", field value: " << std::stoll(field_value); + return FAILED; + } + } else if (field_type == "NUMERIC") { + if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index + << ", field value: " << std::stold(field_value); + return FAILED; + } + } else if (field_type == "NULL") { + if (sqlite3_bind_null(stmt, index) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL"; + return FAILED; + } + } else { + if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value; + return FAILED; + } + } + } + if (sqlite3_step(stmt) != SQLITE_DONE) { + MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt."; + return FAILED; + } + (void)sqlite3_reset(stmt); + } + (void)sqlite3_finalize(stmt); + return SUCCESS; +} + +MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector> &row_data, + const std::shared_ptr cur_blob_page, + uint64_t &cur_blob_page_offset, std::fstream &in) { + row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); + + // blob data start + row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset)); + auto &io_seekg_blob = + in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); + if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + in.close(); + return FAILED; + } + + uint64_t image_size = 0; + + auto &io_read = in.read(reinterpret_cast(&image_size), kInt64Len); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + in.close(); + return FAILED; + } + + cur_blob_page_offset += (kInt64Len + image_size); + row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset)); + + return SUCCESS; +} + +void ShardIndexGenerator::AddIndexFieldByRawData( + const std::vector &schema_detail, std::vector> &row_data) { + auto result = GenerateIndexFields(schema_detail); + if (result.first == SUCCESS) { + int index = 0; + for (const auto &field : result.second) { + // assume simple field: string , number etc. + row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0"); + row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field)); + } + } +} + +ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, + int raw_page_id, std::fstream &in) { + std::vector>> full_data; + + // current raw data page + std::shared_ptr cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first; + + // related blob page + vector> row_group_list = cur_raw_page->GetRowGroupIds(); + + // pair: row_group id, offset in raw data page + for (pair blob_ids : row_group_list) { + // get blob data page according to row_group id + std::shared_ptr cur_blob_page = shard_header_.GetPage(shard_no, blob_id_to_page_id.at(blob_ids.first)).first; + + // offset in current raw data page + auto cur_raw_page_offset = static_cast(blob_ids.second); + uint64_t cur_blob_page_offset = 0; + for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) { + std::vector> row_data; + row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); + row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID())); + row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID())); + + // raw data start + row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); + + // calculate raw data end + auto &io_seekg = + in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + in.close(); + return {FAILED, {}}; + } + + std::vector schema_lens; + if (schema_count_ <= kMaxSchemaCount) { + for (int sc = 0; sc < schema_count_; sc++) { + uint64_t schema_size = 0; + + auto &io_read = in.read(reinterpret_cast(&schema_size), kInt64Len); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + in.close(); + return {FAILED, {}}; + } + + cur_raw_page_offset += (kInt64Len + schema_size); + schema_lens.push_back(schema_size); + } + } + row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset)); + + // Getting schema for getting data for fields + auto st_schema_detail = GetSchemaDetails(schema_lens, in); + if (st_schema_detail.first != SUCCESS) { + return {FAILED, {}}; + } + + // start blob page info + if (AddBlobPageInfo(row_data, cur_blob_page, cur_blob_page_offset, in) != SUCCESS) { + return {FAILED, {}}; + } + + // start index field + AddIndexFieldByRawData(st_schema_detail.second, row_data); + full_data.push_back(std::move(row_data)); + } + } + return {SUCCESS, full_data}; +} + +INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector &schema_detail) { + std::vector> fields; + // index fields + std::vector> index_fields = shard_header_.GetFields(); + for (const auto &field : index_fields) { + if (field.first >= schema_detail.size()) { + return {FAILED, {}}; + } + auto field_value = GetValueByField(field.second, schema_detail[field.first]); + if (field_value.first != SUCCESS) { + MS_LOG(ERROR) << "Get value from json by field name failed"; + return {FAILED, {}}; + } + + auto result = shard_header_.GetSchemaByID(field.first); + if (result.second != SUCCESS) { + return {FAILED, {}}; + } + + std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); + auto ret = GenerateFieldName(field); + if (ret.first != SUCCESS) { + return {FAILED, {}}; + } + + fields.emplace_back(ret.second, field_type, field_value.second); + } + return {SUCCESS, std::move(fields)}; +} + +MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, std::pair &db, + const std::vector &raw_page_ids, + const std::map &blob_id_to_page_id) { + // Add index data to database + std::string shard_address = shard_header_.GetShardAddressByID(shard_no); + if (shard_address.empty()) { + MS_LOG(ERROR) << "Shard address is null"; + return FAILED; + } + + std::fstream in; + in.open(common::SafeCStr(shard_address), std::ios::in | std::ios::binary); + if (!in.good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; + } + (void)sqlite3_exec(db.second, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); + for (int raw_page_id : raw_page_ids) { + auto sql = GenerateRawSQL(fields_); + if (sql.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw SQL failed"; + return FAILED; + } + auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); + if (data.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw data failed"; + return FAILED; + } + if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { + MS_LOG(ERROR) << "Execute SQL failed"; + return FAILED; + } + MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; + } + (void)sqlite3_exec(db.second, "END TRANSACTION;", nullptr, nullptr, nullptr); + in.close(); + + // Close database + if (sqlite3_close(db.second) != SQLITE_OK) { + MS_LOG(ERROR) << "Close database failed"; + return FAILED; + } + db.second = nullptr; + return SUCCESS; +} + +MSRStatus ShardIndexGenerator::WriteToDatabase() { + fields_ = shard_header_.GetFields(); + page_size_ = shard_header_.GetPageSize(); + header_size_ = shard_header_.GetHeaderSize(); + schema_count_ = shard_header_.GetSchemaCount(); + if (shard_header_.GetShardCount() > kMaxShardCount) { + MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount; + return FAILED; + } + task_ = 0; // set two atomic vars to initial value + write_success_ = true; + + // spawn half the physical threads or total number of shards whichever is smaller + const unsigned int num_workers = + std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast(shard_header_.GetShardCount())); + + std::vector threads; + threads.reserve(num_workers); + + for (size_t t = 0; t < threads.capacity(); t++) { + threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this)); + } + + for (size_t t = 0; t < threads.capacity(); t++) { + threads[t].join(); + } + return write_success_ ? SUCCESS : FAILED; +} + +void ShardIndexGenerator::DatabaseWriter() { + int shard_no = task_++; + while (shard_no < shard_header_.GetShardCount()) { + auto db = CreateDatabase(shard_no); + if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { + write_success_ = false; + return; + } + + MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; + + // Pre-processing page information + auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; + + std::map blob_id_to_page_id; + std::vector raw_page_ids; + for (uint64_t i = 0; i < total_pages; ++i) { + std::shared_ptr cur_page = shard_header_.GetPage(shard_no, i).first; + if (cur_page->GetPageType() == "RAW_DATA") { + raw_page_ids.push_back(i); + } else if (cur_page->GetPageType() == "BLOB_DATA") { + blob_id_to_page_id[cur_page->GetPageTypeID()] = i; + } + } + + if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { + write_success_ = false; + return; + } + MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully."; + shard_no = task_++; + } +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc new file mode 100644 index 0000000000..84d7fddb6f --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -0,0 +1,1449 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "common/utils.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::DEBUG; +using mindspore::MsLogLevel::ERROR; +using mindspore::MsLogLevel::INFO; + +namespace mindspore { +namespace mindrecord { +template +// convert the string to exactly number type (int32_t/int64_t/float/double) +Type StringToNum(const std::string &str) { + std::istringstream iss(str); + Type num; + iss >> num; + return num; +} + +ShardReader::ShardReader() { + task_id_ = 0; + deliver_id_ = 0; + shard_count_ = 0; + n_consumer_ = 0; + page_size_ = 0; + header_size_ = 0; + num_rows_ = 0; + row_id_ = 0; + num_blocks_ = 0; + block_reader_ = false; + num_padded_ = 0; +} + +std::pair> ShardReader::GetMeta(const std::string &file_path, json &meta_data) { + if (!IsLegalFile(file_path)) { + return {FAILED, {}}; + } + auto ret = ShardHeader::BuildSingleHeader(file_path); + if (ret.first != SUCCESS) { + return {FAILED, {}}; + } + auto header = ret.second; + meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, + {"version", header["version"]}, {"index_fields", header["index_fields"]}, + {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; + return {SUCCESS, header["shard_addresses"]}; +} + +MSRStatus ShardReader::Init(const std::vector &file_paths, bool load_dataset) { + std::string file_path = file_paths[0]; + json first_meta_data = json(); + auto ret = GetMeta(file_path, first_meta_data); + if (ret.first != SUCCESS) { + return FAILED; + } + if (file_paths.size() == 1 && load_dataset == true) { + auto ret2 = GetParentDir(file_path); + if (SUCCESS != ret2.first) { + return FAILED; + } + std::vector real_addresses; + for (const auto &path : ret.second) { + std::string abs_path = ret2.second + string(path); + real_addresses.emplace_back(abs_path); + } + file_paths_ = real_addresses; + } else if (file_paths.size() >= 1 && load_dataset == false) { + file_paths_ = file_paths; + } else { + MS_LOG(ERROR) << "Error in parameter file_path or load_dataset."; + return FAILED; + } + for (const auto &file : file_paths_) { + json meta_data = json(); + auto ret1 = GetMeta(file, meta_data); + if (ret1.first != SUCCESS) { + return FAILED; + } + if (meta_data != first_meta_data) { + MS_LOG(ERROR) << "Mindrecord files meta information is different."; + return FAILED; + } + sqlite3 *db = nullptr; + // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it + int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); + return FAILED; + } + MS_LOG(DEBUG) << "Opened database successfully"; + + string sql = "select NAME from SHARD_NAME;"; + std::vector> name; + char *errmsg = nullptr; + rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &name, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return FAILED; + } else { + MS_LOG(DEBUG) << "Get " << static_cast(name.size()) << " records from index."; + string shardName = GetFileName(file).second; + if (name.empty() || name[0][0] != shardName) { + MS_LOG(ERROR) << "DB file can not match file " << file; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return FAILED; + } + } + database_paths_.push_back(db); + } + ShardHeader sh = ShardHeader(); + if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) { + return FAILED; + } + shard_header_ = std::make_shared(sh); + header_size_ = shard_header_->GetHeaderSize(); + page_size_ = shard_header_->GetPageSize(); + // version < 3.0 + if (first_meta_data["version"] < kVersion) { + shard_column_ = std::make_shared(shard_header_, false); + } else { + shard_column_ = std::make_shared(shard_header_, true); + } + num_rows_ = 0; + auto row_group_summary = ReadRowGroupSummary(); + for (const auto &rg : row_group_summary) { + num_rows_ += std::get<3>(rg); + } + + MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully."; + + return SUCCESS; +} + +MSRStatus ShardReader::CheckColumnList(const std::vector &selected_columns) { + vector inSchema(selected_columns.size(), 0); + for (auto &p : GetShardHeader()->GetSchemas()) { + auto schema = p->GetSchema()["schema"]; + for (unsigned int i = 0; i < selected_columns.size(); ++i) { + if (schema.find(selected_columns[i]) != schema.end()) { + inSchema[i] = 1; + } + } + } + if (std::any_of(std::begin(inSchema), std::end(inSchema), [](int x) { return x == 0; })) { + return FAILED; + } + + return SUCCESS; +} + +MSRStatus ShardReader::Open() { + file_streams_.clear(); + + for (const auto &file : file_paths_) { + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; + } + MS_LOG(INFO) << "Open shard file successfully."; + file_streams_.push_back(fs); + } + + return SUCCESS; +} + +MSRStatus ShardReader::Open(int n_consumer) { + file_streams_random_ = + std::vector>>(n_consumer, std::vector>()); + for (const auto &file : file_paths_) { + for (int j = 0; j < n_consumer; ++j) { + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; + } + file_streams_random_[j].push_back(fs); + } + MS_LOG(INFO) << "Open shard file successfully."; + } + + return SUCCESS; +} + +void ShardReader::FileStreamsOperator() { + for (int i = static_cast(file_streams_.size()) - 1; i >= 0; --i) { + if (file_streams_[i] != nullptr) { + file_streams_[i]->close(); + } + } + for (int i = static_cast(file_streams_random_.size()) - 1; i >= 0; --i) { + for (int j = static_cast(file_streams_random_[i].size()) - 1; j >= 0; --j) { + if (file_streams_random_[i][j] != nullptr) { + file_streams_random_[i][j]->close(); + } + } + } + for (int i = static_cast(database_paths_.size()) - 1; i >= 0; --i) { + if (database_paths_[i] != nullptr) { + auto ret = sqlite3_close(database_paths_[i]); + if (ret != SQLITE_OK) { + MS_LOG(ERROR) << "Close db failed. Error code: " << ret << "."; + } + database_paths_[i] = nullptr; + } + } +} + +ShardReader::~ShardReader() { Close(); } + +void ShardReader::Close() { + (void)Finish(); // interrupt reading and stop threads + FileStreamsOperator(); +} + +std::shared_ptr ShardReader::GetShardHeader() const { return shard_header_; } + +std::shared_ptr ShardReader::GetShardColumn() const { return shard_column_; } + +int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } + +int ShardReader::GetNumRows() const { return num_rows_; } + +std::vector> ShardReader::ReadRowGroupSummary() { + std::vector> row_group_summary; + int shard_count = shard_header_->GetShardCount(); + if (shard_count <= 0) { + return row_group_summary; + } + if (shard_count <= kMaxShardCount) { + for (int shard_id = 0; shard_id < shard_count; ++shard_id) { + // return -1 when page's size equals to 0. + auto last_page_id = shard_header_->GetLastPageId(shard_id); + if (static_cast(last_page_id) == -1) { + continue; + } + for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) { + const auto &page_t = shard_header_->GetPage(shard_id, page_id); + const auto &page = page_t.first; + if (page->GetPageType() != kPageTypeBlob) continue; + uint64_t start_row_id = page->GetStartRowID(); + if (start_row_id > page->GetEndRowID()) { + return std::vector>(); + } + uint64_t number_of_rows = page->GetEndRowID() - start_row_id; + row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows); + } + } + } + return row_group_summary; +} + +MSRStatus ShardReader::ConvertLabelToJson(const std::vector> &labels, + std::shared_ptr fs, + std::vector>> &offsets, int shard_id, + const std::vector &columns, + std::vector> &column_values) { + for (int i = 0; i < static_cast(labels.size()); ++i) { + uint64_t group_id = std::stoull(labels[i][0]); + uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len; + uint64_t offset_end = std::stoull(labels[i][2]); + offsets[shard_id].emplace_back( + std::vector{static_cast(shard_id), group_id, offset_start, offset_end}); + if (!all_in_index_) { + int raw_page_id = std::stoi(labels[i][3]); + uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len; + uint64_t label_end = std::stoull(labels[i][5]); + auto len = label_end - label_start; + auto label_raw = std::vector(len); + auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + fs->close(); + return FAILED; + } + + auto &io_read = fs->read(reinterpret_cast(&label_raw[0]), len); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + fs->close(); + return FAILED; + } + json label_json = json::from_msgpack(label_raw); + json tmp; + if (!columns.empty()) { + for (auto &col : columns) { + if (label_json.find(col) != label_json.end()) { + tmp[col] = label_json[col]; + } + } + } else { + tmp = label_json; + } + column_values[shard_id].emplace_back(tmp); + } else { + json construct_json; + for (unsigned int j = 0; j < columns.size(); ++j) { + // construct json "f1": value + auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; + + // convert the string to base type by schema + if (schema[columns[j]]["type"] == "int32") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else if (schema[columns[j]]["type"] == "int64") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else if (schema[columns[j]]["type"] == "float32") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else if (schema[columns[j]]["type"] == "float64") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else { + construct_json[columns[j]] = std::string(labels[i][j + 3]); + } + } + column_values[shard_id].emplace_back(construct_json); + } + } + + return SUCCESS; +} + +MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, + std::vector>> &offsets, + std::vector> &column_values) { + auto db = database_paths_[shard_id]; + std::vector> labels; + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return FAILED; + } + MS_LOG(INFO) << "Get " << static_cast(labels.size()) << " records from shard " << shard_id << " index."; + + std::string file_name = file_paths_[shard_id]; + std::shared_ptr fs = std::make_shared(); + if (!all_in_index_) { + fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; + } + } + sqlite3_free(errmsg); + return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); +} + +MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { + std::map index_columns; + for (auto &field : GetShardHeader()->GetFields()) { + index_columns[field.second] = field.first; + } + if (index_columns.find(category_field) == index_columns.end()) { + MS_LOG(ERROR) << "Index field " << category_field << " does not exist."; + return FAILED; + } + auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field)); + if (SUCCESS != ret.first) { + return FAILED; + } + std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::vector threads = std::vector(shard_count_); + for (int x = 0; x < shard_count_; x++) { + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories)); + } + + for (int x = 0; x < shard_count_; x++) { + threads[x].join(); + } + return SUCCESS; +} + +void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, + std::set &categories) { + if (nullptr == db) { + return; + } + std::vector> columns; + char *errmsg = nullptr; + int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg); + if (ret != SQLITE_OK) { + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; + return; + } + MS_LOG(INFO) << "Get " << static_cast(columns.size()) << " records from shard " << shard_id << " index."; + std::lock_guard lck(shard_locker_); + for (int i = 0; i < static_cast(columns.size()); ++i) { + categories.emplace(columns[i][0]); + } +} + +ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { + std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; + std::vector>> offsets(shard_count_, std::vector>{}); + std::vector> column_values(shard_count_, std::vector{}); + if (all_in_index_) { + for (unsigned int i = 0; i < columns.size(); ++i) { + fields += ','; + auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i])); + if (ret.first != SUCCESS) { + return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); + } + fields += ret.second; + } + } else { // fetch raw data from Raw page while some field is not index. + fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END "; + } + + std::string sql = "SELECT " + fields + " FROM INDEXES ORDER BY ROW_ID ;"; + + std::vector thread_read_db = std::vector(shard_count_); + for (int x = 0; x < shard_count_; x++) { + thread_read_db[x] = + std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, std::ref(offsets), std::ref(column_values)); + } + + for (int x = 0; x < shard_count_; x++) { + thread_read_db[x].join(); + } + return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values)); +} + +ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector &columns) { + const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); + if (SUCCESS != ret.first) { + return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); + } + const std::shared_ptr &page = ret.second; + std::string file_name = file_paths_[shard_id]; + uint64_t page_length = page->GetPageSize(); + uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; + std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id); + + auto status_labels = GetLabels(page->GetPageID(), shard_id, columns); + if (status_labels.first != SUCCESS) { + return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); + } + return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::move(image_offset), + std::move(status_labels.second)); +} + +ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, + const std::pair &criteria, + const std::vector &columns) { + const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); + if (SUCCESS != ret.first) { + return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); + } + vector criteria_list{criteria.first}; + if (CheckColumnList(criteria_list) == FAILED) { + return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); + } + const std::shared_ptr &page = ret.second; + std::string file_name = file_paths_[shard_id]; + uint64_t page_length = page->GetPageSize(); + uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; + std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria); + + auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria); + if (status_labels.first != SUCCESS) { + return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); + } + + return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::move(image_offset), + std::move(status_labels.second)); +} + +int ShardReader::SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names) { + auto *records = static_cast> *>(p_data); + if (num_fields > 0 && num_fields <= kMaxFieldCount) { + for (int i = 0; i < num_fields; ++i) + if (p_fields[i] == nullptr) p_fields[i] = const_cast(""); + } + records->emplace_back(p_fields, p_fields + num_fields); + return 0; +} + +std::vector> ShardReader::GetImageOffset(int page_id, int shard_id, + const std::pair &criteria) { + auto db = database_paths_[shard_id]; + + std::string sql = + "SELECT PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); + + // whether use index search + if (!criteria.first.empty()) { + auto schema = shard_header_->GetSchemas()[0]->GetSchema(); + + // not number field should add '' in sql + if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { + sql += + " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; + } else { + sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" + + criteria.second + "'"; + } + } + sql += ";"; + std::vector> image_offsets; + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &image_offsets, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return std::vector>(); + } else { + MS_LOG(DEBUG) << "Get " << static_cast(image_offsets.size()) << "records from index."; + } + std::vector> res; + for (int i = static_cast(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector{0, 0}); + for (int i = 0; i < static_cast(image_offsets.size()); i++) { + const auto &image_offset = image_offsets[i]; + res[i][0] = std::stoull(image_offset[0]) + kInt64Len; + res[i][1] = std::stoull(image_offset[1]); + } + sqlite3_free(errmsg); + return res; +} + +std::pair> ShardReader::GetBlobFields() { + std::vector blob_fields; + for (auto &p : GetShardHeader()->GetSchemas()) { + // assume one schema + const auto &fields = p->GetBlobFields(); + blob_fields.assign(fields.begin(), fields.end()); + break; + } + return std::make_pair(kCV, blob_fields); +} + +void ShardReader::CheckIfColumnInIndex(const std::vector &columns) { + // assume different schemas do not contain same key. + if (columns.empty()) { + all_in_index_ = false; + return; + } + for (auto &field : GetShardHeader()->GetFields()) { + column_schema_id_[field.second] = field.first; + } + for (auto &col : columns) { + if (column_schema_id_.find(col) == column_schema_id_.end()) { + all_in_index_ = false; + return; + } + } +} + +MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criteria, + std::vector> &labels) { + sqlite3_stmt *stmt = nullptr; + if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not prepare statement"; + return FAILED; + } + int index = sqlite3_bind_parameter_index(stmt, ":criteria"); + if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << criteria; + return FAILED; + } + int rc = sqlite3_step(stmt); + while (rc != SQLITE_DONE) { + vector tmp; + int ncols = sqlite3_column_count(stmt); + for (int i = 0; i < ncols; i++) { + tmp.emplace_back(reinterpret_cast(sqlite3_column_text(stmt, i))); + } + labels.push_back(tmp); + rc = sqlite3_step(stmt); + } + (void)sqlite3_finalize(stmt); + return SUCCESS; +} + +std::pair> ShardReader::GetLabelsFromBinaryFile( + int shard_id, const std::vector &columns, const std::vector> &label_offsets) { + std::string file_name = file_paths_[shard_id]; + std::vector res; + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return {FAILED, {}}; + } + + // init the return + for (unsigned int i = 0; i < label_offsets.size(); ++i) { + res.emplace_back(json{}); + } + + for (unsigned int i = 0; i < label_offsets.size(); ++i) { + const auto &labelOffset = label_offsets[i]; + uint64_t label_start = std::stoull(labelOffset[1]) + kInt64Len; + uint64_t label_end = std::stoull(labelOffset[2]); + int raw_page_id = std::stoi(labelOffset[0]); + auto len = label_end - label_start; + auto label_raw = std::vector(len); + auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + fs->close(); + return {FAILED, {}}; + } + + auto &io_read = fs->read(reinterpret_cast(&label_raw[0]), len); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + fs->close(); + return {FAILED, {}}; + } + + json label_json = json::from_msgpack(label_raw); + json tmp = label_json; + for (auto &col : columns) { + if (label_json.find(col) != label_json.end()) { + tmp[col] = label_json[col]; + } + } + res[i] = tmp; + } + return {SUCCESS, res}; +} + +std::pair> ShardReader::GetLabelsFromPage( + int page_id, int shard_id, const std::vector &columns, + const std::pair &criteria) { + // get page info from sqlite + auto db = database_paths_[shard_id]; + std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " + + std::to_string(page_id); + std::vector> label_offsets; + if (!criteria.first.empty()) { + sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria"; + if (QueryWithCriteria(db, sql, criteria.second, label_offsets) == FAILED) { + return {FAILED, {}}; + } + } else { + sql += ";"; + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &label_offsets, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return {FAILED, {}}; + } + MS_LOG(DEBUG) << "Get " << label_offsets.size() << "records from index."; + sqlite3_free(errmsg); + } + // get labels from binary file + return GetLabelsFromBinaryFile(shard_id, columns, label_offsets); +} + +std::pair> ShardReader::GetLabels(int page_id, int shard_id, + const std::vector &columns, + const std::pair &criteria) { + if (all_in_index_) { + auto db = database_paths_[shard_id]; + std::string fields; + for (unsigned int i = 0; i < columns.size(); ++i) { + if (i > 0) fields += ','; + uint64_t schema_id = column_schema_id_[columns[i]]; + fields += columns[i] + "_" + std::to_string(schema_id); + } + if (fields.empty()) fields = "*"; + std::vector> labels; + std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); + if (!criteria.first.empty()) { + sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria"; + if (QueryWithCriteria(db, sql, criteria.second, labels) == FAILED) { + return {FAILED, {}}; + } + } else { + sql += ";"; + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return {FAILED, {}}; + } else { + MS_LOG(DEBUG) << "Get " << static_cast(labels.size()) << "records from index."; + } + sqlite3_free(errmsg); + } + std::vector ret; + for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); + for (unsigned int i = 0; i < labels.size(); ++i) { + json construct_json; + for (unsigned int j = 0; j < columns.size(); ++j) { + // construct json "f1": value + auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; + + // convert the string to base type by schema + if (schema[columns[j]]["type"] == "int32") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else if (schema[columns[j]]["type"] == "int64") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else if (schema[columns[j]]["type"] == "float32") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else if (schema[columns[j]]["type"] == "float64") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else { + construct_json[columns[j]] = std::string(labels[i][j]); + } + } + ret[i] = construct_json; + } + return {SUCCESS, ret}; + } + return GetLabelsFromPage(page_id, shard_id, columns, criteria); +} + +bool ResortRowGroups(std::tuple a, std::tuple b) { + return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b)); +} + +MSRStatus ShardReader::Finish() { + { + std::lock_guard lck(mtx_delivery_); + interrupt_ = true; + } + cv_delivery_.notify_all(); + + // Wait for all threads to finish + for (auto &i_thread : thread_set_) { + if (i_thread.joinable()) { + i_thread.join(); + } + } + return SUCCESS; +} + +int64_t ShardReader::GetNumClasses(const std::string &category_field) { + auto shard_count = file_paths_.size(); + auto index_fields = shard_header_->GetFields(); + + std::map map_schema_id_fields; + for (auto &field : index_fields) { + map_schema_id_fields[field.second] = field.first; + } + + if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) { + MS_LOG(ERROR) << "Field " << category_field << " does not exist."; + return -1; + } + auto ret = + ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); + if (SUCCESS != ret.first) { + return -1; + } + std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::vector threads = std::vector(shard_count); + std::set categories; + for (int x = 0; x < shard_count; x++) { + sqlite3 *db = nullptr; + int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); + if (SQLITE_OK != rc) { + MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); + return -1; + } + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories)); + } + + for (int x = 0; x < shard_count; x++) { + threads[x].join(); + } + return categories.size(); +} + +MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths, bool load_dataset, + const std::shared_ptr &ops, int64_t *count, const int num_padded) { + if (SUCCESS != Init(file_paths, load_dataset)) { + return FAILED; + } + int64_t num_samples = num_rows_; + bool root = true; + std::stack> stack_ops; + std::shared_ptr op(ops); + while (op != nullptr) { + stack_ops.push(op); + op = op->GetChildOp(); + } + while (!stack_ops.empty()) { + op = stack_ops.top(); + stack_ops.pop(); + if (std::dynamic_pointer_cast(op)) { + num_samples = op->GetNumSamples(num_samples, 0); + if (num_padded > 0 && root == true) { + num_samples += num_padded; + MS_LOG(DEBUG) << "Padding samples work on shuffle sampler."; + root = false; + } + } else if (std::dynamic_pointer_cast(op)) { + auto category_op = std::dynamic_pointer_cast(op); + std::string category_field = category_op->GetCategoryField(); + auto num_classes = GetNumClasses(category_field); + num_samples = category_op->GetNumSamples(num_samples, num_classes); + } else if (std::dynamic_pointer_cast(op)) { + if (std::dynamic_pointer_cast(op)) { + auto sampler_op = std::dynamic_pointer_cast(op); + if (root == true) { + sampler_op->SetNumPaddedSamples(num_padded); + num_samples = op->GetNumSamples(num_samples, 0); + if (-1 == num_samples) { + MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; + return FAILED; + } + root = false; + } + } else { + num_samples = op->GetNumSamples(num_samples, 0); + } + } else { + if (num_padded > 0) num_samples += num_padded; + } + } + *count = num_samples; + return SUCCESS; +} + +MSRStatus ShardReader::Open(const std::vector &file_paths, bool load_dataset, int n_consumer, + const std::vector &selected_columns, + const std::vector> &operators, const bool &block_reader, + int num_padded) { + // Open file and set header by ShardReader + auto ret = Init(file_paths, load_dataset); + if (SUCCESS != ret) { + return ret; + } + auto thread_limit = GetMaxThreadNum(); + if (n_consumer > thread_limit) { + n_consumer = thread_limit; + } + if (n_consumer < kMinConsumerCount) { + n_consumer = kMinConsumerCount; + } + vector blob_fields = GetBlobFields().second; + for (unsigned int i = 0; i < selected_columns.size(); ++i) { + if (!std::any_of(blob_fields.begin(), blob_fields.end(), + [&selected_columns, i](std::string item) { return selected_columns[i] == item; })) { + selected_columns_.push_back(selected_columns[i]); + } + } + selected_columns_ = selected_columns; + + if (CheckColumnList(selected_columns_) == FAILED) { + MS_LOG(ERROR) << "Illegal column list"; + return ILLEGAL_COLUMN_LIST; + } + + // Initialize argument + shard_count_ = static_cast(file_paths_.size()); + n_consumer_ = n_consumer; + num_padded_ = num_padded; + + operators_ = operators; + + if (block_reader) { + block_reader_ = true; + if (Open() == FAILED) { + return FAILED; + } + delivery_block_ = std::vector>, std::vector>>>( + kNumPageInBuffer, std::shared_ptr>, std::vector>>{}); + buf_ = std::vector>(kNumPageInBuffer, std::vector(page_size_)); + } else { + block_reader_ = false; + if (Open(n_consumer) == FAILED) { + return FAILED; + } + } + return SUCCESS; +} + +MSRStatus ShardReader::OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer, + const std::vector &selected_columns, + const std::vector> &operators) { + // Open file and set header by ShardReader + if (SUCCESS != Init(file_paths, load_dataset)) { + return FAILED; + } + // should remove blob field from selected_columns when call from python + std::vector columns(selected_columns); + auto blob_fields = GetBlobFields().second; + for (auto &blob_field : blob_fields) { + auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field); + if (it != selected_columns.end()) { + columns.erase(columns.begin() + std::distance(selected_columns.begin(), it)); + } + } + if (CheckColumnList(columns) == FAILED) { + MS_LOG(ERROR) << "Illegal column list"; + return FAILED; + } + if (Open(n_consumer) == FAILED) { + return FAILED; + } + // Initialize argument + shard_count_ = static_cast(file_paths_.size()); + n_consumer_ = n_consumer; + + // Initialize columns which will be read + selected_columns_ = selected_columns; + operators_ = operators; + + return SUCCESS; +} + +MSRStatus ShardReader::Launch(bool isSimpleReader) { + // Get all row groups' info + auto row_group_summary = ReadRowGroupSummary(); + + // Sort row group by (group_id, shard_id), prepare for parallel reading + std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups); + if (CreateTasks(row_group_summary, operators_) != SUCCESS) { + MS_LOG(ERROR) << "Failed to launch read threads."; + interrupt_ = true; + return FAILED; + } + if (isSimpleReader) return SUCCESS; + // Start provider consumer threads + thread_set_ = std::vector(n_consumer_); + if (n_consumer_ <= 0 || n_consumer_ > kMaxConsumerCount) { + return FAILED; + } + + for (int x = 0; x < n_consumer_; ++x) { + if (block_reader_) { + thread_set_[x] = std::thread(&ShardReader::ConsumerByBlock, this, x); + } else { + thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); + } + } + + MS_LOG(INFO) << "Launch read thread successfully."; + return SUCCESS; +} + +MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, + const std::vector> &operators) { + CheckIfColumnInIndex(selected_columns_); + for (const auto &rg : row_group_summary) { + auto shard_id = std::get<0>(rg); + auto group_id = std::get<1>(rg); + auto n_Rows = std::get<3>(rg); + tasks_.InsertTask(TaskType::kCommonTask, shard_id, group_id, std::vector{n_Rows}, json{}); + } + return SUCCESS; +} + +MSRStatus ShardReader::CreateTasksByCategory(const std::vector> &row_group_summary, + const std::shared_ptr &op) { + CheckIfColumnInIndex(selected_columns_); + auto category_op = std::dynamic_pointer_cast(op); + auto categories = category_op->GetCategories(); + int64_t num_elements = category_op->GetNumElements(); + if (num_elements <= 0) { + MS_LOG(ERROR) << "Parameter num_element is not positive"; + return FAILED; + } + if (categories.empty() == true) { + std::string category_field = category_op->GetCategoryField(); + int64_t num_categories = category_op->GetNumCategories(); + if (num_categories <= 0) { + MS_LOG(ERROR) << "Parameter num_categories is not positive"; + return FAILED; + } + std::set categories_set; + auto ret = GetAllClasses(category_field, categories_set); + if (SUCCESS != ret) { + return FAILED; + } + int i = 0; + for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) { + categories.emplace_back(category_field, *it); + i++; + } + } + // Generate task list, a task will create a batch + std::vector categoryTasks(categories.size()); + for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { + int category_index = 0; + for (const auto &rg : row_group_summary) { + if (category_index >= num_elements) break; + auto shard_id = std::get<0>(rg); + auto group_id = std::get<1>(rg); + + auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_); + if (SUCCESS != std::get<0>(details)) { + return FAILED; + } + auto offsets = std::get<4>(details); + + auto number_of_rows = offsets.size(); + for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { + if (category_index < num_elements) { + categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, std::get<4>(details)[iStart], + std::get<5>(details)[iStart]); + category_index++; + } + } + } + MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; + } + tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); + if (SUCCESS != (*category_op)(tasks_)) { + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, + const std::vector> &operators) { + CheckIfColumnInIndex(selected_columns_); + + auto ret = ReadAllRowGroup(selected_columns_); + if (std::get<0>(ret) != SUCCESS) { + return FAILED; + } + auto offsets = std::get<1>(ret); + auto local_columns = std::get<2>(ret); + if (shard_count_ <= kMaxShardCount) { + for (int shard_id = 0; shard_id < shard_count_; shard_id++) { + for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { + tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], + std::vector{offsets[shard_id][i][2], offsets[shard_id][i][3]}, + local_columns[shard_id][i]); + } + } + } else { + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardReader::CreateTasks(const std::vector> &row_group_summary, + const std::vector> &operators) { + if (block_reader_) { + if (SUCCESS != CreateTasksByBlock(row_group_summary, operators)) { + return FAILED; + } + } else { + int category_operator = -1; + for (uint32_t i = 0; i < operators.size(); ++i) { + const auto &op = operators[i]; + if (std::dynamic_pointer_cast(op)) { + category_operator = static_cast(i); + break; + } + } + if (-1 == category_operator) { + if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) { + return FAILED; + } + if (num_padded_ > 0) { + for (int i = 0; i < num_padded_; ++i) { + tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json()); + } + } + } else { + if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { + return FAILED; + } + } + } + + for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) { + const auto &op = operators[operator_no]; + if (std::dynamic_pointer_cast(op)) continue; + if (block_reader_ && std::dynamic_pointer_cast(op)) continue; + if (SUCCESS != (*op)(tasks_)) { + return FAILED; + } + } + + if (tasks_.permutation_.empty()) tasks_.MakePerm(); + num_rows_ = block_reader_ ? tasks_.SizeOfRows() : tasks_.Size(); + num_blocks_ = block_reader_ ? tasks_.Size() : 0; + MS_LOG(INFO) << "Total rows is " << num_rows_; + return SUCCESS; +} + +TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { + // All tasks are done + if (task_id >= static_cast(tasks_.Size())) { + return std::make_pair(FAILED, + std::make_pair(TaskType::kCommonTask, std::vector, json>>())); + } + + // Pick up task from task list + auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); + + // check task type + auto task_type = std::get<0>(task); + if (task_type == TaskType::kPaddedTask) { + return std::make_pair(SUCCESS, + std::make_pair(TaskType::kPaddedTask, std::vector, json>>())); + } + + auto shard_id = std::get<0>(std::get<1>(task)); + auto group_id = std::get<1>(std::get<1>(task)); + auto addr = std::get<2>(task); + const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); + if (SUCCESS != ret.first) { + return std::make_pair(FAILED, + std::make_pair(TaskType::kCommonTask, std::vector, json>>())); + } + const std::shared_ptr &page = ret.second; + + // Pack image list + std::vector images(addr[1] - addr[0]); + auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0]; + + auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + file_streams_random_[consumer_id][shard_id]->close(); + return std::make_pair(FAILED, + std::make_pair(TaskType::kCommonTask, std::vector, json>>())); + } + + auto &io_read = + file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast(&images[0]), addr[1] - addr[0]); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + file_streams_random_[consumer_id][shard_id]->close(); + return std::make_pair(FAILED, + std::pair(TaskType::kCommonTask, std::vector, json>>())); + } + + // Deliver batch data to output map + std::vector, json>> batch; + batch.emplace_back(std::move(images), std::move(std::get<3>(task))); + + return std::make_pair(SUCCESS, std::make_pair(TaskType::kCommonTask, std::move(batch))); +} + +MSRStatus ShardReader::ConsumerByRow(int consumer_id) { + // Set thread name +#if !defined(_WIN32) && !defined(_WIN64) + auto thread_id = kThreadName + std::to_string(consumer_id); + prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0); +#endif + + // Loop forever + for (;;) { + int task_id = 0; + + // Get next task ID + task_id = task_id_++; + + // All tasks are done + if (task_id >= static_cast(tasks_.Size())) { + return FAILED; + } + const auto &ret = ConsumerOneTask(task_id, consumer_id); + if (SUCCESS != ret.first) { + return FAILED; + } + const auto &batch = (ret.second).second; + // Hanging if maximum map size exceeded + // otherwise, set batch data in map + { + std::unique_lock lck(mtx_delivery_); + cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id <= deliver_id_ + kNumBatchInMap; }); + if (interrupt_) { + return SUCCESS; + } + delivery_map_[task_id] = std::make_shared, json>>>(std::move(batch)); + } + cv_iterator_.notify_one(); + } +} + +MSRStatus ShardReader::ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, + const int &buf_id) { + auto &io_seekg = file_streams_[shard_id]->seekg(page_offset, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + auto &io_read = file_streams_[shard_id]->read(reinterpret_cast(&buf_[buf_id][0]), page_length); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardReader::ConsumerByBlock(int consumer_id) { + // Set thread name +#if !defined(_WIN32) && !defined(_WIN64) + auto thread_id = kThreadName + std::to_string(consumer_id); + prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0); +#endif + + // Loop forever + for (;;) { + int task_id = 0; + + // Get next task ID + task_id = task_id_++; + + // All tasks are done, either quit or repeat again + if (task_id >= num_blocks_) { + std::unique_lock lck(mtx_delivery_); + cv_delivery_.wait(lck, [this] { return interrupt_ || task_id_ < num_blocks_; }); + if (interrupt_) { + return SUCCESS; + } + continue; + } + + // Pick up task from task list + auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); + + auto shard_id = std::get<0>(std::get<1>(task)); + auto group_id = std::get<1>(std::get<1>(task)); + auto row_group_brief = ReadRowGroupBrief(group_id, shard_id, selected_columns_); + if (SUCCESS != std::get<0>(row_group_brief)) { + return FAILED; + } + auto page_length = std::get<2>(row_group_brief); + auto page_offset = std::get<3>(row_group_brief); + + MS_LOG(DEBUG) << "Block task " << task_id << tasks_.permutation_[task_id] << ", shard " << shard_id << ", group " + << group_id << ", page length " << page_length << ", page offset " << page_offset; + + // Deliver block data to output map + auto offset_and_labels = std::make_pair(std::get<4>(row_group_brief), std::get<5>(row_group_brief)); + + int deliver_id = deliver_id_; + // Hanging if maximum map size exceeded otherwise, set batch data in buffer + { + std::unique_lock lck(mtx_delivery_); + cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id < deliver_id_ + kNumPageInBuffer; }); + if (interrupt_) { + return SUCCESS; + } + } + + auto buf_id = task_id % kNumPageInBuffer; + delivery_block_[buf_id] = + std::make_shared>, std::vector>>(offset_and_labels); + + // Read blob + if (ReadBlob(shard_id, page_offset, page_length, buf_id) != SUCCESS) { + return FAILED; + } + + { + std::unique_lock lck(mtx_delivery_); + delivery_block_set_.insert(task_id); + } + cv_iterator_.notify_one(); + } +} + +std::shared_ptr, json>>> ShardReader::GetRowFromBuffer(int buf_id, + int rowId) { + auto &blob_page = buf_[buf_id]; + auto &offsets = (*delivery_block_[buf_id]).first; + auto &labels = (*delivery_block_[buf_id]).second; + auto &addr_start = offsets[rowId][0]; + auto &addr_end = offsets[rowId][1]; + std::vector images(blob_page.begin() + addr_start, blob_page.begin() + addr_end); + std::vector, json>> batch; + batch.emplace_back(std::move(images), std::move(labels[rowId])); + return std::make_shared, json>>>(std::move(batch)); +} + +std::vector, json>> ShardReader::GetBlockNext() { + if (deliver_id_ >= num_blocks_) { + return std::vector, json>>(); + } + + if (row_id_ == 0) { + std::unique_lock lck(mtx_delivery_); + cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_block_set_.count(deliver_id_) > 0); }); + + if (interrupt_) { + return std::vector, json>>(); + } + } + auto buf_id = deliver_id_ % kNumPageInBuffer; + auto res = GetRowFromBuffer(buf_id, row_id_); + + row_id_++; + if (row_id_ == (*delivery_block_[buf_id]).first.size()) { + row_id_ = 0; + { + std::unique_lock lck(mtx_delivery_); + delivery_block_set_.erase(deliver_id_++); + } + cv_delivery_.notify_all(); + } + + return *res; +} + +std::vector, json>> ShardReader::GetNext() { + if (interrupt_) { + return std::vector, json>>(); + } + if (block_reader_) return GetBlockNext(); + if (deliver_id_ >= static_cast(tasks_.Size())) { + return std::vector, json>>(); + } + + std::shared_ptr, json>>> res; + { + std::unique_lock lck(mtx_delivery_); + cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_map_.count(deliver_id_) > 0); }); + if (interrupt_) { + return std::vector, json>>(); + } + res = delivery_map_[deliver_id_]; + delivery_map_.erase(deliver_id_++); + } + + cv_delivery_.notify_all(); + + return *res; +} + +std::pair, json>>> ShardReader::GetNextById( + const int64_t &task_id, const int32_t &consumer_id) { + if (interrupt_) { + return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); + } + if (block_reader_) { + return std::make_pair(TaskType::kCommonTask, GetBlockNext()); + } + const auto &ret = ConsumerOneTask(task_id, consumer_id); + if (SUCCESS != ret.first) { + return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); + } + return std::move(ret.second); +} + +std::pair>> ShardReader::UnCompressBlob( + const std::vector &raw_blob_data) { + auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_; + auto blob_fields = GetBlobFields().second; + std::vector> blob_data; + for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) { + if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue; + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0; + auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << "."; + return {FAILED, std::vector>(blob_fields.size(), std::vector())}; + } + if (data == nullptr) { + data = reinterpret_cast(data_ptr.get()); + } + std::vector column(data, data + (n_bytes / sizeof(unsigned char))); + blob_data.push_back(column); + } + return {SUCCESS, blob_data}; +} + +std::vector>, pybind11::object>> ShardReader::GetNextPy() { + auto res = GetNext(); + vector>, pybind11::object>> data; + std::transform(res.begin(), res.end(), std::back_inserter(data), + [this](const std::tuple, json> &item) { + auto &j = std::get<1>(item); + pybind11::object obj = nlohmann::detail::FromJsonImpl(j); + auto ret = UnCompressBlob(std::get<0>(item)); + return std::make_tuple(ret.second, std::move(obj)); + }); + return data; +} + +void ShardReader::Reset() { + { + std::lock_guard lck(mtx_delivery_); + task_id_ = 0; + deliver_id_ = 0; + } + cv_delivery_.notify_all(); +} + +void ShardReader::ShuffleTask() { + if (block_reader_) return; + // exist shuffle and distributed sampler in ops, skip shuffle + bool has_sharding = false; + for (const auto &op : operators_) { + if (std::dynamic_pointer_cast(op)) { + has_sharding = true; + } + } + for (const auto &op : operators_) { + if (std::dynamic_pointer_cast(op) && has_sharding == false) { + if (SUCCESS != (*op)(tasks_)) { + MS_LOG(WARNING) << "Redo randomSampler failed."; + } + } else if (std::dynamic_pointer_cast(op)) { + if (SUCCESS != (*op)(tasks_)) { + MS_LOG(WARNING) << "Redo distributeSampler failed."; + } + } + } + if (tasks_.permutation_.empty()) tasks_.MakePerm(); +} + +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc new file mode 100644 index 0000000000..eda8924e13 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc @@ -0,0 +1,385 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_segment.h" +#include "common/utils.h" + +#include "./securec.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "pybind11/pybind11.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; +using mindspore::MsLogLevel::INFO; + +namespace mindspore { +namespace mindrecord { +ShardSegment::ShardSegment() { SetAllInIndex(false); } + +std::pair> ShardSegment::GetCategoryFields() { + // Skip if already populated + if (!candidate_category_fields_.empty()) return {SUCCESS, candidate_category_fields_}; + + std::string sql = "PRAGMA table_info(INDEXES);"; + std::vector> field_names; + + char *errmsg = nullptr; + int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(database_paths_[0]); + database_paths_[0] = nullptr; + return {FAILED, vector{}}; + } else { + MS_LOG(INFO) << "Get " << static_cast(field_names.size()) << " records from index."; + } + + uint32_t idx = kStartFieldId; + while (idx < field_names.size()) { + if (field_names[idx].size() < 2) { + sqlite3_free(errmsg); + sqlite3_close(database_paths_[0]); + database_paths_[0] = nullptr; + return {FAILED, vector{}}; + } + candidate_category_fields_.push_back(field_names[idx][1]); + idx += 2; + } + sqlite3_free(errmsg); + return {SUCCESS, candidate_category_fields_}; +} + +MSRStatus ShardSegment::SetCategoryField(std::string category_field) { + if (GetCategoryFields().first != SUCCESS) { + MS_LOG(ERROR) << "Get candidate category field failed"; + return FAILED; + } + category_field = category_field + "_0"; + if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_), + [category_field](std::string x) { return x == category_field; })) { + current_category_field_ = category_field; + return SUCCESS; + } + MS_LOG(ERROR) << "Field " << category_field << " is not a candidate category field."; + return FAILED; +} + +std::pair ShardSegment::ReadCategoryInfo() { + MS_LOG(INFO) << "Read category begin"; + auto ret = WrapCategoryInfo(); + if (ret.first != SUCCESS) { + MS_LOG(ERROR) << "Get category info failed"; + return {FAILED, ""}; + } + // Convert category info to json string + auto category_json_string = ToJsonForCategory(ret.second); + + MS_LOG(INFO) << "Read category end"; + + return {SUCCESS, category_json_string}; +} + +std::pair>> ShardSegment::WrapCategoryInfo() { + std::map counter; + + std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ + + ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";"; + + for (auto &db : database_paths_) { + std::vector> field_count; + + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return {FAILED, std::vector>()}; + } else { + MS_LOG(INFO) << "Get " << static_cast(field_count.size()) << " records from index."; + } + + for (const auto &field : field_count) { + counter[field[0]] += std::stoi(field[1]); + } + sqlite3_free(errmsg); + } + + int idx = 0; + std::vector> category_vec(counter.size()); + (void)std::transform(counter.begin(), counter.end(), category_vec.begin(), [&idx](std::tuple item) { + return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); + }); + return {SUCCESS, std::move(category_vec)}; +} + +std::string ShardSegment::ToJsonForCategory(const std::vector> &tri_vec) { + std::vector category_json_vec; + for (auto q : tri_vec) { + json j; + j["id"] = std::get<0>(q); + j["name"] = std::get<1>(q); + j["count"] = std::get<2>(q); + + category_json_vec.emplace_back(j); + } + + json j_vec(category_json_vec); + json category_info; + category_info["key"] = current_category_field_; + category_info["categories"] = j_vec; + return category_info.dump(); +} + +std::pair>> ShardSegment::ReadAtPageById(int64_t category_id, + int64_t page_no, + int64_t n_rows_of_page) { + auto ret = WrapCategoryInfo(); + if (ret.first != SUCCESS) { + MS_LOG(ERROR) << "Get category info"; + return {FAILED, std::vector>{}}; + } + if (category_id >= static_cast(ret.second.size()) || category_id < 0) { + MS_LOG(ERROR) << "Illegal category id, id: " << category_id; + return {FAILED, std::vector>{}}; + } + int total_rows_in_category = std::get<2>(ret.second[category_id]); + // Quit if category not found or page number is out of range + if (total_rows_in_category <= 0 || page_no < 0 || n_rows_of_page <= 0 || + page_no * n_rows_of_page >= total_rows_in_category) { + MS_LOG(ERROR) << "Illegal page no / page size, page no: " << page_no << ", page size: " << n_rows_of_page; + return {FAILED, std::vector>{}}; + } + + std::vector> page; + auto row_group_summary = ReadRowGroupSummary(); + + uint64_t i_start = page_no * n_rows_of_page; + uint64_t i_end = std::min(static_cast(total_rows_in_category), (page_no + 1) * n_rows_of_page); + uint64_t idx = 0; + for (const auto &rg : row_group_summary) { + if (idx >= i_end) break; + + auto shard_id = std::get<0>(rg); + auto group_id = std::get<1>(rg); + auto details = ReadRowGroupCriteria( + group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); + if (SUCCESS != std::get<0>(details)) { + return {FAILED, std::vector>{}}; + } + auto offsets = std::get<4>(details); + uint64_t number_of_rows = offsets.size(); + if (idx + number_of_rows < i_start) { + idx += number_of_rows; + continue; + } + + for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) { + if (idx >= i_start && idx < i_end) { + auto ret1 = PackImages(group_id, shard_id, offsets[i]); + if (SUCCESS != ret1.first) { + return {FAILED, std::vector>{}}; + } + page.push_back(std::move(ret1.second)); + } + } + } + + return {SUCCESS, std::move(page)}; +} + +std::pair> ShardSegment::PackImages(int group_id, int shard_id, + std::vector offset) { + const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); + if (SUCCESS != ret.first) { + return {FAILED, std::vector()}; + } + const std::shared_ptr &blob_page = ret.second; + + // Pack image list + std::vector images(offset[1] - offset[0]); + auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0]; + auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + file_streams_random_[0][shard_id]->close(); + return {FAILED, {}}; + } + + auto &io_read = file_streams_random_[0][shard_id]->read(reinterpret_cast(&images[0]), offset[1] - offset[0]); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + file_streams_random_[0][shard_id]->close(); + return {FAILED, {}}; + } + + return {SUCCESS, std::move(images)}; +} + +std::pair>> ShardSegment::ReadAtPageByName(std::string category_name, + int64_t page_no, + int64_t n_rows_of_page) { + auto ret = WrapCategoryInfo(); + if (ret.first != SUCCESS) { + MS_LOG(ERROR) << "Get category info"; + return {FAILED, std::vector>{}}; + } + for (const auto &categories : ret.second) { + if (std::get<1>(categories) == category_name) { + auto result = ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page); + return result; + } + } + + return {FAILED, std::vector>()}; +} + +std::pair, json>>> ShardSegment::ReadAllAtPageById( + int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { + auto ret = WrapCategoryInfo(); + if (ret.first != SUCCESS || category_id >= static_cast(ret.second.size())) { + MS_LOG(ERROR) << "Illegal category id, id: " << category_id; + return {FAILED, std::vector, json>>{}}; + } + int total_rows_in_category = std::get<2>(ret.second[category_id]); + // Quit if category not found or page number is out of range + if (total_rows_in_category <= 0 || page_no < 0 || page_no * n_rows_of_page >= total_rows_in_category) { + MS_LOG(ERROR) << "Illegal page no: " << page_no << ", page size: " << n_rows_of_page; + return {FAILED, std::vector, json>>{}}; + } + + std::vector, json>> page; + auto row_group_summary = ReadRowGroupSummary(); + + int i_start = page_no * n_rows_of_page; + int i_end = std::min(static_cast(total_rows_in_category), (page_no + 1) * n_rows_of_page); + int idx = 0; + for (const auto &rg : row_group_summary) { + if (idx >= i_end) break; + + auto shard_id = std::get<0>(rg); + auto group_id = std::get<1>(rg); + auto details = ReadRowGroupCriteria( + group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); + if (SUCCESS != std::get<0>(details)) { + return {FAILED, std::vector, json>>{}}; + } + auto offsets = std::get<4>(details); + auto labels = std::get<5>(details); + + int number_of_rows = offsets.size(); + if (idx + number_of_rows < i_start) { + idx += number_of_rows; + continue; + } + + if (number_of_rows > static_cast(labels.size())) { + MS_LOG(ERROR) << "Illegal row number of page: " << number_of_rows; + return {FAILED, std::vector, json>>{}}; + } + for (int i = 0; i < number_of_rows; ++i, ++idx) { + if (idx >= i_start && idx < i_end) { + auto ret1 = PackImages(group_id, shard_id, offsets[i]); + if (SUCCESS != ret1.first) { + return {FAILED, std::vector, json>>{}}; + } + page.emplace_back(std::move(ret1.second), std::move(labels[i])); + } + } + } + return {SUCCESS, std::move(page)}; +} + +std::pair, json>>> ShardSegment::ReadAllAtPageByName( + std::string category_name, int64_t page_no, int64_t n_rows_of_page) { + auto ret = WrapCategoryInfo(); + if (ret.first != SUCCESS) { + MS_LOG(ERROR) << "Get category info"; + return {FAILED, std::vector, json>>{}}; + } + + // category_name to category_id + int64_t category_id = -1; + for (const auto &categories : ret.second) { + std::string categories_name = std::get<1>(categories); + + if (categories_name == category_name) { + category_id = std::get<0>(categories); + break; + } + } + + if (category_id == -1) { + return {FAILED, std::vector, json>>{}}; + } + + return ReadAllAtPageById(category_id, page_no, n_rows_of_page); +} + +std::pair, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( + int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { + auto res = ReadAllAtPageById(category_id, page_no, n_rows_of_page); + if (res.first != SUCCESS) { + return {FAILED, std::vector, pybind11::object>>{}}; + } + + vector, pybind11::object>> json_data; + std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), + [](const std::tuple, json> &item) { + auto &j = std::get<1>(item); + pybind11::object obj = nlohmann::detail::FromJsonImpl(j); + return std::make_tuple(std::get<0>(item), std::move(obj)); + }); + return {SUCCESS, std::move(json_data)}; +} + +std::pair, pybind11::object>>> ShardSegment::ReadAtPageByNamePy( + std::string category_name, int64_t page_no, int64_t n_rows_of_page) { + auto res = ReadAllAtPageByName(category_name, page_no, n_rows_of_page); + if (res.first != SUCCESS) { + return {FAILED, std::vector, pybind11::object>>{}}; + } + vector, pybind11::object>> json_data; + std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), + [](const std::tuple, json> &item) { + auto &j = std::get<1>(item); + pybind11::object obj = nlohmann::detail::FromJsonImpl(j); + return std::make_tuple(std::get<0>(item), std::move(obj)); + }); + return {SUCCESS, std::move(json_data)}; +} + +std::pair> ShardSegment::GetBlobFields() { + std::vector blob_fields; + for (auto &p : GetShardHeader()->GetSchemas()) { + // assume one schema + const auto &fields = p->GetBlobFields(); + blob_fields.assign(fields.begin(), fields.end()); + break; + } + return std::make_pair(kCV, blob_fields); +} + +std::string ShardSegment::CleanUp(std::string field_name) { + while (field_name.back() >= '0' && field_name.back() <= '9') field_name.pop_back(); + field_name.pop_back(); + return field_name; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc new file mode 100644 index 0000000000..e85229cc34 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc @@ -0,0 +1,1254 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_writer.h" +#include "common/utils.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "./securec.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::DEBUG; +using mindspore::MsLogLevel::ERROR; +using mindspore::MsLogLevel::INFO; + +namespace mindspore { +namespace mindrecord { +ShardWriter::ShardWriter() + : shard_count_(1), + header_size_(kDefaultHeaderSize), + page_size_(kDefaultPageSize), + row_count_(0), + schema_count_(1) {} + +ShardWriter::~ShardWriter() { + for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { + file_streams_[i]->close(); + } +} + +MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { + // Get full path from file name + for (const auto &path : paths) { + if (!CheckIsValidUtf8(path)) { + MS_LOG(ERROR) << "The filename contains invalid uft-8 data: " << path << "."; + return FAILED; + } + char resolved_path[PATH_MAX] = {0}; + char buf[PATH_MAX] = {0}; + if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { + MS_LOG(ERROR) << "Secure func failed"; + return FAILED; + } +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(resolved_path, dirname(&(buf[0])), PATH_MAX) == nullptr) { + MS_LOG(ERROR) << "Invalid file path"; + return FAILED; + } + if (_fullpath(resolved_path, common::SafeCStr(path), PATH_MAX) == nullptr) { + MS_LOG(DEBUG) << "Path " << resolved_path; + } +#else + if (realpath(dirname(&(buf[0])), resolved_path) == nullptr) { + MS_LOG(ERROR) << "Invalid file path"; + return FAILED; + } + if (realpath(common::SafeCStr(path), resolved_path) == nullptr) { + MS_LOG(DEBUG) << "Path " << resolved_path; + } +#endif + file_paths_.emplace_back(string(resolved_path)); + } + return SUCCESS; +} + +MSRStatus ShardWriter::OpenDataFiles(bool append) { + // Open files + for (const auto &file : file_paths_) { + std::shared_ptr fs = std::make_shared(); + if (!append) { + // if not append and mindrecord file exist, return FAILED + fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); + if (fs->good()) { + MS_LOG(ERROR) << "MindRecord file already existed."; + fs->close(); + return FAILED; + } + fs->close(); + + // open the mindrecord file to write + fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc); + if (!fs->good()) { + MS_LOG(ERROR) << "MindRecord file could not opened."; + return FAILED; + } + } else { + // open the mindrecord file to append + fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "MindRecord file could not opened for append."; + return FAILED; + } + } + MS_LOG(INFO) << "Open shard file successfully."; + file_streams_.push_back(fs); + } + return SUCCESS; +} + +MSRStatus ShardWriter::RemoveLockFile() { + // Remove temporary file + int ret = std::remove(pages_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove page file."; + } + + ret = std::remove(lock_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove lock file."; + } + return SUCCESS; +} + +MSRStatus ShardWriter::InitLockFile() { + if (file_paths_.size() == 0) { + MS_LOG(ERROR) << "File path not initialized."; + return FAILED; + } + + lock_file_ = file_paths_[0] + kLockFileSuffix; + pages_file_ = file_paths_[0] + kPageFileSuffix; + + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove file failed."; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { + shard_count_ = paths.size(); + if (shard_count_ > kMaxShardCount || shard_count_ == 0) { + MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; + return FAILED; + } + if (schema_count_ > kMaxSchemaCount) { + MS_LOG(ERROR) << "The schema Count greater than max value."; + return FAILED; + } + + // Get full path from file name + if (GetFullPathFromFileName(paths) == FAILED) { + MS_LOG(ERROR) << "Get full path from file name failed."; + return FAILED; + } + + // Open files + if (OpenDataFiles(append) == FAILED) { + MS_LOG(ERROR) << "Open data files failed."; + return FAILED; + } + + // Init lock file + if (InitLockFile() == FAILED) { + MS_LOG(ERROR) << "Init lock file failed."; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardWriter::OpenForAppend(const std::string &path) { + if (!IsLegalFile(path)) { + return FAILED; + } + auto ret1 = ShardHeader::BuildSingleHeader(path); + if (ret1.first != SUCCESS) { + return FAILED; + } + auto json_header = ret1.second; + auto ret2 = GetParentDir(path); + if (SUCCESS != ret2.first) { + return FAILED; + } + std::vector real_addresses; + for (const auto &path : json_header["shard_addresses"]) { + std::string abs_path = ret2.second + string(path); + real_addresses.emplace_back(abs_path); + } + ShardHeader header = ShardHeader(); + if (header.BuildDataset(real_addresses) == FAILED) { + return FAILED; + } + shard_header_ = std::make_shared(header); + MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); + if (ret == FAILED) { + return FAILED; + } + ret = SetPageSize(shard_header_->GetPageSize()); + if (ret == FAILED) { + return FAILED; + } + ret = Open(real_addresses, true); + if (ret == FAILED) { + MS_LOG(ERROR) << "Open file failed"; + return FAILED; + } + shard_column_ = std::make_shared(shard_header_); + return SUCCESS; +} + +MSRStatus ShardWriter::Commit() { + // Read pages file + std::ifstream page_file(pages_file_.c_str()); + if (page_file.good()) { + page_file.close(); + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return FAILED; + } + } + + if (WriteShardHeader() == FAILED) { + MS_LOG(ERROR) << "Write metadata failed"; + return FAILED; + } + MS_LOG(INFO) << "Write metadata successfully."; + + // Remove lock file + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove lock file failed."; + return FAILED; + } + + return SUCCESS; +} + +MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) { + MSRStatus ret = header_data->InitByFiles(file_paths_); + if (ret == FAILED) { + return FAILED; + } + + // set fields in mindrecord when empty + std::vector> fields = header_data->GetFields(); + if (fields.empty()) { + MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields."; + std::vector> schemas = header_data->GetSchemas(); + for (const auto &schema : schemas) { + json jsonSchema = schema->GetSchema()["schema"]; + for (const auto &el : jsonSchema.items()) { + if (el.value()["type"] == "string" || + (el.value()["type"] == "int32" && el.value().find("shape") == el.value().end()) || + (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) || + (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) || + (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) { + fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key())); + } + } + } + // only blob data + if (!fields.empty()) { + ret = header_data->AddIndexFields(fields); + if (ret == FAILED) { + MS_LOG(ERROR) << "Add index field failed"; + return FAILED; + } + } + } + + shard_header_ = header_data; + shard_header_->SetHeaderSize(header_size_); + shard_header_->SetPageSize(page_size_); + shard_column_ = std::make_shared(shard_header_); + return SUCCESS; +} + +MSRStatus ShardWriter::SetHeaderSize(const uint64_t &header_size) { + // header_size [16KB, 128MB] + if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) { + MS_LOG(ERROR) << "Header size should between 16KB and 128MB."; + return FAILED; + } + if (header_size % 4 != 0) { + MS_LOG(ERROR) << "Header size should be divided by four."; + return FAILED; + } + + header_size_ = header_size; + return SUCCESS; +} + +MSRStatus ShardWriter::SetPageSize(const uint64_t &page_size) { + // PageSize [32KB, 256MB] + if (page_size < kMinPageSize || page_size > kMaxPageSize) { + MS_LOG(ERROR) << "Page size should between 16KB and 256MB."; + return FAILED; + } + if (page_size % 4 != 0) { + MS_LOG(ERROR) << "Page size should be divided by four."; + return FAILED; + } + page_size_ = page_size; + return SUCCESS; +} + +void ShardWriter::DeleteErrorData(std::map> &raw_data, + std::vector> &blob_data) { + // get wrong data location + std::set> delete_set; + for (auto &err_mg : err_mg_) { + uint64_t id = err_mg.first; + auto sub_err_mg = err_mg.second; + for (auto &subMg : sub_err_mg) { + int loc = subMg.first; + std::string message = subMg.second; + MS_LOG(ERROR) << "For schema " << id << ", " << loc + 1 << " th data is wrong: " << message; + (void)delete_set.insert(loc); + } + } + + auto it = raw_data.begin(); + if (delete_set.size() == it->second.size()) { + raw_data.clear(); + blob_data.clear(); + return; + } + + // delete wrong raw data + for (auto &loc : delete_set) { + // delete row data + for (auto &raw : raw_data) { + (void)raw.second.erase(raw.second.begin() + loc); + } + + // delete blob data + (void)blob_data.erase(blob_data.begin() + loc); + } +} + +void ShardWriter::PopulateMutexErrorData(const int &row, const std::string &message, + std::map &err_raw_data) { + std::lock_guard lock(check_mutex_); + (void)err_raw_data.insert(std::make_pair(row, message)); +} + +MSRStatus ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, + std::map &err_raw_data) { + auto data_type = std::string(value["type"].get()); + + if ((data_type == "int32" && !data[key].is_number_integer()) || + (data_type == "int64" && !data[key].is_number_integer()) || + (data_type == "float32" && !data[key].is_number_float()) || + (data_type == "float64" && !data[key].is_number_float()) || (data_type == "string" && !data[key].is_string())) { + std::string message = "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is not matched"; + PopulateMutexErrorData(i, message, err_raw_data); + return FAILED; + } + + if (data_type == "int32" && data[key].is_number_integer()) { + int64_t temp_value = data[key]; + if (static_cast(temp_value) < static_cast(std::numeric_limits::min()) && + static_cast(temp_value) > static_cast(std::numeric_limits::max())) { + std::string message = + "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is out of range"; + PopulateMutexErrorData(i, message, err_raw_data); + return FAILED; + } + } + return SUCCESS; +} + +void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const std::vector &sub_raw_data, + std::map &err_raw_data) { + if (start_row < 0 || start_row > end_row || end_row > static_cast(sub_raw_data.size())) { + return; + } + for (int i = start_row; i < end_row; i++) { + json data = sub_raw_data[i]; + + for (auto iter = schema.begin(); iter != schema.end(); iter++) { + std::string key = iter.key(); + json value = iter.value(); + if (data.find(key) == data.end()) { + std::string message = "there is not '" + key + "' object in the raw data"; + PopulateMutexErrorData(i, message, err_raw_data); + break; + } + + if (value.size() == kInt2) { + // Skip check since all shaped data will store as blob + continue; + } + + if (CheckDataTypeAndValue(key, value, data, i, err_raw_data) != SUCCESS) { + break; + } + } + } +} + +MSRStatus ShardWriter::CheckData(const std::map> &raw_data) { + auto rawdata_iter = raw_data.begin(); + + // make sure rawdata match schema + for (; rawdata_iter != raw_data.end(); ++rawdata_iter) { + // used for storing error + std::map sub_err_mg; + int schema_id = rawdata_iter->first; + auto result = shard_header_->GetSchemaByID(schema_id); + if (result.second != SUCCESS) { + return FAILED; + } + json schema = result.first->GetSchema()["schema"]; + for (const auto &field : result.first->GetBlobFields()) { + (void)schema.erase(field); + } + std::vector sub_raw_data = rawdata_iter->second; + + // calculate start position and end position for each thread + int batch_size = rawdata_iter->second.size() / shard_count_; + int thread_num = shard_count_; + if (thread_num <= 0) { + return FAILED; + } + if (thread_num > kMaxThreadCount) { + thread_num = kMaxThreadCount; + } + std::vector thread_set(thread_num); + + // start multiple thread + int start_row = 0, end_row = 0; + for (int x = 0; x < thread_num; ++x) { + if (x != thread_num - 1) { + start_row = batch_size * x; + end_row = batch_size * (x + 1); + } else { + start_row = batch_size * x; + end_row = rawdata_iter->second.size(); + } + thread_set[x] = std::thread(&ShardWriter::CheckSliceData, this, start_row, end_row, schema, + std::ref(sub_raw_data), std::ref(sub_err_mg)); + } + if (thread_num > kMaxThreadCount) { + return FAILED; + } + // Wait for threads done + for (int x = 0; x < thread_num; ++x) { + thread_set[x].join(); + } + + (void)err_mg_.insert(std::make_pair(schema_id, sub_err_mg)); + } + return SUCCESS; +} + +std::tuple ShardWriter::ValidateRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign) { + auto rawdata_iter = raw_data.begin(); + schema_count_ = raw_data.size(); + std::tuple failed(FAILED, 0, 0); + if (schema_count_ == 0) { + MS_LOG(ERROR) << "Data size is zero"; + return failed; + } + + // keep schema_id + std::set schema_ids; + row_count_ = (rawdata_iter->second).size(); + MS_LOG(DEBUG) << "Schema count is " << schema_count_; + + // Determine if the number of schemas is the same + if (shard_header_->GetSchemas().size() != schema_count_) { + MS_LOG(ERROR) << "Data size is not equal with the schema size"; + return failed; + } + + // Determine raw_data size == blob_data size + if (raw_data[0].size() != blob_data.size()) { + MS_LOG(ERROR) << "Raw data size is not equal blob data size"; + return failed; + } + + // Determine whether the number of samples corresponding to each schema is the same + for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) { + if (row_count_ != rawdata_iter->second.size()) { + MS_LOG(ERROR) << "Data size is not equal"; + return failed; + } + (void)schema_ids.insert(rawdata_iter->first); + } + const std::vector> &schemas = shard_header_->GetSchemas(); + if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr &schema) { + return schema_ids.find(schema->GetSchemaID()) == schema_ids.end(); + })) { + // There is not enough data which is not matching the number of schema + MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id."; + return failed; + } + + if (!sign) { + std::tuple success(SUCCESS, schema_count_, row_count_); + return success; + } + + // check the data according the schema + if (CheckData(raw_data) != SUCCESS) { + MS_LOG(ERROR) << "Data validate check failed"; + return std::tuple(FAILED, schema_count_, row_count_); + } + + // delete wrong data from raw data + DeleteErrorData(raw_data, blob_data); + + // update raw count + row_count_ = row_count_ - err_mg_.begin()->second.size(); + std::tuple success(SUCCESS, schema_count_, row_count_); + return success; +} + +void ShardWriter::FillArray(int start, int end, std::map> &raw_data, + std::vector> &bin_data) { + // Prevent excessive thread opening and cause cross-border + if (start >= end) { + flag_ = true; + return; + } + int schema_count = static_cast(raw_data.size()); + std::map>::const_iterator rawdata_iter; + for (int x = start; x < end; ++x) { + int cnt = 0; + for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) { + const json &line = raw_data.at(rawdata_iter->first)[x]; + std::vector bline = json::to_msgpack(line); + + // Storage form is [Sample1-Schema1, Sample1-Schema2, Sample2-Schema1, Sample2-Schema2] + bin_data[x * schema_count + cnt] = bline; + cnt++; + } + } +} + +int ShardWriter::LockWriter(bool parallel_writer) { + if (!parallel_writer) { + return 0; + } + +#if defined(_WIN32) || defined(_WIN64) + MS_LOG(DEBUG) << "Lock file done by python."; + const int fd = 0; +#else + const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); + if (fd >= 0) { + flock(fd, LOCK_EX); + } else { + MS_LOG(ERROR) << "Shard writer failed when locking file"; + return -1; + } +#endif + + // Open files + file_streams_.clear(); + for (const auto &file : file_paths_) { + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); + if (fs->fail()) { + MS_LOG(ERROR) << "File could not opened"; + return -1; + } + file_streams_.push_back(fs); + } + + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return -1; + } + return fd; +} + +MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { + if (!parallel_writer) { + return SUCCESS; + } + + if (shard_header_->PagesToFile(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Write pages to file failed"; + return FAILED; + } + + for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { + file_streams_[i]->close(); + } + +#if defined(_WIN32) || defined(_WIN64) + MS_LOG(DEBUG) << "Unlock file done by python."; +#else + flock(fd, LOCK_UN); + close(fd); +#endif + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, + std::vector> &blob_data, bool sign, int *schema_count, + int *row_count) { + // check the free disk size + auto st_space = GetDiskSize(file_paths_[0], kFreeSize); + if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { + MS_LOG(ERROR) << "IO error / there is no free disk to be used"; + return FAILED; + } + + // compress blob + if (shard_column_->CheckCompressBlob()) { + for (auto &blob : blob_data) { + blob = shard_column_->CompressBlob(blob); + } + } + + // Add 4-bytes dummy blob data if no any blob fields + if (blob_data.size() == 0 && raw_data.size() > 0) { + blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); + } + + // Add dummy id if all are blob fields + if (blob_data.size() > 0 && raw_data.size() == 0) { + raw_data.insert(std::pair>(0, std::vector(blob_data.size(), kDummyId))); + } + + auto v = ValidateRawData(raw_data, blob_data, sign); + if (std::get<0>(v) == FAILED) { + MS_LOG(ERROR) << "Validate raw data failed"; + return FAILED; + } + *schema_count = std::get<1>(v); + *row_count = std::get<2>(v); + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign, bool parallel_writer) { + // Lock Writer if loading data parallel + int fd = LockWriter(parallel_writer); + if (fd < 0) { + MS_LOG(ERROR) << "Lock writer failed"; + return FAILED; + } + + // Get the count of schemas and rows + int schema_count = 0; + int row_count = 0; + + // Serialize raw data + if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { + MS_LOG(ERROR) << "Check raw data failed"; + return FAILED; + } + + if (row_count == kInt0) { + MS_LOG(INFO) << "Raw data size is 0."; + return SUCCESS; + } + + std::vector> bin_raw_data(row_count * schema_count); + + // Serialize raw data + if (SerializeRawData(raw_data, bin_raw_data, row_count) == FAILED) { + MS_LOG(ERROR) << "Serialize raw data failed"; + return FAILED; + } + + // Set row size of raw data + if (SetRawDataSize(bin_raw_data) == FAILED) { + MS_LOG(ERROR) << "Set raw data size failed"; + return FAILED; + } + + // Set row size of blob data + if (SetBlobDataSize(blob_data) == FAILED) { + MS_LOG(ERROR) << "Set blob data size failed"; + return FAILED; + } + + // Write data to disk with multi threads + if (ParallelWriteData(blob_data, bin_raw_data) == FAILED) { + MS_LOG(ERROR) << "Parallel write data failed"; + return FAILED; + } + MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; + + if (UnlockWriter(fd, parallel_writer) == FAILED) { + MS_LOG(ERROR) << "Unlock writer failed"; + return FAILED; + } + + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + std::map> &blob_data, bool sign, + bool parallel_writer) { + std::map> raw_data_json; + std::map> blob_data_json; + + (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), + [](const std::pair> &pair) { + auto &py_raw_data = pair.second; + std::vector json_raw_data; + (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), + [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); + return std::make_pair(pair.first, std::move(json_raw_data)); + }); + + (void)std::transform(blob_data.begin(), blob_data.end(), std::inserter(blob_data_json, blob_data_json.end()), + [](const std::pair> &pair) { + auto &py_blob_data = pair.second; + std::vector jsonBlobData; + (void)std::transform(py_blob_data.begin(), py_blob_data.end(), + std::back_inserter(jsonBlobData), + [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); + return std::make_pair(pair.first, std::move(jsonBlobData)); + }); + + // Serialize blob page + auto blob_data_iter = blob_data.begin(); + auto schema_count = blob_data.size(); + auto row_count = blob_data_iter->second.size(); + + std::vector> bin_blob_data(row_count * schema_count); + // Serialize blob data + if (SerializeRawData(blob_data_json, bin_blob_data, row_count) == FAILED) { + MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; + return FAILED; + } + return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); +} + +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + vector> &blob_data, bool sign, bool parallel_writer) { + std::map> raw_data_json; + (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), + [](const std::pair> &pair) { + auto &py_raw_data = pair.second; + std::vector json_raw_data; + (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), + [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); + return std::make_pair(pair.first, std::move(json_raw_data)); + }); + return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); +} + +MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, + const std::vector> &bin_raw_data) { + auto shards = BreakIntoShards(); + // define the number of thread + int thread_num = static_cast(shard_count_); + if (thread_num < 0) { + return FAILED; + } + if (thread_num > kMaxThreadCount) { + thread_num = kMaxThreadCount; + } + int left_thread = shard_count_; + int current_thread = 0; + while (left_thread) { + if (left_thread < thread_num) { + thread_num = left_thread; + } + // Start one thread for one shard + std::vector thread_set(thread_num); + if (thread_num <= kMaxThreadCount) { + for (int x = 0; x < thread_num; ++x) { + int start_row = shards[current_thread + x].first; + int end_row = shards[current_thread + x].second; + thread_set[x] = std::thread(&ShardWriter::WriteByShard, this, current_thread + x, start_row, end_row, + std::ref(blob_data), std::ref(bin_raw_data)); + } + // Wait for threads done + for (int x = 0; x < thread_num; ++x) { + thread_set[x].join(); + } + left_thread -= thread_num; + current_thread += thread_num; + } + } + return SUCCESS; +} + +MSRStatus ShardWriter::WriteByShard(int shard_id, int start_row, int end_row, + const std::vector> &blob_data, + const std::vector> &bin_raw_data) { + MS_LOG(DEBUG) << "Shard: " << shard_id << ", start: " << start_row << ", end: " << end_row + << ", schema size: " << schema_count_; + if (start_row == end_row) { + return SUCCESS; + } + vector> rows_in_group; + std::shared_ptr last_raw_page = nullptr; + std::shared_ptr last_blob_page = nullptr; + SetLastRawPage(shard_id, last_raw_page); + SetLastBlobPage(shard_id, last_blob_page); + + if (CutRowGroup(start_row, end_row, blob_data, rows_in_group, last_raw_page, last_blob_page) == FAILED) { + MS_LOG(ERROR) << "Cut row group failed"; + return FAILED; + } + + if (AppendBlobPage(shard_id, blob_data, rows_in_group, last_blob_page) == FAILED) { + MS_LOG(ERROR) << "Append bolb page failed"; + return FAILED; + } + + if (NewBlobPage(shard_id, blob_data, rows_in_group, last_blob_page) == FAILED) { + MS_LOG(ERROR) << "New blob page failed"; + return FAILED; + } + + if (ShiftRawPage(shard_id, rows_in_group, last_raw_page) == FAILED) { + MS_LOG(ERROR) << "Shit raw page failed"; + return FAILED; + } + + if (WriteRawPage(shard_id, rows_in_group, last_raw_page, bin_raw_data) == FAILED) { + MS_LOG(ERROR) << "Write raw page failed"; + return FAILED; + } + + return SUCCESS; +} + +MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, + std::vector> &rows_in_group, + const std::shared_ptr &last_raw_page, + const std::shared_ptr &last_blob_page) { + auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0; + + auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; + auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0; + auto n_byte_raw = last_raw_page_size - last_raw_offset; + + int page_start_row = start_row; + if (start_row > end_row) { + return FAILED; + } + if (end_row > static_cast(blob_data_size_.size()) || end_row > static_cast(raw_data_size_.size())) { + return FAILED; + } + for (int i = start_row; i < end_row; ++i) { + // n_byte_blob(0) indicate appendBlobPage + if (n_byte_blob == 0 || n_byte_blob + blob_data_size_[i] > page_size_ || + n_byte_raw + raw_data_size_[i] > page_size_) { + rows_in_group.emplace_back(page_start_row, i); + page_start_row = i; + n_byte_blob = blob_data_size_[i]; + n_byte_raw = raw_data_size_[i]; + } else { + n_byte_blob += blob_data_size_[i]; + n_byte_raw += raw_data_size_[i]; + } + } + + // Not forget last one + rows_in_group.emplace_back(page_start_row, end_row); + return SUCCESS; +} + +MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page) { + auto blob_row = rows_in_group[0]; + if (blob_row.first == blob_row.second) return SUCCESS; + + // Write disk + auto page_id = last_blob_page->GetPageID(); + auto bytes_page = last_blob_page->GetPageSize(); + auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg); + if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { + MS_LOG(ERROR) << "File seekp failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row); + + // Update last blob page + bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); + last_blob_page->SetPageSize(bytes_page); + uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first; + last_blob_page->SetEndRowID(end_row); + (void)shard_header_->SetPage(last_blob_page); + return SUCCESS; +} + +MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page) { + auto page_id = shard_header_->GetLastPageId(shard_id); + auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1; + auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0; + // index(0) indicate appendBlobPage + for (uint32_t i = 1; i < rows_in_group.size(); ++i) { + auto blob_row = rows_in_group[i]; + + // Write 1 blob page to disk + auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg); + if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { + MS_LOG(ERROR) << "File seekp failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row); + // Create new page info for header + auto page_size = + std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); + std::vector> row_group_ids; + auto start_row = current_row; + auto end_row = start_row + blob_row.second - blob_row.first; + auto page = Page(++page_id, shard_id, kPageTypeBlob, ++page_type_id, start_row, end_row, row_group_ids, page_size); + (void)shard_header_->AddPage(std::make_shared(page)); + current_row = end_row; + } + return SUCCESS; +} + +MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page) { + auto blob_row = rows_in_group[0]; + if (blob_row.first == blob_row.second) return SUCCESS; + auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; + if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) + + last_raw_page_size <= + page_size_) { + return SUCCESS; + } + auto page_id = shard_header_->GetLastPageId(shard_id); + auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second; + auto last_raw_page_id = last_raw_page->GetPageID(); + auto shift_size = last_raw_page_size - last_row_group_id_offset; + + std::vector buf(shift_size); + + // Read last row group from previous raw data page + if (shard_id < 0 || shard_id >= file_streams_.size()) { + return FAILED; + } + + auto &io_seekg = file_streams_[shard_id]->seekg( + page_size_ * last_raw_page_id + header_size_ + last_row_group_id_offset, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + auto &io_read = file_streams_[shard_id]->read(reinterpret_cast(&buf[0]), buf.size()); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + // Merge into new row group at new raw data page + auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg); + if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { + MS_LOG(ERROR) << "File seekp failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast(&buf[0]), buf.size()); + if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { + MS_LOG(ERROR) << "File write failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + last_raw_page->DeleteLastGroupId(); + (void)shard_header_->SetPage(last_raw_page); + + // Refresh page info in header + int row_group_id = last_raw_page->GetLastRowGroupID().first + 1; + std::vector> row_group_ids; + row_group_ids.emplace_back(row_group_id, 0); + int page_type_id = last_raw_page->GetPageID(); + auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size); + (void)shard_header_->AddPage(std::make_shared(page)); + + // Reset: last raw page + SetLastRawPage(shard_id, last_raw_page); + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page, + const std::vector> &bin_raw_data) { + int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1; + for (uint32_t i = 0; i < rows_in_group.size(); ++i) { + const auto &blob_row = rows_in_group[i]; + if (blob_row.first == blob_row.second) continue; + auto raw_size = + std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0); + if (!last_raw_page) { + EmptyRawPage(shard_id, last_raw_page); + } else if (last_raw_page->GetPageSize() + raw_size > page_size_) { + (void)shard_header_->SetPage(last_raw_page); + EmptyRawPage(shard_id, last_raw_page); + } + if (AppendRawPage(shard_id, rows_in_group, i, last_row_group_id, last_raw_page, bin_raw_data) != SUCCESS) { + return FAILED; + } + } + (void)shard_header_->SetPage(last_raw_page); + return SUCCESS; +} + +void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { + auto row_group_ids = std::vector>(); + auto page_id = shard_header_->GetLastPageId(shard_id); + auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1; + auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0); + (void)shard_header_->AddPage(std::make_shared(page)); + SetLastRawPage(shard_id, last_raw_page); +} + +MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, + const int &chunk_id, int &last_row_group_id, std::shared_ptr last_raw_page, + const std::vector> &bin_raw_data) { + std::vector> row_group_ids = last_raw_page->GetRowGroupIds(); + auto last_raw_page_id = last_raw_page->GetPageID(); + auto n_bytes = last_raw_page->GetPageSize(); + + // previous raw data page + auto &io_seekp = + file_streams_[shard_id]->seekp(page_size_ * last_raw_page_id + header_size_ + n_bytes, std::ios::beg); + if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { + MS_LOG(ERROR) << "File seekp failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + if (chunk_id > 0) row_group_ids.emplace_back(++last_row_group_id, n_bytes); + n_bytes += std::accumulate(raw_data_size_.begin() + rows_in_group[chunk_id].first, + raw_data_size_.begin() + rows_in_group[chunk_id].second, 0); + (void)FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data); + + // Update previous raw data page + last_raw_page->SetPageSize(n_bytes); + last_raw_page->SetRowGroupIds(row_group_ids); + (void)shard_header_->SetPage(last_raw_page); + + return SUCCESS; +} + +MSRStatus ShardWriter::FlushBlobChunk(const std::shared_ptr &out, + const std::vector> &blob_data, + const std::pair &blob_row) { + if (blob_row.first > blob_row.second) { + return FAILED; + } + if (blob_row.second > static_cast(blob_data.size()) || blob_row.first < 0) { + return FAILED; + } + for (int j = blob_row.first; j < blob_row.second; ++j) { + // Write the size of blob + uint64_t line_len = blob_data[j].size(); + auto &io_handle = out->write(reinterpret_cast(&line_len), kInt64Len); + if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { + MS_LOG(ERROR) << "File write failed"; + out->close(); + return FAILED; + } + + // Write the data of blob + auto line = blob_data[j]; + auto &io_handle_data = out->write(reinterpret_cast(&line[0]), line_len); + if (!io_handle_data.good() || io_handle_data.fail() || io_handle_data.bad()) { + MS_LOG(ERROR) << "File write failed"; + out->close(); + return FAILED; + } + } + return SUCCESS; +} + +MSRStatus ShardWriter::FlushRawChunk(const std::shared_ptr &out, + const std::vector> &rows_in_group, const int &chunk_id, + const std::vector> &bin_raw_data) { + for (int i = rows_in_group[chunk_id].first; i < rows_in_group[chunk_id].second; i++) { + // Write the size of multi schemas + for (uint32_t j = 0; j < schema_count_; ++j) { + uint64_t line_len = bin_raw_data[i * schema_count_ + j].size(); + auto &io_handle = out->write(reinterpret_cast(&line_len), kInt64Len); + if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { + MS_LOG(ERROR) << "File write failed"; + out->close(); + return FAILED; + } + } + // Write the data of multi schemas + for (uint32_t j = 0; j < schema_count_; ++j) { + auto line = bin_raw_data[i * schema_count_ + j]; + auto &io_handle = out->write(reinterpret_cast(&line[0]), line.size()); + if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { + MS_LOG(ERROR) << "File write failed"; + out->close(); + return FAILED; + } + } + } + return SUCCESS; +} + +// Allocate data to shards evenly +std::vector> ShardWriter::BreakIntoShards() { + std::vector> shards; + int row_in_shard = row_count_ / shard_count_; + int remains = row_count_ % shard_count_; + + std::vector v_list(shard_count_); + std::iota(v_list.begin(), v_list.end(), 0); + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(v_list.begin(), v_list.end(), g); + std::unordered_set set(v_list.begin(), v_list.begin() + remains); + + if (shard_count_ <= kMaxShardCount) { + int start_row = 0; + for (int i = 0; i < shard_count_; ++i) { + int end_row = start_row + row_in_shard; + if (set.count(i)) end_row++; + shards.emplace_back(start_row, end_row); + start_row = end_row; + } + } + return shards; +} + +MSRStatus ShardWriter::WriteShardHeader() { + if (shard_header_ == nullptr) { + MS_LOG(ERROR) << "Shard header is null"; + return FAILED; + } + auto shard_header = shard_header_->SerializeHeader(); + // Write header data to multi files + if (shard_count_ > static_cast(file_streams_.size()) || shard_count_ > static_cast(shard_header.size())) { + return FAILED; + } + if (shard_count_ <= kMaxShardCount) { + for (int shard_id = 0; shard_id < shard_count_; ++shard_id) { + auto &io_seekp = file_streams_[shard_id]->seekp(0, std::ios::beg); + if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { + MS_LOG(ERROR) << "File seekp failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + std::vector bin_header(shard_header[shard_id].begin(), shard_header[shard_id].end()); + uint64_t line_len = bin_header.size(); + if (line_len + kInt64Len > header_size_) { + MS_LOG(ERROR) << "Shard header is too big"; + return FAILED; + } + + auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast(&line_len), kInt64Len); + if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { + MS_LOG(ERROR) << "File write failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + auto &io_handle_header = file_streams_[shard_id]->write(reinterpret_cast(&bin_header[0]), line_len); + if (!io_handle_header.good() || io_handle_header.fail() || io_handle_header.bad()) { + MS_LOG(ERROR) << "File write failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + file_streams_[shard_id]->close(); + } + } + return SUCCESS; +} + +MSRStatus ShardWriter::SerializeRawData(std::map> &raw_data, + std::vector> &bin_data, uint32_t row_count) { + // define the number of thread + uint32_t thread_num = std::thread::hardware_concurrency(); + if (thread_num == 0) thread_num = kThreadNumber; + // Set the number of samples processed by each thread + int group_num = ceil(row_count * 1.0 / thread_num); + std::vector thread_set(thread_num); + int work_thread_num = 0; + for (uint32_t x = 0; x < thread_num; ++x) { + int start_num = x * group_num; + int end_num = ((x + 1) * group_num > row_count) ? row_count : (x + 1) * group_num; + if (start_num >= end_num) { + continue; + } + // Define the run boundary and start the child thread + thread_set[x] = + std::thread(&ShardWriter::FillArray, this, start_num, end_num, std::ref(raw_data), std::ref(bin_data)); + work_thread_num++; + } + for (uint32_t x = 0; x < work_thread_num; ++x) { + // Set obstacles to prevent the main thread from running + thread_set[x].join(); + } + return flag_ == true ? FAILED : SUCCESS; +} + +MSRStatus ShardWriter::SetRawDataSize(const std::vector> &bin_raw_data) { + raw_data_size_ = std::vector(row_count_, 0); + for (uint32_t i = 0; i < row_count_; ++i) { + raw_data_size_[i] = std::accumulate( + bin_raw_data.begin() + (i * schema_count_), bin_raw_data.begin() + (i * schema_count_) + schema_count_, 0, + [](uint64_t accumulator, const std::vector &row) { return accumulator + kInt64Len + row.size(); }); + } + if (*std::max_element(raw_data_size_.begin(), raw_data_size_.end()) > page_size_) { + MS_LOG(ERROR) << "Page size is too small to save a row!"; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardWriter::SetBlobDataSize(const std::vector> &blob_data) { + blob_data_size_ = std::vector(row_count_); + (void)std::transform(blob_data.begin(), blob_data.end(), blob_data_size_.begin(), + [](const std::vector &row) { return kInt64Len + row.size(); }); + if (*std::max_element(blob_data_size_.begin(), blob_data_size_.end()) > page_size_) { + MS_LOG(ERROR) << "Page size is too small to save a row!"; + return FAILED; + } + return SUCCESS; +} + +void ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { + // Get last raw page + auto last_raw_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeRaw); + if (last_raw_page_id >= 0) { + auto page = shard_header_->GetPage(shard_id, last_raw_page_id); + last_raw_page = page.first; + } +} + +void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page) { + // Get last blob page + auto last_blob_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeBlob); + if (last_blob_page_id >= 0) { + auto page = shard_header_->GetPage(shard_id, last_blob_page_id); + last_blob_page = page.first; + } +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc new file mode 100644 index 0000000000..eb1428a2ad --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_category.h" + +namespace mindspore { +namespace mindrecord { +ShardCategory::ShardCategory(const std::vector> &categories, int64_t num_elements, + bool replacement) + : categories_(categories), + category_field_(""), + num_elements_(num_elements), + num_categories_(0), + replacement_(replacement) {} + +ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elements, int64_t num_categories, + bool replacement) + : categories_({}), + category_field_(category_field), + num_elements_(num_elements), + num_categories_(num_categories), + replacement_(replacement) {} + +MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; } + +int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (dataset_size == 0) return dataset_size; + if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) { + return std::min(num_categories_, num_classes) * num_elements_; + } + return 0; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc new file mode 100644 index 0000000000..4cc5e9f413 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc @@ -0,0 +1,496 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_column.h" + +#include "common/utils.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" + +namespace mindspore { +namespace mindrecord { +ShardColumn::ShardColumn(const std::shared_ptr &shard_header, bool compress_integer) { + auto first_schema = shard_header->GetSchemas()[0]; + auto schema = first_schema->GetSchema()["schema"]; + + bool has_integer_array = false; + for (json::iterator it = schema.begin(); it != schema.end(); ++it) { + const std::string &column_name = it.key(); + column_name_.push_back(column_name); + + json it_value = it.value(); + + std::string str_type = it_value["type"]; + column_data_type_.push_back(ColumnDataTypeMap.at(str_type)); + if (it_value.find("shape") != it_value.end()) { + std::vector vec(it_value["shape"].size()); + std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); + column_shape_.push_back(vec); + if (str_type == "int32" || str_type == "int64") { + has_integer_array = true; + } + } else { + std::vector vec = {}; + column_shape_.push_back(vec); + } + } + + for (uint64_t i = 0; i < column_name_.size(); i++) { + column_name_id_[column_name_[i]] = i; + } + + auto blob_fields = first_schema->GetBlobFields(); + + for (const auto &field : blob_fields) { + blob_column_.push_back(field); + } + + for (uint64_t i = 0; i < blob_column_.size(); i++) { + blob_column_id_[blob_column_[i]] = i; + } + + has_compress_blob_ = (compress_integer && has_integer_array); + num_blob_column_ = blob_column_.size(); +} + +std::pair ShardColumn::GetColumnTypeByName(const std::string &column_name, + ColumnDataType *column_data_type, + uint64_t *column_data_type_size, + std::vector *column_shape) { + // Skip if column not found + auto column_category = CheckColumnName(column_name); + if (column_category == ColumnNotFound) { + return {FAILED, ColumnNotFound}; + } + + // Get data type and size + auto column_id = column_name_id_[column_name]; + *column_data_type = column_data_type_[column_id]; + *column_data_type_size = ColumnDataTypeSize[*column_data_type]; + *column_shape = column_shape_[column_id]; + + return {SUCCESS, column_category}; +} + +MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, + const json &columns_json, const unsigned char **data, + std::unique_ptr *data_ptr, uint64_t *const n_bytes, + ColumnDataType *column_data_type, uint64_t *column_data_type_size, + std::vector *column_shape) { + // Skip if column not found + auto column_category = CheckColumnName(column_name); + if (column_category == ColumnNotFound) { + return FAILED; + } + + // Get data type and size + auto column_id = column_name_id_[column_name]; + *column_data_type = column_data_type_[column_id]; + *column_data_type_size = ColumnDataTypeSize[*column_data_type]; + *column_shape = column_shape_[column_id]; + + // Retrieve value from json + if (column_category == ColumnInRaw) { + if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) { + MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << "."; + return FAILED; + } + *data = reinterpret_cast(data_ptr->get()); + return SUCCESS; + } + + // Retrieve value from blob + if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) { + MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << "."; + return FAILED; + } + if (*data == nullptr) { + *data = reinterpret_cast(data_ptr->get()); + } + return SUCCESS; +} + +MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, + std::unique_ptr *data_ptr, uint64_t *n_bytes) { + auto column_id = column_name_id_[column_name]; + auto column_data_type = column_data_type_[column_id]; + + // Initialize num bytes + *n_bytes = ColumnDataTypeSize[column_data_type]; + auto json_column_value = columns_json[column_name]; + switch (column_data_type) { + case ColumnFloat32: { + return GetFloat(data_ptr, json_column_value, false); + } + case ColumnFloat64: { + return GetFloat(data_ptr, json_column_value, true); + } + case ColumnInt32: { + return GetInt(data_ptr, json_column_value); + } + case ColumnInt64: { + return GetInt(data_ptr, json_column_value); + } + default: { + // Convert string to c_str + std::string tmp_string = json_column_value; + *n_bytes = tmp_string.size(); + auto data = reinterpret_cast(common::SafeCStr(tmp_string)); + *data_ptr = std::make_unique(*n_bytes); + for (uint32_t i = 0; i < *n_bytes; i++) { + (*data_ptr)[i] = *(data + i); + } + break; + } + } + return SUCCESS; +} + +template +MSRStatus ShardColumn::GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, + bool use_double) { + std::unique_ptr array_data = std::make_unique(1); + if (!json_column_value.is_string() && !json_column_value.is_number()) { + MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; + return FAILED; + } + if (json_column_value.is_number()) { + array_data[0] = json_column_value; + } else { + // Convert string to float + try { + if (use_double) { + array_data[0] = json_column_value.get(); + } else { + array_data[0] = json_column_value.get(); + } + } catch (json::exception &e) { + MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; + return FAILED; + } + } + + auto data = reinterpret_cast(array_data.get()); + *data_ptr = std::make_unique(sizeof(T)); + for (uint32_t i = 0; i < sizeof(T); i++) { + (*data_ptr)[i] = *(data + i); + } + + return SUCCESS; +} + +template +MSRStatus ShardColumn::GetInt(std::unique_ptr *data_ptr, const json &json_column_value) { + std::unique_ptr array_data = std::make_unique(1); + int64_t temp_value; + bool less_than_zero = false; + + if (json_column_value.is_number_integer()) { + const json json_zero = 0; + if (json_column_value < json_zero) less_than_zero = true; + temp_value = json_column_value; + } else if (json_column_value.is_string()) { + std::string string_value = json_column_value; + + if (!string_value.empty() && string_value[0] == '-') { + try { + temp_value = std::stoll(string_value); + less_than_zero = true; + } catch (std::invalid_argument &e) { + MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; + return FAILED; + } catch (std::out_of_range &e) { + MS_LOG(ERROR) << "Conversion to int failed, out of range."; + return FAILED; + } + } else { + try { + temp_value = static_cast(std::stoull(string_value)); + } catch (std::invalid_argument &e) { + MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; + return FAILED; + } catch (std::out_of_range &e) { + MS_LOG(ERROR) << "Conversion to int failed, out of range."; + return FAILED; + } + } + } else { + MS_LOG(ERROR) << "Conversion to int failed."; + return FAILED; + } + + if ((less_than_zero && temp_value < static_cast(std::numeric_limits::min())) || + (!less_than_zero && static_cast(temp_value) > static_cast(std::numeric_limits::max()))) { + MS_LOG(ERROR) << "Conversion to int failed. Out of range"; + return FAILED; + } + array_data[0] = static_cast(temp_value); + + auto data = reinterpret_cast(array_data.get()); + *data_ptr = std::make_unique(sizeof(T)); + for (uint32_t i = 0; i < sizeof(T); i++) { + (*data_ptr)[i] = *(data + i); + } + + return SUCCESS; +} + +MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, + const unsigned char **data, std::unique_ptr *data_ptr, + uint64_t *const n_bytes) { + uint64_t offset_address = 0; + auto column_id = column_name_id_[column_name]; + if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) { + return FAILED; + } + + auto column_data_type = column_data_type_[column_id]; + if (has_compress_blob_ && column_data_type == ColumnInt32) { + if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { + return FAILED; + } + } else if (has_compress_blob_ && column_data_type == ColumnInt64) { + if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { + return FAILED; + } + } else { + *data = reinterpret_cast(&(columns_blob[offset_address])); + } + + return SUCCESS; +} + +ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { + auto it_column = column_name_id_.find(column_name); + if (it_column == column_name_id_.end()) { + return ColumnNotFound; + } + auto it_blob = blob_column_id_.find(column_name); + return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; +} + +std::vector ShardColumn::CompressBlob(const std::vector &blob) { + // Skip if no compress columns + if (!CheckCompressBlob()) return blob; + + std::vector dst_blob; + uint64_t i_src = 0; + for (int64_t i = 0; i < num_blob_column_; i++) { + // Get column data type + auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]]; + auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type; + + // Compress and return is blob has 1 column only + if (num_blob_column_ == 1) { + return CompressInt(blob, int_type); + } + + // Just copy and continue if column dat type is not int32/int64 + uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type); + if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) { + dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes); + i_src += kInt64Len + num_bytes; + continue; + } + + // Get column slice in source blob + std::vector blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes); + // Compress column + auto dst_blob_slice = CompressInt(blob_slice, int_type); + // Get new column size + auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type); + // Append new colmn size + dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end()); + // Append new colmn data + dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end()); + i_src += kInt64Len + num_bytes; + } + MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; + return dst_blob; +} + +vector ShardColumn::CompressInt(const vector &src_bytes, const IntegerType &int_type) { + uint64_t i_size = kUnsignedOne << static_cast(int_type); + // Get number of elements + uint64_t src_n_int = src_bytes.size() / i_size; + // Calculate bitmap size (bytes) + uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte; + + // Initilize destination blob, more space than needed, will be resized + vector dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0); + + // Write number of elements to destination blob + vector size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type); + for (uint64_t n = 0; n < kBytesOfColumnLen; n++) { + dst_bytes[n] = size_by_bytes[n]; + } + + // Write compressed int + uint64_t i_dst = kBytesOfColumnLen + bitmap_size; + for (uint64_t i = 0; i < src_n_int; i++) { + // Initialize destination data type + IntegerType dst_int_type = kInt8Type; + // Shift to next int position + uint64_t pos = i * (kUnsignedOne << static_cast(int_type)); + // Narrow down this int + int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); + + // Write this int to destination blob + uint64_t u_n = *reinterpret_cast(&i_n); + auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type); + for (uint64_t j = 0; j < (kUnsignedOne << static_cast(dst_int_type)); j++) { + dst_bytes[i_dst++] = temp_bytes[j]; + } + + // Update date type in bit map + dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |= + (static_cast(dst_int_type) << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte)))); + } + // Resize destination blob + dst_bytes.resize(i_dst); + MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << "."; + return dst_bytes; +} + +MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, + uint64_t *num_bytes, uint64_t *shift_idx) { + if (num_blob_column_ == 1) { + *num_bytes = columns_blob.size(); + *shift_idx = 0; + return SUCCESS; + } + auto blob_id = blob_column_id_[column_name_[column_id]]; + + for (int32_t i = 0; i < blob_id; i++) { + *shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); + } + *num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); + + (*shift_idx) += kInt64Len; + + return SUCCESS; +} + +template +MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, + const std::vector &columns_blob, uint64_t *num_bytes, + uint64_t shift_idx) { + auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); + *num_bytes = sizeof(T) * num_elements; + + // Parse integer array + uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte; + auto array_data = std::make_unique(num_elements); + + for (uint64_t i = 0; i < num_elements; i++) { + uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte]; + uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask; + auto mr_int_type = static_cast(i_type); + int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type); + i_source += (kUnsignedOne << i_type); + array_data[i] = static_cast(i64); + } + + auto data = reinterpret_cast(array_data.get()); + *data_ptr = std::make_unique(*num_bytes); + int ret_code = memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes); + if (ret_code != 0) { + MS_LOG(ERROR) << "Failed to copy data!"; + } + + return SUCCESS; +} + +uint64_t ShardColumn::BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &i_type) { + uint64_t result = 0; + for (uint64_t i = 0; i < (kUnsignedOne << static_cast(i_type)); i++) { + result = (result << kBitsOfByte) + bytes_array[pos + i]; + } + return result; +} + +std::vector ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) { + uint64_t n_bytes = kUnsignedOne << static_cast(i_type); + std::vector result(n_bytes, 0); + for (uint64_t i = 0; i < n_bytes; i++) { + result[n_bytes - 1 - i] = value & std::numeric_limits::max(); + value >>= kBitsOfByte; + } + return result; +} + +std::vector ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) { + uint64_t n_bytes = kUnsignedOne << static_cast(i_type); + std::vector result(n_bytes, 0); + for (uint64_t i = 0; i < n_bytes; i++) { + result[i] = value & std::numeric_limits::max(); + value >>= kBitsOfByte; + } + return result; +} + +int64_t ShardColumn::BytesLittleToMinIntType(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &src_i_type, IntegerType *dst_i_type) { + uint64_t u_temp = 0; + for (uint64_t i = 0; i < (kUnsignedOne << static_cast(src_i_type)); i++) { + u_temp = (u_temp << kBitsOfByte) + + bytes_array[pos + (kUnsignedOne << static_cast(src_i_type)) - kUnsignedOne - i]; + } + + int64_t i_out; + switch (src_i_type) { + case kInt8Type: { + i_out = (int8_t)(u_temp & std::numeric_limits::max()); + break; + } + case kInt16Type: { + i_out = (int16_t)(u_temp & std::numeric_limits::max()); + break; + } + case kInt32Type: { + i_out = (int32_t)(u_temp & std::numeric_limits::max()); + break; + } + case kInt64Type: { + i_out = (int64_t)(u_temp & std::numeric_limits::max()); + break; + } + default: { + i_out = 0; + } + } + + if (!dst_i_type) { + return i_out; + } + + if (i_out >= static_cast(std::numeric_limits::min()) && + i_out <= static_cast(std::numeric_limits::max())) { + *dst_i_type = kInt8Type; + } else if (i_out >= static_cast(std::numeric_limits::min()) && + i_out <= static_cast(std::numeric_limits::max())) { + *dst_i_type = kInt16Type; + } else if (i_out >= static_cast(std::numeric_limits::min()) && + i_out <= static_cast(std::numeric_limits::max())) { + *dst_i_type = kInt32Type; + } else { + *dst_i_type = kInt64Type; + } + return i_out; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc new file mode 100644 index 0000000000..4c7abbb4b4 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_distributed_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, + uint32_t seed) + : ShardSample(1, num_shards, shard_id), + shuffle_(shuffle), + no_of_padded_samples_(no_of_padded_samples), + first_epoch_(true) { + shuffle_op_ = std::make_shared(seed, kShuffleSample); +} + +ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed) + : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {} + +int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (no_of_padded_samples_ <= 0) { + if (dataset_size % denominator_ == 0) { + return dataset_size / denominator_ * numerator_; + } else { + return dataset_size / denominator_ * numerator_ + 1; + } + } else { + auto padded_size = dataset_size + no_of_padded_samples_; + if (padded_size % denominator_ == 0) { + return padded_size / denominator_ * numerator_; + } else { + return -1; + } + } + return 0; +} + +MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) { + auto total_no = tasks.Size(); + if (no_of_padded_samples_ > 0 && first_epoch_) { + if (total_no % denominator_ != 0) { + MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. " + << "task size: " << total_no << ", number padded: " << no_of_padded_samples_ + << ", denominator: " << denominator_; + return FAILED; + } + } + if (first_epoch_) { + first_epoch_ = false; + task_ = tasks; + } else { + tasks = task_; + } + if (shuffle_ == true) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc new file mode 100644 index 0000000000..500037399b --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc @@ -0,0 +1,725 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_header.h" + +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_page.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +std::atomic thread_status(false); +ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared(); } + +MSRStatus ShardHeader::InitializeHeader(const std::vector &headers, bool load_dataset) { + shard_count_ = headers.size(); + int shard_index = 0; + bool first = true; + for (const auto &header : headers) { + if (first) { + first = false; + if (ParseSchema(header["schema"]) != SUCCESS) { + return FAILED; + } + if (ParseIndexFields(header["index_fields"]) != SUCCESS) { + return FAILED; + } + if (ParseStatistics(header["statistics"]) != SUCCESS) { + return FAILED; + } + ParseShardAddress(header["shard_addresses"]); + header_size_ = header["header_size"].get(); + page_size_ = header["page_size"].get(); + } + ParsePage(header["page"], shard_index, load_dataset); + shard_index++; + } + return SUCCESS; +} + +MSRStatus ShardHeader::CheckFileStatus(const std::string &path) { + std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); + if (!fin) { + MS_LOG(ERROR) << "File does not exist or permission denied. path: " << path; + return FAILED; + } + if (fin.fail()) { + MS_LOG(ERROR) << "Failed to open file. path: " << path; + return FAILED; + } + + // fetch file size + auto &io_seekg = fin.seekg(0, std::ios::end); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + fin.close(); + MS_LOG(ERROR) << "File seekg failed"; + return FAILED; + } + + size_t file_size = fin.tellg(); + if (file_size < kMinFileSize) { + fin.close(); + MS_LOG(ERROR) << "File size %d is smaller than the minimum value."; + return FAILED; + } + fin.close(); + return SUCCESS; +} + +std::pair ShardHeader::ValidateHeader(const std::string &path) { + if (CheckFileStatus(path) != SUCCESS) { + return {FAILED, {}}; + } + + // read header size + json json_header; + std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); + if (!fin.is_open()) { + MS_LOG(ERROR) << "File seekg failed"; + return {FAILED, json_header}; + } + + uint64_t header_size = 0; + auto &io_read = fin.read(reinterpret_cast(&header_size), kInt64Len); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + fin.close(); + return {FAILED, json_header}; + } + + if (header_size > kMaxHeaderSize) { + fin.close(); + MS_LOG(ERROR) << "Header size is illegal."; + return {FAILED, json_header}; + } + + // read header content + std::vector header_content(header_size); + auto &io_read_content = fin.read(reinterpret_cast(&header_content[0]), header_size); + if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) { + MS_LOG(ERROR) << "File read failed"; + fin.close(); + return {FAILED, json_header}; + } + + fin.close(); + std::string raw_header_content = std::string(header_content.begin(), header_content.end()); + // parse json content + try { + json_header = json::parse(raw_header_content); + } catch (json::parse_error &e) { + MS_LOG(ERROR) << "Json parse error: " << e.what(); + return {FAILED, json_header}; + } + return {SUCCESS, json_header}; +} + +std::pair ShardHeader::BuildSingleHeader(const std::string &file_path) { + auto ret = ValidateHeader(file_path); + if (SUCCESS != ret.first) { + return {FAILED, json()}; + } + json raw_header = ret.second; + json header = {{"shard_addresses", raw_header["shard_addresses"]}, + {"header_size", raw_header["header_size"]}, + {"page_size", raw_header["page_size"]}, + {"index_fields", raw_header["index_fields"]}, + {"blob_fields", raw_header["schema"][0]["blob_fields"]}, + {"schema", raw_header["schema"][0]["schema"]}, + {"version", raw_header["version"]}}; + return {SUCCESS, header}; +} + +MSRStatus ShardHeader::BuildDataset(const std::vector &file_paths, bool load_dataset) { + uint32_t thread_num = std::thread::hardware_concurrency(); + if (thread_num == 0) thread_num = kThreadNumber; + uint32_t work_thread_num = 0; + uint32_t shard_count = file_paths.size(); + int group_num = ceil(shard_count * 1.0 / thread_num); + std::vector thread_set(thread_num); + std::vector headers(shard_count); + for (uint32_t x = 0; x < thread_num; ++x) { + int start_num = x * group_num; + int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num; + if (start_num >= end_num) { + continue; + } + + thread_set[x] = + std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths); + work_thread_num++; + } + + for (uint32_t x = 0; x < work_thread_num; ++x) { + thread_set[x].join(); + } + if (thread_status) { + thread_status = false; + return FAILED; + } + if (SUCCESS != InitializeHeader(headers, load_dataset)) { + return FAILED; + } + return SUCCESS; +} + +void ShardHeader::GetHeadersOneTask(int start, int end, std::vector &headers, + const vector &realAddresses) { + if (thread_status || end > realAddresses.size()) { + return; + } + for (int x = start; x < end; ++x) { + auto ret = ValidateHeader(realAddresses[x]); + if (SUCCESS != ret.first) { + thread_status = true; + return; + } + json header; + header = ret.second; + header["shard_addresses"] = realAddresses; + if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) { + MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() + << ", lib version is: " << kVersion; + thread_status = true; + return; + } + headers[x] = header; + } +} + +MSRStatus ShardHeader::InitByFiles(const std::vector &file_paths) { + std::vector file_names(file_paths.size()); + std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string { + if (GetFileName(fp).first == SUCCESS) { + return GetFileName(fp).second; + } + }); + + shard_addresses_ = std::move(file_names); + shard_count_ = file_paths.size(); + if (shard_count_ == 0) { + return FAILED; + } + if (shard_count_ <= kMaxShardCount) { + pages_.resize(shard_count_); + } else { + return FAILED; + } + return SUCCESS; +} + +void ShardHeader::ParseHeader(const json &header) {} + +MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { + std::vector> parsed_index_fields; + for (auto &index_field : index_fields) { + auto schema_id = index_field["schema_id"].get(); + std::string field_name = index_field["index_field"].get(); + std::pair parsed_index_field(schema_id, field_name); + parsed_index_fields.push_back(parsed_index_field); + } + if (!parsed_index_fields.empty() && AddIndexFields(parsed_index_fields) != SUCCESS) { + return FAILED; + } + return SUCCESS; +} + +void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { + // set shard_index when load_dataset is false + if (pages_.empty() && shard_count_ <= kMaxShardCount) { + pages_.resize(shard_count_); + } + for (auto &page : pages) { + int page_id = page["page_id"]; + int shard_id = page["shard_id"]; + std::string page_type = page["page_type"]; + int page_type_id = page["page_type_id"]; + auto start_row_id = page["start_row_id"].get(); + auto end_row_id = page["end_row_id"].get(); + + std::vector> row_group_ids(page["row_group_ids"].size()); + std::transform(page["row_group_ids"].begin(), page["row_group_ids"].end(), row_group_ids.begin(), + [](json rg) { return std::make_pair(rg["id"], rg["offset"].get()); }); + + auto page_size = page["page_size"].get(); + + std::shared_ptr parsed_page = std::make_shared(page_id, shard_id, page_type, page_type_id, start_row_id, + end_row_id, row_group_ids, page_size); + if (load_dataset == true) { + pages_[shard_id].push_back(std::move(parsed_page)); + } else { + pages_[shard_index].push_back(std::move(parsed_page)); + } + } +} + +MSRStatus ShardHeader::ParseStatistics(const json &statistics) { + for (auto &statistic : statistics) { + if (statistic.find("desc") == statistic.end() || statistic.find("statistics") == statistic.end()) { + MS_LOG(ERROR) << "Deserialize statistics failed, statistic: " << statistics.dump(); + return FAILED; + } + std::string statistic_description = statistic["desc"].get(); + json statistic_body = statistic["statistics"]; + std::shared_ptr parsed_statistic = Statistics::Build(statistic_description, statistic_body); + if (!parsed_statistic) { + return FAILED; + } + AddStatistic(parsed_statistic); + } + return SUCCESS; +} + +MSRStatus ShardHeader::ParseSchema(const json &schemas) { + for (auto &schema : schemas) { + // change how we get schemaBody once design is finalized + if (schema.find("desc") == schema.end() || schema.find("blob_fields") == schema.end() || + schema.find("schema") == schema.end()) { + MS_LOG(ERROR) << "Deserialize schema failed. schema: " << schema.dump(); + return FAILED; + } + std::string schema_description = schema["desc"].get(); + std::vector blob_fields = schema["blob_fields"].get>(); + json schema_body = schema["schema"]; + std::shared_ptr parsed_schema = Schema::Build(schema_description, schema_body); + if (!parsed_schema) { + return FAILED; + } + AddSchema(parsed_schema); + } + return SUCCESS; +} + +void ShardHeader::ParseShardAddress(const json &address) { + std::copy(address.begin(), address.end(), std::back_inserter(shard_addresses_)); +} + +std::vector ShardHeader::SerializeHeader() { + std::vector header; + auto index = SerializeIndexFields(); + auto stats = SerializeStatistics(); + auto schema = SerializeSchema(); + auto pages = SerializePage(); + auto address = SerializeShardAddress(); + if (shard_count_ > static_cast(pages.size())) { + return std::vector{}; + } + if (shard_count_ <= kMaxShardCount) { + for (int shardId = 0; shardId < shard_count_; shardId++) { + string s; + s += "{\"header_size\":" + std::to_string(header_size_) + ","; + s += "\"index_fields\":" + index + ","; + s += "\"page\":" + pages[shardId] + ","; + s += "\"page_size\":" + std::to_string(page_size_) + ","; + s += "\"schema\":" + schema + ","; + s += "\"shard_addresses\":" + address + ","; + s += "\"shard_id\":" + std::to_string(shardId) + ","; + s += "\"statistics\":" + stats + ","; + s += "\"version\":\"" + std::string(kVersion) + "\""; + s += "}"; + header.emplace_back(s); + } + } + return header; +} + +std::string ShardHeader::SerializeIndexFields() { + json j; + auto fields = index_->GetFields(); + for (const auto &field : fields) { + j.push_back({{"schema_id", field.first}, {"index_field", field.second}}); + } + return j.dump(); +} + +std::vector ShardHeader::SerializePage() { + std::vector pages; + for (auto &shard_pages : pages_) { + json j; + for (const auto &p : shard_pages) { + j.emplace_back(p->GetPage()); + } + pages.emplace_back(j.dump()); + } + return pages; +} + +std::string ShardHeader::SerializeStatistics() { + json j; + for (const auto &stats : statistics_) { + j.emplace_back(stats->GetStatistics()); + } + return j.dump(); +} + +std::string ShardHeader::SerializeSchema() { + json j; + for (const auto &schema : schema_) { + j.emplace_back(schema->GetSchema()); + } + return j.dump(); +} + +std::string ShardHeader::SerializeShardAddress() { + json j; + for (const auto &addr : shard_addresses_) { + j.emplace_back(GetFileName(addr).second); + } + return j.dump(); +} + +std::pair, MSRStatus> ShardHeader::GetPage(const int &shard_id, const int &page_id) { + if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { + return std::make_pair(pages_[shard_id][page_id], SUCCESS); + } else { + return std::make_pair(nullptr, FAILED); + } +} + +MSRStatus ShardHeader::SetPage(const std::shared_ptr &new_page) { + if (new_page == nullptr) { + return FAILED; + } + int shard_id = new_page->GetShardID(); + int page_id = new_page->GetPageID(); + if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { + pages_[shard_id][page_id] = new_page; + return SUCCESS; + } else { + return FAILED; + } +} + +MSRStatus ShardHeader::AddPage(const std::shared_ptr &new_page) { + if (new_page == nullptr) { + return FAILED; + } + int shard_id = new_page->GetShardID(); + int page_id = new_page->GetPageID(); + if (shard_id < static_cast(pages_.size()) && page_id == static_cast(pages_[shard_id].size())) { + pages_[shard_id].push_back(new_page); + return SUCCESS; + } else { + return FAILED; + } +} + +int64_t ShardHeader::GetLastPageId(const int &shard_id) { + if (shard_id >= static_cast(pages_.size())) { + return 0; + } + return pages_[shard_id].size() - 1; +} + +int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &page_type) { + if (shard_id >= static_cast(pages_.size())) { + return 0; + } + int last_page_id = -1; + for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { + if (pages_[shard_id][i - 1]->GetPageType() == page_type) { + last_page_id = pages_[shard_id][i - 1]->GetPageID(); + return last_page_id; + } + } + return last_page_id; +} + +const std::pair> ShardHeader::GetPageByGroupId(const int &group_id, + const int &shard_id) { + if (shard_id >= static_cast(pages_.size())) { + MS_LOG(ERROR) << "Shard id is more than sum of shards."; + return {FAILED, nullptr}; + } + for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { + auto page = pages_[shard_id][i - 1]; + if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { + return {SUCCESS, page}; + } + } + MS_LOG(ERROR) << "Could not get page by group id " << group_id; + return {FAILED, nullptr}; +} + +int ShardHeader::AddSchema(std::shared_ptr schema) { + if (schema == nullptr) { + MS_LOG(ERROR) << "Schema is illegal"; + return -1; + } + + if (!schema_.empty()) { + MS_LOG(ERROR) << "Only support one schema"; + return -1; + } + + int64_t schema_id = schema->GetSchemaID(); + if (schema_id == -1) { + schema_id = schema_.size(); + schema->SetSchemaID(schema_id); + } + schema_.push_back(schema); + return schema_id; +} + +void ShardHeader::AddStatistic(std::shared_ptr statistic) { + if (statistic) { + int64_t statistics_id = statistic->GetStatisticsID(); + if (statistics_id == -1) { + statistics_id = statistics_.size(); + statistic->SetStatisticsID(statistics_id); + } + statistics_.push_back(statistic); + } +} + +std::shared_ptr ShardHeader::InitIndexPtr() { + std::shared_ptr index = index_; + if (!index_) { + index = std::make_shared(); + index_ = index; + } + return index; +} + +MSRStatus ShardHeader::CheckIndexField(const std::string &field, const json &schema) { + // check field name is or is not valid + if (schema.find(field) == schema.end()) { + MS_LOG(ERROR) << "Schema do not contain the field: " << field << "."; + return FAILED; + } + + if (schema[field]["type"] == "bytes") { + MS_LOG(ERROR) << field << " is bytes type, can not be schema index field."; + return FAILED; + } + + if (schema.find(field) != schema.end() && schema[field].find("shape") != schema[field].end()) { + MS_LOG(ERROR) << field << " array can not be schema index field."; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardHeader::AddIndexFields(const std::vector &fields) { + // create index Object + std::shared_ptr index = InitIndexPtr(); + + if (fields.size() == kInt0) { + MS_LOG(ERROR) << "There are no index fields"; + return FAILED; + } + + if (GetSchemas().empty()) { + MS_LOG(ERROR) << "No schema is set"; + return FAILED; + } + + for (const auto &schemaPtr : schema_) { + auto result = GetSchemaByID(schemaPtr->GetSchemaID()); + if (result.second != SUCCESS) { + MS_LOG(ERROR) << "Could not get schema by id."; + return FAILED; + } + + if (result.first == nullptr) { + MS_LOG(ERROR) << "Could not get schema by id."; + return FAILED; + } + + json schema = result.first->GetSchema().at("schema"); + + // checkout and add fields for each schema + std::set field_set; + for (const auto &item : index->GetFields()) { + field_set.insert(item.second); + } + for (const auto &field : fields) { + if (field_set.find(field) != field_set.end()) { + MS_LOG(ERROR) << "Add same index field twice"; + return FAILED; + } + + // check field name is or is not valid + if (CheckIndexField(field, schema) == FAILED) { + return FAILED; + } + field_set.insert(field); + + // add field into index + index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); + } + } + + index_ = index; + return SUCCESS; +} + +MSRStatus ShardHeader::GetAllSchemaID(std::set &bucket_count) { + // get all schema id + for (const auto &schema : schema_) { + auto bucket_it = bucket_count.find(schema->GetSchemaID()); + if (bucket_it != bucket_count.end()) { + MS_LOG(ERROR) << "Schema duplication"; + return FAILED; + } else { + bucket_count.insert(schema->GetSchemaID()); + } + } + return SUCCESS; +} + +MSRStatus ShardHeader::AddIndexFields(std::vector> fields) { + // create index Object + std::shared_ptr index = InitIndexPtr(); + + if (fields.size() == kInt0) { + MS_LOG(ERROR) << "There are no index fields"; + return FAILED; + } + + // get all schema id + std::set bucket_count; + if (GetAllSchemaID(bucket_count) != SUCCESS) { + return FAILED; + } + + // check and add fields for each schema + std::set> field_set; + for (const auto &item : index->GetFields()) { + field_set.insert(item); + } + for (const auto &field : fields) { + if (field_set.find(field) != field_set.end()) { + MS_LOG(ERROR) << "Add same index field twice"; + return FAILED; + } + + uint64_t schema_id = field.first; + std::string field_name = field.second; + + // check schemaId is or is not valid + if (bucket_count.find(schema_id) == bucket_count.end()) { + MS_LOG(ERROR) << "Illegal schema id: " << schema_id; + return FAILED; + } + + // check field name is or is not valid + auto result = GetSchemaByID(schema_id); + if (result.second != SUCCESS) { + MS_LOG(ERROR) << "Could not get schema by id."; + return FAILED; + } + json schema = result.first->GetSchema().at("schema"); + if (schema.find(field_name) == schema.end()) { + MS_LOG(ERROR) << "Schema " << schema_id << " do not contain the field: " << field_name; + return FAILED; + } + + if (CheckIndexField(field_name, schema) == FAILED) { + return FAILED; + } + + field_set.insert(field); + + // add field into index + index.get()->AddIndexField(schema_id, field_name); + } + index_ = index; + return SUCCESS; +} + +std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { + if (shard_id >= shard_addresses_.size()) { + return ""; + } + return shard_addresses_.at(shard_id); +} + +std::vector> ShardHeader::GetSchemas() { return schema_; } + +std::vector> ShardHeader::GetStatistics() { return statistics_; } + +std::vector> ShardHeader::GetFields() { return index_->GetFields(); } + +std::shared_ptr ShardHeader::GetIndex() { return index_; } + +std::pair, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { + int64_t schemaSize = schema_.size(); + if (schema_id < 0 || schema_id >= schemaSize) { + MS_LOG(ERROR) << "Illegal schema id"; + return std::make_pair(nullptr, FAILED); + } + return std::make_pair(schema_.at(schema_id), SUCCESS); +} + +std::pair, MSRStatus> ShardHeader::GetStatisticByID(int64_t statistic_id) { + int64_t statistics_size = statistics_.size(); + if (statistic_id < 0 || statistic_id >= statistics_size) { + return std::make_pair(nullptr, FAILED); + } + return std::make_pair(statistics_.at(statistic_id), SUCCESS); +} + +MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { + // write header content to file, dump whatever is in the file before + std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); + if (page_out_handle.fail()) { + MS_LOG(ERROR) << "Failed in opening page file"; + return FAILED; + } + + auto pages = SerializePage(); + for (const auto &shard_pages : pages) { + page_out_handle << shard_pages << "\n"; + } + + page_out_handle.close(); + return SUCCESS; +} + +MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { + for (auto &v : pages_) { // clean pages + v.clear(); + } + // attempt to open the file contains the page in json + std::ifstream page_in_handle(dump_file_name.c_str()); + + if (!page_in_handle.good()) { + MS_LOG(INFO) << "No page file exists."; + return SUCCESS; + } + + std::string line; + while (std::getline(page_in_handle, line)) { + ParsePage(json::parse(line), -1, true); + } + + page_in_handle.close(); + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_index.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_index.cc new file mode 100644 index 0000000000..73397b5bba --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_index.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_index.h" + +namespace mindspore { +namespace mindrecord { +// table name for index +const char TABLENAME[] = "index_table"; + +Index::Index() : database_name_(""), table_name_(TABLENAME) {} + +void Index::AddIndexField(const int64_t &schemaId, const std::string &field) { + fields_.emplace_back(pair(schemaId, field)); +} + +// Get attribute list +std::vector> Index::GetFields() { return fields_; } +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_page.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_page.cc new file mode 100644 index 0000000000..ba2292415f --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_page.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_page.h" +#include "pybind11/pybind11.h" + +namespace mindspore { +namespace mindrecord { +json Page::GetPage() const { + json str_page; + str_page["page_id"] = page_id_; + str_page["shard_id"] = shard_id_; + str_page["page_type"] = page_type_; + str_page["page_type_id"] = page_type_id_; + str_page["start_row_id"] = start_row_id_; + str_page["end_row_id"] = end_row_id_; + if (row_group_ids_.size() == 0) { + json row_groups = json({}); + row_groups["id"] = 0; + row_groups["offset"] = 0; + str_page["row_group_ids"].push_back(row_groups); + } else { + for (const auto &rg : row_group_ids_) { + json row_groups = json({}); + row_groups["id"] = rg.first; + row_groups["offset"] = rg.second; + str_page["row_group_ids"].push_back(row_groups); + } + } + str_page["page_size"] = page_size_; + return str_page; +} + +void Page::DeleteLastGroupId() { + if (!row_group_ids_.empty()) { + page_size_ = row_group_ids_.back().second; + row_group_ids_.pop_back(); + } +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc new file mode 100644 index 0000000000..081a48352d --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_pk_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) + : ShardCategory(category_field, num_elements, std::numeric_limits::max(), true), shuffle_(false) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, + uint32_t seed) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { + shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement +} + +MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) { + if (shuffle_ == true) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc new file mode 100644 index 0000000000..808ab55bfb --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc @@ -0,0 +1,141 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardSample::ShardSample(int n) + : numerator_(0), + denominator_(0), + partition_id_(0), + no_of_samples_(n), + indices_({}), + sampler_type_(kCustomTopNSampler) {} + +ShardSample::ShardSample(int num, int den) + : numerator_(num), + denominator_(den), + partition_id_(0), + no_of_samples_(0), + indices_({}), + sampler_type_(kCustomTopPercentSampler) {} + +ShardSample::ShardSample(int num, int den, int par) + : numerator_(num), + denominator_(den), + partition_id_(par), + no_of_samples_(0), + indices_({}), + sampler_type_(kCustomTopPercentSampler) {} + +ShardSample::ShardSample(const std::vector &indices, uint32_t seed) + : numerator_(0), + denominator_(0), + partition_id_(0), + no_of_samples_(0), + indices_(indices), + sampler_type_(kSubsetRandomSampler) { + shuffle_op_ = std::make_shared(seed); +} + +int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (sampler_type_ == kCustomTopNSampler) { + return no_of_samples_; + } + + if (sampler_type_ == kCustomTopPercentSampler) { + if (dataset_size % denominator_ == 0) { + return dataset_size / denominator_ * numerator_; + } else { + return dataset_size / denominator_ * numerator_ + 1; + } + } + if (sampler_type_ == kSubsetRandomSampler) { + return indices_.size(); + } + return 0; +} + +MSRStatus ShardSample::Execute(ShardTask &tasks) { + int no_of_categories = static_cast(tasks.categories); + int total_no = static_cast(tasks.Size()); // make sure task_size + + int taking = 0; + if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1 + no_of_samples_ = std::min(no_of_samples_, total_no); + taking = no_of_samples_ - no_of_samples_ % no_of_categories; + } else if (sampler_type_ == kSubsetRandomSampler) { + if (indices_.size() > total_no) { + MS_LOG(ERROR) << "parameter indices's size is greater than dataset size."; + return FAILED; + } + } else { // constructor TopPercent + if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { + if (numerator_ == 1 && denominator_ > 1) { // sharding + taking = (total_no + denominator_ - 1) / denominator_; + } else { // non sharding + taking = total_no * numerator_ / denominator_; + taking -= (taking % no_of_categories); + } + } else { + MS_LOG(ERROR) << "parameter numerator or denominator is illegal"; + return FAILED; + } + } + + if (tasks.permutation_.empty()) { + ShardTask new_tasks; + total_no = static_cast(tasks.Size()); + if (sampler_type_ == kSubsetRandomSampler) { + for (int i = 0; i < indices_.size(); ++i) { + int index = ((indices_[i] % total_no) + total_no) % total_no; + new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python + } + } else { + for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { + new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start + } + } + std::swap(tasks, new_tasks); + } else { + ShardTask new_tasks; + if (taking > static_cast(tasks.permutation_.size())) { + return FAILED; + } + total_no = static_cast(tasks.permutation_.size()); + for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { + new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); + } + std::swap(tasks, new_tasks); + } + return SUCCESS; +} + +MSRStatus ShardSample::SufExecute(ShardTask &tasks) { + if (sampler_type_ == kSubsetRandomSampler) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc new file mode 100644 index 0000000000..093be9792f --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc @@ -0,0 +1,164 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_schema.h" +#include "common/utils.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +std::shared_ptr Schema::Build(std::string desc, const json &schema) { + // validate check + if (!Validate(schema)) { + return nullptr; + } + + std::vector blob_fields = PopulateBlobFields(schema); + Schema object_schema; + object_schema.desc_ = std::move(desc); + object_schema.blob_fields_ = std::move(blob_fields); + object_schema.schema_ = schema; + object_schema.schema_id_ = -1; + return std::make_shared(object_schema); +} + +std::shared_ptr Schema::Build(std::string desc, pybind11::handle schema) { + // validate check + json schema_json = nlohmann::detail::ToJsonImpl(schema); + return Build(std::move(desc), schema_json); +} + +std::string Schema::GetDesc() const { return desc_; } + +json Schema::GetSchema() const { + json str_schema; + str_schema["desc"] = desc_; + str_schema["schema"] = schema_; + str_schema["blob_fields"] = blob_fields_; + return str_schema; +} + +pybind11::object Schema::GetSchemaForPython() const { + json schema_json = GetSchema(); + pybind11::object schema_py = nlohmann::detail::FromJsonImpl(schema_json); + return schema_py; +} + +void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } + +int64_t Schema::GetSchemaID() const { return schema_id_; } + +std::vector Schema::GetBlobFields() const { return blob_fields_; } + +std::vector Schema::PopulateBlobFields(json schema) { + std::vector blob_fields; + for (json::iterator it = schema.begin(); it != schema.end(); ++it) { + json it_value = it.value(); + if ((it_value.size() == kInt2 && it_value.find("shape") != it_value.end()) || it_value["type"] == "bytes") { + blob_fields.emplace_back(it.key()); + } + } + return blob_fields; +} + +bool Schema::ValidateNumberShape(const json &it_value) { + if (it_value.find("shape") == it_value.end()) { + MS_LOG(ERROR) << "%s supports shape only." << it_value["type"].dump(); + return false; + } + + auto shape = it_value["shape"]; + if (!shape.is_array()) { + MS_LOG(ERROR) << "%s shape format is wrong." << it_value["type"].dump(); + return false; + } + + int num_negtive_one = 0; + for (const auto &i : shape) { + if (i == 0 || i < -1) { + MS_LOG(ERROR) << "Shape %s, number is wrong." << it_value["shape"].dump(); + return false; + } + if (i == -1) { + num_negtive_one++; + } + } + + if (num_negtive_one > 1) { + MS_LOG(ERROR) << "Shape %s, have at most 1 variable-length dimension." << it_value["shape"].dump(); + return false; + } + + return true; +} + +bool Schema::Validate(json schema) { + if (schema.size() == kInt0) { + MS_LOG(ERROR) << "Schema is null"; + return false; + } + + for (json::iterator it = schema.begin(); it != schema.end(); ++it) { + // make sure schema key name must be composed of '0-9' or 'a-z' or 'A-Z' or '_' + if (!ValidateFieldName(it.key())) { + MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', fieldName: " << it.key(); + return false; + } + + json it_value = it.value(); + if (it_value.find("type") == it_value.end()) { + MS_LOG(ERROR) << "No 'type' field exist: " << it_value.dump(); + return false; + } + + if (kFieldTypeSet.find(it_value["type"]) == kFieldTypeSet.end()) { + MS_LOG(ERROR) << "Wrong type: " << it_value["type"].dump(); + return false; + } + + if (it_value.size() == kInt1) { + continue; + } + + if (it_value["type"] == "bytes" || it_value["type"] == "string") { + MS_LOG(ERROR) << it_value["type"].dump() << " can not 1 field only."; + return false; + } + + if (it_value.size() != kInt2) { + MS_LOG(ERROR) << it_value["type"].dump() << " can have at most 2 fields."; + return false; + } + + if (!ValidateNumberShape(it_value)) { + return false; + } + } + + return true; +} + +bool Schema::operator==(const mindrecord::Schema &b) const { + if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) { + return false; + } + return true; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc new file mode 100644 index 0000000000..3aa695e03b --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_sequential_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardSequentialSample::ShardSequentialSample(int n, int offset) + : ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {} + +ShardSequentialSample::ShardSequentialSample(float per, float per_offset) + : ShardSample(0), offset_(0), per_(per), per_offset_(per_offset) {} + +int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { + return dataset_size; + } + if (per_ > kEpsilon && per_ <= 1.0f) { + return dataset_size * kEpsilon; + } + return no_of_samples_; +} + +MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { + int total_no = static_cast(tasks.Size()); + int taking; + if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { + taking = total_no; + } else if (per_ > kEpsilon && per_ <= 1.0f) { + taking = total_no * kEpsilon; + } else { + taking = no_of_samples_; + } + + if (tasks.permutation_.empty()) { + ShardTask new_tasks; + total_no = static_cast(tasks.Size()); + for (int i = offset_; i < taking + offset_; ++i) { + new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); + } + std::swap(tasks, new_tasks); + } else { // shuffled + ShardTask new_tasks; + if (taking > static_cast(tasks.permutation_.size())) { + return FAILED; + } + total_no = static_cast(tasks.permutation_.size()); + for (size_t i = offset_; i < taking + offset_; ++i) { + new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); + } + std::swap(tasks, new_tasks); + } + return SUCCESS; +} + +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc new file mode 100644 index 0000000000..7743cabea3 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_shuffle.h" + +#include + +namespace mindspore { +namespace mindrecord { +ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) + : shuffle_seed_(seed), + no_of_samples_(0), + replacement_(false), + reshuffle_each_epoch_(true), + shuffle_type_(shuffle_type) {} + +ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, + ShuffleType shuffle_type) + : shuffle_seed_(seed), + no_of_samples_(no_of_samples), + replacement_(replacement), + reshuffle_each_epoch_(reshuffle_each_epoch), + shuffle_type_(shuffle_type) {} + +int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (replacement_) { + return no_of_samples_ == 0 ? dataset_size : no_of_samples_; + } + return dataset_size; +} + +MSRStatus ShardShuffle::Execute(ShardTask &tasks) { + if (reshuffle_each_epoch_) shuffle_seed_++; + if (tasks.categories < 1) { + return FAILED; + } + if (shuffle_type_ == kShuffleSample) { // shuffle each sample + if (tasks.permutation_.empty() == true) { + tasks.MakePerm(); + } + if (replacement_ == true) { + ShardTask new_tasks; + if (no_of_samples_ == 0) { + no_of_samples_ = static_cast(tasks.Size()); + } + if (no_of_samples_ <= 0) { + MS_LOG(ERROR) << "no_of_samples need to be positive."; + return FAILED; + } + new_tasks.task_list_.reserve(no_of_samples_); + for (uint32_t i = 0; i < no_of_samples_; ++i) { + new_tasks.InsertTask(tasks.GetRandomTask()); + } + std::swap(tasks, new_tasks); + } else { + std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); + } + } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) + uint32_t individual_size = tasks.Size() / tasks.categories; + std::vector> new_permutations(tasks.categories, std::vector(individual_size)); + for (uint32_t i = 0; i < tasks.categories; i++) { + for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); + std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); + } + tasks.permutation_.clear(); + for (uint32_t j = 0; j < individual_size; j++) { + for (uint32_t i = 0; i < tasks.categories; i++) { + tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); + } + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc new file mode 100644 index 0000000000..7024a2ab06 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_statistics.h" +#include "pybind11/pybind11.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +std::shared_ptr Statistics::Build(std::string desc, const json &statistics) { + // validate check + if (!Validate(statistics)) { + return nullptr; + } + Statistics object_statistics; + object_statistics.desc_ = std::move(desc); + object_statistics.statistics_ = statistics; + object_statistics.statistics_id_ = -1; + return std::make_shared(object_statistics); +} + +std::shared_ptr Statistics::Build(std::string desc, pybind11::handle statistics) { + // validate check + json statistics_json = nlohmann::detail::ToJsonImpl(statistics); + if (!Validate(statistics_json)) { + return nullptr; + } + Statistics object_statistics; + object_statistics.desc_ = std::move(desc); + object_statistics.statistics_ = statistics_json; + object_statistics.statistics_id_ = -1; + return std::make_shared(object_statistics); +} + +std::string Statistics::GetDesc() const { return desc_; } + +json Statistics::GetStatistics() const { + json str_statistics; + str_statistics["desc"] = desc_; + str_statistics["statistics"] = statistics_; + return str_statistics; +} + +pybind11::object Statistics::GetStatisticsForPython() const { + json str_statistics = Statistics::GetStatistics(); + return nlohmann::detail::FromJsonImpl(str_statistics); +} + +void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } + +int64_t Statistics::GetStatisticsID() const { return statistics_id_; } + +bool Statistics::Validate(const json &statistics) { + if (statistics.size() != kInt1) { + MS_LOG(ERROR) << "Statistics object is null"; + return false; + } + if (statistics.find("level") == statistics.end()) { + MS_LOG(ERROR) << "There is not 'level' object in statistic"; + return false; + } + return LevelRecursive(statistics["level"]); +} + +bool Statistics::LevelRecursive(json level) { + bool ini = true; + for (json::iterator it = level.begin(); it != level.end(); ++it) { + json a = it.value(); + if (a.size() == kInt2) { + if ((a.find("key") == a.end()) || (a.find("count") == a.end())) { + MS_LOG(ERROR) << "The node field is 2, but 'key'/'count' is not existed"; + return false; + } + } else if (a.size() == kInt3) { + if ((a.find("key") == a.end()) || (a.find("count") == a.end()) || a.find("level") == a.end()) { + MS_LOG(ERROR) << "The node field is 3, but 'key'/'count'/'level' is not existed"; + return false; + } else { + ini = LevelRecursive(a.at("level")); + } + } else { + MS_LOG(ERROR) << "The node field is not equal 2/3"; + return false; + } + } + return ini; +} + +bool Statistics::operator==(const Statistics &b) const { + if (this->GetStatistics() != b.GetStatistics()) { + return false; + } + return true; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc new file mode 100644 index 0000000000..6f8e440f91 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_task.h" +#include "common/utils.h" +#include "minddata/mindrecord/include/common/shard_utils.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::DEBUG; + +namespace mindspore { +namespace mindrecord { +ShardTask::ShardTask() : categories(1) {} + +ShardTask::ShardTask(const ShardTask &other) + : categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {} + +ShardTask &ShardTask::operator=(const ShardTask &other) { + ShardTask tmp(other); + std::swap(categories, tmp.categories); + permutation_.swap(tmp.permutation_); + task_list_.swap(tmp.task_list_); + return *this; +} + +void ShardTask::MakePerm() { + permutation_ = std::vector(task_list_.size()); + for (uint32_t i = 0; i < task_list_.size(); i++) { + permutation_[i] = static_cast(i); + } +} + +void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, + const json &label) { + MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id + << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; + task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); +} + +void ShardTask::InsertTask(std::tuple, std::vector, json> task) { + MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) + << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() + << ", size of task_list_: " << task_list_.size() << "."; + + task_list_.push_back(std::move(task)); +} + +void ShardTask::PopBack() { task_list_.pop_back(); } + +uint32_t ShardTask::Size() const { return static_cast(task_list_.size()); } + +uint32_t ShardTask::SizeOfRows() const { + if (task_list_.size() == 0) return static_cast(0); + + // 1 task is 1 page + auto sum_num_rows = [](int x, std::tuple, std::vector, json> y) { + return x + std::get<2>(y)[0]; + }; + uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows); + return nRows; +} + +std::tuple, std::vector, json> &ShardTask::GetTaskByID(size_t id) { + MS_ASSERT(id < task_list_.size()); + return task_list_[id]; +} + +std::tuple, std::vector, json> &ShardTask::GetRandomTask() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, task_list_.size() - 1); + return task_list_[dis(gen)]; +} + +ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements) { + ShardTask res; + if (category_tasks.empty()) return res; + auto total_categories = category_tasks.size(); + res.categories = static_cast(total_categories); + if (replacement == false) { + auto minTasks = category_tasks[0].Size(); + for (uint32_t i = 1; i < total_categories; i++) { + minTasks = std::min(minTasks, category_tasks[i].Size()); + } + for (uint32_t task_no = 0; task_no < minTasks; task_no++) { + for (uint32_t i = 0; i < total_categories; i++) { + res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast(task_no)))); + } + } + } else { + auto maxTasks = category_tasks[0].Size(); + for (uint32_t i = 1; i < total_categories; i++) { + maxTasks = std::max(maxTasks, category_tasks[i].Size()); + } + if (num_elements != std::numeric_limits::max()) { + maxTasks = static_cast(num_elements); + } + for (uint32_t i = 0; i < total_categories; i++) { + for (uint32_t j = 0; j < maxTasks; j++) { + res.InsertTask(category_tasks[i].GetRandomTask()); + } + } + } + return res; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/common/shard_error.cc b/mindspore/ccsrc/mindrecord/common/shard_error.cc deleted file mode 100644 index ad68aaf92c..0000000000 --- a/mindspore/ccsrc/mindrecord/common/shard_error.cc +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_error.h" - -namespace mindspore { -namespace mindrecord { -std::string ErrnoToMessage(MSRStatus status) { - switch (status) { - case FAILED: - return "operator failed"; - break; - case SUCCESS: - return "operator success"; - break; - case OPEN_FILE_FAILED: - return "open file failed"; - break; - case CLOSE_FILE_FAILED: - return "close file failed"; - break; - case WRITE_METADATA_FAILED: - return "write metadata failed"; - break; - case WRITE_RAWDATA_FAILED: - return "write rawdata failed"; - break; - case GET_SCHEMA_FAILED: - return "get schema failed"; - break; - case ILLEGAL_RAWDATA: - return "illegal raw data"; - break; - case PYTHON_TO_JSON_FAILED: - return "pybind: python object to json failed"; - break; - case DIR_CREATE_FAILED: - return "directory create failed"; - break; - case OPEN_DIR_FAILED: - return "open directory failed"; - break; - case INVALID_STATISTICS: - return "invalid statistics object"; - break; - case OPEN_DATABASE_FAILED: - return "open database failed"; - break; - case CLOSE_DATABASE_FAILED: - return "close database failed"; - break; - case DATABASE_OPERATE_FAILED: - return "database operate failed"; - break; - case BUILD_SCHEMA_FAILED: - return "build schema failed"; - break; - case DIVISOR_IS_ILLEGAL: - return "divisor is illegal"; - break; - case INVALID_FILE_PATH: - return "file path is invalid"; - break; - case SECURE_FUNC_FAILED: - return "secure function failed"; - break; - case ALLOCATE_MEM_FAILED: - return "allocate memory failed"; - break; - case ILLEGAL_FIELD_NAME: - return "illegal field name"; - break; - case ILLEGAL_FIELD_TYPE: - return "illegal field type"; - break; - case SET_METADATA_FAILED: - return "set metadata failed"; - break; - case ILLEGAL_SCHEMA_DEFINITION: - return "illegal schema definition"; - break; - case ILLEGAL_COLUMN_LIST: - return "illegal column list"; - break; - case SQL_ERROR: - return "sql error"; - break; - case ILLEGAL_SHARD_COUNT: - return "illegal shard count"; - break; - case ILLEGAL_SCHEMA_COUNT: - return "illegal schema count"; - break; - case VERSION_ERROR: - return "data version is not matched"; - break; - case ADD_SCHEMA_FAILED: - return "add schema failed"; - break; - case ILLEGAL_Header_SIZE: - return "illegal header size"; - break; - case ILLEGAL_Page_SIZE: - return "illegal page size"; - break; - case ILLEGAL_SIZE_VALUE: - return "illegal size value"; - break; - case INDEX_FIELD_ERROR: - return "add index fields failed"; - break; - case GET_CANDIDATE_CATEGORYFIELDS_FAILED: - return "get candidate category fields failed"; - break; - case GET_CATEGORY_INFO_FAILED: - return "get category information failed"; - break; - case ILLEGAL_CATEGORY_ID: - return "illegal category id"; - break; - case ILLEGAL_ROWNUMBER_OF_PAGE: - return "illegal row number of page"; - break; - case ILLEGAL_SCHEMA_ID: - return "illegal schema id"; - break; - case DESERIALIZE_SCHEMA_FAILED: - return "deserialize schema failed"; - break; - case DESERIALIZE_STATISTICS_FAILED: - return "deserialize statistics failed"; - break; - case ILLEGAL_DB_FILE: - return "illegal db file"; - break; - case OVERWRITE_DB_FILE: - return "overwrite db file"; - break; - case OVERWRITE_MINDRECORD_FILE: - return "overwrite mindrecord file"; - break; - case ILLEGAL_MINDRECORD_FILE: - return "illegal mindrecord file"; - break; - case PARSE_JSON_FAILED: - return "parse json failed"; - break; - case ILLEGAL_PARAMETERS: - return "illegal parameters"; - break; - case GET_PAGE_BY_GROUP_ID_FAILED: - return "get page by group id failed"; - break; - case GET_SYSTEM_STATE_FAILED: - return "get system state failed"; - break; - case IO_FAILED: - return "io operate failed"; - break; - case MATCH_HEADER_FAILED: - return "match header failed"; - break; - default: - return "invalid error no"; - } -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc deleted file mode 100644 index ee923ebc97..0000000000 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ /dev/null @@ -1,230 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "common/utils.h" -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_index_generator.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_segment.h" -#include "mindrecord/include/shard_writer.h" -#include "nlohmann/json.hpp" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "utils/log_adapter.h" - -namespace py = pybind11; - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -void BindSchema(py::module *m) { - (void)py::class_>(*m, "Schema", py::module_local()) - .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Schema::Build) - .def("get_desc", &Schema::GetDesc) - .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) - .def("get_blob_fields", &Schema::GetBlobFields) - .def("get_schema_id", &Schema::GetSchemaID); -} - -void BindStatistics(const py::module *m) { - (void)py::class_>(*m, "Statistics", py::module_local()) - .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Statistics::Build) - .def("get_desc", &Statistics::GetDesc) - .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) - .def("get_statistics_id", &Statistics::GetStatisticsID); -} - -void BindShardHeader(const py::module *m) { - (void)py::class_>(*m, "ShardHeader", py::module_local()) - .def(py::init<>()) - .def("add_schema", &ShardHeader::AddSchema) - .def("add_statistics", &ShardHeader::AddStatistic) - .def("add_index_fields", - (MSRStatus(ShardHeader::*)(const std::vector &)) & ShardHeader::AddIndexFields) - .def("get_meta", &ShardHeader::GetSchemas) - .def("get_statistics", &ShardHeader::GetStatistics) - .def("get_fields", &ShardHeader::GetFields) - .def("get_schema_by_id", &ShardHeader::GetSchemaByID) - .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); -} - -void BindShardWriter(py::module *m) { - (void)py::class_(*m, "ShardWriter", py::module_local()) - .def(py::init<>()) - .def("open", &ShardWriter::Open) - .def("open_for_append", &ShardWriter::OpenForAppend) - .def("set_header_size", &ShardWriter::SetHeaderSize) - .def("set_page_size", &ShardWriter::SetPageSize) - .def("set_shard_header", &ShardWriter::SetShardHeader) - .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, - vector> &, bool, bool)) & - ShardWriter::WriteRawData) - .def("commit", &ShardWriter::Commit); -} - -void BindShardReader(const py::module *m) { - (void)py::class_>(*m, "ShardReader", py::module_local()) - .def(py::init<>()) - .def("open", (MSRStatus(ShardReader::*)(const std::vector &, bool, const int &, - const std::vector &, - const std::vector> &)) & - ShardReader::OpenPy) - .def("launch", &ShardReader::Launch) - .def("get_header", &ShardReader::GetShardHeader) - .def("get_blob_fields", &ShardReader::GetBlobFields) - .def("get_next", (std::vector>, pybind11::object>>(ShardReader::*)()) & - ShardReader::GetNextPy) - .def("finish", &ShardReader::Finish) - .def("close", &ShardReader::Close); -} - -void BindShardIndexGenerator(const py::module *m) { - (void)py::class_(*m, "ShardIndexGenerator", py::module_local()) - .def(py::init()) - .def("build", &ShardIndexGenerator::Build) - .def("write_to_db", &ShardIndexGenerator::WriteToDatabase); -} - -void BindShardSegment(py::module *m) { - (void)py::class_(*m, "ShardSegment", py::module_local()) - .def(py::init<>()) - .def("open", (MSRStatus(ShardSegment::*)(const std::vector &, bool, const int &, - const std::vector &, - const std::vector> &)) & - ShardSegment::OpenPy) - .def("get_category_fields", - (std::pair>(ShardSegment::*)()) & ShardSegment::GetCategoryFields) - .def("set_category_field", (MSRStatus(ShardSegment::*)(std::string)) & ShardSegment::SetCategoryField) - .def("read_category_info", (std::pair(ShardSegment::*)()) & ShardSegment::ReadCategoryInfo) - .def("read_at_page_by_id", (std::pair, pybind11::object>>>( - ShardSegment::*)(int64_t, int64_t, int64_t)) & - ShardSegment::ReadAtPageByIdPy) - .def("read_at_page_by_name", (std::pair, pybind11::object>>>( - ShardSegment::*)(std::string, int64_t, int64_t)) & - ShardSegment::ReadAtPageByNamePy) - .def("get_header", &ShardSegment::GetShardHeader) - .def("get_blob_fields", - (std::pair>(ShardSegment::*)()) & ShardSegment::GetBlobFields); -} - -void BindGlobalParams(py::module *m) { - (*m).attr("MIN_HEADER_SIZE") = kMinHeaderSize; - (*m).attr("MAX_HEADER_SIZE") = kMaxHeaderSize; - (*m).attr("MIN_PAGE_SIZE") = kMinPageSize; - (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize; - (*m).attr("MIN_SHARD_COUNT") = kMinShardCount; - (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount; - (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount; - (void)(*m).def("get_max_thread_num", &GetMaxThreadNum); -} - -PYBIND11_MODULE(_c_mindrecord, m) { - m.doc() = "pybind11 mindrecord plugin"; // optional module docstring - (void)py::enum_(m, "MSRStatus", py::module_local()) - .value("SUCCESS", SUCCESS) - .value("FAILED", FAILED) - .export_values(); - (void)py::enum_(m, "ShardType", py::module_local()).value("NLP", kNLP).value("CV", kCV).export_values(); - BindGlobalParams(&m); - BindSchema(&m); - BindStatistics(&m); - BindShardHeader(&m); - BindShardWriter(&m); - BindShardReader(&m); - BindShardIndexGenerator(&m); - BindShardSegment(&m); -} -} // namespace mindrecord -} // namespace mindspore - -namespace nlohmann { -namespace detail { -py::object FromJsonImpl(const json &j) { - if (j.is_null()) { - return py::none(); - } else if (j.is_boolean()) { - return py::bool_(j.get()); - } else if (j.is_number()) { - double number = j.get(); - if (fabs(number - std::floor(number)) < mindspore::mindrecord::kEpsilon) { - return py::int_(j.get()); - } else { - return py::float_(number); - } - } else if (j.is_string()) { - return py::str(j.get()); - } else if (j.is_array()) { - py::list obj; - for (const auto &el : j) { - (void)obj.attr("append")(FromJsonImpl(el)); - } - return std::move(obj); - } else { - py::dict obj; - for (json::const_iterator it = j.cbegin(); it != j.cend(); ++it) { - obj[py::str(it.key())] = FromJsonImpl(it.value()); - } - return std::move(obj); - } -} - -json ToJsonImpl(const py::handle &obj) { - if (obj.is_none()) { - return nullptr; - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj) || py::isinstance(obj)) { - auto out = json::array(); - for (const py::handle &value : obj) { - out.push_back(ToJsonImpl(value)); - } - return out; - } - if (py::isinstance(obj)) { - auto out = json::object(); - for (const py::handle &key : obj) { - out[py::str(key).cast()] = ToJsonImpl(obj[key]); - } - return out; - } - MS_LOG(ERROR) << "Python to json failed, obj is: " << py::cast(obj); - return json(); -} -} // namespace detail - -py::object adl_serializer::FromJson(const json &j) { return detail::FromJsonImpl(j); } - -void adl_serializer::ToJson(json *j, const py::object &obj) { - *j = detail::ToJsonImpl(obj); -} // namespace detail -} // namespace nlohmann diff --git a/mindspore/ccsrc/mindrecord/common/shard_utils.cc b/mindspore/ccsrc/mindrecord/common/shard_utils.cc deleted file mode 100644 index edeabb3cde..0000000000 --- a/mindspore/ccsrc/mindrecord/common/shard_utils.cc +++ /dev/null @@ -1,204 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/common/shard_utils.h" -#include "common/utils.h" -#include "./securec.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::DEBUG; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -// split a string using a character -std::vector StringSplit(const std::string &field, char separator) { - std::vector res; - uint64_t s_pos = 0; - while (s_pos < field.length()) { - size_t e_pos = field.find_first_of(separator, s_pos); - if (e_pos != std::string::npos) { - res.push_back(field.substr(s_pos, e_pos - s_pos)); - } else { - res.push_back(field.substr(s_pos, field.length() - s_pos)); - break; - } - s_pos = e_pos + 1; - } - return res; -} - -bool ValidateFieldName(const std::string &str) { - std::string::const_iterator it = str.begin(); - if (it == str.end()) { - return false; - } - for (; it != str.end(); ++it) { - if (*it == '_' || ((*it >= '0') && (*it <= '9')) || ((*it >= 'A') && (*it <= 'Z')) || - ((*it >= 'a') && (*it <= 'z'))) { - continue; - } - return false; - } - return true; -} - -std::pair GetFileName(const std::string &path) { - char real_path[PATH_MAX] = {0}; - char buf[PATH_MAX] = {0}; - if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; - return {FAILED, ""}; - } - char tmp[PATH_MAX] = {0}; -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; - } - if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { - MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; - } -#else - if (realpath(dirname(&(buf[0])), tmp) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; - } - if (realpath(common::SafeCStr(path), real_path) == nullptr) { - MS_LOG(DEBUG) << "Path: " << path << "check successfully"; - } -#endif - std::string s = real_path; - char sep = '/'; - size_t i = s.rfind(sep, s.length()); - if (i != std::string::npos) { - if (i + 1 < s.size()) { - return {SUCCESS, s.substr(i + 1)}; - } - } - return {SUCCESS, s}; -} - -std::pair GetParentDir(const std::string &path) { - char real_path[PATH_MAX] = {0}; - char buf[PATH_MAX] = {0}; - if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; - return {FAILED, ""}; - } - char tmp[PATH_MAX] = {0}; -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; - } - if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { - MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; - } -#else - if (realpath(dirname(&(buf[0])), tmp) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; - } - if (realpath(common::SafeCStr(path), real_path) == nullptr) { - MS_LOG(DEBUG) << "Path: " << path << "check successfully"; - } -#endif - std::string s = real_path; - if (s.rfind('/') + 1 <= s.size()) { - return {SUCCESS, s.substr(0, s.rfind('/') + 1)}; - } - return {SUCCESS, "/"}; -} - -bool CheckIsValidUtf8(const std::string &str) { - int n = 0; - int ix = str.length(); - for (int i = 0; i < ix; ++i) { - uint8_t c = static_cast(str[i]); - if (c <= 0x7f) { - n = 0; - } else if ((c & 0xE0) == 0xC0) { - n = 1; - } else if (c == 0xed && i < (ix - 1) && (static_cast(str[i + 1]) & 0xa0) == 0xa0) { - return false; - } else if ((c & 0xF0) == 0xE0) { - n = 2; - } else if ((c & 0xF8) == 0xF0) { - n = 3; - } else { - return false; - } - for (int j = 0; j < n && i < ix; ++j) { - if ((++i == ix) || ((static_cast(str[i]) & 0xC0) != 0x80)) { - return false; - } - } - } - return true; -} - -bool IsLegalFile(const std::string &path) { - struct stat s; - if (stat(common::SafeCStr(path), &s) == 0) { - if (s.st_mode & S_IFDIR) { - return false; - } - return true; - } - return false; -} - -std::pair GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type) { -#if defined(_WIN32) || defined(_WIN64) - return {SUCCESS, 100}; -#else - uint64_t ll_count = 0; - struct statfs disk_info; - if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) { - MS_LOG(ERROR) << "Get disk size error"; - return {FAILED, 0}; - } - - switch (disk_type) { - case kTotalSize: - ll_count = disk_info.f_bsize * disk_info.f_blocks; - ll_count = ll_count >> 20; - break; - case kFreeSize: - ll_count = disk_info.f_bsize * disk_info.f_bavail; - ll_count = ll_count >> 20; - break; - default: - ll_count = 0; - break; - } - - return {SUCCESS, ll_count}; -#endif -} - -uint32_t GetMaxThreadNum() { - // define the number of thread - uint32_t thread_num = std::thread::hardware_concurrency(); - if (thread_num == 0) { - thread_num = kMaxConsumerCount; - } - return thread_num; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_pybind.h b/mindspore/ccsrc/mindrecord/include/common/shard_pybind.h deleted file mode 100644 index 86c71a0ea7..0000000000 --- a/mindspore/ccsrc/mindrecord/include/common/shard_pybind.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ -#define MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ - -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "pybind11/pybind11.h" - -namespace py = pybind11; -namespace nlohmann { -template <> -struct adl_serializer { - py::object FromJson(const json &j); - - void ToJson(json *j, const py::object &obj); -}; - -namespace detail { -py::object FromJsonImpl(const json &j); - -json ToJsonImpl(const py::handle &obj); -} // namespace detail -} // namespace nlohmann -#endif // MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h deleted file mode 100644 index 8aa5bdfbda..0000000000 --- a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h +++ /dev/null @@ -1,182 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ -#define MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ - -#include -#include -#include -#include -#if !defined(_WIN32) && !defined(_WIN64) -#include -#include -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/shard_error.h" -#include "nlohmann/json.hpp" -#include "./sqlite3.h" -#include "utils/log_adapter.h" - -/* To be used when dlog is ok #include "./slog.h" */ -#ifdef DEBUG -#define MS_ASSERT(f) assert(f) -#else -#define MS_ASSERT(f) ((void)0) -#endif - -namespace mindspore { -namespace mindrecord { -using json = nlohmann::json; - -const int kInt0 = 0; -const int kInt1 = 1; -const int kInt2 = 2; -const int kInt3 = 3; -const int kUnsignedInt4 = 4; - -enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel }; - -const char kVersion[] = "3.0"; -const std::vector kSupportedVersion = {"2.0", kVersion}; - -enum ShardType { - kNLP = 0, - kCV = 1, -}; - -enum TaskType { - kCommonTask = 0, - kPaddedTask = 1, -}; -enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; - -enum ShuffleType { kShuffleCategory, kShuffleSample }; - -const double kEpsilon = 1e-7; - -const int kThreadNumber = 14; - -// Shard default parameters -const uint64_t kDefaultHeaderSize = 1 << 24; // 16MB -const uint64_t kDefaultPageSize = 1 << 25; // 32MB - -// HeaderSize [16KB, 128MB] -const int kMinHeaderSize = 1 << 14; // 16KB -const int kMaxHeaderSize = 1 << 27; // 128MB - -// PageSize [32KB, 256MB] -const int kMinPageSize = 1 << 15; // 32KB -const int kMaxPageSize = 1 << 28; // 256MB - -// used by value length / schema id length / statistic id length ... -const uint64_t kInt64Len = 8; - -// Minimum file size -const uint64_t kMinFileSize = kInt64Len; - -const int kMinShardCount = 1; -const int kMaxShardCount = 1000; - -const int kMinConsumerCount = 1; -const int kMaxConsumerCount = 128; - -const int kMaxSchemaCount = 1; -const int kMaxThreadCount = 32; -const int kMaxFieldCount = 100; - -// Minimum free disk size -const int kMinFreeDiskSize = 10; // 10M - -// dummy json -const json kDummyId = R"({"id": 0})"_json; - -// translate type in schema to type in sqlite3(NULL, INTEGER, REAL, TEXT, BLOB) -const std::unordered_map kDbJsonMap = { - {"string", "TEXT"}, {"date", "DATE"}, {"date-time", "DATETIME"}, {"null", "NULL"}, - {"integer", "INTEGER"}, {"boolean", "BOOLEAN"}, {"array", "BLOB"}, {"number", "NUMERIC"}, - {"int32", "INTEGER"}, {"int64", "INTEGER"}, {"float32", "NUMERIC"}, {"float64", "NUMERIC"}, - {"bytes", "BLOB"}}; - -const char kPoint = '.'; - -// field type used by check schema validation -const std::set kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"}; - -// can be searched field list -const std::set kScalarFieldTypeSet = {"string", "int32", "int64", "float32", "float64"}; - -// number field list -const std::set kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"}; - -/// \brief split a string using a character -/// \param[in] field target string -/// \param[in] separator a character for spliting -/// \return vector type result -std::vector StringSplit(const std::string &field, char separator); - -/// \brief validate field name is composed of '0-9' or 'a-z' or 'A-Z' or '_' or '-' -/// \param[in] str target string -/// \return -bool ValidateFieldName(const std::string &str); - -/// \brief get the filename by the path -/// \param s file path -/// \return -std::pair GetFileName(const std::string &s); - -/// \brief get parent dir -/// \param path file path -/// \return parent path -std::pair GetParentDir(const std::string &path); - -bool CheckIsValidUtf8(const std::string &str); - -/// \brief judge if a path is legal file -/// \param path file path -/// \return parent path -bool IsLegalFile(const std::string &path); - -enum DiskSizeType { kTotalSize = 0, kFreeSize }; - -/// \brief get the free space about the disk -/// \param str_dir file path -/// \param disk_type: kTotalSize / kFreeSize -/// \return size in Megabytes -std::pair GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type); - -/// \brief get the max hardware concurrency -/// \return max concurrency -uint32_t GetMaxThreadNum(); -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_category.h b/mindspore/ccsrc/mindrecord/include/shard_category.h deleted file mode 100644 index 618a91b1d8..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_category.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ -#define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ - -#include -#include -#include -#include -#include -#include "mindrecord/include/shard_operator.h" - -namespace mindspore { -namespace mindrecord { -class ShardCategory : public ShardOperator { - public: - explicit ShardCategory(const std::vector> &categories, - int64_t num_elements = std::numeric_limits::max(), bool replacement = false); - - ShardCategory(const std::string &category_field, int64_t num_elements, - int64_t num_categories = std::numeric_limits::max(), bool replacement = false); - - ~ShardCategory() override{}; - - const std::vector> &GetCategories() const { return categories_; } - - const std::string GetCategoryField() const { return category_field_; } - - int64_t GetNumElements() const { return num_elements_; } - - int64_t GetNumCategories() const { return num_categories_; } - - bool GetReplacement() const { return replacement_; } - - MSRStatus Execute(ShardTask &tasks) override; - - int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - - private: - std::vector> categories_; - std::string category_field_; - int64_t num_elements_; - int64_t num_categories_; - bool replacement_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_column.h b/mindspore/ccsrc/mindrecord/include/shard_column.h deleted file mode 100644 index 968d82e717..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_column.h +++ /dev/null @@ -1,167 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_ -#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_ - -#include -#include -#include -#include -#include -#include "mindrecord/include/shard_header.h" - -namespace mindspore { -namespace mindrecord { -const uint64_t kUnsignedOne = 1; -const uint64_t kBitsOfByte = 8; -const uint64_t kDataTypeBits = 2; -const uint64_t kNumDataOfByte = 4; -const uint64_t kBytesOfColumnLen = 4; -const uint64_t kDataTypeBitMask = 3; -const uint64_t kDataTypes = 6; - -enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type }; - -enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound }; - -enum ColumnDataType { - ColumnBytes = 0, - ColumnString = 1, - ColumnInt32 = 2, - ColumnInt64 = 3, - ColumnFloat32 = 4, - ColumnFloat64 = 5, - ColumnNoDataType = 6 -}; - -// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"}; -const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8}; - -const std::vector ColumnDataTypeNameNormalized = {"uint8", "string", "int32", - "int64", "float32", "float64"}; - -const std::unordered_map ColumnDataTypeMap = { - {"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32}, - {"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}}; - -class ShardColumn { - public: - explicit ShardColumn(const std::shared_ptr &shard_header, bool compress_integer = true); - - ~ShardColumn() = default; - - /// \brief get column value by column name - MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, - const json &columns_json, const unsigned char **data, - std::unique_ptr *data_ptr, uint64_t *const n_bytes, - ColumnDataType *column_data_type, uint64_t *column_data_type_size, - std::vector *column_shape); - - /// \brief compress blob - std::vector CompressBlob(const std::vector &blob); - - /// \brief check if blob compressed - bool CheckCompressBlob() const { return has_compress_blob_; } - - uint64_t GetNumBlobColumn() const { return num_blob_column_; } - - std::vector GetColumnName() { return column_name_; } - - std::vector GeColumnDataType() { return column_data_type_; } - - std::vector> GetColumnShape() { return column_shape_; } - - /// \brief get column value from blob - MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, - const unsigned char **data, std::unique_ptr *data_ptr, - uint64_t *const n_bytes); - std::pair GetColumnTypeByName(const std::string &column_name, - ColumnDataType *column_data_type, - uint64_t *column_data_type_size, - std::vector *column_shape); - - /// \brief get column value from json - MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, - std::unique_ptr *data_ptr, uint64_t *n_bytes); - - private: - /// \brief get float value from json - template - MSRStatus GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, bool use_double); - - /// \brief get integer value from json - template - MSRStatus GetInt(std::unique_ptr *data_ptr, const json &json_column_value); - - /// \brief get column offset address and size from blob - MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, - uint64_t *num_bytes, uint64_t *shift_idx); - - /// \brief check if column name is available - ColumnCategory CheckColumnName(const std::string &column_name); - - /// \brief compress integer column - static vector CompressInt(const vector &src_bytes, const IntegerType &int_type); - - /// \brief uncompress integer array column - template - static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, - const std::vector &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); - - /// \brief convert big-endian bytes to unsigned int - /// \param bytes_array bytes array - /// \param pos shift address in bytes array - /// \param i_type integer type - /// \return unsigned int - static uint64_t BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, - const IntegerType &i_type); - - /// \brief convert unsigned int to big-endian bytes - /// \param value integer value - /// \param i_type integer type - /// \return bytes - static std::vector UIntToBytesBig(uint64_t value, const IntegerType &i_type); - - /// \brief convert unsigned int to little-endian bytes - /// \param value integer value - /// \param i_type integer type - /// \return bytes - static std::vector UIntToBytesLittle(uint64_t value, const IntegerType &i_type); - - /// \brief convert unsigned int to little-endian bytes - /// \param bytes_array bytes array - /// \param pos shift address in bytes array - /// \param src_i_type source integer typ0e - /// \param dst_i_type (output), destination integer type - /// \return integer - static int64_t BytesLittleToMinIntType(const std::vector &bytes_array, const uint64_t &pos, - const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr); - - private: - std::vector column_name_; // column name list - std::vector column_data_type_; // column data type list - std::vector> column_shape_; // column shape list - std::unordered_map column_name_id_; // column name id map - std::vector blob_column_; // blob column list - std::unordered_map blob_column_id_; // blob column name id map - bool has_compress_blob_; // if has compress blob - uint64_t num_blob_column_; // number of blob columns -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h b/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h deleted file mode 100644 index ef0ad738c4..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ - -#include -#include -#include -#include -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_shuffle.h" -#include "mindrecord/include/shard_sample.h" - -namespace mindspore { -namespace mindrecord { -class ShardDistributedSample : public ShardSample { - public: - ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); - - ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed); - - void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; } - - ~ShardDistributedSample() override{}; - - MSRStatus PreExecute(ShardTask &tasks) override; - - int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - - private: - bool shuffle_; - int no_of_padded_samples_; - bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch - ShardTask task_; // maintain the input tasks in first epoch -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h deleted file mode 100644 index e4361c466a..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ /dev/null @@ -1,186 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_HEADER_H_ -#define MINDRECORD_INCLUDE_SHARD_HEADER_H_ - -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_index.h" -#include "mindrecord/include/shard_page.h" -#include "mindrecord/include/shard_schema.h" -#include "mindrecord/include/shard_statistics.h" - -namespace mindspore { -namespace mindrecord { -class ShardHeader { - public: - ShardHeader(); - - ~ShardHeader() = default; - - MSRStatus BuildDataset(const std::vector &file_paths, bool load_dataset = true); - - static std::pair BuildSingleHeader(const std::string &file_path); - /// \brief add the schema and save it - /// \param[in] schema the schema needs to be added - /// \return the last schema's id - int AddSchema(std::shared_ptr schema); - - /// \brief add the statistic and save it - /// \param[in] statistic the statistic needs to be added - /// \return the last statistic's id - void AddStatistic(std::shared_ptr statistic); - - /// \brief create index and add fields which from schema for each schema - /// \param[in] fields the index fields needs to be added - /// \return SUCCESS if add successfully, FAILED if not - MSRStatus AddIndexFields(std::vector> fields); - - MSRStatus AddIndexFields(const std::vector &fields); - - /// \brief get the schema - /// \return the schema - std::vector> GetSchemas(); - - /// \brief get Statistics - /// \return the Statistic - std::vector> GetStatistics(); - - /// \brief get the fields of the index - /// \return the fields of the index - std::vector> GetFields(); - - /// \brief get the index - /// \return the index - std::shared_ptr GetIndex(); - - /// \brief get the schema by schemaid - /// \param[in] schemaId the id of schema needs to be got - /// \return the schema obtained by schemaId - std::pair, MSRStatus> GetSchemaByID(int64_t schema_id); - - /// \brief get the filepath to shard by shardID - /// \param[in] shardID the id of shard which filepath needs to be obtained - /// \return the filepath obtained by shardID - std::string GetShardAddressByID(int64_t shard_id); - - /// \brief get the statistic by statistic id - /// \param[in] statisticId the id of statistic needs to be get - /// \return the statistics obtained by statistic id - std::pair, MSRStatus> GetStatisticByID(int64_t statistic_id); - - MSRStatus InitByFiles(const std::vector &file_paths); - - void SetIndex(Index index) { index_ = std::make_shared(index); } - - std::pair, MSRStatus> GetPage(const int &shard_id, const int &page_id); - - MSRStatus SetPage(const std::shared_ptr &new_page); - - MSRStatus AddPage(const std::shared_ptr &new_page); - - int64_t GetLastPageId(const int &shard_id); - - int GetLastPageIdByType(const int &shard_id, const std::string &page_type); - - const std::pair> GetPageByGroupId(const int &group_id, const int &shard_id); - - std::vector GetShardAddresses() const { return shard_addresses_; } - - int GetShardCount() const { return shard_count_; } - - int GetSchemaCount() const { return schema_.size(); } - - uint64_t GetHeaderSize() const { return header_size_; } - - uint64_t GetPageSize() const { return page_size_; } - - void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } - - void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } - - std::vector SerializeHeader(); - - MSRStatus PagesToFile(const std::string dump_file_name); - - MSRStatus FileToPages(const std::string dump_file_name); - - private: - MSRStatus InitializeHeader(const std::vector &headers, bool load_dataset); - - /// \brief get the headers from all the shard data - /// \param[in] the shard data real path - /// \param[in] the headers which readed from the shard data - /// \return SUCCESS/FAILED - MSRStatus GetHeaders(const vector &real_addresses, std::vector &headers); - - MSRStatus ValidateField(const std::vector &field_name, json schema, const uint64_t &schema_id); - - /// \brief check the binary file status - static MSRStatus CheckFileStatus(const std::string &path); - - static std::pair ValidateHeader(const std::string &path); - - void ParseHeader(const json &header); - - void GetHeadersOneTask(int start, int end, std::vector &headers, const vector &realAddresses); - - MSRStatus ParseIndexFields(const json &index_fields); - - MSRStatus CheckIndexField(const std::string &field, const json &schema); - - void ParsePage(const json &page, int shard_index, bool load_dataset); - - MSRStatus ParseStatistics(const json &statistics); - - MSRStatus ParseSchema(const json &schema); - - void ParseShardAddress(const json &address); - - std::string SerializeIndexFields(); - - std::vector SerializePage(); - - std::string SerializeStatistics(); - - std::string SerializeSchema(); - - std::string SerializeShardAddress(); - - std::shared_ptr InitIndexPtr(); - - MSRStatus GetAllSchemaID(std::set &bucket_count); - - uint32_t shard_count_; - uint64_t header_size_; - uint64_t page_size_; - - std::shared_ptr index_; - std::vector shard_addresses_; - std::vector> schema_; - std::vector> statistics_; - std::vector>> pages_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_HEADER_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_index.h b/mindspore/ccsrc/mindrecord/include/shard_index.h deleted file mode 100644 index d430c5bdcf..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_index.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INDEX_H -#define MINDRECORD_INDEX_H -#pragma once - -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_schema.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -using std::cin; -using std::endl; -using std::pair; -using std::string; -using std::vector; - -class Index { - public: - Index(); - - ~Index() {} - - /// \brief Add field which from schema according to schemaId - /// \param[in] schemaId the id of schema to be added - /// \param[in] field the field need to be added - /// - /// add the field to the fields_ vector - void AddIndexField(const int64_t &schemaId, const std::string &field); - - /// \brief get stored fields - /// \return fields stored - std::vector > GetFields(); - - private: - std::vector > fields_; - string database_name_; - string table_name_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INDEX_H diff --git a/mindspore/ccsrc/mindrecord/include/shard_index_generator.h b/mindspore/ccsrc/mindrecord/include/shard_index_generator.h deleted file mode 100644 index b081b7a0a0..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_index_generator.h +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ -#define MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/shard_header.h" -#include "./sqlite3.h" - -namespace mindspore { -namespace mindrecord { -using INDEX_FIELDS = std::pair>>; -using ROW_DATA = std::pair>>>; -class ShardIndexGenerator { - public: - explicit ShardIndexGenerator(const std::string &file_path, bool append = false); - - MSRStatus Build(); - - static std::pair GenerateFieldName(const std::pair &field); - - ~ShardIndexGenerator() {} - - /// \brief fetch value in json by field name - /// \param[in] field - /// \param[in] input - /// \return pair - std::pair GetValueByField(const string &field, json input); - - /// \brief fetch field type in schema n by field path - /// \param[in] field_path - /// \param[in] schema - /// \return the type of field - static std::string TakeFieldType(const std::string &field_path, json schema); - - /// \brief create databases for indexes - MSRStatus WriteToDatabase(); - - private: - static int Callback(void *not_used, int argc, char **argv, char **az_col_name); - - static MSRStatus ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = ""); - - static std::string ConvertJsonToSQL(const std::string &json); - - std::pair CreateDatabase(int shard_no); - - std::pair> GetSchemaDetails(const std::vector &schema_lens, std::fstream &in); - - static std::pair GenerateRawSQL(const std::vector> &fields); - - std::pair CheckDatabase(const std::string &shard_address); - - /// - /// \param shard_no - /// \param blob_id_to_page_id - /// \param raw_page_id - /// \param in - /// \return field name, db type, field value - ROW_DATA GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, int raw_page_id, - std::fstream &in); - /// - /// \param db - /// \param sql - /// \param data - /// \return - MSRStatus BindParameterExecuteSQL( - sqlite3 *db, const std::string &sql, - const std::vector>> &data); - - INDEX_FIELDS GenerateIndexFields(const std::vector &schema_detail); - - MSRStatus ExecuteTransaction(const int &shard_no, std::pair &db, - const std::vector &raw_page_ids, const std::map &blob_id_to_page_id); - - MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name); - - MSRStatus AddBlobPageInfo(std::vector> &row_data, - const std::shared_ptr cur_blob_page, uint64_t &cur_blob_page_offset, - std::fstream &in); - - void AddIndexFieldByRawData(const std::vector &schema_detail, - std::vector> &row_data); - - void DatabaseWriter(); // worker thread - - std::string file_path_; - bool append_; - ShardHeader shard_header_; - uint64_t page_size_; - uint64_t header_size_; - int schema_count_; - std::atomic_int task_; - std::atomic_bool write_success_; - std::vector> fields_; -}; -} // namespace mindrecord -} // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_operator.h b/mindspore/ccsrc/mindrecord/include/shard_operator.h deleted file mode 100644 index f33e3db5f4..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_operator.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ -#define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ - -#include -#include "mindrecord/include/shard_task.h" - -namespace mindspore { -namespace mindrecord { -class ShardOperator { - public: - virtual ~ShardOperator() = default; - - MSRStatus operator()(ShardTask &tasks) { - if (SUCCESS != this->PreExecute(tasks)) { - return FAILED; - } - if (SUCCESS != this->Execute(tasks)) { - return FAILED; - } - if (SUCCESS != this->SufExecute(tasks)) { - return FAILED; - } - return SUCCESS; - } - virtual bool HasChildOp() { return child_op_ != nullptr; } - - virtual MSRStatus SetChildOp(std::shared_ptr child_op) { - if (child_op != nullptr) child_op_ = child_op; - return SUCCESS; - } - - virtual std::shared_ptr GetChildOp() { return child_op_; } - - virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } - - virtual MSRStatus Execute(ShardTask &tasks) = 0; - - virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } - - virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } - - private: - std::shared_ptr child_op_ = nullptr; -}; -} // namespace mindrecord -} // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_page.h b/mindspore/ccsrc/mindrecord/include/shard_page.h deleted file mode 100644 index c22acd8d2c..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_page.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_PAGE_H_ -#define MINDRECORD_INCLUDE_SHARD_PAGE_H_ - -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "pybind11/pybind11.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -const std::string kPageTypeRaw = "RAW_DATA"; -const std::string kPageTypeBlob = "BLOB_DATA"; -const std::string kPageTypeNewColumn = "NEW_COLUMN_DATA"; - -class Page { - public: - Page(const int &page_id, const int &shard_id, const std::string &page_type, const int &page_type_id, - const uint64_t &start_row_id, const uint64_t end_row_id, - const std::vector> &row_group_ids, const uint64_t page_size) - : page_id_(page_id), - shard_id_(shard_id), - page_type_(page_type), - page_type_id_(page_type_id), - start_row_id_(start_row_id), - end_row_id_(end_row_id), - row_group_ids_(row_group_ids), - page_size_(page_size) {} - - ~Page() = default; - - /// \brief get the page and its description - /// \return the json format of the page and its description - json GetPage() const; - - int GetPageID() const { return page_id_; } - - int GetShardID() const { return shard_id_; } - - int GetPageTypeID() const { return page_type_id_; } - - std::string GetPageType() const { return page_type_; } - - uint64_t GetPageSize() const { return page_size_; } - - uint64_t GetStartRowID() const { return start_row_id_; } - - uint64_t GetEndRowID() const { return end_row_id_; } - - void SetEndRowID(const uint64_t &end_row_id) { end_row_id_ = end_row_id; } - - void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } - - std::pair GetLastRowGroupID() const { return row_group_ids_.back(); } - - std::vector> GetRowGroupIds() const { return row_group_ids_; } - - void SetRowGroupIds(const std::vector> &last_row_group_ids) { - row_group_ids_ = last_row_group_ids; - } - - void DeleteLastGroupId(); - - private: - int page_id_; - int shard_id_; - std::string page_type_; - int page_type_id_; - uint64_t start_row_id_; - uint64_t end_row_id_; - std::vector> row_group_ids_; - uint64_t page_size_; - // JSON page: { - // "page_id":X, - // "shard_id":X, - // "page_type":"XXX", (enum "raw_data", "blob_data", "new_column") - // "page_type_id":X, - // "start_row_id":X, - // "end_row_id":X, - // "row_group_ids":[{"id":X, "offset":X}], - // "page_size":X, -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_PAGE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h deleted file mode 100644 index 4f1a1c307a..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ - -#include -#include -#include -#include -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_shuffle.h" -#include "mindrecord/include/shard_category.h" - -namespace mindspore { -namespace mindrecord { -class ShardPkSample : public ShardCategory { - public: - ShardPkSample(const std::string &category_field, int64_t num_elements); - - ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); - - ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); - - ~ShardPkSample() override{}; - - MSRStatus SufExecute(ShardTask &tasks) override; - - private: - bool shuffle_; - std::shared_ptr shuffle_op_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h deleted file mode 100644 index 1f2138d6d5..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ /dev/null @@ -1,366 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_READER_H_ -#define MINDRECORD_INCLUDE_SHARD_READER_H_ - -#include -#include -#if !defined(_WIN32) && !defined(_WIN64) -#include -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_category.h" -#include "mindrecord/include/shard_column.h" -#include "mindrecord/include/shard_distributed_sample.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_index_generator.h" -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_shuffle.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -using ROW_GROUPS = - std::tuple>>, std::vector>>; -using ROW_GROUP_BRIEF = - std::tuple>, std::vector>; -using TASK_RETURN_CONTENT = - std::pair, json>>>>; -const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode -const int kNumPageInBuffer = 16; // page buffer size in block-reader mode - -class ShardReader { - public: - ShardReader(); - - virtual ~ShardReader(); - - /// \brief open files and initialize reader, c++ API - /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list - /// \param[in] load_dataset load dataset from single file or not - /// \param[in] n_consumer number of threads when reading - /// \param[in] selected_columns column list to be populated - /// \param[in] operators operators applied to data, operator type is shuffle, sample or category - /// \param[in] block_reader block-reader mode if true, otherwise row-reader mode - /// \return MSRStatus the status of MSRStatus - MSRStatus Open(const std::vector &file_paths, bool load_dataset, int n_consumer = 4, - const std::vector &selected_columns = {}, - const std::vector> &operators = {}, const bool &block_reader = false, - const int num_padded = 0); - - /// \brief open files and initialize reader, python API - /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list - /// \param[in] load_dataset load dataset from single file or not - /// \param[in] n_consumer number of threads when reading - /// \param[in] selected_columns column list to be populated - /// \param[in] operators operators applied to data, operator type is shuffle, sample or category - /// \return MSRStatus the status of MSRStatus - MSRStatus OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer = 4, - const std::vector &selected_columns = {}, - const std::vector> &operators = {}); - - /// \brief close reader - /// \return null - void Close(); - - /// \brief read the file, get schema meta,statistics and index, single-thread mode - /// \return MSRStatus the status of MSRStatus - MSRStatus Open(); - - /// \brief read the file, get schema meta,statistics and index, multiple-thread mode - /// \return MSRStatus the status of MSRStatus - MSRStatus Open(int n_consumer); - - /// \brief launch threads to get batches - /// \param[in] is_simple_reader trigger threads if false; do nothing if true - /// \return MSRStatus the status of MSRStatus - MSRStatus Launch(bool is_simple_reader = false); - - /// \brief aim to get the meta data - /// \return the metadata - std::shared_ptr GetShardHeader() const; - - /// \brief aim to get columns context - /// \return the columns - std::shared_ptr GetShardColumn() const; - - /// \brief get the number of shards - /// \return # of shards - int GetShardCount() const; - - /// \brief get the number of rows in database - /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list - /// \param[in] load_dataset load dataset from single file or not - /// \param[in] op smart pointer refer to ShardCategory or ShardSample object - /// \param[out] count # of rows - /// \return MSRStatus the status of MSRStatus - MSRStatus CountTotalRows(const std::vector &file_paths, bool load_dataset, - const std::shared_ptr &op, int64_t *count, const int num_padded); - - /// \brief shuffle task with incremental seed - /// \return void - void ShuffleTask(); - - /// \brief get the number of rows in database - /// \return # of rows - int GetNumRows() const; - - /// \brief Read the summary of row groups - /// \return the tuple of 4 elements - /// 1. Sharding ID - /// 2. Row group ID - /// 3. The row ID started in row group - /// 4. # of rows in row group - std::vector> ReadRowGroupSummary(); - - /// \brief Read 1 row group data, excluding images - /// \param[in] groupID row group ID - /// \param[in] shard_id sharding ID - /// \param[in] columns multi-columns retrieved - /// \return the tuple of 5 elements - /// 1. file name where row group is located - /// 2. Actual row group size - /// 3. Offset address of row group in file - /// 4. The list of image offset in page [startOffset, endOffset) - /// 5. The list of columns data - ROW_GROUP_BRIEF ReadRowGroupBrief(int group_id, int shard_id, - const std::vector &columns = std::vector()); - - /// \brief Read 1 row group data, excluding images, following an index field criteria - /// \param[in] groupID row group ID - /// \param[in] shard_id sharding ID - /// \param[in] column-value pair of criteria to fulfill - /// \param[in] columns multi-columns retrieved - /// \return the tuple of 5 elements - /// 1. file name where row group is located - /// 2. Actual row group size - /// 3. Offset address of row group in file - /// 4. The list of image offset in page [startOffset, endOffset) - /// 5. The list of columns data - ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair &criteria, - const std::vector &columns = std::vector()); - - /// \brief join all created threads - /// \return MSRStatus the status of MSRStatus - MSRStatus Finish(); - - /// \brief return a batch, given that one is ready - /// \return a batch of images and image data - std::vector, json>> GetNext(); - - /// \brief return a row by id - /// \return a batch of images and image data - std::pair, json>>> GetNextById(const int64_t &task_id, - const int32_t &consumer_id); - - /// \brief return a batch in block-reader mode, given that one is ready - /// \return a batch of images and image data - std::vector, json>> GetBlockNext(); - - /// \brief return a batch, given that one is ready, python API - /// \return a batch of images and image data - std::vector>, pybind11::object>> GetNextPy(); - - /// \brief get blob filed list - /// \return blob field list - std::pair> GetBlobFields(); - - /// \brief reset reader - /// \return null - void Reset(); - - /// \brief set flag of all-in-index - /// \return null - void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } - - /// \brief get NLP flag - bool GetNlpFlag(); - - /// \brief get all classes - MSRStatus GetAllClasses(const std::string &category_field, std::set &categories); - - protected: - /// \brief sqlite call back function - static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); - - private: - /// \brief wrap up labels to json format - MSRStatus ConvertLabelToJson(const std::vector> &labels, std::shared_ptr fs, - std::vector>> &offsets, int shard_id, - const std::vector &columns, std::vector> &column_values); - - /// \brief read all rows for specified columns - ROW_GROUPS ReadAllRowGroup(std::vector &columns); - - /// \brief read all rows in one shard - MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, - std::vector>> &offsets, - std::vector> &column_values); - - /// \brief initialize reader - MSRStatus Init(const std::vector &file_paths, bool load_dataset); - - /// \brief validate column list - MSRStatus CheckColumnList(const std::vector &selected_columns); - - /// \brief populate one row by task list in row-reader mode - MSRStatus ConsumerByRow(int consumer_id); - - /// \brief populate one row by task list in block-reader mode - MSRStatus ConsumerByBlock(int consumer_id); - - /// \brief get offset address of images within page - std::vector> GetImageOffset(int group_id, int shard_id, - const std::pair &criteria = {"", ""}); - - /// \brief execute sqlite query with prepare statement - MSRStatus QueryWithCriteria(sqlite3 *db, string &sql, string criteria, std::vector> &labels); - - /// \brief get column values - std::pair> GetLabels(int group_id, int shard_id, const std::vector &columns, - const std::pair &criteria = {"", ""}); - - /// \brief get column values from raw data page - std::pair> GetLabelsFromPage(int group_id, int shard_id, - const std::vector &columns, - const std::pair &criteria = {"", - ""}); - - /// \brief create task list in block-reader mode - MSRStatus CreateTasksByBlock(const std::vector> &row_group_summary, - const std::vector> &operators); - - /// \brief create category-applied task list - MSRStatus CreateTasksByCategory(const std::vector> &row_group_summary, - const std::shared_ptr &op); - - /// \brief create task list in row-reader mode - MSRStatus CreateTasksByRow(const std::vector> &row_group_summary, - const std::vector> &operators); - - /// \brief crate task list - MSRStatus CreateTasks(const std::vector> &row_group_summary, - const std::vector> &operators); - - /// \brief set NLP flag - void CheckNlp(); - - /// \brief check if all specified columns are in index table - void CheckIfColumnInIndex(const std::vector &columns); - - /// \brief open multiple file handle - void FileStreamsOperator(); - - /// \brief read one row by one task - TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id); - - /// \brief get one row from buffer in block-reader mode - std::shared_ptr, json>>> GetRowFromBuffer(int bufId, int rowId); - - /// \brief get labels from binary file - std::pair> GetLabelsFromBinaryFile( - int shard_id, const std::vector &columns, const std::vector> &label_offsets); - - MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id); - - /// \brief get classes in one shard - void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set &categories); - - /// \brief get number of classes - int64_t GetNumClasses(const std::string &category_field); - - /// \brief get meta of header - std::pair> GetMeta(const std::string &file_path, json &meta_data); - - /// \brief extract uncompressed data based on column list - std::pair>> UnCompressBlob(const std::vector &raw_blob_data); - - protected: - uint64_t header_size_; // header size - uint64_t page_size_; // page size - int shard_count_; // number of shards - std::shared_ptr shard_header_; // shard header - std::shared_ptr shard_column_; // shard column - - std::vector database_paths_; // sqlite handle list - std::vector file_paths_; // file paths - std::vector> file_streams_; // single-file handle list - std::vector>> file_streams_random_; // multiple-file handle list - - private: - int n_consumer_; // number of workers (threads) - std::vector selected_columns_; // columns which will be read - std::map column_schema_id_; // column-schema map - std::vector> operators_; // data operators, including shuffle, sample and category - ShardTask tasks_; // shard task - std::mutex shard_locker_; // locker of shard - - // flags - bool all_in_index_ = true; // if all columns are stored in index-table - bool interrupt_ = false; // reader interrupted - - int num_padded_; // number of padding samples - - // Delivery/Iterator mode begin - const std::string kThreadName = "THRD_ITER_"; // prefix of thread name - std::vector thread_set_; // thread list - int num_rows_; // number of rows - std::mutex mtx_delivery_; // locker for delivery - std::condition_variable cv_delivery_; // conditional variable for delivery - std::condition_variable cv_iterator_; // conditional variable for iterator - std::atomic task_id_; // task ID which is working - std::atomic deliver_id_; // delivery ID which is picked up by iterator - // map of delivery - std::unordered_map, json>>>> delivery_map_; - // Delivery/Iterator mode end - - // Block reader mode begin - bool block_reader_; // block-reader mode - int row_id_; // row id in one page - int num_blocks_; // number of pages - // raw data page - std::vector>, std::vector>>> delivery_block_; - std::unordered_set delivery_block_set_; // set of delivered pages - std::vector> buf_; // page buffer - // Block reader mode end -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_READER_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sample.h deleted file mode 100644 index a32acbff6e..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_sample.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ - -#include -#include -#include -#include -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_shuffle.h" - -namespace mindspore { -namespace mindrecord { -class ShardSample : public ShardOperator { - public: - explicit ShardSample(int n); - - ShardSample(int num, int den); - - ShardSample(int num, int den, int par); - - ShardSample(const std::vector &indices, uint32_t seed); - - ~ShardSample() override{}; - - MSRStatus Execute(ShardTask &tasks) override; - - MSRStatus SufExecute(ShardTask &tasks) override; - - int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - - protected: - int numerator_; - int denominator_; - int partition_id_; - int no_of_samples_; - std::shared_ptr shuffle_op_; - - private: - std::vector indices_; - SamplerType sampler_type_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_schema.h b/mindspore/ccsrc/mindrecord/include/shard_schema.h deleted file mode 100644 index 4ef134bde2..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_schema.h +++ /dev/null @@ -1,90 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ -#define MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ - -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_pybind.h" -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" -#include "pybind11/pybind11.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -class Schema { - public: - ~Schema() = default; - - /// \brief obtain the json schema ,its description, its block fields - /// \param[in] desc the description of the schema - /// \param[in] schema the schema's json - static std::shared_ptr Build(std::string desc, const json &schema); - - /// \brief obtain the json schema and its description for python - /// \param[in] desc the description of the schema - /// \param[in] schema the schema's json - static std::shared_ptr Build(std::string desc, pybind11::handle schema); - - /// \brief compare two schema to judge if they are equal - /// \param b another schema to be judged - /// \return true if they are equal,false if not - bool operator==(const Schema &b) const; - - /// \brief get the schema and its description - /// \return the json format of the schema and its description - std::string GetDesc() const; - - /// \brief get the schema and its description - /// \return the json format of the schema and its description - json GetSchema() const; - - /// \brief get the schema and its description for python method - /// \return the python object of the schema and its description - pybind11::object GetSchemaForPython() const; - - /// set the schema id - /// \param[in] id the id need to be set - void SetSchemaID(int64_t id); - - /// get the schema id - /// \return the int64 schema id - int64_t GetSchemaID() const; - - /// get the blob fields - /// \return the vector blob fields - std::vector GetBlobFields() const; - - private: - Schema() = default; - static bool ValidateNumberShape(const json &it_value); - static bool Validate(json schema); - static std::vector PopulateBlobFields(json schema); - - std::string desc_; - json schema_; - std::vector blob_fields_; - int64_t schema_id_ = -1; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_segment.h b/mindspore/ccsrc/mindrecord/include/shard_segment.h deleted file mode 100644 index 12497a5ace..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_segment.h +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ -#define MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ - -#include -#include -#include -#include -#include "mindrecord/include/shard_reader.h" - -namespace mindspore { -namespace mindrecord { -class ShardSegment : public ShardReader { - public: - ShardSegment(); - - ~ShardSegment() override = default; - - /// \brief Get candidate category fields - /// \return a list of fields names which are the candidates of category - std::pair> GetCategoryFields(); - - /// \brief Set category field - /// \param[in] category_field category name - /// \return true if category name is existed - MSRStatus SetCategoryField(std::string category_field); - - /// \brief Thread-safe implementation of ReadCategoryInfo - /// \return statistics data in json format with 2 field: "key" and "categories". - /// The value of "categories" is a list. Each Element in list is {count, id, name} - /// count: count of images in category - /// id: internal unique identification, persistent - /// name: category name - /// example: - /// { "key": "label", - /// "categories": [ { "count": 3, "id": 0, "name": "sport", }, - /// { "count": 3, "id": 1, "name": "finance", } ] } - std::pair ReadCategoryInfo(); - - /// \brief Thread-safe implementation of ReadAtPageById - /// \param[in] category_id category ID - /// \param[in] page_no page number - /// \param[in] n_rows_of_page rows number in one page - /// \return images array, image is a vector of uint8_t - std::pair>> ReadAtPageById(int64_t category_id, int64_t page_no, - int64_t n_rows_of_page); - - /// \brief Thread-safe implementation of ReadAtPageByName - /// \param[in] category_name category Name - /// \param[in] page_no page number - /// \param[in] n_rows_of_page rows number in one page - /// \return images array, image is a vector of uint8_t - std::pair>> ReadAtPageByName(std::string category_name, int64_t page_no, - int64_t n_rows_of_page); - - std::pair, json>>> ReadAllAtPageById(int64_t category_id, - int64_t page_no, - int64_t n_rows_of_page); - - std::pair, json>>> ReadAllAtPageByName( - std::string category_name, int64_t page_no, int64_t n_rows_of_page); - - std::pair, pybind11::object>>> ReadAtPageByIdPy( - int64_t category_id, int64_t page_no, int64_t n_rows_of_page); - - std::pair, pybind11::object>>> ReadAtPageByNamePy( - std::string category_name, int64_t page_no, int64_t n_rows_of_page); - - std::pair> GetBlobFields(); - - private: - std::pair>> WrapCategoryInfo(); - - std::string ToJsonForCategory(const std::vector> &tri_vec); - - std::string CleanUp(std::string fieldName); - - std::pair> PackImages(int group_id, int shard_id, std::vector offset); - - std::vector candidate_category_fields_; - std::string current_category_field_; - const uint32_t kStartFieldId = 9; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h deleted file mode 100644 index a8ee3a36db..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ - -#include -#include -#include -#include -#include "mindrecord/include/shard_sample.h" - -namespace mindspore { -namespace mindrecord { -class ShardSequentialSample : public ShardSample { - public: - ShardSequentialSample(int n, int offset); - - ShardSequentialSample(float per, float per_offset); - - ~ShardSequentialSample() override{}; - - MSRStatus Execute(ShardTask &tasks) override; - - int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - - private: - int offset_; - float per_; - float per_offset_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h deleted file mode 100644 index adb172bdcc..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ -#define MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ - -#include -#include "mindrecord/include/shard_operator.h" - -namespace mindspore { -namespace mindrecord { -class ShardShuffle : public ShardOperator { - public: - explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); - - ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, - ShuffleType shuffle_type = kShuffleSample); - - ~ShardShuffle() override{}; - - MSRStatus Execute(ShardTask &tasks) override; - - int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - - private: - uint32_t shuffle_seed_; - int64_t no_of_samples_; - bool replacement_; - bool reshuffle_each_epoch_; - ShuffleType shuffle_type_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_statistics.h b/mindspore/ccsrc/mindrecord/include/shard_statistics.h deleted file mode 100644 index 7fc2f968cd..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_statistics.h +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#ifndef MINDRECORD_STATISTICS_H -#define MINDRECORD_STATISTICS_H - -#include -#include -#include -#include -#include - -#include "mindrecord/include/common/shard_pybind.h" -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" -#include "pybind11/pybind11.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -class Statistics { - public: - /// \brief save the statistic and its description - /// \param[in] desc the statistic's description - /// \param[in] statistics the statistic needs to be saved - static std::shared_ptr Build(std::string desc, const json &statistics); - - /// \brief save the statistic from python and its description - /// \param[in] desc the statistic's description - /// \param[in] statistics the statistic needs to be saved - static std::shared_ptr Build(std::string desc, pybind11::handle statistics); - - ~Statistics() = default; - - /// \brief compare two statistics to judge if they are equal - /// \param b another statistics to be judged - /// \return true if they are equal,false if not - bool operator==(const Statistics &b) const; - - /// \brief get the description - /// \return the description - std::string GetDesc() const; - - /// \brief get the statistic - /// \return json format of the statistic - json GetStatistics() const; - - /// \brief get the statistic for python - /// \return the python object of statistics - pybind11::object GetStatisticsForPython() const; - - /// \brief decode the bson statistics to json - /// \param[in] encodedStatistics the bson type of statistics - /// \return json type of statistic - void SetStatisticsID(int64_t id); - - /// \brief get the statistics id - /// \return the int64 statistics id - int64_t GetStatisticsID() const; - - private: - /// \brief validate the statistic - /// \return true / false - static bool Validate(const json &statistics); - - static bool LevelRecursive(json level); - - Statistics() = default; - - std::string desc_; - json statistics_; - int64_t statistics_id_ = -1; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_STATISTICS_H diff --git a/mindspore/ccsrc/mindrecord/include/shard_task.h b/mindspore/ccsrc/mindrecord/include/shard_task.h deleted file mode 100644 index 4a12eb9e45..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_task.h +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_TASK_H_ -#define MINDRECORD_INCLUDE_SHARD_TASK_H_ - -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" - -namespace mindspore { -namespace mindrecord { -class ShardTask { - public: - ShardTask(); - - ShardTask(const ShardTask &task); // copy construction - - ShardTask &operator=(const ShardTask &task); // assignment operator - - ~ShardTask() = default; - - void MakePerm(); - - void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, - const json &label); - - void InsertTask(std::tuple, std::vector, json> task); - - void PopBack(); - - uint32_t Size() const; - - uint32_t SizeOfRows() const; - - std::tuple, std::vector, json> &GetTaskByID(size_t id); - - std::tuple, std::vector, json> &GetRandomTask(); - - static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements); - - uint32_t categories; - - std::vector permutation_; - - std::vector, std::vector, json>> task_list_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_TASK_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_writer.h b/mindspore/ccsrc/mindrecord/include/shard_writer.h deleted file mode 100644 index 6175180c92..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_writer.h +++ /dev/null @@ -1,257 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_WRITER_H_ -#define MINDRECORD_INCLUDE_SHARD_WRITER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_column.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_header.h" -#include "mindrecord/include/shard_index.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -class ShardWriter { - public: - ShardWriter(); - - ~ShardWriter(); - - /// \brief Open file at the beginning - /// \param[in] paths the file names list - /// \param[in] append new data at the end of file if true, otherwise overwrite file - /// \return MSRStatus the status of MSRStatus - MSRStatus Open(const std::vector &paths, bool append = false); - - /// \brief Open file at the ending - /// \param[in] paths the file names list - /// \return MSRStatus the status of MSRStatus - MSRStatus OpenForAppend(const std::string &path); - - /// \brief Write header to disk - /// \return MSRStatus the status of MSRStatus - MSRStatus Commit(); - - /// \brief Set file size - /// \param[in] header_size the size of header, only (1< header_data); - - /// \brief write raw data by group size - /// \param[in] raw_data the vector of raw json data, vector format - /// \param[in] blob_data the vector of image data - /// \param[in] sign validate data or not - /// \return MSRStatus the status of MSRStatus to judge if write successfully - MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true, bool parallel_writer = false); - - /// \brief write raw data by group size for call from python - /// \param[in] raw_data the vector of raw json data, python-handle format - /// \param[in] blob_data the vector of image data - /// \param[in] sign validate data or not - /// \return MSRStatus the status of MSRStatus to judge if write successfully - MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true, bool parallel_writer = false); - - /// \brief write raw data by group size for call from python - /// \param[in] raw_data the vector of raw json data, python-handle format - /// \param[in] blob_data the vector of blob json data, python-handle format - /// \param[in] sign validate data or not - /// \return MSRStatus the status of MSRStatus to judge if write successfully - MSRStatus WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign = true, - bool parallel_writer = false); - - private: - /// \brief write shard header data to disk - MSRStatus WriteShardHeader(); - - /// \brief erase error data - void DeleteErrorData(std::map> &raw_data, std::vector> &blob_data); - - /// \brief populate error data - void PopulateMutexErrorData(const int &row, const std::string &message, std::map &err_raw_data); - - /// \brief check data - void CheckSliceData(int start_row, int end_row, json schema, const std::vector &sub_raw_data, - std::map &err_raw_data); - - /// \brief write shard header data to disk - std::tuple ValidateRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign); - - /// \brief fill data array in multiple thread run - void FillArray(int start, int end, std::map> &raw_data, - std::vector> &bin_data); - - /// \brief serialized raw data - MSRStatus SerializeRawData(std::map> &raw_data, - std::vector> &bin_data, uint32_t row_count); - - /// \brief write all data parallel - MSRStatus ParallelWriteData(const std::vector> &blob_data, - const std::vector> &bin_raw_data); - - /// \brief write data shard by shard - MSRStatus WriteByShard(int shard_id, int start_row, int end_row, const std::vector> &blob_data, - const std::vector> &bin_raw_data); - - /// \brief break image data up into multiple row groups - MSRStatus CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, - std::vector> &rows_in_group, const std::shared_ptr &last_raw_page, - const std::shared_ptr &last_blob_page); - - /// \brief append partial blob data to previous page - MSRStatus AppendBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page); - - /// \brief write new blob data page to disk - MSRStatus NewBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page); - - /// \brief shift last row group to next raw page for new appending - MSRStatus ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page); - - /// \brief write raw data page to disk - MSRStatus WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page, const std::vector> &bin_raw_data); - - /// \brief generate empty raw data page - void EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page); - - /// \brief append a row group at the end of raw page - MSRStatus AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, - const int &chunk_id, int &last_row_groupId, std::shared_ptr last_raw_page, - const std::vector> &bin_raw_data); - - /// \brief write blob chunk to disk - MSRStatus FlushBlobChunk(const std::shared_ptr &out, const std::vector> &blob_data, - const std::pair &blob_row); - - /// \brief write raw chunk to disk - MSRStatus FlushRawChunk(const std::shared_ptr &out, - const std::vector> &rows_in_group, const int &chunk_id, - const std::vector> &bin_raw_data); - - /// \brief break up into tasks by shard - std::vector> BreakIntoShards(); - - /// \brief calculate raw data size row by row - MSRStatus SetRawDataSize(const std::vector> &bin_raw_data); - - /// \brief calculate blob data size row by row - MSRStatus SetBlobDataSize(const std::vector> &blob_data); - - /// \brief populate last raw page pointer - void SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page); - - /// \brief populate last blob page pointer - void SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page); - - /// \brief check the data by schema - MSRStatus CheckData(const std::map> &raw_data); - - /// \brief check the data and type - MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, - std::map &err_raw_data); - - /// \brief Lock writer and save pages info - int LockWriter(bool parallel_writer = false); - - /// \brief Unlock writer and save pages info - MSRStatus UnlockWriter(int fd, bool parallel_writer = false); - - /// \brief Check raw data before writing - MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, - bool sign, int *schema_count, int *row_count); - - /// \brief Get full path from file name - MSRStatus GetFullPathFromFileName(const std::vector &paths); - - /// \brief Open files - MSRStatus OpenDataFiles(bool append); - - /// \brief Remove lock file - MSRStatus RemoveLockFile(); - - /// \brief Remove lock file - MSRStatus InitLockFile(); - - private: - const std::string kLockFileSuffix = "_Locker"; - const std::string kPageFileSuffix = "_Pages"; - std::string lock_file_; // lock file for parallel run - std::string pages_file_; // temporary file of pages info for parallel run - - int shard_count_; // number of files - uint64_t header_size_; // header size - uint64_t page_size_; // page size - uint32_t row_count_; // count of rows - uint32_t schema_count_; // count of schemas - - std::vector raw_data_size_; // Raw data size - std::vector blob_data_size_; // Blob data size - - std::vector file_paths_; // file paths - std::vector> file_streams_; // file handles - std::shared_ptr shard_header_; // shard header - std::shared_ptr shard_column_; // shard columns - - std::map> err_mg_; // used for storing error raw_data info - - std::mutex check_mutex_; // mutex for data check - std::atomic flag_{false}; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_WRITER_H_ diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc deleted file mode 100644 index 16c730bd4c..0000000000 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ /dev/null @@ -1,626 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "mindrecord/include/shard_index_generator.h" -#include "common/utils.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::DEBUG; -using mindspore::MsLogLevel::ERROR; -using mindspore::MsLogLevel::INFO; - -namespace mindspore { -namespace mindrecord { -ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append) - : file_path_(file_path), - append_(append), - page_size_(0), - header_size_(0), - schema_count_(0), - task_(0), - write_success_(true) {} - -MSRStatus ShardIndexGenerator::Build() { - auto ret = ShardHeader::BuildSingleHeader(file_path_); - if (ret.first != SUCCESS) { - return FAILED; - } - auto json_header = ret.second; - - auto ret2 = GetParentDir(file_path_); - if (SUCCESS != ret2.first) { - return FAILED; - } - std::vector real_addresses; - for (const auto &path : json_header["shard_addresses"]) { - std::string abs_path = ret2.second + string(path); - real_addresses.emplace_back(abs_path); - } - ShardHeader header = ShardHeader(); - if (header.BuildDataset(real_addresses) == FAILED) { - return FAILED; - } - shard_header_ = header; - MS_LOG(INFO) << "Init header from mindrecord file for index successfully."; - return SUCCESS; -} - -std::pair ShardIndexGenerator::GetValueByField(const string &field, json input) { - if (field.empty()) { - MS_LOG(ERROR) << "The input field is None."; - return {FAILED, ""}; - } - - if (input.empty()) { - MS_LOG(ERROR) << "The input json is None."; - return {FAILED, ""}; - } - - // parameter input does not contain the field - if (input.find(field) == input.end()) { - MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input; - return {FAILED, ""}; - } - - // schema does not contain the field - auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; - if (schema.find(field) == schema.end()) { - MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; - return {FAILED, ""}; - } - - // field should be scalar type - if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) { - MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable"; - return {FAILED, ""}; - } - - if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { - auto schema_field_options = schema[field]; - if (schema_field_options.find("shape") == schema_field_options.end()) { - return {SUCCESS, input[field].dump()}; - } else { - // field with shape option - MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable"; - return {FAILED, ""}; - } - } - - // the field type is string in here - return {SUCCESS, input[field].get()}; -} - -std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { - std::vector field_name = StringSplit(field_path, kPoint); - for (uint64_t i = 0; i < field_name.size(); i++) { - if (i != field_name.size() - 1) { - // Get type information from json schema - schema = schema.at(field_name[i]); - schema = schema.at("properties"); - } else { - // standard root layer exist "properties" if type is "object" - if (schema.find("properties") != schema.end()) { - schema = schema.at("properties"); - } - schema = schema.at(field_name[i]); - std::string field_type = schema.at("type").dump(); - if (field_type.length() <= 2) { - return ""; - } else { - return field_type.substr(1, field_type.length() - 2); - } - } - } - return ""; -} - -std::string ShardIndexGenerator::ConvertJsonToSQL(const std::string &json) { - if (kDbJsonMap.find(json) != kDbJsonMap.end()) { - return kDbJsonMap.at(json); - } else { - return "TEXT"; - } -} - -int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char **az_col_name) { - for (auto i = 0; i < argc; i++) { - if (argv[i] != nullptr) { - MS_LOG(INFO) << az_col_name[i] << " = " << (argv[i] ? argv[i] : "nullptr"); - } - } - MS_LOG(INFO) << "\n"; - return 0; -} - -MSRStatus ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) { - char *z_err_msg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Sql error: " << z_err_msg; - sqlite3_free(z_err_msg); - return FAILED; - } else { - if (!success_msg.empty()) { - MS_LOG(DEBUG) << "Sqlite3_exec exec success, msg is: " << success_msg; - } - sqlite3_free(z_err_msg); - return SUCCESS; - } -} - -std::pair ShardIndexGenerator::GenerateFieldName( - const std::pair &field) { - // Replaces dots and dashes with underscores for SQL use - std::string field_name = field.second; - // white list to avoid sql injection - std::replace_if( - field_name.begin(), field_name.end(), [](char x) { return (x == '-' || x == '.'); }, '_'); - auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) { - return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9'); - }); - if (pos != field_name.end()) { - MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " << field_name; - return {FAILED, ""}; - } - return {SUCCESS, field_name + "_" + std::to_string(field.first)}; -} - -std::pair ShardIndexGenerator::CheckDatabase(const std::string &shard_address) { - sqlite3 *db = nullptr; - std::ifstream fin(common::SafeCStr(shard_address)); - if (!append_ && fin.good()) { - MS_LOG(ERROR) << "DB file already exist"; - fin.close(); - return {FAILED, nullptr}; - } - fin.close(); - int rc = sqlite3_open_v2(common::SafeCStr(shard_address), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); - if (rc) { - MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); - return {FAILED, nullptr}; - } else { - MS_LOG(DEBUG) << "Opened database successfully"; - return {SUCCESS, db}; - } -} - -MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) { - // create shard_name table - std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;"; - if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { - return FAILED; - } - sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);"; - if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { - return FAILED; - } - sql = "INSERT INTO SHARD_NAME (NAME) VALUES ('" + shard_name + "');"; - if (ExecuteSQL(sql, db, "insert name successfully.") != SUCCESS) { - return FAILED; - } - return SUCCESS; -} - -std::pair ShardIndexGenerator::CreateDatabase(int shard_no) { - std::string shard_address = shard_header_.GetShardAddressByID(shard_no); - if (shard_address.empty()) { - MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; - return {FAILED, nullptr}; - } - - string shard_name = GetFileName(shard_address).second; - shard_address += ".db"; - auto ret1 = CheckDatabase(shard_address); - if (ret1.first != SUCCESS) { - return {FAILED, nullptr}; - } - sqlite3 *db = ret1.second; - std::string sql = "DROP TABLE IF EXISTS INDEXES;"; - if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { - return {FAILED, nullptr}; - } - sql = - "CREATE TABLE INDEXES(" - " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL" - ", PAGE_OFFSET_RAW INT NOT NULL, PAGE_OFFSET_RAW_END INT NOT NULL" - ", ROW_GROUP_ID INT NOT NULL, PAGE_ID_BLOB INT NOT NULL" - ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL"; - - int field_no = 0; - for (const auto &field : fields_) { - uint64_t schema_id = field.first; - auto result = shard_header_.GetSchemaByID(schema_id); - if (result.second != SUCCESS) { - return {FAILED, nullptr}; - } - json json_schema = (result.first->GetSchema())["schema"]; - std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema)); - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, nullptr}; - } - sql += ",INC_" + std::to_string(field_no++) + " INT, " + ret.second + " " + type; - } - sql += ", PRIMARY KEY(ROW_ID"; - for (uint64_t i = 0; i < fields_.size(); ++i) sql += ",INC_" + std::to_string(i); - sql += "));"; - if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { - return {FAILED, nullptr}; - } - - if (CreateShardNameTable(db, shard_name) != SUCCESS) { - return {FAILED, nullptr}; - } - return {SUCCESS, db}; -} - -std::pair> ShardIndexGenerator::GetSchemaDetails(const std::vector &schema_lens, - std::fstream &in) { - std::vector schema_details; - if (schema_count_ <= kMaxSchemaCount) { - for (int sc = 0; sc < schema_count_; ++sc) { - std::vector schema_detail(schema_lens[sc]); - - auto &io_read = in.read(&schema_detail[0], schema_lens[sc]); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - in.close(); - return {FAILED, {}}; - } - - schema_details.emplace_back(json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()))); - } - } - - return {SUCCESS, schema_details}; -} - -std::pair ShardIndexGenerator::GenerateRawSQL( - const std::vector> &fields) { - std::string sql = - "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END," - "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END"; - - int field_no = 0; - for (const auto &field : fields) { - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, ""}; - } - sql += ",INC_" + std::to_string(field_no++) + "," + ret.second; - } - sql += - ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB," - ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END"; - field_no = 0; - for (const auto &field : fields) { - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, ""}; - } - sql += ",:INC_" + std::to_string(field_no++) + ",:" + ret.second; - } - sql += " )"; - return {SUCCESS, sql}; -} - -MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( - sqlite3 *db, const std::string &sql, - const std::vector>> &data) { - sqlite3_stmt *stmt = nullptr; - if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; - return FAILED; - } - for (auto &row : data) { - for (auto &field : row) { - const auto &place_holder = std::get<0>(field); - const auto &field_type = std::get<1>(field); - const auto &field_value = std::get<2>(field); - - int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); - if (field_type == "INTEGER") { - if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index - << ", field value: " << std::stoll(field_value); - return FAILED; - } - } else if (field_type == "NUMERIC") { - if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index - << ", field value: " << std::stold(field_value); - return FAILED; - } - } else if (field_type == "NULL") { - if (sqlite3_bind_null(stmt, index) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL"; - return FAILED; - } - } else { - if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value; - return FAILED; - } - } - } - if (sqlite3_step(stmt) != SQLITE_DONE) { - MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt."; - return FAILED; - } - (void)sqlite3_reset(stmt); - } - (void)sqlite3_finalize(stmt); - return SUCCESS; -} - -MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector> &row_data, - const std::shared_ptr cur_blob_page, - uint64_t &cur_blob_page_offset, std::fstream &in) { - row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); - - // blob data start - row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset)); - auto &io_seekg_blob = - in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); - if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - in.close(); - return FAILED; - } - - uint64_t image_size = 0; - - auto &io_read = in.read(reinterpret_cast(&image_size), kInt64Len); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - in.close(); - return FAILED; - } - - cur_blob_page_offset += (kInt64Len + image_size); - row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset)); - - return SUCCESS; -} - -void ShardIndexGenerator::AddIndexFieldByRawData( - const std::vector &schema_detail, std::vector> &row_data) { - auto result = GenerateIndexFields(schema_detail); - if (result.first == SUCCESS) { - int index = 0; - for (const auto &field : result.second) { - // assume simple field: string , number etc. - row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0"); - row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field)); - } - } -} - -ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, - int raw_page_id, std::fstream &in) { - std::vector>> full_data; - - // current raw data page - std::shared_ptr cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first; - - // related blob page - vector> row_group_list = cur_raw_page->GetRowGroupIds(); - - // pair: row_group id, offset in raw data page - for (pair blob_ids : row_group_list) { - // get blob data page according to row_group id - std::shared_ptr cur_blob_page = shard_header_.GetPage(shard_no, blob_id_to_page_id.at(blob_ids.first)).first; - - // offset in current raw data page - auto cur_raw_page_offset = static_cast(blob_ids.second); - uint64_t cur_blob_page_offset = 0; - for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) { - std::vector> row_data; - row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); - row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID())); - row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID())); - - // raw data start - row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); - - // calculate raw data end - auto &io_seekg = - in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - in.close(); - return {FAILED, {}}; - } - - std::vector schema_lens; - if (schema_count_ <= kMaxSchemaCount) { - for (int sc = 0; sc < schema_count_; sc++) { - uint64_t schema_size = 0; - - auto &io_read = in.read(reinterpret_cast(&schema_size), kInt64Len); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - in.close(); - return {FAILED, {}}; - } - - cur_raw_page_offset += (kInt64Len + schema_size); - schema_lens.push_back(schema_size); - } - } - row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset)); - - // Getting schema for getting data for fields - auto st_schema_detail = GetSchemaDetails(schema_lens, in); - if (st_schema_detail.first != SUCCESS) { - return {FAILED, {}}; - } - - // start blob page info - if (AddBlobPageInfo(row_data, cur_blob_page, cur_blob_page_offset, in) != SUCCESS) { - return {FAILED, {}}; - } - - // start index field - AddIndexFieldByRawData(st_schema_detail.second, row_data); - full_data.push_back(std::move(row_data)); - } - } - return {SUCCESS, full_data}; -} - -INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector &schema_detail) { - std::vector> fields; - // index fields - std::vector> index_fields = shard_header_.GetFields(); - for (const auto &field : index_fields) { - if (field.first >= schema_detail.size()) { - return {FAILED, {}}; - } - auto field_value = GetValueByField(field.second, schema_detail[field.first]); - if (field_value.first != SUCCESS) { - MS_LOG(ERROR) << "Get value from json by field name failed"; - return {FAILED, {}}; - } - - auto result = shard_header_.GetSchemaByID(field.first); - if (result.second != SUCCESS) { - return {FAILED, {}}; - } - - std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, {}}; - } - - fields.emplace_back(ret.second, field_type, field_value.second); - } - return {SUCCESS, std::move(fields)}; -} - -MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, std::pair &db, - const std::vector &raw_page_ids, - const std::map &blob_id_to_page_id) { - // Add index data to database - std::string shard_address = shard_header_.GetShardAddressByID(shard_no); - if (shard_address.empty()) { - MS_LOG(ERROR) << "Shard address is null"; - return FAILED; - } - - std::fstream in; - in.open(common::SafeCStr(shard_address), std::ios::in | std::ios::binary); - if (!in.good()) { - MS_LOG(ERROR) << "File could not opened"; - return FAILED; - } - (void)sqlite3_exec(db.second, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); - for (int raw_page_id : raw_page_ids) { - auto sql = GenerateRawSQL(fields_); - if (sql.first != SUCCESS) { - MS_LOG(ERROR) << "Generate raw SQL failed"; - return FAILED; - } - auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); - if (data.first != SUCCESS) { - MS_LOG(ERROR) << "Generate raw data failed"; - return FAILED; - } - if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { - MS_LOG(ERROR) << "Execute SQL failed"; - return FAILED; - } - MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; - } - (void)sqlite3_exec(db.second, "END TRANSACTION;", nullptr, nullptr, nullptr); - in.close(); - - // Close database - if (sqlite3_close(db.second) != SQLITE_OK) { - MS_LOG(ERROR) << "Close database failed"; - return FAILED; - } - db.second = nullptr; - return SUCCESS; -} - -MSRStatus ShardIndexGenerator::WriteToDatabase() { - fields_ = shard_header_.GetFields(); - page_size_ = shard_header_.GetPageSize(); - header_size_ = shard_header_.GetHeaderSize(); - schema_count_ = shard_header_.GetSchemaCount(); - if (shard_header_.GetShardCount() > kMaxShardCount) { - MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount; - return FAILED; - } - task_ = 0; // set two atomic vars to initial value - write_success_ = true; - - // spawn half the physical threads or total number of shards whichever is smaller - const unsigned int num_workers = - std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast(shard_header_.GetShardCount())); - - std::vector threads; - threads.reserve(num_workers); - - for (size_t t = 0; t < threads.capacity(); t++) { - threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this)); - } - - for (size_t t = 0; t < threads.capacity(); t++) { - threads[t].join(); - } - return write_success_ ? SUCCESS : FAILED; -} - -void ShardIndexGenerator::DatabaseWriter() { - int shard_no = task_++; - while (shard_no < shard_header_.GetShardCount()) { - auto db = CreateDatabase(shard_no); - if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { - write_success_ = false; - return; - } - - MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; - - // Pre-processing page information - auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; - - std::map blob_id_to_page_id; - std::vector raw_page_ids; - for (uint64_t i = 0; i < total_pages; ++i) { - std::shared_ptr cur_page = shard_header_.GetPage(shard_no, i).first; - if (cur_page->GetPageType() == "RAW_DATA") { - raw_page_ids.push_back(i); - } else if (cur_page->GetPageType() == "BLOB_DATA") { - blob_id_to_page_id[cur_page->GetPageTypeID()] = i; - } - } - - if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { - write_success_ = false; - return; - } - MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully."; - shard_no = task_++; - } -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc deleted file mode 100644 index 99fa0c447d..0000000000 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ /dev/null @@ -1,1449 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_distributed_sample.h" -#include "mindrecord/include/shard_reader.h" -#include "common/utils.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::DEBUG; -using mindspore::MsLogLevel::ERROR; -using mindspore::MsLogLevel::INFO; - -namespace mindspore { -namespace mindrecord { -template -// convert the string to exactly number type (int32_t/int64_t/float/double) -Type StringToNum(const std::string &str) { - std::istringstream iss(str); - Type num; - iss >> num; - return num; -} - -ShardReader::ShardReader() { - task_id_ = 0; - deliver_id_ = 0; - shard_count_ = 0; - n_consumer_ = 0; - page_size_ = 0; - header_size_ = 0; - num_rows_ = 0; - row_id_ = 0; - num_blocks_ = 0; - block_reader_ = false; - num_padded_ = 0; -} - -std::pair> ShardReader::GetMeta(const std::string &file_path, json &meta_data) { - if (!IsLegalFile(file_path)) { - return {FAILED, {}}; - } - auto ret = ShardHeader::BuildSingleHeader(file_path); - if (ret.first != SUCCESS) { - return {FAILED, {}}; - } - auto header = ret.second; - meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, - {"version", header["version"]}, {"index_fields", header["index_fields"]}, - {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; - return {SUCCESS, header["shard_addresses"]}; -} - -MSRStatus ShardReader::Init(const std::vector &file_paths, bool load_dataset) { - std::string file_path = file_paths[0]; - json first_meta_data = json(); - auto ret = GetMeta(file_path, first_meta_data); - if (ret.first != SUCCESS) { - return FAILED; - } - if (file_paths.size() == 1 && load_dataset == true) { - auto ret2 = GetParentDir(file_path); - if (SUCCESS != ret2.first) { - return FAILED; - } - std::vector real_addresses; - for (const auto &path : ret.second) { - std::string abs_path = ret2.second + string(path); - real_addresses.emplace_back(abs_path); - } - file_paths_ = real_addresses; - } else if (file_paths.size() >= 1 && load_dataset == false) { - file_paths_ = file_paths; - } else { - MS_LOG(ERROR) << "Error in parameter file_path or load_dataset."; - return FAILED; - } - for (const auto &file : file_paths_) { - json meta_data = json(); - auto ret1 = GetMeta(file, meta_data); - if (ret1.first != SUCCESS) { - return FAILED; - } - if (meta_data != first_meta_data) { - MS_LOG(ERROR) << "Mindrecord files meta information is different."; - return FAILED; - } - sqlite3 *db = nullptr; - // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it - int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); - return FAILED; - } - MS_LOG(DEBUG) << "Opened database successfully"; - - string sql = "select NAME from SHARD_NAME;"; - std::vector> name; - char *errmsg = nullptr; - rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &name, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return FAILED; - } else { - MS_LOG(DEBUG) << "Get " << static_cast(name.size()) << " records from index."; - string shardName = GetFileName(file).second; - if (name.empty() || name[0][0] != shardName) { - MS_LOG(ERROR) << "DB file can not match file " << file; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return FAILED; - } - } - database_paths_.push_back(db); - } - ShardHeader sh = ShardHeader(); - if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) { - return FAILED; - } - shard_header_ = std::make_shared(sh); - header_size_ = shard_header_->GetHeaderSize(); - page_size_ = shard_header_->GetPageSize(); - // version < 3.0 - if (first_meta_data["version"] < kVersion) { - shard_column_ = std::make_shared(shard_header_, false); - } else { - shard_column_ = std::make_shared(shard_header_, true); - } - num_rows_ = 0; - auto row_group_summary = ReadRowGroupSummary(); - for (const auto &rg : row_group_summary) { - num_rows_ += std::get<3>(rg); - } - - MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully."; - - return SUCCESS; -} - -MSRStatus ShardReader::CheckColumnList(const std::vector &selected_columns) { - vector inSchema(selected_columns.size(), 0); - for (auto &p : GetShardHeader()->GetSchemas()) { - auto schema = p->GetSchema()["schema"]; - for (unsigned int i = 0; i < selected_columns.size(); ++i) { - if (schema.find(selected_columns[i]) != schema.end()) { - inSchema[i] = 1; - } - } - } - if (std::any_of(std::begin(inSchema), std::end(inSchema), [](int x) { return x == 0; })) { - return FAILED; - } - - return SUCCESS; -} - -MSRStatus ShardReader::Open() { - file_streams_.clear(); - - for (const auto &file : file_paths_) { - std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "File could not opened"; - return FAILED; - } - MS_LOG(INFO) << "Open shard file successfully."; - file_streams_.push_back(fs); - } - - return SUCCESS; -} - -MSRStatus ShardReader::Open(int n_consumer) { - file_streams_random_ = - std::vector>>(n_consumer, std::vector>()); - for (const auto &file : file_paths_) { - for (int j = 0; j < n_consumer; ++j) { - std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "File could not opened"; - return FAILED; - } - file_streams_random_[j].push_back(fs); - } - MS_LOG(INFO) << "Open shard file successfully."; - } - - return SUCCESS; -} - -void ShardReader::FileStreamsOperator() { - for (int i = static_cast(file_streams_.size()) - 1; i >= 0; --i) { - if (file_streams_[i] != nullptr) { - file_streams_[i]->close(); - } - } - for (int i = static_cast(file_streams_random_.size()) - 1; i >= 0; --i) { - for (int j = static_cast(file_streams_random_[i].size()) - 1; j >= 0; --j) { - if (file_streams_random_[i][j] != nullptr) { - file_streams_random_[i][j]->close(); - } - } - } - for (int i = static_cast(database_paths_.size()) - 1; i >= 0; --i) { - if (database_paths_[i] != nullptr) { - auto ret = sqlite3_close(database_paths_[i]); - if (ret != SQLITE_OK) { - MS_LOG(ERROR) << "Close db failed. Error code: " << ret << "."; - } - database_paths_[i] = nullptr; - } - } -} - -ShardReader::~ShardReader() { Close(); } - -void ShardReader::Close() { - (void)Finish(); // interrupt reading and stop threads - FileStreamsOperator(); -} - -std::shared_ptr ShardReader::GetShardHeader() const { return shard_header_; } - -std::shared_ptr ShardReader::GetShardColumn() const { return shard_column_; } - -int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } - -int ShardReader::GetNumRows() const { return num_rows_; } - -std::vector> ShardReader::ReadRowGroupSummary() { - std::vector> row_group_summary; - int shard_count = shard_header_->GetShardCount(); - if (shard_count <= 0) { - return row_group_summary; - } - if (shard_count <= kMaxShardCount) { - for (int shard_id = 0; shard_id < shard_count; ++shard_id) { - // return -1 when page's size equals to 0. - auto last_page_id = shard_header_->GetLastPageId(shard_id); - if (static_cast(last_page_id) == -1) { - continue; - } - for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) { - const auto &page_t = shard_header_->GetPage(shard_id, page_id); - const auto &page = page_t.first; - if (page->GetPageType() != kPageTypeBlob) continue; - uint64_t start_row_id = page->GetStartRowID(); - if (start_row_id > page->GetEndRowID()) { - return std::vector>(); - } - uint64_t number_of_rows = page->GetEndRowID() - start_row_id; - row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows); - } - } - } - return row_group_summary; -} - -MSRStatus ShardReader::ConvertLabelToJson(const std::vector> &labels, - std::shared_ptr fs, - std::vector>> &offsets, int shard_id, - const std::vector &columns, - std::vector> &column_values) { - for (int i = 0; i < static_cast(labels.size()); ++i) { - uint64_t group_id = std::stoull(labels[i][0]); - uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len; - uint64_t offset_end = std::stoull(labels[i][2]); - offsets[shard_id].emplace_back( - std::vector{static_cast(shard_id), group_id, offset_start, offset_end}); - if (!all_in_index_) { - int raw_page_id = std::stoi(labels[i][3]); - uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len; - uint64_t label_end = std::stoull(labels[i][5]); - auto len = label_end - label_start; - auto label_raw = std::vector(len); - auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - fs->close(); - return FAILED; - } - - auto &io_read = fs->read(reinterpret_cast(&label_raw[0]), len); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - fs->close(); - return FAILED; - } - json label_json = json::from_msgpack(label_raw); - json tmp; - if (!columns.empty()) { - for (auto &col : columns) { - if (label_json.find(col) != label_json.end()) { - tmp[col] = label_json[col]; - } - } - } else { - tmp = label_json; - } - column_values[shard_id].emplace_back(tmp); - } else { - json construct_json; - for (unsigned int j = 0; j < columns.size(); ++j) { - // construct json "f1": value - auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; - - // convert the string to base type by schema - if (schema[columns[j]]["type"] == "int32") { - construct_json[columns[j]] = StringToNum(labels[i][j + 3]); - } else if (schema[columns[j]]["type"] == "int64") { - construct_json[columns[j]] = StringToNum(labels[i][j + 3]); - } else if (schema[columns[j]]["type"] == "float32") { - construct_json[columns[j]] = StringToNum(labels[i][j + 3]); - } else if (schema[columns[j]]["type"] == "float64") { - construct_json[columns[j]] = StringToNum(labels[i][j + 3]); - } else { - construct_json[columns[j]] = std::string(labels[i][j + 3]); - } - } - column_values[shard_id].emplace_back(construct_json); - } - } - - return SUCCESS; -} - -MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, - std::vector>> &offsets, - std::vector> &column_values) { - auto db = database_paths_[shard_id]; - std::vector> labels; - char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return FAILED; - } - MS_LOG(INFO) << "Get " << static_cast(labels.size()) << " records from shard " << shard_id << " index."; - - std::string file_name = file_paths_[shard_id]; - std::shared_ptr fs = std::make_shared(); - if (!all_in_index_) { - fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "File could not opened"; - return FAILED; - } - } - sqlite3_free(errmsg); - return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); -} - -MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { - std::map index_columns; - for (auto &field : GetShardHeader()->GetFields()) { - index_columns[field.second] = field.first; - } - if (index_columns.find(category_field) == index_columns.end()) { - MS_LOG(ERROR) << "Index field " << category_field << " does not exist."; - return FAILED; - } - auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field)); - if (SUCCESS != ret.first) { - return FAILED; - } - std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; - std::vector threads = std::vector(shard_count_); - for (int x = 0; x < shard_count_; x++) { - threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories)); - } - - for (int x = 0; x < shard_count_; x++) { - threads[x].join(); - } - return SUCCESS; -} - -void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, - std::set &categories) { - if (nullptr == db) { - return; - } - std::vector> columns; - char *errmsg = nullptr; - int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg); - if (ret != SQLITE_OK) { - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; - return; - } - MS_LOG(INFO) << "Get " << static_cast(columns.size()) << " records from shard " << shard_id << " index."; - std::lock_guard lck(shard_locker_); - for (int i = 0; i < static_cast(columns.size()); ++i) { - categories.emplace(columns[i][0]); - } -} - -ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { - std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; - std::vector>> offsets(shard_count_, std::vector>{}); - std::vector> column_values(shard_count_, std::vector{}); - if (all_in_index_) { - for (unsigned int i = 0; i < columns.size(); ++i) { - fields += ','; - auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i])); - if (ret.first != SUCCESS) { - return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); - } - fields += ret.second; - } - } else { // fetch raw data from Raw page while some field is not index. - fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END "; - } - - std::string sql = "SELECT " + fields + " FROM INDEXES ORDER BY ROW_ID ;"; - - std::vector thread_read_db = std::vector(shard_count_); - for (int x = 0; x < shard_count_; x++) { - thread_read_db[x] = - std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, std::ref(offsets), std::ref(column_values)); - } - - for (int x = 0; x < shard_count_; x++) { - thread_read_db[x].join(); - } - return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values)); -} - -ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector &columns) { - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - const std::shared_ptr &page = ret.second; - std::string file_name = file_paths_[shard_id]; - uint64_t page_length = page->GetPageSize(); - uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; - std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id); - - auto status_labels = GetLabels(page->GetPageID(), shard_id, columns); - if (status_labels.first != SUCCESS) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::move(image_offset), - std::move(status_labels.second)); -} - -ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, - const std::pair &criteria, - const std::vector &columns) { - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - vector criteria_list{criteria.first}; - if (CheckColumnList(criteria_list) == FAILED) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - const std::shared_ptr &page = ret.second; - std::string file_name = file_paths_[shard_id]; - uint64_t page_length = page->GetPageSize(); - uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; - std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria); - - auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria); - if (status_labels.first != SUCCESS) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - - return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::move(image_offset), - std::move(status_labels.second)); -} - -int ShardReader::SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names) { - auto *records = static_cast> *>(p_data); - if (num_fields > 0 && num_fields <= kMaxFieldCount) { - for (int i = 0; i < num_fields; ++i) - if (p_fields[i] == nullptr) p_fields[i] = const_cast(""); - } - records->emplace_back(p_fields, p_fields + num_fields); - return 0; -} - -std::vector> ShardReader::GetImageOffset(int page_id, int shard_id, - const std::pair &criteria) { - auto db = database_paths_[shard_id]; - - std::string sql = - "SELECT PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); - - // whether use index search - if (!criteria.first.empty()) { - auto schema = shard_header_->GetSchemas()[0]->GetSchema(); - - // not number field should add '' in sql - if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { - sql += - " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; - } else { - sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" + - criteria.second + "'"; - } - } - sql += ";"; - std::vector> image_offsets; - char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &image_offsets, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return std::vector>(); - } else { - MS_LOG(DEBUG) << "Get " << static_cast(image_offsets.size()) << "records from index."; - } - std::vector> res; - for (int i = static_cast(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector{0, 0}); - for (int i = 0; i < static_cast(image_offsets.size()); i++) { - const auto &image_offset = image_offsets[i]; - res[i][0] = std::stoull(image_offset[0]) + kInt64Len; - res[i][1] = std::stoull(image_offset[1]); - } - sqlite3_free(errmsg); - return res; -} - -std::pair> ShardReader::GetBlobFields() { - std::vector blob_fields; - for (auto &p : GetShardHeader()->GetSchemas()) { - // assume one schema - const auto &fields = p->GetBlobFields(); - blob_fields.assign(fields.begin(), fields.end()); - break; - } - return std::make_pair(kCV, blob_fields); -} - -void ShardReader::CheckIfColumnInIndex(const std::vector &columns) { - // assume different schemas do not contain same key. - if (columns.empty()) { - all_in_index_ = false; - return; - } - for (auto &field : GetShardHeader()->GetFields()) { - column_schema_id_[field.second] = field.first; - } - for (auto &col : columns) { - if (column_schema_id_.find(col) == column_schema_id_.end()) { - all_in_index_ = false; - return; - } - } -} - -MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criteria, - std::vector> &labels) { - sqlite3_stmt *stmt = nullptr; - if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not prepare statement"; - return FAILED; - } - int index = sqlite3_bind_parameter_index(stmt, ":criteria"); - if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << criteria; - return FAILED; - } - int rc = sqlite3_step(stmt); - while (rc != SQLITE_DONE) { - vector tmp; - int ncols = sqlite3_column_count(stmt); - for (int i = 0; i < ncols; i++) { - tmp.emplace_back(reinterpret_cast(sqlite3_column_text(stmt, i))); - } - labels.push_back(tmp); - rc = sqlite3_step(stmt); - } - (void)sqlite3_finalize(stmt); - return SUCCESS; -} - -std::pair> ShardReader::GetLabelsFromBinaryFile( - int shard_id, const std::vector &columns, const std::vector> &label_offsets) { - std::string file_name = file_paths_[shard_id]; - std::vector res; - std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "File could not opened"; - return {FAILED, {}}; - } - - // init the return - for (unsigned int i = 0; i < label_offsets.size(); ++i) { - res.emplace_back(json{}); - } - - for (unsigned int i = 0; i < label_offsets.size(); ++i) { - const auto &labelOffset = label_offsets[i]; - uint64_t label_start = std::stoull(labelOffset[1]) + kInt64Len; - uint64_t label_end = std::stoull(labelOffset[2]); - int raw_page_id = std::stoi(labelOffset[0]); - auto len = label_end - label_start; - auto label_raw = std::vector(len); - auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - fs->close(); - return {FAILED, {}}; - } - - auto &io_read = fs->read(reinterpret_cast(&label_raw[0]), len); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - fs->close(); - return {FAILED, {}}; - } - - json label_json = json::from_msgpack(label_raw); - json tmp = label_json; - for (auto &col : columns) { - if (label_json.find(col) != label_json.end()) { - tmp[col] = label_json[col]; - } - } - res[i] = tmp; - } - return {SUCCESS, res}; -} - -std::pair> ShardReader::GetLabelsFromPage( - int page_id, int shard_id, const std::vector &columns, - const std::pair &criteria) { - // get page info from sqlite - auto db = database_paths_[shard_id]; - std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " + - std::to_string(page_id); - std::vector> label_offsets; - if (!criteria.first.empty()) { - sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria"; - if (QueryWithCriteria(db, sql, criteria.second, label_offsets) == FAILED) { - return {FAILED, {}}; - } - } else { - sql += ";"; - char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &label_offsets, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return {FAILED, {}}; - } - MS_LOG(DEBUG) << "Get " << label_offsets.size() << "records from index."; - sqlite3_free(errmsg); - } - // get labels from binary file - return GetLabelsFromBinaryFile(shard_id, columns, label_offsets); -} - -std::pair> ShardReader::GetLabels(int page_id, int shard_id, - const std::vector &columns, - const std::pair &criteria) { - if (all_in_index_) { - auto db = database_paths_[shard_id]; - std::string fields; - for (unsigned int i = 0; i < columns.size(); ++i) { - if (i > 0) fields += ','; - uint64_t schema_id = column_schema_id_[columns[i]]; - fields += columns[i] + "_" + std::to_string(schema_id); - } - if (fields.empty()) fields = "*"; - std::vector> labels; - std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); - if (!criteria.first.empty()) { - sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria"; - if (QueryWithCriteria(db, sql, criteria.second, labels) == FAILED) { - return {FAILED, {}}; - } - } else { - sql += ";"; - char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return {FAILED, {}}; - } else { - MS_LOG(DEBUG) << "Get " << static_cast(labels.size()) << "records from index."; - } - sqlite3_free(errmsg); - } - std::vector ret; - for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); - for (unsigned int i = 0; i < labels.size(); ++i) { - json construct_json; - for (unsigned int j = 0; j < columns.size(); ++j) { - // construct json "f1": value - auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; - - // convert the string to base type by schema - if (schema[columns[j]]["type"] == "int32") { - construct_json[columns[j]] = StringToNum(labels[i][j]); - } else if (schema[columns[j]]["type"] == "int64") { - construct_json[columns[j]] = StringToNum(labels[i][j]); - } else if (schema[columns[j]]["type"] == "float32") { - construct_json[columns[j]] = StringToNum(labels[i][j]); - } else if (schema[columns[j]]["type"] == "float64") { - construct_json[columns[j]] = StringToNum(labels[i][j]); - } else { - construct_json[columns[j]] = std::string(labels[i][j]); - } - } - ret[i] = construct_json; - } - return {SUCCESS, ret}; - } - return GetLabelsFromPage(page_id, shard_id, columns, criteria); -} - -bool ResortRowGroups(std::tuple a, std::tuple b) { - return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b)); -} - -MSRStatus ShardReader::Finish() { - { - std::lock_guard lck(mtx_delivery_); - interrupt_ = true; - } - cv_delivery_.notify_all(); - - // Wait for all threads to finish - for (auto &i_thread : thread_set_) { - if (i_thread.joinable()) { - i_thread.join(); - } - } - return SUCCESS; -} - -int64_t ShardReader::GetNumClasses(const std::string &category_field) { - auto shard_count = file_paths_.size(); - auto index_fields = shard_header_->GetFields(); - - std::map map_schema_id_fields; - for (auto &field : index_fields) { - map_schema_id_fields[field.second] = field.first; - } - - if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) { - MS_LOG(ERROR) << "Field " << category_field << " does not exist."; - return -1; - } - auto ret = - ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); - if (SUCCESS != ret.first) { - return -1; - } - std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; - std::vector threads = std::vector(shard_count); - std::set categories; - for (int x = 0; x < shard_count; x++) { - sqlite3 *db = nullptr; - int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); - if (SQLITE_OK != rc) { - MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); - return -1; - } - threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories)); - } - - for (int x = 0; x < shard_count; x++) { - threads[x].join(); - } - return categories.size(); -} - -MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths, bool load_dataset, - const std::shared_ptr &ops, int64_t *count, const int num_padded) { - if (SUCCESS != Init(file_paths, load_dataset)) { - return FAILED; - } - int64_t num_samples = num_rows_; - bool root = true; - std::stack> stack_ops; - std::shared_ptr op(ops); - while (op != nullptr) { - stack_ops.push(op); - op = op->GetChildOp(); - } - while (!stack_ops.empty()) { - op = stack_ops.top(); - stack_ops.pop(); - if (std::dynamic_pointer_cast(op)) { - num_samples = op->GetNumSamples(num_samples, 0); - if (num_padded > 0 && root == true) { - num_samples += num_padded; - MS_LOG(DEBUG) << "Padding samples work on shuffle sampler."; - root = false; - } - } else if (std::dynamic_pointer_cast(op)) { - auto category_op = std::dynamic_pointer_cast(op); - std::string category_field = category_op->GetCategoryField(); - auto num_classes = GetNumClasses(category_field); - num_samples = category_op->GetNumSamples(num_samples, num_classes); - } else if (std::dynamic_pointer_cast(op)) { - if (std::dynamic_pointer_cast(op)) { - auto sampler_op = std::dynamic_pointer_cast(op); - if (root == true) { - sampler_op->SetNumPaddedSamples(num_padded); - num_samples = op->GetNumSamples(num_samples, 0); - if (-1 == num_samples) { - MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; - return FAILED; - } - root = false; - } - } else { - num_samples = op->GetNumSamples(num_samples, 0); - } - } else { - if (num_padded > 0) num_samples += num_padded; - } - } - *count = num_samples; - return SUCCESS; -} - -MSRStatus ShardReader::Open(const std::vector &file_paths, bool load_dataset, int n_consumer, - const std::vector &selected_columns, - const std::vector> &operators, const bool &block_reader, - int num_padded) { - // Open file and set header by ShardReader - auto ret = Init(file_paths, load_dataset); - if (SUCCESS != ret) { - return ret; - } - auto thread_limit = GetMaxThreadNum(); - if (n_consumer > thread_limit) { - n_consumer = thread_limit; - } - if (n_consumer < kMinConsumerCount) { - n_consumer = kMinConsumerCount; - } - vector blob_fields = GetBlobFields().second; - for (unsigned int i = 0; i < selected_columns.size(); ++i) { - if (!std::any_of(blob_fields.begin(), blob_fields.end(), - [&selected_columns, i](std::string item) { return selected_columns[i] == item; })) { - selected_columns_.push_back(selected_columns[i]); - } - } - selected_columns_ = selected_columns; - - if (CheckColumnList(selected_columns_) == FAILED) { - MS_LOG(ERROR) << "Illegal column list"; - return ILLEGAL_COLUMN_LIST; - } - - // Initialize argument - shard_count_ = static_cast(file_paths_.size()); - n_consumer_ = n_consumer; - num_padded_ = num_padded; - - operators_ = operators; - - if (block_reader) { - block_reader_ = true; - if (Open() == FAILED) { - return FAILED; - } - delivery_block_ = std::vector>, std::vector>>>( - kNumPageInBuffer, std::shared_ptr>, std::vector>>{}); - buf_ = std::vector>(kNumPageInBuffer, std::vector(page_size_)); - } else { - block_reader_ = false; - if (Open(n_consumer) == FAILED) { - return FAILED; - } - } - return SUCCESS; -} - -MSRStatus ShardReader::OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer, - const std::vector &selected_columns, - const std::vector> &operators) { - // Open file and set header by ShardReader - if (SUCCESS != Init(file_paths, load_dataset)) { - return FAILED; - } - // should remove blob field from selected_columns when call from python - std::vector columns(selected_columns); - auto blob_fields = GetBlobFields().second; - for (auto &blob_field : blob_fields) { - auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field); - if (it != selected_columns.end()) { - columns.erase(columns.begin() + std::distance(selected_columns.begin(), it)); - } - } - if (CheckColumnList(columns) == FAILED) { - MS_LOG(ERROR) << "Illegal column list"; - return FAILED; - } - if (Open(n_consumer) == FAILED) { - return FAILED; - } - // Initialize argument - shard_count_ = static_cast(file_paths_.size()); - n_consumer_ = n_consumer; - - // Initialize columns which will be read - selected_columns_ = selected_columns; - operators_ = operators; - - return SUCCESS; -} - -MSRStatus ShardReader::Launch(bool isSimpleReader) { - // Get all row groups' info - auto row_group_summary = ReadRowGroupSummary(); - - // Sort row group by (group_id, shard_id), prepare for parallel reading - std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups); - if (CreateTasks(row_group_summary, operators_) != SUCCESS) { - MS_LOG(ERROR) << "Failed to launch read threads."; - interrupt_ = true; - return FAILED; - } - if (isSimpleReader) return SUCCESS; - // Start provider consumer threads - thread_set_ = std::vector(n_consumer_); - if (n_consumer_ <= 0 || n_consumer_ > kMaxConsumerCount) { - return FAILED; - } - - for (int x = 0; x < n_consumer_; ++x) { - if (block_reader_) { - thread_set_[x] = std::thread(&ShardReader::ConsumerByBlock, this, x); - } else { - thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); - } - } - - MS_LOG(INFO) << "Launch read thread successfully."; - return SUCCESS; -} - -MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, - const std::vector> &operators) { - CheckIfColumnInIndex(selected_columns_); - for (const auto &rg : row_group_summary) { - auto shard_id = std::get<0>(rg); - auto group_id = std::get<1>(rg); - auto n_Rows = std::get<3>(rg); - tasks_.InsertTask(TaskType::kCommonTask, shard_id, group_id, std::vector{n_Rows}, json{}); - } - return SUCCESS; -} - -MSRStatus ShardReader::CreateTasksByCategory(const std::vector> &row_group_summary, - const std::shared_ptr &op) { - CheckIfColumnInIndex(selected_columns_); - auto category_op = std::dynamic_pointer_cast(op); - auto categories = category_op->GetCategories(); - int64_t num_elements = category_op->GetNumElements(); - if (num_elements <= 0) { - MS_LOG(ERROR) << "Parameter num_element is not positive"; - return FAILED; - } - if (categories.empty() == true) { - std::string category_field = category_op->GetCategoryField(); - int64_t num_categories = category_op->GetNumCategories(); - if (num_categories <= 0) { - MS_LOG(ERROR) << "Parameter num_categories is not positive"; - return FAILED; - } - std::set categories_set; - auto ret = GetAllClasses(category_field, categories_set); - if (SUCCESS != ret) { - return FAILED; - } - int i = 0; - for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) { - categories.emplace_back(category_field, *it); - i++; - } - } - // Generate task list, a task will create a batch - std::vector categoryTasks(categories.size()); - for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { - int category_index = 0; - for (const auto &rg : row_group_summary) { - if (category_index >= num_elements) break; - auto shard_id = std::get<0>(rg); - auto group_id = std::get<1>(rg); - - auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_); - if (SUCCESS != std::get<0>(details)) { - return FAILED; - } - auto offsets = std::get<4>(details); - - auto number_of_rows = offsets.size(); - for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { - if (category_index < num_elements) { - categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, std::get<4>(details)[iStart], - std::get<5>(details)[iStart]); - category_index++; - } - } - } - MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; - } - tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); - if (SUCCESS != (*category_op)(tasks_)) { - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, - const std::vector> &operators) { - CheckIfColumnInIndex(selected_columns_); - - auto ret = ReadAllRowGroup(selected_columns_); - if (std::get<0>(ret) != SUCCESS) { - return FAILED; - } - auto offsets = std::get<1>(ret); - auto local_columns = std::get<2>(ret); - if (shard_count_ <= kMaxShardCount) { - for (int shard_id = 0; shard_id < shard_count_; shard_id++) { - for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { - tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], - std::vector{offsets[shard_id][i][2], offsets[shard_id][i][3]}, - local_columns[shard_id][i]); - } - } - } else { - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardReader::CreateTasks(const std::vector> &row_group_summary, - const std::vector> &operators) { - if (block_reader_) { - if (SUCCESS != CreateTasksByBlock(row_group_summary, operators)) { - return FAILED; - } - } else { - int category_operator = -1; - for (uint32_t i = 0; i < operators.size(); ++i) { - const auto &op = operators[i]; - if (std::dynamic_pointer_cast(op)) { - category_operator = static_cast(i); - break; - } - } - if (-1 == category_operator) { - if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) { - return FAILED; - } - if (num_padded_ > 0) { - for (int i = 0; i < num_padded_; ++i) { - tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json()); - } - } - } else { - if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { - return FAILED; - } - } - } - - for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) { - const auto &op = operators[operator_no]; - if (std::dynamic_pointer_cast(op)) continue; - if (block_reader_ && std::dynamic_pointer_cast(op)) continue; - if (SUCCESS != (*op)(tasks_)) { - return FAILED; - } - } - - if (tasks_.permutation_.empty()) tasks_.MakePerm(); - num_rows_ = block_reader_ ? tasks_.SizeOfRows() : tasks_.Size(); - num_blocks_ = block_reader_ ? tasks_.Size() : 0; - MS_LOG(INFO) << "Total rows is " << num_rows_; - return SUCCESS; -} - -TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { - // All tasks are done - if (task_id >= static_cast(tasks_.Size())) { - return std::make_pair(FAILED, - std::make_pair(TaskType::kCommonTask, std::vector, json>>())); - } - - // Pick up task from task list - auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); - - // check task type - auto task_type = std::get<0>(task); - if (task_type == TaskType::kPaddedTask) { - return std::make_pair(SUCCESS, - std::make_pair(TaskType::kPaddedTask, std::vector, json>>())); - } - - auto shard_id = std::get<0>(std::get<1>(task)); - auto group_id = std::get<1>(std::get<1>(task)); - auto addr = std::get<2>(task); - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return std::make_pair(FAILED, - std::make_pair(TaskType::kCommonTask, std::vector, json>>())); - } - const std::shared_ptr &page = ret.second; - - // Pack image list - std::vector images(addr[1] - addr[0]); - auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0]; - - auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - file_streams_random_[consumer_id][shard_id]->close(); - return std::make_pair(FAILED, - std::make_pair(TaskType::kCommonTask, std::vector, json>>())); - } - - auto &io_read = - file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast(&images[0]), addr[1] - addr[0]); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - file_streams_random_[consumer_id][shard_id]->close(); - return std::make_pair(FAILED, - std::pair(TaskType::kCommonTask, std::vector, json>>())); - } - - // Deliver batch data to output map - std::vector, json>> batch; - batch.emplace_back(std::move(images), std::move(std::get<3>(task))); - - return std::make_pair(SUCCESS, std::make_pair(TaskType::kCommonTask, std::move(batch))); -} - -MSRStatus ShardReader::ConsumerByRow(int consumer_id) { - // Set thread name -#if !defined(_WIN32) && !defined(_WIN64) - auto thread_id = kThreadName + std::to_string(consumer_id); - prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0); -#endif - - // Loop forever - for (;;) { - int task_id = 0; - - // Get next task ID - task_id = task_id_++; - - // All tasks are done - if (task_id >= static_cast(tasks_.Size())) { - return FAILED; - } - const auto &ret = ConsumerOneTask(task_id, consumer_id); - if (SUCCESS != ret.first) { - return FAILED; - } - const auto &batch = (ret.second).second; - // Hanging if maximum map size exceeded - // otherwise, set batch data in map - { - std::unique_lock lck(mtx_delivery_); - cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id <= deliver_id_ + kNumBatchInMap; }); - if (interrupt_) { - return SUCCESS; - } - delivery_map_[task_id] = std::make_shared, json>>>(std::move(batch)); - } - cv_iterator_.notify_one(); - } -} - -MSRStatus ShardReader::ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, - const int &buf_id) { - auto &io_seekg = file_streams_[shard_id]->seekg(page_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - auto &io_read = file_streams_[shard_id]->read(reinterpret_cast(&buf_[buf_id][0]), page_length); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardReader::ConsumerByBlock(int consumer_id) { - // Set thread name -#if !defined(_WIN32) && !defined(_WIN64) - auto thread_id = kThreadName + std::to_string(consumer_id); - prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0); -#endif - - // Loop forever - for (;;) { - int task_id = 0; - - // Get next task ID - task_id = task_id_++; - - // All tasks are done, either quit or repeat again - if (task_id >= num_blocks_) { - std::unique_lock lck(mtx_delivery_); - cv_delivery_.wait(lck, [this] { return interrupt_ || task_id_ < num_blocks_; }); - if (interrupt_) { - return SUCCESS; - } - continue; - } - - // Pick up task from task list - auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); - - auto shard_id = std::get<0>(std::get<1>(task)); - auto group_id = std::get<1>(std::get<1>(task)); - auto row_group_brief = ReadRowGroupBrief(group_id, shard_id, selected_columns_); - if (SUCCESS != std::get<0>(row_group_brief)) { - return FAILED; - } - auto page_length = std::get<2>(row_group_brief); - auto page_offset = std::get<3>(row_group_brief); - - MS_LOG(DEBUG) << "Block task " << task_id << tasks_.permutation_[task_id] << ", shard " << shard_id << ", group " - << group_id << ", page length " << page_length << ", page offset " << page_offset; - - // Deliver block data to output map - auto offset_and_labels = std::make_pair(std::get<4>(row_group_brief), std::get<5>(row_group_brief)); - - int deliver_id = deliver_id_; - // Hanging if maximum map size exceeded otherwise, set batch data in buffer - { - std::unique_lock lck(mtx_delivery_); - cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id < deliver_id_ + kNumPageInBuffer; }); - if (interrupt_) { - return SUCCESS; - } - } - - auto buf_id = task_id % kNumPageInBuffer; - delivery_block_[buf_id] = - std::make_shared>, std::vector>>(offset_and_labels); - - // Read blob - if (ReadBlob(shard_id, page_offset, page_length, buf_id) != SUCCESS) { - return FAILED; - } - - { - std::unique_lock lck(mtx_delivery_); - delivery_block_set_.insert(task_id); - } - cv_iterator_.notify_one(); - } -} - -std::shared_ptr, json>>> ShardReader::GetRowFromBuffer(int buf_id, - int rowId) { - auto &blob_page = buf_[buf_id]; - auto &offsets = (*delivery_block_[buf_id]).first; - auto &labels = (*delivery_block_[buf_id]).second; - auto &addr_start = offsets[rowId][0]; - auto &addr_end = offsets[rowId][1]; - std::vector images(blob_page.begin() + addr_start, blob_page.begin() + addr_end); - std::vector, json>> batch; - batch.emplace_back(std::move(images), std::move(labels[rowId])); - return std::make_shared, json>>>(std::move(batch)); -} - -std::vector, json>> ShardReader::GetBlockNext() { - if (deliver_id_ >= num_blocks_) { - return std::vector, json>>(); - } - - if (row_id_ == 0) { - std::unique_lock lck(mtx_delivery_); - cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_block_set_.count(deliver_id_) > 0); }); - - if (interrupt_) { - return std::vector, json>>(); - } - } - auto buf_id = deliver_id_ % kNumPageInBuffer; - auto res = GetRowFromBuffer(buf_id, row_id_); - - row_id_++; - if (row_id_ == (*delivery_block_[buf_id]).first.size()) { - row_id_ = 0; - { - std::unique_lock lck(mtx_delivery_); - delivery_block_set_.erase(deliver_id_++); - } - cv_delivery_.notify_all(); - } - - return *res; -} - -std::vector, json>> ShardReader::GetNext() { - if (interrupt_) { - return std::vector, json>>(); - } - if (block_reader_) return GetBlockNext(); - if (deliver_id_ >= static_cast(tasks_.Size())) { - return std::vector, json>>(); - } - - std::shared_ptr, json>>> res; - { - std::unique_lock lck(mtx_delivery_); - cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_map_.count(deliver_id_) > 0); }); - if (interrupt_) { - return std::vector, json>>(); - } - res = delivery_map_[deliver_id_]; - delivery_map_.erase(deliver_id_++); - } - - cv_delivery_.notify_all(); - - return *res; -} - -std::pair, json>>> ShardReader::GetNextById( - const int64_t &task_id, const int32_t &consumer_id) { - if (interrupt_) { - return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); - } - if (block_reader_) { - return std::make_pair(TaskType::kCommonTask, GetBlockNext()); - } - const auto &ret = ConsumerOneTask(task_id, consumer_id); - if (SUCCESS != ret.first) { - return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); - } - return std::move(ret.second); -} - -std::pair>> ShardReader::UnCompressBlob( - const std::vector &raw_blob_data) { - auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_; - auto blob_fields = GetBlobFields().second; - std::vector> blob_data; - for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) { - if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue; - const unsigned char *data = nullptr; - std::unique_ptr data_ptr; - uint64_t n_bytes = 0; - auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << "."; - return {FAILED, std::vector>(blob_fields.size(), std::vector())}; - } - if (data == nullptr) { - data = reinterpret_cast(data_ptr.get()); - } - std::vector column(data, data + (n_bytes / sizeof(unsigned char))); - blob_data.push_back(column); - } - return {SUCCESS, blob_data}; -} - -std::vector>, pybind11::object>> ShardReader::GetNextPy() { - auto res = GetNext(); - vector>, pybind11::object>> data; - std::transform(res.begin(), res.end(), std::back_inserter(data), - [this](const std::tuple, json> &item) { - auto &j = std::get<1>(item); - pybind11::object obj = nlohmann::detail::FromJsonImpl(j); - auto ret = UnCompressBlob(std::get<0>(item)); - return std::make_tuple(ret.second, std::move(obj)); - }); - return data; -} - -void ShardReader::Reset() { - { - std::lock_guard lck(mtx_delivery_); - task_id_ = 0; - deliver_id_ = 0; - } - cv_delivery_.notify_all(); -} - -void ShardReader::ShuffleTask() { - if (block_reader_) return; - // exist shuffle and distributed sampler in ops, skip shuffle - bool has_sharding = false; - for (const auto &op : operators_) { - if (std::dynamic_pointer_cast(op)) { - has_sharding = true; - } - } - for (const auto &op : operators_) { - if (std::dynamic_pointer_cast(op) && has_sharding == false) { - if (SUCCESS != (*op)(tasks_)) { - MS_LOG(WARNING) << "Redo randomSampler failed."; - } - } else if (std::dynamic_pointer_cast(op)) { - if (SUCCESS != (*op)(tasks_)) { - MS_LOG(WARNING) << "Redo distributeSampler failed."; - } - } - } - if (tasks_.permutation_.empty()) tasks_.MakePerm(); -} - -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/io/shard_segment.cc b/mindspore/ccsrc/mindrecord/io/shard_segment.cc deleted file mode 100644 index fb1120b178..0000000000 --- a/mindspore/ccsrc/mindrecord/io/shard_segment.cc +++ /dev/null @@ -1,385 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_segment.h" -#include "common/utils.h" - -#include "./securec.h" -#include "mindrecord/include/common/shard_utils.h" -#include "pybind11/pybind11.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; -using mindspore::MsLogLevel::INFO; - -namespace mindspore { -namespace mindrecord { -ShardSegment::ShardSegment() { SetAllInIndex(false); } - -std::pair> ShardSegment::GetCategoryFields() { - // Skip if already populated - if (!candidate_category_fields_.empty()) return {SUCCESS, candidate_category_fields_}; - - std::string sql = "PRAGMA table_info(INDEXES);"; - std::vector> field_names; - - char *errmsg = nullptr; - int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(database_paths_[0]); - database_paths_[0] = nullptr; - return {FAILED, vector{}}; - } else { - MS_LOG(INFO) << "Get " << static_cast(field_names.size()) << " records from index."; - } - - uint32_t idx = kStartFieldId; - while (idx < field_names.size()) { - if (field_names[idx].size() < 2) { - sqlite3_free(errmsg); - sqlite3_close(database_paths_[0]); - database_paths_[0] = nullptr; - return {FAILED, vector{}}; - } - candidate_category_fields_.push_back(field_names[idx][1]); - idx += 2; - } - sqlite3_free(errmsg); - return {SUCCESS, candidate_category_fields_}; -} - -MSRStatus ShardSegment::SetCategoryField(std::string category_field) { - if (GetCategoryFields().first != SUCCESS) { - MS_LOG(ERROR) << "Get candidate category field failed"; - return FAILED; - } - category_field = category_field + "_0"; - if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_), - [category_field](std::string x) { return x == category_field; })) { - current_category_field_ = category_field; - return SUCCESS; - } - MS_LOG(ERROR) << "Field " << category_field << " is not a candidate category field."; - return FAILED; -} - -std::pair ShardSegment::ReadCategoryInfo() { - MS_LOG(INFO) << "Read category begin"; - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info failed"; - return {FAILED, ""}; - } - // Convert category info to json string - auto category_json_string = ToJsonForCategory(ret.second); - - MS_LOG(INFO) << "Read category end"; - - return {SUCCESS, category_json_string}; -} - -std::pair>> ShardSegment::WrapCategoryInfo() { - std::map counter; - - std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ + - ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";"; - - for (auto &db : database_paths_) { - std::vector> field_count; - - char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return {FAILED, std::vector>()}; - } else { - MS_LOG(INFO) << "Get " << static_cast(field_count.size()) << " records from index."; - } - - for (const auto &field : field_count) { - counter[field[0]] += std::stoi(field[1]); - } - sqlite3_free(errmsg); - } - - int idx = 0; - std::vector> category_vec(counter.size()); - (void)std::transform(counter.begin(), counter.end(), category_vec.begin(), [&idx](std::tuple item) { - return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); - }); - return {SUCCESS, std::move(category_vec)}; -} - -std::string ShardSegment::ToJsonForCategory(const std::vector> &tri_vec) { - std::vector category_json_vec; - for (auto q : tri_vec) { - json j; - j["id"] = std::get<0>(q); - j["name"] = std::get<1>(q); - j["count"] = std::get<2>(q); - - category_json_vec.emplace_back(j); - } - - json j_vec(category_json_vec); - json category_info; - category_info["key"] = current_category_field_; - category_info["categories"] = j_vec; - return category_info.dump(); -} - -std::pair>> ShardSegment::ReadAtPageById(int64_t category_id, - int64_t page_no, - int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info"; - return {FAILED, std::vector>{}}; - } - if (category_id >= static_cast(ret.second.size()) || category_id < 0) { - MS_LOG(ERROR) << "Illegal category id, id: " << category_id; - return {FAILED, std::vector>{}}; - } - int total_rows_in_category = std::get<2>(ret.second[category_id]); - // Quit if category not found or page number is out of range - if (total_rows_in_category <= 0 || page_no < 0 || n_rows_of_page <= 0 || - page_no * n_rows_of_page >= total_rows_in_category) { - MS_LOG(ERROR) << "Illegal page no / page size, page no: " << page_no << ", page size: " << n_rows_of_page; - return {FAILED, std::vector>{}}; - } - - std::vector> page; - auto row_group_summary = ReadRowGroupSummary(); - - uint64_t i_start = page_no * n_rows_of_page; - uint64_t i_end = std::min(static_cast(total_rows_in_category), (page_no + 1) * n_rows_of_page); - uint64_t idx = 0; - for (const auto &rg : row_group_summary) { - if (idx >= i_end) break; - - auto shard_id = std::get<0>(rg); - auto group_id = std::get<1>(rg); - auto details = ReadRowGroupCriteria( - group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); - if (SUCCESS != std::get<0>(details)) { - return {FAILED, std::vector>{}}; - } - auto offsets = std::get<4>(details); - uint64_t number_of_rows = offsets.size(); - if (idx + number_of_rows < i_start) { - idx += number_of_rows; - continue; - } - - for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) { - if (idx >= i_start && idx < i_end) { - auto ret1 = PackImages(group_id, shard_id, offsets[i]); - if (SUCCESS != ret1.first) { - return {FAILED, std::vector>{}}; - } - page.push_back(std::move(ret1.second)); - } - } - } - - return {SUCCESS, std::move(page)}; -} - -std::pair> ShardSegment::PackImages(int group_id, int shard_id, - std::vector offset) { - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return {FAILED, std::vector()}; - } - const std::shared_ptr &blob_page = ret.second; - - // Pack image list - std::vector images(offset[1] - offset[0]); - auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0]; - auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - file_streams_random_[0][shard_id]->close(); - return {FAILED, {}}; - } - - auto &io_read = file_streams_random_[0][shard_id]->read(reinterpret_cast(&images[0]), offset[1] - offset[0]); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - file_streams_random_[0][shard_id]->close(); - return {FAILED, {}}; - } - - return {SUCCESS, std::move(images)}; -} - -std::pair>> ShardSegment::ReadAtPageByName(std::string category_name, - int64_t page_no, - int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info"; - return {FAILED, std::vector>{}}; - } - for (const auto &categories : ret.second) { - if (std::get<1>(categories) == category_name) { - auto result = ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page); - return result; - } - } - - return {FAILED, std::vector>()}; -} - -std::pair, json>>> ShardSegment::ReadAllAtPageById( - int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS || category_id >= static_cast(ret.second.size())) { - MS_LOG(ERROR) << "Illegal category id, id: " << category_id; - return {FAILED, std::vector, json>>{}}; - } - int total_rows_in_category = std::get<2>(ret.second[category_id]); - // Quit if category not found or page number is out of range - if (total_rows_in_category <= 0 || page_no < 0 || page_no * n_rows_of_page >= total_rows_in_category) { - MS_LOG(ERROR) << "Illegal page no: " << page_no << ", page size: " << n_rows_of_page; - return {FAILED, std::vector, json>>{}}; - } - - std::vector, json>> page; - auto row_group_summary = ReadRowGroupSummary(); - - int i_start = page_no * n_rows_of_page; - int i_end = std::min(static_cast(total_rows_in_category), (page_no + 1) * n_rows_of_page); - int idx = 0; - for (const auto &rg : row_group_summary) { - if (idx >= i_end) break; - - auto shard_id = std::get<0>(rg); - auto group_id = std::get<1>(rg); - auto details = ReadRowGroupCriteria( - group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); - if (SUCCESS != std::get<0>(details)) { - return {FAILED, std::vector, json>>{}}; - } - auto offsets = std::get<4>(details); - auto labels = std::get<5>(details); - - int number_of_rows = offsets.size(); - if (idx + number_of_rows < i_start) { - idx += number_of_rows; - continue; - } - - if (number_of_rows > static_cast(labels.size())) { - MS_LOG(ERROR) << "Illegal row number of page: " << number_of_rows; - return {FAILED, std::vector, json>>{}}; - } - for (int i = 0; i < number_of_rows; ++i, ++idx) { - if (idx >= i_start && idx < i_end) { - auto ret1 = PackImages(group_id, shard_id, offsets[i]); - if (SUCCESS != ret1.first) { - return {FAILED, std::vector, json>>{}}; - } - page.emplace_back(std::move(ret1.second), std::move(labels[i])); - } - } - } - return {SUCCESS, std::move(page)}; -} - -std::pair, json>>> ShardSegment::ReadAllAtPageByName( - std::string category_name, int64_t page_no, int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info"; - return {FAILED, std::vector, json>>{}}; - } - - // category_name to category_id - int64_t category_id = -1; - for (const auto &categories : ret.second) { - std::string categories_name = std::get<1>(categories); - - if (categories_name == category_name) { - category_id = std::get<0>(categories); - break; - } - } - - if (category_id == -1) { - return {FAILED, std::vector, json>>{}}; - } - - return ReadAllAtPageById(category_id, page_no, n_rows_of_page); -} - -std::pair, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( - int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { - auto res = ReadAllAtPageById(category_id, page_no, n_rows_of_page); - if (res.first != SUCCESS) { - return {FAILED, std::vector, pybind11::object>>{}}; - } - - vector, pybind11::object>> json_data; - std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), - [](const std::tuple, json> &item) { - auto &j = std::get<1>(item); - pybind11::object obj = nlohmann::detail::FromJsonImpl(j); - return std::make_tuple(std::get<0>(item), std::move(obj)); - }); - return {SUCCESS, std::move(json_data)}; -} - -std::pair, pybind11::object>>> ShardSegment::ReadAtPageByNamePy( - std::string category_name, int64_t page_no, int64_t n_rows_of_page) { - auto res = ReadAllAtPageByName(category_name, page_no, n_rows_of_page); - if (res.first != SUCCESS) { - return {FAILED, std::vector, pybind11::object>>{}}; - } - vector, pybind11::object>> json_data; - std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), - [](const std::tuple, json> &item) { - auto &j = std::get<1>(item); - pybind11::object obj = nlohmann::detail::FromJsonImpl(j); - return std::make_tuple(std::get<0>(item), std::move(obj)); - }); - return {SUCCESS, std::move(json_data)}; -} - -std::pair> ShardSegment::GetBlobFields() { - std::vector blob_fields; - for (auto &p : GetShardHeader()->GetSchemas()) { - // assume one schema - const auto &fields = p->GetBlobFields(); - blob_fields.assign(fields.begin(), fields.end()); - break; - } - return std::make_pair(kCV, blob_fields); -} - -std::string ShardSegment::CleanUp(std::string field_name) { - while (field_name.back() >= '0' && field_name.back() <= '9') field_name.pop_back(); - field_name.pop_back(); - return field_name; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc deleted file mode 100644 index 913caab550..0000000000 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ /dev/null @@ -1,1254 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_writer.h" -#include "common/utils.h" -#include "mindrecord/include/common/shard_utils.h" -#include "./securec.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::DEBUG; -using mindspore::MsLogLevel::ERROR; -using mindspore::MsLogLevel::INFO; - -namespace mindspore { -namespace mindrecord { -ShardWriter::ShardWriter() - : shard_count_(1), - header_size_(kDefaultHeaderSize), - page_size_(kDefaultPageSize), - row_count_(0), - schema_count_(1) {} - -ShardWriter::~ShardWriter() { - for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { - file_streams_[i]->close(); - } -} - -MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { - // Get full path from file name - for (const auto &path : paths) { - if (!CheckIsValidUtf8(path)) { - MS_LOG(ERROR) << "The filename contains invalid uft-8 data: " << path << "."; - return FAILED; - } - char resolved_path[PATH_MAX] = {0}; - char buf[PATH_MAX] = {0}; - if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Secure func failed"; - return FAILED; - } -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(resolved_path, dirname(&(buf[0])), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Invalid file path"; - return FAILED; - } - if (_fullpath(resolved_path, common::SafeCStr(path), PATH_MAX) == nullptr) { - MS_LOG(DEBUG) << "Path " << resolved_path; - } -#else - if (realpath(dirname(&(buf[0])), resolved_path) == nullptr) { - MS_LOG(ERROR) << "Invalid file path"; - return FAILED; - } - if (realpath(common::SafeCStr(path), resolved_path) == nullptr) { - MS_LOG(DEBUG) << "Path " << resolved_path; - } -#endif - file_paths_.emplace_back(string(resolved_path)); - } - return SUCCESS; -} - -MSRStatus ShardWriter::OpenDataFiles(bool append) { - // Open files - for (const auto &file : file_paths_) { - std::shared_ptr fs = std::make_shared(); - if (!append) { - // if not append and mindrecord file exist, return FAILED - fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); - if (fs->good()) { - MS_LOG(ERROR) << "MindRecord file already existed."; - fs->close(); - return FAILED; - } - fs->close(); - - // open the mindrecord file to write - fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc); - if (!fs->good()) { - MS_LOG(ERROR) << "MindRecord file could not opened."; - return FAILED; - } - } else { - // open the mindrecord file to append - fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "MindRecord file could not opened for append."; - return FAILED; - } - } - MS_LOG(INFO) << "Open shard file successfully."; - file_streams_.push_back(fs); - } - return SUCCESS; -} - -MSRStatus ShardWriter::RemoveLockFile() { - // Remove temporary file - int ret = std::remove(pages_file_.c_str()); - if (ret == 0) { - MS_LOG(DEBUG) << "Remove page file."; - } - - ret = std::remove(lock_file_.c_str()); - if (ret == 0) { - MS_LOG(DEBUG) << "Remove lock file."; - } - return SUCCESS; -} - -MSRStatus ShardWriter::InitLockFile() { - if (file_paths_.size() == 0) { - MS_LOG(ERROR) << "File path not initialized."; - return FAILED; - } - - lock_file_ = file_paths_[0] + kLockFileSuffix; - pages_file_ = file_paths_[0] + kPageFileSuffix; - - if (RemoveLockFile() == FAILED) { - MS_LOG(ERROR) << "Remove file failed."; - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { - shard_count_ = paths.size(); - if (shard_count_ > kMaxShardCount || shard_count_ == 0) { - MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; - return FAILED; - } - if (schema_count_ > kMaxSchemaCount) { - MS_LOG(ERROR) << "The schema Count greater than max value."; - return FAILED; - } - - // Get full path from file name - if (GetFullPathFromFileName(paths) == FAILED) { - MS_LOG(ERROR) << "Get full path from file name failed."; - return FAILED; - } - - // Open files - if (OpenDataFiles(append) == FAILED) { - MS_LOG(ERROR) << "Open data files failed."; - return FAILED; - } - - // Init lock file - if (InitLockFile() == FAILED) { - MS_LOG(ERROR) << "Init lock file failed."; - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardWriter::OpenForAppend(const std::string &path) { - if (!IsLegalFile(path)) { - return FAILED; - } - auto ret1 = ShardHeader::BuildSingleHeader(path); - if (ret1.first != SUCCESS) { - return FAILED; - } - auto json_header = ret1.second; - auto ret2 = GetParentDir(path); - if (SUCCESS != ret2.first) { - return FAILED; - } - std::vector real_addresses; - for (const auto &path : json_header["shard_addresses"]) { - std::string abs_path = ret2.second + string(path); - real_addresses.emplace_back(abs_path); - } - ShardHeader header = ShardHeader(); - if (header.BuildDataset(real_addresses) == FAILED) { - return FAILED; - } - shard_header_ = std::make_shared(header); - MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); - if (ret == FAILED) { - return FAILED; - } - ret = SetPageSize(shard_header_->GetPageSize()); - if (ret == FAILED) { - return FAILED; - } - ret = Open(real_addresses, true); - if (ret == FAILED) { - MS_LOG(ERROR) << "Open file failed"; - return FAILED; - } - shard_column_ = std::make_shared(shard_header_); - return SUCCESS; -} - -MSRStatus ShardWriter::Commit() { - // Read pages file - std::ifstream page_file(pages_file_.c_str()); - if (page_file.good()) { - page_file.close(); - if (shard_header_->FileToPages(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Read pages from file failed"; - return FAILED; - } - } - - if (WriteShardHeader() == FAILED) { - MS_LOG(ERROR) << "Write metadata failed"; - return FAILED; - } - MS_LOG(INFO) << "Write metadata successfully."; - - // Remove lock file - if (RemoveLockFile() == FAILED) { - MS_LOG(ERROR) << "Remove lock file failed."; - return FAILED; - } - - return SUCCESS; -} - -MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) { - MSRStatus ret = header_data->InitByFiles(file_paths_); - if (ret == FAILED) { - return FAILED; - } - - // set fields in mindrecord when empty - std::vector> fields = header_data->GetFields(); - if (fields.empty()) { - MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields."; - std::vector> schemas = header_data->GetSchemas(); - for (const auto &schema : schemas) { - json jsonSchema = schema->GetSchema()["schema"]; - for (const auto &el : jsonSchema.items()) { - if (el.value()["type"] == "string" || - (el.value()["type"] == "int32" && el.value().find("shape") == el.value().end()) || - (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) || - (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) || - (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) { - fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key())); - } - } - } - // only blob data - if (!fields.empty()) { - ret = header_data->AddIndexFields(fields); - if (ret == FAILED) { - MS_LOG(ERROR) << "Add index field failed"; - return FAILED; - } - } - } - - shard_header_ = header_data; - shard_header_->SetHeaderSize(header_size_); - shard_header_->SetPageSize(page_size_); - shard_column_ = std::make_shared(shard_header_); - return SUCCESS; -} - -MSRStatus ShardWriter::SetHeaderSize(const uint64_t &header_size) { - // header_size [16KB, 128MB] - if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) { - MS_LOG(ERROR) << "Header size should between 16KB and 128MB."; - return FAILED; - } - if (header_size % 4 != 0) { - MS_LOG(ERROR) << "Header size should be divided by four."; - return FAILED; - } - - header_size_ = header_size; - return SUCCESS; -} - -MSRStatus ShardWriter::SetPageSize(const uint64_t &page_size) { - // PageSize [32KB, 256MB] - if (page_size < kMinPageSize || page_size > kMaxPageSize) { - MS_LOG(ERROR) << "Page size should between 16KB and 256MB."; - return FAILED; - } - if (page_size % 4 != 0) { - MS_LOG(ERROR) << "Page size should be divided by four."; - return FAILED; - } - page_size_ = page_size; - return SUCCESS; -} - -void ShardWriter::DeleteErrorData(std::map> &raw_data, - std::vector> &blob_data) { - // get wrong data location - std::set> delete_set; - for (auto &err_mg : err_mg_) { - uint64_t id = err_mg.first; - auto sub_err_mg = err_mg.second; - for (auto &subMg : sub_err_mg) { - int loc = subMg.first; - std::string message = subMg.second; - MS_LOG(ERROR) << "For schema " << id << ", " << loc + 1 << " th data is wrong: " << message; - (void)delete_set.insert(loc); - } - } - - auto it = raw_data.begin(); - if (delete_set.size() == it->second.size()) { - raw_data.clear(); - blob_data.clear(); - return; - } - - // delete wrong raw data - for (auto &loc : delete_set) { - // delete row data - for (auto &raw : raw_data) { - (void)raw.second.erase(raw.second.begin() + loc); - } - - // delete blob data - (void)blob_data.erase(blob_data.begin() + loc); - } -} - -void ShardWriter::PopulateMutexErrorData(const int &row, const std::string &message, - std::map &err_raw_data) { - std::lock_guard lock(check_mutex_); - (void)err_raw_data.insert(std::make_pair(row, message)); -} - -MSRStatus ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, - std::map &err_raw_data) { - auto data_type = std::string(value["type"].get()); - - if ((data_type == "int32" && !data[key].is_number_integer()) || - (data_type == "int64" && !data[key].is_number_integer()) || - (data_type == "float32" && !data[key].is_number_float()) || - (data_type == "float64" && !data[key].is_number_float()) || (data_type == "string" && !data[key].is_string())) { - std::string message = "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is not matched"; - PopulateMutexErrorData(i, message, err_raw_data); - return FAILED; - } - - if (data_type == "int32" && data[key].is_number_integer()) { - int64_t temp_value = data[key]; - if (static_cast(temp_value) < static_cast(std::numeric_limits::min()) && - static_cast(temp_value) > static_cast(std::numeric_limits::max())) { - std::string message = - "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is out of range"; - PopulateMutexErrorData(i, message, err_raw_data); - return FAILED; - } - } - return SUCCESS; -} - -void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const std::vector &sub_raw_data, - std::map &err_raw_data) { - if (start_row < 0 || start_row > end_row || end_row > static_cast(sub_raw_data.size())) { - return; - } - for (int i = start_row; i < end_row; i++) { - json data = sub_raw_data[i]; - - for (auto iter = schema.begin(); iter != schema.end(); iter++) { - std::string key = iter.key(); - json value = iter.value(); - if (data.find(key) == data.end()) { - std::string message = "there is not '" + key + "' object in the raw data"; - PopulateMutexErrorData(i, message, err_raw_data); - break; - } - - if (value.size() == kInt2) { - // Skip check since all shaped data will store as blob - continue; - } - - if (CheckDataTypeAndValue(key, value, data, i, err_raw_data) != SUCCESS) { - break; - } - } - } -} - -MSRStatus ShardWriter::CheckData(const std::map> &raw_data) { - auto rawdata_iter = raw_data.begin(); - - // make sure rawdata match schema - for (; rawdata_iter != raw_data.end(); ++rawdata_iter) { - // used for storing error - std::map sub_err_mg; - int schema_id = rawdata_iter->first; - auto result = shard_header_->GetSchemaByID(schema_id); - if (result.second != SUCCESS) { - return FAILED; - } - json schema = result.first->GetSchema()["schema"]; - for (const auto &field : result.first->GetBlobFields()) { - (void)schema.erase(field); - } - std::vector sub_raw_data = rawdata_iter->second; - - // calculate start position and end position for each thread - int batch_size = rawdata_iter->second.size() / shard_count_; - int thread_num = shard_count_; - if (thread_num <= 0) { - return FAILED; - } - if (thread_num > kMaxThreadCount) { - thread_num = kMaxThreadCount; - } - std::vector thread_set(thread_num); - - // start multiple thread - int start_row = 0, end_row = 0; - for (int x = 0; x < thread_num; ++x) { - if (x != thread_num - 1) { - start_row = batch_size * x; - end_row = batch_size * (x + 1); - } else { - start_row = batch_size * x; - end_row = rawdata_iter->second.size(); - } - thread_set[x] = std::thread(&ShardWriter::CheckSliceData, this, start_row, end_row, schema, - std::ref(sub_raw_data), std::ref(sub_err_mg)); - } - if (thread_num > kMaxThreadCount) { - return FAILED; - } - // Wait for threads done - for (int x = 0; x < thread_num; ++x) { - thread_set[x].join(); - } - - (void)err_mg_.insert(std::make_pair(schema_id, sub_err_mg)); - } - return SUCCESS; -} - -std::tuple ShardWriter::ValidateRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign) { - auto rawdata_iter = raw_data.begin(); - schema_count_ = raw_data.size(); - std::tuple failed(FAILED, 0, 0); - if (schema_count_ == 0) { - MS_LOG(ERROR) << "Data size is zero"; - return failed; - } - - // keep schema_id - std::set schema_ids; - row_count_ = (rawdata_iter->second).size(); - MS_LOG(DEBUG) << "Schema count is " << schema_count_; - - // Determine if the number of schemas is the same - if (shard_header_->GetSchemas().size() != schema_count_) { - MS_LOG(ERROR) << "Data size is not equal with the schema size"; - return failed; - } - - // Determine raw_data size == blob_data size - if (raw_data[0].size() != blob_data.size()) { - MS_LOG(ERROR) << "Raw data size is not equal blob data size"; - return failed; - } - - // Determine whether the number of samples corresponding to each schema is the same - for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) { - if (row_count_ != rawdata_iter->second.size()) { - MS_LOG(ERROR) << "Data size is not equal"; - return failed; - } - (void)schema_ids.insert(rawdata_iter->first); - } - const std::vector> &schemas = shard_header_->GetSchemas(); - if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr &schema) { - return schema_ids.find(schema->GetSchemaID()) == schema_ids.end(); - })) { - // There is not enough data which is not matching the number of schema - MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id."; - return failed; - } - - if (!sign) { - std::tuple success(SUCCESS, schema_count_, row_count_); - return success; - } - - // check the data according the schema - if (CheckData(raw_data) != SUCCESS) { - MS_LOG(ERROR) << "Data validate check failed"; - return std::tuple(FAILED, schema_count_, row_count_); - } - - // delete wrong data from raw data - DeleteErrorData(raw_data, blob_data); - - // update raw count - row_count_ = row_count_ - err_mg_.begin()->second.size(); - std::tuple success(SUCCESS, schema_count_, row_count_); - return success; -} - -void ShardWriter::FillArray(int start, int end, std::map> &raw_data, - std::vector> &bin_data) { - // Prevent excessive thread opening and cause cross-border - if (start >= end) { - flag_ = true; - return; - } - int schema_count = static_cast(raw_data.size()); - std::map>::const_iterator rawdata_iter; - for (int x = start; x < end; ++x) { - int cnt = 0; - for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) { - const json &line = raw_data.at(rawdata_iter->first)[x]; - std::vector bline = json::to_msgpack(line); - - // Storage form is [Sample1-Schema1, Sample1-Schema2, Sample2-Schema1, Sample2-Schema2] - bin_data[x * schema_count + cnt] = bline; - cnt++; - } - } -} - -int ShardWriter::LockWriter(bool parallel_writer) { - if (!parallel_writer) { - return 0; - } - -#if defined(_WIN32) || defined(_WIN64) - MS_LOG(DEBUG) << "Lock file done by python."; - const int fd = 0; -#else - const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); - if (fd >= 0) { - flock(fd, LOCK_EX); - } else { - MS_LOG(ERROR) << "Shard writer failed when locking file"; - return -1; - } -#endif - - // Open files - file_streams_.clear(); - for (const auto &file : file_paths_) { - std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); - if (fs->fail()) { - MS_LOG(ERROR) << "File could not opened"; - return -1; - } - file_streams_.push_back(fs); - } - - if (shard_header_->FileToPages(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Read pages from file failed"; - return -1; - } - return fd; -} - -MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { - if (!parallel_writer) { - return SUCCESS; - } - - if (shard_header_->PagesToFile(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Write pages to file failed"; - return FAILED; - } - - for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { - file_streams_[i]->close(); - } - -#if defined(_WIN32) || defined(_WIN64) - MS_LOG(DEBUG) << "Unlock file done by python."; -#else - flock(fd, LOCK_UN); - close(fd); -#endif - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, - std::vector> &blob_data, bool sign, int *schema_count, - int *row_count) { - // check the free disk size - auto st_space = GetDiskSize(file_paths_[0], kFreeSize); - if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { - MS_LOG(ERROR) << "IO error / there is no free disk to be used"; - return FAILED; - } - - // compress blob - if (shard_column_->CheckCompressBlob()) { - for (auto &blob : blob_data) { - blob = shard_column_->CompressBlob(blob); - } - } - - // Add 4-bytes dummy blob data if no any blob fields - if (blob_data.size() == 0 && raw_data.size() > 0) { - blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); - } - - // Add dummy id if all are blob fields - if (blob_data.size() > 0 && raw_data.size() == 0) { - raw_data.insert(std::pair>(0, std::vector(blob_data.size(), kDummyId))); - } - - auto v = ValidateRawData(raw_data, blob_data, sign); - if (std::get<0>(v) == FAILED) { - MS_LOG(ERROR) << "Validate raw data failed"; - return FAILED; - } - *schema_count = std::get<1>(v); - *row_count = std::get<2>(v); - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign, bool parallel_writer) { - // Lock Writer if loading data parallel - int fd = LockWriter(parallel_writer); - if (fd < 0) { - MS_LOG(ERROR) << "Lock writer failed"; - return FAILED; - } - - // Get the count of schemas and rows - int schema_count = 0; - int row_count = 0; - - // Serialize raw data - if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { - MS_LOG(ERROR) << "Check raw data failed"; - return FAILED; - } - - if (row_count == kInt0) { - MS_LOG(INFO) << "Raw data size is 0."; - return SUCCESS; - } - - std::vector> bin_raw_data(row_count * schema_count); - - // Serialize raw data - if (SerializeRawData(raw_data, bin_raw_data, row_count) == FAILED) { - MS_LOG(ERROR) << "Serialize raw data failed"; - return FAILED; - } - - // Set row size of raw data - if (SetRawDataSize(bin_raw_data) == FAILED) { - MS_LOG(ERROR) << "Set raw data size failed"; - return FAILED; - } - - // Set row size of blob data - if (SetBlobDataSize(blob_data) == FAILED) { - MS_LOG(ERROR) << "Set blob data size failed"; - return FAILED; - } - - // Write data to disk with multi threads - if (ParallelWriteData(blob_data, bin_raw_data) == FAILED) { - MS_LOG(ERROR) << "Parallel write data failed"; - return FAILED; - } - MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; - - if (UnlockWriter(fd, parallel_writer) == FAILED) { - MS_LOG(ERROR) << "Unlock writer failed"; - return FAILED; - } - - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign, - bool parallel_writer) { - std::map> raw_data_json; - std::map> blob_data_json; - - (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), - [](const std::pair> &pair) { - auto &py_raw_data = pair.second; - std::vector json_raw_data; - (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), - [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); - return std::make_pair(pair.first, std::move(json_raw_data)); - }); - - (void)std::transform(blob_data.begin(), blob_data.end(), std::inserter(blob_data_json, blob_data_json.end()), - [](const std::pair> &pair) { - auto &py_blob_data = pair.second; - std::vector jsonBlobData; - (void)std::transform(py_blob_data.begin(), py_blob_data.end(), - std::back_inserter(jsonBlobData), - [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); - return std::make_pair(pair.first, std::move(jsonBlobData)); - }); - - // Serialize blob page - auto blob_data_iter = blob_data.begin(); - auto schema_count = blob_data.size(); - auto row_count = blob_data_iter->second.size(); - - std::vector> bin_blob_data(row_count * schema_count); - // Serialize blob data - if (SerializeRawData(blob_data_json, bin_blob_data, row_count) == FAILED) { - MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; - return FAILED; - } - return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); -} - -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - vector> &blob_data, bool sign, bool parallel_writer) { - std::map> raw_data_json; - (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), - [](const std::pair> &pair) { - auto &py_raw_data = pair.second; - std::vector json_raw_data; - (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), - [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); - return std::make_pair(pair.first, std::move(json_raw_data)); - }); - return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); -} - -MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, - const std::vector> &bin_raw_data) { - auto shards = BreakIntoShards(); - // define the number of thread - int thread_num = static_cast(shard_count_); - if (thread_num < 0) { - return FAILED; - } - if (thread_num > kMaxThreadCount) { - thread_num = kMaxThreadCount; - } - int left_thread = shard_count_; - int current_thread = 0; - while (left_thread) { - if (left_thread < thread_num) { - thread_num = left_thread; - } - // Start one thread for one shard - std::vector thread_set(thread_num); - if (thread_num <= kMaxThreadCount) { - for (int x = 0; x < thread_num; ++x) { - int start_row = shards[current_thread + x].first; - int end_row = shards[current_thread + x].second; - thread_set[x] = std::thread(&ShardWriter::WriteByShard, this, current_thread + x, start_row, end_row, - std::ref(blob_data), std::ref(bin_raw_data)); - } - // Wait for threads done - for (int x = 0; x < thread_num; ++x) { - thread_set[x].join(); - } - left_thread -= thread_num; - current_thread += thread_num; - } - } - return SUCCESS; -} - -MSRStatus ShardWriter::WriteByShard(int shard_id, int start_row, int end_row, - const std::vector> &blob_data, - const std::vector> &bin_raw_data) { - MS_LOG(DEBUG) << "Shard: " << shard_id << ", start: " << start_row << ", end: " << end_row - << ", schema size: " << schema_count_; - if (start_row == end_row) { - return SUCCESS; - } - vector> rows_in_group; - std::shared_ptr last_raw_page = nullptr; - std::shared_ptr last_blob_page = nullptr; - SetLastRawPage(shard_id, last_raw_page); - SetLastBlobPage(shard_id, last_blob_page); - - if (CutRowGroup(start_row, end_row, blob_data, rows_in_group, last_raw_page, last_blob_page) == FAILED) { - MS_LOG(ERROR) << "Cut row group failed"; - return FAILED; - } - - if (AppendBlobPage(shard_id, blob_data, rows_in_group, last_blob_page) == FAILED) { - MS_LOG(ERROR) << "Append bolb page failed"; - return FAILED; - } - - if (NewBlobPage(shard_id, blob_data, rows_in_group, last_blob_page) == FAILED) { - MS_LOG(ERROR) << "New blob page failed"; - return FAILED; - } - - if (ShiftRawPage(shard_id, rows_in_group, last_raw_page) == FAILED) { - MS_LOG(ERROR) << "Shit raw page failed"; - return FAILED; - } - - if (WriteRawPage(shard_id, rows_in_group, last_raw_page, bin_raw_data) == FAILED) { - MS_LOG(ERROR) << "Write raw page failed"; - return FAILED; - } - - return SUCCESS; -} - -MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, - std::vector> &rows_in_group, - const std::shared_ptr &last_raw_page, - const std::shared_ptr &last_blob_page) { - auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0; - - auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; - auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0; - auto n_byte_raw = last_raw_page_size - last_raw_offset; - - int page_start_row = start_row; - if (start_row > end_row) { - return FAILED; - } - if (end_row > static_cast(blob_data_size_.size()) || end_row > static_cast(raw_data_size_.size())) { - return FAILED; - } - for (int i = start_row; i < end_row; ++i) { - // n_byte_blob(0) indicate appendBlobPage - if (n_byte_blob == 0 || n_byte_blob + blob_data_size_[i] > page_size_ || - n_byte_raw + raw_data_size_[i] > page_size_) { - rows_in_group.emplace_back(page_start_row, i); - page_start_row = i; - n_byte_blob = blob_data_size_[i]; - n_byte_raw = raw_data_size_[i]; - } else { - n_byte_blob += blob_data_size_[i]; - n_byte_raw += raw_data_size_[i]; - } - } - - // Not forget last one - rows_in_group.emplace_back(page_start_row, end_row); - return SUCCESS; -} - -MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page) { - auto blob_row = rows_in_group[0]; - if (blob_row.first == blob_row.second) return SUCCESS; - - // Write disk - auto page_id = last_blob_page->GetPageID(); - auto bytes_page = last_blob_page->GetPageSize(); - auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg); - if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row); - - // Update last blob page - bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); - last_blob_page->SetPageSize(bytes_page); - uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first; - last_blob_page->SetEndRowID(end_row); - (void)shard_header_->SetPage(last_blob_page); - return SUCCESS; -} - -MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page) { - auto page_id = shard_header_->GetLastPageId(shard_id); - auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1; - auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0; - // index(0) indicate appendBlobPage - for (uint32_t i = 1; i < rows_in_group.size(); ++i) { - auto blob_row = rows_in_group[i]; - - // Write 1 blob page to disk - auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg); - if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row); - // Create new page info for header - auto page_size = - std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); - std::vector> row_group_ids; - auto start_row = current_row; - auto end_row = start_row + blob_row.second - blob_row.first; - auto page = Page(++page_id, shard_id, kPageTypeBlob, ++page_type_id, start_row, end_row, row_group_ids, page_size); - (void)shard_header_->AddPage(std::make_shared(page)); - current_row = end_row; - } - return SUCCESS; -} - -MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page) { - auto blob_row = rows_in_group[0]; - if (blob_row.first == blob_row.second) return SUCCESS; - auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; - if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) + - last_raw_page_size <= - page_size_) { - return SUCCESS; - } - auto page_id = shard_header_->GetLastPageId(shard_id); - auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second; - auto last_raw_page_id = last_raw_page->GetPageID(); - auto shift_size = last_raw_page_size - last_row_group_id_offset; - - std::vector buf(shift_size); - - // Read last row group from previous raw data page - if (shard_id < 0 || shard_id >= file_streams_.size()) { - return FAILED; - } - - auto &io_seekg = file_streams_[shard_id]->seekg( - page_size_ * last_raw_page_id + header_size_ + last_row_group_id_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - auto &io_read = file_streams_[shard_id]->read(reinterpret_cast(&buf[0]), buf.size()); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - // Merge into new row group at new raw data page - auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg); - if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast(&buf[0]), buf.size()); - if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - last_raw_page->DeleteLastGroupId(); - (void)shard_header_->SetPage(last_raw_page); - - // Refresh page info in header - int row_group_id = last_raw_page->GetLastRowGroupID().first + 1; - std::vector> row_group_ids; - row_group_ids.emplace_back(row_group_id, 0); - int page_type_id = last_raw_page->GetPageID(); - auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size); - (void)shard_header_->AddPage(std::make_shared(page)); - - // Reset: last raw page - SetLastRawPage(shard_id, last_raw_page); - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page, - const std::vector> &bin_raw_data) { - int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1; - for (uint32_t i = 0; i < rows_in_group.size(); ++i) { - const auto &blob_row = rows_in_group[i]; - if (blob_row.first == blob_row.second) continue; - auto raw_size = - std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0); - if (!last_raw_page) { - EmptyRawPage(shard_id, last_raw_page); - } else if (last_raw_page->GetPageSize() + raw_size > page_size_) { - (void)shard_header_->SetPage(last_raw_page); - EmptyRawPage(shard_id, last_raw_page); - } - if (AppendRawPage(shard_id, rows_in_group, i, last_row_group_id, last_raw_page, bin_raw_data) != SUCCESS) { - return FAILED; - } - } - (void)shard_header_->SetPage(last_raw_page); - return SUCCESS; -} - -void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { - auto row_group_ids = std::vector>(); - auto page_id = shard_header_->GetLastPageId(shard_id); - auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1; - auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0); - (void)shard_header_->AddPage(std::make_shared(page)); - SetLastRawPage(shard_id, last_raw_page); -} - -MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, - const int &chunk_id, int &last_row_group_id, std::shared_ptr last_raw_page, - const std::vector> &bin_raw_data) { - std::vector> row_group_ids = last_raw_page->GetRowGroupIds(); - auto last_raw_page_id = last_raw_page->GetPageID(); - auto n_bytes = last_raw_page->GetPageSize(); - - // previous raw data page - auto &io_seekp = - file_streams_[shard_id]->seekp(page_size_ * last_raw_page_id + header_size_ + n_bytes, std::ios::beg); - if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - if (chunk_id > 0) row_group_ids.emplace_back(++last_row_group_id, n_bytes); - n_bytes += std::accumulate(raw_data_size_.begin() + rows_in_group[chunk_id].first, - raw_data_size_.begin() + rows_in_group[chunk_id].second, 0); - (void)FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data); - - // Update previous raw data page - last_raw_page->SetPageSize(n_bytes); - last_raw_page->SetRowGroupIds(row_group_ids); - (void)shard_header_->SetPage(last_raw_page); - - return SUCCESS; -} - -MSRStatus ShardWriter::FlushBlobChunk(const std::shared_ptr &out, - const std::vector> &blob_data, - const std::pair &blob_row) { - if (blob_row.first > blob_row.second) { - return FAILED; - } - if (blob_row.second > static_cast(blob_data.size()) || blob_row.first < 0) { - return FAILED; - } - for (int j = blob_row.first; j < blob_row.second; ++j) { - // Write the size of blob - uint64_t line_len = blob_data[j].size(); - auto &io_handle = out->write(reinterpret_cast(&line_len), kInt64Len); - if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; - out->close(); - return FAILED; - } - - // Write the data of blob - auto line = blob_data[j]; - auto &io_handle_data = out->write(reinterpret_cast(&line[0]), line_len); - if (!io_handle_data.good() || io_handle_data.fail() || io_handle_data.bad()) { - MS_LOG(ERROR) << "File write failed"; - out->close(); - return FAILED; - } - } - return SUCCESS; -} - -MSRStatus ShardWriter::FlushRawChunk(const std::shared_ptr &out, - const std::vector> &rows_in_group, const int &chunk_id, - const std::vector> &bin_raw_data) { - for (int i = rows_in_group[chunk_id].first; i < rows_in_group[chunk_id].second; i++) { - // Write the size of multi schemas - for (uint32_t j = 0; j < schema_count_; ++j) { - uint64_t line_len = bin_raw_data[i * schema_count_ + j].size(); - auto &io_handle = out->write(reinterpret_cast(&line_len), kInt64Len); - if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; - out->close(); - return FAILED; - } - } - // Write the data of multi schemas - for (uint32_t j = 0; j < schema_count_; ++j) { - auto line = bin_raw_data[i * schema_count_ + j]; - auto &io_handle = out->write(reinterpret_cast(&line[0]), line.size()); - if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; - out->close(); - return FAILED; - } - } - } - return SUCCESS; -} - -// Allocate data to shards evenly -std::vector> ShardWriter::BreakIntoShards() { - std::vector> shards; - int row_in_shard = row_count_ / shard_count_; - int remains = row_count_ % shard_count_; - - std::vector v_list(shard_count_); - std::iota(v_list.begin(), v_list.end(), 0); - std::random_device rd; - std::mt19937 g(rd()); - std::shuffle(v_list.begin(), v_list.end(), g); - std::unordered_set set(v_list.begin(), v_list.begin() + remains); - - if (shard_count_ <= kMaxShardCount) { - int start_row = 0; - for (int i = 0; i < shard_count_; ++i) { - int end_row = start_row + row_in_shard; - if (set.count(i)) end_row++; - shards.emplace_back(start_row, end_row); - start_row = end_row; - } - } - return shards; -} - -MSRStatus ShardWriter::WriteShardHeader() { - if (shard_header_ == nullptr) { - MS_LOG(ERROR) << "Shard header is null"; - return FAILED; - } - auto shard_header = shard_header_->SerializeHeader(); - // Write header data to multi files - if (shard_count_ > static_cast(file_streams_.size()) || shard_count_ > static_cast(shard_header.size())) { - return FAILED; - } - if (shard_count_ <= kMaxShardCount) { - for (int shard_id = 0; shard_id < shard_count_; ++shard_id) { - auto &io_seekp = file_streams_[shard_id]->seekp(0, std::ios::beg); - if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - std::vector bin_header(shard_header[shard_id].begin(), shard_header[shard_id].end()); - uint64_t line_len = bin_header.size(); - if (line_len + kInt64Len > header_size_) { - MS_LOG(ERROR) << "Shard header is too big"; - return FAILED; - } - - auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast(&line_len), kInt64Len); - if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - auto &io_handle_header = file_streams_[shard_id]->write(reinterpret_cast(&bin_header[0]), line_len); - if (!io_handle_header.good() || io_handle_header.fail() || io_handle_header.bad()) { - MS_LOG(ERROR) << "File write failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - file_streams_[shard_id]->close(); - } - } - return SUCCESS; -} - -MSRStatus ShardWriter::SerializeRawData(std::map> &raw_data, - std::vector> &bin_data, uint32_t row_count) { - // define the number of thread - uint32_t thread_num = std::thread::hardware_concurrency(); - if (thread_num == 0) thread_num = kThreadNumber; - // Set the number of samples processed by each thread - int group_num = ceil(row_count * 1.0 / thread_num); - std::vector thread_set(thread_num); - int work_thread_num = 0; - for (uint32_t x = 0; x < thread_num; ++x) { - int start_num = x * group_num; - int end_num = ((x + 1) * group_num > row_count) ? row_count : (x + 1) * group_num; - if (start_num >= end_num) { - continue; - } - // Define the run boundary and start the child thread - thread_set[x] = - std::thread(&ShardWriter::FillArray, this, start_num, end_num, std::ref(raw_data), std::ref(bin_data)); - work_thread_num++; - } - for (uint32_t x = 0; x < work_thread_num; ++x) { - // Set obstacles to prevent the main thread from running - thread_set[x].join(); - } - return flag_ == true ? FAILED : SUCCESS; -} - -MSRStatus ShardWriter::SetRawDataSize(const std::vector> &bin_raw_data) { - raw_data_size_ = std::vector(row_count_, 0); - for (uint32_t i = 0; i < row_count_; ++i) { - raw_data_size_[i] = std::accumulate( - bin_raw_data.begin() + (i * schema_count_), bin_raw_data.begin() + (i * schema_count_) + schema_count_, 0, - [](uint64_t accumulator, const std::vector &row) { return accumulator + kInt64Len + row.size(); }); - } - if (*std::max_element(raw_data_size_.begin(), raw_data_size_.end()) > page_size_) { - MS_LOG(ERROR) << "Page size is too small to save a row!"; - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardWriter::SetBlobDataSize(const std::vector> &blob_data) { - blob_data_size_ = std::vector(row_count_); - (void)std::transform(blob_data.begin(), blob_data.end(), blob_data_size_.begin(), - [](const std::vector &row) { return kInt64Len + row.size(); }); - if (*std::max_element(blob_data_size_.begin(), blob_data_size_.end()) > page_size_) { - MS_LOG(ERROR) << "Page size is too small to save a row!"; - return FAILED; - } - return SUCCESS; -} - -void ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { - // Get last raw page - auto last_raw_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeRaw); - if (last_raw_page_id >= 0) { - auto page = shard_header_->GetPage(shard_id, last_raw_page_id); - last_raw_page = page.first; - } -} - -void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page) { - // Get last blob page - auto last_blob_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeBlob); - if (last_blob_page_id >= 0) { - auto page = shard_header_->GetPage(shard_id, last_blob_page_id); - last_blob_page = page.first; - } -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/mindrecord/meta/shard_category.cc deleted file mode 100644 index bd427a330a..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_category.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_category.h" - -namespace mindspore { -namespace mindrecord { -ShardCategory::ShardCategory(const std::vector> &categories, int64_t num_elements, - bool replacement) - : categories_(categories), - category_field_(""), - num_elements_(num_elements), - num_categories_(0), - replacement_(replacement) {} - -ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elements, int64_t num_categories, - bool replacement) - : categories_({}), - category_field_(category_field), - num_elements_(num_elements), - num_categories_(num_categories), - replacement_(replacement) {} - -MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; } - -int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { - if (dataset_size == 0) return dataset_size; - if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) { - return std::min(num_categories_, num_classes) * num_elements_; - } - return 0; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/mindrecord/meta/shard_column.cc deleted file mode 100644 index 28dc243e17..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_column.cc +++ /dev/null @@ -1,496 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_column.h" - -#include "common/utils.h" -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" - -namespace mindspore { -namespace mindrecord { -ShardColumn::ShardColumn(const std::shared_ptr &shard_header, bool compress_integer) { - auto first_schema = shard_header->GetSchemas()[0]; - auto schema = first_schema->GetSchema()["schema"]; - - bool has_integer_array = false; - for (json::iterator it = schema.begin(); it != schema.end(); ++it) { - const std::string &column_name = it.key(); - column_name_.push_back(column_name); - - json it_value = it.value(); - - std::string str_type = it_value["type"]; - column_data_type_.push_back(ColumnDataTypeMap.at(str_type)); - if (it_value.find("shape") != it_value.end()) { - std::vector vec(it_value["shape"].size()); - std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); - column_shape_.push_back(vec); - if (str_type == "int32" || str_type == "int64") { - has_integer_array = true; - } - } else { - std::vector vec = {}; - column_shape_.push_back(vec); - } - } - - for (uint64_t i = 0; i < column_name_.size(); i++) { - column_name_id_[column_name_[i]] = i; - } - - auto blob_fields = first_schema->GetBlobFields(); - - for (const auto &field : blob_fields) { - blob_column_.push_back(field); - } - - for (uint64_t i = 0; i < blob_column_.size(); i++) { - blob_column_id_[blob_column_[i]] = i; - } - - has_compress_blob_ = (compress_integer && has_integer_array); - num_blob_column_ = blob_column_.size(); -} - -std::pair ShardColumn::GetColumnTypeByName(const std::string &column_name, - ColumnDataType *column_data_type, - uint64_t *column_data_type_size, - std::vector *column_shape) { - // Skip if column not found - auto column_category = CheckColumnName(column_name); - if (column_category == ColumnNotFound) { - return {FAILED, ColumnNotFound}; - } - - // Get data type and size - auto column_id = column_name_id_[column_name]; - *column_data_type = column_data_type_[column_id]; - *column_data_type_size = ColumnDataTypeSize[*column_data_type]; - *column_shape = column_shape_[column_id]; - - return {SUCCESS, column_category}; -} - -MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, - const json &columns_json, const unsigned char **data, - std::unique_ptr *data_ptr, uint64_t *const n_bytes, - ColumnDataType *column_data_type, uint64_t *column_data_type_size, - std::vector *column_shape) { - // Skip if column not found - auto column_category = CheckColumnName(column_name); - if (column_category == ColumnNotFound) { - return FAILED; - } - - // Get data type and size - auto column_id = column_name_id_[column_name]; - *column_data_type = column_data_type_[column_id]; - *column_data_type_size = ColumnDataTypeSize[*column_data_type]; - *column_shape = column_shape_[column_id]; - - // Retrieve value from json - if (column_category == ColumnInRaw) { - if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) { - MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << "."; - return FAILED; - } - *data = reinterpret_cast(data_ptr->get()); - return SUCCESS; - } - - // Retrieve value from blob - if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) { - MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << "."; - return FAILED; - } - if (*data == nullptr) { - *data = reinterpret_cast(data_ptr->get()); - } - return SUCCESS; -} - -MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, - std::unique_ptr *data_ptr, uint64_t *n_bytes) { - auto column_id = column_name_id_[column_name]; - auto column_data_type = column_data_type_[column_id]; - - // Initialize num bytes - *n_bytes = ColumnDataTypeSize[column_data_type]; - auto json_column_value = columns_json[column_name]; - switch (column_data_type) { - case ColumnFloat32: { - return GetFloat(data_ptr, json_column_value, false); - } - case ColumnFloat64: { - return GetFloat(data_ptr, json_column_value, true); - } - case ColumnInt32: { - return GetInt(data_ptr, json_column_value); - } - case ColumnInt64: { - return GetInt(data_ptr, json_column_value); - } - default: { - // Convert string to c_str - std::string tmp_string = json_column_value; - *n_bytes = tmp_string.size(); - auto data = reinterpret_cast(common::SafeCStr(tmp_string)); - *data_ptr = std::make_unique(*n_bytes); - for (uint32_t i = 0; i < *n_bytes; i++) { - (*data_ptr)[i] = *(data + i); - } - break; - } - } - return SUCCESS; -} - -template -MSRStatus ShardColumn::GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, - bool use_double) { - std::unique_ptr array_data = std::make_unique(1); - if (!json_column_value.is_string() && !json_column_value.is_number()) { - MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; - return FAILED; - } - if (json_column_value.is_number()) { - array_data[0] = json_column_value; - } else { - // Convert string to float - try { - if (use_double) { - array_data[0] = json_column_value.get(); - } else { - array_data[0] = json_column_value.get(); - } - } catch (json::exception &e) { - MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; - return FAILED; - } - } - - auto data = reinterpret_cast(array_data.get()); - *data_ptr = std::make_unique(sizeof(T)); - for (uint32_t i = 0; i < sizeof(T); i++) { - (*data_ptr)[i] = *(data + i); - } - - return SUCCESS; -} - -template -MSRStatus ShardColumn::GetInt(std::unique_ptr *data_ptr, const json &json_column_value) { - std::unique_ptr array_data = std::make_unique(1); - int64_t temp_value; - bool less_than_zero = false; - - if (json_column_value.is_number_integer()) { - const json json_zero = 0; - if (json_column_value < json_zero) less_than_zero = true; - temp_value = json_column_value; - } else if (json_column_value.is_string()) { - std::string string_value = json_column_value; - - if (!string_value.empty() && string_value[0] == '-') { - try { - temp_value = std::stoll(string_value); - less_than_zero = true; - } catch (std::invalid_argument &e) { - MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; - return FAILED; - } catch (std::out_of_range &e) { - MS_LOG(ERROR) << "Conversion to int failed, out of range."; - return FAILED; - } - } else { - try { - temp_value = static_cast(std::stoull(string_value)); - } catch (std::invalid_argument &e) { - MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; - return FAILED; - } catch (std::out_of_range &e) { - MS_LOG(ERROR) << "Conversion to int failed, out of range."; - return FAILED; - } - } - } else { - MS_LOG(ERROR) << "Conversion to int failed."; - return FAILED; - } - - if ((less_than_zero && temp_value < static_cast(std::numeric_limits::min())) || - (!less_than_zero && static_cast(temp_value) > static_cast(std::numeric_limits::max()))) { - MS_LOG(ERROR) << "Conversion to int failed. Out of range"; - return FAILED; - } - array_data[0] = static_cast(temp_value); - - auto data = reinterpret_cast(array_data.get()); - *data_ptr = std::make_unique(sizeof(T)); - for (uint32_t i = 0; i < sizeof(T); i++) { - (*data_ptr)[i] = *(data + i); - } - - return SUCCESS; -} - -MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, - const unsigned char **data, std::unique_ptr *data_ptr, - uint64_t *const n_bytes) { - uint64_t offset_address = 0; - auto column_id = column_name_id_[column_name]; - if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) { - return FAILED; - } - - auto column_data_type = column_data_type_[column_id]; - if (has_compress_blob_ && column_data_type == ColumnInt32) { - if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { - return FAILED; - } - } else if (has_compress_blob_ && column_data_type == ColumnInt64) { - if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { - return FAILED; - } - } else { - *data = reinterpret_cast(&(columns_blob[offset_address])); - } - - return SUCCESS; -} - -ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { - auto it_column = column_name_id_.find(column_name); - if (it_column == column_name_id_.end()) { - return ColumnNotFound; - } - auto it_blob = blob_column_id_.find(column_name); - return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; -} - -std::vector ShardColumn::CompressBlob(const std::vector &blob) { - // Skip if no compress columns - if (!CheckCompressBlob()) return blob; - - std::vector dst_blob; - uint64_t i_src = 0; - for (int64_t i = 0; i < num_blob_column_; i++) { - // Get column data type - auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]]; - auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type; - - // Compress and return is blob has 1 column only - if (num_blob_column_ == 1) { - return CompressInt(blob, int_type); - } - - // Just copy and continue if column dat type is not int32/int64 - uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type); - if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) { - dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes); - i_src += kInt64Len + num_bytes; - continue; - } - - // Get column slice in source blob - std::vector blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes); - // Compress column - auto dst_blob_slice = CompressInt(blob_slice, int_type); - // Get new column size - auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type); - // Append new colmn size - dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end()); - // Append new colmn data - dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end()); - i_src += kInt64Len + num_bytes; - } - MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; - return dst_blob; -} - -vector ShardColumn::CompressInt(const vector &src_bytes, const IntegerType &int_type) { - uint64_t i_size = kUnsignedOne << static_cast(int_type); - // Get number of elements - uint64_t src_n_int = src_bytes.size() / i_size; - // Calculate bitmap size (bytes) - uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte; - - // Initilize destination blob, more space than needed, will be resized - vector dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0); - - // Write number of elements to destination blob - vector size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type); - for (uint64_t n = 0; n < kBytesOfColumnLen; n++) { - dst_bytes[n] = size_by_bytes[n]; - } - - // Write compressed int - uint64_t i_dst = kBytesOfColumnLen + bitmap_size; - for (uint64_t i = 0; i < src_n_int; i++) { - // Initialize destination data type - IntegerType dst_int_type = kInt8Type; - // Shift to next int position - uint64_t pos = i * (kUnsignedOne << static_cast(int_type)); - // Narrow down this int - int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); - - // Write this int to destination blob - uint64_t u_n = *reinterpret_cast(&i_n); - auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type); - for (uint64_t j = 0; j < (kUnsignedOne << static_cast(dst_int_type)); j++) { - dst_bytes[i_dst++] = temp_bytes[j]; - } - - // Update date type in bit map - dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |= - (static_cast(dst_int_type) << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte)))); - } - // Resize destination blob - dst_bytes.resize(i_dst); - MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << "."; - return dst_bytes; -} - -MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, - uint64_t *num_bytes, uint64_t *shift_idx) { - if (num_blob_column_ == 1) { - *num_bytes = columns_blob.size(); - *shift_idx = 0; - return SUCCESS; - } - auto blob_id = blob_column_id_[column_name_[column_id]]; - - for (int32_t i = 0; i < blob_id; i++) { - *shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); - } - *num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); - - (*shift_idx) += kInt64Len; - - return SUCCESS; -} - -template -MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, - const std::vector &columns_blob, uint64_t *num_bytes, - uint64_t shift_idx) { - auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); - *num_bytes = sizeof(T) * num_elements; - - // Parse integer array - uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte; - auto array_data = std::make_unique(num_elements); - - for (uint64_t i = 0; i < num_elements; i++) { - uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte]; - uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask; - auto mr_int_type = static_cast(i_type); - int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type); - i_source += (kUnsignedOne << i_type); - array_data[i] = static_cast(i64); - } - - auto data = reinterpret_cast(array_data.get()); - *data_ptr = std::make_unique(*num_bytes); - int ret_code = memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data!"; - } - - return SUCCESS; -} - -uint64_t ShardColumn::BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, - const IntegerType &i_type) { - uint64_t result = 0; - for (uint64_t i = 0; i < (kUnsignedOne << static_cast(i_type)); i++) { - result = (result << kBitsOfByte) + bytes_array[pos + i]; - } - return result; -} - -std::vector ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) { - uint64_t n_bytes = kUnsignedOne << static_cast(i_type); - std::vector result(n_bytes, 0); - for (uint64_t i = 0; i < n_bytes; i++) { - result[n_bytes - 1 - i] = value & std::numeric_limits::max(); - value >>= kBitsOfByte; - } - return result; -} - -std::vector ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) { - uint64_t n_bytes = kUnsignedOne << static_cast(i_type); - std::vector result(n_bytes, 0); - for (uint64_t i = 0; i < n_bytes; i++) { - result[i] = value & std::numeric_limits::max(); - value >>= kBitsOfByte; - } - return result; -} - -int64_t ShardColumn::BytesLittleToMinIntType(const std::vector &bytes_array, const uint64_t &pos, - const IntegerType &src_i_type, IntegerType *dst_i_type) { - uint64_t u_temp = 0; - for (uint64_t i = 0; i < (kUnsignedOne << static_cast(src_i_type)); i++) { - u_temp = (u_temp << kBitsOfByte) + - bytes_array[pos + (kUnsignedOne << static_cast(src_i_type)) - kUnsignedOne - i]; - } - - int64_t i_out; - switch (src_i_type) { - case kInt8Type: { - i_out = (int8_t)(u_temp & std::numeric_limits::max()); - break; - } - case kInt16Type: { - i_out = (int16_t)(u_temp & std::numeric_limits::max()); - break; - } - case kInt32Type: { - i_out = (int32_t)(u_temp & std::numeric_limits::max()); - break; - } - case kInt64Type: { - i_out = (int64_t)(u_temp & std::numeric_limits::max()); - break; - } - default: { - i_out = 0; - } - } - - if (!dst_i_type) { - return i_out; - } - - if (i_out >= static_cast(std::numeric_limits::min()) && - i_out <= static_cast(std::numeric_limits::max())) { - *dst_i_type = kInt8Type; - } else if (i_out >= static_cast(std::numeric_limits::min()) && - i_out <= static_cast(std::numeric_limits::max())) { - *dst_i_type = kInt16Type; - } else if (i_out >= static_cast(std::numeric_limits::min()) && - i_out <= static_cast(std::numeric_limits::max())) { - *dst_i_type = kInt32Type; - } else { - *dst_i_type = kInt64Type; - } - return i_out; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc deleted file mode 100644 index b7e890da7c..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_distributed_sample.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, - uint32_t seed) - : ShardSample(1, num_shards, shard_id), - shuffle_(shuffle), - no_of_padded_samples_(no_of_padded_samples), - first_epoch_(true) { - shuffle_op_ = std::make_shared(seed, kShuffleSample); -} - -ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed) - : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {} - -int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { - if (no_of_padded_samples_ <= 0) { - if (dataset_size % denominator_ == 0) { - return dataset_size / denominator_ * numerator_; - } else { - return dataset_size / denominator_ * numerator_ + 1; - } - } else { - auto padded_size = dataset_size + no_of_padded_samples_; - if (padded_size % denominator_ == 0) { - return padded_size / denominator_ * numerator_; - } else { - return -1; - } - } - return 0; -} - -MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) { - auto total_no = tasks.Size(); - if (no_of_padded_samples_ > 0 && first_epoch_) { - if (total_no % denominator_ != 0) { - MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. " - << "task size: " << total_no << ", number padded: " << no_of_padded_samples_ - << ", denominator: " << denominator_; - return FAILED; - } - } - if (first_epoch_) { - first_epoch_ = false; - task_ = tasks; - } else { - tasks = task_; - } - if (shuffle_ == true) { - if (SUCCESS != (*shuffle_op_)(tasks)) { - return FAILED; - } - } - return SUCCESS; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc deleted file mode 100644 index ec177394ef..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ /dev/null @@ -1,725 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_header.h" - -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_page.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -std::atomic thread_status(false); -ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared(); } - -MSRStatus ShardHeader::InitializeHeader(const std::vector &headers, bool load_dataset) { - shard_count_ = headers.size(); - int shard_index = 0; - bool first = true; - for (const auto &header : headers) { - if (first) { - first = false; - if (ParseSchema(header["schema"]) != SUCCESS) { - return FAILED; - } - if (ParseIndexFields(header["index_fields"]) != SUCCESS) { - return FAILED; - } - if (ParseStatistics(header["statistics"]) != SUCCESS) { - return FAILED; - } - ParseShardAddress(header["shard_addresses"]); - header_size_ = header["header_size"].get(); - page_size_ = header["page_size"].get(); - } - ParsePage(header["page"], shard_index, load_dataset); - shard_index++; - } - return SUCCESS; -} - -MSRStatus ShardHeader::CheckFileStatus(const std::string &path) { - std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); - if (!fin) { - MS_LOG(ERROR) << "File does not exist or permission denied. path: " << path; - return FAILED; - } - if (fin.fail()) { - MS_LOG(ERROR) << "Failed to open file. path: " << path; - return FAILED; - } - - // fetch file size - auto &io_seekg = fin.seekg(0, std::ios::end); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - fin.close(); - MS_LOG(ERROR) << "File seekg failed"; - return FAILED; - } - - size_t file_size = fin.tellg(); - if (file_size < kMinFileSize) { - fin.close(); - MS_LOG(ERROR) << "File size %d is smaller than the minimum value."; - return FAILED; - } - fin.close(); - return SUCCESS; -} - -std::pair ShardHeader::ValidateHeader(const std::string &path) { - if (CheckFileStatus(path) != SUCCESS) { - return {FAILED, {}}; - } - - // read header size - json json_header; - std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); - if (!fin.is_open()) { - MS_LOG(ERROR) << "File seekg failed"; - return {FAILED, json_header}; - } - - uint64_t header_size = 0; - auto &io_read = fin.read(reinterpret_cast(&header_size), kInt64Len); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - fin.close(); - return {FAILED, json_header}; - } - - if (header_size > kMaxHeaderSize) { - fin.close(); - MS_LOG(ERROR) << "Header size is illegal."; - return {FAILED, json_header}; - } - - // read header content - std::vector header_content(header_size); - auto &io_read_content = fin.read(reinterpret_cast(&header_content[0]), header_size); - if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) { - MS_LOG(ERROR) << "File read failed"; - fin.close(); - return {FAILED, json_header}; - } - - fin.close(); - std::string raw_header_content = std::string(header_content.begin(), header_content.end()); - // parse json content - try { - json_header = json::parse(raw_header_content); - } catch (json::parse_error &e) { - MS_LOG(ERROR) << "Json parse error: " << e.what(); - return {FAILED, json_header}; - } - return {SUCCESS, json_header}; -} - -std::pair ShardHeader::BuildSingleHeader(const std::string &file_path) { - auto ret = ValidateHeader(file_path); - if (SUCCESS != ret.first) { - return {FAILED, json()}; - } - json raw_header = ret.second; - json header = {{"shard_addresses", raw_header["shard_addresses"]}, - {"header_size", raw_header["header_size"]}, - {"page_size", raw_header["page_size"]}, - {"index_fields", raw_header["index_fields"]}, - {"blob_fields", raw_header["schema"][0]["blob_fields"]}, - {"schema", raw_header["schema"][0]["schema"]}, - {"version", raw_header["version"]}}; - return {SUCCESS, header}; -} - -MSRStatus ShardHeader::BuildDataset(const std::vector &file_paths, bool load_dataset) { - uint32_t thread_num = std::thread::hardware_concurrency(); - if (thread_num == 0) thread_num = kThreadNumber; - uint32_t work_thread_num = 0; - uint32_t shard_count = file_paths.size(); - int group_num = ceil(shard_count * 1.0 / thread_num); - std::vector thread_set(thread_num); - std::vector headers(shard_count); - for (uint32_t x = 0; x < thread_num; ++x) { - int start_num = x * group_num; - int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num; - if (start_num >= end_num) { - continue; - } - - thread_set[x] = - std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths); - work_thread_num++; - } - - for (uint32_t x = 0; x < work_thread_num; ++x) { - thread_set[x].join(); - } - if (thread_status) { - thread_status = false; - return FAILED; - } - if (SUCCESS != InitializeHeader(headers, load_dataset)) { - return FAILED; - } - return SUCCESS; -} - -void ShardHeader::GetHeadersOneTask(int start, int end, std::vector &headers, - const vector &realAddresses) { - if (thread_status || end > realAddresses.size()) { - return; - } - for (int x = start; x < end; ++x) { - auto ret = ValidateHeader(realAddresses[x]); - if (SUCCESS != ret.first) { - thread_status = true; - return; - } - json header; - header = ret.second; - header["shard_addresses"] = realAddresses; - if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) { - MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() - << ", lib version is: " << kVersion; - thread_status = true; - return; - } - headers[x] = header; - } -} - -MSRStatus ShardHeader::InitByFiles(const std::vector &file_paths) { - std::vector file_names(file_paths.size()); - std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string { - if (GetFileName(fp).first == SUCCESS) { - return GetFileName(fp).second; - } - }); - - shard_addresses_ = std::move(file_names); - shard_count_ = file_paths.size(); - if (shard_count_ == 0) { - return FAILED; - } - if (shard_count_ <= kMaxShardCount) { - pages_.resize(shard_count_); - } else { - return FAILED; - } - return SUCCESS; -} - -void ShardHeader::ParseHeader(const json &header) {} - -MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { - std::vector> parsed_index_fields; - for (auto &index_field : index_fields) { - auto schema_id = index_field["schema_id"].get(); - std::string field_name = index_field["index_field"].get(); - std::pair parsed_index_field(schema_id, field_name); - parsed_index_fields.push_back(parsed_index_field); - } - if (!parsed_index_fields.empty() && AddIndexFields(parsed_index_fields) != SUCCESS) { - return FAILED; - } - return SUCCESS; -} - -void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { - // set shard_index when load_dataset is false - if (pages_.empty() && shard_count_ <= kMaxShardCount) { - pages_.resize(shard_count_); - } - for (auto &page : pages) { - int page_id = page["page_id"]; - int shard_id = page["shard_id"]; - std::string page_type = page["page_type"]; - int page_type_id = page["page_type_id"]; - auto start_row_id = page["start_row_id"].get(); - auto end_row_id = page["end_row_id"].get(); - - std::vector> row_group_ids(page["row_group_ids"].size()); - std::transform(page["row_group_ids"].begin(), page["row_group_ids"].end(), row_group_ids.begin(), - [](json rg) { return std::make_pair(rg["id"], rg["offset"].get()); }); - - auto page_size = page["page_size"].get(); - - std::shared_ptr parsed_page = std::make_shared(page_id, shard_id, page_type, page_type_id, start_row_id, - end_row_id, row_group_ids, page_size); - if (load_dataset == true) { - pages_[shard_id].push_back(std::move(parsed_page)); - } else { - pages_[shard_index].push_back(std::move(parsed_page)); - } - } -} - -MSRStatus ShardHeader::ParseStatistics(const json &statistics) { - for (auto &statistic : statistics) { - if (statistic.find("desc") == statistic.end() || statistic.find("statistics") == statistic.end()) { - MS_LOG(ERROR) << "Deserialize statistics failed, statistic: " << statistics.dump(); - return FAILED; - } - std::string statistic_description = statistic["desc"].get(); - json statistic_body = statistic["statistics"]; - std::shared_ptr parsed_statistic = Statistics::Build(statistic_description, statistic_body); - if (!parsed_statistic) { - return FAILED; - } - AddStatistic(parsed_statistic); - } - return SUCCESS; -} - -MSRStatus ShardHeader::ParseSchema(const json &schemas) { - for (auto &schema : schemas) { - // change how we get schemaBody once design is finalized - if (schema.find("desc") == schema.end() || schema.find("blob_fields") == schema.end() || - schema.find("schema") == schema.end()) { - MS_LOG(ERROR) << "Deserialize schema failed. schema: " << schema.dump(); - return FAILED; - } - std::string schema_description = schema["desc"].get(); - std::vector blob_fields = schema["blob_fields"].get>(); - json schema_body = schema["schema"]; - std::shared_ptr parsed_schema = Schema::Build(schema_description, schema_body); - if (!parsed_schema) { - return FAILED; - } - AddSchema(parsed_schema); - } - return SUCCESS; -} - -void ShardHeader::ParseShardAddress(const json &address) { - std::copy(address.begin(), address.end(), std::back_inserter(shard_addresses_)); -} - -std::vector ShardHeader::SerializeHeader() { - std::vector header; - auto index = SerializeIndexFields(); - auto stats = SerializeStatistics(); - auto schema = SerializeSchema(); - auto pages = SerializePage(); - auto address = SerializeShardAddress(); - if (shard_count_ > static_cast(pages.size())) { - return std::vector{}; - } - if (shard_count_ <= kMaxShardCount) { - for (int shardId = 0; shardId < shard_count_; shardId++) { - string s; - s += "{\"header_size\":" + std::to_string(header_size_) + ","; - s += "\"index_fields\":" + index + ","; - s += "\"page\":" + pages[shardId] + ","; - s += "\"page_size\":" + std::to_string(page_size_) + ","; - s += "\"schema\":" + schema + ","; - s += "\"shard_addresses\":" + address + ","; - s += "\"shard_id\":" + std::to_string(shardId) + ","; - s += "\"statistics\":" + stats + ","; - s += "\"version\":\"" + std::string(kVersion) + "\""; - s += "}"; - header.emplace_back(s); - } - } - return header; -} - -std::string ShardHeader::SerializeIndexFields() { - json j; - auto fields = index_->GetFields(); - for (const auto &field : fields) { - j.push_back({{"schema_id", field.first}, {"index_field", field.second}}); - } - return j.dump(); -} - -std::vector ShardHeader::SerializePage() { - std::vector pages; - for (auto &shard_pages : pages_) { - json j; - for (const auto &p : shard_pages) { - j.emplace_back(p->GetPage()); - } - pages.emplace_back(j.dump()); - } - return pages; -} - -std::string ShardHeader::SerializeStatistics() { - json j; - for (const auto &stats : statistics_) { - j.emplace_back(stats->GetStatistics()); - } - return j.dump(); -} - -std::string ShardHeader::SerializeSchema() { - json j; - for (const auto &schema : schema_) { - j.emplace_back(schema->GetSchema()); - } - return j.dump(); -} - -std::string ShardHeader::SerializeShardAddress() { - json j; - for (const auto &addr : shard_addresses_) { - j.emplace_back(GetFileName(addr).second); - } - return j.dump(); -} - -std::pair, MSRStatus> ShardHeader::GetPage(const int &shard_id, const int &page_id) { - if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { - return std::make_pair(pages_[shard_id][page_id], SUCCESS); - } else { - return std::make_pair(nullptr, FAILED); - } -} - -MSRStatus ShardHeader::SetPage(const std::shared_ptr &new_page) { - if (new_page == nullptr) { - return FAILED; - } - int shard_id = new_page->GetShardID(); - int page_id = new_page->GetPageID(); - if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { - pages_[shard_id][page_id] = new_page; - return SUCCESS; - } else { - return FAILED; - } -} - -MSRStatus ShardHeader::AddPage(const std::shared_ptr &new_page) { - if (new_page == nullptr) { - return FAILED; - } - int shard_id = new_page->GetShardID(); - int page_id = new_page->GetPageID(); - if (shard_id < static_cast(pages_.size()) && page_id == static_cast(pages_[shard_id].size())) { - pages_[shard_id].push_back(new_page); - return SUCCESS; - } else { - return FAILED; - } -} - -int64_t ShardHeader::GetLastPageId(const int &shard_id) { - if (shard_id >= static_cast(pages_.size())) { - return 0; - } - return pages_[shard_id].size() - 1; -} - -int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &page_type) { - if (shard_id >= static_cast(pages_.size())) { - return 0; - } - int last_page_id = -1; - for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { - if (pages_[shard_id][i - 1]->GetPageType() == page_type) { - last_page_id = pages_[shard_id][i - 1]->GetPageID(); - return last_page_id; - } - } - return last_page_id; -} - -const std::pair> ShardHeader::GetPageByGroupId(const int &group_id, - const int &shard_id) { - if (shard_id >= static_cast(pages_.size())) { - MS_LOG(ERROR) << "Shard id is more than sum of shards."; - return {FAILED, nullptr}; - } - for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { - auto page = pages_[shard_id][i - 1]; - if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { - return {SUCCESS, page}; - } - } - MS_LOG(ERROR) << "Could not get page by group id " << group_id; - return {FAILED, nullptr}; -} - -int ShardHeader::AddSchema(std::shared_ptr schema) { - if (schema == nullptr) { - MS_LOG(ERROR) << "Schema is illegal"; - return -1; - } - - if (!schema_.empty()) { - MS_LOG(ERROR) << "Only support one schema"; - return -1; - } - - int64_t schema_id = schema->GetSchemaID(); - if (schema_id == -1) { - schema_id = schema_.size(); - schema->SetSchemaID(schema_id); - } - schema_.push_back(schema); - return schema_id; -} - -void ShardHeader::AddStatistic(std::shared_ptr statistic) { - if (statistic) { - int64_t statistics_id = statistic->GetStatisticsID(); - if (statistics_id == -1) { - statistics_id = statistics_.size(); - statistic->SetStatisticsID(statistics_id); - } - statistics_.push_back(statistic); - } -} - -std::shared_ptr ShardHeader::InitIndexPtr() { - std::shared_ptr index = index_; - if (!index_) { - index = std::make_shared(); - index_ = index; - } - return index; -} - -MSRStatus ShardHeader::CheckIndexField(const std::string &field, const json &schema) { - // check field name is or is not valid - if (schema.find(field) == schema.end()) { - MS_LOG(ERROR) << "Schema do not contain the field: " << field << "."; - return FAILED; - } - - if (schema[field]["type"] == "bytes") { - MS_LOG(ERROR) << field << " is bytes type, can not be schema index field."; - return FAILED; - } - - if (schema.find(field) != schema.end() && schema[field].find("shape") != schema[field].end()) { - MS_LOG(ERROR) << field << " array can not be schema index field."; - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardHeader::AddIndexFields(const std::vector &fields) { - // create index Object - std::shared_ptr index = InitIndexPtr(); - - if (fields.size() == kInt0) { - MS_LOG(ERROR) << "There are no index fields"; - return FAILED; - } - - if (GetSchemas().empty()) { - MS_LOG(ERROR) << "No schema is set"; - return FAILED; - } - - for (const auto &schemaPtr : schema_) { - auto result = GetSchemaByID(schemaPtr->GetSchemaID()); - if (result.second != SUCCESS) { - MS_LOG(ERROR) << "Could not get schema by id."; - return FAILED; - } - - if (result.first == nullptr) { - MS_LOG(ERROR) << "Could not get schema by id."; - return FAILED; - } - - json schema = result.first->GetSchema().at("schema"); - - // checkout and add fields for each schema - std::set field_set; - for (const auto &item : index->GetFields()) { - field_set.insert(item.second); - } - for (const auto &field : fields) { - if (field_set.find(field) != field_set.end()) { - MS_LOG(ERROR) << "Add same index field twice"; - return FAILED; - } - - // check field name is or is not valid - if (CheckIndexField(field, schema) == FAILED) { - return FAILED; - } - field_set.insert(field); - - // add field into index - index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); - } - } - - index_ = index; - return SUCCESS; -} - -MSRStatus ShardHeader::GetAllSchemaID(std::set &bucket_count) { - // get all schema id - for (const auto &schema : schema_) { - auto bucket_it = bucket_count.find(schema->GetSchemaID()); - if (bucket_it != bucket_count.end()) { - MS_LOG(ERROR) << "Schema duplication"; - return FAILED; - } else { - bucket_count.insert(schema->GetSchemaID()); - } - } - return SUCCESS; -} - -MSRStatus ShardHeader::AddIndexFields(std::vector> fields) { - // create index Object - std::shared_ptr index = InitIndexPtr(); - - if (fields.size() == kInt0) { - MS_LOG(ERROR) << "There are no index fields"; - return FAILED; - } - - // get all schema id - std::set bucket_count; - if (GetAllSchemaID(bucket_count) != SUCCESS) { - return FAILED; - } - - // check and add fields for each schema - std::set> field_set; - for (const auto &item : index->GetFields()) { - field_set.insert(item); - } - for (const auto &field : fields) { - if (field_set.find(field) != field_set.end()) { - MS_LOG(ERROR) << "Add same index field twice"; - return FAILED; - } - - uint64_t schema_id = field.first; - std::string field_name = field.second; - - // check schemaId is or is not valid - if (bucket_count.find(schema_id) == bucket_count.end()) { - MS_LOG(ERROR) << "Illegal schema id: " << schema_id; - return FAILED; - } - - // check field name is or is not valid - auto result = GetSchemaByID(schema_id); - if (result.second != SUCCESS) { - MS_LOG(ERROR) << "Could not get schema by id."; - return FAILED; - } - json schema = result.first->GetSchema().at("schema"); - if (schema.find(field_name) == schema.end()) { - MS_LOG(ERROR) << "Schema " << schema_id << " do not contain the field: " << field_name; - return FAILED; - } - - if (CheckIndexField(field_name, schema) == FAILED) { - return FAILED; - } - - field_set.insert(field); - - // add field into index - index.get()->AddIndexField(schema_id, field_name); - } - index_ = index; - return SUCCESS; -} - -std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { - if (shard_id >= shard_addresses_.size()) { - return ""; - } - return shard_addresses_.at(shard_id); -} - -std::vector> ShardHeader::GetSchemas() { return schema_; } - -std::vector> ShardHeader::GetStatistics() { return statistics_; } - -std::vector> ShardHeader::GetFields() { return index_->GetFields(); } - -std::shared_ptr ShardHeader::GetIndex() { return index_; } - -std::pair, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { - int64_t schemaSize = schema_.size(); - if (schema_id < 0 || schema_id >= schemaSize) { - MS_LOG(ERROR) << "Illegal schema id"; - return std::make_pair(nullptr, FAILED); - } - return std::make_pair(schema_.at(schema_id), SUCCESS); -} - -std::pair, MSRStatus> ShardHeader::GetStatisticByID(int64_t statistic_id) { - int64_t statistics_size = statistics_.size(); - if (statistic_id < 0 || statistic_id >= statistics_size) { - return std::make_pair(nullptr, FAILED); - } - return std::make_pair(statistics_.at(statistic_id), SUCCESS); -} - -MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { - // write header content to file, dump whatever is in the file before - std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); - if (page_out_handle.fail()) { - MS_LOG(ERROR) << "Failed in opening page file"; - return FAILED; - } - - auto pages = SerializePage(); - for (const auto &shard_pages : pages) { - page_out_handle << shard_pages << "\n"; - } - - page_out_handle.close(); - return SUCCESS; -} - -MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { - for (auto &v : pages_) { // clean pages - v.clear(); - } - // attempt to open the file contains the page in json - std::ifstream page_in_handle(dump_file_name.c_str()); - - if (!page_in_handle.good()) { - MS_LOG(INFO) << "No page file exists."; - return SUCCESS; - } - - std::string line; - while (std::getline(page_in_handle, line)) { - ParsePage(json::parse(line), -1, true); - } - - page_in_handle.close(); - return SUCCESS; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_index.cc b/mindspore/ccsrc/mindrecord/meta/shard_index.cc deleted file mode 100644 index 8b7a3c0342..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_index.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_index.h" - -namespace mindspore { -namespace mindrecord { -// table name for index -const char TABLENAME[] = "index_table"; - -Index::Index() : database_name_(""), table_name_(TABLENAME) {} - -void Index::AddIndexField(const int64_t &schemaId, const std::string &field) { - fields_.emplace_back(pair(schemaId, field)); -} - -// Get attribute list -std::vector> Index::GetFields() { return fields_; } -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_page.cc b/mindspore/ccsrc/mindrecord/meta/shard_page.cc deleted file mode 100644 index 6bb849ae1d..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_page.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_page.h" -#include "pybind11/pybind11.h" - -namespace mindspore { -namespace mindrecord { -json Page::GetPage() const { - json str_page; - str_page["page_id"] = page_id_; - str_page["shard_id"] = shard_id_; - str_page["page_type"] = page_type_; - str_page["page_type_id"] = page_type_id_; - str_page["start_row_id"] = start_row_id_; - str_page["end_row_id"] = end_row_id_; - if (row_group_ids_.size() == 0) { - json row_groups = json({}); - row_groups["id"] = 0; - row_groups["offset"] = 0; - str_page["row_group_ids"].push_back(row_groups); - } else { - for (const auto &rg : row_group_ids_) { - json row_groups = json({}); - row_groups["id"] = rg.first; - row_groups["offset"] = rg.second; - str_page["row_group_ids"].push_back(row_groups); - } - } - str_page["page_size"] = page_size_; - return str_page; -} - -void Page::DeleteLastGroupId() { - if (!row_group_ids_.empty()) { - page_size_ = row_group_ids_.back().second; - row_group_ids_.pop_back(); - } -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc deleted file mode 100644 index fac2fec708..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_pk_sample.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) - : ShardCategory(category_field, num_elements, std::numeric_limits::max(), true), shuffle_(false) {} - -ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) - : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} - -ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, - uint32_t seed) - : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { - shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement -} - -MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) { - if (shuffle_ == true) { - if (SUCCESS != (*shuffle_op_)(tasks)) { - return FAILED; - } - } - return SUCCESS; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc deleted file mode 100644 index c207747194..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_sample.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -ShardSample::ShardSample(int n) - : numerator_(0), - denominator_(0), - partition_id_(0), - no_of_samples_(n), - indices_({}), - sampler_type_(kCustomTopNSampler) {} - -ShardSample::ShardSample(int num, int den) - : numerator_(num), - denominator_(den), - partition_id_(0), - no_of_samples_(0), - indices_({}), - sampler_type_(kCustomTopPercentSampler) {} - -ShardSample::ShardSample(int num, int den, int par) - : numerator_(num), - denominator_(den), - partition_id_(par), - no_of_samples_(0), - indices_({}), - sampler_type_(kCustomTopPercentSampler) {} - -ShardSample::ShardSample(const std::vector &indices, uint32_t seed) - : numerator_(0), - denominator_(0), - partition_id_(0), - no_of_samples_(0), - indices_(indices), - sampler_type_(kSubsetRandomSampler) { - shuffle_op_ = std::make_shared(seed); -} - -int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { - if (sampler_type_ == kCustomTopNSampler) { - return no_of_samples_; - } - - if (sampler_type_ == kCustomTopPercentSampler) { - if (dataset_size % denominator_ == 0) { - return dataset_size / denominator_ * numerator_; - } else { - return dataset_size / denominator_ * numerator_ + 1; - } - } - if (sampler_type_ == kSubsetRandomSampler) { - return indices_.size(); - } - return 0; -} - -MSRStatus ShardSample::Execute(ShardTask &tasks) { - int no_of_categories = static_cast(tasks.categories); - int total_no = static_cast(tasks.Size()); // make sure task_size - - int taking = 0; - if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1 - no_of_samples_ = std::min(no_of_samples_, total_no); - taking = no_of_samples_ - no_of_samples_ % no_of_categories; - } else if (sampler_type_ == kSubsetRandomSampler) { - if (indices_.size() > total_no) { - MS_LOG(ERROR) << "parameter indices's size is greater than dataset size."; - return FAILED; - } - } else { // constructor TopPercent - if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { - if (numerator_ == 1 && denominator_ > 1) { // sharding - taking = (total_no + denominator_ - 1) / denominator_; - } else { // non sharding - taking = total_no * numerator_ / denominator_; - taking -= (taking % no_of_categories); - } - } else { - MS_LOG(ERROR) << "parameter numerator or denominator is illegal"; - return FAILED; - } - } - - if (tasks.permutation_.empty()) { - ShardTask new_tasks; - total_no = static_cast(tasks.Size()); - if (sampler_type_ == kSubsetRandomSampler) { - for (int i = 0; i < indices_.size(); ++i) { - int index = ((indices_[i] % total_no) + total_no) % total_no; - new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python - } - } else { - for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { - new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start - } - } - std::swap(tasks, new_tasks); - } else { - ShardTask new_tasks; - if (taking > static_cast(tasks.permutation_.size())) { - return FAILED; - } - total_no = static_cast(tasks.permutation_.size()); - for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { - new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); - } - std::swap(tasks, new_tasks); - } - return SUCCESS; -} - -MSRStatus ShardSample::SufExecute(ShardTask &tasks) { - if (sampler_type_ == kSubsetRandomSampler) { - if (SUCCESS != (*shuffle_op_)(tasks)) { - return FAILED; - } - } - return SUCCESS; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_schema.cc b/mindspore/ccsrc/mindrecord/meta/shard_schema.cc deleted file mode 100644 index ee0f5afa4a..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_schema.cc +++ /dev/null @@ -1,164 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_schema.h" -#include "common/utils.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -std::shared_ptr Schema::Build(std::string desc, const json &schema) { - // validate check - if (!Validate(schema)) { - return nullptr; - } - - std::vector blob_fields = PopulateBlobFields(schema); - Schema object_schema; - object_schema.desc_ = std::move(desc); - object_schema.blob_fields_ = std::move(blob_fields); - object_schema.schema_ = schema; - object_schema.schema_id_ = -1; - return std::make_shared(object_schema); -} - -std::shared_ptr Schema::Build(std::string desc, pybind11::handle schema) { - // validate check - json schema_json = nlohmann::detail::ToJsonImpl(schema); - return Build(std::move(desc), schema_json); -} - -std::string Schema::GetDesc() const { return desc_; } - -json Schema::GetSchema() const { - json str_schema; - str_schema["desc"] = desc_; - str_schema["schema"] = schema_; - str_schema["blob_fields"] = blob_fields_; - return str_schema; -} - -pybind11::object Schema::GetSchemaForPython() const { - json schema_json = GetSchema(); - pybind11::object schema_py = nlohmann::detail::FromJsonImpl(schema_json); - return schema_py; -} - -void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } - -int64_t Schema::GetSchemaID() const { return schema_id_; } - -std::vector Schema::GetBlobFields() const { return blob_fields_; } - -std::vector Schema::PopulateBlobFields(json schema) { - std::vector blob_fields; - for (json::iterator it = schema.begin(); it != schema.end(); ++it) { - json it_value = it.value(); - if ((it_value.size() == kInt2 && it_value.find("shape") != it_value.end()) || it_value["type"] == "bytes") { - blob_fields.emplace_back(it.key()); - } - } - return blob_fields; -} - -bool Schema::ValidateNumberShape(const json &it_value) { - if (it_value.find("shape") == it_value.end()) { - MS_LOG(ERROR) << "%s supports shape only." << it_value["type"].dump(); - return false; - } - - auto shape = it_value["shape"]; - if (!shape.is_array()) { - MS_LOG(ERROR) << "%s shape format is wrong." << it_value["type"].dump(); - return false; - } - - int num_negtive_one = 0; - for (const auto &i : shape) { - if (i == 0 || i < -1) { - MS_LOG(ERROR) << "Shape %s, number is wrong." << it_value["shape"].dump(); - return false; - } - if (i == -1) { - num_negtive_one++; - } - } - - if (num_negtive_one > 1) { - MS_LOG(ERROR) << "Shape %s, have at most 1 variable-length dimension." << it_value["shape"].dump(); - return false; - } - - return true; -} - -bool Schema::Validate(json schema) { - if (schema.size() == kInt0) { - MS_LOG(ERROR) << "Schema is null"; - return false; - } - - for (json::iterator it = schema.begin(); it != schema.end(); ++it) { - // make sure schema key name must be composed of '0-9' or 'a-z' or 'A-Z' or '_' - if (!ValidateFieldName(it.key())) { - MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', fieldName: " << it.key(); - return false; - } - - json it_value = it.value(); - if (it_value.find("type") == it_value.end()) { - MS_LOG(ERROR) << "No 'type' field exist: " << it_value.dump(); - return false; - } - - if (kFieldTypeSet.find(it_value["type"]) == kFieldTypeSet.end()) { - MS_LOG(ERROR) << "Wrong type: " << it_value["type"].dump(); - return false; - } - - if (it_value.size() == kInt1) { - continue; - } - - if (it_value["type"] == "bytes" || it_value["type"] == "string") { - MS_LOG(ERROR) << it_value["type"].dump() << " can not 1 field only."; - return false; - } - - if (it_value.size() != kInt2) { - MS_LOG(ERROR) << it_value["type"].dump() << " can have at most 2 fields."; - return false; - } - - if (!ValidateNumberShape(it_value)) { - return false; - } - } - - return true; -} - -bool Schema::operator==(const mindrecord::Schema &b) const { - if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) { - return false; - } - return true; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc deleted file mode 100644 index a7fa4e7343..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_sequential_sample.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -ShardSequentialSample::ShardSequentialSample(int n, int offset) - : ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {} - -ShardSequentialSample::ShardSequentialSample(float per, float per_offset) - : ShardSample(0), offset_(0), per_(per), per_offset_(per_offset) {} - -int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { - if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { - return dataset_size; - } - if (per_ > kEpsilon && per_ <= 1.0f) { - return dataset_size * kEpsilon; - } - return no_of_samples_; -} - -MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { - int total_no = static_cast(tasks.Size()); - int taking; - if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { - taking = total_no; - } else if (per_ > kEpsilon && per_ <= 1.0f) { - taking = total_no * kEpsilon; - } else { - taking = no_of_samples_; - } - - if (tasks.permutation_.empty()) { - ShardTask new_tasks; - total_no = static_cast(tasks.Size()); - for (int i = offset_; i < taking + offset_; ++i) { - new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); - } - std::swap(tasks, new_tasks); - } else { // shuffled - ShardTask new_tasks; - if (taking > static_cast(tasks.permutation_.size())) { - return FAILED; - } - total_no = static_cast(tasks.permutation_.size()); - for (size_t i = offset_; i < taking + offset_; ++i) { - new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); - } - std::swap(tasks, new_tasks); - } - return SUCCESS; -} - -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc deleted file mode 100644 index 5cf49b04f0..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_shuffle.h" - -#include - -namespace mindspore { -namespace mindrecord { -ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) - : shuffle_seed_(seed), - no_of_samples_(0), - replacement_(false), - reshuffle_each_epoch_(true), - shuffle_type_(shuffle_type) {} - -ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, - ShuffleType shuffle_type) - : shuffle_seed_(seed), - no_of_samples_(no_of_samples), - replacement_(replacement), - reshuffle_each_epoch_(reshuffle_each_epoch), - shuffle_type_(shuffle_type) {} - -int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { - if (replacement_) { - return no_of_samples_ == 0 ? dataset_size : no_of_samples_; - } - return dataset_size; -} - -MSRStatus ShardShuffle::Execute(ShardTask &tasks) { - if (reshuffle_each_epoch_) shuffle_seed_++; - if (tasks.categories < 1) { - return FAILED; - } - if (shuffle_type_ == kShuffleSample) { // shuffle each sample - if (tasks.permutation_.empty() == true) { - tasks.MakePerm(); - } - if (replacement_ == true) { - ShardTask new_tasks; - if (no_of_samples_ == 0) { - no_of_samples_ = static_cast(tasks.Size()); - } - if (no_of_samples_ <= 0) { - MS_LOG(ERROR) << "no_of_samples need to be positive."; - return FAILED; - } - new_tasks.task_list_.reserve(no_of_samples_); - for (uint32_t i = 0; i < no_of_samples_; ++i) { - new_tasks.InsertTask(tasks.GetRandomTask()); - } - std::swap(tasks, new_tasks); - } else { - std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); - } - } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) - uint32_t individual_size = tasks.Size() / tasks.categories; - std::vector> new_permutations(tasks.categories, std::vector(individual_size)); - for (uint32_t i = 0; i < tasks.categories; i++) { - for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); - std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); - } - tasks.permutation_.clear(); - for (uint32_t j = 0; j < individual_size; j++) { - for (uint32_t i = 0; i < tasks.categories; i++) { - tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); - } - } - } - return SUCCESS; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_statistics.cc b/mindspore/ccsrc/mindrecord/meta/shard_statistics.cc deleted file mode 100644 index ca36c50863..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_statistics.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_statistics.h" -#include "pybind11/pybind11.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -std::shared_ptr Statistics::Build(std::string desc, const json &statistics) { - // validate check - if (!Validate(statistics)) { - return nullptr; - } - Statistics object_statistics; - object_statistics.desc_ = std::move(desc); - object_statistics.statistics_ = statistics; - object_statistics.statistics_id_ = -1; - return std::make_shared(object_statistics); -} - -std::shared_ptr Statistics::Build(std::string desc, pybind11::handle statistics) { - // validate check - json statistics_json = nlohmann::detail::ToJsonImpl(statistics); - if (!Validate(statistics_json)) { - return nullptr; - } - Statistics object_statistics; - object_statistics.desc_ = std::move(desc); - object_statistics.statistics_ = statistics_json; - object_statistics.statistics_id_ = -1; - return std::make_shared(object_statistics); -} - -std::string Statistics::GetDesc() const { return desc_; } - -json Statistics::GetStatistics() const { - json str_statistics; - str_statistics["desc"] = desc_; - str_statistics["statistics"] = statistics_; - return str_statistics; -} - -pybind11::object Statistics::GetStatisticsForPython() const { - json str_statistics = Statistics::GetStatistics(); - return nlohmann::detail::FromJsonImpl(str_statistics); -} - -void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } - -int64_t Statistics::GetStatisticsID() const { return statistics_id_; } - -bool Statistics::Validate(const json &statistics) { - if (statistics.size() != kInt1) { - MS_LOG(ERROR) << "Statistics object is null"; - return false; - } - if (statistics.find("level") == statistics.end()) { - MS_LOG(ERROR) << "There is not 'level' object in statistic"; - return false; - } - return LevelRecursive(statistics["level"]); -} - -bool Statistics::LevelRecursive(json level) { - bool ini = true; - for (json::iterator it = level.begin(); it != level.end(); ++it) { - json a = it.value(); - if (a.size() == kInt2) { - if ((a.find("key") == a.end()) || (a.find("count") == a.end())) { - MS_LOG(ERROR) << "The node field is 2, but 'key'/'count' is not existed"; - return false; - } - } else if (a.size() == kInt3) { - if ((a.find("key") == a.end()) || (a.find("count") == a.end()) || a.find("level") == a.end()) { - MS_LOG(ERROR) << "The node field is 3, but 'key'/'count'/'level' is not existed"; - return false; - } else { - ini = LevelRecursive(a.at("level")); - } - } else { - MS_LOG(ERROR) << "The node field is not equal 2/3"; - return false; - } - } - return ini; -} - -bool Statistics::operator==(const Statistics &b) const { - if (this->GetStatistics() != b.GetStatistics()) { - return false; - } - return true; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/mindrecord/meta/shard_task.cc deleted file mode 100644 index 8baa3c26cd..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_task.cc +++ /dev/null @@ -1,121 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_task.h" -#include "common/utils.h" -#include "mindrecord/include/common/shard_utils.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::DEBUG; - -namespace mindspore { -namespace mindrecord { -ShardTask::ShardTask() : categories(1) {} - -ShardTask::ShardTask(const ShardTask &other) - : categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {} - -ShardTask &ShardTask::operator=(const ShardTask &other) { - ShardTask tmp(other); - std::swap(categories, tmp.categories); - permutation_.swap(tmp.permutation_); - task_list_.swap(tmp.task_list_); - return *this; -} - -void ShardTask::MakePerm() { - permutation_ = std::vector(task_list_.size()); - for (uint32_t i = 0; i < task_list_.size(); i++) { - permutation_[i] = static_cast(i); - } -} - -void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, - const json &label) { - MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id - << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; - task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); -} - -void ShardTask::InsertTask(std::tuple, std::vector, json> task) { - MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) - << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() - << ", size of task_list_: " << task_list_.size() << "."; - - task_list_.push_back(std::move(task)); -} - -void ShardTask::PopBack() { task_list_.pop_back(); } - -uint32_t ShardTask::Size() const { return static_cast(task_list_.size()); } - -uint32_t ShardTask::SizeOfRows() const { - if (task_list_.size() == 0) return static_cast(0); - - // 1 task is 1 page - auto sum_num_rows = [](int x, std::tuple, std::vector, json> y) { - return x + std::get<2>(y)[0]; - }; - uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows); - return nRows; -} - -std::tuple, std::vector, json> &ShardTask::GetTaskByID(size_t id) { - MS_ASSERT(id < task_list_.size()); - return task_list_[id]; -} - -std::tuple, std::vector, json> &ShardTask::GetRandomTask() { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, task_list_.size() - 1); - return task_list_[dis(gen)]; -} - -ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements) { - ShardTask res; - if (category_tasks.empty()) return res; - auto total_categories = category_tasks.size(); - res.categories = static_cast(total_categories); - if (replacement == false) { - auto minTasks = category_tasks[0].Size(); - for (uint32_t i = 1; i < total_categories; i++) { - minTasks = std::min(minTasks, category_tasks[i].Size()); - } - for (uint32_t task_no = 0; task_no < minTasks; task_no++) { - for (uint32_t i = 0; i < total_categories; i++) { - res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast(task_no)))); - } - } - } else { - auto maxTasks = category_tasks[0].Size(); - for (uint32_t i = 1; i < total_categories; i++) { - maxTasks = std::max(maxTasks, category_tasks[i].Size()); - } - if (num_elements != std::numeric_limits::max()) { - maxTasks = static_cast(num_elements); - } - for (uint32_t i = 0; i < total_categories; i++) { - for (uint32_t j = 0; j < maxTasks; j++) { - res.InsertTask(category_tasks[i].GetRandomTask()); - } - } - } - return res; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/onnx/CMakeLists.txt b/mindspore/ccsrc/onnx/CMakeLists.txt deleted file mode 100644 index a65ea6d450..0000000000 --- a/mindspore/ccsrc/onnx/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB_RECURSE _ONNX_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX) -add_library(_mindspore_onnx_obj OBJECT ${_ONNX_SRC_FILES}) diff --git a/mindspore/ccsrc/onnx/ir_exporter.cc b/mindspore/ccsrc/onnx/ir_exporter.cc deleted file mode 100644 index 2f02f483f5..0000000000 --- a/mindspore/ccsrc/onnx/ir_exporter.cc +++ /dev/null @@ -1,622 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/tensor_py.h" -#include "ir/param_value_py.h" -#include "debug/anf_ir_utils.h" -#include "operator/ops.h" -#include "proto/onnx.pb.h" - -namespace mindspore { -using FloatPtr = std::shared_ptr; -using IntPtr = std::shared_ptr; - -// anf type to onnx type map -static std::unordered_map g_data_type_map = { - {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, - {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, - {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, - {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, - {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, - {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, - {kObjectTypeString, onnx::TensorProto_DataType_STRING}, -}; - -static std::unordered_map g_data_bits_int_map = { - {8, onnx::TensorProto_DataType_INT8}, - {16, onnx::TensorProto_DataType_INT16}, - {32, onnx::TensorProto_DataType_INT32}, - {64, onnx::TensorProto_DataType_INT64}, -}; - -static std::unordered_map g_data_bits_float_map = { - {16, onnx::TensorProto_DataType_FLOAT16}, - {32, onnx::TensorProto_DataType_FLOAT}, -}; - -// Can build different builder according to format -class IrExportBuilder; -using IrExportBuilderPtr = std::shared_ptr; - -class IrExporter { - public: - explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {} - virtual ~IrExporter() = default; - std::string GetDumpString(const FuncGraphPtr &func_graph); - - private: - IrExportBuilderPtr builder_; -}; - -class IrExportBuilder { - public: - IrExportBuilder() = default; - ~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); } - std::string GetProtoString(const FuncGraphPtr &func_graph); - void BuildModelInfo(); - void BuildModel(const FuncGraphPtr &func_graph); - - private: - void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); - void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); - void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); - void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto); - void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto); - std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto); - - void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto); - void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto); - void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto); - void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); - void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, - std::string suffix = "0"); - void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); - void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); - - onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); - onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); - onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits); - std::string GetNodeName(const AnfNodePtr &node); - std::string GetUniqueNodeName(const AnfNodePtr &node); - std::string GetOpTypeName(const AnfNodePtr &node); - size_t AllocateIndex() { return ++node_index_; } - void ResetIndex() { node_index_ = 0; } - - private: - onnx::ModelProto model_; - onnx::NodeProto *last_node_{nullptr}; - std::list todo_; - std::map node_index_map_; - size_t node_index_{0}; -}; - -using IrExporterPtr = std::shared_ptr; - -std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) { - if ((builder_ == nullptr) || (func_graph == nullptr)) { - MS_LOG(EXCEPTION) << "Input params is null."; - } - - // Export model info - builder_->BuildModelInfo(); - - // Export model and return string - builder_->BuildModel(func_graph); - - return builder_->GetProtoString(func_graph); -} - -std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) { - MS_LOG(DEBUG) << "BuildModel complete!"; - return model_.SerializeAsString(); -} - -void IrExportBuilder::BuildModelInfo() { - model_.set_ir_version(onnx::IR_VERSION_2019_1_22); - model_.set_producer_name("MindSpore"); - model_.set_model_version(1); -} - -void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { - onnx::GraphProto *graph_proto = model_.mutable_graph(); - graph_proto->set_name(func_graph->ToString()); - ResetIndex(); - todo_.clear(); - todo_.push_back(func_graph); - while (!todo_.empty()) { - FuncGraphPtr fg = todo_.back(); - todo_.pop_back(); - BuildFuncGraph(fg, graph_proto); - } -} - -void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { - // Export parameters - // 1. parameters should be mapped to ValueInfoProto - // 2. parameters with default value should be mapped to Initializer - BuildParameters(func_graph, graph_proto); - - // Export operator nodes(include output) - BuildNodes(func_graph, graph_proto); -} - -void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { - for (auto &item : func_graph->parameters()) { - auto param = item->cast(); - if (param == nullptr) { - MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; - } - onnx::ValueInfoProto *input_proto = graph_proto->add_input(); - std::string param_name = GetUniqueNodeName(param); - input_proto->set_name(param_name); - SetValueInfoProto(param, input_proto); - if (!param->has_default()) { - MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; - continue; - } - - // Using ONNX initializer to set parameter's default value - onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); - initializer_proto->set_name(param_name); - SetParamToTensorProto(param, initializer_proto); - auto param_value = std::dynamic_pointer_cast(param->default_param()); - py::object obj = param_value->value(); - py::object data = obj.attr("data"); - if (py::isinstance(data)) { - auto method = data.attr("asnumpy"); - py::array npy_data = method(); - initializer_proto->set_raw_data(npy_data.request(true).ptr, static_cast(npy_data.nbytes())); - } - } -} - -onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) { - auto iter = g_data_type_map.find(type_id); - if (iter == g_data_type_map.end()) { - MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id; - } - return iter->second; -} - -onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) { - auto iter = g_data_bits_int_map.find(bits); - if (iter == g_data_bits_int_map.end()) { - MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits; - } - return iter->second; -} - -onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) { - auto iter = g_data_bits_float_map.find(bits); - if (iter == g_data_bits_float_map.end()) { - MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits; - } - return iter->second; -} - -void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) { - if (node == nullptr || value_proto == nullptr) { - MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!"; - } - MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); - SetValueInfoProto(node->Type(), node->Shape(), value_proto); -} - -void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::ValueInfoProto *const value_proto) { - onnx::TypeProto *type_proto = value_proto->mutable_type(); - if (type->isa() && shape->isa()) { - auto tensor = type->cast(); - auto elem_type = tensor->element(); - const auto &dims = shape->cast()->shape(); - type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); - for (const auto &dim : dims) { - MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); - } - } else if (type->isa()) { - auto tup_shape = shape->cast(); - type_proto->set_denotation(std::to_string(tup_shape->shape().size())); - } else { - MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; - } -} - -void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; - } - attr_proto->set_ref_attr_name("tensor"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - auto data = value->cast(); - tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); - auto dtype = data->data_type(); - auto shape = data->shape_c(); - tensor_proto->set_data_type(GetOnnxDataType(dtype)); - for (const auto &dim : shape) { - tensor_proto->add_dims(dim); - } -} - -void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::TensorProto *const tensor_proto) { - if (!type->isa() || !shape->isa()) { - MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString(); - } - auto tensor = type->cast(); - const auto &dims = shape->cast()->shape(); - tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id())); - for (const auto &dim : dims) { - tensor_proto->add_dims(dim); - } -} - -void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { - if (param == nullptr || tensor_proto == nullptr) { - MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; - } - MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString(); - SetTensorProto(param->Type(), param->Shape(), tensor_proto); -} - -void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { - std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); - for (const AnfNodePtr &node : nodes) { - if (!node->isa()) { - MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; - continue; - } - auto cnode = node->cast(); - if (cnode == func_graph->get_return()) { - BuildOutput(cnode, graph_proto); - } else { - BuildCNode(cnode, graph_proto); - } - } -} - -void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) { - if (node->size() != 2) { - MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; - } - AnfNodePtr arg = node->input(1); - // Using make_tuple to set multi-output - if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { - auto tuple_node = arg->cast(); - for (size_t i = 1; i < tuple_node->size(); i++) { - auto input_node = arg->cast()->input(i); - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - auto output_name = GetUniqueNodeName(tuple_node->input(i)); - output_proto->set_name(output_name); - last_node_->add_output(output_name); - SetValueInfoProto(tuple_node->input(i), output_proto); - } - } else { - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - std::string output_name = GetUniqueNodeName(node); - output_proto->set_name(output_name); - last_node_->add_output(output_name); - SetValueInfoProto(arg, output_proto); - } -} - -std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { - // May be ValueNode/CNode/Parameter - std::string type_name = ""; - if (IsValueNode(node)) { - PrimitivePtr prim = GetValueNode(node); - type_name = prim->ToString(); - } else if (IsValueNode(node)) { - FuncGraphPtr fg = GetValueNode(node); - todo_.push_back(fg); - type_name = fg->ToString(); - } else if (node->isa() || node->isa()) { - type_name = node->ToString(); - } else { - MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name(); - } - MS_LOG(DEBUG) << "ExportType: " << type_name; - return type_name; -} - -void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::NodeProto *const node_proto, std::string suffix) { - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_ref_attr_name("shape"); - if (suffix.compare("0") != 0) { - attr_proto->set_name("shape" + suffix); - } else { - attr_proto->set_name("shape"); - } - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - SetTensorProto(type, shape, tensor_proto); -} - -void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { - // Get shape of cnode - // 1. prim ArgMaxWithValue need to get shape from tuple element - // 2. some cnode doesn't has shape, such as LayerNorm - // 3. other cnodes have shape - if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { - auto type = node->Type(); - auto shape = node->Shape(); - if (!type->isa()) { - MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); - } - auto elements = type->cast()->elements(); - auto tuple_shape = shape->cast()->shape(); - for (size_t i = 0; i < elements.size(); i++) { - SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); - } - } else { - auto type = node->Type(); - auto shape = node->Shape(); - if (!type->isa() || !shape->isa()) { - MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); - return; - } - SetShapeToNodeProto(type, shape, node_proto); - } -} - -void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { - auto inputs_size = node->size(); - if (inputs_size < 1) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - - // Need to build input node before dealing with cnode - std::vector op_inputs; - std::vector input_names; - for (size_t i = 1; i < inputs_size; i++) { - auto input = node->input(i); - op_inputs.push_back(input); - input_names.push_back(BuildInputNode(input, graph_proto)); - } - - // Build cnode - onnx::NodeProto *node_proto = graph_proto->add_node(); - std::string output_name = GetUniqueNodeName(node); - node_proto->add_output(output_name); - node_proto->set_name(output_name); - node_proto->set_domain(node->fullname_with_scope()); - AnfNodePtr op = node->input(0); - std::string type_name = GetOpTypeName(op); - node_proto->set_op_type(type_name); - last_node_ = node_proto; - SetShapeToNodeProto(node, node_proto); - (void)std::for_each(input_names.begin(), input_names.end(), - [&node_proto](const string &name) { node_proto->add_input(name); }); - - // Add primitive attrs - if (IsValueNode(op)) { - auto prim = GetValueNode(op); - for (auto attr : prim->attrs()) { - MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name(attr.first); - SetValueToAttributeProto(attr.second, attr_proto); - } - } else { - MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name(); - } -} - -std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) { - std::string node_name = GetUniqueNodeName(node); - if (node->isa()) { - // When node input is a ValueNode, need to create a Constant Node - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->add_output(node_name); - SetAttributeProto(node, node_proto); - } - return node_name; -} - -std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { - // Naming anfnode - // 1. parameter is unique in one func_graph - // 2. cnode and valuenode may be reduplicative, so add index to identify. - std::string node_name = ""; - if (node->isa()) { - node_name = GetNodeName(node); - } else if (node->isa() || node->isa()) { - auto iter = node_index_map_.find(node); - if (iter != node_index_map_.end()) { - node_name = GetNodeName(node) + ":" + std::to_string(iter->second); - } else { - auto node_idx = AllocateIndex(); - node_index_map_[node] = node_idx; - node_name = GetNodeName(node) + ":" + std::to_string(node_idx); - } - } else { - MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); - } - MS_LOG(DEBUG) << "Node name: " << node_name; - return node_name; -} - -std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { - std::string node_name = ""; - if ((node != nullptr) && (node->func_graph() != nullptr)) { - node_name = node->func_graph()->ToString() + ":"; - } - node_name += node->ToString(); - MS_LOG(DEBUG) << "GetNodeName: " << node_name; - return node_name; -} - -void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) { - if (node == nullptr || node_proto == nullptr) { - MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; - } - auto value = node->cast()->value(); - node_proto->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("value"); - MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString(); - SetValueToAttributeProto(value, attr_proto); -} - -void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; - } - attr_proto->set_ref_attr_name("type"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - if (value->isa()) { - auto int_value = value->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); - } else if (value->isa()) { - auto float_value = value->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); - } else if (value->isa()) { - tensor_proto->set_name("tensor"); - auto elem_type = value->cast()->element(); - if (elem_type->isa()) { - auto int_value = elem_type->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); - } else if (elem_type->isa()) { - auto float_value = elem_type->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); - } else { - MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name(); - } - } else { - MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); - } -} - -void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; - } - if (value->isa() || value->isa()) { - SetScalarToAttributeProto(value, attr_proto); - } else if (value->isa() || value->isa()) { - SetTypeToAttributeProto(value, attr_proto); - } else if (value->isa()) { - SetSequenceToAttributeProto(value->cast(), attr_proto); - } else if (value->isa()) { - SetTensorToAttributeProto(value, attr_proto); - } else { - MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); - } -} - -void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; - } - attr_proto->set_ref_attr_name("scalar"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - SetScalarToProto(value, tensor_proto); -} - -void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { - if (value == nullptr || tensor_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; - } - if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); - tensor_proto->add_string_data(GetValue(value)); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL); - tensor_proto->add_int32_data(GetValue(value)); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8); - tensor_proto->add_int32_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16); - tensor_proto->add_int32_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32); - tensor_proto->add_int32_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); - tensor_proto->add_int64_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); - tensor_proto->add_float_data(GetValue(value)); - } else { - MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); - } -} - -void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, - onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; - } - attr_proto->set_ref_attr_name("scalar"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - if (value->isa()) { - const ValueTuplePtr &tuple_value = value->cast(); - if (tuple_value->value().size() == 0) { - MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; - return; - } - auto type_id = tuple_value->value()[0]->type()->type_id(); - tensor_proto->set_data_type(GetOnnxDataType(type_id)); - for (const auto &item : tuple_value->value()) { - SetScalarToProto(item, tensor_proto); - } - } else if (value->isa()) { - const ValueListPtr &list_value = value->cast(); - if (list_value->value().size() == 0) { - MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; - return; - } - auto type_id = list_value->value()[0]->type()->type_id(); - tensor_proto->set_data_type(GetOnnxDataType(type_id)); - for (const auto &item : list_value->value()) { - SetScalarToProto(item, tensor_proto); - } - } -} - -std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { - auto builder = std::make_shared(); - if (builder == nullptr) { - MS_LOG(ERROR) << "Create ir exporter failed!"; - return ""; - } - auto exporter = std::make_shared(builder); - if (exporter == nullptr) { - return ""; - } - return exporter->GetDumpString(func_graph); -} -} // namespace mindspore diff --git a/mindspore/ccsrc/onnx/onnx_exporter.cc b/mindspore/ccsrc/onnx/onnx_exporter.cc deleted file mode 100644 index 65a841246b..0000000000 --- a/mindspore/ccsrc/onnx/onnx_exporter.cc +++ /dev/null @@ -1,1211 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "debug/anf_ir_utils.h" -#include "proto/onnx.pb.h" -#include "operator/ops.h" -#include "ir/param_value_py.h" -#include "ir/tensor_py.h" - -namespace mindspore { -enum OpMergeMode { - OP_MERGE_UNDEFINED = 0, // undefined behavior - OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list - OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` - OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` - OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` - OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool` -}; - -struct OpMergedInfo { - OpMergeMode mode = OP_MERGE_UNDEFINED; - int referred_count = 0; -}; - -using GenAttrFuncType = - std::function; - -template -void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, - onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { - auto casted_value = dyn_cast(value); - if (casted_value == nullptr) { - MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; - } - auto attr_value = casted_value->value(); - switch (attr_type) { - case onnx::AttributeProto_AttributeType_INT: - attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value)); - break; - case onnx::AttributeProto_AttributeType_FLOAT: - attr_proto->set_f(static_cast(attr_value)); - break; - case onnx::AttributeProto_AttributeType_INTS: - for (size_t i = 0; i < rep_cnt; ++i) { - attr_proto->add_ints(static_cast<::google::protobuf::int64>(attr_value)); - } - break; - case onnx::AttributeProto_AttributeType_FLOATS: - for (size_t i = 0; i < rep_cnt; ++i) { - attr_proto->add_floats(static_cast(attr_value)); - } - break; - default: - MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type; - } - attr_proto->set_type(attr_type); -} - -template -void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, - onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { - auto tuple_ptr = dyn_cast(value); - if (tuple_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; - } - switch (attr_type) { - case onnx::AttributeProto_AttributeType_INTS: - for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) { - attr_proto->add_ints(GetValue((*tuple_ptr)[i])); - } - break; - case onnx::AttributeProto_AttributeType_FLOATS: - for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) { - attr_proto->add_floats(GetValue((*tuple_ptr)[i])); - } - break; - default: - MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type; - } - attr_proto->set_type(attr_type); -} - -void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, - onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { - attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); - auto attr_value = GetValue(value); - if (attr_value == "VALID") { - attr_proto->set_s("VALID"); - } else { - attr_proto->set_s("SAME_UPPER"); - } -} - -class OpAttrInfo { - public: - OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name, - onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) - : attr_name_(attr_name), - onnx_attr_name_(onnx_attr_name), - onnx_attr_type_(onnx_attr_type), - fn_gen_attr_(fn_gen_attr) {} - ~OpAttrInfo() {} - - const std::string &attr_name() const { return attr_name_; } - const std::string &onnx_attr_name() const { return onnx_attr_name_; } - onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } - GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } - - private: - std::string attr_name_; // attribute name of MindSpore - std::string onnx_attr_name_; // corresponding attribute name of ONNX - onnx::AttributeProto_AttributeType onnx_attr_type_; // corresponding attribute type of ONNX - GenAttrFuncType fn_gen_attr_; // function used convert -}; - -class OpNameInfo { - public: - OpNameInfo &set_op_type(const std::string &op_type) { - op_type_ = op_type; - return *this; - } - - const std::string &op_type() const { return op_type_; } - - OpNameInfo &set_onnx_type(const std::string &onnx_type) { - onnx_type_ = onnx_type; - return *this; - } - - const std::string &onnx_type() const { return onnx_type_; } - - OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name, - onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) { - op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); - return *this; - } - - const std::vector &op_attrs() const { return op_attrs_; } - - private: - std::string op_type_; // operator type of MindSpore - std::string onnx_type_; // corresponding ONNX operator type - std::vector op_attrs_; // operator attributes map info -}; - -#define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \ - OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); } - -OPERATOR_ONNX_CONVERT_DEFINE(TensorAdd, Add, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo()) - -OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Sigmoid, Sigmoid, OpNameInfo()) - -OPERATOR_ONNX_CONVERT_DEFINE(Flatten, Flatten, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Squeeze, Squeeze, - OpNameInfo().Attr("axis", "axes", onnx::AttributeProto_AttributeType_INTS, - SetAttrTupleValueToProto<0>)) - -OPERATOR_ONNX_CONVERT_DEFINE( - Conv2D, Conv, - OpNameInfo() - .Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) - .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto) - .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) - .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, - [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, - const PrimitivePtr &prim) { - attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); - auto attr_value = GetValue(value); - if (attr_value == "valid") { - attr_proto->set_s("VALID"); - } else if (attr_value == "same") { - attr_proto->set_s("SAME_UPPER"); - } else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads' - attr_proto->set_name("pads"); - SetAttrTupleValueToProto(prim->GetAttr("pad_list"), onnx::AttributeProto_AttributeType_INTS, attr_proto, - prim); - } - }) - .Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) -OPERATOR_ONNX_CONVERT_DEFINE(BiasAdd, Add, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(MatMul, Gemm, - OpNameInfo() - .Attr("transpose_a", "transA", onnx::AttributeProto_AttributeType_INT, - SetAttrValueToProto) - .Attr("transpose_b", "transB", onnx::AttributeProto_AttributeType_INT, - SetAttrValueToProto)) - -OPERATOR_ONNX_CONVERT_DEFINE(BatchNorm, BatchNormalization, - OpNameInfo().Attr("epsilon", "epsilon", onnx::AttributeProto_AttributeType_FLOAT, - SetAttrValueToProto)) - -OPERATOR_ONNX_CONVERT_DEFINE(Reshape, Reshape, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(ReduceMean, ReduceMean, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Cast, Cast, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(PReLU, PRelu, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax, - OpNameInfo() - .Attr("axis", "axis", onnx::AttributeProto_AttributeType_INT, - SetAttrValueToProto) - .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, - [](ValuePtr, onnx::AttributeProto_AttributeType, - onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { - attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - attr_proto->set_i(0); - })) - -OPERATOR_ONNX_CONVERT_DEFINE(SimpleMean, AveragePool, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE( - MaxPool, MaxPool, - OpNameInfo() - .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) - .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) - .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) - -OPERATOR_ONNX_CONVERT_DEFINE( - MaxPoolWithArgmax, MaxPool, - OpNameInfo() - .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) - .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) - .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) - -OPERATOR_ONNX_CONVERT_DEFINE( - AvgPool, AveragePool, - OpNameInfo() - .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) - .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) - .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) - -OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo()) - -#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name - -void RegisterOpConverters(const std::function &fn) { - fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); - fn(OP_CONVERT_FUNCTION_NAME(Mul)()); - - fn(OP_CONVERT_FUNCTION_NAME(ReLU)()); - fn(OP_CONVERT_FUNCTION_NAME(Sigmoid)()); - - fn(OP_CONVERT_FUNCTION_NAME(Conv2D)()); - fn(OP_CONVERT_FUNCTION_NAME(Argmax)()); - - fn(OP_CONVERT_FUNCTION_NAME(Flatten)()); - fn(OP_CONVERT_FUNCTION_NAME(MaxPool)()); - fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)()); - fn(OP_CONVERT_FUNCTION_NAME(AvgPool)()); - - fn(OP_CONVERT_FUNCTION_NAME(Squeeze)()); - fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)()); - fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); - - fn(OP_CONVERT_FUNCTION_NAME(make_tuple)()); - fn(OP_CONVERT_FUNCTION_NAME(Concat)()); - fn(OP_CONVERT_FUNCTION_NAME(RealDiv)()); - fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)()); - fn(OP_CONVERT_FUNCTION_NAME(Sub)()); -} - -class OpConvertRegistry { - public: - ~OpConvertRegistry() { Clear(); } - - static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } - - static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } - - static OpConvertRegistry &GetSingleton() { - static OpConvertRegistry registry = OpConvertRegistry(); - return registry; - } - - static const std::unordered_map &GetOpConvertMap() { return GetSingleton().op_map_; } - - void Clear() noexcept { op_map_.clear(); } - - private: - OpConvertRegistry() {} - - std::unordered_map op_map_; -}; - -class OnnxExporter { - public: - OnnxExporter() {} - ~OnnxExporter() {} - - std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); - - private: - void InitModelInfo(); - - void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); - void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); - - size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map *node_map_ptr, - const PrimitivePtr &prim, const std::vector &inputs, - onnx::GraphProto *graph_proto); - - static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); - void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false); - void SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *tensor_proto); - - void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, - std::unordered_map *op_merged_infos_ptr); - void ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - - void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - - void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - - void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - - void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - std::string GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *const graph_proto); - - void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); - void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); - - size_t AllocateNodeIndex() { return ++onnx_node_index_; } - - void ResetNodeIndex() { onnx_node_index_ = 0; } - - static int GetInt32Value(const AnfNodePtr &node) { - auto value_node_ptr = dyn_cast(node); - MS_EXCEPTION_IF_NULL(value_node_ptr); - return GetValue(value_node_ptr->value()); - } - - onnx::ModelProto model_; - - size_t onnx_node_index_ = 0; -}; - -std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) { - if (func_graph == nullptr) { - return ""; - } - ResetNodeIndex(); - OpConvertRegistry::GetSingleton().Clear(); - OpConvertRegistry::RegisterAllOpConverters(); - InitModelInfo(); - onnx::GraphProto *graph_proto = model_.mutable_graph(); - ExportFuncGraph(func_graph, graph_proto); - return model_.SerializeAsString(); -} - -void OnnxExporter::InitModelInfo() { - model_.set_ir_version(onnx::IR_VERSION_2019_1_22); - model_.set_producer_name("MindSpore"); - model_.set_producer_version("1.0"); - onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import(); - opset_proto->set_version(9); -} - -void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { - std::map node_map; - - MS_LOG(INFO) << "Begin exporting onnx model for graph " << func_graph->ToString(); - - onnx_node_index_ = func_graph->parameters().size(); - - // set graph name - graph_proto->set_name(func_graph->ToString()); - - // export parameters - // 1. all parameters (with or without default value) will be mapped to ONNX parameters - // 2. parameters with default value will mapped to ONNX initializers - ExportParameters(func_graph, graph_proto); - - // export computational nodes and output nodes - ExportNodes(func_graph, &node_map, graph_proto); - - MS_LOG(INFO) << "End exporting onnx model for graph " << func_graph->ToString(); -} - -void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { - for (auto ¶m : func_graph->parameters()) { - const ParameterPtr param_ptr = dyn_cast(param); - if (param_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; - } - - onnx::ValueInfoProto *input_proto = graph_proto->add_input(); - input_proto->set_name(param_ptr->ToString()); - SetValueInfoType(param_ptr, input_proto); - - if (!param_ptr->has_default()) { - continue; - } - // parameter with default value is an ONNX initializer - onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); - initializer_proto->set_name(param_ptr->ToString()); - SetTensorProtoInfo(param_ptr, initializer_proto); - // set value for initializer - auto param_value = std::dynamic_pointer_cast(param_ptr->default_param()); - py::object obj = param_value->value(); - py::object data = obj.attr("data"); - if (py::isinstance(data)) { - auto method = data.attr("asnumpy"); - py::array npy_data = method(); - initializer_proto->set_raw_data(npy_data.request(true).ptr, static_cast(npy_data.nbytes())); - } - } -} - -onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) { - // clang-format off - static std::unordered_map type_map = { - {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, - {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, - {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, - {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, - {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, - {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, - {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, - {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, - {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, - {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, - {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, - {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, - }; - // clang-format on - - auto iter = type_map.find(type_id); - if (iter == type_map.end()) { - MS_LOG(EXCEPTION) << "Convert type error, unsupported type " << type_id; - } - - return iter->second; -} - -void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) { - auto dtype = node->Type(); - auto shape = node->Shape(); - onnx::TypeProto *type_proto = value_proto->mutable_type(); - if (dtype->isa() && shape->isa()) { - auto tensor = dyn_cast(dtype); - auto elem_type = tensor->element(); - const auto &dims = dyn_cast(shape)->shape(); - // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 - auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); - type_proto->mutable_tensor_type()->set_elem_type(type); - - for (const auto &dim : dims) { - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); - } - } -} - -void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { - auto dtype = param->Type(); - auto shape = param->Shape(); - if (!dtype->isa() || !shape->isa()) { - MS_LOG(EXCEPTION) << "Parameter " << param->name() << " is not a regular tensor, with value " << param->ToString(); - } - - auto tensor = dyn_cast(dtype); - auto elem_type = tensor->element(); - const auto &dims = dyn_cast(shape)->shape(); - tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); - for (const auto &dim : dims) { - tensor_proto->add_dims(dim); - } -} - -void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, - std::unordered_map *op_merged_infos_ptr) { - std::unordered_map &op_merged_infos = *op_merged_infos_ptr; - - for (auto &node : nodes) { - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (cnode == func_graph->get_return()) { - // if the key `input` does not exist, just create a new one - op_merged_infos[cnode].referred_count += 1; - } - for (auto &input : cnode->inputs()) { - if (!input->isa()) { - continue; - } - // if the key `input` does not exist, just create a new one - op_merged_infos[input].referred_count += 1; - } - // MindSpore Conv + BiasAdd --> ONNX Conv - if (cnode->IsApply(std::make_shared("BiasAdd")) && - IsPrimitiveCNode(cnode->input(1), prim::kPrimConv2D)) { - op_merged_infos[cnode].mode = OP_MERGE_CONV; - op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; - op_merged_infos[cnode->input(1)].referred_count -= 1; - } else if (cnode->IsApply(std::make_shared("BiasAdd")) && - IsPrimitiveCNode(cnode->input(1), prim::kPrimMatMul)) { - op_merged_infos[cnode].mode = OP_MERGE_GEMM; - op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; - op_merged_infos[cnode->input(1)].referred_count -= 1; - } else if (cnode->IsApply(prim::kPrimTupleGetItem) && - IsPrimitiveCNode(cnode->input(1), std::make_shared("BatchNorm")) && - GetInt32Value(cnode->input(2)) == 0) { - op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM; - op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; - op_merged_infos[cnode->input(1)].referred_count -= 1; - } else if (cnode->IsApply(prim::kPrimTupleGetItem) && - IsPrimitiveCNode(cnode->input(1), std::make_shared("MaxPoolWithArgmax")) && - GetInt32Value(cnode->input(2)) == 0) { - op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX; - op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; - op_merged_infos[cnode->input(1)].referred_count -= 1; - } - } -} - -/** - * AnfNode - * +-- CNode - * +-- ANode - * | +-- Parameter - * | `-- ValueNode - */ -void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { - std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); - - std::unordered_map op_merged_infos; - MatchAndMark(func_graph, nodes, &op_merged_infos); - - for (const AnfNodePtr &node : nodes) { - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - auto iter = op_merged_infos.find(cnode); - // the node is not referenced by any other nodes, skip it - if (iter == op_merged_infos.end()) { - continue; - } - auto merged_info = iter->second; - // the op node is merged with other node and not used any more, skip it - if (merged_info.mode == OP_MERGE_IGNORE && merged_info.referred_count == 0) { - continue; - } - if (cnode == func_graph->get_return()) { - ExportOutput(func_graph, cnode, node_map_ptr, graph_proto); - continue; - } - switch (merged_info.mode) { - case OP_MERGE_CONV: - ExportMergeConv(func_graph, cnode, node_map_ptr, graph_proto); - break; - case OP_MERGE_GEMM: - ExportMergeGemm(func_graph, cnode, node_map_ptr, graph_proto); - break; - case OP_MERGE_BATCH_NORM: - ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto); - break; - case OP_MERGE_MAXPOOL_WITH_ARGMAX: - ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto); - break; - default: - ExportCNode(func_graph, cnode, node_map_ptr, graph_proto); - break; - } - } -} - -void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto input_shape = node->input(2); - std::string name_shape; - if (input_shape->isa()) { - auto const_node_idx = AllocateNodeIndex(); - (*node_map_ptr)[input_shape] = const_node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - name_shape = std::to_string(const_node_idx); - node_proto->add_output(name_shape); - - node_proto->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("value"); - - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - ConvertTupleToTensor(dyn_cast(input_shape)->value(), attr_proto->mutable_t()); - } else { - name_shape = GetNodeInputName(input_shape, node_map_ptr, graph_proto); - MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Reshape."; - } - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type(prim::kPrimReshape->name()); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(name_x); - node_proto->add_input(name_shape); -} - -void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto input_axis = node->input(2); - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - auto name = prim::kPrimReduceMean->name(); - if (node->IsApply(prim::kPrimReduceSum)) { - name = prim::kPrimReduceSum->name(); - } - node_proto->set_op_type(name); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(input_data); - - if (input_axis->isa()) { - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("axes"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); - auto axis_value = dyn_cast(input_axis)->value(); - auto int_ptr = dyn_cast(axis_value); - if (int_ptr == nullptr) { - auto tuple_ptr = dyn_cast(axis_value); - MS_EXCEPTION_IF_NULL(tuple_ptr); - for (size_t i = 0; i < tuple_ptr->size(); ++i) { - attr_proto->add_ints(GetValue((*tuple_ptr)[i])); - } - } else { - attr_proto->add_ints(int_ptr->value()); - } - } else { - MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name; - } -} - -void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto input_type = node->input(2); - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type(prim::kPrimCast->name()); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(input_data); - - if (input_type->isa()) { - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("to"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - auto type_value = dyn_cast(input_type)->value(); - auto type_ptr = dyn_cast(type_value); - MS_EXCEPTION_IF_NULL(type_ptr); - attr_proto->set_i(GetOnnxDataType(type_ptr->type_id())); - } else { - MS_LOG(EXCEPTION) << "Need to convert MindSpore Cast input(1) to ONNX Cast to attribute."; - } -} - -void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); - - auto x_shape = dyn_cast(node->input(1)->Shape()); - auto slope_shape = dyn_cast(node->input(2)->Shape()); - MS_EXCEPTION_IF_NULL(x_shape); - MS_EXCEPTION_IF_NULL(slope_shape); - - // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] - if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { - auto node_idx = AllocateNodeIndex(); - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("Unsqueeze"); - node_proto->add_output(std::to_string(node_idx)); - - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); - attr_proto->set_name("axes"); - attr_proto->add_ints(1); - attr_proto->add_ints(2); - - node_proto->add_input(input_slope); - input_slope = std::to_string(node_idx); - } - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("PRelu"); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(input_x); - node_proto->add_input(input_slope); -} - -void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("Clip"); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(input_x); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); - attr_proto->set_name("min"); - attr_proto->set_f(0.f); - attr_proto = node_proto->add_attribute(); - attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); - attr_proto->set_name("max"); - attr_proto->set_f(6.f); -} - -void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { - auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto input_w = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); - auto x_shape = dyn_cast(node->input(1)->Shape()); - auto w_shape = dyn_cast(node->input(2)->Shape()); - MS_EXCEPTION_IF_NULL(x_shape); - MS_EXCEPTION_IF_NULL(w_shape); - if (x_shape->shape().size() != 4 || w_shape->shape().size() != 4) { - MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d."; - } - if (w_shape->shape()[0] != 1 && w_shape->shape()[1] != 1) { - MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape"; - } - // create w_shape constant node - auto node_idx = AllocateNodeIndex(); - onnx::NodeProto *node_proto = graph_proto->add_node(); - std::string name_w_shape = std::to_string(node_idx); - node_proto->add_output(name_w_shape); - node_proto->set_op_type("Constant"); - // create Value Tensor - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("value"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size())); - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); - // reshape - tensor_proto->add_int64_data(w_shape->shape()[1]); - tensor_proto->add_int64_data(w_shape->shape()[0]); - tensor_proto->add_int64_data(w_shape->shape()[2]); - tensor_proto->add_int64_data(w_shape->shape()[3]); - - // add reshape node - node_idx = AllocateNodeIndex(); - node_proto = graph_proto->add_node(); - node_proto->set_op_type(prim::kPrimReshape->name()); - node_proto->add_input(input_w); - node_proto->add_input(name_w_shape); - input_w = std::to_string(node_idx); - node_proto->add_output(input_w); - - // add conv node - node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - node_proto = graph_proto->add_node(); - node_proto->set_op_type("Conv"); - node_proto->add_input(input_x); - node_proto->add_input(input_w); - node_proto->add_output(std::to_string(node_idx)); - // set attributes - AnfNodePtr op = node->input(0); - auto op_value = dyn_cast(op); - auto prim = dyn_cast(op_value->value()); - // set dilations - onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); - onnx_attr_proto->set_name("dilations"); - SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, - prim); - // set group - onnx_attr_proto = node_proto->add_attribute(); - onnx_attr_proto->set_name("group"); - onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - onnx_attr_proto->set_i(x_shape->shape()[1]); - // set kernel_shape - onnx_attr_proto = node_proto->add_attribute(); - onnx_attr_proto->set_name("kernel_shape"); - SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, - prim); - - // set pad - onnx_attr_proto = node_proto->add_attribute(); - auto attr_value = GetValue(prim->GetAttr("pad_mode")); - onnx_attr_proto->set_name("auto_pad"); - onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); - if (attr_value == "valid") { - onnx_attr_proto->set_s("VALID"); - } else if (attr_value == "same") { - onnx_attr_proto->set_s("SAME_UPPER"); - } else { - onnx_attr_proto->set_name("pads"); - SetAttrTupleValueToProto(prim->GetAttr("pads"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); - } - // set strides - onnx_attr_proto = node_proto->add_attribute(); - onnx_attr_proto->set_name("strides"); - SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); -} - -void OnnxExporter::ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto multiples = node->input(2); - std::string name_multiples; - if (multiples->isa()) { - auto const_node_idx = AllocateNodeIndex(); - (*node_map_ptr)[multiples] = const_node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - name_multiples = std::to_string(const_node_idx); - node_proto->add_output(name_multiples); - - node_proto->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("repeat"); - - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - ConvertTupleToTensor(dyn_cast(multiples)->value(), attr_proto->mutable_t()); - } else { - name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto); - MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile."; - } - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("Tile"); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(name_x); - node_proto->add_input(name_multiples); -} - -void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - std::string name_exponent; - auto const_node_idx = AllocateNodeIndex(); - onnx::NodeProto *node_proto_exp = graph_proto->add_node(); - name_exponent = std::to_string(const_node_idx); - node_proto_exp->add_output(name_exponent); - - node_proto_exp->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute(); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - tensor_proto->set_name("exponent"); - tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1)); - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); - tensor_proto->add_int64_data(2); - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("Pow"); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(name_x); - node_proto->add_input(name_exponent); -} - -void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto name_indices = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); - auto axis = node->input(3)->cast()->value(); - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("Gather"); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(name_x); - node_proto->add_input(name_indices); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast(axis)->value())); -} - -void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert - if (node->IsApply(prim::kPrimReshape)) { - return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); - } - - if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) { - return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore Cast(x, T) --> ONNX Cast[to=T](x) - if (node->IsApply(prim::kPrimCast)) { - return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto); - } - - // ONNX PRelu requires unidirectional broadcasting, here need some process - if (node->IsApply(std::make_shared("PReLU"))) { - return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x) - if (node->IsApply(std::make_shared("ReLU6"))) { - return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w)) - if (node->IsApply(std::make_shared("DepthwiseConv2dNative"))) { - return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore Tile(x) --> ONNX Tile(x, repeat) - if (node->IsApply(prim::kPrimTile)) { - return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore Square(x) --> ONNX Pow(x, 2) - if (node->IsApply(prim::kPrimSquare)) { - return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices) - if (node->IsApply(prim::kPrimGatherV2)) { - return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto); - } - - auto inputs = node->inputs(); - if (inputs.size() < 1) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - - AnfNodePtr op = inputs[0]; - std::vector op_inputs; - // first process node input 1,2,..., since when node input is a ValueNode, here need to create a Constant Operator - for (size_t i = 1; i < inputs.size(); i++) { - op_inputs.push_back(inputs[i]); - } - auto op_value = dyn_cast(op); - if (op_value == nullptr) { - MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name(); - } - auto prim = dyn_cast(op_value->value()); - if (prim == nullptr) { - MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name(); - } - - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); -} - -size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map *node_map_ptr, - const PrimitivePtr &prim, const std::vector &inputs, - onnx::GraphProto *const graph_proto) { - auto op_map = OpConvertRegistry::GetOpConvertMap(); - auto op_iter = op_map.find(prim->name()); - if (op_iter == op_map.end()) { - MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; - } - const OpNameInfo &op_convert_info = op_iter->second; - - auto node_idx = AllocateNodeIndex(); - - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->add_output(std::to_string(node_idx)); - node_proto->set_op_type(op_convert_info.onnx_type()); - - // Set inputs - for (const auto &input : inputs) { - auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); - node_proto->add_input(input_name); - } - - // Set node attribute - for (const OpAttrInfo &attr : op_convert_info.op_attrs()) { - const std::string &attr_name = attr.attr_name(); - ValuePtr attr_value = nullptr; - if (!attr_name.empty()) { - attr_value = prim->GetAttr(attr_name); - if (attr_value == nullptr) { - MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; - } - } - onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); - onnx_attr_proto->set_name(attr.onnx_attr_name()); - attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); - } - return node_idx; -} - -void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto conv_node = dyn_cast(node->input(1)); - auto input_x = conv_node->input(1); // conv input x - auto input_w = conv_node->input(2); // conv weight(filter) - auto input_b = node->input(2); // conv bias - - PrimitivePtr prim_conv = dyn_cast((dyn_cast(conv_node->input(0)))->value()); - std::vector inputs{input_x, input_w, input_b}; - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); -} - -void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto matmul_node = dyn_cast(node->input(1)); - auto input_x = matmul_node->input(1); // matmul input x - auto input_y = matmul_node->input(2); // matmul input y - auto input_b = node->input(2); // matmul bias - - PrimitivePtr prim_matmul = dyn_cast((dyn_cast(matmul_node->input(0)))->value()); - std::vector inputs{input_x, input_y, input_b}; - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); -} - -void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { - auto batch_norm_node = dyn_cast(node->input(1)); - - PrimitivePtr prim_batch_norm = dyn_cast((dyn_cast(batch_norm_node->input(0)))->value()); - std::vector inputs; - for (size_t i = 1; i < batch_norm_node->inputs().size(); i++) { - inputs.push_back(batch_norm_node->input(i)); - } - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); -} - -void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { - auto maxpool_with_argmax_node = dyn_cast(node->input(1)); - - PrimitivePtr prim_maxpool_with_argmax = - dyn_cast((dyn_cast(maxpool_with_argmax_node->input(0)))->value()); - std::vector inputs; - for (size_t i = 1; i < maxpool_with_argmax_node->inputs().size(); i++) { - inputs.push_back(maxpool_with_argmax_node->input(i)); - } - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto); -} - -void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - if (node->inputs().size() != 2) { - MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; - } - AnfNodePtr arg = node->input(1); - std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - output_proto->set_name(name); - SetValueInfoType(arg, output_proto, false); -} - -std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { - if (node->isa()) { - auto iter = node_map_ptr->find(node); - if (iter == node_map_ptr->end()) { - MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in node_map"; - } - return std::to_string(iter->second); - } - - if (node->isa()) { - return node->ToString(); - } - - // for ValueNode input, create a Constant Operator - if (node->isa()) { - auto iter = node_map_ptr->find(node); - if (iter != node_map_ptr->end()) { - return std::to_string(iter->second); - } - // the id number starts at 1, so the id of created node should be size of map plus one - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - std::string node_name = std::to_string(node_idx); - - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->add_output(node_name); - - SetNodeAttribute(node->cast()->value(), node_proto); - - return node_name; - } - - MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name(); -} - -void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { - auto tuple_ptr = dyn_cast(value); - MS_EXCEPTION_IF_NULL(tuple_ptr); - if (tuple_ptr->size() == 0) { - MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, the size of converted tuple is 0."; - } - auto type_id = (*tuple_ptr)[0]->type()->type_id(); - for (size_t i = 1; i < tuple_ptr->size(); ++i) { - if ((*tuple_ptr)[i]->type()->type_id() != type_id) { - MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, type of tuple elements is not same."; - } - } - - tensor_proto->add_dims(static_cast<::google::protobuf::int64>(tuple_ptr->size())); - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); - for (size_t i = 0; i < tuple_ptr->size(); ++i) { - ValuePtr elem = (*tuple_ptr)[i]; - if (elem->isa()) { - tensor_proto->add_int64_data(dyn_cast(elem)->value()); - } else if (elem->isa()) { - tensor_proto->add_int64_data(dyn_cast(elem)->value()); - } else if (elem->isa()) { - tensor_proto->add_int64_data(dyn_cast(elem)->value()); - } else if (elem->isa()) { - tensor_proto->add_int64_data(dyn_cast(elem)->value()); - } else { - MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, unexpected tuple element type " << elem->type()->type_name() - << "."; - } - } -} - -void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) { - node_proto->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("value"); - if (value->isa()) { - attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - auto casted_value = dyn_cast(value); - if (casted_value == nullptr) { - MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; - } - auto attr_value = casted_value->value(); - attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value)); - attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - } else if (value->isa()) { - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - auto data = dyn_cast(value); - tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); - auto dtype = data->data_type(); - auto shape = data->shape_c(); - - tensor_proto->set_data_type(GetOnnxDataType(dtype)); - for (const auto &dim : shape) { - tensor_proto->add_dims(dim); - } - } else { - MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; - } -} - -std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { - OnnxExporter exporter; - return exporter.GetOnnxProtoString(func_graph); -} -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/CMakeLists.txt b/mindspore/ccsrc/operator/CMakeLists.txt deleted file mode 100644 index 88bcf0e532..0000000000 --- a/mindspore/ccsrc/operator/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB_RECURSE _OPERATOR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_OPERATOR_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) -add_library(_mindspore_operator_obj OBJECT ${_OPERATOR_SRC_FILES}) diff --git a/mindspore/ccsrc/operator/cc_implementations.cc b/mindspore/ccsrc/operator/cc_implementations.cc deleted file mode 100644 index 52b71f410f..0000000000 --- a/mindspore/ccsrc/operator/cc_implementations.cc +++ /dev/null @@ -1,432 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/cc_implementations.h" -#include -#include -#include -#include -#include -#include "utils/misc.h" -#include "utils/log_adapter.h" -#include "utils/convert_utils.h" -#include "common/utils.h" - -namespace mindspore { -// namespace to support primitive operators definition -namespace prim { -enum class DataType { kInt, kFloat, kDouble, kUnknown }; - -// Whether has a T type data in AnyPtrList. -template -bool HasType(const AnyPtrList &list) { - bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is(); }); - return ret; -} - -DataType InferType(const AnyPtrList &list) { - if (HasType(list)) { - return DataType::kDouble; - } else if (HasType(list)) { - return DataType::kFloat; - } else if (HasType(list)) { - return DataType::kInt; - } - return DataType::kUnknown; -} - -enum OpType { ADD, SUB, MUL, DIV, MOD }; - -template -bool IsSignedIntOverflow(T x, T y, OpType opType) { - auto max = std::numeric_limits::max(); - auto min = std::numeric_limits::min(); - - if (opType == OpType::ADD) { - return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x); - } - - if (opType == OpType::SUB) { - return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x); - } - - if (opType == OpType::MUL) { - return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || - (x > 0 && y < 0 && (min / y) < x) || (x < 0 && y > 0 && (min / y) > x); - } - - if (opType == OpType::DIV || opType == OpType::MOD) { - return x == min && static_cast(y) == -1; - } - - MS_LOG(EXCEPTION) << "Unsupported operation type."; -} - -template -T InnerScalarAdd(T x, T y) { - if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::ADD)) { - MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x) - << ", y: " << std::to_string(y) << "."; - } - return x + y; -} - -template -T InnerScalarSub(T x, T y) { - if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::SUB)) { - MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x) - << ", y: " << std::to_string(y) << "."; - } - return x - y; -} - -template -T InnerScalarMul(T x, T y) { - if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::MUL)) { - MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x) - << ", y: " << std::to_string(y) << "."; - } - return x * y; -} - -template -float InnerScalarDiv(T x, T y) { - if (y == 0) { - MS_LOG(EXCEPTION) << "Divisor could not be zero"; - } - if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::DIV)) { - MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x) - << ", y: " << std::to_string(y) << "."; - } - return static_cast(x) / static_cast(y); -} - -template -T InnerScalarFloordiv(T x, T y) { - auto ret = std::floor(InnerScalarDiv(x, y)); - if (std::is_integral::value) { - return static_cast(ret); - } - return ret; -} - -template -T InnerScalarMod(T x, T y) { - if (y == 0) { - MS_LOG(EXCEPTION) << "Could not mod to zero."; - } - if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::MOD)) { - MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x) - << ", y: " << std::to_string(y) << "."; - } - if (std::is_integral::value) { - return static_cast(x) % static_cast(y); - } - int x_int = std::floor(x); - int y_int = std::ceil(y); - int max = x_int / y_int; - float ret = x - y * max; - return ret; -} - -template -T InnerScalarPow(T x, U y) { - return std::pow(x, y); -} - -template -bool InnerScalarEq(T x, U y) { - double error = static_cast(x) - static_cast(y); - error = fabs(error); - return error < DBL_EPSILON; -} - -template -bool InnerScalarLt(T x, U y) { - return x < y; -} - -template -bool InnerScalarGt(T x, U y) { - return x > y; -} - -template -bool InnerScalarNe(T x, U y) { - return !InnerScalarEq(x, y); -} - -template -bool InnerScalarLe(T x, U y) { - return x <= y; -} - -template -bool InnerScalarGe(T x, U y) { - return x >= y; -} - -#define SCALAR_OP(op_t) \ - ValuePtr Scalar##op_t(const ValuePtrList &list) { \ - do { \ - if (list.size() < 2) { \ - MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ - } \ - ValuePtr x = list[0]; \ - ValuePtr y = list[1]; \ - MS_EXCEPTION_IF_NULL(x); \ - MS_EXCEPTION_IF_NULL(y); \ - if (x->isa() && y->isa()) { \ - double sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - float sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - int sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - float sum = InnerScalar##op_t(IntToFloat(GetValue(x)), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - float sum = InnerScalar##op_t(GetValue(x), IntToFloat(GetValue(y))); \ - return MakeValue(sum); \ - } \ - MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \ - << ", y: " << y->ToString(); \ - } while (0); \ - } - -SCALAR_OP(Add) -SCALAR_OP(Sub) -SCALAR_OP(Mul) -SCALAR_OP(Div) -SCALAR_OP(Mod) -SCALAR_OP(Pow) -SCALAR_OP(Floordiv) - -#define LOGIC_OP(op_t) \ - ValuePtr Scalar##op_t(const ValuePtrList &list) { \ - if (list.size() < 2) { \ - MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ - } \ - ValuePtr x = list[0]; \ - ValuePtr y = list[1]; \ - MS_EXCEPTION_IF_NULL(x); \ - MS_EXCEPTION_IF_NULL(y); \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \ - << ", y: " << y->ToString() << "."; \ - } - -LOGIC_OP(Eq) -LOGIC_OP(Lt) -LOGIC_OP(Gt) -LOGIC_OP(Ne) -LOGIC_OP(Le) -LOGIC_OP(Ge) - -ValuePtr ScalarUAdd(const ValuePtrList &list) { - if (list.size() != 1) { - MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); - } - ValuePtr x = list[0]; - MS_EXCEPTION_IF_NULL(x); - return x; -} - -ValuePtr ScalarUSub(const ValuePtrList &list) { - if (list.size() != 1) { - MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); - } - ValuePtr x = list[0]; - MS_EXCEPTION_IF_NULL(x); - - if (x->isa()) { - int32_t sum = -1 * GetValue(x); - return MakeValue(sum); - } - if (x->isa()) { - float sum = -1.0f * GetValue(x); - return MakeValue(sum); - } - - MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; -} - -ValuePtr ScalarLog(const ValuePtrList &list) { - if (list.empty()) { - MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; - } - ValuePtr x = list[0]; - MS_EXCEPTION_IF_NULL(x); - - if (x->isa()) { - double v = log(GetValue(x)); - return MakeValue(v); - } - if (x->isa()) { - auto v = static_cast(log(GetValue(x))); - return MakeValue(v); - } - - MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); -} - -ValuePtr BoolNot(const ValuePtrList &list) { - if (list.empty()) { - MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; - } - ValuePtr x = list[0]; - MS_EXCEPTION_IF_NULL(x); - bool convert = false; - - if (ValueToBool(x, &convert)) { - auto res = !convert; - return MakeValue(res); - } - - MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); -} - -ValuePtr BoolAnd(const ValuePtrList &list) { - if (list.size() < 2) { - MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; - } - ValuePtr x = list[0]; - ValuePtr y = list[1]; - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(y); - bool x_b = false; - bool y_b = false; - - if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { - auto res = x_b && y_b; - return MakeValue(res); - } - - MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; -} - -ValuePtr BoolOr(const ValuePtrList &list) { - if (list.size() < 2) { - MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; - } - ValuePtr x = list[0]; - ValuePtr y = list[1]; - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(y); - bool x_b = false; - bool y_b = false; - - if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { - auto res = x_b || y_b; - return MakeValue(res); - } - - MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; -} - -ValuePtr BoolEq(const ValuePtrList &list) { - if (list.size() < 2) { - MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; - } - ValuePtr x = list[0]; - ValuePtr y = list[1]; - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(y); - bool x_b = false; - bool y_b = false; - - if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { - auto res = x_b == y_b; - return MakeValue(res); - } - - MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << "."; -} - -std::vector BroadcastShape_(std::vector shpx, std::vector shpy) { - int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); - if (dlen < 0) { - for (int i = 0; i < -dlen; ++i) { - (void)shpx.insert(shpx.begin(), 1); - } - } else if (dlen > 0) { - for (int i = 0; i < dlen; i++) { - (void)shpy.insert(shpy.begin(), 1); - } - } - if (shpx.size() != shpy.size()) { - MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; - } - std::vector shp; - for (size_t i = 0; i < shpx.size(); i++) { - auto a = shpx[i]; - auto b = shpy[i]; - if (a == 1) { - shp.push_back(b); - } else if (b == 1) { - shp.push_back(a); - } else if (a == -1) { - shp.push_back(b); - } else if (b == -1) { - shp.push_back(a); - } else if (a == b) { - shp.push_back(a); - } else { - return std::vector(); - } - } - return shp; -} -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc deleted file mode 100644 index 75532b9fbd..0000000000 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ /dev/null @@ -1,971 +0,0 @@ - -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/composite.h" -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "pipeline/static_analysis/dshape.h" -#include "pipeline/static_analysis/param_validator.h" -#include "operator/cc_implementations.h" -#include "optimizer/opt.h" -#include "utils/symbolic.h" -#include "pybind_api/api_register.h" -#include "./common.h" -#include "ir/signature.h" -#include "debug/trace.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using AbstractTensor = mindspore::abstract::AbstractTensor; -using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; - -using mindspore::abstract::AbstractAttribute; -using mindspore::abstract::AbstractBase; -using mindspore::abstract::AbstractClass; -using mindspore::abstract::AbstractDictionary; -using mindspore::abstract::AbstractDictionaryPtr; -using mindspore::abstract::AbstractEllipsis; -using mindspore::abstract::AbstractEllipsisPtr; -using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractFunctionPtr; -using mindspore::abstract::AbstractList; -using mindspore::abstract::AbstractNone; -using mindspore::abstract::AbstractScalar; -using mindspore::abstract::AbstractSlice; -using mindspore::abstract::AbstractTuple; - -ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul}, - {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod}, - {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt}, - {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe}, - {"__ge__", kPrimScalarGe}}; - -const MetaFuncGraphPtr kTail = std::make_shared("tail"); - -// copy from python API: reduce. -// Apply a function of two arguments cumulatively to the items of a sequence, -// from left to right, so as to reduce the sequence to a single value.For example, -// reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). -AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) { - std::shared_ptr ret; - size_t size = list.size(); - if (size < 2) { - MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; - } - - AnyPtrList input; - input.push_back(list[0]); - input.push_back(list[1]); - ret = std::make_shared(func(input)); - - for (size_t i = 2; i < size; ++i) { - input.clear(); - input.push_back(ret); - input.push_back(list[i]); - ret = std::make_shared(func(input)); - } - - return ret; -} - -AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector &list) { - size_t size = list.size(); - if (size < 2) { - MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; - } - - std::vector input; - input.push_back(list[0]); - input.push_back(list[1]); - AnfNodePtr ret = func(input); - - for (size_t i = 2; i < size; ++i) { - input.clear(); - input.push_back(ret); - input.push_back(list[i]); - ret = func(input); - } - - return ret; -} - -ValuePtr kCompositeHyperMap = std::make_shared(); - -void HyperMap::Init() { - if (fn_leaf_) { - name_ = "hyper_map[" + fn_leaf_->name() + "]"; - } - signatures_ = - // def hypermap(func:read, *args:ref): - std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, - {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); -} - -HyperMap::HyperMap(const std::shared_ptr &fn_leaf) - : MetaFuncGraph("hyper_map"), - fn_leaf_(fn_leaf), - broadcast_(false), - nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { - Init(); -} - -HyperMap::HyperMap(const HyperMap &h) - : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { - Init(); -} - -AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_map) { - MS_EXCEPTION_IF_NULL(func_graph); - std::vector inputs; - if (fn_arg != nullptr) { - inputs.push_back(fn_arg); - } else { - inputs.push_back(NewValueNode(fn_leaf_)); - } - - (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), - [](const std::pair &item) { return item.first; }); - return func_graph->NewCNode(inputs); -} - -AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(type); - - std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { - auto lhs = std::static_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; - }); - if (is_not_same) { - MS_LOG(EXCEPTION) << "List in HyperMap should have same length"; - } - - // cannot use shared_from_base() also known as this, as it will make a reference cycle on - // hypermap and graph generated, it will cause memory leak. - auto fn_rec = NewValueNode(std::make_shared(*this)); - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeList)); - - for (int i = 0; i < SizeToInt(size); ++i) { - std::vector inputs2; - inputs2.push_back(fn_rec); - if (fn_arg != nullptr) { - inputs2.push_back(fn_arg); - } - - (void)std::transform( - arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), - [&func_graph, i](const std::pair &item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); - }); - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(type); - - std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { - auto lhs = std::static_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; - }); - if (is_not_same) { - MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length"; - } - - // cannot use shared_from_base() also known as this, as it will make a reference cycle on - // hypermap and graph generated, it will cause memory leak. - auto fn_rec = NewValueNode(std::make_shared(*this)); - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - - for (int i = 0; i < SizeToInt(size); ++i) { - std::vector inputs2; - inputs2.push_back(fn_rec); - if (fn_arg != nullptr) { - inputs2.push_back(fn_arg); - } - - (void)std::transform( - arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); - }); - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { - MS_EXCEPTION_IF_NULL(type); - MS_EXCEPTION_IF_NULL(func_graph); - - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); - inputs.push_back(NewValueNode(type)); - - // cannot use shared_from_base() also known as this, as it will make a reference cycle on - // hypermap and graph generated, it will cause memory leak. - auto fn_rec = NewValueNode(std::make_shared(*this)); - std::size_t attrSize = type->GetAttributes().size(); - for (std::size_t i = 0; i < attrSize; ++i) { - std::vector inputs2; - inputs2.push_back(fn_rec); - if (fn_arg) { - inputs2.push_back(fn_arg); - } - - int j = 0; - for (auto item : arg_map) { - inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); - j++; - } - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { - bool found = false; - TypeId id = kObjectTypeEnd; - std::pair pair; - for (auto &item : arg_map) { - pair = item; - id = item.second->type_id(); - if (nonleaf_.count(id)) { - found = true; - break; - } - } - - if (found) { - // In a nonleaf situation, all arguments must have the same generic. - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair &item) { - if (item.first != pair.first) { - return item.second->type_id() != pair.second->type_id(); - } - return false; - }); - if (is_not_same) { - std::ostringstream oss; - oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" - << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; - int idx = 0; - for (auto &item : arg_map) { - oss << ++idx << ": " << item.second->ToString() << "\n"; - } - MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); - } - } - - switch (id) { - case kObjectTypeList: { - auto type = std::static_pointer_cast(pair.second); - return FullMake(type, func_graph, fn_arg, arg_map); - } - case kObjectTypeTuple: { - auto type = std::static_pointer_cast(pair.second); - return FullMake(type, func_graph, fn_arg, arg_map); - } - case kObjectTypeClass: { - auto type = std::static_pointer_cast(pair.second); - return FullMake(type, func_graph, fn_arg, arg_map); - } - default: - return FullMake(pair.second, func_graph, fn_arg, arg_map); - } -} - -ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) { - TypePtr type_tensor = std::make_shared(); - bool flag = std::any_of( - args_spec_list.begin(), args_spec_list.end(), - [type_tensor](const std::pair &item) { return IsSubType(item.second, type_tensor); }); - if (flag && broadcast_) { - ArgsPairList ret; - for (auto &item : args_spec_list) { - if (!IsSubType(item.second, type_tensor)) { - TypePtr type_tensor_ele = std::make_shared(item.second); - ret.push_back( - std::make_pair(func_graph->NewCNode({NewValueNode(prim::kPrimScalarToArray), item.first}), type_tensor_ele)); - } else { - ret.push_back(std::make_pair(item.first, item.second)); - } - } - return ret; - } - return args_spec_list; -} - -FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { - FuncGraphPtr ptrGraph = std::make_shared(); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); - ptrGraph->debug_info()->set_name("hyper_map"); - - AnfNodePtr ptrFnArg = nullptr; - std::size_t i = 0; - ArgsPairList argmap; - ArgsPairList argmap2; - if (fn_leaf_ == nullptr) { - ptrFnArg = ptrGraph->add_parameter(); - i = 1; - } - - std::size_t size = args_spec_list.size(); - for (; i < size; ++i) { - argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); - } - - argmap2 = Harmonize(ptrGraph, argmap); - ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2)); - return ptrGraph; -} - -abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { - if (fn_leaf_ == nullptr) { - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - // Assert that hypermap's function param does not contain free variables - if (args_spec_list[0]->isa()) { - auto graph_func = dyn_cast(args_spec_list[0]); - auto func_graph = graph_func->func_graph(); - if (func_graph->parent() != nullptr) { - MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet."; - } - } - } - - AbstractBasePtrList broadened; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(arg); - return arg->Broaden(); - }); - return broadened; -} - -REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { - (void)py::class_>(*m, "HyperMap_") - .def(py::init>(), py::arg("leaf")) - .def(py::init<>()); - })); - -FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { - MS_EXCEPTION_IF_NULL(a_tuple); - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ret->debug_info()->set_name("tail"); - AnfNodePtr ptrTup = ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeTuple)); - - int tuple_size = SizeToInt(a_tuple->size()); - for (int i = 1; i < tuple_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)})); - } - - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { - MS_EXCEPTION_IF_NULL(a_list); - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ret->debug_info()->set_name("tail"); - AnfNodePtr ptrList = ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeList)); - - int list_size = SizeToInt(a_list->size()); - for (int i = 1; i < list_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)})); - } - - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; - } - - AbstractBasePtr a = args_spec_list[0]; - abstract::AbstractTuplePtr a_tuple = dyn_cast(a); - if (a_tuple != nullptr) { - return GenerateTupleFuncGraph(a_tuple); - } - - abstract::AbstractListPtr a_list = dyn_cast(a); - if (a_list != nullptr) { - return GenerateListFuncGraph(a_list); - } - - MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString(); -} - -REGISTER_PYBIND_DEFINE( - Tail_, ([](const py::module *m) { - (void)py::class_>(*m, "Tail_").def(py::init()); - })); - -FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - int tuple_size = SizeToInt(args_spec_list.size()); - - std::ostringstream ss; - ss << "▶make_tuple_" << tuple_size; - FuncGraphPtr fg = std::make_shared(); - fg->debug_info()->set_name(ss.str()); - - std::vector params; - params.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (int i = 0; i < tuple_size; ++i) { - params.push_back(fg->add_parameter()); - } - - // make fprob first result, maketuple's forward result. - AnfNodePtr out = fg->NewCNode(params); - - // make fprob second result, maketuple's backward function. - FuncGraphPtr b = std::make_shared(); - - ss.clear(); - ss << "◀make_tuple_" << tuple_size; - b->debug_info()->set_name(ss.str()); - AnfNodePtr dout = b->add_parameter(); - - std::vector grads; - grads.push_back(NewValueNode(prim::kPrimMakeTuple)); - grads.push_back(NewValueNode(newenv)); - for (int i = 0; i < tuple_size; ++i) { - grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); - } - - b->set_flag(FUNC_GRAPH_FLAG_CORE, true); - b->set_output(b->NewCNode(grads)); - - fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); - fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); - (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple)); - return fg; -} - -GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) - : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { - if (get_by_list) { - signatures_ = - // def grad(func:read, weight_list:ref): - std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, - {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}}); - } -} - -FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, - const std::vector ¶ms_list, const std::vector &args, - bool applyJ) { - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - - auto weights_node = weights; - if (weights == nullptr && !args.empty()) { - weights_node = ret->NewCNode(args); - } - - ValueNodePtr opsJ = NewValueNode(prim::kPrimJ); - ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem); - - std::vector inputs; - if (applyJ) { - inputs.push_back(opsJ); - inputs.push_back(node); - node = ret->NewCNode(inputs); - } - - std::vector params; - for (size_t i = 0; i < params_list.size(); ++i) { - params.push_back(ret->add_parameter()); - } - - inputs.clear(); - inputs.push_back(node); - (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs)); - AnfNodePtr cnode = ret->NewCNode(inputs); - - inputs.clear(); - inputs.push_back(opsTupleItem); - inputs.push_back(cnode); - inputs.push_back(NewValueNode(0)); - auto out = ret->NewCNode(inputs); - - inputs.clear(); - inputs.push_back(opsTupleItem); - inputs.push_back(cnode); - inputs.push_back(NewValueNode(1)); - AnfNodePtr ptrBprop = ret->NewCNode(inputs); - - doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem); - return ret; -} - -void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, - ValueNodePtr opsTupleItem) { - MS_EXCEPTION_IF_NULL(func_graph); - - AnfNodePtr ptrBPropArg = nullptr; - if (sens_param_) { - ptrBPropArg = func_graph->add_parameter(); - } else { - auto ones_like = prim::GetPythonOps("ones_like"); - ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out}); - } - - AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg}); - - CNodePtr fv_bprop = nullptr; - if (get_by_list_) { - // python code: grads = hyper_map(F.partial(env_get, env), weights) - AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)}); - AnfNodePtr partial_env_get = - func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); - MetaFuncGraphPtr hyper_map = std::make_shared(); - fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights}); - } - - CNodePtr inputs_bprop = nullptr; - if (get_all_) { - inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp}); - } - - // Gradients wrt inputs and parameters - if (fv_bprop != nullptr && inputs_bprop != nullptr) { - func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop})); - return; - } - - // Gradients wrt parameters - if (fv_bprop != nullptr) { - func_graph->set_output(fv_bprop); - return; - } - - // Gradients wrt inputs - if (inputs_bprop != nullptr) { - func_graph->set_output(inputs_bprop); - return; - } - - // Gradients wrt first input. - // ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input - func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)})); -} - -// Generate the graph. -FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - if (args_spec_list.size() < 1) { - MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " - << args_spec_list.size() << "."; - } - - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - AbstractFunctionPtr fn = dyn_cast(args_spec_list[0]); - if (fn == nullptr) { - MS_LOG(EXCEPTION) << "GradOperation arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); - } - - // Waiting for implementation. - auto real_fn = dyn_cast(fn); - MS_EXCEPTION_IF_NULL(real_fn); - - FuncGraphPtr ptrGraph = real_fn->func_graph(); - MS_EXCEPTION_IF_NULL(ptrGraph); - TraceManager::DebugTrace(std::make_shared(ptrGraph->debug_info())); - FuncGraphPtr dfBuilder = std::make_shared(); - TraceManager::EndTrace(); - auto nparam = ptrGraph->parameters().size(); - - std::ostringstream ss; - ss << "grad{" << nparam << "}"; - dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true); - dfBuilder->debug_info()->set_name(ss.str()); - ParameterPtr param_graph = dfBuilder->add_parameter(); - - AnfNodePtr weights = nullptr; - if (get_by_list_) { - weights = dfBuilder->add_parameter(); - } - - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimJ)); - inputs.push_back(param_graph); - auto jf = dfBuilder->NewCNode(inputs); - // df is checked in GetGrad - TraceManager::DebugTrace(std::make_shared(ptrGraph->debug_info())); - auto df = GetGrad(jf, weights, ptrGraph->parameters()); - TraceManager::EndTrace(); - dfBuilder->set_output(NewValueNode(df)); - - return dfBuilder; -} - -REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { - (void)py::class_>( - *m, "GradOperation_") - .def(py::init(), py::arg("fn")) - .def(py::init(), py::arg("fn"), py::arg("get_all"), - py::arg("get_by_list"), py::arg("sens_param")); - })); - -// Generate the ListMap func graph. -FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - size_t args_num = args_spec_list.size(); - // args: fn, list1, list2, ... - if (args_num < 2) { - MS_LOG(EXCEPTION) << "list_map takes at least two arguments"; - } - - for (size_t i = 1; i < args_num; ++i) { - if (typeid(args_spec_list[i]) != typeid(AbstractBase)) { - // The function currently not be use - MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'"; - } - } - - FuncGraphPtr fg_ptr = std::make_shared(); - fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); - fg_ptr->debug_info()->set_name("list_map"); - AnfNodePtr fn = fg_ptr->add_parameter(); - - std::vector lists; - for (size_t i = 1; i < args_num; ++i) { - lists.push_back(fg_ptr->add_parameter()); - } - - std::vector iters; - (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("list_iter")), item}); - }); - - std::vector nexts; - (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("next")), item}); - }); - - std::vector values; - (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item}); - }); - - (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)}); - }); - - (void)values.insert(values.begin(), fn); - AnfNodePtr cnode_graph = fg_ptr->NewCNode(values); - AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimMakeList), cnode_graph}); - - FuncGraphPtr fgnext_ptr = std::make_shared(); - fgnext_ptr->debug_info()->set_name("body"); - - FuncGraphPtr fgcond_ptr = std::make_shared(); - fgcond_ptr->debug_info()->set_name("cond"); - - MakeCond(lists, fgnext_ptr, fgcond_ptr); - MakeNext(lists, fgcond_ptr, fgnext_ptr); - - CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl}); - - auto inputs = output_cnode->inputs(); - (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); - output_cnode->set_inputs(inputs); - - fg_ptr->set_output(output_cnode); - return fg_ptr; -} - -void ListMap::MakeCond(const std::vector &lists, const FuncGraphPtr &fgnext_ptr, - const FuncGraphPtr &fg_ptr) { - MS_EXCEPTION_IF_NULL(fg_ptr); - - AnfNodePtr fn = fg_ptr->add_parameter(); - AnfNodePtr resl = fg_ptr->add_parameter(); - - std::vector iters; - (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), - [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); }); - - std::vector hasnexts; - (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("hasnext")), item}); - }); - - // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts) - FuncGraphPtr fgtrue_ptr = std::make_shared(); - fgtrue_ptr->debug_info()->set_name("ftrue"); - fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); - - CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl}); - auto inputs = fgtrue_output_cnode->inputs(); - (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); - fgtrue_output_cnode->set_inputs(inputs); - fgtrue_ptr->set_output(fgtrue_output_cnode); - - FuncGraphPtr fgfalse_ptr = std::make_shared(); - fgfalse_ptr->debug_info()->set_name("ffalse"); - fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); - fgfalse_ptr->set_output(resl); - - AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), - NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)}); - fgtrue_ptr->set_output(output_cnode); -} - -void ListMap::MakeNext(const std::vector &lists, const FuncGraphPtr &fgcond_ptr, - const FuncGraphPtr &fg_ptr) { - MS_EXCEPTION_IF_NULL(fg_ptr); - AnfNodePtr fn = fg_ptr->add_parameter(); - - std::vector iters; - (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), - [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); }); - - std::vector nexts; - (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("next")), item}); - }); - - std::vector values; - (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, nullptr}); - }); - - iters.clear(); - (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)}); - }); - - (void)values.insert(values.begin(), fn); - AnfNodePtr cnode_graph = fg_ptr->NewCNode(values); - AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimListAppend), cnode_graph}); - CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl}); - - auto inputs = output_cnode->inputs(); - (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); - output_cnode->set_inputs(inputs); - fg_ptr->set_output(output_cnode); -} - -FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // args: tuple1, tuple2 - abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); - AbstractBasePtr abs_a = args_spec_list[0]; - AbstractBasePtr abs_b = args_spec_list[1]; - - abstract::AbstractTuplePtr a_tuple = dyn_cast(abs_a); - abstract::AbstractTuplePtr b_tuple = dyn_cast(abs_b); - if (a_tuple == nullptr || b_tuple == nullptr) { - MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", " - << args_spec_list[1]->ToString(); - } - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - AnfNodePtr p_tup_a = ret->add_parameter(); - AnfNodePtr p_tup_b = ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeTuple)); - - int tuple_size = SizeToInt(a_tuple->size()); - for (int i = 0; i < tuple_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)})); - } - - tuple_size = SizeToInt(b_tuple->size()); - for (int i = 0; i < tuple_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)})); - } - - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) { - MS_EXCEPTION_IF_NULL(scalar); - return GetValue(scalar->BuildValue()); -} - -bool CheckIndexInRange(int index, int min, int max) { return (index >= min && index <= max); } - -int GetPositiveIndex(int index, int length) { - if (index < 0) { - index += length; - } - return index; -} - -int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) { - MS_EXCEPTION_IF_NULL(member); - - if (member->isa()) { - return GetArgScalarValue(dyn_cast(member), member_name); - } - - if (member->isa()) { - return default_value; - } - - MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString(); -} - -void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index, - int *stop_index, int *step_value) { - MS_EXCEPTION_IF_NULL(tuple); - MS_EXCEPTION_IF_NULL(slice); - MS_EXCEPTION_IF_NULL(start_index); - MS_EXCEPTION_IF_NULL(stop_index); - MS_EXCEPTION_IF_NULL(step_value); - - const std::string start_name("Slice start index"); - const std::string stop_name("Slice stop index"); - const std::string step_name("Slice step value"); - - int tuple_size = SizeToInt(tuple->size()); - int start_default = 0; - int stop_default = tuple_size; - int step_default = 1; - - *step_value = CheckSliceMember(slice->step(), step_default, step_name); - if (*step_value == 0) { - MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0."; - } - - if (*step_value < 0) { - start_default = tuple_size - 1; - stop_default = -1; - } - - *start_index = CheckSliceMember(slice->start(), start_default, start_name); - *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name); - if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) || - !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) { - MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index - << " out of range, tuple size " << tuple_size << "."; - } - - *start_index = GetPositiveIndex(*start_index, tuple_size); - if (!slice->stop()->isa()) { - *stop_index = GetPositiveIndex(*stop_index, tuple_size); - } -} - -FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // slice a tuple - // args: tuple, start index, end index, step - const std::string op_name("TupleSlice"); - abstract::CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr tuple = abstract::CheckArg(op_name, args_spec_list, 0); - AbstractSlicePtr slice = abstract::CheckArg(op_name, args_spec_list, 1); - - int start_index; - int stop_index; - int step_value; - GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value); - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - AnfNodePtr p_tuple = ret->add_parameter(); - (void)ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeTuple)); - if (step_value > 0) { - for (int index = start_index; index < stop_index; index = index + step_value) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); - } - } else { - for (int index = start_index; index > stop_index; index = index + step_value) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); - } - } - - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // select indexed item - // args: tuple of items, index - const std::string op_name = std::string("TupleGetItemTensor"); - abstract::CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr branches_abs = abstract::CheckArg(op_name, args_spec_list, 0); - AbstractBasePtrList branches = branches_abs->elements(); - if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa()) { - FuncGraphPtr ret_graph = std::make_shared(); - ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - AnfNodePtr functions = ret_graph->add_parameter(); - auto index = ret_graph->add_parameter(); - - ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); - return ret_graph; - } - - MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << "."; -} - -REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { - (void)py::class_>(*m, "TupleAdd_") - .def(py::init()); - })); - -REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { - (void)py::class_>(*m, "TupleSlice_") - .def(py::init()); - })); - -REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) { - (void)py::class_>( - *m, "TupleGetItemTensor_") - .def(py::init()); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.h b/mindspore/ccsrc/operator/composite/composite.h deleted file mode 100644 index 5944c81fb0..0000000000 --- a/mindspore/ccsrc/operator/composite/composite.h +++ /dev/null @@ -1,192 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "operator/composite/zip_operation.h" -#include "operator/composite/list_append_operation.h" -#include "operator/composite/do_signature.h" -#include "operator/composite/unpack_call.h" -#include "operator/composite/multitype_funcgraph.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/misc.h" -#include "utils/any.h" -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using AbstractSlicePtr = abstract::AbstractSlicePtr; -using AbstractScalarPtr = abstract::AbstractScalarPtr; -using AbstractTensorPtr = abstract::AbstractTensorPtr; -using ElemwiseMap = std::unordered_map; -using ArgsPairList = std::vector>; - -class HyperMap : public MetaFuncGraph { - public: - explicit HyperMap(const std::shared_ptr &fn_leaf = nullptr); - HyperMap(const HyperMap &h); - void Init(); - HyperMap &operator=(const HyperMap &h) { - if (this != &h) { - fn_leaf_ = h.fn_leaf_; - broadcast_ = h.broadcast_; - nonleaf_ = h.nonleaf_; - if (fn_leaf_) { - name_ = "hyper_map[" + fn_leaf_->name() + "]"; - } - } - return *this; - } - ~HyperMap() override = default; - MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) - - abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; - FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; - MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } - - private: - AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_map); - AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_map); - AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_map); - AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_map); - AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); - ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); - - MultitypeFuncGraphPtr fn_leaf_; - bool broadcast_; - std::set nonleaf_; -}; -using HyperMapPtr = std::shared_ptr; - -class HyperMapPy : public HyperMap { - public: - explicit HyperMapPy(const std::shared_ptr &fn_leaf = nullptr) : HyperMap(fn_leaf) {} - ~HyperMapPy() override = default; - MS_DECLARE_PARENT(HyperMapPy, HyperMap) -}; -using HyperMapPyPtr = std::shared_ptr; - -extern ValuePtr kCompositeHyperMap; - -class Tail : public MetaFuncGraph { - public: - explicit Tail(const std::string &name) : MetaFuncGraph(name) {} - ~Tail() override = default; - MS_DECLARE_PARENT(Tail, MetaFuncGraph) - - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); - FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); - - friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } -}; -using TailPtr = std::shared_ptr; - -class MakeTupleGradient : public MetaFuncGraph { - public: - explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} - ~MakeTupleGradient() override = default; - MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } -}; -using MakeTupleGradientPtr = std::shared_ptr; - -class GradOperation : public MetaFuncGraph { - public: - explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, - bool sens_param = false); - ~GradOperation() override = default; - MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) - - FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector &ptrParams, - const std::vector &args = {}, bool applyJ = false); - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - bool sens_param() const { return sens_param_; } - bool get_all_; - bool get_by_list_; - bool sens_param_; - - private: - void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, - ValueNodePtr opsTupleItem); -}; -using GradOperationPtr = std::shared_ptr; - -class ListMap { - public: - explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } - ~ListMap() = default; - void MakeCond(const std::vector &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); - void MakeNext(const std::vector &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); - - private: - std::string name_; - std::map, FuncGraphPtr> cache_; -}; - -class TupleAdd : public MetaFuncGraph { - public: - explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} - ~TupleAdd() override = default; - MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } -}; -using TupleAddPtr = std::shared_ptr; - -class TupleSlice : public MetaFuncGraph { - public: - explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} - ~TupleSlice() override = default; - MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } -}; -using TupleSlicePtr = std::shared_ptr; - -class TupleGetItemTensor : public MetaFuncGraph { - public: - explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} - ~TupleGetItemTensor() override = default; - MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) { - return lhs.name_ == rhs.name_; - } -}; -using TupleGetItemTensorPtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc deleted file mode 100644 index d9bcef3031..0000000000 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ /dev/null @@ -1,317 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/do_signature.h" -#include -#include - -#include "pipeline/static_analysis/abstract_value.h" -#include "ir/anf.h" -#include "pipeline/static_analysis/dshape.h" -#include "pipeline/static_analysis/param_validator.h" -#include "operator/cc_implementations.h" -#include "optimizer/opt.h" -#include "utils/symbolic.h" -#include "./common.h" -#include "pybind_api/api_register.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -namespace { -using PatternListType = std::initializer_list; -const std::map type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt8, 2}, {kNumberTypeUInt8, 3}, - {kNumberTypeInt16, 4}, {kNumberTypeInt32, 5}, {kNumberTypeInt64, 6}, - {kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}}; - -const std::vector &GetSignature(const ValuePtr &function) { - static const auto empty = std::vector(); - if (function->isa() && function->cast()->has_signature()) { - return function->cast()->signatures(); - } else if (function->isa()) { - return function->cast()->signatures(); - } - return empty; -} - -void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, - const std::vector &signature, bool has_var, std::vector *const op_inputs) { - std::size_t sig_size = signature.size(); - auto positional_size = sig_size; - if (has_var) { - positional_size = sig_size - 1; - } - if (args_spec_list.size() < positional_size) { - for (size_t i = args_spec_list.size(); i < sig_size; ++i) { - auto default_value = signature[i].default_value; - if (default_value == nullptr) { - MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; - } else { - (*op_inputs).push_back(NewValueNode(default_value)); - } - } - } -} - -void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) { - *max_type_id = type_id; - *max_type_number = type_number; -} - -bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, - TypeId *arg_type = nullptr) { - if (arg_value->isa()) { - if (is_write) { - arg_value = arg_value->cast()->ref_origin(); - } else { - arg_value = arg_value->cast()->ref(); - } - } - if (arg_value->isa()) { - auto tensor = arg_value->cast(); - auto tensor_type = tensor->element()->BuildType(); - MS_EXCEPTION_IF_NULL(tensor_type); - *arg_type_id = tensor_type->type_id(); - if (arg_type != nullptr) { - *arg_type = kObjectTypeTensorType; - } - return true; - } - if (arg_value->isa()) { - auto scalar = arg_value->cast(); - auto scalar_type = scalar->BuildType(); - MS_EXCEPTION_IF_NULL(scalar_type); - *arg_type_id = scalar_type->type_id(); - if (arg_type != nullptr) { - *arg_type = kObjectTypeNumber; - } - return true; - } - return false; -} - -TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector indices, - const std::set &write_indices) { - TypeId max_type_id = kTypeUnknown; - size_t max_type_number = 0; - bool has_int8 = false; - for (const auto &index : indices) { - TypeId arg_type_id = kTypeUnknown; - TypeId arg_type = kTypeUnknown; - auto is_write = (write_indices.find(index) != write_indices.end()); - if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { - continue; - } - if (arg_type != kObjectTypeTensorType) { - continue; - } - auto it = type_map.find(arg_type_id); - if (it == type_map.end()) { - continue; - } - if (arg_type_id == kNumberTypeInt8) { - has_int8 = true; - } - if (max_type_id == kTypeUnknown) { - SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); - continue; - } - if (it->second > max_type_number) { - SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); - } - } - - if (max_type_id == kNumberTypeUInt8 && has_int8 == true) { - max_type_id = kNumberTypeInt16; - } - return max_type_id; -} - -// Get the largest type of index in the same SignatureEnumDType of arguments. -std::map GetMaxDtype(const std::vector &dtypes, - const abstract::AbstractBasePtrList &args_spec_list, - const std::set &write_indices) { - // record index for signature.dtypes of the same type - // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} - std::map> type_indices; - for (size_t i = 0; i < dtypes.size(); ++i) { - auto it = type_indices.find(dtypes[i]); - if (it == type_indices.end()) { - (void)type_indices.insert(std::make_pair(dtypes[i], std::vector{i})); - } else { - it->second.push_back(i); - } - } - std::map dst_type; - for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) { - auto type = it->first; - auto indices = it->second; - // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it. - if (indices.size() < 2) { - continue; - } - bool has_tensor = false; - for (const auto &index : indices) { - AbstractBasePtr arg_value = args_spec_list[index]; - if (arg_value->isa()) { - arg_value = arg_value->cast()->ref(); - } - if (arg_value->isa()) { - has_tensor = true; - break; - } - } - if (!has_tensor) { - (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); - continue; - } - (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices))); - } - return dst_type; -} - -AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGraphPtr &graph) { - auto prim_cast_class = prim::GetPythonOps("Cast", "mindspore.ops.operations"); - MS_EXCEPTION_IF_NULL(prim_cast_class); - auto dtype_node = NewValueNode(TypeIdToType(type_id)); - auto cast_node = NewCNode({NewValueNode(prim_cast_class)}, graph); - return NewCNode({cast_node, param, dtype_node}, graph); -} - -void DoAutoCast(const std::string &func_name, const std::vector &signature, - const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, - std::vector *const op_inputs, const std::set &write_indices) { - std::vector dtypes; - (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature &sig) { return sig.dtype; }); - int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); - if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { - return; - } - // Stat the index of the arguments with the largest type in the same SignatureEnumDType. - std::map dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); - // Identify which arg requires auto cast - for (size_t i = 0; i < args_spec_list.size(); ++i) { - auto it = dst_type.find(dtypes[i]); - if (it == dst_type.end() || it->second == kTypeUnknown) { - continue; - } - auto rw_it = write_indices.find(i); - auto is_write = (rw_it != write_indices.end()); - - TypeId arg_type_id = kTypeUnknown; - AbstractBasePtr arg_value = args_spec_list[i]; - (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); - auto it_map = type_name_map.find(arg_type_id); - if (it_map == type_name_map.end()) { - continue; - } - if (is_write) { - if (arg_type_id != it->second) { - auto it_name_map = type_name_map.find(it->second); - if (it_name_map == type_name_map.end()) { - continue; - } - MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n" - << "the type of writable argument is '" << it_map->second << "', " - << "but the largest type in the same SignatureEumDtype is '" << it_name_map->second - << "'. The writable arg type is not equal to the largest type, " - << "so can not cast automatically."; - } - continue; - } - if (arg_value->isa() && arg_type_id == it->second) { - continue; - } - (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); - } -} - -AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, - const AbstractBasePtrList &args_spec_list, const std::vector ¶ms_list) { - // args: original inputs - auto &signature = GetSignature(function); - std::size_t sig_size = signature.size(); - auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); - if (sig_size > 0) { - if (has_var) { - if (sig_size - 1 > args_spec_list.size()) { - MS_LOG(EXCEPTION) << "Function " << func_name - << "'s input length less than PositionalKeyword Signature length."; - } - } else if (args_spec_list.size() > sig_size) { - MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; - } - } - std::vector op_inputs; - std::set write_indices; - op_inputs.push_back(NewValueNode(function)); - // Assume, the write input of op is always the first input. We check if any write op, - // and add cast op on other inputs to keep the same type with assigned parameter. - for (size_t i = 0; i < args_spec_list.size(); ++i) { - AnfNodePtr param = params_list[i]; - if (args_spec_list[i] == nullptr) { - op_inputs.push_back(param); - continue; - } - SignatureEnumRW sig = SignatureEnumRW::kRWDefault; - // If sig_size is 0 use defalut. - if (sig_size > 0 && i < sig_size) { - sig = signature[i].rw; - } else if (has_var && i >= sig_size) { - sig = signature[sig_size - 1].rw; - } - - TypePtr type = args_spec_list[i]->GetTypeTrack(); - if (type && type->type_id() == kObjectTypeRef) { - if (sig == SignatureEnumRW::kRWRead) { - param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); - } else if (sig == SignatureEnumRW::kRWWrite) { - param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); - write_indices.insert(i); - } - // If sig is SignatureEnumRW::kRWRef, not do anything. - } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { - MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; - } - op_inputs.push_back(param); - } - // process default - ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); - DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); - return func_graph->NewCNode(op_inputs); -} -} // namespace - -AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, - const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) { - auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); - return new_cnode; -} - -FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - FuncGraphPtr func_graph = std::make_shared(); - - for (size_t i = 0; i < args_spec_list.size(); ++i) { - (void)func_graph->add_parameter(); - } - auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters()); - func_graph->set_output(new_cnode); - func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - return func_graph; -} -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/do_signature.h b/mindspore/ccsrc/operator/composite/do_signature.h deleted file mode 100644 index 3e1596d63f..0000000000 --- a/mindspore/ccsrc/operator/composite/do_signature.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/misc.h" -#include "utils/any.h" -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" -#include "common/utils.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -class DoSignatureMetaFuncGraph : public MetaFuncGraph { - public: - explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function) - : MetaFuncGraph("S-" + name), function_(function) {} - - ~DoSignatureMetaFuncGraph() override = default; - - MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) - - FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override; - const ValuePtr function() const { return function_; } - - friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) { - return &lhs == &rhs; - } - - private: - ValuePtr function_; -}; -using RWSignaturePtr = std::shared_ptr; - -AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, - const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ diff --git a/mindspore/ccsrc/operator/composite/list_append_operation.cc b/mindspore/ccsrc/operator/composite/list_append_operation.cc deleted file mode 100644 index 236a5b7062..0000000000 --- a/mindspore/ccsrc/operator/composite/list_append_operation.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/list_append_operation.h" - -#include -#include -#include - -#include "pipeline/static_analysis/param_validator.h" -#include "optimizer/opt.h" -#include "pybind_api/api_register.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) { - abstract::CheckArgsSize("ListAppend", args_list, 2); - - AbstractBasePtr arg0 = args_list[0]; - abstract::AbstractListPtr arg0_list = dyn_cast(arg0); - MS_EXCEPTION_IF_NULL(arg0_list); - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ret->debug_info()->set_name("append"); - AnfNodePtr arg0_node = ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeList)); - size_t arg0_length = arg0_list->size(); - for (size_t i = 0; i < arg0_length; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(SizeToInt(i))})); - } - AnfNodePtr arg1_node = ret->add_parameter(); - elems.push_back(arg1_node); - - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) { - (void)py::class_>(*m, "ListAppend_") - .def(py::init()); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/map.cc b/mindspore/ccsrc/operator/composite/map.cc deleted file mode 100644 index 2149285323..0000000000 --- a/mindspore/ccsrc/operator/composite/map.cc +++ /dev/null @@ -1,292 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/map.h" -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "pipeline/static_analysis/dshape.h" -#include "pybind_api/api_register.h" -#include "debug/trace.h" -#include "operator/ops.h" -#include "./common.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; - -AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) { - MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n"; - MS_EXCEPTION_IF_NULL(func_graph); - std::vector inputs; - if (fn_arg != nullptr) { - inputs.emplace_back(fn_arg); - } else { - inputs.emplace_back(NewValueNode(fn_leaf_)); - } - inputs.insert(inputs.end(), args.begin(), args.end()); - return func_graph->NewCNode(inputs); -} - -FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { - // Generate func for leaf nodes - FuncGraphPtr ptrGraph = std::make_shared(); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); - ptrGraph->debug_info()->set_name("map"); - AnfNodePtr ptrFnArg = nullptr; - if (fn_leaf_ == nullptr) { - ptrFnArg = ptrGraph->add_parameter(); - } - AnfNodePtrList args; - for (size_t i = 0; i < args_size; ++i) { - args.emplace_back(ptrGraph->add_parameter()); - } - ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args)); - return ptrGraph; -} - -AnfNodePtr Map::FullMakeList(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(type); - - std::size_t size = type->elements().size(); - bool is_not_same = - std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { - auto lhs = std::dynamic_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; - }); - if (is_not_same) { - MS_LOG(EXCEPTION) << "List in Map should have same length"; - } - - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeList)); - - for (int i = 0; i < SizeToInt(size); ++i) { - MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target"; - auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); - auto fn = NewValueNode(ptrGraph); - - std::vector inputs2; - inputs2.push_back(fn); - if (fn_arg != nullptr) { - inputs2.push_back(fn_arg); - } - - (void)std::transform( - arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), - [&func_graph, i](const std::pair &item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); - }); - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr Map::FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(type); - - std::size_t size = type->elements().size(); - bool is_not_same = - std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { - auto lhs = std::dynamic_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; - }); - if (is_not_same) { - MS_LOG(EXCEPTION) << "tuple in Map should have same length"; - } - - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - - for (int i = 0; i < SizeToInt(size); ++i) { - MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs"; - auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); - auto fn = NewValueNode(ptrGraph); - - std::vector inputs2; - inputs2.push_back(fn); - if (fn_arg != nullptr) { - inputs2.push_back(fn_arg); - } - - (void)std::transform( - arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), - [&func_graph, &i](std::pair item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); - }); - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr Map::FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { - MS_EXCEPTION_IF_NULL(type); - MS_EXCEPTION_IF_NULL(func_graph); - - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); - inputs.push_back(NewValueNode(type)); - - std::size_t attrSize = type->GetAttributes().size(); - for (std::size_t i = 0; i < attrSize; ++i) { - MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs"; - auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); - auto fn = NewValueNode(ptrGraph); - - std::vector inputs2; - inputs2.push_back(fn); - if (fn_arg != nullptr) { - inputs2.push_back(fn_arg); - } - - int j = 0; - for (auto item : arg_pairs) { - inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); - j++; - } - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { - if (arg_pairs.empty()) { - MS_EXCEPTION(TypeError) << "map() must have at least two arguments"; - } - bool found = false; - TypeId id = kObjectTypeEnd; - std::pair pair; - for (auto &item : arg_pairs) { - pair = item; - MS_LOG(DEBUG) << "Map " << pair.second->ToString(); - id = item.second->type_id(); - if (nonleaf_.count(id)) { - found = true; - break; - } - } - - if (found) { - // In a nonleaf situation, all arguments must have the same generic. - bool is_not_same = - std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair &item) { - if (item.first != pair.first) { - return item.second->type_id() != pair.second->type_id(); - } - return false; - }); - if (is_not_same) { - std::ostringstream oss; - oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n" - << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; - int idx = 0; - for (auto &item : arg_pairs) { - oss << ++idx << ": " << item.second->ToString() << "\n"; - } - MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n" - << oss.str() << pair.second->ToString() << "\n"; - } - } - - switch (id) { - case kObjectTypeList: { - auto type = std::static_pointer_cast(pair.second); - return FullMakeList(type, func_graph, fn_arg, arg_pairs); - } - case kObjectTypeTuple: { - auto type = std::static_pointer_cast(pair.second); - return FullMakeTuple(type, func_graph, fn_arg, arg_pairs); - } - case kObjectTypeClass: { - auto type = std::static_pointer_cast(pair.second); - return FullMakeClass(type, func_graph, fn_arg, arg_pairs); - } - default: - MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class " - << ", but got " << pair.second->ToString(); - } -} - -FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { - FuncGraphPtr ptrGraph = std::make_shared(); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); - ptrGraph->debug_info()->set_name("map"); - - AnfNodePtr ptrFnArg = nullptr; - std::size_t i = 0; - if (fn_leaf_ == nullptr) { - ptrFnArg = ptrGraph->add_parameter(); - i = 1; - } - ArgsPairList arg_pairs; - std::size_t size = args_spec_list.size(); - for (; i < size; ++i) { - MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString(); - arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); - } - - ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs)); - return ptrGraph; -} - -abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { - if (fn_leaf_ == nullptr) { - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - // Assert that map's function param does not contain free variables - if (args_spec_list[0]->isa()) { - auto graph_func = dyn_cast(args_spec_list[0]); - auto func_graph = graph_func->func_graph(); - if (func_graph->parent() != nullptr) { - MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet."; - } - } - } - - AbstractBasePtrList broadened; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(arg); - return arg->Broaden(); - }); - return broadened; -} - -REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) { - (void)py::class_>(*m, "Map_") - .def(py::init>(), py::arg("leaf")) - .def(py::init<>()); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/map.h b/mindspore/ccsrc/operator/composite/map.h deleted file mode 100644 index 02d374214a..0000000000 --- a/mindspore/ccsrc/operator/composite/map.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ - -#include -#include -#include -#include - -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" -#include "operator/composite/multitype_funcgraph.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using ArgsPairList = std::vector>; - -class Map : public MetaFuncGraph { - public: - explicit Map(const std::shared_ptr &fn_leaf = nullptr) - : MetaFuncGraph("map"), - fn_leaf_(fn_leaf), - broadcast_(false), - nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { - Init(); - } - Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { - Init(); - } - Map &operator=(const Map &h) { - if (this != &h) { - fn_leaf_ = h.fn_leaf_; - broadcast_ = h.broadcast_; - nonleaf_ = h.nonleaf_; - if (fn_leaf_) { - name_ = "map[" + fn_leaf_->name() + "]"; - } - } - return *this; - } - ~Map() override = default; - MS_DECLARE_PARENT(Map, MetaFuncGraph) - abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; - FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; - MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } - - private: - FuncGraphPtr GenerateLeafFunc(const size_t &args_size); - AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args); - AnfNodePtr FullMakeList(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_pairs); - AnfNodePtr FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_pairs); - AnfNodePtr FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_pairs); - AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs); - void Init() { - if (fn_leaf_ != nullptr) { - name_ = "map[" + fn_leaf_->name() + "]"; - } - signatures_ = - // def map(func:read, *args:ref): - std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, - {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); - } - - MultitypeFuncGraphPtr fn_leaf_; - bool broadcast_; - std::set nonleaf_; -}; -using MapPtr = std::shared_ptr; -class MapPy : public Map { - public: - explicit MapPy(const std::shared_ptr &fn_leaf = nullptr) : Map(fn_leaf) {} - ~MapPy() override = default; - MS_DECLARE_PARENT(MapPy, Map) -}; -using MapPyPtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ diff --git a/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc deleted file mode 100644 index de6526f642..0000000000 --- a/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc +++ /dev/null @@ -1,185 +0,0 @@ - -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/multitype_funcgraph.h" -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "pipeline/static_analysis/dshape.h" -#include "pipeline/static_analysis/param_validator.h" -#include "operator/cc_implementations.h" -#include "optimizer/opt.h" -#include "utils/symbolic.h" -#include "pybind_api/api_register.h" -#include "./common.h" -#include "ir/signature.h" -#include "debug/trace.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { - fn_cache_.clear(); - signatures_ = std::vector({// def multitype(*args:ref): - {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); -} - -void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { - MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; - auto fn = fn_cache_.find(types); - if (fn != fn_cache_.end()) { - MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; - } - fn_cache_[types] = s_fn; -} - -void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { - MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; - auto fn = fn_cache_.find(types); - if (fn != fn_cache_.end()) { - MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; - } - fn_cache_py_[types] = py_fn; -} - -void MultitypeFuncGraph::Register(const std::vector &types_name, const py::function &py_fn) { - TypePtrList types; - for (auto &type_name : types_name) { - auto type_ptr = StringToType(type_name); - if (type_ptr == nullptr) { - MS_LOG(EXCEPTION) << type_name << " convert from string error "; - } - types.push_back(type_ptr); - } - Register(types, py_fn); -} - -void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { - std::vector types_name; - for (size_t it = 0; it < tuple.size(); ++it) { - py::object name_py = tuple[it]; - if (py::isinstance(name_py)) { - types_name.push_back(name_py.cast()); - continue; - } - MS_LOG(EXCEPTION) << "Register must be string"; - } - Register(types_name, py_fn); -} -static TypePtr UnwrapRef(const TypePtr &type) { - if (type->isa()) { - return type->cast()->subtype(); - } - return type; -} - -// Return Exact match if exists, else return non ambiguous sub class match -// Return py::none() if matching is ambiguous -const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { - // Exact match - for (auto &item : fn_cache_py_) { - TypePtrList sign = item.first; - if (sign.size() != types.size()) { - continue; - } - auto match = true; - for (size_t i = 0; i < sign.size(); ++i) { - if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { - match = false; - break; - } - } - if (!match) { - continue; - } - return item.second; - } - // Try best match - py::function py_fn_subclass; - size_t subclass_match_cnt = 0; - for (auto &item : fn_cache_py_) { - TypePtrList sign = item.first; - if (sign.size() != types.size()) { - continue; - } - auto match = true; - for (size_t i = 0; i < sign.size(); ++i) { - if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) && - !IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) { - match = false; - break; - } - } - if (!match) { - continue; - } - py_fn_subclass = item.second; - subclass_match_cnt++; - } - if (subclass_match_cnt > 1) { - MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass"; - } - if (subclass_match_cnt == 1) { - MS_LOG(DEBUG) << "Found one subclass match"; - return py_fn_subclass; - } - return py::none(); -} - -FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { - auto py_fn = SignMatch(types); - std::ostringstream buffer; - buffer << types; - if (py_fn != py::none()) { - FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn); - if (func_graph == nullptr) { - MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); - } - MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString(); - return func_graph; - } - std::ostringstream oss; - oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ - << "`, corresponding location info:\n"; - int idx = 0; - for (auto &item : fn_cache_py_) { - FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); - if (func_graph == nullptr) { - MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; - continue; - } - oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; - } - MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n" - << oss.str(); -} - -REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { - (void)py::class_>( - *m, "MultitypeFuncGraph_") - .def(py::init()) - .def("register_fn", &MultitypeFuncGraph::PyRegister); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/multitype_funcgraph.h b/mindspore/ccsrc/operator/composite/multitype_funcgraph.h deleted file mode 100644 index ababf21883..0000000000 --- a/mindspore/ccsrc/operator/composite/multitype_funcgraph.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/misc.h" -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -class MultitypeFuncGraph : public MetaFuncGraph { - public: - explicit MultitypeFuncGraph(const std::string &name); - ~MultitypeFuncGraph() override = default; - MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) - - using specialize_fn = FuncGraph *(*)(TypePtrList); - // Register a method which specialize based on types vectors; - virtual void Register(const TypePtrList &types, specialize_fn s_fn); - virtual void Register(const TypePtrList &types, const py::function &py_fn); - virtual void Register(const std::vector &types_name, const py::function &py_fn); - virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); - - FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; - size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } - const std::unordered_map &GetPyFunctions() const { - return fn_cache_py_; - } - - private: - const py::function SignMatch(const TypePtrList &types); - std::unordered_map fn_cache_; - std::unordered_map fn_cache_py_; -}; -using MultitypeFuncGraphPtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ diff --git a/mindspore/ccsrc/operator/composite/unpack_call.cc b/mindspore/ccsrc/operator/composite/unpack_call.cc deleted file mode 100644 index 3993d41597..0000000000 --- a/mindspore/ccsrc/operator/composite/unpack_call.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/unpack_call.h" -#include -#include - -#include "./common.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pipeline/static_analysis/dshape.h" -#include "pipeline/static_analysis/param_validator.h" -#include "operator/cc_implementations.h" -#include "ir/anf.h" -#include "optimizer/opt.h" -#include "utils/symbolic.h" -#include "pybind_api/api_register.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using mindspore::abstract::AbstractAttribute; -using mindspore::abstract::AbstractBase; -using mindspore::abstract::AbstractDictionary; -using mindspore::abstract::AbstractDictionaryPtr; -using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractKeywordArg; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractTuplePtr; - -FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // slice a tensor - // args: tensor, slice or slice tuple - const std::string op_name = std::string("UnpackCall"); - size_t arg_length = args_spec_list.size(); - if (arg_length < 2) { - MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; - } - - (void)abstract::CheckArg(op_name, args_spec_list, 0); - auto ret_graph = std::make_shared(); - ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - - AnfNodePtr fnNode = ret_graph->add_parameter(); - std::vector elems; - elems.push_back(fnNode); - for (size_t index = 1; index < arg_length; index++) { - MS_EXCEPTION_IF_NULL(args_spec_list[index]); - if (args_spec_list[index]->isa()) { - auto arg_tuple = args_spec_list[index]->cast(); - AnfNodePtr para_tuple = ret_graph->add_parameter(); - for (size_t i = 0; i < arg_tuple->size(); ++i) { - elems.push_back( - ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))})); - } - } else if (args_spec_list[index]->isa()) { - AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast(); - AnfNodePtr para_dict = ret_graph->add_parameter(); - auto dict_elems = arg_dict->elements(); - (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), - [ret_graph, para_dict](const AbstractAttribute &item) { - auto dict_get_item = ret_graph->NewCNode( - {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); - return ret_graph->NewCNode( - {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item}); - }); - } else { - MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got " - << args_spec_list[index]->ToString(); - } - } - ret_graph->set_output(ret_graph->NewCNode(elems)); - return ret_graph; -} - -REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) { - (void)py::class_>(*m, "UnpackCall_") - .def(py::init()); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/unpack_call.h b/mindspore/ccsrc/operator/composite/unpack_call.h deleted file mode 100644 index 8c055a9386..0000000000 --- a/mindspore/ccsrc/operator/composite/unpack_call.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/misc.h" -#include "utils/any.h" -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" -#include "common/utils.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -// Expand the tuple and dict parameters generated when parsing the function call, -// and generate positional parameters and key-value pairs for function. -class UnpackCall : public MetaFuncGraph { - public: - explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {} - ~UnpackCall() override = default; - MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } -}; -using UnpackCallPtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ diff --git a/mindspore/ccsrc/operator/composite/zip_operation.cc b/mindspore/ccsrc/operator/composite/zip_operation.cc deleted file mode 100644 index 38f2b51614..0000000000 --- a/mindspore/ccsrc/operator/composite/zip_operation.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/zip_operation.h" -#include - -#include "pipeline/static_analysis/abstract_value.h" -#include "ir/anf.h" -#include "pipeline/static_analysis/dshape.h" -#include "operator/cc_implementations.h" -#include "optimizer/opt.h" -#include "pybind_api/api_register.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using mindspore::abstract::AbstractBase; -using mindspore::abstract::AbstractList; -using mindspore::abstract::AbstractSequeue; -using mindspore::abstract::AbstractSequeuePtr; -using mindspore::abstract::AbstractTuple; - -FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // zip operation: - // input: tuple arguments - // output: tuple of items of input iterated on every input - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << "For 'zip', there is at least one input."; - } - - auto is_all_sequeue = - std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool { - MS_EXCEPTION_IF_NULL(abs); - return abs->isa(); - }); - if (!is_all_sequeue) { - MS_LOG(EXCEPTION) << "For 'zip', all inputs must be sequence."; - } - - auto min_abs = std::min_element( - args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &x, const AbstractBasePtr &y) { - return (x->cast()->size() < y->cast()->size()); - }); - FuncGraphPtr ret_graph = std::make_shared(); - ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - for (size_t idx = 0; idx < args_spec_list.size(); idx++) { - (void)ret_graph->add_parameter(); - } - - // generate tuple output of ziped arguments input - std::vector make_tuple_nodes; - make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t idx = 0; idx < (*min_abs)->cast()->size(); idx++) { - std::vector make_tuple_zip_nodes; - make_tuple_zip_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); - std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl"; - ValuePtr op = prim::GetPythonOps("getitem", module_name); - for (size_t arg_idx = 0; arg_idx < args_spec_list.size(); arg_idx++) { - std::vector tuple_get_item_nodes{NewValueNode(op), ret_graph->parameters()[arg_idx], - NewValueNode(SizeToInt(idx))}; - auto tuple_get_item_op = ret_graph->NewCNode(tuple_get_item_nodes); - make_tuple_zip_nodes.push_back(tuple_get_item_op); - } - auto make_tuple_zip_op = ret_graph->NewCNode(make_tuple_zip_nodes); - make_tuple_nodes.push_back(make_tuple_zip_op); - } - ret_graph->set_output(ret_graph->NewCNode(make_tuple_nodes)); - return ret_graph; -} - -REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) { - (void)py::class_>(*m, - "ZipOperation_") - .def(py::init()); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/zip_operation.h b/mindspore/ccsrc/operator/composite/zip_operation.h deleted file mode 100644 index 1a3fa1f5fe..0000000000 --- a/mindspore/ccsrc/operator/composite/zip_operation.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/misc.h" -#include "utils/any.h" -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using AbstractBasePtr = abstract::AbstractBasePtr; -using AbstractBasePtrList = abstract::AbstractBasePtrList; -using AbstractTuplePtr = abstract::AbstractTuplePtr; - -class ZipOperation : public MetaFuncGraph { - public: - explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {} - ~ZipOperation() override = default; - MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) { - os << op.name_; - return os; - } - friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; } -}; -using ZipOperationPtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc deleted file mode 100755 index b682847ed7..0000000000 --- a/mindspore/ccsrc/operator/ops.cc +++ /dev/null @@ -1,288 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/ops.h" -#include -#include - -namespace mindspore { -// namespace to support primitive operators -namespace prim { -// Arithmetic -const PrimitivePtr kPrimScalarAdd = std::make_shared("scalar_add"); -const PrimitivePtr kPrimScalarSub = std::make_shared("scalar_sub"); -const PrimitivePtr kPrimScalarMul = std::make_shared("scalar_mul"); -const PrimitivePtr kPrimScalarDiv = std::make_shared("scalar_div"); -const PrimitivePtr kPrimScalarFloordiv = std::make_shared("scalar_floordiv"); -const PrimitivePtr kPrimScalarMod = std::make_shared("scalar_mod"); -const PrimitivePtr kPrimScalarPow = std::make_shared("scalar_pow"); -const PrimitivePtr kPrimScalarTrunc = std::make_shared("scalar_trunc"); -const PrimitivePtr kPrimScalarFloor = std::make_shared("scalar_floor"); -const PrimitivePtr kPrimScalarUadd = std::make_shared("scalar_uadd"); -const PrimitivePtr kPrimScalarUsub = std::make_shared("scalar_usub"); -const PrimitivePtr kPrimScalarExp = std::make_shared("scalar_exp"); -const PrimitivePtr kPrimScalarLog = std::make_shared("scalar_log"); -const PrimitivePtr kPrimScalarSin = std::make_shared("scalar_sin"); -const PrimitivePtr kPrimScalarCos = std::make_shared("scalar_cos"); -const PrimitivePtr kPrimScalarTan = std::make_shared("scalar_tan"); - -// Comparisons -const PrimitivePtr kPrimScalarEq = std::make_shared("scalar_eq"); -const PrimitivePtr kPrimScalarLt = std::make_shared("scalar_lt"); -const PrimitivePtr kPrimScalarGt = std::make_shared("scalar_gt"); -const PrimitivePtr kPrimScalarNe = std::make_shared("scalar_ne"); -const PrimitivePtr kPrimScalarLe = std::make_shared("scalar_le"); -const PrimitivePtr kPrimScalarGe = std::make_shared("scalar_ge"); -const PrimitivePtr kPrimBoolNot = std::make_shared("bool_not"); -const PrimitivePtr kPrimBoolAnd = std::make_shared("bool_and"); -const PrimitivePtr kPrimBoolOr = std::make_shared("bool_or"); -const PrimitivePtr kPrimBoolEq = std::make_shared("bool_eq"); -const PrimitivePtr kPrimGreater = std::make_shared("Greater"); -const PrimitivePtr kPrimGreaterEqual = std::make_shared("GreaterEqual"); -const PrimitivePtr kPrimLess = std::make_shared("Less"); -const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); -const PrimitivePtr kPrimEqual = std::make_shared("Equal"); -const PrimitivePtr kPrimNotEqual = std::make_shared("NotEqual"); - -// Type introspection -const PrimitivePtr kPrimTypeOf = std::make_shared("typeof"); -const PrimitivePtr kPrimHasType = std::make_shared("hastype"); - -// Statements -const PrimitivePtr kPrimSwitch = std::make_shared("switch"); -const PrimitivePtr kPrimSwitchLayer = std::make_shared("switch_layer"); -const PrimitivePtr kPrimReturn = std::make_shared("return"); -const PrimitivePtr kPrimAssign = std::make_shared("Assign"); -const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); -const PrimitivePtr kPrimAssignSub = std::make_shared("AssignSub"); -const PrimitivePtr kPrimSelect = std::make_shared("Select"); -const PrimitivePtr kPrimCall = std::make_shared("call"); - -const PrimitivePtr kPrimDistribute = std::make_shared("distribute"); -const PrimitivePtr kPrimDot = std::make_shared("dot"); -const PrimitivePtr kPrimIm2Col = std::make_shared("im2col"); -const PrimitivePtr kPrimCol2Im = std::make_shared("col2im"); -const PrimitivePtr kPrimIm2ColV1 = std::make_shared("im2col_v1"); -const PrimitivePtr kPrimCol2ImV1 = std::make_shared("col2im_v1"); - -const PrimitivePtr kPrimResolve = std::make_shared("resolve"); -const PrimitivePtr kPrimEmbed = std::make_shared("embed"); -const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); -const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); - -const PrimitivePtr kPrimLabelGoto = std::make_shared("LabelGoto"); -const PrimitivePtr kPrimLabelSwitch = std::make_shared("LabelSwitch"); -const PrimitivePtr kPrimLabelSet = std::make_shared("LabelSet"); - -// Structure -const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); -const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); -const PrimitivePtr kPrimMakeTuple = std::make_shared("make_tuple"); -const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); -const PrimitivePtr kPrimMakeDict = std::make_shared("make_dict"); -const PrimitivePtr kPrimMakeKeywordArg = std::make_shared("make_keyword_arg"); -const PrimitivePtr kPrimExtractKeywordArg = std::make_shared("extract_keyword_arg"); -const PrimitivePtr kPrimMakeSlice = std::make_shared("make_slice"); -const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); -const PrimitivePtr kPrimTupleGetItem = std::make_shared("tuple_getitem"); -const PrimitivePtr kPrimListGetItem = std::make_shared("list_getitem"); -const PrimitivePtr kPrimArrayGetItem = std::make_shared("array_getitem"); -const PrimitivePtr kPrimTupleSetItem = std::make_shared("tuple_setitem"); -const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); -const PrimitivePtr kPrimArraySetItem = std::make_shared("array_setitem"); -const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); -const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); -const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); -const PrimitivePtr kPrimGetAttr = std::make_shared("getattr"); -const PrimitivePtr kPrimTupleLen = std::make_shared("tuple_len"); -const PrimitivePtr kPrimDictLen = std::make_shared("dict_len"); -const PrimitivePtr kPrimListLen = std::make_shared("list_len"); -const PrimitivePtr kPrimArrayLen = std::make_shared("array_len"); -const PrimitivePtr kPrimListMap = std::make_shared("list_map"); -const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); -const PrimitivePtr kPrimTupleReversed = std::make_shared("tuple_reversed"); - -const PrimitivePtr kPrimTileShape = std::make_shared("tile_shape"); -const PrimitivePtr kPrimReducedShape = std::make_shared("reduced_shape"); -const PrimitivePtr kPrimTupleDiv = std::make_shared("tuple_div"); -const PrimitivePtr kPrimTupleToArray = std::make_shared("tuple_to_array"); -const PrimitivePtr kPrimShapeMul = std::make_shared("shape_mul"); -const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared("generate_shape_index"); -const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared("generate_inverse_index"); -const PrimitivePtr kPrimTupleEqual = std::make_shared("tuple_equal"); -const PrimitivePtr kPrimListEqual = std::make_shared("list_equal"); -const PrimitivePtr kPrimMakeRange = std::make_shared("make_range"); -const PrimitivePtr kPrimStopGradient = std::make_shared("stop_gradient"); - -// Arrays -const PrimitivePtr kPrimScalarToArray = std::make_shared("scalar_to_array"); -const PrimitivePtr kPrimArrayToScalar = std::make_shared("array_to_scalar"); -const PrimitivePtr kPrimBroadcastShape = std::make_shared("broadcast_shape"); -const PrimitivePtr kPrimArrayMap = std::make_shared("array_map"); -const PrimitivePtr kPrimArrayReduce = std::make_shared("array_reduce"); -const PrimitivePtr kPrimShape = std::make_shared("Shape"); -const PrimitivePtr kPrimCast = std::make_shared("Cast"); -const PrimitivePtr kPrimConcat = std::make_shared("Concat"); -const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); -const PrimitivePtr kPrimTranspose = std::make_shared("Transpose"); -const PrimitivePtr kPrimGatherV2 = std::make_shared("GatherV2"); -const PrimitivePtr kPrimEmbeddingLookup = std::make_shared("EmbeddingLookup"); -const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("EmbeddingLookupCommGrad"); -const PrimitivePtr kPrimSize = std::make_shared("Size"); -const PrimitivePtr kPrimArgMax = std::make_shared("Argmax"); -const PrimitivePtr kPrimPack = std::make_shared("Pack"); -const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared("UnsortedSegmentSum"); -const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared("UnsortedSegmentMin"); -const PrimitivePtr kPrimConcatOffset = std::make_shared("ConcatOffset"); -const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); -const PrimitivePtr kPrimTile = std::make_shared("Tile"); -const PrimitivePtr kPrimAddN = std::make_shared("AddN"); -const PrimitivePtr KPrimTransData = std::make_shared("TransData"); -const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); -const PrimitivePtr kPrimPad = std::make_shared("Pad"); -const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); - -// Maths -const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); -const PrimitivePtr kPrimMatMul = std::make_shared("MatMul"); -const PrimitivePtr kPrimBatchMatMul = std::make_shared("BatchMatMul"); -const PrimitivePtr kPrimMaximumGrad = std::make_shared("MaximumGrad"); -const PrimitivePtr kPrimMinimumGrad = std::make_shared("MinimumGrad"); -const PrimitivePtr kPrimReduceMean = std::make_shared("ReduceMean"); -const PrimitivePtr kPrimReduceSum = std::make_shared("ReduceSum"); -const PrimitivePtr kPrimReduceAll = std::make_shared("ReduceAll"); -const PrimitivePtr kPrimReduceMax = std::make_shared("ReduceMax"); -const PrimitivePtr kPrimReduceMin = std::make_shared("ReduceMin"); -const PrimitivePtr kPrimNeg = std::make_shared("Neg"); -const PrimitivePtr kPrimSub = std::make_shared("Sub"); -const PrimitivePtr kPrimMul = std::make_shared("Mul"); -const PrimitivePtr kPrimMinimum = std::make_shared("Minimum"); -const PrimitivePtr kPrimMaximum = std::make_shared("Maximum"); -const PrimitivePtr kPrimSquare = std::make_shared("Square"); -const PrimitivePtr kPrimCumSum = std::make_shared("CumSum"); -const PrimitivePtr kPrimCumProd = std::make_shared("CumProd"); -const PrimitivePtr kPrimSubscalar = std::make_shared("Subscalar"); -const PrimitivePtr kPrimInplaceAdd = std::make_shared("InplaceAdd"); -const PrimitivePtr kPrimInplaceSub = std::make_shared("InplaceSub"); -const PrimitivePtr kPrimPow = std::make_shared("Pow"); -const PrimitivePtr kPrimRealDiv = std::make_shared("RealDiv"); -const PrimitivePtr kPrimSqrt = std::make_shared("Sqrt"); -const PrimitivePtr kPrimReciprocal = std::make_shared("Reciprocal"); -const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); - -// NN -const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); -const PrimitivePtr kPrimSoftmax = std::make_shared("Softmax"); -const PrimitivePtr kPrimLogSoftmax = std::make_shared("LogSoftmax"); -const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared("LogSoftmaxGrad"); -const PrimitivePtr kPrimTanh = std::make_shared("Tanh"); -const PrimitivePtr kPrimTanhGrad = std::make_shared("TanhGrad"); -const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); -const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); -const PrimitivePtr kPrimMaxPool = std::make_shared("MaxPool"); -const PrimitivePtr kPrimMaxPoolGrad = std::make_shared("MaxPoolGrad"); -const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("ApplyCenteredRMSProp"); -const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); -const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); -const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); -const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared("FusedBatchNormGrad"); -const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); -const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); -const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); -const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); -const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); -const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); -const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = - std::make_shared("DepthwiseConv2dNativeBackpropFilter"); -const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = - std::make_shared("DepthwiseConv2dNativeBackpropInput"); -const PrimitivePtr kPrimBiasAddGrad = std::make_shared("BiasAddGrad"); -const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared("SoftmaxCrossEntropyWithLogits"); -const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = - std::make_shared("SparseSoftmaxCrossEntropyWithLogits"); -const PrimitivePtr kPrimMomentum = std::make_shared("Momentum"); -const PrimitivePtr kPrimApplyMomentum = std::make_shared("ApplyMomentum"); -const PrimitivePtr kPrimLayerNorm = std::make_shared("LayerNorm"); -const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGrad"); -const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared("LayerNormXBackprop"); -const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); -const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); -const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); -const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); -const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); -const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); -const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); -const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); -const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); -const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); -const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); -const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); -const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); -const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); - -// Other miscellaneous -const PrimitivePtr kPrimIdentity = std::make_shared("identity"); -const PrimitivePtr kPrimPartial = std::make_shared("Partial"); -const PrimitivePtr kPrimJ = std::make_shared("J"); -const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); -const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); -const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); -const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); -const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); -const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); -const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); -const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); -const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); -const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); -const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); -const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); -const PrimitivePtr kPrimPrint = std::make_shared("Print"); - -const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); -const PrimitivePtr kPrimDepend = std::make_shared("Depend"); -const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); - -const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("BroadcastGradientArgs"); -const PrimitivePtr kPrimControlDepend = std::make_shared("ControlDepend"); -const PrimitivePtr kPrimIs_ = std::make_shared("is_"); -const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); -const PrimitivePtr kPrimInDict = std::make_shared("in_dict"); -const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); -const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); -const PrimitivePtr kPrimIsConsant = std::make_shared("is_constant"); -const PrimitivePtr kPrimEquivFormat = std::make_shared("EquivFormat"); - -// Comm ops -const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); -const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); -const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); -const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); - -// Debug ops -const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); -const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary"); -const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); -const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); -const PrimitivePtr kPrimDebug = std::make_shared("Debug"); - -// IndexedSlices -const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared("MakeIndexedSlices"); -const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared("IndexedSlicesGetValues"); -const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared("IndexedSlicesGetIndices"); -const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared("IndexedSlicesGetDenseShape"); -const PrimitivePtr kPrimIsIndexedSlices = std::make_shared("IsIndexedSlices"); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h deleted file mode 100755 index f778013896..0000000000 --- a/mindspore/ccsrc/operator/ops.h +++ /dev/null @@ -1,330 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_OPS_H_ -#define MINDSPORE_CCSRC_OPERATOR_OPS_H_ - -#include -#include -#include -#include "ir/anf.h" -#include "ir/primitive_base.h" - -namespace mindspore { -// namespace to support primitive operators -namespace prim { -ValuePtr GetPythonOps(const std::string &op_name, - const std::string &module_name = "mindspore._extends.parse.standard_method", - bool use_signature = false); - -// Arithmetic -extern const PrimitivePtr kPrimScalarAdd; -extern const PrimitivePtr kPrimScalarSub; -extern const PrimitivePtr kPrimScalarMul; -extern const PrimitivePtr kPrimScalarDiv; -extern const PrimitivePtr kPrimScalarFloordiv; -extern const PrimitivePtr kPrimScalarMod; -extern const PrimitivePtr kPrimScalarPow; -extern const PrimitivePtr kPrimScalarTrunc; -extern const PrimitivePtr kPrimScalarFloor; -extern const PrimitivePtr kPrimScalarUadd; -extern const PrimitivePtr kPrimScalarUsub; -extern const PrimitivePtr kPrimScalarExp; -extern const PrimitivePtr kPrimScalarLog; -extern const PrimitivePtr kPrimScalarSin; -extern const PrimitivePtr kPrimScalarCos; -extern const PrimitivePtr kPrimScalarTan; - -// Comparisons -extern const PrimitivePtr kPrimScalarEq; -extern const PrimitivePtr kPrimScalarLt; -extern const PrimitivePtr kPrimScalarGt; -extern const PrimitivePtr kPrimScalarNe; -extern const PrimitivePtr kPrimScalarLe; -extern const PrimitivePtr kPrimScalarGe; -extern const PrimitivePtr kPrimBoolNot; -extern const PrimitivePtr kPrimBoolAnd; -extern const PrimitivePtr kPrimBoolOr; -extern const PrimitivePtr kPrimBoolEq; -extern const PrimitivePtr kPrimGreater; -extern const PrimitivePtr kPrimGreaterEqual; -extern const PrimitivePtr kPrimLess; -extern const PrimitivePtr kPrimLessEqual; -extern const PrimitivePtr kPrimEqual; -extern const PrimitivePtr kPrimNotEqual; - -// Type introspection -extern const PrimitivePtr kPrimTypeOf; -extern const PrimitivePtr kPrimHasType; - -// Statements -extern const PrimitivePtr kPrimSwitch; -extern const PrimitivePtr kPrimSwitchLayer; -extern const PrimitivePtr kPrimReturn; -extern const PrimitivePtr kPrimAssign; -extern const PrimitivePtr kPrimAssignAdd; -extern const PrimitivePtr kPrimAssignSub; -extern const PrimitivePtr kPrimSelect; -extern const PrimitivePtr kPrimCall; - -extern const PrimitivePtr kPrimDistribute; -extern const PrimitivePtr kPrimDot; -extern const PrimitivePtr kPrimIm2Col; -extern const PrimitivePtr kPrimCol2Im; -extern const PrimitivePtr kPrimIm2ColV1; -extern const PrimitivePtr kPrimCol2ImV1; - -extern const PrimitivePtr kPrimResolve; -extern const PrimitivePtr kPrimEmbed; -extern const PrimitivePtr kPrimRefToEmbed; -extern const PrimitivePtr kPrimCreateInstance; - -extern const PrimitivePtr kPrimLabelGoto; -extern const PrimitivePtr kPrimLabelSwitch; -extern const PrimitivePtr kPrimLabelSet; - -// Structure -extern const PrimitivePtr kPrimStringEqual; -extern const PrimitivePtr kPrimStringConcat; -extern const PrimitivePtr kPrimMakeTuple; -extern const PrimitivePtr kPrimMakeList; -extern const PrimitivePtr kPrimMakeDict; -extern const PrimitivePtr kPrimMakeKeywordArg; -extern const PrimitivePtr kPrimExtractKeywordArg; -extern const PrimitivePtr kPrimMakeSlice; -extern const PrimitivePtr kPrimMakeRecord; -extern const PrimitivePtr kPrimTupleGetItem; -extern const PrimitivePtr kPrimListGetItem; -extern const PrimitivePtr kPrimArrayGetItem; -extern const PrimitivePtr kPrimTupleSetItem; -extern const PrimitivePtr kPrimListSetItem; -extern const PrimitivePtr kPrimArraySetItem; -extern const PrimitivePtr kPrimDictGetItem; -extern const PrimitivePtr kPrimDictSetItem; -extern const PrimitivePtr kPrimListAppend; -extern const PrimitivePtr kPrimGetAttr; -extern const PrimitivePtr kPrimTupleLen; -extern const PrimitivePtr kPrimDictLen; -extern const PrimitivePtr kPrimListLen; -extern const PrimitivePtr kPrimArrayLen; -extern const PrimitivePtr kPrimListMap; -extern const PrimitivePtr kPrimListReduce; -extern const PrimitivePtr kPrimTupleReversed; -extern const PrimitivePtr kPrimTileShape; -extern const PrimitivePtr kPrimReducedShape; -extern const PrimitivePtr kPrimTupleDiv; -extern const PrimitivePtr kPrimTupleToArray; -extern const PrimitivePtr kPrimShapeMul; -extern const PrimitivePtr kPrimGenerateShapeIndex; -extern const PrimitivePtr kPrimGenerateInverseIndex; -extern const PrimitivePtr kPrimTupleEqual; -extern const PrimitivePtr kPrimListEqual; -extern const PrimitivePtr kPrimMakeRange; -extern const PrimitivePtr kPrimStopGradient; - -// Arrays -extern const PrimitivePtr kPrimScalarToArray; -extern const PrimitivePtr kPrimArrayToScalar; -extern const PrimitivePtr kPrimBroadcastShape; -extern const PrimitivePtr kPrimArrayMap; -extern const PrimitivePtr kPrimArrayReduce; -extern const PrimitivePtr kPrimShape; -extern const PrimitivePtr kPrimCast; -extern const PrimitivePtr kPrimConcat; -extern const PrimitivePtr kPrimSqueeze; -extern const PrimitivePtr kPrimTranspose; -extern const PrimitivePtr kPrimGatherV2; -extern const PrimitivePtr kPrimEmbeddingLookup; -extern const PrimitivePtr kPrimEmbeddingLookupCommGrad; -extern const PrimitivePtr kPrimSize; -extern const PrimitivePtr kPrimArgMax; -extern const PrimitivePtr kPrimPack; -extern const PrimitivePtr kPrimUnpack; -extern const PrimitivePtr kPrimUnsortedSegmentMin; -extern const PrimitivePtr kPrimUnsortedSegmentSum; -extern const PrimitivePtr kPrimConcatOffset; -extern const PrimitivePtr kPrimReshape; -extern const PrimitivePtr kPrimTile; -extern const PrimitivePtr kPrimAddN; -extern const PrimitivePtr KPrimTransData; -extern const PrimitivePtr kPrimNMSWithMask; -extern const PrimitivePtr kPrimPad; -extern const PrimitivePtr kPrimArgMaxWithValue; -extern const PrimitivePtr kPrimRealDiv; -extern const PrimitivePtr kPrimSqrt; -extern const PrimitivePtr kPrimReciprocal; -extern const PrimitivePtr kPrimExpandDims; - -// Maths -extern const PrimitivePtr kPrimTensorAdd; -extern const PrimitivePtr kPrimMatMul; -extern const PrimitivePtr kPrimBatchMatMul; -extern const PrimitivePtr kPrimMaximumGrad; -extern const PrimitivePtr kPrimMinimumGrad; -extern const PrimitivePtr kPrimReduceMean; -extern const PrimitivePtr kPrimReduceSum; -extern const PrimitivePtr kPrimReduceAll; -extern const PrimitivePtr kPrimReduceMax; -extern const PrimitivePtr kPrimReduceMin; -extern const PrimitivePtr kPrimNeg; -extern const PrimitivePtr kPrimSub; -extern const PrimitivePtr kPrimMul; -extern const PrimitivePtr kPrimRealDiv; -extern const PrimitivePtr kPrimMinimum; -extern const PrimitivePtr kPrimMaximum; -extern const PrimitivePtr kPrimSquare; -extern const PrimitivePtr kPrimSqrt; -extern const PrimitivePtr kPrimEqual; -extern const PrimitivePtr kPrimLess; -extern const PrimitivePtr kPrimLessEqual; -extern const PrimitivePtr kPrimCumSum; -extern const PrimitivePtr kPrimCumProd; -extern const PrimitivePtr kPrimSubscalar; -extern const PrimitivePtr kPrimInplaceAdd; -extern const PrimitivePtr kPrimInplaceSub; -extern const PrimitivePtr kPrimPow; - -// NN -extern const PrimitivePtr kPrimFlatten; -extern const PrimitivePtr kPrimSoftmax; -extern const PrimitivePtr kPrimLogSoftmax; -extern const PrimitivePtr kPrimLogSoftmaxGrad; -extern const PrimitivePtr kPrimApplyCenteredRMSProp; -extern const PrimitivePtr kPrimTanh; -extern const PrimitivePtr kPrimTanhGrad; -extern const PrimitivePtr kPrimPooling; -extern const PrimitivePtr kPrimPoolingGrad; -extern const PrimitivePtr kPrimFusedBatchNorm; -extern const PrimitivePtr kPrimBatchNorm; -extern const PrimitivePtr kPrimBatchNormGrad; -extern const PrimitivePtr kPrimConv2D; -extern const PrimitivePtr kPrimMaxPool; -extern const PrimitivePtr kPrimMaxPoolGrad; -extern const PrimitivePtr kPrimAvgPoolGrad; -extern const PrimitivePtr kPrimFusedBatchNormGrad; -extern const PrimitivePtr kPrimReluGrad; -extern const PrimitivePtr kPrimConv2DBackpropInput; -extern const PrimitivePtr kPrimConv2DBackpropFilter; -extern const PrimitivePtr kPrimDepthwiseConv2dNative; -extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter; -extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput; - -extern const PrimitivePtr kPrimBiasAddGrad; -extern const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits; -extern const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits; -extern const PrimitivePtr kPrimMomentum; -extern const PrimitivePtr kPrimApplyMomentum; -extern const PrimitivePtr kPrimLayerNorm; -extern const PrimitivePtr kPrimLayerNormGrad; -extern const PrimitivePtr kPrimLayerNormXBackprop; -extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop; -extern const PrimitivePtr kPrimDropoutGenMask; -extern const PrimitivePtr kPrimDropoutDoMask; -extern const PrimitivePtr kPrimOneHot; -extern const PrimitivePtr kPrimGelu; -extern const PrimitivePtr kPrimGeluGrad; -extern const PrimitivePtr kPrimRelu; -extern const PrimitivePtr kPrimReluV2; -extern const PrimitivePtr kPrimActivation; -extern const PrimitivePtr kPrimZerosLike; -extern const PrimitivePtr kPrimFakeBprop; -extern const PrimitivePtr kPrimBpropCut; -extern const PrimitivePtr kPrimFakeQuantPerLayer; -extern const PrimitivePtr kPrimFakeQuantPerChannel; -extern const PrimitivePtr kPrimApplyRMSProp; - -// Other Miscellaneous -extern const PrimitivePtr kPrimIdentity; -extern const PrimitivePtr kPrimPartial; -extern const PrimitivePtr kPrimJ; -extern const PrimitivePtr kPrimEnvSetItem; -extern const PrimitivePtr kPrimEnvGetItem; -extern const PrimitivePtr kPrimEnvAdd; -extern const PrimitivePtr kPrimMakeRefKey; -extern const PrimitivePtr kPrimMakeRef; -extern const PrimitivePtr kPrimGetRefKey; -extern const PrimitivePtr kPrimGetRefValue; -extern const PrimitivePtr kPrimGetRefOrigin; -extern const PrimitivePtr kPrimInsertGradientOf; -extern const PrimitivePtr kPrimHookBackward; -extern const PrimitivePtr kPrimPrintShapeType; -extern const PrimitivePtr kPrimPrint; -extern const PrimitivePtr kPrimSameTypeShape; -extern const PrimitivePtr kPrimCheckBprop; -extern const PrimitivePtr kPrimDepend; -extern const PrimitivePtr kPrimStateSetItem; -extern const PrimitivePtr kPrimScalarSummary; -extern const PrimitivePtr kPrimImageSummary; -extern const PrimitivePtr kPrimTensorSummary; -extern const PrimitivePtr kPrimHistogramSummary; -extern const PrimitivePtr kPrimBroadcastGradientArgs; -extern const PrimitivePtr kPrimControlDepend; -extern const PrimitivePtr kPrimIs_; -extern const PrimitivePtr kPrimIsNot; -extern const PrimitivePtr kPrimInDict; -extern const PrimitivePtr kPrimNotInDict; -extern const PrimitivePtr kPrimMixedPrecisionCast; -extern const PrimitivePtr kPrimIsConsant; -extern const PrimitivePtr kPrimEquivFormat; -extern const PrimitivePtr kPrimDebug; - -// Comm ops -extern const PrimitivePtr kPrimAllReduce; -extern const PrimitivePtr kPrimMirror; -extern const PrimitivePtr kPrimVirtualDiv; -extern const PrimitivePtr kPrimVirtualDataset; - -// IndexedSlices -extern const PrimitivePtr kPrimMakeIndexedSlices; -extern const PrimitivePtr kPrimIndexedSlicesGetValues; -extern const PrimitivePtr kPrimIndexedSlicesGetIndices; -extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; -extern const PrimitivePtr kPrimIsIndexedSlices; - -class DoSignaturePrimitive : public Primitive { - public: - explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) - : Primitive("S-Prim-" + name), function_(function) {} - - ~DoSignaturePrimitive() override = default; - - MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive) - - const ValuePtr function() const { return function_; } - - private: - ValuePtr function_; -}; -using DoSignaturePrimitivePtr = std::shared_ptr; - -class UnpackGraphPrimitive : public Primitive { - public: - explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) - : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} - ~UnpackGraphPrimitive() override = default; - MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) - bool with_sens_in_args() const { return with_sens_in_args_; } - bool need_unpack_args() const { return need_unpack_args_; } - - private: - bool with_sens_in_args_; - bool need_unpack_args_; -}; -using UnpackGraphPrimitivePtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_OPS_H_ diff --git a/mindspore/ccsrc/operator/ops_extends.cc b/mindspore/ccsrc/operator/ops_extends.cc deleted file mode 100755 index d415b45adf..0000000000 --- a/mindspore/ccsrc/operator/ops_extends.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/ops.h" -#include -#include -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/data_converter.h" - -namespace mindspore { -// namespace to support primitive operators -namespace prim { -ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name, bool use_signature) { - py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); - ValuePtr node = nullptr; - bool succ = parse::ConvertData(obj, &node, use_signature); - if (!succ) { - MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail"; - } - return node; -} -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_arrays.cc b/mindspore/ccsrc/operator/prim_arrays.cc deleted file mode 100644 index 237ca795eb..0000000000 --- a/mindspore/ccsrc/operator/prim_arrays.cc +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/utils.h" -#include "operator/cc_implementations.h" -#include "pipeline/static_analysis/param_validator.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a scalar. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractScalarPtr arg = CheckArg(op_name, args_spec_list, 0); - return std::make_shared(arg, std::make_shared()); -} - -AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor with 0 shape. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto arg = CheckArg(op_name, args_spec_list, 0); - auto a_shp = arg->shape(); - if (!a_shp->shape().empty()) { - MS_LOG(EXCEPTION) << "array_to_scalar requires zero size shape."; - } - return arg->element(); -} - -AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tuples. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto xs = CheckArg(op_name, args_spec_list, 0); - auto ys = CheckArg(op_name, args_spec_list, 1); - - auto value_tuple_x = xs->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(value_tuple_x); - auto shp_tuple_x = value_tuple_x->value(); - std::vector shp_x; - (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(shp_x), - [](const ValuePtr &e) -> int { return GetValue(e); }); - - auto value_tuple_y = ys->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(value_tuple_y); - auto shp_tuple_y = value_tuple_y->value(); - std::vector shp_y; - (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y), - [](const ValuePtr &e) -> int { return GetValue(e); }); - - std::vector res = prim::BroadcastShape_(shp_x, shp_y); - if (res.empty()) { - MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," - << args_spec_list[1]->ToString(); - } - - AbstractBasePtrList elems; - (void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int n) -> AbstractBasePtr { - return std::make_shared(std::make_shared(n), kInt32); - }); - - return std::make_shared(elems); -} - -AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, 0); - MS_LOG(DEBUG) << "InferImplShape:" << arg->ToString(); - - AbstractBasePtrList values; - auto shp = arg->shape(); - for (int entry : shp->shape()) { - auto entry_v = MakeValue(entry); - values.push_back(std::make_shared(entry_v, entry_v->type())); - } - return std::make_shared(values); -} - -AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto arg = CheckArg(op_name, args_spec_list, 0); - auto multiples = CheckArg(op_name, args_spec_list, 1); - - ShapePtr input_shape = arg->shape(); - (void)CheckTensorDType(arg, {kInt16, kFloat16, kInt32, kFloat32}, "Input 0 of Tile should be %s"); - - auto mul_shp_value = multiples->BuildValue(); - if (mul_shp_value->isa()) { - MS_LOG(EXCEPTION) << "shape's data field can't be anything: " << args_spec_list[1]->ToString(); - } - - std::vector mul_shp; - auto value_tuple_mul = mul_shp_value->cast(); - auto mul_shp_data = value_tuple_mul->value(); - (void)std::transform(std::begin(mul_shp_data), std::end(mul_shp_data), std::back_inserter(mul_shp), - [](const ValuePtr &e) -> int { return GetValue(e); }); - if (input_shape->shape().size() != mul_shp_data.size()) { - MS_LOG(EXCEPTION) << "Tile requires input and multiples size equal, while the input size is " - << input_shape->shape().size() << ", value size is: " << mul_shp_data.size() << "."; - } - - std::vector result_shp; - for (size_t i = 0; i < mul_shp_data.size(); ++i) { - result_shp.push_back(input_shape->shape()[i] * mul_shp[i]); - } - return std::make_shared(arg->element(), std::make_shared(result_shp)); -} - -AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple of tensor. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto arg = CheckArg(op_name, args_spec_list, 0); - if (arg->elements().empty()) { - MS_LOG(EXCEPTION) << "Arg elements is empty."; - } - - size_t tuple_len = arg->elements().size(); - AbstractTensorPtr tensor_base = CheckArg(op_name, arg->elements(), 0); - int rank_base = SizeToInt(tensor_base->shape()->shape().size()); - - ValuePtr axis = primitive->GetAttr("axis"); - // Axis value should be in [-(rank_base + 1), rank_base). - int axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base); - // If axis is negative, add offset(rank_base + 1) to turn it to positive. - axis_value = GetPositiveAxis(axis_value, IntToSize(rank_base + 1)); - - for (size_t i = 1; i < tuple_len; ++i) { - AbstractTensorPtr tensor = CheckArg(op_name, arg->elements(), i); - (void)CheckDtypeSame(op_name, tensor_base, tensor); - (void)CheckShapeSame(op_name, tensor_base, tensor); - } - - primitive->set_attr("N", MakeValue(SizeToInt(tuple_len))); - primitive->set_attr("T", tensor_base->element()->BuildType()); - - AbstractTensorPtr ret = dyn_cast(tensor_base->Broaden()); - MS_EXCEPTION_IF_NULL(ret); - auto shape = ret->shape()->shape(); - (void)shape.insert(shape.begin() + axis_value, tuple_len); - ret->set_shape(std::make_shared(shape)); - return ret; -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_debug.cc b/mindspore/ccsrc/operator/prim_debug.cc deleted file mode 100644 index 5e6cdcc318..0000000000 --- a/mindspore/ccsrc/operator/prim_debug.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/param_validator.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/utils.h" -#include "utils/symbolic.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor(value) - const std::string op_name = primitive->name(); - - CheckArgsSize(op_name, args_spec_list, 1); - auto tensor_value = CheckArg(op_name, args_spec_list, 0); - - int tensor_rank = SizeToInt(tensor_value->shape()->shape().size()); - if (tensor_rank == 0) { - MS_LOG(EXCEPTION) << op_name << " summary evaluator second arg should be an tensor, but got a scalar, rank is 0"; - } - - return std::make_shared(AbstractBasePtrList({tensor_value->Broaden()})); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_maths.cc b/mindspore/ccsrc/operator/prim_maths.cc deleted file mode 100644 index 02b86603e7..0000000000 --- a/mindspore/ccsrc/operator/prim_maths.cc +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/utils.h" -#include "pipeline/static_analysis/param_validator.h" -#include "common/utils.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - auto input_x = CheckArg(op_name, args_spec_list, 0); - auto input_y = CheckArg(op_name, args_spec_list, 1); - auto dout = CheckArg(op_name, args_spec_list, 2); - (void)CheckTensorsDTypeSame({input_x, input_y, dout}, {kInt, kUInt, kFloat}, - op_name + "evaluator three inputs should be %s"); - - AbstractBasePtr dx = input_x->Broaden(); - AbstractBasePtr dy = input_y->Broaden(); - - return std::make_shared(AbstractBasePtrList({dx, dy})); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_nn.cc b/mindspore/ccsrc/operator/prim_nn.cc deleted file mode 100644 index d9a0071757..0000000000 --- a/mindspore/ccsrc/operator/prim_nn.cc +++ /dev/null @@ -1,432 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/utils.h" -#include "pipeline/static_analysis/param_validator.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractTensorPtr input_tensor = CheckArg(op_name, args_spec_list, 0); - (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input 0 of Pooling should be %s"); - - ShapePtr input_shape = dyn_cast(input_tensor->GetShapeTrack()); // NCHW - MS_EXCEPTION_IF_NULL(input_shape); - if (input_shape->shape().size() != 4) { - MS_LOG(EXCEPTION) << "Pooling input should be a 4-D tensor."; - } - int h_input = input_shape->shape()[2]; - int w_input = input_shape->shape()[3]; - - int window = primitive->GetAttr("window")->cast()->value(); - int stride = primitive->GetAttr("stride")->cast()->value(); - int padding = primitive->GetAttr("pad")->cast()->value(); - int nan_opt = primitive->GetAttr("nan_opt")->cast()->value(); - int data_mode = primitive->GetAttr("data_mode")->cast()->value(); - int ceil_mode = primitive->GetAttr("ceil_mode")->cast()->value(); - - if (stride <= 0) { - MS_LOG(EXCEPTION) << "Invalid stride value: " << stride << ", should greater then 0"; - } - if (nan_opt != 0) { - MS_LOG(EXCEPTION) << "Invalid nan_opt value: " << nan_opt << ", should be 0"; - } - if (data_mode != 1) { - MS_LOG(EXCEPTION) << "Invalid data_mode value: " << data_mode << ", should be 1"; - } - if (ceil_mode != 0) { - MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0"; - } - - std::set available_pad_mode{"pad", "same", "valid"}; - auto pad_mode_ptr = primitive->GetAttr("pad_mode"); - if ((pad_mode_ptr != nullptr) && pad_mode_ptr->isa()) { - auto pad_mode = pad_mode_ptr->cast()->value(); - if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { - MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; - } - if (pad_mode == "valid") { - padding = 0; - } else if (pad_mode == "same") { - padding = (window - 1) / 2; - } - } - - std::set available_mode{"max", "avg"}; - auto mode_ptr = primitive->GetAttr("mode"); - if ((mode_ptr != nullptr) && mode_ptr->isa()) { - auto mode = mode_ptr->cast()->value(); - if (available_mode.find(mode) == available_mode.end()) { - MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << "."; - } - } - - int h_out = ((h_input + 2 * padding - (window - 1) - 1) / stride) + 1; - int w_out = ((w_input + 2 * padding - (window - 1) - 1) / stride) + 1; - std::vector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out}; - AbstractBasePtr ret = input_tensor->Broaden(); - ret->set_shape(std::make_shared(shape_out)); - return ret; -} - -AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors(y, dy, x). - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - auto out_y = CheckArg(op_name, args_spec_list, 0); - auto d_out = CheckArg(op_name, args_spec_list, 1); - auto input_x = CheckArg(op_name, args_spec_list, 2); - (void)CheckTensorsDTypeSame({out_y, d_out, input_x}, {kInt, kUInt, kFloat}, - op_name + "evaluator three inputs should be %s"); - - AbstractBasePtr ret = d_out->Broaden(); - auto x_shape = dyn_cast(args_spec_list[2]->GetShapeTrack()); - MS_EXCEPTION_IF_NULL(x_shape); - - ret->set_shape(x_shape); - return ret; -} - -void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - // check dimension, x > 1, others equal 1 - const std::string op_name = primitive->name(); - for (std::size_t i = 0; i < args_spec_list.size(); ++i) { - AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, i); - ShapePtr arg_shape = dyn_cast(arg->GetShapeTrack()); - if (arg_shape == nullptr) { - MS_LOG(EXCEPTION) << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString(); - } - - if (i == 0) { - if (arg_shape->shape().size() < 2) { - MS_LOG(EXCEPTION) << op_name << " shape of args[" << i - << "] should be TensorShape with dimension greater than 1, but shape: " - << arg_shape->ToString(); - } - continue; - } - - if (arg_shape->shape().size() != 1) { - MS_LOG(EXCEPTION) << op_name << " shape of args[" << i - << "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString(); - } - } -} - -AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: five tensors(x, gamma, beta, mean, variance). - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 5); - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - MS_LOG(DEBUG) << "InferImplFusedBatchNorm args0:" << args_spec_list[0]->ToString() - << ", arg1:" << args_spec_list[1]->ToString(); - FusedBatchNormCheckDim(primitive, args_spec_list); - - auto input = args_spec_list[0]; - auto input_shape = dyn_cast(input->GetShapeTrack()); - MS_EXCEPTION_IF_NULL(input_shape); - const auto &input_shape_list = input_shape->shape(); - if (input_shape_list.size() < 2) { - MS_LOG(EXCEPTION) << "Input shape size should >= 2."; - } - - for (size_t i = 1; i < args_spec_list.size(); ++i) { - auto arg_shape = dyn_cast(args_spec_list[i]->GetShapeTrack()); - MS_EXCEPTION_IF_NULL(arg_shape); - const auto &arg_shape_list = arg_shape->shape(); - if (arg_shape_list.size() < 1) { - MS_LOG(EXCEPTION) << "Arg shape size should >= 1."; - } - if (arg_shape_list[0] != input_shape_list[1]) { - MS_LOG(EXCEPTION) << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0] - << ") should match the second dimension of tensor" - " param[0](which is " - << input_shape_list[1] << ")."; - } - } - auto input_tensor = CheckArg(op_name, args_spec_list, 0); - (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param 0 of FusedBatchNorm should be %s"); - - AbstractTensorPtrList tensorPtrList = std::vector(); - for (size_t i = 1; i < args_spec_list.size(); ++i) { - auto param = CheckArg(op_name, args_spec_list, i); - tensorPtrList.push_back(param); - } - (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32}, "param 1 to 4 of FusedBatchNorm should be %s"); - - // check validity; - auto epsilon_value = primitive->GetAttr("epsilon"); - auto momentum_value = primitive->GetAttr("momentum"); - MS_EXCEPTION_IF_NULL(epsilon_value); - MS_EXCEPTION_IF_NULL(momentum_value); - if (!epsilon_value->isa() || !momentum_value->isa()) { - MS_LOG(EXCEPTION) << "expect epsilon and momentum be float, but: epsilon: " << epsilon_value->ToString() - << ", momentum: " << momentum_value->ToString(); - } - - auto epsilon = epsilon_value->cast()->value(); - auto momentum = momentum_value->cast()->value(); - - if (epsilon > 1.0f || epsilon <= 0.0f) { - MS_LOG(EXCEPTION) << "expect epsilon is greater than 0 and less or equal than 1, but epsilon: " << epsilon; - } - if (momentum > 1.0f || momentum < 0.0f) { - MS_LOG(EXCEPTION) << "expect momentum is great or equal than 0 and less or equal than 1, but epsilon: " << momentum; - } - - // Outputs: y, running_mean, running_variance, save_mean, save_inv_variance. - AbstractBasePtr y = input->Broaden(); - AbstractBasePtr other = args_spec_list[1]->Broaden(); - MS_LOG(DEBUG) << "output y: " << y->ToString() << ", other: " << other->ToString(); - - AbstractBasePtrList elements = {y, other, other, other, other}; - return std::make_shared(elements); -} - -AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - MS_EXCEPTION_IF_NULL(args_spec_list[2]); - MS_EXCEPTION_IF_NULL(args_spec_list[3]); - - CheckArgsSize(primitive->name(), args_spec_list, 5); - auto dx = args_spec_list[1]->Broaden(); - auto dscale = args_spec_list[2]->Broaden(); - auto dbias = args_spec_list[3]->Broaden(); - - AbstractBasePtrList rets = {dx, dscale, dbias}; - return std::make_shared(rets); -} - -AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors(y_backprop, x). - CheckArgsSize(primitive->name(), args_spec_list, 2); - return args_spec_list[1]->Broaden(); -} - -AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors(doutput, input, filters). - CheckArgsSize(primitive->name(), args_spec_list, 3); - return args_spec_list[1]->Broaden(); -} - -AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors(inputs, filter, doutput). - CheckArgsSize(primitive->name(), args_spec_list, 3); - return args_spec_list[2]->Broaden(); -} - -AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: at least one tensor(y_backprop) - // Outputs: dbias - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << primitive->name() << " evaluator at least has 1 parameters, while the input size is " - << args_spec_list.size() << "."; - } - - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - ShapePtr shape_y = dyn_cast(args_spec_list[0]->GetShapeTrack()); - MS_EXCEPTION_IF_NULL(shape_y); - std::vector y_dims = shape_y->shape(); - if (y_dims.size() < 2) { - MS_LOG(EXCEPTION) << primitive->name() << " input y backprop, dim should >= 2, while " << y_dims.size() << "."; - } - std::vector bias_dims = {y_dims[1]}; - ShapePtr ret_shape = std::make_shared(bias_dims); - AbstractBasePtr ret = args_spec_list[0]->Broaden(); - ret->set_shape(ret_shape); - return ret; -} - -AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]->Broaden(); -} - -AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]->Broaden(); -} - -AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]->Broaden(); -} - -AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - AbstractBasePtrList args_list; - for (size_t i = 0; i < args_spec_list.size() - 2; i++) { - args_list.push_back(args_spec_list[i]->Broaden()); - } - return std::make_shared(args_list); -} - -AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors(x, gamma, beta). - // outputs: y, mean, variance - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - auto input_x = CheckArg(op_name, args_spec_list, 0); - auto input_shape = input_x->shape(); - auto const &input_shape_list = input_shape->shape(); - const size_t input_rank = input_shape_list.size(); - if (input_rank == 0) { - MS_LOG(EXCEPTION) << "input_rank should not be zero"; - } - - // begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1 - ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis"); - int begin_norm_axis = CheckAxis(op_name, bna_ptr, -1, SizeToInt(input_rank) - 1); - - ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis"); - int begin_params_axis = CheckAxis(op_name, bpa_ptr, -1, SizeToInt(input_rank) - 1); - begin_params_axis = GetPositiveAxis(begin_params_axis, input_rank); - - // the beta and gama shape should be x_shape[begin_params_axis:] - auto tensor = CheckArg(op_name, args_spec_list, 0); - auto gamma = CheckArg(op_name, args_spec_list, 1); - auto beta = CheckArg(op_name, args_spec_list, 2); - (void)CheckTensorDType(tensor, {kFloat16, kFloat32}, "input 0 of LayerNorm should be %s"); - (void)CheckTensorDType(gamma, {kFloat16, kFloat32}, "input 1 of LayerNorm should be %s"); - (void)CheckTensorDType(beta, {kFloat16, kFloat32}, "input 2 of LayerNorm should be %s"); - auto gamma_shape = dyn_cast(gamma->BuildShape()); - auto beta_shape = dyn_cast(beta->BuildShape()); - MS_EXCEPTION_IF_NULL(gamma_shape); - MS_EXCEPTION_IF_NULL(beta_shape); - - auto const &gamma_shape_list = gamma_shape->shape(); - auto const &beta_shape_list = beta_shape->shape(); - if (gamma_shape_list.empty() || beta_shape_list.empty()) { - MS_LOG(EXCEPTION) << "LayerNorm evaluator gamma or beta is a AbstractScalar that is not support."; - } - - size_t begin_params_axis_u = IntToSize(begin_params_axis); - if ((begin_params_axis_u > input_shape_list.size()) || - (gamma_shape_list.size() + begin_params_axis_u < input_shape_list.size()) || - (beta_shape_list.size() + begin_params_axis_u < input_shape_list.size())) { - MS_LOG(EXCEPTION) << "Gamma and beta shape get wrong size."; - } - for (size_t i = begin_params_axis_u; i < input_shape_list.size(); ++i) { - size_t gamma_beta_shape_dim = i - begin_params_axis_u; - if ((gamma_shape_list[gamma_beta_shape_dim] != input_shape_list[i]) || - (beta_shape_list[gamma_beta_shape_dim] != input_shape_list[i])) { - MS_LOG(EXCEPTION) << "Gamma or beta shape not match input shape, input_shape=" << input_shape->ToString() - << ", gamma_shape=" << gamma_shape->ToString() << ", beta_shape=" << beta_shape->ToString(); - } - } - - auto mean_var_shape_value = input_shape->shape(); - if (begin_norm_axis == -1) { - mean_var_shape_value[input_rank - 1] = 1; - } else { - for (size_t i = begin_norm_axis; i < input_rank; ++i) { - mean_var_shape_value[i] = 1; - } - } - - auto mean = input_x->Broaden(); - mean->set_shape(std::make_shared(mean_var_shape_value)); - auto var = input_x->Broaden(); - var->set_shape(std::make_shared(mean_var_shape_value)); - - AbstractBasePtrList args_list({input_x->Broaden(), mean, var}); - return std::make_shared(args_list); -} - -AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: five tensors(y_backprob, x, variance, mean, gamma). - // Outputs: x_backprob, gamma_backprob, beta_backprob - CheckArgsSize(primitive->name(), args_spec_list, 5); - - auto x_backprob = args_spec_list[0]->Broaden(); - auto gamma_backprob = args_spec_list[4]->Broaden(); - auto beta_backprob = args_spec_list[4]->Broaden(); - - AbstractBasePtrList args_list({x_backprob, gamma_backprob, beta_backprob}); - return std::make_shared(args_list); -} - -AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple and a tensor. - // Outputs: mask. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr x_shape = CheckArg(op_name, args_spec_list, 0); - AbstractTensorPtr keep_prob = CheckArg(op_name, args_spec_list, 1); - - TypePtr prob_type = keep_prob->element()->BuildType(); - if ((prob_type->type_id() != kNumberTypeFloat16) && (prob_type->type_id() != kNumberTypeFloat32)) { - MS_LOG(EXCEPTION) << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString() - << "."; - } - - auto x_shape_data = x_shape->elements(); - int count = 1; - for (std::size_t i = 0; i < x_shape->size(); ++i) { - auto value_track = x_shape_data[i]->GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_track); - if (!value_track->isa()) { - MS_LOG(EXCEPTION) << "DropOutGenMask input x_shape elements is not int32, but " << value_track->ToString() << "."; - } - - int e_value = GetValue(value_track); - if (e_value <= 0) { - MS_LOG(EXCEPTION) << "DropOutGenMask product of x_shape should be > 0"; - } - if (std::numeric_limits::max() / count / e_value < 1) { - MS_LOG(EXCEPTION) << "integer multiply integer overflow"; - } - count = count * e_value; - } - - // convert to bytes(8 bits) mask, using round up - int n128s = count / 128; - if ((count % 128) != 0) { - n128s++; - } - int bytes_count = n128s * 16; - std::vector shape_y{bytes_count}; - - primitive->set_attr("T", kInt32); - return std::make_shared(std::make_shared(kAnyValue, kUInt8), - std::make_shared(std::vector{shape_y})); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_others.cc b/mindspore/ccsrc/operator/prim_others.cc deleted file mode 100644 index ff9ec712bb..0000000000 --- a/mindspore/ccsrc/operator/prim_others.cc +++ /dev/null @@ -1,517 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "ir/dtype.h" -#include "common/utils.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/param_validator.h" -#include "pipeline/static_analysis/prim.h" -#include "pipeline/static_analysis/utils.h" -#include "utils/symbolic.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // An object of a subclass of AbstractBase - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]; -} - -AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: An object of AbstractFunction. - CheckArgsSize(primitive->name(), args_spec_list, 1); - MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString(); - - AbstractFunctionPtr x = dyn_cast(args_spec_list[0]); - if (x == nullptr) { - return std::make_shared(args_spec_list[0]); - } - - AbstractFuncAtomPtrList jv; - auto build_jv = [&jv](const AbstractFuncAtomPtr &func) { - auto j_closure = std::make_shared(func); - jv.push_back(j_closure); - }; - x->Visit(build_jv); - - return AbstractFunction::MakeAbstractFunction(jv); -} - -class UndeterminedShapeType { - public: - explicit UndeterminedShapeType(const std::string &env_str) { - // param_name indices_shape indices_type values_shape values_type dense_shape - // export UNDETERMINED_SPARSE_SHAPE_TYPES="sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 - // 2:Float32:3 1 2" - std::vector fields; - string tmp; - std::stringstream input(env_str); - while (std::getline(input, tmp, ':')) { - fields.push_back(tmp); - } - if (fields.size() != fields_num) { - MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size(); - } - - param_name_ = fields[0]; - - indices_shape_ = GetShape(fields[1]); - indices_type_ = StringToType(fields[2]); - - values_shape_ = GetShape(fields[3]); - values_type_ = StringToType(fields[4]); - - auto dense_shape_vec = GetShape(fields[5]); - AbstractBasePtrList dense_shape_list; - (void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list), - [](const auto &elem) { return FromValue(elem, false); }); - dense_shape_ = dense_shape_list; - } - ~UndeterminedShapeType() = default; - const std::string ¶m_name() { return param_name_; } - const std::vector &indices_shape() { return indices_shape_; } - const TypePtr &indices_type() { return indices_type_; } - const std::vector &values_shape() { return values_shape_; } - const TypePtr &values_type() { return values_type_; } - const AbstractBasePtrList &dense_shape() { return dense_shape_; } - - private: - std::string param_name_; - std::vector indices_shape_; - TypePtr indices_type_; - std::vector values_shape_; - TypePtr values_type_; - AbstractBasePtrList dense_shape_; - static const size_t fields_num; - - std::vector GetShape(const std::string &shape_str); -}; -std::vector UndeterminedShapeType::GetShape(const std::string &shape_str) { - std::vector ret; - std::istringstream iss(shape_str); - int elem; - while (iss.good()) { - iss >> elem; - ret.emplace_back(elem); - } - return ret; -} -const size_t UndeterminedShapeType::fields_num = 6; - -std::unordered_map g_undetermined_configs; -void InitUndeterminedFromEnv(const std::string &sparse_shape_types) { - std::string tmp; - std::stringstream input(sparse_shape_types); - g_undetermined_configs.clear(); - while (std::getline(input, tmp, ';')) { - auto config = UndeterminedShapeType(tmp); - g_undetermined_configs.insert(std::make_pair(config.param_name(), config)); - MS_LOG(DEBUG) << "Undetermined config from env: " << tmp; - } -} - -AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - MS_EXCEPTION_IF_NULL(primitive); - // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). - CheckArgsSize(primitive->name(), args_spec_list, 3); - auto key = args_spec_list[1]; - auto dflt = args_spec_list[2]; - TypePtr type = key->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(type); - if (type->type_id() != kObjectTypeSymbolicKeyType) { - MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); - } - - if (!key->sparse_grad().empty()) { - // Will be fixed once undetermined type ready - if (g_undetermined_configs.empty()) { - auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); - MS_LOG(INFO) << "Undetermind sparse shape:" << sparse_shape_types; - if (sparse_shape_types.empty()) { - sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2"; - } - InitUndeterminedFromEnv(sparse_shape_types); - } - - auto shape_types = g_undetermined_configs.find(key->sparse_grad()); - if (shape_types == g_undetermined_configs.end()) { - MS_LOG(EXCEPTION) << "Param " << key->ToString() - << " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES"; - } - MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString(); - AbstractBasePtrList sparse_list; - // indices - auto indices_ele = std::make_shared(kAnyValue, shape_types->second.indices_type()); - auto indices = - std::make_shared(indices_ele, std::make_shared(shape_types->second.indices_shape())); - sparse_list.emplace_back(indices); - // values - auto dout_ele = std::make_shared(kAnyValue, shape_types->second.values_type()); - auto dout = std::make_shared(dout_ele, std::make_shared(shape_types->second.values_shape())); - sparse_list.emplace_back(dout); - // dense_shape - sparse_list.emplace_back(std::make_shared(shape_types->second.dense_shape())); - return std::make_shared(sparse_list); - } - - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse_flag = context->enable_sparse_flag(); - if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa()) { - auto dflt_tensor = dflt->cast(); - return std::make_shared(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); - } - if (!key->GetValueTrack()->isa()) { - return dflt; - } - ValuePtr key_value_ptr = key->GetValueTrack(); - MS_EXCEPTION_IF_NULL(key_value_ptr); - auto key_value_track = key_value_ptr->cast(); - auto expected = key_value_track->abstract(); - MS_EXCEPTION_IF_NULL(expected); - (void)expected->Join(dflt); - return expected; -} - -AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). - CheckArgsSize(primitive->name(), args_spec_list, 3); - - auto key = args_spec_list[1]; - ValuePtr key_value_ptr = key->GetValueTrack(); - MS_EXCEPTION_IF_NULL(key_value_ptr); - auto key_value_track = key_value_ptr->cast(); - if (key_value_track == nullptr) { - MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: " - << key_value_ptr->ToString(); - } - auto expected = key_value_track->abstract(); - MS_EXCEPTION_IF_NULL(expected); - return std::make_shared(kAnyValue, std::make_shared()); -} - -AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). - CheckArgsSize(primitive->name(), args_spec_list, 2); - return std::make_shared(kAnyValue, std::make_shared()); -} - -AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &prim, const AbstractBasePtrList &) { - ValuePtr name_value = prim->GetAttr("tag"); - auto name = name_value->cast(); - if (name == nullptr) { - MS_LOG(EXCEPTION) << "MakeRefKey attr tag sould be a String " << name_value->ToString() << "."; - } - auto refkey = std::make_shared(name->value()); - if (refkey == nullptr) { - MS_LOG(EXCEPTION) << "MakeRefKey std::make_shared failed"; - } - return refkey->ToAbstract(); -} - -AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // arguments: key, value, original value - if (args_spec_list.size() != 3) { - MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() - << "."; - } - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kObjectTypeRefKey) { - MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); - } - auto ret = std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); - ret->set_sparse_grad(args_spec_list[2]->sparse_grad()); - ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad()); - return ret; -} - -AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // arguments: value - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "get_ref_key requires 1 parameters, while the input size is " << args_spec_list.size() << "."; - } - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kObjectTypeRef) { - MS_LOG(EXCEPTION) << "First input of get_ref_key should be a Ref but a " << type->ToString(); - } - return args_spec_list[0]->cast()->ref(); -} - -AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // arguments: value - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size() - << "."; - } - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kObjectTypeRef) { - MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); - } - return args_spec_list[0]->cast()->ref(); -} - -AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // arguments: value - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size() - << "."; - } - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kObjectTypeRef) { - MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); - } - return args_spec_list[0]->cast()->ref_origin(); -} - -AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: Two objects of a subclass of AbstractBase, key and value. - CheckArgsSize(primitive->name(), args_spec_list, 2); - - TypePtr type = args_spec_list[0]->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(type); - if (type->type_id() != kObjectTypeRefKey && type->type_id() != kObjectTypeSymbolicKeyType) { - MS_LOG(EXCEPTION) << "First input of StateSetItem should be a RefKey or SymbolicKeyType but a " << type->ToString(); - } - return std::make_shared(kAnyValue, kBool); -} - -AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0"; - } - auto depends = args_spec_list[0]->Broaden(); - return depends; -} - -bool CompareShape(const std::vector &x_shape, const std::vector &y_shape) { - if (x_shape.size() != y_shape.size()) { - return false; - } - - for (size_t i = 0; i < x_shape.size(); ++i) { - if (GetValue(x_shape[i]) != GetValue(y_shape[i])) { - return false; - } - } - - return true; -} - -enum State { - SAME, - X_ONE, - Y_ONE, -}; - -void ComputeReduceIndex(const std::vector &reverse_x, const std::vector &reverse_y, - std::vector *grad_x_reduce_idx, std::vector *grad_y_reduce_idy) { - const size_t n = reverse_x.size(); - for (size_t i = 0; i < n; ++i) { - State curr; - const int32_t x_i = reverse_x[i]; - const int32_t y_i = reverse_y[i]; - const int reduce_idx = SizeToInt(n - 1 - i); - if (x_i == y_i) { - curr = SAME; - } else if (x_i == 1) { - grad_x_reduce_idx->push_back(reduce_idx); - curr = X_ONE; - } else if (y_i == 1) { - grad_y_reduce_idy->push_back(reduce_idx); - curr = Y_ONE; - } else { - MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs"; - } - if (curr == SAME && x_i == 1) { - grad_x_reduce_idx->push_back(reduce_idx); - grad_y_reduce_idy->push_back(reduce_idx); - continue; - } - } - - std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end()); - std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end()); -} - -AbstractBasePtr BroadcastGradientArgsDiff(const std::vector &x_shape, const std::vector &y_shape) { - std::vector reverse_x; - std::vector reverse_y; - - (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x), - [](const ValuePtr &v) { return v->cast()->value(); }); - (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y), - [](const ValuePtr &v) { return v->cast()->value(); }); - - if (reverse_x.size() > reverse_y.size()) { - reverse_y.resize(reverse_x.size(), 1); - } else { - reverse_x.resize(reverse_y.size(), 1); - } - - std::vector grad_x_reduce_idx; - std::vector grad_y_reduce_idy; - ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy); - - AbstractBasePtrList abs_list_x; - AbstractBasePtrList abs_list_y; - (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x), - [](int v) { return abstract::FromValue(v); }); - (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y), - [](int v) { return abstract::FromValue(v); }); - auto x_reduce_idx = std::make_shared(abs_list_x); - auto y_reduce_idx = std::make_shared(abs_list_y); - AbstractBasePtrList elem_list; - elem_list.push_back(x_reduce_idx); - elem_list.push_back(y_reduce_idx); - - return std::make_shared(elem_list); -} - -AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // this primitive get the index that need to reduce - // input: x's shape and y's shape, inputs should be tuple - // output: tuple of x and y 's reduce index, reduce index should be a tuple - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto arg_x = CheckArg(op_name, args_spec_list, 0); - auto arg_y = CheckArg(op_name, args_spec_list, 1); - - ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(arg_x_value); - - ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(arg_y_value); - - const std::vector x_shape = arg_x_value->value(); - const std::vector y_shape = arg_y_value->value(); - bool is_same_shape = CompareShape(x_shape, y_shape); - // if it is the same shape , do not need reduce , return empty tuple - if (is_same_shape) { - AbstractBasePtrList empty_list; - auto x_reduce_idx = std::make_shared(empty_list); - auto y_reduce_idx = std::make_shared(empty_list); - - AbstractBasePtrList elem_list; - elem_list.push_back(x_reduce_idx); - elem_list.push_back(y_reduce_idx); - - return std::make_shared(elem_list); - } - - return BroadcastGradientArgsDiff(x_shape, y_shape); -} - -AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: Two objects of a subclass of AbstractBase - CheckArgsSize(primitive->name(), args_spec_list, 2); - auto arg_src = args_spec_list[0]; - auto arg_dst = args_spec_list[1]; - // control depend can not setup tuple of ops to tuple of ops dependency relation - if (arg_src->isa() && arg_dst->isa()) { - auto src_size = arg_src->cast()->size(); - auto dst_size = arg_src->cast()->size(); - if (src_size > 1 && dst_size > 1) { - MS_LOG(EXCEPTION) << "Control depend can not setup operator dependcy relationship from tuple from tuple"; - } - } - return std::make_shared(kAnyValue, kBool); -} - -AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - auto indices = CheckArg(op_name, args_spec_list, 0); - auto values = CheckArg(op_name, args_spec_list, 1); - auto dense_shape = CheckArg(op_name, args_spec_list, 2); - - auto dense_shape_value = dense_shape->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(dense_shape_value); - auto shp = dense_shape_value->value(); - std::vector dense_shape_vec; - (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), - [](const ValuePtr &e) -> int { - auto elem = GetValue(e); - return elem; - }); - auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); - ret->set_indices(indices); - ret->set_values(values); - ret->set_dense_shape(dense_shape); - return ret; -} - -AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto indexed_slices = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(indexed_slices->values()); - return indexed_slices->values(); -} - -AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto indexed_slices = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(indexed_slices->indices()); - return indexed_slices->indices(); -} - -AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto indexed_slices = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape()); - return indexed_slices->dense_shape(); -} - -AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - bool ret = false; - if (args_spec_list[0]->isa()) { - ret = true; - } - MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString(); - return std::make_shared(ret); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_statement.cc b/mindspore/ccsrc/operator/prim_statement.cc deleted file mode 100644 index fc40e511e1..0000000000 --- a/mindspore/ccsrc/operator/prim_statement.cc +++ /dev/null @@ -1,244 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/param_validator.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/utils.h" -#include "utils/symbolic.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a pointer to an AbstractBase object - if (args_spec_list.size() != 1) { - MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? " - "while the input size is " - << args_spec_list.size() << "."; - } - AbstractBasePtr abs_base = args_spec_list[0]; - return abs_base; -} - -AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a pointer to an AbstractBase object - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size() - << "."; - } - AbstractBasePtr abs_base = args_spec_list[0]; - MS_EXCEPTION_IF_NULL(abs_base); - TypePtr type = abs_base->BuildType(); - return std::make_shared(type); -} - -AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a pointer to an AbstractBase object and a pointer to a Type - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTypePtr abs_type = CheckArg(op_name, args_spec_list, 1); - - auto mode_v = abs_type->GetValueTrack(); - MS_EXCEPTION_IF_NULL(mode_v); - if (!mode_v->isa()) { - MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed."; - } - - TypePtr mode_t = mode_v->cast(); - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - bool v = IsSubtype(args_spec_list[0], mode_t); - return std::make_shared(std::make_shared(v), kBool); -} - -AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTensorPtr input_x = CheckArg(op_name, args_spec_list, 0); - AbstractTensorPtr input_y = CheckArg(op_name, args_spec_list, 1); - - ShapePtr x_shp = input_x->shape(); - auto x_shp_value = x_shp->shape(); - ShapePtr y_shp = input_y->shape(); - auto y_shp_value = y_shp->shape(); - // Should be matrix which shape size is 2. - if (x_shp_value.size() != 2 || y_shp_value.size() != 2) { - MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are " - << x_shp_value.size() << ", " << y_shp_value.size() << " "; - } - if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) { - MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}"; - } - - auto x_element = input_x->element(); - MS_EXCEPTION_IF_NULL(x_element); - (void)x_element->Join(input_y->element()); - auto param = {x_shp_value[0], y_shp_value[1]}; - - return std::make_shared(input_x->element(), std::make_shared(param)); -} - -AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // Inputs: condition, true branch, false branch - if (args_spec_list.size() != 3) { - MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_spec_list.size() - << "."; - } - - auto cond = args_spec_list[0]; - auto tb = args_spec_list[1]; - auto fb = args_spec_list[2]; - MS_EXCEPTION_IF_NULL(cond); - - ValuePtr v = cond->GetValueTrack(); - MS_EXCEPTION_IF_NULL(v); - // for tensor as condition, keeps both true and false branch. - if (v->isa() || cond->isa()) { - MS_EXCEPTION_IF_NULL(tb); - return tb->Join(fb); - } - - if (v->isa()) { - if (v->cast()->IsOne()) { - return tb; - } else { - return fb; - } - } - - MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString(); -} - -AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: index, branch - const std::string op_name = primitive->name(); - abstract::CheckArgsSize(op_name, args_spec_list, 2); - (void)CheckArg(op_name, args_spec_list, 0); - AbstractTuplePtr branches_abs = CheckArg(op_name, args_spec_list, 1); - AbstractBasePtrList branches = branches_abs->elements(); - const size_t maximum_layer_num = 1000; - if (branches.size() < 0 || branches.size() > maximum_layer_num) { - MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " - << branches.size() << " branches."; - } - - for (size_t i = 0; i < branches.size(); i++) { - MS_EXCEPTION_IF_NULL(branches[i]); - if (!branches[i]->isa()) { - MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got " - << branches[i]->ToString() << " as the " << i << "th element."; - } - } - - auto b = branches[0]; - for (size_t i = 1; i < branches.size(); i++) { - b = b->Join(branches[i]); - } - return b; -} - -std::vector GetSupportedTargetValue() { - std::vector list = {kNone, MakeValue(false), MakeValue(true)}; - return list; -} - -bool SupportedIsTargetValue(const ValuePtr t) { - auto list = GetSupportedTargetValue(); - auto match = std::any_of(list.begin(), list.end(), [&t](const ValuePtr &v) { return *v == *t; }); - return match; -} - -AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // statement: x is t - // Inputs: x, t - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - ValuePtr t = args_spec_list[1]->BuildValue(); - if (!SupportedIsTargetValue(t)) { - MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString() - << " for statement is, supported list is:None, False, True "; - } - ValuePtr x = args_spec_list[0]->BuildValue(); - - return std::make_shared(*t == *x); -} - -AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // statement: x is not t - // Inputs: x, t - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - ValuePtr t = args_spec_list[1]->BuildValue(); - if (!SupportedIsTargetValue(t)) { - MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString() - << " for statement is not, supported list is:None, False, True "; - } - ValuePtr x = args_spec_list[0]->BuildValue(); - - return std::make_shared(!(*t == *x)); -} - -bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto key = CheckArg(op_name, args_spec_list, 0); - auto dict = CheckArg(op_name, args_spec_list, 1); - - ValuePtr key_value = key->BuildValue(); - if (!key_value->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); - } - auto key_str = GetValue(key_value); - std::vector dict_elems = dict->elements(); - auto it = std::find_if(dict_elems.begin(), dict_elems.end(), - [key_str](const AbstractAttribute &item) { return item.first == key_str; }); - return it != dict_elems.end(); -} - -AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // statement: x in t - // Inputs: x, t - return std::make_shared(IsInDict(primitive, args_spec_list)); -} - -AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // statement: x not in t - // Inputs: x, t - return std::make_shared(!IsInDict(primitive, args_spec_list)); -} - -AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // statement: isconstant(x) - // Inputs: x - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1"; - } - ValuePtr v = args_spec_list[0]->BuildValue(); - return std::make_shared(!v->isa()); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_to_function.cc b/mindspore/ccsrc/operator/prim_to_function.cc deleted file mode 100644 index 733cdbdb73..0000000000 --- a/mindspore/ccsrc/operator/prim_to_function.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/prim_to_function.h" -#include -#include -#include - -namespace mindspore { -// namespace to support prim related definition -namespace prim { - -PrimToFunction::PrimToFunction() - : prim_func_type_map_({// ONE_ARG prim - {"bool_not", kPrimTypeOneArg}, - {"scalar_cos", kPrimTypeOneArg}, - {"scalar_exp", kPrimTypeOneArg}, - {"scalar_floor", kPrimTypeOneArg}, - {"scalar_log", kPrimTypeOneArg}, - {"scalar_sin", kPrimTypeOneArg}, - {"scalar_tan", kPrimTypeOneArg}, - {"scalar_trunc", kPrimTypeOneArg}, - {"typeof", kPrimTypeOneArg}, - {"scalar_uadd", kPrimTypeOneArg}, - {"scalar_usub", kPrimTypeOneArg}, - // TWO_ARGS prim - {"scalar_add", kPrimTypeTwoArgs}, - {"bool_and", kPrimTypeTwoArgs}, - {"bool_eq", kPrimTypeTwoArgs}, - {"bool_or", kPrimTypeTwoArgs}, - {"scalar_div", kPrimTypeTwoArgs}, - {"scalar_eq", kPrimTypeTwoArgs}, - {"scalar_ge", kPrimTypeTwoArgs}, - {"scalar_gt", kPrimTypeTwoArgs}, - {"scalar_le", kPrimTypeTwoArgs}, - {"scalar_lt", kPrimTypeTwoArgs}, - {"scalar_ne", kPrimTypeTwoArgs}, - {"scalar_mod", kPrimTypeTwoArgs}, - {"scalar_mul", kPrimTypeTwoArgs}, - {"scalar_pow", kPrimTypeTwoArgs}, - {"scalar_sub", kPrimTypeTwoArgs}, - {"scalar_floordiv", kPrimTypeTwoArgs}}) {} - -bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { - bool result = false; - - if (func != nullptr) { - int args_num = GetPrimType(prim); - std::vector one_arg{std::make_shared()}; - std::vector two_args{std::make_shared(), std::make_shared()}; - TypePtr retval = std::make_shared(); - result = true; - switch (args_num) { - case kPrimTypeOneArg: - *func = Function(one_arg, retval).DeepCopy()->cast(); - break; - case kPrimTypeTwoArgs: - *func = Function(two_args, retval).DeepCopy()->cast(); - break; - default: - result = false; - break; - } - } - - return result; -} - -int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const { - MS_EXCEPTION_IF_NULL(prim); - int prim_type = static_cast(kPrimTypeUnknown); - - auto value = prim_func_type_map_.find(prim->name()); - if (value != prim_func_type_map_.end()) { - prim_type = value->second; - } - return prim_type; -} -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/CMakeLists.txt b/mindspore/ccsrc/optimizer/CMakeLists.txt deleted file mode 100644 index 44af01735a..0000000000 --- a/mindspore/ccsrc/optimizer/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB_RECURSE _OPTIMIZER_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_OPTIMIZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_OPTIMIZER) -add_library(_mindspore_optimizer_obj OBJECT ${_OPTIMIZER_SRC_FILES}) diff --git a/mindspore/ccsrc/optimizer/ad/adjoint.cc b/mindspore/ccsrc/optimizer/ad/adjoint.cc deleted file mode 100644 index ed89aba20e..0000000000 --- a/mindspore/ccsrc/optimizer/ad/adjoint.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/ad/adjoint.h" - -#include -#include - -#include "ir/anf.h" -#include "optimizer/ad/dfunctor.h" - -namespace mindspore { -namespace ad { -Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller) - : primal_(primal), caller_(caller), dout_(nullptr) { - if (k != nullptr) { - k_ = k; - MS_LOG(DEBUG) << "Add adjoint for " << primal->ToString() << " " << k_->ToString(); - } else { - // Init k hole in a recursive case. - auto k_hole = std::make_shared("k_hole"); - (void)k_hole->AddAttr("info", MakeValue(primal->ToString())); - k_ = NewValueNode(k_hole); - MS_LOG(DEBUG) << "Add hole for " << primal->ToString() << " " << k_->ToString(); - } - - dout_hole_ = caller_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), k_}); - RegisterKUser(dout_hole_->cast(), 1); -} - -AnfNodePtr Adjoint::k() { return k_; } - -void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } - -void Adjoint::UpdateK(const AnfNodePtr &new_k) { - MS_EXCEPTION_IF_NULL(new_k); - MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); - // In recursive case, it needs update. - for (auto &user : k_user_) { - MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" - << new_k->ToString(); - if (user.first->input(user.second) != k_) { - MS_LOG(EXCEPTION) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k " - << new_k->ToString() << ", user relation is set wrongly"; - } - user.first->set_input(user.second, new_k); - } - k_ = new_k; -} - -AnfNodePtr Adjoint::primal() { return primal_; } - -AnfNodePtr Adjoint::dout() { return dout_hole_; } - -void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) { - dout_user_.emplace_back(std::make_pair(user, index)); -} - -void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) { - if (dout_ != nullptr) { - MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); - auto add = prim::GetPythonOps("hyper_add"); - dout_ = caller_->NewCNode({NewValueNode(add), dout_, dout_factor}); - return; - } - dout_ = dout_factor; -} - -void Adjoint::CallDoutHole() { - if (dout_ != nullptr) { - for (auto &user : dout_user_) { - MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " - << dout_->ToString(); - if (user.first->input(user.second) != dout_hole_) { - MS_LOG(EXCEPTION) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " - << dout_->ToString() << ", user relation is set wrongly"; - } - user.first->set_input(user.second, dout_); - } - } -} -} // namespace ad -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/ad/adjoint.h b/mindspore/ccsrc/optimizer/ad/adjoint.h deleted file mode 100644 index b2dae8e66f..0000000000 --- a/mindspore/ccsrc/optimizer/ad/adjoint.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ - -#include -#include -#include - -#include "ir/anf.h" -#include "optimizer/opt.h" - -namespace mindspore { -namespace ad { -class Adjoint { - public: - Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller); - ~Adjoint() = default; - AnfNodePtr primal(); - AnfNodePtr k(); - void UpdateK(const AnfNodePtr &k); - void RegisterKUser(const CNodePtr &user, size_t index); - AnfNodePtr dout(); - void AccumulateDout(const AnfNodePtr &dout_factor); - void RegisterDoutUser(const CNodePtr &user, size_t index); - void CallDoutHole(); - - private: - AnfNodePtr primal_; - FuncGraphPtr caller_; - // For ```def f(x): return expr```, The representation graph k is ```def kf(kx): return expr, bprop{expr}```. - AnfNodePtr k_; - std::vector> k_user_; - AnfNodePtr dout_; - AnfNodePtr dout_hole_; - std::vector> dout_user_; -}; - -using AdjointPtr = std::shared_ptr; -} // namespace ad -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/optimizer/ad/dfunctor.cc deleted file mode 100644 index f9c056a84e..0000000000 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.cc +++ /dev/null @@ -1,617 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/ad/dfunctor.h" - -#include -#include -#include - -#include "ir/anf.h" -#include "ir/meta_func_graph.h" -#include "debug/info.h" -#include "ir/func_graph_cloner.h" -#include "ir/manager.h" -#include "pipeline/resource.h" -#include "pipeline/parse/parse.h" -#include "optimizer/ad/adjoint.h" -#include "optimizer/opt.h" -#include "operator/ops.h" -#include "operator/composite/composite.h" -#include "utils/symbolic.h" -#include "utils/context/ms_context.h" -#include "./common.h" - -namespace mindspore { -namespace ad { -std::unordered_map DFunctor::func_graph_to_functor_; -std::unordered_map DFunctor::anfnode_to_adjoin_definition_; -FuncGraphSet DFunctor::scope_; - -DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources) - : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { - TraceManager::DebugTrace(std::make_shared(primal_graph->debug_info())); - k_graph_ = std::make_shared(); - if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - std::string grad_op_name = GetValue(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); - } - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(primal_graph->debug_info())); - tape_ = std::make_shared(); - // Add "_Grad" postfix - if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - std::string grad_op_name = GetValue(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad"; - tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); - } - TraceManager::EndTrace(); - - dout_ = tape_->add_parameter(); -} - -void DFunctor::Init(bool is_top) { - func_graph_to_functor_[primal_graph_] = shared_from_this(); - is_top_ = is_top; - if (is_top) { - scope_ = primal_graph_->scope(); - } -} - -void DFunctor::Clear() { - func_graph_to_functor_.clear(); - anfnode_to_adjoin_definition_.clear(); - scope_.clear(); -} - -void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { - auto fv_adjoint = anfnode_to_adjoin_.find(fv); - if (fv_adjoint == anfnode_to_adjoin_.end()) { - MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() - << " " << fv->ToString() << "."; - fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); - if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) { - MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv " - << fv->func_graph()->ToString() << " " << fv->ToString() << "."; - auto parent_adjoint = FindAdjoint(fv); - AdjointPtr adjoint = nullptr; - if (parent_adjoint != nullptr) { - adjoint = std::make_shared(fv, parent_adjoint->k(), tape_); - } else { - MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole " - << fv->func_graph()->ToString() << " " << fv->ToString() << "."; - adjoint = std::make_shared(fv, nullptr, tape_); - } - anfnode_to_adjoin_indirect_fv_[fv] = adjoint; - fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); - } - } - auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()}); - fv_adjoint->second->RegisterKUser(key, 1); - auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()}); - fv_adjoint->second->RegisterKUser(default_val, 1); - auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, key, default_val}); - MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv " - << fv->func_graph()->ToString() << " " << fv->ToString() << "."; - MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << key->ToString() << "."; - fv_adjoint->second->AccumulateDout(dfv); -} - -void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) { - // Take switch_layer as a set of candidate functions. - auto input = cnode_morph->input(2); - if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { - MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << "."; - } - auto tuple_graphs = input->cast(); - for (size_t i = 1; i < tuple_graphs->size(); ++i) { - auto graph = tuple_graphs->input(i); - if (!IsValueNode(graph)) { - MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString() - << " as the " << i << "th element."; - } - auto func_graph = GetValueNode(graph); - auto functor = func_graph_to_functor_.find(func_graph); - if (functor == func_graph_to_functor_.end()) { - MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] " - << func_graph->ToString() << "."; - } - // Consider direct and indirect fvs. - for (auto fv : func_graph->free_variables_nodes()) { - BackPropagateFv(fv, env); - } - for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { - MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " " - << indirect_fv.first->ToString() << "."; - BackPropagateFv(indirect_fv.first, env); - } - } -} - -void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) { - auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)}); - // Call with delimited continuation dout. - auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()}); - node_adjoint->RegisterDoutUser(bprop_app, 1); - // Special case for switch_layer - if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) { - auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)}); - BackPropagateSwitchLayer(cnode_morph, din); - return; - } - for (size_t i = 0; i < cnode_morph->size(); i++) { - auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))}); - auto input = cnode_morph->input(i); - // Backprop sens wrt fvs. - if (IsValueNode(input)) { - auto func_graph = GetValueNode(input); - auto functor = func_graph_to_functor_.find(func_graph); - if (functor == func_graph_to_functor_.end()) { - MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] " - << func_graph->ToString() << "."; - } - // Consider direct and indirect fvs. - for (auto fv : func_graph->free_variables_nodes()) { - BackPropagateFv(fv, din); - } - for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { - MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " " - << indirect_fv.first->ToString() << "."; - BackPropagateFv(indirect_fv.first, din); - } - continue; - } - // Backprop sens wrt inputs. - auto input_adjoint = anfnode_to_adjoin_.find(input); - if (input_adjoint == anfnode_to_adjoin_.end()) { - MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << "."; - } - input_adjoint->second->AccumulateDout(din); - } -} - -// Map a morphism. -AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { - // MapMorphism All type except CNode should already be mapped by MapObject. - if (!morph->isa()) { - return nullptr; - } - ScopeGuard scope_guard(morph->scope()); - auto cnode_morph = morph->cast(); - - std::vector inputs; - std::vector param_adjoints; - for (size_t i = 0; i < cnode_morph->size(); i++) { - auto node = cnode_morph->input(i); - auto node_adjoint_iter = anfnode_to_adjoin_.find(node); - AdjointPtr node_adjoint = nullptr; - AnfNodePtr k = nullptr; - if (node_adjoint_iter != anfnode_to_adjoin_.end()) { - node_adjoint = node_adjoint_iter->second; - } else { - // Input might be a CNode that needs to be handled before hand. - node_adjoint = MapMorphism(node); - } - MS_EXCEPTION_IF_NULL(node_adjoint); - k = node_adjoint->k(); - if (k == nullptr) { - MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << "."; - } - inputs.push_back(k); - param_adjoints.push_back(node_adjoint); - } - TraceManager::DebugTrace(std::make_shared(cnode_morph->debug_info())); - auto k_app = k_graph_->NewCNode(inputs); - TraceManager::EndTrace(); - for (size_t i = 0; i < param_adjoints.size(); ++i) { - param_adjoints[i]->RegisterKUser(k_app, i); - } - - // Do forward computation - auto foward_app = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(0)}); - // K:: cnode -> forward_app - auto node_adjoint = std::make_shared(morph, foward_app, tape_); - UpdateAdjoint(node_adjoint); - anfnode_to_adjoin_[morph] = node_adjoint; - if (cnode_morph->stop_gradient()) { - MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped."; - return node_adjoint; - } - - // Do sens backpropagation - BackPropagate(cnode_morph, k_app, node_adjoint); - MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << "."; - return node_adjoint; -} - -bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { - // Do not care about non-CNode - if (!node->isa()) { - return false; - } - // Do not care about kPrimReturn - if (IsPrimitiveCNode(node, prim::kPrimReturn)) { - return false; - } - auto &users = primal_graph_->manager()->node_users()[node]; - // Do not care about isolated morphisms - if (users.empty()) { - return false; - } - // Not free if it's used by some node in primal_graph - bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) { - auto &user = kv.first; - return user->func_graph() == primal_graph_; - }); - return !nonfree; -} - -void DFunctor::MapFreeMorphism() { - // Handle cnode not attached to output, that might be refered in other functions. - for (auto &node : primal_graph_->nodes()) { - if (!IsFreeMorphism(node)) { - continue; - } - MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << "."; - (void)MapMorphism(node); - } -} - -AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { - AnfNodePtr new_grad_fv = grad_fv; - // Add grads wrt fv. - const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); - for (auto &fv : free_variables_nodes) { - auto fv_adjoint = anfnode_to_adjoin_.find(fv); - if (fv_adjoint == anfnode_to_adjoin_.end()) { - MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << "."; - } - auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()}); - fv_adjoint->second->RegisterKUser(key, 1); - auto sens = fv_adjoint->second->dout(); - new_grad_fv = tape_->NewCNode({ - NewValueNode(prim::kPrimEnvSetItem), - new_grad_fv, - key, - sens, - }); - fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast(), 3); - MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " " - << fv->ToString() << " " << primal_graph_->ToString() << "."; - } - return new_grad_fv; -} - -AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) { - AnfNodePtr new_grad_fv = grad_fv; - // Add indirect fv bprop. - for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) { - MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " " - << primal_graph_->ToString() << "."; - auto key = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()}); - fv_adjoint.second->RegisterKUser(key, 1); - auto sens = fv_adjoint.second->dout(); - new_grad_fv = tape_->NewCNode({ - NewValueNode(prim::kPrimEnvSetItem), - new_grad_fv, - key, - sens, - }); - fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast(), 3); - MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to " - << new_grad_fv->ToString() << "."; - } - return new_grad_fv; -} - -void DFunctor::MapMorphism() { - // Set stop_gradient before MapMorphism. - BroadCastStopFlag(); - - // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent - MapFreeMorphism(); - // Handle morphism from output. - (void)MapMorphism(primal_graph_->output()); - - // Construct K for primal_graph_ - auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output()); - // Attach dout_ parameter to output_adjoint. - output_adjoint->second->AccumulateDout(dout_); - - // Set output for tape closure. - auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv))); - - std::vector inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv}; - // Add grads wrt inputs. - std::vector param_adjoints; - for (auto ¶m : primal_graph_->parameters()) { - auto param_adjoint = anfnode_to_adjoin_.find(param); - inputs.push_back(param_adjoint->second->dout()); - param_adjoints.push_back(param_adjoint->second); - } - auto tape_output = tape_->NewCNode(inputs); - for (size_t i = 0; i < param_adjoints.size(); ++i) { - param_adjoints[i]->RegisterDoutUser(tape_output, i + 2); - } - tape_->set_output(tape_output); - // Set output for k_graph_, K:: cnode->forward_app. - auto forward_app = output_adjoint->second->k(); - auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)}); - output_adjoint->second->RegisterKUser(output, 1); - k_graph_->set_output(output); - (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_))); - (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_))); -} - -FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { - // K user defined cell bprop. - auto bprop = primal->transforms().find("bprop"); - if (bprop != primal->transforms().end()) { - FuncGraphPtr bprop_graph = bprop->second.func_graph(); - resources_->manager()->AddFuncGraph(bprop_graph); - - if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) { - MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope " - << primal->output()->scope()->name() << " does not support Parameter data type."; - } - auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph); - if (fg == nullptr) { - MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope " - << primal->output()->scope()->name() << "."; - } - - // Cache the grad func - (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg))); - (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal))); - // Reset defer_inline to enable successive inlining - primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); - - auto functor = std::make_shared(primal, resources_); - functor->Init(); - functor->k_graph_ = fg; - - return fg; - } - return nullptr; -} - -// MapToK(func) -AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { - auto f = func_graph_to_functor_.find(primal); - if (f != func_graph_to_functor_.end()) { - MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << "."; - return NewValueNode(f->second->k_graph_); - } - - auto k_user_defined = KUserDefined(primal); - if (k_user_defined != nullptr) { - MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << "."; - return NewValueNode(k_user_defined); - } - - auto functor = std::make_shared(primal, resources_); - functor->Init(); - functor->MapObject(); - functor->MapMorphism(); - - MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << "."; - return NewValueNode(functor->k_graph_); -} - -// Construct representation graph for given node. -AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { - ScopeGuard scope_guard(primal->scope()); - // MapToK(prim) - if (IsValueNode(primal)) { - auto value_node = primal->cast(); - auto prim = GetValueNode(value_node); - if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { - MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; - need_cut_ = true; - } - auto k_prim = g_k_prims.KPrimitive(value_node, resources_); - if (k_prim != nullptr) { - return NewValueNode(k_prim); - } - // When failed to find k_prim, try k_meta. - auto k_meta = g_k_prims.KMetaFuncGraph(prim); - if (k_meta != nullptr) { - return NewValueNode(k_meta); - } - } - - // MapToK(func) - if (IsValueNode(primal)) { - auto func_graph = GetValueNode(primal); - auto k_func = MapToK(func_graph); - return k_func; - } - - if (primal->isa()) { - TraceManager::DebugTrace(std::make_shared(primal->debug_info())); - auto ret = k_graph_->add_parameter(); - TraceManager::EndTrace(); - return ret; - } - - if (!primal->isa()) { - MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode."; - } - return primal; -} - -bool DFunctor::IsInScope(const AnfNodePtr &node) { - return std::any_of(scope_.begin(), scope_.end(), - [&](const FuncGraphPtr &graph) { return node->func_graph() == graph; }); -} - -void DFunctor::MapFvObject() { - // Map free variable. - const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); - for (auto &node : free_variables_nodes) { - ScopeGuard scope_guard(node->scope()); - MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << "."; - // Find fv's K from parent. - AdjointPtr adjoint = nullptr; - auto parent_adjoint = FindAdjoint(node); - if (parent_adjoint != nullptr) { - adjoint = std::make_shared(node, parent_adjoint->k(), tape_); - } else { - if (is_top_ || node->isa() || !IsInScope(node)) { - // Out of ad scope, add adjoint for free variables. - adjoint = std::make_shared(node, node, tape_); - UpdateAdjoint(adjoint); - } else { - MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << "."; - adjoint = std::make_shared(node, nullptr, tape_); - } - } - if (adjoint == nullptr) { - MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << "."; - } - anfnode_to_adjoin_[node] = adjoint; - } -} - -void DFunctor::MapParamObject() { - // Map parameter. - for (auto &p : primal_graph_->parameters()) { - ScopeGuard scope_guard(p->scope()); - MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << "."; - auto adjoint = std::make_shared(p, MapToK(p), tape_); - UpdateAdjoint(adjoint); - anfnode_to_adjoin_[p] = adjoint; - } -} - -void DFunctor::MapValueObject() { - // Map ValueNode. - auto manager = resources_->manager(); - auto &value_nodes = primal_graph_->value_nodes(); - for (const auto &value_pair : value_nodes) { - auto node = value_pair.first; - auto parent_adjoint = FindAdjoint(node); - if (parent_adjoint != nullptr) { - auto adjoint = std::make_shared(node, parent_adjoint->k(), tape_); - anfnode_to_adjoin_[node] = adjoint; - continue; - } - // Skip Return. - if (IsValueNode(node) && GetValueNode(node) == prim::kPrimReturn) { - continue; - } - MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << "."; - auto adjoint = std::make_shared(node, MapToK(node), tape_); - UpdateAdjoint(adjoint); - anfnode_to_adjoin_[node] = adjoint; - } -} - -// Skip morphism. -void DFunctor::MapObject() { - // The order does not matter - MapFvObject(); - MapParamObject(); - MapValueObject(); -} - -void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) { - auto primal = adjoint_definition->primal(); - if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) { - MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " " - << primal->ToString() << "."; - } - anfnode_to_adjoin_definition_[primal] = adjoint_definition; - // Update k hole for primal. - for (auto &f : func_graph_to_functor_) { - auto adjoint = f.second->anfnode_to_adjoin_.find(primal); - if (adjoint != f.second->anfnode_to_adjoin_.end()) { - adjoint->second->UpdateK(adjoint_definition->k()); - } - adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal); - if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) { - adjoint->second->UpdateK(adjoint_definition->k()); - } - } -} - -AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) { - auto adjoint = anfnode_to_adjoin_definition_.find(primal); - if (adjoint != anfnode_to_adjoin_definition_.end()) { - MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << "."; - return adjoint->second; - } - MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << "."; - return nullptr; -} - -void DFunctor::CallDoutHoleOnTape() { - if (!is_top_) { - return; - } - - // Call dout hole of all adjoint. - for (auto &f : func_graph_to_functor_) { - for (auto &adjoint : f.second->anfnode_to_adjoin_) { - adjoint.second->CallDoutHole(); - } - for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) { - adjoint.second->CallDoutHole(); - } - } -} -FuncGraphPtr DFunctor::k_graph() { - CallDoutHoleOnTape(); - return k_graph_; -} - -void DFunctor::BroadCastStopFlag() { - // As stop set expanding, all directly or indirectly stopped CNode will be cut off - while (need_cut_) { - need_cut_ = false; - for (auto &node : primal_graph_->nodes()) { - if (node->isa()) { - auto cnode = node->cast(); - if (!cnode->stop_gradient()) { - // Cut off the cnode only when it's not referred any more - if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || AllReferencesStopped(cnode)) { - MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << "."; - cnode->set_stop_gradient(true); - // The stop set changed, more cut required - need_cut_ = true; - } - } - } - } - } -} - -bool DFunctor::AllReferencesStopped(const CNodePtr &node) { - auto &users = primal_graph_->manager()->node_users()[node]; - // Only care about stop_gradient caused cutting - if (users.empty()) { - return false; - } - for (auto &kv : users) { - auto &user = kv.first; - if (!user->isa() || !user->cast()->stop_gradient()) { - return false; - } - } - return true; -} -} // namespace ad -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.h b/mindspore/ccsrc/optimizer/ad/dfunctor.h deleted file mode 100644 index 4fa9cf6bb5..0000000000 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.h +++ /dev/null @@ -1,228 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ - -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/meta_func_graph.h" -#include "ir/func_graph_cloner.h" -#include "pipeline/resource.h" -#include "optimizer/ad/adjoint.h" -#include "operator/ops.h" -#include "debug/trace.h" - -namespace mindspore { -namespace ad { -struct PrimitiveTotalEqual { - bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { - if (t1->name() != t2->name()) { - return false; - } - - auto const &attrs1 = t1->attrs(); - auto const &attrs2 = t2->attrs(); - if (attrs1.size() != attrs2.size()) { - return false; - } - - for (auto &attr1 : attrs1) { - if (!t2->HasAttr(attr1.first)) { - return false; - } - - if (!(*(attr1.second) == *(t2->GetAttr(attr1.first)))) { - return false; - } - } - - return true; - } -}; - -using Registry = std::unordered_map; -class KPrim; -extern KPrim g_k_prims; -class DFunctor; -using DFunctorPtr = std::shared_ptr; - -// D Functor's rules to map closure object and morphisms. -class DFunctor : public std::enable_shared_from_this { - public: - DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources); - ~DFunctor() = default; - // Map object in D category to K category. - void MapObject(); - // Map morphism in D category to K category. - void MapMorphism(); - FuncGraphPtr k_graph(); - // Construct user defined k object. - FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); - // Register functor objects to form a global view. - void Init(bool is_top = false); - bool IsInScope(const AnfNodePtr &node); - - // Clear resources. - static void Clear(); - - private: - // Map one morphism. - AdjointPtr MapMorphism(const AnfNodePtr &morph); - bool IsFreeMorphism(const AnfNodePtr &node); - // Map morphism that's not attached to output. - void MapFreeMorphism(); - void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); - void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env); - void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); - AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); - AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); - // Map Anfnode object from D category to K category. - AnfNodePtr MapToK(const AnfNodePtr &primal); - // Map FuncGraph object from D category to K category. - AnfNodePtr MapToK(const FuncGraphPtr &primal); - // MapObject impls. - void MapFvObject(); - void MapValueObject(); - void MapParamObject(); - // Find adjoint with its primary k. - AdjointPtr FindAdjoint(const AnfNodePtr &primal); - // Broadcast stop flags. - void BroadCastStopFlag(); - bool AllReferencesStopped(const CNodePtr &node); - // Update k hole with adjoint_definition, only applied in recursive case. - void UpdateAdjoint(const AdjointPtr &adjoint_definition); - void CallDoutHoleOnTape(); - - std::unordered_map anfnode_to_adjoin_; - // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. - std::unordered_map anfnode_to_adjoin_indirect_fv_; - FuncGraphPtr primal_graph_; - // K object for primal_graph_; - FuncGraphPtr k_graph_; - // The Backprop part of k_graph_. - FuncGraphPtr tape_; - // Dout parameter for primal_graph_. - AnfNodePtr dout_; - pipeline::ResourceBasePtr resources_; - // Cut off stopped objects in category D. - bool need_cut_; - bool is_top_; - static std::unordered_map> func_graph_to_functor_; - static std::unordered_map anfnode_to_adjoin_definition_; - static FuncGraphSet scope_; -}; - -// D Functor's rules to map primitive object. -class KPrim { - public: - KPrim() = default; - ~KPrim() = default; - - FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); - MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); - FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); - - void clear() { - bprop_registry_meta_.clear(); - bprop_registry_.clear(); - } - - private: - FuncGraphPtr GetBprop(const PrimitivePtr &prim); - FuncGraphPtr GetFprop(const PrimitivePtr &prim); - FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); - FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); - // Given a bprop rule, do the K mapping. - template - FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); - AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); - void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, - std::vector *const transf_args); - void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check); - - Registry bprop_registry_; - std::unordered_map bprop_registry_meta_; -}; - -template -FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { - MS_EXCEPTION_IF_NULL(primal); - MS_EXCEPTION_IF_NULL(bprop_fg); - CheckBprop(bprop_fg, primal->ToString()); - - auto debug_info = std::make_shared(); - debug_info->set_name(primal->ToString()); - - auto cloned_bprop_fg = BasicClone(bprop_fg); - MS_EXCEPTION_IF_NULL(cloned_bprop_fg); - - cloned_bprop_fg->debug_info()->set_name(""); - cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared(debug_info)); - - AnfNodePtr bout = BuildOutput(cloned_bprop_fg); - cloned_bprop_fg->set_output(bout); - - TraceManager::DebugTrace(std::make_shared(debug_info)); - auto outer = std::make_shared(); - (void)outer->transforms().emplace("primal", FuncGraphTransform(primal)); - outer->set_output(NewValueNode(kNone)); - TraceManager::EndTrace(); - - auto mng = Manage({cloned_bprop_fg, outer}, false); - - // Make sure (out, dout) provided. - if (cloned_bprop_fg->parameters().size() < 2) { - MS_LOG(EXCEPTION) << "Primitive or Cell " << primal->ToString() - << " bprop requires out and dout at least, but only got " << cloned_bprop_fg->parameters().size() - << " params. NodeInfo: " << trace::GetDebugInfo(cloned_bprop_fg->debug_info()); - } - - // In a bprop definition, the last two param should be out and dout. - auto dout = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 1]; - auto out_param = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 2]; - std::vector transf_args; - TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); - - TraceManager::DebugTrace(std::make_shared(dout->debug_info())); - (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); - auto out_value = outer->NewCNode(transf_args); - TraceManager::EndTrace(); - - (void)mng->Replace(out_param, out_value); - - TraceManager::DebugTrace(std::make_shared(out_param->debug_info())); - auto new_dout = cloned_bprop_fg->add_parameter(); - (void)mng->Replace(dout, new_dout); - // We remove all parameters except new_dout. - std::vector newBpropParams = {new_dout}; - cloned_bprop_fg->set_parameters(newBpropParams); - TraceManager::EndTrace(); - - outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)})); - return BasicClone(outer); -} -} // namespace ad -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ diff --git a/mindspore/ccsrc/optimizer/ad/grad.cc b/mindspore/ccsrc/optimizer/ad/grad.cc deleted file mode 100644 index d141dc6eea..0000000000 --- a/mindspore/ccsrc/optimizer/ad/grad.cc +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/ad/grad.h" -#include "optimizer/ad/dfunctor.h" -#include "ir/func_graph_cloner.h" -#include "utils/context/ms_context.h" -#include "utils/symbolic.h" -#include "utils/graph_utils.h" - -namespace mindspore { -namespace ad { -FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) { - MS_EXCEPTION_IF_NULL(func_graph); - auto gradkv = func_graph->transforms().find("grad"); - if (gradkv != func_graph->transforms().end()) { - return gradkv->second.func_graph(); - } - - auto manager_ptr = resources->manager(); - MS_EXCEPTION_IF_NULL(manager_ptr); - manager_ptr->AddFuncGraph(func_graph); - - auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { - if (MsContext::GetInstance()->is_multi_graph_sink()) { - if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { - f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); - } - } - }; - - auto f = std::make_shared(func_graph, resources); - auto user_defined = f->KUserDefined(func_graph); - if (user_defined != nullptr) { - multi_graph_sink(user_defined); - if (is_top) { - DFunctor::Clear(); - } - return user_defined; - } - f->Init(is_top); - f->MapObject(); - f->MapMorphism(); - auto ret = f->k_graph(); - if (is_top) { - DFunctor::Clear(); - } - - multi_graph_sink(ret); - return ret; -} - -FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { - auto fg = g_k_prims.KPrimitive(value_node, resources); - if (fg == nullptr) { - return nullptr; - } - return BasicClone(fg); -} - -MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &) { - MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim); - return fg; -} - -void CleanRes() { DFunctor::Clear(); } -} // namespace ad -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/ad/grad.h b/mindspore/ccsrc/optimizer/ad/grad.h deleted file mode 100644 index a878aa9df7..0000000000 --- a/mindspore/ccsrc/optimizer/ad/grad.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ - -#include -#include - -#include "ir/anf.h" -#include "ir/meta_func_graph.h" -#include "pipeline/resource.h" - -namespace mindspore { -namespace ad { -using ResourcePtr = std::shared_ptr; - -FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top = true); -FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); -MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &); -void CleanRes(); -} // namespace ad -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ diff --git a/mindspore/ccsrc/optimizer/ad/kprim.cc b/mindspore/ccsrc/optimizer/ad/kprim.cc deleted file mode 100644 index 4141fb5413..0000000000 --- a/mindspore/ccsrc/optimizer/ad/kprim.cc +++ /dev/null @@ -1,294 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "ir/anf.h" -#include "ir/primitive.h" -#include "ir/meta_func_graph.h" -#include "ir/func_graph_cloner.h" -#include "ir/manager.h" -#include "pipeline/resource.h" -#include "pipeline/parse/parse.h" -#include "optimizer/ad/dfunctor.h" -#include "optimizer/opt.h" -#include "operator/ops.h" -#include "operator/composite/composite.h" -#include "utils/symbolic.h" -#include "utils/primitive_utils.h" -#include "utils/context/ms_context.h" -#include "debug/info.h" -#include "debug/trace.h" - -#include "./common.h" - -namespace mindspore { -namespace ad { -using PatternListType = std::initializer_list; -KPrim g_k_prims; - -FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { - // Set a child scope named "grad'PrimitiveName'" for the bprop function, - // and add "Gradients" to the front. - static const std::string gradients_scope = "Gradients/"; - static const std::string grad_op_child_scope_prefix = "/grad"; - MS_EXCEPTION_IF_NULL(prim); - auto scope = std::make_shared(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + - grad_op_child_scope_prefix + prim->name()); - ScopeGuard scope_guard(scope); - py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast()->GetBpropFunction(); - if (fn == nullptr || py::isinstance(fn)) { - MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; - return nullptr; - } - FuncGraphPtr func_graph = parse::ParsePythonCode(fn); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; - return nullptr; - } - return func_graph; -} - -FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) { - static const std::string ad_module = "mindspore.ops._grad.grad_implementations"; - std::string func_name = "_fprop_" + prim->name(); - py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name); - auto func_graph = parse::ParsePythonCode(fn); - MS_EXCEPTION_IF_NULL(func_graph); - return BasicClone(func_graph); -} - -MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { - MS_EXCEPTION_IF_NULL(prim); - - auto iter = bprop_registry_meta_.find(prim); - if (iter != bprop_registry_meta_.end()) { - return iter->second; - } - - if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { - MetaFuncGraphPtr meta = std::make_shared("make_tuple_gradient"); - bprop_registry_meta_[prim::kPrimMakeTuple] = meta; - return meta; - } - - MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; -} - -FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { - if (!IsValueNode(value_node)) { - MS_LOG(EXCEPTION) << "Primitive node is not valid."; - } - - auto prim = GetValueNode(value_node); - if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) { - auto fprop = GetFprop(prim); - fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer)); - return fprop; - } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { - return nullptr; - } - - FuncGraphPtr bprop_fg = nullptr; - if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { - bprop_fg = BpropCut(value_node, resources); - } else { - auto iter = bprop_registry_.find(prim); - if (iter != bprop_registry_.end()) { - bprop_fg = iter->second; - } - - if (bprop_fg == nullptr) { - bprop_fg = GetBprop(prim); - if (bprop_fg != nullptr) { - // Set bprop_g graph cache - bprop_registry_[prim] = bprop_fg; - } else { - bprop_fg = FakeBprop(value_node, resources); - } - } - } - - auto expanded_fg = BpropToK(prim, bprop_fg); - if (expanded_fg == nullptr) { - MS_LOG(EXCEPTION) << "Failed convert " << prim->name() - << " prim bprop function to J expanded func graph. NodeInfo: " - << trace::GetDebugInfo(bprop_fg->debug_info()); - } - - return expanded_fg; -} - -AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg) { - // bprop_fg has been checked in caller - if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { - // Set bprop output as (env, dx, dy, dz, ...) - auto cbprop = bprop_fg->output()->cast(); - auto &inputs = cbprop->inputs(); - - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - args.push_back(NewValueNode(newenv)); - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - return NewCNode(args, bprop_fg); - } - - // Set bprop output as (env, dx) - std::string model_name("mindspore.ops.composite.multitype_ops.add_impl"); - std::string python_ops("_tuple_add"); - auto tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg); - return NewCNode({NewValueNode(prim::GetPythonOps(python_ops, model_name)), tuple, bprop_fg->output()}, bprop_fg); -} - -void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, - std::vector *const transf_args) { - MS_EXCEPTION_IF_NULL(mng); - // bprop_fg has been checked in caller - // transform except the last 2 parameters: out, dout. - for (size_t i = 0; i < bprop_fg->parameters().size() - 2; ++i) { - auto p = bprop_fg->parameters()[i]; - MS_EXCEPTION_IF_NULL(p); - - TraceManager::DebugTrace(std::make_shared(p->debug_info())); - auto transf_p = outer->add_parameter(); - TraceManager::EndTrace(); - - (void)mng->Replace(p, transf_p); - transf_args->push_back(transf_p); - } -} - -void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool check_bprop_flag = context->check_bprop_flag(); - // Skip checking if check_bprop not set - if (!check_bprop_flag) { - return; - } - - // bprop_fg has been checked in caller - auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops"); - MS_EXCEPTION_IF_NULL(check_bprop_class); - auto check_bprop = - bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared(prim_to_check))}); - - std::vector inputs; - inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2); - AnfNodePtr params = bprop_fg->NewCNode(inputs); - - inputs.clear(); - inputs.push_back(check_bprop); - inputs.push_back(bprop_fg->output()); - inputs.push_back(params); - AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs); - bprop_fg->set_output(bprop_out); -} - -FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { - MS_EXCEPTION_IF_NULL(bprop_fg); - auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph(); - auto expanded_fg = BpropToK(fprop_fg, bprop_fg); - if (expanded_fg == nullptr) { - MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString() - << " Cell bprop function to K expanded func graph. NodeInfo: " - << trace::GetDebugInfo(fprop_fg->debug_info()); - } - return expanded_fg; -} - -FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { - auto prim = GetValueNode(value_node); - MS_EXCEPTION_IF_NULL(prim); - auto &node_users = resources->manager()->node_users(); - - auto &users = node_users[value_node]; - auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { - return IsPrimitiveCNode(user.first, prim); - }); - if (cnode == users.end()) { - MS_LOG(EXCEPTION) << "Fail to find cnode."; - } - auto inputs_num = cnode->first->cast()->size() - 1; - - auto func_graph = std::make_shared(); - std::vector outputs; - - auto bprop_cut = std::make_shared("bprop_cut", py::object()); - if (!prim->is_base()) { - PrimitivePyPtr prim_py = dyn_cast(prim); - bprop_cut->set_hook(prim_py->hook()); - } - - auto cell_id = GetValue(prim->GetAttr("cell_id")); - if (cell_id != "") { - (void)bprop_cut->AddAttr("cell_hook", MakeValue(true)); - (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id)); - } - - outputs.push_back(NewValueNode(bprop_cut)); - for (size_t i = 0; i < inputs_num; ++i) { - auto param = func_graph->add_parameter(); - outputs.push_back(param); - } - auto p1 = func_graph->add_parameter(); - auto p2 = func_graph->add_parameter(); - outputs.push_back(p1); - outputs.push_back(p2); - - func_graph->set_output(func_graph->NewCNode(outputs)); - return func_graph; -} - -FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { - auto prim = value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - auto &node_users = resources->manager()->node_users(); - - auto &users = node_users[value_node]; - auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { - return IsPrimitiveCNode(user.first, prim); - }); - if (cnode == users.end()) { - MS_LOG(EXCEPTION) << "Fail to find cnode."; - } - auto inputs_num = cnode->first->cast()->inputs().size() - 1; - - auto func_graph = std::make_shared(); - std::vector outputs; - outputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - - auto fake_bprop = std::make_shared("fake_bprop"); - (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined.")); - - for (size_t i = 0; i < inputs_num; ++i) { - // Mock params for inputs - auto param = func_graph->add_parameter(); - // Mock derivatives for each inputs - outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param})); - } - // mock params for out and dout - (void)func_graph->add_parameter(); - (void)func_graph->add_parameter(); - func_graph->set_output(func_graph->NewCNode(outputs)); - return func_graph; -} -} // namespace ad -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/clean.cc b/mindspore/ccsrc/optimizer/clean.cc deleted file mode 100644 index bb52273568..0000000000 --- a/mindspore/ccsrc/optimizer/clean.cc +++ /dev/null @@ -1,531 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/clean.h" -#include -#include -#include -#include -#include -#include "./common.h" -#include "debug/trace.h" -#include "operator/composite/composite.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { -using mindspore::abstract::AbstractAttribute; -using mindspore::abstract::AbstractClass; -using mindspore::abstract::AbstractDictionary; -using mindspore::abstract::AbstractJTagged; -using mindspore::abstract::AbstractList; -using mindspore::abstract::AbstractScalar; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractUndetermined; - -static AbstractBasePtr Reabs(const AbstractBasePtr &t) { - if (t == nullptr) { - return nullptr; - } - - AbstractBasePtr res = t; - if (t->isa()) { - auto abs_class = dyn_cast(t); - AbstractBasePtrList baselist; - auto attributes = abs_class->attributes(); - (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), - [](const AbstractAttribute &item) { return item.second; }); - res = std::make_shared(baselist); - } else if (t->isa()) { - auto abs_dict = dyn_cast(t); - AbstractBasePtrList baselist; - auto elements = abs_dict->elements(); - (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), - [](const AbstractAttribute &item) { return item.second; }); - res = std::make_shared(baselist); - } else if (t->isa()) { - auto abs_dict = dyn_cast(t); - res = std::make_shared(abs_dict->elements()); - } - return res; -} - -AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - const auto &inputs = node->inputs(); - // Inputs should be [getattr, data, attribute] - MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); - - AnfNodePtr data = inputs[1]; - AnfNodePtr cons = inputs[2]; - MS_EXCEPTION_IF_NULL(data); - MS_EXCEPTION_IF_NULL(cons); - - auto dt = data->abstract(); - if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) { - return nullptr; - } - - if (!dt->isa()) { - MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << "."; - } - - auto cons_is_str = IsValueNode(cons); - auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; - - auto ct = dyn_cast(dt); - const auto &cmap = ct->attributes(); - int count = 0; - for (auto &item : cmap) { - if (cons_is_str && item.first == cons_str) { - break; - } - count++; - } - - auto idx_c = NewValueNode(count); - AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); - idx_c->set_abstract(aptr); - - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); -} - -AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - // Inputs should be [dict_getitem, dict, item] - const auto &inputs = node->inputs(); - MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); - - AnfNodePtr data = inputs[1]; - AnfNodePtr cons = inputs[2]; - MS_EXCEPTION_IF_NULL(data); - MS_EXCEPTION_IF_NULL(cons); - - auto dt = data->abstract(); - MS_EXCEPTION_IF_NULL(dt); - if (!dt->isa()) { - MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name(); - } - auto cons_is_str = IsValueNode(cons); - auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; - - auto ct = dyn_cast(dt); - const auto &cmap = ct->elements(); - int count = 0; - for (auto &item : cmap) { - if (cons_is_str && item.first == cons_str) { - break; - } - count++; - } - - auto idx_c = NewValueNode(count); - AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); - idx_c->set_abstract(aptr); - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); -} - -AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - // Inputs should be [dict_setitem, dict, item, value] - const auto &inputs = node->inputs(); - MS_ASSERT(inputs.size() == 4 && "DictSetItem should have three inputs."); - - AnfNodePtr data = inputs[1]; - AnfNodePtr cons = inputs[2]; - AnfNodePtr item_value = inputs[3]; - MS_EXCEPTION_IF_NULL(data); - MS_EXCEPTION_IF_NULL(cons); - - auto dt = data->abstract(); - MS_EXCEPTION_IF_NULL(dt); - if (!dt->isa()) { - MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name(); - } - auto cons_is_str = IsValueNode(cons); - auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; - - auto ct = dyn_cast(dt); - const auto &cmap = ct->elements(); - int count = 0; - for (auto &item : cmap) { - if (cons_is_str && item.first == cons_str) { - break; - } - count++; - } - if (IntToSize(count) >= cmap.size()) { - // for dictionary set, if the key does not exist, we should create a new item - auto tuple_add_op = std::make_shared("tuple_add"); - auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value}); - return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item}); - } - auto idx_c = NewValueNode(count); - AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); - idx_c->set_abstract(aptr); - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value}); -} - -AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - std::vector inputs; - inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr; - (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end()); - return node->func_graph()->NewCNode(inputs); -} - -AnfNodePtr ErasePartialNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - const auto &inputs = node->inputs(); - // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; - MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); - - std::vector args(inputs.begin() + 2, inputs.end()); - auto oper = inputs[1]; - if (IsPrimitive(oper, prim::kPrimMakeRecord)) { - if (args.size() == 1) { - return NewValueNode(prim::kPrimMakeTuple); - } - - if (args.size() > 1) { - std::vector new_inputs; - new_inputs.emplace_back(NewValueNode(prim::kPrimPartial)); - new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end()); - - MS_EXCEPTION_IF_NULL(node->func_graph()); - return node->func_graph()->NewCNode(new_inputs); - } - } - return nullptr; -} - -AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - std::vector inputs; - inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items; - (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end()); - return node->func_graph()->NewCNode(inputs); -} - -AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - const auto &inputs = node->inputs(); - // Inputs should be [list_getitem, list, item] - if (inputs.size() < 3) { - MS_LOG(EXCEPTION) << "Node's input number < 3."; - } - - AnfNodePtr data = inputs[1]; - AnfNodePtr cons = inputs[2]; - MS_EXCEPTION_IF_NULL(data); - MS_EXCEPTION_IF_NULL(cons); - - auto cons_node = cons->cast(); - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); -} - -AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - const auto &inputs = node->inputs(); - // Inputs should be [list_setitem, list, index, item] - if (inputs.size() < 4) { - MS_LOG(EXCEPTION) << "Node's input number < 4."; - } - - AnfNodePtr data = inputs[1]; - AnfNodePtr cons = inputs[2]; - AnfNodePtr value = inputs[3]; - - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); -} - -AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - const auto &inputs = node->inputs(); - MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); - return inputs[2]; -} - -AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - const auto &inputs = node->inputs(); - // Inputs should be [make_keyword_arg, key, value] - MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); - return inputs[2]; -} - -AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - const auto &inputs = node->inputs(); - // Inputs should be [extract_keyword_arg, arg, key] - MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); - return inputs[2]; -} - -ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) { - const int DEPTH_MAX = 5; - if (depth > DEPTH_MAX) { - MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; - } - std::vector elements; - for (const auto &it : value_list->value()) { - ValuePtr value = nullptr; - if (it->isa()) { - value = ConvertValueListToValueTuple(it->cast(), depth + 1); - } else { - value = it; - } - elements.push_back(value); - } - return std::make_shared(elements); -} - -AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - ValuePtr value = node->value(); - auto value_list = value->cast(); - MS_EXCEPTION_IF_NULL(value_list); - int depth = 0; - return std::make_shared(ConvertValueListToValueTuple(value_list, depth)); -} - -// Convert class to Tuple -// Convert getattr to getitem -// Convert make_record to make_tuple -bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(root); - - bool changed = false; - - // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var - AnfNodeSet all_node = manager->all_nodes(); - for (auto &node : all_node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - AnfNodePtr new_node = nullptr; - if (IsValueNode(node)) { - new_node = NewValueNode(prim::kPrimMakeTuple); - } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) { - new_node = ConvertGetAttrToTupleGetItem(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) { - new_node = ConvertMakeRecordToMakeTuple(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) { - new_node = ErasePartialNode(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) { - new_node = ConvertDictGetItemToTupleGetItem(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) { - new_node = ConvertDictSetItemToTupleSetItem(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) { - new_node = EraseMakeDictNode(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) { - new_node = EraseMakeKeywordArgNode(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) { - new_node = EraseExtractKeywordArg(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { - new_node = ConvertMakeListToMakeTuple(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) { - new_node = ConvertListGetItemToTupleGetItem(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) { - new_node = ConvertListSetItemToTupleSetItem(cnode); - } else if (IsValueNode(node)) { - new_node = ConvertValueListNodeToValueTupleNode(node->cast()); - } - - if (new_node != nullptr) { - new_node->set_abstract(node->abstract()); - MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); - (void)manager->Replace(node, new_node); - changed = true; - } - } - - for (auto &node : manager->all_nodes()) { - auto ret = Reabs(node->abstract()); - node->set_abstract(ret); - } - return changed; -} - -// expand tuples in graph parameters -static std::vector ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph, - const std::vector ¶ms) { - MS_EXCEPTION_IF_NULL(mng); - MS_EXCEPTION_IF_NULL(func_graph); - - std::vector new_params; - for (const auto ¶m : params) { - MS_EXCEPTION_IF_NULL(param); - auto param_abs = param->abstract(); - MS_EXCEPTION_IF_NULL(param_abs); - - if (param_abs->isa()) { - MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info()); - } - - if (!param_abs->isa()) { - new_params.emplace_back(param); - continue; - } - - std::vector new_param; - std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; - auto abs_tuple = dyn_cast(param_abs); - for (auto &elem : abs_tuple->elements()) { - auto np = std::make_shared(func_graph); - np->set_abstract(elem); - new_param.emplace_back(np); - } - (void)inputs.insert(inputs.end(), new_param.begin(), new_param.end()); - auto new_tuple = func_graph->NewCNode(inputs); - (void)mng->Replace(param, new_tuple); - - auto expand_param = ExpandTuplesP(mng, func_graph, new_param); - (void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end()); - } - return new_params; -} - -// expand tuples in graph applies -static std::vector ExpandTuplesC(const FuncGraphPtr &graph, const std::vector &inputs) { - MS_EXCEPTION_IF_NULL(graph); - - std::vector new_inputs; - for (const auto &input : inputs) { - MS_EXCEPTION_IF_NULL(input); - - auto input_abs = input->abstract(); - MS_EXCEPTION_IF_NULL(input_abs); - - if (input_abs->isa()) { - auto abstract_tag = dyn_cast(input_abs); - if (abstract_tag->element()->isa()) { - MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info()); - } - } - - if (!input_abs->isa()) { - new_inputs.emplace_back(input); - continue; - } - - int idx = 0; - std::vector new_input; - auto abs_tuple = dyn_cast(input_abs); - for (auto &elem : abs_tuple->elements()) { - auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); - AbstractBasePtr aptr = std::make_shared(std::make_shared(idx)); - c_node->input(2)->set_abstract(aptr); - c_node->set_abstract(elem); - new_input.emplace_back(c_node); - idx++; - } - - auto expand_tuple = ExpandTuplesC(graph, new_input); - (void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end()); - } - - return new_inputs; -} - -// remove most uses of tuples from the graph parameters & apply inputs -// tuples that are returned will be kept -// tuples in CNode's inputs: AbstractTuple (a, b ,c) --> -// CNode("tuple_getitem", (a,b,c), 0) -// CNode("tuple_getitem", (a,b,c), 1) -// CNode("tuple_getitem", (a,b,c), 2) -// tuples in Graph's parameters: AbstractTuple (a, b, c) --> -// CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) -// cppcheck-suppress unusedFunction -void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(root); - - // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var - AnfNodeSet all_node = manager->all_nodes(); - for (auto &node : all_node) { - auto cnode = node->cast(); - if (cnode == nullptr) { - continue; - } - - const auto &inputs = cnode->inputs(); - - // Bypass the first input in inputs as it's fn. - if (!IsValueNode(inputs[0])) { - std::vector expand_inputs; - (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end()); - - auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs); - if (new_inputs != expand_inputs) { - std::vector cnode_inputs{inputs[0]}; - (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end()); - - MS_EXCEPTION_IF_NULL(node->func_graph()); - auto new_node = node->func_graph()->NewCNode(cnode_inputs); - new_node->set_abstract(node->abstract()); - - (void)manager->Replace(node, new_node); - } - // Bypass the first 2 inputs in inputs as it's [partial, fn]. - } else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode(inputs[1])) { - std::vector expand_inputs; - (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end()); - - auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs); - if (new_inputs != expand_inputs) { - std::vector cnode_inputs{inputs[0], inputs[1]}; - (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end()); - - MS_EXCEPTION_IF_NULL(cnode->func_graph()); - auto new_node = cnode->func_graph()->NewCNode(cnode_inputs); - new_node->set_abstract(cnode->abstract()); - - (void)manager->Replace(node, new_node); - } - } - } - - FuncGraphSet all_graph = manager->func_graphs(); - for (auto &func_graph : all_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); - manager->SetParameters(func_graph, expand_p); - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/clean.h b/mindspore/ccsrc/optimizer/clean.h deleted file mode 100644 index 0130ecfb32..0000000000 --- a/mindspore/ccsrc/optimizer/clean.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ - -#include -#include "ir/anf.h" -#include "operator/ops.h" -#include "utils/any.h" -#include "ir/manager.h" -#include "pipeline/static_analysis/dshape.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { - -// Remove the class type from graphs -bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); - -// Remove most uses of tuples from the graph -// tuples that are returned will be kept -void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); - -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ diff --git a/mindspore/ccsrc/optimizer/control_depend.cc b/mindspore/ccsrc/optimizer/control_depend.cc deleted file mode 100644 index 0b5c85b1e0..0000000000 --- a/mindspore/ccsrc/optimizer/control_depend.cc +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/control_depend.h" - -#include -#include -#include -#include -#include - -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -std::vector DoControlDepend(const FuncGraphPtr &graph, const CNodePtr &return_node, - const std::vector &effect_index, const std::vector &cnodes) { - std::vector depend_nodes{NewValueNode(prim::kPrimDepend), return_node->input(1)}; - std::vector make_tuple{NewValueNode(prim::kPrimMakeTuple)}; - size_t effect_size = effect_index.size(); - for (size_t i = 0; i < effect_size; i++) { - size_t pre_index = 0; - if (i > 0) { - pre_index = effect_index[i - 1] + 1; - } - size_t this_index = effect_index[i]; - size_t last_index = cnodes.size() - 2; - if (i < effect_size - 1) { - last_index = effect_index[i + 1]; - } - - if (this_index > pre_index) { - std::vector pre_segment; - for (size_t k = pre_index; k < this_index; k++) { - // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. - if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || - IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { - continue; - } - pre_segment.push_back(cnodes[k]); - } - auto roots = FindRoots(pre_segment); - for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { - AnfNodePtr control_depend = - graph->NewCNode({NewValueNode(prim::kPrimControlDepend), *iter, cnodes[this_index]}); - make_tuple.push_back(control_depend); - } - } - if (last_index > this_index) { - std::vector last_segment; - for (size_t k = this_index + 1; k <= last_index; k++) { - // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. - if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || - IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { - continue; - } - last_segment.push_back(cnodes[k]); - } - auto leaves = FindLeaves(last_segment); - for (auto iter = leaves->begin(); iter != leaves->end(); (void)iter++) { - AnfNodePtr control_depend = - graph->NewCNode({NewValueNode(prim::kPrimControlDepend), cnodes[this_index], *iter}); - make_tuple.push_back(control_depend); - } - } - } - depend_nodes.push_back(graph->NewCNode(make_tuple)); - return depend_nodes; -} - -void AddControlDepend(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - std::list orders = graph->GetOrderedCnodes(); - std::vector cnodes(orders.begin(), orders.end()); - size_t cnodes_size = cnodes.size(); - // get effect index of cnodes - std::vector effect_index{}; - for (size_t i = 0; i < cnodes_size; i++) { - if (graph->HasEffect(cnodes[i])) { - effect_index.push_back(i); - } - } - if (effect_index.empty()) { - return; - } - AnfNodePtr last_node = cnodes[cnodes_size - 1]; - CNodePtr return_node; - if (last_node->isa()) { - return_node = last_node->cast(); - } - MS_EXCEPTION_IF_NULL(return_node); - if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { - MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; - } - if (return_node->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; - } - - auto depend_node_inputs = DoControlDepend(graph, return_node, effect_index, cnodes); - auto depend_cnode = graph->NewCNode(depend_node_inputs); - depend_cnode->set_abstract(depend_cnode->input(1)->abstract()); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (!manager->Replace(return_node->input(1), depend_cnode)) { - MS_LOG(EXCEPTION) << "Depend replace node failed"; - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/cse.cc b/mindspore/ccsrc/optimizer/cse.cc deleted file mode 100644 index 0b675cca72..0000000000 --- a/mindspore/ccsrc/optimizer/cse.cc +++ /dev/null @@ -1,231 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/cse.h" -#include -#include -#include -#include "./common.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { -using mindspore::abstract::AbstractBase; -using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractFunctionPtr; - -BasePtr AbsOf(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto node_abs = node->abstract(); - // in testcase: TestOptOpt.CSE, node->abstract() is null; - if (node_abs == nullptr) { - return kAnyValue; - } - - return node_abs; -} - -bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { - bool changed = false; - for (FuncGraphPtr fg : manager->func_graphs()) { - MS_EXCEPTION_IF_NULL(fg); - std::vector order_group; - std::unordered_map> groups; - std::unordered_map hashes; - - std::vector toposet = TopoSort(fg->get_return()); - for (auto node : toposet) { - MS_EXCEPTION_IF_NULL(node); - if (hashes.find(node) != hashes.end()) { - continue; - } - - std::size_t h = 0; - if (node->isa()) { - ValueNodePtr value_node = node->cast(); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - h = hash_combine(value->hash(), (AbsOf(value_node)->hash())); - } else if (node->isa()) { - auto cnode = node->cast(); - auto &inputs = cnode->inputs(); - size_t init = 0; - h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) { - return hash_combine(hash, hashes[node_in]); - }); - } else if (node->isa()) { - h = node->hash(); - } else { - MS_LOG(ERROR) << "Unknow node type"; - } - - hashes[node] = h; - if (groups.find(h) == groups.end()) { - std::vector innervec({node}); - groups[h] = innervec; - order_group.emplace_back(h); - } else { - groups[h].push_back(node); - } - } - - changed = DoReplace(manager, order_group, &groups) || changed; - } - - return changed; -} -// The op like print, summary, or the op do not has true output, and always as a depend node input. -static bool HasSideEffect(const AnfNodePtr &node) { - auto prim = GetCNodePrimitive(node); - if (prim == nullptr) { - return false; - } - auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT); - if (side_effect_v != nullptr && side_effect_v->isa()) { - return GetValue(side_effect_v); - } - return false; -} -// If true do not merge the node. -bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { - bool has_random_effect = false; - auto prim_main = GetCNodePrimitive(main); - auto prim_node = GetCNodePrimitive(node); - // if has random effect, when generate by different op (not same object), do not merge. - if (prim_main != nullptr) { - if (prim_main == prim_node) { - return false; - } - auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); - if (effect_val != nullptr && effect_val->isa()) { - has_random_effect = GetValue(effect_val); - } - } - return has_random_effect; -} - -bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { - MS_EXCEPTION_IF_NULL(main); - MS_EXCEPTION_IF_NULL(node); - - if (main->isa() && node->isa()) { - auto main_value = GetValueNode(main); - auto node_value = GetValueNode(node); - return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); - } else if (main->isa() && node->isa()) { - auto c_main = main->cast(); - auto c_node = node->cast(); - // When appsame is true, check if has side effect, do not merge. - if (check_side_effect && HasSideEffect(main)) { - return false; - } - const auto &inp1 = c_main->inputs(); - const auto &inp2 = c_node->inputs(); - if (inp1.size() != inp2.size()) { - return false; - } - for (size_t j = 0; j < inp1.size(); j++) { - auto inp1_j = inp1[j]; - auto inp2_j = inp2[j]; - MS_EXCEPTION_IF_NULL(inp1_j); - MS_EXCEPTION_IF_NULL(inp2_j); - if (!(*inp1_j == *inp2_j)) { - // Handle the case of two different Tensor, but with the same value - if (IsValueNode(inp1_j) && IsValueNode(inp2_j)) { - auto tensor1 = GetValueNode(inp1_j); - auto tensor2 = GetValueNode(inp2_j); - if (tensor1->ValueEqual(*tensor2)) { - continue; - } - } else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) { - // When the same side effect node as another two nodes' inputs, we still merge the node. - // Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the - // node. - if (CheckReplace(inp1_j, inp2_j, false)) { - continue; - } - } - return false; - } - } - // When appsame is true, check if has random effect do not merge - if (CheckRandomEffect(c_main, c_node)) { - return false; - } - return true; - } - // a parameter node. - return false; -} - -bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, - std::unordered_map> *groups) const { - bool changes = false; - std::set clear_set; - for (auto &h : order_group) { - std::vector &group = (*groups)[h]; - // If there are more than 2 node in that group, they may be same common expression can be eliminated. - if (group.size() > 1) { - for (size_t k = 0; k < group.size() - 1; k++) { - AnfNodePtr main = group[k]; - MS_EXCEPTION_IF_NULL(main); - - // When all node in group has been replaced - // or a valuenode node, skip compare in group - if ((k + 1 + clear_set.size() == group.size()) || (k > 0 && main->isa())) { - break; - } - - // skip node has been replaced - if (clear_set.find(k) != clear_set.end()) { - continue; - } - - // Compare with rest elements in this group. - for (size_t i = k + 1; i < group.size(); i++) { - auto node = group[i]; - MS_EXCEPTION_IF_NULL(node); - - if (clear_set.find(i) != clear_set.end()) { - continue; - } - if (main->func_graph() != node->func_graph()) { - continue; - } - if (CheckReplace(node, main)) { - changes = true; - (void)manager->Replace(node, main); - (void)clear_set.insert(i); - } - } - } - clear_set.clear(); - } - } - - return changes; -} - -bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const { - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(root); - - return BuildOrderGroupAndDoReplace(manager); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/cse.h b/mindspore/ccsrc/optimizer/cse.h deleted file mode 100644 index 57163cc5c9..0000000000 --- a/mindspore/ccsrc/optimizer/cse.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ - -#include -#include -#include -#include "ir/anf.h" -#include "ir/manager.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { - -// Common subexpression elimination. -class CSE { - public: - explicit CSE(bool report_changes = true) : report_changes_(report_changes) {} - virtual ~CSE() = default; - - bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { - bool chg = Cse(root, optimizer->resource()->manager()); - return chg && report_changes_; - } - - virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const; - - virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; - - bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const; - - private: - bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const; - bool DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, - std::unordered_map> *groups) const; - bool report_changes_; -}; - -BasePtr AbsOf(const AnfNodePtr &node); -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ diff --git a/mindspore/ccsrc/optimizer/graph_kernel_reuse.cc b/mindspore/ccsrc/optimizer/graph_kernel_reuse.cc deleted file mode 100644 index dc20ad925e..0000000000 --- a/mindspore/ccsrc/optimizer/graph_kernel_reuse.cc +++ /dev/null @@ -1,157 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/graph_kernel_reuse.h" -#include -#include -#include -#include "./common.h" -#include "utils/graph_utils.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { - -bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) { - if (a->abstract() && b->abstract()) { - auto a_type = a->abstract()->GetTypeTrack(); - auto b_type = b->abstract()->GetTypeTrack(); - - if (a_type != b_type) { - return false; - } - - auto a_shape = a->abstract()->GetShapeTrack(); - auto b_shape = b->abstract()->GetShapeTrack(); - if (a_shape != nullptr && a_shape == b_shape) { - return true; - } - - if (a_shape != nullptr && b_shape != nullptr && a_shape->isa() && - b_shape->isa()) { - return a_shape->cast()->shape() == b_shape->cast()->shape(); - } - } - return false; -} - -bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { - bool changed = false; - auto fgs = manager->func_graphs(); - for (FuncGraphPtr &fg : fgs) { - if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - continue; - } - std::string key = GetValue(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) { - if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) { - FuncGraphPtr new_fg = nullptr; - for (auto &cfg : graph_kernel_ops[key]) { - // If two graphs have different size then continue - auto fg_topos = TopoSort(fg->get_return()); - auto cfg_topos = TopoSort(cfg->get_return()); - if (fg_topos.size() != cfg_topos.size()) { - continue; - } - - // Compare const tensor - bool has_same = true; - for (size_t i = 0; i < fg_topos.size(); ++i) { - if (IsValueNode(fg_topos[i])) { - if (!IsValueNode(cfg_topos[i])) { - has_same = false; - break; - } - - auto tensor1 = GetValueNode(fg_topos[i]); - auto tensor2 = GetValueNode(cfg_topos[i]); - if (!tensor1->ValueEqual(*tensor2)) { - has_same = false; - break; - } - } - } - - if (!has_same) { - continue; - } - - auto fg_input = fg->parameters(); - auto cfg_input = cfg->parameters(); - if (fg_input.size() != cfg_input.size()) { - continue; - } - // Compare input - for (size_t i = 0; i < fg_input.size(); ++i) { - if (!CompareNode(fg_input[i], cfg_input[i])) { - has_same = false; - break; - } - } - if (!has_same) { - continue; - } - - // Compare output - if (!CompareNode(fg->output(), cfg->output())) { - continue; - } - - // Find reusable fg - new_fg = cfg; - break; - } - - if (new_fg != nullptr) { - // Replace current fg with existing fg - auto users = fg->func_graph_cnodes_index(); - for (auto &iter : users) { - auto cnode = iter.first->first->cast(); - auto new_input = cnode->inputs(); - auto main_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(main_graph); - if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) { - new_input[1] = NewValueNode(new_fg); - } else { - new_input[0] = NewValueNode(new_fg); - } - auto new_cnode = main_graph->NewCNode(new_input); - manager->Replace(iter.first->first, new_cnode); - changed = true; - } - - } else { - // Add current fg to map - graph_kernel_ops[key].push_back(fg); - } - } - } else { - graph_kernel_ops[key] = {fg}; - } - } - - return changed; -} - -bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) { - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(root); - - return DoReplace(manager); -} - -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/graph_kernel_reuse.h b/mindspore/ccsrc/optimizer/graph_kernel_reuse.h deleted file mode 100644 index ed5cc93d18..0000000000 --- a/mindspore/ccsrc/optimizer/graph_kernel_reuse.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H -#define MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H - -#include -#include -#include -#include - -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { - -// Common subexpression elimination. -class GraphKernelReuse { - public: - GraphKernelReuse() : count(0) {} - virtual ~GraphKernelReuse() = default; - - bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { - bool chg = ReuseGraphKernel(root, optimizer->resource()->manager()); - return chg; - } - - bool CompareNode(const AnfNodePtr a, const AnfNodePtr other); - bool DoReplace(const FuncGraphManagerPtr manager); - - bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager); - - private: - std::unordered_map> graph_kernel_ops; - int count; -}; - -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc deleted file mode 100644 index 166151751f..0000000000 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ /dev/null @@ -1,174 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "optimizer/irpass.h" -#include "optimizer/irpass/arithmetic_simplify.h" -#include "optimizer/irpass/branch_culling.h" -#include "optimizer/irpass/cast_eliminate.h" -#include "optimizer/irpass/convert.h" -#include "optimizer/irpass/env_item_eliminate.h" -#include "optimizer/irpass/grad_var_prepare.h" -#include "optimizer/irpass/gradient_eliminate.h" -#include "optimizer/irpass/inline.h" -#include "optimizer/irpass/incorporate_call.h" -#include "optimizer/irpass/incorporate_getitem.h" -#include "optimizer/irpass/item_tuple_eliminate.h" -#include "optimizer/irpass/mark_interface_fusion.h" -#include "optimizer/irpass/merge_addn.h" -#include "optimizer/irpass/minmax_grad.h" -#include "optimizer/irpass/param_replace.h" -#include "optimizer/irpass/partial_eliminate.h" -#include "optimizer/irpass/reduce_eliminate.h" -#include "optimizer/irpass/ref_eliminate.h" -#include "optimizer/irpass/reshape_eliminate.h" -#include "optimizer/irpass/special_op_eliminate.h" -#include "optimizer/irpass/specialize_transform.h" -#include "optimizer/irpass/symbol_resolver.h" -#include "optimizer/irpass/tile_eliminate.h" -#include "optimizer/irpass/transpose_eliminate.h" -#include "optimizer/opt.h" -#include "optimizer/irpass/indexed_slices_eliminate.h" - -namespace mindspore { -namespace opt { -namespace irpass { -OptimizeIRPassLib::OptimizeIRPassLib() { - arithmetic_simplify_ = MakeSubstitution(std::make_shared(), "arithmetic_simplify", - {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, - prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); - arithmetic_simplify2_ = - MakeSubstitution(std::make_shared(), "arithmetic_simplify2", {prim::kPrimMul}); - special_op_eliminate_ = - MakeSubstitution(std::make_shared(), "special_op_eliminate", - {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, - prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv}); - zero_like_fill_zero_ = - MakeSubstitution(std::make_shared(), "zero_like_fill_zero", prim::kPrimZerosLike); - adjust_all_reduce_mul_add_ = - MakeSubstitution(std::make_shared(), "adjust_all_reduce_mul_add", prim::kPrimAddN); - - // ops eliminate - item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", - {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); - tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); - cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); - reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); - transpose_eliminate_ = - MakeSubstitution(std::make_shared(), "transpose_eliminate", prim::kPrimTranspose); - reduce_eliminate_ = MakeSubstitution( - std::make_shared(), "reduce_eliminate", - {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); - partial_eliminate_ = MakeSubstitution(std::make_shared(), "partial_eliminate", IsCNodeDup); - same_eliminate_ = MakeSubstitution(std::make_shared(), "same_eliminate", prim::kPrimSameTypeShape); - check_bprop_eliminate_ = - MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); - reset_defer_inline_ = - MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); - depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); - - // Env Item Eliminate - env_get_item_eliminate_ = - MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); - new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); - incorporate_env_getitem_ = - MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); - incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), - "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); - - // Ref eliminate - make_ref_eliminate_ = - MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); - get_ref_param_eliminate_ = MakeSubstitution(std::make_shared(), "get_ref_param_eliminate", - {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); - get_make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "get_make_ref_eliminate", - {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); - - replace_refkey_by_param_ = MakeSubstitution(std::make_shared(), "replace_refkey_by_param", - IsValueNode, opt::FORCE_RENORM); - replace_old_param_ = MakeSubstitution(std::make_shared(), "replace_old_param", IsParam); - // Gradient transforms - expand_jprim_ = MakeSubstitution(std::make_shared(), "expand_jprim", prim::kPrimJ); - minmaximum_grad_ = MakeSubstitution(std::make_shared(), "minmaximum_grad", prim::kPrimTupleGetItem); - - // branch culling - switch_simplify_ = MakeSubstitution(std::make_shared(), "switch_simplify", prim::kPrimSwitch); - float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared(), - "float_tuple_getitem_switch", prim::kPrimTupleGetItem); - float_env_getitem_switch_ = - MakeSubstitution(std::make_shared(), "float_env_getitem_switch", prim::kPrimEnvGetItem); - convert_switch_replacement_ = - MakeSubstitution(std::make_shared(), "convert_switch_replacement", IsCNodeDup); - - // Addn - merge_addn_ = MakeSubstitution(std::make_shared(), "merge_addn", prim::kPrimAddN); - addn_zero_filter_ = MakeSubstitution(std::make_shared(), "addn_zero_filter", prim::kPrimAddN); - - // inline - inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); - replace_applicator_ = - MakeSubstitution(std::make_shared(), "replace_applicator", IsValueNode); - specialize_transform_ = - MakeSubstitution(std::make_shared(), "specialize_transform", IsCNodeGraph); - - // Incorporation - incorporate_getitem_set_ = - MakeSubstitution(std::make_shared(), "incorporate_getitem_set", prim::kPrimTupleGetItem); - incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared(), - "incorporate_getitem_from_param", IsCNodeGraphKernel); - incorporate_call_ = MakeSubstitution(std::make_shared(), "incorporate_call", IsCNodeDup); - incorporate_call_switch_ = - MakeSubstitution(std::make_shared(), "incorporate_call_switch", IsCNodeDup); - - // Virtual Dataset - virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared(), - "virtual_dataset_eliminate", prim::kPrimVirtualDataset); - - // Convert - print_tuple_wrapper_ = - MakeSubstitution(std::make_shared(), "print_tuple_wrapper", prim::kPrimPrint); - - // Unused parameter eliminate - unused_parameter_eliminate_ = - MakeSubstitution(std::make_shared(), "unused_parameter_eliminate", IsCNodeGraphKernel); - unused_output_eliminate_ = - MakeSubstitution(std::make_shared(), "unused_output_eliminate", IsCNodeGraphKernel); - - // AddN eliminate - addn_eliminate_ = MakeSubstitution(std::make_shared(), "addn_eliminate", IsCNodeGraphKernel); - - // Mark interface fusion - mark_interface_fusion_ = - MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); - - // IndexedSlices Eliminate - indexed_slices_eliminate_ = MakeSubstitution( - std::make_shared(), "indexed_slices_eliminate", - {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); -} - -ResolveIRPassLib::ResolveIRPassLib() { - resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); - resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); -} - -InferenceOptPrepareLib::InferenceOptPrepareLib() { - grad_var_prepare_ = MakeSubstitution(std::make_shared(), "grad_var_prepare", IsCNode); -} -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h deleted file mode 100644 index 782eae6124..0000000000 --- a/mindspore/ccsrc/optimizer/irpass.h +++ /dev/null @@ -1,192 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ - -#include - -#include "optimizer/optimizer.h" -#include "optimizer/opt.h" -#include "ir/visitor.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// the collection of irpass for optimie action -class OptimizeIRPassLib { - public: - OptimizeIRPassLib(); - ~OptimizeIRPassLib() = default; - - SubstitutionPtr arithmetic_simplify_; - SubstitutionPtr arithmetic_simplify2_; - SubstitutionPtr special_op_eliminate_; - SubstitutionPtr zero_like_fill_zero_; - SubstitutionPtr adjust_all_reduce_mul_add_; - - // ops eliminate - SubstitutionPtr item_tuple_eliminate_; - SubstitutionPtr tile_eliminate_; - SubstitutionPtr cast_eliminate_; - SubstitutionPtr reshape_eliminate_; - SubstitutionPtr transpose_eliminate_; - SubstitutionPtr reduce_eliminate_; - SubstitutionPtr partial_eliminate_; - SubstitutionPtr same_eliminate_; - SubstitutionPtr check_bprop_eliminate_; - SubstitutionPtr reset_defer_inline_; - SubstitutionPtr depend_value_elim_; - - // Env Item Eliminate - SubstitutionPtr env_get_item_eliminate_; - SubstitutionPtr new_env_get_item_; - SubstitutionPtr incorporate_env_getitem_; - SubstitutionPtr incorporate_env_getitem_switch_; - - // Ref eliminate - SubstitutionPtr make_ref_eliminate_; - SubstitutionPtr get_ref_param_eliminate_; - SubstitutionPtr get_make_ref_eliminate_; - SubstitutionPtr replace_refkey_by_param_; - SubstitutionPtr replace_old_param_; - - // Branch culling - SubstitutionPtr switch_simplify_; - SubstitutionPtr float_tuple_getitem_switch_; - SubstitutionPtr float_env_getitem_switch_; - SubstitutionPtr convert_switch_replacement_; - - // AddN - SubstitutionPtr merge_addn_; - SubstitutionPtr addn_zero_filter_; - - // Gradient irpasses - SubstitutionPtr expand_jprim_; - SubstitutionPtr minmaximum_grad_; - - // inline - SubstitutionPtr inline_; - SubstitutionPtr replace_applicator_; - SubstitutionPtr specialize_transform_; - - // Incorporation - SubstitutionPtr incorporate_getitem_set_; - SubstitutionPtr incorporate_getitem_from_param_; - SubstitutionPtr incorporate_call_; - SubstitutionPtr incorporate_call_switch_; - - // virtual dataset - SubstitutionPtr virtual_dataset_eliminate_; - - // Convert - SubstitutionPtr print_tuple_wrapper_; - - // Unused parameter eliminate - SubstitutionPtr unused_parameter_eliminate_; - SubstitutionPtr unused_output_eliminate_; - - // AddN eliminate - SubstitutionPtr addn_eliminate_; - - // Fusion - SubstitutionPtr mark_interface_fusion_; - - // IndexedSlices Eliminate - SubstitutionPtr indexed_slices_eliminate_; -}; - -// the collection of irpass for resolve action -class ResolveIRPassLib { - public: - ResolveIRPassLib(); - ~ResolveIRPassLib() = default; - - SubstitutionPtr resolver_resolve_; - SubstitutionPtr resolver_getattr_; -}; - -class InferenceOptPrepareLib { - public: - InferenceOptPrepareLib(); - ~InferenceOptPrepareLib() = default; - SubstitutionPtr grad_var_prepare_; -}; - -// predicate functions -inline bool IsNode(const AnfNodePtr &) { return true; } - -inline bool IsCNode(const AnfNodePtr &node) { - if (node != nullptr) { - return node->isa(); - } - return false; -} - -inline bool IsVNode(const AnfNodePtr &node) { - if (node != nullptr) { - return node->isa(); - } - return false; -} - -inline bool IsParam(const AnfNodePtr &node) { - if (node != nullptr) { - return node->isa(); - } - return false; -} - -// Check if CNode Input 0 is Func Graph -inline bool IsCNodeGraph(const AnfNodePtr &node) { - if (node == nullptr || !node->isa()) { - return false; - } - - auto inp0 = node->cast()->input(0); - return IsValueNode(inp0); -} - -// Check if CNode Input 0 is Func Graph of graph kernel. -inline bool IsCNodeGraphKernel(const AnfNodePtr &node) { - if (node == nullptr || !node->isa()) { - return false; - } - - auto inp0 = node->cast()->input(0); - if (IsValueNode(inp0)) { - auto fg = GetValueNode(inp0); - if (fg == nullptr) { - return false; - } - return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - } - return false; -} - -// Check if CNode Input 0 is CNode -inline bool IsCNodeDup(const AnfNodePtr &node) { - if (node == nullptr || !node->isa()) { - return false; - } - - auto inp0 = node->cast()->input(0); - return (inp0 != nullptr) && inp0->isa(); -} -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc deleted file mode 100644 index b111a6b67a..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc +++ /dev/null @@ -1,680 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -#include "optimizer/irpass/arithmetic_simplify.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/prim_eliminate.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} -// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} -AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimScalarMul)(node); - - if (is_zero_) { - return NewValueNode(zero_); - } - if (is_one_) { - return x_; - } - return nullptr; -} - -void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) { - if (is_one_ || node->isa()) { - x_ = node; - return; - } - - AnfVisitor::Visit(node); - if (!is_one_) { - x_ = node; - } -} - -void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (*value == *zero_) { - is_zero_ = true; - } else if (*value == *one_) { - is_one_ = true; - } -} - -void MultiplyByZeroOrOne::Reset() { - x_ = nullptr; - is_one_ = false; - is_zero_ = false; -} - -// Support class used for checking if all values of a Tensor are equal `check_value_` -// Supported data types: double, float/float32, int/int32 -bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - TypeId tensor_type = tensor_ptr->Dtype()->type_id(); - if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { - float *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > FLT_EPSILON) { - return false; - } - } - return true; - } else if (tensor_type == TypeId::kNumberTypeFloat64) { - double *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > DBL_EPSILON) { - return false; - } - } - return true; - } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { - int *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (data2[i] != check_value_) { - return false; - } - } - return true; - } - // input Data Types is not supported - return false; -} - -bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { - return false; - } - return IsTensorConstant(value); -} - -void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) { - if (!node->isa()) { - return nullptr; - } - - auto value = node->cast()->value(); - - if (!value->isa()) { - return nullptr; - } - - tensor::TensorPtr tensor_ptr = dyn_cast(value); - return tensor_ptr->data_c(); -} - -// Make a new tensor (when possible) with the same shape as of `node` -// If x is nullptr then fill new tensor will "0" -// If x is a tensor with empty shape then fill new tensor with the single value of x -// If x is a tensor with same shape as `node` then return x as result -AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) { - if ((node->abstract() == nullptr) || !node->abstract()->isa()) { - return nullptr; - } - - auto tensor_abstract = node->abstract()->cast(); - TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); - - auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); - size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - - if (x == nullptr) { - std::memset(data, 0, mem_size); - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } - // x is not nullptr - if (x->isa()) { - if ((x->abstract() == nullptr) || !x->abstract()->isa()) { - return nullptr; - } - auto x_abstract = x->abstract()->cast(); - std::vector x_shape = x_abstract->shape()->shape(); - - if (x_shape != tensor_shape) { - return nullptr; - } - return x; - } - - if (!x->isa()) { - return nullptr; - } - auto x_value = x->cast()->value(); - if (!x_value->isa()) { - return nullptr; - } - - auto x_tensor_ptr = dyn_cast(x_value); - - if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { - return nullptr; - } - char *source_data = reinterpret_cast(GetPointerToTensorData(x)); - if (x_tensor_ptr->DataSize() == 1) { - for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { - memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr)); - } - } else { - memcpy(data, source_data, mem_size); - } - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; -} - -// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} -AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); - - if (is_zero_) { - if (x_->func_graph() != node->func_graph()) { - return nullptr; - } - return NewTensorFilledWithData(node); - } - return nullptr; -} - -void TensorMultiplyByZero::Visit(const AnfNodePtr &node) { - if (is_zero_) { - x_ = node; - return; - } - - if (IsParam(node)) { - x_ = node; - return; - } - - if (IsCNode(node)) { - CNodePtr cnode = node->cast(); - if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { - is_zero_ = true; - return; - } - x_ = node; - return; - } - auto value = node->cast()->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - x_ = node; -} - -void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - x_ = vnode; -} -void TensorMultiplyByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} -AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); - - if (is_one_) { - return NewTensorFilledWithData(node, x_); - } - return nullptr; -} - -void TensorMultiplyByOne::Visit(const AnfNodePtr &node) { - if (is_one_) { - x_ = node; - return; - } - - if (IsParam(node) || IsCNode(node)) { - x_ = node; - return; - } - - auto value = node->cast()->value(); - if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = node; -} - -void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = vnode; -} -void TensorMultiplyByOne::Reset() { - x_ = nullptr; - is_one_ = false; -} - -// {prim::kPrimScalarAdd, X, 0} -// {prim::kPrimScalarAdd, 0, X} -AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimScalarAdd)(node); - - if (is_zero_) { - return x_; - } - return nullptr; -} - -void AddByZero::Visit(const AnfNodePtr &node) { - if (node->isa() && - ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { - is_zero_ = true; - return; - } - - x_ = node; -} - -void AddByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, -// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} -AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimTensorAdd)(node); - - if (is_zero_) { - return x_; - } - return nullptr; -} - -void TensorAddByZero::Visit(const AnfNodePtr &node) { - if (node->isa() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { - is_zero_ = true; - return; - } - - x_ = node; -} - -void TensorAddByZero::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } -} - -void TensorAddByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} -AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { - return nullptr; - } - - // {PrimMomentum, {...}, Y, Z, Xs} - auto &inputs = node->cast()->inputs(); - if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { - return nullptr; - } - auto y = inputs[2]; - auto z = inputs[3]; - - // {kPrimZerosLike, X} - if (inputs[1]->cast()->size() != 2) { - return nullptr; - } - - // {prim::kPrimMakeTuple, Z, Y} - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); -} - -// {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} -> -// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} -// Support function to multiply two constant tensors: partially support broadcasting shapes -template -void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, - void **out_data, int out_data_size) { - T *data_1 = reinterpret_cast(in_data_1); - T *data_2 = reinterpret_cast(in_data_2); - T *data_out = new T[out_data_size]; - - if (in_data_1_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[i]; - } - } - if (in_data_2_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[i]; - } - } - *out_data = reinterpret_cast(data_out); - return; -} - -AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, - const AnfNodePtr &node_3) { - if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || - (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { - return nullptr; - } - - auto value_1 = GetValueNode(vnode_1); - auto value_2 = GetValueNode(vnode_2); - - if (!value_1->isa() || !value_2->isa()) { - return nullptr; - } - - auto tensor_ptr_1 = dyn_cast(value_1); - auto tensor_ptr_2 = dyn_cast(value_2); - - auto tensor_1_abstract = vnode_1->abstract()->cast(); - auto tensor_2_abstract = vnode_1->abstract()->cast(); - auto tensor_3_abstract = node_3->abstract()->cast(); - - TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); - TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); - TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); - - if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || - (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { - return nullptr; - } - - std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); - - int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); - - if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { - return nullptr; - } - if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { - return nullptr; - } - - void *data_out; - - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), - &data_out, data_out_size); - } else { - if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - // Un-support data types - return nullptr; - } - } - } - - auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); - size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - memcpy(data, data_out, mem_size); - - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; -} - -AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - // {prim::kPrimMul, Tensor1, {...}} - AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); - if (vnode_ == nullptr || c_p_node_ == nullptr) { - return nullptr; - } - - if (!IsCNode(c_p_node_)) { - return nullptr; - } - - auto tensor1 = vnode_; - auto mul = c_p_node_->cast(); - - Reset(); - // {prim::kPrimMul, Tensor2, {...}} - AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); - if (vnode_ == nullptr || c_p_node_ == nullptr) { - return nullptr; - } - auto tensor2 = vnode_; - auto c_p_node = c_p_node_; - - auto PrimMul = GetValueNode(mul->input(0)); - auto fg = node->func_graph(); - - auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); - if (new_mul_tensor == nullptr) { - auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); - return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); - } - return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); -} - -void ConstantDuplicateMul::Visit(const AnfNodePtr &node) { - if (IsValueNode(node)) { - vnode_ = node; - } - - if (IsCNode(node) || IsParam(node)) { - c_p_node_ = node; - } -} - -void ConstantDuplicateMul::Reset() { - vnode_ = nullptr; - c_p_node_ = nullptr; -} - -AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (!IsValueNode(inputs[2])) { - return nullptr; - } - auto scalar = GetValueNode(inputs[2]); - if (scalar->isa() && GetValue(scalar) == 1.0) { - return inputs[1]; - } else if (scalar->isa() && GetValue(scalar) == 1) { - return inputs[1]; - } - return nullptr; -} - -// grad = AllReduce(grad) / worker_number -// grad = grad + weight * decy -// -> -// grad = grad + weight * decy -// grad = AllReduce(grad) / worker_number -// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> -// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} -AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - // {prim::kPrimAddN, Zs} - if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { - return nullptr; - } - auto addn = node->cast(); - if (addn->size() != 2) { - return nullptr; - } - AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); - if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { - return nullptr; - } - auto addn_maketuple = addn->input(1); - - auto fg = all_reduce_fg_; - // addn inputs cross the graph, make the inputs same as allreduce node. - if (z_->isa() && fg != z_->func_graph()) { - auto cnode_z = z_->cast(); - z_ = NewCNode(cnode_z->inputs(), fg); - } - - auto addn_op_node = addn->input(0); - auto make_tuple_op_node = addn->input(1)->cast()->input(0); - - AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); - AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); - AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); - AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); - ProcessDependEdge(fg, addn_maketuple, all_reduce); - return mul; -} - -void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, - const AnfNodePtr &new_node) { - // If has dynamic loss scale. - auto &users_map = fg->manager()->node_users(); - auto it = users_map.find(mul_cnode_); - if (it != users_map.end()) { - auto users = it->second; - for (auto &user_pair : users) { - auto node = user_pair.first; - if (node != addn_maketuple) { - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - fg->manager()->SetEdge(node, user_pair.second, new_node); - } - } - } - } -} - -void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) { - if (level_ == 0) { - level_ = 1; - is_reduce_match_ = false; - // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} - AnfVisitor::Match(prim::kPrimMul)(node); - level_ = 0; - if (is_reduce_match_) { - mul_ = node->cast()->input(0); - mul_cnode_ = node->cast(); - y_ = tmp_; - } else { - z_ = node; - } - } - - if (level_ == 1) { - // {prim::kPrimAllReduce, X} - if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { - auto cnode = node->cast(); - if (cnode->size() > 1) { - all_reduce_ = cnode->input(0); - x_ = cnode->input(1); - is_reduce_match_ = true; - all_reduce_fg_ = cnode->func_graph(); - } - } else { - tmp_ = node; - } - } -} - -void AdjustAllReduceMulAdd::Reset() { - level_ = 0; - is_reduce_match_ = false; - x_ = nullptr; - y_ = nullptr; - z_ = nullptr; - tmp_ = nullptr; - all_reduce_fg_ = nullptr; -} - -AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; -} - -AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; -} -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h deleted file mode 100644 index f4bdb0d655..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ /dev/null @@ -1,259 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ - -#include -#include -#include - -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/prim_eliminate.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} -// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} -class MultiplyByZeroOrOne : public AnfVisitor { - public: - MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} - ~MultiplyByZeroOrOne() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}, is_one_{false}; - ValuePtr zero_, one_; - AnfNodePtr x_{nullptr}; -}; - -// Support class used for checking if all values of a Tensor are equal `check_value_` -// Supported data types: double, float/float32, int/int32 -class CheckTensorConstant { - public: - explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} - ~CheckTensorConstant() = default; - - bool IsTensorConstant(const ValuePtr &value); - bool IsTensorScalarConstant(const ValuePtr &value); - - private: - int check_value_; -}; - -class TensorMultiplyBase : public AnfVisitor { - protected: - void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false); - - // Make a new tensor (when possible) with the same shape as of `node` - // If x is nullptr then fill new tensor will "0" - // If x is a tensor with empty shape then fill new tensor with the single value of x - // If x is a tensor with same shape as `node` then return x as result - AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr); - - AnfNodePtr x_{nullptr}; -}; - -// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} -class TensorMultiplyByZero : public TensorMultiplyBase { - public: - TensorMultiplyByZero() : zero_(MakeValue(0)) {} - ~TensorMultiplyByZero() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}; - ValuePtr zero_; -}; - -// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} -class TensorMultiplyByOne : public TensorMultiplyBase { - public: - TensorMultiplyByOne() {} - ~TensorMultiplyByOne() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_one_{false}; -}; - -// {prim::kPrimScalarAdd, X, 0} -// {prim::kPrimScalarAdd, 0, X} -class AddByZero : public AnfVisitor { - public: - AddByZero() : zero_(MakeValue(0)) {} - ~AddByZero() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Reset(); - - private: - bool is_zero_{false}; - ValuePtr zero_; - AnfNodePtr x_{nullptr}; -}; - -// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, -// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} -class TensorAddByZero : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}; - AnfNodePtr x_{nullptr}; -}; - -// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} -class OptUpdateZeroTensor : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; -}; - -// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> -// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} -class ConstantDuplicateMul : public AnfVisitor { - public: - // Support function to multiply two constant tensors: partially support broadcasting shapes - template - void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, - int out_data_size); - - AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3); - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Reset(); - - private: - AnfNodePtr vnode_; - AnfNodePtr c_p_node_; -}; - -class PowerOneEliminate : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; -}; - -// grad = AllReduce(grad) / worker_number -// grad = grad + weight * decy -// -> -// grad = grad + weight * decy -// grad = AllReduce(grad) / worker_number - -// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> -// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} -class AdjustAllReduceMulAdd : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node); - void Visit(const AnfNodePtr &node) override; - void Reset(); - - private: - int level_{0}; - bool is_reduce_match_{false}; - AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; - AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; - FuncGraphPtr all_reduce_fg_{nullptr}; -}; - -class ArithmeticSimplify : public OptimizerCaller { - public: - ArithmeticSimplify() - : multiply_by_zero_or_one_(std::make_shared()), - tensor_multiply_by_one_(std::make_shared()), - add_by_zero_(std::make_shared()), - tensor_add_by_zero_(std::make_shared()), - identity_(std::make_shared(prim::kPrimIdentity)), - opt_update_zero_tensor_(std::make_shared()), - constant_duplicate_mul_(std::make_shared()), - power_one_(std::make_shared()) { - eliminaters_.emplace_back(multiply_by_zero_or_one_); - eliminaters_.emplace_back(tensor_multiply_by_one_); - eliminaters_.emplace_back(add_by_zero_); - eliminaters_.emplace_back(tensor_add_by_zero_); - eliminaters_.emplace_back(identity_); - eliminaters_.emplace_back(opt_update_zero_tensor_); - eliminaters_.emplace_back(constant_duplicate_mul_); - eliminaters_.emplace_back(power_one_); - } - ~ArithmeticSimplify() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; - - private: - OptimizerCallerPtr multiply_by_zero_or_one_; - OptimizerCallerPtr tensor_multiply_by_one_; - OptimizerCallerPtr add_by_zero_; - OptimizerCallerPtr tensor_add_by_zero_; - OptimizerCallerPtr identity_; - OptimizerCallerPtr opt_update_zero_tensor_; - OptimizerCallerPtr constant_duplicate_mul_; - OptimizerCallerPtr power_one_; - - std::vector eliminaters_{}; -}; - -// Arithmetic Simplifications should be done after step_parallel. -// eg: Mul(0, weight) where weight is a parameter will be simplified to a constant tensor -// with shape(weight), but after step_parallel, shape of weight may be changed, so the -// shape of the constant tensor should also be changed. So this pass is seperated from -// ArithmeticSimplify and deferred until step_parallel. -class ArithmeticSimplify2 : public OptimizerCaller { - public: - ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared()) { - eliminaters_.emplace_back(tensor_multiply_by_zero_); - } - ~ArithmeticSimplify2() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; - - private: - OptimizerCallerPtr tensor_multiply_by_zero_; - std::vector eliminaters_{}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/optimizer/irpass/branch_culling.cc deleted file mode 100644 index 726f4a28b0..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc +++ /dev/null @@ -1,584 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/irpass/branch_culling.h" - -#include -#include -#include - -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -AnfNodePtr GenerateSwitchNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data, - int switch_idx) { - auto switch_node = prim::GetPythonOps("geswitch", "mindspore.ops.functional")->cast(); - std::vector switch_nodes{NewValueNode(switch_node), data, cond}; - auto switch_apply = graph->NewCNode(switch_nodes); - std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), switch_apply, - NewValueNode(MakeValue(switch_idx))}; - return graph->NewCNode(tuple_getitem_nodes); -} - -AnfNodePtr GenerateSwitchTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { - return GenerateSwitchNode(graph, cond, data, 1); -} - -AnfNodePtr GenerateSwitchFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { - return GenerateSwitchNode(graph, cond, data, 0); -} - -bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { - // The CNode inputs of the following Primitive with index in std::vector should not be guarded by geswitch - // node because it is attribute or ge specific reason. - // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be - // converted to switch guarded. - std::vector>> white_list({{prim::kPrimApplyMomentum, {1, 2}}, - {prim::kPrimMomentum, {2, 3}}, - {prim::kPrimStateSetItem, {1}}, - {prim::kPrimTupleGetItem, {2}}, - {prim::kPrimEnvGetItem, {1}}, - {prim::kPrimEnvSetItem, {1}}, - {prim::kPrimReduceSum, {2}}, - {prim::kPrimReduceMean, {2}}, - {prim::kPrimReduceAll, {2}}, - {prim::kPrimCast, {2}}, - {prim::kPrimTranspose, {2}}, - {prim::kPrimOneHot, {2}}, - {prim::kPrimGatherV2, {3}}, - {prim::kPrimReshape, {2}}, - {prim::kPrimAssign, {1}}, - {prim::kPrimAssignAdd, {1}}, - {prim::kPrimAssignSub, {1}}, - {prim::kPrimTensorSummary, {1}}, - {prim::kPrimImageSummary, {1}}, - {prim::kPrimScalarSummary, {1}}, - {prim::kPrimApplyRMSProp, {6, 7, 8}}, - {prim::kPrimCumSum, {2}}, - {prim::kPrimTile, {2}}, - {prim::kPrimExpandDims, {2}}, - {prim::kPrimHistogramSummary, {1}}}); - for (auto &item : white_list) { - auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { - return IsPrimitiveCNode(node, item.first) && idx == index; - }); - if (matched) { - return true; - } - } - - std::vector adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend}; - for (auto &item : adapter_convert_ops) { - if (IsPrimitiveCNode(node, item)) { - return true; - } - } - return false; -} - -using NodeInputReplMap = std::unordered_map, AnfNodePtr, PairHasher>; -// replace the nodes which should be changed -void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector> nodes_changed, - std::unordered_map repl_node, NodeInputReplMap repl_node_inputs, - const FuncGraphPtr &func_graph) { - for (auto &node_pair : nodes_changed) { - CNodePtr old_node = node_pair.first; - CNodePtr new_node = node_pair.second; - MS_EXCEPTION_IF_NULL(old_node); - MS_EXCEPTION_IF_NULL(new_node); - for (size_t i = 0; i < old_node->size(); i++) { - auto input = old_node->input(i); - if (repl_node.count(input) != 0) { - new_node->add_input(repl_node[input]); - } else if (repl_node_inputs.count(std::pair(old_node, i)) != 0) { - new_node->add_input(repl_node_inputs[std::pair(old_node, i)]); - } else { - new_node->add_input(input); - } - } - } - - for (auto &item : repl_node) { - if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) { - func_graph->set_output(item.second->cast()->input(1)); - } else if (!manager->Replace(item.first, item.second)) { - MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2) - << " to new: " << item.second->DebugString(2); - } - } -} - -// trace the node that should add switch and replace them with new nodes in the graph -FuncGraphPtr TransformGraphCondBranchNodes( - const FuncGraphPtr &graph, const AnfNodePtr &cond, - const std::function &generate_func) { - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - // record the node that has been changed - std::vector> nodes_changed; - // record the node to be replaced - std::unordered_map repl_node; - // record the node input to be replaced - NodeInputReplMap repl_node_inputs; - const AnfNodeSet &nodes = graph->nodes(); - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto inputs = node->cast()->inputs(); - bool should_replace = false; - // if the apply input does not belong to graph, insert a switch node - for (size_t index = 0; index < inputs.size(); index++) { - auto input_node = inputs[index]; - MS_EXCEPTION_IF_NULL(input_node); - // for some ops input should not guard it with switch - if (InConvertWhiteList(node, index)) { - continue; - } - - // If the input for node is not the graph belonged, or it is an ValueNode. - // Bypass the Primitive node which is inputs[0]. - if ((index >= 1 && inputs[index]->func_graph() != nullptr && inputs[index]->func_graph() != graph) || - ((index >= 1 && inputs[index]->isa()))) { - input_node = generate_func(graph, cond, inputs[index]); - repl_node_inputs[std::pair(node, index)] = input_node; - should_replace = true; - } - if (input_node == nullptr) { - MS_LOG(EXCEPTION) << "generate switch node failed"; - } - } - if (should_replace) { - auto new_node = graph->NewCNode(); - repl_node[node] = new_node; - nodes_changed.emplace_back(node->cast(), new_node); - } - } - RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph); - return graph; -} - -struct SharedOp { - tensor::TensorPtr const_data; - CNodePtr square_ops[2]; - CNodePtr merge_ops[2]; -} MergeNetOutput; - -inline tensor::TensorPtr GetConstData() { return MergeNetOutput.const_data; } -inline void SetConstData(const tensor::TensorPtr &const_value) { MergeNetOutput.const_data = const_value; } - -inline CNodePtr GetSquareOp(int switch_idx) { return MergeNetOutput.square_ops[switch_idx]; } -inline void SetSquareOp(int switch_idx, const CNodePtr &op) { MergeNetOutput.square_ops[switch_idx] = op; } - -inline CNodePtr GetMergeOp(int switch_idx) { return MergeNetOutput.merge_ops[switch_idx]; } -inline void SetMergeOp(int switch_idx, const CNodePtr &op) { MergeNetOutput.merge_ops[switch_idx] = op; } - -inline void ResetSharedOp() { - SetConstData(nullptr); - SetSquareOp(0, nullptr); - SetSquareOp(1, nullptr); - SetMergeOp(0, nullptr); - SetMergeOp(1, nullptr); -} - -tensor::TensorPtr ConstData() { - std::vector shp = {1}; - tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); - auto *val = static_cast(const_data->data_c()); - *val = 0; - return const_data; -} - -CNodePtr SquareOp(const FuncGraphPtr &graph, const AnfNodePtr &cond, int switch_idx, - const tensor::TensorPtr &const_data) { - auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); - // for the depended node , add two const data to merge the flow ,one for depended node with same switch, - // the other use the opposite - auto ctrl_data = NewValueNode(const_data); - auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); - - std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; - auto square_op = graph->NewCNode(square_nodes); - - return square_op; -} - -CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int switch_idx, - const tensor::TensorPtr &const_data, const CNodePtr &square_op) { - // for the depended node , add two const data to merge the flow ,one for depended node with same switch, - // the other use the opposite - auto oppsite_ctrl_data = NewValueNode(const_data); - auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); - - std::vector merge_nodes; - auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); - merge_nodes.push_back(NewValueNode(PrimMerge)); - std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; - merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); - auto merge_op = graph->NewCNode(merge_nodes); - - return merge_op; -} - -// construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) -// control_depend(output_node, square_op) -AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node, - int switch_idx) { - tensor::TensorPtr const_data = GetConstData(); - if (const_data == nullptr) { - const_data = ConstData(); - SetConstData(const_data); - } - - CNodePtr square_op = GetSquareOp(switch_idx); - if (square_op == nullptr) { - square_op = SquareOp(graph, cond, switch_idx, const_data); - SetSquareOp(switch_idx, square_op); - } - - CNodePtr merge_op = GetMergeOp(switch_idx); - if (merge_op == nullptr) { - merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op); - SetMergeOp(switch_idx, merge_op); - } - - std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op}; - auto control_depend_op = graph->NewCNode(control_depend_nodes); - - std::vector depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op}; - auto depend_op = graph->NewCNode(depend_nodes); - - return depend_op; -} - -// construct a merge output and add dependency with the netoutput node from control_depend -// we need to reserve the control_depend node, besides the generated merge node and control_depend node -CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, - const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst, - int switch_idx) { - auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); - auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); - std::vector shp = {1}; - tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); - auto *val = static_cast(const_data->data_c()); - *val = 0; - // for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same - // switch the other use the opposite - auto ctrl_data = NewValueNode(const_data); - auto oppsite_ctrl_data = NewValueNode(const_data); - auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); - auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); - - std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; - auto square_op = graph->NewCNode(square_nodes); - - std::vector merge_nodes; - merge_nodes.push_back(NewValueNode(PrimMerge)); - std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; - merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); - auto merge_output = graph->NewCNode(merge_nodes); - - std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op}; - auto cond_dep_output = graph->NewCNode(control_depend_nodes); - - std::vector depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output, - cond_dep_output}; - return graph->NewCNode(depended_make_tuple_nodes); -} - -// generate switch nodes for true graph node inputs -AnfNodePtr GenerateSwitchDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { - // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch - return GenerateSwitchDependNode(graph, cond, data, 1); -} - -// generate switch nodes for false graph node inputs -AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { - // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch - return GenerateSwitchDependNode(graph, cond, data, 0); -} - -// generate switch nodes for true graph node inputs -CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, - const AnfNodePtr &con_input, const AnfNodePtr &output) { - // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch - return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1); -} - -// generate switch nodes for false graph node inputs -CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, - const AnfNodePtr &con_input, const AnfNodePtr &output) { - // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch - return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0); -} - -// to judge if the node used in ControlDepend is a net output node -bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { - auto uses = manager->node_users()[node]; - bool is_output_node = true; - for (auto &item : uses) { - if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) { - continue; - } - is_output_node = false; - break; - } - return is_output_node; -} - -// generate node for Depended MakeTuple -void GenerateReplNodeForDependMakeTuple( - const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond, - const std::shared_ptr> &repl_node, - const std::function &generate_func, - const std::function &gen_ctl_depd_func) { - MS_EXCEPTION_IF_NULL(graph->manager()); - - auto make_tuple_inputs = depended_node->cast()->inputs(); - const size_t make_tuple_begin_idx = 1; - std::vector new_make_tuple_nodes; - bool replace_make_tuple = false; - new_make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t idx = make_tuple_begin_idx; idx < make_tuple_inputs.size(); idx++) { - auto depended_tuple_input_node = make_tuple_inputs[idx]; - if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimDepend)) { - new_make_tuple_nodes.push_back(depended_tuple_input_node); - continue; - } - if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimControlDepend)) { - // only when the control depend input is not square op (the op to use as merge output) - auto control_inputs = depended_tuple_input_node->cast()->inputs(); - if (control_inputs.size() != 3) { - MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); - } - // control inputs: primitive, src, dst - auto dst_node = control_inputs[2]; - if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { - auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node); - MS_EXCEPTION_IF_NULL(gen_node); - auto tuple_inputs = gen_node->inputs(); - // add depended tuple inputs to new_make_tuple directly - for (size_t i = 1; i < tuple_inputs.size(); i++) { - new_make_tuple_nodes.push_back(tuple_inputs[i]); - } - } - replace_make_tuple = true; - continue; - } - - if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) { - auto gen_node = generate_func(graph, cond, depended_tuple_input_node); - new_make_tuple_nodes.push_back(gen_node); - replace_make_tuple = true; - continue; - } - - MS_LOG(WARNING) << "depended node being used by others, "; - } - if (replace_make_tuple) { - auto make_tuple_op = graph->NewCNode(new_make_tuple_nodes); - (*repl_node)[depended_node] = make_tuple_op; - } -} - -// generate a replace depend node for a single network output node -void GenerateRepDepend( - const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond, - const std::shared_ptr> &repl_node, - const std::function &generate_func, - const std::function &gen_ctl_depd_func) { - auto inputs = node->inputs(); - if (inputs.size() != 3) { - MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node]."; - } - - std::vector new_depened_inputs; - // Inputs should be [depend, actual_value, depended_node] - auto depended_node = inputs[2]; - new_depened_inputs.push_back(inputs[0]); - new_depened_inputs.push_back(inputs[1]); - // depended node should be make_tuple or a single depended node - if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) { - GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func); - } else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) { - // only when the control depend input is not square op (the op to use as merge output) - auto control_inputs = depended_node->cast()->inputs(); - // control inputs: primitive, src, dst - if (control_inputs.size() != 3) { - MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); - } - auto dst_node = control_inputs[2]; - if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { - auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node); - (*repl_node)[depended_node] = gen_node; - } - } else { - // Check if there is only single user for depend_node. - if (graph->manager()->node_users()[depended_node].size() == 1) { - auto gen_node = generate_func(graph, cond, depended_node); - (*repl_node)[depended_node] = gen_node; - } else { - MS_LOG(WARNING) << "depended node being used by others"; - } - } -} - -// generate depend node for netoutput node, to resolve the stream synchronize problem of ge -// traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const) -// and add control_depend of graph output node and square node. -FuncGraphPtr TransformGraphDependNode( - const FuncGraphPtr &graph, const AnfNodePtr &cond, - const std::function &gen_depend_func, - const std::function &gen_ctl_depd_func) { - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - ResetSharedOp(); - std::shared_ptr> repl_node = - std::make_shared>(); // record the node to be replaced - const AnfNodeSet &nodes = graph->nodes(); - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - if (IsPrimitiveCNode(node, prim::kPrimDepend)) { - auto cnode = node->cast(); - if (cnode->size() != 3) { - MS_LOG(EXCEPTION) << "Dependnode input size != 3"; - } - auto depended_node = cnode->input(2); - MS_EXCEPTION_IF_NULL(depended_node); - if (!depended_node->isa()) { - continue; - } - if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) { - continue; - } - GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func); - } - } - ResetSharedOp(); - - for (auto &item : *repl_node) { - if (!manager->Replace(item.first, item.second)) { - MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed"; - } - } - - return graph; -} - -FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { - (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode); - return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode); -} - -FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { - (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode); - return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode); -} - -// judge if the true and false graph output is compatible(they shall have same tuple size) -bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const AbstractBasePtr &false_branch_abs) { - MS_EXCEPTION_IF_NULL(true_branch_abs); - MS_EXCEPTION_IF_NULL(false_branch_abs); - - if (true_branch_abs->isa() && false_branch_abs->isa()) { - abstract::AbstractTuplePtr true_branch_tuple = true_branch_abs->cast(); - abstract::AbstractTuplePtr false_branch_tuple = false_branch_abs->cast(); - if (true_branch_tuple->elements().size() != false_branch_tuple->elements().size()) { - MS_LOG(ERROR) << "true branch size:" << true_branch_tuple->elements().size() - << ", not equal to false banch size:" << false_branch_tuple->elements().size() << " "; - return false; - } - bool all_compatible = true; - for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { - all_compatible = - all_compatible && GraphOutputCompatible(true_branch_tuple->elements()[i], false_branch_tuple->elements()[i]); - } - return all_compatible; - } - TypePtr true_branch_type = true_branch_abs->BuildType(); - TypePtr false_branch_type = false_branch_abs->BuildType(); - MS_LOG(DEBUG) << "branch output Type equal?" << (*true_branch_type == *false_branch_type) - << " true:" << true_branch_type->ToString() << " false:" << false_branch_type->ToString(); - return (*true_branch_type == *false_branch_type); -} - -AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, - const AbstractBasePtr &true_graph_output_abs, - const AbstractBasePtr &false_graph_output_abs, const FuncGraphPtr &switch_graph, - const AnfNodePtr &cond) { - MS_EXCEPTION_IF_NULL(true_graph_output_abs); - MS_EXCEPTION_IF_NULL(false_graph_output_abs); - MS_EXCEPTION_IF_NULL(cond); - MS_EXCEPTION_IF_NULL(switch_graph); - auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); - MS_EXCEPTION_IF_NULL(PrimMerge); - - if (!true_graph_output_abs->isa()) { - std::vector merge_nodes; - merge_nodes.push_back(NewValueNode(PrimMerge)); - std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), true_output_node, false_output_node}; - merge_nodes.push_back(switch_graph->NewCNode(make_tuple_nodes)); - std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), - switch_graph->NewCNode(merge_nodes), NewValueNode(MakeValue(0))}; - return switch_graph->NewCNode(tuple_getitem_nodes); - } else { - abstract::AbstractTuplePtr true_branch_tuple = true_graph_output_abs->cast(); - abstract::AbstractTuplePtr false_branch_tuple = false_graph_output_abs->cast(); - - std::vector make_tuple_nodes; - make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { - std::vector true_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), true_output_node, - NewValueNode(MakeValue(SizeToInt(i)))}; - auto true_node = switch_graph->NewCNode(true_getitem_nodes); - std::vector false_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), false_output_node, - NewValueNode(MakeValue(SizeToInt(i)))}; - auto false_node = switch_graph->NewCNode(false_getitem_nodes); - - auto merge_node = GenerateMergeNodes(true_node, false_node, true_branch_tuple->elements()[i], - false_branch_tuple->elements()[i], switch_graph, cond); - make_tuple_nodes.push_back(merge_node); - } - return switch_graph->NewCNode(make_tuple_nodes); - } -} - -AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, - const AbstractBasePtr &true_graph_output_abs, - const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, - const FuncGraphPtr &switch_graph) { - if (!GraphOutputCompatible(true_graph_output_abs, false_graph_output_abs)) { - MS_LOG(EXCEPTION) << "Switch output branch not compatible, true:" << true_graph_output_abs->ToString() - << ", false:" << false_graph_output_abs->ToString(); - } - return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, - switch_graph, cond); -} -} // namespace internal -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/optimizer/irpass/branch_culling.h deleted file mode 100644 index 2b5b30bdbf..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.h +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ - -#include -#include - -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "ir/optimizer_caller.h" -#include "ir/pattern_matcher.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimSwitch, true, X, Y} -// {prim::kPrimSwitch, false, X, Y} -class SwitchSimplify : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode cond, true_br, false_br; - auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { - auto cond_value_ = GetValue(GetValueNode(cond.GetNode(node))); - if (cond_value_) { - return true_br.GetNode(node); - } - return false_br.GetNode(node); - }; - - MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda, - cond.CheckFunc(IsValueNode, node)); - - return nullptr; - } -}; - -// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => -// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} -class FloatTupleGetItemSwitch : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode cond, true_br, false_br, x; - MATCH_REPLACE_IF(node, - PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), - PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x), - PPrimitive(prim::kPrimTupleGetItem, false_br, x)), - x.CheckFunc(IsVNode, node)); - return nullptr; - } -}; - -// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => -// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} -class FloatEnvGetItemSwitch : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode cond, true_br, false_br, x, x2; - MATCH_REPLACE(node, - PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), - PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2), - PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2))); - - return nullptr; - } -}; - -namespace internal { -FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); -FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); -AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, - const AbstractBasePtr &true_graph_output_abs, - const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, - const FuncGraphPtr &func_graph); -} // namespace internal - -// {{prim::kPrimSwitch, X, G1, G2}, Xs} -class ConvertSwitchReplacement : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto cnode_ = node->cast(); - if (cnode_->size() < 1) { - return nullptr; - } - - auto node_ = cnode_->input(0); - - PatternNode cond, true_br, false_br; - - auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { - auto g1_ = GetValueNode(true_br.GetNode(node_)); - auto g2_ = GetValueNode(false_br.GetNode(node_)); - auto x_ = cond.GetNode(node_); - - // for switch replace method, only graphs without graph inside can be replaced - for (auto &item : g1_->value_nodes()) { - auto value_node = item.first; - if (IsValueNode(value_node)) { - return nullptr; - } - } - - for (auto &item : g2_->value_nodes()) { - auto value_node = item.first; - if (IsValueNode(value_node)) { - return nullptr; - } - } - - auto true_output = g1_->output()->abstract(); - auto false_output = g2_->output()->abstract(); - auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_); - auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); - - std::vector params; - auto fg = node_->func_graph(); - auto cloned_g1 = InlineClone(trans_g1, fg, params); - auto cloned_g2 = InlineClone(trans_g2, fg, params); - auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); - - return nnode; - }; - - MATCH_REPLACE_LAMBDA_IF( - node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, - true_br.CheckFunc(IsValueNode, node_) && false_br.CheckFunc(IsValueNode, node_)); - - return nullptr; - } -}; - -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.cc b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.cc deleted file mode 100644 index a497f3d5bd..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.cc +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/irpass/cast_eliminate.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "ir/func_graph.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/python_adapter.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimCast, X, T} -AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimCast, {IsNode, IsVNode})(node); - - // check pattern match - if (tgt_ == nullptr) { - return nullptr; - } - - // src type check - auto src_type = src_->Type(); - if (src_type == nullptr || !src_type->isa()) { - return nullptr; - } - - src_type = src_type->cast()->element(); - - // tgt type check - auto tgt_type = GetValueNode(tgt_); - if (tgt_type->isa()) { - tgt_type = tgt_type->cast()->element(); - } - - if (src_type->type_id() == tgt_type->type_id()) { - return src_; - } - - return nullptr; -} - -void CastSameTypeEliminater::Visit(const AnfNodePtr &node) { - if (src_ == nullptr) { - src_ = node; - } else { - tgt_ = node; - } -} - -// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} -AnfNodePtr TwoCastEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimCast, {IsCNode, IsNode})(node); - - if (x_ != nullptr && t_ != nullptr) { - auto cast_op = parse::python_adapter::GetPyFn("mindspore.ops.operations", "Cast")(); - ValuePtr cast = parse::data_converter::PyDataToValue(cast_op); - auto cnode = NewCNode({NewValueNode(cast), x_, t_}, node->func_graph()); - cnode->set_abstract(node->abstract()); - return cnode; - } - return nullptr; -} - -void TwoCastEliminater::Visit(const AnfNodePtr &node) { - if (IsPrimitiveCNode(node, prim::kPrimCast)) { - auto cnode = node->cast(); - // {prim::kPrimCast, X, Y} - if (cnode->size() != 3) { - return; - } - x_ = cnode->input(1); - } else { - t_ = node; - } -} -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h deleted file mode 100644 index d98d0b677b..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ - -#include "ir/visitor.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimCast, X, T} -class CastSameTypeEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - void Visit(const AnfNodePtr &node) override; - void Reset() { - src_ = nullptr; - tgt_ = nullptr; - } - - private: - AnfNodePtr src_{nullptr}, tgt_{nullptr}; -}; - -// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} -class TwoCastEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - void Visit(const AnfNodePtr &node) override; - void Reset() { - x_ = nullptr; - t_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}, t_{nullptr}; -}; - -class CastEliminater : public OptimizerCaller { - public: - CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} - ~CastEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - auto new_node = cast_same_type_eliminater_(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - - new_node = two_cast_eliminater_(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - - return nullptr; - } - - private: - CastSameTypeEliminater cast_same_type_eliminater_; - TwoCastEliminater two_cast_eliminater_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/convert.h b/mindspore/ccsrc/optimizer/irpass/convert.h deleted file mode 100644 index 3049bafb1e..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/convert.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ - -#include - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimPrint, Xs} -> {prim::kPrimPrint, {prim::kPrinMakeTuple, Xs}} -class PrintTupleWrapper : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!IsPrimitiveCNode(node, prim::kPrimPrint)) { - return nullptr; - } - - // already be {prim::kPrimPrint, {prim::kPrinMakeTuple, Xs}} - auto cnode = node->cast(); - if (cnode->size() == 2 && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeTuple)) { - return nullptr; - } - - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - - // {prim::kPrimPrint, Xs} - auto &inputs = cnode->inputs(); - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - - // {prim::kPrinMakeTuple, Xs} - auto fg = node->func_graph(); - auto tuple = NewCNode(args, fg); - auto print = GetValueNode(cnode->input(0)); - return NewCNode({NewValueNode(print), tuple}, fg); - } -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h deleted file mode 100644 index 3f100dcaec..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h +++ /dev/null @@ -1,364 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ - -#include -#include -#include -#include -#include - -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "utils/symbolic.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -class EnvGetitemTransform { - public: - EnvGetitemTransform() : cache_() {} - ~EnvGetitemTransform() = default; - - FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) { - if (cache_.find(fg) == cache_.end()) { - cache_[fg] = {}; - } - - auto &cache = cache_[fg]; - auto hash_key = std::make_pair(key, default_node); - if (cache.find(hash_key) == cache.end()) { - std::ostringstream ss("env", std::ostringstream::app); - if (key->node() != nullptr) { - ss << key->node()->ToString(); - } - - auto new_fg = TransformableClone(fg, std::make_shared(ss.str())); - auto env = new_fg->output(); - while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { - // {prim::kPrimEnvSetItem, env, symbolickey, value} - auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4 || !IsValueNode(inputs[2])) { - MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance."; - } - - env = inputs[1]; - auto value = inputs[3]; - auto key2 = GetValueNode(inputs[2]); - if (*key2 == *key) { - new_fg->set_output(value); - cache[hash_key] = new_fg; - cache_[fg] = cache; - return new_fg; - } - } - new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node})); - cache[hash_key] = new_fg; - } - - return cache[hash_key]; - } - - private: - std::unordered_map, FuncGraphPtr, PairHasher>> - cache_; -}; -} // namespace internal - -// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y -class NewEnvGetItem : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - auto gety = [this](const AnfNodePtr &node) -> bool { - this->y_ = node; - return true; - }; - - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode, IsVNode, gety})(node); - if (env_ != nullptr && env_->Len() == 0) { - return y_; - } - return nullptr; - } - - void Visit(const ValueNodePtr &vnode) override { - if (env_ == nullptr) { - env_ = GetValueNode(vnode); - } - } - - void Reset() { - y_ = nullptr; - env_ = nullptr; - } - - private: - AnfNodePtr y_{nullptr}; - EnvInstancePtr env_{nullptr}; -}; - -// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> -// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}} -class AddEnvGetItem : public AnfVisitor { - public: - AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {} - ~AddEnvGetItem() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - is_match_ = false; - auto IsAddCNode = [](const AnfNodePtr &node) -> bool { - return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast()->size() == 3; - }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node); - - if (!is_match_ || node->func_graph() == nullptr) { - return nullptr; - } - - // {prim::kPrimEnvGetItem, {...}, C, Z} - auto cnode = node->cast(); - auto inp1 = cnode->input(1)->cast(); - auto c = cnode->input(2); - auto z = cnode->input(3); - - // {prim::kPrimEnvAdd, X, Y} - auto x = inp1->input(1); - auto y = inp1->input(2); - - auto fg = node->func_graph(); - auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z}); - auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z}); - - return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz}); - } - - void Visit(const AnfNodePtr &) override { is_match_ = true; } - - private: - bool is_match_{false}; - ValuePtr PrimHyperAdd_; -}; - -// {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z} -class EnvGetSetItem : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - is_match_ = false; - auto IsSetCNode = [](const AnfNodePtr &node) -> bool { - if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) { - return false; - } - - // {prim::kPrimEnvSetItem, X, C1, Y} - auto &inputs = node->cast()->inputs(); - if (inputs.size() != 4) { - return false; - } - - return IsValueNode(inputs[2]); - }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode, IsNode})(node); - - if (!is_match_ || node->func_graph() == nullptr) { - return nullptr; - } - - // {prim::kPrimEnvGetItem, {...}, C2, Z} - auto cnode = node->cast(); - auto inp1 = cnode->input(1)->cast(); - auto key2 = cnode->input(2); - auto c2 = GetValueNode(key2); - auto default_v = cnode->input(3); - - // {prim::kPrimEnvSetItem, X, C1, Y} - auto env = inp1->input(1); - auto c1 = GetValueNode(inp1->input(2)); - auto last_set = inp1->input(3); - - if (*c1 == *c2) { - return last_set; - } - - while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { - // {prim::kPrimEnvSetItem, env, symbolickey, value} - auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4 || !IsValueNode(inputs[2])) { - MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance."; - } - - env = inputs[1]; - last_set = inputs[3]; - auto symbolic_c1 = GetValueNode(inputs[2]); - if (*symbolic_c1 == *c2) { - return last_set; - } - } - - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v}); - } - - void Visit(const AnfNodePtr &) override { is_match_ = true; } - - private: - bool is_match_{false}; -}; - -class EnvGetItemEliminater : public OptimizerCaller { - public: - EnvGetItemEliminater() - : new_env_get_item_(std::make_shared()), - add_env_get_item_(std::make_shared()), - env_get_set_item_(std::make_shared()) { - eliminaters_.emplace_back(new_env_get_item_); - eliminaters_.emplace_back(add_env_get_item_); - eliminaters_.emplace_back(env_get_set_item_); - } - ~EnvGetItemEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } - - private: - OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; - std::vector eliminaters_{}; -}; - -// {prim::kPrimEnvGetItem, {G, Xs}, C, Y} -class IncorporateEnvGetitem : public AnfVisitor { - public: - IncorporateEnvGetitem() : env_get_item_transform_() {} - ~IncorporateEnvGetitem() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - is_match_ = false; - auto IsGCNode = [](const AnfNodePtr &node) -> bool { - auto cnode = node->cast(); - if (cnode == nullptr || cnode->size() < 1) { - return false; - } - return IsValueNode(cnode->input(0)); - }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode, IsNode})(node); - - if (!is_match_) { - return nullptr; - } - - // {prim::kPrimEnvGetItem, {...}, C, Y} - auto cnode = node->cast(); - auto inp1 = cnode->input(1)->cast(); - auto key = GetValueNode(cnode->input(2)); - auto default_v = cnode->input(3); - - // {G, Xs} - auto inputs = inp1->inputs(); - auto fg = GetValueNode(inputs[0]); - auto new_fg = env_get_item_transform_(fg, key, default_v); - - std::vector args; - args.push_back(NewValueNode(new_fg)); - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - - return node->func_graph()->NewCNode(args); - } - - void Visit(const AnfNodePtr &) override { is_match_ = true; } - - private: - bool is_match_{false}; - internal::EnvGetitemTransform env_get_item_transform_; -}; - -// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} -class IncorporateEnvGetitemSwitch : public AnfVisitor { - public: - IncorporateEnvGetitemSwitch() : env_get_item_transform_() {} - ~IncorporateEnvGetitemSwitch() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - is_match_ = false; - auto IsSwNode = [](const AnfNodePtr &node) -> bool { - auto cnode = node->cast(); - if (cnode == nullptr || cnode->size() < 1) { - return false; - } - - return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch); - }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode, IsNode})(node); - if (!is_match_ || node->func_graph() == nullptr) { - return nullptr; - } - - // {prim::kPrimEnvGetItem, {...}, C, Y} - auto cnode = node->cast(); - auto inp1 = cnode->input(1)->cast(); - auto key = GetValueNode(cnode->input(2)); - auto default_v = cnode->input(3); - - // {{prim::kPrimSwitch, X, G1, G2}, Xs} - auto inputs = inp1->inputs(); - is_match_ = false; - AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(inputs[0]); - if (!is_match_) { - return nullptr; - } - - // {prim::kPrimSwitch, X, G1, G2} - auto sw = inputs[0]->cast(); - auto x = sw->input(1); - auto g1 = GetValueNode(sw->input(2)); - auto g2 = GetValueNode(sw->input(3)); - auto new_g1 = env_get_item_transform_(g1, key, default_v); - auto new_g2 = env_get_item_transform_(g2, key, default_v); - - auto fg = node->func_graph(); - auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)}); - - std::vector args{new_sw}; - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - - return fg->NewCNode(args); - } - - void Visit(const AnfNodePtr &) override { is_match_ = true; } - - private: - bool is_match_{false}; - internal::EnvGetitemTransform env_get_item_transform_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc deleted file mode 100644 index 317d67e792..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc +++ /dev/null @@ -1,143 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/irpass/grad_var_prepare.h" -#include -#include -#include -#include - -#include "operator/composite/composite.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" - -namespace mindspore { -namespace opt { -namespace irpass { -static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, FuncGraphPtr func_graph, - AnfNodePtr func_node, bool is_unpack, bool sens_param) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(func_node); - std::vector nodes; - AnfNodePtr unpack_graph_node = nullptr; - if (is_unpack) { - auto unpack_graph = std::make_shared("unpack_graph", sens_param, true); - nodes.push_back(NewValueNode(unpack_graph)); - nodes.push_back(func_node); - // {unpackcall, {GradOperation, ...}, args...} - std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), - [](const AnfNodePtr &node) { return node; }); - unpack_graph_node = func_graph->NewCNode(nodes); - } else { - auto unpack_graph = std::make_shared("unpack_graph", sens_param, false); - nodes.push_back(NewValueNode(unpack_graph)); - nodes.push_back(func_node); - // {{GradOperation, ...}, args...} - std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), - [](const AnfNodePtr &node) { return node; }); - unpack_graph_node = func_graph->NewCNode(nodes); - } - return unpack_graph_node; -} - -// get metagraph of value node -MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) { - ValuePtr value; - if (IsValueNode(node)) { - value = GetValueNode(node)->cast()->function(); - } else { - value = GetValueNode(node); - } - if (value == nullptr) { - return nullptr; - } - return value->cast(); -} - -// check if node is a specific metafuncgraph op -bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) { - if (node != nullptr) { - auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); - if (meta_func_graph_ptr == nullptr) { - return false; - } - - if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) { - return true; - } - } - return false; -} - -// {{GradOperation, g, w}, Ys} -// {UnPackCall, {GradOperation, g, w}, Ys} -AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - // {{...}, Ys} - auto inputs_y = node->cast()->inputs(); - std::vector inputs_x; - if (IsCNode(inputs_y[0])) { - inputs_x = inputs_y[0]->cast()->inputs(); - } else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) { - inputs_x = inputs_y[1]->cast()->inputs(); - } else { - return nullptr; - } - - // {{...}, Xs} - if (inputs_x.size() < 2) { - return nullptr; - } - - // {GradOperation, g, w} or {GradOperation, g} - if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) { - return nullptr; - } - - auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]); - if (meta_func == nullptr) { - return nullptr; - } - auto grad_op_ptr = meta_func->cast(); - auto func_node = inputs_x[1]; - if (!IsValueNode(func_node)) { - return nullptr; - } - - AnfNodePtr unpack_graph_node = - GenerateUnpackGraphNode(inputs_y, node->cast()->func_graph(), func_node, - IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param()); - // constuct new grad_opration - inputs_x[1] = unpack_graph_node; - auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x); - if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) { - inputs_y[1] = grad_op_cnode; - } else { - inputs_y[0] = grad_op_cnode; - } - auto cnode = node->func_graph()->NewCNode(inputs_y); - return cnode; -} -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h deleted file mode 100644 index 9713017d12..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ - -#include -#include -#include -#include - -#include "operator/composite/composite.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {{GradOperation, g, w}, Ys} -// {UnPackCall, {GradOperation, g, w}, Ys} -class GradVarPrepare : public AnfVisitor { - public: - GradVarPrepare() - : grad_op_(std::make_shared("grad")), - unpack_op_(std::make_shared("unpack_call")) {} - ~GradVarPrepare() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - private: - MetaFuncGraphPtr grad_op_; - MetaFuncGraphPtr unpack_op_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.cc b/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.cc deleted file mode 100644 index 3347fa9dc0..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/irpass/gradient_eliminate.h" - -#include - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { - ScopeGuard scope_guard(vnode->scope()); - - auto newg = ad::Kprim(vnode, resource); - if (newg != nullptr) { - return NewValueNode(newg); - } - - // when find in J failed, try in Jmeta - auto prim = GetValueNode(vnode); - MetaFuncGraphPtr meta = ad::Kmeta(prim, resource); - if (meta != nullptr) { - return NewValueNode(meta); - } - - return nullptr; -} - -bool CheckIfEmbedJFuncGraph(const FuncGraphPtr func_graph) { - // if func graph also contain J FuncGraph, then ignore this funcgraph. ExpandJ innermost graph first; - auto func_graph_manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(func_graph_manager); - return func_graph_manager->func_graph_j_total(func_graph); -} - -AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { - if (IsValueNode(vnode)) { - ScopeGuard scope_guard(vnode->scope()); - - auto func_graph = GetValueNode(vnode); - MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString(); - - // high_order_grad begin; - // if graph also contain J Graph, then ignore this graph. ExpandJ innermost graph first; - if (CheckIfEmbedJFuncGraph(func_graph)) { - MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J(funcgraph), will expandJ later"; - return nullptr; - } - // high_order_grad end; - - MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now"; - auto newfg = ad::Grad(func_graph, resource); - return NewValueNode(newfg); - } - - if (IsValueNode(vnode)) { - return ExpandJPrimitive(vnode, resource); - } - - return nullptr; -} -} // namespace internal -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h b/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h deleted file mode 100644 index 671d9bde49..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ - -#include -#include -#include - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "common/utils.h" -#include "operator/ops.h" -#include "optimizer/ad/grad.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource); -} // namespace internal - -// {prim::kPrimJ, C} -class ExpandJPrim : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim::kPrimJ, {IsVNode})(node); - if (x_ != nullptr) { - TraceManager::DebugTrace(std::make_shared(node->debug_info())); - auto j_node = internal::ExpandJ(x_, optimizer->resource()); - TraceManager::EndTrace(); - return j_node; - } - return nullptr; - } - - void Visit(const ValueNodePtr &node) override { x_ = node; } - - private: - ValueNodePtr x_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/incorporate_call.h b/mindspore/ccsrc/optimizer/irpass/incorporate_call.h deleted file mode 100644 index 5842b7bfd6..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/incorporate_call.h +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ - -#include -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -class CallOutputTransform { - public: - CallOutputTransform() : cache_() {} - ~CallOutputTransform() = default; - - FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs) { - if (cache_.find(fg) == cache_.end()) { - cache_[fg] = {}; - } - - auto &cache = cache_[fg]; - if (cache.find(nargs) == cache.end()) { - FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("call")); - - std::vector new_items; - new_items.push_back(new_fg->output()); - for (size_t i = 0; i < nargs; i++) { - new_items.push_back(new_fg->add_parameter()); - } - new_fg->set_output(new_fg->NewCNode(new_items)); - - cache[nargs] = new_fg; - } - return cache[nargs]; - } - - private: - std::unordered_map> cache_; -}; -} // namespace internal - -// {{G, Xs}, Ys} -class IncorporateCall : public AnfVisitor { - public: - IncorporateCall() : call_output_transform_() {} - ~IncorporateCall() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (inputs[0] == nullptr || !inputs[0]->isa()) { - return nullptr; - } - - AnfVisitor::Visit(inputs[0]); - if (fg_ == nullptr) { - return nullptr; - } - - auto xs_size = Xs_.size(); - auto ys_size = inputs.size() - 1; - auto new_fg = call_output_transform_(fg_, ys_size); - - std::vector args; - args.push_back(NewValueNode(new_fg)); - - if (xs_size > 0) { - (void)args.insert(args.end(), Xs_.begin(), Xs_.end()); - } - - if (ys_size > 0) { - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - } - - return node->func_graph()->NewCNode(args); - } - - void Visit(const CNodePtr &cnode) override { - // {G, Xs} - if (cnode->size() < 1 || !IsValueNode(cnode->input(0))) { - return; - } - - auto &inputs = cnode->inputs(); - fg_ = GetValueNode(inputs[0]); - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); - } - - void Reset() { - Xs_.clear(); - fg_ = nullptr; - } - - private: - FuncGraphPtr fg_; - std::vector Xs_{}; - internal::CallOutputTransform call_output_transform_; -}; - -// {{{prim::kPrimSwitch, X, G1, G2}, Xs}, Ys} -class IncorporateCallSwitch : public AnfVisitor { - public: - IncorporateCallSwitch() : call_output_transform_() {} - ~IncorporateCallSwitch() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - // {{...}, Ys} - auto &inputs = node->cast()->inputs(); - if (inputs[0] == nullptr || !inputs[0]->isa()) { - return nullptr; - } - - // {{...}, Xs} - auto &inputs_x = inputs[0]->cast()->inputs(); - if (inputs_x[0] == nullptr || !inputs_x[0]->isa()) { - return nullptr; - } - - // {prim::kPrimSwitch, X, G1, G2} - AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(inputs_x[0]); - if (g2_ == nullptr) { - return nullptr; - } - - auto fg = node->func_graph(); - auto xs_size = inputs_x.size() - 1; - auto ys_size = inputs.size() - 1; - auto new_g1 = call_output_transform_(g1_, ys_size); - auto new_g2 = call_output_transform_(g2_, ys_size); - auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); - - std::vector args{sw_node}; - if (xs_size > 0) { - (void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end()); - } - if (ys_size > 0) { - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - } - - return fg->NewCNode(args); - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - return; - } - AnfVisitor::Visit(node); - } - - void Visit(const ValueNodePtr &vnode) override { - auto g = GetValueNode(vnode); - if (g1_ == nullptr) { - g1_ = g; - } else { - g2_ = g; - } - } - - void Reset() { - x_ = nullptr; - g1_ = nullptr; - g2_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}; - FuncGraphPtr g1_{nullptr}, g2_{nullptr}; - internal::CallOutputTransform call_output_transform_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h deleted file mode 100644 index b6c8fb0e18..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h +++ /dev/null @@ -1,416 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ - -#include -#include -#include -#include -#include - -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -class GetitemTransform { - public: - GetitemTransform() : cache_() {} - ~GetitemTransform() = default; - - FuncGraphPtr operator()(const FuncGraphPtr &fg, int idx) { - if (cache_.find(fg) == cache_.end()) { - cache_[fg] = {}; - } - - auto &cache = cache_[fg]; - if (cache.find(idx) == cache.end()) { - std::ostringstream ss("tp", std::ostringstream::app); - ss << idx; - - auto new_fg = TransformableClone(fg, std::make_shared(ss.str())); - auto output = new_fg->output(); - if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { - auto cnode = output->cast(); - auto ids = IntToSize(idx + 1); - // Inputs should be [make_tuple, item1, item2, ...], so have to offset idx in tuple_getitem by 1. - if (ids >= cnode->size()) { - MS_LOG(EXCEPTION) << "index " << ids << " is out of inputs length " << cnode->size(); - } - new_fg->set_output(cnode->input(ids)); - } else { - new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(idx)})); - } - - cache[idx] = new_fg; - } - return cache[idx]; - } - - private: - std::unordered_map> cache_; -}; -} // namespace internal - -// {prim::kPrimTupleGetItem, {G, Xs}, C} -class IncorporateGetitem : public AnfVisitor { - public: - IncorporateGetitem() : getitem_transform_() {} - ~IncorporateGetitem() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); - if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr) { - return nullptr; - } - - if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - // If graph kernel has muti output, do not split. - // some graph kernel output has EnvInstance node or DeadCode node should split. - auto output = fg_->output(); - if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { - auto output_cnode = output->cast(); - auto outputs = output_cnode->inputs(); - int real_output_cnt = 0; - for (size_t i = 1; i < outputs.size(); ++i) { - if (IsCNode(outputs[i]) || IsValueNode(outputs[i]) || IsParam(outputs[i])) { - real_output_cnt++; - if (real_output_cnt > 1) { - return nullptr; - } - } - } - } - } - - auto new_fg = getitem_transform_(fg_, idx_); - (void)args_.insert(args_.begin(), NewValueNode(new_fg)); - return node->func_graph()->NewCNode(args_); - } - - void Visit(const CNodePtr &cnode) override { - if (cnode->size() == 0 || !IsValueNode(cnode->input(0))) { - return; - } - - auto &inputs = cnode->inputs(); - fg_ = GetValueNode(inputs[0]); - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); - } - - void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } - - void Reset() { - idx_ = -1; - fg_ = nullptr; - args_.clear(); - } - - private: - int idx_{-1}; - FuncGraphPtr fg_{nullptr}; - std::vector args_{}; - internal::GetitemTransform getitem_transform_; -}; - -class IncorporateGetitemFromParam : public AnfVisitor { - public: - void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr ¶m, size_t input_idx) { - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - auto &node_users = mng->node_users(); - if (node_users.find(param) == node_users.end() || node_users[param].empty()) { - args_.push_back(cnode->input(input_idx + 1)); - return; - } - - for (auto &user : node_users[param]) { - if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { - // we do not process this case. - args_.push_back(cnode->input(input_idx + 1)); - return; - } - } - - // update new args. - if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) { - // case 1 - replace_parameters_[input_idx] = true; - need_update_ = true; - auto make_tuple_cnode = cnode->input(input_idx + 1)->cast(); - auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs(); - inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1; - args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end()); - } else { - // case 2 - auto prev_cnode = cnode->input(input_idx + 1)->cast(); - auto prev_fg = GetValueNode(prev_cnode->input(0)); - auto fg_output = prev_fg->output(); - if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) { - MS_LOG(ERROR) << "The return of: " << prev_fg->ToString() - << " should be a make tuple, but got: " << fg_output->DebugString(); - return; - } - replace_parameters_[input_idx] = true; - need_update_ = true; - auto make_tuple_cnode = fg_output->cast(); - inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1; - for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) { - auto new_getitem = - func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToInt(output_i))}); - auto aptr = std::make_shared(std::make_shared(SizeToInt(output_i))); - new_getitem->input(2)->set_abstract(aptr); - new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract()); - args_.push_back(new_getitem); - } - } - } - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (node->func_graph() == nullptr) { - return nullptr; - } - - Reset(); - - auto cnode = node->cast(); - if (cnode == nullptr) { - return nullptr; - } - auto &inputs = cnode->inputs(); - auto fg = GetValueNode(inputs[0]); - if (fg == nullptr) { - return nullptr; - } - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - auto parameters = fg->parameters(); - if (parameters.size() != inputs.size() - 1) { - return nullptr; - } - replace_parameters_ = std::vector(parameters.size(), false); - inputs_num_ = std::vector(parameters.size(), 1); - auto node_fg = node->func_graph(); - - for (size_t i = 1; i < inputs.size(); ++i) { - if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) { - Process(node_fg, cnode, parameters[i - 1], i - 1); - } else { - args_.push_back(inputs[i]); - } - } - - if (!need_update_) { - return nullptr; - } - - FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("sp")); - mng->AddFuncGraph(new_fg); - - auto node_users = mng->node_users(); - std::vector new_fg_parameters = new_fg->parameters(); - std::vector new_parameters; - size_t curr_input_idx{0}; - for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) { - if (!replace_parameters_[param_i]) { - if (parameters[param_i]->abstract() != nullptr) { - new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract()); - } - new_parameters.push_back(new_fg_parameters[param_i]); - curr_input_idx++; - continue; - } - - // make a new parameter. - for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) { - auto new_param = std::make_shared(new_fg); - new_param->set_abstract(args_.at(curr_input_idx)->abstract()); - - // update users of new parameter. - for (auto &user : node_users[new_fg_parameters[param_i]]) { - idx_ = -1; - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode})(user.first); - if (idx_ == -1) { - MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString() - << " must be tuple getitem here, but got: " << user.first->DebugString(); - return nullptr; - } - - if (input_i == IntToSize(idx_)) { - for (auto &sub_user : node_users[user.first]) { - auto sub_user_cnode = sub_user.first->cast(); - MS_EXCEPTION_IF_NULL(sub_user_cnode); - sub_user_cnode->set_input(sub_user.second, new_param); - (void)mng->Replace(sub_user.first, sub_user_cnode); - } - } - } - - // (void)mng->Replace(new_fg_parameters[param_i], new_param); - new_parameters.push_back(new_param); - curr_input_idx++; - } - } - - mng->SetParameters(new_fg, new_parameters); - (void)args_.insert(args_.begin(), NewValueNode(new_fg)); - auto new_call = node_fg->NewCNode(args_); - new_call->set_abstract(node->abstract()); - return new_call; - } - - void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } - - void Visit(const CNodePtr &cnode) override {} - - void Reset() { - replace_parameters_.clear(); - args_.clear(); - inputs_num_.clear(); - need_update_ = false; - idx_ = -1; - } - - private: - std::vector replace_parameters_{}; - std::vector args_{}; - std::vector inputs_num_{}; - bool need_update_{false}; - int idx_{-1}; -}; - -// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C} -class IncorporateGetitemSwitch : public AnfVisitor { - public: - IncorporateGetitemSwitch() : getitem_transform_() {} - ~IncorporateGetitemSwitch() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - is_in_get_ = true; - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); - is_in_get_ = false; - - auto fg = node->func_graph(); - if (idx_ == -1 || switch_ == nullptr || fg == nullptr) { - return nullptr; - } - - is_in_switch_ = true; - AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(switch_); - is_in_switch_ = false; - - if (g2_ == nullptr) { - return nullptr; - } - - auto new_g1 = getitem_transform_(g1_, idx_); - auto new_g2 = getitem_transform_(g2_, idx_); - auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); - (void)args_.insert(args_.begin(), sw_node); - - return fg->NewCNode(args_); - } - - void Visit(const AnfNodePtr &node) override { - if (is_in_switch_ && x_ == nullptr) { - x_ = node; - return; - } - AnfVisitor::Visit(node); - } - - void Visit(const CNodePtr &cnode) override { - if (is_in_get_ && cnode->size() != 0) { - auto &inputs = cnode->inputs(); - switch_ = inputs[0]; - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); - } - } - - void Visit(const ValueNodePtr &vnode) override { - if (is_in_get_) { - idx_ = GetValue(vnode->value()); - } - - if (is_in_switch_) { - auto g = GetValueNode(vnode); - if (g1_ == nullptr) { - g1_ = g; - } else { - g2_ = g; - } - } - } - - void Reset() { - x_ = nullptr; - g1_ = nullptr; - g2_ = nullptr; - switch_ = nullptr; - args_.clear(); - is_in_get_ = false; - is_in_switch_ = false; - } - - private: - int idx_{-1}; - AnfNodePtr switch_{nullptr}, x_{nullptr}; - FuncGraphPtr g1_{nullptr}, g2_{nullptr}; - bool is_in_get_{false}, is_in_switch_{false}; - std::vector args_{}; - internal::GetitemTransform getitem_transform_; -}; - -class IncorporateGetitemSet : public OptimizerCaller { - public: - IncorporateGetitemSet() - : incorporate_getitem_(std::make_shared()), - incorporate_getitem_switch_(std::make_shared()) { - eliminaters_.emplace_back(incorporate_getitem_); - eliminaters_.emplace_back(incorporate_getitem_switch_); - } - ~IncorporateGetitemSet() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } - - private: - OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; - std::vector eliminaters_{}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h b/mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h deleted file mode 100644 index 630d567549..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ - -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}} -// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}} -// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}} -class IndexedSlicesEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(1); - } - AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(2); - } - AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(3); - } - return nullptr; - } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) { - tuple_ = cnode; - is_match_ = true; - } - } - - void Reset() { - tuple_ = nullptr; - is_match_ = false; - } - - private: - bool is_match_{false}; - CNodePtr tuple_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/inline.h b/mindspore/ccsrc/optimizer/irpass/inline.h deleted file mode 100644 index 64f192347c..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/inline.h +++ /dev/null @@ -1,204 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ - -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -class ReplaceApplicator : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!IsValueNode(node)) { - return nullptr; - } - - auto fg = GetValueNode(node); - if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { - return nullptr; - } - - auto out = fg->output(); - MS_EXCEPTION_IF_NULL(out); - if (!out->isa()) { - return nullptr; - } - - auto &inputs = out->cast()->inputs(); - auto params = fg->parameters(); - - // Exclude first elements of inputs which is fn. - auto input_size = inputs.size(); - auto param_size = params.size(); - if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size && - std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) { - auto inner = inputs[0]; - if (IsValueNode(inner) || - (IsValueNode(inner) && GetValueNode(inner)->parent() == nullptr)) { - return inner; - } - } - - return nullptr; - } -}; - -using CriterionFuncType = std::function; - -bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { - auto n_cnode = fg->nodes().size() - fg->parameters().size(); - // There is at least one CNode(return, other_node). - return n_cnode <= 2; -} - -bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { - auto &cnodes = fg->func_graph_cnodes_index(); - int n_use = - std::accumulate(cnodes.begin(), cnodes.end(), 0, - [](int sum, const std::pair &item) { return sum + item.second; }); - return n_use == 1; -} - -bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node->func_graph()); - return node->func_graph()->has_flag("inline_inside"); -} - -bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } - -bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } - -// {G, Xs} -class InlinerBase : public AnfVisitor { - public: - explicit InlinerBase(std::vector> criterions) : criterions_(criterions) {} - ~InlinerBase() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa()) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (inputs.size() < 1 || !IsValueNode(inputs[0])) { - return nullptr; - } - - // G - auto fg = GetValueNode(inputs[0]); - if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { - return nullptr; - } - // Do not inline GraphKernel to Cell. - if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - // If the GraphKernel only contains a return node, we make it inlined. - if (fg->nodes().size() - fg->parameters().size() > 1) { - return nullptr; - } - } - - Reset(); - bool is_match = false; - for (auto &criterion : criterions_) { - if (!criterion.first(fg, node)) { - continue; - } - - if (criterion.second && IsRecursive(fg)) { - continue; - } - - is_match = true; - break; - } - - if (!is_match) { - return nullptr; - } - - std::vector params; - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params)); - - if (IsUniqueUse(fg, nullptr)) { - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - ReplaceParams(mng, params, fg); - auto out_node = fg->output(); - mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); - return out_node; - } - - return InlineClone(fg, node->func_graph(), params, inputs[0]->scope()); - } - - void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector &new_params, - const FuncGraphPtr &fg) { - auto params = fg->parameters(); - auto old_size = params.size(); - if (old_size != new_params.size()) { - MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size() - << fg->output()->DebugString(10); - } - for (size_t i = 0; i < old_size; i++) { - (void)mng->Replace(params[i], new_params[i]); - } - } - - bool IsRecursive(const FuncGraphPtr &fg) { - if (!is_checked_) { - is_checked_ = true; - is_recursive_ = fg->recursive(); - } - return is_recursive_; - } - - void Reset() { - is_checked_ = false; - is_recursive_ = false; - } - - private: - bool is_checked_{false}, is_recursive_{false}; - std::vector> criterions_; -}; - -class Inliner : public InlinerBase { - public: - Inliner() - : InlinerBase({ - {IsUniqueUse, true}, - {IsTrivial, false}, - {IsInside, false}, - {IsCore, false}, - {NoCriterion, true}, - }) {} - ~Inliner() override = default; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h deleted file mode 100644 index 202951a254..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h +++ /dev/null @@ -1,301 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ - -#include -#include -#include - -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// (a, b, c, ...)[0] => a -// (a, b, c, ...)[1] => b -// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} -class GetitemEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); - - if (is_match_) { - return tuple_->input(id_); - } - return nullptr; - } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { - tuple_ = cnode; - } - } - - void Visit(const ValueNodePtr &vnode) override { - if (tuple_ != nullptr && IsValueNode(vnode)) { - id_ = IntToSize(GetValue(vnode->value()) + 1); - if (tuple_->size() > id_) { - is_match_ = true; - } - } - } - - void Reset() { - id_ = 0; - tuple_ = nullptr; - is_match_ = false; - } - - private: - bool is_match_{false}; - size_t id_{0}; - CNodePtr tuple_{nullptr}; -}; - -// (a, b, c, ...)[0] => a -// (a, b, c, ...)[1] => b -// {prim::kPrimTupleGetItem, C1, C} -class GetitemConstEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node); - - if (is_match_) { - return NewValueNode((*tuple_)[id_]); - } - return nullptr; - } - - void Visit(const ValueNodePtr &vnode) override { - if (IsValueNode(vnode)) { - tuple_ = GetValueNode(vnode); - } - if (tuple_ != nullptr && IsValueNode(vnode)) { - id_ = IntToSize(GetValue(vnode->value())); - if (tuple_->size() > id_) { - is_match_ = true; - } - } - } - - void Reset() { - id_ = 0; - tuple_ = nullptr; - is_match_ = false; - } - - private: - bool is_match_{false}; - size_t id_{0}; - ValueTuplePtr tuple_{nullptr}; -}; - -// setitem((a, b, c, ...), 0, z) => (z, b, c, ...) -// setitem((a, b, c, ...), 1, z) => (a, z, c, ...) -// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} -class SetitemEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); - - auto fg = node->func_graph(); - if (fg != nullptr && z_ != nullptr) { - args_[id_] = z_; - return fg->NewCNode(args_); - } - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (is_match_) { - z_ = node; - return; - } - - AnfVisitor::Visit(node); - } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { - auto &inputs = cnode->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(args_)); - } - } - - void Visit(const ValueNodePtr &vnode) override { - if (args_.size() > 0 && IsValueNode(vnode)) { - id_ = IntToSize(GetValue(vnode->value()) + 1); - if (id_ < args_.size()) { - is_match_ = true; - } - } - } - - void Reset() { - id_ = 0; - z_ = nullptr; - is_match_ = false; - args_.clear(); - } - - private: - bool is_match_{false}; - size_t id_{0}; - AnfNodePtr z_{nullptr}; - std::vector args_{}; -}; - -// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} -class GetSetitemEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); - - auto fg = node->func_graph(); - if (fg != nullptr && key1_ >= 0 && key2_ >= 0) { - if (key1_ == key2_) { - return last_; - } - return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_, c2_}); - } - return nullptr; - } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) { - if (cnode->size() < 4) { - return; - } - - tuple_ = cnode->input(1); - last_ = cnode->input(3); - - // key of setitem - is_in_set_ = true; - AnfVisitor::Visit(cnode->input(2)); - is_in_set_ = false; - } - } - - void Visit(const ValueNodePtr &vnode) override { - if (IsValueNode(vnode)) { - auto key = GetValue(vnode->value()); - if (is_in_set_) { - key1_ = key; - } else { - c2_ = vnode; - key2_ = key; - } - } - } - - void Reset() { - key1_ = -1; - key2_ = -1; - c2_ = nullptr; - last_ = nullptr; - tuple_ = nullptr; - is_in_set_ = false; - } - - private: - bool is_in_set_{false}; - int key1_{-1}, key2_{-1}; - AnfNodePtr tuple_{nullptr}, last_{nullptr}, c2_{nullptr}; -}; - -// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> -// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} -class GetitemDependReorder : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); - if (x_ == nullptr) { - return nullptr; - } - - auto fg = node->func_graph(); - auto item_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), x_, c_}, fg); - return NewCNode({NewValueNode(prim::kPrimDepend), item_node, y_}, fg); - } - - void Visit(const CNodePtr &cnode) override { - // {prim::kPrimDepend, X, Y} - if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && cnode->size() == 3) { - x_ = cnode->input(1); - y_ = cnode->input(2); - } - } - - void Visit(const ValueNodePtr &vnode) override { c_ = vnode; } - - void Reset() { - x_ = nullptr; - y_ = nullptr; - c_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; -}; - -class ItemTupleEliminater : public OptimizerCaller { - public: - ItemTupleEliminater() - : get_item_eliminater_(std::make_shared()), - get_item_const_eliminater_(std::make_shared()), - set_item_eliminater_(std::make_shared()), - get_set_item_eliminater_(std::make_shared()), - get_item_depend_reorder_(std::make_shared()) { - eliminaters_.emplace_back(get_item_eliminater_); - eliminaters_.emplace_back(get_item_const_eliminater_); - eliminaters_.emplace_back(set_item_eliminater_); - eliminaters_.emplace_back(get_set_item_eliminater_); - eliminaters_.emplace_back(get_item_depend_reorder_); - } - ~ItemTupleEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } - - private: - OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, - get_item_depend_reorder_; - std::vector eliminaters_{}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/mark_interface_fusion.h b/mindspore/ccsrc/optimizer/irpass/mark_interface_fusion.h deleted file mode 100644 index 6f2bcc187f..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/mark_interface_fusion.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "utils/graph_utils.h" -#include "operator/composite/composite.h" - -namespace mindspore { -namespace opt { -namespace irpass { - -static int count = 0; - -std::string GetFusionNumber() { - std::stringstream ss; - ss << std::setw(4) << std::setfill('0') << count; - std::string num = ss.str(); - ++count; - - return "_" + num; -} - -// Mark CNodes which can be merged in kernel build -class MarkInterfaceFusion : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) { - auto cnode = node->cast(); - auto condition = cnode->input(1); - std::string cmp; - std::unordered_map cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"}, - {"LessEqual", "LE"}, {"Less", "LT"}, - {"Equal", "EQ"}, {"NotEqual", "NE"}}; - if (IsPrimitiveCNode(condition)) { - auto prim_name = GetCNodeFuncName(condition->cast()); - if (cmp_list.count(prim_name) != 0) { - // Mark Select and compare node - cmp = cmp_list[prim_name]; - auto cnt = GetFusionNumber(); - AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition); - AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node); - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) { - AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i)); - } - } - } - } - } - return nullptr; - } - - void Visit(const AnfNodePtr &) override {} - - private: - AnfNodePtr y_{nullptr}; -}; - -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H diff --git a/mindspore/ccsrc/optimizer/irpass/merge_addn.h b/mindspore/ccsrc/optimizer/irpass/merge_addn.h deleted file mode 100644 index e1e4b8878b..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/merge_addn.h +++ /dev/null @@ -1,320 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ - -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} -> -// {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}} -// {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} -> -// {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}} -class MergeAddN : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - Reset(); - optimizer_ = optimizer; - is_outer_ = true; - AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); - if (!is_match_ || node->func_graph() == nullptr) { - return nullptr; - } - - auto cnode = node->cast(); - auto addn = NewValueNode(GetValueNode(cnode->input(0))); - - // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs} - (void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple)); - auto fg = node->func_graph(); - auto make_node = fg->NewCNode(args_); - - return fg->NewCNode({addn, make_node}); - } - - void Visit(const CNodePtr &cnode) override { - if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { - return; - } - - auto &inputs = cnode->inputs(); - - if (is_outer_) { - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_)); - - is_outer_ = false; - is_inner_ = true; - - // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys} - AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs[1]); - if (is_match_) { - if (!is_unique(inputs[1])) { - is_match_ = false; - return; - } - (void)Ys_.erase(Ys_.begin()); - (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); - (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); - return; - } - - // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}} - AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs.back()); - if (is_match_) { - if (!is_unique(inputs.back())) { - is_match_ = false; - return; - } - Ys_.pop_back(); - (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); - (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); - return; - } - - return; - } - - if (is_inner_) { - is_match_ = true; - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); - } - } - - bool is_unique(const AnfNodePtr &node) { - auto mng = optimizer_->resource()->manager(); - auto &node_users = mng->node_users(); - if (node_users.find(node) == node_users.end()) { - return false; - } - - size_t n_use = node_users[node].size(); - return n_use == 1; - } - - void Reset() { - Xs_.clear(); - Ys_.clear(); - args_.clear(); - is_inner_ = false; - is_outer_ = false; - is_match_ = false; - } - - private: - OptimizerPtr optimizer_{nullptr}; - std::vector Xs_{}, Ys_{}, args_{}; - bool is_inner_{false}, is_outer_{false}, is_match_{false}; -}; - -// {PrimAddN, {kPrimMakeTuple, Xs}} -class AddNZeroFilter : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); - - if (filtered_Xs_.empty() || node->func_graph() == nullptr) { - return nullptr; - } - - // if only two node in filtered_nodes, {make_tuple, x}. return x. - if (filtered_Xs_.size() == 2) { - return filtered_Xs_[1]; - } - - // if only one node in filtered_nodes, all node is zerolike, return one of the input. - if (filtered_Xs_.size() == 1 && Xs_.size() > 0) { - return Xs_[0]; - } - - if (!has_zero_like_) { - return nullptr; - } - - auto cnode = node->cast(); - auto addn = NewValueNode(GetValueNode(cnode->input(0))); - auto fg = node->func_graph(); - auto make_tuple = fg->NewCNode(filtered_Xs_); - return fg->NewCNode({addn, make_tuple}); - } - - void Visit(const CNodePtr &cnode) override { - if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { - return; - } - - auto &inputs = cnode->inputs(); - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); - - // {kPrimMakeTuple, X1, X2, ...} - filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (auto &x : Xs_) { - if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) { - filtered_Xs_.push_back(x); - } else { - has_zero_like_ = true; - } - } - } - - void Reset() { - Xs_.clear(); - filtered_Xs_.clear(); - has_zero_like_ = false; - } - - private: - std::vector filtered_Xs_{}, Xs_{}; - bool has_zero_like_{false}; -}; - -// {PrimAddN, {kPrimMakeTuple, Xs}} -// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd. -// case0: AddN(inputs)(inputs size < 2) -> error -// case1: AddN(inputs)(all inputs is ValueNode) -> error -// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor) -// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input) -// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) -class AddNEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - auto fg = GetValueNode(inputs[0]); - MS_EXCEPTION_IF_NULL(fg); - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - if (fg->recursive()) { - return nullptr; - } - - auto new_fg = TransformableClone(fg, std::make_shared("fg")); - mng->AddFuncGraph(new_fg); - need_update_ = false; - bool changed; - do { - changed = Process(new_fg); - } while (changed); - - if (!need_update_) { - return nullptr; - } else { - auto new_sx = inputs; - new_sx[0] = NewValueNode(new_fg); - return node->func_graph()->NewCNode(new_sx); - } - } - - bool Process(const FuncGraphPtr &func_graph) { - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - auto nodes = TopoSort(func_graph->output()); - bool changed = false; - - for (size_t i = 0; i < nodes.size(); ++i) { - auto node = nodes[i]; - if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { - continue; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto &tuple_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(tuple_input); - auto tuple_input_cnode = tuple_input->cast(); - MS_EXCEPTION_IF_NULL(tuple_input_cnode); - auto &tuple_inputs = tuple_input_cnode->inputs(); - if (tuple_inputs.size() < 3) { - // case0: inputs size < 2, error - MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2); - } - - int valuenode_num = - std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) { - if (IsValueNode(node)) { - return accumulator + 1; - } else { - return accumulator; - } - }); - if (IntToSize(valuenode_num) == tuple_inputs.size()) { - // case1: all inputs is ValueNode, error - MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2); - } - - if (tuple_inputs.size() == 3) { - // case2: inputs size = 2, -> TensorAdd(Tensor, Tensor) - MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2); - ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); - std::vector new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1], - tuple_inputs[2]}; - mng->Replace(node, func_graph->NewCNode(new_xs)); - changed = true; - continue; - } - - auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(), - [](const AnfNodePtr &node) { return IsValueNode(node); }); - if (first_valuenode == tuple_inputs.end()) { - // no ValueNode input found. - continue; - } else { - // case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) - std::vector make_tuple_new_xs{ - NewValueNode(prim::kPrimMakeTuple), - }; - std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(), - [&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) { - if (node != *first_valuenode) { - make_tuple_new_xs.push_back(node); - } - }); - ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations"); - auto new_addn = func_graph->NewCNode( - {func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)}); - ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); - auto new_add = - func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn}); - (void)mng->Replace(node, new_add); - changed = true; - continue; - } - } - - need_update_ = need_update_ || changed; - return changed; - } - - private: - bool need_update_{false}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/minmax_grad.h b/mindspore/ccsrc/optimizer/irpass/minmax_grad.h deleted file mode 100644 index a426a9fb9b..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/minmax_grad.h +++ /dev/null @@ -1,110 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ - -#include -#include - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -// check if node is MinimumGrad() or MaximumGrad() -bool IsOriginMaxMinGrad(const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimMaximumGrad) && !IsPrimitiveCNode(node, prim::kPrimMinimumGrad)) { - return false; - } - - auto cnode = node->cast(); - auto prim = GetValueNode(cnode->input(0)); - auto x_v = prim->GetAttr("grad_x"); - auto y_v = prim->GetAttr("grad_y"); - if (x_v == nullptr || y_v == nullptr || !x_v->isa() || !y_v->isa()) { - return false; - } - - bool x = GetValue(x_v); - bool y = GetValue(y_v); - return x && y; -} -} // namespace internal - -// {prim::kPrimTupleGetItem, {target_grad, Xs}, C} -class MinMaximumGrad : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {internal::IsOriginMaxMinGrad, IsValueNode})(node); - if (grad_ == nullptr || idx_ < 0 || idx_ > 1 || node->func_graph() == nullptr) { - return nullptr; - } - - // check single use - auto mng = optimizer->resource()->manager(); - auto &users = mng->node_users(); - if (users.find(grad_) == users.end() || users[grad_].size() != 1) { - return nullptr; - } - - // {target_grad, Xs} - auto &inputs = grad_->inputs(); - auto prim = GetValueNode(inputs[0]); - - auto new_prim = std::make_shared(prim->name()); - new_prim->set_attr("grad_x", MakeValue(true)); - new_prim->set_attr("grad_y", MakeValue(true)); - - if (idx_ == 0) { - new_prim->set_attr("grad_y", MakeValue(false)); - } - if (idx_ == 1) { - new_prim->set_attr("grad_x", MakeValue(false)); - } - - std::vector args; - args.push_back(NewValueNode(new_prim)); - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - - auto fg = node->func_graph(); - auto tuple = fg->NewCNode(args); - - return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple, NewValueNode(MakeValue(idx_))}); - } - - void Visit(const CNodePtr &cnode) override { grad_ = cnode; } - - void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } - - void Reset() { - idx_ = -1; - grad_ = nullptr; - } - - private: - int idx_{-1}; - CNodePtr grad_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/param_replace.h b/mindspore/ccsrc/optimizer/irpass/param_replace.h deleted file mode 100644 index c0c4c832d7..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/param_replace.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ - -#include - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "pipeline/parse/parse.h" - -namespace mindspore { -namespace opt { -namespace irpass { -class ReplaceOldParam : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - if (!IsParam(node)) { - return nullptr; - } - auto resource = std::dynamic_pointer_cast(optimizer->resource()); - MS_EXCEPTION_IF_NULL(resource); - - auto top_graph = resource->func_graph(); // parse::Parser::GetTopFuncGraph(); - MS_EXCEPTION_IF_NULL(top_graph); - - auto param_node = node->cast(); - if (!param_node->has_default() || node->func_graph() == top_graph) { - return nullptr; - } - auto para_name = param_node->name(); - for (const auto &tnode : top_graph->parameters()) { - auto para = tnode->cast(); - if (para != nullptr && para->name() == para_name) { - return para; - } - } - return nullptr; - } -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/partial_eliminate.h b/mindspore/ccsrc/optimizer/irpass/partial_eliminate.h deleted file mode 100644 index bc8ef9d8f3..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/partial_eliminate.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ - -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} -class PartialEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - Xs_.clear(); - auto &inputs = node->cast()->inputs(); - Visit(inputs[0]); - - if (Xs_.size() == 0) { - return nullptr; - } - - // {X, Xs, Ys} - std::vector args{}; - (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args)); - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); - TraceManager::DebugTrace(std::make_shared(node->debug_info())); - auto new_node = node->func_graph()->NewCNode(args); - TraceManager::EndTrace(); - return new_node; - } - - void Visit(const AnfNodePtr &node) override { - if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { - return; - } - - auto &inputs = node->cast()->inputs(); - // {prim::kPrimPartial, X, Xs} - if (inputs.size() < 2) { - return; - } - - // fill Xs - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); - } - - private: - std::vector Xs_{}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/prim_eliminate.h b/mindspore/ccsrc/optimizer/irpass/prim_eliminate.h deleted file mode 100644 index 725c30a6b9..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/prim_eliminate.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim, X} -class PrimEliminater : public AnfVisitor { - public: - explicit PrimEliminater(const PrimitivePtr &prim) : prim_(prim) {} - ~PrimEliminater() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim_, {IsNode})(node); - return x_; - } - - void Visit(const AnfNodePtr &node) override { x_ = node; } - - private: - AnfNodePtr x_{nullptr}; - PrimitivePtr prim_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/reduce_eliminate.h b/mindspore/ccsrc/optimizer/irpass/reduce_eliminate.h deleted file mode 100644 index d2e1d15f91..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/reduce_eliminate.h +++ /dev/null @@ -1,160 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ - -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/dshape.h" - -namespace mindspore { -namespace opt { -namespace irpass { -using abstract::Shape; -using abstract::ShapePtr; - -// {ReduceLike, X, axis} -class ReduceOneEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - PrimitivePtr prim; - if (IsPrimitiveCNode(node, prim::kPrimReduceMean) || IsPrimitiveCNode(node, prim::kPrimReduceAll) || - IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimReduceMax) || - IsPrimitiveCNode(node, prim::kPrimReduceMin)) { - prim = GetValueNode(node->cast()->input(0)); - AnfVisitor::Match(prim, {IsNode, IsVNode})(node); - if (!is_axis_one_) { - return nullptr; - } - - // consider keep_dims - auto keep_dims = prim->GetAttr("keep_dims"); - auto is_keep_dims = GetValue(keep_dims); - // {_Reduce, X, axis} -> X - if (is_keep_dims) { - return x_; - } - - // {_Reduce, Tensor} - if (is_tensor_) { - return nullptr; - } - - // {_Reduce, X, axis} -> {Reshape, X, new_shape} - std::vector elements; - for (size_t i = 0; i < x_shape_.size(); i++) { - auto iter = find(axis_.begin(), axis_.end(), i); - if (iter == axis_.end()) { - ValuePtr s = MakeValue(x_shape_[i]); - elements.push_back(s); - } - } - auto new_shape = std::make_shared(elements); - auto reshape_op = prim::GetPythonOps("reshape", "mindspore.ops.functional")->cast(); - return node->func_graph()->NewCNode({NewValueNode(reshape_op), x_, NewValueNode(new_shape)}); - } - - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (!IsVNode(node) && x_ == nullptr) { - if (IsValueNode(node)) { - is_tensor_ = true; - } - // get X's shape - auto x_shape_abs = node->abstract(); - if (x_shape_abs != nullptr) { - auto x_track = x_shape_abs->GetShapeTrack()->cast(); - if (x_track == nullptr) { - return; - } - auto x_shape = x_track->shape(); - (void)std::copy(x_shape.begin(), x_shape.end(), std::back_inserter(x_shape_)); - x_ = node; - } - return; - } - - // check axis - AnfVisitor::Visit(node); - } - - void Visit(const ValueNodePtr &vnode) override { - if (x_shape_.empty()) { - return; - } - - // axis : int - if (IsValueNode(vnode)) { - auto idx = GetValue(vnode->value()); - // axis could be negative - if (idx < 0) { - idx += SizeToInt(x_shape_.size()); - } - if (SizeToInt(x_shape_.size()) > idx && x_shape_[IntToSize(idx)] == 1) { - is_axis_one_ = true; - axis_.push_back(idx); - } - return; - } - - // axis : tuple(int), default () - if (IsValueNode(vnode)) { - auto axis = GetValue>(vnode->value()); - if (axis.empty()) { - return; - } - - auto cmp = std::all_of(axis.cbegin(), axis.cend(), [this](int idx) { - // axis could be negative - if (idx < 0) { - idx += SizeToInt(x_shape_.size()); - } - return SizeToInt(this->x_shape_.size()) > idx && this->x_shape_[IntToSize(idx)] == 1; - }); - if (cmp) { - is_axis_one_ = true; - (void)std::copy(axis.begin(), axis.end(), std::back_inserter(axis_)); - } - } - } - - void Reset() { - axis_.clear(); - x_shape_.clear(); - x_ = nullptr; - is_axis_one_ = false; - is_tensor_ = false; - } - - private: - bool is_axis_one_{false}, is_tensor_{false}; - std::vector axis_{}, x_shape_{}; - AnfNodePtr x_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h deleted file mode 100644 index cafc8b796c..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h +++ /dev/null @@ -1,154 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ - -#include - -#include "ir/func_graph.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "pipeline/static_analysis/dshape.h" - -namespace mindspore { -namespace opt { -namespace irpass { -using abstract::Shape; -using abstract::ShapePtr; - -// {reshape_op, X, Shape} -class ReshapeSameShapeEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimReshape, {IsNode, IsVNode})(node); - - // check pattern match - if (shape_ == nullptr) { - return nullptr; - } - - auto src_shape_abs = x_->abstract(); - if (src_shape_abs == nullptr) { - return nullptr; - } - - auto src_shape = src_shape_abs->GetShapeTrack(); - auto tgt_shape_abs = node->abstract(); - if (tgt_shape_abs == nullptr) { - return nullptr; - } - auto tgt_shape = tgt_shape_abs->GetShapeTrack(); - if (src_shape != nullptr && tgt_shape != nullptr && src_shape->isa() && tgt_shape->isa()) { - auto elements = tgt_shape->cast(); - auto shape = src_shape->cast(); - if (shape->shape() == elements->shape()) { - return x_; - } - } - - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - } else { - shape_ = node; - } - } - - void Reset() { - x_ = nullptr; - shape_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}, shape_{nullptr}; -}; - -// {PrimReshape, {PrimReshape, X, Y}, Shape} -class TwoReshapeEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimReshape, {IsCNode, IsNode})(node); - - auto fg = node->func_graph(); - if (fg != nullptr && x_ != nullptr && shape_ != nullptr) { - auto new_node = fg->NewCNode({NewValueNode(prim_), x_, shape_}); - new_node->set_abstract(node->abstract()); - return new_node; - } - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (IsPrimitiveCNode(node, prim::kPrimReshape)) { - auto &inputs = node->cast()->inputs(); - // {PrimReshape, X, Y} - if (inputs.size() != 3) { - return; - } - prim_ = GetValueNode(inputs[0]); - x_ = inputs[1]; - } else { - shape_ = node; - } - } - - void Reset() { - prim_ = nullptr; - x_ = nullptr; - shape_ = nullptr; - } - - private: - PrimitivePtr prim_{nullptr}; - AnfNodePtr x_{nullptr}, shape_{nullptr}; -}; - -class ReshapeEliminater : public OptimizerCaller { - public: - ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} - ~ReshapeEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - auto new_node = reshape_same_shape_eliminater_(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - - new_node = two_reshape_eliminater_(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - - return nullptr; - } - - private: - ReshapeSameShapeEliminater reshape_same_shape_eliminater_; - TwoReshapeEliminater two_reshape_eliminater_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h deleted file mode 100644 index b6a4e1c852..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ /dev/null @@ -1,210 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ - -#include -#include -#include -#include - -#include "ir/optimizer_caller.h" -#include "ir/pattern_matcher.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/prim_eliminate.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -class SpecialOpEliminater : public OptimizerCaller { - public: - SpecialOpEliminater() - : insert_gradient_of_(std::make_shared(prim::kPrimInsertGradientOf)), - stop_gradient_(std::make_shared(prim::kPrimStopGradient)), - hook_backward_(std::make_shared(prim::kPrimHookBackward)), - print_shape_type_(std::make_shared(prim::kPrimPrintShapeType)), - get_ref_value_(std::make_shared(prim::kPrimGetRefValue)), - mirror_(std::make_shared(prim::kPrimMirror)), - virtual_div_(std::make_shared(prim::kPrimVirtualDiv)) { - eliminaters_.emplace_back(insert_gradient_of_); - eliminaters_.emplace_back(stop_gradient_); - eliminaters_.emplace_back(hook_backward_); - eliminaters_.emplace_back(print_shape_type_); - eliminaters_.emplace_back(get_ref_value_); - eliminaters_.emplace_back(mirror_); - eliminaters_.emplace_back(virtual_div_); - } - ~SpecialOpEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } - - private: - OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, - virtual_div_; - std::vector eliminaters_{}; -}; - -// {PrimVirtualDataset, X} -> X -// {PrimVirtualDataset, Xs} -> {prim::kPrimMakeTuple, Xs} -class VirtualDatasetEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!IsPrimitiveCNode(node, prim::kPrimVirtualDataset) || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (inputs.size() < 1) { - return nullptr; - } - - std::vector args; - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); - if (args.size() == 1) { - return args.front(); - } - - (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple)); - - return node->func_graph()->NewCNode(args); - } - - void Visit(const AnfNodePtr &) override {} -}; - -// {prim::kPrimSameTypeShape, X, Y} -> X -class SameEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim::kPrimSameTypeShape, {IsNode, IsNode})(node); - return x_; - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - } - } - - private: - AnfNodePtr x_{nullptr}; -}; - -// {prim::kPrimCheckBprop, X, Y} -> X -class CheckBpropEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim::kPrimCheckBprop, {IsNode, IsNode})(node); - return x_; - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - } - } - - private: - AnfNodePtr x_{nullptr}; -}; - -// Reset defer_inline flag -class ResetDeferInline : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (IsValueNode(node)) { - auto fg = GetValueNode(node); - fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); - } - return nullptr; - } -}; - -// {PrimZerosLike, Y} -> -// {PrimFill, {PrimDType, Y}, {PrimShape, Y}, 0} -class ZeroLikeFillZero : public AnfVisitor { - public: - ZeroLikeFillZero() - : PrimFill_(prim::GetPythonOps("fill", "mindspore.ops.functional")->cast()), - PrimShape_(prim::GetPythonOps("shape", "mindspore.ops.functional")->cast()), - PrimDType_(prim::GetPythonOps("dtype", "mindspore.ops.functional")->cast()) {} - ~ZeroLikeFillZero() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - y_ = nullptr; - AnfVisitor::Match(prim::kPrimZerosLike, {IsNode})(node); - if (y_ == nullptr || node->func_graph() == nullptr) { - return nullptr; - } - if ((y_->abstract() == nullptr) || !y_->abstract()->isa()) { - auto fg = node->func_graph(); - auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); - auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); - return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))}); - } - - abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast(); - - TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); - - tensor::TensorPtr new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); - size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - (void)memset_s(data, mem_size, 0, mem_size); - - auto new_cnode = NewValueNode(new_tensor_ptr); - new_cnode->set_abstract(new_tensor_ptr->ToAbstract()); - - return new_cnode; - } - - void Visit(const AnfNodePtr &node) override { y_ = node; } - - private: - AnfNodePtr y_{nullptr}; - PrimitivePtr PrimFill_, PrimShape_, PrimDType_; -}; - -// {prim::kPrimDepend, X, ValueCond}->X -class DependValueElim : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode x, cond; - MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node))); - return nullptr; - } -}; - -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/specialize_transform.h b/mindspore/ccsrc/optimizer/irpass/specialize_transform.h deleted file mode 100644 index 3db9e7bd51..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/specialize_transform.h +++ /dev/null @@ -1,305 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ - -#include -#include -#include -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/manager.h" -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -class SpecializeTransform { - public: - SpecializeTransform() : cache_() {} - ~SpecializeTransform() = default; - - FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector graph_args, - std::vector prim_args, std::vector value_args) { - if (cache_.count(func_graph) == 0) { - cache_[func_graph] = {}; - } - - auto &cache = cache_[func_graph]; - auto key = std::make_pair(graph_args, prim_args); - if (cache.count(key) == 0) { - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - - FuncGraphPtr new_fg = TransformableClone(func_graph, std::make_shared("sp")); - mng->AddFuncGraph(new_fg); - - std::vector params = new_fg->parameters(); - std::vector new_params; - size_t n = graph_args.size(); - for (size_t i = 0; i < n; i++) { - if (graph_args[i] != nullptr) { - auto arg = NewValueNode(graph_args[i]); - (void)mng->Replace(params[i], arg); - continue; - } - if (prim_args[i] != nullptr) { - auto arg = NewValueNode(prim_args[i]); - (void)mng->Replace(params[i], arg); - continue; - } - if (value_args[i] != nullptr) { - auto &const_tensor = *value_args[i]; - auto const_tensor_ptr = std::make_shared(const_tensor); - AnfNodePtr arg = NewValueNode(const_tensor_ptr); - (void)mng->Replace(params[i], arg); - continue; - } - new_params.push_back(params[i]); - } - - mng->SetParameters(new_fg, new_params); - cache[key] = new_fg; - } - return cache[key]; - } - - private: - std::unordered_map, std::vector>, FuncGraphPtr>> - cache_; -}; -} // namespace internal - -// {G, Xs} -class SpecializeOnGraphArguments : public AnfVisitor { - public: - SpecializeOnGraphArguments() : specialize_transform_() {} - ~SpecializeOnGraphArguments() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (!IsValueNode(inputs[0])) { - return nullptr; - } - - auto inp0_fg = GetValueNode(inputs[0]); - if (inp0_fg->recursive()) { - return nullptr; - } - - std::vector graph_args; - std::vector prim_args; - std::vector value_node_args; - std::vector new_xs; - bool hasVNode = false; - for (size_t i = 1; i < inputs.size(); i++) { - if (IsValueNode(inputs[i])) { - auto fg_vnode = GetValueNode(inputs[i]); - graph_args.push_back(fg_vnode); - prim_args.emplace_back(nullptr); - value_node_args.emplace_back(nullptr); - hasVNode = true; - } else if (IsValueNode(inputs[i])) { - auto p_vnode = GetValueNode(inputs[i]); - graph_args.emplace_back(nullptr); - prim_args.push_back(p_vnode); - value_node_args.emplace_back(nullptr); - hasVNode = true; - } else if (IsValueNode(inputs[i])) { - tensor::TensorPtr t_vnode = GetValueNode(inputs[i]); - graph_args.emplace_back(nullptr); - prim_args.emplace_back(nullptr); - value_node_args.emplace_back(t_vnode); - hasVNode = true; - } else { - graph_args.emplace_back(nullptr); - prim_args.emplace_back(nullptr); - value_node_args.emplace_back(nullptr); - new_xs.push_back(inputs[i]); - } - } - - if (!hasVNode) { - return nullptr; - } - - auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args); - (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); - - return node->func_graph()->NewCNode(new_xs); - } - - private: - internal::SpecializeTransform specialize_transform_; -}; - -// Eliminate unused parameters. -// {G, Xs} -class UnusedParasEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto &inputs = cnode->inputs(); - auto fg = GetValueNode(inputs[0]); - MS_EXCEPTION_IF_NULL(fg); - - std::vector parameters = fg->parameters(); - size_t size = parameters.size(); - if (size != inputs.size() - 1) { - return nullptr; - } - - std::vector new_xs; - std::vector keep_parameters; - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - auto &node_users = mng->node_users(); - bool has_unused_para = false; - for (size_t i = 0; i < size; ++i) { - auto iter = node_users.find(parameters[i]); - if (iter != node_users.end() && !iter->second.empty()) { - keep_parameters.push_back(true); - new_xs.push_back(inputs[i + 1]); - continue; - } - keep_parameters.push_back(false); - has_unused_para = true; - } - - if (!has_unused_para) { - return nullptr; - } - FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("sp")); - mng->AddFuncGraph(new_fg); - - std::vector new_fg_parameters = new_fg->parameters(); - std::vector new_parameters; - for (size_t i = 0; i < size; i++) { - if (keep_parameters[i]) { - if (parameters[i]->abstract() != nullptr) { - new_fg_parameters[i]->set_abstract(parameters[i]->abstract()); - } - new_parameters.push_back(new_fg_parameters[i]); - } - } - mng->SetParameters(new_fg, new_parameters); - - (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); - return node->func_graph()->NewCNode(new_xs); - } -}; - -// Eliminate unused outputs. -// {G, Xs} -class UnusedOutputEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - auto fg = GetValueNode(inputs[0]); - MS_EXCEPTION_IF_NULL(fg); - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - if (fg->recursive()) { - return nullptr; - } - - auto new_fg = TransformableClone(fg, std::make_shared("fg")); - mng->AddFuncGraph(new_fg); - auto new_fg_output = new_fg->output(); - if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) { - return nullptr; - } - - auto output_cnode = new_fg_output->cast(); - auto &node_users = mng->node_users(); - if (node_users.count(node) == 0 || node_users[node].empty()) { - return nullptr; - } - std::unordered_set used_output_idx; - std::vector> all_users; - for (auto &node_user : node_users[node]) { - if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { - return nullptr; - } - auto user_cnode = node_user.first->cast(); - size_t used_idx = GetValue(user_cnode->input(2)->cast()->value()); - used_output_idx.insert(used_idx); - all_users.push_back(std::make_pair(node_user.first, used_idx)); - } - - if (used_output_idx.size() >= output_cnode->inputs().size() - 1) { - // all output has users. - return nullptr; - } - - if (used_output_idx.empty()) { - // we do not process this case. - return nullptr; - } else if (used_output_idx.size() == 1) { - // after eliminate, only one output left. - new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1)); - // update users. - for (auto &ret_user : all_users) { - (void)mng->Replace(ret_user.first, node); - } - } else { - // after eliminate, create new multi output. - std::vector new_output_inputs{output_cnode->input(0)}; - std::unordered_map new_idx_map; - for (auto idx : used_output_idx) { - new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1); - new_output_inputs.push_back(output_cnode->input(idx + 1)); - } - new_fg->set_output(new_fg->NewCNode(new_output_inputs)); - // update users. - for (auto &ret_user : all_users) { - auto ret_user_cnode = ret_user.first->cast(); - ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second])); - } - } - - auto new_sx = inputs; - new_sx[0] = NewValueNode(new_fg); - return node->func_graph()->NewCNode(new_sx); - } -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/symbol_resolver.h b/mindspore/ccsrc/optimizer/irpass/symbol_resolver.h deleted file mode 100644 index 7b35cf5451..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/symbol_resolver.h +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ - -#include -#include - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/python_adapter.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimResolve, Ns, Sym} -class ResolverResolve : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimResolve, {IsVNode, IsVNode})(node); - if (sym_ != nullptr) { - return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node); - } - return nullptr; - } - - void Visit(const ValueNodePtr &vnode) override { - if (IsValueNode(vnode)) { - ns_ = GetValueNode(vnode); - } else if (ns_ != nullptr && IsValueNode(vnode)) { - sym_ = GetValueNode(vnode); - } - } - - void Reset() { - ns_ = nullptr; - sym_ = nullptr; - } - - private: - parse::NameSpacePtr ns_{nullptr}; - parse::SymbolPtr sym_{nullptr}; -}; - -// {prim::kPrimGetAttr, Ns, Str} -class ResolverGetattr : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimGetAttr, {IsVNode, IsVNode})(node); - if (sym_ != nullptr) { - return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node); - } - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (IsValueNode(node)) { - ns_ = GetValueNode(node); - } else if (ns_ != nullptr && IsValueNode(node)) { - auto str = GetValue(GetValueNode(node)); - sym_ = std::make_shared(str); - } - } - - void Reset() { - ns_ = nullptr; - sym_ = nullptr; - } - - private: - parse::NameSpacePtr ns_{nullptr}; - parse::SymbolPtr sym_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/tile_eliminate.h b/mindspore/ccsrc/optimizer/irpass/tile_eliminate.h deleted file mode 100644 index 86ac5bab73..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/tile_eliminate.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ - -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// check if node is value tuple and all one. e.g. (1, 1, 1) -// {PrimTile, X, MultiOne} -class TileMultiplyByOne : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTile, {IsNode, IsVNode})(node); - - // check pattern match - if (tuple_ == nullptr) { - return nullptr; - } - - auto value = GetValueNode(tuple_); - auto elements = GetValue>(value); - if (elements.empty()) { - return nullptr; - } - - auto cmp = std::all_of(elements.cbegin(), elements.cend(), [](int i) { return i == 1; }); - if (cmp) { - return x_; - } - - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - } else { - tuple_ = node; - } - } - - void Reset() { - x_ = nullptr; - tuple_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}, tuple_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/transpose_eliminate.h b/mindspore/ccsrc/optimizer/irpass/transpose_eliminate.h deleted file mode 100644 index de196ea619..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/transpose_eliminate.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ - -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// check if node is value tuple and ascends one by one from zero. e.g., (0, 1, 2, 3) -// {PrimTranspose, X, AscendingNums} -class TransposeSameIOEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTranspose, {IsNode, IsVNode})(node); - - // check pattern match - if (tuple_ == nullptr) { - return nullptr; - } - - auto value = GetValueNode(tuple_); - auto elements = GetValue>(value); - if (elements.empty()) { - return nullptr; - } - - int j = 0; - bool cmp = std::all_of(elements.cbegin(), elements.cend(), [&j](int i) { return i == j++; }); - // same IO settings, eliminate this transpose - if (cmp) { - return x_; - } - - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - } else { - tuple_ = node; - } - } - - void Reset() { - x_ = nullptr; - tuple_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}, tuple_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc deleted file mode 100644 index 462d08ad3c..0000000000 --- a/mindspore/ccsrc/optimizer/opt.cc +++ /dev/null @@ -1,246 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/opt.h" - -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/manager.h" -#include "optimizer/optimizer.h" -#include "utils/log_adapter.h" -#include "utils/ordered_set.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, - const RenormAction &renorm_action) { - auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; - return std::make_shared(transform, name, fn, renorm_action); -} - -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, - const std::vector &prims, const RenormAction &renorm_action) { - auto fn = [prims](const AnfNodePtr &node) -> bool { - if (!node->isa()) { - return false; - } - - auto cnode = node->cast(); - auto inp0 = cnode->input(0); - auto prim0 = GetValueNode(inp0); - if (prim0 == nullptr) { - return false; - } - - auto hash = prim0->Hash(); - auto const &name = prim0->name(); - for (auto &prim : prims) { - if (hash == prim->Hash() && name == prim->name()) { - return true; - } - } - return false; - }; - - return std::make_shared(transform, name, fn, renorm_action); -} - -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, - const PredicateFuncType &predicate, const RenormAction &renorm_action) { - return std::make_shared(transform, name, predicate, renorm_action); -} - -AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { -#ifdef ENABLE_PROFILE - double t = GetTime(); -#endif - AnfNodePtr result = (*transform_)(optimizer, node); -#ifdef ENABLE_PROFILE - if (optimizer != nullptr) { - auto time = GetTime(); - MsProfile::StatTime("substitution." + name_, time - t); - if (result != nullptr) { - MsProfile::StatTime("match." + name_, time - t); - } - } -#endif - if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) { - if (renorm_action_ == FORCE_RENORM) { - optimizer->add_node_to_renormalize(result); - } else { - // renorm_action_ is CHECK_RENORM - if (result->abstract() == nullptr) { - optimizer->add_node_to_renormalize(result); - } - } - } - - return result; -} - -static bool isTraversable(const AnfNodePtr &node) { - if (node == nullptr) { - return false; - } - if (node->isa() || node->isa()) { - return true; - } - if (IsValueNode(node) || IsValueNode(node)) { - return true; - } - return false; -} - -bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, - const SubstitutionPtr &transform) const { -#ifdef ENABLE_PROFILE - double start = GetTime(); -#endif - FuncGraphManagerPtr manager = optimizer->manager(); - auto seen = NewSeenGeneration(); - // 1024 is for the initial capacity of deque - std::deque todo(1024); - todo.clear(); - todo.push_back(root_node); - bool changes = false; - - auto &all_nodes = manager->all_nodes(); - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); - - // check whether this node has been matched. - if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { - continue; - } - node->seen_ = seen; - - // select nodes that this transform can be applied. - bool is_match = transform->predicate_(node); - - // apply transform on this node - bool change = false; - if (is_match) { - auto ret = (*transform)(optimizer, node); - if (ret != nullptr && ret != node) { - change = true; - changes = true; -#ifdef ENABLE_PROFILE - double t = GetTime(); -#endif - (void)manager->Replace(node, ret); -#ifdef ENABLE_PROFILE - MsProfile::StatTime("replace." + transform->name_, GetTime() - t); -#endif - node = ret; - } - } - - // find success, and add them to todo list - if (IsValueNode(node)) { - todo.push_back(GetValueNode(node)->output()); - } - - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); - } - - auto &node_users = manager->node_users(); - if (change && node_users.find(node) != node_users.end()) { - for (auto &use : node_users[node]) { - auto use_node = use.first; - if (use_node == nullptr) { - continue; - } - todo.push_back(use_node); - if (use_node->seen_ == seen) { - use_node->seen_--; - } - } - } - } - -#ifdef ENABLE_PROFILE - MsProfile::StatTime("opt.transform." + optimizer->name(), GetTime() - start); -#endif - return changes; -} - -bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { - MS_EXCEPTION_IF_NULL(optimizer); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = optimizer->manager(); - manager->AddFuncGraph(func_graph); - - // for transform status counting - size_t space = 0; - std::unordered_map> status; - if (optimizer->is_on_debug_) { - for (size_t i = 0; i < list_.size(); i++) { - status[list_[i]->name_ + std::to_string(i)] = {}; - } - } - - bool loop = false; - bool changes = false; - - do { - loop = false; - for (size_t i = 0; i < list_.size(); i++) { - auto change = ApplyTransform(optimizer, func_graph->output(), list_[i]); - changes = changes || change; - loop = loop || change; - - // record the status of each transform - if (optimizer->is_on_debug_) { - status[list_[i]->name_ + std::to_string(i)].push_back(change); - space = std::max(list_[i]->name_.size(), space); - } - } - - if (is_once_) { - break; - } - } while (loop); - - // display the status of each transform - if (optimizer->is_on_debug_) { - std::stringstream ss; - ss << std::endl - << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name - << std::endl; - for (size_t i = 0; i < list_.size(); i++) { - auto name = list_[i]->name_; - ss << std::left << std::setw(space + 4) << name << "\t"; - for (auto change : status[name + std::to_string(i)]) { - ss << change << " "; - } - ss << std::endl; - } - MS_LOG(DEBUG) << ss.str(); - } - - return changes; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/opt.h b/mindspore/ccsrc/optimizer/opt.h deleted file mode 100644 index 6601d969d2..0000000000 --- a/mindspore/ccsrc/optimizer/opt.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ - -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/optimizer_caller.h" -#include "operator/ops.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { - -// Define the interaction mode between an Optimize pass and Renormalize pass -// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed -// CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted -enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; - -class Substitution { - public: - OptimizerCallerPtr transform_; - std::string name_; - PredicateFuncType predicate_{nullptr}; - // an enum to mark this Substitution relation to renormalize pass - RenormAction renorm_action_; - Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, - const RenormAction &renorm_action) - : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} - ~Substitution() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); -}; - -using SubstitutionPtr = std::shared_ptr; - -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, - const RenormAction &action_renorm = CHECK_RENORM); -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, - const std::vector &prims, - const RenormAction &action_renorm = CHECK_RENORM); -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, - const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); - -class SubstitutionList { - public: - explicit SubstitutionList(const std::vector &patterns, bool is_once = false) - : list_(patterns), is_once_(is_once) {} - ~SubstitutionList() = default; - - bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; - - private: - bool ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &transform) const; - std::vector list_; - // a flag to mark this list of Substitution can only be executed only once - bool is_once_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ diff --git a/mindspore/ccsrc/optimizer/optimizer.h b/mindspore/ccsrc/optimizer/optimizer.h deleted file mode 100644 index dc423ed314..0000000000 --- a/mindspore/ccsrc/optimizer/optimizer.h +++ /dev/null @@ -1,241 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "debug/draw.h" -#include "debug/anf_ir_dump.h" -#include "debug/anf_ir_utils.h" -#include "debug/trace.h" -#include "optimizer/opt.h" -#include "pipeline/resource.h" -#include "pipeline/action.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -using OptimizeGraphFunc = std::function; - -class OptPassConfig { - public: - explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {} - explicit OptPassConfig(const std::vector &list, bool is_once = false) - : list_(list), is_once_(is_once) {} - OptPassConfig(const std::initializer_list &list, bool is_once = false) - : list_(list), is_once_(is_once) {} - ~OptPassConfig() = default; - - const std::vector &list() const { return list_; } - const OptimizeGraphFunc &func() const { return func_; } - - static OptPassConfig Renormalize() { return OptPassConfig(); } - const bool is_renormalize() const { return is_renormalize_; } - - const bool is_once() const { return is_once_; } - - private: - OptPassConfig() : is_renormalize_(true) {} - - OptimizeGraphFunc func_; - std::vector list_; - bool is_renormalize_{false}; - bool is_once_{false}; -}; - -class OptPass { - public: - explicit OptPass(const OptimizeGraphFunc &func) : pass_func_(func) {} - ~OptPass() = default; - - bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { - return pass_func_(func_graph, optimizer); - } - - static OptPass Renormalize() { return OptPass(); } - const bool is_renormalize() const { return is_renormalize_; } - - private: - OptPass() : is_renormalize_(true) {} - - OptimizeGraphFunc pass_func_; - bool is_renormalize_{false}; -}; -using OptPassGroupMap = std::vector>; - -class Optimizer : public std::enable_shared_from_this { - public: - Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) - : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true) {} - virtual ~Optimizer() = default; - - void Init(const OptPassGroupMap &passes, bool run_only_once) { - run_only_once_ = run_only_once; - is_watch_renormalize_ = false; - is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG); - - for (auto &iter : passes) { - const std::string &name = iter.first; - pass_names_.push_back(name); - - const OptPassConfig &config = iter.second; - if (config.is_renormalize()) { - passes_.push_back(OptPass::Renormalize()); - continue; - } - - if (config.list().size() > 0) { - OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once()); - passes_.push_back(OptPass(func)); - continue; - } - - passes_.push_back(OptPass(config.func())); - } - - if (passes_.size() == 1) { - run_only_once_ = true; - } - } - - static std::shared_ptr MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, - const OptPassGroupMap &passes, bool run_only_once = false, - bool watch_renormalize = false) { - OptimizerPtr optimizer = std::make_shared(name, resource_ptr); - optimizer->Init(passes, run_only_once); - if (watch_renormalize) { - optimizer->enable_watch_renormalize(); - } - return optimizer; - } - - FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { - if (!is_enable_) { - return func_graph; - } - // Optimizer step counter; - int counter = 1; - bool changes = true; - - while (changes) { - changes = false; - auto run_runc = [&counter, &func_graph, &changes, use_profile, this]() { - for (size_t i = 0; i < passes_.size(); ++i) { - const OptPass &opt = passes_[i]; - CurPass_ = {counter, pass_names_[i]}; - auto opt_func = [&func_graph, &changes, &opt, this]() { - if (opt.is_renormalize()) { - auto resource_ptr = std::dynamic_pointer_cast(resource_); - if (resource_ptr != nullptr) { - // StepParallel may replace the AbstractValue of the parameters of func_graph, - // So generate the args_spec from parameters. - abstract::AbstractBasePtrList maybe_new_args_spec; - if (is_watch_renormalize_) { - if (untyped_nodes_.size() > 0) { - std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), - std::back_inserter(maybe_new_args_spec), - [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); - func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); - clear_untyped_nodes(); - } else { - MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty."; - } - } else { - std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), - std::back_inserter(maybe_new_args_spec), - [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); - func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); - } - } - } else if (opt(func_graph, shared_from_this())) { - changes = true; - } - }; - use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); - if (is_on_debug_ && MsContext::GetInstance()->save_graphs_flag()) { - MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; - auto fg_name = - "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; - func_graph->DumpFuncGraph(fg_name); - DumpIR(fg_name + ".ir", func_graph); - ExportIR(fg_name + ".dat", "", func_graph); - MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; - } - } - }; - use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter)) run_runc) : run_runc(); - counter++; - - if (run_only_once_) { - break; - } - } - return func_graph; - } - - pipeline::ResourceBasePtr resource() const { return resource_; } - FuncGraphManagerPtr manager() const { - if (resource_ != nullptr) { - return resource_->manager(); - } - MS_LOG(EXCEPTION) << "No ResourceBase exists."; - } - - const std::string name() const { return name_; } - - void add_node_to_renormalize(AnfNodePtr anode) { - if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) { - untyped_nodes_.push_back(anode); - } - } - - void clear_untyped_nodes() { untyped_nodes_.clear(); } - - void enable_watch_renormalize() { is_watch_renormalize_ = true; } - void disable_watch_renormalize() { is_watch_renormalize_ = false; } - bool is_watch_renormalize() { return is_watch_renormalize_; } - void set_enable(bool enable) { is_enable_ = enable; } - - struct { - int counter; - std::string name; - } CurPass_; - - bool is_on_debug_{false}; - - private: - const std::string name_; - pipeline::ResourceBasePtr resource_; - std::vector passes_; - std::vector pass_names_; - bool run_only_once_; - std::vector untyped_nodes_; - bool is_watch_renormalize_; - bool is_enable_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/optimizer/pass_group.cc b/mindspore/ccsrc/optimizer/pass_group.cc deleted file mode 100644 index 2d1ab07f7d..0000000000 --- a/mindspore/ccsrc/optimizer/pass_group.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "optimizer/pass_group.h" - -namespace mindspore { -namespace opt { -namespace python_pass { -void PassGroup::AddPass(const PythonPassPtr &pass) { - if (pass != nullptr) { - passes_.push_back(pass); - } -} - -bool PassGroup::DeletePass(const std::string &pass_name) { - for (auto iter = passes_.begin(); iter != passes_.end(); iter++) { - if ((*iter)->name() == pass_name) { - *iter = nullptr; - passes_.erase(iter); - return true; - } - } - return false; -} - -bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { - if (func_graph == nullptr) { - return false; - } - bool changed = false; - for (const auto &pass : passes) { - if (pass != nullptr) { - if (pass->Run(func_graph)) { - changed = true; - } - } - } - return changed; -} - -bool PassGroup::Run(const FuncGraphPtr &func_graph) const { - bool changed = false; - // run all passes - bool change = true; - while (change) { - change = Run(func_graph, passes_); - changed = change || changed; - if (run_only_once_) { - break; - } - } - return changed; -} - -} // namespace python_pass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/pass_group.h b/mindspore/ccsrc/optimizer/pass_group.h deleted file mode 100644 index 895f5a4128..0000000000 --- a/mindspore/ccsrc/optimizer/pass_group.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ - -#include -#include -#include -#include - -#include "optimizer/py_pass.h" - -namespace mindspore { -namespace opt { -namespace python_pass { -class PassGroup { - public: - explicit PassGroup(const std::string &name = "pass_group", bool run_only_once = false) - : name_(name), passes_{}, run_only_once_(run_only_once) {} - virtual ~PassGroup() = default; - // Add graph pass, the pass object will be freed when pass manager freed. - void AddPass(const PythonPassPtr &pass); - // Delete graph pass before the pass manager is freed. - bool DeletePass(const std::string &pass_name); - // Run passes added in pass manager on the input graph - // @param [inout] graph The graph to be optimized - // @return true, graph changed - // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph) const; - // Run the given graph passes on the input graph - // @param [inout] graph The graph to be optimized - // @param [in] passes The given graph passes - // @return true, graph changed - // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; - std::string name() const { return name_; } - - private: - const std::string name_; - std::vector passes_; - bool run_only_once_; -}; -using PassGroupPtr = std::shared_ptr; -} // namespace python_pass -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ diff --git a/mindspore/ccsrc/optimizer/py_pass.cc b/mindspore/ccsrc/optimizer/py_pass.cc deleted file mode 100644 index 8ce348b22e..0000000000 --- a/mindspore/ccsrc/optimizer/py_pass.cc +++ /dev/null @@ -1,236 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "optimizer/py_pass.h" -#include -#include -#include -#include -#include - -#include "ir/func_graph.h" -#include "ir/manager.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/resource.h" - -namespace mindspore { -namespace opt { -namespace python_pass { -namespace internal { -std::string GetNodeRepr(AnfNodePtr node) { - if (node != nullptr) { - if (node->isa()) { - std::string repr = "("; - auto const &inputs = node->cast()->inputs(); - for (auto &input : inputs) { - repr += " "; - repr += GetNodeRepr(input); - repr += " "; - } - repr += ")"; - return repr; - } - if (node->isa()) { - return GetValueNode(node)->ToString(); - } - return node->ToString(); - } - return ""; -} - -void ResolveFuncGraph_(const FuncGraphPtr &fg) { - auto manager = Manage(fg, false); - parse::python_adapter::set_use_signature_in_resolve(false); - parse::ResolveAll(manager); -} - -bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { - if (node == nullptr) { - return false; - } - MS_EXCEPTION_IF_NULL(pattern); - if (pattern->isa()) { - if (!node->isa()) { - return false; - } - if (GetNodeRepr(pattern) == GetNodeRepr(node)) { - // add to equiv_ptr - equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node)); - return true; - } - return false; - } else if (pattern->isa()) { - MS_LOG(DEBUG) << pattern->ToString() + "\n"; - // add to equiv_ptr - equiv_ptr->insert(std::make_pair(pattern->ToString(), node)); - return true; - } else if (pattern->isa()) { - // match every single sub ANode - if (!node->isa()) { - return false; - } - auto pattern_inputs = pattern->cast()->inputs(); - auto node_inputs = node->cast()->inputs(); - if (pattern_inputs.size() != node_inputs.size()) { - return false; - } - for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end(); - p_item++, node_item++) { - auto res = Match(*p_item, *node_item, equiv_ptr); - if (!res) { - return false; - } - } - return true; - } - MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n"; -} - -AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_, - const NodeEquivPtr &equiv_ptr) { - if (cur_raw_dst_node_->isa()) { - auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString()); - if (sub_pair != equiv_ptr->end()) { - return sub_pair->second; - } - MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n"; - } else if (cur_raw_dst_node_->isa()) { - // check primitive ValueNode - auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast()->value()->ToString()); - if (sub_pair != equiv_ptr->end()) { - return sub_pair->second; - } - return cur_raw_dst_node_; - } else if (cur_raw_dst_node_->isa()) { - std::vector new_inputs; - auto inputs = cur_raw_dst_node_->cast()->inputs(); - for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) { - auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr); - new_inputs.push_back(subed); - } - return func_graph->NewCNode(new_inputs); - } - MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_); -} - -bool isTraversable(const AnfNodePtr &node) { - if (node == nullptr) { - return false; - } - if (node->isa() || node->isa()) { - return true; - } - if (IsValueNode(node) || IsValueNode(node)) { - return true; - } - return false; -} -} // namespace internal - -void PythonPass::Build(const py::function &src, const py::function &dst) { - // 1. get FuncGraph from py::function - auto src_fg_ = parse::ParsePythonCode(src); - auto dst_fg_ = parse::ParsePythonCode(dst); - if (src_fg_ == nullptr || dst_fg_ == nullptr) { - MS_LOG(EXCEPTION) << "Failed to parse python code.\n"; - } - // 2. Resolve - internal::ResolveFuncGraph_(src_fg_); - internal::ResolveFuncGraph_(dst_fg_); - // 3. from FuncGraphPtr to ValueNode - src_node_ = src_fg_->output(); - dst_node_ = dst_fg_->output(); -} - -PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once, - bool multigraph) - : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) { - Build(src, dst); -} - -AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - auto equiv_ptr = std::make_shared(); - bool is_a_match = internal::Match(src_node_, node, equiv_ptr); - if (is_a_match) { - auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr); - MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; - return new_node; - } - return nullptr; -} - -bool PythonPass::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - auto seen = NewSeenGeneration(); - // 1024 is for the initial capacity of deque - std::deque todo(1024); - todo.push_back(func_graph->output()); - bool changes = false; - - auto &all_nodes = manager->all_nodes(); - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); - - // check whether this node has been matched. - if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) { - continue; - } - node->seen_ = seen; - - // select nodes that this transform can be applied. - AnfNodePtr new_node = Run(func_graph, node); - bool change = (new_node != nullptr); - if (new_node != nullptr && new_node != node) { - (void)manager->Replace(node, new_node); - } else if (new_node == nullptr) { - new_node = node; - } - if (run_only_once_) { - return change; - } - - // find success, and add them to todo list - if (IsValueNode(node)) { - todo.push_back(GetValueNode(node)->output()); - } - - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); - } - - auto &node_users = manager->node_users(); - if (change && node_users.find(node) != node_users.end()) { - for (auto &use : node_users[node]) { - auto use_node = use.first; - if (use_node == nullptr) { - continue; - } - todo.push_back(use_node); - if (use_node->seen_ == seen) { - use_node->seen_--; - } - } - } - } - return changes; -} -} // namespace python_pass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/py_pass_manager.cc b/mindspore/ccsrc/optimizer/py_pass_manager.cc deleted file mode 100644 index 1c36e93c9a..0000000000 --- a/mindspore/ccsrc/optimizer/py_pass_manager.cc +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "optimizer/py_pass_manager.h" - -#include -#include -#include -#include - -#include "ir/manager.h" -#include "optimizer/pass_group.h" - -namespace mindspore { -namespace opt { -namespace python_pass { -PyPassManagerPtr PyPassManager::global_instance = nullptr; -std::unordered_map PyPassManager::phase_to_group_; - -PassGroupPtr PyPassManager::GetPassGroup(Phase phase) { - auto pm = phase_to_group_.find(phase); - if (pm == phase_to_group_.end()) { - return nullptr; - } - return pm->second; -} - -PyPassManagerPtr PyPassManager::GetInstance() { - if (global_instance == nullptr) { - global_instance = std::shared_ptr(new (std::nothrow) PyPassManager()); - } - return global_instance; -} - -PyPassManager::PyPassManager() { - phase_to_group_[Phase::RESOLVE] = std::make_shared(); - phase_to_group_[Phase::OPT] = std::make_shared(); -} - -void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, - Phase phase, bool run_only_once, bool multigraph) { - auto cur_pm = GetPassGroup(phase); - MS_EXCEPTION_IF_NULL(cur_pm); - PythonPassPtr new_pass = std::make_shared(pass_name, pattern, target, run_only_once, multigraph); - cur_pm->AddPass(new_pass); -} - -void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { - auto cur_pm = GetPassGroup(phase); - MS_EXCEPTION_IF_NULL(cur_pm); - if (!cur_pm->DeletePass(pass_name)) { - MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; - } -} - -void PyPassManager::ClearRes() { - MS_LOG(INFO) << "Clear PyPassManager resources!"; - global_instance = nullptr; - phase_to_group_.clear(); -} - -REGISTER_PYBIND_DEFINE( - PyPassManager_, ([](const py::module *m) { - (void)py::enum_(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT); - (void)py::class_>(*m, "PyPassManager_") - .def(py::init([]() { return PyPassManager::GetInstance(); })) - .def("registe", &PyPassManager::Registe, "Registe python pass") - .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass"); - })); -} // namespace python_pass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/py_pass_manager.h b/mindspore/ccsrc/optimizer/py_pass_manager.h deleted file mode 100644 index eaeefce213..0000000000 --- a/mindspore/ccsrc/optimizer/py_pass_manager.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ - -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" -#include "utils/graph_utils.h" -#include "common/utils.h" - -#include "pipeline/parse/resolve.h" -#include "optimizer/py_pass.h" -#include "optimizer/pass_group.h" - -namespace mindspore { -namespace opt { -namespace python_pass { -class PyPassManager; -using PyPassManagerPtr = std::shared_ptr; - -enum Phase { RESOLVE, OPT }; - -class PyPassManager { - protected: - PyPassManager(); - static PyPassManagerPtr global_instance; - - public: - // Singletons should not be cloneable and assignable - PyPassManager(const PyPassManager &other) = delete; - void operator=(const PyPassManager &) = delete; - // Access the only global instance - static PyPassManagerPtr GetInstance(); - virtual ~PyPassManager() = default; - void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, - Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); - void Unregiste(const std::string &pass_name, Phase phase); - PassGroupPtr GetPassGroup(Phase phase); - void ClearRes(); - - private: - static std::unordered_map phase_to_group_; -}; -} // namespace python_pass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/parallel/CMakeLists.txt b/mindspore/ccsrc/parallel/CMakeLists.txt deleted file mode 100644 index 940b1ed1d8..0000000000 --- a/mindspore/ccsrc/parallel/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -file(GLOB_RECURSE _PARALLEL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -if (ENABLE_DUMP_PROTO) - list(REMOVE_ITEM _PARALLEL_SRC_FILES "parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") -endif () - -set_property(SOURCE ${_PARALLEL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PARALLEL) -add_library(_mindspore_parallel_obj OBJECT ${_PARALLEL_SRC_FILES}) diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc deleted file mode 100644 index 30173e533c..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc +++ /dev/null @@ -1,435 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/allreduce_fusion/allreduce_fusion.h" -#include -#include -#include -#include -#include "ir/func_graph.h" -#include "parallel/costmodel_context.h" -#include "parallel/graph_util/node_info.h" -#include "parallel/status.h" -#include "parallel/step_parallel.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -std::unordered_set FindCNodesWithPara(const AnfNodePtr ¶, uint32_t recursive_times = 0) { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - MS_EXCEPTION_IF_NULL(para); - MS_EXCEPTION_IF_NULL(para->func_graph()); - FuncGraphManagerPtr manager = para->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto node_set = manager->node_users()[para]; - std::unordered_set cnode_set; - for (auto &node_pair : node_set) { - auto cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - continue; - } - auto node_prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { - (void)cnode_set.emplace(cnode); - } else { - auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); - for (auto &cnode_sub : cnode_set_sub) { - (void)cnode_set.emplace(cnode_sub); - } - } - } - return cnode_set; -} - -Status AllreduceFusion::AddNodeToGraph() { - const auto ¶meters = root_graph_->parameters(); - for (auto ¶meter : parameters) { - if (!ParameterRequireGrad(parameter)) { - continue; - } - auto cnode_set = FindCNodesWithPara(parameter); - if (cnode_set.empty()) { - continue; - } - for (auto &cnode : cnode_set) { - MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); - if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { - MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); - return FAILED; - } - } - } - return SUCCESS; -} - -CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - MS_EXCEPTION_IF_NULL(from); - std::unordered_map cnode_dist; - if (!from->isa()) { - return cnode_dist; - } - auto cnode = from->cast(); - if (!IsValueNode(cnode->input(0))) { - return cnode_dist; - } - - MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) - << " operator_info: " << (cnode->operator_info() != nullptr); - - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode(); - MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost; - - if (allreduce_graph_.NodeInGraph(cnode)) { - cnode_dist[cnode] = cost; - return cnode_dist; - } else { - auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); - for (auto &ele : cnode_dist_next) { - cnode_dist[ele.first] = cost + ele.second; - } - } - } else { - auto cnode_dist_next = FindNextCNodes(cnode); - for (auto &ele : cnode_dist_next) { - cnode_dist[ele.first] = ele.second; - } - } - return cnode_dist; -} - -CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - const auto &from_inputs = from->inputs(); - std::unordered_map dist_map; - MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; - for (auto &input_node : from_inputs) { - auto cnode_dist = FindCNode(input_node, recursive_times + 1); - for (auto &ele : cnode_dist) { - (void)dist_map.emplace(ele); - } - } - return dist_map; -} - -Status AllreduceFusion::AddEdgeToGraph() { - std::unordered_map cnode_state_map; - const auto &cnodes = allreduce_graph_.cnode_set(); - for (auto &cnode : cnodes) { - cnode_state_map[cnode] = 0; - } - const auto &head_cnode = allreduce_graph_.head_cnode(); - std::queue cnode_queue; - cnode_queue.emplace(head_cnode); - cnode_state_map[head_cnode] = 1; - - while (!cnode_queue.empty()) { - const auto cur_cnode = cnode_queue.front(); - cnode_queue.pop(); - cnode_state_map[cur_cnode] = 2; - auto next = FindNextCNodes(cur_cnode); - for (auto &ele : next) { - auto &cnode = ele.first; - auto &dist = ele.second; - if (cnode_state_map[cnode] == 0) { - cnode_queue.emplace(cnode); - cnode_state_map[cnode] = 1; - } - if (allreduce_graph_.AddEdge(cur_cnode, cnode, dist) != SUCCESS) { - MS_LOG(ERROR) << "AddEdge error"; - return FAILED; - } - MS_LOG(DEBUG) << "from " << cur_cnode->DebugString() << ", to " << cnode->DebugString() << " dist " << dist; - } - } - return SUCCESS; -} - -std::vector FindMirror(const AnfNodePtr ¶, uint32_t recursive_times = 0) { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - MS_EXCEPTION_IF_NULL(para); - MS_EXCEPTION_IF_NULL(para->func_graph()); - FuncGraphManagerPtr manager = para->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[para]; - std::vector cnode_list; - for (auto &node_pair : node_set) { - auto cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - continue; - } - auto node_prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == CAST) { - auto mirror_cnodes = FindMirror(node_pair.first, recursive_times + 1); - if (mirror_cnodes.empty()) { - MS_LOG(WARNING) << "mirror node after cast not found"; - continue; - } - if (mirror_cnodes.size() > 1) { - MS_LOG(EXCEPTION) << "mirror node after cast number is not 1"; - } - cnode_list.emplace_back(mirror_cnodes[0]); - } - if (node_prim->name() == MIRROR_OPERATOR) { - cnode_list.emplace_back(cnode); - } - } - return cnode_list; -} - -void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string ¶meter_name) { - MS_EXCEPTION_IF_NULL(mirror_cnode); - MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; - auto node_prim = GetValueNode(mirror_cnode->input(0)); - auto old_value_ptr = node_prim->GetAttr(FUSION); - if (old_value_ptr != nullptr) { - if (old_value_ptr->isa()) { - int32_t old_value = old_value_ptr->cast()->value(); - if (old_value < fusion) { - return; - } - } - } - (void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared(fusion))); - (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared(parameter_name))); -} - -Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int32_t fusion) { - auto mirror_cnodes = FindMirror(para); - if (mirror_cnodes.empty()) { - MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; - return SUCCESS; - } - if (mirror_cnodes.size() > 2) { - for (auto &mirror_cnode : mirror_cnodes) { - MS_EXCEPTION_IF_NULL(mirror_cnode); - MS_LOG(INFO) << mirror_cnode->DebugString(); - } - MS_EXCEPTION_IF_NULL(para); - MS_LOG(ERROR) << para->ToString() << " FindMirror is more than 2. " << mirror_cnodes.size() - << "Mirror CNode found."; - return FAILED; - } - for (auto &mirror_cnode : mirror_cnodes) { - auto parameter_name = ParameterName(para); - SetMirrorFusion(mirror_cnode, fusion, parameter_name); - } - return SUCCESS; -} - -Status FindMirrorAndSetFusion(const std::vector ¶s, int32_t fusion) { - for (auto ¶m_node : paras) { - if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { - MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; - return FAILED; - } - } - return SUCCESS; -} - -Status AllreduceFusion::SetFusion(const std::vector &cost_map) { - if (cost_map.size() < 2) { - MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); - return FAILED; - } - int32_t fusion = 1; - for (auto cost_iter = cost_map.end() - 1; cost_iter != cost_map.begin(); --cost_iter) { - auto paras = allreduce_graph_.GetParaByCost(*(cost_iter - 1), *cost_iter); - if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) { - MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; - return FAILED; - } - fusion++; - } - return SUCCESS; -} - -std::vector AllreduceFusion::GenerateCostMap(int32_t fusion_times, double tail_percent) const { - double offset = allreduce_graph_.max() * (1 - tail_percent) / (fusion_times - 1); - MS_LOG(DEBUG) << "max = " << allreduce_graph_.max() << ", offset = " << offset; - std::vector cost_map; - double begin = 0; - for (auto i = 0; i < fusion_times - 1; i++) { - cost_map.push_back(begin); - begin += offset; - } - cost_map.push_back(allreduce_graph_.max() * (1 - tail_percent)); - cost_map.push_back(allreduce_graph_.max()); - MS_LOG(DEBUG) << "cost_map = " << cost_map; - return cost_map; -} - -Status AllreduceFusion::SetFusionByBackwardCompTime() { - auto fusion_times = CostModelContext::GetInstance()->costmodel_allreduce_fusion_times(); - if (fusion_times < 2) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_times' is " << fusion_times << ". Bypass ProcessAllreduceFusion"; - return SUCCESS; - } - auto tail_percent = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_percent(); - if (tail_percent < 0 || tail_percent >= 1) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_tail_percent' is " << tail_percent - << ". Bypass ProcessAllreduceFusion"; - return SUCCESS; - } - const auto cost_map = GenerateCostMap(fusion_times, tail_percent); - MS_LOG(DEBUG) << "AllreduceGraph GenerateCostMap succeed."; - if (SetFusion(cost_map) != SUCCESS) { - MS_LOG(ERROR) << "SetFusion failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph SetFusion succeed."; - return SUCCESS; -} - -Status AllreduceFusion::GetSetFusionByBackwardCompAndAllreduceTimeParams() { - tail_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_time(); - if (tail_time_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - allreduce_inherent_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_inherent_time(); - if (allreduce_inherent_time_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_ - << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - if (tail_time_ <= allreduce_inherent_time_) { - MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ - << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_ - << ".tail_time is not more than allreduce_inherent_time. Bypass ProcessAllreduceFusion"; - return FAILED; - } - allreduce_bandwidth_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_bandwidth(); - if (allreduce_bandwidth_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_bandwidth' is " << allreduce_bandwidth_ - << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - computation_time_parameter_ = - CostModelContext::GetInstance()->costmodel_allreduce_fusion_computation_time_parameter(); - if (computation_time_parameter_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_computation_time_parameter' is " << computation_time_parameter_ - << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - return SUCCESS; -} - -Status AllreduceFusion::SetFusionByBackwardCompAndAllreduceTime() { - if (GetSetFusionByBackwardCompAndAllreduceTimeParams() != SUCCESS) { - MS_LOG(ERROR) << "GetSetFusionByBackwardCompAndAllreduceTimeParams failed!"; - return FAILED; - } - allreduce_graph_.SortArnode(); - if (allreduce_graph_.RemoveExtraParas() != SUCCESS) { - MS_LOG(ERROR) << "RemoveExtraParas failed!"; - return FAILED; - } - double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_; - double to_cost = allreduce_graph_.max(); - int32_t fusion = 1; - while (to_cost != 0) { - MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size; - auto node_cost_pair = allreduce_graph_.GetParaByParaSize(to_cost, para_size); - MS_LOG(INFO) << "para size: " << node_cost_pair.first.size() << " from_cost: " << node_cost_pair.second; - auto paras = node_cost_pair.first; - if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) { - MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; - return FAILED; - } - fusion++; - para_size = ((to_cost - node_cost_pair.second) * computation_time_parameter_ - allreduce_inherent_time_) / - allreduce_bandwidth_; - to_cost = node_cost_pair.second; - } - MS_LOG(DEBUG) << "AllreduceGraph SetFusionByBackwardCompAndAllreduceTime succeed."; - return SUCCESS; -} - -Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) { - if (algorithm == 1) { - return SetFusionByBackwardCompTime(); - } - return SetFusionByBackwardCompAndAllreduceTime(); -} - -Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { - if (ret == nullptr) { - MS_LOG(ERROR) << "ret is nullptr."; - return FAILED; - } - auto algorithm = CostModelContext::GetInstance()->costmodel_allreduce_fusion_algorithm(); - if (algorithm < 1 || algorithm > 2) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_algorithm' is " << algorithm << ". Bypass ProcessAllreduceFusion"; - return SUCCESS; - } - ret_ = ret; - root_graph_ = ret_->func_graph(); - MS_EXCEPTION_IF_NULL(root_graph_); - auto graph_set = ForwardGraph(root_graph_); - if (graph_set.size() > 1) { - MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now."; - return SUCCESS; - } - auto forward_graph = *(graph_set.begin()); - MS_EXCEPTION_IF_NULL(forward_graph); - forward_ret_ = forward_graph->get_return(); - MS_EXCEPTION_IF_NULL(forward_ret_); - - if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) { - MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph set_head_cnode succeed."; - if (AddNodeToGraph() != SUCCESS) { - MS_LOG(ERROR) << "AddNodeToGraph failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph AddNodeToGraph succeed."; - if (AddEdgeToGraph() != SUCCESS) { - MS_LOG(ERROR) << "AddNodeToGraph failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph AddEdgeToGraph succeed."; - if (SetFusionByAlgorithm(algorithm) != SUCCESS) { - MS_LOG(ERROR) << "SetFusionByAlgorithm failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph SetFusionByAlgorithm succeed."; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h deleted file mode 100644 index 43a9935095..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ - -#include -#include -#include "ir/anf.h" -#include "parallel/allreduce_fusion/allreduce_graph.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -using CNodeCostMap = std::unordered_map; - -constexpr int32_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM = 0; -constexpr int32_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES = 0; -constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT = 0.1; -constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME = 0.1; -constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0.1; -constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1; -constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1; - -constexpr char FUSION[] = "fusion"; -constexpr char PARAMETER[] = "parameter"; -const uint32_t MAX_RECURSIVE_CALL_TIMES = 100; -class AllreduceFusion { - public: - AllreduceFusion() - : allreduce_graph_(), - ret_(nullptr), - forward_ret_(nullptr), - root_graph_(nullptr), - tail_time_(0), - allreduce_inherent_time_(0), - allreduce_bandwidth_(0), - computation_time_parameter_(0) {} - virtual ~AllreduceFusion() = default; - Status ProcessAllreduceFusion(const CNodePtr &ret); - - private: - Status AddNodeToGraph(); - CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const; - CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const; - Status AddEdgeToGraph(); - std::vector GenerateCostMap(int32_t fusion_times, double tail_percent) const; - Status SetFusion(const std::vector &cost_map); - Status SetFusionByAlgorithm(int32_t algorithm); - Status SetFusionByBackwardCompTime(); - Status SetFusionByBackwardCompAndAllreduceTime(); - Status GetSetFusionByBackwardCompAndAllreduceTimeParams(); - - AllreduceGraph allreduce_graph_; - CNodePtr ret_; - CNodePtr forward_ret_; - FuncGraphPtr root_graph_; - double tail_time_; - double allreduce_inherent_time_; - double allreduce_bandwidth_; - double computation_time_parameter_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc deleted file mode 100644 index 2a98a38add..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc +++ /dev/null @@ -1,209 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/allreduce_fusion/allreduce_graph.h" -#include -#include -#include "ir/anf.h" -#include "parallel/allreduce_fusion/allreduce_node.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) { - AllreduceNodePtr arnode; - auto cnode_emplace_return = cnode_set_.emplace(node); - if (!cnode_emplace_return.second) { - MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!"; - auto cnode_arnode_pair = cnode_arnode_map_.find(node); - if (cnode_arnode_pair == cnode_arnode_map_.end()) { - MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!"; - } - arnode = cnode_arnode_pair->second; - } else { - arnode = std::make_shared(AllreduceNode()); - } - - if (arnode->Init(node) != SUCCESS) { - MS_LOG(ERROR) << "AllreduceNode Init failed"; - return FAILED; - } - if (arnode->AddPara(para) != SUCCESS) { - MS_LOG(ERROR) << "AllreduceNode AddPara failed"; - return FAILED; - } - cnode_arnode_map_[node] = arnode; - - auto arnode_emplace_return = arnode_set_.insert(arnode); - if (!arnode_emplace_return.second) { - MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!"; - } - cnode_emplace_return = para_cnodeset_map_[para].emplace(node); - if (!cnode_emplace_return.second) { - MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope() - << "'s cnodeset!"; - } - auto para_emplace_return = cnode_paraset_map_[node].emplace(para); - if (!para_emplace_return.second) { - MS_LOG(INFO) << "para: " << para->fullname_with_scope() << " already in node: " << node->DebugString() - << "'s paraset!"; - } - return SUCCESS; -} - -Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) { - auto from_arnode_iter = cnode_arnode_map_.find(from); - if (from_arnode_iter == cnode_arnode_map_.end()) { - MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; - PrintCNodeSet(); - return FAILED; - } - auto to_arnode_iter = cnode_arnode_map_.find(to); - if (to_arnode_iter == cnode_arnode_map_.end()) { - MS_LOG(ERROR) << "cnode to: " << to->DebugString() << "has not been added"; - PrintCNodeSet(); - return FAILED; - } - auto from_arnode = from_arnode_iter->second; - auto to_arnode = to_arnode_iter->second; - if (from_arnode->AddNext(to_arnode) != SUCCESS) { - MS_LOG(ERROR) << "from_arnode AddNext failed"; - return FAILED; - } - if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) { - MS_LOG(ERROR) << "to_arnode AddPrev failed"; - return FAILED; - } - max_ = std::max(max_, to_arnode->depend_feat_size()); - MS_LOG(DEBUG) << "from " << from->DebugString() << ", to " << to->DebugString(); - MS_LOG(DEBUG) << "from depend_feat_size: " << from_arnode->depend_feat_size() - << ", to depend_feat_size: " << to_arnode->depend_feat_size(); - return SUCCESS; -} - -bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const { - auto cnode_iter = cnode_set_.find(node); - return !(cnode_iter == cnode_set_.end()); -} - -std::vector AllreduceGraph::GetParaByCost(double from, double to) { - std::vector nodes; - for (auto &cnode_arnode : cnode_arnode_map_) { - MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() - << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() - << " curr_para_size: " << cnode_arnode.second->curr_para_size(); - if ((cnode_arnode.second->depend_feat_size() <= to) && (cnode_arnode.second->depend_feat_size() > from)) { - (void)nodes.insert(nodes.end(), cnode_paraset_map_[cnode_arnode.first].begin(), - cnode_paraset_map_[cnode_arnode.first].end()); - } - } - return nodes; -} - -std::pair, double> AllreduceGraph::GetParaByParaSize(double to, double para_size) { - std::vector nodes; - double cur_para_size = 0; - double from = to; - for (auto &arnode : arnode_vec_) { - if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { - continue; - } - if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) { - return std::make_pair(nodes, from); - } - (void)nodes.insert(nodes.end(), arnode.paras().begin(), arnode.paras().end()); - cur_para_size += arnode.curr_para_size(); - from = arnode.depend_feat_size(); - } - MS_LOG(INFO) << "GetParaByParaSize has reached head node! para_size: " << para_size - << " cur_para_size: " << cur_para_size << " from: " << from; - return std::make_pair(nodes, from); -} - -void AllreduceGraph::PrintCNodeSet() const { - MS_LOG(INFO) << "CNodeSet:"; - for (auto &cnode : cnode_set_) { - MS_LOG(INFO) << cnode->DebugString(); - } -} - -void AllreduceGraph::PrintAllredueGraphInfo() const { - MS_LOG(INFO) << "max: " << max_; - for (auto &cnode_arnode : cnode_arnode_map_) { - MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); - MS_LOG(INFO) << "arnode info: "; - cnode_arnode.second->ToString(); - } -} - -void AllreduceGraph::PrintArnodeVec() const { - MS_LOG(INFO) << "ArnodeVec:"; - for (auto &arnode : arnode_vec_) { - arnode.ToString(); - } -} - -void AllreduceGraph::PrintArnodeSet() const { - MS_LOG(INFO) << "ArnodeSet:"; - for (auto &arnode : arnode_set_) { - arnode->ToString(); - } -} - -void AllreduceGraph::SortArnode() { - arnode_vec_.clear(); - for (auto &node : arnode_set_) { - arnode_vec_.emplace_back(*node); - } - std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); -} - -Status AllreduceGraph::RemoveExtraParas() { - std::unordered_set para_map; - for (auto &node : arnode_vec_) { - for (auto ¶ : node.paras()) { - auto emplac_result = para_map.emplace(para); - if (!emplac_result.second) { - MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; - if (node.RemovePara(para) != SUCCESS) { - MS_LOG(ERROR) << "remove para failed"; - return FAILED; - } - } - } - } - return SUCCESS; -} - -Status AllreduceGraph::set_head_cnode(const CNodePtr &node) { - auto arnode = std::make_shared(AllreduceNode()); - if (arnode->Init(node) != SUCCESS) { - MS_LOG(ERROR) << "AllreduceNode Init failed"; - } - head_cnode_ = node; - cnode_arnode_map_[node] = arnode; - auto arnode_emplace_return = arnode_set_.insert(arnode); - if (!arnode_emplace_return.second) { - MS_LOG(WARNING) << "node: " << node->DebugString() << "'s arnode has already been added!"; - } - auto cnode_emplace_return = cnode_set_.emplace(node); - if (!cnode_emplace_return.second) { - MS_LOG(WARNING) << "node: " << node->DebugString() << " has already been added!"; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h deleted file mode 100644 index b2084b735c..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "parallel/allreduce_fusion/allreduce_node.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -class AllreduceGraph { - public: - AllreduceGraph() - : head_cnode_(nullptr), - arnode_set_(), - arnode_vec_(), - cnode_set_(), - para_cnode_map_(), - para_cnodeset_map_(), - cnode_paraset_map_(), - cnode_arnode_map_(), - max_(0) {} - virtual ~AllreduceGraph() = default; - Status AddNode(const CNodePtr &node, const AnfNodePtr ¶); - Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist); - bool NodeInGraph(const CNodePtr &node) const; - std::vector GetParaByCost(double from, double to); - // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is - // over para_size. - // Return the parameter AnfNodePtr vector corresponding to these AllreduceNodes and the smallest depend_feat_size. - // If the sum of left AllreduceNode's parameter size is less than para_size, the returned depend_feat_size must be 0. - std::pair, double> GetParaByParaSize(double to, double para_size); - // If one parameter is used by multiple AllreduceNode, parameter belong to the last node for backward computation - // is saved by the corresponding AllreduceNode, parameters belong to other AllreduceNode are removed. - // Called during precise optimization, not implemented temporarily. - void SortArnode(); - Status RemoveExtraParas(); - void PrintCNodeSet() const; - void PrintAllredueGraphInfo() const; - void PrintArnodeVec() const; - void PrintArnodeSet() const; - const std::unordered_set &cnode_set() const { return cnode_set_; } - CNodePtr head_cnode() const { return head_cnode_; } - Status set_head_cnode(const CNodePtr &node); - double max() const { return max_; } - - private: - CNodePtr head_cnode_; - std::set arnode_set_; - std::vector arnode_vec_; - std::unordered_set cnode_set_; - // If One ParameterPtr is used by multiple CNode, the last node for backward computation is saved. - std::unordered_map> para_cnode_map_; - // One ParameterPtr may be used by multiple CNode - std::unordered_map> para_cnodeset_map_; - // Multiple Parameter may be inputs to the same CNode - std::unordered_map> cnode_paraset_map_; - std::unordered_map cnode_arnode_map_; - double max_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc deleted file mode 100644 index 113d4ec59b..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/allreduce_fusion/allreduce_node.h" -#include -#include "parallel/tensor_layout/tensor_layout.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status AllreduceNode::AddNext(const AllreduceNodePtr &next_node) { - if (next_node == nullptr) { - MS_LOG(ERROR) << "next_node is nullptr!"; - return FAILED; - } - next_.emplace_back(next_node); - return SUCCESS; -} - -Status AllreduceNode::AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max) { - if (prev_node == nullptr) { - MS_LOG(ERROR) << "next_node is nullptr!"; - return FAILED; - } - if (dist <= 0) { - MS_LOG(ERROR) << "dist must be positive! dist: " << dist; - return FAILED; - } - prev_.emplace_back(prev_node); - double add_dist = prev_node->depend_feat_size() + dist; - depend_feat_size_ += add_dist; - if (depend_feat_size_ > *max) { - *max = depend_feat_size_; - } - std::queue next_queue; - for (auto &next : next_) { - next_queue.push(next); - } - while (!next_queue.empty()) { - auto ele = next_queue.front(); - ele->AddDependFeatSize(add_dist); - if (ele->depend_feat_size() > *max) { - *max = ele->depend_feat_size(); - } - for (auto &next : ele->next()) { - next_queue.push(next); - } - next_queue.pop(); - } - return SUCCESS; -} - -Status AllreduceNode::Init(const CNodePtr &cnode_ptr) { - if (cnode_ptr == nullptr) { - MS_LOG(ERROR) << "cnode_ptr is nullptr!"; - return FAILED; - } - cnode_ptr_ = cnode_ptr; - return SUCCESS; -} - -Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { - if (node_ptr == nullptr) { - MS_LOG(ERROR) << "node_ptr is nullptr!"; - return FAILED; - } - if (!node_ptr->isa()) { - MS_LOG(ERROR) << "node_ptr is not a ParameterPtr!"; - return FAILED; - } - auto para_ptr = node_ptr->cast(); - MS_EXCEPTION_IF_NULL(para_ptr); - auto layout_ptr = para_ptr->tensor_layout(); - if (layout_ptr == nullptr) { - MS_LOG(ERROR) << "layout_ptr is nullptr!"; - return FAILED; - } - auto emplace_return = paras_.emplace(node_ptr); - if (emplace_return.second) { - double para_size = static_cast(layout_ptr->slice_shape().size()); - curr_para_size_ += para_size; - para_size_map_[node_ptr] = para_size; - } else { - MS_LOG(INFO) << "node already exist!"; - } - return SUCCESS; -} - -Status AllreduceNode::RemovePara(const AnfNodePtr &node_ptr) { - if (node_ptr == nullptr) { - MS_LOG(ERROR) << "node_ptr is nullptr!"; - return FAILED; - } - auto erase_num = paras_.erase(node_ptr); - if (erase_num == 0) { - MS_LOG(ERROR) << "para not find!"; - return FAILED; - } - curr_para_size_ -= para_size_map_[node_ptr]; - return SUCCESS; -} - -void AllreduceNode::ToString() const { - MS_LOG(INFO) << "cnode: " << cnode_ptr_->DebugString() << "para size: " << paras_.size(); - for (auto ¶ : paras_) { - MS_LOG(INFO) << "para name: " << para->fullname_with_scope() << " size: " << para_size_map_.at(para); - } - MS_LOG(INFO) << "depend_feat_size: " << depend_feat_size_ << " curr_para_size: " << curr_para_size_; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h deleted file mode 100644 index db1c4e3f2e..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -class AllreduceNode; -using AllreduceNodePtr = std::shared_ptr; - -class AllreduceNode { - public: - AllreduceNode() - : cnode_ptr_(nullptr), prev_(), next_(), paras_(), para_size_map_(), curr_para_size_(0), depend_feat_size_(0) {} - Status Init(const CNodePtr &cnode_ptr); - Status AddPara(const AnfNodePtr &node_ptr); - Status RemovePara(const AnfNodePtr &node_ptr); - const std::unordered_set ¶s() const { return paras_; } - double curr_para_size() const { return curr_para_size_; } - virtual ~AllreduceNode() = default; - // Add previous node - // prev_node is the previous to be added - // max is the current max depend_feat_size of the AllreduceGraph - Status AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max); - Status AddNext(const AllreduceNodePtr &next_node); - double depend_feat_size() const { return depend_feat_size_; } - void AddDependFeatSize(double add_dist) { depend_feat_size_ += add_dist; } - const std::vector &next() const { return next_; } - void ToString() const; - bool operator<(const AllreduceNode &node) const { return depend_feat_size_ < node.depend_feat_size(); } - bool operator>(const AllreduceNode &node) const { return depend_feat_size_ > node.depend_feat_size(); } - - private: - CNodePtr cnode_ptr_; - std::vector prev_; - std::vector next_; - std::unordered_set paras_; - std::unordered_map para_size_map_; - double curr_para_size_; - double depend_feat_size_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc b/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc deleted file mode 100644 index 999c4a85a9..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/allreduce_fusion/step_allreduce_fusion.h" -#include -#include -#include "optimizer/optimizer.h" -#include "parallel/allreduce_fusion/allreduce_fusion.h" -#include "parallel/context.h" -#include "parallel/graph_util/graph_info.h" -#include "parallel/status.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(optimizer); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); - bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion(); - // assume no change to graph - bool changes = false; - // control whether use model_parallel mode - if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || - (!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) { - return changes; - } -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif - MS_LOG(INFO) << "Now entering allreduce fusion"; - DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN)); - - pipeline::ResourceBasePtr res = optimizer->resource(); - MS_EXCEPTION_IF_NULL(res); - - FuncGraphManagerPtr manager = res->manager(); - MS_EXCEPTION_IF_NULL(manager); - CNodePtr ret = root->get_return(); - MS_EXCEPTION_IF_NULL(ret); - - AllreduceFusion allreduce_fusion; - if (allreduce_fusion.ProcessAllreduceFusion(ret) != SUCCESS) { - MS_LOG(EXCEPTION) << "ProcessAllreduceFusion failed"; - } - - DumpGraph(root, std::string(ALLREDUCE_FUSION_END)); - - // allreduce fusion only run once - root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true); - res->results()[pipeline::kStepParallelGraph] = root; -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << time << " us"; -#endif - return changes; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.h b/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.h deleted file mode 100644 index 2343a7a2fe..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ - -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace parallel { -constexpr char ALLREDUCE_FUSION_RUN_ONCE_ONLY[] = "allreduce_fusion_run_once_only"; -constexpr char ALLREDUCE_FUSION_BEGIN[] = "allreduce_fusion_begin"; -constexpr char ALLREDUCE_FUSION_END[] = "allreduce_fusion_end"; - -bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc deleted file mode 100644 index 65e9acf714..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc +++ /dev/null @@ -1,123 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/costmodel.h" -#include -#include -#include -#include "parallel/auto_parallel/graph_costmodel.h" - -namespace mindspore { -namespace parallel { -void Simplify(CostPtrList *clist_ptrs) { - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); - } else { - // inference phase - SimplifyForDecreasingCommunicationForward(clist_ptrs); - } -} -void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { - // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method - // excludes the cost with greater computation_cost_ and greater communication_forward. - // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} - if (!COST_MODEL_SIMPLIFY_CALCULATION) { - return; - } - MS_EXCEPTION_IF_NULL(clist_ptrs); - std::vector id(clist_ptrs->size()); - std::iota(id.begin(), id.end(), size_t(0)); - std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { - return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; - }); - CostPtrList ret; - for (size_t i = 0; i < clist_ptrs->size(); ++i) { - if ((ret.size() == size_t(0)) || - (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) { - ret.emplace_back(std::move(clist_ptrs->at(id[i]))); - } - } - *clist_ptrs = std::move(ret); -} - -void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { - // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing - // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. - if (!COST_MODEL_SIMPLIFY_CALCULATION) { - return; - } - MS_EXCEPTION_IF_NULL(clist_ptrs); - std::vector id(clist_ptrs->size()); - std::iota(id.begin(), id.end(), size_t(0)); - std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { - return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; - }); - CostPtrList ret; - for (size_t i = 0; i < clist_ptrs->size(); ++i) { - if ((ret.size() == size_t(0)) || - (clist_ptrs->at(id[i])->communication_with_partial_para_ < ret.back()->communication_with_partial_para_)) { - ret.emplace_back(std::move(clist_ptrs->at(id[i]))); - } - } - *clist_ptrs = std::move(ret); -} - -void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { - MS_EXCEPTION_IF_NULL(origin_cost); - if (is_redistribution) { - // Redistribution cost - if ((origin_cost->communication_redis_forward_ > EPS) && - (origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { - origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST; - } else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) { - origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS; - } - if ((origin_cost->communication_redis_backward_ > EPS) && - (origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { - origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST; - } else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) { - origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS; - } - origin_cost->communication_cost_ = - origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_; - origin_cost->communication_without_parameter_ = origin_cost->communication_cost_; - origin_cost->communication_with_partial_para_ = origin_cost->communication_cost_; - } else { - // Operator cost - double backward = 0.0; - if (std::abs(origin_cost->communication_cost_ - origin_cost->communication_without_parameter_) > EPS) { - backward = origin_cost->communication_cost_ - origin_cost->communication_without_parameter_; - } - // forward cost - if ((origin_cost->communication_without_parameter_ > EPS) && - (origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) { - origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST; - } else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) { - origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS; - } - // total - if (origin_cost->communication_cost_ > EPS) { - origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward; - } - if (origin_cost->communication_with_partial_para_ > EPS) { - origin_cost->communication_with_partial_para_ = - origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward; - } - } -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h deleted file mode 100644 index 8b92e18cd8..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h +++ /dev/null @@ -1,311 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ - -#include -#include -#include -#include -#include -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_info.h" - -namespace mindspore { -namespace parallel { -struct Decision; -using OperatorName = std::string; -using Attr = std::pair; -using Param = std::pair, int32_t>; -using OperatorParams = std::vector; -using OperatorAttrs = std::vector; -// OutPutInfo.fist: true if the operator's output is a tuple -// OutPutInfo.second: elements number of the tuple output. Only meaningful if OutPutInfo.fist is true. -using OutPutInfo = std::pair; -using OutPutInfoVector = std::vector; -using OperatorArgs = std::pair; -using Operator = std::pair; -using OperatorVector = std::vector; -using RedistributionOpListPtr = std::shared_ptr>; - -struct Cost { - Cost(); - Cost(double computation, double commuication, const std::shared_ptr &decision_ = nullptr) - : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { - memory_with_reuse_ = 0.0; - communication_without_parameter_ = 0.0; - communication_with_partial_para_ = 0.0; - communication_redis_forward_ = 0.0; - communication_redis_backward_ = 0.0; - communication_forward_ = 0.0; - } - // 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase - double memory_with_reuse_; - // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated - // by ONLY forward phase - double computation_cost_; - // 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution) - double communication_cost_; - // communication_without_parameter_ = communication_cost_ - (backward communication from operators) - double communication_without_parameter_; - // communication_with_partial_para_ = - // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ ) - double communication_with_partial_para_; - // communication_forward_ = communication cost from operators (only forward phase) and forward redistribution. - double communication_forward_; - double communication_redis_forward_; - double communication_redis_backward_; - std::shared_ptr decision_ptr_; -}; - -using CostPtr = std::shared_ptr; -using CostPtrList = std::vector>; - -class StrategyWithCost { - public: - StrategyWithCost(StrategyPtr strategy, std::vector inputs_, std::vector outputs_) - : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} - - StrategyWithCost(const StrategyWithCost &swc) = delete; - StrategyWithCost(StrategyWithCost &&swc) - : strategy_ptr(swc.strategy_ptr), - inputs_ptr(swc.inputs_ptr), - outputs_ptr(swc.outputs_ptr), - cost_list(swc.cost_list) {} - ~StrategyWithCost() = default; - - StrategyPtr strategy_ptr; - std::vector inputs_ptr; - std::vector outputs_ptr; - CostPtrList cost_list; -}; - -enum DecisionType { - OP_ELIMINATION, - EDGE_ELIMINATION, - MERGE_ELIMINATION, - CONTRACT_ELIMINATION, - TRIANGLE_ELIMINATION, - STAR_ELIMINATION, - FINAL_TYPE, - FINAL_SINGLE -}; - -struct Decision : public Base { - ~Decision() override = default; - DecisionType type_; -}; - -// 'OpEliminationDecision' is for the Operator Elimination in DP algorithm: u --> v --> w ==> u --> w. -// This data structure records the strategy 'op_strategy_' for v, the edge cost 'left_cost_' for 'u --> v', the -// operator cost 'middle_cost_' for v, and the edge cost 'right_cost_' for 'v --> w' -struct OpEliminationDecision : public Decision { - OpEliminationDecision(StrategyPtr op_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost) - : op_strategy_(std::move(op_stra)), - left_cost_(std::move(l_cost)), - middle_cost_(std::move(m_cost)), - right_cost_(std::move(r_cost)) { - type_ = DecisionType::OP_ELIMINATION; - } - - StrategyPtr op_strategy_; - CostPtr left_cost_; - CostPtr middle_cost_; - CostPtr right_cost_; - MS_DECLARE_PARENT(OpEliminationDecision, Decision); -}; - -/* 'EdgeEliminationDecision' is for the Edge Elimination in DP algorithm: - ____ - / \ - u v ==> u --> v, which replace the multi-edges by a single edge. - \____/ - This data structure records the cost list for all edges 'edges_cost_list_' - */ -struct EdgeEliminationDecision : public Decision { - explicit EdgeEliminationDecision(CostPtrList cost_list) : edges_cost_list_(std::move(cost_list)) { - type_ = DecisionType::EDGE_ELIMINATION; - } - - CostPtrList edges_cost_list_; - MS_DECLARE_PARENT(EdgeEliminationDecision, Decision); -}; - -// 'MergeEliminationDecision' is for the Merge Elimination in DP algorithm: -// w -// | -// | ==> u --> v -// u --> v In the original graph, v has two alive incoming edges, w has one alive outgoing edge, -// and w has zero alive incoming edges. After the Merge Elimination, the result graph contains only 'u -- >v'. -// This data structure records the strategy 'merged_op_strategy_' for operator 'w', -// the cost 'merged_op_cost_' for operator 'w', and the edge cost 'edge_cost_' for 'w --> v'. -struct MergeEliminationDecision : public Decision { - MergeEliminationDecision(StrategyPtr op_stra, CostPtr op_cost, CostPtr edge_c, StrategyPtr tar_op_stra, - CostPtr target_op_c) - : merged_op_strategy_(std::move(op_stra)), - merged_op_cost_(std::move(op_cost)), - edge_cost_(std::move(edge_c)), - target_op_strategy_(std::move(tar_op_stra)), - target_op_cost_(std::move(target_op_c)) { - type_ = DecisionType::MERGE_ELIMINATION; - } - - StrategyPtr merged_op_strategy_; - CostPtr merged_op_cost_; - CostPtr edge_cost_; - StrategyPtr target_op_strategy_; - CostPtr target_op_cost_; - MS_DECLARE_PARENT(MergeEliminationDecision, Decision); -}; - -// 'ContractEliminationDecision' is for the Contract Elimination in DP algorithm: -// u --> v -// | -// | ==> u --> w -// w In the original graph, u has two alive outgoing edges, v has one alive incoming edge, -// and v has zero outgoing edge. After the Contract Elimination, the result graph contains only 'u --> w'. -// This data structure records the strategy 'contracted_op_strategy_' for operator 'v', the cost for -// operator 'contracted_op_cost_', and the edge cost for 'edge_cost_'. -struct ContractEliminationDecision : public Decision { - ContractEliminationDecision(StrategyPtr contra_stra, CostPtr contra_op_cost, CostPtr edge_cost, - StrategyPtr target_stra, CostPtr tar_cost) - : contracted_op_strategy_(std::move(contra_stra)), - contracted_op_cost_(std::move(contra_op_cost)), - edge_cost_(std::move(edge_cost)), - target_op_strategy_(std::move(target_stra)), - target_cost_(std::move(tar_cost)) { - type_ = DecisionType::CONTRACT_ELIMINATION; - } - - StrategyPtr contracted_op_strategy_; - CostPtr contracted_op_cost_; - CostPtr edge_cost_; - StrategyPtr target_op_strategy_; - CostPtr target_cost_; - MS_DECLARE_PARENT(ContractEliminationDecision, Decision); -}; - -/* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm: - * - * u - * / \ - * / \ - * v --- w ==> v --- w In the original graph, u has 2 outgoing edges, v has 1 outgoing edge, - * and w has 2 incoming edges, u can be eliminated into v. - * 'eliminated_op_strategy_' is for u, 'eliminated_op_cost_' is for u, 'eliminated_left_edge_' is for edge u --> v, - * 'eliminated_right_edge_' is for edge u --> w. - */ -struct TriangleEliminationDecision : public Decision { - TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, - StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra) - : eliminated_op_strategy_(std::move(elimi_stra)), - eliminated_op_cost_(std::move(elimi_op_cost)), - left_edge_cost_(std::move(l_edge_cost)), - right_edge_cost_(std::move(r_edge_cost)), - left_node_strategy_(std::move(left_stra)), - left_node_cost_(std::move(l_node_cost)), - right_node_strategy_(std::move(right_stra)) { - type_ = DecisionType::TRIANGLE_ELIMINATION; - } - - StrategyPtr eliminated_op_strategy_; - CostPtr eliminated_op_cost_; - CostPtr left_edge_cost_; - CostPtr right_edge_cost_; - StrategyPtr left_node_strategy_; - CostPtr left_node_cost_; - StrategyPtr right_node_strategy_; - MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); -}; - -/* 'StarEliminationDecision' is for the Star Elimination in DP algorithm: - * - * v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges. - * In addition, v and w have other complicated connections, resulting in v and w can not be performed other - * eliminations. After the StarElimination, u is merged into v, and the resulting graph is splitted into multiple - * connected components. - * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. - */ -struct StarEliminationDecision : public Decision { - StarEliminationDecision(StrategyPtr elimi_op_stra, CostPtr elimi_op_cost, CostPtrList succ_edges_clist, - std::vector succ_ops_stra_list, CostPtrList succ_ops_clist) - : eliminated_op_strategy_(std::move(elimi_op_stra)), - eliminated_op_cost_(std::move(elimi_op_cost)), - succ_edges_cost_list_(std::move(succ_edges_clist)), - succ_ops_stra_list_(std::move(succ_ops_stra_list)), - succ_ops_cost_list_(std::move(succ_ops_clist)) { - type_ = DecisionType::STAR_ELIMINATION; - } - - StrategyPtr eliminated_op_strategy_; - CostPtr eliminated_op_cost_; - CostPtrList succ_edges_cost_list_; - std::vector succ_ops_stra_list_; - CostPtrList succ_ops_cost_list_; - MS_DECLARE_PARENT(StarEliminationDecision, Decision); -}; - -// This data structure records the decision for the graph which contains two nodes: u --> v. This includes -// the strategy 'u_strategy_' for 'u', the strategy 'v_strategy_' for 'v', the cost 'left_cost_' for 'u'. -struct FinalDecision : public Decision { - FinalDecision(StrategyPtr u_stra, StrategyPtr v_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost) - : u_strategy_(std::move(u_stra)), - v_strategy_(std::move(v_stra)), - left_cost_(std::move(l_cost)), - middle_cost_(std::move(m_cost)), - right_cost_(std::move(r_cost)) { - type_ = DecisionType::FINAL_TYPE; - } - - StrategyPtr u_strategy_; - StrategyPtr v_strategy_; - CostPtr left_cost_; - CostPtr middle_cost_; - CostPtr right_cost_; - MS_DECLARE_PARENT(FinalDecision, Decision); -}; - -// This data structure records the final decision for the graph containing a single node: u. This includes -// the strategy 'u_strategy_' for 'u', the cost 'u_cost_' for 'u'. -struct FinalSingleDecision : public Decision { - FinalSingleDecision(StrategyPtr u_stra, CostPtr u_cost) : u_strategy_(std::move(u_stra)), u_cost_(std::move(u_cost)) { - type_ = DecisionType::FINAL_SINGLE; - } - - StrategyPtr u_strategy_; - CostPtr u_cost_; - MS_DECLARE_PARENT(FinalSingleDecision, Decision); -}; - -using DecisionPtr = std::shared_ptr; -using OpEliminationDecisionPtr = std::shared_ptr; -using EdgeEliminationDecisionPtr = std::shared_ptr; -using MergeEliminationDecisionPtr = std::shared_ptr; -using ContractEliminationDecisionPtr = std::shared_ptr; -using TriangleEliminationDecisionPtr = std::shared_ptr; -using StarEliminationDecisionPtr = std::shared_ptr; -using FinalDecisionPtr = std::shared_ptr; -using FinalSingleDecisionPtr = std::shared_ptr; - -void Simplify(CostPtrList *clist); -void SimplifyForDecreasingCommunicationForward(CostPtrList *clist); -void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist); -void RefineForPracticalCost(const CostPtr &, bool is_redistribution); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc deleted file mode 100644 index 72451fab57..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc +++ /dev/null @@ -1,226 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/dp_algo_costmodel.h" - -#include -#include -#include - -namespace mindspore { -namespace parallel { -Status GetStrategy(const CostGraphPtr &graph) { - MS_LOG(INFO) << "Searching strategies begins."; - MS_EXCEPTION_IF_NULL(graph); - std::vector eliminations; - bool flag = true; - - // Phase 1: Shrink the CostGraph using 6 operations, and record them in the order. - // Note: the checking and applying of the 6 operations MUST in current order. - while (flag) { - flag = false; - auto node = graph->CheckOpElimination(); - if (node != nullptr) { - // Applying the Operator Elimination - flag = true; - auto l_edge = node->GetAlivePrevEdges()[0]; - auto r_edge = node->GetAliveSuccEdges()[0]; - auto n_edge = graph->EliminationOp(node); - auto elimi = std::make_shared(n_edge, l_edge, node, r_edge); - eliminations.emplace_back(std::move(elimi)); - } - auto edges = graph->CheckEdgeElimination(); - if ((!flag) && (!edges.empty())) { - // Applying the Edge Elimination - flag = true; - auto n_edge = graph->EliminationEdges(edges); - auto elimi = std::make_shared(n_edge, edges); - eliminations.emplace_back(std::move(elimi)); - } - auto merge_node = graph->CheckMergeElimination(); - if ((!flag) && (merge_node != nullptr)) { - // Applying the Merge Elimination - flag = true; - auto succ_edge = merge_node->GetAliveSuccEdges()[0]; - auto target_node = graph->EliminationMerge(merge_node); - auto elimi = std::make_shared(merge_node, succ_edge, target_node); - eliminations.emplace_back(std::move(elimi)); - } - auto contracted_node = graph->CheckContractElimination(); - if ((!flag) && (contracted_node != nullptr)) { - // Applying the Contract Elimination - flag = true; - auto prev_edge = contracted_node->GetAlivePrevEdges()[0]; - auto target_node = graph->EliminationContract(contracted_node); - auto elimi = std::make_shared(target_node, prev_edge, contracted_node); - eliminations.emplace_back(std::move(elimi)); - } - auto triangle_pair = graph->CheckTriangleElimination(); - if ((!flag) && (triangle_pair.first != nullptr)) { - // Applying the Triangle Elimination - flag = true; - auto eliminated_node = triangle_pair.first; - auto l_r_edge = triangle_pair.second; - - auto left_node = l_r_edge->prev_operator(); - auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; - auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; - MS_EXCEPTION_IF_NULL(left_edge); - if (left_edge->next_operator() != left_node) { - auto tmp = left_edge; - left_edge = right_edge; - right_edge = tmp; - } - auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); - auto right_node = l_r_edge->next_operator(); - auto elimi = - std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); - eliminations.emplace_back(std::move(elimi)); - } - auto star_center = graph->CheckStarElimination(); - if ((!flag) && (star_center != nullptr)) { - // Applying the Star Elimination - flag = true; - auto succ_edges = graph->EliminationStar(star_center); - std::vector succ_nodes; - for (size_t i = 0; i < succ_edges.size(); ++i) { - MS_EXCEPTION_IF_NULL(succ_edges[i]); - succ_nodes.push_back(succ_edges[i]->next_operator()); - } - auto elimi = std::make_shared(star_center, succ_edges, succ_nodes); - eliminations.emplace_back(std::move(elimi)); - } - } - - // Phase 2: Search the cost_list in the final graph, and determine the optimal one - if (graph->SearchStrategy() != SUCCESS) { - MS_LOG(ERROR) << "Searching strategy for the final failed."; - return FAILED; - } - - // Phase 3: Recover the original CostGraph, the determine strategy for each operator - if (RecoverStrategy(eliminations) == SUCCESS) { - MS_LOG(INFO) << "Searching strategies ends."; - return SUCCESS; - } else { - MS_LOG(EXCEPTION) << "Searching strategies failed."; - } -} - -Status RecoverStrategy(std::vector eliminations) { - std::vector::reverse_iterator rit; - - for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) { - if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto e = elimination->new_edge_; - auto w = elimination->op_; - MS_EXCEPTION_IF_NULL(e); - MS_EXCEPTION_IF_NULL(w); - auto left_edge = elimination->left_edge_; - auto right_edge = elimination->right_edge_; - MS_EXCEPTION_IF_NULL(left_edge); - MS_EXCEPTION_IF_NULL(right_edge); - auto decision = e->selected_cost()->decision_ptr_->cast(); - w->SetSelectedStrategyAndCost(decision->op_strategy_, decision->middle_cost_); - left_edge->set_selected_cost(decision->left_cost_); - right_edge->set_selected_cost(decision->right_cost_); - MS_LOG(INFO) << "Recover opElimination succeeded."; - } else if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto new_edge = elimination->new_edge_; - MS_EXCEPTION_IF_NULL(new_edge); - auto &edges = elimination->edges_; - auto decision = new_edge->selected_cost()->decision_ptr_->cast(); - for (size_t j = 0; j < edges.size(); ++j) { - MS_EXCEPTION_IF_NULL(edges[j]); - edges[j]->set_selected_cost(decision->edges_cost_list_[j]); - } - MS_LOG(INFO) << "Recover edgeElimination succeeded."; - } else if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto target_node = elimination->target_node_; - MS_EXCEPTION_IF_NULL(target_node); - auto merged_node = elimination->merged_node_; - MS_EXCEPTION_IF_NULL(merged_node); - auto merged_edge = elimination->dir_edge_; - MS_EXCEPTION_IF_NULL(merged_edge); - MS_EXCEPTION_IF_NULL(target_node->selected_cost()); - MS_EXCEPTION_IF_NULL(target_node->selected_cost()->decision_ptr_); - auto decision = target_node->selected_cost()->decision_ptr_->cast(); - merged_node->SetSelectedStrategyAndCost(decision->merged_op_strategy_, decision->merged_op_cost_); - merged_edge->set_selected_cost(decision->edge_cost_); - target_node->SetSelectedStrategyAndCost(decision->target_op_strategy_, decision->target_op_cost_); - - MS_LOG(INFO) << "Recover mergeElimination succeeded."; - } else if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto target_node = elimination->target_node_; - auto contracted_node = elimination->contracted_node_; - auto contracted_edge = elimination->dir_edge_; - auto decision = target_node->selected_cost()->decision_ptr_->cast(); - - contracted_node->SetSelectedStrategyAndCost(decision->contracted_op_strategy_, decision->contracted_op_cost_); - contracted_edge->set_selected_cost(decision->edge_cost_); - target_node->SetSelectedStrategyAndCost(decision->target_op_strategy_, decision->target_cost_); - MS_LOG(INFO) << "Recover contractElimination succeeded."; - } else if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto left_node = elimination->left_node_; - auto left_edge = elimination->left_edge_; - auto eliminated_node = elimination->eliminated_node_; - auto right_edge = elimination->right_edge_; - auto right_node = elimination->right_node_; - auto decision = left_node->selected_cost()->decision_ptr_->cast(); - - eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); - left_edge->set_selected_cost(decision->left_edge_cost_); - right_edge->set_selected_cost(decision->right_edge_cost_); - // Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy. - left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); - right_node->CheckSelectedStrategy(decision->right_node_strategy_); - MS_LOG(INFO) << "Recover triangleElimination succeeded."; - } else if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto merged_node = elimination->eliminated_node_; - auto succ_edges = elimination->succ_edges_; - auto succ_nodes = elimination->succ_ops_; - // decision is hided in succ_nodes[0] - auto decision = succ_nodes[0]->selected_cost()->decision_ptr_->cast(); - - merged_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); - for (size_t i = 0; i < succ_edges.size(); ++i) { - succ_edges[i]->set_selected_cost(decision->succ_edges_cost_list_[i]); - } - MS_EXCEPTION_IF_NULL(succ_nodes[0]); - MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]); - MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); - // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. - succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); - for (size_t k = 1; k < succ_nodes.size(); ++k) { - succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); - } - MS_LOG(INFO) << "Recover starElimination succeeded."; - } else { - MS_LOG(ERROR) << "Unknown Elimination type."; - return FAILED; - } - } - - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h deleted file mode 100644 index e3fbfba5a7..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ - -#include -#include -#include -#include "ir/value.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/auto_parallel/graph_costmodel.h" - -namespace mindspore { -namespace parallel { -// There are 3 meta phases of the Dynamic Programming (DP) algorithm. The input is a CostGraph, and the goal -// is to compute the strategy for each operator in the CostGraph. -// -// Phase 1: Shrink the CostGraph using 6 operations, and record them in the order -// Using for operations: Operator Elimination, Edge Elimination, Merge Elimination, and Contract Elimination, -// each connected component in the CostGraph can be shrunk in to the final graph: u --> v. See the -// interpretation of 6 operations in costmodel.h. -// Phase 2: Search the cost_list in the final graph, and determine the optimal one -// Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity -// COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost -// Phase 3: Recover the original CostGraph, the determine strategy for each operator -// After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying -// the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy, -// the operators' strategies can be all determined. - -struct Elimination : public Base { - enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, TRIANGLE, STAR }; - Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {} - - EdgePtr new_edge_; - EliminationType type_; -}; - -// Operator Elimination -struct OpElimination : public Elimination { - OpElimination(EdgePtr n_edge, EdgePtr l_edge, OperatorInfoPtr op_info, EdgePtr r_edge) - : Elimination(std::move(n_edge), Elimination::EliminationType::OPERA), - left_edge_(std::move(l_edge)), - op_(std::move(op_info)), - right_edge_(std::move(r_edge)) {} - - EdgePtr left_edge_; - OperatorInfoPtr op_; - EdgePtr right_edge_; - MS_DECLARE_PARENT(OpElimination, Elimination); -}; - -// Edge Elimination -struct EdgeElimination : public Elimination { - EdgeElimination(const EdgePtr &n_edge, std::vector eds) - : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {} - - std::vector edges_; - MS_DECLARE_PARENT(EdgeElimination, Elimination); -}; - -// Merge Elimination -struct MergeElimination : public Elimination { - MergeElimination(OperatorInfoPtr u_info, EdgePtr merged_target_edge, OperatorInfoPtr v_info) - : Elimination(nullptr, Elimination::EliminationType::MERGE), - merged_node_(std::move(u_info)), - dir_edge_(std::move(merged_target_edge)), - target_node_(std::move(v_info)) {} - - OperatorInfoPtr merged_node_; - EdgePtr dir_edge_; - OperatorInfoPtr target_node_; - MS_DECLARE_PARENT(MergeElimination, Elimination); -}; - -// Contract Elimination -struct ContractElimination : public Elimination { - ContractElimination(OperatorInfoPtr tar_info, EdgePtr tar_con_edge, OperatorInfoPtr con_info) - : Elimination(nullptr, Elimination::EliminationType::CONTRACT), - contracted_node_(std::move(con_info)), - dir_edge_(std::move(tar_con_edge)), - target_node_(std::move(tar_info)) {} - - OperatorInfoPtr contracted_node_; - EdgePtr dir_edge_; - OperatorInfoPtr target_node_; - MS_DECLARE_PARENT(ContractElimination, Elimination); -}; - -// Triangle Elimination -struct TriangleElimination : public Elimination { - TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, - OperatorInfoPtr r_node) - : Elimination(nullptr, Elimination::EliminationType::TRIANGLE), - eliminated_node_(std::move(elim_node)), - left_edge_(std::move(l_edge)), - left_node_(std::move(l_node)), - right_edge_(std::move(r_edge)), - right_node_(std::move(r_node)) {} - - OperatorInfoPtr eliminated_node_; - EdgePtr left_edge_; - OperatorInfoPtr left_node_; - EdgePtr right_edge_; - OperatorInfoPtr right_node_; - MS_DECLARE_PARENT(TriangleElimination, Elimination); -}; - -// Star Elimination -struct StarElimination : public Elimination { - StarElimination(OperatorInfoPtr elimi_node, std::vector s_edges, std::vector s_ops) - : Elimination(nullptr, Elimination::EliminationType::STAR), - eliminated_node_(std::move(elimi_node)), - succ_edges_(std::move(s_edges)), - succ_ops_(std::move(s_ops)) {} - - OperatorInfoPtr eliminated_node_; - std::vector succ_edges_; - std::vector succ_ops_; - MS_DECLARE_PARENT(StarElimination, Elimination); -}; - -using EliminationPtr = std::shared_ptr; -using OpEliminationPtr = std::shared_ptr; -using EdgeEliminationPtr = std::shared_ptr; -using MergeEliminationPtr = std::shared_ptr; -using ContractEliminationPtr = std::shared_ptr; -using TriangleEliminationPtr = std::shared_ptr; -using StarEliminationPtr = std::shared_ptr; - -// Phase 1 and Phase 2 -Status GetStrategy(const CostGraphPtr &graph); - -// Phase 3 -Status RecoverStrategy(std::vector eliminations); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc deleted file mode 100644 index 60256a3ae3..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc +++ /dev/null @@ -1,324 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/edge_costmodel.h" - -#include -#include -#include -#include -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Status Edge::InitEdgeCost() { - bool has_available_cost = false; - for (auto &swc : prev_op_->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(swc); - pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr)); - } - for (auto &swc : next_op_->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(swc); - next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr)); - } - if (is_identity_edge) { - for (auto &target_output : pre_op_output_) { - auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); - auto target_output_str = target_output.first; - for (auto &target_input : next_op_input_) { - auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); - auto target_input_str = target_input.first; - if (target_output_lyt == target_input_lyt) { - CostPtrKey ck = {target_output_str, target_input_str}; - CostPtr cost = std::make_shared(0.0, 0.0); - MS_EXCEPTION_IF_NULL(cost); - cost->communication_without_parameter_ = 0.0; - cost->communication_with_partial_para_ = 0.0; - CostPtrList cl; - cl.push_back(cost); - (void)cost_map_.emplace(std::make_pair(ck, cl)); - has_available_cost = true; - } - } - } - } else { - for (auto &target_output : pre_op_output_) { - auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); - auto target_output_str = target_output.first; - auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_]; - auto type = prev_op_->outputs_type()[prev_op_output_index_]; - for (auto &target_input : next_op_input_) { - auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); - auto target_input_str = target_input.first; - CostPtr cost; - if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) { - MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed"; - } - MS_EXCEPTION_IF_NULL(cost); - MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_ - << ", communication_cost: " << cost->communication_cost_ - << ", communication_without_parameter_: " << cost->communication_without_parameter_ - << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; - // refine communication cost calculation for practice - RefineForPracticalCost(cost, true); - cost->communication_forward_ = cost->communication_redis_forward_; - CostPtrKey ck = {target_output_str, target_input_str}; - CostPtrList cl; - cl.push_back(cost); - (void)cost_map_.emplace(std::make_pair(ck, cl)); - has_available_cost = true; - } - } - } - if (!has_available_cost) { - if (FULLY_USE_DEVICES) { - MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ - << " failed, it may be caused by setting 'fully_use_devices' true. Try to set " - "'fully_use_devices' false."; - } else if (ELEMENTWISE_OP_STRA_FOLLOW) { - MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ - << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " - "Try to set 'elementwise_op_strategy_follow' false."; - } - if (edge_name_.find(RESHAPE) != std::string::npos) { - MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ - << " failed, it may be caused by setting different strategies for operators following Reshape. " - "Try to fix that."; - } - MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed."; - } - return Status::SUCCESS; -} - -Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, - size_t type_length, TypePtr type, CostPtr *cost) { - MS_EXCEPTION_IF_NULL(prev_op_); - MS_EXCEPTION_IF_NULL(cost); - RankList dev_list = prev_op_->global_device_list(); - TensorRedistribution tensor_redistribution(false); - - // Init TensorRedistribution - if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; - } - - if (tensor_redistribution.ComputeCost() == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; - } - - double comm_cost = tensor_redistribution.comm_cost(); - double forward_comm_cost = tensor_redistribution.forward_comm_cost(); - double backward_comm_cost = tensor_redistribution.backward_comm_cost(); - double computation_cost = tensor_redistribution.computation_cost(); - double mem_cost = tensor_redistribution.memory_cost(); - - // Now AllGather, ReduceScatter, AlltoAll don't support bool type - MS_EXCEPTION_IF_NULL(type); - if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) { - computation_cost = INF; - comm_cost = INF; - MS_LOG(WARNING) << "Communication Operators don't support bool dtype!"; - } - *cost = std::make_shared(type_length * computation_cost, type_length * comm_cost); - (*cost)->communication_without_parameter_ = type_length * comm_cost; - (*cost)->communication_with_partial_para_ = - (*cost)->communication_without_parameter_ + - COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); - (*cost)->communication_redis_forward_ = type_length * forward_comm_cost; - (*cost)->communication_redis_backward_ = type_length * backward_comm_cost; - (*cost)->memory_with_reuse_ = mem_cost; - return Status::SUCCESS; -} - -CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) { - CostPtrKey ck = {output_str, input_str}; - CostPtrList result; - if (cost_map_.find(ck) != cost_map_.end()) { - return cost_map_.at(ck); - } - return result; -} - -CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector &edges, - const StrategyPtr &input_st_ptr) { - std::function LocalGetCostList = [&](const EdgePtr &edge) { - MS_EXCEPTION_IF_NULL(edge); - return edge->GetCostList(output_st_ptr, input_st_ptr); - }; - CostPtrList result; - std::vector all_cost_list; - all_cost_list.resize(edges.size()); - (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); - - CostPtrList selected_cost_list(all_cost_list.size(), nullptr); - std::function recursive = - [&](size_t k, double computation, double memory, double communication, double communication_without_para, - double communication_forward) { - if (k == edges.size()) { - auto decision = std::make_shared(selected_cost_list); - CostPtr new_cost = std::make_shared(computation, communication); - MS_EXCEPTION_IF_NULL(new_cost); - new_cost->communication_without_parameter_ = communication_without_para; - new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - new_cost->memory_with_reuse_ = memory; - new_cost->communication_forward_ = communication_forward; - new_cost->decision_ptr_ = decision; - result.push_back(new_cost); - return; - } - for (auto &c : all_cost_list[k]) { - MS_EXCEPTION_IF_NULL(c); - selected_cost_list[k] = c; - recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, - communication + c->communication_cost_, - communication_without_para + c->communication_without_parameter_, - communication_forward + c->communication_forward_); - } - }; - recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0); - Simplify(&result); - return result; -} - -void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector &edges, OperatorInfoPtr) { - bool valid = false; - for (const auto &output_pair : pre_op_output_) { - StrategyPtr output_st_ptr = output_pair.first; - for (const auto &input_pair : next_op_input_) { - StrategyPtr input_st_ptr = input_pair.first; - CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr); - CostPtrKey key = {output_st_ptr, input_st_ptr}; - cost_map_[key] = clist; - if ((!valid) && (!clist.empty())) { - valid = true; - } - } - } - if (!valid) { - MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; - } -} - -void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, - const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, - CostPtrList *ret_cost_list) { - for (auto &left_cost : left_cost_list) { - MS_EXCEPTION_IF_NULL(left_cost); - for (auto &middle_cost : middle_cost_list) { - MS_EXCEPTION_IF_NULL(middle_cost); - for (auto &right_cost : right_cost_list) { - MS_EXCEPTION_IF_NULL(right_cost); - double computation = - left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; - double communication = - left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; - double communication_forward = - left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_; - double communication_without_para = left_cost->communication_without_parameter_ + - middle_cost->communication_without_parameter_ + - right_cost->communication_without_parameter_; - double memory_cost = - left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_; - - auto decision = std::make_shared(op_strategy, left_cost, middle_cost, right_cost); - auto cost = std::make_shared(computation, communication, decision); - MS_EXCEPTION_IF_NULL(cost); - cost->communication_without_parameter_ = communication_without_para; - cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - cost->memory_with_reuse_ = memory_cost; - cost->communication_forward_ = communication_forward; - ret_cost_list->emplace_back(std::move(cost)); - } - } - } -} - -CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr, - const OperatorInfoPtr &op, const EdgePtr &e2, - const StrategyPtr &input_st_ptr) { - MS_EXCEPTION_IF_NULL(op); - MS_EXCEPTION_IF_NULL(e1); - MS_EXCEPTION_IF_NULL(e2); - CostPtrList result; - for (const auto &op_strategy : op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(op_strategy); - auto middle_strategy = op_strategy->strategy_ptr; - CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), - op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result); - } - Simplify(&result); - return result; -} - -void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) { - bool valid = false; - for (const auto &output_pair : pre_op_output_) { - StrategyPtr output_st_ptr = output_pair.first; - for (const auto &input_pair : next_op_input_) { - StrategyPtr input_st_ptr = input_pair.first; - - CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr); - CostPtrKey key = {output_st_ptr, input_st_ptr}; - cost_map_[key] = clist; - if ((!valid) && (!clist.empty())) { - valid = true; - } - } - } - if (!valid) { - MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; - } -} - -Status Edge::CalculateMemoryCost() { - if (is_output_parameter_involve_ == -1) { - MS_LOG(ERROR) << "is_output_parameter_involve_ is unset."; - return FAILED; - } - if (is_output_parameter_involve_ == 0) { - // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is - // unnecessary to keep them in memory. - for (auto &cost_kv : cost_map_) { - auto &cost_v = cost_kv.second; - if (!cost_v.empty()) { - cost_v[0]->memory_with_reuse_ = 0; - } - } - } - - return SUCCESS; -} - -Status Edge::CalculateMemoryCostForInference() { - // Currently, memory cost is NOT calculated for redistribution - if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) { - MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_; - return FAILED; - } - for (auto &cost_kv : cost_map_) { - auto &cost_v = cost_kv.second; - if (!cost_v.empty()) { - cost_v[0]->memory_with_reuse_ = 0; - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h deleted file mode 100644 index 2a5ed3b2a4..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h +++ /dev/null @@ -1,171 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ -#define PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ - -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/tensor_layout/tensor_info.h" -#include "parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -using CostPtrKey = std::pair; -using OperatorInfoPtr = std::shared_ptr; -using EdgePtr = std::shared_ptr; - -class Edge { - // An 'Edge' connects two Operators in the CostGraph. - public: - Edge(const std::string &edge_name, const std::shared_ptr &prev_op, - const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, - const bool &is_com) - : edge_name_(edge_name), - prev_op_(prev_op), - next_op_(next_op), - prev_op_output_index_(output_index_), - next_op_input_index_(input_index_), - is_combined_(is_com) { - is_identity_edge = false; - } - - Edge(const std::string &edge_name, const std::shared_ptr &prev_op, - const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, - const bool &is_com, const bool &is_iden) - : edge_name_(edge_name), - prev_op_(prev_op), - next_op_(next_op), - prev_op_output_index_(output_index_), - next_op_input_index_(input_index_), - is_combined_(is_com), - is_identity_edge(is_iden) {} - - Edge(const std::string &edge_name, const std::shared_ptr &prev_op, - const std::shared_ptr &next_op, const std::vector &output_indexs_, - const std::vector &input_indexs_, const bool &is_com) - : edge_name_(edge_name), - prev_op_(prev_op), - next_op_(next_op), - pre_op_output_indexs_(output_indexs_), - next_op_input_indexs_(input_indexs_), - is_combined_(is_com) { - prev_op_output_index_ = 0; - next_op_input_index_ = 0; - is_identity_edge = false; - } - - ~Edge() = default; - std::shared_ptr prev_operator() const { return prev_op_; } - std::shared_ptr next_operator() const { return next_op_; } - std::string edge_name() const { return edge_name_; } - // Init cost_map_: for each output layout and input layout, calculate the cost - Status InitEdgeCost(); - // For two operators u--->v, given the output tensor layout of u, - // and the input tensor layout of v, return the redistribution cost, - // and the op_list to carry out the redistribution. - Status GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, - size_t, TypePtr type, CostPtr *cost); - - void set_pre_op_output(const std::vector, std::vector>> &output_set) { - pre_op_output_ = output_set; - } - void set_next_op_input(const std::vector, std::vector>> &input_set) { - next_op_input_ = input_set; - } - - // Given a pair of output strategy and input strategy, return the corresponding costlist - CostPtrList GetCostList(StrategyPtr output_str, StrategyPtr input_str); - - std::vector, std::vector>> prev_op_output() const { - return pre_op_output_; - } - std::vector, std::vector>> next_op_input() const { - return next_op_input_; - } - - bool is_combined() const { return is_combined_; } - size_t prev_op_output_index() const { return prev_op_output_index_; } - size_t next_op_input_index() const { return next_op_input_index_; } - std::vector prev_op_output_indexs() const { return pre_op_output_indexs_; } - std::vector next_op_input_indexs() const { return next_op_input_indexs_; } - - CostPtrList CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, - const std::vector> &edges, - const StrategyPtr &input_st_ptr); - // In the Edge Elimination operation in DP algorithm, 'edges' is replaced by a new edge. This method is used to - // set cost for this new edge - void EdgeEliminationSetNewCost(std::shared_ptr u, const std::vector> &edges, - std::shared_ptr v); - void CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, - const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, - CostPtrList *ret_cost_list); - - CostPtrList CreateOpEliminationCostList(const std::shared_ptr &e1, const StrategyPtr &output_st_ptr, - const std::shared_ptr &op, const std::shared_ptr &e2, - const StrategyPtr &input_st_ptr); - // In the Operation Elimination operation in DP algorithm, 'op', 'e1' and 'e2' are replaced by a new edge. - // This method is used to set cost for this new edge - void OpEliminationSetNewCost(const std::shared_ptr &e1, const std::shared_ptr &op, - const std::shared_ptr &e2); - - void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } - const CostPtr &selected_cost() const { return selected_cost_; } - void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } - // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving - // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory - // at the end of forward phase. - Status CalculateMemoryCost(); - // In the inference phase, - Status CalculateMemoryCostForInference(); - void mark_output_critical() { is_output_critical_ = 1; } - - private: - std::string edge_name_; - std::shared_ptr prev_op_, next_op_; - std::map cost_map_; - // pre_op_output_ - std::vector, std::vector>> pre_op_output_; - std::vector, std::vector>> next_op_input_; - // the index of outputs of prev_op, and the index of inputs of next_op - size_t prev_op_output_index_, next_op_input_index_; - - // pre_op_output_indexs_ and next_op_input_indexs_ store the indexs of inputs and outputs if is_combined = true - std::vector pre_op_output_indexs_; - std::vector next_op_input_indexs_; - // is this edge constructed by combining multiple edges? If is is, then is_combined = true, else is_combined = false - bool is_combined_; - // When a Parameter in the ANF graph being used by multiple operators, we include the Parameter in the costgraph by - // replace the Parameter by a TmpIdentity operator, and connecting this TmpIdentity operator with subsequent - // operators. The resulting edges are different from those normal edges, thus this Bool variable distinguishes them. - // If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor. - bool is_identity_edge; - CostPtr selected_cost_; - // In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator - // is parameter-involved - int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved - // In the inference phase, this is used to mark whether the output of the previous operator is critical. - int is_output_critical_ = 0; -}; -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc deleted file mode 100644 index 05be097e6a..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ /dev/null @@ -1,1678 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include -#include -#include - -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/ops_info/reshape_info.h" -#include "parallel/step_auto_parallel.h" - -namespace mindspore { -namespace parallel { -CostGraphPtr entire_costgraph = nullptr; -size_t TOTAL_OPS = 0; -double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA; -bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; -double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY; -double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; -double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST; -double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS; -bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; -size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; -bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; -bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; -bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; -int32_t RUN_PHASE = DEFAULT_RUN_PHASE; -constexpr char RESHAPEINFO[] = "ReshapeInfo"; - -void CostGraph::SetDeviceMemoryAndCostParameter() { - MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); - - // DEVICE_MEMORY_CAPACITY - auto device_memory = CostModelContext::GetInstance()->device_memory_capacity(); - if (device_memory <= 0) { - MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; - } - dev_memory_ = device_memory; - DEVICE_MEMORY_CAPACITY = device_memory; - MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << "."; - - // COST_MODEL_ALPHA - auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); - if (alpha <= 0) { - MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; - } - costmodel_alpha_ = alpha; - MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; - - // COST_MODEL_BETA - auto beta = CostModelContext::GetInstance()->costmodel_beta(); - if (beta <= 0) { - MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; - } - costmodel_beta_ = beta; - MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; - - // COST_MODEL_GAMMA - auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); - if ((gamma < 0) || (gamma > 1)) { - MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; - } - COST_MODEL_GAMMA = gamma; - MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << "."; - - // COST_MODEL_SIMPLIFY_CALCULATION - auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal(); - COST_MODEL_SIMPLIFY_CALCULATION = simplify; - if (COST_MODEL_SIMPLIFY_CALCULATION) { - MS_LOG(INFO) << "costmodel_simplify_cal: true."; - } else { - MS_LOG(INFO) << "costmodel_simplify_cal: false."; - } - - // COST_MODEL_COMMUNI_THRESHOLD - auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold(); - if (communi_threshold < 0) { - MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; - } - COST_MODEL_COMMUNI_THRESHOLD = communi_threshold; - MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << "."; - - // COST_MODEL_COMMUNI_CONST - auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const(); - if (communi_const < 0) { - MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; - } - COST_MODEL_COMMUNI_CONST = communi_const; - MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << "."; - - // COST_MODEL_COMMUNI_BIAS - auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias(); - if (communi_bias < 0) { - MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; - } - COST_MODEL_COMMUNI_BIAS = communi_bias; - MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << "."; - - // TENSOR_SLICE_ALIGNMENT_ENABLE - auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable(); - TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable; - if (TENSOR_SLICE_ALIGNMENT_ENABLE) { - MS_LOG(INFO) << "tensor_slice_align_enable: true."; - } else { - MS_LOG(INFO) << "tensor_slice_align_enable: false."; - } - - // TENSOR_SLICE_ALIGNMENT_SIZE - auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size(); - if (align_size == 0) { - MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; - } - TENSOR_SLICE_ALIGNMENT_SIZE = align_size; - MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << "."; - - // FULLY_USE_DEVICES - auto fully_devices = CostModelContext::GetInstance()->fully_use_device(); - FULLY_USE_DEVICES = fully_devices; - if (FULLY_USE_DEVICES) { - MS_LOG(INFO) << "fully_use_devices: true."; - } else { - MS_LOG(INFO) << "fully_use_devices: false."; - } - - // ELEMENTWISE_OP_STRA_FOLLOW - auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); - ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow; - if (ELEMENTWISE_OP_STRA_FOLLOW) { - MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; - } else { - MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; - } - - // MULTI_SUBGRAPHS - auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs(); - MULTI_SUBGRAPHS = multi_subgraphs; - if (MULTI_SUBGRAPHS) { - MS_LOG(INFO) << "multi_subgraphs: true."; - } else { - MS_LOG(INFO) << "multi_subgraphs: false."; - } - - // RUN_PHASE - auto phase = CostModelContext::GetInstance()->run_phase(); - if (phase != 0 && phase != 1) { - MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; - } - RUN_PHASE = phase; - MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; -} - -void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { - for (auto it = ops_.begin(); it != ops_.end();) { - if ((*it) == op) { - it = ops_.erase(it); - } else { - ++it; - } - } -} - -bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) { - struct IsInGraph { - const OperatorInfoPtr test_; - explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {} - bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); } - }; - return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); -} - -void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { - std::vector curr_edges(edges_[{u_node, v_node}]); - curr_edges.push_back(edge); - edges_[{u_node, v_node}] = curr_edges; - - std::vector curr_out_edges(out_edges_[u_node]); - curr_out_edges.push_back(edge); - out_edges_[u_node] = curr_out_edges; - - std::vector curr_in_edges(in_edges_[v_node]); - curr_in_edges.push_back(edge); - in_edges_[v_node] = curr_in_edges; -} - -bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { - for (auto &edge_pair : edges_) { - auto edges = edge_pair.second; - for (auto &edge : edges) { - MS_EXCEPTION_IF_NULL(edge); - bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) && - (edge->next_op_input_index() == input_index); - if (bool_result) { - return true; - } - } - } - return false; -} - -std::vector> CostGraph::ConstructConnectedComponents( - std::vector alive_ops) { - std::map visited; - - for (auto &op : alive_ops) { - visited[op] = false; - } - - MS_LOG(INFO) << "visited: " << visited.size() << "."; - for (auto &op : alive_ops) { - if ((!visited[op]) && op->is_alive()) { - std::shared_ptr new_component = std::make_shared(); - MS_EXCEPTION_IF_NULL(new_component); - new_component->SetDeviceMemoryAndCostParameter(); - DFS(op, &visited, new_component); - connected_compoents_.push_back(new_component); - } - } - return connected_compoents_; -} - -void CostGraph::DFS(const OperatorInfoPtr ¤t_op, std::map *visited, - const std::shared_ptr &component) { - MS_EXCEPTION_IF_NULL(visited); - MS_EXCEPTION_IF_NULL(component); - visited->at(current_op) = true; - component->AddOperator(current_op); - - for (auto &edge : current_op->succ_edges()) { - bool bool_test = (visited->find(edge->next_operator()) != visited->end()) && - (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive(); - if (bool_test) { - component->AddEdge(current_op, edge->next_operator(), edge); - DFS(edge->next_operator(), visited, component); - } - } - - for (auto &edge : current_op->prev_edges()) { - bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) && - (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive(); - if (bool_test) { - component->AddEdge(edge->prev_operator(), current_op, edge); - DFS(edge->prev_operator(), visited, component); - } - } -} - -// Create final cost list for the graph: u --> v -CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr &e, - const OperatorInfoPtr &v) { - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(v); - MS_EXCEPTION_IF_NULL(e); - CostPtrList ret; - for (const auto &u_strategy : u->GetStrategyCost()) { - for (const auto &v_strategy : v->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(u_strategy); - MS_EXCEPTION_IF_NULL(v_strategy); - auto u_strategy_ptr = u_strategy->strategy_ptr; - auto v_strategy_ptr = v_strategy->strategy_ptr; - CostPtrList clist1 = u_strategy->cost_list; - CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr); - CostPtrList clist3 = v_strategy->cost_list; - for (const auto &cost1 : clist1) { - for (const auto &cost2 : clist2) { - for (const auto &cost3 : clist3) { - MS_EXCEPTION_IF_NULL(cost1); - MS_EXCEPTION_IF_NULL(cost2); - MS_EXCEPTION_IF_NULL(cost3); - double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; - double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; - double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; - double communication_forward = - cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_; - double communication_without_para = cost1->communication_without_parameter_ + - cost2->communication_without_parameter_ + - cost3->communication_without_parameter_; - auto decision = - std::make_shared(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); - auto cost = std::make_shared(computation, communication, decision); - MS_EXCEPTION_IF_NULL(cost); - cost->communication_without_parameter_ = communication_without_para; - cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - cost->memory_with_reuse_ = memory; - cost->communication_forward_ = communication_forward; - ret.push_back(cost); - } - } - } - } - } - - Simplify(&ret); - return ret; -} - -// Create final cost list for the graph containing a signle node: u -CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { - MS_EXCEPTION_IF_NULL(u); - CostPtrList ret; - for (const auto &u_strategy : u->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(u_strategy); - auto u_strategy_ptr = u_strategy->strategy_ptr; - CostPtrList clist1 = u_strategy->cost_list; - for (const auto &cost1 : clist1) { - MS_EXCEPTION_IF_NULL(cost1); - auto decision = std::make_shared(u_strategy_ptr, cost1); - auto new_cost = std::make_shared(cost1->computation_cost_, cost1->communication_cost_, decision); - MS_EXCEPTION_IF_NULL(new_cost); - new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; - new_cost->communication_with_partial_para_ = - cost1->communication_without_parameter_ + - COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); - new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; - new_cost->communication_forward_ = cost1->communication_forward_; - ret.push_back(new_cost); - } - } - - Simplify(&ret); - return ret; -} - -CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) { - // Select the cost with minimum inference time. Currently, the inference time is modeled as = - // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_ - if (cost_list.empty()) { - MS_LOG(ERROR) << "Final cost list is null."; - return nullptr; - } - CostPtrList after_mem_filter; - double minimum_memory = DBL_MAX; - // Filter out the valid costs. - for (auto &a_cost : cost_list) { - if (a_cost->memory_with_reuse_ <= memory) { - after_mem_filter.emplace_back(std::move(a_cost)); - } else if (a_cost->memory_with_reuse_ < minimum_memory) { - minimum_memory = a_cost->memory_with_reuse_; - } - } - if (after_mem_filter.empty()) { - MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory - << ", the memory capacity is: " << memory << "."; - return nullptr; - } - // Init the returned value with first cost. - CostPtr ret = after_mem_filter[0]; - - double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_; - MS_LOG(INFO) << "Cost 0: " - << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ - << ", communication_forward_: " << ret->communication_forward_ - << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ - << ", communication_cost_: " << ret->communication_cost_ - << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; - MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; - for (size_t i = 1; i < after_mem_filter.size(); ++i) { - MS_EXCEPTION_IF_NULL(after_mem_filter[i]); - MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ - << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ - << ", communication_forward_: " << after_mem_filter[i]->communication_forward_ - << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ - << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ - << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ - << "."; - auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + - costmodel_beta_ * after_mem_filter[i]->communication_forward_; - MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; - if (minimum > tmp) { - minimum = tmp; - ret = after_mem_filter[i]; - MS_LOG(INFO) << "Selected: " << i; - } - } - return ret; -} - -CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { - // Select the cost with minimum training time. Currently, the training time is modeled as = - // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_ - if (cost_list.empty()) { - MS_LOG(ERROR) << "Final cost list is null."; - return nullptr; - } - CostPtrList after_mem_filter; - double minimum_memory = DBL_MAX; - // Filter out the valid costs. - for (auto &a_cost : cost_list) { - if (a_cost->memory_with_reuse_ <= memory) { - after_mem_filter.emplace_back(std::move(a_cost)); - } else if (a_cost->memory_with_reuse_ < minimum_memory) { - minimum_memory = a_cost->memory_with_reuse_; - } - } - if (after_mem_filter.empty()) { - MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory - << ", the memory capacity is: " << memory << "."; - return nullptr; - } - // Init the returned value with first cost. - CostPtr ret = after_mem_filter[0]; - - double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; - MS_LOG(INFO) << "Cost 0: " - << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ - << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ - << ", communication_cost_: " << ret->communication_cost_ - << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; - MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; - for (size_t i = 1; i < after_mem_filter.size(); ++i) { - MS_EXCEPTION_IF_NULL(after_mem_filter[i]); - MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ - << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ - << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ - << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ - << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ - << "."; - auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + - costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_; - MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; - if (minimum > tmp) { - minimum = tmp; - ret = after_mem_filter[i]; - MS_LOG(INFO) << "Selected: " << i; - } - } - return ret; -} - -CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_cost_list, - double available_memory) { - CostPtrList selected_cost_list(all_cost_list.size(), nullptr); - double minimum = DBL_MAX, total_memory = 0.0; - CostPtrList ret(all_cost_list.size(), nullptr); - // Check whether valid costs exist. - for (size_t i = 0; i < all_cost_list.size(); ++i) { - if (all_cost_list[i][0] == nullptr) { - MS_LOG(ERROR) << "The cost list " << i << " is empty."; - return ret; - } else { - double memory_i_cost = DBL_MAX; - for (size_t j = 0; j < all_cost_list[i].size(); ++j) { - if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) { - memory_i_cost = all_cost_list[i][j]->memory_with_reuse_; - } - } - total_memory += memory_i_cost; - } - } - if (total_memory >= available_memory) { - MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory - << ", minimum strategy cost: " << total_memory << "."; - return selected_cost_list; - } - - std::function recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive, - &available_memory, this](size_t k) { - if (k == all_cost_list.size()) { - double tmp_memory = 0.0, tmp_minimum = 0.0; - for (size_t i = 0; i < selected_cost_list.size(); ++i) { - MS_EXCEPTION_IF_NULL(selected_cost_list[i]); - tmp_memory += selected_cost_list[i]->memory_with_reuse_; - tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + - costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; - } - MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum - << "."; - if (tmp_memory < available_memory && tmp_minimum < minimum) { - ret = selected_cost_list; - minimum = tmp_minimum; - MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << "."; - } - return; - } - - MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << "."; - for (auto &c : all_cost_list[k]) { - selected_cost_list[k] = c; - recursive(k + 1); - } - }; - recursive(0); - return ret; -} - -Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector &alive_ops) { - MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph."; - auto connected_components = ConstructConnectedComponents(alive_ops); - MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; - std::vector all_list; - for (size_t j = 0; j < connected_components.size(); ++j) { - auto one_component = connected_components[j]; - MS_EXCEPTION_IF_NULL(one_component); - if (one_component->GetOperators().size() == 1) { - MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; - auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); - all_list.push_back(cost_list); - } else if (one_component->GetOperators().size() == 2) { - MS_LOG(INFO) << "There are 2 operators in a component in the final graph."; - OperatorInfoPtr u, v; - auto first_op = one_component->GetOperators()[0]; - auto second_op = one_component->GetOperators()[1]; - MS_EXCEPTION_IF_NULL(first_op); - MS_EXCEPTION_IF_NULL(second_op); - if (!first_op->GetAliveSuccEdges().empty() && - first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) { - u = first_op; - v = second_op; - } else if (!second_op->GetAliveSuccEdges().empty() && - second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) { - u = second_op; - v = first_op; - } else { - MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size() - << ", " << second_op->GetAliveSuccEdges().size() << "."; - } - MS_EXCEPTION_IF_NULL(u); - auto e = u->GetAliveSuccEdges()[0]; - auto cost_list = one_component->CreateFinalCostList(u, e, v); - all_list.push_back(cost_list); - } else { - MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size() - << " operators in a component in the final graph."; - } - } - // - auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); - for (size_t k = 0; k < selected_cost_list.size(); ++k) { - auto selected_cost = selected_cost_list[k]; - if (selected_cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(connected_components[k]); - if (connected_components[k]->GetOperators().size() == 1) { - auto u = connected_components[k]->GetOperators()[0]; - auto decision = selected_cost->decision_ptr_->cast(); - u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); - MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; - } else if (connected_components[k]->GetOperators().size() == 2) { - OperatorInfoPtr u = nullptr, v = nullptr; - auto first_op = connected_components[k]->GetOperators()[0]; - auto second_op = connected_components[k]->GetOperators()[1]; - MS_EXCEPTION_IF_NULL(first_op); - MS_EXCEPTION_IF_NULL(second_op); - if (!first_op->GetAliveSuccEdges().empty() && - first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) { - u = first_op; - v = second_op; - } else if (!second_op->GetAliveSuccEdges().empty() && - second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) { - u = second_op; - v = first_op; - } - MS_EXCEPTION_IF_NULL(u); - auto e = u->GetAliveSuccEdges()[0]; - MS_EXCEPTION_IF_NULL(v); - MS_EXCEPTION_IF_NULL(e); - MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); - auto decision = selected_cost->decision_ptr_->cast(); - MS_EXCEPTION_IF_NULL(decision); - u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); - v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); - e->set_selected_cost(decision->middle_cost_); - MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; - } - } - return SUCCESS; -} - -// searching the strategy for the final eliminated graph -Status CostGraph::SearchStrategy() { - MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began."; - std::vector alive_ops; - (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) { - MS_EXCEPTION_IF_NULL(op); - if (op->is_alive()) { - alive_ops.push_back(op); - } - }); - - if (alive_ops.size() > 2) { - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - return SearchStrategyForMultiNodeFinalGraph(alive_ops); - } else { - // inference phase - MS_LOG(EXCEPTION) - << "Currently, searching strategy for the multi-node final graph in inference phase is not supported."; - } - } else if (alive_ops.size() == 1) { - MS_LOG(INFO) << "There are 1 single node in the final graph."; - OperatorInfoPtr u = alive_ops[0]; - auto cost_list = CreateFinalSingleCostList(u); - CostPtr cost = nullptr; - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); - } else { - // inference phase - cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_); - } - if (cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(cost->decision_ptr_); - auto decision = cost->decision_ptr_->cast(); - MS_EXCEPTION_IF_NULL(decision); - u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); - MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; - return SUCCESS; - } else { - // In this case, the final graph should contains exactly 2 nodes. - if (alive_ops.empty()) { - MS_LOG(INFO) << "0 Operator in the final graph."; - return SUCCESS; - } - OperatorInfoPtr u, v; - MS_EXCEPTION_IF_NULL(alive_ops[0]); - MS_EXCEPTION_IF_NULL(alive_ops[1]); - if (!alive_ops[0]->GetAliveSuccEdges().empty() && - alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) { - u = alive_ops[0]; - v = alive_ops[1]; - } else if (!alive_ops[1]->GetAliveSuccEdges().empty() && - alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) { - u = alive_ops[1]; - v = alive_ops[0]; - } else { - if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) { - MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size() - << ", " << alive_ops[1]->GetAliveSuccEdges().size() << "."; - } else { - // In this case, the final graph consists of two single nodes - MS_LOG(INFO) << "There are 2 single nodes in the final graph."; - std::vector all_list; - auto connected_components = ConstructConnectedComponents(alive_ops); - MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; - for (size_t i = 0; i < connected_components.size(); ++i) { - MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; - auto one_component = connected_components[i]; - MS_EXCEPTION_IF_NULL(one_component); - auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); - all_list.push_back(cost_list); - } - CostPtrList selected_cost_list; - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); - } else { - // inference phase - MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " - "phase is not supported."; - } - for (size_t k = 0; k < selected_cost_list.size(); ++k) { - auto selected_cost = selected_cost_list[k]; - if (selected_cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(connected_components[k]); - auto one_operator = connected_components[k]->GetOperators()[0]; - MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); - auto decision = selected_cost->decision_ptr_->cast(); - MS_EXCEPTION_IF_NULL(decision); - one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); - MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; - } - - return SUCCESS; - } - } - MS_LOG(INFO) << "There are 2 nodes in the final graph."; - // In this case, the finale graph is exactly of the form: u --> v - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(v); - auto e = u->GetAliveSuccEdges()[0]; - MS_EXCEPTION_IF_NULL(e); - auto cost_list = CreateFinalCostList(u, e, v); - CostPtr cost = nullptr; - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); - } else { - MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " - "phase is not supported."; - } - if (cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(cost->decision_ptr_); - auto decision = cost->decision_ptr_->cast(); - MS_EXCEPTION_IF_NULL(decision); - u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); - v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); - e->set_selected_cost(decision->middle_cost_); - MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; - return SUCCESS; - } -} - -// Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated -// return the v and the edge u --> v -OperatorInfoPtr CostGraph::CheckOpElimination() const { - for (auto &op : ops_) { - bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1; - if (bool_test) { - if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) { - return op; - } - } - } - return nullptr; -} - -// Check the graph whether an EdgeElimination can be performed -std::vector> CostGraph::CheckEdgeElimination() const { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - if (!op->is_alive()) continue; - std::map count; - for (auto &edge : op->GetAliveSuccEdges()) { - MS_EXCEPTION_IF_NULL(edge); - auto v = edge->next_operator(); - count[v.get()]++; - } - for (auto &pair : count) { - auto *op_ptr = pair.first; - int op_count = pair.second; - if (op_count > 1) { - std::vector> ret; - for (auto &edge : op->GetAliveSuccEdges()) { - MS_EXCEPTION_IF_NULL(edge); - if (edge->next_operator().get() == op_ptr) { - ret.push_back(edge); - } - } - return ret; - } - } - } - return {}; -} - -// Check the graph whether a MergeElimination can be performed -OperatorInfoPtr CostGraph::CheckMergeElimination() const { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1; - if (bool_test) { - auto next_op = op->GetAliveSuccEdges()[0]->next_operator(); - MS_EXCEPTION_IF_NULL(next_op); - if (!next_op->GetAlivePrevEdges().empty()) { - return op; - } - } - } - return nullptr; -} - -// Check the graph whether a ContractElimination can be performed -OperatorInfoPtr CostGraph::CheckContractElimination() const { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty(); - if (bool_test) { - auto edge = op->GetAlivePrevEdges()[0]; - MS_EXCEPTION_IF_NULL(edge); - auto prev_op = edge->prev_operator(); - MS_EXCEPTION_IF_NULL(prev_op); - if (!prev_op->GetAliveSuccEdges().empty()) { - return op; - } - } - } - return nullptr; -} - -// Check the graph whether a TriangleElimination can be performed -std::pair> CostGraph::CheckTriangleElimination() const { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2); - if (bool_test) { - auto edge1 = op->GetAliveSuccEdges()[0]; - auto edge2 = op->GetAliveSuccEdges()[1]; - MS_EXCEPTION_IF_NULL(edge1); - MS_EXCEPTION_IF_NULL(edge2); - auto first_op = edge1->next_operator(); - auto second_op = edge2->next_operator(); - MS_EXCEPTION_IF_NULL(first_op); - for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) { - if (first_op_succ_edge->next_operator() == second_op) { - return {op, first_op_succ_edge}; - } - } - MS_EXCEPTION_IF_NULL(second_op); - for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) { - if (second_op_succ_edge->next_operator() == first_op) { - return {op, second_op_succ_edge}; - } - } - } - } - return {nullptr, nullptr}; -} - -// Check the graph whether a StarElimination can be performed. -// NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. -OperatorInfoPtr CostGraph::CheckStarElimination() const { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1); - if (bool_test) { - return op; - } - } - return nullptr; -} - -// This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace -// 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge. -std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr &op) { - // in this case, the operators are organised in the form of u-->op-->v, and the goal - // is to eliminate 'op'. - MS_EXCEPTION_IF_NULL(op); - MS_LOG(INFO) << "Now eliminating node: " << op->name() << "."; - auto edge_u_op = op->GetAlivePrevEdges()[0]; - auto edge_op_v = op->GetAliveSuccEdges()[0]; - MS_EXCEPTION_IF_NULL(edge_u_op); - MS_EXCEPTION_IF_NULL(edge_op_v); - auto u = edge_u_op->prev_operator(); - auto v = edge_op_v->next_operator(); - std::vector output_indexs, input_indexs; - size_t output_index, input_index; - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(v); - std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); - std::shared_ptr new_edge; - if (edge_u_op->is_combined()) { - output_indexs = edge_u_op->prev_op_output_indexs(); - } else { - output_index = edge_u_op->prev_op_output_index(); - output_indexs.push_back(output_index); - } - if (edge_op_v->is_combined()) { - input_indexs = edge_op_v->next_op_input_indexs(); - } else { - input_index = edge_op_v->next_op_input_index(); - input_indexs.push_back(input_index); - } - - if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) { - new_edge = std::make_shared(new_edge_name, u, v, output_index, input_index, false); - } else { - new_edge = std::make_shared(new_edge_name, u, v, output_indexs, input_indexs, true); - } - MS_EXCEPTION_IF_NULL(new_edge); - new_edge->set_pre_op_output(edge_u_op->prev_op_output()); - new_edge->set_next_op_input(edge_op_v->next_op_input()); - new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v); - u->ReplaceSuccEdge(op, new_edge); - v->ReplacePreEdge(op, new_edge); - op->SetNotAlive(); - MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded."; - return new_edge; -} - -// This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges', -// and sets new costlist for the new edge. -std::shared_ptr CostGraph::EliminationEdges(const std::vector> &edges) { - MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges."; - MS_EXCEPTION_IF_NULL(edges[0]); - auto u = edges[0]->prev_operator(); - auto v = edges[0]->next_operator(); - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(v); - std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); - std::vector output_indexs, input_indexs; - - for (auto &edge : edges) { - MS_EXCEPTION_IF_NULL(edge); - if (edge->is_combined()) { - auto from_output_indexs = edge->prev_op_output_indexs(); - auto from_input_indexs = edge->next_op_input_indexs(); - (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs)); - (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs)); - } else { - output_indexs.push_back(edge->prev_op_output_index()); - input_indexs.push_back(edge->next_op_input_index()); - } - } - - std::shared_ptr new_edge = std::make_shared(new_edge_name, u, v, output_indexs, input_indexs, true); - MS_EXCEPTION_IF_NULL(new_edge); - new_edge->set_pre_op_output(edges[0]->prev_op_output()); - new_edge->set_next_op_input(edges[0]->next_op_input()); - - new_edge->EdgeEliminationSetNewCost(u, edges, v); - - u->ReplaceSuccEdges(v, new_edge); - v->ReplacePreEdges(u, new_edge); - MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded."; - return new_edge; -} - -// Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' -// for this contract under the strategy 'op_strategy' -void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, - const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, - const CostPtrList &tar_cost_list, - CostPtrList *const tar_cost_list_new) { - for (size_t i = 0; i < op_cost_list.size(); ++i) { - auto &op_cost = op_cost_list[i]; - MS_EXCEPTION_IF_NULL(op_cost); - for (size_t j = 0; j < edge_cost_list.size(); ++j) { - auto &edge_cost = edge_cost_list[j]; - MS_EXCEPTION_IF_NULL(edge_cost); - for (size_t k = 0; k < tar_cost_list.size(); ++k) { - auto &tar_cost = tar_cost_list[k]; - MS_EXCEPTION_IF_NULL(tar_cost); - double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; - double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; - double communication = - op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; - double communication_forward = - op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_; - double communication_without_para = op_cost->communication_without_parameter_ + - edge_cost->communication_without_parameter_ + - tar_cost->communication_without_parameter_; - - auto decision = - std::make_shared(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); - auto new_cost = std::make_shared(computation, communication, decision); - MS_EXCEPTION_IF_NULL(new_cost); - new_cost->communication_without_parameter_ = communication_without_para; - new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - new_cost->memory_with_reuse_ = memory; - new_cost->communication_forward_ = communication_forward; - MS_EXCEPTION_IF_NULL(tar_cost_list_new); - tar_cost_list_new->emplace_back(std::move(new_cost)); - } - } - } -} - -// This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the -// target_op -OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) { - MS_EXCEPTION_IF_NULL(op); - auto target_op = op->GetAliveSuccEdges()[0]->next_operator(); - auto edge_ptr = op->GetAliveSuccEdges()[0]; - MS_EXCEPTION_IF_NULL(target_op); - MS_EXCEPTION_IF_NULL(edge_ptr); - MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << "."; - bool valid = false; - - for (auto &tar_stra_cost : target_op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(tar_stra_cost); - auto tar_stra = tar_stra_cost->strategy_ptr; - auto tar_clist_origin = tar_stra_cost->cost_list; - CostPtrList tar_clist_new; - - for (auto &op_stra_cost : op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(op_stra_cost); - auto op_stra = op_stra_cost->strategy_ptr; - auto op_clist = op_stra_cost->cost_list; - auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra); - - CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); - } - Simplify(&tar_clist_new); - // Set the new costlist w.r.t the strategy - tar_stra_cost->cost_list = tar_clist_new; - if ((!valid) && (!tar_clist_new.empty())) { - valid = true; - } - } - - if (!valid) { - MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed."; - } - op->SetNotAlive(); - MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded."; - return target_op; -} - -// Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' -// for this contract under the strategy 'contract_op_stra' -void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra, - const CostPtrList &contract_op_cost_list, - const CostPtrList &edge_cost_list, StrategyPtr target_op_stra, - const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) { - for (size_t i = 0; i < contract_op_cost_list.size(); ++i) { - auto &contract_op_cost = contract_op_cost_list[i]; - MS_EXCEPTION_IF_NULL(contract_op_cost); - for (size_t j = 0; j < edge_cost_list.size(); ++j) { - auto &edge_cost = edge_cost_list[j]; - MS_EXCEPTION_IF_NULL(edge_cost); - for (size_t k = 0; k < tar_cost_list.size(); ++k) { - auto &tar_cost = tar_cost_list[k]; - MS_EXCEPTION_IF_NULL(tar_cost); - double computation = - contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; - double memory = - contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; - double communication = - contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; - double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ + - tar_cost->communication_forward_; - double communication_without_para = contract_op_cost->communication_without_parameter_ + - edge_cost->communication_without_parameter_ + - tar_cost->communication_without_parameter_; - - auto decision = std::make_shared(contract_op_stra, contract_op_cost, edge_cost, - target_op_stra, tar_cost); - auto new_cost = std::make_shared(computation, communication, decision); - new_cost->communication_without_parameter_ = communication_without_para; - new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - new_cost->memory_with_reuse_ = memory; - new_cost->communication_forward_ = communication_forward; - tar_cost_list_new->emplace_back(std::move(new_cost)); - } - } - } -} - -// This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the -// target_op -OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) { - MS_EXCEPTION_IF_NULL(op); - auto target_op = op->GetAlivePrevEdges()[0]->prev_operator(); - auto edge_ptr = op->GetAlivePrevEdges()[0]; - MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << "."; - bool valid = false; - - for (auto &tar_stra_cost : target_op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(tar_stra_cost); - auto tar_stra = tar_stra_cost->strategy_ptr; - auto tar_clist_origin = tar_stra_cost->cost_list; - CostPtrList tar_clist_new; - - for (auto &op_stra_cost : op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(op_stra_cost); - auto op_stra = op_stra_cost->strategy_ptr; - auto op_clist = op_stra_cost->cost_list; - auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra); - - CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); - } - Simplify(&tar_clist_new); - // Set the new costlist w.r.t the strategy - tar_stra_cost->cost_list = tar_clist_new; - if ((!valid) && (!tar_clist_new.empty())) { - valid = true; - } - } - if (!valid) { - MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed."; - } - op->SetNotAlive(); - MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded."; - return target_op; -} - -void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, - StrategyPtr right_op_stra, const CostPtr &right_op_cost, - const CostPtrList &elimi_op_clist, - const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost, - const CostPtrList &left_node_clist_origin, - CostPtrList *left_node_clist_new) { - MS_EXCEPTION_IF_NULL(right_edge_cost); - MS_EXCEPTION_IF_NULL(right_op_cost); - MS_EXCEPTION_IF_NULL(left_node_clist_new); - for (auto &elimi_op_cost : elimi_op_clist) { - MS_EXCEPTION_IF_NULL(elimi_op_cost); - for (auto &left_edge_cost : left_edge_clist) { - MS_EXCEPTION_IF_NULL(left_edge_cost); - for (auto &left_node_cost : left_node_clist_origin) { - MS_EXCEPTION_IF_NULL(left_node_cost); - double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + - left_node_cost->computation_cost_ + right_edge_cost->computation_cost_; - double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ + - left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; - double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + - left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; - double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ + - left_node_cost->communication_forward_ + right_edge_cost->communication_forward_; - double new_commu_without = - elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + - left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; - - auto decision = std::make_shared( - elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra); - auto new_cost = std::make_shared(new_computation, new_commu_cost, decision); - new_cost->communication_without_parameter_ = new_commu_without; - new_cost->communication_with_partial_para_ = - new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); - new_cost->memory_with_reuse_ = new_memory; - new_cost->communication_forward_ = new_commu_forward; - left_node_clist_new->emplace_back(std::move(new_cost)); - } - } - } -} - -void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist, - const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra, - const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra, - const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist, - const CostPtrList &left_node_clist_origin, - CostPtrList *left_node_clist_new) { - MS_EXCEPTION_IF_NULL(elimi_op); - for (auto &right_node_cost : right_node_clist) { - MS_EXCEPTION_IF_NULL(right_node_cost); - for (auto &right_edge_cost : right_edge_clist) { - MS_EXCEPTION_IF_NULL(right_edge_cost); - CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, - elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin, - left_node_clist_new); - } - } -} - -OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, - const std::shared_ptr &edge_left_right) { - MS_EXCEPTION_IF_NULL(edge_left_right); - MS_EXCEPTION_IF_NULL(elimi_op); - MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << "."; - auto left_node = edge_left_right->prev_operator(); - auto right_node = edge_left_right->next_operator(); - auto left_edge = elimi_op->GetAliveSuccEdges()[0]; - auto right_edge = elimi_op->GetAliveSuccEdges()[1]; - MS_EXCEPTION_IF_NULL(left_node); - MS_EXCEPTION_IF_NULL(right_node); - MS_EXCEPTION_IF_NULL(left_edge); - MS_EXCEPTION_IF_NULL(right_edge); - MS_LOG(INFO) << "The left operator is: " << left_node->name() << "."; - MS_LOG(INFO) << "The right operator is: " << right_node->name() << "."; - - if (left_edge->next_operator() != left_node) { - auto tmp = left_edge; - left_edge = right_edge; - right_edge = tmp; - } - bool valid = false; - - for (auto &left_node_stra_cost : left_node->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(left_node_stra_cost); - auto left_node_stra = left_node_stra_cost->strategy_ptr; - auto left_node_clist_origin = left_node_stra_cost->cost_list; - CostPtrList left_node_clist_new; - - for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(elimi_op_stra_cost); - auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr; - auto elimi_op_clist = elimi_op_stra_cost->cost_list; - auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra); - - for (auto &right_node_stra_cost : right_node->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(right_node_stra_cost); - auto right_node_stra = right_node_stra_cost->strategy_ptr; - auto right_node_clist = right_node_stra_cost->cost_list; - auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra); - - CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra, - right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin, - &left_node_clist_new); - } - } - Simplify(&left_node_clist_new); - // Set the new costlist w.r.t the strategy - left_node_stra_cost->cost_list = left_node_clist_new; - if ((!valid) && (!left_node_clist_new.empty())) { - valid = true; - } - } - - if (!valid) { - MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed."; - } - elimi_op->SetNotAlive(); - MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded."; - return left_node; -} - -void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra, - const CostPtrList &first_succ_node_clist, - const CostPtrList &first_succ_edge_clist, - const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, - std::vector succ_nodes_stras, - CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs, - CostPtrList *first_succ_node_clist_new) { - for (auto &first_succ_node_cost : first_succ_node_clist) { - for (auto &first_succ_edge_cost : first_succ_edge_clist) { - for (auto &merged_node_cost : merged_op_clist) { - MS_EXCEPTION_IF_NULL(merged_node_cost); - succ_nodes_stras[0] = first_succ_node_stra; - succ_edges_costs[0] = first_succ_edge_cost; - succ_nodes_costs[0] = first_succ_node_cost; - - double computation_cost = merged_node_cost->computation_cost_, - memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, - commu_without = merged_node_cost->communication_without_parameter_, - commu_forward = merged_node_cost->communication_forward_; - for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { - MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); - if (i == 0) { - computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; - memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; - commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; - commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_; - commu_without += succ_edges_costs[i]->communication_without_parameter_ + - succ_nodes_costs[i]->communication_without_parameter_; - } else { - computation_cost += succ_edges_costs[i]->computation_cost_; - memory_cost += succ_edges_costs[i]->memory_with_reuse_; - commu_cost += succ_edges_costs[i]->communication_cost_; - commu_forward += succ_edges_costs[i]->communication_forward_; - commu_without += succ_edges_costs[i]->communication_without_parameter_; - } - } - - auto decision = std::make_shared(merged_op_stra, merged_node_cost, succ_edges_costs, - succ_nodes_stras, succ_nodes_costs); - auto new_cost = std::make_shared(computation_cost, commu_cost, decision); - new_cost->communication_without_parameter_ = commu_without; - new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); - new_cost->memory_with_reuse_ = memory_cost; - new_cost->communication_forward_ = commu_forward; - first_succ_node_clist_new->emplace_back(std::move(new_cost)); - } - } - } -} - -void CostGraph::CreateStarEliminationCostList(std::vector> &succ_edges, - const StrategyPtr &first_succ_node_stra, - const CostPtrList &first_succ_node_clist, - const CostPtrList &first_succ_edge_clist, - const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, - CostPtrList *first_succ_node_clist_new) { - std::vector succ_nodes_stras(succ_edges.size(), nullptr); - CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr); - std::function recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist, - &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs, - &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive, - this](size_t k) { - if (k == succ_edges.size()) { - CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, - merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs, - succ_nodes_costs, first_succ_node_clist_new); - return; - } - MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size() - << ", first_succ_edge_clist: " << first_succ_edge_clist.size() - << ", merged_op_clist: " << merged_op_clist.size() - << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << "."; - auto succ_edge = succ_edges[k]; - MS_EXCEPTION_IF_NULL(succ_edge); - auto succ_node = succ_edge->next_operator(); - MS_EXCEPTION_IF_NULL(succ_node); - for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(succ_node_stra_cost); - auto succ_node_stra = succ_node_stra_cost->strategy_ptr; - auto succ_node_clist = succ_node_stra_cost->cost_list; - auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra); - - for (auto &succ_node_cost : succ_node_clist) { - MS_EXCEPTION_IF_NULL(succ_node_cost); - for (auto &succ_edge_cost : succ_edge_clist) { - MS_EXCEPTION_IF_NULL(succ_edge_cost); - succ_nodes_stras[k] = succ_node_stra; - succ_edges_costs[k] = succ_edge_cost; - succ_nodes_costs[k] = succ_node_cost; - recursive(k + 1); - } - } - } - }; - - recursive(1); -} - -std::vector> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) { - MS_EXCEPTION_IF_NULL(merged_op); - auto succ_edges = merged_op->GetAliveSuccEdges(); - MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << "."; - for (auto &succ_edge : succ_edges) { - MS_EXCEPTION_IF_NULL(succ_edge->next_operator()); - MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << "."; - } - - MS_EXCEPTION_IF_NULL(succ_edges[0]); - auto first_succ_node = succ_edges[0]->next_operator(); - auto first_succ_edge = succ_edges[0]; - bool valid = false; - - // 'merged_op' is merged into first_node - MS_EXCEPTION_IF_NULL(first_succ_node); - for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost); - auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr; - auto first_succ_node_clist = first_succ_node_stra_cost->cost_list; - CostPtrList first_succ_node_clist_new; - - for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(merged_op_stra_cost); - auto merged_op_stra = merged_op_stra_cost->strategy_ptr; - auto merged_op_clist = merged_op_stra_cost->cost_list; - auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra); - - CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, - merged_op_stra, merged_op_clist, &first_succ_node_clist_new); - } - Simplify(&first_succ_node_clist_new); - // Set the new costlist w.r.t the strategy - first_succ_node_stra_cost->cost_list = first_succ_node_clist_new; - if ((!valid) && (!first_succ_node_clist_new.empty())) { - valid = true; - } - } - - if (!valid) { - MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed."; - } - - merged_op->SetNotAlive(); - MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded."; - return succ_edges; -} - -size_t CostGraph::GetNumEdges() const { - size_t sum = 0; - for (const auto &kv : edges_) { - auto &edges = kv.second; - sum += edges.size(); - } - return sum; -} -Status CostGraph::InitSelectedStrategy() { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - if (op->name().find(RESHAPEINFO) != std::string::npos) { - continue; - } - auto result = op->InitSelectedStrategy(op->selected_strategy()); - if (result != SUCCESS) { - return result; - } - } - // reshape init should be apply after the init of it's previous node and next node. - for (size_t i = 0; i < ops_.size(); ++i) { - if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { - auto reshape_info = std::dynamic_pointer_cast(ops_[i]); - auto in_edges = GetOriginalPrevEdges(ops_[i]); - auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr edge) { - return edge->prev_operator()->name() == reshape_info->pre_operator_name(); - }); - auto out_edges = GetOriginalNextEdges(ops_[i]); - auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr edge) { - return edge->next_operator()->name() == reshape_info->next_operator_name(); - }); - if (pre_iter != in_edges.end()) { - MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name(); - int32_t pre_index = reshape_info->pre_operator_index(); - TensorInfo pre_info; - if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) { - pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index]; - } else { - pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index]; - } - reshape_info->SetInputLayout(pre_info.tensor_layout()); - Dimensions stra = pre_info.InferStrategy(); - if (stra.empty()) { - MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; - } - std::vector stra_inputs = {stra}; - StrategyPtr reshape_stra = - std::make_shared((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); - reshape_info->set_strategy(reshape_stra); - } - if (next_iter != out_edges.end()) { - MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name(); - int32_t next_index = reshape_info->next_operator_index(); - reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout()); - } - if (reshape_info->Init(nullptr) != SUCCESS) { - return FAILED; - } - } - } - return SUCCESS; -} - -Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved(); - if ((output_parameter != 0) && (output_parameter != 1)) { - MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed."; - return FAILED; - } - } - return SUCCESS; -} - -void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map *visited, - std::vector *topo_order) { - MS_EXCEPTION_IF_NULL(current_op); - MS_EXCEPTION_IF_NULL(visited); - MS_EXCEPTION_IF_NULL(topo_order); - - visited->at(current_op) = true; - for (const auto &s_edge : current_op->succ_edges()) { - if (!visited->at(s_edge->next_operator())) { - DFSForTopoOrder(s_edge->next_operator(), visited, topo_order); - } - } - topo_order->push_back(current_op); -} - -// Compute a topological order of the costgraph -void CostGraph::TopologyOrder(std::vector *topo_order) { - std::map visited; - for (auto &op : ops_) { - visited[op] = false; - } - - for (auto &op : ops_) { - if (!visited[op]) { - DFSForTopoOrder(op, &visited, topo_order); - } - } -} -void CostGraph::MarkCriticalOpsAndEdges(const std::map &candidate_ops) { - for (auto &op : ops_) { - auto search = candidate_ops.find(op); - if (search != candidate_ops.end()) { - // Mark the critical operators - op->mark_output_critical(); - // Mark the successive edges - for (auto &s_edge : op->succ_edges()) { - s_edge->mark_output_critical(); - } - } else { - op->mark_output_not_critical(); - } - } -} - -Status CostGraph::DetermineCriticalOps(const std::vector &topo_order) { - if (topo_order.size() == 0) { - MS_LOG(ERROR) << "0 operator in costgraph."; - return FAILED; - } - auto &first_op = topo_order[0]; - if (first_op->prev_edges().size() > 0) { - MS_LOG(ERROR) << "The first operator in the first of topological order of " - "costgraph should have 0 incoming edge, but has " - << first_op->prev_edges() << "edges."; - return FAILED; - } - // The 'curr_memory_state' records , where remaining_output_cnt is the number - // of the output of OperatorInfo that currently has not been used - std::map curr_memory_state; - (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size()))); - std::map max_memory_state = curr_memory_state; - // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has - // not been used - double curr_memory_size = first_op->GetOutputsTotalSize(); - double max_memory_size = curr_memory_size; - - for (size_t finished = 1; finished < topo_order.size(); ++finished) { - // Produce - (void)curr_memory_state.emplace( - std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size()))); - curr_memory_size += topo_order[finished]->GetOutputsTotalSize(); - // Consume - for (const auto &prev_edge : topo_order[finished]->prev_edges()) { - const auto &prev_op = prev_edge->prev_operator(); - curr_memory_state[prev_op]--; - } - for (const auto &prev_edge : topo_order[finished]->prev_edges()) { - const auto &prev_op = prev_edge->prev_operator(); - if (curr_memory_state[prev_op] < 0) { - MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op]; - return FAILED; - } else if (curr_memory_state[prev_op] == 0) { - curr_memory_state.erase(prev_op); - curr_memory_size -= prev_op->GetOutputsTotalSize(); - } - } - - if (curr_memory_size < 0) { - MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size; - } - // Modify the max - if (curr_memory_size > max_memory_size) { - max_memory_size = curr_memory_size; - max_memory_state = curr_memory_state; - } - } - // Mark those critical operators - MarkCriticalOpsAndEdges(max_memory_state); - return SUCCESS; -} - -Status CostGraph::ComputeOpsAndEdgesOutputCritical() { - // Two steps to do: - // 1. Compute a topological order of the costgraph - // 2. Determine and mark the operators (and necessary edges) that are critical - std::vector topo_order; - TopologyOrder(&topo_order); - std::reverse(std::begin(topo_order), std::end(topo_order)); - - if (DetermineCriticalOps(topo_order) != SUCCESS) { - MS_LOG(ERROR) << "Determining critical operators failed."; - return FAILED; - } - return SUCCESS; -} - -Status CostGraph::CalculateOpsMemoryCost() { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - if (op->CalculateMemoryCost() != SUCCESS) { - MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; - return FAILED; - } - } - return SUCCESS; -} - -Status CostGraph::CalculateOpsMemoryCostForInference() { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - if (op->CalculateMemoryCostForInference() != SUCCESS) { - MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; - return FAILED; - } - } - return SUCCESS; -} - -Status CostGraph::CalculateEdgesMemoryCost() { - for (auto &edge_pair : edges_) { - const auto &edges = edge_pair.second; - for (auto &one_edge : edges) { - if (one_edge->CalculateMemoryCost() != SUCCESS) { - MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; - return FAILED; - } - } - } - return SUCCESS; -} - -Status CostGraph::CalculateEdgesMemoryCostForInference() { - for (auto &edge_pair : edges_) { - const auto &edges = edge_pair.second; - for (auto &one_edge : edges) { - if (one_edge->CalculateMemoryCostForInference() != SUCCESS) { - MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; - return FAILED; - } - } - } - return SUCCESS; -} - -OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { - for (auto one_op : ops_) { - if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { - if (one_op->refkey_parameter_name() == p_name) { - return one_op; - } - } - } - return nullptr; -} -Status CostGraph::CorrectOpsMemoryCost() { - for (auto &one_op : ops_) { - if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) { - if (one_op->GetAliveSuccEdges().size() > 1) { - // Filter out the case when the TmpIdentity being used by multiple operators - std::map output_count; - for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { - auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); - output_count[output_index]++; - } - for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { - auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); - if (output_count[output_index] <= 1) { - continue; - } - auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator(); - MS_EXCEPTION_IF_NULL(next_op); - auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index(); - if (next_op->CorrectMemoryCost(input_index) != SUCCESS) { - MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name() - << ", the output_index: " << output_index << ", the input_index: " << input_index << "."; - return FAILED; - } - output_count[output_index]--; - } - } - } - } - return SUCCESS; -} - -Status CostGraph::CalculateMemoryCost() { - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { - // Calculate operators' memory usage - if (CalculateOpsMemoryCost() != SUCCESS) { - MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed."; - return FAILED; - } - // Calculate edges' memory usage - if (CalculateEdgesMemoryCost() != SUCCESS) { - MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed."; - return FAILED; - } - // Correct memory usage caused by TmpIdentity - if (CorrectOpsMemoryCost() != SUCCESS) { - MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed."; - return FAILED; - } - } else { - MS_LOG(ERROR) << "Computing operators' parameter_involved failed."; - return FAILED; - } - } else { - // inference phase - if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) { - // Calculate operators' memory usage - if (CalculateOpsMemoryCostForInference() != SUCCESS) { - MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; - return FAILED; - } - // Calculate edges's memory usage - if (CalculateEdgesMemoryCostForInference() != SUCCESS) { - MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; - return FAILED; - } - } else { - MS_LOG(ERROR) << "Computing operators' critical flag failed."; - return FAILED; - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h deleted file mode 100644 index 3b8b389d81..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ /dev/null @@ -1,238 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ - -#include -#include -#include -#include -#include -#include "../../common.h" -#include "common/utils.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/costmodel_context.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/ops_info/tmp_identity_info.h" - -namespace mindspore { -namespace parallel { -#define OPERATOR_TO_OPERATOR_CONNECTOR "-" -#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) -#define DEFAULT_COST_MODEL_ALPHA 1.0 -#define DEFAULT_COST_MODEL_BETA 400.0 -#define DEFAULT_COST_MODEL_GAMMA 0.001 -#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true -#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 -#define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0 -#define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0 -#define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false -#define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 -#define DEFAULT_FULLY_USE_DEVICES true -#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false -#define DEFAULT_IS_MULTI_SUBGRAPHS false -#define DEFAULT_RUN_PHASE 0 -#define TRAINING_PHASE 0 -#define INFERENCE_PHASE 1 - -class CostGraph; -using CostGraphPtr = std::shared_ptr; -extern CostGraphPtr entire_costgraph; -extern size_t TOTAL_OPS; -extern double COST_MODEL_GAMMA; -extern bool COST_MODEL_SIMPLIFY_CALCULATION; -extern double DEVICE_MEMORY_CAPACITY; -extern double COST_MODEL_COMMUNI_THRESHOLD; -extern double COST_MODEL_COMMUNI_CONST; -extern double COST_MODEL_COMMUNI_BIAS; -extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; -extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; -extern bool FULLY_USE_DEVICES; -extern bool ELEMENTWISE_OP_STRA_FOLLOW; -extern bool MULTI_SUBGRAPHS; -extern int32_t RUN_PHASE; - -class CostGraph { - // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have - // output-input dependency relationship. - public: - CostGraph() { - dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY; - costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; - costmodel_beta_ = DEFAULT_COST_MODEL_BETA; - } - ~CostGraph() = default; - void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } - OperatorInfoPtr FindOperatorByIndex(size_t index) { - if (index >= ops_.size()) { - MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << "."; - return nullptr; - } - return ops_[index]; - } - void RemoveOperator(const OperatorInfoPtr &op); - bool IsOperatorInCostGraph(const OperatorInfoPtr &op); - // the edge is in the form: u --> v - void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge); - std::vector> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; } - std::vector> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; } - // An edge is uniquely identified by its name, and its output index and input index. - bool IsEdgeInCostGraph(const std::string &, size_t, size_t); - - void SetDeviceMemoryAndCostParameter(); - - std::vector> ConstructConnectedComponents(std::vector); - void DFS(const OperatorInfoPtr ¤t_op, std::map *visited, - const std::shared_ptr &component); - - CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); - CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); - CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory); - CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); - CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_costlist, double memory); - Status SearchStrategyForMultiNodeFinalGraph(const std::vector &); - std::vector> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { - return edges_[{u_node, v_node}]; - } - double GetDeviceMemory() const { return dev_memory_; } - - // Search the cost_list in the final graph, and determine the optimal one - Status SearchStrategy(); - - // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated - OperatorInfoPtr CheckOpElimination() const; - // Given a graph which contains the following subgraph where there are multiple edges between u and v, these edges - // can be eliminated into one - std::vector CheckEdgeElimination() const; - // Given a graph which contains the following subgraph: - // u - // | - // w --- v --- x - // where u has 0 incoming edge, u has 1 outgoing edge, and v has > 1 incoming edges, u can be merged into v. - // u is returned. - OperatorInfoPtr CheckMergeElimination() const; - // Given a graph which contains the following subgraph: - // u - // | - // v --- x - // where v has 2 outgoing edges, and u has 1 incoming edges and no outgoing edges. In this case, u can be contracted - // into v. u is returned. - OperatorInfoPtr CheckContractElimination() const; - /* Given a graph which contains the following subgraph: - * u - * / \ - * / \ - * v --- w - * where u has 2 outgoing edges, v has 1 outgoing edge, and w has 2 incoming edges, u can be eliminated into v. - * The returned value includes u and the edge >. - */ - std::pair CheckTriangleElimination() const; - /* Given a graph which contains the following subgraph: - * v <--- u ---> w - * where u has 0 incoming edges, and multiple outgoing edges. In addition, v and w have other complicated connections, - * resulting in v and w can not be performed ContractElimination. u is returned. - * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. - */ - OperatorInfoPtr CheckStarElimination() const; - // Applying Operator Elimination in DP algorithm - EdgePtr EliminationOp(const OperatorInfoPtr &op); - // Applying Edge Elimination in DP algorithm - EdgePtr EliminationEdges(const std::vector &edges); - // Applying Merge Elimination in DP algorithm - OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op); - void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, - const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, - const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new); - // Applying Contract Elimination in DP algorithm - OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op); - void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr, - const CostPtrList &, CostPtrList *); - - // Applying Triangle Elimination in DP algorithm. return the left_node - OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right); - void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &, - const StrategyPtr &, const StrategyPtr &, const StrategyPtr &, - const CostPtrList &, const CostPtrList &, const CostPtrList &, CostPtrList *); - // Given the relevant costlist, create the TriangleElimination cost - void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr &, const CostPtrList &, - const CostPtrList &, const CostPtr &, const CostPtrList &, CostPtrList *); - - // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op - // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. - std::vector EliminationStar(const OperatorInfoPtr &op); - void CreateStarEliminationCostList(std::vector &, const StrategyPtr &, const CostPtrList &, - const CostPtrList &, const StrategyPtr &, const CostPtrList &, CostPtrList *); - void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, - const StrategyPtr &, const CostPtrList &, std::vector, - CostPtrList &, CostPtrList &, CostPtrList *); - // Calculate memory cost for training phase or inference phase. - Status CalculateMemoryCost(); - // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then - // the memory cost can be resused. This is used to calculate memory in the training phase. - Status CalculateOpsMemoryCost(); - // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then - // the memory cost can be reused. This is used to calculate memory in the training phase. - Status CalculateEdgesMemoryCost(); - // Calculate memory cost of operators in the inference phase. - Status CalculateOpsMemoryCostForInference(); - // Calculate memory cost of edges in the inference phase. - Status CalculateEdgesMemoryCostForInference(); - Status ComputeOpsAndEdgesParameterInvolved(); - // Compute for each operator whether the output is critical. - Status ComputeOpsAndEdgesOutputCritical(); - - std::vector GetOperators() const { return ops_; } - size_t GetNumEdges() const; - Status InitSelectedStrategy(); - OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; - // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only - // once (instead of multiple times), this method is used to correct this. - Status CorrectOpsMemoryCost(); - // Needed by rec_parser - void add_inputs_tensor_name(const std::vector &inputs_tensor_name) { - inputs_tensor_name_list_.push_back(inputs_tensor_name); - } - const std::vector> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; } - void add_tuple_getitem(const std::pair &tuple_getitem) { - auto ret = tuple_getitem_list_.insert(tuple_getitem); - if (ret.second == false) { - MS_LOG(EXCEPTION) << "The insert item is already exist."; - } - } - const std::map get_tuple_getitem_list() const { return tuple_getitem_list_; } - - private: - void TopologyOrder(std::vector *); - void DFSForTopoOrder(const OperatorInfoPtr &, std::map *, std::vector *); - Status DetermineCriticalOps(const std::vector &); - void MarkCriticalOpsAndEdges(const std::map &); - // Needed by rec_parser - std::vector> inputs_tensor_name_list_; - std::map tuple_getitem_list_; - double dev_memory_; - double costmodel_alpha_; - double costmodel_beta_; - std::vector ops_; - std::map, std::vector> edges_; - std::vector> connected_compoents_; - std::map> out_edges_; - std::map> in_edges_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc deleted file mode 100644 index 8ebfdb7d13..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ /dev/null @@ -1,892 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/operator_costmodel.h" - -#include -#include -#include "parallel/device_matrix.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -void OperatorCost::set_is_parameter(const std::vector &is_parameter) { is_parameter_ = is_parameter; } - -void OperatorCost::set_is_parameter_involve(const std::vector &is_parameter_inv) { - is_parameter_involve_ = is_parameter_inv; -} - -void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; } - -void OperatorCost::SetInputAndOutputTypeLength(const std::vector &input_lengths, - const std::vector &output_lengths) { - inputs_type_lengths_ = input_lengths; - outputs_type_lengths_ = output_lengths; -} - -void OperatorCost::set_output_critical(int critical) { is_outputs_critical_ = critical; } - -double OperatorCost::GetMemoryCost(const std::vector &inputs, - const std::vector &outputs) const { - double result = 0.0; - if (output_parameter_involve_ == 1) { - // When this operator has multiple outputs, they all contributes to the memory. - for (size_t i = 0; i < outputs.size(); ++i) { - result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); - } - bool is_any_para_inv = - std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; }); - if (is_any_para_inv) { - for (size_t i = 0; i < inputs.size(); ++i) { - if (is_parameter_[i]) { - result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); - } else if (inputs_related_ && (!is_parameter_involve_[i])) { - // When the inputs of this operator are related, and they are not parameter-involved, then they are included - // in the memory cost. - result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); - } - } - } - } - - return result; -} - -double OperatorCost::GetMemoryCostForInference(const std::vector &, - const std::vector &outputs) const { - double result = 0.0; - if (is_outputs_critical_ == -1) { - MS_LOG(EXCEPTION) << "The critical flag is not set."; - } - if (is_outputs_critical_ == 1) { - for (size_t i = 0; i < outputs.size(); ++i) { - result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); - } - } - return result; -} - -// return the per device communication cost in the forward phase. -double MatMulCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t) const { - TensorInfo input0 = inputs[0]; - TensorInfo output0 = outputs[0]; - Shape input0_shape = input0.shape(); - Shape input0_slice_shape = input0.slice_shape(); - if (input0_shape[input0_shape.size() - 1] == input0_slice_shape[input0_slice_shape.size() - 1]) { - // If the reduced dimension has not been partitioned, then there is no communication cost. - return 0.0; - } else { - // Else, the communication cost is the size (number of bytes) of a slice of output tensor. - return ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); - } -} - -// return the per device communication cost in the forward phase. -double MatMulCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not - // fully utilize all devices - double result = 0.0; - if (is_parameter_[1]) { - TensorInfo input1 = inputs[1]; // tensor B - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double MatMulCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t) const { - // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) - double result = 0.0; - TensorInfo output0 = outputs[0]; - Shape input0_slice_shape = inputs[0].slice_shape(); - Shape input1_slice_shape = inputs[1].slice_shape(); - Shape input0_shape = inputs[0].shape(); - if (input0_shape[input0_shape.size() - 1] != input0_slice_shape[input0_slice_shape.size() - 1]) { - // If the reduced dimension has been partitioned, then there is no communication cost. - result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); - } - result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double MatMulCost::GetBackwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) - double result = 0.0; - if (is_parameter_[1]) { - TensorInfo input1 = inputs[1]; // tensor B - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - - return result; -} - -// Return the per device communication cost in the forward phase. -double ActivationCost::GetForwardCommCost(const std::vector &, const std::vector &, - int32_t) const { - // ReLU is the element-wise operator, thus it does not need communication in the forward phase - return 0.0; -} - -// Return the per device communication cost in the backward phase. -double ActivationCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_[0]) { - TensorInfo input1 = inputs[0]; - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - } - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double ActivationCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - TensorInfo input0_info = inputs[0]; - Shape input0_slice_shape = input0_info.slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double ActivationCost::GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const { - return 0.0; -} - -// Return the per device communication cost in the forward phase. -double SoftmaxCost::GetForwardCommCost(const std::vector &, const std::vector &, - int32_t) const { - // In the forward phase, the communication cost = 0 - return 0.0; -} - -// Return the per device communication cost in the backward phase. -double SoftmaxCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_[0]) { - TensorInfo input1 = inputs[0]; - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - } - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double SoftmaxCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - // In the forward phase, the computation cost = slice(A) - TensorInfo input0 = inputs[0]; - Shape input0_slice_shape = input0.slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double SoftmaxCost::GetBackwardComputationCost(const std::vector &, - const std::vector &, int32_t) const { - return 0.0; -} - -// return the per device communication cost in the forward phase. -double TmpIdentityCost::GetForwardCommCost(const std::vector &, - const std::vector &, int32_t) const { - // Identity is the element-wise operator, thus it does not need communication in the forward phase - return 0.0; -} - -// return the per device communication cost in the backward phase. -double TmpIdentityCost::GetBackwardCommCost(const std::vector &, - const std::vector &, int32_t) const { - // Identity is the element-wise operator, thus it does not need communication in the backward phase - return 0.0; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double TmpIdentityCost::GetForwardComputationCost(const std::vector &, - const std::vector &, int32_t) const { - return 0.0; -} - -// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes -// this operator uses -double TmpIdentityCost::GetBackwardComputationCost(const std::vector &, - const std::vector &, - int32_t) const { - return 0.0; -} - -// Return the per device PEAK memory cost contributed by this operator in a training iteration. -double TmpIdentityCost::GetMemoryCost(const std::vector &, const std::vector &) const { - return 0.0; -} - -double BatchParallelCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &, - int32_t) const { - double cost = 0.0; - for (size_t i = 0; i < inputs.size(); ++i) { - cost += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); - } - return cost; -} - -double BatchParallelCost::GetBackwardComputationCost(const std::vector &, - const std::vector &, - int32_t) const { - return 0.0; -} - -double BatchParallelCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - for (size_t j = 0; j < inputs.size(); ++j) { - if (!is_parameter_[j]) { - continue; - } - TensorInfo input_a_tensor_info = inputs[j]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - } - - return result; -} -// return the per device communication cost in the forward phase. -double PReLUCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { - // prelu does not need communication in the forward phase - return 0.0; -} - -// return the per device communication cost in the backward phase. -double PReLUCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_[1]) { - TensorInfo input1 = inputs[1]; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - } - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double PReLUCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - // In forward phase, the computation cost = slice(A) + slice(B) - Shape input0_slice_shape = inputs[0].slice_shape(); - Shape input1_slice_shape = inputs[1].slice_shape(); - double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - return result; -} - -// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes -// this operator uses -double PReLUCost::GetBackwardComputationCost(const std::vector &inputs, - const std::vector &, - int32_t stage_id) const { - // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) - double result = 0.0; - if (is_parameter_[1]) { - TensorInfo input1 = inputs[1]; // tensor B - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) { - result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - } - return result; -} - -// return the per device communication cost in the forward phase. -double OneHotCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { - // onehot does not need communication in the forward phase - return 0.0; -} - -// return the per device communication cost in the backward phase. -double OneHotCost::GetBackwardCommCost(const std::vector &, const std::vector &, - int32_t) const { - // onehot does not need communication in the backward phase - return 0.0; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double OneHotCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - // In onehot's forward phase, the computation cost = slice(A) - Shape input0_slice_shape = inputs[0].slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); -} - -// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes -// this operator uses -double OneHotCost::GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const { - return 0.0; -} - -// return the per device communication cost in the forward phase. -double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector &, - const std::vector &, int32_t) const { - // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase - return 0.0; -} - -// return the per device communication cost in the backward phase. -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector &, - const std::vector &, int32_t) const { - // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase - return 0.0; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &, int32_t) const { - // In forward phase, the computation cost = slice(A) + slice(B) - Shape input0_slice_shape = inputs[0].slice_shape(); - Shape input1_slice_shape = inputs[1].slice_shape(); - double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - return result; -} - -// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes -// this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector &, - const std::vector &, int32_t) const { - return 0.0; -} - -// return the per device communication cost in the forward phase. -double ReshapeCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const { - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); - TensorRedistribution tensor_redistribution(false, true); - if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; - } - if (tensor_redistribution.ComputeCost() == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; - } - return (inputs_type_lengths_[0] * tensor_redistribution.comm_cost()); -} - -// return the per device communication cost in the backward phase. -double ReshapeCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_[0]) { - TensorInfo input1 = inputs[0]; - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - } - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double ReshapeCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const { - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); - TensorRedistribution tensor_redistribution(false, true); - if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; - } - if (tensor_redistribution.ComputeCost() == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; - } - return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost()); -} - -// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes -// this operator uses -double ReshapeCost::GetBackwardComputationCost(const std::vector &, - const std::vector &, int32_t) const { - return 0.0; -} - -double ArithmeticCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - double result; - result = ListProduct(inputs[0].slice_shape()) * static_cast(inputs_type_lengths_[0]) + - ListProduct(inputs[1].slice_shape()) * static_cast(inputs_type_lengths_[1]); - return result; -} - -double ArithmeticCost::GetBackwardComputationCost(const std::vector &inputs, - const std::vector &, int32_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - if (is_parameter_[0]) { - TensorInfo input_a_tensor_info = inputs[0]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - - if (is_parameter_[1]) { - TensorInfo input_b_tensor_info = inputs[1]; - Shape input_b_shape = input_b_tensor_info.shape(); - Shape input_b_slice_shape = input_b_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_b_shape.size(); ++i) { - used_device_num *= input_b_shape[i] / input_b_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_b_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - return result; -} - -double ArithmeticCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - if (is_parameter_[0]) { - TensorInfo input_a_tensor_info = inputs[0]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - - if (is_parameter_[1]) { - TensorInfo input_b_tensor_info = inputs[1]; - Shape input_b_shape = input_b_tensor_info.shape(); - Shape input_b_slice_shape = input_b_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_b_shape.size(); ++i) { - used_device_num *= input_b_shape[i] / input_b_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_b_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - - return result; -} - -bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int32_t stage_id) { - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - auto strategy0 = shape[0] / slice_shape[0]; - - return (total_device_num == IntToSize(strategy0)); -} - -double ReduceMethodCost::GetForwardCommCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const { - double result = 0.0; - TensorInfo input0 = inputs[0]; - TensorInfo output0 = outputs[0]; - Shape input0_shape = input0.shape(); - Shape input0_slice_shape = input0.slice_shape(); - if (cross_batch_ && IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { - return result; - } - std::vector dim_list = input0.reduce_dim(); - std::vector::iterator pos; - pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { - return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; - }); - if (pos != dim_list.end()) { - result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); - } - - return result; -} - -double ReduceMethodCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_[0]) { - TensorInfo input_tensor_info = inputs[0]; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape input_shape = input_tensor_info.shape(); - Shape input_slice_shape = input_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_shape.size(); ++i) { - used_device_num *= input_shape[i] / input_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - - return result; -} - -double ReduceMethodCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const { - double result = 0.0; - TensorInfo input0 = inputs[0]; - TensorInfo output0 = outputs[0]; - std::vector dim_list = input0.reduce_dim(); - Shape input0_slice_shape = input0.slice_shape(); - Shape input0_shape = input0.shape(); - if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { - std::vector::iterator pos; - pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { - return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; - }); - if (pos != dim_list.end()) { - result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); - } - } - result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); - - return result; -} - -double ReduceMeanCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const { - double result = 0.0; - TensorInfo input0 = inputs[0]; - TensorInfo output0 = outputs[0]; - std::vector dim_list = input0.reduce_dim(); - Shape input0_slice_shape = input0.slice_shape(); - Shape input0_shape = input0.shape(); - if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { - std::vector::iterator pos; - pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { - return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; - }); - if (pos != dim_list.end()) { - result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]) * 2.0; - } - } - result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); - - return result; -} - -double DropOutCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - if (inputs.empty()) { - return 0.0; - } - TensorInfo input0 = inputs[0]; - Shape input0_slice_shape = input0.slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * DROPOUT_COST_RATE; -} - -// return the per device communication cost in the forward phase. -double GatherV2Cost::GetForwardCommCost(const std::vector &, const std::vector &, - int32_t) const { - // GatherV2Cost does not need communication in the forward phase - return 0.0; -} - -// return the per device communication cost in the backward phase. -double GatherV2Cost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - for (size_t j = 0; j < inputs.size(); ++j) { - if (!is_parameter_[j]) { - continue; - } - TensorInfo input_a_tensor_info = inputs[j]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - } - - return result; -} - -double GatherV2Cost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - // In forward phase, the computation cost = slice(A) + slice(B) - Shape input0_slice_shape = inputs[0].slice_shape(); - Shape input1_slice_shape = inputs[1].slice_shape(); - double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - return result; -} - -double GatherV2Cost::GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const { - return 0.0; -} - -double LayerNormCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid parameter size " << is_parameter_.size() << " for layer norm cost"; - } - if (inputs_type_lengths_.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost"; - } - - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - for (size_t index = 0; index < inputs.size(); ++index) { - if (is_parameter_[index]) { - TensorInfo tensor_info = inputs[index]; - Shape shape = tensor_info.shape(); - Shape slice_shape = tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < shape.size(); ++i) { - if (slice_shape[i] == 0) { - MS_LOG(EXCEPTION) << "Invalid slice shape " << ShapeToString(slice_shape); - } - used_device_num *= shape[i] / slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result += ListProduct(slice_shape) * static_cast(inputs_type_lengths_[index]); - } - } - } - return result; -} - -double LayerNormCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - double result = 0.0; - if (inputs_type_lengths_.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost"; - } - - for (size_t index = 0; index < inputs.size(); ++index) { - TensorInfo tensor_info = inputs[index]; - Shape slice_shape = tensor_info.slice_shape(); - result += ListProduct(slice_shape) * static_cast(inputs_type_lengths_[index]); - } - return result; -} - -double GatherV2PCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const { - double result = 0.0; - if (outputs_type_lengths_.size() != outputs.size()) { - MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; - } - // don't split axis - if (strategy_.at(IntToSize(axis_)) == 1) { - return result; - } - - // split axis - auto param_shape = inputs[0].slice_shape(); - auto index_shape = inputs[1].slice_shape(); - Shape reducescatter_shape = index_shape; - if (param_shape.size() == 2) { - reducescatter_shape.push_back(param_shape.at(1 - axis_)); - } - result += ListProduct(reducescatter_shape) * static_cast(outputs_type_lengths_[0]); - return result; -} - -double GatherV2PCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - for (size_t j = 0; j < inputs.size(); ++j) { - if (!is_parameter_[j]) { - continue; - } - TensorInfo input_a_tensor_info = inputs[j]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - } - return result; -} - -double GatherV2PCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const { - double result = 0.0; - Shape input0_slice_shape = inputs[0].slice_shape(); - Shape input1_slice_shape = inputs[1].slice_shape(); - if (inputs_type_lengths_.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; - } - // don't split axis - if (strategy_.at(IntToSize(axis_)) == 1) { - result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } else { - // split axis - result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1; - } - - return result; -} - -double GatherV2PCost::GetBackwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t) const { - double result = 0.0; - Shape input1_slice_shape = inputs[1].slice_shape(); - Shape output0_slice_shape = outputs[0].slice_shape(); - // don't split axis - if (strategy_.at(IntToSize(axis_)) == 1) { - result += ListProduct(output0_slice_shape) * static_cast(inputs_type_lengths_[0]); - } else { - // split axis - result += ListProduct(output0_slice_shape) * static_cast(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3; - } - - return result; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h deleted file mode 100644 index a08a4dbb13..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ /dev/null @@ -1,656 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ -#define PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ - -#include -#include -#include "parallel/device_manager.h" -#include "parallel/tensor_layout/tensor_info.h" - -namespace mindspore { -namespace parallel { -#define MAXIMUM_INPUT_NUMBER 100 -#define DEFAULT_DATA_TYPE_LENGTH 4 -#define DROPOUT_COST_RATE 1.125 // the DropoutGenMask need 12.5% memory -#define GATHERV2_COST_WEIGHT0 3 -#define GATHERV2_COST_WEIGHT1 7 -#define GATHERV2_COST_WEIGHT2 2 -#define GATHERV2_COST_WEIGHT3 6 - -class OperatorCost; -using OperatorCostPtr = std::shared_ptr; - -template -double ListProduct(std::vector vec) { - double result = 1; - for (size_t i = 0; i < vec.size(); ++i) { - result *= vec[i]; - } - return result; -} -// NOTE: Currently, the returned value in each method is bytes of memory size, which is calculated by the number of -// entries timing the length of each entry's data type -class OperatorCost { - public: - explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) { - // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked - for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { - is_parameter_.push_back(false); - is_parameter_involve_.push_back(false); - inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); - outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); - } - } - OperatorCost() : inputs_related_(false) { - // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked - for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { - is_parameter_.push_back(false); - is_parameter_involve_.push_back(false); - inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); - outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); - } - } - virtual ~OperatorCost() = default; - - void set_is_parameter(const std::vector &is_parameter); - void set_is_parameter_involve(const std::vector &); - void set_output_parameter_involve(int); - void set_output_critical(int); - void SetInputAndOutputTypeLength(const std::vector &input_lengths, const std::vector &output_lengths); - std::vector inputs_type_lengths() const { return inputs_type_lengths_; } - std::vector outputs_type_lengths() const { return outputs_type_lengths_; } - - // per device communication cost - virtual double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const = 0; - virtual double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const = 0; - virtual double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const = 0; - // per device computation cost - virtual double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const = 0; - virtual double GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetBackwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const = 0; - // per device PEAK memory cost in a training iteration - // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), - // plus necessary inputs. - virtual double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const; - // per device memory cost in a inference phase - double GetMemoryCostForInference(const std::vector &, const std::vector &) const; - - protected: - // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of - // pre-operator that has parameters as input. - std::vector is_parameter_involve_; - int output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved - // Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while - // Mul's two inputs are dependent (related). - bool inputs_related_; - // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter - std::vector is_parameter_; - // for each input and output, the followings record the number of bytes of each element - std::vector inputs_type_lengths_; - std::vector outputs_type_lengths_; - // Whether the output is critical, which means that this output is included in calculating peak memory cost - // in the inference phase. - int is_outputs_critical_ = -1; -}; - -using OperatorCostPtr = std::shared_ptr; - -class MatMulCost : public OperatorCost { - public: - explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - MatMulCost() : OperatorCost(true) {} - ~MatMulCost() override = default; - - // per device communication cost - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - // per device computation cost - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using MatMulCostPtr = std::shared_ptr; - -class ActivationCost : public OperatorCost { - public: - explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ActivationCost() : OperatorCost(false) {} - ~ActivationCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using ActivationCostPtr = std::shared_ptr; -using TransposeCost = ActivationCost; -using TransposeCostPtr = std::shared_ptr; - -class SoftmaxCost : public OperatorCost { - public: - explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - SoftmaxCost() : OperatorCost(false) {} - ~SoftmaxCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t) const override; -}; -using SoftmaxCostPtr = std::shared_ptr; - -class TmpIdentityCost : public OperatorCost { - public: - explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - TmpIdentityCost() : OperatorCost(false) {} - ~TmpIdentityCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override; -}; -using TmpIdentityCostPtr = std::shared_ptr; - -class BatchParallelCost : public OperatorCost { - public: - explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - BatchParallelCost() : OperatorCost(false) {} - ~BatchParallelCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using BatchParallelCostPtr = std::shared_ptr; - -class VirtualDatasetCost : public OperatorCost { - public: - explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - VirtualDatasetCost() : OperatorCost(false) {} - ~VirtualDatasetCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } - // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override { - return 0.0; - } -}; -using VirtualDatasetCostPtr = std::shared_ptr; - -class GeneratorBaseCost : public OperatorCost { - public: - explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - GeneratorBaseCost() : OperatorCost(false) {} - ~GeneratorBaseCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - // Inputs vector is empty for generator ops. - double GetForwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } - // Generator ops don't have backward steps. - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } -}; -using GeneratorBaseCostPtr = std::shared_ptr; - -class PReLUCost : public OperatorCost { - public: - explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - PReLUCost() : OperatorCost(true) {} - ~PReLUCost() override = default; - - // per device communication cost - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - // per device computation cost - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using PReLUCostPtr = std::shared_ptr; - -class OneHotCost : public OperatorCost { - public: - explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - OneHotCost() : OperatorCost(true) {} - ~OneHotCost() override = default; - - // per device communication cost - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - // per device computation cost - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using OneHotCostPtr = std::shared_ptr; - -class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { - public: - explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {} - ~SoftmaxCrossEntropyWithLogitsCost() override = default; - - // per device communication cost - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - // per device computation cost - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; - -class ReshapeCost : public OperatorCost { - public: - explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ReshapeCost() : OperatorCost(true) {} - - ~ReshapeCost() override = default; - - // per device communication cost - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - // per device computation cost - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using ReshapeCostPtr = std::shared_ptr; - -class ArithmeticCost : public OperatorCost { - public: - explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ArithmeticCost() : OperatorCost(false) {} - ~ArithmeticCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; - - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using ArithmeticCostPtr = std::shared_ptr; -using BiasAddCost = ArithmeticCost; -using BiasAddCostPtr = std::shared_ptr; - -class ReduceMethodCost : public OperatorCost { - public: - explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ReduceMethodCost() : OperatorCost(true) {} - ~ReduceMethodCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } - void set_cross_batch(bool cb) { cross_batch_ = cb; } - - protected: - bool cross_batch_ = false; -}; -using ReduceMethodCostPtr = std::shared_ptr; - -class ReduceMeanCost : public ReduceMethodCost { - public: - explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {} - ReduceMeanCost() : ReduceMethodCost(true) {} - ~ReduceMeanCost() override = default; - - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using ReduceMeanCostPtr = std::shared_ptr; - -class GetNextCost : public OperatorCost { - public: - explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - GetNextCost() : OperatorCost(false) {} - ~GetNextCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - // Inputs vector is empty for generator ops. - double GetForwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } - // Generator ops don't have backward steps. - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } -}; -using GetNextCostPtr = std::shared_ptr; - -class DropOutCost : public OperatorCost { - public: - explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - DropOutCost() : OperatorCost(true) {} - ~DropOutCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override; - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } -}; - -using DropOutCostPtr = std::shared_ptr; - -class LayerNormCost : public OperatorCost { - public: - explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - LayerNormCost() : OperatorCost(true) {} - ~LayerNormCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override; - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } -}; - -using DropOutCostPtr = std::shared_ptr; - -class GatherV2Cost : public OperatorCost { - public: - explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - GatherV2Cost() : OperatorCost(true) {} - ~GatherV2Cost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t) const override; -}; - -using GatherV2CostPtr = std::shared_ptr; - -class GatherV2PCost : public OperatorCost { - public: - explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {} - GatherV2PCost() : OperatorCost(true), axis_(0) {} - ~GatherV2PCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t) const override; - void set_axis(int32_t axis) { axis_ = axis; } - void set_strategy(const Shape &strategy) { strategy_ = strategy; } - - protected: - int32_t axis_; - Shape strategy_; -}; - -using GatherV2PCostPtr = std::shared_ptr; -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc deleted file mode 100644 index 9fb79ceee4..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc +++ /dev/null @@ -1,750 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/rec_core/rec_cost.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" - -namespace mindspore { -namespace parallel { - -// Compute redistributed cost -double CostRedis(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const std::vector> &mode, const Graph &graph) { - // Store value of cost redist - double cost_redis = 0; - - // Number of current strategies. - size_t num_strategy = node_name_to_strategy.size(); - - // Number of node-in and node-out - size_t num_node_in = node.node_in.size(); - size_t num_node_out = node.node_out.size(); - - // Set tensor edge value with original tensor shape and cutting times. - double input_tensor = node.apply.arguments[0].tensor_shape.shape_n * node.apply.arguments[0].tensor_str.str_n * - node.apply.arguments[0].tensor_shape.shape_c * node.apply.arguments[0].tensor_str.str_c * - node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h * - node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w; - - double output_tensor = node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n * - node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c * - node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h * - node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w; - - // For each strategy candidate. - for (size_t i_strategy = 0; i_strategy < num_strategy; i_strategy++) { - // Find its forward nodes - for (size_t i_node = 0; i_node < num_node_in; i_node++) { - if (graph.nodes[node.node_in[i_node]].name == node_name_to_strategy[i_strategy].first) { - bool is_search_forward = true; - cost_redis += - CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, input_tensor, is_search_forward); - } - } - - // Find its backward nodes - for (size_t i_node = 0; i_node < num_node_out; i_node++) { - if (graph.nodes[node.node_out[i_node]].name == node_name_to_strategy[i_strategy].first) { - bool is_search_forward = false; - cost_redis += - CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, output_tensor, is_search_forward); - } - } - } - - return cost_redis; -} - -double CostRedisWithAdjacentNode(const std::vector> &node_name_to_strategy, - const std::vector> &mode, size_t i_strategy, size_t i_node, - double tensor_size, bool search_forward) { - double new_redis_cost = 0; - int counter = 0; - - if (search_forward) { - if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_n) != - static_cast(1 / mode[i_node][0])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_c) != - static_cast(1 / mode[i_node][1])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_h) != - static_cast(1 / mode[i_node][2])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_w) != - static_cast(1 / mode[i_node][3])) { - counter += 1; - } - } else { - if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_n) != - static_cast(1 / mode[2][0])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_c) != - static_cast(1 / mode[2][1])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_h) != - static_cast(1 / mode[2][2])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_w) != - static_cast(1 / mode[2][3])) { - counter += 1; - } - } - - if (counter >= 2) { - new_redis_cost = tensor_size / 4.0; - } else if (counter == 0 || counter == 1) { - new_redis_cost = 0; - } else { - MS_LOG(EXCEPTION) << "Failure: CostRedis failed."; - } - - return new_redis_cost; -} - -// Get optimal strategy for MatMul -StrategyRec CostMatMul::GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph) { - int edge_i = - static_cast(node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h); - int edge_j = - static_cast(node.apply.arguments[1].tensor_shape.shape_w * node.apply.arguments[1].tensor_str.str_w); - int edge_k = - static_cast(node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w); - - std::vector cost_op; - std::vector> mode; - - if (edge_i < 2 || edge_i % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrConcatDimI(edge_j, edge_k) + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}}, - graph)); - } - - if (edge_j < 2 || edge_j % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrConcatDimJ(edge_i, edge_k) + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 1, 1}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}}, - graph)); - } - - if (edge_k < 2 || edge_k % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrReduceDimK(edge_i, edge_j) + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 1, 0.5}, {1, 1, 0.5, 1}, {1, 1, 1, 1}}, - graph)); - } - - return ChoseStr(cost_op, node.apply.str); -} - -// Get weight for MatMul -double CostMatMul::GetMinCostIn(const OperatorRec &op) { - int edge_i = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); - int edge_j = static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w); - int edge_k = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); - - std::vector cost_in; - cost_in.push_back(StrConcatDimI(edge_j, edge_k)); - cost_in.push_back(StrConcatDimJ(edge_i, edge_k)); - cost_in.push_back(StrReduceDimK(edge_i, edge_j)); - - return *min_element(cost_in.begin(), cost_in.end()); -} - -// Chose strategy for MatMul -StrategyRec CostMatMul::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_i_; - break; - - case 1: - str.inputTensor[1].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_j_; - break; - - case 2: - str.inputTensor[0].str_w /= 2.0; - str.inputTensor[1].str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_k_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure:CostMatMul failed."; - } - - return str; -} - -// Get optimal strategy for Conv -StrategyRec CostConvolution::GetOptimalStr( - const Graph::NodeType &node, const std::vector> &node_name_to_strategy, - const Graph &graph, bool channel_partition) { - const OperatorRec &op = node.apply; - - int input_tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); - int input_tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); - int input_tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); - int input_tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); - - int tensor_in = input_tensor_h * input_tensor_w * input_tensor_n * input_tensor_c; - - int tensor_filter_h = static_cast(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h); - int tensor_filter_w = static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w); - int tensor_filter_n = static_cast(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n); - int tensor_filter_c = static_cast(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); - - int tensor_filter = tensor_filter_h * tensor_filter_w * tensor_filter_n * tensor_filter_c; - - int output_tensor_h = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h); - int output_tensor_w = static_cast(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w); - int output_tensor_n = static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n); - int output_tensor_c = static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); - - int tensor_out = output_tensor_h * output_tensor_w * output_tensor_n * output_tensor_c; - - std::vector cost_op; - cost_op.reserve(7); - std::vector> mode; - - if (input_tensor_n < 2 || input_tensor_n % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy, - mode = {{0.5, 1, 1, 1}, {1, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); - } - - cost_op.push_back(DOUBLE_MAX); - cost_op.push_back(DOUBLE_MAX); - - if (channel_partition == false || tensor_filter < 2 || tensor_filter % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrDimK(tensor_in) + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 1, 1}, {0.5, 1, 1, 1}, {1, 0.5, 1, 1}}, graph)); - } - - cost_op.push_back(DOUBLE_MAX); - cost_op.push_back(DOUBLE_MAX); - - if (channel_partition == false || tensor_filter_c < 2 || tensor_filter_c % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrDimQ(tensor_out) + CostRedis(node, node_name_to_strategy, - mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 1, 1, 1}}, graph)); - } - - return ChoseStr(cost_op, node.apply.str); -} - -// Get weight for Conv -double CostConvolution::GetMinCostIn(const Graph::NodeType &node) { - const OperatorRec &op = node.apply; - - int tensor_in = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) * - static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) * - static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) * - static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); - int tensor_filter = static_cast(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h) * - static_cast(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n) * - static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w) * - static_cast(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); - int tensor_out = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h) * - static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n) * - static_cast(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w) * - static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); - - std::vector cost_in; - cost_in.push_back(StrDimB(tensor_filter)); - cost_in.push_back(StrDimI(tensor_in, tensor_filter)); - cost_in.push_back(StrDimJ(tensor_in, tensor_filter)); - cost_in.push_back(StrDimK(tensor_in)); - cost_in.push_back(StrDimDI(tensor_in, tensor_out)); - cost_in.push_back(StrDimDJ(tensor_in, tensor_out)); - cost_in.push_back(StrDimQ(tensor_out)); - - return *min_element(cost_in.begin(), cost_in.end()); -} - -// Chose strategy for Conv -StrategyRec CostConvolution::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_b_; - break; - - case 1: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_i_; - break; - - case 2: - str.inputTensor[0].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_j_; - break; - - case 3: - str.inputTensor[1].str_n /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_k_; - break; - - case 4: - str.inputTensor[1].str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_di_; - break; - - case 5: - str.inputTensor[1].str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_dj_; - break; - - case 6: - str.inputTensor[0].str_c /= 2.0; - str.inputTensor[1].str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_q_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostConvolution failed."; - } - return str; -} - -// Get optimal strategy for Pooling -StrategyRec CostPooling::GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph) { - int tensor_n = static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n); - int tensor_c = static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); - - std::vector cost_op; - std::vector> mode; - - if (tensor_n < 2 || tensor_n % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); - } - - if (tensor_c < 2 || tensor_c % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph)); - } - - cost_op.push_back(DOUBLE_MAX); - cost_op.push_back(DOUBLE_MAX); - - return ChoseStr(cost_op, node.apply.str); -} - -// Chose strategy for Pooling -StrategyRec CostPooling::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostPooling failed."; - } - return str; -} - -// Chose strategy for Add -StrategyRec CostTensorAdd::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.inputTensor[1].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.inputTensor[1].str_c /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.inputTensor[1].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.inputTensor[1].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostAdd failed."; - } - return str; -} - -// Get optimal strategy for Reshape -StrategyRec CostReshape::GetOptimalStr(const Graph::NodeType &node) const { return ChoseStr(node.apply.str); } - -StrategyRec CostReshape::ChoseStr(StrategyRec str) const { return str; } - -// Chose strategy for BiasAdd -StrategyRec CostBiasAdd::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.inputTensor[1].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostBiasAdd failed."; - } - return str; -} - -// Get optimal strategy for Common OPs -StrategyRec CostCommon::GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph) { - const OperatorRec &op = node.apply; - int tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); - int tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); - int tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); - int tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); - - std::vector cost_op; - std::vector> mode; - - if (tensor_n < 2 || tensor_n % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); - } - - if (tensor_c < 2 || tensor_c % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph)); - } - - if (tensor_h < 2 || tensor_h % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 0.5, 1}, {1, 1, 0.5, 1}, {1, 1, 0.5, 1}}, graph)); - } - - if (tensor_w < 2 || tensor_w % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 1, 0.5}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}}, graph)); - } - - return ChoseStr(cost_op, node.apply.str); -} - -// Chose strategy for Common op -StrategyRec CostCommon::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: Common failed."; - } - return str; -} - -// Get optimal strategy for BatchParallel OPs -StrategyRec CostBatchParallel::GetOptimalStr(const Graph::NodeType &node) { - const OperatorRec &op = node.apply; - int tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); - int tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); - int tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); - int tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); - - std::vector cost_op; - - if (tensor_n < 2 || tensor_n % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_); - } - - if (tensor_c < 2 || tensor_c % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_); - } - - if (tensor_h < 2 || tensor_h % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_); - } - - if (tensor_w < 2 || tensor_w % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_); - } - - return ChoseStr(cost_op, node.apply.str); -} - -// Chose strategy for BatchParallel op -StrategyRec CostBatchParallel::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostBatchParallel failed."; - } - return str; -} - -// Chose strategy for CostSoftmaxCrossEntropyWithLogits -StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.inputTensor[1].str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.inputTensor[1].str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.inputTensor[1].str_h /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.inputTensor[1].str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed."; - } - return str; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h deleted file mode 100644 index fb4fc27164..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h +++ /dev/null @@ -1,233 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_COST_H_ -#define PARALLEL_AUTO_PARALLEL_REC_COST_H_ - -#include -#include -#include -#include -#include - -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/auto_parallel/rec_core/rec_strategy.h" - -namespace mindspore { -namespace parallel { -#define DOUBLE_MAX (std::numeric_limits::max)() - -double CostRedis(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const std::vector> &mode, const Graph &graph); - -double CostRedisWithAdjacentNode(const std::vector> &node_name_to_strategy, - const std::vector> &mode, size_t i_strategy, size_t i_node, - double tensor_size, bool is_search_forward); - -// class CostMatMul is used to compute the cost of MatMul operator. -class CostMatMul { - public: - StrategyRec GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph); - - double GetMinCostIn(const OperatorRec &op); - - private: - double StrConcatDimI(int32_t a, int32_t b) { - cost_in_i_ = (static_cast(a) * static_cast(b)) / 2.0; - - return cost_in_i_; - } - - double StrConcatDimJ(int32_t a, int32_t b) { - cost_in_j_ = (static_cast(a) * static_cast(b)) / 2.0; - - return cost_in_j_; - } - - double StrReduceDimK(int32_t a, int32_t b) { - cost_in_k_ = (static_cast(a) * static_cast(b)) / 2.0; - - return cost_in_k_; - } - - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_i_ = 0; - - double cost_in_j_ = 0; - - double cost_in_k_ = 0; -}; // class CostMatMul is used to compute the cost of MatMul operator. - -// class CostConvolution is used to compute the cost of Conv operator. -class CostConvolution { - public: - StrategyRec GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph, bool channel_partition); - - double GetMinCostIn(const Graph::NodeType &node); - - private: - double StrDimB(int32_t TensorFilter) { - cost_in_b_ = static_cast((TensorFilter) / 2.0); - - return cost_in_b_; - } - - double StrDimI(int32_t TensorIn, int32_t TensorFilter) { - cost_in_i_ = static_cast((TensorIn + TensorFilter) / 2.0); - - return cost_in_i_; - } - - double StrDimJ(int32_t TensorIn, int32_t TensorFilter) { - cost_in_j_ = static_cast((TensorIn + TensorFilter) / 2.0); - - return cost_in_j_; - } - - double StrDimK(int32_t TensorIn) { - cost_in_k_ = static_cast((TensorIn) / 2.0); - - return cost_in_k_; - } - - double StrDimDI(int32_t TensorIn, int32_t TensorOut) { - cost_in_di_ = static_cast((TensorIn + TensorOut) / 2.0); - - return cost_in_di_; - } - - double StrDimDJ(int32_t TensorIn, int32_t TensorOut) { - cost_in_dj_ = static_cast((TensorIn + TensorOut) / 2.0); - - return cost_in_dj_; - } - - double StrDimQ(int32_t TensorOut) { - cost_in_q_ = static_cast((TensorOut) / 2.0); - - return cost_in_q_; - } - - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_b_ = 0; - - double cost_in_i_ = 0; - - double cost_in_j_ = 0; - - double cost_in_k_ = 0; - - double cost_in_di_ = 0; - - double cost_in_dj_ = 0; - - double cost_in_q_ = 0; -}; // class CostConvolution is used to compute the cost of Conv operator. - -// class CostPooling is used to compute the cost of Pooling operator. -class CostPooling { - public: - StrategyRec GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph); - - double GetMinCostIn() const { return cost_in_; } - - private: - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_ = 0; -}; // class CostPooling is used to compute the cost of Pooling operator. - -// class CostReshape is used to compute the cost of Reshape operator. -class CostReshape { - public: - StrategyRec GetOptimalStr(const Graph::NodeType &node) const; - - double GetMinCostIn() const { return cost_in_; } - - private: - StrategyRec ChoseStr(StrategyRec str) const; - - double cost_in_ = 0; -}; // class CostReshape is used to compute the cost of Reshape operator. - -// class CostCommon is used to compute the cost of an element-wise operator -class CostCommon { - public: - virtual StrategyRec GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph); - - virtual double GetMinCostIn() const { return cost_in_; } - - protected: - virtual StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_ = 0; -}; // class CostCommon is used to compute the cost of an element-wise operator - -// class CostBiasAdd is used to compute the cost of the addition between a tensor and a bias -class CostBiasAdd : public CostCommon { - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); -}; -// class CostAdd is used to compute the cost of Add operator. -class CostTensorAdd : public CostCommon { - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); -}; - -// all the following operation are element-wise and have the same cost -class CostReLU : public CostCommon {}; -class CostLog : public CostCommon {}; -class CostExp : public CostCommon {}; -class CostAdd : public CostCommon {}; -class CostSub : public CostCommon {}; -class CostMul : public CostCommon {}; -class CostDiv : public CostCommon {}; -class CostSqueeze : public CostCommon {}; -class CostCast : public CostCommon {}; - -// class BatchParallel is used to compute the cost of BatchParallel operator. -class CostBatchParallel { - public: - virtual StrategyRec GetOptimalStr(const Graph::NodeType &node); - - virtual double GetMaxCostIn() const { return DOUBLE_MAX; } - - protected: - virtual StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_ = 0; -}; // class BatchParallel is used to compute the cost of BatchParallel operator. - -class CostBatchNorm : public CostBatchParallel {}; -class CostOneHot : public CostBatchParallel {}; -class CostPRelu : public CostBatchParallel {}; -class CostSoftmax : public CostBatchParallel {}; - -class CostSoftmaxCrossEntropyWithLogits : public CostBatchParallel { - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); -}; -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc deleted file mode 100644 index 9de71231c0..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ /dev/null @@ -1,838 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/rec_core/rec_generate_strategy.h" - -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/rec_core/rec_parse_graph.h" -#include "parallel/auto_parallel/rec_core/rec_partition.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, - const std::shared_ptr>> &eli_list, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(eli_list); - MS_EXCEPTION_IF_NULL(index_list); - GeneratePartitionedOperatorStrategy(graph, ops, index_list); - std::shared_ptr> no_stra_op_list(new std::vector); - for (size_t i = 0; i < eli_list->size(); i++) { - no_stra_op_list->push_back(eli_list->at(i)[0]); - } - GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); - GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); - GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); -} - -std::vector> PrepareMatMul(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - std::vector> strategies; - auto attrs = ops[iter_ops]->attrs(); - bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); - bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); - - // HCCL does not support multi-dimension partition, and the hardware does not support excessive - // number of EVENT, so we temporarily disable matmul's multi-dimension partition function. - const auto max_cut = 1.0 / g_device_manager->DeviceNum(); - if (graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h != max_cut && - graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w != max_cut) { - graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = 1.0; - graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_w = 1.0; - graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_h = 1.0; - graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = 1.0; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; - - auto shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; - if (transpose_a) { - shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[1]; - } - auto shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[1]; - if (transpose_b) { - shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[0]; - } - - bool already_cut = false; - if (shape_1 >= shape_4) { - if (shape_1 % g_device_manager->DeviceNum() == 0) { - graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut; - already_cut = true; - } - if (!already_cut && shape_4 % g_device_manager->DeviceNum() == 0) { - graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut; - already_cut = true; - } - } else { - if (shape_4 % g_device_manager->DeviceNum() == 0) { - graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut; - already_cut = true; - } - if (!already_cut && shape_1 % g_device_manager->DeviceNum() == 0) { - graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut; - already_cut = true; - } - } - - if (!already_cut) { - MS_LOG(EXCEPTION) << "Failure: MatMul's shape is invalid."; - } - } - - for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - std::vector s; - if (transpose_a && (iter_op_inputs == 0)) { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); - } else if (transpose_b && (iter_op_inputs == 1)) { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); - } else { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } - strategies.push_back(s); - } - return strategies; -} - -std::vector> PrepareBiasAdd(const std::shared_ptr> &s) { - std::vector> strategies; - strategies.push_back(*s); - std::vector s_biasadd; - s_biasadd.push_back(s->at(1)); - strategies.push_back(s_biasadd); - return strategies; -} - -std::vector> PrepareOneHot(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - std::vector> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); - - int32_t axis = -1; - auto iter = ops[iter_ops]->attrs().find(AXIS); - if (iter != ops[iter_ops]->attrs().end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - axis = iter->second->cast()->value(); - } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int."; - } - } - if (axis == -1) { - strategies[0][0] = strategies[0][1]; - strategies[0][1] = 1; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; - } - - std::vector s_empty = {}; - strategies.push_back(s_empty); - strategies.push_back(s_empty); - return strategies; -} - -std::vector> PrepareGatherV2(const std::vector> &ops, - const size_t iter_ops, std::vector s) { - std::vector> strategies; - - int32_t axis = 0; - auto axis_input = GetValue(ops[iter_ops]->input_value().at(2)); - if (axis_input < 0) { - axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); - } - axis = axis_input; - if (axis >= SizeToInt(s.size())) { - MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; - } - s[axis] = 1; - strategies.push_back(s); - - auto pos = ops[iter_ops]->name().find("Info"); - auto name = ops[iter_ops]->name().substr(0, pos); - if (name == "GatherV2") { - return strategies; - } - - std::vector s_indices; - for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { - s_indices.push_back(1); - } - strategies.push_back(s_indices); - - return strategies; -} - -std::vector> PrepareL2Normalize(const std::vector> &ops, - const size_t iter_ops, std::vector s) { - int32_t axis = 0; - auto iter = ops[iter_ops]->attrs().find(AXIS); - if (iter != ops[iter_ops]->attrs().end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - axis = iter->second->cast()->value(); - } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " : The value of axis is not int."; - } - } - - int32_t axis_index = axis; - if (axis < 0) { - size_t input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); - axis_index = static_cast(input_dim) + axis; - } - - s[IntToSize(axis_index)] = 1; - - std::vector> strategies; - strategies.push_back(s); - return strategies; -} - -std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - if (ops.empty()) { - MS_LOG(EXCEPTION) << "Failure: Operators is empty."; - } - if (iter_ops >= ops.size()) { - MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; - } - - StrategyPtr origin_strategy = ops[iter_ops]->strategy(); - std::vector> strategies; - for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { - MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; - } - - size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); - std::vector s; - if (output_size == 4) { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_n)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } else if (output_size == 2) { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } else if (output_size == 1) { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } else if (output_size == 0) { - s = {}; - } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's output size is unexcepted."; - } - strategies.push_back(s); - } - return strategies; -} - -std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - if (ops.empty()) { - MS_LOG(EXCEPTION) << "Failure: Operators is empty."; - } - if (iter_ops >= ops.size()) { - MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; - } - - StrategyPtr origin_strategy = ops[iter_ops]->strategy(); - std::vector> strategies; - size_t max_device_num = g_device_manager->DeviceNum(); - size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; - for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { - MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; - } - - std::vector s; - size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); - for (size_t dim = 0; dim < input_size; dim++) { - if (input_size == 1 || input_size == 2 || input_size == 4) { - if (dim == 0) { - s.push_back(std::min(max_device_num, target_tensor_batch)); - } else { - s.push_back(1); - } - } else if (input_size == 0) { - s = {}; - } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; - } - } - strategies.push_back(s); - } - - graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; - if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch); - } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch); - } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { - graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch); - } - - return strategies; -} - -std::vector> PrepareStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - if (ops.empty()) { - MS_LOG(EXCEPTION) << "Failure: Operators is empty."; - } - if (iter_ops >= ops.size()) { - MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; - } - MS_EXCEPTION_IF_NULL(ops[iter_ops]); - - auto type = ops[iter_ops]->type(); - auto idx = DictOpType.find(type); - if (idx == DictOpType.end()) { - return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); - } - - if (type == MATMUL) { - return PrepareMatMul(graph, ops, iter_graph, iter_ops); - } else if (type == ONEHOT) { - return PrepareOneHot(graph, ops, iter_graph, iter_ops); - } else { - return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); - } -} - -void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const std::shared_ptr> &index_list) { - for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { - std::vector> strategies; - size_t iter_graph = index_list->at(iter_ops); - if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) { - strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); - } - StrategyPtr sp = std::make_shared(0, strategies); - ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); - } -} - -size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, - const size_t iter_ops) { - size_t incoming_op_index = SIZE_MAX; - for (size_t i = 1; i < input_tensor_names[iter_ops].size(); i++) { - for (size_t j = 0; j < input_tensor_names.size(); j++) { - if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) { - incoming_op_index = j; - break; - } - } - if (incoming_op_index != SIZE_MAX) { - break; - } - } - return incoming_op_index; -} - -std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_ops, const size_t iter_graph) { - std::vector s; - for (auto input : ops[iter_ops]->inputs_tensor_info()) { - auto input_stra_dim = input.shape().size(); - if (input_stra_dim == 0) { - continue; - } - if (input_stra_dim == 1) { - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); - } else if (input_stra_dim == 2) { - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); - } else if (input_stra_dim == 4) { - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_n); - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c); - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); - } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; - } - break; - } - return s; -} - -std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t incoming_op_index) { - std::vector s; - if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || - ops[incoming_op_index]->type() == TRANSPOSE) { - return s; - } - auto strategy = ops[incoming_op_index]->selected_strategy(); - if (strategy->GetInputNumber() == 0) { - return s; - } - - for (size_t i = 0; i < (size_t)ops[incoming_op_index]->inputs_tensor_info().size(); i++) { - if (ops[incoming_op_index]->inputs_tensor_info()[i].shape().size() == 0) { - continue; - } - for (size_t j = 0; j < ops[incoming_op_index]->inputs_tensor_info()[i].shape().size(); ++j) { - s.push_back(strategy->GetInputDim()[i][j]); - } - break; - } - return s; -} - -std::vector GetAxisList(const std::vector> &ops, const int iter_ops) { - std::vector axis_list; - auto axis_param = ops[iter_ops]->attrs().find(AXIS)->second; - std::vector elements; - if (axis_param->isa()) { - elements = axis_param->cast()->value(); - } else if (axis_param->isa()) { - elements = axis_param->cast()->value(); - } else { - MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list." << std::endl; - } - - for (auto &element : elements) { - if (!element->isa()) { - MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32." << std::endl; - } - auto axis = element->cast()->value(); - axis_list.push_back(axis); - } - return axis_list; -} - -std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s) { - std::vector s_Squeeze; - std::vector stra_dim_list; - for (size_t i = 0; i < s.size(); i++) { - stra_dim_list.push_back(i); - } - - auto axis_list = GetAxisList(ops, incoming_op_index); - for (auto axis : axis_list) { - auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis); - if (it == stra_dim_list.end()) { - MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; - } - if (ops[incoming_op_index]->inputs_tensor_info()[0].shape()[axis] != 1) { - MS_LOG(EXCEPTION) << "Failure: Removed dimension's shape is not 1." << std::endl; - } - stra_dim_list.erase(it); - } - - for (size_t i = 0; i < (size_t)stra_dim_list.size(); i++) { - s_Squeeze.push_back(s[stra_dim_list[i]]); - } - return s_Squeeze; -} - -bool GetKeepDims(const std::vector> &ops, const size_t iter_ops) { - bool keepdims = false; - auto keep_dims_iter = ops[iter_ops]->attrs().find(KEEP_DIMS); - if (keep_dims_iter == ops[iter_ops]->attrs().end()) { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr keep_dims."; - } - MS_EXCEPTION_IF_NULL(keep_dims_iter->second); - if (!keep_dims_iter->second->isa()) { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Keep_dims is not a bool."; - } - keepdims = keep_dims_iter->second->cast()->value(); - return keepdims; -} - -std::vector GetDimList(const std::vector> &ops, const size_t iter_ops) { - std::vector dim_list; - bool keep_dims = GetKeepDims(ops, iter_ops); - if (keep_dims != false) { - return dim_list; - } - auto input_value = ops[iter_ops]->input_value(); - auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); - if (input_value.back()->isa()) { - auto attr_axis = GetValue>(input_value.back()); - if (attr_axis.empty()) { - MS_LOG(EXCEPTION) << "Failure: This output is a 0-D tensor." << std::endl; - } - for (auto &axis : attr_axis) { - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } - } else if (input_value.back()->isa()) { - int axis = GetValue(input_value.back()); - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Failure: Axis type is invalid." << std::endl; - } - return dim_list; -} - -std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s) { - std::vector s_Reduce; - std::vector axis_list; - for (size_t i = 0; i < s.size(); i++) { - axis_list.push_back(i); - } - - auto dim_list = GetDimList(ops, incoming_op_index); - for (auto axis : dim_list) { - auto it = find(axis_list.begin(), axis_list.end(), axis); - if (it == axis_list.end()) { - MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; - } - axis_list.erase(it); - } - - for (size_t i = 0; i < (size_t)axis_list.size(); i++) { - s_Reduce.push_back(s[axis_list[i]]); - } - return s_Reduce; -} - -std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops) { - std::vector dim_list; - auto iter = ops[iter_ops]->attrs().find(AXIS); - if (iter == ops[iter_ops]->attrs().end()) { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis."; - } - auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - auto attr_axis = GetValue>(iter->second); - if (attr_axis.empty()) { - for (size_t i = 0; i < input_dim; ++i) { - dim_list.push_back(SizeToInt(i)); - } - } else { - for (auto &axis : attr_axis) { - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } - } - } else if (iter->second->isa()) { - int axis = GetValue(iter->second); - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Axis type is invalid."; - } - return dim_list; -} - -std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s) { - bool keepdims = GetKeepDims(ops, incoming_op_index); - if (keepdims) { - return s; - } - - std::vector s_Arg; - std::vector axis_list; - for (size_t i = 0; i < s.size(); i++) { - axis_list.push_back(i); - } - - auto dim_list = GetDimListFromAttrs(ops, incoming_op_index); - for (auto axis : dim_list) { - auto it = find(axis_list.begin(), axis_list.end(), axis); - if (it == axis_list.end()) { - MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; - } - axis_list.erase(it); - } - - for (size_t i = 0; i < (size_t)axis_list.size(); i++) { - s_Arg.push_back(s[axis_list[i]]); - } - return s_Arg; -} - -std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t iter_ops, const size_t incoming_op_index) { - std::vector s; - s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index); - if (s.size() != 0) { - if (ops[incoming_op_index]->type() == SQUEEZE) { - s = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, s); - } - if (ops[incoming_op_index]->type() == REDUCE_SUM || ops[incoming_op_index]->type() == REDUCE_MAX || - ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { - s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); - } - if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) { - s = ModifyStrategyIfArgIncoming(ops, incoming_op_index, s); - } - } - return s; -} - -std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, - const size_t iter_ops, - std::vector basic_stra) { - std::vector s_empty = {}; - std::vector> stra; - MS_EXCEPTION_IF_NULL(ops[iter_ops]); - - if (basic_stra.size() == 0) { - for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); - iter_op_inputs++) { - stra.push_back(basic_stra); - } - return stra; - } - - auto s_ptr = std::make_shared>(basic_stra); - if (ops[iter_ops]->type() == BIAS_ADD) { - return PrepareBiasAdd(s_ptr); - } - if (ops[iter_ops]->type() == GATHERV2) { - return PrepareGatherV2(ops, iter_ops, basic_stra); - } - if (ops[iter_ops]->type() == L2_NORMALIZE) { - return PrepareL2Normalize(ops, iter_ops, basic_stra); - } - - for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); - iter_op_inputs++) { - if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) { - stra.push_back(s_empty); - continue; - } - - std::vector tmp_stra = basic_stra; - bool modified = false; - for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) { - if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) { - tmp_stra[j] = 1; - modified = true; - } - } - if (modified) { - stra.push_back(tmp_stra); - } else { - stra.push_back(basic_stra); - } - } - return stra; -} - -void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, - const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list, - const std::shared_ptr> &no_stra_op_list) { - if (no_stra_op_list->size() == 0) { - return; - } - std::vector no_stra_op_list_bis; - - for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { - size_t iter_ops = no_stra_op_list->at(iter_list - 1); - std::vector> stra; - std::vector s; - size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); - if (incoming_op_index != SIZE_MAX) { - auto iter_graph = index_list->at(incoming_op_index); - if (iter_graph != SIZE_MAX) { - s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph); - } else { - s = CopyIncomingOperatorInputStrategy(ops, iter_ops, incoming_op_index); - } - } - - if (s.size() == 0) { - no_stra_op_list_bis.push_back(iter_ops); - } else { - stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); - } - - StrategyPtr sp = std::make_shared(0, stra); - ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); - } - - no_stra_op_list->clear(); - for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) { - no_stra_op_list->push_back(no_stra_op_list_bis[i]); - } -} - -std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, - const size_t iter_ops, std::vector s) { - std::vector s_Squeeze; - auto axis_list = GetAxisList(ops, iter_ops); - size_t s_index = 0; - size_t axis_list_index = 0; - for (size_t i = 0; i < (size_t)(s.size() + axis_list.size()); i++) { - if (i == (size_t)axis_list[axis_list_index]) { - s_Squeeze.push_back(1); - axis_list_index++; - } else { - s_Squeeze.push_back(s[s_index]); - s_index++; - } - } - - size_t cut = 1; - for (size_t i = 0; i < s_Squeeze.size(); i++) { - cut *= s_Squeeze[i]; - } - if (cut != g_device_manager->DeviceNum()) { - s_Squeeze.clear(); - } - - return s_Squeeze; -} - -std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, - const std::vector> &input_tensor_names, - const size_t iter_ops) { - std::vector s; - if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || - ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || - ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE || - ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) { - return s; - } - - bool found = false; - size_t outgoing_op_index = SIZE_MAX; - size_t iter_op_inputs = SIZE_MAX; - for (size_t i = 0; i < input_tensor_names.size(); i++) { - for (size_t j = 1; j < input_tensor_names[i].size(); j++) { - if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0] && - ops[i]->selected_strategy()->GetInputNumber() != 0) { - outgoing_op_index = i; - iter_op_inputs = j - 1; - found = true; - break; - } - } - if (found) { - break; - } - } - - if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) { - for (size_t k = 0; k < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); ++k) { - s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]); - } - } - return s; -} - -void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &no_stra_op_list) { - if (no_stra_op_list->size() == 0) { - return; - } - std::vector no_stra_op_list_bis; - - for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { - auto iter_ops = no_stra_op_list->at(iter_list - 1); - std::vector> stra; - std::vector s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); - - if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) { - s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); - } - if (s.size() != 0) { - stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); - } else { - no_stra_op_list_bis.push_back(iter_ops); - } - - StrategyPtr sp = std::make_shared(0, stra); - ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); - } - - no_stra_op_list->clear(); - for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) { - no_stra_op_list->push_back(no_stra_op_list_bis[i]); - } -} - -void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list, - const std::shared_ptr> &no_stra_op_list) { - if (no_stra_op_list->size() == 0) { - return; - } - - size_t no_stra_op_list_size = no_stra_op_list->size(); - do { - no_stra_op_list_size = no_stra_op_list->size(); - GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); - GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); - } while (no_stra_op_list_size > no_stra_op_list->size()); - - for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) { - auto iter_ops = no_stra_op_list->at(iter_list); - std::vector> stra; - std::vector s; - - size_t max_dim_num = 0; - for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() > max_dim_num) { - max_dim_num = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); - } - } - for (size_t i = 0; i < max_dim_num; i++) { - s.push_back(1); - } - - stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); - StrategyPtr sp = std::make_shared(0, stra); - ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); - } -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h deleted file mode 100644 index e82efe6798..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ -#define PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ - -#include -#include -#include -#include - -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/ops_info/operator_info.h" - -namespace mindspore { -namespace parallel { -void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, - const std::shared_ptr>> &eli_list, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list); -std::vector> PrepareMatMul(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareBiasAdd(const std::shared_ptr> &s); -std::vector> PrepareOneHot(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareGatherV2(const std::vector> &ops, - const size_t iter_ops, std::vector s); -std::vector> PrepareL2Normalize(const std::vector> &ops, - const size_t iter_ops, std::vector s); -std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const std::shared_ptr> &index_list); -size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, - const size_t iter_ops); -std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_ops, const size_t iter_graph); -std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t incoming_op_index); -std::vector GetAxisList(const std::vector> &ops, const int iter_ops); -std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s); -bool GetKeepDims(const std::vector> &ops, const size_t iter_ops); -std::vector GetDimList(const std::vector> &ops, const size_t iter_ops); -std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s); -std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops); -std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s); -std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t iter_ops, const size_t incoming_op_index); -std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, - const size_t iter_ops, - std::vector basic_stra); -void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, - const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list, - const std::shared_ptr> &no_stra_op_list); -std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, - const size_t iter_ops, std::vector s); -std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, - const std::vector> &input_tensor_names, - const size_t iter_ops); -void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &no_stra_op_list); -void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list, - const std::shared_ptr> &no_stra_op_list); -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h deleted file mode 100644 index 9007218d15..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ -#define PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ - -#include -#include -#include - -#include "parallel/auto_parallel/rec_core/rec_strategy.h" -#include "parallel/auto_parallel/rec_core/rec_tensor.h" - -namespace mindspore { -namespace parallel { -enum OperatorType { - kRecUnkownType, - kRecMatMul, - kRecConvolution, - kRecPooling, - kRecElmWiseOp, - kRecReLU, - kRecBatchNorm, - kRecReshape, - kRecBiasAdd, - kRecSoftmax, - kRecSparseSoftmaxCrossEntropyWithLogits, - kRecSoftmaxCrossEntropyWithLogits, - kRecOneHot, - kRecLog, - kRecExp, - kRecAdd, - kRecSub, - kRecMul, - kRecDiv, - kRecSqueeze, - kRecCast, - kRecReduce, - kRecPReLU, - kRecGatherV2, - kRecArgWithValue -}; - -enum InfoType { kApplication, kConstant }; - -struct OperatorRec { - OperatorType op_type; - TensorParam arguments[MAX_INPUT_NUM]; - StrategyRec str; -}; - -// Define simplified dataflow Graph for partitioning -class Graph { - public: - struct NodeType { - std::string name; - // Nodes that point to this node - std::vector node_in; - // Nodes that point from this node - std::vector node_out; - std::vector node_in_aux; - // Node Type Info: Application or Constant. Defined in enum . - InfoType info; - // Operator info. Defined in struct . - OperatorRec apply; - // Tensor info. Defined in tensor.h struct . - TensorParam tensor_parm; - }; - - std::vector nodes; // Nodes of the graph. Pubic. -}; // Define simplified dataflow Graph for partitioning -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc deleted file mode 100644 index c0412e9108..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ /dev/null @@ -1,274 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/rec_core/rec_parse_graph.h" - -#include -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/auto_parallel/rec_core/rec_tensor.h" -#include "parallel/ops_info/operator_info.h" - -namespace mindspore { -namespace parallel { -const TensorParam MakeTensor(int n, int c, int h, int w) { - TensorParam new_tensor; - new_tensor.tensor_type = kFloat32; - new_tensor.tensor_shape.shape_n = n; - new_tensor.tensor_shape.shape_c = c; - new_tensor.tensor_shape.shape_h = h; - new_tensor.tensor_shape.shape_w = w; - const TensorParam &tensor = new_tensor; - return tensor; -} - -Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops) { - Graph::NodeType NewOp; - NewOp.name = ops[iter_ops]->name(); - NewOp.info = InfoType::kApplication; - - auto op_type = ops[iter_ops]->type(); - auto idx = DictOpType.find(op_type); - if (idx == DictOpType.end()) { - NewOp.apply.op_type = OperatorType::kRecUnkownType; - MS_LOG(INFO) << "Unknown operator type."; - } else { - NewOp.apply.op_type = DictOpType.at(op_type); - } - - if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { - NewOp.tensor_parm = MakeTensor( - ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], - ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); - } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { - NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], - ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); - } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { - NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0]); - } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) { - NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); - } else { - MS_LOG(ERROR) << "Tensor's shape is unknown."; - } - - NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp); - return NewOp; -} - -OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, - Graph::NodeType NewTensor) { - for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); - iter_input_tensors++) { - if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { - NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { - NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); - } else { - MS_LOG(ERROR) << "Tensor's shape is unknown."; - } - } - return NewTensor.apply; -} - -TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, - const size_t iter_input_tensors, Graph::NodeType NewTensor) { - if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { - auto attrs = ops[iter_ops]->attrs(); - bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); - bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); - if (transpose_a && (iter_input_tensors == 0)) { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else if (transpose_b && (iter_input_tensors == 1)) { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); - } - } else { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); - } - return NewTensor.apply.arguments[iter_input_tensors]; -} - -std::shared_ptr ParseGraph(const std::vector> &ops, - const std::vector> &input_tensor_names) { - std::shared_ptr graph(new Graph); - if (ops.size() > SIZE_MAX / 2) { - MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2; - } - - for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { - Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops); - graph->nodes.push_back(NewOp); - } - MakeEdge(input_tensor_names, graph); - - return graph; -} - -void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph) { - for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { - for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { - size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); - if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) { - graph->nodes[iter_i].node_in.push_back(head_node_index); - graph->nodes[head_node_index].node_out.push_back(iter_i); - } - } - } -} - -size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_name, - const std::string &input_name) { - for (size_t index = 0; index < input_tensor_name.size(); index++) { - if (input_tensor_name[index][0] == input_name) { - return index; - } - } - MS_LOG(INFO) << "Get index failed, using SIZE_MAX insted"; - return SIZE_MAX; -} - -void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, - const std::shared_ptr>> &eli_list) { - std::vector eli; - eli.push_back(node_index); - for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) { - eli.push_back(graph->nodes[node_index].node_out[i]); - } - eli_list->push_back(eli); - - for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) { - auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out; - auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index); - if (it != incoming_outputs->end()) { - it = incoming_outputs->erase(it); - incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), graph->nodes[node_index].node_out.end()); - } - } - - for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) { - auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out; - auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index); - if (it != aux_incoming_outputs->end()) { - it = aux_incoming_outputs->erase(it); - aux_incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), - graph->nodes[node_index].node_out.end()); - } - } - - for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) { - auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in; - auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index); - if (it != outgoing_inputs->end()) { - if (graph->nodes[node_index].node_in.size() > 0) { - outgoing_inputs->at(std::distance(outgoing_inputs->begin(), it)) = graph->nodes[node_index].node_in[0]; - for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) { - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]); - } - for (size_t j = 1; j < graph->nodes[node_index].node_in_aux.size(); j++) { - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back( - graph->nodes[node_index].node_in_aux[j]); - } - } else { - outgoing_inputs->erase(it); - } - } - } -} - -std::shared_ptr EliminateGraph(const std::shared_ptr &graph, - const std::shared_ptr>> &eli_list, - const std::shared_ptr> &index_list) { - MS_EXCEPTION_IF_NULL(graph); - static const std::set elementwise_type = { - OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, - OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, - OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, - OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue}; - for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { - auto type = graph->nodes[node_index].apply.op_type; - if (elementwise_type.find(type) != elementwise_type.end()) { - Eliminate_Aux(node_index, graph, eli_list); - } - } - - index_list->reserve(graph->nodes.size()); - for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { - index_list->push_back(i); - } - - for (size_t i = 0; i < (size_t)eli_list->size(); i++) { - if (eli_list->at(i)[0] >= index_list->size()) { - MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; - } - index_list->at(eli_list->at(i)[0]) = SIZE_MAX; - for (size_t j = eli_list->at(i)[0] + 1; j < (size_t)index_list->size(); j++) { - index_list->at(j)--; - } - } - - std::shared_ptr new_graph(new Graph); - for (size_t i = 0; i < graph->nodes.size(); i++) { - if (index_list->at(i) > SIZE_MAX / 2) { - continue; - } - - new_graph->nodes.push_back(graph->nodes[i]); - auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; - for (size_t j = node_in->size(); j > 0; j--) { - bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX); - if (IsEliminated) { - node_in->erase(node_in->begin() + j - 1); - } else { - node_in->at(j - 1) = index_list->at(node_in->at(j - 1)); - } - } - auto *node_out = &new_graph->nodes[index_list->at(i)].node_out; - for (size_t j = node_out->size(); j > 0; j--) { - bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX); - if (IsEliminated) { - node_out->erase(node_out->begin() + j - 1); - } else { - node_out->at(j - 1) = index_list->at(node_out->at(j - 1)); - } - } - } - return new_graph; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h deleted file mode 100644 index 66fc82b8ce..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ -#define PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ - -#include -#include -#include -#include -#include - -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/ops_info/operator_info.h" - -namespace mindspore { -namespace parallel { -const std::map DictOpType{ - {MATMUL, OperatorType::kRecMatMul}, - {CONV2D, OperatorType::kRecConvolution}, - {MAXPOOL, OperatorType::kRecPooling}, - {MAXPOOLV2, OperatorType::kRecPooling}, - {SIMPLE_MEAN, OperatorType::kRecPooling}, - {RESHAPE, OperatorType::kRecReshape}, - {BIAS_ADD, OperatorType::kRecBiasAdd}, - {BATCH_NORM, OperatorType::kRecBatchNorm}, - {FUSE_BATCH_NORM, OperatorType::kRecBatchNorm}, - {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, - {ONEHOT, OperatorType::kRecOneHot}, - {SQUEEZE, OperatorType::kRecSqueeze}, - {CAST, OperatorType::kRecCast}, - {REDUCE_SUM, OperatorType::kRecReduce}, - {REDUCE_MAX, OperatorType::kRecReduce}, - {REDUCE_MIN, OperatorType::kRecReduce}, - {REDUCE_MEAN, OperatorType::kRecReduce}, - {GATHERV2, OperatorType::kRecGatherV2}, - {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue}, - {ARGMINWITHVALUE, OperatorType::kRecArgWithValue}, - - {RELU, OperatorType::kRecReLU}, - {"ReLU6", OperatorType::kRecReLU}, - {"ReLUV2", OperatorType::kRecReLU}, - {SIGMOID, OperatorType::kRecReLU}, - {SIGMOID_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecReLU}, - {"HSigmoid", OperatorType::kRecReLU}, - {GELU, OperatorType::kRecReLU}, - {TANH, OperatorType::kRecReLU}, - - {PRELU, OperatorType::kRecPReLU}, - - {TRANSPOSE, OperatorType::kRecElmWiseOp}, - {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, - {TENSOR_ADD, OperatorType::kRecElmWiseOp}, - {SUB, OperatorType::kRecElmWiseOp}, - {MUL, OperatorType::kRecElmWiseOp}, - {DIV, OperatorType::kRecElmWiseOp}, - {REAL_DIV, OperatorType::kRecElmWiseOp}, - {SOFTMAX, OperatorType::kRecSoftmax}, - {LOG_SOFTMAX, OperatorType::kRecSoftmax}, - {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits}, - {SQRT, OperatorType::kRecElmWiseOp}, - {NEG, OperatorType::kRecElmWiseOp}, - {POW, OperatorType::kRecElmWiseOp}, - {EXP, OperatorType::kRecElmWiseOp}, - {LOG, OperatorType::kRecElmWiseOp}, - {COS, OperatorType::kRecElmWiseOp}, - {ACOS, OperatorType::kRecElmWiseOp}, - {LOGICALNOT, OperatorType::kRecElmWiseOp}, - {"LogicalAnd", OperatorType::kRecElmWiseOp}, - {"LogicalOr", OperatorType::kRecElmWiseOp}, - {SQUARE, OperatorType::kRecElmWiseOp}, - {"Abs", OperatorType::kRecElmWiseOp}, - {"Acosh", OperatorType::kRecElmWiseOp}, - {"AddN", OperatorType::kRecElmWiseOp}, - {"AccumulateNV2", OperatorType::kRecElmWiseOp}, - {"Atan2", OperatorType::kRecElmWiseOp}, - {"Erf", OperatorType::kRecElmWiseOp}, - {"Floor", OperatorType::kRecElmWiseOp}, - {FLOORDIV, OperatorType::kRecElmWiseOp}, - {"FloorMod", OperatorType::kRecElmWiseOp}, - {GREATER, OperatorType::kRecElmWiseOp}, - {"GreaterEqual", OperatorType::kRecElmWiseOp}, - {"HSwish", OperatorType::kRecElmWiseOp}, - {"Less", OperatorType::kRecElmWiseOp}, - {"LessEqual", OperatorType::kRecElmWiseOp}, - {MAXIMUM, OperatorType::kRecElmWiseOp}, - {MINIMUM, OperatorType::kRecElmWiseOp}, - {EQUAL, OperatorType::kRecElmWiseOp}, - {NOT_EQUAL, OperatorType::kRecElmWiseOp}, - {"Reciprocal", OperatorType::kRecElmWiseOp}, - {"Round", OperatorType::kRecElmWiseOp}, - {"Rsqrt", OperatorType::kRecElmWiseOp}, - {"Sign", OperatorType::kRecElmWiseOp}, - {"Sin", OperatorType::kRecElmWiseOp}, - {ASSIGN, OperatorType::kRecElmWiseOp}, - {ASSIGN_SUB, OperatorType::kRecElmWiseOp}, - {"AssignAdd", OperatorType::kRecElmWiseOp}}; - -const TensorParam MakeTensor(int n, int c, int h, int w); - -Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops); - -OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, - Graph::NodeType NewTensor); - -TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, - const size_t iter_input_tensor, Graph::NodeType NewTensor); - -std::shared_ptr ParseGraph(const std::vector> &ops, - const std::vector> &input_tensor_names); - -void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph); - -size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, - const std::string &input_name); - -void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, - const std::shared_ptr>> &eli_list); - -std::shared_ptr EliminateGraph(const std::shared_ptr &graph, - const std::shared_ptr>> &eli_list, - const std::shared_ptr> &index_list); -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc deleted file mode 100644 index d5200f54d8..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ /dev/null @@ -1,310 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/rec_core/rec_partition.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -// Get the target node's weight for sorting. -double GetWeights(const Graph::NodeType &node) { - const OperatorRec &op = node.apply; - - if (op.op_type == OperatorType::kRecMatMul) { - // For MatMul - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(op); - } else if (op.op_type == OperatorType::kRecConvolution) { - // For Convolution - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(node); - } else if (op.op_type == OperatorType::kRecPooling) { - // For Pooling - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecElmWiseOp) { - // For TensorAdd - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecReLU) { - // For Activation - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecReshape) { - // For Reshape - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecBiasAdd) { - // For BiasAdd - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecLog || op.op_type == OperatorType::kRecExp || - op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecSub || - op.op_type == OperatorType::kRecMul || op.op_type == OperatorType::kRecDiv || - op.op_type == OperatorType::kRecSqueeze || op.op_type == OperatorType::kRecCast) { - // For element-wise op - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot || - op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax || - op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || - op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { - // For BatchParallel op - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMaxCostIn(); - } else if (op.op_type == OperatorType::kRecUnkownType) { - // For Unkown type - return 0.0; - } else { - MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; - } -} - -// Sort all the nodes by their weights -std::vector SortByWeight(const std::shared_ptr &graph) { - MS_EXCEPTION_IF_NULL(graph); - - std::vector> weight_to_node_index; - std::vector node_index_by_weights; - - // Get node's weight. - for (size_t i = 0; i < graph->nodes.size(); i++) { - if (graph->nodes[i].info == kApplication) { - const Graph::NodeType &node_ptr = graph->nodes[i]; - double weight = GetWeights(node_ptr); - size_t index = i; - weight_to_node_index.push_back(std::make_pair(weight, index)); - } - } - - // Ordering ops aka nodes of the graph - std::sort(weight_to_node_index.begin(), weight_to_node_index.end()); - - // Store the result in node_index_by_weights. - uint64_t size = weight_to_node_index.size(); - for (uint64_t i = 1; i <= size; i++) { - node_index_by_weights.push_back(weight_to_node_index[size - i].second); - } - - return node_index_by_weights; -} - -// Get optimal strategy to partition the target node -StrategyRec PartitionNode(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const std::shared_ptr &graph) { - bool enable_conv_chw_partition = false; - MS_EXCEPTION_IF_NULL(graph); - - if (node.apply.op_type == OperatorType::kRecMatMul) { - // For MatMul - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecConvolution) { - // For Convolution - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph, enable_conv_chw_partition); - } else if (node.apply.op_type == OperatorType::kRecPooling) { - // For Pooling - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecElmWiseOp) { - // For TensorAdd - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecReLU) { - // For Activation - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecReshape) { - // For Reshape - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node); - } else if (node.apply.op_type == OperatorType::kRecBiasAdd) { - // For BiasAdd - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecLog || node.apply.op_type == OperatorType::kRecExp || - node.apply.op_type == OperatorType::kRecAdd || node.apply.op_type == OperatorType::kRecSub || - node.apply.op_type == OperatorType::kRecMul || node.apply.op_type == OperatorType::kRecDiv || - node.apply.op_type == OperatorType::kRecSqueeze || node.apply.op_type == OperatorType::kRecCast) { - // For element-wise op - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot || - node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax || - node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { - // For BatchParallel type - auto cost_ptr = std::make_shared(); - return cost_ptr->GetOptimalStr(node); - } else if (node.apply.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { - // For SoftmaxCrossEntropyWithLogits type - auto cost_ptr = std::make_shared(); - return cost_ptr->GetOptimalStr(node); - } else if (node.apply.op_type == OperatorType::kRecUnkownType) { - // For Unkown type - StrategyRec default_strategy; - return default_strategy; - } else { - MS_LOG(EXCEPTION) << "Failure: Partition Operator failed."; - } -} - -// Parttion graph into all devices. -Status PartitionForAllDevices(const size_t num_device, const double device_memory, - const std::shared_ptr &graph) { - if (num_device < 1) { - MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; - } - - if (num_device > 1024) { - MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be larger than 1024."; - } - - MS_EXCEPTION_IF_NULL(graph); - - // Comopute iter times - int iter_times = static_cast(log2(num_device)); - - // N-cuts loop - for (int loop = 0; loop < iter_times; loop++) { - // Sort by weights - std::vector reorder_node_list = SortByWeight(graph); - - // get total node number - size_t iter_nodes = reorder_node_list.size(); - - // temp vector to map nodename to its strategy. - std::vector> node_name_to_strategy; - - // Loop for all the nodes - for (size_t i_node = 0; i_node < iter_nodes; i_node++) { - // get current node's index - size_t index = reorder_node_list[i_node]; - - Graph::NodeType &node_ptr = graph->nodes[index]; - - // Serch optimal strategy to cut this operator. And store the result optimal strategy in graph. - graph->nodes[index].apply.str = PartitionNode(node_ptr, node_name_to_strategy, graph); - - // Apply OP Strategy to Tensor Strategy. - graph->nodes[index] = ApplyStrToTensor(node_ptr); - - // Note down the node name and its strategy in this loop. - auto node_name_to_str = - std::pair(graph->nodes[index].name, graph->nodes[index].apply.str); - node_name_to_strategy.push_back(node_name_to_str); - } - } - - if (DevicesMemoryControl(num_device, device_memory, graph) != SUCCESS) { - return FAILED; - } else { - return SUCCESS; - } -} - -// Apply OP Strategy to Tensor Strategy -Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { - // Set Node's tensor_parm - Node.tensor_parm.tensor_str.str_n = Node.apply.str.outputTensor.str_n; - Node.tensor_parm.tensor_str.str_c = Node.apply.str.outputTensor.str_c; - Node.tensor_parm.tensor_str.str_h = Node.apply.str.outputTensor.str_h; - Node.tensor_parm.tensor_str.str_w = Node.apply.str.outputTensor.str_w; - - // Set input tensors' tersor_parm - for (int i = 0; i < 2; i++) { - Node.apply.arguments[i].tensor_str.str_n = Node.apply.str.inputTensor[i].str_n; - Node.apply.arguments[i].tensor_str.str_c = Node.apply.str.inputTensor[i].str_c; - Node.apply.arguments[i].tensor_str.str_h = Node.apply.str.inputTensor[i].str_h; - Node.apply.arguments[i].tensor_str.str_w = Node.apply.str.inputTensor[i].str_w; - } - return Node; -} - -Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph) { - MS_EXCEPTION_IF_NULL(graph); - if (num_device == 0) { - MS_LOG(EXCEPTION) << "Failure: device number is 0."; - } - - uint64_t iter_nodes = graph->nodes.size(); - double used_memory = 0.0; - - for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) { - if (graph->nodes[i_node].info == 0) { - Graph::NodeType &Node = graph->nodes[i_node]; - for (int index = 0; index < 2; index++) { - used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n * - Node.apply.arguments[index].tensor_str.str_c * Node.apply.arguments[index].tensor_shape.shape_c * - Node.apply.arguments[index].tensor_str.str_h * Node.apply.arguments[index].tensor_shape.shape_h * - Node.apply.arguments[index].tensor_str.str_w * Node.apply.arguments[index].tensor_shape.shape_w * - GetDataTypeSize(Node.apply.arguments[index].tensor_type); - } - } - } - - if (device_memory < (used_memory / num_device)) { - MS_LOG(EXCEPTION) << "Failure: Out of memory!"; - return FAILED; - } else { - return SUCCESS; - } -} - -size_t GetDataTypeSize(const TensorType &type) { - switch (type) { - case kInt8: - return sizeof(int); - case kFloat16: - return sizeof(float) / 2; - case kFloat32: - return sizeof(float); - case kDouble64: - return sizeof(double); - default: - MS_LOG(EXCEPTION) << "GetDataTypeSize Failed. Unexpected type"; - } -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h deleted file mode 100644 index c98f3317f8..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ -#define PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "parallel/auto_parallel/rec_core/rec_cost.h" -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/auto_parallel/rec_core/rec_strategy.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -std::vector SortByWeight(const std::shared_ptr &graph); - -double GetWeights(const Graph::NodeType &node); - -StrategyRec PartitionNode(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const std::shared_ptr &graph); - -Status PartitionForAllDevices(const size_t num_device, const double device_memory, const std::shared_ptr &graph); - -Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); - -Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph); - -size_t GetDataTypeSize(const TensorType &type); -} // namespace parallel -} // namespace mindspore - -#endif // PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_tensor.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_tensor.h deleted file mode 100644 index 51ffca4023..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_tensor.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ -#define PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ - -#include "parallel/auto_parallel/rec_core/rec_strategy.h" - -namespace mindspore { -namespace parallel { -enum TensorType { kInt8, kFloat16, kFloat32, kDouble64 }; - -struct Shape4D { - int32_t shape_n = 1; - int32_t shape_c = 1; - int32_t shape_h = 1; - int32_t shape_w = 1; -}; - -struct TensorParam { - TensorType tensor_type = kFloat32; // default as float. - Shape4D tensor_shape; - TensorStr4D tensor_str; -}; -} // namespace parallel -} // namespace mindspore - -#endif // PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc deleted file mode 100644 index 062d814aa0..0000000000 --- a/mindspore/ccsrc/parallel/context.cc +++ /dev/null @@ -1,198 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/context.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "parallel/device_manager.h" - -namespace mindspore { -namespace parallel { -static std::map> param_shapes; - -std::vector PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, - AUTO_PARALLEL}; -std::vector STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING}; - -std::shared_ptr ParallelContext::inst_context_ = nullptr; - -std::shared_ptr ParallelContext::GetInstance() { - if (inst_context_ == nullptr) { - inst_context_.reset(new (std::nothrow) ParallelContext()); - } - return inst_context_; -} - -ParallelContext::ParallelContext() { Reset(); } - -void ParallelContext::Reset() { - mirror_mean_ = false; - full_batch_ = false; - cast_before_mirror_ = true; - loss_repeated_mean_ = true; - device_num_ = 1; - global_rank_ = 0; - communication_backend_ = HCCL_BACKEND; - device_num_is_set_ = false; - global_rank_is_set_ = false; - parallel_mode_ = STAND_ALONE; - parameter_broadcast_ = false; - parameter_broadcast_is_set_ = false; - enable_all_reduce_fusion_ = false; - strategy_ckpt_load_file_ = ""; - strategy_ckpt_save_file_ = ""; - enable_parallel_optimizer_ = false; -} - -void ParallelContext::set_device_num(int32_t device_num) { - device_num_ = device_num; - device_num_is_set_ = true; -} - -void ParallelContext::set_global_rank(int32_t global_rank) { - global_rank_ = global_rank; - global_rank_is_set_ = true; -} - -void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_mean; } - -void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } - -void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } - -void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } - -void ParallelContext::set_communication_backend(const std::string &communication_backend) { - communication_backend_ = communication_backend; -} - -bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { - auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); - if (iter == PARALLEL_MODE_LIST.end()) { - MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode; - return false; - } - parallel_mode_ = parallel_mode; - return true; -} - -bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) { - auto iter = std::find(STRATEGY_SEARCH_MODE_LIST.begin(), STRATEGY_SEARCH_MODE_LIST.end(), strategy_search_mode); - if (iter == STRATEGY_SEARCH_MODE_LIST.end()) { - MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode; - return false; - } - strategy_search_mode_ = strategy_search_mode; - return true; -} - -void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) { - parameter_broadcast_ = parameter_broadcast; - parameter_broadcast_is_set_ = true; -} - -void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) { - strategy_ckpt_load_file_ = strategy_ckpt_load_file; -} - -void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) { - strategy_ckpt_save_file_ = strategy_ckpt_save_file; -} - -void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group) { - all_reduce_fusion_split_indices_[group] = indices; -} - -const std::vector ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const { - auto iter = all_reduce_fusion_split_indices_.find(group); - if (iter != all_reduce_fusion_split_indices_.end()) { - return iter->second; - } - return {}; -} - -void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group) { - all_reduce_fusion_split_sizes_[group] = sizes; -} - -const std::vector ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const { - auto iter = all_reduce_fusion_split_sizes_.find(group); - if (iter != all_reduce_fusion_split_sizes_.end()) { - return iter->second; - } - return {}; -} - -// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode -void ParallelParameterContextInit(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { - return; - } - param_shapes.clear(); -} - -// Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode -void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - AbstractBasePtr ptr) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(param_node); - MS_EXCEPTION_IF_NULL(ptr); - if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) || - func_graph->has_flag(TRAINING)) { - return; - } - - auto iter = param_shapes.find(param_node->name()); - if (iter == param_shapes.end()) { - MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); - return; - } - std::vector shape = iter->second; - std::shared_ptr base_shape = std::make_shared(shape); - ptr->set_shape(base_shape); - MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; -} - -// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode -void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - const AbstractBasePtr &ptr) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(param_node); - MS_EXCEPTION_IF_NULL(ptr); - if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { - return; - } - - std::vector shape = dyn_cast(ptr->GetShapeTrack())->shape(); - auto ret = param_shapes.try_emplace(param_node->name(), shape); - if (!ret.second) { - MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed"; - return; - } - - MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h deleted file mode 100644 index 6a503ca7ed..0000000000 --- a/mindspore/ccsrc/parallel/context.h +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ -#define MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ - -#include -#include -#include -#include -#include - -#include "parallel/ops_info/ops_utils.h" -#include "parallel/status.h" -#include "utils/convert_utils.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "debug/info.h" -#include "pipeline/static_analysis/abstract_value.h" - -namespace mindspore { -namespace parallel { -constexpr char STAND_ALONE[] = "stand_alone"; -constexpr char DATA_PARALLEL[] = "data_parallel"; -constexpr char HYBRID_PARALLEL[] = "hybrid_parallel"; -constexpr char AUTO_PARALLEL[] = "auto_parallel"; -constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel"; - -constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming"; -constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; - -constexpr char TRAINING[] = "training"; - -class ParallelContext { - public: - ~ParallelContext() = default; - ParallelContext(const ParallelContext &) = delete; - ParallelContext &operator=(const ParallelContext &) = delete; - - static std::shared_ptr GetInstance(); - - void set_mirror_mean(bool mirror_mean); - bool mirror_mean() const { return mirror_mean_; } - - void set_full_batch(bool full_batch); - bool full_batch() const { return full_batch_; } - - void set_cast_before_mirror(bool cast_before_mirror); - bool cast_before_mirror() const { return cast_before_mirror_; } - - void set_loss_repeated_mean(bool loss_repeated_mean); - bool loss_repeated_mean() const { return loss_repeated_mean_; } - - void set_device_num(int32_t device_num); - int32_t device_num() const { return device_num_; } - - void set_global_rank(int32_t global_rank); - int32_t global_rank() const { return global_rank_; } - - void set_communication_backend(const std::string &communication_backend); - std::string communication_backend() const { return communication_backend_; } - - bool set_parallel_mode(const std::string ¶llel_mode); - std::string parallel_mode() const { return parallel_mode_; } - - bool set_strategy_search_mode(const std::string &strategy_search_mode); - std::string strategy_search_mode() const { return strategy_search_mode_; } - - void set_parameter_broadcast(bool parameter_broadcast); - bool parameter_broadcast() const { return parameter_broadcast_; } - - bool device_num_is_set() const { return device_num_is_set_; } - bool global_rank_is_set() const { return global_rank_is_set_; } - bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } - - void SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group); - const std::vector GetAllReduceFusionSplitIndices(const std::string &group) const; - void SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group); - const std::vector GetAllReduceFusionSplitSizes(const std::string &group) const; - void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { - enable_all_reduce_fusion_ = enable_all_reduce_fusion; - } - bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; } - - void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file); - std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } - void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); - std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } - - void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { - enable_parallel_optimizer_ = enable_parallel_optimizer; - } - bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } - - void Reset(); - - private: - ParallelContext(); - static std::shared_ptr inst_context_; - bool mirror_mean_; - bool full_batch_; - bool cast_before_mirror_; - bool loss_repeated_mean_; - int32_t device_num_; - int32_t global_rank_; - std::string communication_backend_; - std::string parallel_mode_; - std::string strategy_search_mode_; - bool parameter_broadcast_; - bool device_num_is_set_; - bool global_rank_is_set_; - bool parameter_broadcast_is_set_; - bool enable_all_reduce_fusion_; - std::map> all_reduce_fusion_split_indices_; - std::map> all_reduce_fusion_split_sizes_; - std::string strategy_ckpt_load_file_; - std::string strategy_ckpt_save_file_; - bool enable_parallel_optimizer_; -}; - -void ParallelParameterContextInit(const FuncGraphPtr &func_graph); -void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - AbstractBasePtr ptr); -void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - const AbstractBasePtr &ptr); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ diff --git a/mindspore/ccsrc/parallel/costmodel_context.cc b/mindspore/ccsrc/parallel/costmodel_context.cc deleted file mode 100644 index 92aff29557..0000000000 --- a/mindspore/ccsrc/parallel/costmodel_context.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/costmodel_context.h" - -#include - -#include "parallel/allreduce_fusion/allreduce_fusion.h" -#include "parallel/auto_parallel/graph_costmodel.h" - -namespace mindspore { -namespace parallel { -std::shared_ptr CostModelContext::cm_context_inst_ = nullptr; - -std::shared_ptr CostModelContext::GetInstance() { - if (cm_context_inst_ == nullptr) { - MS_LOG(INFO) << "Create costmodel_context"; - cm_context_inst_.reset(new (std::nothrow) CostModelContext()); - } - return cm_context_inst_; -} - -CostModelContext::CostModelContext() { - ResetCostModel(); - ResetAlgoParameters(); -} - -void CostModelContext::ResetCostModel() { - device_memory_capacity_ = DEFAULT_DEVICE_MEMORY_CAPACITY; - costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; - costmodel_beta_ = DEFAULT_COST_MODEL_BETA; - costmodel_gamma_ = DEFAULT_COST_MODEL_GAMMA; - costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; - costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; - costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; - is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; - run_phase_ = DEFAULT_RUN_PHASE; - costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; - costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; - costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; - costmodel_allreduce_fusion_tail_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME; - costmodel_allreduce_fusion_allreduce_inherent_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME; - costmodel_allreduce_fusion_allreduce_bandwidth_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH; - costmodel_allreduce_fusion_computation_time_parameter_ = - DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER; -} - -void CostModelContext::ResetAlgoParameters() { - costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; - tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; - tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; - fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; - elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; -} - -void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; } - -void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; } - -void CostModelContext::set_costmodel_beta(double cm_beta) { costmodel_beta_ = cm_beta; } - -void CostModelContext::set_costmodel_gamma(double cm_gamma) { costmodel_gamma_ = cm_gamma; } - -void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { costmodel_simplify_cal_ = cm_simplify; } - -void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) { - costmodel_communi_threshold_ = cm_communi_th; -} - -void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { - costmodel_communi_const_ = cm_communi_const; -} - -void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } - -void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; } -void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int32_t algorithm) { - costmodel_allreduce_fusion_algorithm_ = algorithm; -} - -void CostModelContext::set_costmodel_allreduce_fusion_times(int32_t allreduce_fusion_times) { - costmodel_allreduce_fusion_times_ = allreduce_fusion_times; -} - -void CostModelContext::set_costmodel_allreduce_fusion_tail_percent(double tail_percent) { - costmodel_allreduce_fusion_tail_percent_ = tail_percent; -} - -void CostModelContext::set_costmodel_allreduce_fusion_tail_time(double tail_time) { - costmodel_allreduce_fusion_tail_time_ = tail_time; -} - -void CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time(double allreduce_inherent_time) { - costmodel_allreduce_fusion_allreduce_inherent_time_ = allreduce_inherent_time; -} - -void CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth(double allreduce_bandwidth) { - costmodel_allreduce_fusion_allreduce_bandwidth_ = allreduce_bandwidth; -} - -void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter(double computation_time_parameter) { - costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter; -} - -void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { tensor_slice_alignment_enable_ = ts_align; } - -void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { - tensor_slice_alignment_size_ = ts_align_size; -} - -void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; } - -void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { - elementwise_stra_follow_ = elementwise_follow; -} - -void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/device.h b/mindspore/ccsrc/parallel/device.h deleted file mode 100644 index 8c3174ae55..0000000000 --- a/mindspore/ccsrc/parallel/device.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ -#define MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ - -#include -#include -#include - -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -class Device { - // This class abstract the 'device' information, used in Parallel module. - public: - Device() : rank_(0) { name_.clear(); } - explicit Device(int32_t rank) : rank_(rank) { name_.clear(); } - Device(std::string name, int32_t rank) : name_(std::move(name)), rank_(rank) {} - ~Device() = default; - std::string name() const { return name_; } - int32_t rank() const { return rank_; } - - private: - std::string name_; - int32_t rank_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ diff --git a/mindspore/ccsrc/parallel/device_manager.cc b/mindspore/ccsrc/parallel/device_manager.cc deleted file mode 100644 index 45628bec65..0000000000 --- a/mindspore/ccsrc/parallel/device_manager.cc +++ /dev/null @@ -1,374 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/device_manager.h" - -#include -#include -#include -#include -#include -#include - -#include "parallel/step_parallel.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -DeviceManagerPtr g_device_manager = nullptr; - -Stage::Stage(const std::vector &devices, int num, int rank) - : devices_(devices), number_(num), rank_(rank) { - gm_ = GroupManager(); -} - -// NOTE: '-1' indicates ERROR -int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); } - -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) { - if (device_num <= 0) { - MS_LOG(ERROR) << "'device_num' must be positive."; - return false; - } - if (global_rank < 0) { - MS_LOG(ERROR) << "'global_rank' must be nonnegative."; - return false; - } - if (device_num > MAX_DEVICE_NUM) { - MS_LOG(ERROR) << "'device_num' must be no more than " << MAX_DEVICE_NUM << "."; - return false; - } - // 'device_num_converted' must be the power of 2 - if ((IntToUint(device_num) & IntToUint(device_num - 1)) != 0) { - MS_LOG(ERROR) << "'device_num' must be the power of 2."; - return false; - } - if (global_rank >= device_num) { - MS_LOG(ERROR) << "'global_rank' must be less than 'device_num'."; - return false; - } - if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { - MS_LOG(ERROR) << "Invalid backend: " << backend; - return false; - } - - RankList devices, stage_map; - for (int i = 0; i < device_num; ++i) { - devices.push_back(i); - } - - stage_map.push_back(device_num); - g_device_manager = std::make_shared(); - if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) { - MS_LOG(INFO) << "Device initialization succeeds."; - return true; - } else { - MS_LOG(ERROR) << "Device initialization fails."; - return false; - } -} - -void CheckGlobalDeviceManager() { - if (g_device_manager == nullptr) { - MS_LOG(EXCEPTION) << "Device information has not been set!"; - } -} - -int32_t GetListMemberByIndex(size_t index, const RankList &devices) { - size_t i = 0; - int32_t result = 0; - if ((devices.empty()) || (index >= devices.size())) { - MS_LOG(EXCEPTION) << "Index is out of the list scope"; - } - auto it = devices.begin(); - for (; it != devices.end(); ++it) { - if (i == index) { - result = *it; - break; - } - ++i; - } - return result; -} - -std::shared_ptr GetListMemberByIndex(size_t index, const std::vector> &device_list) { - size_t i = 0; - std::shared_ptr result; - if ((device_list.empty()) || (index >= device_list.size())) { - MS_LOG(EXCEPTION) << "Index is out of the list scope"; - } - auto it = device_list.begin(); - for (; it != device_list.end(); ++it) { - if (i == index) { - result = *it; - break; - } - ++i; - } - return result; -} - -// E.g. devices = [4, 5, 2, 1, 7, 8, 10], stage_map = [4, 3], -// therefore the stage_devices_ = [[4, 5, 2, 1], [7, 8, 10]]. -Status DeviceManager::Init(const RankList &devices, int32_t global_device_rank, const RankList &stage_map, - const std::string &backend) { - auto dev_it = devices.begin(); - auto stage_it = stage_map.begin(); - int32_t sum = 0; - - if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { - MS_LOG(ERROR) << "Invalid backend: " << backend; - return Status::FAILED; - } - - for (; stage_it != stage_map.end(); ++stage_it) { - sum += (*stage_it); - } - if (IntToSize(sum) != devices.size()) { - MS_LOG(ERROR) << "The number of 'devices' in the list is not equal to the mentioned " - << "size of 'stage_map'"; - return Status::FAILED; - } - - for (; dev_it != devices.end(); ++dev_it) { - std::shared_ptr one = std::make_shared(*dev_it); - devices_.push_back(one); - } - - size_t global_index = 0; - for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) { - int num_device = *stage_it; - if (num_device > MAX_DEVICE_NUM) { - MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM; - return Status::FAILED; - } - if (num_device <= 0) { - MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; - return Status::FAILED; - } - RankList curr_dev_list; - for (int i = 0; i < num_device; ++i) { - curr_dev_list.push_back(GetListMemberByIndex(global_index, devices)); - global_index++; - } - stage_devices_.push_back(curr_dev_list); - } - - global_index = 0; - for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) { - int num_device = *stage_it; - if (num_device > MAX_DEVICE_NUM) { - MS_LOG(ERROR) << "The number of 'devices' in a stage must be less than " << MAX_DEVICE_NUM; - return Status::FAILED; - } - if (num_device <= 0) { - MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; - return Status::FAILED; - } - std::vector curr_dev_list; - for (int i = 0; i < num_device; ++i) { - curr_dev_list.push_back(*GetListMemberByIndex(global_index, devices_)); - global_index++; - } - std::shared_ptr new_stage = std::make_shared(curr_dev_list); - stages_.push_back(new_stage); - } - - std::shared_ptr dev = std::make_shared(global_device_rank); - device_ = dev; - set_global_rank(global_device_rank); - backend_ = backend; - - if (backend == HCCL_BACKEND) { - gm_.set_world_group(HCCL_WORLD_GROUP); - } else if (backend_ == NCCL_BACKEND) { - gm_.set_world_group(NCCL_WORLD_GROUP); - } else { - gm_.set_world_group(UNDEFINED_WORLD_GROUP); - } - MS_LOG(INFO) << "The device num: " << devices.size() << "rank id: " << global_device_rank - << "the backend: " << backend; - return Status::SUCCESS; -} - -std::shared_ptr DeviceManager::GetStageById(int32_t stage_id) { - std::shared_ptr res; - if (IntToSize(stage_id) >= stages_.size()) { - MS_LOG(ERROR) << "the 'stage_id': " << stage_id << ", is out of the scope of 'stage_devices_': " << stages_.size(); - return res; - } - int32_t index = 0; - for (auto &stage : stages_) { - if (index == stage_id) return stage; - index++; - } - return res; -} - -RankList DeviceManager::GetDeviceListByStageId(int32_t stage_id) const { - if (IntToSize(stage_id) >= stage_devices_.size()) - MS_LOG(ERROR) << "the 'stage_id': " << stage_id - << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); - RankList res; - int32_t index = 0; - for (auto &stage : stage_devices_) { - if (index == stage_id) { - return stage; - } - index++; - } - return res; -} - -RankList DeviceManager::global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const { - RankList res; - if (split_num <= 0) { - return res; - } - if (IntToSize(stage_id) >= stage_devices_.size()) { - MS_LOG(ERROR) << "the 'stage_id': " << stage_id - << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); - return res; - } - - RankList global_list = GetDeviceListByStageId(stage_id); - if (global_list.size() % IntToSize(split_num)) { - MS_LOG(ERROR) << "dev list size(" << global_list.size() << ") can not be divisible by split num: " << stage_id; - return res; - } - - std::vector dev_list; - (void)std::copy(global_list.begin(), global_list.end(), std::back_inserter(dev_list)); - - size_t index = 0; - size_t slice_size = dev_list.size() / IntToSize(split_num); - for (int32_t i = 0; i < split_num; ++i) { - bool found = false; - index = slice_size * IntToSize(i); - for (size_t j = 0; j < slice_size; ++j) { - if (dev_list[index + j] == rank) { - found = true; - break; - } - } - - if (found) { - break; - } - } - - for (size_t k = 0; k < slice_size; ++k) { - res.push_back(dev_list[index + k]); - } - return res; -} - -Device DeviceManager::CreateNewDeviceByRank(int32_t rank) const { return Device(rank); } - -std::vector DeviceManager::CreateDeviceListByRankList(RankList ranks) { - std::vector dev_list; - for (auto &rank : ranks) { - Device one = CreateNewDeviceByRank(rank); - dev_list.push_back(one); - } - return dev_list; -} - -DeviceManager &DeviceManager::GetInstance() { - static DeviceManager instance = DeviceManager(); - return instance; -} - -std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) { - std::string tmp = "WORLD_GROUP"; - if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) { - return tmp; - } - auto iter = group_to_rank_.find(hash_name); - if (iter == group_to_rank_.end()) { - MS_LOG(WARNING) << "Can not find the rank list name by hash name: " << hash_name; - return tmp; - } - return iter->second; -} - -std::string HashName(const std::string &origin_name) { return std::to_string(std::hash{}(origin_name)); } - -// Group name is generated using the increasing ranks of the devices. -// E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name -// is '0-1-3-5-7'. -std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) { - std::string rank_list_name; - std::vector::iterator it; - std::sort(ranks.begin(), ranks.end()); // sorted in increasing order - for (it = ranks.begin(); it != ranks.end(); ++it) { - if (it == ranks.begin()) { - rank_list_name = std::to_string(*it); - } else { - rank_list_name += "-" + std::to_string(*it); - } - } - - // hash rank-list-name and add ranks' size as prefix - std::string group_hash_name = HashName(rank_list_name); - std::string group_name = std::to_string(ranks.size()) + "-" + group_hash_name; - - if (rank_to_group_.find(rank_list_name) == rank_to_group_.end()) { - if (group_to_rank_.find(group_name) == group_to_rank_.end()) { - rank_to_group_[rank_list_name] = group_name; - group_to_rank_[group_name] = rank_list_name; - MS_LOG(INFO) << "The rank list name is " << rank_list_name << "nd group name is " << group_name; - } else { - MS_LOG(EXCEPTION) << "Hash collision, the current rank list: " << rank_list_name - << "the old rank list:" << group_to_rank_.find(group_name)->second - << "the group name: " << group_name; - } - } - return group_name; -} - -// Create the group with the given devices and the given name. The GroupManager -// gm_ will create a new group only if there does not exit a group with the same -// name. Otherwise, let the pointer g point to that group. -Group DeviceManager::CreateGroup(const std::string &group_name, - const std::vector &devices) { - if ((world_group() == NCCL_WORLD_GROUP) && (devices.size() != devices_.size())) { - MS_LOG(EXCEPTION) << "Do not support sub group for nccl"; - } - Group g; - (void)gm_.CreateGroup(group_name, devices, &g); - return g; -} - -// Create the group with only the given devices' ranks. -Group DeviceManager::CreateGroup(const RankList &dev_ranks) { - std::unordered_set rank_set(dev_ranks.begin(), dev_ranks.end()); - if (dev_ranks.size() != rank_set.size()) { - MS_LOG(EXCEPTION) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list"; - } - - std::string group_name = GenerateGroupNameByRanks(dev_ranks); - auto dev_list = CreateDeviceListByRankList(dev_ranks); - return CreateGroup(group_name, dev_list); -} - -void DeviceManager::Clear() { - devices_.clear(); - stage_devices_.clear(); - gm_.Clear(); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/device_manager.h b/mindspore/ccsrc/parallel/device_manager.h deleted file mode 100644 index 3afafe6a9c..0000000000 --- a/mindspore/ccsrc/parallel/device_manager.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ -#define MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "parallel/device.h" -#include "parallel/device_matrix.h" -#include "parallel/group_manager.h" -#include "parallel/status.h" -#include "parallel/strategy.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace parallel { -#define MAX_DEVICE_NUM 1024 - -constexpr char HCCL_BACKEND[] = "hccl"; -constexpr char NCCL_BACKEND[] = "nccl"; -constexpr char UNDEFINED_BACKEND[] = "undefined_backend"; - -class DeviceManager; -using DeviceManagerPtr = std::shared_ptr; -// 'g_device_manager' is the globally unique manager to manage the devices. -extern DeviceManagerPtr g_device_manager; - -class Stage { - // This class is used in pipeline-parallelization. Available devices are partitioned into multiple stages. - // Currently, the function of pipeline-parallelization and this class are NOT implemented. - public: - explicit Stage(std::vector devices) : devices_(std::move(devices)), number_(0), rank_(0) { - gm_ = GroupManager(); - } - Stage(const std::vector &devices, int num, int rank); - ~Stage() = default; - - int GetStageNum() const { return number_; } - size_t GetDevicesNum() const { return devices_.size(); } - std::vector GetDevicesList() { return devices_; } - int global_rank(Group *g) const; - - private: - std::vector devices_; - int number_; - int32_t rank_; - GroupManager gm_; -}; - -// This method is used for initializing the global DeviceManager 'g_device_manager', -// arguments including 'device_num' and 'global_rank' -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend); - -void CheckGlobalDeviceManager(); - -std::string HashName(const std::string &rank_list_name); - -class DeviceManager { - // This class is used to manage the abstract devices, including group-related and stage-related management. - public: - DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(0) { gm_ = GroupManager(); } - ~DeviceManager() = default; - - Status Init(const RankList &devices, int32_t local_device, const RankList &stage_map, const std::string &backend); - - static DeviceManager &GetInstance(); - RankList GetDeviceListByStageId(int32_t stage_id) const; - RankList global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const; - - Device CreateNewDeviceByRank(int32_t rank) const; - std::vector CreateDeviceListByRankList(RankList ranks); - - std::string GenerateGroupNameByRanks(RankList dev_ranks); - Group CreateGroup(const std::string &group_name, const std::vector &devices); - Group CreateGroup(const RankList &dev_ranks); - std::shared_ptr GetStageById(int32_t stage_id); - - size_t DeviceNum() const { return devices_.size(); } - - int32_t GetStageNum() const { return static_cast(stage_devices_.size()); } - - int32_t global_rank() const { return global_rank_; } - std::string backend() const { return backend_; } - void set_global_rank(int32_t global_rank) { global_rank_ = global_rank; } - void Clear(); - std::string world_group() const { return gm_.world_group(); } - std::string FindRankListNameByHashName(const std::string &hash_name); - - private: - std::vector> devices_; - // each stage has a list of devices - std::vector> stage_devices_; - std::shared_ptr device_; - std::vector> stages_; - GroupManager gm_; - std::string backend_; - - // bimap: - std::map rank_to_group_; // the key is rank list, value is hash name - std::map group_to_rank_; // the key is hash name, value is rank list - - int32_t local_rank_; - int32_t global_rank_; - int32_t stage_num_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/parallel/device_matrix.cc b/mindspore/ccsrc/parallel/device_matrix.cc deleted file mode 100644 index 3c9467a223..0000000000 --- a/mindspore/ccsrc/parallel/device_matrix.cc +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/device_matrix.h" - -#include -#include -#include -#include -#include -#include - -#include "parallel/ops_info/operator_info.h" -#include "parallel/status.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -DeviceMatrix::DeviceMatrix(int32_t rank, RankList dev_list, Shape dev_shape) - : rank_(rank), dev_list_(std::move(dev_list)), dev_shape_(std::move(dev_shape)) { - if (!std::any_of(dev_list_.begin(), dev_list_.end(), [rank](int32_t a) { return a == rank; })) { - MS_LOG(EXCEPTION) << "Rank " << rank << " is not in the current stage!"; - } - int32_t total = std::accumulate(dev_shape_.begin(), dev_shape_.end(), 1, std::multiplies()); - if (IntToSize(total) != dev_list_.size()) { - MS_LOG(EXCEPTION) << "Device shape does not match the size of the device list!"; - } -} - -Status DeviceMatrix::CreateGroupList() { - size_t size = dev_shape_.size(); - RankList group; - for (size_t i = 0; i < size; i++) { - Status status = GetDevicesAlongDim(SizeToUint(i), &group); - group_list_.push_back(group); - if (status == Status::FAILED) { - return Status::FAILED; - } - } - return Status::SUCCESS; -} - -Status DeviceMatrix::GetDevicesAlongDim(const uint32_t &dim, RankList *devices) { - if (dim >= dev_shape_.size()) { - MS_LOG(EXCEPTION) << "The dimension " << dim << " is out of the size of the device shape!"; - } - if (dev_shape_[dim] == 1) { - *devices = {rank_}; - return Status::SUCCESS; - } - - RankList group; - std::vector local_group_list; - - // lower than dim - int32_t step = 1; - for (uint32_t i = dim + 1; i < dev_shape_.size(); i++) { - step = step * dev_shape_[i]; - } - int32_t num = *dev_list_.begin(); - for (int32_t i = 0; i < dev_shape_[dim]; i++) { - group.push_back(num); - num += step; - } - - for (int32_t i = 0; i < step; i++) { - local_group_list.push_back(group); - (void)std::for_each(group.begin(), group.end(), [](int32_t &a) { a++; }); - } - - // higher than dim - step = step * dev_shape_[dim]; - int32_t len = SizeToInt(dev_list_.size()) / step; - - // search rank - int32_t target = rank_; - for (int32_t i = 0; i < len; i++) { - for (RankList &temp : local_group_list) { - if (std::any_of(temp.begin(), temp.end(), [target](int32_t a) { return a == target; })) { - *devices = temp; - return Status::SUCCESS; - } - (void)std::for_each(temp.begin(), temp.end(), [step](int32_t &a) { a = a + step; }); - } - } - MS_LOG(ERROR) << "Can't find groups for rank" << rank_ << " in device list!"; - return Status::FAILED; -} - -Shape ConvertRankToCoordinate(int32_t rank, const Shape &dev_shape) { - Shape dev_coordinate; - for (size_t i = 0; i < dev_shape.size(); ++i) { - int32_t size = dev_shape[dev_shape.size() - i - 1]; - if (size == 0) { - MS_LOG(EXCEPTION) << "Invalid dev shape: " << ShapeToString(dev_shape); - } else { - int32_t index = rank % size; - (void)dev_coordinate.insert(dev_coordinate.begin(), index); - rank = rank / size; - } - } - return dev_coordinate; -} - -Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list) { - for (auto &element : tensor_map) { - // -1 means the corresponding dimension is not split. - if (element == MAP_NONE) { - continue; - } else if ((element < 0) || (IntToSize(element) >= dev_shape_.size())) { - MS_LOG(ERROR) << "create group by tensor map: the tensor map is invalid"; - return FAILED; - } - } - - Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_); - for (auto &tmp_rank : dev_list_) { - Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_); - bool matched = true; - for (auto &map : tensor_map) { - if (map == MAP_NONE) { - continue; - } - size_t index = dev_shape_.size() - IntToSize(map) - 1; - if (current_rank_coordinate[index] != tmp_rank_coordinate[index]) { - matched = false; - break; - } - } - if (matched) { - rank_list->push_back(tmp_rank); - } - } - - return SUCCESS; -} - -std::string ShapeToString(const Shape &shape) { - std::string str = "["; - for (size_t i = 0; i < shape.size(); ++i) { - str += std::to_string(shape[i]); - if (i < shape.size() - 1) { - str += ", "; - } - } - return str + "]"; -} - -std::string ListToString(const std::vector &list) { - std::string str = "["; - for (auto &element : list) { - str += std::to_string(element) + ", "; - } - return str + "]"; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/device_matrix.h b/mindspore/ccsrc/parallel/device_matrix.h deleted file mode 100644 index 295bf33836..0000000000 --- a/mindspore/ccsrc/parallel/device_matrix.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ -#define MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ - -#include -#include -#include - -#include "parallel/status.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace parallel { -using RankList = std::vector; -using Shape = std::vector; - -class DeviceMatrix { - public: - DeviceMatrix(int32_t rank, RankList devices, Shape dev_shape); - DeviceMatrix() = default; - ~DeviceMatrix() = default; - std::vector group_list() const { return group_list_; } - Status CreateGroupList(); - Status GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list); - Status GetDevicesAlongDim(const uint32_t &dim, RankList *devices); - - private: - int32_t rank_ = -1; - RankList dev_list_; - // From low dim to high dim. eg: [D0 D1 D2 D3] - Shape dev_shape_; - std::vector group_list_; -}; - -std::string ShapeToString(const Shape &shape); -std::string ListToString(const std::vector &list); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h deleted file mode 100644 index f8e1d62d0a..0000000000 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ -#define MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ - -#include -#include -#include -#include - -#include "parallel/ops_info/ops_info_head_files.h" -#include "parallel/step_parallel.h" - -namespace mindspore { -namespace parallel { -#define REGISTER(className) \ - OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \ - return std::make_shared(name, in, out, attrs); \ - } \ - RegisterAction className##Register(#className, (CreatFn)objectCreator##className); - -typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out, - const PrimitiveAttrs &attrs); - -class DynCreator { - public: - ~DynCreator() = default; - - // creat static singleton dyn_creator instance - static DynCreator &Instance() { - static DynCreator fac = DynCreator(); - return fac; - } - // register - void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } - // creator - OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, - const PrimitiveAttrs &attrs, size_t count) { - std::string op_name = name + std::to_string(count); - auto iter = Function_map_.find(name); - if (iter == Function_map_.end()) { - MS_LOG(INFO) << name << " is not register yet"; - return nullptr; - } - return iter->second(op_name, shape_in, shape_out, attrs); - } - - private: - DynCreator() = default; - std::map Function_map_; -}; - -class RegisterAction { - public: - RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { - DynCreator::Instance().Regist(name, creatfn); - } - ~RegisterAction() = default; - - private: - std::string name_; -}; - -// operator register -REGISTER(MatMulInfo); -REGISTER(GeluInfo); -REGISTER(VirtualDatasetInfo); -REGISTER(BatchParallelInfo); -REGISTER(TanhInfo); -REGISTER(SoftmaxInfo); -REGISTER(LogSoftmaxInfo); -REGISTER(ActivationInfo); -REGISTER(SoftmaxCrossEntropyWithLogitsInfo); -REGISTER(SubInfo); -REGISTER(TensorAddInfo); -REGISTER(BiasAddInfo); -REGISTER(MulInfo); -REGISTER(DivInfo); -REGISTER(RealDivInfo); -REGISTER(PowInfo); -REGISTER(ExpInfo); -REGISTER(OneHotInfo); -REGISTER(EqualInfo); -REGISTER(NotEqualInfo); -REGISTER(LogInfo); -REGISTER(CosInfo); -REGISTER(ACosInfo); -REGISTER(LogicalNotInfo); -REGISTER(L2NormalizeInfo); -REGISTER(LayerNormInfo); -REGISTER(ReduceMaxInfo); -REGISTER(ArgMaxWithValueInfo); -REGISTER(ArgMinWithValueInfo); -REGISTER(ReduceMeanInfo); -REGISTER(ReduceSumInfo); -REGISTER(ReduceMinInfo); -REGISTER(TransposeInfo); -REGISTER(PReLUInfo); -REGISTER(DropoutDoMaskInfo); -REGISTER(ReshapeInfo); -REGISTER(FloorDivInfo); -REGISTER(MaximumInfo); -REGISTER(MinimumInfo); -REGISTER(CastInfo); -REGISTER(GreaterInfo); -REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo); -REGISTER(AssignSubInfo); -REGISTER(ReLUInfo); -REGISTER(GatherV2Info); -REGISTER(SparseGatherV2Info); -REGISTER(SqrtInfo); -REGISTER(SigmoidInfo); -REGISTER(GetNextInfo); -REGISTER(NegInfo); -REGISTER(BatchMatMulInfo); -REGISTER(ExpandDimsInfo); -REGISTER(SqueezeInfo); -REGISTER(SigmoidCrossEntropyWithLogitsInfo); -REGISTER(SquareInfo); -REGISTER(GatherV2PInfo); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/parallel/graph_util/generate_graph.h deleted file mode 100644 index 71227a6e7b..0000000000 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ -#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ - -#include -#include -#include -#include -#include -#include - -#include "./common.h" -#include "optimizer/opt.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -#define USING_HASH_NAME "USING_HASH_NAME" -// Get the operator's path where the operator has be defined -std::string GetOpPythonPath(const OperatorName &op_name); - -// Init python operator Instance -ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name); - -AnfNodePtr CreatTypeInt(int32_t value); -AnfNodePtr CreatInt32Imm(int32_t value); -AnfNodePtr CreateInt32Tensor(int32_t value); -AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr); -std::string HashInstanceName(const std::string &name); - -class GenerateGraph { - public: - GenerateGraph() : name_idx_(0) {} - Status Init(const CNodePtr &cnode); - ~GenerateGraph() = default; - AnfNodePtr virtual_input_node() { return virtual_input_node_; } - AnfNodePtr NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs); - AnfNodePtr NewOpInst(const OperatorName &op_name); - AnfNodePtr PushBack(const std::vector &inputs); - - private: - CNodePtr cnode_; - FuncGraphManagerPtr manager_; - ScopePtr scope_; - FuncGraphPtr func_graph_; - AnfNodePtr virtual_input_node_; - std::string instance_name_base_; - int64_t name_idx_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc deleted file mode 100644 index 32cd106d8e..0000000000 --- a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/graph_util/get_parallel_info.h" - -#include -#include -#include -#include - -#include "common/utils.h" -#include "ir/func_graph.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/graph_util/graph_info.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -py::dict GetParameterLayout(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - py::dict dict; - std::vector graph_params = graph->parameters(); - - for (auto para : graph_params) { - std::string name = std::static_pointer_cast(para)->name(); - std::shared_ptr tensor_layout = std::static_pointer_cast(para)->tensor_layout(); - if (tensor_layout == nullptr) { - MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; - } else { - auto device_arrangement = tensor_layout->device_arrangement().array(); - auto tensor_map = tensor_layout->tensor_map().array(); - auto slice_shape = tensor_layout->slice_shape().array(); - std::vector> layout = {device_arrangement, tensor_map, slice_shape}; - dict[py::str(name)] = layout; - MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); - } - } - return dict; -} - -py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - py::dict dict; - auto ret = graph->get_return(); - MS_EXCEPTION_IF_NULL(ret); - auto nodes = DeepScopedGraphSearch(ret); - - for (auto node : nodes) { - if (node->isa()) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto distributed_operation_info = cnode->operator_info(); - if (distributed_operation_info != nullptr) { - auto strategyPtr = distributed_operation_info->strategy(); - if (strategyPtr != nullptr) { - auto strategy = strategyPtr->GetInputDim(); - auto name = cnode->fullname_with_scope(); - dict[py::str(name)] = strategy; - } - } - } - } - return dict; -} - -py::dict GetAllreduceFusion(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - py::dict dict; - auto allreduce_prim_list = FindPrimtive(graph, ALL_REDUCE); - - for (auto prim : allreduce_prim_list) { - auto name_ptr = prim->GetAttr("parameter"); - auto fusion_ptr = prim->GetAttr("fusion"); - if (fusion_ptr == nullptr) { - MS_LOG(EXCEPTION) << "fusion_ptr is nullptr"; - } else if (name_ptr == nullptr) { - continue; - } - if (!name_ptr->isa()) { - MS_LOG(EXCEPTION) << "name is not StringImm"; - } - auto name = name_ptr->cast()->value(); - if (!fusion_ptr->isa()) { - MS_LOG(EXCEPTION) << "fusion is not Int32Imm"; - } - int32_t fusion = fusion_ptr->cast()->value(); - dict[py::str(name)] = fusion; - } - return dict; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/graph_info.cc b/mindspore/ccsrc/parallel/graph_util/graph_info.cc deleted file mode 100644 index 175413c0fd..0000000000 --- a/mindspore/ccsrc/parallel/graph_util/graph_info.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/graph_util/graph_info.h" -#include "debug/anf_ir_dump.h" -#include "debug/anf_ir_utils.h" -#include "debug/draw.h" -#include "ir/func_graph.h" -#include "utils/context/ms_context.h" -#include "utils/graph_utils.h" - -namespace mindspore { -namespace parallel { -std::vector FindPrimtive(const FuncGraphPtr &graph, const std::string &name) { - AnfNodePtr ret = graph->get_return(); - MS_EXCEPTION_IF_NULL(ret); - std::vector all_nodes = DeepScopedGraphSearch(ret); - std::vector prim_list; - for (auto &node : all_nodes) { - if (!IsValueNode(node)) { - continue; - } - ValueNodePtr prim_node_anf = node->cast(); - MS_EXCEPTION_IF_NULL(prim_node_anf); - PrimitivePtr node_prim = prim_node_anf->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == name) { - prim_list.emplace_back(node_prim); - } - } - return prim_list; -} - -void DumpGraph(const FuncGraphPtr &root, const std::string &name) { - if (MsContext::GetInstance()->save_graphs_flag()) { - draw::Draw(name + ".dot", root); - DumpIR(name + ".ir", root); - ExportIR(name + ".dat", "0", root); - } -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.cc b/mindspore/ccsrc/parallel/graph_util/node_info.cc deleted file mode 100644 index 7298b06832..0000000000 --- a/mindspore/ccsrc/parallel/graph_util/node_info.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/graph_util/node_info.h" - -#include - -#include "ir/anf.h" -#include "ir/param_value_py.h" -#include "pipeline/parse/python_adapter.h" - -namespace mindspore { -namespace parallel { -std::string ParameterName(const AnfNodePtr &node_ptr) { - auto para_ptr = node_ptr->cast(); - MS_EXCEPTION_IF_NULL(para_ptr); - return para_ptr->name(); -} - -bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { - auto para_ptr = node_ptr->cast(); - if (para_ptr == nullptr) { - return false; - } - if (!para_ptr->has_default()) { - return false; - } - auto param_value = std::dynamic_pointer_cast(para_ptr->default_param()); - return py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad")); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.h b/mindspore/ccsrc/parallel/graph_util/node_info.h deleted file mode 100644 index bda268e582..0000000000 --- a/mindspore/ccsrc/parallel/graph_util/node_info.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ - -#include -#include "ir/base.h" - -namespace mindspore { -namespace parallel { -std::string ParameterName(const AnfNodePtr &node_ptr); - -bool ParameterRequireGrad(const AnfNodePtr &node_ptr); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ diff --git a/mindspore/ccsrc/parallel/group_manager.cc b/mindspore/ccsrc/parallel/group_manager.cc deleted file mode 100644 index 1562cbc140..0000000000 --- a/mindspore/ccsrc/parallel/group_manager.cc +++ /dev/null @@ -1,178 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/group_manager.h" - -#include -#include - -#include "parallel/device_manager.h" -#include "parallel/ops_info/ops_utils.h" -#include "utils/comm_manager.h" - -namespace mindspore { -namespace parallel { -Group::Group() { - name_.clear(); - devices_.clear(); -} - -Status Group::Init(const std::string &name, const std::vector &devices) { - this->name_ = name; - this->devices_ = devices; - return Status::SUCCESS; -} - -std::vector Group::GetDevicesList() const { return devices_; } - -bool Group::IsInThisGroup(int32_t device_rank) { - for (auto &device : devices_) { - if (device.rank() == device_rank) { - return true; - } - } - return false; -} - -// Get the position of the device in the group -Status Group::GetIndex(size_t *index) { - size_t pos = 0; - CheckGlobalDeviceManager(); - int32_t rank = g_device_manager->global_rank(); - for (auto &device : devices_) { - if (device.rank() == rank) { - *index = pos; - return Status::SUCCESS; - } else { - pos++; - } - } - MS_LOG(ERROR) << "Could not find device rank " << rank << "in this group!"; - return Status::FAILED; -} - -GroupManager::GroupManager() { groups_.clear(); } - -Status GroupManager::CreateGroup(const std::string &group_name, const std::vector &devices, - mindspore::parallel::Group *const group) { - // it is simple to use size to determine whether it is a world group - uint32_t world_size = 0; - if (world_group_ != NCCL_WORLD_GROUP) { - (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); - } - - if ((world_group_ == NCCL_WORLD_GROUP) || (devices.size() == world_size)) { - auto it = groups_.find(world_group_); - if (it == groups_.end()) { - (void)group->Init(world_group_, devices); - groups_[world_group_] = *group; - } else { - *group = it->second; - } - MS_LOG(INFO) << "It is world group " << world_group_ << ", no need to create it."; - return Status::SUCCESS; - } - - auto it = groups_.find(group_name); - // If there already exits a group with the desired 'name', - // let the pointer point to the group. - if (it != groups_.end()) { - *group = it->second; - return Status::SUCCESS; - } else { - (void)group->Init(group_name, devices); - groups_[group_name] = *group; - - vector ranks; - (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks), - [](const Device dev) { return (uint32_t)dev.rank(); }); - // Create group through the CommManager interface - bool ret = CommManager::GetInstance().CreateGroupSync(group_name, ranks); - if (!ret) { - MS_LOG(ERROR) << "Create group failed, group name is " << group_name; - return Status::FAILED; - } - - MS_LOG(INFO) << "Create group success, group name is " << group_name; - return Status::SUCCESS; - } -} - -Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { - std::string name = (*group).name(); - auto it = groups_.find(name); - if (it == groups_.end()) { - MS_LOG(ERROR) << "Could not find group name :" << name; - return Status::FAILED; - } - (void)groups_.erase(it); - bool ret = CommManager::GetInstance().DestroyGroup(name); - if (!ret) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -Status GroupManager::DestroyAllGroups() { - for (auto &it : groups_) { - std::string name = it.first; - bool ret = CommManager::GetInstance().DestroyGroup(name); - if (!ret) { - return Status::FAILED; - } - } - groups_.clear(); - return Status::SUCCESS; -} - -Status GroupManager::GetRankID(const std::string &name, unsigned int *const rank_id) { - auto it = groups_.find(name); - if (it == groups_.end()) { - MS_LOG(ERROR) << "Could not find group name :" << name; - return Status::FAILED; - } - bool ret = CommManager::GetInstance().GetRankID(name, rank_id); - if (!ret) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -Status GroupManager::GetRankSize(const std::string &name, unsigned int *const rank_size) { - auto it = groups_.find(name); - if (it == groups_.end()) { - MS_LOG(ERROR) << "Could not find group name :" << name; - return Status::FAILED; - } - bool ret = CommManager::GetInstance().GetRankSize(name, rank_size); - if (!ret) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Group **group) { - auto it = groups_.find(name); - if (it == groups_.end()) { - return Status::FAILED; - } - *group = &it->second; - return Status::SUCCESS; -} - -void GroupManager::Clear() { (void)DestroyAllGroups(); } -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/group_manager.h b/mindspore/ccsrc/parallel/group_manager.h deleted file mode 100644 index f763d483cc..0000000000 --- a/mindspore/ccsrc/parallel/group_manager.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ -#define MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ - -#include -#include -#include -#include - -#include "parallel/device.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -constexpr char HCCL_WORLD_GROUP[] = "hccl_world_group"; -constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; -constexpr char UNDEFINED_WORLD_GROUP[] = "undefined_world_group"; - -// Devices that need communication should in the same group. These classes are used to -// create and destroy group among devices. -class Group { - public: - Group(); - ~Group() = default; - Status Init(const std::string &name, const std::vector &devices); - std::vector GetDevicesList() const; - std::string name() const { return name_; } - bool IsInThisGroup(int32_t device_rank); - Status GetIndex(size_t *index); - size_t GetDevNum() const { return devices_.size(); } - - private: - std::string name_; - std::vector devices_; -}; - -class GroupManager { - public: - GroupManager(); - ~GroupManager() = default; - - Status CreateGroup(const std::string &name, const std::vector &devices, Group *group); - Status DestroyGroup(Group *group); - Status DestroyAllGroups(); - Status GetRankID(const std::string &name, unsigned int *rank_id); - Status GetRankSize(const std::string &name, unsigned int *rank_size); - Status FindGroup(const std::string &name, Group **group); - std::string world_group() const { return world_group_; } - void set_world_group(const std::string &name) { world_group_ = name; } - void Clear(); - - private: - // the key is group name (name_) - std::map groups_; - std::string world_group_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ diff --git a/mindspore/ccsrc/parallel/node_check.cc b/mindspore/ccsrc/parallel/node_check.cc deleted file mode 100644 index 6b920f82ec..0000000000 --- a/mindspore/ccsrc/parallel/node_check.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/node_check.h" - -#include -#include - -#include "parallel/ops_info/ops_utils.h" - -namespace mindspore { -namespace parallel { -const std::set BLACK_LIST = {TUPLE_GETITEM, - MAKE_TUPLE, - J, - LIST_GETITEM, - ARRAY_GETITEM, - TUPLE_SETITEM, - DEPEND, - LIST_SETITEM, - ARRAY_SETITEM, - DICT_GETITEM, - LIST_APPEND, - LIST_MAP, - LIST_REDUCE, - TUPLE_REVERSED, - TILE_SHAPE, - TUPLE_DIV, - TUPLE_TO_ARRAY, - MAKE_LIST, - MAKE_DICT, - MAKE_SLICE, - MAKE_RECORD, - STRING_EQUAL, - VIRTUALLOSS, - RETURN, - ENV_GETITEM, - IDENTITY, - PARTIAL, - ENVSETITEM, - ENVGETITEM, - ENVADD, - MAKEREFKEY, - MAKEREF, - GETREFKEY, - GETREFVALUE, - GETREFORIGIN, - DOT, - IM2COL, - COL2IM, - IM2COLV1, - STATESETITEM, - SCALARSUMMARY, - IMAGESUMMARY, - TENSORSUMMARY, - DEBUG, - HISTOGRAMSUMMARY, - COL2IMV1, - RESOLVE, - BROADCASTGRADIENTARGS, - INVERTPERMUTATION, - CONTROLDEPEND, - DROPOUT_GEN_MASK, - EMBED, - CREATINSTANCE, - ZEROSLIKE, - ASSIGN, - REF_TO_EMBED, - STOP_GRADIENT}; - -bool IsInBlackList(const PrimitivePtr &prim) { - MS_EXCEPTION_IF_NULL(prim); - return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end()); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/parallel/ops_info/activation_info.cc deleted file mode 100644 index 6bc33677a6..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.cc +++ /dev/null @@ -1,705 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/activation_info.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status Activation::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - - return SUCCESS; -} - -Status ActivationInfo::GetAttrs() { - if (attrs_.size() < ACTIVATION_ATTR_SIZE) { - MS_LOG(ERROR) << name_ << " : The size of attrs small than 1."; - return FAILED; - } - - if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" - << outputs_shape_.size() << "is wrong."; - return FAILED; - } - - auto iter = attrs_.find(ACTIVATION_TYPE); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - std::string val = iter->second->cast()->value(); - if ((val != RELU_TYPE) && (val != RELU6_TYPE) && (val != SIGMOID_TYPE)) { - MS_LOG(ERROR) << name_ << " : Activation type is wrong."; - return FAILED; - } - } else { - MS_LOG(ERROR) << name_ << " : The value of activation_type is not string."; - return FAILED; - } - } - - return SUCCESS; -} - -Status ActivationOther::GetAttrs() { - if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" - << outputs_shape_.size() << "is wrong."; - return FAILED; - } - return SUCCESS; -} - -Status Activation::GenerateStrategies(int32_t stage_id) { - if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" - << outputs_shape_.size() << "is wrong."; - return FAILED; - } - - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status Softmax::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - Dimensions input_strategy = stra.at(0); - - for (auto &element : axis_) { - int32_t axis_index = element; - if (element < 0) { - size_t input_dim = inputs_shape_.at(0).size(); - axis_index = static_cast(input_dim) + element; - } - - int32_t axis_strategy = input_strategy.at(IntToSize(axis_index)); - // Dimension corresponding to axis is un-splittable - if (axis_strategy != MIN_SLICE_NUM) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1"; - } else { - MS_LOG(ERROR) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1"; - } - return FAILED; - } - } - - return SUCCESS; -} - -Status Softmax::GetAttrs() { - if (attrs_.size() < SOFTMAX_ATTR_SIZE) { - MS_LOG(ERROR) << name_ << " : The size of attrs small than 1."; - return FAILED; - } - - auto iter = attrs_.find(AXIS); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { // the axis is a number - int32_t axis_element = iter->second->cast()->value(); - axis_.push_back(axis_element); - MS_LOG(INFO) << name_ << " : The axis is int, value is " << axis_element; - } else if (iter->second->isa()) { // the axis is a tuple - ValueTuplePtr value_tuple = iter->second->cast(); - if (value_tuple == nullptr) { - MS_LOG(ERROR) << name_ << " : The value_tuple is nullptr."; - return FAILED; - } - std::vector value_vector = value_tuple->value(); - (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_), - [](const ValuePtr &value) { return static_cast(GetValue(value)); }); - if (axis_.empty()) { - MS_LOG(ERROR) << name_ << " : The axis tuple is empty."; - return FAILED; - } - MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ShapeToString(axis_); - } else { - MS_LOG(ERROR) << name_ << " : The value of axis is not int or tuple int."; - return FAILED; - } - } - - if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; - return FAILED; - } - - // for example: tensor dimension is 4, then axis range [-4, 3] - int32_t dim = SizeToInt(inputs_shape_.at(0).size()); - auto it = - std::find_if(axis_.begin(), axis_.end(), [dim](int32_t element) { return ((element >= dim) || (element < -dim)); }); - if (it != axis_.end()) { - MS_LOG(ERROR) << name_ << " : The axis(" << *it << ") is out of range[" << -dim << ", " << dim - 1 << "]."; - return FAILED; - } - - return SUCCESS; -} - -Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status Softmax::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << " : GetAttrs failed."; - return FAILED; - } - if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; - return FAILED; - } - - is_auto_parallel_ = true; - Shape input0_split; - (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); - for (auto &element : axis_) { - int32_t axis_index = element; - if (element < 0) { - size_t input_dim = inputs_shape_.at(0).size(); - axis_index = static_cast(input_dim) + element; - } - input0_split[IntToSize(axis_index)] = 0; - } - Shapes splittable_inputs = {input0_split}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status ActivationBase::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - - dev_matrix_shape_ = input_strategy; - - return SUCCESS; -} - -Status ActivationBase::InferMirrorOps() { - mirror_ops_.clear(); - - Shape tensor_map = inputs_tensor_map_[0]; - std::vector group; - if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group failed."; - return FAILED; - } - - OperatorVector mirror_op; - if (group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror ops is empty."; - return SUCCESS; - } else { - mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); - mirror_ops_.push_back(mirror_op); - std::string group_name = group[0].name(); - MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; - } - - return SUCCESS; -} - -Status ActivationBase::InferForwardCommunication() { - // do nothing - return SUCCESS; -} - -Status ActivationBase::InferTensorMap() { - std::vector tensor_map_index; - size_t size = inputs_shape_.at(0).size(); - // such as 4: tensor_map_index [3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - i - 1)); - } - - inputs_tensor_map_.push_back(tensor_map_index); - outputs_tensor_map_.push_back(tensor_map_index); - return SUCCESS; -} - -Status ActivationBase::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {inputs_strategy.at(0)}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - - TensorLayout input_tensor_layout; - if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(input_tensor_info); // the same as input - - return SUCCESS; -} - -Status ActivationBase::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} - -Status CastInfo::InferMirrorOps() { - mirror_ops_.clear(); - - Shape tensor_map = inputs_tensor_map_[0]; - std::vector group; - if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group failed."; - return FAILED; - } - - OperatorVector mirror_op; - OperatorVector op_for_value; - if (group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror ops is empty."; - return SUCCESS; - } else { - mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); - mirror_ops_.push_back(mirror_op); - mirror_ops_.push_back(op_for_value); - std::string group_name = group[0].name(); - MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; - } - - return SUCCESS; -} - -Status ExpandDimsInfo::GetAttrs() { - if (input_value_.size() != EXPANDDIMS_INPUT_SIZE) { - MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size(); - return FAILED; - } - - if (!input_value_.back()->isa()) { - MS_LOG(ERROR) << name_ << ": The type of axis is not int"; - return FAILED; - } - - int32_t axis = GetValue(input_value_.back()); - - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - - int32_t dim = SizeToInt(inputs_shape_[0].size()); - if ((axis > dim) || (axis < -dim - 1)) { - MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim - 1 << ", " << dim << "]"; - return FAILED; - } - - if (axis < 0) { - positive_axis_ = dim + axis + 1; - } else { - positive_axis_ = axis; - } - MS_LOG(INFO) << name_ << ": The axis is " << axis << ", and the positive axis is " << positive_axis_; - return SUCCESS; -} - -Status ExpandDimsInfo::InferTensorMap() { - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - - // for example: if the dimension of input is 3, and the axis is 2, - // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0] - std::vector input_tensor_map, output_tensor_map; - size_t size = inputs_shape_[0].size(); - for (size_t i = 0; i < size; ++i) { - input_tensor_map.push_back(SizeToInt(size - i - 1)); - } - - inputs_tensor_map_.push_back(input_tensor_map); - - output_tensor_map = input_tensor_map; - if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(size))) { - MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; - return FAILED; - } - (void)output_tensor_map.insert(output_tensor_map.begin() + positive_axis_, NO_SPLIT_MAP); - outputs_tensor_map_.push_back(output_tensor_map); - - MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) - << ", and the tensor map of output is " << ShapeToString(output_tensor_map); - return SUCCESS; -} - -Status ExpandDimsInfo::InferTensorStrategy() { - if (strategy_ == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null"; - return FAILED; - } - - inputs_strategy_ = strategy_->GetInputDim(); - if (inputs_strategy_.empty()) { - MS_LOG(ERROR) << name_ << ": The strategy is empty"; - return FAILED; - } - - Shape output_strategy = inputs_strategy_[0]; - if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(output_strategy.size()))) { - MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; - return FAILED; - } - (void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY); - outputs_strategy_ = {output_strategy}; - return SUCCESS; -} - -Status ExpandDimsInfo::InferTensorInfo() { - if (inputs_shape_.empty() || outputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; - return FAILED; - } - - if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; - return FAILED; - } - - Shape input_shape = inputs_shape_[0]; - Shape output_shape = outputs_shape_[0]; - - // infer slice shape - if (InferTensorStrategy() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer tensor strategy failed"; - return FAILED; - } - Shapes inputs_slice_shape, outputs_slice_shape; - if (InferSliceShape(inputs_strategy_, outputs_strategy_, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; - return FAILED; - } - - if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { - MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; - return FAILED; - } - - Shape input_slice_shape = inputs_slice_shape[0]; - Shape output_slice_shape = outputs_slice_shape[0]; - - TensorLayout input_tensor_layout, output_tensor_layout; - if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; - return FAILED; - } - - if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status ExpandDimsInfo::InferMirrorOps() { - mirror_ops_.clear(); - - if (inputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The tensor map of inputs is empty"; - return FAILED; - } - - std::vector group; - if (CreateGroupByTensorMap(inputs_tensor_map_[0], &group) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Create group failed"; - return FAILED; - } - - if (group.empty()) { - MS_LOG(INFO) << name_ << ": No need to create mirror ops"; - return SUCCESS; - } - - OperatorVector mirror_op, placeholder_op; - mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); - mirror_ops_.push_back(mirror_op); - mirror_ops_.push_back(placeholder_op); - MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name(); - return SUCCESS; -} - -Status SqueezeInfo::InferAxis(const ValueTuplePtr &value_tuple) { - std::vector axis; - auto axis_list = value_tuple->value(); - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - Shape input_shape = inputs_shape_.at(0); - size_t input_size = input_shape.size(); - // if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1. - if (axis_list.empty()) { - for (size_t i = 0; i < input_size; ++i) { - if (input_shape[i] == 1) { - axis.push_back(i); - } - } - axis_ = MakeValue(axis)->cast(); - return SUCCESS; - } - - // convert negative axis to positive. - for (auto &dim : axis_list) { - if (!dim->isa()) { - MS_LOG(ERROR) << name_ << ": The type of axis is not int"; - return FAILED; - } - int32_t dim_value = GetValue(dim); - int32_t positive_value = (dim_value < 0) ? (dim_value + SizeToInt(input_size)) : dim_value; - axis.push_back(positive_value); - } - axis_ = MakeValue(axis)->cast(); - return SUCCESS; -} - -Status SqueezeInfo::GetAttrs() { - auto iter = attrs_.find(AXIS); - if (iter == attrs_.end()) { - MS_LOG(ERROR) << name_ << ": Can't find axis attribute."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(iter->second); - auto value_tuple = iter->second->cast(); - MS_EXCEPTION_IF_NULL(value_tuple); - InferAxis(value_tuple); - attrs_[AXIS] = axis_; - return SUCCESS; -} - -Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) { - Attr attr = std::make_pair(AXIS, axis_); - OperatorAttrs attrs = {attr}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - replace_op_ = {std::make_pair(SQUEEZE, args)}; - return SUCCESS; -} - -Status SqueezeInfo::InferTensorMap() { - // for example: if the shape of input is [32, 32, 1], and the axis is (2, ), - // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1] - std::vector input_tensor_map, output_tensor_map; - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - size_t size = inputs_shape_[0].size(); - std::vector axis = GetValue>(axis_); - for (size_t i = 0; i < size; ++i) { - size_t index = size - i - 1; - auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); - if (iter == axis.end()) { - output_tensor_map.push_back(SizeToInt(index)); - } - input_tensor_map.push_back(SizeToInt(index)); - } - inputs_tensor_map_.push_back(input_tensor_map); - outputs_tensor_map_.push_back(output_tensor_map); - MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) - << ", and the tensor map of output is " << ShapeToString(output_tensor_map); - - return SUCCESS; -} - -Status SqueezeInfo::InferTensorInfo() { - if (inputs_shape_.empty() || outputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; - return FAILED; - } - - if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; - return FAILED; - } - - Shape input_shape = inputs_shape_[0]; - Shape output_shape = outputs_shape_[0]; - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Dimensions output_strategy; - std::vector axis = GetValue>(axis_); - for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { - auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); - if (iter == axis.end()) { - output_strategy.push_back(inputs_strategy[0].at(i)); - } - } - Strategys outputs_strategy = {output_strategy}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; - return FAILED; - } - - if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { - MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; - return FAILED; - } - - Shape input_slice_shape = inputs_slice_shape[0]; - Shape output_slice_shape = outputs_slice_shape[0]; - - // infer tensor layout - TensorLayout input_tensor_layout, output_tensor_layout; - if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; - return FAILED; - } - - if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status SqueezeInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - } - - if (InferReplaceOps(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Infer replace ops failed"; - } - - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h deleted file mode 100644 index cd66bf8e8b..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ /dev/null @@ -1,224 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ - -#include -#include -#include -#include -#include - -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class ActivationBase : public OperatorInfo { - public: - ActivationBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs, OperatorCostPtr cost) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} - ~ActivationBase() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - protected: - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; -}; - -class Activation : public ActivationBase { - public: - Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~Activation() override = default; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; -}; - -class ActivationInfo : public Activation { - public: - ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Activation(name, inputs_shape, outputs_shape, attrs) {} - ~ActivationInfo() override = default; - - protected: - Status GetAttrs() override; // activation_type: relu, relu6, sigmoid -}; - -class ActivationOther : public Activation { - public: - ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Activation(name, inputs_shape, outputs_shape, attrs) {} - ~ActivationOther() override = default; - - protected: - Status GetAttrs() override; -}; - -class GeluInfo : public ActivationOther { - public: - GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~GeluInfo() override = default; -}; - -class TanhInfo : public ActivationOther { - public: - TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~TanhInfo() override = default; -}; - -class Softmax : public ActivationBase { - public: - explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~Softmax() override = default; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override; - - private: - std::vector axis_; -}; - -class SoftmaxInfo : public Softmax { - public: - SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Softmax(name, inputs_shape, outputs_shape, attrs) {} - ~SoftmaxInfo() override = default; -}; - -class LogSoftmaxInfo : public Softmax { - public: - LogSoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Softmax(name, inputs_shape, outputs_shape, attrs) {} - ~LogSoftmaxInfo() override = default; -}; - -class ReLUInfo : public ActivationOther { - public: - ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~ReLUInfo() override = default; -}; - -class CastInfo : public ActivationOther { - public: - CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~CastInfo() override = default; - - protected: - Status InferMirrorOps() override; -}; - -class SqrtInfo : public ActivationOther { - public: - SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~SqrtInfo() override = default; -}; - -class NegInfo : public ActivationOther { - public: - NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~NegInfo() override = default; -}; - -class ExpandDimsInfo : public ActivationOther { - public: - ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~ExpandDimsInfo() override = default; - - protected: - Status GetAttrs() override; - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferMirrorOps() override; - Status InferTensorStrategy(); - - private: - int32_t positive_axis_ = -1; - Strategys inputs_strategy_; - Strategys outputs_strategy_; -}; - -class SqueezeInfo : public ActivationOther { - public: - SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~SqueezeInfo() override = default; - - protected: - Status InferAxis(const ValueTuplePtr &value_tuple); - Status GetAttrs() override; - Status InferReplaceOps(const StrategyPtr &strategy); - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status Init(const StrategyPtr &strategy) override; - - private: - ValueTuplePtr axis_; -}; - -class SquareInfo : public ActivationOther { - public: - SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~SquareInfo() override = default; -}; - -class SigmoidInfo : public ActivationOther { - public: - SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~SigmoidInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.cc b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.cc deleted file mode 100644 index 02c26ea965..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.cc +++ /dev/null @@ -1,363 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/arithmetic_info.h" - -#include -#include -#include -#include - -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Shape ExpendShape(const Shape &bigger_size_shape, Shape smaller_size_shape) { - size_t insert_num = bigger_size_shape.size() - smaller_size_shape.size(); - for (size_t num = 0; num < insert_num; ++num) { - (void)smaller_size_shape.insert(smaller_size_shape.begin(), 1); - } - return smaller_size_shape; -} - -Shapes ArithmeticBase::InferExpendShape() { - Shape input_a_shape = inputs_shape_.at(0); - Shape input_b_shape = inputs_shape_.at(1); - Shapes input_shapes; - size_t input_a_size = input_a_shape.size(); - size_t input_b_size = input_b_shape.size(); - if (input_a_size > input_b_size) { - input_shapes.push_back(input_a_shape); - input_shapes.push_back(ExpendShape(input_a_shape, input_b_shape)); - } else if (input_a_size < input_b_size) { - input_shapes.push_back(ExpendShape(input_b_shape, input_a_shape)); - input_shapes.push_back(input_b_shape); - } else { - input_shapes.push_back(input_a_shape); - input_shapes.push_back(input_b_shape); - } - return input_shapes; -} - -std::vector ExpendStrategy(const StrategyPtr &strategy) { - std::vector expend_strategy; - std::vector stra = strategy->GetInputDim(); - Dimensions sub_a_strategy = stra.at(0); - Dimensions sub_b_strategy = stra.at(1); - size_t input_a_size = sub_a_strategy.size(); - size_t input_b_size = sub_b_strategy.size(); - if (input_a_size > input_b_size) { - expend_strategy.push_back(sub_a_strategy); - expend_strategy.push_back(ExpendShape(sub_a_strategy, sub_b_strategy)); - } else if (input_a_size < input_b_size) { - expend_strategy.push_back(ExpendShape(sub_b_strategy, sub_a_strategy)); - expend_strategy.push_back(sub_b_strategy); - } else { - expend_strategy = stra; - } - return expend_strategy; -} - -Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - Shapes input_shapes = InferExpendShape(); - std::vector expend_strategy = ExpendStrategy(strategy); - Dimensions sub_a_strategy = expend_strategy.at(0); - Dimensions sub_b_strategy = expend_strategy.at(1); - Shape input_a_shape = input_shapes.at(0); - Shape input_b_shape = input_shapes.at(1); - - for (size_t i = 0; i < input_a_shape.size(); ++i) { - if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != 1) && (input_b_shape[i] != 1)) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - } - return SUCCESS; -} - -Status ArithmeticBase::InferDevMatrixShape() { - std::vector expend_strategy = ExpendStrategy(strategy_); - Dimensions sub_a_strategy = expend_strategy.at(0); - Dimensions sub_b_strategy = expend_strategy.at(1); - Shape dev_shape; - for (size_t i = 0; i < sub_a_strategy.size(); ++i) { - if (sub_a_strategy[i] != sub_b_strategy[i]) { - dev_shape.push_back(sub_a_strategy[i] * sub_b_strategy[i]); - } else { - dev_shape.push_back(sub_a_strategy[i]); - } - } - dev_matrix_shape_ = dev_shape; - - return SUCCESS; -} - -TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shape) { - TensorMap tensor_map_index; - for (size_t i = 0; i < strategy.size(); ++i) { - if (strategy[i] == dev_matrix_shape[i]) { - tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(strategy.size())) - i)); - } else { - tensor_map_index.push_back(-1); - } - } - return tensor_map_index; -} - -TensorMap SetTensorMap(const Shape &strategy_expend, const Shape &dev_matrix_shape, const Shape &strategy) { - TensorMap expend_map = SetExpendTensorMap(strategy_expend, dev_matrix_shape); - size_t dev_matrix_size = dev_matrix_shape.size(); - size_t strategy_size = strategy.size(); - if (dev_matrix_size != strategy_size) { - (void)expend_map.erase(expend_map.begin(), - expend_map.begin() + static_cast(dev_matrix_size - strategy_size)); - } - return expend_map; -} - -void ArithmeticBase::ReComputeBatchSplitFlagList() { - Shapes expend_shapes = InferExpendShape(); - Shape expend_a_shape = expend_shapes.at(0); - Shape expend_b_shape = expend_shapes.at(1); - if (expend_a_shape.size() != expend_b_shape.size()) { - MS_LOG(EXCEPTION) << name_ << " : Recompute batch split flag list is wrong."; - } - if (expend_a_shape.empty()) { - split_flag_list_[0] = false; - split_flag_list_[1] = false; - return; - } - (expend_a_shape.at(0) != 1) ? (split_flag_list_[0] = true) : (split_flag_list_[0] = false); - (expend_b_shape.at(0) != 1) ? (split_flag_list_[1] = true) : (split_flag_list_[1] = false); -} - -Status ArithmeticBase::InferTensorMap() { - std::vector tensor_map_index; - std::vector expend_strategy = ExpendStrategy(strategy_); - Dimensions sub_a_expend_strategy = expend_strategy.at(0); - Dimensions sub_b_expend_strategy = expend_strategy.at(1); - Strategys stra = strategy_->GetInputDim(); - Dimensions sub_a_strategy = stra.at(0); - Dimensions sub_b_strategy = stra.at(1); - for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { - tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_expend_strategy.size())) - i)); - } - - Shape dev_shape; - for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { - if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) { - dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]); - } else { - dev_shape.push_back(sub_a_expend_strategy[i]); - } - } - inputs_tensor_map_.push_back(SetTensorMap(sub_a_expend_strategy, dev_shape, sub_a_strategy)); - inputs_tensor_map_.push_back(SetTensorMap(sub_b_expend_strategy, dev_shape, sub_b_strategy)); - outputs_tensor_map_.push_back(tensor_map_index); - - return SUCCESS; -} - -Status ArithmeticBase::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_a_tensor_map = inputs_tensor_map_.at(0); - Shape input_b_tensor_map = inputs_tensor_map_.at(1); - std::vector input_a_group, input_b_group; - if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input a failed."; - return FAILED; - } - if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input b failed."; - return FAILED; - } - - OperatorVector op_for_input_a, op_for_input_b; - if (input_a_group.empty() && input_b_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror group is empty."; - return SUCCESS; - } - if (!input_a_group.empty()) { - op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); - } - if (!input_b_group.empty()) { - op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name(); - } - mirror_ops_.push_back(op_for_input_a); - mirror_ops_.push_back(op_for_input_b); - - return SUCCESS; -} - -Status ArithmeticBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, - const Shape &dev_matrix_array) { - if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { - MS_LOG(ERROR) << name_ << " : The layout is null."; - return FAILED; - } - TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0); - TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1); - TensorMap out_tensor_map_array = outputs_tensor_map_.at(0); - Shape input_a_shape_array = inputs_shape_.at(0); - Shape input_b_shape_array = inputs_shape_.at(1); - Shape out_shape_array = outputs_shape_.at(0); - - TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout; - if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) != - SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed."; - return FAILED; - } - if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) != - SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed."; - return FAILED; - } - if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed."; - return FAILED; - } - inputs_layout->push_back(input_a_tensor_layout); - inputs_layout->push_back(input_b_tensor_layout); - outputs_layout->push_back(out_tensor_layout); - - return SUCCESS; -} - -Status ArithmeticBase::InferTensorInfo() { - // infer tensor shape - Shape input_a_shape = inputs_shape_.at(0); - Shape input_b_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - std::vector expend_strategy = ExpendStrategy(strategy_); - Dimensions sub_a_expend_strategy = expend_strategy.at(0); - Dimensions sub_b_expend_strategy = expend_strategy.at(1); - Strategys inputs_strategy = strategy_->GetInputDim(); - Shape dev_shape; - for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { - if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) { - dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]); - } else { - dev_shape.push_back(sub_a_expend_strategy[i]); - } - } - Strategys outputs_strategy = {dev_shape}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_a_slice_shape = inputs_slice_shape.at(0); - Shape input_b_slice_shape = inputs_slice_shape.at(1); - Shape output_slice_shape = outputs_slice_shape.at(0); - - // infer tensor layout - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Infer tensor layout failed."; - return FAILED; - } - - TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape); - TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape); - TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape); - - inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a - inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b - outputs_tensor_info_.push_back(out_tensor_info); // output - - return SUCCESS; -} - -Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status ArithmeticBase::GenerateStrategies(int32_t stage_id) { - Shape input0_split(inputs_shape_[0].size(), 1); - Shape input1_split(inputs_shape_[1].size(), 1); - Shapes splittable_inputs = {input0_split, input1_split}; - - std::vector sp_vector; - is_auto_parallel_ = true; - if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies with broadcast failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success."; - - size_t success = 0; - for (auto &sp : sp_vector) { - PrintStrategy(sp); - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status ArithmeticBase::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status ArithmeticBase::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h deleted file mode 100644 index 27caacc30c..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class ArithmeticBase : public OperatorInfo { - public: - ArithmeticBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs, OperatorCostPtr cost) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} - ~ArithmeticBase() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr &) override; - void ReComputeBatchSplitFlagList() override; - - protected: - Status GetAttrs() override { return SUCCESS; } - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); - Shapes InferExpendShape(); -}; - -class SubInfo : public ArithmeticBase { - public: - SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~SubInfo() override = default; -}; - -class TensorAddInfo : public ArithmeticBase { - public: - TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~TensorAddInfo() override = default; -}; - -class MulInfo : public ArithmeticBase { - public: - MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~MulInfo() override = default; -}; - -class DivInfo : public ArithmeticBase { - public: - DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~DivInfo() override = default; -}; - -class RealDivInfo : public ArithmeticBase { - public: - RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~RealDivInfo() override = default; -}; - -class FloorDivInfo : public ArithmeticBase { - public: - FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~FloorDivInfo() override = default; -}; - -class PowInfo : public ArithmeticBase { - public: - PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~PowInfo() override = default; -}; - -class GreaterInfo : public ArithmeticBase { - public: - GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~GreaterInfo() override = default; -}; - -class AssignSubInfo : public ArithmeticBase { - public: - AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~AssignSubInfo() override = default; -}; - -// All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label. -class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { - public: - SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~SigmoidCrossEntropyWithLogitsInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc deleted file mode 100644 index dac3b0a675..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc +++ /dev/null @@ -1,235 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/batch_parallel_info.h" - -#include -#include -#include - -#include "ir/value.h" -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/step_parallel.h" - -namespace mindspore { -namespace parallel { -Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - - int32_t stage = strategy->GetInputStage(); - CheckGlobalDeviceManager(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); - dev_num_ = dev_num; - - size_t strategy_size = strategy->GetInputNumber(); - std::vector stra = strategy->GetInputDim(); - for (size_t i = 0; i < strategy_size; ++i) { - Shape sub_strategy = stra.at(i); - size_t strategy_len = sub_strategy.size(); - bool flag = false; - for (size_t j = 0; j < strategy_len; ++j) { - int32_t strategy_value = sub_strategy.at(j); - if (strategy_value > 1) { - if (flag || strategy_value != dev_num_) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : It is not a valid data parallel strategy."; - } else { - MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; - } - return FAILED; - } - flag = true; - } - } - } - return SUCCESS; -} - -Status BatchParallelInfo::InferDevMatrixShape() { - dev_matrix_shape_.push_back(dev_num_); - return SUCCESS; -} - -Status BatchParallelInfo::InferMirrorOps() { - mirror_ops_.clear(); - if (g_device_manager->DeviceNum() == 1) { - MS_LOG(INFO) << name_ << " : The device num is 1, no need to create mirror ops."; - return SUCCESS; - } - - MS_LOG(INFO) << name_ << " : Batch parallel input number " << strategy_->GetInputNumber(); - for (size_t i = 0; i < input_value_.size(); i++) { - MS_EXCEPTION_IF_NULL(g_device_manager); - OperatorVector op_vec = CreateMirrorOps(g_device_manager->world_group(), g_device_manager->DeviceNum()); - mirror_ops_.push_back(op_vec); - } - return SUCCESS; -} - -Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } - -Status BatchParallelInfo::InferTensorMap() { - if (strategy_->GetInputDim()[0][0] != dev_num_) { - MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; - return FAILED; - } - for (size_t i = 0; i < inputs_shape_.size(); i++) { - std::vector tensor_map_index; - for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { - if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) { - tensor_map_index.push_back(0); - } else { - tensor_map_index.push_back(MAP_NONE); - } - } - inputs_tensor_map_.push_back(tensor_map_index); - } - for (size_t i = 0; i < outputs_shape_.size(); i++) { - std::vector tensor_map_index; - for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { - if (i == 0 && j == 0) { - tensor_map_index.push_back(0); - } else { - tensor_map_index.push_back(MAP_NONE); - } - } - outputs_tensor_map_.push_back(tensor_map_index); - } - return SUCCESS; -} - -Strategys BatchParallelInfo::GetOutputsStrategy() { - Strategys outputs_strategy; - - for (size_t i = 0; i < outputs_shape_.size(); ++i) { - std::vector strategy; - for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { - if (i == 0 && j == 0) { - strategy.push_back(dev_num_); - } else { - strategy.push_back(1); - } - } - outputs_strategy.push_back(strategy); - } - - return outputs_strategy; -} - -Status BatchParallelInfo::InferTensorInfo() { - for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { - MS_LOG(INFO) << name_ << " : The input size is " << strategy_->GetInputNumber(); - TensorLayout tensor_layout_in; - if (tensor_layout_in.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(i), inputs_shape_.at(i)) != SUCCESS) { - return FAILED; - } - TensorInfo tensor_info_in(tensor_layout_in); - inputs_tensor_info_.push_back(tensor_info_in); - } - for (size_t i = 0; i < outputs_shape_.size(); i++) { - TensorLayout tensor_layout_out; - if (tensor_layout_out.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(i), outputs_shape_.at(i)) != - SUCCESS) { - return FAILED; - } - TensorInfo tensor_info_out(tensor_layout_out); - outputs_tensor_info_.push_back(tensor_info_out); - } - return SUCCESS; -} - -Status BatchParallelInfo::GetAttrs() { return SUCCESS; } - -Status BatchParallelInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} - -Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) { - CheckGlobalDeviceManager(); - is_auto_parallel_ = true; - size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - StrategyPtr sp; - std::vector strategy; - for (size_t i = 0; i < inputs_shape_.size(); i++) { - Shape temp(inputs_shape_[i].size(), 1); - if (split_flag_list_[i]) { - temp[0] = SizeToInt(total_dev_num); - } - strategy.push_back(temp); - } - sp = std::make_shared(stage_id, strategy); - - if (SetCostUnderStrategy(sp) == SUCCESS) { - MS_LOG(INFO) << name_ << " : Successfully generated batch-parallel-strategy."; - PrintStrategy(sp); - } else { - MS_LOG(ERROR) << name_ << " : Generating batch-parallel-strategy failed."; - return FAILED; - } - return SUCCESS; -} - -void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { - for (size_t i = 0; i < inputs_shape_.size(); i++) { - split_flag_list_[i] = true; - } -} - -Status BatchParallelInfo::InferAsLossDivisor() { - as_loss_divisor_ = 1; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h deleted file mode 100644 index db6cb206d5..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ - -#include -#include -#include -#include -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class BatchParallelInfo : public OperatorInfo { - public: - BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs, OperatorCostPtr cost) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} - BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), - dev_num_(1) {} - - ~BatchParallelInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status GetAttrs() override; - Strategys GetOutputsStrategy(); - Status InferAsLossDivisor() override; - - private: - int32_t dev_num_; -}; - -class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { - public: - SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, - const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; - void ReComputeBatchSplitFlagList() override; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/bias_add_info.cc b/mindspore/ccsrc/parallel/ops_info/bias_add_info.cc deleted file mode 100644 index 005edaf7c7..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/bias_add_info.cc +++ /dev/null @@ -1,261 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/bias_add_info.h" - -#include -#include -#include -#include - -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - std::vector stra = strategy->GetInputDim(); - Dimensions sub_a_strategy = stra.at(0); - Dimensions sub_b_strategy = stra.at(1); - int32_t channel_a_strategy = sub_a_strategy.at(1); - int32_t channel_b_strategy = sub_b_strategy.at(0); - if (channel_a_strategy != channel_b_strategy) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - return SUCCESS; -} - -Status BiasAddInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions sub_a_strategy = stra.at(0); - dev_matrix_shape_ = sub_a_strategy; - return SUCCESS; -} - -void BiasAddInfo::ReComputeBatchSplitFlagList() { - split_flag_list_[0] = true; - split_flag_list_[1] = false; -} - -Status BiasAddInfo::InferTensorMap() { - TensorMap sub_a_tensor_map; - TensorMap sub_b_tensor_map; - std::vector stra = strategy_->GetInputDim(); - Dimensions sub_a_strategy = stra.at(0); - size_t sub_a_strategy_size = sub_a_strategy.size(); - for (size_t i = 0; i < sub_a_strategy_size; ++i) { - sub_a_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - i)); - } - sub_b_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - 1)); - - inputs_tensor_map_.push_back(sub_a_tensor_map); - inputs_tensor_map_.push_back(sub_b_tensor_map); - outputs_tensor_map_.push_back(sub_a_tensor_map); - - return SUCCESS; -} - -Status BiasAddInfo::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_a_tensor_map = inputs_tensor_map_.at(0); - Shape input_b_tensor_map = inputs_tensor_map_.at(1); - std::vector input_a_group, input_b_group; - if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input a failed."; - return FAILED; - } - if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input b failed."; - return FAILED; - } - - OperatorVector op_for_input_a, op_for_input_b; - if (input_a_group.empty() && input_b_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror group is empty."; - return SUCCESS; - } - if (!input_a_group.empty()) { - op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); - } - if (!input_b_group.empty()) { - op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name(); - } - mirror_ops_.push_back(op_for_input_a); - mirror_ops_.push_back(op_for_input_b); - - return SUCCESS; -} - -Status BiasAddInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, - const Shape &dev_matrix_array) { - if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { - MS_LOG(ERROR) << name_ << " : The layout is null."; - return FAILED; - } - TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0); - TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1); - TensorMap out_tensor_map_array = outputs_tensor_map_.at(0); - Shape input_a_shape_array = inputs_shape_.at(0); - Shape input_b_shape_array = inputs_shape_.at(1); - Shape out_shape_array = outputs_shape_.at(0); - - TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout; - if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) != - SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed."; - return FAILED; - } - if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) != - SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed."; - return FAILED; - } - if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed."; - return FAILED; - } - inputs_layout->push_back(input_a_tensor_layout); - inputs_layout->push_back(input_b_tensor_layout); - outputs_layout->push_back(out_tensor_layout); - - return SUCCESS; -} - -Status BiasAddInfo::InferTensorInfo() { - // infer tensor shape - Shape input_a_shape = inputs_shape_.at(0); - Shape input_b_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {inputs_strategy.at(0)}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_a_slice_shape = inputs_slice_shape.at(0); - Shape input_b_slice_shape = inputs_slice_shape.at(1); - Shape output_slice_shape = outputs_slice_shape.at(0); - - // infer tensor layout - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Infer tensor layout failed."; - return FAILED; - } - - TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape); - TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape); - TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape); - - inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a - inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b - outputs_tensor_info_.push_back(out_tensor_info); // output - - return SUCCESS; -} - -Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status BiasAddInfo::GenerateStrategies(int32_t stage_id) { - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split, input0_split}; - - std::vector sp_vector; - is_auto_parallel_ = true; - Shapes tmp_inputs_shape = {inputs_shape_[0], inputs_shape_[0]}; - Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]}; - if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, &sp_vector) != - SUCCESS) { - return FAILED; - } - MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success."; - - for (auto &sp : sp_vector) { - std::vector tmp_strategy; - Dimensions input0_strategy = sp->GetInputDim()[0]; - tmp_strategy.push_back(input0_strategy); // input0 - - Dimensions input1_strategy = {input0_strategy.at(1)}; - - // reset the strategy - tmp_strategy.push_back(input1_strategy); // input1 - sp->ResetInputs(tmp_strategy); - } - size_t success = 0; - for (auto &sp : sp_vector) { - PrintStrategy(sp); - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status BiasAddInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status BiasAddInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h deleted file mode 100644 index 37f555a258..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ - -#include - -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class BiasAddInfo : public OperatorInfo { - public: - BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~BiasAddInfo() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr &) override; - void ReComputeBatchSplitFlagList() override; - - protected: - Status GetAttrs() override { return SUCCESS; } - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h deleted file mode 100644 index 8dd2976b04..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ - -#include -#include -#include -#include -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/arithmetic_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class EqualInfo : public ArithmeticBase { - public: - EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~EqualInfo() override = default; -}; - -class NotEqualInfo : public ArithmeticBase { - public: - NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~NotEqualInfo() override = default; -}; - -class MaximumInfo : public ArithmeticBase { - public: - MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~MaximumInfo() override = default; -}; - -class MinimumInfo : public ArithmeticBase { - public: - MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~MinimumInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc deleted file mode 100644 index e88868c772..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc +++ /dev/null @@ -1,323 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/dropout_do_mask_info.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "pipeline/resource.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -static int32_t SEED_NUM = 1; - -Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { - if (strategy == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null"; - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - if (stra.size() != 1) { - MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1"; - return FAILED; - } - - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - - // only check the input[0] - Shapes input_shape = {inputs_shape_[0]}; - if (CheckStrategyValue(strategy, input_shape, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy"; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy"; - } - return FAILED; - } - return SUCCESS; -} - -Status DropoutDoMaskInfo::InferDevMatrixShape() { - if (strategy_ == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null"; - return FAILED; - } - - std::vector strategy = strategy_->GetInputDim(); - if (strategy.empty()) { - MS_LOG(ERROR) << name_ << ": The strategy is empty"; - return FAILED; - } - - dev_matrix_shape_ = strategy[0]; - return SUCCESS; -} - -Status DropoutDoMaskInfo::InferTensorMap() { - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - - std::vector tensor_map_index; - size_t size = inputs_shape_[0].size(); - // if the dimension of input is 4, and tensor_map_index is [3, 2, 1, 0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back(SizeToInt(size - i - 1)); - } - - // the input[1] do not need tensor map - inputs_tensor_map_.push_back(tensor_map_index); // input_0 - outputs_tensor_map_.push_back(tensor_map_index); // output - return SUCCESS; -} - -Status DropoutDoMaskInfo::InferTensorInfo() { - if (inputs_shape_.size() != 3) { - MS_LOG(ERROR) << name_ << ": Invalid inputs shape size " << inputs_shape_.size(); - return FAILED; - } - - if (strategy_ == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null"; - return FAILED; - } - - Shape input_0_shape = inputs_shape_[0]; - - if (inputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty"; - return FAILED; - } - - TensorLayout input_0_tensor_layout; - if (input_0_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_0_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout failed"; - return FAILED; - } - - TensorInfo input_0_tensor_info(input_0_tensor_layout); - - // input_1 do not need tensor info - inputs_tensor_info_.push_back(input_0_tensor_info); // input_0 - outputs_tensor_info_.push_back(input_0_tensor_info); // output - return SUCCESS; -} - -Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - Shapes used_inputs_shape = {inputs_shape_[0]}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, used_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Generate strategies failed"; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy"; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -std::shared_ptr>> DropoutDoMaskInfo::GenerateBatchStrategies() { - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - Dimensions strategy(inputs_shape_[0].size() - 1, 1); - (void)strategy.insert(strategy.begin(), SizeToInt(dev_num)); - std::vector strategy_v = {strategy}; - return std::make_shared>>(strategy_v); -} - -Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; - } - - AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX); - MS_EXCEPTION_IF_NULL(dropout_gen_mask); - if (!dropout_gen_mask->isa()) { - MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode"; - } - - auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); - if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; - } - if (!IsValueNode(dropout_gen_mask_cnode->input(0))) { - MS_LOG(EXCEPTION) << "The input[0] of dropout gen mask cnode is not primitive"; - } - - ValueNodePtr value_node = dropout_gen_mask_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(value_node); - PrimitivePtr prim = value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() != DROPOUT_GEN_MASK) { - MS_LOG(EXCEPTION) << "The primitive name is not DropoutGenMask"; - } - return prim; -} - -void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) { - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; - } - - AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX); - MS_EXCEPTION_IF_NULL(dropout_gen_mask); - if (!dropout_gen_mask->isa()) { - MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode."; - } - - auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); - if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; - } - - if (!IsValueNode(dropout_gen_mask_cnode->input(1))) { - MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple."; - } - - FuncGraphPtr func_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr."; - } - - ValuePtr new_shape = MakeValue(input_slice_shape); - AnfNodePtr val = NewValueNode(new_shape); - (void)manager->Replace(dropout_gen_mask_cnode->input(1), val); -} - -// DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is -// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape -// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation -// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. -std::vector DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { - std::vector replace_ops; - MS_EXCEPTION_IF_NULL(cnode); - PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); - MS_EXCEPTION_IF_NULL(prim); - - if (inputs_tensor_info_.empty()) { - MS_LOG(EXCEPTION) << "The tensor info of dropout do mask is empty"; - } - - if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; - } - - if (!cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX)->isa()) { - MS_LOG(EXCEPTION) << "The keep prob of dropout do mask is not value node"; - } - - ValuePtr keep_prob = GetValueNode(cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX)); - MS_EXCEPTION_IF_NULL(keep_prob); - auto attr = prim->attrs(); - if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) { - MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1"; - } - - Shape input_slice_shape = inputs_tensor_info_[0].slice_shape(); - int32_t seed_0 = GetValue(attr[SEED0]); - int32_t seed_1 = GetValue(attr[SEED1]); - if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) { - seed_0 = SEED_NUM; - seed_1 = SEED_NUM; - SEED_NUM++; - } else { - SetGenMaskShape(cnode, input_slice_shape); - MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape); - return replace_ops; - } - - ValuePtr new_shape = MakeValue(input_slice_shape); - Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0)); - Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1)); - OperatorAttrs attrs = {attr_0, attr_1}; - Attr param_0 = std::make_pair(SHAPE, new_shape); - Attr param_1 = std::make_pair(KEEP_PROB, keep_prob); - OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)}; - OperatorArgs args = std::make_pair(attrs, params); - Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)}; - replace_ops.push_back(replace_op); - return replace_ops; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h deleted file mode 100644 index c51a0a9513..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class DropoutDoMaskInfo : public OperatorInfo { - public: - DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~DropoutDoMaskInfo() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - std::shared_ptr>> GenerateBatchStrategies() override; - std::vector GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorMap() override; - Status GetAttrs() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; -}; - -using DropoutDoMaskInfoPtr = std::shared_ptr; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h deleted file mode 100644 index 2172c5cd89..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ - -#include -#include -#include -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class ExpInfo : public ActivationOther { - public: - ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~ExpInfo() override = default; -}; - -class LogInfo : public ActivationOther { - public: - LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~LogInfo() override = default; -}; - -class CosInfo : public ActivationOther { - public: - CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~CosInfo() override = default; -}; - -class ACosInfo : public ActivationOther { - public: - ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~ACosInfo() override = default; -}; - -class LogicalNotInfo : public ActivationOther { - public: - LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~LogicalNotInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc deleted file mode 100644 index 078be08128..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc +++ /dev/null @@ -1,350 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/gather_v2_info.h" - -#include -#include -#include - -#include "ir/tensor.h" -#include "ir/value.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/device_matrix.h" -#include "parallel/graph_util/generate_graph.h" -#include "parallel/strategy.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status GatherV2Info::GetAttrs() { - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size(); - return FAILED; - } - if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) { - MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size(); - return FAILED; - } - // the second input is the index tensor - - // the third input is the axis, is a ValueNode - if (input_value_.at(2) == nullptr) { - MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; - return FAILED; - } - - if (inputs_shape_.at(0).size() == 0) { - MS_LOG(ERROR) << name_ << ": input can not be a scalar!"; - return FAILED; - } - int axis = GetValue(input_value_.at(2)); - if (axis >= SizeToInt(inputs_shape_.at(0).size()) || axis < 0 - SizeToInt(inputs_shape_.at(0).size())) { - MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", " - << inputs_shape_.at(0).size() << ")."; - } - if (axis < 0) { - axis += SizeToInt(inputs_shape_[0].size()); - } - axis_ = axis; - - index_size_ = inputs_shape_.at(1).size(); - - return SUCCESS; -} - -Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " - << outputs_shape_.size(); - return FAILED; - } - // Only strategy of the first input should be set. - if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - axis_strategy_ = strategy->GetInputDim().at(0).at(axis_); - if (index_size_ != 1 && axis_strategy_ != 1) { - MS_LOG(ERROR) << name_ - << ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy " - "corresponding to axis must be 1, but is " - << axis_strategy_; - return FAILED; - } - if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) { - MS_LOG(ERROR) << name_ - << ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to " - "axis. The first dimension of index is " - << inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_; - return FAILED; - } - return SUCCESS; -} - -Status GatherV2Info::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - dev_matrix_shape_ = stra.at(0); - return SUCCESS; -} - -// If index is a scalar, output dimension is input dimension minus 1; -// If index is a n dimension tensor, output dimension is input dimension plus (n - 1). -// Tensor map dimension is equal to the corresponding input and output dimension. -// If index's dimension is more than 1, we insert -1 for the output tensor map. -Status GatherV2Info::InferTensorMap() { - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " - << outputs_shape_.size(); - return FAILED; - } - std::vector tensor_map_in; - std::vector tensor_map_out; - size_t size = inputs_shape_.at(0).size(); - // such as 4: tensor_map_index [3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_in.push_back(SizeToInt(size - i - 1)); - tensor_map_out.push_back(SizeToInt(size - i - 1)); - } - - if (index_size_ == 0) { - (void)tensor_map_out.erase(tensor_map_out.begin() + axis_); - } else if (index_size_ > 1) { - (void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1); - } - if (tensor_map_out.size() != outputs_shape_.at(0).size()) { - MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size() - << " output size is " << outputs_shape_.at(0).size(); - return FAILED; - } - - std::vector tensor_map_in_index; - if (index_size_ >= 1) { - tensor_map_in_index.push_back(SizeToInt(size - axis_ - 1)); - } - for (size_t i = 1; i < index_size_; ++i) { - tensor_map_in_index.push_back(-1); - } - inputs_tensor_map_.emplace_back(std::move(tensor_map_in)); - inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index)); - outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); - return SUCCESS; -} - -Status GatherV2Info::InferTensorInfo() { - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " - << outputs_shape_.size(); - return FAILED; - } - if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_tensor_map_.size(); - return FAILED; - } - if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " - << outputs_tensor_map_.size(); - return FAILED; - } - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape input_index_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - - TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || - (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || - (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) { - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout); - TensorInfo input_index_info(input_index_layout); - TensorInfo output_tensor_info(output_tensor_layout); - - inputs_tensor_info_.push_back(input_tensor_info); - inputs_tensor_info_.push_back(input_index_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -OperatorVector CreateSubOp(int32_t sub_value) { - OperatorVector ops; - OperatorName operator_name = SUB; - OperatorAttrs operator_attrs; - - std::vector tensor_data = {sub_value}; - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(tensor_data, kInt32); - ValuePtr op_param_value = MakeValue(tensor_ptr); - - Attr op1_param = std::make_pair("", op_param_value); - OperatorParams operator_param = {std::make_pair(op1_param, 2)}; - - OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); - Operator op = std::make_pair(operator_name, operator_args); - ops.push_back(op); - return ops; -} - -Status GatherV2Info::InferTensorSubOps() { - sub_ops_.clear(); - if ((index_size_ == 0) || (axis_strategy_ == 1)) { - return SUCCESS; - } - int32_t mod_n = 1; - for (size_t i = IntToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) { - mod_n *= dev_matrix_shape_.at(i); - } - if ((axis_ >= SizeToInt(dev_matrix_shape_.size())) || axis_ < 0) { - MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; - } - int32_t mod_p = mod_n * dev_matrix_shape_.at(axis_); - int32_t rank = g_device_manager->global_rank(); - int32_t mod_rank = rank % mod_p; - mod_rank = static_cast(mod_rank / mod_n); - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_shape_.size(); - return FAILED; - } - if ((axis_ >= SizeToInt(inputs_shape_.at(0).size())) || axis_ < 0) { - MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ")."; - } - int32_t sub_value = static_cast(inputs_shape_.at(0).at(axis_) / dev_matrix_shape_.at(axis_)) * mod_rank; - - OperatorVector sub_op; - sub_ops_.emplace_back(std::move(sub_op)); - sub_op = CreateSubOp(sub_value); - sub_ops_.emplace_back(std::move(sub_op)); - return SUCCESS; -} - -Status GatherV2Info::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - Status status = InferTensorSubOps(); - if (status != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed."; - return status; - } - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status GatherV2Info::GenerateStrategies(int32_t stage_id) { - if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" - << outputs_shape_.size() << "is wrong."; - return FAILED; - } - - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != - SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -std::shared_ptr>> GatherV2Info::GenerateBatchStrategies() { - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_shape_.size(); - } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - if (GetAttrs() != SUCCESS) { - MS_LOG(EXCEPTION) << "GetAttrs failed!"; - } - - Dimensions strategy; - if (index_size_ != 1) { - strategy.push_back(1); - } else { - strategy.push_back(SizeToInt(dev_num)); - } - for (size_t i = 1; i < inputs_shape_[0].size(); i++) { - strategy.push_back(1); - } - std::vector strategy_v = {strategy}; - return std::make_shared>>(strategy_v); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h deleted file mode 100644 index f7aeb6a0d9..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -constexpr size_t GATHER_V2_INPUTS_SIZE = 2; -constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1; -constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3; -// We now supported limited parallel strategies. -// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of -// the input. -// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1. -class GatherV2Info : public OperatorInfo { - public: - GatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), - axis_(-1), - index_size_(0), - axis_strategy_(1) {} - ~GatherV2Info() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - std::shared_ptr>> GenerateBatchStrategies() override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status GetAttrs() override; - - private: - Status InferTensorSubOps(); - - int32_t axis_; - size_t index_size_; - int32_t axis_strategy_; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc deleted file mode 100644 index 9fb8df0883..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ /dev/null @@ -1,556 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/gather_v2_p_info.h" - -#include -#include -#include -#include - -#include "parallel/device_matrix.h" -#include "parallel/graph_util/generate_graph.h" - -namespace mindspore { -namespace parallel { -Status GatherV2PInfo::GetAttrs() { - // get axis, the third input is the axis, is a ValueNode - if (input_value_.at(2) == nullptr) { - MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; - return FAILED; - } - auto axis = GetValue(input_value_.at(2)); - // if axis is negative then convert it to positive - auto params_shape = inputs_shape_.at(0); - if (params_shape.size() == 0) { - MS_LOG(ERROR) << name_ << ": params can not be a scalar!"; - return FAILED; - } - if (axis < 0) { - axis += SizeToInt(inputs_shape_[0].size()); - } - axis_ = axis; - - // get target - auto target_iter = attrs_.find(TARGET); - if (target_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(target_iter->second); - if (target_iter->second->isa()) { - target_ = target_iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of target is not a string."; - return FAILED; - } - } - - // target=CPU, axis must be 0 - if (target_ == "CPU" && axis_ != 0) { - MS_LOG(ERROR) << name_ << ": target is CPU, axis must be 0, but got " << axis_; - return FAILED; - } - - return SUCCESS; -} - -Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - // param slice shape need 32Byte aligned - auto param_shape = inputs_shape_.at(0); - auto param_strategy = strategy->GetInputDim().at(0); - auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1); - if ((target_ != CPU) && (slice_shape % 8 != 0) && (slice_shape != 1)) { - MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned."; - return FAILED; - } - - // only support 1-dim and 2-dim param - if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) { - MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size(); - return FAILED; - } - - // don't support scalar index - if (inputs_shape_.at(1).size() == 0) { - MS_LOG(DEBUG) << name_ << ": Don't support scalar index."; - return FAILED; - } - - // axis=0, index_shape(0)%param_strategy(0) must be 0 - Shape index_shape = inputs_shape_.at(1); - if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) { - MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; - return FAILED; - } - - // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 - if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { - MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; - return FAILED; - } - - // param_strategy(axis) != 1, index can't be splited - auto index_strategy = strategy->GetInputDim().at(1); - auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); - if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) { - MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; - return FAILED; - } - - // param_strategy(axis) != 1, Don't support repeated calc - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); - if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; - return FAILED; - } - - return SUCCESS; -} - -Status GatherV2PInfo::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_a_tensor_map = inputs_tensor_map_.at(0); - std::vector input_a_group; - if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input a failed."; - return FAILED; - } - - OperatorVector op_for_input_a, op_for_input_b, op_for_axis; - if (input_a_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror group is empty."; - return SUCCESS; - } else { - op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); - } - - mirror_ops_.push_back(op_for_input_a); - mirror_ops_.push_back(op_for_input_b); - mirror_ops_.push_back(op_for_axis); - - return SUCCESS; -} - -Status GatherV2PInfo::InferDevMatrixShape() { - dev_matrix_shape_.clear(); - out_dev_matrix_shape_.clear(); - // infer input dev_matrix_shape - auto param_strategy = strategy_->GetInputDim().at(0); - auto index_strategy = strategy_->GetInputDim().at(1); - dev_matrix_shape_ = param_strategy; - - // param_strategy(axis)!=1, - if (param_strategy.at(IntToSize(axis_)) != 1) { - std::reverse(dev_matrix_shape_.begin(), dev_matrix_shape_.end()); - } else { - dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end()); - } - - // infer out dev_matrix_shape - // axis!=0, split axis - if (axis_ != 0 && param_strategy.at(IntToSize(axis_)) != 1) { - out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(IntToSize(axis_))); - for (size_t i = 1; i < param_strategy.size(); ++i) { - if (i == IntToSize(axis_)) { - out_dev_matrix_shape_.push_back(1); - } else { - out_dev_matrix_shape_.push_back(param_strategy.at(i)); - } - } - } else { - out_dev_matrix_shape_ = dev_matrix_shape_; - } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); - auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); - if (param_product * index_product < SizeToInt(dev_num)) { - out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), SizeToInt(dev_num / (param_product * index_product))); - } - - return SUCCESS; -} - -Status GatherV2PInfo::InferTensorMap() { - // infer input tensor map - // param_strategy(axis) != 1 - size_t param_size = inputs_shape_.at(0).size(); - size_t index_size = inputs_shape_.at(1).size(); - size_t total_size = param_size + index_size; - std::vector tensor_map_index; - std::vector tensor_map_params; - auto param_strategy = strategy_->GetInputDim().at(0); - if (param_strategy.at(IntToSize(axis_)) != 1) { - tensor_map_index.insert(tensor_map_index.begin(), index_size, -1); - for (size_t i = 0; i < param_size; ++i) { - tensor_map_params.push_back(SizeToInt(i)); - } - } else { - // param_strategy(axis) == 1 - for (size_t i = 0; i < param_size; ++i) { - tensor_map_params.push_back(SizeToInt(total_size - i - 1)); - } - for (size_t i = 0; i < index_size; ++i) { - tensor_map_index.push_back(SizeToInt(index_size - i - 1)); - } - } - - // infer output tensor map - std::vector tensor_map_out; - if (param_strategy.at(IntToSize(axis_)) == 1) { - // param_strategy(axis) == 1 - for (size_t i = 0; i < param_size; ++i) { - if (i == IntToSize(axis_)) { - for (size_t j = 0; j < index_size; ++j) { - tensor_map_out.push_back(SizeToInt(index_size - j - 1)); - } - } else { - tensor_map_out.push_back(SizeToInt(total_size - i - 1)); - } - } - } else { - // param_strategy(axis) != 1 - if (axis_ == 0) { - tensor_map_out.insert(tensor_map_out.end(), 0); - tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); - for (size_t i = 1; i < param_size; ++i) { - tensor_map_out.push_back(i); - } - } else { - for (size_t i = 0; i < param_size; ++i) { - if (i == IntToSize(axis_)) { - tensor_map_out.insert(tensor_map_out.end(), index_size, -1); - } else { - tensor_map_out.push_back(SizeToInt(param_size - i - 1)); - } - } - } - } - - inputs_tensor_map_.emplace_back(std::move(tensor_map_params)); - inputs_tensor_map_.emplace_back(std::move(tensor_map_index)); - outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); - return SUCCESS; -} - -Status GatherV2PInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape input_index_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - // infer tensor layout - TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || - (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || - (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != - SUCCESS)) { - return FAILED; - } - // infer tensor info - TensorInfo input_tensor_info(input_tensor_layout); - TensorInfo input_index_info(input_index_layout); - TensorInfo output_tensor_info(output_tensor_layout); - - inputs_tensor_info_.push_back(input_tensor_info); - inputs_tensor_info_.push_back(input_index_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status GatherV2PInfo::InferBias() { - CheckGlobalDeviceManager(); - int32_t rank = g_device_manager->global_rank(); - auto input_shape = inputs_shape_.at(0); - auto params_strategy = strategy_->GetInputDim().at(0); - // axis don't split - if (params_strategy.at(axis_) == 1) { - bias_ = 0; - return SUCCESS; - } - // params_size=1, axis=0 - if ((input_shape.size() == 1) && (axis_ == 0)) { - slice_size_ = input_shape.at(0) / params_strategy.at(0); - bias_ = rank * slice_size_; - return SUCCESS; - } - // params_size=2, axis=0 - if ((input_shape.size() == 2) && (axis_ == 0)) { - slice_size_ = input_shape.at(0) / params_strategy.at(0); - bias_ = rank / params_strategy.at(1) * slice_size_; - return SUCCESS; - } - // params_size=2, axis=1 - if ((input_shape.size() == 2) && (axis_ == 1)) { - slice_size_ = input_shape.at(1) / params_strategy.at(1); - bias_ = rank % params_strategy.at(1) * slice_size_; - return SUCCESS; - } - MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_; - return FAILED; -} - -Status GatherV2PInfo::InferGroup() { - auto param_strategy = strategy_->GetInputDim().at(0); - size_t dim = IntToSize(axis_); - if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { - dim = (axis_ + 1) % 2; - } - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - int32_t rank = g_device_manager->global_rank(); - RankList dev_list = g_device_manager->GetDeviceListByStageId(0); - DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); - RankList group_devices; - if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Create group failed."; - return FAILED; - } - if (group_devices.size() == 1) { - MS_LOG(INFO) << "the group is empty"; - return SUCCESS; - } - - group_ = g_device_manager->CreateGroup(group_devices); - return SUCCESS; -} - -std::vector GetRankFromGroup(const Group &group) { - std::vector rank_list; - auto device_list = group.GetDevicesList(); - for (auto &device : device_list) { - rank_list.insert(rank_list.end(), device.rank() % 8); - } - return rank_list; -} - -Status GatherV2PInfo::InferForwardCommunication() { - forward_op_.clear(); - if (target_ != CPU) { - return SUCCESS; - } - auto param_strategy = strategy_->GetInputDim().at(0); - // don't split axis, no need forward communication - if (param_strategy.at(IntToSize(axis_)) == 1) { - return SUCCESS; - } - // split axis - OperatorName operator_name; - if (InferGroup() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer Group failed."; - return FAILED; - } - auto group_size = group_.GetDevNum(); - Attr attr_group; - if (host_reduce_scatter_) { - // group size <= 8 - std::vector rank_list; - if (group_size <= 8) { - reduce_scatter_flag_ = false; - operator_name = HOST_REDUCE_SCATTER; - rank_list = GetRankFromGroup(group_); - attr_group = std::make_pair(GROUP, MakeValue(rank_list)); - } else { - // group size > 8, don't support host reduce_scatter - reduce_scatter_flag_ = true; - split_num_ = SizeToInt(group_size / 8); - CheckGlobalDeviceManager(); - operator_name = REDUCE_SCATTER; - int32_t rank = g_device_manager->global_rank(); - size_t repeat = group_size / 8; - for (size_t i = 0; i < repeat; ++i) { - rank_list.push_back(rank + SizeToInt(i * 8)); - } - Group g = g_device_manager->CreateGroup(rank_list); - attr_group = std::make_pair(GROUP, MakeValue(g.name())); - } - } else { - operator_name = REDUCE_SCATTER; - if (InferGroup() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer Group failed."; - return FAILED; - } - attr_group = std::make_pair(GROUP, MakeValue(group_.name())); - } - Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); - OperatorAttrs attrs = {attr_op, attr_group}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - Operator op = std::make_pair(operator_name, args); - - forward_op_.push_back(op); - return SUCCESS; -} - -Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { - GenerateGraph gen_g = GenerateGraph(); - if (gen_g.Init(cnode) != SUCCESS) { - MS_LOG(ERROR) << "GenerateGraph Init failed"; - return FAILED; - } - if (InferBias() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer Bias failed."; - return FAILED; - } - auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); - auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); - auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); - auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); - auto gather_v2 = - gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); - auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); - auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); - auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)}); - auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims}); - // don't need expandim,if param_size = 1, - if (inputs_shape_.at(0).size() == 1) { - mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast}); - } - if (InferGroup() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer Group failed."; - return FAILED; - } - Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); - Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); - OperatorAttrs attrs = {attr_op, attr_group}; - auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); - std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; - replace_graph_ = std::make_shared>, AnfNodePtr>>( - std::make_pair(input_nodes, reduce_scatter)); - - return SUCCESS; -} - -ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { - auto param_strategy = strategy_->GetInputDim().at(0); - // target_ == CPU, no need to raplace graph - if (target_ == CPU) { - return nullptr; - } - if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; - return nullptr; - } - return replace_graph_; -} - -Status GatherV2PInfo::ComputeReplaceOp() { - if (InferBias() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer offset failed."; - return FAILED; - } - OperatorName op_name = EMBEDDING_LOOKUP; - OperatorAttrs attrs; - Attr param_offset = std::make_pair("offset", MakeValue(bias_)); - Attr param_flag = std::make_pair("reduce_scatter_flag", MakeValue(reduce_scatter_flag_)); - Attr param_split_num = std::make_pair("split_num", MakeValue(split_num_)); - OperatorParams params = {std::make_pair(param_offset, 3), std::make_pair(param_flag, 4), - std::make_pair(param_split_num, 5)}; - OperatorArgs args = std::make_pair(attrs, params); - Operator op = std::make_pair(op_name, args); - replace_op_.push_back(op); - - return SUCCESS; -} - -Status GatherV2PInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - // only target_ == CPU, we need to replace op - if (target_ == CPU && ComputeReplaceOp() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed."; - } - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - auto param_strategy = strategy_->GetInputDim().at(0); - // cost model set axis and strategy - auto gatherv2_2cost = std::dynamic_pointer_cast(operator_cost()); - gatherv2_2cost->set_axis(axis_); - gatherv2_2cost->set_strategy(param_strategy); - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shape input1_split(inputs_shape_[1].size(), 1); - Shapes splittable_inputs = {input0_split, input1_split}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -std::shared_ptr>> GatherV2PInfo::GenerateBatchStrategies() { - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - Dimensions param_strategy(inputs_shape_[0].size(), 1); - Dimensions index_strategy; - index_strategy.push_back(SizeToInt(dev_num)); - for (size_t i = 1; i < inputs_shape_[1].size(); i++) { - index_strategy.push_back(1); - } - std::vector strategy_v = {param_strategy, index_strategy}; - return std::make_shared>>(strategy_v); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h deleted file mode 100644 index 83868606d1..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class GatherV2PInfo : public OperatorInfo { - public: - GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), - axis_(0), - bias_(0), - slice_size_(0) {} - ~GatherV2PInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; - std::shared_ptr>> GenerateBatchStrategies() override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status GetAttrs() override; - - private: - Status ComputeReplaceGraph(const CNodePtr &cnode); - Status ComputeReplaceOp(); - Status InferBias(); - Status InferGroup(); - - int32_t axis_; - std::string target_; - std::string replace_op_name_ = GATHERV2; - int32_t bias_; - int32_t slice_size_; - Shape out_dev_matrix_shape_; - Group group_; - bool reduce_scatter_flag_ = false; - int32_t split_num_ = 1; - bool host_reduce_scatter_ = false; -}; - -class SparseGatherV2Info : public GatherV2PInfo { - public: - SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} - ~SparseGatherV2Info() override = default; - - private: - std::string replace_op_name_ = SPARSE_GATHERV2; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/parallel/ops_info/get_next_info.cc deleted file mode 100644 index 0fb49364f0..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc +++ /dev/null @@ -1,269 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/get_next_info.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/context.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Status GetNextInfo::InferTensorMap() { - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - - for (auto shp : shapes_) { - TensorMap out_tensor_map; - for (size_t i = 0; i < shp.size(); ++i) { - if (full_batch) { - out_tensor_map.push_back(MAP_NONE); - } else { - out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1)); - } - } - outputs_tensor_map_.push_back(out_tensor_map); - } - return SUCCESS; -} - -Status GetNextInfo::InferTensorLayout(TensorLayouts *outputs_layout) { - if (outputs_layout == nullptr) { - MS_LOG(ERROR) << name_ << " : The layout is null."; - return FAILED; - } - for (size_t i = 0; i < outputs_shape_.size(); ++i) { - TensorLayout output_layout; - if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) { - return FAILED; - } - outputs_layout->push_back(output_layout); - } - return SUCCESS; -} - -Strategys GetNextInfo::GetOutputStrategy() { - Strategys outputs_strategy; - for (auto shp : shapes_) { - Dimensions out_strategy; - out_strategy.push_back(dev_num_); - for (size_t i = 1; i < shp.size(); ++i) { - out_strategy.push_back(1); - } - outputs_strategy.push_back(out_strategy); - } - return outputs_strategy; -} - -Status GetNextInfo::InferTensorInfo() { - TensorLayouts outputs_layout; - if (InferTensorLayout(&outputs_layout) != SUCCESS) { - return FAILED; - } - for (size_t i = 0; i < outputs_shape_.size(); ++i) { - TensorInfo output_tensor_info(outputs_layout[i]); - outputs_tensor_info_.push_back(output_tensor_info); - } - return SUCCESS; -} - -Status GetNextInfo::InferDevMatrixShape() { - size_t max_shape_length = 0; - for (auto shp : shapes_) { - if (max_shape_length < shp.size()) { - max_shape_length = shp.size(); - } - } - if (max_shape_length == 0) { - MS_LOG(ERROR) << name_ << " : shape is 0"; - } - dev_matrix_shape_.push_back(dev_num_); - for (size_t i = 1; i < max_shape_length; ++i) { - dev_matrix_shape_.push_back(1); - } - return SUCCESS; -} - -Status GetNextInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed"; - return FAILED; - } - if (InferReplaceOps(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Infer replace Ops failed"; - return FAILED; - } - MS_LOG(INFO) << name_ << " : Init success"; - return SUCCESS; -} - -Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { - std::vector stras = strategy->GetInputDim(); - for (Dimensions stra : stras) { - if (stra.size() != 0) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - } - int32_t stage = strategy->GetInputStage(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); - dev_num_ = dev_num; - return SUCCESS; -} - -Status GetNextInfo::GetAttrTypes() { - auto iter = attrs_.find(TYPES); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - auto iter_cast = iter->second->cast(); - MS_EXCEPTION_IF_NULL(iter_cast); - auto types = iter_cast->value(); - for (auto &type : types) { - MS_EXCEPTION_IF_NULL(type); - types_.push_back(type->ToString()); - } - } else if (iter->second->isa()) { - auto iter_cast = iter->second->cast(); - MS_EXCEPTION_IF_NULL(iter_cast); - auto types = iter_cast->value(); - for (auto &type : types) { - MS_EXCEPTION_IF_NULL(type); - types_.push_back(type->ToString()); - } - } else { - MS_LOG(ERROR) << name_ << " : The value of types is not list."; - return FAILED; - } - } - return SUCCESS; -} - -Status GetNextInfo::GetAttrShapes() { - shapes_ = outputs_shape_; - if (shapes_.size() == 0) { - MS_LOG(ERROR) << name_ << " : Shape is None."; - return FAILED; - } - return SUCCESS; -} - -Status GetNextInfo::GetAttrOutPutNum() { - auto iter = attrs_.find(GETNEXT_NUM); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - output_num_ = iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of output_num is not int."; - return FAILED; - } - } - return SUCCESS; -} - -Status GetNextInfo::GetAttrs() { - if (GetAttrTypes() == FAILED || GetAttrShapes() == FAILED || GetAttrOutPutNum() == FAILED) { - return FAILED; - } - if (types_.size() != IntToSize(output_num_) || shapes_.size() != IntToSize(output_num_) || output_num_ == 0) { - MS_LOG(ERROR) << name_ << " : The output_num is not equal to shapes size."; - return FAILED; - } - return SUCCESS; -} - -Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - - Shapes out_shapes = outputs_shape_; - for (size_t i = 0; i < out_shapes.size(); ++i) { - if (dev_num_ <= 0) { - MS_LOG(ERROR) << name_ << " : The dev num is 0."; - return FAILED; - } - if (out_shapes[i][0] % dev_num_ != 0) { - MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; - return FAILED; - } - if (!full_batch) { - out_shapes[i][0] = out_shapes[i][0] / dev_num_; - } - } - ValuePtr new_shapes = MakeValue(out_shapes); - Attr attr_types = std::make_pair(TYPES, attrs_[TYPES]); - Attr attr_shapes = std::make_pair(SHAPES, new_shapes); - Attr attr_num = std::make_pair(GETNEXT_NUM, attrs_[GETNEXT_NUM]); - Attr attr_shared_name = std::make_pair(SHARED_NAME, attrs_[SHARED_NAME]); - OperatorAttrs attrs = {attr_types, attr_shapes, attr_num, attr_shared_name}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - replace_op_ = {std::make_pair(GET_NEXT, args)}; - return SUCCESS; -} - -Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} - -Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -Status GetNextInfo::GenerateStrategies(int32_t stage_id) { - is_auto_parallel_ = true; - std::vector stra; - StrategyPtr sp = std::make_shared(stage_id, stra); - if (SetCostUnderStrategy(sp) == SUCCESS) { - MS_LOG(INFO) << name_ << " : Successfully generated strategy."; - PrintStrategy(sp); - } else { - MS_LOG(ERROR) << name_ << " : Generating strategy failed."; - return FAILED; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/get_next_info.h b/mindspore/ccsrc/parallel/ops_info/get_next_info.h deleted file mode 100644 index ba209910b7..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/get_next_info.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ - -#include -#include -#include -#include - -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class GetNextInfo : public OperatorInfo { - public: - GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~GetNextInfo() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t stage_id) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *outputs_layout); - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferReplaceOps(const StrategyPtr &strategy); - Status GetAttrTypes(); - Status GetAttrShapes(); - Status GetAttrOutPutNum(); - Strategys GetOutputStrategy(); - Status InferAsLossDivisor() override { return SUCCESS; } - - private: - int32_t dev_num_ = 1; - std::vector types_; - Shapes shapes_; - int32_t output_num_ = 0; - std::string shared_name_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc deleted file mode 100644 index 8716997d9f..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/l2_normalize_info.h" - -#include -#include -#include -#include - -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(INFO) << name_ << " : Init success."; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - Dimensions input_strategy = stra.at(0); - int32_t axis_index = axis_; - if (axis_ < 0) { - size_t input_dim = inputs_shape_.at(0).size(); - axis_index = static_cast(input_dim) + axis_; - } - - if (input_strategy[IntToSize(axis_index)] != 1) { - MS_LOG(ERROR) << name_ << " : The dim " << axis_index << " of input strategy must be 1."; - return FAILED; - } - - return SUCCESS; -} - -Status L2NormalizeInfo::GetAttrs() { - auto iter = attrs_.find(AXIS); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - axis_ = iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of axis is not int."; - return FAILED; - } - } - - return SUCCESS; -} - -Status L2NormalizeInfo::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_tensor_map = inputs_tensor_map_.at(0); - std::vector input_group; - if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group failed."; - return FAILED; - } - - OperatorVector op_for_weight; - if (input_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror ops is empty."; - return SUCCESS; - } else { - op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); - mirror_ops_.push_back(op_for_weight); - MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group is " << input_group[0].name(); - } - - return SUCCESS; -} - -Status L2NormalizeInfo::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << " : GetAttrs failed."; - return FAILED; - } - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size() - 1, 1); - int32_t axis_index = axis_; - if (axis_ < 0) { - size_t input_dim = inputs_shape_.at(0).size(); - axis_index = static_cast(input_dim) + axis_; - } - (void)input0_split.insert(input0_split.begin() + axis_index, 0); - Shapes splittable_inputs = {input0_split}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h deleted file mode 100644 index ca063d01d8..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class L2NormalizeInfo : public Activation { - public: - L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Activation(name, inputs_shape, outputs_shape, attrs) {} - ~L2NormalizeInfo() override = default; - Status GenerateStrategies(int32_t stage_id) override; - - protected: - Status GetAttrs() override; - Status InferMirrorOps() override; - Status CheckStrategy(const StrategyPtr &strategy) override; - - private: - int32_t axis_ = 0; // Default value = 0 -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.cc b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.cc deleted file mode 100644 index 5bdd24090f..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.cc +++ /dev/null @@ -1,324 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/layer_norm_info.h" -#include -#include -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -Status LayerNormInfo::GetAttrs() { - auto iter = attrs_.find(BEGIN_NORM_AXIS); - if (iter == attrs_.end()) { - MS_LOG(ERROR) << name_ << ": Can not find the attr of begin norm axis"; - return FAILED; - } - if ((iter->second == nullptr) || !iter->second->isa()) { - MS_LOG(ERROR) << name_ << ": The axis type is not int"; - return FAILED; - } - - int32_t dim = SizeToInt(input_shape_.size()); - auto axis = GetValue(iter->second); - if ((axis >= dim) || (axis < -dim)) { - MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim << ", " << dim - 1 << "]"; - return FAILED; - } - - if (axis < 0) { - axis = axis + dim; - } - begin_norm_axis_ = IntToSize(axis); - return SUCCESS; -} - -Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) { - MS_EXCEPTION_IF_NULL(strategy); - std::vector stra = strategy->GetInputDim(); - if (stra.size() != LAYER_NORM_INPUT_SIZE) { - MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size(); - return FAILED; - } - - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Invalid strategy value"; - return FAILED; - } - - Dimensions input_strategy = stra[LAYER_NORM_INPUT_INDEX]; - Dimensions gamma_strategy = stra[LAYER_NORM_GAMMA_INDEX]; - Dimensions beta_strategy = stra[LAYER_NORM_BETA_INDEX]; - if (begin_norm_axis_ >= input_strategy.size()) { - MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_; - return FAILED; - } - // check input strategy - for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) { - if (input_strategy[i] != NO_SPLIT_STRATEGY) { - MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy); - return FAILED; - } - } - - // check gamma and beta strategy - if ((gamma_strategy.size() > input_strategy.size()) || (beta_strategy.size() > input_strategy.size())) { - MS_LOG(ERROR) << name_ << " : The strategy size of gamma or beta is lager than input strategy"; - return FAILED; - } - - size_t gamma_diff = input_strategy.size() - gamma_strategy.size(); - for (size_t j = 0; j < gamma_strategy.size(); ++j) { - if (gamma_strategy[j] != input_strategy[gamma_diff + j]) { - MS_LOG(ERROR) << name_ << ": Invalid gamma strategy " << ShapeToString(gamma_strategy); - return FAILED; - } - } - - size_t beta_diff = input_strategy.size() - beta_strategy.size(); - for (size_t k = 0; k < beta_strategy.size(); ++k) { - if (beta_strategy[k] != input_strategy[beta_diff + k]) { - MS_LOG(ERROR) << name_ << ": Invalid beta strategy " << ShapeToString(beta_strategy); - return FAILED; - } - } - return SUCCESS; -} - -Status LayerNormInfo::InferDevMatrixShape() { - if (strategy_ == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null"; - return FAILED; - } - std::vector stra = strategy_->GetInputDim(); - if (stra.empty()) { - MS_LOG(ERROR) << name_ << ": The strategy is empty"; - return FAILED; - } - dev_matrix_shape_ = stra[0]; - return SUCCESS; -} - -Status LayerNormInfo::CreateTensorMap(size_t input_index) { - if (inputs_shape_.size() <= input_index) { - MS_LOG(ERROR) << name_ << ": Invalid index" << input_index; - return FAILED; - } - Shape shape = inputs_shape_[input_index]; - Shape tensor_map; - for (size_t i = 0; i < shape.size(); ++i) { - tensor_map.push_back(SizeToInt(shape.size() - i - 1)); - } - inputs_tensor_map_.push_back(tensor_map); - outputs_tensor_map_.push_back(tensor_map); - return SUCCESS; -} - -Status LayerNormInfo::InferTensorMap() { - if ((CreateTensorMap(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorMap(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || - (CreateTensorMap(LAYER_NORM_BETA_INDEX) != SUCCESS)) { - MS_LOG(ERROR) << name_ << ": Create tensor map failed"; - return FAILED; - } - return SUCCESS; -} - -Status LayerNormInfo::CreateMirrorOp(size_t input_index) { - if (inputs_tensor_map_.size() <= input_index) { - MS_LOG(ERROR) << name_ << ": Invalid index " << input_index; - return FAILED; - } - Shape tensor_map = inputs_tensor_map_[input_index]; - std::vector group; - if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input " << input_index << " failed"; - return FAILED; - } - OperatorVector mirror_op; - if (!group.empty()) { - mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input " << input_index << " success, group is " - << group[0].name(); - } - mirror_ops_.push_back(mirror_op); - return SUCCESS; -} - -Status LayerNormInfo::InferMirrorOps() { - if ((CreateMirrorOp(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateMirrorOp(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || - (CreateMirrorOp(LAYER_NORM_BETA_INDEX) != SUCCESS)) { - MS_LOG(ERROR) << name_ << ": Create mirror op failed"; - return FAILED; - } - return SUCCESS; -} - -Status LayerNormInfo::CreateTensorInfo(size_t input_index) { - if ((inputs_shape_.size() <= input_index) || (inputs_tensor_map_.size() <= input_index)) { - MS_LOG(ERROR) << name_ << ": Invalid input index" << input_index; - return FAILED; - } - Shape tensor_map = inputs_tensor_map_[input_index]; - Shape shape = inputs_shape_[input_index]; - TensorLayout tensor_layout; - if (tensor_layout.InitFromVector(dev_matrix_shape_, tensor_map, shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout for input " << input_index << " failed"; - return FAILED; - } - - TensorInfo tensor_info(tensor_layout); - inputs_tensor_info_.push_back(tensor_info); - outputs_tensor_info_.push_back(tensor_info); - return SUCCESS; -} - -Status LayerNormInfo::InferTensorInfo() { - if ((CreateTensorInfo(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorInfo(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || - (CreateTensorInfo(LAYER_NORM_BETA_INDEX) != SUCCESS)) { - MS_LOG(ERROR) << name_ << ": Create tensor info failed"; - return FAILED; - } - return SUCCESS; -} - -Status LayerNormInfo::InferAsLossDivisor() { - if (outputs_tensor_map_.size() != LAYER_NORM_INPUT_SIZE) { - MS_LOG(ERROR) << name_ << ": The size of outputs tensor map " << outputs_tensor_map_.size() << " is error"; - return FAILED; - } - as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); - MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_) - << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0]) - << ", as_loss_divisor_ is " << as_loss_divisor_; - return SUCCESS; -} - -Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Set cost failed"; - return FAILED; - } - return SUCCESS; -} - -Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector &sp_vector) { - if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) { - MS_LOG(ERROR) << name_ << ": The dimension of gamma or beta is lager than input"; - return FAILED; - } - - size_t gamma_diff = input_shape_.size() - gamma_shape_.size(); - size_t beta_diff = input_shape_.size() - beta_shape_.size(); - for (auto &sp : sp_vector) { - if ((sp == nullptr) || sp->GetInputDim().empty()) { - MS_LOG(ERROR) << name_ << ": Invalid strategy"; - return FAILED; - } - std::vector tmp_strategy; - Dimensions input_strategy = sp->GetInputDim()[0]; - Dimensions gamma_strategy = input_strategy; - (void)gamma_strategy.erase(gamma_strategy.begin(), - gamma_strategy.begin() + static_cast(gamma_diff)); - Dimensions beta_strategy = input_strategy; - (void)beta_strategy.erase(beta_strategy.begin(), beta_strategy.begin() + static_cast(beta_diff)); - - // reset the strategy - tmp_strategy.push_back(input_strategy); - tmp_strategy.push_back(gamma_strategy); - tmp_strategy.push_back(beta_strategy); - sp->ResetInputs(tmp_strategy); - } - return SUCCESS; -} - -Status LayerNormInfo::GenerateStrategies(int32_t stage_id) { - if (InitShapes() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init shapes failed"; - return FAILED; - } - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Get attrs failed"; - return FAILED; - } - Shape input_split(input_shape_.size(), SPLIT_FLAG); - if (begin_norm_axis_ >= input_split.size()) { - MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_; - return FAILED; - } - - // Can not split the dimensions from begin norm axis - for (size_t i = begin_norm_axis_; i < input_split.size(); ++i) { - input_split[i] = NO_SPLIT_FLAG; - } - - // Generate strategy for input - Shapes splittable_inputs = {input_split}; - Shapes tmp_inputs_shape = {input_shape_}; - std::vector sp_vector; - is_auto_parallel_ = true; - if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Generate input strategy failed"; - return FAILED; - } - - // Generate the strategies for gamma and beta - if (GenerateGammaAndBetaStrategies(sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Generate gamma and beta strategies failed"; - return FAILED; - } - - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(DEBUG) << name_ << ": Successfully generated " << success << " strategy"; - } - } - return SUCCESS; -} - -Status LayerNormInfo::InitShapes() { - if (inputs_shape_.size() != LAYER_NORM_INPUT_SIZE) { - MS_LOG(ERROR) << name_ << ": Invalid inputs size"; - return FAILED; - } - input_shape_ = inputs_shape_[LAYER_NORM_INPUT_INDEX]; - gamma_shape_ = inputs_shape_[LAYER_NORM_GAMMA_INDEX]; - beta_shape_ = inputs_shape_[LAYER_NORM_BETA_INDEX]; - return SUCCESS; -} - -Status LayerNormInfo::Init(const StrategyPtr &strategy) { - if ((InitShapes() != SUCCESS) || (InitWithAutoRepeatCalc(strategy)) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed"; - return FAILED; - } - MS_LOG(INFO) << name_ << ": Init success"; - return SUCCESS; -} - -Status LayerNormInfo::InitForCostModel(const StrategyPtr &strategy) { - if ((InitShapes() != SUCCESS) || (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS)) { - MS_LOG(ERROR) << name_ << ": Init for cost model failed"; - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success"; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h deleted file mode 100644 index 50117b8185..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ - -#include -#include -#include -#include -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -constexpr size_t LAYER_NORM_INPUT_SIZE = 3; -constexpr size_t LAYER_NORM_INPUT_INDEX = 0; -constexpr size_t LAYER_NORM_GAMMA_INDEX = 1; -constexpr size_t LAYER_NORM_BETA_INDEX = 2; -constexpr char BEGIN_NORM_AXIS[] = "begin_norm_axis"; - -// The dimensions of input tensor starting from begin norm axis cannot be split. Other dimensions can be split -// arbitrarily. Gamma and beta should match input to meet the broadcast requirements of mul and add. -class LayerNormInfo : public OperatorInfo { - public: - LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(true)), - begin_norm_axis_(0) {} - ~LayerNormInfo() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr &) override; - - protected: - Status GetAttrs() override; - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferAsLossDivisor() override; - Status CreateTensorMap(size_t input_index); - Status CreateTensorInfo(size_t input_index); - Status CreateMirrorOp(size_t input_index); - Status GenerateGammaAndBetaStrategies(const std::vector &sp_vector); - Status InitShapes(); - - private: - size_t begin_norm_axis_; - Shape input_shape_; - Shape gamma_shape_; - Shape beta_shape_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.cc b/mindspore/ccsrc/parallel/ops_info/loss_info.cc deleted file mode 100644 index 0ba325c0cd..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.cc +++ /dev/null @@ -1,232 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/loss_info.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - Dimensions input_strategy = stra.at(0); - Dimensions label_strategy = stra.at(1); - if (input_strategy != label_strategy) { - MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; - return FAILED; - } - - int32_t axis_index = axis_; - if (axis_ < 0) { - size_t input_dim = inputs_shape_.at(0).size(); - axis_index = static_cast(input_dim) + axis_; - } - - int32_t input_axis_strategy = input_strategy.at(IntToSize(axis_index)); - int32_t label_axis_strategy = label_strategy.at(IntToSize(axis_index)); - // Dimension corresponding to axis is un-splittable - if ((input_axis_strategy != MIN_SLICE_NUM) && (label_axis_strategy != MIN_SLICE_NUM)) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ - << " : The strategy corresponding to axis dimension is not 1, input: " << input_axis_strategy - << ", label: " << label_axis_strategy; - } else { - MS_LOG(ERROR) << name_ - << " : The strategy corresponding to axis dimension is not 1, input: " << input_axis_strategy - << ", label: " << label_axis_strategy; - } - return FAILED; - } - - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() { - if ((inputs_shape_.size() != SoftmaxCrossEntropyWithLogitsInputsSize) || - (outputs_shape_.size() != SoftmaxCrossEntropyWithLogitsOutputsSize)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; - return FAILED; - } - - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - dev_matrix_shape_ = input_strategy; - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorMap() { - std::vector tensor_map_index; - size_t size = inputs_shape_[0].size(); - // such as 4: tensor_map_index [3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - i - 1)); - } - - std::vector first_output_tensor_map = {tensor_map_index[0]}; - inputs_tensor_map_.push_back(tensor_map_index); // input - inputs_tensor_map_.push_back(tensor_map_index); // label - outputs_tensor_map_.push_back(first_output_tensor_map); // output-0 - outputs_tensor_map_.push_back(tensor_map_index); // output-1 - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape first_output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {{inputs_strategy[0][0]}, inputs_strategy.at(0)}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - Shape first_output_slice_shape = outputs_slice_shape.at(0); - - TensorMap input_tensor_map = inputs_tensor_map_.at(0); - TensorMap first_output_tensor_map = outputs_tensor_map_.at(0); - - TensorLayout input_tensor_layout, first_output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, input_tensor_map, input_shape) != SUCCESS) || - (first_output_tensor_layout.InitFromVector(dev_matrix_shape_, first_output_tensor_map, first_output_shape) != - SUCCESS)) { - return FAILED; - } - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - TensorInfo first_output_tensor_info(first_output_tensor_layout, first_output_shape, first_output_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); // input - inputs_tensor_info_.push_back(input_tensor_info); // label - outputs_tensor_info_.push_back(first_output_tensor_info); // output-0 - outputs_tensor_info_.push_back(input_tensor_info); // output-1 - - return SUCCESS; -} - -// There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload the function. -Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() { - if (outputs_tensor_map_.size() != 2) { - MS_LOG(ERROR) << name_ << " : The size of outputs tensor map " << outputs_tensor_map_.size() << " is error."; - return FAILED; - } - as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[1]); - MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_) - << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[1]) << ", as_loss_divisor_ is " - << as_loss_divisor_; - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} - -void SoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { - for (size_t i = 0; i < inputs_shape_.size(); ++i) { - split_flag_list_[i] = true; - } -} - -Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << " : GetAttrs failed."; - return FAILED; - } - int32_t axis_index = axis_; - if (axis_ < 0) { - size_t input_dim = inputs_shape_[0].size(); - axis_index = static_cast(input_dim) + axis_; - } - is_auto_parallel_ = true; - - Shape input0_split; - (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); - input0_split[IntToSize(axis_index)] = 0; - Shapes splittable_inputs = {input0_split, input0_split}; - std::vector sp_vector; - if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies failed."; - return FAILED; - } - - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - PrintStrategy(strategy); - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.h b/mindspore/ccsrc/parallel/ops_info/loss_info.h deleted file mode 100644 index 2679c2d62b..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.h +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -// infer shape: -// input_0 : [a, b], input_1 : [a, b] -// output_0 : [a], output_1: [a, b] -class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { - public: - SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, - std::make_shared(false)) {} - ~SoftmaxCrossEntropyWithLogitsInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - void ReComputeBatchSplitFlagList() override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override; - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - // There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload - // the InferAsLossDivisor. - Status InferAsLossDivisor() override; - - private: - int32_t axis_ = -1; // default -1 -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc deleted file mode 100644 index 7d1ab8dc0f..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc +++ /dev/null @@ -1,647 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/matmul_info.h" - -#include -#include -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b, - Shape *dev_matrix_shape) { - MS_EXCEPTION_IF_NULL(dev_matrix_shape); - size_t mat_a_size = mat_a_strategy.size(); - size_t mat_b_size = mat_b_strategy.size(); - if (mat_a_size >= mat_b_size) { - // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32] - // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) - - // [2],[4] in the example above - for (size_t i = 0; i < SECOND_FROM_END(mat_a_size); ++i) { - dev_matrix_shape->push_back(mat_a_strategy.at(i)); - } - } else { - // for example: mat_a_strategy:[8,16], mat_b_strategy:[2,4,16,32] - // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) - - // [2],[4] in the example above - for (size_t i = 0; i < SECOND_FROM_END(mat_b_size); ++i) { - dev_matrix_shape->push_back(mat_b_strategy.at(i)); - } - } - - // [8],[16] in the example above - dev_matrix_shape->push_back(mat_a_strategy.at(SECOND_FROM_END(mat_a_size))); - dev_matrix_shape->push_back(mat_a_strategy.back()); - - // [32] in the example above - if (!transpose_b) { - dev_matrix_shape->push_back(mat_b_strategy.back()); - } else { - dev_matrix_shape->push_back(mat_b_strategy.at(SECOND_FROM_END(mat_b_size))); - } -} - -Status MatMulBase::GetAttrs() { - if (attrs_.size() < MATMUL_ATTRS_SIZE) { - MS_LOG(ERROR) << name_ << " : The size of attrs small than 2."; - return FAILED; - } - - auto transpose_a_iter = attrs_.find(TRANSPOSE_A); - if (transpose_a_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(transpose_a_iter->second); - if (transpose_a_iter->second->isa()) { - transpose_a_ = transpose_a_iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool."; - return FAILED; - } - } - - auto transpose_b_iter = attrs_.find(TRANSPOSE_B); - if (transpose_b_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(transpose_b_iter->second); - if (transpose_b_iter->second->isa()) { - transpose_b_ = transpose_b_iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool."; - return FAILED; - } - } - - auto forward_reduce_scatter_iter = attrs_.find(FORWARD_REDUCE_SCATTER); - if (forward_reduce_scatter_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(forward_reduce_scatter_iter->second); - if (forward_reduce_scatter_iter->second->isa()) { - forward_reduce_scatter_ = forward_reduce_scatter_iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of forward reduce scatter is not bool."; - return FAILED; - } - } - - // infer inputs dimension size - if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; - return FAILED; - } - mat_a_dimension_ = inputs_shape_.at(0).size(); - mat_b_dimension_ = inputs_shape_.at(1).size(); - - return SUCCESS; -} - -Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) { - size_t long_size = long_strategy.size(); - size_t short_size = short_strategy.size(); - if (long_size < short_size) { - MS_LOG(ERROR) << "Size error, the size of long strategy is " << long_size << ", the size of short strategy is " - << short_size; - return FAILED; - } - - size_t len_diff = long_size - short_size; - for (size_t j = 0; j < SECOND_FROM_END(short_size); ++j) { - if (long_strategy.at(len_diff + j) != short_strategy.at(j)) { - MS_LOG(ERROR) << "Strategies of relevant dimensions are not equal, long strategy is " - << ShapeToString(long_strategy) << ", short strategy is " << ShapeToString(short_strategy); - return FAILED; - } - } - - return SUCCESS; -} - -Status MatMul::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - Dimensions mat_a_strategy = stra.at(0); - Dimensions mat_b_strategy = stra.at(1); - - size_t mat_a_size = mat_a_strategy.size(); - size_t mat_b_size = mat_b_strategy.size(); - if ((mat_a_size != mat_a_dimension_) || (mat_b_size != mat_b_dimension_)) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong."; - } else { - MS_LOG(ERROR) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong."; - } - return FAILED; - } - - // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32] - // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) - // [16] in the example above - if (!transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.at(SECOND_FROM_END(mat_b_size)))) { - MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; - return FAILED; - } else if (transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.back())) { - MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; - return FAILED; - } - - if (mat_a_size >= mat_b_size) { - if (CheckRelevantDimension(mat_a_strategy, mat_b_strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; - return FAILED; - } - } else { - if (CheckRelevantDimension(mat_b_strategy, mat_a_strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; - return FAILED; - } - } - - if ((mat_a_dimension_ != 2 || mat_b_dimension_ != 2) && forward_reduce_scatter_) { - MS_LOG(WARNING) << name_ - << ": The dimension of mat a and mat b must be 2 in forward reduce scatter mode, " - "setting the forward reduce scatter mode to false here"; - forward_reduce_scatter_ = false; - } - - return SUCCESS; -} - -Status MatMulBase::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions mat_a_strategy = stra.at(0); - Dimensions mat_b_strategy = stra.at(1); - - SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_); - return SUCCESS; -} - -// all-reduce weight's grad -Status MatMulBase::InferMirrorOps() { - mirror_ops_.clear(); - - Shape mat_b_tensor_map = inputs_tensor_map_[1]; - std::vector mat_b_group; - if (CreateGroupByTensorMap(mat_b_tensor_map, &mat_b_group) != SUCCESS) { - return FAILED; - } - - OperatorVector op_for_inputs; // op_for_inputs is empty - OperatorVector op_for_weight; - - if (mat_b_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror ops is empty."; - return SUCCESS; - } else { - op_for_weight = CreateMirrorOps(mat_b_group[0].name(), mat_b_group[0].GetDevNum()); - mirror_ops_.push_back(op_for_inputs); - mirror_ops_.push_back(op_for_weight); - MS_LOG(INFO) << name_ << " : Create the mirror ops for weight success, group is " << mat_b_group[0].name(); - } - - return SUCCESS; -} - -Status MatMulBase::InferForwardCommunication() { - forward_op_.clear(); - size_t dimension = dev_matrix_shape_.size(); - size_t relevant_dimension_index = SECOND_FROM_END(dimension); - // Relevant dimension is not split and all reduce is not required - if (dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) { - MS_LOG(INFO) << name_ << " : Forward all reduce is not required."; - return SUCCESS; - } - - std::vector group_list; - if (CreateGroupByDim(relevant_dimension_index, &group_list) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Infer forward communication, create group failed."; - return FAILED; - } else if (group_list.empty()) { - MS_LOG(INFO) << name_ << " : Forward all reduce is not required."; - return SUCCESS; - } - - Operator op; - if (forward_reduce_scatter_) { - op = CreateReduceScatterOp(REDUCE_OP_SUM, group_list[0].name()); - } else { - op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name()); - } - - forward_op_.push_back(op); - MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name(); - return SUCCESS; -} - -Status MatMulBase::InferTensorMap() { - size_t size = dev_matrix_shape_.size(); - if (repeated_calc_num_ > 1) { - // move the first dimension(repeated_calc_num_), just for the convenience of tensor-map's calculation - size = dev_matrix_shape_.size() - 1; - } - - std::vector tensor_map_index; - // such as 5: tensor_map_index [4,3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); - } - - // infer output tensor map: [4,3,2,0], delete the second-from-end element - TensorMap output_tensor_map = tensor_map_index; - (void)output_tensor_map.erase(output_tensor_map.begin() + static_cast(SECOND_FROM_END(size))); - - // infer mat_a tensor map - // for example: mat_a_dimension is 4, mat_a tensor map:[4,3,2,1] - TensorMap mat_a_tensor_map = tensor_map_index; - // delete last one element - mat_a_tensor_map.pop_back(); - // delete the first (dev_matrix_size - 1 - mat_a_dimension) elements - (void)mat_a_tensor_map.erase( - mat_a_tensor_map.begin(), - mat_a_tensor_map.begin() + static_cast(LAST_INDEX(size) - mat_a_dimension_)); - - // infer mat_b tensor map - TensorMap mat_b_tensor_map = tensor_map_index; - // delete the third-to-last element - (void)mat_b_tensor_map.erase(mat_b_tensor_map.begin() + static_cast(THIRD_FROM_END(size))); - // delete the first (dev_matrix_size - 1 - mat_b_dimension) elements - (void)mat_b_tensor_map.erase( - mat_b_tensor_map.begin(), - mat_b_tensor_map.begin() + static_cast(LAST_INDEX(size) - mat_b_dimension_)); - if (transpose_b_) { - // swap the last two elements - int32_t last_value = mat_b_tensor_map.back(); - mat_b_tensor_map.pop_back(); - (void)mat_b_tensor_map.insert( - mat_b_tensor_map.begin() + static_cast(LAST_INDEX(mat_b_tensor_map.size())), last_value); - } - - if (forward_reduce_scatter_) { - if (dev_matrix_shape_.size() != 3) { - MS_LOG(WARNING) << name_ - << ": The dimension of dev matrix shape must be 3 in forward reduce scatter mode, " - "setting the forward reduce scatter mode to false here"; - forward_reduce_scatter_ = false; - } else if (outputs_shape_[0][0] % (dev_matrix_shape_[0] * dev_matrix_shape_[1]) != 0) { - MS_LOG(WARNING) << name_ - << ": The first dimension of output should be split by dev_matrix[0]*dev_matrix[1] in " - "forward reduce scatter mode, setting the forward reduce scatter mode to false here"; - forward_reduce_scatter_ = false; - } else { - // the forward reduce scatter only support that the dimension of output is 2 - output_tensor_map = {1, 0}; - } - } - - inputs_tensor_map_.push_back(mat_a_tensor_map); - inputs_tensor_map_.push_back(mat_b_tensor_map); - outputs_tensor_map_.push_back(output_tensor_map); - return SUCCESS; -} - -Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { - Shape output_dev_matrix_shape; - if (forward_reduce_scatter_) { - if (dev_matrix_shape_.size() != 3) { - MS_LOG(ERROR) << "The size of origin dev matrix shape must be 3 in forward reduce scatter mode"; - return FAILED; - } - output_dev_matrix_shape = {dev_matrix_shape_[0] * dev_matrix_shape_[1], dev_matrix_shape_[2]}; - } else { - output_dev_matrix_shape = dev_matrix_shape_; - } - - TensorLayout mat_a_layout, mat_b_layout, output_layout; - if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || - (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || - (output_layout.InitFromVector(output_dev_matrix_shape, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) { - return FAILED; - } - - inputs_layout->push_back(mat_a_layout); - inputs_layout->push_back(mat_b_layout); - outputs_layout->push_back(output_layout); - return SUCCESS; -} - -Status MatMulBase::InferTensorInfo() { - // infer tensor layout - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { - return FAILED; - } - - TensorLayout mat_a_layout = inputs_layout.at(0); - TensorLayout mat_b_layout = inputs_layout.at(1); - TensorLayout output_layout = outputs_layout.at(0); - TensorInfo mat_a_tensor_info(mat_a_layout); - TensorInfo mat_b_tensor_info(mat_b_layout); - TensorInfo output_tensor_info(output_layout); - - inputs_tensor_info_.push_back(mat_a_tensor_info); - inputs_tensor_info_.push_back(mat_b_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status MatMulBase::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - - if (forward_reduce_scatter_) { - virtual_div_op_.clear(); - MS_LOG(INFO) << "The forward reduce scatter mode does not involve repeated calculation, clear the virtual div op"; - } - - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} - -Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) { - if (input->size() < 2) { - MS_LOG(ERROR) << name_ << " : The size of inputs small than 2."; - return FAILED; - } - auto last_1st_value = input->at(input->size() - 1); - auto last_2nd_value = input->at(input->size() - 2); - input->pop_back(); - input->pop_back(); - input->push_back(last_1st_value); - input->push_back(last_2nd_value); - return SUCCESS; -} - -Status MatMulBase::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << " : GetAttrs failed."; - return FAILED; - } - CheckGlobalDeviceManager(); - std::vector dev_list = g_device_manager->GetDeviceListByStageId(stage_id); - size_t dev_num = dev_list.size(); - Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1]; - if (transpose_a_) { - if (SwapLastTwoElements(&input0_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - } - if (transpose_b_) { - if (SwapLastTwoElements(&input1_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - } - // The shape of input0 (input1) - // E.g., input0 = [100, 200, 300], input1 = [300, 400] - - // Combining the input0_shape and input1_shape - // E.g., combined_shape = [100, 200, 300, 400] - is_auto_parallel_ = true; - size_t input1_shape_size = input1_shape.size(), input0_shape_size = input0_shape.size(); - Dimensions combined_partitions; - Shape combined_shape; - // In SwapLastTwoElements(), it is guaranteed that input0_shape.size() and input1_shape.size() are both larger than 2 - if (input0_shape.size() >= input1_shape.size()) { - combined_shape = input0_shape; - combined_shape.push_back(input1_shape[input1_shape.size() - 1]); - } else { - combined_shape = input1_shape; - combined_shape.push_back(input0_shape[input0_shape.size() - 2]); - } - std::function recursive = [&stage_id, &dev_num, &combined_partitions, &combined_shape, - &input1_shape_size, &recursive, &input0_shape_size, - this](uint32_t current_index, size_t n) { - // Finishing the recursive steps, if the strategy is valid, then calculate the cost - // for this operator under the strategy. - if (current_index == combined_shape.size()) { - StrategyPtr sp; - if (this->PrepareStrategy(stage_id, dev_num, combined_partitions, input0_shape_size, input1_shape_size, &sp) == - FAILED) { - return; - } - if (this->SetCostUnderStrategy(sp) == FAILED) { - MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed."; - return; - } - } else { - MS_LOG(DEBUG) << name_ << " : The value input0_shape_size: " << input0_shape_size - << ", input1_shape_size: " << input1_shape_size; - for (uint32_t i = 1; i <= n; i *= 2) { - if (n % i == 0 && IntToSize(combined_shape[current_index]) % i == 0) { - combined_partitions.push_back(i); - recursive(current_index + 1, n / i); - combined_partitions.pop_back(); - } - } - } - }; - recursive(0, dev_num); - if (strategy_cost_.empty()) { - MS_LOG(EXCEPTION) << name_ << " : No available strategy."; - } - return Status::SUCCESS; -} - -Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, - mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, - size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { - int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies()); - if (!FULLY_USE_DEVICES) { - if (IntToSize(product) > dev_num) { - return FAILED; - } - } else { - if (IntToSize(product) != dev_num) { - return FAILED; - } - } - Dimensions input0_partitions, input1_partitions; - if (input0_shape_size >= input1_shape_size) { - for (size_t i = 0; i < input0_shape_size; ++i) { - input0_partitions.push_back(combined_partitions[i]); - } - if (input1_shape_size == 2) { - input1_partitions.push_back(combined_partitions[combined_partitions.size() - 2]); - input1_partitions.push_back(combined_partitions[combined_partitions.size() - 1]); - } else { - // input1_shape.size() > 2 - for (size_t j = combined_partitions.size() - input1_shape_size - 1; j < combined_partitions.size(); ++j) { - if (j == combined_partitions.size() - 3) { - continue; - } - input1_partitions.push_back(combined_partitions[j]); - } - } - } else { - for (size_t i = 0; i < input1_shape_size; ++i) { - input1_partitions.push_back(combined_partitions[i]); - } - for (size_t j = combined_partitions.size() - input0_shape_size - 1; j < combined_partitions.size() - 3; ++j) { - input0_partitions.push_back(combined_partitions[j]); - } - input0_partitions.push_back(combined_partitions[combined_partitions.size() - 1]); - input0_partitions.push_back(combined_partitions[combined_partitions.size() - 3]); - } - if (transpose_a_) { - if (SwapLastTwoElements(&input0_partitions) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - } - if (transpose_b_) { - if (SwapLastTwoElements(&input1_partitions) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - } - std::vector stras; - stras.push_back(input0_partitions); - stras.push_back(input1_partitions); - (*sp) = std::make_shared(stage_id, stras); - - return SUCCESS; -} - -void MatMulBase::InitTensorInfoForCost(std::vector *relica_inputs_tensor_vector) { - TensorLayout tly; - if (transpose_a_) { - Shape replica_input0_shape(inputs_tensor_info_[0].shape()); - Shape replica_input0_slice_shape(inputs_tensor_info_[0].slice_shape()); - if (SwapLastTwoElements(&replica_input0_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - if (SwapLastTwoElements(&replica_input0_slice_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - - TensorInfo replica_input0_info(tly, replica_input0_shape, replica_input0_slice_shape); - relica_inputs_tensor_vector->push_back(replica_input0_info); - } else { - relica_inputs_tensor_vector->push_back(inputs_tensor_info_[0]); - } - if (transpose_b_) { - Shape replica_input1_shape(inputs_tensor_info_[1].shape()); - Shape replica_input1_slice_shape(inputs_tensor_info_[1].slice_shape()); - if (SwapLastTwoElements(&replica_input1_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - if (SwapLastTwoElements(&replica_input1_slice_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - - TensorInfo replica_input1_info(tly, replica_input1_shape, replica_input1_slice_shape); - relica_inputs_tensor_vector->push_back(replica_input1_info); - } else { - relica_inputs_tensor_vector->push_back(inputs_tensor_info_[1]); - } -} - -Status MatMulBase::CheckForTensorSliceValid() const { - if (!TENSOR_SLICE_ALIGNMENT_ENABLE) { - return SUCCESS; - } - if (inputs_tensor_info_.empty()) { - return FAILED; - } - for (auto &one_input_tensor : inputs_tensor_info_) { - auto slice_shape = one_input_tensor.slice_shape(); - if ((IntToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) || - (IntToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) { - return FAILED; - } - } - return SUCCESS; -} - -Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (InitForCostModel(strategy) == FAILED) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Initialization under the strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Initialization under the strategy failed."; - } - return FAILED; - } - PrintStrategy(strategy); - // Check whether the tensor slice of input_tensor_info is valid or not - if (CheckForTensorSliceValid() != SUCCESS) { - MS_LOG(INFO) << name_ << " : The tensor slice is not valid under this strategy."; - return FAILED; - } - // Here, a replicated inputs_ is constructed for the transposed TensorInfo. - std::vector relica_inputs_tensor_vector; - InitTensorInfoForCost(&relica_inputs_tensor_vector); - - int32_t stage_id = strategy->GetInputStage(); - // Here, we use the origin outputs_, because we only use the slice size of the output tensor. - // It does not matter whether the output tensor is transposed or not. - double computation_cost = - operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); - double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); - std::shared_ptr result = std::make_shared(computation_cost, communication_cost); - result->communication_without_parameter_ = - operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); - result->communication_with_partial_para_ = - result->communication_without_parameter_ + - COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); - - // Breaking ties for preferring data parallelization - BreakingTiesForPerferringDataParallel(strategy, result); - MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_ - << ", communication_cost: " << result->communication_cost_ - << ", communication_without_parameter_: " << result->communication_without_parameter_ - << ", communication_with_partial_para_: " << result->communication_with_partial_para_; - // refine communication cost calculation for practice - RefineForPracticalCost(result, false); - result->communication_forward_ = result->communication_without_parameter_; - - std::shared_ptr swc = - std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); - swc->cost_list.push_back(result); - strategy_cost_.emplace_back(swc); - - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/parallel/ops_info/matmul_info.h deleted file mode 100644 index cb3e54a048..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.h +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ - -#include -#include -#include -#include - -#include "common/utils.h" -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class MatMulBase : public OperatorInfo { - public: - MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~MatMulBase() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - // Generate all strategies and the corresponding cost for this MatMul operator - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, - size_t input1_shape_size, StrategyPtr *sp); - - Status SwapLastTwoElements(Shape *shape); - - protected: - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); - void InitTensorInfoForCost(std::vector *); - Status CheckForTensorSliceValid() const; - Status GetAttrs() override; - - bool transpose_a_ = false; - bool transpose_b_ = false; - bool forward_reduce_scatter_ = false; - size_t mat_a_dimension_ = 0; - size_t mat_b_dimension_ = 0; -}; - -class MatMul : public MatMulBase { - public: - MatMul(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : MatMulBase(name, inputs_shape, outputs_shape, attrs) {} - ~MatMul() override = default; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; -}; - -class MatMulInfo : public MatMul { - public: - MatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : MatMul(name, inputs_shape, outputs_shape, attrs) {} - ~MatMulInfo() override = default; -}; - -class BatchMatMulInfo : public MatMul { - public: - BatchMatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : MatMul(name, inputs_shape, outputs_shape, attrs) {} - ~BatchMatMulInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/parallel/ops_info/onehot_info.cc deleted file mode 100644 index ea2d045104..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc +++ /dev/null @@ -1,311 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/onehot_info.h" - -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/device_matrix.h" -#include "parallel/graph_util/generate_graph.h" -#include "parallel/strategy.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status OneHotInfo::GetAttrs() { - auto iter = attrs_.find(AXIS); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - axis_value_ptr_ = iter->second; - axis_ = iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << ": The value of axis is not int."; - return FAILED; - } - } - - if (inputs_shape_[0].size() != 1) { - MS_LOG(ERROR) << name_ << ": Input's shape only support 1-D now."; - return FAILED; - } - - if ((axis_ > 1) || (axis_ < -1)) { - MS_LOG(ERROR) << name_ << ": Axis " << axis_ << " is out of range[-1, 1]."; - return FAILED; - } - return SUCCESS; -} - -Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { - if (inputs_shape_.size() != 3) { - MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != 1) { - MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size(); - return FAILED; - } - if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}, - is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - return SUCCESS; -} - -Status OneHotInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - - // Now input only support 1-D tensor, so the output is a 2-D tensor - // If input is a vector of length features, the output shape will be: - // [features, depth] if axis == -1 (or axis == 1) - // [depth, features] if axis == 0 - if (axis_ == 0) { - dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable - dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable - } else { - dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable - dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable - } - - return SUCCESS; -} - -Status OneHotInfo::InferTensorMap() { - std::vector input_tensor_map_index, output_tensor_map_index; - size_t size = outputs_shape_[0].size(); - // such as 2: tensor_map_index [1,0] - if (axis_ == 0) { - for (size_t i = 0; i < size; ++i) { - output_tensor_map_index.push_back((int32_t)(i)); - } - } else { - for (size_t i = 0; i < size; ++i) { - output_tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); - } - } - outputs_tensor_map_.push_back(output_tensor_map_index); - - // Now input only support 1-D tensor - input_tensor_map_index.push_back(1); - - inputs_tensor_map_.push_back(input_tensor_map_index); - return SUCCESS; -} - -// axis = -1 -// (0,(1,16),(),())reid dev_matrix=(1,16) map_in=(1) map_out=(1,0) -// (0,(16,1),(),())data parallel dev_matrix=(16,1) map_in=(1) map_out=(1,0) -// (0,(2,8),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between -// machines dev_matrix=(2,8) map_in=(1) map_out=(1,0) (0, (2,4),(),())16 devices dev_matrix=(2,4,2) map_in=(1) -// map_out=(1,0) -// axis = 0 -// (0, (16,1),(),())reid dev_matrix=(1,16) map_in=(1) map_out=(0,1) -// (0, (1,16),(),())data parallel dev_matrix=(16,1) map_in=(1) map_out=(0,1) -// (0, (8,2),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between -// machines dev_matrix=(2,8) map_in=(1) map_out=(0,1) (0,(4,2),(),())16 devices dev_matrix=(2,4,2) map_in=(1) -// map_out=(0,1) -Status OneHotInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape output_shape = outputs_shape_.at(0); - - TensorLayout input_tensor_layout, output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || - (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout); - TensorInfo output_tensor_info(output_tensor_layout); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - - return SUCCESS; -} - -Status OneHotInfo::ExtractInputInfo() { - CheckGlobalDeviceManager(); - rank_ = g_device_manager->global_rank(); - mod_rank_ = rank_ % dev_matrix_shape_.back(); - if (!cnode_) { - MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; - return FAILED; - } - if (cnode_->inputs().size() != 5) { - MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, real input size is " - << cnode_->inputs().size(); - return FAILED; - } - if (input_value_.size() != 4) { - MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, and input value size " - "must be 4, real size is " - << input_value_.size(); - return FAILED; - } - auto value_ptr = input_value_.at(1); - if (value_ptr == nullptr) { - MS_LOG(WARNING) << "Input 2 of cnode is not a value node, its type is " << cnode_->input(2)->type_name(); - return FAILED; - } - - if (value_ptr->isa()) { - total_class_number_ = value_ptr->cast()->value(); - } else { - MS_LOG(ERROR) << "OneHot Primitive depth type must be int"; - return FAILED; - } - classes_each_device_ = total_class_number_ / dev_matrix_shape_.back(); - - return SUCCESS; -} - -Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { - if (dev_matrix_shape_.back() == 1) { - replace_graph_ = nullptr; - return SUCCESS; - } - if (ExtractInputInfo() != SUCCESS) { - MS_LOG(ERROR) << "ExtractInputInfo failed"; - return FAILED; - } - GenerateGraph gen_g = GenerateGraph(); - Status status = gen_g.Init(cnode); - if (status != SUCCESS) { - MS_LOG(ERROR) << "GenerateGraph Init failed"; - return FAILED; - } - - auto floor_div = - gen_g.PushBack({gen_g.NewOpInst(FLOORDIV), gen_g.virtual_input_node(), CreateInt32Tensor(classes_each_device_)}); - auto mul1 = gen_g.PushBack({gen_g.NewOpInst(MUL), floor_div, CreateInt32Tensor(classes_each_device_)}); - auto sub1 = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), mul1}); - auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)}); - auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)}); - auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast}); - auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(TENSOR_ADD), mul2, CreateInt32Tensor(1)}); - auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add}); - auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)}); - Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_); - OperatorAttrs attrs_onehot = {attr_onehot_axis}; - auto onehot = gen_g.PushBack({gen_g.NewOpInst(ONEHOT, attrs_onehot), sub2, CreatInt32Imm(classes_each_device_), - cnode->input(3), cnode->input(4)}); - std::vector> input_nodes = {std::make_pair(floor_div, 1), std::make_pair(sub1, 1)}; - replace_graph_ = std::make_shared>, AnfNodePtr>>( - std::make_pair(input_nodes, onehot)); - - return SUCCESS; -} - -ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) { - if (ComputeReplaceGraph(cnode) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; - return nullptr; - } - return replace_graph_; -} - -Status OneHotInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - Status status = ComputeReplaceGraph(cnode_); - if (status != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; - return status; - } - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status OneHotInfo::GenerateStrategies(int32_t stage_id) { - Shapes splittable_inputs = {{1, 1}, {}, {}}; - std::vector sp_vector; - if (inputs_shape_.size() != 3) { - MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != 1) { - MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size(); - return FAILED; - } - is_auto_parallel_ = true; - if (GenerateStrategiesForIndependentInputs(stage_id, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}, - splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategies failed."; - return FAILED; - } - - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - - return SUCCESS; -} - -Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -std::shared_ptr>> OneHotInfo::GenerateBatchStrategies() { - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - Dimensions strategy = {SizeToInt(dev_num), 1}; - Dimensions empty_strategy; - std::vector strategy_v = {strategy, empty_strategy, empty_strategy}; - return std::make_shared>>(strategy_v); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/parallel/ops_info/onehot_info.h deleted file mode 100644 index 3c8a64f954..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class OneHotInfo : public OperatorInfo { - public: - OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~OneHotInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; - std::shared_ptr>> GenerateBatchStrategies() override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override; - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status ExtractInputInfo(); - - private: - Status ComputeReplaceGraph(const CNodePtr &cnode); - - int axis_ = -1; - int32_t rank_ = 0; - int32_t total_class_number_ = 1; - int32_t classes_each_device_ = 1; - ValuePtr axis_value_ptr_; - int32_t mod_rank_ = 0; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc deleted file mode 100644 index f9b294898c..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ /dev/null @@ -1,1334 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/operator_info.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/dtype.h" -#include "ir/tensor.h" -#include "ir/value.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/context.h" -#include "utils/context/ms_context.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) { - if (strategy == nullptr) { - MS_LOG(ERROR) << "The strategy is null."; - return FAILED; - } - - size_t strategy_size = strategy->GetInputNumber(); - size_t inputs_shape_size = inputs_shape.size(); - if (strategy_size != inputs_shape_size) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; - } else { - MS_LOG(ERROR) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - for (size_t i = 0; i < strategy_size; ++i) { - Shape sub_strategy = stra.at(i); - Shape sub_input_shape = inputs_shape.at(i); - size_t strategy_len = sub_strategy.size(); - size_t inputs_len = sub_input_shape.size(); - if (strategy_len != inputs_len) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len - << ", index: " << i; - } else { - MS_LOG(ERROR) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len - << ", index: " << i; - } - return FAILED; - } - - for (size_t j = 0; j < strategy_len; ++j) { - int32_t strategy_value = sub_strategy.at(j); - if (strategy_value < MIN_SLICE_NUM) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value; - } else { - MS_LOG(ERROR) << "Invalid strategy value: " << strategy_value; - } - return FAILED; - } - - if ((IntToUint(strategy_value) & IntToUint(strategy_value - 1)) != 0) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Invalid Strategy value it is not the power of 2, " << strategy_value; - } else { - MS_LOG(ERROR) << "Invalid Strategy value it is not the power of 2, " << strategy_value; - } - return FAILED; - } - - int32_t shape_value = sub_input_shape.at(j); - if ((shape_value % strategy_value) != 0) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; - } else { - MS_LOG(ERROR) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; - } - return FAILED; - } - } - } - - return SUCCESS; -} - -void OperatorInfo::ResetQueueMember() { - inputs_tensor_info_.clear(); - outputs_tensor_info_.clear(); - inputs_tensor_map_.clear(); - outputs_tensor_map_.clear(); - dev_matrix_shape_.clear(); - forward_op_.clear(); - mirror_ops_.clear(); - sub_ops_.clear(); - replace_op_.clear(); - replace_op_info_.clear(); - virtual_div_op_.clear(); - global_device_list_.clear(); -} - -Status OperatorInfo::InferAttrs() { - if (infer_attrs_completed_) { - return SUCCESS; - } - - if (GetAttrs() != SUCCESS) { - return FAILED; - } - infer_attrs_completed_ = true; - return SUCCESS; -} - -void OperatorInfo::SetDeviceListByStrategy() { - int32_t stage = strategy_->GetInputStage(); - CheckGlobalDeviceManager(); - global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); -} - -Status OperatorInfo::InferRepeatedCalcInfo() { - int32_t g_dev_list_size = SizeToInt(global_device_list_.size()); - int32_t dev_matrix_size = - std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); - if (dev_matrix_size == 0) { - MS_LOG(ERROR) << name_ << ": The dev matrix size is 0"; - return FAILED; - } - - if (g_dev_list_size == dev_matrix_size) { - repeated_calc_num_ = 1; - } else if (g_dev_list_size % dev_matrix_size == 0) { - repeated_calc_num_ = g_dev_list_size / dev_matrix_size; - } else { - MS_LOG(ERROR) << name_ << ": Dev list size " << g_dev_list_size << " can not be divisible by dev matrix size " - << dev_matrix_size; - return FAILED; - } - - CheckGlobalDeviceManager(); - int32_t rank = g_device_manager->global_rank(); - int32_t stage = strategy_->GetInputStage(); - local_device_list_ = g_device_manager->global_device_list(stage, rank, repeated_calc_num_); - - return SUCCESS; -} - -// if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix, -// only use for infer tensor layout -void OperatorInfo::SetRepeatedCalcDevMatrix() { - if (repeated_calc_num_ <= 1) { - return; - } - - (void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_); -} - -// use for loss repeated calculation -Operator CreateVirtualDivOp(int32_t div_num) { - OperatorName operator_name = VIRTUAL_DIV; - ValuePtr attr0_value = MakeValue(div_num); - Attr attr0 = std::make_pair(DIVISOR, attr0_value); - OperatorAttrs operator_attrs; - operator_attrs.push_back(attr0); - - OperatorParams operator_param; - OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); - - Operator op = std::make_pair(operator_name, operator_arg); - return op; -} - -// use for forward all reduce -Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) { - OperatorName operator_name = ALL_REDUCE; - ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM - ValuePtr attr1_value = MakeValue(group); // group - Attr attr0 = std::make_pair(OP, attr0_value); - Attr attr1 = std::make_pair(GROUP, attr1_value); - OperatorAttrs operator_attrs; - operator_attrs.push_back(attr0); - operator_attrs.push_back(attr1); - - OperatorParams operator_param; - OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); - - Operator op = std::make_pair(operator_name, operator_arg); - MS_LOG(INFO) << "Create all reduce op success, the reduce_op is " << reduce_op << ", the group is " << group; - return op; -} - -Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) { - OperatorName operator_name = REDUCE_SCATTER; - ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM - ValuePtr attr1_value = MakeValue(group); // group - Attr attr0 = std::make_pair(OP, attr0_value); - Attr attr1 = std::make_pair(GROUP, attr1_value); - OperatorAttrs operator_attrs; - operator_attrs.push_back(attr0); - operator_attrs.push_back(attr1); - - OperatorParams operator_param; - OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); - - Operator op = std::make_pair(operator_name, operator_arg); - MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is " << reduce_op << ", the group is " << group; - return op; -} - -// use for get tensor slice -Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { - Shape tensor_map = tensor_layout.tensor_map().array(); - Shape dev_matrix_shape = tensor_layout.device_arrangement().array(); - OperatorName operator_name = GET_TENSOR_SLICE; - - OperatorAttrs attrs; - ValuePtr dev_mat_value = MakeValue(dev_matrix_shape); - Param dev_mat_param = std::make_pair(std::make_pair(DEV_MAT, dev_mat_value), 2); - ValuePtr tensor_map_value = MakeValue(tensor_map); - Param tensor_map_param = std::make_pair(std::make_pair(TENSOR_MAP, tensor_map_value), 3); - OperatorParams params = {dev_mat_param, tensor_map_param}; - OperatorArgs operator_arg = std::make_pair(attrs, params); - - Operator op = std::make_pair(operator_name, operator_arg); - MS_LOG(INFO) << "Create get tensor slice op success, the dev mat and tensor map is " - << ShapeToString(dev_matrix_shape) << ", " << ShapeToString(tensor_map); - return op; -} - -OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { - if ((dev_num == 0) || (dev_num == 1)) { - MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num; - } - OperatorVector op_for_weight; - bool mean_flag = ParallelContext::GetInstance()->mirror_mean(); - - OperatorName operator_name = MIRROR_OPERATOR; - ValuePtr attr0_value = MakeValue(group_name); - ValuePtr attr1_value = MakeValue(SizeToInt(dev_num)); - ValuePtr attr2_value = MakeValue(mean_flag); - - Attr attr0 = std::make_pair(GROUP, attr0_value); - Attr attr1 = std::make_pair(DEV_NUM, attr1_value); - Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value); - - OperatorAttrs operator_attrs; - operator_attrs.push_back(attr0); - operator_attrs.push_back(attr1); - operator_attrs.push_back(attr2); - - OperatorParams operator_param; - OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); - - Operator op = std::make_pair(operator_name, operator_args); - - op_for_weight.push_back(op); - MS_LOG(INFO) << "The group name is " << group_name << ", the dev num is " << dev_num << ", the mean flag is " - << mean_flag; - return op_for_weight; -} - -Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group) { - if (group == nullptr) { - MS_LOG(ERROR) << "The group is null."; - return FAILED; - } - CheckGlobalDeviceManager(); - int32_t rank = g_device_manager->global_rank(); - DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); - RankList group_devices; - if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { - return FAILED; - } - - if (group_devices.size() == 1) { - MS_LOG(INFO) << "The dev size is 1, no need to create group."; - return SUCCESS; - } - - Group g = g_device_manager->CreateGroup(group_devices); - group->push_back(g); - return SUCCESS; -} - -Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector *group) { - if (group == nullptr) { - MS_LOG(ERROR) << "The group is null."; - return FAILED; - } - CheckGlobalDeviceManager(); - int32_t rank = g_device_manager->global_rank(); - DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); - RankList group_devices; - if (dev_matrix.GetDevicesAlongDim(SizeToUint(axis), &group_devices) != SUCCESS) { - return FAILED; - } - - if (group_devices.size() == 1) { - MS_LOG(INFO) << "The dev size is 1, no need to create group."; - return SUCCESS; - } - - Group g = g_device_manager->CreateGroup(group_devices); - group->push_back(g); - return SUCCESS; -} - -Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) { - Shape slice_shape; - if (std::any_of(strategy.begin(), strategy.end(), [](int32_t value) { return value <= 0; })) { - MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0"; - return slice_shape; - } - for (size_t i = 0; i < strategy.size(); ++i) { - slice_shape.push_back(tensor_shape.at(i) / strategy.at(i)); - } - return slice_shape; -} - -Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) { - if (slice_shapes == nullptr) { - MS_LOG(ERROR) << "The slice_shapes is null."; - return FAILED; - } - if (strategys.size() != shapes.size()) { - MS_LOG(ERROR) << "Strategy size " << strategys.size() << " not equal to shape size " << shapes.size(); - return FAILED; - } - - for (size_t i = 0; i < strategys.size(); ++i) { - if (strategys.at(i).size() != shapes.at(i).size()) { - MS_LOG(ERROR) << "Strategy dimension " << strategys.at(i).size() << " not equal to shape dimension " - << shapes.at(i).size(); - slice_shapes->clear(); - return FAILED; - } - - for (size_t j = 0; j < shapes.at(i).size(); ++j) { - if (strategys.at(i).at(j) <= 0) { - MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategys[i]) - << " the element is less than or equal to 0."; - slice_shapes->clear(); - return FAILED; - } - if (shapes.at(i).at(j) % strategys.at(i).at(j) != 0) { - MS_LOG(ERROR) << "Shape cannot be divisible by strategy, " << shapes.at(i).at(j) << " : " - << strategys.at(i).at(j); - slice_shapes->clear(); - return FAILED; - } - } - Shape slice_shape = GetSliceShape(shapes.at(i), strategys.at(i)); - slice_shapes->push_back(slice_shape); - } - - return SUCCESS; -} - -Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, - Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) { - if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) { - MS_LOG(ERROR) << "The slice_shape is null."; - return FAILED; - } - - if (InferSliceShapeByStrategy(inputs_strategy, inputs_shape_, inputs_slice_shape) != SUCCESS) { - MS_LOG(ERROR) << "Infer inputs slice shape error."; - return FAILED; - } - - if (InferSliceShapeByStrategy(outputs_strategy, outputs_shape_, outputs_slice_shape) != SUCCESS) { - MS_LOG(ERROR) << "Infer outputs slice shape error."; - inputs_slice_shape->clear(); - return FAILED; - } - - return SUCCESS; -} - -// method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1 -Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) { - if (strategy == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null."; - return FAILED; - } - - if (InferAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferAttrs failed."; - return FAILED; - } - - // must be after InferAttrs() - if (CheckStrategy(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": CheckStrategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": CheckStrategy failed."; - } - return FAILED; - } - - // need to clear queues before Init(), - // because Init() may be called multiple times by cost model - ResetQueueMember(); - - strategy_ = strategy; - SetDeviceListByStrategy(); - - if (InferDevMatrixShape() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; - return FAILED; - } - - used_devices_ = std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); - - // must be after InferDevMatrixShape - if (InferRepeatedCalcInfo() != SUCCESS) { - MS_LOG(ERROR) << ": InferRepeatedCalcInfo failed."; - return FAILED; - } - - // if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix for layout - SetRepeatedCalcDevMatrix(); - - if (InferTensorMap() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorMap failed."; - return FAILED; - } - - if (InferTensorInfo() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; - return FAILED; - } - - return SUCCESS; -} - -// method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape -Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) { - if (strategy == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null."; - return FAILED; - } - - if (InferAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferAttrs failed."; - return FAILED; - } - - // must be after InferAttrs() - if (CheckStrategy(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": CheckStrategy failed."; - return FAILED; - } - - // need to clear queues before Init(), - // because Init() may be called multiple times by cost model - ResetQueueMember(); - - strategy_ = strategy; - SetDeviceListByStrategy(); - - if (InferDevMatrixShape() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; - return FAILED; - } - - // must be after InferDevMatrixShape - if (InferRepeatedCalcInfo() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferRepeatedCalcInfo failed."; - return FAILED; - } - - if (InferTensorMap() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorMap failed."; - return FAILED; - } - - if (InferTensorInfo() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; - return FAILED; - } - - return SUCCESS; -} - -Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) { - if (strategy == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null."; - return FAILED; - } - - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - return FAILED; - } - - if (InferForwardCommunication() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed."; - return FAILED; - } - - if (InferMirrorOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; - return FAILED; - } - - if (InferVirtualDivOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; - return FAILED; - } - - return SUCCESS; -} - -Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) { - if (strategy == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null."; - return FAILED; - } - - if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { - return FAILED; - } - - if (InferForwardCommunication() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed."; - return FAILED; - } - - if (InferMirrorOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; - return FAILED; - } - - if (InferVirtualDivOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; - return FAILED; - } - - return SUCCESS; -} - -std::vector> OperatorInfo::GetAliveSuccEdges() { - std::vector> ret; - for (auto &edge : succ_edges_) { - if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) { - ret.push_back(edge); - } else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) { - // CAST is ordered in front of L2NORMALIZE - ret.push_back(edge); - } - } - for (auto &edge : succ_edges_) { - if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) && - (edge->next_operator()->name().find(CAST) == std::string::npos)) { - ret.push_back(edge); - } - } - return ret; -} - -std::vector> OperatorInfo::GetAlivePrevEdges() { - std::vector> ret; - for (auto &edge : prev_edges_) { - if (edge->prev_operator()->is_alive()) { - ret.push_back(edge); - } - } - return ret; -} - -void OperatorInfo::ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { - if (op == nullptr) { - MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null."; - return; - } - for (auto &edge : prev_edges_) { - if (edge->prev_operator() == op) { - edge = new_edge; - return; - } - } - MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; -} - -void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { - if (op == nullptr) { - MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null."; - return; - } - for (auto &edge : succ_edges_) { - if (edge->next_operator() == op) { - edge = new_edge; - return; - } - } - MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; -} - -void OperatorInfo::ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { - if (op == nullptr) { - MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null."; - return; - } - std::vector> new_pre_edges; - for (auto &edge : prev_edges_) { - if (edge->prev_operator() != op) { - new_pre_edges.push_back(edge); - } - } - new_pre_edges.push_back(new_edge); - prev_edges_ = new_pre_edges; -} - -void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { - if (op == nullptr) { - MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null"; - return; - } - std::vector> new_succ_edges; - for (auto &edge : succ_edges_) { - if (edge->next_operator() != op) { - new_succ_edges.push_back(edge); - } - } - new_succ_edges.push_back(new_edge); - succ_edges_ = new_succ_edges; -} - -std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes &shapes, const std::vector &split_flag_list) { - if (shapes.size() != split_flag_list.size()) { - MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : " - << shapes.size(); - return nullptr; - } - CheckGlobalDeviceManager(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); - std::vector> strategy_v; - for (size_t i = 0; i != shapes.size(); i++) { - if (shapes[i].empty()) { - MS_LOG(INFO) << "Elements of shapes is empty."; - std::vector empty_element; - strategy_v.push_back(empty_element); - } else { - std::vector element(shapes[i].size(), 1); - if (split_flag_list[i]) { - element[0] = dev_num; - } - strategy_v.push_back(element); - } - } - return std::make_shared>>(strategy_v); -} - -void OperatorInfo::ReComputeBatchSplitFlagList() { - if (!inputs_shape_.empty()) { - split_flag_list_[0] = true; - } -} - -void OperatorInfo::ComputeBatchSplitFlagList() { - split_flag_list_.clear(); - for (auto iter = inputs_shape_.begin(); iter != inputs_shape_.end(); ++iter) { - split_flag_list_.push_back(false); - } - ReComputeBatchSplitFlagList(); -} - -// This is a common method for checking whether the generated stragegy has the correct number of devuces. -Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { - if (sp == nullptr) { - MS_LOG(ERROR) << "The strategy is null."; - return FAILED; - } - int32_t product = 1; - - for (auto &input_partition : inputs_partitions) { - product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies()); - } - if (!FULLY_USE_DEVICES) { - if (IntToSize(product) > dev_num) { - return FAILED; - } - } else { - if ((product != 1) && (IntToSize(product) != dev_num)) { - return FAILED; - } - } - std::vector stras(inputs_partitions); - (*sp) = std::make_shared(stage_id, stras); - return SUCCESS; -} - -std::shared_ptr>> OperatorInfo::GenerateBatchStrategies() { - ComputeBatchSplitFlagList(); - return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); -} - -void PrintStrategy(const StrategyPtr &strategy) { - if (strategy == nullptr) { - return; - } - std::string all_strategy = ""; - for (size_t i = 0; i < strategy->GetInputNumber(); ++i) { - all_strategy += "["; - for (size_t j = 0; j < strategy->GetInputDim()[i].size(); ++j) { - all_strategy += std::to_string(strategy->GetInputDim()[i][j]); - if (j != strategy->GetInputDim()[i].size() - 1) { - all_strategy += ", "; - } - } - all_strategy += "]"; - if (i != strategy->GetInputNumber() - 1) { - all_strategy += ", "; - } - } - MS_LOG(INFO) << "The strategy is: " << all_strategy; -} - -// generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d]) -Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes &inputs_shape, - const Shapes &splittable_inputs, std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - - if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) { - MS_LOG(ERROR) << "The inputs size is wrong."; - return FAILED; - } - - if ((inputs_shape[0].size() != inputs_shape[1].size()) || - (splittable_inputs[0].size() != splittable_inputs[1].size())) { - MS_LOG(ERROR) << "The size of two inputs are not equal."; - return FAILED; - } - - Shapes input0_shape = {inputs_shape[0]}; - Shapes input0_splittable = {splittable_inputs[0]}; - if (GenerateStrategiesForIndependentInputs(stage_id, input0_shape, input0_splittable, sp_vector) != SUCCESS) { - return FAILED; - } - - for (auto &sp : *sp_vector) { - sp->ExpandInputDimFromOneToTwo(); - } - - return SUCCESS; -} - -// generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast -// such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) -Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, - std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - - if (inputs_shape[0].size() >= inputs_shape[1].size()) { - MS_LOG(ERROR) << "Invalid inputs shape."; - return FAILED; - } - - // first, generate strategy for input0 the same as input1 - Shapes tmp_inputs_shape = {inputs_shape[1], inputs_shape[1]}; - Shapes tmp_splittable_inputs = {splittable_inputs[1], splittable_inputs[1]}; - if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; - return FAILED; - } - - // second, get the correct strategy for input0 - for (auto &sp : *sp_vector) { - std::vector tmp_strategy; - Dimensions input0_strategy = sp->GetInputDim()[0]; - size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size(); - - // erase the unnecessary part - (void)input0_strategy.erase(input0_strategy.begin(), - input0_strategy.begin() + static_cast(size_diff)); - - // handel the case likes ([1, c, d], [a, b, c, d]) - for (size_t i = 0; i < inputs_shape[0].size(); ++i) { - if (inputs_shape[0][i] == 1) { - input0_strategy[i] = 1; - } else { - break; - } - } - - // reset the strategy - tmp_strategy.push_back(input0_strategy); // input0 - tmp_strategy.push_back(sp->GetInputDim()[1]); // input1 - sp->ResetInputs(tmp_strategy); - } - return SUCCESS; -} - -// generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast -// such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) -Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes &inputs_shape, - const Shapes &splittable_inputs, std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - - if (inputs_shape[0].size() <= inputs_shape[1].size()) { - MS_LOG(ERROR) << "Invalid inputs shape."; - return FAILED; - } - - // first, generate strategy for input1 the same as input0 - Shapes tmp_inputs_shape = {inputs_shape[0], inputs_shape[0]}; - Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]}; - if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; - return FAILED; - } - - // second, get the correct strategy for input1 - for (auto &sp : *sp_vector) { - std::vector tmp_strategy; - tmp_strategy.push_back(sp->GetInputDim()[0]); // input0 - - Dimensions input1_strategy = sp->GetInputDim()[1]; - size_t size_diff = inputs_shape[0].size() - inputs_shape[1].size(); - - // erase the unnecessary part - (void)input1_strategy.erase(input1_strategy.begin(), - input1_strategy.begin() + static_cast(size_diff)); - - // handel the case likes ([a, b, c, d], [1, c, d]) - for (size_t i = 0; i < inputs_shape[1].size(); ++i) { - if (inputs_shape[1][i] == 1) { - input1_strategy[i] = 1; - } else { - break; - } - } - - // reset the strategy - tmp_strategy.push_back(input1_strategy); // input1 - sp->ResetInputs(tmp_strategy); - } - return SUCCESS; -} - -// generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast -// such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, - std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - - if (inputs_shape[0].size() != inputs_shape[1].size()) { - MS_LOG(ERROR) << "Invalid inputs shape."; - return FAILED; - } - - // step1: ([a, 1], [1, b]) -> [a, b] - Shape max_shape, splittable_vector; - for (size_t i = 0; i < inputs_shape[0].size(); ++i) { - if (inputs_shape[0][i] >= inputs_shape[1][i]) { - max_shape.push_back(inputs_shape[0][i]); - splittable_vector.push_back(splittable_inputs[0][i]); - } else { - max_shape.push_back(inputs_shape[1][i]); - splittable_vector.push_back(splittable_inputs[1][i]); - } - } - - // step2: ([a, 1], [1, b]) -> generate strategy for ([a, b], [a, b]) - Shapes tmp_inputs_shape = {max_shape, max_shape}; - Shapes tmp_splittable_inputs = {splittable_vector, splittable_vector}; - if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; - return FAILED; - } - - // step3: reset the strategy if the dimension is 1 - for (auto &sp : *sp_vector) { - Dimensions input0_strategy = sp->GetInputDim()[0]; - Dimensions input1_strategy = sp->GetInputDim()[1]; - for (size_t i = 0; i < inputs_shape[0].size(); ++i) { - if (inputs_shape[0][i] == 1) { - input0_strategy[i] = 1; - } - - if (inputs_shape[1][i] == 1) { - input1_strategy[i] = 1; - } - } - sp->ResetInputs({input0_strategy, input1_strategy}); - } - - return SUCCESS; -} - -// 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that -// the corresponding dimension is unsplittable, '1' in 'splittable_inputs' means that the corresponding -// dimension is splittable. 'inputs_partitions' is the result of partitions. -// NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring -// specific dimensions in inputs have the identical partition should have individual implementation. -Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, - const Shapes &splittable_inputs, - std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - if (splittable_inputs.size() != inputs_shape.size()) { - MS_LOG(ERROR) << "Splittable_inputs do not have the same input number of inputs shape, " << splittable_inputs.size() - << " : " << inputs_shape.size(); - return FAILED; - } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape combined_inputs_shape, combined_splittable_inputs, combined_partitions; - for (size_t j = 0; j < inputs_shape.size(); ++j) { - (void)combined_inputs_shape.insert(combined_inputs_shape.end(), inputs_shape[j].begin(), inputs_shape[j].end()); - (void)combined_splittable_inputs.insert(combined_splittable_inputs.end(), splittable_inputs[j].begin(), - splittable_inputs[j].end()); - } - std::function recursive = [&stage_id, &dev_num, &sp_vector, &combined_inputs_shape, - &combined_splittable_inputs, &combined_partitions, &recursive, - &inputs_shape](uint32_t current_index, size_t n) { - if (current_index == combined_inputs_shape.size()) { - MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size(); - Shapes inputs_partitions; - size_t global_index = 0; - for (auto &shape : inputs_shape) { - Shape tmp_partition; - for (size_t j = 0; j < shape.size(); ++j) { - tmp_partition.push_back(combined_partitions[global_index]); - global_index++; - } - inputs_partitions.push_back(tmp_partition); - } - StrategyPtr sp; - if (PrepareStrategyBase(stage_id, dev_num, inputs_partitions, &sp) == SUCCESS) { - sp_vector->push_back(sp); - } - return; - } else { - MS_LOG(DEBUG) << "The value of sp_vector size is " << sp_vector->size(); - if (combined_splittable_inputs[current_index] == 0) { - combined_partitions.push_back(MIN_SLICE_NUM); - recursive(current_index + 1, n / MIN_SLICE_NUM); - combined_partitions.pop_back(); - } else if (combined_splittable_inputs[current_index] == 1) { - for (uint32_t i = 1; i <= n; i *= 2) { - if (n % i == 0 && IntToSize(combined_inputs_shape[current_index]) % i == 0) { - combined_partitions.push_back(i); - recursive(current_index + 1, n / i); - combined_partitions.pop_back(); - } - } - } - } - }; - recursive(0, dev_num); - if (sp_vector->empty()) { - MS_LOG(EXCEPTION) << "No available strategy for current OperatorInfo."; - } - return SUCCESS; -} - -// generate strategies for that have two inputs, and input0 or input1 maybe broadcast, -// and the corresponding dimensions that are not broadcast are all relevant dimensions -// such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) -// or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) -// or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, - std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - - if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) { - MS_LOG(ERROR) << "The inputs' size is wrong."; - return FAILED; - } - - if (inputs_shape[0] == inputs_shape[1]) { - // element wise operation([a, b, c, d], [a, b, c, d]), so input0's strategy is equal to input1's strategy - if (GenerateStrategiesForTwoEqualInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; - return FAILED; - } - MS_LOG(INFO) << "GenerateStrategiesForTwoEqualInputs success."; - } else if (inputs_shape[0].empty() || inputs_shape[1].empty()) { - // ([a, b, c, d], []) or ([], [a, b, c, d]) - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "Generate strategies for scalar case failed."; - return FAILED; - } - MS_LOG(INFO) << "Generate strategies for scalar case success."; - } else if (inputs_shape[0].size() > inputs_shape[1].size()) { - // ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) - if (GenerateStrategiesForBroadcastRight(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForBroadcastRight failed."; - return FAILED; - } - MS_LOG(INFO) << "GenerateStrategiesForBroadcastRight success."; - } else if (inputs_shape[0].size() < inputs_shape[1].size()) { - // ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) - if (GenerateStrategiesForBroadcastLeft(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForBroadcastLeft failed."; - return FAILED; - } - MS_LOG(INFO) << "GenerateStrategiesForBroadcastLeft success."; - } else { // same size, but different value - // ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) - if (GenerateStrategiesForBroadcastBoth(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForBroadcastBoth failed."; - return FAILED; - } - MS_LOG(INFO) << "GenerateStrategiesForBroadcastBoth success."; - } - return SUCCESS; -} - -Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { - if (InitForCostModel(strategy) == FAILED) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Initialization under the strategy failed."; - } - return FAILED; - } - int32_t stage_id = strategy->GetInputStage(); - double computation_cost = - operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - std::shared_ptr result = std::make_shared(computation_cost, communication_cost); - result->communication_without_parameter_ = - operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - result->communication_with_partial_para_ = - result->communication_without_parameter_ + - COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); - - // Breaking ties for preferring data parallelization - BreakingTiesForPerferringDataParallel(strategy, result); - // refine communication cost calculation for practice - RefineForPracticalCost(result, false); - result->communication_forward_ = result->communication_without_parameter_; - - std::shared_ptr swc = - std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); - swc->cost_list.push_back(result); - strategy_cost_.emplace_back(swc); - - return SUCCESS; -} - -int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { - if (is_output_parameter_involve_ != -1) { - return is_output_parameter_involve_; - } - is_parameter_involve_ = is_parameter_; - const auto &prev_edges = this->GetAlivePrevEdges(); - for (auto &p_edge : prev_edges) { - auto input_index = p_edge->next_op_input_index(); - auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved(); - if (input_index >= is_parameter_involve_.size()) { - MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size() - << ", but got wrong input_index: " << input_index; - } - if (prev_op_para == 0) { - is_parameter_involve_[input_index] = false; - } else if (prev_op_para == 1) { - is_parameter_involve_[input_index] = true; - } else { - MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index; - } - p_edge->set_parameter_involve(prev_op_para); - } - if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) { - // If anyone of the input is a parameter_involved, the output is parameter_involved. - is_output_parameter_involve_ = 1; - } else { - is_output_parameter_involve_ = 0; - } - - return is_output_parameter_involve_; -} - -Status OperatorInfo::set_is_parameter(const std::vector &is_parameter) { - if (is_parameter.size() != inputs_shape_.size()) { - MS_LOG(ERROR) << "Is_parameter: " << is_parameter.size() - << " do not have the same number of inputs_shape_: " << inputs_shape_.size(); - return FAILED; - } - is_parameter_ = is_parameter; - operator_cost()->set_is_parameter(is_parameter); - return SUCCESS; -} - -Status OperatorInfo::CalculateMemoryCost() { - // First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to - // calculate memory cost. - if (is_parameter_involve_.size() != is_parameter_.size()) { - MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."; - return FAILED; - } - operator_cost()->set_is_parameter_involve(is_parameter_involve_); - operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); - // Set the memory cost in the 'strategy_cost_' - for (auto &swc : strategy_cost_) { - auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); - swc->cost_list[0]->memory_with_reuse_ = mem_cost; - } - return SUCCESS; -} - -Status OperatorInfo::CalculateMemoryCostForInference() { - // First, set the 'is_outputs_critical_' flag into OperatorCost. - if (is_output_critical_ == -1) { - MS_LOG(EXCEPTION) << "The critical flag is not set."; - return FAILED; - } - operator_cost()->set_output_critical(is_output_critical_); - // Set the memory cost in the 'strategy_cost_' - for (auto &swc : strategy_cost_) { - auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr); - swc->cost_list[0]->memory_with_reuse_ = mem_cost; - } - return SUCCESS; -} - -Status OperatorInfo::CorrectMemoryCost(size_t input_index) { - for (auto &swc : strategy_cost_) { - double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * - static_cast(operator_cost()->inputs_type_lengths()[input_index]); - swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost; - if (swc->cost_list[0]->memory_with_reuse_ < 0) { - MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_ - << ", the parameter memory cost is: " << parameter_mem_cost; - return FAILED; - } - } - return SUCCESS; -} - -int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) { - int32_t ret = -1; - - // The number of repetitions is equal to the number of all devices divided by the number of devices use for - // tensor map. - int32_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies()); - for (auto &element : tensor_map) { - // -1 means the corresponding dimension is not split. - if (element == MAP_NONE) { - continue; - } else if ((element < 0) || (IntToSize(element) >= dev_matrix_shape.size())) { - MS_LOG(ERROR) << "Invalid tensor map: " << ShapeToString(tensor_map) << ", the dev matrix shape is " - << ShapeToString(dev_matrix_shape); - return ret; - } else { - size_t index = dev_matrix_shape.size() - IntToSize(element) - 1; - if (dev_matrix_shape[index] <= 0) { - MS_LOG(ERROR) << "Invalid dev matrix shape: " << ShapeToString(dev_matrix_shape); - return ret; - } - device_num /= dev_matrix_shape[index]; - } - } - - return device_num; -} - -Status OperatorInfo::InferAsLossDivisor() { - if (!ParallelContext::GetInstance()->loss_repeated_mean()) { - as_loss_divisor_ = 1; - return SUCCESS; - } - - if (outputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty."; - return FAILED; - } - - if (outputs_tensor_map_.size() > 1) { - MS_LOG(ERROR) << name_ << ": The output size is " << outputs_tensor_map_.size() - << ", need to override this function "; - return FAILED; - } - - if (outputs_tensor_map_[0].empty()) { - as_loss_divisor_ = SizeToInt(global_device_list_.size()); - MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; - return SUCCESS; - } - - as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); - MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(dev_matrix_shape_) - << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is " - << as_loss_divisor_; - return SUCCESS; -} - -// If the operator is used as a loss, a div node is inserted for the grad of all its inputs. -Status OperatorInfo::InferVirtualDivOps() { - if (InferAsLossDivisor() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferAsLossDivisor failed."; - return FAILED; - } - - if (as_loss_divisor_ <= 0) { - MS_LOG(ERROR) << name_ << ": Invalid loss divisor: " << as_loss_divisor_; - return FAILED; - } else if (as_loss_divisor_ == 1) { - MS_LOG(INFO) << name_ << ": The loss divisor is 1, no need to create virtual div op."; - return SUCCESS; - } - - virtual_div_op_.clear(); - // if loss is repeated calculation, insert div op - Operator op = CreateVirtualDivOp(as_loss_divisor_); - virtual_div_op_.push_back(op); - return SUCCESS; -} - -Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector &input_lengths, - const std::vector &output_lengths) { - if (input_lengths.size() != inputs_shape_.size()) { - MS_LOG(ERROR) << "Input_lengths: " << input_lengths.size() - << " do not have the same number of inputs shape: " << inputs_shape_.size(); - return FAILED; - } - if (output_lengths.size() != outputs_shape_.size()) { - MS_LOG(ERROR) << "Output_lengths: " << output_lengths.size() - << " do not have the same number of outputs shape: " << outputs_shape_.size(); - return FAILED; - } - inputs_type_lengths_ = input_lengths; - outputs_type_lengths_ = output_lengths; - operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); - return SUCCESS; -} - -double OperatorInfo::GetOutputsTotalSize() { - if (is_calculated_outputs_size_) { - return outputs_total_size_; - } - if (outputs_type_lengths_.size() != outputs_shape_.size()) { - MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size() - << " do not have the same number of outputs shape: " << outputs_shape_.size(); - } - double sum = 0.0; - for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) { - auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast(1.0), - std::multiplies()); - sum += size * static_cast(outputs_type_lengths_[i]); - } - is_calculated_outputs_size_ = true; - outputs_total_size_ = sum; - return outputs_total_size_; -} - -Status OperatorInfo::set_outputs_type(const std::vector &outputs_type) { - if (outputs_type.size() != outputs_shape_.size()) { - MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() - << " do not have the same number of outputs shape: " << outputs_shape_.size(); - return FAILED; - } - outputs_type_ = outputs_type; - return SUCCESS; -} - -void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { - if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { - CheckGlobalDeviceManager(); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); - if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) { - if (cost->computation_cost_ > 1.0) { - cost->computation_cost_ -= 1.0; - } - if (cost->communication_cost_ > 1.0) { - cost->communication_cost_ -= 1.0; - } - if (cost->communication_with_partial_para_ > 1.0) { - cost->communication_with_partial_para_ -= 1.0; - } - if (cost->communication_without_parameter_ > 1.0) { - cost->communication_without_parameter_ -= 1.0; - } - } - } -} - -double OperatorInfo::GetForwardMemoryCostFromCNode() { - return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); -} - -void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) { - MS_EXCEPTION_IF_NULL(s_strategy); - if (!s_strategy->IsEqual(selected_strategy_)) { - MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:"; - PrintStrategy(selected_strategy_); - MS_LOG(INFO) << "The minimal strategy:"; - PrintStrategy(s_strategy); - } -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h deleted file mode 100644 index 21041c3e94..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ /dev/null @@ -1,289 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "ir/base.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/group_manager.h" -#include "parallel/ops_info/ops_utils.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_info.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -using ForwardOp = OperatorVector; -using MirrorOps = std::vector; -using Ops = std::vector; -using VirtualDivOp = OperatorVector; -using TensorMaps = std::vector>; -using TensorLayouts = std::vector; -using different_type = std::vector::difference_type; -using PrimitiveAttrs = std::unordered_map; -using Strategys = std::vector; -using ReplaceGraphPtr = std::shared_ptr>, AnfNodePtr>>; - -class Edge; - -class OperatorInfo { - public: - OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs, OperatorCostPtr cost) - : name_(std::move(name)), - inputs_shape_(std::move(inputs_shape)), - outputs_shape_(std::move(outputs_shape)), - attrs_(std::move(attrs)), - is_alive_(true), - operator_cost_(cost), - outputs_type_() { - std::vector not_parameteter(inputs_shape_.size(), false); - is_parameter_ = not_parameteter; - refkey_parameter_name_ = ""; - } - - virtual ~OperatorInfo() = default; - - Status set_is_parameter(const std::vector &is_parameter); - Status SetInputAndOutputTypeLength(const std::vector &input_lengths, - const std::vector &output_lengths); - double GetOutputsTotalSize(); - // Set outputs dtype. - // If only one output, outputs_type.size() is 1. - // If output is tuple, outputs_type.size() is greater than 1. - Status set_outputs_type(const std::vector &outputs_type); - const std::vector &outputs_type() const { return outputs_type_; } - virtual Status Init(const StrategyPtr &strategy) = 0; - virtual Status InitForCostModel(const StrategyPtr &strategy) = 0; // only init the necessary parts - - // Given the stage_id (which indicates the number of devices), - // generate all strategies for this operator - virtual Status GenerateStrategies(int32_t stage_id) = 0; - const OperatorCostPtr &operator_cost() const { return operator_cost_; } - void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; } - virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0; - - virtual std::shared_ptr>> GenerateBatchStrategies(); - virtual void ReComputeBatchSplitFlagList(); - void ComputeBatchSplitFlagList(); - - double GetForwardMemoryCostFromCNode(); - // This is a common method for setting operator cost for a given strategy, in which the validity of this strategy - // is checked - Status SetCostUnderStrategyBase(const StrategyPtr &strategy); - std::vector> GetStrategyCost() { return strategy_cost_; } - // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving - // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory - // at the end of forward phase. - Status CalculateMemoryCost(); - // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated - // by the output - Status CalculateMemoryCostForInference(); - int ComputeOpAndPrevEdgeParameterInvolved(); - - ForwardOp forward_op() const { return forward_op_; } - ForwardOp replace_op() const { return replace_op_; } - OutPutInfoVector replace_op_info() const { return replace_op_info_; } - virtual ReplaceGraphPtr replace_graph(const CNodePtr &) { return replace_graph_; } - MirrorOps mirror_ops() const { return mirror_ops_; } - Ops sub_ops() const { return sub_ops_; } - VirtualDivOp virtual_div_op() const { return virtual_div_op_; } - Shape dev_matrix_shape() const { return dev_matrix_shape_; } - std::vector inputs_tensor_info() const { return inputs_tensor_info_; } - std::vector outputs_tensor_info() const { return outputs_tensor_info_; } - std::vector> strategy_cost() const { return strategy_cost_; } - const std::string &name() const { return name_; } - void set_name(const std::string &name) { name_ = name; } - RankList global_device_list() const { return global_device_list_; } - - void AddSuccEdge(const std::shared_ptr &e) { succ_edges_.push_back(e); } - void AddPrevEdge(const std::shared_ptr &e) { prev_edges_.push_back(e); } - std::vector> succ_edges() const { return succ_edges_; } - std::vector> prev_edges() const { return prev_edges_; } - std::vector> GetAliveSuccEdges(); - std::vector> GetAlivePrevEdges(); - void ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); - void ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); - void ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); - void ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); - std::vector GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); } - void SetSelectedStrategyAndCost(const StrategyPtr &s_strategy, const CostPtr &cost) { - selected_strategy_ = s_strategy; - selected_cost_ = cost; - } - StrategyPtr selected_strategy() const { return selected_strategy_; } - CostPtr selected_cost() const { return selected_cost_; } - void CheckSelectedStrategy(const StrategyPtr &); - Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } - void set_input_value(const std::vector &input_value) { input_value_ = input_value; } - const std::vector &input_value() const { return input_value_; } - void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; } - void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } - bool is_alive() const { return is_alive_; } - void SetNotAlive() { is_alive_ = false; } - StrategyPtr strategy() const { return strategy_; } - void set_strategy(const StrategyPtr &strategy) { strategy_ = strategy; } - void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } - const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } - // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated - // multiple times. This method is to correct this, and makes the cost is calulated only once. - Status CorrectMemoryCost(size_t input_index); - int is_output_parameter_involve() const { return is_output_parameter_involve_; } - int is_output_critical() const { return is_output_critical_; } - void mark_output_critical() { is_output_critical_ = 1; } - void mark_output_not_critical() { is_output_critical_ = 0; } - int used_devices() const { return used_devices_; } - // needed by rec_parser - void set_type(const std::string &type) { type_ = type; } - const std::string &type() const { return type_; } - const std::unordered_map &attrs() const { return attrs_; } - - protected: - // needed by rec_parser - std::string type_; - virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; - virtual Status InferTensorMap() = 0; - virtual Status InferForwardCommunication() = 0; - virtual Status InferMirrorOps() = 0; - virtual Status GetAttrs() = 0; - virtual Status InferTensorInfo() = 0; - virtual Status InferDevMatrixShape() = 0; - void SetDeviceListByStrategy(); - void SetRepeatedCalcDevMatrix(); - Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group); - Status CreateGroupByDim(size_t axis, std::vector *group); - Status InferAttrs(); - void ResetQueueMember(); - Status InitWithAutoRepeatCalc(const StrategyPtr &strategy); - Status InitWithManualRepeatCalc(const StrategyPtr &strategy); - Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy); - Status InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy); - Status InferRepeatedCalcInfo(); - Status InferVirtualDivOps(); - - // Calculate the number of repeated calculations for the output by the number of devices and the output tensor map. - // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output - // is used for grad and overload the function. If the output is a scalar, need to override the function too. - virtual Status InferAsLossDivisor(); - Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, - Shapes *inputs_slice_shape, Shapes *outputs_slice_shape); - void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &); - - std::string name_; - Shapes inputs_shape_; - Shapes outputs_shape_; - std::unordered_map attrs_; - std::vector input_value_; - TypePtr outputs_dtype_; - - StrategyPtr strategy_; - std::vector inputs_tensor_info_; - std::vector outputs_tensor_info_; - Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num as the first dimension - int32_t repeated_calc_num_ = 1; - int32_t as_loss_divisor_ = 1; - TensorMaps inputs_tensor_map_; - TensorMaps outputs_tensor_map_; - ForwardOp forward_op_; - Ops sub_ops_; - ForwardOp replace_op_; - OutPutInfoVector replace_op_info_; - ReplaceGraphPtr replace_graph_; - MirrorOps mirror_ops_; - VirtualDivOp virtual_div_op_; - RankList global_device_list_; // the size of global_device_list equal to the size of stageID - RankList local_device_list_; // the size equal to global_device_list_.size() / repeated_calc_num_ - bool infer_attrs_completed_ = false; - - bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel - // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected. - std::vector corrected_input_indices_; - // Given a parallization strategy, there is a cost. - std::vector> strategy_cost_; - // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter - std::vector is_parameter_; - // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of - // pre-operator that has parameters as input. - std::vector is_parameter_involve_; - // If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating - // peak memory cost in the training phase. - // -1: unset; 0: not parameter_involved; 1: parameter_involved - int is_output_parameter_involve_ = -1; - // Whether this output is critical, which means that this output is included in calculating peak memory cost - // in the inference phase. - // -1 : unset; 0: not critical; 1: critical - int is_output_critical_ = -1; - double outputs_total_size_ = 0.0; - bool is_calculated_outputs_size_ = false; - // for each input and output, the followings record the number of bytes of each element - std::vector inputs_type_lengths_; - std::vector outputs_type_lengths_; - std::vector> prev_edges_; - std::vector> succ_edges_; - StrategyPtr selected_strategy_; - // Used in DP algorithm - bool is_alive_; - CostPtr selected_cost_; - std::vector split_flag_list_; - std::string refkey_parameter_name_; - CNodePtr cnode_; - int32_t used_devices_ = -1; - - private: - OperatorCostPtr operator_cost_; - std::vector outputs_type_; -}; - -Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); -Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool); -Operator CreateVirtualDivOp(int32_t div_num); -Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); -Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); -Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); -OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); -int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); -std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes &shapes, const std::vector &split_flag_list); - -void PrintStrategy(const StrategyPtr &strategy); -// generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d]) -Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, - const Shapes &splittable_inputs, std::vector *sp_vector); -// generate strategies for that have two inputs, and input0 or input1 maybe broadcast, -// and the corresponding dimensions that are not broadcast are all relevant dimensions -// such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) -// or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) -// or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, - std::vector *sp_vector); - -Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h deleted file mode 100644 index 45b00aed30..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ - -#include "parallel/ops_info/activation_info.h" -#include "parallel/ops_info/arithmetic_info.h" -#include "parallel/ops_info/batch_parallel_info.h" -#include "parallel/ops_info/bias_add_info.h" -#include "parallel/ops_info/comparison_function_info.h" -#include "parallel/ops_info/dropout_do_mask_info.h" -#include "parallel/ops_info/elementary_function_info.h" -#include "parallel/ops_info/gather_v2_info.h" -#include "parallel/ops_info/get_next_info.h" -#include "parallel/ops_info/l2_normalize_info.h" -#include "parallel/ops_info/layer_norm_info.h" -#include "parallel/ops_info/loss_info.h" -#include "parallel/ops_info/matmul_info.h" -#include "parallel/ops_info/onehot_info.h" -#include "parallel/ops_info/prelu_info.h" -#include "parallel/ops_info/reduce_method_info.h" -#include "parallel/ops_info/reshape_info.h" -#include "parallel/ops_info/transpose_info.h" -#include "parallel/ops_info/virtual_dataset_info.h" -#include "parallel/ops_info/gather_v2_p_info.h" - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h deleted file mode 100644 index 9cb3c7040a..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ /dev/null @@ -1,294 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_ - -namespace mindspore { -namespace parallel { -constexpr size_t PRELU_INPUTS_SIZE = 2; -constexpr size_t PRELU_OUTPUTS_SIZE = 1; -constexpr size_t PRELU_SECOND_INPUT_SIZE = 1; -constexpr int32_t PRELU_CHANNEL_INDEX = 1; -constexpr int32_t PRELU_CHANNEL_STRATEGY = 1; -constexpr int32_t NO_SPLIT_MAP = -1; -constexpr int32_t NO_SPLIT_STRATEGY = 1; -constexpr int32_t SPLIT_FLAG = 1; -constexpr int32_t NO_SPLIT_FLAG = 0; -constexpr size_t MATMUL_ATTRS_SIZE = 2; -constexpr size_t MATMUL_INPUTS_SIZE = 2; -constexpr size_t MATMUL_OUTPUTS_SIZE = 1; -constexpr size_t ACTIVATION_ATTR_SIZE = 1; -constexpr size_t SOFTMAX_ATTR_SIZE = 1; -constexpr size_t ACTIVATION_INPUTS_SIZE = 1; -constexpr size_t ACTIVATION_OUTPUTS_SIZE = 1; -constexpr size_t EXPANDDIMS_INPUT_SIZE = 2; -constexpr size_t DROPOUT_DO_MASK_CNODE_INPUT_SIZE = 4; -constexpr size_t DROPOUT_GEN_MASK_CNODE_INPUT_SIZE = 3; -constexpr size_t DROPOUT_GEN_MASK_INDEX = 2; -constexpr size_t DROPOUT_DO_MASK_KEEP_PROB_INDEX = 3; -constexpr size_t SoftmaxCrossEntropyWithLogitsAttrSize = 1; -constexpr size_t SoftmaxCrossEntropyWithLogitsInputsSize = 2; -constexpr size_t SoftmaxCrossEntropyWithLogitsOutputsSize = 2; -constexpr double EPS = 1e-6; -constexpr double INF = 1e20; - -constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only"; -constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only"; -constexpr char CHECK_SET_STRATEGY_VALID_ONCE_ONLY[] = "check_set_strategy_valid_once_only"; -constexpr char STRATEGY[] = "strategy"; -constexpr char GEN_STRATEGY[] = "gen_strategy"; -constexpr char REDUCE_OP_SUM[] = "sum"; -constexpr char REDUCE_OP_MAX[] = "max"; -constexpr char REDUCE_OP_MIN[] = "min"; -constexpr char OP_PATH[] = "mindspore.ops.operations"; -constexpr char INNER_OP_PATH[] = "mindspore.ops.operations._inner_ops"; -constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; -constexpr char GET_OP_FUNCTION[] = "_get_python_op"; -constexpr char KEEP_DIMS[] = "keep_dims"; -constexpr char CROSS_BATCH[] = "cross_batch"; -constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin"; -constexpr char STEP_PARALLEL_END[] = "step_parallel_end"; -constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; -constexpr char REQUIRES_GRAD[] = "requires_grad"; -constexpr char PARAM_NAME[] = "name"; - -constexpr char RELU_TYPE[] = "relu"; -constexpr char RELU6_TYPE[] = "relu6"; -constexpr char SIGMOID_TYPE[] = "sigmoid"; -constexpr char OP[] = "op"; -constexpr char IDENTITY_INFO[] = "identity_info"; -constexpr char DIVISOR[] = "divisor"; -constexpr char NONE[] = "None"; -constexpr char DEPEND[] = "Depend"; -constexpr char BATCH_PARALLEL[] = "BatchParallel"; - -constexpr char ACTIVATION_TYPE[] = "activation_type"; -constexpr char TARGET[] = "primitive_target"; -constexpr char CPU[] = "CPU"; -constexpr char TRANSPOSE_A[] = "transpose_a"; -constexpr char TRANSPOSE_B[] = "transpose_b"; -constexpr char SHAPE[] = "shape"; -constexpr char BEGIN_MASK[] = "begin_mask"; -constexpr char END_MASK[] = "end_mask"; -constexpr char ELLIPSIS_MASK[] = "ellipsis_mask"; -constexpr char NEW_AXIS_MASK[] = "new_axis_mask"; -constexpr char SHRINK_AXIS_MASK[] = "shrink_axis_mask"; -constexpr char BEGIN[] = "begin"; -constexpr char END[] = "end"; -constexpr char STRIDES[] = "strides"; -constexpr char GROUP[] = "group"; -constexpr char AXIS[] = "axis"; -constexpr char OUTPUT_NUM[] = "output_num"; -constexpr char SPLIT_COUNT[] = "split_count"; -constexpr char SPLIT_DIM[] = "split_dim"; -constexpr char CONCAT_DIM[] = "concat_dim"; -constexpr char FORWARD[] = "forward"; -constexpr char BACKWARD[] = "backward"; -constexpr char REDISTRIBUTION[] = "redistribution"; -constexpr char REPLACE[] = "replace"; -constexpr char CONNSYMBOL[] = "/"; -constexpr char INSTANCE_NAME[] = "instance_name"; -constexpr char SPLIT_SENS[] = "split_sens"; -constexpr char SPLIT_TENSOR[] = "split_tensor"; -constexpr char DEV_MAT[] = "dev_mat"; -constexpr char TENSOR_MAP[] = "tensor_map"; -constexpr char SEED0[] = "Seed0"; -constexpr char SEED1[] = "Seed1"; -constexpr char KEEP_PROB[] = "keep_prob"; -constexpr char SRC[] = "src"; -constexpr char CLONE_INFO[] = "clone_info"; -constexpr char CLONED[] = "cloned"; -constexpr char BE_CLONED[] = "be_cloned"; -constexpr char CLONED_INDEX[] = "cloned_index"; -constexpr char BE_CLONED_INDEX[] = "be_cloned_index"; -constexpr char GROUP_RANKS[] = "group_ranks"; -constexpr char IS_IN_FORWARD[] = "is_in_forward"; -constexpr char DEFAULT_INPUT[] = "default_input"; -constexpr char DTYPE[] = "DType"; -constexpr char DEV_NUM[] = "dev_num"; -constexpr char MEAN_FLAG[] = "mean_flag"; -constexpr char TYPES[] = "types"; -constexpr char SHAPES[] = "shapes"; -constexpr char GETNEXT_NUM[] = "output_num"; -constexpr char SHARED_NAME[] = "shared_name"; -constexpr char MIRROR_OP[] = "mirror_op"; -constexpr char FORWARD_OP[] = "forward_op"; -constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; -constexpr char DARA_PARALLEL[] = "data_parallel"; -constexpr char FORWARD_REDUCE_SCATTER[] = "forward_reduce_scatter"; -constexpr char OPTIMIZER_SUB_STRING[] = "optimizer"; - -// Operator -constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; -constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice"; -constexpr char SPLIT[] = "Split"; -constexpr char ALL_TO_ALL[] = "_AlltoAll"; -constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis"; -constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis"; -constexpr char SPLIT_BY_AXIS[] = "SplitByAxis"; -constexpr char ALL_REDUCE[] = "AllReduce"; -constexpr char MIRROR_OPERATOR[] = "_MirrorOperator"; -constexpr char STRIDED_SLICE[] = "StridedSlice"; -constexpr char ALL_GATHER[] = "AllGather"; -constexpr char REDUCE_SCATTER[] = "ReduceScatter"; -constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter"; -constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; -constexpr char CONCAT[] = "Concat"; -constexpr char SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SoftmaxCrossEntropyWithLogits"; -constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLogits"; -constexpr char MATMUL[] = "MatMul"; -constexpr char GELU[] = "Gelu"; -constexpr char TANH[] = "Tanh"; -constexpr char SOFTMAX[] = "Softmax"; -constexpr char LOG_SOFTMAX[] = "LogSoftmax"; -constexpr char ACTIVATION[] = "Activation"; -constexpr char PRELU[] = "PReLU"; -constexpr char FLOORDIV[] = "FloorDiv"; -constexpr char MAXPOOL[] = "MaxPool"; -constexpr char MAXPOOLV2[] = "MaxPoolV2"; -constexpr char L2_NORMALIZE[] = "L2Normalize"; -constexpr char TRANSPOSE[] = "Transpose"; -constexpr char RESHAPE[] = "Reshape"; -constexpr char TENSOR_ADD[] = "TensorAdd"; -constexpr char BIAS_ADD[] = "BiasAdd"; -constexpr char SUB[] = "Sub"; -constexpr char MUL[] = "Mul"; -constexpr char DIV[] = "Div"; -constexpr char REAL_DIV[] = "RealDiv"; -constexpr char ASSIGN_SUB[] = "AssignSub"; -constexpr char GREATER[] = "Greater"; -constexpr char VIRTUAL_DATA_SET[] = "_VirtualDataset"; -constexpr char VIRTUAL_DATA_SET_INFO[] = "VirtualDatasetInfo"; -constexpr char SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS[] = "SparseSoftmaxCrossEntropyWithLogits"; -constexpr char RELU[] = "ReLU"; -constexpr char ONEHOT[] = "OneHot"; -constexpr char DROPOUT_DO_MASK[] = "DropoutDoMask"; -constexpr char DROPOUT_GEN_MASK[] = "DropoutGenMask"; -constexpr char REDUCE_MAX[] = "ReduceMax"; -constexpr char REDUCE_MIN[] = "ReduceMin"; -constexpr char REDUCE_SUM[] = "ReduceSum"; -constexpr char REDUCE_MEAN[] = "ReduceMean"; -constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue"; -constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue"; -constexpr char CONV2D[] = "Conv2D"; -constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm"; -constexpr char BATCH_NORM[] = "BatchNorm"; -constexpr char LAYER_NORM[] = "LayerNorm"; -constexpr char POOLING[] = "Pooling"; -constexpr char CAST[] = "Cast"; -constexpr char MAX_POOL_WITH_ARGMAX[] = "MaxPoolWithArgmax"; -constexpr char SIMPLE_MEAN[] = "SimpleMean"; -constexpr char FLATTEN[] = "Flatten"; -constexpr char J[] = "J"; -constexpr char TMPIDENTITY_INFO_NAME[] = "identity_info"; -constexpr char COS[] = "Cos"; -constexpr char ACOS[] = "ACos"; -constexpr char EXP[] = "Exp"; -constexpr char LOG[] = "Log"; -constexpr char SIGMOID[] = "Sigmoid"; -constexpr char POW[] = "Pow"; -constexpr char MAXIMUM[] = "Maximum"; -constexpr char MINIMUM[] = "Minimum"; -constexpr char EQUAL[] = "Equal"; -constexpr char NOT_EQUAL[] = "NotEqual"; -constexpr char LOGICALNOT[] = "LogicalNot"; -constexpr char GATHERV2[] = "GatherV2"; -constexpr char SPARSE_GATHERV2[] = "SparseGatherV2"; -constexpr char STRIDEDSLICE[] = "StridedSlice"; -constexpr char BROADCAST[] = "Broadcast"; -constexpr char SQRT[] = "Sqrt"; -constexpr char ASSIGN[] = "Assign"; -constexpr char GET_NEXT[] = "GetNext"; -constexpr char SQUEEZE[] = "Squeeze"; -constexpr char NEG[] = "Neg"; -constexpr char BATCH_MATMUL[] = "BatchMatMul"; -constexpr char EXPAND_DIMS[] = "ExpandDims"; -constexpr char SQUARE[] = "Square"; -constexpr char BATCHMATMUL[] = "BatchMatMul"; -constexpr char TOPK[] = "TopK"; -constexpr char IN_TOPK[] = "InTopK"; -constexpr char PACK[] = "Pack"; -constexpr char GATHER_ND[] = "GatherNd"; -constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD"; -constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; -constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; -constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; -constexpr char ADD[] = "Add"; - -// Parallel don't care -constexpr char TUPLE_GETITEM[] = "tuple_getitem"; -constexpr char STRING_EQUAL[] = "string_equal"; -constexpr char MAKE_TUPLE[] = "make_tuple"; -constexpr char MAKE_LIST[] = "make_list"; -constexpr char MAKE_DICT[] = "make_dict"; -constexpr char MAKE_SLICE[] = "make_slice"; -constexpr char MAKE_RECORD[] = "make_record"; -constexpr char LIST_GETITEM[] = "list_getitem"; -constexpr char ARRAY_GETITEM[] = "array_getitem"; -constexpr char TUPLE_SETITEM[] = "tuple_setitem"; -constexpr char LIST_SETITEM[] = "list_setitem"; -constexpr char ARRAY_SETITEM[] = "array_setitem"; -constexpr char DICT_GETITEM[] = "dict_getitem"; -constexpr char LIST_APPEND[] = "list_append"; -constexpr char LIST_MAP[] = "list_map"; -constexpr char LIST_REDUCE[] = "list_reduce"; -constexpr char TUPLE_REVERSED[] = "tuple_reversed"; -constexpr char TILE_SHAPE[] = "tile_shape"; -constexpr char REDUCED_SHAPE[] = "reduced_shape"; -constexpr char TUPLE_DIV[] = "tuple_div"; -constexpr char TUPLE_TO_ARRAY[] = "tuple_to_array"; -constexpr char VIRTUALLOSS[] = "VirtualLoss"; -constexpr char RETURN[] = "return"; -constexpr char ENV_GETITEM[] = "env_getitem"; -constexpr char IDENTITY[] = "identity"; -constexpr char PARTIAL[] = "partial"; -constexpr char ENVSETITEM[] = "env_setitem"; -constexpr char ENVGETITEM[] = "env_getitem"; -constexpr char ENVADD[] = "env_add"; -constexpr char MAKEREFKEY[] = "MakeRefKey"; -constexpr char MAKEREF[] = "make_ref"; -constexpr char GETREFKEY[] = "get_ref_key"; -constexpr char GETREFVALUE[] = "get_ref_value"; -constexpr char GETREFORIGIN[] = "get_ref_origin"; -constexpr char STATESETITEM[] = "state_setitem"; -constexpr char SCALARSUMMARY[] = "ScalarSummary"; -constexpr char IMAGESUMMARY[] = "ImageSummary"; -constexpr char TENSORSUMMARY[] = "TensorSummary"; -constexpr char HISTOGRAMSUMMARY[] = "HistogramSummary"; -constexpr char DEBUG[] = "Debug"; -constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs"; -constexpr char INVERTPERMUTATION[] = "InvertPermutation"; -constexpr char CONTROLDEPEND[] = "ControlDepend"; -constexpr char DOT[] = "dot"; -constexpr char IM2COL[] = "im2col"; -constexpr char COL2IM[] = "col2im"; -constexpr char IM2COLV1[] = "im2col_v1"; -constexpr char COL2IMV1[] = "col2im_v1"; -constexpr char RESOLVE[] = "resolve"; -constexpr char EMBED[] = "embed"; -constexpr char CREATINSTANCE[] = "create_instance"; -constexpr char ZEROSLIKE[] = "ZerosLike"; -constexpr char REF_TO_EMBED[] = "RefToEmbed"; -constexpr char STOP_GRADIENT[] = "stop_gradient"; - -constexpr size_t LAST_INDEX(size_t s) { return s - 1; } -constexpr size_t SECOND_FROM_END(size_t s) { return s - 2; } -constexpr size_t THIRD_FROM_END(size_t s) { return s - 3; } -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_UTILS_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc deleted file mode 100644 index 14483e97a1..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc +++ /dev/null @@ -1,253 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/prelu_info.h" - -#include -#include -#include - -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/step_parallel.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -/* - * prelu has 2 input - * A: A float tensor of shape [NCHW] representing the output of the preview layer. - * w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input. - * the strategy of w should equal to the channel dimension of strategy of A, or equal to 1 - */ -Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - std::vector stra = strategy->GetInputDim(); - if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy size."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy size."; - } - return FAILED; - } - if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0] && inputs_shape_[1][0] != 1) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid channel strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid channel strategy."; - } - return FAILED; - } - return SUCCESS; -} - -/* - * device matrix is same with the strategy matrix - */ -Status PReLUInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - input_strategy_ = input_strategy; - dev_matrix_shape_ = input_strategy; - return SUCCESS; -} - -Status PReLUInfo::InferMirrorOps() { - Shape param_tensor_map = inputs_tensor_map_[1]; - std::vector param_group; - if (CreateGroupByTensorMap(param_tensor_map, ¶m_group) != SUCCESS) { - return FAILED; - } else if (param_group.empty()) { - MS_LOG(INFO) << name_ << ": The mirror ops is empty."; - return SUCCESS; - } - OperatorVector op_for_param; - op_for_param = CreateMirrorOps(param_group[0].name(), param_group[0].GetDevNum()); - // op_for_inputs is empty - OperatorVector op_for_inputs; - mirror_ops_.push_back(op_for_inputs); - mirror_ops_.push_back(op_for_param); - std::string group_name = param_group[0].name(); - MS_LOG(INFO) << name_ << ": The mirror ops group is " << group_name; - return SUCCESS; -} - -Status PReLUInfo::InferForwardCommunication() { return SUCCESS; } - -/* - * the output tensor map is the same as the input tensor map - */ -Status PReLUInfo::InferTensorMap() { - TensorMap input_tensor_map; - // such as 4: input_tensor_map [3,2,1,0] - for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { - input_tensor_map.push_back((int32_t)(inputs_shape_[0].size() - i - 1)); - } - - TensorMap param_tensor_map; - if (inputs_shape_[1][0] == 1) { - param_tensor_map.push_back(-1); - } else { - param_tensor_map.push_back(input_tensor_map.at(1)); - } - inputs_tensor_map_.push_back(input_tensor_map); - inputs_tensor_map_.push_back(param_tensor_map); - outputs_tensor_map_.push_back(input_tensor_map); - return SUCCESS; -} - -Dimensions PReLUInfo::GetOutputStrategy() { - Dimensions output_strategy = input_strategy_; - return output_strategy; -} - -Status PReLUInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { - if (inputs_layout == nullptr || outputs_layout == nullptr) { - MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; - return FAILED; - } - TensorLayout input_layout, param_layout, output_layout; - if ((input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || - (param_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || - (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) { - return FAILED; - } - inputs_layout->push_back(input_layout); - inputs_layout->push_back(param_layout); - outputs_layout->push_back(output_layout); - return SUCCESS; -} - -Status PReLUInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape param_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Dimensions output_strategy = GetOutputStrategy(); - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {output_strategy}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - Shape param_slice_shape = inputs_slice_shape.at(1); - Shape output_slice_shape = outputs_slice_shape.at(0); - - // infer tensor layout - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { - return FAILED; - } - - TensorLayout input_layout = inputs_layout.at(0); - TensorLayout param_layout = inputs_layout.at(1); - TensorLayout output_layout = outputs_layout.at(0); - TensorInfo input_tensor_info(input_layout, input_shape, input_slice_shape); - TensorInfo param_tensor_info(param_layout, param_shape, param_slice_shape); - TensorInfo output_tensor_info(output_layout, output_shape, output_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - inputs_tensor_info_.push_back(param_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status PReLUInfo::GetAttrs() { - if ((inputs_shape_.size() != PRELU_INPUTS_SIZE) || (outputs_shape_.size() != PRELU_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << ": Inputs shape size " << inputs_shape_.size() << " or outputs shape size " - << outputs_shape_.size() << " is wrong."; - return FAILED; - } - return SUCCESS; -} - -Status PReLUInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status PReLUInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status PReLUInfo::GenerateStrategies(int32_t stage_id) { - if (inputs_shape_.size() != PRELU_INPUTS_SIZE) { - return FAILED; - } - if (inputs_shape_[1].size() != PRELU_SECOND_INPUT_SIZE) { - return FAILED; - } - is_auto_parallel_ = true; - Shape input0_split; - input0_split.emplace_back(1); - input0_split.emplace_back(0); - (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 2, 1); - Shape input1_split(inputs_shape_[1].size(), 0); - Shapes splittable_inputs = {input0_split, input1_split}; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed"; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/parallel/ops_info/prelu_info.h deleted file mode 100644 index 28e149fad7..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -/* - * parallel class for PReLU Primitive - */ -class PReLUInfo : public OperatorInfo { - public: - PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~PReLUInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); - Status GetAttrs() override; - Dimensions GetOutputStrategy(); - - private: - Dimensions input_strategy_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc deleted file mode 100644 index 7304666a77..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc +++ /dev/null @@ -1,571 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/reduce_method_info.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/tensor_layout/tensor_redistribution.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - return SUCCESS; -} - -Status ReduceMethod::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - - dev_matrix_shape_ = input_strategy; - - return SUCCESS; -} - -std::vector ReduceMethod::reduce_dim() { - std::vector dim_list; - if (input_value_.size() < 2) { - MS_LOG(EXCEPTION) << name_ << ": Input value size is smaller than 2."; - } - if (input_value_.back() == nullptr) { - MS_LOG(EXCEPTION) << name_ << ": Input value is nullptr."; - } - MS_ASSERT(inputs_shape_.size() == 1); - auto input_dim = inputs_shape_.at(0).size(); - if (input_value_.back()->isa()) { - auto attr_axis = GetValue>(input_value_.back()); - // axis is (), reduce all dim - if (attr_axis.empty()) { - for (size_t i = 0; i < input_dim; ++i) { - dim_list.push_back(SizeToInt(i)); - } - } else { - for (auto &axis : attr_axis) { - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } - } - } else if (input_value_.back()->isa()) { - int axis = GetValue(input_value_.back()); - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Axis type is invalid."; - } - - return dim_list; -} - -Status ReduceMethod::GetAttrs() { - // get attr cross_batch and keep_dims - auto keep_dims_iter = attrs_.find(KEEP_DIMS); - if (keep_dims_iter == attrs_.end()) { - MS_LOG(ERROR) << name_ << ": Don't have attr keep_dims."; - return FAILED; - } - - if (keep_dims_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(keep_dims_iter->second); - if (!keep_dims_iter->second->isa()) { - MS_LOG(ERROR) << name_ << ": Keep_dims is not a bool."; - return FAILED; - } - keepdims_ = keep_dims_iter->second->cast()->value(); - } - - auto cross_batch_iter = attrs_.find(CROSS_BATCH); - if (cross_batch_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(cross_batch_iter->second); - if (!cross_batch_iter->second->isa()) { - MS_LOG(ERROR) << name_ << ": cross_batch is not a bool."; - return FAILED; - } - cross_batch_ = cross_batch_iter->second->cast()->value(); - } - auto reducemethodcost = std::dynamic_pointer_cast(operator_cost()); - if (reducemethodcost == nullptr) { - MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; - return FAILED; - } - reducemethodcost->set_cross_batch(cross_batch_); - return SUCCESS; -} - -Status ReduceMethod::InferTensorMap() { - std::vector tensor_map_index, dim_list, output_tensor_map; - size_t size = inputs_shape_.at(0).size(); - // such as 4: tensor_map_index [3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - 1 - i)); - } - dim_list = reduce_dim(); - for (size_t i = 0; i < size; ++i) { - if (find(dim_list.begin(), dim_list.end(), SizeToInt(i)) != dim_list.end()) { - if (keepdims_) { - output_tensor_map.push_back(-1); - } else { - continue; - } - } else { - output_tensor_map.push_back(tensor_map_index[i]); - } - } - inputs_tensor_map_.push_back(tensor_map_index); - outputs_tensor_map_.push_back(output_tensor_map); - - return SUCCESS; -} - -bool IsDataParallelStrategy(const Dimensions &strategy) { - CheckGlobalDeviceManager(); - size_t total_dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - if (strategy.empty()) { - MS_LOG(EXCEPTION) << "IsDataParallelStrategy: strategy is empty"; - } - - return (IntToSize(strategy[0]) == total_dev_num); -} - -Status ReduceMethod::InferForwardCommunication() { - Dimensions stra = strategy_->GetInputDim().at(0); - if (cross_batch_ && IsDataParallelStrategy(stra)) { - MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; - return SUCCESS; - } - if (cross_batch_) { - MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; - return SUCCESS; - } - forward_op_.clear(); - std::vector dim_list = reduce_dim(); - size_t size = stra.size(); - // judge if the reduce dim is partitioned. - Shape group_creat_map; - if (dev_matrix_shape_.size() > size) { - group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); - } - for (size_t index = 0; index < size; ++index) { - auto pos = - std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; }); - if (pos != dim_list.end() && stra[index] != 1) { - continue; - } - group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); - } - std::vector forward_group; - if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; - return FAILED; - } - if (!forward_group.empty()) { - Operator op = CreateAllReduceOp(reduce_method_, forward_group[0].name()); - forward_op_.push_back(op); - std::string group_name = forward_group[0].name(); - MS_LOG(INFO) << name_ << ": Forward communication group is " << group_name; - } - - return SUCCESS; -} - -ForwardOp CreatReduceMeanForwardOp(const std::vector &forward_group, const TypePtr &dtype) { - // Creat AllReduceSum op - Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name()); - std::string group_name = forward_group[0].name(); - MS_LOG(INFO) << "The group of forward all reduce is " << group_name; - - // Creat RealDiv op - OperatorName operator1_name = REAL_DIV; - std::vector device_list = forward_group[0].GetDevicesList(); - auto divisor = static_cast(device_list.size()); - std::vector tensor_data = {divisor}; - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(tensor_data, dtype); - ValuePtr op1_param_value = MakeValue(tensor_ptr); - Attr op1_param = std::make_pair("divisor", op1_param_value); - OperatorParams operator1_params = {std::make_pair(op1_param, 2)}; - OperatorAttrs operator1_attrs; - OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params); - Operator op1 = std::make_pair(operator1_name, operator1_args); - ForwardOp forward_op = {op0, op1}; - - std::string dtype_name = dtype->ToString(); - MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name; - return forward_op; -} - -Status ReduceMeanInfo::InferForwardCommunication() { - Dimensions stra = strategy_->GetInputDim().at(0); - if (cross_batch_ && IsDataParallelStrategy(stra)) { - MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; - return SUCCESS; - } - forward_op_.clear(); - std::vector dim_list = reduce_dim(); - size_t size = stra.size(); - // judge if the reduce dim is partitioned. - Shape group_creat_map; - if (dev_matrix_shape_.size() > size) { - group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); - } - for (size_t index = 0; index < size; ++index) { - auto pos = - std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; }); - if (pos != dim_list.end() && stra[index] != 1) { - continue; - } - group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); - } - std::vector forward_group; - if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; - return FAILED; - } - if (!forward_group.empty()) { - if ((outputs_dtype_ == nullptr) || !outputs_dtype_->isa()) { - MS_LOG(ERROR) << name_ << ": The dtype of output is not Array"; - return FAILED; - } - - auto element_type = outputs_dtype_->cast()->element(); - forward_op_ = CreatReduceMeanForwardOp(forward_group, element_type); - } - - return SUCCESS; -} - -Status ReduceMethod::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_tensor_map = inputs_tensor_map_.at(0); - std::vector input_group; - if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " Infer MirrorOps failed."; - return FAILED; - } - - OperatorVector op_for_weight; - OperatorVector op_for_reduce_axis; // helper node - if (input_group.empty()) { - MS_LOG(INFO) << name_ << ": The mirror ops is empty."; - return SUCCESS; - } else { - op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); - mirror_ops_.push_back(op_for_weight); - mirror_ops_.push_back(op_for_reduce_axis); - std::string group_name = input_group[0].name(); - MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success, the group is " << group_name; - } - - return SUCCESS; -} - -Status ArgMaxWithValueInfo::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_tensor_map = inputs_tensor_map_.at(0); - std::vector input_group; - if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed."; - return FAILED; - } - - OperatorVector op_for_weight; - if (input_group.empty()) { - MS_LOG(INFO) << name_ << ": The mirror ops is empty."; - return SUCCESS; - } else { - op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); - mirror_ops_.push_back(op_for_weight); - MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success."; - } - - return SUCCESS; -} - -Dimensions ReduceMethod::InferOutputStrategy() { - std::vector dim_list = reduce_dim(); - Dimensions output_strategy; - Dimensions stra = strategy_->GetInputDim().at(0); - // if keepdims_ is true,then output strategy is same with input. - for (size_t i = 0; i < stra.size(); ++i) { - if (find(dim_list.begin(), dim_list.end(), SizeToInt(i)) != dim_list.end()) { - if (keepdims_) { - output_strategy.push_back(1); - } - } else { - output_strategy.push_back(stra[i]); - } - } - return output_strategy; -} - -Status ReduceMethod::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Dimensions output_strategy = InferOutputStrategy(); - - Strategys outputs_strategy = {output_strategy}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - Shape output_slice_shape = outputs_slice_shape.at(0); - - TensorLayout input_tensor_layout, output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || - (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { - return FAILED; - } - - std::vector dim_list = reduce_dim(); - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); - input_tensor_info.set_reduce_dim(dim_list); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - - return SUCCESS; -} - -Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status ReduceMethod::GenerateStrategies(int32_t stage_id) { - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " - << outputs_shape_.size(); - return FAILED; - } - - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - is_auto_parallel_ = true; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status ReduceMethod::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - - return SUCCESS; -} - -Status ReduceMethod::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed"; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed"; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success"; - return SUCCESS; -} - -std::vector ArgMaxWithValueInfo::reduce_dim() { - std::vector dim_list; - auto iter = attrs_.find(AXIS); - if (iter == attrs_.end()) { - MS_LOG(EXCEPTION) << name_ << ": Don't have attr axis."; - } - - MS_ASSERT(inputs_shape_.size() == 1); - auto input_dim = inputs_shape_.at(0).size(); - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - auto attr_axis = GetValue>(iter->second); - if (attr_axis.empty()) { - for (size_t i = 0; i < input_dim; ++i) { - dim_list.push_back(SizeToInt(i)); - } - } else { - for (auto &axis : attr_axis) { - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } - } - } else if (iter->second->isa()) { - int axis = GetValue(iter->second); - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Axis type is invalid."; - } - - return dim_list; -} - -Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) { - if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": CheckStrategy for parent class ReduceMethod failed"; - } else { - MS_LOG(ERROR) << name_ << ": CheckStrategy for parent class ReduceMethod failed"; - } - return FAILED; - } - std::vector dim_list = reduce_dim(); - MS_ASSERT(dim_list.size() == 1); - - std::vector stra = strategy->GetInputDim(); - MS_ASSERT(stra.size() == 1); - Shape input_strategy = stra.at(0); - MS_ASSERT(dim_list.at(0) < input_strategy.size()); - if (input_strategy.at(IntToSize(dim_list.at(0))) != 1) { - MS_LOG(WARNING) - << name_ - << " CheckStrategy for ArgMaxWithValueInfo, the strategy corresponding to axis is not one, real strategy " - "is " - << input_strategy.at(IntToSize(dim_list.at(0))) - << ", the output index may be not compatible with the stand alone Primitive"; - } - return SUCCESS; -} - -Status ArgMaxWithValueInfo::InferTensorMap() { - if (ReduceMethod::InferTensorMap() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorMap for parent class ReduceMethod failed"; - return FAILED; - } - MS_ASSERT(outputs_tensor_map_.size() == 1); - outputs_tensor_map_.push_back(outputs_tensor_map_[0]); - return SUCCESS; -} - -Status ArgMaxWithValueInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Dimensions output_strategy = InferOutputStrategy(); - - Strategys outputs_strategy = {output_strategy, output_strategy}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - Shape output_slice_shape = outputs_slice_shape.at(0); - - TensorLayout input_tensor_layout, output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || - (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { - return FAILED; - } - - std::vector dim_list = reduce_dim(); - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); - input_tensor_info.set_reduce_dim(dim_list); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status ArgMaxWithValueInfo::InferAsLossDivisor() { - if (outputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty."; - return FAILED; - } - - MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer"; - if (outputs_tensor_map_[0].empty()) { - as_loss_divisor_ = SizeToInt(global_device_list_.size()); - MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor."; - return SUCCESS; - } - - as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); - - std::string dev_matrix_shape_str = ShapeToString(dev_matrix_shape_); - std::string output_tensor_map_str = ShapeToString(outputs_tensor_map_[0]); - MS_LOG(INFO) << name_ << ": the dev matrix shape, the output tensor map, and loss divisor is " << dev_matrix_shape_str - << ", " << output_tensor_map_str << ", " << as_loss_divisor_; - return SUCCESS; -} - -Status ArgMaxWithValueInfo::GenerateStrategies(int32_t stage_id) { - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 2)) { - MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " - << outputs_shape_.size(); - return FAILED; - } - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - is_auto_parallel_ = true; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated strategy " << success; - PrintStrategy(sp); - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h deleted file mode 100644 index 796c7e457b..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ - -#include -#include -#include -#include - -#include "ir/tensor.h" -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class ReduceMethod : public OperatorInfo { - public: - ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~ReduceMethod() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - std::string reduce_method_; - bool keepdims_ = false; - bool cross_batch_ = false; - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override; - Dimensions InferOutputStrategy(); - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferMirrorOps() override; - virtual std::vector reduce_dim(); - Status InferForwardCommunication() override; - Status InferDevMatrixShape() override; -}; - -class ReduceMaxInfo : public ReduceMethod { - public: - ReduceMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - reduce_method_ = REDUCE_OP_MAX; - } - - ~ReduceMaxInfo() override = default; -}; - -class ArgMaxWithValueInfo : public ReduceMethod { - public: - ArgMaxWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - reduce_method_ = REDUCE_OP_MAX; - } - - ~ArgMaxWithValueInfo() override = default; - - Status GenerateStrategies(int32_t stage_id) override; - - protected: - std::vector reduce_dim() override; - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferAsLossDivisor() override; -}; - -class ArgMinWithValueInfo : public ArgMaxWithValueInfo { - public: - ArgMinWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArgMaxWithValueInfo(name, inputs_shape, outputs_shape, attrs) { - reduce_method_ = REDUCE_OP_MIN; - } - - ~ArgMinWithValueInfo() override = default; -}; - -class ReduceMeanInfo : public ReduceMethod { - public: - ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - set_cost(std::make_shared()); - } - - ~ReduceMeanInfo() override = default; - - protected: - Status InferForwardCommunication() override; -}; - -class ReduceSumInfo : public ReduceMethod { - public: - ReduceSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - reduce_method_ = REDUCE_OP_SUM; - } - - ~ReduceSumInfo() override = default; -}; - -class ReduceMinInfo : public ReduceMethod { - public: - ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - reduce_method_ = REDUCE_OP_MIN; - } - - ~ReduceMinInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc deleted file mode 100644 index 57e1a76d0a..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc +++ /dev/null @@ -1,507 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/reshape_info.h" - -#include -#include - -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/step_parallel.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - size_t strategy_size = strategy->GetInputNumber(); - if (strategy_size != 1) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy size " << strategy_size; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy size " << strategy_size; - } - return FAILED; - } - return SUCCESS; -} - -/* - * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of - * device matrix - * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) - */ -Status ReshapeInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - input_strategy_ = stra.at(0); - dev_matrix_shape_.push_back(input_strategy_[0]); - return SUCCESS; -} - -/* - * there is no Parameter for Reshape Primitive, so no need to do allreduce - */ -Status ReshapeInfo::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_tensor_map = input_layout_.tensor_map().array(); - std::vector input_group; - if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed."; - return FAILED; - } - - OperatorVector op_for_input; - if (input_group.empty()) { - MS_LOG(INFO) << name_ << ": The mirror ops is empty."; - return SUCCESS; - } - if (!input_group.empty()) { - op_for_input = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); - std::string group_name = input_group[0].name(); - MS_LOG(INFO) << name_ << ": Create the mirror ops for input_a success, group is " << group_name; - } - mirror_ops_.push_back(op_for_input); - OperatorVector op_for_input_empty; - mirror_ops_.push_back(op_for_input_empty); - - return SUCCESS; -} - -/* - * there is no reduction dimension for forward computation of Reshape Primitive, so no need to do allreduce - */ -Status ReshapeInfo::InferForwardCommunication() { return SUCCESS; } - -/* - * get shape input of Reshape Primitive - * the result is saved in parameter_input_v_ - * not support -1 - */ -Status ReshapeInfo::GetParameterInput() { - if (input_value_[1] == nullptr) { - MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr."; - return FAILED; - } - std::vector elements; - ValueTuplePtr dim_tuple = input_value_[1]->cast(); - if (dim_tuple == nullptr) { - MS_LOG(ERROR) << name_ << ": Input_value_[1] must be ValueTuplePtr."; - return FAILED; - } - elements = dim_tuple->value(); - if (elements.size() != outputs_shape_[0].size()) { - MS_LOG(ERROR) << name_ << ": Elements size must equal to outputs shape[0] size."; - return FAILED; - } - - for (auto &element : elements) { - MS_EXCEPTION_IF_NULL(element); - if (element->isa()) { - int32_t axis = element->cast()->value(); - parameter_input_v_.push_back(axis); - } else { - MS_LOG(ERROR) << name_ << ": The value of axis must be int32."; - return FAILED; - } - } - return SUCCESS; -} - -Status ReshapeInfo::ComputeReplaceOp() { - RankList dev_list = global_device_list(); - TensorRedistribution tensor_redistribution(!is_generating_costs_, true); - if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) { - if (is_generating_costs_) { - MS_LOG(DEBUG) << name_ << ": tensor_redistribution init failed."; - } else { - MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed."; - } - return FAILED; - } - MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); - MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); - MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); - RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); - if (redistribution_oplist_ptr == nullptr) { - if (is_generating_costs_) { - MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; - } else { - MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; - } - return FAILED; - } - replace_op_ = redistribution_oplist_ptr->first; - replace_op_info_ = redistribution_oplist_ptr->second; - MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); - return SUCCESS; -} - -/* - * the first dimension of input tensor map and output tensor map is set to the last dimension of device arrangement, - * all other dimension is set to None - * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) - */ -Status ReshapeInfo::InferTensorMap() { - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": inputs shape and outputs shape size must be 1. inputs shape and outputs shape are " - << inputs_shape_.size() << " and " << outputs_shape_.size(); - return FAILED; - } - - std::vector tensor_map_index_input; - tensor_map_index_input.push_back(0); - - for (size_t j = 1; j < inputs_shape_[0].size(); ++j) { - tensor_map_index_input.push_back(MAP_NONE); - } - inputs_tensor_map_.push_back(tensor_map_index_input); - - std::vector tensor_map_index_output; - tensor_map_index_output.push_back(0); - - for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { - tensor_map_index_output.push_back(MAP_NONE); - } - outputs_tensor_map_.push_back(tensor_map_index_output); - return SUCCESS; -} - -/* - * the output tensor strategy is the same as input tensor strategy - * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) - */ -Strategys ReshapeInfo::GetOutputsStrategy() { - Strategys outputs_strategy; - std::vector strategy; - strategy.push_back(input_strategy_[0]); - for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { - strategy.push_back(1); - } - outputs_strategy.push_back(strategy); - return outputs_strategy; -} - -Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { - if (inputs_layout == nullptr || outputs_layout == nullptr) { - MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; - return FAILED; - } - Arrangement dev_matrix; - Status status = dev_matrix.Init(dev_matrix_shape_); - if (status != Status::SUCCESS) { - return status; - } - // infer input tensor info - Shape shape_array_in = inputs_shape_.at(0); - TensorMap tensor_map_array_in = inputs_tensor_map_.at(0); - TensorLayout tensor_layout_in; - Map tensor_map_in; - status = tensor_map_in.Init(tensor_map_array_in); - if (status != Status::SUCCESS) { - return status; - } - Arrangement shape_in; - status = shape_in.Init(shape_array_in); - if (status != Status::SUCCESS) { - return status; - } - (void)tensor_layout_in.Init(dev_matrix, tensor_map_in, shape_in); - inputs_layout->push_back(tensor_layout_in); - // infer output tensor info - Shape shape_array_out = outputs_shape_.at(0); - - TensorMap tensor_map_array_out = outputs_tensor_map_.at(0); - TensorLayout tensor_layout_out; - Map tensor_map_out; - status = tensor_map_out.Init(tensor_map_array_out); - if (status != Status::SUCCESS) { - return status; - } - Arrangement shape_out; - status = shape_out.Init(shape_array_out); - if (status != Status::SUCCESS) { - return status; - } - (void)tensor_layout_out.Init(dev_matrix, tensor_map_out, shape_out); - outputs_layout->push_back(tensor_layout_out); - - input_layout_ = tensor_layout_in; - output_layout_ = tensor_layout_out; - return SUCCESS; -} - -Status ReshapeInfo::InferTensorInfo() { - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = GetOutputsStrategy(); - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { - return FAILED; - } - TensorLayout tensor_layout_in = inputs_layout.at(0); - TensorLayout tensor_layout_out = outputs_layout.at(0); - Shape shape_array_in = inputs_shape_.at(0); - Shape slice_shape_in = inputs_slice_shape.at(0); - Shape shape_array_out = outputs_shape_.at(0); - Shape slice_shape_out = outputs_slice_shape.at(0); - TensorInfo tensor_info_in(tensor_layout_in, shape_array_in, slice_shape_in); - TensorInfo tensor_info_out(tensor_layout_out, shape_array_out, slice_shape_out); - inputs_tensor_info_.push_back(tensor_info_in); - outputs_tensor_info_.push_back(tensor_info_out); - return SUCCESS; -} - -void ReshapeInfo::InferTensorInfoByLayout() { - TensorInfo tensor_info_in(input_layout_); - TensorInfo tensor_info_out(output_layout_); - inputs_tensor_info_.push_back(tensor_info_in); - outputs_tensor_info_.push_back(tensor_info_out); -} - -/* - * compute parameter_input_v_ during this method - */ -Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } - -void ReshapeInfo::device_number(const StrategyPtr &strategy) { - int32_t stage = 0; - if (strategy != nullptr) { - stage = strategy->GetInputStage(); - } - CheckGlobalDeviceManager(); - global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); - dev_num_ = SizeToInt(global_device_list_.size()); - MS_ASSERT(dev_num_ > 0); -} - -Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) { - std::vector tensor_map_index; - for (size_t i = 0; i < shape.size(); i++) { - tensor_map_index.push_back(MAP_NONE); - } - Status status = layout->InitFromVector({dev_num_}, tensor_map_index, shape); - if (status != Status::SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferDefaultLayout failed."; - return status; - } - return Status::SUCCESS; -} - -Status ReshapeInfo::Init(const StrategyPtr &strategy) { - ResetQueueMember(); - device_number(strategy); - if (strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - } else { - if (!input_layout_set_flag_) { - MS_ASSERT(inputs_shape_.size() == 1); - Status status = InferDefaultLayout(inputs_shape_.at(0), &input_layout_); - if (status != SUCCESS) { - MS_LOG(ERROR) << name_ << ": infer input default layout failed."; - return status; - } - } - if (!output_layout_set_flag_) { - MS_ASSERT(output_layout_.size() == 1); - Status status = InferDefaultLayout(outputs_shape_.at(0), &output_layout_); - if (status != SUCCESS) { - MS_LOG(ERROR) << name_ << ": infer output default layout failed."; - return status; - } - } - inputs_tensor_map_.push_back(input_layout_.tensor_map().array()); - outputs_tensor_map_.push_back(output_layout_.tensor_map().array()); - InferTensorInfoByLayout(); - // change dev_matrix_shape_ to input_layout_ device_arrangement before InferMirrorOps - dev_matrix_shape_ = input_layout_.device_arrangement().array(); - if (InferMirrorOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; - return FAILED; - } - // change dev_matrix_shape_ to output_layout_ device_arrangement before InferVirtualDivOps - dev_matrix_shape_ = output_layout_.device_arrangement().array(); - if (InferVirtualDivOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; - return FAILED; - } - } - Status status = ComputeReplaceOp(); - if (status != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed."; - return status; - } - return SUCCESS; -} - -Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -void ReshapeInfo::SetCostForReshapeWithParameter() { - size_t success = 0; - for (auto &sp : sp_vector_) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } -} - -void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) { - MS_EXCEPTION_IF_NULL(strategy); - int32_t stage_id = strategy->GetInputStage(); - double computation_cost = - operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - std::shared_ptr result = std::make_shared(computation_cost, communication_cost); - result->communication_without_parameter_ = - operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - result->communication_with_partial_para_ = - result->communication_without_parameter_ + - COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); - - // Breaking ties for preferring data parallelization - BreakingTiesForPerferringDataParallel(strategy, result); - // refine communication cost calculation for practice - RefineForPracticalCost(result, false); - - std::shared_ptr swc = - std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); - swc->cost_list.push_back(result); - strategy_cost_.emplace_back(swc); -} - -Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GetAttrs failed."; - return FAILED; - } - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " - << outputs_shape_.size(); - return FAILED; - } - is_auto_parallel_ = true; - Shape input0_split; - (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - // strategy used only in the input node is parameter, - // in other case, use the input node's output_layout as input_layout. - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; - return FAILED; - } - return SUCCESS; -} - -Status ReshapeInfo::GenetateStrategyCosts(const std::vector> &pre_stra_costs, - const std::vector> &next_stra_costs, - int32_t out_index, int32_t in_index, bool is_prev_param) { - is_generating_costs_ = true; - for (auto pre_stra_cost : pre_stra_costs) { - std::vector pre_out_tensor_infos; - if (is_prev_param) { - pre_out_tensor_infos = pre_stra_cost->inputs_ptr; - } else { - pre_out_tensor_infos = pre_stra_cost->outputs_ptr; - } - if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { - MS_LOG(ERROR) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; - return FAILED; - } - TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; - SetInputLayout(pre_out_tensor_info.tensor_layout()); - // infer pre_node output strategy from output_layout. - Dimensions stra = pre_out_tensor_info.InferStrategy(); - if (stra.empty()) { - MS_LOG(ERROR) << "Infer strategy by tensor_info failed"; - return FAILED; - } - std::vector stra_inputs = {stra}; - StrategyPtr reshape_stra = std::make_shared(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); - if (next_stra_costs.empty()) { - if (Init(nullptr) == FAILED) { - MS_LOG(ERROR) << "Failure:operator reshape init failed"; - return FAILED; - } - SetCostForReshape(reshape_stra); - continue; - } - for (auto next_stra_cost : next_stra_costs) { - std::vector next_in_tensor_infos = next_stra_cost->inputs_ptr; - if (next_in_tensor_infos.size() <= IntToSize(in_index)) { - MS_LOG(ERROR) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; - return FAILED; - } - TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; - SetOutputLayout(next_in_tensor_info.tensor_layout()); - if (Init(nullptr) == FAILED) { - MS_LOG(DEBUG) << "Failure:operator reshape init failed"; - continue; - } - SetCostForReshape(reshape_stra); - } - } - is_generating_costs_ = false; - if (strategy_cost_.empty()) { - return FAILED; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/parallel/ops_info/reshape_info.h deleted file mode 100644 index 77a1f8e7f1..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.h +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ - -#include - -#include -#include -#include -#include - -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -/* - * parallel class for Reshape Primitive - */ -class ReshapeInfo : public OperatorInfo { - public: - ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), - dev_num_(0), - pre_operator_index_(0), - next_operator_index_(0), - input_layout_set_flag_(false), - output_layout_set_flag_(false) {} - ~ReshapeInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - void SetInputLayout(const TensorLayout &input_layout) { - input_layout_ = input_layout; - input_layout_set_flag_ = true; - } - void SetOutputLayout(const TensorLayout &output_layout) { - output_layout_ = output_layout; - output_layout_set_flag_ = true; - } - void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy); - void SetCostForReshapeWithParameter(); - void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; } - void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } - void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; } - void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; } - Status GenetateStrategyCosts(const std::vector> &pre_stra_costs, - const std::vector> &next_stra_costs, int32_t out_index, - int32_t in_index, bool is_prev_param); - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - std::string pre_operator_name() const { return pre_operator_name_; } - std::string next_operator_name() const { return next_operator_name_; } - int32_t pre_operator_index() const { return pre_operator_index_; } - int32_t next_operator_index() const { return next_operator_index_; } - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); - Status GetAttrs() override; - Strategys GetOutputsStrategy(); - - private: - Status GetParameterInput(); - Status ComputeReplaceOp(); - void InferTensorInfoByLayout(); - void device_number(const StrategyPtr &strategy); - Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); - - int32_t dev_num_; - int32_t pre_operator_index_; - int32_t next_operator_index_; - std::vector parameter_input_v_; - std::vector sp_vector_; - Dimensions input_strategy_; - TensorLayout input_layout_; - TensorLayout output_layout_; - bool input_layout_set_flag_; - bool output_layout_set_flag_; - bool is_generating_costs_; - std::string pre_operator_name_; - std::string next_operator_name_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.cc b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.cc deleted file mode 100644 index 772a4f83f6..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.cc +++ /dev/null @@ -1,147 +0,0 @@ -/** -#include "utils/log_adapter.h" - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/tmp_identity_info.h" - -#include -#include - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": invalid strategy."; - } - return FAILED; - } - return SUCCESS; -} - -Status TmpIdentityInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - dev_matrix_shape_ = input_strategy; - return SUCCESS; -} - -Status TmpIdentityInfo::InferTensorMap() { - std::vector tensor_map_index; - size_t size = inputs_shape_[0].size(); - // such as 4: tensor_map_index [3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - 1 - i)); - } - - inputs_tensor_map_.push_back(tensor_map_index); - outputs_tensor_map_.push_back(tensor_map_index); - return SUCCESS; -} - -Status TmpIdentityInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {inputs_strategy.at(0)}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - - TensorLayout input_tensor_layout; - if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(input_tensor_info); // the same as input - - return SUCCESS; -} - -Status TmpIdentityInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status TmpIdentityInfo::GenerateStrategies(int32_t stage_id) { - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " - << outputs_shape_.size(); - return FAILED; - } - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h deleted file mode 100644 index f7895d0511..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ - -#include -#include -#include - -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class TmpIdentityInfo : public OperatorInfo { - // This operator is only used for the case of a parameter tensor being used by multiple operators, where we - // consider this parameter tensor as TmpIdentityInfo operator. TmpIdentityInfo operator tasks as input a tensor, - // and outputs the same tensor. After the transformation, subsequent operators can share the output tensor. - public: - TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, - const std::string &name = IDENTITY_INFO) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~TmpIdentityInfo() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override { return SUCCESS; } - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.cc b/mindspore/ccsrc/parallel/ops_info/transpose_info.cc deleted file mode 100644 index 49bbae0cb4..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.cc +++ /dev/null @@ -1,247 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/transpose_info.h" - -#include -#include - -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/step_parallel.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - return SUCCESS; -} - -Status TransposeInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - input_strategy_ = stra.at(0); - for (auto &iter : input_strategy_) { - dev_matrix_shape_.push_back(iter); - } - return SUCCESS; -} - -// there is no Parameter for Transpose Primitive, so no need to do all reduce -Status TransposeInfo::InferMirrorOps() { return SUCCESS; } - -// there is no reduction dimension for forward computation of Transpose Primitive, so no need to do all reduce -Status TransposeInfo::InferForwardCommunication() { return SUCCESS; } - -/* - * get perm input of Transpose Primitive - * perm is a permutation of the dimensions of input - * the result is saved in axis_v_ - */ -Status TransposeInfo::ComputeAxis() { - if (input_value_[1] == nullptr) { - MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr."; - return FAILED; - } - std::vector elements; - ValueTuplePtr dim_tuple = input_value_[1]->cast(); - if (dim_tuple == nullptr) { - MS_LOG(ERROR) << name_ << ": input_value_[1] must be ValueTuplePtr."; - return FAILED; - } - elements = dim_tuple->value(); - if (elements.size() != inputs_shape_[0].size()) { - MS_LOG(ERROR) << name_ << ": elements size must equal to inputs shape 0 size."; - return FAILED; - } - axis_v_.clear(); - for (auto &element : elements) { - MS_EXCEPTION_IF_NULL(element); - if (element->isa()) { - int32_t axis = element->cast()->value(); - axis_v_.push_back(axis); - } else { - MS_LOG(ERROR) << name_ << ": The value of axis must be int32."; - return FAILED; - } - } - - for (int32_t i = 0; i < SizeToInt(axis_v_.size()); i++) { - auto iter = std::find(axis_v_.begin(), axis_v_.end(), i); - if (iter == axis_v_.end()) { - MS_LOG(ERROR) << name_ << ": axis_v_ must be a permutation."; - } - } - return SUCCESS; -} - -// the output tensor map is the permutation of input tensor map, the permutation is axis_v -Status TransposeInfo::InferTensorMap() { - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": inputs_shape_ and outputs_shape_ size must be 1, inputs shape and outputs shape is " - << inputs_shape_.size() << ", " << outputs_shape_.size(); - return FAILED; - } - - std::vector tensor_map_index_input; - for (size_t j = 0; j < inputs_shape_[0].size(); ++j) { - tensor_map_index_input.push_back(SizeToInt(inputs_shape_[0].size() - j - 1)); - } - inputs_tensor_map_.push_back(tensor_map_index_input); - - std::vector tensor_map_index_output = tensor_map_index_input; - for (uint32_t i = 0; i < tensor_map_index_output.size(); i++) { - tensor_map_index_output[i] = tensor_map_index_input[IntToUint(axis_v_[i])]; - } - outputs_tensor_map_.push_back(tensor_map_index_output); - return SUCCESS; -} - -// the output tensor strategy is the permutation of input tensor strategy, the permutation is axis_v -Strategys TransposeInfo::GetOutputsStrategy() { - Strategys outputs_strategy; - std::vector strategy = input_strategy_; - for (uint32_t i = 0; i < strategy.size(); i++) { - strategy[i] = input_strategy_[IntToUint(axis_v_[i])]; - } - outputs_strategy.push_back(strategy); - return outputs_strategy; -} - -Status TransposeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { - if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { - MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; - return FAILED; - } - Shape shape_in = inputs_shape_.at(0); - TensorMap tensor_map_in = inputs_tensor_map_.at(0); - Shape shape_out = outputs_shape_.at(0); - TensorMap tensor_map_out = outputs_tensor_map_.at(0); - - TensorLayout tensor_layout_in, tensor_layout_out; - if ((tensor_layout_in.InitFromVector(dev_matrix_shape_, tensor_map_in, shape_in) != SUCCESS) || - (tensor_layout_out.InitFromVector(dev_matrix_shape_, tensor_map_out, shape_out) != SUCCESS)) { - return FAILED; - } - - inputs_layout->push_back(tensor_layout_in); - outputs_layout->push_back(tensor_layout_out); - return SUCCESS; -} - -Status TransposeInfo::InferTensorInfo() { - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = GetOutputsStrategy(); - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { - return FAILED; - } - TensorLayout tensor_layout_in = inputs_layout.at(0); - TensorLayout tensor_layout_out = outputs_layout.at(0); - Shape shape_array_in = inputs_shape_.at(0); - Shape slice_shape_in = inputs_slice_shape.at(0); - Shape shape_array_out = outputs_shape_.at(0); - Shape slice_shape_out = outputs_slice_shape.at(0); - TensorInfo tensor_info_in(tensor_layout_in, shape_array_in, slice_shape_in); - TensorInfo tensor_info_out(tensor_layout_out, shape_array_out, slice_shape_out); - inputs_tensor_info_.push_back(tensor_info_in); - outputs_tensor_info_.push_back(tensor_info_out); - return SUCCESS; -} - -// compute axis_v_ during this method -Status TransposeInfo::GetAttrs() { return ComputeAxis(); } - -Status TransposeInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status TransposeInfo::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GetAttrs failed."; - return FAILED; - } - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " - << outputs_shape_.size(); - return FAILED; - } - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed"; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << "strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/parallel/ops_info/transpose_info.h deleted file mode 100644 index 50b76bde65..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -/* - * parallel class for Transpose Primitive - */ -class TransposeInfo : public OperatorInfo { - public: - TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~TransposeInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); - Status GetAttrs() override; - Strategys GetOutputsStrategy(); - - private: - Status ComputeAxis(); - std::vector axis_v_; - Dimensions input_strategy_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc deleted file mode 100644 index ce8b04d802..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc +++ /dev/null @@ -1,229 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/virtual_dataset_info.h" - -#include -#include -#include - -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/step_parallel.h" -#include "parallel/context.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - if (stra.size() < 1) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Strategy size must be larger than 1."; - } else { - MS_LOG(ERROR) << name_ << ": Strategy size must be larger than 1."; - } - return FAILED; - } - if (stra.size() == 1) { - MS_LOG(WARNING) << name_ << ": Strategy size is 1."; - return SUCCESS; - } - Dimensions strategy_first = stra.at(1); - for (auto iter_strategy = stra.begin() + 1; iter_strategy != stra.end(); ++iter_strategy) { - if (iter_strategy->empty()) { - MS_LOG(ERROR) << name_ << ": iter_strategy size is zero."; - } - if (strategy_first.at(0) != *(iter_strategy->begin())) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": The first dimension of each strategy must be the same."; - } else { - MS_LOG(ERROR) << name_ << ": The first dimension of each strategy must be the same."; - } - return FAILED; - } - - for (auto iter_element = iter_strategy->begin() + 1; iter_element != iter_strategy->end(); ++iter_element) { - if (*iter_element != 1) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": All dimension except the first dimension of each strategy must be 1."; - } else { - MS_LOG(ERROR) << name_ << ": All dimension except the first dimension of each strategy must be 1."; - } - return FAILED; - } - } - } - return SUCCESS; -} - -Status VirtualDatasetInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions strategy_first = stra.at(0); - int32_t stage = strategy_->GetInputStage(); - CheckGlobalDeviceManager(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); - int32_t batch_split_num = strategy_first.at(0); - dev_matrix_shape_.push_back(batch_split_num); - if (dev_num > batch_split_num) { - dev_matrix_shape_.push_back(dev_num / batch_split_num); - } - - return SUCCESS; -} - -Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; } - -Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; } - -Status VirtualDatasetInfo::InferTensorMap() { - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - - for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { - std::vector tensor_map_index; - if (full_batch) { - tensor_map_index.push_back(MAP_NONE); - } else { - tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size())))); - } - for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) { - tensor_map_index.push_back(MAP_NONE); - } - inputs_tensor_map_.push_back(tensor_map_index); - outputs_tensor_map_.push_back(tensor_map_index); - } - return SUCCESS; -} - -Status VirtualDatasetInfo::InferTensorInfo() { - for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { - MS_LOG(INFO) << name_ << ": InferTensorInfo " << i << ", size " << strategy_->GetInputNumber(); - TensorLayout tensor_layout_in; - if (tensor_layout_in.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(i), inputs_shape_.at(i)) != SUCCESS) { - return FAILED; - } - TensorInfo tensor_info_in(tensor_layout_in); - inputs_tensor_info_.push_back(tensor_info_in); - outputs_tensor_info_.push_back(tensor_info_in); - } - return SUCCESS; -} - -Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; } - -Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) { - if (InitWithManualRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - return SUCCESS; -} - -Status VirtualDatasetInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -void VirtualDatasetInfo::ReComputeBatchSplitFlagList() { - for (size_t i = 0; i < inputs_shape_.size(); i++) { - split_flag_list_[i] = true; - } -} - -Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - size_t total_dev_num; - - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GetAttrs failed"; - return FAILED; - } - - CheckGlobalDeviceManager(); - is_auto_parallel_ = true; - if (full_batch) { - total_dev_num = 1; - } else { - total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - } - StrategyPtr sp; - std::vector strategy; - for (auto &shape : inputs_shape_) { - Shape temp; - temp.emplace_back(SizeToInt(total_dev_num)); - (void)temp.insert(temp.end(), shape.size() - 1, 1); - strategy.push_back(temp); - } - sp = std::make_shared(stage_id, strategy); - - if (SetCostUnderStrategy(sp) == SUCCESS) { - if (full_batch) { - MS_LOG(INFO) << name_ << ": Successfully generated full-batch-parallel-strategy."; - } else { - MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy."; - } - PrintStrategy(sp); - } else { - if (full_batch) { - MS_LOG(ERROR) << name_ << ": Generating full-batch-parallel-strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -Status VirtualDatasetInfo::InferAsLossDivisor() { - // no need to insert div op - as_loss_divisor_ = 1; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h deleted file mode 100644 index 312ac7a6a4..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_OPS_INFO_DATASET_INFO_H_ -#define PARALLEL_OPS_INFO_DATASET_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class VirtualDatasetInfo : public OperatorInfo { - public: - VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~VirtualDatasetInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - void ReComputeBatchSplitFlagList() override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status GetAttrs() override; - Status InferAsLossDivisor() override; -}; -} // namespace parallel -} // namespace mindspore - -#endif // PARALLEL_OPS_INFO_VIRTUAL_DATASET_INFO_H_ diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc deleted file mode 100644 index 894177df8d..0000000000 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ /dev/null @@ -1,1189 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/step_auto_parallel.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/param_value_py.h" -#include "ir/tensor.h" -#include "optimizer/opt.h" -#include "optimizer/optimizer.h" -#include "parallel/auto_parallel/dp_algo_costmodel.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/auto_parallel/rec_core/rec_generate_strategy.h" -#include "parallel/auto_parallel/rec_core/rec_parse_graph.h" -#include "parallel/auto_parallel/rec_core/rec_partition.h" -#include "parallel/context.h" -#include "parallel/ops_info/tmp_identity_info.h" -#include "parallel/ops_info/reshape_info.h" -#include "parallel/step_parallel.h" -#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/pipeline.h" - -namespace mindspore { -namespace parallel { -bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); - // assume no change to graph - bool changes = false; - // control whether use model_parallel mode - if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) || - root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) { - return changes; - } - // check whether strategy_search_mode is valid - std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); - if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { - // Setting searching mode: dynanic programming as default. - strategy_search_mode = DYNAMIC_PROGRAMMING; - MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default"; - } - - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); - - if (MsContext::GetInstance()->save_graphs_flag()) { - draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root); - } - MS_LOG(INFO) << "Now entering step auto parallel"; - TOTAL_OPS = 0; - AnfNodePtr ret = root->get_return(); - std::vector all_nodes = DeepScopedGraphSearch(ret); - - if (ParallelInit() != SUCCESS) { - MS_LOG(EXCEPTION) << "Parallel init failed"; - } - - // mark the forward cnodes, parallel only care these nodes - MarkForwardCNode(root); - - if (FindCommunicationOp(all_nodes)) { - MS_LOG(EXCEPTION) << "The graph contain communication op"; - } - - // search parallelization strategy - if (strategy_search_mode == DYNAMIC_PROGRAMMING) { - if (ParallelStrategySearch(all_nodes, root) != SUCCESS) { - MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode"; - } - } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) { - if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) { - MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode"; - } - } else { - MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected"; - } - - (void)gettimeofday(&end_time, nullptr); - uint64_t time = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us"; - - root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true); - return changes; -} - -// Given the node, return whether each input is a parameter or a output of a operator. -// The returned boolean vector should be the same order of the inputs, thus its implementation -// is closely consistent with ExtractShape() in step_parallel.cc -std::vector ExtractInputParameterByNode(const CNodePtr &node) { - std::vector is_parameter; - std::vector node_inputs{node->inputs()}; - for (size_t i = 1; i < node_inputs.size(); ++i) { - auto input = node_inputs[i]; - - if (input->isa()) { - auto input_parameter = input->cast(); - if (input_parameter->has_default()) { - auto param_value = std::dynamic_pointer_cast(input_parameter->default_param()); - bool require_grad = py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad")); - is_parameter.push_back(require_grad); - } else { - is_parameter.push_back(false); - } - } else if (input->isa() || IsValueNode(input) || IsValueNode(input)) { - is_parameter.push_back(false); - } - } - return is_parameter; -} - -// Given the type, return the number of bytes to represent this type -size_t GetLengthOfDataType(const TypePtr &type) { - switch (type->type_id()) { - case kNumberTypeBool: - return sizeof(bool); - case kNumberTypeInt8: - return sizeof(int8_t); - case kNumberTypeInt16: - return sizeof(int16_t); - case kNumberTypeInt32: - return sizeof(int32_t); - case kNumberTypeInt64: - return sizeof(int64_t); - case kNumberTypeUInt8: - return sizeof(uint8_t); - case kNumberTypeUInt16: - return sizeof(uint16_t); - case kNumberTypeUInt32: - return sizeof(uint32_t); - case kNumberTypeUInt64: - return sizeof(uint64_t); - case kNumberTypeFloat16: - return sizeof(float) / 2; - case kNumberTypeFloat32: - return sizeof(float); - case kNumberTypeFloat64: - return sizeof(double); - case kNumberTypeInt: - return sizeof(int); - case kNumberTypeUInt: - return sizeof(unsigned int); - case kNumberTypeFloat: - return sizeof(float); - default: - MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); - } -} - -size_t GetInputsTypeLen(const AnfNodePtr &input) { - MS_EXCEPTION_IF_NULL(input); - if (!input->isa() && !input->isa() && !IsValueNode(input)) { - MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor"; - } - - size_t input_type_len = 0; - auto type = input->Type(); - MS_EXCEPTION_IF_NULL(type); - if (type->isa()) { - auto input_element_type = type->cast()->element(); - input_type_len = GetLengthOfDataType(input_element_type); - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); - } - return input_type_len; -} - -std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - std::vector inputs_type_len; - std::vector node_inputs{node->inputs()}; - - // extract input element length - for (auto &input : node_inputs) { - if (IsValueNode(input)) { - auto func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector parameters = FindParameterByRefKeyNode(input, func_graph); - if (parameters.size() != 1) { - MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; - } - inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); - } else if (input->isa() || input->isa() || IsValueNode(input)) { - // extract input shape from parameter and apply node - inputs_type_len.push_back(GetInputsTypeLen(input)); - } - } - return inputs_type_len; -} - -std::vector ExtractOutputTypeByNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - std::vector outputs_type; - // extract output element type - auto primary_output_type = node->Type(); - MS_EXCEPTION_IF_NULL(primary_output_type); - if (primary_output_type->isa()) { - // in this case, the output is a tuple - auto tuple_output_type = primary_output_type->cast(); - auto elements = tuple_output_type->elements(); - for (auto &ele : elements) { - if (ele->isa()) { - auto ele_element_type = ele->cast()->element(); - outputs_type.push_back(ele_element_type); - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); - } - } - } else { - // in this case, the output is a single tensor - if (primary_output_type->isa()) { - auto element_type = primary_output_type->cast()->element(); - outputs_type.push_back(element_type); - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); - } - } - return outputs_type; -} - -bool IsElementWiseOperator(const std::string &op_name) { - static const std::set elementwise_op = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, - SQRT, CAST, POW, EXP, LOG, COS, - ACOS, LOGICALNOT, NEG, SQUARE, SIGMOID}; - auto iter = elementwise_op.find(op_name); - return (iter != elementwise_op.end()); -} - -bool IsSplittableOperator(const std::string &op_name) { - // clang-format off - static const std::set splittable_op = - {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, - FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, - REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, - MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, - LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, - STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, - SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS}; - // clang-format on - - auto iter = splittable_op.find(op_name); - return (iter != splittable_op.end()); -} - -bool IsAutoParallelCareNode(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - ValueNodePtr prim_node = cnode->input(0)->cast(); - if (prim_node == nullptr) { - return false; - } - PrimitivePtr prim = GetValueNode(prim_node); - if (prim == nullptr) { - return false; - } - bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name()); - if (bool_result) { - MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name(); - } else if (prim->name() == CAST) { - if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) { - // Do not care CASTs from optimizer - return false; - } - return true; - } - return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name()); -} - -OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { - MS_EXCEPTION_IF_NULL(prim); - MS_EXCEPTION_IF_NULL(cnode); - auto attrs = prim->attrs(); - std::vector shape_list = ExtractShape(cnode); - if (shape_list.empty()) { - MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape"; - } - // Create an OperatorInfo instance - OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list); - MS_EXCEPTION_IF_NULL(operator_info); - // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not) - std::vector parameter_info = ExtractInputParameterByNode(cnode); - if (operator_info->set_is_parameter(parameter_info) != SUCCESS) { - MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name(); - return nullptr; - } - // Set the data type for inputs and outputs of this OperatorInfo - auto inputs_type_length = ExtractInputTypeLengthByNode(cnode); - auto outputs_type = ExtractOutputTypeByNode(cnode); - std::vector outputs_type_length; - outputs_type_length.reserve(outputs_type.size()); - std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length), - GetLengthOfDataType); - if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) { - MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name(); - return nullptr; - } - if (operator_info->set_outputs_type(outputs_type) != SUCCESS) { - MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name(); - return nullptr; - } - // When the 'inputs' contains numerical values for some operators, these values should be extracted from - // ANF graph - auto &inputs = cnode->inputs(); - std::vector input_value; - for (size_t index = 1; index < inputs.size(); ++index) { - if (inputs[index]->isa()) { - input_value.push_back(GetValueNode(inputs[index])); - } else { - input_value.emplace_back(nullptr); - } - } - operator_info->set_input_value(input_value); - operator_info->set_outputs_dtype(cnode->Type()); - operator_info->set_cnode(cnode); - // key of strategy map - std::string strategy_key_name = NodeParameterName(cnode); - bool load_strategy_from_ckpt = - StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); - // If no strategy has been configured for this operator, then candidate strategies are generated for - // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. - // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . - if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) { - // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for - // BatchParallelInfo operator - operator_info->ComputeBatchSplitFlagList(); - if (operator_info->GenerateStrategies(0) != SUCCESS) { - MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed."; - return nullptr; - } - } else { - // In this case, the configured strategy should be extracted to help setting cost - StrategyPtr strategyPtr; - if (load_strategy_from_ckpt) { - strategyPtr = (*stra_map)[strategy_key_name]; - } else { - strategyPtr = parallel::ExtractStrategy(attrs); - } - if (strategyPtr != nullptr) { - if (prim->name() == RESHAPE) { - MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; - } - // Set cost for this configured strategy - if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { - MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; - } else if (FULLY_USE_DEVICES) { - // If configured to fully use devices, then checking for the user-specified strategy - int32_t used_devices = operator_info->used_devices(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size(); - // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel - if (used_devices == 1) { - return operator_info; - } - // 'used_devices == -1' means that 'used_devices_' is not set - if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) { - MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, " - << "but the specified strategy uses device: " << used_devices - << ", total devices: " << total_device_num; - } - } - } - } - return operator_info; -} - -// Using CNode's UniqueIds to construct nodes -Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &) { - MS_LOG(INFO) << "Constructing nodes for cost graph begins."; - entire_costgraph = std::make_shared(); - entire_costgraph->SetDeviceMemoryAndCostParameter(); - // The map from CNode's UniqueId to its operatorInfo - std::map from_cnode_to_info; - // extract strategy from checkpoint for multi-train - StrategyMap stra_map; - if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { - if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { - MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; - } - } - // Step 1 - for (auto &node : all_nodes) { - // NOTE: we only care about splittable Primitive operators - auto cnode = node->cast(); - bool bool_result = (cnode == nullptr) || (!IsValueNode(cnode->input(0))); - if (bool_result) { - continue; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsAutoParallelCareNode(cnode)) { - // Needed by rec_parser - if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) { - auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node); - if (prev_cnode != nullptr) { - entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId())); - } - } - continue; - } - PrimitivePtr prim = GetValueNode(prim_anf_node); - MS_EXCEPTION_IF_NULL(prim); - - auto search_cnode = from_cnode_to_info.find(cnode->UniqueId()); - if (search_cnode == from_cnode_to_info.end()) { - auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); - if (operator_info == nullptr) { - return FAILED; - } - // Needed by rec_parser - operator_info->set_type(prim->name()); - std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); - - entire_costgraph->AddOperator(operator_info); - (void)cnode->set_operator_info(operator_info); - MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() - << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() - << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); - (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); - // Needed by rec_parser - entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); - } else { - // Two CNODEs' UniqueIds should not be equal - MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId() - << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() - << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name(); - } - } - - MS_LOG(INFO) << "Constructing nodes for cost graph ends."; - return SUCCESS; -} - -// Using CNode's UniqueIdThroughCopys to construct nodes -Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &) { - MS_LOG(INFO) << "Constructing nodes for cost graph begins."; - entire_costgraph = std::make_shared(); - entire_costgraph->SetDeviceMemoryAndCostParameter(); - // The map from CNode's UniqueIdThroughCopy to its operatorInfo - std::map from_cnode_to_info; - // extract strategy from checkpoint for multi-train - StrategyMap stra_map; - if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { - if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { - MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; - } - } - for (auto &node : all_nodes) { - // NOTE: we only care about splittable Primitive operators - auto cnode = node->cast(); - bool bool_result = (cnode == nullptr) || (!IsValueNode(cnode->input(0))); - if (bool_result) { - continue; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsAutoParallelCareNode(cnode)) { - // Needed by rec_parser - if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) { - auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node); - if (prev_cnode != nullptr) { - entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId())); - } - } - continue; - } - PrimitivePtr prim = GetValueNode(prim_anf_node); - - // Find the operatorInfo if it exists - auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy()); - if (search_cnode == from_cnode_to_info.end()) { - // In this case, the corresponding OperatorInfo is not created, create the new one. - auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); - if (operator_info == nullptr) { - return FAILED; - } - // Needed by rec_parser - operator_info->set_type(prim->name()); - std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); - - entire_costgraph->AddOperator(operator_info); - (void)cnode->set_operator_info(operator_info); - MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() - << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() - << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); - (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); - // Needed by rec_parser - entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); - } else { - auto current_op_ptr = search_cnode->second; - if (current_op_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; - } else { - bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && - (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && - (current_op_ptr->name().find(prim->name()) == std::string::npos); - if (is_find_wrong) { - MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() - << " does not match the Prim: " << prim->name(); - } - (void)cnode->set_operator_info(current_op_ptr); - MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() - << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() - << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); - } - } - } - - MS_LOG(INFO) << "Constructing nodes for cost graph ends."; - return SUCCESS; -} - -void ConstructCostGraphEdges(const std::vector &all_nodes) { - // Step 2 - MS_LOG(INFO) << "Constructing edges for cost graph begins."; - for (auto &node : all_nodes) { - auto cnode = node->cast(); - bool bool_result_cnode = (cnode == nullptr) || !IsValueNode(cnode->input(0)); - if (bool_result_cnode) { - continue; - } - auto &inputs = cnode->inputs(); - ValueNodePtr prim_anf_node = inputs[0]->cast(); - if (!IsAutoParallelCareNode(cnode)) { - continue; - } - PrimitivePtr prim = GetValueNode(prim_anf_node); - size_t edge_count = 0; - - for (size_t i = 1; i < inputs.size(); ++i) { - auto prev_cnode = inputs[i]->cast(); - bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); - if (bool_result_prev_cnode) { - continue; - } - ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast(); - PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast(); - size_t output_index = 0; - - bool bool_result = - (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); - while (bool_result) { - if (IsAutoParallelCareNode(prev_cnode)) { - std::string edge_name = - prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name(); - // If the edge between these two operators already has been added, then the edge will not be added again. - if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { - break; - } - EdgePtr edge_ptr; - MS_LOG(INFO) << "Creating edge: " << edge_name; - - bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || - (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); - if (follow_strategy) { - // Redistribution in not allowed on the edge. - // Elementwise operators have the same strategy as their previous operators. - edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), - output_index, i - 1, false, true); - } else { - edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), - output_index, i - 1, false); - } - - // Init costs for this edge - if (edge_ptr->InitEdgeCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Edge cost initialization failed"; - } - cnode->operator_info()->AddPrevEdge(edge_ptr); - prev_cnode->operator_info()->AddSuccEdge(edge_ptr); - entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr); - MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and " - << cnode->operator_info()->name(); - edge_count++; - - break; - } else if (prev_prim->name() == TUPLE_GETITEM) { - // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before - // this 'tuple_getitem' - MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator."; - output_index = IntToSize(GetValue(GetValueNode(prev_cnode->input(2)))); - prev_cnode = prev_cnode->input(1)->cast(); - bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); - if (bool_result_tuple) { - break; - } - prev_prim_anf_node = prev_cnode->input(0)->cast(); - prev_prim = prev_prim_anf_node->value()->cast(); - if (!IsAutoParallelCareNode(prev_cnode)) { - MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name(); - } - MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, " - << "and creating an edge between the Operator before " - << "'tuple_getitem' and the Operator after 'tuple_getitem'."; - } else if (prev_prim->name() == DEPEND) { - // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before - // this 'depend' - MS_LOG(INFO) << "Jumping the 'depend' operator."; - prev_cnode = prev_cnode->input(1)->cast(); - bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); - if (bool_result_depend) { - break; - } - prev_prim_anf_node = prev_cnode->input(0)->cast(); - prev_prim = prev_prim_anf_node->value()->cast(); - MS_LOG(INFO) << "Jumped the 'depend' operator, " - << "and creating an edge between the Operator before " - << "'depend' and the Operator after 'depend'."; - } - bool_result = - (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); - } - } - MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); - } - - MS_LOG(INFO) << "Constructing edges for cost graph ends."; -} - -std::pair> CNodeWithRefKeys(const AnfNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - std::vector refkeys; - if (cnode->isa()) { - auto cnode_ptr = cnode->cast(); - auto inputs = cnode_ptr->inputs(); - for (auto &one_input : inputs) { - if (IsValueNode(one_input)) { - refkeys.push_back(one_input); - } - } - if (refkeys.size() >= 1) { - return std::make_pair(cnode, refkeys); - } - } - return {nullptr, refkeys}; -} - -void AugmentCostGraph(const std::vector &all_nodes) { - // Step 3 - for (auto &node : all_nodes) { - auto cnode_with_refkeys = CNodeWithRefKeys(node); - if ((!node->isa()) && (cnode_with_refkeys.first == nullptr)) { - continue; - } - std::string parameter_name; - AnfNodePtr target_parameter = nullptr; - AnfNodeIndexSet target_set; - - if (cnode_with_refkeys.first != nullptr) { - // Dealing with the RefKey case - auto refkeys = cnode_with_refkeys.second; - auto cnode = cnode_with_refkeys.first; - - auto cnode_ptr = cnode->cast(); - if (cnode_ptr == nullptr || !IsValueNode(cnode_ptr->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(cnode_ptr)) { - continue; - } - - if (refkeys.size() > 1) { - MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys."; - } - MS_EXCEPTION_IF_NULL(cnode->func_graph()); - auto cnode_func_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager()); - - // Find the RefKey being used - auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]]; - for (auto &candidate : candidate_set_by_refkey) { - auto candidate_node = candidate.first; - auto c = candidate_node->cast(); - if (c == nullptr || !IsValueNode(c->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(c)) { - continue; - } - target_set.add(candidate); - } - - // Find the corresponding Parameter being used - std::vector parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph); - if (parameters.size() != 1) { - MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; - } - parameter_name = parameters[0]->cast()->name(); - target_parameter = parameters[0]; - auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]]; - for (auto &candidate : candidate_set_by_para) { - auto candidate_node = candidate.first; - auto c = candidate_node->cast(); - if (c == nullptr || !IsValueNode(c->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(c)) { - continue; - } - (void)target_set.insert(candidate); - } - } else if (node->isa()) { - // Dealing with the Parameter case - MS_EXCEPTION_IF_NULL(node->func_graph()); - MS_EXCEPTION_IF_NULL(node->func_graph()->manager()); - auto candidate_set = node->func_graph()->manager()->node_users()[node]; - for (auto &candidate : candidate_set) { - auto candidate_node = candidate.first; - auto c = candidate_node->cast(); - if (c == nullptr || !IsValueNode(c->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(c)) { - continue; - } - (void)target_set.insert(candidate); - } - // In this case, node is a Parameter - parameter_name = node->cast()->name(); - target_parameter = node; - } - if (target_set.size() <= 1) { - continue; - } - - // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs - std::set target_without_duplicate; - for (auto &target : target_set) { - auto target_cnode = target.first->cast(); - auto input_index = target.second; - (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name()); - } - if (target_without_duplicate.size() <= 1) { - continue; - } - - // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators. - OperatorInfoPtr tmp_identity_ptr; - bool new_identity = false; - std::string tmp_identity_name; - auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name); - if (returned_identity != nullptr) { - // In this case, the TmpIdentityInfo instance has already been created - new_identity = false; - tmp_identity_ptr = returned_identity; - tmp_identity_name = tmp_identity_ptr->name(); - } else { - // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created. - new_identity = true; - // 1) extract input shape from this Parameter - MS_EXCEPTION_IF_NULL(target_parameter); - AbstractBasePtr abstract = target_parameter->abstract(); - if (abstract == nullptr) { - MS_LOG(EXCEPTION) << "Failure: abstract is nullptr"; - } - auto input_shape = dyn_cast(abstract->GetShapeTrack()); - if (input_shape == nullptr) { - MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr"; - } - std::vector shape_int = input_shape->shape(); - Shape shape; - (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape), - [](int sub_shape) { return static_cast(sub_shape); }); - Shapes inputs_shape = {shape}; - Shapes outputs_shape = {shape}; - // 2) init the attr - std::unordered_map attr = {}; - - // Create the TmpIdentity instance - tmp_identity_ptr = std::make_shared(inputs_shape, outputs_shape, attr); - tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS)); - TOTAL_OPS++; - tmp_identity_ptr->set_refkey_parameter_name(parameter_name); - // Set the parameter and type lengths for inputs and outputs - std::vector is_parameter; - auto casted_target_parameter = target_parameter->cast(); - MS_EXCEPTION_IF_NULL(casted_target_parameter); - if (casted_target_parameter->has_default()) { - auto param_value = std::dynamic_pointer_cast(casted_target_parameter->default_param()); - bool require_grad = py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad")); - is_parameter.push_back(require_grad); - } else { - is_parameter.push_back(false); - } - if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) { - MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed"; - } - auto node_type = target_parameter->Type(); - if (node_type->isa()) { - auto input_element_type = node_type->cast()->element(); - std::vector type_length = {GetLengthOfDataType(input_element_type)}; - if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) { - MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed"; - } - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name(); - } - - // Generate strategies for this TmpIdentityInfo instance; - if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) { - MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name(); - } - } - // A flag recording whether new edges have been created or not - bool add_identity_edge = false; - - // Create edges between this TmpIdentityInfo instance and subsequent Operator instances - for (auto &target : target_set) { - auto target_cnode = target.first->cast(); - auto prim = GetValueNode(target_cnode->input(0)); - auto input_index = target.second; - - std::string edge_name = - std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name(); - // If the edge between these two operators already has been added, then the edge will not be added again. - if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) { - continue; - } - std::shared_ptr edge_ptr = std::make_shared( - edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); - - if (edge_ptr->InitEdgeCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Edge cost initialization failed"; - } - target_cnode->operator_info()->AddPrevEdge(edge_ptr); - tmp_identity_ptr->AddSuccEdge(edge_ptr); - entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr); - MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and " - << target_cnode->operator_info()->name(); - add_identity_edge = true; - } - if (new_identity && add_identity_edge) { - // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied - entire_costgraph->AddOperator(tmp_identity_ptr); - } - } -} - -bool FindReshape(const CNodePtr &cnode) { - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - return false; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { - return false; - } - PrimitivePtr prim = GetValueNode(prim_anf_node); - MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); - if (operator_info == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; - } - if (prim->name() != RESHAPE) { - return false; - } - return true; -} - -// find previous node, then obtain its strategy_cost_ vector to get its layout vector. -bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) { - // if previous node is a parameter, handle it in the outsize. - if (node->isa()) { - return false; - } - if (!node->isa()) { - return false; - } - CNodePtr cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - return false; - } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - *pre_operator_info = cnode->operator_info(); - *out_index = 0; - return true; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = prim_anf_node->value()->cast(); - if (prim->name() == TUPLE_GETITEM) { - *out_index = GetTupleGetItemIndex(cnode); - // find tuple_get_item's previous node - auto pre_node = cnode->input(1); - if (!pre_node->isa()) { - MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; - } - CNodePtr pre_cnode = pre_node->cast(); - if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { - *pre_operator_info = pre_cnode->operator_info(); - return true; - } - return false; - } - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - if (prim->name() == DEPEND && index != 1) { - continue; - } - if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { - continue; - } - return true; - } - MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; - return false; -} - -// find next node, then obtain its strategy_cost_ vector to get its layout vector. -// if reshape's output connect to several primitive, return the first layout found -bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(cnode->func_graph()); - FuncGraphManagerPtr manager = cnode->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[cnode]; - for (auto &node_pair : node_set) { - CNodePtr use_apply = node_pair.first->cast(); - if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { - MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); - *next_operator_info = use_apply->operator_info(); - *in_index = node_pair.second - 1; - return true; - } - MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) - << " " << (use_apply->operator_info() != nullptr); - - if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { - return true; - } - } - return false; -} - -void ReshapeCostCompute(const std::vector &all_nodes) { - for (auto node : all_nodes) { - auto cnode = node->cast(); - if (!FindReshape(cnode)) { - continue; - } - MS_ASSERT(cnode->inputs().size() == 3); - // get previous node's strategy_cost_ - auto pre_node = cnode->input(1); - int32_t out_index = 0; - OperatorInfoPtr pre_operator_info; - std::vector> pre_stra_costs; - if (pre_node->isa()) { - OperatorInfoPtr operator_info = cnode->operator_info(); - auto reshape_info = std::dynamic_pointer_cast(operator_info); - reshape_info->SetCostForReshapeWithParameter(); - pre_operator_info = reshape_info; - pre_stra_costs = reshape_info->strategy_cost(); - } else { - if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { - MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed"; - } - pre_stra_costs = pre_operator_info->strategy_cost(); - } - // get next node's strategy_cost_ - int32_t in_index = 0; - OperatorInfoPtr next_operator_info; - std::vector> next_stra_costs; - bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index); - if (!find_next_node) { - MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed"; - } - // set input_layout and output_layout for reshape. - // init reshape and set cost for each input_layout and output_layout. - OperatorInfoPtr operator_info = cnode->operator_info(); - auto reshape_info = std::dynamic_pointer_cast(operator_info); - reshape_info->set_pre_operator_name(pre_operator_info->name()); - reshape_info->set_pre_operator_index(out_index); - if (find_next_node) { - next_stra_costs = next_operator_info->strategy_cost(); - reshape_info->set_next_operator_name(next_operator_info->name()); - reshape_info->set_next_operator_index(in_index); - } - bool is_prev_param = pre_node->isa(); - if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) != - SUCCESS) { - MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!"; - } - } -} - -Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root) { - // There are 4 meta-steps to determine the parallelization strategy for the ANF graph. - // Step 1: Traverse the ANF graph, and create NODEs for costgraph: - // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies - // for each OperatorInfo; - // Step 1.1: Deal with 'Reshape': - // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's - // layout as its output layout. - // Step 2: Traverse the ANF graph, and create EDGES for costgraph: - // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies - // for each edge, based on the strategies of two OperatorInfos; - // Step 3: Augment the costgraph: - // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity - // operator for this Parameter, and add an edge for the use of this Parameter by each - // subsequent operator; - // Step 3.1: Calculate memory usage: - // note the memory usage calculation is different in training phase and inference phase. - // Step 4: Run the Dynamic Programming algorithm: - // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge - // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input - // tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm - // runs on each of them. - // - // OUTPUT: the determined strategy for each operator. - - // Step 1 - if (CostModelContext::GetInstance()->is_multi_subgraphs()) { - if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { - MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " - << entire_costgraph->GetOperators().size() << " operators."; - } else { - MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; - } - } else { - if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { - MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " - << entire_costgraph->GetOperators().size() << " operators."; - } else { - MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; - } - } - // Step 1.1 - ReshapeCostCompute(all_nodes); - // Step 2 - ConstructCostGraphEdges(all_nodes); - MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() - << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; - - // Step 3: Augment the costgraph. - AugmentCostGraph(all_nodes); - MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() - << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; - - // Step 3.1: Calculate the memory usage - if (entire_costgraph->CalculateMemoryCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Calculating memory cost failed."; - } - - // Step 4: run DP algorithm on the costgraph. - if (GetStrategy(entire_costgraph) != SUCCESS) { - MS_LOG(ERROR) << "Strategy search for cost-graph fails"; - return FAILED; - } - MS_LOG(INFO) << "Searching strategy succeeded."; - - if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { - MS_LOG(INFO) << "Init selected strategy succeeded."; - } else { - MS_LOG(EXCEPTION) << "Init selected strategy failed."; - } - - // print the selected strategy - for (auto &op : entire_costgraph->GetOperators()) { - StrategyPtr s_strategy = op->selected_strategy(); - MS_LOG(INFO) << op->name() << " : The strategy is:"; - PrintStrategy(s_strategy); - } - - return SUCCESS; -} - -std::vector> RecInputTensorNames(const std::map::iterator &it, - std::vector> input_tensor_names) { - for (size_t j = 0; j < input_tensor_names.size(); j++) { - for (size_t k = 0; k < input_tensor_names[j].size(); k++) { - if (it->first == input_tensor_names[j][k]) { - input_tensor_names[j][k] = it->second; - break; - } - } - } - return input_tensor_names; -} - -CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) { - PrimitivePtr prim = GetValueNode(prim_anf_node); - if (prim->name() == TUPLE_GETITEM || prim->name() == DEPEND) { - auto prev_cnode = cnode->input(1)->cast(); - if (prev_cnode == nullptr || !IsValueNode(prev_cnode->input(0))) { - return nullptr; - } - auto prev_prim = prev_cnode->input(0)->cast()->value()->cast(); - while (prev_prim->name() == TUPLE_GETITEM || prev_prim->name() == DEPEND) { - prev_cnode = prev_cnode->input(1)->cast(); - if (prev_cnode == nullptr || !IsValueNode(prev_cnode->input(0))) { - return nullptr; - } - prev_prim = prev_cnode->input(0)->cast()->value()->cast(); - } - return prev_cnode; - } - return nullptr; -} - -Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root) { - if (CostModelContext::GetInstance()->is_multi_subgraphs()) { - if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { - MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " - << entire_costgraph->GetOperators().size() << " operators."; - } else { - MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; - } - } else { - if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { - MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " - << entire_costgraph->GetOperators().size() << " operators."; - } else { - MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; - } - } - ReshapeCostCompute(all_nodes); - - auto ops = entire_costgraph->GetOperators(); - std::vector> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list(); - auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list(); - for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) { - input_tensor_names = RecInputTensorNames(it++, input_tensor_names); - } - std::shared_ptr graph = ParseGraph(ops, input_tensor_names); - - std::shared_ptr>> eli_list(new std::vector>); - std::shared_ptr> index_list(new std::vector); - graph = EliminateGraph(graph, eli_list, index_list); - - size_t num_device = g_device_manager->DeviceNum(); - double device_memory = entire_costgraph->GetDeviceMemory(); - if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) { - MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; - } else { - MS_LOG(ERROR) << "PartitionForAllDevices failed."; - return FAILED; - } - - GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list); - - if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { - MS_LOG(INFO) << "Init selected strategy succeeded."; - } else { - MS_LOG(ERROR) << "Init selected strategy failed."; - return FAILED; - } - - // print the selected strategy - for (auto &op : entire_costgraph->GetOperators()) { - StrategyPtr s_strategy = op->selected_strategy(); - MS_LOG(INFO) << op->name() << " : The strategy is:"; - PrintStrategy(s_strategy); - } - - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.h b/mindspore/ccsrc/parallel/step_auto_parallel.h deleted file mode 100644 index c923e5770f..0000000000 --- a/mindspore/ccsrc/parallel/step_auto_parallel.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_STEP_AUTO_PARALLEL_H_ -#define PARALLEL_STEP_AUTO_PARALLEL_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "optimizer/opt.h" -#include "parallel/status.h" -#include "pipeline/pipeline.h" - -namespace mindspore { -namespace parallel { -bool IsSplittableOperator(const std::string &); - -bool IsAutoParallelCareNode(const CNodePtr &); - -// main step of Auto-parallel -bool StepAutoParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); - -size_t GetLengthOfDataType(const TypePtr &type); - -std::vector ExtractInputParameterByNode(const CNodePtr &node); - -std::vector ExtractInputTypeLengthByNode(const CNodePtr &node); - -std::vector ExtractOutputTypeByNode(const CNodePtr &node); - -Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &root); - -Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &root); - -void ConstructCostGraphEdges(const std::vector &all_nodes); - -void AugmentCostGraph(const std::vector &all_nodes); - -Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root); - -Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root); - -std::vector> RecInputTensorNames(const std::map::iterator &it, - std::vector> input_tensor_names); - -CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node); -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_STEP_AUTO_PARALLEL_H_ diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc deleted file mode 100644 index 7d1200b190..0000000000 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ /dev/null @@ -1,2368 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/step_parallel.h" - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "ir/tensor.h" -#include "ir/param_value_py.h" -#include "operator/ops.h" -#include "optimizer/optimizer.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/context.h" -#include "parallel/device_manager.h" -#include "parallel/dynamic_creator.h" -#include "parallel/graph_util/generate_graph.h" -#include "parallel/graph_util/graph_info.h" -#include "parallel/graph_util/node_info.h" -#include "parallel/node_check.h" -#include "parallel/ops_info/matmul_info.h" -#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" -#include "utils/comm_manager.h" -#include "utils/symbolic.h" -#include "pipeline/static_analysis/prim.h" - -using mindspore::tensor::Tensor; - -namespace mindspore { -namespace parallel { -static const std::set COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; -static const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; -// g_RefMap, for CNode B input i is a RefKey[Parameter C], -// it will be one item in map with key: C, and value: (B, i) -static std::map> g_RefMap; - -void SetCommunicationOpGroupLabel(std::vector new_node_input) { - if (new_node_input.empty()) { - return; - } - - ValueNodePtr prim_anf_node = new_node_input[0]->cast(); - PrimitivePtr prim = GetValueNode(prim_anf_node); - MS_EXCEPTION_IF_NULL(prim); - - auto attrs = prim->attrs(); - auto iter = attrs.find(GROUP); - if (iter != attrs.end()) { - auto value = iter->second; - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - std::string hash_name = value->cast()->value(); - MS_EXCEPTION_IF_NULL(g_device_manager); - std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name); - (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name)); - } - } -} - -std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { - MS_EXCEPTION_IF_NULL(node); - OperatorArgs arg_forward = op.second; - ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name); - MS_EXCEPTION_IF_NULL(pyop_instance); - OperatorParams params = arg_forward.second; - - std::vector new_node_input = {NewValueNode(pyop_instance), node}; - if (!params.empty()) { - for (auto ¶m : params) { - AnfNodePtr val = NewValueNode(param.first.second); - MS_EXCEPTION_IF_NULL(val); - int32_t position = param.second; - (void)new_node_input.insert(new_node_input.begin() + position, val); - } - } - - // if the op have 'group' attr, set the rank list name for the op - SetCommunicationOpGroupLabel(new_node_input); - return new_node_input; -} - -void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node, - const FuncGraphPtr &func_graph, const std::string &instance_name) { - // insert new node before the node - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - ScopePtr scope = node->scope(); - MS_EXCEPTION_IF_NULL(scope); - std::vector node_input = CreateInput(op, pre_node, instance_name); - CNodePtr new_node = func_graph->NewCNode(node_input); - MS_EXCEPTION_IF_NULL(new_node); - if (instance_name.find(SPLIT_SENS) == std::string::npos) { - new_node->set_in_forward_flag(true); // mark forward flag - } - auto new_node_value = node_input[0]->cast(); - MS_EXCEPTION_IF_NULL(new_node_value); - PrimitivePtr new_node_prim = new_node_value->value()->cast(); - new_node_prim->set_instance_name(instance_name); - new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); - new_node->set_scope(scope); - node_input[0]->set_scope(scope); - manager->SetEdge(node, SizeToInt(index), new_node); -} - -std::string CreateInstanceName(const CNodePtr &node, size_t index) { - MS_EXCEPTION_IF_NULL(node); - if (!IsValueNode(node->input(0))) { - MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; - } - std::string name_base = node->fullname_with_scope(); - std::string name = name_base + "_" + std::to_string(index); - std::string instance_name = HashInstanceName(name); - return instance_name; -} - -void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - // step1:get graph manager distribute_operator - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto uses_set = manager->node_users()[node]; - CNodePtr node_to_insert = node; - for (auto &uses_pair : uses_set) { - auto uses_cnode = uses_pair.first->cast(); - MS_EXCEPTION_IF_NULL(uses_cnode); - if (!IsValueNode(uses_cnode->input(0))) { - break; - } - PrimitivePtr value_node_prim = GetValueNode(uses_cnode->input(0)); - MS_EXCEPTION_IF_NULL(value_node_prim); - if (value_node_prim->name() == TUPLE_GETITEM) { - if (uses_set.size() > 1) { - MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size(); - } - node_to_insert = uses_cnode; - } - } - MS_EXCEPTION_IF_NULL(node_to_insert); - std::reverse(forward_op.begin(), forward_op.end()); - - // step2:traverse op_list and insert node - for (size_t index = 0; index < forward_op.size(); ++index) { - std::string instance_name_base = FORWARD_OP; - std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index); - std::vector forward_input = CreateInput(forward_op[index], node_to_insert, instance_name); - CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to creat anfnode - MS_EXCEPTION_IF_NULL(forward_node); - ScopePtr scope = node->scope(); - MS_EXCEPTION_IF_NULL(scope); - forward_node->set_scope(scope); - forward_node->set_in_forward_flag(true); - forward_input[0]->set_scope(scope); - (void)manager->Replace(node_to_insert, forward_node); // using Replace function to insert node - } -} - -CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint32_t num, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(prev); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector make_tuple_inputs; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (uint32_t i = 0; i < num; i++) { - std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), prev, - CreatInt32Imm(UintToInt(i))}; - auto tuple_get_item = func_graph->NewCNode(tuple_get_item_inputs); - MS_EXCEPTION_IF_NULL(tuple_get_item); - make_tuple_inputs.push_back(tuple_get_item); - } - auto make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(prev, make_tuple); - return make_tuple; -} - -void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, - const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(pre_node); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if ((redistribution_oplist_ptr->first).size() != (redistribution_oplist_ptr->second).size()) { - MS_LOG(EXCEPTION) << "size of OperatorVector and OutPutInfoVector must be the same!"; - } - for (size_t index = 0; index < (redistribution_oplist_ptr->first).size(); ++index) { - if (pos >= SizeToInt(node->inputs().size())) { - MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size"; - } - // Creat new node - AnfNodePtr target_node = node->input(IntToSize(pos)); - MS_EXCEPTION_IF_NULL(target_node); - // Creat instance_name - auto op = (redistribution_oplist_ptr->first)[index]; - std::string op_name = (redistribution_oplist_ptr->first)[index].first; - std::string instance_name_base = REDISTRIBUTION_OP; - std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name; - InsertNode(op, node, IntToSize(pos), target_node, func_graph, instance_name); - if ((redistribution_oplist_ptr->second)[index].first) { - target_node = node->input(IntToSize(pos)); - MS_EXCEPTION_IF_NULL(target_node); - (void)InsertMakeTuple(target_node, (redistribution_oplist_ptr->second)[index].second, func_graph); - } - } -} - -void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph, int pos, - const std::string &instance_name) { - if (func_graph == nullptr) { - MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name; - } - - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (pos >= SizeToInt(node->inputs().size())) { - MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is " - << instance_name; - } - // Creat new node - AnfNodePtr pre_node = node->input(IntToSize(pos)); - MS_EXCEPTION_IF_NULL(pre_node); - InsertNode(op, node, IntToSize(pos), pre_node, func_graph, instance_name); -} - -TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim, - const OperatorInfoPtr &distribute_operator) { - TensorInfo tensorinfo_in; - if (middle_prim->name() == TUPLE_GETITEM) { - auto value_node = middle_node->input(2)->cast(); - MS_EXCEPTION_IF_NULL(value_node); - size_t index_s = IntToSize(GetValue(value_node->value())); - if (index_s >= distribute_operator->outputs_tensor_info().size()) { - MS_LOG(EXCEPTION) << "The index out of range, index: " << index_s - << ", vector size: " << distribute_operator->outputs_tensor_info().size(); - } - tensorinfo_in = distribute_operator->outputs_tensor_info()[index_s]; - } else { - if (distribute_operator->outputs_tensor_info().empty()) { - MS_LOG(EXCEPTION) << "The outputs tensor info is empty"; - } - tensorinfo_in = distribute_operator->outputs_tensor_info()[0]; - } - return tensorinfo_in.tensor_layout(); -} - -OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!IsParallelCareNode(node)) { - return nullptr; - } - OperatorInfoPtr distribute_operator = node->operator_info(); - if (distribute_operator == nullptr) { - MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; - } - return distribute_operator; -} - -void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, - const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, - const CNodePtr &pre_node) { - FuncGraphPtr func_graph = middle_node->func_graph(); - if (func_graph == nullptr) { - MS_LOG(EXCEPTION) << "Redistribution:get graph failed"; - } - CNodePtr next_node = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(next_node); - auto middle_value = middle_node->input(0)->cast(); - MS_EXCEPTION_IF_NULL(middle_value); - PrimitivePtr middle_prim = middle_value->value()->cast(); - MS_EXCEPTION_IF_NULL(middle_prim); - OperatorInfoPtr next_distribute_operator = GetDistributeOperator(next_node); - if (next_distribute_operator == nullptr) { - MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed"; - } - RankList dev_list = distribute_operator->global_device_list(); - std::string next_prim_name = GetValueNode(next_node->input(0))->name(); - MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name; - MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); - // extract tensor layout in and out - if (distribute_operator->outputs_tensor_info().empty()) { - MS_LOG(EXCEPTION) << "Failure:pre_node's tensorinfo_in is empty"; - } - - if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) { - MS_LOG(EXCEPTION) << "The index is out of range, the index is " << index - 1 << ", the vector size is " - << next_distribute_operator->inputs_tensor_info().size(); - } - TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; - TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); - TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); - if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { - MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; - MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " - << next_node->ToString(); - DumpGraph(func_graph, "redistribution_error"); - MS_LOG(EXCEPTION) << "Failure:tensor_redistribution init failed"; - } - RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); - if (redistribution_oplist_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Failure:InferTensorRedistribution failed"; - } - MS_LOG(DEBUG) << "Redistribution size " << redistribution_oplist_ptr->first.size(); - if (!redistribution_oplist_ptr->first.empty()) { - // insert node before next node - InsertRedistribution(redistribution_oplist_ptr, next_node, func_graph, node_pair.second, pre_node); - } -} - -bool StrategyFound(std::unordered_map attrs) { - auto iter = attrs.find(STRATEGY); - return !((iter == attrs.end()) || (iter->second->type_name() == NONE)); -} - -bool HasStrategy(const FuncGraphPtr &root) { - AnfNodePtr ret = root->get_return(); - MS_EXCEPTION_IF_NULL(ret); - std::vector all_nodes = DeepScopedGraphSearch(ret); - - for (auto &node : all_nodes) { - auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - continue; - } - - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = GetValueNode(prim_anf_node); - auto attrs = prim->attrs(); - if (StrategyFound(attrs)) { - return true; - } - } - - return false; -} - -bool IsCommunicationOp(const PrimitivePtr &prim) { - MS_EXCEPTION_IF_NULL(prim); - return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()); -} - -bool FindCommunicationOp(const std::vector &all_nodes) { - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - continue; - } - ValueNodePtr prim_value_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_value_node); - PrimitivePtr prim = GetValueNode(prim_value_node); - MS_EXCEPTION_IF_NULL(prim); - - if (IsCommunicationOp(prim) && cnode->in_forward_flag()) { - MS_EXCEPTION_IF_NULL(prim_value_node->scope()); - MS_LOG(INFO) << "The graph contain communication op: " << prim->name() << ", scope name is " - << prim_value_node->scope()->name(); - return true; - } - } - return false; -} - -bool IsParallelCareNode(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - ValueNodePtr prim_node = cnode->input(0)->cast(); - if (prim_node == nullptr) { - return false; - } - PrimitivePtr prim = prim_node->value()->cast(); - if (prim == nullptr) { - return false; - } - if (IsInBlackList(prim)) { - MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); - return false; - } - // get_next is not in the forward graph, we need mark the get_next as the forward node - if (prim->name() == GET_NEXT) { - return true; - } - if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) { - return false; - } - - return cnode->in_forward_flag(); -} - -void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, - const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node) { - MS_EXCEPTION_IF_NULL(node->func_graph()); - FuncGraphManagerPtr manager = node->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[node]; - CNodePtr insert_node_new; - if (IsValueNode(node->input(0))) { - auto current_value = node->input(0)->cast(); - MS_EXCEPTION_IF_NULL(current_value); - PrimitivePtr current_prim = current_value->value()->cast(); - MS_EXCEPTION_IF_NULL(current_prim); - insert_node_new = ((current_prim->name() == TUPLE_GETITEM) ? node : insert_node); - } else { - insert_node_new = insert_node; - } - MS_EXCEPTION_IF_NULL(insert_node_new); - for (auto &node_pair : node_set) { - CNodePtr use_cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(use_cnode); - if (!IsValueNode(use_cnode->input(0))) { - StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node); - } else { - ValueNodePtr prim_anf_node = use_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) { - Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, - pre_node); - } else { - StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node); - } - } - } -} - -void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(next_node); - OperatorInfoPtr op_info = next_node->operator_info(); - MS_EXCEPTION_IF_NULL(op_info); - - // If the shape of tensor is [] or [1], no need to split it. - Shapes shapes = GetNodeShape(node); - if (shapes.size() != 1) { - MS_LOG(EXCEPTION) << "Split tensor for " << op_info->name() - << ": GetNodeShape for tensor_node, output size is not 1"; - } - Shape shape = shapes[0]; - std::string shape_str = ShapeToString(shape); - if (shape.empty() || ((shape.size() == 1) && (shape[0] == 1))) { - MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape is " << shape_str - << ", no need to split it."; - return; - } - - MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape of tensor is " << shape_str; - - // extract tensor layout - if (IntToSize(index - 1) >= op_info->inputs_tensor_info().size()) { - MS_LOG(EXCEPTION) << "The index is out of range, index is " << index - 1 << ", vector size is " - << op_info->inputs_tensor_info().size(); - } - TensorInfo tensor_info = op_info->inputs_tensor_info()[IntToSize(index - 1)]; - TensorLayout tensor_layout = tensor_info.tensor_layout(); - - // Use _GetTensorSlice operator to split the tensor - FuncGraphPtr func_graph = next_node->func_graph(); // only cnode can get the graph - MS_EXCEPTION_IF_NULL(func_graph); - Operator op = CreateGetTensorSliceOp(tensor_layout); - InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR); - if (!op_info->sub_ops().empty()) { - auto sub_ops = op_info->sub_ops(); - for (size_t i = 0; i < sub_ops.size(); i++) { - if (!sub_ops.at(i).empty()) { - InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB); - } - } - } -} - -void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[node]; - for (auto &node_pair : node_set) { - CNodePtr use_cnode = node_pair.first->cast(); - if (use_cnode == nullptr || !IsValueNode(use_cnode->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = use_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(use_cnode_prim); - if (use_cnode_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(use_cnode)) { - SplitTensor(node, use_cnode, node_pair.second); - } - } -} - -std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, - const CNodePtr &node) { - OperatorArgs arg_replace_op = replace_op.second; - ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name); - if (pyop_instance == nullptr) { - MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed"; - } - OperatorParams params = arg_replace_op.second; - if (node->inputs().size() < 2) { - // GetNext operator dose not has input - if (node->inputs().size() == 1) { - return {NewValueNode(pyop_instance)}; - } - MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; - } - std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; - auto prim = GetValueNode(node->input(0)); - if (prim->name() == GATHERV2 || prim->name() == SPARSE_GATHERV2) { - replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; - } - if (!params.empty()) { - Param param_first = *(params.begin()); - int32_t first_position = param_first.second; - if (first_position == 1) { - replace_input.pop_back(); - } - for (auto ¶m : params) { - AnfNodePtr val = NewValueNode(param.first.second); - if (val == nullptr) { - MS_LOG(EXCEPTION) << "Failure:val is nullptr"; - } - int32_t position = param.second; - (void)replace_input.insert(replace_input.begin() + position, val); - } - } - - return replace_input; -} - -void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; - } - std::string instance_name = CreateInstanceName(node, 0); - std::vector replace_input; - replace_input = ReplaceOpInput(replace_op, instance_name, node); - CNodePtr replace_node = func_graph->NewCNode(replace_input); - MS_EXCEPTION_IF_NULL(replace_node); - ScopePtr scope = node->scope(); - MS_EXCEPTION_IF_NULL(scope); - replace_node->set_scope(scope); - replace_node->set_in_forward_flag(true); - replace_input[0]->set_scope(scope); - (void)manager->Replace(node, replace_node); -} - -void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { - // step1:get graph manager distribute_operator - OperatorInfoPtr distribute_operator = node->operator_info(); - if (distribute_operator == nullptr) { - MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; - } - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; - } - // step2:traverse op_list and insert node - std::reverse(replace_op.begin(), replace_op.end()); - auto replace_op_info = distribute_operator->replace_op_info(); - std::reverse(replace_op_info.begin(), replace_op_info.end()); - if (!replace_op_info.empty() && replace_op_info.size() != replace_op.size()) { - MS_LOG(EXCEPTION) << "replace_op_info is not empty and size not equal to replace_op!"; - } - bool replace_op_info_flag = !replace_op_info.empty(); - for (size_t index = 0; index < replace_op.size(); ++index) { - std::string instance_name = CreateInstanceName(node, index); - std::vector replace_input; - if (index != replace_op.size() - 1) { - replace_input = CreateInput(replace_op[index], node, instance_name); - } else { - replace_input = ReplaceOpInput(replace_op[index], instance_name, node); - } - CNodePtr replace_node = func_graph->NewCNode(replace_input); - MS_EXCEPTION_IF_NULL(replace_node); - ScopePtr scope = node->scope(); - MS_EXCEPTION_IF_NULL(scope); - replace_node->set_scope(scope); - if (index == replace_op.size() - 1) { - (void)replace_node->set_operator_info(node->operator_info()); - } - replace_node->set_in_forward_flag(true); - replace_input[0]->set_scope(scope); - if (replace_op_info_flag && replace_op_info[index].first) { - auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph); - (void)manager->Replace(node, new_cnode); // using Replace function to insert node - } else { - (void)manager->Replace(node, replace_node); // using Replace function to insert node - } - } - MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name(); -} - -bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { - ValueNodePtr anf_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(anf_node); - PrimitivePtr prim = anf_node->value()->cast(); - return (prim->name() == name); -} - -void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(replace_graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(replace_graph->second); - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; - } - for (auto &replace_input : replace_graph->first) { - auto pre_node = node->input(IntToSize(replace_input.second)); - manager->SetEdge(replace_input.first, 1, pre_node); - } - // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called - auto replace_output = replace_graph->second; - MS_EXCEPTION_IF_NULL(replace_output); - (void)manager->Replace(node, replace_output); -} - -int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != 3) { - MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3"; - } - - if (!cnode->input(2)->isa()) { - MS_LOG(EXCEPTION) << "The index of tuple getitem is not a value node"; - } - - ValuePtr tuple_index_value = GetValueNode(cnode->input(2)); - MS_EXCEPTION_IF_NULL(tuple_index_value); - if (!tuple_index_value->isa()) { - MS_LOG(EXCEPTION) << "The index of tuple getitem is not int32"; - } - return tuple_index_value->cast()->value(); -} - -// Judge whether the node is a loss, and if there are multiple outputs, -// get which output is a grad according to the tuple getitem. -// Currently, it is not supported that the sens is a tuple. -LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { - MS_EXCEPTION_IF_NULL(loss_node); - FuncGraphPtr sub_graph = loss_node->func_graph(); - MS_EXCEPTION_IF_NULL(sub_graph); - CNodePtr return_node = sub_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; - } - AnfNodePtr pre_node = return_node->input(1); - MS_EXCEPTION_IF_NULL(pre_node); - - LossNodeInfo node_info; - - // return -> cast - auto pre_cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - auto pre_prim = GetValueNode(pre_cnode->input(0)); - if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { - pre_node = pre_cnode->input(1); - } - - // return -> loss - if (pre_node == loss_node) { - node_info.has_tuple_getitem = false; - node_info.dout_index = 0; - return node_info; - } - - // return -> tuple_getitem -> loss - auto cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto current_value = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(current_value); - PrimitivePtr current_prim = current_value->value()->cast(); - MS_EXCEPTION_IF_NULL(current_prim); - // size of common cnode is larger than 1 - if (cnode->inputs().size() < 2) { - MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is smaller than 2"; - } - - if ((current_prim->name() == TUPLE_GETITEM) && (cnode->input(1) == loss_node)) { - // size of tuple_getitem cnode is 3 - auto tuple_index = GetTupleGetItemIndex(cnode); - node_info.has_tuple_getitem = true; - node_info.dout_index = tuple_index; - return node_info; - } - - MS_LOG(EXCEPTION) << "Invalid loss"; -} - -void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - size_t node_size = node->inputs().size(); - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - for (size_t index = 1; index < node_size; ++index) { - AnfNodePtr input = node->input(index); - MS_EXCEPTION_IF_NULL(input); - if (!input->isa() && !input->isa()) { // if it is not a tensor, continue - MS_LOG(INFO) << "insert div op: the index " << index << " is not tensor, skip"; - continue; - } - - for (size_t pos = 0; pos < virtual_div_op.size(); ++pos) { - std::string instance_name = CreateInstanceName(node, pos); - InsertNode(virtual_div_op[pos], node, index, node->input(index), func_graph, instance_name); - } - MS_LOG(INFO) << "insert div op for input index " << index << " of node"; - } -} - -std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { - if (!node->isa() && !node->isa() && !node->isa()) { - return std::make_pair(nullptr, false); - } else if (node->isa()) { - return std::make_pair(node, false); - } else if (node->isa()) { - if (IsValueNode(node)) { - std::vector param_v = FindParameterByRefKeyNode(node, func_graph); - if (param_v.size() != 1) { - MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " - << param_v.size(); - } - return std::make_pair(node, true); - } - return std::make_pair(nullptr, false); - } else { - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - if (!FindParameter(cnode->input(index), func_graph).first) { - continue; - } - return FindParameter(cnode->input(index), func_graph); - } - } else { - if (IsParallelCareNode(cnode)) { - return std::make_pair(nullptr, false); - } else { - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - PrimitivePtr prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == DEPEND && index != 1) { - continue; - } - if (!FindParameter(cnode->input(index), func_graph).first) { - continue; - } - return FindParameter(cnode->input(index), func_graph); - } - } - } - } - return std::make_pair(nullptr, false); -} - -std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(anode); - MS_EXCEPTION_IF_NULL(anode->func_graph()); - FuncGraphManagerPtr manager = anode->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[anode]; - bool result = false; - CNodePtr cnode_return = nullptr; - for (auto &node_pair : node_set) { - CNodePtr use_apply = node_pair.first->cast(); - if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == name && node_pair.second == 1) { - if (use_apply->func_graph() == func_graph) { - result = true; - cnode_return = use_apply; - MS_LOG(INFO) << "Find Primitive " << name << " in the same func_graph"; - continue; - } - MS_LOG(INFO) << "Find Primitive " << name << " in different func_graph"; - } - } - return std::make_pair(result, cnode_return); -} - -bool IsCastBeforMirror(const CNodePtr &node, size_t index) { - // only if cast_before_mirror is true, pre node is cast and type is not float32 return true - if (!ParallelContext::GetInstance()->cast_before_mirror()) { - return false; - } - auto pre_node = node->input(index); - MS_EXCEPTION_IF_NULL(pre_node); - auto cnode = pre_node->cast(); - if (cnode == nullptr || !IsValueNode(cnode->input(0))) { - return false; - } - auto pre_value_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(pre_value_node); - auto pre_prim = pre_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(pre_prim); - if (pre_prim->name() != CAST) { - return false; - } - auto node_type = pre_node->Type(); - MS_EXCEPTION_IF_NULL(node_type); - if (!node_type->isa()) { - MS_LOG(EXCEPTION) << "Unknown type."; - } - auto input_element_type = node_type->cast()->element(); - MS_EXCEPTION_IF_NULL(input_element_type); - auto type_id = input_element_type->type_id(); - - return (type_id != kNumberTypeFloat32); -} - -void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - size_t node_size = node->inputs().size(); - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (mirror_ops.size() != node_size - 1) { - MS_LOG(EXCEPTION) << "Failure:Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() - << ", node_size is " << node_size; - } - for (size_t index = 1; index < node_size; ++index) { - OperatorVector backward_op = mirror_ops[index - 1]; - if (backward_op.empty()) { - continue; - } - std::pair param_node_pair = FindParameter(node->input(index), func_graph); - if (!param_node_pair.first) { - continue; - } - // not a RefKey - if (!param_node_pair.second) { - auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph); - // if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead - if (next_cnode.first) { - MS_EXCEPTION_IF_NULL(next_cnode.second); - manager->SetEdge(node, SizeToInt(index), next_cnode.second); - continue; - } - } - // if the parameter found is a RefKey, or no MirrorOp is found in the same graph, insert a new MirrorOp - // only one MirrorOp in backward_op - if (backward_op.size() != 1) { - MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size(); - } - std::string instance_name = MIRROR_OP; - if (IsCastBeforMirror(node, index)) { - for (auto &op : backward_op) { - // insert new node before the node - CNodePtr cnode = node->input(index)->cast(); - MS_EXCEPTION_IF_NULL(cnode); - AnfNodePtr pre_node = cnode->input(1); - InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); - } - } else { - for (auto &op : backward_op) { - AnfNodePtr pre_node = node->input(index); - InsertNode(op, node, index, pre_node, func_graph, instance_name); - } - } - } -} - -void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, - const std::vector> &sens_loss_pairs) { - MS_EXCEPTION_IF_NULL(distribute_operator); - MS_EXCEPTION_IF_NULL(node); - - bool is_loss_cnode = - std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(), - [node](const std::pair &element) { return element.second == node; }); - - MirrorOps mirror_ops = distribute_operator->mirror_ops(); - VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); - // insert mirror op - if (!mirror_ops.empty()) { - MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); - InsertMirrorOps(mirror_ops, node); - } - // insert virtual div op - if (!virtual_div_op.empty() && is_loss_cnode) { - MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name(); - InsertVirtualDivOp(virtual_div_op, node); - } -} - -std::string GetDisOpName(const std::string &prim_name) { - std::string op_name = prim_name; - if (!prim_name.empty() && (prim_name[0] == '_')) { - op_name = prim_name.substr(1); - } - return op_name + "Info"; -} - -OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs, - const std::vector &shape_list) { - if (shape_list.size() != 2) { - MS_LOG(ERROR) << "The size of shape list is not 2"; - return nullptr; - } - if (name.length() == 0) { - MS_LOG(EXCEPTION) << "Length of name is zero!"; - } - std::string distribute_opname = GetDisOpName(name); - if (name == GATHERV2) { - distribute_opname = name + "PInfo"; - auto data_parallel_iter = attrs.find(DATA_PARALLEL); - if (data_parallel_iter != attrs.end()) { - MS_EXCEPTION_IF_NULL(data_parallel_iter->second); - if (!data_parallel_iter->second->isa()) { - MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool."; - } - bool data_parallel = data_parallel_iter->second->cast()->value(); - if (data_parallel) { - distribute_opname = name + "Info"; - } - } - } - OperatorInfoPtr operator_ = - (OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); - if (operator_ == nullptr) { - MS_LOG(INFO) << "Creat " << name << " failed"; - return nullptr; - } - std::string origin_name = operator_->name(); - operator_->set_name(origin_name + std::to_string(TOTAL_OPS)); - MS_LOG(INFO) << "Successfully created operator " << origin_name; - ++TOTAL_OPS; - return operator_; -} - -OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, - const std::vector &shape_list) { - MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list); - if (operator_ == nullptr) { - MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel"; - operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list); - MS_EXCEPTION_IF_NULL(operator_); - } - return operator_; -} - -OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, - std::vector shape_list) { - OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); - for (size_t i = 0; i < shape_list[0].size(); ++i) { - MS_LOG(INFO) << "No: " << i << " input's shape: " << ShapeToString(shape_list[0][i]); - } - return operator_; -} - -StrategyPtr ExtractStrategy(std::unordered_map attrs) { - ValueTuplePtr var = attrs[STRATEGY]->cast(); - StrategyPtr strategyPtr; - MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString(); - if (var == nullptr) { - MS_LOG(EXCEPTION) << "Strategy value is nullptr"; - } - if (var->size() > 0) { - std::vector elements = var->value(); - std::vector strategy; - for (uint32_t index = 0; index < elements.size(); ++index) { - Dimensions dim; - if (elements[index]->isa()) { - ValueTuplePtr value_tuple = elements[index]->cast(); - std::vector value_vector = value_tuple->value(); - (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim), - [](const ValuePtr &value) { return static_cast(GetValue(value)); }); - strategy.push_back(dim); - } else { - MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; - } - } - if (strategy.empty()) { - MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy"; - } - strategyPtr = NewStrategy(0, strategy); - } - - return strategyPtr; -} - -Shapes GetNodeShape(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - Shapes shapes; - BaseShapePtr base_shape_ptr = node->Shape(); - if (node->isa()) { - auto cnode = node->cast(); - if (IsValueNode(cnode->input(0))) { - PrimitivePtr prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == MAKEREF) { - AnfNodePtr ref_node = cnode->input(1); - auto func_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(ref_node); - MS_EXCEPTION_IF_NULL(func_graph); - return GetRefKeyNodeShape(ref_node, func_graph); - } - } - if (cnode->input(0)->isa()) { - if (cnode->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is samller than 2"; - } - base_shape_ptr = cnode->input(1)->Shape(); - } - } - if (base_shape_ptr == nullptr) { - MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is " - << node->fullname_with_scope(); - } - auto tuple_shape_ptr = dyn_cast(base_shape_ptr); - if (tuple_shape_ptr != nullptr) { - auto tuple_shape = tuple_shape_ptr->shape(); - for (auto &shape : tuple_shape) { - auto each_shape = dyn_cast(shape); - MS_EXCEPTION_IF_NULL(each_shape); - shapes.push_back(each_shape->shape()); - } - } else { - auto shape_ptr = dyn_cast(base_shape_ptr); - MS_EXCEPTION_IF_NULL(shape_ptr); - shapes.push_back(shape_ptr->shape()); - } - return shapes; -} - -std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector parameters; - if (!IsValueNode(node)) { - MS_LOG(ERROR) << "The node is not a ref key"; - return parameters; - } - - auto ref_key = GetValueNode(node); - MS_EXCEPTION_IF_NULL(ref_key); - auto name = ref_key->tag(); - - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto roots = manager->roots(); - if (roots.size() != 1) { - MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1"; - return parameters; - } - - FuncGraphPtr root_g = roots.back(); - MS_EXCEPTION_IF_NULL(root_g); - for (auto ¶m_node : root_g->parameters()) { - auto param = param_node->cast(); - if (param && (name == param->name())) { - parameters.push_back(param_node); - MS_LOG(INFO) << "The name of ref key is: " << name; - return parameters; - } - } - - MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter"; - return parameters; -} - -Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(func_graph); - - std::vector parameters = FindParameterByRefKeyNode(node, func_graph); - if (parameters.size() != 1) { - MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; - } - - Shapes input_shapes; - input_shapes = GetNodeShape(parameters[0]); - if (input_shapes.size() != 1) { - MS_LOG(EXCEPTION) << "Get input shape failed"; - } - - MS_LOG(INFO) << "The parameter shape is " << ShapeToString(input_shapes[0]); - return input_shapes; -} - -std::vector ExtractShape(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - Shapes shape_inputs, shape_outputs; - std::vector shape_all; - std::vector all_inputs = node->inputs(); - std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; - - size_t inputs_size = all_inputs.size(); - for (size_t i = 1; i < inputs_size; ++i) { - Shapes input_shapes; - AnfNodePtr input = all_inputs[i]; - if (IsValueNode(input)) { - auto func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector parameters = FindParameterByRefKeyNode(input, func_graph); - if (parameters.size() != 1) { - MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; - } - std::pair node_pair = std::make_pair(node, SizeToInt(i)); - g_RefMap[parameters[0]] = node_pair; - input_shapes = GetRefKeyNodeShape(input, func_graph); - } else if (IsValueNode(input) || input->isa() || input->isa()) { - input_shapes = GetNodeShape(input); - } else { - continue; - } - if (input_shapes.size() != 1) { - MS_LOG(EXCEPTION) << "ExtractShape:Get input shape failed"; - } - shape_inputs.push_back(input_shapes[0]); - } - shape_all.push_back(shape_inputs); - // extract out shape - shape_outputs = GetNodeShape(node); - shape_all.push_back(shape_outputs); - return shape_all; -} - -std::pair FindParallelCareNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[node]; - for (auto &node_pair : node_set) { - CNodePtr cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - continue; - } - ValueNodePtr prim_node_anf = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_node_anf); - PrimitivePtr node_prim = prim_node_anf->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { - return node_pair; - } else if (FindParallelCareNode(node_pair.first).first != nullptr) { - return FindParallelCareNode(node_pair.first); - } - } - return std::make_pair(nullptr, 0); -} - -std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(parameter); - FuncGraphManagerPtr manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - std::pair prim_anf_node_pair = FindParallelCareNode(parameter); - if (prim_anf_node_pair.first != nullptr) { - return prim_anf_node_pair; - } else { - AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; - for (auto ¶m_pair : param_sub_set) { - CNodePtr graph_cnode = param_pair.first->cast(); - if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa()) { - continue; - } - CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast(); - if (!IsValueNode(graph_cnode_inp0->input(1))) { - continue; - } - FuncGraphPtr graph_sub = GetValueNode(graph_cnode_inp0->input(1)); - auto parameters = graph_sub->parameters(); - if (IntToSize(param_pair.second - 1) >= parameters.size()) { - MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " - << parameters.size(); - } - std::pair res = FindSubGraph(graph_sub, parameters[IntToSize(param_pair.second - 1)]); - if (res.first != nullptr) { - return res; - } - } - } - return std::make_pair(nullptr, 0); -} - -void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res) { - MS_EXCEPTION_IF_NULL(parameter); - AbstractBasePtr abstract = parameter->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); - CNodePtr cnode = res.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - OperatorInfoPtr distribute_operator = cnode->operator_info(); - if (distribute_operator == nullptr) { - MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; - } - - if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) { - MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is " - << distribute_operator->inputs_tensor_info().size(); - } - TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)]; - Shape slice_shape = tensorinfo_in.slice_shape(); - MS_LOG(DEBUG) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " - << MakeValue(slice_shape)->ToString(); - std::shared_ptr parallel_shape = std::make_shared(slice_shape); - MS_EXCEPTION_IF_NULL(parallel_shape); - // Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis. - auto cloned_abstract = abstract->Clone(); - MS_EXCEPTION_IF_NULL(cloned_abstract); - cloned_abstract->set_shape(parallel_shape); - parameter->set_abstract(cloned_abstract); - TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); - ParameterPtr parameter_ptr = parameter->cast(); - MS_EXCEPTION_IF_NULL(parameter_ptr); - parameter_ptr->set_tensor_layout(std::make_shared(tensor_layout)); -} - -void CoverSliceShape(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - auto parameters = root->parameters(); - for (auto ¶meter : parameters) { - MS_EXCEPTION_IF_NULL(parameter->Shape()); - auto iter = g_RefMap.find(parameter); - if (iter != g_RefMap.end()) { - SetParallelShape(parameter, g_RefMap[parameter]); - continue; - } - std::pair res = FindSubGraph(root, parameter); - if (res.first == nullptr) { - MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape"; - } else { - SetParallelShape(parameter, res); - MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); - } - } - g_RefMap.clear(); -} - -bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_node) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(parameter_node); - FuncGraphManagerPtr manager = root->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto cloned_parameter = parameter_node->cast(); - MS_EXCEPTION_IF_NULL(cloned_parameter); - - // find the clone parameter - if (!cloned_parameter->has_default()) { - return false; - } - - auto param_value = std::dynamic_pointer_cast(cloned_parameter->default_param()); - py::object clone_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO); - bool cloned = py::cast(parse::python_adapter::GetPyObjAttr(clone_info, CLONED)); - if (!cloned) { - return false; - } - - MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned"; - return true; -} - -void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - for (auto &cloned_parameter_node : root->parameters()) { - MS_EXCEPTION_IF_NULL(cloned_parameter_node); - auto cloned_parameter = cloned_parameter_node->cast(); - MS_EXCEPTION_IF_NULL(cloned_parameter); - - if (!ParameterIsCloned(root, cloned_parameter_node)) { - continue; - } - - // get the cloned index - auto param_value = std::dynamic_pointer_cast(cloned_parameter->default_param()); - py::object cloned_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO); - int32_t cloned_index = py::cast(parse::python_adapter::GetPyObjAttr(cloned_info, CLONED_INDEX)); - - // find the be cloned parameter - bool found_be_cloned_parameter = false; - ParameterPtr cloned_from_parameter = nullptr; - AnfNodePtr cloned_from_node = nullptr; - for (auto &be_cloned_parameter_node : root->parameters()) { - MS_EXCEPTION_IF_NULL(be_cloned_parameter_node); - auto be_cloned_parameter = be_cloned_parameter_node->cast(); - MS_EXCEPTION_IF_NULL(be_cloned_parameter); - if (!be_cloned_parameter->has_default()) { - continue; - } - - auto param_value_cloned = std::dynamic_pointer_cast(be_cloned_parameter->default_param()); - py::object be_cloned_info = parse::python_adapter::GetPyObjAttr(param_value_cloned->value(), CLONE_INFO); - if (!py::cast(parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED))) { - continue; - } - - // get the be cloned index - py::list be_cloned_index = parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED_INDEX); - for (auto &index : be_cloned_index) { - if (cloned_index == py::cast(index)) { - found_be_cloned_parameter = true; - cloned_from_parameter = be_cloned_parameter; - cloned_from_node = be_cloned_parameter_node; - break; - } - } - } - - if (found_be_cloned_parameter) { - // set the shape and tensor layout for cloned parameter - cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); - MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); - MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); - auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); - MS_EXCEPTION_IF_NULL(cloned_abstract); - cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); - cloned_parameter_node->set_abstract(cloned_abstract); - MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() - << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() - << ", clone index is: " << cloned_index; - } else { - MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is " - << cloned_index << ", but not found the be cloned parameter"; - } - } - std::string env = common::GetEnv("SLICE_ENV"); - if (!env.empty()) { - MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env; - abstract::InitUndeterminedFromEnv(env); - } -} - -void SetVirtualDatasetStrategy(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - - PrimitivePtr prim = GetValueNode(node->input(0)); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == VIRTUAL_DATA_SET) { - CheckGlobalDeviceManager(); - int32_t dev_num; - if (full_batch) { - dev_num = 1; - } else { - dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); - } - auto attrs_temp = prim->attrs(); - std::vector shape_list = ExtractShape(node); - if (shape_list.empty()) { - MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; - } - std::vector elements; - for (size_t i = 0; i < shape_list[0].size(); i++) { - if (shape_list[0][i].empty()) { - MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero"; - } - std::vector input_strategy = {dev_num}; - for (size_t j = 1; j < shape_list[0][i].size(); j++) { - input_strategy.push_back(1); - } - elements.push_back(MakeValue(input_strategy)); - } - ValueTuplePtr strategy = std::make_shared(elements); - attrs_temp[STRATEGY] = strategy; - (void)prim->SetAttrs(attrs_temp); - } -} - -void ExtractInformation(const std::vector &all_nodes) { - // load strategy map from checkpoint - StrategyMap stra_map; - if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { - if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { - MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; - } - } - for (auto &node : all_nodes) { - auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - continue; - } - SetVirtualDatasetStrategy(cnode); - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = GetValueNode(prim_anf_node); - auto attrs = prim->attrs(); - MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name(); - if (IsParallelCareNode(cnode)) { - std::vector shape_list = ExtractShape(cnode); - if (shape_list.empty()) { - MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; - } - OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); - if (operator_ == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed"; - } - auto &inputs = cnode->inputs(); - std::vector input_value; - for (size_t index = 1; index < inputs.size(); ++index) { - if (inputs[index]->isa()) { - input_value.push_back(GetValueNode(inputs[index])); - } else { - input_value.emplace_back(nullptr); - } - } - StrategyPtr strategyPtr = nullptr; - (*operator_).set_input_value(input_value); - (*operator_).set_outputs_dtype(cnode->Type()); - (*operator_).set_cnode(cnode); - if (prim->name() == RESHAPE) { - (void)cnode->set_operator_info(operator_); - continue; - } - // load strategy checkpoint - // key of strategy map - std::string strategy_key_name = NodeParameterName(cnode); - bool load_strategy_from_ckpt = - StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); - if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { - MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() - << " is empty, using batch parallel"; - std::shared_ptr> strategy_v_ptr = operator_->GenerateBatchStrategies(); - if (strategy_v_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Generate batch parallel strategy failed"; - } - std::vector elements; - for (size_t i = 0; i < strategy_v_ptr->size(); i++) { - elements.push_back(MakeValue((*strategy_v_ptr)[i])); - } - ValueTuplePtr strategy = std::make_shared(elements); - // display the strategy generated by batch parallel - attrs[GEN_STRATEGY] = strategy; - (void)prim->SetAttrs(attrs); - MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is " - << attrs[GEN_STRATEGY]->ToString(); - strategyPtr = NewStrategy(0, *strategy_v_ptr); - } else if (load_strategy_from_ckpt) { - strategyPtr = stra_map[strategy_key_name]; - } else { - strategyPtr = ExtractStrategy(attrs); - } - if (strategyPtr != nullptr) { - if (operator_->Init(strategyPtr) == FAILED) { - MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; - } - (void)cnode->set_operator_info(operator_); - } else { - MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; - } - } - } -} - -TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair) { - CNodePtr cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); - MS_EXCEPTION_IF_NULL(distribute_operator); - int index = node_pair.second; - if (index > SizeToInt(distribute_operator->inputs_tensor_info().size())) { - MS_LOG(EXCEPTION) << "The index is out of range, the node_pair.second is " << index - 1 << ", the vector size is " - << distribute_operator->inputs_tensor_info().size(); - } - TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; - TensorLayout tensorlayout_in = tensorinfo_in.tensor_layout(); - return tensorlayout_in; -} - -// if reshape's output connect to several primitive, return the first layout found -std::shared_ptr FindNextLayout(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(cnode->func_graph()); - FuncGraphManagerPtr manager = cnode->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[cnode]; - for (auto &node_pair : node_set) { - CNodePtr use_apply = node_pair.first->cast(); - if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { - MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); - auto layout = GetInputLayoutFromCNode(node_pair); - return std::make_shared(layout); - } - MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) - << " " << (use_apply->operator_info() != nullptr); - - auto layout_ptr = FindNextLayout(use_apply); - if (layout_ptr) { - return layout_ptr; - } - } - MS_LOG(WARNING) << "FindNextLayout return nullptr, if reshape is not the last primitive, there must be some error"; - return nullptr; -} - -std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) { - MS_EXCEPTION_IF_NULL(cnode); - OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); - MS_EXCEPTION_IF_NULL(distribute_operator); - if (distribute_operator->outputs_tensor_info().size() < output_index) { - MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size() - << ", must be less than output_index " << output_index; - } - TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index]; - TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); - return std::make_shared(tensorlayout_out); -} - -std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) { - if (!node->isa()) { - return nullptr; - } - CNodePtr cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - return nullptr; - } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); - if (!layout_ptr) { - MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; - } - return layout_ptr; - } - return nullptr; -} - -std::shared_ptr CreateParameterLayout(const AnfNodePtr &node) { - // Create DataParallel tensor layout for parameter(support WideDeep). - CheckGlobalDeviceManager(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); - TensorLayout input_tensor_layout; - // create input_shape - Shapes inputs_shape = GetNodeShape(node); - Shape input_shape_array = inputs_shape[0]; - if (input_shape_array.empty()) { - MS_LOG(EXCEPTION) << "Don't support reshape a scalar parameter."; - } - // create tensor_map - size_t shape_size = input_shape_array.size(); - TensorMap input_tensor_map_array(SizeToInt(shape_size) - 1, -1); - input_tensor_map_array.insert(input_tensor_map_array.begin(), 0); - // create dev_matrix - Shape dev_matrix_array = {dev_num}; - if (input_tensor_layout.InitFromVector(dev_matrix_array, input_tensor_map_array, input_shape_array) != SUCCESS) { - MS_LOG(EXCEPTION) << "Create tensor layout for parameter failed."; - } - return std::make_shared(input_tensor_layout); -} - -std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { - if (node->isa()) { - return CreateParameterLayout(node); - } - if (!node->isa()) { - return nullptr; - } - CNodePtr cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - return nullptr; - } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); - if (!layout_ptr) { - MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; - } - return layout_ptr; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = prim_anf_node->value()->cast(); - if (prim->name() == TUPLE_GETITEM) { - auto tuple_index = GetTupleGetItemIndex(cnode); - auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), IntToSize(tuple_index)); - if (!layout_ptr) { - MS_LOG(EXCEPTION) - << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a parallel care node " - "before tuple_getitem!"; - } - return layout_ptr; - } - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - if (prim->name() == DEPEND && index != 1) { - continue; - } - auto layout_ptr = FindPrevLayout(cnode->inputs()[index]); - if (!layout_ptr) { - continue; - } - return layout_ptr; - } - MS_LOG(WARNING) << "FindPrevLayout return nullptr, if reshape is not the first primitive, there must be some error"; - return nullptr; -} - -void ReshapeInit(const std::vector &all_nodes) { - for (auto &node : all_nodes) { - auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { - continue; - } - PrimitivePtr prim = GetValueNode(prim_anf_node); - MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); - if (operator_info == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; - } - if (prim->name() != RESHAPE) { - continue; - } - auto attrs = prim->attrs(); - if (StrategyFound(attrs)) { - MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; - } - MS_ASSERT(cnode->inputs().size() == 3); - auto prev_layout_ptr = FindPrevLayout(cnode->input(1)); - if (prev_layout_ptr) { - auto reshape_info_ptr = std::dynamic_pointer_cast(operator_info); - reshape_info_ptr->SetInputLayout(*prev_layout_ptr); - } - auto next_layout_ptr = FindNextLayout(cnode); - if (next_layout_ptr) { - auto reshape_info_ptr = std::dynamic_pointer_cast(operator_info); - reshape_info_ptr->SetOutputLayout(*next_layout_ptr); - } - if (operator_info->Init(nullptr) == FAILED) { - MS_LOG(EXCEPTION) << "Failure:operator " << prim->ToString() << " init failed"; - } - } -} - -CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - CNodePtr return_node = func_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->size() < 2) { - MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; - } - AnfNodePtr pre_node = return_node->input(1); - MS_EXCEPTION_IF_NULL(pre_node); - - auto pre_cnode = pre_node->cast(); - if (pre_cnode == nullptr) { - return nullptr; - } - - auto current_prim = GetValueNode(pre_cnode->input(0)); - // return -> cast - if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { - pre_cnode = pre_cnode->input(1)->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - current_prim = GetValueNode(pre_cnode->input(0)); - } - - // notice: the GetNext op has not input - if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { - MS_LOG(INFO) << "The loss is: " << current_prim->name(); - return pre_cnode; - } - - // size of common cnode is larger than 1 - if (pre_cnode->size() < 2) { - MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; - } - - // return -> tuple_getitem -> loss - if (current_prim->name() == TUPLE_GETITEM) { - AnfNodePtr pre_pre_node = pre_cnode->input(1); - MS_EXCEPTION_IF_NULL(pre_pre_node); - - auto pre_pre_cnode = pre_pre_node->cast(); - auto value = pre_pre_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(value); - PrimitivePtr prim = value->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - MS_LOG(DEBUG) << "The loss name is " << prim->name(); - return pre_pre_cnode; - } - - // return -> make_tuple - if (current_prim->name() == MAKE_TUPLE) { - MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; - } - - // return -> loss - MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); - return pre_cnode; -} - -TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { - TensorLayouts ret; - MS_EXCEPTION_IF_NULL(loss_cnode); - AnfNodePtr node = loss_cnode->cast(); - MS_EXCEPTION_IF_NULL(node); - - LossNodeInfo node_info = GetLossNodeInfo(node); - ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) { - MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now"; - return ret; - } - - OperatorInfoPtr operator_info = loss_cnode->operator_info(); - MS_EXCEPTION_IF_NULL(operator_info); - TensorInfo loss_grad_tensor_info; - size_t op_output_size = operator_info->outputs_tensor_info().size(); - MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is " - << node_info.has_tuple_getitem << ", the output size is " << op_output_size << ", the dout_index is " - << node_info.dout_index; - - if ((op_output_size == 0) || (op_output_size <= IntToSize(node_info.dout_index))) { - MS_LOG(EXCEPTION) << "The index is " << node_info.dout_index << ", but the size of outputs is " << op_output_size; - } - - if (!node_info.has_tuple_getitem && (op_output_size > 1)) { - MS_LOG(EXCEPTION) << "Currently, it is not supported that the sens is a tuple."; - } - - loss_grad_tensor_info = operator_info->outputs_tensor_info()[IntToSize(node_info.dout_index)]; - ret.push_back(loss_grad_tensor_info.tensor_layout()); - return ret; -} - -void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) { - MS_EXCEPTION_IF_NULL(grad_sens_node); - if (grad_sens_node->size() <= 1) { - MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2"; - } - AnfNodePtr sens_tensor_node = grad_sens_node->input(1); - MS_EXCEPTION_IF_NULL(sens_tensor_node); - Shapes sens_shapes = GetNodeShape(sens_tensor_node); - if (sens_shapes.size() != 1) { - MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1"; - } - // If the shape of sens tensor is [] or [1], no need to split it. - Shape sens_shape = sens_shapes[0]; - if (sens_shape.empty() || ((sens_shape.size() == 1) && (sens_shape[0] == 1))) { - if (sens_tensor_node->isa()) { - auto sens_tensor_param = sens_tensor_node->cast(); - MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); - sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); - } - MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; - return; - } - auto loss_shape = loss_grad_layout.tensor_shape().array(); - if (loss_shape != sens_shape) { - MS_LOG(EXCEPTION) << "The shape of sens is not equal to loss output, it is unsupported now. Sens shape is " - << ShapeToString(sens_shape) << ", loss shape is " << ShapeToString(loss_shape); - } - MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", split it."; - - if (!IsValueNode(sens_tensor_node)) { - if (sens_tensor_node->isa()) { - MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); - AbstractBasePtr abstract = sens_tensor_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - auto slice_shape = loss_grad_layout.slice_shape().array(); - std::shared_ptr parallel_shape = std::make_shared(slice_shape); - MS_EXCEPTION_IF_NULL(parallel_shape); - auto cloned_abstract = abstract->Clone(); - MS_EXCEPTION_IF_NULL(cloned_abstract); - cloned_abstract->set_shape(parallel_shape); - sens_tensor_node->set_abstract(cloned_abstract); - auto sens_tensor_param = sens_tensor_node->cast(); - sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); - return; - } - MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; - } - - // Use _GetTensorSlice operator to split the sens tensor - FuncGraphPtr func_graph = grad_sens_node->func_graph(); // only cnode can get the graph - MS_EXCEPTION_IF_NULL(func_graph); - Operator op = CreateGetTensorSliceOp(loss_grad_layout); - InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS); -} - -void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(distribute_operator); - MS_EXCEPTION_IF_NULL(cnode); - OperatorVector forward_op = distribute_operator->forward_op(); - if (!forward_op.empty()) { - MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name(); - ForwardCommunication(forward_op, cnode); - } -} - -void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(distribute_operator); - MS_EXCEPTION_IF_NULL(cnode); - // StepReplaceOp - OperatorVector replace_op = distribute_operator->replace_op(); - if (!replace_op.empty()) { - MS_LOG(INFO) << "StepReplaceOp " << cnode->ToString(); - StepReplaceOp(replace_op, cnode); - } - - // StepReplaceGraph: after calling StepReplaceGraph, cnode can not be used anymore. - ReplaceGraphPtr replace_graph = distribute_operator->replace_graph(cnode); - if (!replace_op.empty() && replace_graph) { - MS_LOG(EXCEPTION) << "Only one of replace_op or replace_op can be used"; - } - if (replace_graph) { - MS_LOG(INFO) << "StepReplaceGraph " << cnode->ToString(); - StepReplaceGraph(replace_graph, cnode); - } -} - -void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(distribute_operator); - MS_EXCEPTION_IF_NULL(cnode); - - std::string op_name = distribute_operator->name(); - if (op_name.find(DROPOUT_DO_MASK) == std::string::npos) { - return; - } - - DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast(distribute_operator); - MS_EXCEPTION_IF_NULL(dropout_do_mask); - std::vector replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode); - if (replace_op.empty()) { - MS_LOG(DEBUG) << "No need to replace dropout_gen_mask"; - return; - } - if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; - } - ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); -} - -void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { - HandleDropoutNode(distribute_operator, cnode); -} - -std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { - // J->CNode->Graph - std::set graph_set; - for (auto &node : root_all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - - auto cnode = node->cast(); - if ((cnode->size() < 2) || !IsValueNode(cnode->input(0))) { - continue; - } - auto expect_j_prim = GetValueNode(cnode->input(0)); - if (expect_j_prim->name() != J) { - continue; - } - if (IsValueNode(cnode->input(1))) { - auto graph = GetValueNode(cnode->input(1)); - MS_LOG(DEBUG) << "Find the forward graph success"; - graph_set.insert(graph); - } - } - return graph_set; -} - -void StepSplitSens(const std::pair &sens_loss_pair) { - CNodePtr sens_node = sens_loss_pair.first; - CNodePtr loss_node = sens_loss_pair.second; - auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node); - if (!loss_grad_layout.empty()) { - SplitSens(sens_node, loss_grad_layout[0]); - } -} - -// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) -std::vector> GetSensLossPairs(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - std::vector> sens_loss_pairs; - for (auto &node : root->nodes()) { - if (!node->isa()) { - continue; - } - - // cnode(sens)-->cnode(tuple_getitem) - auto sens_cnode = node->cast(); - AnfNodePtr expect_tuple_getitem = sens_cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem); - if (!expect_tuple_getitem->isa()) { - continue; - } - - auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); - if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) { - continue; - } - - // cnode(sens)-->cnode(tuple_getitem)-->cnode - AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); - MS_EXCEPTION_IF_NULL(expect_anonymous); - if (!expect_anonymous->isa()) { - continue; - } - - // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) - auto expect_anonymous_cnode = expect_anonymous->cast(); - AnfNodePtr expect_j = expect_anonymous_cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_j); - if (!expect_j->isa()) { - continue; - } - auto expect_j_cnode = expect_j->cast(); - if (!IsSomePrimitive(expect_j_cnode, J)) { - continue; - } - - if (!IsValueNode(expect_j_cnode->input(1))) { - MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; - } - auto func_graph = GetValueNode(expect_j_cnode->input(1)); - auto loss_cnode = FindLossCNode(func_graph); - if (loss_cnode == nullptr) { - MS_LOG(WARNING) << "Can not find the loss cnode"; - continue; - } - std::pair sens_loss_pair = std::make_pair(sens_cnode, loss_cnode); - sens_loss_pairs.push_back(sens_loss_pair); - } - return sens_loss_pairs; -} - -void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, - const FuncGraphManagerPtr &manager) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(manager); - TensorRedistribution tensor_redistribution; - - std::vector> sens_loss_pairs = GetSensLossPairs(root); - bool has_backward = !sens_loss_pairs.empty(); - // split sens must before inserting the operators. - for (auto &pair : sens_loss_pairs) { - // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. - // If the type of sens node is not Tensor, it is unsupported now, do nothing default. - StepSplitSens(pair); - } - - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - continue; - } - OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); - if (distribute_operator == nullptr) { - continue; - } - - // insert forward ops - InsertForwardOps(distribute_operator, cnode); - - // insert redistribution ops - StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode); - - // insert backward ops - if (has_backward) { - BackwardCommunication(distribute_operator, cnode, sens_loss_pairs); - } - - HandleSpecialNode(distribute_operator, cnode); - } else if (IsValueNode(node)) { - StepSplitTensor(node, manager); - } - } - - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - continue; - } - OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); - if (distribute_operator == nullptr) { - continue; - } - // StepReplace - StepReplace(distribute_operator, cnode); - } - } -} - -namespace { -void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(node); - auto symbolic_key = GetValueNode(node); - MS_EXCEPTION_IF_NULL(symbolic_key); - auto all_upstream_node = root->manager()->node_users()[node]; - for (auto &upstream_node : all_upstream_node) { - FuncGraphPtr fg = upstream_node.first->func_graph(); - if (symbolic_key->node()->isa()) { - for (auto ¶m : root->parameters()) { - if (*param == *symbolic_key->node()) { - AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param}); - MS_EXCEPTION_IF_NULL(reverted_node); - MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString(); - (void)fg->manager()->Replace(node, reverted_node); - MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString(); - } - } - } - } -} -} // namespace - -void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector &all_nodes) { - MS_EXCEPTION_IF_NULL(root); - for (auto &node : all_nodes) { - // revert back SymbolicKeyInstance to embed() primitive - if (IsValueNode(node)) { - RevertSymbolicKeyInstance(root, node); - continue; - } - } -} - -std::string NodeParameterName(const CNodePtr &node) { - std::vector node_inputs{node->inputs()}; - for (auto input : node_inputs) { - if (input->isa()) { - auto input_parameter = input->cast(); - if (input_parameter->has_default()) { - auto param_value = std::dynamic_pointer_cast(input_parameter->default_param()); - if (py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), REQUIRES_GRAD))) { - return py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), PARAM_NAME)); - } - } - } - } - return ""; -} - -void CheckpointStrategy(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; - StrategyMap stra_map; - auto ret = func_graph->get_return(); - auto all_nodes = DeepScopedGraphSearch(ret); - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - continue; - } - std::string param_name = NodeParameterName(cnode); - if (param_name.empty()) { - continue; - } - PrimitivePtr prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); - if (operator_info) { - StrategyPtr strategyPtr = operator_info->strategy(); - MS_EXCEPTION_IF_NULL(node->scope()); - stra_map[param_name] = strategyPtr; - } - } - if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { - MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; - } -} - -void SetForwardFlag(const std::vector &all_nodes) { - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - continue; - } - - // CNode is globally unique. - MS_LOG(DEBUG) << "Set forward flag " << cnode->DebugString() << "."; - cnode->set_in_forward_flag(true); - } -} - -void SetForwardFlag(const AnfNodeSet &all_nodes) { - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - continue; - } - - // CNode is globally unique. - cnode->set_in_forward_flag(true); - } -} - -std::set ForwardGraph(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - const auto &all_nodes = root->nodes(); - std::set graph_set = FindForwardGraphByRootNodes(all_nodes); - return graph_set; -} - -std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { - MS_EXCEPTION_IF_NULL(graph); - std::vector root_forward_nodes; - auto loss_cnode = FindLossCNode(graph); - if (loss_cnode == nullptr) { - MS_LOG(WARNING) << "Can not find the loss cnode"; - return root_forward_nodes; - } - - auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy(); - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - auto root_node_id = node->UniqueIdThroughCopy(); - if (loss_cnode_id == root_node_id) { - root_forward_nodes = DeepLinkedGraphSearch(cnode); - break; - } - } - return root_forward_nodes; -} - -void MarkForwardCNode(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - auto all_nodes = root->nodes(); - std::set graph_set = FindForwardGraphByRootNodes(all_nodes); - - if (graph_set.empty()) { - MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; - SetForwardFlag(all_nodes); - } else { - for (auto &func_graph : graph_set) { - MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); - auto return_node = func_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); - SetForwardFlag(all_dfs_nodes); - auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes); - if (root_forward_nodes.empty()) { - continue; - } - // Mark forward flag for the nodes in root graph. - SetForwardFlag(root_forward_nodes); - } - } -} - -Status ParallelInit() { - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - int32_t device_num = ParallelContext::GetInstance()->device_num(); - int32_t global_rank = ParallelContext::GetInstance()->global_rank(); - std::string backend = ParallelContext::GetInstance()->communication_backend(); - std::string world_group; - - if (backend == HCCL_BACKEND) { - world_group = HCCL_WORLD_GROUP; - } else if (backend == NCCL_BACKEND) { - world_group = NCCL_WORLD_GROUP; - } else { - MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; - } - - uint32_t world_rank_size = 0; - if (!ParallelContext::GetInstance()->device_num_is_set()) { - if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) { - MS_LOG(EXCEPTION) << "Get rank size failed"; - } - device_num = UintToInt(world_rank_size); - MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num; - } - - uint32_t rank_id = 0; - if (!ParallelContext::GetInstance()->global_rank_is_set()) { - if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { - MS_LOG(EXCEPTION) << "Get rank id failed"; - } - global_rank = UintToInt(rank_id); - MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; - } - - if (!InitDevice(device_num, global_rank, backend)) { - MS_LOG(ERROR) << "Init device failed"; - return FAILED; - } - - MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank - << ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean() - << ", cast_before_mirror: " << ParallelContext::GetInstance()->cast_before_mirror(); - return SUCCESS; -} - -bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(optimizer); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); - // assume no change to graph - bool changes = false; - // control whether use model_parallel mode - if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || - (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { - if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) { - if (HasStrategy(root)) { - MS_LOG(INFO) << "Strategies ignored in " << parallel_mode - << ", set_strategy() only valid in [semi_]auto_parallel."; - } - root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); - } - - return changes; - } - - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); - - MS_LOG(INFO) << "Now entering step parallel"; - DumpGraph(root, std::string(STEP_PARALLEL_BEGIN)); - - pipeline::ResourceBasePtr res = optimizer->resource(); - MS_EXCEPTION_IF_NULL(res); - - FuncGraphManagerPtr manager = res->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodePtr ret = root->get_return(); - MS_EXCEPTION_IF_NULL(ret); - std::vector all_nodes = DeepScopedGraphSearch(ret); - std::reverse(all_nodes.begin(), all_nodes.end()); - if (parallel_mode != AUTO_PARALLEL) { - TOTAL_OPS = 0; - if (ParallelInit() != SUCCESS) { - MS_LOG(EXCEPTION) << "Parallel init failed"; - } - - // mark the forward cnodes, parallel only care these nodes - MarkForwardCNode(root); - - if (FindCommunicationOp(all_nodes)) { - MS_LOG(EXCEPTION) << "The graph contain communication op"; - } - - // extract shape and strategy, set operator_info - ExtractInformation(all_nodes); - ReshapeInit(all_nodes); - } - // save strategy as checkpoint for multi-train - if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { - CheckpointStrategy(root); - } - - HandleSymbolicKeyInstance(root, all_nodes); - - // cover Parallel shape - CoverSliceShape(root); - - // set the shape for optimizer's clone tensor - SetClonedTensorShapeForOptimizer(root); - - // ForwardCommunication BackwardCommunication TensorRedistribution - ParallelCommunication(root, all_nodes, manager); - - DumpGraph(root, std::string(STEP_PARALLEL_END)); - - // step parallel only run once - root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true); - res->results()[pipeline::kStepParallelGraph] = root; - - // in auto parallel mode, no need to check if stategies set - root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); - - (void)gettimeofday(&end_time, nullptr); - uint64_t time = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Now leaving step parallel, used time: " << time << " us"; - return changes; -} - -// Needed by rec_parser -std::vector ExtractInputsTensorName(const CNodePtr &node) { - std::vector name_inputs; - std::vector all_inputs = node->inputs(); - std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; - - std::string node_id = node->UniqueId(); - name_inputs.push_back(node_id); - for (auto &input : node_inputs) { - std::string name = input->UniqueId(); - name_inputs.push_back(name); - } - - return name_inputs; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h deleted file mode 100644 index 308473dcd7..0000000000 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ - -#include - -#include -#include -#include -#include -#include -#include - -#include "./common.h" -#include "optimizer/opt.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -using OperatorInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace parallel { -const uint64_t kUSecondInSecond = 1000000; - -struct LossNodeInfo { - bool has_tuple_getitem = false; - int dout_index = 0; // now don't support the sens is a tuple -}; - -std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); -std::string CreateInstanceName(const CNodePtr &node, size_t index); -void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node); - -void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, - const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node); - -TensorLayout GetTensorInLayout(const CNodePtr &pre_node, const PrimitivePtr &pre_prim, - const OperatorInfoPtr &distribute_operator_pre); - -OperatorInfoPtr GetDistributeOperator(const CNodePtr &node); - -void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, - const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, - const CNodePtr &pre_node); - -bool StrategyFound(std::unordered_map attrs); - -bool IsParallelCareNode(const CNodePtr &cnode); - -void MarkForwardCNode(const FuncGraphPtr &root); - -bool FindCommunicationOp(const std::vector &all_nodes); - -void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, - const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node); - -std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, - const CNodePtr &node); - -void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node); - -void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node); - -std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph); - -std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); - -void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); - -void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, - const std::vector> &sens_loss_pairs); - -// Generate and init parallel operator -OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, - const std::vector &shape_list); - -// Generate without initing parallel operator -OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, - std::vector shape_list); - -// Extract strategy from attr -StrategyPtr ExtractStrategy(std::unordered_map attrs); - -Shapes GetNodeShape(const AnfNodePtr &node); - -std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); - -// Extract shape from anfnode -std::vector ExtractShape(const CNodePtr &node); - -std::pair FindParallelCareNode(const AnfNodePtr &node); - -// Find finally sub graph -std::pair FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter); - -// Set distribute shape for parameters abstract -void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res); - -// change parameters'shape in resource -void CoverSliceShape(const FuncGraphPtr &root); - -void SetVirtualDatasetStrategy(const CNodePtr &node); - -// Creat parallel operator for primitive node(has strategy) -void ExtractInformation(const std::vector &all_nodes); - -TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair); - -std::shared_ptr FindNextLayout(const CNodePtr &node); - -std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index); - -std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index); - -std::shared_ptr FindPrevLayout(const AnfNodePtr &node); - -void ReshapeInit(const std::vector &all_nodes); - -// Add node for whole graph -void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, - const FuncGraphManagerPtr &manager); - -std::string NodeParameterName(const CNodePtr &node); - -void CheckpointStrategy(const FuncGraphPtr &func_graph); - -// main step of Parallel -bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); - -int32_t GetTupleGetItemIndex(const CNodePtr &cnode); - -Status ParallelInit(); - -std::vector ExtractInputsTensorName(const CNodePtr &node); - -std::set ForwardGraph(const FuncGraphPtr &root); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ diff --git a/mindspore/ccsrc/parallel/strategy.h b/mindspore/ccsrc/parallel/strategy.h deleted file mode 100644 index bc62dd5308..0000000000 --- a/mindspore/ccsrc/parallel/strategy.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ -#define MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ - -#include -#include -#include -#include -#include - -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -#define MIN_SLICE_NUM 1 - -using Dimensions = std::vector; - -class Strategy; -using StrategyPtr = std::shared_ptr; - -class Strategy { - public: - Strategy(int32_t stage, std::vector inputs) : stage_(stage), inputs_(std::move(inputs)) {} - ~Strategy() = default; - size_t GetInputNumber() const { return inputs_.size(); } - std::vector GetInputDim() const { return inputs_; } - int32_t GetInputStage() const { return stage_; } - void ExpandInputDimFromOneToTwo() { - if (inputs_.size() == 1) { - inputs_.push_back(inputs_[0]); - } - } - void ResetInputs(const std::vector &input) { inputs_ = input; } - - bool IsEqual(const StrategyPtr &another_stra) { - if (another_stra == nullptr) { - return false; - } - if ((stage_ != another_stra->GetInputStage()) || (inputs_ != another_stra->GetInputDim())) { - return false; - } - return true; - } - - private: - const int32_t stage_; - - // The size of Dimensions must equal to inputs_ tensor dimension. - std::vector inputs_; -}; - -inline StrategyPtr NewStrategy(const int32_t stage, const std::vector &inputs) { - return std::make_shared(stage, inputs); -} -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc deleted file mode 100644 index de10f4beb4..0000000000 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" - -#include -#include -#include - -#include "common/utils.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" -#include "proto/node_strategy.pb.h" - -namespace mindspore { -namespace parallel { -StrategyCheckpoint &StrategyCheckpoint::GetInstance() { - static StrategyCheckpoint instance = StrategyCheckpoint(); - if (ParallelContext::GetInstance() != nullptr) { - instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file(); - instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty(); - instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file(); - instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty(); - } - return instance; -} - -bool StrategyCheckpoint::CheckPointExit(const std::string path) const { - std::ifstream fin(path); - if (fin) { - return true; - } - return false; -} - -Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { - if (strategy_map == nullptr) { - MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; - } - if (!CheckPointExit(load_file_)) { - MS_LOG(EXCEPTION) << "CheckPoint file is not found"; - } - straspb::ParallelStrategyMap parallel_strategy_map; - std::fstream input(load_file_, std::ios::in | std::ios::binary); - if (!parallel_strategy_map.ParseFromIstream(&input)) { - MS_LOG(ERROR) << "Load strategy file failed"; - return FAILED; - } - size_t node_num = IntToSize(parallel_strategy_map.parallel_strategy_item_size()); - for (size_t i = 0; i < node_num; i++) { - straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i)); - std::string node_name = parallel_strategy_item.node_name(); - straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys(); - auto stage = (int32_t)parallel_strategys.stage(); - size_t strategys_num = IntToSize(parallel_strategys.parallel_strategy_size()); - std::vector> strategy_inputs; - for (size_t j = 0; j < strategys_num; j++) { - straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j)); - std::vector dimension; - size_t dim_num = IntToSize(parallel_strategy.dim_size()); - for (size_t k = 0; k < dim_num; k++) { - dimension.push_back(parallel_strategy.dim(SizeToInt(k))); - } - strategy_inputs.push_back(dimension); - } - - StrategyPtr strategy = NewStrategy(stage, strategy_inputs); - (*strategy_map)[node_name] = strategy; - current_stage_ = (int32_t)parallel_strategy_map.current_stage(); - } - return SUCCESS; -} - -Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { - straspb::ParallelStrategyMap parallel_strategy_map; - parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); - for (auto &node_stra : strategy_map) { - straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); - MS_EXCEPTION_IF_NULL(parallel_strategy_item); - parallel_strategy_item->set_node_name(node_stra.first); - straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); - MS_EXCEPTION_IF_NULL(parallel_strategys); - parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); - for (auto &dims : node_stra.second->GetInputDim()) { - straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); - MS_EXCEPTION_IF_NULL(parallel_strategy); - for (auto dim : dims) { - parallel_strategy->add_dim(IntToUint(dim)); - } - } - } - std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); - if (!parallel_strategy_map.SerializeToOstream(&output)) { - MS_LOG(ERROR) << "Save strategy file failed"; - return FAILED; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h deleted file mode 100644 index a758a9e7bb..0000000000 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ -#define MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ - -#include -#include -#include "parallel/ops_info/ops_utils.h" -#include "parallel/strategy.h" -#include "parallel/context.h" - -namespace mindspore { -namespace parallel { -using StrategyMap = std::unordered_map; -class StrategyCheckpoint { - public: - StrategyCheckpoint() { - current_stage_ = 0; - load_file_ = ""; - load_checkpoint_on_ = false; - save_file_ = ""; - save_checkpoint_on_ = false; - } - ~StrategyCheckpoint() = default; - - Status Load(StrategyMap *strategy_map); - Status Save(const StrategyMap &strategy_map); - - static StrategyCheckpoint &GetInstance(); - bool LoadCheckPointOn() const { return load_checkpoint_on_; } - bool SaveCheckPointOn() const { return save_checkpoint_on_; } - - private: - std::string load_file_; - std::string save_file_; - bool load_checkpoint_on_; - bool save_checkpoint_on_; - bool CheckPointExit(const std::string path) const; - int32_t current_stage_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc b/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc deleted file mode 100644 index 235ab00302..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc +++ /dev/null @@ -1,248 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/arrangement.h" -#include -#include -#include -#include "common/utils.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/shape_util.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status Arrangement::Init(const std::vector &array) { - Status status = Array::Init(array); - if (status != Status::SUCCESS) { - return Status::FAILED; - } - if (!IsValidArrangement()) { - MS_LOG(ERROR) << "invalid arrangement " << this->ToString(); - return Status::FAILED; - } - ComputeSize(); - return Status::SUCCESS; -} - -bool Arrangement::IsValidArrangement() { - return !std::any_of(array_.begin(), array_.end(), [](int32_t value) { return value <= 0; }); -} - -void Arrangement::ComputeSize() { - size_ = 1; - for (auto &value : array_) { - size_ *= value; - } -} - -/* - * if GetDimSize() = 0, return [] - * if value <= array_[0], return [value] - * if array_[0] < value <= size_[i], return [shape[0], shape[1], ..., shape[i-1], value/size_[i-1]], - * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1], - * if value > size_, return [] - */ -std::vector Arrangement::GetFrontElementByValue(int32_t value) const { - std::vector out; - if (GetDimSize() == 0) { - return out; - } - if (value <= size_) { - int32_t size = 1; - uint32_t shape_list_idx = 0; - while (size < value) { - size *= array_[shape_list_idx]; - if (size <= value) { - out.push_back(array_[shape_list_idx]); - } else { - if (size == 0) { - MS_LOG(ERROR) << "The size is 0"; - out.clear(); - return out; - } - out.push_back(value * array_[shape_list_idx] / size); - } - shape_list_idx++; - } - } - return out; -} - -std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft( - const std::vector &expand_list) const { - if (expand_list.size() != GetDimSize()) { - return nullptr; - } - std::vector new_shape; - for (uint32_t i = 0; i < expand_list.size(); i++) { - std::vector expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i)); - if (expand_shape.empty()) { - new_shape.push_back(GetDimByIdx(i)); - } else { - (void)new_shape.insert(new_shape.end(), expand_shape.begin(), expand_shape.end()); - } - } - Arrangement arrangement_new; - (void)arrangement_new.Init(new_shape); - return std::make_shared(arrangement_new); -} - -/* - * example: - * expand_shape = [4, 2, 2, 2] - * array_ = [8, 4], - * arrangement_list = [[4, 2], [2, 2]] - */ -std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const { - int32_t size = 1; - uint32_t ind = 0; - std::vector arrangement_list; - std::vector shape; - for (uint32_t i = 0; i < expand_shape.GetDimSize(); i++) { - size *= expand_shape.GetDimByIdx(i); - if (size > GetDimByIdx(ind)) { - MS_LOG(ERROR) << "invalid expand_shape"; - return nullptr; - } else if (size < GetDimByIdx(ind)) { - shape.push_back(expand_shape.GetDimByIdx(i)); - continue; - } else { - shape.push_back(expand_shape.GetDimByIdx(i)); - Arrangement arrangement; - (void)arrangement.Init(shape); - arrangement_list.push_back(arrangement); - shape.clear(); - ind++; - size = 1; - } - } - if (ind != GetDimSize()) { - MS_LOG(ERROR) << "invalid expand_shape"; - return nullptr; - } - auto arrangement_new = std::make_shared>(arrangement_list); - return arrangement_new; -} - -std::shared_ptr, Arrangement>> Arrangement::GetExpandShapeListPair( - const Arrangement &expand_shape) const { - std::shared_ptr> expand_shape_list_ptr = GetExpandShapeList(expand_shape); - if (expand_shape_list_ptr == nullptr) { - return nullptr; - } - std::vector expand_num_list_shape; - (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(), - std::back_inserter(expand_num_list_shape), - [](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); }); - Arrangement expand_num_list; - Status status = expand_num_list.Init(expand_num_list_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - auto out_value = std::make_pair(*expand_shape_list_ptr, expand_num_list); - return std::make_shared, Arrangement>>(out_value); -} - -std::vector Arrangement::ComputeReverseAccumulateSumInReverseOrder() const { - std::vector shape_accum; - int32_t size = 0; - for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) { - shape_accum.push_back(size); - size += *iter; - } - return shape_accum; -} - -std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLeft( - const std::vector &expand_list) const { - if (expand_list.size() != GetDimSize()) { - return nullptr; - } - std::vector new_shape; - for (uint32_t i = 0; i < expand_list.size(); i++) { - if (expand_list[i].GetDimSize() >= 1) { - int32_t size = 1; - for (uint32_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) { - new_shape.push_back(expand_list[i].GetDimByIdx(k)); - size *= expand_list[i].GetDimByIdx(k); - } - new_shape.push_back(GetDimByIdx(i) / size); - } else { - new_shape.push_back(GetDimByIdx(i)); - } - } - Arrangement arrangement_new; - (void)arrangement_new.Init(new_shape); - return std::make_shared(arrangement_new); -} - -std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement &in2) const { - std::vector in1_accum; - Status status = ShapeToAccumulateProduct(array_, &in1_accum); - if (status != Status::SUCCESS) { - return nullptr; - } - std::vector in2_accum; - status = ShapeToAccumulateProduct(in2.array(), &in2_accum); - if (status != Status::SUCCESS) { - return nullptr; - } - std::vector out_accum; - status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum); - if (status != Status::SUCCESS) { - return nullptr; - } - std::vector out_shape; - status = AccumulateProductToShape(out_accum, &out_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - Arrangement out; - status = out.Init(out_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(out); -} - -std::vector Arrangement::GetSqueezeIdx() const { - std::vector out; - for (size_t i = 0; i < GetDimSize(); i++) { - if (GetDimByIdx(SizeToUint(i)) == 1) { - out.push_back(i); - } - } - return out; -} - -Arrangement Arrangement::GetSqueezeArrangement() const { - std::vector out_shape(array_.size()); - auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int32_t value) { return value != 1; }); - out_shape.resize(LongToSize(std::distance(out_shape.begin(), it))); - - // if all elements are 1, out_shape = {1} - if (out_shape.empty()) { - MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation"; - out_shape.push_back(1); - } - Arrangement out; - (void)out.Init(out_shape); - return out; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/arrangement.h b/mindspore/ccsrc/parallel/tensor_layout/arrangement.h deleted file mode 100644 index ca71b05c91..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/arrangement.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ - -#include -#include -#include -#include -#include -#include -#include "parallel/status.h" -#include "parallel/tensor_layout/array.h" - -namespace mindspore { -namespace parallel { -class Arrangement : public Array { - public: - Arrangement() : size_(1) {} - ~Arrangement() override = default; - Status Init(const std::vector &array) override; - int32_t size() const { return size_; } - std::vector GetFrontElementByValue(int32_t value) const; - std::shared_ptr> GetExpandShapeList(const Arrangement &expand_shape) const; - std::vector ComputeReverseAccumulateSumInReverseOrder() const; - std::shared_ptr GetExpandedShapeByExpandListReserveLeft( - const std::vector &expand_list) const; - std::shared_ptr GetExpandedShapeByExpandListRemoveLeft( - const std::vector &expand_list) const; - std::shared_ptr, Arrangement>> GetExpandShapeListPair( - const Arrangement &expand_shape) const; - std::shared_ptr GetUnifiedShape(const Arrangement &in2) const; - std::vector GetSqueezeIdx() const; - Arrangement GetSqueezeArrangement() const; - - private: - bool IsValidArrangement(); - void ComputeSize(); - int32_t size_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/array.cc b/mindspore/ccsrc/parallel/tensor_layout/array.cc deleted file mode 100644 index ef358e7cde..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/array.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/array.h" -#include -#include "parallel/status.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -std::string Array::ToString() const { - std::ostringstream buffer; - buffer << "[ "; - for (auto &element : array_) { - buffer << std::to_string(element) + " "; - } - buffer << "]"; - return buffer.str(); -} - -Status Array::Init(const std::vector &array) { - array_ = array; - return IsvalidArray() ? Status::SUCCESS : Status::FAILED; -} - -bool Array::IsvalidArray() const { return true; } - -int32_t Array::GetDimByIdx(uint32_t idx) const { - size_t mod_idx = idx; - if (idx >= GetDimSize()) { - MS_LOG(EXCEPTION) << "idx is " << idx << ", but array size is " << GetDimSize(); - } - return array_[mod_idx]; -} - -int32_t Array::GetDimByReverseIdx(uint32_t idx) const { - size_t mod_idx = idx; - if (idx >= GetDimSize()) { - MS_LOG(EXCEPTION) << "idx is " << idx << " but array size is " << GetDimSize(); - } - return array_[GetDimSize() - 1 - mod_idx]; -} - -bool Array::operator==(const Array &shape) const { - if (GetDimSize() != shape.GetDimSize()) { - return false; - } - for (uint32_t i = 0; i < GetDimSize(); i++) { - if (GetDimByIdx(i) != shape.GetDimByIdx(i)) { - return false; - } - } - return true; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/array.h b/mindspore/ccsrc/parallel/tensor_layout/array.h deleted file mode 100644 index 5aa3bdb138..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/array.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ - -#include -#include -#include -#include -#include -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -class Array { - public: - Array() = default; - virtual ~Array() = default; - std::string ToString() const; - virtual Status Init(const std::vector &array); - bool IsvalidArray() const; - std::vector array() const { return array_; } - size_t GetDimSize() const { return array_.size(); } - int32_t GetDimByIdx(uint32_t idx) const; - int32_t GetDimByReverseIdx(uint32_t idx) const; - bool operator==(const Array &a1) const; - - protected: - std::vector array_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc deleted file mode 100644 index b5ca5ed60a..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc +++ /dev/null @@ -1,254 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/construct_operator.h" - -#include -#include - -namespace mindspore { -namespace parallel { -Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape) { - dev_size_ = dev_matrix_shape.size(); - dev_matrix_shape_ = dev_matrix_shape; - dev_list_ = dev_list; - return Status::SUCCESS; -} - -Status ConstructOperator::ReshapeOP(Shape shape) { - int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies()); - if (prod != prod_expect) { - ValuePtr ptr = MakeValue(shape); - MS_EXCEPTION_IF_NULL(ptr); - MS_LOG(ERROR) << "Invalid tensor shape " << ptr->ToString() << "when construct Reshape operator!"; - return Status::INVALID_ARGUMENT; - } - OperatorAttrs attrs; - ValuePtr param_value = MakeValue(shape); - Attr param = std::make_pair(SHAPE, param_value); - OperatorParams params = {std::make_pair(param, 2)}; - OperatorArgs args = std::make_pair(attrs, params); - op_ = std::make_pair(RESHAPE, args); - return Status::SUCCESS; -} - -Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &end, const Shape &strides) { - ValuePtr attr_value = MakeValue(value); - Attr attr_begin_mask = std::make_pair(BEGIN_MASK, attr_value); - Attr attr_end_mask = std::make_pair(END_MASK, attr_value); - Attr attr_ellipsis_mask = std::make_pair(ELLIPSIS_MASK, attr_value); - Attr attr_new_axis_mask = std::make_pair(NEW_AXIS_MASK, attr_value); - Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value); - OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask}; - - ValuePtr param_begin_value = MakeValue(begin); - Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), 2); - ValuePtr param_end_value = MakeValue(end); - Param param_end = std::make_pair(std::make_pair(END, param_end_value), 3); - - ValuePtr param_strides_value = MakeValue(strides); - Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), 4); - OperatorParams params = {param_begin, param_end, param_strides}; - OperatorArgs op_args = std::make_pair(attrs, params); - - return std::make_pair(STRIDED_SLICE, op_args); -} - -Status ConstructOperator::StridedSliceOP(Args args) { - if (args.size() < 3) { - MS_LOG(ERROR) << "args size should not be less than 3!"; - return Status::FAILED; - } - int32_t split_count = args[0]; - if (split_count <= 0) { - MS_LOG(ERROR) << "split_count should not be less than 0!"; - return Status::FAILED; - } - int32_t split_dim = args[1]; - int32_t dev_dim = args[2]; - std::vector group_list; - - if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { - MS_LOG(ERROR) << "stride slice op: create group failed"; - return FAILED; - } else if (group_list.empty()) { // this group only has one device, don't need do StridedSlice - MS_LOG(INFO) << "no need stride slice op"; - return SUCCESS; - } - - Group group = group_list[0]; - size_t rank; - if (group.GetIndex(&rank) == Status::FAILED) { - return Status::FAILED; - } - size_t size = tensor_shape_.size(); - Shape begin(size); - Shape end(size); - Shape strides(size, 1); - size_t index = 0; - for (auto num : tensor_shape_) { - if (index != IntToSize(split_dim)) { - begin[index] = 0; - end[index] = num; - } else { - if (num % split_count != 0) { - MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim - << "! when construct StridedSlice operator"; - return Status::INVALID_ARGUMENT; - } - int32_t count = num / split_count; - begin[index] = SizeToInt(rank) * count; - end[index] = (SizeToInt(rank) + 1) * count; - } - index++; - } - - op_ = CreateStridedSliceOp(DEFAULT, begin, end, strides); - - return Status::SUCCESS; -} - -Status ConstructOperator::AllGatherOP(int32_t dev_dim) { - if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { - MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!"; - return Status::INVALID_ARGUMENT; - } - - std::vector group_list; - if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { - MS_LOG(ERROR) << "AllGather op: create group failed"; - return FAILED; - } else if (group_list.empty()) { // this group only has one device, don't need do allgather - MS_LOG(INFO) << "no need all gather op"; - return SUCCESS; - } - - std::string group_name = group_list[0].name(); - ValuePtr attr_value = MakeValue(group_name); - Attr attr = std::make_pair(GROUP, attr_value); - OperatorAttrs attrs = {attr}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - op_ = std::make_pair(ALL_GATHER, args); - return Status::SUCCESS; -} - -Status ConstructOperator::ConcatOP(int32_t concat_dim) { - if (IntToSize(concat_dim) >= tensor_shape_.size()) { - MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!"; - return Status::INVALID_ARGUMENT; - } - ValuePtr attr_value = MakeValue(concat_dim); - Attr attr = std::make_pair(AXIS, attr_value); - OperatorAttrs attrs = {attr}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - op_ = std::make_pair(CONCAT, args); - return Status::SUCCESS; -} - -Status ConstructOperator::SplitOP(int32_t split_count) { - if (split_count <= 0) { - MS_LOG(ERROR) << "Invalid split count when construct Split operator!"; - return Status::FAILED; - } - OperatorAttrs attrs; - ValuePtr attr_value_axis = MakeValue(DEFAULT); - Attr attr_axis = std::make_pair(AXIS, attr_value_axis); - ValuePtr attr_value_split = MakeValue(split_count); - Attr attr_split = std::make_pair(OUTPUT_NUM, attr_value_split); - attrs = {attr_axis, attr_split}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - op_ = std::make_pair(SPLIT, args); - return Status::SUCCESS; -} - -Status ConstructOperator::AlltoAllOP(Args args) { - if (args.size() < 4) { - MS_LOG(ERROR) << "args size should not be less than 4!"; - return Status::FAILED; - } - int32_t split_count = args[0]; - int32_t split_dim = args[1]; - int32_t concat_dim = args[2]; - int32_t dev_dim = args[3]; - if (split_count <= 0) { - MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!"; - return Status::FAILED; - } - if (tensor_shape_[IntToSize(split_dim)] % split_count != 0) { - MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim - << "when construct AlltoAll operator!"; - return Status::INVALID_ARGUMENT; - } - if (IntToSize(concat_dim) >= tensor_shape_.size()) { - MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!"; - return Status::INVALID_ARGUMENT; - } - if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { - MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!"; - return Status::INVALID_ARGUMENT; - } - - std::vector group_list; - if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { - MS_LOG(ERROR) << "AlltoAll op: create group failed"; - return FAILED; - } else if (group_list.empty()) { // this group only has one device, don't need do alltoall - MS_LOG(INFO) << "no need all to all op"; - return SUCCESS; - } - - std::string group_name = group_list[0].name(); - ValuePtr attr_value_group = MakeValue(group_name); - Attr attr_group = std::make_pair(GROUP, attr_value_group); - ValuePtr attr_value_split_count = MakeValue(split_count); - Attr attr_split_count = std::make_pair(SPLIT_COUNT, attr_value_split_count); - ValuePtr attr_value_split_dim = MakeValue(split_dim); - Attr attr_split_dim = std::make_pair(SPLIT_DIM, attr_value_split_dim); - ValuePtr attr_value_concat_dim = MakeValue(concat_dim); - Attr attr_concat_dim = std::make_pair(CONCAT_DIM, attr_value_concat_dim); - OperatorAttrs attrs = {attr_split_count, attr_split_dim, attr_concat_dim, attr_group}; - OperatorParams params; - OperatorArgs op_args = std::make_pair(attrs, params); - op_ = std::make_pair(ALL_TO_ALL, op_args); - return Status::SUCCESS; -} - -Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector *group) { - MS_EXCEPTION_IF_NULL(group); - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - int32_t rank = g_device_manager->global_rank(); - DeviceMatrix dev_matrix(rank, dev_list_, dev_matrix_shape_); - RankList group_devices; - if (dev_matrix.GetDevicesAlongDim(SizeToUint(axis), &group_devices) != SUCCESS) { - return FAILED; - } - // this group only has one device, don't need create the group - if (group_devices.size() == 1) { - MS_LOG(INFO) << "the group is empty"; - return SUCCESS; - } - - Group g = g_device_manager->CreateGroup(group_devices); - group->push_back(g); - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h deleted file mode 100644 index 1a69638fb6..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ - -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -using Args = std::vector; - -class ConstructOperator { - public: - const int32_t DEFAULT = 0; - ConstructOperator() : dev_size_(0) {} - ~ConstructOperator() = default; - Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); - Status ReshapeOP(Shape shape); - Status StridedSliceOP(Args args); - Status AllGatherOP(int32_t dev_dim); - Status SplitOP(int32_t split_count); - Status ConcatOP(int32_t concat_dim); - Status AlltoAllOP(Args args); - Operator GetOperator() const { return op_; } - void UpdateTensorShape(const Shape &tensor_shape) { tensor_shape_ = tensor_shape; } - - private: - Operator op_; - size_t dev_size_; - Shape tensor_shape_; - RankList dev_list_; - Shape dev_matrix_shape_; - Status CreateGroupByDim(size_t axis, std::vector *group); -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc deleted file mode 100644 index 84c0580ba8..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/layout_transfer.h" -#include "common/utils.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -std::string LayoutTransfer::ToString() const { - std::ostringstream buffer; - buffer << std::endl << std::string("from_in_ tensor layout:" + from_in_.ToString()); - buffer << std::endl << std::string("to_in_ tensor layout:" + to_in_.ToString()); - return buffer.str(); -} - -LayoutTransfer::~LayoutTransfer() = default; - -Status LayoutTransfer::Init(const TensorLayout &from_in, const TensorLayout &to_in) { - from_in_ = from_in; - to_in_ = to_in; - MS_LOG(DEBUG) << "LayoutTransfer " << this->ToString(); - Status status = CheckValidTransfer(); - return status; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h deleted file mode 100644 index c4da4b728f..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ - -#include -#include "parallel/status.h" -#include "parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -class LayoutTransfer { - public: - LayoutTransfer() = default; - virtual ~LayoutTransfer() = 0; - std::string ToString() const; - Status Init(const TensorLayout &from_in, const TensorLayout &to_in); - TensorLayout from_in() const { return from_in_; } - TensorLayout to_in() const { return to_in_; } - - protected: - bool IsSameTensorShape() const { return from_in_.IsSameTensorShape(to_in_); } - bool IsSameDeviceArrangement() const { return from_in_.IsSameDeviceArrangement(to_in_); } - - TensorLayout from_in_; - TensorLayout to_in_; - - private: - virtual Status CheckValidTransfer() = 0; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/map.cc b/mindspore/ccsrc/parallel/tensor_layout/map.cc deleted file mode 100644 index 669920fc44..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/map.cc +++ /dev/null @@ -1,171 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/map.h" -#include -#include -#include -#include "common/utils.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/shape_util.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status Map::Init(const std::vector &array) { - Status status = Array::Init(array); - if (status != Status::SUCCESS) { - return Status::FAILED; - } - if (!IsValidMap()) { - MS_LOG(ERROR) << "invalid map " << this->ToString(); - return Status::FAILED; - } - return Status::SUCCESS; -} - -bool Map::IsValidMap() { - if (std::any_of(array_.begin(), array_.end(), [](int32_t value) { return ((value < 0) && (value != MAP_NONE)); })) { - return false; - } - // check that all none -1 value in array_ is different - std::vector sorted_array = array_; - std::sort(sorted_array.begin(), sorted_array.end()); - int32_t value = MAP_NONE; - for (auto &element : sorted_array) { - if (element == MAP_NONE) { - continue; - } - if (element == value) { - return false; - } - value = element; - } - return true; -} - -int32_t Map::GetMaxItem() const { - if (!array_.empty()) { - return *std::max_element(array_.begin(), array_.end()); - } else { - return MAP_NONE; - } -} - -int32_t Map::GetIndexByValue(int32_t value) const { - auto iter = find(array_.begin(), array_.end(), value); - if (iter != array_.end()) { - return static_cast(std::distance(array_.begin(), iter)); - } else { - return MAP_NONE; - } -} - -/* - * expand.size() should be equal to array_.size() - */ -std::shared_ptr Map::ExpandMapByNone(const Arrangement &expand_num_list) const { - if (expand_num_list.GetDimSize() != GetDimSize()) { - return nullptr; - } - std::vector new_shape; - for (uint32_t i = 0; i != GetDimSize(); i++) { - if (GetDimByIdx(i) == MAP_NONE) { - for (int32_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) { - new_shape.push_back(MAP_NONE); - } - } else { - new_shape.push_back(GetDimByIdx(i)); - int32_t j = 1; - while (j < expand_num_list.GetDimByIdx(i)) { - new_shape.push_back(MAP_NONE); - j++; - } - } - } - auto map_new = std::make_shared(); - (void)map_new->Init(new_shape); - return map_new; -} - -/* - * expand.size() should be equal to array_.size() - */ -std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const { - if (GetMaxItem() >= static_cast(expand_num_list.GetDimSize())) { - return nullptr; - } - std::vector new_shape; - for (uint32_t i = 0; i < GetDimSize(); i++) { - if (GetDimByIdx(i) == MAP_NONE) { - new_shape.push_back(MAP_NONE); - } else { - int32_t start_map = - expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast(GetDimByIdx(i))]; - for (int32_t k = expand_num_list.GetDimByReverseIdx(static_cast(GetDimByIdx(i))) - 1; k >= 0; k--) { - new_shape.push_back(k + start_map); - } - } - } - auto map_new = std::make_shared(); - (void)map_new->Init(new_shape); - return map_new; -} - -std::shared_ptr> Map::ReMapVector(const std::vector &input_vector) const { - if (GetMaxItem() >= static_cast(input_vector.size())) { - return nullptr; - } - std::vector out; - Arrangement empty_arrangement; - for (uint32_t i = 0; i < GetDimSize(); i++) { - if (GetDimByIdx(i) == MAP_NONE) { - out.push_back(empty_arrangement); - } else { - out.push_back(input_vector[IntToUint(SizeToInt(input_vector.size()) - 1 - GetDimByIdx(i))]); - } - } - return std::make_shared>(out); -} - -bool Map::CheckNoneByIdxList(std::vector idx_list) const { - for (auto &value : idx_list) { - if (GetDimByIdx(SizeToUint(value)) != MAP_NONE) { - return false; - } - } - return true; -} - -Map Map::SqueezeMapByIdxList(std::vector idx_list) const { - std::vector out_shape; - for (size_t i = 0; i < GetDimSize(); i++) { - auto it = std::find(idx_list.begin(), idx_list.end(), i); - if (it == idx_list.end()) { - out_shape.push_back(GetDimByIdx(SizeToUint(i))); - } - } - if (out_shape.empty()) { - MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation"; - out_shape.push_back(MAP_NONE); - } - Map out; - (void)out.Init(out_shape); - return out; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/map.h b/mindspore/ccsrc/parallel/tensor_layout/map.h deleted file mode 100644 index 8c8bba2775..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/map.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ - -#include -#include -#include -#include -#include -#include "parallel/status.h" -#include "parallel/tensor_layout/arrangement.h" -#include "parallel/tensor_layout/array.h" - -namespace mindspore { -namespace parallel { -constexpr int32_t MAP_NONE = -1; - -class Map : public Array { - public: - Map() = default; - ~Map() override = default; - Status Init(const std::vector &array) override; - int32_t GetMaxItem() const; - int32_t GetIndexByValue(int32_t value) const; - std::shared_ptr ExpandMapByNone(const Arrangement &expand_num_list) const; - std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const; - std::shared_ptr> ReMapVector(const std::vector &input_vector) const; - bool CheckNoneByIdxList(std::vector idx_list) const; - Map SqueezeMapByIdxList(std::vector idx_list) const; - - private: - bool IsValidMap(); -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.cc deleted file mode 100644 index 7ed07ac02e..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/redistribution_layout_transfer.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/reshape_layout_transfer.h" -#include "parallel/tensor_layout/shape_util.h" - -namespace mindspore { -namespace parallel { -Status RedistributionLayoutTransfer::CheckValidTransfer() { return Status::SUCCESS; } - -/* - * unify device arrangement between in_layout and out_layout - * after this function is called, - * in_step1_layout.device_arrangement and out_step1_layout.device_arrangement will be the same - */ -std::shared_ptr RedistributionLayoutTransfer::UnifyDeviceArrangement() const { - Arrangement in_arrangement; - Arrangement out_arrangement; - in_arrangement = from_in_.device_arrangement(); - out_arrangement = to_in_.device_arrangement(); - std::shared_ptr unify_arrangement_ptr = in_arrangement.GetUnifiedShape(out_arrangement); - if (unify_arrangement_ptr == nullptr) { - return nullptr; - } - std::shared_ptr from_out_ptr = from_in_.ExpandDeviceArrangement(*unify_arrangement_ptr); - if (from_out_ptr == nullptr) { - return nullptr; - } - std::shared_ptr to_out_ptr = to_in_.ExpandDeviceArrangement(*unify_arrangement_ptr); - if (to_out_ptr == nullptr) { - return nullptr; - } - ReshapeLayoutTransfer out; - Status status = out.Init(*from_out_ptr, *to_out_ptr); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(out); -} - -/* - * unify tensor shape between in_step1_layout.tensor_shape and out_step1_layout.tensor_shape - * after this function is called, - * in_step2_layout.tensor_shape and out_step2_layout.tensor_shape will be the same - */ -std::shared_ptr RedistributionLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { - std::shared_ptr unified_device_arrangement_ptr = UnifyDeviceArrangement(); - if (unified_device_arrangement_ptr == nullptr) { - return nullptr; - } - return unified_device_arrangement_ptr->UnifyDeviceArrangementAndTensorShape(); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.h deleted file mode 100644 index 7b57f46dd6..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ - -#include -#include "parallel/status.h" -#include "parallel/tensor_layout/layout_transfer.h" -#include "parallel/tensor_layout/reshape_layout_transfer.h" - -namespace mindspore { -namespace parallel { -class RedistributionLayoutTransfer : public LayoutTransfer { - public: - RedistributionLayoutTransfer() = default; - ~RedistributionLayoutTransfer() override = default; - std::shared_ptr UnifyDeviceArrangementAndTensorShape() const; - - private: - Status CheckValidTransfer() override; - std::shared_ptr UnifyDeviceArrangement() const; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc deleted file mode 100644 index 946620ec4c..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc +++ /dev/null @@ -1,289 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/redistribution_operator_infer.h" - -#include - -#include "parallel/device_manager.h" - -namespace mindspore { -namespace parallel { -Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, - RankList dev_list, bool is_cost_model) { - in_tensor_map_ = tensor_layout.tensor_map(); - dev_mat_ = tensor_layout.device_arrangement(); - - if (in_tensor_map_.GetDimSize() == 0 || out_tensor_map.GetDimSize() != in_tensor_map_.GetDimSize()) { - MS_LOG(ERROR) << "Invalid input when initialize RedistributionOperatorInfer!"; - return Status::FAILED; - } - - cur_tensor_layout_ = tensor_layout; - out_tensor_map_ = out_tensor_map; - dev_list_ = std::move(dev_list); - - operator_list_.clear(); - operator_vector_.clear(); - output_info_vector_.clear(); - - if (constructor_.Init(dev_list_, dev_mat_.array()) != Status::SUCCESS) { - MS_LOG(ERROR) << "Init constructor failed"; - return Status::FAILED; - } - constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array()); - - size_t key = 0; - std::vector map = in_tensor_map_.array(); - for (int32_t item : map) { - map_[key++] = item; - } - - is_cost_model_ = is_cost_model; - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::InferRedistributionOperator() { - while (!map_.empty()) { - size_t len_global = operator_list_.size(); - - while (!map_.empty()) { - size_t len_split_by_axis = operator_list_.size(); - // split_by_axis operation - if (InferSplitByAxis() == Status::FAILED) { - return Status::FAILED; - } - // permute_by_axis operation - while (!map_.empty()) { - size_t len_permute_by_axis = operator_list_.size(); - if (InferPermuteByAxis() == Status::FAILED) { - return Status::FAILED; - } - if (len_permute_by_axis == operator_list_.size()) break; - } - if (len_split_by_axis == operator_list_.size()) break; - } - // concat_by_axis operation - if (InferConcatByAxis() == Status::FAILED) { - return Status::FAILED; - } - // break loop structure with concat_by_axis - if (len_global == operator_list_.size() && !map_.empty()) { - size_t index = map_.begin()->first; - int32_t in_dim = map_[index]; - map_[index] = NONE; - Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; - if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { - return Status::FAILED; - } - } - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::InferSplitByAxis() { - for (auto iter = map_.begin(); iter != map_.end();) { - uint32_t index = iter->first; - int32_t in_dim = iter->second; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); - if (in_dim == out_dim) { - (void)map_.erase(iter++); - continue; - } - if (in_dim == NONE && - !std::any_of(map_.begin(), map_.end(), - [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { - Args args = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; - if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) { - MS_LOG(ERROR) << "Insert SplitByAxis Error!"; - return Status::FAILED; - } - (void)map_.erase(iter++); - } else { - (void)++iter; - } - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::InferPermuteByAxis() { - for (auto iter = map_.begin(); iter != map_.end();) { - uint32_t index = iter->first; - int32_t in_dim = map_[index]; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); - if (in_dim == out_dim) { - (void)map_.erase(iter++); - continue; - } - if (in_dim == NONE && - std::any_of(map_.begin(), map_.end(), - [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { - int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); - int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); - if (is_cost_model_) { - int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); - Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim, - dev_num}; - if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) { - MS_LOG(ERROR) << "Insert PermuteByAxis Error!"; - return Status::FAILED; - } - } else { - Args args_allconcat = {cat_dim, out_dim, dev_num}; - Args args_allsplit = {dev_num, UintToInt(index), out_dim}; - if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { - MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; - return Status::FAILED; - } - if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { - MS_LOG(ERROR) << "Insert SplitByAxis Error!"; - return Status::FAILED; - } - } - (void)map_.erase(iter++); - map_[IntToSize(cat_dim)] = NONE; - } else { - (void)++iter; - } - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::InferConcatByAxis() { - for (auto iter = map_.begin(); iter != map_.end();) { - uint32_t index = iter->first; - int32_t in_dim = map_[index]; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); - if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) { - Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; - if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { - MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; - return Status::FAILED; - } - if (out_dim == NONE) { - (void)map_.erase(iter++); - } else { - map_[index] = NONE; - (void)++iter; - } - } else { - (void)++iter; - } - } - return Status::SUCCESS; -} - -// Transfer communicative operators into primitives and insert them into vector -Status RedistributionOperatorInfer::InsertOperator(OperatorName name, Args args) { - OperatorR op = std::make_pair(name, args); - OperatorC op_cost = std::make_pair(op, cur_tensor_layout_.slice_shape().array()); - operator_list_.push_back(op_cost); - if (construct_op_flag_) { - if (name == SPLIT_BY_AXIS) { - if (TransferSplitByAxis(args) == Status::FAILED) { - return Status::FAILED; - } - } else if (name == PERMUTE_BY_AXIS) { - if (TransferPermuteByAxis(args) == Status::FAILED) { - return Status::FAILED; - } - } else { - if (TransferConcatByAxis(args) == Status::FAILED) { - return Status::FAILED; - } - } - constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array()); - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::TransferSplitByAxis(Args args) { - if (args.size() < 3) { - MS_LOG(ERROR) << "args size should not be less than 3!"; - return Status::FAILED; - } - uint32_t index = IntToUint(args[1]); - if (constructor_.StridedSliceOP(args) != Status::SUCCESS) { - return Status::FAILED; - } else { - operator_vector_.push_back(constructor_.GetOperator()); - output_info_vector_.push_back(std::make_pair(false, 0)); - } - if (cur_tensor_layout_.UpdateTensorMap(index, args[2]) == Status::FAILED) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::TransferPermuteByAxis(Args args) { - if (args.size() < 3) { - MS_LOG(ERROR) << "args size should not be less than 3!"; - return Status::FAILED; - } - if (constructor_.AlltoAllOP(args) != Status::SUCCESS) { - return Status::FAILED; - } else { - operator_vector_.push_back(constructor_.GetOperator()); - output_info_vector_.push_back(std::make_pair(false, 0)); - } - uint32_t index = IntToUint(args[1]); - int32_t val = args[2]; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); - - if (cur_tensor_layout_.UpdateTensorMap(IntToUint(val), NONE) == Status::FAILED) { - return Status::FAILED; - } - if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) { - if (args.size() < 3) { - MS_LOG(ERROR) << "args size should not be less than 3!"; - return Status::FAILED; - } - int32_t tensor_dim = args[0]; - int32_t dev_dim = args[1]; - int32_t split_count = args[2]; - if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) { - return Status::FAILED; - } else { - operator_vector_.push_back(constructor_.GetOperator()); - output_info_vector_.push_back(std::make_pair(false, 0)); - } - if (tensor_dim != 0) { - if (constructor_.SplitOP(split_count) != Status::SUCCESS) { - return Status::FAILED; - } else { - operator_vector_.push_back(constructor_.GetOperator()); - output_info_vector_.push_back(std::make_pair(true, split_count)); - } - if (constructor_.ConcatOP(tensor_dim) != Status::SUCCESS) { - return Status::FAILED; - } else { - operator_vector_.push_back(constructor_.GetOperator()); - output_info_vector_.push_back(std::make_pair(false, 0)); - } - } - if (cur_tensor_layout_.UpdateTensorMap(IntToUint(tensor_dim), NONE) == Status::FAILED) { - return Status::FAILED; - } - return Status::SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h deleted file mode 100644 index 37a8ac3d9e..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ - -#include -#include -#include -#include -#include - -#include "parallel/tensor_layout/construct_operator.h" -#include "parallel/tensor_layout/redistribution_layout_transfer.h" -#include "utils/convert_utils.h" -namespace mindspore { -namespace parallel { -using DeviceArrangement = std::vector; -using TensorMap = std::vector; -using TensorShape = std::vector; -using RedistributionOperatorMap = std::unordered_map; -using OperatorR = std::pair; -using OperatorC = std::pair; -using OperatorList = std::vector; - -class RedistributionOperatorInfer { - public: - const int NONE = -1; - explicit RedistributionOperatorInfer(bool construct_op_flag = true) - : construct_op_flag_(construct_op_flag), is_cost_model_(false) {} - Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, - bool is_cost_model = false); - ~RedistributionOperatorInfer() = default; - OperatorList operator_list() const { return operator_list_; } - OperatorVector operator_vector() const { return operator_vector_; } - OutPutInfoVector output_info_vector() const { return output_info_vector_; } - Status InferRedistributionOperator(); - - private: - Status InferSplitByAxis(); - Status InferPermuteByAxis(); - Status InferConcatByAxis(); - Status TransferSplitByAxis(Args args); - Status TransferPermuteByAxis(Args args); - Status TransferConcatByAxis(Args args); - Status InsertOperator(OperatorName name, Args args); - - OperatorList operator_list_; - OperatorVector operator_vector_; - OutPutInfoVector output_info_vector_; - Arrangement dev_mat_; - RedistributionOperatorMap map_; - Map in_tensor_map_; - Map out_tensor_map_; - TensorLayout cur_tensor_layout_; - ConstructOperator constructor_; - RankList dev_list_; - bool construct_op_flag_; - bool is_cost_model_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc deleted file mode 100644 index 4c66befd78..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/reshape_layout_transfer.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/shape_util.h" - -namespace mindspore { -namespace parallel { -Status ReshapeLayoutTransfer::CheckValidTransfer() { - if (!IsSameDeviceArrangement()) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -std::shared_ptr ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { - bool is_unified = IsSameTensorShape(); - std::shared_ptr out_layout_ptr = std::make_shared(*this); - if (out_layout_ptr == nullptr) { - return nullptr; - } - while (!is_unified) { - std::shared_ptr temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); - if (temp_layout_ptr == nullptr) { - return nullptr; - } - out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); - if (out_layout_ptr == nullptr) { - return nullptr; - } - is_unified = out_layout_ptr->IsSameTensorShape(); - } - return out_layout_ptr; -} - -std::shared_ptr ReshapeLayoutTransfer::ExtendFromTensorShapeByTo() const { - std::shared_ptr out_ptr = std::make_shared(*this); - bool is_expanded = FromTensorShapeCanBeExpandByTo(); - while (!is_expanded) { - out_ptr = out_ptr->ExtendFromTensorShapeByExpandedTensorShape(); - if (out_ptr == nullptr) { - return nullptr; - } - is_expanded = out_ptr->FromTensorShapeCanBeExpandByTo(); - } - return out_ptr; -} - -std::shared_ptr ReshapeLayoutTransfer::ExtendToTensorShapeByFrom() const { - std::shared_ptr out_ptr = std::make_shared(*this); - bool is_expanded = ToTensorShapeCanBeExpandByFrom(); - while (!is_expanded) { - out_ptr = out_ptr->ExtendToTensorShapeByExpandedTensorShape(); - if (out_ptr == nullptr) { - return nullptr; - } - is_expanded = out_ptr->ToTensorShapeCanBeExpandByFrom(); - } - return out_ptr; -} - -bool ReshapeLayoutTransfer::FromTensorShapeCanBeExpandByTo() const { - return from_in_.TensorShapeCanBeExpanded(to_in_.tensor_shape()); -} - -bool ReshapeLayoutTransfer::ToTensorShapeCanBeExpandByFrom() const { - return to_in_.TensorShapeCanBeExpanded(from_in_.tensor_shape()); -} - -std::shared_ptr ReshapeLayoutTransfer::ExtendFromTensorShapeByExpandedTensorShape() const { - std::shared_ptr expanded_shape_ptr = ComputeExpandedFromTensorShapeByTo(); - if (expanded_shape_ptr == nullptr) { - return nullptr; - } - return ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); -} - -std::shared_ptr ReshapeLayoutTransfer::ExtendToTensorShapeByExpandedTensorShape() const { - std::shared_ptr exchanged_from_and_to_ptr = ExchangeFromAndTo(); - if (exchanged_from_and_to_ptr == nullptr) { - return nullptr; - } - std::shared_ptr expanded_shape_ptr = exchanged_from_and_to_ptr->ComputeExpandedFromTensorShapeByTo(); - if (expanded_shape_ptr == nullptr) { - return nullptr; - } - std::shared_ptr exchanged_out = - exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); - if (exchanged_out == nullptr) { - return nullptr; - } - return exchanged_out->ExchangeFromAndTo(); -} - -std::shared_ptr ReshapeLayoutTransfer::ExchangeFromAndTo() const { - ReshapeLayoutTransfer out; - Status status = out.Init(to_in_, from_in_); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(out); -} - -std::shared_ptr ReshapeLayoutTransfer::ExpandFromTensorShapeAndExpandToDeviceArrangement( - const Arrangement &expand_shape) const { - std::shared_ptr extend_tensor_shape_from_ptr = from_in_.ExpandTensorShape(expand_shape); - if (extend_tensor_shape_from_ptr == nullptr) { - return nullptr; - } - Arrangement unified_device_arrangement = extend_tensor_shape_from_ptr->device_arrangement(); - std::shared_ptr extend_device_arrangement_to_ptr = - to_in_.ExpandDeviceArrangement(unified_device_arrangement); - if (extend_device_arrangement_to_ptr == nullptr) { - return nullptr; - } - ReshapeLayoutTransfer out; - Status status = out.Init(*extend_tensor_shape_from_ptr, *extend_device_arrangement_to_ptr); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(out); -} - -std::shared_ptr ReshapeLayoutTransfer::ComputeExpandedFromTensorShapeByTo() const { - return from_in_.ComputeExpandedTensorShape(to_in_.tensor_shape()); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h deleted file mode 100644 index ed62cb59da..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ - -#include -#include "parallel/status.h" -#include "parallel/tensor_layout/layout_transfer.h" - -namespace mindspore { -namespace parallel { -class ReshapeLayoutTransfer : public LayoutTransfer { - public: - ReshapeLayoutTransfer() = default; - ~ReshapeLayoutTransfer() override = default; - std::shared_ptr UnifyDeviceArrangementAndTensorShape() const; - std::shared_ptr ExtendFromTensorShapeByTo() const; - std::shared_ptr ExtendToTensorShapeByFrom() const; - std::shared_ptr ExtendFromTensorShapeByExpandedTensorShape() const; - std::shared_ptr ExtendToTensorShapeByExpandedTensorShape() const; - std::shared_ptr ExpandFromTensorShapeAndExpandToDeviceArrangement( - const Arrangement &expand_shape) const; - std::shared_ptr ExchangeFromAndTo() const; - - private: - Status CheckValidTransfer() override; - std::shared_ptr ComputeExpandedFromTensorShapeByTo() const; - bool FromTensorShapeCanBeExpandByTo() const; - bool ToTensorShapeCanBeExpandByFrom() const; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc b/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc deleted file mode 100644 index e8f208708c..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc +++ /dev/null @@ -1,263 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/shape_util.h" -#include -#include "parallel/status.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -/* - * example: - * shape = [2, 8, 32] - * shape_accum = [2, 2 * 8, 2 * 8 * 32] - */ -Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum) { - MS_EXCEPTION_IF_NULL(shape_accum); - shape_accum->clear(); - int64_t size = 1; - for (auto iter = shape.begin(); iter < shape.end(); ++iter) { - size *= *iter; - if (size <= 0) { - MS_LOG(ERROR) << "element of shape should not be zero"; - return Status::FAILED; - } - shape_accum->push_back(size); - } - return Status::SUCCESS; -} - -/* - * example: - * shape = [2, 8, 32] - * shape_accum = [2 * 8 * 32, 8 * 32, 32] - * - */ -Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum) { - MS_EXCEPTION_IF_NULL(shape_accum); - shape_accum->clear(); - int64_t size = 1; - for (auto iter = shape.end() - 1; iter >= shape.begin(); --iter) { - size *= *iter; - if (size <= 0) { - MS_LOG(ERROR) << "element of shape should not be zero"; - return Status::FAILED; - } - (void)shape_accum->insert(shape_accum->begin(), size); - } - return Status::SUCCESS; -} - -/* - * example: - * shape_accum = [2, 2 * 8, 2 * 8 * 32] - * shape = [2, 8, 32] - * - */ -Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape) { - MS_EXCEPTION_IF_NULL(shape); - shape->clear(); - int64_t value = 1; - for (auto iter = shape_accum.begin(); iter < shape_accum.end(); ++iter) { - if ((*iter) == 0) { - MS_LOG(ERROR) << "element of shape_accum should not be zero"; - return Status::FAILED; - } - if ((*iter) % value != 0) { - MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; - return Status::FAILED; - } - shape->push_back(static_cast((*iter) / value)); - value = (*iter); - } - return Status::SUCCESS; -} - -/* - * example: - * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * shape = [2, 8, 32] - */ -Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape) { - MS_EXCEPTION_IF_NULL(shape); - shape->clear(); - int64_t value = 1; - for (auto iter = shape_accum_reverse.end() - 1; iter >= shape_accum_reverse.begin(); --iter) { - if (*iter == 0) { - MS_LOG(ERROR) << "element of shape_accum should not be zero"; - return Status::FAILED; - } - if ((*iter) % value != 0) { - MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; - return Status::FAILED; - } - (void)shape->insert(shape->begin(), static_cast((*iter) / value)); - value = *iter; - } - return Status::SUCCESS; -} - -/* - * example1: - * in1 = [2, 8] - * in2 = [4, 8] - * *out = [2, 4, 8] - * - * example2: - * in1 = [2, 4, 16] - * in2 = [8, 16] - * *out = [2, 4, 8, 16] - */ -Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, - std::vector *out_accum) { - MS_EXCEPTION_IF_NULL(out_accum); - out_accum->clear(); - auto in1_iter = in1_accum.begin(); - auto in2_iter = in2_accum.begin(); - while ((in1_iter < in1_accum.end()) || (in2_iter < in2_accum.end())) { - if ((*in1_iter <= 0) || (*in2_iter <= 0)) { - MS_LOG(ERROR) << "element of in1 and in2 must be larger than zero"; - return Status::FAILED; - } - if (*in1_iter < *in2_iter) { - out_accum->push_back(*in1_iter); - ++in1_iter; - continue; - } else if (*in1_iter == *in2_iter) { - out_accum->push_back(*in1_iter); - ++in1_iter; - ++in2_iter; - } else { - out_accum->push_back(*in2_iter); - ++in2_iter; - } - } - if ((in1_iter != in1_accum.end()) || (in2_iter != in2_accum.end())) { - MS_LOG(ERROR) << "last element of in1 and in2 must be equal"; - return Status::FAILED; - } - return Status::SUCCESS; -} - -/* - * example: - * in1 = [8, 4] - * in2 = [2, 16] - * out = [2, 4, 4] - */ -Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out) { - MS_EXCEPTION_IF_NULL(out); - std::vector in1_accum; - Status status = ShapeToAccumulateProduct(in1, &in1_accum); - if (status != Status::SUCCESS) { - return status; - } - std::vector in2_accum; - status = ShapeToAccumulateProduct(in2, &in2_accum); - if (status != Status::SUCCESS) { - return status; - } - std::vector out_accum; - status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum); - if (status != Status::SUCCESS) { - return status; - } - status = AccumulateProductToShape(out_accum, out); - if (status != Status::SUCCESS) { - return status; - } - return status; -} - -/* - * example1: - * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * expand_accum_reverse = [2 * 8 * 32, 32, 8] - * out_accum_reverse = [2 * 8 * 4 * 8, 8 * 4 * 8, 4 * 8, 8] - * - * example2: - * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] - * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] - */ -Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, - const std::vector &expand_accum_reverse, - std::vector *out_accum_reverse) { - MS_EXCEPTION_IF_NULL(out_accum_reverse); - out_accum_reverse->clear(); - auto in_riter = in_accum_reverse.rbegin(); - auto expand_riter = expand_accum_reverse.rbegin(); - while (expand_riter != expand_accum_reverse.rend()) { - if (in_riter == in_accum_reverse.rend()) { - MS_LOG(ERROR) << "invalid ExpandAccumProd inputs"; - return Status::FAILED; - } - if (*in_riter > *expand_riter) { - (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter); - ++expand_riter; - } else if (*in_riter == *expand_riter) { - (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter); - ++in_riter; - ++expand_riter; - } else { - (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter); - ++in_riter; - } - } - while (in_riter != in_accum_reverse.rend()) { - (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter); - ++in_riter; - } - return Status::SUCCESS; -} - -/* - * example1: - * in = [2, 8, 32] - * expand = [16, 4, 8] - * out = [2, 8, 4, 8] - * - * example2: - * in = [2, 8, 32] - * expand = [2, 4, 8] - * out = [2, 4, 2, 4, 8] - */ -Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out) { - MS_EXCEPTION_IF_NULL(out); - std::vector in_accum_reverse; - Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse); - if (status != Status::SUCCESS) { - return status; - } - std::vector expand_accum_reverse; - status = ShapeToAccumulateProductReverse(expand, &expand_accum_reverse); - if (status != Status::SUCCESS) { - return status; - } - std::vector out_accum_reverse; - status = ExpandAccumulateProduct(in_accum_reverse, expand_accum_reverse, &out_accum_reverse); - if (status != Status::SUCCESS) { - return status; - } - status = AccumulateProductReverseToShape(out_accum_reverse, out); - if (status != Status::SUCCESS) { - return status; - } - return status; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/shape_util.h b/mindspore/ccsrc/parallel/tensor_layout/shape_util.h deleted file mode 100644 index 2ec21f3881..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/shape_util.h +++ /dev/null @@ -1,172 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ - -#include -#include -#include -#include -#include - -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -/* - * compute the accumulating product of all the values in shape from left to right, - * the accumulating results are saved in shape_accum from left to right - * - * given a shape = [d_n-1, d_n-2, ..., d_0](d_i > 0, i=0,1,...,n-1, elements of shape must be larger than zero), - * then *shape_accum = [d_n-1, d_n-1 * d_n-2, d_n-1 * d_n-2 * d_n-3, ..., d_n-1 * d_n-2 * ... *d_0] - * - * example: - * shape = [2, 8, 32] - * shape_accum = [2, 2 * 8, 2 * 8 * 32] - * - */ -Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum); - -/* - * compute the accumulating product of all the values in shape from right to left, - * the accumulating results are saved in shape_accum from right to left - * - * given a shape = [d_n-1, d_n-2, ..., d_0](d_i > 0, i=0,1,...,n-1, elements of shape must be larger than zero), - * then *shape_accum = [d_n-1 * d_n-2 * ... *d_0, d_n-2 * d_n-3 * ... *d_0, ..., d_0] - * - * example: - * shape = [2, 8, 32] - * shape_accum = [2 * 8 * 32, 8 * 32, 32] - * - */ -Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum); - -/* - * compute the original shape from the accumulating product shape_accum, - * elements of shape_accum is saved from left to right, - * given shape_accum = [accum_n-1, accum_n-2, accum_n-3, ..., accum_0] - * (accum_i > 0, i=0,1,...,n-1, elements of shape_accum must be larger than zero), - * (accum_i-1 % accum_i == 0, i=1,...,n-1) - * then *shape = [accum_n-2/accum_n-1, accum_n-3/accum_n-2, ..., accum_0/accum_1] - * - * example: - * shape_accum = [2, 2 * 8, 2 * 8 * 32] - * shape = [2, 8, 32] - * - */ -Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape); - -/* - * compute the original shape from the accumulating product shape_accum, - * elements of shape_accum is saved from right to left, - * given shape_accum_reverse = [accum_n-1, accum_n-2, accum_n-3, ..., accum_0] - * (accum_i > 0, i=0,1,...,n-1, elements of shape_accum must be larger than zero), - * (accum_i % accum_i-1 == 0, i=1,...,n-1) - * then *shape = [accum_n-1/accum_n-2, accum_n-2/accum_n-1, ..., accum_1/accum_0] - * - * example: - * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * shape = [2, 8, 32] - * - */ -Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape); - -/* - * given two accumulate product in1_accum and in2_accum, compute the union of in1_accum and in2_accum, - * results are saved in out. - * i.e. *out_accum = in1_accum U in2_accum - * elements of out are saved in increasing order - * - * example1: - * in1_accum = [2, 8] - * in2_accum = [4, 8] - * out_accum = [2, 4, 8] - * - * example2: - * in1_accum = [2, 4, 16] - * in2_accum = [8, 16] - * out_accum = [2, 4, 8, 16] - */ -Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, - std::vector *out_accum); - -/* - * given two shape in1 = [din1_n-1, din1_n-2, ..., din1_0] and in2 = [din2_m-1, din2_m-2, ..., din2_m] - * size = din1_n-1 * din1n-2 * ... * din1_0 = din2_m-1 * din2_m-2 * ... * din2_0 - * find *out = [dout_k-1, dout_k-2, ..., dout_0], s.t. dout_k-1 * dout_k-2 * ... * dout_0 = size and - * suppose in1_accum, in2_accum, and *out_accum is the ShapeToAccumulateProduct result of in1, in2, and *out - * then for each din1_i in in1_accum, din1_i is in *out_accumulate, - * for each din2_i in in2_accum, din2_i is in *out_accumulate - * - * example: - * in1 = [8, 4] - * in2 = [2, 16] - * out = [2, 4, 4] - */ -Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out); - -/* - * given two accumulate product in reverse order of in and expand, - * in_accum_reverse = [din_n-1, din_n-2, ..., din_0] and expand_pos_reverse = [dexp_n-1, dexp_n-2, ..., dexp_0], - * i.e. in_accum_reverse is the ShapeToAccumulateProductReverse result of a shape in, - * expand_accum_reverse is the ShapeToAccumulateProductReverse result of a shape expand, - * compute the accumulate product in reverse order out_accum_reverse = [dout_k-1, dout_k-2, ..., dout_0], - * s.t. elements in out_accum_reverse are union of elements in in_accum_reverse and expand_accum_reverse - * (out_accum_reverse = in_accum_reverse U expand_accum_reverse), and - * out_accum_reverse is the ShapeToAccumulateProductReverse result of shape expand, - * i.e. dout_i > 0, i=0,1,...,k-1, elements of out_accum_reverse must be larger than zero, - * dout_i-1 % dout_i == 0, i=1,...,k-1 - * - * example1: - * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * expand_accum_reverse = [2 * 8 * 32, 32, 8] - * out_accum_reverse = [2 * 8 * 4 * 8, 8 * 4 * 8, 4 * 8, 8] - * - * example2: - * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] - * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] - */ -Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, - const std::vector &expand_accum_reverse, - std::vector *out_accum_reverse); - -/* - * given a shape in = [din_n-1, din_n-2, ..., d_0], and the expand shape expand= [dexp_m-1, dexp_m-2, ..., dexp_0], - * compute the expended shape out = [dout_k-1, dout_k-2, ..., dout_0], - * s.t. dout_k-1 * dout_k-2 * ...* dout_0 = din_n-1 * din_n-2 * ... * d_0 - * suppose in_accum_reverse is the ShapeToAccumulateProductReverse result of in, - * expand_accum_reverse is the ShapeToAccumulateProductReverse result of expand, - * out_accum_reverse is the ShapeToAccumulateProductReverse result of out, - * then out_accum_reverse is the union of in_accum_reverse and expand_accum_reverse - * (out_accum_reverse = in_accum_reverse U expand_accum_reverse) - * - * example1: - * in = [2, 8, 32] - * expand = [16, 4, 8] - * out = [2, 8, 4, 8] - * - * example2: - * in = [2, 8, 32] - * expand = [2, 4, 8] - * out = [2, 4, 2, 4, 8] - */ -Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h deleted file mode 100644 index 0eee736cea..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ - -#include -#include -#include -#include - -#include "parallel/device_matrix.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -using Shapes = std::vector; - -class TensorInfo { - public: - TensorInfo(const TensorLayout &tensor_layout, Shape shape, Shape slice_shape) - : tensor_layout_(tensor_layout), shape_(std::move(shape)), slice_shape_(std::move(slice_shape)) {} - explicit TensorInfo(const TensorLayout &tensor_layout) : tensor_layout_(tensor_layout) { - shape_ = tensor_layout.tensor_shape().array(); - slice_shape_ = tensor_layout.slice_shape().array(); - } - // trivial default constructor will not initialize c language types. - TensorInfo() = default; - ~TensorInfo() = default; - TensorLayout tensor_layout() const { return tensor_layout_; } - Shape slice_shape() const { return slice_shape_; } - Shape shape() const { return shape_; } - void set_reduce_dim(const std::vector &dim) { reduce_dim_ = dim; } - std::vector reduce_dim() const { return reduce_dim_; } - Dimensions InferStrategy() const { - Dimensions stra; - for (size_t i = 0; i < shape_.size(); ++i) { - if ((slice_shape_[i] == 0) || (shape_[i] % slice_shape_[i] != 0)) { - return stra; - } - int32_t dim = (int32_t)(shape_[i] / slice_shape_[i]); - stra.push_back(dim); - } - return stra; - } - - private: - TensorLayout tensor_layout_; - Shape shape_; - Shape slice_shape_; - // reduce method's reduce dim - std::vector reduce_dim_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc deleted file mode 100644 index f3498065f2..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc +++ /dev/null @@ -1,394 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/tensor_layout.h" -#include -#include -#include "common/utils.h" -#include "ir/value.h" -#include "parallel/device_matrix.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/array.h" -#include "parallel/tensor_layout/shape_util.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -std::string TensorLayout::ToString() const { return StandardToString() + OriginToString(); } - -std::string TensorLayout::StandardToString() const { - std::ostringstream buffer; - buffer << std::endl << std::string("device arrangement = " + device_arrangement_.ToString()); - buffer << std::endl << std::string("tensor map = " + tensor_map_.ToString()); - buffer << std::endl << std::string("tensor shape = " + tensor_shape_.ToString()); - return buffer.str(); -} - -std::string TensorLayout::OriginToString() const { - std::ostringstream buffer; - buffer << std::endl << std::string("device arrangement origin = " + device_arrangement_origin_.ToString()); - buffer << std::endl << std::string("tensor map origin = " + tensor_map_origin_.ToString()); - buffer << std::endl << std::string("tensor shape origin = " + tensor_shape_origin_.ToString()); - return buffer.str(); -} - -Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tensor_map, - const Arrangement &tensor_shape) { - device_arrangement_origin_ = device_arrangement; - tensor_map_origin_ = tensor_map; - tensor_shape_origin_ = tensor_shape; - device_arrangement_ = device_arrangement; - tensor_map_ = tensor_map; - tensor_shape_ = tensor_shape; - if (IsValidTensorLayout()) { - MS_LOG(DEBUG) << "valid origin tensor layout " << this->OriginToString(); - RemoveElementEqualToOneInDeviceArrangement(); - MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString(); - return Status::SUCCESS; - } else { - MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); - return Status::FAILED; - } -} - -Status TensorLayout::InitFromVector(const std::vector &device_arrangement, - const std::vector &tensor_map, const std::vector &tensor_shape) { - if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) { - return FAILED; - } - if (tensor_map_origin_.Init(tensor_map) != SUCCESS) { - return FAILED; - } - if (tensor_shape_origin_.Init(tensor_shape) != SUCCESS) { - return FAILED; - } - if (Init(device_arrangement_origin_, tensor_map_origin_, tensor_shape_origin_) != SUCCESS) { - return FAILED; - } - return SUCCESS; -} - -bool TensorLayout::IsValidTensorLayout() const { - if (tensor_map_origin_.GetMaxItem() >= static_cast(device_arrangement_origin_.GetDimSize())) { - MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size!"; - return false; - } - if (tensor_map_origin_.GetDimSize() != tensor_shape_origin_.GetDimSize()) { - MS_LOG(ERROR) << "tensor_map_origin_ size must be equal to tensor_shape_origin_ size!"; - return false; - } - if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) { - MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; - return false; - } - return true; -} - -bool TensorLayout::TensorShapeDimensionIsDividedBySplitDeviceDimension() const { - for (uint32_t i = 0; i < tensor_map_.GetDimSize(); i++) { - if (tensor_map_.GetDimByIdx(i) != -1) { - int32_t divisor = GetSliceNumByTensorDimensionIndex(i); - if (divisor == 0) { - MS_LOG(ERROR) << "GetSliceNumByTensorDimensionIndex is 0"; - return false; - } - if (tensor_shape_.GetDimByIdx(i) % divisor != 0) { - return false; - } - } - } - return true; -} - -void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { - std::vector device_arrangement_shape; - std::vector tensor_map_shape = tensor_map_origin_.array(); - uint32_t dev_num = SizeToUint(device_arrangement_origin_.GetDimSize()); - int32_t dev_num_left = SizeToInt(device_arrangement_origin_.GetDimSize()); - for (uint32_t i = 0; i < dev_num; i++) { - if (device_arrangement_origin_.GetDimByIdx(i) == 1) { - int32_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast(dev_num - 1 - i)); - if (idx != -1) { - tensor_map_shape[static_cast(idx)] = -1; - } - for (auto &value : tensor_map_shape) { - if (value >= dev_num_left - 1 - static_cast(i)) { - value--; - } - } - continue; - } - device_arrangement_shape.push_back(device_arrangement_origin_.GetDimByIdx(i)); - } - (void)device_arrangement_.Init(device_arrangement_shape); - (void)tensor_map_.Init(tensor_map_shape); - tensor_shape_ = tensor_shape_origin_; -} - -// if idx is not in tensor_map, return -1 -int32_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const { - return tensor_map_.GetIndexByValue(idx); -} - -// tensor_map_.GetDimByIdx(idx) should not be -1 -int32_t TensorLayout::GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const { - return static_cast(device_arrangement_.GetDimSize()) - 1 - tensor_map_.GetDimByIdx(idx); -} - -// tensor_map_.GetDimByIdx(idx) should not be -1 -int32_t TensorLayout::GetSliceNumByTensorDimensionIndex(uint32_t idx) const { - return device_arrangement_.GetDimByIdx(static_cast(GetSliceDeviceDimensionByTensorDimensionIndex(idx))); -} - -std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement &expanded_shape) const { - std::shared_ptr expanded_arrangement_ptr = ComputeArrangementByExpandedShape(expanded_shape); - if (expanded_arrangement_ptr == nullptr) { - return nullptr; - } - std::shared_ptr temp_tensor_layout_ptr = ExpandDeviceArrangement(*expanded_arrangement_ptr); - if (temp_tensor_layout_ptr == nullptr) { - return nullptr; - } - return temp_tensor_layout_ptr->ExpandTensorShapeWithoutExtendDeviceArrangement(expanded_shape); -} - -/* - * example1: - * in_device_arrangement = [8, 4], - * in_tensor_map = [1, 0], - * in_tensor_shape = [512, 1024], - * out_tensor_shape = [128, 4, 2, 512], - * => - * out_device_arrangement = [8, 2, 2] - */ -std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const { - std::shared_ptr> expand_list_ptr = tensor_shape_.GetExpandShapeList(tensor_shape); - if (expand_list_ptr == nullptr) { - return nullptr; - } - std::vector re_map_expand_list; - Arrangement empty_arrangement; - for (int32_t i = static_cast(device_arrangement_.GetDimSize()) - 1; i >= 0; i--) { - if (tensor_map_.GetIndexByValue(i) < 0) { - re_map_expand_list.push_back(empty_arrangement); - } else { - re_map_expand_list.push_back((*expand_list_ptr)[IntToUint(tensor_map_.GetIndexByValue(i))]); - } - } - std::shared_ptr new_arrangement_ptr = - device_arrangement_.GetExpandedShapeByExpandListRemoveLeft(re_map_expand_list); - return new_arrangement_ptr; -} - -/* - * example1: - * in_device_arrangement = [8, 4], - * in_tensor_map = [1, 0], - * in_tensor_shape = [512, 1024], - * out_tensor_shape = [8, 64, 4, 256] - * => - * out_device_arrangement = [8, 4], - * out_tensor_map = [1, -1, 0, -1], - */ -std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDeviceArrangement( - const Arrangement &expanded_shape) const { - std::shared_ptr, Arrangement>> expand_list_pair_ptr = - tensor_shape_.GetExpandShapeListPair(expanded_shape); - if (expand_list_pair_ptr == nullptr) { - return nullptr; - } - std::shared_ptr tensor_map_new_ptr = tensor_map_.ExpandMapByNone(expand_list_pair_ptr->second); - if (tensor_map_new_ptr == nullptr) { - return nullptr; - } - TensorLayout tensor_layout_new; - Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(tensor_layout_new); -} - -/* - * example1: - * in_device_arrangement = [8, 4], - * in_tensor_map = [1, 0], - * in_tensor_shape = [512, 1024], - * out_device_arrangement = [4, 2, 2, 2] - * => - * out_tensor_map = [3, 2, 1, 0], - * out_tensor_shape = [4, 128, 2, 512] - * - * example2: - * in_device_arrangement = [8, 4], - * in_tensor_map = [0, 1], - * in_tensor_shape = [512, 1024], - * out_device_arrangement = [4, 2, 2, 2] - * => - * out_tensor_map = [1, 0, 3, 2], - * out_tensor_shape = [2, 256, 4, 256] - * - * example3: - * in_device_arrangement = [8, 4], - * in_tensor_map = [1, -1], - * in_tensor_shape = [512, 1024], - * out_device_arrangement = [4, 2, 2, 2] - * => - * out_tensor_map = [3, 2, -1], - * out_tensor_shape = [4, 128, 1024] - * - * example4: - * in_device_arrangement = [8, 4], - * in_tensor_map = [0, 1], - * in_tensor_shape = [512, 1024], - * out_device_arrangement = [4, 2, 4] - * => - * out_tensor_map = [0, 2, 1], - * out_tensor_shape = [512, 4, 256] - */ -std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const { - std::shared_ptr, Arrangement>> expand_list_pair_ptr = - device_arrangement_.GetExpandShapeListPair(expanded_arrangement); - if (expand_list_pair_ptr == nullptr) { - return nullptr; - } - std::shared_ptr tensor_map_new_ptr = tensor_map_.ExpandMapByDecreaseNumber(expand_list_pair_ptr->second); - if (tensor_map_new_ptr == nullptr) { - return nullptr; - } - std::shared_ptr> re_map_shape_list_ptr = - tensor_map_.ReMapVector(expand_list_pair_ptr->first); - if (re_map_shape_list_ptr == nullptr) { - return nullptr; - } - std::shared_ptr tensor_shape_new_ptr = - tensor_shape_.GetExpandedShapeByExpandListReserveLeft(*re_map_shape_list_ptr); - if (tensor_shape_new_ptr == nullptr) { - return nullptr; - } - TensorLayout tensor_layout_new; - Status status = tensor_layout_new.Init(expanded_arrangement, *tensor_map_new_ptr, *tensor_shape_new_ptr); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(tensor_layout_new); -} - -bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const { - std::vector in_expand_shape_shape; - Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); - if (status != Status::SUCCESS) { - return false; - } - return (in_expand_shape_shape == tensor_shape_.array()); -} - -std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const { - std::vector in_expand_shape_shape; - Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - Arrangement expanded_shape; - status = expanded_shape.Init(in_expand_shape_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(expanded_shape); -} - -Arrangement TensorLayout::slice_shape() const { - std::vector shape; - for (uint32_t index = 0; index < tensor_map_.GetDimSize(); index++) { - int32_t dim = tensor_map_.GetDimByIdx(index); - int32_t num = tensor_shape_.GetDimByIdx(index); - if (dim == -1) { - shape.push_back(num); - } else { - int32_t divisor = device_arrangement_.GetDimByReverseIdx(IntToUint(dim)); - shape.push_back(num / divisor); - } - } - Arrangement new_tensor_shape; - if (new_tensor_shape.Init(shape) == Status::FAILED) { - ValuePtr ptr = MakeValue(shape); - MS_LOG(EXCEPTION) << "Can't get slice shape when initialize a new shape " << ptr->ToString(); - } else { - return new_tensor_shape; - } -} - -Status TensorLayout::UpdateTensorMap(uint32_t index, int32_t value) { - if (index >= tensor_map_.GetDimSize()) { - MS_LOG(ERROR) << "Index is out of the size of the tensor map!"; - return Status::FAILED; - } - auto shape = tensor_map_.array(); - shape[index] = value; - if (tensor_map_.Init(shape) == Status::FAILED) { - MS_LOG(ERROR) << "Update tensor map failed!"; - return Status::FAILED; - } - return Status::SUCCESS; -} - -bool TensorLayout::operator==(const TensorLayout &t1) const { - return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1)); -} - -/* - * remove elements equal to 1 in tensor_shape, if all elements are 1, squeeze the tensor_shape to [ 1 ] - * example 1: - * original tensor layout: - * device arrangement = [ 8 ] - * tensor map = [ 0 -1 -1 -1 ] - * tensor shape = [ 128 64 1 1 ] - * return tensor layout: - * device arrangement = [ 8 ] - * tensor map = [ 0 -1 ] - * tensor shape = [ 128 64 ] - * - * example 2: - * device arrangement = [ 8 ] - * tensor map = [ -1 -1 -1 -1 ] - * tensor shape = [ 1 1 1 1 ] - * return tensor layout: - * device arrangement = [ 8 ] - * tensor map = [ -1 ] - * tensor shape = [ 1 ] - */ -TensorLayout TensorLayout::SqueezeShape() const { - TensorLayout out; - Map out_map; - Arrangement out_shape; - if (tensor_shape_.size() == 1) { - (void)out_map.Init({MAP_NONE}); - (void)out_shape.Init({1}); - (void)out.Init(device_arrangement_, out_map, out_shape); - return out; - } - std::vector squeeze_list = tensor_shape_.GetSqueezeIdx(); - if (!tensor_map_.CheckNoneByIdxList(squeeze_list)) { - MS_LOG(ERROR) << "CheckNoneByIdxList failed, this may not happen under current situation"; - return *this; - } - out_shape = tensor_shape_.GetSqueezeArrangement(); - out_map = tensor_map_.SqueezeMapByIdxList(squeeze_list); - (void)out.Init(device_arrangement_, out_map, out_shape); - return out; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h deleted file mode 100644 index f51ed4e3e0..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ - -#include -#include -#include -#include -#include -#include "parallel/device_manager.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/arrangement.h" -#include "parallel/tensor_layout/map.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace parallel { -class TensorLayout { - public: - TensorLayout() = default; - ~TensorLayout() = default; - std::string ToString() const; - std::string StandardToString() const; - std::string OriginToString() const; - Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); - Status InitFromVector(const std::vector &device_arrangement, const std::vector &tensor_map, - const std::vector &tensor_shape); - - Arrangement device_arrangement() const { return device_arrangement_; } - - Map tensor_map() const { return tensor_map_; } - - Arrangement tensor_shape() const { return tensor_shape_; } - - Map origin_tensor_map() const { return tensor_map_origin_; } - - std::shared_ptr ExpandTensorShape(const Arrangement &expanded_shape) const; - - std::shared_ptr ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const; - - bool IsSameTensorShape(const TensorLayout &tensor_layout) const { - return (tensor_shape_ == tensor_layout.tensor_shape()); - } - - bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const { - return (device_arrangement_ == tensor_layout.device_arrangement()); - } - - bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } - - bool operator==(const TensorLayout &t1) const; - - bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const; - - std::shared_ptr ComputeExpandedTensorShape(const Arrangement &expand_shape) const; - - Arrangement slice_shape() const; - - Status UpdateTensorMap(uint32_t index, int32_t value); - - TensorLayout SqueezeShape() const; - - private: - std::shared_ptr ExpandTensorShapeWithoutExtendDeviceArrangement( - const Arrangement &expanded_shape) const; - std::shared_ptr ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const; - bool IsValidTensorLayout() const; - void RemoveElementEqualToOneInDeviceArrangement(); - int32_t GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const; - int32_t GetSliceNumByTensorDimensionIndex(uint32_t idx) const; - bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const; - int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const; - - Arrangement device_arrangement_origin_; - Map tensor_map_origin_; - Arrangement tensor_shape_origin_; - Arrangement device_arrangement_; - Map tensor_map_; - Arrangement tensor_shape_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc deleted file mode 100644 index 7824c21f3d..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc +++ /dev/null @@ -1,209 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/tensor_redistribution.h" -#include -#include -#include -#include "common/utils.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/shape_util.h" - -namespace mindspore { -namespace parallel { -Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) { - from_origin_ = from; - to_origin_ = to; - if (from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) { - MS_LOG(ERROR) << "from shape size must be equal to to shape size!"; - MS_LOG(ERROR) << "reshape from_origin_ " << from_origin_.ToString(); - MS_LOG(ERROR) << "reshape to_origin_ " << to_origin_.ToString(); - return Status::FAILED; - } - - dev_list_ = dev_list; - from_ = from_origin_.SqueezeShape(); - to_ = to_origin_.SqueezeShape(); - return Status::SUCCESS; -} - -RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) { - // Step 1: Match device arrangement between from_ and to_ - RedistributionLayoutTransfer layout_transfer; - Status status = layout_transfer.Init(from_, to_); - if (status != Status::SUCCESS) { - return nullptr; - } - std::shared_ptr ptr = layout_transfer.UnifyDeviceArrangementAndTensorShape(); - if (ptr == nullptr) { - MS_LOG(ERROR) << "Infer tensor layout return nullptr!"; - return nullptr; - } - TensorLayout from_layout = ptr->from_in(); - TensorLayout to_layout = ptr->to_in(); - MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString(); - MS_LOG(DEBUG) << "reshape to_layout " << to_layout.ToString(); - MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); - MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); - MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); - MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); - // Step 2: Infer redistribution and insert operators - RedistributionOperatorInfer operator_infer(construct_op_flag_); - if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { - MS_LOG(ERROR) << "Init operatorInfer failed!"; - return nullptr; - } - OperatorVector operator_vector; - OutPutInfoVector output_info_vector; - if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) { - MS_LOG(ERROR) << "Infer redistribution failed!"; - return nullptr; - } else { - operator_vector = operator_infer.operator_vector(); - output_info_vector = operator_infer.output_info_vector(); - operator_list_ = operator_infer.operator_list(); - } - - // Step 3: Infer reshape and insert operators - if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) { - MS_LOG(ERROR) << "Construct Reshape operator failed!"; - return nullptr; - } - - return std::make_shared>( - std::make_pair(operator_vector, output_info_vector)); -} - -Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, - OperatorVector *const operator_vector, - OutPutInfoVector *const output_info_vector) { - MS_EXCEPTION_IF_NULL(operator_vector); - MS_EXCEPTION_IF_NULL(output_info_vector); - ConstructOperator constructor; - if (operator_list_.empty()) { - if (from_origin_.slice_shape().array() != to_origin_.slice_shape().array() || keep_reshape_) { - reshape_flag_ = true; - constructor.UpdateTensorShape(from_origin_.slice_shape().array()); - Arrangement shape = to_origin_.slice_shape(); - MS_LOG(DEBUG) << "reshape " << shape.ToString(); - if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { - return Status::FAILED; - } else { - (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator()); - (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0)); - } - } - return Status::SUCCESS; - } - - if (from_origin_.slice_shape().array() != from_layout.slice_shape().array()) { - reshape_flag_ = true; - constructor.UpdateTensorShape(from_origin_.slice_shape().array()); - Arrangement shape = from_layout.slice_shape(); - MS_LOG(DEBUG) << "reshape " << shape.ToString(); - if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { - return Status::FAILED; - } else { - (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator()); - (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0)); - } - } - - if (to_origin_.slice_shape().array() != to_layout.slice_shape().array()) { - reshape_flag_ = true; - constructor.UpdateTensorShape(to_layout.slice_shape().array()); - Arrangement shape = to_origin_.slice_shape(); - MS_LOG(DEBUG) << "step_parallel to reshape " << shape.ToString(); - if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { - return Status::FAILED; - } else { - (void)operator_vector->insert(operator_vector->end(), constructor.GetOperator()); - (void)output_info_vector->insert(output_info_vector->end(), std::make_pair(false, 0)); - } - } - return Status::SUCCESS; -} - -Status TensorRedistribution::ComputeCost() { - RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true); - if (redistribution_oplist_ptr == nullptr) { - MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; - return Status::FAILED; - } - // Compute redistribution communication cost and computation cost - for (auto &op_cost : operator_list_) { - OperatorR op = op_cost.first; - Shape slice_shape = op_cost.second; - double prod = - std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast(1.0), std::multiplies()); - std::string str = op.first; - if (str == PERMUTE_BY_AXIS) { - // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost. - // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape - forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; - backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; - comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR; - int32_t concat_dim = op.second[2]; - if (concat_dim == 0) { - // memory cost = all_gather - computation_cost_ += prod; - memory_cost_ += prod; - } else { - // memory cost = all_gather + split + concat - int32_t dev_num = op.second[4]; - computation_cost_ += (prod + prod * dev_num + prod * dev_num); - memory_cost_ += (prod * dev_num + prod * dev_num + prod); - } - } else if (str == CONCAT_BY_AXIS) { - // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape - // computation cost = before_slice_shape - if (op.second.size() < 3) { - MS_LOG(ERROR) << "op.second size should not be less than 3!"; - return Status::FAILED; - } - double dev_num = op.second[2]; - // here, communication cost = all_gather + reduce_scatter - forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; - backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; - comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; - int32_t concat_dim = op.second[0]; - if (concat_dim == 0) { - // computation cost = all_gather - computation_cost_ += prod; - memory_cost_ += prod * dev_num; - } else { - // computation cost = all_gather + split + concat - computation_cost_ += (prod + prod * dev_num + prod * dev_num); - memory_cost_ += (prod * dev_num + prod * dev_num + prod); - } - } else { - // There is only computation cost in SplitByAxis. - // computation cost = before_slice_shape - computation_cost_ += prod; - // This addtion may be erroneous - memory_cost_ += prod; - } - } - if (reshape_flag()) { - Shape prev_slice_shape = from_.slice_shape().array(); - double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies()); - computation_cost_ += 2.0 * prev_prod; - memory_cost_ += 2.0 * prev_prod; - } - return Status::SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h deleted file mode 100644 index d1f46108bb..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h +++ /dev/null @@ -1,90 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ - -#include -#include -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/construct_operator.h" -#include "parallel/tensor_layout/redistribution_operator_infer.h" -#include "parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -constexpr double ALLTOALL_SCALE_FACTOR = 2.0; -constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5; -class TensorRedistribution { - public: - explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) - : reshape_flag_(false), - comm_cost_(0.0), - forward_comm_cost_(0.0), - backward_comm_cost_(0.0), - computation_cost_(0.0), - memory_cost_(0.0), - construct_op_flag_(construct_op_flag), - keep_reshape_(keep_reshape) {} - Status Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list); - ~TensorRedistribution() = default; - RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); - OperatorList operator_list() const { return operator_list_; } - bool reshape_flag() const { return reshape_flag_; } - Status ComputeCost(); - double comm_cost() const { return comm_cost_; } - double computation_cost() const { return computation_cost_; } - double forward_comm_cost() const { return forward_comm_cost_; } - double backward_comm_cost() const { return backward_comm_cost_; } - double memory_cost() const { return memory_cost_; } - - private: - Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, - OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector); - - TensorLayout from_origin_; - TensorLayout to_origin_; - TensorLayout from_; - TensorLayout to_; - RankList dev_list_; - OperatorList operator_list_; - bool reshape_flag_; - // communication cost, which is the sum of forward communication cost and backward communication cost - double comm_cost_; - // forward communication cost - double forward_comm_cost_; - // backward communication cost - double backward_comm_cost_; - // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the - // inputs. This is calculated ONLY for forward phase. - double computation_cost_; - // memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is - // calculated by the outputs. - double memory_cost_; - bool construct_op_flag_; - bool keep_reshape_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ diff --git a/mindspore/ccsrc/pipeline/CMakeLists.txt b/mindspore/ccsrc/pipeline/CMakeLists.txt deleted file mode 100644 index 39664d717d..0000000000 --- a/mindspore/ccsrc/pipeline/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "pipeline.cc" - "resource.cc" - "pass.cc" - "action.cc" - "validator.cc" - "remove_value_node_dup.cc" - "parse/*.cc" - "static_analysis/*.cc" -) - - -file(GLOB PIPELINE_SRC_FILES "*.cc") -set_property(SOURCE ${PIPELINE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) - -file(GLOB_RECURSE PARSER_SRC_FILES "parse/*.cc") -set_property(SOURCE ${PARSER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PARSER) - -file(GLOB_RECURSE ANALYZER_SRC_FILES "static_analysis/*.cc") -set_property(SOURCE ${ANALYZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) - -if (ENABLE_GE OR ENABLE_D) - file(GLOB_RECURSE _PIPELINE_GE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pipeline_ge.cc") - list(APPEND _PIPELINE_SRC_FILES ${_PIPELINE_GE_SRC_FILES}) -endif () - -add_library(_mindspore_pipeline_obj OBJECT ${_PIPELINE_SRC_FILES}) diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc deleted file mode 100644 index 89598ae85d..0000000000 --- a/mindspore/ccsrc/pipeline/action.cc +++ /dev/null @@ -1,498 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/action.h" - -#include -#include -#include -#include -#include -#include - -#include "ir/func_graph_cloner.h" -#include "ir/param_value_py.h" -#include "parallel/costmodel_context.h" -#include "parallel/context.h" -#include "pipeline/pass.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "pipeline/static_analysis/program_specialize.h" -#include "pipeline/resource.h" -#include "utils/context/ms_context.h" -#include "pipeline/remove_value_node_dup.h" -#include "optimizer/optimizer.h" -#include "vm/transform.h" -#include "parse/python_adapter.h" -#include "optimizer/py_pass_manager.h" - -namespace mindspore { -namespace pipeline { -using CompileGraphs = compile::CompileGraphs; -using abstract::AnalysisResult; -using mindspore::abstract::AnalysisContextPtr; - -abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AbstractBasePtrList &args_spec, bool clear) { - MS_LOG(DEBUG) << "AbstractAnalyze start"; - auto engine = res->engine(); - MS_EXCEPTION_IF_NULL(engine); - if (clear) { - auto manager = res->manager(); - MS_EXCEPTION_IF_NULL(manager); - engine->Clear(); - for (auto &node : manager->all_nodes()) { - MS_EXCEPTION_IF_NULL(node); - const AbstractBasePtr &prev_inferred = node->abstract(); - // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. - if (!node->isa() || (prev_inferred != nullptr && prev_inferred->isa())) { - node->set_abstract(nullptr); - MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr"; - } - } - } - auto ret = engine->Run(func_graph, args_spec); - MS_LOG(DEBUG) << "AbstractAnalyze end"; - return ret; -} - -FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AnalysisContextPtr &context) { - MS_LOG(DEBUG) << "ProgramSpecialize start"; - abstract::ProgramSpecializer spc(res->engine()); - FuncGraphPtr result = spc.Run(func_graph, context); - auto manager = res->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->KeepRoots({result}); - MS_LOG(DEBUG) << "ProgramSpecialize end"; - return result; -} - -FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AbstractBasePtrList &args_spec) { - MS_LOG(DEBUG) << "Renormalize start"; -#ifdef ENABLE_PROFILE - double t1 = GetTime(); -#endif - abstract::AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec, true); -#ifdef ENABLE_PROFILE - double t2 = GetTime(); -#endif - auto ret = ProgramSpecialize(res, func_graph, result.context); - res->set_func_graph(ret); -#ifdef ENABLE_PROFILE - double t3 = GetTime(); - MsProfile::StatTime("renormalize.infer", t2 - t1); - MsProfile::StatTime("renormalize.specialize", t3 - t2); -#endif - MS_LOG(DEBUG) << "Renormalize end"; - return ret; -} - -bool ParseAction(const ResourcePtr &res) { - if (!res->input()) { - MS_LOG(EXCEPTION) << "Parse error"; - } - - py::object input = res->input(); - parse::Parser::InitParserEnvironment(input); - py::module path = py::module::import("os.path"); - std::string dir = path.attr("dirname")(py::globals()["__file__"]).cast(); - - parse::python_adapter::set_python_env_flag(true); - parse::python_adapter::SetPythonPath(dir); - - FuncGraphPtr fg = parse::ConvertToFuncGraph(input); - if (fg == nullptr) { - MS_LOG(EXCEPTION) << "Parse error."; - } - res->set_func_graph(fg); - - FuncGraphManagerPtr manager = res->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Manager is nullptr."; - } - manager->AddFuncGraph(fg); - return true; -} - -// obj_map's graphs have the same construct, these graphs can be optimized to one graph. -// This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}-> -// graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx} -// all obj_map's graph shared base_graph -bool CombineLikeGraphs(const ResourcePtr &res) { - auto &obj_map = parse::data_converter::GetObjGraphs(); - - for (auto it : obj_map) { - auto &graphs = it.second; - MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size(); - auto fg = graphs[0]; - FuncGraphPtrList func_graphs = {fg}; - ClonerPtr cloner = std::make_shared(func_graphs, false, false, true, std::make_shared(), - std::make_shared()); - cloner->Run(); - auto base_graph = cloner->cloned_func_graph()[fg]; - MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString(); - - if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) { - continue; - } - for (auto &fv : fg->paramter_obj_nodes()) { - TraceManager::DebugTrace(std::make_shared(fv->debug_info())); - auto param = base_graph->add_parameter(); - TraceManager::EndTrace(); - auto &node_users = res->manager()->node_users()[fv]; - for (auto &n : node_users) { - auto repl_n = (*cloner->cloned_node())[n.first]->cast(); - repl_n->set_input(n.second, param); - } - } - MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size(); - - for (auto &g : graphs) { - auto fvs = g->paramter_obj_nodes(); - std::vector new_node_inputs; - new_node_inputs.push_back(NewValueNode(base_graph)); - for (auto &p : g->parameters()) { - AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p); - new_node_inputs.push_back(para_after_cast); - } - (void)new_node_inputs.insert(new_node_inputs.end(), fvs.begin(), fvs.end()); - AnfNodePtr out = g->NewCNode(new_node_inputs); - g->set_output(out); - MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(4); - } - MS_LOG(DEBUG) << "End combine graph:" << it.first; - } - return true; -} - -bool SymbolResolveAction(const ResourcePtr &res) { - if (res->manager() == nullptr) { - MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; - } - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null"; - } - FuncGraphPtr func_graph = res->func_graph(); - auto succ = parse::ResolveFuncGraph(func_graph, res); - - // Remove unused nodes in cnode order list. - func_graph->EraseUnusedNodeInOrder(); - func_graph->ReleaseFullOrderToEffectOrder(); - for (auto fg : func_graph->func_graphs_used_total()) { - MS_EXCEPTION_IF_NULL(fg); - fg->EraseUnusedNodeInOrder(); - fg->ReleaseFullOrderToEffectOrder(); - } - return succ; -} - -bool InferenceOptPrepareAction(const ResourcePtr &res) { - if (res->manager() == nullptr) { - MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; - } - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null."; - } - return InferenceOptPreparePass(res); -} - -bool AbstractSpecializeAction(const ResourcePtr &res) { - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "AbstractSpecialize error"; - } - - FuncGraphPtr func_graph = res->func_graph(); - abstract::AbstractBasePtrList args_spec = res->args_spec(); - - parallel::ParallelParameterContextInit(func_graph); - - // suppose that there is not KeywordArgument for the top graph - // get the hyper parameter - for (const auto ¶m : func_graph->parameters()) { - auto param_node = std::static_pointer_cast(param); - if (param_node->has_default()) { - auto param_value = std::dynamic_pointer_cast(param_node->default_param()); - AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true); - auto sparse_grad = - py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad")); - ptr->set_sparse_grad(sparse_grad); - auto has_indexed_slices_grad = - py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "has_indexed_slices_grad")); - ptr->set_has_indexed_slices_grad(has_indexed_slices_grad); - - parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); - args_spec.push_back(ptr); - parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr); - } - } - // Analyze - AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec); - // The top graph may be replaced by infer, update the top graph when the infer is done - parse::Parser::UpdateTopFuncGraph(result.context->func_graph()); - - // Specialize - FuncGraphPtr new_fg = ProgramSpecialize(res, result.context->func_graph(), result.context); - res->set_func_graph(new_fg); - - MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true); - return true; -} - -bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) { - size_t counter = 0; - for (auto &pass : passes) { - WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() { - MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; - auto result = pass.second(res); - if (!result) { - MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; - } - if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) { - auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first; - auto func_graph = res->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - func_graph->DumpFuncGraph(fg_name); - DumpIR(fg_name + ".ir", func_graph); - MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; - } - counter++; - MS_LOG(DEBUG) << "Pass " << pass.first << " end."; - }; - } - - return true; -} - -bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } - -bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } - -bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } - -static bool IsCtrlSink() { - auto ms_ctx = MsContext::GetInstance(); - if (ms_ctx->execution_mode() != kGraphMode) { - return false; - } - - std::string device_target = ms_ctx->device_target(); - if (device_target != kAscendDevice) { - return false; - } - - if (!ms_ctx->enable_task_sink()) { - return false; - } - - if (!ms_ctx->is_multi_graph_sink()) { - return false; - } - return true; -} - -bool TaskEmitAction(const ResourcePtr &res) { - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "TaskEmit args error"; - } - FuncGraphPtr func_graph = res->func_graph(); - auto bc_ptr = res->results()[kBackend].cast(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (CompileGraphs::ContainMixedTarget(func_graph)) { - bc_ptr->set_is_multi_graph_sink(false); - context_ptr->set_is_multi_graph_sink(false); - context_ptr->set_loop_sink_flag(false); - } else if (context_ptr->execution_mode() != kPynativeMode) { - std::string device_target = context_ptr->device_target(); - if (device_target == kAscendDevice) { - bc_ptr->set_is_multi_graph_sink(true); - context_ptr->set_is_multi_graph_sink(true); - } - } - - if (IsCtrlSink()) { - res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); - return true; - } - std::vector cut_list = compile::nonlinear_ops; - if (bc_ptr->name() == kMsConvert) { - cut_list = compile::GetMsNonlinearOps(); - } - std::shared_ptr compile = std::make_shared(bc_ptr, cut_list); - res->results()[kOutput] = compile->CompileAndLink(func_graph); - return true; -} - -bool ExecuteAction(const ResourcePtr &res) { - if (res->results().count(kOutput) == 0) { - MS_LOG(EXCEPTION) << "Execute args error"; - } - - if (IsCtrlSink()) { - if (!res->results()[kOutput].is()) { - MS_LOG(EXCEPTION) << "Execute args error"; - } - auto graph_id = res->results()[kOutput].cast(); - std::shared_ptr bc_ptr = res->results()[kBackend].cast>(); - std::shared_ptr msbc_ptr = std::dynamic_pointer_cast(bc_ptr); - MS_EXCEPTION_IF_NULL(msbc_ptr); - compile::VmEvalFuncPtr run = - std::make_shared([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef { - MS_LOG(INFO) << "Execute args size " << args.size(); - auto outs = msbc_ptr->RunGraph(graph_id, args); - MS_LOG(DEBUG) << "out size " << outs.size(); - return outs[0]; - }); - res->results()[kOutput] = run; - return true; - } - - if (!res->results()[kOutput].is()) { - MS_LOG(EXCEPTION) << "Execute args error"; - } - compile::FinalVMPtr vm = res->results()[kOutput].cast(); - if (vm == nullptr) { - MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM"; - return true; - } - compile::VmEvalFuncPtr run = - std::make_shared(std::bind(&compile::FinalVM::Eval, vm, std::placeholders::_1)); - res->results()[kOutput] = run; - return true; -} - -// The parallel primitive related valuenode might be partitioned so that its value changes by device, -// that will result in a syncronization error due to different executing order. -// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, -// the final solution will be proposed later as a parallel feature. -bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) { - auto &node_users = res->manager()->node_users(); - auto &users = node_users[value_node]; - auto used_by_keep_value_prim = - std::any_of(users.begin(), users.end(), [](const std::pair &user) -> bool { - MS_EXCEPTION_IF_NULL(user.first); - auto cnode = user.first->cast(); - if (cnode == nullptr) { - return false; - } - auto prim_node = cnode->input(0); - if (IsValueNode(prim_node)) { - auto prim = GetValue(prim_node->cast()->value()); - // value_node is referenced by some parallel primitive - return prim->HasAttr("keep_value_node_input"); - } - return false; - }); - return used_by_keep_value_prim; -} - -bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "Remove value node duplications error."; - } - FuncGraphPtr func_graph = res->func_graph(); - auto manager = res->manager(); - // Remove duplicated value nodes, due to replace operation, can't use reference. - auto value_nodes = func_graph->value_nodes(); - HashCache hash_cache; - HashValue hashes; - for (const auto &value_pair : value_nodes) { - if (KeepValueNodeDuplication(value_pair.first, res)) { - continue; - } - TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes); - } - return true; -} - -bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } - -void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { - MS_EXCEPTION_IF_NULL(res->manager()); - MS_EXCEPTION_IF_NULL(res->func_graph()); - auto ppm = opt::python_pass::PyPassManager::GetInstance(); - if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { - MS_LOG(DEBUG) << "No match.\n"; - } -} - -bool ResolveActionPyStub(const ResourcePtr &res) { - ActionPyStub(res, opt::python_pass::Phase::RESOLVE); - return true; -} - -bool OptActionPyStub(const ResourcePtr &res) { - ActionPyStub(res, opt::python_pass::Phase::RESOLVE); - return true; -} - -static std::vector CommonPipeline() { - std::vector actions; - - // Parse the python ast to ANF graph - actions.emplace_back(std::make_pair("parse", ParseAction)); - - // Resolve the python func - actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); - auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs(); - if (!multi_graphs) { - actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); - } - // Add resolve-stage python pass stub - actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); - actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); - // Evaluate type and shape, and specialize - actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); - - return actions; -} - -std::vector GePipeline() { - auto actions = CommonPipeline(); - // optimize - actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); - // Add opt-stage python pass stub - actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); - actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); - actions.emplace_back(std::make_pair("validate", ValidateAction)); - return actions; -} - -std::vector VmPipeline() { - auto actions = CommonPipeline(); - - // optimize - actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); - - // Add opt-stage python pass stub - actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); - - actions.emplace_back(std::make_pair("validate", ValidateAction)); - - // compile the ANF graph - actions.emplace_back(std::make_pair("task_emit", TaskEmitAction)); - - // to execute the graph - actions.emplace_back(std::make_pair("execute", ExecuteAction)); - - return actions; -} -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/action.h b/mindspore/ccsrc/pipeline/action.h deleted file mode 100644 index eed1307872..0000000000 --- a/mindspore/ccsrc/pipeline/action.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_ACTION_H_ -#define MINDSPORE_CCSRC_PIPELINE_ACTION_H_ - -#include -#include -#include -#include -#include "pipeline/resource.h" -#include "vm/segment_runner.h" - -namespace mindspore { -extern const char kMsConvert[]; - -namespace pipeline { -using ActionItem = std::pair>; - -bool ParseAction(const ResourcePtr &res); -bool SymbolResolveAction(const ResourcePtr &res); -bool AbstractSpecializeAction(const ResourcePtr &res); -bool GeOptimizeAction(const ResourcePtr &res); -bool VmOptimizeAction(const ResourcePtr &res); -bool PynativeOptimizeAction(const ResourcePtr &res); -bool TaskEmitAction(const ResourcePtr &res); -bool ExecuteAction(const ResourcePtr &res); - -std::vector GePipeline(); -std::vector VmPipeline(); -abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AbstractBasePtrList &args_spec, bool clear = false); -FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AnalysisContextPtr &context); -FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AbstractBasePtrList &args_spec); -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_ACTION_H_ diff --git a/mindspore/ccsrc/pipeline/base.h b/mindspore/ccsrc/pipeline/base.h deleted file mode 100644 index 57edea03a2..0000000000 --- a/mindspore/ccsrc/pipeline/base.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_BASE_H_ -#define MINDSPORE_CCSRC_PIPELINE_BASE_H_ - -#include -#include -#include -#include - -#include "ir/anf.h" -#include "pipeline/resource.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace pipeline { -struct ExecutorInfo { - FuncGraphPtr func_graph; - ResourcePtr resource; - std::size_t arg_list_size; -}; -using ExecutorInfoPtr = std::shared_ptr; - -inline std::string GetPhasePrefix(const std::string &phase) { - auto pos = phase.find('.'); - if (pos == std::string::npos) { - MS_LOG(EXCEPTION) << "Phase has no . for prefix" << phase; - } - return phase.substr(0, pos); -} - -inline std::string GetFilePathName(const std::string &file_name) { - std::ostringstream oss; - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(EXCEPTION) << "ms_context is nullptr"; - } - auto save_graphs_path = ms_context->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - oss << save_graphs_path << "/" << file_name; - return oss.str(); -} -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_BASE_H_ diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc deleted file mode 100644 index f28be181dd..0000000000 --- a/mindspore/ccsrc/pipeline/init.cc +++ /dev/null @@ -1,343 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "kernel/oplib/oplib.h" -#include "kernel/oplib/oploader.h" -#include "pipeline/pipeline.h" -#include "operator/composite/composite.h" -#include "ir/signature.h" -#include "pynative/pynative_execute.h" -#include "utils/symbolic.h" -#include "pybind_api/api_register.h" -#include "pipeline/parse/python_adapter.h" -#include "utils/summary/event_writer.h" -#include "utils/config_manager.h" -#include "utils/mpi/mpi_config.h" -#include "parallel/context.h" -#include "parallel/device_manager.h" -#include "parallel/costmodel_context.h" -#ifdef ENABLE_GPU_COLLECTIVE -#include "device/gpu/distribution/collective_init.h" -#else -#include "device/gpu/distribution/collective_fake_init.h" -#endif -namespace py = pybind11; - -using FuncGraph = mindspore::FuncGraph; -using EnvInstance = mindspore::EnvInstance; -using ExecutorPy = mindspore::pipeline::ExecutorPy; -using Pipeline = mindspore::pipeline::Pipeline; -using PrimitivePy = mindspore::PrimitivePy; -using MetaFuncGraph = mindspore::MetaFuncGraph; -using EventWriter = mindspore::summary::EventWriter; -using OpLib = mindspore::kernel::OpLib; -using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy; -using ParallelContext = mindspore::parallel::ParallelContext; -using CostModelContext = mindspore::parallel::CostModelContext; - -// Interface with python -PYBIND11_MODULE(_c_expression, m) { - m.doc() = "MindSpore c plugin"; - - (void)py::class_>(*m, "MetaFuncGraph_") - .def_readonly(mindspore::PYTHON_METAFUNCGRAPH_FLAG, &mindspore::MetaFuncGraph::parse_info_) - .def(py::init()); - - auto fns = mindspore::PybindDefineRegister::AllFuncs(); - for (auto &item : fns) { - item.second(&m); - } - - // Class Pipeline interface - (void)py::class_>(m, "Executor_") - .def_static("get_instance", &ExecutorPy::GetInstance, "Executor get_instance.") - .def("__call__", &ExecutorPy::Run, py::arg("args"), py::arg("phase") = py::str(""), "Executor run function.") - .def("del_net_res", &ExecutorPy::DelNetRes, py::arg("network_id") = py::str(""), "Delete network resource.") - .def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.") - .def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""), - py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") - .def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), - py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") - .def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"), - "Get Parameter Tensor Layout Dictionary.") - .def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"), - "Get CNode Strategy Dictionary.") - .def("get_allreduce_fusion", &ExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"), - "Get Allreduce Fusion Dictionary.") - .def("fetch_info_for_quant_export", &ExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"), - "Fetch the inputs of Conv or Matmul for quant export.") - .def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"), - py::arg("broadcast_params") = py::dict(), "Build data graph.") - .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") - .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); - // Class Graph interface - (void)py::class_(m, "FuncGraph").def(py::init()); - - (void)py::class_>(m, "EnvInstance_") - .def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_) - .def(py::init()); - - (void)m.def("generate_key", &mindspore::pipeline::GenerateKey, "Generate the function graph key."); - (void)m.def("real_run_op", &mindspore::pynative::RunOp, "Run op pynatively."); - (void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id"); - (void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl"); - (void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl"); - (void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature."); - (void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"), - py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"), - py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset."); - (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode."); - (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend."); - - (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); - - (void)py::class_>(m, "MSContext") - .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") - .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") - .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.") - .def("get_execution_mode", &mindspore::MsContext::execution_mode, "Get execution mode.") - .def("set_execution_mode", &mindspore::MsContext::set_execution_mode, "Set execution mode.") - .def("set_precompile_only", &mindspore::MsContext::set_precompile_only, "Set enable precompile only.") - .def("get_precompile_only", &mindspore::MsContext::precompile_only, "Get enable precompile only.") - .def("get_device_target", &mindspore::MsContext::device_target, "Get device target.") - .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") - .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") - .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") - .def("open_tsd", &mindspore::MsContext::OpenTsd, "Open tdt dataset client.") - .def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.") - .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") - .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") - .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, - "Get whether to enable auto mixed precision.") - .def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag, - "Set whether to enable auto mixed precision.") - .def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision, - "Get whether to enable reduce precision.") - .def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision, - "Set whether to enable reduce precision.") - .def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.") - .def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.") - .def("get_save_ms_model_flag", &mindspore::MsContext::save_ms_model_flag, "Get whether to save ms model.") - .def("set_save_ms_model_flag", &mindspore::MsContext::set_save_ms_model_flag, "Set whether to save ms model.") - .def("get_save_ms_model_path", &mindspore::MsContext::save_ms_model_path, "Get path to save ms model.") - .def("set_save_ms_model_path", &mindspore::MsContext::set_save_ms_model_path, "Set path to save ms model") - .def("get_enable_dump", &mindspore::MsContext::enable_dump, "Get whether to enable dump.") - .def("set_enable_dump", &mindspore::MsContext::set_enable_dump, "Set whether to enable dump.") - .def("get_save_dump_path", &mindspore::MsContext::save_dump_path, "Get path to dump.") - .def("set_save_dump_path", &mindspore::MsContext::set_save_dump_path, "Set path to dump.") - .def("set_graph_memory_max_size", &mindspore::MsContext::set_graph_memory_max_size, "set graph memory max size.") - .def("set_variable_memory_max_size", &mindspore::MsContext::set_variable_memory_max_size, - "set variable memory max size") - .def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.") - .def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.") - .def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.") - .def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.") - .def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.") - .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") - .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") - .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") - .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") - .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, - "Set the GraphKernel switch to on or off.") - .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") - .def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.") - .def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse."); - - (void)py::class_>(m, "MpiConfig") - .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") - .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.") - .def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi."); - - (void)py::class_>(m, "AutoParallelContext") - .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") - .def("get_device_num", &ParallelContext::device_num, "Get device num.") - .def("set_device_num", &ParallelContext::set_device_num, "Set device num.") - .def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.") - .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.") - .def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.") - .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") - .def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.") - .def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.") - .def("get_cast_before_mirror", &ParallelContext::cast_before_mirror, "Get cast before mirror.") - .def("set_cast_before_mirror", &ParallelContext::set_cast_before_mirror, "Set cast before mirror.") - .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.") - .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") - .def("get_communication_backend", &ParallelContext::communication_backend, "Get communication backend.") - .def("set_communication_backend", &ParallelContext::set_communication_backend, "Set communication backend.") - .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") - .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") - .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") - .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.") - .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices, - "Set all reduce fusion split indices.") - .def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices, - "Get all reduce fusion split indices.") - .def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes, - "Set all reduce fusion split sizes.") - .def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes, - "Get all reduce fusion split sizes.") - .def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion, - "Set enable/disable all reduce fusion.") - .def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion, - "Get enable/disable all reduce fusion.") - .def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.") - .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set, - "Get parameter broadcast is set.") - .def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.") - .def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file, - "Set strategy checkpoint load file.") - .def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file, - "Set strategy checkpoint save file.") - .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") - .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") - .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") - .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") - .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, - "Set enable/disable parallel optimizer.") - .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, - "Get enable/disable parallel optimizer.") - .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); - - (void)py::class_>(m, "CostModelContext") - .def_static("get_instance", &CostModelContext::GetInstance, "Get cost_model context instance.") - .def("set_device_memory_capacity", &CostModelContext::set_device_memory_capacity, - "Set the capacity of device memory.") - .def("get_device_memory_capacity", &CostModelContext::device_memory_capacity, "Get the capacity of device memory.") - .def("set_costmodel_alpha", &CostModelContext::set_costmodel_alpha, - "Set the parameter cost_model_alpha of the DP algorithm.") - .def("get_costmodel_alpha", &CostModelContext::costmodel_alpha, - "Get the parameter cost_model_alpha of the DP algorithm.") - .def("set_costmodel_beta", &CostModelContext::set_costmodel_beta, - "Set the parameter cost_model_beta of the DP algorithm.") - .def("get_costmodel_beta", &CostModelContext::costmodel_beta, - "Get the parameter cost_model_beta of the DP algorithm.") - .def("set_costmodel_gamma", &CostModelContext::set_costmodel_gamma, - "Set the parameter cost_model_gamma of the DP algorithm") - .def("get_costmodel_gamma", &CostModelContext::costmodel_gamma, - "Get the parameter cost_model_gamma of the DP algorithm.") - .def("set_costmodel_communi_threshold", &CostModelContext::set_costmodel_communi_threshold, - "Set the parameter cost_model_communi_threshold of the DP algorithm.") - .def("get_costmodel_communi_threshold", &CostModelContext::costmodel_communi_threshold, - "Get the parameter cost_model_communi_threshold of the DP algorithm.") - .def("set_costmodel_communi_const", &CostModelContext::set_costmodel_communi_const, - "Set the parameter cost_model_communi_const of the DP algorithm.") - .def("get_costmodel_communi_const", &CostModelContext::costmodel_communi_const, - "Get the parameter cost_model_communi_const of the DP algorithm.") - .def("set_costmodel_communi_bias", &CostModelContext::set_costmodel_communi_bias, - "Set the parameter cost_model_communi_bias of the DP algorithm.") - .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias, - "Get the parameter cost_model_communi_bias of the DP algorithm.") - .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") - .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") - .def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.") - .def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.") - .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, - "Set the parameter gradient AllReduce fusion algorithm.") - .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, - "Get the parameter gradient AllReduce fusion algorithm.") - .def("set_costmodel_allreduce_fusion_times", &CostModelContext::set_costmodel_allreduce_fusion_times, - "Set the parameter gradient AllReduce times.") - .def("get_costmodel_allreduce_fusion_times", &CostModelContext::costmodel_allreduce_fusion_times, - "Get the parameter gradient AllReduce times.") - .def("set_costmodel_allreduce_fusion_tail_percent", &CostModelContext::set_costmodel_allreduce_fusion_tail_percent, - "Set the parameter gradient AllReduce fusion tail percent.") - .def("get_costmodel_allreduce_fusion_tail_percent", &CostModelContext::costmodel_allreduce_fusion_tail_percent, - "Get the parameter gradient AllReduce fusion tail percent.") - .def("set_costmodel_allreduce_fusion_tail_time", &CostModelContext::set_costmodel_allreduce_fusion_tail_time, - "Set the parameter gradient AllReduce fusion tail time.") - .def("get_costmodel_allreduce_fusion_tail_time", &CostModelContext::costmodel_allreduce_fusion_tail_time, - "Get the parameter gradient AllReduce fusion tail time.") - .def("set_costmodel_allreduce_fusion_allreduce_inherent_time", - &CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time, - "Set the parameter gradient AllReduce fusion allreduce inherent time.") - .def("get_costmodel_allreduce_fusion_allreduce_inherent_time", - &CostModelContext::costmodel_allreduce_fusion_allreduce_inherent_time, - "Get the parameter gradient AllReduce fusion allreduce inherent time.") - .def("set_costmodel_allreduce_fusion_allreduce_bandwidth", - &CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth, - "Set the parameter gradient AllReduce fusion allreduce bandwidth.") - .def("get_costmodel_allreduce_fusion_allreduce_bandwidth", - &CostModelContext::costmodel_allreduce_fusion_allreduce_bandwidth, - "Get the parameter gradient AllReduce fusion allreduce bandwidth.") - .def("set_costmodel_allreduce_fusion_computation_time_parameter", - &CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter, - "Set the parameter gradient AllReduce fusion computation time parameter.") - .def("get_costmodel_allreduce_fusion_computation_time_parameter", - &CostModelContext::costmodel_allreduce_fusion_computation_time_parameter, - "Get the parameter gradient AllReduce fusion computation time parameter.") - .def("set_tensor_slice_align_enable", &CostModelContext::set_tensor_slice_alignment_enable, - "Set the parameter tensor_slice_align_enable in strategy generation.") - .def("get_tensor_slice_align_enable", &CostModelContext::tensor_slice_alignment_enable, - "Get the parameter tensor_slice_align_enable in strategy generation.") - .def("set_tensor_slice_align_size", &CostModelContext::set_tensor_slice_alignment_size, - "Set the parameter tensor_slice_size in strategy generation.") - .def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size, - "Get the parameter tensor_slice_size in strategy generation.") - .def("set_fully_use_devices", &CostModelContext::set_fully_use_device, - "Set the parameter fully_use_devices in the DP algorithm.") - .def("get_fully_use_devices", &CostModelContext::fully_use_device, - "Get the parameter fully_use_devices in the DP algorithm.") - .def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow, - "Set the parameter elementwise_op_strategy_follow in the DP algorithm.") - .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow, - "Get the parameter elementwise_op_strategy_follow in the DP algorithm.") - .def("reset_cost_model", &CostModelContext::ResetCostModel, "Reset the CostModelContext.") - .def("reset_algo_parameters", &CostModelContext::ResetAlgoParameters, "Reset the AlgoParameters."); - - (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { - // only in case that c++ calling python interface, ClearResAtexit should be called. - if (mindspore::parse::python_adapter::IsPythonEnv()) { - mindspore::pipeline::ClearResAtexit(); - -#ifdef ENABLE_MINDDATA - py::module iterators = py::module::import("mindspore.dataset.engine.iterators"); - (void)iterators.attr("_cleanup")(); -#endif - } - }}); - - (void)py::class_>(m, "EventWriter_") - .def(py::init()) - .def("GetFileName", &EventWriter::GetFileName, "Get the file name.") - .def("Open", &EventWriter::Open, "Open the write file.") - .def("Write", &EventWriter::Write, "Write the serialize event.") - .def("EventCount", &EventWriter::GetWriteEventCount, "Write event count.") - .def("Flush", &EventWriter::Flush, "Flush the event.") - .def("Close", &EventWriter::Close, "Close the write.") - .def("Shut", &EventWriter::Shut, "Final close the write."); - - (void)py::class_>(m, "Oplib") - .def(py::init()) - .def("reg_op", &OpLib::RegOp, "Register op info."); -#ifdef ENABLE_GPU_COLLECTIVE - (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::InitCollective, - "Init gpu collective communication mode."); - (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::FinalizeCollective, - "Finalize gpu collective communication mode."); -#else - (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::InitCollective, - "Init gpu collective communication mode."); - (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::FinalizeCollective, - "Finalize gpu collective communication mode."); - -#endif - - (void)py::class_>(m, "OpInfoLoaderPy") - .def(py::init()) - .def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info."); -} diff --git a/mindspore/ccsrc/pipeline/jit/CMakeLists.txt b/mindspore/ccsrc/pipeline/jit/CMakeLists.txt new file mode 100644 index 0000000000..6188546ce5 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/CMakeLists.txt @@ -0,0 +1,27 @@ +file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "pipeline.cc" + "resource.cc" + "pass.cc" + "action.cc" + "validator.cc" + "remove_value_node_dup.cc" + "parse/*.cc" + "static_analysis/*.cc" +) + + +file(GLOB PIPELINE_SRC_FILES "*.cc") +set_property(SOURCE ${PIPELINE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) + +file(GLOB_RECURSE PARSER_SRC_FILES "parse/*.cc") +set_property(SOURCE ${PARSER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PARSER) + +file(GLOB_RECURSE ANALYZER_SRC_FILES "static_analysis/*.cc") +set_property(SOURCE ${ANALYZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) + +if (ENABLE_GE OR ENABLE_D) + file(GLOB_RECURSE _PIPELINE_GE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pipeline_ge.cc") + list(APPEND _PIPELINE_SRC_FILES ${_PIPELINE_GE_SRC_FILES}) +endif () + +add_library(_mindspore_pipeline_jit_obj OBJECT ${_PIPELINE_SRC_FILES}) diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc new file mode 100644 index 0000000000..74eb9f3f9b --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -0,0 +1,494 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/action.h" + +#include +#include +#include +#include +#include +#include + +#include "ir/func_graph_cloner.h" +#include "ir/param_value.h" +#include "frontend/parallel/costmodel_context.h" +#include "frontend/parallel/context.h" +#include "pipeline/jit/pass.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/data_converter.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/program_specialize.h" +#include "pipeline/jit/resource.h" +#include "utils/context/ms_context.h" +#include "pipeline/jit/remove_value_node_dup.h" +#include "frontend/optimizer/optimizer.h" +#include "vm/transform.h" +#include "parse/python_adapter.h" +#include "frontend/optimizer/py_pass_manager.h" + +namespace mindspore { +namespace pipeline { +using CompileGraphs = compile::CompileGraphs; +using abstract::AnalysisResult; +using mindspore::abstract::AnalysisContextPtr; + +abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec, bool clear) { + MS_LOG(DEBUG) << "AbstractAnalyze start"; + auto engine = res->engine(); + MS_EXCEPTION_IF_NULL(engine); + if (clear) { + auto manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + engine->Clear(); + for (auto &node : manager->all_nodes()) { + MS_EXCEPTION_IF_NULL(node); + const AbstractBasePtr &prev_inferred = node->abstract(); + // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. + if (!node->isa() || (prev_inferred != nullptr && prev_inferred->isa())) { + node->set_abstract(nullptr); + MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr"; + } + } + } + auto ret = engine->Run(func_graph, args_spec); + MS_LOG(DEBUG) << "AbstractAnalyze end"; + return ret; +} + +FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context) { + MS_LOG(DEBUG) << "ProgramSpecialize start"; + abstract::ProgramSpecializer spc(res->engine()); + FuncGraphPtr result = spc.Run(func_graph, context); + auto manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->KeepRoots({result}); + MS_LOG(DEBUG) << "ProgramSpecialize end"; + return result; +} + +FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec) { + MS_LOG(DEBUG) << "Renormalize start"; +#ifdef ENABLE_PROFILE + double t1 = GetTime(); +#endif + abstract::AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec, true); +#ifdef ENABLE_PROFILE + double t2 = GetTime(); +#endif + auto ret = ProgramSpecialize(res, func_graph, result.context); + res->set_func_graph(ret); +#ifdef ENABLE_PROFILE + double t3 = GetTime(); + MsProfile::StatTime("renormalize.infer", t2 - t1); + MsProfile::StatTime("renormalize.specialize", t3 - t2); +#endif + MS_LOG(DEBUG) << "Renormalize end"; + return ret; +} + +bool ParseAction(const ResourcePtr &res) { + if (!res->input()) { + MS_LOG(EXCEPTION) << "Parse error"; + } + + py::object input = res->input(); + parse::Parser::InitParserEnvironment(input); + py::module path = py::module::import("os.path"); + std::string dir = path.attr("dirname")(py::globals()["__file__"]).cast(); + + parse::python_adapter::set_python_env_flag(true); + parse::python_adapter::SetPythonPath(dir); + + FuncGraphPtr fg = parse::ConvertToFuncGraph(input); + if (fg == nullptr) { + MS_LOG(EXCEPTION) << "Parse error."; + } + res->set_func_graph(fg); + + FuncGraphManagerPtr manager = res->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Manager is nullptr."; + } + manager->AddFuncGraph(fg); + return true; +} + +// obj_map's graphs have the same construct, these graphs can be optimized to one graph. +// This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}-> +// graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx} +// all obj_map's graph shared base_graph +bool CombineLikeGraphs(const ResourcePtr &res) { + auto &obj_map = parse::data_converter::GetObjGraphs(); + + for (auto it : obj_map) { + auto &graphs = it.second; + MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size(); + auto fg = graphs[0]; + FuncGraphPtrList func_graphs = {fg}; + ClonerPtr cloner = std::make_shared(func_graphs, false, false, true, std::make_shared(), + std::make_shared()); + cloner->Run(); + auto base_graph = cloner->cloned_func_graph()[fg]; + MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString(); + + if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) { + continue; + } + for (auto &fv : fg->paramter_obj_nodes()) { + TraceManager::DebugTrace(std::make_shared(fv->debug_info())); + auto param = base_graph->add_parameter(); + TraceManager::EndTrace(); + auto &node_users = res->manager()->node_users()[fv]; + for (auto &n : node_users) { + auto repl_n = (*cloner->cloned_node())[n.first]->cast(); + repl_n->set_input(n.second, param); + } + } + MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size(); + + for (auto &g : graphs) { + auto fvs = g->paramter_obj_nodes(); + std::vector new_node_inputs; + new_node_inputs.push_back(NewValueNode(base_graph)); + for (auto &p : g->parameters()) { + AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p); + new_node_inputs.push_back(para_after_cast); + } + (void)new_node_inputs.insert(new_node_inputs.end(), fvs.begin(), fvs.end()); + AnfNodePtr out = g->NewCNode(new_node_inputs); + g->set_output(out); + MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(4); + } + MS_LOG(DEBUG) << "End combine graph:" << it.first; + } + return true; +} + +bool SymbolResolveAction(const ResourcePtr &res) { + if (res->manager() == nullptr) { + MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; + } + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null"; + } + FuncGraphPtr func_graph = res->func_graph(); + auto succ = parse::ResolveFuncGraph(func_graph, res); + + // Remove unused nodes in cnode order list. + func_graph->EraseUnusedNodeInOrder(); + func_graph->ReleaseFullOrderToEffectOrder(); + for (auto fg : func_graph->func_graphs_used_total()) { + MS_EXCEPTION_IF_NULL(fg); + fg->EraseUnusedNodeInOrder(); + fg->ReleaseFullOrderToEffectOrder(); + } + return succ; +} + +bool InferenceOptPrepareAction(const ResourcePtr &res) { + if (res->manager() == nullptr) { + MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; + } + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null."; + } + return InferenceOptPreparePass(res); +} + +bool AbstractSpecializeAction(const ResourcePtr &res) { + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "AbstractSpecialize error"; + } + + FuncGraphPtr func_graph = res->func_graph(); + abstract::AbstractBasePtrList args_spec = res->args_spec(); + + parallel::ParallelParameterContextInit(func_graph); + + // suppose that there is not KeywordArgument for the top graph + // get the hyper parameter + for (const auto ¶m : func_graph->parameters()) { + auto param_node = std::static_pointer_cast(param); + if (param_node->has_default()) { + const auto ¶m_value = param_node->default_param(); + ValuePtr value = param_value->value(); + constexpr bool broaden = true; + AbstractBasePtr ptr = abstract::FromValue(value, broaden); + + parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); + args_spec.push_back(ptr); + parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr); + } + } + // Analyze + AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec); + // The top graph may be replaced by infer, update the top graph when the infer is done + parse::Parser::UpdateTopFuncGraph(result.context->func_graph()); + + // Specialize + FuncGraphPtr new_fg = ProgramSpecialize(res, result.context->func_graph(), result.context); + res->set_func_graph(new_fg); + + MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true); + return true; +} + +bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) { + size_t counter = 0; + for (auto &pass : passes) { + WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() { + MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; + auto result = pass.second(res); + if (!result) { + MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; + } + if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) { + auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first; + auto func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + func_graph->DumpFuncGraph(fg_name); + DumpIR(fg_name + ".ir", func_graph); + MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; + } + counter++; + MS_LOG(DEBUG) << "Pass " << pass.first << " end."; + }; + } + + return true; +} + +bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } + +bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } + +bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } + +static bool IsCtrlSink() { + auto ms_ctx = MsContext::GetInstance(); + if (ms_ctx->execution_mode() != kGraphMode) { + return false; + } + + std::string device_target = ms_ctx->device_target(); + if (device_target != kAscendDevice) { + return false; + } + + if (!ms_ctx->enable_task_sink()) { + return false; + } + + if (!ms_ctx->is_multi_graph_sink()) { + return false; + } + return true; +} + +bool TaskEmitAction(const ResourcePtr &res) { + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "TaskEmit args error"; + } + FuncGraphPtr func_graph = res->func_graph(); + auto bc_ptr = res->results()[kBackend].cast(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (CompileGraphs::ContainMixedTarget(func_graph)) { + bc_ptr->set_is_multi_graph_sink(false); + context_ptr->set_is_multi_graph_sink(false); + context_ptr->set_loop_sink_flag(false); + } else if (context_ptr->execution_mode() != kPynativeMode) { + std::string device_target = context_ptr->device_target(); + if (device_target == kAscendDevice) { + bc_ptr->set_is_multi_graph_sink(true); + context_ptr->set_is_multi_graph_sink(true); + } + } + + if (IsCtrlSink()) { + res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); + return true; + } + std::vector cut_list = compile::nonlinear_ops; + if (bc_ptr->name() == kMsConvert) { + cut_list = compile::GetMsNonlinearOps(); + } + std::shared_ptr compile = std::make_shared(bc_ptr, cut_list); + res->results()[kOutput] = compile->CompileAndLink(func_graph); + return true; +} + +bool ExecuteAction(const ResourcePtr &res) { + if (res->results().count(kOutput) == 0) { + MS_LOG(EXCEPTION) << "Execute args error"; + } + + if (IsCtrlSink()) { + if (!res->results()[kOutput].is()) { + MS_LOG(EXCEPTION) << "Execute args error"; + } + auto graph_id = res->results()[kOutput].cast(); + std::shared_ptr bc_ptr = res->results()[kBackend].cast>(); + std::shared_ptr msbc_ptr = std::dynamic_pointer_cast(bc_ptr); + MS_EXCEPTION_IF_NULL(msbc_ptr); + compile::VmEvalFuncPtr run = + std::make_shared([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef { + MS_LOG(INFO) << "Execute args size " << args.size(); + auto outs = msbc_ptr->RunGraph(graph_id, args); + MS_LOG(DEBUG) << "out size " << outs.size(); + return outs[0]; + }); + res->results()[kOutput] = run; + return true; + } + + if (!res->results()[kOutput].is()) { + MS_LOG(EXCEPTION) << "Execute args error"; + } + compile::FinalVMPtr vm = res->results()[kOutput].cast(); + if (vm == nullptr) { + MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM"; + return true; + } + compile::VmEvalFuncPtr run = + std::make_shared(std::bind(&compile::FinalVM::Eval, vm, std::placeholders::_1)); + res->results()[kOutput] = run; + return true; +} + +// The parallel primitive related valuenode might be partitioned so that its value changes by device, +// that will result in a syncronization error due to different executing order. +// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, +// the final solution will be proposed later as a parallel feature. +bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) { + auto &node_users = res->manager()->node_users(); + auto &users = node_users[value_node]; + auto used_by_keep_value_prim = + std::any_of(users.begin(), users.end(), [](const std::pair &user) -> bool { + MS_EXCEPTION_IF_NULL(user.first); + auto cnode = user.first->cast(); + if (cnode == nullptr) { + return false; + } + auto prim_node = cnode->input(0); + if (IsValueNode(prim_node)) { + auto prim = GetValue(prim_node->cast()->value()); + // value_node is referenced by some parallel primitive + return prim->HasAttr("keep_value_node_input"); + } + return false; + }); + return used_by_keep_value_prim; +} + +bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "Remove value node duplications error."; + } + FuncGraphPtr func_graph = res->func_graph(); + auto manager = res->manager(); + // Remove duplicated value nodes, due to replace operation, can't use reference. + auto value_nodes = func_graph->value_nodes(); + HashCache hash_cache; + HashValue hashes; + for (const auto &value_pair : value_nodes) { + if (KeepValueNodeDuplication(value_pair.first, res)) { + continue; + } + TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes); + } + return true; +} + +bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } + +void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { + MS_EXCEPTION_IF_NULL(res->manager()); + MS_EXCEPTION_IF_NULL(res->func_graph()); + auto ppm = opt::python_pass::PyPassManager::GetInstance(); + if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { + MS_LOG(DEBUG) << "No match.\n"; + } +} + +bool ResolveActionPyStub(const ResourcePtr &res) { + ActionPyStub(res, opt::python_pass::Phase::RESOLVE); + return true; +} + +bool OptActionPyStub(const ResourcePtr &res) { + ActionPyStub(res, opt::python_pass::Phase::OPT); + return true; +} + +static std::vector CommonPipeline() { + std::vector actions; + + // Parse the python ast to ANF graph + actions.emplace_back(std::make_pair("parse", ParseAction)); + + // Resolve the python func + actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); + auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs(); + if (!multi_graphs) { + actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); + } + // Add resolve-stage python pass stub + actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); + actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); + // Evaluate type and shape, and specialize + actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); + + return actions; +} + +std::vector GePipeline() { + auto actions = CommonPipeline(); + // optimize + actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); + // Add opt-stage python pass stub + actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); + actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); + actions.emplace_back(std::make_pair("validate", ValidateAction)); + return actions; +} + +std::vector VmPipeline() { + auto actions = CommonPipeline(); + + // optimize + actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); + + // Add opt-stage python pass stub + actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); + + actions.emplace_back(std::make_pair("validate", ValidateAction)); + + // compile the ANF graph + actions.emplace_back(std::make_pair("task_emit", TaskEmitAction)); + + // to execute the graph + actions.emplace_back(std::make_pair("execute", ExecuteAction)); + + return actions; +} +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/action.h b/mindspore/ccsrc/pipeline/jit/action.h new file mode 100644 index 0000000000..0a1feab1c9 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/action.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_ACTION_H_ +#define MINDSPORE_CCSRC_PIPELINE_ACTION_H_ + +#include +#include +#include +#include +#include "pipeline/jit/resource.h" +#include "vm/segment_runner.h" + +namespace mindspore { +extern const char kMsConvert[]; + +namespace pipeline { +using ActionItem = std::pair>; + +bool ParseAction(const ResourcePtr &res); +bool SymbolResolveAction(const ResourcePtr &res); +bool AbstractSpecializeAction(const ResourcePtr &res); +bool GeOptimizeAction(const ResourcePtr &res); +bool VmOptimizeAction(const ResourcePtr &res); +bool PynativeOptimizeAction(const ResourcePtr &res); +bool TaskEmitAction(const ResourcePtr &res); +bool ExecuteAction(const ResourcePtr &res); + +std::vector GePipeline(); +std::vector VmPipeline(); +abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec, bool clear = false); +FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context); +FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec); +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_ACTION_H_ diff --git a/mindspore/ccsrc/pipeline/jit/base.h b/mindspore/ccsrc/pipeline/jit/base.h new file mode 100644 index 0000000000..0a8a2b75f3 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/base.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_BASE_H_ +#define MINDSPORE_CCSRC_PIPELINE_BASE_H_ + +#include +#include +#include +#include + +#include "ir/anf.h" +#include "pipeline/jit/resource.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace pipeline { +struct ExecutorInfo { + FuncGraphPtr func_graph; + ResourcePtr resource; + std::size_t arg_list_size; +}; +using ExecutorInfoPtr = std::shared_ptr; + +inline std::string GetPhasePrefix(const std::string &phase) { + auto pos = phase.find('.'); + if (pos == std::string::npos) { + MS_LOG(EXCEPTION) << "Phase has no . for prefix" << phase; + } + return phase.substr(0, pos); +} + +inline std::string GetFilePathName(const std::string &file_name) { + std::ostringstream oss; + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(EXCEPTION) << "ms_context is nullptr"; + } + auto save_graphs_path = ms_context->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + oss << save_graphs_path << "/" << file_name; + return oss.str(); +} +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_BASE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc new file mode 100644 index 0000000000..65adebb6e2 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -0,0 +1,336 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/oplib/oploader.h" +#include "pipeline/jit/pipeline.h" +#include "frontend/operator/composite/composite.h" +#include "ir/signature.h" +#include "pipeline/pynative/pynative_execute.h" +#include "utils/symbolic.h" +#include "pybind_api/api_register.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "utils/summary/event_writer.h" +#include "utils/config_manager.h" +#include "utils/mpi/mpi_config.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/costmodel_context.h" +#ifdef ENABLE_GPU_COLLECTIVE +#include "runtime/device/gpu/distribution/collective_init.h" +#else +#include "runtime/device/gpu/distribution/collective_fake_init.h" +#endif +namespace py = pybind11; + +using EnvInstance = mindspore::EnvInstance; +using ExecutorPy = mindspore::pipeline::ExecutorPy; +using Pipeline = mindspore::pipeline::Pipeline; +using PrimitivePy = mindspore::PrimitivePy; +using MetaFuncGraph = mindspore::MetaFuncGraph; +using EventWriter = mindspore::summary::EventWriter; +using OpLib = mindspore::kernel::OpLib; +using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy; +using ParallelContext = mindspore::parallel::ParallelContext; +using CostModelContext = mindspore::parallel::CostModelContext; + +// Interface with python +PYBIND11_MODULE(_c_expression, m) { + m.doc() = "MindSpore c plugin"; + + auto fns = mindspore::PybindDefineRegister::AllFuncs(); + for (auto &item : fns) { + item.second(&m); + } + + // Class Pipeline interface + (void)py::class_>(m, "Executor_") + .def_static("get_instance", &ExecutorPy::GetInstance, "Executor get_instance.") + .def("__call__", &ExecutorPy::Run, py::arg("args"), py::arg("phase") = py::str(""), "Executor run function.") + .def("del_net_res", &ExecutorPy::DelNetRes, py::arg("network_id") = py::str(""), "Delete network resource.") + .def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.") + .def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""), + py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") + .def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), + py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") + .def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"), + "Get Parameter Tensor Layout Dictionary.") + .def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"), + "Get CNode Strategy Dictionary.") + .def("get_allreduce_fusion", &ExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"), + "Get Allreduce Fusion Dictionary.") + .def("fetch_info_for_quant_export", &ExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"), + "Fetch the inputs of Conv or Matmul for quant export.") + .def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"), + py::arg("broadcast_params") = py::dict(), "Build data graph.") + .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") + .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); + + (void)py::class_>(m, "EnvInstance_") + .def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_) + .def(py::init()); + + (void)m.def("generate_key", &mindspore::pipeline::GenerateKey, "Generate the function graph key."); + (void)m.def("real_run_op", &mindspore::pynative::RunOp, "Run op pynatively."); + (void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id"); + (void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl"); + (void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl"); + (void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature."); + (void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"), + py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"), + py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset."); + (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode."); + (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend."); + + (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); + + (void)py::class_>(m, "MSContext") + .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") + .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") + .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.") + .def("get_execution_mode", &mindspore::MsContext::execution_mode, "Get execution mode.") + .def("set_execution_mode", &mindspore::MsContext::set_execution_mode, "Set execution mode.") + .def("set_precompile_only", &mindspore::MsContext::set_precompile_only, "Set enable precompile only.") + .def("get_precompile_only", &mindspore::MsContext::precompile_only, "Get enable precompile only.") + .def("get_device_target", &mindspore::MsContext::device_target, "Get device target.") + .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") + .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") + .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") + .def("open_tsd", &mindspore::MsContext::OpenTsd, "Open tdt dataset client.") + .def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.") + .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") + .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") + .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, + "Get whether to enable auto mixed precision.") + .def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag, + "Set whether to enable auto mixed precision.") + .def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision, + "Get whether to enable reduce precision.") + .def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision, + "Set whether to enable reduce precision.") + .def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.") + .def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.") + .def("get_save_ms_model_flag", &mindspore::MsContext::save_ms_model_flag, "Get whether to save ms model.") + .def("set_save_ms_model_flag", &mindspore::MsContext::set_save_ms_model_flag, "Set whether to save ms model.") + .def("get_save_ms_model_path", &mindspore::MsContext::save_ms_model_path, "Get path to save ms model.") + .def("set_save_ms_model_path", &mindspore::MsContext::set_save_ms_model_path, "Set path to save ms model") + .def("get_enable_dump", &mindspore::MsContext::enable_dump, "Get whether to enable dump.") + .def("set_enable_dump", &mindspore::MsContext::set_enable_dump, "Set whether to enable dump.") + .def("get_save_dump_path", &mindspore::MsContext::save_dump_path, "Get path to dump.") + .def("set_save_dump_path", &mindspore::MsContext::set_save_dump_path, "Set path to dump.") + .def("set_graph_memory_max_size", &mindspore::MsContext::set_graph_memory_max_size, "set graph memory max size.") + .def("set_variable_memory_max_size", &mindspore::MsContext::set_variable_memory_max_size, + "set variable memory max size") + .def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.") + .def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.") + .def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.") + .def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.") + .def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.") + .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") + .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") + .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") + .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") + .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, + "Set the GraphKernel switch to on or off.") + .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") + .def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.") + .def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity."); + + (void)py::class_>(m, "MpiConfig") + .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") + .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.") + .def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi."); + + (void)py::class_>(m, "AutoParallelContext") + .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") + .def("get_device_num", &ParallelContext::device_num, "Get device num.") + .def("set_device_num", &ParallelContext::set_device_num, "Set device num.") + .def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.") + .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.") + .def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.") + .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") + .def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.") + .def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.") + .def("get_cast_before_mirror", &ParallelContext::cast_before_mirror, "Get cast before mirror.") + .def("set_cast_before_mirror", &ParallelContext::set_cast_before_mirror, "Set cast before mirror.") + .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.") + .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") + .def("get_communication_backend", &ParallelContext::communication_backend, "Get communication backend.") + .def("set_communication_backend", &ParallelContext::set_communication_backend, "Set communication backend.") + .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") + .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") + .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") + .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.") + .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices, + "Set all reduce fusion split indices.") + .def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices, + "Get all reduce fusion split indices.") + .def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes, + "Set all reduce fusion split sizes.") + .def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes, + "Get all reduce fusion split sizes.") + .def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion, + "Set enable/disable all reduce fusion.") + .def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion, + "Get enable/disable all reduce fusion.") + .def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.") + .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set, + "Get parameter broadcast is set.") + .def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.") + .def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file, + "Set strategy checkpoint load file.") + .def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file, + "Set strategy checkpoint save file.") + .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") + .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") + .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") + .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") + .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, + "Set enable/disable parallel optimizer.") + .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, + "Get enable/disable parallel optimizer.") + .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); + + (void)py::class_>(m, "CostModelContext") + .def_static("get_instance", &CostModelContext::GetInstance, "Get cost_model context instance.") + .def("set_device_memory_capacity", &CostModelContext::set_device_memory_capacity, + "Set the capacity of device memory.") + .def("get_device_memory_capacity", &CostModelContext::device_memory_capacity, "Get the capacity of device memory.") + .def("set_costmodel_alpha", &CostModelContext::set_costmodel_alpha, + "Set the parameter cost_model_alpha of the DP algorithm.") + .def("get_costmodel_alpha", &CostModelContext::costmodel_alpha, + "Get the parameter cost_model_alpha of the DP algorithm.") + .def("set_costmodel_beta", &CostModelContext::set_costmodel_beta, + "Set the parameter cost_model_beta of the DP algorithm.") + .def("get_costmodel_beta", &CostModelContext::costmodel_beta, + "Get the parameter cost_model_beta of the DP algorithm.") + .def("set_costmodel_gamma", &CostModelContext::set_costmodel_gamma, + "Set the parameter cost_model_gamma of the DP algorithm") + .def("get_costmodel_gamma", &CostModelContext::costmodel_gamma, + "Get the parameter cost_model_gamma of the DP algorithm.") + .def("set_costmodel_communi_threshold", &CostModelContext::set_costmodel_communi_threshold, + "Set the parameter cost_model_communi_threshold of the DP algorithm.") + .def("get_costmodel_communi_threshold", &CostModelContext::costmodel_communi_threshold, + "Get the parameter cost_model_communi_threshold of the DP algorithm.") + .def("set_costmodel_communi_const", &CostModelContext::set_costmodel_communi_const, + "Set the parameter cost_model_communi_const of the DP algorithm.") + .def("get_costmodel_communi_const", &CostModelContext::costmodel_communi_const, + "Get the parameter cost_model_communi_const of the DP algorithm.") + .def("set_costmodel_communi_bias", &CostModelContext::set_costmodel_communi_bias, + "Set the parameter cost_model_communi_bias of the DP algorithm.") + .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias, + "Get the parameter cost_model_communi_bias of the DP algorithm.") + .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") + .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") + .def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.") + .def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.") + .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, + "Set the parameter gradient AllReduce fusion algorithm.") + .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, + "Get the parameter gradient AllReduce fusion algorithm.") + .def("set_costmodel_allreduce_fusion_times", &CostModelContext::set_costmodel_allreduce_fusion_times, + "Set the parameter gradient AllReduce times.") + .def("get_costmodel_allreduce_fusion_times", &CostModelContext::costmodel_allreduce_fusion_times, + "Get the parameter gradient AllReduce times.") + .def("set_costmodel_allreduce_fusion_tail_percent", &CostModelContext::set_costmodel_allreduce_fusion_tail_percent, + "Set the parameter gradient AllReduce fusion tail percent.") + .def("get_costmodel_allreduce_fusion_tail_percent", &CostModelContext::costmodel_allreduce_fusion_tail_percent, + "Get the parameter gradient AllReduce fusion tail percent.") + .def("set_costmodel_allreduce_fusion_tail_time", &CostModelContext::set_costmodel_allreduce_fusion_tail_time, + "Set the parameter gradient AllReduce fusion tail time.") + .def("get_costmodel_allreduce_fusion_tail_time", &CostModelContext::costmodel_allreduce_fusion_tail_time, + "Get the parameter gradient AllReduce fusion tail time.") + .def("set_costmodel_allreduce_fusion_allreduce_inherent_time", + &CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time, + "Set the parameter gradient AllReduce fusion allreduce inherent time.") + .def("get_costmodel_allreduce_fusion_allreduce_inherent_time", + &CostModelContext::costmodel_allreduce_fusion_allreduce_inherent_time, + "Get the parameter gradient AllReduce fusion allreduce inherent time.") + .def("set_costmodel_allreduce_fusion_allreduce_bandwidth", + &CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth, + "Set the parameter gradient AllReduce fusion allreduce bandwidth.") + .def("get_costmodel_allreduce_fusion_allreduce_bandwidth", + &CostModelContext::costmodel_allreduce_fusion_allreduce_bandwidth, + "Get the parameter gradient AllReduce fusion allreduce bandwidth.") + .def("set_costmodel_allreduce_fusion_computation_time_parameter", + &CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter, + "Set the parameter gradient AllReduce fusion computation time parameter.") + .def("get_costmodel_allreduce_fusion_computation_time_parameter", + &CostModelContext::costmodel_allreduce_fusion_computation_time_parameter, + "Get the parameter gradient AllReduce fusion computation time parameter.") + .def("set_tensor_slice_align_enable", &CostModelContext::set_tensor_slice_alignment_enable, + "Set the parameter tensor_slice_align_enable in strategy generation.") + .def("get_tensor_slice_align_enable", &CostModelContext::tensor_slice_alignment_enable, + "Get the parameter tensor_slice_align_enable in strategy generation.") + .def("set_tensor_slice_align_size", &CostModelContext::set_tensor_slice_alignment_size, + "Set the parameter tensor_slice_size in strategy generation.") + .def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size, + "Get the parameter tensor_slice_size in strategy generation.") + .def("set_fully_use_devices", &CostModelContext::set_fully_use_device, + "Set the parameter fully_use_devices in the DP algorithm.") + .def("get_fully_use_devices", &CostModelContext::fully_use_device, + "Get the parameter fully_use_devices in the DP algorithm.") + .def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow, + "Set the parameter elementwise_op_strategy_follow in the DP algorithm.") + .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow, + "Get the parameter elementwise_op_strategy_follow in the DP algorithm.") + .def("reset_cost_model", &CostModelContext::ResetCostModel, "Reset the CostModelContext.") + .def("reset_algo_parameters", &CostModelContext::ResetAlgoParameters, "Reset the AlgoParameters."); + + (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { + // only in case that c++ calling python interface, ClearResAtexit should be called. + if (mindspore::parse::python_adapter::IsPythonEnv()) { + mindspore::pipeline::ClearResAtexit(); + +#ifdef ENABLE_MINDDATA + py::module iterators = py::module::import("mindspore.dataset.engine.iterators"); + (void)iterators.attr("_cleanup")(); +#endif + } + }}); + + (void)py::class_>(m, "EventWriter_") + .def(py::init()) + .def("GetFileName", &EventWriter::GetFileName, "Get the file name.") + .def("Open", &EventWriter::Open, "Open the write file.") + .def("Write", &EventWriter::Write, "Write the serialize event.") + .def("EventCount", &EventWriter::GetWriteEventCount, "Write event count.") + .def("Flush", &EventWriter::Flush, "Flush the event.") + .def("Close", &EventWriter::Close, "Close the write.") + .def("Shut", &EventWriter::Shut, "Final close the write."); + + (void)py::class_>(m, "Oplib") + .def(py::init()) + .def_static("reg_op", &OpLib::RegOp, "Register op info."); +#ifdef ENABLE_GPU_COLLECTIVE + (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::InitCollective, + "Init gpu collective communication mode."); + (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::FinalizeCollective, + "Finalize gpu collective communication mode."); +#else + (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::InitCollective, + "Init gpu collective communication mode."); + (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::FinalizeCollective, + "Finalize gpu collective communication mode."); + +#endif + + (void)py::class_>(m, "OpInfoLoaderPy") + .def(py::init()) + .def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info."); +} diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc new file mode 100644 index 0000000000..baef64481b --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -0,0 +1,559 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/data_converter.h" +#include +#include +#include +#include +#include +#include +#include +#include "pipeline/jit/parse/resolve.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/composite.h" +#include "ir/func_graph_cloner.h" +#include "utils/symbolic.h" +#include "utils/context/ms_context.h" +#include "debug/trace.h" +#include "frontend/optimizer/ad/grad.h" + +namespace mindspore { +namespace parse { +using Tensor = mindspore::tensor::Tensor; +using TensorPtr = mindspore::tensor::TensorPtr; +using MetaTensor = mindspore::tensor::MetaTensor; +using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; + +FuncGraphPtr ConvertToBpropCut(const py::object &obj) { + std::vector results = data_converter::GetObjKey(obj); + std::string obj_key = results[0]; + py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME); + + auto bprop_graph = std::make_shared(); + std::vector outputs; + + auto fake_bprop = std::make_shared("bprop_cut", py::object()); + fake_bprop->set_hook(bprop_func); + (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true)); + outputs.push_back(NewValueNode(fake_bprop)); + + py::object code_obj = py::getattr(bprop_func, "__code__"); + size_t inputs_num = py::cast(py::getattr(code_obj, "co_argcount")) - 3; + for (size_t i = 0; i < inputs_num; ++i) { + auto param = bprop_graph->add_parameter(); + outputs.push_back(param); + } + auto p1 = bprop_graph->add_parameter(); + auto p2 = bprop_graph->add_parameter(); + outputs.push_back(p1); + outputs.push_back(p2); + + bprop_graph->set_output(bprop_graph->NewCNode(outputs)); + data_converter::SetObjGraphValue(obj_key, bprop_graph); + return bprop_graph; +} + +namespace { +bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { + MS_LOG(DEBUG) << "Converting python tuple"; + py::tuple tuple = obj.cast(); + std::vector value_list; + for (size_t it = 0; it < tuple.size(); ++it) { + ValuePtr out = nullptr; + bool success = ConvertData(tuple[it], &out, use_signature); + if (!success) { + return false; + } + value_list.push_back(out); + } + *data = std::make_shared(value_list); + + return true; +} + +bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { + MS_LOG(DEBUG) << "Converting python list"; + + py::list list = obj.cast(); + std::vector value_list; + for (size_t it = 0; it < list.size(); ++it) { + ValuePtr out = nullptr; + bool success = ConvertData(list[it], &out, use_signature); + if (!success) { + return false; + } + value_list.push_back(out); + } + *data = std::make_shared(value_list); + return true; +} + +bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) { + MS_LOG(DEBUG) << "Converting cell list"; + py::sequence list = obj; + std::vector value_list; + for (size_t it = 0; it < list.size(); ++it) { + ValuePtr out = nullptr; + bool success = ConvertData(list[it], &out, use_signature); + if (!success) { + return false; + } + value_list.push_back(out); + } + *data = std::make_shared(value_list); + return true; +} + +bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { + MS_LOG(DEBUG) << "Converting python dict"; + + py::dict dict_values = obj.cast(); + std::vector> key_values; + for (auto item : dict_values) { + if (!py::isinstance(item.first)) { + MS_LOG(EXCEPTION) << "The key of dict is only support str."; + } + std::string key = py::str(item.first); + ValuePtr out = nullptr; + bool success = ConvertData(dict_values[item.first], &out, use_signature); + if (!success) { + return false; + } + key_values.emplace_back(key, out); + } + *data = std::make_shared(key_values); + return true; +} + +void ConvertNameSpace(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting python module"; + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj); + *data = std::make_shared(RESOLVE_NAMESPACE_NAME_MODULE, py::cast(module_namespace)); +} + +void ConvertDataClass(py::object obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting dataclass"; + // Maybe the obj is dataclass define + auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); + // desc has format "", strip the '<' and '>' by offset 1; + *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); +} + +bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { + MS_LOG(DEBUG) << "Converting primitive object"; + + // need check the primitive is class type or instance + auto obj_type = data_converter::GetObjType(obj); + if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { + auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); + // desc has format "", strip the '<' and '>' by offset 1; + *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); + } else { + auto primitive = obj.cast(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null"; + return false; + } + if (py::hasattr(obj, "__setattr_flag__")) { + if (py::hasattr(obj, "_clone")) { + auto clone_fn = obj.attr("_clone"); + py::object new_obj = clone_fn(); + primitive = new_obj.cast(); + } + } + if (use_signature) { + *data = std::make_shared(primitive->name(), primitive); + } else { + *data = primitive; + } + } + return true; +} + +bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) { + MS_LOG(DEBUG) << "Converting MetaFuncGraph object"; + auto meta = obj.cast(); + if (meta == nullptr) { + MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null"; + return false; + } + if (use_signature) { + *data = std::make_shared(meta->name(), meta); + } else { + *data = meta; + } + return true; +} + +bool ConvertDataType(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting type object"; + auto typeptr = obj.cast(); + if (typeptr == nullptr) { + MS_LOG(ERROR) << "Resolve TypePtr error, get ptr is null"; + return false; + } + *data = typeptr; + return true; +} + +bool ConvertMetaTensor(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting MetaTensor object."; + + auto m_tensor = obj.cast(); + if (m_tensor == nullptr) { + MS_LOG(ERROR) << "Resolve MetaTensor error, get ptr is null."; + return false; + } + *data = m_tensor; + return true; +} + +bool ConvertTensor(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting tensor object"; + + auto m_tensor = obj.cast(); + if (m_tensor == nullptr) { + MS_LOG(ERROR) << "Resolve Tensor error, get ptr is null"; + return false; + } + *data = m_tensor; + return true; +} + +bool ConvertSlice(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting slice object"; + + py::slice slice_obj = obj.cast(); + auto convert_func = [obj](std::string attr) -> ValuePtr { + auto py_attr = py::getattr(obj, attr.c_str()); + if (py::isinstance(py_attr)) { + return kNone; + } else if (py::isinstance(py_attr)) { + int value = py::cast(py_attr); + return MakeValue(value); + } else { + MS_LOG(EXCEPTION) << "Slice should contain only int or none"; + } + }; + ValuePtr start = convert_func("start"); + ValuePtr stop = convert_func("stop"); + ValuePtr step = convert_func("step"); + *data = std::make_shared(start, stop, step); + return true; +} + +bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { + FuncGraphPtr func_graph = ConvertToFuncGraph(obj); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Parse resolve function error."; + return false; + } + // if the cell object has specified bprop, it has user-defined bprop function parse and record it + if (py::hasattr(obj, CUSTOM_BPROP_NAME)) { + FuncGraphPtr bprop_graph = nullptr; + bool enable_bprop_debug = py::cast(py::getattr(obj, "bprop_debug")); + if (enable_bprop_debug) { + bprop_graph = ConvertToBpropCut(obj); + } else { + bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); + } + if (bprop_graph != nullptr) { + (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); + (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); + func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true); + } + } + *data = func_graph; + return true; +} + +bool ConvertOtherObj(py::object obj, ValuePtr *const data) { + auto obj_type = data_converter::GetObjType(obj); + MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; + if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { + MS_LOG(DEBUG) << "Resolve the class type, need create class instance."; + std::string desc = py::str(obj); + // desc has format "", strip the '<' and '>' by offset 1; + *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); + return true; + } + if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) { + MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type; + FuncGraphPtr func_graph = ConvertToFuncGraph(obj); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Parse resolve function error."; + return false; + } + *data = func_graph; + return true; + } + if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { + // Create the namespace for common class instance + // When the obj is Cell, default parse the 'construct' + if (data_converter::IsCellInstance(obj)) { + return ConvertCellObjToFuncGraph(obj, data); + } + + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); + *data = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); + return true; + } + MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); + return false; +} +} // namespace + +bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) { + // check parameter valid + if (data == nullptr) { + MS_LOG(ERROR) << "Data is null pointer"; + return false; + } + + bool ret = true; + ValuePtr converted = nullptr; + if (py::isinstance(obj)) { + converted = kNone; + } else if (py::isinstance(obj)) { + converted = std::make_shared(py::cast(obj)); + } else if (py::isinstance(obj)) { + converted = std::make_shared(py::cast(obj)); + } else if (py::isinstance(obj)) { + converted = std::make_shared(py::cast(obj)); + } else if (py::isinstance(obj)) { + converted = std::make_shared(py::cast(obj)); + } else if (py::isinstance(obj)) { + ret = ConvertDict(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + ret = ConvertSlice(obj, &converted); + } else if (py::isinstance(obj)) { + converted = kEllipsis; + } else if (py::isinstance(obj)) { + ret = ConvertTuple(obj, &converted, use_signature); + } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { + ret = ConvertCellList(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + ret = ConvertList(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + ConvertNameSpace(obj, &converted); + } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { + ConvertDataClass(obj, &converted); + } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { + ret = ConvertPrimitive(obj, &converted, use_signature); + } else if (py::hasattr(obj, PYTHON_METAFUNCGRAPH_FLAG)) { + ret = ConvertMetaFuncGraph(obj, &converted, use_signature); + } else if (py::hasattr(obj, PYTHON_DTYPE_FLAG)) { + ret = ConvertDataType(obj, &converted); + } else if (py::hasattr(obj, PYTHON_TENSOR_FLAG)) { + ret = ConvertTensor(obj, &converted); + } else if (py::hasattr(obj, PYTHON_META_TENSOR_FLAG)) { + ret = ConvertMetaTensor(obj, &converted); + } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { + std::shared_ptr env = obj.cast>(); + converted = env; + } else if (py::hasattr(obj, "__parameter__")) { + auto to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); + ret = ConvertData(to_convert, &converted); + } else { + ret = ConvertOtherObj(obj, &converted); + } + + *data = converted; + return ret; +} + +// convert data to graph +FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) { + std::vector results = data_converter::GetObjKey(obj); + std::string obj_id = results[0] + python_mod_get_parse_method; + std::string obj_key = results[1]; + FuncGraphPtr func_graph = nullptr; + Any value = Any(); + bool is_cache = data_converter::GetObjectValue(obj_id, &value); + if (is_cache) { + if (value.is()) { + MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; + func_graph = value.cast(); + return func_graph; + } + } + + func_graph = ParsePythonCode(obj, python_mod_get_parse_method); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Parse resolve function error."; + return nullptr; + } + + data_converter::MakeProperNameToFuncGraph(func_graph, obj_id); + data_converter::CacheObjectValue(obj_id, func_graph); + if (obj_key != "") { + MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString(); + data_converter::SetObjGraphValue(obj_key, func_graph); + } + + return func_graph; +} +namespace data_converter { +static std::unordered_map object_map_ = std::unordered_map(); + +static std::unordered_map> object_graphs_map_ = + std::unordered_map>(); + +void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { + object_graphs_map_[obj_key].push_back(data); + MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size(); +} + +const std::unordered_map> &GetObjGraphs() { + MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size(); + return object_graphs_map_; +} + +void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; } +bool GetObjectValue(const std::string &obj_key, Any *const data) { + if (object_map_.count(obj_key)) { + *data = object_map_[obj_key]; + return true; + } + return false; +} +std::vector GetObjKey(const py::object &obj) { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj); + if (obj_tuple.size() != 2) { + MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements"; + } + return {py::cast(obj_tuple[0]), py::cast(obj_tuple[1])}; +} + +// get obj detail type +ResolveTypeDef GetObjType(const py::object &obj) { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + auto obj_type = + ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast()); + return obj_type; +} + +// get class instance detail type +ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + auto class_type = + ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast()); + return class_type; +} + +// check the object is Cell Instance +bool IsCellInstance(const py::object &obj) { + auto class_type = GetClassInstanceType(obj); + bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); + return isCell; +} + +// create the python class instance +py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::object obj; + if (params.size() == 0) { + obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type); + } else { + obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); + } + return obj; +} + +// Generate an appropriate name and set to graph debuginfo +// character <> can not used in the dot file, so change to another symbol +void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(func_graph->debug_info()); + // set detail name info of function + std::ostringstream oss; + for (size_t i = 0; i < name.size(); i++) { + if (name[i] == '<') { + oss << "「"; + } else if (name[i] == '>') { + oss << "」"; + } else { + oss << name[i]; + } + } + func_graph->debug_info()->set_full_name(oss.str()); +} + +ValuePtr PyDataToValue(const py::object &obj) { + py::object to_convert = obj; + if (py::hasattr(obj, "__parameter__")) { + to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); + } + ValuePtr value = nullptr; + (void)ConvertData(to_convert, &value); + return value; +} + +void ClearObjectCache() { + object_map_.clear(); + object_graphs_map_.clear(); +} +} // namespace data_converter + +static std::unordered_map g_dataClassToClass = {}; + +// parse dataclass to mindspore Class type +ClassPtr ParseDataClass(const py::object &cls_obj) { + std::string cls_name = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__name__")); + std::string cls_module = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__module__")); + std::string cls = cls_module + "." + cls_name; + auto iterator = g_dataClassToClass.find(cls); + if (iterator != g_dataClassToClass.end()) { + return iterator->second; + } + + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + ClassAttrVector attributes; + py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj); + for (auto &item : names) { + TypePtr type_value = item.second.cast(); + MS_EXCEPTION_IF_NULL(type_value); + MS_LOG(DEBUG) << "(Name: " << py::cast(item.first) << ", type: " << type_value->ToString() << ")"; + attributes.push_back(std::make_pair(py::cast(item.first), type_value)); + } + + std::unordered_map methods_map; + py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj); + for (auto &item : methods) { + std::string fun_name = item.first.cast(); + py::object obj = py::cast(item.second); + std::shared_ptr method_obj = std::make_shared(obj, fun_name); + methods_map[fun_name] = method_obj; + } + + std::shared_ptr me_class = std::make_shared(Named(cls_name), attributes, methods_map); + // static Variable for cache + // cppcheck-suppress unreadVariable + g_dataClassToClass[cls] = me_class; + + return me_class; +} + +void CleanDataClassToClassMap() { g_dataClassToClass.clear(); } +} // namespace parse +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.h b/mindspore/ccsrc/pipeline/jit/parse/data_converter.h new file mode 100644 index 0000000000..6632d4801e --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.h @@ -0,0 +1,61 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_DATA_CONVERTER_H_ +#define PIPELINE_PARSE_DATA_CONVERTER_H_ + +#include +#include +#include +#include +#include +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parse { +// data convert for parse +namespace data_converter { +void CacheObjectValue(const std::string &obj_key, const Any &data); +bool GetObjectValue(const std::string &obj_key, Any *const data); + +void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data); + +const std::unordered_map> &GetObjGraphs(); + +std::vector GetObjKey(const py::object &obj); +ResolveTypeDef GetObjType(const py::object &obj); +ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); + +bool IsCellInstance(const py::object &obj); +py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms); +void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); +ValuePtr PyDataToValue(const py::object &obj); +void ClearObjectCache(); +} // namespace data_converter + +ClassPtr ParseDataClass(const py::object &cls_obj); +FuncGraphPtr ConvertToBpropCut(const py::object &obj); + +void CleanDataClassToClassMap(); + +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_DATA_CONVERTER_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc new file mode 100644 index 0000000000..b52dddda66 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -0,0 +1,374 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/function_block.h" +#include +#include +#include +#include "pipeline/jit/parse/resolve.h" +#include "pipeline/jit/parse/parse.h" +#include "frontend/operator/ops.h" +#include "debug/info.h" +#include "debug/trace.h" +#include "pybind11/pybind11.h" + +namespace mindspore { +namespace py = pybind11; + +namespace parse { +FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { + func_graph_ = std::make_shared(); + matured_ = false; +} + +void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); } + +// write variable records the variable name to corresponding node +void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { + MS_LOG(DEBUG) << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString(); + vars_[var_name] = node; +} + +// read variable from predecessors +AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { + // get var node if it is found + if (vars_.count(var)) { + AnfNodePtr node = vars_[var]; + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + return NewValueNode(GetValueNode(node)); + } else { + return node; + } + } + // get var from predecessor block ,if can't get the make a resolve node to it + if (matured_) { + // If only one predecessor block, read the definition of var from it. + if (prev_blocks_.size() == 1) { + auto block = prev_blocks_[0]; + MS_EXCEPTION_IF_NULL(block); + return block->ReadVariable(var); + } else if (prev_blocks_.empty()) { + // get namespace and make Reslove + return MakeResolveSymbol(var); + } + } + // If have more than one predecessor blocks then build a phi node. + auto debug_info = std::make_shared(); + debug_info->set_name(var); + TraceManager::DebugTrace(std::make_shared(debug_info)); + ParameterPtr phi_param = std::make_shared(func_graph()); + TraceManager::EndTrace(); + MS_LOG(DEBUG) << func_graph_->ToString() << " generate phi node " << phi_param->ToString() << " for " << var; + func_graph()->add_parameter(phi_param); + phi_nodes_[phi_param] = var; + WriteVariable(var, phi_param); + if (matured_) { + SetPhiArgument(phi_param); + } + return phi_param; +} + +// Resolve Ast operator node +AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) { + auto ast = parser_.ast(); + MS_EXCEPTION_IF_NULL(ast); + TraceGuard trace_guard(parser_.GetLocation(op)); + py::tuple namespace_var = ast->CallParserObjMethod(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op); + if (namespace_var.size() != 2) { + MS_LOG(EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size(); + } + NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]); + SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); + return MakeResolve(name_space, symbol); +} + +// Resolve class member, two possible: method, member variable +AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { + py::object namespace_var = + parser_.ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, parser_.ast()->obj()); + NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); + SymbolPtr symbol = std::make_shared(attr); + return MakeResolve(name_space, symbol); +} + +// Make a resolve node for symbol string +AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { + if (value.compare(0, strlen("self."), "self.") == 0) { + auto start = value.find_first_of('.') + 1; + if (start >= value.size()) { + MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value; + return nullptr; + } + auto bits_str = value.substr(start); + return MakeResolveClassMember(bits_str); + } + py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value); + + NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_var[0]); + SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); + return MakeResolve(name_space, symbol); +} + +AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) { + py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value); + NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]); + SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); + return MakeResolve(name_space, symbol); +} + +AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) { + MS_LOG(DEBUG) << "MakeResolve for " << ((std::string)py::str(name_space->obj())) << " , " + << ((std::string)resolve_symbol->symbol()); + ValueNodePtr module_node = NewValueNode(name_space); + ValueNodePtr symbol_node = NewValueNode(resolve_symbol); + auto node = func_graph()->NewCNode({NewValueNode(prim::kPrimResolve), module_node, symbol_node}); + return node; +} + +// add input for the block's phi parameter +void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { + std::string var = phi_nodes_[phi]; + MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; + for (auto &pred : prev_blocks_) { + MS_EXCEPTION_IF_NULL(pred); + MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); + AnfNodePtr arg_node = pred->ReadVariable(var); + CNodePtr jump = pred->jumps_[this]; + jump->add_input(arg_node); + } + // If the phi node in the body part of a for/while loop is being removed, + // then the closure convert phase will generate a cycle in graph if the + // loop is kept after specialization. This should be investigate further. + // Just now user has to set a flag on a function to indicate the for loop + // will definitely can be unroll as the sequence in for statement is fixed + // size in compile time. + if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) || + parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) { + CollectRemovablePhi(phi); + } +} + +AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { + AnfNodePtr arg_node = nullptr; + for (auto &prev : prev_blocks_) { + MS_EXCEPTION_IF_NULL(prev); + AnfNodePtr temp_node = prev->ReadVariable(var); + MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var + << " is " << temp_node->DebugString(); + if (temp_node != phi) { + if (arg_node == nullptr) { + arg_node = temp_node; + MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() + << " may be replaced by node " << arg_node->DebugString(); + } else if (temp_node == arg_node) { + MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " is same as node " + << arg_node->DebugString(); + } else { + MS_LOG(DEBUG) << "phi " << phi->ToString() + << " cannot be removed as it assigns to different node. node1: " << arg_node->DebugString() + << ", node2: " << temp_node->DebugString(); + return nullptr; + } + } + } + return arg_node; +} + +// Check if there is removable unnecessary phi node in this graph. +// as per the FIRM TR 3.2, a phi node can be remove if: +// +// If all arguments of a φ-function are the same value s or the φfunction itself, +// then we remove the φ-function and let all users directly uses. We call such a +// φ-function obviously unnecessary. +// When we removed a φ-function p, then we recursively try to apply this simplification +// rule with all (former) users of p, because they may have become obviously unnecessary +// due to the removal of p +// +// phi node in graph will be removed after the whole function is parsed in a DFS visit +// of that graph.The reason is : +// 1. when this function is called, not all usage of this phi node had bound to the +// graph of this function block, some may stay in vars_ in other blocks. +// 2. it's costly to iterate the graph to replace the phi for each phi. +// Args : +// phi : This parameter node is functioning as a phi node. +void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { + MS_EXCEPTION_IF_NULL(phi); + std::string var = phi_nodes_[phi]; + MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); + if (prev_blocks_.size() == 0) { + MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString(); + return; + } + AnfNodePtr arg_node = SearchReplaceNode(var, phi); + if (arg_node != nullptr) { + MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with " + << arg_node->DebugString(); + // replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1." + WriteVariable(var, arg_node); + removable_phis_[phi] = arg_node; + // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized + // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. + for (auto &prev : prev_blocks_) { + MS_EXCEPTION_IF_NULL(prev); + if (!prev->matured_) { + continue; + } + for (auto &phi_iter : prev->removable_phis_) { + MS_EXCEPTION_IF_NULL(phi_iter.second); + if (phi_iter.second->isa()) { + const auto ¶m = phi_iter.second->cast(); + if (param == phi) { + MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() + << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); + prev->removable_phis_[phi_iter.first] = arg_node; + } + } + } + } + } +} + +// A block should be marked matured if its predecessor blocks have been processed +void FunctionBlock::Mature() { + const auto &graphParamVec = func_graph_->parameters(); + for (auto ¶mItr : graphParamVec) { + MS_EXCEPTION_IF_NULL(paramItr); + ParameterPtr param = paramItr->cast(); + if (phi_nodes_.find(param) != phi_nodes_.cend()) { + SetPhiArgument(param); + } + } + matured_ = true; +} + +// Force the conditIon node to bool using bool operation +CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) { + TraceManager::DebugTrace(std::make_shared(cond->debug_info())); + CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond}); + TraceManager::EndTrace(); + return op_apply_node; +} + +CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) { + TraceManager::DebugTrace(std::make_shared(cond->debug_info())); + CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation("while_cond"), cond}); + TraceManager::EndTrace(); + return op_apply_node; +} + +// Perform a jump from this block to target block +void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) { + if (func_graph()->get_return() != nullptr) { + MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " + << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); + } + std::vector input_nodes; + input_nodes.emplace_back(NewValueNode(target_block->func_graph())); + if (node != nullptr) { + input_nodes.emplace_back(node); + } + + CNodePtr jump = func_graph()->NewCNode(input_nodes); + jumps_[target_block.get()] = jump; + target_block->AddPrevBlock(shared_from_this()); + func_graph()->set_output(jump); + InsertDependItemsBeforeReturn(); +} + +// Perform a conditional jump using switch operation. +// The first CNode select graph with condition, and than execute this graph +void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block, + const FunctionBlockPtr &false_block, bool unroll_loop) { + if (func_graph()->get_return() != nullptr) { + MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " + << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); + } + // Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch' + auto prim_switch = std::make_shared(prim::kPrimSwitch->name()); + if (!unroll_loop) { + prim_switch->AddAttr(prim::SWITCH_UNROLL_FLAG, MakeValue(0)); + } + CNodePtr switch_app = + func_graph()->NewCNode({NewValueNode(prim_switch), condNode, NewValueNode(true_block->func_graph()), + NewValueNode(false_block->func_graph())}); + CNodePtr switch_app_new = func_graph()->NewCNode({switch_app}); + func_graph()->set_output(switch_app_new); + InsertDependItemsBeforeReturn(); +} + +void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { + state_assign_[target] = readid; +} + +void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } + +void FunctionBlock::InsertDependItemsBeforeReturn() { + if (!prev_blocks_.empty()) { + for (auto &prev_block : prev_blocks_) { + MS_LOG(DEBUG) << "Has prev_block " << prev_block->func_graph()->debug_info().get(); + } + } + + ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); + ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); + ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); + const std::string primitive_name("assign"); + const std::string module_name("mindspore.ops.functional"); + ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); + if (state_assign_.size() == 0 && auto_depends_.size() == 0) { + return; + } + AnfNodePtr state = nullptr; + std::vector vec_states; + vec_states.emplace_back(make_tuple_op); + for (auto &item : state_assign_) { + auto source = ReadVariable(item.second); + auto assign = func_graph()->NewCNode({assign_op, item.first, source}); + MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; + vec_states.emplace_back(assign); + } + for (auto &item : auto_depends_) { + MS_LOG(DEBUG) << "auto_depends " << item->ToString(); + vec_states.emplace_back(item); + } + // if there are only make_tuple_op and another node in vec_states(the vec_states size is 2) + // do not need to make_tuple, just use the node. + if (vec_states.size() == 2) { + state = vec_states[1]; + } else { + state = func_graph()->NewCNode(vec_states); + } + + AnfNodePtr old_ret = nullptr; + auto return_node = func_graph()->get_return(); + if (return_node) { + if (return_node->inputs().size() < 1) { + MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2"; + } + old_ret = return_node->input(1); + } else { + old_ret = NewValueNode(kNone); + } + AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state}); + AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped}); + func_graph()->set_output(ret, true); + state_assign_.clear(); +} +} // namespace parse +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.h b/mindspore/ccsrc/pipeline/jit/parse/function_block.h new file mode 100644 index 0000000000..cbf75a3dd8 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.h @@ -0,0 +1,118 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_FUNCTION_BLOCK_H_ +#define PIPELINE_PARSE_FUNCTION_BLOCK_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "pipeline/jit/parse/parse_base.h" +#include "utils/log_adapter.h" +#include "utils/ordered_map.h" + +namespace mindspore { +namespace parse { + +class Parser; +class NameSpace; +class Symbol; +class FunctionBlock; +using FunctionBlockPtr = std::shared_ptr; + +// A function block is a straight-line code sequence with no branches, every block has one one exit point +// which is return. When parsing function, loop or branch , we use function block to track the structure of +// the original source code. +class FunctionBlock : public std::enable_shared_from_this { + public: + explicit FunctionBlock(const Parser &parser); + virtual ~FunctionBlock() {} + + FuncGraphPtr func_graph() { return func_graph_; } + void WriteVariable(const std::string &var_name, const AnfNodePtr &node); + AnfNodePtr ReadVariable(const std::string &var_name); + void AddPrevBlock(const FunctionBlockPtr &block); + void SetPhiArgument(const ParameterPtr &phi); + void CollectRemovablePhi(const ParameterPtr &phi); + // A block is matured if all its predecessors is generated + void Mature(); + CNodePtr ForceToBoolNode(const AnfNodePtr &cond); + CNodePtr ForceToWhileCond(const AnfNodePtr &cond); + void Jump(const FunctionBlockPtr &block, AnfNodePtr node); + AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi); + void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock, + bool unroll_loop = true); + // record the assign statement of self.xx weight parameter ,which will use state_setitem op + void SetStateAssgin(const AnfNodePtr &target, const std::string &readid); + void AddAutoDepend(const AnfNodePtr &target); + void InsertDependItemsBeforeReturn(); + void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } + bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); } + AnfNodePtr MakeResolveAstOp(const py::object &op); + AnfNodePtr MakeResolveClassMember(std::string attr); + AnfNodePtr MakeResolveSymbol(const std::string &value); + AnfNodePtr MakeResolveOperation(const std::string &value); + AnfNodePtr MakeResolve(const std::shared_ptr &name_space, const std::shared_ptr &resolve_symbol); + const std::unordered_map &removable_phis() const { return removable_phis_; } + + private: + // block graph + FuncGraphPtr func_graph_; + + // the block's parser + const Parser &parser_; + + // A block is matured if all its prev_blocks is processed + bool matured_; + + // store the nest-level block + // refer to comments in Parser::func_block_list_; + std::vector prev_blocks_; + + // store args and variable's node + std::map vars_; + + // phi_nodes map the parameter node to variable, it can be resolved if the block's predecessors are processed + std::map phi_nodes_; + + // jumps map the successor block and the function call that perform jump + // refer to comments in Parser::func_block_list_ that how to break the cyclic reference + std::map jumps_; + + // keeps all removable phis which will be removed in one pass. + std::unordered_map removable_phis_; + + // set state nodes need to insert before function return nodes. + OrderedMap state_assign_; + + // hold declared global variables in function + std::set global_vars_; + + // other depend need to insert before function return nodes. + // summary or some other node + std::vector auto_depends_; +}; + +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_FUNCTION_BLOCK_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc new file mode 100644 index 0000000000..edc9a66594 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -0,0 +1,1604 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/parse.h" +#include +#include +#include +#include +#include +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/data_converter.h" +#include "frontend/operator/composite/composite.h" +#include "utils/context/ms_context.h" +#include "debug/trace.h" + +namespace mindspore { +namespace parse { + +FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) { + (void)python_adapter::set_python_scoped(); + + if (obj == nullptr || py::isinstance(obj)) { + MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none"; + return nullptr; + } + + auto ast = std::make_shared(obj); + bool success = ast->InitParseAstInfo(python_mod_get_parse_method); + if (!success) { + MS_LOG(ERROR) << "Parse code to ast tree failed."; + return nullptr; + } + + auto parser = std::make_shared(ast); + + FuncGraphPtr func_graph = parser->ParseFuncGraph(); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode(); + return nullptr; + } + + return func_graph; +} + +// if any mixed precision flag add a cast node after the parameter node. +AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { + TypePtr dst_type; + if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { + dst_type = kFloat32; + } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { + dst_type = kFloat16; + } else { + return param; + } + auto cast_helper = prim::kPrimMixedPrecisionCast; + auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param}); + return cast; +} + +FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr(); + +Parser::Parser(const std::shared_ptr &ast) : ast_(ast) { + errcode_ = PARSE_SUCCESS; + BuildMethodMap(); +} + +void Parser::BuildMethodMap() { + stmt_method_map_["Return"] = &Parser::ParseReturn; + stmt_method_map_["Expr"] = &Parser::ParseExpr; + stmt_method_map_["If"] = &Parser::ParseIf; + stmt_method_map_["Assign"] = &Parser::ParseAssign; + stmt_method_map_["While"] = &Parser::ParseWhile; + stmt_method_map_["For"] = &Parser::ParseFor; + stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef; + stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign; + stmt_method_map_["Global"] = &Parser::ParseGlobal; + stmt_method_map_["Break"] = &Parser::ParseBreak; + stmt_method_map_["Continue"] = &Parser::ParseContinue; + stmt_method_map_["Pass"] = &Parser::ParsePass; + expr_method_map_["NoneType"] = &Parser::ParseNone; + expr_method_map_["BinOp"] = &Parser::ParseBinOp; + expr_method_map_["Name"] = &Parser::ParseName; + expr_method_map_["Num"] = &Parser::ParseNum; + expr_method_map_["Str"] = &Parser::ParseStr; + expr_method_map_["NameConstant"] = &Parser::ParseNameConstant; + expr_method_map_["Call"] = &Parser::ParseCall; + expr_method_map_["IfExp"] = &Parser::ParseIfExp; + expr_method_map_["Attribute"] = &Parser::ParseAttribute; + expr_method_map_["Compare"] = &Parser::ParseCompare; + expr_method_map_["BoolOp"] = &Parser::ParseBoolOp; + expr_method_map_["Lambda"] = &Parser::ParseLambda; + expr_method_map_["Tuple"] = &Parser::ParseTuple; + expr_method_map_["List"] = &Parser::ParseList; + expr_method_map_["Subscript"] = &Parser::ParseSubscript; + expr_method_map_["Slice"] = &Parser::ParseSlice; + expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice; + expr_method_map_["Index"] = &Parser::ParseIndex; + expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; + expr_method_map_["Dict"] = &Parser::ParseDict; + expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis; +} + +void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } + +void Parser::InitParserEnvironment(const py::object &obj) { + Parser::top_func_graph_ = FuncGraphWeakPtr(); + ScopeManager::GetInstance().ClearScope(); + (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj); +} + +void Parser::CleanParserResource() { + Parser::top_func_graph_ = FuncGraphWeakPtr(); + ScopeManager::GetInstance().ClearScope(); +} + +FuncGraphPtr Parser::ParseFuncGraph() { + // get ast FunctionDef node + py::object node = ast_->GetAstNode(); + FunctionBlockPtr pFnBlock = ParseFunction(node); + if (errcode() != PARSE_SUCCESS) { + MS_LOG(ERROR) << "Parse function error, code is " << errcode(); + return nullptr; + } + + RemoveUnnecessaryPhis(); + + MS_EXCEPTION_IF_NULL(pFnBlock); + return pFnBlock->func_graph(); +} + +void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { + py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args"); + py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg"); + block->func_graph()->set_has_vararg(!py::isinstance(var_arg_node)); + + py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg"); + block->func_graph()->set_has_kwarg(!py::isinstance(kw_arg_node)); + + py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs"); + block->func_graph()->set_kwonlyargs_count(SizeToInt(kwonly_args.size())); + + MS_EXCEPTION_IF_NULL(ast_); + py::list args = ast_->GetArgs(fn_node); + for (std::size_t i = 0; i < args.size(); i++) { + std::string arg_name = py::cast(args[i].attr("arg")); + if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { + if (arg_name == "self") { + continue; + } + } + TraceManager::DebugTrace(GetLocation(args[i])); + auto para_node = std::make_shared(block->func_graph()); + MS_EXCEPTION_IF_NULL(para_node); + TraceManager::EndTrace(); + para_node->set_name(arg_name); + para_node->debug_info()->set_name(arg_name); + block->func_graph()->add_parameter(para_node); + AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node); + block->WriteVariable(arg_name, para_after_cast); + MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name; + } +} + +void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { + py::list defaults = ast_->GetArgsDefaultValues(fn_node); + py::list args = ast_->GetArgs(fn_node); + std::vector namelist_for_default_value; + std::vector default_values; + for (std::size_t i = 0; i < args.size(); i++) { + std::string arg_name = py::cast(args[i].attr("arg")); + if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { + if (arg_name == "self") { + continue; + } + } + + namelist_for_default_value.push_back(arg_name); + if (py::isinstance(defaults[i])) { + default_values.push_back(NewValueNode(kNull)); + } else { + default_values.push_back(ParseExprNode(block, defaults[i])); + } + } + block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values); +} + +ScopePtr Parser::GetScopeForParseFunction() { + ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope(); + if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { + py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj()); + if (!py::isinstance(scope_str)) { + auto scope_name = py::cast(scope_str); + scope = std::make_shared(scope_name); + } + } + return scope; +} + +FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) { + ScopePtr scope = GetScopeForParseFunction(); + // the node created in the parsefunction context, will inherit the scope created using scope_guard + ScopeGuard scope_guard(scope); + TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node)); + FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this); + if (block != nullptr) { + pFunBlock->AddPrevBlock(block); + } else { + func_graph_ = pFunBlock->func_graph(); + } + pFunBlock->Mature(); + auto current_fg = pFunBlock->func_graph(); + auto function_name = py::cast(python_adapter::GetPyObjAttr(node, "name")); + MS_LOG(DEBUG) << "The function name is " << function_name; + + current_fg->debug_info()->set_name(function_name); + MS_EXCEPTION_IF_NULL(ast_); + py::list deco_list = node.attr("decorator_list"); + if (deco_list.size() > 0) { + current_fg->debug_info()->set_deco_location(GetLocation(deco_list)); + } + + bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg); + if (ast_->obj() != ast_->function()) { + set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg); + } + + if (!set_flag) { + MS_LOG(ERROR) << "Set flags failed"; + return nullptr; + } + GenerateArgsNodeForFunction(pFunBlock, node); + + // when parsing the top graph of construct, save the top graph + if (GetTopFuncGraph() == nullptr) { + UpdateTopFuncGraph(pFunBlock->func_graph()); + } + + // save the function node to block + pFunBlock->WriteVariable(function_name, NewValueNode(current_fg)); + + py::object funcObj = python_adapter::GetPyObjAttr(node, "body"); + (void)ParseStatements(pFunBlock, funcObj); + + if (current_fg->get_return() == nullptr) { + MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString(); + errcode_ = PARSE_NO_RETURN; + return pFunBlock; + } + GenerateArgsDefaultValueForFunction(pFunBlock, node); + return pFunBlock; +} + +FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::object &nodes) { + py::int_ pcount = python_adapter::CallPyObjMethod(nodes, "__len__"); + size_t count = IntToSize(pcount); + MS_LOG(DEBUG) << "The nodes count is " << count; + for (size_t i = 0; i < count; i++) { + auto node = py::cast(nodes)[i]; + TraceManager::DebugTrace(GetLocation(node)); + fn_block = ParseStatement(fn_block, node); + TraceManager::EndTrace(); + // insert appropriate depended items for the function block if it has a return node + if (fn_block->func_graph()->get_return() != nullptr) { + fn_block->InsertDependItemsBeforeReturn(); + // Skip statements after 'return' (or 'break', 'continue'). + break; + } + } + return fn_block; +} + +FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) { + auto node_type = ast_->GetNodeType(node); + + // check the node type + AstMainType nodeType = node_type->main_type(); + if (nodeType != AST_MAIN_TYPE_STMT) { + MS_LOG(INFO) << "Node type is error : " << nodeType; + return block; + } + // call the process function + std::string node_name = node_type->node_name(); + MS_LOG(DEBUG) << "Ast node is " << node_name; + if (stmt_method_map_.count(node_name)) { + TraceManager::DebugTrace(GetLocation(node)); + auto stmt_block = (this->*stmt_method_map_[node_name])(block, node); + TraceManager::EndTrace(); + return stmt_block; + } else { + errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; + } +} + +AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast expr"; + auto node_type = ast_->GetNodeType(node); + // check the node type + AstMainType node_main_type = node_type->main_type(); + if (node_main_type != AST_MAIN_TYPE_EXPR) { + MS_LOG(ERROR) << "Node type is error : " << node_main_type; + errcode_ = PARSE_NODE_TYPE_NO_MATCH; + return nullptr; + } + // call the process function + std::string node_name = node_type->node_name(); + MS_LOG(DEBUG) << "Ast node is " << node_name; + if (expr_method_map_.count(node_name)) { + TraceManager::DebugTrace(GetLocation(node)); + auto expr_node = (this->*expr_method_map_[node_name])(block, node); + TraceManager::EndTrace(); + return expr_node; + } else { + errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; + py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + auto filename = ret[0].cast(); + auto line_no = ret[1].cast(); + MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; + } +} + +// process the expr statement and expand it +// eg: x.append(y) -> x = x.append(y) +FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Expr"; + // Expr only have value , no target + py::tuple expand_info = ast_->CallParserObjMethod(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node); + + // refer python function expand_expr_statement, expand_info is one of the following: + // True, expr.value, x + // True, expr.value + // False, None, None + // check the expand info result + auto is_expand = py::cast(expand_info[0]); + if (is_expand) { + // process the expr statement + py::object value_object = expand_info[1]; + AnfNodePtr value_node = ParseExprNode(block, value_object); + + if (py::len(expand_info) == 2) { + // add to depend list and insert before output + block->AddAutoDepend(value_node); + } else { + // expand the assign statement + py::object target_node = expand_info[2]; + WriteAssignVars(block, target_node, value_node); + } + } + return block; +} + +LocationPtr Parser::GetLocation(const py::object &node) const { + MS_EXCEPTION_IF_NULL(ast_); + py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + if (ret.size() < 5) { + MS_LOG(EXCEPTION) << "List size should not be less than 5."; + } + // refer to Location::Location() for each member of ret: line, column, line_end, column_end. + auto location = std::make_shared(ret[0].cast(), ret[1].cast(), ret[2].cast(), + ret[3].cast(), ret[4].cast()); + return location; +} + +void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block, + const FunctionBlockPtr &false_block) { + true_block->AddPrevBlock(pre_block); + true_block->Mature(); + + false_block->AddPrevBlock(pre_block); + false_block->Mature(); +} + +FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast return"; + MS_EXCEPTION_IF_NULL(block); + // create return valuenode + AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn); + // parse the return Statements value + py::object value = python_adapter::GetPyObjAttr(node, "value"); + AnfNodePtr pReturnStatementNode = ParseExprNode(block, value); + // Create the cnode + CNodePtr pReturnCNode = block->func_graph()->NewCNode({pReturnValueNode, pReturnStatementNode}); + + block->func_graph()->set_return(pReturnCNode); + + return block; +} + +// Process binary operators,eg: `a + b`, `a | b`, etc. +AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast BinOP"; + + py::object left = python_adapter::GetPyObjAttr(node, "left"); + py::object right = python_adapter::GetPyObjAttr(node, "right"); + py::object op = python_adapter::GetPyObjAttr(node, "op"); + // create left and right ANF node + AnfNodePtr left_node = ParseExprNode(block, left); + if (left_node == nullptr) { + MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode(); + return nullptr; + } + AnfNodePtr right_node = ParseExprNode(block, right); + if (right_node == nullptr) { + MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode(); + return nullptr; + } + // resolve the op + AnfNodePtr op_node = block->MakeResolveAstOp(op); + // create apply node + return block->func_graph()->NewCNode({op_node, left_node, right_node}); +} + +AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Name"; + auto name_id = py::cast(python_adapter::GetPyObjAttr(node, "id")); + MS_LOG(DEBUG) << "The Name id is " << name_id; + TraceGuard trace_guard(GetLocation(node)); + if (block->IsGlobalVar(name_id)) { + return block->MakeResolveSymbol(name_id); + } + return block->ReadVariable(name_id); +} + +AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) { + MS_LOG(DEBUG) << "Process ast NoneType"; + return NewValueNode(kNone); +} + +AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) { + MS_LOG(DEBUG) << "Process ast Ellipsis"; + return NewValueNode(kEllipsis); +} + +AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Num"; + py::object obj = python_adapter::GetPyObjAttr(node, "n"); + TraceGuard trace_guard(GetLocation(node)); + if (py::isinstance(obj)) { + MS_LOG(INFO) << "The Num is int:" << (std::string)py::str(obj); + auto data = py::cast(obj); + return NewValueNode(data); + } else if (py::isinstance(obj)) { + MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj); + auto data = py::cast(obj); + return NewValueNode(data); + } else { + // no else actually + MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj) << GetLocation(node)->ToString(); + errcode_ = PARSE_NODE_TYPE_UNKOWN; + return nullptr; + } +} + +AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Str"; + auto str_s = py::cast(python_adapter::GetPyObjAttr(node, "s")); + return NewValueNode(str_s); +} + +AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) { + MS_LOG(DEBUG) << "Process ast NameConstant"; + py::object obj = python_adapter::GetPyObjAttr(node, "value"); + TraceGuard trace_guard(GetLocation(node)); + if (py::isinstance(obj)) { + MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj); + auto data = py::cast(obj); + return NewValueNode(data); + } else if (py::isinstance(obj)) { + MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj); + return NewValueNode(kNone); + } else { + // no else actually + MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj) << GetLocation(node)->ToString(); + errcode_ = PARSE_NODE_TYPE_UNKOWN; + return nullptr; + } +} +AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes) { + AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); + std::vector make_tuple_nodes; + make_tuple_nodes.push_back(make_tuple_op); + (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes), + [](AnfNodePtr arg) -> AnfNodePtr { return arg; }); + return block->func_graph()->NewCNode(make_tuple_nodes); +} +// process function call, eg : f1(x, y) ... +AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Call"; + // process function call + py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func"); + AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node); + // function call arguments should be passed in as groups and unpacked later using unpack call + py::list args = python_adapter::GetPyObjAttr(node, "args"); + std::vector packed_arguments; + std::vector group_arguments; + + bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments); + bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments); + // if there is stared or keyword argument, unpack may be needed + bool need_unpack = need_unpack_args || need_unpack_keywords; + + return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack); +} + +AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, + const std::vector &packed_arguments, + const std::vector &group_arguments, bool need_unpack) const { + // if there is keyword arguments or starred, using an unpack_call op to unpack the argument + if (need_unpack) { + std::vector unpack_call_nodes; + auto unpack_call_op = NewValueNode(std::make_shared(NAMED_METAGRAPH_UNPACKCALL)); + unpack_call_nodes.push_back(unpack_call_op); + unpack_call_nodes.push_back(call_function_anf_node); + (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), + [](AnfNodePtr node) -> AnfNodePtr { return node; }); + CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes); + return unpack_call; + } + // else there is no keyword arguments and starred, parsed as normal arguments without unpack + std::vector func_call_nodes; + func_call_nodes.push_back(call_function_anf_node); + (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes), + [](AnfNodePtr node) -> AnfNodePtr { return node; }); + CNodePtr call_anf_node = block->func_graph()->NewCNode(func_call_nodes); + return call_anf_node; +} + +bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, + std::vector *packed_arguments, std::vector *group_arguments) { + bool need_unpack = false; + for (size_t i = 0; i < args.size(); i++) { + auto arg_node = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[i]))); + if (arg_node == AST_SUB_TYPE_STARRED) { + if (!group_arguments->empty()) { + packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); + } + packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value"))); + group_arguments->clear(); + need_unpack = true; + } else { + group_arguments->push_back(ParseExprNode(block, args[i])); + } + } + if (!group_arguments->empty()) { + packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); + } + return need_unpack; +} + +bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, + std::vector *packed_arguments) { + bool need_unpack = false; + py::list keywords = python_adapter::GetPyObjAttr(node, "keywords"); + if (!keywords.empty()) { + need_unpack = true; + std::vector keys; + std::vector values; + for (size_t index = 0; index < keywords.size(); index++) { + auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg"); + auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value"); + if (py::isinstance(kw_key)) { + packed_arguments->push_back(ParseExprNode(block, kw_value)); + } else { + auto kw_key_c = kw_key.cast(); + keys.push_back(NewValueNode(kw_key_c)); + values.push_back(ParseExprNode(block, kw_value)); + } + } + auto keys_tuple = GenerateMakeTuple(block, keys); + auto values_tuple = GenerateMakeTuple(block, values); + auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); + std::vector make_dict_nodes; + make_dict_nodes.push_back(make_dict_op); + make_dict_nodes.push_back(keys_tuple); + make_dict_nodes.push_back(values_tuple); + packed_arguments->push_back(block->func_graph()->NewCNode(make_dict_nodes)); + } + return need_unpack; +} + +// process call attributes of class type define, eg: x.y() +AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Attribute"; + + // process class value,eg: self.xx + if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { + if (ast_->IsClassMember(node)) { + std::string var_name = "self."; + std::string attr_name = node.attr("attr").cast(); + (void)var_name.append(attr_name); + auto attr_obj = ast()->obj().attr(attr_name.c_str()); + if (py::hasattr(ast()->obj(), attr_name.c_str()) && + (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance(attr_obj) || + py::isinstance(attr_obj) || py::isinstance(attr_obj) || + py::isinstance(attr_obj) || data_converter::IsCellInstance(attr_obj))) { + return block->MakeResolveSymbol(var_name); + } else { + return block->ReadVariable(var_name); + } + } + } + + // process the get attr + // Use the Primitive replace the operation resolve node (getattr) + // because the getattr will eventually be converted to Primitive node + AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr); + + // process the attr body + py::object value_body = python_adapter::GetPyObjAttr(node, "value"); + AnfNodePtr value_node = ParseExprNode(block, value_body); + if (value_node == nullptr) { + MS_LOG(WARNING) << "Parse attribute failed"; + return nullptr; + } + + // process the node attr + auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast(); + MS_LOG(DEBUG) << "Attr = " << attr_str; + TraceManager::DebugTrace(GetLocation(python_adapter::GetPyObjAttr(node, "attr"))); + AnfNodePtr attr_node = NewValueNode(attr_str); + TraceManager::EndTrace(); + + // create the apply node + return block->func_graph()->NewCNode({op_node, value_node, attr_node}); +} + +// Process comparison expression : a == b. a > b etc. +AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Compare"; + + // for python comparison ,there may be if x>y>5 , + // which there is two ops , but we only support one now + py::list ops = python_adapter::GetPyObjAttr(node, "ops"); + if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) { + MS_LOG(ERROR) << "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size(); + return nullptr; + } + + py::object left = python_adapter::GetPyObjAttr(node, "left"); + py::list comparators = python_adapter::GetPyObjAttr(node, "comparators"); + AnfNodePtr left_node = ParseExprNode(block, left); + AnfNodePtr right_node = ParseExprNode(block, comparators[0]); + + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]); + + return block->func_graph()->NewCNode({op_node, left_node, right_node}); +} + +AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, + const py::object &op) { + // if there is only one bool op now + if (value_list.size() == 1) { + AnfNodePtr first_node = ParseExprNode(block, value_list[0]); + return first_node; + } else { + py::object first = value_list[0]; + py::list rest; + for (size_t i = 1; i < value_list.size(); i++) { + rest.append(value_list[i]); + } + + AnfNodePtr first_node = ParseExprNode(block, first); + AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op); + auto op_node = block->MakeResolveAstOp(op); + return block->func_graph()->NewCNode({op_node, first_node, rest_node}); + } +} + +// Process comparison expression : a and b. a or b . +AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast BoolOp"; + py::object op_node = python_adapter::GetPyObjAttr(node, "op"); + py::list op_values = python_adapter::GetPyObjAttr(node, "values"); + return ProcessBoolOpValueList(block, op_values, op_node); +} + +// Process a function def +FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast FunctionDef"; + FunctionBlockPtr function_block = ParseFunction(node, block); + MS_EXCEPTION_IF_NULL(function_block); + + // get function name + py::str name = python_adapter::GetPyObjAttr(node, "name"); + std::string function_name = name; + ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph()); + block->WriteVariable(function_name, valuenode_graph); + return block; +} + +// Process a lambda expression . like lambda x,y: x + y +AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Lambda"; + FunctionBlockPtr func_block = MakeFunctionBlock(*this); + func_block->AddPrevBlock(block); + func_block->Mature(); + + // get lambda args + py::list args = ast_->GetArgs(node); + for (std::size_t i = 0; i < args.size(); i++) { + std::string arg = py::cast(args[i].attr("arg")); + TraceManager::DebugTrace(GetLocation(args[i])); + auto para_node = std::make_shared(func_block->func_graph()); + TraceManager::EndTrace(); + para_node->debug_info()->set_name(arg); + func_block->func_graph()->add_parameter(para_node); + func_block->WriteVariable(arg, para_node); + MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg; + } + + py::object body_node = python_adapter::GetPyObjAttr(node, "body"); + AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node); + func_block->func_graph()->set_output(lambda_body_node); + ValueNodePtr const_graph = NewValueNode(func_block->func_graph()); + return const_graph; +} + +// process a tuple +AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Tuple"; + MS_EXCEPTION_IF_NULL(block); + py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); + if (elts.size() == 0) { + auto empty_tuple = std::vector(); + return NewValueNode(std::make_shared(empty_tuple)); + } + + std::vector tuple_vec; + AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); + tuple_vec.emplace_back(make_tuple_op); + for (size_t i = 0; i < elts.size(); i++) { + AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); + tuple_vec.emplace_back(node_ptr); + } + CNodePtr tuple_app = block->func_graph()->NewCNode(tuple_vec); + return tuple_app; +} + +// process a list +AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast List"; + MS_EXCEPTION_IF_NULL(block); + py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); + if (elts.size() == 0) { + auto empty_list = std::vector(); + return NewValueNode(std::make_shared(empty_list)); + } + + std::vector list_vec; + AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST); + list_vec.emplace_back(make_list_op); + for (size_t i = 0; i < elts.size(); i++) { + AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); + list_vec.emplace_back(node_ptr); + } + CNodePtr list_app = block->func_graph()->NewCNode(list_vec); + return list_app; +} + +// process a subscript, such as x[y] , node expressed as value[slice] +AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Subscript"; + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); + py::object value_node = python_adapter::GetPyObjAttr(node, "value"); + py::object slice_node = python_adapter::GetPyObjAttr(node, "slice"); + AnfNodePtr value = ParseExprNode(block, value_node); + AnfNodePtr slice = ParseExprNode(block, slice_node); + + return block->func_graph()->NewCNode({op_getitem, value, slice}); +} + +// process a slice, get the slice value +AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Slice"; + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE); + py::object start = python_adapter::GetPyObjAttr(node, "lower"); + py::object stop = python_adapter::GetPyObjAttr(node, "upper"); + py::object step = python_adapter::GetPyObjAttr(node, "step"); + AnfNodePtr start_node = ParseExprNode(block, start); + AnfNodePtr stop_node = ParseExprNode(block, stop); + AnfNodePtr step_node = ParseExprNode(block, step); + + return block->func_graph()->NewCNode({op_makeslice, start_node, stop_node, step_node}); +} + +// process a extslice +AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast ExtSlice"; + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); + py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims"); + + std::vector node_vec; + node_vec.emplace_back(make_tuple_op); + for (size_t i = 0; i < slice_tuple.size(); i++) { + AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]); + node_vec.emplace_back(node_ptr); + } + CNodePtr tuple_conde = block->func_graph()->NewCNode(node_vec); + return tuple_conde; +} + +// process a index, get the index number +AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Index"; + py::object value_node = python_adapter::GetPyObjAttr(node, "value"); + return ParseExprNode(block, value_node); +} + +// process a UnaryOp, +a, -b +AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast UnaryOp"; + py::object op = python_adapter::GetPyObjAttr(node, "op"); + + MS_EXCEPTION_IF_NULL(block); + // resolve the op + AnfNodePtr op_node = block->MakeResolveAstOp(op); + + py::object operand = python_adapter::GetPyObjAttr(node, "operand"); + AnfNodePtr operand_node = ParseExprNode(block, operand); + return block->func_graph()->NewCNode({op_node, operand_node}); +} + +// process a dict ast node expression +AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Dict"; + py::list keys = node.attr("keys"); + py::list values = node.attr("values"); + std::vector key_nodes; + std::vector value_nodes; + for (size_t i = 0; i < keys.size(); i++) { + key_nodes.push_back(ParseExprNode(block, keys[i])); + value_nodes.push_back(ParseExprNode(block, values[i])); + } + auto keys_tuple = GenerateMakeTuple(block, key_nodes); + auto values_tuple = GenerateMakeTuple(block, value_nodes); + MS_EXCEPTION_IF_NULL(block); + auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); + return block->func_graph()->NewCNode({make_dict_op, keys_tuple, values_tuple}); +} + +// process a augment assign such as a += b; +FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast AugAssign"; + py::object op = python_adapter::GetPyObjAttr(node, "op"); + + MS_EXCEPTION_IF_NULL(block); + // resolve the op + AnfNodePtr op_node = block->MakeResolveAstOp(op); + py::object target_node = python_adapter::GetPyObjAttr(node, "target"); + MS_EXCEPTION_IF_NULL(ast_); + auto ast_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, target_node))); + AnfNodePtr read_node = nullptr; + if (ast_type == AST_SUB_TYPE_NAME) { + read_node = ParseName(block, target_node); + } else if (ast_->IsClassMember(target_node)) { + read_node = ParseAttribute(block, target_node); + } else { + MS_LOG(EXCEPTION) << "Not supported augassign"; + } + if (read_node == nullptr) { + MS_LOG(EXCEPTION) << "Can not get target node "; + } + + py::object value = python_adapter::GetPyObjAttr(node, "value"); + AnfNodePtr value_node = ParseExprNode(block, value); + CNodePtr augassign_app = block->func_graph()->NewCNode({op_node, read_node, value_node}); + WriteAssignVars(block, target_node, augassign_app); + return block; +} + +// process global declaration such as 'global x'; +FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Global"; + MS_EXCEPTION_IF_NULL(block); + py::list vars = python_adapter::GetPyObjAttr(node, "names"); + for (auto &item : vars) { + block->AddGlobalVar(py::cast(item)); + } + return block; +} + +// process a if statement +FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast If"; + py::object test_node = python_adapter::GetPyObjAttr(node, "test"); + AnfNodePtr condition_node = ParseExprNode(block, test_node); + MS_EXCEPTION_IF_NULL(block); + CNodePtr bool_node = block->ForceToBoolNode(condition_node); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr true_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr false_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + MakeConditionBlocks(block, true_block, false_block); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr after_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + // process the if-true branch + py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); + FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode); + + // if the return_ is set ,it has its own continuation block + if (true_end->func_graph()->get_return() == nullptr) { + true_end->Jump(after_block, nullptr); + } + + // process the orelse branch + py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); + FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode); + + // if the return_ is set ,it has its own continuation block + if (false_end->func_graph()->get_return() == nullptr) { + false_end->Jump(after_block, nullptr); + } + + block->ConditionalJump(bool_node, true_block, false_block); + after_block->Mature(); + return after_block; +} + +FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast While"; + MS_EXCEPTION_IF_NULL(block); + MS_LOG(INFO) << "Parse while statement"; + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr header_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr body_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr after_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + body_block->AddPrevBlock(header_block); + after_block->AddPrevBlock(header_block); + block->Jump(header_block, nullptr); + + py::object test_node = python_adapter::GetPyObjAttr(node, "test"); + AnfNodePtr condition_node = ParseExprNode(header_block, test_node); + condition_node = header_block->ForceToWhileCond(condition_node); + body_block->Mature(); + header_block->ConditionalJump(condition_node, body_block, after_block); + + // Parse loop body statements with loop context. + LoopContext loop_context{&loops_, header_block, nullptr}; + py::object body_node = python_adapter::GetPyObjAttr(node, "body"); + FunctionBlockPtr after_body = ParseStatements(body_block, body_node); + if (after_body->func_graph()->get_return() == nullptr) { + after_body->Jump(header_block, nullptr); + } + + header_block->Mature(); + after_block->Mature(); + auto &end_block = loop_context.EndBlock(); + if (end_block) { + // end_block exists if we encounter 'break' in loop body. + after_block->Jump(end_block, nullptr); + end_block->Mature(); + return end_block; + } + // No 'break', no end_block. + return after_block; +} + +CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node, + const AnfNodePtr &op_iter) { + py::object iter_node = python_adapter::GetPyObjAttr(node, "iter"); + AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node); + return block->func_graph()->NewCNode({op_iter, iter_anf_node}); +} + +CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, + const AnfNodePtr &op_hasnext) { + MS_EXCEPTION_IF_NULL(header_block); + return header_block->func_graph()->NewCNode({op_hasnext, iter_param}); +} + +FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) { + TraceManager::DebugTrace(trace_info); + FunctionBlockPtr body_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + return body_block; +} + +// A for loop will generate 3 functions :the test, the body, and the continuation +// for x in xs: +// body +// it is compiled to be following statement +// if len(xs) < max_loop_cnt: +// ParseForIter() // use iter to implement for loop, which always unroll loop +// else: +// ParseForLoop() // use loop var to implement for loop, which always sink loop +FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast For, create an if else statement"; + MS_EXCEPTION_IF_NULL(block); + // create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT' + AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); + py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); + AnfNodePtr iter_node = ParseExprNode(block, iter_obj); + CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); + CNodePtr bool_node = block->func_graph()->NewCNode( + {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)}); + + // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr true_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr false_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + MakeConditionBlocks(block, true_block, false_block); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr after_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + FunctionBlockPtr true_end = ParseForIter(true_block, node); + true_end->Jump(after_block, nullptr); + + FunctionBlockPtr false_end = ParseForLoop(false_block, node); + false_end->Jump(after_block, nullptr); + + block->ConditionalJump(bool_node, true_block, false_block); + after_block->Mature(); + return after_block; +} + +// A for loop will generate 3 functions :the test, the body, and the continuation +// for x in xs: +// body +// it is compiled to be following statement +// it = iter(xs) +// while hastnext(it) +// x, it = next(it) +// body +FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast For"; + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER); + AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT); + AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); + AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT); + // generate the iterator apply + CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter); + MS_EXCEPTION_IF_NULL(iter_apply); + FunctionBlockPtr header_block = + GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); + MS_EXCEPTION_IF_NULL(header_block); + // generate the hasnext apply which is a condition + ParameterPtr iter_param = header_block->func_graph()->add_parameter(); + CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext); + // generate the body of the for statement + FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); + MS_EXCEPTION_IF_NULL(body_block); + body_block->AddPrevBlock(header_block); + // generate the iterator next apply + // process as following: `app = next(it); target = app[0]; it = app[1];` + CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param}); + CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)}); + py::object target_node = python_adapter::GetPyObjAttr(node, "target"); + + CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)}); + WriteAssignVars(body_block, target_node, target_app); + + // link the variable name with the target + auto it_info = std::make_shared(target_app->debug_info()); + iter_param->debug_info()->set_trace_info(it_info); + iter2_app->debug_info()->set_trace_info(it_info); + iter_apply->debug_info()->set_trace_info(it_info); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr after_block = MakeFunctionBlock(*this); + MS_EXCEPTION_IF_NULL(after_block); + TraceManager::EndTrace(); + after_block->AddPrevBlock(header_block); + + block->Jump(header_block, iter_apply); + body_block->Mature(); + header_block->ConditionalJump(cond_apply, body_block, after_block); + + // Parse loop body statements with loop context. + LoopContext loop_context{&loops_, header_block, iter2_app}; + py::object body_node = python_adapter::GetPyObjAttr(node, "body"); + FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); + if (after_body_block->func_graph()->get_return() == nullptr) { + after_body_block->Jump(header_block, iter2_app); + } + + header_block->Mature(); + after_block->Mature(); + auto &end_block = loop_context.EndBlock(); + if (end_block) { + // end_block exists if we encounter 'break' in loop body. + after_block->Jump(end_block, nullptr); + end_block->Mature(); + return end_block; + } + // No 'break', no end_block. + return after_block; +} + +// A for loop will generate 3 functions :the test, the body, and the continuation +// for x in xs: +// body +// it is compiled to be following statement +// i = 0 +// while i < len(xs) +// x = xs[i] +// i = i + 1 +// body +FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast For by loop variable"; + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); + AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); + + // get varibale name of 'x' in statement 'for x in xs' + py::object target_node = python_adapter::GetPyObjAttr(node, "target"); + + // create statement 'len(xs)' + py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); + AnfNodePtr iter_node = ParseExprNode(block, iter_obj); + MS_EXCEPTION_IF_NULL(iter_node); + CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); + + FunctionBlockPtr header_block = + GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); + MS_EXCEPTION_IF_NULL(header_block); + // create loop variable 'i' + ParameterPtr loop_var = header_block->func_graph()->add_parameter(); + // create loop condition 'i < len(xs)' + CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter}); + + // generate the body of the for statement + FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); + MS_EXCEPTION_IF_NULL(body_block); + body_block->AddPrevBlock(header_block); + // create 'x = xs[i]' + CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var}); + WriteAssignVars(body_block, target_node, target_var); + // create 'i = i + 1' + CNodePtr loop_var_inc = + body_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(1)}); + body_block->WriteVariable(loop_var->name(), loop_var_inc); + + // link the variable name with the target + auto it_info = std::make_shared(loop_var_inc->debug_info()); + loop_var->debug_info()->set_trace_info(it_info); + len_iter->debug_info()->set_trace_info(it_info); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr after_block = MakeFunctionBlock(*this); + MS_EXCEPTION_IF_NULL(after_block); + TraceManager::EndTrace(); + after_block->AddPrevBlock(header_block); + + block->Jump(header_block, NewValueNode(0)); + body_block->Mature(); + + header_block->ConditionalJump(cond_node, body_block, after_block, false); + + // Parse loop body statements with loop context. + LoopContext loop_context{&loops_, header_block, loop_var_inc}; + py::object body_node = python_adapter::GetPyObjAttr(node, "body"); + FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); + if (after_body_block->func_graph()->get_return() == nullptr) { + after_body_block->Jump(header_block, loop_var_inc); + } + + header_block->Mature(); + after_block->Mature(); + auto &end_block = loop_context.EndBlock(); + if (end_block) { + // end_block exists if we encounter 'break' in loop body. + after_block->Jump(end_block, nullptr); + end_block->Mature(); + return end_block; + } + // No 'break', no end_block. + return after_block; +} + +AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast IfExp"; + MS_EXCEPTION_IF_NULL(block); + py::object test_node = python_adapter::GetPyObjAttr(node, "test"); + AnfNodePtr condition_node = ParseExprNode(block, test_node); + CNodePtr bool_node = block->ForceToBoolNode(condition_node); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr true_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr false_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + MakeConditionBlocks(block, true_block, false_block); + + // process the if-true branch + py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); + true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode)); + AnfNodePtr true_node = ParseExprNode(true_block, bodyNode); + + // process the orelse branch + py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); + false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode)); + AnfNodePtr false_node = ParseExprNode(false_block, orelseNode); + + true_block->func_graph()->set_output(true_node); + false_block->func_graph()->set_output(false_node); + + // Use the Primitive replace the operation resolve node (switch) + // because the switch will eventually be converted to Primitive node + CNodePtr switch_app = + block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), bool_node, NewValueNode(true_block->func_graph()), + NewValueNode(false_block->func_graph())}); + + std::vector call_graph_nodes{switch_app}; + CNodePtr switch_app_call = block->func_graph()->NewCNode(call_graph_nodes); + return switch_app_call; +} + +void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { + MS_EXCEPTION_IF_NULL(block); + MS_EXCEPTION_IF_NULL(assigned_node); + py::str name = python_adapter::GetPyObjAttr(targ, "id"); + std::string name_id = name; + assigned_node->debug_info()->set_name(name_id); + // set the debug name of the constant graph + if (IsValueNode(assigned_node)) { + // the value should be graph + auto fg = GetValueNode(assigned_node); + if (fg->debug_info()->name().empty()) { + fg->debug_info()->set_name(name_id); + } + } + block->WriteVariable(name_id, assigned_node); +} + +void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); + py::list items = python_adapter::GetPyObjAttr(targ, "elts"); + for (size_t i = 0; i < items.size(); i++) { + // Use the Primitive replace the operation resolve node (getitem) + // because the getitem will eventually be converted to Primitive node + CNodePtr item_apply = block->func_graph()->NewCNode({op_getitem, assigned_node, NewValueNode(static_cast(i))}); + + py::object elt = items[i]; + WriteAssignVars(block, elt, item_apply); + } +} + +void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, + const AnfNodePtr &assigned_node) { + // Now only support the self.xx = xxxxx, can't support x.y = xxxx + AnfNodePtr target_node = ParseExprNode(block, targ); + MS_EXCEPTION_IF_NULL(target_node); + + std::string attr_name = targ.attr("attr").cast(); + std::string var_name = "self."; + (void)var_name.append(attr_name); + MS_LOG(DEBUG) << "assign " << var_name; + + // Get targ location info for error printing + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, targ); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type + if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":" + << line_no; + } + auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); + auto obj_type = obj.attr("__class__").attr("__name__"); + if (!py::hasattr(obj, "__parameter__")) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" + << py::str(obj).cast() << "' with type '" + << py::str(obj_type).cast() << "' at " << filename << ":" << line_no; + } + + MS_EXCEPTION_IF_NULL(block); + block->WriteVariable(var_name, assigned_node); + MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString(); + block->SetStateAssgin(target_node, var_name); +} + +void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, + const AnfNodePtr &assigned_node) { + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM); + py::object value_obj = python_adapter::GetPyObjAttr(targ, "value"); + py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice"); + AnfNodePtr value_node = ParseExprNode(block, value_obj); + AnfNodePtr slice_node = ParseExprNode(block, slice_obj); + CNodePtr setitem_app = block->func_graph()->NewCNode({op_setitem, value_node, slice_node, assigned_node}); + // getitem apply should return the sequence data structure itself + std::string var_name = ""; + if (ast_->IsClassMember(value_obj)) { + std::string attr_name = value_obj.attr("attr").cast(); + var_name = "self." + attr_name; + if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function."; + } + auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); + auto obj_type = obj.attr("__class__").attr("__name__"); + if (!py::hasattr(obj, "__parameter__")) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" + << py::str(obj).cast() << "' with type '" + << py::str(obj_type).cast() << "'."; + } + } else { + var_name = value_obj.attr("id").cast(); + } + block->WriteVariable(var_name, setitem_app); +} + +void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + MS_LOG(DEBUG) << "Process WriteAssignVars"; + auto ast_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, targ))); + if (ast_type == AST_SUB_TYPE_NAME) { + HandleAssignName(block, targ, value_node); + } else if (ast_type == AST_SUB_TYPE_TUPLE) { + HandleAssignTuple(block, targ, value_node); + } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { + HandleAssignSubscript(block, targ, value_node); + } else if (ast_->IsClassMember(targ)) { + HandleAssignClassMember(block, targ, value_node); + } else { + MS_LOG(EXCEPTION) << "Not supported assign type: " << ast_type + << " NodeInfo: " << trace::GetDebugInfo(value_node->debug_info()); + } +} + +// process a assign statement, such as a =b, a,b = tup +FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast assgin"; + py::object value_object = python_adapter::GetPyObjAttr(node, "value"); + AnfNodePtr value_node = ParseExprNode(block, value_object); + py::object targets_object = python_adapter::GetPyObjAttr(node, "targets"); + py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__"); + size_t count = IntToSize(pcount); + MS_LOG(DEBUG) << "The nodes count is " << count; + for (size_t i = 0; i < count; i++) { + auto target_node = py::cast(targets_object)[i]; + WriteAssignVars(block, target_node, value_node); + } + + return block; +} + +FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) { + if (loops_.empty()) { + // Report error if loop context not set for the 'break' statement. + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + MS_LOG(EXCEPTION) << "Unexpected 'break' at " << filename << ":" << line_no; + } + // Get current loop. + Loop &loop = loops_.top(); + if (loop.end == nullptr) { + // Create end_block if it is not existed. + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + loop.end = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + } + // Jump to the end_block. + block->Jump(loop.end, nullptr); + return block; +} + +FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) { + if (loops_.empty()) { + // Report error if loop context not set for the 'continue' statement. + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + MS_LOG(EXCEPTION) << "Unexpected 'continue' at " << filename << ":" << line_no; + } + // Jump to the header of the loop with iterator called. + Loop &loop = loops_.top(); + block->Jump(loop.header, loop.iterator); + return block; +} + +FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) { + // We just bypass 'pass' statement. + return block; +} + +void Parser::RemoveUnnecessaryPhis() { + // merge all removable phis to one map; + std::unordered_map removable_phis; + for (FunctionBlockPtr &block : func_block_list_) { + MS_EXCEPTION_IF_NULL(block); + removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); + } + + if (removable_phis.size() == 0) { + return; + } + for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) { + if (node->isa()) { + const auto &cnode = node->cast(); + auto &inputs = cnode->inputs(); + for (std::size_t i = 0; i < inputs.size(); i++) { + if (inputs[i]->isa()) { + const auto &inp = inputs[i]->cast(); + const auto &iter = removable_phis.find(inp); + if (iter == removable_phis.end()) { + continue; + } + auto &argNode = iter->second; + MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in " + << cnode->DebugString() << " with " << argNode->DebugString(); + cnode->set_input(i, argNode); + } + } + } + } +} + +// ParseAst class code +bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) { + // init the type + target_type_ = PARSE_TARGET_UNKNOW; + + // call python parse, get the parser fn + module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD); + + // get the obj type + auto type = data_converter::GetObjType(obj_); + if (type == RESOLVE_TYPE_FUNCTION) { + target_type_ = PARSE_TARGET_FUNCTION; + function_ = obj_; + } else if (type == RESOLVE_TYPE_METHOD) { + // process the method ,need get the method's self obj + target_type_ = PARSE_TARGET_METHOD; + py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS); + if (py::isinstance(method_object)) { + MS_LOG(ERROR) << "Get method's self object instance failed."; + return false; + } + target_type_ = PARSE_TARGET_OBJECT_INSTANCE; + function_ = obj_; + obj_ = method_object; + } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) { + // obj is class instance, get the method to parse. + function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method); + if (py::isinstance(function_)) { + MS_LOG(ERROR) << "Get obj method function failed."; + return false; + } + target_type_ = PARSE_TARGET_OBJECT_INSTANCE; + // check the fn is method + auto obj_type = data_converter::GetObjType(function_); + if (obj_type != RESOLVE_TYPE_METHOD) { + MS_LOG(WARNING) << "Parse method function is invalid."; + return false; + } + } else { + MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type; + return false; + } + + // call python parse get ast tree + parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method); + ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse"); + + // get fn name and module + function_module_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_module")); + function_name_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_name")); + function_filename_ = py::cast(python_adapter::GetPyObjAttr(parser_, "filename")); + function_line_offset_ = py::cast(python_adapter::GetPyObjAttr(parser_, "line_offset")); + + return true; +} + +// Get ast tree node : is the tree bode list[0] +py::object ParseAst::GetAstNode() { + py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body"); + py::object ast_node = tree_body[0]; + return ast_node; +} + +py::list ParseAst::GetArgs(const py::object &func_node) { + py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS, func_node); + return ret; +} + +py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) { + py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node); + return ret; +} + +AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) { + py::list list_value = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_NODE_TYPE, node); + if (list_value.size() < 2) { + MS_LOG(ERROR) << "The node of python method must has 2 values."; + return nullptr; + } + auto node_name = py::cast(list_value[0]); + auto type = AstMainType(py::cast(list_value[1])); + return std::make_shared(node, node_name, type); +} + +AstSubType ParseAst::GetOpType(const py::object &node) { + auto op_type = AstSubType(python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_AST_TYPE, node).cast()); + return op_type; +} + +bool ParseAst::IsClassMember(const py::object &node) { + py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node); + if (!py::isinstance(ret)) { + MS_LOG(ERROR) << "The result of mod function parse, should be bool type."; + return false; + } + return ret.cast(); +} + +bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + MS_LOG(ERROR) << "FuncGraph is null"; + return false; + } + + if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) { + MS_LOG(DEBUG) << "No flags"; + return true; + } + py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG); + for (auto &item : flags) { + if (!py::isinstance(item.first)) { + MS_LOG(ERROR) << "Type error in flags dict convert"; + return false; + } + auto name = py::cast(item.first); + if (py::isinstance(item.second)) { + auto value = py::cast(item.second); + MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; + func_graph->set_flag(name, value); + } else if (py::isinstance(item.second)) { + auto value = py::cast(item.second); + MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; + func_graph->set_attr(name, MakeValue(value)); + } else { + MS_LOG(ERROR) << "Type error in flags/attrs dict convert"; + return false; + } + } + return true; +} + +} // namespace parse +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h new file mode 100644 index 0000000000..90e965389f --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -0,0 +1,360 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_PARSE_H_ +#define PIPELINE_PARSE_PARSE_H_ + +#include +#include +#include +#include +#include +#include +#include "utils/misc.h" +#include "ir/anf.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/function_block.h" + +namespace mindspore { +namespace parse { + +// Parse status define +enum ParseStatusCode : int { + PARSE_SUCCESS = 0, + PARSE_FUNCTION_IS_NULL, // python function is null + PARSE_PARAMETER_INVALID, // parameter is invalid + PARSE_NO_RETURN, // function no return node + PARSE_NODE_TYPE_NO_MATCH, // ast node type is error + PARSE_NODE_TYPE_UNKOWN, // node type is unkown + PARSE_NODE_METHOD_UNSUPPORTED, // no method to parse the node + PARSE_DONT_RESOLVE_SYMBOL, // can't resolve the string + PARSE_NOT_SUPPORTED_COMPARE_EXPR, // the comparison is not supported + PARSE_FAILURE = 0xFF +}; + +class AstNodeType; +class ParseAst; + +// Save loop info for 'continue' and 'break' statements. +struct Loop { + // Loop header block. + FunctionBlockPtr header; + // Loop iterator node, used in 'for loop'. + AnfNodePtr iterator; + // Loop end block. + FunctionBlockPtr end; + + Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end) + : header(header), iterator(iterator), end(end) {} + ~Loop() = default; +}; + +// Loop context for loop stack management. +class LoopContext { + public: + LoopContext(std::stack *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) { + loops_->emplace(header, iterator, nullptr); + } + ~LoopContext() { loops_->pop(); } + const FunctionBlockPtr &EndBlock() const { return loops_->top().end; } + + private: + std::stack *loops_; +}; + +// Parser to parse python function +class Parser { + public: + explicit Parser(const std::shared_ptr &ast); + + ~Parser() {} + FuncGraphPtr ParseFuncGraph(); + FuncGraphPtr func_graph() const { return func_graph_; } + ParseStatusCode errcode() const { return errcode_; } + std::shared_ptr ast() const { return ast_; } + // get location info from the ast node + LocationPtr GetLocation(const py::object &node) const; + static void InitParserEnvironment(const py::object &obj); + static void CleanParserResource(); + static FuncGraphPtr GetTopFuncGraph() { return top_func_graph_.lock(); } + static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph); + + private: + // process the stmt node method list + FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node); + // parse expression + FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node); + // process a if statement + FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node); + // process a while statement + FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node); + // process a for statement + FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node); + FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node); + FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node); + // process a function def statement + FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node); + // process a augment assign + FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node); + // process a global declaration + FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node); + // process assign statement + FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node); + // process break statement + FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node); + // process continue statement + FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node); + // process pass statement + FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node); + // process the expr and slice node method list + AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node); + // process a variable name + AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); + // process NoneType + AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node); + // process Ellipsis + AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node); + // process a integer or float number + AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); + // process a string variable + AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node); + // process a name + AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node); + // process a function call + AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node); + // process the if expression + AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node); + // process class type define + AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node); + // process a compare expression + AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node); + // process a bool operation + AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node); + // process a lambda operation + AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node); + // process a tuple + AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node); + // process a tuple + AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node); + // process a tuple + AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node); + // process a slice + AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node); + + // process a extslice + AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node); + + // process a tuple + AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node); + + // process a unaryop + AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node); + + // process a dict ast node expression + AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node); + // generate argument nodes for ast function node + void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node); + // generate argument default value for ast function node + void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node); + // parse ast function node + FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr); + // parse ast statements + FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node); + // parse one ast statement node + FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node); + // parse an ast expresion node + AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node); + + void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock, + const FunctionBlockPtr &falseBlock); + void RemoveUnnecessaryPhis(); + // write a new var + void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node); + + // assign value to single variable name + void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); + + // assign value to tuple + void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); + + // assign value to class member + void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); + + // assign value to subscript + void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); + + // process a bool operation value list + AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, const py::object &op); + + CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node, + const AnfNodePtr &op_iter); + + CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, + const AnfNodePtr &op_hasnext); + + FunctionBlockPtr GenerateBlockInFor(const TraceInfoPtr &trace_info); + + bool ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, + std::vector *packed_arguments); + + bool ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, std::vector *packed_arguments, + std::vector *group_arguments); + + AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, + const std::vector &packed_arguments, + const std::vector &group_arguments, bool need_unpack) const; + ScopePtr GetScopeForParseFunction(); + void BuildMethodMap(); + FunctionBlockPtr MakeFunctionBlock(const Parser &parse) { + FunctionBlockPtr block = std::make_shared(parse); + // In order to keep effect order in the sub-graphs which generated by control flow. + // We copy the flags from the top graph to the sub-graphs. + if (func_graph_ && !func_graph_->attrs().empty()) { + block->func_graph()->set_attrs(func_graph_->attrs()); + } + func_block_list_.push_back(block); + return block; + } + // return a make tuple for input elements list + AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes); + + // shared_ptr will be hold by GraphManager, so just hold a weak ref here. + static FuncGraphWeakPtr top_func_graph_; + // Python function id, used to indicate whether two CNodes come from the same Python function + const std::shared_ptr &ast_; + FuncGraphPtr func_graph_; + // error code setwhen parsing ast tree + ParseStatusCode errcode_; + + // hold all reference for FunctionBlock in this round of parsing, + // so in FunctionBlock class we can use FunctionBlock* in member + // pre_blocks_ and jumps_ to break reference cycle. + std::vector func_block_list_; + using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); + using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); + // define the function map to parse ast Statement + std::map stmt_method_map_; + // define the function map to parse ast expression + std::map expr_method_map_; + // Save current loops to support 'continue', 'break' statement. + std::stack loops_; +}; + +// AST node type define code to ast +class AstNodeType { + public: + AstNodeType(const py::object &node, const std::string &name, AstMainType type) + : node_(node), node_name_(name), main_type_(type) {} + + ~AstNodeType() {} + + std::string node_name() const { return node_name_; } + + py::object node() const { return node_; } + + AstMainType main_type() const { return main_type_; } + + private: + const py::object &node_; + const std::string node_name_; + AstMainType main_type_; +}; + +using AstNodeTypePtr = std::shared_ptr; + +// A helper class to parse python function +class ParseAst { + public: + explicit ParseAst(const py::object &obj) : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {} + + ~ParseAst() = default; + + bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); + + py::object GetAstNode(); + + py::list GetArgs(const py::object &func_node); + + py::list GetArgsDefaultValues(const py::object &func_node); + + AstNodeTypePtr GetNodeType(const py::object &node); + + AstSubType GetOpType(const py::object &node); + + template + py::object CallParserObjMethod(const std::string &method, const T &... args) { + return python_adapter::CallPyObjMethod(parser_, method, args...); + } + + template + py::object CallParseModFunction(const std::string &function, const T &... args) { + return python_adapter::CallPyModFn(module_, function, args...); + } + + const std::string &function_name() const { return function_name_; } + + const std::string &function_module() const { return function_module_; } + + const std::string &function_filename() const { return function_filename_; } + + int function_line_offset() const { return function_line_offset_; } + + py::function function() { return function_; } + + ParseTargetTypeDef target_type() const { return target_type_; } + + py::object obj() { return obj_; } + + py::object parser() { return parser_; } + + py::object module() { return module_; } + + py::object ast_tree() { return ast_tree_; } + + bool IsClassMember(const py::object &node); + + private: + // save obj,eg: class instance or function + py::object obj_; + + // function or class method. + py::function function_; + + py::object ast_tree_; + py::object parser_; + py::module module_; + + // Is function or method + ParseTargetTypeDef target_type_; + + std::string function_name_; + std::string function_module_; + std::string function_filename_; + int function_line_offset_; +}; + +// update the graph flags +bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); + +AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); + +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_PARSE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h new file mode 100644 index 0000000000..bdd79d00bd --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -0,0 +1,152 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_PARSE_BASE_H_ +#define PIPELINE_PARSE_PARSE_BASE_H_ +#include +#include +#include "pybind11/pybind11.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "pybind_api/export_flags.h" + +namespace py = pybind11; +namespace mindspore { +namespace parse { +// define the node type +enum AstMainType : int { + AST_MAIN_TYPE_STMT = 0, // ast.Stmt + AST_MAIN_TYPE_EXPR = 1, // ast.Expr + AST_MAIN_TYPE_SLICE = 2, // ast.Slice + AST_MAIN_TYPE_UNKNOWN = 0xFF // Error +}; + +enum AstSubType : int { + AST_SUB_TYPE_AND = 3, // ast.And + AST_SUB_TYPE_OR = 4, // ast.Or + AST_SUB_TYPE_NAME = 5, // ast.Name + AST_SUB_TYPE_TUPLE = 6, // ast.Tuple + AST_SUB_TYPE_SUBSCRIPT = 7, // ast.Subscript + AST_SUB_TYPE_STARRED = 8, // ast.Starred + AST_SUB_TYPE_UNKNOWN = 0xFF // Error +}; + +// define the parse target type +enum ParseTargetTypeDef { + PARSE_TARGET_FUNCTION = 0, // function + PARSE_TARGET_METHOD = 1, // method + PARSE_TARGET_OBJECT_INSTANCE = 2, // object instance + PARSE_TARGET_UNKNOW = 0xFF // ERROR TYPE +}; + +// define python module name +const char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse"; +const char PYTHON_MOD_PARSE_OBJECT_FUNCTION[] = "parse_cb"; +const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol"; +const char PYTHON_MOD_RESOLVE_GET_OBJ_KEY[] = "get_object_key"; +const char PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER[] = "is_class_member"; +const char PYTHON_MOD_RESOLVE_GET_OBJ_TYPE[] = "get_obj_type"; +const char PYTHON_MOD_GET_OBJ_ID[] = "get_obj_id"; +const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type"; +const char PYTHON_MOD_CREATE_OBJ_INSTANCE[] = "create_obj_instance"; +const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes"; +const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods"; +const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; +const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; +const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class"; +const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; + +const char PYTHON_PARSE_GET_ARGS[] = "get_args"; +const char PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES[] = "get_args_default_values"; +const char PYTHON_PARSE_GET_NODE_TYPE[] = "get_node_type"; +const char PYTHON_PARSE_GET_AST_TYPE[] = "get_ast_type"; +const char PYTHON_PARSE_GET_NAMESPACE_SYMBOL[] = "get_namespace_symbol"; +const char PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL[] = "get_ast_namespace_symbol"; +const char PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL[] = "get_operation_namespace_symbol"; +const char PYTHON_PARSE_GET_LOCATION[] = "get_location"; +const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement"; +const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; +const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; + +const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; +const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; +const char PYTHON_MOD_GET_DEFAULT_INPUT[] = "get_default_input"; + +// define the common name +const char NAMED_PRIMITIVE_LEN[] = "len"; +const char NAMED_PRIMITIVE_ITER[] = "iter"; +const char NAMED_PRIMITIVE_NEXT[] = "next"; +const char NAMED_PRIMITIVE_GETITEM[] = "getitem"; +const char NAMED_PRIMITIVE_SETITEM[] = "setitem"; +const char NAMED_PRIMITIVE_HASNEXT[] = "hasnext"; +const char NAMED_PRIMITIVE_BOOL[] = "bool"; // bool: P.identity +const char NAMED_PRIMITIVE_MAKETUPLE[] = "make_tuple"; +const char NAMED_PRIMITIVE_MAKELIST[] = "make_list"; +const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice"; +const char NAMED_PRIMITIVE_MAKEDICT[] = "make_dict"; +const char NAMED_METAGRAPH_UNPACKCALL[] = "unpack_call"; + +// define NAMED_PRIMITIVE_GETATTR "getattr" +// define python inline attr +const char PYTHON_GET_METHOD_SELF_CLASS[] = "__self__"; +const char PYTHON_GET_OBJ_DESC[] = "__str__"; + +const char PYTHON_EXTERN_PARSE_METHOD[] = "__parse_method__"; +const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; + +// define the parse constant +const int MAX_COMPARISON_OPS_SUPPORTED = 1; +const char CUSTOM_BPROP_NAME[] = "bprop"; + +// define the Namespace name +const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace +const char RESOLVE_NAMESPACE_NAME_CLASS_MEMBER[] = "ClassMember"; // for class member namespace +const char RESOLVE_NAMESPACE_NAME_SYMBOL_STR[] = "SymbolStr"; // for symbol str namespace +const char RESOLVE_NAMESPACE_NAME_COMMON_OPS[] = "CommonOPS"; // for common ops, eg: hasnext, next +const char RESOLVE_NAMESPACE_NAME_MODULE[] = "Module"; // fro Module namespace + +// define Resolve type +enum ResolveTypeDef : int { + RESOLVE_TYPE_NONE = 0, // resolve None + RESOLVE_TYPE_FUNCTION = 1, // reslove function + RESOLVE_TYPE_METHOD = 2, // resolve class method + RESOLVE_TYPE_CLASS_TYPE = 3, // resolve class type + RESOLVE_TYPE_CLASS_INSTANCE = 4, // resolve the class instance of common class + RESOLVE_TYPE_INVALID = 0xFF // resolve invalid +}; + +// define the class instance detail type When the type is RESOLVE_TYPE_CLASS_INSTANCE +enum ClassInstanceTypeDef { + CLASS_INSTANCE_TYPE_CELL = 0, // class instance type is Cell + CLASS_INSTANCE_TYPE_PRIMITIVE = 1, // class instance type is Primitive + CLASS_INSTANCE_TYPE_INVALID = 0xFF +}; + +// Convert python object to ValuePtr +bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false); + +// Convert python obj to graph +FuncGraphPtr ConvertToFuncGraph(const py::object &obj, + const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); + +// Parse the python object to graph +FuncGraphPtr ParsePythonCode(const py::object &obj, + const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_PARSE_BASE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/python_adapter.cc b/mindspore/ccsrc/pipeline/jit/parse/python_adapter.cc new file mode 100644 index 0000000000..17be74b2a1 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/python_adapter.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/python_adapter.h" +#include +#include +#include + +namespace mindspore { +namespace parse { +namespace python_adapter { +// python scoped env, should only have one scoped_ instance +static std::shared_ptr scoped_ = nullptr; +// true: start process from python, false: start process from c++ +static bool python_env_ = false; +static bool use_signature_in_resolve_ = true; +void ResetPythonScope() { scoped_ = nullptr; } +void set_use_signature_in_resolve(bool use_signature) noexcept { use_signature_in_resolve_ = use_signature; } +bool UseSignatureInResolve() { return use_signature_in_resolve_; } +void set_python_env_flag(bool python_env) noexcept { python_env_ = python_env; } +bool IsPythonEnv() { return python_env_; } +void SetPythonPath(const std::string &path) { + // load the python module path + (void)python_adapter::set_python_scoped(); + py::module sys = py::module::import("sys"); + py::list sys_path = sys.attr("path"); + + // check the path is exist? + bool is_exist = false; + for (size_t i = 0; i < sys_path.size(); i++) { + std::string path_str = py::cast(sys_path[i]); + if (path_str == path) { + is_exist = true; + } + } + if (!is_exist) { + (void)sys_path.attr("append")(path.c_str()); + } +} + +std::shared_ptr set_python_scoped() { + // if start process from python, no need set the python scope. + if (!python_env_) { + if ((Py_IsInitialized() == 0) && (scoped_ == nullptr)) { + scoped_ = std::make_shared(); + } + } + return scoped_; +} + +// return the module of python +py::module GetPyModule(const std::string &module) { + if (!module.empty()) { + return py::module::import(module.c_str()); + } else { + return py::none(); + } +} + +// Get the obj of attr +py::object GetPyObjAttr(const py::object &obj, const std::string &attr) { + if (!attr.empty() && !py::isinstance(obj)) { + if (py::hasattr(obj, attr.c_str())) { + return obj.attr(attr.c_str()); + } + MS_LOG(DEBUG) << "Obj have not the attr: " << attr; + } + return py::none(); +} + +py::object GetPyFn(const std::string &module, const std::string &name) { + (void)python_adapter::set_python_scoped(); + if (!module.empty() && !name.empty()) { + py::module mod = py::module::import(module.c_str()); + py::object fn = mod.attr(name.c_str()); + return fn; + } + return py::none(); +} + +} // namespace python_adapter +} // namespace parse +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/python_adapter.h b/mindspore/ccsrc/pipeline/jit/parse/python_adapter.h new file mode 100644 index 0000000000..0f49539bc8 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/python_adapter.h @@ -0,0 +1,78 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_PYTHON_ADAPTER_H_ +#define PIPELINE_PARSE_PYTHON_ADAPTER_H_ +#include +#include +#include + +#include "pybind11/embed.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +#include "pipeline/jit/parse/parse_base.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parse { +// A utility to call python interface +namespace python_adapter { +py::module GetPyModule(const std::string &module); +py::object GetPyObjAttr(const py::object &obj, const std::string &attr); +template +py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) { + if (!method.empty() && !py::isinstance(obj)) { + return obj.attr(method.c_str())(args...); + } + return py::none(); +} + +// call python function of module +template +py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) { + if (!function.empty() && !py::isinstance(mod)) { + return mod.attr(function.c_str())(args...); + } + return py::none(); +} + +// turn off the signature when ut use parser to construct a graph. +void set_use_signature_in_resolve(bool use_signature) noexcept; +bool UseSignatureInResolve(); + +std::shared_ptr set_python_scoped(); +void ResetPythonScope(); +bool IsPythonEnv(); +void SetPythonPath(const std::string &path); +void set_python_env_flag(bool python_env) noexcept; +py::object GetPyFn(const std::string &module, const std::string &name); +// Call the python function +template +py::object CallPyFn(const std::string &module, const std::string &name, T... args) { + (void)set_python_scoped(); + if (!module.empty() && !name.empty()) { + py::module mod = py::module::import(module.c_str()); + py::object fn = mod.attr(name.c_str())(args...); + return fn; + } + return py::none(); +} +} // namespace python_adapter +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_PYTHON_ADAPTER_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc new file mode 100644 index 0000000000..8d4c402639 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -0,0 +1,324 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/resolve.h" + +#include +#include +#include +#include + +#include "ir/param_value.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "utils/any.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/opt.h" +#include "frontend/optimizer/irpass.h" +#include "./common.h" + +namespace mindspore { +namespace parse { +abstract::AbstractBasePtr ClassObject::ToAbstract() { + ClassPtr cls_ptr = ParseDataClass(obj()); + auto abs_scalar = std::make_shared(); + abs_scalar->set_type(std::make_shared()); + abs_scalar->set_value(cls_ptr); + + AbstractBasePtrList args_spec_list = {abs_scalar}; + auto func_ptr = std::make_shared(prim::kPrimMakeRecord); + return std::make_shared(func_ptr, args_spec_list); +} + +abstract::AbstractBasePtr ClassType::ToAbstract() { + auto abs_scalar = + std::make_shared(shared_from_base(), std::make_shared()); + AbstractBasePtrList args_spec_list = {abs_scalar}; + + auto func_ptr = std::make_shared(prim::kPrimCreateInstance); + auto ret_val = std::make_shared(func_ptr, args_spec_list); + ret_val->set_value_desc(ToString()); + return ret_val; +} + +// call python PYTHON_MOD_RESOLVE_FUNCTION interface to resolve the symbol in corresponding namespace +bool SymbolResolver::Resolve() { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + + py::object obj = namespace_->obj(); + std::string symbol = symbol_->symbol(); + if (py::isinstance(obj)) { + MS_LOG(ERROR) << "Unresolved symbol: " << symbol; + return false; + } + result_ = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol)); + return true; +} + +namespace { +// argument obj should be python Parameter object +// it will be converted to Parameter node here +AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { + MS_EXCEPTION_IF_NULL(func_graph); + + // parameter object should not be none + if (py::isinstance(obj)) { + MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null."; + } + + if (!py::hasattr(obj, "name")) { + MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj"; + } + + // get the parameter name from parameter object + auto name_attr = python_adapter::GetPyObjAttr(obj, "name"); + if (py::isinstance(name_attr)) { + MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; + } + + std::string param_name = py::cast(name_attr); + auto top_graph = Parser::GetTopFuncGraph(); + // if the parameter node has been created , return it + AnfNodePtr para_node = nullptr; + for (auto const ¶m : top_graph->parameters()) { + auto param_node = dyn_cast(param); + if (param_node != nullptr && param_node->name() == param_name) { + para_node = param; + break; + } + } + if (para_node == nullptr) { + auto node = top_graph->AddWeightParameter(param_name); + auto param_value = py::cast(python_adapter::GetPyObjAttr(obj, "_value")); + node->set_default_param(param_value); + // set_abstract for parameter + ValuePtr value = param_value->value(); + constexpr bool broaden = true; + node->set_abstract(abstract::FromValue(value, broaden)); + para_node = node; + } + auto iter = func_graph->make_ref_params().find(para_node); + if (iter == func_graph->make_ref_params().end()) { + AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node); + + AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); + AnfNodePtr ref_key = NewValueNode(std::make_shared(param_name)); + AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node}); + func_graph->make_ref_params()[para_node] = ref_node; + func_graph->add_parameter_obj_node(ref_node); + return ref_node; + } else { + return iter->second; + } +} + +bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { + AnfNodePtr output = nullptr; + if (py::hasattr(obj, "__parameter__")) { + auto param = ResolveParameterObj(func_graph, obj); + if (param == nullptr) { + MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr"; + return false; + } + MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString(); + + output = param; + } else if (py::hasattr(obj, "__parameter_tuple__")) { + auto tuple = obj.cast(); + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t it = 0; it < tuple.size(); ++it) { + AnfNodePtr out = nullptr; + bool success = ResolveObjectToNode(func_graph, tuple[it], &out); + if (!success) { + MS_LOG(ERROR) << "Resolve object to node failed"; + return false; + } + args.push_back(out); + } + output = NewCNode(args, func_graph); + } else { + ValuePtr convert_result = nullptr; + bool converted = ConvertData(obj, &convert_result, parse::python_adapter::UseSignatureInResolve()); + if (!converted) { + MS_LOG(ERROR) << "Convert data failed"; + return false; + } + MS_EXCEPTION_IF_NULL(convert_result); + output = NewValueNode(convert_result); + if (convert_result->isa()) { + output = GetMixedPrecisionCastHelp(func_graph, output); + } + } + *node = output; + return true; +} + +bool IsAllFuncInValueSequence(const std::vector &value_vec) { + for (auto &elem : value_vec) { + if (elem->isa() || elem->isa()) { + const auto &vec = GetValue>(elem); + auto is_graph = IsAllFuncInValueSequence(vec); + if (!is_graph) { + return false; + } + } else if (!elem->isa() && !elem->isa()) { + return false; + } + } + return true; +} + +AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, + const std::vector &value_vec) { + std::vector nodes; + nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (auto &elem : value_vec) { + AnfNodePtr node = nullptr; + if (elem->isa() || elem->isa()) { + const auto &vec = GetValue>(elem); + node = TransformToMakeTupleNodes(manager, func_graph, vec); + } else if (elem->isa()) { + FuncGraphPtr new_fg = elem->cast(); + manager->AddFuncGraph(new_fg); + node = NewValueNode(new_fg); + } else if (elem->isa()) { + node = NewValueNode(elem); + } else { + MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString(); + } + nodes.emplace_back(node); + } + auto cnode = func_graph->NewCNode(nodes); + return cnode; +} + +// transform the ValueTuple or ValueList of graph/primitve node to make tuple of const graph/primitve node +bool TransformVectorFuncValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, + const ValueNodePtr &value_node, AnfNodePtr *const transformed) { + MS_EXCEPTION_IF_NULL(value_node); + const auto &value_vec = GetValue>(value_node->value()); + if (!IsAllFuncInValueSequence(value_vec)) { + return false; + } + + // (1) The celllist or ordered_cell will be parsed as valuetuple of const graph in it, + // So if has graph in list, try to replace the node with make tuple of graph value node. + // we do this because the graphmanger won't investigate the graph inside valuetuple, + // change the vector of graph to be make_tuple of graph value node. + // (2) the primitve valuetuple or valuelist may encounter to abstract error, make it all + // independent nodes. + auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); + // replace the ret ptr to be make tuple of graph value node + *transformed = node_tuple_graphs; + + return true; +} +} // namespace + +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node) { + if (node->func_graph() == nullptr || manager == nullptr) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; + } + SymbolResolver symbol_resolver(name_space, symbol, node); + if (!symbol_resolver.Resolve()) { + MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + + py::object obj = symbol_resolver.result(); + ScopeGuard scope_guard(node->scope()); + AnfNodePtr resolved_node = nullptr; + TraceManager::DebugTrace(std::make_shared(node->debug_info())); + bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node); + if (!success) { + MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + if (IsValueNode(resolved_node)) { + auto new_fg = GetValueNode(resolved_node); + manager->AddFuncGraph(new_fg); + } + + // if the constant node is constant of vector of graph ,add graph to manager + if (IsValueNode(resolved_node) || IsValueNode(resolved_node)) { + (void)TransformVectorFuncValueNode(manager, node->func_graph(), resolved_node->cast(), + &resolved_node); + } + + TraceManager::EndTrace(); + return resolved_node; +} + +namespace { +opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { + opt::OptPassGroupMap map({ + {"resolve", + { + // for resolve and getattr primitive; + irpass.resolver_resolve_, + irpass.resolver_getattr_, + }}, + }); + return map; +} +} // namespace + +bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) { + if (func_graph == nullptr || res == nullptr) { + MS_LOG(ERROR) << "func_graph or resource is null"; + return false; + } + opt::irpass::ResolveIRPassLib irpass; + opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass)); + + (void)parse::python_adapter::set_python_scoped(); + + MS_EXCEPTION_IF_NULL(opt_resolve); + (void)opt_resolve->step(func_graph, use_profile); + return true; +} + +bool ResolveAll(const FuncGraphManagerPtr &manager) { + if (manager == nullptr) { + MS_LOG(ERROR) << "func graph manager is null"; + return false; + } + + if (manager->roots().size() > 1) { + MS_LOG(WARNING) + << "After call ResolveAll, only one graph will be kept in GraphManager. ResolveAll can resolve graphs" + "called from root graph, so it's not necessary to pass all graphs as roots. " + "Please ensure your usage."; + } + // should not use pipeline::Resource as Resource::Clean will clean some + // global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop + // fail as valid scope has been cleaned. + auto res = std::make_shared(); + res->set_manager(manager); + + auto roots = manager->roots(); + for (auto &fg : roots) { + bool ret = ResolveFuncGraph(fg, res, false); + if (!ret) { + MS_EXCEPTION_IF_NULL(fg); + MS_LOG(ERROR) << "Resolve fg " << fg->ToString() << " failed"; + } + } + return true; +} +} // namespace parse +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.h b/mindspore/ccsrc/pipeline/jit/parse/resolve.h new file mode 100644 index 0000000000..d924f1ef44 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.h @@ -0,0 +1,158 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_RESOLVE_H_ +#define PIPELINE_PARSE_RESOLVE_H_ + +#include +#include +#include "ir/anf.h" +#include "ir/manager.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/parse_base.h" +#include "abstract/abstract_value.h" +#include "utils/log_adapter.h" + +// forward declaration of ResourceBase +namespace mindspore { +namespace pipeline { +class ResourceBase; +using ResourceBasePtr = std::shared_ptr; +} // namespace pipeline +} // namespace mindspore + +namespace mindspore { +namespace parse { + +// NameSpace class for resolving python code. +class NameSpace : public Named { + public: + NameSpace(const std::string &module, const py::object &obj) : Named(module), module_(module), obj_(obj) {} + ~NameSpace() override = default; + MS_DECLARE_PARENT(NameSpace, Named); + + py::object obj() { return obj_; } + std::string module() { return module_; } + abstract::AbstractBasePtr ToAbstract() override { + return std::make_shared(shared_from_base(), std::make_shared()); + } + + private: + // namespace of the module + std::string module_; + // namespace object + py::object obj_; +}; +using NameSpacePtr = std::shared_ptr; + +// Symbol in NameSpace or Class which shall be resolved. +class Symbol : public Named { + public: + explicit Symbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {} + explicit Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {} + + ~Symbol() override = default; + MS_DECLARE_PARENT(Symbol, Named); + + std::string symbol() { return symbol_; } + abstract::AbstractBasePtr ToAbstract() override { + return std::make_shared(shared_from_base(), std::make_shared()); + } + + private: + std::string symbol_; +}; +using SymbolPtr = std::shared_ptr; + +// PyObjectWrapper class wrappers resolved python object for further processing. +class PyObjectWrapper : public Named { + public: + explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {} + ~PyObjectWrapper() override = default; + MS_DECLARE_PARENT(PyObjectWrapper, Named); + py::object obj() { return obj_; } + + private: + // the object that needs to be resolved + py::object obj_; +}; + +// ClassObject class wrappers dataclass +class ClassObject : public PyObjectWrapper { + public: + explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass") + : PyObjectWrapper(obj, name) {} + ~ClassObject() override = default; + MS_DECLARE_PARENT(ClassObject, PyObjectWrapper); + abstract::AbstractBasePtr ToAbstract() override; +}; + +// ClassType class wrappers class name in python +class ClassType : public PyObjectWrapper { + public: + explicit ClassType(const py::object &obj, const std::string name = "Python class type") + : PyObjectWrapper(obj, name) {} + ~ClassType() override = default; + MS_DECLARE_PARENT(ClassType, PyObjectWrapper); + abstract::AbstractBasePtr ToAbstract() override; +}; + +// SymbolResolver class for resolving symbol extracted from AnfNode. +class SymbolResolver { + public: + SymbolResolver(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) + : namespace_(name_space), symbol_(symbol), resolved_node_(node) {} + + ~SymbolResolver() = default; + + // resolve symbol in namespace and save it in result_; + bool Resolve(); + + NameSpacePtr get_namespace() { return namespace_; } + + SymbolPtr symbol() { return symbol_; } + + py::object &result() { return result_; } + + AnfNodePtr resolved_node() { return resolved_node_; } + + // Resolve result + py::object result_; + + private: + // namespace where the symbol locates + NameSpacePtr namespace_; + // the symbol that needs to be resovled + SymbolPtr symbol_; + // the node that has been resolved + AnfNodePtr resolved_node_; +}; +using SymbolResolverPtr = std::shared_ptr; +// Resolve symbol in namespace. +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node); + +// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). +bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); + +// Resolve all graphs in manager which is defined outside of pipeline::Resource. +// Mainly used for test cases or resolve graphs which will not be managed by manager. +bool ResolveAll(const FuncGraphManagerPtr &manager); + +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_RESOLVE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc new file mode 100644 index 0000000000..bb9a517556 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -0,0 +1,340 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/pass.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/func_graph_cloner.h" +#include "debug/anf_ir_utils.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/validator.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/cse.h" +#include "frontend/optimizer/graph_kernel_reuse.h" +#include "frontend/optimizer/clean.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/control_depend.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/step_auto_parallel.h" +#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" +#include "utils/any.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace pipeline { +using OptPassGroupMap = opt::OptPassGroupMap; +using Optimizer = opt::Optimizer; +using CompileGraphs = compile::CompileGraphs; +using abstract::AnalysisResult; +using mindspore::abstract::AnalysisContextPtr; +using mindspore::validator::Validate; + +bool SimplifyDataStructuresPass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res->func_graph()); + + FuncGraphPtr func_graph = res->func_graph(); + bool changed = opt::SimplifyDataStructures(func_graph, res->manager()); + + abstract::AbstractBasePtrList args_spec; + auto parameters = func_graph->parameters(); + (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); + if (changed) { + FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); + res->set_func_graph(new_fg); + } + res->set_args_spec(args_spec); + return true; +} + +namespace { +OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig a_1 = opt::OptPassConfig({ + irpass.switch_simplify_, + + // Safe inlining + irpass.inline_, + irpass.partial_eliminate_, + irpass.replace_applicator_, + + // Specialization + irpass.specialize_transform_, + + // Miscellaneous + irpass.item_tuple_eliminate_, + irpass.env_get_item_eliminate_, + irpass.cast_eliminate_, + irpass.reshape_eliminate_, + irpass.reduce_eliminate_, + irpass.tile_eliminate_, + irpass.transpose_eliminate_, + irpass.minmaximum_grad_, + irpass.get_make_ref_eliminate_, + + // Arithmetic simplifications + irpass.arithmetic_simplify_, + irpass.addn_zero_filter_, + irpass.adjust_all_reduce_mul_add_, + + // Safe inlining + irpass.inline_, + }); + opt::OptPassConfig a_2 = opt::OptPassConfig({ + irpass.merge_addn_, + irpass.float_tuple_getitem_switch_, + irpass.float_env_getitem_switch_, + irpass.incorporate_getitem_set_, + irpass.incorporate_call_, + irpass.incorporate_call_switch_, + irpass.incorporate_env_getitem_, + irpass.incorporate_env_getitem_switch_, + irpass.new_env_get_item_, + irpass.depend_value_elim_, + }); + opt::OptPassConfig a_3 = opt::OptPassConfig({ + irpass.arithmetic_simplify2_, + irpass.same_eliminate_, + irpass.check_bprop_eliminate_, + irpass.replace_applicator_, + }); + opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); + opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); + opt::irpass::ResolveIRPassLib resolve_irpass; + + opt::OptPassConfig resolve_pass = + opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, + irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); + + OptPassGroupMap map_a({{"a_1", a_1}, + {"a_2", a_2}, + {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)}, + {"parallel", opt::OptPassConfig(parallel::StepParallel)}, + {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, + {"virtual_dataset", virtual_dataset}, + {"grad", grad}, + {"resolve", resolve_pass}, + {"renormalize", opt::OptPassConfig::Renormalize()}, + {"cse", opt::OptPassConfig(opt::CSE(false))}, + {"a_3", a_3}}); + + return map_a; +} + +OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig b_1 = opt::OptPassConfig({ + irpass.zero_like_fill_zero_, + irpass.item_tuple_eliminate_, + irpass.float_tuple_getitem_switch_, + irpass.reset_defer_inline_, + irpass.inline_, + irpass.special_op_eliminate_, + irpass.get_make_ref_eliminate_, + }); + opt::OptPassConfig b_2 = opt::OptPassConfig({ + irpass.replace_refkey_by_param_, + irpass.make_ref_eliminate_, + irpass.get_ref_param_eliminate_, + irpass.indexed_slices_eliminate_, + }); + OptPassGroupMap map({ + {"b_1", b_1}, + {"b_2", b_2}, + {"renormalize", opt::OptPassConfig::Renormalize()}, + {"cse", opt::OptPassConfig(opt::CSE(false))}, + }); + return map; +} + +OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig interface_fusion = opt::OptPassConfig({ + irpass.mark_interface_fusion_, + }); + OptPassGroupMap map({ + {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, + {"interface_fusion", interface_fusion}, + {"renormalize", opt::OptPassConfig::Renormalize()}, + {"cse", opt::OptPassConfig(opt::CSE(false))}, + }); + return map; +} + +OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig elim_1 = opt::OptPassConfig({ + irpass.addn_eliminate_, + irpass.incorporate_getitem_from_param_, + }); + opt::OptPassConfig elim_2 = opt::OptPassConfig({ + irpass.unused_parameter_eliminate_, + irpass.unused_output_eliminate_, + }); + OptPassGroupMap map({ + {"elim_1", elim_1}, + {"renormalize", opt::OptPassConfig::Renormalize()}, + {"elim_2", elim_2}, + }); + return map; +} + +OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) { + return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}}); +} + +OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true); + OptPassGroupMap map({ + {"control_group", control_group}, + {"renormalize", opt::OptPassConfig::Renormalize()}, + }); + return map; +} + +OptPassGroupMap GetInferenceOptPreparePhases() { + opt::irpass::InferenceOptPrepareLib irpass; + auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); + opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}}); + return prepare_map; +} + +OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); + OptPassGroupMap map({{"prepare_group", prepare_group}}); + return map; +} + +static std::unordered_map> g_pass_opts = {}; + +void InitOpt(const ResourcePtr &res) { + if (g_pass_opts.size() == 0) { + opt::irpass::OptimizeIRPassLib irpass; + g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); + g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); + g_pass_opts["opt_graph_kernel_a"] = + Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); + g_pass_opts["opt_graph_kernel_b"] = + Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false); + g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); + g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); + g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!(context_ptr->enable_graph_kernel())) { + g_pass_opts["opt_graph_kernel_a"]->set_enable(false); + g_pass_opts["opt_graph_kernel_b"]->set_enable(false); + } + } +} +} // namespace + +void ReclaimOptimizer() { + for (auto &opt : g_pass_opts) { + opt.second = nullptr; + } + g_pass_opts.clear(); +} + +bool OptPassGroup(const ResourcePtr &res, const std::string &name) { + if (res->func_graph() == nullptr) { + MS_LOG(ERROR) << "Opt passes int error"; + return false; + } + + FuncGraphPtr func_graph = res->func_graph(); + MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", " + << func_graph->get_return()->DebugString(true); + InitOpt(res); + if (g_pass_opts.find(name) != g_pass_opts.end()) { + res->set_func_graph(g_pass_opts[name]->step(func_graph)); + } + // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to + // res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here. + return true; +} + +bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } +bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } +bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } +bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } +bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } +bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } + +bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } + +bool AddControlDependPass(const ResourcePtr &res) { + FuncGraphPtr func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + + if (func_graph->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { + opt::AddControlDepend(func_graph); + } + for (auto fg : func_graph->func_graphs_used_total()) { + MS_EXCEPTION_IF_NULL(fg); + if (fg->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { + opt::AddControlDepend(fg); + } + } + return true; +} + +bool CconvPass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res->func_graph()); + FuncGraphPtr func_graph = res->func_graph(); + FuncGraphPtr new_fg = LiftingClone(func_graph); + res->set_func_graph(new_fg); + return true; +} + +bool ValidatePass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res->func_graph()); + FuncGraphPtr func_graph = res->func_graph(); + Validate(func_graph); + return true; +} + +bool InferenceOptPreparePass(const ResourcePtr &res) { + FuncGraphPtr func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto prepare_map = GetInferenceOptPreparePhases(); + auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); + (void)infer_opt_prepare->step(func_graph, false); + return true; +} + +std::vector kVmPasses = {{"opt_a", OptPassAGroup}, + {"simplify_data_structures", SimplifyDataStructuresPass}, + {"opt_b", OptPassBGroup}, + {"cconv", CconvPass}, + {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, + {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, + {"add_control_depend", AddControlDependPass}}; + +std::vector kGePasses = { + {"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass}, + {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, + {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, + {"cconv", CconvPass}}; + +std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.h b/mindspore/ccsrc/pipeline/jit/pass.h new file mode 100644 index 0000000000..0233b6cf26 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pass.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_PASS_H_ +#define MINDSPORE_CCSRC_PIPELINE_PASS_H_ + +#include +#include +#include +#include +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace pipeline { +using PassItem = std::pair>; + +extern std::vector kGePasses; +extern std::vector kVmPasses; +extern std::vector kPynativePasses; + +bool CconvPass(const ResourcePtr &res); +bool ValidatePass(const ResourcePtr &res); +bool ConvertPrepareAdapt(const ResourcePtr &res); +bool AddControlDependPass(const ResourcePtr &res); +bool InferenceOptPreparePass(const ResourcePtr &res); +void ReclaimOptimizer(); +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_PASS_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc new file mode 100644 index 0000000000..05699793ff --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -0,0 +1,948 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/pipeline.h" + +#include +#include +#include +#include +#include + +#include "ir/param_value.h" +#include "pipeline/jit/pass.h" +#include "pipeline/jit/parse/data_converter.h" +#include "frontend/optimizer/ad/dfunctor.h" +#include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" +#include "utils/config_manager.h" +#include "utils/convert_utils.h" +#include "utils/utils.h" +#include "vm/segment_runner.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/graph_util/get_parallel_info.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "debug/trace.h" +#include "pipeline/pynative/pynative_execute.h" +#include "frontend/optimizer/py_pass_manager.h" + +#if (ENABLE_GE || ENABLE_D) +#include "pipeline/jit/pipeline_ge.h" +#include "transform/graph_ir/convert.h" +#include "transform/graph_ir/df_graph_manager.h" +#endif + +namespace mindspore { +// namespace to support intermediate representation definition +namespace pipeline { +using Tensor = mindspore::tensor::Tensor; +using MetaTensor = mindspore::tensor::MetaTensor; +using TensorOrderMap = std::map>; +using mindspore::abstract::AbstractTensor; +using mindspore::abstract::AbstractTensorPtr; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractTuplePtr; + +const char IR_TYPE_ANF[] = "anf_ir"; +const char IR_TYPE_ONNX[] = "onnx_ir"; +const char IR_TYPE_BINARY[] = "binary_ir"; + +ExecutorPyPtr ExecutorPy::executor_ = nullptr; +std::mutex ExecutorPy::instance_lock_; + +std::unordered_map + g_args_cache; + +namespace { +std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) { + std::ostringstream oss; + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(EXCEPTION) << "ms_context is nullptr"; + } + auto save_graphs_path = ms_context->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + oss << save_graphs_path << "/" << stage_idx << "_" << action_name; + return oss.str(); +} +} // namespace + +py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults) { + MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size(); + abstract::AbstractBasePtrList args_spec; + + for (auto arg : defaults) { + if (py::isinstance(arg.second)) { + MS_LOG(EXCEPTION) << "GenerateKey failed, argument input should not be py::module"; + } + ValuePtr converted = nullptr; + if (!parse::ConvertData(arg.second, &converted)) { + MS_LOG(EXCEPTION) << "GenerateKey convert arg failed"; + } + args_spec.push_back(abstract::FromValue(converted, true)); + } + if (g_args_cache.count(args_spec) == 0) { + static int key = 0; + MS_LOG(INFO) << "Start new args and compile key:" << key; + g_args_cache[args_spec] = key++; + } + auto argSpec = py::tuple(2); + argSpec[0] = name; + argSpec[1] = g_args_cache[args_spec]; + return argSpec; +} + +py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs) { + MS_LOG(DEBUG) << "Verify args size:" << inputs.size(); + if (inputs.size() != input_signature.size()) { + MS_LOG(ERROR) << "Signature size not equal to args size"; + return false; + } + + size_t count = 0; + for (auto arg_obj : inputs) { + if (py::hasattr(arg_obj, PYTHON_TENSOR_FLAG)) { + MS_LOG(DEBUG) << "Verify Tensor"; + std::shared_ptr m_tensor = arg_obj.cast>(); + if (m_tensor == nullptr) { + MS_LOG(ERROR) << "Verify Tensor error, get ptr is null"; + return false; + } + std::shared_ptr sig = input_signature[count].cast>(); + std::vector sig_shape = sig->shape(); + TypePtr sig_type = sig->Dtype(); + + std::vector tensor_shape = m_tensor->shape_c(); + if (tensor_shape != sig_shape) { + MS_LOG(ERROR) << "Python input shape is incompatible with input_signature"; + return false; + } + + if (*m_tensor->Dtype() != *sig_type) { + MS_LOG(ERROR) << "Python input type(" << m_tensor->Dtype()->ToString() << ") incompatible with input_signature(" + << sig_type->ToString() << ")"; + return false; + } + } + count++; + } + + return true; +} + +ExecutorPy::ExecutorPy() {} + +ResourcePtr ExecutorPy::GetResource(const std::string &phase) { + MS_LOG(DEBUG) << "Phase size:" << info_.size(); + if (info_.count(phase) == 0) { + return nullptr; + } + return info_[phase]->resource; +} + +FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) { + if (info_.count(phase) == 0) { + MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); + } + return info_[phase]->func_graph; +} + +std::size_t ExecutorPy::ArgListSize(const std::string &phase) { + if (info_.count(phase) == 0) { + MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); + } + return info_[phase]->arg_list_size; +} + +compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) { + ResourcePtr res = GetResource(phase); + MS_EXCEPTION_IF_NULL(res); + if (res->results().find(kOutput) != res->results().end() && res->results()[kOutput].is()) { + return res->results()[kOutput].cast(); + } + MS_LOG(ERROR) << "GetVmEvalFunc vm model can't find kOutput:" << kOutput; + return nullptr; +} + +bool ExecutorPy::HasCompiled(const std::string &phase) const { + if (info_.count(phase) == 0) { + return false; + } + return true; +} + +py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type) { + FuncGraphPtr fg_ptr = GetFuncGraph(phase); + if (fg_ptr == nullptr) { + for (auto &item : info_) { + MS_LOG(DEBUG) << "Phase key is: " << item.first; + } + MS_LOG(EXCEPTION) << "Can not find func graph " << phase; + } + + if (ir_type == IR_TYPE_ANF) { + std::string proto_str = GetFuncGraphProtoString(fg_ptr); + if (proto_str.empty()) { + MS_LOG(EXCEPTION) << "Graph proto is empty."; + } + return proto_str; + } + + if (ir_type == IR_TYPE_ONNX) { + std::string proto_str = GetOnnxProtoString(fg_ptr); + if (proto_str.empty()) { + MS_LOG(EXCEPTION) << "Graph proto is empty."; + } + return proto_str; + } + + if (ir_type == IR_TYPE_BINARY) { + std::string proto_str = GetBinaryProtoString(fg_ptr); + if (proto_str.empty()) { + MS_LOG(EXCEPTION) << "Graph proto is empty."; + } + return proto_str; + } + + MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type; +} + +py::dict ExecutorPy::GetParameterLayout(const std::string &phase) { + MS_LOG(DEBUG) << "GetParameterLayout!"; + std::string layout_graph = phase + kStepParallelGraph; + auto graph = GetFuncGraph(layout_graph); + return mindspore::parallel::GetParameterLayout(graph); +} + +py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) { + MS_LOG(DEBUG) << "GetCNodeStrategy!"; + std::string layout_graph = phase + kStepParallelGraph; + auto graph = GetFuncGraph(layout_graph); + return mindspore::parallel::GetCNodeStrategy(graph); +} + +py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) { + MS_LOG(INFO) << "GetAllreduceFusion!"; + auto graph = GetFuncGraph(phase); + return mindspore::parallel::GetAllreduceFusion(graph); +} + +void ExecutorPy::DelNetRes(const std::string &id) { +#ifdef ENABLE_GE + FinalizeBackend(); +#endif + if (executor_ != nullptr) { + bool flag = false; + auto tmp_info = info_; + for (auto &item : tmp_info) { + if (item.first.find(id) != string::npos) { + MS_LOG(DEBUG) << "Delete network res:" << item.first; + (void)info_.erase(item.first); + flag = true; + } + } + + MS_LOG(DEBUG) << "Delete flag:" << flag; +#ifdef ENABLE_GE + if (flag && info_.size() == 0) { + // because Ge only support one Session exist at the same time ,so we delete the old one + transform::DfGraphManager::GetInstance().DeleteGraphRunner(); + transform::DfGraphManager::GetInstance().EraseAnfGraph(); + transform::DfGraphManager::GetInstance().DeleteGeSession(); + } +#endif + } +} + +void ExecutorPy::ClearRes() { + MS_LOG(INFO) << "Clean executor resource!"; + executor_ = nullptr; +} + +ExecutorPy::~ExecutorPy() { + MS_LOG(INFO) << "Release Executor!"; + ConfigManager::GetInstance().ResetConfig(); +} + +std::map> ExecutorPy::FetchInfoForQuantExport( + const std::string &phase_s) { + FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; + std::map> fake_quant_table; + auto filter = [](AnfNodePtr node) { + return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) || + IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative)); + }; + std::vector nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter); + auto is_quant_cnode = [](AnfNodePtr node) { + return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) || + IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel); + }; + for (auto node : nodes) { + auto cnode = node->cast(); + if (cnode == nullptr || cnode->size() != 3) { + continue; + } + auto x = cnode->input(1); + auto weight = cnode->input(2); + if (!is_quant_cnode(weight)) { + continue; + } + // get parameter weight's name + cnode = weight->cast(); + auto weight_node = cnode->input(2); + if (!weight_node->isa()) { + continue; + } + auto weight_name = weight_node->cast()->name(); + // find the fakequant from input + int count = 0; + const int max_depth = 5; + while (!is_quant_cnode(x)) { + if (count >= max_depth) { + break; + } + cnode = x->cast(); + if (cnode == nullptr || cnode->size() <= 1) { + break; + } + x = cnode->input(1); + count += 1; + } + if (x->isa()) { + fake_quant_table[weight_name] = std::make_pair(nullptr, "input"); + } + // get the fakequant parameter minq's name + if (!is_quant_cnode(x)) { + continue; + } + cnode = x->cast(); + if (cnode == nullptr || cnode->size() != 4) { + continue; + } + auto fakequant_min_node = cnode->input(2); + if (!fakequant_min_node->isa()) { + continue; + } + auto fakequant_min_node_name = fakequant_min_node->cast()->name(); + auto quant_op_value = cnode->input(0)->cast()->value(); + if (!quant_op_value->isa()) { + continue; + } + auto quant_op = quant_op_value->cast(); + fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name); + } + + return fake_quant_table; +} + +void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { + // save the graph to ExecutorPy + FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); + std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); + + MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; + info_[phase_s]->func_graph = func_graph; + if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) && + ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) { + MS_LOG(DEBUG) << "Save model parallel parameter layout graph!"; + func_graph = info_[phase_s]->resource->results()[kStepParallelGraph].cast(); + ExecutorInfoPtr executor_info = std::make_shared(); + std::string layout_graph = phase_s + kStepParallelGraph; + executor_info->func_graph = func_graph; + info_[layout_graph] = executor_info; + } else { + MS_LOG(DEBUG) << "Save model parallel parameter layout graph null!"; + } + MS_LOG(INFO) << "End save compiled func graph!"; +} + +bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const { + std::string phase_prefix = GetPhasePrefix(phase_s); + + if (use_vm && phase_prefix == "export") { + MS_LOG(INFO) << "Use ge backend to export geir"; + use_vm = false; + } + return use_vm; +} + +void ExecutorPy::GetGeBackendPolicy() const { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + std::string backend = ms_context->backend_policy(); + if (backend != "ge") { + MS_LOG(EXCEPTION) << backend << " backend policy is not supported under ge backend!"; + } +} + +bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { + MS_LOG(DEBUG) << "Start ExecutorPy compile!"; + if ((!py::isinstance(phase))) { + MS_LOG(ERROR) << "Arg phase must be string."; + return false; + } + // check the arg valid? + if (py::isinstance(obj)) { + MS_LOG(ERROR) << "Find error: parse obj is None."; + return false; + } +#ifdef ENABLE_GE + GetGeBackendPolicy(); +#endif + ExecutorInfoPtr executor_info = std::make_shared(); + std::string phase_s = py::cast(phase); + MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!"; + ResourcePtr resource = std::make_shared(obj); + std::vector p_actions; + + use_vm = ChangeExportGeirUseVmFlag(use_vm, phase_s); + + std::string backend = MsContext::GetInstance()->backend_policy(); + if (use_vm && backend != "ge") { + // Create backend and session + auto backend_ptr = compile::CreateBackend(); + // Connect session to debugger + backend_ptr->SetDebugger(); + resource->results()[kBackend] = backend_ptr; + p_actions = VmPipeline(); + } else { + p_actions = GePipeline(); + } + + std::shared_ptr pip = std::make_shared(resource, FilterActions(p_actions, phase_s)); + + // get the parameters items and add the value to args_spec + abstract::AbstractBasePtrList args_spec; + std::size_t size = args.size(); + for (std::size_t i = 0; i < size; i++) { + ValuePtr converted = nullptr; + bool succ = parse::ConvertData(args[i], &converted); + if (!succ) { + MS_LOG(EXCEPTION) << "Args convert error"; + } + bool broaden = true; + args_spec.push_back(abstract::FromValue(converted, broaden)); + } + + resource->set_args_spec(args_spec); + executor_info->arg_list_size = size; + executor_info->resource = resource; + info_[phase_s] = executor_info; + pip->Run(); + + // save the run graph func to MsPipeLine + SaveCompiledGraph(phase_s); + + resource->Clean(); + // Reclaim all resource used by optimizer; + ReclaimOptimizer(); + + MS_LOG(INFO) << "End ExecutorPy compile!"; + return true; +} + +std::vector ExecutorPy::FilterActions(const std::vector &actions, const std::string &phase) { + // phase does not contain 'export_onnx' + if (GetPhasePrefix(phase).find("export_onnx") == std::string::npos) { + return actions; + } + MS_LOG(INFO) << "Phase is '" << phase << "', filter out actions after stage 'validate'"; + std::vector filtered_actions; + for (const auto &item : actions) { + filtered_actions.emplace_back(item); + if (item.first == "validate") { + break; + } + } + return filtered_actions; +} + +void ExecutorPy::ReleaseResource(const py::object &phase) { + ResourcePtr res = GetResource(py::cast(phase)); + if (res != nullptr) { + res->Clean(); + } + // Reclaim all resource used by optimizer; + ReclaimOptimizer(); +} + +static std::string PrintArgs(const py::tuple &args) { + py::print(args); + return ""; +} + +bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { + bool ret_value = false; + + try { + MS_LOG(DEBUG) << PrintArgs(args); + ret_value = CompileInner(obj, args, phase, use_vm); + } catch (const py::error_already_set &ex) { + // print function call stack info before release + std::ostringstream oss; + trace::TraceGraphEval(); + trace::GetEvalStackInfo(oss); + // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see + // these info from screen, no need to open log file to find these info + py::print(oss.str()); + MS_LOG(ERROR) << oss.str(); + ReleaseResource(phase); + + // re-throw this exception to Python interpreter to handle it + throw(py::error_already_set(ex)); + } catch (const py::type_error &ex) { + ReleaseResource(phase); + throw py::type_error(ex); + } catch (const py::value_error &ex) { + ReleaseResource(phase); + throw py::value_error(ex); + } catch (const py::index_error &ex) { + ReleaseResource(phase); + throw py::index_error(ex); + } catch (const std::exception &ex) { + ReleaseResource(phase); + // re-throw this exception to Python interpreter to handle it + throw(std::runtime_error(ex.what())); + } catch (...) { + ReleaseResource(phase); + std::string exName(abi::__cxa_current_exception_type()->name()); + MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; + } + + return ret_value; +} + +#ifdef ENABLE_LOAD_ANF_IR +// get MindSpore Intermediate Representation File +std::string GetMsIrFile(void) { + std::string file; + const char *path = getenv("MS_IR_FILE"); + if (path == nullptr) { + return file; + } + + char real_path[PATH_MAX] = {0}; + if (realpath(path, real_path) == nullptr) { + MS_LOG(ERROR) << "MS IR path error, " << path; + return file; + } + file = real_path; + return file; +} + +void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource, bool *result) { + MS_EXCEPTION_IF_NULL(resource); + MS_EXCEPTION_IF_NULL(result); + + std::string ir_file = GetMsIrFile(); + (void)parse::python_adapter::set_python_scoped(); + if (ir_file.empty()) { + *result = action.second(resource); + return; + } + + // when in loading anf ir mode, action `parse` do nothing + if (action.first == "parse") { + return; + } + + // load MindSpore IR from file + if (action.first == "symbol_resolve") { + MS_LOG(DEBUG) << action.first << " read ir file: " << ir_file; + std::vector graphs = ImportIR(ir_file); + if (graphs.size() == 0) { + MS_LOG(EXCEPTION) << action.first << " read ir file " << ir_file << " failed as no graph found"; + } + auto manager = resource->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (auto &graph : graphs) { + manager->AddFuncGraph(graph); + } + resource->set_func_graph(graphs[0]); + return; + } + + // do normal action when not in `parse` and `symbol_resolve` stage + *result = action.second(resource); +} +#endif + +void Pipeline::Run() { + MS_LOG(INFO) << "Pipeline run"; + MS_EXCEPTION_IF_NULL(resource_); + FuncGraphPtr user_graph = nullptr; + + WITH(MsProfile::GetProfile())[&user_graph, this]() { + int i = 0; + for (auto &action : actions_) { +#ifdef ENABLE_TIMELINE + DumpTime &dump_time = DumpTime::GetInstance(); + dump_time.Record(action.first, GetTime(), true); +#endif + bool result = true; + WITH(MsProfile::GetProfile()->Step(action.first))[&result, &action, this]() { + MS_LOG(DEBUG) << "Action " << action.first << " start ..."; +#ifdef ENABLE_LOAD_ANF_IR + RunPipelineAction(action, resource_, &result); +#else + result = action.second(resource_); +#endif + MS_LOG(DEBUG) << "Action " << action.first << " end."; + }; + if (!result) { + MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first; + } + if (MsContext::GetInstance()->save_graphs_flag() && resource_->func_graph() != nullptr) { + auto graph = resource_->func_graph(); + if (graph != nullptr) { + user_graph = graph; + std::string base_name = GetBaseNameForIR(i, action.first); + + // generate IR file in dot format, which can be converted to svg file using graphviz dot command + draw::Draw(base_name + ".dot", graph); + // generate IR file in human readable format + DumpIR(base_name + ".ir", graph); + // generate IR file in a heavily commented format, which can also be reloaded + ExportIR(base_name + ".dat", std::to_string(i), graph); + } +#ifdef MS_DEBUG + // Dump graph cnode list + MS_LOG(INFO) << "Show CNode list after " << action.first; + graph->DumpCNodeList(); +#endif + } + if (resource_->func_graph() != nullptr) { + auto func_graph = resource_->func_graph(); + if (func_graph->has_flag(GRAPH_FLAG_HAS_EFFECT)) { + func_graph->EraseUnusedNodeInOrder(); + func_graph->CheckOrder(); + for (auto fg : func_graph->func_graphs_used_total()) { + MS_LOG(DEBUG) << "Check order graph " << fg->ToString() << "."; + fg->EraseUnusedNodeInOrder(); + fg->CheckOrder(); + } + } + } + i++; +#ifdef ENABLE_TIMELINE + dump_time.Record(action.first, GetTime(), false); +#endif + } + }; +#ifdef ENABLE_PROFILE + MsProfile::Print(); + MsProfile::Reset(); +#endif + + if (MsContext::GetInstance()->save_graphs_flag() && (user_graph != nullptr)) { + std::string user_graph_file = GetFilePathName("ModelDigraph.dot"); + MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file; + draw::DrawUserFuncGraph(user_graph_file, user_graph); + } + MS_LOG(INFO) << "End"; +} + +void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) { + std::size_t size = args.size(); + + for (std::size_t i = 0; i < size; i++) { + py::object arg = args[i]; + auto ms_context = MsContext::GetInstance(); + if (ms_context->backend_policy() == kMsConvert && py::isinstance(arg)) { + MS_LOG(EXCEPTION) << "The " << i << "th arg is numpy array, not tensor."; + } + ValuePtr converted = nullptr; + bool succ = parse::ConvertData(arg, &converted); + if (!succ) { + MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; + } + if (MsContext::GetInstance()->execution_mode() == 0 && !converted->isa()) { + MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString() + << " is not tensor."; + } + arg_list->push_back(converted); + } + + MS_EXCEPTION_IF_NULL(res); + auto graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + std::vector graph_params = graph->parameters(); + std::size_t graph_params_size = graph_params.size(); + if ((*arg_list).size() != graph_params_size) { + // maybe some default parameter + for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) { + MS_EXCEPTION_IF_NULL(graph_params[i]); + auto param_ptr = (graph_params[i])->cast(); + if (!param_ptr->has_default()) { + MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param"; + } + arg_list->push_back(param_ptr->default_param()->value()); + } + } +} + +void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list) { + ProcessVmArgInner(args, GetResource(phase), arg_list); +} + +py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { + std::size_t size = args.size(); + if (!py::isinstance(phase)) { + MS_LOG(EXCEPTION) << "Run failed, phase input is not a str"; + } + auto phase_s = py::cast(phase); + std::string backend = MsContext::GetInstance()->backend_policy(); +#ifdef ENABLE_GE + if (backend == "ge") { + return ExecDFGraph(info_, args, phase_s); + } +#else + if (backend == "ms" || backend == "ge") { + auto ret_val = std::make_shared(); + if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) { + if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) { + return *ret_val; + } + } + if (backend == "ge") { + if (args.size() > 0) { + return args[0]; + } + return args; + } + } +#endif + std::size_t full_arg_size = ArgListSize(phase_s); + if (size > full_arg_size) { + MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << full_arg_size; + } + VectorRef arg_list; + ProcessVmArg(args, phase_s, &arg_list); + + compile::VmEvalFuncPtr run = GetVmEvalFunc(phase_s); + if (run == nullptr) { + MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s; + } + + MS_LOG(DEBUG) << "Eval run" << backend; + BaseRef value = (*run)(arg_list); + MS_LOG(DEBUG) << "Run end"; + return BaseRefToPyData(value); +} + +FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase, + const py::object &broadcast_params) { +#if (ENABLE_GE || ENABLE_D) + return BuildDFGraph(info_, init_params, phase, broadcast_params); +#else + return nullptr; +#endif +} + +void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) { +#if ENABLE_GE + RunGEInitGraph(init_params, phase); +#endif +} + +bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase, bool need_run) { + std::string name = MsContext::GetInstance()->backend_policy(); +#ifndef NO_DLIB + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (!ms_context->IsTsdOpened() || !ms_context->IsGeInited()) { + (void)InitBackend(); + } +#endif + if (name == kMsConvert || name == kMsVm) { + return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run); + } +#if ENABLE_GE + return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase); +#else + std::string backend = MsContext::GetInstance()->backend_policy(); + if (backend == "ge") { + return true; + } +#endif + return false; +} + +bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, bool need_run) { + MS_LOG(INFO) << "Start InitDataSet Entry"; + std::vector int_input_indexes; + (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), + [](int64_t item) { return static_cast(item); }); + std::vector> int_shapes; + (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes), + [](const std::vector &item) { + std::vector vector_item; + (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item), + [](int64_t inner_item) { return static_cast(inner_item); }); + return vector_item; + }); + auto p_init = std::make_shared("InitDataSetQueue"); + p_init->set_attr("queue_name", MakeValue(queue_name)); + p_init->set_attr("size", MakeValue(static_cast(size))); + p_init->set_attr("batch_size", MakeValue(static_cast(batch_size))); + p_init->set_attr("types", MakeValue(types)); + p_init->set_attr("shapes", MakeValue(int_shapes)); + p_init->set_attr("input_indexes", MakeValue(int_input_indexes)); + + const std::vector empty_str_list; + p_init->set_attr("input_names", MakeValue(empty_str_list)); + p_init->set_attr("output_names", MakeValue(empty_str_list)); + + FuncGraphPtr func_graph = std::make_shared(); + auto app_init = std::make_shared(AnfNodePtrList{NewValueNode(p_init)}, func_graph); + func_graph->set_output(app_init); + auto manager = MakeManager(); + manager->AddFuncGraph(func_graph); + + // AbstractNone indicates there is no output for this apply node. + auto abstract_none = std::make_shared(); + app_init->set_abstract(abstract_none); + + auto backend = compile::CreateBackend(); + MS_EXCEPTION_IF_NULL(backend); + auto convert_fn = backend->convert_fn(); + MS_EXCEPTION_IF_NULL(convert_fn); + // Convert CNodeList to LinConvertResult. + ConfigManager::GetInstance().set_iter_num(1); + auto runner = convert_fn({app_init}, ""); + if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { + backend->Link(runner.graph_id); + } + ConfigManager::GetInstance().set_iter_num(size); + + if (!(*runner.run)) { + // empty function + MS_LOG(EXCEPTION) << "Backend " << backend->name() << " unsupported tdt dataset."; + } + + // launch init dataset runner without inputs and outputs + VectorRef args; + auto fn = runner.run; + if (need_run) { + (void)(*fn)(args); + } + MS_LOG(DEBUG) << "InitDataSetVm End."; + return true; +} + +void ResetOpId() { mindspore::id_generator::reset_id(); } + +void InitHccl() { +#ifdef ENABLE_GE + (void)InitBackend(); +#else + mindspore::parse::python_adapter::set_python_env_flag(true); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + (void)ms_context->OpenTsd(); + uint32_t device_id = ms_context->device_id(); + std::string device_name = ms_context->device_target(); + ms_context->set_enable_hccl(true); + if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) { + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id); + MS_EXCEPTION_IF_NULL(runtime_instance); + if (!runtime_instance->Init()) { + MS_LOG(ERROR) << "Kernel runtime init error."; + return; + } + } +#endif +} + +void FinalizeHccl() { +#ifdef ENABLE_GE + (void)FinalizeBackend(); +#else + device::KernelRuntimeManager::Instance().ClearRuntimeResource(); +#endif +} + +void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) { +#if (ENABLE_GE || ENABLE_D) + ExportDFGraph(file_name, phase); +#endif + MS_LOG(WARNING) << "In ut test no export_graph"; +} + +void ReleaseGeTsd() { + auto context_ptr = MsContext::GetInstance(); + if (context_ptr != nullptr) { + (void)context_ptr->FinalizeGe(true); + (void)context_ptr->CloseTsd(true); + } +} + +void InitBackend() { + // set python env flag + mindspore::parse::python_adapter::set_python_env_flag(true); + // open tsd before ge initialize + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (!ms_context->OpenTsd()) { + MS_LOG(EXCEPTION) << "Open tsd failed"; + } + (void)ms_context->InitGe(); +} + +void FinalizeBackend() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + (void)context_ptr->FinalizeGe(); + (void)context_ptr->CloseTsd(); +} + +void ClearResAtexit() { + MS_LOG(DEBUG) << "Pipeline clear all resource"; + pynative::ClearPyNativeSession(); + session::ClearPythonParasMap(); + device::KernelRuntimeManager::Instance().ClearRuntimeResource(); + + ad::g_k_prims.clear(); + + abstract::ClearPrimEvaluatorMap(); + compile::ClearConvertCache(); + pipeline::GetMethodMap().clear(); + pipeline::ExecutorPy::ClearRes(); + pipeline::ReclaimOptimizer(); + pynative::PynativeExecutor::GetInstance()->ClearRes(); + opt::python_pass::PyPassManager::GetInstance()->ClearRes(); +#ifdef ENABLE_GE + transform::DfGraphManager::GetInstance().ClearGraph(); + transform::DfGraphConvertor::get_adpt_map().clear(); +#endif + ReleaseGeTsd(); + parse::python_adapter::ResetPythonScope(); +} +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h new file mode 100644 index 0000000000..705853d086 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -0,0 +1,148 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ +#define MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "utils/base_ref_extends.h" +#include "debug/draw.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "pipeline/jit/action.h" +#include "vm/segment_runner.h" +#include "vm/transform.h" +#include "pipeline/jit/base.h" + +namespace mindspore { +extern const char kMsConvert[]; +extern const char kMsVm[]; + +// namespace to support pipeline structures definition +namespace pipeline { + +namespace py = pybind11; + +class Pipeline { + public: + Pipeline(const ResourcePtr &res, const std::vector &actions) : resource_(res), actions_(actions) {} + + ~Pipeline() = default; + + void Run(); + + ResourcePtr resource() { return resource_; } + + private: + ResourcePtr resource_; + std::vector actions_; +}; + +// A function pipeline. +class ExecutorPy : public std::enable_shared_from_this { + public: + static std::shared_ptr GetInstance() { + std::lock_guard i_lock(instance_lock_); + if (executor_ == nullptr) { + executor_ = std::shared_ptr(new (std::nothrow) ExecutorPy()); + } + return executor_; + } + + ~ExecutorPy(); + + void SaveCompiledGraph(const std::string &phase_s); + bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); + bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); + + void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list); + + // for pynative mode when use_vm is on + py::object Run(const py::tuple &args, const py::object &phase); + ResourcePtr GetResource(const std::string &phase); + FuncGraphPtr GetFuncGraph(const std::string &phase); + py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); + std::size_t ArgListSize(const std::string &phase); + compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); + bool HasCompiled(const std::string &phase) const; + + FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase, + const py::object &broadcast_params = {}); + void RunInitGraph(const py::dict &init_params, const std::string &phase); + py::dict GetParameterLayout(const std::string &phase); + py::dict GetCNodeStrategy(const std::string &phase); + py::dict GetAllreduceFusion(const std::string &phase); + void DelNetRes(const std::string &id); + void ReleaseResource(const py::object &phase); + static void ClearRes(); + + std::map> FetchInfoForQuantExport(const std::string &phase_s); + + private: + ExecutorPy(); + void ConvertObjectToTensors(const py::dict &dict, std::map *tensors); + bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const; + void GetGeBackendPolicy() const; + // filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after + // 'validate' stage + static std::vector FilterActions(const std::vector &actions, const std::string &phase); + + std::map info_; + static std::shared_ptr executor_; + static std::mutex instance_lock_; +}; +using ExecutorPyPtr = std::shared_ptr; + +// Generate a key for mapping function graph +py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults); +py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs); + +bool InitDistribute(const std::map &options); + +void ResetOpId(); +void InitHccl(); +void FinalizeHccl(); +void InitBackend(); +void FinalizeBackend(); + +void ClearResAtexit(); +void ReleaseGeTsd(); + +void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); + +// init and exec dataset sub graph +bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase, bool need_run); + +// Build and run dataset subgraph for ms backend +bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, bool need_run); + +void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list); + +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc new file mode 100644 index 0000000000..e08af4f2dc --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc @@ -0,0 +1,535 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/pipeline_ge.h" + +#include +#include +#include +#include +#include + +#include "debug/anf_ir_dump.h" +#include "ir/tensor.h" +#include "transform/graph_ir/convert.h" +#include "transform/graph_ir/df_graph_manager.h" +#include "transform/graph_ir/graph_builder.h" +#include "transform/graph_ir/graph_runner.h" +#include "debug/draw.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace pipeline { +using Tensor = mindspore::tensor::Tensor; +using MetaTensor = mindspore::tensor::MetaTensor; +using TensorOrderMap = std::map>; +using mindspore::abstract::AbstractTensor; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractTuplePtr; +using mindspore::transform::DfGraphConvertor; +using mindspore::transform::DfGraphManager; +using mindspore::transform::GeTensorPtr; +using mindspore::transform::MeTensorPtr; +using mindspore::transform::Status; +using mindspore::transform::TransformUtil; + +void DoExecNonInputGraph(const std::string &phase) { + std::vector ge_tensors; + std::vector ge_outputs; + transform::RunOptions run_options; + run_options.name = phase; + auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); + if (graph_runner == nullptr) { + MS_LOG(ERROR) << "Can not found GraphRunner"; + return; + } + + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); + if (ret != Status::SUCCESS) { + MS_LOG(ERROR) << "Exec graph:" << run_options.name << " failed"; + return; + } + } +} + +void SetGeOption(const std::map &options) { + ConfigManager::GetInstance().set_ge_initialize_options(options); +} + +Status CreateSessionAndGraphRunner(bool is_training = true) { + std::shared_ptr sess = DfGraphManager::GetInstance().GetGeSession(); + if (sess == nullptr) { + transform::SessionOptions options; + if (is_training) { + options["ge.trainFlag"] = "1"; + options["ge.streamNum"] = "100"; + options["ge.enabledLocalFmkop"] = "1"; + options["ge.hcomParallel"] = "1"; + } else { + options["ge.trainFlag"] = "0"; + } + + options["ge.enablePrintOpPass"] = "0"; + sess = transform::GraphRunner::NewSession(options); + if (sess == nullptr) { + MS_LOG(ERROR) << "Init data graph failed, because of create Ge session failed"; + return Status::FAILED; + } else { + DfGraphManager::GetInstance().SetGeSession(sess); + } + } + + transform::GraphRunnerOptions options; + options.sess_ptr = sess; + auto graph_runner = std::make_shared(options); + if (graph_runner == nullptr) { + MS_LOG(ERROR) << "Create new graph runner failed"; + return Status::FAILED; + } else { + DfGraphManager::GetInstance().SetGraphRunner(graph_runner); + } + + return Status::SUCCESS; +} + +bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase) { + std::vector ge_types; + (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), [](const TypePtr &i) -> int64_t { + return transform::TransformUtil::ConvertDataType(i->type_id()); + }); + + ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE); + ConfigManager::GetInstance().set_iter_num(size); + ConfigManager::GetInstance().set_dataset_phase(phase); + + DatasetGraphParam param(queue_name, size, batch_size, ge_types, shapes, input_indexes); + ConfigManager::GetInstance().set_dataset_param(param); + + if (transform::BuildDatasetGraph(param, phase) != transform::SUCCESS) { + MS_LOG(ERROR) << "Build dateset graph failed."; + return false; + } + +#if ENABLE_TRAIN + (void)setenv("GE_TRAIN", "1", 1); +#else + (void)setenv("GE_TRAIN", "0", 1); +#endif + + if (CreateSessionAndGraphRunner(static_cast(ENABLE_TRAIN)) != Status::SUCCESS) { + MS_LOG(ERROR) << "Create GE Session or GraphRunner failed."; + return false; + } + + MS_LOG(INFO) << "DoExecNonInputGraph:" << phase; + DoExecNonInputGraph(phase); + + return true; +} + +void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) { + for (auto item : dict) { + if ((!py::isinstance(item.first))) { + MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it."; + continue; + } + std::shared_ptr tensor; + std::string name = py::cast(item.first); + if (py::isinstance(item.second.attr("default_input"))) { + // convert float to tensor with shape([1]) + tensor = std::make_shared(kNumberTypeFloat32, std::vector({1})); + *(static_cast(tensor->data_c())) = py::cast(item.second.attr("default_input")); + } else if (py::isinstance(item.second.attr("default_input"))) { + // convert int to tensor with shape([1]) + tensor = std::make_shared(kNumberTypeInt32, std::vector({1})); + *(static_cast(tensor->data_c())) = py::cast(item.second.attr("default_input")); + } else if (py::hasattr(item.second.attr("default_input"), PYTHON_TENSOR_FLAG)) { + // cast tensor + tensor = py::cast>(item.second.attr("default_input")); + } + + if (tensor == nullptr) { + MS_LOG(EXCEPTION) << "Get default value for " << name << " failed"; + } + (void)tensors->emplace(name, tensor); + } +} + +bool AddDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params) { + FuncGraphPtr anf_graph = info.at(phase)->func_graph; + DfGraphConvertor convertor(anf_graph); + + size_t pos = phase.find('.'); + std::string net_id = ((pos == std::string::npos || pos == phase.size() - 1) ? phase : phase.substr(pos + 1)); + std::string phase_prefix = phase.substr(0, pos); + if (phase_prefix == "export") { + MS_LOG(INFO) << "Set DfGraphConvertor training : false"; + convertor.set_training(false); + } + + TensorOrderMap init_tensors{}; + ConvertObjectToTensors(init_params, &init_tensors); + (void)convertor.ConvertAllNode().InitParam(init_tensors).BuildGraph(); + + if (broadcast_params != py::none()) { + if (!py::isinstance(broadcast_params)) { + MS_LOG(ERROR) << "Invalid broadcast params, it must be py::dict type"; + return false; + } + py::dict broadcast = broadcast_params.cast(); + if (broadcast.empty()) { + (void)convertor.GenerateBroadcastGraph(init_tensors); + } else { + TensorOrderMap broadcast_tensors{}; + ConvertObjectToTensors(broadcast, &broadcast_tensors); + (void)convertor.GenerateBroadcastGraph(broadcast_tensors); + } + MS_LOG(INFO) << "Generate broadcast graph with params and broadcast_empty is " << broadcast.empty(); + } + + (void)convertor.GenerateCheckpointGraph(); + if (convertor.ErrCode() != 0) { + DfGraphManager::GetInstance().ClearGraph(); + MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode(); + return false; + } + + if (MsContext::GetInstance()->save_graphs_flag()) { + convertor.DrawComputeGraph(GetFilePathName("ge_graph.dot")); // for debug + convertor.DrawInitGraph(GetFilePathName("init_graph.dot")); // for debug + convertor.DrawSaveCheckpointGraph(GetFilePathName("save_checkpoint_graph.dot")); // for debug + } + std::string init_graph = "init_subgraph." + net_id; + std::string checkpoint_name = "save." + net_id; + if (phase.find("train") != std::string::npos) { + (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); + } else { + (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph()); + } + (void)DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph()); + (void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph()); + + Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph()); + if (ret == Status::SUCCESS) { + DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph); + } + + return true; +} + +FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params) { + if (info.count(phase) == 0) { + MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); + } + FuncGraphPtr anf_graph = info.at(phase)->func_graph; + + if (MsContext::GetInstance()->save_graphs_flag()) { + draw::Draw(GetFilePathName("anf_graph.dot"), anf_graph); // for debug + DumpIR(GetFilePathName("anf_graph.ir"), anf_graph, true); + } + + if (!AddDFGraph(info, init_params, phase, broadcast_params)) { + MS_LOG(ERROR) << "GenConvertor failed"; + return nullptr; + } + +#if ENABLE_TRAIN + (void)setenv("GE_TRAIN", "1", 1); +#else + (void)setenv("GE_TRAIN", "0", 1); +#endif + + if (CreateSessionAndGraphRunner(static_cast(ENABLE_TRAIN)) != Status::SUCCESS) { + MS_LOG(ERROR) << "Create GE Session or GraphRunner failed."; + return nullptr; + } + + return anf_graph; +} + +void RunGEInitGraph(const py::dict &init_params, const std::string &phase) { + MS_LOG(DEBUG) << "ExecInitGraph start."; + TensorOrderMap inputs_with_name{}; + ConvertObjectToTensors(init_params, &inputs_with_name); + std::vector inputs; + (void)std::transform(inputs_with_name.begin(), inputs_with_name.end(), std::back_inserter(inputs), + [](const std::pair &item) { return item.second; }); + + std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); + if (ge_tensors.size() != inputs.size()) { + MS_LOG(ERROR) << "Args convert to ge tensor error."; + return; + } + MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size() << "."; + + std::vector ge_outputs; + transform::RunOptions run_options; + + run_options.name = phase; + if (DfGraphManager::GetInstance().GetGraphByName(phase) == nullptr) { + MS_LOG(WARNING) << "Can not find " << phase << " sub graph, don't need data init subgraph in INFER mode."; + return; + } + auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); + if (graph_runner == nullptr) { + MS_LOG(EXCEPTION) << "Can not found GraphRunner."; + } + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); + if (ret != Status::SUCCESS) { + MS_LOG(EXCEPTION) << "Exec " << phase << " graph failed."; + } + + MS_LOG(INFO) << "Exec " << phase << " graph success."; + + if ((ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::DISTRIBUTION) && + (DfGraphManager::GetInstance().GetGraphByName(BROADCAST_GRAPH_NAME) != nullptr)) { + run_options.name = BROADCAST_GRAPH_NAME; + ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); + if (ret != Status::SUCCESS) { + MS_LOG(EXCEPTION) << "Exec BROADCAST_GRAPH_NAME failed."; + } + MS_LOG(INFO) << "Exec broadcast graph success."; + } + } +} + +py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) { + MS_EXCEPTION_IF_NULL(cnode_data); + + if (cnode_data->isa()) { + if (*count >= data.size()) { + MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() + << " less than the number of elements required. "; + } + + BaseShapePtr shape = cnode_data->BuildShape(); + if (!shape->isa()) { + MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString(); + } + auto shape_me = shape->cast()->shape(); + auto shape_ge = py::cast(data[*count]).shape(); + if (shape_ge != shape_me) { + MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge + << " is not the same as the shape of the tensor derived: " << shape_me; + } + + return data[(*count)++]; + } + + if (!cnode_data->isa()) { + MS_LOG(EXCEPTION) << "The output of operator in the final anf graph could " + << "only be a tensor or a tuple of tensor, but got " << cnode_data->BuildValue()->ToString() + << "."; + } + auto data_tp = cnode_data->cast(); + auto elements = data_tp->elements(); + size_t size = data_tp->size(); + auto tp = py::tuple(size); + for (size_t i = 0; i < size; i++) { + tp[i] = ExtractGeneralCnodeRet(elements[i], data, count); + } + return std::move(tp); +} + +py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data, size_t *count) { + MS_EXCEPTION_IF_NULL(output_node); + + if (output_node->isa()) { + return ValuePtrToPyData(GetValueNode(output_node)); + } + + if (output_node->isa()) { + if (*count >= data.size()) { + MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() + << " less than the number of elements required. "; + } + return data[(*count)++]; + } + + auto output_c = output_node->cast(); + if (output_c == nullptr) { + MS_LOG(EXCEPTION) << "The final anf graph could only have constant, parameter, and operator, but got " + << output_node->ToString(); + } + + if (output_c->IsApply(prim::kPrimMakeTuple)) { + auto input_list = output_c->inputs(); + size_t size = input_list.size(); + auto tp = py::tuple(size - 1); + for (size_t i = 1; i < size; i++) { + tp[i - 1] = StructureOutput(input_list[i], data, count); + } + return std::move(tp); + } + if (output_c->IsApply(prim::kPrimDepend)) { + return StructureOutput(output_c->input(1), data, count); + } + + return ExtractGeneralCnodeRet(output_c->abstract(), data, count); +} + +std::shared_ptr DoExecGraph(const FuncGraphPtr &graph, const std::vector &inputs, + const std::string &phase) { + std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); + if (ge_tensors.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Convert me args to ge tensor error."; + } + + std::vector ge_outputs; + transform::RunOptions run_options; + run_options.name = phase; + auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); + if (graph_runner == nullptr) { + MS_LOG(EXCEPTION) << "Can not found GraphRunner."; + } + + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size(); + Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); + MS_LOG(DEBUG) << "Run graph finish, outputs size is: " << ge_outputs.size(); + if (ret != Status::SUCCESS) { + MS_LOG(ERROR) << "Exec graph failed"; + return nullptr; + } + } + + std::vector me_outputs = TransformUtil::ConvertGeTensors(ge_outputs); + if (me_outputs.size() != ge_outputs.size()) { + MS_LOG(WARNING) << "Convert output Ge tensor to Me tensor failed"; + } + + py::tuple outputs(me_outputs.size()); + for (std::size_t i = 0; i < outputs.size(); i++) { + outputs[i] = *me_outputs[i]; + } + + std::shared_ptr ret = nullptr; + + AnfNodePtr output_node = graph->get_return()->input(1); + MS_EXCEPTION_IF_NULL(output_node); + size_t count = 0; + py::object oj = StructureOutput(output_node, outputs, &count); + ret = std::make_shared(oj); + + return ret; +} + +void ProcessGeArg(const std::map &info, const py::tuple &args, const std::string &phase, + std::vector *inputs) { + // check the arg and use the ExecutorPy args + std::size_t size = args.size(); + + if (info.count(phase) == 0) { + MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); + } + + auto arg_size = info.at(phase)->arg_list_size; + if (size != arg_size) { + MS_LOG(EXCEPTION) << "The real arg num : size = " << size << ". graph_arg_size = " << arg_size; + } + + // process the first args of tensor + // only in dataset normal(non-sink) mode, fp_bp graph need input tensors + if (ConfigManager::GetInstance().dataset_mode() == DS_NORMAL_MODE) { + for (std::size_t i = 0; i < size; i++) { + ValuePtr converted = nullptr; + bool succ = parse::ConvertData(args[i], &converted); + if (!succ) { + MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; + } + if (converted->isa()) { + inputs->push_back(converted->cast()); + } else { + MS_EXCEPTION(TypeError) << "The " << i << "th arg: " << converted->ToString() << " is not tensor."; + } + } + } +} + +py::object ExecDFGraph(const std::map &info, const py::tuple &args, + const std::string &phase) { + std::string phase_prefix = GetPhasePrefix(phase); + if (phase_prefix == "save") { + DoExecNonInputGraph(phase); + ConfigManager::GetInstance().ResetConfig(); + return py::none(); + } + + if (info.count(phase) == 0) { + MS_LOG(EXCEPTION) << "There is no phase:" << phase; + } + FuncGraphPtr anf_graph = info.at(phase)->func_graph; + +#ifdef ENABLE_INFER + // Now don't use the graph because the exec ge function don't take effect + MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph); + if (ENABLE_TRAIN != info.at(phase)->func_graph->has_flag("training")) { + MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries"; + ConfigManager::GetInstance().ResetConfig(); + return py::none(); + } +#endif + + std::shared_ptr ret_val = std::make_shared(); + // We will not execute graph when output is constant or just input itself. + if (IsGraphOutputValueNodeOrParameter(info.at(phase)->func_graph->output(), args, ret_val)) { + ConfigManager::GetInstance().ResetConfig(); + return *ret_val; + } + + std::vector inputs; + ProcessGeArg(info, args, phase, &inputs); + + std::shared_ptr ret = DoExecGraph(anf_graph, inputs, phase); + ConfigManager::GetInstance().ResetConfig(); + if (ret != nullptr) { + return *ret; + } else { + MS_LOG(EXCEPTION) << "Exec graph failed"; + } +} +void ExportDFGraph(const std::string &file_name, const std::string &phase) { + MS_LOG(DEBUG) << "ExportGraph Begin"; + transform::DfGraphWrapperPtr wrap_ptr = DfGraphManager::GetInstance().GetGraphByName(phase); + if (wrap_ptr == nullptr) { + MS_LOG(ERROR) << "Get graph form DfGraphManager failed!"; + return; + } + + transform::DfGraphPtr ge_graph = wrap_ptr->graph_ptr_; + if (nullptr == ge_graph) { + MS_LOG(ERROR) << "The export graph is null"; + return; + } + + (void)ge_graph->SaveToFile(file_name); + + MS_LOG(DEBUG) << "ExportGraph End"; +} +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_ge.h b/mindspore/ccsrc/pipeline/jit/pipeline_ge.h new file mode 100644 index 0000000000..f834125231 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pipeline_ge.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ +#define MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pipeline/jit/base.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace pipeline { +namespace py = pybind11; + +void SetGeOption(const std::map &options); + +void RunGEInitGraph(const py::dict &init_params, const std::string &phase); + +py::object ExecDFGraph(const std::map &info, const py::tuple &args, + const std::string &phase = "train"); + +FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params = {}); + +// init and exec dataset sub graph for GE backend +bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase); + +void ExportDFGraph(const std::string &file_name, const std::string &phase); +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc new file mode 100644 index 0000000000..e9467e4aeb --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/remove_value_node_dup.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "ir/manager.h" +#include "frontend/optimizer/cse.h" +#include "utils/log_adapter.h" +#include "utils/hashing.h" + +namespace mindspore { +namespace pipeline { +void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache, + HashValue *const hash_value) { + const auto &to_check_value = GetValueNode(node); + MS_EXCEPTION_IF_NULL(to_check_value); + + // Calculate hash value. + size_t h; + auto hash_iter = hash_value->find(node); + if (hash_iter == hash_value->end()) { + h = hash_combine(to_check_value->hash(), (opt::AbsOf(node)->hash())); + (*hash_value)[node] = h; + } else { + h = hash_iter->second; + } + + auto bucket_iter = hash_cache->find(h); + if (bucket_iter == hash_cache->end()) { + // Meet for the first time, add bucket. + (*hash_cache)[h] = {node}; + return; + } + + auto &bucket = bucket_iter->second; + // Check if need to replace node with value node already met. + for (const auto &v : bucket) { + // Already met and cached. + if (v == node) { + return; + } + const auto &existed_value = GetValueNode(v); + MS_EXCEPTION_IF_NULL(existed_value); + auto equal = [&]() -> bool { + if (existed_value->isa() && to_check_value->isa()) { + return existed_value->cast()->ValueEqual(*(to_check_value->cast())); + } + return *existed_value == *to_check_value; + }; + if (equal()) { + (void)manager->Replace(node, v); + return; + } + } + + // Meet for the first time, append node to bucket. + bucket.emplace_back(node); +} +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h new file mode 100644 index 0000000000..b36544bdba --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_REMOVE_VALUE_NODE_DUP_H_ +#define MINDSPORE_CCSRC_PIPELINE_REMOVE_VALUE_NODE_DUP_H_ + +#include +#include +#include "base/base.h" +#include "ir/manager.h" + +namespace mindspore { +namespace pipeline { +using HashCache = std::unordered_map>; +using HashValue = std::unordered_map; + +void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value); +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_REMOVE_VALUE_NODE_DUP_H_ diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc new file mode 100644 index 0000000000..ece128b77b --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -0,0 +1,260 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/resource.h" +#include "pipeline/jit/pipeline.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "debug/draw.h" +#include "debug/trace.h" +#include "ir/dtype.h" +#include "pipeline/jit/parse/data_converter.h" +#include "frontend/operator/ops.h" +#include "utils/graph_utils.h" +#include "frontend/optimizer/ad/dfunctor.h" +#include "vm/segment_runner.h" + +namespace mindspore { +// namespace to support opmap definition +namespace pipeline { + +MethodMap &GetMethodMap() { + static MethodMap method_map = { + {kObjectTypeString, + { + {"__bool__", std::string("str_bool")} // C.str_bool + }}, + {kMetaTypeNone, + { + {"__bool__", std::string("none_bool")} // C.none_bool + }}, + {kNumberTypeBool, + { + {"__and__", prim::kPrimBoolAnd}, // P.bool_and + {"__or__", prim::kPrimBoolOr}, // P.bool_or + {"__eq__", prim::kPrimBoolEq}, // P.bool_eq + {"__ne__", std::string("bool_ne")}, // C.bool_ne + {"__bool__", prim::kPrimIdentity} // P.identity + }}, + {kNumberTypeInt, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul + {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv + {"__truediv__", std::string("int_truediv")}, // C.int_truediv + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow + {"__floor__", prim::kPrimIdentity}, // P.identity + {"__trunc__", prim::kPrimIdentity}, // P.identity + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt + {"__le__", prim::kPrimScalarLe}, // P.scalar_le + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge + {"__bool__", std::string("int_bool")}, // C.int_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array + }}, + {kNumberTypeUInt, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, + {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, + {"__truediv__", std::string("int_truediv")}, // C.int_truediv + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, + {"__floor__", prim::kPrimIdentity}, // P.identity, + {"__trunc__", prim::kPrimIdentity}, // P.identity, + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, + {"__le__", prim::kPrimScalarLe}, // P.scalar_le, + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, + {"__bool__", std::string("int_bool")}, // C.int_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, + }}, + {kNumberTypeFloat, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, + {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv + {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, + {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, + {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, + {"__le__", prim::kPrimScalarLe}, // P.scalar_le, + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, + {"__bool__", std::string("float_bool")}, // C.float_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, + }}, + {kObjectTypeTuple, + { + {"__len__", prim::kPrimTupleLen}, // P.tuple_len, + {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, + {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, + {"__ms_iter__", prim::kPrimIdentity}, // P.identity, + {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, + {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext + {"__bool__", std::string("tuple_bool")} // C.tuple_bool + }}, + {kObjectTypeList, + { + {"__len__", prim::kPrimListLen}, // P.list_len, + {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, + {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, + {"__ms_iter__", prim::kPrimIdentity}, // P.identity + {"__ms_next__", std::string("list_next")}, // C.list_next + {"append", std::string("list_append")}, // C.list_next + {"__bool__", std::string("list_bool")}, // C.list_bool + {"__ms_hasnext__", std::string("list_hasnext")}, + }}, + {kObjectTypeDictionary, + { + {"__len__", prim::kPrimDictLen}, // P.dict_len + {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem + {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, + {"__bool__", std::string("dict_bool")} // C.dict_bool + }}, + {kObjectTypeTensorType, + { + {"__add__", std::string("add")}, // C.add + {"__sub__", std::string("sub")}, // C.sub + {"__mul__", std::string("mul")}, // C.mul + {"__truediv__", std::string("truediv")}, // C.truediv + {"__floordiv__", std::string("floordiv")}, // C.floordiv + {"__mod__", std::string("mod")}, // C.mod + {"__pow__", std::string("pow_")}, // C.pow + {"__floor__", std::string("array_floor")}, // C.array_floor + {"__trunc__", std::string("array_trunc")}, // C.array_trunc + {"__pos__", std::string("array_uadd")}, // C.array_uadd + {"__neg__", std::string("array_usub")}, // C.array_usub + {"__eq__", std::string("eq")}, // C.eq + {"__ne__", std::string("ne")}, // C.ne + {"__lt__", std::string("lt")}, // C.lt + {"__gt__", std::string("gt")}, // C.gt + {"__le__", std::string("le")}, // C.le + {"__ge__", std::string("ge")}, // C.ge + {"__matmul__", prim::kPrimDot}, // P.dot, + {"__len__", prim::kPrimArrayLen}, // P.array_len, + {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, + {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, + {"__ms_iter__", std::string("array_iter")}, // C.array_iter + {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, + {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, + {"transpose", std::string("transpose")}, // P.transpose + {"__bool__", std::string("tensor_bool")}, // C.tensor_bool + }}, + {kObjectTypeIndexedSlicesType, + { + {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values + {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices + {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape + }}, + {kObjectTypeJTagged, {}}, + {kObjectTypeSymbolicKeyType, {}}, + {kObjectTypeEnvType, {}}}; + return method_map; +} + +Resource::Resource(const py::object &obj) + : engine_(std::make_shared(abstract::GetPrimEvaluatorConstructors(), manager_)), + input_(obj), + is_cleaned_(false) {} + +Resource::~Resource() { + MS_LOG(DEBUG) << "Resource clear"; + + // If exit normally, these global variables will be cleaned + // in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION, + // these global variables may not being cleaned, it may + // cause segmentfault when free python object inside these global variables + // after python interpreter got freed, so these global variables + // are cleaned here. + // So if exit normally, these global variable will be cleaned twice, + // care be taken to prevent double free in the following functions. + if (!is_cleaned_) { + try { + Clean(); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what(); + } catch (...) { + MS_LOG(ERROR) << "Exception when cleaning resource."; + } + } +} + +bool Resource::IsTypeInMethodMap(const TypeId &type) { + TypeId type_id = NormalizeTypeId(type); + const MethodMap &method_map = GetMethodMap(); + auto iter = method_map.find(static_cast(type_id)); + if (iter != method_map.end()) { + return true; + } + return false; +} + +Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { + TypeId type_id = NormalizeTypeId(type); + const MethodMap &method_map = GetMethodMap(); + auto iter = method_map.find(static_cast(type_id)); + if (iter == method_map.end()) { + MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map"; + return Any(); + } + + auto iter_map = iter->second.find(name); + if (iter_map == iter->second.end()) { + MS_LOG(WARNING) << "Object type: " << type_id << " have no method: " << name; + return Any(); + } + return iter_map->second; +} + +void Resource::Clean() { + // AbstractTensor->elements() will be saved in AbstractBasePtrList + args_spec_.clear(); + input_ = py::none(); + // Context with AbstractBasePtrList may be saved in GraphEvaluator + // some Evaluator like ResolveEvaluator may save Python object in cache, + // it should be cleaned before Python Interpreter destructed. + MS_EXCEPTION_IF_NULL(engine_); + engine_->ClearEvaluatorCache(); + // clean static variable to prevent from crash. As static variable is released after + // Python threads is released. + parse::data_converter::ClearObjectCache(); + parse::Parser::CleanParserResource(); + parse::CleanDataClassToClassMap(); + trace::ClearTraceStack(); + is_cleaned_ = true; +} +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/resource.h b/mindspore/ccsrc/pipeline/jit/resource.h new file mode 100644 index 0000000000..819fdd3d20 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/resource.h @@ -0,0 +1,120 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ +#define MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ + +#include +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +#include "utils/any.h" +#include "utils/profile.h" +#include "ir/manager.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "./common.h" + +namespace mindspore { +namespace pipeline { + +namespace py = pybind11; + +const char kBackend[] = "backend"; +const char kStepParallelGraph[] = "step_parallel"; +const char kOutput[] = "output"; + +class InferenceResource; + +using MethodMap = std::unordered_map>; + +MethodMap &GetMethodMap(); + +class ResourceBase { + public: + ResourceBase() { manager_ = MakeManager(); } + + virtual ~ResourceBase() = default; + + FuncGraphManagerPtr manager() { return manager_; } + // set a manager defined outside which will not manage the graphs. + void set_manager(const FuncGraphManagerPtr &manager) { manager_ = manager; } + + std::unordered_map &results() { return results_; } + + void SetResult(const std::string &key, const Any &value) { results_[key] = value; } + + Any GetResult(const std::string &key) { + if (results_.count(key) == 0) { + MS_LOG(EXCEPTION) << "this key is not in resource list:" << key; + } + return results_[key]; + } + + bool HasResult(const std::string &key) const { return results_.count(key) != 0; } + + std::unordered_map results_; + + protected: + FuncGraphManagerPtr manager_; +}; + +using ResourceBasePtr = std::shared_ptr; + +class Resource : public ResourceBase { + public: + explicit Resource(const py::object &obj = py::none()); + + ~Resource() override; + + abstract::AnalysisEnginePtr engine() { return engine_; } + + static bool IsTypeInMethodMap(const TypeId &type); + + static Any GetMethodPtr(const TypeId &type, const std::string &name); + + const py::object &input() const { return input_; } + + FuncGraphPtr func_graph() const { return func_graph_; } + void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } + + const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; } + void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; } + + // Reclaim resource and clear the cache. + // ExecutorPy::Compile() can be called multiple times, so cache + // should be cleared. + void Clean(); + + private: + abstract::AnalysisEnginePtr engine_; + FuncGraphPtr func_graph_; + abstract::AbstractBasePtrList args_spec_; + py::object input_; + bool is_cleaned_; +}; + +using ResourcePtr = std::shared_ptr; + +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc new file mode 100644 index 0000000000..8bdb2a0c6c --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc @@ -0,0 +1,361 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/abstract_function.h" + +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" + +namespace mindspore { +namespace abstract { +class Evaluator; +class AnalysisEngine; + +AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) { + if (func_list.size() == 1) { + return func_list[0]; + } + return std::make_shared(func_list); +} + +AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) { + auto this_func = shared_from_base(); + if (other->isa()) { + if (*this_func == *other) { + return this_func; + } + return std::make_shared(this_func, other); + } + auto other_union = dyn_cast(other); + if (other_union->IsSuperSet(this_func)) { + return other; + } + return std::make_shared(this_func, other); +} + +void AbstractFuncAtom::Visit(std::function visit_func) const { + visit_func(const_cast(this)->shared_from_base()); +} + +bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; } + +AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) { func_list_ = func_list; } + +AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) { + AbstractFuncAtomPtrList new_func_list; + auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); }; + + first->Visit(build_func_list); + second->Visit(build_func_list); + func_list_ = new_func_list; +} + +std::string AbstractFuncUnion::ToString() const { + std::ostringstream buffer; + buffer << "AbstractFuncUnion({"; + int i = 0; + for (const auto &func : func_list_) { + MS_EXCEPTION_IF_NULL(func); + buffer << "[" << i << "]: " << func->ToString() << ", "; + i++; + } + buffer << "})"; + return buffer.str(); +} + +bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) { + MS_EXCEPTION_IF_NULL(other); + std::vector is_in_list; + auto build_in_list = [this, &is_in_list](const AbstractFuncAtomPtr &func) { + auto iter = find(func_list_.begin(), func_list_.end(), func); + if (iter == func_list_.end()) { + is_in_list.push_back(false); + } + return true; + }; + other->Visit(build_in_list); + return std::all_of(is_in_list.begin(), is_in_list.end(), [](bool is_in) { return is_in; }); +} + +AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) { + auto this_func = shared_from_base(); + if (other->isa()) { + if (IsSuperSet(other)) { + return this_func; + } + return std::make_shared(this_func, other); + } + auto other_union = dyn_cast(other); + if (other_union->IsSuperSet(this_func)) { + return other; + } + return std::make_shared(this_func, other); +} + +void AbstractFuncUnion::Visit(std::function visit_func) const { + for (AbstractFuncAtomPtr poss : func_list_) { + visit_func(poss); + } +} + +bool AbstractFuncUnion::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_union = static_cast(&other); + if (func_list_.size() != other_union->func_list_.size()) { + return false; + } + if (func_list_ == other_union->func_list_) { + return true; + } + return false; +} + +std::size_t AbstractFuncUnion::hash() const { + std::size_t hash_sum = 0; + for (auto f : func_list_) { + hash_sum = hash_combine(hash_sum, f->hash()); + } + return hash_sum; +} + +EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_prim = static_cast(&other); + if (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()) { + return true; + } + return false; +} + +std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } + +EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_fg = static_cast(&other); + if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) { + return true; + } + return false; +} + +std::size_t FuncGraphAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), func_graph_->hash()); + hash_value = hash_combine(hash_value, context_->hash()); + return hash_value; +} + +std::string FuncGraphAbstractClosure::ToString() const { + std::stringstream ss; + ss << "FuncGraphAbstractClosure: " + << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString(); + return ss.str(); +} + +EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_meta_fg = static_cast(&other); + if (meta_func_graph_ == other_meta_fg->meta_func_graph_) { + return true; + } + return false; +} + +std::size_t MetaFuncGraphAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); + return hash_value; +} + +std::string MetaFuncGraphAbstractClosure::ToString() const { + return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name(); +} + +bool PartialAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_partial = static_cast(&other); + if (fn_ != other_partial->fn_) { + return false; + } + if (args_spec_list_.size() != other_partial->args_spec_list_.size()) { + return false; + } + if (args_spec_list_ == other_partial->args_spec_list_) { + return true; + } + return false; +} + +std::size_t PartialAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), fn_->hash()); + hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); + return hash_value; +} + +EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +std::string PartialAbstractClosure::ToString() const { + std::ostringstream buffer; + buffer << "PartialAbstractClosure(" << fn_->ToString() << "("; + for (auto arg : args_spec_list_) { + buffer << arg->ToString() << ", "; + } + buffer << "))"; + return buffer.str(); +} + +EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_transformed = static_cast(&other); + if (fn_ == other_transformed->fn_) { + return true; + } + return false; +} + +std::size_t JTransformedAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), fn_->hash()); + return hash_value; +} + +EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_virtual = static_cast(&other); + if (output_ != other_virtual->output_) { + return false; + } + if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) { + return false; + } + if (args_spec_list_ == other_virtual->args_spec_list_) { + return true; + } + return false; +} + +std::size_t VirtualAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), output_->hash()); + hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); + return hash_value; +} + +std::string VirtualAbstractClosure::ToString() const { + std::ostringstream buffer; + buffer << "VirtualAbstractClosure(args: {"; + int i = 0; + for (const auto &arg : args_spec_list_) { + MS_EXCEPTION_IF_NULL(arg); + buffer << "[" << i << "]: " << arg->ToString() << ", "; + i++; + } + buffer << "}, output: " << output_->ToString() << ")"; + return buffer.str(); +} + +EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_typed = static_cast(&other); + if (output_ != other_typed->output_) { + return false; + } + if (prim_ != other_typed->prim_) { + return false; + } + if (args_spec_list_.size() != other_typed->args_spec_list_.size()) { + return false; + } + if (args_spec_list_ == other_typed->args_spec_list_) { + return true; + } + return false; +} + +std::size_t TypedPrimitiveAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), prim_->hash()); + hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); + return hash_value; +} + +std::string TypedPrimitiveAbstractClosure::ToString() const { + std::ostringstream buffer; + buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {"; + int i = 0; + for (const auto &arg : args_spec_list_) { + MS_EXCEPTION_IF_NULL(arg); + buffer << "[" << i << "]: " << arg->ToString() << ", "; + i++; + } + buffer << "}, output: " << output_->ToString() << ")"; + return buffer.str(); +} + +bool DummyAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + return true; +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h b/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h new file mode 100644 index 0000000000..0823b21cd7 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h @@ -0,0 +1,303 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ +#define PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ + +#include +#include + +#include "abstract/abstract_value.h" +#include "abstract/analysis_context.h" +#include "ir/meta_func_graph.h" + +namespace mindspore { +namespace abstract { +class AbstractFuncAtom : public AbstractFunction { + public: + AbstractFuncAtom() = default; + ~AbstractFuncAtom() override = default; + MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction) + + AbstractFunctionPtr GetUnique() override { return shared_from_base(); } + EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { + MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom"; + } + + AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; + void Visit(std::function) const final; + bool operator==(const AbstractFunction &other) const override; + + std::size_t hash() const override { return tid(); } +}; + +class AbstractFuncUnion : public AbstractFunction { + public: + explicit AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list); + AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second); + ~AbstractFuncUnion() override = default; + MS_DECLARE_PARENT(AbstractFuncUnion, AbstractFunction) + + std::string ToString() const override; + + AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; } + EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { + MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion"; + } + bool IsSuperSet(const AbstractFunctionPtr &other); + AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; + void Visit(std::function) const final; + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; } + + private: + AbstractFuncAtomPtrList func_list_; +}; + +class PrimitiveAbstractClosure : public AbstractFuncAtom { + public: + // Represents a Primitive. + // prim: The primitive + // tracking_id: Identifies different uses of the same primitive. + explicit PrimitiveAbstractClosure(const PrimitivePtr &prim, const AnfNodePtr &tracking_id = nullptr) + : prim_(prim), tracking_id_(AnfNodeWeakPtr(tracking_id)) {} + ~PrimitiveAbstractClosure() override = default; + MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom) + + EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; + + PrimitivePtr prim() { return prim_; } + + AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } + + void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); } + + AbstractFunctionPtr Copy() const override { return std::make_shared(prim_, tracking_id()); } + + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override { return "Prim: " + prim_->name(); } + + private: + PrimitivePtr prim_; + // store it as weak_ptr to break reference cycle. + // one reference cycle example is Graph::set_output() input0 local variable. + AnfNodeWeakPtr tracking_id_; +}; + +class FuncGraphAbstractClosure : public AbstractFuncAtom { + public: + // Represents a Graph in a certain Context. + // context: The context, or Context.empty() + FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) + : func_graph_(func_graph), context_(context) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(context); + } + ~FuncGraphAbstractClosure() override = default; + MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom) + + EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; + + FuncGraphPtr func_graph() { return func_graph_; } + + AnalysisContextPtr context() const override { return context_; } + + AbstractFunctionPtr Copy() const override { + return std::make_shared(func_graph_, context_); + } + + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override; + + private: + FuncGraphPtr func_graph_; + AnalysisContextPtr context_; +}; +using FuncGraphAbstractClosurePtr = std::shared_ptr; + +class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { + public: + explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope) + : meta_func_graph_(meta_func_graph), scope_(scope) {} + ~MetaFuncGraphAbstractClosure() override = default; + MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) + + MetaFuncGraphPtr meta_func_graph() { return meta_func_graph_; } + + AnalysisContextPtr context() const override { return kDummyAnalysisContext; } + + EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; + + ScopePtr GetScope() { return scope_; } + + AbstractFunctionPtr Copy() const override { return std::make_shared(meta_func_graph_); } + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override; + + private: + MetaFuncGraphPtr meta_func_graph_; + ScopePtr scope_; +}; +using MetaFuncGraphAbstractClosurePtr = std::shared_ptr; + +class PartialAbstractClosure : public AbstractFuncAtom { + public: + // Represents a partial application. + // args_spec_list: The first few arguments of that function + PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list, + const AnfNodePtr &node = nullptr) + : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {} + ~PartialAbstractClosure() override = default; + MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) + + EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; + + AbstractFunctionPtr fn() { return fn_; } + AbstractBasePtrList args() { return args_spec_list_; } + AnfNodePtr node() { return node_.lock(); } + void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } + AbstractFunctionPtr Copy() const override { + return std::make_shared(fn_, args_spec_list_, node_.lock()); + } + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override; + + private: + AbstractFuncAtomPtr fn_; + AbstractBasePtrList args_spec_list_; + // The CNode which this PartialAbstractClosure evaluated from. + AnfNodeWeakPtr node_; +}; + +class JTransformedAbstractClosure : public AbstractFuncAtom { + public: + // Represents a Function transformed through the application of J. + explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} + ~JTransformedAbstractClosure() override = default; + MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom) + EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; + + AbstractFuncAtomPtr fn() { return fn_; } + AbstractFunctionPtr Copy() const override { return std::make_shared(fn_); } + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override { return "J(" + fn_->ToString() + ")"; } + + private: + AbstractFuncAtomPtr fn_; +}; + +class VirtualAbstractClosure : public AbstractFuncAtom { + public: + // Represents some function with an explicitly fixed type signature. + // args_spec_list: The arguments as abstract value given to the function + // output: The output which is abstract value. + VirtualAbstractClosure(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output_spec) + : args_spec_list_(args_spec_list), output_(output_spec) {} + VirtualAbstractClosure(const AbstractBasePtr &args_spec, const AbstractBasePtr &output_spec) + : args_spec_list_({args_spec}), output_(output_spec) {} + ~VirtualAbstractClosure() override = default; + MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom) + + EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; + + AbstractBasePtrList args_spec_list() { return args_spec_list_; } + + AbstractBasePtr output() { return output_; } + AbstractFunctionPtr Copy() const override { + return std::make_shared(args_spec_list_, output_); + } + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override; + + private: + AbstractBasePtrList args_spec_list_; + AbstractBasePtr output_; +}; +using VirtualAbstractClosurePtr = std::shared_ptr; + +class TypedPrimitiveAbstractClosure : public AbstractFuncAtom { + public: + // Represents a Primitive with an explicitly fixed type signature. + // args_spec_list: The arguments as abstract value given to the Primitive + // output: The output which is abstract value. + TypedPrimitiveAbstractClosure(const PrimitivePtr prim, const AbstractBasePtrList &args_spec_list, + const AbstractBasePtr &output_spec) + : prim_(prim), args_spec_list_(args_spec_list), output_(output_spec) {} + ~TypedPrimitiveAbstractClosure() override = default; + MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom) + + EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; + + PrimitivePtr prim() { return prim_; } + AbstractBasePtrList args_spec_list() { return args_spec_list_; } + AbstractBasePtr output() { return output_; } + AbstractFunctionPtr Copy() const override { + return std::make_shared(prim_, args_spec_list_, output_); + } + bool operator==(const AbstractFunction &other) const override; + std::size_t hash() const override; + + std::string ToString() const override; + + private: + PrimitivePtr prim_; + AbstractBasePtrList args_spec_list_; + AbstractBasePtr output_; +}; + +// Represents a function that can't be called. +class DummyAbstractClosure : public AbstractFuncAtom { + public: + DummyAbstractClosure() = default; + ~DummyAbstractClosure() override = default; + MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom) + + EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; } + + AbstractFunctionPtr Copy() const override { return std::make_shared(); } + bool operator==(const AbstractFunction &other) const override; + + std::string ToString() const override { return "DummyAbstractClosure()"; } +}; + +struct AbstractFunctionHasher { + std::size_t operator()(const AbstractFunctionPtr &t) const { + std::size_t hash = t->hash(); + return hash; + } +}; + +struct AbstractFunctionEqual { + bool operator()(const AbstractFunctionPtr &lhs, const AbstractFunctionPtr &rhs) const { return *lhs == *rhs; } +}; +} // namespace abstract +} // namespace mindspore +#endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc new file mode 100644 index 0000000000..3e820eed3a --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -0,0 +1,404 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/evaluator.h" + +#include +#include + +#include "ir/func_graph_cloner.h" +#include "abstract/utils.h" +#include "debug/trace.h" + +namespace mindspore { +namespace abstract { +namespace { +string EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list, + const AnfNodeConfigPtr &out_conf) { + MS_EXCEPTION_IF_NULL(evaluator); + std::stringstream ss; + if (out_conf != nullptr) { + ss << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name(); + } + for (size_t i = 0; i < arg_spec_list.size(); i++) { + ss << evaluator->ToString() << " input[" << i << "] abstract value: " << arg_spec_list[i]->ToString(); + } + return ss.str(); +} + +void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) { + MS_EXCEPTION_IF_NULL(evaluator); + if (out_conf != nullptr) { + auto node = out_conf->node(); + if (IsValueNode(node)) { + MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->fullname_with_scope() + << ", with debug info: " << trace::GetDebugInfo(node->debug_info()); + } else { + MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->DebugString() + << ", with debug info: " << trace::GetDebugInfo(node->debug_info()); + } + } +} +} // namespace + +AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine, + const AbstractBasePtrList &args_spec_list) { + AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list); + normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list); + FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list); + MS_EXCEPTION_IF_NULL(parent_context_); + AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list); + return context; +} + +static std::vector FastShadowSort(const AnfNodePtr &ret_node) { + auto current_func_graph = ret_node->func_graph(); + MS_EXCEPTION_IF_NULL(current_func_graph); + + std::vector sorted_nodes; + auto seen = NewSeenGeneration(); + std::size_t index = 0; + sorted_nodes.emplace_back(ret_node); + while (index < sorted_nodes.size()) { + auto current = sorted_nodes[index]; + index++; + MS_EXCEPTION_IF_NULL(current); + if (current->isa()) { + auto &inputs = current->cast()->inputs(); + for (auto it = inputs.begin(); it != inputs.end(); it++) { + AnfNodePtr input = *it; + if (input != nullptr && input->isa() && input->seen_ != seen && + input->func_graph() == current_func_graph) { + sorted_nodes.emplace_back(input); + input->seen_ = seen; + } + } + } + } + return sorted_nodes; +} + +EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { + FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); + MS_EXCEPTION_IF_NULL(fg); + std::size_t nargs = fg->parameters().size(); + if (args_spec_list.size() != nargs) { + MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is " + << fg->parameters().size() << ", but the number of provided arguments is " + << args_spec_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info()); + } + MS_EXCEPTION_IF_NULL(parent_context_); + MS_EXCEPTION_IF_NULL(engine); + graph_context_ = parent_context_->NewFuncGraphContext(fg, args_spec_list); + const auto ¶meters = fg->parameters(); + for (size_t i = 0; i < nargs; i++) { + const auto &arg = args_spec_list[i]; + const auto &node = parameters[i]; + AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); + engine->cache().set_value(conf, std::make_shared(arg, nullptr)); + } + const AnfNodePtr &func_node = fg->get_return(); + + MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString() + << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); + AbstractBasePtr ret_base = nullptr; + std::vector nodes = FastShadowSort(func_node); + for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { + const auto &node = *it; + AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); + MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); + ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); + MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() + << ", abstract: " << ret_base->ToString(); + } + + MS_EXCEPTION_IF_NULL(ret_base); + MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() + << ", is stub: " << fg->stub(); + if (fg->stub()) { + return std::make_shared(std::make_shared(), nullptr); + } + return std::make_shared(ret_base, nullptr); +} + +AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { + MS_EXCEPTION_IF_NULL(func_graph_); + if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { + AbstractBasePtrList broaded_list; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->Broaden(); + }); + MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) + << ", broaded: " << mindspore::ToString(broaded_list); + return broaded_list; + } + return args_spec_list; +} + +AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(func_graph_); + if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { + return args_spec_list; + } + if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { + if (parent_context_) { + MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() + << ", context: " << parent_context_->ToString(); + auto last_context = parent_context_->Filter(func_graph_); + if (last_context && last_context->func_graph() == func_graph_) { + MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString(); + MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); + MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list()); + // Join the last eval arguments and current arguments to check if there are loop variant. + auto joined_args_spec_list = AbstractJoin(args_spec_list, last_context->args_spec_list()); + MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list); + // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. + if (!(joined_args_spec_list == args_spec_list)) { + func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; + } + return joined_args_spec_list; + } + } + if (trace_.size() != 0) { + MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); + MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back()); + // Join the last eval arguments and current arguments to check if there are loop variant. + auto joined_args_spec_list = AbstractJoin(args_spec_list, trace_.back()); + // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. + if (!(joined_args_spec_list == args_spec_list)) { + trace_.push_back(joined_args_spec_list); + func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; + } + MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); + return joined_args_spec_list; + } else { + trace_.push_back(args_spec_list); + } + } + return args_spec_list; +} + +FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { + auto iter = func_graph_cache_.find(args_spec_list); + FuncGraphPtr ret = nullptr; + if (iter == func_graph_cache_.end()) { + auto fg = func_graph(); + MS_EXCEPTION_IF_NULL(fg); + TraceManager::DebugTrace(std::make_shared(fg->debug_info())); + FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list); + TraceManager::EndTrace(); + func_graph_cache_[args_spec_list] = generated_graph; + MS_EXCEPTION_IF_NULL(engine); + engine->func_graph_manager()->AddFuncGraph(generated_graph); + ret = generated_graph; + } else { + ret = iter->second; + } + + // For the top graph, if it is replaced by generated graph, update the top graph to the new one. + if (parse::Parser::GetTopFuncGraph() == func_graph()) { + if (ret != func_graph()) { + parse::Parser::UpdateTopFuncGraph(ret); + } + } + return ret; +} + +FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { + auto iter = func_graph_cache_.find(args_spec_list); + if (iter != func_graph_cache_.end()) { + return iter->second; + } + + MS_EXCEPTION_IF_NULL(meta_func_graph_); + FuncGraphPtr generated_func_graph = nullptr; + if (this->bound_node() != nullptr) { + TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); + generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); + TraceManager::EndTrace(); + } else { + generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); + } + + FuncGraphPtr cloned_func_graph = BasicClone(generated_func_graph); + func_graph_cache_[args_spec_list] = cloned_func_graph; + MS_EXCEPTION_IF_NULL(engine); + engine->func_graph_manager()->AddFuncGraph(cloned_func_graph); + return cloned_func_graph; +} + +EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { + const std::string &evaluator_name = ToString(); + + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + args_spec_list = NormalizeArgs(args_spec_list); + args_spec_list = BroadenUndeterminedArgs(args_spec_list); + trace::TraceGraphEvalEnter(shared_from_base(), out_conf); + MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base(), args_spec_list, out_conf); + MS_EXCEPTION_IF_NULL(cache_); + auto iter = cache_->find(args_spec_list); + if (iter == cache_->end()) { + MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; + EvalResultPtr ret = Eval(engine, args_spec_list); + if (ret->abstract() == nullptr) { + EvalFailLogging(shared_from_base(), args_spec_list, out_conf); + MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; + } + MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << "."; + (*cache_)[args_spec_list] = ret; + trace::TraceGraphEvalLeave(shared_from_base()); + return ret; + } else { + MS_EXCEPTION_IF_NULL(iter->second); + MS_EXCEPTION_IF_NULL(iter->second->abstract()); + MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << "."; + trace::TraceGraphEvalLeave(shared_from_base()); + return iter->second; + } +} + +EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + EvalResultPtr ret = EvalPrim(engine, args_spec_list); + return ret; +} + +EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + if (args_conf_list.size() == 0) { + MS_LOG(EXCEPTION) << "Size should greater than 0"; + } + EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); + // No need to cache. + return ret; +} + +EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { + EvalResultPtr ret = EvalPrim(args_conf_list); + return ret; +} + +EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); + // Don't lookup from cache, as different out_conf with same node but different context + // may add different entry to anfnode_config_map_, like getattr primitive. + (*cache_)[args_spec_list] = ret; + return ret; +} + +EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + MS_EXCEPTION_IF_NULL(cache_); + auto iter = cache_->find(args_spec_list); + if (iter != cache_->end()) { + return iter->second; + } + + ConfigPtrList partial_args_conf_list; + // Join arguments in partial and the rest arguments from args_conf_list. + (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(partial_args_conf_list), + [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); + + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), + [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); + EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); + + (*cache_)[args_spec_list] = ret; + return ret; +} + +EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + MS_EXCEPTION_IF_NULL(cache_); + auto iter = cache_->find(args_spec_list); + if (iter != cache_->end()) { + return iter->second; + } + + // Call the original evaluator, get the result: y = f(x) + EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr); + // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input + // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) + AbstractBasePtrList bparams; + bparams.push_back(SensitivityTransform(orig_func_)); + (void)std::transform( + args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), + [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); + AbstractBasePtr bparams_final = std::make_shared(bparams); + AbstractFunctionPtr bprop = + std::make_shared(SensitivityTransform(result->abstract()), bparams_final); + + // J(f)(J(x)) return a tuple (y, bprop_f) + AbstractBasePtrList jargs = {result->abstract(), bprop}; + AbstractBasePtr jtuple = std::make_shared(jargs); + auto infer_reuslt = std::make_shared(jtuple, std::make_shared()); + (*cache_)[args_spec_list] = infer_reuslt; + return infer_reuslt; +} + +EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.size() != args_spec_list_.size()) { + MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() + << ", arguments no: " << args_spec_list.size(); + } + // Check each parameter and argument match; + for (std::size_t i = 0; i < args_spec_list.size(); i++) { + MS_EXCEPTION_IF_NULL(args_spec_list[i]); + (void)args_spec_list[i]->Join(args_spec_list_[i]); + } + return std::make_shared(output_, std::make_shared()); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h new file mode 100644 index 0000000000..461574257d --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h @@ -0,0 +1,330 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ +#define PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ + +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace abstract { +using EvaluatorCacheMap = + std::unordered_map; +using EvaluatorCacheMapPtr = std::shared_ptr; + +using EvaluatorAttrMap = + std::unordered_map; +using EvaluatorAttrMapPtr = std::shared_ptr; + +class Evaluator : public Base { + public: + explicit Evaluator(const std::string &id) + : cache_(std::make_shared()), + attr_cache_(std::make_shared()), + identifier_(id) {} + ~Evaluator() override = default; + MS_DECLARE_PARENT(Evaluator, Base); + + // difference between Run() and Eval(): + // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. + // Run() will modify cache_ member, so it cannot marked as const; + virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); + + virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; + + virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } + + virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { + return args_spec_list; + } + + virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (!enable_sparse) { + return nullptr; + } + + auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { + if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { + return true; + } + return false; + }); + if (is_abstract) { + MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result"; + return std::make_shared(std::make_shared(), std::make_shared()); + } + return nullptr; + } + + std::string ToString() const override { return identifier_; } + + virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } + + virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } + + EvaluatorCacheMapPtr &cache() { return cache_; } + EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; } + + EvaluatorCacheMapPtr cache_; + EvaluatorAttrMapPtr attr_cache_; + std::string identifier_; + + AnfNodeWeakPtr bound_node_; +}; + +class PrimEvaluator : public Evaluator { + public: + explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} + ~PrimEvaluator() override = default; + MS_DECLARE_PARENT(PrimEvaluator, Evaluator); + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } +}; + +class TrivialPrimEvaluator : public PrimEvaluator { + public: + explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} + ~TrivialPrimEvaluator() override = default; + MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; +}; + +class TransitionPrimEvaluator : public PrimEvaluator { + public: + explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} + ~TransitionPrimEvaluator() override = default; + MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + // Parameter in_conf0 : the first element in args_conf_list; + virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; +}; + +class SymbolicPrimEvaluator : public PrimEvaluator { + public: + explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} + ~SymbolicPrimEvaluator() override = default; + MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; +}; + +// Evaluator will be stored in AnalysisEngine.constructors_ +using EvaluatorPtrList = std::vector; + +class DummyEvaluator : public Evaluator { + public: + DummyEvaluator() : Evaluator("dummy") {} + ~DummyEvaluator() override = default; + MS_DECLARE_PARENT(DummyEvaluator, Evaluator); + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } +}; + +// Wrap another evaluator to track a subset of uses. +// A TrackedEvaluator has its own cache that maps possible calls to +// their results, but is ultimately backed by a different evaluator. +// Multiple TrackedEvaluators can be backed by the same Evaluator. +class TrackedEvaluator : public Evaluator { + public: + explicit TrackedEvaluator(const EvaluatorPtr &subinf) : Evaluator("TrackedEvaluator"), sub_evaluator_(subinf) {} + ~TrackedEvaluator() override = default; + MS_DECLARE_PARENT(TrackedEvaluator, Evaluator); + AnfNodePtr bound_node() const override { + if (sub_evaluator_ != nullptr) { + return sub_evaluator_->bound_node(); + } + return bound_node_.lock(); + } + + void set_bound_node(const AnfNodePtr &node) override { + if (sub_evaluator_ != nullptr) { + sub_evaluator_->set_bound_node(node); + } + bound_node_ = AnfNodeWeakPtr(node); + } + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; + std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } + + private: + EvaluatorPtr sub_evaluator_; +}; + +class BaseFuncGraphEvaluator : public Evaluator { + public: + explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context) + : Evaluator("basegraph"), parent_context_(context) {} + + ~BaseFuncGraphEvaluator() override = default; + MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); + + EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; + + virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; + + AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list); + AnalysisContextPtr graph_context() const { return graph_context_; } + + protected: + AnalysisContextPtr parent_context_; + + private: + AnalysisContextPtr graph_context_; +}; + +class FuncGraphEvaluator : public BaseFuncGraphEvaluator { + public: + FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) + : BaseFuncGraphEvaluator(context->Filter(func_graph)), func_graph_(func_graph) {} + + ~FuncGraphEvaluator() override = default; + MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator); + + FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; + + FuncGraphPtr func_graph() { return func_graph_; } + + AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override; + AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override; + std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); } + + private: + FuncGraphPtr func_graph_; + std::unordered_map + func_graph_cache_; + std::vector trace_; +}; +using FuncGraphEvaluatorPtr = std::shared_ptr; + +class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator { + public: + // Note: context parameter is not used; + MetaFuncGraphEvaluator(const MetaFuncGraphPtr &meta_func_graph, AnalysisContextPtr, const ScopePtr &scope) + : BaseFuncGraphEvaluator(AnalysisContext::DummyContext()), meta_func_graph_(meta_func_graph), scope_(scope) {} + ~MetaFuncGraphEvaluator() override = default; + MS_DECLARE_PARENT(MetaFuncGraphEvaluator, BaseFuncGraphEvaluator); + + FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; + + // Return normalized versions of the arguments. + AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { + return meta_func_graph_->NormalizeArgs(args_spec_list); + } + std::string ToString() const override { return identifier_ + "_" + meta_func_graph_->ToString(); } + + private: + MetaFuncGraphPtr meta_func_graph_; + std::unordered_map + func_graph_cache_; + ScopePtr scope_; +}; + +class PartialAppEvaluator : public Evaluator { + public: + PartialAppEvaluator(const EvaluatorPtr &evaluator, const AbstractBasePtrList &args) + : Evaluator("PartialAppEvaluator"), evaluator_(evaluator), args_spec_list_(args) {} + ~PartialAppEvaluator() override = default; + MS_DECLARE_PARENT(PartialAppEvaluator, Evaluator); + AnfNodePtr bound_node() const override { + if (evaluator_ != nullptr) { + return evaluator_->bound_node(); + } + return bound_node_.lock(); + } + + void set_bound_node(const AnfNodePtr &node) override { + if (evaluator_ != nullptr) { + evaluator_->set_bound_node(node); + } + bound_node_ = AnfNodeWeakPtr(node); + } + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; + } + + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; + std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } + + private: + EvaluatorPtr evaluator_; + AbstractBasePtrList args_spec_list_; +}; + +class VirtualEvaluator : public Evaluator { + public: + VirtualEvaluator(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output) + : Evaluator("virtual"), args_spec_list_(args_spec_list), output_(output) {} + ~VirtualEvaluator() override = default; + MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); + + EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; + std::string ToString() const override { return identifier_; } + + private: + AbstractBasePtrList args_spec_list_; + AbstractBasePtr output_; +}; + +class JEvaluator : public Evaluator { + public: + JEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func) + : Evaluator("JEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {} + ~JEvaluator() override = default; + MS_DECLARE_PARENT(JEvaluator, Evaluator); + AnfNodePtr bound_node() const override { + if (evaluator_ != nullptr) { + return evaluator_->bound_node(); + } + return bound_node_.lock(); + } + + void set_bound_node(const AnfNodePtr &node) override { + if (evaluator_ != nullptr) { + evaluator_->set_bound_node(node); + } + bound_node_ = AnfNodeWeakPtr(node); + } + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; + } + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; + std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } + + private: + EvaluatorPtr evaluator_; + AbstractFunctionPtr orig_func_; +}; +} // namespace abstract +} // namespace mindspore +#endif // PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc new file mode 100644 index 0000000000..99e613395c --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -0,0 +1,1384 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/prim.h" + +#include +#include +#include +#include +#include +#include + +#include "frontend/operator/cc_implementations.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/do_signature.h" +#include "frontend/operator/prim_to_function.h" +#include "abstract/utils.h" +#include "utils/symbolic.h" +#include "./common.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/parse/resolve.h" +#include "ir/tensor.h" +#include "utils/convert_utils.h" +#include "utils/context/ms_context.h" +#include "pipeline/jit/parse/data_converter.h" +#include "abstract/param_validator.h" +#include "common/utils.h" + +namespace mindspore { +namespace abstract { +PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { + static PrimitiveEvalImplMap prim_eval_implement_map = { + // Statements + {prim::kPrimReturn, {InferImplReturn, true}}, + {prim::kPrimTypeOf, {InferImplTypeof, false}}, + {prim::kPrimHasType, {InferImplHasType, false}}, + {prim::kPrimDot, {InferImplDot, true}}, + {prim::kPrimSwitch, {InferImplSwitch, true}}, + {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, + {prim::kPrimIs_, {InferImplIs_, true}}, + {prim::kPrimIsNot, {InferImplIsNot, true}}, + {prim::kPrimInDict, {InferImplInDict, true}}, + {prim::kPrimNotInDict, {InferImplNotInDict, true}}, + {prim::kPrimIsConsant, {InferImplIsConstant, true}}, + // Maths + {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, + {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, + // Array + {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, + {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, + {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, + {prim::kPrimShape, {InferImplShape, true}}, + {prim::kPrimPack, {InferImplPack, true}}, + // Structure + {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, + {prim::kPrimMakeList, {InferImplMakeList, true}}, + {prim::kPrimMakeDict, {InferImplMakeDict, true}}, + {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, + {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, + {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, + {prim::kPrimMakeRecord, {InferImplMakeRecord, false}}, + {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, + {prim::kPrimListGetItem, {InferImplListGetItem, true}}, + {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, + {prim::kPrimListSetItem, {InferImplListSetItem, true}}, + {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, + {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, + {prim::kPrimListAppend, {InferImplListAppend, true}}, + {prim::kPrimTupleLen, {InferImplTupleLen, true}}, + {prim::kPrimListLen, {InferImplListLen, true}}, + {prim::kPrimArrayLen, {InferImplArrayLen, true}}, + {prim::kPrimListMap, {InferImplListMap, false}}, + {prim::kPrimListReduce, {InferImplListReduce, false}}, + {prim::kPrimTupleReversed, {InferImplTupleReversed, false}}, + {prim::kPrimReducedShape, {InferImplReduceShape, false}}, + {prim::kPrimTupleDiv, {InferImplTupleDiv, false}}, + {prim::kPrimTupleToArray, {InferImplTuple2Array, false}}, + {prim::kPrimShapeMul, {InferImplShapeMul, false}}, + {prim::kPrimTupleEqual, {InferImplTupleEqual, false}}, + {prim::kPrimListEqual, {InferImplListEqual, false}}, + {prim::kPrimMakeRange, {InferImplMakeRange, false}}, + {prim::kPrimStopGradient, {InferImplStopGradient, false}}, + {prim::kPrimStringEqual, {InferImplStringEqual, false}}, + {prim::kPrimStringConcat, {InferImplStringConcat, false}}, + {prim::kPrimDictLen, {InferImplDictLen, false}}, + // NN + {prim::kPrimPooling, {InferImplPooling, true}}, + {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, + {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, + {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, + {prim::kPrimReluGrad, {InferImplReluGrad, true}}, + {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, + {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, + {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, + {prim::kPrimRelu, {InferImplRelu, true}}, + {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, + {prim::kPrimZerosLike, {InferImplZerosLike, true}}, + {prim::kPrimBpropCut, {InferImplBpropCut, true}}, + {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, + {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, + {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, + // Others + {prim::kPrimIdentity, {InferImplIdentity, true}}, + // Set impl to null as it will use PartialEvaluator; + {prim::kPrimPartial, {nullptr, true}}, + {prim::kPrimJ, {InferImplJ, false}}, + {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, + {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, + {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, + {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, + {prim::kPrimMakeRef, {InferImplMakeRef, true}}, + {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, + {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, + {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}}, + {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, + {prim::kPrimDepend, {InferImplDepend, true}}, + {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, + {prim::kPrimControlDepend, {InferImplControlDepend, true}}, + // Debug + {prim::kPrimDebug, {InferImplDebug, true}}, + // IndexedSlices + {prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}}, + {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, + {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, + {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, + {prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}}, + }; + return prim_eval_implement_map; +} + +using mindspore::parse::PyObjectWrapper; + +EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { + if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; + return ret_abstract; + } + } + prim_->BeginRecordAddAttr(); + AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); + prim_->EndRecordAddAttr(); + auto added_attrs = prim_->evaluate_added_attrs(); + auto infer_result = std::make_shared(abs_base, std::make_shared(added_attrs)); + return infer_result; +} + +EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); + auto ret_abstract = AbstractEval(args_spec_list); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; + return ret_abstract; + } + + if (out_conf->node() == nullptr || !out_conf->node()->isa()) { + MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; + } + + auto do_signature = dyn_cast(prim_); + auto out_node = dyn_cast(out_conf->node()); + const auto &out_node_inputs = out_node->inputs(); + if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { + MS_LOG(EXCEPTION) << "Op: " << do_signature->function()->ToString() + << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() + << ", inputs size " << out_node_inputs.size(); + } + AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; + + ScopePtr scope = kDefaultScope; + if (out_conf != nullptr) { + scope = out_conf->node()->scope(); + } + ScopeGuard scope_guard(scope); + + AnfNodePtr new_cnode = nullptr; + if (bound_node() != nullptr) { + TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); + new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, + args_inputs); + TraceManager::EndTrace(); + } else { + new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, + args_inputs); + } + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context()); + + return engine->ForwardConfig(out_conf, fn_conf); +} + +static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) { + // arg[0] is the func graph to unpack, ignore it + AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end()); + AbstractBasePtrList graph_specialize_args; + if (need_unpack) { + for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) { + MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]); + if (specialize_args_before_unpack[index]->isa()) { + AbstractTuplePtr arg_tuple = specialize_args_before_unpack[index]->cast(); + std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), + std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; }); + } else if (specialize_args_before_unpack[index]->isa()) { + AbstractDictionaryPtr arg_dict = specialize_args_before_unpack[index]->cast(); + auto dict_elems = arg_dict->elements(); + (void)std::transform( + dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args), + [](const AbstractAttribute &item) { return std::make_shared(item.first, item.second); }); + } else { + MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got " + << specialize_args_before_unpack[index]->ToString(); + } + } + } else { + graph_specialize_args = specialize_args_before_unpack; + } + return graph_specialize_args; +} + +EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + if (out_conf->node() == nullptr || !out_conf->node()->isa()) { + MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; + } + + auto unpack_graph = prim_->cast(); + auto out_node = out_conf->node()->cast(); + const auto &out_node_inputs = out_node->inputs(); + if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { + MS_LOG(EXCEPTION) << "UnpackGraphPrimitive" + << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() + << ", inputs size " << out_node_inputs.size(); + } + AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); + // get the forward graph + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + AbstractFunctionPtr fn = args_spec_list[0]->cast(); + if (fn == nullptr) { + MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); + } + auto real_fn = fn->cast(); + MS_EXCEPTION_IF_NULL(real_fn); + FuncGraphPtr forward_graph = real_fn->func_graph(); + MS_EXCEPTION_IF_NULL(forward_graph); + AbstractBasePtrList graph_specialize_args = + GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args()); + + AbstractBasePtrList graph_specialize_args_without_sens; + (void)std::transform(graph_specialize_args.begin(), + graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0), + std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; }); + auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens); + engine->func_graph_manager()->AddFuncGraph(new_graph); + ScopePtr scope = kDefaultScope; + if (out_conf != nullptr) { + scope = out_conf->node()->scope(); + } + ScopeGuard scope_guard(scope); + AnfNodePtr new_vnode = NewValueNode(new_graph); + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context()); + + return engine->ForwardConfig(out_conf, fn_conf); +} + +AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node_type, AnfNodePtr target_type, + FuncGraphPtr func_graph) { + AnfNodePtr target_node = source_node; + if (node_type->isa()) { + auto x = node_type->cast(); + if (x->element()->BuildType()->isa()) { + auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); + MS_EXCEPTION_IF_NULL(cast); + target_node = func_graph->NewCNode({NewValueNode(cast), source_node, target_type}); + } + } else if (node_type->isa()) { + auto x = node_type->cast(); + auto &items = x->elements(); + std::vector nodes; + nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + int idx = 0; + for (const auto &item : items) { + AnfNodePtr tuple_node = + func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)}); + AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, item, target_type, func_graph); + nodes.emplace_back(node); + ++idx; + } + target_node = func_graph->NewCNode(nodes); + } else if (node_type->isa()) { + auto x = node_type->cast(); + auto &items = x->elements(); + std::vector dict_key_nodes; + std::vector dict_value_nodes; + dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (const auto &item : items) { + AnfNodePtr dict_value_node = + func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)}); + AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph); + dict_key_nodes.emplace_back(NewValueNode(item.first)); + dict_value_nodes.emplace_back(node); + } + target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes), + func_graph->NewCNode(dict_value_nodes)}); + } else if (node_type->isa()) { + auto x = node_type->cast(); + std::string kwarg_key = x->get_key(); + AnfNodePtr kwarg_value_node = + func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node}); + AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph); + target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node}); + } + return target_node; +} + +EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + AbstractBasePtrList args_spec_list; + if (out_conf->node() == nullptr || !out_conf->node()->isa()) { + MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; + } + auto out_node = out_conf->node()->cast(); + const auto &out_node_inputs = out_node->inputs(); + if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { + MS_LOG(EXCEPTION) << "MixedPrecisionCast" + << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() + << ", inputs size " << out_node_inputs.size(); + } + AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); + + ScopePtr scope = kDefaultScope; + if (out_conf != nullptr) { + scope = out_conf->node()->scope(); + } + ScopeGuard scope_guard(scope); + + FuncGraphPtr func_graph = out_conf->node()->func_graph(); + AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph); + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context()); + + return engine->ForwardConfig(out_conf, fn_conf); +} + +namespace { +py::object BuildValue(const ValuePtr &value_ptr) { + if (value_ptr == nullptr) { + return py::none(); + } else { + return ValuePtrToPyData(value_ptr); + } +} +} // end anonymous namespace + +py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { + MS_EXCEPTION_IF_NULL(abs_base); + py::dict dic; + if (abs_base->isa()) { + auto arg_tensor = dyn_cast(abs_base); + dic["shape"] = arg_tensor->shape()->shape(); + dic["dtype"] = arg_tensor->BuildType(); + dic["value"] = BuildValue(arg_tensor->BuildValue()); + } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { + std::vector shape; + dic["shape"] = shape; + dic["dtype"] = abs_base->BuildType(); + dic["value"] = BuildValue(abs_base->BuildValue()); + } else if (abs_base->isa()) { + auto arg_slice = dyn_cast(abs_base); + std::vector shape; + dic["shape"] = shape; + dic["dtype"] = arg_slice->BuildType(); + dic["value"] = BuildValue(arg_slice->BuildValue()); + } else if (abs_base->isa()) { + auto value = abs_base->cast()->ref(); + dic = ConvertAbstractToPython(value); + } else if (abs_base->isa()) { + dic["shape"] = py::none(); + dic["dtype"] = py::ellipsis(); + dic["value"] = py::ellipsis(); + } else if (abs_base->isa()) { + auto arg_tuple = dyn_cast(abs_base); + size_t len = arg_tuple->size(); + py::tuple shape_tuple(len); + py::tuple dtype_tuple(len); + + for (size_t i = 0; i < len; i++) { + py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]); + shape_tuple[i] = out["shape"]; + dtype_tuple[i] = out["dtype"]; + } + dic["shape"] = shape_tuple; + dic["dtype"] = dtype_tuple; + dic["value"] = BuildValue(arg_tuple->BuildValue()); + } else if (abs_base->isa()) { + auto arg_list = dyn_cast(abs_base); + size_t len = arg_list->size(); + py::list shape_list(len); + py::list dtype_list(len); + + for (size_t i = 0; i < len; i++) { + py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); + shape_list[i] = out["shape"]; + dtype_list[i] = out["dtype"]; + } + dic["shape"] = shape_list; + dic["dtype"] = dtype_list; + dic["value"] = BuildValue(arg_list->BuildValue()); + } else if (abs_base->isa()) { + dic["shape"] = py::none(); + dic["dtype"] = py::none(); + dic["value"] = py::none(); + } else if (abs_base->isa()) { + dic["shape"] = py::none(); + dic["dtype"] = abs_base->BuildType(); + dic["value"] = py::none(); + } else { + auto value = abs_base->BuildValue(); + if ((*value == *kAnyValue)) { + auto value_desc = abs_base->value_desc(); + MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc) + << " for python primitive." << abs_base->ToString(); + } + MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is " + << value->ToString(); + } + return dic; +} + +namespace { +py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrList &args) { + const AbstractBasePtrList *args_ptr; + + if (prim_py->is_tuple_input_) { + if (args.empty()) { + MS_LOG(EXCEPTION) << "Primitive args is empty"; + } + if (args[0] == nullptr || !args[0]->isa()) { + MS_LOG(EXCEPTION) << "Custom Primitive inputs should be packed into a Tuple after converting" + "prim convert pass for GE."; + } + args_ptr = &(args[0]->cast()->elements()); + } else { + args_ptr = &args; + } + + py::tuple py_args(args_ptr->size()); + for (size_t i = 0; i < args_ptr->size(); i++) { + auto arg_i = (*args_ptr)[i]; + py_args[i] = ConvertAbstractToPython(arg_i); + } + return py_args; +} + +AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { + // Convert to AbstractValue based on type and shape + if (output["value"].is_none()) { + auto out_shape = output["shape"]; + auto out_dtype = output["dtype"]; + return PyListDtype2AbstractTensor(out_shape, out_dtype); + } + // Convert pyobject to Value, then to AbstractValue + ValuePtr converted_ret = nullptr; + bool converted = parse::ConvertData(output["value"], &converted_ret); + if (!converted) { + MS_LOG(EXCEPTION) << "Convert data failed"; + } + auto res_spec = FromValue(converted_ret); + MS_EXCEPTION_IF_NULL(res_spec); + if (res_spec->isa()) { + // Replace to tensor constant node in specialize + auto res_tensor = res_spec->cast(); + res_tensor->set_value(converted_ret); + } + if (prim_py->IsCustomPrim()) { + // Raise error if output_num is not match the infer result. + int output_num = GetValue(prim_py->GetAttr("output_num")); + if (res_spec->isa() && output_num != 1) { + MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num + << " not matches the infer result."; + } else if (res_spec->isa() && + (res_spec->cast()->size() != IntToSize(output_num))) { + MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num + << " not matches the infer result."; + } + } + return res_spec; +} +} // end anonymous namespace + +EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; + return ret_abstract; + } + MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); + + const auto &iter = cache_->find(args); + if (iter != cache_->end()) { + return iter->second; + } + auto py_args = PreparePyInputs(prim_py_, args); + + auto pyobj = prim_py_->GetPyObj(); + if (pyobj == nullptr) { + MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; + } + auto infer_fuc = pyobj.attr("__infer__"); + prim_py_->BeginRecordAddAttr(); + py::dict output = infer_fuc(*py_args); + prim_py_->EndRecordAddAttr(); + auto added_attrs = prim_py_->evaluate_added_attrs(); + MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); + auto res_spec = PyInferRes2Abstract(prim_py_, output); + + MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; + auto infer_result = std::make_shared(res_spec, std::make_shared(added_attrs)); + (*cache_)[args] = infer_result; + return infer_result; +} + +EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; + return ret_abstract; + } + // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. + if (nargs_ != args.size()) { + MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; + return nullptr; + } + TypePtr ret_value_type = return_value_type_; + ValuePtrList value_list; + for (const auto &arg : args) { + // Check if all arguments are scalar type. + MS_EXCEPTION_IF_NULL(arg); + if (arg->isa()) { + auto arg_scalar = dyn_cast(arg); + auto arg_value = arg_scalar->GetValueTrack(); + value_list.push_back(arg_value); + } else { + // Raise TypeError Expected Scalar. + MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives."; + } + } + for (const auto &item : type_map_) { + TypePtrList selections; + MS_EXCEPTION_IF_NULL(item.second); + (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections), + [&args](size_t arg_idx) -> TypePtr { return args[arg_idx]->GetTypeTrack(); }); + TypePtr res = CheckTypeList(item.first, selections); + if (*return_value_type_ == *(item.first)) { + ret_value_type = res; + } + } + + ValuePtr evaluated_value = RunImpl(value_list); + if (!(*evaluated_value == *kAnyValue)) { + ret_value_type = evaluated_value->type(); + } + // for comparison primitives , return type shall have be specified to be bool. + if (specify_out_type_ != nullptr) { + ret_value_type = specify_out_type_; + } + + AbstractScalarPtr abs_base = std::make_shared(evaluated_value, ret_value_type); + return std::make_shared(abs_base, std::make_shared()); +} + +ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { + if (!eval_value_) { + return kAnyValue; + } else { + if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) { + MS_EXCEPTION_IF_NULL(arg); + return arg->isa(); + })) { + return kAnyValue; + } + return impl_(args); + } +} + +// Primitive implementation +// static function start +namespace { +EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) { + EvaluatorPtr prim_evaluator = std::make_shared(primitive, eval_impl); + return prim_evaluator; +} + +EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value, + const TypePtr &specify_out_type) { + FunctionPtr func = nullptr; + (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func); + MS_EXCEPTION_IF_NULL(func); + + EvaluatorPtr uniform_primitive_evaluator = + std::make_shared(func, prim_impl, eval_value, specify_out_type); + return uniform_primitive_evaluator; +} + +const int kResolveCaseUserDefineClass = 1; +const int kResolveCaseBuildinTypeMethod = 2; +const int kResolveCaseFunction = 3; +int GetResolveCase(const TypePtr &data_type) { + MS_EXCEPTION_IF_NULL(data_type); + if (data_type->type_id() == kObjectTypeClass) { + return kResolveCaseUserDefineClass; + } + + // try method map, if not in method map, the data_type should be External type. + if (pipeline::Resource::IsTypeInMethodMap(data_type->type_id())) { + return kResolveCaseBuildinTypeMethod; + } + + return kResolveCaseFunction; +} + +FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) { + MS_EXCEPTION_IF_NULL(engine); + MS_EXCEPTION_IF_NULL(method); + if (!method->isa()) { + MS_LOG(EXCEPTION) << "Method type error: " << method->ToString(); + } + + std::shared_ptr obj = method->cast>(); + FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj()); + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed"; + } + + FuncGraphManagerPtr manager = engine->func_graph_manager(); + manager->AddFuncGraph(func_graph); + return func_graph; +} + +inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) { + MS_EXCEPTION_IF_NULL(engine); + FuncGraphManagerPtr manager = engine->func_graph_manager(); + manager->AddFuncGraph(func_graph); +} + +EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, + const AnfNodeConfigPtr &old_conf) { + MS_EXCEPTION_IF_NULL(old_conf); + + AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); + AbstractFunctionPtr abs_func = dyn_cast(abs_ptr); + MS_EXCEPTION_IF_NULL(abs_func); + + // Create new cnode + std::vector input = {NewValueNode(prim::kPrimPartial)}; + auto func_graph_func = dyn_cast(abs_func); + if (func_graph_func != nullptr) { + FuncGraphPtr fg = func_graph_func->func_graph(); + input.push_back(NewValueNode(fg)); + } else { + auto prim_func = dyn_cast(abs_func); + MS_EXCEPTION_IF_NULL(prim_func); + PrimitivePtr prim = prim_func->prim(); + input.push_back(NewValueNode(prim)); + } + + AnfNodeConfigPtr conf = dyn_cast(data_conf); + MS_EXCEPTION_IF_NULL(conf); + input.push_back(conf->node()); + MS_EXCEPTION_IF_NULL(old_conf); + FuncGraphPtr func_graph = old_conf->node()->func_graph(); + CNodePtr new_cnode = func_graph->NewCNode(input); + AnalysisEnginePtr eng = old_conf->engine(); + AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context()); + return eng->ForwardConfig(old_conf, fn_conf); +} + +EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, + const AbstractBasePtrList &args_spec_list, + const AnfNodeConfigPtr &out_conf) { + // args_spec_list: same as StaticGetter + if (args_spec_list.size() < 2) { + MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2"; + } + MS_EXCEPTION_IF_NULL(out_conf); + // An external type. + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString(); + MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString(); + auto data_v = args_spec_list[0]->BuildValue(); + if (!data_v->isa()) { + MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString(); + } + + auto item_v = args_spec_list[1]->BuildValue(); + if (item_v->isa()) { + item_v = std::make_shared(item_v->cast()->value()); + } + + if (!item_v->isa()) { + MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_v->ToString(); + } + + // item_name to func addr from obj_map + parse::SymbolPtr symbol = item_v->cast(); + parse::NameSpacePtr name_space = data_v->cast(); + FuncGraphPtr func_graph = out_conf->node()->func_graph(); + + auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node()); + if (new_node == nullptr) { + MS_LOG(EXCEPTION) << "Resolve node failed"; + } + + AnalysisEnginePtr eng = out_conf->engine(); + AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context()); + return eng->ForwardConfig(out_conf, fn_conf); +} + +EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, + const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, + const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << "args_spec_list is empty"; + } + AbstractClassPtr cls = CheckArg("__FUNC__", args_spec_list, 0); + + // If item_v is an attribute, get abstract value from AbstractClass + MS_EXCEPTION_IF_NULL(item_v); + if (!item_v->isa()) { + MS_LOG(EXCEPTION) << "Attribute type error"; + } + std::string item_name = item_v->cast()->value(); + MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name(); + MS_LOG(DEBUG) << "Resolve item: " << item_name; + + AbstractBasePtr attr = cls->GetAttribute(item_name); + if (attr != nullptr) { + return std::make_shared(attr, nullptr); + } + + ValuePtr method = cls->GetMethod(item_name); + if (method->isa()) { + MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() + << ", item value: " << item_v->ToString(); + } + + // Infer class method + ValuePtr converted_v = PyObjToGraph(engine, method); + return StaticGetterInferred(converted_v, data_conf, out_conf); +} + +EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, + const TypePtr &data_type, const ConfigPtr &data_conf, + const AnfNodeConfigPtr &out_conf) { + MS_EXCEPTION_IF_NULL(item_v); + MS_EXCEPTION_IF_NULL(data_type); + // The method maybe a Primitive or Composite + if (!item_v->isa()) { + MS_LOG(EXCEPTION) << "Error item is not string"; + } + + std::string item_name = item_v->cast()->value(); + Any method = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); + if (method.empty()) { + MS_LOG(EXCEPTION) << "Object type: " << data_type->ToString() << " has no method: " << item_name; + } + + ValuePtr converted_v = nullptr; + if (method.is()) { + // composite registered in standard_method_map go to this branch + converted_v = prim::GetPythonOps(method.cast()); + AddToManager(engine, converted_v->cast()); + } else if (method.is()) { + converted_v = method.cast(); + } else { + MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from method map, but got " << method.ToString(); + } + return StaticGetterInferred(converted_v, data_conf, out_conf); +} + +EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { + // Inputs: namespace and its static function; or class and its member function + CheckArgsSize("StaticGetter", args_spec_list, 2); + + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + TypePtr data_type = args_spec_list[0]->BuildType(); + ValuePtr item_value = args_spec_list[1]->BuildValue(); + ScopePtr scope = kDefaultScope; + if (out_conf != nullptr) { + scope = out_conf->node()->scope(); + } + ScopeGuard scope_guard(scope); + if (item_value->isa()) { + MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString(); + } + + int case_v = GetResolveCase(data_type); + if (case_v == kResolveCaseUserDefineClass) { + return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf); + } else if (case_v == kResolveCaseBuildinTypeMethod) { + return GetEvaluatedValueForBuiltinTypeMethod(engine, item_value, data_type, data_conf, out_conf); + } else { + return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf); + } +} +} // end anonymous namespace + +// static variable start; +namespace { +class EmbedEvaluator : public SymbolicPrimEvaluator { + public: + EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {} + ~EmbedEvaluator() override = default; + MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator); + EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { + // arg: free variable to be embedded + if (args_conf_list.size() != 1) { + MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); + } + AnfNodeConfigPtr node_conf = dyn_cast(args_conf_list[0]); + MS_EXCEPTION_IF_NULL(node_conf); + + AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract(); + x = SensitivityTransform(x); + SymbolicKeyInstancePtr key = std::make_shared(node_conf->node(), x); + AbstractScalarPtr abs_scalar = std::make_shared(key, std::make_shared()); + return std::make_shared(abs_scalar, std::make_shared()); + } +}; + +static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) { + auto root_g_set = manager->roots(); + if (root_g_set.size() != 1) { + return nullptr; + } + const FuncGraphPtr &root_g = root_g_set.back(); + + for (auto ¶m_node : root_g->parameters()) { + auto param = param_node->cast(); + if (param && name == param->name()) { + return param; + } + } + return nullptr; +} + +class RefToEmbedEvaluator : public SymbolicPrimEvaluator { + public: + RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {} + ~RefToEmbedEvaluator() override = default; + MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator); + EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { + if (args_conf_list.size() != 1) { + MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size(); + return nullptr; + } + static TypePtr type = std::make_shared(); + auto node_conf = dyn_cast(args_conf_list[0]); + if (node_conf == nullptr) { + MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; + return nullptr; + } + AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); + AbstractRefPtr ref_abs = abs->cast(); + if (ref_abs == nullptr) { + MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); + return nullptr; + } + auto key_abs = ref_abs->ref_key(); + if (key_abs == nullptr) { + MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr."; + return nullptr; + } + auto key_value = key_abs->BuildValue(); + if (key_value == nullptr) { + MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr."; + return nullptr; + } + auto refkey = key_value->cast(); + if (refkey == nullptr) { + auto ret = std::make_shared(type); + auto ref_value = ref_abs->ref(); + MS_EXCEPTION_IF_NULL(ref_value); + return std::make_shared(ret, std::make_shared()); + } + + std::string name = refkey->tag(); + const auto &manager = node_conf->node()->func_graph()->manager(); + auto node = FindParameterNodeByString(manager, name); + if (node == nullptr) { + MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph."; + return nullptr; + } + AbstractBasePtr x = ref_abs->ref(); + x = SensitivityTransform(x); + std::shared_ptr key = std::make_shared(node, x); + std::shared_ptr abs_scalar = std::make_shared(key, type); + return std::make_shared(abs_scalar, std::make_shared()); + } +}; + +class GetAttrEvaluator : public TransitionPrimEvaluator { + public: + GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {} + ~GetAttrEvaluator() override = default; + MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { + auto ret_abstract = AbstractEval(args_spec_list); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; + return ret_abstract; + } + // Inputs: data, item + if (args_spec_list.size() != 2) { + MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); + } + EvalResultPtr ret = nullptr; + if (bound_node() != nullptr) { + TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); + ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); + TraceManager::EndTrace(); + } else { + ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); + } + // don't lookup from cache, as different out_conf with same node but different context + // may add different entry to anfnode_config_map, like getattr primitive; + (*cache_)[args_spec_list] = ret; + return ret; + } +}; + +class ResolveEvaluator : public TransitionPrimEvaluator { + public: + ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {} + ~ResolveEvaluator() override = default; + MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { + // Inputs: namespace, symbol + if (args_spec_list.size() != 2) { + MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); + } + EvalResultPtr ret = nullptr; + if (bound_node() != nullptr) { + TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); + ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); + TraceManager::EndTrace(); + } else { + ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); + } + return ret; + } +}; + +class CreateInstanceEvaluator : public TransitionPrimEvaluator { + public: + CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {} + ~CreateInstanceEvaluator() override = default; + MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, + const AnfNodeConfigPtr &out_conf) override { + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; + } + + // get the type parameter + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + TypePtr type = args_spec_list[0]->GetTypeTrack(); + if (type->type_id() != kMetaTypeTypeType) { + MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got " + << type->ToString(); + } + + ValuePtr value_track = args_spec_list[0]->GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + + std::shared_ptr type_obj = dyn_cast(value_track); + if (type_obj == nullptr) { + MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << "."; + } + + if (!type_obj->isa()) { + MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got " + << type_obj->ToString() << "."; + } + + auto class_type = type_obj->obj(); + MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << "."; + + // get the create instance obj's parameters + pybind11::tuple params = GetParameters(args_spec_list); + + // create class instance + auto obj = parse::data_converter::CreatePythonObject(class_type, params); + if (py::isinstance(obj)) { + MS_LOG(EXCEPTION) << "Create python object failed, only support Cell and Primitive type"; + } + + // process the object + ValuePtr converted_ret = nullptr; + bool converted = parse::ConvertData(obj, &converted_ret, true); + if (!converted) { + MS_LOG(EXCEPTION) << "Convert the python object failed"; + } + MS_EXCEPTION_IF_NULL(converted_ret); + + if (converted_ret->isa()) { + AddToManager(engine, converted_ret->cast()); + } + + AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); + auto infer_result = std::make_shared(ret, nullptr); + (*cache_)[args_spec_list] = infer_result; + return infer_result; + } + + pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const { + // Exclude class type by minus 1; + std::size_t params_size = args_spec_list.size() - 1; + auto params = py::tuple(params_size); + if (params_size > 0) { + for (size_t i = 0; i < params_size; i++) { + // Only support the Scalar parameters type. Bypass class type by offset with 1. + auto arg = args_spec_list[i + 1]; + MS_EXCEPTION_IF_NULL(arg); + // Because the Tensor's AbstractTensor can't get value from GetValueTrack. + ValuePtr param_value = arg->BuildValue(); + py::object param = ValuePtrToPyData(param_value); + params[i] = param; + } + } + return params; + } +}; + +class PartialEvaluator : public Evaluator { + public: + PartialEvaluator() : Evaluator("PartialEvaluator") {} + ~PartialEvaluator() override = default; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf = nullptr) override { + if (args_conf_list.size() == 0) { + MS_LOG(EXCEPTION) << "Args size should be greater than 0"; + } + + MS_EXCEPTION_IF_NULL(out_conf); + MS_EXCEPTION_IF_NULL(out_conf->node()); + auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract(); + AbstractBasePtrList args_spec_list{arg0_value}; + // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. + if (arg0_value->isa()) { + auto ret = std::make_shared(arg0_value->GetValueTrack()->cast(), out_conf->node()); + MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() + << " as func is: " << arg0_value->ToString(); + auto eval_result = std::make_shared(ret, std::make_shared()); + (*cache_)[args_spec_list] = eval_result; + return eval_result; + } + auto func = CheckArg("partial", args_spec_list, 0); + // Sometimes, node[0] in out_conf becomes phi0; + if (func->isa()) { + auto prim_func = dyn_cast(func); + if (prim_func->prim()->isa()) { + prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast(prim_func->prim()); + return HandleDoSignature(engine, do_signature_prim->function(), out_conf); + } + } + + (void)std::transform( + args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); }); + AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); + + auto cnode = out_conf->node()->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() != (args_conf_list.size() + 1)) { + MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString() + << ", args_conf_list: " << mindspore::ToString(args_conf_list); + } + + AbstractFuncAtomPtrList partial_funcs_list; + auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { + auto new_func = std::make_shared(atom_func, args, cnode); + partial_funcs_list.push_back(new_func); + }; + func->Visit(build_partial); + + auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); + auto infer_result = std::make_shared(ret, std::make_shared()); + (*cache_)[args_spec_list] = infer_result; + return infer_result; + } + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } + + EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, + const AnfNodeConfigPtr &out_conf = nullptr) const { + MS_EXCEPTION_IF_NULL(out_conf); + MS_EXCEPTION_IF_NULL(out_conf->node()); + auto cnode = out_conf->node()->cast(); + if (cnode == nullptr) { + MS_LOG(EXCEPTION) << "Cnode is nullptr"; + } + std::vector new_nodes_inputs = cnode->inputs(); + auto new_signature_value = std::make_shared("signature", signature_value); + new_nodes_inputs[1] = NewValueNode(new_signature_value); + FuncGraphPtr func_graph = cnode->func_graph(); + + ScopePtr scope = out_conf->node()->scope(); + ScopeGuard scope_guard(scope); + + CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs); + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context()); + return engine->ForwardConfig(out_conf, fn_conf); + } +}; + +struct PrimitiveImplInferValue { + PrimitiveImpl impl_; // implement function of primitive + bool eval_value_; // whether evaluate value + TypePtr specify_out_type_; // whether specify return type + bool in_white_list_; // true if this Primitive in white list, else false. +}; + +using PrimitiveToImplMap = std::unordered_map; +PrimitiveToImplMap &GetUniformPrimitiveToImplMap() { + static PrimitiveToImplMap uniform_prim_implement_map = { + {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}}, + {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}}, + {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, + {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, + {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, + {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}}, + {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}}, + {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, + {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, + {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, + {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared(), true}}, + {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared(), true}}, + {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared(), true}}, + {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared(), true}}, + {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared(), true}}, + {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared(), true}}, + {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared(), true}}, + {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared(), true}}, + {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared(), true}}, + {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared(), true}}, + }; + return uniform_prim_implement_map; +} + +PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap(); +std::mutex PrimEvaluatorConstructorMutex; + +void InitPrimEvaluatorConstructors() { + PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; + + for (const auto &iter : GetPrimitiveToEvalImplMap()) { + constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_); + } + + for (const auto &iter : GetUniformPrimitiveToImplMap()) { + constructor[iter.first] = + InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_); + } + constructor[prim::kPrimEmbed] = std::make_shared(); + constructor[prim::kPrimRefToEmbed] = std::make_shared(); + constructor[prim::kPrimGetAttr] = std::make_shared(); + constructor[prim::kPrimResolve] = std::make_shared(); + constructor[prim::kPrimCreateInstance] = std::make_shared(); + constructor[prim::kPrimPartial] = std::make_shared(); +} +} // namespace + +void ClearPrimEvaluatorMap() { + PrimEvaluatorConstructors.clear(); + GetPrimitiveToEvalImplMap().clear(); + GetUniformPrimitiveToImplMap().clear(); +} + +bool IsInWhiteList(const PrimitivePtr primitive) { + MS_EXCEPTION_IF_NULL(primitive); + + auto iter = GetPrimitiveToEvalImplMap().find(primitive); + if (iter != GetPrimitiveToEvalImplMap().end()) { + return iter->second.in_white_list_; + } + + auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive); + if (uni_iter != GetUniformPrimitiveToImplMap().end()) { + return uni_iter->second.in_white_list_; + } + + return false; +} + +StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { + MS_EXCEPTION_IF_NULL(primitive); + auto iter = GetPrimitiveToEvalImplMap().find(primitive); + if (iter == GetPrimitiveToEvalImplMap().end()) { + return nullptr; + } + return iter->second.impl_; +} + +PrimEvaluatorMap &GetPrimEvaluatorConstructors() { + PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; + if (!constructor.empty()) { + return constructor; + } + std::lock_guard initLock(PrimEvaluatorConstructorMutex); + if (constructor.empty()) { + InitPrimEvaluatorConstructors(); + } + + return constructor; +} + +namespace { +bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + auto x_tuple = dyn_cast(x); + auto model_tuple = dyn_cast(model); + + if (x_tuple == nullptr || model_tuple == nullptr) { + return false; + } + + if (model->IsGeneric()) { + return true; + } + + if (x_tuple->size() != model_tuple->size()) { + return false; + } + + for (size_t i = 0; i < x_tuple->size(); i++) { + bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]); + if (!is_subtype) { + return false; + } + } + return true; +} + +bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + auto x_tensor = dyn_cast(x); + auto model_tensor = dyn_cast(model); + + if (x_tensor == nullptr || model_tensor == nullptr) { + return false; + } + + if (model->IsGeneric()) { + return true; + } + + return IsSubtype(x_tensor->element(), model_tensor->element()); +} + +bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + auto x_list = dyn_cast(x); + auto model_list = dyn_cast(model); + + if (x_list == nullptr || model_list == nullptr) { + return false; + } + + if (model->IsGeneric()) { + return true; + } + + if (x_list->size() != model_list->size()) { + return false; + } + + bool is_subtype = true; + for (size_t i = 0; i < x_list->size(); i++) { + is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]); + if (!is_subtype) { + return false; + } + } + return is_subtype; +} + +bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + auto x_class = dyn_cast(x); + auto model_class = dyn_cast(model); + if (x_class == nullptr) { + return false; + } + if (model->IsGeneric()) { + return true; + } + + if (x_class->tag() == model_class->tag()) { + auto m_attributes = model_class->GetAttributes(); + auto x_attributes = x_class->attributes(); + if (m_attributes.size() != x_attributes.size()) { + return false; + } + + for (size_t i = 0; i < m_attributes.size(); i++) { + if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) { + return false; + } + } + return true; + } + + return false; +} + +inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + if (dyn_cast(x) == nullptr) { + return false; + } + TypePtr x_type = x->GetTypeTrack(); + return IsSubType(x_type, model); +} +} // namespace + +bool IsSubtype(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + TypeId model_typeid = model->type_id(); + switch (model_typeid) { + case kMetaTypeObject: + return true; + case kObjectTypeTuple: + return IsSubtypeTuple(x, model); + case kObjectTypeTensorType: + return IsSubtypeArray(x, model); + case kObjectTypeList: + return IsSubtypeList(x, model); + case kObjectTypeClass: + return IsSubtypeClass(x, model); + default: + if (IsSubType(model, std::make_shared())) { + return IsSubtypeScalar(x, model); + } + MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << "."; + } +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h new file mode 100644 index 0000000000..692fbe66e8 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -0,0 +1,366 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_STATIC_ANALYSIS_PRIM_H_ +#define PIPELINE_STATIC_ANALYSIS_PRIM_H_ + +#include +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/evaluator.h" + +namespace mindspore { +namespace abstract { +using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &); +struct StandartPrimitiveImplReg { + StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. + bool in_white_list_; // true if this Primitive in white list, else false. +}; + +using PrimitiveEvalImplMap = + std::unordered_map; + +class StandardPrimEvaluator : public TrivialPrimEvaluator { + public: + StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl) + : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} + ~StandardPrimEvaluator() override = default; + MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; + PrimitivePtr prim() { return prim_; } + + std::string ToString() const override { return identifier_ + prim_->name(); } + + private: + PrimitivePtr prim_; + const StandardPrimitiveEvalImpl eval_impl_; +}; + +using StandardPrimEvaluatorPtr = std::shared_ptr; + +class PythonPrimEvaluator : public TrivialPrimEvaluator { + public: + explicit PythonPrimEvaluator(const PrimitivePyPtr primitive) + : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {} + ~PythonPrimEvaluator() override = default; + MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; + PrimitivePtr prim() { return dyn_cast(prim_py_); } + + std::string ToString() const override { return identifier_ + prim_py_->name(); } + + private: + PrimitivePyPtr prim_py_; +}; + +class DoSignatureEvaluator : public Evaluator { + public: + explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} + ~DoSignatureEvaluator() override = default; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, + AnfNodeConfigPtr out_config = nullptr) override; + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } + + private: + PrimitivePtr prim_; +}; + +class UnpackGraphEvaluator : public Evaluator { + public: + explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} + ~UnpackGraphEvaluator() override = default; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, + AnfNodeConfigPtr out_config = nullptr) override; + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } + + private: + PrimitivePtr prim_; +}; + +class MixedPrecisionCastEvaluator : public Evaluator { + public: + explicit MixedPrecisionCastEvaluator(const PrimitivePtr primitive) + : Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {} + ~MixedPrecisionCastEvaluator() override = default; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, + AnfNodeConfigPtr out_config = nullptr) override; + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } + + private: + PrimitivePtr prim_; +}; + +bool IsInWhiteList(PrimitivePtr primitive); +StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); + +using ValuePtrList = std::vector; +using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); + +class UniformPrimEvaluator : public TrivialPrimEvaluator { + public: + UniformPrimEvaluator(const FunctionPtr func_desc, PrimitiveImpl impl, bool eval_value, const TypePtr specify_out_type) + : TrivialPrimEvaluator("UniformPrimEvaluator"), + impl_(impl), + eval_value_(eval_value), + func_desc_(func_desc), + nargs_(func_desc_->args().size()), + return_value_type_(func_desc_->retval()), + specify_out_type_(specify_out_type) { + for (size_t i = 0; i < nargs_; ++i) { + TypePtr type = func_desc_->args()[i]; + if (type_map_[type]) { + type_map_[type]->push_back(i); + } else { + type_map_[type] = std::make_shared>(); + type_map_[type]->push_back(i); + } + } + } + ~UniformPrimEvaluator() override = default; + MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); + + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; + ValuePtr RunImpl(const ValuePtrList &args) const; + + // If eval_value_ is False, return broadened arguments. + AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { + if (!eval_value_) { + AbstractBasePtrList broadened_args_spec_list; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened_args_spec_list), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); + return broadened_args_spec_list; + } + return args_spec_list; + } + + private: + PrimitiveImpl impl_; + bool eval_value_; + const FunctionPtr func_desc_; + const std::size_t nargs_; + const TypePtr return_value_type_; + const TypePtr specify_out_type_; + std::unordered_map>, TypeHasher, TypeEqual> type_map_; +}; + +PrimEvaluatorMap &GetPrimEvaluatorConstructors(); + +// Check whether type x is a subtype of model. +bool IsSubtype(const AbstractBasePtr x, const TypePtr model); + +void ClearPrimEvaluatorMap(); + +py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base); + +AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +} // namespace abstract +} // namespace mindspore + +#endif // PIPELINE_STATIC_ANALYSIS_PRIM_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc new file mode 100644 index 0000000000..ad39190dc3 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -0,0 +1,728 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/program_specialize.h" + +#include +#include +#include "./common.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/do_signature.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "utils/graph_utils.h" +#include "utils/log_adapter.h" +#include "utils/profile.h" +#include "debug/trace.h" + +namespace mindspore { +namespace abstract { +namespace { +inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) { + if (conf->node()->intermediate_abstract()) { + return conf->node()->intermediate_abstract(); + } + return conf->GetEvaluatedValue()->abstract(); +} + +AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) { + AnfNodePtr value_node = NewValueNode(v); + value_node->set_abstract(abs_base); + MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString(); + return value_node; +} + +bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) { + while (fg != nullptr && fg != parent) { + fg = fg->parent(); + } + return fg == parent; +} +} // namespace + +FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(context); + MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString(); + return SpecializeFuncGraph(fg, context); +} + +FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(context); + auto iter = specializations_.find(context->SpecializeKey()); + if (iter != specializations_.end()) { + return iter->second->specialized_func_graph(); + } + + std::shared_ptr fg_spec = std::make_shared(this, fg, context); + FuncGraphPtr fg2 = fg_spec->specialized_func_graph(); + specializations_[context->SpecializeKey()] = fg_spec; + fg_spec->Run(); + return fg2; +} + +std::shared_ptr ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) { + MS_EXCEPTION_IF_NULL(context); + auto iter = specializations_.find(context->SpecializeKey()); + if (iter != specializations_.end()) { + return iter->second; + } + return nullptr; +} + +std::string GetNextCounter() { + static int g_CloneCounter = 1; + std::string str_count = std::to_string(g_CloneCounter); + g_CloneCounter++; + return str_count; +} + +FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, + const AnalysisContextPtr &context) + : specializer_(s), func_graph_(fg), context_(context) { + parent_ = s->GetFuncGraphSpecializer(context->parent()); + engine_ = s->engine(); + cloner_ = SpecializerClone(fg, std::make_shared(GetNextCounter())); + repl_node_ = cloner_->cloned_node(); + specialized_func_graph_ = cloner_->cloned_func_graph()[fg]; + todo_.push_back(fg->get_return()); + auto ps = fg->parameters(); + (void)todo_.insert(todo_.end(), ps.begin(), ps.end()); +} + +AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + FuncGraphPtr fg = node->func_graph(); + + if (node->isa()) { + return node; + } + std::shared_ptr specializer = shared_from_this(); + while (fg != nullptr && fg != specializer->func_graph_) { + specializer = specializer->parent_; + } + // If had replicated, just return that. + auto iter = specializer->repl_node_->find(node); + if (iter != specializer->repl_node_->end()) { + return iter->second; + } + + auto new_node = specializer->cloner_->CloneDisconnected(node); + if (node->isa()) { + if (!new_node->isa()) { + MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << "."; + } + auto c_node = node->cast(); + MS_EXCEPTION_IF_NULL(c_node); + auto inputs = c_node->inputs(); + std::vector new_inputs; + (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs), + [this](const AnfNodePtr &inp) -> AnfNodePtr { + if (inp->isa()) { + return inp; + } + return ReplicateDisconnectedNode(inp); + }); + auto c_new_node = new_node->cast(); + MS_EXCEPTION_IF_NULL(c_new_node); + c_new_node->set_inputs(new_inputs); + } + + iter = specializer->repl_node_->find(node); + if (iter != specializer->repl_node_->end()) { + if (iter->second == node) { + MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString(); + } + } else { + MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString(); + } + return new_node; +} + +AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + FuncGraphPtr fg = node->func_graph(); + + std::shared_ptr specializer = shared_from_this(); + while (fg != nullptr && fg != specializer->func_graph_) { + specializer = specializer->parent_; + } + + MS_EXCEPTION_IF_NULL(specializer->repl_node_); + auto iter = specializer->repl_node_->find(node); + if (iter != specializer->repl_node_->end()) { + return iter->second; + } + return node; +} + +void FuncGraphSpecializer::Run() { + MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString() + << ", cloned func graph name: " << specialized_func_graph_->ToString() + << ", func graph: " << func_graph_->get_return()->DebugString(); + FirstPass(); + SecondPass(); + MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString() + << ", cloned func graph name: " << specialized_func_graph_->ToString() + << ", new func graph: " << specialized_func_graph_->get_return()->DebugString(); +} + +void FuncGraphSpecializer::FirstPass() { + while (todo_.size()) { + AnfNodePtr node = todo_.back(); + todo_.pop_back(); + if (node->func_graph() == nullptr) { + // do nothing for ValueNode + continue; + } + if (node->func_graph() != func_graph_) { + if (parent_ == nullptr) { + MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + parent_->AddTodoItem(node); + parent_->FirstPass(); + AnfNodePtr new_node = parent_->GetReplicatedNode(node); + if (node->isa()) { + parent_->ProcessCNode(new_node->cast()); + } + continue; + } + if (marked_.count(node) > 0) { + continue; + } + (void)marked_.insert(node); + ProcessNode(node); + } +} + +// Specialize CNode in func graphs +void FuncGraphSpecializer::SecondPass() { + for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) { + if (node->isa()) { + ProcessCNode(node->cast()); + } + } +} + +void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + ScopeGuard scope_guard(node->scope()); + AnfNodeConfigPtr conf = MakeConfig(node); + AnfNodePtr new_node = GetReplicatedNode(node); + MS_EXCEPTION_IF_NULL(new_node); + if (new_node->func_graph() != specialized_func_graph_) { + MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString() + << ", new_node: " << new_node->DebugString() + << ", new_node->func_graph(): " << new_node->func_graph()->ToString() + << ", specialized_func_graph_: " << specialized_func_graph_->ToString(); + return; + } + new_node->set_abstract(GetEvaluatedValueWrap(conf)); + if (new_node->isa() && new_node->abstract()->isa()) { + auto partial_abstract = dyn_cast(new_node->abstract()); + if (partial_abstract->node() == node) { + partial_abstract->set_node(new_node); + } + } + + MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); + + if (node->isa()) { + auto attrs = conf->GetEvaluatedValue()->attribute(); + auto c_old = node->cast(); + auto c_new = new_node->cast(); + auto new_inputs = c_new->inputs(); + auto old_inputs = c_old->inputs(); + for (size_t i = 0; i < old_inputs.size(); ++i) { + auto node_input = old_inputs[i]; + AnfNodeConfigPtr iconf = MakeConfig(node_input); + AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); + // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if + // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. + AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs); + if (replace_node == nullptr) { + replace_node = BuildReplacedNode(iconf); + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_abstract(ival); + MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString(); + } else { + MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString() + << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString(); + } + if (new_inputs[i] != replace_node) { + new_inputs[i] = replace_node; + MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString(); + } + } + c_new->set_inputs(new_inputs); + } +} + +AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) { + MS_EXCEPTION_IF_NULL(conf); + + auto conf_iter = engine_->anfnode_config_map().find(conf); + AnfNodeConfigPtr new_conf = conf; + while (conf_iter != engine_->anfnode_config_map().end()) { + MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node(" + << new_conf->node()->DebugString() << ")"; + new_conf = conf_iter->second; + MS_EXCEPTION_IF_NULL(new_conf); + MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node(" + << conf->node()->DebugString() << ")"; + (void)ReplicateDisconnectedNode(new_conf->node()); + conf_iter = engine_->anfnode_config_map().find(new_conf); + } + todo_.push_back(new_conf->node()); + auto repl = GetReplicatedNode(new_conf->node()); + if (repl->func_graph()) { + MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString() + << ") to replace origin:" << new_conf->node()->DebugString(); + } else { + MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString() + << ") to replace origin: " << new_conf->node()->DebugString(); + } + return repl; +} + +namespace { +const StringImmPtr kDeadNode = std::make_shared("Dead Node"); +const StringImmPtr kPolyNode = std::make_shared("Poly Node"); + +inline bool CanSpecializeNode(const AnfNodePtr &node) { + if (IsValueNode(node) || IsValueNode(node) || IsValueNode(node)) { + return true; + } + return false; +} +} // namespace + +AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, + const AbstractBasePtrList &argvals) { + MS_EXCEPTION_IF_NULL(abs); + AbstractFunctionPtr real_a = dyn_cast(abs); + MS_EXCEPTION_IF_NULL(real_a); + + AbstractFunctionPtr func = real_a->GetUnique(); + SpecializeStatusCode errcode; + ScopeGuard scope_guard(node->scope()); + AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode); + if (repl == nullptr) { + if (errcode == kSpecializeFindUniqueArgvalDead) { + const auto error_dead_node = std::make_shared(kDeadNode, node); + repl = BuildValueNode(kDeadNode, error_dead_node); + MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString(); + } else if (errcode == kSpecializeFindUniqueArgvalPoly) { + const auto error_poly_node = std::make_shared(kPolyNode, node); + repl = BuildValueNode(kPolyNode, error_poly_node); + MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString(); + } else { + MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString() + << ", abstract: " << abs->ToString(); + } + } + + return repl; +} + +AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs, + const AbstractFunctionPtr &func, + const AbstractBasePtrList &args, + SpecializeStatusCode *errcode) { + MS_EXCEPTION_IF_NULL(abs); + MS_EXCEPTION_IF_NULL(func); + MS_EXCEPTION_IF_NULL(errcode); + *errcode = kSpecializeSuccess; + + auto real_func = dyn_cast(func); + if (real_func != nullptr) { + return BuildValueNode(real_func->prim(), abs); + } + + EvaluatorPtr eval; + eval = engine_->GetEvaluatorFor(func); + MS_EXCEPTION_IF_NULL(eval); + AbstractBasePtrList argvals = eval->NormalizeArgs(args); + + std::pair result; + SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result); + if (status != kSpecializeSuccess) { + *errcode = status; + return nullptr; + } + argvals = result.first; + AbstractBasePtr unique_output = result.second; + + auto prim_func = dyn_cast(func); + if (prim_func != nullptr) { + auto type_func = std::make_shared(prim_func->prim(), argvals, unique_output); + return BuildValueNode(prim_func->prim(), type_func); + } + + if (!eval->isa()) { + MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString(); + } + auto real_eval = dyn_cast(eval); + + if (func->context() == nullptr) { + MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); + } + AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); + MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size() + << ", graph: " << context->func_graph()->get_return()->DebugString(); + if (context->func_graph()->stub()) { + MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString() + << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString() + << ", " << node->ToString(); + return node; + } + FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context); + v->set_flag(kFuncGraphFlagUndetermined, false); + return BuildValueNode(v, abs); +} + +AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) { + auto new_inputs = new_node->inputs(); + AnfNodePtr func = new_inputs[0]; + AbstractBasePtr fnval = new_inputs[0]->abstract(); + + AbstractBasePtrList args; + auto backed_fnval = fnval; + if (fnval->isa()) { + auto partial_closure = dyn_cast(fnval); + backed_fnval = partial_closure->fn(); + args = partial_closure->args(); + } + std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args), + [](const AnfNodePtr &inp) { return inp->abstract(); }); + + ScopeGuard scope_guard(new_node->scope()); + + auto specialized_node = BuildSpecializedNode(func, backed_fnval, args); + auto wrapped_node = specialized_node; + if (fnval->isa()) { + auto partial_closure = dyn_cast(fnval); + AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)), + specialized_node}; + auto anf_node = partial_closure->node(); + if (!anf_node->isa()) { + MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString(); + } + auto cnode = anf_node->cast(); + if (cnode->size() != partial_closure->args().size() + 2) { + MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() + << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); + } + auto attrs = std::make_shared(); + for (size_t i = 0; i < partial_closure->args().size(); i++) { + auto old_node = cnode->input(i + 2); + auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs); + if (possibile_value_node != nullptr) { + partial_node_list.push_back(possibile_value_node); + } else { + if (!(old_node->isa() || old_node->isa())) { + MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString(); + } + partial_node_list.push_back(old_node); + } + } + wrapped_node = new_node->func_graph()->NewCNode(partial_node_list); + wrapped_node->set_abstract(partial_closure); + } + return wrapped_node; +} + +const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { + auto cache_iter = evalcaches_.find(eval); + if (cache_iter == evalcaches_.end()) { + evalcaches_[eval] = eval->cache(); + return eval->cache(); + } + return cache_iter->second; +} + +std::pair FuncGraphSpecializer::BuildFromBroadedArgsVal( + const EvaluatorPtr &eval) { + MS_EXCEPTION_IF_NULL(eval); + std::unordered_set choices; + EvalResultPtr ret = nullptr; + AbstractBasePtrList broaded_argvals; + for (auto &argvals_map : *evalcaches_[eval]) { + auto argvals = argvals_map.first; + broaded_argvals.clear(); + + (void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); + (void)choices.insert(broaded_argvals); + MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals); + } + + if (1 == choices.size()) { + ConfigPtrList args_conf_list; + (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list), + [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared(v); }); + + // if broaden return null + ret = eval->Run(engine_, args_conf_list, nullptr); + EvaluatorCacheMapPtr real = std::make_shared(); + + (*real)[broaded_argvals] = ret; + evalcaches_[eval] = real; + return std::make_pair(broaded_argvals, ret->abstract()); + } else { + MS_LOG(DEBUG) << "Choices.size: " << choices.size(); + return std::make_pair(AbstractBasePtrList(), nullptr); + } +} + +void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { + MS_EXCEPTION_IF_NULL(new_node); + if (specializer_->seen().count(new_node) > 0) { + return; + } + specializer_->AddSeen(new_node); + auto new_inputs = new_node->inputs(); + if (new_inputs.empty()) { + MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; + } + AnfNodePtr func = new_inputs[0]; + MS_EXCEPTION_IF_NULL(func); + + // First element is func so arg start from 1 + std::vector args(new_inputs.begin() + 1, new_inputs.end()); + // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...) + while (IsPrimitiveCNode(func, prim::kPrimPartial)) { + std::vector inputs = func->cast()->inputs(); + // First element is partial, second is func so arg is start from 2 + (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); + func = inputs[1]; + } + new_inputs = args; + (void)new_inputs.insert(new_inputs.begin(), func); + + AbstractBasePtrList argvals; + MS_EXCEPTION_IF_NULL(new_inputs[0]); + AbstractBasePtr fnval = new_inputs[0]->abstract(); + MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", " + << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString(); + + // First element is func so function arguments start from 1 + for (size_t i = 1; i < new_inputs.size(); ++i) { + argvals.push_back(new_inputs[i]->abstract()); + MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", " + << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); + } + + if (!func->isa()) { + MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString(); + if (func->abstract()->isa() && !func->abstract()->isa()) { + auto func_abs = func->abstract()->cast(); + EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs); + std::pair result; + AbstractBasePtrList empty_args; + auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result); + MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status; + // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early + if (status == kSpecializeFindUniqueArgvalPoly || + (func->isa() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) || + func->abstract()->isa()))) { + auto wrapped_node = BuildSpecializedParameterNode(new_node); + new_inputs[0] = wrapped_node; + } + } + } + + if (CanSpecializeNode(func)) { + // for primitive node , we build the primitive node with infered attributes in the first pass + // so we do not build replaced node again here in second pass + if (IsValueNode(func)) { + new_inputs[0] = func; + } else { + new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); + } + } + + for (size_t i = 0; i < argvals.size();) { + size_t next = i + 1; + if (CanSpecializeNode(args[i])) { + new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector{}); + } + i = next; + } + new_node->set_inputs(new_inputs); +} + +namespace { +void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) { + MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items."; + int i = 0; + for (const auto &item : evaluator_cache_map) { + MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first; + } +} + +bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) { + if (func->isa() && argvals.empty()) { + MS_LOG(DEBUG) << "High order primitive return POLY."; + return true; + } + if (func->isa() && argvals.empty()) { + auto meta_func_graph_wrapper = dyn_cast(func); + auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph(); + if (meta_func_graph != nullptr && meta_func_graph->isa()) { + auto do_signature = dyn_cast(meta_func_graph); + if (do_signature != nullptr && do_signature->function()->isa()) { + MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY."; + return true; + } + } + } + return false; +} +} // end anonymous namespace + +SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval, + const AbstractBasePtrList &argvals, + std::pair *result) { + MS_EXCEPTION_IF_NULL(func); + MS_EXCEPTION_IF_NULL(eval); + MS_EXCEPTION_IF_NULL(result); + + EvaluatorCacheMap evaluator_cache_map = *eval->cache(); + if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { + *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract()); + return kSpecializeSuccess; + } + DumpEvaluatorCache(evaluator_cache_map, argvals); + + const EvaluatorCacheMapPtr &choices = GetEvalCache(eval); + MS_EXCEPTION_IF_NULL(choices); + + if (choices->count(argvals)) { + *result = std::make_pair(argvals, (*choices)[argvals]->abstract()); + return kSpecializeSuccess; + } else if (choices->size() == 1) { + MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; + *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); + return kSpecializeSuccess; + } else if (choices->empty()) { + MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | " + << func->type_name(); + return kSpecializeFindUniqueArgvalDead; + } else { + if (IsPolyFunc(func, argvals)) { + return kSpecializeFindUniqueArgvalPoly; + } + + MS_LOG(DEBUG) << "Try to find generalized argvals."; + *result = BuildFromBroadedArgsVal(eval); + if (!result->first.empty()) { + return kSpecializeSuccess; + } + MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism."; + return kSpecializeFindUniqueArgvalPoly; + } +} +static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) { + auto &prim_attrs = prim->attrs(); + bool is_attr_same = true; + for (auto &item : *attrs) { + auto itr = prim_attrs.find(item.first); + if (itr != prim_attrs.end()) { + if (!(*(itr->second) == *(item.second))) { + is_attr_same = false; + break; + } + } else { + is_attr_same = false; + break; + } + } + if (!is_attr_same) { + if (prim->isa()) { + PrimitivePyPtr prim_py = prim->cast(); + auto clone_fn = prim_py->GetPyObj().attr("_clone"); + py::object new_obj = clone_fn(); + auto cloned_prim = new_obj.cast(); + for (auto &item : *attrs) { + cloned_prim->AddAttr(item.first, item.second); + } + return cloned_prim; + } + auto cloned_prim = std::make_shared(*prim); + for (auto &item : *attrs) { + cloned_prim->AddAttr(item.first, item.second); + } + return cloned_prim; + } + return prim; +} + +AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, + const AttrValueMapPtr &attrs) { + MS_EXCEPTION_IF_NULL(origin_node); + MS_EXCEPTION_IF_NULL(ival); + + AbstractFunctionPtr abs = dyn_cast(ival); + if (abs != nullptr) { + // Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction. + if (abs->isa()) { + return nullptr; + } + ValuePtr value = nullptr; + if (abs->isa()) { + auto real_fn = dyn_cast(abs); + // for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one + if (attrs != nullptr) { + value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); + } else { + value = real_fn->prim(); + } + } else if (abs->isa()) { + auto real_fn = dyn_cast(abs); + value = real_fn->meta_func_graph(); + } else if (abs->isa()) { + auto real_fn = dyn_cast(abs); + value = real_fn->func_graph(); + } else { + return nullptr; + } + if (!value->isa() || value->cast()->parent() == nullptr || + (IsValueNode(origin_node) && IsVisible(func_graph_, value->cast()->parent()))) { + return BuildValueNode(value, ival); + } else { + return nullptr; + } + } else { + ValuePtr val = ival->BuildValue(); + if (val->isa()) { + return nullptr; + } + // keep primitive 'depend' not to be optimized + if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) { + return nullptr; + } + return BuildValueNode(val, ival); + } +} + +AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) { + return engine_->MakeConfig(node, context_); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h new file mode 100644 index 0000000000..d7f95be4ca --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h @@ -0,0 +1,136 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ +#define PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph_cloner.h" +#include "pipeline/jit/static_analysis/evaluator.h" + +namespace mindspore { +namespace abstract { +enum SpecializeStatusCode { + kSpecializeSuccess = 0, + kSpecializeFindUniqueArgvalDead = 1, // Dead Node + kSpecializeFindUniqueArgvalPoly = 2, // Poly Node + kSpecializeFailure = 0xFF +}; + +class FuncGraphSpecializer; + +// Specialize a func graph using analyzed abstract values. +class ProgramSpecializer { + public: + explicit ProgramSpecializer(const std::shared_ptr &engine) : engine_(engine) { + mng_ = engine_->func_graph_manager(); + } + ~ProgramSpecializer() = default; + // Run the program specializer on the topmost graph in the given context. + FuncGraphPtr Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context); + const std::unordered_set &seen() const { return seen_; } + void AddSeen(const AnfNodePtr &node) { (void)seen_.insert(node); } + + std::shared_ptr GetFuncGraphSpecializer(const AnalysisContextPtr &context); + // Specialze one FuncGraph in a given context. + FuncGraphPtr SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context); + + std::shared_ptr engine() { return engine_; } + + private: + std::shared_ptr engine_; + std::unordered_set seen_; + FuncGraphManagerPtr mng_; + std::unordered_map, ContextHasher, ContextEqual> + specializations_; +}; + +class FuncGraphSpecializer : public std::enable_shared_from_this { + public: + FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, const AnalysisContextPtr &context); + virtual ~FuncGraphSpecializer() { + specializer_ = nullptr; + repl_node_ = nullptr; + } + void Run(); + FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; } + + private: + ProgramSpecializer *specializer_; + FuncGraphPtr func_graph_; + FuncGraphPtr specialized_func_graph_; + AnalysisContextPtr context_; + std::shared_ptr parent_; + std::shared_ptr engine_; + ClonerPtr cloner_; + // ProcessNode-> [cloner_->CloneDisconnected] will clone AnfNode again. + // So, repl_node_ should pointer to GraphCloner->repl_node_ other than a copy of that. + std::unordered_map *repl_node_; + std::vector todo_; + std::unordered_set marked_; + std::unordered_map evalcaches_; + + void FirstPass(); + void SecondPass(); + void ProcessNode(const AnfNodePtr &node); + void ProcessCNode(const CNodePtr &new_node); + + AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node); + inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); } + // Get node replicated by Cloner. + AnfNodePtr GetReplicatedNode(const AnfNodePtr &node); + // Replicated node which is not used directly by a func graph, so it's not searchable from it's return node + // (disconnected). + AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node); + + // Build a value node from parameter if the function graph has special flag to hint it can be done. + AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); + + // Build a value node if ival is constant and not any-value + AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, + const AttrValueMapPtr &attrs); + // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a + // replicated node. + AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); + // Build a specialized node from given argvals; + AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, + const AbstractBasePtrList &argvals); + AnfNodePtr BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs, + const AbstractFunctionPtr &func, const AbstractBasePtrList &args, + SpecializeStatusCode *errcode); + + // Find the unique argument values which can be used to specialize a primitive or graph function. + SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval, + const AbstractBasePtrList &argvals, + std::pair *result); + // Get cache, it may be eval's cache or cache built from broaded argument values. + const EvaluatorCacheMapPtr &GetEvalCache(const EvaluatorPtr &eval); + // Try to build unique argvals from the broaded arg vals if it is unique. + std::pair BuildFromBroadedArgsVal(const EvaluatorPtr &eval); +}; +} // namespace abstract +} // namespace mindspore +#endif // PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc new file mode 100644 index 0000000000..b9e747a70b --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -0,0 +1,679 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/static_analysis.h" + +#include +#include + +#include "abstract/utils.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "utils/symbolic.h" +#include "ir/tensor.h" +#include "ir/func_graph_cloner.h" +#include "./common.h" +#include "pipeline/jit/parse/data_converter.h" +#include "debug/draw.h" +#include "pipeline/jit/static_analysis/evaluator.h" +#include "debug/trace.h" + +namespace mindspore { +namespace abstract { +bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) { + if (dyn_cast(arg_spec)) { + auto v = arg_spec->GetValueTrack(); + if (v->isa()) { + return true; + } else { + return false; + } + } else { + return false; + } +} + +AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) { + if (dyn_cast(arg1) && dyn_cast(arg2)) { + return arg1->Join(arg2); + } + return nullptr; +} + +void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { + MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() + << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString() + << ", Pointer: " << result->abstract().get(); + cache_[conf] = result; + + // Set intermediate abstract value. + if (IsIntermediateAbstract(result->abstract())) { + if (conf->node()->intermediate_abstract() == nullptr) { + conf->node()->set_intermediate_abstract(result->abstract()); + MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString(); + } else { + auto old_spec = conf->node()->intermediate_abstract(); + auto joined_spec = IntermediateJoin(result->abstract(), old_spec); + conf->node()->set_intermediate_abstract(joined_spec); + MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" + << result->abstract()->ToString() << "\njoined_spec:\t" + << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); + } + } +} + +EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { + auto value = cache_.find(conf); + if (value == cache_.end()) { + return nullptr; + } + return value->second; +} + +std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const { + MS_EXCEPTION_IF_NULL(conf); + MS_EXCEPTION_IF_NULL(conf->node()); + std::size_t hash_value = conf->node()->hash(); + if (!conf->context()->IsDummyContext()) { + hash_value = hash_combine(hash_value, std::hash{}(conf->context().get())); + } + if (conf->context() != nullptr && conf->context()->func_graph() != nullptr) { + MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() + << ", Graph: " << conf->context()->func_graph()->ToString() << " ### , hash value: " << hash_value; + } else { + MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() << " ### , hash value: " << hash_value; + } + return hash_value; +} + +bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const { + if (lhs == nullptr || rhs == nullptr) { + return false; + } + if (lhs == rhs) { + return true; + } + return (*lhs == *rhs); +} + +AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) { + ConfigPtrList args_conf_list; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), + [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); + MS_EXCEPTION_IF_NULL(func_graph_manager_); + func_graph_manager_->AddFuncGraph(func_graph); + + AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); + + // Running the analyzer. + AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); + MS_EXCEPTION_IF_NULL(root_context); + MS_EXCEPTION_IF_NULL(root_context->func_graph()); + AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context); + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(INFO) << func_graph->ToString() << ": Run finished."; + + AnalysisResult result; + MS_EXCEPTION_IF_NULL(output_conf); + result.inferred = output_conf->GetEvaluatedValue(); + result.context = root_context; + return result; +} + +AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, + const ConfigPtrList &args_conf_list) { + std::shared_ptr eval = std::make_shared(func_graph, context); + (void)eval->Run(shared_from_this(), args_conf_list, nullptr); + return eval->graph_context(); +} + +EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { + MS_EXCEPTION_IF_NULL(conf); + auto value = cache_.GetValue(conf); + if (value != nullptr) { + MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get() + << ", " << value->abstract()->ToString(); + return value; + } + + MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString(); + value = Eval(conf); + if (value == nullptr) { + MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr"; + } + cache_.set_value(conf, value); + return value; +} + +EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { + MS_EXCEPTION_IF_NULL(conf); + AnfNodePtr node = conf->node(); + EvalResultPtr eval_result = nullptr; +#ifdef DEBUG + compute_conf_stack_.push_back(node); + std::ostringstream buffer; + buffer << "Compute Config Begin:"; + for (auto iter : compute_conf_stack_) { + buffer << " -> " << iter->DebugString(); + } + MS_LOG(DEBUG) << buffer.str(); +#endif + MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString(); + MS_EXCEPTION_IF_NULL(node); + if (node->abstract() != nullptr) { + MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString(); + eval_result = std::make_shared(node->abstract(), std::make_shared()); + } else if (node->isa()) { + auto value_node = node->cast(); + eval_result = std::make_shared(EvalValueNode(value_node, conf), nullptr); + } else if (node->isa()) { + auto cnode = node->cast(); + trace::TraceEvalCNodeEnter(conf); + eval_result = EvalCNode(cnode, conf); + trace::TraceEvalCNodeLeave(); + } else { + MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() + << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + +#ifdef DEBUG + compute_conf_stack_.pop_back(); + if (eval_result == nullptr) { + MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString() + << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } +#endif + MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString(); + return eval_result; +} + +AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) { + MS_EXCEPTION_IF_NULL(conf); + MS_EXCEPTION_IF_NULL(value_node); + return ToAbstract(value_node->value(), conf->context(), conf); +} + +EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { + MS_EXCEPTION_IF_NULL(conf); + MS_EXCEPTION_IF_NULL(cnode); + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString(); + } + + AnfNodePtr func_node = inputs[0]; + MS_EXCEPTION_IF_NULL(func_node); + MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString(); + AnalysisContextPtr context = conf->context(); + AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); + MS_EXCEPTION_IF_NULL(func_conf); + // Keep it in a local variable, otherwise smart pointer will free it. + AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract(); + if (maybe_func == nullptr) { + MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() + << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); + } + if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) { + MS_LOG(DEBUG) << "EvalCNode eval Undetermined"; + return std::make_shared(maybe_func->Clone(), std::make_shared()); + } + AbstractFunctionPtr func = dyn_cast(maybe_func); + if (func == nullptr) { + MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString() + << ", func_conf: " << func_conf->ToString() + << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); + } + + ConfigPtrList args_conf_list; + // ignore the first node which is function name + for (std::size_t i = 1; i < inputs.size(); i++) { + const AnfNodePtr &node = inputs[i]; + args_conf_list.push_back(MakeConfig(node, context)); + } + std::vector infs; + + auto build_evaluator = [this, &infs, &cnode](const AbstractFuncAtomPtr &poss) { + auto evaluator = this->GetEvaluatorFor(poss); + evaluator->set_bound_node(cnode); + infs.push_back(evaluator); + }; + func->Visit(build_evaluator); + + return ExecuteEvaluators(infs, conf, args_conf_list); +} + +EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { + ConfigPtrList args_conf_list; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), + [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); + std::vector infs; + MS_EXCEPTION_IF_NULL(func); + auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) { + auto evaluator = this->GetEvaluatorFor(poss); + infs.push_back(evaluator); + }; + func->Visit(build_evaluator); + return ExecuteEvaluators(infs, nullptr, args_conf_list); +} + +void AnalysisEngine::ClearEvaluatorCache() { + for (std::pair element : constructors_) { + EvaluatorPtr evaluator = element.second; + MS_EXCEPTION_IF_NULL(evaluator); + MS_EXCEPTION_IF_NULL(evaluator->cache()); + evaluator->cache()->clear(); + } + for (auto &element : prim_constructors_) { + EvaluatorPtr evaluator = element.second; + MS_EXCEPTION_IF_NULL(evaluator); + MS_EXCEPTION_IF_NULL(evaluator->cache()); + evaluator->cache()->clear(); + } + for (auto &element : prim_py_evaluators_) { + EvaluatorPtr evaluator = element.second; + MS_EXCEPTION_IF_NULL(evaluator); + MS_EXCEPTION_IF_NULL(evaluator->cache()); + evaluator->cache()->clear(); + } +} + +void AnalysisEngine::Clear() { + cache_.Clear(); + anfnode_config_map_.clear(); + eval_trace_.clear(); + constructors_.clear(); +} + +namespace { +EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { + // Custom Primitive with python infer_shape, infer_type + EvaluatorPtr evaluator = nullptr; + MS_EXCEPTION_IF_NULL(prim); + if (prim->isa()) { + evaluator = std::make_shared(prim); + return evaluator; + } + if (prim->isa()) { + evaluator = std::make_shared(prim); + return evaluator; + } + if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) { + evaluator = std::make_shared(prim); + return evaluator; + } + if (prim->HasPyEvaluator()) { + auto prim_py = dyn_cast(prim); + if (prim_py != nullptr) { + if (engine == nullptr) { + return std::make_shared(prim_py); + } + + const auto &iter = engine->prim_py_evaluators_.find(prim_py); + if (iter != engine->prim_py_evaluators_.end()) { + return iter->second; + } + evaluator = std::make_shared(prim_py); + engine->prim_py_evaluators_[prim_py] = evaluator; + return evaluator; + } + MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; + } + + if (prim->isa() || prim->HasAttr()) { + if (engine == nullptr) { + (void)GetPrimEvaluatorConstructors(); + } + // If a primitive may have attr, try to create a new evaluator. + StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); + if (eval_impl != nullptr) { + return std::make_shared(prim, eval_impl); + } + } + + if (engine == nullptr) { + // If engine is nullptr, get constructor from default. + const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); + auto iter = prim_evaluator_map.find(prim); + if (iter != prim_evaluator_map.end()) { + evaluator = iter->second; + } + } else { + // If engine is given, get constructor from engine resource. + const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors(); + auto iter = prim_evaluator_map.find(prim); + if (iter != prim_evaluator_map.end()) { + evaluator = iter->second; + } + } + if (evaluator == nullptr) { + MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ")."; + } + return evaluator; +} +} // namespace + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + auto inf_pair = constructors_.find(func); + if (inf_pair != constructors_.end()) { + return inf_pair->second; + } + MS_EXCEPTION_IF_NULL(func); + auto primitive = func->prim(); + auto evaluator = GetPrimEvaluator(primitive, shared_from_this()); + constructors_[func] = evaluator; + return evaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + auto inf_pair = constructors_.find(func); + if (inf_pair != constructors_.end()) { + return inf_pair->second; + } + MS_EXCEPTION_IF_NULL(func); + std::shared_ptr func_graph_evaluator = + std::make_shared(func->func_graph(), func->context()); + constructors_[func] = func_graph_evaluator; + return func_graph_evaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + auto inf_pair = constructors_.find(func); + if (inf_pair != constructors_.end()) { + return inf_pair->second; + } + MS_EXCEPTION_IF_NULL(func); + std::shared_ptr evaluator = + std::make_shared(func->meta_func_graph(), func->context(), func->GetScope()); + constructors_[func] = evaluator; + return evaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + MS_EXCEPTION_IF_NULL(func); + AbstractFunctionPtr func_orig = func->fn(); + EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); + auto jevaluator = std::make_shared(evaluator_orig, func_orig); + return jevaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + MS_EXCEPTION_IF_NULL(func); + std::shared_ptr virtual_evaluator = + std::make_shared(func->args_spec_list(), func->output()); + return virtual_evaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + MS_EXCEPTION_IF_NULL(func); + AbstractFunctionPtr func_orig = func->fn(); + EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); + std::shared_ptr partial_evaluator = + std::make_shared(evaluator_orig, func->args()); + return partial_evaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &) { + MS_LOG(EXCEPTION) << "Should not be called "; +} + +// Forward to specific subclass of FunctionWrapper. +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) { + MS_EXCEPTION_IF_NULL(func); + EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this()); + return evaluator; +} + +EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { + MS_LOG(DEBUG) << "The func value: " << func->ToString(); + if (func->tracking_id() != nullptr) { + MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); + } + MS_EXCEPTION_IF_NULL(func); + if (func->tracking_id() == nullptr) { + EvaluatorPtr evaluator = _GetEvaluatorFor(func); + return evaluator; + } + auto inf_pair = constructors_.find(func); + if (inf_pair != constructors_.end()) { + return inf_pair->second; + } + + AbstractFunctionPtr func_generic = func->Copy(); + func_generic->set_tracking_id(nullptr); + EvaluatorPtr eval = _GetEvaluatorFor(func_generic); + auto tracked_eval = std::make_shared(eval); + constructors_[func] = tracked_eval; + + return tracked_eval; +} + +EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector &evaluators, + const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { + if (evaluators.size() == 1) { + EvaluatorPtr eval = evaluators[0]; + MS_EXCEPTION_IF_NULL(eval); + return eval->Run(shared_from_this(), args_conf_list, out_conf); + } + return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); +} + +void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { + auto fg_eval = evaluator->cast(); + if (fg_eval == nullptr) { + return; + } + auto fg = fg_eval->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto undetermined_fgs = fg->recursive_graphs(); + if (undetermined_fgs) { + auto fg_parent = fg->parent(); + MS_EXCEPTION_IF_NULL(fg_parent); + fg_parent->set_flag(kFuncGraphFlagUndetermined, true); + MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); + } +} + +EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector &evaluators, + const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list, + const EvalTraceRevIter &it, bool *continue_flag) { + *continue_flag = false; + // Find latest entry function to handle nested recursion. + EvaluatorPtr latest_entry = eval; + auto latest_entry_iter = eval_trace_.rbegin(); + for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { + auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); + if (it_temp != evaluators.end()) { + latest_entry = *it_temp; + latest_entry_iter = r_it; + break; + } + latest_entry_iter = ++r_it; + } + if (latest_entry != eval) { + MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); + *continue_flag = true; + return latest_entry; + } + + bool has_undetermined = false; + // Check whether sub loop has untraced undetermined evaluator. + std::set> undetermined_evals; + for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { + undetermined_evals.insert(*r_it); + } + MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); + + for (auto u_eval : undetermined_evals) { + MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; + if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { + MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; + has_undetermined = true; + break; + } + } + if (has_undetermined == false) { + MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; + *continue_flag = true; + return latest_entry; + } + + return latest_entry; +} + +EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) { + if (out_specs.size() == 0) { + MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; + } + + if (out_specs.size() == 1) { + MS_EXCEPTION_IF_NULL(out_specs[0]); + // If only one result derived, then broaden it to avoid wrong constant propagation. + return std::make_shared(out_specs[0]->Broaden(), std::make_shared()); + } + auto joined_spec = AbstractJoin(out_specs); + MS_EXCEPTION_IF_NULL(joined_spec); + MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); + return std::make_shared(joined_spec, std::make_shared()); +} + +EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector &evaluators, + const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list) { + AbstractBasePtrList out_specs; + if (!multi_poss_.count(evaluators[0])) { + multi_poss_[evaluators[0]] = evaluators[1]; + multi_poss_[evaluators[1]] = evaluators[0]; + } + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + for (auto eval : evaluators) { + SetUndeterminedFlag(eval); + + auto current_inf = std::make_pair(eval, args_spec_list); + MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); + + // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. + auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); + if (it == eval_trace_.rend()) { + eval_trace_.push_back(current_inf); + MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); + MS_EXCEPTION_IF_NULL(eval); + auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); + MS_EXCEPTION_IF_NULL(eval_result->abstract()); + MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); + out_specs.push_back(eval_result->abstract()); + eval_trace_.pop_back(); + if (eval_trace_.empty()) { + multi_poss_.clear(); + } + } else if (it != eval_trace_.rbegin()) { + bool continue_flag = false; + auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); + if (continue_flag) { + continue; + } + + // Try to travel the latest undetermined. + if (latest_entry != eval_trace_.rbegin()->first) { + MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); + auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); + MS_EXCEPTION_IF_NULL(eval_result->abstract()); + MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() + << " return out_spec: " << eval_result->abstract()->ToString(); + return eval_result; + } + } + } + + return ProcessEvalResults(out_specs); +} + +EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { + AnfNodeConfigPtr self = shared_from_base(); + return engine_.lock()->GetEvaluatedValue(self); +} + +abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context) { + AnalysisContextPtr temp_context = context; + if (temp_context == nullptr) { + temp_context = abstract::AnalysisContext::DummyContext(); + } + return std::make_shared(func_graph, temp_context); +} + +abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) { + abstract::MetaFuncGraphAbstractClosurePtr meta_func_graph_fn; + if (anf_node == nullptr) { + meta_func_graph_fn = std::make_shared(meta_func_graph); + } else { + meta_func_graph_fn = std::make_shared(meta_func_graph, anf_node->scope()); + } + return meta_func_graph_fn; +} + +abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, const AnfNodePtr &anf_node) { + auto prim_func = std::make_shared(primitive, anf_node); + return prim_func; +} + +AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) { + if (value->isa()) { + auto func_graph = value->cast(); + return MakeAbstractClosure(func_graph, context); + } + AnfNodePtr anf_node = nullptr; + if (conf != nullptr) { + anf_node = conf->node(); + } + if (value->isa()) { + auto meta_func_graph = value->cast(); + return MakeAbstractClosure(meta_func_graph, anf_node); + } + if (value->isa()) { + auto prim = value->cast(); + return MakeAbstractClosure(prim, anf_node); + } + return value->ToAbstract(); +} + +AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) { + AbstractBasePtr a = ToAbstract(value, nullptr, nullptr); + if (broaden) { + a = a->Broaden(); + } + return a; +} + +EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { + auto evaluator = GetPrimEvaluator(primitive, nullptr); + MS_EXCEPTION_IF_NULL(evaluator); + if (!evaluator->isa()) { + MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but " + << evaluator->ToString(); + } + auto trivial_evaluator = dyn_cast(evaluator); + auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); + return eval_result; +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h new file mode 100644 index 0000000000..181696f756 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -0,0 +1,280 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ +#define PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#ifdef DEBUG +#include +#endif + +#include "utils/log_adapter.h" +#include "ir/anf.h" +#include "ir/primitive_py.h" +#include "abstract/analysis_context.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "pipeline/jit/parse/parse.h" + +namespace mindspore { +namespace abstract { +// define attribute value map +using AttrValueMap = std::unordered_map; +using AttrValueMapPtr = std::shared_ptr; + +// the class to save evaluated result: abstract value and modified attribute +class EvalResult : public Base { + public: + EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {} + ~EvalResult() override = default; + MS_DECLARE_PARENT(EvalResult, Base); + AbstractBasePtr abstract() { return abstract_; } + AttrValueMapPtr attribute() { return attribute_; } + + private: + AbstractBasePtr abstract_; + AttrValueMapPtr attribute_; +}; + +using EvalResultPtr = std::shared_ptr; +// Superclass for AnfNodeConfig and VirtualConfig. +class Config : public Base { + public: + Config() = default; + ~Config() override = default; + MS_DECLARE_PARENT(Config, Base); + virtual EvalResultPtr GetEvaluatedValue() = 0; +}; + +// Config will be stored in AnalysisCache +using ConfigPtr = std::shared_ptr; +using ConfigPtrList = std::vector; + +// Config to a certain node in a certain context. +class AnfNodeConfig : public Config { + public: + AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context) + : Config(), engine_(std::weak_ptr(engine)), node_(node) { + FuncGraphPtr fg; + if (IsValueNode(node)) { + auto v = node->cast(); + fg = v->value()->cast(); + } else { + fg = node->func_graph(); + } + context_ = nullptr; + if (context != nullptr) { + context_ = context->Filter(fg); + } + } + + ~AnfNodeConfig() override = default; + MS_DECLARE_PARENT(AnfNodeConfig, Config); + + EvalResultPtr GetEvaluatedValue() override; + + AnalysisContextPtr context() const { return context_; } + + AnfNodePtr node() const { return node_; } + + AnalysisEnginePtr engine() const { return engine_.lock(); } + + // used by unordered_map; + bool operator==(const AnfNodeConfig &other) const { + // compare node with pointer, context with pointer except DummyContext as it's created by make_shared; + // context should not be nullptr; + if (context_->IsDummyContext() && other.context_->IsDummyContext()) { + return true; + } + return (node_ == other.node_) && (context_ == other.context_); + } + + std::string ToString() const override { + std::ostringstream buffer; + buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString(); + return buffer.str(); + } + + private: + // AnalysisEngine is global. + // As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use + // weak_ptr to break Config cycle. + std::weak_ptr engine_; + AnfNodePtr node_; + AnalysisContextPtr context_; +}; + +using AnfNodeConfigPtr = std::shared_ptr; + +struct AnfNodeConfigHasher { + std::size_t operator()(const AnfNodeConfigPtr conf) const; +}; + +struct AnfNodeConfigEqual { + bool operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const; +}; + +class VirtualConfig : public Config { + public: + explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {} + + ~VirtualConfig() override = default; + MS_DECLARE_PARENT(VirtualConfig, Config); + EvalResultPtr GetEvaluatedValue() override { + return std::make_shared(abstract_, std::make_shared()); + } + + private: + AbstractBasePtr abstract_; +}; + +// AnalysisCache +class AnalysisCache { + public: + AnalysisCache() = default; + ~AnalysisCache() = default; + void Clear() { cache_.clear(); } + void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg); + EvalResultPtr GetValue(const AnfNodeConfigPtr &conf); + + private: + std::unordered_map cache_; +}; + +using PrimEvaluatorMap = std::unordered_map; +using AnfNodeConfigMap = + std::unordered_map; + +struct AnalysisResult { + EvalResultPtr inferred; + AnalysisContextPtr context; +}; + +using EvalTraceRevIter = std::list>::reverse_iterator; + +class AnalysisEngine : public std::enable_shared_from_this { + public: + AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) + : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {} + ~AnalysisEngine() = default; + + // func_graph: The func_graph to analyze. + // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. + AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); + EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); + // Return the Evaluator for the given function. + EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); + + AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); + EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); + // Infer the result of fn(args). + EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); + void Clear(); + void ClearEvaluatorCache(); + AnalysisCache &cache() { return cache_; } + AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) { + return std::make_shared(shared_from_this(), node, context); + } + // Overloaded function. + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + + FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; } + const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; } + + // Set the analysis result for orig to the result for new. + // This sets an entry in anfnode_config_map from orig to new. + EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { + // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. + (void)anfnode_config_map_.emplace(orig_conf, new_conf); + MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() + << ", to new_conf: " << new_conf->node()->DebugString(); + return GetEvaluatedValue(new_conf); + } + const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } + + AnalysisCache cache_; + std::unordered_map prim_py_evaluators_; + + private: + void SetUndeterminedFlag(const EvaluatorPtr &evaluator); + EvaluatorPtr HandleNestedRecursion(const std::vector &evaluators, const EvaluatorPtr &eval, + const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, + bool *continue_flag); + EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs); + + const PrimEvaluatorMap &prim_constructors_; + FuncGraphManagerPtr func_graph_manager_; + std::unordered_map constructors_; + AnfNodeConfigMap anfnode_config_map_; + // Use a list to trace multiple evaluators. + std::list> eval_trace_; + std::map multi_poss_; + + AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, + const ConfigPtrList &args_conf_list); + EvalResultPtr Eval(const AnfNodeConfigPtr &conf); + EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn); + EvalResultPtr ExecuteEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list); + EvalResultPtr ExecuteMultipleEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list); + +#ifdef DEBUG + std::vector compute_conf_stack_; +#endif +}; + +// Translate the value to an abstract value. +// Arguments: +// value: The value to convert. +// context: The context in which the value was found, used if the value is a Graph. +// conf: The Config to the valuenode we are converting, if there is one, +// so that we can generate a tracking_id. +AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr, + const AnfNodeConfigPtr &conf = nullptr); + +// Convert a value to an abstract value. +// Arguments: +// v: The value to convert. +// broaden: If True, concrete values will be made more abstract, so e.g. +// the value 1234 would become ANYTHING. +AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false); + +template +AbstractBasePtr FromValue(const T &value, bool broaden = false) { + return FromValueInside(MakeValue(value), broaden); +} + +EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); +} // namespace abstract +} // namespace mindspore + +#endif // PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc new file mode 100644 index 0000000000..04aa6efd05 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -0,0 +1,120 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/validator.h" + +#include +#include + +#include "ir/manager.h" +#include "ir/dtype.h" +#include "./common.h" +#include "pipeline/jit/static_analysis/prim.h" + +namespace mindspore { +namespace validator { +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractClass; +using mindspore::abstract::AbstractError; +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractIndexedSlices; +using mindspore::abstract::AbstractJTagged; +using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractTensor; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractType; + +void ValidateOperation(const AnfNodePtr &node) { + if (!IsValueNode(node)) { + return; + } + + // Primitive must in whitelist + PrimitivePtr prim = GetValueNode(node); + if (abstract::IsInWhiteList(prim)) { + return; + } + if (prim->HasPyEvaluator()) { + MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; + return; + } + if (prim->name() == "fake_bprop") { + MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue(prim->GetAttr("info")); + } + + MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name(); +} + +void ValidateAbstract(const AnfNodePtr &node) { + if (node == nullptr) { + MS_LOG(DEBUG) << "Node to validate is invalid"; + return; + } + AbstractBasePtr ptrBase = node->abstract(); + if (ptrBase == nullptr) { + MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString(); + return; + } + if (ptrBase->isa() || ptrBase->isa()) { + // Validate a type. + MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); + } + if (ptrBase->isa()) { + TypePtr ptrType = ptrBase->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(ptrType); + if (ptrType->isa() || ptrType->isa()) { + // only send string in external + if (!IsValueNode(node)) { + // Validate a type. + MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); + } + } + return; + } + if (ptrBase->isa()) { + // NOTICE: validate dead code? + MS_LOG(DEBUG) << "AbstractError in the graph: " << ptrBase->ToString(); + return; + } + + if (ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || + ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || + ptrBase->isa()) { + return; + } + + if (ptrBase->isa()) { + return; + } + + // Other types show exception + MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); +} + +void Validate(const FuncGraphPtr &fg) { + FuncGraphManagerPtr mgr = Manage(fg, false); + MS_EXCEPTION_IF_NULL(mgr); + AnfNodeSet &all_nodes = mgr->all_nodes(); + for (const auto &anf_node : all_nodes) { + ValidateOperation(anf_node); + ValidateAbstract(anf_node); + } +} +} // namespace validator +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/validator.h b/mindspore/ccsrc/pipeline/jit/validator.h new file mode 100644 index 0000000000..041448aed9 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/validator.h @@ -0,0 +1,38 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H_ +#define MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H_ + +#include +#include +#include +#include +#include "frontend/operator/ops.h" +#include "ir/anf.h" +#include "utils/misc.h" + +namespace mindspore { +namespace validator { +void Validate(const FuncGraphPtr &func_graph); +void ValidateAbstract(const AnfNodePtr &node); +void ValidateOperation(const AnfNodePtr &node); +} // namespace validator +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H__ diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.cc b/mindspore/ccsrc/pipeline/parse/data_converter.cc deleted file mode 100644 index 330d03d11c..0000000000 --- a/mindspore/ccsrc/pipeline/parse/data_converter.cc +++ /dev/null @@ -1,559 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/parse/data_converter.h" -#include -#include -#include -#include -#include -#include -#include -#include "pipeline/parse/resolve.h" -#include "pipeline/parse/python_adapter.h" -#include "operator/ops.h" -#include "operator/composite/composite.h" -#include "ir/func_graph_cloner.h" -#include "utils/symbolic.h" -#include "utils/context/ms_context.h" -#include "debug/trace.h" -#include "optimizer/ad/grad.h" - -namespace mindspore { -namespace parse { -using Tensor = mindspore::tensor::Tensor; -using TensorPtr = mindspore::tensor::TensorPtr; -using MetaTensor = mindspore::tensor::MetaTensor; -using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; - -FuncGraphPtr ConvertToBpropCut(const py::object &obj) { - std::vector results = data_converter::GetObjKey(obj); - std::string obj_key = results[0]; - py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME); - - auto bprop_graph = std::make_shared(); - std::vector outputs; - - auto fake_bprop = std::make_shared("bprop_cut", py::object()); - fake_bprop->set_hook(bprop_func); - (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true)); - outputs.push_back(NewValueNode(fake_bprop)); - - py::object code_obj = py::getattr(bprop_func, "__code__"); - size_t inputs_num = py::cast(py::getattr(code_obj, "co_argcount")) - 3; - for (size_t i = 0; i < inputs_num; ++i) { - auto param = bprop_graph->add_parameter(); - outputs.push_back(param); - } - auto p1 = bprop_graph->add_parameter(); - auto p2 = bprop_graph->add_parameter(); - outputs.push_back(p1); - outputs.push_back(p2); - - bprop_graph->set_output(bprop_graph->NewCNode(outputs)); - data_converter::SetObjGraphValue(obj_key, bprop_graph); - return bprop_graph; -} - -namespace { -bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { - MS_LOG(DEBUG) << "Converting python tuple"; - py::tuple tuple = obj.cast(); - std::vector value_list; - for (size_t it = 0; it < tuple.size(); ++it) { - ValuePtr out = nullptr; - bool success = ConvertData(tuple[it], &out, use_signature); - if (!success) { - return false; - } - value_list.push_back(out); - } - *data = std::make_shared(value_list); - - return true; -} - -bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { - MS_LOG(DEBUG) << "Converting python list"; - - py::list list = obj.cast(); - std::vector value_list; - for (size_t it = 0; it < list.size(); ++it) { - ValuePtr out = nullptr; - bool success = ConvertData(list[it], &out, use_signature); - if (!success) { - return false; - } - value_list.push_back(out); - } - *data = std::make_shared(value_list); - return true; -} - -bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) { - MS_LOG(DEBUG) << "Converting cell list"; - py::sequence list = obj; - std::vector value_list; - for (size_t it = 0; it < list.size(); ++it) { - ValuePtr out = nullptr; - bool success = ConvertData(list[it], &out, use_signature); - if (!success) { - return false; - } - value_list.push_back(out); - } - *data = std::make_shared(value_list); - return true; -} - -bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { - MS_LOG(DEBUG) << "Converting python dict"; - - py::dict dict_values = obj.cast(); - std::vector> key_values; - for (auto item : dict_values) { - if (!py::isinstance(item.first)) { - MS_LOG(EXCEPTION) << "The key of dict is only support str."; - } - std::string key = py::str(item.first); - ValuePtr out = nullptr; - bool success = ConvertData(dict_values[item.first], &out, use_signature); - if (!success) { - return false; - } - key_values.emplace_back(key, out); - } - *data = std::make_shared(key_values); - return true; -} - -void ConvertNameSpace(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting python module"; - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj); - *data = std::make_shared(RESOLVE_NAMESPACE_NAME_MODULE, py::cast(module_namespace)); -} - -void ConvertDataClass(py::object obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting dataclass"; - // Maybe the obj is dataclass define - auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); - // desc has format "", strip the '<' and '>' by offset 1; - *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); -} - -bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { - MS_LOG(DEBUG) << "Converting primitive object"; - - // need check the primitive is class type or instance - auto obj_type = data_converter::GetObjType(obj); - if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { - auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); - // desc has format "", strip the '<' and '>' by offset 1; - *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); - } else { - auto primitive = obj.cast(); - if (primitive == nullptr) { - MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null"; - return false; - } - if (py::hasattr(obj, "__setattr_flag__")) { - if (py::hasattr(obj, "_clone")) { - auto clone_fn = obj.attr("_clone"); - py::object new_obj = clone_fn(); - primitive = new_obj.cast(); - } - } - if (use_signature) { - *data = std::make_shared(primitive->name(), primitive); - } else { - *data = primitive; - } - } - return true; -} - -bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) { - MS_LOG(DEBUG) << "Converting MetaFuncGraph object"; - auto meta = obj.cast(); - if (meta == nullptr) { - MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null"; - return false; - } - if (use_signature) { - *data = std::make_shared(meta->name(), meta); - } else { - *data = meta; - } - return true; -} - -bool ConvertDataType(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting type object"; - auto typeptr = obj.cast(); - if (typeptr == nullptr) { - MS_LOG(ERROR) << "Resolve TypePtr error, get ptr is null"; - return false; - } - *data = typeptr; - return true; -} - -bool ConvertMetaTensor(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting MetaTensor object."; - - auto m_tensor = obj.cast(); - if (m_tensor == nullptr) { - MS_LOG(ERROR) << "Resolve MetaTensor error, get ptr is null."; - return false; - } - *data = m_tensor; - return true; -} - -bool ConvertTensor(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting tensor object"; - - auto m_tensor = obj.cast(); - if (m_tensor == nullptr) { - MS_LOG(ERROR) << "Resolve Tensor error, get ptr is null"; - return false; - } - *data = m_tensor; - return true; -} - -bool ConvertSlice(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting slice object"; - - py::slice slice_obj = obj.cast(); - auto convert_func = [obj](std::string attr) -> ValuePtr { - auto py_attr = py::getattr(obj, attr.c_str()); - if (py::isinstance(py_attr)) { - return kNone; - } else if (py::isinstance(py_attr)) { - int value = py::cast(py_attr); - return MakeValue(value); - } else { - MS_LOG(EXCEPTION) << "Slice should contain only int or none"; - } - }; - ValuePtr start = convert_func("start"); - ValuePtr stop = convert_func("stop"); - ValuePtr step = convert_func("step"); - *data = std::make_shared(start, stop, step); - return true; -} - -bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { - FuncGraphPtr func_graph = ConvertToFuncGraph(obj); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Parse resolve function error."; - return false; - } - // if the cell object has specified bprop, it has user-defined bprop function parse and record it - if (py::hasattr(obj, CUSTOM_BPROP_NAME)) { - FuncGraphPtr bprop_graph = nullptr; - bool enable_bprop_debug = py::cast(py::getattr(obj, "bprop_debug")); - if (enable_bprop_debug) { - bprop_graph = ConvertToBpropCut(obj); - } else { - bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); - } - if (bprop_graph != nullptr) { - (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); - (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); - func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true); - } - } - *data = func_graph; - return true; -} - -bool ConvertOtherObj(py::object obj, ValuePtr *const data) { - auto obj_type = data_converter::GetObjType(obj); - MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; - if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { - MS_LOG(DEBUG) << "Resolve the class type, need create class instance."; - std::string desc = py::str(obj); - // desc has format "", strip the '<' and '>' by offset 1; - *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); - return true; - } - if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) { - MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type; - FuncGraphPtr func_graph = ConvertToFuncGraph(obj); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Parse resolve function error."; - return false; - } - *data = func_graph; - return true; - } - if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { - // Create the namespace for common class instance - // When the obj is Cell, default parse the 'construct' - if (data_converter::IsCellInstance(obj)) { - return ConvertCellObjToFuncGraph(obj, data); - } - - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); - *data = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); - return true; - } - MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); - return false; -} -} // namespace - -bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) { - // check parameter valid - if (data == nullptr) { - MS_LOG(ERROR) << "Data is null pointer"; - return false; - } - - bool ret = true; - ValuePtr converted = nullptr; - if (py::isinstance(obj)) { - converted = kNone; - } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - ret = ConvertDict(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ret = ConvertSlice(obj, &converted); - } else if (py::isinstance(obj)) { - converted = kEllipsis; - } else if (py::isinstance(obj)) { - ret = ConvertTuple(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { - ret = ConvertCellList(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ret = ConvertList(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ConvertNameSpace(obj, &converted); - } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { - ConvertDataClass(obj, &converted); - } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { - ret = ConvertPrimitive(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_METAFUNCGRAPH_FLAG)) { - ret = ConvertMetaFuncGraph(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_DTYPE_FLAG)) { - ret = ConvertDataType(obj, &converted); - } else if (py::hasattr(obj, PYTHON_TENSOR_FLAG)) { - ret = ConvertTensor(obj, &converted); - } else if (py::hasattr(obj, PYTHON_META_TENSOR_FLAG)) { - ret = ConvertMetaTensor(obj, &converted); - } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { - std::shared_ptr env = obj.cast>(); - converted = env; - } else if (py::hasattr(obj, "__parameter__")) { - auto to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); - ret = ConvertData(to_convert, &converted); - } else { - ret = ConvertOtherObj(obj, &converted); - } - - *data = converted; - return ret; -} - -// convert data to graph -FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) { - std::vector results = data_converter::GetObjKey(obj); - std::string obj_id = results[0] + python_mod_get_parse_method; - std::string obj_key = results[1]; - FuncGraphPtr func_graph = nullptr; - Any value = Any(); - bool is_cache = data_converter::GetObjectValue(obj_id, &value); - if (is_cache) { - if (value.is()) { - MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; - func_graph = value.cast(); - return func_graph; - } - } - - func_graph = ParsePythonCode(obj, python_mod_get_parse_method); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Parse resolve function error."; - return nullptr; - } - - data_converter::MakeProperNameToFuncGraph(func_graph, obj_id); - data_converter::CacheObjectValue(obj_id, func_graph); - if (obj_key != "") { - MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString(); - data_converter::SetObjGraphValue(obj_key, func_graph); - } - - return func_graph; -} -namespace data_converter { -static std::unordered_map object_map_ = std::unordered_map(); - -static std::unordered_map> object_graphs_map_ = - std::unordered_map>(); - -void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { - object_graphs_map_[obj_key].push_back(data); - MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size(); -} - -const std::unordered_map> &GetObjGraphs() { - MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size(); - return object_graphs_map_; -} - -void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; } -bool GetObjectValue(const std::string &obj_key, Any *const data) { - if (object_map_.count(obj_key)) { - *data = object_map_[obj_key]; - return true; - } - return false; -} -std::vector GetObjKey(const py::object &obj) { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj); - if (obj_tuple.size() != 2) { - MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements"; - } - return {py::cast(obj_tuple[0]), py::cast(obj_tuple[1])}; -} - -// get obj detail type -ResolveTypeDef GetObjType(const py::object &obj) { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - auto obj_type = - ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast()); - return obj_type; -} - -// get class instance detail type -ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - auto class_type = - ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast()); - return class_type; -} - -// check the object is Cell Instance -bool IsCellInstance(const py::object &obj) { - auto class_type = GetClassInstanceType(obj); - bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); - return isCell; -} - -// create the python class instance -py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::object obj; - if (params.size() == 0) { - obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type); - } else { - obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); - } - return obj; -} - -// Generate an appropriate name and set to graph debuginfo -// character <> can not used in the dot file, so change to another symbol -void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(func_graph->debug_info()); - // set detail name info of function - std::ostringstream oss; - for (size_t i = 0; i < name.size(); i++) { - if (name[i] == '<') { - oss << "「"; - } else if (name[i] == '>') { - oss << "」"; - } else { - oss << name[i]; - } - } - func_graph->debug_info()->set_full_name(oss.str()); -} - -ValuePtr PyDataToValue(const py::object &obj) { - py::object to_convert = obj; - if (py::hasattr(obj, "__parameter__")) { - to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); - } - ValuePtr value = nullptr; - (void)ConvertData(to_convert, &value); - return value; -} - -void ClearObjectCache() { - object_map_.clear(); - object_graphs_map_.clear(); -} -} // namespace data_converter - -static std::unordered_map g_dataClassToClass = {}; - -// parse dataclass to mindspore Class type -ClassPtr ParseDataClass(const py::object &cls_obj) { - std::string cls_name = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__name__")); - std::string cls_module = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__module__")); - std::string cls = cls_module + "." + cls_name; - auto iterator = g_dataClassToClass.find(cls); - if (iterator != g_dataClassToClass.end()) { - return iterator->second; - } - - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - ClassAttrVector attributes; - py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj); - for (auto &item : names) { - TypePtr type_value = item.second.cast(); - MS_EXCEPTION_IF_NULL(type_value); - MS_LOG(DEBUG) << "(Name: " << py::cast(item.first) << ", type: " << type_value->ToString() << ")"; - attributes.push_back(std::make_pair(py::cast(item.first), type_value)); - } - - std::unordered_map methods_map; - py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj); - for (auto &item : methods) { - std::string fun_name = item.first.cast(); - py::object obj = py::cast(item.second); - std::shared_ptr method_obj = std::make_shared(obj, fun_name); - methods_map[fun_name] = method_obj; - } - - std::shared_ptr me_class = std::make_shared(Named(cls_name), attributes, methods_map); - // static Variable for cache - // cppcheck-suppress unreadVariable - g_dataClassToClass[cls] = me_class; - - return me_class; -} - -void CleanDataClassToClassMap() { g_dataClassToClass.clear(); } -} // namespace parse -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.h b/mindspore/ccsrc/pipeline/parse/data_converter.h deleted file mode 100644 index 0165b55363..0000000000 --- a/mindspore/ccsrc/pipeline/parse/data_converter.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_DATA_CONVERTER_H_ -#define PIPELINE_PARSE_DATA_CONVERTER_H_ - -#include -#include -#include -#include -#include -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/python_adapter.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parse { -// data convert for parse -namespace data_converter { -void CacheObjectValue(const std::string &obj_key, const Any &data); -bool GetObjectValue(const std::string &obj_key, Any *const data); - -void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data); - -const std::unordered_map> &GetObjGraphs(); - -std::vector GetObjKey(const py::object &obj); -ResolveTypeDef GetObjType(const py::object &obj); -ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); - -bool IsCellInstance(const py::object &obj); -py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms); -void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); -ValuePtr PyDataToValue(const py::object &obj); -void ClearObjectCache(); -} // namespace data_converter - -ClassPtr ParseDataClass(const py::object &cls_obj); -FuncGraphPtr ConvertToBpropCut(const py::object &obj); - -void CleanDataClassToClassMap(); - -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_DATA_CONVERTER_H_ diff --git a/mindspore/ccsrc/pipeline/parse/function_block.cc b/mindspore/ccsrc/pipeline/parse/function_block.cc deleted file mode 100644 index fbeeba94a1..0000000000 --- a/mindspore/ccsrc/pipeline/parse/function_block.cc +++ /dev/null @@ -1,369 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/parse/function_block.h" -#include -#include -#include -#include "pipeline/parse/resolve.h" -#include "pipeline/parse/parse.h" -#include "operator/ops.h" -#include "debug/info.h" -#include "debug/trace.h" -#include "pybind11/pybind11.h" - -namespace mindspore { -namespace py = pybind11; - -namespace parse { -FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { - func_graph_ = std::make_shared(); - matured_ = false; -} - -void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); } - -// write variable records the variable name to corresponding node -void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { - MS_LOG(DEBUG) << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString(); - vars_[var_name] = node; -} - -// read variable from predecessors -AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { - // get var node if it is found - if (vars_.count(var)) { - AnfNodePtr node = vars_[var]; - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - return NewValueNode(GetValueNode(node)); - } else { - return node; - } - } - // get var from predecessor block ,if can't get the make a resolve node to it - if (matured_) { - // If only one predecessor block, read the definition of var from it. - if (prev_blocks_.size() == 1) { - auto block = prev_blocks_[0]; - MS_EXCEPTION_IF_NULL(block); - return block->ReadVariable(var); - } else if (prev_blocks_.empty()) { - // get namespace and make Reslove - return MakeResolveSymbol(var); - } - } - // If have more than one predecessor blocks then build a phi node. - auto debug_info = std::make_shared(); - debug_info->set_name(var); - TraceManager::DebugTrace(std::make_shared(debug_info)); - ParameterPtr phi_param = std::make_shared(func_graph()); - TraceManager::EndTrace(); - MS_LOG(DEBUG) << func_graph_->ToString() << " generate phi node " << phi_param->ToString() << " for " << var; - func_graph()->add_parameter(phi_param); - phi_nodes_[phi_param] = var; - WriteVariable(var, phi_param); - if (matured_) { - SetPhiArgument(phi_param); - } - return phi_param; -} - -// Resolve Ast operator node -AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) { - auto ast = parser_.ast(); - MS_EXCEPTION_IF_NULL(ast); - TraceGuard trace_guard(parser_.GetLocation(op)); - py::tuple namespace_var = ast->CallParserObjMethod(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op); - if (namespace_var.size() != 2) { - MS_LOG(EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size(); - } - NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]); - SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); - return MakeResolve(name_space, symbol); -} - -// Resolve class member, two possible: method, member variable -AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { - py::object namespace_var = - parser_.ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, parser_.ast()->obj()); - NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); - SymbolPtr symbol = std::make_shared(attr); - return MakeResolve(name_space, symbol); -} - -// Make a resolve node for symbol string -AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { - if (value.compare(0, strlen("self."), "self.") == 0) { - auto start = value.find_first_of('.') + 1; - if (start >= value.size()) { - MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value; - return nullptr; - } - auto bits_str = value.substr(start); - return MakeResolveClassMember(bits_str); - } - py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value); - - NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_var[0]); - SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); - return MakeResolve(name_space, symbol); -} - -AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) { - py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value); - NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]); - SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); - return MakeResolve(name_space, symbol); -} - -AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) { - MS_LOG(DEBUG) << "MakeResolve for " << ((std::string)py::str(name_space->obj())) << " , " - << ((std::string)resolve_symbol->symbol()); - ValueNodePtr module_node = NewValueNode(name_space); - ValueNodePtr symbol_node = NewValueNode(resolve_symbol); - auto node = func_graph()->NewCNode({NewValueNode(prim::kPrimResolve), module_node, symbol_node}); - return node; -} - -// add input for the block's phi parameter -void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { - std::string var = phi_nodes_[phi]; - MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; - for (auto &pred : prev_blocks_) { - MS_EXCEPTION_IF_NULL(pred); - MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); - AnfNodePtr arg_node = pred->ReadVariable(var); - CNodePtr jump = pred->jumps_[this]; - jump->add_input(arg_node); - } - // If the phi node in the body part of a for/while loop is being removed, - // then the closure convert phase will generate a cycle in graph if the - // loop is kept after specialization. This should be investigate further. - // Just now user has to set a flag on a function to indicate the for loop - // will definitely can be unroll as the sequence in for statement is fixed - // size in compile time. - if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) || - parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) { - CollectRemovablePhi(phi); - } -} - -AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { - AnfNodePtr arg_node = nullptr; - for (auto &prev : prev_blocks_) { - MS_EXCEPTION_IF_NULL(prev); - AnfNodePtr temp_node = prev->ReadVariable(var); - MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var - << " is " << temp_node->DebugString(); - if (temp_node != phi) { - if (arg_node == nullptr) { - arg_node = temp_node; - MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() - << " may be replaced by node " << arg_node->DebugString(); - } else if (temp_node == arg_node) { - MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " is same as node " - << arg_node->DebugString(); - } else { - MS_LOG(DEBUG) << "phi " << phi->ToString() - << " cannot be removed as it assigns to different node. node1: " << arg_node->DebugString() - << ", node2: " << temp_node->DebugString(); - return nullptr; - } - } - } - return arg_node; -} - -// Check if there is removable unnecessary phi node in this graph. -// as per the FIRM TR 3.2, a phi node can be remove if: -// -// If all arguments of a φ-function are the same value s or the φfunction itself, -// then we remove the φ-function and let all users directly uses. We call such a -// φ-function obviously unnecessary. -// When we removed a φ-function p, then we recursively try to apply this simplification -// rule with all (former) users of p, because they may have become obviously unnecessary -// due to the removal of p -// -// phi node in graph will be removed after the whole function is parsed in a DFS visit -// of that graph.The reason is : -// 1. when this function is called, not all usage of this phi node had bound to the -// graph of this function block, some may stay in vars_ in other blocks. -// 2. it's costly to iterate the graph to replace the phi for each phi. -// Args : -// phi : This parameter node is functioning as a phi node. -void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { - MS_EXCEPTION_IF_NULL(phi); - std::string var = phi_nodes_[phi]; - MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); - if (prev_blocks_.size() == 0) { - MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString(); - return; - } - AnfNodePtr arg_node = SearchReplaceNode(var, phi); - if (arg_node != nullptr) { - MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with " - << arg_node->DebugString(); - // replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1." - WriteVariable(var, arg_node); - removable_phis_[phi] = arg_node; - // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized - // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. - for (auto &prev : prev_blocks_) { - MS_EXCEPTION_IF_NULL(prev); - if (!prev->matured_) { - continue; - } - for (auto &phi_iter : prev->removable_phis_) { - MS_EXCEPTION_IF_NULL(phi_iter.second); - if (phi_iter.second->isa()) { - const auto ¶m = phi_iter.second->cast(); - if (param == phi) { - MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() - << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); - prev->removable_phis_[phi_iter.first] = arg_node; - } - } - } - } - } -} - -// A block should be marked matured if its predecessor blocks have been processed -void FunctionBlock::Mature() { - const auto &graphParamVec = func_graph_->parameters(); - for (auto ¶mItr : graphParamVec) { - MS_EXCEPTION_IF_NULL(paramItr); - ParameterPtr param = paramItr->cast(); - if (phi_nodes_.find(param) != phi_nodes_.cend()) { - SetPhiArgument(param); - } - } - matured_ = true; -} - -// Force the conditIon node to bool using bool operation -CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) { - TraceManager::DebugTrace(std::make_shared(cond->debug_info())); - CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond}); - TraceManager::EndTrace(); - return op_apply_node; -} - -CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) { - TraceManager::DebugTrace(std::make_shared(cond->debug_info())); - CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation("while_cond"), cond}); - TraceManager::EndTrace(); - return op_apply_node; -} - -// Perform a jump from this block to target block -void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) { - if (func_graph()->get_return() != nullptr) { - MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " - << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); - } - std::vector input_nodes; - input_nodes.emplace_back(NewValueNode(target_block->func_graph())); - if (node != nullptr) { - input_nodes.emplace_back(node); - } - - CNodePtr jump = func_graph()->NewCNode(input_nodes); - jumps_[target_block.get()] = jump; - target_block->AddPrevBlock(shared_from_this()); - func_graph()->set_output(jump); - InsertDependItemsBeforeReturn(); -} - -// Perform a conditional jump using switch operation. -// The first CNode select graph with condition, and than execute this graph -void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block, - const FunctionBlockPtr &false_block) { - if (func_graph()->get_return() != nullptr) { - MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " - << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); - } - CNodePtr switch_app = - func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()), - NewValueNode(false_block->func_graph())}); - CNodePtr switch_app_new = func_graph()->NewCNode({switch_app}); - func_graph()->set_output(switch_app_new); - InsertDependItemsBeforeReturn(); -} - -void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { - state_assign_[target] = readid; -} - -void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } - -void FunctionBlock::InsertDependItemsBeforeReturn() { - if (!prev_blocks_.empty()) { - for (auto &prev_block : prev_blocks_) { - MS_LOG(DEBUG) << "Has prev_block " << prev_block->func_graph()->debug_info().get(); - } - } - - ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); - ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); - ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); - const std::string primitive_name("assign"); - const std::string module_name("mindspore.ops.functional"); - ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); - if (state_assign_.size() == 0 && auto_depends_.size() == 0) { - return; - } - AnfNodePtr state = nullptr; - std::vector vec_states; - vec_states.emplace_back(make_tuple_op); - for (auto &item : state_assign_) { - auto source = ReadVariable(item.second); - auto assign = func_graph()->NewCNode({assign_op, item.first, source}); - MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; - vec_states.emplace_back(assign); - } - for (auto &item : auto_depends_) { - MS_LOG(DEBUG) << "auto_depends " << item->ToString(); - vec_states.emplace_back(item); - } - // if there are only make_tuple_op and another node in vec_states(the vec_states size is 2) - // do not need to make_tuple, just use the node. - if (vec_states.size() == 2) { - state = vec_states[1]; - } else { - state = func_graph()->NewCNode(vec_states); - } - - AnfNodePtr old_ret = nullptr; - auto return_node = func_graph()->get_return(); - if (return_node) { - if (return_node->inputs().size() < 1) { - MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2"; - } - old_ret = return_node->input(1); - } else { - old_ret = NewValueNode(kNone); - } - AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state}); - AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped}); - func_graph()->set_output(ret, true); - state_assign_.clear(); -} -} // namespace parse -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/function_block.h b/mindspore/ccsrc/pipeline/parse/function_block.h deleted file mode 100644 index 346061430d..0000000000 --- a/mindspore/ccsrc/pipeline/parse/function_block.h +++ /dev/null @@ -1,117 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_FUNCTION_BLOCK_H_ -#define PIPELINE_PARSE_FUNCTION_BLOCK_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "pipeline/parse/parse_base.h" -#include "utils/log_adapter.h" -#include "utils/ordered_map.h" - -namespace mindspore { -namespace parse { - -class Parser; -class NameSpace; -class Symbol; -class FunctionBlock; -using FunctionBlockPtr = std::shared_ptr; - -// A function block is a straight-line code sequence with no branches, every block has one one exit point -// which is return. When parsing function, loop or branch , we use function block to track the structure of -// the original source code. -class FunctionBlock : public std::enable_shared_from_this { - public: - explicit FunctionBlock(const Parser &parser); - virtual ~FunctionBlock() {} - - FuncGraphPtr func_graph() { return func_graph_; } - void WriteVariable(const std::string &var_name, const AnfNodePtr &node); - AnfNodePtr ReadVariable(const std::string &var_name); - void AddPrevBlock(const FunctionBlockPtr &block); - void SetPhiArgument(const ParameterPtr &phi); - void CollectRemovablePhi(const ParameterPtr &phi); - // A block is matured if all its predecessors is generated - void Mature(); - CNodePtr ForceToBoolNode(const AnfNodePtr &cond); - CNodePtr ForceToWhileCond(const AnfNodePtr &cond); - void Jump(const FunctionBlockPtr &block, AnfNodePtr node); - AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi); - void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock); - // record the assign statement of self.xx weight parameter ,which will use state_setitem op - void SetStateAssgin(const AnfNodePtr &target, const std::string &readid); - void AddAutoDepend(const AnfNodePtr &target); - void InsertDependItemsBeforeReturn(); - void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } - bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); } - AnfNodePtr MakeResolveAstOp(const py::object &op); - AnfNodePtr MakeResolveClassMember(std::string attr); - AnfNodePtr MakeResolveSymbol(const std::string &value); - AnfNodePtr MakeResolveOperation(const std::string &value); - AnfNodePtr MakeResolve(const std::shared_ptr &name_space, const std::shared_ptr &resolve_symbol); - const std::unordered_map &removable_phis() const { return removable_phis_; } - - private: - // block graph - FuncGraphPtr func_graph_; - - // the block's parser - const Parser &parser_; - - // A block is matured if all its prev_blocks is processed - bool matured_; - - // store the nest-level block - // refer to comments in Parser::func_block_list_; - std::vector prev_blocks_; - - // store args and variable's node - std::map vars_; - - // phi_nodes map the parameter node to variable, it can be resolved if the block's predecessors are processed - std::map phi_nodes_; - - // jumps map the successor block and the function call that perform jump - // refer to comments in Parser::func_block_list_ that how to break the cyclic reference - std::map jumps_; - - // keeps all removable phis which will be removed in one pass. - std::unordered_map removable_phis_; - - // set state nodes need to insert before function return nodes. - OrderedMap state_assign_; - - // hold declared global variables in function - std::set global_vars_; - - // other depend need to insert before function return nodes. - // summary or some other node - std::vector auto_depends_; -}; - -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_FUNCTION_BLOCK_H_ diff --git a/mindspore/ccsrc/pipeline/parse/parse.cc b/mindspore/ccsrc/pipeline/parse/parse.cc deleted file mode 100644 index 77e865cee9..0000000000 --- a/mindspore/ccsrc/pipeline/parse/parse.cc +++ /dev/null @@ -1,1476 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/parse/parse.h" -#include -#include -#include -#include -#include -#include "operator/ops.h" -#include "pipeline/parse/data_converter.h" -#include "operator/composite/composite.h" -#include "utils/context/ms_context.h" -#include "debug/trace.h" - -namespace mindspore { -namespace parse { - -FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) { - (void)python_adapter::set_python_scoped(); - - if (obj == nullptr || py::isinstance(obj)) { - MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none"; - return nullptr; - } - - auto ast = std::make_shared(obj); - bool success = ast->InitParseAstInfo(python_mod_get_parse_method); - if (!success) { - MS_LOG(ERROR) << "Parse code to ast tree failed."; - return nullptr; - } - - auto parser = std::make_shared(ast); - - FuncGraphPtr func_graph = parser->ParseFuncGraph(); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode(); - return nullptr; - } - - return func_graph; -} - -// if any mixed precision flag add a cast node after the parameter node. -AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { - TypePtr dst_type; - if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { - dst_type = kFloat32; - } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { - dst_type = kFloat16; - } else { - return param; - } - auto cast_helper = prim::kPrimMixedPrecisionCast; - auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param}); - return cast; -} - -FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr(); - -Parser::Parser(const std::shared_ptr &ast) : ast_(ast) { - errcode_ = PARSE_SUCCESS; - BuildMethodMap(); -} - -void Parser::BuildMethodMap() { - stmt_method_map_["Return"] = &Parser::ParseReturn; - stmt_method_map_["Expr"] = &Parser::ParseExpr; - stmt_method_map_["If"] = &Parser::ParseIf; - stmt_method_map_["Assign"] = &Parser::ParseAssign; - stmt_method_map_["While"] = &Parser::ParseWhile; - stmt_method_map_["For"] = &Parser::ParseFor; - stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef; - stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign; - stmt_method_map_["Global"] = &Parser::ParseGlobal; - stmt_method_map_["Break"] = &Parser::ParseBreak; - stmt_method_map_["Continue"] = &Parser::ParseContinue; - stmt_method_map_["Pass"] = &Parser::ParsePass; - expr_method_map_["NoneType"] = &Parser::ParseNone; - expr_method_map_["BinOp"] = &Parser::ParseBinOp; - expr_method_map_["Name"] = &Parser::ParseName; - expr_method_map_["Num"] = &Parser::ParseNum; - expr_method_map_["Str"] = &Parser::ParseStr; - expr_method_map_["NameConstant"] = &Parser::ParseNameConstant; - expr_method_map_["Call"] = &Parser::ParseCall; - expr_method_map_["IfExp"] = &Parser::ParseIfExp; - expr_method_map_["Attribute"] = &Parser::ParseAttribute; - expr_method_map_["Compare"] = &Parser::ParseCompare; - expr_method_map_["BoolOp"] = &Parser::ParseBoolOp; - expr_method_map_["Lambda"] = &Parser::ParseLambda; - expr_method_map_["Tuple"] = &Parser::ParseTuple; - expr_method_map_["List"] = &Parser::ParseList; - expr_method_map_["Subscript"] = &Parser::ParseSubscript; - expr_method_map_["Slice"] = &Parser::ParseSlice; - expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice; - expr_method_map_["Index"] = &Parser::ParseIndex; - expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; - expr_method_map_["Dict"] = &Parser::ParseDict; - expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis; -} - -void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } - -void Parser::InitParserEnvironment(const py::object &obj) { - Parser::top_func_graph_ = FuncGraphWeakPtr(); - ScopeManager::GetInstance().ClearScope(); - (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj); -} - -void Parser::CleanParserResource() { - Parser::top_func_graph_ = FuncGraphWeakPtr(); - ScopeManager::GetInstance().ClearScope(); -} - -FuncGraphPtr Parser::ParseFuncGraph() { - // get ast FunctionDef node - py::object node = ast_->GetAstNode(); - FunctionBlockPtr pFnBlock = ParseFunction(node); - if (errcode() != PARSE_SUCCESS) { - MS_LOG(ERROR) << "Parse function error, code is " << errcode(); - return nullptr; - } - - RemoveUnnecessaryPhis(); - - MS_EXCEPTION_IF_NULL(pFnBlock); - return pFnBlock->func_graph(); -} - -void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { - py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args"); - py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg"); - block->func_graph()->set_has_vararg(!py::isinstance(var_arg_node)); - - py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg"); - block->func_graph()->set_has_kwarg(!py::isinstance(kw_arg_node)); - - py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs"); - block->func_graph()->set_kwonlyargs_count(SizeToInt(kwonly_args.size())); - - MS_EXCEPTION_IF_NULL(ast_); - py::list args = ast_->GetArgs(fn_node); - for (std::size_t i = 0; i < args.size(); i++) { - std::string arg_name = py::cast(args[i].attr("arg")); - if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { - if (arg_name == "self") { - continue; - } - } - TraceManager::DebugTrace(GetLocation(args[i])); - auto para_node = std::make_shared(block->func_graph()); - MS_EXCEPTION_IF_NULL(para_node); - TraceManager::EndTrace(); - para_node->set_name(arg_name); - para_node->debug_info()->set_name(arg_name); - block->func_graph()->add_parameter(para_node); - AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node); - block->WriteVariable(arg_name, para_after_cast); - MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name; - } -} - -void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { - py::list defaults = ast_->GetArgsDefaultValues(fn_node); - py::list args = ast_->GetArgs(fn_node); - std::vector namelist_for_default_value; - std::vector default_values; - for (std::size_t i = 0; i < args.size(); i++) { - std::string arg_name = py::cast(args[i].attr("arg")); - if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { - if (arg_name == "self") { - continue; - } - } - - namelist_for_default_value.push_back(arg_name); - if (py::isinstance(defaults[i])) { - default_values.push_back(NewValueNode(kNull)); - } else { - default_values.push_back(ParseExprNode(block, defaults[i])); - } - } - block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values); -} - -ScopePtr Parser::GetScopeForParseFunction() { - ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope(); - if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { - py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj()); - if (!py::isinstance(scope_str)) { - auto scope_name = py::cast(scope_str); - scope = std::make_shared(scope_name); - } - } - return scope; -} - -FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) { - ScopePtr scope = GetScopeForParseFunction(); - // the node created in the parsefunction context, will inherit the scope created using scope_guard - ScopeGuard scope_guard(scope); - TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node)); - FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this); - if (block != nullptr) { - pFunBlock->AddPrevBlock(block); - } else { - func_graph_ = pFunBlock->func_graph(); - } - pFunBlock->Mature(); - auto current_fg = pFunBlock->func_graph(); - auto function_name = py::cast(python_adapter::GetPyObjAttr(node, "name")); - MS_LOG(DEBUG) << "The function name is " << function_name; - - current_fg->debug_info()->set_name(function_name); - MS_EXCEPTION_IF_NULL(ast_); - py::list deco_list = node.attr("decorator_list"); - if (deco_list.size() > 0) { - current_fg->debug_info()->set_deco_location(GetLocation(deco_list)); - } - - bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg); - if (ast_->obj() != ast_->function()) { - set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg); - } - - if (!set_flag) { - MS_LOG(ERROR) << "Set flags failed"; - return nullptr; - } - GenerateArgsNodeForFunction(pFunBlock, node); - - // when parsing the top graph of construct, save the top graph - if (GetTopFuncGraph() == nullptr) { - UpdateTopFuncGraph(pFunBlock->func_graph()); - } - - // save the function node to block - pFunBlock->WriteVariable(function_name, NewValueNode(current_fg)); - - py::object funcObj = python_adapter::GetPyObjAttr(node, "body"); - (void)ParseStatements(pFunBlock, funcObj); - - if (current_fg->get_return() == nullptr) { - MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString(); - errcode_ = PARSE_NO_RETURN; - return pFunBlock; - } - GenerateArgsDefaultValueForFunction(pFunBlock, node); - return pFunBlock; -} - -FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::object &nodes) { - py::int_ pcount = python_adapter::CallPyObjMethod(nodes, "__len__"); - size_t count = IntToSize(pcount); - MS_LOG(DEBUG) << "The nodes count is " << count; - for (size_t i = 0; i < count; i++) { - auto node = py::cast(nodes)[i]; - TraceManager::DebugTrace(GetLocation(node)); - fn_block = ParseStatement(fn_block, node); - TraceManager::EndTrace(); - // insert appropriate depended items for the function block if it has a return node - if (fn_block->func_graph()->get_return() != nullptr) { - fn_block->InsertDependItemsBeforeReturn(); - // Skip statements after 'return' (or 'break', 'continue'). - break; - } - } - return fn_block; -} - -FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) { - auto node_type = ast_->GetNodeType(node); - - // check the node type - AstMainType nodeType = node_type->main_type(); - if (nodeType != AST_MAIN_TYPE_STMT) { - MS_LOG(INFO) << "Node type is error : " << nodeType; - return block; - } - // call the process function - std::string node_name = node_type->node_name(); - MS_LOG(DEBUG) << "Ast node is " << node_name; - if (stmt_method_map_.count(node_name)) { - TraceManager::DebugTrace(GetLocation(node)); - auto stmt_block = (this->*stmt_method_map_[node_name])(block, node); - TraceManager::EndTrace(); - return stmt_block; - } else { - errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; - py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - if (location.size() < 2) { - MS_LOG(EXCEPTION) << "List size should not be less than 2."; - } - auto filename = location[0].cast(); - auto line_no = location[1].cast(); - MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; - } -} - -AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast expr"; - auto node_type = ast_->GetNodeType(node); - // check the node type - AstMainType node_main_type = node_type->main_type(); - if (node_main_type != AST_MAIN_TYPE_EXPR) { - MS_LOG(ERROR) << "Node type is error : " << node_main_type; - errcode_ = PARSE_NODE_TYPE_NO_MATCH; - return nullptr; - } - // call the process function - std::string node_name = node_type->node_name(); - MS_LOG(DEBUG) << "Ast node is " << node_name; - if (expr_method_map_.count(node_name)) { - TraceManager::DebugTrace(GetLocation(node)); - auto expr_node = (this->*expr_method_map_[node_name])(block, node); - TraceManager::EndTrace(); - return expr_node; - } else { - errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; - py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - auto filename = ret[0].cast(); - auto line_no = ret[1].cast(); - MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; - } -} - -// process the expr statement and expand it -// eg: x.append(y) -> x = x.append(y) -FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Expr"; - // Expr only have value , no target - py::tuple expand_info = ast_->CallParserObjMethod(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node); - - // refer python function expand_expr_statement, expand_info is one of the following: - // True, expr.value, x - // True, expr.value - // False, None, None - // check the expand info result - auto is_expand = py::cast(expand_info[0]); - if (is_expand) { - // process the expr statement - py::object value_object = expand_info[1]; - AnfNodePtr value_node = ParseExprNode(block, value_object); - - if (py::len(expand_info) == 2) { - // add to depend list and insert before output - block->AddAutoDepend(value_node); - } else { - // expand the assign statement - py::object target_node = expand_info[2]; - WriteAssignVars(block, target_node, value_node); - } - } - return block; -} - -LocationPtr Parser::GetLocation(const py::object &node) const { - MS_EXCEPTION_IF_NULL(ast_); - py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - if (ret.size() < 5) { - MS_LOG(EXCEPTION) << "List size should not be less than 5."; - } - // refer to Location::Location() for each member of ret: line, column, line_end, column_end. - auto location = std::make_shared(ret[0].cast(), ret[1].cast(), ret[2].cast(), - ret[3].cast(), ret[4].cast()); - return location; -} - -void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block, - const FunctionBlockPtr &false_block) { - true_block->AddPrevBlock(pre_block); - true_block->Mature(); - - false_block->AddPrevBlock(pre_block); - false_block->Mature(); -} - -FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast return"; - MS_EXCEPTION_IF_NULL(block); - // create return valuenode - AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn); - // parse the return Statements value - py::object value = python_adapter::GetPyObjAttr(node, "value"); - AnfNodePtr pReturnStatementNode = ParseExprNode(block, value); - // Create the cnode - CNodePtr pReturnCNode = block->func_graph()->NewCNode({pReturnValueNode, pReturnStatementNode}); - - block->func_graph()->set_return(pReturnCNode); - - return block; -} - -// Process binary operators,eg: `a + b`, `a | b`, etc. -AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast BinOP"; - - py::object left = python_adapter::GetPyObjAttr(node, "left"); - py::object right = python_adapter::GetPyObjAttr(node, "right"); - py::object op = python_adapter::GetPyObjAttr(node, "op"); - // create left and right ANF node - AnfNodePtr left_node = ParseExprNode(block, left); - if (left_node == nullptr) { - MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode(); - return nullptr; - } - AnfNodePtr right_node = ParseExprNode(block, right); - if (right_node == nullptr) { - MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode(); - return nullptr; - } - // resolve the op - AnfNodePtr op_node = block->MakeResolveAstOp(op); - // create apply node - return block->func_graph()->NewCNode({op_node, left_node, right_node}); -} - -AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Name"; - auto name_id = py::cast(python_adapter::GetPyObjAttr(node, "id")); - MS_LOG(DEBUG) << "The Name id is " << name_id; - TraceGuard trace_guard(GetLocation(node)); - if (block->IsGlobalVar(name_id)) { - return block->MakeResolveSymbol(name_id); - } - return block->ReadVariable(name_id); -} - -AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) { - MS_LOG(DEBUG) << "Process ast NoneType"; - return NewValueNode(kNone); -} - -AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) { - MS_LOG(DEBUG) << "Process ast Ellipsis"; - return NewValueNode(kEllipsis); -} - -AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Num"; - py::object obj = python_adapter::GetPyObjAttr(node, "n"); - TraceGuard trace_guard(GetLocation(node)); - if (py::isinstance(obj)) { - MS_LOG(INFO) << "The Num is int:" << (std::string)py::str(obj); - auto data = py::cast(obj); - return NewValueNode(data); - } else if (py::isinstance(obj)) { - MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj); - auto data = py::cast(obj); - return NewValueNode(data); - } else { - // no else actually - MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj) << GetLocation(node)->ToString(); - errcode_ = PARSE_NODE_TYPE_UNKOWN; - return nullptr; - } -} - -AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Str"; - auto str_s = py::cast(python_adapter::GetPyObjAttr(node, "s")); - return NewValueNode(str_s); -} - -AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) { - MS_LOG(DEBUG) << "Process ast NameConstant"; - py::object obj = python_adapter::GetPyObjAttr(node, "value"); - TraceGuard trace_guard(GetLocation(node)); - if (py::isinstance(obj)) { - MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj); - auto data = py::cast(obj); - return NewValueNode(data); - } else if (py::isinstance(obj)) { - MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj); - return NewValueNode(kNone); - } else { - // no else actually - MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj) << GetLocation(node)->ToString(); - errcode_ = PARSE_NODE_TYPE_UNKOWN; - return nullptr; - } -} -AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes) { - AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); - std::vector make_tuple_nodes; - make_tuple_nodes.push_back(make_tuple_op); - (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes), - [](AnfNodePtr arg) -> AnfNodePtr { return arg; }); - return block->func_graph()->NewCNode(make_tuple_nodes); -} -// process function call, eg : f1(x, y) ... -AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Call"; - // process function call - py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func"); - AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node); - // function call arguments should be passed in as groups and unpacked later using unpack call - py::list args = python_adapter::GetPyObjAttr(node, "args"); - std::vector packed_arguments; - std::vector group_arguments; - - bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments); - bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments); - // if there is stared or keyword argument, unpack may be needed - bool need_unpack = need_unpack_args || need_unpack_keywords; - - return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack); -} - -AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, - const std::vector &packed_arguments, - const std::vector &group_arguments, bool need_unpack) const { - // if there is keyword arguments or starred, using an unpack_call op to unpack the argument - if (need_unpack) { - std::vector unpack_call_nodes; - auto unpack_call_op = NewValueNode(std::make_shared(NAMED_METAGRAPH_UNPACKCALL)); - unpack_call_nodes.push_back(unpack_call_op); - unpack_call_nodes.push_back(call_function_anf_node); - (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), - [](AnfNodePtr node) -> AnfNodePtr { return node; }); - CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes); - return unpack_call; - } - // else there is no keyword arguments and starred, parsed as normal arguments without unpack - std::vector func_call_nodes; - func_call_nodes.push_back(call_function_anf_node); - (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes), - [](AnfNodePtr node) -> AnfNodePtr { return node; }); - CNodePtr call_anf_node = block->func_graph()->NewCNode(func_call_nodes); - return call_anf_node; -} - -bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, - std::vector *packed_arguments, std::vector *group_arguments) { - bool need_unpack = false; - for (size_t i = 0; i < args.size(); i++) { - auto arg_node = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[i]))); - if (arg_node == AST_SUB_TYPE_STARRED) { - if (!group_arguments->empty()) { - packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); - } - packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value"))); - group_arguments->clear(); - need_unpack = true; - } else { - group_arguments->push_back(ParseExprNode(block, args[i])); - } - } - if (!group_arguments->empty()) { - packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); - } - return need_unpack; -} - -bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, - std::vector *packed_arguments) { - bool need_unpack = false; - py::list keywords = python_adapter::GetPyObjAttr(node, "keywords"); - if (!keywords.empty()) { - need_unpack = true; - std::vector keys; - std::vector values; - for (size_t index = 0; index < keywords.size(); index++) { - auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg"); - auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value"); - if (py::isinstance(kw_key)) { - packed_arguments->push_back(ParseExprNode(block, kw_value)); - } else { - auto kw_key_c = kw_key.cast(); - keys.push_back(NewValueNode(kw_key_c)); - values.push_back(ParseExprNode(block, kw_value)); - } - } - auto keys_tuple = GenerateMakeTuple(block, keys); - auto values_tuple = GenerateMakeTuple(block, values); - auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); - std::vector make_dict_nodes; - make_dict_nodes.push_back(make_dict_op); - make_dict_nodes.push_back(keys_tuple); - make_dict_nodes.push_back(values_tuple); - packed_arguments->push_back(block->func_graph()->NewCNode(make_dict_nodes)); - } - return need_unpack; -} - -// process call attributes of class type define, eg: x.y() -AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Attribute"; - - // process class value,eg: self.xx - if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { - if (ast_->IsClassMember(node)) { - std::string var_name = "self."; - std::string attr_name = node.attr("attr").cast(); - (void)var_name.append(attr_name); - auto attr_obj = ast()->obj().attr(attr_name.c_str()); - if (py::hasattr(ast()->obj(), attr_name.c_str()) && - (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance(attr_obj) || - py::isinstance(attr_obj) || py::isinstance(attr_obj) || - py::isinstance(attr_obj) || data_converter::IsCellInstance(attr_obj))) { - return block->MakeResolveSymbol(var_name); - } else { - return block->ReadVariable(var_name); - } - } - } - - // process the get attr - // Use the Primitive replace the operation resolve node (getattr) - // because the getattr will eventually be converted to Primitive node - AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr); - - // process the attr body - py::object value_body = python_adapter::GetPyObjAttr(node, "value"); - AnfNodePtr value_node = ParseExprNode(block, value_body); - if (value_node == nullptr) { - MS_LOG(WARNING) << "Parse attribute failed"; - return nullptr; - } - - // process the node attr - auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast(); - MS_LOG(DEBUG) << "Attr = " << attr_str; - TraceManager::DebugTrace(GetLocation(python_adapter::GetPyObjAttr(node, "attr"))); - AnfNodePtr attr_node = NewValueNode(attr_str); - TraceManager::EndTrace(); - - // create the apply node - return block->func_graph()->NewCNode({op_node, value_node, attr_node}); -} - -// Process comparison expression : a == b. a > b etc. -AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Compare"; - - // for python comparison ,there may be if x>y>5 , - // which there is two ops , but we only support one now - py::list ops = python_adapter::GetPyObjAttr(node, "ops"); - if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) { - MS_LOG(ERROR) << "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size(); - return nullptr; - } - - py::object left = python_adapter::GetPyObjAttr(node, "left"); - py::list comparators = python_adapter::GetPyObjAttr(node, "comparators"); - AnfNodePtr left_node = ParseExprNode(block, left); - AnfNodePtr right_node = ParseExprNode(block, comparators[0]); - - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]); - - return block->func_graph()->NewCNode({op_node, left_node, right_node}); -} - -AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, - const py::object &op) { - // if there is only one bool op now - if (value_list.size() == 1) { - AnfNodePtr first_node = ParseExprNode(block, value_list[0]); - return first_node; - } else { - py::object first = value_list[0]; - py::list rest; - for (size_t i = 1; i < value_list.size(); i++) { - rest.append(value_list[i]); - } - - AnfNodePtr first_node = ParseExprNode(block, first); - AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op); - auto op_node = block->MakeResolveAstOp(op); - return block->func_graph()->NewCNode({op_node, first_node, rest_node}); - } -} - -// Process comparison expression : a and b. a or b . -AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast BoolOp"; - py::object op_node = python_adapter::GetPyObjAttr(node, "op"); - py::list op_values = python_adapter::GetPyObjAttr(node, "values"); - return ProcessBoolOpValueList(block, op_values, op_node); -} - -// Process a function def -FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast FunctionDef"; - FunctionBlockPtr function_block = ParseFunction(node, block); - MS_EXCEPTION_IF_NULL(function_block); - - // get function name - py::str name = python_adapter::GetPyObjAttr(node, "name"); - std::string function_name = name; - ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph()); - block->WriteVariable(function_name, valuenode_graph); - return block; -} - -// Process a lambda expression . like lambda x,y: x + y -AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Lambda"; - FunctionBlockPtr func_block = MakeFunctionBlock(*this); - func_block->AddPrevBlock(block); - func_block->Mature(); - - // get lambda args - py::list args = ast_->GetArgs(node); - for (std::size_t i = 0; i < args.size(); i++) { - std::string arg = py::cast(args[i].attr("arg")); - TraceManager::DebugTrace(GetLocation(args[i])); - auto para_node = std::make_shared(func_block->func_graph()); - TraceManager::EndTrace(); - para_node->debug_info()->set_name(arg); - func_block->func_graph()->add_parameter(para_node); - func_block->WriteVariable(arg, para_node); - MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg; - } - - py::object body_node = python_adapter::GetPyObjAttr(node, "body"); - AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node); - func_block->func_graph()->set_output(lambda_body_node); - ValueNodePtr const_graph = NewValueNode(func_block->func_graph()); - return const_graph; -} - -// process a tuple -AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Tuple"; - MS_EXCEPTION_IF_NULL(block); - py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); - if (elts.size() == 0) { - auto empty_tuple = std::vector(); - return NewValueNode(std::make_shared(empty_tuple)); - } - - std::vector tuple_vec; - AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); - tuple_vec.emplace_back(make_tuple_op); - for (size_t i = 0; i < elts.size(); i++) { - AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); - tuple_vec.emplace_back(node_ptr); - } - CNodePtr tuple_app = block->func_graph()->NewCNode(tuple_vec); - return tuple_app; -} - -// process a list -AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast List"; - MS_EXCEPTION_IF_NULL(block); - py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); - if (elts.size() == 0) { - auto empty_list = std::vector(); - return NewValueNode(std::make_shared(empty_list)); - } - - std::vector list_vec; - AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST); - list_vec.emplace_back(make_list_op); - for (size_t i = 0; i < elts.size(); i++) { - AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); - list_vec.emplace_back(node_ptr); - } - CNodePtr list_app = block->func_graph()->NewCNode(list_vec); - return list_app; -} - -// process a subscript, such as x[y] , node expressed as value[slice] -AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Subscript"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - py::object value_node = python_adapter::GetPyObjAttr(node, "value"); - py::object slice_node = python_adapter::GetPyObjAttr(node, "slice"); - AnfNodePtr value = ParseExprNode(block, value_node); - AnfNodePtr slice = ParseExprNode(block, slice_node); - - return block->func_graph()->NewCNode({op_getitem, value, slice}); -} - -// process a slice, get the slice value -AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Slice"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE); - py::object start = python_adapter::GetPyObjAttr(node, "lower"); - py::object stop = python_adapter::GetPyObjAttr(node, "upper"); - py::object step = python_adapter::GetPyObjAttr(node, "step"); - AnfNodePtr start_node = ParseExprNode(block, start); - AnfNodePtr stop_node = ParseExprNode(block, stop); - AnfNodePtr step_node = ParseExprNode(block, step); - - return block->func_graph()->NewCNode({op_makeslice, start_node, stop_node, step_node}); -} - -// process a extslice -AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast ExtSlice"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); - py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims"); - - std::vector node_vec; - node_vec.emplace_back(make_tuple_op); - for (size_t i = 0; i < slice_tuple.size(); i++) { - AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]); - node_vec.emplace_back(node_ptr); - } - CNodePtr tuple_conde = block->func_graph()->NewCNode(node_vec); - return tuple_conde; -} - -// process a index, get the index number -AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Index"; - py::object value_node = python_adapter::GetPyObjAttr(node, "value"); - return ParseExprNode(block, value_node); -} - -// process a UnaryOp, +a, -b -AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast UnaryOp"; - py::object op = python_adapter::GetPyObjAttr(node, "op"); - - MS_EXCEPTION_IF_NULL(block); - // resolve the op - AnfNodePtr op_node = block->MakeResolveAstOp(op); - - py::object operand = python_adapter::GetPyObjAttr(node, "operand"); - AnfNodePtr operand_node = ParseExprNode(block, operand); - return block->func_graph()->NewCNode({op_node, operand_node}); -} - -// process a dict ast node expression -AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Dict"; - py::list keys = node.attr("keys"); - py::list values = node.attr("values"); - std::vector key_nodes; - std::vector value_nodes; - for (size_t i = 0; i < keys.size(); i++) { - key_nodes.push_back(ParseExprNode(block, keys[i])); - value_nodes.push_back(ParseExprNode(block, values[i])); - } - auto keys_tuple = GenerateMakeTuple(block, key_nodes); - auto values_tuple = GenerateMakeTuple(block, value_nodes); - MS_EXCEPTION_IF_NULL(block); - auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); - return block->func_graph()->NewCNode({make_dict_op, keys_tuple, values_tuple}); -} - -// process a augment assign such as a += b; -FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast AugAssign"; - py::object op = python_adapter::GetPyObjAttr(node, "op"); - - MS_EXCEPTION_IF_NULL(block); - // resolve the op - AnfNodePtr op_node = block->MakeResolveAstOp(op); - py::object target_node = python_adapter::GetPyObjAttr(node, "target"); - MS_EXCEPTION_IF_NULL(ast_); - auto ast_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, target_node))); - AnfNodePtr read_node = nullptr; - if (ast_type == AST_SUB_TYPE_NAME) { - read_node = ParseName(block, target_node); - } else if (ast_->IsClassMember(target_node)) { - read_node = ParseAttribute(block, target_node); - } else { - MS_LOG(EXCEPTION) << "Not supported augassign"; - } - if (read_node == nullptr) { - MS_LOG(EXCEPTION) << "Can not get target node "; - } - - py::object value = python_adapter::GetPyObjAttr(node, "value"); - AnfNodePtr value_node = ParseExprNode(block, value); - CNodePtr augassign_app = block->func_graph()->NewCNode({op_node, read_node, value_node}); - WriteAssignVars(block, target_node, augassign_app); - return block; -} - -// process global declaration such as 'global x'; -FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Global"; - MS_EXCEPTION_IF_NULL(block); - py::list vars = python_adapter::GetPyObjAttr(node, "names"); - for (auto &item : vars) { - block->AddGlobalVar(py::cast(item)); - } - return block; -} - -// process a if statement -FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast If"; - py::object test_node = python_adapter::GetPyObjAttr(node, "test"); - AnfNodePtr condition_node = ParseExprNode(block, test_node); - MS_EXCEPTION_IF_NULL(block); - CNodePtr bool_node = block->ForceToBoolNode(condition_node); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr true_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr false_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - MakeConditionBlocks(block, true_block, false_block); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr after_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - // process the if-true branch - py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); - FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode); - - // if the return_ is set ,it has its own continuation block - if (true_end->func_graph()->get_return() == nullptr) { - true_end->Jump(after_block, nullptr); - } - - // process the orelse branch - py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); - FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode); - - // if the return_ is set ,it has its own continuation block - if (false_end->func_graph()->get_return() == nullptr) { - false_end->Jump(after_block, nullptr); - } - - block->ConditionalJump(bool_node, true_block, false_block); - after_block->Mature(); - return after_block; -} - -FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast While"; - MS_EXCEPTION_IF_NULL(block); - MS_LOG(INFO) << "Parse while statement"; - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr header_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr body_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr after_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - body_block->AddPrevBlock(header_block); - after_block->AddPrevBlock(header_block); - block->Jump(header_block, nullptr); - - py::object test_node = python_adapter::GetPyObjAttr(node, "test"); - AnfNodePtr condition_node = ParseExprNode(header_block, test_node); - condition_node = header_block->ForceToWhileCond(condition_node); - body_block->Mature(); - header_block->ConditionalJump(condition_node, body_block, after_block); - - // Parse loop body statements with loop context. - LoopContext loop_context{&loops_, header_block, nullptr}; - py::object body_node = python_adapter::GetPyObjAttr(node, "body"); - FunctionBlockPtr after_body = ParseStatements(body_block, body_node); - if (after_body->func_graph()->get_return() == nullptr) { - after_body->Jump(header_block, nullptr); - } - - header_block->Mature(); - after_block->Mature(); - auto &end_block = loop_context.EndBlock(); - if (end_block) { - // end_block exists if we encounter 'break' in loop body. - after_block->Jump(end_block, nullptr); - end_block->Mature(); - return end_block; - } - // No 'break', no end_block. - return after_block; -} - -CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node, - const AnfNodePtr &op_iter) { - py::object iter_node = python_adapter::GetPyObjAttr(node, "iter"); - AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node); - return block->func_graph()->NewCNode({op_iter, iter_anf_node}); -} -CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, - const AnfNodePtr &op_hasnext) { - MS_EXCEPTION_IF_NULL(header_block); - return header_block->func_graph()->NewCNode({op_hasnext, iter_param}); -} - -FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) { - TraceManager::DebugTrace(trace_info); - FunctionBlockPtr body_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - return body_block; -} - -// A for loop will generate 3 functions :the test, the body, and the continuation -// for x in xs: -// body -// it compiled to be following statement -// it = iter(xs) -// while hastnext(it) -// x, it = next(it) -// body -FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast For"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER); - AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT); - AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT); - // generate the iterator apply - CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter); - MS_EXCEPTION_IF_NULL(iter_apply); - FunctionBlockPtr header_block = - GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); - MS_EXCEPTION_IF_NULL(header_block); - // generate the hasnext apply which is a condition - ParameterPtr iter_param = header_block->func_graph()->add_parameter(); - CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext); - // generate the body of the for statement - FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); - MS_EXCEPTION_IF_NULL(body_block); - body_block->AddPrevBlock(header_block); - // generate the iterator next apply - // process as following: `app = next(it); target = app[0]; it = app[1];` - CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param}); - CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)}); - py::object target_node = python_adapter::GetPyObjAttr(node, "target"); - - CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)}); - WriteAssignVars(body_block, target_node, target_app); - - // link the variable name with the target - auto it_info = std::make_shared(target_app->debug_info()); - iter_param->debug_info()->set_trace_info(it_info); - iter2_app->debug_info()->set_trace_info(it_info); - iter_apply->debug_info()->set_trace_info(it_info); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr after_block = MakeFunctionBlock(*this); - MS_EXCEPTION_IF_NULL(after_block); - TraceManager::EndTrace(); - after_block->AddPrevBlock(header_block); - - block->Jump(header_block, iter_apply); - body_block->Mature(); - header_block->ConditionalJump(cond_apply, body_block, after_block); - - // Parse loop body statements with loop context. - LoopContext loop_context{&loops_, header_block, iter2_app}; - py::object body_node = python_adapter::GetPyObjAttr(node, "body"); - FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); - if (after_body_block->func_graph()->get_return() == nullptr) { - after_body_block->Jump(header_block, iter2_app); - } - - header_block->Mature(); - after_block->Mature(); - auto &end_block = loop_context.EndBlock(); - if (end_block) { - // end_block exists if we encounter 'break' in loop body. - after_block->Jump(end_block, nullptr); - end_block->Mature(); - return end_block; - } - // No 'break', no end_block. - return after_block; -} -AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast IfExp"; - MS_EXCEPTION_IF_NULL(block); - py::object test_node = python_adapter::GetPyObjAttr(node, "test"); - AnfNodePtr condition_node = ParseExprNode(block, test_node); - CNodePtr bool_node = block->ForceToBoolNode(condition_node); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr true_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr false_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - MakeConditionBlocks(block, true_block, false_block); - - // process the if-true branch - py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); - true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode)); - AnfNodePtr true_node = ParseExprNode(true_block, bodyNode); - - // process the orelse branch - py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); - false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode)); - AnfNodePtr false_node = ParseExprNode(false_block, orelseNode); - - true_block->func_graph()->set_output(true_node); - false_block->func_graph()->set_output(false_node); - - // Use the Primitive replace the operation resolve node (switch) - // because the switch will eventually be converted to Primitive node - CNodePtr switch_app = - block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), bool_node, NewValueNode(true_block->func_graph()), - NewValueNode(false_block->func_graph())}); - - std::vector call_graph_nodes{switch_app}; - CNodePtr switch_app_call = block->func_graph()->NewCNode(call_graph_nodes); - return switch_app_call; -} - -void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { - MS_EXCEPTION_IF_NULL(block); - MS_EXCEPTION_IF_NULL(assigned_node); - py::str name = python_adapter::GetPyObjAttr(targ, "id"); - std::string name_id = name; - assigned_node->debug_info()->set_name(name_id); - // set the debug name of the constant graph - if (IsValueNode(assigned_node)) { - // the value should be graph - auto fg = GetValueNode(assigned_node); - if (fg->debug_info()->name().empty()) { - fg->debug_info()->set_name(name_id); - } - } - block->WriteVariable(name_id, assigned_node); -} - -void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - py::list items = python_adapter::GetPyObjAttr(targ, "elts"); - for (size_t i = 0; i < items.size(); i++) { - // Use the Primitive replace the operation resolve node (getitem) - // because the getitem will eventually be converted to Primitive node - CNodePtr item_apply = block->func_graph()->NewCNode({op_getitem, assigned_node, NewValueNode(static_cast(i))}); - - py::object elt = items[i]; - WriteAssignVars(block, elt, item_apply); - } -} - -void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, - const AnfNodePtr &assigned_node) { - // Now only support the self.xx = xxxxx, can't support x.y = xxxx - AnfNodePtr target_node = ParseExprNode(block, targ); - MS_EXCEPTION_IF_NULL(target_node); - - std::string attr_name = targ.attr("attr").cast(); - std::string var_name = "self."; - (void)var_name.append(attr_name); - MS_LOG(DEBUG) << "assign " << var_name; - - // Get targ location info for error printing - py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, targ); - if (location.size() < 2) { - MS_LOG(EXCEPTION) << "List size should not be less than 2."; - } - auto filename = location[0].cast(); - auto line_no = location[1].cast(); - // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type - if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { - MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":" - << line_no; - } - auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); - auto obj_type = obj.attr("__class__").attr("__name__"); - if (!py::hasattr(obj, "__parameter__")) { - MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" - << py::str(obj).cast() << "' with type '" - << py::str(obj_type).cast() << "' at " << filename << ":" << line_no; - } - - MS_EXCEPTION_IF_NULL(block); - block->WriteVariable(var_name, assigned_node); - MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString(); - block->SetStateAssgin(target_node, var_name); -} - -void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, - const AnfNodePtr &assigned_node) { - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM); - py::object value_obj = python_adapter::GetPyObjAttr(targ, "value"); - py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice"); - AnfNodePtr value_node = ParseExprNode(block, value_obj); - AnfNodePtr slice_node = ParseExprNode(block, slice_obj); - CNodePtr setitem_app = block->func_graph()->NewCNode({op_setitem, value_node, slice_node, assigned_node}); - // getitem apply should return the sequence data structure itself - std::string var_name = ""; - if (ast_->IsClassMember(value_obj)) { - std::string attr_name = value_obj.attr("attr").cast(); - var_name = "self." + attr_name; - if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { - MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function."; - } - auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); - auto obj_type = obj.attr("__class__").attr("__name__"); - if (!py::hasattr(obj, "__parameter__")) { - MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" - << py::str(obj).cast() << "' with type '" - << py::str(obj_type).cast() << "'."; - } - } else { - var_name = value_obj.attr("id").cast(); - } - block->WriteVariable(var_name, setitem_app); -} - -void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - MS_LOG(DEBUG) << "Process WriteAssignVars"; - auto ast_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, targ))); - if (ast_type == AST_SUB_TYPE_NAME) { - HandleAssignName(block, targ, value_node); - } else if (ast_type == AST_SUB_TYPE_TUPLE) { - HandleAssignTuple(block, targ, value_node); - } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { - HandleAssignSubscript(block, targ, value_node); - } else if (ast_->IsClassMember(targ)) { - HandleAssignClassMember(block, targ, value_node); - } else { - MS_LOG(EXCEPTION) << "Not supported assign type: " << ast_type - << " NodeInfo: " << trace::GetDebugInfo(value_node->debug_info()); - } -} - -// process a assign statement, such as a =b, a,b = tup -FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast assgin"; - py::object value_object = python_adapter::GetPyObjAttr(node, "value"); - AnfNodePtr value_node = ParseExprNode(block, value_object); - py::object targets_object = python_adapter::GetPyObjAttr(node, "targets"); - py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__"); - size_t count = IntToSize(pcount); - MS_LOG(DEBUG) << "The nodes count is " << count; - for (size_t i = 0; i < count; i++) { - auto target_node = py::cast(targets_object)[i]; - WriteAssignVars(block, target_node, value_node); - } - - return block; -} - -FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) { - if (loops_.empty()) { - // Report error if loop context not set for the 'break' statement. - py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - if (location.size() < 2) { - MS_LOG(EXCEPTION) << "List size should not be less than 2."; - } - auto filename = location[0].cast(); - auto line_no = location[1].cast(); - MS_LOG(EXCEPTION) << "Unexpected 'break' at " << filename << ":" << line_no; - } - // Get current loop. - Loop &loop = loops_.top(); - if (loop.end == nullptr) { - // Create end_block if it is not existed. - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - loop.end = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - } - // Jump to the end_block. - block->Jump(loop.end, nullptr); - return block; -} - -FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) { - if (loops_.empty()) { - // Report error if loop context not set for the 'continue' statement. - py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - if (location.size() < 2) { - MS_LOG(EXCEPTION) << "List size should not be less than 2."; - } - auto filename = location[0].cast(); - auto line_no = location[1].cast(); - MS_LOG(EXCEPTION) << "Unexpected 'continue' at " << filename << ":" << line_no; - } - // Jump to the header of the loop with iterator called. - Loop &loop = loops_.top(); - block->Jump(loop.header, loop.iterator); - return block; -} - -FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) { - // We just bypass 'pass' statement. - return block; -} - -void Parser::RemoveUnnecessaryPhis() { - // merge all removable phis to one map; - std::unordered_map removable_phis; - for (FunctionBlockPtr &block : func_block_list_) { - MS_EXCEPTION_IF_NULL(block); - removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); - } - - if (removable_phis.size() == 0) { - return; - } - for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) { - if (node->isa()) { - const auto &cnode = node->cast(); - auto &inputs = cnode->inputs(); - for (std::size_t i = 0; i < inputs.size(); i++) { - if (inputs[i]->isa()) { - const auto &inp = inputs[i]->cast(); - const auto &iter = removable_phis.find(inp); - if (iter == removable_phis.end()) { - continue; - } - auto &argNode = iter->second; - MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in " - << cnode->DebugString() << " with " << argNode->DebugString(); - cnode->set_input(i, argNode); - } - } - } - } -} - -// ParseAst class code -bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) { - // init the type - target_type_ = PARSE_TARGET_UNKNOW; - - // call python parse, get the parser fn - module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD); - - // get the obj type - auto type = data_converter::GetObjType(obj_); - if (type == RESOLVE_TYPE_FUNCTION) { - target_type_ = PARSE_TARGET_FUNCTION; - function_ = obj_; - } else if (type == RESOLVE_TYPE_METHOD) { - // process the method ,need get the method's self obj - target_type_ = PARSE_TARGET_METHOD; - py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS); - if (py::isinstance(method_object)) { - MS_LOG(ERROR) << "Get method's self object instance failed."; - return false; - } - target_type_ = PARSE_TARGET_OBJECT_INSTANCE; - function_ = obj_; - obj_ = method_object; - } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) { - // obj is class instance, get the method to parse. - function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method); - if (py::isinstance(function_)) { - MS_LOG(ERROR) << "Get obj method function failed."; - return false; - } - target_type_ = PARSE_TARGET_OBJECT_INSTANCE; - // check the fn is method - auto obj_type = data_converter::GetObjType(function_); - if (obj_type != RESOLVE_TYPE_METHOD) { - MS_LOG(WARNING) << "Parse method function is invalid."; - return false; - } - } else { - MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type; - return false; - } - - // call python parse get ast tree - parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method); - ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse"); - - // get fn name and module - function_module_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_module")); - function_name_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_name")); - function_filename_ = py::cast(python_adapter::GetPyObjAttr(parser_, "filename")); - function_line_offset_ = py::cast(python_adapter::GetPyObjAttr(parser_, "line_offset")); - - return true; -} - -// Get ast tree node : is the tree bode list[0] -py::object ParseAst::GetAstNode() { - py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body"); - py::object ast_node = tree_body[0]; - return ast_node; -} - -py::list ParseAst::GetArgs(const py::object &func_node) { - py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS, func_node); - return ret; -} - -py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) { - py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node); - return ret; -} - -AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) { - py::list list_value = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_NODE_TYPE, node); - if (list_value.size() < 2) { - MS_LOG(ERROR) << "The node of python method must has 2 values."; - return nullptr; - } - auto node_name = py::cast(list_value[0]); - auto type = AstMainType(py::cast(list_value[1])); - return std::make_shared(node, node_name, type); -} - -AstSubType ParseAst::GetOpType(const py::object &node) { - auto op_type = AstSubType(python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_AST_TYPE, node).cast()); - return op_type; -} - -bool ParseAst::IsClassMember(const py::object &node) { - py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node); - if (!py::isinstance(ret)) { - MS_LOG(ERROR) << "The result of mod function parse, should be bool type."; - return false; - } - return ret.cast(); -} - -bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { - if (func_graph == nullptr) { - MS_LOG(ERROR) << "FuncGraph is null"; - return false; - } - - if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) { - MS_LOG(DEBUG) << "No flags"; - return true; - } - py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG); - for (auto &item : flags) { - if (!py::isinstance(item.first)) { - MS_LOG(ERROR) << "Type error in flags dict convert"; - return false; - } - auto name = py::cast(item.first); - if (py::isinstance(item.second)) { - auto value = py::cast(item.second); - MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; - func_graph->set_flag(name, value); - } else if (py::isinstance(item.second)) { - auto value = py::cast(item.second); - MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; - func_graph->set_attr(name, MakeValue(value)); - } else { - MS_LOG(ERROR) << "Type error in flags/attrs dict convert"; - return false; - } - } - return true; -} - -} // namespace parse -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/parse.h b/mindspore/ccsrc/pipeline/parse/parse.h deleted file mode 100644 index 19c503c6d0..0000000000 --- a/mindspore/ccsrc/pipeline/parse/parse.h +++ /dev/null @@ -1,358 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_PARSE_H_ -#define PIPELINE_PARSE_PARSE_H_ - -#include -#include -#include -#include -#include -#include -#include "utils/misc.h" -#include "ir/anf.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/function_block.h" - -namespace mindspore { -namespace parse { - -// Parse status define -enum ParseStatusCode : int { - PARSE_SUCCESS = 0, - PARSE_FUNCTION_IS_NULL, // python function is null - PARSE_PARAMETER_INVALID, // parameter is invalid - PARSE_NO_RETURN, // function no return node - PARSE_NODE_TYPE_NO_MATCH, // ast node type is error - PARSE_NODE_TYPE_UNKOWN, // node type is unkown - PARSE_NODE_METHOD_UNSUPPORTED, // no method to parse the node - PARSE_DONT_RESOLVE_SYMBOL, // can't resolve the string - PARSE_NOT_SUPPORTED_COMPARE_EXPR, // the comparison is not supported - PARSE_FAILURE = 0xFF -}; - -class AstNodeType; -class ParseAst; - -// Save loop info for 'continue' and 'break' statements. -struct Loop { - // Loop header block. - FunctionBlockPtr header; - // Loop iterator node, used in 'for loop'. - AnfNodePtr iterator; - // Loop end block. - FunctionBlockPtr end; - - Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end) - : header(header), iterator(iterator), end(end) {} - ~Loop() = default; -}; - -// Loop context for loop stack management. -class LoopContext { - public: - LoopContext(std::stack *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) { - loops_->emplace(header, iterator, nullptr); - } - ~LoopContext() { loops_->pop(); } - const FunctionBlockPtr &EndBlock() const { return loops_->top().end; } - - private: - std::stack *loops_; -}; - -// Parser to parse python function -class Parser { - public: - explicit Parser(const std::shared_ptr &ast); - - ~Parser() {} - FuncGraphPtr ParseFuncGraph(); - FuncGraphPtr func_graph() const { return func_graph_; } - ParseStatusCode errcode() const { return errcode_; } - std::shared_ptr ast() const { return ast_; } - // get location info from the ast node - LocationPtr GetLocation(const py::object &node) const; - static void InitParserEnvironment(const py::object &obj); - static void CleanParserResource(); - static FuncGraphPtr GetTopFuncGraph() { return top_func_graph_.lock(); } - static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph); - - private: - // process the stmt node method list - FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node); - // parse expression - FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node); - // process a if statement - FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node); - // process a while statement - FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node); - // process a for statement - FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node); - // process a function def statement - FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node); - // process a augment assign - FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node); - // process a global declaration - FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node); - // process assign statement - FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node); - // process break statement - FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node); - // process continue statement - FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node); - // process pass statement - FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node); - // process the expr and slice node method list - AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node); - // process a variable name - AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); - // process NoneType - AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node); - // process Ellipsis - AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node); - // process a integer or float number - AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); - // process a string variable - AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node); - // process a name - AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node); - // process a function call - AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node); - // process the if expression - AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node); - // process class type define - AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node); - // process a compare expression - AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node); - // process a bool operation - AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node); - // process a lambda operation - AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node); - // process a tuple - AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node); - // process a tuple - AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node); - // process a tuple - AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node); - // process a slice - AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node); - - // process a extslice - AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node); - - // process a tuple - AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node); - - // process a unaryop - AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node); - - // process a dict ast node expression - AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node); - // generate argument nodes for ast function node - void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node); - // generate argument default value for ast function node - void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node); - // parse ast function node - FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr); - // parse ast statements - FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node); - // parse one ast statement node - FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node); - // parse an ast expresion node - AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node); - - void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock, - const FunctionBlockPtr &falseBlock); - void RemoveUnnecessaryPhis(); - // write a new var - void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node); - - // assign value to single variable name - void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); - - // assign value to tuple - void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); - - // assign value to class member - void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); - - // assign value to subscript - void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); - - // process a bool operation value list - AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, const py::object &op); - - CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node, - const AnfNodePtr &op_iter); - - CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, - const AnfNodePtr &op_hasnext); - - FunctionBlockPtr GenerateBlockInFor(const TraceInfoPtr &trace_info); - - bool ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, - std::vector *packed_arguments); - - bool ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, std::vector *packed_arguments, - std::vector *group_arguments); - - AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, - const std::vector &packed_arguments, - const std::vector &group_arguments, bool need_unpack) const; - ScopePtr GetScopeForParseFunction(); - void BuildMethodMap(); - FunctionBlockPtr MakeFunctionBlock(const Parser &parse) { - FunctionBlockPtr block = std::make_shared(parse); - // In order to keep effect order in the sub-graphs which generated by control flow. - // We copy the flags from the top graph to the sub-graphs. - if (func_graph_ && !func_graph_->attrs().empty()) { - block->func_graph()->set_attrs(func_graph_->attrs()); - } - func_block_list_.push_back(block); - return block; - } - // return a make tuple for input elements list - AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes); - - // shared_ptr will be hold by GraphManager, so just hold a weak ref here. - static FuncGraphWeakPtr top_func_graph_; - // Python function id, used to indicate whether two CNodes come from the same Python function - const std::shared_ptr &ast_; - FuncGraphPtr func_graph_; - // error code setwhen parsing ast tree - ParseStatusCode errcode_; - - // hold all reference for FunctionBlock in this round of parsing, - // so in FunctionBlock class we can use FunctionBlock* in member - // pre_blocks_ and jumps_ to break reference cycle. - std::vector func_block_list_; - using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); - using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); - // define the function map to parse ast Statement - std::map stmt_method_map_; - // define the function map to parse ast expression - std::map expr_method_map_; - // Save current loops to support 'continue', 'break' statement. - std::stack loops_; -}; - -// AST node type define code to ast -class AstNodeType { - public: - AstNodeType(const py::object &node, const std::string &name, AstMainType type) - : node_(node), node_name_(name), main_type_(type) {} - - ~AstNodeType() {} - - std::string node_name() const { return node_name_; } - - py::object node() const { return node_; } - - AstMainType main_type() const { return main_type_; } - - private: - const py::object &node_; - const std::string node_name_; - AstMainType main_type_; -}; - -using AstNodeTypePtr = std::shared_ptr; - -// A helper class to parse python function -class ParseAst { - public: - explicit ParseAst(const py::object &obj) : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {} - - ~ParseAst() = default; - - bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); - - py::object GetAstNode(); - - py::list GetArgs(const py::object &func_node); - - py::list GetArgsDefaultValues(const py::object &func_node); - - AstNodeTypePtr GetNodeType(const py::object &node); - - AstSubType GetOpType(const py::object &node); - - template - py::object CallParserObjMethod(const std::string &method, const T &... args) { - return python_adapter::CallPyObjMethod(parser_, method, args...); - } - - template - py::object CallParseModFunction(const std::string &function, const T &... args) { - return python_adapter::CallPyModFn(module_, function, args...); - } - - const std::string &function_name() const { return function_name_; } - - const std::string &function_module() const { return function_module_; } - - const std::string &function_filename() const { return function_filename_; } - - int function_line_offset() const { return function_line_offset_; } - - py::function function() { return function_; } - - ParseTargetTypeDef target_type() const { return target_type_; } - - py::object obj() { return obj_; } - - py::object parser() { return parser_; } - - py::object module() { return module_; } - - py::object ast_tree() { return ast_tree_; } - - bool IsClassMember(const py::object &node); - - private: - // save obj,eg: class instance or function - py::object obj_; - - // function or class method. - py::function function_; - - py::object ast_tree_; - py::object parser_; - py::module module_; - - // Is function or method - ParseTargetTypeDef target_type_; - - std::string function_name_; - std::string function_module_; - std::string function_filename_; - int function_line_offset_; -}; - -// update the graph flags -bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); - -AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); - -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_PARSE_H_ diff --git a/mindspore/ccsrc/pipeline/parse/parse_base.h b/mindspore/ccsrc/pipeline/parse/parse_base.h deleted file mode 100644 index 4961ab78c0..0000000000 --- a/mindspore/ccsrc/pipeline/parse/parse_base.h +++ /dev/null @@ -1,151 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_PARSE_BASE_H_ -#define PIPELINE_PARSE_PARSE_BASE_H_ -#include -#include -#include "pybind11/pybind11.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/manager.h" -#include "pybind_api/export_flags.h" - -namespace py = pybind11; -namespace mindspore { -namespace parse { -// define the node type -enum AstMainType : int { - AST_MAIN_TYPE_STMT = 0, // ast.Stmt - AST_MAIN_TYPE_EXPR = 1, // ast.Expr - AST_MAIN_TYPE_SLICE = 2, // ast.Slice - AST_MAIN_TYPE_UNKNOWN = 0xFF // Error -}; - -enum AstSubType : int { - AST_SUB_TYPE_AND = 3, // ast.And - AST_SUB_TYPE_OR = 4, // ast.Or - AST_SUB_TYPE_NAME = 5, // ast.Name - AST_SUB_TYPE_TUPLE = 6, // ast.Tuple - AST_SUB_TYPE_SUBSCRIPT = 7, // ast.Subscript - AST_SUB_TYPE_STARRED = 8, // ast.Starred - AST_SUB_TYPE_UNKNOWN = 0xFF // Error -}; - -// define the parse target type -enum ParseTargetTypeDef { - PARSE_TARGET_FUNCTION = 0, // function - PARSE_TARGET_METHOD = 1, // method - PARSE_TARGET_OBJECT_INSTANCE = 2, // object instance - PARSE_TARGET_UNKNOW = 0xFF // ERROR TYPE -}; - -// define python module name -const char PYTHON_MOD_PARSE_MODULE[] = "mindspore._extends.parse"; -const char PYTHON_MOD_PARSE_OBJECT_FUNCTION[] = "parse_cb"; -const char PYTHON_MOD_RESOLVE_FUNCTION[] = "resolve_symbol"; -const char PYTHON_MOD_RESOLVE_GET_OBJ_KEY[] = "get_object_key"; -const char PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER[] = "is_class_member"; -const char PYTHON_MOD_RESOLVE_GET_OBJ_TYPE[] = "get_obj_type"; -const char PYTHON_MOD_GET_OBJ_ID[] = "get_obj_id"; -const char PYTHON_MOD_GET_CLASS_INSTANCE_TYPE[] = "get_class_instance_type"; -const char PYTHON_MOD_CREATE_OBJ_INSTANCE[] = "create_obj_instance"; -const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes"; -const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods"; -const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; -const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; -const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class"; -const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; - -const char PYTHON_PARSE_GET_ARGS[] = "get_args"; -const char PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES[] = "get_args_default_values"; -const char PYTHON_PARSE_GET_NODE_TYPE[] = "get_node_type"; -const char PYTHON_PARSE_GET_AST_TYPE[] = "get_ast_type"; -const char PYTHON_PARSE_GET_NAMESPACE_SYMBOL[] = "get_namespace_symbol"; -const char PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL[] = "get_ast_namespace_symbol"; -const char PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL[] = "get_operation_namespace_symbol"; -const char PYTHON_PARSE_GET_LOCATION[] = "get_location"; -const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement"; -const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; -const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; - -const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; -const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; -const char PYTHON_MOD_GET_DEFAULT_INPUT[] = "get_default_input"; - -// define the common name -const char NAMED_PRIMITIVE_ITER[] = "iter"; -const char NAMED_PRIMITIVE_NEXT[] = "next"; -const char NAMED_PRIMITIVE_GETITEM[] = "getitem"; -const char NAMED_PRIMITIVE_SETITEM[] = "setitem"; -const char NAMED_PRIMITIVE_HASNEXT[] = "hasnext"; -const char NAMED_PRIMITIVE_BOOL[] = "bool"; // bool: P.identity -const char NAMED_PRIMITIVE_MAKETUPLE[] = "make_tuple"; -const char NAMED_PRIMITIVE_MAKELIST[] = "make_list"; -const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice"; -const char NAMED_PRIMITIVE_MAKEDICT[] = "make_dict"; -const char NAMED_METAGRAPH_UNPACKCALL[] = "unpack_call"; - -// define NAMED_PRIMITIVE_GETATTR "getattr" -// define python inline attr -const char PYTHON_GET_METHOD_SELF_CLASS[] = "__self__"; -const char PYTHON_GET_OBJ_DESC[] = "__str__"; - -const char PYTHON_EXTERN_PARSE_METHOD[] = "__parse_method__"; -const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; - -// define the parse constant -const int MAX_COMPARISON_OPS_SUPPORTED = 1; -const char CUSTOM_BPROP_NAME[] = "bprop"; - -// define the Namespace name -const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace -const char RESOLVE_NAMESPACE_NAME_CLASS_MEMBER[] = "ClassMember"; // for class member namespace -const char RESOLVE_NAMESPACE_NAME_SYMBOL_STR[] = "SymbolStr"; // for symbol str namespace -const char RESOLVE_NAMESPACE_NAME_COMMON_OPS[] = "CommonOPS"; // for common ops, eg: hasnext, next -const char RESOLVE_NAMESPACE_NAME_MODULE[] = "Module"; // fro Module namespace - -// define Resolve type -enum ResolveTypeDef : int { - RESOLVE_TYPE_NONE = 0, // resolve None - RESOLVE_TYPE_FUNCTION = 1, // reslove function - RESOLVE_TYPE_METHOD = 2, // resolve class method - RESOLVE_TYPE_CLASS_TYPE = 3, // resolve class type - RESOLVE_TYPE_CLASS_INSTANCE = 4, // resolve the class instance of common class - RESOLVE_TYPE_INVALID = 0xFF // resolve invalid -}; - -// define the class instance detail type When the type is RESOLVE_TYPE_CLASS_INSTANCE -enum ClassInstanceTypeDef { - CLASS_INSTANCE_TYPE_CELL = 0, // class instance type is Cell - CLASS_INSTANCE_TYPE_PRIMITIVE = 1, // class instance type is Primitive - CLASS_INSTANCE_TYPE_INVALID = 0xFF -}; - -// Convert python object to ValuePtr -bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false); - -// Convert python obj to graph -FuncGraphPtr ConvertToFuncGraph(const py::object &obj, - const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); - -// Parse the python object to graph -FuncGraphPtr ParsePythonCode(const py::object &obj, - const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_PARSE_BASE_H_ diff --git a/mindspore/ccsrc/pipeline/parse/python_adapter.cc b/mindspore/ccsrc/pipeline/parse/python_adapter.cc deleted file mode 100644 index df2f7d0d45..0000000000 --- a/mindspore/ccsrc/pipeline/parse/python_adapter.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/parse/python_adapter.h" -#include -#include -#include - -namespace mindspore { -namespace parse { -namespace python_adapter { -// python scoped env, should only have one scoped_ instance -static std::shared_ptr scoped_ = nullptr; -// true: start process from python, false: start process from c++ -static bool python_env_ = false; -static bool use_signature_in_resolve_ = true; -void ResetPythonScope() { scoped_ = nullptr; } -void set_use_signature_in_resolve(bool use_signature) noexcept { use_signature_in_resolve_ = use_signature; } -bool UseSignatureInResolve() { return use_signature_in_resolve_; } -void set_python_env_flag(bool python_env) noexcept { python_env_ = python_env; } -bool IsPythonEnv() { return python_env_; } -void SetPythonPath(const std::string &path) { - // load the python module path - (void)python_adapter::set_python_scoped(); - py::module sys = py::module::import("sys"); - py::list sys_path = sys.attr("path"); - - // check the path is exist? - bool is_exist = false; - for (size_t i = 0; i < sys_path.size(); i++) { - std::string path_str = py::cast(sys_path[i]); - if (path_str == path) { - is_exist = true; - } - } - if (!is_exist) { - (void)sys_path.attr("append")(path.c_str()); - } -} - -std::shared_ptr set_python_scoped() { - // if start process from python, no need set the python scope. - if (!python_env_) { - if ((Py_IsInitialized() == 0) && (scoped_ == nullptr)) { - scoped_ = std::make_shared(); - } - } - return scoped_; -} - -// return the module of python -py::module GetPyModule(const std::string &module) { - if (!module.empty()) { - return py::module::import(module.c_str()); - } else { - return py::none(); - } -} - -// Get the obj of attr -py::object GetPyObjAttr(const py::object &obj, const std::string &attr) { - if (!attr.empty() && !py::isinstance(obj)) { - if (py::hasattr(obj, attr.c_str())) { - return obj.attr(attr.c_str()); - } - MS_LOG(DEBUG) << "Obj have not the attr: " << attr; - } - return py::none(); -} - -py::object GetPyFn(const std::string &module, const std::string &name) { - (void)python_adapter::set_python_scoped(); - if (!module.empty() && !name.empty()) { - py::module mod = py::module::import(module.c_str()); - py::object fn = mod.attr(name.c_str()); - return fn; - } - return py::none(); -} - -} // namespace python_adapter -} // namespace parse -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/python_adapter.h b/mindspore/ccsrc/pipeline/parse/python_adapter.h deleted file mode 100644 index 98adcd4f73..0000000000 --- a/mindspore/ccsrc/pipeline/parse/python_adapter.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_PYTHON_ADAPTER_H_ -#define PIPELINE_PARSE_PYTHON_ADAPTER_H_ -#include -#include -#include - -#include "pybind11/embed.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -#include "pipeline/parse/parse_base.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parse { -// A utility to call python interface -namespace python_adapter { -py::module GetPyModule(const std::string &module); -py::object GetPyObjAttr(const py::object &obj, const std::string &attr); -template -py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) { - if (!method.empty() && !py::isinstance(obj)) { - return obj.attr(method.c_str())(args...); - } - return py::none(); -} - -// call python function of module -template -py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) { - if (!function.empty() && !py::isinstance(mod)) { - return mod.attr(function.c_str())(args...); - } - return py::none(); -} - -// turn off the signature when ut use parser to construct a graph. -void set_use_signature_in_resolve(bool use_signature) noexcept; -bool UseSignatureInResolve(); - -std::shared_ptr set_python_scoped(); -void ResetPythonScope(); -bool IsPythonEnv(); -void SetPythonPath(const std::string &path); -void set_python_env_flag(bool python_env) noexcept; -py::object GetPyFn(const std::string &module, const std::string &name); -// Call the python function -template -py::object CallPyFn(const std::string &module, const std::string &name, T... args) { - (void)set_python_scoped(); - if (!module.empty() && !name.empty()) { - py::module mod = py::module::import(module.c_str()); - py::object fn = mod.attr(name.c_str())(args...); - return fn; - } - return py::none(); -} -} // namespace python_adapter -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_PYTHON_ADAPTER_H_ diff --git a/mindspore/ccsrc/pipeline/parse/resolve.cc b/mindspore/ccsrc/pipeline/parse/resolve.cc deleted file mode 100644 index 87c2f78b42..0000000000 --- a/mindspore/ccsrc/pipeline/parse/resolve.cc +++ /dev/null @@ -1,324 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/parse/resolve.h" - -#include -#include -#include -#include - -#include "ir/param_value_py.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/python_adapter.h" -#include "utils/any.h" -#include "operator/ops.h" -#include "optimizer/opt.h" -#include "optimizer/irpass.h" -#include "./common.h" - -namespace mindspore { -namespace parse { -abstract::AbstractBasePtr ClassObject::ToAbstract() { - ClassPtr cls_ptr = ParseDataClass(obj()); - auto abs_scalar = std::make_shared(); - abs_scalar->set_type(std::make_shared()); - abs_scalar->set_value(cls_ptr); - - AbstractBasePtrList args_spec_list = {abs_scalar}; - auto func_ptr = std::make_shared(prim::kPrimMakeRecord); - return std::make_shared(func_ptr, args_spec_list); -} - -abstract::AbstractBasePtr ClassType::ToAbstract() { - auto abs_scalar = - std::make_shared(shared_from_base(), std::make_shared()); - AbstractBasePtrList args_spec_list = {abs_scalar}; - - auto func_ptr = std::make_shared(prim::kPrimCreateInstance); - auto ret_val = std::make_shared(func_ptr, args_spec_list); - ret_val->set_value_desc(ToString()); - return ret_val; -} - -// call python PYTHON_MOD_RESOLVE_FUNCTION interface to resolve the symbol in corresponding namespace -bool SymbolResolver::Resolve() { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - - py::object obj = namespace_->obj(); - std::string symbol = symbol_->symbol(); - if (py::isinstance(obj)) { - MS_LOG(ERROR) << "Unresolved symbol: " << symbol; - return false; - } - result_ = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol)); - return true; -} - -namespace { -// argument obj should be python Parameter object -// it will be converted to Parameter node here -AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { - MS_EXCEPTION_IF_NULL(func_graph); - - // parameter object should not be none - if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null."; - } - - if (!py::hasattr(obj, "name")) { - MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj"; - } - - // get the parameter name from parameter object - auto name_attr = python_adapter::GetPyObjAttr(obj, "name"); - if (py::isinstance(name_attr)) { - MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; - } - - std::string param_name = py::cast(name_attr); - auto top_graph = Parser::GetTopFuncGraph(); - // if the parameter node has been created , return it - AnfNodePtr para_node = nullptr; - for (auto const ¶m : top_graph->parameters()) { - auto param_node = dyn_cast(param); - if (param_node != nullptr && param_node->name() == param_name) { - para_node = param; - break; - } - } - if (para_node == nullptr) { - auto node = top_graph->AddWeightParameter(param_name); - auto param_value_new = std::make_shared(obj); - node->set_default_param(param_value_new); - - // set_abstract for parameter - auto to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); - ValuePtr converted = nullptr; - (void)ConvertData(to_convert, &converted); - bool broaden = true; - node->set_abstract(abstract::FromValue(converted, broaden)); - - para_node = node; - } - auto iter = func_graph->make_ref_params().find(para_node); - if (iter == func_graph->make_ref_params().end()) { - AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node); - - AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); - AnfNodePtr ref_key = NewValueNode(std::make_shared(param_name)); - AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node}); - func_graph->make_ref_params()[para_node] = ref_node; - func_graph->add_parameter_obj_node(ref_node); - return ref_node; - } else { - return iter->second; - } -} - -bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { - AnfNodePtr output = nullptr; - if (py::hasattr(obj, "__parameter__")) { - auto param = ResolveParameterObj(func_graph, obj); - if (param == nullptr) { - MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr"; - return false; - } - MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString(); - - output = param; - } else if (py::hasattr(obj, "__parameter_tuple__")) { - auto tuple = obj.cast(); - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t it = 0; it < tuple.size(); ++it) { - AnfNodePtr out = nullptr; - bool success = ResolveObjectToNode(func_graph, tuple[it], &out); - if (!success) { - MS_LOG(ERROR) << "Resolve object to node failed"; - return false; - } - args.push_back(out); - } - output = NewCNode(args, func_graph); - } else { - ValuePtr convert_result = nullptr; - bool converted = ConvertData(obj, &convert_result, parse::python_adapter::UseSignatureInResolve()); - if (!converted) { - MS_LOG(ERROR) << "Convert data failed"; - return false; - } - MS_EXCEPTION_IF_NULL(convert_result); - output = NewValueNode(convert_result); - if (convert_result->isa()) { - output = GetMixedPrecisionCastHelp(func_graph, output); - } - } - *node = output; - return true; -} - -bool IsAllGraphInValueSequence(const std::vector &value_vec) { - for (auto &elem : value_vec) { - if (elem->isa() || elem->isa()) { - const auto &vec = GetValue>(elem); - auto is_graph = IsAllGraphInValueSequence(vec); - if (!is_graph) { - return false; - } - } else if (!elem->isa()) { - return false; - } - } - return true; -} - -AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, - const std::vector &value_vec) { - std::vector nodes; - nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - for (auto &elem : value_vec) { - AnfNodePtr node = nullptr; - if (elem->isa() || elem->isa()) { - const auto &vec = GetValue>(elem); - node = TransformToMakeTupleNodes(manager, func_graph, vec); - } else if (elem->isa()) { - FuncGraphPtr new_fg = elem->cast(); - manager->AddFuncGraph(new_fg); - node = NewValueNode(new_fg); - } else { - MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString(); - } - nodes.emplace_back(node); - } - auto cnode = func_graph->NewCNode(nodes); - return cnode; -} - -// transform the ValueTuple or ValueList of graph node to make tuple of const graph node -bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, - const ValueNodePtr &value_node, AnfNodePtr *const transformed) { - MS_EXCEPTION_IF_NULL(value_node); - const auto &value_vec = GetValue>(value_node->value()); - if (!IsAllGraphInValueSequence(value_vec)) { - return false; - } - - // The celllist or ordered_cell will be parsed as valuetuple of const graph in it, - // So if has graph in list, try to replace the node with make tuple of graph value node. - // we do this because the graphmanger won't investigate the graph inside valuetuple, - // change the vector of graph to be make_tuple of graph value node - auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); - // replace the ret ptr to be make tuple of graph value node - *transformed = node_tuple_graphs; - - return true; -} -} // namespace - -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, - const AnfNodePtr &node) { - if (node->func_graph() == nullptr || manager == nullptr) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; - } - SymbolResolver symbol_resolver(name_space, symbol, node); - if (!symbol_resolver.Resolve()) { - MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } - - py::object obj = symbol_resolver.result(); - ScopeGuard scope_guard(node->scope()); - AnfNodePtr resolved_node = nullptr; - TraceManager::DebugTrace(std::make_shared(node->debug_info())); - bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node); - if (!success) { - MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } - if (IsValueNode(resolved_node)) { - auto new_fg = GetValueNode(resolved_node); - manager->AddFuncGraph(new_fg); - } - - // if the constant node is constant of vector of graph ,add graph to manager - if (IsValueNode(resolved_node) || IsValueNode(resolved_node)) { - (void)TransformVectorGraphValueNode(manager, node->func_graph(), resolved_node->cast(), - &resolved_node); - } - - TraceManager::EndTrace(); - return resolved_node; -} - -namespace { -opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { - opt::OptPassGroupMap map({ - {"resolve", - { - // for resolve and getattr primitive; - irpass.resolver_resolve_, - irpass.resolver_getattr_, - }}, - }); - return map; -} -} // namespace - -bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) { - if (func_graph == nullptr || res == nullptr) { - MS_LOG(ERROR) << "func_graph or resource is null"; - return false; - } - opt::irpass::ResolveIRPassLib irpass; - opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass)); - - (void)parse::python_adapter::set_python_scoped(); - - MS_EXCEPTION_IF_NULL(opt_resolve); - (void)opt_resolve->step(func_graph, use_profile); - return true; -} - -bool ResolveAll(const FuncGraphManagerPtr &manager) { - if (manager == nullptr) { - MS_LOG(ERROR) << "func graph manager is null"; - return false; - } - - if (manager->roots().size() > 1) { - MS_LOG(WARNING) - << "After call ResolveAll, only one graph will be kept in GraphManager. ResolveAll can resolve graphs" - "called from root graph, so it's not necessary to pass all graphs as roots. " - "Please ensure your usage."; - } - // should not use pipeline::Resource as Resource::Clean will clean some - // global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop - // fail as valid scope has been cleaned. - auto res = std::make_shared(); - res->set_manager(manager); - - auto roots = manager->roots(); - for (auto &fg : roots) { - bool ret = ResolveFuncGraph(fg, res, false); - if (!ret) { - MS_EXCEPTION_IF_NULL(fg); - MS_LOG(ERROR) << "Resolve fg " << fg->ToString() << " failed"; - } - } - return true; -} -} // namespace parse -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/resolve.h b/mindspore/ccsrc/pipeline/parse/resolve.h deleted file mode 100644 index df5c54855f..0000000000 --- a/mindspore/ccsrc/pipeline/parse/resolve.h +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_RESOLVE_H_ -#define PIPELINE_PARSE_RESOLVE_H_ - -#include -#include -#include "ir/anf.h" -#include "ir/manager.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "utils/log_adapter.h" - -// forward declaration of ResourceBase -namespace mindspore { -namespace pipeline { -class ResourceBase; -using ResourceBasePtr = std::shared_ptr; -} // namespace pipeline -} // namespace mindspore - -namespace mindspore { -namespace parse { - -// NameSpace class for resolving python code. -class NameSpace : public Named { - public: - NameSpace(const std::string &module, const py::object &obj) : Named(module), module_(module), obj_(obj) {} - ~NameSpace() override = default; - MS_DECLARE_PARENT(NameSpace, Named); - - py::object obj() { return obj_; } - std::string module() { return module_; } - abstract::AbstractBasePtr ToAbstract() override { - return std::make_shared(shared_from_base(), std::make_shared()); - } - - private: - // namespace of the module - std::string module_; - // namespace object - py::object obj_; -}; -using NameSpacePtr = std::shared_ptr; - -// Symbol in NameSpace or Class which shall be resolved. -class Symbol : public Named { - public: - explicit Symbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {} - explicit Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {} - - ~Symbol() override = default; - MS_DECLARE_PARENT(Symbol, Named); - - std::string symbol() { return symbol_; } - abstract::AbstractBasePtr ToAbstract() override { - return std::make_shared(shared_from_base(), std::make_shared()); - } - - private: - std::string symbol_; -}; -using SymbolPtr = std::shared_ptr; - -// PyObjectWrapper class wrappers resolved python object for further processing. -class PyObjectWrapper : public Named { - public: - explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {} - ~PyObjectWrapper() override = default; - MS_DECLARE_PARENT(PyObjectWrapper, Named); - py::object obj() { return obj_; } - - private: - // the object that needs to be resolved - py::object obj_; -}; - -// ClassObject class wrappers dataclass -class ClassObject : public PyObjectWrapper { - public: - explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass") - : PyObjectWrapper(obj, name) {} - ~ClassObject() override = default; - MS_DECLARE_PARENT(ClassObject, PyObjectWrapper); - abstract::AbstractBasePtr ToAbstract() override; -}; - -// ClassType class wrappers class name in python -class ClassType : public PyObjectWrapper { - public: - explicit ClassType(const py::object &obj, const std::string name = "Python class type") - : PyObjectWrapper(obj, name) {} - ~ClassType() override = default; - MS_DECLARE_PARENT(ClassType, PyObjectWrapper); - abstract::AbstractBasePtr ToAbstract() override; -}; - -// SymbolResolver class for resolving symbol extracted from AnfNode. -class SymbolResolver { - public: - SymbolResolver(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) - : namespace_(name_space), symbol_(symbol), resolved_node_(node) {} - - ~SymbolResolver() = default; - - // resolve symbol in namespace and save it in result_; - bool Resolve(); - - NameSpacePtr get_namespace() { return namespace_; } - - SymbolPtr symbol() { return symbol_; } - - py::object &result() { return result_; } - - AnfNodePtr resolved_node() { return resolved_node_; } - - // Resolve result - py::object result_; - - private: - // namespace where the symbol locates - NameSpacePtr namespace_; - // the symbol that needs to be resovled - SymbolPtr symbol_; - // the node that has been resolved - AnfNodePtr resolved_node_; -}; -using SymbolResolverPtr = std::shared_ptr; -// Resolve symbol in namespace. -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, - const AnfNodePtr &node); - -// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). -bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); - -// Resolve all graphs in manager which is defined outside of pipeline::Resource. -// Mainly used for test cases or resolve graphs which will not be managed by manager. -bool ResolveAll(const FuncGraphManagerPtr &manager); - -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_RESOLVE_H_ diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc deleted file mode 100644 index f6cfd6362c..0000000000 --- a/mindspore/ccsrc/pipeline/pass.cc +++ /dev/null @@ -1,342 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/pass.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/func_graph_cloner.h" -#include "debug/anf_ir_utils.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/resource.h" -#include "pipeline/validator.h" -#include "optimizer/optimizer.h" -#include "optimizer/cse.h" -#include "optimizer/graph_kernel_reuse.h" -#include "optimizer/clean.h" -#include "optimizer/irpass.h" -#include "optimizer/control_depend.h" -#include "parallel/step_parallel.h" -#include "parallel/step_auto_parallel.h" -#include "parallel/allreduce_fusion/step_allreduce_fusion.h" -#include "utils/any.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace pipeline { -using OptPassGroupMap = opt::OptPassGroupMap; -using Optimizer = opt::Optimizer; -using CompileGraphs = compile::CompileGraphs; -using abstract::AnalysisResult; -using mindspore::abstract::AnalysisContextPtr; -using mindspore::validator::Validate; - -bool SimplifyDataStructuresPass(const ResourcePtr &res) { - MS_EXCEPTION_IF_NULL(res->func_graph()); - - FuncGraphPtr func_graph = res->func_graph(); - bool changed = opt::SimplifyDataStructures(func_graph, res->manager()); - - abstract::AbstractBasePtrList args_spec; - auto parameters = func_graph->parameters(); - (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), - [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); - if (changed) { - FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); - res->set_func_graph(new_fg); - } - res->set_args_spec(args_spec); - return true; -} - -namespace { -OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig a_1 = opt::OptPassConfig({ - irpass.switch_simplify_, - - // Safe inlining - irpass.inline_, - irpass.partial_eliminate_, - irpass.replace_applicator_, - - // Specialization - irpass.specialize_transform_, - - // Miscellaneous - irpass.item_tuple_eliminate_, - irpass.env_get_item_eliminate_, - irpass.cast_eliminate_, - irpass.reshape_eliminate_, - irpass.reduce_eliminate_, - irpass.tile_eliminate_, - irpass.transpose_eliminate_, - irpass.minmaximum_grad_, - irpass.get_make_ref_eliminate_, - - // Arithmetic simplifications - irpass.arithmetic_simplify_, - irpass.addn_zero_filter_, - irpass.adjust_all_reduce_mul_add_, - - // Safe inlining - irpass.inline_, - }); - opt::OptPassConfig a_2 = opt::OptPassConfig({ - irpass.merge_addn_, - irpass.float_tuple_getitem_switch_, - irpass.float_env_getitem_switch_, - irpass.incorporate_getitem_set_, - irpass.incorporate_call_, - irpass.incorporate_call_switch_, - irpass.incorporate_env_getitem_, - irpass.incorporate_env_getitem_switch_, - irpass.new_env_get_item_, - irpass.depend_value_elim_, - }); - opt::OptPassConfig a_3 = opt::OptPassConfig({ - irpass.arithmetic_simplify2_, - irpass.same_eliminate_, - irpass.check_bprop_eliminate_, - irpass.replace_applicator_, - }); - opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); - opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); - opt::irpass::ResolveIRPassLib resolve_irpass; - - opt::OptPassConfig resolve_pass = - opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, - irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); - - OptPassGroupMap map_a({{"a_1", a_1}, - {"a_2", a_2}, - {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)}, - {"parallel", opt::OptPassConfig(parallel::StepParallel)}, - {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, - {"virtual_dataset", virtual_dataset}, - {"grad", grad}, - {"resolve", resolve_pass}, - {"renormalize", opt::OptPassConfig::Renormalize()}, - {"cse", opt::OptPassConfig(opt::CSE(false))}, - {"a_3", a_3}}); - - return map_a; -} - -OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig b_1 = opt::OptPassConfig({ - irpass.zero_like_fill_zero_, - irpass.item_tuple_eliminate_, - irpass.float_tuple_getitem_switch_, - irpass.reset_defer_inline_, - irpass.inline_, - irpass.special_op_eliminate_, - irpass.get_make_ref_eliminate_, - }); - opt::OptPassConfig b_2 = opt::OptPassConfig({ - irpass.replace_refkey_by_param_, - irpass.make_ref_eliminate_, - irpass.get_ref_param_eliminate_, - irpass.indexed_slices_eliminate_, - }); - OptPassGroupMap map({ - {"b_1", b_1}, - {"b_2", b_2}, - {"renormalize", opt::OptPassConfig::Renormalize()}, - {"cse", opt::OptPassConfig(opt::CSE(false))}, - }); - return map; -} - -OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig interface_fusion = opt::OptPassConfig({ - irpass.mark_interface_fusion_, - }); - OptPassGroupMap map({ - {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, - {"interface_fusion", interface_fusion}, - {"renormalize", opt::OptPassConfig::Renormalize()}, - {"cse", opt::OptPassConfig(opt::CSE(false))}, - }); - return map; -} - -OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig elim_1 = opt::OptPassConfig({ - irpass.addn_eliminate_, - irpass.incorporate_getitem_from_param_, - }); - opt::OptPassConfig elim_2 = opt::OptPassConfig({ - irpass.unused_parameter_eliminate_, - irpass.unused_output_eliminate_, - }); - OptPassGroupMap map({ - {"elim_1", elim_1}, - {"renormalize", opt::OptPassConfig::Renormalize()}, - {"elim_2", elim_2}, - }); - return map; -} - -OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) { - return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}}); -} - -OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true); - OptPassGroupMap map({ - {"control_group", control_group}, - {"renormalize", opt::OptPassConfig::Renormalize()}, - }); - return map; -} - -OptPassGroupMap GetInferenceOptPreparePhases() { - opt::irpass::InferenceOptPrepareLib irpass; - auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); - opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}}); - return prepare_map; -} - -OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); - OptPassGroupMap map({{"prepare_group", prepare_group}}); - return map; -} - -static std::unordered_map> g_pass_opts = {}; - -void InitOpt(const ResourcePtr &res) { - if (g_pass_opts.size() == 0) { - opt::irpass::OptimizeIRPassLib irpass; - g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); - g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); - g_pass_opts["opt_graph_kernel_a"] = - Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); - g_pass_opts["opt_graph_kernel_b"] = - Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false); - g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); - g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); - g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { - g_pass_opts["opt_graph_kernel_a"]->set_enable(false); - g_pass_opts["opt_graph_kernel_b"]->set_enable(false); - } - } -} -} // namespace - -void ReclaimOptimizer() { - for (auto &opt : g_pass_opts) { - opt.second = nullptr; - } - g_pass_opts.clear(); -} - -bool OptPassGroup(const ResourcePtr &res, const std::string &name) { - if (res->func_graph() == nullptr) { - MS_LOG(ERROR) << "Opt passes int error"; - return false; - } - - FuncGraphPtr func_graph = res->func_graph(); - MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", " - << func_graph->get_return()->DebugString(true); - InitOpt(res); - if (g_pass_opts.find(name) != g_pass_opts.end()) { - res->set_func_graph(g_pass_opts[name]->step(func_graph)); - } - // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to - // res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here. - return true; -} - -bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } -bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } -bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } -bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } -bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } -bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } - -bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } - -bool AddControlDependPass(const ResourcePtr &res) { - FuncGraphPtr func_graph = res->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - - if (func_graph->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { - opt::AddControlDepend(func_graph); - } - for (auto fg : func_graph->func_graphs_used_total()) { - MS_EXCEPTION_IF_NULL(fg); - if (fg->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { - opt::AddControlDepend(fg); - } - } - return true; -} - -bool CconvPass(const ResourcePtr &res) { - MS_EXCEPTION_IF_NULL(res->func_graph()); - FuncGraphPtr func_graph = res->func_graph(); - FuncGraphPtr new_fg = LiftingClone(func_graph); - res->set_func_graph(new_fg); - return true; -} - -bool ValidatePass(const ResourcePtr &res) { - MS_EXCEPTION_IF_NULL(res->func_graph()); - FuncGraphPtr func_graph = res->func_graph(); - Validate(func_graph); - return true; -} - -bool InferenceOptPreparePass(const ResourcePtr &res) { - FuncGraphPtr func_graph = res->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - auto prepare_map = GetInferenceOptPreparePhases(); - auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); - (void)infer_opt_prepare->step(func_graph, false); - return true; -} - -std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, - {"opt_a", OptPassAGroup}, - {"opt_b", OptPassBGroup}, - {"cconv", CconvPass}, - {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, - {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, - {"add_control_depend", AddControlDependPass}}; - -std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, - {"opt_a", OptPassAGroup}, - {"opt_b", OptPassBGroup}, - {"add_control_depend", AddControlDependPass}, - {"opt_control", ControlGroup}, - {"opt_prepare", PrepareGroup}, - {"cconv", CconvPass}}; - -std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pass.h b/mindspore/ccsrc/pipeline/pass.h deleted file mode 100644 index 9064df52ee..0000000000 --- a/mindspore/ccsrc/pipeline/pass.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_PASS_H_ -#define MINDSPORE_CCSRC_PIPELINE_PASS_H_ - -#include -#include -#include -#include -#include "pipeline/resource.h" - -namespace mindspore { -namespace pipeline { -using PassItem = std::pair>; - -extern std::vector kGePasses; -extern std::vector kVmPasses; -extern std::vector kPynativePasses; - -bool CconvPass(const ResourcePtr &res); -bool ValidatePass(const ResourcePtr &res); -bool ConvertPrepareAdapt(const ResourcePtr &res); -bool AddControlDependPass(const ResourcePtr &res); -bool InferenceOptPreparePass(const ResourcePtr &res); -void ReclaimOptimizer(); -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_PASS_H_ diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc deleted file mode 100644 index 6abe198f5a..0000000000 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ /dev/null @@ -1,980 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/pipeline.h" - -#include -#include -#include -#include -#include - -#include "ir/param_value_py.h" -#include "pipeline/pass.h" -#include "pipeline/parse/data_converter.h" -#include "optimizer/ad/dfunctor.h" -#include "debug/anf_ir_dump.h" -#include "debug/anf_ir_utils.h" -#include "utils/config_manager.h" -#include "utils/convert_utils.h" -#include "utils/utils.h" -#include "vm/segment_runner.h" -#include "parallel/context.h" -#include "parallel/graph_util/get_parallel_info.h" -#include "device/kernel_runtime_manager.h" -#include "debug/trace.h" -#include "pynative/pynative_execute.h" -#include "optimizer/py_pass_manager.h" - -#if (ENABLE_GE || ENABLE_D) -#include "pipeline/pipeline_ge.h" -#include "transform/convert.h" -#include "transform/df_graph_manager.h" -#endif - -namespace mindspore { -// namespace to support intermediate representation definition -namespace pipeline { -using Tensor = mindspore::tensor::Tensor; -using MetaTensor = mindspore::tensor::MetaTensor; -using TensorOrderMap = std::map>; -using mindspore::abstract::AbstractTensor; -using mindspore::abstract::AbstractTensorPtr; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractTuplePtr; - -const char IR_TYPE_ANF[] = "anf_ir"; -const char IR_TYPE_ONNX[] = "onnx_ir"; -const char IR_TYPE_BINARY[] = "binary_ir"; - -ExecutorPyPtr ExecutorPy::executor_ = nullptr; -std::mutex ExecutorPy::instance_lock_; - -std::unordered_map - g_args_cache; - -namespace { -std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) { - std::ostringstream oss; - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(EXCEPTION) << "ms_context is nullptr"; - } - auto save_graphs_path = ms_context->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - oss << save_graphs_path << "/" << stage_idx << "_" << action_name; - return oss.str(); -} -} // namespace - -py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults) { - MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size(); - abstract::AbstractBasePtrList args_spec; - - for (auto arg : defaults) { - if (py::isinstance(arg.second)) { - MS_LOG(EXCEPTION) << "GenerateKey failed, argument input should not be py::module"; - } - ValuePtr converted = nullptr; - if (!parse::ConvertData(arg.second, &converted)) { - MS_LOG(EXCEPTION) << "GenerateKey convert arg failed"; - } - args_spec.push_back(abstract::FromValue(converted, true)); - } - if (g_args_cache.count(args_spec) == 0) { - static int key = 0; - MS_LOG(INFO) << "Start new args and compile key:" << key; - g_args_cache[args_spec] = key++; - } - auto argSpec = py::tuple(2); - argSpec[0] = name; - argSpec[1] = g_args_cache[args_spec]; - return argSpec; -} - -py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs) { - MS_LOG(DEBUG) << "Verify args size:" << inputs.size(); - if (inputs.size() != input_signature.size()) { - MS_LOG(ERROR) << "Signature size not equal to args size"; - return false; - } - - size_t count = 0; - for (auto arg_obj : inputs) { - if (py::hasattr(arg_obj, PYTHON_TENSOR_FLAG)) { - MS_LOG(DEBUG) << "Verify Tensor"; - std::shared_ptr m_tensor = arg_obj.cast>(); - if (m_tensor == nullptr) { - MS_LOG(ERROR) << "Verify Tensor error, get ptr is null"; - return false; - } - std::shared_ptr sig = input_signature[count].cast>(); - std::vector sig_shape = sig->shape(); - TypePtr sig_type = sig->Dtype(); - - std::vector tensor_shape = m_tensor->shape_c(); - if (tensor_shape != sig_shape) { - MS_LOG(ERROR) << "Python input shape is incompatible with input_signature"; - return false; - } - - if (*m_tensor->Dtype() != *sig_type) { - MS_LOG(ERROR) << "Python input type(" << m_tensor->Dtype()->ToString() << ") incompatible with input_signature(" - << sig_type->ToString() << ")"; - return false; - } - } - count++; - } - - return true; -} - -ExecutorPy::ExecutorPy() {} - -ResourcePtr ExecutorPy::GetResource(const std::string &phase) { - MS_LOG(DEBUG) << "Phase size:" << info_.size(); - if (info_.count(phase) == 0) { - return nullptr; - } - return info_[phase]->resource; -} - -FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) { - if (info_.count(phase) == 0) { - MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); - } - return info_[phase]->func_graph; -} - -std::size_t ExecutorPy::ArgListSize(const std::string &phase) { - if (info_.count(phase) == 0) { - MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); - } - return info_[phase]->arg_list_size; -} - -compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) { - ResourcePtr res = GetResource(phase); - MS_EXCEPTION_IF_NULL(res); - if (res->results().find(kOutput) != res->results().end() && res->results()[kOutput].is()) { - return res->results()[kOutput].cast(); - } - MS_LOG(ERROR) << "GetVmEvalFunc vm model can't find kOutput:" << kOutput; - return nullptr; -} - -bool ExecutorPy::HasCompiled(const std::string &phase) const { - if (info_.count(phase) == 0) { - return false; - } - return true; -} - -py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type) { - FuncGraphPtr fg_ptr = GetFuncGraph(phase); - if (fg_ptr == nullptr) { - for (auto &item : info_) { - MS_LOG(DEBUG) << "Phase key is: " << item.first; - } - MS_LOG(EXCEPTION) << "Can not find func graph " << phase; - } - - if (ir_type == IR_TYPE_ANF) { - std::string proto_str = GetFuncGraphProtoString(fg_ptr); - if (proto_str.empty()) { - MS_LOG(EXCEPTION) << "Graph proto is empty."; - } - return proto_str; - } - - if (ir_type == IR_TYPE_ONNX) { - std::string proto_str = GetOnnxProtoString(fg_ptr); - if (proto_str.empty()) { - MS_LOG(EXCEPTION) << "Graph proto is empty."; - } - return proto_str; - } - - if (ir_type == IR_TYPE_BINARY) { - std::string proto_str = GetBinaryProtoString(fg_ptr); - if (proto_str.empty()) { - MS_LOG(EXCEPTION) << "Graph proto is empty."; - } - return proto_str; - } - - MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type; -} - -py::dict ExecutorPy::GetParameterLayout(const std::string &phase) { - MS_LOG(DEBUG) << "GetParameterLayout!"; - std::string layout_graph = phase + kStepParallelGraph; - auto graph = GetFuncGraph(layout_graph); - return mindspore::parallel::GetParameterLayout(graph); -} - -py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) { - MS_LOG(DEBUG) << "GetCNodeStrategy!"; - std::string layout_graph = phase + kStepParallelGraph; - auto graph = GetFuncGraph(layout_graph); - return mindspore::parallel::GetCNodeStrategy(graph); -} - -py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) { - MS_LOG(INFO) << "GetAllreduceFusion!"; - auto graph = GetFuncGraph(phase); - return mindspore::parallel::GetAllreduceFusion(graph); -} - -void ExecutorPy::DelNetRes(const std::string &id) { -#ifdef ENABLE_GE - FinalizeBackend(); -#endif - if (executor_ != nullptr) { - bool flag = false; - auto tmp_info = info_; - for (auto &item : tmp_info) { - if (item.first.find(id) != string::npos) { - MS_LOG(DEBUG) << "Delete network res:" << item.first; - (void)info_.erase(item.first); - flag = true; - } - } - - MS_LOG(DEBUG) << "Delete flag:" << flag; -#ifdef ENABLE_GE - if (flag && info_.size() == 0) { - // because Ge only support one Session exist at the same time ,so we delete the old one - transform::DfGraphManager::GetInstance().DeleteGraphRunner(); - transform::DfGraphManager::GetInstance().EraseAnfGraph(); - transform::DfGraphManager::GetInstance().DeleteGeSession(); - } -#endif - } -} - -void ExecutorPy::ClearRes() { - MS_LOG(INFO) << "Clean executor resource!"; - executor_ = nullptr; -} - -ExecutorPy::~ExecutorPy() { - MS_LOG(INFO) << "Release Executor!"; - ConfigManager::GetInstance().ResetConfig(); -} - -std::map> ExecutorPy::FetchInfoForQuantExport( - const std::string &phase_s) { - FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; - std::map> fake_quant_table; - auto filter = [](AnfNodePtr node) { - return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul)); - }; - std::vector nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter); - auto is_quant_cnode = [](AnfNodePtr node) { - return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) || - IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel); - }; - for (auto node : nodes) { - auto cnode = node->cast(); - if (cnode == nullptr || cnode->size() != 3) { - continue; - } - auto x = cnode->input(1); - auto weight = cnode->input(2); - if (!is_quant_cnode(weight)) { - continue; - } - // get parameter weight's name - cnode = weight->cast(); - auto weight_node = cnode->input(2); - if (!weight_node->isa()) { - continue; - } - auto weight_name = weight_node->cast()->name(); - // find the fakequant from input - int count = 0; - const int max_depth = 5; - while (!is_quant_cnode(x)) { - if (count >= max_depth) { - break; - } - cnode = x->cast(); - if (cnode == nullptr || cnode->size() <= 1) { - break; - } - x = cnode->input(1); - count += 1; - } - // get the fakequant parameter minq's name - if (!is_quant_cnode(x)) { - continue; - } - cnode = x->cast(); - if (cnode == nullptr || cnode->size() != 4) { - continue; - } - auto fakequant_min_node = cnode->input(2); - if (!fakequant_min_node->isa()) { - continue; - } - auto fakequant_min_node_name = fakequant_min_node->cast()->name(); - auto quant_op_value = cnode->input(0)->cast()->value(); - if (!quant_op_value->isa()) { - continue; - } - auto quant_op = quant_op_value->cast(); - fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name); - } - - return fake_quant_table; -} - -void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { - // save the graph to ExecutorPy - FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); - std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); - - MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; - info_[phase_s]->func_graph = func_graph; - if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) && - ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) { - MS_LOG(DEBUG) << "Save model parallel parameter layout graph!"; - func_graph = info_[phase_s]->resource->results()[kStepParallelGraph].cast(); - ExecutorInfoPtr executor_info = std::make_shared(); - std::string layout_graph = phase_s + kStepParallelGraph; - executor_info->func_graph = func_graph; - info_[layout_graph] = executor_info; - } else { - MS_LOG(DEBUG) << "Save model parallel parameter layout graph null!"; - } - MS_LOG(INFO) << "End save compiled func graph!"; -} - -void ExecutorPy::SaveCompiledGraphToPb(const std::string &phase_s) { -#ifdef ENABLE_DUMP_IR - // save the graph to file in protobuf format - FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - if (phase_s.empty()) { - MS_LOG(ERROR) << "`phase` is empty '" << phase_s << "'!"; - return; - } - std::string name_prefix = phase_s.substr(0, phase_s.find(".")); - std::string pb_filename = std::string("ms_output_") + name_prefix + ".pb"; - std::string filename = GetFilePathName(pb_filename); - - MS_LOG(INFO) << "Begin saving graph to file <<'" << filename << "' in protobuf formart."; - ChangeFileMode(filename, S_IRWXU); - std::ofstream ofs(filename); - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file '" << filename << "' failed!"; - return; - } - ofs << GetFuncGraphProtoString(func_graph); - ofs.close(); - // set file mode to read only by user - ChangeFileMode(filename, S_IRUSR); - MS_LOG(INFO) << "End saving graph to file in protobuf format"; -#endif -} - -bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const { - std::string phase_prefix = GetPhasePrefix(phase_s); - - if (use_vm && phase_prefix == "export") { - MS_LOG(INFO) << "Use ge backend to export geir"; - use_vm = false; - } - return use_vm; -} - -void ExecutorPy::GetGeBackendPolicy() const { - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - std::string backend = ms_context->backend_policy(); - if (backend != "ge") { - MS_LOG(EXCEPTION) << backend << " backend policy is not supported under ge backend!"; - } -} - -bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { - MS_LOG(DEBUG) << "Start ExecutorPy compile!"; - if ((!py::isinstance(phase))) { - MS_LOG(ERROR) << "Arg phase must be string."; - return false; - } - // check the arg valid? - if (py::isinstance(obj)) { - MS_LOG(ERROR) << "Find error: parse obj is None."; - return false; - } -#ifdef ENABLE_GE - GetGeBackendPolicy(); -#endif - ExecutorInfoPtr executor_info = std::make_shared(); - std::string phase_s = py::cast(phase); - MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!"; - ResourcePtr resource = std::make_shared(obj); - std::vector p_actions; - - use_vm = ChangeExportGeirUseVmFlag(use_vm, phase_s); - - std::string backend = MsContext::GetInstance()->backend_policy(); - if (use_vm && backend != "ge") { - // Create backend and session - auto backend_ptr = compile::CreateBackend(); - // Connect session to debugger - backend_ptr->SetDebugger(); - resource->results()[kBackend] = backend_ptr; - p_actions = VmPipeline(); - } else { - p_actions = GePipeline(); - } - - std::shared_ptr pip = std::make_shared(resource, FilterActions(p_actions, phase_s)); - - // get the parameters items and add the value to args_spec - abstract::AbstractBasePtrList args_spec; - std::size_t size = args.size(); - for (std::size_t i = 0; i < size; i++) { - ValuePtr converted = nullptr; - bool succ = parse::ConvertData(args[i], &converted); - if (!succ) { - MS_LOG(EXCEPTION) << "Args convert error"; - } - bool broaden = true; - args_spec.push_back(abstract::FromValue(converted, broaden)); - } - - resource->set_args_spec(args_spec); - executor_info->arg_list_size = size; - executor_info->resource = resource; - info_[phase_s] = executor_info; - pip->Run(); - - // save compile graph to file in protobuf format - SaveCompiledGraphToPb(phase_s); - // save the run graph func to MsPipeLine - SaveCompiledGraph(phase_s); - - resource->Clean(); - // Reclaim all resource used by optimizer; - ReclaimOptimizer(); - - MS_LOG(INFO) << "End ExecutorPy compile!"; - return true; -} - -std::vector ExecutorPy::FilterActions(const std::vector &actions, const std::string &phase) { - // phase does not contain 'export_onnx' - if (GetPhasePrefix(phase).find("export_onnx") == std::string::npos) { - return actions; - } - MS_LOG(INFO) << "Phase is '" << phase << "', filter out actions after stage 'validate'"; - std::vector filtered_actions; - for (const auto &item : actions) { - filtered_actions.emplace_back(item); - if (item.first == "validate") { - break; - } - } - return filtered_actions; -} - -void ExecutorPy::ReleaseResource(const py::object &phase) { - ResourcePtr res = GetResource(py::cast(phase)); - if (res != nullptr) { - res->Clean(); - } - // Reclaim all resource used by optimizer; - ReclaimOptimizer(); -} - -static std::string PrintArgs(const py::tuple &args) { - py::print(args); - return ""; -} - -bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { - bool ret_value = false; - - try { - MS_LOG(DEBUG) << PrintArgs(args); - ret_value = CompileInner(obj, args, phase, use_vm); - } catch (const py::error_already_set &ex) { - // print function call stack info before release - std::ostringstream oss; - trace::TraceGraphEval(); - trace::GetEvalStackInfo(oss); - // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see - // these info from screen, no need to open log file to find these info - py::print(oss.str()); - MS_LOG(ERROR) << oss.str(); - ReleaseResource(phase); - - // re-throw this exception to Python interpreter to handle it - throw(py::error_already_set(ex)); - } catch (const py::type_error &ex) { - ReleaseResource(phase); - throw py::type_error(ex); - } catch (const py::value_error &ex) { - ReleaseResource(phase); - throw py::value_error(ex); - } catch (const py::index_error &ex) { - ReleaseResource(phase); - throw py::index_error(ex); - } catch (const std::exception &ex) { - ReleaseResource(phase); - // re-throw this exception to Python interpreter to handle it - throw(std::runtime_error(ex.what())); - } catch (...) { - ReleaseResource(phase); - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; - } - - return ret_value; -} - -#ifdef ENABLE_LOAD_ANF_IR -// get MindSpore Intermediate Representation File -std::string GetMsIrFile(void) { - std::string file; - const char *path = getenv("MS_IR_FILE"); - if (path == nullptr) { - return file; - } - - char real_path[PATH_MAX] = {0}; - if (realpath(path, real_path) == nullptr) { - MS_LOG(ERROR) << "MS IR path error, " << path; - return file; - } - file = real_path; - return file; -} - -void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource, bool *result) { - MS_EXCEPTION_IF_NULL(resource); - MS_EXCEPTION_IF_NULL(result); - - std::string ir_file = GetMsIrFile(); - (void)parse::python_adapter::set_python_scoped(); - if (ir_file.empty()) { - *result = action.second(resource); - return; - } - - // when in loading anf ir mode, action `parse` do nothing - if (action.first == "parse") { - return; - } - - // load MindSpore IR from file - if (action.first == "symbol_resolve") { - MS_LOG(DEBUG) << action.first << " read ir file: " << ir_file; - std::vector graphs = ImportIR(ir_file); - if (graphs.size() == 0) { - MS_LOG(EXCEPTION) << action.first << " read ir file " << ir_file << " failed as no graph found"; - } - auto manager = resource->manager(); - MS_EXCEPTION_IF_NULL(manager); - for (auto &graph : graphs) { - manager->AddFuncGraph(graph); - } - resource->set_func_graph(graphs[0]); - return; - } - - // do normal action when not in `parse` and `symbol_resolve` stage - *result = action.second(resource); -} -#endif - -void Pipeline::Run() { - MS_LOG(INFO) << "Pipeline run"; - MS_EXCEPTION_IF_NULL(resource_); - FuncGraphPtr user_graph = nullptr; - - WITH(MsProfile::GetProfile())[&user_graph, this]() { - int i = 0; - for (auto &action : actions_) { -#ifdef ENABLE_TIMELINE - DumpTime &dump_time = DumpTime::GetInstance(); - dump_time.Record(action.first, GetTime(), true); -#endif - bool result = true; - WITH(MsProfile::GetProfile()->Step(action.first))[&result, &action, this]() { - MS_LOG(DEBUG) << "Action " << action.first << " start ..."; -#ifdef ENABLE_LOAD_ANF_IR - RunPipelineAction(action, resource_, &result); -#else - result = action.second(resource_); -#endif - MS_LOG(DEBUG) << "Action " << action.first << " end."; - }; - if (!result) { - MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first; - } - if (MsContext::GetInstance()->save_graphs_flag() && resource_->func_graph() != nullptr) { - auto graph = resource_->func_graph(); - if (graph != nullptr) { - user_graph = graph; - std::string base_name = GetBaseNameForIR(i, action.first); - - // generate IR file in dot format, which can be converted to svg file using graphviz dot command - draw::Draw(base_name + ".dot", graph); - // generate IR file in human readable format - DumpIR(base_name + ".ir", graph); - - // generate IR file in a heavily commented format, which can also be reloaded - if (action.first != "parse") { - ExportIR(base_name + ".dat", std::to_string(i), graph); - } - } -#ifdef MS_DEBUG - // Dump graph cnode list - MS_LOG(INFO) << "Show CNode list after " << action.first; - graph->DumpCNodeList(); -#endif - } - if (resource_->func_graph() != nullptr) { - auto func_graph = resource_->func_graph(); - if (func_graph->has_flag(GRAPH_FLAG_HAS_EFFECT)) { - func_graph->EraseUnusedNodeInOrder(); - func_graph->CheckOrder(); - for (auto fg : func_graph->func_graphs_used_total()) { - MS_LOG(DEBUG) << "Check order graph " << fg->ToString() << "."; - fg->EraseUnusedNodeInOrder(); - fg->CheckOrder(); - } - } - } - i++; -#ifdef ENABLE_TIMELINE - dump_time.Record(action.first, GetTime(), false); -#endif - } - }; -#ifdef ENABLE_PROFILE - MsProfile::Print(); - MsProfile::Reset(); -#endif - - if (MsContext::GetInstance()->save_graphs_flag() && (user_graph != nullptr)) { - std::string user_graph_file = GetFilePathName("ModelDigraph.dot"); - MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file; - draw::DrawUserFuncGraph(user_graph_file, user_graph); - } - MS_LOG(INFO) << "End"; -} - -void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) { - std::size_t size = args.size(); - - for (std::size_t i = 0; i < size; i++) { - py::object arg = args[i]; - auto ms_context = MsContext::GetInstance(); - if (ms_context->backend_policy() == kMsConvert && py::isinstance(arg)) { - MS_LOG(EXCEPTION) << "The " << i << "th arg is numpy array, not tensor."; - } - ValuePtr converted = nullptr; - bool succ = parse::ConvertData(arg, &converted); - if (!succ) { - MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; - } - if (MsContext::GetInstance()->execution_mode() == 0 && !converted->isa()) { - MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString() - << " is not tensor."; - } - arg_list->push_back(converted); - } - - MS_EXCEPTION_IF_NULL(res); - auto graph = res->func_graph(); - MS_EXCEPTION_IF_NULL(graph); - std::vector graph_params = graph->parameters(); - std::size_t graph_params_size = graph_params.size(); - if ((*arg_list).size() != graph_params_size) { - // maybe some default parameter - for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) { - MS_EXCEPTION_IF_NULL(graph_params[i]); - auto param_ptr = (graph_params[i])->cast(); - if (!param_ptr->has_default()) { - MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param"; - } - auto param_value = std::dynamic_pointer_cast(param_ptr->default_param()); - py::object obj = param_value->value(); - py::object p_value = py::cast(parse::python_adapter::GetPyObjAttr(obj, "default_input")); - (*arg_list).push_back(p_value); - } - } -} - -void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list) { - ProcessVmArgInner(args, GetResource(phase), arg_list); -} - -py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { - std::size_t size = args.size(); - if (!py::isinstance(phase)) { - MS_LOG(EXCEPTION) << "Run failed, phase input is not a str"; - } - auto phase_s = py::cast(phase); - std::string backend = MsContext::GetInstance()->backend_policy(); -#ifdef ENABLE_GE - if (backend == "ge") { - return ExecDFGraph(info_, args, phase_s); - } -#else - if (backend == "ms" || backend == "ge") { - auto ret_val = std::make_shared(); - if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) { - if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) { - return *ret_val; - } - } - if (backend == "ge") { - if (args.size() > 0) { - return args[0]; - } - return args; - } - } -#endif - std::size_t full_arg_size = ArgListSize(phase_s); - if (size > full_arg_size) { - MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << full_arg_size; - } - VectorRef arg_list; - ProcessVmArg(args, phase_s, &arg_list); - - compile::VmEvalFuncPtr run = GetVmEvalFunc(phase_s); - if (run == nullptr) { - MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s; - } - - MS_LOG(DEBUG) << "Eval run" << backend; - BaseRef value = (*run)(arg_list); - MS_LOG(DEBUG) << "Run end"; - return BaseRefToPyData(value); -} - -FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase, - const py::object &broadcast_params) { -#if (ENABLE_GE || ENABLE_D) - return BuildDFGraph(info_, init_params, phase, broadcast_params); -#else - return nullptr; -#endif -} - -void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) { -#if ENABLE_GE - RunGEInitGraph(init_params, phase); -#endif -} - -bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase, bool need_run) { - std::string name = MsContext::GetInstance()->backend_policy(); -#ifndef NO_DLIB - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (!ms_context->IsTsdOpened() || !ms_context->IsGeInited()) { - (void)InitBackend(); - } -#endif - if (name == kMsConvert || name == kMsVm) { - return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run); - } -#if ENABLE_GE - return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase); -#else - std::string backend = MsContext::GetInstance()->backend_policy(); - if (backend == "ge") { - return true; - } -#endif - return false; -} - -bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, bool need_run) { - MS_LOG(INFO) << "Start InitDataSet Entry"; - std::vector int_input_indexes; - (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), - [](int64_t item) { return static_cast(item); }); - std::vector> int_shapes; - (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes), - [](const std::vector &item) { - std::vector vector_item; - (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item), - [](int64_t inner_item) { return static_cast(inner_item); }); - return vector_item; - }); - auto p_init = std::make_shared("InitDataSetQueue"); - p_init->set_attr("queue_name", MakeValue(queue_name)); - p_init->set_attr("size", MakeValue(static_cast(size))); - p_init->set_attr("batch_size", MakeValue(static_cast(batch_size))); - p_init->set_attr("types", MakeValue(types)); - p_init->set_attr("shapes", MakeValue(int_shapes)); - p_init->set_attr("input_indexes", MakeValue(int_input_indexes)); - - const std::vector empty_str_list; - p_init->set_attr("input_names", MakeValue(empty_str_list)); - p_init->set_attr("output_names", MakeValue(empty_str_list)); - - FuncGraphPtr func_graph = std::make_shared(); - auto app_init = std::make_shared(AnfNodePtrList{NewValueNode(p_init)}, func_graph); - func_graph->set_output(app_init); - auto manager = MakeManager(); - manager->AddFuncGraph(func_graph); - - // AbstractNone indicates there is no output for this apply node. - auto abstract_none = std::make_shared(); - app_init->set_abstract(abstract_none); - - auto backend = compile::CreateBackend(); - MS_EXCEPTION_IF_NULL(backend); - auto convert_fn = backend->convert_fn(); - MS_EXCEPTION_IF_NULL(convert_fn); - // Convert CNodeList to LinConvertResult. - ConfigManager::GetInstance().set_iter_num(1); - auto runner = convert_fn({app_init}, ""); - if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { - backend->Link(runner.graph_id); - } - ConfigManager::GetInstance().set_iter_num(size); - - if (!(*runner.run)) { - // empty function - MS_LOG(EXCEPTION) << "Backend " << backend->name() << " unsupported tdt dataset."; - } - - // launch init dataset runner without inputs and outputs - VectorRef args; - auto fn = runner.run; - if (need_run) { - (void)(*fn)(args); - } - MS_LOG(DEBUG) << "InitDataSetVm End."; - return true; -} - -void ResetOpId() { mindspore::id_generator::reset_id(); } - -void InitHccl() { -#ifdef ENABLE_GE - (void)InitBackend(); -#else - mindspore::parse::python_adapter::set_python_env_flag(true); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - (void)ms_context->OpenTsd(); - uint32_t device_id = ms_context->device_id(); - std::string device_name = ms_context->device_target(); - ms_context->set_enable_hccl(true); - if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) { - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id); - MS_EXCEPTION_IF_NULL(runtime_instance); - if (!runtime_instance->Init()) { - MS_LOG(ERROR) << "Kernel runtime init error."; - return; - } - } -#endif -} - -void FinalizeHccl() { -#ifdef ENABLE_GE - (void)FinalizeBackend(); -#else - device::KernelRuntimeManager::Instance().ClearRuntimeResource(); -#endif -} - -void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) { -#if (ENABLE_GE || ENABLE_D) - ExportDFGraph(file_name, phase); -#endif - MS_LOG(WARNING) << "In ut test no export_graph"; -} - -void ReleaseGeTsd() { - auto context_ptr = MsContext::GetInstance(); - if (context_ptr != nullptr) { - (void)context_ptr->FinalizeGe(true); - (void)context_ptr->CloseTsd(true); - } -} - -void InitBackend() { - // set python env flag - mindspore::parse::python_adapter::set_python_env_flag(true); - // open tsd before ge initialize - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (!ms_context->OpenTsd()) { - MS_LOG(EXCEPTION) << "Open tsd failed"; - } - (void)ms_context->InitGe(); -} - -void FinalizeBackend() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - (void)context_ptr->FinalizeGe(); - (void)context_ptr->CloseTsd(); -} - -void ClearResAtexit() { - MS_LOG(DEBUG) << "Pipeline clear all resource"; - pynative::ClearPyNativeSession(); - session::ClearPythonParasMap(); - device::KernelRuntimeManager::Instance().ClearRuntimeResource(); - - ad::g_k_prims.clear(); - - abstract::ClearPrimEvaluatorMap(); - compile::ClearConvertCache(); - pipeline::GetMethodMap().clear(); - pipeline::ExecutorPy::ClearRes(); - pipeline::ReclaimOptimizer(); - pynative::PynativeExecutor::GetInstance()->ClearRes(); - opt::python_pass::PyPassManager::GetInstance()->ClearRes(); -#ifdef ENABLE_GE - transform::DfGraphManager::GetInstance().ClearGraph(); - transform::DfGraphConvertor::get_adpt_map().clear(); -#endif - ReleaseGeTsd(); - parse::python_adapter::ResetPythonScope(); -} -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline.h b/mindspore/ccsrc/pipeline/pipeline.h deleted file mode 100644 index 3f1274c417..0000000000 --- a/mindspore/ccsrc/pipeline/pipeline.h +++ /dev/null @@ -1,149 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ -#define MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "utils/base_ref_extends.h" -#include "debug/draw.h" -#include "ir/anf.h" -#include "ir/tensor.h" -#include "pipeline/action.h" -#include "vm/segment_runner.h" -#include "vm/transform.h" -#include "pipeline/base.h" - -namespace mindspore { -extern const char kMsConvert[]; -extern const char kMsVm[]; - -// namespace to support pipeline structures definition -namespace pipeline { - -namespace py = pybind11; - -class Pipeline { - public: - Pipeline(const ResourcePtr &res, const std::vector &actions) : resource_(res), actions_(actions) {} - - ~Pipeline() = default; - - void Run(); - - ResourcePtr resource() { return resource_; } - - private: - ResourcePtr resource_; - std::vector actions_; -}; - -// A function pipeline. -class ExecutorPy : public std::enable_shared_from_this { - public: - static std::shared_ptr GetInstance() { - std::lock_guard i_lock(instance_lock_); - if (executor_ == nullptr) { - executor_ = std::shared_ptr(new (std::nothrow) ExecutorPy()); - } - return executor_; - } - - ~ExecutorPy(); - - void SaveCompiledGraph(const std::string &phase_s); - void SaveCompiledGraphToPb(const std::string &phase_s); - bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); - bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); - - void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list); - - // for pynative mode when use_vm is on - py::object Run(const py::tuple &args, const py::object &phase); - ResourcePtr GetResource(const std::string &phase); - FuncGraphPtr GetFuncGraph(const std::string &phase); - py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); - std::size_t ArgListSize(const std::string &phase); - compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); - bool HasCompiled(const std::string &phase) const; - - FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase, - const py::object &broadcast_params = {}); - void RunInitGraph(const py::dict &init_params, const std::string &phase); - py::dict GetParameterLayout(const std::string &phase); - py::dict GetCNodeStrategy(const std::string &phase); - py::dict GetAllreduceFusion(const std::string &phase); - void DelNetRes(const std::string &id); - void ReleaseResource(const py::object &phase); - static void ClearRes(); - - std::map> FetchInfoForQuantExport(const std::string &phase_s); - - private: - ExecutorPy(); - void ConvertObjectToTensors(const py::dict &dict, std::map *tensors); - bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const; - void GetGeBackendPolicy() const; - // filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after - // 'validate' stage - static std::vector FilterActions(const std::vector &actions, const std::string &phase); - - std::map info_; - static std::shared_ptr executor_; - static std::mutex instance_lock_; -}; -using ExecutorPyPtr = std::shared_ptr; - -// Generate a key for mapping function graph -py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults); -py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs); - -bool InitDistribute(const std::map &options); - -void ResetOpId(); -void InitHccl(); -void FinalizeHccl(); -void InitBackend(); -void FinalizeBackend(); - -void ClearResAtexit(); -void ReleaseGeTsd(); - -void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); - -// init and exec dataset sub graph -bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase, bool need_run); - -// Build and run dataset subgraph for ms backend -bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, bool need_run); - -void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list); - -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.cc b/mindspore/ccsrc/pipeline/pipeline_ge.cc deleted file mode 100644 index 8ec1602315..0000000000 --- a/mindspore/ccsrc/pipeline/pipeline_ge.cc +++ /dev/null @@ -1,535 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/pipeline_ge.h" - -#include -#include -#include -#include -#include - -#include "debug/anf_ir_dump.h" -#include "ir/tensor.h" -#include "transform/convert.h" -#include "transform/df_graph_manager.h" -#include "transform/graph_builder.h" -#include "transform/graph_runner.h" -#include "debug/draw.h" -#include "pipeline/static_analysis/abstract_value.h" - -namespace mindspore { -namespace pipeline { -using Tensor = mindspore::tensor::Tensor; -using MetaTensor = mindspore::tensor::MetaTensor; -using TensorOrderMap = std::map>; -using mindspore::abstract::AbstractTensor; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractTuplePtr; -using mindspore::transform::DfGraphConvertor; -using mindspore::transform::DfGraphManager; -using mindspore::transform::GeTensorPtr; -using mindspore::transform::MeTensorPtr; -using mindspore::transform::Status; -using mindspore::transform::TransformUtil; - -void DoExecNonInputGraph(const std::string &phase) { - std::vector ge_tensors; - std::vector ge_outputs; - transform::RunOptions run_options; - run_options.name = phase; - auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); - if (graph_runner == nullptr) { - MS_LOG(ERROR) << "Can not found GraphRunner"; - return; - } - - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); - if (ret != Status::SUCCESS) { - MS_LOG(ERROR) << "Exec graph:" << run_options.name << " failed"; - return; - } - } -} - -void SetGeOption(const std::map &options) { - ConfigManager::GetInstance().set_ge_initialize_options(options); -} - -Status CreateSessionAndGraphRunner(bool is_training = true) { - std::shared_ptr sess = DfGraphManager::GetInstance().GetGeSession(); - if (sess == nullptr) { - transform::SessionOptions options; - if (is_training) { - options["ge.trainFlag"] = "1"; - options["ge.streamNum"] = "100"; - options["ge.enabledLocalFmkop"] = "1"; - options["ge.hcomParallel"] = "1"; - } else { - options["ge.trainFlag"] = "0"; - } - - options["ge.enablePrintOpPass"] = "0"; - sess = transform::GraphRunner::NewSession(options); - if (sess == nullptr) { - MS_LOG(ERROR) << "Init data graph failed, because of create Ge session failed"; - return Status::FAILED; - } else { - DfGraphManager::GetInstance().SetGeSession(sess); - } - } - - transform::GraphRunnerOptions options; - options.sess_ptr = sess; - auto graph_runner = std::make_shared(options); - if (graph_runner == nullptr) { - MS_LOG(ERROR) << "Create new graph runner failed"; - return Status::FAILED; - } else { - DfGraphManager::GetInstance().SetGraphRunner(graph_runner); - } - - return Status::SUCCESS; -} - -bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase) { - std::vector ge_types; - (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), [](const TypePtr &i) -> int64_t { - return transform::TransformUtil::ConvertDataType(i->type_id()); - }); - - ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE); - ConfigManager::GetInstance().set_iter_num(size); - ConfigManager::GetInstance().set_dataset_phase(phase); - - DatasetGraphParam param(queue_name, size, batch_size, ge_types, shapes, input_indexes); - ConfigManager::GetInstance().set_dataset_param(param); - - if (transform::BuildDatasetGraph(param, phase) != transform::SUCCESS) { - MS_LOG(ERROR) << "Build dateset graph failed."; - return false; - } - -#if ENABLE_TRAIN - (void)setenv("GE_TRAIN", "1", 1); -#else - (void)setenv("GE_TRAIN", "0", 1); -#endif - - if (CreateSessionAndGraphRunner(static_cast(ENABLE_TRAIN)) != Status::SUCCESS) { - MS_LOG(ERROR) << "Create GE Session or GraphRunner failed."; - return false; - } - - MS_LOG(INFO) << "DoExecNonInputGraph:" << phase; - DoExecNonInputGraph(phase); - - return true; -} - -void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) { - for (auto item : dict) { - if ((!py::isinstance(item.first))) { - MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it."; - continue; - } - std::shared_ptr tensor; - std::string name = py::cast(item.first); - if (py::isinstance(item.second.attr("default_input"))) { - // convert float to tensor with shape([1]) - tensor = std::make_shared(kNumberTypeFloat32, std::vector({1})); - *(static_cast(tensor->data_c())) = py::cast(item.second.attr("default_input")); - } else if (py::isinstance(item.second.attr("default_input"))) { - // convert int to tensor with shape([1]) - tensor = std::make_shared(kNumberTypeInt32, std::vector({1})); - *(static_cast(tensor->data_c())) = py::cast(item.second.attr("default_input")); - } else if (py::hasattr(item.second.attr("default_input"), PYTHON_TENSOR_FLAG)) { - // cast tensor - tensor = py::cast>(item.second.attr("default_input")); - } - - if (tensor == nullptr) { - MS_LOG(EXCEPTION) << "Get default value for " << name << " failed"; - } - (void)tensors->emplace(name, tensor); - } -} - -bool AddDFGraph(const std::map &info, const py::dict &init_params, - const std::string &phase, const py::object &broadcast_params) { - FuncGraphPtr anf_graph = info.at(phase)->func_graph; - DfGraphConvertor convertor(anf_graph); - - size_t pos = phase.find('.'); - std::string net_id = ((pos == std::string::npos || pos == phase.size() - 1) ? phase : phase.substr(pos + 1)); - std::string phase_prefix = phase.substr(0, pos); - if (phase_prefix == "export") { - MS_LOG(INFO) << "Set DfGraphConvertor training : false"; - convertor.set_training(false); - } - - TensorOrderMap init_tensors{}; - ConvertObjectToTensors(init_params, &init_tensors); - (void)convertor.ConvertAllNode().InitParam(init_tensors).BuildGraph(); - - if (broadcast_params != py::none()) { - if (!py::isinstance(broadcast_params)) { - MS_LOG(ERROR) << "Invalid broadcast params, it must be py::dict type"; - return false; - } - py::dict broadcast = broadcast_params.cast(); - if (broadcast.empty()) { - (void)convertor.GenerateBroadcastGraph(init_tensors); - } else { - TensorOrderMap broadcast_tensors{}; - ConvertObjectToTensors(broadcast, &broadcast_tensors); - (void)convertor.GenerateBroadcastGraph(broadcast_tensors); - } - MS_LOG(INFO) << "Generate broadcast graph with params and broadcast_empty is " << broadcast.empty(); - } - - (void)convertor.GenerateCheckpointGraph(); - if (convertor.ErrCode() != 0) { - DfGraphManager::GetInstance().ClearGraph(); - MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode(); - return false; - } - - if (MsContext::GetInstance()->save_graphs_flag()) { - convertor.DrawComputeGraph(GetFilePathName("ge_graph.dot")); // for debug - convertor.DrawInitGraph(GetFilePathName("init_graph.dot")); // for debug - convertor.DrawSaveCheckpointGraph(GetFilePathName("save_checkpoint_graph.dot")); // for debug - } - std::string init_graph = "init_subgraph." + net_id; - std::string checkpoint_name = "save." + net_id; - if (phase.find("train") != std::string::npos) { - (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); - } else { - (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph()); - } - (void)DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph()); - (void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph()); - - Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph()); - if (ret == Status::SUCCESS) { - DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph); - } - - return true; -} - -FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, - const std::string &phase, const py::object &broadcast_params) { - if (info.count(phase) == 0) { - MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); - } - FuncGraphPtr anf_graph = info.at(phase)->func_graph; - - if (MsContext::GetInstance()->save_graphs_flag()) { - draw::Draw(GetFilePathName("anf_graph.dot"), anf_graph); // for debug - DumpIR(GetFilePathName("anf_graph.ir"), anf_graph, true); - } - - if (!AddDFGraph(info, init_params, phase, broadcast_params)) { - MS_LOG(ERROR) << "GenConvertor failed"; - return nullptr; - } - -#if ENABLE_TRAIN - (void)setenv("GE_TRAIN", "1", 1); -#else - (void)setenv("GE_TRAIN", "0", 1); -#endif - - if (CreateSessionAndGraphRunner(static_cast(ENABLE_TRAIN)) != Status::SUCCESS) { - MS_LOG(ERROR) << "Create GE Session or GraphRunner failed."; - return nullptr; - } - - return anf_graph; -} - -void RunGEInitGraph(const py::dict &init_params, const std::string &phase) { - MS_LOG(DEBUG) << "ExecInitGraph start."; - TensorOrderMap inputs_with_name{}; - ConvertObjectToTensors(init_params, &inputs_with_name); - std::vector inputs; - (void)std::transform(inputs_with_name.begin(), inputs_with_name.end(), std::back_inserter(inputs), - [](const std::pair &item) { return item.second; }); - - std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); - if (ge_tensors.size() != inputs.size()) { - MS_LOG(ERROR) << "Args convert to ge tensor error."; - return; - } - MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size() << "."; - - std::vector ge_outputs; - transform::RunOptions run_options; - - run_options.name = phase; - if (DfGraphManager::GetInstance().GetGraphByName(phase) == nullptr) { - MS_LOG(WARNING) << "Can not find " << phase << " sub graph, don't need data init subgraph in INFER mode."; - return; - } - auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); - if (graph_runner == nullptr) { - MS_LOG(EXCEPTION) << "Can not found GraphRunner."; - } - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); - if (ret != Status::SUCCESS) { - MS_LOG(EXCEPTION) << "Exec " << phase << " graph failed."; - } - - MS_LOG(INFO) << "Exec " << phase << " graph success."; - - if ((ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::DISTRIBUTION) && - (DfGraphManager::GetInstance().GetGraphByName(BROADCAST_GRAPH_NAME) != nullptr)) { - run_options.name = BROADCAST_GRAPH_NAME; - ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); - if (ret != Status::SUCCESS) { - MS_LOG(EXCEPTION) << "Exec BROADCAST_GRAPH_NAME failed."; - } - MS_LOG(INFO) << "Exec broadcast graph success."; - } - } -} - -py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) { - MS_EXCEPTION_IF_NULL(cnode_data); - - if (cnode_data->isa()) { - if (*count >= data.size()) { - MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() - << " less than the number of elements required. "; - } - - BaseShapePtr shape = cnode_data->BuildShape(); - if (!shape->isa()) { - MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString(); - } - auto shape_me = shape->cast()->shape(); - auto shape_ge = py::cast(data[*count]).shape(); - if (shape_ge != shape_me) { - MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge - << " is not the same as the shape of the tensor derived: " << shape_me; - } - - return data[(*count)++]; - } - - if (!cnode_data->isa()) { - MS_LOG(EXCEPTION) << "The output of operator in the final anf graph could " - << "only be a tensor or a tuple of tensor, but got " << cnode_data->BuildValue()->ToString() - << "."; - } - auto data_tp = cnode_data->cast(); - auto elements = data_tp->elements(); - size_t size = data_tp->size(); - auto tp = py::tuple(size); - for (size_t i = 0; i < size; i++) { - tp[i] = ExtractGeneralCnodeRet(elements[i], data, count); - } - return std::move(tp); -} - -py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data, size_t *count) { - MS_EXCEPTION_IF_NULL(output_node); - - if (output_node->isa()) { - return ValuePtrToPyData(GetValueNode(output_node)); - } - - if (output_node->isa()) { - if (*count >= data.size()) { - MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() - << " less than the number of elements required. "; - } - return data[(*count)++]; - } - - auto output_c = output_node->cast(); - if (output_c == nullptr) { - MS_LOG(EXCEPTION) << "The final anf graph could only have constant, parameter, and operator, but got " - << output_node->ToString(); - } - - if (output_c->IsApply(prim::kPrimMakeTuple)) { - auto input_list = output_c->inputs(); - size_t size = input_list.size(); - auto tp = py::tuple(size - 1); - for (size_t i = 1; i < size; i++) { - tp[i - 1] = StructureOutput(input_list[i], data, count); - } - return std::move(tp); - } - if (output_c->IsApply(prim::kPrimDepend)) { - return StructureOutput(output_c->input(1), data, count); - } - - return ExtractGeneralCnodeRet(output_c->abstract(), data, count); -} - -std::shared_ptr DoExecGraph(const FuncGraphPtr &graph, const std::vector &inputs, - const std::string &phase) { - std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); - if (ge_tensors.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Convert me args to ge tensor error."; - } - - std::vector ge_outputs; - transform::RunOptions run_options; - run_options.name = phase; - auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); - if (graph_runner == nullptr) { - MS_LOG(EXCEPTION) << "Can not found GraphRunner."; - } - - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size(); - Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); - MS_LOG(DEBUG) << "Run graph finish, outputs size is: " << ge_outputs.size(); - if (ret != Status::SUCCESS) { - MS_LOG(ERROR) << "Exec graph failed"; - return nullptr; - } - } - - std::vector me_outputs = TransformUtil::ConvertGeTensors(ge_outputs); - if (me_outputs.size() != ge_outputs.size()) { - MS_LOG(WARNING) << "Convert output Ge tensor to Me tensor failed"; - } - - py::tuple outputs(me_outputs.size()); - for (std::size_t i = 0; i < outputs.size(); i++) { - outputs[i] = *me_outputs[i]; - } - - std::shared_ptr ret = nullptr; - - AnfNodePtr output_node = graph->get_return()->input(1); - MS_EXCEPTION_IF_NULL(output_node); - size_t count = 0; - py::object oj = StructureOutput(output_node, outputs, &count); - ret = std::make_shared(oj); - - return ret; -} - -void ProcessGeArg(const std::map &info, const py::tuple &args, const std::string &phase, - std::vector *inputs) { - // check the arg and use the ExecutorPy args - std::size_t size = args.size(); - - if (info.count(phase) == 0) { - MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); - } - - auto arg_size = info.at(phase)->arg_list_size; - if (size != arg_size) { - MS_LOG(EXCEPTION) << "The real arg num : size = " << size << ". graph_arg_size = " << arg_size; - } - - // process the first args of tensor - // only in dataset normal(non-sink) mode, fp_bp graph need input tensors - if (ConfigManager::GetInstance().dataset_mode() == DS_NORMAL_MODE) { - for (std::size_t i = 0; i < size; i++) { - ValuePtr converted = nullptr; - bool succ = parse::ConvertData(args[i], &converted); - if (!succ) { - MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; - } - if (converted->isa()) { - inputs->push_back(converted->cast()); - } else { - MS_EXCEPTION(TypeError) << "The " << i << "th arg: " << converted->ToString() << " is not tensor."; - } - } - } -} - -py::object ExecDFGraph(const std::map &info, const py::tuple &args, - const std::string &phase) { - std::string phase_prefix = GetPhasePrefix(phase); - if (phase_prefix == "save") { - DoExecNonInputGraph(phase); - ConfigManager::GetInstance().ResetConfig(); - return py::none(); - } - - if (info.count(phase) == 0) { - MS_LOG(EXCEPTION) << "There is no phase:" << phase; - } - FuncGraphPtr anf_graph = info.at(phase)->func_graph; - -#ifdef ENABLE_INFER - // Now don't use the graph because the exec ge function don't take effect - MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph); - if (ENABLE_TRAIN != info.at(phase)->func_graph->has_flag("training")) { - MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries"; - ConfigManager::GetInstance().ResetConfig(); - return py::none(); - } -#endif - - std::shared_ptr ret_val = std::make_shared(); - // We will not execute graph when output is constant or just input itself. - if (IsGraphOutputValueNodeOrParameter(info.at(phase)->func_graph->output(), args, ret_val)) { - ConfigManager::GetInstance().ResetConfig(); - return *ret_val; - } - - std::vector inputs; - ProcessGeArg(info, args, phase, &inputs); - - std::shared_ptr ret = DoExecGraph(anf_graph, inputs, phase); - ConfigManager::GetInstance().ResetConfig(); - if (ret != nullptr) { - return *ret; - } else { - MS_LOG(EXCEPTION) << "Exec graph failed"; - } -} -void ExportDFGraph(const std::string &file_name, const std::string &phase) { - MS_LOG(DEBUG) << "ExportGraph Begin"; - transform::DfGraphWrapperPtr wrap_ptr = DfGraphManager::GetInstance().GetGraphByName(phase); - if (wrap_ptr == nullptr) { - MS_LOG(ERROR) << "Get graph form DfGraphManager failed!"; - return; - } - - transform::DfGraphPtr ge_graph = wrap_ptr->graph_ptr_; - if (nullptr == ge_graph) { - MS_LOG(ERROR) << "The export graph is null"; - return; - } - - (void)ge_graph->SaveToFile(file_name); - - MS_LOG(DEBUG) << "ExportGraph End"; -} -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.h b/mindspore/ccsrc/pipeline/pipeline_ge.h deleted file mode 100644 index f3a363dbe8..0000000000 --- a/mindspore/ccsrc/pipeline/pipeline_ge.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ -#define MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pybind11/pybind11.h" -#include "pipeline/base.h" -#include "operator/ops.h" - -namespace mindspore { -namespace pipeline { -namespace py = pybind11; - -void SetGeOption(const std::map &options); - -void RunGEInitGraph(const py::dict &init_params, const std::string &phase); - -py::object ExecDFGraph(const std::map &info, const py::tuple &args, - const std::string &phase = "train"); - -FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, - const std::string &phase, const py::object &broadcast_params = {}); - -// init and exec dataset sub graph for GE backend -bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase); - -void ExportDFGraph(const std::string &file_name, const std::string &phase); -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ diff --git a/mindspore/ccsrc/pipeline/pynative/CMakeLists.txt b/mindspore/ccsrc/pipeline/pynative/CMakeLists.txt new file mode 100644 index 0000000000..c15928ee76 --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/CMakeLists.txt @@ -0,0 +1,9 @@ +file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "base.cc" "pynative_execute.cc") + +if (ENABLE_GE) + file(GLOB_RECURSE _GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute_ge.cc") + list(APPEND _PYNATIVE_SRC_LIST ${_GE_SRC_LIST}) +endif () + +set_property(SOURCE ${_PYNATIVE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PYNATIVE) +add_library(_mindspore_pipeline_pynative_obj OBJECT ${_PYNATIVE_SRC_LIST}) diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h new file mode 100644 index 0000000000..afb6d0982b --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PYNATIVE_BASE_H_ +#define MINDSPORE_CCSRC_PYNATIVE_BASE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "ir/primitive_py.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace pynative { +namespace py = pybind11; + +enum PynativeStatusCode { + PYNATIVE_SUCCESS = 0, + PYNATIVE_OP_NOT_IMPLEMENTED_ERR = 1, + PYNATIVE_OP_INPUTS_ERR = 2, + PYNATIVE_OP_PARAMS_ERR = 3, + PYNATIVE_OP_ATTRS_ERR = 4, + PYNATIVE_GRAPH_MANAGER_ERR = 5, + PYNATIVE_GRAPH_GE_BUILD_ERR = 6, + PYNATIVE_GRAPH_GE_RUN_ERR = 7, + PYNATIVE_UNKNOWN_STATE = 0XFF +}; + +enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; + +struct OpExecInfo { + PrimitivePyPtr py_primitive; + std::string op_name; + AbstractBasePtr abstract; + + py::tuple op_inputs; + py::tuple inputs_mask; + py::dict op_attrs; +}; +using OpExecInfoPtr = std::shared_ptr; +OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args); + +const std::set ignore_infer_prim = {"make_ref"}; +} // namespace pynative +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PYNATIVE_BASE_H_ diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc new file mode 100644 index 0000000000..5e3add1b5f --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -0,0 +1,1167 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/pynative/pynative_execute.h" + +#include +#include +#include +#include +#include + +#include "debug/trace.h" +#include "ir/tensor_py.h" +#include "ir/param_value.h" +#include "utils/any.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/composite.h" +#include "frontend/operator/composite/do_signature.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/resolve.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "backend/session/session_factory.h" +#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "backend/optimizer/common/helper.h" +#include "pipeline/jit/action.h" + +#include "pipeline/pynative/base.h" +#include "pybind_api/api_register.h" +#include "vm/transform.h" + +#include "frontend/optimizer/ad/grad.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/pipeline.h" +#include "pipeline/jit/pass.h" + +#ifdef ENABLE_GE +#include "pipeline/pynative/pynative_execute_ge.h" +#endif + +using mindspore::tensor::TensorPy; + +const char SINGLE_OP_GRAPH[] = "single_op_graph"; +// primitive unable to infer value for constant input in PyNative mode +const std::set vm_operators = {"make_ref", "HookBackward", "stop_gradient"}; + +namespace mindspore { +namespace pynative { + +static std::shared_ptr session = nullptr; +PynativeExecutorPtr PynativeExecutor::executor_ = nullptr; +std::mutex PynativeExecutor::instance_lock_; +ResourcePtr PynativeExecutor::resource_; + +template +void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) { + try { + (executor->*method)(args...); + } catch (const py::error_already_set &ex) { + // print function call stack info before release + std::ostringstream oss; + trace::TraceGraphEval(); + trace::GetEvalStackInfo(oss); + // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see + // these info from screen, no need to open log file to find these info + py::print(oss.str()); + MS_LOG(ERROR) << oss.str(); + PynativeExecutor::GetInstance()->Clean(); + // re-throw this exception to Python interpreter to handle it + throw(py::error_already_set(ex)); + } catch (const py::type_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::type_error(ex); + } catch (const py::value_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::value_error(ex); + } catch (const py::index_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::index_error(ex); + } catch (const std::exception &ex) { + PynativeExecutor::GetInstance()->Clean(); + // re-throw this exception to Python interpreter to handle it + throw(std::runtime_error(ex.what())); + } catch (...) { + PynativeExecutor::GetInstance()->Clean(); + std::string exName(abi::__cxa_current_exception_type()->name()); + MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; + } +} + +inline ValuePtr PyAttrValue(const py::object &obj) { + ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj); + if (!converted_ret) { + MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); + } + return converted_ret; +} + +std::string GetId(const py::object &obj) { + py::object to_process = obj; + std::string prefix = ""; + if (py::isinstance(to_process)) { + auto p_list = py::cast(to_process); + if (p_list.size() == 0) { + return "empty"; + } + prefix = "tuple:"; + std::string key = ""; + for (size_t i = 0; i < p_list.size(); ++i) { + key += std::string(py::str(GetId(p_list[i]))) + ":"; + } + return prefix + key; + } + if (py::isinstance(to_process)) { + return prefix + std::string(py::str(to_process)); + } + if (py::isinstance(to_process)) { + return prefix + std::string(py::str(to_process)); + } + if (py::isinstance(to_process)) { + auto tensor_ptr = py::cast(to_process); + return prefix + tensor_ptr->id(); + } + + py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj); + return py::cast(ret); +} + +py::object GetTupleObj(const py::object &obj) { + py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); + py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj); + return obj_tuple; +} + +std::map> GetTypeIndex(const std::vector &dtypes) { + std::map> type_indexes; + for (size_t i = 0; i < dtypes.size(); ++i) { + auto it = type_indexes.find(dtypes[i]); + if (it == type_indexes.end()) { + (void)type_indexes.insert(std::make_pair(dtypes[i], std::vector{i})); + } else { + it->second.push_back(i); + } + } + return type_indexes; +} + +std::map GetDstType(const py::tuple &py_args, + const std::map> &type_indexes) { + std::map dst_type; + for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) { + auto type = it->first; + auto indexes = it->second; + if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < 2) { + continue; + } + size_t priority = 0; + TypeId max_type = TypeId::kTypeUnknown; + bool has_float = false; + bool has_int = false; + for (size_t index : indexes) { + if (!has_float && py::isinstance(py_args[index])) { + has_float = true; + } + if (!has_int && !py::isinstance(py_args[index]) && py::isinstance(py_args[index])) { + has_int = true; + } + if (py::isinstance(py_args[index])) { + auto arg = py::cast(py_args[index]); + TypeId arg_type_id = arg->data_type(); + auto type_priority = prim::type_map.find(arg_type_id); + if (type_priority == prim::type_map.end()) { + continue; + } + if (type_priority->second > priority) { + max_type = type_priority->first; + priority = type_priority->second; + } + } + } + if (max_type == TypeId::kNumberTypeBool) { + if (has_int) { + max_type = TypeId::kNumberTypeInt32; + } + if (has_float) { + max_type = TypeId::kNumberTypeFloat32; + } + } + (void)dst_type.insert(std::make_pair(type, max_type)); + } + return dst_type; +} + +std::string TypeIdToMsTypeStr(const TypeId &type_id) { + auto type_name = type_name_map.find(type_id); + if (type_name == type_name_map.end()) { + MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id); + } + return type_name->second; +} + +py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { + py::tuple args(3); + std::string module_name = "mindspore.ops.functional"; + std::string op_name = "cast"; + args[0] = parse::python_adapter::GetPyFn(module_name, op_name); + args[1] = "Cast"; + + std::string dst_type_str = TypeIdToMsTypeStr(type_id); + module_name = "mindspore.common.dtype"; + py::object dst_type = parse::python_adapter::GetPyFn(module_name, dst_type_str); + py::tuple inputs(2); + inputs[0] = arg; + inputs[1] = dst_type; + args[2] = inputs; + + return RunOp(args)[0]; +} +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args, + py::list *const out_args_list) { + auto &py_args = *out_args; + py::tuple input_mask(args.size()); + for (size_t i = 0; i < args.size(); ++i) { + input_mask[i] = py::hasattr(args[i], "__parameter__"); + py_args[i] = GetTupleObj(args[i]); + } + auto signature = prim->signatures(); + std::vector dtypes; + (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), + [](const Signature &sig) { return sig.dtype; }); + int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); + if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { + return input_mask; + } + auto type_indexes = GetTypeIndex(dtypes); + auto dst_type = GetDstType(py_args, type_indexes); + + for (size_t i = 0; i < dtypes.size(); ++i) { + if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { + continue; + } + auto it = dst_type.find(dtypes[i]); + if (it == dst_type.end() || it->second == kTypeUnknown) { + continue; + } + if (py::isinstance(py_args[i])) { + auto arg = py::cast(py_args[i]); + if (arg->data_type() == it->second) { + continue; + } + if (signature[i].rw == SignatureEnumRW::kRWWrite) { + prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg->data_type()), + TypeIdToMsTypeStr(it->second)); + } + } + py::object cast_output = DoAutoCast(py_args[i], it->second); + (*out_args)[i] = cast_output; + (*out_args_list)[i] = cast_output; + } + return input_mask; +} + +void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) { + size_t size = py_args.size(); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < size; i++) { + ValuePtr input_value = PyAttrValue(py_args[i]); + args_spec_list.emplace_back(abstract::FromValueInside( + input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa())); + } + AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); + op_exec_info->abstract = infer_res; +} + +OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) { + if (args.size() != PY_ARGS_NUM) { + MS_LOG(ERROR) << "Three args are needed by RunOp"; + return nullptr; + } + auto op_exec_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(op_exec_info); + op_exec_info->op_name = py::cast(args[PY_NAME]); + auto prim = py::cast(args[PY_PRIM]); + auto pyobj = prim->GetPyObj(); + if (pyobj == nullptr) { + MS_LOG(EXCEPTION) << "pyobj is empty"; + } + + py::list a = args[PY_INPUTS]; + size_t input_num = a.size(); + op_exec_info->op_inputs = py::tuple(input_num); + + op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs, out_args); + // use python infer method + if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { + PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get()); + } + op_exec_info->py_primitive = prim; + op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); + if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { + MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; + return nullptr; + } + return op_exec_info; +} + +std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, + const std::vector &input_tensors) { + MS_EXCEPTION_IF_NULL(op_exec_info); + std::string graph_info; + // get input tensor info + size_t input_num = op_exec_info->op_inputs.size(); + for (size_t index = 0; index < input_num; ++index) { + auto input = op_exec_info->op_inputs[index]; + if (py::isinstance(input)) { + auto tensor_ptr = py::cast(input); + (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_"); + } + } + // get prim and abstract info + MS_EXCEPTION_IF_NULL(op_exec_info->abstract); + (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" + + op_exec_info->abstract->ToString()); + return graph_info; +} + +py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { + MS_LOG(INFO) << "RunOpInVM start"; + + MS_EXCEPTION_IF_NULL(status); + MS_EXCEPTION_IF_NULL(op_exec_info); + MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive); + if (op_exec_info->op_name == "HookBackward") { + auto op_inputs = op_exec_info->op_inputs; + py::tuple result(op_inputs.size()); + for (size_t i = 0; i < op_inputs.size(); i++) { + py::object input = op_inputs[i]; + if (py::hasattr(input, "__parameter__")) { + input = py::getattr(input, "data"); + } + auto tensor = py::cast(input); + auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); + new_tensor->set_device_address(tensor->device_address()); + new_tensor->set_dirty(tensor->is_dirty()); + result[i] = new_tensor; + } + *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "RunOpInVM end"; + return std::move(result); + } + auto func = op_exec_info->py_primitive->GetComputeFunction(); + if (py::isinstance(func)) { + MS_LOG(ERROR) << "VM failed to get func"; + *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; + py::tuple err_ret(0); + return std::move(err_ret); + } + + // execute op + py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs)); + *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "RunOpInVM end"; + return std::move(result); +} + +bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, + const std::unordered_set &input_attrs) { + MS_EXCEPTION_IF_NULL(op_prim); + auto input_names_value = op_prim->GetAttr(kAttrInputNames); + if (input_names_value == nullptr) { + return false; + } + auto input_names_vec = GetValue>(input_names_value); + if (input_index >= input_names_vec.size()) { + MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!"; + } + + if (input_attrs.find(input_index) != input_attrs.end()) { + ValuePtr value = parse::data_converter::PyDataToValue(input_object); + MS_EXCEPTION_IF_NULL(value); + auto input_name = input_names_vec[input_index]; + op_prim->set_attr(input_name, value); + return true; + } + return false; +} + +void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim, + std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(op_prim); + MS_EXCEPTION_IF_NULL(input_tensors); + for (const auto &input_object : tuple_inputs) { + if (!py::isinstance(input_object)) { + MS_LOG(EXCEPTION) << "The input object is not a tensor!"; + } + auto tensor = py::cast(input_object); + MS_EXCEPTION_IF_NULL(tensor); + input_tensors->push_back(tensor); + } + op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector{SizeToInt(tuple_inputs.size())})); +} + +void ConvertValueTupleToTensor(const py::object &input_object, std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(input_tensors); + ValuePtr input_value = parse::data_converter::PyDataToValue(input_object); + MS_EXCEPTION_IF_NULL(input_value); + if (!input_value->isa()) { + MS_LOG(EXCEPTION) << "The input object is not a value tuple!"; + } + auto value_tuple = input_value->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple); + MS_EXCEPTION_IF_NULL(tensor_ptr); + input_tensors->push_back(tensor_ptr); +} + +void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, + std::vector *input_tensors, int *tensor_mask) { + MS_EXCEPTION_IF_NULL(op_prim); + MS_EXCEPTION_IF_NULL(input_tensors); + MS_EXCEPTION_IF_NULL(tensor_mask); + + if (!py::isinstance(input_object)) { + MS_LOG(EXCEPTION) << "The input should be a tuple!"; + } + auto tuple_inputs = py::cast(input_object); + if (tuple_inputs.size() == 0) { + MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!"; + } + if (py::isinstance(tuple_inputs[0])) { + PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors); + } else { + ConvertValueTupleToTensor(input_object, input_tensors); + *tensor_mask = kValueNodeTensorMask; + } +} + +void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, + std::vector *input_tensors, int *tensor_mask) { + MS_EXCEPTION_IF_NULL(op_prim); + MS_EXCEPTION_IF_NULL(input_tensors); + MS_EXCEPTION_IF_NULL(tensor_mask); + tensor::TensorPtr tensor_ptr = nullptr; + if (py::isinstance(input_object)) { + tensor_ptr = py::cast(input_object); + } else if (py::isinstance(input_object)) { + double input_value = py::cast(input_object); + tensor_ptr = std::make_shared(input_value, kFloat32); + *tensor_mask = kValueNodeTensorMask; + } else if (py::isinstance(input_object)) { + tensor_ptr = std::make_shared(py::cast(input_object), kInt32); + *tensor_mask = kValueNodeTensorMask; + } else if (py::isinstance(input_object)) { + tensor_ptr = TensorPy::MakeTensor(py::cast(input_object), nullptr); + } else if (py::isinstance(input_object)) { + auto list_inputs = py::cast(input_object); + py::tuple tuple_inputs(list_inputs.size()); + for (size_t i = 0; i < tuple_inputs.size(); ++i) { + tuple_inputs[i] = list_inputs[i]; + } + ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask); + return; + } else if (py::isinstance(input_object)) { + ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask); + return; + } else if (py::isinstance(input_object)) { + return; + } else { + MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; + } + MS_EXCEPTION_IF_NULL(tensor_ptr); + input_tensors->push_back(tensor_ptr); +} + +void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *tensors_mask, + std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(op_run_info); + MS_EXCEPTION_IF_NULL(tensors_mask); + MS_EXCEPTION_IF_NULL(input_tensors); + PrimitivePtr op_prim = op_run_info->py_primitive; + MS_EXCEPTION_IF_NULL(op_prim); + + if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) { + MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size " + << op_run_info->inputs_mask.size(); + } + opt::ConstInputToAttrInfoRegister reg; + bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); + size_t input_num = op_run_info->op_inputs.size(); + for (size_t index = 0; index < input_num; ++index) { + // convert const input to attr + if (reg_exist && + RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) { + continue; + } + // convert const and tuple input to tensor + int tensor_mask = py::cast(op_run_info->inputs_mask[index]); + ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask); + // mark tensors, data : 0, weight : 1, valuenode: 2 + std::vector new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); + tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); + } +} + +void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(input_tensors); + if (input_tensors->size() != tensors_mask.size()) { + MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size " + << tensors_mask.size(); + } + std::vector new_input_tensors; + for (size_t index = 0; index < tensors_mask.size(); ++index) { + if (tensors_mask[index] != kValueNodeTensorMask) { + new_input_tensors.push_back(input_tensors->at(index)); + } + } + *input_tensors = new_input_tensors; +} + +py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { + MS_EXCEPTION_IF_NULL(op_exec_info); + MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->set_enable_pynative_infer(true); + std::string device_target = ms_context->device_target(); + if (device_target != kAscendDevice && device_target != kGPUDevice) { + MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; + } + + if (session == nullptr) { + session = session::SessionFactory::Get().Create(device_target); + } + MS_EXCEPTION_IF_NULL(session); + session->Init(ms_context->device_id()); + + std::vector input_tensors; + std::vector tensors_mask; + ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); + // get graph info for checking it whether existing in the cache + std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); + session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask); + EraseValueNodeTensor(tensors_mask, &input_tensors); + py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); + ms_context->set_enable_pynative_infer(false); + *status = PYNATIVE_SUCCESS; + return result; +} + +py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, + PynativeStatusCode *const status) { + MS_EXCEPTION_IF_NULL(status); + py::object result; + switch (backend_policy) { + case kMsBackendVmOnly: { + // use vm only + MS_LOG(INFO) << "RunOp use VM only backend"; + result = RunOpInVM(op_exec_info, status); + break; + } + case kMsBackendGePrior: { +#ifdef ENABLE_GE + // use GE first, use vm when GE fails + MS_LOG(INFO) << "RunOp use GE first backend"; + result = RunOpInGE(op_exec_info, status); + if (*status != PYNATIVE_SUCCESS) { + result = RunOpInVM(op_exec_info, status); + } +#endif + break; + } + case kMsBackendMsPrior: { + // use Ms fisrt,use others when ms failed + MS_LOG(INFO) << "RunOp use Ms first backend"; + result = RunOpInMs(op_exec_info, status); + if (*status != PYNATIVE_SUCCESS) { + MS_LOG(ERROR) << "RunOp use Ms backend failed!!!"; + } + break; + } + default: + MS_LOG(ERROR) << "No backend configured for run op"; + } + return result; +} + +AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) { + if (!grad_flag_ || graph_info_map_.empty()) { + return nullptr; + } + std::vector inputs; + auto prim = op_exec_info->py_primitive; + inputs.push_back(NewValueNode(prim)); + py::tuple op_masks = op_exec_info->inputs_mask; + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < args.size(); i++) { + auto node = GetInput(args[i], op_masks[i]); + args_spec_list.push_back(node->abstract()); + inputs.push_back(node); + } + + auto cnode = curr_g_->NewCNode(inputs); + MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4); + py::object out_real = out; + if (out.size() == 1) { + MS_LOG(DEBUG) << "MakeCnode out size is one."; + out_real = out[0]; + } + std::string obj_id = GetId(out_real); + if (py::isinstance(out_real)) { + auto value = py::cast(out_real); + if (value.size() > 1) { + for (int i = 0; i < static_cast(value.size()); i++) { + auto value_id = GetId(value[i]); + MS_LOG(DEBUG) << "MakeCnode set node id " << value_id; + set_obj_node_map(curr_g_, value_id, cnode, i); + } + } + } + MS_LOG(DEBUG) << "MakeCnode set node id " << obj_id; + set_obj_node_map(curr_g_, obj_id, cnode); + set_pyobj(curr_g_, obj_id); + return cnode; +} + +AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { + auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)]; + if (out.second.size() == 1 && out.second[0] == -1) { + return out.first; + } + auto node = out.first; + MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString(); + for (auto &idx : out.second) { + std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)}; + node = curr_g_->NewCNode(tuple_get_item_inputs); + } + MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); + return node; +} + +py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) { + MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; + mindspore::parse::python_adapter::set_python_env_flag(true); + MsBackendPolicy backend_policy; +#if (!defined ENABLE_GE) + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->backend_policy() == "ms") { + backend_policy = kMsBackendMsPrior; + } else { + backend_policy = kMsBackendVmOnly; + } +#else + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->PynativeInitGe(); + backend_policy = kMsBackendGeOnly; +#endif + if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) { + backend_policy = kMsBackendVmOnly; + } + PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE; + // returns a null py::tuple on error + py::tuple err_ret(0); + py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status); + if (status != PYNATIVE_SUCCESS) { + MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name; + return err_ret; + } + + auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result); + if (node != nullptr) { + node->set_abstract(op_exec_info->abstract); + MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString(); + } + MS_LOG(DEBUG) << "RunOp end"; + return result; +} + +py::tuple RunOpInner(const py::args &args) { + MS_LOG(DEBUG) << "RunOp start" << args.size(); + py::list args_input = args[PY_INPUTS]; + + OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args, &args_input); + MS_EXCEPTION_IF_NULL(op_exec_info); + + if (op_exec_info->abstract != nullptr) { + py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); + if (!output["value"].is_none()) { + py::tuple value_ret(1); + value_ret[0] = output["value"]; + return value_ret; + } + if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { + py::tuple value_ret(1); + value_ret[0] = ""; + return value_ret; + } + } + return RunOpInner(op_exec_info, args_input); +} + +py::tuple RunOp(const py::args &args) { + try { + return RunOpInner(args); + } catch (const py::error_already_set &ex) { + // print function call stack info before release + std::ostringstream oss; + trace::TraceGraphEval(); + trace::GetEvalStackInfo(oss); + // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see + // these info from screen, no need to open log file to find these info + py::print(oss.str()); + MS_LOG(ERROR) << oss.str(); + PynativeExecutor::GetInstance()->Clean(); + // re-throw this exception to Python interpreter to handle it + throw(py::error_already_set(ex)); + } catch (const py::type_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::type_error(ex); + } catch (const py::value_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::value_error(ex); + } catch (const py::index_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::index_error(ex); + } catch (const std::exception &ex) { + PynativeExecutor::GetInstance()->Clean(); + // re-throw this exception to Python interpreter to handle it + throw(std::runtime_error(ex.what())); + } catch (...) { + PynativeExecutor::GetInstance()->Clean(); + std::string exName(abi::__cxa_current_exception_type()->name()); + MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; + } +} + +void ClearPyNativeSession() { session = nullptr; } + +PynativeExecutor::~PynativeExecutor() { ClearRes(); } + +PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } + +void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { + auto cell_id = GetId(cell); + if (cell_graph_map_.count(cell_id) != 0) { + MS_LOG(DEBUG) << "Newgraph already compiled"; + return; + } + + auto g = std::make_shared(); + + if (top_g_ == nullptr) { + top_g_ = curr_g_ = g; + df_builder_ = std::make_shared(); + MS_LOG(DEBUG) << "First new graph" << top_g_.get(); + Pushp(); + } else { + Pushp(); + curr_g_ = g; + } + if (graph_info_map_.count(g) == 0) { + graph_info_map_[g] = GraphInfo(); + } + for (size_t i = 0; i < args.size(); i++) { + auto new_param = g->add_parameter(); + std::string param_obj = GetId(args[i]); + graph_info_map_[g].param_map[param_obj] = new_param; + } +} + +AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) { + ValuePtr converted_ret = nullptr; + parse::ConvertData(obj, &converted_ret); + auto node = NewValueNode(converted_ret); + set_obj_node_map(curr_g_, obj_id, node); + return node; +} + +AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) { + AnfNodePtr node = nullptr; + std::string obj_id = GetId(obj); + + if (op_mask != nullptr && py::cast(op_mask)) { + MS_LOG(DEBUG) << "Topgraph free parameter"; + // get the parameter name from parameter object + auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name"); + if (py::isinstance(name_attr)) { + MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; + } + auto param_name = py::cast(name_attr); + if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { + auto free_param = df_builder_->add_parameter(); + free_param->set_name(param_name); + auto free_param_new = py::cast(obj.attr("_value")); + free_param->set_default_param(free_param_new); + free_param->debug_info()->set_name(param_name); + MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; + graph_info_map_[df_builder_].param_map[obj_id] = free_param; + return free_param; + } + return graph_info_map_[df_builder_].param_map[obj_id]; + } + + // if input is graph output + if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) { + // op(x, y) + node = graph_info_map_[curr_g_].param_map[obj_id]; + } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) { + // out = op(op1(x, y)) + // out = op(cell1(x, y)) + // out = op(cell1(x, y)[0]) + node = GetObjNode(obj); + } else if (py::isinstance(obj)) { + // out = op((x, y)) + // out = cell((x, y)) + auto tuple = obj.cast(); + + // cell((1,2)): support not mix (scalar, tensor) + if (tuple.size() > 0 && !py::isinstance(tuple[0])) { + return MakeValueNode(obj, obj_id); + } + + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + + auto tuple_size = static_cast(tuple.size()); + for (int i = 0; i < tuple_size; i++) { + args.push_back(GetInput(tuple[i], py::object())); + } + auto cnode = curr_g_->NewCNode(args); + set_obj_node_map(curr_g_, GetId(obj), cnode); + node = cnode; + } else { + node = MakeValueNode(obj, obj_id); + } + + MS_LOG(DEBUG) << "Now getinput node " << node->ToString() << obj_id; + return node; +} + +// for output[0][1] need getitem multi +void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx) { + if (py::isinstance(obj)) { + auto tuple = obj.cast(); + for (int i = 0; i < static_cast(tuple.size()); i++) { + std::vector tmp = idx; + tmp.push_back(i); + set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, tmp); + SetTupleOutput(tuple[i], cnode, tmp); + } + } +} + +void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); } + +void PynativeExecutor::Popp() { + if (graph_p_.empty()) { + MS_LOG(EXCEPTION) << "Stack graph_p_ is empty"; + } + curr_g_ = graph_p_.top(); + graph_p_.pop(); +} + +void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { + auto cell_id = GetId(cell); + if (cell_graph_map_.count(cell_id) != 0) { + MS_LOG(DEBUG) << "Endgraph already compiled"; + return; + } + cell_graph_map_[cell_id] = curr_g_; + auto out_id = GetId(out); + if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) { + // cell construct return x, y + if (py::isinstance(out)) { + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + + auto tuple = out.cast(); + MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size(); + auto tuple_size = static_cast(tuple.size()); + auto cnode = curr_g_->NewCNode(args); + for (int i = 0; i < tuple_size; i++) { + args.push_back(GetInput(tuple[i], py::object())); + set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i); + SetTupleOutput(tuple[i], cnode, std::vector{i}); + } + cnode->set_inputs(args); + set_obj_node_map(curr_g_, out_id, cnode); + } else { + MS_LOG(ERROR) << "Graph has no this out: " << out_id; + return; + } + } + EndGraphByOutId(out_id, cell, out, args); +} + +void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, + const py::args &args) { + AnfNodePtr output_node; + if (graph_info_map_[curr_g_].param_map.count(out_id)) { + output_node = graph_info_map_[curr_g_].param_map[out_id]; + } else { + output_node = GetObjNode(out); + } + curr_g_->set_output(output_node); + std::vector inputs; + inputs.push_back(NewValueNode(curr_g_)); + MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString(); + resource_->manager()->AddFuncGraph(curr_g_); + // custom bprop debug + if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { + MS_LOG(DEBUG) << "Use cell custom bprop function."; + FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell); + if (bprop_graph != nullptr) { + (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); + (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_))); + } + } + auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); + if (curr_g_ != top_g_) { + Popp(); + for (size_t i = 0; i < args.size(); i++) { + auto input = GetInput(args[i], py::object()); + inputs.push_back(input); + } + auto out_cnode = curr_g_->NewCNode(inputs); + set_pyobj(curr_g_, GetId(cell)); + if (py::isinstance(out)) { + auto out_list = py::cast(out); + auto out_size = static_cast(out_list.size()); + for (int i = 0; i < out_size; i++) { + set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i); + SetTupleOutput(out_list[i], out_cnode, std::vector{i}); + } + } + set_obj_node_map(curr_g_, GetId(out), out_cnode); + } else { + parse::ResolveFuncGraph(newfg, resource_); + resource_->set_func_graph(newfg); + } +} + +std::vector PynativeExecutor::GetWeightsArgs(const py::object &weights) { + std::vector w_args; + if (py::hasattr(weights, "__parameter_tuple__")) { + auto tuple = weights.cast(); + MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size(); + w_args.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t it = 0; it < tuple.size(); ++it) { + auto param = tuple[it]; + auto param_id = GetId(param); + AnfNodePtr para_node = nullptr; + if (graph_info_map_[df_builder_].param_map.count(param_id)) { + para_node = graph_info_map_[df_builder_].param_map[param_id]; + + AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node); + AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); + auto refkey = std::make_shared(para_node->cast()->name()); + AnfNodePtr ref_key_node = NewValueNode(refkey); + AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node}); + + w_args.push_back(ref_node); + } + } + } else { + MS_LOG(DEBUG) << "training not paramter_tuple"; + } + return w_args; +} + +abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) { + abstract::AbstractBasePtrList args_spec; + std::size_t size = args.size(); + for (std::size_t i = 0; i < size; i++) { + ValuePtr converted = nullptr; + bool succ = parse::ConvertData(args[i], &converted); + if (!succ) { + MS_LOG(EXCEPTION) << "Args convert error"; + } + bool broaden = true; + auto abs = abstract::FromValue(converted, broaden); + args_spec.push_back(abs); + auto param_node = std::static_pointer_cast(df_builder_->parameters()[i]); + param_node->set_abstract(abs); + } + + for (const auto ¶m : df_builder_->parameters()) { + auto param_node = std::static_pointer_cast(param); + if (param_node->has_default()) { + const auto ¶m_value = param_node->default_param(); + ValuePtr value = param_value->value(); + AbstractBasePtr ptr = abstract::FromValue(value, true); + if (ptr == nullptr) { + MS_LOG(EXCEPTION) << "Args convert error"; + } + args_spec.push_back(ptr); + param_node->set_abstract(ptr); + } + } + + return args_spec; +} + +void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + const py::args &args) { + MS_LOG(INFO) << "GradNet start" << args.size(); + + std::size_t size = args.size(); + auto cell_id = GetId(cell); + if (graph_map_.count(cell_id) != 0) { + MS_LOG(DEBUG) << "GradNet already compiled"; + return; + } + MS_LOG(DEBUG) << "GradNet first compiled"; + std::vector new_params; + for (size_t i = 0; i < size; i++) { + ParameterPtr p = std::make_shared(df_builder_); + new_params.push_back(p); + } + MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size(); + new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end()); + df_builder_->set_parameters(new_params); + resource_->manager()->SetParameters(df_builder_, new_params); + + std::vector w_args = GetWeightsArgs(weights); + MS_EXCEPTION_IF_NULL(resource_->func_graph()); + auto g = GradGraph(resource_->func_graph(), grad, w_args, size); + resource_->set_func_graph(g); + resource_->manager()->KeepRoots({g}); + + // get the parameters items and add the value to args_spec + abstract::AbstractBasePtrList args_spec = GetArgsSpec(args); + MS_LOG(DEBUG) << "Args_spec size" << args_spec.size(); + + resource_->set_args_spec(args_spec); + MS_LOG(DEBUG) << "Start opt"; + + // Create backend and session + resource_->results()[pipeline::kBackend] = compile::CreateBackend(); + + graph_map_[cell_id] = g; + PynativeOptimizeAction(resource_); + TaskEmitAction(resource_); + ExecuteAction(resource_); + resource_->Clean(); + ad::CleanRes(); + pipeline::ReclaimOptimizer(); +} + +void PynativeExecutor::Clear(const std::string &flag) { + if (!flag.empty()) { + MS_LOG(INFO) << "Clear res"; + (void)graph_map_.erase(flag); + (void)cell_graph_map_.erase(flag); + Clean(); + // Maybe exit in the pynative runing op, so need reset pynative flag. + auto ms_context = MsContext::GetInstance(); + if (ms_context != nullptr) { + ms_context->set_enable_pynative_infer(false); + } + return; + } + + MS_LOG(INFO) << "Clear"; + top_g_ = nullptr; + curr_g_ = nullptr; + graph_info_map_.clear(); + std::stack().swap(graph_p_); +} + +void PynativeExecutor::Clean() { + MS_LOG(INFO) << "Clean all res"; + Clear(); + grad_flag_ = false; + df_builder_ = nullptr; + ad::CleanRes(); + pipeline::ReclaimOptimizer(); +} + +void PynativeExecutor::ClearRes() { + Clean(); + resource_.reset(); +} + +py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) { + VectorRef arg_list; + pipeline::ProcessVmArgInner(args, resource_, &arg_list); + if (resource_->results().find(pipeline::kOutput) == resource_->results().end() || + !resource_->results()[pipeline::kOutput].is()) { + MS_LOG(EXCEPTION) << "Can't find run graph func for "; + } + compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast(); + if (run == nullptr) { + MS_LOG(EXCEPTION) << "Can't find run graph func for "; + } + + std::string backend = MsContext::GetInstance()->backend_policy(); + + MS_LOG(DEBUG) << "Eval run" << backend; + BaseRef value = (*run)(arg_list); + MS_LOG(DEBUG) << "Run end" << value.ToString(); + return BaseRefToPyData(value); +} + +FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, + const std::vector &weights, size_t arg_size) { + auto nparam = top_g_->parameters().size(); + std::ostringstream ss; + ss << "grad{" << nparam << "}"; + df_builder_->set_flag(FUNC_GRAPH_FLAG_CORE, true); + df_builder_->debug_info()->set_name(ss.str()); + + auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights); + std::vector inputs = {NewValueNode(df)}; + for (size_t i = 0; i < arg_size; ++i) { + inputs.push_back(df_builder_->parameters()[i]); + } + auto out = df_builder_->NewCNode(inputs); + df_builder_->set_output(out); + resource_->manager()->AddFuncGraph(df); + resource_->manager()->AddFuncGraph(df_builder_); + return df_builder_; +} + +void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { + PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args); +} + +void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { + PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args); +} + +void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + const py::args &args) { + PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args); +} + +REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { + (void)py::class_>(*m, "PynativeExecutor_") + .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") + .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.") + .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.") + .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.") + .def("clear", &PynativeExecutor::Clear, "pynative clear status.") + .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""), + "Executor run function.") + .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), + "Executor set grad flag."); + })); +} // namespace pynative +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h new file mode 100644 index 0000000000..152d58aca4 --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -0,0 +1,130 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ +#define MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/numpy.h" + +#include "pipeline/pynative/base.h" +#include "utils/context/ms_context.h" +#include "ir/anf.h" +#include "pipeline/jit/resource.h" +#include "frontend/operator/composite/composite.h" + +namespace mindspore { +namespace pynative { + +namespace py = pybind11; +using ResourcePtr = std::shared_ptr; +using GradOperationPtr = std::shared_ptr; + +py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); + +py::tuple RunOp(const py::args &args); + +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args, + py::list *const out_args_list); + +void ClearPyNativeSession(); + +struct GraphInfo { + std::unordered_map param_map; + std::unordered_map>> obj_node_map; + AnfNodePtr output; + std::vector objects; +}; + +class PynativeExecutor : public std::enable_shared_from_this { + public: + static std::shared_ptr GetInstance() { + std::lock_guard i_lock(instance_lock_); + if (executor_ == nullptr) { + executor_ = std::shared_ptr(new (std::nothrow) PynativeExecutor()); + resource_ = std::make_shared(); + } + return executor_; + } + void NewGraph(const py::object &cell, const py::args &args); + void NewGraphInner(const py::object &cell, const py::args &args); + void EndGraph(const py::object &cell, const py::object &out, const py::args &args); + void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); + void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args); + std::vector GetWeightsArgs(const py::object &weights); + abstract::AbstractBasePtrList GetArgsSpec(const py::args &args); + void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); + void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + const py::args &args); + void Clear(const std::string &flag = ""); + void Clean(); + void ClearRes(); + bool grad_flag() { return grad_flag_; } + void set_grad_flag(bool flag) { grad_flag_ = flag; } + AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask); + AnfNodePtr GetObjNode(const py::object &obj); + FuncGraphPtr curr_g() { return curr_g_; } + void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } + void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { + graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{-1}); + } + void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { + graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{index}); + } + void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { + graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); + } + AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out); + py::object Run(const py::tuple &args, const py::object &phase); + + void Pushp(); + void Popp(); + FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector &weights, + size_t arg_size); + void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx); + AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); + + ~PynativeExecutor(); + + private: + PynativeExecutor(); + static std::shared_ptr executor_; + static std::mutex instance_lock_; + static ResourcePtr resource_; + bool grad_flag_; + std::unordered_map graph_map_; + std::unordered_map cell_graph_map_; + std::unordered_map graph_info_map_; + std::stack graph_p_; + FuncGraphPtr top_g_; + FuncGraphPtr df_builder_; + FuncGraphPtr curr_g_; +}; + +using PynativeExecutorPtr = std::shared_ptr; + +} // namespace pynative +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.cc new file mode 100644 index 0000000000..897c21fc90 --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.cc @@ -0,0 +1,312 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/pynative/pynative_execute_ge.h" + +#include +#include +#include +#include + +#include "utils/any.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "backend/session/session_factory.h" +#include "ir/tensor_py.h" + +const char SINGLE_OP_GRAPH[] = "single_op_graph"; + +using mindspore::tensor::TensorPy; + +namespace mindspore { +namespace pynative { +using MeTensor = mindspore::tensor::Tensor; +using MeTensorPtr = mindspore::tensor::TensorPtr; +using GeOperator = ge::Operator; +using GeOperatorPtr = std::shared_ptr; + +using transform::GraphRunner; +using transform::GraphRunnerOptions; +using transform::OperatorPtr; +static std::shared_ptr session = nullptr; +inline ValuePtr PyAttrValue(const py::object &obj) { + ValuePtr converted_ret = nullptr; + bool converted = parse::ConvertData(obj, &converted_ret); + if (!converted) { + MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); + } + return converted_ret; +} + +MeTensorPtr ConvertPyObjToTensor(const py::object &obj) { + MeTensorPtr me_tensor_ptr = nullptr; + if (py::isinstance(obj)) { + me_tensor_ptr = py::cast(obj); + } else if (py::isinstance(obj)) { + me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); + } else if (py::isinstance(obj)) { + me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); + } else if (py::isinstance(obj)) { + me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); + } else if (py::isinstance(obj)) { + me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); + } else if (py::isinstance(obj)) { + me_tensor_ptr = TensorPy::MakeTensor(py::cast(obj), nullptr); + } else { + MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; + } + return me_tensor_ptr; +} + +bool SetInputsForSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const OperatorPtr &op, std::vector *graph_input_nodes) { + MS_EXCEPTION_IF_NULL(op_exec_info); + MS_EXCEPTION_IF_NULL(graph_input_nodes); + auto op_inputs = op_exec_info->op_inputs; + std::string op_name = op_exec_info->op_name; + transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); + if (adapter == nullptr) { + return false; + } + + int op_input_idx = 1; + size_t size = inputs.size(); + for (size_t i = 0; i < size; i++) { + if (inputs[i] == nullptr) { + continue; + } + auto const_op = std::make_shared(); + MS_EXCEPTION_IF_NULL(const_op); + (void)const_op->set_attr_value(*inputs[i]); + MeTensorPtr me_tensor_ptr = ConvertPyObjToTensor(op_inputs[i]); + MS_EXCEPTION_IF_NULL(me_tensor_ptr); + auto const_op_desc = + transform::TransformUtil::GetGeTensorDesc(me_tensor_ptr->shape_c(), me_tensor_ptr->data_type(), kOpFormat_NCHW); + if (const_op_desc == nullptr) { + MS_LOG(ERROR) << "Create variable " << op_name << " output descriptor failed!"; + return false; + } + auto pointer_cast_const_op = std::static_pointer_cast(const_op); + MS_EXCEPTION_IF_NULL(pointer_cast_const_op); + (void)pointer_cast_const_op->update_output_desc_y(*const_op_desc); + auto &input_map = adapter->getInputMap(); + if (input_map.find(op_input_idx) == input_map.end()) { + continue; + } + if (adapter->setInput(op, op_input_idx++, const_op)) { + MS_LOG(ERROR) << "Failed to set params, index is " << op_input_idx; + return false; + } + graph_input_nodes->push_back(*const_op); + } + return true; +} + +bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const std::unordered_map &attrs, const GeGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(op_exec_info); + std::string op_name = op_exec_info->op_name; + auto op_inputs = op_exec_info->op_inputs; + transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); + if (adapter == nullptr) { + MS_LOG(ERROR) << "Unable to find Adapter for " << ((std::string)py::str(op_name)); + return false; + } + OperatorPtr op = adapter->generate(op_name); + MS_EXCEPTION_IF_NULL(op); + + std::vector graph_input_nodes; + // hold param nodes after setting input and output for the graph + // set input + if (!SetInputsForSingleOpGraph(op_exec_info, inputs, op, &graph_input_nodes)) { + return false; + } + // set attributes + for (auto attr : attrs) { + (void)adapter->setAttr(op, attr.first, attr.second); + } + // set default attributes + auto extra_attrs = adapter->GetExtraAttr(); + for (auto attr : extra_attrs) { + (void)adapter->setAttr(op, attr.first, attr.second); + } + // set input attributes + auto &input_attr_map = adapter->getInputAttrMap(); + for (auto &it : input_attr_map) { + if (op_inputs.size() < it.first) { + continue; + } + auto const_value = PyAttrValue(op_inputs[it.first - 1]); + if (const_value->isa()) { + continue; + } + it.second.set_attr(op, const_value); + } + // construct output data nodes + std::vector graph_outputs{*op}; + // set input and output nodes for the graph + MS_EXCEPTION_IF_NULL(graph); + (void)graph->SetInputs(graph_input_nodes).SetOutputs(graph_outputs); + MS_LOG(INFO) << "BuildSingleOpGraph done"; + return true; +} + +void ToTensorPtr(const OpExecInfoPtr op_exec_info, std::vector *const inputs) { + MS_EXCEPTION_IF_NULL(inputs); + MS_EXCEPTION_IF_NULL(op_exec_info); + auto op_inputs = op_exec_info->op_inputs; + size_t size = op_inputs.size(); + for (size_t i = 0; i < size; i++) { + if (py::isinstance(op_inputs[i])) { + inputs->emplace_back(nullptr); + continue; + } + MeTensorPtr me_tensor_ptr = ConvertPyObjToTensor(op_inputs[i]); + auto ge_tensor_ptr = transform::TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW); + if (ge_tensor_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Convert inputs to GE tensor failed in op " << op_exec_info->op_name << "."; + } + // set inputs for operator to build single node graph + inputs->push_back(ge_tensor_ptr); + } +} + +PynativeStatusCode ConvertAttributes(const OpExecInfoPtr &op_exec_info, const std::vector &inputs) { + MS_EXCEPTION_IF_NULL(op_exec_info); + auto op_attrs = op_exec_info->op_attrs; + std::unordered_map attrs{}; + + for (auto &item : op_attrs) { + if (!py::isinstance(item.first)) { + MS_LOG(ERROR) << "Type error in py dict convert"; + return PYNATIVE_OP_ATTRS_ERR; + } + std::string name = py::cast(item.first); + auto attr_value = PyAttrValue(py::cast(item.second)); + (void)attrs.emplace(name, attr_value); + } + + // build graph + GeGraphPtr graph = std::make_shared(op_exec_info->op_name); + if (BuildSingleOpGraph(op_exec_info, inputs, attrs, graph) == false) { + MS_LOG(ERROR) << "Failed to BuildSingleOpGraph"; + return PYNATIVE_GRAPH_GE_BUILD_ERR; + } + + // add the single op graph into the graph manager, which will be iterated by session. + transform::Status ret = + transform::DfGraphManager::GetInstance().AddGraph(SINGLE_OP_GRAPH, std::shared_ptr(graph)); + if (ret != transform::SUCCESS) { + MS_LOG(ERROR) << "Failed to AddGraph into graph manager"; + return PYNATIVE_GRAPH_MANAGER_ERR; + } + + return PYNATIVE_SUCCESS; +} + +std::vector ConvertOutputTensors(const OpExecInfoPtr &op_exec_info, + const std::vector &ge_tensors) { + std::vector outputs; + AbstractBasePtr abs_base = op_exec_info->abstract; + std::vector> shapes; + if (abs_base != nullptr && abs_base->isa()) { + auto arg_tensor = dyn_cast(abs_base); + shapes.emplace_back(arg_tensor->shape()->shape()); + outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); + return outputs; + } + if (abs_base != nullptr && abs_base->isa()) { + auto arg_tuple = dyn_cast(abs_base); + size_t len = arg_tuple->size(); + + for (size_t i = 0; i < len; i++) { + if (arg_tuple->elements()[i]->isa()) { + auto arg_tensor = dyn_cast(arg_tuple->elements()[i]); + shapes.emplace_back(arg_tensor->shape()->shape()); + } + } + outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); + return outputs; + } + for (auto &it : ge_tensors) { + auto tensor = transform::TransformUtil::ConvertGeTensor(it); + if (tensor != nullptr) { + outputs.emplace_back(tensor); + } + } + return outputs; +} + +py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { + MS_LOG(INFO) << "RunOpInGe start"; + MS_EXCEPTION_IF_NULL(op_exec_info); + MS_EXCEPTION_IF_NULL(status); + + // returns a null py::tuple on error + py::tuple err_ret(0); + auto op_name = op_exec_info->op_name; + transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); + if (adapter == nullptr) { + MS_LOG(ERROR) << "Unable to find GE Adapter for " << ((std::string)py::str(op_name)); + *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; + return std::move(err_ret); + } + + std::vector inputs{}; + ToTensorPtr(op_exec_info, &inputs); + // convert me attr to ge AttrValue + PynativeStatusCode ret = ConvertAttributes(op_exec_info, inputs); + if (ret != PYNATIVE_SUCCESS) { + *status = ret; + return std::move(err_ret); + } + // run graph + transform::RunOptions run_options; + run_options.name = SINGLE_OP_GRAPH; + std::vector ge_inputs; + std::vector ge_outputs; + transform::GraphRunnerOptions graph_runner_options; + graph_runner_options.options["ge.trainFlag"] = "1"; + auto graph_runner = std::make_shared(graph_runner_options); + transform::Status run_ret; + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + run_ret = graph_runner->RunGraph(run_options, ge_inputs, &ge_outputs); + } + if (run_ret != transform::Status::SUCCESS) { + MS_LOG(ERROR) << "GraphRunner fails to run graph"; + *status = PYNATIVE_GRAPH_GE_RUN_ERR; + return std::move(err_ret); + } + + std::vector graph_outputs = ConvertOutputTensors(op_exec_info, ge_outputs); + size_t output_size = graph_outputs.size(); + py::tuple result(output_size); + for (size_t i = 0; i < output_size; i++) { + MS_EXCEPTION_IF_NULL(graph_outputs[i]); + result[i] = *graph_outputs[i]; + } + + *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "RunOpInGe end"; + return std::move(result); +} +} // namespace pynative +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.h new file mode 100644 index 0000000000..2978278489 --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ +#define MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ + +#include +#include +#include +#include +#include + +#include "pipeline/pynative/base.h" +#include "transform/graph_ir/convert.h" +#include "transform/graph_ir/graph_runner.h" +#include "transform/graph_ir/types.h" +#include "utils/context/ms_context.h" + +using GeTensor = ge::Tensor; +using GeTensorPtr = std::shared_ptr; +using GeGraph = ge::Graph; +using GeGraphPtr = std::shared_ptr; + +namespace mindspore { +namespace pynative { +bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const std::unordered_map &attrs, const GeGraphPtr &graph); + +py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); +} // namespace pynative +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ diff --git a/mindspore/ccsrc/pipeline/remove_value_node_dup.cc b/mindspore/ccsrc/pipeline/remove_value_node_dup.cc deleted file mode 100644 index 47881e4b91..0000000000 --- a/mindspore/ccsrc/pipeline/remove_value_node_dup.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/remove_value_node_dup.h" -#include "ir/anf.h" -#include "ir/tensor.h" -#include "ir/manager.h" -#include "optimizer/cse.h" -#include "utils/log_adapter.h" -#include "utils/hashing.h" - -namespace mindspore { -namespace pipeline { -void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache, - HashValue *const hash_value) { - const auto &to_check_value = GetValueNode(node); - MS_EXCEPTION_IF_NULL(to_check_value); - - // Calculate hash value. - size_t h; - auto hash_iter = hash_value->find(node); - if (hash_iter == hash_value->end()) { - h = hash_combine(to_check_value->hash(), (opt::AbsOf(node)->hash())); - (*hash_value)[node] = h; - } else { - h = hash_iter->second; - } - - auto bucket_iter = hash_cache->find(h); - if (bucket_iter == hash_cache->end()) { - // Meet for the first time, add bucket. - (*hash_cache)[h] = {node}; - return; - } - - auto &bucket = bucket_iter->second; - // Check if need to replace node with value node already met. - for (const auto &v : bucket) { - // Already met and cached. - if (v == node) { - return; - } - const auto &existed_value = GetValueNode(v); - MS_EXCEPTION_IF_NULL(existed_value); - auto equal = [&]() -> bool { - if (existed_value->isa() && to_check_value->isa()) { - return existed_value->cast()->ValueEqual(*(to_check_value->cast())); - } - return *existed_value == *to_check_value; - }; - if (equal()) { - (void)manager->Replace(node, v); - return; - } - } - - // Meet for the first time, append node to bucket. - bucket.emplace_back(node); -} -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/remove_value_node_dup.h b/mindspore/ccsrc/pipeline/remove_value_node_dup.h deleted file mode 100644 index 8f670c7dcf..0000000000 --- a/mindspore/ccsrc/pipeline/remove_value_node_dup.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_REMOVE_VALUE_NODE_DUP_H_ -#define MINDSPORE_CCSRC_PIPELINE_REMOVE_VALUE_NODE_DUP_H_ - -#include -#include -#include "ir/base.h" -#include "ir/manager.h" - -namespace mindspore { -namespace pipeline { -using HashCache = std::unordered_map>; -using HashValue = std::unordered_map; - -void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value); -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_REMOVE_VALUE_NODE_DUP_H_ diff --git a/mindspore/ccsrc/pipeline/resource.cc b/mindspore/ccsrc/pipeline/resource.cc deleted file mode 100644 index faf1f2015d..0000000000 --- a/mindspore/ccsrc/pipeline/resource.cc +++ /dev/null @@ -1,262 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/resource.h" -#include "pipeline/pipeline.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "debug/draw.h" -#include "debug/trace.h" -#include "ir/dtype.h" -#include "pipeline/parse/data_converter.h" -#include "operator/ops.h" -#include "utils/graph_utils.h" -#include "optimizer/ad/dfunctor.h" -#include "vm/segment_runner.h" - -namespace mindspore { -// namespace to support opmap definition -namespace pipeline { - -MethodMap &GetMethodMap() { - static MethodMap method_map = { - {kObjectTypeString, - { - {"__bool__", std::string("str_bool")} // C.str_bool - }}, - {kMetaTypeNone, - { - {"__bool__", std::string("none_bool")} // C.none_bool - }}, - {kNumberTypeBool, - { - {"__and__", prim::kPrimBoolAnd}, // P.bool_and - {"__or__", prim::kPrimBoolOr}, // P.bool_or - {"__eq__", prim::kPrimBoolEq}, // P.bool_eq - {"__ne__", std::string("bool_ne")}, // C.bool_ne - {"__bool__", prim::kPrimIdentity} // P.identity - }}, - {kNumberTypeInt, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul - {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv - {"__truediv__", std::string("int_truediv")}, // C.int_truediv - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow - {"__floor__", prim::kPrimIdentity}, // P.identity - {"__trunc__", prim::kPrimIdentity}, // P.identity - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt - {"__le__", prim::kPrimScalarLe}, // P.scalar_le - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge - {"__bool__", std::string("int_bool")}, // C.int_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array - }}, - {kNumberTypeUInt, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, - {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, - {"__truediv__", std::string("int_truediv")}, // C.int_truediv - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, - {"__floor__", prim::kPrimIdentity}, // P.identity, - {"__trunc__", prim::kPrimIdentity}, // P.identity, - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, - {"__le__", prim::kPrimScalarLe}, // P.scalar_le, - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, - {"__bool__", std::string("int_bool")}, // C.int_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, - }}, - {kNumberTypeFloat, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, - {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv - {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, - {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, - {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, - {"__le__", prim::kPrimScalarLe}, // P.scalar_le, - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, - {"__bool__", std::string("float_bool")}, // C.float_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, - }}, - {kObjectTypeTuple, - { - {"__len__", prim::kPrimTupleLen}, // P.tuple_len, - {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, - {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, - {"__ms_iter__", prim::kPrimIdentity}, // P.identity, - {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, - {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext - {"__bool__", std::string("tuple_bool")} // C.tuple_bool - }}, - {kObjectTypeList, - { - {"__len__", prim::kPrimListLen}, // P.list_len, - {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, - {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, - {"__ms_iter__", prim::kPrimIdentity}, // P.identity - {"__ms_next__", std::string("list_next")}, // C.list_next - {"append", std::string("list_append")}, // C.list_next - {"__bool__", std::string("list_bool")}, // C.list_bool - {"__ms_hasnext__", std::string("list_hasnext")}, - }}, - {kObjectTypeDictionary, - { - {"__len__", prim::kPrimDictLen}, // P.dict_len - {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem - {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, - {"__bool__", std::string("dict_bool")} // C.dict_bool - }}, - {kObjectTypeTensorType, - { - {"__add__", std::string("add")}, // C.add - {"__sub__", std::string("sub")}, // C.sub - {"__mul__", std::string("mul")}, // C.mul - {"__truediv__", std::string("truediv")}, // C.truediv - {"__floordiv__", std::string("floordiv")}, // C.floordiv - {"__mod__", std::string("mod")}, // C.mod - {"__pow__", std::string("pow_")}, // C.pow - {"__floor__", std::string("array_floor")}, // C.array_floor - {"__trunc__", std::string("array_trunc")}, // C.array_trunc - {"__pos__", std::string("array_uadd")}, // C.array_uadd - {"__neg__", std::string("array_usub")}, // C.array_usub - {"__eq__", std::string("eq")}, // C.eq - {"__ne__", std::string("ne")}, // C.ne - {"__lt__", std::string("lt")}, // C.lt - {"__gt__", std::string("gt")}, // C.gt - {"__le__", std::string("le")}, // C.le - {"__ge__", std::string("ge")}, // C.ge - {"__matmul__", prim::kPrimDot}, // P.dot, - {"__len__", prim::kPrimArrayLen}, // P.array_len, - {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, - {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, - {"__ms_iter__", std::string("array_iter")}, // C.array_iter - {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, - {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, - {"transpose", std::string("transpose")}, // P.transpose - {"__bool__", std::string("tensor_bool")}, // C.tensor_bool - {"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices - }}, - {kObjectTypeIndexedSlicesType, - { - {"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices - {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values - {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices - {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape - }}, - {kObjectTypeJTagged, {}}, - {kObjectTypeSymbolicKeyType, {}}, - {kObjectTypeEnvType, {}}}; - return method_map; -} - -Resource::Resource(const py::object &obj) - : engine_(std::make_shared(abstract::GetPrimEvaluatorConstructors(), manager_)), - input_(obj), - is_cleaned_(false) {} - -Resource::~Resource() { - MS_LOG(DEBUG) << "Resource clear"; - - // If exit normally, these global variables will be cleaned - // in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION, - // these global variables may not being cleaned, it may - // cause segmentfault when free python object inside these global variables - // after python interpreter got freed, so these global variables - // are cleaned here. - // So if exit normally, these global variable will be cleaned twice, - // care be taken to prevent double free in the following functions. - if (!is_cleaned_) { - try { - Clean(); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what(); - } catch (...) { - MS_LOG(ERROR) << "Exception when cleaning resource."; - } - } -} - -bool Resource::IsTypeInMethodMap(const TypeId &type) { - TypeId type_id = NormalizeTypeId(type); - const MethodMap &method_map = GetMethodMap(); - auto iter = method_map.find(static_cast(type_id)); - if (iter != method_map.end()) { - return true; - } - return false; -} - -Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { - TypeId type_id = NormalizeTypeId(type); - const MethodMap &method_map = GetMethodMap(); - auto iter = method_map.find(static_cast(type_id)); - if (iter == method_map.end()) { - MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map"; - return Any(); - } - - auto iter_map = iter->second.find(name); - if (iter_map == iter->second.end()) { - MS_LOG(WARNING) << "Object type: " << type_id << " have no method: " << name; - return Any(); - } - return iter_map->second; -} - -void Resource::Clean() { - // AbstractTensor->elements() will be saved in AbstractBasePtrList - args_spec_.clear(); - input_ = py::none(); - // Context with AbstractBasePtrList may be saved in GraphEvaluator - // some Evaluator like ResolveEvaluator may save Python object in cache, - // it should be cleaned before Python Interpreter destructed. - MS_EXCEPTION_IF_NULL(engine_); - engine_->ClearEvaluatorCache(); - // clean static variable to prevent from crash. As static variable is released after - // Python threads is released. - parse::data_converter::ClearObjectCache(); - parse::Parser::CleanParserResource(); - parse::CleanDataClassToClassMap(); - trace::ClearTraceStack(); - is_cleaned_ = true; -} -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/resource.h b/mindspore/ccsrc/pipeline/resource.h deleted file mode 100644 index 0c1348fd94..0000000000 --- a/mindspore/ccsrc/pipeline/resource.h +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ -#define MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ - -#include -#include -#include -#include -#include - -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -#include "utils/any.h" -#include "utils/profile.h" -#include "ir/manager.h" -#include "pipeline/static_analysis/prim.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "./common.h" - -namespace mindspore { -namespace pipeline { - -namespace py = pybind11; - -const char kBackend[] = "backend"; -const char kStepParallelGraph[] = "step_parallel"; -const char kOutput[] = "output"; - -class InferenceResource; - -using MethodMap = std::unordered_map>; - -MethodMap &GetMethodMap(); - -class ResourceBase { - public: - ResourceBase() { manager_ = MakeManager(); } - - virtual ~ResourceBase() = default; - - FuncGraphManagerPtr manager() { return manager_; } - // set a manager defined outside which will not manage the graphs. - void set_manager(const FuncGraphManagerPtr &manager) { manager_ = manager; } - - std::unordered_map &results() { return results_; } - - void SetResult(const std::string &key, const Any &value) { results_[key] = value; } - - Any GetResult(const std::string &key) { - if (results_.count(key) == 0) { - MS_LOG(EXCEPTION) << "this key is not in resource list:" << key; - } - return results_[key]; - } - - bool HasResult(const std::string &key) const { return results_.count(key) != 0; } - - std::unordered_map results_; - - protected: - FuncGraphManagerPtr manager_; -}; - -using ResourceBasePtr = std::shared_ptr; - -class Resource : public ResourceBase { - public: - explicit Resource(const py::object &obj = py::none()); - - ~Resource() override; - - abstract::AnalysisEnginePtr engine() { return engine_; } - - static bool IsTypeInMethodMap(const TypeId &type); - - static Any GetMethodPtr(const TypeId &type, const std::string &name); - - const py::object &input() const { return input_; } - - FuncGraphPtr func_graph() const { return func_graph_; } - void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } - - const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; } - void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; } - - // Reclaim resource and clear the cache. - // ExecutorPy::Compile() can be called multiple times, so cache - // should be cleared. - void Clean(); - - private: - abstract::AnalysisEnginePtr engine_; - FuncGraphPtr func_graph_; - abstract::AbstractBasePtrList args_spec_; - py::object input_; - bool is_cleaned_; -}; - -using ResourcePtr = std::shared_ptr; - -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc deleted file mode 100644 index ced4a518cb..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc +++ /dev/null @@ -1,362 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/abstract_function.h" - -#include - -#include "pipeline/static_analysis/analysis_context.h" -#include "pipeline/static_analysis/static_analysis.h" - -namespace mindspore { -namespace abstract { -class Evaluator; -class AnalysisEngine; - -AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) { - if (func_list.size() == 1) { - return func_list[0]; - } - return std::make_shared(func_list); -} - -AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) { - auto this_func = shared_from_base(); - if (other->isa()) { - if (*this_func == *other) { - return this_func; - } - return std::make_shared(this_func, other); - } - auto other_union = dyn_cast(other); - if (other_union->IsSuperSet(this_func)) { - return other; - } - return std::make_shared(this_func, other); -} - -void AbstractFuncAtom::Visit(std::function visit_func) const { - visit_func(const_cast(this)->shared_from_base()); -} - -bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; } - -AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) { func_list_ = func_list; } - -AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) { - AbstractFuncAtomPtrList new_func_list; - auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); }; - - first->Visit(build_func_list); - second->Visit(build_func_list); - func_list_ = new_func_list; -} - -std::string AbstractFuncUnion::ToString() const { - std::ostringstream buffer; - buffer << "AbstractFuncUnion({"; - int i = 0; - for (const auto &func : func_list_) { - MS_EXCEPTION_IF_NULL(func); - buffer << "[" << i << "]: " << func->ToString() << ", "; - i++; - } - buffer << "})"; - return buffer.str(); -} - -bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) { - MS_EXCEPTION_IF_NULL(other); - std::vector is_in_list; - auto build_in_list = [this, &is_in_list](const AbstractFuncAtomPtr &func) { - auto iter = find(func_list_.begin(), func_list_.end(), func); - if (iter == func_list_.end()) { - is_in_list.push_back(false); - } - return true; - }; - other->Visit(build_in_list); - return std::all_of(is_in_list.begin(), is_in_list.end(), [](bool is_in) { return is_in; }); -} - -AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) { - auto this_func = shared_from_base(); - if (other->isa()) { - if (IsSuperSet(other)) { - return this_func; - } - return std::make_shared(this_func, other); - } - auto other_union = dyn_cast(other); - if (other_union->IsSuperSet(this_func)) { - return other; - } - return std::make_shared(this_func, other); -} - -void AbstractFuncUnion::Visit(std::function visit_func) const { - for (AbstractFuncAtomPtr poss : func_list_) { - visit_func(poss); - } -} - -bool AbstractFuncUnion::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_union = static_cast(&other); - if (func_list_.size() != other_union->func_list_.size()) { - return false; - } - if (func_list_ == other_union->func_list_) { - return true; - } - return false; -} - -std::size_t AbstractFuncUnion::hash() const { - std::size_t hash_sum = 0; - for (auto f : func_list_) { - hash_sum = hash_combine(hash_sum, f->hash()); - } - return hash_sum; -} - -EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_prim = static_cast(&other); - if (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()) { - return true; - } - return false; -} - -std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } - -EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_fg = static_cast(&other); - if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) { - return true; - } - return false; -} - -std::size_t FuncGraphAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), func_graph_->hash()); - hash_value = hash_combine(hash_value, context_->hash()); - return hash_value; -} - -std::string FuncGraphAbstractClosure::ToString() const { - std::stringstream ss; - ss << "FuncGraphAbstractClosure: " - << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString(); - return ss.str(); -} - -EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_meta_fg = static_cast(&other); - if (meta_func_graph_ == other_meta_fg->meta_func_graph_) { - return true; - } - return false; -} - -std::size_t MetaFuncGraphAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); - return hash_value; -} - -std::string MetaFuncGraphAbstractClosure::ToString() const { - return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name(); -} - -bool PartialAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_partial = static_cast(&other); - if (fn_ != other_partial->fn_) { - return false; - } - if (args_spec_list_.size() != other_partial->args_spec_list_.size()) { - return false; - } - if (args_spec_list_ == other_partial->args_spec_list_) { - return true; - } - return false; -} - -std::size_t PartialAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), fn_->hash()); - hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); - return hash_value; -} - -EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -std::string PartialAbstractClosure::ToString() const { - std::ostringstream buffer; - buffer << "PartialAbstractClosure(" << fn_->ToString() << "("; - for (auto arg : args_spec_list_) { - buffer << arg->ToString() << ", "; - } - buffer << "))"; - return buffer.str(); -} - -EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_transformed = static_cast(&other); - if (fn_ == other_transformed->fn_) { - return true; - } - return false; -} - -std::size_t JTransformedAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), fn_->hash()); - return hash_value; -} - -EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_virtual = static_cast(&other); - if (output_ != other_virtual->output_) { - return false; - } - if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) { - return false; - } - if (args_spec_list_ == other_virtual->args_spec_list_) { - return true; - } - return false; -} - -std::size_t VirtualAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), output_->hash()); - hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); - return hash_value; -} - -std::string VirtualAbstractClosure::ToString() const { - std::ostringstream buffer; - buffer << "VirtualAbstractClosure(args: {"; - int i = 0; - for (const auto &arg : args_spec_list_) { - MS_EXCEPTION_IF_NULL(arg); - buffer << "[" << i << "]: " << arg->ToString() << ", "; - i++; - } - buffer << "}, output: " << output_->ToString() << ")"; - return buffer.str(); -} - -EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_typed = static_cast(&other); - if (output_ != other_typed->output_) { - return false; - } - if (prim_ != other_typed->prim_) { - return false; - } - if (args_spec_list_.size() != other_typed->args_spec_list_.size()) { - return false; - } - if (args_spec_list_ == other_typed->args_spec_list_) { - return true; - } - return false; -} - -std::size_t TypedPrimitiveAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), prim_->hash()); - hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); - return hash_value; -} - -std::string TypedPrimitiveAbstractClosure::ToString() const { - std::ostringstream buffer; - buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {"; - int i = 0; - for (const auto &arg : args_spec_list_) { - MS_EXCEPTION_IF_NULL(arg); - buffer << "[" << i << "]: " << arg->ToString() << ", "; - i++; - } - buffer << "}, output: " << output_->ToString() << ")"; - return buffer.str(); -} - -bool DummyAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - return true; -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h b/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h deleted file mode 100644 index 9e1cf9ba83..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h +++ /dev/null @@ -1,303 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ -#define PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ - -#include -#include - -#include "pipeline/static_analysis/abstract_value.h" -#include "pipeline/static_analysis/analysis_context.h" -#include "ir/meta_func_graph.h" - -namespace mindspore { -namespace abstract { -class AbstractFuncAtom : public AbstractFunction { - public: - AbstractFuncAtom() = default; - ~AbstractFuncAtom() override = default; - MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction) - - AbstractFunctionPtr GetUnique() override { return shared_from_base(); } - EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { - MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncAtom"; - } - - AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; - void Visit(std::function) const final; - bool operator==(const AbstractFunction &other) const override; - - std::size_t hash() const override { return tid(); } -}; - -class AbstractFuncUnion : public AbstractFunction { - public: - explicit AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list); - AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second); - ~AbstractFuncUnion() override = default; - MS_DECLARE_PARENT(AbstractFuncUnion, AbstractFunction) - - std::string ToString() const override; - - AbstractFunctionPtr GetUnique() override { MS_LOG(EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; } - EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { - MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFuncUnion"; - } - bool IsSuperSet(const AbstractFunctionPtr &other); - AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; - void Visit(std::function) const final; - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; } - - private: - AbstractFuncAtomPtrList func_list_; -}; - -class PrimitiveAbstractClosure : public AbstractFuncAtom { - public: - // Represents a Primitive. - // prim: The primitive - // tracking_id: Identifies different uses of the same primitive. - explicit PrimitiveAbstractClosure(const PrimitivePtr &prim, const AnfNodePtr &tracking_id = nullptr) - : prim_(prim), tracking_id_(AnfNodeWeakPtr(tracking_id)) {} - ~PrimitiveAbstractClosure() override = default; - MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - PrimitivePtr prim() { return prim_; } - - AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } - - void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); } - - AbstractFunctionPtr Copy() const override { return std::make_shared(prim_, tracking_id()); } - - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override { return "Prim: " + prim_->name(); } - - private: - PrimitivePtr prim_; - // store it as weak_ptr to break reference cycle. - // one reference cycle example is Graph::set_output() input0 local variable. - AnfNodeWeakPtr tracking_id_; -}; - -class FuncGraphAbstractClosure : public AbstractFuncAtom { - public: - // Represents a Graph in a certain Context. - // context: The context, or Context.empty() - FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) - : func_graph_(func_graph), context_(context) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(context); - } - ~FuncGraphAbstractClosure() override = default; - MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - FuncGraphPtr func_graph() { return func_graph_; } - - AnalysisContextPtr context() const override { return context_; } - - AbstractFunctionPtr Copy() const override { - return std::make_shared(func_graph_, context_); - } - - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override; - - private: - FuncGraphPtr func_graph_; - AnalysisContextPtr context_; -}; -using FuncGraphAbstractClosurePtr = std::shared_ptr; - -class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { - public: - explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope) - : meta_func_graph_(meta_func_graph), scope_(scope) {} - ~MetaFuncGraphAbstractClosure() override = default; - MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) - - MetaFuncGraphPtr meta_func_graph() { return meta_func_graph_; } - - AnalysisContextPtr context() const override { return kDummyAnalysisContext; } - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - ScopePtr GetScope() { return scope_; } - - AbstractFunctionPtr Copy() const override { return std::make_shared(meta_func_graph_); } - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override; - - private: - MetaFuncGraphPtr meta_func_graph_; - ScopePtr scope_; -}; -using MetaFuncGraphAbstractClosurePtr = std::shared_ptr; - -class PartialAbstractClosure : public AbstractFuncAtom { - public: - // Represents a partial application. - // args_spec_list: The first few arguments of that function - PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list, - const AnfNodePtr &node = nullptr) - : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {} - ~PartialAbstractClosure() override = default; - MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - AbstractFunctionPtr fn() { return fn_; } - AbstractBasePtrList args() { return args_spec_list_; } - AnfNodePtr node() { return node_.lock(); } - void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } - AbstractFunctionPtr Copy() const override { - return std::make_shared(fn_, args_spec_list_, node_.lock()); - } - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override; - - private: - AbstractFuncAtomPtr fn_; - AbstractBasePtrList args_spec_list_; - // The CNode which this PartialAbstractClosure evaluated from. - AnfNodeWeakPtr node_; -}; - -class JTransformedAbstractClosure : public AbstractFuncAtom { - public: - // Represents a Function transformed through the application of J. - explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} - ~JTransformedAbstractClosure() override = default; - MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom) - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - AbstractFuncAtomPtr fn() { return fn_; } - AbstractFunctionPtr Copy() const override { return std::make_shared(fn_); } - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override { return "J(" + fn_->ToString() + ")"; } - - private: - AbstractFuncAtomPtr fn_; -}; - -class VirtualAbstractClosure : public AbstractFuncAtom { - public: - // Represents some function with an explicitly fixed type signature. - // args_spec_list: The arguments as abstract value given to the function - // output: The output which is abstract value. - VirtualAbstractClosure(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output_spec) - : args_spec_list_(args_spec_list), output_(output_spec) {} - VirtualAbstractClosure(const AbstractBasePtr &args_spec, const AbstractBasePtr &output_spec) - : args_spec_list_({args_spec}), output_(output_spec) {} - ~VirtualAbstractClosure() override = default; - MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - AbstractBasePtrList args_spec_list() { return args_spec_list_; } - - AbstractBasePtr output() { return output_; } - AbstractFunctionPtr Copy() const override { - return std::make_shared(args_spec_list_, output_); - } - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override; - - private: - AbstractBasePtrList args_spec_list_; - AbstractBasePtr output_; -}; -using VirtualAbstractClosurePtr = std::shared_ptr; - -class TypedPrimitiveAbstractClosure : public AbstractFuncAtom { - public: - // Represents a Primitive with an explicitly fixed type signature. - // args_spec_list: The arguments as abstract value given to the Primitive - // output: The output which is abstract value. - TypedPrimitiveAbstractClosure(const PrimitivePtr prim, const AbstractBasePtrList &args_spec_list, - const AbstractBasePtr &output_spec) - : prim_(prim), args_spec_list_(args_spec_list), output_(output_spec) {} - ~TypedPrimitiveAbstractClosure() override = default; - MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) override; - - PrimitivePtr prim() { return prim_; } - AbstractBasePtrList args_spec_list() { return args_spec_list_; } - AbstractBasePtr output() { return output_; } - AbstractFunctionPtr Copy() const override { - return std::make_shared(prim_, args_spec_list_, output_); - } - bool operator==(const AbstractFunction &other) const override; - std::size_t hash() const override; - - std::string ToString() const override; - - private: - PrimitivePtr prim_; - AbstractBasePtrList args_spec_list_; - AbstractBasePtr output_; -}; - -// Represents a function that can't be called. -class DummyAbstractClosure : public AbstractFuncAtom { - public: - DummyAbstractClosure() = default; - ~DummyAbstractClosure() override = default; - MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom) - - EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; } - - AbstractFunctionPtr Copy() const override { return std::make_shared(); } - bool operator==(const AbstractFunction &other) const override; - - std::string ToString() const override { return "DummyAbstractClosure()"; } -}; - -struct AbstractFunctionHasher { - std::size_t operator()(const AbstractFunctionPtr &t) const { - std::size_t hash = t->hash(); - return hash; - } -}; - -struct AbstractFunctionEqual { - bool operator()(const AbstractFunctionPtr &lhs, const AbstractFunctionPtr &rhs) const { return *lhs == *rhs; } -}; -} // namespace abstract -} // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_FUNCTION_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc deleted file mode 100644 index b59545e5ae..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc +++ /dev/null @@ -1,1120 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/abstract_value.h" - -#include - -#include "utils/symbolic.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "pipeline/static_analysis/utils.h" - -namespace mindspore { -namespace abstract { -bool AbstractBase::operator==(const AbstractBase &other) const { - if (tid() != other.tid()) { - return false; - } - if (BuildType()->type_id() == kObjectTypeUndeterminedType && - other.BuildType()->type_id() == kObjectTypeUndeterminedType) { - return true; - } - if (value_ == nullptr || other.value_ == nullptr) { - MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: " - << this->ToString() << ", other: " << other.ToString(); - } - - bool value_equal = *value_ == *other.value_; - bool type_equal = *type_ == *other.type_; - bool shape_equal = *shape_ == *other.shape_; - return value_equal && type_equal && shape_equal; -} - -ValuePtr AbstractBase::BuildValue() const { - if (value_ == nullptr) { - return RealBuildValue(); - } - return value_; -} - -AbstractBasePtr AbstractBase::Broaden() const { - AbstractBasePtr clone = Clone(); - clone->set_value(kAnyValue); - clone->set_sparse_grad(sparse_grad_); - return clone; -} - -std::string AbstractBase::ToString() const { - std::ostringstream buffer; - std::string value = std::string("value is null"); - if (value_ != nullptr) { - value = value_->ToString(); - } - MS_EXCEPTION_IF_NULL(type_); - MS_EXCEPTION_IF_NULL(shape_); - buffer << type_name() << "(" - << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() - << " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")"; - return buffer.str(); -} - -AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden(); } - -AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { - MS_EXCEPTION_IF_NULL(other); - if (*this == *other) { - auto ret = shared_from_base(); - ret->set_sparse_grad(sparse_grad()); - ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return ret; - } - auto value_self = GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_self); - ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); - TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); - if (res_value == value_self) { - auto ret = shared_from_base(); - ret->set_sparse_grad(sparse_grad()); - ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return ret; - } - auto ret = std::make_shared(res_value, res_type); - ret->set_sparse_grad(sparse_grad()); - ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return ret; -} - -AbstractBasePtr AbstractType::Clone() const { - ValuePtr value_self = GetValueTrack(); - if (value_self == nullptr || !value_self->isa()) { - return nullptr; - } - TypePtr type_self = value_self->cast(); - return std::make_shared(type_self->Clone()); -} - -bool AbstractType::operator==(const AbstractBase &other) const { - if (tid() != other.tid()) { - return false; - } - // Have to compare TypePtr with value; - ValuePtr value_self = GetValueTrack(); - ValuePtr value_other = other.GetValueTrack(); - if (value_self == nullptr || value_other == nullptr) { - MS_LOG(EXCEPTION) << "AbstractType value should not be nullptr. this: " << this->ToString() - << ", other: " << other.ToString(); - } - if (!value_self->isa() || !value_other->isa()) { - return false; - } - TypePtr type_self = value_self->cast(); - TypePtr type_other = value_other->cast(); - bool value_equal = *type_self == *type_other; - return value_equal; -} - -std::string AbstractType::ToString() const { - std::ostringstream buffer; - ValuePtr value_self = GetValueTrack(); - if (value_self == nullptr) { - buffer << "AbstractType value: nullptr"; - return buffer.str(); - } - if (!value_self->isa()) { - buffer << type_name() << "(Value: nullptr)"; - return buffer.str(); - } - TypePtr type_self = value_self->cast(); - MS_EXCEPTION_IF_NULL(type_self); - buffer << type_name() << "(" - << "Value: " << type_self->ToString() << ")"; - return buffer.str(); -} - -std::string AbstractError::ToString() const { - std::ostringstream buffer; - auto value_track = GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_track); - buffer << type_name() << "(" - << "Value: " << value_track->ToString() << ", Node: " << node_->DebugString() << ")"; - return buffer.str(); -} - -AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) { - MS_EXCEPTION_IF_NULL(other); - auto other_func = dyn_cast(other); - if (other_func == nullptr) { - MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); - } - return Join(other_func); -} - -bool AbstractFunction::operator==(const AbstractBase &other) const { - if (!other.isa()) { - return false; - } - const auto &other_func = static_cast(other); - bool value_equal = (*this == other_func); - return value_equal; -} - -const AbstractBasePtr AbstractSequeue::operator[](const std::size_t &dim) const { - if (dim >= size()) { - MS_LOG(EXCEPTION) << "Index [" << dim << "] Out of the size [" << size() << "] of the list."; - } - return elements_[dim]; -} - -std::string AbstractSequeue::ToString() const { - std::ostringstream buffer; - int i = 0; - for (const auto &ele : elements_) { - MS_EXCEPTION_IF_NULL(ele); - buffer << "element[" << i << "]: " << ele->ToString() << ","; - i++; - } - return buffer.str(); -} - -TypePtrList AbstractSequeue::ElementsType() const { - TypePtrList element_type_list; - for (const auto &ele : elements_) { - MS_EXCEPTION_IF_NULL(ele); - TypePtr element_type = ele->BuildType(); - element_type_list.push_back(element_type); - } - return element_type_list; -} - -BaseShapePtrList AbstractSequeue::ElementsShape() const { - BaseShapePtrList element_shape_list; - for (const auto &ele : elements_) { - MS_EXCEPTION_IF_NULL(ele); - BaseShapePtr element_shape = ele->BuildShape(); - element_shape_list.push_back(element_shape); - } - return element_shape_list; -} - -AbstractBasePtrList AbstractSequeue::ElementsClone() const { - AbstractBasePtrList ele_list; - for (const auto &ele : elements_) { - MS_EXCEPTION_IF_NULL(ele); - AbstractBasePtr clone = ele->Clone(); - ele_list.push_back(clone); - } - return ele_list; -} - -AbstractBasePtrList AbstractSequeue::ElementsBroaden() const { - AbstractBasePtrList ele_list; - for (const auto &ele : elements_) { - MS_EXCEPTION_IF_NULL(ele); - AbstractBasePtr broadend = ele->Broaden(); - ele_list.push_back(broadend); - } - return ele_list; -} - -template -ValuePtr AbstractSequeue::ElementsBuildValue() const { - std::vector element_value_list; - for (const auto &ele : elements_) { - ValuePtr element_value = ele->BuildValue(); - if (element_value->isa()) { - return kAnyValue; - } - element_value_list.push_back(element_value); - } - return std::make_shared(element_value_list); -} -template ValuePtr AbstractSequeue::ElementsBuildValue() const; -template ValuePtr AbstractSequeue::ElementsBuildValue() const; - -template -AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &other) { - auto other_sequeue = dyn_cast(other); - if (other_sequeue == nullptr) { - MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); - } - auto joined_list = AbstractJoin(elements_, other_sequeue->elements_); - bool changes = false; - for (std::size_t i = 0; i < elements_.size(); i++) { - if (elements_[i] != joined_list[i]) { - changes = true; - break; - } - } - if (!changes) { - return shared_from_base(); - } - return std::make_shared(joined_list); -} -template AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &); -template AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &); - -std::size_t AbstractSequeue::hash() const { - std::size_t hash_sum = hash_combine(tid(), std::hash{}(elements_.size())); - // Hashing all elements is costly, so only take at most 4 elements into account based on - // some experiments. - for (size_t i = 0; (i < elements_.size()) && (i < 4); i++) { - hash_sum = hash_combine(hash_sum, elements_[i]->hash()); - } - return hash_sum; -} - -bool AbstractTuple::operator==(const AbstractTuple &other) const { - if (&other == this) { - return true; - } - - if (elements_.size() != other.elements_.size()) { - return false; - } - for (size_t i = 0; i < elements_.size(); i++) { - if (!(*(elements_[i]) == *(other.elements_[i]))) { - return false; - } - } - return true; -} - -bool AbstractTuple::operator==(const AbstractBase &other) const { - if (&other == this) { - return true; - } - - if (other.isa()) { - auto other_tuple = static_cast(&other); - return *this == *other_tuple; - } - - return false; -} - -bool AbstractList::operator==(const AbstractList &other) const { - if (&other == this) { - return true; - } - - if (elements_.size() != other.elements_.size()) { - return false; - } - for (size_t i = 0; i < elements_.size(); i++) { - if (!(*(elements_[i]) == *(other.elements_[i]))) { - return false; - } - } - return true; -} - -bool AbstractList::operator==(const AbstractBase &other) const { - if (&other == this) { - return true; - } - - if (other.isa()) { - auto other_list = static_cast(&other); - return *this == *other_list; - } - return false; -} - -TypePtr AbstractSlice::BuildType() const { - MS_EXCEPTION_IF_NULL(start_); - MS_EXCEPTION_IF_NULL(stop_); - MS_EXCEPTION_IF_NULL(step_); - TypePtr start = start_->BuildType(); - TypePtr stop = stop_->BuildType(); - TypePtr step = step_->BuildType(); - return std::make_shared(start, stop, step); -} - -bool AbstractSlice::operator==(const AbstractSlice &other) const { - if (&other == this) { - return true; - } - return (*start_ == *other.start_ && *stop_ == *other.stop_ && *step_ == *other.step_); -} - -bool AbstractSlice::operator==(const AbstractBase &other) const { - if (&other == this) { - return true; - } - if (!other.isa()) { - return false; - } - auto other_slice = static_cast(&other); - return *this == *other_slice; -} - -AbstractBasePtr AbstractSlice::Clone() const { - MS_EXCEPTION_IF_NULL(start_); - MS_EXCEPTION_IF_NULL(stop_); - MS_EXCEPTION_IF_NULL(step_); - AbstractBasePtr start = start_->Clone(); - AbstractBasePtr stop = stop_->Clone(); - AbstractBasePtr step = step_->Clone(); - return std::make_shared(start, stop, step); -} - -AbstractBasePtr AbstractSlice::Broaden() const { - MS_EXCEPTION_IF_NULL(start_); - MS_EXCEPTION_IF_NULL(stop_); - MS_EXCEPTION_IF_NULL(step_); - AbstractBasePtr start = start_->Broaden(); - AbstractBasePtr stop = stop_->Broaden(); - AbstractBasePtr step = step_->Broaden(); - return std::make_shared(start, stop, step); -} - -std::string AbstractSlice::ToString() const { - std::ostringstream buffer; - buffer << type_name() << "["; - MS_EXCEPTION_IF_NULL(start_); - buffer << start_->ToString() << " : "; - MS_EXCEPTION_IF_NULL(stop_); - buffer << stop_->ToString() << " : "; - MS_EXCEPTION_IF_NULL(step_); - buffer << step_->ToString(); - buffer << "]"; - return buffer.str(); -} - -ValuePtr AbstractSlice::RealBuildValue() const { - MS_EXCEPTION_IF_NULL(start_); - MS_EXCEPTION_IF_NULL(stop_); - MS_EXCEPTION_IF_NULL(step_); - ValuePtr start = start_->BuildValue(); - ValuePtr stop = stop_->BuildValue(); - ValuePtr step = step_->BuildValue(); - if (start->isa() || stop->isa() || step->isa()) { - return kAnyValue; - } - return std::make_shared(start, stop, step); -} - -std::size_t AbstractSlice::hash() const { - MS_EXCEPTION_IF_NULL(start_); - MS_EXCEPTION_IF_NULL(stop_); - MS_EXCEPTION_IF_NULL(step_); - return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); -} - -ShapePtr AbstractUndetermined::shape() const { - auto shp = dyn_cast(GetShapeTrack()); - if (shp == nullptr) { - MS_LOG(EXCEPTION) << "Tensor should have a shape."; - } - return shp; -} - -TypePtr AbstractTensor::BuildType() const { - MS_EXCEPTION_IF_NULL(element_); - TypePtr element_type = element_->BuildType(); - return std::make_shared(element_type); -} - -BaseShapePtr AbstractTensor::BuildShape() const { - auto shape = GetShapeTrack(); - // Guard from using set_shape(nullptr) - if (shape == nullptr) { - return kNoShape; - } - return shape; -} - -AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { - if (other->BuildType()->type_id() == kObjectTypeUndeterminedType) { - auto other_tensor = dyn_cast(other); - auto element = element_->Join(other_tensor->element()); - auto shape = ShapeJoin(this->shape(), other_tensor->shape()); - auto ret = std::make_shared(element, shape); - return ret; - } - auto other_tensor = dyn_cast(other); - if (other_tensor == nullptr) { - MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); - } - if (*this == *other) { - if (sparse_grad() == other->sparse_grad()) { - return shared_from_base(); - } - } - auto element = element_->Join(other_tensor->element_); - auto shape = ShapeJoin(this->shape(), other_tensor->shape()); - auto ret = std::make_shared(element, shape); - ret->set_sparse_grad(sparse_grad()); - ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return ret; -} - -bool AbstractTensor::operator==(const AbstractTensor &other) const { - if (&other == this) { - return true; - } - - auto v1 = GetValueTrack(); - auto v2 = other.GetValueTrack(); - if (v1 == nullptr || v2 == nullptr) { - MS_LOG(EXCEPTION) << "The value of AbstractTensor is nullptr"; - } - - bool is_value_equal = (v1 == v2); - if (v1->isa() && v2->isa()) { - is_value_equal = true; - } - return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal; -} - -bool AbstractTensor::operator==(const AbstractBase &other) const { - if (&other == this) { - return true; - } - - if (other.isa()) { - auto other_tensor = static_cast(&other); - return *this == *other_tensor; - } else { - return false; - } -} - -AbstractBasePtr AbstractTensor::Clone() const { - MS_EXCEPTION_IF_NULL(element_); - auto clone = std::make_shared(element_->Clone()); - ShapePtr shp = shape(); - clone->set_shape(shp->Clone()); - clone->set_value(GetValueTrack()); - clone->set_sparse_grad(sparse_grad()); - clone->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return clone; -} - -AbstractBasePtr AbstractTensor::Broaden() const { - MS_EXCEPTION_IF_NULL(element_); - auto broaden = std::make_shared(element_->Broaden()); - auto shp = shape(); - broaden->set_shape(shp->Clone()); - broaden->set_value(kAnyValue); - broaden->set_sparse_grad(sparse_grad()); - broaden->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return broaden; -} - -AbstractBasePtr AbstractTensor::BroadenWithShape() const { - MS_EXCEPTION_IF_NULL(element_); - auto broaden = std::make_shared(element_->Broaden()); - auto shp = shape()->Clone(); - shp->Broaden(); - broaden->set_shape(shp); - broaden->set_value(kAnyValue); - broaden->set_sparse_grad(sparse_grad()); - broaden->set_has_indexed_slices_grad(has_indexed_slices_grad()); - return broaden; -} - -std::string AbstractTensor::ToString() const { - std::ostringstream buffer; - BaseShapePtr shape_track = GetShapeTrack(); - MS_EXCEPTION_IF_NULL(shape_track); - MS_EXCEPTION_IF_NULL(element_); - auto value_track = GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_track); - buffer << type_name() << "(" - << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() - << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad() - << " has_indexed_slices_grad " << has_indexed_slices_grad() << ")"; - return buffer.str(); -} - -TypePtr AbstractDictionary::BuildType() const { - std::vector> key_values; - for (const auto &item : key_values_) { - MS_EXCEPTION_IF_NULL(item.second); - TypePtr type = item.second->BuildType(); - key_values.emplace_back(item.first, type); - } - return std::make_shared(key_values); -} - -bool AbstractDictionary::operator==(const AbstractDictionary &other) const { - if (key_values_.size() != other.key_values_.size()) { - return false; - } - - for (size_t index = 0; index < key_values_.size(); index++) { - if (key_values_[index].first != other.key_values_[index].first) { - return false; - } - if (!(*key_values_[index].second == *other.key_values_[index].second)) { - return false; - } - } - return true; -} - -bool AbstractDictionary::operator==(const AbstractBase &other) const { - if (&other == this) { - return true; - } - if (other.isa()) { - auto other_class = static_cast(&other); - return *this == *other_class; - } - return false; -} - -AbstractBasePtr AbstractDictionary::Clone() const { - std::vector kv; - (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const AbstractAttribute &item) { - MS_EXCEPTION_IF_NULL(item.second); - return std::make_pair(item.first, item.second->Clone()); - }); - return std::make_shared(kv); -} - -AbstractBasePtr AbstractDictionary::Broaden() const { - std::vector kv; - (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const AbstractAttribute &item) { - MS_EXCEPTION_IF_NULL(item.second); - return std::make_pair(item.first, item.second->Broaden()); - }); - return std::make_shared(kv); -} - -std::string AbstractDictionary::ToString() const { - std::ostringstream buffer; - buffer << type_name() << "{ "; - for (const auto &kv : key_values_) { - MS_EXCEPTION_IF_NULL(kv.second); - buffer << "(" << kv.first << ": " << kv.second->ToString() << ") "; - } - buffer << "}"; - return buffer.str(); -} - -std::size_t AbstractDictionary::hash() const { - std::size_t hash_sum = std::accumulate(key_values_.begin(), key_values_.end(), tid(), - [](std::size_t hash_sum, const AbstractAttribute &item) { - hash_sum = hash_combine(hash_sum, std::hash()(item.first)); - MS_EXCEPTION_IF_NULL(item.second); - hash_sum = hash_combine(hash_sum, item.second->hash()); - return hash_sum; - }); - return hash_sum; -} - -ValuePtr AbstractDictionary::RealBuildValue() const { - std::vector> key_values; - for (const auto &item : key_values_) { - MS_EXCEPTION_IF_NULL(item.second); - auto element_value = item.second->BuildValue(); - MS_EXCEPTION_IF_NULL(element_value); - if (element_value->isa()) { - return kAnyValue; - } - key_values.emplace_back(item.first, element_value); - } - return std::make_shared(key_values); -} - -TypePtr AbstractClass::BuildType() const { - ClassAttrVector attributes_type; - for (auto attr : attributes_) { - MS_EXCEPTION_IF_NULL(attr.second); - TypePtr type = attr.second->BuildType(); - std::pair elem(attr.first, type); - attributes_type.push_back(elem); - } - - return std::make_shared(tag_, attributes_type, methods_); -} - -bool AbstractClass::operator==(const AbstractClass &other) const { - if (!(tag_ == other.tag_)) { - return false; - } - if (attributes_.size() != other.attributes_.size()) { - return false; - } - for (size_t i = 0; i < attributes_.size(); i++) { - MS_EXCEPTION_IF_NULL(attributes_[i].second); - MS_EXCEPTION_IF_NULL(other.attributes_[i].second); - if (!(*attributes_[i].second == *other.attributes_[i].second)) { - MS_LOG(DEBUG) << "attr " << attributes_[i].first << " not equal, arg1:" << attributes_[i].second->ToString() - << " arg2:" << other.attributes_[i].second->ToString(); - return false; - } - } - // method compare; - if (methods_.size() != other.methods_.size()) { - return false; - } - for (const auto &iter : methods_) { - auto iter_other = other.methods_.find(iter.first); - if (iter_other == other.methods_.end()) { - return false; - } - if (!(*iter.second == *iter_other->second)) { - return false; - } - } - return true; -} - -bool AbstractClass::operator==(const AbstractBase &other) const { - if (other.isa()) { - auto other_class = static_cast(&other); - return *this == *other_class; - } - return false; -} - -AbstractBasePtr AbstractClass::GetAttribute(const std::string &name) { - auto it = std::find_if(attributes_.begin(), attributes_.end(), - [name](const AbstractAttribute &pair) -> bool { return pair.first == name; }); - if (it != attributes_.end()) { - return it->second; - } - return nullptr; -} - -ValuePtr AbstractClass::GetMethod(const std::string &name) { - auto method_pair = methods_.find(name); - if (method_pair != methods_.end()) { - return method_pair->second; - } - return kAnyValue; -} - -AbstractBasePtr AbstractClass::Clone() const { - std::vector attributes_clone; - for (auto attr : attributes_) { - MS_EXCEPTION_IF_NULL(attr.second); - AbstractBasePtr clone = attr.second->Clone(); - AbstractAttribute elem(attr.first, clone); - attributes_clone.push_back(elem); - } - return std::make_shared(tag_, attributes_clone, methods_); -} - -AbstractBasePtr AbstractClass::Broaden() const { - std::vector attributes_clone; - for (auto attr : attributes_) { - MS_EXCEPTION_IF_NULL(attr.second); - AbstractBasePtr clone = attr.second->Broaden(); - AbstractAttribute elem(attr.first, clone); - attributes_clone.push_back(elem); - } - return std::make_shared(tag_, attributes_clone, methods_); -} - -std::string AbstractClass::ToString() const { - std::ostringstream buffer; - buffer << type_name() << "(tag: " << tag_ << ") attrs:("; - bool append_comma = false; - for (const auto &attr : attributes_) { - if (append_comma) { - buffer << ", "; - } else { - append_comma = true; - } - MS_EXCEPTION_IF_NULL(attr.second); - buffer << attr.first << ":" << attr.second->ToString(); - } - buffer << ") method:("; - append_comma = false; - for (const auto &iter : methods_) { - if (append_comma) { - buffer << ", "; - } else { - append_comma = true; - } - MS_EXCEPTION_IF_NULL(iter.second); - buffer << iter.first << ":" << iter.second->ToString(); - } - buffer << ")"; - return buffer.str(); -} - -std::size_t AbstractClass::hash() const { - std::size_t hash_sum = std::accumulate(attributes_.begin(), attributes_.end(), hash_combine(tid(), tag_.hash()), - [](std::size_t hash_sum, const AbstractAttribute &item) { - MS_EXCEPTION_IF_NULL(item.second); - return hash_combine(hash_sum, item.second->hash()); - }); - - return hash_sum; -} - -ValuePtr AbstractClass::RealBuildValue() const { - auto cls = BuildType()->cast(); - std::unordered_map attributes_value_map; - for (const auto &attr : attributes_) { - MS_EXCEPTION_IF_NULL(attr.second); - ValuePtr _value = attr.second->BuildValue(); - if (_value->isa()) { - return kAnyValue; - } - attributes_value_map[attr.first] = _value; - } - cls->set_value(attributes_value_map); - return cls; -} - -TypePtr AbstractJTagged::BuildType() const { - MS_EXCEPTION_IF_NULL(element_); - TypePtr subtype = element_->BuildType(); - return std::make_shared(subtype); -} - -AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) { - auto other_jtagged = dyn_cast(other); - if (other_jtagged == nullptr) { - MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); - } - auto joined_elem = element_->Join(other_jtagged->element_); - return std::make_shared(joined_elem); -} - -bool AbstractJTagged::operator==(const AbstractJTagged &other) const { - MS_EXCEPTION_IF_NULL(element_); - MS_EXCEPTION_IF_NULL(other.element_); - return (*element_ == *other.element_); -} - -bool AbstractJTagged::operator==(const AbstractBase &other) const { - if (other.isa()) { - auto other_jtagged = static_cast(&other); - return *this == *other_jtagged; - } - return false; -} - -std::string AbstractJTagged::ToString() const { - std::ostringstream buffer; - MS_EXCEPTION_IF_NULL(element_); - buffer << type_name() << "(" - << "element: " << element_->ToString() << ")"; - return buffer.str(); -} - -TypePtr AbstractRef::BuildType() const { - TypePtr subtype = ref_->BuildType(); - TypePtr subtype_origin = ref_origin_->BuildType(); - return std::make_shared(subtype, subtype_origin); -} - -bool AbstractRef::operator==(const AbstractRef &other) const { - return (*ref_ == *other.ref_) && (*ref_key_ == *other.ref_key_) && (*ref_origin_ == *other.ref_origin_); -} - -bool AbstractRef::operator==(const AbstractBase &other) const { - if (other.isa()) { - auto other_conf = static_cast(&other); - return *this == *other_conf; - } - return false; -} - -AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { - auto other_ref = other->cast(); - if (other_ref == nullptr) { - MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); - } - if (*this == *other) { - return shared_from_base(); - } - auto ref_key = ref_key_->Join(other_ref->ref_key_); - auto ref = ref_->Join(other_ref->ref()); - auto ref_origin = ref_origin_->Join(other_ref->ref_origin_); - - return std::make_shared(ref_key, ref, ref_origin); -} - -std::string AbstractRef::ToString() const { - std::ostringstream buffer; - buffer << type_name() << "(" - << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString() - << " origin_value: " << ref_origin_->ToString(); - auto value = GetValueTrack(); - if (value) { - buffer << ", value: " << value->ToString(); - } - buffer << ")"; - return buffer.str(); -} - -bool AbstractNone::operator==(const AbstractNone &) const { return true; } - -bool AbstractNone::operator==(const AbstractBase &other) const { - if (other.isa()) { - auto other_none = static_cast(&other); - return *this == *other_none; - } - return false; -} - -std::string AbstractNone::ToString() const { - std::ostringstream buffer; - buffer << type_name() << "(Value: None)"; - return buffer.str(); -} - -ValuePtr AbstractNone::RealBuildValue() const { return kNone; } - -bool AbstractRefKey::operator==(const AbstractRefKey &other) const { - ValuePtr value_self = GetValueTrack(); - ValuePtr value_other = other.GetValueTrack(); - if (value_self != nullptr && value_other != nullptr) { - if (value_self->isa() && value_other->isa()) { - return true; - } - if (!value_self->isa() || !value_other->isa()) { - return false; - } - RefKeyPtr type_self = value_self->cast(); - RefKeyPtr type_other = value_other->cast(); - return *type_self == *type_other; - } else if (value_self != nullptr || value_other != nullptr) { - return false; - } - return true; -} - -bool AbstractRefKey::operator==(const AbstractBase &other) const { - if (other.isa()) { - auto other_confkey = static_cast(&other); - return *this == *other_confkey; - } else { - return false; - } -} - -std::string AbstractRefKey::ToString() const { - std::ostringstream buffer; - buffer << type_name(); - auto value = GetValueTrack(); - if (value) { - buffer << "(value: " << value->ToString() << ")"; - } - return buffer.str(); -} - -bool AbstractNull::operator==(const AbstractNull &) const { return true; } - -bool AbstractNull::operator==(const AbstractBase &other) const { - if (&other == this) { - return true; - } - if (other.isa()) { - auto other_none = static_cast(&other); - return *this == *other_none; - } else { - return false; - } -} - -std::string AbstractNull::ToString() const { - std::ostringstream buffer; - buffer << type_name() << "(Value: Null)"; - return buffer.str(); -} - -bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; } - -bool AbstractEllipsis::operator==(const AbstractBase &other) const { - if (&other == this) { - return true; - } - if (other.isa()) { - auto other_none = static_cast(&other); - return *this == *other_none; - } else { - return false; - } -} - -std::string AbstractEllipsis::ToString() const { - std::ostringstream buffer; - buffer << type_name() << "(Value: Ellipsis)"; - return buffer.str(); -} - -TypePtr AbstractKeywordArg::BuildType() const { - MS_EXCEPTION_IF_NULL(arg_value_); - TypePtr type = arg_value_->BuildType(); - return std::make_shared(arg_name_, type); -} - -AbstractBasePtr AbstractKeywordArg::Clone() const { - MS_EXCEPTION_IF_NULL(arg_value_); - return std::make_shared(arg_name_, arg_value_->Clone()); -} - -AbstractBasePtr AbstractKeywordArg::Broaden() const { - MS_EXCEPTION_IF_NULL(arg_value_); - return std::make_shared(arg_name_, arg_value_->Broaden()); -} - -std::size_t AbstractKeywordArg::hash() const { - MS_EXCEPTION_IF_NULL(arg_value_); - return hash_combine({tid(), std::hash{}(arg_name_), arg_value_->hash()}); -} - -std::string AbstractKeywordArg::ToString() const { - std::ostringstream buffer; - MS_EXCEPTION_IF_NULL(arg_value_); - buffer << type_name() << "("; - buffer << "key : " << arg_name_; - buffer << "value : " << arg_value_->ToString(); - buffer << ")"; - return buffer.str(); -} - -bool AbstractKeywordArg::operator==(const AbstractBase &other) const { - if (&other == this) { - return true; - } - - if (other.isa()) { - auto other_tuple = static_cast(&other); - return *this == *other_tuple; - } - return false; -} - -bool AbstractKeywordArg::operator==(const AbstractKeywordArg &other) const { - if (&other == this) { - return true; - } - MS_EXCEPTION_IF_NULL(arg_value_); - MS_EXCEPTION_IF_NULL(other.arg_value_); - return other.arg_name_ == arg_name_ && *other.arg_value_ == *arg_value_; -} - -ValuePtr AbstractKeywordArg::RealBuildValue() const { - MS_EXCEPTION_IF_NULL(arg_value_); - ValuePtr value = arg_value_->BuildValue(); - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - return kAnyValue; - } - return std::make_shared(arg_name_, value); -} - -std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list) { - std::size_t hash_value = 0; - // Hashing all elements is costly, so only take at most 4 elements into account based on - // some experiments. - for (size_t i = 0; (i < args_spec_list.size()) && (i < 4); i++) { - MS_EXCEPTION_IF_NULL(args_spec_list[i]); - hash_value = hash_combine(hash_value, args_spec_list[i]->hash()); - } - return hash_value; -} - -bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) { - if (lhs.size() != rhs.size()) { - return false; - } - std::size_t size = lhs.size(); - for (std::size_t i = 0; i < size; i++) { - MS_EXCEPTION_IF_NULL(lhs[i]); - MS_EXCEPTION_IF_NULL(rhs[i]); - if (lhs[i] == rhs[i]) { - continue; - } - if (!(*lhs[i] == *rhs[i])) { - return false; - } - } - return true; -} - -std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &args_spec_list) const { - return AbstractBasePtrListHash(args_spec_list); -} - -bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const { - return AbstractBasePtrListDeepEqual(lhs, rhs); -} - -// IndexedSlices -TypePtr AbstractIndexedSlices::BuildType() const { - MS_EXCEPTION_IF_NULL(element()); - TypePtr element_type = element()->BuildType(); - return std::make_shared(element_type); -} - -AbstractBasePtr AbstractIndexedSlices::Clone() const { - MS_EXCEPTION_IF_NULL(element()); - auto clone = std::make_shared(element()->Clone()); - ShapePtr shp = shape(); - clone->set_shape(shp->Clone()); - clone->set_value(GetValueTrack()); - clone->set_indices(indices_->Clone()->cast()); - clone->set_values(values_->Clone()->cast()); - clone->set_dense_shape(dense_shape_->Clone()->cast()); - return clone; -} - -AbstractBasePtr AbstractIndexedSlices::Broaden() const { - MS_EXCEPTION_IF_NULL(element()); - auto broaden = std::make_shared(element()->Broaden()); - auto shp = shape(); - broaden->set_shape(shp->Clone()); - broaden->set_value(kAnyValue); - broaden->set_indices(indices_->Clone()->cast()); - broaden->set_values(values_->Clone()->cast()); - broaden->set_dense_shape(dense_shape_->Clone()->cast()); - return broaden; -} - -AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const { - MS_EXCEPTION_IF_NULL(element()); - auto broaden = std::make_shared(element()->Broaden()); - auto shp = shape()->Clone(); - shp->Broaden(); - broaden->set_shape(shp); - broaden->set_value(kAnyValue); - broaden->set_indices(indices_->Clone()->cast()); - broaden->set_values(values_->Clone()->cast()); - broaden->set_dense_shape(dense_shape_->Clone()->cast()); - return broaden; -} - -std::string AbstractIndexedSlices::ToString() const { - std::ostringstream buffer; - BaseShapePtr shape_track = GetShapeTrack(); - MS_EXCEPTION_IF_NULL(shape_track); - MS_EXCEPTION_IF_NULL(element()); - auto value_track = GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_track); - buffer << type_name() << "(" - << "shape: " << shape_track->ToString() << ", element: " << element()->ToString() - << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")" - << ", indices: " << indices_->ToString() << ", values" << values_->ToString() - << ", dense_shape: " << dense_shape_->ToString(); - return buffer.str(); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h deleted file mode 100644 index 3981a6eb23..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h +++ /dev/null @@ -1,634 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_ -#define PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_ - -#include -#include -#include -#include -#include - -#include "utils/log_adapter.h" -#include "utils/hashing.h" -#include "ir/base.h" -#include "ir/dtype.h" -#include "ir/value.h" -#include "ir/tensor.h" -#include "pipeline/static_analysis/dshape.h" - -namespace mindspore { -namespace abstract { -class AbstractBase; -using AbstractBasePtrList = std::vector; - -// The base class for abstract value. The abstract value is used in evaluating -// to express the type, shape, and value of the real value. -class AbstractBase : public Base { - public: - explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, - const BaseShapePtr &shape = kNoShape) - : value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {} - ~AbstractBase() override = default; - MS_DECLARE_PARENT(AbstractBase, Base) - - std::size_t hash() const override { return tid(); } - std::string ToString() const override; - - virtual bool operator==(const AbstractBase &other) const; - void set_value(const ValuePtr &value) { value_ = value; } - void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; } - void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) { - has_indexed_slices_grad_ = has_indexed_slices_grad; - } - void set_type(const TypePtr &type) { type_ = type; } - void set_shape(const BaseShapePtr &shape) { shape_ = shape; } - void set_value_desc(const std::string &desc) { value_desc_ = desc; } - const std::string &value_desc() const { return value_desc_; } - ValuePtr GetValueTrack() const { return value_; } - const std::string &sparse_grad() const { return sparse_grad_; } - const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; } - TypePtr GetTypeTrack() const { return type_; } - BaseShapePtr GetShapeTrack() const { return shape_; } - - // Try build a real value from an abstract value. If the value cannot be built, - // a default value (AnyValue) is returned. - ValuePtr BuildValue() const; - - virtual TypePtr BuildType() const = 0; - virtual BaseShapePtr BuildShape() const { return kNoShape; } - virtual AbstractBasePtr Clone() const = 0; - virtual AbstractBasePtr Broaden() const; - virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base(); } - - friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &a) { - os << a->ToString(); - return os; - } - - protected: - // default implementation, it can be overwritten by subclass; - virtual ValuePtr RealBuildValue() const { return kAnyValue; } - - private: - ValuePtr value_; - TypePtr type_; - BaseShapePtr shape_; - std::string value_desc_; // store initial value description for error report - std::string sparse_grad_; - bool has_indexed_slices_grad_; -}; - -class AbstractScalar : public AbstractBase { - public: - AbstractScalar() : AbstractBase(kAnyValue, kAnyType) {} - explicit AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {} - explicit AbstractScalar(const ValuePtr &value) : AbstractBase(value, value->type()) {} - explicit AbstractScalar(int value) : AbstractBase(MakeValue(value), kInt32) {} - explicit AbstractScalar(float value) : AbstractBase(MakeValue(value), kFloat32) {} - explicit AbstractScalar(double value) : AbstractBase(MakeValue(value), kFloat64) {} - explicit AbstractScalar(bool value) : AbstractBase(MakeValue(value), kBool) {} - explicit AbstractScalar(const std::string &value) : AbstractBase(MakeValue(value), kString) {} - explicit AbstractScalar(const TypePtr &type) : AbstractBase(kAnyValue, type) {} - ~AbstractScalar() override = default; - MS_DECLARE_PARENT(AbstractScalar, AbstractBase) - - std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); } - - TypePtr BuildType() const override { return GetTypeTrack(); } - AbstractBasePtr Clone() const override { - return std::make_shared(GetValueTrack(), GetTypeTrack()->Clone()); - } - AbstractBasePtr Broaden() const override; - AbstractBasePtr Join(const AbstractBasePtr &other) override; -}; -using AbstractScalarPtr = std::shared_ptr; - -class AbstractType : public AbstractBase { - public: - explicit AbstractType(const TypePtr &type) : AbstractBase(type, kTypeType) { - if (type == nullptr) { - MS_LOG(EXCEPTION) << "type is nullptr"; - } - } - ~AbstractType() override = default; - MS_DECLARE_PARENT(AbstractType, AbstractBase) - - std::string ToString() const override; - bool operator==(const AbstractBase &other) const override; - - TypePtr BuildType() const override { return std::make_shared(); } - AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override { return Clone(); } -}; -using AbstractTypePtr = std::shared_ptr; - -class AbstractError : public AbstractBase { - public: - explicit AbstractError(const StringImmPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) { - if (err == nullptr || node == nullptr) { - MS_LOG(EXCEPTION) << "err or node is nullptr"; - } - } - ~AbstractError() override = default; - MS_DECLARE_PARENT(AbstractError, AbstractBase) - - TypePtr BuildType() const override { return std::make_shared(); } - AbstractBasePtr Broaden() const override { return Clone(); } - - AbstractBasePtr Clone() const override { - return std::make_shared(GetValueTrack()->cast(), node_); - } - - std::string ToString() const override; - - private: - // Origin node been specialized to AbstractError, for debug purpose only. - const AnfNodePtr node_; -}; - -class Evaluator; -using EvaluatorPtr = std::shared_ptr; -class AnalysisEngine; -using AnalysisEnginePtr = std::shared_ptr; - -class AbstractFunction; -using AbstractFunctionPtr = std::shared_ptr; -class AbstractFuncAtom; -using AbstractFuncAtomPtr = std::shared_ptr; -using AbstractFuncAtomPtrList = std::vector; - -class AbstractFunction : public AbstractBase { - public: - AbstractFunction() = default; - ~AbstractFunction() override = default; - MS_DECLARE_PARENT(AbstractFunction, AbstractBase) - - // If there is exactly one possible function, return it. Otherwise, raise an Exception. - // Caller should ensure the uniqueness. - virtual AbstractFunctionPtr GetUnique() = 0; - - TypePtr BuildType() const override { return std::make_shared(); } - AbstractBasePtr Clone() const override { return Copy(); } - // For Function, no need to broaden. - AbstractBasePtr Broaden() const override { - return const_cast(this)->shared_from_base(); - } - virtual AbstractFunctionPtr Copy() const = 0; - - AbstractBasePtr Join(const AbstractBasePtr &other) final; - virtual AbstractFunctionPtr Join(const AbstractFunctionPtr &other) = 0; - - virtual void Visit(std::function) const = 0; - bool operator==(const AbstractBase &other) const final; - virtual bool operator==(const AbstractFunction &other) const = 0; - - static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list); - - virtual EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) = 0; - virtual AnfNodePtr tracking_id() const { return nullptr; } - virtual void set_tracking_id(AnfNodePtr) {} - virtual AnalysisContextPtr context() const { return nullptr; } -}; -using AbstractFunctionPtrList = std::vector; - -// Represents a key-value pair used in function's parameters. -class AbstractKeywordArg : public AbstractBase { - public: - AbstractKeywordArg(const std::string &key, const AbstractBasePtr &argument) : arg_name_(key), arg_value_(argument) {} - ~AbstractKeywordArg() override = default; - MS_DECLARE_PARENT(AbstractKeywordArg, AbstractBase) - - TypePtr BuildType() const override; - AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; - std::size_t hash() const override; - - bool operator==(const AbstractKeywordArg &other) const; - bool operator==(const AbstractBase &other) const override; - std::string get_key() const { return arg_name_; } - AbstractBasePtr get_arg() const { return arg_value_; } - - std::string ToString() const override; - - protected: - ValuePtr RealBuildValue() const override; - - private: - std::string arg_name_; - AbstractBasePtr arg_value_; -}; -using AbstractKeywordArgPtr = std::shared_ptr; - -class AbstractUndetermined : public AbstractBase { - public: - // shape and type are all unknown - AbstractUndetermined() : AbstractBase(kAnyValue) {} - // only element_ and value, shape track are valid member, type track are unknown. - explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) - : AbstractBase(kAnyValue), element_(element) { - if (element == nullptr) { - MS_LOG(EXCEPTION) << "element is nullptr"; - } - if (element->isa()) { - MS_LOG(EXCEPTION) << "element type error"; - } - set_shape(shape); - } - AbstractUndetermined(const TypePtr &element_type, const std::vector &shape) - : AbstractBase(kAnyValue), element_(std::make_shared(kAnyValue, element_type)) { - if (element_type == nullptr) { - MS_LOG(EXCEPTION) << "element_type is nullptr"; - } - set_shape(std::make_shared(shape)); - } - ~AbstractUndetermined() override = default; - MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase) - TypePtr BuildType() const override { return std::make_shared(); } - AbstractBasePtr Clone() const override { return std::make_shared(); } - const AbstractBasePtr element() const { return element_; } - ShapePtr shape() const; - - protected: - AbstractBasePtr element_; -}; - -class AbstractTensor : public AbstractUndetermined { - public: - // only element_ and value, shape track are valid member, type track are unknown. - explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) - : AbstractUndetermined(element, shape) {} - AbstractTensor(const TypePtr &element_type, const std::vector &shape) - : AbstractUndetermined(element_type, shape) {} - explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {} - ~AbstractTensor() override = default; - MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined) - - TypePtr BuildType() const override; - BaseShapePtr BuildShape() const override; - AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; - AbstractBasePtr BroadenWithShape() const; - AbstractBasePtr Join(const AbstractBasePtr &other) final; - - bool operator==(const AbstractTensor &other) const; - bool operator==(const AbstractBase &other) const override; - - std::string ToString() const override; - std::size_t hash() const override { - auto value = GetValueTrack(); - auto hash_sum = hash_combine(tid(), element_->hash()); - if (value != nullptr) { - auto tensor = value->cast(); - if (tensor != nullptr) { - hash_sum = hash_combine(hash_sum, IntToSize(tensor->DataSize())); - } - } - return hash_sum; - } -}; -using AbstractTensorPtr = std::shared_ptr; -using AbstractTensorPtrList = std::vector; - -class AbstractSequeue : public AbstractBase { - public: - explicit AbstractSequeue(const AbstractBasePtrList &elements) : elements_(elements) {} - ~AbstractSequeue() override = default; - MS_DECLARE_PARENT(AbstractSequeue, AbstractBase) - - TypePtrList ElementsType() const; - BaseShapePtrList ElementsShape() const; - AbstractBasePtrList ElementsClone() const; - AbstractBasePtrList ElementsBroaden() const; - - template - ValuePtr ElementsBuildValue() const; - - template - AbstractBasePtr ElementsJoin(const AbstractBasePtr &other); - - std::size_t size() const { return elements_.size(); } - const AbstractBasePtrList &elements() const { return elements_; } - - std::size_t hash() const override; - std::string ToString() const override; - const AbstractBasePtr operator[](const std::size_t &dim) const; - - protected: - AbstractBasePtrList elements_; -}; -using AbstractSequeuePtr = std::shared_ptr; - -class AbstractTuple : public AbstractSequeue { - public: - explicit AbstractTuple(const AbstractBasePtrList &elements) : AbstractSequeue(elements) {} - - ~AbstractTuple() override = default; - MS_DECLARE_PARENT(AbstractTuple, AbstractSequeue) - - TypePtr BuildType() const override { return std::make_shared(ElementsType()); } - - BaseShapePtr BuildShape() const override { return std::make_shared(ElementsShape()); } - - AbstractBasePtr Clone() const override { return std::make_shared(ElementsClone()); } - - AbstractBasePtr Broaden() const override { return std::make_shared(ElementsBroaden()); } - - AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin(other); } - - std::string ToString() const override { return type_name() + "(" + AbstractSequeue::ToString() + ")"; } - - bool operator==(const AbstractTuple &other) const; - bool operator==(const AbstractBase &other) const override; - - protected: - ValuePtr RealBuildValue() const override { return ElementsBuildValue(); } -}; -using AbstractTuplePtr = std::shared_ptr; - -class AbstractList : public AbstractSequeue { - public: - explicit AbstractList(const AbstractBasePtrList &elements) : AbstractSequeue(elements) {} - - ~AbstractList() override = default; - MS_DECLARE_PARENT(AbstractList, AbstractSequeue) - - TypePtr BuildType() const override { return std::make_shared(ElementsType()); } - - BaseShapePtr BuildShape() const override { return std::make_shared(ElementsShape()); } - - AbstractBasePtr Clone() const override { return std::make_shared(ElementsClone()); } - - AbstractBasePtr Broaden() const override { return std::make_shared(ElementsBroaden()); } - - AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin(other); } - - std::string ToString() const override { return type_name() + "[" + AbstractSequeue::ToString() + "]"; } - - bool operator==(const AbstractList &other) const; - bool operator==(const AbstractBase &other) const override; - - protected: - ValuePtr RealBuildValue() const override { return ElementsBuildValue(); } -}; -using AbstractListPtr = std::shared_ptr; - -class AbstractClass : public AbstractBase { - public: - AbstractClass(const Named &tag, const std::vector &attributes, - const std::unordered_map &methods) - : attributes_(attributes), tag_(tag), methods_(methods) {} - - ~AbstractClass() override = default; - MS_DECLARE_PARENT(AbstractClass, AbstractBase) - - TypePtr BuildType() const override; - bool operator==(const AbstractClass &other) const; - bool operator==(const AbstractBase &other) const override; - const std::vector &attributes() const { return attributes_; } - std::unordered_map methods() { return methods_; } - AbstractBasePtr GetAttribute(const std::string &name); - ValuePtr GetMethod(const std::string &name); - AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; - std::string ToString() const override; - Named tag() const { return tag_; } - std::size_t hash() const override; - - protected: - ValuePtr RealBuildValue() const override; - - private: - std::vector attributes_; - Named tag_; - std::unordered_map methods_; -}; -using AbstractClassPtr = std::shared_ptr; - -class AbstractDictionary : public AbstractBase { - public: - explicit AbstractDictionary(const std::vector &key_values) : key_values_(key_values) {} - ~AbstractDictionary() override = default; - MS_DECLARE_PARENT(AbstractDictionary, AbstractBase) - - TypePtr BuildType() const override; - bool operator==(const AbstractDictionary &other) const; - bool operator==(const AbstractBase &other) const override; - AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; - std::string ToString() const override; - std::size_t hash() const override; - std::size_t size() const { return key_values_.size(); } - const std::vector &elements() const { return key_values_; } - - std::vector key_values_; - - protected: - ValuePtr RealBuildValue() const override; -}; -using AbstractDictionaryPtr = std::shared_ptr; - -class AbstractSlice : public AbstractBase { - public: - AbstractSlice(const AbstractBasePtr &start, const AbstractBasePtr &stop, const AbstractBasePtr &step) - : start_(start), stop_(stop), step_(step) {} - ~AbstractSlice() override = default; - MS_DECLARE_PARENT(AbstractSlice, AbstractBase) - - TypePtr BuildType() const override; - bool operator==(const AbstractSlice &other) const; - bool operator==(const AbstractBase &other) const override; - AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; - std::string ToString() const override; - std::size_t hash() const override; - AbstractBasePtr start() const { return start_; } - AbstractBasePtr stop() const { return stop_; } - AbstractBasePtr step() const { return step_; } - - protected: - ValuePtr RealBuildValue() const override; - - private: - AbstractBasePtr start_; - AbstractBasePtr stop_; - AbstractBasePtr step_; -}; -using AbstractSlicePtr = std::shared_ptr; - -class AbstractJTagged : public AbstractBase { - public: - explicit AbstractJTagged(const AbstractBasePtr &element) : element_(element) {} - - ~AbstractJTagged() override = default; - MS_DECLARE_PARENT(AbstractJTagged, AbstractBase) - - TypePtr BuildType() const override; - AbstractBasePtr Clone() const override { return std::make_shared(element_->Clone()); } - AbstractBasePtr Broaden() const override { return std::make_shared(element_->Broaden()); } - AbstractBasePtr Join(const AbstractBasePtr &other) override; - - bool operator==(const AbstractJTagged &other) const; - bool operator==(const AbstractBase &other) const override; - std::string ToString() const override; - AbstractBasePtr element() { return element_; } - std::size_t hash() const override { return hash_combine(tid(), element_->hash()); } - - private: - AbstractBasePtr element_; -}; -using AbstractJTaggedPtr = std::shared_ptr; - -class AbstractNone : public AbstractBase { - public: - AbstractNone() : AbstractBase() { set_type(std::make_shared()); } - ~AbstractNone() override = default; - MS_DECLARE_PARENT(AbstractNone, AbstractBase) - - TypePtr BuildType() const override { return std::make_shared(); } - bool operator==(const AbstractNone &other) const; - bool operator==(const AbstractBase &other) const override; - AbstractBasePtr Clone() const override { return std::make_shared(); } - std::string ToString() const override; - - protected: - ValuePtr RealBuildValue() const override; -}; -using AbstractNonePtr = std::shared_ptr; - -// the un assigned state value for variable, which means the variable is not assigned -class AbstractNull : public AbstractBase { - public: - AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared()); } - ~AbstractNull() override = default; - MS_DECLARE_PARENT(AbstractNull, AbstractBase) - - TypePtr BuildType() const override { return std::make_shared(); } - bool operator==(const AbstractNull &other) const; - bool operator==(const AbstractBase &other) const override; - AbstractBasePtr Clone() const override { return std::make_shared(); } - std::string ToString() const override; -}; -using AbstractNullPtr = std::shared_ptr; - -class AbstractEllipsis : public AbstractBase { - public: - AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared()); } - ~AbstractEllipsis() override = default; - MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase) - - TypePtr BuildType() const override { return std::make_shared(); } - bool operator==(const AbstractEllipsis &other) const; - bool operator==(const AbstractBase &other) const override; - AbstractBasePtr Clone() const override { return std::make_shared(); } - std::string ToString() const override; -}; -using AbstractEllipsisPtr = std::shared_ptr; - -class AbstractRefKey : public AbstractBase { - public: - AbstractRefKey() : AbstractBase() { set_type(std::make_shared()); } - ~AbstractRefKey() override = default; - MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) - - TypePtr BuildType() const override { return std::make_shared(); } - bool operator==(const AbstractRefKey &other) const; - bool operator==(const AbstractBase &other) const override; - AbstractBasePtr Clone() const override { return std::make_shared(); } - std::string ToString() const override; -}; -using AbstractRefKeyPtr = std::shared_ptr; - -class AbstractRef : public AbstractBase { - public: - AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, const AbstractBasePtr &ref_origin) - : ref_key_(ref_key), ref_(ref_value), ref_origin_(ref_origin) { - set_type(std::make_shared()); - } - - ~AbstractRef() override = default; - MS_DECLARE_PARENT(AbstractRef, AbstractBase) - - TypePtr BuildType() const override; - bool operator==(const AbstractRef &other) const; - bool operator==(const AbstractBase &other) const override; - AbstractBasePtr Clone() const override { - return std::make_shared(ref_key_->Clone(), ref_->Clone(), ref_origin_->Clone()); - } - std::string ToString() const override; - AbstractBasePtr ref() { return ref_; } - AbstractBasePtr ref_origin() { return ref_origin_; } - AbstractBasePtr ref_key() { return ref_key_; } - AbstractBasePtr Broaden() const override { - return std::make_shared(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden()); - } - AbstractBasePtr Join(const AbstractBasePtr &other) override; - std::size_t hash() const override { - return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash{}(this->tid()) << 1); - } - - private: - AbstractBasePtr ref_key_; - AbstractBasePtr ref_; - AbstractBasePtr ref_origin_; -}; -using AbstractRefPtr = std::shared_ptr; - -struct AbstractBasePtrListHasher { - std::size_t operator()(const AbstractBasePtrList &args_spec_list) const; -}; - -struct AbstractBasePtrListEqual { - bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const; -}; - -std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); -bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); - -// IndexedSlices -class AbstractIndexedSlices : public AbstractUndetermined { - public: - explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) - : AbstractUndetermined(element, shape) {} - AbstractIndexedSlices(const TypePtr &element_type, const std::vector &shape) - : AbstractUndetermined(element_type, shape) {} - ~AbstractIndexedSlices() override = default; - MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined) - - const AbstractTensorPtr indices() const { return indices_; } - const AbstractTensorPtr values() const { return values_; } - const AbstractTuplePtr dense_shape() const { return dense_shape_; } - void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } - void set_values(const AbstractTensorPtr &values) { values_ = values; } - void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } - TypePtr BuildType() const override; - AbstractBasePtr Clone() const override; - AbstractBasePtr Broaden() const override; - AbstractBasePtr BroadenWithShape() const; - - std::string ToString() const override; - - private: - AbstractTensorPtr indices_; - AbstractTensorPtr values_; - AbstractTuplePtr dense_shape_; -}; -} // namespace abstract -} // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/analysis_context.cc b/mindspore/ccsrc/pipeline/static_analysis/analysis_context.cc deleted file mode 100644 index 4a43b14168..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/analysis_context.cc +++ /dev/null @@ -1,216 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/analysis_context.h" - -#include - -#include "utils/symbolic.h" -#include "debug/trace.h" - -namespace mindspore { -namespace abstract { -AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, - const AbstractBasePtrList &args_spec_list) { - auto children_context_map_iter = parent->children_cache_.find(fg); - if (children_context_map_iter != parent->children_cache_.end()) { - auto children_context_map = children_context_map_iter->second; - auto children_context_iter = children_context_map.find(args_spec_list); - if (children_context_iter != children_context_map.end()) { - return children_context_iter->second.lock(); - } - } - AnalysisContextPtr context_new = std::make_shared(parent, fg, args_spec_list); - // Reference to myself, so use weak_ptr to break reference cycle. - auto weak_context = std::weak_ptr(context_new); - context_new->parent_cache_[fg] = weak_context; - parent->children_cache_[fg][args_spec_list] = weak_context; - return context_new; -} - -AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph, - const AbstractBasePtrList &args_spec_list) { - FuncGraphPtr graph_parent = func_graph->parent(); - auto iter = parent_cache_.find(graph_parent); - AnalysisContextPtr parent_context = nullptr; - if (iter != parent_cache_.end()) { - parent_context = iter->second.lock(); - } - // if this happen, it will be bug in code. but we raise exception to keep the scene. - if (parent_context == nullptr) { - std::ostringstream oss; - oss << "BUG: cannot found parent_context in current context: " << this->ToString() - << ", func_graph: " << func_graph->ToString() << ", graph_parent: "; - if (graph_parent != nullptr) { - oss << graph_parent->ToString(); - } else { - oss << "nullptr"; - } - MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); - } - return NewContext(parent_context, func_graph, args_spec_list); -} - -AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) { - auto p_iter = parent_cache_.find(func_graph); - AnalysisContextPtr parent_context = nullptr; - if (p_iter != parent_cache_.end()) { - parent_context = p_iter->second.lock(); - } else { - auto iter_parent = parent_cache_.find(func_graph->parent()); - if (iter_parent != parent_cache_.end()) { - parent_context = iter_parent->second.lock(); - } - } - // if this happen, it will be bug in code. but we raise exception to keep the scene. - if (parent_context == nullptr) { - std::ostringstream oss; - oss << "BUG: Filter graph failed: " << func_graph->ToString() << ", graph_parent: "; - if (func_graph->parent() != nullptr) { - oss << func_graph->parent()->ToString(); - } else { - oss << "nullptr"; - } - oss << " parent_cache_: {"; - for (auto iter : parent_cache_) { - if (iter.first == nullptr) { - oss << " [graph: nullptr"; - } else { - oss << " [graph: " << iter.first->ToString(); - } - // iter.second cannot be nullptr even iter.first is nullptr as it will - // always be a Context() object. - oss << ", context: " << iter.second.lock()->ToString() << "]"; - } - oss << "}"; - MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); - } - return parent_context; -} - -AnalysisContextPtr AnalysisContext::DummyContext() { - AnalysisContextPtr dummy_context = std::make_shared(nullptr, nullptr, AbstractBasePtrList()); - dummy_context->parent_cache_[nullptr] = std::weak_ptr(dummy_context); - return dummy_context; -} - -bool AnalysisContext::IsDummyContext() { - if (parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty()) { - return true; - } - return false; -} - -const AnalysisContextPtr kDummyAnalysisContext = - std::make_shared(nullptr, nullptr, AbstractBasePtrList()); - -bool AnalysisContext::operator==(const AnalysisContext &other) const { - if (func_graph_ != other.func_graph_) { - return false; - } - - if (args_spec_list_.size() != other.args_spec_list_.size()) { - return false; - } - - if (((parent_ == nullptr) && (other.parent_ != nullptr)) || ((parent_ != nullptr) && (other.parent_ == nullptr))) { - return false; - } - // Compare parent with content. - bool is_parent_equal = false; - if (parent_ == other.parent_) { - is_parent_equal = true; - } else if (*parent_ == *other.parent_) { - is_parent_equal = true; - } else { - return false; - } - for (std::size_t i = 0; i < args_spec_list_.size(); i++) { - if (!(*args_spec_list_[i] == *other.args_spec_list_[i])) { - return false; - } - } - return is_parent_equal; -} - -// brief The key which controls the graph cloning in Specialize. -// -// Originally, specialize use context directly as the key for cloning graph. The graph will be cloned multiple times -// for different context, which means the graph is called from different node with different arguments and different -// free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what -// graph can be reused. -// The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined -// and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in eval, thus the reused -// graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize. -// The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies -// on correct shape to specialize a tensor constant. -AnalysisContextPtr AnalysisContext::SpecializeKey() const { - AbstractBasePtrList args_broad_shp; - (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(args_broad_shp), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { - if (arg->isa()) { - auto val = arg->GetValueTrack(); - if (val->isa()) { - auto scalar_spec = dyn_cast(arg); - auto ret_spec = scalar_spec->Broaden(); - return ret_spec; - } - } - if (arg->isa()) { - MS_LOG(DEBUG) << "refkey broaden"; - auto arg_spec = dyn_cast(arg); - auto ret_spec = arg_spec->Broaden(); - return ret_spec; - } - return arg; - }); - AnalysisContextPtr context_new = std::make_shared(nullptr, func_graph_, args_broad_shp); - context_new->parent_ = parent_; - return context_new; -} - -std::size_t AnalysisContext::hash() { - std::size_t hash_value = 0; - // hash() recursion exit condition. - if (parent_ != nullptr) { - hash_value = hash_combine(hash_value, parent_->hash()); - } - if (func_graph_ != nullptr) { - hash_value = hash_combine(hash_value, func_graph_->hash()); - } - return hash_value; -} - -std::string AnalysisContext::ToString() const { - std::ostringstream buffer; - buffer << "{"; - if (func_graph_ != nullptr) { - buffer << "Func Graph: " << func_graph_->ToString(); - } - buffer << " Args: "; - int i = 0; - for (const auto &arg : args_spec_list_) { - buffer << "[" << i << "]: " << arg->ToString() << ", "; - i++; - } - if (parent_ != nullptr) { - buffer << "Parent: " << parent_->ToString(); - } - buffer << "}"; - return buffer.str(); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/analysis_context.h b/mindspore/ccsrc/pipeline/static_analysis/analysis_context.h deleted file mode 100644 index c0b3403702..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/analysis_context.h +++ /dev/null @@ -1,88 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_ANALYSIS_CONTEXT_H_ -#define PIPELINE_STATIC_ANALYSIS_ANALYSIS_CONTEXT_H_ - -#include -#include -#include - -#include "pipeline/static_analysis/abstract_value.h" -#include "ir/meta_func_graph.h" - -namespace mindspore { -namespace abstract { -class AnalysisContext; -using AnalysisContextWeakPtr = std::weak_ptr; -using ArgsSpecToAnalysisContextMap = - std::unordered_map; - -// AnalysisContext will be stored in Config in AnalysisCache. -class AnalysisContext { - public: - AnalysisContext(const AnalysisContextPtr &parent, const FuncGraphPtr &fg, const AbstractBasePtrList &args_spec_list) - : parent_(parent), func_graph_(fg), args_spec_list_(args_spec_list) { - if (parent_ != nullptr) { - parent_cache_ = parent_->parent_cache_; - } - } - - ~AnalysisContext() = default; - - // Helper function to wrapper constructor to save shared_ptr in parent_cache. - AnalysisContextPtr NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, const AbstractBasePtrList &args_spec_list); - - // Extend this context with values for another graph. - AnalysisContextPtr NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); - - // Return a context restricted to a graph's dependencies. - AnalysisContextPtr Filter(const FuncGraphPtr &graph); - bool operator==(const AnalysisContext &other) const; - std::size_t hash(); - static AnalysisContextPtr DummyContext(); - bool IsDummyContext(); - FuncGraphPtr func_graph() const { return func_graph_; } - AnalysisContextPtr parent() const { return parent_; } - std::string ToString() const; - AnalysisContextPtr SpecializeKey() const; - AbstractBasePtrList args_spec_list() { return args_spec_list_; } - - private: - AnalysisContextPtr parent_; - FuncGraphPtr func_graph_; - AbstractBasePtrList args_spec_list_; - std::unordered_map parent_cache_; - std::unordered_map children_cache_; -}; - -struct ContextHasher { - std::size_t operator()(const AnalysisContextPtr &t) const { - std::size_t hash = t->hash(); - return hash; - } -}; - -struct ContextEqual { - bool operator()(const AnalysisContextPtr &lhs, const AnalysisContextPtr &rhs) const { return *lhs == *rhs; } -}; - -extern const AnalysisContextPtr kDummyAnalysisContext; -} // namespace abstract -} // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_ANALYSIS_CONTEXT_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/dshape.cc b/mindspore/ccsrc/pipeline/static_analysis/dshape.cc deleted file mode 100644 index 183ec772ff..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/dshape.cc +++ /dev/null @@ -1,134 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/dshape.h" - -#include -#include - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace abstract { -// used for print BaseShape content -std::ostream &operator<<(std::ostream &os, const BaseShape &bs) { - os << bs.ToString(); - return os; -} - -std::ostream &operator<<(std::ostream &os, const std::shared_ptr bs) { - MS_EXCEPTION_IF_NULL(bs); - os << bs->ToString(); - return os; -} - -bool BaseShape::operator==(const BaseShape &other) const { - if (tid() != other.tid()) { - return false; - } - return true; -} - -bool BaseShape::operator!=(const BaseShape &other) const { return !(*this == other); } - -std::string Shape::ToString() const { - std::ostringstream buffer; - bool f_begin = true; - buffer << "("; - for (auto &x : shape_) { - if (!f_begin) { - buffer << ", "; - } else { - f_begin = false; - } - buffer << x; - } - buffer << ")"; - return buffer.str(); -} - -std::string Shape::DumpText() const { - std::ostringstream buffer; - buffer << "["; - for (size_t i = 0; i < shape_.size(); i++) { - buffer << (i > 0 ? ", " : "") << shape_[i]; - } - buffer << "]"; - return buffer.str(); -} - -bool Shape::operator==(const BaseShape &other) const { - if (tid() != other.tid()) { - return false; - } - return shape_ == static_cast(other).shape_; -} - -const int Shape::SHP_ANY; -void Shape::Broaden() { - for (size_t i = 0; i < shape_.size(); i++) { - shape_[i] = SHP_ANY; - } -} - -std::string SequeueShape::ToString() const { - std::ostringstream buffer; - bool f_begin = true; - for (auto p_shp : p_shapes_) { - if (!f_begin) { - buffer << ", "; - } else { - f_begin = false; - } - MS_EXCEPTION_IF_NULL(p_shp); - buffer << p_shp->ToString(); - } - return buffer.str(); -} - -BaseShapePtrList SequeueShape::ElementsClone() const { - BaseShapePtrList ele_list; - for (auto p_shp : p_shapes_) { - MS_EXCEPTION_IF_NULL(p_shp); - ele_list.push_back(p_shp->Clone()); - } - return ele_list; -} - -template -bool SequeueShape::SequeueEqual(const BaseShape &other) const { - if (tid() != other.tid()) { - return false; - } - auto other_shapes = static_cast(other).p_shapes_; - if (other_shapes.size() != p_shapes_.size()) { - return false; - } - for (unsigned int i = 0; i < p_shapes_.size(); ++i) { - if (!(*p_shapes_[i] == *other_shapes[i])) { - return false; - } - } - return true; -} -template bool SequeueShape::SequeueEqual(const BaseShape &) const; -template bool SequeueShape::SequeueEqual(const BaseShape &) const; - -const std::shared_ptr kNoShape = std::make_shared(); -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/dshape.h b/mindspore/ccsrc/pipeline/static_analysis/dshape.h deleted file mode 100644 index 3e850e309b..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/dshape.h +++ /dev/null @@ -1,135 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_DSHAPE_H_ -#define PIPELINE_STATIC_ANALYSIS_DSHAPE_H_ - -#include -#include -#include -#include -#include -#include - -#include "utils/log_adapter.h" -#include "ir/base.h" - -namespace mindspore { -namespace abstract { -class BaseShape; -using BaseShapePtr = std::shared_ptr; -using BaseShapePtrList = std::vector; - -class BaseShape : public Base { - public: - BaseShape() = default; - ~BaseShape() override = default; - - MS_DECLARE_PARENT(BaseShape, Base) - virtual bool operator==(const BaseShape &other) const; - bool operator!=(const BaseShape &other) const; - std::size_t hash() const override { return tid(); } - - // return a deep copy - virtual BaseShapePtr Clone() const = 0; - virtual void Broaden() {} -}; - -class NoShape : public BaseShape { - public: - MS_DECLARE_PARENT(NoShape, BaseShape) - BaseShapePtr Clone() const override { return std::make_shared(); } - std::string ToString() const override { return type_name(); } -}; -extern const std::shared_ptr kNoShape; - -class Shape : public BaseShape { - public: - static const int SHP_ANY = -1; - Shape() : shape_() {} - Shape(const std::initializer_list &list) : shape_(list) {} - explicit Shape(const std::vector &list) : shape_(list) {} - ~Shape() override = default; - MS_DECLARE_PARENT(Shape, BaseShape) - std::string ToString() const override; - std::string DumpText() const override; - bool operator==(const BaseShape &other) const override; - BaseShapePtr Clone() const override { return std::make_shared(shape_); } - void Broaden() override; - std::vector &shape() { return shape_; } - - std::vector shape_; // use SHP_ANY to implement the any shape in python -}; -using ShapePtr = std::shared_ptr; -using ShapePtrList = std::vector; - -class SequeueShape : public BaseShape { - public: - SequeueShape() : p_shapes_() {} - explicit SequeueShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {} - ~SequeueShape() override = default; - MS_DECLARE_PARENT(SequeueShape, BaseShape) - - std::string ToString() const override; - BaseShapePtrList ElementsClone() const; - - template - bool SequeueEqual(const BaseShape &other) const; - - const BaseShapePtrList &shape() const { return p_shapes_; } - size_t size() const { return p_shapes_.size(); } - const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; } - - protected: - BaseShapePtrList p_shapes_; // shape list of each elements -}; -using SequeueShapePtr = std::shared_ptr; - -class TupleShape : public SequeueShape { - public: - TupleShape() : SequeueShape() {} - explicit TupleShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} - ~TupleShape() override = default; - MS_DECLARE_PARENT(TupleShape, SequeueShape) - - std::string ToString() const override { return type_name() + "(" + SequeueShape::ToString() + ")"; } - - BaseShapePtr Clone() const override { return std::make_shared(ElementsClone()); } - - bool operator==(const BaseShape &other) const override { return SequeueEqual(other); } -}; -using TupleShapePtr = std::shared_ptr; - -class ListShape : public SequeueShape { - public: - ListShape() : SequeueShape() {} - explicit ListShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} - ~ListShape() override = default; - MS_DECLARE_PARENT(ListShape, SequeueShape) - - std::string ToString() const override { return type_name() + "[" + SequeueShape::ToString() + "]"; } - - BaseShapePtr Clone() const override { return std::make_shared(SequeueShape::ElementsClone()); } - - bool operator==(const BaseShape &other) const override { return SequeueEqual(other); } -}; -using ListShapePtr = std::shared_ptr; -} // namespace abstract -} // namespace mindspore - -#endif // PIPELINE_STATIC_ANALYSIS_DSHAPE_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc deleted file mode 100644 index 34ecfc8980..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc +++ /dev/null @@ -1,400 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/evaluator.h" - -#include -#include - -#include "ir/func_graph_cloner.h" -#include "pipeline/static_analysis/utils.h" -#include "debug/trace.h" - -namespace mindspore { -namespace abstract { -namespace { -string EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list, - const AnfNodeConfigPtr &out_conf) { - MS_EXCEPTION_IF_NULL(evaluator); - std::stringstream ss; - if (out_conf != nullptr) { - ss << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name(); - } - for (size_t i = 0; i < arg_spec_list.size(); i++) { - ss << evaluator->ToString() << " input[" << i << "] abstract value: " << arg_spec_list[i]->ToString(); - } - return ss.str(); -} - -void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) { - MS_EXCEPTION_IF_NULL(evaluator); - if (out_conf != nullptr) { - auto node = out_conf->node(); - if (IsValueNode(node)) { - MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->fullname_with_scope() - << ", with debug info: " << trace::GetDebugInfo(node->debug_info()); - } else { - MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->DebugString() - << ", with debug info: " << trace::GetDebugInfo(node->debug_info()); - } - } -} -} // namespace - -AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine, - const AbstractBasePtrList &args_spec_list) { - AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list); - normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list); - FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list); - MS_EXCEPTION_IF_NULL(parent_context_); - AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list); - return context; -} - -static std::vector FastShadowSort(const AnfNodePtr &ret_node) { - auto current_func_graph = ret_node->func_graph(); - MS_EXCEPTION_IF_NULL(current_func_graph); - - std::vector sorted_nodes; - auto seen = NewSeenGeneration(); - std::size_t index = 0; - sorted_nodes.emplace_back(ret_node); - while (index < sorted_nodes.size()) { - auto current = sorted_nodes[index]; - index++; - MS_EXCEPTION_IF_NULL(current); - if (current->isa()) { - auto &inputs = current->cast()->inputs(); - for (auto it = inputs.begin(); it != inputs.end(); it++) { - AnfNodePtr input = *it; - if (input != nullptr && input->isa() && input->seen_ != seen && - input->func_graph() == current_func_graph) { - sorted_nodes.emplace_back(input); - input->seen_ = seen; - } - } - } - } - return sorted_nodes; -} - -EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { - FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); - MS_EXCEPTION_IF_NULL(fg); - std::size_t nargs = fg->parameters().size(); - if (args_spec_list.size() != nargs) { - MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is " - << fg->parameters().size() << ", but the number of provided arguments is " - << args_spec_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info()); - } - MS_EXCEPTION_IF_NULL(parent_context_); - MS_EXCEPTION_IF_NULL(engine); - graph_context_ = parent_context_->NewFuncGraphContext(fg, args_spec_list); - const auto ¶meters = fg->parameters(); - for (size_t i = 0; i < nargs; i++) { - const auto &arg = args_spec_list[i]; - const auto &node = parameters[i]; - AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); - engine->cache().set_value(conf, std::make_shared(arg, nullptr)); - } - const AnfNodePtr &func_node = fg->get_return(); - - MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString() - << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); - AbstractBasePtr ret_base = nullptr; - std::vector nodes = FastShadowSort(func_node); - for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { - const auto &node = *it; - AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); - MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); - ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); - MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() - << ", abstract: " << ret_base->ToString(); - } - - MS_EXCEPTION_IF_NULL(ret_base); - MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString(); - return std::make_shared(ret_base, nullptr); -} - -AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { - MS_EXCEPTION_IF_NULL(func_graph_); - if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { - AbstractBasePtrList broaded_list; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(arg); - return arg->Broaden(); - }); - MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) - << ", broaded: " << mindspore::ToString(broaded_list); - return broaded_list; - } - return args_spec_list; -} - -AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { - MS_EXCEPTION_IF_NULL(func_graph_); - if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { - return args_spec_list; - } - if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { - if (parent_context_) { - MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() - << ", context: " << parent_context_->ToString(); - auto last_context = parent_context_->Filter(func_graph_); - if (last_context && last_context->func_graph() == func_graph_) { - MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString(); - MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); - MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list()); - // Join the last eval arguments and current arguments to check if there are loop variant. - auto joined_args_spec_list = AbstractJoin(args_spec_list, last_context->args_spec_list()); - MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list); - // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. - if (!(joined_args_spec_list == args_spec_list)) { - func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); - MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; - } - return joined_args_spec_list; - } - } - if (trace_.size() != 0) { - MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); - MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back()); - // Join the last eval arguments and current arguments to check if there are loop variant. - auto joined_args_spec_list = AbstractJoin(args_spec_list, trace_.back()); - // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. - if (!(joined_args_spec_list == args_spec_list)) { - trace_.push_back(joined_args_spec_list); - func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); - MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; - } - MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); - return joined_args_spec_list; - } else { - trace_.push_back(args_spec_list); - } - } - return args_spec_list; -} - -FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { - auto iter = func_graph_cache_.find(args_spec_list); - FuncGraphPtr ret = nullptr; - if (iter == func_graph_cache_.end()) { - auto fg = func_graph(); - MS_EXCEPTION_IF_NULL(fg); - TraceManager::DebugTrace(std::make_shared(fg->debug_info())); - FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list); - TraceManager::EndTrace(); - func_graph_cache_[args_spec_list] = generated_graph; - MS_EXCEPTION_IF_NULL(engine); - engine->func_graph_manager()->AddFuncGraph(generated_graph); - ret = generated_graph; - } else { - ret = iter->second; - } - - // For the top graph, if it is replaced by generated graph, update the top graph to the new one. - if (parse::Parser::GetTopFuncGraph() == func_graph()) { - if (ret != func_graph()) { - parse::Parser::UpdateTopFuncGraph(ret); - } - } - return ret; -} - -FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { - auto iter = func_graph_cache_.find(args_spec_list); - if (iter != func_graph_cache_.end()) { - return iter->second; - } - - MS_EXCEPTION_IF_NULL(meta_func_graph_); - FuncGraphPtr generated_func_graph = nullptr; - if (this->bound_node() != nullptr) { - TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); - generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); - TraceManager::EndTrace(); - } else { - generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); - } - - FuncGraphPtr cloned_func_graph = BasicClone(generated_func_graph); - func_graph_cache_[args_spec_list] = cloned_func_graph; - MS_EXCEPTION_IF_NULL(engine); - engine->func_graph_manager()->AddFuncGraph(cloned_func_graph); - return cloned_func_graph; -} - -EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { - const std::string &evaluator_name = ToString(); - - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - args_spec_list = NormalizeArgs(args_spec_list); - args_spec_list = BroadenUndeterminedArgs(args_spec_list); - trace::TraceGraphEvalEnter(shared_from_base(), out_conf); - MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base(), args_spec_list, out_conf); - MS_EXCEPTION_IF_NULL(cache_); - auto iter = cache_->find(args_spec_list); - if (iter == cache_->end()) { - MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; - EvalResultPtr ret = Eval(engine, args_spec_list); - if (ret->abstract() == nullptr) { - EvalFailLogging(shared_from_base(), args_spec_list, out_conf); - MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; - } - MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << "."; - (*cache_)[args_spec_list] = ret; - trace::TraceGraphEvalLeave(shared_from_base()); - return ret; - } else { - MS_EXCEPTION_IF_NULL(iter->second); - MS_EXCEPTION_IF_NULL(iter->second->abstract()); - MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << "."; - trace::TraceGraphEvalLeave(shared_from_base()); - return iter->second; - } -} - -EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - EvalResultPtr ret = EvalPrim(engine, args_spec_list); - return ret; -} - -EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - if (args_conf_list.size() == 0) { - MS_LOG(EXCEPTION) << "Size should greater than 0"; - } - EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); - // No need to cache. - return ret; -} - -EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { - EvalResultPtr ret = EvalPrim(args_conf_list); - return ret; -} - -EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); - // Don't lookup from cache, as different out_conf with same node but different context - // may add different entry to anfnode_config_map_, like getattr primitive. - (*cache_)[args_spec_list] = ret; - return ret; -} - -EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - MS_EXCEPTION_IF_NULL(cache_); - auto iter = cache_->find(args_spec_list); - if (iter != cache_->end()) { - return iter->second; - } - - ConfigPtrList partial_args_conf_list; - // Join arguments in partial and the rest arguments from args_conf_list. - (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(partial_args_conf_list), - [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), - [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); - - (*cache_)[args_spec_list] = ret; - return ret; -} - -EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - MS_EXCEPTION_IF_NULL(cache_); - auto iter = cache_->find(args_spec_list); - if (iter != cache_->end()) { - return iter->second; - } - - // Call the original evaluator, get the result: y = f(x) - EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr); - // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input - // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) - AbstractBasePtrList bparams; - bparams.push_back(SensitivityTransform(orig_func_)); - (void)std::transform( - args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), - [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); - AbstractBasePtr bparams_final = std::make_shared(bparams); - AbstractFunctionPtr bprop = - std::make_shared(SensitivityTransform(result->abstract()), bparams_final); - - // J(f)(J(x)) return a tuple (y, bprop_f) - AbstractBasePtrList jargs = {result->abstract(), bprop}; - AbstractBasePtr jtuple = std::make_shared(jargs); - auto infer_reuslt = std::make_shared(jtuple, std::make_shared()); - (*cache_)[args_spec_list] = infer_reuslt; - return infer_reuslt; -} - -EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { - if (args_spec_list.size() != args_spec_list_.size()) { - MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() - << ", arguments no: " << args_spec_list.size(); - } - // Check each parameter and argument match; - for (std::size_t i = 0; i < args_spec_list.size(); i++) { - MS_EXCEPTION_IF_NULL(args_spec_list[i]); - (void)args_spec_list[i]->Join(args_spec_list_[i]); - } - return std::make_shared(output_, std::make_shared()); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/static_analysis/evaluator.h deleted file mode 100644 index f6430eda84..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h +++ /dev/null @@ -1,322 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ -#define PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ - -#include -#include -#include -#include - -#include "pipeline/static_analysis/static_analysis.h" - -namespace mindspore { -namespace abstract { -using EvaluatorCacheMap = - std::unordered_map; -using EvaluatorCacheMapPtr = std::shared_ptr; - -using EvaluatorAttrMap = - std::unordered_map; -using EvaluatorAttrMapPtr = std::shared_ptr; - -class Evaluator : public Base { - public: - explicit Evaluator(const std::string &id) - : cache_(std::make_shared()), - attr_cache_(std::make_shared()), - identifier_(id) {} - ~Evaluator() override = default; - MS_DECLARE_PARENT(Evaluator, Base); - - // difference between Run() and Eval(): - // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. - // Run() will modify cache_ member, so it cannot marked as const; - virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); - - virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; - - virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } - - virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { - return args_spec_list; - } - - virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { - auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { - if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { - return true; - } - return false; - }); - if (is_abstract) { - MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result"; - return std::make_shared(std::make_shared(), std::make_shared()); - } - return nullptr; - } - - std::string ToString() const override { return identifier_; } - - virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } - - virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } - - EvaluatorCacheMapPtr &cache() { return cache_; } - EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; } - - EvaluatorCacheMapPtr cache_; - EvaluatorAttrMapPtr attr_cache_; - std::string identifier_; - - AnfNodeWeakPtr bound_node_; -}; - -class PrimEvaluator : public Evaluator { - public: - explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} - ~PrimEvaluator() override = default; - MS_DECLARE_PARENT(PrimEvaluator, Evaluator); - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } -}; - -class TrivialPrimEvaluator : public PrimEvaluator { - public: - explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} - ~TrivialPrimEvaluator() override = default; - MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; - virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; -}; - -class TransitionPrimEvaluator : public PrimEvaluator { - public: - explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} - ~TransitionPrimEvaluator() override = default; - MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; - // Parameter in_conf0 : the first element in args_conf_list; - virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; -}; - -class SymbolicPrimEvaluator : public PrimEvaluator { - public: - explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} - ~SymbolicPrimEvaluator() override = default; - MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; - virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; -}; - -// Evaluator will be stored in AnalysisEngine.constructors_ -using EvaluatorPtrList = std::vector; - -class DummyEvaluator : public Evaluator { - public: - DummyEvaluator() : Evaluator("dummy") {} - ~DummyEvaluator() override = default; - MS_DECLARE_PARENT(DummyEvaluator, Evaluator); - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } -}; - -// Wrap another evaluator to track a subset of uses. -// A TrackedEvaluator has its own cache that maps possible calls to -// their results, but is ultimately backed by a different evaluator. -// Multiple TrackedEvaluators can be backed by the same Evaluator. -class TrackedEvaluator : public Evaluator { - public: - explicit TrackedEvaluator(const EvaluatorPtr &subinf) : Evaluator("TrackedEvaluator"), sub_evaluator_(subinf) {} - ~TrackedEvaluator() override = default; - MS_DECLARE_PARENT(TrackedEvaluator, Evaluator); - AnfNodePtr bound_node() const override { - if (sub_evaluator_ != nullptr) { - return sub_evaluator_->bound_node(); - } - return bound_node_.lock(); - } - - void set_bound_node(const AnfNodePtr &node) override { - if (sub_evaluator_ != nullptr) { - sub_evaluator_->set_bound_node(node); - } - bound_node_ = AnfNodeWeakPtr(node); - } - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; - std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } - - private: - EvaluatorPtr sub_evaluator_; -}; - -class BaseFuncGraphEvaluator : public Evaluator { - public: - explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context) - : Evaluator("basegraph"), parent_context_(context) {} - - ~BaseFuncGraphEvaluator() override = default; - MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); - - EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; - - virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; - - AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list); - AnalysisContextPtr graph_context() const { return graph_context_; } - - protected: - AnalysisContextPtr parent_context_; - - private: - AnalysisContextPtr graph_context_; -}; - -class FuncGraphEvaluator : public BaseFuncGraphEvaluator { - public: - FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) - : BaseFuncGraphEvaluator(context->Filter(func_graph)), func_graph_(func_graph) {} - - ~FuncGraphEvaluator() override = default; - MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator); - - FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; - - FuncGraphPtr func_graph() { return func_graph_; } - - AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override; - AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override; - std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); } - - private: - FuncGraphPtr func_graph_; - std::unordered_map - func_graph_cache_; - std::vector trace_; -}; -using FuncGraphEvaluatorPtr = std::shared_ptr; - -class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator { - public: - // Note: context parameter is not used; - MetaFuncGraphEvaluator(const MetaFuncGraphPtr &meta_func_graph, AnalysisContextPtr, const ScopePtr &scope) - : BaseFuncGraphEvaluator(AnalysisContext::DummyContext()), meta_func_graph_(meta_func_graph), scope_(scope) {} - ~MetaFuncGraphEvaluator() override = default; - MS_DECLARE_PARENT(MetaFuncGraphEvaluator, BaseFuncGraphEvaluator); - - FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; - - // Return normalized versions of the arguments. - AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { - return meta_func_graph_->NormalizeArgs(args_spec_list); - } - std::string ToString() const override { return identifier_ + "_" + meta_func_graph_->ToString(); } - - private: - MetaFuncGraphPtr meta_func_graph_; - std::unordered_map - func_graph_cache_; - ScopePtr scope_; -}; - -class PartialAppEvaluator : public Evaluator { - public: - PartialAppEvaluator(const EvaluatorPtr &evaluator, const AbstractBasePtrList &args) - : Evaluator("PartialAppEvaluator"), evaluator_(evaluator), args_spec_list_(args) {} - ~PartialAppEvaluator() override = default; - MS_DECLARE_PARENT(PartialAppEvaluator, Evaluator); - AnfNodePtr bound_node() const override { - if (evaluator_ != nullptr) { - return evaluator_->bound_node(); - } - return bound_node_.lock(); - } - - void set_bound_node(const AnfNodePtr &node) override { - if (evaluator_ != nullptr) { - evaluator_->set_bound_node(node); - } - bound_node_ = AnfNodeWeakPtr(node); - } - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; - } - - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; - std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } - - private: - EvaluatorPtr evaluator_; - AbstractBasePtrList args_spec_list_; -}; - -class VirtualEvaluator : public Evaluator { - public: - VirtualEvaluator(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output) - : Evaluator("virtual"), args_spec_list_(args_spec_list), output_(output) {} - ~VirtualEvaluator() override = default; - MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); - - EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; - std::string ToString() const override { return identifier_; } - - private: - AbstractBasePtrList args_spec_list_; - AbstractBasePtr output_; -}; - -class JEvaluator : public Evaluator { - public: - JEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func) - : Evaluator("JEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {} - ~JEvaluator() override = default; - MS_DECLARE_PARENT(JEvaluator, Evaluator); - AnfNodePtr bound_node() const override { - if (evaluator_ != nullptr) { - return evaluator_->bound_node(); - } - return bound_node_.lock(); - } - - void set_bound_node(const AnfNodePtr &node) override { - if (evaluator_ != nullptr) { - evaluator_->set_bound_node(node); - } - bound_node_ = AnfNodeWeakPtr(node); - } - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; - } - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; - std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } - - private: - EvaluatorPtr evaluator_; - AbstractFunctionPtr orig_func_; -}; -} // namespace abstract -} // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/param_validator.cc b/mindspore/ccsrc/pipeline/static_analysis/param_validator.cc deleted file mode 100644 index 2cbd33c162..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/param_validator.cc +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/param_validator.h" - -#include -#include -#include -#include "utils/symbolic.h" -#include "pipeline/static_analysis/utils.h" - -namespace mindspore { -namespace abstract { -#define ABSTRACT_REPORT_NAME_DEC(abstract) constexpr char ReportNameTraits::name[]; - -ABSTRACT_REPORT_NAME_DEC(Tensor) -ABSTRACT_REPORT_NAME_DEC(Tuple) -ABSTRACT_REPORT_NAME_DEC(Scalar) -ABSTRACT_REPORT_NAME_DEC(List) -ABSTRACT_REPORT_NAME_DEC(Dictionary) -ABSTRACT_REPORT_NAME_DEC(Slice) -ABSTRACT_REPORT_NAME_DEC(Function) -ABSTRACT_REPORT_NAME_DEC(Type) -ABSTRACT_REPORT_NAME_DEC(KeywordArg) -ABSTRACT_REPORT_NAME_DEC(Class) - -TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix) { - bool ok = std::any_of(accepts.begin(), accepts.end(), - [type](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(type, accept); }); - if (ok) { - return type; - } else { - MS_LOG(EXCEPTION) << error_message_prefix << accepts << " but is " << type->ToString(); - } -} - -TypePtr CheckTensorDType(const AbstractTensorPtr &tensor, const TypePtrList &accepts, - const std::string &error_message_prefix) { - MS_EXCEPTION_IF_NULL(tensor); - TypePtr type = tensor->BuildType(); - if (!type->isa()) { - MS_LOG(EXCEPTION) << error_message_prefix << "requires Tensor but got " << type->ToString(); - } - TypePtr ele_type = tensor->element()->BuildType(); - if (ele_type == nullptr) { - MS_LOG(EXCEPTION) << "Abstract tensor element type nullptr"; - } - return CheckType(ele_type, accepts, error_message_prefix); -} - -TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const TypePtrList &accepts, - const std::string &error_message_prefix) { - if (tensor_list.empty()) { - MS_LOG(EXCEPTION) << "Array list is empty"; - } - - auto sample_tensor = tensor_list[0]; - MS_EXCEPTION_IF_NULL(sample_tensor); - TypePtr sample_type = sample_tensor->element()->BuildType(); - std::ostringstream loginfoBuffer; - loginfoBuffer << "same type, got"; - // Check if other elements have the same type with the first element. - for (size_t index = 1; index < tensor_list.size(); ++index) { - MS_EXCEPTION_IF_NULL(tensor_list[index]); - auto aType = tensor_list[index]->element()->BuildType(); - loginfoBuffer << " " << aType->ToString(); - if (sample_type->type_id() != aType->type_id()) { - MS_LOG(EXCEPTION) << "Expected type " << sample_type->ToString() << ", but got " << aType->ToString() - << ", index " << index; - } - } - MS_LOG(DEBUG) << error_message_prefix << loginfoBuffer.str(); - return CheckTensorDType(sample_tensor, accepts, error_message_prefix); -} - -TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts, - const std::string &error_message_prefix) { - if (scalar == nullptr) { - MS_LOG(EXCEPTION) << "Scalar nullptr"; - } - auto type = scalar->BuildType(); - if (type == nullptr) { - MS_LOG(EXCEPTION) << "Scalar value nullptr"; - } - - return CheckType(type, accepts, error_message_prefix); -} - -ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) { - ShapePtr shape_base = tensor_base->shape(); - ShapePtr shape = tensor->shape(); - if (*shape != *shape_base) { - MS_LOG(EXCEPTION) << op << " evaluator first arg shape " << tensor->shape()->ToString() - << " are not consistent with second arg shape " << tensor_base->shape()->ToString(); - } - return shape_base; -} - -TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) { - TypePtr type_base = tensor_base->element()->BuildType(); - TypePtr type = tensor->element()->BuildType(); - if (*type != *type_base) { - MS_LOG(EXCEPTION) << op << " evaluator first arg dtype " << type_base->ToString() - << " are not consistent with second arg dtype " << type->ToString(); - } - return type_base; -} - -int CheckAxis(const std::string &op, const ValuePtr &axis, int minimum, int max) { - if (axis == nullptr) { - MS_LOG(EXCEPTION) << op << " evaluator axis is null"; - } - if (!axis->isa()) { - MS_LOG(EXCEPTION) << op << " evaluator axis should be int, but got " << axis->type_name(); - } - int axis_value = GetValue(axis); - if (axis_value > max || axis_value < minimum) { - MS_LOG(EXCEPTION) << op << " evaluator axis value should be in the range [" << minimum << ", " << max - << "], but get " << axis_value; - } - return axis_value; -} -void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list, - size_t size_expect) { - if (args_spec_list.size() != size_expect) { - MS_LOG(EXCEPTION) << op << " input args size should be " << size_expect << ", but got " << args_spec_list.size(); - } - - for (size_t i = 0; i < size_expect; i++) { - MS_EXCEPTION_IF_NULL(args_spec_list[i]); - } -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/param_validator.h b/mindspore/ccsrc/pipeline/static_analysis/param_validator.h deleted file mode 100644 index daa436d66d..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/param_validator.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_PARAM_VALIDATOR_H_ -#define PIPELINE_STATIC_ANALYSIS_PARAM_VALIDATOR_H_ - -#include -#include -#include -#include -#include "pipeline/static_analysis/abstract_value.h" -#include "pipeline/static_analysis/utils.h" -#include "utils/any.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace abstract { -// check if variable's type is an instance of any of accepts or of a subclass of it. -TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix); - -TypePtr CheckTensorDType(const AbstractTensorPtr &tensor, const TypePtrList &accepts, - const std::string &error_message_prefix); - -TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const TypePtrList &accepts, - const std::string &error_message_prefix); - -TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts, - const std::string &error_message_prefix); - -ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor); - -TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor); - -int CheckAxis(const std::string &op, const ValuePtr &axis, int min, int max); - -void CheckArgsSize(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t size_expect); - -template -struct ReportNameTraits {}; - -#define ABSTRACT_REPORT_NAME_TRAITS(abstract) \ - template <> \ - struct ReportNameTraits { \ - static constexpr char name[] = #abstract; \ - }; -ABSTRACT_REPORT_NAME_TRAITS(Tensor) -ABSTRACT_REPORT_NAME_TRAITS(Tuple) -ABSTRACT_REPORT_NAME_TRAITS(Scalar) -ABSTRACT_REPORT_NAME_TRAITS(List) -ABSTRACT_REPORT_NAME_TRAITS(Dictionary) -ABSTRACT_REPORT_NAME_TRAITS(Slice) -ABSTRACT_REPORT_NAME_TRAITS(Function) -ABSTRACT_REPORT_NAME_TRAITS(Type) -ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) -ABSTRACT_REPORT_NAME_TRAITS(Class) -ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices) -ABSTRACT_REPORT_NAME_TRAITS(Sequeue) - -template -std::shared_ptr CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) { - if (index >= args_spec_list.size()) { - MS_EXCEPTION(ValueError) << op << " evaluator args list index out of bound, size " << args_spec_list.size() - << ", index " << index; - } - auto arg = dyn_cast(args_spec_list[index]); - if (arg == nullptr) { - MS_EXCEPTION(TypeError) << "Operator " << op << " input[" << index << "] should be " << ReportNameTraits::name - << ", but got " << args_spec_list[index]->BuildType()->ToString() << "."; - } - return arg; -} - -// check if each element in args_spec is type T, and can be joined. -template -void CheckArgsSpec(const AbstractBasePtrList &args_list) { - for (const auto &arg : args_list) { - if (!arg->isa()) { - MS_EXCEPTION(TypeError) << "Expected type " << ReportNameTraits::name << ", but got " - << arg->BuildType()->ToString() << "."; - } - } - (void)AbstractJoin(args_list); -} -} // namespace abstract -} // namespace mindspore - -#endif // PIPELINE_STATIC_ANALYSIS_PARAM_VALIDATOR_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc deleted file mode 100644 index 99dc085989..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ /dev/null @@ -1,1394 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/prim.h" - -#include -#include -#include -#include -#include -#include - -#include "operator/cc_implementations.h" -#include "operator/ops.h" -#include "operator/composite/do_signature.h" -#include "operator/prim_to_function.h" -#include "pipeline/static_analysis/utils.h" -#include "utils/symbolic.h" -#include "./common.h" -#include "pipeline/resource.h" -#include "pipeline/parse/resolve.h" -#include "ir/tensor.h" -#include "utils/convert_utils.h" -#include "utils/context/ms_context.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/static_analysis/param_validator.h" -#include "common/utils.h" - -namespace mindspore { -namespace abstract { -PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { - static PrimitiveEvalImplMap prim_eval_implement_map = { - // Statements - {prim::kPrimReturn, {InferImplReturn, true}}, - {prim::kPrimTypeOf, {InferImplTypeof, false}}, - {prim::kPrimHasType, {InferImplHasType, false}}, - {prim::kPrimDot, {InferImplDot, true}}, - {prim::kPrimSwitch, {InferImplSwitch, true}}, - {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, - {prim::kPrimIs_, {InferImplIs_, true}}, - {prim::kPrimIsNot, {InferImplIsNot, true}}, - {prim::kPrimInDict, {InferImplInDict, true}}, - {prim::kPrimNotInDict, {InferImplNotInDict, true}}, - {prim::kPrimIsConsant, {InferImplIsConstant, true}}, - // Maths - {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, - {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, - // Array - {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, - {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, - {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, - {prim::kPrimShape, {InferImplShape, true}}, - {prim::kPrimPack, {InferImplPack, true}}, - // Structure - {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, - {prim::kPrimMakeList, {InferImplMakeList, true}}, - {prim::kPrimMakeDict, {InferImplMakeDict, true}}, - {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, - {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, - {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, - {prim::kPrimMakeRecord, {InferImplMakeRecord, false}}, - {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, - {prim::kPrimListGetItem, {InferImplListGetItem, true}}, - {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, - {prim::kPrimListSetItem, {InferImplListSetItem, true}}, - {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, - {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, - {prim::kPrimListAppend, {InferImplListAppend, true}}, - {prim::kPrimTupleLen, {InferImplTupleLen, true}}, - {prim::kPrimListLen, {InferImplListLen, true}}, - {prim::kPrimArrayLen, {InferImplArrayLen, true}}, - {prim::kPrimListMap, {InferImplListMap, false}}, - {prim::kPrimListReduce, {InferImplListReduce, false}}, - {prim::kPrimTupleReversed, {InferImplTupleReversed, false}}, - {prim::kPrimReducedShape, {InferImplReduceShape, false}}, - {prim::kPrimTupleDiv, {InferImplTupleDiv, false}}, - {prim::kPrimTupleToArray, {InferImplTuple2Array, false}}, - {prim::kPrimShapeMul, {InferImplShapeMul, false}}, - {prim::kPrimTupleEqual, {InferImplTupleEqual, false}}, - {prim::kPrimListEqual, {InferImplListEqual, false}}, - {prim::kPrimMakeRange, {InferImplMakeRange, false}}, - {prim::kPrimStopGradient, {InferImplStopGradient, false}}, - {prim::kPrimStringEqual, {InferImplStringEqual, false}}, - {prim::kPrimStringConcat, {InferImplStringConcat, false}}, - {prim::kPrimDictLen, {InferImplDictLen, false}}, - // NN - {prim::kPrimPooling, {InferImplPooling, true}}, - {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, - {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, - {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, - {prim::kPrimReluGrad, {InferImplReluGrad, true}}, - {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, - {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, - {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, - {prim::kPrimRelu, {InferImplRelu, true}}, - {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, - {prim::kPrimZerosLike, {InferImplZerosLike, true}}, - {prim::kPrimBpropCut, {InferImplBpropCut, true}}, - {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, - {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, - {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, - // Others - {prim::kPrimIdentity, {InferImplIdentity, true}}, - // Set impl to null as it will use PartialEvaluator; - {prim::kPrimPartial, {nullptr, true}}, - {prim::kPrimJ, {InferImplJ, false}}, - {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, - {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, - {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, - {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, - {prim::kPrimMakeRef, {InferImplMakeRef, true}}, - {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, - {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, - {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}}, - {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, - {prim::kPrimDepend, {InferImplDepend, true}}, - {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, - {prim::kPrimControlDepend, {InferImplControlDepend, true}}, - // Debug - {prim::kPrimDebug, {InferImplDebug, true}}, - // IndexedSlices - {prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}}, - {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, - {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, - {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, - {prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}}, - }; - return prim_eval_implement_map; -} - -using mindspore::parse::PyObjectWrapper; - -EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse_flag = context->enable_sparse_flag(); - if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { - auto ret_abstract = AbstractEval(args); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; - return ret_abstract; - } - } - prim_->BeginRecordAddAttr(); - AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); - prim_->EndRecordAddAttr(); - auto added_attrs = prim_->evaluate_added_attrs(); - auto infer_result = std::make_shared(abs_base, std::make_shared(added_attrs)); - return infer_result; -} - -EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - AbstractBasePtrList args_spec_list; - if (out_conf->node() == nullptr || !out_conf->node()->isa()) { - MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; - } - - auto do_signature = dyn_cast(prim_); - auto out_node = dyn_cast(out_conf->node()); - const auto &out_node_inputs = out_node->inputs(); - if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { - MS_LOG(EXCEPTION) << "Op: " << do_signature->function()->ToString() - << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() - << ", inputs size " << out_node_inputs.size(); - } - AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; - - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); - - ScopePtr scope = kDefaultScope; - if (out_conf != nullptr) { - scope = out_conf->node()->scope(); - } - ScopeGuard scope_guard(scope); - - AnfNodePtr new_cnode = nullptr; - if (bound_node() != nullptr) { - TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); - new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, - args_inputs); - TraceManager::EndTrace(); - } else { - new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, - args_inputs); - } - AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context()); - - return engine->ForwardConfig(out_conf, fn_conf); -} - -static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) { - // arg[0] is the func graph to unpack, ignore it - AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end()); - AbstractBasePtrList graph_specialize_args; - if (need_unpack) { - for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) { - MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]); - if (specialize_args_before_unpack[index]->isa()) { - AbstractTuplePtr arg_tuple = specialize_args_before_unpack[index]->cast(); - std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), - std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; }); - } else if (specialize_args_before_unpack[index]->isa()) { - AbstractDictionaryPtr arg_dict = specialize_args_before_unpack[index]->cast(); - auto dict_elems = arg_dict->elements(); - (void)std::transform( - dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args), - [](const AbstractAttribute &item) { return std::make_shared(item.first, item.second); }); - } else { - MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got " - << specialize_args_before_unpack[index]->ToString(); - } - } - } else { - graph_specialize_args = specialize_args_before_unpack; - } - return graph_specialize_args; -} - -EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - if (out_conf->node() == nullptr || !out_conf->node()->isa()) { - MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; - } - - auto unpack_graph = prim_->cast(); - auto out_node = out_conf->node()->cast(); - const auto &out_node_inputs = out_node->inputs(); - if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { - MS_LOG(EXCEPTION) << "UnpackGraphPrimitive" - << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() - << ", inputs size " << out_node_inputs.size(); - } - AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); - // get the forward graph - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - AbstractFunctionPtr fn = args_spec_list[0]->cast(); - if (fn == nullptr) { - MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); - } - auto real_fn = fn->cast(); - MS_EXCEPTION_IF_NULL(real_fn); - FuncGraphPtr forward_graph = real_fn->func_graph(); - MS_EXCEPTION_IF_NULL(forward_graph); - AbstractBasePtrList graph_specialize_args = - GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args()); - - AbstractBasePtrList graph_specialize_args_without_sens; - (void)std::transform(graph_specialize_args.begin(), - graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0), - std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; }); - auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens); - engine->func_graph_manager()->AddFuncGraph(new_graph); - ScopePtr scope = kDefaultScope; - if (out_conf != nullptr) { - scope = out_conf->node()->scope(); - } - ScopeGuard scope_guard(scope); - AnfNodePtr new_vnode = NewValueNode(new_graph); - AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context()); - - return engine->ForwardConfig(out_conf, fn_conf); -} - -AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node_type, AnfNodePtr target_type, - FuncGraphPtr func_graph) { - AnfNodePtr target_node = source_node; - if (node_type->isa()) { - auto x = node_type->cast(); - if (x->element()->BuildType()->isa()) { - auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); - MS_EXCEPTION_IF_NULL(cast); - target_node = func_graph->NewCNode({NewValueNode(cast), source_node, target_type}); - } - } else if (node_type->isa()) { - auto x = node_type->cast(); - auto &items = x->elements(); - std::vector nodes; - nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - int idx = 0; - for (const auto &item : items) { - AnfNodePtr tuple_node = - func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)}); - AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, item, target_type, func_graph); - nodes.emplace_back(node); - ++idx; - } - target_node = func_graph->NewCNode(nodes); - } else if (node_type->isa()) { - auto x = node_type->cast(); - auto &items = x->elements(); - std::vector dict_key_nodes; - std::vector dict_value_nodes; - dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - for (const auto &item : items) { - AnfNodePtr dict_value_node = - func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)}); - AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph); - dict_key_nodes.emplace_back(NewValueNode(item.first)); - dict_value_nodes.emplace_back(node); - } - target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes), - func_graph->NewCNode(dict_value_nodes)}); - } - return target_node; -} - -EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - AbstractBasePtrList args_spec_list; - if (out_conf->node() == nullptr || !out_conf->node()->isa()) { - MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; - } - auto out_node = out_conf->node()->cast(); - const auto &out_node_inputs = out_node->inputs(); - if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { - MS_LOG(EXCEPTION) << "MixedPrecisionCast" - << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() - << ", inputs size " << out_node_inputs.size(); - } - AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); - - ScopePtr scope = kDefaultScope; - if (out_conf != nullptr) { - scope = out_conf->node()->scope(); - } - ScopeGuard scope_guard(scope); - - FuncGraphPtr func_graph = out_conf->node()->func_graph(); - AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph); - AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context()); - - return engine->ForwardConfig(out_conf, fn_conf); -} - -namespace { -py::object BuildValue(const ValuePtr &value_ptr) { - if (value_ptr == nullptr) { - return py::none(); - } else { - return ValuePtrToPyData(value_ptr); - } -} -} // end anonymous namespace - -py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { - MS_EXCEPTION_IF_NULL(abs_base); - py::dict dic; - if (abs_base->isa()) { - auto arg_tensor = dyn_cast(abs_base); - dic["shape"] = arg_tensor->shape()->shape(); - dic["dtype"] = arg_tensor->BuildType(); - dic["value"] = BuildValue(arg_tensor->BuildValue()); - } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { - std::vector shape; - dic["shape"] = shape; - dic["dtype"] = abs_base->BuildType(); - dic["value"] = BuildValue(abs_base->BuildValue()); - } else if (abs_base->isa()) { - auto arg_slice = dyn_cast(abs_base); - std::vector shape; - dic["shape"] = shape; - dic["dtype"] = arg_slice->BuildType(); - dic["value"] = BuildValue(arg_slice->BuildValue()); - } else if (abs_base->isa()) { - auto value = abs_base->cast()->ref(); - dic = ConvertAbstractToPython(value); - } else if (abs_base->isa()) { - dic["shape"] = py::none(); - dic["dtype"] = py::ellipsis(); - dic["value"] = py::ellipsis(); - } else if (abs_base->isa()) { - auto arg_tuple = dyn_cast(abs_base); - size_t len = arg_tuple->size(); - py::tuple shape_tuple(len); - py::tuple dtype_tuple(len); - - for (size_t i = 0; i < len; i++) { - py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]); - shape_tuple[i] = out["shape"]; - dtype_tuple[i] = out["dtype"]; - } - dic["shape"] = shape_tuple; - dic["dtype"] = dtype_tuple; - dic["value"] = BuildValue(arg_tuple->BuildValue()); - } else if (abs_base->isa()) { - auto arg_list = dyn_cast(abs_base); - size_t len = arg_list->size(); - py::list shape_list(len); - py::list dtype_list(len); - - for (size_t i = 0; i < len; i++) { - py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); - shape_list[i] = out["shape"]; - dtype_list[i] = out["dtype"]; - } - dic["shape"] = shape_list; - dic["dtype"] = dtype_list; - dic["value"] = BuildValue(arg_list->BuildValue()); - } else if (abs_base->isa()) { - dic["shape"] = py::none(); - dic["dtype"] = py::none(); - dic["value"] = py::none(); - } else if (abs_base->isa()) { - dic["shape"] = py::none(); - dic["dtype"] = abs_base->BuildType(); - dic["value"] = py::none(); - } else { - auto value = abs_base->BuildValue(); - if ((*value == *kAnyValue)) { - auto value_desc = abs_base->value_desc(); - MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc) - << " for python primitive." << abs_base->ToString(); - } - MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is " - << value->ToString(); - } - return dic; -} - -namespace { -py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrList &args) { - const AbstractBasePtrList *args_ptr; - - if (prim_py->is_tuple_input_) { - if (args.empty()) { - MS_LOG(EXCEPTION) << "Primitive args is empty"; - } - if (args[0] == nullptr || !args[0]->isa()) { - MS_LOG(EXCEPTION) << "Custom Primitive inputs should be packed into a Tuple after converting" - "prim convert pass for GE."; - } - args_ptr = &(args[0]->cast()->elements()); - } else { - args_ptr = &args; - } - - py::tuple py_args(args_ptr->size()); - for (size_t i = 0; i < args_ptr->size(); i++) { - auto arg_i = (*args_ptr)[i]; - py_args[i] = ConvertAbstractToPython(arg_i); - } - return py_args; -} - -AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { - // Convert to AbstractValue based on type and shape - if (output["value"].is_none()) { - auto out_shape = output["shape"]; - auto out_dtype = output["dtype"]; - return PyListDtype2AbstractTensor(out_shape, out_dtype); - } - // Convert pyobject to Value, then to AbstractValue - ValuePtr converted_ret = nullptr; - bool converted = parse::ConvertData(output["value"], &converted_ret); - if (!converted) { - MS_LOG(EXCEPTION) << "Convert data failed"; - } - auto res_spec = FromValue(converted_ret); - MS_EXCEPTION_IF_NULL(res_spec); - if (res_spec->isa()) { - // Replace to tensor constant node in specialize - auto res_tensor = res_spec->cast(); - res_tensor->set_value(converted_ret); - } - if (prim_py->IsCustomPrim()) { - // Raise error if output_num is not match the infer result. - int output_num = GetValue(prim_py->GetAttr("output_num")); - if (res_spec->isa() && output_num != 1) { - MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num - << " not matches the infer result."; - } else if (res_spec->isa() && - (res_spec->cast()->size() != IntToSize(output_num))) { - MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num - << " not matches the infer result."; - } - } - return res_spec; -} -} // end anonymous namespace - -EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse_flag = context->enable_sparse_flag(); - if (enable_sparse_flag) { - auto ret_abstract = AbstractEval(args); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; - return ret_abstract; - } - } - MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); - - const auto &iter = cache_->find(args); - if (iter != cache_->end()) { - return iter->second; - } - auto py_args = PreparePyInputs(prim_py_, args); - - auto pyobj = prim_py_->GetPyObj(); - if (pyobj == nullptr) { - MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; - } - auto infer_fuc = pyobj.attr("__infer__"); - prim_py_->BeginRecordAddAttr(); - py::dict output = infer_fuc(*py_args); - prim_py_->EndRecordAddAttr(); - auto added_attrs = prim_py_->evaluate_added_attrs(); - MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); - auto res_spec = PyInferRes2Abstract(prim_py_, output); - - MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; - auto infer_result = std::make_shared(res_spec, std::make_shared(added_attrs)); - (*cache_)[args] = infer_result; - return infer_result; -} - -EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse_flag = context->enable_sparse_flag(); - if (enable_sparse_flag) { - auto ret_abstract = AbstractEval(args); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; - return ret_abstract; - } - } - // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. - if (nargs_ != args.size()) { - MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; - return nullptr; - } - TypePtr ret_value_type = return_value_type_; - ValuePtrList value_list; - for (const auto &arg : args) { - // Check if all arguments are scalar type. - MS_EXCEPTION_IF_NULL(arg); - if (arg->isa()) { - auto arg_scalar = dyn_cast(arg); - auto arg_value = arg_scalar->GetValueTrack(); - value_list.push_back(arg_value); - } else { - // Raise TypeError Expected Scalar. - MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives."; - } - } - for (const auto &item : type_map_) { - TypePtrList selections; - MS_EXCEPTION_IF_NULL(item.second); - (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections), - [&args](size_t arg_idx) -> TypePtr { return args[arg_idx]->GetTypeTrack(); }); - TypePtr res = CheckTypeList(item.first, selections); - if (*return_value_type_ == *(item.first)) { - ret_value_type = res; - } - } - - ValuePtr evaluated_value = RunImpl(value_list); - if (!(*evaluated_value == *kAnyValue)) { - ret_value_type = evaluated_value->type(); - } - // for comparison primitives , return type shall have be specified to be bool. - if (specify_out_type_ != nullptr) { - ret_value_type = specify_out_type_; - } - - AbstractScalarPtr abs_base = std::make_shared(evaluated_value, ret_value_type); - return std::make_shared(abs_base, std::make_shared()); -} - -ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { - if (!eval_value_) { - return kAnyValue; - } else { - if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) { - MS_EXCEPTION_IF_NULL(arg); - return arg->isa(); - })) { - return kAnyValue; - } - return impl_(args); - } -} - -// Primitive implementation -// static function start -namespace { -EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) { - EvaluatorPtr prim_evaluator = std::make_shared(primitive, eval_impl); - return prim_evaluator; -} - -EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value, - const TypePtr &specify_out_type) { - FunctionPtr func = nullptr; - (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func); - MS_EXCEPTION_IF_NULL(func); - - EvaluatorPtr uniform_primitive_evaluator = - std::make_shared(func, prim_impl, eval_value, specify_out_type); - return uniform_primitive_evaluator; -} - -const int kResolveCaseUserDefineClass = 1; -const int kResolveCaseBuildinTypeMethod = 2; -const int kResolveCaseFunction = 3; -int GetResolveCase(const TypePtr &data_type) { - MS_EXCEPTION_IF_NULL(data_type); - if (data_type->type_id() == kObjectTypeClass) { - return kResolveCaseUserDefineClass; - } - - // try method map, if not in method map, the data_type should be External type. - if (pipeline::Resource::IsTypeInMethodMap(data_type->type_id())) { - return kResolveCaseBuildinTypeMethod; - } - - return kResolveCaseFunction; -} - -FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) { - MS_EXCEPTION_IF_NULL(engine); - MS_EXCEPTION_IF_NULL(method); - if (!method->isa()) { - MS_LOG(EXCEPTION) << "Method type error: " << method->ToString(); - } - - std::shared_ptr obj = method->cast>(); - FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj()); - if (func_graph == nullptr) { - MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed"; - } - - FuncGraphManagerPtr manager = engine->func_graph_manager(); - manager->AddFuncGraph(func_graph); - return func_graph; -} - -inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) { - MS_EXCEPTION_IF_NULL(engine); - FuncGraphManagerPtr manager = engine->func_graph_manager(); - manager->AddFuncGraph(func_graph); -} - -EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, - const AnfNodeConfigPtr &old_conf) { - MS_EXCEPTION_IF_NULL(old_conf); - - AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); - AbstractFunctionPtr abs_func = dyn_cast(abs_ptr); - MS_EXCEPTION_IF_NULL(abs_func); - - // Create new cnode - std::vector input = {NewValueNode(prim::kPrimPartial)}; - auto func_graph_func = dyn_cast(abs_func); - if (func_graph_func != nullptr) { - FuncGraphPtr fg = func_graph_func->func_graph(); - input.push_back(NewValueNode(fg)); - } else { - auto prim_func = dyn_cast(abs_func); - MS_EXCEPTION_IF_NULL(prim_func); - PrimitivePtr prim = prim_func->prim(); - input.push_back(NewValueNode(prim)); - } - - AnfNodeConfigPtr conf = dyn_cast(data_conf); - MS_EXCEPTION_IF_NULL(conf); - input.push_back(conf->node()); - MS_EXCEPTION_IF_NULL(old_conf); - FuncGraphPtr func_graph = old_conf->node()->func_graph(); - CNodePtr new_cnode = func_graph->NewCNode(input); - AnalysisEnginePtr eng = old_conf->engine(); - AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context()); - return eng->ForwardConfig(old_conf, fn_conf); -} - -EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, - const AbstractBasePtrList &args_spec_list, - const AnfNodeConfigPtr &out_conf) { - // args_spec_list: same as StaticGetter - if (args_spec_list.size() < 2) { - MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2"; - } - MS_EXCEPTION_IF_NULL(out_conf); - // An external type. - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString(); - MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString(); - auto data_v = args_spec_list[0]->BuildValue(); - if (!data_v->isa()) { - MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString(); - } - - auto item_v = args_spec_list[1]->BuildValue(); - if (item_v->isa()) { - item_v = std::make_shared(item_v->cast()->value()); - } - - if (!item_v->isa()) { - MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_v->ToString(); - } - - // item_name to func addr from obj_map - parse::SymbolPtr symbol = item_v->cast(); - parse::NameSpacePtr name_space = data_v->cast(); - FuncGraphPtr func_graph = out_conf->node()->func_graph(); - - auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node()); - if (new_node == nullptr) { - MS_LOG(EXCEPTION) << "Resolve node failed"; - } - - AnalysisEnginePtr eng = out_conf->engine(); - AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context()); - return eng->ForwardConfig(out_conf, fn_conf); -} - -EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, - const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, - const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << "args_spec_list is empty"; - } - AbstractClassPtr cls = CheckArg("__FUNC__", args_spec_list, 0); - - // If item_v is an attribute, get abstract value from AbstractClass - MS_EXCEPTION_IF_NULL(item_v); - if (!item_v->isa()) { - MS_LOG(EXCEPTION) << "Attribute type error"; - } - std::string item_name = item_v->cast()->value(); - MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name(); - MS_LOG(DEBUG) << "Resolve item: " << item_name; - - AbstractBasePtr attr = cls->GetAttribute(item_name); - if (attr != nullptr) { - return std::make_shared(attr, nullptr); - } - - ValuePtr method = cls->GetMethod(item_name); - if (method->isa()) { - MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() - << ", item value: " << item_v->ToString(); - } - - // Infer class method - ValuePtr converted_v = PyObjToGraph(engine, method); - return StaticGetterInferred(converted_v, data_conf, out_conf); -} - -EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, - const TypePtr &data_type, const ConfigPtr &data_conf, - const AnfNodeConfigPtr &out_conf) { - MS_EXCEPTION_IF_NULL(item_v); - MS_EXCEPTION_IF_NULL(data_type); - // The method maybe a Primitive or Composite - if (!item_v->isa()) { - MS_LOG(EXCEPTION) << "Error item is not string"; - } - - std::string item_name = item_v->cast()->value(); - Any method = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); - if (method.empty()) { - MS_LOG(EXCEPTION) << "Object type: " << data_type->ToString() << " has no method: " << item_name; - } - - ValuePtr converted_v = nullptr; - if (method.is()) { - // composite registered in standard_method_map go to this branch - converted_v = prim::GetPythonOps(method.cast()); - AddToManager(engine, converted_v->cast()); - } else if (method.is()) { - converted_v = method.cast(); - } else { - MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from method map, but got " << method.ToString(); - } - return StaticGetterInferred(converted_v, data_conf, out_conf); -} - -EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { - // Inputs: namespace and its static function; or class and its member function - CheckArgsSize("StaticGetter", args_spec_list, 2); - - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - TypePtr data_type = args_spec_list[0]->BuildType(); - ValuePtr item_value = args_spec_list[1]->BuildValue(); - ScopePtr scope = kDefaultScope; - if (out_conf != nullptr) { - scope = out_conf->node()->scope(); - } - ScopeGuard scope_guard(scope); - if (item_value->isa()) { - MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString(); - } - - int case_v = GetResolveCase(data_type); - if (case_v == kResolveCaseUserDefineClass) { - return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf); - } else if (case_v == kResolveCaseBuildinTypeMethod) { - return GetEvaluatedValueForBuiltinTypeMethod(engine, item_value, data_type, data_conf, out_conf); - } else { - return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf); - } -} -} // end anonymous namespace - -// static variable start; -namespace { -class EmbedEvaluator : public SymbolicPrimEvaluator { - public: - EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {} - ~EmbedEvaluator() override = default; - MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator); - EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { - // arg: free variable to be embedded - if (args_conf_list.size() != 1) { - MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); - } - AnfNodeConfigPtr node_conf = dyn_cast(args_conf_list[0]); - MS_EXCEPTION_IF_NULL(node_conf); - - AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract(); - x = SensitivityTransform(x); - SymbolicKeyInstancePtr key = std::make_shared(node_conf->node(), x); - AbstractScalarPtr abs_scalar = std::make_shared(key, std::make_shared()); - return std::make_shared(abs_scalar, std::make_shared()); - } -}; - -static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) { - auto root_g_set = manager->roots(); - if (root_g_set.size() != 1) { - return nullptr; - } - const FuncGraphPtr &root_g = root_g_set.back(); - - for (auto ¶m_node : root_g->parameters()) { - auto param = param_node->cast(); - if (param && name == param->name()) { - return param; - } - } - return nullptr; -} - -class RefToEmbedEvaluator : public SymbolicPrimEvaluator { - public: - RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {} - ~RefToEmbedEvaluator() override = default; - MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator); - EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { - if (args_conf_list.size() != 1) { - MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size(); - return nullptr; - } - static TypePtr type = std::make_shared(); - auto node_conf = dyn_cast(args_conf_list[0]); - if (node_conf == nullptr) { - MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; - return nullptr; - } - AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); - AbstractRefPtr ref_abs = abs->cast(); - if (ref_abs == nullptr) { - MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); - return nullptr; - } - auto key_abs = ref_abs->ref_key(); - if (key_abs == nullptr) { - MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr."; - return nullptr; - } - auto key_value = key_abs->BuildValue(); - if (key_value == nullptr) { - MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr."; - return nullptr; - } - auto refkey = key_value->cast(); - if (refkey == nullptr) { - auto ret = std::make_shared(type); - auto ref_value = ref_abs->ref(); - MS_EXCEPTION_IF_NULL(ref_value); - ret->set_sparse_grad(ref_value->sparse_grad()); - ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad()); - return std::make_shared(ret, std::make_shared()); - } - - std::string name = refkey->tag(); - const auto &manager = node_conf->node()->func_graph()->manager(); - auto node = FindParameterNodeByString(manager, name); - if (node == nullptr) { - MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph."; - return nullptr; - } - AbstractBasePtr x = ref_abs->ref(); - x = SensitivityTransform(x); - std::shared_ptr key = std::make_shared(node, x); - std::shared_ptr abs_scalar = std::make_shared(key, type); - abs_scalar->set_sparse_grad(x->sparse_grad()); - abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad()); - return std::make_shared(abs_scalar, std::make_shared()); - } -}; - -class GetAttrEvaluator : public TransitionPrimEvaluator { - public: - GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {} - ~GetAttrEvaluator() override = default; - MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse_flag = context->enable_sparse_flag(); - if (enable_sparse_flag) { - auto ret_abstract = AbstractEval(args_spec_list); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; - return ret_abstract; - } - } - // Inputs: data, item - if (args_spec_list.size() != 2) { - MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); - } - EvalResultPtr ret = nullptr; - if (bound_node() != nullptr) { - TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); - ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); - TraceManager::EndTrace(); - } else { - ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); - } - // don't lookup from cache, as different out_conf with same node but different context - // may add different entry to anfnode_config_map, like getattr primitive; - (*cache_)[args_spec_list] = ret; - return ret; - } -}; - -class ResolveEvaluator : public TransitionPrimEvaluator { - public: - ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {} - ~ResolveEvaluator() override = default; - MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator); - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { - // Inputs: namespace, symbol - if (args_spec_list.size() != 2) { - MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); - } - EvalResultPtr ret = nullptr; - if (bound_node() != nullptr) { - TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); - ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); - TraceManager::EndTrace(); - } else { - ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); - } - return ret; - } -}; - -class CreateInstanceEvaluator : public TransitionPrimEvaluator { - public: - CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {} - ~CreateInstanceEvaluator() override = default; - MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, - const AnfNodeConfigPtr &out_conf) override { - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; - } - - // get the type parameter - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kMetaTypeTypeType) { - MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got " - << type->ToString(); - } - - ValuePtr value_track = args_spec_list[0]->GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_track); - - std::shared_ptr type_obj = dyn_cast(value_track); - if (type_obj == nullptr) { - MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << "."; - } - - if (!type_obj->isa()) { - MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got " - << type_obj->ToString() << "."; - } - - auto class_type = type_obj->obj(); - MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << "."; - - // get the create instance obj's parameters - pybind11::tuple params = GetParameters(args_spec_list); - - // create class instance - auto obj = parse::data_converter::CreatePythonObject(class_type, params); - if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "Create python object failed, only support Cell and Primitive type"; - } - - // process the object - ValuePtr converted_ret = nullptr; - bool converted = parse::ConvertData(obj, &converted_ret, true); - if (!converted) { - MS_LOG(EXCEPTION) << "Convert the python object failed"; - } - MS_EXCEPTION_IF_NULL(converted_ret); - - if (converted_ret->isa()) { - AddToManager(engine, converted_ret->cast()); - } - - AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); - auto infer_result = std::make_shared(ret, nullptr); - (*cache_)[args_spec_list] = infer_result; - return infer_result; - } - - pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const { - // Exclude class type by minus 1; - std::size_t params_size = args_spec_list.size() - 1; - auto params = py::tuple(params_size); - if (params_size > 0) { - for (size_t i = 0; i < params_size; i++) { - // Only support the Scalar parameters type. Bypass class type by offset with 1. - auto arg = args_spec_list[i + 1]; - MS_EXCEPTION_IF_NULL(arg); - // Because the Tensor's AbstractTensor can't get value from GetValueTrack. - ValuePtr param_value = arg->BuildValue(); - py::object param = ValuePtrToPyData(param_value); - params[i] = param; - } - } - return params; - } -}; - -class PartialEvaluator : public Evaluator { - public: - PartialEvaluator() : Evaluator("PartialEvaluator") {} - ~PartialEvaluator() override = default; - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf = nullptr) override { - if (args_conf_list.size() == 0) { - MS_LOG(EXCEPTION) << "Args size should be greater than 0"; - } - - MS_EXCEPTION_IF_NULL(out_conf); - MS_EXCEPTION_IF_NULL(out_conf->node()); - auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract(); - AbstractBasePtrList args_spec_list{arg0_value}; - // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. - if (arg0_value->isa()) { - auto ret = std::make_shared(arg0_value->GetValueTrack()->cast(), out_conf->node()); - MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() - << " as func is: " << arg0_value->ToString(); - auto eval_result = std::make_shared(ret, std::make_shared()); - (*cache_)[args_spec_list] = eval_result; - return eval_result; - } - auto func = CheckArg("partial", args_spec_list, 0); - // Sometimes, node[0] in out_conf becomes phi0; - if (func->isa()) { - auto prim_func = dyn_cast(func); - if (prim_func->prim()->isa()) { - prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast(prim_func->prim()); - return HandleDoSignature(engine, do_signature_prim->function(), out_conf); - } - } - - (void)std::transform( - args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); }); - AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); - - auto cnode = out_conf->node()->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() != (args_conf_list.size() + 1)) { - MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString() - << ", args_conf_list: " << mindspore::ToString(args_conf_list); - } - - AbstractFuncAtomPtrList partial_funcs_list; - auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { - auto new_func = std::make_shared(atom_func, args, cnode); - partial_funcs_list.push_back(new_func); - }; - func->Visit(build_partial); - - auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); - auto infer_result = std::make_shared(ret, std::make_shared()); - (*cache_)[args_spec_list] = infer_result; - return infer_result; - } - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } - - EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, - const AnfNodeConfigPtr &out_conf = nullptr) const { - MS_EXCEPTION_IF_NULL(out_conf); - MS_EXCEPTION_IF_NULL(out_conf->node()); - auto cnode = out_conf->node()->cast(); - if (cnode == nullptr) { - MS_LOG(EXCEPTION) << "Cnode is nullptr"; - } - std::vector new_nodes_inputs = cnode->inputs(); - auto new_signature_value = std::make_shared("signature", signature_value); - new_nodes_inputs[1] = NewValueNode(new_signature_value); - FuncGraphPtr func_graph = cnode->func_graph(); - - ScopePtr scope = out_conf->node()->scope(); - ScopeGuard scope_guard(scope); - - CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs); - AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context()); - return engine->ForwardConfig(out_conf, fn_conf); - } -}; - -struct PrimitiveImplInferValue { - PrimitiveImpl impl_; // implement function of primitive - bool eval_value_; // whether evaluate value - TypePtr specify_out_type_; // whether specify return type - bool in_white_list_; // true if this Primitive in white list, else false. -}; - -using PrimitiveToImplMap = std::unordered_map; -PrimitiveToImplMap &GetUniformPrimitiveToImplMap() { - static PrimitiveToImplMap uniform_prim_implement_map = { - {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}}, - {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}}, - {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, - {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, - {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, - {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}}, - {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}}, - {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, - {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, - {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, - {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared(), true}}, - {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared(), true}}, - {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared(), true}}, - {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared(), true}}, - {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared(), true}}, - {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared(), true}}, - {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared(), true}}, - {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared(), true}}, - {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared(), true}}, - {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared(), true}}, - }; - return uniform_prim_implement_map; -} - -PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap(); -std::mutex PrimEvaluatorConstructorMutex; - -void InitPrimEvaluatorConstructors() { - PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; - - for (const auto &iter : GetPrimitiveToEvalImplMap()) { - constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_); - } - - for (const auto &iter : GetUniformPrimitiveToImplMap()) { - constructor[iter.first] = - InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_); - } - constructor[prim::kPrimEmbed] = std::make_shared(); - constructor[prim::kPrimRefToEmbed] = std::make_shared(); - constructor[prim::kPrimGetAttr] = std::make_shared(); - constructor[prim::kPrimResolve] = std::make_shared(); - constructor[prim::kPrimCreateInstance] = std::make_shared(); - constructor[prim::kPrimPartial] = std::make_shared(); -} -} // namespace - -void ClearPrimEvaluatorMap() { - PrimEvaluatorConstructors.clear(); - GetPrimitiveToEvalImplMap().clear(); - GetUniformPrimitiveToImplMap().clear(); -} - -bool IsInWhiteList(const PrimitivePtr primitive) { - MS_EXCEPTION_IF_NULL(primitive); - - auto iter = GetPrimitiveToEvalImplMap().find(primitive); - if (iter != GetPrimitiveToEvalImplMap().end()) { - return iter->second.in_white_list_; - } - - auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive); - if (uni_iter != GetUniformPrimitiveToImplMap().end()) { - return uni_iter->second.in_white_list_; - } - - return false; -} - -StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { - MS_EXCEPTION_IF_NULL(primitive); - auto iter = GetPrimitiveToEvalImplMap().find(primitive); - if (iter == GetPrimitiveToEvalImplMap().end()) { - return nullptr; - } - return iter->second.impl_; -} - -PrimEvaluatorMap &GetPrimEvaluatorConstructors() { - PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; - if (!constructor.empty()) { - return constructor; - } - std::lock_guard initLock(PrimEvaluatorConstructorMutex); - if (constructor.empty()) { - InitPrimEvaluatorConstructors(); - } - - return constructor; -} - -namespace { -bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - auto x_tuple = dyn_cast(x); - auto model_tuple = dyn_cast(model); - - if (x_tuple == nullptr || model_tuple == nullptr) { - return false; - } - - if (model->IsGeneric()) { - return true; - } - - if (x_tuple->size() != model_tuple->size()) { - return false; - } - - for (size_t i = 0; i < x_tuple->size(); i++) { - bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]); - if (!is_subtype) { - return false; - } - } - return true; -} - -bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - auto x_tensor = dyn_cast(x); - auto model_tensor = dyn_cast(model); - - if (x_tensor == nullptr || model_tensor == nullptr) { - return false; - } - - if (model->IsGeneric()) { - return true; - } - - return IsSubtype(x_tensor->element(), model_tensor->element()); -} - -bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - auto x_list = dyn_cast(x); - auto model_list = dyn_cast(model); - - if (x_list == nullptr || model_list == nullptr) { - return false; - } - - if (model->IsGeneric()) { - return true; - } - - if (x_list->size() != model_list->size()) { - return false; - } - - bool is_subtype = true; - for (size_t i = 0; i < x_list->size(); i++) { - is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]); - if (!is_subtype) { - return false; - } - } - return is_subtype; -} - -bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - auto x_class = dyn_cast(x); - auto model_class = dyn_cast(model); - if (x_class == nullptr) { - return false; - } - if (model->IsGeneric()) { - return true; - } - - if (x_class->tag() == model_class->tag()) { - auto m_attributes = model_class->GetAttributes(); - auto x_attributes = x_class->attributes(); - if (m_attributes.size() != x_attributes.size()) { - return false; - } - - for (size_t i = 0; i < m_attributes.size(); i++) { - if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) { - return false; - } - } - return true; - } - - return false; -} - -inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - if (dyn_cast(x) == nullptr) { - return false; - } - TypePtr x_type = x->GetTypeTrack(); - return IsSubType(x_type, model); -} -} // namespace - -bool IsSubtype(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - TypeId model_typeid = model->type_id(); - switch (model_typeid) { - case kMetaTypeObject: - return true; - case kObjectTypeTuple: - return IsSubtypeTuple(x, model); - case kObjectTypeTensorType: - return IsSubtypeArray(x, model); - case kObjectTypeList: - return IsSubtypeList(x, model); - case kObjectTypeClass: - return IsSubtypeClass(x, model); - default: - if (IsSubType(model, std::make_shared())) { - return IsSubtypeScalar(x, model); - } - MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << "."; - } -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h deleted file mode 100644 index 1346dba2a2..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ /dev/null @@ -1,367 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_PRIM_H_ -#define PIPELINE_STATIC_ANALYSIS_PRIM_H_ - -#include -#include -#include -#include -#include - -#include "pipeline/static_analysis/evaluator.h" - -namespace mindspore { -namespace abstract { -using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &); -struct StandartPrimitiveImplReg { - StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. - bool in_white_list_; // true if this Primitive in white list, else false. -}; - -using PrimitiveEvalImplMap = - std::unordered_map; - -class StandardPrimEvaluator : public TrivialPrimEvaluator { - public: - StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl) - : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} - ~StandardPrimEvaluator() override = default; - MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; - PrimitivePtr prim() { return prim_; } - - std::string ToString() const override { return identifier_ + prim_->name(); } - - private: - PrimitivePtr prim_; - const StandardPrimitiveEvalImpl eval_impl_; -}; - -using StandardPrimEvaluatorPtr = std::shared_ptr; - -class PythonPrimEvaluator : public TrivialPrimEvaluator { - public: - explicit PythonPrimEvaluator(const PrimitivePyPtr primitive) - : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {} - ~PythonPrimEvaluator() override = default; - MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator); - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; - PrimitivePtr prim() { return dyn_cast(prim_py_); } - - std::string ToString() const override { return identifier_ + prim_py_->name(); } - - private: - PrimitivePyPtr prim_py_; -}; - -class DoSignatureEvaluator : public Evaluator { - public: - explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} - ~DoSignatureEvaluator() override = default; - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } - - private: - PrimitivePtr prim_; -}; - -class UnpackGraphEvaluator : public Evaluator { - public: - explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} - ~UnpackGraphEvaluator() override = default; - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } - - private: - PrimitivePtr prim_; -}; - -class MixedPrecisionCastEvaluator : public Evaluator { - public: - explicit MixedPrecisionCastEvaluator(const PrimitivePtr primitive) - : Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {} - ~MixedPrecisionCastEvaluator() override = default; - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } - - private: - PrimitivePtr prim_; -}; - -bool IsInWhiteList(PrimitivePtr primitive); -StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); - -using ValuePtrList = std::vector; -using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); - -class UniformPrimEvaluator : public TrivialPrimEvaluator { - public: - UniformPrimEvaluator(const FunctionPtr func_desc, PrimitiveImpl impl, bool eval_value, const TypePtr specify_out_type) - : TrivialPrimEvaluator("UniformPrimEvaluator"), - impl_(impl), - eval_value_(eval_value), - func_desc_(func_desc), - nargs_(func_desc_->args().size()), - return_value_type_(func_desc_->retval()), - specify_out_type_(specify_out_type) { - for (size_t i = 0; i < nargs_; ++i) { - TypePtr type = func_desc_->args()[i]; - if (type_map_[type]) { - type_map_[type]->push_back(i); - } else { - type_map_[type] = std::make_shared>(); - type_map_[type]->push_back(i); - } - } - } - ~UniformPrimEvaluator() override = default; - MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); - - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; - ValuePtr RunImpl(const ValuePtrList &args) const; - - // If eval_value_ is False, return broadened arguments. - AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { - if (!eval_value_) { - AbstractBasePtrList broadened_args_spec_list; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened_args_spec_list), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); - return broadened_args_spec_list; - } - return args_spec_list; - } - - private: - PrimitiveImpl impl_; - bool eval_value_; - const FunctionPtr func_desc_; - const std::size_t nargs_; - const TypePtr return_value_type_; - const TypePtr specify_out_type_; - std::unordered_map>, TypeHasher, TypeEqual> type_map_; -}; - -PrimEvaluatorMap &GetPrimEvaluatorConstructors(); - -// Check whether type x is a subtype of model. -bool IsSubtype(const AbstractBasePtr x, const TypePtr model); - -void ClearPrimEvaluatorMap(); - -py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base); - -AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -void InitUndeterminedFromEnv(const std::string &sparse_shape_types); - -AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -} // namespace abstract -} // namespace mindspore - -#endif // PIPELINE_STATIC_ANALYSIS_PRIM_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc deleted file mode 100644 index e01b98841b..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc +++ /dev/null @@ -1,719 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/program_specialize.h" - -#include -#include -#include "./common.h" -#include "operator/ops.h" -#include "operator/composite/do_signature.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "utils/graph_utils.h" -#include "utils/log_adapter.h" -#include "utils/profile.h" -#include "debug/trace.h" - -namespace mindspore { -namespace abstract { -namespace { -inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) { - if (conf->node()->intermediate_abstract()) { - return conf->node()->intermediate_abstract(); - } - return conf->GetEvaluatedValue()->abstract(); -} - -AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) { - AnfNodePtr value_node = NewValueNode(v); - value_node->set_abstract(abs_base); - MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString(); - return value_node; -} - -bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) { - while (fg != nullptr && fg != parent) { - fg = fg->parent(); - } - return fg == parent; -} -} // namespace - -FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(context); - MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString(); - return SpecializeFuncGraph(fg, context); -} - -FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(context); - auto iter = specializations_.find(context->SpecializeKey()); - if (iter != specializations_.end()) { - return iter->second->specialized_func_graph(); - } - - std::shared_ptr fg_spec = std::make_shared(this, fg, context); - FuncGraphPtr fg2 = fg_spec->specialized_func_graph(); - specializations_[context->SpecializeKey()] = fg_spec; - fg_spec->Run(); - return fg2; -} - -std::shared_ptr ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) { - MS_EXCEPTION_IF_NULL(context); - auto iter = specializations_.find(context->SpecializeKey()); - if (iter != specializations_.end()) { - return iter->second; - } - return nullptr; -} - -std::string GetNextCounter() { - static int g_CloneCounter = 1; - std::string str_count = std::to_string(g_CloneCounter); - g_CloneCounter++; - return str_count; -} - -FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, - const AnalysisContextPtr &context) - : specializer_(s), func_graph_(fg), context_(context) { - parent_ = s->GetFuncGraphSpecializer(context->parent()); - engine_ = s->engine(); - cloner_ = SpecializerClone(fg, std::make_shared(GetNextCounter())); - repl_node_ = cloner_->cloned_node(); - specialized_func_graph_ = cloner_->cloned_func_graph()[fg]; - todo_.push_back(fg->get_return()); - auto ps = fg->parameters(); - (void)todo_.insert(todo_.end(), ps.begin(), ps.end()); -} - -AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - FuncGraphPtr fg = node->func_graph(); - - if (node->isa()) { - return node; - } - std::shared_ptr specializer = shared_from_this(); - while (fg != nullptr && fg != specializer->func_graph_) { - specializer = specializer->parent_; - } - // If had replicated, just return that. - auto iter = specializer->repl_node_->find(node); - if (iter != specializer->repl_node_->end()) { - return iter->second; - } - - auto new_node = specializer->cloner_->CloneDisconnected(node); - if (node->isa()) { - if (!new_node->isa()) { - MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << "."; - } - auto c_node = node->cast(); - MS_EXCEPTION_IF_NULL(c_node); - auto inputs = c_node->inputs(); - std::vector new_inputs; - (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs), - [this](const AnfNodePtr &inp) -> AnfNodePtr { - if (inp->isa()) { - return inp; - } - return ReplicateDisconnectedNode(inp); - }); - auto c_new_node = new_node->cast(); - MS_EXCEPTION_IF_NULL(c_new_node); - c_new_node->set_inputs(new_inputs); - } - - iter = specializer->repl_node_->find(node); - if (iter != specializer->repl_node_->end()) { - if (iter->second == node) { - MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString(); - } - } else { - MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString(); - } - return new_node; -} - -AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - FuncGraphPtr fg = node->func_graph(); - - std::shared_ptr specializer = shared_from_this(); - while (fg != nullptr && fg != specializer->func_graph_) { - specializer = specializer->parent_; - } - - MS_EXCEPTION_IF_NULL(specializer->repl_node_); - auto iter = specializer->repl_node_->find(node); - if (iter != specializer->repl_node_->end()) { - return iter->second; - } - return node; -} - -void FuncGraphSpecializer::Run() { - MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString() - << ", cloned func graph name: " << specialized_func_graph_->ToString() - << ", func graph: " << func_graph_->get_return()->DebugString(); - FirstPass(); - SecondPass(); - MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString() - << ", cloned func graph name: " << specialized_func_graph_->ToString() - << ", new func graph: " << specialized_func_graph_->get_return()->DebugString(); -} - -void FuncGraphSpecializer::FirstPass() { - while (todo_.size()) { - AnfNodePtr node = todo_.back(); - todo_.pop_back(); - if (node->func_graph() == nullptr) { - // do nothing for ValueNode - continue; - } - if (node->func_graph() != func_graph_) { - if (parent_ == nullptr) { - MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } - parent_->AddTodoItem(node); - parent_->FirstPass(); - AnfNodePtr new_node = parent_->GetReplicatedNode(node); - if (node->isa()) { - parent_->ProcessCNode(new_node->cast()); - } - continue; - } - if (marked_.count(node) > 0) { - continue; - } - (void)marked_.insert(node); - ProcessNode(node); - } -} - -// Specialize CNode in func graphs -void FuncGraphSpecializer::SecondPass() { - for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) { - if (node->isa()) { - ProcessCNode(node->cast()); - } - } -} - -void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - ScopeGuard scope_guard(node->scope()); - AnfNodeConfigPtr conf = MakeConfig(node); - AnfNodePtr new_node = GetReplicatedNode(node); - MS_EXCEPTION_IF_NULL(new_node); - if (new_node->func_graph() != specialized_func_graph_) { - MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString() - << ", new_node: " << new_node->DebugString() - << ", new_node->func_graph(): " << new_node->func_graph()->ToString() - << ", specialized_func_graph_: " << specialized_func_graph_->ToString(); - return; - } - new_node->set_abstract(GetEvaluatedValueWrap(conf)); - if (new_node->isa() && new_node->abstract()->isa()) { - auto partial_abstract = dyn_cast(new_node->abstract()); - if (partial_abstract->node() == node) { - partial_abstract->set_node(new_node); - } - } - - MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); - - if (node->isa()) { - auto attrs = conf->GetEvaluatedValue()->attribute(); - auto c_old = node->cast(); - auto c_new = new_node->cast(); - auto new_inputs = c_new->inputs(); - auto old_inputs = c_old->inputs(); - for (size_t i = 0; i < old_inputs.size(); ++i) { - auto node_input = old_inputs[i]; - AnfNodeConfigPtr iconf = MakeConfig(node_input); - AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); - // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if - // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. - AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs); - if (replace_node == nullptr) { - replace_node = BuildReplacedNode(iconf); - MS_EXCEPTION_IF_NULL(replace_node); - replace_node->set_abstract(ival); - MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString(); - } else { - MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString() - << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString(); - } - if (new_inputs[i] != replace_node) { - new_inputs[i] = replace_node; - MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString(); - } - } - c_new->set_inputs(new_inputs); - } -} - -AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) { - MS_EXCEPTION_IF_NULL(conf); - - auto conf_iter = engine_->anfnode_config_map().find(conf); - AnfNodeConfigPtr new_conf = conf; - while (conf_iter != engine_->anfnode_config_map().end()) { - MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node(" - << new_conf->node()->DebugString() << ")"; - new_conf = conf_iter->second; - MS_EXCEPTION_IF_NULL(new_conf); - MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node(" - << conf->node()->DebugString() << ")"; - (void)ReplicateDisconnectedNode(new_conf->node()); - conf_iter = engine_->anfnode_config_map().find(new_conf); - } - todo_.push_back(new_conf->node()); - auto repl = GetReplicatedNode(new_conf->node()); - if (repl->func_graph()) { - MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString() - << ") to replace origin:" << new_conf->node()->DebugString(); - } else { - MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString() - << ") to replace origin: " << new_conf->node()->DebugString(); - } - return repl; -} - -namespace { -const StringImmPtr kDeadNode = std::make_shared("Dead Node"); -const StringImmPtr kPolyNode = std::make_shared("Poly Node"); - -inline bool CanSpecializeNode(const AnfNodePtr &node) { - if (IsValueNode(node) || IsValueNode(node) || IsValueNode(node)) { - return true; - } - return false; -} -} // namespace - -AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, - const AbstractBasePtrList &argvals) { - MS_EXCEPTION_IF_NULL(abs); - AbstractFunctionPtr real_a = dyn_cast(abs); - MS_EXCEPTION_IF_NULL(real_a); - - AbstractFunctionPtr func = real_a->GetUnique(); - SpecializeStatusCode errcode; - ScopeGuard scope_guard(node->scope()); - AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode); - if (repl == nullptr) { - if (errcode == kSpecializeFindUniqueArgvalDead) { - const auto error_dead_node = std::make_shared(kDeadNode, node); - repl = BuildValueNode(kDeadNode, error_dead_node); - MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString(); - } else if (errcode == kSpecializeFindUniqueArgvalPoly) { - const auto error_poly_node = std::make_shared(kPolyNode, node); - repl = BuildValueNode(kPolyNode, error_poly_node); - MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString(); - } else { - MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString() - << ", abstract: " << abs->ToString(); - } - } - - return repl; -} - -AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func, - const AbstractBasePtrList &args, - SpecializeStatusCode *errcode) { - MS_EXCEPTION_IF_NULL(abs); - MS_EXCEPTION_IF_NULL(func); - MS_EXCEPTION_IF_NULL(errcode); - *errcode = kSpecializeSuccess; - - auto real_func = dyn_cast(func); - if (real_func != nullptr) { - return BuildValueNode(real_func->prim(), abs); - } - - EvaluatorPtr eval; - eval = engine_->GetEvaluatorFor(func); - MS_EXCEPTION_IF_NULL(eval); - AbstractBasePtrList argvals = eval->NormalizeArgs(args); - - std::pair result; - SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result); - if (status != kSpecializeSuccess) { - *errcode = status; - return nullptr; - } - argvals = result.first; - AbstractBasePtr unique_output = result.second; - - auto prim_func = dyn_cast(func); - if (prim_func != nullptr) { - auto type_func = std::make_shared(prim_func->prim(), argvals, unique_output); - return BuildValueNode(prim_func->prim(), type_func); - } - - if (!eval->isa()) { - MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString(); - } - auto real_eval = dyn_cast(eval); - - if (func->context() == nullptr) { - MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); - } - AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); - MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size() - << ", graph: " << context->func_graph()->get_return()->DebugString(); - FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context); - return BuildValueNode(v, abs); -} - -AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) { - auto new_inputs = new_node->inputs(); - AnfNodePtr func = new_inputs[0]; - AbstractBasePtr fnval = new_inputs[0]->abstract(); - - AbstractBasePtrList args; - auto backed_fnval = fnval; - if (fnval->isa()) { - auto partial_closure = dyn_cast(fnval); - backed_fnval = partial_closure->fn(); - args = partial_closure->args(); - } - std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args), - [](const AnfNodePtr &inp) { return inp->abstract(); }); - - ScopeGuard scope_guard(new_node->scope()); - - auto specialized_node = BuildSpecializedNode(func, backed_fnval, args); - auto wrapped_node = specialized_node; - if (fnval->isa()) { - auto partial_closure = dyn_cast(fnval); - AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)), - specialized_node}; - auto anf_node = partial_closure->node(); - if (!anf_node->isa()) { - MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString(); - } - auto cnode = anf_node->cast(); - if (cnode->size() != partial_closure->args().size() + 2) { - MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() - << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); - } - auto attrs = std::make_shared(); - for (size_t i = 0; i < partial_closure->args().size(); i++) { - auto old_node = cnode->input(i + 2); - auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs); - if (possibile_value_node != nullptr) { - partial_node_list.push_back(possibile_value_node); - } else { - if (!(old_node->isa() || old_node->isa())) { - MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString(); - } - partial_node_list.push_back(old_node); - } - } - wrapped_node = new_node->func_graph()->NewCNode(partial_node_list); - wrapped_node->set_abstract(partial_closure); - } - return wrapped_node; -} - -const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { - auto cache_iter = evalcaches_.find(eval); - if (cache_iter == evalcaches_.end()) { - evalcaches_[eval] = eval->cache(); - return eval->cache(); - } - return cache_iter->second; -} - -std::pair FuncGraphSpecializer::BuildFromBroadedArgsVal( - const EvaluatorPtr &eval) { - MS_EXCEPTION_IF_NULL(eval); - std::unordered_set choices; - EvalResultPtr ret = nullptr; - AbstractBasePtrList broaded_argvals; - for (auto &argvals_map : *evalcaches_[eval]) { - auto argvals = argvals_map.first; - broaded_argvals.clear(); - - (void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); - (void)choices.insert(broaded_argvals); - MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals); - } - - if (1 == choices.size()) { - ConfigPtrList args_conf_list; - (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list), - [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared(v); }); - - // if broaden return null - ret = eval->Run(engine_, args_conf_list, nullptr); - EvaluatorCacheMapPtr real = std::make_shared(); - - (*real)[broaded_argvals] = ret; - evalcaches_[eval] = real; - return std::make_pair(broaded_argvals, ret->abstract()); - } else { - MS_LOG(DEBUG) << "Choices.size: " << choices.size(); - return std::make_pair(AbstractBasePtrList(), nullptr); - } -} - -void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { - MS_EXCEPTION_IF_NULL(new_node); - if (specializer_->seen().count(new_node) > 0) { - return; - } - specializer_->AddSeen(new_node); - auto new_inputs = new_node->inputs(); - if (new_inputs.empty()) { - MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; - } - AnfNodePtr func = new_inputs[0]; - MS_EXCEPTION_IF_NULL(func); - - // First element is func so arg start from 1 - std::vector args(new_inputs.begin() + 1, new_inputs.end()); - // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...) - while (IsPrimitiveCNode(func, prim::kPrimPartial)) { - std::vector inputs = func->cast()->inputs(); - // First element is partial, second is func so arg is start from 2 - (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); - func = inputs[1]; - } - new_inputs = args; - (void)new_inputs.insert(new_inputs.begin(), func); - - AbstractBasePtrList argvals; - MS_EXCEPTION_IF_NULL(new_inputs[0]); - AbstractBasePtr fnval = new_inputs[0]->abstract(); - MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", " - << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString(); - - // First element is func so function arguments start from 1 - for (size_t i = 1; i < new_inputs.size(); ++i) { - argvals.push_back(new_inputs[i]->abstract()); - MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", " - << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); - } - - if (!func->isa()) { - MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString(); - if (func->abstract()->isa() && !func->abstract()->isa()) { - auto func_abs = func->abstract()->cast(); - EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs); - std::pair result; - AbstractBasePtrList empty_args; - auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result); - MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status; - // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early - if (status == kSpecializeFindUniqueArgvalPoly || - (func->isa() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) || - func->abstract()->isa()))) { - auto wrapped_node = BuildSpecializedParameterNode(new_node); - new_inputs[0] = wrapped_node; - } - } - } - - if (CanSpecializeNode(func)) { - // for primitive node , we build the primitive node with infered attributes in the first pass - // so we do not build replaced node again here in second pass - if (IsValueNode(func)) { - new_inputs[0] = func; - } else { - new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); - } - } - - for (size_t i = 0; i < argvals.size();) { - size_t next = i + 1; - if (CanSpecializeNode(args[i])) { - new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector{}); - } - i = next; - } - new_node->set_inputs(new_inputs); -} - -namespace { -void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) { - MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items."; - int i = 0; - for (const auto &item : evaluator_cache_map) { - MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first; - } -} - -bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) { - if (func->isa() && argvals.empty()) { - MS_LOG(DEBUG) << "High order primitive return POLY."; - return true; - } - if (func->isa() && argvals.empty()) { - auto meta_func_graph_wrapper = dyn_cast(func); - auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph(); - if (meta_func_graph != nullptr && meta_func_graph->isa()) { - auto do_signature = dyn_cast(meta_func_graph); - if (do_signature != nullptr && do_signature->function()->isa()) { - MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY."; - return true; - } - } - } - return false; -} -} // end anonymous namespace - -SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval, - const AbstractBasePtrList &argvals, - std::pair *result) { - MS_EXCEPTION_IF_NULL(func); - MS_EXCEPTION_IF_NULL(eval); - MS_EXCEPTION_IF_NULL(result); - - EvaluatorCacheMap evaluator_cache_map = *eval->cache(); - if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { - *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract()); - return kSpecializeSuccess; - } - DumpEvaluatorCache(evaluator_cache_map, argvals); - - const EvaluatorCacheMapPtr &choices = GetEvalCache(eval); - MS_EXCEPTION_IF_NULL(choices); - - if (choices->count(argvals)) { - *result = std::make_pair(argvals, (*choices)[argvals]->abstract()); - return kSpecializeSuccess; - } else if (choices->size() == 1) { - MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; - *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); - return kSpecializeSuccess; - } else if (choices->empty()) { - MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase."; - return kSpecializeFindUniqueArgvalDead; - } else { - if (IsPolyFunc(func, argvals)) { - return kSpecializeFindUniqueArgvalPoly; - } - - MS_LOG(DEBUG) << "Try to find generalized argvals."; - *result = BuildFromBroadedArgsVal(eval); - if (!result->first.empty()) { - return kSpecializeSuccess; - } - MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism."; - return kSpecializeFindUniqueArgvalPoly; - } -} -static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) { - auto &prim_attrs = prim->attrs(); - bool is_attr_same = true; - for (auto &item : *attrs) { - auto itr = prim_attrs.find(item.first); - if (itr != prim_attrs.end()) { - if (!(*(itr->second) == *(item.second))) { - is_attr_same = false; - break; - } - } else { - is_attr_same = false; - break; - } - } - if (!is_attr_same) { - if (prim->isa()) { - PrimitivePyPtr prim_py = prim->cast(); - auto clone_fn = prim_py->GetPyObj().attr("_clone"); - py::object new_obj = clone_fn(); - auto cloned_prim = new_obj.cast(); - for (auto &item : *attrs) { - cloned_prim->AddAttr(item.first, item.second); - } - return cloned_prim; - } - auto cloned_prim = std::make_shared(*prim); - for (auto &item : *attrs) { - cloned_prim->AddAttr(item.first, item.second); - } - return cloned_prim; - } - return prim; -} - -AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, - const AttrValueMapPtr &attrs) { - MS_EXCEPTION_IF_NULL(origin_node); - MS_EXCEPTION_IF_NULL(ival); - - AbstractFunctionPtr abs = dyn_cast(ival); - if (abs != nullptr) { - // Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction. - if (abs->isa()) { - return nullptr; - } - ValuePtr value = nullptr; - if (abs->isa()) { - auto real_fn = dyn_cast(abs); - // for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one - if (attrs != nullptr) { - value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); - } else { - value = real_fn->prim(); - } - } else if (abs->isa()) { - auto real_fn = dyn_cast(abs); - value = real_fn->meta_func_graph(); - } else if (abs->isa()) { - auto real_fn = dyn_cast(abs); - value = real_fn->func_graph(); - } else { - return nullptr; - } - if (!value->isa() || value->cast()->parent() == nullptr || - (IsValueNode(origin_node) && IsVisible(func_graph_, value->cast()->parent()))) { - return BuildValueNode(value, ival); - } else { - return nullptr; - } - } else { - ValuePtr val = ival->BuildValue(); - if (val->isa()) { - return nullptr; - } - // keep primitive 'depend' not to be optimized - if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) { - return nullptr; - } - return BuildValueNode(val, ival); - } -} - -AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) { - return engine_->MakeConfig(node, context_); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h deleted file mode 100644 index b04978586d..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h +++ /dev/null @@ -1,135 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ -#define PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph_cloner.h" -#include "pipeline/static_analysis/evaluator.h" - -namespace mindspore { -namespace abstract { -enum SpecializeStatusCode { - kSpecializeSuccess = 0, - kSpecializeFindUniqueArgvalDead = 1, // Dead Node - kSpecializeFindUniqueArgvalPoly = 2, // Poly Node - kSpecializeFailure = 0xFF -}; - -class FuncGraphSpecializer; - -// Specialize a func graph using analyzed abstract values. -class ProgramSpecializer { - public: - explicit ProgramSpecializer(const std::shared_ptr &engine) : engine_(engine) { - mng_ = engine_->func_graph_manager(); - } - ~ProgramSpecializer() = default; - // Run the program specializer on the topmost graph in the given context. - FuncGraphPtr Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context); - const std::unordered_set &seen() const { return seen_; } - void AddSeen(const AnfNodePtr &node) { (void)seen_.insert(node); } - - std::shared_ptr GetFuncGraphSpecializer(const AnalysisContextPtr &context); - // Specialze one FuncGraph in a given context. - FuncGraphPtr SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context); - - std::shared_ptr engine() { return engine_; } - - private: - std::shared_ptr engine_; - std::unordered_set seen_; - FuncGraphManagerPtr mng_; - std::unordered_map, ContextHasher, ContextEqual> - specializations_; -}; - -class FuncGraphSpecializer : public std::enable_shared_from_this { - public: - FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, const AnalysisContextPtr &context); - virtual ~FuncGraphSpecializer() { - specializer_ = nullptr; - repl_node_ = nullptr; - } - void Run(); - FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; } - - private: - ProgramSpecializer *specializer_; - FuncGraphPtr func_graph_; - FuncGraphPtr specialized_func_graph_; - AnalysisContextPtr context_; - std::shared_ptr parent_; - std::shared_ptr engine_; - ClonerPtr cloner_; - // ProcessNode-> [cloner_->CloneDisconnected] will clone AnfNode again. - // So, repl_node_ should pointer to GraphCloner->repl_node_ other than a copy of that. - std::unordered_map *repl_node_; - std::vector todo_; - std::unordered_set marked_; - std::unordered_map evalcaches_; - - void FirstPass(); - void SecondPass(); - void ProcessNode(const AnfNodePtr &node); - void ProcessCNode(const CNodePtr &new_node); - - AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node); - inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); } - // Get node replicated by Cloner. - AnfNodePtr GetReplicatedNode(const AnfNodePtr &node); - // Replicated node which is not used directly by a func graph, so it's not searchable from it's return node - // (disconnected). - AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node); - - // Build a value node from parameter if the function graph has special flag to hint it can be done. - AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); - - // Build a value node if ival is constant and not any-value - AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, - const AttrValueMapPtr &attrs); - // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a - // replicated node. - AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); - // Build a specialized node from given argvals; - AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, - const AbstractBasePtrList &argvals); - AnfNodePtr BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func, - const AbstractBasePtrList &args, SpecializeStatusCode *errcode); - - // Find the unique argument values which can be used to specialize a primitive or graph function. - SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval, - const AbstractBasePtrList &argvals, - std::pair *result); - // Get cache, it may be eval's cache or cache built from broaded argument values. - const EvaluatorCacheMapPtr &GetEvalCache(const EvaluatorPtr &eval); - // Try to build unique argvals from the broaded arg vals if it is unique. - std::pair BuildFromBroadedArgsVal(const EvaluatorPtr &eval); -}; -} // namespace abstract -} // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc deleted file mode 100644 index 5416576680..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc +++ /dev/null @@ -1,655 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/static_analysis.h" - -#include -#include - -#include "pipeline/static_analysis/utils.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "utils/symbolic.h" -#include "ir/tensor.h" -#include "ir/func_graph_cloner.h" -#include "./common.h" -#include "pipeline/parse/data_converter.h" -#include "debug/draw.h" -#include "pipeline/static_analysis/evaluator.h" -#include "debug/trace.h" - -namespace mindspore { -namespace abstract { -bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) { - if (dyn_cast(arg_spec)) { - auto v = arg_spec->GetValueTrack(); - if (v->isa()) { - return true; - } else { - return false; - } - } else { - return false; - } -} - -AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) { - if (dyn_cast(arg1) && dyn_cast(arg2)) { - return arg1->Join(arg2); - } - return nullptr; -} - -void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { - MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() - << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString() - << ", Pointer: " << result->abstract().get(); - cache_[conf] = result; - - // Set intermediate abstract value. - if (IsIntermediateAbstract(result->abstract())) { - if (conf->node()->intermediate_abstract() == nullptr) { - conf->node()->set_intermediate_abstract(result->abstract()); - MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString(); - } else { - auto old_spec = conf->node()->intermediate_abstract(); - auto joined_spec = IntermediateJoin(result->abstract(), old_spec); - conf->node()->set_intermediate_abstract(joined_spec); - MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" - << result->abstract()->ToString() << "\njoined_spec:\t" - << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); - } - } -} - -EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { - auto value = cache_.find(conf); - if (value == cache_.end()) { - return nullptr; - } - return value->second; -} - -std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const { - MS_EXCEPTION_IF_NULL(conf); - MS_EXCEPTION_IF_NULL(conf->node()); - std::size_t hash_value = conf->node()->hash(); - if (!conf->context()->IsDummyContext()) { - hash_value = hash_combine(hash_value, std::hash{}(conf->context().get())); - } - if (conf->context() != nullptr && conf->context()->func_graph() != nullptr) { - MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() - << ", Graph: " << conf->context()->func_graph()->ToString() << " ### , hash value: " << hash_value; - } else { - MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() << " ### , hash value: " << hash_value; - } - return hash_value; -} - -bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const { - if (lhs == nullptr || rhs == nullptr) { - return false; - } - if (lhs == rhs) { - return true; - } - return (*lhs == *rhs); -} - -AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) { - ConfigPtrList args_conf_list; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), - [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - MS_EXCEPTION_IF_NULL(func_graph_manager_); - func_graph_manager_->AddFuncGraph(func_graph); - - AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); - - // Running the analyzer. - AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); - MS_EXCEPTION_IF_NULL(root_context); - MS_EXCEPTION_IF_NULL(root_context->func_graph()); - AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context); - MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(INFO) << func_graph->ToString() << ": Run finished."; - - AnalysisResult result; - MS_EXCEPTION_IF_NULL(output_conf); - result.inferred = output_conf->GetEvaluatedValue(); - result.context = root_context; - return result; -} - -AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, - const ConfigPtrList &args_conf_list) { - std::shared_ptr eval = std::make_shared(func_graph, context); - (void)eval->Run(shared_from_this(), args_conf_list, nullptr); - return eval->graph_context(); -} - -EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { - MS_EXCEPTION_IF_NULL(conf); - auto value = cache_.GetValue(conf); - if (value != nullptr) { - MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get() - << ", " << value->abstract()->ToString(); - return value; - } - - MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString(); - value = Eval(conf); - if (value == nullptr) { - MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr"; - } - cache_.set_value(conf, value); - return value; -} - -EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { - MS_EXCEPTION_IF_NULL(conf); - AnfNodePtr node = conf->node(); - EvalResultPtr eval_result = nullptr; -#ifdef DEBUG - compute_conf_stack_.push_back(node); - std::ostringstream buffer; - buffer << "Compute Config Begin:"; - for (auto iter : compute_conf_stack_) { - buffer << " -> " << iter->DebugString(); - } - MS_LOG(DEBUG) << buffer.str(); -#endif - MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString(); - MS_EXCEPTION_IF_NULL(node); - if (node->abstract() != nullptr) { - MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString(); - eval_result = std::make_shared(node->abstract(), std::make_shared()); - } else if (node->isa()) { - auto value_node = node->cast(); - eval_result = std::make_shared(EvalValueNode(value_node, conf), nullptr); - } else if (node->isa()) { - auto cnode = node->cast(); - trace::TraceEvalCNodeEnter(conf); - eval_result = EvalCNode(cnode, conf); - trace::TraceEvalCNodeLeave(); - } else { - MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() - << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } - -#ifdef DEBUG - compute_conf_stack_.pop_back(); - if (eval_result == nullptr) { - MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString() - << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } -#endif - MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString(); - return eval_result; -} - -AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) { - MS_EXCEPTION_IF_NULL(conf); - MS_EXCEPTION_IF_NULL(value_node); - return ToAbstract(value_node->value(), conf->context(), conf); -} - -EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { - MS_EXCEPTION_IF_NULL(conf); - MS_EXCEPTION_IF_NULL(cnode); - auto &inputs = cnode->inputs(); - if (inputs.empty()) { - MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString(); - } - - AnfNodePtr func_node = inputs[0]; - MS_EXCEPTION_IF_NULL(func_node); - MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString(); - AnalysisContextPtr context = conf->context(); - AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); - MS_EXCEPTION_IF_NULL(func_conf); - // Keep it in a local variable, otherwise smart pointer will free it. - AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract(); - if (maybe_func == nullptr) { - MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() - << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); - } - if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) { - MS_LOG(DEBUG) << "EvalCNode eval Undetermined"; - return std::make_shared(maybe_func->Clone(), std::make_shared()); - } - AbstractFunctionPtr func = dyn_cast(maybe_func); - if (func == nullptr) { - MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString() - << ", func_conf: " << func_conf->ToString() - << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); - } - - ConfigPtrList args_conf_list; - // ignore the first node which is function name - for (std::size_t i = 1; i < inputs.size(); i++) { - const AnfNodePtr &node = inputs[i]; - args_conf_list.push_back(MakeConfig(node, context)); - } - std::vector infs; - - auto build_evaluator = [this, &infs, &cnode](const AbstractFuncAtomPtr &poss) { - auto evaluator = this->GetEvaluatorFor(poss); - evaluator->set_bound_node(cnode); - infs.push_back(evaluator); - }; - func->Visit(build_evaluator); - - return ExecuteEvaluators(infs, conf, args_conf_list); -} - -EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { - ConfigPtrList args_conf_list; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), - [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - std::vector infs; - MS_EXCEPTION_IF_NULL(func); - auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) { - auto evaluator = this->GetEvaluatorFor(poss); - infs.push_back(evaluator); - }; - func->Visit(build_evaluator); - return ExecuteEvaluators(infs, nullptr, args_conf_list); -} - -void AnalysisEngine::ClearEvaluatorCache() { - for (std::pair element : constructors_) { - EvaluatorPtr evaluator = element.second; - MS_EXCEPTION_IF_NULL(evaluator); - MS_EXCEPTION_IF_NULL(evaluator->cache()); - evaluator->cache()->clear(); - } - for (auto &element : prim_constructors_) { - EvaluatorPtr evaluator = element.second; - MS_EXCEPTION_IF_NULL(evaluator); - MS_EXCEPTION_IF_NULL(evaluator->cache()); - evaluator->cache()->clear(); - } - for (auto &element : prim_py_evaluators_) { - EvaluatorPtr evaluator = element.second; - MS_EXCEPTION_IF_NULL(evaluator); - MS_EXCEPTION_IF_NULL(evaluator->cache()); - evaluator->cache()->clear(); - } -} - -void AnalysisEngine::Clear() { - cache_.Clear(); - anfnode_config_map_.clear(); - eval_trace_.clear(); - constructors_.clear(); -} - -namespace { -EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { - // Custom Primitive with python infer_shape, infer_type - EvaluatorPtr evaluator = nullptr; - MS_EXCEPTION_IF_NULL(prim); - if (prim->isa()) { - evaluator = std::make_shared(prim); - return evaluator; - } - if (prim->isa()) { - evaluator = std::make_shared(prim); - return evaluator; - } - if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) { - evaluator = std::make_shared(prim); - return evaluator; - } - if (prim->HasPyEvaluator()) { - auto prim_py = dyn_cast(prim); - if (prim_py != nullptr) { - if (engine == nullptr) { - return std::make_shared(prim_py); - } - - const auto &iter = engine->prim_py_evaluators_.find(prim_py); - if (iter != engine->prim_py_evaluators_.end()) { - return iter->second; - } - evaluator = std::make_shared(prim_py); - engine->prim_py_evaluators_[prim_py] = evaluator; - return evaluator; - } - MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; - } - - if (prim->isa() || prim->HasAttr()) { - if (engine == nullptr) { - (void)GetPrimEvaluatorConstructors(); - } - // If a primitive may have attr, try to create a new evaluator. - StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); - if (eval_impl != nullptr) { - return std::make_shared(prim, eval_impl); - } - } - - if (engine == nullptr) { - // If engine is nullptr, get constructor from default. - const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); - auto iter = prim_evaluator_map.find(prim); - if (iter != prim_evaluator_map.end()) { - evaluator = iter->second; - } - } else { - // If engine is given, get constructor from engine resource. - const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors(); - auto iter = prim_evaluator_map.find(prim); - if (iter != prim_evaluator_map.end()) { - evaluator = iter->second; - } - } - if (evaluator == nullptr) { - MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ")."; - } - return evaluator; -} -} // namespace - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - auto inf_pair = constructors_.find(func); - if (inf_pair != constructors_.end()) { - return inf_pair->second; - } - MS_EXCEPTION_IF_NULL(func); - auto primitive = func->prim(); - auto evaluator = GetPrimEvaluator(primitive, shared_from_this()); - constructors_[func] = evaluator; - return evaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - auto inf_pair = constructors_.find(func); - if (inf_pair != constructors_.end()) { - return inf_pair->second; - } - MS_EXCEPTION_IF_NULL(func); - std::shared_ptr func_graph_evaluator = - std::make_shared(func->func_graph(), func->context()); - constructors_[func] = func_graph_evaluator; - return func_graph_evaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - auto inf_pair = constructors_.find(func); - if (inf_pair != constructors_.end()) { - return inf_pair->second; - } - MS_EXCEPTION_IF_NULL(func); - std::shared_ptr evaluator = - std::make_shared(func->meta_func_graph(), func->context(), func->GetScope()); - constructors_[func] = evaluator; - return evaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - MS_EXCEPTION_IF_NULL(func); - AbstractFunctionPtr func_orig = func->fn(); - EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); - auto jevaluator = std::make_shared(evaluator_orig, func_orig); - return jevaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - MS_EXCEPTION_IF_NULL(func); - std::shared_ptr virtual_evaluator = - std::make_shared(func->args_spec_list(), func->output()); - return virtual_evaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - MS_EXCEPTION_IF_NULL(func); - AbstractFunctionPtr func_orig = func->fn(); - EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); - std::shared_ptr partial_evaluator = - std::make_shared(evaluator_orig, func->args()); - return partial_evaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &) { - MS_LOG(EXCEPTION) << "Should not be called "; -} - -// Forward to specific subclass of FunctionWrapper. -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) { - MS_EXCEPTION_IF_NULL(func); - EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this()); - return evaluator; -} - -EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { - MS_LOG(DEBUG) << "The func value: " << func->ToString(); - if (func->tracking_id() != nullptr) { - MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); - } - MS_EXCEPTION_IF_NULL(func); - if (func->tracking_id() == nullptr) { - EvaluatorPtr evaluator = _GetEvaluatorFor(func); - return evaluator; - } - auto inf_pair = constructors_.find(func); - if (inf_pair != constructors_.end()) { - return inf_pair->second; - } - - AbstractFunctionPtr func_generic = func->Copy(); - func_generic->set_tracking_id(nullptr); - EvaluatorPtr eval = _GetEvaluatorFor(func_generic); - auto tracked_eval = std::make_shared(eval); - constructors_[func] = tracked_eval; - - return tracked_eval; -} - -EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector &evaluators, - const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { - if (evaluators.size() == 1) { - EvaluatorPtr eval = evaluators[0]; - MS_EXCEPTION_IF_NULL(eval); - return eval->Run(shared_from_this(), args_conf_list, out_conf); - } - return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); -} - -void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { - auto fg_eval = evaluator->cast(); - if (fg_eval == nullptr) { - return; - } - auto fg = fg_eval->func_graph(); - MS_EXCEPTION_IF_NULL(fg); - auto undetermined_fgs = fg->recursive_graphs(); - if (undetermined_fgs) { - auto fg_parent = fg->parent(); - MS_EXCEPTION_IF_NULL(fg_parent); - fg_parent->set_flag(kFuncGraphFlagUndetermined, true); - MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); - } -} - -EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector &evaluators, - const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list, - const EvalTraceRevIter &it, bool *continue_flag) { - *continue_flag = false; - // Find latest entry function to handle nested recursion. - EvaluatorPtr latest_entry = eval; - auto latest_entry_iter = eval_trace_.rbegin(); - for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { - auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); - if (it_temp != evaluators.end()) { - latest_entry = *it_temp; - latest_entry_iter = r_it; - break; - } - latest_entry_iter = ++r_it; - } - if (latest_entry != eval) { - MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); - *continue_flag = true; - return latest_entry; - } - - bool has_undetermined = false; - // Check whether sub loop has untraced undetermined evaluator. - std::set> undetermined_evals; - for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { - undetermined_evals.insert(*r_it); - } - MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); - - for (auto u_eval : undetermined_evals) { - MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; - if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { - MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; - has_undetermined = true; - break; - } - } - if (has_undetermined == false) { - MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; - *continue_flag = true; - return latest_entry; - } - - return latest_entry; -} - -EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) { - if (out_specs.size() == 0) { - MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; - } - - if (out_specs.size() == 1) { - MS_EXCEPTION_IF_NULL(out_specs[0]); - // If only one result derived, then broaden it to avoid wrong constant propagation. - return std::make_shared(out_specs[0]->Broaden(), std::make_shared()); - } - auto joined_spec = AbstractJoin(out_specs); - MS_EXCEPTION_IF_NULL(joined_spec); - MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); - return std::make_shared(joined_spec, std::make_shared()); -} - -EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector &evaluators, - const AnfNodeConfigPtr &out_conf, - const ConfigPtrList &args_conf_list) { - AbstractBasePtrList out_specs; - if (!multi_poss_.count(evaluators[0])) { - multi_poss_[evaluators[0]] = evaluators[1]; - multi_poss_[evaluators[1]] = evaluators[0]; - } - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - for (auto eval : evaluators) { - SetUndeterminedFlag(eval); - - auto current_inf = std::make_pair(eval, args_spec_list); - MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); - - // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. - auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); - if (it == eval_trace_.rend()) { - eval_trace_.push_back(current_inf); - MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); - MS_EXCEPTION_IF_NULL(eval); - auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); - MS_EXCEPTION_IF_NULL(eval_result->abstract()); - MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); - out_specs.push_back(eval_result->abstract()); - eval_trace_.pop_back(); - if (eval_trace_.empty()) { - multi_poss_.clear(); - } - } else if (it != eval_trace_.rbegin()) { - bool continue_flag = false; - auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); - if (continue_flag) { - continue; - } - - // Try to travel the latest undetermined. - if (latest_entry != eval_trace_.rbegin()->first) { - MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); - auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); - MS_EXCEPTION_IF_NULL(eval_result->abstract()); - MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() - << " return out_spec: " << eval_result->abstract()->ToString(); - return eval_result; - } - } - } - - return ProcessEvalResults(out_specs); -} - -EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { - AnfNodeConfigPtr self = shared_from_base(); - return engine_.lock()->GetEvaluatedValue(self); -} - -AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) { - if (value->isa()) { - auto func_graph = value->cast(); - return func_graph->MakeAbstractClosure(context); - } - AnfNodePtr anf_node = nullptr; - if (conf != nullptr) { - anf_node = conf->node(); - } - if (value->isa()) { - auto meta_func_graph = value->cast(); - return meta_func_graph->MakeAbstractClosure(anf_node); - } - if (value->isa()) { - auto prim = value->cast(); - return prim->ToPrimAbstract(anf_node); - } - return value->ToAbstract(); -} - -AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) { - AbstractBasePtr a = ToAbstract(value, nullptr, nullptr); - if (broaden) { - a = a->Broaden(); - } - return a; -} - -EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { - auto evaluator = GetPrimEvaluator(primitive, nullptr); - MS_EXCEPTION_IF_NULL(evaluator); - if (!evaluator->isa()) { - MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but " - << evaluator->ToString(); - } - auto trivial_evaluator = dyn_cast(evaluator); - auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); - return eval_result; -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h deleted file mode 100644 index a0b7ee5478..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h +++ /dev/null @@ -1,280 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ -#define PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ - -#include -#include -#include -#include -#include -#include -#include - -#ifdef DEBUG -#include -#endif - -#include "utils/log_adapter.h" -#include "ir/anf.h" -#include "ir/primitive.h" -#include "pipeline/static_analysis/analysis_context.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "pipeline/parse/parse.h" - -namespace mindspore { -namespace abstract { -// define attribute value map -using AttrValueMap = std::unordered_map; -using AttrValueMapPtr = std::shared_ptr; - -// the class to save evaluated result: abstract value and modified attribute -class EvalResult : public Base { - public: - EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {} - ~EvalResult() override = default; - MS_DECLARE_PARENT(EvalResult, Base); - AbstractBasePtr abstract() { return abstract_; } - AttrValueMapPtr attribute() { return attribute_; } - - private: - AbstractBasePtr abstract_; - AttrValueMapPtr attribute_; -}; - -using EvalResultPtr = std::shared_ptr; -// Superclass for AnfNodeConfig and VirtualConfig. -class Config : public Base { - public: - Config() = default; - ~Config() override = default; - MS_DECLARE_PARENT(Config, Base); - virtual EvalResultPtr GetEvaluatedValue() = 0; -}; - -// Config will be stored in AnalysisCache -using ConfigPtr = std::shared_ptr; -using ConfigPtrList = std::vector; - -// Config to a certain node in a certain context. -class AnfNodeConfig : public Config { - public: - AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context) - : Config(), engine_(std::weak_ptr(engine)), node_(node) { - FuncGraphPtr fg; - if (IsValueNode(node)) { - auto v = node->cast(); - fg = v->value()->cast(); - } else { - fg = node->func_graph(); - } - context_ = nullptr; - if (context != nullptr) { - context_ = context->Filter(fg); - } - } - - ~AnfNodeConfig() override = default; - MS_DECLARE_PARENT(AnfNodeConfig, Config); - - EvalResultPtr GetEvaluatedValue() override; - - AnalysisContextPtr context() const { return context_; } - - AnfNodePtr node() const { return node_; } - - AnalysisEnginePtr engine() const { return engine_.lock(); } - - // used by unordered_map; - bool operator==(const AnfNodeConfig &other) const { - // compare node with pointer, context with pointer except DummyContext as it's created by make_shared; - // context should not be nullptr; - if (context_->IsDummyContext() && other.context_->IsDummyContext()) { - return true; - } - return (node_ == other.node_) && (context_ == other.context_); - } - - std::string ToString() const override { - std::ostringstream buffer; - buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString(); - return buffer.str(); - } - - private: - // AnalysisEngine is global. - // As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use - // weak_ptr to break Config cycle. - std::weak_ptr engine_; - AnfNodePtr node_; - AnalysisContextPtr context_; -}; - -using AnfNodeConfigPtr = std::shared_ptr; - -struct AnfNodeConfigHasher { - std::size_t operator()(const AnfNodeConfigPtr conf) const; -}; - -struct AnfNodeConfigEqual { - bool operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const; -}; - -class VirtualConfig : public Config { - public: - explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {} - - ~VirtualConfig() override = default; - MS_DECLARE_PARENT(VirtualConfig, Config); - EvalResultPtr GetEvaluatedValue() override { - return std::make_shared(abstract_, std::make_shared()); - } - - private: - AbstractBasePtr abstract_; -}; - -// AnalysisCache -class AnalysisCache { - public: - AnalysisCache() = default; - ~AnalysisCache() = default; - void Clear() { cache_.clear(); } - void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg); - EvalResultPtr GetValue(const AnfNodeConfigPtr &conf); - - private: - std::unordered_map cache_; -}; - -using PrimEvaluatorMap = std::unordered_map; -using AnfNodeConfigMap = - std::unordered_map; - -struct AnalysisResult { - EvalResultPtr inferred; - AnalysisContextPtr context; -}; - -using EvalTraceRevIter = std::list>::reverse_iterator; - -class AnalysisEngine : public std::enable_shared_from_this { - public: - AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) - : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {} - ~AnalysisEngine() = default; - - // func_graph: The func_graph to analyze. - // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. - AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); - EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); - // Return the Evaluator for the given function. - EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); - - AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); - EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); - // Infer the result of fn(args). - EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); - void Clear(); - void ClearEvaluatorCache(); - AnalysisCache &cache() { return cache_; } - AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) { - return std::make_shared(shared_from_this(), node, context); - } - // Overloaded function. - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - - FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; } - const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; } - - // Set the analysis result for orig to the result for new. - // This sets an entry in anfnode_config_map from orig to new. - EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { - // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. - (void)anfnode_config_map_.emplace(orig_conf, new_conf); - MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() - << ", to new_conf: " << new_conf->node()->DebugString(); - return GetEvaluatedValue(new_conf); - } - const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } - - AnalysisCache cache_; - std::unordered_map prim_py_evaluators_; - - private: - void SetUndeterminedFlag(const EvaluatorPtr &evaluator); - EvaluatorPtr HandleNestedRecursion(const std::vector &evaluators, const EvaluatorPtr &eval, - const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, - bool *continue_flag); - EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs); - - const PrimEvaluatorMap &prim_constructors_; - FuncGraphManagerPtr func_graph_manager_; - std::unordered_map constructors_; - AnfNodeConfigMap anfnode_config_map_; - // Use a list to trace multiple evaluators. - std::list> eval_trace_; - std::map multi_poss_; - - AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, - const ConfigPtrList &args_conf_list); - EvalResultPtr Eval(const AnfNodeConfigPtr &conf); - EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn); - EvalResultPtr ExecuteEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, - const ConfigPtrList &args_conf_list); - EvalResultPtr ExecuteMultipleEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, - const ConfigPtrList &args_conf_list); - -#ifdef DEBUG - std::vector compute_conf_stack_; -#endif -}; - -// Translate the value to an abstract value. -// Arguments: -// value: The value to convert. -// context: The context in which the value was found, used if the value is a Graph. -// conf: The Config to the valuenode we are converting, if there is one, -// so that we can generate a tracking_id. -AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr, - const AnfNodeConfigPtr &conf = nullptr); - -// Convert a value to an abstract value. -// Arguments: -// v: The value to convert. -// broaden: If True, concrete values will be made more abstract, so e.g. -// the value 1234 would become ANYTHING. -AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false); - -template -AbstractBasePtr FromValue(const T &value, bool broaden = false) { - return FromValueInside(MakeValue(value), broaden); -} - -EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); -} // namespace abstract -} // namespace mindspore - -#endif // PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/utils.cc b/mindspore/ccsrc/pipeline/static_analysis/utils.cc deleted file mode 100644 index 4c399f6ffc..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/utils.cc +++ /dev/null @@ -1,201 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/utils.h" - -#include -#include -#include -#include "utils/symbolic.h" -#include "pipeline/static_analysis/param_validator.h" - -namespace mindspore { -namespace abstract { -ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) { - MS_EXCEPTION_IF_NULL(value1); - MS_EXCEPTION_IF_NULL(value2); - if (*value1 == *value2) { - return value1; - } - return kAnyValue; -} - -TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) { - MS_EXCEPTION_IF_NULL(type1); - MS_EXCEPTION_IF_NULL(type2); - if (*type1 == *type2) { - return type1; - } - return kAnyType; -} - -ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { - MS_EXCEPTION_IF_NULL(shape1); - MS_EXCEPTION_IF_NULL(shape2); - if (*shape1 == *shape2) { - return shape1; - } - if (shape1->shape().size() != shape2->shape().size()) { - MS_LOG(WARNING) << "Unsupported shape join. shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString(); - return shape1; - } - std::vector dims; - dims.resize(shape1->shape().size()); - for (std::size_t i = 0; i < shape1->shape().size(); i++) { - if (shape1->shape()[i] == shape2->shape()[i]) { - dims[i] = shape1->shape()[i]; - } else { - dims[i] = Shape::SHP_ANY; - } - } - return std::make_shared(dims); -} - -AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) { - if (args_spec_list.size() < 1) { - MS_LOG(EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is " << args_spec_list.size() - << "."; - } - AbstractBasePtr arg_spec_tmp = args_spec_list[0]; - MS_EXCEPTION_IF_NULL(arg_spec_tmp); - for (auto arg_spec : args_spec_list) { - arg_spec_tmp = arg_spec_tmp->Join(arg_spec); - MS_EXCEPTION_IF_NULL(arg_spec_tmp); - } - return arg_spec_tmp; -} - -AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2) { - if (spec1.size() != spec2.size()) { - MS_LOG(EXCEPTION) << "Join failed as list don't have the same size. spec1: " << ::mindspore::ToString(spec1) - << ", spec2: " << ::mindspore::ToString(spec2); - } - AbstractBasePtrList joined_list; - bool changes = false; - for (std::size_t i = 0; i < spec1.size(); i++) { - auto joined_elem = spec1[i]->Join(spec2[i]); - if (joined_elem != spec1[i]) { - changes = true; - } - joined_list.push_back(joined_elem); - } - if (!changes) { - return spec1; - } - return joined_list; -} - -AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) { - AbstractFunctionPtr f_spec = dyn_cast(spec); - if (f_spec != nullptr) { - return std::make_shared(kAnyValue, std::make_shared()); - } - return spec->Clone(); -} - -namespace { -// Join all types in args_type_list; -TypePtr TypeJoin(const TypePtrList &args_type_list) { - if (args_type_list.empty()) { - MS_LOG(EXCEPTION) << "args_type_list is empty"; - } - - TypePtr type_tmp = args_type_list[0]; - for (std::size_t i = 1; i < args_type_list.size(); i++) { - type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]); - } - return type_tmp; -} -} // namespace - -bool CheckType(const TypePtr &expected_type, const TypePtr &x) { - // As x and predicate both are mindspore type staticly, here we only to judge whether - // x is predicate or is a subclass of predicate. - return IsIdentidityOrSubclass(x, expected_type); -} - -TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) { - MS_EXCEPTION_IF_NULL(predicate); - for (auto arg_type : args_type_list) { - MS_EXCEPTION_IF_NULL(arg_type); - if (!CheckType(predicate, arg_type)) { - MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString(); - } - } - return TypeJoin(args_type_list); -} - -int GetPositiveAxis(int axis_value, size_t increment) { - if (axis_value < 0) { - axis_value = axis_value + SizeToInt(increment); - } - - if (axis_value < 0) { - MS_LOG(EXCEPTION) << "axis_value should not still <0"; - } - - return axis_value; -} - -// Return if two shapes can be broadcast. -// Broadcast shape is placed in broadcast_output_shape. -std::vector RealBroadcast(const std::string &op, std::vector x_shape, std::vector y_shape) { - std::reverse(x_shape.begin(), x_shape.end()); - std::reverse(y_shape.begin(), y_shape.end()); - // Fill a placeholder value 1 which will be replaced later. - size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size(); - y_shape.resize(std_len, 1); - x_shape.resize(std_len, 1); - - std::vector broadcast_shape; - for (size_t i = 0; i < std_len; i++) { - int x_i = x_shape[i]; // i-th dimension of x - int y_i = y_shape[i]; // i-th dimension of y - int output_i = 0; // i-th dimension of the output - if (x_i == y_i) { - output_i = x_i; - } else if (x_i == 1) { - output_i = y_i; - } else if (y_i == 1) { - output_i = x_i; - } else { - MS_LOG(EXCEPTION) - << op - << " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting " - "requirements"; - } - broadcast_shape.push_back(output_i); - } - std::reverse(broadcast_shape.begin(), broadcast_shape.end()); - return broadcast_shape; -} - -ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, - const AbstractTensorPtr &tensor_y) { - mindspore::abstract::ShapePtr tensor_x_shape = tensor_x->shape(); - mindspore::abstract::ShapePtr tensor_y_shape = tensor_y->shape(); - // if is the same shape ,just return the x_shape - if (*tensor_x_shape == *tensor_y_shape) { - return tensor_x_shape; - } - auto x_shape = tensor_x_shape->shape(); - auto y_shape = tensor_y_shape->shape(); - return std::make_shared(RealBroadcast(op, x_shape, y_shape)); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/utils.h b/mindspore/ccsrc/pipeline/static_analysis/utils.h deleted file mode 100644 index 6a709ea99c..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/utils.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_UTILS_H_ -#define PIPELINE_STATIC_ANALYSIS_UTILS_H_ - -#include -#include -#include -#include -#include "pipeline/static_analysis/abstract_value.h" -#include "utils/any.h" -#include "utils/misc.h" -#include "utils/convert_utils.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace abstract { -ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2); -TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2); -ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2); - -AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list); -AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2); - -// Return an abstract value for the sensitivity of x. -// The sensitivity of a function is an Env -// The sensitivity of J(x) is x -// else self.Clone; -AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec); - -TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list); - -bool CheckType(const TypePtr &expected_type, const TypePtr &x); - -int GetPositiveAxis(int axis_value, size_t increment); - -// Get broadcasted shape for binary element-wise operation -ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y); -} // namespace abstract -} // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_UTILS_H_ diff --git a/mindspore/ccsrc/pipeline/validator.cc b/mindspore/ccsrc/pipeline/validator.cc deleted file mode 100644 index bbca3c8721..0000000000 --- a/mindspore/ccsrc/pipeline/validator.cc +++ /dev/null @@ -1,120 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/validator.h" - -#include -#include - -#include "ir/manager.h" -#include "ir/dtype.h" -#include "./common.h" -#include "pipeline/static_analysis/prim.h" - -namespace mindspore { -namespace validator { -using mindspore::abstract::AbstractBase; -using mindspore::abstract::AbstractClass; -using mindspore::abstract::AbstractError; -using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractIndexedSlices; -using mindspore::abstract::AbstractJTagged; -using mindspore::abstract::AbstractList; -using mindspore::abstract::AbstractScalar; -using mindspore::abstract::AbstractTensor; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractType; - -void ValidateOperation(const AnfNodePtr &node) { - if (!IsValueNode(node)) { - return; - } - - // Primitive must in whitelist - PrimitivePtr prim = GetValueNode(node); - if (abstract::IsInWhiteList(prim)) { - return; - } - if (prim->HasPyEvaluator()) { - MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; - return; - } - if (prim->name() == "fake_bprop") { - MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue(prim->GetAttr("info")); - } - - MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name(); -} - -void ValidateAbstract(const AnfNodePtr &node) { - if (node == nullptr) { - MS_LOG(DEBUG) << "Node to validate is invalid"; - return; - } - AbstractBasePtr ptrBase = node->abstract(); - if (ptrBase == nullptr) { - MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString(); - return; - } - if (ptrBase->isa() || ptrBase->isa()) { - // Validate a type. - MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); - } - if (ptrBase->isa()) { - TypePtr ptrType = ptrBase->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(ptrType); - if (ptrType->isa() || ptrType->isa()) { - // only send string in external - if (!IsValueNode(node)) { - // Validate a type. - MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); - } - } - return; - } - if (ptrBase->isa()) { - // NOTICE: validate dead code? - MS_LOG(DEBUG) << "AbstractError in the graph: " << ptrBase->ToString(); - return; - } - - if (ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa()) { - return; - } - - if (ptrBase->isa()) { - return; - } - - // Other types show exception - MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); -} - -void Validate(const FuncGraphPtr &fg) { - FuncGraphManagerPtr mgr = Manage(fg, false); - MS_EXCEPTION_IF_NULL(mgr); - AnfNodeSet &all_nodes = mgr->all_nodes(); - for (const auto &anf_node : all_nodes) { - ValidateOperation(anf_node); - ValidateAbstract(anf_node); - } -} -} // namespace validator -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/validator.h b/mindspore/ccsrc/pipeline/validator.h deleted file mode 100644 index 61f7470349..0000000000 --- a/mindspore/ccsrc/pipeline/validator.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H_ -#define MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H_ - -#include -#include -#include -#include -#include "operator/ops.h" -#include "ir/anf.h" -#include "utils/misc.h" - -namespace mindspore { -namespace validator { -void Validate(const FuncGraphPtr &func_graph); -void ValidateAbstract(const AnfNodePtr &node); -void ValidateOperation(const AnfNodePtr &node); -} // namespace validator -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H__ diff --git a/mindspore/ccsrc/pre_activate/CMakeLists.txt b/mindspore/ccsrc/pre_activate/CMakeLists.txt deleted file mode 100644 index 239757fb17..0000000000 --- a/mindspore/ccsrc/pre_activate/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -file(GLOB_RECURSE _PREACTIVATE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "common/*.cc" - "mem_reuse/*.cc" - "pass/*.cc" - "gpu/*.cc" -) - -if (ENABLE_D) - file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc") - list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST}) -endif () - -set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT) -add_library(_mindspore_pre_activate_obj OBJECT ${_PREACTIVATE_SRC_LIST}) diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h deleted file mode 100644 index 222c4b90b5..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ -#include -#include "session/kernel_graph.h" -namespace mindspore { -namespace opt { -void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph); -void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph); -void AscendDataLayout(const std::shared_ptr &kernel_graph); -void AscendMixPrecision(const std::shared_ptr &kernel_graph); -void AscendBackendOptimization(const std::shared_ptr &kernel_graph); -void AscendGraphKernelCommonProcess(const std::shared_ptr &kernel_graph); -void AscendBackendGraphKernelOpt(const std::shared_ptr &kernel_graph, - bool is_before_kernel_select = false); -void AscendBackendFuseBasicOpt(const std::shared_ptr &kernel_graph, - bool is_before_kernel_select = false); -void AscendBackendAddAtomicClean(const std::shared_ptr &kernel_graph); -void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph); -void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph); -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc deleted file mode 100644 index 9c498bd736..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ /dev/null @@ -1,345 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ascend_helper.h" -#include -#include "common/trans.h" -#include "common/utils.h" -#include "pre_activate/common/helper.h" -#include "utils/utils.h" -#include "device/kernel_info.h" -#include "kernel/oplib/oplib.h" -#include "kernel/common_utils.h" -#include "operator/ops.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; -namespace { -const std::set kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW}; -AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, - const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { - std::vector trans_inputs; - auto prim = std::make_shared(prim::kPrimReshape->name()); - trans_inputs.emplace_back(NewValueNode(prim)); - trans_inputs.emplace_back(input_node); - auto reshape = func_graph->NewCNode(trans_inputs); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get()); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); - AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape); - reshape->set_scope(input_node->scope()); - kernel_select->SelectKernel(reshape); - return reshape; -} - -AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { - AnfNodePtr trans_node = nullptr; - AnfNodePtr input_node = node; - CNodePtr trans_data = nullptr; - std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); - std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; - std::vector padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); - MS_EXCEPTION_IF_NULL(node); - // if insert transdata for input we need to change the input - if (is_insert_input) { - if (!node->isa()) { - MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; - } - auto cnode = node->cast(); - dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); - input_node = AnfAlgo::GetInputNode(cnode, insert_index); - padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); - } - bool need_padding = false; - if (is_insert_input) { - need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); - } else { - need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); - } - if (!need_padding) { - // don't need padding insert transdata only - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); - trans_node = trans_data; - } else if (is_insert_input) { - // if need padding & is input need insert a transdata - // reshape[padding shape] -> transdata[padding shape] -> node - auto padding_shape = - trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); - auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); - trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); - trans_node = trans_data; - } else { - // if need padding & is output need insert a transdata - // node -> transdata[padding shape] -> reshape[ori_shape] - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); - auto reshape_node = - CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); - trans_node = reshape_node; - } - // refresh the transdata's format to ori format & dst format - RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); - return trans_node; -} - -AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(node); - auto input_node = AnfAlgo::GetInputNode(node, index); - auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); - MS_EXCEPTION_IF_NULL(node_with_index.first); - auto real_input = node_with_index.first; - if (real_input->isa() || real_input->isa()) { - input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); - MS_EXCEPTION_IF_NULL(input_node); - AnfAlgo::SetNodeInput(node, input_node, index); - } - std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); - std::string dest_format = AnfAlgo::GetInputFormat(node, index); - if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { - MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) - << " To DefaultFormat , index: " << index; - return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); - } - return input_node; -} - -AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(node); - std::string output_format = AnfAlgo::GetOutputFormat(node, 0); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, 0); - if (output_format == kOpFormat_NC1KHKWHWC0) { - MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " - << node->DebugString(); - } - if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { - MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; - return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); - } - return node; -} - -AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - std::vector make_tuple_inputs; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { - std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); - if (output_format == kOpFormat_NC1KHKWHWC0) { - MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " - << node->DebugString(); - } - auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { - make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false)); - } else { - // No need insert trans op. - make_tuple_inputs.push_back(tuple_getitem); - } - } - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - return make_tuple; -} -} // namespace -void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &trans_data, const std::vector &reshape_type) { - MS_EXCEPTION_IF_NULL(trans_data); - auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); - MS_EXCEPTION_IF_NULL(ori_build_info); - auto builder = std::make_shared(ori_build_info); - builder->SetInputsFormat({input_format}); - builder->SetInputReshapeType({reshape_type}); - builder->SetOutputReshapeType({reshape_type}); - builder->SetOutputsFormat({output_format}); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); -} - -CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, - const bool need_padding, const std::string &op_name) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(input); - std::vector trans_inputs; - auto prim = std::make_shared(op_name); - trans_inputs.push_back(NewValueNode(prim)); - trans_inputs.push_back(input); - CNodePtr trans_node = func_graph->NewCNode(trans_inputs); - MS_EXCEPTION_IF_NULL(trans_node); - auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); - if (need_padding) { - // if need padding we should set the transdata node's shape to the padding shape - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, - {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, - trans_node.get()); - } else { - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, - {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); - } - // special handle for ut - if (trans_node->kernel_info() == nullptr) { - auto kernel_info = std::make_shared(); - trans_node->set_kernel_info(kernel_info); - } - MS_EXCEPTION_IF_NULL(kernel_select); - kernel_select->SelectKernel(trans_node); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); - MS_EXCEPTION_IF_NULL(trans_node); - trans_node->set_scope(input->scope()); - return trans_node; -} - -AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, - const TypeId &input_type, const TypeId &output_type, - const std::vector &origin_shape, const TypeId &origin_type) { - MS_EXCEPTION_IF_NULL(func_graph); - std::string input_format = format; - std::string output_format = format; - std::vector new_cast_inputs; - auto prim = std::make_shared(prim::kPrimCast->name()); - new_cast_inputs.push_back(NewValueNode(prim)); - new_cast_inputs.push_back(input); - CNodePtr cast = func_graph->NewCNode(new_cast_inputs); - MS_EXCEPTION_IF_NULL(cast); - // set kernel build info - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetInputsFormat({input_format}); - builder.SetOutputsFormat({output_format}); - builder.SetInputsDeviceType({input_type}); - builder.SetOutputsDeviceType({output_type}); - builder.SetFusionType(kernel::FusionType::OPAQUE); - builder.SetProcessor(kernel::Processor::AICORE); - if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) { - builder.SetKernelType(KernelType::TBE_KERNEL); - } else { - builder.SetKernelType(KernelType::AKG_KERNEL); - } - // if kernel info is null , it remarks this function is running ut - if (cast->kernel_info() == nullptr) { - auto kernel_info = std::make_shared(); - cast->set_kernel_info(kernel_info); - } - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); - AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); - AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); - return cast; -} - -AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { - size_t outputs_num = AnfAlgo::GetOutputTensorNum(node); - if (outputs_num == 0) { - return node; - } - // Single output - if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) { - return InsertTransOpForSingleOutput(func_graph, node, kernel_select); - } - // Multiple output - return InsertTransOpForMultipleOutput(func_graph, node, kernel_select); -} - -AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { - AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select); - MS_EXCEPTION_IF_NULL(input_node); - new_inputs.push_back(input_node); - } - CNodePtr new_cnode = nullptr; - // cnode changed so make a new cnode to differ from original one. - auto kernel_graph = func_graph->cast>(); - if (kernel_graph == nullptr) { - new_cnode = std::make_shared(*cnode); - } else { - new_cnode = kernel_graph->NewCNode(cnode); - } - MS_EXCEPTION_IF_NULL(new_cnode); - new_cnode->set_inputs(new_inputs); - return new_cnode; -} - -CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { - const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); - TypeId origin_type(kTypeUnknown); - auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); - auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); - auto real_input_node = kernel_with_index.first; - if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - // weight - origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index); - if (origin_type == kTypeUnknown) { - origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); - } - } else { - // feature map - origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); - } - const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); - const std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index); - const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); - // In graph kernel, we check parameter, - // the eliminate pass will not eliminate this case, so we just do not insert the noused cast. - if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode(cur_input)) { - new_inputs.push_back(cur_input); - } else if (origin_type != device_type) { - auto cast = - AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type); - MS_EXCEPTION_IF_NULL(cast); - cast->set_scope(cnode->scope()); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast); - new_inputs.push_back(cast); - } else { - new_inputs.push_back(cur_input); - } - } - auto kernel_graph = func_graph->cast>(); - CNodePtr new_node = nullptr; - if (kernel_graph == nullptr) { - new_node = std::make_shared(*cnode); - } else { - new_node = kernel_graph->NewCNode(cnode); - } - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_inputs(new_inputs); - return new_node; -} - -AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto prim = std::make_shared(kMemCpyAsyncOpName); - std::vector new_node_inputs = {NewValueNode(prim), node}; - auto new_node = graph->NewCNode(new_node_inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_abstract(node->abstract()); - new_node->set_scope(node->scope()); - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h deleted file mode 100644 index ad48ca5291..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ - -#include -#include -#include -#include "device/ascend/kernel_select_ascend.h" -#include "kernel/kernel_query.h" -#include "kernel/oplib/oplib.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -class KernelSelect { - public: - KernelSelect() = default; - virtual ~KernelSelect() = default; - virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); } -}; -using KernelSelectPtr = std::shared_ptr; - -class SupportedChecker { - public: - SupportedChecker() = default; - virtual ~SupportedChecker() = default; - virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node, - const kernel::KernelBuildInfoPtr &select_kernel_build_info) { - return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); - } - virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node, - const kernel::KernelBuildInfoPtr &select_kernel_build_info) { - return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); - } -}; -using SupportedCheckerPtr = std::shared_ptr; - -class KernelQuery { - public: - KernelQuery() = default; - virtual ~KernelQuery() = default; - virtual void Query(const CNodePtr &kernel_node, - std::vector> *kernel_info_list) { - kernel::KernelQuery(kernel_node, kernel_info_list); - } - virtual bool IsTbeRef(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(node), kernel::kTBE); - if (op_info != nullptr) { - return op_info->is_ref(); - } - return false; - } -}; -using KernelQueryPtr = std::shared_ptr; -void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &trans_data, const std::vector &reshape_type = {}); - -CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, - const bool need_padding, const std::string &op_name); - -AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, - const TypeId &input_type, const TypeId &output_type, - const std::vector &origin_shape, const TypeId &origin_type); - -AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select); - -AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select); - -CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); - -AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node); -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc deleted file mode 100644 index 94318d63ca..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(relu_input); - auto add = relu_input->cast(); - MS_EXCEPTION_IF_NULL(add); - auto tuple_getitem = add->input(1); - MS_EXCEPTION_IF_NULL(tuple_getitem); - if (tuple_getitem->isa() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) { - auto getitem = tuple_getitem->cast(); - MS_EXCEPTION_IF_NULL(getitem); - auto bnupdate = getitem->input(1); - MS_EXCEPTION_IF_NULL(bnupdate); - if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { - std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); - for (auto out_getitem : manager->node_users()[bnupdate]) { - MS_EXCEPTION_IF_NULL(out_getitem.first); - auto out_getitem_ptr = out_getitem.first->cast(); - MS_EXCEPTION_IF_NULL(out_getitem_ptr); - auto input2 = out_getitem_ptr->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); - } - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); - std::unordered_set record{cnode, relu_input, bnupdate}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } -} - -void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { - auto eltwise_input = cnode->input(1); - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { - MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); - } - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h deleted file mode 100644 index 6cdc5885f6..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass { - public: - explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {} - ~BnupdateEltwiseEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc deleted file mode 100644 index 1f7fef9e62..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(relu_input); - auto getitem = relu_input->cast(); - MS_EXCEPTION_IF_NULL(getitem); - auto bnupdate = getitem->input(1); - MS_EXCEPTION_IF_NULL(bnupdate); - if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { - std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); - for (auto out_getitem : manager->node_users()[bnupdate]) { - MS_EXCEPTION_IF_NULL(out_getitem.first); - auto out_getitem_ptr = out_getitem.first->cast(); - MS_EXCEPTION_IF_NULL(out_getitem_ptr); - auto input2 = out_getitem_ptr->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); - } - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); - std::unordered_set record{cnode, bnupdate}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { - auto eltwise_input = cnode->input(1); - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { - MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); - } - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h deleted file mode 100644 index b5688f3a36..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class BnupdateEltwiseFusionPass : public FusionBasePass { - public: - explicit BnupdateEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {} - ~BnupdateEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc deleted file mode 100644 index 6091eb572d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltwise( - const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - } else { - return; - } - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - auto double_in_eltwise_input = input_cnode->input(1); - MS_EXCEPTION_IF_NULL(double_in_eltwise_input); - if (!double_in_eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { - return; - } - if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input, prim::kPrimConv2DBackpropInput)) { - (void)record.insert(double_in_eltwise_input); - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && - (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { - MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h deleted file mode 100644 index 7d779d35f8..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class Conv2DBackpropEltwiseEltwiseFusionPass : public FusionBasePass { - public: - explicit Conv2DBackpropEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {} - ~Conv2DBackpropEltwiseEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchConv2DBackpropInputEltwiseEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc deleted file mode 100644 index 963f1885fe..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { - return; - } - if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) { - (void)record.insert(eltwise_input); - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && - (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { - MatchConv2DBackpropInputEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h deleted file mode 100644 index 171352de9b..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class Conv2DBackpropEltwiseFusionPass : public FusionBasePass { - public: - explicit Conv2DBackpropEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {} - ~Conv2DBackpropEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc deleted file mode 100644 index 63e7dcf6b8..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" - -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - auto conv = cnode->input(1); - MS_EXCEPTION_IF_NULL(conv); - if (conv->isa() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) { - std::vector output_used_num{SizeToInt(manager->node_users()[conv].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv); - std::unordered_set record{cnode, conv}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void ConvBnReduceFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) { - MatchConvBnreduce(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h deleted file mode 100644 index 7a06faa624..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class ConvBnReduceFusionPass : public FusionBasePass { - public: - explicit ConvBnReduceFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("ConvBnReduceFusionPass", idAllocator) {} - ~ConvBnReduceFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_CONV_BNREDUCE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc deleted file mode 100644 index a126143811..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - } else { - return; - } - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - auto double_in_eltwise_input = input_cnode->input(1); - MS_EXCEPTION_IF_NULL(double_in_eltwise_input); - if (!double_in_eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { - return; - } - if (AnfAlgo::GetKernelType(double_in_eltwise_input) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::FusionType::CONVLUTION) { - (void)record.insert(double_in_eltwise_input); - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void ConvDoubleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchConvDoubleInEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h deleted file mode 100644 index 062b8182fb..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class ConvDoubleInFusionPass : public FusionBasePass { - public: - explicit ConvDoubleInFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("ConvDoubleInFusionPass", idAllocator) {} - ~ConvDoubleInFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.cc deleted file mode 100644 index d83b32a888..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - while (CheckEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - if (record.size() == MAX_ELTWISE_NUM) { - break; - } - } - MS_EXCEPTION_IF_NULL(eltwise_input); - if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { - return; - } - if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::CONVLUTION) { - (void)record.insert(eltwise_input); - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void ConvSingleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchConvSingleInEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h deleted file mode 100644 index bf7e581dff..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class ConvSingleInFusionPass : public FusionBasePass { - public: - explicit ConvSingleInFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("ConvSingleInFusionPass", idAllocator) {} - ~ConvSingleInFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchConvSingleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc deleted file mode 100644 index 98a6838bed..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" - -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnode, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion, bool is_order) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - if (is_order) { - // DepthwiseConvolution--->Elemwise - auto depthwise_conv = cnode->input(1); - MS_EXCEPTION_IF_NULL(depthwise_conv); - if (cnode->isa() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) { - std::vector output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv); - std::unordered_set record{cnode, depthwise_conv}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } else { - // Elemwise-->DepthwiseConvolution - auto relu = cnode->input(1); - MS_EXCEPTION_IF_NULL(relu); - if (cnode->isa() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) { - std::vector output_used_num{SizeToInt(manager->node_users()[relu].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu); - std::unordered_set record{cnode, relu}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } -} - -void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { - auto eltwise_input = cnode->input(1); - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { - MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); - } - } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { - MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h deleted file mode 100644 index c2e72f26ff..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class DepthwiseConvEltwiseFusionPass : public FusionBasePass { - public: - explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("DepthwiseConvEltwiseFusionPass", idAllocator) {} - ~DepthwiseConvEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion, bool is_order); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc deleted file mode 100644 index 2f04e16692..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - while (CheckEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - if (record.size() == MAX_ELTWISE_SIZE) { - break; - } - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - } - if (record.size() < MIN_ELTWISE_SIZE) { - return; - } - candidate_fusion->push_back(record); - SetRecordFusionId(record); -} - -void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - std::reverse(node_list.begin(), node_list.end()); - for (auto &node : node_list) { - MS_EXCEPTION_IF_NULL(node); - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h deleted file mode 100644 index 54ff0f5982..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class EltwiseFusionPass : public FusionBasePass { - public: - explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {} - ~EltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc deleted file mode 100644 index a516f04442..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include -#include -#include "debug/anf_ir_dump.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(node); - if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto user_nodes = manager->node_users()[node]; - return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && - cnode->inputs().size() == ELTWISE_INPUT_SIZE; -} - -bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(node); - if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto user_nodes = manager->node_users()[node]; - return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && - cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE; -} - -bool FusionBasePass::CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(node); - if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto user_nodes = manager->node_users()[node]; - return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_MULTI_USE && - cnode->inputs().size() == ELTWISE_INPUT_SIZE; -} - -void FusionBasePass::SetRecordFusionId(const std::unordered_set &record) { - auto id = fusion_id_allocator->AllocateFusionId(); - for (auto node : record) { - fusion_id_allocator->SetFusionId(node, id); - } -} - -bool FusionBasePass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) { - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - auto return_node = kernel_graph.get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->inputs().size() <= 1) { - return false; - } - MS_LOG(DEBUG) << "MatchBufferFusionPattern start..."; - FusedNodeRecord candidate_fusion; - MatchSingleFusionPattern(kernel_graph, &candidate_fusion); - if (candidate_fusion.empty()) { - return false; - } - MS_LOG(DEBUG) << "MatchBufferFusionPattern Success..."; - return true; -} - -bool FusionBasePass::Run(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - auto kernel_graph = graph->cast>(); - MS_EXCEPTION_IF_NULL(kernel_graph); - return MatchUBFusionPattern(*kernel_graph); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h deleted file mode 100644 index 8d6eca774c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ -#include -#include -#include -#include - -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -const int8_t MAX_ELTWISE_NUM = 3; -const int8_t MIN_ELTWISE_SIZE = 2; -const int8_t ELTWISE_INPUT_SIZE = 2; -const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3; -const int8_t CONV_DOUBLE_IN_INPUT_SIZE = 3; -const int8_t CONV_QUART_IN_INPUT_SIZE = 5; -const int8_t ELTWISE_USE = 1; -const int8_t ELTWISE_MULTI_USE = 2; -const int8_t MAX_ELTWISE_SIZE = 6; -const int8_t MULTI_ELTWISE_SIZE = 4; -using FusedNodeRecord = std::vector>; - -struct BufferFusionInfo_t { - std::vector anf_nodes; - std::vector inputs_list; - std::vector outputs_list; - kernel::KernelBuildInfoPtr kernel_build_info; -}; - -class FusionBasePass : public Pass { - public: - FusionBasePass(const std::string &name, FusionIdAllocatorPtr idAllocator) - : Pass(name), fusion_id_allocator(idAllocator) {} - ~FusionBasePass() override = default; - bool Run(const FuncGraphPtr &graph) override; - bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph); - - protected: - virtual void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) = 0; - void SetRecordFusionId(const std::unordered_set &record); - bool CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); - bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); - bool CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); - FusionIdAllocatorPtr fusion_id_allocator; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc deleted file mode 100644 index d1ef5dc83b..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::vector output_used_num{SizeToInt(manager->node_users()[relu_input].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input); - std::unordered_set record{cnode, relu_input}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); -} - -void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) { - MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); - } - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h deleted file mode 100644 index 5baaa6db86..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class MatmulEltwiseFusionPass : public FusionBasePass { - public: - explicit MatmulEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {} - ~MatmulEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.cc deleted file mode 100644 index be4d2af1cb..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.cc +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) { - std::vector output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input); - (void)record.insert(eltwise_input); - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - } else { - return; - } - while (CheckEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - if (record.size() == MULTI_ELTWISE_SIZE) { - break; - } - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - } - if (record.size() != MULTI_ELTWISE_SIZE) { - return; - } - candidate_fusion->push_back(record); - SetRecordFusionId(record); -} - -void MultiOutputFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - std::reverse(node_list.begin(), node_list.end()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchMultiOutputEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h deleted file mode 100644 index 0e2510128a..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class MultiOutputFusionPass : public FusionBasePass { - public: - explicit MultiOutputFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("MultiOutputFusionPass", idAllocator) {} - ~MultiOutputFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchMultiOutputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc deleted file mode 100644 index 623f0e3426..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - while (CheckEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - if (record.size() == MAX_ELTWISE_NUM) { - break; - } - } - MS_EXCEPTION_IF_NULL(eltwise_input); - if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { - return; - } - if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::COMMREDUCE) { - (void)record.insert(eltwise_input); - auto previous_input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(previous_input_cnode); - auto previous_eltwise_input = previous_input_cnode->input(1); - auto previous_size = record.size(); - while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { - (void)record.insert(previous_eltwise_input); - auto previous_node = previous_eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(previous_node); - previous_eltwise_input = previous_node->input(1); - if (record.size() - previous_size == MAX_ELTWISE_NUM) { - break; - } - } - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void ReduceEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - std::reverse(node_list.begin(), node_list.end()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchReduceEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h deleted file mode 100644 index 42d896e96b..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class ReduceEltwiseFusionPass : public FusionBasePass { - public: - explicit ReduceEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("ReduceEltwiseFusionPass", idAllocator) {} - ~ReduceEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchReduceEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWSIE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc deleted file mode 100644 index 0dcf2362bc..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - while (CheckEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - if (record.size() == MAX_ELTWISE_NUM) { - break; - } - } - MS_EXCEPTION_IF_NULL(eltwise_input); - if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { - return; - } - if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::SEGMENT) { - (void)record.insert(eltwise_input); - auto previous_input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(previous_input_cnode); - auto previous_eltwise_input = previous_input_cnode->input(1); - auto previous_size = record.size(); - while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { - (void)record.insert(previous_eltwise_input); - auto previous_node = previous_eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(previous_node); - previous_eltwise_input = previous_node->input(1); - if (record.size() - previous_size == MAX_ELTWISE_NUM) { - break; - } - } - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - std::reverse(node_list.begin(), node_list.end()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchSegmentEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h deleted file mode 100644 index 41f06ba1f9..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class SegmentEltwiseFusionPass : public FusionBasePass { - public: - explicit SegmentEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("SegmentEltwiseFusionPass", idAllocator) {} - ~SegmentEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchSegmentEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWSIE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc deleted file mode 100644 index 5bc0fdced7..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h" - -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(const CNodePtr &cnode, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto write_input = cnode->input(1); - if (CheckEltWiseNode(manager.get(), write_input)) { - (void)record.insert(write_input); - auto input_cnode = write_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - write_input = input_cnode->input(1); - } - MS_EXCEPTION_IF_NULL(write_input); - if (!write_input->isa() || !AnfAlgo::IsRealCNodeKernel(write_input) || - fusion_id_allocator->HasFusionIdAttr(write_input)) { - return; - } - auto conv_cnode = write_input->cast(); - MS_EXCEPTION_IF_NULL(conv_cnode); - if (AnfAlgo::GetKernelType(conv_cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(conv_cnode) == kernel::FusionType::CONVLUTION && - conv_cnode->inputs().size() >= CONV_DOUBLE_IN_INPUT_SIZE && - conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) { - (void)record.insert(write_input); - auto conv_input = conv_cnode->input(1); - MS_EXCEPTION_IF_NULL(conv_input); - if (!conv_input->isa() || !AnfAlgo::IsRealCNodeKernel(conv_input) || - fusion_id_allocator->HasFusionIdAttr(conv_input)) { - return; - } - if (AnfAlgo::GetCNodeName(conv_input) == kStridedReadOpName) { - (void)record.insert(conv_input); - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } -} - -void StridedReadConvStridedWriteFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) == kStridedWriteOpName) { - MatchStridedReadConvStridedWrite(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h deleted file mode 100644 index c6c5fe88dc..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class StridedReadConvStridedWriteFusionPass : public FusionBasePass { - public: - explicit StridedReadConvStridedWriteFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("StridedReadConvStridedWriteFusionPass", idAllocator) {} - ~StridedReadConvStridedWriteFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchStridedReadConvStridedWrite(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.cc deleted file mode 100644 index faa5169c40..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.cc +++ /dev/null @@ -1,448 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "device/kernel_info.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -namespace { -const int8_t MAX_PATTERN_SIZE = 7; -const int8_t MIN_PATTERN_SIZE = 2; -const int8_t ELTWISE_INPUT_SIZE = 2; -const int8_t ELTWISE_USE = 1; -const int8_t MULTI_ELTWISE_USE = 2; -const int8_t MAX_MULTI_ELTWISE_SIZE = 4; -const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3; -constexpr auto kOpAttrFusionId = "fusion_id"; - -#ifdef DEBUG -std::string GetFusionTypeName(const kernel::FusionType &type) { - switch (type) { - case kernel::FusionType::COMMREDUCE: - return "COMMREDUCE"; - case kernel::FusionType::SEGMENT: - return "SEGMENT"; - case kernel::FusionType::ELEMWISE: - return "ELEMWISE"; - case kernel::FusionType::CONVLUTION: - return "CONVLUTION"; - case kernel::FusionType::OPAQUE: - return "OPAQUE"; - default: - return "OPAQUE"; - } -} - -void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) { - MS_LOG(INFO) << "=== Dump FusionScopeInfo start id: " << info.scope_id; - for (auto &node : info.input_nodes) { - MS_LOG(INFO) << "=== Input: " << node->DebugString(); - } - for (auto &node : info.output_nodes) { - MS_LOG(INFO) << "=== Output: " << node->DebugString(); - } - for (auto &node : info.compute_nodes) { - MS_LOG(INFO) << "=== Compute: (" << node->DebugString() << ")-(" << GetFusionTypeName(AnfAlgo::GetFusionType(node)) - << ")"; - } - MS_LOG(INFO) << "=== Dump FusionScopeInfo end"; -} -#endif -CNodePtr CreateFusionOp(const std::vector &inputs_list, const std::vector &outputs_list, - const std::vector &anf_nodes, session::KernelGraph *kernel_graph) { - MS_LOG(DEBUG) << "Start Create FusionOp Kernel"; - MS_EXCEPTION_IF_NULL(kernel_graph); - std::string fusion_op_name = "FusionOp"; - for (auto node : anf_nodes) { - fusion_op_name += '_' + AnfAlgo::GetCNodeName(node); - } - auto fusion_op = std::make_shared(fusion_op_name); - MS_EXCEPTION_IF_NULL(fusion_op); - - std::vector input_names; - for (uint8_t i = 0; i < inputs_list.size(); i++) { - input_names.emplace_back("input" + std::to_string(i)); - } - std::vector output_names; - for (uint8_t i = 0; i < outputs_list.size(); i++) { - output_names.emplace_back("output" + std::to_string(i)); - } - - ValuePtr input_names_v = MakeValue(input_names); - ValuePtr output_names_v = MakeValue(output_names); - fusion_op->set_attr("input_names", input_names_v); - fusion_op->set_attr("output_names", output_names_v); - std::vector fusion_inputs_list = inputs_list; - auto value_node = std::make_shared(fusion_op); - (void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node); - auto buffer_fusion_kernel = kernel_graph->NewCNode(fusion_inputs_list); - if (buffer_fusion_kernel == nullptr) { - MS_LOG(EXCEPTION) << "New FusionOp kernel failed!"; - } - buffer_fusion_kernel->set_scope((anf_nodes.back())->scope()); - - return buffer_fusion_kernel; -} - -kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector &inputs_list, - const std::vector &outputs_list) { - MS_LOG(DEBUG) << "Start Create Kernel Info"; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - // inputs format and data type - std::vector inputs_format; - std::vector inputs_data_type; - for (const auto &input : inputs_list) { - auto real_input = AnfAlgo::VisitKernel(input, 0); - inputs_format.push_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); - inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); - } - // outputs format and data type - std::vector outputs_format; - std::vector outputs_data_type; - for (const auto &output : outputs_list) { - if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { - auto tuple_getitem = output->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - outputs_format.push_back(AnfAlgo::GetOutputFormat( - tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); - outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( - tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); - } else { - outputs_format.push_back(AnfAlgo::GetOutputFormat(output, 0)); - outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); - } - } - builder.SetInputsFormat(inputs_format); - builder.SetInputsDeviceType(inputs_data_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_data_type); - builder.SetKernelType(KernelType::TBE_KERNEL); - return builder.Build(); -} - -AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph, - size_t output_index) { - MS_EXCEPTION_IF_NULL(kernel_graph); - std::vector tuple_getitem_inputs_list; - auto value = std::make_shared(prim::kPrimTupleGetItem); - MS_EXCEPTION_IF_NULL(value); - auto idx = NewValueNode(SizeToInt(output_index)); - MS_EXCEPTION_IF_NULL(idx); - int temp = SizeToInt(output_index); - auto imm = std::make_shared(temp); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - tuple_getitem_inputs_list.push_back(value); - tuple_getitem_inputs_list.push_back(buffer_fusion_kernel); - tuple_getitem_inputs_list.push_back(idx); - auto tuple_item = kernel_graph->NewCNode(tuple_getitem_inputs_list); - MS_EXCEPTION_IF_NULL(tuple_item); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)}, - {AnfAlgo::GetOutputInferShape(buffer_fusion_kernel, output_index)}, - tuple_item.get()); - return tuple_item; -} - -void ReplaceInputNodeInOtherFusionScope(std::unordered_map *buffer_fusion_infos, - int32_t fusion_id, const AnfNodePtr &output_item, - const AnfNodePtr &replace_item) { - for (int32_t id = fusion_id + 1; id <= SizeToInt(buffer_fusion_infos->size()); ++id) { - auto itr = std::find((*buffer_fusion_infos)[id].inputs_list.begin(), (*buffer_fusion_infos)[id].inputs_list.end(), - output_item); - if (itr != (*buffer_fusion_infos)[id].inputs_list.end()) { - MS_LOG(DEBUG) << "replace input of other pattern, id = " << id; - *itr = replace_item; - } - } -} - -void ReplaceOldNode(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, - const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto manager = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; - if (buffer_fusion_info.outputs_list.size() == 1) { // single output - (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); - ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0], - buffer_fusion_kernel); - } else { // multiple output - for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) { - auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index); - (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item); - ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index], - tuple_item); - } - } -} - -void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - MS_EXCEPTION_IF_NULL(kernel_graph); - auto nodes = TopoSort(kernel_graph->get_return()); - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) { - auto fusion_id = AnfAlgo::GetNodeAttr(cnode, kOpAttrFusionId); - (*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode); - } - } -} - -void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, - std::unordered_map *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - - for (auto &buffer_fusion_info : *buffer_fusion_infos) { - auto fusion_id = buffer_fusion_info.first; - auto fusion_info = buffer_fusion_info.second; - for (const auto &node : fusion_info.anf_nodes) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { - auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); - if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == - fusion_info.anf_nodes.end()) { - if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), - (*buffer_fusion_infos)[fusion_id].inputs_list.end(), - cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { - (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx)); - } - } - } - } - } -} - -bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) { - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - auto getitem1 = node1->cast(); - auto getitem2 = node2->cast(); - MS_EXCEPTION_IF_NULL(getitem1); - MS_EXCEPTION_IF_NULL(getitem2); - if (getitem1->size() < kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1[" - << getitem1->DebugString() << "]"; - } - if (getitem2->size() < kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1[" - << getitem2->DebugString() << "]"; - } - auto output_idx1 = GetValue(GetValueNode(getitem1->input(2))); - auto output_idx2 = GetValue(GetValueNode(getitem2->input(2))); - return output_idx1 < output_idx2; -} - -void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - auto manager = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - for (auto &buffer_fusion_info : *buffer_fusion_infos) { - auto fusion_id = buffer_fusion_info.first; - auto fusion_info = buffer_fusion_info.second; - for (const auto &node : fusion_info.anf_nodes) { - if (AnfAlgo::GetOutputTensorNum(node) == 1) { - for (auto use_node : manager->node_users()[node]) { - if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), use_node.first) == - fusion_info.anf_nodes.end()) { - (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(node); - break; - } - } - } else { - int prev_idx = 0; - std::vector tuple_getitem_nodes; - std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), - std::back_inserter(tuple_getitem_nodes), - [](const std::pair &use_node) { return use_node.first; }); - std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); - for (auto getitem : tuple_getitem_nodes) { - MS_EXCEPTION_IF_NULL(getitem); - auto getitem_ptr = getitem->cast(); - auto input2 = getitem_ptr->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - for (int stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) { - auto stub_node = CreateTupleGetItem(node, kernel_graph, IntToSize(stub_idx)); - (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node); - } - prev_idx = output_idx + 1; - for (auto item_use_node : manager->node_users()[getitem]) { - if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) == - fusion_info.anf_nodes.end()) { - (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem); - break; - } - } - } - } - } - } -} - -void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector &outputs_list, - const AnfNodePtr &fusion_kernel) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto manager = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - for (size_t idx = 0; idx < outputs_list.size(); ++idx) { - auto output = outputs_list[idx]; - MS_EXCEPTION_IF_NULL(output); - if (output->isa() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { - auto real_output = AnfAlgo::VisitKernel(output, 0); - auto output_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - auto input2 = output_cnode->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - session::AnfWithOutIndex out_pair(real_output.first, output_idx); - if (kernel_graph->IsInRefOutputMap(out_pair)) { - auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); - session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); - kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); - } - } else { - session::AnfWithOutIndex out_pair(output, 0); - if (kernel_graph->IsInRefOutputMap(out_pair)) { - auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); - session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); - kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); - } - } - } -} -} // namespace - -void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) const { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); - GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos); - GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); - for (auto &buffer_fusion_info : *buffer_fusion_infos) { - buffer_fusion_info.second.kernel_build_info = - CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list); - } -} - -bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - bool change = false; - std::unordered_map buffer_fusion_infos; - buffer_fusion_infos.clear(); - GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos); - - std::vector fusion_scope_infos; - for (auto &buffer_fusion_info : buffer_fusion_infos) { - mindspore::kernel::FusionScopeInfo fusion_scope_info; - fusion_scope_info.scope_id = buffer_fusion_info.first; - fusion_scope_info.input_nodes = buffer_fusion_info.second.inputs_list; - fusion_scope_info.compute_nodes = buffer_fusion_info.second.anf_nodes; - fusion_scope_info.output_nodes = buffer_fusion_info.second.outputs_list; - fusion_scope_infos.push_back(fusion_scope_info); -#ifdef DEBUG - DumpFusionScopeInfo(fusion_scope_info); -#endif - } - auto kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos); - std::vector fusion_ids; - for (auto &buffer_fusion_info : buffer_fusion_infos) { - MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size() - << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size() - << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size(); - fusion_ids.push_back(buffer_fusion_info.first); - } - // Replace fusion op from return to head - std::sort(fusion_ids.begin(), fusion_ids.end()); - for (auto &fusion_id : fusion_ids) { - // Get kernel mod when supporting tbe - if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) { - MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed"; - continue; - } - change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); - } - MS_LOG(DEBUG) << "End Buffer Fusion"; - return change; -} - -bool UbPatternFusion::ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, - int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, - session::KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; - auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, - buffer_fusion_info.anf_nodes, kernel_graph); - AnfAlgo::SetSelectKernelBuildInfo(buffer_fusion_info.kernel_build_info, buffer_fusion.get()); - // Set abstract of fusion_op node - std::vector types; - std::vector> shapes; - for (const auto &out_node : buffer_fusion_info.outputs_list) { - for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(out_node); ++idx) { - types.push_back(AnfAlgo::GetOutputInferDataType(out_node, idx)); - shapes.push_back(AnfAlgo::GetOutputInferShape(out_node, idx)); - } - } - if (types.empty() || shapes.empty()) { - MS_LOG(WARNING) << "buffer_fusion_info.outputs_list is empty"; - return false; - } - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get()); - AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get()); - SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion); - ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph); - return true; -} - -bool UbPatternFusion::Run(const FuncGraphPtr &graph) { - bool changed = false; - MS_EXCEPTION_IF_NULL(graph); - auto kernel_graph = graph->cast>(); - MS_EXCEPTION_IF_NULL(kernel_graph); - changed = FuseBufferFusionPattern(kernel_graph.get()); - // clear fusion_id attr - for (auto &node : graph->nodes()) { - if (node != nullptr && node->isa()) { - AnfAlgo::EraseNodeAttr(kAttrFusionId, node); - } - } - return changed; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h deleted file mode 100644 index 7099c92772..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ -#include -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class UbPatternFusion : public Pass { - public: - UbPatternFusion() : Pass("TbeBufferFusion") {} - ~UbPatternFusion() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - void GetBufferFusionInfo(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) const; - bool ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, - const kernel::KernelModPtr &kernel_ptr, session::KernelGraph *kernel_graph) const; - bool FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc deleted file mode 100644 index 6d0906363e..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" - -namespace mindspore::opt { - -const BaseRef GetnextMemcpyElimination::DefinePattern() const { - auto prim_memcpy = std::make_shared(kMemCpyAsyncOpName); - VarPtr x = std::make_shared(); - VectorRef memcpy_async({prim_memcpy, x}); - return memcpy_async; -} - -const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - if (graph == nullptr || node == nullptr || equiv == nullptr) { - return nullptr; - } - auto memcpy_cnode = node->cast(); - if (memcpy_cnode == nullptr) { - return nullptr; - } - - // 1. memcpy has attr kAttrLabelForInsertStreamActive - if (!AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, memcpy_cnode)) { - MS_LOG(DEBUG) << "node has no label_for_insert_stream_active attr"; - return nullptr; - } - - // 2. memcpy's output has only one user next_node - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(memcpy_cnode) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "memcpy has no output in manager"; - } - auto next_nodes = manager->node_users()[memcpy_cnode]; - if (next_nodes.size() > 1) { - MS_LOG(DEBUG) << "node's output has more than one users"; - return nullptr; - } - - // 3. next_node is not nop node and it has only one input which is memcpy's output - for (auto &item : next_nodes) { - auto next_node = item.first->cast(); - if (opt::IsNopNode(next_node)) { - return nullptr; - } - if (next_node->inputs().size() != 2) { - MS_LOG(DEBUG) << "next node has more than one input"; - return nullptr; - } - // add attr label_for_insert_stream_active for next_node - AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), next_node); - } - - return memcpy_cnode->input(1); -} -} // namespace mindspore::opt diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.h b/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.h deleted file mode 100644 index 523fc87a38..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class GetnextMemcpyElimination : public PatternProcessPass { - public: - explicit GetnextMemcpyElimination(bool multigraph = true) - : PatternProcessPass("getnext_memcpy_elimination", multigraph) {} - ~GetnextMemcpyElimination() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc deleted file mode 100644 index 01a3f789e7..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" -#include -#include -#include "pre_activate/ascend/ascend_helper.h" -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -AnfNodePtr InsertMemcpyAsyncForGetNextOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - if (func_graph == nullptr || node == nullptr) { - return nullptr; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(node); - if (output_num == 0) { - MS_LOG(DEBUG) << "Output number is zero, no need to insert memcpy_async!"; - return node; - } - - // getnext output is tuple and dynamic - std::vector make_tuple_inputs; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - - for (size_t output_index = 0; output_index < output_num; ++output_index) { - auto tuple_get_item = CreatTupleGetItemNode(func_graph, node, output_index); - auto new_node = CreateMemcpyAsyncOp(func_graph, tuple_get_item); - if (new_node == nullptr) { - MS_LOG(EXCEPTION) << "Create memcpy_async op failed!"; - } - AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node); - make_tuple_inputs.push_back(new_node); - } - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - return make_tuple; -} - -const BaseRef InsertMemcpyAsyncForGetNext::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - auto prim = std::make_shared(kGetNextOpName); - - return VectorRef({prim, Xs}); -} - -const AnfNodePtr InsertMemcpyAsyncForGetNext::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (func_graph == nullptr || node == nullptr || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - - auto cnode = node->cast(); - if (AnfAlgo::HasNodeAttr(kAttrVisited, cnode)) { - MS_LOG(DEBUG) << "Node op_name[" << kGetNextOpName << "] has visited."; - return nullptr; - } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cnode); - - return InsertMemcpyAsyncForGetNextOutputs(func_graph, cnode); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h deleted file mode 100644 index eb3b78d33f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class InsertMemcpyAsyncForGetNext : public PatternProcessPass { - public: - explicit InsertMemcpyAsyncForGetNext(bool multigraph = true) - : PatternProcessPass("insert_memcpy_async_for_getnext", multigraph) {} - ~InsertMemcpyAsyncForGetNext() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc deleted file mode 100644 index 63ea59d744..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc +++ /dev/null @@ -1,144 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" -#include -#include -#include -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -namespace { -// insert memcpy for some cnode even if not a Ref cnode -const std::set kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName, - kLambUpdateWithLROpName}; - -bool IsParameterOrValueNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - return kernel_with_index.first->isa() || kernel_with_index.first->isa(); -} - -void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(hccl_node); - MS_EXCEPTION_IF_NULL(memcpy_async); - MS_EXCEPTION_IF_NULL(graph); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &node_users = manager->node_users(); - auto iter = node_users.find(hccl_node); - if (iter == node_users.end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - // find hccl_node's output which is a control depend - for (const auto &node_index : iter->second) { - AnfNodePtr output = node_index.first; - int output_index = node_index.second; - if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { - CNodePtr control_depend = output->cast(); - MS_EXCEPTION_IF_NULL(control_depend); - std::vector new_inputs; - for (size_t i = 0; i < control_depend->size(); ++i) { - if (i == IntToSize(output_index)) { - new_inputs.push_back(memcpy_async); - } else { - new_inputs.push_back(control_depend->input(i)); - } - } - control_depend->set_inputs(new_inputs); - } - } -} -} // namespace - -bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(input); - // when input is a parameter or is a value node - if (IsParameterOrValueNode(input)) { - return true; - } - - // when input is a Ref or some special cnodes - if (kernel_query_->IsTbeRef(input) || - kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { - return true; - } - - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &node_users = manager->node_users(); - auto iter = node_users.find(input); - if (iter == node_users.end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - // when input is used by others - if (iter->second.size() > 1) { - return true; - } - return false; -} - -void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(hccl_node); - bool has_insert_memcpy = false; - AnfNodePtr memcpy_async = nullptr; - std::vector new_inputs = {hccl_node->input(0)}; - for (size_t i = 1; i < hccl_node->size(); ++i) { - auto input = hccl_node->input(i); - if (NeedInsertMemcpy(graph, input)) { - memcpy_async = CreateMemcpyAsyncOp(graph, input); - has_insert_memcpy = true; - new_inputs.push_back(memcpy_async); - } else { - new_inputs.push_back(input); - } - } - - if (has_insert_memcpy) { - CNodePtr new_hccl_node = std::make_shared(*hccl_node); - new_hccl_node->set_inputs(new_inputs); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; - (void)manager->Replace(hccl_node, new_hccl_node); - MS_LOG(DEBUG) << "end replace"; - - // transer hccl op's control to the memcpy_async - if (hccl_node->size() == 2) { - TransferControl(new_hccl_node, memcpy_async, graph); - } - } -} - -const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (func_graph == nullptr || node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - if (!AnfAlgo::IsCommunicationOp(node)) { - return nullptr; - } - InsertMemcpyAsync(func_graph, cnode); - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h deleted file mode 100644 index e2f3b781ed..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ - -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { - public: - explicit InsertMemcpyAsyncForHcclOp(bool multigraph = true) - : PatternProcessPass("insert_memcpy_async_for_hccl_op", multigraph), - kernel_query_(std::make_shared()) {} - ~InsertMemcpyAsyncForHcclOp() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; - bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; - KernelQueryPtr kernel_query_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.cc deleted file mode 100644 index b73fe6c83c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.cc +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" -#include -#include -#include -#include "pre_activate/ascend/ascend_helper.h" -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "device/kernel_info.h" -#include "kernel//oplib/oplib.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -const BaseRef InsertPadForNMSWithMask::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimNMSWithMask, Xs}); -} - -AnfNodePtr InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type, - const std::vector &origin_shape) { - MS_EXCEPTION_IF_NULL(func_graph); - std::vector new_pad_inputs; - auto prim = std::make_shared(prim::kPrimPad->name()); - new_pad_inputs.push_back(NewValueNode(prim)); - new_pad_inputs.push_back(input); - CNodePtr pad = func_graph->NewCNode(new_pad_inputs); - MS_EXCEPTION_IF_NULL(pad); - AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, pad.get()); - return pad; -} - -const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - - size_t input_num = AnfAlgo::GetInputTensorNum(node); - if (input_num == 0) { - return nullptr; - } - std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) { - auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx); - auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx); - auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx); - if (!(origin_shape.size() == 2 && origin_shape[1] == 5)) { - return nullptr; - } - origin_shape[1] = 8; - auto pad = InsertPadToGraph(func_graph, cur_input, origin_type, origin_shape); - MS_EXCEPTION_IF_NULL(pad); - pad->set_scope(cnode->scope()); - AnfAlgo::SetNodeAttr("paddings", MakeValue(std::vector>{{0, 0}, {0, 3}}), pad); - new_inputs.push_back(pad); - } - auto kernel_graph = func_graph->cast>(); - CNodePtr new_node = nullptr; - if (kernel_graph == nullptr) { - new_node = std::make_shared(*cnode); - } else { - new_node = kernel_graph->NewCNode(cnode); - } - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_inputs(new_inputs); - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h deleted file mode 100644 index bfc201ed11..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass.h" - -namespace mindspore { -namespace opt { -class InsertPadForNMSWithMask : public PatternProcessPass { - public: - explicit InsertPadForNMSWithMask(bool multigraph = true) - : PatternProcessPass("insert_pad_for_nms_with_mask", multigraph) {} - ~InsertPadForNMSWithMask() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.cc deleted file mode 100644 index 7c8fb70fda..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.cc +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/format_type/check_consistency.h" - -#include -#include -#include - -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace opt { -namespace { -bool CheckFormatForConsistency(const CNodePtr &node, const size_t input_index) { - MS_EXCEPTION_IF_NULL(node); - // get prior node's device output format - string pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(node, input_index); - string selected_input_format = AnfAlgo::GetInputFormat(node, input_index); - if (pre_output_format == selected_input_format) { - return true; - } - auto input_origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, input_index); - if (pre_output_format == kOpFormat_DEFAULT || selected_input_format == kOpFormat_DEFAULT) { - string checking_format = (pre_output_format == kOpFormat_DEFAULT) ? selected_input_format : pre_output_format; - // when input shape size is 1D, default format and NC1HWC0 are compatible - if (input_origin_shape.size() == 1 && checking_format == kOpFormat_NC1HWC0) { - return true; - } - if (kDefaultCompatibleFormat.find(checking_format) != kDefaultCompatibleFormat.end()) { - return true; - } - } - if (input_origin_shape.size() == 0) { - return true; - } - MS_LOG(ERROR) << "Found inconsistent format! input format " << input_index << ": " << pre_output_format - << ", selected input format: " << selected_input_format; - return false; -} - -bool CheckDataTypeForConsistency(const CNodePtr &node, const size_t input_index) { - MS_EXCEPTION_IF_NULL(node); - TypeId input_data_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(node, input_index); - TypeId selected_data_type = AnfAlgo::GetInputDeviceDataType(node, input_index); - if (input_data_type == selected_data_type) { - return true; - } - MS_LOG(ERROR) << "Found inconsistent dtype! input dtype " << input_index << ": " << TypeIdLabel(input_data_type) - << ", selected dtype: " << TypeIdLabel(selected_data_type); - return false; -} -} // namespace - -const BaseRef CheckConsistency::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({X, Xs}); -} - -const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { - if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - - std::vector todos = {node}; - if (AnfAlgo::IsGraphKernel(node)) { - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - kernel::GetValidKernelNodes(sub_graph, &todos); - } - - for (auto &t : todos) { - CNodePtr cnode = t->cast(); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { - if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { - MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "[" - << cnode->DebugString() << "]"; - } - } - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.h b/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.h deleted file mode 100644 index e134547dc8..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class CheckConsistency : public PatternProcessPass { - public: - explicit CheckConsistency(bool multigraph = true) : PatternProcessPass("check_consistency", multigraph) {} - ~CheckConsistency() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc deleted file mode 100644 index c0f99ed415..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel_build_info.h" -#include "kernel/kernel_query.h" -namespace mindspore { -namespace opt { -const BaseRef ConvertUnSupportNodeToAICPU::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({X, Xs}); -} - -const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraphPtr &, - const mindspore::AnfNodePtr &node, - const mindspore::EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - auto node_name = AnfAlgo::GetCNodeName(node); - if (node_name != prim::KPrimTransData->name() && node_name != prim::kPrimCast->name()) { - return nullptr; - } - auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); - if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) { - return nullptr; - } else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) { - auto builder = std::make_shared(kernel_builder_info); - builder->SetKernelType(AICPU_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); - AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), node); - } else { - MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" - << node->DebugString() << "]"; - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h deleted file mode 100644 index 80cc8170ac..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" -#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H -#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H -namespace mindspore { -namespace opt { -class ConvertUnSupportNodeToAICPU : public PatternProcessPass { - public: - explicit ConvertUnSupportNodeToAICPU(bool multigraph = true) - : PatternProcessPass("convert_unsupported_node_to_aicpu", multigraph), - supported_checker_(std::make_shared()) {} - ~ConvertUnSupportNodeToAICPU() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - SupportedCheckerPtr supported_checker_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h deleted file mode 100644 index 1b54a7b111..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ - -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class DealRefTransAndCast : public PatternProcessPass { - public: - explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} - ~DealRefTransAndCast() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc deleted file mode 100644 index 3d09233d99..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc +++ /dev/null @@ -1,204 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/format_type/insert_cast.h" - -#include -#include -#include -#include - -#include "device/kernel_info.h" -#include "pre_activate/ascend/ascend_helper.h" -#include "pre_activate/common/helper.h" -#include "kernel/kernel_build_info.h" -#include "kernel/oplib/oplib.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "utils/utils.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace opt { -namespace { -AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::vector &need_insert_cast) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(cnode); - std::vector make_tuple_inputs; - AbstractBasePtrList abstract_list; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { - AnfNodePtr replace_node = nullptr; - const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); - const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); - auto idx = NewValueNode(SizeToInt(output_idx)); - MS_EXCEPTION_IF_NULL(idx); - auto imm = std::make_shared(output_idx); - idx->set_abstract(std::make_shared(imm)); - auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); - AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get()); - if (need_insert_cast[output_idx]) { - const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); - TypeId origin_type(kTypeUnknown); - if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); - } - origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; - const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); - if (origin_type != device_type) { - replace_node = - AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, infer_type); - MS_EXCEPTION_IF_NULL(replace_node); - replace_node->set_scope(cnode->scope()); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); - } else { - replace_node = getitem; - } - } else { - replace_node = getitem; - } - abstract_list.push_back(replace_node->abstract()); - make_tuple_inputs.push_back(replace_node); - } - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - make_tuple->set_abstract(std::make_shared(abstract_list)); - return make_tuple; -} // namespace - -AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::vector &need_insert_cast) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { - return cnode; - } - MS_EXCEPTION_IF_NULL(cnode->Type()); - // Single output - if (!cnode->Type()->isa()) { - if (!need_insert_cast[0]) { - return cnode; - } - - const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); - const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0); - TypeId origin_type(kTypeUnknown); - if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); - } - origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; - const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); - AnfNodePtr replace_node = cnode; - if (origin_type != device_type) { - replace_node = - AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, infer_type); - MS_EXCEPTION_IF_NULL(replace_node); - replace_node->set_scope(cnode->scope()); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); - } - return replace_node; - } - // Multiple output - return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast); -} - -AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - // insert cast for ops in graph kernel. - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - auto mng = sub_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - std::vector todo; - std::vector> graph_rets; - kernel::GetValidKernelNodes(sub_graph, &todo); - kernel::GetGraphRealOutput(sub_graph, &graph_rets); - for (auto &t : todo) { - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t); - // process input - CNodePtr t_cnode = t->cast(); - MS_EXCEPTION_IF_NULL(t_cnode); - auto t_new_node = InsertCastForInput(sub_graph, t_cnode); - AnfNodePtr t_new_node_1 = nullptr; - std::vector need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true); - // process output - auto iter = std::find_if(graph_rets.begin(), graph_rets.end(), - [&t](const std::pair &ret) { return ret.first == t; }); - if (iter != graph_rets.end()) { - auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t); - auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second); - auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin()); - if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) { - need_insert_cast[iter->second] = false; - } else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) { - need_insert_cast[iter->second] = false; - } - t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); - } else { - t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); - } - - if (t_new_node_1 != nullptr && t_new_node_1 != t) { - (void)mng->Replace(t, t_new_node_1); - } - } - - // insert cast for graph kernel. - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - // process input - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto new_node = InsertCastForInput(func_graph, cnode); - // process output - return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); -} -} // namespace - -const BaseRef InsertCast::DefinePattern() const { - VarPtr V = std::make_shared(UnVisited); - VarPtr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { - return nullptr; - } - - if (AnfAlgo::IsGraphKernel(node)) { - return ProcessGraphKernelOp(func_graph, node); - } else { - // insert cast for single op. - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - // process input - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto new_node = InsertCastForInput(func_graph, cnode); - // process output - return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); - } - // insert cast for single op. - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - // process input - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto new_node = InsertCastForInput(func_graph, cnode); - // process output - return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.h b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.h deleted file mode 100644 index a7f93ec8f3..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ -#include - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pattern_engine.h" -#include "ir/anf.h" - -namespace mindspore { -namespace opt { -class InsertCast : public PatternProcessPass { - public: - explicit InsertCast(bool multigraph = true) : PatternProcessPass("insert_cast", multigraph) {} - ~InsertCast() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc deleted file mode 100644 index 3f77c68f86..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/insert_trans_op.h" -#include -#include -#include "utils/utils.h" -#include "pre_activate/ascend/ascend_helper.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "kernel/oplib/oplib.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -const BaseRef InsertTransOp::DefinePattern() const { - std::shared_ptr V = std::make_shared(UnVisited); - std::shared_ptr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -bool IsGraphOutput(const AnfNodePtr &node, const std::vector &outputs) { - auto iter = std::find(outputs.begin(), outputs.end(), node); - if (iter != outputs.end()) { - return true; - } - - return false; -} - -const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - AnfNodePtr front_node; - auto kernel_graph = func_graph->cast>(); - if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { - front_node = kernel_graph->GetFrontNodeByInternalOutput(node); - } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - MS_LOG(DEBUG) << "====process op: " << node->DebugString(); - AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) { - if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { - return new_node; - } - } - auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_); - if (kernel_graph != nullptr && front_node != nullptr) { - auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node); - kernel_graph->ReplaceInternalOutput(old_node, final_node); - } - return final_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.h b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.h deleted file mode 100644 index eb6cfa9542..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class InsertTransOp : public PatternProcessPass { - public: - explicit InsertTransOp(bool multigraph = true) - : PatternProcessPass("insert_trans_op", multigraph), kernel_select_(std::make_shared()) {} - ~InsertTransOp() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - KernelSelectPtr kernel_select_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.cc deleted file mode 100644 index 3df513a19f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" -#include -#include "utils/utils.h" -#include "pre_activate/ascend/ascend_helper.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "kernel/oplib/oplib.h" - -namespace mindspore { -namespace opt { -const BaseRef RunOpInsertTransData::DefinePattern() const { - std::shared_ptr V = std::make_shared(UnVisited); - MS_EXCEPTION_IF_NULL(V); - std::shared_ptr Xs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - return VectorRef({V, Xs}); -} - -const AnfNodePtr RunOpInsertTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - MS_LOG(DEBUG) << "====process op: " << node->DebugString(); - return InsertTransOpForInput(func_graph, node, kernel_select_); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h deleted file mode 100644 index f699cdd580..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class RunOpInsertTransData : public PatternProcessPass { - public: - explicit RunOpInsertTransData(bool multigraph = true) - : PatternProcessPass("insert_transdata_for_runop", multigraph), - kernel_select_(std::make_shared()) {} - ~RunOpInsertTransData() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - KernelSelectPtr kernel_select_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc deleted file mode 100644 index b1817cec3d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc +++ /dev/null @@ -1,282 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/merge_cast_to_op.h" - -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace { -const size_t kCastInputNum = 2; -const size_t kTupleGetitemInputNum = 3; -bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, const size_t change_idx, - const std::shared_ptr &candidate_kernel_info) { - if (node == nullptr || node->kernel_info() == nullptr || candidate_kernel_info == nullptr) { - return false; - } - - // checkout inputs' fmt and dtype except index equal change_idx - for (size_t i = 0; i < candidate_kernel_info->GetInputNum(); i++) { - if (i == change_idx) { - if (candidate_kernel_info->GetInputDeviceType(i) != dst_type || - candidate_kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(node, i)) { - return false; - } - } else if (candidate_kernel_info->GetInputDeviceType(i) != AnfAlgo::GetInputDeviceDataType(node, i) || - candidate_kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(node, i)) { - return false; - } - } - - // check outputs's fmt and dtype - for (size_t i = 0; i < candidate_kernel_info->GetOutputNum(); i++) { - if (candidate_kernel_info->GetOutputDeviceType(i) != AnfAlgo::GetOutputDeviceDataType(node, i) || - candidate_kernel_info->GetOutputFormat(i) != AnfAlgo::GetOutputFormat(node, i)) { - return false; - } - } - return true; -} - -bool GetNextNodeAndCastIndex(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr *next_node, - size_t *cast_index) { - auto output_node_list = GetRealNodeUsedList(graph, node); - MS_EXCEPTION_IF_NULL(output_node_list); - if (output_node_list->size() != 1) { - return false; - } - auto node_pair = output_node_list->at(0); - *next_node = node_pair.first; - *cast_index = node_pair.second - 1; - return true; -} - -bool CheckInputs(const CNodePtr &node, const std::shared_ptr &kernel_info) { - MS_EXCEPTION_IF_NULL(kernel_info); - if (AnfAlgo::GetInputTensorNum(node) != kernel_info->GetInputNum()) { - return false; - } - - for (size_t index = 0; index < kernel_info->GetInputNum(); ++index) { - if (AnfAlgo::GetInputFormat(node, index) != kernel_info->GetInputFormat(index) || - AnfAlgo::GetInputDeviceDataType(node, index) != kernel_info->GetInputDeviceType(index)) { - return false; - } - } - return true; -} - -bool CheckOtherOutputs(const CNodePtr &node, const std::shared_ptr &kernel_info, - const size_t idx) { - MS_EXCEPTION_IF_NULL(kernel_info); - if (AnfAlgo::GetOutputTensorNum(node) != kernel_info->GetOutputNum()) { - return false; - } - for (size_t index = 0; index < kernel_info->GetOutputNum(); ++index) { - if (idx == index) { - continue; - } - if (AnfAlgo::GetOutputFormat(node, index) != kernel_info->GetOutputFormat(index) || - AnfAlgo::GetOutputDeviceDataType(node, index) != kernel_info->GetOutputDeviceType(index)) { - return false; - } - } - return true; -} - -bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr &kernel_info, size_t index) { - if (kernel_info == nullptr) { - return false; - } - - if (AnfAlgo::GetOutputDeviceDataType(node, 0) != kernel_info->GetOutputDeviceType(index)) { - return false; - } - if (AnfAlgo::GetOutputInferShape(node, 0).size() == 4 && AnfAlgo::GetOutputFormat(node, 0) == kOpFormat_NCHW && - kernel_info->GetOutputFormat(index) == kOpFormat_DEFAULT) { - return true; - } - return AnfAlgo::GetOutputFormat(node, 0) == kernel_info->GetOutputFormat(index); -} - -void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) { - using Shape = std::vector; - auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0); - auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0); - std::vector shapes; - std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { - if (cast_index == index) { - shapes.emplace_back(cast_shape); - types.emplace_back(cast_dtype); - continue; - } - shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index)); - types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index)); - } - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get()); -} - -AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, const KernelQueryPtr kernel_query) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(kernel_query); - AnfNodePtr next_node = nullptr; - size_t cast_index = 0; - if (!GetNextNodeAndCastIndex(graph, node, &next_node, &cast_index)) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(next_node); - if (!next_node->isa() || !AnfAlgo::IsRealKernel(next_node)) { - return nullptr; - } - auto next_cnode = next_node->cast(); - if (AnfAlgo::IsGraphKernel(next_node)) { - return nullptr; - } - auto next_op_name = AnfAlgo::GetCNodeName(next_node); - std::vector> kernel_info_list; - kernel_query->Query(next_cnode, &kernel_info_list); - - auto dst_type_id = AnfAlgo::GetInputDeviceDataType(node, 0); - auto alternative_kernel_info = std::find_if( - kernel_info_list.begin(), kernel_info_list.end(), - [&next_cnode, &dst_type_id, &cast_index](const std::shared_ptr &candidate_kernel_info) { - return AlternativeKernelInfoForInput(next_cnode, dst_type_id, cast_index, candidate_kernel_info); - }); - if (alternative_kernel_info == kernel_info_list.end()) { - return nullptr; - } - auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(next_node); - MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_cnode->DebugString() - << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" - << (*alternative_kernel_info)->ToString(); - AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); - ChangeNodeInferInfo(next_cnode, node, cast_index); - if (node->inputs().size() < kCastInputNum) { - MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:"; - } - return node->input(1); -} - -bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_output, size_t *output_idx) { - MS_EXCEPTION_IF_NULL(x_node); - if (x_node->isa()) { - auto x_cnode = x_node->cast(); - *prior_op = x_cnode; - // when x_node is tuple_getitem - if (AnfAlgo::GetCNodeName(x_node) == prim::kPrimTupleGetItem->name()) { - if (x_cnode->inputs().size() < kTupleGetitemInputNum) { - MS_LOG(EXCEPTION) << "tuple getitem node has wrong input num" << x_cnode->inputs().size(); - } - MS_EXCEPTION_IF_NULL(output_idx); - AnfNodePtr input1 = x_cnode->input(1); - MS_EXCEPTION_IF_NULL(input1); - if (!input1->isa()) { - return false; - } - *prior_op = input1->cast(); - MS_EXCEPTION_IF_NULL(*prior_op); - AnfNodePtr input2 = x_cnode->input(2); - MS_EXCEPTION_IF_NULL(input2); - auto value_ptr = input2->cast(); - MS_EXCEPTION_IF_NULL(value_ptr); - *output_idx = IntToSize(GetValue(value_ptr->value())); - *single_output = false; - } - return AnfAlgo::IsRealKernel(*prior_op); - } - return false; -} - -AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_node, const KernelQueryPtr kernel_query) { - MS_EXCEPTION_IF_NULL(cur_node); - MS_EXCEPTION_IF_NULL(kernel_query); - if (cur_node->inputs().size() < kCastInputNum) { - MS_LOG(EXCEPTION) << "op[Cast] has wrong input num:"; - } - AnfNodePtr x_node = cur_node->input(1); - if (IsUsedByOthers(graph, x_node)) { - return nullptr; - } - - CNodePtr prior_op = nullptr; - bool single_output = true; - size_t output_idx = 0; - if (!GetPriorOp(x_node, &prior_op, &single_output, &output_idx)) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(prior_op); - if (AnfAlgo::IsGraphKernel(prior_op)) { - return nullptr; - } - - std::vector> kernel_info_list; - kernel_query->Query(prior_op, &kernel_info_list); - auto kernel_info_it = std::find_if( - kernel_info_list.begin(), kernel_info_list.end(), - [&prior_op, &cur_node, &output_idx](const std::shared_ptr &item_kernel_info) { - return CheckInputs(prior_op, item_kernel_info) && CheckOtherOutputs(prior_op, item_kernel_info, output_idx) && - CheckIndexOutput(cur_node, item_kernel_info, output_idx); - }); - if (kernel_info_it == kernel_info_list.end()) { - return nullptr; - } - auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(prior_op); - MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << prior_op->DebugString() - << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" - << (*kernel_info_it)->ToString(); - AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get()); - ChangeNodeInferInfo(prior_op, cur_node, output_idx); - if (!single_output) { - MS_EXCEPTION_IF_NULL(x_node); - ChangeNodeInferInfo(x_node->cast(), cur_node, 0); - } - auto prior_name = AnfAlgo::GetCNodeName(prior_op); - if (prior_name == kFive2FourOpName) { - AnfAlgo::CopyNodeAttr("dst_type", "dstType", cur_node, prior_op); - } else if (prior_name == kFour2FiveOpName) { - AnfAlgo::CopyNodeAttr("dst_type", cur_node, prior_op); - } - return single_output ? prior_op : x_node; -} -} // namespace - -const BaseRef MergeCastToOp::DefinePattern() const { - VarPtr X = std::make_shared(); - return VectorRef({prim::kPrimCast, X}); -} - -const AnfNodePtr MergeCastToOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - auto new_node = MergeCastToNextOp(graph, cnode, kernel_query_); - if (new_node == nullptr) { - new_node = MergeCastToPriorOp(graph, cnode, kernel_query_); - } - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.h b/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.h deleted file mode 100644 index 7e05c8a02a..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H - -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class MergeCastToOp : public PatternProcessPass { - public: - explicit MergeCastToOp(bool multigraph = true) - : PatternProcessPass("merge_cast_to_op", multigraph), kernel_query_(std::make_shared()) {} - ~MergeCastToOp() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - KernelQueryPtr kernel_query_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.cc deleted file mode 100644 index 42061957b9..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/modify_ops_attrs.h" -#include -#include -#include "utils/utils.h" -#include "pre_activate/common/helper.h" -#include "kernel/common_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace { -AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); - auto input_format = AnfAlgo::GetInputFormat(cnode, 0); - if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) { - return nullptr; - } - if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { - return nullptr; - } - - AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode); - return cnode; -} - -AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) { - auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0); - if (input_shape.size() != 5) { - return nullptr; - } - if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) { - return nullptr; - } - - auto multiples = AnfAlgo::GetNodeAttr>(cnode, kAttrMultiples); - if (multiples.size() == 4 && multiples[1] == 1) { - multiples.push_back(1); - AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode); - } - - return cnode; -} - -AnfNodePtr ModifyAttrs(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto op_name = AnfAlgo::GetCNodeName(cnode); - if (op_name == prim::kPrimTile->name()) { - return ModifyTileOpAttrs(cnode); - } else if (op_name == prim::kPrimReduceSum->name()) { - // kPrimReduceMean - // kPrimReduceSum - // kPrimReduceAll - // kPrimReduceMax - // kPrimReduceMin - return ModifyReduceOpsAttrs(cnode); - } - return nullptr; -} -} // namespace - -const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa() || !AnfAlgo::IsGraphKernel(node)) { - return nullptr; - } - MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node); - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - auto manager = fg->manager(); - MS_EXCEPTION_IF_NULL(manager); - std::vector todos; - kernel::GetValidKernelNodes(fg, &todos); - for (auto &t : todos) { - auto new_node = ModifyAttrs(t->cast()); - if (new_node != nullptr && new_node != t) { - (void)manager->Replace(t, new_node); - } - } - return node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.h b/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.h deleted file mode 100644 index 25ec94b6b4..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ModifyOpAttrs : public PatternProcessPass { - public: - explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {} - ~ModifyOpAttrs() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc deleted file mode 100644 index d81a8c90ce..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h" - -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel_build_info.h" -#include "utils/utils.h" -#include "kernel/common_utils.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -const BaseRef RectifyDoMaskKernelInfo::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({X, Xs}); -} - -const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode) { - return RectifyKernelInfoInPynativeProcess(node); - } - if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { - return nullptr; - } - std::vector do_mask_node_list; - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto node_map = manager->node_users(); - auto iter = node_map.find(node); - if (iter == node_map.end()) { - MS_LOG(EXCEPTION) << "Cannot find the node " << node->DebugString() << " in the graph manager!"; - } - auto gen_mask_output_nodes = iter->second; - for (const auto &output_node : gen_mask_output_nodes) { - if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) { - auto output_cnode = output_node.first->cast(); - do_mask_node_list.push_back(output_cnode); - } - } - std::vector input_shape; - for (const auto &output_node : do_mask_node_list) { - if (input_shape.empty()) { - input_shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); - continue; - } - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); - if (!kernel::IsSameShape(shape, input_shape)) { - MS_LOG(EXCEPTION) << "The DropOutGenMask connected with same genmask's shape must be equal!" - << " GenMask " << node->DebugString(); - } - } - RectifyKernelInfo(do_mask_node_list); - return nullptr; -} - -void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_mask_node_list) const { - std::map format_counter; - std::string special_format; - std::string convert_format; - for (const auto &do_mask : do_mask_node_list) { - auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0); - if (special_format.empty() && kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end()) { - special_format = do_mask_data_format; - } - if (format_counter.find(do_mask_data_format) == format_counter.end()) { - format_counter[do_mask_data_format] = 1; - } else { - format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1; - } - // if has two or more special format we need change all domask's format to default that can avoid insert more - // transdata - if (format_counter.size() > 2) { - convert_format = kOpFormat_DEFAULT; - break; - } - if (kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end() && - special_format != do_mask_data_format) { - convert_format = kOpFormat_DEFAULT; - break; - } - } - if (format_counter.size() == 1) { - return; - } - if (convert_format.empty()) { - convert_format = GetConvertFormat(format_counter); - } - RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format); -} - -std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map &format_counter) const { - std::string convert_format; - const size_t counter = 0; - for (const auto &iter : format_counter) { - if (counter < iter.second) { - convert_format = iter.first; - } - if (counter == iter.second && kHWSpecialFormatSet.find(convert_format) == kHWSpecialFormatSet.end()) { - convert_format = iter.first; - } - } - return convert_format; -} - -void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, - const std::string &format) const { - for (const auto &do_mask : do_mask_node_list) { - auto builder = - std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); - builder->SetInputFormat(format, 0); - builder->SetOutputFormat(format, 0); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); - } -} - -AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode == nullptr) { - return nullptr; - } - if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { - return nullptr; - } - auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0); - if (do_mask_input_format != kOpFormat_DEFAULT) { - auto builder = - std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); - builder->SetInputFormat(kOpFormat_DEFAULT, 0); - builder->SetOutputFormat(kOpFormat_DEFAULT, 0); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h deleted file mode 100644 index 81bad4d8f8..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H -#include -#include -#include - -#include "pre_activate/common/optimizer.h" -namespace mindspore { -namespace opt { -class RectifyDoMaskKernelInfo : public PatternProcessPass { - public: - explicit RectifyDoMaskKernelInfo(bool multigraph = true) - : PatternProcessPass("batch_norm_bert_fission", multigraph) {} - ~RectifyDoMaskKernelInfo() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - void RectifyKernelInfo(const std::vector &do_mask_node_list) const; - AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const; - std::string GetConvertFormat(const std::map &format_counter) const; - void RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, const std::string &format) const; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.cc deleted file mode 100644 index dde40a5090..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" -#include -#include -#include "pre_activate/common/helper.h" -#include "kernel/common_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace { -AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto op_name = AnfAlgo::GetCNodeName(cnode); - if (op_name != prim::kPrimReshape->name()) { - return nullptr; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); - auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); - if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) { - return nullptr; - } - - return cnode->input(1); -} -} // namespace - -const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa() || !AnfAlgo::IsGraphKernel(node)) { - return nullptr; - } - MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node); - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - auto manager = fg->manager(); - MS_EXCEPTION_IF_NULL(manager); - std::vector todos; - kernel::GetValidKernelNodes(fg, &todos); - for (auto &t : todos) { - auto new_node = RemoveReshapeOp(t->cast()); - if (new_node != nullptr && new_node != t) { - (void)manager->Replace(t, new_node); - } - } - return node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.h b/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.h deleted file mode 100644 index 4942c2fc08..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class RemoveNoUseReshapeOp : public PatternProcessPass { - public: - explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {} - ~RemoveNoUseReshapeOp() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc deleted file mode 100644 index b9a86f7bcb..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/addn_fission.h" -#include -#include -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_addn_cnode, size_t begin_index, - size_t offset) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(origin_addn_cnode); - std::vector new_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; - for (size_t i = begin_index; i < begin_index + offset; ++i) { - new_addn_inputs.push_back(origin_addn_cnode->input(i)); - } - CNodePtr new_addn = func_graph->NewCNode(new_addn_inputs); - MS_EXCEPTION_IF_NULL(new_addn); - new_addn->set_scope(origin_addn_cnode->scope()); - new_addn->set_abstract(origin_addn_cnode->abstract()); - AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); - std::vector dyn_input_sizes{SizeToInt(offset)}; - AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn); - return new_addn; -} -} // namespace - -const BaseRef AddnFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimAddN, Xs}); -} - -const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // The real input begins with index 1. - size_t origin_input_size = cnode->inputs().size() - 1; - if (origin_input_size <= inputs_divisor_) { - return nullptr; - } - CNodePtr new_cnode = cnode; - while (origin_input_size > inputs_divisor_) { - MS_EXCEPTION_IF_NULL(new_cnode); - std::vector base_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; - size_t cur_input_index = 1; - // Divide the inputs of addn by inputs_divisor_. - while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { - base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); - cur_input_index += inputs_divisor_; - } - for (size_t i = cur_input_index; i <= origin_input_size; i++) { - base_addn_inputs.push_back(new_cnode->input(i)); - } - CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); - MS_EXCEPTION_IF_NULL(base_addn); - base_addn->set_scope(new_cnode->scope()); - base_addn->set_abstract(new_cnode->abstract()); - AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); - std::vector dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)}; - AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn); - new_cnode = base_addn; - origin_input_size = base_addn->inputs().size() - 1; - } - - return new_cnode; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.h deleted file mode 100644 index 3c62391f9a..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -constexpr size_t kAddnInputsDivisor = 63; -class AddnFission : public PatternProcessPass { - public: - explicit AddnFission(bool multigraph = true) - : PatternProcessPass("addn_fission", multigraph), inputs_divisor_(kAddnInputsDivisor) {} - ~AddnFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - size_t inputs_divisor_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc deleted file mode 100644 index e6a8864e46..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc +++ /dev/null @@ -1,172 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -const std::vector kOutputIndex{0, 3, 4, 5}; -constexpr size_t kBatchNormRealOutputNum = 3; -constexpr size_t kBatchNormRealInputNum = 3; - -bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(bn) == manager->node_users().end()) { - return false; - } - size_t output_num = 0; - for (const auto &node_index : manager->node_users()[bn]) { - AnfNodePtr output = node_index.first; - MS_EXCEPTION_IF_NULL(output); - if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { - continue; - } - auto tuple_getiterm_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); - auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) { - return false; - } - bn_outputs->push_back(output); - output_num++; - } - return output_num == kBatchNormRealOutputNum; -} - -AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn); - auto bn_cnode = bn->cast(); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { - MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " - << kBatchNormRealInputNum + 1; - } - std::vector bn_training_reduce_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceOpName)), bn_cnode->input(1)}; - auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); - MS_EXCEPTION_IF_NULL(bn_training_reduce); - auto bn_input1 = bn_cnode->input(2); - MS_EXCEPTION_IF_NULL(bn_input1); - auto bn_input2 = bn_cnode->input(3); - MS_EXCEPTION_IF_NULL(bn_input2); - AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input2->abstract()}; - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_reduce->set_abstract(abstract_tuple); - bn_training_reduce->set_scope(bn->scope()); - AnfAlgo::CopyNodeAttrs(bn, bn_training_reduce); - return bn_training_reduce; -} - -AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, - const std::vector &bn_training_reduce_outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn); - auto bn_cnode = bn->cast(); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { - MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " - << kBatchNormRealInputNum + 1; - } - if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum - << ", but it is " << bn_training_reduce_outputs.size(); - } - std::vector bn_training_update_v2_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateV2OpName)), - bn_cnode->input(1), - bn_training_reduce_outputs[0], - bn_training_reduce_outputs[1], - bn_cnode->input(2), - bn_cnode->input(3)}; - auto bn_training_update_v2 = func_graph->NewCNode(bn_training_update_v2_inputs); - MS_EXCEPTION_IF_NULL(bn_training_update_v2); - - auto bn_abstract_tuple = dyn_cast(bn->abstract()); - MS_EXCEPTION_IF_NULL(bn_abstract_tuple); - if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { - MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " - << bn_abstract_tuple->elements().size(); - } - std::vector abstract_list{bn_abstract_tuple->elements()[0], bn_abstract_tuple->elements()[3], - bn_abstract_tuple->elements()[4]}; - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_update_v2->set_abstract(abstract_tuple); - bn_training_update_v2->set_scope(bn->scope()); - AnfAlgo::CopyNodeAttrs(bn, bn_training_update_v2); - return bn_training_update_v2; -} -} // namespace - -const BaseRef BatchNormBertFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimBatchNorm, Xs}); -} - -const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - std::vector bn_outputs; - if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { - MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != kBatchNormRealInputNum + 1) { - MS_LOG(INFO) << "The input size of BatchNorm should be " << kBatchNormRealInputNum - << ". The node should not be changed"; - return nullptr; - } - AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); - std::vector bn_training_reduce_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, - &bn_training_reduce_outputs); - - AnfNodePtr bn_training_update_v2 = CreateBNTrainingUpdateV2(func_graph, node, bn_training_reduce_outputs); - std::vector bn_training_update_v2_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_v2, kBNTrainingUpdateV2OutputNum, - &bn_training_update_v2_outputs); - if (bn_training_update_v2_outputs.size() != kBNTrainingUpdateV2OutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingUpdateV2OutputNum - << ", but it is " << bn_training_update_v2_outputs.size(); - } - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem); - size_t output_index = 0; - for (const auto &output : bn_outputs) { - (void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); - output_index++; - } - // Return the new node for control depends. - return bn_training_update_v2; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.h deleted file mode 100644 index fc214817fc..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormBertFission : public PatternProcessPass { - public: - explicit BatchNormBertFission(bool multigraph = true) : PatternProcessPass("batch_norm_bert_fission", multigraph) {} - ~BatchNormBertFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc deleted file mode 100644 index 5e41111660..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc +++ /dev/null @@ -1,172 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" -#include -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kBatchNormGradInferOutputNum = 3; -bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(node) == manager->node_users().end()) { - MS_LOG(DEBUG) << "The node " << node->DebugString() << " should have some outputs"; - return false; - } - for (const auto &node_index : manager->node_users()[node]) { - AnfNodePtr output = node_index.first; - MS_EXCEPTION_IF_NULL(output); - if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { - continue; - } - auto tuple_getiterm_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); - auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (index == kBatchNormGradInferOutputNum || index == kBatchNormGradInferOutputNum + 1) { - MS_LOG(DEBUG) << "The output " << index << " of node " << node->DebugString() << " is not null, no need change"; - return false; - } - } - return true; -} -} // namespace - -AnfNodePtr BatchNormGradInferFission::CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn_grad); - MS_EXCEPTION_IF_NULL(equiv); - // Set inputs - auto iter_input0 = (*equiv).find(input0_var_); - if (iter_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; - } - auto iter_input2 = (*equiv).find(input2_var_); - if (iter_input2 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input2 var after matched."; - } - auto iter_input4 = (*equiv).find(input4_var_); - if (iter_input4 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; - } - std::vector bn_infer_grad_inputs = { - NewValueNode(std::make_shared(kBNInferGradOpName)), utils::cast(iter_input0->second), - utils::cast(iter_input2->second), utils::cast(iter_input4->second)}; - auto bn_infer_grad = func_graph->NewCNode(bn_infer_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_infer_grad); - // Set abstract, the output of new node is taking the place of the 0th output of bn_grad. - auto bn_grad_abstract_tuple = dyn_cast(bn_grad->abstract()); - MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); - if (bn_grad_abstract_tuple->elements().empty()) { - MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be empty"; - } - bn_infer_grad->set_abstract(bn_grad_abstract_tuple->elements()[0]); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_infer_grad); - bn_infer_grad->set_scope(bn_grad->scope()); - return bn_infer_grad; -} - -AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, - const AnfNodePtr &bn_grad, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn_grad); - MS_EXCEPTION_IF_NULL(equiv); - // Set inputs - auto iter_input0 = (*equiv).find(input0_var_); - if (iter_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; - } - auto iter_input1 = (*equiv).find(input1_var_); - if (iter_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input1 var after matched."; - } - auto iter_input3 = (*equiv).find(input3_var_); - if (iter_input3 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input3 var after matched."; - } - auto iter_input4 = (*equiv).find(input4_var_); - if (iter_input4 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; - } - std::vector bn_training_update_grad_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), - utils::cast(iter_input0->second), utils::cast(iter_input1->second), - utils::cast(iter_input3->second), utils::cast(iter_input4->second)}; - auto bn_training_update_grad = func_graph->NewCNode(bn_training_update_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_training_update_grad); - // Set abstract, the outputs of new node are taking the place of the 1st and 2nd outputs of bn_grad. - auto bn_grad_abstract_tuple = dyn_cast(bn_grad->abstract()); - MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); - if (bn_grad_abstract_tuple->elements().size() < kBatchNormGradInferOutputNum) { - MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be less than 3"; - } - std::vector abstract_list{bn_grad_abstract_tuple->elements()[1], - bn_grad_abstract_tuple->elements()[2]}; - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_update_grad->set_abstract(abstract_tuple); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_training_update_grad); - bn_training_update_grad->set_scope(bn_grad->scope()); - return bn_training_update_grad; -} - -const BaseRef BatchNormGradInferFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimBatchNormGrad, input0_var_, input1_var_, input2_var_, input3_var_, input4_var_, Xs}); -} - -const AnfNodePtr BatchNormGradInferFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, node->cast())) { - MS_LOG(DEBUG) << "The BatchNormGrad " << node->DebugString() << " has no is_training attr, should not be changed"; - return nullptr; - } - if (AnfAlgo::GetNodeAttr(node, kAttrIsTraining)) { - MS_LOG(DEBUG) << "The is_training attr value of " << node->DebugString() << " is true, no need change"; - return nullptr; - } - if (!CheckOutputsIndex(func_graph, node)) { - MS_LOG(DEBUG) << "The output 3 or 4 of BatchNormGrad is not null, no need change"; - return nullptr; - } - AnfNodePtr bn_infer_grad = CreateBNInferGrad(func_graph, node, equiv); - AnfNodePtr bn_training_update_grad = CreateBNTrainingUpdateGrad(func_graph, node, equiv); - std::vector bn_training_update_grad_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_grad, kBNTrainingUpdateGradOutputNum, - &bn_training_update_grad_outputs); - if (bn_training_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { - MS_LOG(EXCEPTION) << "The output size of " << bn_training_update_grad << " should be " - << kBNTrainingUpdateGradOutputNum << ", but it is " << bn_training_update_grad_outputs.size(); - } - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_infer_grad, - bn_training_update_grad_outputs[0], bn_training_update_grad_outputs[1]}; - auto make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - return make_tuple; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h deleted file mode 100644 index a8eefdaa85..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormGradInferFission : public PatternProcessPass { - public: - explicit BatchNormGradInferFission(bool multigraph = true) - : PatternProcessPass("batch_norm_grad_infer_fission", multigraph), - input0_var_(std::make_shared()), - input1_var_(std::make_shared()), - input2_var_(std::make_shared()), - input3_var_(std::make_shared()), - input4_var_(std::make_shared()) {} - ~BatchNormGradInferFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - AnfNodePtr CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, const EquivPtr &equiv) const; - AnfNodePtr CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, - const EquivPtr &equiv) const; - - VarPtr input0_var_; - VarPtr input1_var_; - VarPtr input2_var_; - VarPtr input3_var_; - VarPtr input4_var_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.cc deleted file mode 100644 index 270b02cb00..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h" - -#include -#include -#include - -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" -#include "pre_activate/common/helper.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, - std::vector *bn_update_grad_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_grad_node); - auto bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() < kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; - } - std::vector bn_update_grad_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], - bn_grad_inputs[4], bn_grad_inputs[5]}; - auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_update_grad); - bn_update_grad->set_kernel_info(std::make_shared()); - bn_update_grad->set_scope(bn_grad_node->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 1), AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 1), AnfAlgo::GetOutputInferShape(bn_grad_node, 2)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_update_grad.get()); - - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad); - CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs); -} - -void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, - const std::vector &bn_update_grad_outputs, - std::vector *bn_reduce_grad_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_grad_node); - auto bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() < kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; - } - if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { - MS_LOG(EXCEPTION) << "BNTrainingReduceGrad_outputs has wrong size"; - } - std::vector bn_reduce_grad_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceGradOpName)), - bn_grad_inputs[1], - bn_grad_inputs[2], - bn_update_grad_outputs[0], - bn_update_grad_outputs[1], - bn_grad_inputs[3], - bn_grad_inputs[4], - bn_grad_inputs[5]}; - auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_reduce_grad); - bn_reduce_grad->set_kernel_info(std::make_shared()); - bn_reduce_grad->set_scope(bn_grad_node->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_reduce_grad.get()); - - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); - (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); -} -} // namespace -const BaseRef BatchNormGradSplit::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto prim = std::make_shared(kBatchNormGradOpName); - return VectorRef({prim, Xs}); -} - -const AnfNodePtr BatchNormGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(func_graph); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - if (!primitive->HasAttr(kAttrIsTraining)) { - MS_LOG(INFO) << "Op BatchNormGrad must have attrs of is_training"; - return nullptr; - } - if (!AnfAlgo::GetNodeAttr(cnode, kAttrIsTraining)) { - MS_LOG(INFO) << "is_training must be true"; - return nullptr; - } - - std::vector bn_update_grad_outputs; - CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); - if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { - MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; - } - - std::vector bn_reduce_grad_outputs; - CreateOutputsOfReduceGrad(func_graph, cnode, bn_update_grad_outputs, &bn_reduce_grad_outputs); - if (bn_reduce_grad_outputs.size() != kSingleOutputNum) { - MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size"; - } - - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0], - bn_update_grad_outputs[0], bn_update_grad_outputs[1]}; - auto make_tuple = func_graph->NewCNode(make_tuple_inputs); - return make_tuple; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.h deleted file mode 100644 index e539fdb27c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -class BatchNormGradSplit : public PatternProcessPass { - public: - explicit BatchNormGradSplit(bool multigraph = true) : PatternProcessPass("batch_norm_grad_split", multigraph) {} - ~BatchNormGradSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.cc deleted file mode 100644 index 6282ed4f76..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.cc +++ /dev/null @@ -1,123 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/bn_grad_split.h" - -#include -#include -#include - -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" -#include "pre_activate/common/helper.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, - std::vector *bn_update_grad_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_grad_node); - auto bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() != kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; - } - std::vector bn_update_grad_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], - bn_grad_inputs[4], bn_grad_inputs[5]}; - auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_update_grad); - bn_update_grad->set_kernel_info(std::make_shared()); - bn_update_grad->set_scope(bn_grad_node->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 1), AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 1), AnfAlgo::GetOutputInferShape(bn_grad_node, 2)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_update_grad.get()); - - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad); - CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs); -} - -void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, - const std::vector &bn_update_grad_outputs, - std::vector *bn_reduce_grad_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_grad_node); - auto bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() != kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; - } - if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { - MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; - } - std::vector bn_reduce_grad_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceGradOpName)), - bn_grad_inputs[1], - bn_grad_inputs[2], - bn_update_grad_outputs[0], - bn_update_grad_outputs[1], - bn_grad_inputs[3], - bn_grad_inputs[4], - bn_grad_inputs[5]}; - auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_reduce_grad); - bn_reduce_grad->set_kernel_info(std::make_shared()); - bn_reduce_grad->set_scope(bn_grad_node->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_reduce_grad.get()); - - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); - (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); -} - -CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(func_graph); - std::vector bn_update_grad_outputs; - CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); - if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { - MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; - } - - std::vector bn_reduce_grad_outputs; - CreateOutputsOfReduceGrad(func_graph, cnode, bn_update_grad_outputs, &bn_reduce_grad_outputs); - if (bn_reduce_grad_outputs.size() != 1) { - MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size"; - } - - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0], - bn_update_grad_outputs[0], bn_update_grad_outputs[1]}; - auto make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - return make_tuple; -} -} // namespace - -const BaseRef BnGradSplit::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimFusedBatchNormGrad, Xs}); -} - -const AnfNodePtr BnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - return BNGradSplitForTBE(func_graph, cnode); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.h deleted file mode 100644 index 17e1f9b98e..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -class BnGradSplit : public PatternProcessPass { - public: - explicit BnGradSplit(bool multigraph = true) : PatternProcessPass("bn_grad_split", multigraph) {} - ~BnGradSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.cc deleted file mode 100644 index 66ffa24bf1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/bn_split.h" - -#include -#include -#include - -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/helper.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, - std::vector *bn_training_reduce_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() != kBnInputNum) { - MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString(); - return false; - } - std::vector bn_training_reduce_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceOpName))}; - bn_training_reduce_inputs.push_back(bn_cnode->input(1)); - auto bn_training_reduce = graph->NewCNode(bn_training_reduce_inputs); - MS_EXCEPTION_IF_NULL(bn_training_reduce); - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - bn_training_reduce->set_kernel_info(kernel_info); - std::vector bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0); - if (bn_shape_i0.size() < kShape2dDims) { - MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims; - return false; - } - std::vector bn_training_reduce_shape = {bn_shape_i0[1]}; - auto types = {kNumberTypeFloat32, kNumberTypeFloat32}; - auto shapes = {bn_training_reduce_shape, bn_training_reduce_shape}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_training_reduce.get()); - bn_training_reduce->set_scope(bn_cnode->scope()); - AnfAlgo::CopyNodeAttrs(bn_cnode, bn_training_reduce); - - CreateMultipleOutputsOfAnfNode(graph, bn_training_reduce, kBNTrainingReduceOutputNum, bn_training_reduce_outputs); - return true; -} - -AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, - const std::vector &bn_training_reduce_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() != kBnInputNum) { - MS_LOG(EXCEPTION) << "BN node has wrong input size"; - } - if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { - MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; - } - // the inputs of BNTrainingUpdate are from the outputs of BNTrainingReduce and the inputs of BN - std::vector bn_training_update_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateOpName))}; - bn_training_update_inputs.push_back(bn_cnode->input(1)); - bn_training_update_inputs.push_back(bn_training_reduce_outputs[0]); - bn_training_update_inputs.push_back(bn_training_reduce_outputs[1]); - bn_training_update_inputs.push_back(bn_cnode->input(2)); - bn_training_update_inputs.push_back(bn_cnode->input(3)); - bn_training_update_inputs.push_back(bn_cnode->input(4)); - bn_training_update_inputs.push_back(bn_cnode->input(5)); - auto bn_training_update = graph->NewCNode(bn_training_update_inputs); - MS_EXCEPTION_IF_NULL(bn_training_update); - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - bn_training_update->set_kernel_info(kernel_info); - bn_training_update->set_abstract(bn_cnode->abstract()); - bn_training_update->set_scope(bn_cnode->scope()); - auto factor = AnfAlgo::GetNodeAttr(bn_cnode, kAttrMomentum); - AnfAlgo::SetNodeAttr(kAttrFactor, MakeValue(factor), bn_training_update); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update); - AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update); - return bn_training_update; -} - -AnfNodePtr SplitFusedBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() < kBnInputNum) { - MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs."; - return nullptr; - } - // Create BNTrainingReduce node and get outputs of BNTrainingReduce - std::vector bn_training_reduce_outputs; - if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) { - MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split"; - return nullptr; - } - if (bn_training_reduce_outputs.size() != kBN1OutputNum) { - MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"; - } - - // Create BNTrainingUpdate node - return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs); -} -} // namespace - -const BaseRef BnSplit::DefinePattern() const { - VarPtr Xs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - return VectorRef({prim::kPrimFusedBatchNorm, Xs}); -} - -const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - return SplitFusedBatchNormForTBE(func_graph, node); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.h deleted file mode 100644 index bc5975af17..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -class BnSplit : public PatternProcessPass { - public: - explicit BnSplit(bool multigraph = true) : PatternProcessPass("bn_split", multigraph) {} - ~BnSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.cc deleted file mode 100644 index 479e00e4c0..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/lars_v2_fission.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -namespace { -void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2, - std::vector *square_sum_all_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(lars_v2); - if (lars_v2->size() != kLarsV2InputNum) { - MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; - } - - std::vector inputs = {NewValueNode(std::make_shared(kSquareSumAllOpName))}; - inputs.push_back(lars_v2->input(1)); - inputs.push_back(lars_v2->input(2)); - auto square_sum_all = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(square_sum_all); - square_sum_all->set_scope(lars_v2->scope()); - - auto types = {kNumberTypeFloat32, kNumberTypeFloat32}; - std::vector shape; - auto shapes = {shape, shape}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sum_all.get()); - - CreateMultipleOutputsOfAnfNode(graph, square_sum_all, 2, square_sum_all_outputs); -} - -CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2, - const std::vector &square_sum_all_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(lars_v2); - if (square_sum_all_outputs.size() != 2) { - MS_LOG(EXCEPTION) << "square_sum_all_outputs' size not equal 2"; - } - if (lars_v2->size() != kLarsV2InputNum) { - MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; - } - std::vector inputs = {NewValueNode(std::make_shared(kLarsV2UpdateOpName))}; - inputs.push_back(lars_v2->input(1)); - inputs.push_back(lars_v2->input(2)); - inputs.push_back(square_sum_all_outputs[0]); - inputs.push_back(square_sum_all_outputs[1]); - inputs.push_back(lars_v2->input(3)); - inputs.push_back(lars_v2->input(4)); - auto lars_v2_update = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(lars_v2_update); - lars_v2_update->set_scope(lars_v2->scope()); - lars_v2_update->set_abstract(lars_v2->abstract()); - return lars_v2_update; -} -} // namespace - -const BaseRef LarsV2Fission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto lars_v2_prim = std::make_shared(kLarsV2OpName); - return VectorRef({lars_v2_prim, Xs}); -} - -const AnfNodePtr LarsV2Fission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto lars_v2 = node->cast(); - MS_EXCEPTION_IF_NULL(lars_v2); - - std::vector square_sum_all_outputs; - CreateOutputsOfSquareSumAll(graph, lars_v2, &square_sum_all_outputs); - return CreateLarsV2Update(graph, lars_v2, square_sum_all_outputs); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.h deleted file mode 100644 index 846d221c53..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class LarsV2Fission : public PatternProcessPass { - public: - explicit LarsV2Fission(bool multigraph = true) : PatternProcessPass("lars_v2_fission", multigraph) {} - ~LarsV2Fission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.cc deleted file mode 100644 index 1a25d83650..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "ir/primitive.h" -#include "common/utils.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop( - const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, - std::vector *layer_norm_x_backprop_outputs) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(layer_norm_grad); - auto prim = std::make_shared(kLayerNormXBackpropOpName); - std::vector layer_norm_x_backprop_inputs = {NewValueNode(prim)}; - for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) { - layer_norm_x_backprop_inputs.push_back(layer_norm_grad->input(i)); - } - auto layer_norm_x_backprop = graph->NewCNode(layer_norm_x_backprop_inputs); - MS_EXCEPTION_IF_NULL(layer_norm_x_backprop); - layer_norm_x_backprop->set_scope(layer_norm_grad->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_x_backprop.get()); - - (*layer_norm_x_backprop_outputs).push_back(layer_norm_x_backprop); -} - -void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop( - const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, - std::vector *layer_norm_beta_gamma_backprop_outputs) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(layer_norm_grad); - auto prim = std::make_shared(kLayerNormBetaGammaBackpropOpName); - std::vector layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim)}; - for (size_t i = 1; i < layer_norm_grad->inputs().size() - 1; ++i) { - layer_norm_beta_gamma_backprop_inputs.push_back(layer_norm_grad->input(i)); - } - auto layer_norm_beta_gamma_backprop = graph->NewCNode(layer_norm_beta_gamma_backprop_inputs); - MS_EXCEPTION_IF_NULL(layer_norm_beta_gamma_backprop); - auto kernel_info = std::make_shared(); - layer_norm_beta_gamma_backprop->set_kernel_info(kernel_info); - layer_norm_beta_gamma_backprop->set_scope(layer_norm_grad->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 1), - AnfAlgo::GetOutputInferDataType(layer_norm_grad, 2)}; - auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 1), AnfAlgo::GetOutputInferShape(layer_norm_grad, 2)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_beta_gamma_backprop.get()); - - // get device shape of LayerNormGrad's 5th Input, and convert it to attr - std::vector shape_gamma = AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, 4); - AnfAlgo::SetNodeAttr(kAttrShapeGamma, MakeValue(opt::Convert2Int(shape_gamma)), layer_norm_beta_gamma_backprop); - - CreateMultipleOutputsOfAnfNode(graph, layer_norm_beta_gamma_backprop, kLayerNormBetaGammaBackpropOutputNum, - layer_norm_beta_gamma_backprop_outputs); -} - -const BaseRef LayerNormGradSplit::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VectorRef pattern({prim::kPrimLayerNormGrad, Xs}); - return pattern; -} - -const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode->inputs().size() != kLayerNormGradInputNum) { - return nullptr; - } - - // create layer_norm_x_backprop - std::vector layer_norm_x_backprop_outputs; - CreateOutputsOfLayerNormXBackprop(graph, cnode, &layer_norm_x_backprop_outputs); - if (layer_norm_x_backprop_outputs.size() != kSingleOutputNum) { - MS_LOG(EXCEPTION) << "layer_norm_grad_outputs has wrong size"; - } - - // create layer_norm_beta_gamma_backprop - std::vector layer_norm_beta_gamma_backprop_outputs; - CreateOutputsOfLayerNormBetaGammaBackprop(graph, cnode, &layer_norm_beta_gamma_backprop_outputs); - if (layer_norm_beta_gamma_backprop_outputs.size() != kLayerNormBetaGammaBackpropOutputNum) { - MS_LOG(EXCEPTION) << "layer_norm_beta_gamma_outputs has wrong size"; - } - - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), layer_norm_x_backprop_outputs[0], - layer_norm_beta_gamma_backprop_outputs[0], - layer_norm_beta_gamma_backprop_outputs[1]}; - auto make_tuple = graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - return make_tuple; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.h deleted file mode 100644 index f442446b01..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ - -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class LayerNormGradSplit : public PatternProcessPass { - public: - explicit LayerNormGradSplit(bool multigraph = true) : PatternProcessPass("layer_norm_grad_split", multigraph) {} - ~LayerNormGradSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - void CreateOutputsOfLayerNormXBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, - std::vector *layer_norm_grad_outputs) const; - void CreateOutputsOfLayerNormBetaGammaBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, - std::vector *layer_norm_beta_gamma_outputs) const; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc deleted file mode 100644 index 159be2ac3b..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kBatchNormRealInputNum = 3; - -AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn); - auto bn_cnode = bn->cast(); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { - MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " - << kBatchNormRealInputNum + 1; - } - std::vector bn_training_reduce_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceOpName)), bn_cnode->input(1)}; - auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); - MS_EXCEPTION_IF_NULL(bn_training_reduce); - - // set abstract - auto bn_input1 = bn_cnode->input(2); - MS_EXCEPTION_IF_NULL(bn_input1); - AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input1->abstract()}; - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_reduce->set_abstract(abstract_tuple); - bn_training_reduce->set_scope(bn->scope()); - return bn_training_reduce; -} - -AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, - const std::vector &bn_training_reduce_outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn); - auto bn_cnode = bn->cast(); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { - MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " - << kBatchNormRealInputNum + 1; - } - if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum - << ", but it is " << bn_training_reduce_outputs.size(); - } - std::vector bn_training_update_v3_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateV3OpName)), - bn_cnode->input(1), - bn_training_reduce_outputs[0], - bn_training_reduce_outputs[1], - bn_cnode->input(2), - bn_cnode->input(3)}; - auto bn_training_update_v3 = func_graph->NewCNode(bn_training_update_v3_inputs); - MS_EXCEPTION_IF_NULL(bn_training_update_v3); - - auto bn_abstract_tuple = dyn_cast(bn->abstract()); - MS_EXCEPTION_IF_NULL(bn_abstract_tuple); - if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { - MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " - << bn_abstract_tuple->elements().size(); - } - bn_training_update_v3->set_abstract(bn->abstract()); - bn_training_update_v3->set_scope(bn->scope()); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update_v3); - return bn_training_update_v3; -} -} // namespace - -const BaseRef SingleBatchNormFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimBatchNorm, Xs}); -} - -const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() < kBatchNormRealInputNum + 1) { - MS_LOG(INFO) << "The input num of BatchNorm less than" << kBatchNormRealInputNum - << ". The node should not be changed"; - return nullptr; - } - if (!GetBoolAttr(cnode, kAttrIsTraining)) { - MS_LOG(INFO) << "is training should be true if do fusion"; - return nullptr; - } - AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); - std::vector bn_training_reduce_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, - &bn_training_reduce_outputs); - - return CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.h deleted file mode 100644 index 145603132b..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class SingleBatchNormFission : public PatternProcessPass { - public: - explicit SingleBatchNormFission(bool multigraph = true) - : PatternProcessPass("single_batch_norm_fission", multigraph) {} - ~SingleBatchNormFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc deleted file mode 100644 index c39a5e01e6..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/split_fission.h" -#include -#include -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(input_node); - std::vector splitv_inputs{NewValueNode(std::make_shared(kSplitVOpName)), input_node}; - CNodePtr splitv = func_graph->NewCNode(splitv_inputs); - MS_EXCEPTION_IF_NULL(splitv); - splitv->set_scope(input_node->scope()); - return splitv; -} - -CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { - MS_EXCEPTION_IF_NULL(origin_cnode); - if (origin_cnode->inputs().size() < kSplitInputNum) { - MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " - << kSplitInputNum - 1; - } - return CreateSplitVNode(func_graph, origin_cnode->input(1)); -} - -void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector &size_splits, int split_dim, int num_split) { - AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv); - AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv); - AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv); -} - -size_t GetSmallSplitSize(const AnfNodePtr &split_node, int split_dim, int num_split) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0); - if (split_dim < 0) { - split_dim += input_shape.size(); - } - if (IntToSize(split_dim) >= input_shape.size()) { - MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0"; - } - return input_shape[split_dim] / num_split; -} - -void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int outputs_num, - std::vector *inputs) { - MS_EXCEPTION_IF_NULL(inputs); - std::vector new_splitv_output; - CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, outputs_num, &new_splitv_output); - inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end()); -} - -AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) { - MS_EXCEPTION_IF_NULL(func_graph); - auto idx = NewValueNode(SizeToInt(index)); - MS_EXCEPTION_IF_NULL(idx); - auto imm = std::make_shared(SizeToInt(index)); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx}); - return tuple_getitem; -} - -void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int split_size, int num_split, - std::vector *new_type_ids, - std::vector> *new_output_shapes) { - MS_EXCEPTION_IF_NULL(new_type_ids); - MS_EXCEPTION_IF_NULL(new_output_shapes); - auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); - output_shape[split_dim] = split_size; - TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); - for (int i = 0; i < num_split; ++i) { - new_type_ids->emplace_back(type_id); - new_output_shapes->emplace_back(output_shape); - } -} - -void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, - const std::vector &size_splits_base, int split_dim, int num_split) { - SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); - std::vector base_type_ids; - std::vector> base_output_shapes_base; - auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); - TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); - for (int i = 0; i < num_split; ++i) { - output_shape[split_dim] = size_splits_base[i]; - base_output_shapes_base.emplace_back(output_shape); - base_type_ids.emplace_back(type_id); - } - AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); -} - -AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int num_split, int divisor) { - MS_EXCEPTION_IF_NULL(func_graph); - auto split_dim = AnfAlgo::GetNodeAttr(cnode, kAttrAxis); - CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode); - - // Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs. - auto small_split_size = SizeToInt(GetSmallSplitSize(cnode, split_dim, num_split)); - std::vector size_splits_new; - for (int i = 0; i < divisor; ++i) { - size_splits_new.emplace_back(small_split_size); - } - // Create new output shape and new output type id for each new Splitv node which has full inputs. - std::vector new_type_ids; - std::vector> new_output_shapes; - CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes); - - // Create make_tuple input to create a make_tuple for replacing the old Split node. - std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; - // Start to divide the outputs of Split. - std::vector size_splits_base; - const auto base_split_size = divisor * small_split_size; - int nodes_num = 0; - int cur_output_index = 0; - while (num_split - cur_output_index > divisor) { - CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); - SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor); - AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get()); - AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs); - cur_output_index += divisor; - size_splits_base.emplace_back(base_split_size); - nodes_num++; - } - if (cur_output_index < num_split) { - auto last_node_num_split = num_split - cur_output_index; - if (last_node_num_split > 1) { - CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); - std::vector size_splits_new_last; - for (int i = 0; i < last_node_num_split; ++i) { - size_splits_new_last.emplace_back(small_split_size); - } - SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); - // Create new output shape and new output type id for the last Splitv node - std::vector last_new_type_ids; - std::vector> last_new_output_shapes; - CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids, - &last_new_output_shapes); - AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get()); - AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs); - size_splits_base.emplace_back(last_node_num_split * small_split_size); - } else { - make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num)); - size_splits_base.emplace_back(small_split_size); - } - nodes_num++; - } - // Set Attr and abstract for the base splitv - SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num); - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - return make_tuple; -} -} // namespace - -const BaseRef SplitFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto split_prim = std::make_shared(kSplitOpName); - return VectorRef({split_prim, Xs}); -} - -const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // Check output num - if (!AnfAlgo::HasNodeAttr(kAttrOutputNum, cnode)) { - return nullptr; - } - auto num_split = AnfAlgo::GetNodeAttr(cnode, kAttrOutputNum); - if (num_split <= outputs_divisor_) { - return nullptr; - } - return DoFission(func_graph, cnode, num_split, outputs_divisor_); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.h deleted file mode 100644 index c2763bb714..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -constexpr int kSplitOutputsDivisor = 63; -class SplitFission : public PatternProcessPass { - public: - explicit SplitFission(bool multigraph = true) - : PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {} - ~SplitFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - int outputs_divisor_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.cc deleted file mode 100644 index 6e6cea5ae5..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(tensor_scatter_update); - std::vector inputs = {NewValueNode(std::make_shared(kTensorMoveOpName)), - tensor_scatter_update->input(1)}; - auto tensor_move = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(tensor_move); - tensor_move->set_scope(tensor_scatter_update->scope()); - tensor_move->set_abstract(tensor_scatter_update->abstract()); - AnfAlgo::SetNodeAttr(kAttrUseLocking, MakeValue(false), tensor_move); - return tensor_move; -} - -CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update, - const CNodePtr &tensor_move) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(tensor_scatter_update); - MS_EXCEPTION_IF_NULL(tensor_move); - std::vector inputs = {NewValueNode(std::make_shared(kScatterNdUpdateOpName)), tensor_move, - tensor_scatter_update->input(2), tensor_scatter_update->input(3)}; - auto scatter_nd_update = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(scatter_nd_update); - scatter_nd_update->set_scope(tensor_scatter_update->scope()); - scatter_nd_update->set_abstract(tensor_scatter_update->abstract()); - return scatter_nd_update; -} -} // namespace - -const BaseRef TensorScatterUpdateFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto prim = std::make_shared(kTensorScatterUpdateOpName); - return VectorRef({prim, Xs}); -} - -const AnfNodePtr TensorScatterUpdateFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto tensor_scatter_update = node->cast(); - if (tensor_scatter_update == nullptr || tensor_scatter_update->size() != 4) { - return nullptr; - } - auto tensor_move = CreateTensorMove(func_graph, tensor_scatter_update); - return CreateScatterNdUpdate(func_graph, tensor_scatter_update, tensor_move); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h deleted file mode 100644 index 0ada93ac70..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class TensorScatterUpdateFission : public PatternProcessPass { - public: - explicit TensorScatterUpdateFission(bool multigraph = true) - : PatternProcessPass("tensor_scatter_update_fission", multigraph) {} - ~TensorScatterUpdateFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc deleted file mode 100644 index c8477353f9..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc +++ /dev/null @@ -1,182 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/topk_split.h" -#include -#include -#include -#include -#include "pre_activate/common/helper.h" -#include "kernel/kernel_build_info.h" -#include "utils/utils.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -constexpr size_t kFloat16Len = 2; // size of float16; -constexpr size_t kTopkIndexK = 1; -namespace { -tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { - // 1 create tensor - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); - auto last_dim = shape[shape.size() - 1]; - std::vector indices_shape = {SizeToInt(last_dim * 2)}; - TensorTypePtr tensor_type = std::make_shared(kFloat16); - MS_EXCEPTION_IF_NULL(tensor_type); - tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type}; - tensor::TensorPtr indices_tensor = std::make_shared(kFloat16->type_id(), indices_shape); - MS_EXCEPTION_IF_NULL(indices_tensor); - indices_tensor->set_device_info(device_info); - - // 2 set value of tensor - auto data_ptr = indices_tensor->data_c(); - MS_EXCEPTION_IF_NULL(data_ptr); - std::vector half_data; - for (size_t i = 0; i < last_dim; ++i) { - half_data.emplace_back(Eigen::half(static_cast(i))); - } - for (size_t i = 0; i < last_dim; ++i) { - auto gap = static_cast(i) - static_cast(Eigen::half(static_cast(i))); - half_data.emplace_back(Eigen::half(static_cast(gap))); - } - auto elem_num = last_dim * kFloat16Len * 2; - auto ret_code = memcpy_s(data_ptr, static_cast(indices_tensor->data().nbytes()), half_data.data(), elem_num); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data into Tensor."; - return nullptr; - } - return indices_tensor; -} - -ValueNodePtr CreateValueNode(const AnfNodePtr &node) { - tensor::TensorPtr indices_tensor = CreateTensor(node); - MS_EXCEPTION_IF_NULL(indices_tensor); - auto indices_const = std::make_shared(indices_tensor); - MS_EXCEPTION_IF_NULL(indices_const); - auto indices_abstract = indices_tensor->ToAbstract(); - indices_const->set_abstract(indices_abstract); - auto indices_kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(indices_kernel_info); - indices_const->set_kernel_info(indices_kernel_info); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1; - builder1.SetOutputsFormat({kOpFormat_DEFAULT}); - builder1.SetOutputsDeviceType({kNumberTypeFloat16}); - AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get()); - return indices_const; -} - -kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetKernelType(TBE_KERNEL); - builder.SetFusionType(kernel::OPAQUE); - builder.SetProcessor(kernel::AICORE); - builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); - builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); - builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); - builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); - return builder.Build(); -} - -bool CheckInputNamesSize(const CNodePtr &cnode) { - auto input_names_vec = AnfAlgo::GetNodeAttr>(cnode, kAttrInputNames); - if (input_names_vec.size() < kTopkIndexK + 1) { - MS_LOG(INFO) << "The input k of topk has been converted to attr"; - return false; - } - return true; -} - -bool CheckOutputShape(const AnfNodePtr &node) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); - if (shape.empty()) { - MS_LOG(INFO) << "The output shape of topk to split must not be empty"; - return false; - } - auto last_dim = shape[shape.size() - 1]; - const size_t kMaxFloat16 = 65500; - if (last_dim > kMaxFloat16) { - MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops."; - return false; - } - return true; -} -} // namespace - -const BaseRef TopKSplit::DefinePattern() const { - VarPtr X1 = std::make_shared(); - VarPtr X2 = std::make_shared(); - auto prim = std::make_shared(kTopKOpName); - return VectorRef({prim, X1, X2}); -} - -const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto kernel_graph = func_graph->cast(); - // set value node as topk's input - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!CheckInputNamesSize(cnode)) { - return nullptr; - } - if (!CheckOutputShape(cnode)) { - return nullptr; - } - // Copy a new node to check supported. - std::vector new_inputs{NewValueNode(std::make_shared(kTopKOpName))}; - new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); - CNodePtr new_cnode = func_graph->NewCNode(new_inputs); - MS_EXCEPTION_IF_NULL(new_cnode); - new_cnode->set_abstract(cnode->abstract()); - new_cnode->set_scope(cnode->scope()); - AnfAlgo::CopyNodeAttrs(cnode, new_cnode); - CheckCNodeInputSize(new_cnode, kTopkInputNum); - // Convert the tensor input to scalar and convert it to attr - auto input_k = new_cnode->input(kTopkIndexK + 1); - MS_EXCEPTION_IF_NULL(input_k); - if (!IsValueNode(input_k)) { - return nullptr; - } - ValuePtr value = GetValueNode(input_k); - MS_EXCEPTION_IF_NULL(value); - auto tensor = value->cast(); - MS_EXCEPTION_IF_NULL(tensor); - int32_t *data = reinterpret_cast(tensor->data_c()); - MS_EXCEPTION_IF_NULL(data); - auto new_value_node = std::make_shared(MakeValue(*data)); - new_cnode->set_input(kTopkIndexK + 1, new_value_node); - - std::unordered_set attr_index{kTopkIndexK}; - ConstInputToAttr(new_cnode, attr_index); - auto indices_const = CreateValueNode(new_cnode); - new_cnode->add_input(indices_const); - MS_EXCEPTION_IF_NULL(supported_checker_); - if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { - MS_LOG(INFO) << "split topk failed, check to aicpu."; - return nullptr; - } - - if (kernel_graph != nullptr) { - MS_LOG(INFO) << "split topk success. use tbe aicore."; - kernel_graph->AddValueNodeToGraph(indices_const); - } - - return new_cnode; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.h deleted file mode 100644 index e7293e1fa3..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ - -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class TopKSplit : public PatternProcessPass { - public: - explicit TopKSplit(bool multigraph = true) - : PatternProcessPass("topk_split", multigraph), supported_checker_(std::make_shared()) {} - ~TopKSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - SupportedCheckerPtr supported_checker_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc deleted file mode 100644 index bfb7e50486..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/transdata_split.h" -#include -#include "pre_activate/ascend/ascend_helper.h" -#include "session/anf_runtime_algorithm.h" -#include "debug/anf_ir_dump.h" - -namespace mindspore { -namespace opt { -const std::set> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, - {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, - {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, - {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; - -bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - bool changed = false; - std::vector node_list = TopoSort(func_graph->get_return()); - for (auto &node : node_list) { - if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { - CheckCNodeInputSize(node->cast(), kBackendTransDataInputNum); - if (IsFormatInvaild(node)) { - changed = DoSplit(func_graph, node); - } - } - } - return changed; -} -bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_format = AnfAlgo::GetInputFormat(node, 0); - auto output_format = AnfAlgo::GetOutputFormat(node, 0); - auto format_pair = std::make_pair(input_format, output_format); - - return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); -} -// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) -bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_node = node->cast()->input(1); - MS_EXCEPTION_IF_NULL(input_node); - - auto input_format = AnfAlgo::GetInputFormat(node, 0); - auto output_format = AnfAlgo::GetOutputFormat(node, 0); - AnfNodePtr new_transdata_node = nullptr; - AnfNodePtr new_transpose_node = nullptr; - AnfNodePtr new_replace_node = nullptr; - // if output_format=default transdata need split transdata->transpose else transpose->transdata - if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { - // trans input_format to hwcn - new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, - false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node); - // trans hwcn to default_format - new_transpose_node = - NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); - RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node); - AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{3, 2, 0, 1}), new_transpose_node); - new_replace_node = new_transpose_node; - } else { - // trans default to hwcn - new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, - false, prim::kPrimTranspose->name()); - AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); - RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node); - - // trans hwcn to output_format - new_transdata_node = - NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); - new_replace_node = new_transdata_node; - } - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - - if (!manager->Replace(node, new_replace_node)) { - MS_LOG(EXCEPTION) << "Manager replace node failed"; - } - MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; - return true; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.h deleted file mode 100644 index f450897db1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ -#include -#include -#include -#include - -#include "pre_activate/common/pass.h" -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class TransDataSplit : public Pass { - public: - TransDataSplit() : Pass("trans_data_split"), kernel_select_(std::make_shared()) {} - ~TransDataSplit() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - bool DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node); - bool IsFormatInvaild(const AnfNodePtr &node); - KernelSelectPtr kernel_select_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc deleted file mode 100644 index 59be003b15..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc +++ /dev/null @@ -1,151 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto prim = std::make_shared(kAdamApplyOneOpName); - std::vector new_node_inputs = {NewValueNode(prim)}; - for (const auto &input_var : input_vars_) { - auto input_node = utils::cast((*equiv)[input_var]); - MS_EXCEPTION_IF_NULL(input_node); - new_node_inputs.push_back(input_node); - } - for (const auto &mul_x_input_var : mul_x_input_vars_) { - auto mul_x_input_node = utils::cast((*equiv)[mul_x_input_var]); - MS_EXCEPTION_IF_NULL(mul_x_input_node); - new_node_inputs.push_back(mul_x_input_node); - } - auto add2_y_node = utils::cast((*equiv)[add2_y_]); - MS_EXCEPTION_IF_NULL(add2_y_node); - new_node_inputs.push_back(add2_y_node); - auto new_node = func_graph->NewCNode(new_node_inputs); - return new_node; -} - -const BaseRef AdamApplyOneFusion::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); - VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); - VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); - return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); -} - -const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); - VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); - VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); - return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); -} - -const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); - VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); - VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); - return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); -} - -const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); - VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); - VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); - return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); -} - -const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); - VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); - VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); - return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); -} - -const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - auto new_node = CreateAdamApplyOneNode(func_graph, equiv); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(node->scope()); - // Set abstract of new node - AbstractBasePtrList new_node_abstract_list; - auto iter_add0 = (*equiv).find(add0_var_); - if (iter_add0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; - } - auto iter_add1 = (*equiv).find(add1_var_); - if (iter_add1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; - } - auto add0 = utils::cast(iter_add0->second); - MS_EXCEPTION_IF_NULL(add0); - auto add1 = utils::cast(iter_add1->second); - MS_EXCEPTION_IF_NULL(add1); - new_node_abstract_list.push_back(add1->abstract()); - new_node_abstract_list.push_back(add0->abstract()); - new_node_abstract_list.push_back(node->abstract()); - auto abstract_tuple = std::make_shared(new_node_abstract_list); - new_node->set_abstract(abstract_tuple); - // Create tuple_getitem node for outputs - std::vector new_node_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, new_node, kAdamApplyOneOutputNum, &new_node_outputs); - if (new_node_outputs.size() != kAdamApplyOneOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node " << new_node->DebugString() << " should be " - << kAdamApplyOneOutputNum; - } - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(add1, new_node_outputs[0]); - (void)manager->Replace(add0, new_node_outputs[1]); - return new_node_outputs[2]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h deleted file mode 100644 index 5ee8a86cfb..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -constexpr size_t kAdamApplyOneInputVarNum = 5; -constexpr size_t kAdamApplyOneMulInputVarNum = 4; - -class AdamApplyOneFusion : public PatternProcessPass { - public: - explicit AdamApplyOneFusion(const std::string &name = "adam_apply_one_fusion", bool multigraph = true) - : PatternProcessPass(name, multigraph) { - for (size_t i = 0; i < kAdamApplyOneInputVarNum; ++i) { - input_vars_.push_back(std::make_shared()); - } - for (size_t i = 0; i < kAdamApplyOneMulInputVarNum; ++i) { - mul_x_input_vars_.push_back(std::make_shared()); - } - add2_y_ = std::make_shared(); - add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - } - - ~AdamApplyOneFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - protected: - AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; - std::vector input_vars_; - std::vector mul_x_input_vars_; - VarPtr add2_y_; - VarPtr add0_var_; - VarPtr add1_var_; -}; - -class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { - public: - explicit AdamApplyOneCond1Fusion(bool multigraph = true) - : AdamApplyOneFusion("adam_apply_one_cond1_fusion", multigraph) {} - - ~AdamApplyOneCond1Fusion() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneCond2Fusion : public AdamApplyOneFusion { - public: - explicit AdamApplyOneCond2Fusion(bool multigraph = true) - : AdamApplyOneFusion("adam_apply_one_cond2_fusion", multigraph) {} - - ~AdamApplyOneCond2Fusion() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneCond3Fusion : public AdamApplyOneFusion { - public: - explicit AdamApplyOneCond3Fusion(bool multigraph = true) - : AdamApplyOneFusion("adam_apply_one_cond3_fusion", multigraph) {} - - ~AdamApplyOneCond3Fusion() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { - public: - explicit AdamApplyOneCond4Fusion(bool multigraph = true) - : AdamApplyOneFusion("adam_apply_one_cond4_fusion", multigraph) {} - - ~AdamApplyOneCond4Fusion() override = default; - const BaseRef DefinePattern() const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc deleted file mode 100644 index f6077c95f2..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc +++ /dev/null @@ -1,189 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" - -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -std::vector AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(equiv); - auto input0 = utils::cast((*equiv)[input0_]); - auto input1 = utils::cast((*equiv)[input1_]); - auto input2 = utils::cast((*equiv)[input2_]); - auto input3 = utils::cast((*equiv)[input3_]); - auto input4 = utils::cast((*equiv)[input4_]); - auto mul0_x = utils::cast((*equiv)[mul0_x_]); - auto mul1_x = utils::cast((*equiv)[mul1_x_]); - auto mul2_x = utils::cast((*equiv)[mul2_x_]); - auto mul3_x = utils::cast((*equiv)[mul3_x_]); - auto mul4_x = utils::cast((*equiv)[mul4_x_]); - auto add2_y = utils::cast((*equiv)[add2_y_]); - auto prim = std::make_shared(kAdamApplyOneWithDecayOpName); - return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y}; -} - -const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const { - auto sqrt = std::make_shared(kSqrtOpName); - auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); - VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); - VectorRef square0({prim::kPrimSquare, input0_}); - VectorRef add0({add0_var_, mul0, mul1}); - VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); - VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); - VectorRef add1({add1_var_, mul2, mul3}); - VectorRef sqrt0({sqrt, add1}); - VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); - VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); - VectorRef real_div0({real_div, add0, add2}); - VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); - VectorRef mul5({prim::kPrimMul, input4_, add3}); - VectorRef sub0({prim::kPrimSub, input3_, mul5}); - return sub0; -} - -const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const { - auto sqrt = std::make_shared(kSqrtOpName); - auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0({prim::kPrimMul, input2_, mul0_x_}); - VectorRef mul1({prim::kPrimMul, input0_, mul1_x_}); - VectorRef square0({prim::kPrimSquare, input0_}); - VectorRef add0({add0_var_, mul0, mul1}); - VectorRef mul2({prim::kPrimMul, input1_, mul2_x_}); - VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); - VectorRef add1({add1_var_, mul2, mul3}); - VectorRef sqrt0({sqrt, add1}); - VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); - VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); - VectorRef real_div0({real_div, add0, add2}); - VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); - VectorRef mul5({prim::kPrimMul, add3, input4_}); - VectorRef sub0({prim::kPrimSub, input3_, mul5}); - return sub0; -} - -const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const { - auto sqrt = std::make_shared(kSqrtOpName); - auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); - VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); - VectorRef square0({prim::kPrimSquare, input0_}); - VectorRef add0({add0_var_, mul0, mul1}); - VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); - VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); - VectorRef add1({add1_var_, mul2, mul3}); - VectorRef sqrt0({sqrt, add1}); - VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); - VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); - VectorRef real_div0({real_div, add0, add2}); - VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); - VectorRef mul5({prim::kPrimMul, add3, input4_}); - VectorRef sub0({prim::kPrimSub, input3_, mul5}); - return sub0; -} - -const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const { - auto sqrt = std::make_shared(kSqrtOpName); - auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); - VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); - VectorRef square0({prim::kPrimSquare, input0_}); - VectorRef add0({add0_var_, mul0, mul1}); - VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); - VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); - VectorRef add1({add1_var_, mul2, mul3}); - VectorRef sqrt0({sqrt, add1}); - VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); - VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); - VectorRef real_div0({real_div, add0, add2}); - VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); - VectorRef mul5({prim::kPrimMul, add3, input4_}); - VectorRef sub0({prim::kPrimSub, input3_, mul5}); - return sub0; -} - -const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const { - auto sqrt = std::make_shared(kSqrtOpName); - auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); - VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); - VectorRef square0({prim::kPrimSquare, input0_}); - VectorRef add0({add0_var_, mul0, mul1}); - VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); - VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); - VectorRef add1({add1_var_, mul2, mul3}); - VectorRef sqrt0({sqrt, add1}); - VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); - VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); - VectorRef real_div0({real_div, add0, add2}); - VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); - VectorRef mul5({prim::kPrimMul, add3, input4_}); - VectorRef sub0({prim::kPrimSub, input3_, mul5}); - return sub0; -} - -const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - if (graph == nullptr || node == nullptr || equiv == nullptr) { - return nullptr; - } - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - std::vector inputs = GetFusionNodeInputs(equiv); - auto fusion_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(node->scope()); - - auto iter_add0 = (*equiv).find(add0_var_); - if (iter_add0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; - } - auto iter_add1 = (*equiv).find(add1_var_); - if (iter_add1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; - } - auto add0 = utils::cast(iter_add0->second); - MS_EXCEPTION_IF_NULL(add0); - auto add1 = utils::cast(iter_add1->second); - MS_EXCEPTION_IF_NULL(add1); - auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), - AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), - AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); - - std::vector fusion_node_outputs; - CreateMultipleOutputsOfAnfNode(graph, fusion_node, kAdamApplyOneWithDecayOutputNum, &fusion_node_outputs); - if (fusion_node_outputs.size() != kAdamApplyOneWithDecayOutputNum) { - MS_LOG(ERROR) << "create multiple outputs for fusion node fail!"; - return nullptr; - } - - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(add1, fusion_node_outputs[0]); - (void)manager->Replace(add0, fusion_node_outputs[1]); - return fusion_node_outputs[2]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h deleted file mode 100644 index 742295dd9c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h +++ /dev/null @@ -1,111 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "utils/utils.h" -namespace mindspore { -namespace opt { -class AdamApplyOneWithDecayRule : public PatternProcessPass { - public: - explicit AdamApplyOneWithDecayRule(const std::string &name = "adam_apply_one_with_decay_rule", bool multigraph = true) - : PatternProcessPass(name, multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - input2_ = std::make_shared(); - input3_ = std::make_shared(); - input4_ = std::make_shared(); - mul0_x_ = std::make_shared(); - mul1_x_ = std::make_shared(); - mul2_x_ = std::make_shared(); - mul3_x_ = std::make_shared(); - mul4_x_ = std::make_shared(); - add2_y_ = std::make_shared(); - add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - } - ~AdamApplyOneWithDecayRule() override = default; - const BaseRef DefinePattern() const override = 0; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - protected: - std::vector GetFusionNodeInputs(const EquivPtr &equiv) const; - VarPtr input0_; - VarPtr input1_; - VarPtr input2_; - VarPtr input3_; - VarPtr input4_; - VarPtr mul0_x_; - VarPtr mul1_x_; - VarPtr mul2_x_; - VarPtr mul3_x_; - VarPtr mul4_x_; - VarPtr add2_y_; - VarPtr add0_var_; - VarPtr add1_var_; -}; - -class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { - public: - explicit AdamApplyOneWithDecayRuleCond1(bool multigraph = true) - : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond1", multigraph) {} - - ~AdamApplyOneWithDecayRuleCond1() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneWithDecayRuleCond2 : public AdamApplyOneWithDecayRule { - public: - explicit AdamApplyOneWithDecayRuleCond2(bool multigraph = true) - : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond2", multigraph) {} - - ~AdamApplyOneWithDecayRuleCond2() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneWithDecayRuleCond3 : public AdamApplyOneWithDecayRule { - public: - explicit AdamApplyOneWithDecayRuleCond3(bool multigraph = true) - : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond3", multigraph) {} - - ~AdamApplyOneWithDecayRuleCond3() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneWithDecayRuleCond4 : public AdamApplyOneWithDecayRule { - public: - explicit AdamApplyOneWithDecayRuleCond4(bool multigraph = true) - : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond4", multigraph) {} - - ~AdamApplyOneWithDecayRuleCond4() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule { - public: - explicit AdamApplyOneWithDecayRuleCond5(bool multigraph = true) - : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond5", multigraph) {} - - ~AdamApplyOneWithDecayRuleCond5() override = default; - const BaseRef DefinePattern() const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.cc deleted file mode 100644 index 1a62b7a5be..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.cc +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr CreateBNInfer(const FuncGraphPtr &graph, const CNodePtr &batchnorm, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(batchnorm); - MS_EXCEPTION_IF_NULL(node); - auto prim = std::make_shared(kBNInferOpName); - std::vector inputs = {NewValueNode(prim)}; - for (size_t i = 1; i < batchnorm->size(); ++i) { - inputs.push_back(batchnorm->input(i)); - } - auto new_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(batchnorm->scope()); - new_node->set_abstract(node->abstract()); - AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnorm, new_node); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnorm, new_node); - return new_node; -} - -bool CheckIndex(const AnfNodePtr &index_node) { - MS_EXCEPTION_IF_NULL(index_node); - if (!IsValueNode(index_node)) { - return false; - } - ValueNodePtr value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (index != 0) { - MS_LOG(DEBUG) << "tuple_getitem must be 0th output of BatchNorm"; - return false; - } - return true; -} - -bool CheckBatchNorm(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(batchnorm); - if (batchnorm->size() < kBatchNormInputNum + 1) { - MS_LOG(DEBUG) << "BatchNorm's input less than " << kBatchNormInputNum; - return false; - } - if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) { - return false; - } - auto is_training = AnfAlgo::GetNodeAttr(batchnorm, kAttrIsTraining); - if (is_training) { - MS_LOG(DEBUG) << "is_training is true, no need do fusion"; - return false; - } - - if (IsUsedByOthers(graph, batchnorm)) { - MS_LOG(DEBUG) << "Only the 0th output of BatchNorm is used, then do fusion"; - return false; - } - return true; -} - -bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto tuple_getitem = node->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); - AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - if (!CheckIndex(index_node)) { - return false; - } - - AnfNodePtr batchnorm_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(batchnorm_anf); - MS_EXCEPTION_IF_NULL(batchnorm); - *batchnorm = batchnorm_anf->cast(); - MS_EXCEPTION_IF_NULL(*batchnorm); - return CheckBatchNorm(graph, *batchnorm); -} -} // namespace - -const BaseRef BatchNorm2BNInfer::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VarPtr Y = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Y); - VectorRef batchnorm({prim::kPrimBatchNorm, Xs}); - VectorRef pattern({prim::kPrimTupleGetItem, batchnorm, Y}); - return pattern; -} - -const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - - CNodePtr batchnorm = nullptr; - if (!NeedFusion(graph, node, &batchnorm)) { - return nullptr; - } - return CreateBNInfer(graph, batchnorm, node); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h deleted file mode 100644 index 551fe0f6f9..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNorm2BNInfer : public PatternProcessPass { - public: - explicit BatchNorm2BNInfer(bool multigraph = true) : PatternProcessPass("batchnorm_to_bninfer", multigraph) {} - ~BatchNorm2BNInfer() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc deleted file mode 100644 index 424d3a12c1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr CreateBNInferGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(batchnormgrad); - auto prim = std::make_shared(kBNInferGradOpName); - std::vector inputs = {NewValueNode(prim)}; - inputs.push_back(batchnormgrad->input(1)); - inputs.push_back(batchnormgrad->input(3)); - inputs.push_back(batchnormgrad->input(5)); - auto new_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(batchnormgrad->scope()); - new_node->set_abstract(node->abstract()); - AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnormgrad, new_node); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnormgrad, new_node); - return new_node; -} - -bool CheckIndex(const AnfNodePtr &index_node) { - MS_EXCEPTION_IF_NULL(index_node); - if (!IsValueNode(index_node)) { - return false; - } - ValueNodePtr value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (index != 0) { - MS_LOG(DEBUG) << "tuple_getitem must be 0th output of BatchNormGrad"; - return false; - } - return true; -} - -bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(batchnormgrad); - if (batchnormgrad->size() < kBatchNormInputNum + 1) { - MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBatchNormInputNum; - return false; - } - if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnormgrad)) { - return false; - } - auto is_training = AnfAlgo::GetNodeAttr(batchnormgrad, kAttrIsTraining); - if (is_training) { - MS_LOG(DEBUG) << "is_training is true, no need do fusion"; - return false; - } - - if (IsUsedByOthers(graph, batchnormgrad)) { - MS_LOG(DEBUG) << "Only the 0th output of BatchNormGrad is used, then do fusion"; - return false; - } - return true; -} - -bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto tuple_getitem = node->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); - AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - if (!CheckIndex(index_node)) { - return false; - } - - AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(batchnormgrad_anf); - MS_EXCEPTION_IF_NULL(batchnormgrad); - *batchnormgrad = batchnormgrad_anf->cast(); - MS_EXCEPTION_IF_NULL(*batchnormgrad); - return CheckBatchNormGrad(graph, *batchnormgrad); -} -} // namespace - -const BaseRef BatchNormGrad2BNInferGrad::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VarPtr Y = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Y); - VectorRef batchnormgrad({prim::kPrimBatchNormGrad, Xs}); - VectorRef pattern({prim::kPrimTupleGetItem, batchnormgrad, Y}); - return pattern; -} - -const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - - CNodePtr batchnormgrad = nullptr; - if (!NeedFusion(graph, node, &batchnormgrad)) { - return nullptr; - } - return CreateBNInferGrad(graph, batchnormgrad, node); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h deleted file mode 100644 index 020dc1a999..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormGrad2BNInferGrad : public PatternProcessPass { - public: - explicit BatchNormGrad2BNInferGrad(bool multigraph = true) - : PatternProcessPass("batchnormgrad_to_bninfergrad", multigraph) {} - ~BatchNormGrad2BNInferGrad() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc deleted file mode 100644 index 2af3afbf19..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "common/utils.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -const BaseRef ClipByNormNoDivSquareSumFusion::DefinePattern() const { - auto greater = std::make_shared(kGreaterOpName); - MS_EXCEPTION_IF_NULL(greater); - auto sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(sqrt); - - VectorRef greater_pattern({greater, input_, constant_greater_}); - VectorRef pattern( - {prim::kPrimMaximum, - VectorRef({prim::kPrimSelect, greater_pattern, - VectorRef({sqrt, VectorRef({prim::kPrimSelect, greater_pattern, input_, constant_select_})}), input_}), - constant_maximum_}); - return pattern; -} - -const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - BaseRef &input_gnode = (*equiv)[input_]; - BaseRef &constant_select_gnode = (*equiv)[constant_select_]; - BaseRef &constant_greater_gnode = (*equiv)[constant_greater_]; - BaseRef &constant_maximum_gnode = (*equiv)[constant_maximum_]; - auto input = utils::cast(input_gnode); - auto constant_select = utils::cast(constant_select_gnode); - auto constant_greater = utils::cast(constant_greater_gnode); - auto constant_maximum = utils::cast(constant_maximum_gnode); - MS_EXCEPTION_IF_NULL(input); - MS_EXCEPTION_IF_NULL(constant_select); - MS_EXCEPTION_IF_NULL(constant_greater); - MS_EXCEPTION_IF_NULL(constant_maximum); - - auto prim = std::make_shared(kClipByNormNoDivSumOpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = {NewValueNode(prim), input, constant_select, constant_greater, constant_maximum}; - auto fusion_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); - fusion_node->set_scope(node->scope()); - return fusion_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h deleted file mode 100644 index 126480603e..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ - -#include -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -constexpr auto kInputVarName = "input"; -constexpr auto kConstantSelectVarName = "constant_select"; -constexpr auto kConstantGreaterVarName = "constant_greater"; -constexpr auto kConstantMaximumVarName = "constant_maximum"; - -class ClipByNormNoDivSquareSumFusion : public PatternProcessPass { - public: - explicit ClipByNormNoDivSquareSumFusion(bool multigraph = true) - : PatternProcessPass("clip_by_norm_no_div_square_sum_fusion", multigraph) { - input_ = std::make_shared(kInputVarName); - constant_select_ = std::make_shared(kConstantSelectVarName); - constant_greater_ = std::make_shared(kConstantGreaterVarName); - constant_maximum_ = std::make_shared(kConstantMaximumVarName); - } - ~ClipByNormNoDivSquareSumFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input_; - VarPtr constant_select_; - VarPtr constant_greater_; - VarPtr constant_maximum_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.cc deleted file mode 100644 index df94e897ec..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/clip_by_value_fusion.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -bool GetMinimumOp(const AnfNodePtr &input0, const AnfNodePtr &input1, CNodePtr *minimum, bool *is_first_input) { - MS_EXCEPTION_IF_NULL(input0); - MS_EXCEPTION_IF_NULL(input1); - - CNodePtr cnode = nullptr; - if (input0->isa() && !input1->isa()) { - cnode = input0->cast(); - *is_first_input = true; - } else if (!input0->isa() && input1->isa()) { - cnode = input1->cast(); - *is_first_input = false; - } else if (input0->isa() && input1->isa()) { - if (AnfAlgo::GetCNodeName(input0) == prim::kPrimMinimum->name()) { - cnode = input0->cast(); - *is_first_input = true; - } else { - cnode = input1->cast(); - *is_first_input = false; - } - } else { - return false; - } - - if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMinimum->name()) { - return false; - } - *minimum = cnode; - return true; -} -} // namespace - -const BaseRef ClipByValueFusion::DefinePattern() const { - VectorRef pattern({prim::kPrimMaximum, maximum_input0_, maximum_input1_}); - return pattern; -} - -const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - auto maximum_input0 = utils::cast((*equiv)[maximum_input0_]); - auto maximum_input1 = utils::cast((*equiv)[maximum_input1_]); - MS_EXCEPTION_IF_NULL(maximum_input0); - MS_EXCEPTION_IF_NULL(maximum_input1); - - CNodePtr minimum = nullptr; - bool is_first_input = true; - if (!GetMinimumOp(maximum_input0, maximum_input1, &minimum, &is_first_input)) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(minimum); - if (minimum->inputs().size() != kMinimumInputNum) { - return nullptr; - } - - auto prim = std::make_shared(kClipByValueOpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = {NewValueNode(prim), minimum->input(1), - is_first_input ? maximum_input1 : maximum_input0, minimum->input(2)}; - auto clip_by_value = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(clip_by_value); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, clip_by_value.get()); - clip_by_value->set_scope(node->scope()); - return clip_by_value; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.h deleted file mode 100644 index 309b7cedd0..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ClipByValueFusion : public PatternProcessPass { - public: - explicit ClipByValueFusion(bool multigraph = true) : PatternProcessPass("clip_by_value_fusion", multigraph) { - maximum_input0_ = std::make_shared(); - maximum_input1_ = std::make_shared(); - } - ~ClipByValueFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr maximum_input0_; - VarPtr maximum_input1_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc deleted file mode 100644 index d49b2d47f3..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc +++ /dev/null @@ -1,151 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" -#include -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -const size_t kConfusionMulGradOutputNum = 2; - -CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum, const AnfNodePtr &mul0_anf, - const AnfNodePtr &input3) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(reduce_sum); - MS_EXCEPTION_IF_NULL(mul0_anf); - MS_EXCEPTION_IF_NULL(input3); - auto mul0 = mul0_anf->cast(); - MS_EXCEPTION_IF_NULL(mul0); - - auto prim = std::make_shared(kConfusionMulGradOpName); - std::vector inputs = {NewValueNode(prim), mul0->input(1), mul0->input(2), input3}; - auto fusion_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(reduce_sum->scope()); - AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node); - auto types = {AnfAlgo::GetOutputInferDataType(mul0, 0), AnfAlgo::GetOutputInferDataType(reduce_sum, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(mul0, 0), AnfAlgo::GetOutputInferShape(reduce_sum, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); - return fusion_node; -} - -AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const AnfNodePtr &mul1) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(input2); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(input2) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - - AnfNodePtr mul0 = nullptr; - const AnfNodeIndexSet &outputs_set = manager->node_users()[input2]; - // input2 must be the 2rd input of mul0 - auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul1](const std::pair &node_index) { - return node_index.first != mul1 && node_index.second == 2; - }); - if (it != outputs_set.end() && AnfAlgo::GetCNodeName(it->first) == prim::kPrimMul->name()) { - mul0 = it->first; - } - return mul0; -} - -bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, - const AnfNodePtr &reduce_sum, const AnfNodePtr &input2) { - MS_EXCEPTION_IF_NULL(mul0_anf); - MS_EXCEPTION_IF_NULL(mul1_anf); - MS_EXCEPTION_IF_NULL(reduce_sum); - MS_EXCEPTION_IF_NULL(input2); - auto addn = input2->cast(); - if (addn == nullptr || AnfAlgo::GetCNodeName(addn) != prim::kPrimAddN->name()) { - MS_LOG(INFO) << "mul's second input is not addn"; - return true; - } - std::vector shape = AnfAlgo::GetOutputInferShape(addn, 0); - if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) { - MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]"; - return true; - } - if (!mul0_anf->isa() || !mul1_anf->isa()) { - return true; - } - auto mul1 = mul1_anf->cast(); - MS_EXCEPTION_IF_NULL(mul1); - auto mul0 = mul0_anf->cast(); - MS_EXCEPTION_IF_NULL(mul0); - - if (IsDepend(graph, mul0->input(1), reduce_sum)) { - MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; - return true; - } - if (IsDepend(graph, mul1->input(1), mul0)) { - MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion"; - return true; - } - return false; -} -} // namespace - -const BaseRef ConfusionMulGradFusion::DefinePattern() const { - VectorRef mul1({prim::kPrimMul, input3_, input2_}); - VectorRef reduce_sum({prim::kPrimReduceSum, mul1}); - return reduce_sum; -} - -const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - auto input2 = utils::cast((*equiv)[input2_]); - auto input3 = utils::cast((*equiv)[input3_]); - auto reduce_sum = node->cast(); - MS_EXCEPTION_IF_NULL(reduce_sum); - auto mul1 = reduce_sum->input(1); - if (IsUsedByOthers(graph, mul1)) { - MS_LOG(INFO) << "Mul1 is used by others, quit fusion!"; - return nullptr; - } - auto mul0 = GetMul0(graph, input2, mul1); - if (mul0 == nullptr) { - MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; - return nullptr; - } - if (QuitFusion(graph, mul0, mul1, node, input2)) { - return nullptr; - } - - auto fusion_node = CreateFusionNode(graph, reduce_sum, mul0, input3); - std::vector fusion_node_outputs; - CreateMultipleOutputsOfAnfNode(graph, fusion_node, kConfusionMulGradOutputNum, &fusion_node_outputs); - - auto manage = graph->manager(); - MS_EXCEPTION_IF_NULL(manage); - manage->Replace(mul0, fusion_node_outputs[0]); - return fusion_node_outputs[1]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h deleted file mode 100644 index 170df5b0e4..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConfusionMulGradFusion : public PatternProcessPass { - public: - explicit ConfusionMulGradFusion(bool multigraph = true) - : PatternProcessPass("confusion_mul_grad_fusion", multigraph) { - input2_ = std::make_shared(); - input3_ = std::make_shared(); - } - ~ConfusionMulGradFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input2_; - VarPtr input3_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc deleted file mode 100644 index 9e2c6374ce..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h" - -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { - return VectorRef({prim::kPrimSub, input0_, VectorRef({reduce_sum_, VectorRef({prim::kPrimMul, input1_, input0_})})}); -} - -const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - AnfNodePtr input0 = GetAnfNodeByVar(equiv, input0_); - AnfNodePtr input1 = GetAnfNodeByVar(equiv, input1_); - AnfNodePtr sum_anf = GetAnfNodeByVar(equiv, reduce_sum_); - if (sum_anf == nullptr || !sum_anf->isa()) { - MS_LOG(WARNING) << "Matched ReduceSum is not a CNode!"; - return nullptr; - } - if (!GetBoolAttr(sum_anf, kAttrKeepDims)) { - MS_LOG(INFO) << "ReduceSum's attr keep_dims should be true if do fusion. Otherwise the calculation will be wrong"; - return nullptr; - } - - auto prim = std::make_shared(kConfusionSoftmaxGradOpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = {NewValueNode(prim), input0, input1}; - auto fusion_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_abstract(node->abstract()); - fusion_node->set_scope(node->scope()); - AnfAlgo::CopyNodeAttr(kAttrAxis, sum_anf, fusion_node); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum_anf, fusion_node); - return fusion_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h deleted file mode 100644 index a4d0d1ce7a..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConfusionSoftmaxGradRule : public PatternProcessPass { - public: - explicit ConfusionSoftmaxGradRule(bool multigraph = true) - : PatternProcessPass("confusion_softmax_grad_rule", multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - reduce_sum_ = std::make_shared(std::make_shared(prim::kPrimReduceSum->name())); - } - ~ConfusionSoftmaxGradRule() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input0_; - VarPtr input1_; - VarPtr reduce_sum_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc deleted file mode 100644 index 2f3c998bb8..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc +++ /dev/null @@ -1,121 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/derelu_fusion.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -const size_t kReluV2OutputNum = 2; - -CNodePtr GetRelu(const CNodePtr &relu_grad) { - MS_EXCEPTION_IF_NULL(relu_grad); - if (relu_grad->size() != kReluGradInputNum) { - MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); - } - auto relu_anf = relu_grad->input(2); - MS_EXCEPTION_IF_NULL(relu_anf); - return relu_anf->cast(); -} - -CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(relu); - if (relu->size() != kReluInputNum) { - MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); - } - - auto prim = std::make_shared(kReluV2OpName); - std::vector inputs = {NewValueNode(prim), relu->input(1)}; - auto new_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(relu->scope()); - - // ReluV2's 2rd output is mask whose data type is uint8 - TypeId mask_dtype = kNumberTypeUInt8; - std::vector mask_shape = AnfAlgo::GetOutputInferShape(relu, 0); - if (mask_shape.size() != 4) { - MS_LOG(DEBUG) << "relu's infer shape size not equal 4"; - return nullptr; - } - auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0); - if (input_dtype == kNumberTypeUInt8 || input_dtype == kNumberTypeInt8) { - mask_shape[1] = (mask_shape[1] + 31) / 32; - mask_shape.push_back(4); - } else { - mask_shape[1] = (mask_shape[1] + 15) / 16; - mask_shape.push_back(2); - } - - auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype}; - auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); - return new_node; -} - -CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(relu_grad); - MS_EXCEPTION_IF_NULL(second_input); - - auto prim = std::make_shared(kReluGradV2OpName); - std::vector inputs = {NewValueNode(prim), relu_grad->input(1), second_input}; - auto new_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(relu_grad->scope()); - new_node->set_abstract(relu_grad->abstract()); - return new_node; -} -} // namespace - -const BaseRef DereluFusion::DefinePattern() const { - VarPtr i0 = std::make_shared(); - VarPtr i1 = std::make_shared(); - VectorRef relu({prim::kPrimRelu, i1}); - VectorRef relu_grad({prim::kPrimReluGrad, i0, relu}); - return relu_grad; -} - -const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto relu_grad = node->cast(); - MS_EXCEPTION_IF_NULL(relu_grad); - auto relu = GetRelu(relu_grad); - MS_EXCEPTION_IF_NULL(relu); - - auto relu_v2 = CreateReluV2(graph, relu); - if (relu_v2 == nullptr) { - return nullptr; - } - std::vector relu_v2_node_outputs; - CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); - - auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]); - - auto manage = graph->manager(); - MS_EXCEPTION_IF_NULL(manage); - manage->Replace(relu, relu_v2_node_outputs[0]); - return relu_grad_v2; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.h deleted file mode 100644 index e1811f4db4..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class DereluFusion : public PatternProcessPass { - public: - explicit DereluFusion(bool multigraph = true) : PatternProcessPass("derelu_fusion", multigraph) {} - ~DereluFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc deleted file mode 100644 index efc9ee7934..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc +++ /dev/null @@ -1,340 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" -#include -#include -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kReplaceOutputIndex0 = 3; -constexpr size_t kReplaceOutputIndex1 = 4; -bool IsC(const BaseRef &n) { - if (utils::isa(n)) { - AnfNodePtr in = utils::cast(n); - MS_EXCEPTION_IF_NULL(in); - return in->isa(); - } - return false; -} - -void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn); - MS_EXCEPTION_IF_NULL(bn_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(bn) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs"; - } - for (const auto &node_index : manager->node_users()[bn]) { - AnfNodePtr output = node_index.first; - MS_EXCEPTION_IF_NULL(output); - bn_outputs->push_back(output); - } -} -} // namespace - -const BaseRef FusedBatchNormFusion::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - VarPtr index0 = std::make_shared(IsC); - VarPtr index1 = std::make_shared(IsC); - VarPtr index2 = std::make_shared(IsC); - VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); - VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); - VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); - VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); - VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1}); - VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2}); - VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); - VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); - VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); - VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); - VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); - return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); -} - -ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(equiv); - auto iter_constant_input0 = (*equiv).find(constant_input0_var_); - if (iter_constant_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the constant_input0 var after matched."; - } - auto constant_input = utils::cast(iter_constant_input0->second); - MS_EXCEPTION_IF_NULL(constant_input); - if (!constant_input->isa()) { - return nullptr; - } - auto value_node = constant_input->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - if (!value->isa()) { - return nullptr; - } - auto tensor_ptr = value->cast(); - MS_EXCEPTION_IF_NULL(tensor_ptr); - if (tensor_ptr->data_type() == kNumberTypeFloat16) { - auto *half_data = static_cast(tensor_ptr->data_c()); - MS_EXCEPTION_IF_NULL(half_data); - float float_data = Eigen::half_impl::half_to_float(half_data[0]); - return MakeValue(float_data); - } else if (tensor_ptr->data_type() == kNumberTypeFloat32) { - auto *tensor_data = static_cast(tensor_ptr->data_c()); - MS_EXCEPTION_IF_NULL(tensor_data); - return MakeValue(tensor_data[0]); - } else { - MS_LOG(WARNING) << "The factor data type of value node " << value_node->DebugString() << " is not fp16 or fp32"; - return nullptr; - } -} - -AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - // Set input to create node - auto iter_data_input0 = (*equiv).find(data_input0_var_); - if (iter_data_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; - } - std::vector bn_training_reduce_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceOpName)), - utils::cast(iter_data_input0->second)}; - auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); - MS_EXCEPTION_IF_NULL(bn_training_reduce); - bn_training_reduce->set_scope(node->scope()); - // Set abstract - auto iter_data_input1 = (*equiv).find(data_input1_var_); - if (iter_data_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; - } - auto data_input1 = utils::cast(iter_data_input1->second); - MS_EXCEPTION_IF_NULL(data_input1); - auto iter_data_input2 = (*equiv).find(data_input2_var_); - if (iter_data_input2 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; - } - auto data_input2 = utils::cast(iter_data_input2->second); - MS_EXCEPTION_IF_NULL(data_input2); - AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()}; - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_reduce->set_abstract(abstract_tuple); - return bn_training_reduce; -} - -void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv, - const std::vector &bn_training_reduce_outputs, - std::vector *bn_training_update_inputs) const { - MS_EXCEPTION_IF_NULL(equiv); - MS_EXCEPTION_IF_NULL(bn_training_update_inputs); - auto iter_data_input0 = (*equiv).find(data_input0_var_); - if (iter_data_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; - } - auto iter_data_input1 = (*equiv).find(data_input1_var_); - if (iter_data_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; - } - auto iter_data_input2 = (*equiv).find(data_input2_var_); - if (iter_data_input2 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; - } - auto iter_variable_input0 = (*equiv).find(variable_input0_var_); - if (iter_variable_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; - } - auto iter_variable_input1 = (*equiv).find(variable_input1_var_); - if (iter_variable_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; - } - if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum - << ", but it is " << bn_training_reduce_outputs.size(); - } - *bn_training_update_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateOpName)), - utils::cast(iter_data_input0->second), - bn_training_reduce_outputs[0], - bn_training_reduce_outputs[1], - utils::cast(iter_data_input1->second), - utils::cast(iter_data_input2->second), - utils::cast(iter_variable_input0->second), - utils::cast(iter_variable_input1->second), - }; -} - -void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn, - std::vector *abstract_list) const { - MS_EXCEPTION_IF_NULL(equiv); - MS_EXCEPTION_IF_NULL(bn); - MS_EXCEPTION_IF_NULL(abstract_list); - auto bn_abstract_tuple = dyn_cast(bn->abstract()); - MS_EXCEPTION_IF_NULL(bn_abstract_tuple); - if (bn_abstract_tuple->elements().size() < kBnOutputNum) { - MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is " - << bn_abstract_tuple->elements().size(); - } - auto iter_variable_input0 = (*equiv).find(variable_input0_var_); - if (iter_variable_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; - } - auto variable_input0 = utils::cast(iter_variable_input0->second); - MS_EXCEPTION_IF_NULL(variable_input0); - auto iter_variable_input1 = (*equiv).find(variable_input1_var_); - if (iter_variable_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; - } - auto variable_input1 = utils::cast(iter_variable_input1->second); - MS_EXCEPTION_IF_NULL(variable_input1); - *abstract_list = {bn_abstract_tuple->elements()[0], variable_input0->abstract(), variable_input1->abstract(), - bn_abstract_tuple->elements()[1], bn_abstract_tuple->elements()[2]}; -} - -AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate( - const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, - const std::vector &bn_training_reduce_outputs) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - // Set input - std::vector bn_training_update_inputs; - GetBNTrainingUpdateInputs(equiv, bn_training_reduce_outputs, &bn_training_update_inputs); - auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs); - MS_EXCEPTION_IF_NULL(bn_training_update); - // Set abstract - auto iter_batch_norm = (*equiv).find(batch_norm_var_); - if (iter_batch_norm == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."; - } - AnfNodePtr bn = utils::cast(iter_batch_norm->second); - MS_EXCEPTION_IF_NULL(bn); - AbstractBasePtrList abstract_list; - GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list); - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_update->set_abstract(abstract_tuple); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn, bn_training_update); - ValuePtr factor = GetFactor(equiv); - if (factor == nullptr) { - return nullptr; - } - AnfAlgo::SetNodeAttr(kAttrFactor, factor, bn_training_update); - AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update); - bn_training_update->set_scope(node->scope()); - return bn_training_update; -} - -const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - MS_EXCEPTION_IF_NULL(node); - AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node, equiv); - std::vector bn_training_reduce_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, - &bn_training_reduce_outputs); - AnfNodePtr bn_training_update = CreateBNTrainingUpdate(func_graph, node, equiv, bn_training_reduce_outputs); - if (bn_training_update == nullptr) { - MS_LOG(DEBUG) << "Create BNTrainingUpdate failed for bn node " << node->DebugString(); - return nullptr; - } - std::vector bn_training_update_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update, kBNTrainingUpdateOutputNum, - &bn_training_update_outputs); - if (bn_training_update_outputs.size() < kBNTrainingUpdateOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn must be " << kBNTrainingUpdateOutputNum << ", but it is " - << bn_training_update_outputs.size(); - } - // Replace old bn outputs with new outputs - auto iter_batch_norm = (*equiv).find(batch_norm_var_); - if (iter_batch_norm == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."; - } - AnfNodePtr bn = utils::cast(iter_batch_norm->second); - std::vector bn_outputs; - GetBNOutput(func_graph, bn, &bn_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - for (const auto &output : bn_outputs) { - MS_EXCEPTION_IF_NULL(output); - if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { - continue; - } - auto tuple_getitem_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); - AnfNodePtr index_node = tuple_getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (index == kReplaceOutputIndex0 || index == kReplaceOutputIndex1) { - (void)manager->Replace(output, bn_training_update_outputs[index]); - } - } - return bn_training_update_outputs[0]; -} - -const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - VarPtr index0 = std::make_shared(IsC); - VarPtr index1 = std::make_shared(IsC); - VarPtr index2 = std::make_shared(IsC); - VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); - VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); - VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); - VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); - VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); - VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); - VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); - VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); - VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); - VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); - VectorRef cast2 = VectorRef({prim::kPrimCast, mul0}); - VectorRef cast3 = VectorRef({prim::kPrimCast, mul1}); - VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2}); - VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3}); - VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); - return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); -} - -const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - VarPtr index0 = std::make_shared(IsC); - VarPtr index1 = std::make_shared(IsC); - VarPtr index2 = std::make_shared(IsC); - VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); - VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); - VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); - VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); - VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); - VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); - VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); - VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); - VectorRef cast0 = VectorRef({prim::kPrimCast, sub0}); - VectorRef cast1 = VectorRef({prim::kPrimCast, sub1}); - VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_}); - VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_}); - VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); - VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); - VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); - return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h deleted file mode 100644 index f476e96062..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -class FusedBatchNormFusion : public PatternProcessPass { - public: - explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true) - : PatternProcessPass(name, multigraph), - data_input0_var_(std::make_shared()), - data_input1_var_(std::make_shared()), - data_input2_var_(std::make_shared()), - variable_input0_var_(std::make_shared()), - variable_input1_var_(std::make_shared()), - constant_input0_var_(std::make_shared()), - constant_input1_var_(std::make_shared()), - batch_norm_var_(std::make_shared(std::make_shared(prim::kPrimBatchNorm->name()))) {} - ~FusedBatchNormFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - protected: - AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const; - void GetBNTrainingUpdateInputs(const EquivPtr &equiv, const std::vector &bn_training_reduce_outputs, - std::vector *bn_training_update_inputs) const; - void GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn, - std::vector *abstract_list) const; - AnfNodePtr CreateBNTrainingUpdate(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, - const std::vector &bn_training_reduce_outputs) const; - ValuePtr GetFactor(const EquivPtr &equiv) const; - - VarPtr data_input0_var_; - VarPtr data_input1_var_; - VarPtr data_input2_var_; - VarPtr variable_input0_var_; - VarPtr variable_input1_var_; - VarPtr constant_input0_var_; - VarPtr constant_input1_var_; - VarPtr batch_norm_var_; -}; - -class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion { - public: - explicit FusedBatchNormMixPrecisionFusion0(bool multigraph = true) - : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} - - ~FusedBatchNormMixPrecisionFusion0() override = default; - const BaseRef DefinePattern() const override; -}; - -class FusedBatchNormMixPrecisionFusion1 : public FusedBatchNormFusion { - public: - explicit FusedBatchNormMixPrecisionFusion1(bool multigraph = true) - : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} - - ~FusedBatchNormMixPrecisionFusion1() override = default; - const BaseRef DefinePattern() const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc deleted file mode 100644 index 42e37df3e4..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc +++ /dev/null @@ -1,266 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h" -#include -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, - std::vector *old_pattern_outputs) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto real_div0 = GetAnfNodeByVar(equiv, real_div0_var_); - auto real_div2 = GetAnfNodeByVar(equiv, real_div2_var_); - - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &users = manager->node_users(); - if (users.find(real_div0) == users.end() || users[real_div0].size() < 2) { - return false; - } - AnfNodeIndexSet real_div0_outputs = users[real_div0]; - auto iter = std::find_if(real_div0_outputs.begin(), real_div0_outputs.end(), - [&real_div2, &equiv, this](const std::pair &node_index) { - return node_index.first != real_div2 && node_index.second == 1 && - MatchAnotherPattern(node_index.first, equiv); - }); - if (iter == real_div0_outputs.end()) { - return false; - } - - (*old_pattern_outputs).push_back(node); - (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add0_var_)); - (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add1_var_)); - (*old_pattern_outputs).push_back(iter->first); - - return true; -} - -AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph, - const std::vector &old_pattern_outputs, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - auto prim = std::make_shared(kLambNextMVOpName); - std::vector lamb_next_mv_rule_inputs = {NewValueNode(prim)}; - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input0_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input1_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input2_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input3_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input4_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input5_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input6_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul0_x_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul1_sub_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul2_x_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul3_sub1_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul4_x_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[add2_y_])); - auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs); - MS_EXCEPTION_IF_NULL(lamb_next_mv_rule); - - // Set abstract of new node - AbstractBasePtrList new_abstracts; - (void)std::transform(old_pattern_outputs.begin(), old_pattern_outputs.end(), std::back_inserter(new_abstracts), - [](const AnfNodePtr &out) { return out->abstract(); }); - auto abstract_tuple = std::make_shared(new_abstracts); - MS_EXCEPTION_IF_NULL(abstract_tuple); - lamb_next_mv_rule->set_abstract(abstract_tuple); - - // Create tuple_getitem node for outputs - std::vector lamb_next_mv_rule_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, lamb_next_mv_rule, kLambNextMVRuleOutputNum, &lamb_next_mv_rule_outputs); - - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(old_pattern_outputs[1], lamb_next_mv_rule_outputs[1]); - (void)manager->Replace(old_pattern_outputs[2], lamb_next_mv_rule_outputs[2]); - (void)manager->Replace(old_pattern_outputs[3], lamb_next_mv_rule_outputs[3]); - - return lamb_next_mv_rule_outputs[0]; -} - -bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { - return IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) && - IsSameNode(equiv1, equiv2, add2_y_); -} - -const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - std::vector old_pattern_outputs; - if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) { - return nullptr; - } - - return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv); -} - -const BaseRef LambNextMVRuleCond1::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - - auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); - auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); - auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); - auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); - auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); - auto add0 = VectorRef({add0_var_, mul0, mul1}); - auto add1 = VectorRef({add1_var_, mul2, mul3}); - - auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); - auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); - - auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); - auto sqrt0 = VectorRef({prim_rsqrt, add2}); - auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); - - return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); -} - -BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - // Two patterns share: real_div0, real_div1, add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1}); - VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); - return real_div4; -} - -const BaseRef LambNextMVRuleCond2::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - - auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); - auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); - auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); - auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); - auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); - auto add0 = VectorRef({add0_var_, mul0, mul1}); - auto add1 = VectorRef({add1_var_, mul2, mul3}); - - auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); - auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); - - auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); - auto sqrt0 = VectorRef({prim_rsqrt, add2}); - auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); - - return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); -} - -BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - // Two patterns share: real_div0, real_div1, add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); - VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); - return real_div4; -} - -const BaseRef LambNextMVRuleCond3::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - - auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); - auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); - auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); - auto mul3 = VectorRef({prim::kPrimMul, input0_, mul3_sub1_}); - auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); - auto add0 = VectorRef({add0_var_, mul0, mul1}); - auto add1 = VectorRef({add1_var_, mul2, mul3}); - - auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); - auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); - - auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); - auto sqrt0 = VectorRef({prim_rsqrt, add2}); - auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); - - return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); -} - -BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - // Two patterns share: real_div0, real_div1, add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); - VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); - return real_div4; -} - -const BaseRef LambNextMVRuleCond4::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - - auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); - auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); - auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); - auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); - auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); - auto add0 = VectorRef({add0_var_, mul0, mul1}); - auto add1 = VectorRef({add1_var_, mul2, mul3}); - - auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); - auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); - - auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); - auto sqrt0 = VectorRef({prim_rsqrt, add2}); - auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0}); - - return VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); -} - -BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - // Two patterns share: real_div0, real_div1, add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); - VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); - return real_div4; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h deleted file mode 100644 index 0089c33f87..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h +++ /dev/null @@ -1,128 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ - -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class LambNextMVRule : public MultipleOutputPatternProcessPass { - public: - explicit LambNextMVRule(const std::string &name = "", bool multigraph = true) - : MultipleOutputPatternProcessPass(name, multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - input2_ = std::make_shared(); - input3_ = std::make_shared(); - input4_ = std::make_shared(); - input5_ = std::make_shared(); - input6_ = std::make_shared(); - mul0_x_ = std::make_shared(); - mul1_sub_ = std::make_shared(); - mul2_x_ = std::make_shared(); - mul3_sub1_ = std::make_shared(); - mul4_x_ = std::make_shared(); - add2_y_ = std::make_shared(); - real_div0_var_ = std::make_shared(std::make_shared(kRealDivOpName)); - real_div1_var_ = std::make_shared(std::make_shared(kRealDivOpName)); - real_div2_var_ = std::make_shared(std::make_shared(prim::kPrimMul->name())); - add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - } - ~LambNextMVRule() override = default; - const BaseRef DefinePattern() const override = 0; - BaseRef DefineAnotherPattern() const override = 0; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; - - protected: - bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, - std::vector *old_pattern_outputs) const; - AnfNodePtr CreateLambNextMVNode(const FuncGraphPtr &func_graph, const std::vector &old_pattern_outputs, - const EquivPtr &equiv) const; - - VarPtr input0_; - VarPtr input1_; - VarPtr input2_; - VarPtr input3_; - VarPtr input4_; - VarPtr input5_; - VarPtr input6_; - VarPtr mul0_x_; - VarPtr mul1_sub_; - VarPtr mul2_x_; - VarPtr mul3_sub1_; - VarPtr mul4_x_; - VarPtr add2_y_; - // nodes which two patterns share, and add2_y_ also. - VarPtr real_div0_var_; - VarPtr real_div1_var_; - // part of output nodes - VarPtr add0_var_; - VarPtr add1_var_; - // other node - VarPtr real_div2_var_; -}; - -class LambNextMVRuleCond1 : public LambNextMVRule { - public: - explicit LambNextMVRuleCond1(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond1", multigraph) {} - - ~LambNextMVRuleCond1() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVRuleCond2 : public LambNextMVRule { - public: - explicit LambNextMVRuleCond2(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond2", multigraph) {} - - ~LambNextMVRuleCond2() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVRuleCond3 : public LambNextMVRule { - public: - explicit LambNextMVRuleCond3(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond3", multigraph) {} - - ~LambNextMVRuleCond3() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVRuleCond4 : public LambNextMVRule { - public: - explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {} - - ~LambNextMVRuleCond4() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc deleted file mode 100644 index 0e3cd28a66..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc +++ /dev/null @@ -1,278 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" - -namespace mindspore { -namespace opt { -AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, - const AnfNodePtr &new_node, const AnfNodePtr &add3, - const AnfNodePtr &add5, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(new_node); - MS_EXCEPTION_IF_NULL(add3); - MS_EXCEPTION_IF_NULL(add5); - MS_EXCEPTION_IF_NULL(equiv); - auto add0 = GetAnfNodeByVar(equiv, add0_var_); - MS_EXCEPTION_IF_NULL(add0); - auto add1 = GetAnfNodeByVar(equiv, add1_var_); - MS_EXCEPTION_IF_NULL(add1); - - // Set abstract of new node - AbstractBasePtrList new_node_list; - new_node_list.push_back(add3->abstract()); - new_node_list.push_back(add0->abstract()); - new_node_list.push_back(add1->abstract()); - new_node_list.push_back(add5->abstract()); - auto abstract_tuple = std::make_shared(new_node_list); - MS_EXCEPTION_IF_NULL(abstract_tuple); - new_node->set_abstract(abstract_tuple); - // Create tuple_getitem node for outputs - std::vector new_node_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, new_node, kLambNextMVWithDecayOutputNum, &new_node_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(add3, new_node_outputs[0]); - (void)manager->Replace(add0, new_node_outputs[1]); - (void)manager->Replace(add1, new_node_outputs[2]); - return new_node_outputs[3]; -} - -AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, - const AnfNodePtr &add3, const AnfNodePtr &add5, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(add3); - MS_EXCEPTION_IF_NULL(equiv); - // Create new node with all the inputs - auto prim = std::make_shared(kLambNextMVWithDecayOpName); - std::vector new_node_inputs = {NewValueNode(prim)}; - for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { - auto input_node = utils::cast((*equiv)[input_vars_[i]]); - MS_EXCEPTION_IF_NULL(input_node); - new_node_inputs.push_back(input_node); - } - for (size_t i = 0; i < kLambNextMVWithDecayConstantMulInputNum; ++i) { - auto constant_mul_input_node = utils::cast((*equiv)[constant_mul_input_vars_[i]]); - MS_EXCEPTION_IF_NULL(constant_mul_input_node); - new_node_inputs.push_back(constant_mul_input_node); - } - auto constant_add2_y_node = utils::cast((*equiv)[constant_add2_y_]); - MS_EXCEPTION_IF_NULL(constant_add2_y_node); - new_node_inputs.push_back(constant_add2_y_node); - auto new_node = func_graph->NewCNode(new_node_inputs); - return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv); -} - -bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { - return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) && - IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_); -} - -const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_); - MS_EXCEPTION_IF_NULL(mul4); - // Get add3 and match the add3 pattern - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(mul4) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; - } - AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4]; - auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(), - [&node, &equiv, this](const std::pair &node_index) { - return node_index.first != node && MatchAnotherPattern(node_index.first, equiv); - }); - if (iter != mul4_outputs.end()) { - return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv); - } - return nullptr; -} - -BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - VarPtr Zs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Ys); - MS_EXCEPTION_IF_NULL(Zs); - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - VectorRef mul4 = VectorRef({mul4_var_, Zs}); - - VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); - VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); - VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); - VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); - return add3; -} - -const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_deal_div); - VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[0], constant_mul_input_vars_[3]}); - VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); - VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); - VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); - VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); - VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); - VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); - VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); - return add5; -} - -BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - VarPtr Zs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Ys); - MS_EXCEPTION_IF_NULL(Zs); - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - VectorRef mul4 = VectorRef({mul4_var_, Zs}); - - VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); - VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); - VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); - VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); - return add3; -} - -const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_deal_div); - VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); - VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); - VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1}); - VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); - VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); - VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); - VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); - VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); - return add5; -} - -BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - VarPtr Zs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Ys); - MS_EXCEPTION_IF_NULL(Zs); - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - VectorRef mul4 = VectorRef({mul4_var_, Zs}); - - VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); - VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); - VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); - VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); - return add3; -} - -const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_deal_div); - VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); - VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); - VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); - VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); - VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); - VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); - VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]}); - VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); - return add5; -} - -BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - VarPtr Zs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Ys); - MS_EXCEPTION_IF_NULL(Zs); - // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - VectorRef mul4 = VectorRef({mul4_var_, Zs}); - - VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); - VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); - VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); - VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); - return add3; -} - -const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_deal_div); - VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); - VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); - VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); - VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); - VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); - VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); - VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); - VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); - return add5; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h deleted file mode 100644 index 5d61975197..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h +++ /dev/null @@ -1,110 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass { - public: - explicit LambNextMVWithDecayRule(const std::string &name = "", bool multigraph = true) - : MultipleOutputPatternProcessPass(name, multigraph) { - for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { - input_vars_.push_back(std::make_shared()); - } - for (size_t i = 0; i < kLambNextMVWithDecayConstantMulInputNum; ++i) { - constant_mul_input_vars_.push_back(std::make_shared()); - } - constant_add2_y_ = std::make_shared(); - mul4_var_ = std::make_shared(std::make_shared(prim::kPrimMul->name())); - real_div0_var_ = std::make_shared(std::make_shared(kRealDivOpName)); - real_div1_var_ = std::make_shared(std::make_shared(kRealDivOpName)); - add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - } - - ~LambNextMVWithDecayRule() override = default; - const BaseRef DefinePattern() const override = 0; - BaseRef DefineAnotherPattern() const override = 0; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; - - protected: - AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, - const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const; - AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, - const AnfNodePtr &add5, const EquivPtr &equiv) const; - std::vector input_vars_; - std::vector constant_mul_input_vars_; - // nodes which two patterns share - VarPtr constant_add2_y_; - VarPtr mul4_var_; - VarPtr real_div0_var_; - VarPtr real_div1_var_; - // part of output nodes - VarPtr add0_var_; - VarPtr add1_var_; -}; - -class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule { - public: - explicit LambNextMVWithDecayRuleCond1(bool multigraph = true) - : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond1", multigraph) {} - - ~LambNextMVWithDecayRuleCond1() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule { - public: - explicit LambNextMVWithDecayRuleCond2(bool multigraph = true) - : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond2", multigraph) {} - - ~LambNextMVWithDecayRuleCond2() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule { - public: - explicit LambNextMVWithDecayRuleCond3(bool multigraph = true) - : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond3", multigraph) {} - - ~LambNextMVWithDecayRuleCond3() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVWithDecayRuleCond4 : public LambNextMVWithDecayRule { - public: - explicit LambNextMVWithDecayRuleCond4(bool multigraph = true) - : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond4", multigraph) {} - - ~LambNextMVWithDecayRuleCond4() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc deleted file mode 100644 index 26828f2137..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h" - -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" - -namespace mindspore { -namespace opt { -namespace { -std::tuple GetSharedNodes(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto add3 = node->cast(); - MS_EXCEPTION_IF_NULL(add3); - if (add3->inputs().size() < kAddInputNum) { - MS_LOG(EXCEPTION) << "The input size of Add3 is less than " << kAddInputNum; - } - auto real_div2_anf = add3->input(1); - MS_EXCEPTION_IF_NULL(real_div2_anf); - auto real_div2 = real_div2_anf->cast(); - MS_EXCEPTION_IF_NULL(real_div2); - if (real_div2->inputs().size() < kRealDivInputNum) { - MS_LOG(EXCEPTION) << "The input size of RealDiv2 is less than " << kRealDivInputNum; - } - auto sqrt0_anf = real_div2->input(2); - MS_EXCEPTION_IF_NULL(sqrt0_anf); - auto sqrt0 = sqrt0_anf->cast(); - MS_EXCEPTION_IF_NULL(sqrt0); - if (sqrt0->inputs().size() < kRsqrtInputNum) { - MS_LOG(EXCEPTION) << "The input size of Sqrt0 is less than " << kSqrtInputNum; - } - auto add2_anf = sqrt0->input(1); - MS_EXCEPTION_IF_NULL(add2_anf); - auto add2 = add2_anf->cast(); - if (add2->inputs().size() < kAddInputNum) { - MS_LOG(EXCEPTION) << "The input size of Add2 is less than " << kAddInputNum; - } - return std::make_tuple(add3->input(2), real_div2->input(1), add2->input(1), add2->input(2)); -} - -bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfNodePtr &real_div0, - const AnfNodePtr &real_div1, const AnfNodePtr &add2_y) { - if (node == nullptr || !node->isa()) { - return false; - } - auto add5 = node->cast(); - if (AnfAlgo::GetCNodeName(add5) != prim::kPrimTensorAdd->name() || add5->inputs().size() != kAddInputNum) { - return false; - } - auto real_div4_anf = add5->input(1); - if (real_div4_anf == nullptr || !real_div4_anf->isa()) { - return false; - } - auto real_div4 = real_div4_anf->cast(); - if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName || real_div4->inputs().size() != kRealDivInputNum) { - return false; - } - auto add4_anf = real_div4->input(2); - if (add4_anf == nullptr || !add4_anf->isa()) { - return false; - } - auto add4 = add4_anf->cast(); - if (AnfAlgo::GetCNodeName(add4) != prim::kPrimTensorAdd->name() || add4->inputs().size() != kAddInputNum) { - return false; - } - auto sqrt1_anf = add4->input(1); - if (sqrt1_anf == nullptr || !sqrt1_anf->isa()) { - return false; - } - auto sqrt1 = sqrt1_anf->cast(); - if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || sqrt1->inputs().size() != kSqrtInputNum) { - return false; - } - return add5->input(2) == mul4 && real_div4->input(1) == real_div0 && sqrt1->input(1) == real_div1 && - *add4->input(2) == *add2_y; -} - -std::tuple GetAdd0Add1Nodes(const AnfNodePtr &real_div0_anf, const AnfNodePtr &real_div1_anf) { - MS_EXCEPTION_IF_NULL(real_div0_anf); - MS_EXCEPTION_IF_NULL(real_div1_anf); - auto real_div0 = real_div0_anf->cast(); - auto real_div1 = real_div1_anf->cast(); - MS_EXCEPTION_IF_NULL(real_div0); - MS_EXCEPTION_IF_NULL(real_div1); - if (real_div0->inputs().size() != kRealDivInputNum) { - MS_LOG(EXCEPTION) << "RealDiv0 has wrong input size"; - } - if (real_div1->inputs().size() != kRealDivInputNum) { - MS_LOG(EXCEPTION) << "RealDiv1 has wrong input size"; - } - return std::make_tuple(real_div0->input(1), real_div1->input(1)); -} -} // namespace - -std::vector LambNextMVWithDecayV1Rule::GetFusionNodeInputs(const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(equiv); - auto i0 = utils::cast((*equiv)[input0_]); - auto i1 = utils::cast((*equiv)[input1_]); - auto i2 = utils::cast((*equiv)[input2_]); - auto i3 = utils::cast((*equiv)[input3_]); - auto i4 = utils::cast((*equiv)[input4_]); - auto i5 = utils::cast((*equiv)[input5_]); - auto i6 = utils::cast((*equiv)[input6_]); - auto i7 = utils::cast((*equiv)[mul0_x_]); - auto i8 = utils::cast((*equiv)[mul1_sub_]); - auto i9 = utils::cast((*equiv)[mul2_x_]); - auto i10 = utils::cast((*equiv)[mul3_sub1_]); - auto i11 = utils::cast((*equiv)[mul4_x_]); - auto i12 = utils::cast((*equiv)[add2_y_]); - auto prim = std::make_shared(kLambNextMVWithDecayV1OpName); - return {NewValueNode(prim), i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12}; -} - -const BaseRef LambNextMVWithDecayV1Rule::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_}); - VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); - VectorRef add1({prim::kPrimTensorAdd, mul2, mul3}); - VectorRef real_div1({prim_real_div, add1, input2_}); - VectorRef add2({prim::kPrimTensorAdd, real_div1, add2_y_}); - VectorRef mul0({prim::kPrimMul, mul0_x_, input4_}); - VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_}); - VectorRef sqrt0({prim_rsqrt, add2}); - VectorRef add0({prim::kPrimTensorAdd, mul0, mul1}); - VectorRef real_div0({prim_real_div, add0, input5_}); - VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0}); - VectorRef mul4({prim::kPrimMul, mul4_x_, input6_}); - VectorRef add3({prim::kPrimTensorAdd, real_div2, mul4}); - return add3; -} - -const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - if (func_graph == nullptr || node == nullptr || equiv == nullptr) { - return nullptr; - } - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - AnfNodePtr mul4 = nullptr; - AnfNodePtr real_div0 = nullptr; - AnfNodePtr real_div1 = nullptr; - AnfNodePtr add2_y = nullptr; - std::tie(mul4, real_div0, real_div1, add2_y) = GetSharedNodes(node); - - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(mul4) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; - } - AnfNodeIndexSet mul4_output_node_index_set = manager->node_users()[mul4]; - auto iter = std::find_if( - mul4_output_node_index_set.begin(), mul4_output_node_index_set.end(), - [&node, &mul4, &real_div0, &real_div1, &add2_y](const std::pair &node_index) { - return node_index.first != node && MatchAdd5Pattern(node_index.first, mul4, real_div0, real_div1, add2_y); - }); - if (iter == mul4_output_node_index_set.end()) { - return nullptr; - } - - std::vector inputs = GetFusionNodeInputs(equiv); - auto fusion_node = func_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(node->scope()); - - AnfNodePtr add0 = nullptr; - AnfNodePtr add1 = nullptr; - AnfNodePtr add5 = iter->first; - std::tie(add0, add1) = GetAdd0Add1Nodes(real_div0, real_div1); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(add0, 0), - AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add5, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0), AnfAlgo::GetOutputInferShape(add0, 0), - AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add5, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); - - std::vector fusion_node_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, fusion_node, kLambNextMVWithDecayV1OutputNum, &fusion_node_outputs); - if (fusion_node_outputs.size() != kLambNextMVWithDecayV1OutputNum) { - MS_LOG(ERROR) << "create multiple outputs for fusion node fail!"; - return nullptr; - } - - (void)manager->Replace(add0, fusion_node_outputs[1]); - (void)manager->Replace(add1, fusion_node_outputs[2]); - (void)manager->Replace(add5, fusion_node_outputs[3]); - return fusion_node_outputs[0]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h deleted file mode 100644 index ff14a253dd..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ - -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -class LambNextMVWithDecayV1Rule : public PatternProcessPass { - public: - explicit LambNextMVWithDecayV1Rule(bool multigraph = true) - : PatternProcessPass("lamb_next_mv_with_decay_v1_rule", multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - input2_ = std::make_shared(); - input3_ = std::make_shared(); - input4_ = std::make_shared(); - input5_ = std::make_shared(); - input6_ = std::make_shared(); - mul0_x_ = std::make_shared(); - mul1_sub_ = std::make_shared(); - mul2_x_ = std::make_shared(); - mul3_sub1_ = std::make_shared(); - mul4_x_ = std::make_shared(); - add2_y_ = std::make_shared(); - } - - ~LambNextMVWithDecayV1Rule() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - std::vector GetFusionNodeInputs(const EquivPtr &equiv) const; - VarPtr input0_; - VarPtr input1_; - VarPtr input2_; - VarPtr input3_; - VarPtr input4_; - VarPtr input5_; - VarPtr input6_; - VarPtr mul0_x_; - VarPtr mul1_sub_; - VarPtr mul2_x_; - VarPtr mul3_sub1_; - VarPtr mul4_x_; - VarPtr add2_y_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc deleted file mode 100644 index 5065c4c5ba..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" -#include -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - std::vector new_node_inputs; - auto prim = std::make_shared(kLambNextRightOpName); - MS_EXCEPTION_IF_NULL(prim); - new_node_inputs.push_back(NewValueNode(prim)); - auto input0 = utils::cast((*equiv)[input0_]); - MS_EXCEPTION_IF_NULL(input0); - new_node_inputs.push_back(input0); - auto input1 = utils::cast((*equiv)[input1_]); - MS_EXCEPTION_IF_NULL(input1); - new_node_inputs.push_back(input1); - auto mul2_x = utils::cast((*equiv)[mul2_x_]); - MS_EXCEPTION_IF_NULL(mul2_x); - new_node_inputs.push_back(mul2_x); - auto mul3_x = utils::cast((*equiv)[mul3_x_]); - MS_EXCEPTION_IF_NULL(mul3_x); - new_node_inputs.push_back(mul3_x); - auto true_div1_recip = utils::cast((*equiv)[true_div1_recip_]); - MS_EXCEPTION_IF_NULL(true_div1_recip); - new_node_inputs.push_back(true_div1_recip); - auto add2_y = utils::cast((*equiv)[add2_y_]); - MS_EXCEPTION_IF_NULL(add2_y); - new_node_inputs.push_back(add2_y); - auto new_node = func_graph->NewCNode(new_node_inputs); - return new_node; -} - -const BaseRef LambNextRightRule::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); - VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); - return VectorRef( - {prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); -} - -const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - auto new_node = CreateLambNextRightNode(func_graph, equiv); - MS_EXCEPTION_IF_NULL(new_node); - // Set abstract of new node - auto iter_add1 = (*equiv).find(add1_var_); - if (iter_add1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; - } - auto add1 = utils::cast(iter_add1->second); - MS_EXCEPTION_IF_NULL(add1); - AbstractBasePtrList new_node_abstract_list; - new_node_abstract_list.push_back(add1->abstract()); - new_node_abstract_list.push_back(node->abstract()); - auto abstract_tuple = std::make_shared(new_node_abstract_list); - MS_EXCEPTION_IF_NULL(abstract_tuple); - new_node->set_abstract(abstract_tuple); - // Create tuple_getitem node for outputs - std::vector new_node_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, new_node, kLambNextRightOutputNum, &new_node_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(add1, new_node_outputs[0]); - return new_node_outputs[1]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h deleted file mode 100644 index 3d15001da2..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ - -#include -#include "pre_activate/common/optimizer.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -class LambNextRightRule : public PatternProcessPass { - public: - explicit LambNextRightRule(bool multigraph = true) - : PatternProcessPass("lamb_next_right_rule", multigraph), - input0_(std::make_shared()), - input1_(std::make_shared()), - mul2_x_(std::make_shared()), - mul3_x_(std::make_shared()), - true_div1_recip_(std::make_shared()), - add2_y_(std::make_shared()), - add1_var_(std::make_shared(std::make_shared(prim::kPrimTensorAdd->name()))) {} - - ~LambNextRightRule() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - AnfNodePtr CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; - - VarPtr input0_; - VarPtr input1_; - VarPtr mul2_x_; - VarPtr mul3_x_; - VarPtr true_div1_recip_; - VarPtr add2_y_; - VarPtr add1_var_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc deleted file mode 100644 index b5b6d2bb08..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "common/utils.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -const BaseRef LambUpdateWithLRRuleFusion::DefinePattern() const { - auto real_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(real_div); - auto greater = std::make_shared(kGreaterOpName); - MS_EXCEPTION_IF_NULL(greater); - - VectorRef pattern_real_div0({real_div, input1_, input2_}); - VectorRef pattern_greater0({greater, input0_, constant_greater_max_}); - VectorRef pattern_greater1({greater, input1_, constant_greater_max_}); - VectorRef pattern_select0({prim::kPrimSelect, pattern_greater0, pattern_real_div0, constant_select_}); - VectorRef pattern_select1({prim::kPrimSelect, pattern_greater1, pattern_select0, constant_select_}); - VectorRef pattern_minimum0({prim::kPrimMinimum, pattern_select1, constant_minimum_}); - VectorRef pattern_maximum0({prim::kPrimMaximum, pattern_minimum0, constant_greater_max_}); - VectorRef pattern_mul0({prim::kPrimMul, pattern_maximum0, input3_}); - VectorRef pattern_mul1({prim::kPrimMul, pattern_mul0, input4_}); - VectorRef pattern({prim::kPrimSub, input5_, pattern_mul1}); - return pattern; -} - -const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - auto input0 = utils::cast((*equiv)[input0_]); - auto input1 = utils::cast((*equiv)[input1_]); - auto input2 = utils::cast((*equiv)[input2_]); - auto input3 = utils::cast((*equiv)[input3_]); - auto input4 = utils::cast((*equiv)[input4_]); - auto input5 = utils::cast((*equiv)[input5_]); - auto input6 = utils::cast((*equiv)[constant_greater_max_]); - auto input7 = utils::cast((*equiv)[constant_select_]); - auto input8 = utils::cast((*equiv)[constant_minimum_]); - - auto prim = std::make_shared(kLambUpdateWithLROpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = { - NewValueNode(prim), input0, input1, input2, input3, input4, input5, input6, input7, input8}; - auto lamb_update_with_lr = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(lamb_update_with_lr); - - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lamb_update_with_lr.get()); - lamb_update_with_lr->set_scope(node->scope()); - return lamb_update_with_lr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h deleted file mode 100644 index cb3939549f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class LambUpdateWithLRRuleFusion : public PatternProcessPass { - public: - explicit LambUpdateWithLRRuleFusion(bool multigraph = true) - : PatternProcessPass("lamb_update_with_lr_rule_fusion", multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - input2_ = std::make_shared(); - input3_ = std::make_shared(); - input4_ = std::make_shared(); - input5_ = std::make_shared(); - constant_greater_max_ = std::make_shared(); - constant_select_ = std::make_shared(); - constant_minimum_ = std::make_shared(); - } - ~LambUpdateWithLRRuleFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input0_; - VarPtr input1_; - VarPtr input2_; - VarPtr input3_; - VarPtr input4_; - VarPtr input5_; - VarPtr constant_greater_max_; - VarPtr constant_select_; - VarPtr constant_minimum_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc deleted file mode 100644 index 43e1872163..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h" -#include -#include -#include -#include "utils/utils.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -const BaseRef LambUpdateWithLrV2::DefinePattern() const { - const auto prim_greater = std::make_shared(kGreaterOpName); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - - VectorRef greater0({prim_greater, input_varptr_[0], input_varptr_[5]}); - VectorRef greater1({prim_greater, input_varptr_[1], input_varptr_[5]}); - VectorRef real_div0({prim_deal_div, input_varptr_[0], input_varptr_[1]}); - VectorRef select0({prim::kPrimSelect, greater1, real_div0, input_varptr_[6]}); - VectorRef select1({prim::kPrimSelect, greater0, select0, input_varptr_[6]}); - VectorRef mul0({prim::kPrimMul, select1, input_varptr_[2]}); - VectorRef mul1({prim::kPrimMul, mul0, input_varptr_[3]}); - - return VectorRef({prim::kPrimSub, input_varptr_[4], mul1}); -} - -const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - auto prim = std::make_shared(kLambUpdateWithLrV2OpName); - std::vector inputs = {NewValueNode(prim)}; - (void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs), - [&equiv](const VarPtr &in) { return utils::cast((*equiv)[in]); }); - auto lamb_update_with_lr_v2 = func_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(lamb_update_with_lr_v2); - lamb_update_with_lr_v2->set_abstract(node->abstract()); - - return lamb_update_with_lr_v2; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h deleted file mode 100644 index ea614d3d2d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ - -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class LambUpdateWithLrV2 : public PatternProcessPass { - public: - explicit LambUpdateWithLrV2(bool multigraph = true) : PatternProcessPass("lamb_update_with_lr_v2", multigraph) { - for (size_t i = 0; i < kLambUpdateWithLrV2InputNum - 1; ++i) { - input_varptr_.push_back(std::make_shared()); - } - } - ~LambUpdateWithLrV2() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - std::vector input_varptr_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc deleted file mode 100644 index b16387d8f1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc +++ /dev/null @@ -1,162 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" -#include -#include -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -using common::SafeCStr; -namespace { -void GetOutputCastNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &node, std::vector *cast_nodes) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(node) == manager->node_users().end()) { - return; - } - for (const auto &node_index : manager->node_users()[node]) { - AnfNodePtr output = node_index.first; - auto output_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - if (AnfAlgo::GetCNodeName(output_cnode) != prim::kPrimTupleGetItem->name()) { - MS_LOG(EXCEPTION) << "The output of node " << node->DebugString() << " should be " - << prim::kPrimTupleGetItem->name(); - } - if (manager->node_users().find(output) == manager->node_users().end() || - manager->node_users()[output].size() != 1) { - continue; - } - AnfNodePtr transitive_output = manager->node_users()[output].begin()->first; - MS_EXCEPTION_IF_NULL(transitive_output); - auto transitive_output_cnode = transitive_output->cast(); - MS_EXCEPTION_IF_NULL(transitive_output_cnode); - if (AnfAlgo::GetCNodeName(transitive_output_cnode) == prim::kPrimCast->name()) { - cast_nodes->push_back(transitive_output_cnode); - } - } -} - -bool CheckKernelBuildInfo(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_info) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(kernel_info); - for (size_t i = 0; i < kernel_info->GetInputNum(); ++i) { - if (kernel_info->GetInputDeviceType(i) != kNumberTypeFloat16 || - kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(cnode, i)) { - return false; - } - } - for (size_t i = 0; i < kernel_info->GetOutputNum(); ++i) { - if (kernel_info->GetOutputDeviceType(i) != kNumberTypeFloat32 || - kernel_info->GetOutputFormat(i) != AnfAlgo::GetOutputFormat(cnode, i)) { - return false; - } - } - return true; -} - -bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - std::vector *cast_nodes) { - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::HasNodeAttr(kAttrShapeGamma, cnode)) { - MS_LOG(INFO) << "The node " << cnode->DebugString() << " has no " << kAttrShapeGamma << " attr"; - return false; - } - if (cnode->inputs().size() != kLayerNormBetaGammaBackpropInputNum) { - MS_LOG(INFO) << "The node " << cnode->DebugString() << " inputs num is not equal to " - << kLayerNormBetaGammaBackpropInputNum; - return false; - } - if (AnfAlgo::GetOutputTensorNum(cnode) != kLayerNormBetaGammaBackpropOutputNum) { - MS_LOG(INFO) << "The node " << cnode->DebugString() << " outputs num is not equal to " - << kLayerNormBetaGammaBackpropOutputNum; - return false; - } - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) { - if (AnfAlgo::GetInputDeviceDataType(cnode, i) != kNumberTypeFloat16) { - MS_LOG(INFO) << "The data type of node " << cnode->DebugString() << " input " << i << " is not float16"; - return false; - } - } - GetOutputCastNodes(func_graph, cnode, cast_nodes); - if (cast_nodes->size() != kLayerNormBetaGammaBackpropOutputNum) { - MS_LOG(INFO) << "The num of cast node in node " << cnode->DebugString() << " outputs is not equal to " - << kLayerNormBetaGammaBackpropOutputNum; - return false; - } - for (const auto &cast : *cast_nodes) { - if (AnfAlgo::GetInputDeviceDataType(cast, 0) != kNumberTypeFloat16 || - AnfAlgo::GetOutputDeviceDataType(cast, 0) != kNumberTypeFloat32) { - MS_LOG(INFO) << "The cast " << cast->DebugString() << " should be fp16->fp32"; - return false; - } - } - return true; -} -} // namespace - -const BaseRef LayerNormBetaGammaBackpropFusion::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - const auto prim = std::make_shared(kLayerNormBetaGammaBackpropOpName); - return VectorRef({prim, Xs}); -} - -const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - if (AnfAlgo::IsGraphKernel(node)) { - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::vector cast_nodes; - if (!CheckLayernormBetaGammaBackprop(func_graph, cnode, &cast_nodes)) { - return nullptr; - } - std::vector> kernel_info_list; - MS_EXCEPTION_IF_NULL(kernel_query_); - kernel_query_->Query(cnode, &kernel_info_list); - auto alternative_kernel_build_info = - std::find_if(kernel_info_list.begin(), kernel_info_list.end(), - [&cnode](const kernel::KernelBuildInfoPtr &candidate_kernel_build_info) { - return CheckKernelBuildInfo(cnode, candidate_kernel_build_info); - }); - if (alternative_kernel_build_info == kernel_info_list.end()) { - MS_LOG(INFO) << "Can not find alternative kernel build info for node " << node->DebugString(); - return nullptr; - } - AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_build_info, cnode.get()); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - // The cast_nodes size has been checked above. - MS_EXCEPTION_IF_NULL(cast_nodes[0]); - MS_EXCEPTION_IF_NULL(cast_nodes[1]); - if (cast_nodes[0]->inputs().size() != kCastInputNum) { - MS_LOG(EXCEPTION) << "The cast0 " << cast_nodes[0]->DebugString() << " input size should be " << kCastInputNum; - } - (void)manager->Replace(cast_nodes[0], cast_nodes[0]->input(1)); - if (cast_nodes[1]->inputs().size() != kCastInputNum) { - MS_LOG(EXCEPTION) << "The cast1 " << cast_nodes[1]->DebugString() << " input size should be " << kCastInputNum; - } - (void)manager->Replace(cast_nodes[1], cast_nodes[1]->input(1)); - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h deleted file mode 100644 index 2655c0f14d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class LayerNormBetaGammaBackpropFusion : public PatternProcessPass { - public: - explicit LayerNormBetaGammaBackpropFusion(bool multigraph = true) - : PatternProcessPass("layer_norm_beta_gamma_backprop_fusion", multigraph), - kernel_query_(std::make_shared()) {} - - ~LayerNormBetaGammaBackpropFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - KernelQueryPtr kernel_query_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.cc deleted file mode 100644 index e81c804b71..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" -#include -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kMatMulInputIndex = 1; -constexpr size_t kBiasInputIndex = 2; -} // namespace - -const BaseRef MatmulBiasaddFusion::DefinePattern() const { - VarPtr X0 = std::make_shared(); - VarPtr X1 = std::make_shared(); - VarPtr X2 = std::make_shared(); - const auto prim_bias_add = std::make_shared(kBiasAddOpName); - return VectorRef({prim_bias_add, VectorRef({prim::kPrimMatMul, X0, X1}), X2}); -} - -const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - CheckCNodeInputSize(cnode, kBiasAddInputNum); - AnfNodePtr matmul = cnode->input(kMatMulInputIndex); - MS_EXCEPTION_IF_NULL(matmul); - auto matmul_cnode = matmul->cast(); - MS_EXCEPTION_IF_NULL(matmul_cnode); - matmul_cnode->add_input(cnode->input(kBiasInputIndex)); - AnfAlgo::SetNodeAttr(kAttrHasBias, MakeValue(true), matmul); - return matmul; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h deleted file mode 100644 index 56675243de..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class MatmulBiasaddFusion : public PatternProcessPass { - public: - explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {} - - ~MatmulBiasaddFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc deleted file mode 100644 index e7a73a9c7f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h" -#include -#include -#include -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kAccumIndex = 1; -bool CheckValueNodeInputOfMul(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - std::vector mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0); - return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1); -} -} // namespace - -const BaseRef MomentumLossscaleFusion::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VarPtr X0 = std::make_shared(); - VarPtr X1 = std::make_shared(); - VarPtr X2 = std::make_shared(); - VarPtr X4 = std::make_shared(); - return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4}); -} - -const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - CheckCNodeInputSize(cnode, kApplyMomentumInputNum); - AnfNodePtr mul = cnode->input(4); - MS_EXCEPTION_IF_NULL(mul); - auto mul_cnode = mul->cast(); - MS_EXCEPTION_IF_NULL(mul_cnode); - CheckCNodeInputSize(mul_cnode, kMulInputNum); - size_t value_node_index = 0; - for (size_t i = 1; i < kMulInputNum; ++i) { - if (CheckValueNodeInputOfMul(mul_cnode->input(i))) { - value_node_index = i; - break; - } - } - if (value_node_index == 0) { - MS_LOG(DEBUG) << "The Mul " << mul->DebugString() << " to be fused must has a scalar constant input"; - return nullptr; - } - auto new_prim = std::make_shared(kFusedMulApplyMomentumOpName); - std::vector new_node_inputs{NewValueNode(new_prim), - cnode->input(1), - cnode->input(2), - cnode->input(3), - mul_cnode->input(kMulInputNum - value_node_index), - cnode->input(5), - mul_cnode->input(value_node_index)}; - auto new_node = func_graph->NewCNode(new_node_inputs); - MS_EXCEPTION_IF_NULL(new_node); - AnfAlgo::CopyNodeAttrs(node, new_node); - auto input_names_value = AnfAlgo::GetNodeAttr>(new_node, kAttrInputNames); - input_names_value[3] = "x1"; - input_names_value.emplace_back("x2"); - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node); - new_node->set_abstract(node->abstract()); - new_node->set_scope(node->scope()); - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h deleted file mode 100644 index c092e0ca22..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class MomentumLossscaleFusion : public PatternProcessPass { - public: - explicit MomentumLossscaleFusion(bool multigraph = true) - : PatternProcessPass("momentum_lossscale_fusion", multigraph) {} - - ~MomentumLossscaleFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc deleted file mode 100644 index 2536255fc1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" -#include -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(add); - - for (size_t index = 1; index < add->size(); ++index) { - auto input = add->input(index); - MS_EXCEPTION_IF_NULL(input); - if (input->isa()) { - auto cnode = input->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) { - if (!opt::IsUsedByOthers(graph, cnode)) { - auto full_name = cnode->fullname_with_scope(); - // exclude lamb and adam, and only work in bert - if (std::string::npos != full_name.find("adam") || std::string::npos != full_name.find("lamb") || - std::string::npos == full_name.find("bert")) { - MS_LOG(INFO) << "Mul is in adam or lamb or not a bert network, quit fusion"; - return false; - } - - *mul = cnode; - *mul_index = index; - return true; - } - } - } - } - return false; -} -} // namespace -const BaseRef MulAddFusion::DefinePattern() const { - VarPtr x = std::make_shared(); - VarPtr y = std::make_shared(); - VectorRef pattern({prim::kPrimTensorAdd, x, y}); - return pattern; -} - -const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - if (graph == nullptr || node == nullptr) { - return nullptr; - } - auto add = node->cast(); - if (add == nullptr || add->inputs().size() != kAddInputNum) { - return nullptr; - } - CNodePtr mul = nullptr; - size_t mul_index = 0; - if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) { - MS_LOG(DEBUG) << "Cannot find used-by-only-one-op Mul in Add's inputs"; - return nullptr; - } - - auto prim = std::make_shared(kFusedMulAddOpName); - std::vector inputs = {NewValueNode(prim)}; - for (size_t index = 1; index < mul->size(); ++index) { - inputs.push_back(mul->input(index)); - } - auto another_input_node = add->input(add->size() - mul_index); - if (another_input_node->isa() && - AnfAlgo::GetCNodeName(another_input_node) == prim::kPrimTupleGetItem->name()) { - MS_LOG(INFO) << "Add's another input node has multiple outputs, do not fuse"; - return nullptr; - } - inputs.push_back(another_input_node); - auto fusion_node = graph->NewCNode(inputs); - fusion_node->set_scope(add->scope()); - fusion_node->set_abstract(add->abstract()); - return fusion_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.h deleted file mode 100644 index 4b4db2b312..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class MulAddFusion : public PatternProcessPass { - public: - explicit MulAddFusion(bool multigraph = true) : PatternProcessPass("mul_add_fusion", multigraph) {} - ~MulAddFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc deleted file mode 100644 index a5e4675c8f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" -#include -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const CNodePtr &addn, - const size_t &lossscale_input_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mul); - MS_EXCEPTION_IF_NULL(addn); - auto prim = std::make_shared(kFusedMulAddNOpName); - std::vector inputs = {NewValueNode(prim)}; - inputs.push_back(mul->input(kMulInputNum - lossscale_input_index)); - inputs.push_back(addn->input(2)); - // scalar input should be 3rd input - inputs.push_back(mul->input(lossscale_input_index)); - auto fusion_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(addn->scope()); - fusion_node->set_abstract(addn->abstract()); - return fusion_node; -} -} // namespace - -const BaseRef MulAddNFusion::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Y = std::make_shared(); - VarPtr Z = std::make_shared(); - - VectorRef mul({prim::kPrimMul, X, Z}); - VectorRef addn({prim::kPrimAddN, mul, Y}); - return addn; -} - -const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - if (graph == nullptr || node == nullptr || equiv == nullptr) { - return nullptr; - } - - auto addn = node->cast(); - if (addn == nullptr || addn->inputs().size() != kAddNInputNum) { - return nullptr; - } - auto mul_anf = addn->input(1); - if (mul_anf == nullptr) { - return nullptr; - } - auto mul = mul_anf->cast(); - if (mul == nullptr || mul->inputs().size() != kMulInputNum) { - return nullptr; - } - if (IsUsedByOthers(graph, mul)) { - MS_LOG(DEBUG) << "Mul is used by more then two nodes, cannot fuse"; - return nullptr; - } - - size_t lossscale_input_index = 1; - for (size_t index = 1; index < mul->inputs().size(); ++index) { - auto input_node = mul->input(index); - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa()) { - lossscale_input_index = index; - break; - } - } - auto constant_shape = AnfAlgo::GetOutputInferShape(mul->input(lossscale_input_index), 0); - if (!(constant_shape.size() == 0 || (constant_shape.size() == 1 && constant_shape[0] == 1))) { - MS_LOG(DEBUG) << "The const input of Mul node must be scalar or shape=(1,), but shape size is " - << constant_shape.size() << " and shape[0] is " << constant_shape[0]; - return nullptr; - } - - return CreateFusionNode(graph, mul, addn, lossscale_input_index); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.h deleted file mode 100644 index d03309bf73..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class MulAddNFusion : public PatternProcessPass { - public: - explicit MulAddNFusion(bool multigraph = true) : PatternProcessPass("mul_addn_fusion", multigraph) {} - ~MulAddNFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc deleted file mode 100644 index a3c87dad5d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "device/kernel_info.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -namespace { -const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag, - std::vector *trans_road) { - if (node == nullptr) { - MS_LOG(ERROR) << "nullptr"; - return nullptr; - } - if (node->isa()) { - auto cnode = node->cast(); - auto op_name = AnfAlgo::GetCNodeName(cnode); - auto manager = func_graph->manager(); - if (manager == nullptr) { - return nullptr; - } - if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || - op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { - auto users = manager->node_users()[node]; - if (users.size() > 1 && !first_flag) { - return nullptr; - } - trans_road->push_back(cnode); - first_flag = false; - auto next_node = AnfAlgo::GetInputNode(cnode, 0); - if (next_node->isa() || next_node->isa()) { - return next_node; - } - return ParamTransRoad(func_graph, next_node, first_flag, trans_road); - } - } else if (node->isa() || node->isa()) { - return node; - } - return nullptr; -} - -kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type, - TypeId output_type) { - MS_EXCEPTION_IF_NULL(cast); - auto kernel_info = cast->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto cast_build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(cast_build_info); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetOutputsFormat({format}); - builder.SetInputsFormat({format}); - builder.SetInputsDeviceType({input_type}); - builder.SetOutputsDeviceType({output_type}); - builder.SetKernelType(cast_build_info->kernel_type()); - builder.SetFusionType(cast_build_info->fusion_type()); - builder.SetProcessor(cast_build_info->processor()); - return builder.Build(); -} -} // namespace -bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Func graph is nullptr"; - return false; - } - auto manager = func_graph->manager(); - if (manager == nullptr) { - return false; - } - std::vector node_list = TopoSort(func_graph->get_return()); - bool changed = false; - for (auto node : node_list) { - if (node == nullptr || !node->isa()) { - continue; - } - auto cnode = node->cast(); - auto node_name = AnfAlgo::GetCNodeName(cnode); - if (node_name == prim::kPrimCast->name() || node_name == prim::kPrimTranspose->name() || - node_name == prim::kPrimReshape->name() || node_name == kTransDataOpName) { - MS_LOG(DEBUG) << "Skip trans op"; - continue; - } - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { - std::vector trans_road; - bool first_flag = true; - auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road); - if (final_node != nullptr && trans_road.size() == 3 && AnfAlgo::GetCNodeName(trans_road[0]) == kTransDataOpName && - AnfAlgo::GetCNodeName(trans_road[1]) == prim::kPrimCast->name() && - AnfAlgo::GetCNodeName(trans_road[2]) == kTransDataOpName) { - auto cur_transop = trans_road[0]; - auto format = AnfAlgo::GetOutputFormat(cur_transop, 0); - auto dtype = AnfAlgo::GetOutputDeviceDataType(cur_transop, 0); - auto param_format = AnfAlgo::GetOutputFormat(final_node, 0); - auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); - - auto cast = trans_road[1]; - if (param_format == format && param_dtype != dtype) { - AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); - manager->Replace(trans_road[2], final_node); - manager->Replace(cur_transop, cast); - } - changed = true; - } - } - } - return changed; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h deleted file mode 100644 index 823ec083b1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pass.h" - -namespace mindspore { -namespace opt { -class ParameterTransOpFusion : public Pass { - public: - explicit ParameterTransOpFusion(size_t groups = 1) : Pass("Parameter_and_transop_fusion"), groups_(groups) {} - ~ParameterTransOpFusion() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - size_t groups_ = 1; -}; -} // namespace opt -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.cc deleted file mode 100644 index 857670a384..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/refresh_parameter_format.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "device/kernel_info.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -void DoRefresh(const CNodePtr &cnode) { - if (cnode == nullptr) { - MS_LOG(EXCEPTION) << "node is nullptr"; - } - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { - auto input_kernel_node = AnfAlgo::GetInputNode(cnode, input_index); - if (input_kernel_node->isa()) { - std::shared_ptr builder = - std::make_shared(); - auto cnode_input_format = AnfAlgo::GetInputFormat(cnode, input_index); - auto kernel_node_format = AnfAlgo::GetOutputFormat(input_kernel_node, 0); - auto dtype = AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0); - if (kernel_node_format != cnode_input_format) { - builder->SetOutputsFormat({cnode_input_format}); - builder->SetOutputsDeviceType({dtype}); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); - } - } - } -} - -bool RefreshParameterFormat::Run(const FuncGraphPtr &func_graph) { - if (func_graph == nullptr) { - MS_LOG(ERROR) << "func_graph is nullptr."; - return false; - } - std::vector node_list = TopoSort(func_graph->get_return()); - for (auto node : node_list) { - if (node == nullptr || !node->isa()) { - continue; - } - auto cnode = node->cast(); - if (cnode == nullptr) { - continue; - } - auto node_name = AnfAlgo::GetCNodeName(cnode); - if (node_name == kBNTrainingUpdateOpName) { - DoRefresh(cnode); - } - } - return true; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.h deleted file mode 100644 index 0ba688b134..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ - -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pass.h" - -namespace mindspore { -namespace opt { -class RefreshParameterFormat : public Pass { - public: - explicit RefreshParameterFormat(size_t groups = 1) : Pass("refresh_parameter_format"), groups_(groups) {} - ~RefreshParameterFormat() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - size_t groups_ = 1; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc deleted file mode 100644 index fa2815ff62..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -const BaseRef RemoveReshapePair::DefinePattern() const { - VarPtr X = std::make_shared(); - MS_EXCEPTION_IF_NULL(X); - return VectorRef({prim::kPrimReshape, VectorRef({prim::kPrimReshape, X})}); -} - -const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(reshape_op_1); - // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly - if (IsUsedByOthers(func_graph, reshape_op_1)) { - return nullptr; - } - auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(reshape_op_2); - if (IsUsedByOthers(func_graph, reshape_op_2)) { - return nullptr; - } - auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0); - auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0); - if (input_shape == output_shape) { - auto input_node = reshape_op_2->input(1); - return input_node; - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h deleted file mode 100644 index ddb25df70c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ - -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class RemoveReshapePair : public PatternProcessPass { - public: - explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) {} - ~RemoveReshapePair() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc deleted file mode 100644 index 9b13002798..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace { -bool CheckShapeDimInfo(const std::vector &shape) { - if (shape.empty()) { - return false; - } - if (shape.size() == 1 && shape[0] % kCubeSize != 0) { - return false; - } - return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); -} -} // namespace - -const BaseRef ReshapeTransposeFusion::DefinePattern() const { - const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); - VectorRef reshape({prim_reshape, input_varptr_}); - - return VectorRef({prim::kPrimTranspose, reshape}); -} - -const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(transpose_cnode); - auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(reshape_cnode); - std::vector reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); - std::vector transpose_output0_shape = AnfAlgo::GetOutputInferShape(transpose_cnode, 0); - if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_output0_shape)) { - return nullptr; - } - auto prim = std::make_shared(kConfusionTransposeDOpName); - std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; - auto new_node = func_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_abstract(node->abstract()); - - AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); - AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); - AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(false), new_node); - auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); - AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); - - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h deleted file mode 100644 index 5abf3e0d53..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ReshapeTransposeFusion : public PatternProcessPass { - public: - explicit ReshapeTransposeFusion(bool multigraph = true) : PatternProcessPass("reshape_transpose_fusion", multigraph) { - input_varptr_ = std::make_shared(); - } - ~ReshapeTransposeFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input_varptr_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc deleted file mode 100644 index f95406e5e1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -const BaseRef SoftmaxGradExtFusion::DefinePattern() const { - VectorRef mul({prim::kPrimMul, input1_, input0_}); - VectorRef sum({sum_var_, mul}); - VectorRef sub({prim::kPrimSub, input0_, sum}); - VectorRef mul1({prim::kPrimMul, input2_, input1_}); - VectorRef mul_grad({prim::kPrimMul, mul1, sub}); - return mul_grad; -} - -const BaseRef SoftmaxGradExtFusionV2::DefinePattern() const { - VectorRef mul({prim::kPrimMul, input1_, input0_}); - VectorRef sum({sum_var_, mul}); - VectorRef sub({prim::kPrimSub, input0_, sum}); - VectorRef mul1({prim::kPrimMul, input1_, sub}); - VectorRef mul_grad({prim::kPrimMul, input2_, mul1}); - return mul_grad; -} - -const BaseRef SoftmaxGradExtFusionV3::DefinePattern() const { - VectorRef mul({prim::kPrimMul, input1_, input0_}); - VectorRef sum({sum_var_, mul}); - VectorRef sub({prim::kPrimSub, input0_, sum}); - VectorRef mul1({prim::kPrimMul, input1_, sub}); - VectorRef mul_grad({prim::kPrimMul, mul1, input2_}); - return mul_grad; -} - -const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(equiv); - MS_EXCEPTION_IF_NULL(node); - auto input0 = GetAnfNodeByVar(equiv, input0_); - auto input1 = GetAnfNodeByVar(equiv, input1_); - auto input2 = GetAnfNodeByVar(equiv, input2_); - auto sum = GetAnfNodeByVar(equiv, sum_var_); - if (!GetBoolAttr(sum, kAttrKeepDims)) { - MS_LOG(INFO) << "sum's attr keep_dims should be true if do fusion"; - return nullptr; - } - - auto prim = std::make_shared(kSoftmaxGradExtOpName); - auto fusion_node = graph->NewCNode({NewValueNode(prim), input0, input1, input2}); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(node->scope()); - fusion_node->set_abstract(node->abstract()); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, "keepdims", sum, fusion_node); - AnfAlgo::CopyNodeAttr(kAttrAxis, sum, fusion_node); - return fusion_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h deleted file mode 100644 index 59032e6973..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ - -#include -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class SoftmaxGradExtFusion : public PatternProcessPass { - public: - explicit SoftmaxGradExtFusion(const std::string &name = "softmax_grad_ext_fusion", bool multigraph = true) - : PatternProcessPass(name, multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - input2_ = std::make_shared(); - sum_var_ = std::make_shared(std::make_shared(prim::kPrimReduceSum->name())); - } - ~SoftmaxGradExtFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - protected: - VarPtr input0_; - VarPtr input1_; - VarPtr input2_; - VarPtr sum_var_; -}; - -class SoftmaxGradExtFusionV2 : public SoftmaxGradExtFusion { - public: - explicit SoftmaxGradExtFusionV2(bool multigraph = true) - : SoftmaxGradExtFusion("softmax_grad_ext_fusion_v2", multigraph) {} - ~SoftmaxGradExtFusionV2() override = default; - const BaseRef DefinePattern() const override; -}; - -class SoftmaxGradExtFusionV3 : public SoftmaxGradExtFusion { - public: - explicit SoftmaxGradExtFusionV3(bool multigraph = true) - : SoftmaxGradExtFusion("softmax_grad_ext_fusion_v3", multigraph) {} - ~SoftmaxGradExtFusionV3() override = default; - const BaseRef DefinePattern() const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.cc deleted file mode 100644 index 6261b63882..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.cc +++ /dev/null @@ -1,134 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h" - -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "pre_activate/common/helper.h" -#include "device/kernel_info.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(square); - MS_EXCEPTION_IF_NULL(sum); - if (square->inputs().size() != kSquareNodeInputNum) { - MS_LOG(EXCEPTION) << "Square node has wrong input size"; - } - auto prim = std::make_shared(kSquareSumV1OpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector square_sumv1_inputs = {NewValueNode(prim), square->input(1)}; - auto square_sumv1 = graph->NewCNode(square_sumv1_inputs); - MS_EXCEPTION_IF_NULL(square_sumv1); - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - square_sumv1->set_kernel_info(kernel_info); - auto types = {AnfAlgo::GetOutputInferDataType(sum, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(sum, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sumv1.get()); - square_sumv1->set_scope(sum->scope()); - AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv1); - auto names = MakeValue>({prim::kPrimSquare->name(), prim::kPrimReduceSum->name()}); - AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv1); - return square_sumv1; -} - -CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(square); - MS_EXCEPTION_IF_NULL(sum); - if (square->inputs().size() != kSquareNodeInputNum) { - MS_LOG(EXCEPTION) << "Square node has wrong input size"; - } - auto prim = std::make_shared(kSquareSumV2OpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector square_sumv2_inputs = {NewValueNode(prim), square->input(1)}; - auto square_sumv2 = graph->NewCNode(square_sumv2_inputs); - MS_EXCEPTION_IF_NULL(square_sumv2); - auto types = {AnfAlgo::GetOutputInferDataType(sum, 0), AnfAlgo::GetOutputInferDataType(square, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(sum, 0), AnfAlgo::GetOutputInferShape(square, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sumv2.get()); - square_sumv2->set_scope(sum->scope()); - AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv2); - auto names = MakeValue>({prim::kPrimSquare->name(), prim::kPrimReduceSum->name()}); - AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv2); - return square_sumv2; -} - -std::tuple GetPrevNodes(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto sum = node->cast(); - MS_EXCEPTION_IF_NULL(sum); - if (sum->inputs().size() != kSumNodeInputNum) { - MS_LOG(EXCEPTION) << "ReduceSumD node has wrong input size"; - } - auto square_anf = sum->input(1); - MS_EXCEPTION_IF_NULL(square_anf); - auto square = square_anf->cast(); - MS_EXCEPTION_IF_NULL(square); - - return std::make_tuple(sum, square_anf, square); -} -} // namespace - -const BaseRef SquareSumFusion::DefinePattern() const { - VarPtr X = std::make_shared(); - MS_EXCEPTION_IF_NULL(X); - return VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimSquare, X})}); -} - -const AnfNodePtr SquareSumFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - CNodePtr sum = nullptr; - AnfNodePtr square_anf = nullptr; - CNodePtr square = nullptr; - std::tie(sum, square_anf, square) = GetPrevNodes(node); - - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(square_anf) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "Square node has no output in NodeUsersMap"; - } - AnfNodePtr ret_node = nullptr; - if (manager->node_users()[square_anf].size() == 1) { - ret_node = GenerateSquareSumV1(graph, square, sum); - } else if (manager->node_users()[square_anf].size() == 2) { - auto square_sumv2 = GenerateSquareSumV2(graph, square, sum); - - std::vector square_sumv2_outputs; - CreateMultipleOutputsOfAnfNode(graph, square_sumv2, kSquareSumv2OutputNum, &square_sumv2_outputs); - if (square_sumv2_outputs.size() != kSquareSumv2OutputNum) { - MS_LOG(EXCEPTION) << "make SquareSumV2 outputs fail"; - } - (void)manager->Replace(square, square_sumv2_outputs[1]); - ret_node = square_sumv2_outputs[0]; - } - return ret_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.h deleted file mode 100644 index 5a694a5585..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class SquareSumFusion : public PatternProcessPass { - public: - explicit SquareSumFusion(bool multigraph = true) : PatternProcessPass("square_sum_fusion", multigraph) {} - ~SquareSumFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc deleted file mode 100644 index 250f86d9b1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace { -bool CheckShapeDimInfo(const std::vector &shape) { - if (shape.empty()) { - return false; - } - if (shape.size() == 1 && shape[0] % kCubeSize != 0) { - return false; - } - return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); -} -} // namespace - -const BaseRef TransposeReshapeFusion::DefinePattern() const { - const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); - VectorRef transpose({prim::kPrimTranspose, input_varptr_}); - - return VectorRef({prim_reshape, transpose}); -} - -const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(reshape_cnode); - auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(transpose_cnode); - std::vector reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); - std::vector transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); - if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { - return nullptr; - } - auto prim = std::make_shared(kConfusionTransposeDOpName); - std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; - auto new_node = func_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - - new_node->set_abstract(node->abstract()); - AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); - AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); - AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(true), new_node); - auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); - AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); - - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h deleted file mode 100644 index 8b979f869d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class TransposeReshapeFusion : public PatternProcessPass { - public: - explicit TransposeReshapeFusion(bool multigraph = true) : PatternProcessPass("transpose_reshape_fusion", multigraph) { - input_varptr_ = std::make_shared(); - } - ~TransposeReshapeFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input_varptr_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc deleted file mode 100644 index e45fc2637f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -const BaseRef TransposeTransDataFusion::DefinePattern() const { - const auto prim_transdata = std::make_shared(prim::KPrimTransData->name()); - VectorRef transpose({prim::kPrimTranspose, input_varptr_}); - - return VectorRef({prim_transdata, transpose}); -} - -const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputNum); - MS_EXCEPTION_IF_NULL(transdata_cnode); - auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kBackendTransDataInputNum); - MS_EXCEPTION_IF_NULL(transpose_cnode); - auto transpose_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transpose_cnode); - auto transdata_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transdata_cnode); - MS_EXCEPTION_IF_NULL(transpose_kernel_build_info); - MS_EXCEPTION_IF_NULL(transdata_kernel_build_info); - - auto new_transdata_builder = std::make_shared(); - auto transpose_input_formats = transpose_kernel_build_info->GetAllInputFormats(); - new_transdata_builder->SetInputsFormat(transpose_input_formats); - new_transdata_builder->SetOutputsFormat(transdata_kernel_build_info->GetAllOutputFormats()); - new_transdata_builder->SetInputsDeviceType(transdata_kernel_build_info->GetAllInputDeviceTypes()); - new_transdata_builder->SetOutputsDeviceType(transdata_kernel_build_info->GetAllOutputDeviceTypes()); - new_transdata_builder->SetKernelType(transdata_kernel_build_info->kernel_type()); - new_transdata_builder->SetFusionType(transdata_kernel_build_info->fusion_type()); - new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); - - auto new_fusion_transdata = std::make_shared(kTransDataOpName); - if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) { - std::vector inputs = {NewValueNode(new_fusion_transdata), - utils::cast((*equiv)[input_varptr_])}; - auto new_node = func_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_abstract(node->abstract()); - AnfAlgo::CopyNodeAttrs(transdata_cnode, new_node); - AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(transpose_input_formats[0]), new_node); - AnfAlgo::SetSelectKernelBuildInfo(new_transdata_builder->Build(), new_node.get()); - MS_LOG(INFO) << "transpose transdata fusion node:" << node->fullname_with_scope() << " success"; - return new_node; - } else { - MS_LOG(INFO) << "transpose transdata fusion node:" << node->fullname_with_scope() << " failed"; - return node; - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h deleted file mode 100644 index 833588cf45..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class TransposeTransDataFusion : public PatternProcessPass { - public: - explicit TransposeTransDataFusion(bool multigraph = true) - : PatternProcessPass("transpose_transdata_fusion", multigraph) { - input_varptr_ = std::make_shared(); - supported_checker_ = std::make_shared(); - } - ~TransposeTransDataFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input_varptr_; - - private: - SupportedCheckerPtr supported_checker_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc b/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc deleted file mode 100644 index b930ac69c9..0000000000 --- a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/common/common_backend_optimization.h" -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/pass/convert_const_input_to_attr.h" -#include "pre_activate/pass/convert_tuple_output_to_maketuple.h" -#include "pre_activate/pass/convert_const_input_to_tensor_input.h" -#include "pre_activate/pass/convert_tuple_input_to_dynamic_input.h" -#include "pre_activate/pass/const_to_attr_strided_slice_grad.h" -#include "utils/context/ms_context.h" -#include "debug/anf_ir_dump.h" - -namespace mindspore { -namespace opt { -void BackendCommonOptimization(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = - save_graphs_path + "/hwopt_common_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } - auto optimizer = std::make_shared(); - auto common_pm = std::make_shared("common_pm"); - common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(common_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - if (save_graphs) { - std::string file_path = - save_graphs_path + "/hwopt_common_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.h b/mindspore/ccsrc/pre_activate/common/common_backend_optimization.h deleted file mode 100644 index 6ce92da0dc..0000000000 --- a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.h +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ -#include -#include "session/kernel_graph.h" -namespace mindspore { -namespace opt { -void BackendCommonOptimization(const std::shared_ptr &kernel_graph); -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ diff --git a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc deleted file mode 100644 index 2b45fc6579..0000000000 --- a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/common/fusion_id_allocator.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -FusionIdAllocator::FusionIdAllocator() { fusion_id = 0; } - -FusionIdAllocator::~FusionIdAllocator() {} - -void FusionIdAllocator::Init() { fusion_id = 0; } - -int32_t FusionIdAllocator::AllocateFusionId() { - fusion_id++; - return fusion_id; -} - -bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - auto cnode = node->cast(); - return AnfAlgo::HasNodeAttr(kAttrFusionId, cnode); -} - -int32_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) { - if (HasFusionIdAttr(node)) { - return AnfAlgo::GetNodeAttr(node, kAttrFusionId); - } - return -1; -} - -void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int32_t id) { - ValuePtr fusion_id_v = MakeValue(id); - AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h deleted file mode 100644 index 91e83600f2..0000000000 --- a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ - -#include -#include "ir/base.h" - -namespace mindspore { -namespace opt { -class FusionIdAllocator { - public: - FusionIdAllocator(); - virtual ~FusionIdAllocator(); - FusionIdAllocator(const FusionIdAllocator &in) = delete; - FusionIdAllocator &operator=(const FusionIdAllocator &in) = delete; - - void Init(); - int32_t AllocateFusionId(); - bool HasFusionIdAttr(const AnfNodePtr &node); - int32_t GetFusionId(const AnfNodePtr &node); - void SetFusionId(const AnfNodePtr &node, int32_t id); - - private: - int32_t fusion_id; -}; -using FusionIdAllocatorPtr = std::shared_ptr; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc deleted file mode 100644 index e1db0ed6ed..0000000000 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ /dev/null @@ -1,785 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/common/helper.h" -#include -#include -#include -#include -#include -#include -#include -#include "utils/utils.h" -#include "utils/base_ref.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "common/utils.h" -#include "device/kernel_info.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -constexpr size_t kType32Len = 4; -std::vector Convert2Int(const std::vector &v) { - std::vector result; - (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt); - return result; -} - -bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - std::vector node_list = TopoSort(graph->get_return()); - std::map> control_depend_map; - for (auto &nd : node_list) { - MS_EXCEPTION_IF_NULL(nd); - if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { - auto control_depend = nd->cast(); - auto prior_node = control_depend->input(kControlDependPriorIndex); - auto behind_node = control_depend->input(kControlDependBehindIndex); - auto it = control_depend_map.find(behind_node); - if (it == control_depend_map.end()) { - control_depend_map[behind_node] = std::set{prior_node}; - } else { - it->second.insert(prior_node); - } - } - } - - FuncGraphManagerPtr manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - std::unordered_set seen_node; - std::deque todo{node1}; - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); - if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { - continue; - } - (void)seen_node.insert(node); - - if (node == node2) { - return true; - } - if (node->isa()) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto inputs = cnode->inputs(); - (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); - } - auto it = control_depend_map.find(node); - if (it != control_depend_map.end()) { - (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); - } - } - return false; -} - -bool UnVisited(const BaseRef &n) { - if (utils::isa(n)) { - AnfNodePtr in = utils::cast(n); - MS_EXCEPTION_IF_NULL(in); - if (IsValueNode(in)) { - auto value_node = in->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - auto prim_py = value->cast(); - MS_EXCEPTION_IF_NULL(prim_py); - return !prim_py->HasAttr(kAttrVisited); - } else if (IsValueNode(in)) { - auto func_graph = GetValueNode(in); - MS_EXCEPTION_IF_NULL(func_graph); - return !func_graph->has_flag(kAttrVisited); - } - return false; - } - return false; -} - -bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(ERROR) << "The node is expected to be a cnode"; - return false; - } - *cnode = node->cast(); - if (*cnode == nullptr) { - return false; - } - if ((*cnode)->inputs().size() < IntToSize(input_size)) { - auto op_name = AnfAlgo::GetCNodeName(*cnode); - MS_LOG(ERROR) << "op[" + op_name + "] has less than " << input_size << " inputs."; - return false; - } - return true; -} - -CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(EXCEPTION) << "The node is expected to be a cnode"; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != IntToSize(input_size)) { - auto op_name = AnfAlgo::GetCNodeName(cnode); - MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs."; - } - return cnode; -} - -void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) { - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != input_size) { - MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size; - } -} - -bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) { - MS_EXCEPTION_IF_NULL(node_x); - MS_EXCEPTION_IF_NULL(node_y); - return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) && - AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0)); -} - -const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); - - auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum); - MS_EXCEPTION_IF_NULL(transop_cnode); - auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum); - auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum); - MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1)); - MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1)); - auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1); - MS_EXCEPTION_IF_NULL(transed_node); - - std::vector replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node, - depend_cnode->input(kDependInputNum - 1)}; - AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs); - MS_EXCEPTION_IF_NULL(replace_depend); - auto transed_abstract = transed_node->abstract(); - replace_depend->set_abstract(transed_abstract); - return replace_depend; -} - -bool Visited(const BaseRef &n) { - if (utils::isa(n)) { - AnfNodePtr in = utils::cast(n); - MS_EXCEPTION_IF_NULL(in); - if (IsValueNode(in)) { - auto value_node = in->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - auto prim_py = value->cast(); - MS_EXCEPTION_IF_NULL(prim_py); - return prim_py->HasAttr(kAttrVisited); - } else if (IsValueNode(in)) { - auto func_graph = GetValueNode(in); - MS_EXCEPTION_IF_NULL(func_graph); - return func_graph->has_flag(kAttrVisited); - } - return false; - } - return false; -} - -void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode, - std::vector *conv_bn1_outputs) { - auto prim = std::make_shared(kConvBN1OpName); - std::vector conv_bn1_inputs = {NewValueNode(prim)}; - MS_EXCEPTION_IF_NULL(conv_cnode); - // All the inputs of conv_bn1 are from the inputs of conv - for (size_t i = 1; i < conv_cnode->inputs().size(); i++) { - conv_bn1_inputs.push_back(conv_cnode->input(i)); - } - MS_EXCEPTION_IF_NULL(func_graph); - CNodePtr conv_bn1_cnode = func_graph->NewCNode(conv_bn1_inputs); - MS_EXCEPTION_IF_NULL(conv_bn1_cnode); - auto kernel_info = std::make_shared(); - conv_bn1_cnode->set_kernel_info(kernel_info); - // Set attr for conv_bn1 - AnfAlgo::CopyNodeAttrs(conv_cnode, conv_bn1_cnode); - // Set abstract of conv_bn1 - MS_EXCEPTION_IF_NULL(bn_cnode); - auto bn_abstract_tuple = dyn_cast(bn_cnode->abstract()); - MS_EXCEPTION_IF_NULL(bn_abstract_tuple); - AbstractBasePtrList conv_bn1_abstract_list; - conv_bn1_abstract_list.push_back(conv_cnode->abstract()); - auto abstract_tensor = std::make_shared( - kFloat32, Convert2Int(AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, kVariance - 1))); - conv_bn1_abstract_list.push_back(abstract_tensor); - conv_bn1_abstract_list.push_back(bn_abstract_tuple->elements()[kSaveMean]); - auto abstract_tuple = std::make_shared(conv_bn1_abstract_list); - conv_bn1_cnode->set_abstract(abstract_tuple); - - CreateMultipleOutputsOfAnfNode(func_graph, conv_bn1_cnode, kConvBn1OutputNum, conv_bn1_outputs); -} - -void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector &fused_bn1_outputs, - const CNodePtr &bn_node, std::vector *fused_bn2_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_node); - MS_EXCEPTION_IF_NULL(fused_bn2_outputs); - if (bn_node->inputs().size() != kBnInputNum) { - MS_LOG(EXCEPTION) << "BN node has wrong input size"; - } - if (fused_bn1_outputs.size() != kBN1OutputNum) { - MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; - } - - // the inputs of fused_bn2 are from the outputs of fused_bn1 and the inputs of bn - std::vector fused_bn2_inputs = {NewValueNode(std::make_shared(kFusedBN2OpName))}; - fused_bn2_inputs.push_back(fused_bn1_outputs[0]); - fused_bn2_inputs.push_back(fused_bn1_outputs[1]); - fused_bn2_inputs.push_back(bn_node->input(4)); - fused_bn2_inputs.push_back(bn_node->input(5)); - auto fused_bn2 = graph->NewCNode(fused_bn2_inputs); - MS_EXCEPTION_IF_NULL(fused_bn2); - auto kernel_info = std::make_shared(); - fused_bn2->set_kernel_info(kernel_info); - auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 4), AnfAlgo::GetOutputInferDataType(bn_node, 1), - AnfAlgo::GetOutputInferDataType(bn_node, 2)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 4), AnfAlgo::GetOutputInferShape(bn_node, 1), - AnfAlgo::GetOutputInferShape(bn_node, 2)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn2.get()); - fused_bn2->set_scope(bn_node->scope()); - AnfAlgo::CopyNodeAttr(kAttrMomentum, bn_node, fused_bn2); - - CreateMultipleOutputsOfAnfNode(graph, fused_bn2, kBN2OutputNum, fused_bn2_outputs); -} - -void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input, - const std::vector &fused_bn1_outputs, - const std::vector &fused_bn2_outputs, const CNodePtr &bn_node, - std::vector *fused_bn3_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(data_input); - MS_EXCEPTION_IF_NULL(bn_node); - MS_EXCEPTION_IF_NULL(fused_bn3_outputs); - if (bn_node->inputs().size() != kBnInputNum) { - MS_LOG(EXCEPTION) << "BN node has wrong input size"; - } - - if (fused_bn1_outputs.size() != kBN1OutputNum) { - MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; - } - - if (fused_bn2_outputs.size() != kBN2OutputNum) { - MS_LOG(EXCEPTION) << "BN2 outputs has wrong input size"; - } - - // the inputs of fused_bn3 are from the outputs of fused_bn1 and the inputs of bn - std::vector fused_bn3_inputs = {NewValueNode(std::make_shared(kFusedBN3OpName))}; - fused_bn3_inputs.push_back(data_input); - fused_bn3_inputs.push_back(fused_bn1_outputs[0]); - fused_bn3_inputs.push_back(fused_bn2_outputs[0]); - fused_bn3_inputs.push_back(bn_node->input(2)); - fused_bn3_inputs.push_back(bn_node->input(3)); - auto fused_bn3 = graph->NewCNode(fused_bn3_inputs); - MS_EXCEPTION_IF_NULL(fused_bn3); - auto kernel_info = std::make_shared(); - fused_bn3->set_kernel_info(kernel_info); - auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn3.get()); - - fused_bn3->set_scope(bn_node->scope()); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, kAttrEps, bn_node, fused_bn3); - - (*fused_bn3_outputs).push_back(fused_bn3); -} - -void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num, - std::vector *outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(outputs); - for (size_t i = 0; i < output_num; i++) { - auto idx = NewValueNode(SizeToInt(i)); - MS_EXCEPTION_IF_NULL(idx); - int temp = SizeToInt(i); - auto imm = std::make_shared(temp); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); - MS_EXCEPTION_IF_NULL(tuple_getitem); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(node, i)}, - {AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get()); - (*outputs).push_back(tuple_getitem); - } -} - -template -tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, - size_t data_length) { - MS_EXCEPTION_IF_NULL(value_tuple_ptr); - MS_EXCEPTION_IF_NULL(type_ptr); - std::vector values; - for (const auto &v : value_tuple_ptr->value()) { - MS_EXCEPTION_IF_NULL(v); - if (v->isa()) { - ScalarPtr scalar = v->cast(); - values.push_back(GetValue(scalar)); - } else { - MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; - return nullptr; - } - } - std::vector tensor_shape = {SizeToInt(values.size())}; - tensor::TensorPtr tensor = std::make_shared(type_ptr->type_id(), tensor_shape); - MS_EXCEPTION_IF_NULL(tensor); - tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr}; - tensor->set_device_info(device_info); - auto data_ptr = tensor->data_c(); - MS_EXCEPTION_IF_NULL(data_ptr); - auto elem_num = values.size() * data_length; - auto ret_code = memcpy_s(data_ptr, static_cast(tensor->data().nbytes()), values.data(), elem_num); - if (ret_code != 0) { - MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; - } - return tensor; -} - -tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { - MS_EXCEPTION_IF_NULL(value_tuple); - tensor::TensorPtr tensor = nullptr; - if (value_tuple->value().empty()) { - MS_LOG(WARNING) << "The value tuple is empty."; - return nullptr; - } - ValuePtr v = *(value_tuple->value().begin()); - MS_EXCEPTION_IF_NULL(v); - // Currently we only deal with the scalar tuple - if (!v->isa()) { - MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; - return nullptr; - } - ScalarPtr scalar = v->cast(); - MS_EXCEPTION_IF_NULL(scalar); - if (scalar->isa()) { - tensor = CreateTensorWithValueTuple(value_tuple, kInt32, kType32Len); - } else if (scalar->isa()) { - tensor = CreateTensorWithValueTuple(value_tuple, kFloat32, kType32Len); - } else { - auto type = scalar->type(); - auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); - MS_LOG(ERROR) << "Invalid scalar type: " << type_str; - return nullptr; - } - return tensor; -} - -bool IsNopNode(const AnfNodePtr &node) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) { - return false; - } - static std::unordered_set nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, - prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(), - kFlattenGradOpName}; - if (node == nullptr || !node->isa()) { - return false; - } - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end()) { - return false; - } - return true; -} - -bool IsAllNopNode(const session::KernelGraph *const graph) { - MS_EXCEPTION_IF_NULL(graph); - auto execution_order = graph->execution_order(); - for (auto &cnode : execution_order) { - MS_EXCEPTION_IF_NULL(cnode); - if (!IsNopNode(cnode)) { - return false; - } - } - return true; -} - -void HideNopNode(session::KernelGraph *const graph) { - MS_EXCEPTION_IF_NULL(graph); - if (IsAllNopNode(graph) == true) { - return; - } - auto execution_order = graph->execution_order(); - MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); - std::vector new_nodes; - for (auto &cnode : execution_order) { - MS_EXCEPTION_IF_NULL(cnode); - if (!IsNopNode(cnode)) { - new_nodes.push_back(cnode); - } - } - graph->set_execution_order(new_nodes); - MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size(); -} - -void RemoveNopNode(session::KernelGraph *const graph) { - MS_EXCEPTION_IF_NULL(graph); - if (IsAllNopNode(graph) == true) { - return; - } - bool changed = true; - while (changed) { - changed = false; - std::vector new_nodes; - for (auto &cnode : graph->execution_order()) { - MS_EXCEPTION_IF_NULL(cnode); - // ignore nop node itself - if (IsNopNode(cnode)) { - continue; - } - // Replace the input which is nop node - std::vector new_inputs; - new_inputs.push_back(cnode->input(0)); - bool need_update = false; - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - auto input = cnode->input(i); - MS_EXCEPTION_IF_NULL(input); - auto cinput = input->cast(); - if (cinput == nullptr || !IsNopNode(cinput)) { - new_inputs.push_back(input); - continue; - } - if (cinput->inputs().size() == 2) { - new_inputs.push_back(cinput->input(1)); - need_update = true; - changed = true; - } else { - new_inputs.push_back(input); - } - } - if (need_update) { - cnode->set_inputs(new_inputs); - } - // push into new execution list - new_nodes.push_back(cnode); - } - graph->set_execution_order(new_nodes); - } -} - -std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, - const AnfNodePtr &node) { - auto output_node_list = std::make_shared>>(); - MS_EXCEPTION_IF_NULL(graph); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto iter = manager->node_users().find(node); - if (iter == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - auto output_info_list = iter->second; - for (const auto &output_info : output_info_list) { - if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { - continue; - } - if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && - output_info.second == kDependAttachNodeIndex) { - continue; - } - output_node_list->push_back(output_info); - } - return output_node_list; -} - -bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto output_node_list = GetRealNodeUsedList(graph, node); - MS_EXCEPTION_IF_NULL(output_node_list); - return output_node_list->size() > 1; -} - -AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { - auto idx = NewValueNode(SizeToInt(output_idx)); - MS_EXCEPTION_IF_NULL(idx); - auto imm = std::make_shared(SizeToInt(output_idx)); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); - MS_EXCEPTION_IF_NULL(tuple_getitem); - tuple_getitem->set_scope(node->scope()); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); - AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); - return tuple_getitem; -} - -void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs) { - MS_EXCEPTION_IF_NULL(cnode); - std::vector new_inputs; - std::vector new_input_names; - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - auto input_names = primitive->GetAttr(kAttrInputNames); - if (input_names == nullptr) { - MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; - return; - } - auto input_names_vec = GetValue>(input_names); - auto inputs = cnode->inputs(); - new_inputs.push_back(inputs[0]); - bool need_update = false; - for (size_t i = 0; i < inputs.size() - 1; ++i) { - auto input_node = inputs[i + 1]; - MS_EXCEPTION_IF_NULL(input_node); - if (input_attrs.find(i) != input_attrs.end() && input_node->isa()) { - auto value_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; - if (i >= input_names_vec.size()) { - MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; - } - primitive->set_attr(input_names_vec[i], value_node->value()); - need_update = true; - } else { - new_inputs.push_back(input_node); - if (i < input_names_vec.size()) { - new_input_names.push_back(input_names_vec[i]); - } - } - } - if (need_update) { - // Update cnode's inputs - cnode->set_inputs(new_inputs); - // Update cnode's input_names attr - primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); - } -} - -bool AnfEqual(const BaseRef &a, const BaseRef &b) { - if (utils::isa(a) && utils::isa(b)) { - auto a_node = utils::cast(a); - auto b_node = utils::cast(b); - MS_EXCEPTION_IF_NULL(a_node); - MS_EXCEPTION_IF_NULL(b_node); - if (IsValueNode(a_node) && IsValueNode(b_node)) { - auto a_value_node = a_node->cast(); - MS_EXCEPTION_IF_NULL(a_value_node); - auto a_value = a_value_node->value(); - MS_EXCEPTION_IF_NULL(a_value); - auto a_prim = a_value->cast(); - MS_EXCEPTION_IF_NULL(a_prim); - - auto b_value_node = b_node->cast(); - MS_EXCEPTION_IF_NULL(b_value_node); - auto b_value = b_value_node->value(); - MS_EXCEPTION_IF_NULL(b_value); - auto b_prim = b_value->cast(); - MS_EXCEPTION_IF_NULL(b_prim); - - return a_prim->name() == b_prim->name(); - } else if (a_node->isa() && b_node->isa()) { - auto a_value_node_ptr = a_node->cast(); - if (a_value_node_ptr == nullptr) { - MS_LOG(EXCEPTION) << "cast value node ptr fail"; - } - auto a_value_ptr = a_value_node_ptr->value(); - if (a_value_ptr == nullptr) { - MS_LOG(EXCEPTION) << "value ptr is nullptr"; - } - - auto b_value_node_ptr = b_node->cast(); - if (b_value_node_ptr == nullptr) { - MS_LOG(EXCEPTION) << "cast value node ptr fail"; - } - auto b_value_ptr = b_value_node_ptr->value(); - if (b_value_ptr == nullptr) { - MS_LOG(EXCEPTION) << "value ptr is nullptr"; - } - - return (*a_value_ptr) == (*b_value_ptr); - } - MS_LOG(DEBUG) << "check AnfNodePtr equal"; - } - if (utils::isa(a) && utils::isa(b)) { - MS_LOG(DEBUG) << "check GraphPtr equal"; - } - return a == b; -} - -bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { - // To matchCNode and Kernel's type - if (utils::isa(a) && utils::isa(b)) { - return true; - } - return a.type() == b.type(); -} - -namespace { -ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - return nullptr; -} - -CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { - if (utils::isa(graph)) { - return std::make_shared(input_nodes, utils::cast(graph)); - } - if (utils::isa(graph)) { - return std::make_shared(input_nodes, utils::cast(graph)); - } - return nullptr; -} - -VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { - if (utils::isa(graph)) { - MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); - return std::make_shared(utils::cast(sexp), nullptr); - } - if (utils::isa(graph)) { - MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); - return std::make_shared(utils::cast(sexp), utils::cast(graph)); - } - MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); - return nullptr; -} - -AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, - bool multigraph) { - MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); - std::vector input_nodes; - const auto &tuple = utils::cast(sexp); - if (multigraph && utils::isa(graph)) { - for (auto &x : tuple) { - AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); - input_nodes.push_back(node); - } - VarPtr var_ptr = utils::cast(graph); - return std::make_shared(input_nodes, var_ptr); - } - - for (auto &x : tuple) { - AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); - input_nodes.push_back(node); - } - return CreateCNodeWithGraph(input_nodes, graph); -} -} // namespace - -AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { - MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); - MS_EXCEPTION_IF_NULL(primitive_vars); - if (utils::isa(sexp)) { - return HandleSexpVector(sexp, graph, primitive_vars, multigraph); - } - if (utils::isa(sexp)) { - auto var_ptr = utils::cast(sexp); - MS_EXCEPTION_IF_NULL(var_ptr); - if (var_ptr->primitive()) { - (*primitive_vars)[var_ptr->primitive()] = var_ptr; - return NewValueNode(var_ptr->primitive()); - } - return CreateVarNodeWithSexp(sexp, graph); - } - if (utils::isa(sexp)) { - return utils::cast(sexp); - } - auto value_node = CreateValueNodeWithSexp(sexp); - if (value_node == nullptr) { - MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); - } - return value_node; -} - -bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) { - MS_EXCEPTION_IF_NULL(equiv1); - MS_EXCEPTION_IF_NULL(equiv2); - MS_EXCEPTION_IF_NULL(var_node); - auto equiv1_node = GetAnfNodeByVar(equiv1, var_node); - MS_EXCEPTION_IF_NULL(equiv1_node); - auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); - MS_EXCEPTION_IF_NULL(equiv2_node); - return *equiv1_node == *equiv2_node; -} - -AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { - MS_EXCEPTION_IF_NULL(equiv); - MS_EXCEPTION_IF_NULL(var_node); - auto iter = (*equiv).find(var_node); - if (iter == (*equiv).end()) { - MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched."; - return nullptr; - } - auto res = utils::cast(iter->second); - if (res == nullptr) { - MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node"; - } - return res; -} - -bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { - MS_EXCEPTION_IF_NULL(n1); - MS_EXCEPTION_IF_NULL(n2); - auto n1_cnode = n1->cast(); - auto n2_cnode = n2->cast(); - MS_EXCEPTION_IF_NULL(n1_cnode); - MS_EXCEPTION_IF_NULL(n2_cnode); - auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_input1); - auto value_node1 = index_input1->cast(); - MS_EXCEPTION_IF_NULL(value_node1); - auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_input2); - auto value_node2 = index_input2->cast(); - MS_EXCEPTION_IF_NULL(value_node2); - return GetValue(value_node1->value()) < GetValue(value_node2->value()); -} - -bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(INFO) << "node is not a cnode"; - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr(node, attr_name); -} - -bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set) { - MS_EXCEPTION_IF_NULL(node); - TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0); - if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) { - return true; - } - MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString(); - return false; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h deleted file mode 100644 index 49a1d47d0c..0000000000 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ /dev/null @@ -1,199 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ - -#include -#include -#include -#include -#include -#include -#include "ir/func_graph.h" -#include "session/kernel_graph.h" -#include "common/utils.h" -#include "pre_activate/common/pattern_engine.h" - -namespace mindspore { -namespace opt { -constexpr size_t kTransOpInputNum = 2; -constexpr size_t kCastInputNum = 2; -constexpr size_t kDependInputNum = 3; -constexpr size_t kReluInputNum = 2; -constexpr size_t kReluGradInputNum = 3; -constexpr size_t kAddInputNum = 3; -constexpr size_t kAddNInputNum = 3; -constexpr size_t kTupleGetitemInputNum = 3; -constexpr size_t kConvInputNum = 3; -constexpr size_t kRealDivInputNum = 3; -constexpr size_t kSqrtInputNum = 2; -constexpr size_t kMulInputNum = 3; -constexpr size_t kRsqrtInputNum = 2; -constexpr size_t kSubInputNum = 3; -constexpr size_t kAssignSubInputNum = 3; - -constexpr size_t kConvBn1OutputNum = 3; -constexpr size_t kBn2ReluOutputNum = 4; - -constexpr size_t kBnInputNum = 6; -constexpr size_t kBnOutputNum = 5; -constexpr size_t kBatchNormInputNum = 5; -constexpr size_t kBatchNormOutputNum = 5; - -constexpr size_t kBN1OutputNum = 2; -constexpr size_t kBN2OutputNum = 3; -constexpr size_t kBN3OutputNum = 1; - -constexpr size_t kBNGradInputNum = 6; -constexpr size_t kBNGradOutputNum = 3; - -constexpr size_t kBNGrad1OutputNum = 3; -constexpr size_t kBNGrad2OutputNum = 5; -constexpr size_t kBNGrad3OutputNum = 1; - -constexpr size_t kBNTrainingReduceOutputNum = 2; -constexpr size_t kBNTrainingUpdateOutputNum = 5; -constexpr size_t kBNTrainingUpdateV2OutputNum = 3; -constexpr size_t kBNTrainingUpdateV3OutputNum = 5; -constexpr size_t kBNTrainingUpdateGradOutputNum = 2; - -constexpr size_t kSingleOutputNum = 1; -constexpr size_t kSumNodeInputNum = 2; -constexpr size_t kSquareNodeInputNum = 2; -constexpr size_t kSquareSumv2OutputNum = 2; -constexpr size_t kMinimumInputNum = 3; - -constexpr size_t kLambNextMVWithDecayInputNum = 7; -constexpr size_t kLambNextMVWithDecayConstantMulInputNum = 5; -constexpr size_t kLambNextMVWithDecayOutputNum = 4; -constexpr size_t kLambNextMVWithDecayV1OutputNum = 4; -constexpr size_t kLambNextRightOutputNum = 2; -constexpr size_t kLambUpdateWithLrV2InputNum = 8; -constexpr size_t kLambNextMVRuleInputNum = 14; -constexpr size_t kLambNextMVRuleOutputNum = 4; -constexpr size_t kBackendReshapeInputNum = 2; -constexpr size_t kBackendTransposeInputNum = 2; -constexpr size_t kAdamApplyOneWithDecayOutputNum = 3; -constexpr size_t kLayerNormBetaGammaBackpropInputNum = 5; -constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2; -constexpr size_t kLayerNormGradInputNum = 6; -constexpr size_t kAdamApplyOneOutputNum = 3; -constexpr size_t kBackendTransDataInputNum = 2; -constexpr size_t kApplyMomentumInputNum = 6; -constexpr size_t kBiasAddInputNum = 3; -constexpr size_t kTopkInputNum = 3; -constexpr size_t kLarsV2InputNum = 5; -constexpr size_t kFusedMulApplyMomentumOutputNum = 2; -constexpr size_t kSplitInputNum = 2; - -enum FusedBatchNormInput { - kX = 1, - kVariance = 5, -}; -enum FusedBatchNormOutput { - kY = 0, - kRunningMean, - kRunningVariance, - kSaveMean, - kSaveInvVariance, -}; -enum ConvBn1Output { - kData = 0, - kVarPart, - kMean, -}; - -std::vector Convert2Int(const std::vector &v); - -// check whether node1 depends on node2 or not -bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2); - -bool UnVisited(const BaseRef &n); - -bool Visited(const BaseRef &n); - -// check if the input node is CNode, then check it's input_size, if meet condition above, return true, otherwise return -// false. cnode can only be used when return true. -bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode); - -// check if the input node is CNode, then check it's input_size, return CNodePtr if check success. -CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size); - -void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size); - -bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y); - -const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node); - -void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode, - std::vector *conv_bn1_outputs); - -void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector &fused_bn1_outputs, - const CNodePtr &bn_node, std::vector *fused_bn2_outputs); -void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input, - const std::vector &fused_bn1_outputs, - const std::vector &fused_bn2_outputs, const CNodePtr &bn_node, - std::vector *fused_bn3_outputs); - -void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &anf_node_ptr, size_t output_num, - std::vector *outputs); - -tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, - size_t data_length); - -tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple); - -bool IsAllNopNode(const session::KernelGraph *const graph); - -bool IsNopNode(const AnfNodePtr &node); - -void HideNopNode(session::KernelGraph *const graph); - -void RemoveNopNode(session::KernelGraph *const graph); - -AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); - -bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); - -std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, - const AnfNodePtr &node); - -void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs); - -bool AnfEqual(const BaseRef &a, const BaseRef &b); - -bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); - -AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, - bool multigraph = false); - -// Check var_node in two equivs is the same node -bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node); - -// Get anf_node from equiv by var_node -AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node); - -// Compare tuple getitem's index, return bool[n1's index < n2's index] -bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2); - -// Get attr which is bool from cnode -bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name); - -// Check node's data type is in supported data type set -bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set); -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/pre_activate/common/node_pass.cc b/mindspore/ccsrc/pre_activate/common/node_pass.cc deleted file mode 100644 index 876da8667b..0000000000 --- a/mindspore/ccsrc/pre_activate/common/node_pass.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/common/node_pass.h" - -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/manager.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -bool NodePass::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - - std::unordered_set seen_node; - std::deque todo{func_graph->output()}; - bool changes = false; - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); - if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { - continue; - } - (void)seen_node.insert(node); - AnfNodePtr new_node = Run(func_graph, node); - bool change = (new_node != nullptr); - if (new_node != nullptr && new_node != node) { - (void)manager->Replace(node, new_node); - (void)seen_node.erase(node); - } else if (new_node == nullptr) { - new_node = node; - } - if (new_node && IsValueNode(new_node)) { - auto const_func_graph = GetValueNode(new_node); - MS_EXCEPTION_IF_NULL(const_func_graph); - if (!const_func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - todo.push_back(const_func_graph->output()); - } - } else if (new_node && new_node->isa()) { - if (AnfAlgo::IsGraphKernel(new_node)) { - todo.push_back(new_node); - } - auto cnode = new_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto inputs = cnode->inputs(); - (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); - } - changes = changes || change; - } - return changes; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/node_pass.h b/mindspore/ccsrc/pre_activate/common/node_pass.h deleted file mode 100644 index 7750a59e59..0000000000 --- a/mindspore/ccsrc/pre_activate/common/node_pass.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ -#include -#include - -#include "pre_activate/common/pass.h" - -namespace mindspore { -namespace opt { -// @brief ANF Node level optimization base pass -class NodePass : public Pass { - public: - explicit NodePass(const std::string &name) : Pass(name) {} - ~NodePass() override = default; - bool Run(const FuncGraphPtr &func_graph) final; - virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; -}; -using NodePassPtr = std::shared_ptr; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/common/optimizer.cc b/mindspore/ccsrc/pre_activate/common/optimizer.cc deleted file mode 100644 index 71a523ea1d..0000000000 --- a/mindspore/ccsrc/pre_activate/common/optimizer.cc +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/common/optimizer.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "pre_activate/common/pass_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "ir/manager.h" - -namespace mindspore { -namespace opt { -PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) - : NodePass(name), - multigraph_(multigraph), - pattern_engine_(PatternEngine(std::make_shared(), - std::function(AnfEqual), - std::function(CNodeTypeEqual))), - primitive_vars_(std::make_shared()) {} - -const BaseRef PatternProcessPass::DefinePattern() const { - VarPtr X = std::make_shared(); - return BaseRef({X}); -} - -void PatternProcessPass::Build() { - VarPtr fg = std::make_shared("RootG"); - BaseRef pattern = std::move(DefinePattern()); - pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_); -} - -AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - if (pattern_ == nullptr) { - Build(); - } - - auto empty_equiv = std::make_shared(); - MS_EXCEPTION_IF_NULL(primitive_vars_); - EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv); - if (equiv != nullptr && !equiv->empty()) { - return Process(func_graph, node, equiv); - } - return nullptr; -} - -bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - VarPtr fg = std::make_shared("RootG"); - auto empty_equiv = std::make_shared(); - MS_EXCEPTION_IF_NULL(child_primitive_vars_); - EquivPtr another_equiv = - child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node, - *child_primitive_vars_, empty_equiv); - if (another_equiv != nullptr && !another_equiv->empty()) { - return IsShareNodes(equiv, another_equiv); - } - return false; -} - -void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) { - if (pass_manager != nullptr) { - pass_managers_.push_back(pass_manager); - } -} - -FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) { - MS_EXCEPTION_IF_NULL(func_graph); - run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once; - // Performance risk by creating new manager each time - auto manager = Manage(func_graph, true); - - bool changed = true; - while (changed) { - changed = false; - for (size_t i = 0; i < pass_managers_.size(); ++i) { - const PassManagerPtr &pm = pass_managers_[i]; - if (pm != nullptr && pm->Run(func_graph)) { - changed = true; - } - } - if (run_only_once_) { - break; - } - } - - std::vector func_graphs; - func_graphs.push_back(func_graph); - manager->KeepRoots(func_graphs); - (void)TopoSort(func_graph->get_return()); - return func_graph; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/optimizer.h b/mindspore/ccsrc/pre_activate/common/optimizer.h deleted file mode 100644 index 1f9961df6b..0000000000 --- a/mindspore/ccsrc/pre_activate/common/optimizer.h +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ - -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/common/pattern_engine.h" -#include "utils/graph_utils.h" -#include "common/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -using PatternListType = std::initializer_list; - -class PatternProcessPass : public NodePass { - public: - explicit PatternProcessPass(const std::string &name = "", bool multigraph = true); - ~PatternProcessPass() override = default; - virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0; - virtual const BaseRef DefinePattern() const; - AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; - - private: - void Build(); - - AnfNodePtr pattern_ = nullptr; - bool multigraph_ = true; - PatternEngine pattern_engine_; - PrimitiveVarMapPtr primitive_vars_; -}; - -class MultipleOutputPatternProcessPass : public PatternProcessPass { - public: - explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true) - : PatternProcessPass(name, multigraph), - child_pattern_engine_(PatternEngine(std::make_shared(), - std::function(AnfEqual), - std::function(CNodeTypeEqual))), - child_primitive_vars_(std::make_shared()) {} - ~MultipleOutputPatternProcessPass() override = default; - virtual BaseRef DefineAnotherPattern() const = 0; - // check two patterns whether share the same nodes or not - virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0; - - protected: - bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; - PatternEngine child_pattern_engine_; - PrimitiveVarMapPtr child_primitive_vars_; -}; - -class GraphOptimizer { - public: - explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {} - virtual ~GraphOptimizer() = default; - - void AddPassManager(const PassManagerPtr &pass_manager); - FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true); - - private: - const std::string name_ = "graph_optimizer"; - std::vector pass_managers_{}; - bool run_only_once_ = true; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/pre_activate/common/pass.h b/mindspore/ccsrc/pre_activate/common/pass.h deleted file mode 100644 index 3d2468cddb..0000000000 --- a/mindspore/ccsrc/pre_activate/common/pass.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ -#include -#include - -#include "ir/anf.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -// @brief ANF Graph level optimization base pass -class Pass { - public: - explicit Pass(const std::string &name = "pass") : name_(name) {} - virtual ~Pass() = default; - virtual bool Run(const FuncGraphPtr &func_graph) = 0; - virtual std::string name() const { return name_; } - - private: - const std::string name_; -}; -using PassPtr = std::shared_ptr; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/common/pass_manager.cc b/mindspore/ccsrc/pre_activate/common/pass_manager.cc deleted file mode 100644 index 3213b8a6d2..0000000000 --- a/mindspore/ccsrc/pre_activate/common/pass_manager.cc +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/common/pass_manager.h" - -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/manager.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "debug/anf_ir_dump.h" - -namespace mindspore { -namespace opt { -const std::vector &PassManager::Passes() const { return passes_; } - -void PassManager::AddPass(const PassPtr &pass) { - if (pass != nullptr) { - passes_.push_back(pass); - } -} - -bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { - if (func_graph == nullptr) { - return false; - } - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - bool changed = false; - size_t num = 0; - for (const auto &pass : passes) { - if (pass != nullptr) { -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time {}; - struct timeval end_time {}; - (void)gettimeofday(&start_time, nullptr); -#endif - if (pass->Run(func_graph)) { - changed = true; - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost << " us"; -#endif - if (save_graphs) { - auto dump_file_path = - save_graphs_path + "/" + "hwopt_" + name() + "_" + std::to_string(num) + "_" + pass->name() + ".ir"; - DumpIR(dump_file_path, func_graph); - } - num++; - } - } - return changed; -} - -bool PassManager::Run(const FuncGraphPtr &func_graph) const { - bool changed = false; - // run all passes - bool change = true; - while (change) { - change = Run(func_graph, passes_); - changed = change || changed; - if (run_only_once_) { - break; - } - } - return changed; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/pass_manager.h b/mindspore/ccsrc/pre_activate/common/pass_manager.h deleted file mode 100644 index 38fe49b94c..0000000000 --- a/mindspore/ccsrc/pre_activate/common/pass_manager.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ - -#include -#include -#include -#include - -#include "pre_activate/common/pass.h" -#include "pre_activate/common/node_pass.h" - -namespace mindspore { -namespace opt { -// @brief For optimization passes management -class PassManager { - public: - explicit PassManager(const std::string &name = "pm", bool run_only_once = true) - : name_(name), passes_{}, run_only_once_(run_only_once) {} - virtual ~PassManager() = default; - // Get all the passes added by AddPass - const std::vector &Passes() const; - // Add graph pass, the pass object will be freed when pass manager freed. - void AddPass(const PassPtr &pass); - // Run passes added in pass manager on the input graph - // @param [inout] graph The graph to be optimized - // @return true, graph changed - // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph) const; - // Run the given graph passes on the input graph - // @param [inout] graph The graph to be optimized - // @param [in] passes The given graph passes - // @return true, graph changed - // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; - std::string name() const { return name_; } - - private: - const std::string name_; - std::vector passes_; - bool run_only_once_; -}; -using PassManagerPtr = std::shared_ptr; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/pre_activate/common/pattern_engine.cc b/mindspore/ccsrc/pre_activate/common/pattern_engine.cc deleted file mode 100644 index 42f966aa3d..0000000000 --- a/mindspore/ccsrc/pre_activate/common/pattern_engine.cc +++ /dev/null @@ -1,360 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/common/pattern_engine.h" - -#include -#include -#include -#include - -#include "optimizer/opt.h" - -#include "ir/anf.h" -#include "utils/convert_utils_base.h" -#include "utils/overload.h" - -namespace mindspore { -static int GetNextTag() { - static int kID = 0; - return kID++; -} - -void Var::EnsureTag() { - if (tag_.length() == 0) { - std::ostringstream buffer; - buffer << "_" << GetNextTag(); - tag_ = buffer.str(); - } -} - -bool operator==(const VarPtr &lhs, const VarPtr &rhs) { - if (lhs->isa() && rhs->isa()) { - CondVarPtr v1 = dyn_cast(lhs); - CondVarPtr v2 = dyn_cast(rhs); - return *v1 == *v2; - } - - if (lhs->isa() && rhs->isa()) { - SVarPtr v1 = dyn_cast(lhs); - SVarPtr v2 = dyn_cast(rhs); - return *v1 == *v2; - } - return (*lhs == *rhs); -} - -std::string SeqVar::ToString() const { - std::ostringstream buffer; - buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")"; - return buffer.str(); -} - -std::ostream &operator<<(std::ostream &os, const VarPtr &var) { - if (var == nullptr) { - os << ""; - } else { - os << var->ToString(); - } - return os; -} - -template <> -std::ostream &operator<<(std::ostream &os, const Equiv &equiv) { - os << "[Equiv]" - << "\n"; - for (auto &equiv_item : equiv) { - auto k = equiv_item.first; - os << k << ":"; - BaseRef x = equiv_item.second; - if (utils::isa(x)) { - auto node = utils::cast(x); - os << "TypeString[" << node->type_name() << "]"; - if (IsValueNode(node)) { - os << "IsValueNodeGraph "; - } - os << "type " << node->type_name(); - if (node->isa()) { - os << " value " << GetValueNode(node); - } - os << " addr: " << node; - } else if (utils::isa(x)) { - os << "Named " << x.ToString().c_str(); - } else if (utils::isa(x)) { - os << "TypeString[Var]"; - os << utils::cast(x); - } else if (utils::isa(x)) { - os << "TypeString[Graph]"; - } - os << "\n"; - } - return os; -} - -static BaseRef GetVar(const BaseRef &x) { - MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); - if (utils::isa(x)) { - auto node = utils::cast(x); - MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]"; - if (node->isa()) { - MS_LOG(DEBUG) << "IsVarNode " + node->cast()->var_->ToString(); - return node->cast()->var_; - } - if (node->isa()) { - MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString(); - } else { - MS_LOG(DEBUG) << "type " + node->type_name(); - } - } else if (utils::isa(x)) { - MS_LOG(DEBUG) << "Named " + x.ToString(); - } else if (utils::isa(x)) { - MS_LOG(DEBUG) << "VectorRef"; - } else if (utils::isa(x)) { - MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString(); - } - MS_LOG(DEBUG) << "GetVar end: " + x.ToString(); - return x; -} - -EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { - MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); - MS_EXCEPTION_IF_NULL(equiv); - if (utils::isa(pattern)) { - VarPtr var = utils::cast(pattern); - if (var->matches(expr)) { - (*equiv)[var] = expr; - MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString(); - return equiv; - } - } - - return nullptr; -} - -bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, - VectorRef *const values_expr) const { - MS_EXCEPTION_IF_NULL(values_expr); - if (utils::isa(pattern_ref)) { - *values_pattern = pattern_ref; - *values_expr = expr_ref; - return true; - } - return false; -} - -bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern, - VectorRef *const values_expr) const { - MS_EXCEPTION_IF_NULL(values_expr); - // visitor to visite the list - auto appender_pattern = [](VectorRef &values) { - std::function fn = [&](const BaseRef &u) { - values.push_back(GetVar(u)); - return u; - }; - return fn; - }; - - visitor_->SetFn(appender_pattern(*values_pattern)); - MS_LOG(DEBUG) << "visit pattern_ref"; - bool success = visitor_->Visit(pattern_ref, nullptr); - if (!success) { - return false; - } - - auto appender_expr = [](VectorRef &values) { - std::function fn = [&](const BaseRef &u) { - values.push_back(u); - return u; - }; - return fn; - }; - - visitor_->SetFn(appender_expr(*values_expr)); - MS_LOG(DEBUG) << "visit expr_ref"; - return visitor_->Visit(expr_ref, nullptr); -} - -static int GetSVarStartIndex(const VectorRef &values) { - int index = -1; - int count = 0; - for (auto &value : values) { - if (utils::isa(value) && utils::cast(value)->isa()) { - if (index != -1) { - MS_LOG(DEBUG) << "Multiple SVars in sequence"; - return kInvalidVarIndex; - } - index = count; - } - count++; - } - return index; -} - -void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars, - EquivPtr equiv) { - if (equiv == nullptr || values_pattern.empty() || !utils::isa(values_pattern[0]) || - !utils::isa(expr_ref)) { - return; - } - auto real_node = utils::cast(expr_ref); - MS_EXCEPTION_IF_NULL(real_node); - if (!real_node->isa()) { - return; - } - auto prim_node = utils::cast(values_pattern[0]); - MS_EXCEPTION_IF_NULL(prim_node); - if (!IsValueNode(prim_node)) { - return; - } - ValuePtr value = GetValueNode(prim_node); - MS_EXCEPTION_IF_NULL(value); - auto prim = value->cast(); - MS_EXCEPTION_IF_NULL(prim); - auto iter = primitive_vars.find(prim); - if (iter == primitive_vars.end()) { - return; - } - (*equiv)[iter->second] = real_node; -} - -EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, - const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { - int svar_index = GetSVarStartIndex(values_pattern); - if (svar_index == kInvalidVarIndex) { - return nullptr; - } - - size_t values_pattern_len = values_pattern.size(); - size_t values_expr_len = values_expr.size(); - - if (svar_index == -1) { - if (values_pattern_len != values_expr_len) { - MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", expr len " - << values_expr_len; - return nullptr; - } - } - if (values_expr_len < values_pattern_len - 1) { - MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len; - return nullptr; - } - size_t diff = values_expr_len - values_pattern_len + 1; - for (size_t i = 0; i < values_pattern_len; i++) { - size_t expr_i = i; - if (svar_index != -1 && i == IntToSize(svar_index)) { - auto seq = - std::vector(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); - equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv); - } else { - if (svar_index != -1 && i > IntToSize(svar_index)) { - expr_i = i + diff - 1; - } - equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv); - } - if (equiv == nullptr) { - return nullptr; - } - } - return equiv; -} - -EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, - EquivPtr equiv) const { - MS_LOG(DEBUG) << "-----[in Match]"; - MS_LOG(DEBUG) << "GetVar w"; - BaseRef pattern_ref = GetVar(pattern); - MS_LOG(DEBUG) << "GetVar v"; - BaseRef expr_ref = expr; - - if (equiv == nullptr) { - MS_LOG(EXCEPTION) << "Equiv pointer is null"; - } - - MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString(); - // 1. if pattern_ref is var and already in equiv, replace it. - if (utils::isa(pattern_ref)) { - VarPtr var = utils::cast(pattern_ref); - auto iter = equiv->find(var); - if (iter != equiv->end()) { - pattern_ref = iter->second; - } - } - - // 2. check equal - if (eq_(pattern_ref, expr_ref)) { - return equiv; - } - - // 3. match var - EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv); - if (ret_equiv) { - return ret_equiv; - } - - // 4. here the type can be std:vector, std:list, - // or cnode. - if (!type_eq_(pattern_ref, expr_ref)) { - MS_LOG(DEBUG) << "Type mismatch"; - return nullptr; - } - - // 5. transfer the Containers by visitor to std::vector - VectorRef values_pattern; - VectorRef values_expr; - if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) { - return nullptr; - } - - // 6. if any svar in both side, find the SeqVar index, - // try to pack the Var s in std::vector to a Seq and match elements one by one. - // check svar - equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv); - UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv); - return equiv; -} - -BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(equiv); - MS_LOG(DEBUG) << "-----[in Replace]"; - BaseRef ref = GetVar(pattern); - BaseRef out; - bool is_match = false; - - // w is var - if (utils::isa(ref)) { - const VarPtr &var = utils::cast(ref); - auto iter = equiv->find(var); - if (iter != equiv->end()) { - out = iter->second; - is_match = true; - } - } - if (is_match) { - return out; - } - - // visitor to visit the list - std::function fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); }; - - visitor_->SetFn(fn); - BaseRef visit_out; - if (!visitor_->Visit(pattern, &visit_out)) { - return pattern; - } - return visit_out; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/pattern_engine.h b/mindspore/ccsrc/pre_activate/common/pattern_engine.h deleted file mode 100644 index 858b1aecb8..0000000000 --- a/mindspore/ccsrc/pre_activate/common/pattern_engine.h +++ /dev/null @@ -1,204 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "pre_activate/common/visit.h" -#include "ir/base.h" -#include "utils/log_adapter.h" -#include "utils/base_ref.h" - -namespace mindspore { -class CondVar; -class SeqVar; -using CondVarPtr = std::shared_ptr; -using SVarPtr = std::shared_ptr; -const int kInvalidVarIndex = -2; - -using ConditionFunc = std::function; - -// Base wildcard variable which could match any anf node. -class Var : public Base { - friend class VarHasher; - - public: - explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); } - explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { - EnsureTag(); - } - Var(const Var &other) : Base(other), tag_(other.tag_) {} - virtual Var &operator=(const Var &other) { - if (&other == this) { - return *this; - } - this->tag_ = other.tag_; - return *this; - } - ~Var() override = default; - MS_DECLARE_PARENT(Var, Base); - - virtual bool matches(const BaseRef &) { return true; } - - virtual bool operator==(const Var &other) const { return tag_ == other.tag_; } - bool operator!=(const Var &other) const { return !(&other == this); } - - std::string tag() const { return tag_; } - PrimitivePtr primitive() const { return primitive_; } - std::string ToString() const override { - std::ostringstream buffer; - buffer << "Var(" << tag_ << ")"; - return buffer.str(); - } - std::size_t hash() const override { return std::hash()(tag_); } - - protected: - void EnsureTag(); - - std::string tag_; - PrimitivePtr primitive_; -}; - -// VarNode means variable node, a subclass of AnfNode -class VarNode : public AnfNode { - public: - VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {} - ~VarNode() override = default; - MS_DECLARE_PARENT(VarNode, AnfNode); - - const VarPtr var_; -}; -using VarNodePtr = std::shared_ptr; - -class VarHasher { - public: - std::size_t operator()(const Var &var) const { return var.hash(); } -}; - -// Condition Var, match an anf node when condition function return true. -class CondVar : public Var { - public: - explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {} - ~CondVar() override = default; - MS_DECLARE_PARENT(CondVar, Var); - bool matches(const BaseRef &value) override { - MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); - if (utils::isa(value)) { - return false; - } - return cond_fn_(value); - } - ConditionFunc cond_fn_; -}; - -using Seq = VectorRef; -using SeqPtr = std::shared_ptr; - -// Sequence Var which could match multiple consecutive input nodes of a CNode. -class SeqVar : public Var { - public: - SeqVar() { subvar_ = std::make_shared(); } - ~SeqVar() override = default; - MS_DECLARE_PARENT(SeqVar, Var); - explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; } - bool matches(const BaseRef &value) override { - // match Seq. - if (utils::isa(value)) { - const Seq &seq = utils::cast(value); - return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) { - auto eq = subvar_->matches(v); - return eq; - }); - } - return false; - } - bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; } - std::string ToString() const override; - - private: - VarPtr subvar_; -}; - -bool operator==(const VarPtr &lhs, const VarPtr &rhs); - -inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); } - -std::ostream &operator<<(std::ostream &os, const VarPtr &var); - -using Equiv = std::map; -using EquivPtr = std::shared_ptr; -using PrimitiveVarMap = std::unordered_map; -using PrimitiveVarMapPtr = std::shared_ptr; - -inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); } - -class PatternEngine { - public: - PatternEngine(const std::shared_ptr &visitor, - const std::function &eq, - const std::function &type_eq = DefaultTypeEq) - : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} - ~PatternEngine() = default; - - EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, - EquivPtr equiv) const; - // Replace pattern with equivalent - BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const; - - private: - EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, - const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; - bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern, - VectorRef *const values_expr) const; - bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, - VectorRef *const values_expr) const; - std::shared_ptr visitor_; - std::function eq_; - std::function type_eq_; -}; -} // namespace mindspore -namespace std { -using mindspore::ERROR; -using mindspore::LogStream; -using mindspore::NoExceptionType; -template <> -struct hash { - std::size_t operator()(const mindspore::VarPtr var) const { - if (var == nullptr) { - MS_LOG(ERROR) << "Invalid var ptr"; - return 0; - } - return std::hash{}(var->tag()); - } -}; -} // namespace std -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ diff --git a/mindspore/ccsrc/pre_activate/common/visit.cc b/mindspore/ccsrc/pre_activate/common/visit.cc deleted file mode 100644 index 179177dd67..0000000000 --- a/mindspore/ccsrc/pre_activate/common/visit.cc +++ /dev/null @@ -1,166 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/common/visit.h" - -#include -#include -#include -#include - -#include "pre_activate/common/pattern_engine.h" -#include "utils/any.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "utils/log_adapter.h" - -/* namespace to support utils definition */ -namespace mindspore { -bool CheckIfNeedExpand(const std::vector &list) { - return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa(any); }); -} - -std::shared_ptr ExpandList(const std::vector &list) { - std::shared_ptr new_list = std::make_shared(); - for (auto &item : list) { - if (utils::isa(item)) { - const Seq &seq = utils::cast(item); - new_list->insert(new_list->end(), seq.begin(), seq.end()); - } else { - new_list->push_back(item); - } - } - return new_list; -} - -bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const { - std::vector out; - (void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out), - [this](const BaseRef &item) { return fn_(item); }); - if (visit_out != nullptr) { - *visit_out = ExpandList(out); - } - return true; -} - -bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const { - if (utils::isa(any)) { - return Visit(utils::cast(any), visit_out); - } else if (utils::isa(any)) { - auto nodeptr = utils::cast(any); - AnfNodePtr output; - AnfNodePtr *p_output = &output; - if (visit_out == nullptr) { - p_output = nullptr; - } - Visit(nodeptr, fn_, p_output); - if (visit_out != nullptr) { - *visit_out = output; - } - return true; - } - MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString(); - return false; -} - -void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const { - if (node->isa()) { - Visit(node->cast(), fn, output); - return; - } - - if (node->isa()) { - Visit(node->cast(), fn, output); - return; - } - - if (output != nullptr) { - *output = node; - } -} - -void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const { - // if output is nullptr, it's not required to make the new CNode node. - if (output == nullptr) { - for (auto &inp : cnode->inputs()) { - (void)fn(inp); - } - - if (cnode->func_graph() != nullptr) { - (void)fn(cnode->func_graph()); - } else { - (void)fn(cnode->func_graph_as_var()); - } - return; - } - - std::vector new_inputs; - std::vector after_cnode_fn; - std::shared_ptr out; - (void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn); - if (CheckIfNeedExpand(after_cnode_fn)) { - out = ExpandList(after_cnode_fn); - } - - std::vector &outs = after_cnode_fn; - if (out != nullptr) { - outs = out->elements(); - } - - for (auto &any_item : outs) { - if (!utils::isa(any_item)) { - MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr"; - } - new_inputs.push_back(utils::cast(any_item)); - } - - BaseRef any_fg; - AnfNodePtr new_cnode = nullptr; - if (cnode->func_graph() != nullptr) { - any_fg = fn(cnode->func_graph()); - if (!utils::isa(any_fg)) { - MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr"; - } - new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); - } else { - any_fg = fn(cnode->func_graph_as_var()); - if (utils::isa(any_fg)) { - new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); - } else if (utils::isa(any_fg)) { - new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); - } else { - MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr"; - } - } - new_cnode->set_abstract(cnode->abstract()); - *output = new_cnode; -} - -void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const { - const BaseRef &value = utils::cast(fn(vnode->value())); - if (utils::isa(value)) { - if (output != nullptr) { - auto ct = NewValueNode(utils::cast(value)); - ct->set_abstract(vnode->abstract()); - *output = ct; - } - return; - } - MS_LOG(EXCEPTION) << "Visit result is not ValuePtr."; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/visit.h b/mindspore/ccsrc/pre_activate/common/visit.h deleted file mode 100644 index 2017b03b2f..0000000000 --- a/mindspore/ccsrc/pre_activate/common/visit.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_VISIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_VISIT_H_ - -#include -#include -#include -#include -#include -#include - -#include "ir/base.h" -#include "utils/base_ref.h" - -// namespace to support utils definition -namespace mindspore { -using VisitFn = std::function; - -class Visitor { - public: - virtual void SetFn(VisitFn fn) = 0; - virtual bool Visit(const BaseRef &e, BaseRef *out) const = 0; - virtual bool Visit(const VectorRef &e, BaseRef *out) const = 0; - virtual ~Visitor() = default; -}; - -class DefaultVisitor : public Visitor { - public: - DefaultVisitor() : fn_(nullptr) {} - ~DefaultVisitor() override = default; - void SetFn(VisitFn fn) override { fn_ = fn; }; - bool Visit(const VectorRef &e, BaseRef *out) const override; - bool Visit(const BaseRef &e, BaseRef *out) const override; - void Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const; - void Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const; - void Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const; - - VisitFn fn_; -}; - -std::shared_ptr ExpandList(const std::vector &list); -bool CheckIfNeedExpand(const std::vector &list); -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_VISIT_H_ diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_fusion.cc b/mindspore/ccsrc/pre_activate/gpu/adam_fusion.cc deleted file mode 100644 index 8111ee429d..0000000000 --- a/mindspore/ccsrc/pre_activate/gpu/adam_fusion.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/gpu/adam_fusion.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { - std::vector inputs_format; - std::vector outputs_format; - std::vector inputs_type; - std::vector outputs_type; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { - inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); - inputs_format.push_back(kOpFormat_DEFAULT); - } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); - outputs_format.push_back(kOpFormat_DEFAULT); - } - builder.SetInputsDeviceType(inputs_type); - builder.SetInputsFormat(inputs_format); - builder.SetOutputsDeviceType(outputs_type); - builder.SetOutputsFormat(outputs_format); - return builder.Build(); -} -} // namespace - -const BaseRef AdamFusion::DefinePattern() const { - VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), - VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); - VectorRef next_v = - VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), - VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); - VectorRef update = VectorRef( - {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); - VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); - VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); - VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); - VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); - VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); - return depend3; -} - -const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - auto beta1_input = utils::cast((*equiv)[beta1_]); - auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); - auto beta2_input = utils::cast((*equiv)[beta2_]); - auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); - auto eps_input = utils::cast((*equiv)[eps_]); - auto lr_input = utils::cast((*equiv)[lr_]); - auto param_input = utils::cast((*equiv)[param_]); - auto m_input = utils::cast((*equiv)[m_]); - auto v_input = utils::cast((*equiv)[v_]); - auto gradient_input = utils::cast((*equiv)[gradient_]); - MS_EXCEPTION_IF_NULL(beta1_input); - MS_EXCEPTION_IF_NULL(one_sub_beta1_input); - MS_EXCEPTION_IF_NULL(beta2_input); - MS_EXCEPTION_IF_NULL(one_sub_beta2_input); - MS_EXCEPTION_IF_NULL(eps_input); - MS_EXCEPTION_IF_NULL(lr_input); - MS_EXCEPTION_IF_NULL(param_input); - MS_EXCEPTION_IF_NULL(m_input); - MS_EXCEPTION_IF_NULL(v_input); - MS_EXCEPTION_IF_NULL(gradient_input); - - auto prim = std::make_shared(kFusedAdamName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = { - NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, - eps_input, lr_input, param_input, m_input, v_input, - gradient_input}; - auto adam = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(adam); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get()); - adam->set_scope(node->scope()); - - auto build_info = GenerateKernelBuildInfo(adam); - AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); - return adam; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_fusion.h b/mindspore/ccsrc/pre_activate/gpu/adam_fusion.h deleted file mode 100644 index d8c10a0986..0000000000 --- a/mindspore/ccsrc/pre_activate/gpu/adam_fusion.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class AdamFusion : public PatternProcessPass { - public: - explicit AdamFusion(bool multigraph = true) : PatternProcessPass("adam_fusion", multigraph) { - beta1_ = std::make_shared(); - one_sub_beta1_ = std::make_shared(); - beta2_ = std::make_shared(); - one_sub_beta2_ = std::make_shared(); - eps_ = std::make_shared(); - lr_ = std::make_shared(); - param_ = std::make_shared(); - m_ = std::make_shared(); - v_ = std::make_shared(); - gradient_ = std::make_shared(); - } - ~AdamFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr beta1_; - VarPtr one_sub_beta1_; - VarPtr beta2_; - VarPtr one_sub_beta2_; - VarPtr eps_; - VarPtr lr_; - VarPtr param_; - VarPtr m_; - VarPtr v_; - VarPtr gradient_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.cc b/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.cc deleted file mode 100644 index c950cbd56f..0000000000 --- a/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/gpu/adam_weight_decay_fusion.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { - std::vector inputs_format; - std::vector outputs_format; - std::vector inputs_type; - std::vector outputs_type; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { - inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); - inputs_format.push_back(kOpFormat_DEFAULT); - } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); - outputs_format.push_back(kOpFormat_DEFAULT); - } - builder.SetInputsDeviceType(inputs_type); - builder.SetInputsFormat(inputs_format); - builder.SetOutputsDeviceType(outputs_type); - builder.SetOutputsFormat(outputs_format); - return builder.Build(); -} -} // namespace - -const BaseRef AdamWeightDecayFusion::DefinePattern() const { - VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), - VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); - VectorRef next_v = - VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), - VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); - VectorRef update = VectorRef( - {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); - VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); - - VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); - VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); - VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); - VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); - VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); - return depend3; -} - -const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - auto beta1_input = utils::cast((*equiv)[beta1_]); - auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); - auto beta2_input = utils::cast((*equiv)[beta2_]); - auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); - auto eps_input = utils::cast((*equiv)[eps_]); - auto lr_input = utils::cast((*equiv)[lr_]); - auto weight_decay_input = utils::cast((*equiv)[weight_decay_]); - auto param_input = utils::cast((*equiv)[param_]); - auto m_input = utils::cast((*equiv)[m_]); - auto v_input = utils::cast((*equiv)[v_]); - auto gradient_input = utils::cast((*equiv)[gradient_]); - MS_EXCEPTION_IF_NULL(beta1_input); - MS_EXCEPTION_IF_NULL(one_sub_beta1_input); - MS_EXCEPTION_IF_NULL(beta2_input); - MS_EXCEPTION_IF_NULL(one_sub_beta2_input); - MS_EXCEPTION_IF_NULL(eps_input); - MS_EXCEPTION_IF_NULL(lr_input); - MS_EXCEPTION_IF_NULL(weight_decay_input); - MS_EXCEPTION_IF_NULL(param_input); - MS_EXCEPTION_IF_NULL(m_input); - MS_EXCEPTION_IF_NULL(v_input); - MS_EXCEPTION_IF_NULL(gradient_input); - - auto prim = std::make_shared(kFusedAdamWeightDecayName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = { - NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, - eps_input, lr_input, param_input, m_input, v_input, - gradient_input, weight_decay_input}; - auto adam_weight_decay = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(adam_weight_decay); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam_weight_decay.get()); - adam_weight_decay->set_scope(node->scope()); - - auto build_info = GenerateKernelBuildInfo(adam_weight_decay); - AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get()); - return adam_weight_decay; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.h b/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.h deleted file mode 100644 index 0ada5756e3..0000000000 --- a/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class AdamWeightDecayFusion : public PatternProcessPass { - public: - explicit AdamWeightDecayFusion(bool multigraph = true) : PatternProcessPass("adam_weight_decay_fusion", multigraph) { - beta1_ = std::make_shared(); - one_sub_beta1_ = std::make_shared(); - beta2_ = std::make_shared(); - one_sub_beta2_ = std::make_shared(); - eps_ = std::make_shared(); - lr_ = std::make_shared(); - weight_decay_ = std::make_shared(); - param_ = std::make_shared(); - m_ = std::make_shared(); - v_ = std::make_shared(); - gradient_ = std::make_shared(); - } - ~AdamWeightDecayFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr beta1_; - VarPtr one_sub_beta1_; - VarPtr beta2_; - VarPtr one_sub_beta2_; - VarPtr eps_; - VarPtr lr_; - VarPtr weight_decay_; - VarPtr param_; - VarPtr m_; - VarPtr v_; - VarPtr gradient_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/kernel_refcount.cc b/mindspore/ccsrc/pre_activate/mem_reuse/kernel_refcount.cc deleted file mode 100644 index c75860a8df..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/kernel_refcount.cc +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/mem_reuse/kernel_refcount.h" -#include -#include "utils/log_adapter.h" -namespace mindspore { -namespace memreuse { -/** - * Add some set && get function - */ -void KernelRefCount::SetKernelRefCountInfo(int index, size_t size, RefCountType reftype) { - index_ = index; - size_ = size; - reftype_ = reftype; -} - -std::vector KernelDef::GetInputRefIndexs() const { - std::vector input_ref_indexs; - if (input_refs_.empty()) { - return input_ref_indexs; - } - (void)std::transform(input_refs_.begin(), input_refs_.end(), std::back_inserter(input_ref_indexs), - [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); - return input_ref_indexs; -} - -std::vector KernelDef::GetOutputRefIndexs() const { - std::vector output_ref_indexs; - if (output_refs_.empty()) { - return output_ref_indexs; - } - (void)std::transform(output_refs_.begin(), output_refs_.end(), std::back_inserter(output_ref_indexs), - [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); - return output_ref_indexs; -} - -std::vector KernelDef::GetWorkspaceRefIndexs() const { - std::vector wk_ref_indexs; - if (wk_space_.empty()) { - return wk_ref_indexs; - } - // only one key - auto wk_refs_iter = wk_space_.begin(); - auto wk_refs = wk_refs_iter->second; - (void)std::transform(wk_refs.begin(), wk_refs.end(), std::back_inserter(wk_ref_indexs), - [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); - return wk_ref_indexs; -} -} // namespace memreuse -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_copy_manager.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_copy_manager.h deleted file mode 100644 index ea9947b41b..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_copy_manager.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ - -#include -#include -#include -#include -#include -#include "session/kernel_graph.h" -#include "kernel/kernel.h" - -using HostAddress = mindspore::kernel::Address; -namespace mindspore { -namespace device { -namespace memswap { -enum class SwapKind { kDeviceToHost = 0, kHostToDevice = 1 }; - -struct TensorInfo { - size_t tensor_size_{0}; - AnfNodePtr kernel_{nullptr}; - size_t output_idx_{0}; -}; - -struct KernelExecutionInfo { - size_t topo_order_{0}; - float execution_perform_{0.0}; - bool trigger_swap_{false}; - bool need_swap_{false}; - // output index to topo orders of node users - std::map> node_users_map_; - // kernel output idx to host addr - std::map host_addrs_; - - KernelExecutionInfo() : KernelExecutionInfo(0, 0.0, false, false) {} - explicit KernelExecutionInfo(size_t topo_order) - : topo_order_(topo_order), execution_perform_(0.0), trigger_swap_(false), need_swap_(false) {} - KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap, bool need_swap) - : topo_order_(topo_order), - execution_perform_(execution_perform), - trigger_swap_(trigger_swap), - need_swap_(need_swap) {} -}; - -// trigger swap -struct MemSwapInfo { - SwapKind swap_kind_; - // kernel need to be swapped - AnfNodePtr kernel_{nullptr}; - size_t output_idx_{0}; -}; - -class MemCopyManager { - public: - MemCopyManager() = default; - - virtual ~MemCopyManager() = default; - - virtual void Init() {} - - virtual void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {} - - virtual void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {} - - virtual bool SyncMemCopyStream(SwapKind swap_kind) { return true; } - - virtual DeviceAddressPtr UpdateSwapOutQueue() { return nullptr; } - - virtual DeviceAddressPtr UpdateSwapInQueue() { return nullptr; } - - virtual bool AllocHostPinnedMem(size_t size, void **addr) const { return true; } - - virtual void FreeHostPinnedMem(void *addr) const {} - - virtual void ClearSwapQueue() {} -}; -using MemCopyManagerPtr = std::shared_ptr; -} // namespace memswap -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc deleted file mode 100644 index 095f8f6495..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc +++ /dev/null @@ -1,324 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/mem_reuse/mem_dynamic_allocator.h" -#include "common/utils.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -DynamicMemPoolBestFit::~DynamicMemPoolBestFit() { - global_mem_block_list_.clear(); - global_idle_mem_buf_map_.clear(); -} - -DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) { - size_t align_size = AlignMemorySize(size); - // Find the idle memory buf by tensor size, if not find, then add new memory block and memory buf. - DeviceMemPtr device_addr = FindIdleMemBuf(align_size); - if (!device_addr) { - device_addr = AddMemBlockAndMemBuf(align_size); - } - return device_addr; -} - -std::vector DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size, - std::vector size_list) { - std::vector device_addr_list; - // Pre-alloc the one whole piece memory. - auto device_addr = AllocTensorMem(total_size); - if (!device_addr) { - return device_addr_list; - } - // Remove the pre-alloc memory. - auto mem_block = FindMemBlock(device_addr); - MS_EXCEPTION_IF_NULL(mem_block); - auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); - if (iter == mem_block->block_all_mem_buf_map_.end()) { - MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; - } - auto mem_buf = iter->second; - MS_EXCEPTION_IF_NULL(mem_buf); - auto rest_size = mem_buf->size_ - total_size; - (void)mem_block->block_all_mem_buf_map_.erase(iter); - // Split the pre-alloc memory into continuous memory by the size list. - DynamicMemBufPtr continuous_mem_buf; - auto buf_addr = device_addr; - for (size_t i = 0; i < size_list.size(); i++) { - continuous_mem_buf = std::make_shared(buf_addr, kMemBufUsed, size_list[i]); - (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf); - device_addr_list.emplace_back(buf_addr); - buf_addr = AddressOffset(buf_addr, size_list[i]); - } - // Update the size of the last memory buf. - continuous_mem_buf->size_ += rest_size; - return device_addr_list; -} - -size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const { - if (size == 0) { - return DYNAMIC_MEM_ALIGN_SIZE; - } - return ((size + DYNAMIC_MEM_ALIGN_SIZE - 1) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE; -} - -DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size) { - auto iter = global_idle_mem_buf_map_.lower_bound(size); - if (iter != global_idle_mem_buf_map_.end()) { - auto mem_buf = iter->second; - MS_EXCEPTION_IF_NULL(mem_buf); - if (mem_buf->status_ != kMemBufIdle) { - MS_LOG(EXCEPTION) << "Find the mem_buf is not idle, alloc_size[" << size << "] mem_buf_size[" << mem_buf->size_ - << "] mem_buf_address[" << mem_buf->device_addr_ << "]."; - } - mem_buf->status_ = kMemBufUsed; - // Remove map of old idle memory buf - (void)global_idle_mem_buf_map_.erase(iter); - // Divide memory buf - if (IsDivide(size, mem_buf->size_)) { - DivideMemBuf(size, mem_buf); - } - // Memory statistics - total_used_mem_statistics_ += mem_buf->size_; - if (total_used_mem_statistics_ > used_mem_peak_statistics_) { - used_mem_peak_statistics_ = total_used_mem_statistics_; - } - return mem_buf->device_addr_; - } - return nullptr; -} - -DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size) { - size_t alloc_mem_size = CalMemBlockAllocSize(size); - if (alloc_mem_size == 0) { - return nullptr; - } - // Add new memory block - DeviceMemPtr device_addr = nullptr; - auto real_alloc_size = AllocDeviceMem(alloc_mem_size, &device_addr); - if (real_alloc_size < size) { - MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size - << "]."; - return nullptr; - } - auto mem_block = std::make_shared(device_addr, real_alloc_size); - MS_EXCEPTION_IF_NULL(mem_block); - auto iter = std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock); - (void)global_mem_block_list_.insert(iter, mem_block); - // Add new memory buf - auto mem_buf = std::make_shared(device_addr, kMemBufUsed, real_alloc_size); - MS_EXCEPTION_IF_NULL(mem_buf); - // Add map of new memory buf in the block - (void)mem_block->block_all_mem_buf_map_.emplace(device_addr, mem_buf); - // Divide memory buf - if (IsDivide(size, mem_buf->size_)) { - DivideMemBuf(size, mem_buf); - } - // Memory statistics - total_mem_statistics_ += real_alloc_size; - total_used_mem_statistics_ += mem_buf->size_; - if (total_used_mem_statistics_ > used_mem_peak_statistics_) { - used_mem_peak_statistics_ = total_used_mem_statistics_; - } - return mem_buf->device_addr_; -} - -size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size) { - auto device_free_mem_size = free_mem_size(); - if (device_free_mem_size < size) { - MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size - << "] is smaller than required size[" << size << "]."; - return 0; - } - auto alloc_mem_size = mem_alloc_unit_size(); - // Growing at twice of alloc size - while (alloc_mem_size < size) { - alloc_mem_size = alloc_mem_size * 2; - } - alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size); - return alloc_mem_size; -} - -bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) const { - return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE; -} - -void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) { - MS_EXCEPTION_IF_NULL(mem_buf); - auto mem_block = FindMemBlock(mem_buf->device_addr_); - MS_EXCEPTION_IF_NULL(mem_block); - // Divide new memory buf - size_t newbuf_size = mem_buf->size_ - size; - mem_buf->size_ = size; - DeviceMemPtr newbuf_addr = AddressOffset(mem_buf->device_addr_, size); - auto new_mem_buf = std::make_shared(newbuf_addr, kMemBufIdle, newbuf_size); - // Add map of new memory buf in the block - (void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf); - // Add map of new idle memory buf - (void)global_idle_mem_buf_map_.emplace(newbuf_size, new_mem_buf); -} - -bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr device_addr, const DynamicMemBlockPtr mem_block) { - MS_EXCEPTION_IF_NULL(device_addr); - MS_EXCEPTION_IF_NULL(mem_block); - return device_addr < mem_block->device_addr(); -} - -DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr device_addr) { - MS_EXCEPTION_IF_NULL(device_addr); - auto iter = std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock); - if (iter != global_mem_block_list_.begin()) { - return *(--iter); - } - MS_LOG(ERROR) << "Can't find the mem_block of the device address[" << device_addr << "]."; - return nullptr; -} - -void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr device_addr) { - MS_EXCEPTION_IF_NULL(device_addr); - auto mem_block = FindMemBlock(device_addr); - MS_EXCEPTION_IF_NULL(mem_block); - CombineMemBuf(mem_block, device_addr); -} - -void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr device_addr) { - MS_EXCEPTION_IF_NULL(mem_block); - MS_EXCEPTION_IF_NULL(device_addr); - auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); - if (iter == mem_block->block_all_mem_buf_map_.end()) { - MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; - } - auto mem_buf = iter->second; - MS_EXCEPTION_IF_NULL(mem_buf); - if (mem_buf->status_ != kMemBufUsed) { - MS_LOG(EXCEPTION) << "Find the mem_buf is not used, mem_buf_address[" << mem_buf->device_addr_ << "]."; - } - mem_buf->status_ = kMemBufIdle; - total_used_mem_statistics_ -= mem_buf->size_; - // Combine backward(combine the next_mem_buf to mem_buf) - auto next_iter = iter; - (void)next_iter++; - if (next_iter != mem_block->block_all_mem_buf_map_.end()) { - auto next_mem_buf = next_iter->second; - MS_EXCEPTION_IF_NULL(next_mem_buf); - if (next_mem_buf->status_ == kMemBufIdle) { - mem_buf->size_ += next_mem_buf->size_; - EraseIdleMemBuf(next_mem_buf->size_, next_mem_buf->device_addr_); - (void)mem_block->block_all_mem_buf_map_.erase(next_iter); - } - } - // Combine forward(combine the mem_buf to prev_mem_buf) - bool forward_combine = false; - DynamicMemBufPtr prev_mem_buf; - if (iter != mem_block->block_all_mem_buf_map_.begin()) { - auto prev_iter = iter; - (void)prev_iter--; - prev_mem_buf = prev_iter->second; - MS_EXCEPTION_IF_NULL(prev_mem_buf); - if (prev_mem_buf->status_ == kMemBufIdle) { - EraseIdleMemBuf(prev_mem_buf->size_, prev_mem_buf->device_addr_); - prev_mem_buf->size_ += mem_buf->size_; - (void)mem_block->block_all_mem_buf_map_.erase(iter); - forward_combine = true; - } - } - // Add map of new idle memory - if (forward_combine) { - (void)global_idle_mem_buf_map_.emplace(prev_mem_buf->size_, prev_mem_buf); - } else { - (void)global_idle_mem_buf_map_.emplace(mem_buf->size_, mem_buf); - } -} - -void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr device_addr) { - MS_EXCEPTION_IF_NULL(device_addr); - auto iter = global_idle_mem_buf_map_.equal_range(size); - while (iter.first != iter.second) { - MS_EXCEPTION_IF_NULL(iter.first->second); - // Remove map of the idle memory buf by size and device address - if (iter.first->second->device_addr_ == device_addr) { - (void)global_idle_mem_buf_map_.erase(iter.first); - return; - } - (void)iter.first++; - } - MS_LOG(ERROR) << "Can't find the size[" << size << "] and device address[" << device_addr << "] in the idle mem_buf."; -} - -void DynamicMemPoolBestFit::ReleaseDeviceRes() { - MS_LOG(INFO) << "The dynamic memmory pool total size is " << total_mem_statistics_ << ", total used size is " - << total_used_mem_statistics_ << ", used peak size is " << used_mem_peak_statistics_ << "."; - for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) { - auto device_addr = (*iter)->device_addr(); - if (device_addr != nullptr) { - if (!FreeDeviceMem(device_addr)) { - MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error."; - } - } - } -} - -void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() { - MS_LOG(INFO) << "Start dump dynamic memory pool info."; - DeviceAddrMapMemBuf mem_block_map; - DynamicMemBufPtr mem_buf; - size_t total_mem = 0; - size_t total_used_mem = 0; - size_t total_idle_mem1 = 0; - size_t total_idle_mem2 = 0; - // Dump the memory block info and memory buf info - MS_LOG(INFO) << "Dump all mem_block info: counts[" << global_mem_block_list_.size() << "]."; - for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) { - total_mem += (*iter)->size(); - mem_block_map = (*iter)->block_all_mem_buf_map_; - MS_LOG(INFO) << "MemBlock info: number[" << iter - global_mem_block_list_.begin() << "] mem_buf_counts[" - << mem_block_map.size() << "] base_address[" << (*iter)->device_addr() << "] block_size[" - << (*iter)->size() << "]."; - for (auto iter_mem_buf = mem_block_map.begin(); iter_mem_buf != mem_block_map.end(); ++iter_mem_buf) { - mem_buf = iter_mem_buf->second; - MS_EXCEPTION_IF_NULL(mem_buf); - if (mem_buf->status_ == kMemBufIdle) { - total_idle_mem1 += mem_buf->size_; - } else { - total_used_mem += mem_buf->size_; - } - MS_LOG(INFO) << "MemBuf info: address[" << mem_buf->device_addr_ << "] size[" << mem_buf->size_ << "] status[" - << mem_buf->status_ << "]."; - } - } - // Dump all the idle memory buf info - MS_LOG(INFO) << "Dump all idle mem_buf info: counts[" << global_idle_mem_buf_map_.size() << "]."; - for (auto iter_idle = global_idle_mem_buf_map_.begin(); iter_idle != global_idle_mem_buf_map_.end(); ++iter_idle) { - mem_buf = iter_idle->second; - MS_EXCEPTION_IF_NULL(mem_buf); - total_idle_mem2 += mem_buf->size_; - MS_LOG(INFO) << "Idle mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_ << "] status[" - << mem_buf->status_ << "]."; - } - // Dump the memory statistical info - MS_LOG(INFO) << "Total allocated memory[" << total_mem << "], used memory[" << total_used_mem << "], idle memory[" - << total_idle_mem1 << "]."; - if (total_idle_mem1 != total_idle_mem2) { - MS_LOG(ERROR) << "Check error: the idle memory in the mem_block is not equal the global idle memory."; - } - if (total_mem != total_used_mem + total_idle_mem1) { - MS_LOG(ERROR) << "Check error: the the total memory is not equal the sum of used memory and idle memory."; - } - MS_LOG(INFO) << "Finish dump dynamic memory pool info."; -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc deleted file mode 100644 index d550b77bba..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ /dev/null @@ -1,433 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/mem_reuse/mem_reuse.h" -#include -#include -#include "pre_activate/mem_reuse/mem_reuse_checker.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace memreuse { -bool MemReuseUtil::InitDynamicOutputKernelRef() { - int index = util_index_; - auto kernel_cnodes = graph_->execution_order(); - if (kernel_cnodes.empty()) { - return true; - } - int kernel_out_ref_num = 0; - for (auto &kernel_cnode : kernel_cnodes) { -#ifdef MEM_REUSE_DEBUG - MemReuseChecker::GetInstance().CheckSignalOps(kernel_cnode); -#endif - if (kernel_cnode == nullptr) { - return false; - } - auto kernel_mod = AnfAlgo::GetKernelMod(kernel_cnode); - if (kernel_mod == nullptr) { - return false; - } - auto key = kernel_cnode.get(); - // for every apply_kernel to set new output - auto iter = kernel_output_refs_.find(key); - if (iter == kernel_output_refs_.end()) { - auto output_sizes = kernel_mod->GetOutputSizeList(); - KernelRefCountPtrList kernel_refs; - for (auto size : output_sizes) { - total_dy_size_ += size; - // do not MallocDynamicMem just record this - KernelRefCountPtr kernel_ref = std::make_shared(); - index++; - auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode); - kernel_ref->stream_id_ = curr_stream_id; - kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); - kernel_refs.push_back(kernel_ref); - kernel_out_ref_num++; - total_refs_list_.push_back(kernel_ref); - } - if (!kernel_refs.empty()) { - kernel_output_refs_[key] = kernel_refs; - } - } - } - return true; -} - -bool MemReuseUtil::InitDynamicWorkspaceKernelRef() { - int WkIndex = util_index_; - auto kernel_cnodes = graph_->execution_order(); - if (kernel_cnodes.empty()) { - return true; - } - for (auto &kernel_cnode : kernel_cnodes) { - if (kernel_cnode == nullptr) { - return false; - } - auto kernel_mod = AnfAlgo::GetKernelMod(kernel_cnode); - if (kernel_mod == nullptr) { - return false; - } - auto key = kernel_cnode.get(); - auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); - KernelRefCountPtrList workspace_kernel_refs; - for (auto size : workspace_sizes) { - total_workspace_size_ += size; - ++WkIndex; - KernelRefCountPtr workspace_ref = std::make_shared(); - workspace_ref->SetKernelRefCountInfo(WkIndex, size, kDynamicRefCount); - workspace_kernel_refs.push_back(workspace_ref); - // total wk ref - total_wk_ref_list_.push_back(workspace_ref); - } - if (!workspace_kernel_refs.empty()) { - // every key index wk_refs - kernel_workspace_refs_[key] = workspace_kernel_refs; - } - } - return true; -} - -bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - graph_ = graph; - is_all_nop_node_ = opt::IsAllNopNode(graph); - if (!InitDynamicOutputKernelRef()) { - MS_LOG(INFO) << "InitDynamicOutputKernelRef fail"; - return false; - } - if (!InitDynamicWorkspaceKernelRef()) { - MS_LOG(INFO) << "InitDynamicWorkspaceKernelRef fail"; - return false; - } - return true; -} - -// set longest worspace list && largest workspace sizes -void MemReuseUtil::SetWorkSpaceList() { - int max_list_size = 0; - std::vector total_sizes; - std::vector max_list; - auto kernel_cnodes = graph_->execution_order(); - for (auto &kernel_cnode : kernel_cnodes) { - MS_EXCEPTION_IF_NULL(kernel_cnode); - auto cnode_key = kernel_cnode.get(); - auto cnode_iter = kernel_workspace_refs_.find(cnode_key); - if (cnode_iter != kernel_workspace_refs_.end()) { - auto kernel_refs = cnode_iter->second; - std::vector current_list; - for (size_t i = 0; i < kernel_refs.size(); ++i) { - auto size = kernel_refs[i]->size_; - current_list.push_back(size); - } - if (max_list_size < SizeToInt(current_list.size())) { - max_list_size = SizeToInt(current_list.size()); - } - (void)std::copy(current_list.begin(), current_list.end(), std::back_inserter(total_sizes)); - } - } - sort(total_sizes.rbegin(), total_sizes.rend()); - max_list.resize(IntToSize(max_list_size)); - if (SizeToInt(total_sizes.size()) < max_list_size) { - MS_LOG(EXCEPTION) << "total workspace size is less than required max list size"; - } - max_list.assign(total_sizes.begin(), total_sizes.begin() + max_list_size); - for (auto &ma : max_list) { - total_reuseworkspace_size_ += ma; - } - max_workspace_size_ = max_list_size; - max_workspace_list_ = max_list; -} - -void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_def_ptr); - auto key = kernel.get(); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto ref_ptr = GetKernelInputRef(kernel, i); - if (ref_ptr != nullptr) { - if (ref_ptr->reftype() == kStaticRefCount) { - continue; - } else if (ref_ptr->reftype() == kDynamicRefCount) { - auto iter = kernel_def_ptr->inputs_.find(key); - if (iter == kernel_def_ptr->inputs_.end()) { - kernel_def_ptr->inputs_[key].push_back(ref_ptr); - } else { - iter->second.push_back(ref_ptr); - } - } - } - } -} - -void MemReuseUtil::SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_def_ptr); - auto key = kernel.get(); - auto iter = kernel_def_ptr->outputs_.find(key); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t k = 0; k < kernel_mod->GetOutputSizeList().size(); ++k) { - KernelRefCountPtr kernel_ref = kernel_output_refs_[key][k]; - if (iter == kernel_def_ptr->outputs_.end()) { - kernel_def_ptr->outputs_[key].push_back(kernel_ref); - } else { - iter->second.push_back(kernel_ref); - } - } -} - -void MemReuseUtil::SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_def_ptr); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto key = kernel.get(); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - if (kernel_workspace_refs_.find(key) != kernel_workspace_refs_.end()) { - auto wk_refs = kernel_workspace_refs_[key]; - if (i < wk_refs.size()) { - auto wk_ref = wk_refs[i]; - kernel_def_ptr->wk_space_[key].push_back(wk_ref); - } else { - MS_LOG(EXCEPTION) << "current index: " << i << " larger than wk_refs size " << wk_refs.size(); - } - } else { - MS_LOG(EXCEPTION) << "kernel_workspace_refs_ init error"; - } - } -} - -KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { - if (node == nullptr) { - MS_LOG(EXCEPTION) << "The node pointer is a nullptr."; - } - if (node->isa()) { - auto ak_node = node->cast(); - auto key = ak_node.get(); - MemReuseChecker::GetInstance().CheckOutRef(kernel_output_refs_, ak_node, IntToSize(output_idx)); - return kernel_output_refs_[key][IntToSize(output_idx)]; - } - return nullptr; -} - -KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { - if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { - MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " - << AnfAlgo::GetInputTensorNum(kernel); - } - auto input_node = kernel->input(input_idx + 1); - // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. - session::KernelWithIndex kernel_input; - if (is_all_nop_node_) { - // The graph does not remove the nop node. - kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); - } else { - // The graph removes the nop node. - kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); - } - if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { - MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; - } - auto result = GetRef(kernel_input.first, SizeToInt(kernel_input.second)); - return result; -} - -void MemReuseUtil::SetKernelDefMap() { - auto kernel_cnodes = graph_->execution_order(); - for (auto &kernel : kernel_cnodes) { - KernelDefPtr kernel_def_ptr = std::make_shared(); - kernel_def_ptr->set_kernel_name(AnfAlgo::GetCNodeName(kernel)); - kernel_def_ptr->set_scope_full_name(kernel->fullname_with_scope()); - kernel_def_ptr->set_stream_id(AnfAlgo::GetStreamId(kernel)); - SetInputMap(kernel, kernel_def_ptr.get()); - SetOutputMap(kernel, kernel_def_ptr.get()); - SetWkMap(kernel, kernel_def_ptr.get()); - auto key = kernel.get(); - kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); - kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); - kernel_def_ptr_list_.push_back(kernel_def_ptr); - kernel_map_[key] = kernel_def_ptr; - } - SetKernelDefInputs(); -} - -void MemReuseUtil::SetKernelDefInputs() { - for (const auto &kernel : graph_->execution_order()) { - MS_EXCEPTION_IF_NULL(kernel); - auto key = kernel.get(); - // find kernel_def according to cnode addr - auto iter = kernel_map_.find(key); - if (iter == kernel_map_.end()) { - MS_LOG(EXCEPTION) << "kernel [" << kernel->fullname_with_scope() << "] is not init."; - } - auto kernel_def = iter->second; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto ref_ptr = GetKernelInputRef(kernel, i); - if (ref_ptr != nullptr) { - // set the inputs of this kernel_def - auto input_node = AnfAlgo::GetInputNode(kernel, i); - // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. - session::KernelWithIndex input; - if (is_all_nop_node_) { - // The graph does not remove the nop node. - input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); - } else { - // The graph removes the nop node. - input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); - } - if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { - MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; - } - auto input_key = (input.first).get(); - auto input_iter = kernel_map_.find(input_key); - if (input_iter == kernel_map_.end()) { - MS_LOG(EXCEPTION) << "kernel [" << (input.first)->fullname_with_scope() << "] is not init."; - } - kernel_def->InsertInputKernel(input_iter->second); - } - } - } -} - -void MemReuseUtil::SetReuseRefCount() { - auto kernels = graph_->execution_order(); - for (auto &kernel : kernels) { - auto key = kernel.get(); - for (auto &def : kernel_def_ptr_list_) { - auto iter = def->inputs_.find(key); - if (iter != def->inputs_.end()) { - for (auto &input : iter->second) { - input->ref_count_++; - input->ref_count_dynamic_use_++; - } - } - } - } -} - -void MemReuseUtil::SetSummaryNodesRefCount() { - bool summary_exist = graph_->summary_node_exist(); - if (!summary_exist) { - return; - } - - auto summary_nodes = graph_->summary_nodes(); - if (summary_nodes.empty()) { - return; - } - - for (auto &node_item : summary_nodes) { - auto node = node_item.second.first; - size_t index = IntToSize(node_item.second.second); - MS_LOG(INFO) << "set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; - if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) { - KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; - kernel_ref->ref_count_ = kMaxRefCount; - kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; - } else { - MS_LOG(WARNING) << "can't find summary node's kernel_def " << node->fullname_with_scope(); - } - } -#ifdef MEM_REUSE_DEBUG - auto graph = *graph_; - MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); -#endif -} - -void MemReuseUtil::SetGraphOutputRefCount() { - auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); - for (const auto &node : nodes) { - session::KernelWithIndex kernel_input; - if (is_all_nop_node_) { - // The graph does not remove the nop node. - kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false); - } else { - // The graph removes the nop node. - kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - } - MS_EXCEPTION_IF_NULL(kernel_input.first); - if (!kernel_input.first->isa() || !AnfAlgo::IsRealKernel(kernel_input.first)) { - continue; - } - auto ak_node = kernel_input.first->cast(); - auto key = ak_node.get(); - auto iter = kernel_output_refs_.find(key); - if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) { - auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second]; - MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr); - kernel_ref_count_ptr->ref_count_ = kMaxRefCount; - kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount; - } - } -#ifdef MEM_REUSE_DEBUG - auto graph = *graph_; - MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); -#endif -} - -void MemReuseUtil::ResetDynamicUsedRefCount() { - for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) { - for (auto &ref_count : iter->second) { - MS_EXCEPTION_IF_NULL(ref_count); - ref_count->ref_count_dynamic_use_ = ref_count->ref_count_; - } - } -} - -void MemReuseUtil::SetAllInfo(KernelGraph *graph) { - if (!InitDynamicKernelRef(graph)) { - MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; - } - SetKernelDefMap(); - SetReuseRefCount(); - SetSummaryNodesRefCount(); - SetWorkSpaceList(); -#ifdef MEM_REUSE_DEBUG - MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); -#endif -} - -uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { - auto key = node.get(); - auto iter = kernel_output_refs_.find(key); - uint8_t *ptr = nullptr; - if (iter != kernel_output_refs_.end()) { - if (index >= iter->second.size()) { - MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]"; - } - auto output_ref = iter->second[index]; - ptr = mem_base_ + output_ref->offset_; - } else { - MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs"; - } - return ptr; -} - -uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const { - auto key = node.get(); - auto iter = kernel_workspace_refs_.find(key); - uint8_t *ptr = nullptr; - if (iter != kernel_workspace_refs_.end()) { - if (index >= iter->second.size()) { - MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]"; - } - auto wk_ref = iter->second[index]; - ptr = mem_base_ + wk_ref->offset_; - } - return ptr; -} -} // namespace memreuse -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h deleted file mode 100644 index 37281a7128..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ -#include -#include -#include -#include "pre_activate/mem_reuse/kernel_refcount.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "kernel/tbe/tbe_utils.h" -using mindspore::kernel::tbe::TbeUtils; -namespace mindspore { -namespace memreuse { -static constexpr int kMaxRefCount = 9999; -static constexpr size_t kDefaultMemAlignSize = 512; -static constexpr size_t kAttAlignSize = 31; -static constexpr int kInvalidIndex = -2; - -using KernelDefPtrMaps = std::vector; -using KernelRefs = std::map; - -using KernelGraph = mindspore::session::KernelGraph; - -class MemReuseUtil { - public: - KernelRefs kernel_output_refs_; - KernelRefCountPtrList total_refs_list_; - KernelRefCountPtrList total_wk_ref_list_; - KernelRefs kernel_workspace_refs_; - MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr), is_all_nop_node_(false) {} - ~MemReuseUtil() { - if (graph_ != nullptr) { - graph_ = nullptr; - } - MS_LOG(INFO) << "Total Dynamic Memory Size: " << total_dy_size_; - MS_LOG(INFO) << "Total WorkSpace Memory Size: " << total_workspace_size_; - MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_; - } - - void SetAllInfo(KernelGraph *graph); - bool InitDynamicOutputKernelRef(); - bool InitDynamicWorkspaceKernelRef(); - bool InitDynamicKernelRef(const KernelGraph *graph); - void SetWorkSpaceList(); - void SetKernelDefMap(); - void SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); - void SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); - void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); - void SetKernelDefInputs(); - void SetReuseRefCount(); - void SetSummaryNodesRefCount(); - // Set the reference count of graph output specially. - void SetGraphOutputRefCount(); - // Reset the dynamic used reference count by ref_count_. - void ResetDynamicUsedRefCount(); - - KernelRefCountPtr GetRef(const AnfNodePtr &node, int output_idx); - KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx); - KernelRefCountPtrList total_refs_list() const { return total_refs_list_; } - KernelRefCountPtrList total_wk_ref_list() const { return total_wk_ref_list_; } - KernelDefPtrMaps kernel_def_ptr_list() const { return kernel_def_ptr_list_; } - int max_workspace_size() const { return max_workspace_size_; } - std::vector max_workspace_list() const { return max_workspace_list_; } - void set_total_refs_list(const KernelRefCountPtrList &total_refs_list) { total_refs_list_ = total_refs_list; } - void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) { - kernel_def_ptr_list_ = kernel_def_ptr_list; - } - void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; } - uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const; - uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const; - - private: - int util_index_; - const KernelGraph *graph_; - bool is_all_nop_node_; - KernelRefCountPtrList ref_list_; - KernelDefPtrMaps kernel_def_ptr_list_; - KernelRefCountPtrList last_ref_list_; - int max_workspace_size_ = 0; - std::vector max_workspace_list_; - size_t total_dy_size_ = 0; - size_t total_workspace_size_ = 0; - size_t total_reuseworkspace_size_ = 0; - uint8_t *mem_base_{nullptr}; - // kernel_map_: key is the AnfNodePtr addr, value is the KernelDef - std::map kernel_map_; -}; -using MemReuseUtilPtr = std::shared_ptr; -} // namespace memreuse -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc deleted file mode 100644 index b36147f9bb..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc +++ /dev/null @@ -1,368 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/mem_reuse/mem_reuse_allocator.h" -#include "pre_activate/mem_reuse/mem_reuse.h" -#include "pre_activate/mem_reuse/mem_reuse_checker.h" - -namespace mindspore { -namespace memreuse { -void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) { - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - set_tensor_ptr_list(mem_reuse_util_ptr->total_refs_list()); - set_workspace_ptr_list(mem_reuse_util_ptr->total_wk_ref_list()); - set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list()); - // check info Correctness - for (auto &tensor : tensor_ptr_list_) { - tensor->size_ = AlignMemorySize(tensor->size_); - } - // align wk size to 512 && refcount == 1 - for (auto &wk : wk_tensor_list_) { - wk->size_ = AlignMemorySize(wk->size_); - wk->ref_count_ = 1; - } -} - -void BestFitMemReuse::InitKernelDependence() { - for (const auto &kernel : op_ptr_list_) { - std::set front; - std::queue to_visit; - to_visit.push(kernel); - // find all kernels before current kernel - while (!to_visit.empty()) { - auto curr = to_visit.front(); - to_visit.pop(); - if (front.count(curr)) { - continue; - } - front.insert(curr); - auto iter = kernel_front_map_.find(curr); - if (iter != kernel_front_map_.end()) { - auto visited_front = iter->second; - front.insert(visited_front.begin(), visited_front.end()); - continue; - } - for (const auto &input : curr->input_kernels()) { - to_visit.push(input); - } - } - kernel_front_map_[kernel] = front; - } -} - -bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const KernelDefPtr &kernel_prev) { - // determine whether the kernel_curr can reuse kernel_prev's output tensor membuf - MS_EXCEPTION_IF_NULL(kernel_curr); - MS_EXCEPTION_IF_NULL(kernel_prev); - auto curr_stream_id = kernel_curr->stream_id(); - auto prev_stream_id = kernel_prev->stream_id(); - if (curr_stream_id == prev_stream_id) { - return true; - } - auto iter = kernel_front_map_.find(kernel_curr); - if (iter == kernel_front_map_.end()) { - MS_LOG(EXCEPTION) << kernel_curr->scope_full_name() << " is not init."; - } - auto kernel_curr_front = iter->second; - return kernel_curr_front.count(kernel_prev); -} - -void BestFitMemReuse::AssignNodeOutputOffset() { - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { - size_t index = GetTensorIndex(tensor_idx); - auto tensor_desc = tensor_ptr_list_[index]; - MS_EXCEPTION_IF_NULL(tensor_desc); - auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_); - if (!reusable_membuf_map.empty()) { - auto membuf_index = reusable_membuf_map.begin()->second; - // find the best suitable membuf in membuf list, and reuse it - ReuseExistMembuf(tensor_desc.get(), membuf_index, kDynamicMem); - } else { - // no membuf can reuse, add new membuf after the membuf_ptr_list - AddNewMembufPtr(tensor_desc.get(), kDynamicMem); -#ifdef MEM_REUSE_DEBUG - MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; -#endif - } - } -} - -void BestFitMemReuse::AssignNodeWorkspaceOffset() { - for (auto &wk_idx : current_kernel_->GetWorkspaceRefIndexs()) { - size_t index = GetWorkspaceIndex(wk_idx); - auto wk_ref = wk_tensor_list_[index]; - MS_EXCEPTION_IF_NULL(wk_ref); - auto re_wk_membuf_map = GetReusableMembufMap(wk_ref->size_); - if (!re_wk_membuf_map.empty()) { - auto membuf_index = re_wk_membuf_map.begin()->second; - ReuseExistMembuf(wk_ref.get(), membuf_index, kWorkspaceMem); - } else { - AddNewMembufPtr(wk_ref.get(), kWorkspaceMem); - } - } -} - -void BestFitMemReuse::ReuseExistMembuf(KernelRefCount *tensor_desc, size_t membuf_index, int flag) { - MS_EXCEPTION_IF_NULL(tensor_desc); - CheckMembufIndx(membuf_index); - auto membuf = membuf_ptr_list_[membuf_index]; - MS_EXCEPTION_IF_NULL(membuf); - // first to split && then update membuf_info - if (IsSplit(tensor_desc->size_, membuf->size_)) { - // split the membuf, and insert a new membuf after this membuf - SplitMembuf(tensor_desc, membuf_index); - } - // update membuf status, and set tensor offset - UpdateMembufInfo(tensor_desc, membuf.get(), flag); -} - -std::map BestFitMemReuse::GetReusableMembufMap(size_t tensor_size) { - std::map size_map; - for (size_t i = 0; i < membuf_ptr_list_.size(); ++i) { - auto membuf = membuf_ptr_list_[i]; - auto index = i; - bool is_membuf_ok = membuf->status_ == kUnused && membuf->size_ >= tensor_size; - if (is_membuf_ok && IsUsable(current_kernel_, membuf->used_kernel_)) { - (void)size_map.insert(std::make_pair(membuf->size_, index)); - break; - } - } - return size_map; -} - -void BestFitMemReuse::UpdateMembufInfo(KernelRefCount *tensor_desc, Membuf *membuf, int flag) { - MS_EXCEPTION_IF_NULL(tensor_desc); - MS_EXCEPTION_IF_NULL(membuf); - auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); - membuf->status_ = kReused; - membuf->index_ = real_index; - membuf->used_kernel_ = current_kernel_; - tensor_desc->offset_ = membuf->offset_; -} - -bool BestFitMemReuse::IsSplit(size_t tensor_size, size_t membuf_size) const { return tensor_size < membuf_size; } - -void BestFitMemReuse::SplitMembuf(const KernelRefCount *tensor_desc, size_t membuf_index) { - MS_EXCEPTION_IF_NULL(tensor_desc); - CheckMembufIndx(membuf_index); - auto membuf = membuf_ptr_list_[membuf_index]; - MS_EXCEPTION_IF_NULL(membuf); - auto bias = membuf->size_ - tensor_desc->size_; - membuf->size_ = tensor_desc->size_; - // to check if spilt membuf can be merge - auto new_membuf = - std::make_shared(kUnused, bias, membuf->offset_ + membuf->size_, kInvalidIndex, current_kernel_); - (void)membuf_ptr_list_.insert(membuf_ptr_list_.begin() + SizeToInt(membuf_index + 1), new_membuf); -} - -void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) { - MS_EXCEPTION_IF_NULL(tensor_desc); - size_t membuf_offset = 0; - if (!membuf_ptr_list_.empty()) { - membuf_offset = membuf_ptr_list_.back()->offset_ + membuf_ptr_list_.back()->size_; - } - auto membuf_size = tensor_desc->size_; - auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); - auto membuf = std::make_shared(kReused, membuf_size, membuf_offset, real_index, current_kernel_); - membuf_ptr_list_.push_back(membuf); - tensor_desc->offset_ = membuf_offset; -} - -void BestFitMemReuse::UpdateNodeInputAndMembuf() { - // process node input tensor - for (const auto &tensor_idx : current_kernel_->GetInputRefIndexs()) { - size_t tensor_index = GetTensorIndex(tensor_idx); - auto tensor_desc = tensor_ptr_list_[tensor_index]; - MS_EXCEPTION_IF_NULL(tensor_desc); - tensor_desc->ref_count_--; - if (tensor_desc->ref_count_ == 0) { - ReleaseMembuf(tensor_index, kDynamicMem); - } else if (tensor_desc->ref_count_ < 0) { - MS_LOG(EXCEPTION) << "tensor: " << tensor_desc->index_ << " refcount: " << tensor_desc->ref_count_ - << " check error"; - } - } -} - -void BestFitMemReuse::ReleaseNodeUnusedOutput() { - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { - size_t tensor_index = GetTensorIndex(tensor_idx); - auto tensor_desc = tensor_ptr_list_[tensor_index]; - MS_EXCEPTION_IF_NULL(tensor_desc); - if (tensor_desc->ref_count_ == 0) { - ReleaseMembuf(tensor_index, kDynamicMem); - } else if (tensor_desc->ref_count_ < 0) { - MS_LOG(EXCEPTION) << "tensor: " << tensor_desc->index_ << " refcount: " << tensor_desc->ref_count_ - << " check error"; - } - } -} - -void BestFitMemReuse::ReleasePreNodeWorkspace(const KernelDef *kernel_def_ptr) { - for (auto &workspace_index : kernel_def_ptr->GetWorkspaceRefIndexs()) { - size_t index = GetWorkspaceIndex(workspace_index); - auto wk_tensor = wk_tensor_list_[index]; - wk_tensor->ref_count_--; - if (wk_tensor->ref_count_ == 0) { - ReleaseMembuf(index, kWorkspaceMem); - } else if (wk_tensor->ref_count_ < 0) { - MS_LOG(EXCEPTION) << "tensor: " << wk_tensor->index_ << " refcount: " << wk_tensor->ref_count_ << " check error"; - } - } -} - -void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) { - if (membuf_ptr_list_.empty()) { - return; - } - auto real_index = GetRealIndex(tensor_index, flag); - auto membuf_iter = std::find_if(membuf_ptr_list_.begin(), membuf_ptr_list_.end(), - [real_index](const MembufPtr &membuf) { return membuf->index_ == real_index; }); - if (membuf_iter == membuf_ptr_list_.end()) { - return; - } - auto membuf = (*membuf_iter); - MS_EXCEPTION_IF_NULL(membuf); - membuf->status_ = kUnused; - if (membuf_iter != membuf_ptr_list_.end() - 1) { - auto next_iter = membuf_iter + 1; - auto membuf_next = (*next_iter); - MS_EXCEPTION_IF_NULL(membuf_next); - if (membuf_next->status_ == kUnused) { - bool is_merge = IsUsable(current_kernel_, membuf_next->used_kernel_); - if (is_merge) { - membuf->size_ += membuf_next->size_; - (void)membuf_ptr_list_.erase(next_iter); - } - } - } - if (membuf_iter != membuf_ptr_list_.begin()) { - auto prev_iter = membuf_iter - 1; - auto membuf_prev = (*prev_iter); - MS_EXCEPTION_IF_NULL(membuf_prev); - if (membuf_prev->status_ == kUnused) { - bool is_merge = IsUsable(current_kernel_, membuf_prev->used_kernel_); - if (is_merge) { - membuf->size_ += membuf_prev->size_; - membuf->offset_ = membuf_prev->offset_; - (void)membuf_ptr_list_.erase(prev_iter); - } - } - } -} - -size_t BestFitMemReuse::AlignMemorySize(size_t size) const { - // memory size 512 align - return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize; -} - -size_t BestFitMemReuse::GetAllocatedSize() { - size_t AllocatedSize = kTotalSize; - if (membuf_ptr_list_.empty()) { - return AllocatedSize; - } - AllocatedSize = membuf_ptr_list_.back()->offset_ + membuf_ptr_list_.back()->size_; - MS_LOG(INFO) << "MemReuse Allocated Dynamic Size: " << AllocatedSize; - return AllocatedSize; -} - -bool BestFitMemReuse::IsRelease() { - // unable_used_node include the node type that output tensor cannot be released, - // even if its refcount is equal to zero. - std::unordered_set unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(), - prim::kPrimFusedBatchNorm->name(), - prim::kPrimFusedBatchNormGrad->name()}; - return unable_used_node.find(current_kernel_->kernel_name()) == unable_used_node.end(); -} - -size_t BestFitMemReuse::GetTensorIndex(int index) const { - if (index < 0 || IntToSize(index) >= tensor_ptr_list_.size()) { - MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); - MS_LOG(EXCEPTION) << "invalid tensor index"; - } - return IntToSize(index); -} - -size_t BestFitMemReuse::GetWorkspaceIndex(int index) const { - if (index < 0 || IntToSize(index) >= wk_tensor_list_.size()) { - MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); - MS_LOG(EXCEPTION) << "invalid tensor index"; - } - return IntToSize(index); -} - -int BestFitMemReuse::GetRealIndex(size_t index, int flag) const { - if (flag == kDynamicMem) { - return SizeToInt(index); - } else if (flag == kWorkspaceMem) { - return kWorkspaceIndexFactor * SizeToInt(index + 1); - } else { - MS_LOG(EXCEPTION) << "flag " << flag << " is invalid"; - } -} - -void BestFitMemReuse::CheckMembufIndx(size_t membuf_index) const { - if (membuf_index >= membuf_ptr_list_.size()) { - MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); - MS_LOG(EXCEPTION) << "invalid membuf index: " << membuf_index << ", real size: " << membuf_ptr_list_.size(); - } -} - -void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - InitMemReuseInfo(mem_reuse_util_ptr); - InitKernelDependence(); - KernelDefPtr pre_op = nullptr; -#ifdef MEM_REUSE_DEBUG - size_t op_num = 0; -#endif - for (const auto &op_def_ptr : op_ptr_list_) { - current_kernel_ = op_def_ptr; - // releas pre_op_def - if (pre_op != nullptr) { - ReleasePreNodeWorkspace(pre_op.get()); - } - MemReuseChecker::GetInstance().IsAddNewMembuf_ = false; - // process node output tensor - AssignNodeOutputOffset(); -#ifdef MEM_REUSE_DEBUG - if (MemReuseChecker::GetInstance().IsAddNewMembuf_) { - MemReuseChecker::GetInstance().SetAddNewMembuInfos(op_def_ptr.get(), membuf_ptr_list_, op_num); - } -#endif - // deal with current op'workspace - AssignNodeWorkspaceOffset(); - pre_op = op_def_ptr; - // update node input tensor refcount, and membuf list status - UpdateNodeInputAndMembuf(); - // check node output tensor which refcount is equal to zero - if (IsRelease()) { - ReleaseNodeUnusedOutput(); - } -#ifdef MEM_REUSE_DEBUG - MemReuseChecker::GetInstance().SetMembuInfos(op_def_ptr.get(), membuf_ptr_list_); - ++op_num; -#endif - } -#ifdef MEM_REUSE_DEBUG - MemReuseChecker::GetInstance().ExportMembufInfoIR(); - MemReuseChecker::GetInstance().ExportAddNewMmebufIR(); - MemReuseChecker::GetInstance().set_kernel_front_map(kernel_front_map_); - MemReuseChecker::GetInstance().ExportKernelDependence(); -#endif -} -} // namespace memreuse -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.h deleted file mode 100644 index 9aeda05dc3..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.h +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "pre_activate/mem_reuse/kernel_refcount.h" -#include "pre_activate/mem_reuse/mem_reuse.h" - -namespace mindspore { -namespace memreuse { -static constexpr int kWorkspaceIndexFactor = -1000; -static constexpr int kDynamicMem = -1; -static constexpr int kWorkspaceMem = 1; -static constexpr size_t kTotalSize = 0; -enum Status { kUnused, kReused }; -class Membuf { - public: - Membuf() = default; - Membuf(Status status, size_t size, size_t offset, int index, const KernelDefPtr &used_kernel) - : status_(status), size_(size), offset_(offset), index_(index), used_kernel_(used_kernel) {} - ~Membuf() = default; - // Memory block status flags - Status status_ = kUnused; - size_t size_{0}; - size_t offset_{0}; - // Store the tensor index stored in this memory block at a certain moment - int index_{0}; - KernelDefPtr used_kernel_; -}; -using MembufPtr = std::shared_ptr; - -class BestFitMemReuse { - public: - BestFitMemReuse() = default; - ~BestFitMemReuse() { membuf_ptr_list_.clear(); } - /** - * Init all information need by memory reuse - * @param mem_reuse_util_ptr, initialize in the memreuse.cc - */ - void InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr); - void CheckMembufIndx(size_t check_idx) const; - void AssignNodeWorkspaceOffset(); - void ReleasePreNodeWorkspace(const KernelDef *kernel_def_ptr); - /** - * Assign output tensor memory offset of current kernel - */ - void AssignNodeOutputOffset(); - /** - * Update input tensor's status of current kernel, and the status of membuf used by current kernel - */ - void UpdateNodeInputAndMembuf(); - /** - * Check whether to release the kernel output tensor which refcount is equal to zero - */ - void ReleaseNodeUnusedOutput(); - /** - * Reuse the exist membuf if possible - * @param tensor_desc, the output tensor of current kernel - * @param membuf_index, the index of membuf to be reused - * @param flag - */ - void ReuseExistMembuf(KernelRefCount *tensor_desc, size_t membuf_index, int flag); - /** - * Get the membuf that can be reused - * @param tensor_size, the size of the tensor ready to assign memory offset - * @return membuf map, key: the membuf size, value: the membuf index - */ - std::map GetReusableMembufMap(size_t tensor_size); - /** - * Update the status of the reused memory block - * @param tensor_desc, the tensor ready to assign memory - * @param membuf, the membuf to be reused - * @param flag, distinguish dynamic memory and workspace - */ - void UpdateMembufInfo(KernelRefCount *tensor_desc, Membuf *membuf, int flag); - // If the size of the memory block is greater than the size of the tensor, split the extra memory - void SplitMembuf(const KernelRefCount *tensor_desc, size_t membuf_index); - // Determine if the memory block needs to be split - bool IsSplit(size_t tensor_size, size_t membuf_size) const; - // If there is no memory block that can be reused, add a new memory block at the end - void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag); - // Merge unused membuf - void ReleaseMembuf(size_t tensor_index, int flag); - // Memory address alignment 512 - size_t AlignMemorySize(size_t size) const; - int GetRealIndex(size_t index, int flag = kDynamicMem) const; - size_t GetTensorIndex(int index) const; - size_t GetWorkspaceIndex(int index) const; - // Memory reuse main program entry - void Reuse(const MemReuseUtil *mem_reuse_util_ptr); - // Get the total memory that needs to be applied eventually - size_t GetAllocatedSize(); - // return false, when the node output cannot be released - bool IsRelease(); - /** - * determine if the kernel_curr can reuse the output tensor add of kernel_prev - * @param kernel_curr, current kernel - * @param kernel_prev, the membuf used by this kernel - * @return bool - */ - bool IsUsable(const KernelDefPtr &kernel_curr, const KernelDefPtr &kernel_prev); - /** - * init the dependence of all kernels in the graph - */ - void InitKernelDependence(); - // set tensor_def and op_def - void set_tensor_ptr_list(const std::vector &tensor_ptr_list) { - tensor_ptr_list_ = tensor_ptr_list; - } - void set_workspace_ptr_list(const std::vector &workspace_ptr_list) { - wk_tensor_list_ = workspace_ptr_list; - } - void set_op_ptr_list(const std::vector &op_ptr_list) { op_ptr_list_ = op_ptr_list; } - - private: - KernelDefPtr current_kernel_; - // Save all tensor information - std::vector tensor_ptr_list_; - std::vector wk_tensor_list_; - // Save all op information, including input and output tensor index - std::vector op_ptr_list_; - // Memory block information sequence, temporary variables - std::vector membuf_ptr_list_; - // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def - std::map> kernel_front_map_; -}; -} // namespace memreuse -} // namespace mindspore -#endif // #define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc deleted file mode 100644 index 5cd6a5f50e..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc +++ /dev/null @@ -1,569 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/mem_reuse/mem_reuse_checker.h" -#include -#include -#include -#include - -namespace mindspore { -namespace memreuse { -MemReuseChecker &MemReuseChecker::GetInstance() { - static MemReuseChecker instance; - return instance; -} - -void MemReuseChecker::CheckSignalOps(const CNodePtr &c_node) { - std::string node_name = AnfAlgo::GetCNodeName(c_node); - if (node_name == kSend || node_name == kRecv) { - MS_LOG(INFO) << "MemReuseChecker check op_name of Send or Send"; - // get op's info && check - MS_LOG(INFO) << "op: " << node_name << " in_num: " << AnfAlgo::GetInputTensorNum(c_node) - << " out_num: " << AnfAlgo::GetOutputTensorNum(c_node); - } -} - -void MemReuseChecker::CheckWorkSpace(const std::vector &max_list) { - for (auto &ma : max_list) { - total_re_wkspe_size_checker_ += ma; - } -} - -void MemReuseChecker::CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx) { - auto key = c_node.get(); - auto iter = kernel_refs.find(key); - auto node_name = AnfAlgo::GetCNodeName(c_node); - if (iter == kernel_refs.end()) { - MS_LOG(EXCEPTION) << "kernel [" << node_name << "] has no output tensor, node: " << c_node->DebugString() - << " output index: " << output_idx; - } - if (output_idx >= iter->second.size()) { - MS_LOG(INFO) << "invalid cnode: " << c_node->fullname_with_scope().c_str(); - MS_LOG(EXCEPTION) << "The index: " << output_idx - << " is out of the size of kernel_output_refs_:" << iter->second.size(); - } -} - -int64_t MemReuseChecker::CalculOriInput(const KernelGraph *graph) const { - MS_EXCEPTION_IF_NULL(graph); - int64_t static_input_size = 0; - for (auto &item : graph->inputs()) { - if (!item->isa()) { - continue; - } - auto output_size = AnfAlgo::GetOutputTensorNum(item); - for (size_t index = 0; index < output_size; index++) { - TypeId ou_type = AnfAlgo::GetOutputDeviceDataType(item, index); - // parameter has not init by a cnode - if (ou_type == kTypeUnknown) { - ou_type = AnfAlgo::GetOutputInferDataType(item, index); - } - size_t type_size = GetTypeByte(TypeIdToType(ou_type)); - std::vector shape = AnfAlgo::GetOutputDeviceShape(item, index); - size_t tensor_size = - shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - auto checker_size = SizeToLong(tensor_size); - static_input_size += checker_size; - } - } - return static_input_size; -} - -int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const { - MS_EXCEPTION_IF_NULL(graph); - int64_t static_value_size = 0; - for (auto &value_node : graph->graph_value_nodes()) { - MS_EXCEPTION_IF_NULL(value_node); - auto &node_value = value_node->value(); - MS_EXCEPTION_IF_NULL(node_value); - auto tensor = node_value->cast(); - if (tensor == nullptr) { - continue; - } - size_t tensor_size = tensor->data().nbytes(); - auto checker_size = SizeToLong(tensor_size); - static_value_size += checker_size; - } - return static_value_size; -} - -int64_t MemReuseChecker::CalculOriStatic(KernelGraph *graph) const { - // cal static inputs - auto static_input_size = CalculOriInput(graph); - // do not calcul outpput size - auto statica_value_size = CalculOriValue(graph); - auto total_ori_static_size = static_input_size + statica_value_size; - return total_ori_static_size; -} - -int64_t MemReuseChecker::CalculOriDy(const KernelGraph *graph) const { - MS_EXCEPTION_IF_NULL(graph); - int64_t ori_dy_size = 0; - auto kerenls = graph->execution_order(); - for (auto &kernel : kerenls) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (auto &dy_size : kernel_mod->GetOutputSizeList()) { - auto checker_size = SizeToLong(dy_size); - ori_dy_size += checker_size; - } - } - return ori_dy_size; -} - -int64_t MemReuseChecker::CalculOriWk(const KernelGraph *graph) const { - MS_EXCEPTION_IF_NULL(graph); - int64_t ori_wk_size = 0; - auto kerenls = graph->execution_order(); - for (auto &kernel : kerenls) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (auto &wk_size : kernel_mod->GetWorkspaceSizeList()) { - auto checker_size = SizeToLong(wk_size); - ori_wk_size += checker_size; - } - } - return ori_wk_size; -} - -std::string MemReuseChecker::GetSplitName(const std::string &scope_name) const { - auto indx = scope_name.rfind(kSplitC); - if (indx == std::string::npos) { - return scope_name; - } else { - if (indx < scope_name.size() - 1) { - auto split_name = scope_name.substr(indx + 1); - return split_name; - } - return scope_name; - } -} - -void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, - const KernelDefPtrMaps &kernel_def_ptr_list, KernelGraph *graph) { - total_ori_static_size_ = CalculOriStatic(graph); - total_ori_input_size_ = CalculOriInput(graph); - total_ori_value_size_ = CalculOriValue(graph); - total_ori_dy_size_ = CalculOriDy(graph); - total_ori_wkspace_size_ = CalculOriWk(graph); - std::string graph_id = std::to_string(graph->graph_id()); - std::string filename = "./memreuse_" + graph_id + ".ir"; - std::ofstream ofs(filename); - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; - return; - } - ofs << "all_tensor_refs:\n"; - ofs << "index:" - << "\tsize:" - << "\trefcount:\n"; - for (auto &ref : total_refs_list) { - ofs << "%" << ref->index_ << "T" - << "\t" - << "#" << ref->size_ << "S" - << "\t" << ref->ref_count_ << "C" - << "\n"; - } - ofs << "kernel_def exc_order:\n"; - int def_idx = 0; - for (auto &def : kernel_def_ptr_list) { - ExportMemOpIr(def.get(), ofs, def_idx); - def_idx++; - } - ofs.close(); -} - -void MemReuseChecker::ExportKernelDependence() { - std::string filename = "./memreuse_dependence.ir"; - std::ofstream ofs(filename); - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; - return; - } - size_t i = 0; - for (const auto &kernel_front : kernel_front_map_) { - auto kernel = kernel_front.first; - auto front = kernel_front.second; - ofs << "[" << i++ << "] " << kernel->scope_full_name() << "\n"; - for (const auto &node : front) { - ofs << node->scope_full_name() << "\n"; - } - ofs << "\n\n"; - } - - ofs.close(); -} - -bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph) { - // set real graph output node to be special who's refcount equal kMaxRefCount - for (const auto &output : graph->outputs()) { - MS_EXCEPTION_IF_NULL(output); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) { - if (output->isa()) { - auto cnode = output->cast(); - auto input_node = cnode->input(i + 1); - auto kernel_input_with_idx = AnfAlgo::VisitKernel(input_node, 0); - auto kernel_input = kernel_input_with_idx.first; - MS_EXCEPTION_IF_NULL(kernel_input); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel_input); - if (kernel_mod == nullptr) { - continue; - } - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.empty()) { - continue; - } - for (size_t j = 0; j < output_sizes.size(); ++j) { - if (!AnfAlgo::OutputAddrExist(kernel_input, j)) { - return false; - } - } - } - } - } - return true; -} - -void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) { - auto scope_name = def->scope_full_name(); - std::string split_name = GetSplitName(scope_name); - ofs << "$" << def_idx << "\t" << split_name << "\t"; - ofs << "inputs["; - for (auto &in : def->inputs_) { - for (auto &in_ref : in.second) { - ofs << "%" << in_ref->index_ << "T" - << ","; - } - } - ofs << "]"; - ofs << "\toutpus["; - for (auto &ou : def->outputs_) { - for (auto &ou_ref : ou.second) { - ofs << "%" << ou_ref->index_ << "T" - << ","; - } - } - ofs << "]"; - ofs << "\tstreamID[" - << "@" << def->stream_id() << "]\n"; -} - -void MemReuseChecker::ExportNormalTensorIR(std::ofstream &ofs) { - ofs << "all_tensor_refs:\n"; - ofs << "index:" - << "\tsize:" - << "\trefcount:\n"; - size_t ou_idx = 0; - for (auto &ou : nor_output_tensors_) { - ofs << "%" << ou_idx << "T" - << "\t" - << "#" << nor_tensor_sizes_[ou_idx] << "S" - << "\t"; - auto iter_ref = ptr_refs_.find(ou); - if (iter_ref != ptr_refs_.end()) { - ofs << iter_ref->second << "C" - << "\n"; - } else { - MS_LOG(EXCEPTION) << "can not find refs for output"; - } - ou_idx++; - } - ofs << "kernel_def exc_order:\n"; -} - -int MemReuseChecker::GetTensorIdx(const void *in) const { - auto iter = ptr_idx_.find(in); - if (iter == ptr_idx_.end()) { - return kInvalidIndex; - } else { - return SizeToInt(iter->second); - } -} - -void MemReuseChecker::ExportNormalOpIr(const std::vector &cnodes) { - std::ofstream ofs("./normal_mem.ir"); - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file failed!"; - return; - } - ExportNormalTensorIR(ofs); - size_t node_idx = 0; - for (const auto &node : cnodes) { - MS_EXCEPTION_IF_NULL(node); - ofs << "$" << node_idx << "\t" << GetSplitName(node->fullname_with_scope()) << "\t"; - std::vector in_idx; - auto iter = node_ins_.find(node.get()); - if (iter != node_ins_.end()) { - for (auto &in : iter->second) { - if (GetTensorIdx(in) != kInvalidIndex) { - in_idx.push_back(GetTensorIdx(in)); - } - } - } - std::vector ou_idx; - iter = node_ous_.find(node.get()); - if (iter != node_ous_.end()) { - for (auto &ou : iter->second) { - if (GetTensorIdx(ou) != kInvalidIndex) { - ou_idx.push_back(GetTensorIdx(ou)); - } - } - } - ofs << "inputs["; - for (auto idx : in_idx) { - bool has_in_ou = std::any_of(ou_idx.begin(), ou_idx.end(), [idx](int odx) { return idx == odx; }); - if (!has_in_ou) { - ofs << "%" << idx << "T,"; - } - } - ofs << "]\toutpus["; - for (auto odx : ou_idx) { - ofs << "%" << odx << "T,"; - } - ofs << "]\tstreamID[@" << AnfAlgo::GetStreamId(node) << "]\n"; - node_idx++; - } - ofs.close(); -} - -void MemReuseChecker::SetTesnorFromAndToInfo(const KernelDef *op_def) { - auto split_name = GetSplitName(op_def->scope_full_name()); - for (auto &in : op_def->inputs_) { - auto in_tensors = in.second; - for (auto &tensor : in_tensors) { - auto indx = tensor->index_; - tensor_to_[indx].push_back(split_name); - } - } - for (auto &ou : op_def->outputs_) { - auto ou_tensors = ou.second; - for (auto &tensor : ou_tensors) { - auto indx = tensor->index_; - tensor_from_[indx].push_back(split_name); - } - } -} - -void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) { - const auto &cnodes = graph->execution_order(); - for (const auto &node : cnodes) { - std::vector curr_ous; - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); ++i) { - auto it = AnfAlgo::GetOutputAddr(node, i); - MS_EXCEPTION_IF_NULL(it); - auto ptr = it->GetPtr(); - nor_output_tensors_.push_back(ptr); - nor_tensor_sizes_.push_back(it->GetSize()); - curr_ous.push_back(it->GetPtr()); - } - (void)node_ous_.insert(std::make_pair(node.get(), curr_ous)); - std::vector curr_ins; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { - if (i + 1 >= node->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index: " << i - << " is larger than input number: " << AnfAlgo::GetInputTensorNum(node); - } - auto real_input_index = AnfAlgo::GetRealInputIndex(node, i); - auto input = node->input(real_input_index + 1); - MS_EXCEPTION_IF_NULL(input); - auto kernel_with_index = AnfAlgo::VisitKernel(input, 0); - if (kernel_with_index.first->isa()) { - continue; - } - auto device_address = AnfAlgo::GetPrevNodeOutputAddr(node, real_input_index); - MS_EXCEPTION_IF_NULL(device_address); - nor_input_tensors_.push_back(device_address->GetPtr()); - curr_ins.push_back(device_address->GetPtr()); - } - (void)node_ins_.insert(std::make_pair(node.get(), curr_ins)); - } - size_t ou_idx = 0; - for (const auto &ou : nor_output_tensors_) { - (void)ptr_idx_.insert(std::make_pair(ou, ou_idx)); - (void)ptr_refs_.insert(std::make_pair(ou, 0)); - ou_idx++; - } - for (const auto &in : nor_input_tensors_) { - if (ptr_idx_.find(in) != ptr_idx_.end()) { - if (ptr_refs_.find(in) != ptr_refs_.end()) { - auto iter = ptr_refs_.find(in); - (iter->second)++; - } else { - MS_LOG(EXCEPTION) << "ptr_refs is not equal to ptr_idx"; - } - } - } - ExportNormalOpIr(cnodes); -} - -void MemReuseChecker::SetMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list) { - std::vector curr_mem_infos; - for (const auto &mem : membuf_ptr_list) { - auto mem_checker = std::make_shared(mem->status_, mem->size_, mem->offset_, mem->index_, mem->used_kernel_); - curr_mem_infos.push_back(mem_checker); - } - membuf_all_infos_.push_back(curr_mem_infos); - auto split_name = GetSplitName(op_def->scope_full_name()); - all_split_names_.push_back(split_name); - SetTesnorFromAndToInfo(op_def); -} - -void MemReuseChecker::SetAddNewMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list, - size_t op_idx) { - std::vector add_new_curr_mem; - - for (const auto &mem : membuf_ptr_list) { - auto mem_checker = std::make_shared(mem->status_, mem->size_, mem->offset_, mem->index_, mem->used_kernel_); - add_new_curr_mem.push_back(mem_checker); - } - add_new_mem_infos_.push_back(add_new_curr_mem); - auto split_name = GetSplitName(op_def->scope_full_name()); - add_new_names_.push_back(split_name); - add_new_op_indxs_.push_back(op_idx); - add_new_stream_ids_.push_back(op_def->stream_id()); -} - -void MemReuseChecker::ExportEachMembufInfo(std::ofstream &ofs) { - size_t i = 0; - std::vector each_node_used_size; - std::vector each_node_allocated_size; - for (const auto &curr_membuf_list : membuf_all_infos_) { - ofs << all_split_names_.at(i) << "\n"; - ++i; - ofs << "mem_num\t" - << "stream_id\t" - << "status\t" - << "tensor_idex\t" - << "mem_size\t" - << "mem_head\t" - << "mem_tail\t" - << "used_kernel\n"; - size_t curr_used = 0; - size_t curr_allocated = 0; - for (size_t j = 0; j < curr_membuf_list.size(); ++j) { - auto membuf = curr_membuf_list.at(j); - auto used_kernel = membuf->used_kernel_->scope_full_name(); - ofs << "&" << j << "\t" - << "streamID[@" << membuf->used_kernel_->stream_id() << "]" - << "\t" - << "#" << static_cast(membuf->status_) << "\t%" << membuf->index_ << "T" - << "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t" << membuf->offset_ + membuf->size_ << "\t" - << GetSplitName(used_kernel) << "\n"; - if (membuf->status_ == kReused) { - curr_used += membuf->size_; - } - } - if (!curr_membuf_list.empty()) { - curr_allocated = curr_membuf_list.back()->offset_ + curr_membuf_list.back()->size_; - } - each_node_used_size.push_back(curr_used); - each_node_allocated_size.push_back(curr_allocated); - ofs << "curr real used size: \t" << curr_used << "\n"; - ofs << "curr allocated size: \t" << curr_allocated << "\n"; - ofs << "\n\n"; - } - auto optimal_iter = std::max_element(each_node_used_size.begin(), each_node_used_size.end()); - ofs << "theoretical optimal size: " << *optimal_iter << "\n"; - ofs << "each node used size: \n"; - for (auto size : each_node_used_size) { - ofs << size << "\t"; - } - ofs << "\n\n"; - ofs << "each node allocated size: \n"; - for (auto size : each_node_allocated_size) { - ofs << size << "\t"; - } - ofs << "\n\n"; -} - -void MemReuseChecker::ExportMembufInfoIR() { - std::string ir_file_name = "./mem_buf_info.ir"; - std::ofstream ofs(ir_file_name); - int64_t total_reuse_size = 0; - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file [" << ir_file_name << "] failed!"; - } - ofs << "Total static size:\t" << total_ori_static_size_ << "\n"; - ofs << "Graph inputs size:\t" << total_ori_input_size_ << "\n"; - ofs << "Value nodes size:\t" << total_ori_value_size_ << "\n"; - ofs << "Total dynamic size:\t" << total_ori_dy_size_ << "\n"; - ofs << "Total workspace size:\t" << total_ori_wkspace_size_ << "\n"; - // get last membuf_list - if (membuf_all_infos_.empty()) { - return; - } - auto last_membuf_list = membuf_all_infos_.back(); - for (const auto &membuf : last_membuf_list) { - auto checker_size = SizeToLong(membuf->size_); - total_reuse_size += checker_size; - } - ofs << "After reuse size:\t" << total_reuse_size << "\n\n"; - ExportEachMembufInfo(ofs); - ofs.close(); -} - -void MemReuseChecker::ExportAddNewMmebufIR() { - std::string ir_file_name = "./AddNewMembuf.ir"; - std::ofstream ofs(ir_file_name); - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file [" << ir_file_name << "] failed!"; - } - auto check_idx = add_new_mem_infos_.size(); - if (check_idx == add_new_op_indxs_.size() && check_idx == add_new_names_.size() && - check_idx == add_new_stream_ids_.size()) { - size_t i = 0; - for (const auto &curr_membuf_list : add_new_mem_infos_) { - ofs << "op_idx:$" << add_new_op_indxs_.at(i) << "\t" << add_new_names_.at(i) << "\t"; - ofs << "streamID[@" << add_new_stream_ids_.at(i) << "]" - << "\n"; - i++; - ofs << "mem_num\t" - << "status\t" - << "tensor_idex\t" - << "mem_size\t" - << "mem_head\t" - << "mem_tail\t" - << "FromOp\t" - << "ToOp\n"; - for (size_t j = 0; j < curr_membuf_list.size(); ++j) { - auto membuf = curr_membuf_list.at(j); - ofs << "&" << j << "\t" - << "\t" - << "#" << static_cast(membuf->status_) << "\t%" << membuf->index_ << "T" - << "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t" << membuf->offset_ + membuf->size_ << "\t"; - auto in_idx_iter = tensor_from_.find(membuf->index_); - if (in_idx_iter != tensor_from_.end()) { - for (auto &in_name : in_idx_iter->second) { - ofs << in_name << ","; - } - ofs << "\t"; - } - auto ou_idx_iter = tensor_to_.find(membuf->index_); - if (ou_idx_iter != tensor_to_.end()) { - for (auto &ou_name : ou_idx_iter->second) { - ofs << ou_name << ","; - } - ofs << "\n"; - } - } - ofs << "\n"; - } - } - ofs.close(); -} -} // namespace memreuse -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.h deleted file mode 100644 index 5fd3d0f5ae..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ -#include -#include -#include -#include -#include -#include -#include "mindspore/ccsrc/ir/anf.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/mem_reuse/mem_reuse.h" -#include "kernel/common_utils.h" -#include "pre_activate/mem_reuse/mem_reuse_allocator.h" -namespace mindspore { -namespace memreuse { -constexpr auto kSend = "Send"; -constexpr auto kRecv = "Recv"; -constexpr auto kSplitC = '/'; -class MemReuseChecker { - public: - bool IsAddNewMembuf_ = false; - static MemReuseChecker &GetInstance(); - MemReuseChecker(const MemReuseChecker &) = delete; - MemReuseChecker &operator=(const MemReuseChecker &) = delete; - void CheckSignalOps(const CNodePtr &c_node); - void CheckWorkSpace(const std::vector &max_list); - void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx); - bool CheckGraphOutputAssigned(const session::KernelGraph *graph); - void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list, - KernelGraph *graph); - int64_t CalculOriStatic(KernelGraph *graph) const; - int64_t CalculOriInput(const KernelGraph *graph) const; - int64_t CalculOriValue(KernelGraph *graph) const; - int64_t CalculOriDy(const KernelGraph *graph) const; - int64_t CalculOriWk(const KernelGraph *graph) const; - std::string GetSplitName(const std::string &scope_name) const; - int GetTensorIdx(const void *in) const; - void SetMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list); - void SetTesnorFromAndToInfo(const KernelDef *op_def); - void ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx); - void ExportNormalOpIr(const std::vector &cnodes); - void ExportNormalTensorIR(std::ofstream &ofs); - void CheckNormalIR(const session::KernelGraph *graph); - void ExportMembufInfoIR(); - void ExportEachMembufInfo(std::ofstream &ofs); - void SetAddNewMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list, size_t op_idx); - void ExportAddNewMmebufIR(); - void set_kernel_front_map(const std::map> &kernel_front_map) { - kernel_front_map_ = kernel_front_map; - } - void ExportKernelDependence(); - - private: - MemReuseChecker() = default; - ~MemReuseChecker() {} - size_t total_re_wkspe_size_checker_{0}; - std::vector> membuf_all_infos_; - std::vector nor_output_tensors_; - std::vector nor_tensor_sizes_; - std::vector nor_input_tensors_; - std::map ptr_idx_; - std::map ptr_refs_; - std::map> node_ins_; - std::map> node_ous_; - std::vector> add_new_mem_infos_; - std::vector add_new_names_; - std::vector add_new_op_indxs_; - std::vector add_new_stream_ids_; - std::vector all_split_names_; - std::map> tensor_from_; - std::map> tensor_to_; - std::map> kernel_front_map_; - int64_t total_ori_static_size_ = 0; - int64_t total_ori_input_size_ = 0; - int64_t total_ori_value_size_ = 0; - int64_t total_ori_dy_size_ = 0; - int64_t total_ori_wkspace_size_ = 0; -}; -} // namespace memreuse -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc deleted file mode 100644 index 14073bfbc9..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc +++ /dev/null @@ -1,344 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/mem_reuse/mem_swap_manager.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace device { -namespace memswap { -void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - graph_manager_ = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(graph_manager_); - auto &kernels = kernel_graph->execution_order(); - for (const auto &kernel : kernels) { - if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) { - execution_order_.push_back(kernel); - } - } - - size_t kernel_index = 0; - for (const auto &kernel : execution_order_) { - // parse topo order of kernel - (void)kernel_execution_info_.emplace(kernel.get(), kernel_index++); - // parse tensor info - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - - for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(kernel); ++output_idx) { - TensorInfo tensor_info = {output_sizes[output_idx], kernel, output_idx}; - ordered_tensors_.push_back(tensor_info); - } - } - - // parse topo order of user kernel - SaveUserKernelTopoOrder(); - - sort(ordered_tensors_.begin(), ordered_tensors_.end(), - [](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; }); - - auto cur_tensor_size = ordered_tensors_.front().tensor_size_; - for (auto &tensor_info : ordered_tensors_) { - if (cur_tensor_size != tensor_info.tensor_size_) { - cur_tensor_size = tensor_info.tensor_size_; - tensor_size_num_++; - } - } - tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; - tensor_size_threshold_idx_ = 0; - - distance_threshold_ = kernel_index / kDistanceInitFactor; - mem_swap_initialized_ = true; - MS_EXCEPTION_IF_NULL(mem_copy_manager_); - mem_copy_manager_->Init(); -} - -bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { - MS_EXCEPTION_IF_NULL(kernel); - NodeUsersMap &user_map = graph_manager_->node_users(); - auto iter = user_map.find(kernel); - bool adjacent_with_communication_op = false; - if (iter != user_map.end()) { - AnfNodeIndexSet node_set = iter->second; - adjacent_with_communication_op = std::any_of( - node_set.begin(), node_set.end(), - [](const std::pair &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); }); - } - return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op; -} - -void MemSwapManager::SaveUserKernelTopoOrder() { - NodeUsersMap &user_map = graph_manager_->node_users(); - for (const auto &kernel : execution_order_) { - auto iter = user_map.find(kernel); - if (iter == user_map.end()) { - continue; - } - AnfNodeIndexSet node_set = iter->second; - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - for (auto &node_pair : node_set) { - auto user_kernel = node_pair.first; - if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) { - continue; - } - - size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_; - auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1); - auto &output_idx = kernel_with_index.second; - if (kernel_with_index.first.get() != kernel.get()) { - MS_LOG(EXCEPTION) << "Save user kernel topo order failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort); - } - for (auto &node_user_pair : kernel_exec_info.node_users_map_) { - sort(node_user_pair.second.begin(), node_user_pair.second.end()); - } - } -} - -void MemSwapManager::AddSwapInfo() { - for (const auto &tensor : ordered_tensors_) { - size_t tensor_size = tensor.tensor_size_; - if (tensor_size < tensor_size_threshold_) { - break; - } - - size_t output_idx = tensor.output_idx_; - const AnfNodePtr &kernel = tensor.kernel_; - if (IsCommunicationRelevantOp(kernel)) { - continue; - } - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - auto &node_users_map = kernel_exec_info.node_users_map_; - - auto iter = node_users_map.find(output_idx); - if (iter == node_users_map.end()) { - continue; - } - auto &node_users = iter->second; - bool need_swap = (node_users.size() == 1 && node_users[0] - kernel_exec_info.topo_order_ >= distance_threshold_) || - (node_users.size() > 1 && node_users[1] - node_users[0] >= distance_threshold_); - if (!need_swap) { - continue; - } - AddKernelNeedSwap(kernel, true); - HostAddress host_addr; - host_addr.size = tensor_size; - auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast(&host_addr.addr)); - if (!ret) { - MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed."; - } - kernel_exec_info.host_addrs_[output_idx] = host_addr; - MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel, output_idx}; - if (node_users.size() > 1) { - AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info); - AddKernelTriggerSwap(execution_order_[node_users[0]], true); - } else { - AddKernelMemSwapInfo(kernel, mem_swap_out_info); - AddKernelTriggerSwap(kernel, true); - } - - size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1; - if (swap_in_order <= kernel_exec_info.topo_order_) { - MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - auto swap_in_kernel = execution_order_[swap_in_order]; - MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel, output_idx}; - AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info); - AddKernelTriggerSwap(swap_in_kernel, true); - - host_addrs_list_.push_back(host_addr); - } -} - -void MemSwapManager::AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, - const HostAddress &host_address) const { - if (swap_kind == SwapKind::kDeviceToHost) { - mem_copy_manager_->AddMemSwapOutTask(device_address, host_address); - } else if (swap_kind == SwapKind::kHostToDevice) { - mem_copy_manager_->AddMemSwapInTask(device_address, host_address); - } -} - -bool MemSwapManager::SyncMemCopyStream(SwapKind swap_kind) const { - return mem_copy_manager_->SyncMemCopyStream(swap_kind); -} - -DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind) const { - if (swap_kind == SwapKind::kDeviceToHost) { - return mem_copy_manager_->UpdateSwapOutQueue(); - } else { - return mem_copy_manager_->UpdateSwapInQueue(); - } -} - -// retreat to find a workable swap scheme -bool MemSwapManager::RetreatSwapInfo() { - if (!trigger_swap_) { - trigger_swap_ = true; - } - if (swap_info_already_set_) { - ResetSwapInfo(); - if (distance_threshold_ >= kDistanceLowerBound) { - auto distance_decay_step = execution_order_.size() / kDistanceInitFactor / tensor_size_num_; - distance_threshold_ -= (distance_decay_step > 1 ? distance_decay_step : 1); - } - - while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { - ++tensor_size_threshold_idx_; - if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { - tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; - break; - } - } - - if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) { - MS_LOG(ERROR) << "Retreat swap info failed"; - return false; - } - } else { - swap_info_already_set_ = true; - } - AddSwapInfo(); - return true; -} - -KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const { - MS_EXCEPTION_IF_NULL(kernel); - auto iter = kernel_execution_info_.find(kernel.get()); - if (iter == kernel_execution_info_.end()) { - MS_LOG(EXCEPTION) << "Can not find execution info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - return const_cast(iter->second); -} - -void MemSwapManager::AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform) { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - kernel_exec_info.execution_perform_ = perform; -} - -void MemSwapManager::AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap) { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - kernel_exec_info.trigger_swap_ = trigger_swap; -} - -void MemSwapManager::AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap) { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - kernel_exec_info.need_swap_ = need_swap; -} - -void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, - const std::pair &perform) { - MS_EXCEPTION_IF_NULL(kernel); - kernel_swap_perform_[kernel.get()][output_idx] = perform; -} - -void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { - MS_EXCEPTION_IF_NULL(kernel); - mem_swap_info_[kernel.get()].push_back(mem_swap_info); -} - -float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const { - const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - return kernel_exec_info.execution_perform_; -} - -bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const { - const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - return kernel_exec_info.trigger_swap_; -} - -bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const { - const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - return kernel_exec_info.need_swap_; -} - -const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const { - MS_EXCEPTION_IF_NULL(kernel); - auto iter_kernel = kernel_swap_perform_.find(kernel.get()); - if (iter_kernel == kernel_swap_perform_.end()) { - MS_LOG(EXCEPTION) << "Can not find swap performance data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - - auto &perform_map = iter_kernel->second; - auto iter_output = perform_map.find(output_idx); - if (iter_output == perform_map.end()) { - MS_LOG(EXCEPTION) << "Can not find swap performance data of output[" << output_idx << "] of op[" - << AnfAlgo::GetCNodeName(kernel) << "]"; - } - return iter_output->second; -} - -const std::vector &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { - MS_EXCEPTION_IF_NULL(kernel); - auto iter = mem_swap_info_.find(kernel.get()); - if (iter == mem_swap_info_.end()) { - MS_LOG(EXCEPTION) << "Can not find memory swap information data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - return iter->second; -} - -void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); } - -bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const { - auto iter = swap_in_blacklist_.find(device_ptr); - return iter != swap_in_blacklist_.end(); -} - -const HostAddress &MemSwapManager::kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - auto &host_addrs = kernel_exec_info.host_addrs_; - auto iter = host_addrs.find(output_idx); - if (iter == host_addrs.end()) { - MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - return iter->second; -} - -bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const { - return mem_copy_manager_->AllocHostPinnedMem(size, addr); -} - -void MemSwapManager::ReleaseHostPinnedMem() { - for (const auto &host_addr : host_addrs_list_) { - if (host_addr.addr) { - mem_copy_manager_->FreeHostPinnedMem(host_addr.addr); - } - } - host_addrs_list_.clear(); -} - -void MemSwapManager::ClearSwapQueue() const { mem_copy_manager_->ClearSwapQueue(); } - -void MemSwapManager::ResetSwapInfo() { - ClearSwapQueue(); - for (auto &kernel_exec_info_pair : kernel_execution_info_) { - auto &kernel_exec_info = kernel_exec_info_pair.second; - kernel_exec_info.trigger_swap_ = false; - kernel_exec_info.need_swap_ = false; - kernel_exec_info.host_addrs_.clear(); - } - ReleaseHostPinnedMem(); - swap_in_blacklist_.clear(); - mem_swap_info_.clear(); -} -} // namespace memswap -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h deleted file mode 100644 index 1969dadb54..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include "pre_activate/mem_reuse/mem_copy_manager.h" - -using PerformPair = std::pair; -namespace mindspore { -namespace device { -namespace memswap { -class MemSwapManager { - public: - explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager) - : tensor_size_threshold_(0), tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1) { - mem_copy_manager_ = mem_copy_manager; - } - - MemSwapManager(const MemSwapManager &) = delete; - - MemSwapManager &operator=(const MemSwapManager &) = delete; - - ~MemSwapManager() = default; - - void Init(const mindspore::session::KernelGraph *kernel_graph); - - void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, - const HostAddress &host_address) const; - - bool SyncMemCopyStream(SwapKind swap_kind) const; - - DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const; - - // retreat to find a workable swap scheme - bool RetreatSwapInfo(); - - bool trigger_swap() const { return trigger_swap_; } - - bool mem_swap_init() const { return mem_swap_initialized_; } - - KernelExecutionInfo &SearchKernelExecutionInfo(const AnfNodePtr &kernel) const; - - void AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform); - - float QueryKernelExecutionPerform(const AnfNodePtr &kernel) const; - - void AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, const PerformPair &perform); - - const PerformPair &QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const; - - bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const; - - bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const; - - const std::vector &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; - - void InsertSwapInBlackList(const void *device_ptr); - - bool FindInSwapInBlackList(const void *device_ptr) const; - - const HostAddress &kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const; - - bool AllocHostPinnedMem(size_t size, void **addr) const; - - void ReleaseHostPinnedMem(); - - void ClearSwapQueue() const; - - private: - void AddSwapInfo(); - - void ResetSwapInfo(); - - void SaveUserKernelTopoOrder(); - - void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap); - - void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap); - - void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); - - bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; - - std::vector execution_order_; - std::vector ordered_tensors_; - std::unordered_map kernel_execution_info_; - std::unordered_map> kernel_swap_perform_; - // trigger swap kernel key : MemSwapInfo of kernel need to be swapped - std::unordered_map> mem_swap_info_; - std::vector host_addrs_list_; - std::unordered_set swap_in_blacklist_; - - size_t tensor_size_threshold_; - size_t tensor_size_threshold_idx_; - size_t tensor_size_num_; - size_t distance_threshold_; - - MemCopyManagerPtr mem_copy_manager_{nullptr}; - FuncGraphManagerPtr graph_manager_{nullptr}; - bool mem_swap_initialized_{false}; - bool swap_info_already_set_{false}; - bool trigger_swap_{false}; - - static constexpr size_t kDistanceInitFactor = 3; - static constexpr size_t kDistanceLowerBound = 3; -}; -using MemSwapManagerPtr = std::shared_ptr; -} // namespace memswap -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.cc b/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.cc deleted file mode 100644 index 9df34a1c59..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.cc +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/pass/add_atomic_clean.h" -#include -#include -#include -#include "operator/ops.h" -#include "utils/utils.h" -#include "utils/graph_utils.h" -#include "utils/log_adapter.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "debug/anf_ir_dump.h" - -namespace mindspore { -namespace opt { -namespace { - -static std::vector g_output_idx; - -bool HasAtomic(const AnfNodePtr &input) { - if (IsPrimitiveCNode(input)) { - const auto &cnode = input->cast(); - const auto &prim = GetValueNode(cnode->input(0)); - return prim->HasAttr("atomic_add"); - } - return false; -} - -std::vector CalCleanSize(const CNodePtr &pre_node) { - MS_EXCEPTION_IF_NULL(pre_node); - std::vector clean_size_list; - // clean output - for (auto &index : g_output_idx) { - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(pre_node, index); - size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); - std::vector shape = AnfAlgo::GetOutputDeviceShape(pre_node, index); - auto size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - clean_size_list.push_back((size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize); - } - MS_LOG(DEBUG) << "Clear output size: " << clean_size_list.size() << ", pre_node: " << pre_node->fullname_with_scope(); - return clean_size_list; -} - -CNodePtr CreateTbeAtomicCleanNode(const std::shared_ptr &kernel_graph, - const mindspore::CNodePtr &pre_node) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(pre_node); - auto clean_zero_prim = std::make_shared(kAtomicAddrCleanOpName); - auto new_value_node = NewValueNode(clean_zero_prim); - std::vector inputs = {new_value_node}; - CNodePtr clean_zero = kernel_graph->NewCNode(inputs); - AbstractBasePtr abstract = std::make_shared(); - clean_zero->set_abstract(abstract); - auto builder = std::make_shared(); - builder->SetKernelType(KernelType::TBE_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clean_zero.get()); - auto clean_size = CalCleanSize(pre_node); - AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clean_zero); - AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(g_output_idx), clean_zero); - AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clean_zero.get()); - return clean_zero; -} -} // namespace - -void AddAtomicClean(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto mng = kernel_graph->manager(); - if (mng == nullptr) { - mng = Manage(kernel_graph, true); - kernel_graph->set_manager(mng); - } - auto &todos = kernel_graph->execution_order(); - for (auto iter = todos.cbegin(); iter != todos.end(); ++iter) { - auto node = *iter; - if (AnfAlgo::IsGraphKernel(node) && kernel_graph->nodes().contains(node)) { - auto fg = GetValueNode(node->input(kAnfPrimitiveIndex)); - MS_EXCEPTION_IF_NULL(fg); - auto input = fg->get_return()->input(1); - if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { - const auto &cnode = input->cast(); - for (size_t i = 0; i < cnode->inputs().size(); ++i) { - if (HasAtomic(cnode->input(i))) { - g_output_idx.push_back(i - 1); - } - } - } else if (HasAtomic(input)) { - g_output_idx.push_back(0); - } - - if (!g_output_idx.empty()) { - auto zero_node = CreateTbeAtomicCleanNode(kernel_graph, node); - auto depend = kernel_graph->NewCNode({NewValueNode(prim::kPrimDepend), node->input(1), zero_node}); - std::vector new_input = node->inputs(); - new_input[1] = depend; - auto new_cnode = std::make_shared(new_input, kernel_graph); - // Set abstract - new_cnode->set_abstract(node->abstract()); - // Set kernel info - new_cnode->set_kernel_info(node->kernel_info_ptr()); - mng->Replace(node, new_cnode); - g_output_idx.clear(); - } - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.h b/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.h deleted file mode 100644 index bb1edb0e35..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H_ - -#include -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -void AddAtomicClean(const std::shared_ptr &kernel_graph); -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H diff --git a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc deleted file mode 100644 index 297a167aa8..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/common_subexpression_elimination.h" -#include -#include "device/kernel_info.h" - -namespace mindspore { -namespace opt { -namespace { -bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(main); - MS_EXCEPTION_IF_NULL(node); - auto main_kernel_info = main->kernel_info(); - auto node_kernel_info = node->kernel_info(); - if (main_kernel_info == nullptr && node_kernel_info == nullptr) { - return true; - } - if (main_kernel_info != nullptr && node_kernel_info != nullptr) { - return *main_kernel_info == *node_kernel_info; - } - return false; -} -} // namespace - -bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { - MS_EXCEPTION_IF_NULL(main); - MS_EXCEPTION_IF_NULL(node); - - bool replace = false; - if (main->isa() && node->isa()) { - auto main_value = GetValueNode(main); - auto node_value = GetValueNode(node); - if (main_value->isa() && node_value->isa()) { - replace = false; - } else if (main_value->isa() && node_value->isa()) { - replace = (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node); - } else { - replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); - } - } else if (main->isa() && node->isa()) { - if (!CheckEqualKernelBuildInfo(main, node)) { - replace = false; - } else { - auto c_main = main->cast(); - MS_EXCEPTION_IF_NULL(c_main); - auto c_node = node->cast(); - MS_EXCEPTION_IF_NULL(c_node); - const auto &inp1 = c_main->inputs(); - const auto &inp2 = c_node->inputs(); - if (inp1.size() == inp2.size()) { - bool appsame = true; - for (size_t j = 0; j < inp1.size(); j++) { - MS_EXCEPTION_IF_NULL(inp1[j]); - MS_EXCEPTION_IF_NULL(inp2[j]); - if (!(*inp1[j] == *inp2[j])) { - appsame = false; - break; - } - } - replace = appsame; - } - } - } - return replace; -} - -bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto backend_cse = std::make_shared(); - return backend_cse->Cse(func_graph, func_graph->manager()); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h deleted file mode 100644 index 18f433ab95..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ -#include "pre_activate/common/pass.h" -#include "optimizer/cse.h" - -namespace mindspore { -namespace opt { -class CommonSubexpressionElimination : public Pass { - public: - CommonSubexpressionElimination() : Pass("cse") {} - ~CommonSubexpressionElimination() override = default; - bool Run(const FuncGraphPtr &func_graph) override; -}; - -class BackendCSE : public CSE { - public: - BackendCSE() = default; - ~BackendCSE() override = default; - bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc deleted file mode 100644 index aa4690abcb..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc +++ /dev/null @@ -1,274 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/communication_op_fusion.h" - -#include -#include -#include - -#include "utils/graph_utils.h" -#include "operator/ops.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel_build_info.h" -#include "parallel/context.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr auto kAttrDefaultGroup = "default_group"; -constexpr auto kAttrDefaultOp = "default_op"; - -kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index, - size_t end_index) { - if (end_index >= communication_op_info.communication_op_nodes.size()) { - MS_LOG(EXCEPTION) << "end index out of vector size"; - } - std::vector inputs_device_format; - std::vector outputs_device_format; - std::vector inputs_device_type; - std::vector outputs_device_type; - std::vector> outputs_shape; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - for (size_t idx = start_index; idx <= end_index; ++idx) { - auto cnode = communication_op_info.communication_op_nodes[idx]; - MS_EXCEPTION_IF_NULL(cnode); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { - inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); - inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); - } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { - outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); - outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); - } - builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); - builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); - builder.SetKernelType(AnfAlgo::GetKernelType(cnode)); - } - builder.SetInputsFormat(inputs_device_format); - builder.SetOutputsFormat(outputs_device_format); - builder.SetInputsDeviceType(inputs_device_type); - builder.SetOutputsDeviceType(outputs_device_type); - return builder.Build(); -} - -std::string GetFusionGroupKey(const AnfNodePtr &node) { - auto primitive = AnfAlgo::GetCNodePrimitive(node); - MS_EXCEPTION_IF_NULL(primitive); - ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion); - if (attr_fusion == nullptr) { - return ""; - } - int fusion = GetValue(attr_fusion); - if (fusion == 0) { - return ""; - } - std::string group = kAttrDefaultGroup; - ValuePtr attr_group = primitive->GetAttr(kAttrGroup); - if (attr_group != nullptr) { - group = GetValue(attr_group); - } - std::string op = kAttrDefaultOp; - ValuePtr attr_op = primitive->GetAttr(kAttrOp); - if (attr_op != nullptr) { - op = GetValue(attr_op); - } - return group + op + std::to_string(fusion); -} -} // namespace - -bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, - std::vector *segment_index, const std::string &group) const { - MS_EXCEPTION_IF_NULL(segment_num); - MS_EXCEPTION_IF_NULL(segment_index); - size_t communication_op_node_size = communication_op_info.communication_op_nodes.size(); - MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size; - - auto parallel_context = parallel::ParallelContext::GetInstance(); - MS_EXCEPTION_IF_NULL(parallel_context); - const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); - - size_t segments = 0; - if (split_indices.size() != 0) { - uint32_t last_index = 0; - for (size_t i = 0; i < split_indices.size(); ++i) { - uint32_t index = split_indices[i]; - if (index <= last_index || index >= communication_op_node_size) { - MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index; - } - segment_index->push_back(index); - last_index = index; - segments++; - } - if (last_index != communication_op_node_size - 1) { - segment_index->push_back(communication_op_node_size - 1); - segments++; - } - } else { - segments = groups_; - for (size_t i = 0; i < segments - 1; ++i) { - segment_index->push_back((i + 1) * (communication_op_node_size / segments) - 1); - } - segment_index->push_back(communication_op_node_size - 1); - } - - if (segments >= communication_op_node_size) { - MS_LOG(INFO) << "fusion not changed: segment_num=" << segments - << ", communication_op_node_size=" << communication_op_node_size; - return false; - } - if (segment_index->at(segments - 1) != communication_op_node_size - 1) { - MS_LOG(EXCEPTION) << "the last segment index is invalid."; - } - for (size_t i = 0; i < segments - 1; ++i) { - if (segment_index->at(i) > segment_index->at(i + 1)) { - MS_LOG(EXCEPTION) << "illegal split: segment_index[" << i << "]=" << segment_index->at(i) << ", segment_index[ " - << i + 1 << "]=" << segment_index->at(i + 1); - } - } - *segment_num = segments; - return true; -} - -AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, - const CommunicationOpInfo &communication_op_info, - size_t start_index, size_t end_index) const { - MS_EXCEPTION_IF_NULL(func_graph); - auto prim = std::make_shared(op_name_); - MS_EXCEPTION_IF_NULL(prim); - std::vector fusion_inputs = {NewValueNode(prim)}; - // get all inputs of current segment - if (end_index >= communication_op_info.communication_op_nodes.size()) { - MS_LOG(EXCEPTION) << "end index out of vector size"; - } - for (size_t idx = start_index; idx <= end_index; ++idx) { - auto cnode = communication_op_info.communication_op_nodes[idx]; - MS_EXCEPTION_IF_NULL(cnode); - fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); - } - AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs); - MS_EXCEPTION_IF_NULL(fused_node); - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - fused_node->set_kernel_info(kernel_info); - AbstractBasePtrList abstract_list; - for (size_t idx = start_index; idx <= end_index; ++idx) { - auto cnode = communication_op_info.communication_op_nodes[idx]; - MS_EXCEPTION_IF_NULL(cnode); - AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node); - AnfAlgo::CopyNodeAttr("op", cnode, fused_node); - AnfAlgo::CopyNodeAttr("group", cnode, fused_node); - abstract_list.push_back(cnode->abstract()); - } - auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get()); - auto abstract_tuple = std::make_shared(abstract_list); - MS_EXCEPTION_IF_NULL(abstract_tuple); - fused_node->set_abstract(abstract_tuple); - return fused_node; -} - -bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, - size_t segment_num, const std::vector &segment_index) const { - MS_EXCEPTION_IF_NULL(func_graph); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - bool changed = false; - size_t start_index = 0; - for (size_t segment_idx = 0; segment_idx < segment_num; ++segment_idx) { - size_t end_index = segment_index.at(segment_idx); - if (end_index - start_index < 1) { - start_index = end_index + 1; - continue; - } - AnfNodePtr new_communication_op = - CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index); - // replace old communication op with new communication op - for (auto idx = start_index; idx <= end_index; ++idx) { - std::vector tuple_getitem_input; - tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem)); - tuple_getitem_input.push_back(new_communication_op); - auto index = NewValueNode(SizeToInt(idx - start_index)); - MS_EXCEPTION_IF_NULL(index); - auto imm = std::make_shared(idx - start_index); - MS_EXCEPTION_IF_NULL(imm); - auto abstract_scalar = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_scalar); - index->set_abstract(abstract_scalar); - tuple_getitem_input.push_back(index); - AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input); - MS_EXCEPTION_IF_NULL(tuple_getitem); - auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx); - MS_EXCEPTION_IF_NULL(communication_op_node_item); - tuple_getitem->set_abstract(communication_op_node_item->abstract()); - if (!manager->Replace(communication_op_node_item, tuple_getitem)) { - MS_LOG(EXCEPTION) << "manager replace node failed"; - } - } - start_index = end_index + 1; - changed = true; - } - return changed; -} - -bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - const float input_grad_size_num = 0.0; - const float input_grad_time_num = 0.0; - // divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion - std::unordered_map candidate_groups; - std::vector node_list = TopoSort(func_graph->get_return()); - for (auto &node : node_list) { - if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == op_name_) { - std::string key = GetFusionGroupKey(node); - if (key.empty()) { - continue; - } - if (candidate_groups.find(key) == candidate_groups.end()) { - CommunicationOpInfo communication_op_info; - candidate_groups[key] = communication_op_info; - } - candidate_groups[key].communication_op_nodes.push_back(node->cast()); - candidate_groups[key].input_grad_size.push_back(input_grad_size_num); - candidate_groups[key].input_grad_time.push_back(input_grad_time_num); - } - } - // split candidate group to segments according to _group class member - bool changed = false; - for (auto &it : candidate_groups) { - if (it.second.communication_op_nodes.size() <= 1) { - continue; - } - auto first_node = it.second.communication_op_nodes[0]; - if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr(first_node, kAttrIndex) > 0) { - std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(), - [](const CNodePtr &a, const CNodePtr &b) { - return AnfAlgo::GetNodeAttr(a, kAttrIndex) < AnfAlgo::GetNodeAttr(b, kAttrIndex); - }); - } - size_t segment_num = 0; - std::vector segment_index; - if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) { - if (DoFusion(func_graph, it.second, segment_num, segment_index)) { - changed = true; - } - } - } - return changed; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h deleted file mode 100644 index d00180f97f..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ -#include -#include -#include - -#include "pre_activate/common/pass.h" -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -struct CommunicationOpInfo { - std::vector communication_op_nodes; - std::vector input_grad_size; - std::vector input_grad_time; -}; - -class CommunicationOpFusion : public Pass { - public: - explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1) - : Pass(name), op_name_(std::move(op_name)), groups_(groups) {} - ~CommunicationOpFusion() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, size_t segment_num, - const std::vector &segment_index) const; - AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, - const CommunicationOpInfo &communication_op_info, size_t start_index, - size_t end_index) const; - bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, - std::vector *segment_index, const std::string &group) const; - std::string op_name_; - size_t groups_ = 1; -}; - -class AllReduceFusion : public CommunicationOpFusion { - public: - explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {} - ~AllReduceFusion() override = default; -}; - -class AllGatherFusion : public CommunicationOpFusion { - public: - explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {} - ~AllGatherFusion() override = default; -}; - -class BroadcastFusion : public CommunicationOpFusion { - public: - explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {} - ~BroadcastFusion() override = default; -}; - -class ReduceScatterFusion : public CommunicationOpFusion { - public: - explicit ReduceScatterFusion(size_t groups = 1) - : CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {} - ~ReduceScatterFusion() override = default; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc deleted file mode 100644 index 6a557388ad..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc +++ /dev/null @@ -1,110 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/const_input_to_attr_registry.h" - -#include - -#include "utils/utils.h" -#include "utils/log_adapter.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { - Register(prim::kPrimCast->name(), {1}); - Register(prim::kPrimAvgPoolGrad->name(), {0}); - Register(prim::kPrimConv2DBackpropInput->name(), {2}); - Register(prim::kPrimConv2DBackpropFilter->name(), {2}); - Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); - Register(prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), {0}); - Register(prim::kPrimReshape->name(), {1}); - Register(prim::kPrimReduceMax->name(), {1}); - Register(prim::kPrimReduceMin->name(), {1}); - Register(prim::kPrimReduceSum->name(), {1}); - Register(prim::kPrimReduceMean->name(), {1}); - Register(prim::kPrimGatherV2->name(), {2}); - Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5}); - Register(prim::kPrimEmbeddingLookupCommGrad->name(), {1}); - Register(prim::kPrimSubscalar->name(), {1}); - Register(prim::kPrimTranspose->name(), {1}); - Register(prim::kPrimUnsortedSegmentSum->name(), {2}); - Register(prim::kPrimOneHot->name(), {1}); - Register(prim::kPrimConcat->name(), {0}); - Register(prim::kPrimCumSum->name(), {1}); - Register(prim::kPrimCumProd->name(), {1}); - Register(prim::kPrimReduceAll->name(), {1}); - Register(prim::kPrimUnsortedSegmentMin->name(), {2}); - Register(kSparseGatherV2, {2}); - Register(kUnsortedSegmentProdOpName, {2}); - Register(kSimpleMeanGradOpName, {1}); - Register(kMeanGradOpName, {1}); - Register(kSliceOpName, {1, 2}); - Register(kSliceGradOpName, {2, 3}); - Register(kTileOpName, {1}); - Register(kScatterNdOpName, {2}); - Register(kStridedSliceAssignOpName, {1, 2, 3}); - Register(kStridedSliceOpName, {1, 2, 3}); - Register(kFlattenGradOpName, {1}); - Register(kExpandDimsOpName, {1}); - Register(kSplitOpName, {0}); - Register(kErfOpName, {1}); - Register(kSparseApplyAdagradOpName, {2}); - Register(kResizeNearestNeighborGradOpName, {1}); - Register(kResizeNearestNeighborV2OpName, {1}); - Register(kResizeNearestNeighborV2GradOpName, {1}); - Register(kApplyRMSPropOpname, {5, 6, 7}); - Register(kResizeBilinearV2OpName, {1}); - Register(kReduceProdOpName, {1}); - Register(kCumprodOpName, {1}); - Register(kSpaceToBatchOpName, {1}); - Register(kBatchToSpaceOpName, {1}); - Register(kPadOpName, {1}); -} - -ConstInputToAttrInfoRegistry &ConstInputToAttrInfoRegistry::Instance() { - static ConstInputToAttrInfoRegistry instance; - return instance; -} - -void ConstInputToAttrInfoRegistry::Register(const ConstInputToAttrInfoRegister ®) { - auto op_name = reg.GetOpName(); - if (op_input_to_attr_map_.find(op_name) == op_input_to_attr_map_.end()) { - (void)op_input_to_attr_map_.insert(make_pair(op_name, reg)); - MS_LOG(DEBUG) << op_name << " const2attr register successfully!"; - } -} - -void ConstInputToAttrInfoRegistry::Register(const std::string &op_name, - const std::unordered_set &input_attr_set) { - if (op_input_to_attr_map_.find(op_name) == op_input_to_attr_map_.end()) { - ConstInputToAttrInfoRegister reg(op_name); - (void)reg.SetConstInputToAttr(input_attr_set); - (void)op_input_to_attr_map_.insert(make_pair(op_name, reg)); - MS_LOG(DEBUG) << op_name << " const2attr register successfully!"; - } -} - -bool ConstInputToAttrInfoRegistry::GetRegisterByOpName(const std::string &op_name, - ConstInputToAttrInfoRegister *reg) const { - if (op_input_to_attr_map_.find(op_name) != op_input_to_attr_map_.end()) { - *reg = op_input_to_attr_map_.at(op_name); - MS_LOG(DEBUG) << op_name << " const2attr find in registery."; - return true; - } - return false; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.cc b/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.cc deleted file mode 100644 index b0e2ab044c..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.cc +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/const_to_attr_strided_slice_grad.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/context/ms_context.h" -#include "utils/utils.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -const size_t strides_index = 5; - -bool GetStridesValues(const CNodePtr &strided_slice_grad, ValuePtrList *strides_values) { - MS_EXCEPTION_IF_NULL(strided_slice_grad); - if (strided_slice_grad->size() < 6) { - MS_LOG(DEBUG) << "Op strided_slice_grad's inputs size less than 6, graph not changed"; - return false; - } - auto strides_input = strided_slice_grad->input(strides_index); - MS_EXCEPTION_IF_NULL(strides_input); - auto strides_value_node = strides_input->cast(); - if (strides_value_node == nullptr) { - MS_LOG(DEBUG) << "strides is not a value node."; - return false; - } - auto value = strides_value_node->value(); - if (value == nullptr) { - MS_LOG(DEBUG) << "strides has no value."; - return false; - } - auto value_tuple = value->cast(); - if (value_tuple == nullptr) { - MS_LOG(DEBUG) << "strides is not a value tuple."; - return false; - } - *strides_values = value_tuple->value(); - return true; -} - -bool CheckValues(const ValuePtrList &strides_values) { - if (strides_values.empty()) { - MS_LOG(DEBUG) << "strides_values is empty"; - return false; - } - for (auto &value : strides_values) { - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - auto scalar = value->cast(); - MS_EXCEPTION_IF_NULL(scalar); - if (!scalar->isa()) { - MS_LOG(DEBUG) << "strides value is not a Integer"; - return false; - } - if (GetValue(scalar) != 1) { - MS_LOG(DEBUG) << "StridedSliceGrad has no 1 value"; - return false; - } - } else { - MS_LOG(DEBUG) << "The value " << value << "of tuple is not a scalar"; - return false; - } - } - return true; -} - -bool CheckAttrs(const CNodePtr &strided_slice_grad) { - MS_EXCEPTION_IF_NULL(strided_slice_grad); - if (!AnfAlgo::HasNodeAttr(kAttrNewAxisMask, strided_slice_grad) || - !AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, strided_slice_grad)) { - MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not exist in cnode[" + strided_slice_grad->DebugString() + "]"; - return false; - } - auto new_axis_mask = AnfAlgo::GetNodeAttr(strided_slice_grad, kAttrNewAxisMask); - auto shrink_axis_mask = AnfAlgo::GetNodeAttr(strided_slice_grad, kAttrShrinkAxisMask); - if (new_axis_mask != 0 || shrink_axis_mask != 0) { - MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not equal 0"; - return false; - } - return true; -} -} // namespace - -const BaseRef ConstToAttrStridedSliceGradPass::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto strided_slice_grad_prim = std::make_shared(kStridedSliceGradOpName); - return VectorRef({strided_slice_grad_prim, Xs}); -} - -const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto strided_slice_grad = node->cast(); - MS_EXCEPTION_IF_NULL(strided_slice_grad); - - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - - if (ms_context->device_target() == kAscendDevice) { - if (!CheckAttrs(strided_slice_grad)) { - MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed"; - return nullptr; - } - - ValuePtrList strides_values; - if (!GetStridesValues(strided_slice_grad, &strides_values)) { - return nullptr; - } - - if (!CheckValues(strides_values)) { - MS_LOG(INFO) << "Check strides' values failed, graph not changed"; - return nullptr; - } - } - - ConstInputToAttr(strided_slice_grad, {1, 2, 3, 4}); - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.h b/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.h deleted file mode 100644 index 2e364244bf..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConstToAttrStridedSliceGradPass : public PatternProcessPass { - public: - explicit ConstToAttrStridedSliceGradPass(bool multigraph = true) - : PatternProcessPass("const_to_attr_strided_slice_grad_", multigraph) {} - ~ConstToAttrStridedSliceGradPass() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.h b/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.h deleted file mode 100644 index e124ff8cf4..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ -#include -#include -#include - -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConvertConstInputToAttr : public PatternProcessPass { - public: - explicit ConvertConstInputToAttr(bool multigraph = true) - : PatternProcessPass("convert_const_input_to_attr", multigraph) {} - ~ConvertConstInputToAttr() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - std::unordered_map> op_input_attr_map_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.cc b/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.cc deleted file mode 100644 index b4f98cc6d7..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.cc +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/convert_const_input_to_tensor_input.h" - -#include -#include -#include - -#include "utils/graph_utils.h" -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "kernel/common_utils.h" -#include "device/kernel_info.h" - -namespace mindspore { -namespace opt { -namespace { -ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - ValueNodePtr new_value_node = std::make_shared(value_node->value()); - new_value_node->set_abstract(value_node->abstract()); - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { - types.push_back(kTypeUnknown); - } - kernel_build_info_builder->SetOutputsDeviceType(types); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - return new_value_node; -} - -AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) { - MS_EXCEPTION_IF_NULL(input_node); - auto value_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - tensor::TensorPtr tensor_ptr = nullptr; - if (value->isa()) { - tensor_ptr = ScalarToTensor(value->cast()); - } else if (value->isa()) { - tensor_ptr = CreateTupleTensor(value->cast()); - } else { - MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple"; - } - if (tensor_ptr == nullptr) { - MS_LOG(WARNING) << "Create tensor failed"; - return nullptr; - } - auto tensor_input = std::make_shared(tensor_ptr); - MS_EXCEPTION_IF_NULL(tensor_input); - tensor_input->set_abstract(tensor_ptr->ToAbstract()); - if (kernel_graph != nullptr) { - tensor_input = kernel_graph->NewValueNode(tensor_input); - kernel_graph->AddValueNodeToGraph(tensor_input); - } else { - tensor_input = MakeValueNode(tensor_input); - } - tensor_input->set_scope(input_node->scope()); - return tensor_input; -} - -AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(cnode); - std::vector new_inputs; - auto kernel_graph = func_graph->cast>(); - auto inputs = cnode->inputs(); - new_inputs.push_back(inputs[0]); - bool need_update = false; - // the first input is primitive node which is not the real input - for (size_t i = 0; i < inputs.size() - 1; ++i) { - auto input_node = inputs[i + 1]; - if (IsValueNode(input_node) || IsValueNode(input_node)) { - auto tensor_input = CreateTensorInput(kernel_graph, input_node); - if (tensor_input == nullptr) { - new_inputs.push_back(input_node); - continue; - } - new_inputs.push_back(tensor_input); - need_update = true; - } else { - new_inputs.push_back(input_node); - } - } - if (need_update) { - MS_EXCEPTION_IF_NULL(func_graph); - auto new_cnode = func_graph->NewCNode(new_inputs); - MS_EXCEPTION_IF_NULL(new_cnode); - new_cnode->set_abstract(cnode->abstract()); - new_cnode->set_scope(cnode->scope()); - AnfAlgo::CopyNodeAttrs(cnode, new_cnode); - return new_cnode; - } - return nullptr; -} - -AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) { - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - auto mng = sub_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - std::vector todo; - std::vector> graph_rets; - kernel::GetValidKernelNodes(sub_graph, &todo); - kernel::GetGraphRealOutput(sub_graph, &graph_rets); - - for (auto &t : todo) { - auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast()); - if (t_new_node != nullptr && t_new_node != t) { - (void)mng->Replace(t, t_new_node); - } - } - - return node; -} -} // namespace - -const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { - return nullptr; - } - if (AnfAlgo::IsGraphKernel(node)) { - return ProcessGraphKernelOp(node); - } else { - return ConstInputToTensorInput(func_graph, node->cast()); - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.h b/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.h deleted file mode 100644 index 1cc2bdf0ec..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ -#include - -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConvertConstInputToTensorInput : public PatternProcessPass { - public: - explicit ConvertConstInputToTensorInput(bool multigraph = true) - : PatternProcessPass("convert_const_input_to_tensor_input", multigraph) {} - ~ConvertConstInputToTensorInput() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc deleted file mode 100644 index a03087c1a4..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc +++ /dev/null @@ -1,148 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/convert_tuple_input_to_dynamic_input.h" - -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" -#include "session/kernel_graph.h" -#include "kernel/common_utils.h" -#include "device/kernel_info.h" - -namespace mindspore { -namespace opt { -namespace { -bool MakeValueNode(const AnfNodePtr &node) { - auto value_node = node->cast(); - if (value_node == nullptr) { - return false; - } - - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - TypeId infer_data_type; - if (AnfAlgo::GetOutputTensorNum(value_node) == 0) { - infer_data_type = kTypeUnknown; - } else { - infer_data_type = AnfAlgo::GetOutputInferDataType(value_node, 0); - } - kernel_build_info_builder->SetOutputsDeviceType(std::vector{infer_data_type}); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get()); - return true; -} - -void ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node, - std::vector *plant_inputs, std::vector *dyn_input_sizes) { - MS_EXCEPTION_IF_NULL(plant_inputs); - MS_EXCEPTION_IF_NULL(dyn_input_sizes); - MS_EXCEPTION_IF_NULL(graph); - auto output_size = AnfAlgo::GetOutputTensorNum(input_node); - dyn_input_sizes->push_back(output_size); - std::vector convert_inputs; - auto kernel_graph = graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); - if (input_node->isa()) { - auto value_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - convert_inputs = kernel_graph->SplitTupleValueNodeToNodeList(value_node); - } else { - for (size_t index = 0; index < output_size; ++index) { - auto tuple_get_item = CreatTupleGetItemNode(graph, input_node, index); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, index)}, - {AnfAlgo::GetOutputInferShape(input_node, index)}, tuple_get_item.get()); - convert_inputs.emplace_back(tuple_get_item); - } - } - (void)std::copy(convert_inputs.begin(), convert_inputs.end(), std::back_inserter(*plant_inputs)); -} - -void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - MS_EXCEPTION_IF_NULL(graph); - auto &ori_args = cnode_ptr->inputs(); - if (ori_args.size() < 1) { - return; - } - std::vector plant_inputs; - std::vector dyn_input_sizes; - plant_inputs.push_back(ori_args[kAnfPrimitiveIndex]); - for (size_t i = 1; i < ori_args.size(); ++i) { - auto input_node = ori_args[i]; - if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) { - auto input_size = AnfAlgo::GetOutputTensorNum(input_node); - dyn_input_sizes.push_back(input_size); - auto cnode = input_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto inputs = cnode->inputs(); - for (size_t j = 1; j < inputs.size(); ++j) { - MS_EXCEPTION_IF_NULL(inputs[j]); - if (IsValueNode(inputs[j])) { - auto success = MakeValueNode(inputs[j]); - if (!success) { - MS_LOG(WARNING) << "Make value node failed, " << inputs[j]->DebugString(); - } - } - plant_inputs.push_back(inputs[j]); - } - } else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) { - ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes); - } else { - dyn_input_sizes.push_back(-1); - plant_inputs.push_back(input_node); - } - } - // If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs. - if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int s) { return s >= 0; })) { - AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode_ptr); - cnode_ptr->set_inputs(plant_inputs); - } -} -} // namespace - -const BaseRef ConvertTupleInputToDynamicInput::DefinePattern() const { - VarPtr V = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - if (AnfAlgo::IsGraphKernel(node)) { - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - std::vector todos; - kernel::GetValidKernelNodes(sub_graph, &todos); - for (auto &t : todos) { - ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast()); - } - } else { - ConvertMakeTupleInputToPlantInputs(func_graph, node->cast()); - } - return node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.h b/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.h deleted file mode 100644 index b3d8e25d6e..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ - -#include -#include - -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConvertTupleInputToDynamicInput : public PatternProcessPass { - public: - explicit ConvertTupleInputToDynamicInput(bool multigraph = true) - : PatternProcessPass("convert_tuple_input_to_dynamic_input", multigraph) {} - - ~ConvertTupleInputToDynamicInput() override = default; - - const BaseRef DefinePattern() const override; - - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc deleted file mode 100644 index a5e51411bc..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/convert_tuple_output_to_maketuple.h" - -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - MS_EXCEPTION_IF_NULL(graph); - std::vector convert_inputs = {cnode_ptr->input(0)}; - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode_ptr); ++index) { - auto input_node = AnfAlgo::GetInputNode(cnode_ptr, index); - if (AnfAlgo::IsTupleOutput(input_node)) { - std::vector types; - std::vector> shapes; - std::vector make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)}; - for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) { - make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index)); - types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index)); - shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index)); - } - auto make_tuple = graph->NewCNode(make_tuple_inputs_list); - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); - convert_inputs.emplace_back(make_tuple); - } else { - convert_inputs.push_back(input_node); - } - } - return graph->NewCNode(convert_inputs); -} -} // namespace - -const BaseRef ConvertTupleOutputToMaketuple::DefinePattern() const { - VarPtr V = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { - return nullptr; - } - if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { - return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); - })) { - return ConvertTupleInputToMakeTuple(func_graph, cnode); - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.h b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.h deleted file mode 100644 index a16ffaf674..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H -#define MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H -#include -#include - -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConvertTupleOutputToMaketuple : public PatternProcessPass { - public: - explicit ConvertTupleOutputToMaketuple(bool multigraph = true) - : PatternProcessPass("convert_tuple_output_to_maketuple", multigraph) {} - - ~ConvertTupleOutputToMaketuple() override = default; - - const BaseRef DefinePattern() const override; - - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H diff --git a/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.cc b/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.cc deleted file mode 100644 index 4d3dcfccc0..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.cc +++ /dev/null @@ -1,190 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/pass/eliminate_redundant_op.h" -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" -#include "operator/ops.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace opt { -using KernelWithIndex = std::pair; -namespace { -CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector *pass_vector) { - MS_EXCEPTION_IF_NULL(pass_vector); - if (node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::IsRealCNodeKernel(cnode)) { - pass_vector->push_back(make_pair(cnode, IntToSize(1))); - return cnode; - } - - auto input0 = cnode->input(0); - MS_EXCEPTION_IF_NULL(input0); - if (IsPrimitive(input0, prim::kPrimMakeTuple)) { - auto temp_node = cnode->input(index + IntToSize(1)); - MS_EXCEPTION_IF_NULL(temp_node); - pass_vector->push_back(make_pair(cnode, index + IntToSize(1))); - return GetRealPrevCNode(temp_node, 0, pass_vector); - } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { - auto input2 = cnode->input(2); - MS_EXCEPTION_IF_NULL(input2); - auto value_node = input2->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - pass_vector->push_back(make_pair(cnode, IntToSize(1))); - return GetRealPrevCNode(cnode->input(1), IntToSize(item_idx), pass_vector); - } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { - pass_vector->push_back(make_pair(cnode, IntToSize(1))); - return GetRealPrevCNode(cnode->input(1), 0, pass_vector); - } else { - return nullptr; - } -} - -bool TransOpEliminateCondition(const CNodePtr &, const CNodePtr &) { return true; } - -bool CastEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { - return HasSymmetricalKernelInfo(node1, node2); -} - -bool TransDataOpEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { - return AnfAlgo::GetInputFormat(node1, 0) == AnfAlgo::GetOutputFormat(node2, 0) && - AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0); -} - -const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &prev_cnode, - std::vector *pass_vector) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(pass_vector); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - bool has_depend_node = false; - bool has_node_used_more_than_once = false; - auto &users = manager->node_users(); - - auto pass_size = pass_vector->size(); - for (size_t idx = 1; idx <= pass_size - 1; ++idx) { - auto nd = (*pass_vector)[idx].first; - if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend) || - AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { - has_depend_node = true; - } - if (users[nd].size() >= 2) { - has_node_used_more_than_once = true; - } - } - - // when no depend node and no node used more than once, no need to rebuild the pass nodes - if (!has_depend_node) { - return prev_cnode->input(1); - } else if (!has_node_used_more_than_once) { - (void)manager->Replace(prev_cnode, prev_cnode->input(1)); - return cnode->input(1); - } else { // rebuild the pass nodes - for (size_t idx = pass_size - 2; idx > 0; --idx) { - auto new_node = func_graph->NewCNode((*pass_vector)[idx].first->inputs()); - new_node->set_input((*pass_vector)[idx].second, - (*pass_vector)[idx + 1].first->input((*pass_vector)[idx + 1].second)); - (*pass_vector)[idx].first = new_node; - } - return (*pass_vector)[1].first; - } -} -} // namespace - -void EliminateRedundantOp::Init() { - (void)redundant_process_map_.emplace(std::pair( - kFour2FiveOpName, std::pair(kFive2FourOpName, TransOpEliminateCondition))); - (void)redundant_process_map_.emplace(std::pair( - kFive2FourOpName, std::pair(kFour2FiveOpName, TransOpEliminateCondition))); - (void)redundant_process_map_.emplace(std::pair( - prim::kPrimCast->name(), std::pair(prim::kPrimCast->name(), CastEliminateCondition))); - (void)redundant_process_map_.emplace(std::pair( - kTransDataOpName, std::pair(kTransDataOpName, TransDataOpEliminateCondition))); -} - -const AnfNodePtr EliminateRedundantOp::DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { - // match the first name - auto name1 = AnfAlgo::GetCNodeName(cnode); - auto it = redundant_process_map_.find(name1); - if (it == redundant_process_map_.end()) { - return nullptr; - } - std::vector pass_vector; - pass_vector.push_back(make_pair(cnode, 1)); - auto prev_cnode = GetRealPrevCNode(cnode->input(1), 0, &pass_vector); - if (prev_cnode == nullptr) { - return nullptr; - } - // match the second name - auto name2 = AnfAlgo::GetCNodeName(prev_cnode); - if (name2 != it->second.first) { - return nullptr; - } - // match condition - auto condition_func = it->second.second; - if (condition_func == nullptr) { - return nullptr; - } - if (!condition_func(cnode, prev_cnode)) { - return nullptr; - } - - return ProcessMatchedNodes(func_graph, cnode, prev_cnode, &pass_vector); -} - -const AnfNodePtr EliminateRedundantOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode == nullptr || func_graph == nullptr) { - return nullptr; - } - - if (AnfAlgo::IsGraphKernel(node)) { - // do eliminate for ops in graph kernel. - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - auto mng = sub_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - std::vector todo; - kernel::GetValidKernelNodes(sub_graph, &todo); - for (auto &t : todo) { - CNodePtr t_cnode = t->cast(); - MS_EXCEPTION_IF_NULL(t_cnode); - auto t_new_node = DoEliminate(sub_graph, t_cnode); - if (t_new_node != nullptr && t_new_node != t) { - (void)mng->Replace(t, t_new_node); - } - } - return node; - } - // do eliminate for single op. - return DoEliminate(func_graph, cnode); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.h b/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.h deleted file mode 100644 index c44190f645..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -using ConditionFunc = std::function; -using RedundantOpPair = std::pair; - -class EliminateRedundantOp : public PatternProcessPass { - public: - explicit EliminateRedundantOp(bool multigraph = true) : PatternProcessPass("eliminate_redundant_op", multigraph) { - Init(); - } - ~EliminateRedundantOp() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - void Init(); - const AnfNodePtr DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; - std::unordered_map redundant_process_map_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.cc b/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.cc deleted file mode 100644 index 3b566b4f7c..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/pass/erase_visit_attr.h" -#include -#include -#include "kernel/common_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -const BaseRef EraseVisitAttr::DefinePattern() const { - std::shared_ptr V = std::make_shared(Visited); - std::shared_ptr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { - if (node != nullptr && AnfAlgo::IsRealCNodeKernel(node)) { - if (AnfAlgo::IsGraphKernel(node)) { - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - std::vector todos; - kernel::GetValidKernelNodes(fg, &todos); - for (auto &t : todos) { - AnfAlgo::EraseNodeAttr(kAttrVisited, t); - } - } - AnfAlgo::EraseNodeAttr(kAttrVisited, node); - } else { - AnfAlgo::EraseNodeAttr(kAttrVisited, node); - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.h b/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.h deleted file mode 100644 index a986aad83a..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class EraseVisitAttr : public PatternProcessPass { - public: - explicit EraseVisitAttr(bool multigraph = true) : PatternProcessPass("erase_visit_attr", multigraph) {} - ~EraseVisitAttr() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/fuse_basic.cc b/mindspore/ccsrc/pre_activate/pass/fuse_basic.cc deleted file mode 100644 index 84edd5c5e2..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/fuse_basic.cc +++ /dev/null @@ -1,222 +0,0 @@ - -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/fuse_basic.h" -#include "pre_activate/pass/fuse_graph_kernel.h" - -#include -#include -#include -#include -#include -#include - -#include "operator/ops.h" -#include "utils/utils.h" -#include "utils/graph_utils.h" -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "vm/segment_runner.h" -#include "debug/draw.h" -#include "debug/anf_ir_dump.h" -#include "ir/func_graph_cloner.h" - -namespace mindspore { -namespace opt { -namespace { -std::vector get_fusable_basic_ops(bool is_before_kernel_select) { - std::vector fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, - prim::kPrimExpandDims}; - if (!is_before_kernel_select) { - fusable_basic_ops.push_back(prim::kPrimCast); - } - return fusable_basic_ops; -} - -IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, - const AnfNodePtr &node) { - if (cur_node == node) { - return FOLLOW; - } - if (!IsPrimitiveCNode(node)) { - return EXCLUDE; - } - - auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select); - bool is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); - - return is_fusable ? FOLLOW : EXCLUDE; -} - -std::vector FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { - GraphKernelInfo info; - info.is_before_kernel_select = is_before_kernel_select; - // Search fusable nodes according input direction. - auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1); - auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); - if (used_nodes.size() > 1) { - used_nodes = RemoveCircle(used_nodes, false); - } - TopoSortForNodeList(&used_nodes); - return used_nodes; -} - -void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { - AnfNodeSet outputs_set; - for (auto out : *outputs) { - outputs_set.insert(out); - } - - AnfNodePtrList vir_outputs; - std::unordered_map eqv; - auto fg_outputs = fg->output(); - if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { - auto cnode = fg_outputs->cast(); - for (size_t i = 1; i < cnode->size(); ++i) { - vir_outputs.push_back(cnode->input(i)); - } - } else { - vir_outputs.push_back(fg_outputs); - } - - if (vir_outputs.size() != outputs->size()) { - MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; - } - bool has_erase_outs = false; - size_t index = -1; - for (auto it = outputs->begin(); it != outputs->end();) { - index++; - auto out = *it; - eqv[out] = vir_outputs[index]; - auto users = mng->node_users()[out]; - bool is_only_control_depend_use = true; - std::vector control_depend_use_index; - std::vector control_depend_nodes; - AnfNodePtr use_out = nullptr; - for (auto &user : users) { - auto use_node = user.first; - if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { - is_only_control_depend_use = false; - continue; - } - if (outputs_set.count(use_node) != 0) { - use_out = use_node; - } - - if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { - control_depend_nodes.push_back(use_node->cast()); - control_depend_use_index.push_back(user.second); - } - } - - if (is_only_control_depend_use && !control_depend_nodes.empty()) { - MS_EXCEPTION_IF_NULL(use_out); - it = outputs->erase(it); - for (size_t i = 0; i < control_depend_nodes.size(); ++i) { - auto control_depend_node = control_depend_nodes[i]; - std::vector new_control_depend_inputs; - for (size_t j = 0; j < control_depend_node->size(); ++j) { - if (j == control_depend_use_index[i]) { - new_control_depend_inputs.push_back(use_out); - } else { - new_control_depend_inputs.push_back(control_depend_node->input(j)); - } - } - auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs); - mng->Replace(control_depend_node, new_control_depend); - has_erase_outs = true; - } - } else { - it++; - } - } - - if (!has_erase_outs) { - return; - } - - AnfNodePtr fg_new_output; - if (outputs->size() > 1) { - std::vector output_args; - output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); - (void)std::transform(std::begin(*outputs), std::end(*outputs), std::back_inserter(output_args), - [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); - // Set output for AnfGraph - fg_new_output = fg->NewCNode(output_args); - } else { - fg_new_output = eqv[(*outputs)[0]]; - } - fg->set_output(fg_new_output, true); -} - -void FuseBasic(const std::shared_ptr &kernel_graph, const std::vector &todos, - std::unordered_set *fused_ops, bool is_before_kernel_select) { - auto mng = kernel_graph->manager(); - for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { - auto node = (*iter)->cast(); - if (node == nullptr) { - continue; - } - if (fused_ops->count(node)) { - continue; - } - auto fusable_basic_ops = get_fusable_basic_ops(is_before_kernel_select); - bool is_basic_op = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); - if (!is_basic_op || !kernel_graph->nodes().contains(node)) { - continue; - } - - auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select); - if (fuse_nodes.size() <= 1) { - continue; - } - - FuncGraphPtr fg; - AnfNodePtrList inputs; - AnfNodePtrList outputs; - std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); - RemoveControlDependOut(fg, &outputs, mng); - auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select); - - ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); - - // Set graph kernel attr - std::string fuse_op_name = ""; - for (auto &fuse_node : fuse_nodes) { - fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_"; - } - fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end()); - fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); - } -} -} // namespace - -void FuseBasic(const std::shared_ptr &kernel_graph, bool is_before_kernel_select) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto mng = kernel_graph->manager(); - if (mng == nullptr) { - mng = Manage(kernel_graph, true); - kernel_graph->set_manager(mng); - } - std::unordered_set fused_ops; - auto todos = TopoSort(kernel_graph->get_return()); - std::reverse(todos.begin(), todos.end()); - FuseBasic(kernel_graph, todos, &fused_ops, is_before_kernel_select); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/fuse_basic.h b/mindspore/ccsrc/pre_activate/pass/fuse_basic.h deleted file mode 100644 index fbbf5d9937..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/fuse_basic.h +++ /dev/null @@ -1,29 +0,0 @@ - -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ - -#include -#include "pre_activate/common/optimizer.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -void FuseBasic(const std::shared_ptr &kernel_graph, bool is_before_kernel_select); -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.cc b/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.cc deleted file mode 100644 index 0e287587a2..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.cc +++ /dev/null @@ -1,562 +0,0 @@ - -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/fuse_graph_kernel.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "operator/ops.h" -#include "utils/utils.h" -#include "utils/graph_utils.h" -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "vm/segment_runner.h" -#include "debug/draw.h" -#include "debug/anf_ir_dump.h" -#include "ir/func_graph_cloner.h" - -namespace mindspore { -namespace opt { -std::vector get_fusable_basic_ops(bool is_before_kernel_select) { - std::vector fusable_basic_ops = { - prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum, - prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt, - prim::kPrimReciprocal, prim::kPrimExpandDims, prim::kPrimLessEqual}; - if (!is_before_kernel_select) { - fusable_basic_ops.push_back(prim::kPrimCast); - } - return fusable_basic_ops; -} - -std::vector get_fusable_basic_ops_with_reduce(bool is_before_kernel_select) { - std::vector fusable_basic_ops_with_reduce; - if (!is_before_kernel_select) { - fusable_basic_ops_with_reduce.push_back(prim::kPrimCast); - } - return fusable_basic_ops_with_reduce; -} - -std::vector get_reduce_ops() { - std::vector reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin, - prim::kPrimReduceMax, prim::kPrimReduceAll}; - return reduce_ops; -} - -void GetGraphKernelInfo(const FuncGraphPtr fg, GraphKernelInfo *info) { - MS_EXCEPTION_IF_NULL(fg); - auto reduce_ops = get_reduce_ops(); - const auto &nodes = fg->nodes(); - info->op_type = ELEWISE; - info->cal_step = -1; - info->reduce_op_num = 0; - for (auto node : nodes) { - auto cnode = node->cast(); - if (cnode == nullptr) { - continue; - } - info->cal_step++; - auto prim = GetValueNode(cnode->input(0)); - if (prim != nullptr) { - bool is_reudce = std::any_of(reduce_ops.begin(), reduce_ops.end(), [&prim](const PrimitivePtr &op) { - return op->hash() == prim->hash() && op->name() == prim->name(); - }); - if (is_reudce) { - info->op_type = REDUCE; - info->reduce_op_num++; - } - } - } -} - -bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) { - auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select); - auto fusable_basic_ops_with_reduce = get_fusable_basic_ops_with_reduce(info.is_before_kernel_select); - bool is_fusable = false; - if (info.op_type == REDUCE && - (info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) { - is_fusable = std::any_of(fusable_basic_ops_with_reduce.begin(), fusable_basic_ops_with_reduce.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); - } else { - is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); - } - - return is_fusable; -} - -IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, - const AnfNodePtr &node) { - if (cur_node == node) { - return FOLLOW; - } - if (!IsPrimitiveCNode(node)) { - return EXCLUDE; - } - - bool is_fusable = IsFuse(info, node); - return is_fusable ? FOLLOW : EXCLUDE; -} - -IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, - const AnfNodePtr &node) { - if (cur_node == node) { - return FOLLOW; - } - if (AnfAlgo::IsGraphKernel(node)) { - auto cnode = node->cast(); - auto fg = GetValueNode(cnode->input(kAnfPrimitiveIndex)); - auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - MS_EXCEPTION_IF_NULL(fg_attr_val); - auto fg_attr = GetValue(fg_attr_val); - if (fg_attr == kApplyMomentumOpName) { - return FOLLOW; - } - return EXCLUDE; - } - if (!IsPrimitiveCNode(node)) { - return EXCLUDE; - } - - bool is_fusable = IsFuse(info, node); - return is_fusable ? FOLLOW : EXCLUDE; -} - -bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &check_node, - std::set *cached_unconnected_set) { - if (!check_node->isa() || AnfAlgo::IsGraphKernel(check_node)) { - return false; - } - - auto cnode = check_node->cast(); - const auto &inputs = cnode->inputs(); - // there is a input not in fused_op_set, but the input depends on the fused_op_set - bool has_circle = false; - for (auto input : inputs) { - if (input->isa() && !fused_op_set.count(input)) { - std::set done; - std::vector todos = {input}; - while (!todos.empty()) { - auto node = todos.back(); - todos.pop_back(); - if (done.count(node) || cached_unconnected_set->count(node)) { - continue; - } - - done.insert(node); - if (fused_op_set.count(node)) { - has_circle = true; - break; - } - - if (node->isa()) { - auto cnode_ptr = node->cast(); - for (auto it : cnode_ptr->inputs()) { - if (it->isa()) { - todos.push_back(it); - } - } - } - } - - if (has_circle) { - return true; - } - cached_unconnected_set->insert(done.begin(), done.end()); - } - } - - return false; -} - -bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) { - if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { - auto &inputs = out->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - real_outs->push_back(inputs[i]); - } - return true; - } - - if (AnfAlgo::GetCNodeFuncGraphPtr(out) != nullptr) { - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out); - auto fg_out = fg->output(); - if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) { - auto inputs = fg_out->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - real_outs->push_back(inputs[i]); - } - return true; - } - } - return false; -} - -std::vector RemoveCircle(const std::vector &fused_op, bool is_backward) { - std::set cached_unconnected_set; - std::set fused_op_set(fused_op.begin(), fused_op.end()); - auto include = [&fused_op_set](const AnfNodePtr &node) { - if (fused_op_set.count(node)) { - return FOLLOW; - } - return EXCLUDE; - }; - for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { - bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set); - // delete the circle node and the node which depend on the circle node in fused op - if (has_circle) { - auto mng = (*iter)->func_graph()->manager(); - std::vector erase_nodes; - if (is_backward) { - erase_nodes = DeepUsersSearch(*iter, include, mng); - } else { - erase_nodes = DeepLinkedGraphSearch(*iter, include); - } - for (auto erase_node : erase_nodes) { - fused_op_set.erase(erase_node); - } - } - } - - std::vector res; - for (auto node : fused_op) { - if (fused_op_set.count(node)) { - res.push_back(node); - } - } - return res; -} - -void TopoSortForNodeList(std::vector *lst) { - if (lst->size() < 2) { - return; - } - - std::vector res; - std::set node_sets(lst->begin(), lst->end()); - std::map> ins; - std::map> outs; - std::queue q; - for (auto node : *lst) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - for (auto input : cnode->inputs()) { - if (!node_sets.count(input)) { - continue; - } - // out_degree - outs[input].insert(node); - // in_degree - ins[node].insert(input); - } - if (!ins.count(node)) { - ins[node] = {}; - } - } - - for (auto p : ins) { - if (p.second.size() == 0) { - q.push(p.first); - } - } - - while (!q.empty()) { - auto node = q.front(); - q.pop(); - res.push_back(node); - if (!outs.count(node)) { - continue; - } - for (auto out : outs[node]) { - if (!ins.count(out)) { - continue; - } - ins[out].erase(node); - if (ins[out].size() == 0) { - q.push(out); - } - } - } - - lst->assign(res.begin(), res.end()); -} - -std::vector FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { - auto func_graph = cnode->func_graph(); - auto graph_kernel_g = GetValueNode(cnode->input(0)); - GraphKernelInfo info; - info.is_before_kernel_select = is_before_kernel_select; - GetGraphKernelInfo(graph_kernel_g, &info); - auto mng = func_graph->manager(); - // Search fusable nodes according input direction. - auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1); - auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); - std::reverse(used_nodes.begin(), used_nodes.end()); - // Search fusable nodes according output direction. - auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, info, std::placeholders::_1); - auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng); - - used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end()); - if (used_nodes.size() > 1) { - used_nodes = RemoveCircle(used_nodes); - } - TopoSortForNodeList(&used_nodes); - return used_nodes; -} - -AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) { - auto out_spec = node->abstract(); - if (out_spec->isa()) { - return out_spec->cast()->elements()[output_idx]; - } - return out_spec; -} - -AnfNodePtr CreateNewFuseCNode(const std::shared_ptr &kernel_graph, const FuncGraphPtr &fg, - const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, - bool is_before_kernel_select) { - auto func_node = NewValueNode(fg); - std::vector fn_inputs; - fn_inputs.push_back(func_node); - fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end()); - auto fuse_cnode = kernel_graph->NewCNode(fn_inputs); - // Set output abstract - if (outputs.size() > 1) { - std::vector out_specs; - for (size_t i = 0; i < outputs.size(); ++i) { - out_specs.push_back(outputs[i]->abstract()); - } - auto out_spec = std::make_shared(out_specs); - fuse_cnode->set_abstract(out_spec); - } else { - fuse_cnode->set_abstract(outputs[0]->abstract()); - } - // Set parameter abstract. - for (size_t i = 0; i < inputs.size(); ++i) { - auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); - auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); - fg->parameters()[i]->set_abstract(input_abs); - if (is_before_kernel_select) { - fg->parameters()[i]->set_kernel_info(std::make_shared()); - } - } - // Set kernel info. - if (!is_before_kernel_select) { - std::vector graph_input_format; - std::vector graph_input_type; - std::vector graph_output_format; - std::vector graph_output_type; - for (size_t i = 0; i < inputs.size(); ++i) { - auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); - auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); - graph_input_format.push_back(input_format); - auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); - graph_input_type.push_back(input_type); - auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); - fg->parameters()[i]->set_abstract(input_abs); - } - auto new_outputs = outputs; - if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) { - std::vector real_outs; - if (IsMakeTupleOut(outputs[0], &real_outs)) { - new_outputs = real_outs; - } - } - for (size_t i = 0; i < new_outputs.size(); ++i) { - auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0); - auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); - auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); - graph_output_format.push_back(output_format); - graph_output_type.push_back(output_type); - } - kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; - graph_info_builder.SetInputsFormat(graph_input_format); - graph_info_builder.SetInputsDeviceType(graph_input_type); - graph_info_builder.SetOutputsFormat(graph_output_format); - graph_info_builder.SetOutputsDeviceType(graph_output_type); - graph_info_builder.SetProcessor(kernel::Processor::AICORE); - graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); - graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); - auto graph_selected_info = graph_info_builder.Build(); - AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, fuse_cnode.get()); - } - return fuse_cnode; -} - -void ReplaceNewFuseCNode(const std::shared_ptr &kernel_graph, const AnfNodePtr &new_fuse_cnode, - const AnfNodePtrList &outputs) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto mng = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - // single out - if (outputs.size() == 1) { - mng->Replace(outputs[0], new_fuse_cnode); - return; - } - - std::vector fn_inputs; - for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) { - AnfNodePtrList real_outs; - // not make tuple out, replace - if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) { - fn_inputs.clear(); - fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); - fn_inputs.push_back(new_fuse_cnode); - fn_inputs.push_back(NewValueNode(MakeValue(SizeToInt(out_idx)))); - auto new_out = kernel_graph->NewCNode(fn_inputs); - new_out->set_abstract(outputs[out_idx]->abstract()); - mng->Replace(outputs[out_idx], new_out); - continue; - } - - // the out is make tuple , modify the get_item node's value - auto users = mng->node_users()[outputs[out_idx]]; - for (auto &user : users) { - auto use_node = user.first; - if (use_node->isa() && (IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem))) { - auto get_item_cnode = use_node->cast(); - auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(value_input); - auto value_node = value_input->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - int new_item_idx = SizeToInt(out_idx) + item_idx; - fn_inputs.clear(); - fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); - fn_inputs.push_back(new_fuse_cnode); - fn_inputs.push_back(NewValueNode(new_item_idx)); - auto new_out = kernel_graph->NewCNode(fn_inputs); - new_out->set_abstract(get_item_cnode->abstract()); - mng->Replace(get_item_cnode, new_out); - } - } - } -} - -AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) { - AnfNodePtrList outs; - auto out_node = (*fg)->output(); - if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { - std::vector output_args; - auto out_cnode = out_node->cast(); - for (auto out : out_cnode->inputs()) { - if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { - auto inputs = out->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - output_args.push_back(inputs[i]); - } - } else { - output_args.push_back(out); - } - } - if (output_args.size() != out_cnode->inputs().size()) { - auto new_out = (*fg)->NewCNode(output_args); - (*mng)->Replace(out_node, new_out); - } - - for (size_t i = 1; i < output_args.size(); ++i) { - outs.push_back(output_args[i]); - } - return outs; - } - - outs.push_back(out_node); - return outs; -} - -AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) { - AnfNodePtrList res; - if (outs.size() <= 1) { - return outs; - } - - for (auto out : outs) { - AnfNodePtrList real_outs; - if (IsMakeTupleOut(out, &real_outs)) { - res.insert(res.end(), real_outs.begin(), real_outs.end()); - continue; - } - res.push_back(out); - } - return res; -} - -void FuseGraphKernel(const std::shared_ptr &kernel_graph, bool is_before_kernel_select) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto mng = kernel_graph->manager(); - if (mng == nullptr) { - mng = Manage(kernel_graph, true); - kernel_graph->set_manager(mng); - } - auto &todos = kernel_graph->execution_order(); - for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { - auto node = *iter; - if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) { - continue; - } - - auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - if (fg_attr != nullptr) { - auto fg_name = GetValue(fg_attr); - if (graph_kernel_black_list.count(fg_name) != 0) { - continue; - } - } - - auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select); - if (fuse_nodes.size() <= 1) { - continue; - } - - FuncGraphPtr fg; - AnfNodePtrList inputs; - AnfNodePtrList outputs; - std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); - - // Remove nest make tuple in outs - auto expand_out = GetExpandOuts(outputs); - auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, expand_out, is_before_kernel_select); - - ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); - - // Inline origin graphkernel - auto cnodes = fg->GetOrderedCnodes(); - for (const auto &n : cnodes) { - if (!AnfAlgo::IsGraphKernel(n)) { - continue; - } - auto graph_kernel_g = GetValueNode(n->input(0)); - AnfNodePtrList ins; - ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end()); - auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope()); - mng->Replace(n, out); - } - - EliminateMakeTuple(&fg, &mng); - // Set graphkernel flag - auto ori_fg = GetValueNode(node->input(kAnfPrimitiveIndex)); - fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, ori_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.h b/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.h deleted file mode 100644 index a5a26765a3..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.h +++ /dev/null @@ -1,63 +0,0 @@ - -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ - -#include -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -enum GraphKernelType { - ELEWISE = 0, // only contain elewise basic ops - REDUCE, // contain reduce ops - CUBE, // contain cube ops -}; -struct GraphKernelInfo { - GraphKernelType op_type = ELEWISE; - bool is_before_kernel_select = false; - int reduce_op_num = 0; - int cal_step = 0; -}; - -// when reduce graph kernel's cal step is greater than this number, not fuse -const int MAX_REDUCE_OP_FUSION_CAL_STEP = 5; -// when reduce graph kernel contain reduce op num is greater than this number, not fuse -const int MAX_REDUCE_OP_FUSION_REDUCE_NUM = 2; - -const std::set graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", - "LambNextMV", "LambUpdateWithLR"}; - -std::vector RemoveCircle(const std::vector &fused_op, bool is_backward = true); - -void TopoSortForNodeList(std::vector *lst); - -AnfNodePtr CreateNewFuseCNode(const std::shared_ptr &kernel_graph, const FuncGraphPtr &fg, - const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, - bool is_before_kernel_select); - -void ReplaceNewFuseCNode(const std::shared_ptr &kernel_graph, const AnfNodePtr &new_fuse_cnode, - const AnfNodePtrList &outputs); - -void FuseGraphKernel(const std::shared_ptr &kernel_graph, bool is_before_kernel_select = false); -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/getitem_tuple.cc b/mindspore/ccsrc/pre_activate/pass/getitem_tuple.cc deleted file mode 100644 index af16017a7c..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/getitem_tuple.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/getitem_tuple.h" - -#include -#include "operator/ops.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -bool IsC(const BaseRef &n) { - MS_EXCEPTION_IF_NULL(n); - if (utils::isa(n)) { - AnfNodePtr in = utils::cast(n); - MS_EXCEPTION_IF_NULL(in); - return in->isa(); - } else { - return false; - } -} -} // namespace - -const BaseRef GetitemTuple::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VarPtr C = std::make_shared(IsC); - return VectorRef({prim::kPrimTupleGetItem, VectorRef({prim::kPrimMakeTuple, Xs}), C}); -} - -const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - CNodePtr tuple_getitem = node->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - if (tuple_getitem->inputs().size() < kTupleGetitemInputNum) { - MS_LOG(EXCEPTION) << "tuple getitem's input num is wrong"; - } - AnfNodePtr make_tuple_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(make_tuple_anf); - AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - if (IsValueNode(index_node)) { - ValueNodePtr value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - CNodePtr make_tuple = make_tuple_anf->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - if (make_tuple->inputs().size() > IntToSize(index + 1)) { - auto ret = make_tuple->input(IntToSize(index + 1)); - MS_EXCEPTION_IF_NULL(ret); - return ret; - } - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/getitem_tuple.h b/mindspore/ccsrc/pre_activate/pass/getitem_tuple.h deleted file mode 100644 index 0fc42a15dc..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/getitem_tuple.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class GetitemTuple : public PatternProcessPass { - public: - explicit GetitemTuple(bool multigraph = true) : PatternProcessPass("getitem_tuple", multigraph) {} - ~GetitemTuple() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc deleted file mode 100644 index 1d5f909e7d..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/pass/optimize_dependence.h" -#include -#include -#include -#include "pre_activate/common/helper.h" -#include "operator/ops.h" -#include "utils/utils.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -constexpr auto kSingleInputIndex = 1; -namespace { -AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - string op_name = AnfAlgo::GetCNodeName(cnode); - // Currently we only eliminate transdata or cast nodes. - if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { - return nullptr; - } - CheckCNodeInputSize(cnode, kSingleInputIndex + 1); - return cnode->input(kSingleInputIndex); -} - -AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { - return nullptr; - } - std::vector new_make_tuple_inputs; - bool need_update = false; - for (const auto &input : cnode->inputs()) { - AnfNodePtr replace_input = GetReplaceNode(input); - // If replace input is not null, it will be the input of the TransData or Cast. - if (replace_input == nullptr) { - new_make_tuple_inputs.push_back(input); - continue; - } - new_make_tuple_inputs.push_back(replace_input); - need_update = true; - } - if (need_update) { - auto kernel_graph = func_graph->cast>(); - CNodePtr new_make_tuple = nullptr; - if (kernel_graph == nullptr) { - new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs); - } else { - new_make_tuple = kernel_graph->NewCNode(cnode); - } - MS_EXCEPTION_IF_NULL(new_make_tuple); - new_make_tuple->set_inputs(new_make_tuple_inputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->Replace(cnode, new_make_tuple); - return new_make_tuple; - } - return nullptr; -} -} // namespace - -const BaseRef OptimizeDependence::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({X, Xs}); -} - -const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return nullptr; - } - auto node_name = AnfAlgo::GetCNodeName(node); - if (node_name != prim::kPrimControlDepend->name() && node_name != prim::kPrimDepend->name()) { - return nullptr; - } - size_t index = 0; - auto depend_cnode = node->cast(); - MS_EXCEPTION_IF_NULL(depend_cnode); - std::vector new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)}; - if (node_name == prim::kPrimDepend->name()) { - index = 1; - new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend)); - } - if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) { - MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got " - << AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString(); - } - auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode); - while (index < input_num) { - auto replace_node = GetConvertNode(func_graph, node, index); - MS_EXCEPTION_IF_NULL(replace_node); - new_depend_inputs.push_back(replace_node); - ++index; - } - auto kernel_graph = func_graph->cast>(); - CNodePtr new_depend = nullptr; - if (kernel_graph == nullptr) { - new_depend = func_graph->NewCNode(new_depend_inputs); - MS_EXCEPTION_IF_NULL(new_depend); - new_depend->set_abstract(node->abstract()); - new_depend->set_scope(node->scope()); - } else { - new_depend = kernel_graph->NewCNode(depend_cnode); - MS_EXCEPTION_IF_NULL(new_depend); - new_depend->set_inputs(new_depend_inputs); - } - return new_depend; -} - -const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, - const size_t index) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto depend_cnode = node->cast(); - auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); - MS_EXCEPTION_IF_NULL(replacing_node); - if (!replacing_node->isa()) { - return replacing_node; - } - auto replacing_cnode = replacing_node->cast(); - MS_EXCEPTION_IF_NULL(replacing_cnode); - // Deal with the make_tuple with TransData or Cast inputs. - auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode); - if (make_tuple_replace_node != nullptr) { - return make_tuple_replace_node; - } - AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); - if (replace_node == nullptr) { - MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); - return replacing_node; - } - return replace_node; -} - -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.h b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.h deleted file mode 100644 index 30027b790a..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class OptimizeDependence : public PatternProcessPass { - public: - explicit OptimizeDependence(bool multigraph = true) : PatternProcessPass("optimize_dependence", multigraph) {} - ~OptimizeDependence() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - const AnfNodePtr GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t index) const; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ diff --git a/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h b/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h index 5c7551a190..612ccde1a5 100644 --- a/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h +++ b/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h @@ -25,7 +25,7 @@ #include #include #include "ir/tensor.h" -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" #include "predict/schema/inner/ms_generated.h" using TensorPtr = mindspore::tensor::TensorPtr; diff --git a/mindspore/ccsrc/predict/converter/kernel2ms.cc b/mindspore/ccsrc/predict/converter/kernel2ms.cc index 1b1277aade..04aceb62eb 100644 --- a/mindspore/ccsrc/predict/converter/kernel2ms.cc +++ b/mindspore/ccsrc/predict/converter/kernel2ms.cc @@ -18,7 +18,7 @@ #include #include "ir/anf.h" #include "predict/converter/lite_model/op_attr_packer.h" -#include "mindspore/ccsrc/operator/ops.h" +#include "mindspore/ccsrc/frontend/operator/ops.h" namespace mindspore { namespace executor { diff --git a/mindspore/ccsrc/predict/converter/kernel2ms.h b/mindspore/ccsrc/predict/converter/kernel2ms.h index 7013f88107..8cbc89ed6a 100644 --- a/mindspore/ccsrc/predict/converter/kernel2ms.h +++ b/mindspore/ccsrc/predict/converter/kernel2ms.h @@ -22,7 +22,7 @@ #include #include #include -#include "session/kernel_graph.h" +#include "backend/session/kernel_graph.h" #include "predict/converter/executor_tensor.h" #include "predict/schema/inner/ms_generated.h" #include "predict/converter/attr_utils/convert_util.h" diff --git a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h b/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h index 89e38d1871..31f14ef73a 100644 --- a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h +++ b/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h @@ -20,7 +20,7 @@ #include #include #include -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" #include "predict/schema/inner/ms_generated.h" static constexpr size_t kNIndex = 0; diff --git a/mindspore/ccsrc/predict/predict.h b/mindspore/ccsrc/predict/predict.h index 7c65f16619..9125451492 100644 --- a/mindspore/ccsrc/predict/predict.h +++ b/mindspore/ccsrc/predict/predict.h @@ -19,7 +19,7 @@ #include #include -#include "session/session_basic.h" +#include "backend/session/session_basic.h" #include "predict/converter/kernel2ms.h" namespace mindspore { diff --git a/mindspore/ccsrc/pynative/CMakeLists.txt b/mindspore/ccsrc/pynative/CMakeLists.txt deleted file mode 100644 index 5139160774..0000000000 --- a/mindspore/ccsrc/pynative/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "base.cc" "pynative_execute.cc") - -if (ENABLE_GE) - file(GLOB_RECURSE _GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute_ge.cc") - list(APPEND _PYNATIVE_SRC_LIST ${_GE_SRC_LIST}) -endif () - -set_property(SOURCE ${_PYNATIVE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PYNATIVE) -add_library(_mindspore_pynative_obj OBJECT ${_PYNATIVE_SRC_LIST}) diff --git a/mindspore/ccsrc/pynative/base.h b/mindspore/ccsrc/pynative/base.h deleted file mode 100644 index 60ae869227..0000000000 --- a/mindspore/ccsrc/pynative/base.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PYNATIVE_BASE_H_ -#define MINDSPORE_CCSRC_PYNATIVE_BASE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pybind11/pybind11.h" -#include "ir/primitive.h" -#include "pipeline/static_analysis/abstract_value.h" - -namespace mindspore { -namespace pynative { -namespace py = pybind11; - -enum PynativeStatusCode { - PYNATIVE_SUCCESS = 0, - PYNATIVE_OP_NOT_IMPLEMENTED_ERR = 1, - PYNATIVE_OP_INPUTS_ERR = 2, - PYNATIVE_OP_PARAMS_ERR = 3, - PYNATIVE_OP_ATTRS_ERR = 4, - PYNATIVE_GRAPH_MANAGER_ERR = 5, - PYNATIVE_GRAPH_GE_BUILD_ERR = 6, - PYNATIVE_GRAPH_GE_RUN_ERR = 7, - PYNATIVE_UNKNOWN_STATE = 0XFF -}; - -enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; - -struct OpExecInfo { - PrimitivePyPtr py_primitive; - std::string op_name; - AbstractBasePtr abstract; - - py::tuple op_inputs; - py::tuple inputs_mask; - py::dict op_attrs; -}; -using OpExecInfoPtr = std::shared_ptr; -OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args); - -const std::set ignore_infer_prim = {"make_ref"}; -} // namespace pynative -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PYNATIVE_BASE_H_ diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc deleted file mode 100644 index f477bfbdcd..0000000000 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ /dev/null @@ -1,1113 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pynative/pynative_execute.h" - -#include -#include -#include -#include -#include - -#include "debug/trace.h" -#include "ir/tensor_py.h" -#include "ir/param_value_py.h" -#include "utils/any.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "operator/ops.h" -#include "operator/composite/composite.h" -#include "operator/composite/do_signature.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/resolve.h" -#include "pipeline/static_analysis/prim.h" -#include "session/session_factory.h" -#include "pre_activate/pass/const_input_to_attr_registry.h" -#include "pre_activate/common/helper.h" -#include "pipeline/action.h" - -#include "pynative/base.h" -#include "pybind_api/api_register.h" -#include "vm/transform.h" - -#include "optimizer/ad/grad.h" -#include "pipeline/resource.h" -#include "pipeline/pipeline.h" -#include "pipeline/pass.h" - -#ifdef ENABLE_GE -#include "pynative/pynative_execute_ge.h" -#endif - -using mindspore::tensor::TensorPy; - -const char SINGLE_OP_GRAPH[] = "single_op_graph"; -// primitive unable to infer value for constant input in PyNative mode -const std::set vm_operators = {"make_ref", "HookBackward", "stop_gradient"}; - -namespace mindspore { -namespace pynative { - -static std::shared_ptr session = nullptr; -PynativeExecutorPtr PynativeExecutor::executor_ = nullptr; -std::mutex PynativeExecutor::instance_lock_; -ResourcePtr PynativeExecutor::resource_; - -template -void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) { - try { - (executor->*method)(args...); - } catch (const py::error_already_set &ex) { - // print function call stack info before release - std::ostringstream oss; - trace::TraceGraphEval(); - trace::GetEvalStackInfo(oss); - // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see - // these info from screen, no need to open log file to find these info - py::print(oss.str()); - MS_LOG(ERROR) << oss.str(); - PynativeExecutor::GetInstance()->Clean(); - // re-throw this exception to Python interpreter to handle it - throw(py::error_already_set(ex)); - } catch (const py::type_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::type_error(ex); - } catch (const py::value_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::value_error(ex); - } catch (const py::index_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::index_error(ex); - } catch (const std::exception &ex) { - PynativeExecutor::GetInstance()->Clean(); - // re-throw this exception to Python interpreter to handle it - throw(std::runtime_error(ex.what())); - } catch (...) { - PynativeExecutor::GetInstance()->Clean(); - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; - } -} - -inline ValuePtr PyAttrValue(const py::object &obj) { - ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj); - if (!converted_ret) { - MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); - } - return converted_ret; -} - -std::string GetId(const py::object &obj) { - py::object to_process = obj; - std::string prefix = ""; - if (py::isinstance(to_process)) { - auto p_list = py::cast(to_process); - if (p_list.size() == 0) { - return "empty"; - } - prefix = "tuple:"; - std::string key = ""; - for (size_t i = 0; i < p_list.size(); ++i) { - key += std::string(py::str(GetId(p_list[i]))) + ":"; - } - return prefix + key; - } - if (py::isinstance(to_process)) { - return prefix + std::string(py::str(to_process)); - } - if (py::isinstance(to_process)) { - return prefix + std::string(py::str(to_process)); - } - if (py::isinstance(to_process)) { - auto tensor_ptr = py::cast(to_process); - return prefix + tensor_ptr->id(); - } - - py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj); - return py::cast(ret); -} - -py::object GetTupleObj(const py::object &obj) { - py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); - py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj); - return obj_tuple; -} - -std::map> GetTypeIndex(const std::vector &dtypes) { - std::map> type_indexes; - for (size_t i = 0; i < dtypes.size(); ++i) { - auto it = type_indexes.find(dtypes[i]); - if (it == type_indexes.end()) { - (void)type_indexes.insert(std::make_pair(dtypes[i], std::vector{i})); - } else { - it->second.push_back(i); - } - } - return type_indexes; -} - -std::map GetDstType(const py::tuple &py_args, - const std::map> &type_indexes) { - std::map dst_type; - for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) { - auto type = it->first; - auto indexes = it->second; - if (indexes.size() < 2) { - continue; - } - size_t m_index = indexes[0]; - for (size_t i = 1; i < indexes.size(); ++i) { - if (py::isinstance(py_args[indexes[i]])) { - m_index = indexes[i]; - } - } - (void)dst_type.insert(std::make_pair(type, m_index)); - } - return dst_type; -} - -py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args, - py::list *const out_args_list) { - auto &py_args = *out_args; - py::tuple input_mask(args.size()); - for (size_t i = 0; i < args.size(); ++i) { - if (py::hasattr(args[i], "__parameter__")) { - input_mask[i] = true; - } else { - input_mask[i] = false; - } - py_args[i] = GetTupleObj(args[i]); - } - auto signature = prim->signatures(); - std::vector dtypes; - (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature &sig) { return sig.dtype; }); - int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); - if (dtypes.size() == 0 || static_cast(dtypes.size()) == empty_dtype_count) { - return input_mask; - } - auto type_indexes = GetTypeIndex(dtypes); - auto dst_type = GetDstType(py_args, type_indexes); - for (size_t i = 0; i < py_args.size(); ++i) { - auto it = dst_type.find(dtypes[i]); - if (it != dst_type.end() && it->second != i && - (py::isinstance(py_args[i]) || py::isinstance(py_args[i]))) { - auto tensor_ptr = py::cast(py_args[it->second]); - if (py::isinstance(py_args[i])) { - py_args[i] = std::make_shared(py::cast(py_args[i]), tensor_ptr->Dtype()); - (*out_args_list)[i] = py_args[i]; - } else { - double arg_value = py::cast(py_args[i]); - py_args[i] = std::make_shared(arg_value, tensor_ptr->Dtype()); - (*out_args_list)[i] = py_args[i]; - } - continue; - } - } - return input_mask; -} - -void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) { - size_t size = py_args.size(); - AbstractBasePtrList args_spec_list; - for (size_t i = 0; i < size; i++) { - ValuePtr input_value = PyAttrValue(py_args[i]); - args_spec_list.emplace_back(abstract::FromValueInside( - input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa())); - } - AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); - op_exec_info->abstract = infer_res; -} - -OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) { - if (args.size() != PY_ARGS_NUM) { - MS_LOG(ERROR) << "Three args are needed by RunOp"; - return nullptr; - } - auto op_exec_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(op_exec_info); - op_exec_info->op_name = py::cast(args[PY_NAME]); - auto prim = py::cast(args[PY_PRIM]); - auto pyobj = prim->GetPyObj(); - if (pyobj == nullptr) { - MS_LOG(EXCEPTION) << "pyobj is empty"; - } - - py::list a = args[PY_INPUTS]; - size_t input_num = a.size(); - op_exec_info->op_inputs = py::tuple(input_num); - - op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs, out_args); - // use python infer method - if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { - PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get()); - } - op_exec_info->py_primitive = prim; - op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); - if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { - MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; - return nullptr; - } - return op_exec_info; -} - -std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, - const std::vector &input_tensors) { - MS_EXCEPTION_IF_NULL(op_exec_info); - std::string graph_info; - // get input tensor info - size_t input_num = op_exec_info->op_inputs.size(); - for (size_t index = 0; index < input_num; ++index) { - auto input = op_exec_info->op_inputs[index]; - if (py::isinstance(input)) { - auto tensor_ptr = py::cast(input); - (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_"); - } - } - // get prim and abstract info - MS_EXCEPTION_IF_NULL(op_exec_info->abstract); - (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" + - op_exec_info->abstract->ToString()); - return graph_info; -} - -py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { - MS_LOG(INFO) << "RunOpInVM start"; - - MS_EXCEPTION_IF_NULL(status); - MS_EXCEPTION_IF_NULL(op_exec_info); - MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive); - if (op_exec_info->op_name == "HookBackward") { - auto op_inputs = op_exec_info->op_inputs; - py::tuple result(op_inputs.size()); - for (size_t i = 0; i < op_inputs.size(); i++) { - py::object input = op_inputs[i]; - if (py::hasattr(input, "__parameter__")) { - result[i] = py::getattr(input, "data"); - } else { - auto tensor = py::cast(input); - auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); - new_tensor->set_device_address(tensor->device_address()); - new_tensor->set_dirty(tensor->is_dirty()); - result[i] = new_tensor; - } - } - *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "RunOpInVM end"; - return std::move(result); - } - auto func = op_exec_info->py_primitive->GetComputeFunction(); - if (py::isinstance(func)) { - MS_LOG(ERROR) << "VM failed to get func"; - *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; - py::tuple err_ret(0); - return std::move(err_ret); - } - - // execute op - py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs)); - *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "RunOpInVM end"; - return std::move(result); -} - -bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, - const std::unordered_set &input_attrs) { - MS_EXCEPTION_IF_NULL(op_prim); - auto input_names_value = op_prim->GetAttr(kAttrInputNames); - if (input_names_value == nullptr) { - return false; - } - auto input_names_vec = GetValue>(input_names_value); - if (input_index >= input_names_vec.size()) { - MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!"; - } - - if (input_attrs.find(input_index) != input_attrs.end()) { - ValuePtr value = parse::data_converter::PyDataToValue(input_object); - MS_EXCEPTION_IF_NULL(value); - auto input_name = input_names_vec[input_index]; - op_prim->set_attr(input_name, value); - return true; - } - return false; -} - -void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim, - std::vector *input_tensors) { - MS_EXCEPTION_IF_NULL(op_prim); - MS_EXCEPTION_IF_NULL(input_tensors); - for (const auto &input_object : tuple_inputs) { - if (!py::isinstance(input_object)) { - MS_LOG(EXCEPTION) << "The input object is not a tensor!"; - } - auto tensor = py::cast(input_object); - MS_EXCEPTION_IF_NULL(tensor); - input_tensors->push_back(tensor); - } - op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector{SizeToInt(tuple_inputs.size())})); -} - -void ConvertValueTupleToTensor(const py::object &input_object, std::vector *input_tensors) { - MS_EXCEPTION_IF_NULL(input_tensors); - ValuePtr input_value = parse::data_converter::PyDataToValue(input_object); - MS_EXCEPTION_IF_NULL(input_value); - if (!input_value->isa()) { - MS_LOG(EXCEPTION) << "The input object is not a value tuple!"; - } - auto value_tuple = input_value->cast(); - MS_EXCEPTION_IF_NULL(value_tuple); - tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple); - MS_EXCEPTION_IF_NULL(tensor_ptr); - input_tensors->push_back(tensor_ptr); -} - -void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, - std::vector *input_tensors, int *tensor_mask) { - MS_EXCEPTION_IF_NULL(op_prim); - MS_EXCEPTION_IF_NULL(input_tensors); - MS_EXCEPTION_IF_NULL(tensor_mask); - - if (!py::isinstance(input_object)) { - MS_LOG(EXCEPTION) << "The input should be a tuple!"; - } - auto tuple_inputs = py::cast(input_object); - if (tuple_inputs.size() == 0) { - MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!"; - } - if (py::isinstance(tuple_inputs[0])) { - PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors); - } else { - ConvertValueTupleToTensor(input_object, input_tensors); - *tensor_mask = kValueNodeTensorMask; - } -} - -void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, - std::vector *input_tensors, int *tensor_mask) { - MS_EXCEPTION_IF_NULL(op_prim); - MS_EXCEPTION_IF_NULL(input_tensors); - MS_EXCEPTION_IF_NULL(tensor_mask); - tensor::TensorPtr tensor_ptr = nullptr; - if (py::isinstance(input_object)) { - tensor_ptr = py::cast(input_object); - } else if (py::isinstance(input_object)) { - double input_value = py::cast(input_object); - tensor_ptr = std::make_shared(input_value, kFloat32); - *tensor_mask = kValueNodeTensorMask; - } else if (py::isinstance(input_object)) { - tensor_ptr = std::make_shared(py::cast(input_object), kInt32); - *tensor_mask = kValueNodeTensorMask; - } else if (py::isinstance(input_object)) { - tensor_ptr = TensorPy::MakeTensor(py::cast(input_object), nullptr); - } else if (py::isinstance(input_object)) { - auto list_inputs = py::cast(input_object); - py::tuple tuple_inputs(list_inputs.size()); - for (size_t i = 0; i < tuple_inputs.size(); ++i) { - tuple_inputs[i] = list_inputs[i]; - } - ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask); - return; - } else if (py::isinstance(input_object)) { - ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask); - return; - } else if (py::isinstance(input_object)) { - return; - } else { - MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; - } - MS_EXCEPTION_IF_NULL(tensor_ptr); - input_tensors->push_back(tensor_ptr); -} - -void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *tensors_mask, - std::vector *input_tensors) { - MS_EXCEPTION_IF_NULL(op_run_info); - MS_EXCEPTION_IF_NULL(tensors_mask); - MS_EXCEPTION_IF_NULL(input_tensors); - PrimitivePtr op_prim = op_run_info->py_primitive; - MS_EXCEPTION_IF_NULL(op_prim); - - if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) { - MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size " - << op_run_info->inputs_mask.size(); - } - opt::ConstInputToAttrInfoRegister reg; - bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); - size_t input_num = op_run_info->op_inputs.size(); - for (size_t index = 0; index < input_num; ++index) { - // convert const input to attr - if (reg_exist && - RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) { - continue; - } - // convert const and tuple input to tensor - int tensor_mask = py::cast(op_run_info->inputs_mask[index]); - ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask); - // mark tensors, data : 0, weight : 1, valuenode: 2 - std::vector new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); - tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); - } -} - -void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector *input_tensors) { - MS_EXCEPTION_IF_NULL(input_tensors); - if (input_tensors->size() != tensors_mask.size()) { - MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size " - << tensors_mask.size(); - } - std::vector new_input_tensors; - for (size_t index = 0; index < tensors_mask.size(); ++index) { - if (tensors_mask[index] != kValueNodeTensorMask) { - new_input_tensors.push_back(input_tensors->at(index)); - } - } - *input_tensors = new_input_tensors; -} - -py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { - MS_EXCEPTION_IF_NULL(op_exec_info); - MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - ms_context->set_enable_pynative_infer(true); - std::string device_target = ms_context->device_target(); - if (device_target != kAscendDevice && device_target != kGPUDevice) { - MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; - } - - if (session == nullptr) { - session = session::SessionFactory::Get().Create(device_target); - } - MS_EXCEPTION_IF_NULL(session); - session->Init(ms_context->device_id()); - - std::vector input_tensors; - std::vector tensors_mask; - ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); - // get graph info for checking it whether existing in the cache - std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); - session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask); - EraseValueNodeTensor(tensors_mask, &input_tensors); - py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); - ms_context->set_enable_pynative_infer(false); - *status = PYNATIVE_SUCCESS; - return result; -} - -py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, - PynativeStatusCode *const status) { - MS_EXCEPTION_IF_NULL(status); - py::object result; - switch (backend_policy) { - case kMsBackendVmOnly: { - // use vm only - MS_LOG(INFO) << "RunOp use VM only backend"; - result = RunOpInVM(op_exec_info, status); - break; - } - case kMsBackendGePrior: { -#ifdef ENABLE_GE - // use GE first, use vm when GE fails - MS_LOG(INFO) << "RunOp use GE first backend"; - result = RunOpInGE(op_exec_info, status); - if (*status != PYNATIVE_SUCCESS) { - result = RunOpInVM(op_exec_info, status); - } -#endif - break; - } - case kMsBackendMsPrior: { - // use Ms fisrt,use others when ms failed - MS_LOG(INFO) << "RunOp use Ms first backend"; - result = RunOpInMs(op_exec_info, status); - if (*status != PYNATIVE_SUCCESS) { - MS_LOG(ERROR) << "RunOp use Ms backend failed!!!"; - } - break; - } - default: - MS_LOG(ERROR) << "No backend configured for run op"; - } - return result; -} - -AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) { - if (!grad_flag_ || graph_info_map_.empty()) { - return nullptr; - } - std::vector inputs; - auto prim = op_exec_info->py_primitive; - inputs.push_back(NewValueNode(prim)); - py::tuple op_masks = op_exec_info->inputs_mask; - AbstractBasePtrList args_spec_list; - for (size_t i = 0; i < args.size(); i++) { - auto node = GetInput(args[i], op_masks[i]); - args_spec_list.push_back(node->abstract()); - inputs.push_back(node); - } - - auto cnode = curr_g_->NewCNode(inputs); - MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4); - py::object out_real = out; - if (out.size() == 1) { - MS_LOG(DEBUG) << "MakeCnode out size is one."; - out_real = out[0]; - } - std::string obj_id = GetId(out_real); - if (py::isinstance(out_real)) { - auto value = py::cast(out_real); - if (value.size() > 1) { - for (int i = 0; i < static_cast(value.size()); i++) { - auto value_id = GetId(value[i]); - MS_LOG(DEBUG) << "MakeCnode set node id " << value_id; - set_obj_node_map(curr_g_, value_id, cnode, i); - } - } - } - MS_LOG(DEBUG) << "MakeCnode set node id " << obj_id; - set_obj_node_map(curr_g_, obj_id, cnode); - set_pyobj(curr_g_, obj_id); - return cnode; -} - -AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { - auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)]; - if (out.second.size() == 1 && out.second[0] == -1) { - return out.first; - } - auto node = out.first; - MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString(); - for (auto &idx : out.second) { - std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)}; - node = curr_g_->NewCNode(tuple_get_item_inputs); - } - MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); - return node; -} - -py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) { - MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; - mindspore::parse::python_adapter::set_python_env_flag(true); - MsBackendPolicy backend_policy; -#if (!defined ENABLE_GE) - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->backend_policy() == "ms") { - backend_policy = kMsBackendMsPrior; - } else { - backend_policy = kMsBackendVmOnly; - } -#else - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - ms_context->PynativeInitGe(); - backend_policy = kMsBackendGeOnly; -#endif - if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) { - backend_policy = kMsBackendVmOnly; - } - PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE; - // returns a null py::tuple on error - py::tuple err_ret(0); - py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status); - if (status != PYNATIVE_SUCCESS) { - MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name; - return err_ret; - } - - auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result); - if (node != nullptr) { - node->set_abstract(op_exec_info->abstract); - MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString(); - } - MS_LOG(DEBUG) << "RunOp end"; - return result; -} - -py::tuple RunOpInner(const py::args &args) { - MS_LOG(DEBUG) << "RunOp start" << args.size(); - py::list args_input = args[PY_INPUTS]; - - OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args, &args_input); - MS_EXCEPTION_IF_NULL(op_exec_info); - - if (op_exec_info->abstract != nullptr) { - py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); - if (!output["value"].is_none()) { - py::tuple value_ret(1); - value_ret[0] = output["value"]; - return value_ret; - } - if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { - py::tuple value_ret(1); - value_ret[0] = ""; - return value_ret; - } - } - return RunOpInner(op_exec_info, args_input); -} - -py::tuple RunOp(const py::args &args) { - try { - return RunOpInner(args); - } catch (const py::error_already_set &ex) { - // print function call stack info before release - std::ostringstream oss; - trace::TraceGraphEval(); - trace::GetEvalStackInfo(oss); - // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see - // these info from screen, no need to open log file to find these info - py::print(oss.str()); - MS_LOG(ERROR) << oss.str(); - PynativeExecutor::GetInstance()->Clean(); - // re-throw this exception to Python interpreter to handle it - throw(py::error_already_set(ex)); - } catch (const py::type_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::type_error(ex); - } catch (const py::value_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::value_error(ex); - } catch (const py::index_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::index_error(ex); - } catch (const std::exception &ex) { - PynativeExecutor::GetInstance()->Clean(); - // re-throw this exception to Python interpreter to handle it - throw(std::runtime_error(ex.what())); - } catch (...) { - PynativeExecutor::GetInstance()->Clean(); - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; - } -} - -void ClearPyNativeSession() { session = nullptr; } - -PynativeExecutor::~PynativeExecutor() { ClearRes(); } - -PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } - -void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { - auto cell_id = GetId(cell); - if (cell_graph_map_.count(cell_id) != 0) { - MS_LOG(DEBUG) << "Newgraph already compiled"; - return; - } - - auto g = std::make_shared(); - - if (top_g_ == nullptr) { - top_g_ = curr_g_ = g; - df_builder_ = std::make_shared(); - MS_LOG(DEBUG) << "First new graph" << top_g_.get(); - Pushp(); - } else { - Pushp(); - curr_g_ = g; - } - if (graph_info_map_.count(g) == 0) { - graph_info_map_[g] = GraphInfo(); - } - for (size_t i = 0; i < args.size(); i++) { - auto new_param = g->add_parameter(); - std::string param_obj = GetId(args[i]); - graph_info_map_[g].param_map[param_obj] = new_param; - } -} - -AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) { - ValuePtr converted_ret = nullptr; - parse::ConvertData(obj, &converted_ret); - auto node = NewValueNode(converted_ret); - set_obj_node_map(curr_g_, obj_id, node); - return node; -} - -AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) { - AnfNodePtr node = nullptr; - std::string obj_id = GetId(obj); - - if (op_mask != nullptr && py::cast(op_mask)) { - MS_LOG(DEBUG) << "Topgraph free parameter"; - // get the parameter name from parameter object - auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name"); - if (py::isinstance(name_attr)) { - MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; - } - auto param_name = py::cast(name_attr); - if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { - auto free_param = df_builder_->add_parameter(); - free_param->set_name(param_name); - auto free_param_new = std::make_shared(obj); - free_param->set_default_param(free_param_new); - free_param->debug_info()->set_name(param_name); - MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; - graph_info_map_[df_builder_].param_map[obj_id] = free_param; - return free_param; - } - return graph_info_map_[df_builder_].param_map[obj_id]; - } - - // if input is graph output - if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) { - // op(x, y) - node = graph_info_map_[curr_g_].param_map[obj_id]; - } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) { - // out = op(op1(x, y)) - // out = op(cell1(x, y)) - // out = op(cell1(x, y)[0]) - node = GetObjNode(obj); - } else if (py::isinstance(obj)) { - // out = op((x, y)) - // out = cell((x, y)) - auto tuple = obj.cast(); - - // cell((1,2)): support not mix (scalar, tensor) - if (tuple.size() > 0 && !py::isinstance(tuple[0])) { - return MakeValueNode(obj, obj_id); - } - - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - - auto tuple_size = static_cast(tuple.size()); - for (int i = 0; i < tuple_size; i++) { - args.push_back(GetInput(tuple[i], py::object())); - } - auto cnode = curr_g_->NewCNode(args); - set_obj_node_map(curr_g_, GetId(obj), cnode); - node = cnode; - } else { - node = MakeValueNode(obj, obj_id); - } - - MS_LOG(DEBUG) << "Now getinput node " << node->ToString() << obj_id; - return node; -} - -// for output[0][1] need getitem multi -void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx) { - if (py::isinstance(obj)) { - auto tuple = obj.cast(); - for (int i = 0; i < static_cast(tuple.size()); i++) { - std::vector tmp = idx; - tmp.push_back(i); - set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, tmp); - SetTupleOutput(tuple[i], cnode, tmp); - } - } -} - -void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); } - -void PynativeExecutor::Popp() { - if (graph_p_.empty()) { - MS_LOG(EXCEPTION) << "Stack graph_p_ is empty"; - } - curr_g_ = graph_p_.top(); - graph_p_.pop(); -} - -void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { - auto cell_id = GetId(cell); - if (cell_graph_map_.count(cell_id) != 0) { - MS_LOG(DEBUG) << "Endgraph already compiled"; - return; - } - cell_graph_map_[cell_id] = curr_g_; - auto out_id = GetId(out); - if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) { - // cell construct return x, y - if (py::isinstance(out)) { - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - - auto tuple = out.cast(); - MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size(); - auto tuple_size = static_cast(tuple.size()); - auto cnode = curr_g_->NewCNode(args); - for (int i = 0; i < tuple_size; i++) { - args.push_back(GetInput(tuple[i], py::object())); - set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i); - SetTupleOutput(tuple[i], cnode, std::vector{i}); - } - cnode->set_inputs(args); - set_obj_node_map(curr_g_, out_id, cnode); - } else { - MS_LOG(ERROR) << "Graph has no this out: " << out_id; - return; - } - } - EndGraphByOutId(out_id, cell, out, args); -} - -void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, - const py::args &args) { - AnfNodePtr output_node; - if (graph_info_map_[curr_g_].param_map.count(out_id)) { - output_node = graph_info_map_[curr_g_].param_map[out_id]; - } else { - output_node = GetObjNode(out); - } - curr_g_->set_output(output_node); - std::vector inputs; - inputs.push_back(NewValueNode(curr_g_)); - MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString(); - resource_->manager()->AddFuncGraph(curr_g_); - // custom bprop debug - if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { - MS_LOG(DEBUG) << "Use cell custom bprop function."; - FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell); - if (bprop_graph != nullptr) { - (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); - (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_))); - } - } - auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); - if (curr_g_ != top_g_) { - Popp(); - for (size_t i = 0; i < args.size(); i++) { - auto input = GetInput(args[i], py::object()); - inputs.push_back(input); - } - auto out_cnode = curr_g_->NewCNode(inputs); - set_pyobj(curr_g_, GetId(cell)); - if (py::isinstance(out)) { - auto out_list = py::cast(out); - auto out_size = static_cast(out_list.size()); - for (int i = 0; i < out_size; i++) { - set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i); - SetTupleOutput(out_list[i], out_cnode, std::vector{i}); - } - } - set_obj_node_map(curr_g_, GetId(out), out_cnode); - } else { - parse::ResolveFuncGraph(newfg, resource_); - resource_->set_func_graph(newfg); - } -} - -std::vector PynativeExecutor::GetWeightsArgs(const py::object &weights) { - std::vector w_args; - if (py::hasattr(weights, "__parameter_tuple__")) { - auto tuple = weights.cast(); - MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size(); - w_args.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t it = 0; it < tuple.size(); ++it) { - auto param = tuple[it]; - auto param_id = GetId(param); - AnfNodePtr para_node = nullptr; - if (graph_info_map_[df_builder_].param_map.count(param_id)) { - para_node = graph_info_map_[df_builder_].param_map[param_id]; - - AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node); - AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); - auto refkey = std::make_shared(para_node->cast()->name()); - AnfNodePtr ref_key_node = NewValueNode(refkey); - AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node}); - - w_args.push_back(ref_node); - } - } - } else { - MS_LOG(EXCEPTION) << "training not paramter_tuple"; - } - return w_args; -} - -abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) { - abstract::AbstractBasePtrList args_spec; - std::size_t size = args.size(); - for (std::size_t i = 0; i < size; i++) { - ValuePtr converted = nullptr; - bool succ = parse::ConvertData(args[i], &converted); - if (!succ) { - MS_LOG(EXCEPTION) << "Args convert error"; - } - bool broaden = true; - auto abs = abstract::FromValue(converted, broaden); - args_spec.push_back(abs); - auto param_node = std::static_pointer_cast(df_builder_->parameters()[i]); - param_node->set_abstract(abs); - } - - for (const auto ¶m : df_builder_->parameters()) { - auto param_node = std::static_pointer_cast(param); - if (param_node->has_default()) { - auto param_value = std::dynamic_pointer_cast(param_node->default_param()); - AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true); - if (ptr == nullptr) { - MS_LOG(EXCEPTION) << "Args convert error"; - } - args_spec.push_back(ptr); - param_node->set_abstract(ptr); - } - } - - return args_spec; -} - -void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, - const py::args &args) { - MS_LOG(INFO) << "GradNet start" << args.size(); - - std::size_t size = args.size(); - auto cell_id = GetId(cell); - if (graph_map_.count(cell_id) != 0) { - MS_LOG(DEBUG) << "GradNet already compiled"; - return; - } - MS_LOG(DEBUG) << "GradNet first compiled"; - std::vector new_params; - for (size_t i = 0; i < size; i++) { - ParameterPtr p = std::make_shared(df_builder_); - new_params.push_back(p); - } - MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size(); - new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end()); - df_builder_->set_parameters(new_params); - resource_->manager()->SetParameters(df_builder_, new_params); - - std::vector w_args = GetWeightsArgs(weights); - MS_EXCEPTION_IF_NULL(resource_->func_graph()); - auto g = GradGraph(resource_->func_graph(), grad, w_args, size); - resource_->set_func_graph(g); - resource_->manager()->KeepRoots({g}); - - // get the parameters items and add the value to args_spec - abstract::AbstractBasePtrList args_spec = GetArgsSpec(args); - MS_LOG(DEBUG) << "Args_spec size" << args_spec.size(); - - resource_->set_args_spec(args_spec); - MS_LOG(DEBUG) << "Start opt"; - - // Create backend and session - resource_->results()[pipeline::kBackend] = compile::CreateBackend(); - - graph_map_[cell_id] = g; - PynativeOptimizeAction(resource_); - TaskEmitAction(resource_); - ExecuteAction(resource_); - resource_->Clean(); - ad::CleanRes(); - pipeline::ReclaimOptimizer(); -} - -void PynativeExecutor::Clear(const std::string &flag) { - if (!flag.empty()) { - MS_LOG(INFO) << "Clear res"; - (void)graph_map_.erase(flag); - (void)cell_graph_map_.erase(flag); - Clean(); - // Maybe exit in the pynative runing op, so need reset pynative flag. - auto ms_context = MsContext::GetInstance(); - if (ms_context != nullptr) { - ms_context->set_enable_pynative_infer(false); - } - return; - } - - MS_LOG(INFO) << "Clear"; - top_g_ = nullptr; - curr_g_ = nullptr; - graph_info_map_.clear(); - std::stack().swap(graph_p_); -} - -void PynativeExecutor::Clean() { - MS_LOG(INFO) << "Clean all res"; - Clear(); - grad_flag_ = false; - df_builder_ = nullptr; - ad::CleanRes(); - pipeline::ReclaimOptimizer(); -} - -void PynativeExecutor::ClearRes() { - Clean(); - resource_.reset(); -} - -py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) { - VectorRef arg_list; - pipeline::ProcessVmArgInner(args, resource_, &arg_list); - if (resource_->results().find(pipeline::kOutput) == resource_->results().end() || - !resource_->results()[pipeline::kOutput].is()) { - MS_LOG(EXCEPTION) << "Can't find run graph func for "; - } - compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast(); - if (run == nullptr) { - MS_LOG(EXCEPTION) << "Can't find run graph func for "; - } - - std::string backend = MsContext::GetInstance()->backend_policy(); - - MS_LOG(DEBUG) << "Eval run" << backend; - BaseRef value = (*run)(arg_list); - MS_LOG(DEBUG) << "Run end" << value.ToString(); - return BaseRefToPyData(value); -} - -FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, - const std::vector &weights, size_t arg_size) { - auto nparam = top_g_->parameters().size(); - std::ostringstream ss; - ss << "grad{" << nparam << "}"; - df_builder_->set_flag(FUNC_GRAPH_FLAG_CORE, true); - df_builder_->debug_info()->set_name(ss.str()); - - auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights); - std::vector inputs = {NewValueNode(df)}; - for (size_t i = 0; i < arg_size; ++i) { - inputs.push_back(df_builder_->parameters()[i]); - } - auto out = df_builder_->NewCNode(inputs); - df_builder_->set_output(out); - resource_->manager()->AddFuncGraph(df); - resource_->manager()->AddFuncGraph(df_builder_); - return df_builder_; -} - -void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { - PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args); -} - -void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { - PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args); -} - -void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, - const py::args &args) { - PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args); -} - -REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { - (void)py::class_>(*m, "PynativeExecutor_") - .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") - .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.") - .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.") - .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.") - .def("clear", &PynativeExecutor::Clear, "pynative clear status.") - .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""), - "Executor run function.") - .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), - "Executor set grad flag."); - })); -} // namespace pynative -} // namespace mindspore diff --git a/mindspore/ccsrc/pynative/pynative_execute.h b/mindspore/ccsrc/pynative/pynative_execute.h deleted file mode 100644 index 83cbea88d4..0000000000 --- a/mindspore/ccsrc/pynative/pynative_execute.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ -#define MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pybind11/pybind11.h" -#include "pybind11/numpy.h" - -#include "pynative/base.h" -#include "utils/context/ms_context.h" -#include "ir/anf.h" -#include "pipeline/resource.h" -#include "operator/composite/composite.h" - -namespace mindspore { -namespace pynative { - -namespace py = pybind11; -using ResourcePtr = std::shared_ptr; -using GradOperationPtr = std::shared_ptr; - -py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); - -py::tuple RunOp(const py::args &args); - -py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args, - py::list *const out_args_list); - -void ClearPyNativeSession(); - -struct GraphInfo { - std::unordered_map param_map; - std::unordered_map>> obj_node_map; - AnfNodePtr output; - std::vector objects; -}; - -class PynativeExecutor : public std::enable_shared_from_this { - public: - static std::shared_ptr GetInstance() { - std::lock_guard i_lock(instance_lock_); - if (executor_ == nullptr) { - executor_ = std::shared_ptr(new (std::nothrow) PynativeExecutor()); - resource_ = std::make_shared(); - } - return executor_; - } - void NewGraph(const py::object &cell, const py::args &args); - void NewGraphInner(const py::object &cell, const py::args &args); - void EndGraph(const py::object &cell, const py::object &out, const py::args &args); - void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); - void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args); - std::vector GetWeightsArgs(const py::object &weights); - abstract::AbstractBasePtrList GetArgsSpec(const py::args &args); - void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); - void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, - const py::args &args); - void Clear(const std::string &flag = ""); - void Clean(); - void ClearRes(); - bool grad_flag() { return grad_flag_; } - void set_grad_flag(bool flag) { grad_flag_ = flag; } - AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask); - AnfNodePtr GetObjNode(const py::object &obj); - FuncGraphPtr curr_g() { return curr_g_; } - void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } - void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { - graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{-1}); - } - void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { - graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{index}); - } - void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { - graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); - } - AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out); - py::object Run(const py::tuple &args, const py::object &phase); - - void Pushp(); - void Popp(); - FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector &weights, - size_t arg_size); - void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx); - AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); - - ~PynativeExecutor(); - - private: - PynativeExecutor(); - static std::shared_ptr executor_; - static std::mutex instance_lock_; - static ResourcePtr resource_; - bool grad_flag_; - std::unordered_map graph_map_; - std::unordered_map cell_graph_map_; - std::unordered_map graph_info_map_; - std::stack graph_p_; - FuncGraphPtr top_g_; - FuncGraphPtr df_builder_; - FuncGraphPtr curr_g_; -}; - -using PynativeExecutorPtr = std::shared_ptr; - -} // namespace pynative -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ diff --git a/mindspore/ccsrc/pynative/pynative_execute_ge.cc b/mindspore/ccsrc/pynative/pynative_execute_ge.cc deleted file mode 100644 index 8e10468236..0000000000 --- a/mindspore/ccsrc/pynative/pynative_execute_ge.cc +++ /dev/null @@ -1,312 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pynative/pynative_execute_ge.h" - -#include -#include -#include -#include - -#include "utils/any.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "operator/ops.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/static_analysis/prim.h" -#include "session/session_factory.h" -#include "ir/tensor_py.h" - -const char SINGLE_OP_GRAPH[] = "single_op_graph"; - -using mindspore::tensor::TensorPy; - -namespace mindspore { -namespace pynative { -using MeTensor = mindspore::tensor::Tensor; -using MeTensorPtr = mindspore::tensor::TensorPtr; -using GeOperator = ge::Operator; -using GeOperatorPtr = std::shared_ptr; - -using transform::GraphRunner; -using transform::GraphRunnerOptions; -using transform::OperatorPtr; -static std::shared_ptr session = nullptr; -inline ValuePtr PyAttrValue(const py::object &obj) { - ValuePtr converted_ret = nullptr; - bool converted = parse::ConvertData(obj, &converted_ret); - if (!converted) { - MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); - } - return converted_ret; -} - -MeTensorPtr ConvertPyObjToTensor(const py::object &obj) { - MeTensorPtr me_tensor_ptr = nullptr; - if (py::isinstance(obj)) { - me_tensor_ptr = py::cast(obj); - } else if (py::isinstance(obj)) { - me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); - } else if (py::isinstance(obj)) { - me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); - } else if (py::isinstance(obj)) { - me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); - } else if (py::isinstance(obj)) { - me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); - } else if (py::isinstance(obj)) { - me_tensor_ptr = TensorPy::MakeTensor(py::cast(obj), nullptr); - } else { - MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; - } - return me_tensor_ptr; -} - -bool SetInputsForSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, - const OperatorPtr &op, std::vector *graph_input_nodes) { - MS_EXCEPTION_IF_NULL(op_exec_info); - MS_EXCEPTION_IF_NULL(graph_input_nodes); - auto op_inputs = op_exec_info->op_inputs; - std::string op_name = op_exec_info->op_name; - transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); - if (adapter == nullptr) { - return false; - } - - int op_input_idx = 1; - size_t size = inputs.size(); - for (size_t i = 0; i < size; i++) { - if (inputs[i] == nullptr) { - continue; - } - auto const_op = std::make_shared(); - MS_EXCEPTION_IF_NULL(const_op); - (void)const_op->set_attr_value(*inputs[i]); - MeTensorPtr me_tensor_ptr = ConvertPyObjToTensor(op_inputs[i]); - MS_EXCEPTION_IF_NULL(me_tensor_ptr); - auto const_op_desc = - transform::TransformUtil::GetGeTensorDesc(me_tensor_ptr->shape_c(), me_tensor_ptr->data_type(), kOpFormat_NCHW); - if (const_op_desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << op_name << " output descriptor failed!"; - return false; - } - auto pointer_cast_const_op = std::static_pointer_cast(const_op); - MS_EXCEPTION_IF_NULL(pointer_cast_const_op); - (void)pointer_cast_const_op->update_output_desc_y(*const_op_desc); - auto &input_map = adapter->getInputMap(); - if (input_map.find(op_input_idx) == input_map.end()) { - continue; - } - if (adapter->setInput(op, op_input_idx++, const_op)) { - MS_LOG(ERROR) << "Failed to set params, index is " << op_input_idx; - return false; - } - graph_input_nodes->push_back(*const_op); - } - return true; -} - -bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, - const std::unordered_map &attrs, const GeGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(op_exec_info); - std::string op_name = op_exec_info->op_name; - auto op_inputs = op_exec_info->op_inputs; - transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); - if (adapter == nullptr) { - MS_LOG(ERROR) << "Unable to find Adapter for " << ((std::string)py::str(op_name)); - return false; - } - OperatorPtr op = adapter->generate(op_name); - MS_EXCEPTION_IF_NULL(op); - - std::vector graph_input_nodes; - // hold param nodes after setting input and output for the graph - // set input - if (!SetInputsForSingleOpGraph(op_exec_info, inputs, op, &graph_input_nodes)) { - return false; - } - // set attributes - for (auto attr : attrs) { - (void)adapter->setAttr(op, attr.first, attr.second); - } - // set default attributes - auto extra_attrs = adapter->GetExtraAttr(); - for (auto attr : extra_attrs) { - (void)adapter->setAttr(op, attr.first, attr.second); - } - // set input attributes - auto &input_attr_map = adapter->getInputAttrMap(); - for (auto &it : input_attr_map) { - if (op_inputs.size() < it.first) { - continue; - } - auto const_value = PyAttrValue(op_inputs[it.first - 1]); - if (const_value->isa()) { - continue; - } - it.second.set_attr(op, const_value); - } - // construct output data nodes - std::vector graph_outputs{*op}; - // set input and output nodes for the graph - MS_EXCEPTION_IF_NULL(graph); - (void)graph->SetInputs(graph_input_nodes).SetOutputs(graph_outputs); - MS_LOG(INFO) << "BuildSingleOpGraph done"; - return true; -} - -void ToTensorPtr(const OpExecInfoPtr op_exec_info, std::vector *const inputs) { - MS_EXCEPTION_IF_NULL(inputs); - MS_EXCEPTION_IF_NULL(op_exec_info); - auto op_inputs = op_exec_info->op_inputs; - size_t size = op_inputs.size(); - for (size_t i = 0; i < size; i++) { - if (py::isinstance(op_inputs[i])) { - inputs->emplace_back(nullptr); - continue; - } - MeTensorPtr me_tensor_ptr = ConvertPyObjToTensor(op_inputs[i]); - auto ge_tensor_ptr = transform::TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW); - if (ge_tensor_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Convert inputs to GE tensor failed in op " << op_exec_info->op_name << "."; - } - // set inputs for operator to build single node graph - inputs->push_back(ge_tensor_ptr); - } -} - -PynativeStatusCode ConvertAttributes(const OpExecInfoPtr &op_exec_info, const std::vector &inputs) { - MS_EXCEPTION_IF_NULL(op_exec_info); - auto op_attrs = op_exec_info->op_attrs; - std::unordered_map attrs{}; - - for (auto &item : op_attrs) { - if (!py::isinstance(item.first)) { - MS_LOG(ERROR) << "Type error in py dict convert"; - return PYNATIVE_OP_ATTRS_ERR; - } - std::string name = py::cast(item.first); - auto attr_value = PyAttrValue(py::cast(item.second)); - (void)attrs.emplace(name, attr_value); - } - - // build graph - GeGraphPtr graph = std::make_shared(op_exec_info->op_name); - if (BuildSingleOpGraph(op_exec_info, inputs, attrs, graph) == false) { - MS_LOG(ERROR) << "Failed to BuildSingleOpGraph"; - return PYNATIVE_GRAPH_GE_BUILD_ERR; - } - - // add the single op graph into the graph manager, which will be iterated by session. - transform::Status ret = - transform::DfGraphManager::GetInstance().AddGraph(SINGLE_OP_GRAPH, std::shared_ptr(graph)); - if (ret != transform::SUCCESS) { - MS_LOG(ERROR) << "Failed to AddGraph into graph manager"; - return PYNATIVE_GRAPH_MANAGER_ERR; - } - - return PYNATIVE_SUCCESS; -} - -std::vector ConvertOutputTensors(const OpExecInfoPtr &op_exec_info, - const std::vector &ge_tensors) { - std::vector outputs; - AbstractBasePtr abs_base = op_exec_info->abstract; - std::vector> shapes; - if (abs_base != nullptr && abs_base->isa()) { - auto arg_tensor = dyn_cast(abs_base); - shapes.emplace_back(arg_tensor->shape()->shape()); - outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); - return outputs; - } - if (abs_base != nullptr && abs_base->isa()) { - auto arg_tuple = dyn_cast(abs_base); - size_t len = arg_tuple->size(); - - for (size_t i = 0; i < len; i++) { - if (arg_tuple->elements()[i]->isa()) { - auto arg_tensor = dyn_cast(arg_tuple->elements()[i]); - shapes.emplace_back(arg_tensor->shape()->shape()); - } - } - outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); - return outputs; - } - for (auto &it : ge_tensors) { - auto tensor = transform::TransformUtil::ConvertGeTensor(it); - if (tensor != nullptr) { - outputs.emplace_back(tensor); - } - } - return outputs; -} - -py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { - MS_LOG(INFO) << "RunOpInGe start"; - MS_EXCEPTION_IF_NULL(op_exec_info); - MS_EXCEPTION_IF_NULL(status); - - // returns a null py::tuple on error - py::tuple err_ret(0); - auto op_name = op_exec_info->op_name; - transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); - if (adapter == nullptr) { - MS_LOG(ERROR) << "Unable to find GE Adapter for " << ((std::string)py::str(op_name)); - *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; - return std::move(err_ret); - } - - std::vector inputs{}; - ToTensorPtr(op_exec_info, &inputs); - // convert me attr to ge AttrValue - PynativeStatusCode ret = ConvertAttributes(op_exec_info, inputs); - if (ret != PYNATIVE_SUCCESS) { - *status = ret; - return std::move(err_ret); - } - // run graph - transform::RunOptions run_options; - run_options.name = SINGLE_OP_GRAPH; - std::vector ge_inputs; - std::vector ge_outputs; - transform::GraphRunnerOptions graph_runner_options; - graph_runner_options.options["ge.trainFlag"] = "1"; - auto graph_runner = std::make_shared(graph_runner_options); - transform::Status run_ret; - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - run_ret = graph_runner->RunGraph(run_options, ge_inputs, &ge_outputs); - } - if (run_ret != transform::Status::SUCCESS) { - MS_LOG(ERROR) << "GraphRunner fails to run graph"; - *status = PYNATIVE_GRAPH_GE_RUN_ERR; - return std::move(err_ret); - } - - std::vector graph_outputs = ConvertOutputTensors(op_exec_info, ge_outputs); - size_t output_size = graph_outputs.size(); - py::tuple result(output_size); - for (size_t i = 0; i < output_size; i++) { - MS_EXCEPTION_IF_NULL(graph_outputs[i]); - result[i] = *graph_outputs[i]; - } - - *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "RunOpInGe end"; - return std::move(result); -} -} // namespace pynative -} // namespace mindspore diff --git a/mindspore/ccsrc/pynative/pynative_execute_ge.h b/mindspore/ccsrc/pynative/pynative_execute_ge.h deleted file mode 100644 index 2dca3df018..0000000000 --- a/mindspore/ccsrc/pynative/pynative_execute_ge.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ -#define MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ - -#include -#include -#include -#include -#include - -#include "pynative/base.h" -#include "transform/convert.h" -#include "transform/graph_runner.h" -#include "transform/types.h" -#include "utils/context/ms_context.h" - -using GeTensor = ge::Tensor; -using GeTensorPtr = std::shared_ptr; -using GeGraph = ge::Graph; -using GeGraphPtr = std::shared_ptr; - -namespace mindspore { -namespace pynative { -bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, - const std::unordered_map &attrs, const GeGraphPtr &graph); - -py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); -} // namespace pynative -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ diff --git a/mindspore/ccsrc/runtime/device/CMakeLists.txt b/mindspore/ccsrc/runtime/device/CMakeLists.txt new file mode 100644 index 0000000000..9c95aee0dc --- /dev/null +++ b/mindspore/ccsrc/runtime/device/CMakeLists.txt @@ -0,0 +1,65 @@ +file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc" + "kernel_info.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" +) + +if (ENABLE_GPU) + list(APPEND DEVICE_SRC_LIST "gpu/distribution/collective_init.cc") +else () + list(APPEND DEVICE_SRC_LIST "gpu/distribution/collective_fake_init.cc") +endif () + +if (ENABLE_D) + file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc" "kernel_adjust.cc") +endif () + +if (ENABLE_CPU) + file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc") +endif () + +if (ENABLE_MPI) + # _ms_mpi + file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc") + set_property(SOURCE ${MPI_SRC_LIST} + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + add_library(mpi_adapter SHARED ${MPI_SRC_LIST}) + target_link_libraries(mpi_adapter PRIVATE mindspore::ompi) + + set_property(SOURCE "gpu/mpi/mpi_initializer.cc" + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") + target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) +endif () + +# gpu +if (ENABLE_GPU) + file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc" "gpu/*.cu") + + set(GPU_QUEUE_SRCS "gpu/blocking_queue.cc" "gpu/gpu_buffer_mgr.cc") + set(GPU_COLLECTIVE_SRCS "gpu/distribution/collective_wrapper.cc" + "gpu/distribution/mpi_wrapper.cc" + "gpu/distribution/nccl_wrapper.cc") + + # gpu_queue + list(REMOVE_ITEM CUDA_SRC_LIST ${GPU_QUEUE_SRCS}) + set_property(SOURCE ${GPU_QUEUE_SRCS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + add_library(gpu_queue SHARED ${GPU_QUEUE_SRCS}) + target_link_libraries(gpu_queue ${CMAKE_THREAD_LIBS_INIT} ${CUDA_PATH}/lib64/libcudart.so) + + list(REMOVE_ITEM CUDA_SRC_LIST "gpu/mpi/mpi_initializer.cc" ${GPU_COLLECTIVE_SRCS}) + + if (ENABLE_MPI) + include(ExternalProject) + # gpu_collective + set_property(SOURCE ${GPU_COLLECTIVE_SRCS} + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS}) + target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl) + endif () + + # add_library(_mindspore_device_cuda_obj OBJECT ${CUDA_SRC_LIST}) +endif () + +set_property(SOURCE ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST} + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) +add_library(_mindspore_runtime_device_obj OBJECT ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST}) diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc new file mode 100644 index 0000000000..1a87f3e6af --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -0,0 +1,408 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/ascend/ascend_device_address.h" +#include +#include +#include +#include +#include "runtime/mem.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "runtime/device/convert_tensor_utils.h" +#include "ir/dtype/type.h" +#include "ir/tensor.h" +#include "backend/kernel_compiler/common_utils.h" +#include "utils/utils.h" +#include "common/utils.h" +#include "common/trans.h" +#ifdef ENABLE_DUMP_E2E +#include "debug/e2e_dump.h" +#endif +#ifdef ENABLE_DEBUGGER +#include "debug/tensor_load.h" +#endif + +namespace mindspore { +namespace device { +namespace ascend { +const int FLOAT_LEN = sizeof(float); +const int FLOAT16_LEN = 2; // sizeof(float16); +const std::set kOpNeedTransFormat = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, + kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, + kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; + +void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) { + auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind); + if (ret_rt_memcpy != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "rtMemcpy failed"; + } +} + +bool FloatToHalfAndSyncHostToDevice(void *dst, size_t dst_size, const void *src, size_t src_size) { + auto elem_num = src_size / FLOAT_LEN; + if (elem_num != (dst_size / FLOAT16_LEN)) { + MS_EXCEPTION(ArgumentError) << "FloatToHalf failed. size not match src_size[" << src_size << "], dst_size[" + << dst_size << "]"; + } + std::vector half_data(elem_num); + FloatToHalf(half_data.data(), src, elem_num); + SyncMemory(dst, half_data.data(), dst_size, RT_MEMCPY_HOST_TO_DEVICE); + return true; +} + +bool Float64ToFloatAndSyncHostToDevice(void *dst, size_t dst_size, const void *src, size_t src_size) { + if (src_size / 2 != dst_size) { + MS_EXCEPTION(ArgumentError) << "src_size[" << src_size << "], dst_size[" << dst_size << "]"; + } + size_t elem_num = dst_size / sizeof(float); + auto host_tmp = std::vector(elem_num); + DoubleToFloat(host_tmp.data(), src, elem_num); + SyncMemory(dst, host_tmp.data(), dst_size, RT_MEMCPY_HOST_TO_DEVICE); + return true; +} + +bool SyncDeviceToHostAndHalfToFloat(void *dst, size_t dst_size, const void *src, size_t src_size) { + auto elem_num = src_size / FLOAT16_LEN; + if (elem_num != (dst_size / FLOAT_LEN)) { + MS_EXCEPTION(ArgumentError) << "HalfToFloat failed. size not match src_size[" << src_size << "], dst_size[" + << dst_size << "]"; + } + std::vector half_data(elem_num); + SyncMemory(half_data.data(), src, src_size, RT_MEMCPY_DEVICE_TO_HOST); + HalfToFloat(dst, half_data.data(), elem_num); + return true; +} + +bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *src, size_t src_size) { + if (src_size != dst_size / 2) { + MS_EXCEPTION(ArgumentError) << "src_size[" << src_size << "], dst_size[" << dst_size << "]"; + } + size_t elem_num = src_size / sizeof(float); + auto host_tmp = std::vector(elem_num); + SyncMemory(host_tmp.data(), src, src_size, RT_MEMCPY_DEVICE_TO_HOST); + FloatToDouble(dst, host_tmp.data(), elem_num); + return true; +} + +void AscendDeviceAddress::SyncStream() const { + MS_LOG(INFO) << "Start!"; + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() != kPynativeMode) { + MS_LOG(INFO) << "Finish!"; + return; + } + auto device_id = ms_context->device_id(); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); + MS_EXCEPTION_IF_NULL(runtime_instance); + auto ret = runtime_instance->SyncStream(); + if (!ret) { + MS_LOG(EXCEPTION) << "Sync stream error!"; + } + MS_LOG(INFO) << "Finish!"; +} + +bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t size, mindspore::TypeId type, + void *host_ptr) const { + MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) + << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; + SyncStream(); + bool sync_ok = false; + std::vector host_shape; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); + if (host_shape.empty()) { + host_shape.emplace_back(1); + } + if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { + if (type_id_ == type) { + SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST); + sync_ok = true; + } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { + sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); + } else { + auto shape_size = trans::ShapeSize(host_shape); + auto host = std::vector(size_); + SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); + const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; + sync_ok = trans::TransDataType(type_args, host_ptr); + if (!sync_ok) { + MS_LOG(ERROR) << "trans data type failed."; + return false; + } + } + } else { + auto iter = kOpNeedTransFormat.find(format_); + if (iter != kOpNeedTransFormat.end()) { + sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr); + } else { + MS_LOG(INFO) << "Can not find format transfer for :" << format_; + } + } + if (!sync_ok) { + MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_) + << ", host_type:" << TypeIdLabel(type); + return false; + } + return sync_ok; +} + +bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, + mindspore::TypeId type, void *host_ptr) const { + MS_LOG(INFO) << "SyncDeviceToHostAndConvertFormat, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) + << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; + bool sync_ok = false; + auto host_tmp = std::vector(size_); + SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); + std::vector host_shape; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); + std::vector device_shape; + if (host_shape.empty()) { + host_shape.emplace_back(1); + } + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { + device_shape = trans::TransShapeToDevice(host_shape, format_); + } else { + if (host_shape_.empty()) { + host_shape = trans::PaddingShapeTo4d(host_shape); + } else { + host_shape.clear(); + (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize); + } + + device_shape = trans::TransShapeToDevice(host_shape, format_); + } + if (type_id_ != type) { + const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, + host_shape, device_shape, type_id_}; + auto host = std::vector(size_); + sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data()); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + auto shape_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; + sync_ok = trans::TransDataType(type_args, host_ptr); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + } else { + const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, + host_shape, device_shape, type_id_}; + sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + } + return sync_ok; +} + +bool AscendDeviceAddress::SyncHostToDevice(const std::vector &shape, size_t size, mindspore::TypeId type, + const void *host_ptr) const { + MS_LOG(INFO) << "SyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) + << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; + SyncStream(); + bool sync_ok = false; + std::vector host_shape; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); + if (host_shape.empty()) { + host_shape.emplace_back(1); + } + if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { + if (type_id_ == type) { + SyncMemory(ptr_, host_ptr, size_, RT_MEMCPY_HOST_TO_DEVICE); + sync_ok = true; + } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { + sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); + } else { + auto shape_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; + auto host_tmp = std::vector(size_); + sync_ok = trans::TransDataType(type_args, host_tmp.data()); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans data type failed."; + return false; + } + SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); + } + } else { + auto iter = kOpNeedTransFormat.find(format_); + if (iter != kOpNeedTransFormat.end()) { + sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr); + } else { + MS_LOG(INFO) << "Can not find format transfer for :" << format_; + } + } + if (!sync_ok) { + MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_) + << ", host_type:" << TypeIdLabel(type); + return false; + } + return sync_ok; +} + +bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, + mindspore::TypeId type, const void *host_ptr) const { + bool sync_ok = false; + MS_LOG(INFO) << "ConvertFormatAndSyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) + << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; + std::vector host_shape; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); + if (host_shape.empty()) { + host_shape.emplace_back(1); + } + std::vector device_shape; + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { + device_shape = trans::TransShapeToDevice(host_shape, format_); + } else { + host_shape = trans::PaddingShapeTo4d(host_shape); + device_shape = trans::TransShapeToDevice(host_shape, format_); + } + if (type_id_ != type) { + auto shape_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; + auto host_tmp = std::vector(size_); + sync_ok = trans::TransDataType(type_args, host_tmp.data()); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans datatype failed."; + return false; + } + const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, + host_shape, device_shape, type_id_}; + auto dst_tmp = std::vector(size_); + sync_ok = trans::TransFormat(format_args, dst_tmp.data()); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + SyncMemory(ptr_, dst_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); + } else { + const trans::FormatArgs format_args{host_ptr, size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_}; + auto host_tmp = std::vector(size_); + sync_ok = trans::TransFormat(format_args, host_tmp.data()); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); + } + return sync_ok; +} + +void AscendDeviceAddress::UpdateCommunicationAddress() { + MS_EXCEPTION_IF_NULL(ptr_); + communication_ptr_ = reinterpret_cast(ptr_) - kMemAlignSize; +} + +AscendDeviceAddress::~AscendDeviceAddress() { + if (ptr_ == nullptr) { + return; + } + if (from_mem_pool_) { + if (communication_ptr_ != nullptr) { + AscendMemoryPool::GetInstance().FreeTensorMem(communication_ptr_); + communication_ptr_ = nullptr; + } else { + AscendMemoryPool::GetInstance().FreeTensorMem(ptr_); + } + ptr_ = nullptr; + } +} + +#ifdef ENABLE_DUMP_E2E +bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &filepath, const std::string &host_fmt, + const std::vector &host_shape, TypeId host_type) const { + bool ret = false; + if (filepath.empty()) { + MS_LOG(ERROR) << "Dump file path is null!"; + return ret; + } + std::string shape = "shape"; + if (host_shape.size()) { + for (auto &value : host_shape) { + shape = shape + '_' + std::to_string(value); + } + } else { + shape = shape + "_0"; + } + std::string file_extension = ".bin"; + if (trans_flag) { + std::string path = filepath + '_' + shape + '_' + TypeIdLabel(host_type) + '_' + host_fmt + file_extension; + MS_LOG(INFO) << "E2E Dump path is " << path; + mindspore::tensor::TensorPtr out_tensor = std::make_shared(host_type, host_shape); + size_t host_size = out_tensor->data().nbytes(); + ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); + if (!ret) { + MS_LOG(ERROR) << "Copy device mem to host failed"; + return ret; + } + ret = mindspore::Dump::DumpToFile(path, out_tensor->data_c(), host_size); + } else { + auto host_tmp = std::vector(size_); + auto ret_rt_memcpy = rtMemcpy(host_tmp.data(), size_, ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); + if (ret_rt_memcpy != RT_ERROR_NONE) { + MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; + } + std::string path = + filepath + '_' + shape + '_' + TypeIdToType(type_id_)->ToString() + '_' + format_ + file_extension; + MS_LOG(INFO) << "E2E Dump path is " << path; + ret = mindspore::Dump::DumpToFile(path, host_tmp.data(), size_); + } + + return ret; +} +#endif + +#ifdef ENABLE_DEBUGGER +bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tensor_name, int execution_order, + const std::string &host_fmt, const std::vector &host_shape, + TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const { + bool ret = false; + DebugServices *debug_services = debugger->debug_services(); + TensorLoader *tensor_loader = debug_services->tensor_loader(); + // TensorData is freed up in AscendSession class + auto tensor_data = std::make_shared(); + tensor_data->SetName(tensor_name); + tensor_data->SetExecutionOrder(execution_order); + tensor_data->SetSlot(slot); + if (trans_flag) { + MS_LOG(INFO) << "E2E tensor name is " << tensor_name; + mindspore::tensor::TensorPtr out_tensor = std::make_shared(host_type, host_shape); + size_t host_size = out_tensor->data().nbytes(); + ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); + if (!ret) { + MS_LOG(ERROR) << "Copy device mem to host failed"; + return ret; + } + tensor_data->SetTensor(out_tensor); + } else { + mindspore::tensor::TensorPtr out_tensor = std::make_shared(type_id_, host_shape); + size_t host_size = out_tensor->data().nbytes(); + auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST); + if (ret_rt_memcpy != RT_ERROR_NONE) { + MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; + } + MS_LOG(INFO) << "E2E tensor name is " << tensor_name; + tensor_data->SetTensor(out_tensor); + } + ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev); + return ret; +} +#endif +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h new file mode 100644 index 0000000000..78d7006b56 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ + +#include +#include +#include +#include "runtime/device/device_address.h" +#include "runtime/device/ascend/ascend_memory_pool.h" +#include "ir/dtype.h" + +namespace mindspore { +#ifdef ENABLE_DEBUGGER +class Debugger; +#endif +namespace device { +namespace ascend { +class AscendDeviceAddress : public DeviceAddress { + public: + explicit AscendDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} + explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id) + : DeviceAddress(ptr, size, format, type_id) {} + ~AscendDeviceAddress() override; + bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; + bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; } + void UpdateCommunicationAddress() override; +#ifdef ENABLE_DUMP_E2E + bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, + const std::vector &host_shape, TypeId host_type) const; +#endif +#ifdef ENABLE_DEBUGGER + bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt, + const std::vector &host_shape, TypeId host_type, size_t slot, Debugger *debugger, + bool keep_prev) const; +#endif + + private: + bool SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const; + bool ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, TypeId type, + const void *host_ptr) const; + void SyncStream() const; + uint8_t *communication_ptr_{nullptr}; +}; +using AscendDeviceAddressPtr = std::shared_ptr; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc new file mode 100644 index 0000000000..3ab3a52d42 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -0,0 +1,723 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#define PATH_MAX 0x3ffff +#include "runtime/device/ascend/ascend_kernel_runtime.h" +#include +#include +#include +#include +#include +#include +#include "runtime/device/ascend/ascend_device_address.h" +#include "runtime/device/cpu/mpi/mpi_adapter.h" +#include "utils/context/ms_context.h" +#include "utils/mpi/mpi_config.h" +#include "runtime/device/ascend/profiling/profiling_manager.h" +#include "hccl/hcom.h" +#include "common/trans.h" +#include "runtime/context.h" +#include "runtime/device/ascend/ascend_label_assign.h" +#include "runtime/device/ascend/ascend_stream_assign.h" +#include "runtime/device/ascend/ascend_memory_pool.h" +#include "framework/ge_runtime/model_runner.h" +#include "runtime/device/ascend/tasksink/task_generator.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/ascend/profiling/profiling_utils.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "backend/optimizer/mem_reuse/mem_reuse_checker.h" +#include "runtime/device/ascend/ascend_memory_manager.h" +#include "debug/tensor_load.h" + +using ge::model_runner::ModelRunner; +using mindspore::device::ascend::ProfilingManager; +using mindspore::device::ascend::ProfilingUtils; +using mindspore::device::ascend::tasksink::TaskGenerator; +using mindspore::kernel::tbe::TbeUtils; +using std::vector; + +namespace mindspore { +namespace device { +namespace ascend { +static const size_t PRAMATER_OUTPUT_INDEX = 0; +namespace { +std::string GetRankId() { + std::string rank_id_str; +#ifdef ENABLE_MPI + auto mpi_config_ptr = MpiConfig::GetInstance(); + MS_EXCEPTION_IF_NULL(mpi_config_ptr); + if (mpi_config_ptr->enable_mpi()) { + auto mpi_instance = device::cpu::MPIAdapter::Instance(); + MS_EXCEPTION_IF_NULL(mpi_instance); + int rank_id = mpi_instance->GetRankId(); + const char *offset = std::getenv("RANK_OFFSET"); + if (offset != nullptr) { + try { + int rank_offset = std::stoi(offset); + rank_id += rank_offset; + } catch (std::invalid_argument) { + MS_LOG(EXCEPTION) << "Call stoi invalid argument:" << offset; + } catch (std::out_of_range) { + MS_LOG(EXCEPTION) << "Call stoi out_of_range:" << offset; + } + } + rank_id_str = std::to_string(rank_id); + } else { + rank_id_str = std::getenv("RANK_ID"); + } +#else + rank_id_str = std::getenv("RANK_ID"); +#endif + if (rank_id_str.empty()) { + MS_LOG(ERROR) << "Get hccl rankid failed, please set env RANK_ID"; + } + return rank_id_str; +} +} // namespace + +AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); } + +void AscendKernelRuntime::ClearGraphModelMap() { +#ifdef ENABLE_DATA_DUMP + for (auto &iter : graph_data_dumper_) { + MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first; + iter.second->UnloadDumpInfo(); + } + graph_data_dumper_.clear(); +#endif + for (auto &iter : graph_model_map_) { + MS_LOG(INFO) << "Ge UnloadModel " << iter.first; + auto ret = ModelRunner::Instance().UnloadModel(iter.first); + if (!ret) { + MS_LOG(ERROR) << "UnloadModel failed"; + } + } +} + +void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { + MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; + auto iter = graph_model_map_.find(graph_id); + if (iter == graph_model_map_.end()) { + MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; + return; + } + MS_LOG(DEBUG) << "Ge UnloadModel " << iter->first; + auto ret = ModelRunner::Instance().UnloadModel(iter->first); + if (!ret) { + MS_LOG(ERROR) << "UnloadModel failed"; + } + graph_model_map_.erase(iter); +} + +bool AscendKernelRuntime::NeedDestroyHccl() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->enable_hccl()) { + MS_LOG(INFO) << "Hccl is not enabled"; + return false; + } + // Note: make sure hcom_connectivity_detection api never be used. + return true; +} + +void AscendKernelRuntime::ReleaseDeviceRes() { + MS_LOG(INFO) << "Ascend finalize start"; + // release ge runtime + ClearGraphModelMap(); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto ret = rtSetDevice(context_ptr->device_id()); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast(ret) << "]"; + } + + if (mem_manager_ != nullptr) { + mem_manager_->FreeDeviceMemory(); + } + + (void)DestroyHccl(); + (void)ResetDevice(); + (void)ProfilingManager::GetInstance().StopProfiling(); + MS_LOG(INFO) << "Ascend finalize end"; +} + +bool AscendKernelRuntime::Init() { + if (initialized_) { + return true; + } + bool ret = false; +#ifdef ENABLE_DUMP_E2E + ret = SetDumpConf(); + if (!ret) { + MS_LOG(INFO) << "No dump conf to set!"; + } +#endif + +#ifdef ENABLE_DATA_DUMP + DataDumpParser::GetInstance().ParseDumpConfig(); +#endif + + // Start up profiling before rtSetDevice + ret = ProfilingManager::GetInstance().StartupProfiling(device_id_); + if (!ret) { + MS_EXCEPTION(DeviceProcessError) << "StartupProfiling failed."; + } + + ret = InitDevice(); + if (!ret) { + return ret; + } + mem_manager_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(mem_manager_); + mem_manager_->MallocDeviceMemory(); + + initialized_ = true; + return ret; +} + +#ifdef ENABLE_DUMP_E2E +namespace { +void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(dump_conf); + bool trans_flag = dump_conf->trans_flag(); + const auto &apply_kernels = graph->execution_order(); + for (const auto &node : apply_kernels) { + MS_EXCEPTION_IF_NULL(node); + auto node_name = AnfAlgo::GetCNodeName(node); + std::string kernel_name = node->fullname_with_scope(); + if (!dump_conf->IsKernelNeedDump(kernel_name)) { + continue; + } + const std::string strsrc = "/"; + const std::string strdst = "--"; + std::string::size_type pos = 0; + std::string::size_type srclen = strsrc.size(); + std::string::size_type dstlen = strdst.size(); + while ((pos = kernel_name.find(strsrc, pos)) != std::string::npos) { + kernel_name.replace(pos, srclen, strdst); + pos += dstlen; + } + auto output_size = AnfAlgo::GetOutputTensorNum(node); + for (size_t j = 0; j < output_size; ++j) { + auto addr = AnfAlgo::GetOutputAddr(node, j); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(node, j); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(node, j); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + auto type = AnfAlgo::GetOutputInferDataType(node, j); + auto format = kOpFormat_DEFAULT; + string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j); + auto ascend_addr = dynamic_cast(addr); + auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type); + if (!ret) { + MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath + << ", host_format:" << format << ".!"; + } + } + } +} + +void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(dump_conf); + bool trans_flag = dump_conf->trans_flag(); + const auto ¶meters = graph->inputs(); + for (auto &item : parameters) { + if (!item->isa()) { + continue; + } + std::string parameter_name = item->fullname_with_scope(); + if (!dump_conf->IsKernelNeedDump(parameter_name)) { + continue; + } + auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); + auto format = kOpFormat_DEFAULT; + string filepath = dump_path + '/' + parameter_name + '_' + "output_0"; + auto ascend_addr = dynamic_cast(addr); + auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type); + if (!ret) { + MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath + << ", host_format:" << format << ".!"; + } + } +} +} // namespace +#endif + +bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); +#ifdef ENABLE_DUMP_E2E + MS_LOG(INFO) << "Start dump step"; + DumpConfPtr dump_conf = GetDumpConf(); + MS_EXCEPTION_IF_NULL(dump_conf); + dump_conf->UpdataCurIter(); + bool dump_flag = dump_conf->dump_enable(); + if (!dump_flag) { + MS_LOG(INFO) << "Dump flag is disable, pass dump step"; + return true; + } + uint32_t cur_iter = dump_conf->cur_iter(); + if (dump_conf->dump_iter() != 0) { + if (cur_iter != dump_conf->dump_iter()) { + return true; + } + } + MS_LOG(INFO) << "Cur iter is " << cur_iter; + std::string net_name = dump_conf->dump_net_name(); + std::string iterator = to_string(cur_iter); + std::string dump_path = dump_conf->dump_path(); + if (dump_path.back() == '/') { + dump_path = dump_path + net_name + '/' + iterator; + } else { + dump_path = dump_path + '/' + net_name + '/' + iterator; + } + // dump output + DumpOutput(graph, dump_path, dump_conf); + // dump parameters + DumpParameters(graph, dump_path, dump_conf); +#endif + return true; +} + +#ifdef ENABLE_DEBUGGER +namespace { +void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); + // trans_flag: "true" means tensor values will be transfered to host format, otherwise not. + bool trans_flag = false; + const auto &apply_kernels = graph->execution_order(); + // for kernels, execution order starts from 1 + int exec_order = 1; + auto debugger_ = mindspore::Debugger::GetInstance(); + DebugServices *debug_services = debugger_->debug_services(); + auto watchpoint_table = debug_services->GetWatchpointTable(); + for (const auto &node : apply_kernels) { + MS_EXCEPTION_IF_NULL(node); + auto node_name = AnfAlgo::GetCNodeName(node); + std::string kernel_name = node->fullname_with_scope(); + auto output_size = AnfAlgo::GetOutputTensorNum(node); + if (debugger_->partial_memory()) { + if (!debug_services->IsWatchPoint(kernel_name, watchpoint_table)) { + continue; + } + } + for (size_t j = 0; j < output_size; ++j) { + auto addr = AnfAlgo::GetOutputAddr(node, j); + auto type = AnfAlgo::GetOutputInferDataType(node, j); + auto format = kOpFormat_DEFAULT; + string tensor_name = kernel_name + ':' + std::to_string(j); + auto ascend_addr = dynamic_cast(addr); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(node, j); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(node, j); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + auto ret = + ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, j, debugger, false); + if (!ret) { + MS_LOG(ERROR) << "LoadMemToHost: flag:" << trans_flag << ", tensor_name:" << tensor_name + << ", host_format:" << format << ".!"; + } + } + exec_order = exec_order + 1; + } +} + +void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); + // trans_flag: "true" means tensor values will be transfered to host format, otherwise not. + bool trans_flag = false; + const auto ¶meters = graph->inputs(); + // for parameters, set its execution order to be 0; + int exec_order = 0; + for (auto &item : parameters) { + if (!item->isa()) { + continue; + } + std::string parameter_name = item->fullname_with_scope(); + auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); + auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); + auto format = kOpFormat_DEFAULT; + string tensor_name = parameter_name + ':' + "0"; + auto ascend_addr = dynamic_cast(addr); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + auto ret = + ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, 0, debugger, true); + if (!ret) { + MS_LOG(ERROR) << "LoadMemToHost Failed: flag:" << trans_flag << ", path:" << tensor_name + << ", host_format:" << format << ".!"; + } + } +} +} // namespace +#endif + +bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); +#ifdef ENABLE_DEBUGGER + MS_LOG(INFO) << "Start load step"; + uint32_t cur_iter = 0; + MS_LOG(INFO) << "Cur iter is " << cur_iter; + // load output + LoadOutput(graph, debugger); + // load parameters + LoadParameters(graph, debugger); +#endif + return true; +} + +bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { + if (AnfAlgo::OutputAddrExist(kernel, index)) { + auto address = AnfAlgo::GetOutputAddr(kernel, index); + MS_EXCEPTION_IF_NULL(address); + return address->DeviceType() == DeviceAddressType::kAscend; + } + return false; +} + +DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) { + return std::make_shared(device_ptr, device_size, format, type_id); +} + +bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { + if (graph == nullptr) { + MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; + } + MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool is_task_sink = context_ptr->enable_task_sink(); + if (!is_task_sink) { + return true; + } +#ifdef MEM_REUSE_DEBUG + if (!context_ptr->enable_mem_reuse()) { + // Get normal graph ir for memreuse + mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph); + } +#endif + vector> task_info_list; + auto anf_node_list = graph->execution_order(); + TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id()); + // Store the task_info_list + auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list)); + if (!insert_ret.second) { + MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; + } + // Graph may have no compute node, such TensorAddGrad. + if (task_info_list.empty()) { + MS_LOG(WARNING) << "Graph " << graph->graph_id() << " have no compute node"; + return true; + } + AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); + // the streams' flag not HEAD_STREAM + std::vector wait_active_stream_list; + assign_instance.GetWaitStreams(&wait_active_stream_list); + std::vector force_copy_stream_list; + assign_instance.GetHcomStreams(&force_copy_stream_list); + MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.get_cur_stream_num() + << ", total event num:" << resource_manager.get_cur_event_num() + << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) + << ", wait_active_stream_list size:" << wait_active_stream_list.size() + << ", force_copy_stream_list size:" << force_copy_stream_list.size(); + std::vector> empty_list; + auto model = std::make_shared( + task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, + 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), + resource_manager.get_cur_event_num(), 0); + auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); + if (!ret.second) { + MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; + } + MS_LOG(INFO) << "TaskGenerator GetTaskInfo end..."; + return true; +} + +bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { + if (graph == nullptr) { + MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; + } + MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool is_task_sink = context_ptr->enable_task_sink(); + if (!is_task_sink) { + return true; + } + + if (GraphWithEmptyTaskList(graph)) { + MS_LOG(WARNING) << "LoadTask end, task list is empty"; + return true; + } + + auto model_iter = graph_model_map_.find(graph->graph_id()); + if (model_iter == graph_model_map_.end()) { + MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph LoadTask without GenTask."; + return false; + } + + std::shared_ptr listener; + MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first; + bool status = + ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second, listener); + if (!status) { + MS_LOG(EXCEPTION) << "Load Task Failed"; + } + if (ProfilingManager::GetInstance().IsProfiling()) { + auto task_ids = ModelRunner::Instance().GetTaskIdList(model_iter->first); + auto stream_ids = ModelRunner::Instance().GetStreamIdList(model_iter->first); + ProfilingUtils::ReportProfilingData(task_ids, stream_ids, NOT_NULL(graph)); + } + +#ifdef ENABLE_DATA_DUMP + LaunchDataDump(NOT_NULL(graph)); +#endif + if (!ModelRunner::Instance().LoadModelComplete(model_iter->first)) { + MS_LOG(ERROR) << "Call ge runtime LoadModelComplete failed"; + return false; + } + return true; +} + +#ifdef ENABLE_DATA_DUMP +void AscendKernelRuntime::LaunchDataDump(NotNull graph) { + if (!DataDumpParser::GetInstance().DumpEnabled()) { + return; + } + auto runtime_info_map = ModelRunner::Instance().GetRuntimeInfoMap(graph->graph_id()); + auto data_dumper = std::make_shared(graph.get(), runtime_info_map); + MS_EXCEPTION_IF_NULL(data_dumper); + data_dumper->LoadDumpInfo(); + auto ret = graph_data_dumper_.try_emplace(graph->graph_id(), data_dumper); + if (!ret.second) { + MS_LOG(WARNING) << "[DataDump] Insert graphId:" << graph->graph_id() << " data dumper failed"; + } +} +#endif + +void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) { + auto task_ids = ModelRunner::Instance().GetTaskIdList(graph_id); + auto graph_task_names = ProfilingUtils::graph_kernel_name(); + auto iter = graph_task_names.find(graph_id); + if (iter != graph_task_names.end()) { + const auto &task_names = iter->second; + if (task_ids.size() != task_names.size()) { + MS_LOG(WARNING) << "Task_ids and task_names size not match"; + return; + } + for (size_t i = 0; i < task_ids.size(); ++i) { + MS_LOG(INFO) << "Task_id:" << task_ids[i] << " task_name:" << task_names[i]; + } + } +} + +bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id(); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + ge::InputData input_tensors = ge::InputData(); + ge::OutputData *output_tensors = nullptr; + if (GraphWithEmptyTaskList(graph)) { + MS_LOG(WARNING) << "RunTask end, no task info found"; + return true; + } + + if (!CheckGraphIdValid(graph->graph_id())) { + MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph RunTask without GenTask."; + return false; + } + + bool status = ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors); + if (!status) { + MS_LOG(ERROR) << "Run task failed"; + DebugTaskIdName(graph->graph_id()); + return false; + } + return true; +} + +bool AscendKernelRuntime::SyncStream() { + if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { // o for switch stream + MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; + return false; + } + return true; +} + +bool AscendKernelRuntime::InitDevice() { + int device_count = 0; + auto ret = rtGetDeviceCount(&device_count); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast(ret) << "]"; + } + + ret = rtSetDevice(device_id_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast(ret) << "]"; + } + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr == nullptr) { + MS_LOG(ERROR) << "Get MsContext instance failed"; + return false; + } + if (context_ptr->enable_hccl()) { + if (!HcclInit()) { + MS_LOG(ERROR) << "HcclInit init failed"; + return false; + } + } + + ret = rtCtxCreate(&rt_context_, 0, device_id_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast(ret) << "]"; + } + + ret = rtCtxSetCurrent(rt_context_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; + } + + ret = rtStreamCreate(&stream_, 0); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]"; + } + + return true; +} + +bool AscendKernelRuntime::ResetDevice() { + auto ret = rtCtxSetCurrent(rt_context_); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rtCtxSetCurrent failed"; + return false; + } + + if (stream_ != nullptr) { + ret = rtStreamDestroy(stream_); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]"; + } + stream_ = nullptr; + } + + if (rt_context_ != nullptr) { + ret = rtCtxDestroy(rt_context_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]"; + } + rt_context_ = nullptr; + } + return true; +} + +bool AscendKernelRuntime::HcclInit() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->IsTsdOpened()) { + MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open"; + } + MS_LOG(INFO) << "Do hcom init"; + auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH"); + if (config_path_str == nullptr) { + config_path_str = std::getenv("RANK_TABLE_FILE"); + if (config_path_str == nullptr) { + MS_LOG(ERROR) << "Get hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE"; + return false; + } + } + if (strlen(config_path_str) > PATH_MAX) { + MS_LOG(ERROR) << "File path oversize"; + return false; + } + std::string rank_id_str = GetRankId(); + auto full_path = realpath(config_path_str, nullptr); + if (full_path == nullptr) { + MS_LOG(ERROR) << "File path " << config_path_str << " does not exist"; + return false; + } + MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; + hcclResult_t res = hcom_init(full_path, rank_id_str.c_str()); + free(full_path); + if (res != HCCL_SUCCESS) { + MS_LOG(ERROR) << "Hcom init failed, res is " << static_cast(res); + return false; + } + return true; +} + +bool AscendKernelRuntime::DestroyHccl() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!NeedDestroyHccl()) { + MS_LOG(INFO) << "Hccl is not enable, no need to close."; + return true; + } + hcclResult_t res = hcom_destroy(); + if (res != HCCL_SUCCESS) { + MS_LOG(ERROR) << "Hccl destroy failed"; + return false; + } + MS_LOG(INFO) << "Hccl destroy successful, status = " << res << "."; + context_ptr->set_enable_hccl(false); + return true; +} + +bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const { + auto iter = task_map_.find(graph->graph_id()); + if (iter == task_map_.end()) { + MS_LOG(EXCEPTION) << "Unknown graph ptr"; + } + return iter->second.empty(); +} + +bool AscendKernelRuntime::CheckGraphIdValid(GraphId graph_id) const { + return task_map_.find(graph_id) != task_map_.end() && graph_model_map_.find(graph_id) != graph_model_map_.end(); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h new file mode 100644 index 0000000000..4f1663d4d5 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -0,0 +1,83 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ +#include +#include +#include +#include +#include "runtime/device/kernel_runtime.h" +#include "runtime/context.h" +#include "framework/ge_runtime/davinci_model.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "backend/session/session_basic.h" +#ifdef ENABLE_DATA_DUMP +#include "debug/data_dump_parser.h" +#include "runtime/device/ascend/dump/data_dumper.h" +#endif + +using ge::model_runner::TaskInfo; +using std::unordered_map; +using std::vector; +namespace mindspore { +namespace device { +namespace ascend { +class AscendKernelRuntime : public KernelRuntime { + public: + AscendKernelRuntime() = default; + ~AscendKernelRuntime() override; + bool Init() override; + bool DumpData(session::KernelGraph *graph) override; + bool LoadData(session::KernelGraph *graph, Debugger *debugger) override; + bool GenTask(const session::KernelGraph *graph) override; + bool RunTask(const session::KernelGraph *graph) override; + bool LoadTask(const session::KernelGraph *graph) override; + void ClearGraphRuntimeResource(uint32_t graph_id) override; + bool SyncStream() override; + + protected: + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) override; + bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override; + + private: + bool InitDevice(); + bool ResetDevice(); + bool HcclInit(); + bool NeedDestroyHccl(); + bool DestroyHccl(); + + void ClearGraphModelMap(); + void ReleaseDeviceRes() override; + bool GraphWithEmptyTaskList(const session::KernelGraph *graph) const; + bool CheckGraphIdValid(GraphId graph_id) const; + static void DebugTaskIdName(GraphId graph_id); + + rtContext_t rt_context_{nullptr}; + bool initialized_{false}; + unordered_map>> task_map_; + unordered_map> graph_model_map_; +#ifdef ENABLE_DATA_DUMP + void LaunchDataDump(NotNull graph); + unordered_map> graph_data_dumper_; +#endif +}; + +MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime); +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc new file mode 100644 index 0000000000..035f4dd8e3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "runtime/device/ascend/ascend_label_assign.h" +#include "backend/session/anf_runtime_algorithm.h" + +static constexpr uint32_t kLabelGotoLabelId = 1; +static constexpr uint32_t kLabelSwitchLabelId = 2; + +namespace mindspore { +namespace device { +namespace ascend { +static void UpdateLabelGoto(NotNull node) { + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { + return; + } + if (node->size() <= kLabelGotoLabelId) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); + } + + auto input = node->input(kLabelGotoLabelId); + uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(goto_label_id), node.get()); + MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; + node->set_inputs({node->input(0)}); +} + +static void UpdateLabelSwitch(NotNull node) { + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { + return; + } + if (node->size() <= kLabelGotoLabelId) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); + } + std::vector label_list; + for (size_t i = kLabelSwitchLabelId; i < node->size(); ++i) { + auto input = node->input(i); + if (!input->isa() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) { + break; + } + + uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); + label_list.push_back(goto_label_id); + MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; + } + AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue>(label_list), node.get()); + node->set_inputs({node->input(kAnfPrimitiveIndex), node->input(kFirstDataInputIndex)}); +} + +static void AssignLabelForLabelSet(NotNull> graph, NotNull label_id, + NotNull> *> memo) { + if (memo->find(graph.get()) != memo->end()) { + return; + } + memo->insert(graph.get()); + + MS_LOG(INFO) << "Assign label for " << graph->ToString(); + graph->SetExecOrderByDefault(); + auto nodes = graph->execution_order(); + + for (auto &node : nodes) { + if (!node->isa()) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string node_name = AnfAlgo::GetCNodeName(node); + if (node_name == kLabelSetOpName && !AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(*label_id), node); + MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << *label_id; + ++(*label_id); + } + } + + for (auto &cg : graph->child_graph_order()) { + AssignLabelForLabelSet(NOT_NULL(cg), label_id, memo); + } +} + +static void AssignLabelForGotoSwitch(NotNull> graph, + NotNull> *> memo) { + if (memo->find(graph.get()) != memo->end()) { + return; + } + memo->insert(graph.get()); + + MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); + + auto nodes = graph->execution_order(); + auto end_goto = graph->get_end_goto(); + if (end_goto != nullptr) { + nodes.push_back(end_goto); + } + for (auto &node : nodes) { + if (!node->isa()) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string node_name = AnfAlgo::GetCNodeName(node); + if (node_name == kLabelGotoOpName) { + UpdateLabelGoto(NOT_NULL(cnode)); + cnode->set_abstract(nullptr); + } + + if (node_name == kLabelSwitchOpName) { + UpdateLabelSwitch(NOT_NULL(cnode)); + } + } + for (auto &cg : graph->child_graph_order()) { + AssignLabelForGotoSwitch(NOT_NULL(cg), memo); + } + graph->SetExecOrderByDefault(); +} + +void AscendLabelAssign::AssignLabel(NotNull> graph) { + MS_LOG(INFO) << "Assign label start."; + std::set> memo; + uint32_t label_id = 0; + AssignLabelForLabelSet(graph, NOT_NULL(&label_id), NOT_NULL(&memo)); + memo.clear(); + { + std::lock_guard lock(label_num_mutex_); + label_num_[graph.get().get()] = label_id; + } + AssignLabelForGotoSwitch(graph, NOT_NULL(&memo)); + MS_LOG(INFO) << "Assign label end."; +} + +uint32_t AscendLabelAssign::GetLabelNum(NotNull graph) { + std::lock_guard lock(label_num_mutex_); + auto iter = label_num_.find(graph.get()); + if (iter == label_num_.end()) { + MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 0."; + return 0; + } + return iter->second; +} + +uint32_t AscendLabelAssign::GetLabelNum(NotNull> graph) { + return GetLabelNum(NOT_NULL(graph.get().get())); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.h new file mode 100644 index 0000000000..6b09f2940e --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ + +#include +#include +#include "backend/session/kernel_graph.h" +#include "utils/contract.h" + +namespace mindspore { +namespace device { +namespace ascend { +class AscendLabelAssign { + public: + static AscendLabelAssign &GetInstance() { + static AscendLabelAssign instance; // Guaranteed to be destroyed. + return instance; + } + + AscendLabelAssign(const AscendLabelAssign &) = delete; + AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; + + void AssignLabel(NotNull> graph); + uint32_t GetLabelNum(NotNull graph); + uint32_t GetLabelNum(NotNull> graph); + + private: + AscendLabelAssign() = default; + ~AscendLabelAssign() = default; + + std::map label_num_; + std::mutex label_num_mutex_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc new file mode 100644 index 0000000000..f9da0850c6 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc @@ -0,0 +1,137 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "runtime/device/ascend/ascend_memory_manager.h" +#include "runtime/device/ascend/ascend_memory_pool.h" +#include "utils/context/ms_context.h" +#include "runtime/mem.h" +namespace mindspore { +namespace device { +namespace ascend { +constexpr uint64_t kAscendDeviceMemGB = 30; +constexpr uint64_t kMemSizeGB = 30; +constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB); + +void AscendMemoryManager::MallocDeviceMemory() { + auto context_mem = GetDeviceMemSizeFromContext(); + device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem; + dynamic_mem_offset_ = device_mem_size_; + auto ret = rtMalloc(reinterpret_cast(&device_mem_base_), dynamic_mem_offset_, RT_MEMORY_HBM); + + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << dynamic_mem_offset_ << "] fail, ret[" << ret << "]"; + } + + AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_base_); + AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); +} + +uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + auto variable_memory_max_size = context->variable_memory_max_size(); + if (variable_memory_max_size == "0") { + return 0; + } + MS_LOG(INFO) << "context variable_memory_max_size:" << variable_memory_max_size; + auto pos = variable_memory_max_size.find('*'); + if (pos == std::string::npos) { + MS_LOG(EXCEPTION) << "Invalid variable_memory_max_size"; + } + auto gb_str = variable_memory_max_size.substr(0, pos); + auto gb_var = std::stoull(gb_str); + MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var; + if (gb_var > kAscendDeviceMemGB || gb_var == 0) { + MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB"; + } + return gb_var << kMemSizeGB; +} + +void AscendMemoryManager::FreeDeviceMemory() { + if (device_mem_base_ != nullptr) { + auto ret = rtFree(device_mem_base_); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "rtFree mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]"; + } + device_mem_base_ = nullptr; + } + if (device_mem_pool_base_ != nullptr) { + auto ret = rtFree(device_mem_pool_base_); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "rtFree mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]"; + } + device_mem_pool_base_ = nullptr; + } +} + +void AscendMemoryManager::ResetDynamicMemory() { + total_dynamic_size_ = 0; + dynamic_mem_offset_ = device_mem_size_; + AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); +} + +void *AscendMemoryManager::MallocMemFromMemPool(size_t size) { + auto align_size = GetCommonAlignSize(size); + return AscendMemoryPool::GetInstance().AllocTensorMem(align_size); +} + +uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem) { + size_t align_size = 0; + if (communication_mem) { + align_size = GetCommunicationAlignSize(size); + } else { + align_size = GetCommonAlignSize(size); + } + if (communication_mem) { + // create protect area [kMemAlignSize -- data -- kMemAlignSize] + uint8_t *alloc_address = reinterpret_cast(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); + return alloc_address + kMemAlignSize; + } else { + return reinterpret_cast(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); + } +} + +uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { + size_t align_size = 0; + if (communication_mem) { + align_size = GetCommunicationAlignSize(size); + } else { + align_size = GetCommonAlignSize(size); + } + if (dynamic_mem_offset_ < align_size) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ + << "]) malloc [" << align_size << "] failed!"; + } + auto new_offset = dynamic_mem_offset_ - align_size; + auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); + if (new_offset <= device_mem_pool_offset) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ + << "] memory pool[" << device_mem_pool_offset << "])" + << " malloc [" << align_size << "] failed!"; + } + total_dynamic_size_ += align_size; + dynamic_mem_offset_ = new_offset; + AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); + if (communication_mem) { + // create protect area [kMemAlignSize -- data -- kMemAlignSize] + return device_mem_base_ + new_offset + kMemAlignSize; + } else { + return device_mem_base_ + new_offset; + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h new file mode 100644 index 0000000000..720f15be00 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ +#include "runtime/device/memory_manager.h" +namespace mindspore { +namespace device { +namespace ascend { +class AscendMemoryManager : public MemoryManager { + public: + AscendMemoryManager() = default; + ~AscendMemoryManager() override = default; + + void MallocDeviceMemory() override; + void FreeDeviceMemory() override; + void ResetDynamicMemory() override; + void *MallocMemFromMemPool(size_t size) override; + + protected: + uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; + uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override; + + private: + uint8_t *device_mem_pool_base_{nullptr}; + uint64_t device_mem_pool_size_{0}; + + uint64_t GetDeviceMemSizeFromContext(); +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc new file mode 100644 index 0000000000..fe71ba43fc --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ascend_memory_pool.h" +#include "runtime/device/ascend/ascend_kernel_runtime.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +namespace ascend { +size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { + if (size == 0) { + MS_LOG(EXCEPTION) << "Can not alloc memory size(0) in memory pool !"; + } + if (device_mem_pool_offset_ + size >= graph_dynamic_mem_offset_) { + MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ [" + << device_mem_pool_offset_ << "], current graph_dynamic_mem_offset_ " << graph_dynamic_mem_offset_ + << "], need memory size [" << size << "]"; + } + *addr = device_mem_pool_base_ + device_mem_pool_offset_; + device_mem_pool_offset_ += size; + if (*addr == nullptr) { + MS_LOG(EXCEPTION) << "Alloc device address is nullptr, failed to alloc memory pool memory!"; + } + return size; +} + +bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { + MS_EXCEPTION_IF_NULL(addr); + return true; +} + +size_t AscendMemoryPool::AlignMemorySize(size_t size) const { + if (size == 0) { + MS_LOG(EXCEPTION) << "The align memory size is a zero !"; + } + return size; +} + +void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { + MS_EXCEPTION_IF_NULL(device_mem_pool_base); + device_mem_pool_base_ = device_mem_pool_base; +} + +void AscendMemoryPool::set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset) { + graph_dynamic_mem_offset_ = graph_dynamic_mem_offset; +} + +uint64_t AscendMemoryPool::device_mem_pool_offset() const { return device_mem_pool_offset_; } + +size_t AscendMemoryPool::free_mem_size() { + if (graph_dynamic_mem_offset_ < device_mem_pool_offset_) { + MS_LOG(EXCEPTION) << "graph dynamic mem offset [" << graph_dynamic_mem_offset_ + << "] less than device mem pool offset [" << device_mem_pool_offset_ << "]!"; + } + return graph_dynamic_mem_offset_ - device_mem_pool_offset_; +} + +size_t AscendMemoryPool::total_mem_size() { return graph_dynamic_mem_offset_ == 0 ? 0 : graph_dynamic_mem_offset_ - 1; } +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h new file mode 100644 index 0000000000..7a75198ab4 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ + +#include +#include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h" + +namespace mindspore { +namespace device { +namespace ascend { +class AscendMemoryPool : public DynamicMemPoolBestFit { + public: + ~AscendMemoryPool() override = default; + AscendMemoryPool(const AscendMemoryPool &) = delete; + AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; + + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; + void set_device_mem_pool_base(uint8_t *device_mem_pool_base); + void set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset); + + uint64_t device_mem_pool_offset() const; + size_t free_mem_size() override; + size_t total_mem_size() override; + + static AscendMemoryPool &GetInstance() { + static AscendMemoryPool instance; + return instance; + } + + protected: + // The real size by memory alloc aligned. + size_t AlignMemorySize(size_t size) const override; + + private: + AscendMemoryPool() = default; + uint8_t *device_mem_pool_base_{nullptr}; + uint64_t device_mem_pool_offset_{0}; + uint64_t graph_dynamic_mem_offset_{0}; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc new file mode 100644 index 0000000000..7cf5b94d45 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -0,0 +1,1268 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ascend_stream_assign.h" + +#include +#include + +#include "ir/manager.h" +#include "utils/context/ms_context.h" +#include "common/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_adjust.h" +#include "predict/generator/utils/ir_model_util.h" +#include "backend/optimizer/common/helper.h" +#include "utils/utils.h" + +namespace mindspore { +namespace device { +namespace ascend { +const uint32_t kHcomMaxTask = 5; +const uint32_t kCommonMaxTask = 350; + +void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) { + if (IsTaskSink()) { + Reset(); + ReorderIndependentOrders(graph_ptr); + AssignAllNodesStream(graph_ptr); + UpdateAtomicAddrCleanStreamId(graph_ptr); + InsertStreamActive(graph_ptr); + InsertEventForHcomParallel(graph_ptr); + InsertEventForIndependentParallel(graph_ptr); + GetNeedActiveStreams(graph_ptr); + graph_ptr->PrintGraphExecuteOrder(); + CheckResourceAssign(graph_ptr); + MS_LOG(INFO) << "After finish stream assign"; + + FindStreamRelations(graph_ptr); + PrintStreamRelations(); + GetStreamRelations(); + PrintStreamGroups(); + FindEventRelations(graph_ptr); + + // Get info for D Model + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + generator::IRModelUtil::GetInstance().set_event_num(resource_manager.get_cur_event_num()); + generator::IRModelUtil::GetInstance().set_stream_num(resource_manager.get_cur_stream_num()); + // Init to 1,temporarily + generator::IRModelUtil::GetInstance().set_batch_num(1); + } +} + +// section 1 +void AscendStreamAssign::ReorderIndependentOrders(const NotNull &graph_ptr) { + std::vector exe_orders; + std::vector independents; + std::vector others; + + auto cnode_ptr_list = graph_ptr->execution_order(); + MS_LOG(INFO) << "Before reorder, graph orders size:" << cnode_ptr_list.size(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + auto cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (IsIndependentNode(cur_cnode_ptr)) { + independents.emplace_back(cur_cnode_ptr); + } else { + others.emplace_back(cur_cnode_ptr); + } + } + + if (others.empty() || independents.empty()) { + MS_LOG(INFO) << "Independent or others is empty, no need reorder"; + return; + } + + std::set processed; + for (size_t i = 0; i < others.size(); i++) { + auto begin = others.begin() + i; + auto end = begin + 1; + bool flag = false; + for (size_t j = 0; j < independents.size(); j++) { + auto cur_independent = independents[j]; + auto it = std::find(processed.begin(), processed.end(), cur_independent.get()); + if (it != processed.end()) { + continue; + } + + auto res = FindTargetOp(begin, end, cur_independent); + if (res != end) { + flag = true; + exe_orders.emplace_back(cur_independent); + exe_orders.emplace_back(*begin); + processed.emplace(cur_independent.get()); + break; + } + } + + if (!flag) { + exe_orders.emplace_back(*begin); + } + } + + MS_LOG(INFO) << "After reorder, graph orders size:" << exe_orders.size(); + if (processed.size() != independents.size()) { + MS_LOG(WARNING) << "Processed independent nodes size is not equal to exiting independent nodes size"; + return; + } + + graph_ptr->set_execution_order(exe_orders); +} + +// section 2 +void AscendStreamAssign::AssignAllNodesStream(const NotNull &graph_ptr) { + auto cnode_ptr_list = graph_ptr->execution_order(); + bool exit_independent = false; + bool exit_hcom = false; + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + // node has been assigned stream before + if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { + continue; + } + + if (IsHcom(cur_cnode_ptr)) { + exit_hcom = true; + continue; + } + + if (IsIndependentNode(cur_cnode_ptr)) { + exit_independent = true; + continue; + } + + AssignCommonStreamId(cur_cnode_ptr); + } + MS_LOG(INFO) << "Common start from 0, common stream nums:" << resource_manager.get_cur_stream_num(); + + if (exit_hcom) { + uint32_t first_hcom_stream_id = resource_manager.ApplyNewStream(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + // node has been assigned stream before + if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { + continue; + } + + if (IsHcom(cur_cnode_ptr)) { + AssignHcomStreamId(cur_cnode_ptr); + } + } + MS_LOG(INFO) << "Hcom start from :" << first_hcom_stream_id << ", hcom stream nums:" << hcom_stream_map_.size(); + } + + if (exit_independent) { + uint32_t first_independ = resource_manager.ApplyNewStream(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { + continue; + } + if (IsIndependentNode(cur_cnode_ptr)) { + AssignIndependentStreamId(cur_cnode_ptr); + } + } + MS_LOG(INFO) << "Independ start from:" << first_independ << ", stream nums:" << independent_stream_map_.size(); + } + + MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num(); +} + +void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t cur_common_stream_id = 0; + uint32_t cur_stream_num = resource_manager.get_cur_stream_num(); + if (cur_stream_num == 0) { + cur_common_stream_id = resource_manager.ApplyNewStream(); + } else { + cur_common_stream_id = resource_manager.GetCurAllocStreamId(); + } + + auto it = common_stream_map_.find(cur_common_stream_id); + if (it == common_stream_map_.end()) { + AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get()); + common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1)); + } else { + if (it->second < kCommonMaxTask) { + AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); + it->second++; + } else { + cur_common_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get()); + common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1)); + } + } +} + +void AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr) { + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t cur_hcom_stream_id = resource_manager.GetCurAllocStreamId(); + auto it = hcom_stream_map_.find(cur_hcom_stream_id); + if (it == hcom_stream_map_.end()) { + AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); + hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); + } else { + if (it->second < kHcomMaxTask) { + AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); + it->second++; + } else { + cur_hcom_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); + hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); + } + } +} + +void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr) { + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t cur_independent_id = resource_manager.GetCurAllocStreamId(); + auto it = independent_stream_map_.find(cur_independent_id); + if (it == independent_stream_map_.end()) { + AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); + independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); + } else { + if (it->second < kCommonMaxTask) { + AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); + it->second++; + } else { + cur_independent_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); + independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); + } + } +} + +bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { + MS_EXCEPTION_IF_NULL(node_ptr); + if (AnfAlgo::GetKernelType(node_ptr) != AICPU_KERNEL) { + return false; + } + + if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) { + MS_LOG(INFO) << "GetNext should not be independent node"; + return false; + } + + uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr); + if (input_nums == 0) { + MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero"; + return true; + } + + auto inputs = node_ptr->inputs(); + for (size_t i = 1; i < inputs.size(); i++) { + if (!inputs[i]->isa()) { + return false; + } + } + MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node"; + return true; +} + +// section 3: +void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr) { + MS_LOG(INFO) << "Start"; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + // update AtomicAddrClean stream same witch the next node + if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) { + AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get()); + } + } + MS_LOG(INFO) << "End"; +} + +// section 4 +void AscendStreamAssign::InsertStreamActive(const NotNull &graph_ptr) { + MS_LOG(INFO) << "Start"; + GetProcessedStream(graph_ptr); + std::vector update_cnode_list; + CNodePtr cur_cnode_ptr = nullptr; + CNodePtr pre_cnode_ptr = nullptr; + uint32_t pre_stream_id = UINT32_MAX; + + bool independent_flag = !(independent_stream_map_.empty()); + bool hcom_flag = !(hcom_stream_map_.empty()); + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (IsIndependentNode(cur_cnode_ptr)) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + + if (IsHcom(cur_cnode_ptr)) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + bool processed = IsProcessedStream(cur_stream_id); + // 1)inner stream assign, need insert active op + if (!processed) { + MS_LOG(INFO) << "Common stream active info:" << pre_stream_id << "->active" << cur_stream_id; + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); + // 1.set stream id + AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); + // 2.set active stream ids + std::vector active_index_list{cur_stream_id}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); + update_cnode_list.emplace_back(active_ptr); + } + + if ((independent_flag || hcom_flag) && (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName)) { + MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel"; + UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list); + } else { + update_cnode_list.emplace_back(cur_cnode_ptr); + } + + processed_streams_.emplace(cur_stream_id); + pre_stream_id = cur_stream_id; + pre_cnode_ptr = cur_cnode_ptr; + } + graph_ptr->set_execution_order(update_cnode_list); + MS_LOG(INFO) << "End"; +} + +void AscendStreamAssign::GetProcessedStream(const NotNull &graph_ptr) { + // 0 stream is activated at first + processed_streams_.emplace(0); + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + auto cur_cnode_ptr = cnode_ptr_list[i]; + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { + auto true_stream_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrTrueBranchStream); + processed_streams_.emplace(true_stream_id); + + if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { + continue; + } + auto need_active = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kStreamNeedActivedFirst); + if (need_active) { + processed_streams_.emplace(cur_stream_id); + } + } + } + for (const auto &item : processed_streams_) { + MS_LOG(INFO) << "Before active:" << item << " is been processed"; + } +} + +void AscendStreamAssign::UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, + vector *orders) { + orders->emplace_back(switch_ptr); + if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) { + return; + } + + auto need_active = AnfAlgo::GetNodeAttr(switch_ptr, kStreamNeedActivedFirst); + if (!need_active) { + return; + } + + MS_EXCEPTION_IF_NULL(switch_ptr); + auto true_stream_id = AnfAlgo::GetNodeAttr(switch_ptr, kAttrTrueBranchStream); + MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) + << "; active stream id:" << true_stream_id; + + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); + AnfAlgo::SetStreamId(true_stream_id, active_ptr.get()); + vector active_ids; + // active indepdent stream + for (const auto &item : independent_stream_map_) { + active_ids.emplace_back(item.first); + } + // active hcom stream + for (const auto &item : hcom_stream_map_) { + active_ids.emplace_back(item.first); + } + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_ids), active_ptr); + + // update processed stream + independent_stream_activated_ = true; + for (const auto &item : independent_stream_map_) { + processed_streams_.emplace(item.first); + } + + hcom_stream_activated_ = true; + for (const auto &item : hcom_stream_map_) { + processed_streams_.emplace(item.first); + } + + orders->emplace_back(active_ptr); +} + +bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { + auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id); + if (it != processed_streams_.end()) { + return true; + } + return false; +} + +// section5 +void AscendStreamAssign::InsertEventForHcomParallel(const NotNull &graph_ptr) { + MS_LOG(INFO) << "Start"; + InsertEventCommonDependHcom(graph_ptr); + InsertEventHcomDependCommon(graph_ptr); + InsertEventHcomDependHcom(graph_ptr); + MS_LOG(INFO) << "End"; +} + +void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector cnodes = cnode_ptr_list; + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + auto it = cnodes.begin(); + while (it != cnodes.end() && (it + 1) != cnodes.end()) { + MS_EXCEPTION_IF_NULL(*it); + MS_EXCEPTION_IF_NULL(*(it + 1)); + if (IsHcom(*it) && !IsHcom(*(it + 1))) { + CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); + it = cnodes.insert(it + 1, send_cnode_ptr); + + auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); + if (target == cnodes.end()) { + MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope() + << ", can't find target for insert recv op, no insert send/recv"; + it = cnodes.erase(it); + continue; + } + + if (IsHcom(*target)) { + it = cnodes.erase(it); + continue; + } + + // deal recv op + uint32_t stream_id = AnfAlgo::GetStreamId(*target); + CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); + (void)cnodes.insert(target, recv_cnode_ptr); + cur_event_id = resource_manager.ApplyNewEvent(); + } + ++it; + } + // one event allocated additional, should delete + resource_manager.DeleteEvent(); + graph_ptr->set_execution_order(cnodes); + MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num(); +} + +void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector cnodes; + CNodePtr cur_cnode_ptr = nullptr; + uint32_t pre_stream_id = UINT32_MAX; + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (i == 0) { + cnodes.emplace_back(cur_cnode_ptr); + pre_stream_id = cur_stream_id; + continue; + } + + if (!IsHcom(cur_cnode_ptr)) { + cnodes.emplace_back(cur_cnode_ptr); + pre_stream_id = cur_stream_id; + continue; + } + + if (cur_stream_id == pre_stream_id) { + cnodes.emplace_back(cur_cnode_ptr); + pre_stream_id = cur_stream_id; + continue; + } + + if (!IsHcom(cnode_ptr_list[i - 1])) { + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id); + cnodes.emplace_back(send); + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id); + cnodes.emplace_back(recv); + cnodes.emplace_back(cur_cnode_ptr); + } else { + cnodes.emplace_back(cur_cnode_ptr); + } + pre_stream_id = cur_stream_id; + } + + graph_ptr->set_execution_order(cnodes); + MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); +} + +void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + uint32_t first_hcom_stream = kInvalidStreamId; + uint32_t last_hcom_stream = kInvalidStreamId; + // key: stream id, value:hcom index + std::map> hcom_index; + for (size_t i = 0; i < cnode_ptr_list.size(); i++) { + auto cur_cnode = cnode_ptr_list[i]; + if (!IsHcom(cur_cnode)) { + continue; + } + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); + auto it = hcom_index.find(cur_stream_id); + if (it != hcom_index.end()) { + hcom_index[cur_stream_id].emplace_back(i); + } else { + hcom_index[cur_stream_id] = {i}; + } + + // record first hcom stream id + if (first_hcom_stream == kInvalidStreamId) { + first_hcom_stream = cur_stream_id; + } + + // record last hcom stream id + if (cur_stream_id != last_hcom_stream) { + last_hcom_stream = cur_stream_id; + } + } + + if (hcom_index.size() < 2) { + MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them"; + return; + } + InsertEventBetweenHcom(graph_ptr, hcom_index, first_hcom_stream, last_hcom_stream); + MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); +} + +void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &graph_ptr, + const map> &hcom_index, + uint32_t first_hcom_stream, uint32_t last_hcom_stream) { + vector orders; + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + size_t first_stream_last_index = hcom_index.at(first_hcom_stream).back(); + size_t last_stream_first_index = hcom_index.at(last_hcom_stream).front(); + std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders)); + for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) { + auto cur_cnode = cnode_ptr_list[i]; + if (!IsSatisfiedHcom(hcom_index, cur_cnode, i)) { + orders.emplace_back(cur_cnode); + continue; + } + auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode); + if (i == first_stream_last_index) { + // first fusion hcom + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } else if (i == last_stream_first_index) { + // last fusion hcom + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + orders.emplace_back(cur_cnode); + } else { + auto cur_stream_hcom_size = hcom_index.at(cur_hcom_stream_id).size(); + if (cur_stream_hcom_size == 1) { + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + cur_event_id = resource_manager.ApplyNewEvent(); + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } else { + // current stream, first hcom:add recv op + if (i == hcom_index.at(cur_hcom_stream_id).front()) { + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + cur_event_id = resource_manager.ApplyNewEvent(); + orders.emplace_back(cur_cnode); + } else if (i == hcom_index.at(cur_hcom_stream_id).back()) { + // current stream, last hcom:add send op + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } else { + // current stream, not first and last op + orders.emplace_back(cur_cnode); + } + } + } + } + std::copy(cnode_ptr_list.begin() + last_stream_first_index + 1, cnode_ptr_list.end(), std::back_inserter(orders)); + graph_ptr->set_execution_order(orders); +} + +bool AscendStreamAssign::IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, + size_t index) { + MS_EXCEPTION_IF_NULL(node_ptr); + auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr); + auto it = hcom_index.find(cur_hcom_stream_id); + if (it == hcom_index.end()) { + return false; + } + auto iter = std::find(hcom_index.at(cur_hcom_stream_id).begin(), hcom_index.at(cur_hcom_stream_id).end(), index); + if (iter == hcom_index.at(cur_hcom_stream_id).end()) { + return false; + } + return true; +} + +// section6 +void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull &graph_ptr) { + MS_LOG(INFO) << "Start"; + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector cnodes = cnode_ptr_list; + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + auto it = cnodes.begin(); + while (it != cnodes.end()) { + MS_EXCEPTION_IF_NULL(*it); + if (IsIndependentNode(*it)) { + MS_LOG(INFO) << "Deal independent op[" << (*it)->DebugString() << "]"; + CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); + it = cnodes.insert(it + 1, send_cnode_ptr); + + auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); + if (target == cnodes.end()) { + MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope() + << "] can't find target for insert recv op, no insert send/recv"; + it = cnodes.erase(it); + continue; + } + + // deal recv op + uint32_t stream_id = AnfAlgo::GetStreamId(*target); + CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); + (void)cnodes.insert(target, recv_cnode_ptr); + cur_event_id = resource_manager.ApplyNewEvent(); + } + ++it; + } + // one event allocated additional, should delete + resource_manager.DeleteEvent(); + graph_ptr->set_execution_order(cnodes); + MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.get_cur_event_num(); + MS_LOG(INFO) << "End"; +} + +// section7 +void AscendStreamAssign::GetNeedActiveStreams(const NotNull &graph_ptr) { + CNodePtr cur_cnode_ptr = nullptr; + auto cnode_ptr_list = graph_ptr->execution_order(); + // 1)first stream 0 should be actived first; + need_first_active_streams_.emplace_back(0); + + // 2)stream witch kStreamNeedActivedFirst attr should be actived; + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { + continue; + } + + auto need_active = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kStreamNeedActivedFirst); + if (need_active) { + auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first"; + need_first_active_streams_.push_back(stream_id); + } + } + + // 3)independent stream:if has not been activate, push to need active vector + if (!independent_stream_activated_) { + for (auto &item : independent_stream_map_) { + need_first_active_streams_.emplace_back(item.first); + } + } + + // 4)hcom stream:if has not been activate, push to need active vector + if (!hcom_stream_activated_) { + for (auto &item : hcom_stream_map_) { + need_first_active_streams_.emplace_back(item.first); + } + } +} + +// section8 +void AscendStreamAssign::CheckResourceAssign(const NotNull &graph_ptr) { + CheckStreamAssign(graph_ptr); + CheckEventAssign(graph_ptr); +} + +void AscendStreamAssign::CheckStreamAssign(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + std::set streams; + uint32_t max_stream = 0; + uint32_t min_stream = kInvalidStreamId; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + if (stream_id == kInvalidStreamId) { + MS_LOG(EXCEPTION) << "Node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "had not been assigned stream"; + } + + (void)streams.emplace(stream_id); + if (stream_id > max_stream) { + max_stream = stream_id; + } + if (stream_id < min_stream) { + min_stream = stream_id; + } + } + + // check stream assign + if (!streams.empty()) { + if (min_stream != 0) { + MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream; + } + uint32_t assigned_stream_num = resource_manager.get_cur_stream_num(); + if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) { + MS_LOG(EXCEPTION) << "Stream should be consecutive, max stream id:" << max_stream + << "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size(); + } + } +} + +void AscendStreamAssign::CheckEventAssign(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + std::map> event_map; + uint32_t max_event_id = 0; + uint32_t min_event_id = kInvalidEventId; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + auto name = AnfAlgo::GetCNodeName(cur_cnode_ptr); + if (name == kSendOpName || name == kRecvOpName) { + uint32_t event_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId); + if (event_id > max_event_id) { + max_event_id = event_id; + } + + if (event_id < min_event_id) { + min_event_id = event_id; + } + auto it = event_map.find(event_id); + if (it == event_map.end()) { + event_map[event_id] = {cur_cnode_ptr}; + } else { + event_map[event_id].emplace_back(cur_cnode_ptr); + } + } + } + // check event assign + if (!event_map.empty()) { + if (min_event_id != 0) { + MS_LOG(EXCEPTION) << "Event should start from 0, now is from " << min_event_id; + } + uint32_t assigned_event_num = resource_manager.get_cur_event_num(); + if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) { + MS_LOG(EXCEPTION) << "Event should be consecutive"; + } + for (const auto &item : event_map) { + if (item.second.size() != 2) { + MS_LOG(EXCEPTION) << "Send/recv should be in pair and share one event id"; + } + auto first_name = AnfAlgo::GetCNodeName(item.second[0]); + auto second_name = AnfAlgo::GetCNodeName(item.second[1]); + if (!(first_name == kSendOpName && second_name == kRecvOpName)) { + MS_LOG(EXCEPTION) << "Send should be before recv"; + } + } + } +} + +// section9 +CNodePtr AscendStreamAssign::CreateSendApplyKernel(const NotNull &graph_ptr, uint32_t event_id, + uint32_t stream_id) { + auto send_op = std::make_shared(kSendOpName); + MS_EXCEPTION_IF_NULL(send_op); + auto send_apply = std::make_shared(send_op); + MS_EXCEPTION_IF_NULL(send_apply); + std::vector send_input_list = {send_apply}; + CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); + MS_EXCEPTION_IF_NULL(send_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + send_node_ptr->set_abstract(abstract_none); + AnfAlgo::SetStreamId(stream_id, send_node_ptr.get()); + return send_node_ptr; +} + +CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull &graph_ptr, uint32_t event_id, + uint32_t stream_id) { + auto recv_op = std::make_shared(kRecvOpName); + MS_EXCEPTION_IF_NULL(recv_op); + auto recv_apply = std::make_shared(recv_op); + MS_EXCEPTION_IF_NULL(recv_apply); + std::vector recv_input_list = {recv_apply}; + CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); + MS_EXCEPTION_IF_NULL(recv_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); + AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get()); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + recv_node_ptr->set_abstract(abstract_none); + return recv_node_ptr; +} + +vector::iterator AscendStreamAssign::FindTargetOp(vector::iterator begin, + vector::iterator end, const CNodePtr &node) { + while (begin != end) { + auto inputs = (*begin)->inputs(); + for (size_t i = 1; i < inputs.size(); i++) { + auto input = inputs[i]; + if (opt::IsNopNode(input)) { + CNodePtr cnode = input->cast(); + auto new_inputs = cnode->inputs(); + for (size_t j = 1; j < new_inputs.size(); j++) { + auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0); + if (node == new_real_input.first) { + MS_LOG(INFO) << "Nop node find target op[" << (*begin)->DebugString() << "]"; + return begin; + } + } + } else { + auto real_input = AnfAlgo::VisitKernel(input, 0); + if (node == real_input.first) { + MS_LOG(INFO) << "Find target op[" << (*begin)->DebugString() << "]"; + return begin; + } + } + } + ++begin; + } + return end; +} + +bool AscendStreamAssign::IsTaskSink() { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (!ms_context->enable_task_sink()) { + MS_LOG(INFO) << "Task sink mode is not enable"; + return false; + } else { + MS_LOG(INFO) << "Task sink mode is enable"; + return true; + } +} + +void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { + MS_EXCEPTION_IF_NULL(wait_active_stream_list); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t total_stream_num = resource_manager.get_cur_stream_num(); + if (total_stream_num == 0) { + MS_LOG(INFO) << "The total_common_stream_num is zero"; + return; + } + + // common stream:active first common stream + for (uint32_t i = 0; i < total_stream_num; i++) { + auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); + if (it == need_first_active_streams_.end()) { + MS_LOG(INFO) << "Wait common stream id = " << i; + wait_active_stream_list->push_back(i); + } + } +} + +bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) { + MS_EXCEPTION_IF_NULL(apply_kernel); + return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL; +} + +void AscendStreamAssign::GetHcomStreams(std::vector *streams) { + MS_EXCEPTION_IF_NULL(streams); + for (const auto &item : hcom_stream_map_) { + streams->emplace_back(item.first); + } +} + +void AscendStreamAssign::Reset() { + independent_stream_activated_ = false; + hcom_stream_activated_ = false; + independent_stream_map_.clear(); + hcom_stream_map_.clear(); + common_stream_map_.clear(); + processed_streams_.clear(); + need_first_active_streams_.clear(); + stream_groups_.clear(); + stream_relations_.clear(); + event_map_.clear(); +} + +// section 10 +bool AscendStreamAssign::IsVecExist(std::vector *group) { + auto group_size = group->size(); + if (group_size == 0) { + return false; + } + for (const auto &item : stream_groups_) { + if (item.size() < group->size()) { + continue; + } + + bool flag = true; + for (size_t i = 0; i < group_size; i++) { + if (item[i] != group->at(i)) { + flag = false; + break; + } + } + + if (flag) { + return true; + } else { + continue; + } + } + + return false; +} + +void AscendStreamAssign::DFS(uint32_t start, std::vector *group) { + auto it = stream_relations_.find(start); + if (it == stream_relations_.end()) { + if (!IsVecExist(group)) { + stream_groups_.emplace_back(*group); + } else { + MS_LOG(WARNING) << "DFS should not print this log"; + } + return; + } + + vector active_streams = stream_relations_[start]; + + for (const auto &item : active_streams) { + group->emplace_back(item); + DFS(item, group); + group->pop_back(); + } +} + +void AscendStreamAssign::GetStreamRelations() { + for (const auto &start : need_first_active_streams_) { + vector group{start}; + DFS(start, &group); + } +} + +void AscendStreamAssign::FindStreamRelations(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto stream_num = resource_manager.get_cur_stream_num(); + if (stream_num <= 1) { + return; + } + + auto exe_orders = graph_ptr->execution_order(); + for (size_t i = 0; i < exe_orders.size(); i++) { + auto cur_cnode = exe_orders[i]; + auto name = AnfAlgo::GetCNodeName(cur_cnode); + if (name != kStreamSwitchOpName && name != kStreamActiveOpName) { + continue; + } + + // support:streamswitch is begin of the stream + if (name == kStreamSwitchOpName) { + GetStreamSwitchStreamRelation(cur_cnode); + } + + if (name == kStreamActiveOpName) { + GetStreamActiveStreamRelation(graph_ptr, i); + } + } +} + +void AscendStreamAssign::GetStreamSwitchStreamRelation(const CNodePtr &node_ptr) { + MS_EXCEPTION_IF_NULL(node_ptr); + auto cur_stream_id = AnfAlgo::GetStreamId(node_ptr); + auto true_stream_id = AnfAlgo::GetNodeAttr(node_ptr, kAttrTrueBranchStream); + if (true_stream_id <= cur_stream_id) { + MS_LOG(ERROR) << "StreamSwitch self stream id " << cur_stream_id + << " is greater than true branch stream id:" << true_stream_id; + } + auto it = stream_relations_.find(cur_stream_id); + if (it == stream_relations_.end()) { + stream_relations_[cur_stream_id] = {true_stream_id}; + } else { + auto iter = + std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), true_stream_id); + if (iter == stream_relations_[cur_stream_id].end()) { + stream_relations_[cur_stream_id].emplace_back(true_stream_id); + } + } +} + +void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull &graph_ptr, size_t index) { + StreamActiveKind kind = GetStreamActiveKind(graph_ptr, index); + if (kind == kInvalid) { + MS_LOG(INFO) << "Invalid streamActive kind"; + return; + } + + auto orders = graph_ptr->execution_order(); + auto cur_cnode = orders[index]; + auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); + auto active_list = AnfAlgo::GetNodeAttr>(cur_cnode, kAttrActiveStreamList); + if (kind == kHead) { + uint32_t active_current_node = GetStreamByActivedStream(cur_stream_id); + if (active_current_node == kInvalidStreamId) { + MS_LOG(EXCEPTION) << "No stream to active streamactive stream"; + } + + for (const auto &item : active_list) { + if (item <= active_current_node) { + MS_LOG(WARNING) << "Actived stream is less than activing stream"; + continue; + } + auto it = + std::find(stream_relations_[active_current_node].begin(), stream_relations_[active_current_node].end(), item); + if (it == stream_relations_[active_current_node].end()) { + stream_relations_[active_current_node].emplace_back(item); + } + } + } + + if (kind == kMiddle) { + for (const auto &stream : active_list) { + if (stream <= cur_stream_id) { + MS_LOG(INFO) << "MIDDLE StreamActive active stream is less than self stream, no need deal"; + } else { + MS_LOG(ERROR) << "MIDDLE StreamActive active stream is greater than self stream, should not be exit now"; + } + } + } + + if (kind == kTail) { + auto it = stream_relations_.find(cur_stream_id); + if (it == stream_relations_.end()) { + stream_relations_[cur_stream_id] = active_list; + } else { + for (const auto &stream : active_list) { + if (stream <= cur_stream_id) { + MS_LOG(WARNING) << "Actived stream is less than activing stream"; + continue; + } + auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream); + if (iter == stream_relations_[cur_stream_id].end()) { + stream_relations_[cur_stream_id].emplace_back(stream); + } + } + } + } +} + +StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull &graph_ptr, size_t index) { + auto exe_orders = graph_ptr->execution_order(); + if (index >= exe_orders.size()) { + MS_LOG(EXCEPTION) << "Invalid op index:" << index; + } + + auto cur_cnode = exe_orders[index]; + auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); + if (AnfAlgo::GetCNodeName(cur_cnode) != kStreamActiveOpName) { + MS_LOG(EXCEPTION) << "Current node name is not StreamActive"; + } + + if (index == 0) { + return kInvalid; + } + + if (index == exe_orders.size() - 1) { + return kInvalid; + } + + uint32_t pre_stream_id = UINT32_MAX; + uint32_t next_stream_id = UINT32_MAX; + int32_t start = SizeToInt(index) - 1; + for (int32_t i = start; i >= 0; i--) { + auto cnode = exe_orders[IntToSize(i)]; + auto name = AnfAlgo::GetCNodeName(cnode); + if (name == kSendOpName || name == kRecvOpName) { + continue; + } + + pre_stream_id = AnfAlgo::GetStreamId(cnode); + break; + } + + for (size_t i = index + 1; i < exe_orders.size(); i++) { + auto cnode = exe_orders[i]; + auto name = AnfAlgo::GetCNodeName(cnode); + if (name == kSendOpName || name == kRecvOpName) { + continue; + } + + next_stream_id = AnfAlgo::GetStreamId(cnode); + break; + } + + // pre_stream_id = UINT32_MAX:means no node active current StreamActive + // next_stream_id = UINT32_MAX:means current StreamActive active no node + if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) { + return kInvalid; + } + + if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) { + return kMiddle; + } + + if (cur_stream_id == pre_stream_id) { + return kTail; + } + + if (cur_stream_id == next_stream_id) { + return kHead; + } + + return kInvalid; +} + +uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) { + if (stream_relations_.empty()) { + return kInvalidStreamId; + } + + for (const auto &item : stream_relations_) { + auto it = std::find(item.second.begin(), item.second.end(), actived_stream_id); + if (it != item.second.end()) { + return item.first; + } + } + + return kInvalidStreamId; +} + +void AscendStreamAssign::PrintStreamRelations() { + MS_LOG(INFO) << "Stream relations size:" << stream_relations_.size(); + for (const auto &item : stream_relations_) { + MS_LOG(INFO) << "Stream:" << item.first; + for (const auto &stream : item.second) { + MS_LOG(INFO) << "--actived stream id:" << stream; + } + } +} + +void AscendStreamAssign::PrintStreamGroups() { + MS_LOG(INFO) << "Stream group size:" << stream_groups_.size(); + for (const auto &item : stream_groups_) { + MS_LOG(INFO) << "Group:"; + for (const auto &stream : item) { + MS_LOG(INFO) << "Stream id:" << stream; + } + } +} + +// section 11 +bool AscendStreamAssign::IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const { + size_t send_group = 0; + size_t recv_group = 0; + bool send_flag = true; + bool recv_flag = true; + for (size_t i = 0; i < stream_groups_.size(); i++) { + auto group = stream_groups_[i]; + if (send_flag) { + auto it = std::find(group.begin(), group.end(), send_stream_id); + if (it != group.end()) { + send_group = i; + send_flag = false; + } + } + + if (recv_flag) { + auto it = std::find(group.begin(), group.end(), recv_stream_id); + if (it != group.end()) { + recv_group = i; + recv_flag = false; + } + } + } + + if (!(send_flag || recv_flag)) { + return (send_group != recv_group); + } + + return false; +} + +void AscendStreamAssign::FindEventRelations(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto event_nums = resource_manager.get_cur_event_num(); + if (event_nums == 0) { + return; + } + auto exe_orders = graph_ptr->execution_order(); + // find all event info + for (size_t i = 0; i < exe_orders.size(); i++) { + auto cur_cnode = exe_orders[i]; + auto name = AnfAlgo::GetCNodeName(cur_cnode); + if (name == kSendOpName) { + event_map_[cur_cnode] = {}; + } + + if (name == kRecvOpName) { + auto recv_event_id = AnfAlgo::GetNodeAttr(cur_cnode, kAttrEventId); + for (auto &item : event_map_) { + auto send_event_id = AnfAlgo::GetNodeAttr(item.first, kAttrEventId); + if (recv_event_id == send_event_id) { + item.second = cur_cnode; + break; + } + } + } + } + + // delete useless event info + auto begin = event_map_.begin(); + while (begin != event_map_.end()) { + auto send_stream_id = AnfAlgo::GetStreamId(begin->first); + auto recv_stream_id = AnfAlgo::GetStreamId(begin->second); + bool flag = IsSatisfiedEvent(send_stream_id, recv_stream_id); + if (!flag) { + begin = event_map_.erase(begin); + } else { + begin++; + } + } + + MS_LOG(INFO) << "Satisfied event info"; + for (const auto &item : event_map_) { + MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr(item.first, kAttrEventId); + } +} + +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h new file mode 100644 index 0000000000..00fca60e8d --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h @@ -0,0 +1,185 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "runtime/base.h" +#include "runtime/rt_model.h" +#include "runtime/stream.h" +#include "backend/session/kernel_graph.h" +#include "utils/contract.h" + +namespace mindspore { +namespace device { +namespace ascend { +using std::map; +using std::shared_ptr; +using std::unordered_map; +using std::unordered_set; +using std::vector; +const uint32_t kInvalidStreamId = UINT32_MAX; +const uint32_t kInvalidEventId = UINT32_MAX; +class AscendResourceMng { + public: + static AscendResourceMng &GetInstance() { + static AscendResourceMng instance; + return instance; + } + + void ResetResource() { + cur_stream_num_ = 0; + cur_event_num_ = 0; + } + uint32_t ApplyNewStream() { + if (!cur_stream_num_) { + uint32_t cur_stream_id = cur_stream_num_; + cur_stream_num_++; + return cur_stream_id; + } + uint32_t cur_stream_id = cur_stream_num_; + cur_stream_num_++; + return cur_stream_id; + } + uint32_t ApplyNewEvent() { + if (!cur_event_num_) { + uint32_t cur_event_id = cur_event_num_; + cur_event_num_++; + return cur_event_id; + } + uint32_t cur_event_id = cur_event_num_; + cur_event_num_++; + return cur_event_id; + } + + void DeleteEvent() { + if (!cur_event_num_) { + MS_LOG(WARNING) << "total event num is 0, no event to delete"; + } else { + --cur_event_num_; + } + } + uint32_t get_cur_stream_num() { return cur_stream_num_; } + uint32_t GetCurAllocStreamId() { + if (!cur_stream_num_) { + MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get"; + } + return cur_stream_num_ - 1; + } + uint32_t get_cur_event_num() { return cur_event_num_; } + + private: + uint32_t cur_stream_num_{0}; + uint32_t cur_event_num_{0}; +}; + +enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail }; +class AscendStreamAssign { + public: + static AscendStreamAssign &GetInstance() { + static AscendStreamAssign instance; // Guaranteed to be destroyed. + return instance; + } + + AscendStreamAssign(const AscendStreamAssign &) = delete; + AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; + + void AssignStream(const NotNull &graph_ptr); + void GetHcomStreams(std::vector *streams); + void GetWaitStreams(vector *wait_active_stream_list); + CNodePtr CreateSendApplyKernel(const NotNull &graph_ptr, uint32_t event_id, uint32_t stream_id); + CNodePtr CreateRecvApplyKernel(const NotNull &graph_ptr, uint32_t event_id, uint32_t stream_id); + const std::vector> &get_stream_group() const { return stream_groups_; } + const std::map &get_event_map() const { return event_map_; } + + private: + AscendStreamAssign() = default; + ~AscendStreamAssign() = default; + void Reset(); + void CheckResourceAssign(const NotNull &graph_ptr); + void CheckStreamAssign(const NotNull &graph_ptr); + void CheckEventAssign(const NotNull &graph_ptr); + void AssignAllNodesStream(const NotNull &graph_ptr); + void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr); + void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr); + void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr); + void UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr); + void FindHcomParallelStreams(const NotNull &graph_ptr); + void InsertStreamActive(const NotNull &graph_ptr); + void UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, + vector *orders); + void InsertEventForIndependentParallel(const NotNull &graph_ptr); + void InsertEventForHcomParallel(const NotNull &graph_ptr); + void InsertEventCommonDependHcom(const NotNull &graph_ptr); + void InsertEventHcomDependCommon(const NotNull &graph_ptr); + void InsertEventHcomDependHcom(const NotNull &graph_ptr); + void InsertEventBetweenHcom(const NotNull &graph_ptr, const map> &hcom_index, + uint32_t first_hcom_stream, uint32_t last_hcom_stream); + bool IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, size_t index); + + void GetProcessedStream(const NotNull &graph_ptr); + void GetNeedActiveStreams(const NotNull &graph_ptr); + void ReorderIndependentOrders(const NotNull &graph_ptr); + + bool IsTaskSink(); + bool IsHcom(const CNodePtr &cur_cnode_ptr); + bool IsIndependentNode(const CNodePtr &node_ptr); + bool IsProcessedStream(uint32_t stream_id); + vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, + const CNodePtr &node); + void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); + + // function for memory resue + void GetStreamRelations(); + void DFS(uint32_t start, std::vector *group); + bool IsVecExist(std::vector *group); + void FindStreamRelations(const NotNull &graph_ptr); + void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr); + void GetStreamActiveStreamRelation(const NotNull &graph_ptr, size_t index); + StreamActiveKind GetStreamActiveKind(const NotNull &graph_ptr, size_t index); + uint32_t GetStreamByActivedStream(uint32_t actived_stream_id); + void PrintStreamRelations(); + void PrintStreamGroups(); + void FindEventRelations(const NotNull &graph_ptr); + bool IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const; + + bool independent_stream_activated_{false}; + bool hcom_stream_activated_{false}; + std::map independent_stream_map_{}; + std::map hcom_stream_map_{}; + std::map common_stream_map_{}; + std::set processed_streams_{}; + std::vector need_first_active_streams_{}; + + // attr for memory copy reuse + std::map> stream_relations_{}; + std::vector> stream_groups_{}; + std::map event_map_; + // new policy end +}; +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc new file mode 100644 index 0000000000..ab2c6b2748 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc @@ -0,0 +1,282 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_DATA_DUMP +#include "runtime/device/ascend/dump/data_dumper.h" + +#include +#include +#include +#include "utility" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/mem.h" +#include "runtime/kernel.h" +#include "runtime/device/ascend/dump/ge_dump.h" +#include "proto/op_mapping_info.pb.h" +#include "utils/context/ms_context.h" +#include "debug/data_dump_parser.h" + +constexpr uint32_t kAicpuLoadFlag = 1; +constexpr uint32_t kAicpuUnloadFlag = 0; +constexpr uint32_t kTupleTaskId = 0; +constexpr uint32_t kTupleStreamId = 1; +constexpr uint32_t kTupleArgs = 2; +constexpr uint32_t kCurrentStepTensorIndex = 0; +constexpr uint32_t kCurrentEpochTensorIndex = 1; +constexpr uint32_t kStepsPerEpochTensorIndex = 2; + +namespace mindspore { +namespace device { +namespace ascend { +void DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull task); +void DumpKernelInput(const CNodePtr &kernel, void *args, NotNull task); +void RtLoadDumpData(const aicpu::dump::OpMappingInfo &dump_info, void **ptr); + +DataDumper::~DataDumper() { + ReleaseDevMem(&dev_load_mem_); + ReleaseDevMem(&dev_unload_mem_); +} + +void DataDumper::LoadDumpInfo() { + MS_LOG(INFO) << "[DataDump] LoadDumpInfo start"; + MS_EXCEPTION_IF_NULL(kernel_graph_); + aicpu::dump::OpMappingInfo dump_info; + SetOpMappingInfo(NOT_NULL(&dump_info)); + + auto kernels = kernel_graph_->execution_order(); + for (const auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + if (!KernelNeedDump(kernel)) { + continue; + } + MS_LOG(INFO) << "[DataDump] LoadDumpInfo kernel:" << kernel->fullname_with_scope(); + dump_kernel_names_.emplace_back(kernel->fullname_with_scope()); + + aicpu::dump::Task task; + ConstructDumpTask(NOT_NULL(kernel), NOT_NULL(&task)); + MS_EXCEPTION_IF_NULL(dump_info.mutable_task()); + dump_info.mutable_task()->Add(std::move(task)); + } + RtLoadDumpData(dump_info, &dev_load_mem_); + load_flag_ = true; + MS_LOG(INFO) << "[DataDump] LoadDumpInfo end"; +} + +void DataDumper::SetOpMappingInfo(NotNull dump_info) const { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(kernel_graph_); + auto dump_path = DataDumpParser::GetInstance().GetDumpPath(); + if (!dump_path.has_value()) { + MS_LOG(EXCEPTION) << "Dump path invalid"; + } + auto device_id = context_ptr->device_id(); + dump_info->set_dump_path(dump_path.value() + "_" + std::to_string(device_id) + "/"); + MS_LOG(INFO) << "[DataDump] dump_path:" << dump_path.value(); + + dump_info->set_model_name(DataDumpParser::GetInstance().net_name() + "_" + std::to_string(kernel_graph_->graph_id())); + dump_info->set_dump_step(std::to_string(DataDumpParser::GetInstance().dump_step())); + dump_info->set_model_id(kernel_graph_->graph_id()); + dump_info->set_flag(kAicpuLoadFlag); + + const auto &input_ctrl_tensors = kernel_graph_->input_ctrl_tensors(); + if (input_ctrl_tensors == nullptr || input_ctrl_tensors->size() < 3) { + MS_LOG(INFO) << "[DataDump] Not data sink mode, input_ctrl_tensor"; + return; + } + const auto ¤t_step_tensor = input_ctrl_tensors->at(kCurrentStepTensorIndex); + const auto &currnet_epoch_tensor = input_ctrl_tensors->at(kCurrentEpochTensorIndex); + const auto &steps_per_epoch_tensor = input_ctrl_tensors->at(kStepsPerEpochTensorIndex); + + MS_EXCEPTION_IF_NULL(current_step_tensor); + MS_EXCEPTION_IF_NULL(currnet_epoch_tensor); + MS_EXCEPTION_IF_NULL(steps_per_epoch_tensor); + MS_EXCEPTION_IF_NULL(current_step_tensor->device_address()); + MS_EXCEPTION_IF_NULL(currnet_epoch_tensor->device_address()); + MS_EXCEPTION_IF_NULL(steps_per_epoch_tensor->device_address()); + + void *current_step = current_step_tensor->device_address()->ptr_; + void *current_epoch = currnet_epoch_tensor->device_address()->ptr_; + void *steps_per_epoch = steps_per_epoch_tensor->device_address()->ptr_; + + if (current_epoch != nullptr && current_step != nullptr && steps_per_epoch != nullptr) { + dump_info->set_step_id_addr(reinterpret_cast(current_epoch)); + dump_info->set_loop_cond_addr(reinterpret_cast(current_step)); + dump_info->set_iterations_per_loop_addr(reinterpret_cast(steps_per_epoch)); + } else { + MS_LOG(INFO) << "Invalid ctrl tensor device address"; + } +} + +bool DataDumper::KernelNeedDump(const CNodePtr &kernel) const { + if (AnfAlgo::GetKernelType(kernel) != TBE_KERNEL && AnfAlgo::GetKernelType(kernel) != AICPU_KERNEL && + AnfAlgo::GetKernelType(kernel) != AKG_KERNEL) { + return false; + } + MS_EXCEPTION_IF_NULL(kernel); + // dump all kernel if mode is set 0 in data_dump.json + return DataDumpParser::GetInstance().NeedDump(kernel->fullname_with_scope()); +} + +void DataDumper::UnloadDumpInfo() { + if (!load_flag_) { + MS_LOG(WARNING) << "Load not success, no need to unload"; + return; + } + MS_EXCEPTION_IF_NULL(kernel_graph_); + MS_LOG(INFO) << "[DataDump] UnloadDumpInfo start. graphId:" << kernel_graph_->graph_id(); + + aicpu::dump::OpMappingInfo op_mapping_info; + op_mapping_info.set_model_id(kernel_graph_->graph_id()); + op_mapping_info.set_flag(kAicpuUnloadFlag); + + for (const auto &kernel_name : dump_kernel_names_) { + aicpu::dump::Task task; + auto iter = runtime_info_map_.find(kernel_name); + if (iter == runtime_info_map_.end()) { + MS_LOG(EXCEPTION) << "[DataDump] kernel name not found in runtime_info_map"; + } + MS_EXCEPTION_IF_NULL(iter->second); + auto task_id = std::get(*iter->second); + task.set_task_id(task_id); + MS_EXCEPTION_IF_NULL(op_mapping_info.mutable_task()); + op_mapping_info.mutable_task()->Add(std::move(task)); + } + + RtLoadDumpData(op_mapping_info, &dev_unload_mem_); +} + +void DataDumper::ReleaseDevMem(void **ptr) const { + if (ptr == nullptr) { + return; + } + if (*ptr != nullptr) { + rtError_t rt_error = rtFree(*ptr); + if (rt_error != RT_ERROR_NONE) { + MS_LOG(ERROR) << "[DataDump] Call rtFree failed, ret:" << rt_error; + } + *ptr = nullptr; + } +} + +void DataDumper::ConstructDumpTask(NotNull kernel, NotNull dump_task) const { + dump_task->set_end_graph(false); + auto iter = runtime_info_map_.find(kernel->fullname_with_scope()); + if (iter == runtime_info_map_.end()) { + MS_LOG(EXCEPTION) << "[DataDump] kernel name not found in runtime_info_map"; + } + MS_EXCEPTION_IF_NULL(iter->second); + auto task_id = std::get(*iter->second); + auto stream_id = std::get(*iter->second); + auto args = std::get(*iter->second); + MS_LOG(INFO) << "[DataDump] Get runtime info task_id:" << task_id << " stream_id:" << stream_id; + + dump_task->set_task_id(task_id); + dump_task->set_stream_id(stream_id); + MS_EXCEPTION_IF_NULL(dump_task->mutable_op()); + dump_task->mutable_op()->set_op_name(kernel->fullname_with_scope()); + dump_task->mutable_op()->set_op_type(AnfAlgo::GetCNodeName(kernel.get())); + + DumpKernelOutput(kernel, args, dump_task); + DumpKernelInput(kernel, args, dump_task); +} + +void RtLoadDumpData(const aicpu::dump::OpMappingInfo &dump_info, void **ptr) { + std::string proto_str; + size_t proto_size = dump_info.ByteSizeLong(); + bool ret = dump_info.SerializeToString(&proto_str); + if (!ret || proto_size == 0) { + MS_LOG(EXCEPTION) << "[DataDump] Protobuf SerializeToString failed, proto size %zu."; + } + + rtError_t rt_ret = rtMalloc(ptr, proto_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtMalloc failed"; + } + + if (ptr == nullptr) { + MS_LOG(ERROR) << "[DataDump] rtMalloc failed, ptr is nullptr"; + return; + } + rt_ret = rtMemcpy(*ptr, proto_size, proto_str.c_str(), proto_size, RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtMemcpy failed"; + } + + MS_LOG(INFO) << "[DataDump] rtDatadumpInfoLoad start"; + rt_ret = rtDatadumpInfoLoad(*ptr, proto_size); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtDatadumpInfoLoad failed"; + } +} + +void DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull task) { + MS_LOG(INFO) << "[DataDump] DumpKernelOutput start. Kernel:" << kernel->fullname_with_scope(); + auto input_size = AnfAlgo::GetInputTensorNum(kernel); + auto output_size = AnfAlgo::GetOutputTensorNum(kernel); + uint64_t offset = sizeof(void *) * input_size; + for (size_t i = 0; i < output_size; ++i) { + auto data_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); + auto output_format = AnfAlgo::GetOutputFormat(kernel, i); + auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel, i); + + aicpu::dump::Output output; + output.set_data_type(GetGeDataType(data_type)); + output.set_format(GetGeFormat(output_format, output_shape.size())); + MS_EXCEPTION_IF_NULL(output.mutable_shape()); + for (auto dim : output_shape) { + output.mutable_shape()->add_dim(dim); + } + output.set_original_output_format(GetGeFormat(output_format, output_shape.size())); + output.set_address(static_cast(reinterpret_cast(args)) + offset); + MS_EXCEPTION_IF_NULL(task->mutable_output()); + task->mutable_output()->Add(std::move(output)); + offset += sizeof(void *); + } +} + +void DumpKernelInput(const CNodePtr &kernel, void *args, NotNull task) { + MS_LOG(INFO) << "[DataDump] DumpKernelInput start. Kernel:" << kernel->fullname_with_scope(); + auto input_size = AnfAlgo::GetInputTensorNum(kernel); + uint64_t offset = 0; + for (size_t i = 0; i < input_size; ++i) { + aicpu::dump::Input input; + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); + auto input_node = input_node_with_index.first; + auto input_index = input_node_with_index.second; + std::string output_format = AnfAlgo::GetOutputFormat(input_node, input_index); + auto output_type = AnfAlgo::GetOutputDeviceDataType(input_node, input_index); + if (output_type == kTypeUnknown) { + MS_LOG(WARNING) << "[DataDump] It is not suggested to use a lonely weight parameter as the output of graph"; + output_type = AnfAlgo::GetOutputInferDataType(input_node, input_index); + } + auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index); + + input.set_data_type(GetGeDataType(output_type)); + input.set_format(GetGeFormat(output_format, output_shape.size())); + MS_EXCEPTION_IF_NULL(input.mutable_shape()); + for (auto dim : output_shape) { + input.mutable_shape()->add_dim(dim); + } + input.set_address(static_cast(reinterpret_cast(args)) + offset); + MS_EXCEPTION_IF_NULL(task->mutable_input()); + task->mutable_input()->Add(std::move(input)); + offset += sizeof(void *); + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.h b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.h new file mode 100644 index 0000000000..d99eb4db68 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_DATADUMP_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_DATADUMP_H_ +#ifdef ENABLE_DATA_DUMP +#include +#include +#include +#include +#include +#include "backend/session/kernel_graph.h" + +namespace aicpu { +namespace dump { +class OpMappingInfo; +class Task; +} // namespace dump +} // namespace aicpu +namespace mindspore { +namespace device { +namespace ascend { +// tuple(op_name, task_id, stream_id, args) +using RuntimeInfo = std::tuple; +class DataDumper { + public: + DataDumper(const session::KernelGraph *kernel_graph, + const std::map> &runtime_info_map) + : load_flag_(false), + dev_load_mem_(nullptr), + dev_unload_mem_(nullptr), + kernel_graph_(kernel_graph), + runtime_info_map_(runtime_info_map) {} + ~DataDumper(); + void LoadDumpInfo(); + + void UnloadDumpInfo(); + + private: + void ReleaseDevMem(void **ptr) const; + bool KernelNeedDump(const CNodePtr &kernel) const; + void SetOpMappingInfo(NotNull dump_info) const; + void ConstructDumpTask(NotNull kernel, NotNull dump_task) const; + + bool load_flag_; + void *dev_load_mem_; + void *dev_unload_mem_; + std::vector dump_kernel_names_; + const session::KernelGraph *kernel_graph_; + std::map> runtime_info_map_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_DATADUMP_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h b/mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h new file mode 100644 index 0000000000..eae70c4b0b --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_GE_DUMP_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_GE_DUMP_H_ + +#include +#include +#include "proto/ge_dtype.pb.h" +#include "ir/dtype/type_id.h" +#include "utils/utils.h" + +namespace mindspore { +namespace device { +namespace ascend { +static ge::proto::DataType GetGeDataType(TypeId type_id) { + static const std::map data_type_map = { + {TypeId::kTypeUnknown, ge::proto::DT_UNDEFINED}, {TypeId::kNumberTypeFloat32, ge::proto::DT_FLOAT}, + {TypeId::kNumberTypeFloat16, ge::proto::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::proto::DT_INT8}, + {TypeId::kNumberTypeUInt8, ge::proto::DT_UINT8}, {TypeId::kNumberTypeInt16, ge::proto::DT_INT16}, + {TypeId::kNumberTypeUInt16, ge::proto::DT_UINT16}, {TypeId::kNumberTypeInt32, ge::proto::DT_INT32}, + {TypeId::kNumberTypeInt64, ge::proto::DT_INT64}, {TypeId::kNumberTypeUInt32, ge::proto::DT_UINT32}, + {TypeId::kNumberTypeUInt64, ge::proto::DT_UINT64}, {TypeId::kNumberTypeBool, ge::proto::DT_BOOL}, + {TypeId::kNumberTypeFloat64, ge::proto::DT_DOUBLE}, + }; + MS_LOG(INFO) << "Vm origin type_id:" << type_id; + auto iter = data_type_map.find(type_id); + if (iter == data_type_map.end()) { + MS_LOG(EXCEPTION) << "Invalid data type:" << type_id; + } + return iter->second; +} + +enum GeFormat { + kFormat_NCHW = 0, // NCHW + kFormat_NHWC, // NHWC + kFormat_ND, // Nd Tensor + kFormat_NC1HWC0, // NC1HWC0 + kFormat_FRACTAL_Z, // FRACTAL_Z + kFormat_NC1C0HWPAD, + kFormat_NHWC1C0, + kFormat_FSR_NCHW, + kFormat_FRACTAL_DECONV, + kFormat_C1HWNC0, + kFormat_FRACTAL_DECONV_TRANSPOSE, + kFormat_FRACTAL_DECONV_SP_STRIDE_TRANS, + kFormat_NC1HWC0_C04, // NC1HWC0, C0 =4 + kFormat_FRACTAL_Z_C04, // FRACZ, C0 =4 + kFormat_CHWN, + kFormat_FRACTAL_DECONV_SP_STRIDE8_TRANS, + kFormat_HWCN, + kFormat_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format + kFormat_BN_WEIGHT, + kFormat_FILTER_HWCK, // filter input tensor format + kFormat_HASHTABLE_LOOKUP_LOOKUPS = 20, + kFormat_HASHTABLE_LOOKUP_KEYS, + kFormat_HASHTABLE_LOOKUP_VALUE, + kFormat_HASHTABLE_LOOKUP_OUTPUT, + kFormat_HASHTABLE_LOOKUP_HITS = 24, + kFormat_C1HWNCoC0, + kFormat_MD, + kFormat_NDHWC, + kFormat_FRACTAL_ZZ, + kFormat_FRACTAL_NZ, + kFormat_NCDHW, + kFormat_DHWCN, // 3D filter input tensor format + kFormat_NDC1HWC0, + kFormat_FRACTAL_Z_3D, + kFormat_CN, + kFormat_NC, + kFormat_DHWNC, + kFormat_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format + kFormat_RESERVED, + kFormat_ALL +}; + +static GeFormat GetGeFormat(const std::string &format, size_t shape_size) { + static const std::map format_map = { + // default format: nchw, fractal_nz? + {kOpFormat_DEFAULT, kFormat_NCHW}, + {kOpFormat_NC1KHKWHWC0, kFormat_NC1KHKWHWC0}, + {kOpFormat_ND, kFormat_ND}, + {kOpFormat_NCHW, kFormat_NCHW}, + {kOpFormat_NHWC, kFormat_NHWC}, + {kOpFormat_HWCN, kFormat_HWCN}, + {kOpFormat_NC1HWC0, kFormat_NC1HWC0}, + {kOpFormat_FRAC_Z, kFormat_FRACTAL_Z}, + {kOpFormat_FRAC_NZ, kFormat_FRACTAL_NZ}, + {kOpFormat_C1HWNCoC0, kFormat_C1HWNCoC0}, + {kOpFormat_NC1HWC0_C04, kFormat_NC1HWC0_C04}, + {kOpFormat_FRACTAL_Z_C04, kFormat_FRACTAL_Z_C04}, + {kOpFormat_NDHWC, kFormat_NDHWC}, + }; + MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size; + if (format == kOpFormat_DEFAULT) { + return shape_size == 4 ? kFormat_NCHW : kFormat_ND; + } + auto iter = format_map.find(format); + if (iter == format_map.end()) { + MS_LOG(EXCEPTION) << "Invalid format:" << format; + } + return iter->second; +} +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_GE_DUMP_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/proto/ge_dtype.proto b/mindspore/ccsrc/runtime/device/ascend/dump/proto/ge_dtype.proto new file mode 100644 index 0000000000..7c690524d9 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/dump/proto/ge_dtype.proto @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} \ No newline at end of file diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/proto/op_mapping_info.proto b/mindspore/ccsrc/runtime/device/ascend/dump/proto/op_mapping_info.proto new file mode 100644 index 0000000000..d3377c655d --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/dump/proto/op_mapping_info.proto @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +package aicpu.dump; + +message Shape { + repeated uint64 dim = 1; +} + +message Output { + int32 data_type = 1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + string original_name = 5; + int32 original_output_index = 6; + int32 original_output_data_type = 7; + int32 original_output_format = 8; + uint64 size = 9; +}; + +message Input { + int32 data_type = 1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + uint64 size = 5; +} + +message Op { + string op_name = 1; + string op_type = 2; +}; + +message Task { + uint32 task_id = 1; + uint32 stream_id = 2; + Op op = 3; + repeated Output output = 4; + bool end_graph = 5; + repeated Input input = 6; +}; + +message OpMappingInfo { + string dump_path = 1; + oneof model_name_param { + string model_name = 2; + } + oneof model_id_param { + uint32 model_id = 3; + } + oneof step_id { + uint64 step_id_addr = 4; + } + oneof iterations_per_loop { + uint64 iterations_per_loop_addr = 5; + } + oneof loop_cond { + uint64 loop_cond_addr = 6; + } + uint32 flag = 7; // 0x01 load, 0x00 unload + repeated Task task = 8; + string dump_step = 9; +}; diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc new file mode 100644 index 0000000000..39cefcb020 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc @@ -0,0 +1,286 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/kernel_build_ascend.h" + +#include +#include +#include +#include + +#include "runtime/device/ascend/kernel_select_ascend.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" +#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h" +#include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h" +#include "backend/kernel_compiler/hccl/hccl_kernel_build.h" +#include "backend/kernel_compiler/rts/rt_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "backend/kernel_compiler/common_utils.h" +#include "frontend/operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "./common.h" + +namespace mindspore { +namespace device { +namespace ascend { +using mindspore::kernel::tbe::TbeUtils; +using std::make_shared; +static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) { + kernel::KernelModPtr kernel_mod_ptr = nullptr; + KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); + switch (kernel_type) { + case KernelType::AICPU_KERNEL: { + kernel_mod_ptr = kernel::AicpuOpBuild(anf_node); + break; + } + case KernelType::RT_KERNEL: { + kernel_mod_ptr = kernel::RtOpBuild(anf_node); + break; + } + case KernelType::HCCL_KERNEL: { + kernel_mod_ptr = kernel::HcclOpBuild(anf_node); + break; + } + default: { + MS_LOG(EXCEPTION) << "node [" << anf_node->DebugString() << "] Unsupported kernel_type:" << kernel_type; + } + } + return kernel_mod_ptr; +} + +static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + std::vector tbe_nodes; + for (const auto &anf_node : kernel_graph_ptr->execution_order()) { + MS_EXCEPTION_IF_NULL(anf_node); + if (!AnfAlgo::IsRealKernel(anf_node)) { + continue; + } + KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); + switch (kernel_type) { + case KernelType::TBE_KERNEL: { + if (AnfAlgo::GetKernelMod(anf_node) == nullptr && + AnfAlgo::GetFusionType(anf_node) == kernel::FusionType::DYNAMIC) { + tbe_nodes.push_back(anf_node); + } + break; + } + default: { + break; + } + } + } + bool ret = kernel::TbeOpParallelPreBuild(tbe_nodes); + return ret; +} + +static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + std::vector tbe_nodes; + std::vector akg_nodes; + std::vector other_nodes; + for (const auto &anf_node : kernel_graph_ptr->execution_order()) { + MS_EXCEPTION_IF_NULL(anf_node); + if (!AnfAlgo::IsRealKernel(anf_node)) { + continue; + } + KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); + switch (kernel_type) { + case KernelType::TBE_KERNEL: { + if (AnfAlgo::GetKernelMod(anf_node) == nullptr) { + tbe_nodes.push_back(anf_node); + } + break; + } + case KernelType::AKG_KERNEL: { + akg_nodes.push_back(anf_node); + break; + } + default: { + other_nodes.push_back(anf_node); + break; + } + } + } + bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes); + bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes); + auto bin_map = kernel::tbe::KernelMeta::GetInstance(); + (void)bin_map->ReadIndex(kernel::kCceKernelMeta); + for (const auto &anf_node : other_nodes) { + kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); + } + return tbe_ret && akg_ret; +} + +static std::vector CalCleanZerosSize(const CNodePtr &pre_node) { + MS_EXCEPTION_IF_NULL(pre_node); + auto kernel_mod = AnfAlgo::GetKernelMod(pre_node); + MS_EXCEPTION_IF_NULL(kernel_mod); + std::vector clean_size_list; + // clean output + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { + auto output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); + auto output_men_size = kernel_mod->GetOutputSizeList(); + for (auto index : output_indexs) { + auto clean_item = (output_men_size.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; + clean_size_list.emplace_back(clean_item); + } + } + // clean workspace + if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { + auto workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); + auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList(); + for (const auto &index : workspace_indexs) { + auto clean_item = (workspace_men_sizes.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; + clean_size_list.emplace_back(clean_item); + } + } + MS_LOG(INFO) << "clear output size:" << clean_size_list.size() << ",pre_node:" << pre_node->fullname_with_scope(); + return clean_size_list; +} + +static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph, + const mindspore::CNodePtr &pre_node, std::vector *new_nodes) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(pre_node); + MS_EXCEPTION_IF_NULL(new_nodes); + auto clear_zero_prim = std::make_shared(kAtomicAddrCleanOpName); + MS_EXCEPTION_IF_NULL(clear_zero_prim); + auto new_value_node = NewValueNode(clear_zero_prim); + MS_EXCEPTION_IF_NULL(new_value_node); + std::vector inputs = {new_value_node}; + inputs.push_back(pre_node); + CNodePtr clear_zero = kernel_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(clear_zero); + AbstractBasePtr abstract = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract); + clear_zero->set_abstract(abstract); + auto builder = std::make_shared(); + builder->SetKernelType(KernelType::TBE_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get()); + auto clean_size = CalCleanZerosSize(pre_node); + AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clear_zero); + AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get()); + new_nodes->push_back(clear_zero); +} + +static bool IsAtomicNode(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto parameters_indexs = kernel_mod->GenParameters(); + if (parameters_indexs.empty()) { + return false; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size(); + size_t param_num = parameters_indexs.size(); + size_t total_num = input_num + workspace_num + output_num; + MS_LOG(INFO) << "parameters size: " << param_num << ", input & workspace & output num: " << total_num; + size_t pad_index = param_num; + for (; pad_index < total_num; ++pad_index) { + parameters_indexs.emplace_back(0); + } + // process input + for (size_t j = 0; j < input_num; ++j) { + if (parameters_indexs.at(j) == 1) { + MS_LOG(EXCEPTION) << "Atomic addr clean does't support clean input address, input index: " << j; + } + } + // process output + std::vector output_indexs = {}; + for (size_t i = 0; i < output_num; ++i) { + auto param_output = parameters_indexs.at(input_num + workspace_num + i); + if (param_output == 1) { + output_indexs.emplace_back(i); + MS_LOG(INFO) << "Atomic clear output index: " << i; + } + } + if (!output_indexs.empty()) { + AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexs), kernel_node); + } + // process workspace + std::vector workspace_indexs = {}; + for (size_t k = 0; k < workspace_num; ++k) { + auto param_workspace = parameters_indexs.at(input_num + k); + if (param_workspace == 1) { + workspace_indexs.emplace_back(k); + MS_LOG(INFO) << "Atomic clear workspace index: " << k; + } + } + if (!workspace_indexs.empty()) { + AnfAlgo::SetNodeAttr(kAttrAtomicWorkspaceIndexs, MakeValue(workspace_indexs), kernel_node); + } + return !(workspace_indexs.empty() && output_indexs.empty()); +} + +bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + bool ret = device::ascend::KernelPreBuildParallelCompile(kernel_graph_ptr); + return ret; +} + +bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + TbeUtils::LoadCache(); + bool ret; + ret = device::ascend::KernelBuildParallelCompile(kernel_graph_ptr); + return ret; +} + +void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector new_nodes; + for (const auto &anf_node : kernel_graph->execution_order()) { + std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); + if (apply_function_name == prim::kPrimMaxPoolGrad->name() && + AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { + auto clear_zero_prim = std::make_shared(kClearZeroOpName); + MS_EXCEPTION_IF_NULL(clear_zero_prim); + auto new_value_node = NewValueNode(clear_zero_prim); + MS_EXCEPTION_IF_NULL(new_value_node); + std::vector inputs = {new_value_node}; + inputs.push_back(anf_node); + CNodePtr clear_zero = kernel_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(clear_zero); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + clear_zero->set_kernel_info(kernel_info); + AbstractBasePtr abstract = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract); + AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector({"x"})), clear_zero); + SelectKernelInfo(clear_zero); + // set the distinction label of clear same with anf + AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get()); + new_nodes.push_back(clear_zero); + } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) { + if (IsAtomicNode(anf_node)) { + AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); + } + } + new_nodes.push_back(anf_node); + } + kernel_graph->set_execution_order(new_nodes); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.h b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.h new file mode 100644 index 0000000000..0d2870eb0a --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ + +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace device { +namespace ascend { +/** + * @brief kernel pre build for ascend. + */ +bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr); +/** + * @brief kernel build for ascend. + */ +bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr); +/** + * @brief preporcess of kernel build for ascend, e.g. inserting clear_zero node for maxpool, bn. + * Must DO these changes just before kernel build, and after all of other optimizations on AnfGraph + */ +void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph); +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc new file mode 100644 index 0000000000..e8fc6c7a98 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -0,0 +1,584 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/kernel_select_ascend.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "debug/anf_ir_dump.h" +#include "frontend/operator/ops.h" +#include "ir/func_graph.h" +#include "utils/context/ms_context.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/kernel_query.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace { +const float kWegihtBaseScore = 1; +const float kFeatureMapBaseScore = 10; +constexpr auto kPriChoosenFormat = "pri_format"; +enum MatchCountPriority : int { + MATCH_COUNT_PRIORITY_BEGIN = 0, + MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, + MATCH_FORMAT_COUNT, + MATCH_SPECIAL_FORMAT_COUNT, + MATCH_DEFAULT_FORMAT_COUNT, + MATCH_OUTPUT_DTYPE_COUNT, + MATCH_COUNT_PRIORITY_END +}; + +const int kUnSupportMixedDataTypeIndex = -1; + +bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { + MS_EXCEPTION_IF_NULL(cnode); + // Check input data type + for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { + TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); + if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { + return false; + } + } + // Check output data type + for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { + if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) { + return false; + } + } + return true; +} + +string GetPriorityMatchFormat(const CNodePtr &cnode) { + string priority_matched_format = kOpFormat_NC1HWC0; + bool is_init = false; + bool need_change_nd = false; + for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { + auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); + if (AnfAlgo::IsFeatureMapInput(cnode, index) && + kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) { + priority_matched_format = !is_init ? pre_output_format : priority_matched_format; + is_init = true; + } + // feature map has two or more special format; + if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) { + priority_matched_format = kOpFormat_DEFAULT; + } + auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); + need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); + } + if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) { + priority_matched_format = kOpFormat_DEFAULT; + } + AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); + return priority_matched_format; +} +/** + * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location, + * if equal then next num location + * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3] + */ +bool PriorityChooseItem(const std::vector &cur_item, std::vector *best_item) { + MS_EXCEPTION_IF_NULL(best_item); + if (cur_item.size() != best_item->size()) { + MS_LOG(ERROR) << "Item size should be same!"; + return false; + } + // Update the best_item by comparing the cur_item and best_item + for (size_t i = 0; i < cur_item.size(); i++) { + if (cur_item[i] > best_item->at(i)) { + *best_item = cur_item; + return true; + } else if (cur_item[i] == best_item->at(i)) { + continue; + } else { + return false; + } + } + return false; +} + +void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr &kernel_node, + std::vector *const cur_kernelinfo_match_counts) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts); + if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) { + MS_LOG(EXCEPTION) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END; + } + auto pri_match_format = GetPriorityMatchFormat(kernel_node); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + auto input_anf_node = kernel_node->input(input_index + 1); + // we do not take ValueNode into consideration in graph kernel. + if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) { + if (input_anf_node->isa() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) { + continue; + } + } + auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore; + if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { + (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score; + } + // we match output fix precision first. + auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index); + if (prev_device_type == kTypeUnknown) { + prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); + } + if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) { + (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score; + } + if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { + (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; + } + if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) { + (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score; + } + } + + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + // cal count of same output dtype between abstract and kernel info + if (kernel_build_info.GetOutputDeviceType(output_index) == + AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) { + (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT] += 1; + } + } +} + +void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector *support_index) { + MS_EXCEPTION_IF_NULL(support_index); + int index = kUnSupportMixedDataTypeIndex; + switch (data_type) { + case kNumberTypeFloat16: + index = 0; + break; + case kNumberTypeFloat32: + case kNumberTypeFloat: + index = 1; + break; + default: + break; + } + support_index->push_back(index); +} + +void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index, + std::vector *support_datatype_index, std::vector *support_datatype) { + MS_EXCEPTION_IF_NULL(support_datatype); + auto data_type = kernel_build_info.GetInputDeviceType(input_index); + support_datatype->push_back(data_type); + AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); +} + +void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index, + std::vector *support_datatype_index, std::vector *support_datatype) { + MS_EXCEPTION_IF_NULL(support_datatype); + auto data_type = kernel_build_info.GetOutputDeviceType(output_index); + support_datatype->push_back(data_type); + AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); +} + +void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index, + std::vector *node_mix_precision_datatype_index, + std::vector *node_mix_precision_datatype) { + AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); + MS_EXCEPTION_IF_NULL(cur_input); + MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); + TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); + AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); + node_mix_precision_datatype->push_back(input_origin_type); +} + +void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index, + std::vector *node_mix_precision_datatype_index, + std::vector *node_mix_precision_datatype) { + MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); + auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index); + AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index); + node_mix_precision_datatype->push_back(output_origin_type); +} + +void CheckDataTypeInputs(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatypes, + std::map> *kernel_match_datatype_idx) { + if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) { + MS_LOG(EXCEPTION) << "Node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size " + << node_mix_precision_datatype.size(); + } + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); + if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) { + MS_LOG(EXCEPTION) << "Kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size " + << kernel_support_datatypes.size(); + } +} + +bool RaiseDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatypes, + std::map> *kernel_match_datatype_idx) { + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); + CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, + kernel_match_datatype_idx); + for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { + if (node_mix_precision_datatype[i] == kTypeUnknown) { + continue; + } + auto iter = kernel_match_datatype_idx->begin(); + while (iter != kernel_match_datatype_idx->end()) { + if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { + auto find_iter = kernel_support_datatypes.find(iter->first); + if (find_iter == kernel_support_datatypes.end()) { + MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; + } + if (i >= find_iter->second.size()) { + MS_LOG(EXCEPTION) << "Node index " << i << "kernel datatype size " << find_iter->second.size(); + } + if (node_mix_precision_datatype[i] != find_iter->second[i]) { + iter = kernel_match_datatype_idx->erase(iter); + } else { + ++iter; + } + continue; + } + auto datatype_indexes = iter->second; + if (i >= datatype_indexes.size()) { + MS_LOG(EXCEPTION) << "Node datatype index: " << i << " kernel support size " << datatype_indexes.size(); + } + if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) { + iter = kernel_match_datatype_idx->erase(iter); + } else { + ++iter; + } + } + } + return !kernel_match_datatype_idx->empty(); +} + +bool CanDataTypeReduce(const std::vector &datatype_indexes, int check_index, + const std::vector &node_mix_precision_datatype_index) { + auto check_index_tmp = IntToSize(check_index); + if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) { + return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex && + datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index]; + } + MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range"; +} + +bool RaiseOrReduceDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatypes, + std::map> *kernel_match_datatype_idx) { + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); + CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, + kernel_match_datatype_idx); + for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { + if (node_mix_precision_datatype[i] == kTypeUnknown) { + continue; + } + auto iter = kernel_match_datatype_idx->begin(); + while (iter != kernel_match_datatype_idx->end()) { + if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { + auto find_iter = kernel_support_datatypes.find(iter->first); + if (find_iter == kernel_support_datatypes.end()) { + MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; + } + if (i >= find_iter->second.size()) { + MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size(); + } + if (node_mix_precision_datatype[i] != find_iter->second[i]) { + iter = kernel_match_datatype_idx->erase(iter); + } else { + ++iter; + } + continue; + } + auto datatype_indexes = iter->second; + if (i >= datatype_indexes.size()) { + MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); + } + if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) { + iter = kernel_match_datatype_idx->erase(iter); + } else { + ++iter; + } + } + } + return !kernel_match_datatype_idx->empty(); +} + +void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info, + std::vector *support_indexes, std::vector *node_mix_precision_datatype, + std::vector *support_datatypes, + std::vector *node_mix_precision_datatype_index) { + MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); + bool add_node_datatype_flag = false; + if (node_mix_precision_datatype->empty()) { + add_node_datatype_flag = true; + } + for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { + AddKernelInputSupportDataType(kernel_build_info, input_index, support_indexes, support_datatypes); + if (add_node_datatype_flag) { + AddNodeInputDataType(kernel_node, input_index, node_mix_precision_datatype_index, node_mix_precision_datatype); + } + } + // Check output data type + for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { + AddKernelOutputSupportDataType(kernel_build_info, output_index, support_indexes, support_datatypes); + if (add_node_datatype_flag) { + AddNodeOutputDataType(kernel_node, output_index, node_mix_precision_datatype_index, node_mix_precision_datatype); + } + } +} + +void PrecisionReduce(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatype, + std::map> *kernel_match_datatype_idx, bool *precision_reduce) { + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(precision_reduce); + std::map> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx; + // raise precision + bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, + kernel_support_datatype, kernel_match_datatype_idx); + if (selected_ret) { + *precision_reduce = false; + return; + } + if (context_ptr->enable_reduce_precision()) { + selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, + kernel_support_datatype, &kernel_match_datatype_idx_copy); + } + if (selected_ret) { + *precision_reduce = true; + *kernel_match_datatype_idx = kernel_match_datatype_idx_copy; + } +} + +void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, + const std::shared_ptr &selected_kernel_build_info, + bool precision_reduce) { + MS_EXCEPTION_IF_NULL(selected_kernel_build_info); + MS_EXCEPTION_IF_NULL(cnode); + std::ostringstream buffer; + buffer << cnode->DebugString(); + if (precision_reduce) { + buffer << " Reduce precision, node datatype: \n"; + } else { + buffer << " Raise precision, node datatype: \n"; + } + PrintInputAndOutputInferType(buffer, cnode); + buffer << ", select kernel:" << selected_kernel_build_info->ToString(); + MS_LOG(INFO) << buffer.str(); +} + +std::shared_ptr ChooseMatchedKernelInfo( + const CNodePtr &kernel_node, const std::vector> &kernel_info_list) { + if (kernel_info_list.empty()) { + return nullptr; + } + std::vector most_match_counts = {-1, -1, -1, -1, -1}; + size_t selected_index = 0; + for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { + std::vector cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; + auto kernel_info_ptr = kernel_info_list[info_index]; + MS_EXCEPTION_IF_NULL(kernel_info_ptr); + UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); + // Currently the selection policy is the match format count first, and then is datatype counts. + if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) { + selected_index = SizeToInt(info_index); + } + } + return kernel_info_list[selected_index]; +} + +std::vector> FilteredKernelInfoByDtype( + const CNodePtr &cnode, const std::vector> &kernel_info_list) { + std::vector> result; + for (const auto &kernel_build_info : kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_build_info); + if (!MatchInferOutputDataType(cnode, *kernel_build_info)) { + continue; + } + result.push_back(kernel_build_info); + } + return result; +} + +std::vector> FilterRaisedOrReducePrecisionMatchedKernelInfo( + const CNodePtr &cnode, const std::vector> &kernel_info_list, + bool *precision_reduce) { + std::vector> filtered_kernel_info_list; + std::map> kernel_match_datatype_idx; + std::map> kernel_support_datatype; + std::vector node_mix_precision_datatype_index; + std::vector node_mix_precision_datatype; + for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { + std::vector support_indexes; + std::vector support_datatypes; + MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]); + AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype, + &support_datatypes, &node_mix_precision_datatype_index); + kernel_match_datatype_idx[info_index] = support_indexes; + kernel_support_datatype[info_index] = support_datatypes; + } + PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, + &kernel_match_datatype_idx, precision_reduce); + std::transform( + kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list), + [&](const std::pair> &matched_idx) -> std::shared_ptr { + return kernel_info_list[matched_idx.first]; + }); + return filtered_kernel_info_list; +} +} // namespace + +void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); + MS_EXCEPTION_IF_NULL(input_kernel_node); + auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); + MS_EXCEPTION_IF_NULL(input_with_index.first); + auto real_input_node = input_with_index.first; + if (real_input_node->isa()) { + continue; + } + if (real_input_node->isa() && !AnfAlgo::IsParameterWeight(real_input_node->cast())) { + continue; + } + auto builder = std::make_shared(); + if (IsValueNode(input_kernel_node) && + AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { + std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; + builder->SetOutputsFormat(output_format); + std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + builder->SetOutputsDeviceType(output_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + continue; + } + // we set special device info of a input tensor. + bool is_ref = false; + auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); + if (op_info != nullptr) { + is_ref = op_info->is_ref(); + } + MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); + if (MsContext::GetInstance()->execution_mode() == kPynativeMode && + AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { + continue; + } + if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { + std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; + builder->SetOutputsFormat(output_format); + std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + builder->SetOutputsDeviceType(output_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); + } + } +} + +KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, + const std::vector> &kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + KernelSelectStatus select_status = kNoMatched; + bool precision_reduce = false; + std::shared_ptr selected_kernel_info = nullptr; + // Matched kernel info + // Filter kernel info matched with me infered type + auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list); + if (!filtered_kernel_info_list.empty()) { + selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); + select_status = kStatusAllMatched; + } else { + // selected kernel info using raised precision or reduce precision + filtered_kernel_info_list = + FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce); + selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); + if (selected_kernel_info == nullptr) { + return select_status; + } else { + PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); + select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; + } + } + // Set kernel info to the anfnode + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); + // Set format and data type for input tensor. + SetTensorDeviceInfo(*selected_kernel_info, kernel_node); + return select_status; +} + +KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { + std::vector> kernel_info_list; + std::vector> aicpu_kernel_info_list; + MS_EXCEPTION_IF_NULL(kernel_node); + if (AnfAlgo::IsGraphKernel(kernel_node)) { + auto func_graph = GetValueNode(kernel_node->input(kAnfPrimitiveIndex)); + MS_EXCEPTION_IF_NULL(func_graph); + SelectGraphKernelInfo(kernel_node, func_graph); + return kStatusAllMatched; + } + kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type); + auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); + // If aicore not find valid kernel info reloading aicpu kernel info list to find it + if (select_status == kNoMatched) { + MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() + << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; + kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list); + select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list); + AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); + } + // The kernel info not finded both in the aicpu kernel list & aicore kernel list + if (select_status == kNoMatched) { + std::ostringstream buffer; + PrintInputAndOutputInferType(buffer, kernel_node); + MS_LOG(WARNING) << ">>> Candidates kernel info list:"; + for (size_t index = 0; index < kernel_info_list.size(); ++index) { + MS_LOG(WARNING) << "Kernel [" << index << "] :" << kernel_info_list[index]->ToString(); + } + for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) { + MS_LOG(WARNING) << "Kernel [" << (kernel_info_list.size() + index) + << "] :" << aicpu_kernel_info_list[index]->ToString(); + } + if (IsPrimitiveCNode(kernel_node, prim::kPrimLabelSwitch)) { + auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); + // Set format and data type for input tensor. + SetTensorDeviceInfo(*selected_kernel_info, kernel_node); + } else { + MS_LOG(WARNING) << " <<<"; + MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() + << "] cannot find valid kernel info, not supported the type:" << buffer.str() + << ", please refer to the supported dtypes in candidates kernel info list"; + } + } + return select_status; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h new file mode 100644 index 0000000000..8a93b77cec --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ +#include "ir/anf.h" +#include "backend/kernel_compiler/kernel_build_info.h" +namespace mindspore { +namespace device { +namespace ascend { +enum KernelSelectStatus { + kNoMatched = -1, + kStatusAllMatched = 0, + kStatusReducePrecision = 1, + kStatusRaisePrecision = 2, +}; +KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, + KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); +void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node); +void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph); +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc new file mode 100644 index 0000000000..c76f96728f --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc @@ -0,0 +1,530 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/kernel_select_ascend.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "ir/func_graph.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/kernel_query.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace { +// sort format according the number of occurrences. +bool cmp_format_num(const std::pair &a, const std::pair &b) { + if (a.second != b.second) { + return a.second > b.second; + } else if (a.first == kOpFormat_DEFAULT) { + return a.second + 1 > b.second; + } else if (b.first == kOpFormat_DEFAULT) { + return a.second > b.second + 1; + } + return a.second > b.second; +} + +TypeId GetPrimitivePrecision(const CNodePtr &cnode) { + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + + TypeId except_type = kTypeUnknown; + if (primitive->GetAttr(kAttrFixPrecision) != nullptr) { + auto strExceptDtype = GetValue(primitive->GetAttr(kAttrFixPrecision)); + if (strExceptDtype == "float16") { + except_type = kNumberTypeFloat16; + } else if (strExceptDtype == "float32") { + except_type = kNumberTypeFloat32; + } else { + MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got" << strExceptDtype; + } + } + + return except_type; +} +} // namespace + +void ResetKernelBuildInfo(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); + MS_EXCEPTION_IF_NULL(input_kernel_node); + auto kernel_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); + if (!kernel::IsWeightBoundary(kernel_with_index.first)) { + continue; + } + // reset format and dtype. + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + builder.SetOutputsDeviceType(std::vector{kTypeUnknown}); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get()); + } +} + +void UpdateKernelInfo(const std::vector &node_list) { + for (size_t i = 0; i < node_list.size(); ++i) { + // select nodes in subgraph. + auto anf_node = node_list[i]; + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto fix_precision_type = GetPrimitivePrecision(cnode); + if (fix_precision_type != kTypeUnknown) { + std::vector> kernel_info_list; + kernel::KernelQuery(cnode, &kernel_info_list, KernelType::AKG_KERNEL); + + for (size_t index = 0; index < kernel_info_list.size(); ++index) + // only math the first input + if (kernel_info_list[index]->GetInputDeviceType(0) == fix_precision_type && + kernel_info_list[index]->GetInputFormat(0) == AnfAlgo::GetPrevNodeOutputFormat(cnode, 0) && + AnfAlgo::GetInputDeviceDataType(cnode, 0) != fix_precision_type) { + auto selected_kernel_info_ptr = kernel_info_list[index]; + ResetKernelBuildInfo(cnode); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get()); + SetTensorDeviceInfo(*selected_kernel_info_ptr, cnode); + break; + } + } + } +} + +bool CanConvertDefaultShapeToNZ(const std::vector &shape) { + for (size_t i = 1; i <= shape.size(); ++i) { + if (i > 2) { + break; + } + if (shape[shape.size() - i] != 1 && shape[shape.size() - i] % kCubeSize != 0) { + return false; + } + } + return true; +} + +std::vector DefaultToFracNZAxis(const std::vector &ori_shape, const std::vector &axis) { + std::vector frac_nz_axis = axis; + auto shape_len = ori_shape.size(); + for (size_t i = 0; i < axis.size(); ++i) { + auto axis_idx = (frac_nz_axis[i] + shape_len) % shape_len; + if (axis_idx == shape_len - 1) { + frac_nz_axis[i] = axis_idx - 1; + frac_nz_axis.push_back(axis_idx + 2); + } else if (axis_idx == shape_len - 2) { + frac_nz_axis[i] = axis_idx + 1; + frac_nz_axis.push_back(axis_idx + 2); + } else { + frac_nz_axis[i] = axis_idx; + } + } + return frac_nz_axis; +} + +std::vector GetReducedFracNZShape(const std::vector &ori_shape, const std::vector &axis, + bool keep_dims) { + std::vector result; + std::set positive_idx; + for (const auto &a : axis) { + positive_idx.insert(a >= 0 ? a : ori_shape.size() + a); + } + for (size_t i = 0; i < ori_shape.size(); ++i) { + if (positive_idx.count(i) == 0) { + result.push_back(ori_shape[i]); + } else if (keep_dims) { + result.push_back(1); + } + } + return result; +} + +void UpdateFracNZReduceOp(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); + if (input_format == kOpFormat_FRAC_NZ) { + // Clone primitive to modify it + auto prim = GetCNodePrimitive(cnode); + auto new_prim = std::make_shared(*prim); + auto new_prim_node = NewValueNode(new_prim); + cnode->set_input(0, new_prim_node); + + auto axis_value = new_prim->GetAttr(kAttrAxis); + std::vector default_axis; + if (axis_value->isa()) { + auto value_list = dyn_cast(axis_value); + for (const auto &item : value_list->value()) { + if (item->isa()) { + default_axis.push_back(GetValue(item)); + } + } + } else if (axis_value->isa()) { + auto value_tuple = dyn_cast(axis_value); + for (const auto &item : value_tuple->value()) { + if (item->isa()) { + default_axis.push_back(GetValue(item)); + } + } + } else { + MS_LOG(ERROR) << "Axis attr type is not correct!"; + } + auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); + std::vector frac_nz_axis = DefaultToFracNZAxis(infer_shape, default_axis); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue>(frac_nz_axis), cnode); + auto output_shape = AnfAlgo::GetOutputInferShape(cnode, 0); + if (output_shape.size() == 1) { + AnfAlgo::SetNodeAttr(kAttrOutputDefault, MakeValue(true), cnode); + } + } +} + +void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, bool *use_same_format) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(default_format); + MS_EXCEPTION_IF_NULL(use_same_format); + std::unordered_map all_input_formats; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; ++i) { + auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; + MS_EXCEPTION_IF_NULL(input_kernel_node); + if (!input_kernel_node->isa()) { + ++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)]; + continue; + } + auto para = input_kernel_node->cast(); + if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { + ++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)]; + continue; + } + *use_same_format = false; + } + + if (all_input_formats.empty()) { + // all inputs are parameter. + *default_format = kOpFormat_NC1HWC0; + } else { + std::vector> pairs; + for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) { + pairs.push_back(std::make_pair(iter->first, iter->second)); + } + + std::sort(pairs.begin(), pairs.end(), cmp_format_num); + *default_format = pairs.begin()->first; + } + + for (size_t i = 0; i < input_num; ++i) { + auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; + MS_EXCEPTION_IF_NULL(input_kernel_node); + if (!input_kernel_node->isa() || + AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) != kTypeUnknown) { + continue; + } + auto weight_infer_shape = AnfAlgo::GetOutputInferShape(input_kernel_node, 0); + if (weight_infer_shape.size() < 2 && *default_format == kOpFormat_FRAC_NZ) { + *default_format = kOpFormat_DEFAULT; + *use_same_format = true; + break; + } + } +} + +void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector &input_list, + const std::string &default_format, bool use_same_format, + std::vector *graph_input_format, std::vector *graph_input_type) { + MS_EXCEPTION_IF_NULL(graph_input_format); + MS_EXCEPTION_IF_NULL(graph_input_type); + // We set same format to all inputs of graph kernel subgraph, and process this latter. + // We set dtype to inputs of graph kernel subgraph same as infer dtypes. + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; ++i) { + auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; + MS_EXCEPTION_IF_NULL(input_kernel_node); + if (use_same_format) { + bool can_convert = true; + if (default_format == kOpFormat_FRAC_NZ) { + auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + if (!CanConvertDefaultShapeToNZ(infer_shape)) { + MS_LOG(WARNING) << "Shape can't be converted to frac nz shape, so use default format instead"; + can_convert = false; + } + } + if (can_convert) { + graph_input_format->push_back(default_format); + } else { + graph_input_format->push_back(kOpFormat_DEFAULT); + } + graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); + continue; + } + + if (!input_kernel_node->isa()) { + // subgraph parameter from output of other nodes. + graph_input_format->push_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)); + graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); + continue; + } + + auto para = input_kernel_node->cast(); + MS_EXCEPTION_IF_NULL(para); + if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { + // parameter already selected. + graph_input_format->push_back(AnfAlgo::GetOutputFormat(para, 0)); + graph_input_type->push_back(AnfAlgo::GetOutputDeviceDataType(para, 0)); + continue; + } + + // weight parameter. + graph_input_format->push_back(default_format); + graph_input_type->push_back(AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)); + } + + for (size_t i = 0; i < input_num; ++i) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + std::vector outputs_format = {(*graph_input_format)[i]}; + std::vector outputs_device_type = {(*graph_input_type)[i]}; + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_device_type); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); + } +} + +void UpdateEquivFormat(const std::vector> &output_index, + const std::vector &node_list, const FuncGraphPtr &func_graph, + const FuncGraphManagerPtr &mng) { + MS_EXCEPTION_IF_NULL(mng); + for (size_t i = 0; i < node_list.size(); ++i) { + // select nodes in subgraph. + auto anf_node = node_list[i]; + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + cnode->set_kernel_info(std::make_shared()); + SelectKernelInfo(cnode, KernelType::AKG_KERNEL); + // Update ReduceSum + if (!IsPrimitiveCNode(cnode, prim::kPrimReduceSum)) { + continue; + } + UpdateFracNZReduceOp(cnode); + // If ReduceSum's output is 1d and not Default format, convert it to Default format + auto out_format = AnfAlgo::GetOutputFormat(cnode, 0); + if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) { + continue; + } + auto infer_shape = AnfAlgo::GetOutputInferShape(cnode, 0); + // Insert EquivFormat node, then select kernel info again + std::vector trans_inputs; + trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat)); + trans_inputs.push_back(cnode); + CNodePtr trans_node = func_graph->NewCNode(trans_inputs); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0)}, + {AnfAlgo::GetOutputInferShape(cnode, 0)}, trans_node.get()); + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue>({"x"}), trans_node); + + if (trans_node->kernel_info() == nullptr) { + trans_node->set_kernel_info(std::make_shared()); + } + SelectKernelInfo(trans_node, KernelType::AKG_KERNEL); + mng->Replace(cnode, trans_node); + } +} + +void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &input_list, + const FuncGraphManagerPtr &mng, const std::string &default_format, + std::vector *graph_input_format, std::vector *graph_input_type, + std::vector *need_update) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(mng); + MS_EXCEPTION_IF_NULL(graph_input_format); + MS_EXCEPTION_IF_NULL(graph_input_type); + MS_EXCEPTION_IF_NULL(need_update); + // check graph input format and dtype use inner ops. + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (graph_input_format->size() != input_num || graph_input_type->size() != input_num || + need_update->size() != input_num) { + MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() + << "], [" << graph_input_format->size() << "] != [" << input_num << "]"; + } + auto &node_users = mng->node_users(); + for (size_t i = 0; i < input_num; ++i) { + auto &input = input_list[i]; + auto iter = node_users.find(input); + if (iter == node_users.end() || iter->second.empty()) { + continue; + } + for (auto &node_user : iter->second) { + if (node_user.first->kernel_info() == nullptr || !node_user.first->kernel_info()->has_build_info()) { + // maybe not a real kernel. + continue; + } + auto user_format = AnfAlgo::GetInputFormat(node_user.first, IntToSize(node_user.second - 1)); + if (user_format != (*graph_input_format)[i]) { + MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" + << kernel_node->DebugString() + << "] selected different format. we use defult: " << default_format; + (*graph_input_format)[i] = default_format; + (*need_update)[i] = true; + } + + if (kernel_node->input(i + 1)->isa() || + AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) { + continue; + } + + TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); + MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" + << kernel_node->DebugString() + << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); + (*graph_input_type)[i] = default_dtype; + (*need_update)[i] = true; + } + } +} + +void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &node_list, + const std::vector &input_list, const std::vector &need_update, + const std::vector &graph_input_format, + const std::vector &graph_input_type) { + MS_EXCEPTION_IF_NULL(kernel_node); + // update graph input format and dtype use inner ops. + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (graph_input_format.size() != input_num || graph_input_type.size() != input_num || + need_update.size() != input_num) { + MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() + << "], [" << graph_input_format.size() << "] != [" << input_num << "]"; + } + for (size_t i = 0; i < input_num; ++i) { + if (!need_update[i]) { + continue; + } + + MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString() + << "] to: " << graph_input_format[i]; + MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString() + << "] to: " << TypeIdLabel(graph_input_type[i]); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + std::vector outputs_format = {graph_input_format[i]}; + std::vector outputs_device_type = {graph_input_type[i]}; + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_device_type); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); + } + + ResetKernelBuildInfo(kernel_node); + // select nodes in subgraph again. + for (size_t i = 0; i < node_list.size(); ++i) { + auto anf_node = node_list[i]; + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + size_t cnode_input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t j = 0; j < cnode_input_num; ++j) { + auto input_node = cnode->input(j + 1); + MS_EXCEPTION_IF_NULL(input_node); + if (!IsValueNode(input_node)) { + continue; + } + // reset format and dtype of const tensor. + builder.SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + builder.SetOutputsDeviceType(std::vector{kTypeUnknown}); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_node.get()); + } + SelectKernelInfo(node_list[i]->cast(), KernelType::AKG_KERNEL); + } +} + +void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector> &output_index, + const std::vector &graph_input_format, + const std::vector &graph_input_type) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector graph_output_format; + std::vector graph_output_type; + for (size_t i = 0; i < output_index.size(); ++i) { + auto const &output = output_index[i]; + graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second)); + TypeId output_type(kTypeUnknown); + if (output.first->isa()) { + output_type = AnfAlgo::GetCNodeOutputPrecision(output.first); + } + if (output_type == kTypeUnknown) { + output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second); + } + graph_output_type.push_back(output_type); + } + + kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; + graph_info_builder.SetInputsFormat(graph_input_format); + graph_info_builder.SetInputsDeviceType(graph_input_type); + graph_info_builder.SetOutputsFormat(graph_output_format); + graph_info_builder.SetOutputsDeviceType(graph_output_type); + graph_info_builder.SetProcessor(kernel::Processor::AICORE); + graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); + graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); + auto graph_selected_info = graph_info_builder.Build(); + MS_EXCEPTION_IF_NULL(graph_selected_info); + AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get()); + SetTensorDeviceInfo(*graph_selected_info, kernel_node); +} + +void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(func_graph); + + // collect input info of funcgraph + std::vector node_list; + std::vector input_list; + std::vector output_list; + kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); + if (input_list.size() != kernel_node->inputs().size() - 1) { + MS_EXCEPTION(ArgumentError) << "Input num of funcgraph[" << func_graph->ToString() << "] not equal input of cnode[" + << kernel_node->DebugString() << "], [%" << input_list.size() << "] != [" + << kernel_node->inputs().size() << "]"; + } + + std::string default_format; + bool use_same_format = true; + GetDefaultFormat(kernel_node, &default_format, &use_same_format); + MS_LOG(DEBUG) << "GraphKernel[" << func_graph->ToString() << "] use same input format[" << default_format + << "] for ParameterWeight."; + + std::vector graph_input_format; + std::vector graph_input_type; + UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, + &graph_input_type); + + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + } + auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list); + UpdateEquivFormat(output_index, node_list, func_graph, mng); + node_list.clear(); + input_list.clear(); + output_list.clear(); + kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); + + // update graph input format and dtype use inner ops. + std::vector need_update(AnfAlgo::GetInputTensorNum(kernel_node), false); + CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type, + &need_update); + UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type); + + // set fix_precision for kernel when the me prim has fix_precision attr + UpdateKernelInfo(node_list); + + output_index = kernel::GetOutputIndex(node_list, input_list, output_list); + SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.cc new file mode 100644 index 0000000000..4886c00a8e --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/ascend/profiling/plugin_impl.h" +#include +#include "utils/log_adapter.h" +using std::string; + +namespace mindspore { +namespace device { +namespace ascend { +Reporter *PluginImpl::reporter_ = nullptr; + +PluginImpl::PluginImpl(const std::string &module) : module_(module) { MS_LOG(INFO) << "Create PluginImpl."; } + +int PluginImpl::Init(const Reporter *reporter) { + MS_LOG(INFO) << "PluginImpl init"; + MS_EXCEPTION_IF_NULL(reporter); + reporter_ = const_cast(reporter); + return 0; +} + +int PluginImpl::UnInit() { + MS_LOG(INFO) << " PluginImpl Uninit "; + reporter_ = nullptr; + return 0; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h b/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.h similarity index 100% rename from mindspore/ccsrc/device/ascend/profiling/plugin_impl.h rename to mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.h diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.cc new file mode 100644 index 0000000000..1f35cba0f7 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/ascend/profiling/profiling_engine_impl.h" +#include "utils/log_adapter.h" +#include "runtime/device/ascend/profiling/plugin_impl.h" + +namespace mindspore { +namespace device { +namespace ascend { +PluginIntf *ProfilingEngineImpl::CreatePlugin() { + MS_LOG(INFO) << "Create Plugin."; + return new (std::nothrow) PluginImpl("Framework"); +} + +int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) { + if (plugin != nullptr) { + delete plugin; + plugin = nullptr; + } + return 0; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.h similarity index 100% rename from mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h rename to mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.h diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc new file mode 100644 index 0000000000..6117fe5ecf --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc @@ -0,0 +1,207 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/profiling/profiling_manager.h" +#include +#include +#include "securec/include/securec.h" +#include "./prof_mgr_core.h" +#include "runtime/device/ascend/profiling/plugin_impl.h" +#include "runtime/device/ascend/profiling/profiling_engine_impl.h" +#include "utils/log_adapter.h" +#include "utils/context/ms_context.h" +#include "common/utils.h" +#include "utils/convert_utils.h" +#include "runtime/base.h" + +namespace mindspore { +namespace device { +namespace ascend { +ProfilingManager &ProfilingManager::GetInstance() { + static ProfilingManager inst; + return inst; +} + +ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) { + engine_0_ = std::make_shared(); +} + +uint64_t ProfilingManager::GetJobId() const { + const char *job_id = std::getenv("JOB_ID"); + return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); +} + +bool ProfilingManager::ReportProfilingData(const map &op_taskId_map) const { + if (!IsProfiling()) { + MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; + return false; + } + if (op_taskId_map.empty()) { + MS_LOG(WARNING) << "op_taskId_map is empty."; + return false; + } + auto reporter = PluginImpl::GetPluginReporter(); + if (reporter == nullptr) { + MS_LOG(ERROR) << "No profiling data report!"; + return false; + } + MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); + + Msprof::Engine::ReporterData reporter_data = {}; + for (const auto &iter : op_taskId_map) { + auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; + reporter_data.deviceId = UintToInt(device_id_); + reporter_data.data = (unsigned char *)(const_cast(data.c_str())); + reporter_data.dataLen = data.size(); + auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + return false; + } + ret = reporter->Report(&reporter_data); + if (ret != 0) { + MS_LOG(ERROR) << "reporter data fail, errorno(" << ret << ")"; + return false; + } + } + return true; +} + +static std::vector Split(const std::string &str, const char delim) { + std::vector elems; + + if (str.empty()) { + elems.emplace_back(""); + return elems; + } + + std::stringstream ss(str); + std::string item; + + while (getline(ss, item, delim)) { + elems.push_back(item); + } + auto str_size = str.size(); + if (str_size > 0 && str[str_size - 1] == delim) { + elems.emplace_back(""); + } + + return elems; +} + +bool ProfilingManager::StartupProfiling(uint32_t device_id) { + auto is_profiling = IsProfiling(); + if (!is_profiling) { + MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; + return true; + } + device_id_ = device_id; + // register Framework to profiling + int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); + if (result != 0) { + MS_LOG(ERROR) << "Register profiling Engine failed."; + return false; + } + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + const string prof_options_str = context->profiling_options(); + std::vector opts = Split(prof_options_str, ':'); + if (opts.empty()) { + MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!"; + return true; + } + // current one docker only use one device` + nlohmann::json p_device; + // JOBID + auto job_id = GetJobId(); + p_device["jobID"] = std::to_string(job_id); + // device_id + p_device["deviceID"] = std::to_string(device_id); + // features:'training_trace', 'task_trace' etc + nlohmann::json features; + for (std::vector::size_type i = 0; i < opts.size(); i++) { + nlohmann::json f; + f["name"] = opts[i]; + features[i] = f; + } + p_device["features"] = features; + // only one device, but sProfMgrStartUp API require for device list + nlohmann::json devices; + devices[0] = p_device; + nlohmann::json startCfg; + startCfg["startCfg"] = devices; + + if (!ProfStartUp(NOT_NULL(&startCfg))) { + MS_LOG(ERROR) << "ProfMgrStartUp failed."; + return false; + } + return true; +} + +bool ProfilingManager::ProfStartUp(NotNull startCfg) { + // convert json to string + std::stringstream ss; + ss << *startCfg; + std::string cfg = ss.str(); + MS_LOG(INFO) << "profiling config " << cfg; + auto ret = rtProfilerStart(); + if (ret != RT_ERROR_NONE) { + MS_LOG(INFO) << "Call rtProfilerStart failed, ret:" << ret; + return false; + } + + // call profiling startup API + ProfMgrCfg prof_cfg = {cfg}; + prof_handle_ = ProfMgrStartUp(&prof_cfg); + if (prof_handle_ == nullptr) { + MS_LOG(ERROR) << "Startup profiling failed."; + return false; + } + return true; +} + +bool ProfilingManager::StopProfiling() { + MS_LOG(INFO) << "StopProfiling"; + if (!IsProfiling()) { + MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; + return true; + } + Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); + if (reporter != nullptr) { + MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); + } + + auto rt_ret = rtProfilerStop(); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rtProfilerStop failed"; + return false; + } + + if (prof_handle_ != nullptr) { + int result = ProfMgrStop(prof_handle_); + if (result != 0) { + MS_LOG(ERROR) << "ProfMgr stop return fail:" << result << "."; + prof_handle_ = nullptr; + return false; + } + prof_handle_ = nullptr; + } + + return true; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h similarity index 100% rename from mindspore/ccsrc/device/ascend/profiling/profiling_manager.h rename to mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc new file mode 100644 index 0000000000..5b1db6a404 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc @@ -0,0 +1,367 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/profiling/reporter/graph_desc_reporter.h" +#include "runtime/device/ascend/profiling/profiling_utils.h" +#include "backend/kernel_compiler/kernel.h" +#include "runtime/device/ascend/profiling/profiling_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "utils/utils.h" +#include "runtime/device/ascend/profiling/reporter/task_desc_reporter.h" +#include "utils/context/ms_context.h" +#include "runtime/device/ascend/profiling/reporter/point_reporter.h" + +namespace mindspore { +namespace device { +namespace ascend { +constexpr uint32_t kMaxProfilingNodeNum = 100; +constexpr char kCustomNode[] = "PROFILING_CUSTOM_"; +constexpr char kFpStartNode[] = "PROFILING_FP_START"; +constexpr char kBpEndNode[] = "PROFILING_BP_END"; +constexpr char kIterEndNode[] = "PROFILING_ITER_END"; +// PROFILING_CUSTOM_LOGID_START 3 +constexpr uint64_t kProfilingFpStartLogId = 1; +constexpr uint64_t kProfilingBpEndLogId = 2; +constexpr uint64_t kProfilingIterEndLogId = 255; +std::map> ProfilingUtils::graph_profiling_cnode_; +std::map> ProfilingUtils::graph_kernel_name_; +std::map>> ProfilingUtils::graph_point_; +uint32_t ProfilingUtils::custom_node_index_ = 1; + +ProfilingTraceInfo ProfilingUtils::GetProfilingTraceFromEnv(NotNull graph_ptr) { + MS_LOG(INFO) << "get env start"; + custom_node_index_ = 1; + auto &cnode_exec_order = graph_ptr->execution_order(); + ProfilingTraceInfo profiling_trace; + profiling_trace.trace_begin = GetTraceBegin(cnode_exec_order); + profiling_trace.trace_bp_end = GetTraceBpEnd(cnode_exec_order); + profiling_trace.trace_netoutput = GetTraceNetoutput(cnode_exec_order); + + for (uint32_t i = 1; i <= kMaxProfilingNodeNum; ++i) { + std::string env_str = std::string(kCustomNode) + std::to_string(i); + const char *node_full_name = std::getenv(env_str.c_str()); + if (node_full_name == nullptr) { + break; + } + MS_LOG(INFO) << "Get profiling node:" << node_full_name; + profiling_trace.trace_custom_node.insert(node_full_name); + } + MS_LOG(INFO) << "get env end"; + GetTraceHccl(cnode_exec_order, NOT_NULL(&profiling_trace)); + + MS_LOG(INFO) << "[profiling]trace_begin:" << profiling_trace.trace_begin + << " trace_bp_end:" << profiling_trace.trace_bp_end + << " trace_netoutput:" << profiling_trace.trace_netoutput; + return profiling_trace; +} + +void ProfilingUtils::GetTraceHccl(const std::vector &cnode_exec_order, + NotNull profiling_trace) { + for (const auto &node : cnode_exec_order) { + if (AnfAlgo::IsCommunicationOp(node)) { + MS_EXCEPTION_IF_NULL(node); + profiling_trace->trace_custom_node.insert(node->fullname_with_scope()); + MS_LOG(INFO) << "[profiling]Get hccl node:" << node->fullname_with_scope(); + } + } +} + +std::string ProfilingUtils::GetTraceBegin(const std::vector &cnode_exec_order) { + const char *trace_begin = std::getenv(kFpStartNode); + if (trace_begin != nullptr) { + return std::string(trace_begin); + } + + std::string fp_start_str; + std::set getnext_outputs; + GetCNodeOutputRealNode(kGetNextOpName, cnode_exec_order, NOT_NULL(&getnext_outputs)); + if (getnext_outputs.empty()) { + auto first_node = cnode_exec_order.front(); + MS_EXCEPTION_IF_NULL(first_node); + fp_start_str = first_node->fullname_with_scope(); + } else { + for (auto &cnode : cnode_exec_order) { + if (getnext_outputs.count(cnode->fullname_with_scope()) != 0) { + fp_start_str = cnode->fullname_with_scope(); + break; + } + } + } + return fp_start_str; +} + +void ProfilingUtils::GetCNodeOutputRealNode(const std::string &node_name, const std::vector &cnode_exec_order, + NotNull *> getnext_outputs) { + for (const auto &cnode : cnode_exec_order) { + MS_EXCEPTION_IF_NULL(cnode); + for (const auto &input : cnode->inputs()) { + auto prev_cnode = AnfAlgo::VisitKernel(input, 0); + if (!prev_cnode.first->isa()) { + continue; + } + if (AnfAlgo::GetCNodeName(prev_cnode.first) == node_name) { + getnext_outputs->insert(cnode->fullname_with_scope()); + MS_LOG(INFO) << "Find GetNext Output CNode:" << cnode->fullname_with_scope(); + } + } + } + if (getnext_outputs->empty()) { + MS_LOG(WARNING) << "GetNext not found"; + } +} + +std::string ProfilingUtils::GetTraceBpEnd(const std::vector &cnode_exec_order) { + const char *trace_bp_end = std::getenv(kBpEndNode); + + if (trace_bp_end != nullptr) { + return std::string(trace_bp_end); + } + std::string bp_end_str; + // Contain hccl kernel + auto iter = cnode_exec_order.rbegin(); + while (iter != cnode_exec_order.rend()) { + if (AnfAlgo::IsCommunicationOp(*iter)) { + // store communication op input nodes' name + std::set ar_input_node_names; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(*iter); ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(*iter, i); + auto input_node = input_node_with_index.first; + ar_input_node_names.insert(input_node->fullname_with_scope()); + } + // start from previous node + ++iter; + // find input names in previous node + while (iter != cnode_exec_order.rend()) { + if (ar_input_node_names.find((*iter)->fullname_with_scope()) != ar_input_node_names.end()) { + bp_end_str = (*iter)->fullname_with_scope(); + break; + } + ++iter; + } + break; + } + ++iter; + } + + if (bp_end_str.empty()) { + bp_end_str = GetGraphLastTbeKernelName(cnode_exec_order); + } + return bp_end_str; +} + +std::string ProfilingUtils::GetGraphLastTbeKernelName(const std::vector &cnode_exec_order) { + std::string last_tbe_kernel_name; + // find last tbe_kernel + for (auto iter = cnode_exec_order.rbegin(); iter != cnode_exec_order.rend(); ++iter) { + if (AnfAlgo::GetKernelType(*iter) == TBE_KERNEL) { + last_tbe_kernel_name = (*iter)->fullname_with_scope(); + break; + } + } + if (last_tbe_kernel_name.empty()) { + MS_LOG(WARNING) << "tbe kernel not found in graph"; + } + return last_tbe_kernel_name; +} + +std::string ProfilingUtils::GetTraceNetoutput(const std::vector &cnode_exec_order) { + const char *trace_netoutput = std::getenv(kIterEndNode); + return trace_netoutput == nullptr ? GetGraphLastTbeKernelName(cnode_exec_order) : std::string(trace_netoutput); +} + +NotNull ProfilingUtils::CreateProfilingCNode(const ProfilingContent &profiling_content, + NotNull graph_ptr) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); + selected_kernel_builder.SetInputsDeviceType({TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); + selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); + selected_kernel_builder.SetProcessor(kernel::Processor::AICORE); + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + abstract::AbstractBasePtr type_none_abstract = std::make_shared(); + auto primitive = std::make_shared(ProfilingUtils::kProfiling); + std::vector inputs; + inputs.emplace_back(NewValueNode(primitive)); + CNodePtr cnode_ptr = graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode_ptr); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), cnode_ptr.get()); + cnode_ptr->set_abstract(type_none_abstract); + // set attr + ValuePtr notify_value = MakeValue(profiling_content.notify); + ValuePtr trace_id_value = MakeValue(profiling_content.profiler_trace_id); + ValuePtr flags_value = MakeValue(profiling_content.flags); + AnfAlgo::SetNodeAttr(ProfilingUtils::kNotify, notify_value, cnode_ptr); + AnfAlgo::SetNodeAttr(ProfilingUtils::kProfilerTraceId, trace_id_value, cnode_ptr); + AnfAlgo::SetNodeAttr(ProfilingUtils::kFlags, flags_value, cnode_ptr); + return NOT_NULL(cnode_ptr); +} + +void ProfilingUtils::SaveProfilingPoint(uint32_t graph_id, const std::string &node_name, uint32_t point_id) { + std::shared_ptr prof_desc_ptr = std::make_shared(node_name, point_id); + auto iter = graph_point_.find(graph_id); + if (iter == graph_point_.end()) { + std::vector> tmp_vect = {prof_desc_ptr}; + graph_point_.insert({graph_id, tmp_vect}); + } else { + iter->second.emplace_back(prof_desc_ptr); + } +} + +void ProfilingUtils::ProfilingTraceFpStart(const mindspore::AnfNodePtr &anf_node, + const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list) { + if (profiling_trace_info.trace_begin == anf_node->fullname_with_scope()) { + MS_LOG(INFO) << "Profiling Match FpStart:" << profiling_trace_info.trace_begin; + ProfilingTraceJobId(anf_node, graph_ptr, kernel_list); + ProfilingContent fp_profiling_content = {false, kProfilingFpStartLogId, 0}; + auto fp_profiling_node = CreateProfilingCNodeWithStream(anf_node, fp_profiling_content, graph_ptr); + kernel_list->emplace_back(fp_profiling_node); + // insert ProfDesc + SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingFpStartLogId); + } +} + +void ProfilingUtils::ProfilingTraceJobId(const AnfNodePtr &anf_node, NotNull graph_ptr, + NotNull *> kernel_list) { + MS_LOG(INFO) << "Profiling Match start"; + auto job_id = ProfilingManager::GetInstance().GetJobId(); + ProfilingContent job_profiling_context = {false, job_id, 0}; + auto job_profiling_node = CreateProfilingCNodeWithStream(anf_node, job_profiling_context, graph_ptr); + kernel_list->emplace_back(job_profiling_node); +} + +CNodePtr ProfilingUtils::CreateProfilingCNodeWithStream(const mindspore::AnfNodePtr &anf_node, + const ProfilingContent &profiling_content, + NotNull graph_ptr) { + CNodePtr profiling_node = CreateProfilingCNode(profiling_content, graph_ptr); + AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), profiling_node.get()); + AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(anf_node), profiling_node.get()); + return profiling_node; +} + +void ProfilingUtils::ProfilingCustomOp(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list) { + MS_EXCEPTION_IF_NULL(anf_node); + auto iter = profiling_trace_info.trace_custom_node.find(anf_node->fullname_with_scope()); + if (iter == profiling_trace_info.trace_custom_node.end()) { + return; + } + MS_LOG(INFO) << "Profiling Match CustomOp:" << anf_node->fullname_with_scope(); + // custom op profiling job start from 3. + auto custom_point_id = 2 * custom_node_index_ + 1; + ProfilingContent front_profiling_content = {false, custom_point_id, 0}; + CNodePtr front_node = CreateProfilingCNodeWithStream(anf_node, front_profiling_content, graph_ptr); + kernel_list->insert(kernel_list->end() - 1, front_node); + SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), custom_point_id); + + ProfilingContent back_profiling_content = {false, custom_point_id + 1, 0}; + CNodePtr back_node = CreateProfilingCNodeWithStream(anf_node, back_profiling_content, graph_ptr); + kernel_list->insert(kernel_list->end(), back_node); + SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), custom_point_id + 1); + ++custom_node_index_; +} + +void ProfilingUtils::ProfilingTraceBpEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list) { + MS_EXCEPTION_IF_NULL(anf_node); + if (profiling_trace_info.trace_bp_end == anf_node->fullname_with_scope()) { + MS_LOG(INFO) << "Profiling Match BpEnd:" << profiling_trace_info.trace_bp_end; + ProfilingContent bp_end_profiling_content = {false, kProfilingBpEndLogId, 0}; + CNodePtr bp_end_node = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); + kernel_list->emplace_back(bp_end_node); + SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingBpEndLogId); + } +} + +void ProfilingUtils::ProfilingTraceEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list) { + MS_EXCEPTION_IF_NULL(anf_node); + auto full_scope_name = anf_node->fullname_with_scope(); + if (profiling_trace_info.trace_netoutput == full_scope_name) { + MS_LOG(INFO) << "Profiling Match IterEnd:" << profiling_trace_info.trace_netoutput; + ProfilingContent bp_end_profiling_content = {true, kProfilingIterEndLogId, 0}; + CNodePtr bp_kernel_ptr = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); + kernel_list->emplace_back(bp_kernel_ptr); + SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingIterEndLogId); + } +} + +void ProfilingUtils::SetGraphKernelName(uint32_t graph_id, const std::vector &kernel_names) { + auto ret = graph_kernel_name_.try_emplace(graph_id, kernel_names); + if (!ret.second) { + MS_LOG(ERROR) << "[profiling]graph " << graph_id << " kernel names already exist"; + } +} + +void ProfilingUtils::SetGraphProfilingCNode(uint32_t graph_id, const std::vector &profiling_cnode_list) { + auto ret = graph_profiling_cnode_.try_emplace(graph_id, profiling_cnode_list); + if (!ret.second) { + MS_LOG(ERROR) << "[profiling]graph " << graph_id << " profiling cnode list already exist"; + } +} + +bool ProfilingUtils::ValidComputeGraph(NotNull graph_ptr) { + for (const auto &node : graph_ptr->execution_order()) { + if (AnfAlgo::GetKernelType(node) == TBE_KERNEL) { + return true; + } + } + return false; +} + +void ProfilingUtils::ReportProfilingData(const std::vector &task_ids, const std::vector &stream_ids, + NotNull graph) { + if (!ValidComputeGraph(graph)) { + MS_LOG(WARNING) << "Not a valid compute graph:" << graph->graph_id(); + return; + } + + auto ret = graph_profiling_cnode_.find(graph->graph_id()); + if (ret == graph_profiling_cnode_.end()) { + MS_LOG(ERROR) << "Graph id not found"; + return; + } + + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second); + task_reporter.set_task_ids(task_ids); + task_reporter.set_stream_ids(stream_ids); + task_reporter.ReportData(); + + GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second); + graph_profiling_cnode_.erase(ret); + graph_reporter.ReportData(); + + // Report profiling point + auto point_iter = graph_point_.find(graph->graph_id()); + if (point_iter == graph_point_.end()) { + MS_LOG(ERROR) << "Graph id not found in graph_point"; + return; + } + PointReporter point_reporter(context->device_id(), "vm.point"); + for (const auto &point : point_iter->second) { + point_reporter.AddReportData(point); + } + point_reporter.ReportData(); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.h new file mode 100644 index 0000000000..de8ff2ac39 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.h @@ -0,0 +1,142 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include "backend/session/kernel_graph.h" +#include "utils/contract.h" +#include "runtime/device/ascend/profiling/reporter/profiling_desc.h" + +namespace mindspore { +namespace device { +namespace ascend { +struct ProfilingTraceInfo { + // execute order's first execute op(like: Cast or Four2Five ...), except tdt op(GetNext ...) + std::string trace_begin; + // get first net_output(apply kernel) from graph outputs: fp ->net_output<- bp + std::string trace_bp_end; + // execute order's end execute (like: Conv2DBackpropFilter) + std::string trace_netoutput; + + // profiling specific op, such as AllReduce; + std::set trace_custom_node; + + // 1. insert profiling_trace_begin if profiling_trace_bp_end is not empty. + // 2. op lanuch get task info with callback func. + // 3. insert profiling_trace_bp_end. + // 4. insert profiling_trace_net_output if profiling_trace_bp_end is not empty. + + bool IsValid() const { return !(trace_begin.empty() || trace_netoutput.empty()); } +}; + +struct ProfilingContent { + // true -send data from device to host and finish profiling + bool notify; + uint64_t profiler_trace_id; + uint32_t flags; +}; + +class ProfilingUtils { + public: + ProfilingUtils() = default; + ~ProfilingUtils() = default; + + // Insert job_id profiling node and fp_start profiling node. + // Job_id is got from envs, which shound be a number greater than 255 + // Fp_start node should been inserted in the start of a network, and the log_id is hard code to 1. + static void ProfilingTraceFpStart(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list); + + static void ProfilingTraceJobId(const AnfNodePtr &anf_node, NotNull graph_ptr, + NotNull *> kernel_list); + + // Insert net output profiling node, which tells the device to stop profiling. + // The notify in struct ProfilingContent should be 'true', which tells the device to send data to host. + static void ProfilingTraceEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list); + + // Insert bp_end profiling node, which should been inserted after the last backpropagation CNode in the network. + static void ProfilingTraceBpEnd(const mindspore::AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list); + + // Mapping graph id and the kernels' name in the graph + static void SetGraphProfilingCNode(uint32_t graph_id, const std::vector &profiling_cnode_list); + + static void SetGraphKernelName(uint32_t graph_id, const std::vector &kernel_names); + + // Mapping task_id and kernel name for device to generate the time cost of specific kernel. + // Device calculate the time cost of the task which is marked by task id. + // But we need data of (kernel name , time cost) + static void ReportProfilingData(const std::vector &task_ids, const std::vector &stream_ids, + NotNull graph); + + // Get profiling trace point from envs. + // export PROFILING_FP_START='full name of the first cnode to execute' + // export PROFILING_BP_END='full name of the last backpropagation cnode to execute' + // export PROFILING_ITER_END='full name of last cnode in graph to execute' + // And other cnode, like AllReduce, export PROFILING_CUSTOM_1='full name of AllReduce cnode' + // GetNext, export PROFIFLING_CUSTOM_2='full name fo GetNext cnode' + // The variable i in PROFILING_CUSTOM_i should start from 1 without interruption. + static ProfilingTraceInfo GetProfilingTraceFromEnv(NotNull graph_ptr); + + // Insert two profiling trace points, one in front and one behind + static void ProfilingCustomOp(const mindspore::AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list); + + static std::map> graph_kernel_name() { return graph_kernel_name_; } + + inline static constexpr char kProfiling[] = "Profiling"; + inline static constexpr char kNotify[] = "notify"; + inline static constexpr char kProfilerTraceId[] = "profiler_trace_id"; + inline static constexpr char kFlags[] = "flags"; + + private: + static NotNull CreateProfilingCNode(const ProfilingContent &profiling_content, + NotNull graph_ptr); + static CNodePtr CreateProfilingCNodeWithStream(const AnfNodePtr &anf_node, const ProfilingContent &profiling_content, + NotNull graph_ptr); + static std::string GetTraceBegin(const std::vector &cnode_exec_order); + static std::string GetTraceBpEnd(const std::vector &cnode_exec_order); + static std::string GetTraceNetoutput(const std::vector &cnode_exec_order); + static std::string GetGraphLastTbeKernelName(const std::vector &cnode_exec_order); + static void GetTraceHccl(const std::vector &cnode_exec_order, + NotNull profiling_trace); + static void GetCNodeOutputRealNode(const std::string &node_name, const std::vector &cnode_exec_order, + NotNull *> getnext_outputs); + + static bool ValidComputeGraph(NotNull graph_ptr); + static void SaveProfilingPoint(uint32_t graph_id, const std::string &node_name, uint32_t point_id); + + // graph id --> (kernel name list) + static std::map> graph_profiling_cnode_; + static std::map> graph_kernel_name_; + static std::map>> graph_point_; + static uint32_t custom_node_index_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.cc new file mode 100644 index 0000000000..87e2bbcb06 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "runtime/device/ascend/profiling/reporter/desc_reporter.h" +#include "runtime/device/ascend/profiling/plugin_impl.h" +#include "utils/log_adapter.h" + +constexpr size_t kReportMaxLen = 2048; + +namespace mindspore { +namespace device { +namespace ascend { +DescReporter::~DescReporter() = default; + +void DescReporter::ReportByLine(const std::string &data, const std::string &file_name) const { + auto reporter = PluginImpl::GetPluginReporter(); + MS_EXCEPTION_IF_NULL(reporter); + + auto tot_size = data.size(); + size_t cur_size = 0; + while (cur_size < tot_size) { + size_t remain_size = tot_size - cur_size; + size_t report_size = std::min(remain_size, kReportMaxLen); + + Msprof::Engine::ReporterData report_data{}; + report_data.deviceId = device_id_; + report_data.dataLen = report_size; + report_data.data = (unsigned char *)data.c_str() + cur_size; + auto ret = memcpy_s(report_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, file_name.c_str(), file_name.length()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "Memcpy_s report data tag failed"; + } + auto report_ret = reporter->Report(&report_data); + if (report_ret != 0) { + MS_LOG(EXCEPTION) << "Report data failed"; + } + if (report_size == 0) { + MS_LOG(WARNING) << "Report_size is 0"; + break; + } + cur_size += report_size; + } +} + +void DescReporter::ReportAllLine() { + for (const auto &desc : prof_desc_list_) { + auto data = desc->ToString(); + ReportByLine(data, file_name_); + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.h new file mode 100644 index 0000000000..f25c64ce05 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ + +#include +#include +#include +#include +#include "toolchain/prof_reporter.h" +#include "runtime/device/ascend/profiling/reporter/profiling_desc.h" +#include "utils/contract.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace device { +namespace ascend { +class DescReporter { + public: + virtual ~DescReporter() = 0; + DescReporter(int device_id, std::string file_name) : device_id_(device_id), file_name_(std::move(file_name)) {} + + virtual void ReportData() = 0; + + protected: + void ReportByLine(const std::string &data, const std::string &file_name) const; + void ReportAllLine(); + + int device_id_; + std::string file_name_; + std::vector> prof_desc_list_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.cc new file mode 100644 index 0000000000..5c028986d4 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "runtime/device/ascend/profiling/reporter/graph_desc_reporter.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace ascend { +void GraphDescReporter::ReportData() { + for (const auto &node : cnode_list_) { + if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { + MS_LOG(WARNING) << "Skip non tbe kernel"; + continue; + } + std::vector input_data_list; + std::vector output_data_list; + MS_EXCEPTION_IF_NULL(node); + auto op_name = node->fullname_with_scope(); + auto op_type = AnfAlgo::GetCNodeName(node); + auto input_size = AnfAlgo::GetInputTensorNum(node); + for (size_t i = 0; i < input_size; ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); + auto input_node = input_node_with_index.first; + auto input_index = input_node_with_index.second; + DataElement element{}; + element.index_ = i; + element.data_type_ = AnfAlgo::GetOutputDeviceDataType(input_node, input_index); + element.data_format_ = AnfAlgo::GetOutputFormat(input_node, input_index); + element.data_shape_ = AnfAlgo::GetOutputDeviceShape(input_node, input_index); + input_data_list.emplace_back(element); + } + + auto output_size = AnfAlgo::GetOutputTensorNum(node); + for (size_t i = 0; i < output_size; ++i) { + DataElement element{}; + element.index_ = i; + element.data_type_ = AnfAlgo::GetOutputDeviceDataType(node, i); + element.data_format_ = AnfAlgo::GetOutputFormat(node, i); + element.data_shape_ = AnfAlgo::GetOutputDeviceShape(node, i); + output_data_list.emplace_back(element); + } + + auto graph_desc = std::make_shared(op_name, op_type, input_data_list, output_data_list); + prof_desc_list_.emplace_back(graph_desc); + } + ReportAllLine(); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.h new file mode 100644 index 0000000000..531f122cde --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ + +#include +#include +#include +#include "runtime/device/ascend/profiling/reporter/desc_reporter.h" + +namespace mindspore { +namespace device { +namespace ascend { +class GraphDescReporter : public DescReporter { + public: + GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector cnode_list) + : DescReporter(device_id, file_name), cnode_list_(std::move(cnode_list)) {} + ~GraphDescReporter() override = default; + void ReportData() override; + + private: + std::vector cnode_list_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.cc new file mode 100644 index 0000000000..42a1b4c286 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/profiling/reporter/point_reporter.h" + +namespace mindspore { +namespace device { +namespace ascend { +void PointReporter::ReportData() { ReportAllLine(); } + +void PointReporter::AddReportData(const std::shared_ptr &prof_desc) { + prof_desc_list_.emplace_back(prof_desc); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.h new file mode 100644 index 0000000000..c24535f4ec --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ + +#include +#include +#include "runtime/device/ascend/profiling/reporter/desc_reporter.h" + +namespace mindspore { +namespace device { +namespace ascend { +class PointReporter : public DescReporter { + public: + PointReporter(uint32_t device_id, const std::string &file_name) : DescReporter(device_id, file_name) {} + ~PointReporter() override = default; + void ReportData() override; + void AddReportData(const std::shared_ptr &prof_desc); +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.cc new file mode 100644 index 0000000000..4aec72472c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "runtime/device/ascend/profiling/reporter/profiling_desc.h" + +namespace mindspore { +namespace device { +namespace ascend { +std::string TaskDesc::ToString() { + std::string out = op_name_; + out.append(" ") + .append(std::to_string(block_dim_)) + .append(" ") + .append(std::to_string(task_id_)) + .append(" ") + .append(std::to_string(stream_id_)) + .append("\n"); + return out; +} + +std::string GraphDesc::ToString() { + std::string desc; + desc.append("op_name:").append(op_name_).append(" op_type:").append(op_type_); + int input_id = 0; + for (const auto &element : input_data_list_) { + desc.append(" input_id:") + .append(std::to_string(input_id++)) + .append(" input_format:") + .append(element.data_format_) + .append(" input_data_type:") + .append(std::to_string(element.data_type_)) + .append(" input_shape:") + .append(DataShapeToString(element.data_shape_)); + } + + input_id = 0; + for (const auto &element : output_data_list_) { + desc.append(" output_id:") + .append(std::to_string(input_id++)) + .append(" output_format:") + .append(element.data_format_) + .append(" output_data_type:") + .append(std::to_string(element.data_type_)) + .append(" output_shape:") + .append((DataShapeToString(element.data_shape_))); + } + + desc.append("\n"); + + return desc; +} + +std::string PointDesc::ToString() { + std::string desc; + desc.append(std::to_string(point_id_)).append(" ").append(op_name_).append("\n"); + return desc; +} + +std::string GraphDesc::DataShapeToString(const std::vector &shape) { + std::ostringstream oss; + oss << "\""; + if (!shape.empty()) { + std::copy(shape.begin(), shape.end() - 1, std::ostream_iterator(oss, ",")); + oss << shape.back(); + } + oss << "\""; + return oss.str(); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/profiling_desc.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.h similarity index 100% rename from mindspore/ccsrc/device/ascend/profiling/reporter/profiling_desc.h rename to mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.h diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.cc new file mode 100644 index 0000000000..26d722aa1a --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "runtime/device/ascend/profiling/reporter/task_desc_reporter.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/ascend_kernel_mod.h" + +namespace mindspore { +namespace device { +namespace ascend { +void TaskDescReporter::ReportData() { + MS_LOG(INFO) << "cnode_list.size()=" << cnode_list_.size() << " task_ids_.size()=" << task_ids_.size(); + if (cnode_list_.size() != task_ids_.size()) { + MS_LOG(ERROR) << "cnode list size not equal task ids size"; + return; + } + + size_t task_index = 0; + for (const auto &node : cnode_list_) { + if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { + MS_LOG(WARNING) << "Skip non tbe kernel"; + ++task_index; + continue; + } + auto kernel_mod = AnfAlgo::GetKernelMod(node); + auto ascend_kernel_mod = dynamic_cast(kernel_mod); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(ascend_kernel_mod); + // Check task_id and stream_id valid + CheckStreamTaskValid(task_index, task_index); + auto desc_ptr = std::make_shared(node->fullname_with_scope(), task_ids_[task_index], + ascend_kernel_mod->block_dim(), stream_ids_[task_index]); + prof_desc_list_.emplace_back(desc_ptr); + ++task_index; + } + ReportAllLine(); +} + +void TaskDescReporter::CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id) { + if (task_id >= task_ids_.size() || stream_id >= stream_ids_.size()) { + MS_LOG(EXCEPTION) << "Index invalid. task_id:" << task_id << ", task_ids.size:" << task_ids_.size() + << ", stream_id:" << stream_id << ", stream_ids.size:" << stream_ids_.size(); + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.h new file mode 100644 index 0000000000..51526735a9 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ + +#include +#include +#include +#include "runtime/device/ascend/profiling/reporter/desc_reporter.h" + +namespace mindspore { +namespace device { +namespace ascend { +class TaskDescReporter : public DescReporter { + public: + TaskDescReporter(int device_id, const std::string &file_name, std::vector cnode_list) + : DescReporter(device_id, file_name), cnode_list_(std::move(cnode_list)) {} + ~TaskDescReporter() override = default; + void ReportData() override; + void set_task_ids(const std::vector &task_ids) { task_ids_ = task_ids; } + void set_stream_ids(const std::vector &stream_ids) { stream_ids_ = stream_ids; } + + private: + std::vector task_ids_; + std::vector stream_ids_; + void CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id); + std::vector cnode_list_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/device/ascend/readme.md b/mindspore/ccsrc/runtime/device/ascend/readme.md similarity index 100% rename from mindspore/ccsrc/device/ascend/readme.md rename to mindspore/ccsrc/runtime/device/ascend/readme.md diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.cc b/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.cc new file mode 100644 index 0000000000..dba71edfd3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/tasksink/runtime_utils.h" + +#include + +#include "hccl/hcom.h" +#include "utils/log_adapter.h" +#include "utils/utils.h" + +constexpr auto kHcomBroadcast = "hcom_broadcast_"; +constexpr auto kHcomAllGather = "hcom_all_gather_"; +constexpr auto kHcomAllReduce = "hcom_all_reduce_"; +constexpr auto kHcomReduceScatter = "hcom_reduce_scatter_"; +constexpr auto kUnderline = "_"; +namespace mindspore { +namespace device { +namespace ascend { +namespace tasksink { +bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) { + hcclResult_t ret = hcom_bind_model(model, stream); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast(ret); + return false; + } + return true; +} + +bool RuntimeUtils::HcomUnbindModel(rtModel_t model) { + hcclResult_t ret = hcom_unbind_model(model); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast(ret); + return false; + } + return true; +} + +bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info, rtStream_t stream) { + MS_LOG(INFO) << "hccl distribute start"; + MS_EXCEPTION_IF_NULL(task_info); + hcclResult_t ret; + static uint32_t task_counter = 0; + auto hccl_group = task_info->group(); + if (task_info->hccl_type() == kBroadcastOpName) { + // call hcom broadcast interface to run op + const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); + ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast(task_info->count()), + static_cast(task_info->data_type()), static_cast(task_info->root_id()), + hccl_group.c_str(), stream); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast(ret); + return false; + } + } else if (task_info->hccl_type() == kAllGatherOpName) { + // call hcom allgather interface to run op + const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); + ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), + static_cast(task_info->count()), static_cast(task_info->data_type()), + hccl_group.c_str(), stream); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; + return false; + } + } else if (task_info->hccl_type() == kAllReduceOpName) { + // call hcom allreduce interface to run op + const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0); + ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), + static_cast(task_info->count()), static_cast(task_info->data_type()), + static_cast(task_info->op_type()), hccl_group.c_str(), stream); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; + return false; + } + } else if (task_info->hccl_type() == kReduceScatterOpName) { + // call hcom reducescatter interface to run op + const string tag_reduce_scatter = + kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0); + ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), + static_cast(task_info->count()), static_cast(task_info->data_type()), + static_cast(task_info->op_type()), hccl_group.c_str(), stream); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; + return false; + } + } + return true; +} +} // namespace tasksink +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.h b/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.h similarity index 100% rename from mindspore/ccsrc/device/ascend/tasksink/runtime_utils.h rename to mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.h diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc new file mode 100644 index 0000000000..5aeb932105 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/tasksink/task_generator.h" + +#include +#include "backend/kernel_compiler/task_stream.h" +#include "utils/context/ms_context.h" +#include "common/utils.h" +#include "runtime/device/ascend/profiling/profiling_utils.h" +#include "runtime/device/ascend/profiling/profiling_manager.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace tasksink { +bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::vector *task_info_list, + uint32_t graph_id) { + MS_LOG(INFO) << "GenTasks start..."; + MS_EXCEPTION_IF_NULL(task_info_list); + // Traverse graph applykernel list and run + if (!LaunchAllKernel(anf_node_list, task_info_list, graph_id)) { + MS_LOG(ERROR) << "LaunchAllKernel failed"; + return false; + } + MS_LOG(INFO) << "GenTasks end..."; + return true; +} + +void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { + MS_EXCEPTION_IF_NULL(anf_node_ptr); + MS_EXCEPTION_IF_NULL(kernel_inputs); + // akg process + // set atomic clean addr + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, anf_node_ptr)) { + auto clean_output_indexs = AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicOutputIndexs); + auto graph = anf_node_ptr->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto node_users = manager->node_users(); + if (node_users[anf_node_ptr].empty()) { + MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty."; + } + auto depend_node = node_users[anf_node_ptr].pop().first; + if (!IsPrimitiveCNode(depend_node, prim::kPrimDepend)) { + MS_LOG(EXCEPTION) << "Checking Depend node failed"; + } + if (node_users[depend_node].empty()) { + MS_LOG(EXCEPTION) << "Node users of " << depend_node->ToString() << " is empty."; + } + auto post_node = node_users[depend_node].pop().first; + for (auto index : clean_output_indexs) { + auto device_address = AnfAlgo::GetOutputAddr(post_node, index); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + input->size = device_address->size_; + kernel_inputs->push_back(input); + } + MS_LOG(DEBUG) << "AtomicAddClean clean output size: " << clean_output_indexs.size(); + } +} + +void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { + MS_EXCEPTION_IF_NULL(anf_node_ptr); + MS_EXCEPTION_IF_NULL(kernel_inputs); + if (anf_node_ptr->inputs().size() != 2) { + LaunchAddrCleanAkgKernel(anf_node_ptr, kernel_inputs); + return; + } + MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); + auto pre_node = (anf_node_ptr->inputs()[1])->cast(); + // set clean output addr + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { + auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); + for (auto index : clean_output_indexs) { + auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(input->addr); + input->size = device_address->size_; + kernel_inputs->push_back(input); + } + MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); + } + // set clean workspace address + if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { + auto clean_workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); + for (const auto &index : clean_workspace_indexs) { + auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(workspace->addr); + workspace->size = device_address->size_; + kernel_inputs->push_back(workspace); + } + } + auto clear_mems = AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicAddMemSize); + if (kernel_inputs->size() != clear_mems.size()) { + MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size,kerenl_inputs size:" + << kernel_inputs->size() << ",clean mem size" << clear_mems.size(); + } +} + +bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id, + std::vector *task_info_list) { + MS_EXCEPTION_IF_NULL(task_info_list); + MS_EXCEPTION_IF_NULL(anf_node_ptr); + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr); + MS_EXCEPTION_IF_NULL(kernel_mod); + kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope()); + if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) { + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) { + auto real_input_index = AnfAlgo::GetRealInputIndex(anf_node_ptr, i); + auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, real_input_index); + AddressPtr input = std::make_shared
(); + input->addr = device_address->ptr_; + input->size = device_address->size_; + kernel_inputs.push_back(input); + } + + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node_ptr); ++i) { + auto it = AnfAlgo::GetOutputAddr(anf_node_ptr, i); + AddressPtr output = std::make_shared
(); + output->addr = it->ptr_; + output->size = it->size_; + kernel_outputs.push_back(output); + } + + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto device_address = AnfAlgo::GetWorkspaceAddr(anf_node_ptr, i); + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_address->ptr_; + workspace->size = device_address->size_; + kernel_workspaces.push_back(workspace); + } + } else { + LaunchAddrCleanKernel(anf_node_ptr, &kernel_inputs); + } + + auto ascend_kernel_mod = dynamic_cast(kernel_mod); + MS_EXCEPTION_IF_NULL(ascend_kernel_mod); + std::vector task_info_ptrs = + ascend_kernel_mod->GenTask(kernel_inputs, kernel_workspaces, kernel_outputs, stream_id); + task_info_list->insert(task_info_list->end(), task_info_ptrs.begin(), task_info_ptrs.end()); + return true; +} + +bool TaskGenerator::LaunchAllKernel(const std::vector &anf_node_list, + std::vector *task_info_list, uint32_t graph_id) { + uint32_t current_op_index = 0; + std::vector profiling_cnode_list; + std::vector kernel_name_list; + for (const auto &anf_node_ptr : anf_node_list) { + size_t old_size = task_info_list->size(); + uint32_t stream_id = AnfAlgo::GetStreamId(anf_node_ptr); + MS_EXCEPTION_IF_NULL(anf_node_ptr); + MS_LOG(INFO) << "Task gen launch begin, current_op_idx:" << current_op_index + << " name:" << anf_node_ptr->fullname_with_scope() << ", stream id:" << stream_id; + if (!LaunchKernel(anf_node_ptr, stream_id, task_info_list)) { + MS_LOG(ERROR) << "LaunchKernel failed."; + return false; + } + for (size_t i = old_size; i < task_info_list->size(); ++i) { + profiling_cnode_list.emplace_back(anf_node_ptr); + kernel_name_list.emplace_back(anf_node_ptr->fullname_with_scope()); + } + current_op_index++; + } + + ProfilingUtils::SetGraphKernelName(graph_id, kernel_name_list); + if (ProfilingManager::GetInstance().IsProfiling()) { + ProfilingUtils::SetGraphProfilingCNode(graph_id, profiling_cnode_list); + } + return true; +} +} // namespace tasksink +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h new file mode 100644 index 0000000000..134dec48b6 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ + +#include +#include +#include +#include +#include +#include +#include "runtime/device/kernel_runtime.h" +#include "ir/anf.h" +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "framework/ge_runtime/task_info.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace tasksink { +using mindspore::kernel::Address; +using mindspore::kernel::AddressPtr; +using AddressPtrList = std::vector; +using ge::model_runner::TaskInfo; +using TaskInfoPtr = std::shared_ptr; +class TaskGenerator { + public: + TaskGenerator() = default; + ~TaskGenerator() = default; + TaskGenerator(const TaskGenerator &in) = delete; + TaskGenerator &operator=(const TaskGenerator &in) = delete; + + static bool GenTasks(const std::vector &anf_node_list, std::vector *task_info_list, + uint32_t graph_id); + + private: + static void LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs); + static void LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs); + static bool LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id, std::vector *task_info_list); + static bool LaunchAllKernel(const std::vector &anf_node_list, std::vector *task_info_list, + uint32_t graph_id); +}; +} // namespace tasksink +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ diff --git a/mindspore/ccsrc/runtime/device/convert_tensor_utils.cc b/mindspore/ccsrc/runtime/device/convert_tensor_utils.cc new file mode 100644 index 0000000000..cfd9b0fbdf --- /dev/null +++ b/mindspore/ccsrc/runtime/device/convert_tensor_utils.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/convert_tensor_utils.h" +#include +namespace mindspore { +namespace device { +void HalfToFloat(void *dst, const void *src, size_t elem_num) { + auto half_data = static_cast(src); + auto float_data = static_cast(dst); + for (size_t i = 0; i < elem_num; ++i) { + float tmp = Eigen::half_impl::half_to_float(half_data[i]); + float_data[i] = tmp; + } +} + +void FloatToHalf(void *dst, const void *src, size_t elem_num) { + auto float_data = static_cast(src); + auto half_data = static_cast(dst); + for (size_t i = 0; i < elem_num; ++i) { + half_data[i] = Eigen::half(float_data[i]); + } +} + +void DoubleToFloat(void *dst, const void *src, size_t elem_num) { + auto double_data = static_cast(src); + auto float_data = static_cast(dst); + for (size_t i = 0; i < elem_num; ++i) { + float_data[i] = static_cast(double_data[i]); + } +} + +void FloatToDouble(void *dst, const void *src, size_t elem_num) { + auto float_data = static_cast(src); + auto double_data = static_cast(dst); + for (size_t i = 0; i < elem_num; ++i) { + double_data[i] = static_cast(float_data[i]); + } +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/convert_tensor_utils.h b/mindspore/ccsrc/runtime/device/convert_tensor_utils.h similarity index 100% rename from mindspore/ccsrc/device/convert_tensor_utils.h rename to mindspore/ccsrc/runtime/device/convert_tensor_utils.h diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc new file mode 100644 index 0000000000..92269233bd --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/cpu/cpu_device_address.h" +#include +#include "runtime/device/convert_tensor_utils.h" + +namespace mindspore { +namespace device { +namespace cpu { +bool CPUDeviceAddress::SyncDeviceToHost(const std::vector & /*shape*/, size_t size, TypeId type, + void *host_ptr) const { + if (ptr_ == nullptr) { + MS_LOG(ERROR) << "The pointer ptr_ is null!"; + return false; + } + + if (host_ptr == ptr_) { + MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored."; + return true; + } + + if (type == type_id_) { + auto ret_code = memcpy_s(host_ptr, size, ptr_, size_); + if (ret_code != EOK) { + MS_LOG(ERROR) << "Failed to copy tensor!"; + return false; + } + } else if (type == kNumberTypeFloat16) { + FloatToHalf(host_ptr, ptr_, size / 2); + } else if (type == kNumberTypeFloat64) { + FloatToDouble(host_ptr, ptr_, size / sizeof(double)); + } else { + MS_LOG(ERROR) << "Types not match. Device type: " << TypeIdLabel(type_id_) << ", host type: " << TypeIdLabel(type) + << "!"; + return false; + } + return true; +} + +bool CPUDeviceAddress::SyncHostToDevice(const std::vector & /*shape*/, size_t size, TypeId type, + const void *host_ptr) const { + if (type == kNumberTypeFloat16) { + HalfToFloat(ptr_, host_ptr, size / 2); + } else if (type == kNumberTypeFloat64) { + DoubleToFloat(ptr_, host_ptr, size / sizeof(double)); + } + return true; +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h new file mode 100644 index 0000000000..63cf171fa2 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ +#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ + +#include +#include +#include "runtime/device/device_address.h" + +namespace mindspore { +namespace device { +namespace cpu { +class CPUDeviceAddress : public DeviceAddress { + public: + CPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} + + CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) + : DeviceAddress(ptr, size, format, type_id) {} + + ~CPUDeviceAddress() override = default; + + bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; + bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + DeviceAddressType DeviceType() const override { return DeviceAddressType::kCPU; } +}; +} // namespace cpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc new file mode 100644 index 0000000000..d2e41a1fbd --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -0,0 +1,324 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/cpu/cpu_kernel_runtime.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "utils/context/ms_context.h" +#include "utils/config_manager.h" +#include "utils/profile.h" +#include "common/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/session_basic.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace device { +namespace cpu { +const size_t INIT_NODE_REF = 1; +namespace { +TypeId GetCPUSupportOutputTypeId(const TypeId type_id) { + TypeId support_type_id = type_id; + if (type_id == kNumberTypeUInt32) { + support_type_id = kNumberTypeInt32; + } + if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 || + type_id == kNumberTypeFloat64) { + support_type_id = kNumberTypeFloat32; + } + if (support_type_id != kNumberTypeInt32 && support_type_id != kNumberTypeFloat32) { + MS_LOG(EXCEPTION) << "Check output type failed."; + } + return support_type_id; +} +} // namespace + +void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) { + AssignValueNodeAddress(kernel_graph); + AssignInputNodeAddress(kernel_graph); + AssignKernelOutputAddress(kernel_graph); + resource_manager_.MemPlan(kernel_graph); + resource_manager_.MemMalloc(kernel_graph); +} + +void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + size_t type_size = sizeof(float); + for (auto &item_node : kernel_graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(item_node); + if (item_node->isa()) { + auto value_node = item_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + if (!node_value->isa()) { + continue; + } + auto tensor = node_value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + std::vector data_shape = tensor->shape(); + size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); + DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32); + MS_EXCEPTION_IF_NULL(address); + if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { + address->ptr_ = tensor->data_c(); + } else { + address->ptr_ = resource_manager_.MemMalloc(tensor_size); + if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "Value node sync host to device failed!"; + } + } + address->ref_count_ = INIT_NODE_REF; + AnfAlgo::SetOutputAddr(address, 0, item_node.get()); + } + } +} + +void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + size_t type_size = sizeof(float); + for (auto &item : kernel_graph->inputs()) { + MS_EXCEPTION_IF_NULL(item); + if (item->isa()) { + auto output_num = AnfAlgo::GetOutputTensorNum(item); + for (size_t index = 0; index < output_num; index++) { + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); + std::vector fmt_shape = AnfAlgo::GetOutputDeviceShape(item, index); + size_t tensor_size = + fmt_shape.empty() ? type_size + : std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies()); + auto format = AnfAlgo::GetOutputFormat(item, index); + auto address = CreateDeviceAddress(nullptr, tensor_size, format, output_type_id); + AnfAlgo::SetOutputAddr(address, index, item.get()); + } + } + } +} + +void CPUKernelRuntime::AssignKernelOutputAddress(const session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto kernels = kernel_graph->execution_order(); + for (auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + auto output_format = AnfAlgo::GetOutputFormat(kernel, i); + auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); + AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i, + kernel.get()); + } + auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_sizes.size(); ++i) { + AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32), + i, kernel.get()); + } + } +} + +DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) { + return std::make_shared(device_ptr, device_size, format, type_id); +} + +tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, size_t index, + std::set *bound_addresses, + std::vector *need_sync_outputs) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(bound_addresses); + MS_EXCEPTION_IF_NULL(need_sync_outputs); + size_t output_size = AnfAlgo::GetOutputTensorNum(node); + if (index >= output_size) { + MS_LOG(EXCEPTION) << "Invalid input index " << index; + } + auto address = AnfAlgo::GetMutableOutputAddr(node, index); + MS_EXCEPTION_IF_NULL(address); + auto shape = AnfAlgo::GetOutputInferShape(node, index); + std::vector temp_shape; + (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); + TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index); + type_id = GetCPUSupportOutputTypeId(type_id); + tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + MS_EXCEPTION_IF_NULL(tensor); + if (bound_addresses->find(address) != bound_addresses->end()) { + tensor->set_device_address(address); + need_sync_outputs->emplace_back(tensor); + } else { + address->ptr_ = tensor->data_c(); + address->ref_count_ = INIT_NODE_REF; + (void)bound_addresses->insert(address); + } + tensor->set_dirty(false); + return tensor; +} + +BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, + const std::unordered_map &input_map, + std::set *bound_addresses, + std::vector *need_sync_outputs) { + auto &input_node = kernel_with_index.first; + auto index = kernel_with_index.second; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa()) { + auto node = input_node->cast(); + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimMakeTuple->name()) { + VectorRef ret; + for (size_t i = 1; i < node->inputs().size(); i++) { + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0); + auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs); + ret.push_back(out); + } + return ret; + } + return CreatTensorForOutput(node, index, bound_addresses, need_sync_outputs); + } else if (input_node->isa() || input_node->isa()) { + auto iter = input_map.find(input_node.get()); + if (iter != input_map.end()) { + return iter->second; + } + } + return BaseRef(); +} + +void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, + const std::vector &inputs, VectorRef *outputs, + std::vector *need_sync_outputs) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(outputs); + // bind input ptr + auto &input_nodes = kernel_graph->inputs(); + if (input_nodes.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; + } + std::unordered_map input_map; + size_t input_idx = 0; + for (auto &item : input_nodes) { + MS_EXCEPTION_IF_NULL(item); + input_map[item.get()] = inputs[input_idx]; + if (item->isa()) { + auto address = AnfAlgo::GetMutableOutputAddr(item, 0); + auto tensor = inputs[input_idx]; + auto tensor_address = tensor->device_address(); + MS_EXCEPTION_IF_NULL(address); + MS_EXCEPTION_IF_NULL(tensor); + if (tensor_address != nullptr && tensor_address != address) { + (void)tensor->data_sync(); + } + std::vector data_shape = tensor->shape(); + size_t tensor_size = + std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies()); + if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { + address->ptr_ = tensor->data_c(); + } else { + address->ptr_ = resource_manager_.MemMalloc(tensor_size); + if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!"; + } + tensor->set_dirty(true); + } + address->ref_count_ = INIT_NODE_REF; + tensor->set_device_address(address); + } + input_idx++; + } + // new output and bind ptr + std::set bound_addresses; + auto output_nodes = kernel_graph->outputs(); + for (const auto &item : output_nodes) { + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); + auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs); + outputs->push_back(std::move(out)); + } +} + +void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector *input_list) { + MS_EXCEPTION_IF_NULL(address); + MS_EXCEPTION_IF_NULL(input_list); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + if (address->ptr_ == nullptr) { + address->ptr_ = resource_manager_.MemMalloc(address->size_); + } + MS_EXCEPTION_IF_NULL(address->ptr_); + input->addr = address->ptr_; + input->size = address->size_; + input_list->push_back(input); +} + +void CPUKernelRuntime::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + resource_manager_.IncreaseSummaryRefCount(summary_outputs); +} + +void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + resource_manager_.DecreaseSummaryRefCount(summary_outputs); +} + +bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + resource_manager_.IncreaseAddressRefCount(kernel_graph); + + auto kernels = kernel_graph->execution_order(); + for (const auto &kernel : kernels) { +#ifdef ENABLE_PROFILE + double start_time = GetTime(); +#endif + std::vector kernel_inputs; + std::vector kernel_workspaces; + std::vector kernel_outputs; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i).get(); + MS_EXCEPTION_IF_NULL(device_address); + AddRuntimeAddress(device_address, &kernel_inputs); + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); + for (size_t i = 0; i < output_num; ++i) { + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i).get(); + MS_EXCEPTION_IF_NULL(device_address); + AddRuntimeAddress(device_address, &kernel_outputs); + } + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(device_address); + AddRuntimeAddress(device_address, &kernel_workspaces); + } + auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, 0); + resource_manager_.DecreaseAddressRefCount(kernel); + if (!ret) { + MS_LOG(EXCEPTION) << "Launch kernel failed."; + } +#ifdef ENABLE_PROFILE + double cost_time = GetTime() - start_time; + MS_LOG(INFO) << "cpu kernel: " << kernel->fullname_with_scope() << " costs " << cost_time * 1e6 << " us"; +#endif + } + return true; +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h new file mode 100644 index 0000000000..a29f840bfd --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h @@ -0,0 +1,70 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ + +#include +#include +#include +#include +#include +#include "runtime/device/kernel_runtime.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/session_basic.h" +#include "runtime/device/cpu/cpu_resource_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/any.h" +namespace mindspore { +namespace device { +namespace cpu { +class CPUKernelRuntime : public KernelRuntime { + public: + CPUKernelRuntime() = default; + ~CPUKernelRuntime() override = default; + + bool Init() override { return true; } + bool Run(session::KernelGraph *graph) override; + void AssignKernelAddress(session::KernelGraph *kernel_graph); + void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector &inputs, + VectorRef *outputs, std::vector *need_sync_outputs); + void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + + protected: + bool SyncStream() override { return true; }; + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) override; + + private: + tensor::TensorPtr CreatTensorForOutput(const CNodePtr &node, size_t index, + std::set *bound_addresses, + std::vector *need_sync_outputs); + + BaseRef CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, + const std::unordered_map &input_map, + std::set *bound_addresses, + std::vector *need_sync_outputs); + void AssignValueNodeAddress(session::KernelGraph *kernel_graph); + void AssignInputNodeAddress(const session::KernelGraph *kernel_graph); + void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph); + void AddRuntimeAddress(DeviceAddress *address, std::vector *input_list); + CPUResourceManager resource_manager_; +}; +} // namespace cpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.cc new file mode 100644 index 0000000000..c607260ab3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.cc @@ -0,0 +1,174 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/cpu/cpu_resource_manager.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace cpu { +CPUResourceManager::~CPUResourceManager() { MemFree(); } + +void CPUResourceManager::MemFree() { + if (mem_ptr_ != nullptr) { + free(mem_ptr_); + mem_ptr_ = nullptr; + mem_size_ = 0; + } + + for (auto &&iter : dynamic_mem_) { + free(iter.first); + } + dynamic_mem_.clear(); +} + +void CPUResourceManager::MemPlan(const session::KernelGraph *graph) { + mem_plan_.MemPlan(graph); + size_t graph_mem_size = mem_plan_.GetGraphMemSize(graph); + if (graph_mem_size > mem_size_) { + MemFree(); + mem_ptr_ = reinterpret_cast(malloc(graph_mem_size)); + if (mem_ptr_ != nullptr) { + mem_size_ = graph_mem_size; + dynamic_malloc_ = false; + } else { + MS_LOG(INFO) << "Switch to dynamic malloc"; + dynamic_malloc_ = true; + } + } +} + +void CPUResourceManager::MemMalloc(const session::KernelGraph *graph) { + if (dynamic_malloc_) { + return; + } + mem_plan_.MemAssign(graph, mem_ptr_); +} + +void *CPUResourceManager::MemMalloc(size_t mem_size) { + void *ptr = malloc(mem_size); + if (ptr != nullptr) { + memset_s(ptr, mem_size, 0, mem_size); + dynamic_mem_[ptr] = mem_size; + return ptr; + } else { + MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size; + } +} + +void CPUResourceManager::MemFree(void *ptr) { + auto iter = dynamic_mem_.find(ptr); + if (iter != dynamic_mem_.end()) { + (void)dynamic_mem_.erase(iter); + free(ptr); + } +} + +void CPUResourceManager::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + if (!dynamic_malloc_) { + return; + } + + if (summary_outputs.empty()) { + return; + } + + for (auto &output_item : summary_outputs) { + auto node = output_item.second.first; + size_t index = IntToSize(output_item.second.second); + auto address = AnfAlgo::GetMutableOutputAddr(node, index); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_++; + } +} + +void CPUResourceManager::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + if (!dynamic_malloc_) { + return; + } + + if (summary_outputs.empty()) { + return; + } + + for (auto &output_item : summary_outputs) { + auto node = output_item.second.first; + size_t index = IntToSize(output_item.second.second); + auto address = AnfAlgo::GetMutableOutputAddr(node, index); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_--; + if (address->ref_count_ == 0 && address->ptr_ != nullptr) { + MemFree(address->ptr_); + address->ptr_ = nullptr; + } + } +} + +void CPUResourceManager::IncreaseAddressRefCount(const session::KernelGraph *graph) { + if (!dynamic_malloc_) { + return; + } + MS_EXCEPTION_IF_NULL(graph); + auto kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { + auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_++; + } + + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_++; + } + } +} + +void CPUResourceManager::DecreaseAddressRefCount(const AnfNodePtr &kernel) { + if (!dynamic_malloc_) { + return; + } + MS_EXCEPTION_IF_NULL(kernel); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { + auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_--; + if (address->ref_count_ == 0 && address->ptr_ != nullptr) { + MemFree(address->ptr_); + address->ptr_ = nullptr; + } + } + + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_--; + if (address->ref_count_ == 0 && address->ptr_ != nullptr) { + MemFree(address->ptr_); + address->ptr_ = nullptr; + } + } +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.h b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.h new file mode 100644 index 0000000000..d251760dd2 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ +#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ + +#include +#include +#include "backend/session/kernel_graph.h" +#include "backend/session/session_basic.h" +#include "runtime/device/device_address.h" +#include "runtime/device/cpu/cpu_simple_mem_plan.h" +namespace mindspore { +namespace device { +namespace cpu { +class CPUResourceManager { + public: + CPUResourceManager() = default; + ~CPUResourceManager(); + + void MemPlan(const session::KernelGraph *graph); + void MemMalloc(const session::KernelGraph *graph); + void IncreaseAddressRefCount(const session::KernelGraph *graph); + void DecreaseAddressRefCount(const AnfNodePtr &kernel); + void *MemMalloc(size_t mem_size); + void MemFree(void *ptr); + void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + + private: + void MemFree(); + CPUSimpleMemPlan mem_plan_; + + size_t mem_size_{0}; + uint8_t *mem_ptr_{nullptr}; + bool dynamic_malloc_{false}; + std::unordered_map dynamic_mem_; +}; +} // namespace cpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.cc new file mode 100644 index 0000000000..7838e66984 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/cpu/cpu_simple_mem_plan.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace cpu { +void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + size_t total_mem_size = 0; + auto kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { + auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); + MS_EXCEPTION_IF_NULL(kernel_with_index.first); + if (kernel_with_index.first->isa()) { + continue; + } + auto address = AnfAlgo::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, true); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + total_mem_size += address->size_; + } + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); + for (size_t i = 0; i < output_num; ++i) { + auto address = AnfAlgo::GetOutputAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + total_mem_size += address->size_; + } + } + + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + total_mem_size += address->size_; + } + } + } + graph_mem_size_[graph] = total_mem_size; +} + +size_t CPUSimpleMemPlan::GetGraphMemSize(const session::KernelGraph *graph) const { + auto iter = graph_mem_size_.find(graph); + if (iter != graph_mem_size_.end()) { + return iter->second; + } + return 0; +} + +void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(base_ptr); + uint8_t *mem_ptr = base_ptr; + auto kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { + auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); + MS_EXCEPTION_IF_NULL(kernel_with_index.first); + if (kernel_with_index.first->isa()) { + continue; + } + auto address = AnfAlgo::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, true); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + address->ptr_ = mem_ptr; + mem_ptr = mem_ptr + address->size_; + } + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); + for (size_t i = 0; i < output_num; ++i) { + auto address = AnfAlgo::GetMutableOutputAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + address->ptr_ = mem_ptr; + mem_ptr = mem_ptr + address->size_; + } + } + + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + address->ptr_ = mem_ptr; + mem_ptr = mem_ptr + address->size_; + } + } + } +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.h b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.h new file mode 100644 index 0000000000..123e29fbe5 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ +#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ + +#include +#include +#include "backend/session/kernel_graph.h" +#include "runtime/device/device_address.h" + +namespace mindspore { +namespace device { +namespace cpu { +class CPUSimpleMemPlan { + public: + CPUSimpleMemPlan() = default; + ~CPUSimpleMemPlan() = default; + + void MemPlan(const session::KernelGraph *graph); + void MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr); + size_t GetGraphMemSize(const session::KernelGraph *graph) const; + + private: + std::unordered_map graph_mem_size_; +}; +} // namespace cpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc new file mode 100644 index 0000000000..9528e61ee9 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc @@ -0,0 +1,170 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/cpu/kernel_select_cpu.h" + +#include +#include +#include + +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace device { +namespace cpu { +using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; +using mindspore::kernel::KernelBuildInfo; +namespace { +bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) { + auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() || input_node->isa()) { + return true; + } + return false; +} + +void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector &input_not_cnode_indexes, + const CNodePtr kernel_node) { + for (auto &input_index : input_not_cnode_indexes) { + auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; + MS_EXCEPTION_IF_NULL(input_node); + std::vector output_types; + output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first); + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + builder->SetOutputsFormat({kOpFormat_DEFAULT}); + builder->SetOutputsDeviceType(output_types); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get()); + } +} + +void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector *input_formats, + std::vector *input_types, std::vector *input_no_cnode_indexes) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + TypeId dtype = kTypeUnknown; + if (IsInputNotCNode(kernel_node, input_index)) { + input_no_cnode_indexes->emplace_back(input_index); + dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); + } else { + dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); + } + input_formats->emplace_back(kOpFormat_DEFAULT); + input_types->emplace_back(dtype); + } +} + +void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr, + std::vector *output_formats, std::vector *output_types) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { + output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second); + auto dtype = kernel_attr.GetOutputAttr(output_index).first; + output_types->emplace_back(dtype); + } +} + +bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector &input_formats, + const std::vector &input_types, + const std::vector &input_not_cnode_indexes) { + if (kernel_attr.GetInputSize() != input_types.size()) { + MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size(); + return false; + } + auto input_num = input_types.size(); + for (size_t i = 0; i < input_num; ++i) { + bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), + [i](size_t index) { return index == i; }); + bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); + if (have_cnode_input && is_not_cnode_idx) { + continue; + } + if (kernel_attr.GetInputAttr(i).first != input_types[i]) { + MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first + << ", actual input dtype:" << input_types[i]; + return false; + } + if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { + MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second + << ", actual input format:" << input_formats[i]; + return false; + } + } + return true; +} + +void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { + MS_EXCEPTION_IF_NULL(kernel_attr); + TypeId input_dtype = kernel_attr->GetInputAttr(0).first; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 1; i < input_num; ++i) { + kernel_attr->AddInputAttr(input_dtype); + } + + TypeId output_dtype = kernel_attr->GetOutputAttr(0).first; + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t i = 1; i < output_num; ++i) { + kernel_attr->AddOutputAttr(output_dtype); + } +} +} // namespace + +void SetKernelInfo(const CNodePtr &kernel_node) { + std::vector input_formats; + std::vector input_types; + std::vector input_not_cnode_indexes; + std::vector output_formats; + std::vector output_types; + + MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node); + GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes); + + auto kernel_attrs = + kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); + + for (size_t index = 0; index < kernel_attrs.size(); ++index) { + auto kernel_attr = kernel_attrs[index]; + if (kernel_attr.GetAllSame()) { + ExpandKernelAttr(kernel_node, &kernel_attr); + } + if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (kernel_attr.GetOutputSize() != output_num) { + MS_LOG(DEBUG) << "Output num is not equal!"; + continue; + } + MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; + GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); + UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); + for (auto &input_index : input_not_cnode_indexes) { + input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; + } + break; + } + } + + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + builder->SetInputsFormat(input_formats); + builder->SetInputsDeviceType(input_types); + builder->SetOutputsFormat(output_formats); + builder->SetOutputsDeviceType(output_types); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/kernel_select_cpu.h b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.h similarity index 100% rename from mindspore/ccsrc/device/cpu/kernel_select_cpu.h rename to mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.h diff --git a/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.cc b/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.cc new file mode 100644 index 0000000000..c124523d59 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.cc @@ -0,0 +1,277 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/cpu/mpi/mpi_adapter.h" +#ifdef ENABLE_MPI +#include +#include +#include "pybind11/pybind11.h" +#endif // ENABLE_MPI +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +namespace cpu { +std::shared_ptr MPIAdapter::instance_ = nullptr; +std::shared_ptr MPIAdapter::Instance() { + if (instance_ == nullptr) { + MS_LOG(DEBUG) << "Create new mpi adapter instance."; + instance_.reset(new (std::nothrow) MPIAdapter()); + } + return instance_; +} + +#ifdef ENABLE_MPI + +#define RAISE_EXCEPTION(message) \ + { \ + std::ostringstream oss; \ + oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \ + pybind11::pybind11_fail(oss.str()); \ + } + +#define RAISE_EXCEPTION_WITH_PARAM(message, param) \ + { \ + std::ostringstream oss; \ + oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \ + pybind11::pybind11_fail(oss.str()); \ + } + +namespace { +MPI_Op GetMpiOp(const std::string &op_type) { + if (op_type == "sum") { + return MPI_SUM; + } else if (op_type == "max") { + return MPI_MAX; + } else if (op_type == "min") { + return MPI_MIN; + } else if (op_type == "prod") { + return MPI_PROD; + } + + RAISE_EXCEPTION_WITH_PARAM("unsupport op_type: ", op_type); + return MPI_SUM; +} + +int GetScatterIndex(int rankid, const std::vector &ranks_group) { + int scatter_index = -1; + for (size_t i = 0; i < ranks_group.size(); ++i) { + if (ranks_group[i] == rankid) { + scatter_index = static_cast(i); + break; + } + } + if (scatter_index == -1) { + RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rankid); + } + return scatter_index; +} +} // namespace + +MPIAdapter::MPIAdapter() : comm_group_world_(MPI_GROUP_NULL) { Init(); } + +MPIAdapter::~MPIAdapter() { + int finalized; + MPI_Finalized(&finalized); + if (finalized != 0) { + return; + } + + for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) { + MPI_Group_free(&iter->second); + } + ranks_group_.clear(); + if (comm_group_world_ != MPI_GROUP_NULL) { + MPI_Group_free(&comm_group_world_); + comm_group_world_ = MPI_GROUP_NULL; + } + MPI_Finalize(); +} + +void MPIAdapter::Init() { + static bool init = false; + if (init) { + return; + } + + int init_flag = 0; + if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { + RAISE_EXCEPTION("Check mpi initialized fail!"); + } + if (init_flag == 0) { + auto ret = MPI_Init(nullptr, nullptr); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION("Failed to init mpi!"); + } + } + + MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_); + if (comm_group_world_ == MPI_GROUP_NULL) { + RAISE_EXCEPTION("comm_group_world_ init fail!"); + } + auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION("Failed to init mpi rank id!"); + } + + ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_) + } + init = true; +} + +MPI_Group MPIAdapter::AddGroup(const std::vector &ranks) { + if (ranks.size() > static_cast(rank_size_) || ranks.empty()) { + RAISE_EXCEPTION_WITH_PARAM("input rank size:", ranks.size()); + } + + if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) { + RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rank_id_); + } + std::lock_guard lock(group_mutex_); + auto iter = ranks_group_.find(ranks); + if (iter != ranks_group_.end()) { + return iter->second; + } + const auto ranks_size = ranks.size(); + std::vector ranks_input(ranks_size, 0); + for (size_t i = 0; i < ranks_size; ++i) { + ranks_input[i] = ranks[i]; + } + + MPI_Group group = MPI_GROUP_NULL; + MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group); + if (group == MPI_GROUP_NULL) { + RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_) + } + + ranks_group_[ranks] = group; + return group; +} + +bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector &ranks_group, size_t data_num, + const std::string &op_type) { + if (ranks_group.empty()) { + RAISE_EXCEPTION("input rank group is empty!"); + return false; + } + + auto group = AddGroup(ranks_group); + if (group == MPI_GROUP_NULL) { + RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_) + } + MPI_Comm comm; + MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); + if (comm == MPI_COMM_NULL) { + RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); + } + std::vector receive_count(ranks_group.size(), 0); + for (size_t i = 0; i < ranks_group.size(); ++i) { + receive_count[i] = data_num; + } + + auto op = GetMpiOp(op_type); + auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm); + bool result = true; + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi reduce_scatter fail!ret = ", ret); + result = false; + } + + ret = MPI_Comm_free(&comm); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail! ret = ", ret); + } + return result; +} + +bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t input_data_num, + size_t output_size, const std::string &op_type, float *output) { + int scatter_index = GetScatterIndex(rank_id_, ranks_group); + auto group = AddGroup(ranks_group); + if (group == MPI_GROUP_NULL) { + RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_); + } + MPI_Comm comm; + MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); + if (comm == MPI_COMM_NULL) { + RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); + } + + MPI_Win window; + auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi window create fail! ret = ", ret); + } + MPI_Win_fence(0, window); + for (size_t i = 0; i < ranks_group.size(); ++i) { + int remote_rank = ranks_group[i]; + if (rank_id_ == remote_rank) { + continue; + } + auto op = GetMpiOp(op_type); + ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num, + input_data_num, MPI_FLOAT, op, window); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi accumulate fail!ret = ", ret); + } + } + MPI_Win_fence(0, window); + if (output != nullptr) { + auto data_size = input_data_num * sizeof(float); + if (output_size < data_size) { + std::ostringstream exception_msg; + exception_msg << "output buffer size " << output_size << " < input size " << data_size; + RAISE_EXCEPTION(exception_msg.str()) + } + auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size); + if (copy_ret != 0) { + RAISE_EXCEPTION_WITH_PARAM("copy output memory fail!ret = ", copy_ret); + } + } + MPI_Win_free(&window); + MPI_Comm_free(&comm); + return true; +} + +bool MPIAdapter::AllGather(const float *input, float *output, const std::vector &ranks_group, size_t data_num) { + if (ranks_group.empty()) { + RAISE_EXCEPTION("input rank group is empty!"); + return false; + } + auto group = AddGroup(ranks_group); + if (group == MPI_GROUP_NULL) { + RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail! rankid:", rank_id_); + } + MPI_Comm comm; + MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); + if (comm == MPI_COMM_NULL) { + RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail! rankid:", rank_id_); + } + auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi allgater fail!ret = ", ret); + } + ret = MPI_Comm_free(&comm); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail!ret = ", ret); + } + return true; +} +#endif // ENABLE_MPI +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h b/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.h similarity index 100% rename from mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h rename to mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.h diff --git a/mindspore/ccsrc/device/cpu/readme.md b/mindspore/ccsrc/runtime/device/cpu/readme.md similarity index 100% rename from mindspore/ccsrc/device/cpu/readme.md rename to mindspore/ccsrc/runtime/device/cpu/readme.md diff --git a/mindspore/ccsrc/runtime/device/device_address.h b/mindspore/ccsrc/runtime/device/device_address.h new file mode 100644 index 0000000000..32f5fcced9 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/device_address.h @@ -0,0 +1,97 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DEVICE_TENSOR_H +#define MINDSPORE_DEVICE_TENSOR_H + +#include +#include +#include +#include "ir/dtype.h" +#include "ir/device_sync.h" + +namespace mindspore { +namespace device { +namespace cpu { +class CPUSimpleMemPlan; +class CPUResourceManager; +class CPUKernelRuntime; +} // namespace cpu +namespace ascend { +class AscendKernelRuntime; +class AscendMemoryManager; +class DataDumper; +namespace tasksink { +class TaskGenerator; +} // namespace tasksink +} // namespace ascend +namespace gpu { +class GPUKernelRuntime; +class GPUMemoryManager; +} // namespace gpu +} // namespace device +} // namespace mindspore + +namespace mindspore { +namespace device { +enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice }; +enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU }; + +class DeviceAddress : public mindspore::DeviceSync { + public: + explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {} + explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) + : ptr_(ptr), size_(size), format_(format), type_id_(type_id) {} + virtual ~DeviceAddress() { ptr_ = nullptr; } + const void *GetPtr() const { return ptr_; } + size_t GetSize() const { return size_; } + std::string format() const { return format_; } + TypeId type_id() const { return type_id_; } + void set_host_shape(const std::vector &shape) { host_shape_ = shape; } + virtual void UpdateCommunicationAddress() {} + virtual void set_status(DeviceAddressStatus status) {} + virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } + virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } + + protected: + const void *ptr() const { return ptr_; } + size_t size() const { return size_; } + void set_ptr(void *ptr) { ptr_ = ptr; } + void *ptr_{nullptr}; + size_t size_{0}; + size_t ref_count_{0}; + string format_{"DefaultFormat"}; + TypeId type_id_{kNumberTypeFloat16}; + bool from_mem_pool_{false}; + std::vector host_shape_{}; + friend class KernelRuntime; + friend class MemoryManager; + friend class mindspore::device::ascend::tasksink::TaskGenerator; + friend class mindspore::device::cpu::CPUSimpleMemPlan; + friend class mindspore::device::cpu::CPUResourceManager; + friend class mindspore::device::cpu::CPUKernelRuntime; + friend class mindspore::device::gpu::GPUKernelRuntime; + friend class mindspore::device::gpu::GPUMemoryManager; + friend class mindspore::device::ascend::AscendKernelRuntime; + friend class mindspore::device::ascend::AscendMemoryManager; + friend class mindspore::device::ascend::DataDumper; +}; + +using DeviceAddressPtr = std::shared_ptr; +using DeviceAddressPtrList = std::vector; +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_DEVICE_TENSOR_H diff --git a/mindspore/ccsrc/runtime/device/gpu/blocking_queue.cc b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.cc new file mode 100644 index 0000000000..547c2fbe64 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/blocking_queue.h" +#include +#include "runtime/device/gpu/gpu_common.h" +#include "common/utils.h" + +namespace mindspore { +namespace device { +GpuQueue::GpuQueue(void *addr, const std::vector &shape, const size_t &capacity) + : buffer_(addr), head_(0), tail_(0), shape_(shape), len_(0), capacity_(capacity), stream_(0), node_info_(nullptr) { + CHECK_CUDA_RET_WITH_ERROR(cudaStreamCreate(&stream_), "Cuda Create Stream Failed"); + node_info_ = std::make_unique(capacity); + for (auto item : shape) { + len_ += item; + } +} + +GpuQueue::~GpuQueue() { buffer_ = nullptr; } + +BlockQueueStatus_T GpuQueue::Push(const std::vector &data) { + int offset = 0; + for (size_t i = 0; i < data.size(); i++) { + auto item = data[i]; + if (item.data_ptr_ == nullptr || item.data_len_ != shape_[i]) { + MS_LOG(ERROR) << "Invalid Input: ptr: " << item.data_ptr_ << ", len: " << item.data_len_; + return ERROR_INPUT; + } + + void *addr = reinterpret_cast(buffer_) + tail_ * len_ + offset; + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(addr, item.data_ptr_, item.data_len_, cudaMemcpyHostToDevice, stream_), + "Cuda Memcpy Error"); + + offset += item.data_len_; + } + + node_info_[tail_].event_.reset(new cudaEvent_t()); + CHECK_CUDA_RET_WITH_ERROR(cudaEventCreate(&(*(node_info_[tail_].event_))), "Cuda Create Event Failed"); + node_info_[tail_].data_ = data; + tail_ = (tail_ + 1) % (capacity_); + return SUCCESS; +} + +BlockQueueStatus_T GpuQueue::Front(void **addr, size_t *len) const { + CHECK_CUDA_RET_WITH_ERROR(cudaEventSynchronize(*(node_info_[head_].event_)), "Cuda Event Syn Failed"); + CHECK_CUDA_RET_WITH_ERROR(cudaEventDestroy(*(node_info_[head_].event_)), "Cuda Destroy Event Failed"); + *addr = (unsigned char *)buffer_ + head_ * len_; + *len = len_; + + for (auto item : node_info_[head_].data_) { + host_release_(item.data_ptr_); + } + return SUCCESS; +} + +BlockQueueStatus_T GpuQueue::Pop() { + head_ = (head_ + 1) % (capacity_); + return SUCCESS; +} + +bool GpuQueue::Destroy() { + if (stream_ != nullptr) { + auto ret = cudaStreamDestroy(stream_); + if (ret == cudaSuccess) { + return true; + } else { + return false; + } + } else { + return true; + } +} + +BlockQueueStatus_T BlockingQueue::Create(void *addr, const std::vector &shape, const size_t &capacity) { + if (addr == nullptr) { + MS_LOG(ERROR) << "addr is nullptr"; + return INTERNAL_ERROR; + } + queue_ = std::make_shared(addr, shape, capacity); + return SUCCESS; +} + +void BlockingQueue::RegisterRelease(const std::function &func) { queue_->RegisterRelease(func); } + +BlockQueueStatus_T BlockingQueue::Push(const std::vector &data, unsigned int timeout_in_sec) { + std::unique_lock locker(mutex_); + if (queue_->IsFull()) { + if (not_full_cond_.wait_for(locker, std::chrono::seconds(timeout_in_sec)) == std::cv_status::timeout) { + return TIMEOUT; + } + } + auto ret = queue_->Push(data); + if (ret) { + return ret; + } + not_empty_cond_.notify_one(); + return SUCCESS; +} + +BlockQueueStatus_T BlockingQueue::Front(void **addr, size_t *len) { + std::unique_lock locker(mutex_); + bool timeout = not_empty_cond_.wait_for(locker, std::chrono::seconds(30), [this] { return !queue_->IsEmpty(); }); + if (!timeout) { + return TIMEOUT; + } + + return queue_->Front(addr, len); +} + +BlockQueueStatus_T BlockingQueue::Pop() { + std::unique_lock locker(mutex_); + not_empty_cond_.wait(locker, [this] { return !queue_->IsEmpty(); }); + auto ret = queue_->Pop(); + if (ret) { + return ret; + } + not_full_cond_.notify_one(); + return SUCCESS; +} + +bool BlockingQueue::Destroy() { + if (queue_ != nullptr) { + return queue_->Destroy(); + } else { + return true; + } +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/blocking_queue.h b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.h similarity index 100% rename from mindspore/ccsrc/device/gpu/blocking_queue.h rename to mindspore/ccsrc/runtime/device/gpu/blocking_queue.h diff --git a/mindspore/ccsrc/runtime/device/gpu/cuda_common.h b/mindspore/ccsrc/runtime/device/gpu/cuda_common.h new file mode 100644 index 0000000000..2689fdbaca --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/cuda_common.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ + +#include +#include "runtime/device/gpu/gpu_device_manager.h" + +namespace mindspore { +namespace device { +namespace gpu { +class CudaCommon { + public: + inline int threads_num() const { return threads_per_block_; } + inline int major_sm() const { return major_sm_; } + inline int blocks_num(const int total_threads) const { + return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_); + } + + static CudaCommon &GetInstance() { + static CudaCommon instance; + return instance; + } + + private: + CudaCommon() { + uint32_t device_id = GPUDeviceManager::GetInstance().cur_device_id(); + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, device_id); + threads_per_block_ = prop.maxThreadsPerBlock; + max_blocks_ = prop.multiProcessorCount; + major_sm_ = prop.major; + } + ~CudaCommon() = default; + CudaCommon(const CudaCommon &) = delete; + CudaCommon &operator=(const CudaCommon &) = delete; + + int max_blocks_; + int threads_per_block_; + int major_sm_; +}; +#define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) +#define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() +#define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() +#define MINIUM_SM 6 +#define RECOMMEND_SM 7 +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc new file mode 100644 index 0000000000..1f5e5e3c22 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc @@ -0,0 +1,231 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/cuda_driver.h" +#include +#include "utils/log_adapter.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace device { +namespace gpu { +size_t CudaDriver::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { + size_t retreat_count = 0; + auto ret = cudaMalloc(reinterpret_cast(addr), size); + // If free memory is not enough, then retry with mem_malloc_retry_rate_. + while (ret == cudaErrorMemoryAllocation) { + size = FloatToSize(size * mem_malloc_retry_rate_); + size = (size / mem_malloc_align_size_) * mem_malloc_align_size_; + ret = cudaMalloc(reinterpret_cast(addr), size); + retreat_count++; + if (retreat_count > mem_malloc_retry_conut_max_) { + break; + } + } + + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMalloc failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return 0; + } + return size; +} + +bool CudaDriver::FreeDeviceMem(const DeviceMemPtr &addr) { + auto ret = cudaFree(addr); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaFree failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +size_t CudaDriver::AllocHostPinnedMem(size_t size, void **addr) { + if (size == 0) { + MS_LOG(EXCEPTION) << "The memory allocate size is 0"; + } + auto ret = cudaHostAlloc(addr, size, cudaHostAllocDefault); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaHostAlloc failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return 0; + } + return size; +} + +void CudaDriver::FreeHostPinnedMem(void *addr) { + if (addr) { + auto ret = cudaFreeHost(addr); + if (ret != cudaSuccess) { + MS_LOG(EXCEPTION) << "cudaFreeHost failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + } + } +} + +bool CudaDriver::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) { + auto ret = cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemcpy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) { + auto ret = cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemcpy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, DeviceStream stream) { + auto ret = cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, (cudaStream_t)stream); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemcpyAsync failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, + DeviceStream stream) { + auto ret = cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, (cudaStream_t)stream); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemcpyAsync failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +size_t CudaDriver::total_mem_size() { + size_t free; + size_t total; + auto ret = cudaMemGetInfo(&free, &total); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemGetInfo failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return 0; + } + return total; +} + +size_t CudaDriver::free_mem_size() { + size_t free; + size_t total; + auto ret = cudaMemGetInfo(&free, &total); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemGetInfo failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return 0; + } + + return free; +} + +bool CudaDriver::CreateStream(DeviceStream *stream) { + auto ret = cudaStreamCreateWithFlags(reinterpret_cast(stream), cudaStreamNonBlocking); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaStreamCreate failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::DestroyStream(const DeviceStream &stream) { + auto ret = cudaStreamDestroy((cudaStream_t)stream); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaStreamDestroy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::SyncStream(const DeviceStream &stream) { + auto ret = cudaStreamSynchronize((cudaStream_t)stream); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaStreamSynchronize failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::CreateEvent(DeviceEvent *event, unsigned int flag) { + auto ret = cudaEventCreateWithFlags(reinterpret_cast(event), flag); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaEventCreateWithFlags failed, ret[" << static_cast(ret) << "], " + << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::DestroyEvent(const DeviceEvent &event) { + auto ret = cudaEventDestroy((cudaEvent_t)event); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaEventDestroy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::RecordEvent(DeviceEvent event, DeviceStream stream) { + auto ret = cudaEventRecord((cudaEvent_t)event, (cudaStream_t)stream); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaEventRecord failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::SyncEvent(const DeviceEvent &event) { + auto ret = cudaEventSynchronize((cudaEvent_t)event); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaEventSynchronize failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::QueryEvent(const DeviceEvent &event) { + auto ret = cudaEventQuery((cudaEvent_t)event); + if (ret == cudaSuccess) { + return true; + } else if (ret == cudaErrorNotReady) { + return false; + } else { + MS_LOG(ERROR) << "cudaEventQuery failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } +} + +int CudaDriver::device_count() { + int dev_count; + auto ret = cudaGetDeviceCount(&dev_count); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaGetDeviceCount failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + } + return dev_count; +} + +bool CudaDriver::set_current_device(int index) { + auto ret = cudaSetDevice(index); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaSetDevice failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/cuda_driver.h b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.h similarity index 100% rename from mindspore/ccsrc/device/gpu/cuda_driver.h rename to mindspore/ccsrc/runtime/device/gpu/cuda_driver.h diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h new file mode 100644 index 0000000000..5373f21d70 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_COLLECTIVE_COMMON_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_COLLECTIVE_COMMON_H_ + +#include +#include "pybind11/pybind11.h" + +namespace mindspore { +namespace device { +namespace gpu { +constexpr int MAX_HOSTNAME_LEN = 1024; +constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; +#define CHECK_RET(expression, result, message) \ + { \ + auto ret = (expression); \ + if (ret != result) { \ + std::ostringstream oss; \ + oss << "Error in file " << __FILE__ << " | Error on line " << __LINE__ << " | GPU collective Error: " << message \ + << " | Error Number " << ret; \ + pybind11::pybind11_fail(oss.str()); \ + } \ + } +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_COLLECTIVE_COMMON_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.cc new file mode 100644 index 0000000000..80793042fd --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/distribution/collective_fake_init.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +namespace gpu { +void CollectiveFakeInitializer::InitCollective() { MS_LOG(EXCEPTION) << "build without enable gpu!"; } + +void CollectiveFakeInitializer::FinalizeCollective() { MS_LOG(EXCEPTION) << "build without enable gpu!"; } +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_fake_init.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.h similarity index 100% rename from mindspore/ccsrc/device/gpu/distribution/collective_fake_init.h rename to mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.h diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc new file mode 100644 index 0000000000..cba789b38d --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/distribution/collective_init.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +namespace gpu { +CollectiveInitializer &CollectiveInitializer::instance() { + static CollectiveInitializer instance = {}; + return instance; +} + +bool CollectiveInitializer::collective_inited() const { return collective_inited_; } + +const void *CollectiveInitializer::collective_handle() const { return collective_handle_; } + +void CollectiveInitializer::InitCollective() { + void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); + if (handle == nullptr) { + MS_LOG(EXCEPTION) + << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " + "installed.\n2.nccl is not " + "installed or found.\n3.mpi is not installed or found"; + } + auto mpi_init_funcptr = reinterpret_cast(dlsym(handle, "InitMPI")); + MS_EXCEPTION_IF_NULL(mpi_init_funcptr); + (*mpi_init_funcptr)(); + + CollectiveInitializer::instance().collective_inited_ = true; + CollectiveInitializer::instance().collective_handle_ = handle; +} + +void CollectiveInitializer::FinalizeCollective() { + if (CollectiveInitializer::instance().collective_handle_ != nullptr) { + if (dlclose(CollectiveInitializer::instance().collective_handle_) != 0) { + MS_LOG(EXCEPTION) << "Closing libgpu_collective.so handle failed."; + } + } +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h new file mode 100644 index 0000000000..464492d50f --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ + +#include +#include +#include + +namespace mindspore { +namespace device { +namespace gpu { +using InitMPI = void (*)(); +using InitNCCLComm = void (*)(); +using GetLocalRankId = int (*)(); +using CreateCommGroupFunc = bool (*)(const std::string &, const std::vector &); +using GetRankIDByGroupFunc = int (*)(const std::string &); +using GetGroupSizeFunc = int (*)(const std::string &); +using DestroyGroupFunc = bool (*)(const std::string &); + +class CollectiveInitializer { + public: + CollectiveInitializer(CollectiveInitializer const &) = delete; + CollectiveInitializer &operator=(const CollectiveInitializer &) = delete; + static CollectiveInitializer &instance(); + bool collective_inited() const; + const void *collective_handle() const; + static void InitCollective(); + static void FinalizeCollective(); + + private: + CollectiveInitializer() : collective_inited_(false) {} + ~CollectiveInitializer() = default; + + bool collective_inited_; + void *collective_handle_{nullptr}; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_COLLECTIVE_INIT_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc new file mode 100644 index 0000000000..f427905afa --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "runtime/device/gpu/distribution/mpi_wrapper.h" +#include "runtime/device/gpu/distribution/nccl_wrapper.h" + +#ifndef EXPORT_WRAPPER +#define EXPORT_WRAPPER __attribute__((visibility("default"))) +#endif + +using MPIWrapper = mindspore::device::gpu::MPIWrapper; +using NCCLWrapper = mindspore::device::gpu::NCCLWrapper; + +extern "C" EXPORT_WRAPPER void InitMPI() { MPIWrapper::instance(); } + +extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } + +extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } + +extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector &ranks) { + return MPIWrapper::instance().CreateCommGroup(group_name, ranks); +} + +extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name) { + return MPIWrapper::instance().GetRankIDByGroup(group_name); +} + +extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name) { + return MPIWrapper::instance().GetGroupSize(group_name); +} + +extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name) { + return MPIWrapper::instance().DestroyGroup(group_name); +} + +extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, ncclRedOp_t reduce_type, + cudaStream_t stream) { + return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream); +} + +extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, cudaStream_t stream) { + return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream); +} + +extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, ncclRedOp_t reduce_type, + cudaStream_t stream) { + return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream); +} diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc new file mode 100644 index 0000000000..08ec320cab --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/distribution/mpi_wrapper.h" +#include +#include +#include +#include "runtime/device/gpu/distribution/nccl_wrapper.h" + +namespace mindspore { +namespace device { +namespace gpu { +MPIWrapper::MPIWrapper() : rank_id_(0), rank_size_(0), local_rank_id_(0) { Init(); } + +MPIWrapper::~MPIWrapper() { + int finalized; + MPI_Finalized(&finalized); + if (finalized == 0) { + MPI_Finalize(); + } +} + +MPIWrapper &MPIWrapper::instance() { + static MPIWrapper instance; + return instance; +} + +int MPIWrapper::local_rank_id() const { return local_rank_id_; } + +bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vector &group_ranks) { + std::vector ranks(group_ranks.begin(), group_ranks.end()); + MPI_Group mpi_group; + CHECK_RET(MPI_Group_incl(world_group_, ranks.size(), ranks.data(), &mpi_group), MPI_SUCCESS, + "Failed to produce a new group from MPI_COMM_WORLD group for " + group_name); + SetGroupNameToMPIGroup(group_name, mpi_group); + + MPI_Comm mpi_group_comm; + CHECK_RET(MPI_Comm_create(MPI_COMM_WORLD, mpi_group, &mpi_group_comm), MPI_SUCCESS, + "Failed to create MPI communicator."); + if (mpi_group_comm == MPI_COMM_NULL) { + return false; + } + + ncclUniqueId group_unique_id; + if (rank_id_ == ranks[0]) { + group_unique_id = NCCLWrapper::instance().nccl_unique_id(); + } + MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, ranks[0], mpi_group_comm); + + int group_rank[1]; + int global_rank[1] = {rank_id_}; + CHECK_RET(MPI_Group_translate_ranks(world_group_, 1, global_rank, mpi_group, group_rank), MPI_SUCCESS, + "Failed to translate global rank to group rank."); + if (group_rank[0] == MPI_UNDEFINED) { + return false; + } + + ncclComm_t nccl_group_comm; + NCCLWrapper::instance().InitNCCLComm(&nccl_group_comm, ranks.size(), group_unique_id, group_rank[0]); + NCCLWrapper::instance().SetGroupNameToNCCLComm(group_name, nccl_group_comm); + return true; +} + +int MPIWrapper::GetRankIDByGroup(const std::string &group_name) { + CHECK_RET(group_name_to_mpi_group_map_.count(group_name), 1, "Failed to get MPI group by group name " + group_name); + MPI_Group mpi_group = group_name_to_mpi_group_map_[group_name]; + int rank; + CHECK_RET(MPI_Group_rank(mpi_group, &rank), MPI_SUCCESS, "Failed to get rank id by group name." + group_name); + return rank; +} + +int MPIWrapper::GetGroupSize(const std::string &group_name) { + CHECK_RET(group_name_to_mpi_group_map_.count(group_name), 1, "Failed to get MPI group by group name" + group_name); + MPI_Group mpi_group = group_name_to_mpi_group_map_[group_name]; + int size; + CHECK_RET(MPI_Group_size(mpi_group, &size), MPI_SUCCESS, "Failed to get group size by group name." + group_name); + return size; +} + +bool MPIWrapper::DestroyGroup(const std::string &group_name) { + auto group_iter = group_name_to_mpi_group_map_.find(group_name); + if (group_iter == group_name_to_mpi_group_map_.end()) { + return false; + } + group_name_to_mpi_group_map_.erase(group_name); + MPI_Group mpi_group = group_iter->second; + CHECK_RET(MPI_Group_free(&mpi_group), MPI_SUCCESS, "Failed to free MPI group for " + group_name); + NCCLWrapper::instance().DestroyGroup(group_name); + return true; +} + +void MPIWrapper::Init() { + int initialized; + CHECK_RET(MPI_Initialized(&initialized), MPI_SUCCESS, "Failed to check mpi initialization status."); + if (initialized == 0) { + MPI_Init(nullptr, nullptr); + } + + CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id."); + CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size."); + NCCLWrapper::instance().set_rank(rank_id_, rank_size_); + AssignLocalRankID(); + + CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD"); + SetGroupNameToMPIGroup(NCCL_WORLD_GROUP, world_group_); + + ncclUniqueId unique_id; + if (rank_id_ == 0) { + unique_id = NCCLWrapper::instance().nccl_unique_id(); + } + CHECK_RET(MPI_Bcast(reinterpret_cast(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD), + MPI_SUCCESS, "Failed to broadcast nccl unique id."); + NCCLWrapper::instance().set_nccl_unique_id(unique_id); + return; +} + +void MPIWrapper::AssignLocalRankID() { + char host_name[MAX_HOSTNAME_LEN] = {0}; + CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), 0, "Getting host name failed."); + size_t host_hash = std::hash()(host_name); + + const int kRankSize = rank_size_; + size_t all_host_hashs[kRankSize]; + all_host_hashs[rank_id_] = host_hash; + CHECK_RET(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs, sizeof(size_t), MPI_BYTE, MPI_COMM_WORLD), + MPI_SUCCESS, "MPI_Allgather host hashs failed."); + for (int global_rank = 0; global_rank < kRankSize; global_rank++) { + if (global_rank == rank_id_) { + break; + } + if (all_host_hashs[global_rank] == all_host_hashs[rank_id_]) { + local_rank_id_++; + } + } + return; +} + +void MPIWrapper::SetGroupNameToMPIGroup(const std::string &group_name, const MPI_Group mpi_group) { + group_name_to_mpi_group_map_[group_name] = mpi_group; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h new file mode 100644 index 0000000000..19d06b32d3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "runtime/device/gpu/distribution/collective_common.h" + +namespace mindspore { +namespace device { +namespace gpu { +class MPIWrapper { + public: + MPIWrapper(MPIWrapper const &) = delete; + MPIWrapper &operator=(const MPIWrapper &) = delete; + static MPIWrapper &instance(); + int local_rank_id() const; + bool CreateCommGroup(const std::string &group_name, const std::vector &ranks); + int GetRankIDByGroup(const std::string &group_name); + int GetGroupSize(const std::string &group_name); + bool DestroyGroup(const std::string &group_name); + + private: + MPIWrapper(); + ~MPIWrapper(); + void Init(); + void AssignLocalRankID(); + void SetGroupNameToMPIGroup(const std::string &group_name, const MPI_Group mpi_group); + + int rank_id_; + int rank_size_; + int local_rank_id_; + MPI_Group world_group_; + std::map group_name_to_mpi_group_map_; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc new file mode 100644 index 0000000000..bcba538309 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/distribution/nccl_wrapper.h" + +namespace mindspore { +namespace device { +namespace gpu { +NCCLWrapper &NCCLWrapper::instance() { + static NCCLWrapper instance; + return instance; +} + +ncclUniqueId NCCLWrapper::nccl_unique_id() const { + ncclUniqueId unique_id; + CHECK_RET(ncclGetUniqueId(&unique_id), ncclSuccess, "Failed to create nccl unique id."); + return unique_id; +} + +void NCCLWrapper::set_nccl_unique_id(ncclUniqueId unique_id) { unique_id_ = unique_id; } + +void NCCLWrapper::set_rank(int rank_id, int rank_size) { + rank_id_ = rank_id; + rank_size_ = rank_size; +} + +void NCCLWrapper::InitNCCLComm() { + CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess, + "Failed to init nccl communicator."); + group_to_comm_map_[NCCL_WORLD_GROUP] = comm_; +} + +void NCCLWrapper::InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank) { + CHECK_RET(ncclCommInitRank(comm, rank_size, unique_id, rank), ncclSuccess, "Failed to init nccl communicator."); +} + +ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) { + CHECK_RET(group_to_comm_map_.count(group_name), 1, + "Failed to find NCCL communicator for AllReduce by the group name " + group_name); + ncclComm_t group_comm = group_to_comm_map_[group_name]; + return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); +} + +ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + cudaStream_t stream, const std::string &group_name) { + CHECK_RET(group_to_comm_map_.count(group_name), 1, + "Failed to find NCCL communicator for AllGather by the group name " + group_name); + ncclComm_t group_comm = group_to_comm_map_[group_name]; + return ncclAllGather(input_addr, output_addr, count, data_type, group_comm, stream); +} + +ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, + const std::string &group_name) { + CHECK_RET(group_to_comm_map_.count(group_name), 1, + "Failed to find NCCL communicator for ReduceScatter by the group name " + group_name); + ncclComm_t group_comm = group_to_comm_map_[group_name]; + return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); +} + +void NCCLWrapper::SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm) { + group_to_comm_map_[group_name] = comm; +} + +void NCCLWrapper::DestroyGroup(const std::string &group_name) { + auto group_iter = group_to_comm_map_.find(group_name); + if (group_iter == group_to_comm_map_.end()) { + return; + } + group_to_comm_map_.erase(group_iter); + ncclComm_t group_comm = group_iter->second; + CHECK_RET(ncclCommDestroy(group_comm), ncclSuccess, "Failed to destroy NCCL communicator for " + group_name); + return; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h new file mode 100644 index 0000000000..9cea338c41 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ + +#include +#include +#include +#include +#include +#include "runtime/device/gpu/distribution/collective_common.h" + +namespace mindspore { +namespace device { +namespace gpu { +class NCCLWrapper { + public: + NCCLWrapper(NCCLWrapper const &) = delete; + NCCLWrapper &operator=(const NCCLWrapper &) = delete; + static NCCLWrapper &instance(); + ncclUniqueId nccl_unique_id() const; + void set_nccl_unique_id(ncclUniqueId unique_id); + void set_rank(int rank_id, int rank_size); + void InitNCCLComm(); + void InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank); + ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, + ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); + ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, + cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); + ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, + ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); + void SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm); + void DestroyGroup(const std::string &group_name); + + private: + NCCLWrapper() : rank_id_(-1), rank_size_(0) {} + ~NCCLWrapper() = default; + + private: + int rank_id_; + int rank_size_; + ncclUniqueId unique_id_; + ncclComm_t comm_; + std::map group_to_comm_map_; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.cc new file mode 100644 index 0000000000..a1b1fa9b79 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.cc @@ -0,0 +1,191 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_buffer_mgr.h" +#include +#include +#include "utils/log_adapter.h" +#include "common/utils.h" + +namespace mindspore { +namespace device { +unsigned int HandleMgr::AllocHandle() { + for (size_t i = 0; i < MAX_HANDLE_NUM; ++i) { + if (!handle_list_[i]) { + handle_list_[i] = true; + return (unsigned int)i; + } + } + return INVALID_HANDLE; +} + +void HandleMgr::FreeHandle(unsigned int handle_id) { + if (handle_id >= MAX_HANDLE_NUM) { + return; + } + handle_list_[handle_id] = false; +} + +GpuBufferMgr &GpuBufferMgr::GetInstance() noexcept { + static GpuBufferMgr instance; + return instance; +} + +BlockQueueStatus_T GpuBufferMgr::Create(unsigned int device_id, const std::string &channel_name, void *addr, + const std::vector &shape, const size_t &capacity) { + std::string name = std::to_string(device_id) + std::string("_") + channel_name; + if (name_queue_map_.count(name)) { + MS_LOG(ERROR) << "Queue not exist " << name; + return QUEUE_NOT_EXIST; + } + std::shared_ptr queue = std::make_shared(); + BlockQueueStatus_T rt = queue->Create(addr, shape, capacity); + if (rt != SUCCESS) { + return rt; + } + (void)name_queue_map_.insert(std::make_pair(name, queue)); + init_ = true; + return SUCCESS; +} + +unsigned int GpuBufferMgr::Open(unsigned int device_id, const std::string &channel_name, + const std::vector &shape, const std::function func) { + set_device(); + std::string name = std::to_string(device_id) + std::string("_") + channel_name; + if (!name_queue_map_.count(name)) { + MS_LOG(ERROR) << "Queue not exist " << name; + return HandleMgr::INVALID_HANDLE; + } + unsigned int handle = handle_mgr_.AllocHandle(); + if (handle == HandleMgr::INVALID_HANDLE) { + MS_LOG(ERROR) << "handle is invalid"; + return HandleMgr::INVALID_HANDLE; + } + (void)handle_queue_map_.insert(std::make_pair(handle, name_queue_map_[name])); + name_queue_map_[name]->RegisterRelease(func); + open_by_dataset_++; + return handle; +} + +unsigned int GpuBufferMgr::Open(unsigned int device_id, const std::string &channel_name, + const std::vector &shape) { + set_device(); + std::string name = std::to_string(device_id) + std::string("_") + channel_name; + if (!name_queue_map_.count(name)) { + MS_LOG(ERROR) << "Queue not exist " << name; + return HandleMgr::INVALID_HANDLE; + } + unsigned int handle = handle_mgr_.AllocHandle(); + if (handle == HandleMgr::INVALID_HANDLE) { + MS_LOG(ERROR) << "handle is invalid"; + return HandleMgr::INVALID_HANDLE; + } + (void)handle_queue_map_.insert(std::make_pair(handle, name_queue_map_[name])); + return handle; +} + +void GpuBufferMgr::set_device_id(int device_id) { cur_dev_id_ = device_id; } + +void GpuBufferMgr::set_device() const { + auto ret = cudaSetDevice(cur_dev_id_); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaSetDevice, ret[" << static_cast(ret) << "]"; + } +} + +BlockQueueStatus_T GpuBufferMgr::Push(unsigned int handle, const std::vector &data, + unsigned int timeout_in_sec) { + auto iter = handle_queue_map_.find(handle); + if (iter == handle_queue_map_.end()) { + return HANDLE_NOT_EXIST; + } + return iter->second->Push(data, timeout_in_sec); +} + +BlockQueueStatus_T GpuBufferMgr::Front(unsigned int handle, void **addr, size_t *len) { + auto iter = handle_queue_map_.find(handle); + if (iter == handle_queue_map_.end()) { + return HANDLE_NOT_EXIST; + } + return iter->second->Front(addr, len); +} + +BlockQueueStatus_T GpuBufferMgr::Pop(unsigned int handle) { + auto iter = handle_queue_map_.find(handle); + if (iter == handle_queue_map_.end()) { + return HANDLE_NOT_EXIST; + } + return iter->second->Pop(); +} + +void GpuBufferMgr::Close(unsigned int handle) noexcept { + if (!handle_queue_map_.count(handle)) { + return; + } + (void)handle_queue_map_.erase(handle); + handle_mgr_.FreeHandle(handle); + return; +} + +bool GpuBufferMgr::IsInit() const { return init_; } + +bool GpuBufferMgr::IsClosed() const { return closed_; } + +bool GpuBufferMgr::Destroy() { + for (auto iter = name_queue_map_.begin(); iter != name_queue_map_.end(); ++iter) { + std::shared_ptr queue = iter->second; + if (queue != nullptr) { + if (!queue->Destroy()) { + return false; + } + queue.reset(); + } + } + name_queue_map_.clear(); + return true; +} + +inline bool GpuBufferMgr::isCreated(unsigned int device_id, const std::string &channel_name) { + std::string name = std::to_string(device_id) + std::string("_") + channel_name; + if (name_queue_map_.count(name) != 0) { + return true; + } + return false; +} + +bool GpuBufferMgr::CloseNotify() { + bool result = true; + // lock scope + { + std::lock_guard lk(close_mutex_); + // set closed_ to be true, all the dataset retry can be jumped out of the while + closed_ = true; + } + + // wati for the dataset threads' ack + for (int i = 0; i < open_by_dataset_; i++) { + if (sema.Wait() == false) { + MS_LOG(ERROR) << "time out of receiving signals"; + result = false; + } + MS_LOG(DEBUG) << "receive one signal (" << i + 1 << "/" << open_by_dataset_ << ")"; + } + return result; +} + +void GpuBufferMgr::CloseConfirm() { sema.Signal(); } +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.h b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.h new file mode 100644 index 0000000000..722a36c4ed --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.h @@ -0,0 +1,139 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "runtime/device/gpu/blocking_queue.h" + +#define EXPORT __attribute__((visibility("default"))) + +namespace mindspore { +namespace device { +static const unsigned int MAX_WAIT_TIME_IN_SEC = 60; + +class Semaphore { + public: + explicit Semaphore(int count = 0) : count_(count) {} + + inline void Signal() { + std::unique_lock lock(mutex_); + ++count_; + cv_.notify_one(); + } + + inline bool Wait() { + std::unique_lock lock(mutex_); + while (count_ == 0) { + if (cv_.wait_for(lock, std::chrono::seconds(MAX_WAIT_TIME_IN_SEC)) == std::cv_status::timeout) { + return false; + } + } + --count_; + return true; + } + + private: + std::mutex mutex_; + std::condition_variable cv_; + int count_; +}; + +class HandleMgr { + public: + static const unsigned int MAX_HANDLE_NUM = 32; + static const unsigned int INVALID_HANDLE = 0xffffffffUL; + + unsigned int AllocHandle(); + void FreeHandle(unsigned int); + + private: + bool handle_list_[MAX_HANDLE_NUM]; +}; + +class GpuBufferMgr { + public: + EXPORT GpuBufferMgr() : cur_dev_id_(0), init_(false), closed_(false), open_by_dataset_(0) {} + + EXPORT virtual ~GpuBufferMgr() = default; + + EXPORT static GpuBufferMgr &GetInstance() noexcept; + + EXPORT BlockQueueStatus_T Create(unsigned int device_id, const std::string &channel_name, void *addr, + const std::vector &shape, const size_t &capacity); + + // call for Push thread + EXPORT unsigned int Open(unsigned int device_id, const std::string &channel_name, const std::vector &shape, + std::function func); + + // call for Front/Pop thread + EXPORT unsigned int Open(unsigned int device_id, const std::string &channel_name, const std::vector &shape); + + EXPORT BlockQueueStatus_T Push(unsigned int handle, const std::vector &data, + unsigned int timeout_in_sec); + EXPORT BlockQueueStatus_T Front(unsigned int handle, void **addr, size_t *len); + EXPORT BlockQueueStatus_T Pop(unsigned int handle); + + EXPORT void set_device_id(int device_id); + + EXPORT void Close(unsigned int handle) noexcept; + + EXPORT bool IsInit() const; + + EXPORT bool IsClosed() const; + + EXPORT bool Destroy(); + + // call for Release GPU Resources + EXPORT bool CloseNotify(); + + // call for dataset send thread + EXPORT void CloseConfirm(); + + private: + void set_device() const; + + int cur_dev_id_; + bool init_; + bool closed_; + std::mutex mutex_; + std::mutex close_mutex_; + // how many queues opened by dataset + int open_by_dataset_; + Semaphore sema; + + HandleMgr handle_mgr_; + + std::map> handle_queue_map_; + std::map> name_queue_map_; + + inline bool isCreated(unsigned int device_id, const std::string &channel_name); + + GpuBufferMgr(const GpuBufferMgr &) = delete; + GpuBufferMgr &operator=(const GpuBufferMgr &) = delete; +}; +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_common.h b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h similarity index 100% rename from mindspore/ccsrc/device/gpu/gpu_common.h rename to mindspore/ccsrc/runtime/device/gpu/gpu_common.h diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc new file mode 100644 index 0000000000..a20a6a9a3c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_device_address.h" +#include +#include "runtime/device/gpu/gpu_device_manager.h" +#include "utils/log_adapter.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" + +namespace mindspore { +namespace device { +namespace gpu { +bool GPUDeviceAddress::SyncDeviceToHost(const std::vector &, size_t size, TypeId, void *host_ptr) const { + MS_EXCEPTION_IF_NULL(host_ptr); + auto &stream = GPUDeviceManager::GetInstance().default_stream(); + MS_EXCEPTION_IF_NULL(stream); + auto ret = GPUDeviceManager::GetInstance().SyncStream(stream); + if (!ret) { + MS_LOG(ERROR) << "SyncStream failed"; + return ret; + } + if (size != size_) { + MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_; + return true; + } + return GPUDeviceManager::GetInstance().CopyDeviceMemToHost(host_ptr, ptr_, size_); +} + +bool GPUDeviceAddress::SyncHostToDevice(const std::vector &, size_t, TypeId, const void *host_ptr) const { + MS_EXCEPTION_IF_NULL(host_ptr); + auto &stream = GPUDeviceManager::GetInstance().default_stream(); + MS_EXCEPTION_IF_NULL(stream); + if (!GPUDeviceManager::GetInstance().CopyHostMemToDeviceAsync(ptr_, host_ptr, size_, stream)) { + MS_LOG(ERROR) << "CopyHostMemToDeviceAsync failed"; + return false; + } + return GPUDeviceManager::GetInstance().SyncStream(stream); +} + +GPUDeviceAddress::~GPUDeviceAddress() { + if (ptr_ == nullptr) { + return; + } + if (from_mem_pool_) { + GPUMemoryAllocator::GetInstance().FreeTensorMem(ptr_); + ptr_ = nullptr; + } +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h new file mode 100644 index 0000000000..ade738deed --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ + +#include +#include +#include "runtime/device/device_address.h" + +namespace mindspore { +namespace device { +namespace gpu { +class GPUDeviceAddress : public DeviceAddress { + public: + GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} + GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) + : DeviceAddress(ptr, size, format, type_id) {} + ~GPUDeviceAddress() override; + + bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; + bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + void set_status(DeviceAddressStatus status) { status_ = status; } + DeviceAddressStatus status() const { return status_; } + DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; } + + private: + DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc new file mode 100644 index 0000000000..8f17fc20b5 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_device_manager.h" +#include "runtime/device/gpu/gpu_common.h" +#include "utils/log_adapter.h" +#include "utils/convert_utils.h" +#include "runtime/device/gpu/gpu_buffer_mgr.h" + +namespace mindspore { +namespace device { +namespace gpu { +void GPUDeviceManager::InitDevice() { + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::set_current_device(SizeToInt(cur_dev_id_)), "Failed to set current device id"); + CHECK_OP_RET_WITH_EXCEPT(CreateStream(&default_stream_), "Failed to create CUDA stream."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreate(&cudnn_handle_), "Failed to create cuDNN handle"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetStream(cudnn_handle_, reinterpret_cast(default_stream())), + "Failed to set stream for cuDNN handle."); + CHECK_CUBLAS_RET_WITH_EXCEPT(cublasCreate(&cublas_handle_), "Failed to create cuBLAS handle."); + CHECK_CUBLAS_RET_WITH_EXCEPT(cublasSetStream(cublas_handle_, reinterpret_cast(default_stream())), + "Failed to set stream for cuBLAS handle."); + CHECK_OP_RET_WITH_EXCEPT(GPUMemoryAllocator::GetInstance().Init(), "Failed to Init gpu memory allocator") +} + +void GPUDeviceManager::ReleaseDevice() { + for (DeviceStream stream : gpu_streams_) { + if (stream != nullptr) { + CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream), "Failed to destroy CUDA stream."); + } + } + if (cudnn_handle_ != nullptr) { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroy(cudnn_handle_), "Failed to destroy cuDNN handle"); + } + if (cublas_handle_ != nullptr) { + CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle."); + } + CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); +} + +bool GPUDeviceManager::CreateStream(DeviceStream *stream) { + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); + gpu_streams_.emplace_back(*stream); + return true; +} + +const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } + +int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } + +bool GPUDeviceManager::set_cur_device_id(uint32_t device_id) { + if (!dev_id_init_) { + dev_id_init_ = true; + cur_dev_id_ = device_id; + mindspore::device::GpuBufferMgr::GetInstance().set_device_id(UintToInt(device_id)); + return true; + } else { + MS_LOG(ERROR) << "Device already been set."; + return false; + } +} + +uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } + +bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } + +const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } + +const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } + +bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } + +bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { + return CudaDriver::CopyDeviceMemToHost(dst, src, size); +} + +bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { + return CudaDriver::CopyHostMemToDevice(dst, src, size); +} + +bool GPUDeviceManager::CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, + DeviceStream stream) const { + return CudaDriver::CopyDeviceMemToHostAsync(dst, src, size, stream); +} + +bool GPUDeviceManager::CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, + DeviceStream stream) const { + return CudaDriver::CopyHostMemToDeviceAsync(dst, src, size, stream); +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h new file mode 100644 index 0000000000..002806675c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h @@ -0,0 +1,83 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ + +#include +#include +#include +#include +#include "runtime/device/gpu/cuda_driver.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" + +namespace mindspore { +namespace device { +namespace gpu { +class GPUDeviceManager { + public: + void InitDevice(); + void ReleaseDevice(); + + int device_count() const; + bool set_cur_device_id(uint32_t device_id); + uint32_t cur_device_id() const; + bool is_device_id_init() const; + + bool CreateStream(DeviceStream *stream); + bool SyncStream(const DeviceStream &stream) const; + const DeviceStream &default_stream() const; + + const cudnnHandle_t &GetCudnnHandle() const; + const cublasHandle_t &GetCublasHandle() const; + + bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; + bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; + + bool CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, DeviceStream stream) const; + bool CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, DeviceStream stream) const; + + static GPUDeviceManager &GetInstance() { + static GPUDeviceManager instance; + return instance; + } + + private: + GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} + ~GPUDeviceManager() = default; + GPUDeviceManager(const GPUDeviceManager &) = delete; + GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; + + // default CUDA stream used for all the kernels. + DeviceStream default_stream_{nullptr}; + + // all gpu CUDA streams including default_stream_. + std::vector gpu_streams_; + + // handle used for cuDNN kernels. + cudnnHandle_t cudnn_handle_{nullptr}; + + // handle used for cuBLAS kernels. + cublasHandle_t cublas_handle_{nullptr}; + + bool dev_id_init_; + uint32_t cur_dev_id_; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc new file mode 100644 index 0000000000..9d88a205bc --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/gpu/gpu_kernel_build.h" +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/akg/akg_kernel_build.h" +#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "frontend/operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +namespace mindspore { +namespace device { +namespace gpu { +void GpuBuild(const KernelGraphPtr &kernel_graph) { + kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); + MS_EXCEPTION_IF_NULL(bin_map); + bin_map->Initialize(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto kernels = kernel_graph->execution_order(); + for (const auto &kernel : kernels) { + std::string kernel_name = session::AnfRuntimeAlgorithm::GetCNodeName(kernel); + if (kernel_name == prim::kPrimTupleGetItem->name() || kernel_name == prim::kPrimMakeTuple->name() || + kernel_name == prim::kPrimDepend->name() || kernel_name == prim::kPrimStateSetItem->name()) { + continue; + } + + if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) { + auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel); + if (!gpu_kernel_ptr) { + MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed"; + } + session::AnfRuntimeAlgorithm::SetKernelMod(gpu_kernel_ptr, kernel.get()); + } else { + auto gpu_kernel_ptr = kernel::GpuKernelFactory::GetInstance().Create(kernel_name, kernel); + if (!gpu_kernel_ptr) { + MS_LOG(EXCEPTION) << "Build gpu kernel op[" << kernel_name << "] failed"; + } + if (!gpu_kernel_ptr->Init(kernel)) { + MS_LOG(EXCEPTION) << "Initialize gpu kernel op[" << kernel_name << "] failed."; + } + session::AnfRuntimeAlgorithm::SetKernelMod((kernel::KernelModPtr)gpu_kernel_ptr, kernel.get()); + } + } +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.h new file mode 100644 index 0000000000..831c4e9511 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.h @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ + +#include +#include "backend/session/kernel_graph.h" +namespace mindspore { +namespace device { +namespace gpu { +void GpuBuild(const std::shared_ptr &kernel_graph); +} // namespace gpu +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc new file mode 100644 index 0000000000..ddf73841b7 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -0,0 +1,646 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_kernel_runtime.h" +#include "runtime/device/gpu/gpu_device_address.h" +#include "runtime/device/gpu/cuda_driver.h" +#include "runtime/device/gpu/gpu_buffer_mgr.h" +#include "runtime/device/gpu/gpu_device_manager.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" +#include "runtime/device/gpu/distribution/collective_init.h" +#include "utils/convert_utils.h" +#include "utils/context/ms_context.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "runtime/device/gpu/gpu_common.h" +#include "common/utils.h" +#include "runtime/device/gpu/gpu_memory_manager.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/gpu/gpu_memory_copy_manager.h" + +namespace mindspore { +namespace device { +namespace gpu { +using mindspore::device::memswap::MemSwapManager; +using mindspore::device::memswap::SwapKind; +bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } + +bool GPUKernelRuntime::Init() { + if (device_init_ == true) { + GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); + return true; + } + auto ret = InitDevice(); + if (!ret) { + MS_LOG(ERROR) << "InitDevice error."; + return ret; + } + mem_manager_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(mem_manager_); + mem_manager_->MallocDeviceMemory(); + const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); + bool collective_inited = CollectiveInitializer::instance().collective_inited(); + if (collective_inited && collective_handle_ != nullptr) { + auto init_nccl_comm_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "InitNCCLComm")); + MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr); + (*init_nccl_comm_funcptr)(); + } + device_init_ = true; + return ret; +} + +DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) { + return std::make_shared(device_ptr, device_size, format, type_id); +} + +bool GPUKernelRuntime::InitDevice() { + if (GPUDeviceManager::GetInstance().device_count() <= 0) { + MS_LOG(ERROR) << "No GPU device found."; + return false; + } + const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); + bool collective_inited = CollectiveInitializer::instance().collective_inited(); + if (collective_inited && collective_handle_ != nullptr) { + auto get_local_rank_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "local_rank_id")); + MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); + device_id_ = IntToUint((*get_local_rank_funcptr)()); + } + if (!GPUDeviceManager::GetInstance().is_device_id_init()) { + if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_id_)) { + MS_LOG(ERROR) << "Failed to set current device to " << SizeToInt(device_id_); + return false; + } + } + GPUDeviceManager::GetInstance().InitDevice(); + stream_ = GPUDeviceManager::GetInstance().default_stream(); + if (stream_ == nullptr) { + MS_LOG(ERROR) << "No default CUDA stream found."; + return false; + } + return true; +} + +void GPUKernelRuntime::ReleaseDeviceRes() { + // For dataset mode. + if (GpuBufferMgr::GetInstance().IsInit()) { + if (!GpuBufferMgr::GetInstance().IsClosed()) { + if (!GpuBufferMgr::GetInstance().CloseNotify()) { + MS_LOG(EXCEPTION) << "Could not close gpu data queue."; + } + } + CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue."); + } + + // Destroy remaining memory swap events and free host memory. + for (auto &item : mem_swap_map_) { + auto &mem_swap_manager = item.second; + MS_EXCEPTION_IF_NULL(mem_swap_manager); + if (mem_swap_manager->trigger_swap()) { + mem_swap_manager->ClearSwapQueue(); + mem_swap_manager->ReleaseHostPinnedMem(); + } + } + + GPUDeviceManager::GetInstance().ReleaseDevice(); + if (mem_manager_ != nullptr) { + mem_manager_->FreeDeviceMemory(); + } + + kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); + MS_EXCEPTION_IF_NULL(bin_map); + bin_map->RemoveKernelCache(); +} + +void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(mem_manager_); + mem_manager_->ResetDynamicMemory(); + AssignStaticMemoryInput(graph); + AssignStaticMemoryValueNode(graph); + bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); + if (is_enable_dynamic_mem) { + // Use the dynamic memory pool. + InitKernelRefCount(graph); + InitMemorySwapInfo(graph); + InitKernelOutputAddress(graph); + } else { + AssignDynamicMemory(graph); + } +} + +bool GPUKernelRuntime::Run(session::KernelGraph *graph) { + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); + bool ret = true; + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); + bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); + if (is_enable_dynamic_mem && !is_enable_pynative_infer) { + auto graph_id = graph->graph_id(); + auto iter = mem_swap_map_.find(graph_id); + if (iter == mem_swap_map_.end()) { + MS_LOG(EXCEPTION) << "Find memory swap map failed."; + } + mem_swap_manager_ = iter->second; + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + while (!LaunchKernelDynamic(graph)) { + MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment."; + if (!UpdateMemorySwapInfo(graph)) { + return false; + } + } + } else { + ret = LaunchKernel(graph); + } + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(DEBUG) << "GPU kernel runtime run graph in " << cost << " us"; + return ret; +} + +void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + // Init the kernel reference count. + if (!mem_reuse_util_ptr->InitDynamicKernelRef(graph)) { + MS_LOG(EXCEPTION) << "Init kernel reference count failed"; + } + mem_reuse_util_ptr->SetKernelDefMap(); + mem_reuse_util_ptr->SetReuseRefCount(); + // Can't free the device address of graph output, so set the reference count of graph output specially. + mem_reuse_util_ptr->SetGraphOutputRefCount(); + // Can't free the device address of summary nodes, so set the reference count of summary nodes specially. + mem_reuse_util_ptr->SetSummaryNodesRefCount(); + auto graph_id = graph->graph_id(); + mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr; +} + +void GPUKernelRuntime::InitMemorySwapInfo(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared(); + MS_EXCEPTION_IF_NULL(gpu_mem_copy_manager); + MemSwapManagerPtr mem_swap_manager = std::make_shared(gpu_mem_copy_manager); + MS_EXCEPTION_IF_NULL(mem_swap_manager); + auto graph_id = graph->graph_id(); + mem_swap_map_[graph_id] = mem_swap_manager; +} + +void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + if (AnfAlgo::OutputAddrExist(kernel, i)) { + continue; + } + std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); + auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); + auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); + AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); + } + } +} + +void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + if (!AnfAlgo::OutputAddrExist(kernel, i)) { + continue; + } + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + if (device_address->ptr_) { + mem_manager_->FreeMemFromMemPool(device_address); + } + device_address->set_status(DeviceAddressStatus::kInDevice); + } + } +} + +bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto graph_id = graph->graph_id(); + auto iter = mem_reuse_util_map_.find(graph_id); + if (iter == mem_reuse_util_map_.end()) { + MS_LOG(EXCEPTION) << "Find memory reuse map failed."; + } + auto mem_reuse_util_ptr = iter->second; + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + // Reset the reference count. + mem_reuse_util_ptr->ResetDynamicUsedRefCount(); + // The inputs and outputs memory of communication kernel need be continuous, so separate processing. + AllocCommunicationOpDynamicRes(graph); + + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + if (!ret) { + return false; + } + if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { + MS_LOG(EXCEPTION) << "Launch kernel failed."; + } + FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id); + UpdateMemorySwapTask(kernel); + } + CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); + ClearSwapQueue(); + return true; +} + +bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); + for (auto &mem_swap_info : mem_swap_info_list) { + auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); + const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; + auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false); + + if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { + mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); + } else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { + auto status = device_address->status(); + if (status == DeviceAddressStatus::kInDeviceToHost) { + mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); + device_address->set_status(DeviceAddressStatus::kInDevice); + } else if (status == DeviceAddressStatus::kInHost) { + if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_)) { + return false; + } + if (!mem_swap_manager_->FindInSwapInBlackList(device_address->ptr_)) { + mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address); + } + } + } + } + return true; +} + +bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + ClearKernelOutputAddress(graph); + if (!mem_swap_manager_->mem_swap_init()) { + mem_swap_manager_->Init(graph); + } + return mem_swap_manager_->RetreatSwapInfo(); +} + +bool GPUKernelRuntime::UpdateMemorySwapTask(const AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return true; + } + if (mem_swap_manager_->QueryKernelTriggerSwap(kernel)) { + CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); + if (!AddMemorySwapTask(kernel)) { + return false; + } + } + CHECK_OP_RET_WITH_EXCEPT(mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost), "SyncCopyStream failed."); + return true; +} + +void GPUKernelRuntime::UpdateHostSwapQueue(const DeviceAddressPtr device_address) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return; + } + while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { + device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); + } + auto status = device_address->status(); + switch (status) { + case DeviceAddressStatus::kInDevice: + break; + case DeviceAddressStatus::kInDeviceToHost: { + mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); + device_address->set_status(DeviceAddressStatus::kInDevice); + break; + } + case DeviceAddressStatus::kInHostToDevice: { + while (device_address->status() != DeviceAddressStatus::kInDevice) { + while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { + device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); + } + } + break; + } + case DeviceAddressStatus::kInHost: + MS_LOG(ERROR) << "Invaild device address status:" << status; + break; + default: + MS_LOG(EXCEPTION) << "Invaild device address status:" << status; + } +} + +void GPUKernelRuntime::UpdateDeviceSwapQueue() { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return; + } + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { + if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + device_address_swap_out->set_status(DeviceAddressStatus::kInHost); + mem_manager_->FreeMemFromMemPool(device_address_swap_out); + } + } +} + +void GPUKernelRuntime::ClearSwapQueue() { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return; + } + mem_swap_manager_->ClearSwapQueue(); +} + +bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) { + MS_EXCEPTION_IF_NULL(mem_manager_); + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + auto ret = mem_manager_->MallocMemFromMemPool(device_address, size); + if (!ret) { + if (!mem_swap_manager_->trigger_swap()) { + return false; + } + mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { + if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + device_address_swap_out->set_status(DeviceAddressStatus::kInHost); + mem_manager_->FreeMemFromMemPool(device_address_swap_out); + } + } + ret = mem_manager_->MallocMemFromMemPool(device_address, size); + if (!ret) { + return false; + } + } + return true; +} + +void *GPUKernelRuntime::AttemptMallocMem(size_t size) { + MS_EXCEPTION_IF_NULL(mem_manager_); + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + auto device_ptr = mem_manager_->MallocMemFromMemPool(size); + if (!device_ptr) { + if (!mem_swap_manager_->trigger_swap()) { + return nullptr; + } + mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { + if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + device_address_swap_out->set_status(DeviceAddressStatus::kInHost); + mem_manager_->FreeMemFromMemPool(device_address_swap_out); + } + } + device_ptr = mem_manager_->MallocMemFromMemPool(size); + if (!device_ptr) { + return nullptr; + } + } + return device_ptr; +} + +bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, + AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) { + if (!AllocKernelInputDynamicRes(kernel, kernel_inputs)) { + return false; + } + if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs)) { + return false; + } + if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces)) { + return false; + } + return true; +} + +bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_inputs); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + MS_EXCEPTION_IF_NULL(device_address); + UpdateHostSwapQueue(device_address); + MS_EXCEPTION_IF_NULL(device_address->ptr_); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + input->size = device_address->size_; + kernel_inputs->emplace_back(input); + } + return true; +} + +bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_outputs) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_outputs); + UpdateDeviceSwapQueue(); + auto output_sizes = kernel_mod.GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + MS_EXCEPTION_IF_NULL(device_address); + if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) { + return false; + } + kernel::AddressPtr output = std::make_shared(); + MS_EXCEPTION_IF_NULL(output); + output->addr = device_address->ptr_; + output->size = output_sizes[i]; + kernel_outputs->emplace_back(output); + } + return true; +} + +bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_workspaces) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_workspaces); + auto workspace_sizes = kernel_mod.GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_sizes.size(); ++i) { + if (workspace_sizes[i] == 0) { + kernel_workspaces->emplace_back(nullptr); + continue; + } + auto device_ptr = AttemptMallocMem(workspace_sizes[i]); + if (!device_ptr) { + return false; + } + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_ptr; + workspace->size = workspace_sizes[i]; + kernel_workspaces->emplace_back(workspace); + } + return true; +} + +void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + if (AnfAlgo::IsCommunicationOp(kernel)) { + AllocCommunicationOpInputDynamicRes(kernel); + AllocCommunicationOpOutputDynamicRes(kernel); + } + } +} + +void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(kernel); + bool is_need_alloc_memory = false; + bool is_need_free_memory = false; + size_t total_size = 0; + std::vector size_list; + DeviceAddressPtrList addr_list; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + MS_EXCEPTION_IF_NULL(device_address); + if (device_address->ptr_ == nullptr) { + is_need_alloc_memory = true; + } else { + is_need_free_memory = true; + } + total_size += device_address->size_; + size_list.emplace_back(device_address->size_); + addr_list.emplace_back(device_address); + } + AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); +} + +void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(kernel); + bool is_need_alloc_memory = false; + bool is_need_free_memory = false; + size_t total_size = 0; + std::vector size_list; + DeviceAddressPtrList addr_list; + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + MS_EXCEPTION_IF_NULL(device_address); + if (device_address->ptr_ == nullptr) { + is_need_alloc_memory = true; + } else { + is_need_free_memory = true; + } + total_size += output_sizes[i]; + size_list.emplace_back(output_sizes[i]); + addr_list.emplace_back(device_address); + } + AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); +} + +void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, + const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list) { + MS_EXCEPTION_IF_NULL(mem_manager_); + if (!is_need_alloc_memory) { + return; + } + if (is_need_free_memory) { + for (const auto &iter : addr_list) { + MS_EXCEPTION_IF_NULL(iter); + // Free the inputs/outputs of communication kernel which are not released. + if (iter->ptr_ != nullptr) { + mem_manager_->FreeMemFromMemPool(iter); + } + } + } + auto ret = mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); + if (!ret) { + MS_LOG(EXCEPTION) << "Malloc device memory failed."; + } +} + +void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, + const AddressPtrList &kernel_workspaces, uint32_t graph_id) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + auto cnode = kernel->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::IsCommunicationOp(kernel)) { + return; + } + // Free the input of kernel by reference count. + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetKernelInputRef(cnode, i); + if (kernel_ref_count_ptr == nullptr) { + continue; + } + kernel_ref_count_ptr->ref_count_dynamic_use_--; + if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) { + MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; + } + if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + mem_manager_->FreeMemFromMemPool(device_address); + device_address->set_status(DeviceAddressStatus::kInDevice); + } + } + // Free the output of kernel, if output has no reference. + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) { + auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetRef(cnode, i); + if (kernel_ref_count_ptr == nullptr) { + continue; + } + if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + mem_manager_->FreeMemFromMemPool(device_address); + device_address->set_status(DeviceAddressStatus::kInDevice); + } + } + // Free the workspace of kernel. + for (size_t i = 0; i < kernel_workspaces.size(); ++i) { + auto workspace = kernel_workspaces[i]; + if (workspace != nullptr) { + MS_EXCEPTION_IF_NULL(workspace->addr); + mem_manager_->FreeMemFromMemPool(workspace->addr); + workspace->addr = nullptr; + } + } +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h new file mode 100644 index 0000000000..2b1f8198ce --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ + +#include +#include +#include +#include +#include +#include "runtime/device/kernel_runtime.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "backend/optimizer/mem_reuse/mem_swap_manager.h" + +namespace mindspore { +namespace device { +namespace gpu { +using mindspore::device::memswap::MemSwapManagerPtr; +class GPUKernelRuntime : public KernelRuntime { + public: + GPUKernelRuntime() = default; + ~GPUKernelRuntime() override = default; + bool Init() override; + void ReleaseDeviceRes() override; + void AssignMemory(session::KernelGraph *graph) override; + bool Run(session::KernelGraph *graph) override; + + protected: + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) override; + bool SyncStream() override; + + private: + GPUKernelRuntime(const GPUKernelRuntime &); + GPUKernelRuntime &operator=(const GPUKernelRuntime &); + bool InitDevice(); + bool device_init_{false}; + + // The related functions and members for using dynamic memory pool. + void InitKernelRefCount(const session::KernelGraph *graph); + void InitKernelOutputAddress(const session::KernelGraph *graph); + void InitMemorySwapInfo(const session::KernelGraph *graph); + void ClearKernelOutputAddress(const session::KernelGraph *graph); + bool LaunchKernelDynamic(const session::KernelGraph *graph); + bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); + void *AttemptMallocMem(size_t size); + bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, + AddressPtrList *kernel_outputs); + bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs); + bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_outputs); + bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces); + void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); + void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); + void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); + void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, + const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list); + void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, + uint32_t graph_id); + bool AddMemorySwapTask(const AnfNodePtr &kernel); + bool UpdateMemorySwapInfo(const session::KernelGraph *graph); + bool UpdateMemorySwapTask(const AnfNodePtr &kernel); + void UpdateHostSwapQueue(const DeviceAddressPtr device_address); + void UpdateDeviceSwapQueue(); + void ClearSwapQueue(); + std::unordered_map mem_reuse_util_map_; + std::unordered_map mem_swap_map_; + MemSwapManagerPtr mem_swap_manager_{nullptr}; +}; +MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); +} // namespace gpu +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc new file mode 100644 index 0000000000..e2395bbaf2 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "runtime/device/gpu/gpu_memory_allocator.h" +#include "runtime/device/gpu/cuda_driver.h" +#include "utils/log_adapter.h" +#include "utils/context/ms_context.h" +#include "utils/convert_utils_base.h" + +namespace mindspore { +namespace device { +namespace gpu { +bool GPUMemoryAllocator::Init() { + size_t total_size = total_mem_size(); + size_t free_size = CudaDriver::free_mem_size(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + limited_device_memory_ = context_ptr->max_device_memory(); + available_device_memory_ = FloatToSize(limited_device_memory_ * 1024 * 1024 * 1024); + if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) { + MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size + << ", set max available memory size " << available_device_memory_ << "."; + } else { + MS_LOG(EXCEPTION) << "GPU device memory error, total memory size " << total_size << ", current free memory size " + << free_size << ", set max available memory size " << available_device_memory_ << "."; + } + return true; +} + +void GPUMemoryAllocator::CheckMaxDeviceMemory() const { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto max_device_memory = context_ptr->max_device_memory(); + // Currently not support modifying the max device memory. + if (limited_device_memory_ != max_device_memory) { + MS_LOG(EXCEPTION) + << "Can't change context param max_device_memory in runtime, currently effective max_device_memory(" + << limited_device_memory_ << "GB), set new max_device_memory(" << max_device_memory << "GB) failed."; + } +} + +bool GPUMemoryAllocator::Finalize() { + if (buffer_q_addr_ != nullptr) { + if (!CudaDriver::FreeDeviceMem(buffer_q_addr_)) { + MS_LOG(ERROR) << "Could not free buffer queue memory."; + return false; + } + } + return true; +} + +bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { + auto alloc_size = AllocDeviceMem(size, addr); + buffer_q_addr_ = *addr; + // Buffer queue needs to ensure that the alloc_size and size is equal. + return (alloc_size == size) ? true : false; +} + +size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { + if (size == 0) { + MS_LOG(EXCEPTION) << "The memory alloc size is 0."; + } + auto free_size = free_mem_size(); + if (size > free_size) { + MS_LOG(EXCEPTION) << "Memory not enough: current free memory size[" << free_size + << "] is smaller than required size[" << size << "]."; + } + + auto alloc_size = CudaDriver::AllocDeviceMem(size, addr); + if (alloc_size == 0) { + MS_LOG(EXCEPTION) << "Alloc device memory[" << size << "] failed."; + } + total_used_device_memory_ += alloc_size; + available_device_memory_ -= alloc_size; + MS_LOG(INFO) << "Current free memory size[" << free_size - alloc_size << "], current alloc size[" << alloc_size + << "], total used size[" << total_used_device_memory_ << "]."; + return alloc_size; +} + +bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } + +size_t GPUMemoryAllocator::free_mem_size() { return std::min(CudaDriver::free_mem_size(), available_device_memory_); } + +size_t GPUMemoryAllocator::total_mem_size() { return CudaDriver::total_mem_size(); } +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.h b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.h new file mode 100644 index 0000000000..4b6eaa4e14 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ + +#include +#include "runtime/device/gpu/cuda_driver.h" +#include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h" + +namespace mindspore { +namespace device { +namespace gpu { +class GPUMemoryAllocator : public DynamicMemPoolBestFit { + public: + ~GPUMemoryAllocator() override = default; + bool Init(); + void CheckMaxDeviceMemory() const; + bool Finalize(); + bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); + + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; + size_t free_mem_size() override; + size_t total_mem_size() override; + + static GPUMemoryAllocator &GetInstance() { + static GPUMemoryAllocator instance; + return instance; + } + + private: + GPUMemoryAllocator() = default; + GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; + GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; + + // Used to track address of data buffer queue. + DeviceMemPtr buffer_q_addr_{nullptr}; + + float limited_device_memory_{0.0}; + size_t total_used_device_memory_{0}; + size_t available_device_memory_{0}; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.cc new file mode 100644 index 0000000000..0406c0f151 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_memory_copy_manager.h" +#include "runtime/device/gpu/gpu_common.h" +#include "runtime/device/gpu/gpu_device_manager.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace gpu { +void GPUMemCopyManager::Init() { + CHECK_OP_RET_WITH_EXCEPT(GPUDeviceManager::GetInstance().CreateStream(&swap_out_stream_), + "Failed to create CUDA stream of memory swap out."); + CHECK_OP_RET_WITH_EXCEPT(GPUDeviceManager::GetInstance().CreateStream(&swap_in_stream_), + "Failed to create CUDA stream of memory swap in."); +} + +void GPUMemCopyManager::AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) { + MS_EXCEPTION_IF_NULL(device_address); + MS_EXCEPTION_IF_NULL(host_addr.addr); + DeviceEvent event = nullptr; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event."); + DeviceMemPtr device_ptr = const_cast(device_address->GetPtr()); + MS_EXCEPTION_IF_NULL(device_ptr); + device_address->set_status(DeviceAddressStatus::kInDeviceToHost); + + CHECK_OP_RET_WITH_EXCEPT( + CudaDriver::CopyDeviceMemToHostAsync(host_addr.addr, device_ptr, host_addr.size, swap_out_stream_), + "Failed to copy device memory to host."); + + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(event, swap_out_stream_), + "Failed to record CUDA event to swap out stream."); + swap_out_queue_.emplace(device_address, event); +} + +void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) { + MS_EXCEPTION_IF_NULL(device_address); + MS_EXCEPTION_IF_NULL(host_addr.addr); + DeviceEvent event = nullptr; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event."); + DeviceMemPtr device_ptr = const_cast(device_address->GetPtr()); + MS_EXCEPTION_IF_NULL(device_ptr); + device_address->set_status(DeviceAddressStatus::kInHostToDevice); + + CHECK_OP_RET_WITH_EXCEPT( + CudaDriver::CopyHostMemToDeviceAsync(device_ptr, host_addr.addr, host_addr.size, swap_in_stream_), + "Failed to copy host memory to device."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(event, swap_in_stream_), + "Failed to record CUDA event to swap in stream."); + swap_in_queue_.emplace(device_address, event); +} + +bool GPUMemCopyManager::SyncMemCopyStream(SwapKind swap_kind) { + if (swap_kind == SwapKind::kDeviceToHost) { + return GPUDeviceManager::GetInstance().SyncStream(swap_out_stream_); + } else { + return GPUDeviceManager::GetInstance().SyncStream(swap_in_stream_); + } +} + +DeviceAddressPtr GPUMemCopyManager::UpdateSwapOutQueue() { + if (swap_out_queue_.empty()) { + return nullptr; + } + auto &task = swap_out_queue_.front(); + auto device_address = task.first; + auto &event = task.second; + bool finish_swap = CudaDriver::QueryEvent(event); + if (!finish_swap) { + return nullptr; + } + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap out."); + swap_out_queue_.pop(); + return device_address; +} + +DeviceAddressPtr GPUMemCopyManager::UpdateSwapInQueue() { + if (swap_in_queue_.empty()) { + return nullptr; + } + auto &task = swap_in_queue_.front(); + auto device_address = task.first; + auto &event = task.second; + bool finish_swap = CudaDriver::QueryEvent(event); + if (!finish_swap) { + return nullptr; + } + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap in."); + swap_in_queue_.pop(); + return device_address; +} + +bool GPUMemCopyManager::AllocHostPinnedMem(size_t size, void **addr) const { + auto alloc_size = CudaDriver::AllocHostPinnedMem(size, addr); + return alloc_size == size; +} + +void GPUMemCopyManager::FreeHostPinnedMem(void *addr) const { CudaDriver::FreeHostPinnedMem(addr); } + +void GPUMemCopyManager::ClearSwapQueue() { + CHECK_OP_RET_WITH_EXCEPT(SyncMemCopyStream(SwapKind::kDeviceToHost), "Failed to sync swap out stream"); + CHECK_OP_RET_WITH_EXCEPT(SyncMemCopyStream(SwapKind::kHostToDevice), "Failed to sync swap in stream"); + + while (!swap_out_queue_.empty()) { + auto &event = swap_out_queue_.front().second; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap out."); + swap_out_queue_.pop(); + } + while (!swap_in_queue_.empty()) { + auto &event = swap_in_queue_.front().second; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap in."); + swap_in_queue_.pop(); + } +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.h b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.h new file mode 100644 index 0000000000..dc99b7f7d0 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ + +#include +#include +#include +#include "backend/optimizer/mem_reuse/mem_copy_manager.h" +#include "runtime/device/device_address.h" +#include "runtime/device/gpu/cuda_driver.h" +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace device { +namespace gpu { +using mindspore::device::memswap::MemCopyManager; +using mindspore::device::memswap::SwapKind; +class GPUMemCopyManager : public MemCopyManager { + public: + GPUMemCopyManager() = default; + + ~GPUMemCopyManager() override = default; + + void Init() override; + + void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override; + + void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override; + + bool SyncMemCopyStream(SwapKind swap_kind) override; + + DeviceAddressPtr UpdateSwapOutQueue() override; + + DeviceAddressPtr UpdateSwapInQueue() override; + + bool AllocHostPinnedMem(size_t size, void **addr) const override; + + void FreeHostPinnedMem(void *addr) const override; + + void ClearSwapQueue() override; + + private: + DeviceStream swap_out_stream_{nullptr}; + DeviceStream swap_in_stream_{nullptr}; + std::queue> swap_out_queue_; + std::queue> swap_in_queue_; +}; +using GPUMemCopyManagerPtr = std::shared_ptr; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc new file mode 100644 index 0000000000..ffa07eea0d --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_memory_manager.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" +#include "utils/context/ms_context.h" +#include "utils/convert_utils.h" +namespace mindspore { +namespace device { +namespace gpu { +void *GPUMemoryManager::MallocMemFromMemPool(size_t size) { + return GPUMemoryAllocator::GetInstance().AllocTensorMem(size); +} + +void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) { + GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr); +} + +std::vector GPUMemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { + return GPUMemoryAllocator::GetInstance().AllocContinuousTensorMem(total_size, size_list); +} + +void GPUMemoryManager::MallocDeviceMemory() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + // If use the dynamic memory pool, then alloc the first memory block to init. + if (context_ptr->enable_dynamic_mem_pool()) { + auto device_addr = MallocMemFromMemPool(1); + if (!device_addr) { + MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; + } + } else { + // Need to reserve 20% space for dynamic memory + const float init_gpu_mem_ratio = 0.8; + size_t mem_size = FloatToSize(GPUMemoryAllocator::GetInstance().free_mem_size() * init_gpu_mem_ratio); + auto alloc_size = + GPUMemoryAllocator::GetInstance().AllocDeviceMem(mem_size, reinterpret_cast(&device_mem_base_)); + device_mem_size_ = alloc_size; + static_mem_offset_ = device_mem_size_; + } +} + +void GPUMemoryManager::FreeDeviceMemory() { + if (device_mem_base_ != nullptr) { + if (!GPUMemoryAllocator::GetInstance().FreeDeviceMem(device_mem_base_)) { + MS_LOG(EXCEPTION) << "Could not free gpu device memory."; + } + } + GPUMemoryAllocator::GetInstance().ReleaseDeviceRes(); +} + +uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_dynamic_mem_pool()) { + auto device_ptr = MallocMemFromMemPool(size); + MS_EXCEPTION_IF_NULL(device_ptr); + return AddressOffset(device_ptr, 0); + } + + auto align_size = GetCommonAlignSize(size); + if (static_mem_offset_ < align_size) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] failed!"; + } + auto offset = static_mem_offset_ - align_size; + if (dynamic_mem_offset_ > offset) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] failed!"; + } + total_static_size_ += align_size; + static_mem_offset_ = offset; + return device_mem_base_ + offset; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.h b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.h new file mode 100644 index 0000000000..533116cefc --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ +#include +#include "runtime/device/memory_manager.h" +namespace mindspore { +namespace device { +namespace gpu { +class GPUMemoryManager : public MemoryManager { + public: + GPUMemoryManager() = default; + virtual ~GPUMemoryManager() = default; + + void MallocDeviceMemory() override; + void FreeDeviceMemory() override; + + void *MallocMemFromMemPool(size_t size) override; + void FreeMemFromMemPool(void *device_ptr) override; + std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); + + protected: + uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc new file mode 100644 index 0000000000..78915f10d7 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc @@ -0,0 +1,193 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_stream_assign.h" +#include +#include +#include +#include +#include "runtime/device/gpu/gpu_common.h" +#include "runtime/device/gpu/kernel_info_setter.h" +#include "runtime/device/gpu/gpu_device_manager.h" + +namespace mindspore { +namespace device { +namespace gpu { +void AssignGpuStream(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector allreduce_kernels; + auto execution_kernels = kernel_graph->execution_order(); + for (auto kernel_node : execution_kernels) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == kAllReduceOpName) { + allreduce_kernels.emplace_back(kernel_node); + } else { + DeviceStream compute_stream = GPUDeviceManager::GetInstance().default_stream(); + MS_EXCEPTION_IF_NULL(compute_stream); + AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(reinterpret_cast(compute_stream)), kernel_node); + } + } + if (allreduce_kernels.size() > 1) { + // Assign multiple streams only when there're multiple AllReduce nodes. + std::vector send_recv_pairs; + if (FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs)) { + DeviceStream comm_stream = nullptr; + GPUDeviceManager::GetInstance().CreateStream(&comm_stream); + std::transform( + allreduce_kernels.begin(), allreduce_kernels.end(), allreduce_kernels.begin(), [&](CNodePtr allreduce_kernel) { + AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(reinterpret_cast(comm_stream)), allreduce_kernel); + return allreduce_kernel; + }); + InsertStreamSwitchNode(kernel_graph, send_recv_pairs); + } else { + return; + } + } +} + +bool FindAllReduceStreamSwitchPos(const std::shared_ptr &kernel_graph, + std::vector *send_recv_pairs) { + auto execution_kernels = kernel_graph->execution_order(); + std::vector::iterator iter, iter_begin; + iter = iter_begin = execution_kernels.begin(); + std::vector::iterator iter_end = execution_kernels.end(); + for (; iter != execution_kernels.end(); ++iter) { + std::string kernel_name = AnfAlgo::GetCNodeName(*iter); + if (kernel_name == kAllReduceOpName) { + // Find AllReduce node's last input node. + std::vector::iterator mock_send_node_iter = + FindSendNodePos(iter_begin, iter + 1, *iter, kAllReduceStreamSwitch); + if (mock_send_node_iter == iter + 1) { + MS_LOG(WARNING) << "Can't find send node place before AllReduce node."; + continue; + } + SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter, + IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)}; + send_recv_pairs->push_back(pair1); + // Find node which uses AllReduce as input[0]. + std::vector::iterator mock_recv_node_iter = + FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch); + if (mock_recv_node_iter == iter_end) { + MS_LOG(WARNING) << "Can't find recv node place after AllReduce node."; + return false; + } + SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1), + IntToSize(mock_recv_node_iter - iter_begin)}; + send_recv_pairs->push_back(pair2); + } + } + return true; +} + +std::vector::iterator FindSendNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_recv_node, + StreamSwitchType stream_switch_type) { + MS_EXCEPTION_IF_NULL(mock_recv_node); + if (stream_switch_type == kAllReduceStreamSwitch) { + for (auto iter = begin; iter != end; iter++) { + if (*(iter + 1) == mock_recv_node) { + return iter; + } + } + } + return end; +} + +std::vector::iterator FindRecvNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_send_node, + StreamSwitchType stream_switch_type) { + MS_EXCEPTION_IF_NULL(mock_send_node); + for (auto iter = begin; iter != end; iter++) { + auto node = *iter; + if (stream_switch_type == kAllReduceStreamSwitch) { + for (auto input : node->inputs()) { + if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) { + return iter; + } + } + } + } + return end; +} + +void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, + const std::vector &send_recv_pairs) { + std::set ordered_stream_switch_nodes; + for (SendRecvPair pair : send_recv_pairs) { + StreamSwitchType stream_switch_type = pair.stream_switch_type; + CNodePtr mock_send_node = pair.mock_send_node; + CNodePtr mock_recv_node = pair.mock_recv_node; + size_t send_node_offset = pair.send_node_offset; + size_t recv_node_offset = pair.recv_node_offset; + CNodePtr send_node = nullptr; + CNodePtr recv_node = nullptr; + // Step 1: generate Send and Recv CNodes. + if (stream_switch_type == kAllReduceStreamSwitch) { + if (!GenSendRecvCNodesForAllReduce(kernel_graph, mock_send_node, mock_recv_node, &send_node, &recv_node)) { + MS_LOG(EXCEPTION) << "Generating CNodes for send and recv failed. Stream switch type: kAllReduceStreamSwitch"; + } + } + // Step 2: sort send and recv CNodes by offset. + ordered_stream_switch_nodes.insert({send_node_offset, send_node}); + ordered_stream_switch_nodes.insert({recv_node_offset, recv_node}); + } + // Step 3: insert stream switch CNodes into execution kernel list. + auto execution_kernels = kernel_graph->execution_order(); + for (auto node = ordered_stream_switch_nodes.rbegin(); node != ordered_stream_switch_nodes.rend(); node++) { + execution_kernels.insert(execution_kernels.begin() + node->offset, node->cnode); + } + kernel_graph->set_execution_order(execution_kernels); +} + +bool GenSendRecvCNodesForAllReduce(const std::shared_ptr &kernel_graph, + const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node, + CNodePtr *recv_node) { + *send_node = CreateStreamSwitchNode(kernel_graph, kSendOpName); + MS_EXCEPTION_IF_NULL(*send_node); + *recv_node = CreateStreamSwitchNode(kernel_graph, kRecvOpName); + MS_EXCEPTION_IF_NULL(*recv_node); + + cudaEvent_t event = nullptr; + CHECK_CUDA_RET_WITH_EXCEPT(cudaEventCreate(&event, cudaEventDisableTiming), "Creating cuda event failed."); + AnfAlgo::SetNodeAttr(kAttrRecordEvent, MakeValue(reinterpret_cast(event)), *send_node); + AnfAlgo::SetNodeAttr(kAttrWaitEvent, MakeValue(reinterpret_cast(event)), *recv_node); + + uintptr_t send_stream = AnfAlgo::GetNodeAttr(mock_send_node, kAttrStreamId); + AnfAlgo::SetNodeAttr(kAttrRecordEventStream, MakeValue(send_stream), *send_node); + uintptr_t recv_stream = AnfAlgo::GetNodeAttr(mock_recv_node, kAttrStreamId); + AnfAlgo::SetNodeAttr(kAttrWaitEventStream, MakeValue(recv_stream), *recv_node); + return true; +} + +CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name) { + auto op = std::make_shared(name); + MS_EXCEPTION_IF_NULL(op); + auto apply = std::make_shared(op); + MS_EXCEPTION_IF_NULL(apply); + std::vector input_list = {apply}; + CNodePtr node = kernel_graph->NewCNode(input_list); + MS_EXCEPTION_IF_NULL(node); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), node.get()); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + node->set_abstract(abstract_none); + SetKernelInfo(node); + return node; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h new file mode 100644 index 0000000000..f22ce8fe38 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ + +#include +#include +#include +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace gpu { +enum StreamSwitchType { kAllReduceStreamSwitch, kStreamSwitchInvalidType = 255 }; +struct SendRecvPair { + StreamSwitchType stream_switch_type; + CNodePtr mock_send_node; + CNodePtr mock_recv_node; + size_t send_node_offset; + size_t recv_node_offset; +}; +struct StreamSwitchNode { + size_t offset; + CNodePtr cnode; + bool operator<(const StreamSwitchNode &n) const { + if (offset < n.offset) { + return true; + } else if (offset == n.offset) { + return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false; + } else { + return false; + } + } +}; +void AssignGpuStream(const std::shared_ptr &kernel_graph); +bool FindAllReduceStreamSwitchPos(const std::shared_ptr &kernel_graph, + std::vector *send_recv_pairs); +// Find Send node position according to "mock" recv node. +// "mock" recv node is a gpu kernel node after a real Recv node, e.g. AllReduce node. +std::vector::iterator FindSendNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_recv_node, + StreamSwitchType stream_switch_type); +// Find Recv node position according to "mock" send node. +// "mock" send node is a gpu kernel node before a real send node, e.g. AllReduce node. +std::vector::iterator FindRecvNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_send_node, + StreamSwitchType stream_switch_type); +void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, + const std::vector &send_recv_pairs); +bool GenSendRecvCNodesForAllReduce(const std::shared_ptr &kernel_graph, + const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node, + CNodePtr *recv_node); +CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name); +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc new file mode 100644 index 0000000000..4326987784 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/kernel_info_setter.h" +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" +#include "common/utils.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/oplib/opinfo.h" + +namespace mindspore { +namespace device { +namespace gpu { +using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; +using mindspore::kernel::KernelBuildInfo; +namespace { +bool CheckKernelInfo(const std::shared_ptr &alternative_kernel_info, + const std::shared_ptr &selected_kernel_info) { + MS_EXCEPTION_IF_NULL(selected_kernel_info); + MS_EXCEPTION_IF_NULL(alternative_kernel_info); + size_t selected_input_num = selected_kernel_info->GetInputNum(); + size_t alternative_input_num = alternative_kernel_info->GetInputNum(); + if (selected_input_num != alternative_input_num) { + return false; + } + for (size_t i = 0; i < selected_input_num; i++) { + if (selected_kernel_info->GetInputFormat(i) != alternative_kernel_info->GetInputFormat(i)) { + return false; + } + if (selected_kernel_info->GetInputDeviceType(i) != alternative_kernel_info->GetInputDeviceType(i)) { + return false; + } + } + + size_t selected_output_num = selected_kernel_info->GetOutputNum(); + size_t alternative_output_num = alternative_kernel_info->GetOutputNum(); + if (selected_output_num != alternative_output_num) { + return false; + } + for (size_t i = 0; i < selected_output_num; i++) { + if (selected_kernel_info->GetOutputFormat(i) != alternative_kernel_info->GetOutputFormat(i)) { + return false; + } + if (selected_kernel_info->GetOutputDeviceType(i) != alternative_kernel_info->GetOutputDeviceType(i)) { + return false; + } + } + return true; +} + +std::string SupportedTypeList(const CNodePtr &kernel_node) { + std::string supported_type_lists = + kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); + if (!supported_type_lists.empty()) { + return supported_type_lists; + } + std::vector> kernel_info_list; + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kAKG); + if (op_info_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Unsupported op [" << op_name << "]"; + } + (void)ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list); + for (size_t i = 0; i < kernel_info_list.size(); i++) { + auto supported_akg_type = kernel_info_list[i]->GetAllInputDeviceTypes(); + auto supported_akg_type_out = kernel_info_list[i]->GetAllOutputDeviceTypes(); + std::string supported_akg_type_list = "in["; + for (auto type : supported_akg_type) { + supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); + } + supported_type_lists = supported_type_lists + supported_akg_type_list + "], out["; + supported_akg_type_list.clear(); + for (auto type : supported_akg_type_out) { + supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); + } + supported_type_lists = supported_type_lists + supported_akg_type_list + "]; "; + } + return supported_type_lists; +} + +bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr &selected_kernel_info) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(selected_kernel_info); + std::vector> kernel_info_list; + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kAKG); + if (op_info_ptr == nullptr) { + MS_LOG(ERROR) << "Not find op[" << op_name << "] in akg"; + return false; + } + if (!ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list)) { + MS_LOG(EXCEPTION) << "Parsed metadata of op[" << op_name << "] failed."; + } + if (kernel_info_list.empty()) { + MS_LOG(EXCEPTION) << "Akg dose not has metadata of op[" << op_name << "]."; + } + + bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(), + [&](const std::shared_ptr &alternative_kernel_info) { + return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); + }); + if (!match) { + MS_LOG(ERROR) << "Not find op[" << op_name << "] in akg"; + return false; + } + return true; +} + +void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + auto input_kernel_node = kernel_node->input(input_index + 1); + MS_EXCEPTION_IF_NULL(input_kernel_node); + if (!input_kernel_node->isa()) { + continue; + } + std::shared_ptr builder = + std::make_shared(); + + auto param = input_kernel_node->cast(); + MS_EXCEPTION_IF_NULL(param); + if (!AnfAlgo::IsParameterWeight(param)) { + std::vector output_format = {kOpFormat_DEFAULT}; + builder->SetOutputsFormat(output_format); + std::vector output_type = {AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)}; + builder->SetOutputsDeviceType(output_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + continue; + } + if ((AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) || + (AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) { + std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; + builder->SetOutputsFormat(output_format); + std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + builder->SetOutputsDeviceType(output_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + } + } +} +} // namespace + +void SetKernelInfo(const CNodePtr &kernel_node) { + std::vector inputs_format; + std::vector inputs_type; + std::shared_ptr builder = + std::make_shared(); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + inputs_format.emplace_back(kOpFormat_DEFAULT); + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); + } + builder->SetInputsFormat(inputs_format); + builder->SetInputsDeviceType(inputs_type); + std::vector outputs_format; + std::vector outputs_type; + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + outputs_format.emplace_back(kOpFormat_DEFAULT); + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); + } + builder->SetOutputsFormat(outputs_format); + builder->SetOutputsDeviceType(outputs_type); + + bool result = + kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); + KernelType kernel_type = UNKNOWN_KERNEL_TYPE; + + if (!result) { + result = SelectAkgKernel(kernel_node, builder->Build()); + kernel_type = AKG_KERNEL; + } + + if (!result) { + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + std::string build_type = "in ["; + std::for_each(std::begin(inputs_type), std::end(inputs_type), + [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); + build_type += "] out ["; + std::for_each(std::begin(outputs_type), std::end(outputs_type), + [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); + build_type += "]"; + auto supported_type_lists = SupportedTypeList(kernel_node); + MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name + << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists + << ", but get " << build_type; + } + builder->SetKernelType(kernel_type); + builder->SetProcessor(kernel::Processor::CUDA); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); + SetTensorDeviceInfo(*(builder->Build()), kernel_node); +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.h b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h similarity index 100% rename from mindspore/ccsrc/device/gpu/kernel_info_setter.h rename to mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h diff --git a/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc new file mode 100644 index 0000000000..4605a0eb4e --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/mpi/mpi_initializer.h" + +#include +#include +#include + +namespace mindspore { +namespace device { +namespace gpu { +MPIInitializer::MPIInitializer() { + int init_flag = 0; + if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { + return; + } + if (init_flag == 0) { + auto ret = MPI_Init(nullptr, nullptr); + if (ret != MPI_SUCCESS) { + return; + } + } + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); + MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); +} + +MPIInitializer::~MPIInitializer() { + int finalized_flag = 0; + (void)MPI_Finalized(&finalized_flag); + if (finalized_flag == 0) { + (void)MPI_Finalize(); + } +} + +MPIInitializer &MPIInitializer::GetInstance() { + static MPIInitializer instance; + return instance; +} + +int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; } + +int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; } + +PYBIND11_MODULE(_ms_mpi, mpi_initializer) { + mpi_initializer.doc() = "mindspore mpi python wrapper"; + mpi_initializer.def("get_rank_id", &MPIInitializer::get_rank_id, "get rank id"); + mpi_initializer.def("get_rank_size", &MPIInitializer::get_rank_size, "get rank size"); +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/mpi/mpi_initializer.h b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h similarity index 100% rename from mindspore/ccsrc/device/gpu/mpi/mpi_initializer.h rename to mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h diff --git a/mindspore/ccsrc/device/gpu/readme.md b/mindspore/ccsrc/runtime/device/gpu/readme.md similarity index 100% rename from mindspore/ccsrc/device/gpu/readme.md rename to mindspore/ccsrc/runtime/device/gpu/readme.md diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.cc b/mindspore/ccsrc/runtime/device/kernel_adjust.cc new file mode 100644 index 0000000000..bb1f7f723e --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.cc @@ -0,0 +1,591 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/kernel_adjust.h" + +#include +#include +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/context/ms_context.h" +#include "common/trans.h" +#include "utils/config_manager.h" +#include "common/utils.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "utils/utils.h" +#include "runtime/device/ascend/profiling/profiling_manager.h" +#include "runtime/device/ascend/kernel_select_ascend.h" +#include "runtime/base.h" +#include "runtime/device/ascend/ascend_stream_assign.h" + +namespace mindspore { +namespace device { +using device::ascend::ProfilingUtils; +void KernelAdjust::ReorderGetNext(const std::shared_ptr &kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + const std::vector &origin_cnode_list = kernel_graph_ptr->execution_order(); + std::vector getnext_list; + std::vector other_list; + for (const auto &cnode : origin_cnode_list) { + if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) { + getnext_list.emplace_back(cnode); + } else { + other_list.emplace_back(cnode); + } + } + std::vector new_order_list; + new_order_list.insert(new_order_list.end(), getnext_list.begin(), getnext_list.end()); + new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end()); + kernel_graph_ptr->set_execution_order(new_order_list); +} + +bool KernelAdjust::NeedInsertSwitch() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && + ConfigManager::GetInstance().iter_num() > 1); +} + +CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, + uint32_t event_id) { + MS_EXCEPTION_IF_NULL(graph_ptr); + auto send_op = std::make_shared(kSendOpName); + MS_EXCEPTION_IF_NULL(send_op); + auto send_apply = std::make_shared(send_op); + MS_EXCEPTION_IF_NULL(send_apply); + std::vector send_input_list = {send_apply}; + CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); + MS_EXCEPTION_IF_NULL(send_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + send_node_ptr->set_abstract(abstract_none); + return send_node_ptr; +} + +CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, + uint32_t event_id) { + MS_EXCEPTION_IF_NULL(graph_ptr); + auto recv_op = std::make_shared(kRecvOpName); + MS_EXCEPTION_IF_NULL(recv_op); + auto recv_apply = std::make_shared(recv_op); + MS_EXCEPTION_IF_NULL(recv_apply); + std::vector recv_input_list = {recv_apply}; + CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); + MS_EXCEPTION_IF_NULL(recv_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + recv_node_ptr->set_abstract(abstract_none); + return recv_node_ptr; +} + +void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { + device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); + resource_manager.ResetResource(); + if (!NeedInsertSwitch()) { + return; + } + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX; + ReorderGetNext(kernel_graph_ptr); + std::map switch_loop_input; + CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input); + + std::vector *mute_inputs = kernel_graph_ptr->MutableInputs(); + MS_EXCEPTION_IF_NULL(mute_inputs); + mute_inputs->push_back(switch_loop_input[kLoopCountParamName]); + mute_inputs->push_back(switch_loop_input[kEpochParamName]); + mute_inputs->push_back(switch_loop_input[kIterLoopParamName]); + mute_inputs->push_back(switch_loop_input[kZeroParamName]); + mute_inputs->push_back(switch_loop_input[kOneParamName]); + for (const auto &input : kernel_graph_ptr->inputs()) { + MS_EXCEPTION_IF_NULL(input); + if (input->isa()) { + ParameterPtr param_ptr = input->cast(); + if (param_ptr == nullptr) { + MS_EXCEPTION(NotSupportError) << "Cast to parameter point failed !"; + } + } + } + + const std::vector &orders = kernel_graph_ptr->execution_order(); + if (orders.empty()) { + MS_LOG(EXCEPTION) << "graph execution order is empty"; + } + + std::vector exec_order; + std::vector getnext_active_streams; + std::vector fpbp_active_streams; + CNodePtr getnext_cnode; + uint32_t eos_done_event_id = UINT32_MAX; + + // getnext loop process + // getnext loop stream switch op + CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(getnext_switch_app); + uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); + exec_order.push_back(getnext_switch_app); + + // getnext op + uint32_t getnext_stream_id = resource_manager.ApplyNewStream(); + size_t i = 0; + for (; i < orders.size(); i++) { + auto node = orders[i]; + exec_order.push_back(node); + AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get()); + if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { + getnext_cnode = node; + break; + } + } + + // update getnext loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(getnext_stream_id), getnext_switch_app); + + // getnext loop fpbp start send + uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent(); + CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id); + AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get()); + exec_order.push_back(fpbp_start_send); + + if (eos_mode) { + // getnext loop eos start send + uint32_t eos_start_event_id = resource_manager.ApplyNewEvent(); + CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id); + AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get()); + exec_order.push_back(eos_start_send); + + // End Of Sequence loop process + // eos loop stream switch + CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(eos_switch_app); + uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get()); + AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), eos_switch_app); + exec_order.push_back(eos_switch_app); + + // eos loop eos start recv + CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id); + uint32_t eos_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get()); + exec_order.push_back(eos_start_recv); + + // update eos loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(eos_stream_id), eos_switch_app); + + // EndOfSequence op + CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode); + MS_EXCEPTION_IF_NULL(end_of_sequence_op); + AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get()); + exec_order.push_back(end_of_sequence_op); + + // eos loop eos done send + eos_done_event_id = resource_manager.ApplyNewEvent(); + CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, eos_done_event_id); + AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get()); + exec_order.push_back(eos_done_send); + + // eos loop stream active + fpbp_active_streams.push_back(eos_switch_stream_id); + } + + // fpbp loop process + // fpbp loop stream switch + CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(fpbp_switch_app); + uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); + AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), fpbp_switch_app); + exec_order.push_back(fpbp_switch_app); + + // fpbp loop fpbp start recv + CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id); + uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get()); + exec_order.push_back(fpbp_start_recv); + + // update fpbp loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(fpbp_stream_id), fpbp_switch_app); + + // fpbp loop AssignAdd + CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(assign_add_one); + AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get()); + exec_order.push_back(assign_add_one); + + // fpbp memcpy + std::vector memcpy_list; + std::vector other_list; + CNodePtr cur_cnode = nullptr; + for (size_t idx = i + 1; idx < orders.size(); idx++) { + cur_cnode = orders[idx]; + if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { + memcpy_list.emplace_back(cur_cnode); + } else { + other_list.emplace_back(cur_cnode); + } + } + + (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); + + // fpbp loop eos done recv + if (eos_mode) { + CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id); + AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get()); + exec_order.push_back(eos_done_recv); + } + + // stream active to activate getnext loop + CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(getnext_active_app); + getnext_active_streams.push_back(getnext_switch_stream_id); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(getnext_active_streams), + getnext_active_app); + exec_order.push_back(getnext_active_app); + + // fpbp loop other ops + (void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order)); + + // stream active to activate fpbp loop and eos loop + CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(fpbp_active_app); + fpbp_active_streams.push_back(fpbp_switch_stream_id); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(fpbp_active_streams), fpbp_active_app); + exec_order.push_back(fpbp_active_app); + + kernel_graph_ptr->set_execution_order(exec_order); +} + +void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, + std::map *switch_loop_input) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(switch_loop_input); + std::vector shp = {1}; + tensor::TensorPtr tensor_ptr = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(tensor_ptr); + mindspore::abstract::AbstractBasePtr paremeter_abstract_ptr = tensor_ptr->ToAbstract(); + if (paremeter_abstract_ptr == nullptr) { + MS_LOG(EXCEPTION) << "create abstract before insert switch op failed!"; + } + + ParameterPtr loop_count = std::make_shared(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(loop_count); + loop_count->set_name(kLoopCountParamName); + loop_count->set_abstract(paremeter_abstract_ptr); + ParameterPtr loop_count_new = kernel_graph_ptr->NewParameter(loop_count); + + (*switch_loop_input)[kLoopCountParamName] = loop_count_new; + + ParameterPtr iter_loop = std::make_shared(kernel_graph_ptr); + iter_loop->set_name(kIterLoopParamName); + iter_loop->set_abstract(paremeter_abstract_ptr); + ParameterPtr iter_loop_new = kernel_graph_ptr->NewParameter(iter_loop); + (*switch_loop_input)[kIterLoopParamName] = iter_loop_new; + + ParameterPtr zero = std::make_shared(kernel_graph_ptr); + zero->set_name(kZeroParamName); + zero->set_abstract(paremeter_abstract_ptr); + ParameterPtr zero_new = kernel_graph_ptr->NewParameter(zero); + (*switch_loop_input)[kZeroParamName] = zero_new; + + ParameterPtr one = std::make_shared(kernel_graph_ptr); + one->set_name(kOneParamName); + one->set_abstract(paremeter_abstract_ptr); + ParameterPtr one_new = kernel_graph_ptr->NewParameter(one); + (*switch_loop_input)[kOneParamName] = one_new; + + ParameterPtr epoch = std::make_shared(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(epoch); + epoch->set_name(kEpochParamName); + epoch->set_abstract(paremeter_abstract_ptr); + ParameterPtr epoch_new = kernel_graph_ptr->NewParameter(epoch); + (*switch_loop_input)[kEpochParamName] = epoch_new; +} + +kernel::KernelBuildInfo::KernelBuildInfoBuilder KernelAdjust::CreateMngKernelBuilder( + const std::vector &formats, const std::vector &type_ids) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetInputsFormat(formats); + selected_kernel_builder.SetInputsDeviceType(type_ids); + + selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); + selected_kernel_builder.SetProcessor(kernel::Processor::AICORE); + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + return selected_kernel_builder; +} + +CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, + const std::map &switch_loop_input) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( + {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); + auto typeNone_abstract = std::make_shared(); + auto stream_switch = std::make_shared(kStreamSwitchOpName); + std::vector inputs; + inputs.push_back(NewValueNode(stream_switch)); + inputs.push_back(switch_loop_input.at(kLoopCountParamName)); + inputs.push_back(switch_loop_input.at(kIterLoopParamName)); + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + CNodePtr stream_switch_app = kernel_graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(stream_switch_app); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_switch_app.get()); + stream_switch_app->set_abstract(typeNone_abstract); + // set attr: cond_ RT_LESS + int condition = static_cast(RT_LESS); + ValuePtr cond = MakeValue(condition); + AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); + // set attr:data_type + int data_type = static_cast(RT_SWITCH_INT64); + ValuePtr dt = MakeValue(data_type); + AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); + // set distinction label and graph id + return stream_switch_app; +} + +CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( + {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); + abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); + auto stream_active_others = std::make_shared(kStreamActiveOpName); + std::vector inputs; + inputs.push_back(NewValueNode(stream_active_others)); + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(stream_active_others_app); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get()); + stream_active_others_app->set_abstract(typeNone_abstract); + return stream_active_others_app; +} + +CNodePtr KernelAdjust::CreatTupleGetItemNode(const std::shared_ptr &kernel_graph_ptr, + const CNodePtr &node, size_t output_idx) { + auto idx = NewValueNode(SizeToInt(output_idx)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToInt(output_idx)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + CNodePtr tuple_getitem = kernel_graph_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); + MS_EXCEPTION_IF_NULL(tuple_getitem); + tuple_getitem->set_scope(node->scope()); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); + TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); + return tuple_getitem; +} + +CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr &kernel_graph_ptr, + const CNodePtr &getnext_cnode) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT}); + selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8}); + + selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); + selected_kernel_builder.SetProcessor(kernel::Processor::AICPU); + selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL); + + selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); + selected_kernel_builder.SetOutputsDeviceType({kNumberTypeUInt8}); + // EndOfSequence + auto end_of_sequence = std::make_shared(kEndOfSequence); + std::vector inputs; + inputs.push_back(NewValueNode(end_of_sequence)); + // GetNext output 0 is EndOfSequence's input + auto tuple_get_item = CreatTupleGetItemNode(kernel_graph_ptr, getnext_cnode, 0); + inputs.push_back(tuple_get_item); + CNodePtr end_of_sequence_node = kernel_graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(end_of_sequence_node); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), end_of_sequence_node.get()); + std::vector input_names = {"x"}; + ValuePtr input_names_v = MakeValue(input_names); + AnfAlgo::SetNodeAttr("input_names", input_names_v, end_of_sequence_node); + std::vector output_names = {"y"}; + ValuePtr output_names_v = MakeValue(output_names); + AnfAlgo::SetNodeAttr("output_names", output_names_v, end_of_sequence_node); + end_of_sequence_node->set_abstract(tuple_get_item->abstract()); + return end_of_sequence_node; +} + +CNodePtr KernelAdjust::CreateStreamAssignAddnOP( + const std::shared_ptr &kernel_graph_ptr, + const std::map &switch_loop_input) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( + {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); + selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); + selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32}); + // AssignAdd + auto assign_add = std::make_shared(kAssignAddOpName); + std::vector inputs; + inputs.push_back(NewValueNode(assign_add)); + inputs.push_back(switch_loop_input.at(kLoopCountParamName)); + inputs.push_back(switch_loop_input.at(kOneParamName)); + CNodePtr assign_add_one = kernel_graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(assign_add_one); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), assign_add_one.get()); + std::vector input_names = {"ref", "value"}; + std::vector output_names = {"output"}; + ValuePtr input_names_v = MakeValue(input_names); + ValuePtr output_names_v = MakeValue(output_names); + AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_one); + AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_one); + selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); + MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); + assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); + return assign_add_one; +} + +bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr) { + if (!NeedInsertSwitch()) { + return true; + } + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + auto input_nodes = kernel_graph_ptr->inputs(); + std::vector inputs; + LoadSwitchInputs(&inputs); + std::shared_ptr> inputsPtr = std::make_shared>(inputs); + kernel_graph_ptr->set_input_ctrl_tensors(inputsPtr); + size_t input_ctrl_size = inputs.size(); + // inputs_node:include four ctrl nodes in the back. such as:conv,loop_cnt, ites_loop, zero, one. + // deal four ctrl nodes. + for (size_t i = 0; i < inputs.size(); ++i) { + auto tensor = inputs[i]; + size_t deal_index = input_nodes.size() - input_ctrl_size + i; + if (deal_index >= input_nodes.size()) { + MS_LOG(EXCEPTION) << "deal_index[" << deal_index << "] out of range"; + } + auto input_node = input_nodes[deal_index]; + bool need_sync = false; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa()) { + auto pk_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(tensor); + MS_EXCEPTION_IF_NULL(pk_node); + if (tensor->is_dirty() || !pk_node->has_default()) { + need_sync = true; + } + } + if (need_sync) { + auto pk_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(pk_node); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + MS_EXCEPTION_IF_NULL(device_address); + tensor->set_device_address(device_address); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(INFO) << "SyncHostToDevice failed."; + return false; + } + } + tensor->set_dirty(false); + } + return true; +} + +void KernelAdjust::LoadSwitchInputs(std::vector *inputs) { + MS_LOG(INFO) << "---------------- LoadSwitchInputs---"; + MS_EXCEPTION_IF_NULL(inputs); + std::vector shp = {1}; + tensor::TensorPtr loop_count_tensor = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(loop_count_tensor); + int32_t *val = nullptr; + val = static_cast(loop_count_tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 0; + inputs->push_back(loop_count_tensor); + + // Epoch in device + tensor::TensorPtr epoch_tensor = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(epoch_tensor); + val = static_cast(epoch_tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 0; + inputs->push_back(epoch_tensor); + + tensor::TensorPtr iter_loop_tensor = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(iter_loop_tensor); + val = static_cast(iter_loop_tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = SizeToInt(LongToSize(ConfigManager::GetInstance().iter_num())); + MS_LOG(INFO) << "iter_loop_tensor = " << *val; + inputs->push_back(iter_loop_tensor); + + tensor::TensorPtr zero_tensor = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(zero_tensor); + val = static_cast(zero_tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 0; + inputs->push_back(zero_tensor); + + tensor::TensorPtr one_tensor = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(one_tensor); + val = static_cast(one_tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 1; + inputs->push_back(one_tensor); + + MS_LOG(INFO) << "---------------- LoadSwitchInputs End--"; +} + +void KernelAdjust::Profiling(NotNull kernel_graph_ptr) { + if (!ascend::ProfilingManager::GetInstance().IsProfiling()) { + MS_LOG(INFO) << "No need to profiling"; + return; + } + ProfilingTraceInfo profiling_trace_info = ProfilingUtils::GetProfilingTraceFromEnv(kernel_graph_ptr); + if (!profiling_trace_info.IsValid()) { + MS_LOG(WARNING) << "[profiling] no profiling node found!"; + return; + } + InsertProfilingKernel(profiling_trace_info, kernel_graph_ptr); +} + +void KernelAdjust::InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, + NotNull kernel_graph_ptr) { + MS_LOG(INFO) << "[profiling] Insert profiling kernel start"; + if (!profiling_trace_info.IsValid()) { + MS_LOG(WARNING) << "Profiling trace point not found"; + return; + } + std::vector new_cnode_list; + std::vector cnode_ptr_list = kernel_graph_ptr->execution_order(); + if (cnode_ptr_list.empty()) { + MS_LOG(ERROR) << "No CNode in graph"; + return; + } + for (const auto &cnode_ptr : cnode_ptr_list) { + ProfilingUtils::ProfilingTraceFpStart(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); + new_cnode_list.emplace_back(cnode_ptr); + ProfilingUtils::ProfilingCustomOp(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); + ProfilingUtils::ProfilingTraceBpEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); + ProfilingUtils::ProfilingTraceEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); + } + kernel_graph_ptr->set_execution_order(new_cnode_list); +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.h b/mindspore/ccsrc/runtime/device/kernel_adjust.h new file mode 100644 index 0000000000..dbd6f226af --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ + +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/session/session_context.h" +#include "ir/tensor.h" +#include "runtime/device/ascend/profiling/profiling_utils.h" +#include "runtime/device/kernel_info.h" + +using mindspore::device::ascend::ProfilingTraceInfo; +using mindspore::device::ascend::ProfilingUtils; +namespace mindspore { +constexpr auto kLoopCountParamName = "loop_count"; +constexpr auto kIterLoopParamName = "iter_loop"; +constexpr auto kZeroParamName = "zero"; +constexpr auto kOneParamName = "one"; +constexpr auto kEpochParamName = "loop_epoch"; +constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; +constexpr uint32_t kSecondStreamSwitchLabel = 2; + +namespace device { +class KernelAdjust { + public: + static KernelAdjust &GetInstance() { + static KernelAdjust instance; + return instance; + } + + void InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr); + bool StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr); + void Profiling(NotNull kernel_graph_ptr); + static bool NeedInsertSwitch(); + CNodePtr CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr); + + private: + KernelAdjust() = default; + ~KernelAdjust() = default; + + void ReorderGetNext(const std::shared_ptr &kernel_graph_ptr); + CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); + CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); + void CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, + std::map *switch_loop_input); + CNodePtr CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, + const std::map &switch_loop_input); + CNodePtr CreatTupleGetItemNode(const std::shared_ptr &kernel_graph_ptr, const CNodePtr &node, + size_t output_idx); + CNodePtr CreateEndOfSequenceOP(const std::shared_ptr &kernel_graph_ptr, + const CNodePtr &getnext_cnode); + CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr &kernel_graph_ptr, + const std::map &switch_loop_input); + kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector &formats, + const std::vector &type_ids); + void LoadSwitchInputs(std::vector *inputs); + void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, + NotNull kernel_graph_ptr); +}; +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_info.cc b/mindspore/ccsrc/runtime/device/kernel_info.cc new file mode 100644 index 0000000000..692532e70b --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_info.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace device { +const kernel::KernelBuildInfo *KernelInfo::select_kernel_build_info() const { return select_kernel_build_info_.get(); } + +kernel::KernelBuildInfoPtr KernelInfo::GetMutableSelectKernelBuildInfo() const { return select_kernel_build_info_; } + +const DeviceAddress *KernelInfo::GetOutputAddr(size_t index) const { + if (index >= output_address_list_.size()) { + MS_LOG(ERROR) << "Index [" << index << "] out of range"; + return nullptr; + } + return output_address_list_[index].get(); +} + +DeviceAddressPtr KernelInfo::GetMutableOutputAddr(size_t index) const { + if (index >= output_address_list_.size()) { + MS_LOG(ERROR) << "Index [" << index << "] out of range"; + return nullptr; + } + return output_address_list_[index]; +} + +bool KernelInfo::OutputAddrExist(size_t index) const { + if (index >= output_address_list_.size()) { + return false; + } + return output_address_list_[index] != nullptr; +} + +bool KernelInfo::SetOutputAddr(const DeviceAddressPtr &output_address, size_t index) { + // parameter and valuenode + if (kernel_mod_ == nullptr && index >= output_address_list_.size()) { + for (size_t i = output_address_list_.size(); i <= index; i++) { + output_address_list_.emplace_back(nullptr); + } + } else if (output_address_list_.empty()) { + // set cnode + for (size_t i = 0; i < kernel_mod_->GetOutputSizeList().size(); i++) { + output_address_list_.emplace_back(nullptr); + } + } + if (index >= output_address_list_.size()) { + MS_LOG(ERROR) << "Index [" << index << "] out of range"; + return false; + } + output_address_list_[index] = output_address; + return true; +} + +DeviceAddress *KernelInfo::GetWorkspaceAddr(size_t index) const { + if (index >= workspace_address_list_.size()) { + MS_LOG(ERROR) << "Index [" << index << "] out of range"; + return nullptr; + } + return workspace_address_list_[index].get(); +} + +bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { + if (workspace_address_list_.empty()) { + // parameter and valuenode + if (kernel_mod_ == nullptr) { + workspace_address_list_.emplace_back(nullptr); + } else { + // set cnode + for (size_t i = 0; i < kernel_mod_->GetWorkspaceSizeList().size(); i++) { + workspace_address_list_.emplace_back(nullptr); + } + } + } + if (index >= workspace_address_list_.size()) { + MS_LOG(ERROR) << "Index" << index << " out of range"; + return false; + } + workspace_address_list_[index] = output_address; + return true; +} + +void KernelInfo::set_kernel_mod(const kernel::KernelModPtr &kernel_mod) { kernel_mod_ = kernel_mod; } + +kernel::KernelMod *KernelInfo::MutableKernelMod() const { return kernel_mod_.get(); } + +const kernel::KernelMod *KernelInfo::kernel_mod() const { return kernel_mod_.get(); } + +bool KernelInfo::operator==(const KernelInfo &other) const { + if (stream_id_ != other.stream_id_ || stream_distinction_label_ != other.stream_distinction_label_ || + graph_id_ != other.graph_id_) { + return false; + } + if ((select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ == nullptr) || + (select_kernel_build_info_ == nullptr && other.select_kernel_build_info_ != nullptr)) { + return false; + } + if (select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ != nullptr) { + if (!(*select_kernel_build_info_ == *(other.select_kernel_build_info_))) { + return false; + } + } + // Currently we only check whether both the kernel_mod_ are initialized or uninitialized. + if ((kernel_mod_ == nullptr && other.kernel_mod_ != nullptr) || + (kernel_mod_ != nullptr && other.kernel_mod_ == nullptr)) { + return false; + } + // Currently we only check whether both the sizes are equal of output_address_list_ and workspace_address_list_ or + // not. We can complete this check in the future. + if (output_address_list_.size() != other.output_address_list_.size() || + workspace_address_list_.size() != other.workspace_address_list_.size()) { + return false; + } + return true; +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/kernel_info.h b/mindspore/ccsrc/runtime/device/kernel_info.h new file mode 100644 index 0000000000..baded9d9a3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_info.h @@ -0,0 +1,87 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DEVICE_KERNEL_INFO_H_ +#define MINDSPORE_DEVICE_KERNEL_INFO_H_ + +#include +#include +#include "ir/kernel_info_dev.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "runtime/device/ascend/ascend_device_address.h" +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +const uint32_t kInvalidGraphId = UINT32_MAX; +const uint32_t kInvalidDistincLabel = UINT32_MAX; +namespace device { +class KernelInfo : public KernelInfoDevice { + public: + KernelInfo() { + kernel_mod_ = nullptr; + is_feature_map_ = false; + select_kernel_build_info_ = nullptr; + output_address_list_ = {}; + workspace_address_list_ = {}; + stream_id_ = UINT32_MAX; + stream_distinction_label_ = kInvalidDistincLabel; + graph_id_ = kInvalidGraphId; + } + virtual ~KernelInfo() = default; + + bool has_build_info() const override { return select_kernel_build_info() != nullptr; } + const kernel::KernelBuildInfo *select_kernel_build_info() const; + kernel::KernelBuildInfoPtr GetMutableSelectKernelBuildInfo() const; + void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { + select_kernel_build_info_ = select_kernel_build_info; + } + void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; } + const DeviceAddress *GetOutputAddr(size_t index) const; + DeviceAddressPtr GetMutableOutputAddr(size_t index) const; + bool OutputAddrExist(size_t index) const; + bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); + DeviceAddress *GetWorkspaceAddr(size_t index) const; + bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); + void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); + kernel::KernelMod *MutableKernelMod() const; + const kernel::KernelMod *kernel_mod() const; + uint32_t stream_id() const { return stream_id_; } + void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; } + uint32_t stream_distinction_label() const { return stream_distinction_label_; } + void set_stream_distinction_label(uint32_t stream_distinction_label) { + stream_distinction_label_ = stream_distinction_label; + } + void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } + uint32_t graph_id() const { return graph_id_; } + bool operator==(const KernelInfo &other) const; + bool is_feature_map() const { return is_feature_map_; } + + private: + bool is_feature_map_; + kernel::KernelBuildInfoPtr select_kernel_build_info_; + std::vector> output_address_list_; + std::vector> workspace_address_list_; + kernel::KernelModPtr kernel_mod_; + // stream_id_ is the index of stream object vector + uint32_t stream_id_; + // stream_distinction_label_ is used mark different op in different stream + uint32_t stream_distinction_label_; + // record which graph the node belong to + uint32_t graph_id_; +}; +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_DEVICE_KERNEL_INFO_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc new file mode 100644 index 0000000000..3de9af8c23 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -0,0 +1,775 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/kernel_runtime.h" +#include +#include +#include +#include +#include "common/utils.h" +#include "common/trans.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "ir/value.h" +using mindspore::kernel::Address; +using mindspore::kernel::AddressPtr; + +namespace mindspore { +namespace device { +KernelRuntime::~KernelRuntime() { +#ifdef ENABLE_DUMP_E2E + dump_conf_ptr_ = nullptr; +#endif +} + +bool KernelRuntime::Run(session::KernelGraph *graph) { + bool ret = false; + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); +#endif + bool is_task_sink = context_ptr->enable_task_sink(); + if (is_task_sink) { + ret = RunTask(graph); + } else { + ret = LaunchKernel(graph); + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Call MS Run Success in " << cost << " us"; +#endif + return ret; +} + +// for D to impl +bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { + if (graph != nullptr) { + return true; + } + return false; +} + +// for D to impl +bool KernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { + if (graph != nullptr) { + return true; + } + return false; +} + +// for D to impl +bool KernelRuntime::GenTask(const session::KernelGraph *graph) { + if (graph != nullptr) { + return true; + } + return false; +} + +bool KernelRuntime::LoadTask(const session::KernelGraph *graph) { + if (graph != nullptr) { + return true; + } + return false; +} + +// for D to impl +bool KernelRuntime::RunTask(const session::KernelGraph *graph) { + if (graph != nullptr) { + return true; + } + return false; +} + +bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { + MS_EXCEPTION_IF_NULL(kernel); + if (AnfAlgo::OutputAddrExist(kernel, index)) { + return true; + } + return false; +} + +size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { + MS_EXCEPTION_IF_NULL(node); + if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { + MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size [" + << AnfAlgo::GetOutputTensorNum(node) << "] of node!"; + } + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); + if (output_type_id == kTypeUnknown) { + output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index); + } + size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); + std::vector shape = AnfAlgo::GetOutputDeviceShape(node, output_index); + auto format = AnfAlgo::GetOutputFormat(node, output_index); + if (shape.empty() && format != kOpFormat_DEFAULT) { + shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index)); + shape = trans::TransShapeToDevice(shape, format); + } + // scalar's output shape is a empty vector + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + return tensor_size; +} + +void KernelRuntime::AssignMemory(session::KernelGraph *graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(mem_manager_); + mem_manager_->ResetDynamicMemory(); + AssignStaticMemory(graph); + AssignDynamicMemory(graph); + UpdateRefNodeOutputMem(graph); +} + +void KernelRuntime::RunOpAssignMemory(const std::vector &input_tensors, + session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + RunOpAssignInputMemory(input_tensors, graph); + AssignStaticMemoryValueNode(graph); + for (const auto &cnode : graph->execution_order()) { + RunOpAssignOutputMemory(cnode); + RunOpAssignWorkSpaceMemory(cnode); + } + UpdateRefNodeOutputMem(graph); +} + +void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + // clear input parameter memory resource + for (const auto &input_node : graph->inputs()) { + MS_EXCEPTION_IF_NULL(input_node); + AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get()); + } + // clear input value node memory resource + for (const auto &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get()); + } + for (const auto &cnode : graph->execution_order()) { + MS_EXCEPTION_IF_NULL(cnode); + // clear output memory resource + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { + AnfAlgo::SetOutputAddr(nullptr, index, cnode.get()); + } + // clear workspace memory resource + auto kernel_mod = AnfAlgo::GetKernelMod(cnode); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto workspace_lists = kernel_mod->GetWorkspaceSizeList(); + for (size_t index = 0; index < workspace_lists.size(); ++index) { + AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get()); + } + } +} + +void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) { + AssignStaticMemoryInput(graph); + AssignStaticMemoryValueNode(graph); + AssignStaticMemoryOutput(graph); +} + +void KernelRuntime::RunOpAssignInputMemory(const std::vector &input_tensors, + const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mem_manager_); + if (input_tensors.size() != graph->inputs().size()) { + MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() + << " should be equal to graph input parameter size " << graph->inputs().size(); + } + + for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) { + auto item = graph->inputs()[input_index]; + MS_EXCEPTION_IF_NULL(item); + if (!item->isa()) { + continue; + } + auto output_size = AnfAlgo::GetOutputTensorNum(item); + for (size_t index = 0; index < output_size; index++) { + MS_EXCEPTION_IF_NULL(input_tensors[input_index]); + auto output_address = + std::dynamic_pointer_cast(input_tensors[input_index]->device_address()); + if (output_address != nullptr) { + AnfAlgo::SetOutputAddr(output_address, index, item.get()); + continue; + } + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); + if (output_type_id == kTypeUnknown) { + output_type_id = AnfAlgo::GetOutputInferDataType(item, index); + } + auto tensor_size = CountNodeDeviceMemorySize(item, index); + auto device_address = + CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + MS_EXCEPTION_IF_NULL(device_address); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size); + if (!ret) { + MS_LOG(EXCEPTION) << "Malloc device memory failed."; + } + AnfAlgo::SetOutputAddr(device_address, index, item.get()); + } + } +} + +void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.empty()) { + return; + } + + for (size_t i = 0; i < output_sizes.size(); ++i) { + if (AnfAlgo::OutputAddrExist(kernel, i)) { + continue; + } + if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); + continue; + } + std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); + auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); + auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); + device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i)); + MS_EXCEPTION_IF_NULL(device_address); + auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); + if (!ret) { + MS_LOG(EXCEPTION) << "Malloc device memory failed."; + } + AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); + } +} + +void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(mem_manager_); + if (kernel->isa()) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto workspace_lists = kernel_mod->GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_lists.size(); ++i) { + auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown); + MS_EXCEPTION_IF_NULL(device_address); + auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]); + if (!ret) { + MS_LOG(EXCEPTION) << "Malloc device memory failed."; + } + AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); + } + } +} + +void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto graph_inputs = graph->inputs(); + auto graph_valid_input = graph->valid_inputs(); + graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end()); + std::vector need_alloc_nodes; + for (size_t i = 0; i < graph_inputs.size(); ++i) { + auto item = graph_inputs[i]; + MS_EXCEPTION_IF_NULL(item); + if (i < graph_valid_input.size() && !graph_valid_input[i]) { + continue; + } + + if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) { + auto outs = AnfAlgo::GetAllOutput(item); + for (auto &out : outs) { + MS_EXCEPTION_IF_NULL(out); + if (!out->isa()) { + continue; + } + if (NodeOutputDeviceAddressExist(out, 0)) { + continue; + } + need_alloc_nodes.push_back(out); + } + } + if (!item->isa()) { + continue; + } + if (NodeOutputDeviceAddressExist(item, 0)) { + continue; + } + need_alloc_nodes.push_back(item); + } + + for (auto &item : need_alloc_nodes) { + auto output_size = AnfAlgo::GetOutputTensorNum(item); + for (size_t index = 0; index < output_size; index++) { + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); + // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown + if (output_type_id == kTypeUnknown) { + MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; + output_type_id = AnfAlgo::GetOutputInferDataType(item, index); + } + auto tensor_size = CountNodeDeviceMemorySize(item, index); + auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); + auto address = CreateDeviceAddress(ptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + AnfAlgo::SetOutputAddr(address, index, item.get()); + } + } +} + +void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); + std::vector non_communication_op; + // Assign Communicate Op Memory firstly. + for (const auto &node : nodes) { + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + MS_EXCEPTION_IF_NULL(item_with_index.first); + if (!item_with_index.first->isa() || !AnfAlgo::IsRealKernel(item_with_index.first)) { + continue; + } + graph->AddFinalOutputKernel(item_with_index.first); + if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { + AssignCommunicationNodeMem(kStaticMem, item_with_index.first); + } else { + non_communication_op.emplace_back(item_with_index); + } + } + + for (const auto &item_with_index : non_communication_op) { + AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second)); + } +} + +void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.empty()) { + MS_LOG(INFO) << "This kernel has no output size."; + continue; + } + for (size_t i = 0; i < output_sizes.size(); ++i) { + session::AnfWithOutIndex out_pair(kernel, i); + if (graph->IsInRefOutputMap(out_pair)) { + auto origin_pair = graph->GetRefCorrespondOutput(out_pair); + MS_EXCEPTION_IF_NULL(origin_pair.first); + auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second); + MS_EXCEPTION_IF_NULL(origin_node_output_addr); + auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i); + if (origin_node_output_addr.get() != cur_node_output_addr.get()) { + MS_LOG(INFO) << "REF address is not same, ref node output need address update"; + MS_LOG(INFO) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is " + << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i; + AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get()); + } + } + } + } +} + +void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { + AssignCommunicationNodeInputMem(node); + AssignCommunicationNodeOutputMem(flag, node); +} + +void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto kernel_mod = AnfAlgo::GetKernelMod(node); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.empty()) { + MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size."; + return; + } + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + size_t total_size = 0; + size_t output_index = 0; + std::vector align_size_list; + for (uint64_t mem_size : output_sizes) { + if (AnfAlgo::OutputAddrExist(node, output_index++)) { + MS_LOG(INFO) << "communication op addr exist"; + continue; + } + if (context_ptr->enable_hccl()) { + mem_size = mem_manager_->GetCommonAlignSize(mem_size); + } + total_size += mem_size; + align_size_list.emplace_back(mem_size); + } + uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); + for (size_t j = 0; j < align_size_list.size(); ++j) { + std::string output_format = AnfAlgo::GetOutputFormat(node, j); + auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); + auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type); + MS_EXCEPTION_IF_NULL(address); + if (AnfAlgo::IsCommunicationOp(node) && context_ptr->enable_hccl()) { + address->UpdateCommunicationAddress(); + } + AnfAlgo::SetOutputAddr(address, j, node.get()); + output_ptr += align_size_list[j]; + } +} + +DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + auto kernel_mod = AnfAlgo::GetKernelMod(anf_node); + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.size() <= index) { + MS_LOG(EXCEPTION) << "Previous node output size < node index"; + } + std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index); + auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index); + auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type); + AnfAlgo::SetOutputAddr(address, index, anf_node.get()); + return address; +} + +void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(mem_manager_); + size_t total_size = 0; + std::vector> addr_size; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); + auto input_node = input_node_with_index.first; + DeviceAddressPtr address = nullptr; + if (input_node->isa()) { + address = PreAssignCNodeMemory(input_node, input_node_with_index.second); + } else { + MS_LOG(EXCEPTION) << "Communication node inputs only support CNode"; + } + MS_EXCEPTION_IF_NULL(address); + auto mem_size = mem_manager_->GetCommonAlignSize(address->size()); + total_size += mem_size; + addr_size.emplace_back(address.get(), mem_size); + } + uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size); + for (const auto &iter : addr_size) { + MS_EXCEPTION_IF_NULL(iter.first); + iter.first->set_ptr(input_ptr); + input_ptr += iter.second; + } +} + +void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(mem_manager_); + if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { + MS_LOG(INFO) << "GetNext disable mem_reuse"; + flag = kDynamicMem; + } + auto kernel_mod = AnfAlgo::GetKernelMod(node); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.empty()) { + MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size."; + return; + } + for (size_t i = 0; i < output_sizes.size(); ++i) { + if ((kGetAllOuts != index) && (SizeToInt(i) != index)) { + continue; + } + if (NodeOutputDeviceAddressExist(node, i)) { + MS_LOG(INFO) << "Already malloc index:" << i; + continue; + } + auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]); + if (ptr == nullptr) { + // reused ptr, no need alloc, continue; + continue; + } + std::string output_format = AnfAlgo::GetOutputFormat(node, i); + auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); + auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type); + MS_EXCEPTION_IF_NULL(device_address); + device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); + if (AnfAlgo::IsCommunicationOp(node) && context_ptr->enable_hccl()) { + device_address->UpdateCommunicationAddress(); + } + AnfAlgo::SetOutputAddr(device_address, i, node.get()); + } +} + +void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, + size_t output_idx) { + MS_EXCEPTION_IF_NULL(value_node); + MS_EXCEPTION_IF_NULL(node_value); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto tensor = node_value->cast(); + if (tensor == nullptr) { + MS_LOG(WARNING) << "Tensor is null"; + return; + } + size_t tensor_size = tensor->data().nbytes(); + auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); + if (output_type_id == kTypeUnknown) { + output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); + } + auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); + DeviceAddressPtr address = nullptr; + if (ms_context->enable_pynative_infer()) { + address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); + MS_EXCEPTION_IF_NULL(address); + if (!mem_manager_->MallocMemFromMemPool(address, node_size)) { + MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; + } + } else { + auto ptr = mem_manager_->MallocMem(kStaticMem, node_size); + address = CreateDeviceAddress(ptr, node_size, output_format, output_type_id); + MS_EXCEPTION_IF_NULL(address); + } + AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); + if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), + tensor->data_c())) { + MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" + << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " + << AnfAlgo::GetOutputInferDataType(value_node, output_idx); + } +} + +void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + for (auto &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + if (NodeOutputDeviceAddressExist(value_node, 0)) { + MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist"; + continue; + } + auto &node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + if (node_value->isa()) { + AssignValueNodeTensor(value_node, node_value, 0); + } else if (node_value->isa()) { + auto value = GetValue(node_value); + size_t tensor_size = value.size(); + DeviceAddressPtr address = nullptr; + if (ms_context->enable_pynative_infer()) { + address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); + MS_EXCEPTION_IF_NULL(address); + if (!mem_manager_->MallocMemFromMemPool(address, tensor_size)) { + MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; + } + } else { + auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); + address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); + MS_EXCEPTION_IF_NULL(address); + } + AnfAlgo::SetOutputAddr(address, 0, value_node.get()); + std::vector shape = {1, SizeToInt(tensor_size)}; + if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) { + MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!"; + } + } + } +} + +void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); + auto mem_flag = kDynamicMem; + if (is_enable_mem_reuse) { + mem_manager_->MallocReusedDynamicMem(graph); + mem_flag = kReuseDynamicMem; + } + auto &execution_nodes = graph->execution_order(); + std::vector compute_nodes; + // communication nodes first + for (auto &node : execution_nodes) { + if (AnfAlgo::IsCommunicationOp(node)) { + // skip if the memory is already alocated + AssignCommunicationNodeMem(mem_flag, node); + } else { + compute_nodes.emplace_back(node); + } + } + + // then compute nodes + for (auto &node : compute_nodes) { + AssignNodeOutputMem(mem_flag, node, kGetAllOuts); + AssignWorkSpaceMem(mem_flag, node); + } +} + +void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto kernel_mod = AnfAlgo::GetKernelMod(node); + MS_EXCEPTION_IF_NULL(kernel_mod); + size_t index = 0; + for (auto &size : kernel_mod->GetWorkspaceSizeList()) { + auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size); + AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); + index++; + } +} + +void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, + AddressPtrList *kernel_outputs) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_inputs); + MS_EXCEPTION_IF_NULL(kernel_workspaces); + MS_EXCEPTION_IF_NULL(kernel_outputs); + auto cnode = kernel->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { + return GenAddrCleanLaunchArgs(cnode, kernel_inputs); + } + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); + auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); + MS_EXCEPTION_IF_NULL(device_address); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(input->addr); + input->size = device_address->size_; + kernel_inputs->emplace_back(input); + } + + for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) { + auto device_address = AnfAlgo::GetOutputAddr(kernel, i); + kernel::AddressPtr output = std::make_shared(); + MS_EXCEPTION_IF_NULL(output); + output->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(output->addr); + output->size = device_address->size_; + kernel_outputs->emplace_back(output); + } + + for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) { + auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(workspace->addr); + workspace->size = device_address->size_; + kernel_workspaces->emplace_back(workspace); + } +} + +void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) { + if (cnode->inputs().size() != 2) { + MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2."; + } + MS_EXCEPTION_IF_NULL(cnode->inputs()[1]); + auto pre_node = (cnode->inputs()[1])->cast(); + // set clean output address + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { + auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); + for (auto index : clean_output_indexs) { + auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(input->addr); + input->size = device_address->size_; + kernel_inputs->emplace_back(input); + } + MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); + } + // set clean workspace address + if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { + auto clean_workspaces_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); + for (const auto &index : clean_workspaces_indexs) { + auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(workspace->addr); + workspace->size = device_address->size_; + kernel_inputs->emplace_back(workspace); + } + } +} + +bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { + auto &kernels = graph.execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + if (!ret) { + MS_LOG(ERROR) << "Launch kernel failed."; + return false; + } + } + return true; +} + +bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + if (!LaunchKernelMod(*graph)) { + MS_LOG(ERROR) << "LaunchKernelMod failed!"; + return false; + } + return true; +} + +void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { + MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; +} + +#ifdef ENABLE_DUMP_E2E +bool KernelRuntime::SetDumpConf() { + dump_conf_ptr_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(dump_conf_ptr_); + bool ret = dump_conf_ptr_->SetDumpConfFromJsonFile(); + return ret; +} + +DumpConfPtr KernelRuntime::GetDumpConf() { return dump_conf_ptr_; } +#endif +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h new file mode 100644 index 0000000000..8320355b82 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -0,0 +1,122 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ +#include +#include +#include +#include + +#include "runtime/device/device_address.h" +#include "ir/tensor.h" +#include "predict/generator/utils/ir_model_util.h" +#ifdef ENABLE_DUMP_E2E +#include "debug/e2e_dump.h" +#endif +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel.h" +#include "utils/context/ms_context.h" +#include "runtime/device/memory_manager.h" + +using mindspore::tensor::Tensor; +using std::vector; +using TensorPtr = std::shared_ptr; +using mindspore::kernel::AddressPtr; +using AddressPtrList = std::vector; + +namespace mindspore { +#ifndef ENABLE_DEBUGGER +class Debugger; +#endif +namespace device { +class KernelRuntime { + public: + KernelRuntime() = default; + virtual ~KernelRuntime(); + virtual bool Init() = 0; + virtual void AssignMemory(session::KernelGraph *graph); + void RunOpAssignMemory(const std::vector &input_tensors, session::KernelGraph *graph); + void RunOpClearMemory(const session::KernelGraph *graph); + virtual bool Run(session::KernelGraph *graph); + virtual bool DumpData(session::KernelGraph *graph); + virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger); + virtual bool RunTask(const session::KernelGraph *graph); + virtual bool GenTask(const session::KernelGraph *graph); + bool LaunchKernel(const session::KernelGraph *graph); + virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); + virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); + virtual void ClearGraphRuntimeResource(uint32_t graph_id); + virtual bool SyncStream() = 0; + +#ifdef ENABLE_DUMP_E2E + DumpConfPtr GetDumpConf(); +#endif + virtual bool LoadTask(const session::KernelGraph *graph); + // for GPU and D to impl + virtual void ReleaseDeviceRes() {} + void set_device_id(uint32_t device_id) { device_id_ = device_id; } + + protected: + virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) = 0; + virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); + void AssignStaticMemory(session::KernelGraph *graph); + void AssignDynamicMemory(session::KernelGraph *graph); + void ReuseAssignDynamicMemory(session::KernelGraph *graph); + void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); + void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); + void AssignReuseWorkSpaceMem(const AnfNodePtr &node); + + void UpdateRefNodeOutputMem(const session::KernelGraph *graph); + + void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); + void AssignCommunicationNodeInputMem(const AnfNodePtr &node); + void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); +#ifdef ENABLE_DUMP_E2E + bool SetDumpConf(); +#endif + + private: + void AssignStaticMemoryOutput(session::KernelGraph *graph); + void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, + AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); + bool LaunchKernelMod(const session::KernelGraph &graph); + void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); + size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); + void RunOpAssignInputMemory(const std::vector &input_tensors, const session::KernelGraph *graph); + void RunOpAssignOutputMemory(const AnfNodePtr &kernel); + void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); + void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); + DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); + + protected: + uint32_t device_id_{0}; +#ifdef ENABLE_DUMP_E2E + DumpConfPtr dump_conf_ptr_; +#endif + void *stream_ = nullptr; + std::shared_ptr mem_manager_{nullptr}; +}; +using KernelRuntimePtr = std::shared_ptr; +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc new file mode 100644 index 0000000000..626259f9ce --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/kernel_runtime_manager.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +void KernelRuntimeManager::ClearRuntimeResource() { + std::lock_guard guard(lock_); + for (auto &iter : runtime_map_) { + MS_LOG(INFO) << "Release device " << iter.first; + MS_EXCEPTION_IF_NULL(iter.second); + iter.second->ReleaseDeviceRes(); + } + runtime_map_.clear(); +} + +void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) { + std::lock_guard guard(lock_); + for (auto &iter : runtime_map_) { + MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource"; + if (!iter.second) { + MS_LOG(ERROR) << "Kernel runtime is nullptr"; + continue; + } + iter.second->ClearGraphRuntimeResource(graph_id); + } +} + +void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) { + if (runtime_creators_.find(device_name) == runtime_creators_.end()) { + (void)runtime_creators_.emplace(device_name, runtime_creator); + } +} + +KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) { + std::string runtime_key = device_name + "_" + std::to_string(device_id); + auto runtime_iter = runtime_map_.find(runtime_key); + if (runtime_iter != runtime_map_.end()) { + return runtime_iter->second.get(); + } else if (runtime_map_.size() > 0) { + auto cur_runtime_key = runtime_map_.begin()->first; + auto find_pos = cur_runtime_key.rfind('_'); + if (find_pos != std::string::npos) { + if (cur_runtime_key.size() > find_pos + 1) { + auto cur_device_id = cur_runtime_key.substr(find_pos + 1); + MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id + << ", set device id: " << device_id << " failed"; + } else { + MS_LOG(EXCEPTION) << "Can't change device id in runtime, current runtime_key size error, set device id: " + << device_id << " failed"; + } + } + } + return GetKernelRuntime(device_name, device_id); +} + +KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) { + std::lock_guard guard(lock_); + std::string runtime_key = device_name + "_" + std::to_string(device_id); + auto runtime_iter = runtime_map_.find(runtime_key); + if (runtime_iter != runtime_map_.end()) { + return runtime_iter->second.get(); + } + std::shared_ptr kernel_runtime; + auto creator_iter = runtime_creators_.find(device_name); + if (creator_iter != runtime_creators_.end()) { + MS_EXCEPTION_IF_NULL(creator_iter->second); + kernel_runtime = (creator_iter->second)(); + kernel_runtime->set_device_id(device_id); + MS_EXCEPTION_IF_NULL(kernel_runtime); + runtime_map_[runtime_key] = kernel_runtime; + } else { + MS_LOG(EXCEPTION) << "No kernel runtime creator for " << device_name << " with device id " << device_id; + } + + return kernel_runtime.get(); +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h new file mode 100644 index 0000000000..7fcb40ae67 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ +#include +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "runtime/device/kernel_runtime.h" +namespace mindspore { +namespace device { +using KernelRuntimeCreator = std::function()>; + +class KernelRuntimeManager { + public: + static KernelRuntimeManager &Instance() { + static KernelRuntimeManager instance; + return instance; + } + void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator); + KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id); + KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id); + void ClearRuntimeResource(); + void ClearGraphResource(uint32_t graph_id); + + private: + KernelRuntimeManager() = default; + ~KernelRuntimeManager() = default; + DISABLE_COPY_AND_ASSIGN(KernelRuntimeManager); + std::map > runtime_map_; + std::map runtime_creators_; + std::mutex lock_; +}; + +class KernelRuntimeRegistrar { + public: + KernelRuntimeRegistrar(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) { + KernelRuntimeManager::Instance().Register(device_name, std::move(runtime_creator)); + } + ~KernelRuntimeRegistrar() = default; +}; + +#define MS_REG_KERNEL_RUNTIME(DEVICE_NAME, RUNTIME_CLASS) \ + static const KernelRuntimeRegistrar g_kernel_runtime_##DEVICE_NAME##_reg( \ + DEVICE_NAME, []() { return std::make_shared(); }); +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/memory_manager.cc b/mindspore/ccsrc/runtime/device/memory_manager.cc new file mode 100644 index 0000000000..563d5f0f50 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/memory_manager.cc @@ -0,0 +1,213 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/memory_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/context/ms_context.h" +using mindspore::memreuse::BestFitMemReuse; +using mindspore::memreuse::MemReuseUtilPtr; +namespace mindspore { +namespace device { +size_t MemoryManager::GetCommonAlignSize(size_t input_size) const { + return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; +} + +size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const { + return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize; +} + +void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + // set all infos + mem_reuse_util_ptr->SetAllInfo(graph); + auto bestfit_mem_reuse = std::make_shared(); + MS_EXCEPTION_IF_NULL(bestfit_mem_reuse); + bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get()); + size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize(); + MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]"; + mem_reuse_util_ptr_ = mem_reuse_util_ptr; + auto base_ptr = MallocDynamicMem(total_allocated_size, false); + mem_reuse_util_ptr_->set_mem_base(base_ptr); +} + +uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { + MS_EXCEPTION_IF_NULL(node); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + uint8_t *ptr = nullptr; + if (AnfAlgo::IsCommunicationOp(node)) { + bool communication_mem = false; + if (context_ptr->enable_hccl()) { + communication_mem = true; + } + if (flag == kStaticMem) { + ptr = MallocStaticMem(size, communication_mem); + } else { + ptr = MallocDynamicMem(size, communication_mem); + } + return ptr; + } + + if (flag == kStaticMem) { + ptr = MallocStaticMem(size, false); + } else if (flag == kDynamicMem) { + ptr = MallocDynamicMem(size, false); + } else if (flag == kReuseDynamicMem) { + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); + ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); + } + return ptr; +} + +uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { + if (flag == kReuseDynamicMem) { + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); + return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index); + } + return MallocDynamicMem(size, false); +} + +uint8_t *MemoryManager::MallocMem(int flag, size_t size) { + uint8_t *ptr = nullptr; + if (flag == kStaticMem) { + ptr = MallocStaticMem(size, false); + } else if (flag == kDynamicMem) { + ptr = MallocDynamicMem(size, false); + } + return ptr; +} + +uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem) { + size_t align_size = 0; + if (communication_mem) { + align_size = GetCommunicationAlignSize(size); + } else { + align_size = GetCommonAlignSize(size); + } + + MS_LOG(INFO) << "Malloc Memory for Static: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] communication_mem: " << communication_mem; + + if (static_mem_offset_ < align_size) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] failed!"; + } + total_static_size_ += align_size; + auto offset = static_mem_offset_ - align_size; + if (dynamic_mem_offset_ > offset) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] failed!"; + } + static_mem_offset_ = offset; + if (communication_mem) { + return device_mem_base_ + offset + kMemAlignSize; + } else { + return device_mem_base_ + offset; + } +} + +uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { + size_t align_size = 0; + if (communication_mem) { + align_size = GetCommunicationAlignSize(size); + } else { + align_size = GetCommonAlignSize(size); + } + + MS_LOG(INFO) << "Malloc Memory for Dynamic: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] communication_mem: " << communication_mem; + + uint64_t offset = dynamic_mem_offset_; + auto new_offset = dynamic_mem_offset_ + align_size; + if (new_offset > static_mem_offset_) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] failed!"; + } + total_dynamic_size_ += align_size; + dynamic_mem_offset_ = new_offset; + + if (communication_mem) { + return device_mem_base_ + offset + kMemAlignSize; + } else { + return device_mem_base_ + offset; + } +} + +bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) { + auto device_ptr = MallocMemFromMemPool(size); + if (!device_ptr) { + return false; + } + address->ptr_ = device_ptr; + address->from_mem_pool_ = true; + return true; +} + +void *MemoryManager::MallocMemFromMemPool(size_t size) { + if (size == 0) { + MS_LOG(ERROR) << "MallocMemFromMemPool size is 0."; + } + return nullptr; +} + +void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) { + MS_EXCEPTION_IF_NULL(address); + MS_EXCEPTION_IF_NULL(address->ptr_); + FreeMemFromMemPool(address->ptr_); + address->ptr_ = nullptr; +} + +void MemoryManager::FreeMemFromMemPool(void *device_ptr) { + if (device_ptr == nullptr) { + MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; + } +} + +bool MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list) { + auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); + if (device_ptr_list.size() == 0) { + return false; + } + if (addr_list.size() != device_ptr_list.size()) { + MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; + } + for (size_t i = 0; i < addr_list.size(); i++) { + MS_EXCEPTION_IF_NULL(device_ptr_list[i]); + MS_EXCEPTION_IF_NULL(addr_list[i]); + addr_list[i]->ptr_ = device_ptr_list[i]; + addr_list[i]->from_mem_pool_ = true; + } + return true; +} + +std::vector MemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { + if (total_size == 0) { + MS_LOG(ERROR) << "MallocContinuousMemFromMemPool total_size is 0."; + } + std::vector device_ptr_list; + device_ptr_list.emplace_back(nullptr); + return device_ptr_list; +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/memory_manager.h b/mindspore/ccsrc/runtime/device/memory_manager.h new file mode 100644 index 0000000000..3c6fb1b39a --- /dev/null +++ b/mindspore/ccsrc/runtime/device/memory_manager.h @@ -0,0 +1,73 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ +#include +#include +#include "backend/optimizer/mem_reuse/mem_reuse.h" +#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" +namespace mindspore { +namespace device { +const int kStaticMem = 0; +const int kDynamicMem = 1; +const int kReuseDynamicMem = 2; +const int kGetAllOuts = -1; +const uint64_t kMemAlignSize = 512; +using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; + +class MemoryManager { + public: + MemoryManager() = default; + virtual ~MemoryManager() = default; + + virtual void MallocDeviceMemory() = 0; + virtual void FreeDeviceMemory() = 0; + virtual void ResetDynamicMemory() { + total_dynamic_size_ = 0; + dynamic_mem_offset_ = 0; + } + + void MallocReusedDynamicMem(session::KernelGraph *graph); + uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size); + uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size); + virtual uint8_t *MallocMem(int flag, size_t size); + + virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); + virtual void *MallocMemFromMemPool(size_t size); + virtual void FreeMemFromMemPool(const DeviceAddressPtr address); + virtual void FreeMemFromMemPool(void *device_ptr); + virtual bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list); + virtual std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); + + size_t GetCommonAlignSize(size_t input_size) const; + size_t GetCommunicationAlignSize(size_t input_size) const; + + protected: + virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem); + virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem); + uint8_t *device_mem_base_{nullptr}; + uint64_t device_mem_size_{0}; + uint64_t dynamic_mem_offset_{0}; + uint64_t static_mem_offset_{0}; + size_t total_static_size_ = 0; + size_t total_dynamic_size_ = 0; + MemReuseUtilPtr mem_reuse_util_ptr_{nullptr}; +}; +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/session/CMakeLists.txt b/mindspore/ccsrc/session/CMakeLists.txt deleted file mode 100644 index 782eb51183..0000000000 --- a/mindspore/ccsrc/session/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "kernel_graph.cc" - "session_basic.cc" - "session_factory.cc" - "anf_runtime_algorithm.cc" -) - -if (ENABLE_GPU) - file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "gpu_session.cc" - ) - list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) -endif () - -if (ENABLE_CPU) - file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "cpu_session.cc" - ) - list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST}) -endif () - -if (ENABLE_D) - file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "ascend_session.cc" - "ascend_control_parser.cc" - "ascend_inference_session.cc" - ) - list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) -endif () - -set_property(SOURCE ${_SESSION_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_SESSION) -add_library(_mindspore_session_obj OBJECT ${_SESSION_SRC_LIST}) diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc deleted file mode 100644 index 81ad02e787..0000000000 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ /dev/null @@ -1,1121 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/anf_runtime_algorithm.h" -#include -#include -#include -#include -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "operator/ops.h" -#include "utils/utils.h" -#include "device/kernel_info.h" -#include "device/device_address.h" -#include "pre_activate/common/helper.h" -#include "kernel/kernel.h" -#include "kernel/kernel_build_info.h" -#include "common/utils.h" -#include "common/trans.h" - -namespace mindspore { -namespace session { -using abstract::AbstractTensor; -using abstract::AbstractTuple; -using device::KernelInfo; -using device::ascend::AscendDeviceAddress; -using kernel::KernelBuildInfoPtr; -using kernel::KernelMod; -using kernel::KernelModPtr; -namespace { -std::vector TransShapeToSizet(const abstract::ShapePtr &shape) { - MS_EXCEPTION_IF_NULL(shape); - std::vector shape_size_t; - std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize); - return shape_size_t; -} -} // namespace - -KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) { - MS_EXCEPTION_IF_NULL(anf_node); - if (anf_node->isa()) { - return std::make_pair(anf_node, 0); - } else if (anf_node->isa()) { - return std::make_pair(anf_node, 0); - } else if (anf_node->isa()) { - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input0 = cnode->input(0); - MS_EXCEPTION_IF_NULL(input0); - if (IsPrimitive(input0, prim::kPrimMakeTuple)) { - auto node = cnode->input(index + IntToSize(1)); - MS_EXCEPTION_IF_NULL(node); - return VisitKernel(node, 0); - } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { - if (cnode->inputs().size() != kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; - } - auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(input2); - auto value_node = input2->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx)); - } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { - return VisitKernel(cnode->input(kRealInputIndexInDepend), 0); - } else { - return std::make_pair(anf_node, index); - } - } else { - MS_LOG(EXCEPTION) << "The input is invalid"; - } -} - -KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index, - bool visit_nop_node, - const std::vector &return_types) { - MS_EXCEPTION_IF_NULL(anf_node); - for (const auto &prim_type : return_types) { - if (CheckPrimitiveType(anf_node, prim_type)) { - return std::make_pair(anf_node, index); - } - } - if (anf_node->isa()) { - return std::make_pair(anf_node, 0); - } else if (anf_node->isa()) { - return std::make_pair(anf_node, 0); - } else if (anf_node->isa()) { - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input0 = cnode->input(0); - MS_EXCEPTION_IF_NULL(input0); - if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { - if (cnode->inputs().size() != kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; - } - auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(input2); - auto value_node = input2->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx), - visit_nop_node, return_types); - } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { - return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types); - } else if (opt::IsNopNode(cnode) && visit_nop_node) { - if (cnode->inputs().size() == 2) { - return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types); - } else { - MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node"; - } - } else { - return std::make_pair(anf_node, index); - } - } else { - MS_LOG(EXCEPTION) << "The input is invalid"; - } -} - -std::vector AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node, - const std::vector &return_types) { - std::vector ret; - auto return_prim_type = return_types; - // if visited make_tuple should return back - return_prim_type.push_back(prim::kPrimMakeTuple); - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type); - if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { - MS_EXCEPTION_IF_NULL(item_with_index.first); - auto make_tuple = item_with_index.first->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - for (size_t i = 1; i < make_tuple->inputs().size(); i++) { - auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types); - (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret)); - } - return ret; - } - ret.push_back(item_with_index.first); - return ret; -} - -AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - return node->input(kAnfPrimitiveIndex); -} - -PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto attr_input = GetCNodePrimitiveNode(cnode); - MS_EXCEPTION_IF_NULL(attr_input); - auto value_node = attr_input->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - auto primitive = value->cast(); - return primitive; -} - -bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); -} - -FuncGraphPtr AnfRuntimeAlgorithm::GetCNodeFuncGraphPtr(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto attr_input = cnode->input(kAnfPrimitiveIndex); - MS_EXCEPTION_IF_NULL(attr_input); - auto value_node = attr_input->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - return value->cast(); -} - -std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto primitive = AnfAlgo::GetCNodePrimitive(node); - if (primitive != nullptr) { - return primitive->name(); - } - auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(func_graph); - return func_graph->ToString(); - } - MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString(); -} - -std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - return node->DebugString(); -} - -void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString(); - } - // single op cnode. - auto primitive = AnfAlgo::GetCNodePrimitive(node); - if (primitive != nullptr) { - primitive->set_attr(key, value); - return; - } - // graph kernel cnode. - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - fg->set_attr(key, value); -} - -void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) { - CopyNodeAttr(key, key, from, to); -} - -void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, - const AnfNodePtr &to) { - MS_EXCEPTION_IF_NULL(from); - MS_EXCEPTION_IF_NULL(to); - if (!from->isa() || !to->isa()) { - MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is " - << to->DebugString(); - } - auto from_primitive = AnfAlgo::GetCNodePrimitive(from); - MS_EXCEPTION_IF_NULL(from_primitive); - auto to_primitive = AnfAlgo::GetCNodePrimitive(to); - MS_EXCEPTION_IF_NULL(to_primitive); - to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key)); -} - -void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) { - MS_EXCEPTION_IF_NULL(from); - MS_EXCEPTION_IF_NULL(to); - if (!from->isa() || !to->isa()) { - MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is " - << from->DebugString(); - } - auto from_primitive = AnfAlgo::GetCNodePrimitive(from); - MS_EXCEPTION_IF_NULL(from_primitive); - auto to_primitive = AnfAlgo::GetCNodePrimitive(to); - MS_EXCEPTION_IF_NULL(to_primitive); - (void)to_primitive->SetAttrs(from_primitive->attrs()); -} - -void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString(); - } - // single op cnode. - auto primitive = AnfAlgo::GetCNodePrimitive(node); - if (primitive != nullptr) { - primitive->EraseAttr(key); - return; - } - // graph kernel cnode. - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - fg->erase_flag(key); -} - -bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString(); - return false; - } - // single op cnode. - auto primitive = AnfAlgo::GetCNodePrimitive(node); - if (primitive != nullptr) { - return primitive->HasAttr(key); - } - // graph kernel cnode. - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - return fg->has_attr(key); -} - -size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString(); - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - size_t input_num = cnode->inputs().size(); - if (input_num == 0) { - MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero"; - } - // exclude intputs[0],which is value_node storing attr,inputs left are real input - return input_num - 1; -} - -size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - TypePtr type = node->Type(); - if (type == nullptr) { - return 0; - } - if (type->isa()) { - auto tuple_type = type->cast(); - MS_EXCEPTION_IF_NULL(tuple_type); - return tuple_type->size(); - } else if (type->isa() || type->isa()) { - return 1; - } else if (type->isa()) { - return 0; - } else { - return 1; - } -} - -std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "Output index:" << output_idx - << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" - << node->DebugString() << "]"; - } - if (!AnfAlgo::IsRealKernel(node)) { - return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - auto format = build_info->GetOutputFormat(output_idx); - if (format == kernel::KernelBuildInfo::kInvalidFormat) { - MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" - << " has a invalid output format"; - } - return format; -} - -std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) { - MS_EXCEPTION_IF_NULL(node); - if (input_idx > GetInputTensorNum(node)) { - MS_LOG(EXCEPTION) << "Input index :" << input_idx - << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" - << node->DebugString() << "]"; - } - if (!IsRealKernel(node)) { - GetPrevNodeOutputFormat(node, input_idx); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - auto format = build_info->GetInputFormat(input_idx); - if (format == kernel::KernelBuildInfo::kInvalidFormat) { - MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" - << " has a invalid input format"; - } - return format; -} - -KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) { - MS_EXCEPTION_IF_NULL(anf_node); - if (!anf_node->isa()) { - MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; - } - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); - } - auto node = cnode->input(input_idx + 1); - MS_EXCEPTION_IF_NULL(node); - return VisitKernel(node, 0); -} - -std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); - return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); -} - -std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); - return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second); -} - -std::vector AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - abstract::BaseShapePtr base_shape = node->Shape(); - MS_EXCEPTION_IF_NULL(base_shape); - if (base_shape->isa() && output_idx == 0) { - return TransShapeToSizet(base_shape->cast()); - } else if (base_shape->isa()) { - auto tuple_shape = base_shape->cast(); - MS_EXCEPTION_IF_NULL(tuple_shape); - if (output_idx >= tuple_shape->size()) { - MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size() - << "."; - } - auto b_shp = (*tuple_shape)[output_idx]; - if (b_shp->isa()) { - return TransShapeToSizet(b_shp->cast()); - } else if (b_shp->isa()) { - return std::vector(); - } else { - MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx - << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString(); - } - } else if (base_shape->isa()) { - return std::vector(); - } - MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is " - << base_shape->ToString(); -} - -std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); - return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second); -} - -std::vector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) { - auto format = GetOutputFormat(node, output_idx); - auto infer_shape = GetOutputInferShape(node, output_idx); - if (infer_shape.empty()) { - return infer_shape; - } - // if format is default_format or NC1KHKWHWC0,device shape = original shape - if (trans::IsNeedPadding(format, infer_shape.size())) { - infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); - } - return trans::TransShapeToDevice(infer_shape, format); -} - -std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { - auto format = GetInputFormat(node, input_idx); - auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx); - if (infer_shape.empty()) { - return infer_shape; - } - // if format is default_format or NC1KHKWHWC0,device shape = original shape - if (trans::IsNeedPadding(format, infer_shape.size())) { - infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); - } - return trans::TransShapeToDevice(infer_shape, format); -} - -std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { - MS_EXCEPTION_IF_NULL(node); - if (input_idx > GetInputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index:" << input_idx - << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" - << node->DebugString() << "]"; - } - if (!IsRealKernel(node)) { - return GetPrevNodeOutputReshapeType(node, input_idx); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - if (build_info->IsInputDefaultPadding()) { - return {}; - } - return build_info->GetInputReshapeType(input_idx); -} - -std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " - << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; - } - if (!IsRealKernel(node)) { - return GetPrevNodeOutputReshapeType(node, output_idx); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - if (build_info->IsOutputDefaultPadding()) { - return {}; - } - return build_info->GetOutputReshapeType(output_idx); -} - -TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - TypePtr type_ptr = node->Type(); - MS_EXCEPTION_IF_NULL(type_ptr); - if (type_ptr->isa() && output_idx == 0) { - auto tensor_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(tensor_ptr); - TypePtr elem = tensor_ptr->element(); - MS_EXCEPTION_IF_NULL(elem); - return elem->type_id(); - } else if (type_ptr->isa()) { - auto tuple_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(tuple_ptr); - if (output_idx >= tuple_ptr->size()) { - MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size(); - } - auto tuple_i = (*tuple_ptr)[output_idx]; - MS_EXCEPTION_IF_NULL(tuple_i); - if (tuple_i->isa()) { - auto tensor_ptr = tuple_i->cast(); - MS_EXCEPTION_IF_NULL(tensor_ptr); - TypePtr elem = tensor_ptr->element(); - MS_EXCEPTION_IF_NULL(elem); - return elem->type_id(); - } else if (tuple_i->isa()) { - return tuple_i->type_id(); - } else { - MS_LOG(WARNING) << "Not support type " << tuple_i->ToString(); - return tuple_i->type_id(); - } - } else if (type_ptr->isa()) { - return type_ptr->type_id(); - } - return type_ptr->type_id(); -} - -TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); - return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second); -} - -TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " - << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; - } - if (!IsRealKernel(node)) { - return GetPrevNodeOutputDeviceDataType(node, output_idx); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - auto dtype = build_info->GetOutputDeviceType(output_idx); - if (dtype == TypeId::kNumberTypeEnd) { - MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" - << " has a invalid dtype"; - } - return dtype; -} - -TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) { - MS_EXCEPTION_IF_NULL(node); - if (input_idx > GetInputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " - << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; - } - if (!IsRealKernel(node)) { - return GetPrevNodeOutputDeviceDataType(node, 0); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - auto dtype = build_info->GetInputDeviceType(input_idx); - if (dtype == TypeId::kNumberTypeEnd) { - MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" - << " has a invalid dtype"; - } - return dtype; -} - -TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); - return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); -} - -// get output device addr of anf_node -const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx, - bool visit_nop_node) { - MS_EXCEPTION_IF_NULL(node); - if (opt::IsNopNode(node) && visit_nop_node) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() == 2) { - return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0); - } else { - MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; - } - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto addr = kernel_info->GetOutputAddr(output_idx); - if (addr == nullptr) { - MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() - << " output addr is not exist"; - } - return addr; -} - -DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, - bool visit_nop_node) { - MS_EXCEPTION_IF_NULL(node); - if (opt::IsNopNode(node) && visit_nop_node) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() == 2) { - return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0); - } else { - MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; - } - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto addr = kernel_info->GetMutableOutputAddr(output_idx); - if (addr == nullptr) { - MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString() - << " output addr is not exist"; - } - return addr; -} - -// get output device addr of anf_node -bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " - << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->OutputAddrExist(output_idx); -} - -const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, - bool visit_nop_node) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); - return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); -} - -DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, - bool visit_nop_node) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); - return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); -} - -// set output device addr of anf_node -void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - if (!kernel_info->SetOutputAddr(addr, output_idx)) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; - } -} - -// set workspace device addr of anf_node -void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; - } -} - -// get workspace device addr of anf_node -DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto addr = kernel_info->GetWorkspaceAddr(output_idx); - if (addr == nullptr) { - MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() - << "] workspace addr is not exist"; - } - return addr; -} - -// set infer shapes and types of anf node -void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector &types, - const std::vector> &shapes, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - if (types.size() != shapes.size()) { - MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size(); - } - if (shapes.empty()) { - node->set_abstract(std::make_shared()); - } else if (shapes.size() == 1) { - // single output handle - std::vector shape_int; - std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt); - auto abstract = std::make_shared(TypeIdToType(types[0]), shape_int); - node->set_abstract(abstract); - } else { - // multiple output handle - std::vector abstract_list; - for (size_t i = 0; i < types.size(); ++i) { - std::vector shape_int; - std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt); - abstract_list.push_back(std::make_shared(TypeIdToType(types[i]), shape_int)); - } - auto abstract_tuple = std::make_shared(abstract_list); - node->set_abstract(abstract_tuple); - } -} -// copy an abstract of a node to another node -void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) { - to_node->set_abstract(from_node->abstract()); -} - -kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - // select_kernel_build_info() has checked whether return pointer is null - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - return build_info->op_pattern(); -} - -// get KernelBuildType of node, such as ATT,RT,FWK and so on -KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - // select_kernel_build_info() has checked whether return pointer is null - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - return build_info->kernel_type(); -} - -kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - return build_info->processor(); -} - -kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - return build_info->fusion_type(); -} - -// set select kernel_build_info -void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->set_select_kernel_build_info(select_kernel_build_info); -} - -// get select kernel_build_info -KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->GetMutableSelectKernelBuildInfo(); -} - -// get kernelMode -KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->MutableKernelMod(); -} - -// set kernel mod -void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - kernel_info->set_kernel_mod(kernel_mod); -} - -bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - // parameter and value node is not a real kernel too - if (!node->isa()) { - return true; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().empty()) { - MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString(); - } - auto input = cnode->inputs()[0]; - bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || - IsPrimitive(input, prim::kPrimTensorSummary) || - IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || - IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || - IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || - IsPrimitive(input, prim::kPrimReturn); - return !is_virtual_node; -} - -bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - // parameter and value node is not a real cnode kernel - if (!node->isa()) { - return false; - } - // return considered as a real node - if (CheckPrimitiveType(node, prim::kPrimReturn)) { - return true; - } - return IsRealKernel(node); -} - -bool AnfRuntimeAlgorithm::IsGraphKernel(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - // graph kernel should be a real cnode kernel. - if (!IsRealCNodeKernel(node)) { - return false; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input = cnode->input(kAnfPrimitiveIndex); - // graph kernel should has func_graph as first input. - if (!IsValueNode(input)) { - return false; - } - - auto func_graph = GetValueNode(input); - MS_EXCEPTION_IF_NULL(func_graph); - return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); -} - -bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { - MS_EXCEPTION_IF_NULL(node); - return node->has_default(); -} - -void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - kernel_info->set_stream_id(stream_id); -} - -uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->stream_id(); -} - -void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - kernel_info->set_stream_distinction_label(stream_label); -} - -uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->stream_distinction_label(); -} - -void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - kernel_info->set_graph_id(graph_id); -} - -uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->graph_id(); -} - -bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) { - MS_EXCEPTION_IF_NULL(anf); - TypePtr type = anf->Type(); - MS_EXCEPTION_IF_NULL(type); - return type->isa(); -} - -AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) { - MS_EXCEPTION_IF_NULL(node); - auto get_input_index = index + 1; - if (index + 1 > node->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just" - << node->inputs().size(); - } - // input 0 is primitive node - return node->input(get_input_index); -} - -bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - return false; - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->is_feature_map(); -} - -bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) { - if (!node->isa()) { - MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map"; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_node = cnode->input(input_index + 1); - return IsFeatureMapOutput(input_node); -} - -size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) { - MS_EXCEPTION_IF_NULL(anf_node); - static std::map> spec_node_list = { - {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}}, - {kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}}, - {kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, - {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}}, - {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}}, - {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, - {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, - {prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, - {prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}, - {prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}, - {prim::kPrimApplyCenteredRMSProp->name(), - {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 4}}}}; - size_t ret = cur_index; - auto node_name = AnfAlgo::GetCNodeName(anf_node); - if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) { - auto find = spec_node_list.find(node_name); - if (find != spec_node_list.end()) { - ret = find->second[cur_index]; - MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name; - } - } - return ret; -} - -void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(input_node); - node->set_input(index + 1, input_node); -} - -bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - auto kernel_name = AnfAlgo::GetCNodeName(node); - if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName || - kernel_name == kReduceScatterOpName) { - return true; - } - return false; -} - -bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { - auto kernel_name = AnfAlgo::GetCNodeName(node); - return kernel_name == kGetNextOpName; -} - -FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto value_node = node->cast(); - if (value_node == nullptr) { - return nullptr; - } - auto value = value_node->value(); - if (value == nullptr) { - return nullptr; - } - auto func_graph = value->cast(); - return func_graph; -} - -std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { - MS_EXCEPTION_IF_NULL(call_node); - if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared("call"))) { - MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node."; - } - auto input1 = call_node->input(1); - MS_EXCEPTION_IF_NULL(input1); - if (input1->isa()) { - auto value_node = input1->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto kernel_graph = value_node->value(); - MS_EXCEPTION_IF_NULL(kernel_graph); - return {kernel_graph->cast()}; - } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { - auto switch_node = input1->cast(); - MS_EXCEPTION_IF_NULL(switch_node); - auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { - auto partial = switch_node->input(input_index); - MS_EXCEPTION_IF_NULL(partial); - if (IsValueNode(partial)) { - return GetValueNode(partial); - } - auto partial_cnode = partial->cast(); - MS_EXCEPTION_IF_NULL(partial_cnode); - auto graph_node = partial_cnode->input(1); - MS_EXCEPTION_IF_NULL(graph_node); - auto graph_value_node = graph_node->cast(); - MS_EXCEPTION_IF_NULL(graph_value_node); - auto graph_value = graph_value_node->value(); - MS_EXCEPTION_IF_NULL(graph_value); - auto child_graph = graph_value->cast(); - return child_graph; - }; - return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; - } - return {}; -} - -bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { - MS_EXCEPTION_IF_NULL(call_node); - if (!CheckPrimitiveType(call_node, prim::kPrimCall)) { - MS_LOG(EXCEPTION) << "Call node should be a 'call', but is a " << call_node->DebugString(); - } - auto input1 = call_node->input(1); - if (input1->isa()) { - return false; - } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { - return true; - } - MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); -} - -bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); - if (shape.empty()) { - return true; - } - return shape.size() == kShape1dDims && shape[0] == 1; -} - -bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); - if (shape.empty()) { - return true; - } - return shape.size() == kShape1dDims && shape[0] == 1; -} - -void AnfRuntimeAlgorithm::ReorderExecList(NotNull *> node_list) { - std::vector all_opt_list; - std::vector non_opt_list; - - for (const auto &node : *node_list) { - MS_EXCEPTION_IF_NULL(node); - if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) { - all_opt_list.emplace_back(node); - } else { - non_opt_list.emplace_back(node); - } - } - node_list->clear(); - std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list)); - std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); -} - -TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto prim = AnfAlgo::GetCNodePrimitive(node); - if (prim == nullptr) { - return kTypeUnknown; - } - - TypeId except_type = kTypeUnknown; - if (prim->GetAttr(kAttrOutputPrecision) != nullptr) { - auto output_type_str = GetValue(prim->GetAttr(kAttrOutputPrecision)); - if (output_type_str == "float16") { - except_type = kNumberTypeFloat16; - } else if (output_type_str == "float32") { - except_type = kNumberTypeFloat32; - } else { - MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got " << output_type_str; - } - } - - return except_type; -} - -TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx) { - if (!node->isa()) { - MS_LOG(EXCEPTION) << node->DebugString() << ", input node is not CNode."; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); - } - auto input_node = cnode->input(input_idx + 1); - MS_EXCEPTION_IF_NULL(input_node); - auto kernel_with_index = VisitKernel(input_node, 0); - if (!kernel_with_index.first->isa()) { - return kTypeUnknown; - } - return GetCNodeOutputPrecision(kernel_with_index.first); -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h deleted file mode 100644 index 8205619793..0000000000 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ /dev/null @@ -1,210 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H -#define MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H -#include -#include -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "ir/dtype.h" -#include "ir/base.h" -#include "ir/primitive.h" -#include "device/device_address.h" -#include "kernel/kernel.h" -#include "kernel/kernel_build_info.h" -#include "operator/ops.h" -#include "utils/contract.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace session { -using AnfVisitFuncion = std::function; -using KernelWithIndex = std::pair; -class AnfRuntimeAlgorithm { - public: - // get input_anf_node's real kernel by recurse - static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index); - static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index, - bool visit_nop_node = false, - const std::vector &return_types = { - prim::kPrimMakeTuple}); - static std::vector GetAllOutput(const AnfNodePtr &node, - const std::vector &return_types = {}); - // get cnode primitive - static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node); - static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index); - static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); - // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple - static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type); - // get cnode primitive - static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node); - // get kernel_name of anf node - static std::string GetCNodeName(const AnfNodePtr &node); - // get detail info of anf node - static std::string GetNodeDebugString(const AnfNodePtr &node); - // get attr of anf node - template - static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - std::string node_debug_log = node->DebugString(); - MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str(); - } - // single op cnode. - if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) { - return GetValue(primitive->GetAttr(key)); - } - // graph kernel cnode. - auto fg = GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - return GetValue(fg->get_attr(key)); - } - static bool IsTupleOutput(const AnfNodePtr &anf); - // set attr of anf node - static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); - // set attr of key from 'from' node to 'to' node - static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to); - // set a new key for attr from 'from' node to 'to' node - static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, - const AnfNodePtr &to); - // set all attrs from 'from' node to 'to' node - static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to); - // check whether a cnode has the specified attr. - static bool HasNodeAttr(const std::string &key, const CNodePtr &node); - // delete attr of anf node - static void EraseNodeAttr(const std::string &key, AnfNodePtr node); - // get the num of input real_kernel(which can be build and run in device) - static size_t GetInputTensorNum(const AnfNodePtr &node); - // get the num of output real_kernel(which can be build and run in device) - static size_t GetOutputTensorNum(const AnfNodePtr &node); - // get output format select of anf node - static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx); - // get input format select of anf node - static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx); - // get prev node output width output index - static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); - // get output format from prev node,input_index is the input index of current node related to prev node - static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); - // get reshape_type of from the output of input node. - static std::vector GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); - // get output shapes inferred by ME from input nodes. - static std::vector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); - // get input shapes inferred by ME from input nodes. - static std::vector GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx); - // get output shapes which will built and run in device - static std::vector GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); - // get input shapes which will built and run in device - static std::vector GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); - // Get Input Padding Axis - static std::vector GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); - // Get Output Padding Axis - static std::vector GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); - // get output data type inferred by ME of anf node - static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); - // get output original data type from prev node,input_index is the input index of current node related to prev node - static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx); - // get output select data type of anf node - static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx); - // get input select data type of anf node - static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx); - // get output select data type from prev node,input_index is the input index of current node related to prev node - static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx); - // get output device addr of anf_node - static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); - // get mutable output device addr of anf_node - static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); - // check whether output addr is exist or not - static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); - // get address from prev node,input_index is the input index of current node related to prev node - static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, - bool visit_nop_node = true); - static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, - bool visit_nop_node = true); - // set output device addr of anf_node - static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); - // set workspace device addr of anf_node - static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); - // get workspace device addr of anf_node - static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx); - // set infer shapes and types of anf node - static void SetOutputInferTypeAndShape(const std::vector &types, - const std::vector> &shapes, AnfNode *node); - static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); - // get op pattern of the node - static kernel::OpPattern GetOpPattern(const AnfNodePtr &node); - // get KernelBuildType of node ,such as ATT,RT,FWK and so on - static KernelType GetKernelType(const AnfNodePtr &node); - // get processor type:AICORE,AICPU... - static kernel::Processor GetProcessor(const AnfNodePtr &node); - // get fusion type:AICORE,AICPU... - static kernel::FusionType GetFusionType(const AnfNodePtr &node); - // set select kernel_build_info - static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node); - // get select kernel_build_info - static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node); - // get kernelMode - static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node); - // set kernel mod - static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node); - // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too - static bool IsRealKernel(const AnfNodePtr &node); - // checkout whether the anf node is a real kernel that is a cnode and can run on device - static bool IsRealCNodeKernel(const AnfNodePtr &node); - // checkout whether the anf node is a graph kernel. - static bool IsGraphKernel(const AnfNodePtr &node); - // check parameter is weight or data - static bool IsParameterWeight(const ParameterPtr &node); - // set stream id of kernel,which will be set in stream assign and be used in stream generate - static void SetStreamId(uint32_t stream_id, AnfNode *node); - // get stream id - static uint32_t GetStreamId(const AnfNodePtr &node); - // set stream distinction label to distinguish different ops in different streams - static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node); - // get stream distinction label - static uint32_t GetStreamDistinctionLabel(const AnfNode *node); - // set graph id - static void SetGraphId(uint32_t graph_id, AnfNode *node); - // get graph id - static uint32_t GetGraphId(const AnfNode *node); - static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index); - // charge if the node's output is a feature map output - static bool IsFeatureMapOutput(const AnfNodePtr &node); - // charge if the node's input is from a feature map output - static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index); - // get real input index for some tbe ops which input order is different between me and tbe impl - static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); - static bool IsCommunicationOp(const AnfNodePtr &node); - static bool IsGetNext(const NotNull &node); - static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); - static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); - static bool IsSwitchCall(const CNodePtr &call_node); - static bool IsScalarInput(const CNodePtr &cnode, size_t index); - static bool IsScalarOutput(const CNodePtr &cnode, size_t index); - static void ReorderExecList(NotNull *> node_list); - // get fix output precision of cnode. - static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); - // get fix output precision from prev node, input_idx is the input index of current node related to prev node. - static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); -}; -} // namespace session -using AnfAlgo = session::AnfRuntimeAlgorithm; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc deleted file mode 100644 index 0c97116c6e..0000000000 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ /dev/null @@ -1,643 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "session/ascend_control_parser.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/union_find_set.h" -#include "device/ascend/ascend_label_assign.h" - -static constexpr size_t kCNodePrim = 0; -static constexpr size_t kCNodeCallArg = 1; -static constexpr size_t kCNodeSwitchCond = 1; -static constexpr size_t kCNodeSwitchTrue = 2; -static constexpr size_t kCNodeSwitchFalse = 3; -static constexpr size_t kCNodeSwitchLength = 4; -static constexpr size_t kCNodePartialLength = 2; -static constexpr size_t kCNodePartialFunc = 1; -static constexpr size_t kCNodeSwitchLayerBranch = 2; -static constexpr size_t kCNodeSwitchLayerLength = 3; - -namespace mindspore { -namespace session { -static CNodePtr GetJumpNode(NotNull parent_graph, NotNull child_graph) { - auto &nodes = parent_graph->execution_order(); - CNodePtr last_jump_node = nullptr; - for (auto &node : nodes) { - if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) { - if (child_graph->get_start_label() == node->input(kCNodeCallArg)) { - return node; - } - last_jump_node = node; - } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) { - if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || - child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) { - return node; - } - last_jump_node = node; - } - } - if (last_jump_node == nullptr) { - MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); - } - return last_jump_node; -} - -static void InitUnionFindSet(NotNull kg, const NotNull *> union_find_set, - const NotNull *> memo) { - if (memo->find(kg.get()) != memo->end()) { - return; - } - memo->insert(kg.get()); - - const std::vector>> &real_inputs = kg->real_inputs(); - for (auto &iter : real_inputs) { - auto ¶ = iter.first; - MS_EXCEPTION_IF_NULL(para); - if (para->isa()) { - union_find_set->Add(para); - } - for (auto &arg : iter.second) { - MS_EXCEPTION_IF_NULL(arg); - if (!arg->isa()) { - continue; - } - union_find_set->Add(arg); - } - } - for (auto &child : kg->child_graph_order()) { - InitUnionFindSet(NOT_NULL(child), union_find_set, memo); - } -} - -static void UnionParentParameter(NotNull kg, const NotNull *> union_find_set, - const NotNull *> memo) { - if (memo->find(kg.get()) != memo->end()) { - return; - } - memo->insert(kg.get()); - - const std::vector>> &real_inputs = kg->real_inputs(); - for (auto &iter : real_inputs) { - auto ¶ = iter.first; - for (auto &arg : iter.second) { - MS_EXCEPTION_IF_NULL(arg); - if (!arg->isa()) { - continue; - } - if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) { - continue; - } - union_find_set->Union(arg, para); - } - } - for (auto &child : kg->child_graph_order()) { - UnionParentParameter(NOT_NULL(child), union_find_set, memo); - } -} - -static UnionFindSet MakeUnionFindSet(NotNull root_kg) { - UnionFindSet result; - std::set memo; - InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); - memo.clear(); - UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); - return result; -} - -static void RecursiveReplaceNode(NotNull kg, NotNull main_parameter, - const std::set ¶meter_reuse_set, - const NotNull *> memo) { - if (parameter_reuse_set.empty()) { - MS_LOG(EXCEPTION) << "Parameter_reuse_set is empty."; - } - if (memo->find(kg.get()) != memo->end()) { - return; - } - memo->insert(kg.get()); - - for (auto ¶ : parameter_reuse_set) { - if (para == main_parameter.get()) { - continue; - } - MS_EXCEPTION_IF_NULL(para); - MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to " - << main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get()); - kg->ReplaceNode(NOT_NULL(para), main_parameter); - } - - for (auto &child : kg->child_graph_order()) { - RecursiveReplaceNode(NOT_NULL(child), main_parameter, parameter_reuse_set, memo); - } -} - -static AnfNodePtr GetMainParameter(NotNull root_kg, const AnfNodePtr key, - const std::set ¶meter_reuse_set) { - AnfNodePtr main_parameter = key; - std::set root_inputs_set; - const auto &root_inputs_vector = root_kg->inputs(); - root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); - for (auto &node : parameter_reuse_set) { - if (root_inputs_set.find(node) != root_inputs_set.end()) { - main_parameter = node; - break; - } - } - return main_parameter; -} - -static void ReuseParameter(NotNull root_kg, NotNull *> parameter_set) { - auto parameter_reuse_sets = parameter_set->GetSets(); - for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { - if (parameter_reuse_set.size() <= 1) { - continue; - } - auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set); - std::set memo; - RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo)); - } -} - -CNodePtr GetNextRealKernel(const std::vector &list, size_t start) { - for (size_t i = start; i < list.size() - 1; ++i) { - if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { - return list[i]; - } - } - return nullptr; -} - -void AscendControlParser::LinkGraph(NotNull kg) { - std::set memo; - (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); - device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); - std::map graph_id_map; - for (auto &g : memo) { - MS_EXCEPTION_IF_NULL(g); - if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) { - MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id() - << ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString(); - } - graph_id_map[g->graph_id()] = g; - } - - // Insert Assign - ChildGraphDataAssign(graph_id_map); - // Make UnionFindSet - UnionFindSet parameter_set = MakeUnionFindSet(kg); - // Reuse Parameter - ReuseParameter(kg, NOT_NULL(¶meter_set)); -} - -void AscendControlParser::ExecutorValidate(NotNull root_graph) { - std::set memo; - (void)RecurseGraph(root_graph, NOT_NULL(&memo)); -} - -void AscendControlParser::ChildGraphDataAssign(const std::map &graph_id_map) { - for (auto &iter : graph_id_map) { - auto &kg = iter.second; - MS_LOG(INFO) << "Data assign graph:" << kg->graph_id(); - MS_EXCEPTION_IF_NULL(kg); - std::set> memo; - const std::vector>> &real_inputs = kg->real_inputs(); - for (auto &it : real_inputs) { - auto ¶meter = it.first; - auto &args = it.second; - for (auto &arg : args) { - MS_EXCEPTION_IF_NULL(arg); - if (memo.find({parameter, arg}) != memo.end()) { - continue; - } else { - memo.emplace(parameter, arg); - } - auto unreuse_args_map = kg->unreuse_args(); - auto unreuse_arg_iter = unreuse_args_map.find(arg); - if (unreuse_arg_iter == unreuse_args_map.end()) { - MS_EXCEPTION_IF_NULL(arg); - MS_EXCEPTION_IF_NULL(parameter); - if (!arg->isa()) { - MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << "."; - } - MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() - << ", arg:" << arg->DebugString(); - continue; - } - auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); - if (target_graph_iter == graph_id_map.end()) { - MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; - } - InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), - NOT_NULL(parameter)); - } - } - kg->SetExecOrderByDefault(); - } -} - -NotNull AscendControlParser::GetStartLabel(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label) { - CNodePtr start_label; - if (last_node != nullptr && last_label != nullptr) { - start_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); - kg->set_start_label(start_label); - } else { - // no goto node will jump to start label of root graph, so return a fake label - start_label = std::make_shared(std::vector(), FuncGraphPtr(nullptr)); - } - return NOT_NULL(start_label); -} - -NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label, - const NotNull *> memo) { - MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); - - // 1. recursive condition - if (memo->find(kg) != memo->end()) { - MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); - return NOT_NULL(kg->get_start_label()); - } - memo->insert(kg.get()); - - // 2. args replace placeholder - LinkParentGraph(kg, last_node, last_label); - - // 3. topological sort - kg->SetExecOrderByDefault(); - const std::vector &nodes = kg->execution_order(); - // 4. insert first_label - CNodePtr start_label = GetStartLabel(kg, last_node, last_label); - - // 5. traverse - for (size_t i = 0; i < nodes.size(); ++i) { - auto &cnode = nodes[i]; - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() < kCNodePrim + 1) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex); - if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { - MS_LOG(DEBUG) << "Continue node " << cnode->DebugString(); - continue; - } - AnfNodePtr arg = cnode->input(kFirstDataInputIndex); - MS_EXCEPTION_IF_NULL(arg); - if (IsValueNode(arg)) { - RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else if (!arg->isa()) { - MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); - } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitch)) { - auto arg_cnode = arg->cast(); - MS_EXCEPTION_IF_NULL(arg_cnode); - cnode->set_inputs(arg_cnode->inputs()); - RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitchLayer)) { - auto arg_cnode = arg->cast(); - MS_EXCEPTION_IF_NULL(arg_cnode); - cnode->set_inputs(arg_cnode->inputs()); - RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } - } - kg->SetExecOrderByDefault(); - MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); - return NOT_NULL(start_label); -} - -void AscendControlParser::InsertDependToGraph(NotNull kg, NotNull attch_node) { - auto return_node = kg->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), - return_node->input(kFirstDataInputIndex), attch_node.get()}; - auto depend_node = kg->NewCNode(inputs); - return_node->set_input(1, depend_node); -} - -void AscendControlParser::InsertControlDependToGraph(NotNull kg, NotNull first_node, - NotNull second_node) { - MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() - << ", the second node is " << second_node->DebugString(); - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimControlDepend->name())), - first_node, second_node}; - auto control_depend = kg->NewCNode(inputs); - InsertDependToGraph(kg, NOT_NULL(control_depend)); -} - -void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, - const CNodePtr &last_label) { - // if not entry graph, replace return with label_goto - if (from_graph_call_node != nullptr && last_label != nullptr) { - auto label_goto = - kg->NewCNode({std::make_shared(std::make_shared(kLabelGotoOpName)), last_label}); - MS_EXCEPTION_IF_NULL(label_goto); - MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString(); - kg->set_end_goto(label_goto); - } -} - -void AscendControlParser::RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo) { - MS_LOG(INFO) << "Process call func " << cur_node->DebugString(); - - // 1 get kernel graph - const std::vector &origin_inputs = cur_node->inputs(); - if (kCNodeCallArg >= origin_inputs.size()) { - MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size(); - } - std::vector new_inputs = {std::make_shared(std::make_shared(kLabelGotoOpName))}; - if (!IsValueNode(origin_inputs[kCNodeCallArg])) { - MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; - return; - } - // 2 return label - auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " call node " - << cur_node->DebugString(); - // 3 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - if (next_node != nullptr && next_node != kg->get_return()) { - InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); - } - auto call_kg = GetValueNode(origin_inputs[kCNodeCallArg]); - // 4 modify call op to goto op - cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]); - // 5 recurse sub graph - CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo); - new_inputs.push_back(sub_label); - cur_node->set_inputs(new_inputs); - cur_node->set_abstract(nullptr); - MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); -} - -void AscendControlParser::RecurseSwitch(NotNull kg, NotNull cur_node, - const CNodePtr &next_node, const NotNull *> memo) { - MS_LOG(INFO) << "Process switch node " << cur_node->DebugString(); - - if (cur_node->size() < kCNodeSwitchLength) { - MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; - } - // 1 return label - auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - MS_EXCEPTION_IF_NULL(back_label); - MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node " - << cur_node->DebugString(); - // 2 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - if (next_node != nullptr && next_node != kg->get_return()) { - InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); - } - // 3 recurse sub graph - const std::vector &origin_switch_inputs = cur_node->inputs(); - if (kCNodeSwitchCond >= origin_switch_inputs.size()) { - MS_LOG(EXCEPTION) << "The size of origin_switch_inputs is not more than " << kCNodeSwitchCond; - } - std::vector new_switch_inputs = { - std::make_shared(std::make_shared(kLabelSwitchOpName)), - origin_switch_inputs[kCNodeSwitchCond]}; - for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { - // 3.1 branch kernel graph and args - KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); - // 3.2 recurse sub graph - CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); - new_switch_inputs.push_back(branch_label); - } - std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); - - cur_node->set_inputs(new_switch_inputs); - cur_node->set_abstract(nullptr); - MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString(); -} - -void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull cur_node, - const CNodePtr &next_node, - const NotNull *> memo) { - MS_LOG(INFO) << "Process switch node " << cur_node->DebugString(); - - if (cur_node->size() < kCNodeSwitchLayerLength) { - MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; - } - - auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); - MS_EXCEPTION_IF_NULL(branch_tuple); - if (!branch_tuple->isa()) { - MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode"; - } - const std::vector &branch_partial = utils::cast(branch_tuple)->inputs(); - // 1 return label - auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - // 2 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - if (next_node != nullptr && next_node != kg->get_return()) { - InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); - } - // 3 recurse sub graph - const std::vector &origin_switch_inputs = cur_node->inputs(); - if (kCNodeSwitchCond >= origin_switch_inputs.size()) { - MS_LOG(EXCEPTION) << "Index out of range:" << origin_switch_inputs.size() << "."; - } - std::vector new_switch_inputs = { - std::make_shared(std::make_shared(kLabelSwitchOpName)), - origin_switch_inputs[kCNodeSwitchCond]}; - for (size_t i = 0; i < branch_partial.size(); ++i) { - // 3.1 branch kernel graph and args - KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); - // 3.2 recurse sub graph - CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); - new_switch_inputs.push_back(branch_label); - } - new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); - cur_node->set_inputs(new_switch_inputs); - cur_node->set_abstract(nullptr); - MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); -} - -KernelGraphPtr AscendControlParser::ParsePartial(NotNull node) { - if (!node.get()->isa()) { - if (IsValueNode(node)) { - return GetValueNode(node); - } - MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); - } - // 2.1 branch kernel graph and args - auto partial_cnode = utils::cast(node.get()); - MS_EXCEPTION_IF_NULL(partial_cnode); - if (partial_cnode->size() < kCNodePartialLength) { - MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; - } - - const auto &partial_inputs = partial_cnode->inputs(); - if (kCNodePartialFunc >= partial_inputs.size()) { - MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << "."; - } - auto branch_kg = GetValueNode(partial_inputs[kCNodePartialFunc]); - return branch_kg; -} - -void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, - NotNull to_graph, NotNull from, - NotNull to) { - std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); - std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); - MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]"; - if (from_outputs.size() != to_outputs.size()) { - MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" - << to_outputs.size() << "]"; - } - for (size_t i = 0; i < from_outputs.size(); i++) { - auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); - if (assign_node != nullptr) { - auto jump_node = GetJumpNode(from_graph, to_graph); - const auto &from_graph_exe_order = from_graph->execution_order(); - auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); - if (jump_node_iter == from_graph_exe_order.end()) { - MS_EXCEPTION_IF_NULL(jump_node); - MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id(); - } - // insert assign between jump_node -1 and jump_node - if (jump_node_iter != from_graph_exe_order.begin()) { - InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); - } - if (jump_node != nullptr) { - InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); - } - } - } -} - -AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, - NotNull to) { - if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && - AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { - return nullptr; - } - if (from.get() == to.get()) { - return nullptr; - } - MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " - << to->DebugString(); - // config inputs of assign node - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimAssign->name())), to, from}; - // generate a new cnode - auto assign_node = kg->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(assign_node); - assign_node->set_abstract(to->abstract()); - return assign_node; -} - -std::vector AscendControlParser::RecurseGraph(NotNull graph, - const NotNull *> memo) { - MS_LOG(INFO) << "Graph:" << graph->graph_id() << " start"; - if (memo->find(graph) != memo->end()) { - return {}; - } - memo->insert(graph.get()); - graph->SetExecOrderByDefault(); - std::vector cnodes = graph->execution_order(); - - auto end_label_goto = graph->get_end_goto(); - if (cnodes.rbegin() != cnodes.rend() && *cnodes.rbegin() == end_label_goto) { - cnodes.pop_back(); - } - AnfAlgo::ReorderExecList(NOT_NULL(&cnodes)); - if (end_label_goto != nullptr) { - cnodes.push_back(end_label_goto); - } - - std::vector execution_order; - uint32_t child_order_index = 0; - for (auto &node : cnodes) { - execution_order.push_back(node); - if (node == graph->get_end_goto()) { - continue; - } - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { - std::vector label_switch_list = AnfAlgo::GetNodeAttr>(node, kAttrLabelSwitchList); - for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { - if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { - MS_LOG(EXCEPTION) << "Check label index fail"; - } - if (child_order_index >= graph->child_graph_order().size()) { - MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); - } - auto child_graph = graph->child_graph_order()[child_order_index++]; - auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); - execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); - } - } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { - uint32_t label_index = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); - if (!CheckLabelIndex(child_order_index, label_index, node, graph)) { - MS_LOG(EXCEPTION) << "Check label index fail"; - } - auto child_graph = graph->child_graph_order()[child_order_index++]; - auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); - execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); - } - } - graph->set_execution_order(execution_order); - graph->PrintGraphExecuteOrder(); - return execution_order; -} - -bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, - NotNull graph) { - const std::vector> &child_graph_order = graph->child_graph_order(); - // check index and child order size - if (child_graph_order.size() <= IntToSize(order_index)) { - MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " - << child_graph_order.size() << " goto index " << order_index; - } - auto child_graph = child_graph_order[order_index]; - MS_EXCEPTION_IF_NULL(child_graph); - - // get start_label_set_index of child graph - auto start_label_set = child_graph->get_start_label(); - uint32_t start_label_set_index = AnfAlgo::GetNodeAttr(start_label_set, kAttrLabelIndex); - if (label_index != start_label_set_index) { - MS_EXCEPTION_IF_NULL(cur_label); - MS_EXCEPTION_IF_NULL(start_label_set); - MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() - << " index " << start_label_set_index << " current child graph order : " << order_index; - return false; - } else { - return true; - } -} - -void AscendControlParser::UpdateChildGraphOrder(NotNull kg) { - MS_LOG(INFO) << "Graph id:" << kg->graph_id(); - kg->SetExecOrderByDefault(); - auto call_nodes = kg->FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); - std::vector child_graph_order; - for (auto &call_node : call_nodes) { - MS_EXCEPTION_IF_NULL(call_node); - auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); - for (const auto &child_graph : call_child_graphs) { - MS_EXCEPTION_IF_NULL(child_graph); - if (child_graph != kg->parent_graph()) { - child_graph->set_parent_graph(kg.get()); - } - child_graph_order.push_back(child_graph); - } - } - for (size_t i = 0; i < child_graph_order.size(); i++) { - MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; - } - kg->set_child_graph_order(child_graph_order); -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h deleted file mode 100644 index 7530f2019e..0000000000 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H -#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H - -#include -#include -#include -#include -#include "session/kernel_graph.h" -#include "utils/base_ref.h" -#include "utils/contract.h" -#include "utils/union_find_set.h" - -namespace mindspore { -namespace session { -class AscendControlParser { - public: - static void ChildGraphDataAssign(const std::map &graph_id_map); - static void LinkGraph(NotNull kg); - - static void InsertDependToGraph(NotNull kg, NotNull attch_node); - static void InsertControlDependToGraph(NotNull kg, NotNull first_node, - NotNull second_node); - static void ExecutorValidate(NotNull root_graph); - static void UpdateChildGraphOrder(NotNull kg); - - private: - static NotNull GetStartLabel(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label); - static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label, - const NotNull *> memo); - static void RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo); - static void RecurseSwitch(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo); - static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo); - - static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, - const CNodePtr &last_label); - static KernelGraphPtr ParsePartial(NotNull node); - - static void InsertMultipleAssignToGraph(NotNull from_graph, NotNull to_graph, - NotNull from, NotNull to); - static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); - - // root graph order - static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, - NotNull graph); - static std::vector RecurseGraph(NotNull graph, - const NotNull *> memo); -}; -} // namespace session -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H diff --git a/mindspore/ccsrc/session/ascend_inference_session.cc b/mindspore/ccsrc/session/ascend_inference_session.cc deleted file mode 100644 index aef7738d0b..0000000000 --- a/mindspore/ccsrc/session/ascend_inference_session.cc +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/ascend_inference_session.h" -#include "operator/ops.h" -#include "ir/tensor.h" -#include "ir/tensor_py.h" -#include "ir/anf.h" -#include "ir/param_value_py.h" -#include "device/kernel_runtime.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "common/trans.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "utils/config_manager.h" -#include "utils/base_ref_extends.h" - -using mindspore::tensor::TensorPy; - -namespace mindspore { -namespace session { -namespace { -std::set weight_infos; -static TypeId GetDataType(const py::buffer_info &buf) { - if (buf.format.size() == 1) { - switch (buf.format.front()) { - case 'e': - case 'f': - case 'd': - switch (buf.itemsize) { - case 2: - return TypeId::kNumberTypeFloat16; - case 4: - return TypeId::kNumberTypeFloat32; - case 8: - return TypeId::kNumberTypeFloat64; - } - break; - case 'b': - case 'h': - case 'i': - case 'l': - case 'q': - switch (buf.itemsize) { - case 1: - return TypeId::kNumberTypeInt8; - case 2: - return TypeId::kNumberTypeInt16; - case 4: - return TypeId::kNumberTypeInt32; - case 8: - return TypeId::kNumberTypeInt64; - } - break; - case 'B': - case 'H': - case 'I': - case 'L': - case 'Q': - switch (buf.itemsize) { - case 1: - return TypeId::kNumberTypeUInt8; - case 2: - return TypeId::kNumberTypeUInt16; - case 4: - return TypeId::kNumberTypeUInt32; - case 8: - return TypeId::kNumberTypeUInt64; - } - break; - case '?': - return TypeId::kNumberTypeBool; - } - } - MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize; - return TypeId::kTypeUnknown; -} -} // namespace -void AscendInferenceSession::LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - std::vector inputs(inputs_const); - auto input_nodes = kernel_graph->inputs(); - - size_t no_weight_input = 0; - for (size_t i = 0; i < input_nodes.size(); ++i) { - tensor::TensorPtr tensor = nullptr; - if (!input_nodes[i]->isa()) { - MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; - continue; - } - auto pk_node = input_nodes[i]->cast(); - MS_EXCEPTION_IF_NULL(pk_node); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - MS_EXCEPTION_IF_NULL(device_address); - if (AnfAlgo::IsParameterWeight(pk_node)) { - if (weight_infos.count(pk_node) != 0) { - continue; - } - auto param_value = std::dynamic_pointer_cast(pk_node->default_param()); - MS_EXCEPTION_IF_NULL(param_value); - auto py_param = param_value->value(); - MS_EXCEPTION_IF_NULL(py_param); - py::array py_array = py_param.cast(); - py::buffer_info buf = py_array.request(); - auto buf_type = GetDataType(buf); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(buf.size * buf.itemsize), buf_type, buf.ptr)) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } - weight_infos.insert(pk_node); - } else { - tensor = inputs[no_weight_input++]; - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } - } - } -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_inference_session.h b/mindspore/ccsrc/session/ascend_inference_session.h deleted file mode 100644 index 53be881f93..0000000000 --- a/mindspore/ccsrc/session/ascend_inference_session.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H -#define MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "session/ascend_session.h" -#include "session/kernel_graph.h" -#include "kernel/kernel.h" -#include "session/session_factory.h" -#include "session/ascend_control_parser.h" - -namespace mindspore { -namespace session { -class AscendInferenceSession : public AscendSession { - public: - AscendInferenceSession() = default; - ~AscendInferenceSession() = default; - void LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const; -}; -MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc deleted file mode 100644 index f361cb26ca..0000000000 --- a/mindspore/ccsrc/session/ascend_session.cc +++ /dev/null @@ -1,1755 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/ascend_session.h" -#include -#include -#include -#include -#include -#include -#include "operator/ops.h" -#include "ir/tensor.h" -#include "ir/anf.h" -#include "common/trans.h" -#include "device/kernel_runtime.h" -#include "device/ascend/kernel_select_ascend.h" -#include "device/ascend/kernel_build_ascend.h" -#include "device/ascend/ascend_kernel_runtime.h" -#include "device/ascend/ascend_device_address.h" -#include "pre_activate/ascend/ascend_backend_optimization.h" -#include "pre_activate/common/common_backend_optimization.h" -#include "device/kernel_adjust.h" -#include "device/ascend/ascend_stream_assign.h" -#include "device/ascend/ascend_label_assign.h" -#include "predict/predict.h" -#include "session/anf_runtime_algorithm.h" -#include "ir/scalar.h" -#include "debug/anf_ir_dump.h" -#include "debug/anf_ir_utils.h" -#include "debug/draw.h" -#include "common/utils.h" -#include "pre_activate/common/helper.h" -#include "device/kernel_runtime_manager.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "utils/config_manager.h" -#include "utils/base_ref_extends.h" -#include "debug/tensor_load.h" - -namespace mindspore { -namespace session { -const size_t kInvalidIndex = SIZE_MAX; -namespace { -void DumpGraphExeOrder(const std::vector &execution_order, const std::string &tag = "") { - MS_LOG(INFO) << "Dump execution_order size " << execution_order.size(); - MS_LOG(INFO) << "[index][stream_label][graph_id][node string]"; - int i = 0; - for (auto &cnode : execution_order) { - MS_EXCEPTION_IF_NULL(cnode); - MS_LOG(INFO) << "[ " << i << "]" - << "[" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "]" - << "[" << AnfAlgo::GetGraphId(cnode.get()) << "]" - << "[" << cnode->DebugString() << "]"; - i++; - } - - std::stringstream buf; - buf << "================== execution order ==================\n"; - if (!tag.empty()) { - buf << tag << "\n"; - } - buf << "execution_order size: " << execution_order.size() << "\n"; - i = 0; - for (auto &cnode : execution_order) { - MS_EXCEPTION_IF_NULL(cnode); - buf << i << ":\n"; - buf << "\t" << cnode->DebugString() << "\n"; - buf << "\t" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "\n"; - buf << "\t" << AnfAlgo::GetGraphId(cnode.get()) << "\n"; - i++; - } - buf << "================== execution order ==================\n"; - // std::cout << buf.str() << std::endl; -} - -void DumpGraphInputArgs(const VectorRef &args) { - MS_LOG(INFO) << "Args size[%lu]" << args.size(); - for (size_t i = 0; i < args.size(); i++) { - if (utils::isa(args[i])) { - auto anf = utils::cast(args[i]); - MS_EXCEPTION_IF_NULL(anf); - MS_LOG(INFO) << "Parameter arg" << i << " = [%s]" << anf->DebugString(); - } else if (utils::isa(args[i])) { - auto value = utils::cast(args[i]); - MS_EXCEPTION_IF_NULL(value); - MS_LOG(INFO) << "Tensor arg" << i << " = " << value->ToString(); - } else { - MS_LOG(INFO) << "Unknown arg" << i << " = " << args[i].ToString(); - } - } -} - -void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) { - MS_EXCEPTION_IF_NULL(graph); - if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) { - graph->set_stream_distinction_label(label); - } -} - -std::vector GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) { - MS_EXCEPTION_IF_NULL(graph); - std::vector graph_inputs = graph->inputs(); - auto valid_inputs = graph->valid_inputs(); - size_t real_args_size = 0; - std::vector real_args = {}; - for (size_t i = 0; i < args.size(); i++) { - if (utils::isa(args[i])) { - auto tmp_args = AnfAlgo::GetAllOutput(utils::cast(args[i]), {prim::kPrimTupleGetItem}); - for (auto &real_arg : tmp_args) { - auto anf_node = utils::cast(real_arg); - MS_EXCEPTION_IF_NULL(anf_node); - auto abstract = anf_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // create multiple parameters if is a tuple output real kernel - if (abstract->isa() && - !AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - real_args_size += tuple_abstract->size(); - continue; - } - real_args_size += 1; - real_args.push_back(real_arg); - } - } else { - real_args_size += 1; - real_args.push_back(args[i]); - } - } - if (graph_inputs.size() != valid_inputs.size()) { - MS_LOG(EXCEPTION) << "Graph_inputs.size(): " << graph_inputs.size() - << ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; - } - if (real_args_size != graph_inputs.size()) { - for (size_t j = 0; j < valid_inputs.size(); j++) { - if (valid_inputs[j]) { - MS_LOG(INFO) << "Index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); - } - } - MS_LOG(WARNING) << "Real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() - << " not equal"; - } - return real_args; -} - -std::vector GetCNodes(const std::vector &anf_nodes) { - std::vector cnodes = {}; - size_t i = 0; - for (const auto &anf : anf_nodes) { - MS_LOG(INFO) << "Apply_list[" << i++ << "] = " << anf->DebugString(); - MS_EXCEPTION_IF_NULL(anf); - if (anf->isa()) { - cnodes.push_back(anf->cast()); - } - } - return cnodes; -} - -static std::vector> GetChildList(const std::vector &cnodes, - const std::set &cut_prims) { - size_t after_cut_index = 0; - std::vector> ret; - for (size_t i = 0; i < cnodes.size(); ++i) { - bool is_cut_node = false; - for (auto &prim : cut_prims) { - if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim)) { - is_cut_node = true; - break; - } - } - if (is_cut_node) { - // is call and not switch call,cut to 3 lists - if (!AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall)) { - // if is not a call,cut to 2 lists - ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); - after_cut_index = i; - } else if (!AnfAlgo::IsSwitchCall(cnodes[i])) { - ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); - ret.emplace_back(1, cnodes[i]); - after_cut_index = i + 1; - continue; - } - } - // get last child graph list - if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { - ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.end()); - continue; - } - } - return ret; -} - -static void BindCallArgsWithParameter(const std::vector ¶meters, const std::vector &args, - const KernelGraphPtr &graph, KernelGraphPtr child_graph, - const NotNull *> memo) { - MS_EXCEPTION_IF_NULL(child_graph); - MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id(); - if (args.empty()) { - return; - } - if (parameters.size() != args.size()) { - MS_LOG(EXCEPTION) << "Graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() - << " and args size:" << args.size() << " not equal!"; - } - child_graph->SetExecOrderByDefault(); - for (size_t i = 0; i < parameters.size(); i++) { - MS_LOG(INFO) << "parameters[" << i << "]" << parameters[i]->DebugString() << ",args[" << i << "]" - << args[i]->DebugString(); - if (args[i] == parameters[i]) { - MS_LOG(INFO) << "Parameter and arg are same."; - continue; - } - child_graph->SetRealInput(parameters[i], args[i]); - if (memo->find(child_graph) != memo->end() || !args[i]->isa()) { - MS_LOG(INFO) << "Add unreused arg,graph:" << graph->graph_id(); - child_graph->AddUnreuseArgs(args[i], graph); - } - } -} - -// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of -// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] -static void UpdateRealInput(NotNull graph, bool split_flag, - const NotNull *> memo) { - MS_EXCEPTION_IF_NULL(memo.get()); - auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); - for (auto &call_node : call_nodes) { - MS_EXCEPTION_IF_NULL(call_node); - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); - if (child_graphs.size() == 1) { - MS_EXCEPTION_IF_NULL(child_graphs[0]); - std::vector real_args = - std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()); - std::vector child_inputs = child_graphs[0]->inputs(); - BindCallArgsWithParameter(child_inputs, real_args, graph, child_graphs[0], memo); - if (split_flag) { - call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); - } - } else if (child_graphs.size() == 2) { - auto get_partial_args = [&](size_t input_index) -> std::vector { - auto switch_node = call_node->input(1); - MS_EXCEPTION_IF_NULL(switch_node); - auto switch_cnode = switch_node->cast(); - MS_EXCEPTION_IF_NULL(switch_cnode); - auto partial = switch_cnode->input(input_index); - MS_EXCEPTION_IF_NULL(partial); - if (IsValueNode(partial)) { - return {}; - } - auto partial_cnode = partial->cast(); - MS_EXCEPTION_IF_NULL(partial_cnode); - auto ret = std::vector(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); - if (split_flag) { - partial_cnode->set_inputs( - std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); - } - return ret; - }; - BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), graph, child_graphs[0], memo); - BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), graph, child_graphs[1], memo); - } - } -} - -static void RecurseToUpdateCallRealInput(NotNull graph, - const NotNull *> memo) { - memo->insert(graph.get()); - MS_LOG(INFO) << "Start graph id:" << graph->graph_id(); - for (auto &child_graph : graph->child_graph_order()) { - if (memo->find(child_graph) != memo->end()) { - MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() - << ",parent graph:" << graph->parent_graph()->graph_id(); - continue; - } - RecurseToUpdateCallRealInput(NOT_NULL(child_graph), memo); - } - // this action should from bottom to top - graph->UpdateCallRealInput(); -} -} // namespace - -GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - MS_LOG(INFO) << "Start"; - // construct graph, if successfully, graph_sum_ + 1 - auto graph = ConstructKernelGraph(lst, outputs); - auto graph_id = graph->graph_id(); - MS_LOG(INFO) << "Compile graph " << graph_id << " success"; - return graph_id; -} - -GraphId AscendSession::CompileGraph(NotNull func_graph) { - MS_LOG(INFO) << "Start"; - std::vector all_graphs; - auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); - BackendOptimization(all_graphs); - // empty graph dont entry to backend - if (root_graph->execution_order().empty()) { - MS_LOG(INFO) << root_graph->ToString() << " is empty graph."; - root_graph->set_executable(false); - InitRuntimeResource(); - return root_graph->graph_id(); - } - // split switch - SplitGraphs(NOT_NULL(root_graph)); - // insert goto labels and label_sets - LinkChildGraphs(NOT_NULL(root_graph)); - // resource initialize - InitRuntimeResource(); - // recurse compile child root_graph - std::set memo; - RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo)); - // root root_graph valiate,include genearte execute order and so on - RootGraphExecutorValidate(NOT_NULL(root_graph)); - // adjust kernel - AdjustKernel(root_graph); - // assign stream - AssignStream(NOT_NULL(root_graph)); - // insert profiling point - device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); - // build kernel - BuildKernel(root_graph); - // alloc mem - MemoryAlloc(root_graph.get()); - // task generate - GenerateTaskInfo(root_graph); - // load task into device - LoadTask(root_graph); - // return the root_graph id to backend - auto graph_id = root_graph->graph_id(); - return graph_id; -} - -void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto graph_order = GetGraphOrder(kernel_graph->graph_id()); - for (auto graph_id : graph_order) { - auto child_graph = GetGraph(graph_id); - if (child_graph == nullptr) { - continue; - } - if (child_graph->summary_node_exist()) { - kernel_graph->set_summary_node_exist(true); - return; - } - } - kernel_graph->set_summary_node_exist(false); -} - -void AscendSession::BuildGraph(GraphId graph_id) { - MS_LOG(INFO) << "Start"; - auto graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(graph); - // resource initialize - InitRuntimeResource(); - // multiple graph handle - if (graph_id == final_graph_id_) { - if (!graph->executable()) { - return; - } - // insert assigns to child graph - InsertAllAssigns(); - // insert switch and active to child graph - MergeSwitchCompile(); - SetFinalGraphSummaryFlag(graph); - // OptChildGraphs - auto graph_order = GetGraphOrder(final_graph_id_); - auto &graph_type = GetGraphOrderType(final_graph_id_); - for (size_t i = 0; i < graph_order.size(); i++) { - if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { - continue; - } - MS_LOG(INFO) << "Start build child graph " << graph_order[i]; - auto child_graph = GetGraph(graph_order[i]); - CompileChildGraph(child_graph); - } - GetSummaryNodes(graph.get()); - // merge child graph - MergeGraphExecOrder(); - } else { - auto single_graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(single_graph); - CompileChildGraph(single_graph); - // set the distinction label of single graph - single_graph->set_stream_distinction_label(graph_id); - single_graph->UpdateExecuteKernelStreamLabel(); - } - // adjust execution order because merge child graph and other special operations - AdjustKernel(graph); - // Assign streams for control sink and hccl and so on - AssignStream(NOT_NULL(graph)); - - device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get())); - // build kernel if node is cnode - BuildKernel(graph); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->precompile_only()) { - MS_LOG(INFO) << "Precompile only, stop in build kernel step"; - } else { - // alloc memory, including static memory and dynamic memory - MemoryAlloc(graph.get()); - // generate task info for task sink mode - GenerateTaskInfo(graph); - // load task info to device if it is sink mode - LoadTask(graph); - } - // sync the inital const tensor to device - SyncInitialTenosrToDevice(); - ExportChildGraphs(graph_id); - MS_LOG(INFO) << "End"; -} - -void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { - MS_EXCEPTION_IF_NULL(child_graph); - MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString(); - opt::AscendBackendIRFusionOptimization(child_graph); - opt::AscendBackendFuseBasicOpt(child_graph, true); - opt::AscendBackendGraphKernelOpt(child_graph, true); - child_graph->SetExecOrderByDefault(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir"; - DumpIR(file_path, child_graph); - } - // select kernel build info - SelectKernel(*child_graph); - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir"; - DumpIR(file_path, child_graph); - } - // convert kernel Graph to model - predictmodel::StepConvertGraph(child_graph); - // optimize graph - HardwareOptimize(child_graph); - // assign static memory of parameters - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->AssignStaticMemoryInput(child_graph.get()); - runtime_instance->AssignStaticMemoryValueNode(child_graph.get()); -} - -void AscendSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, - VectorRef *const outputs) { - MS_LOG(INFO) << "Start"; - auto kernel_graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(kernel_graph); - // if none of child graph and no anf output exists - if (!kernel_graph->executable()) { - MS_LOG(INFO) << "No child graph has anf output"; - UpdateOutputs(kernel_graph, outputs, inputs); - return; - } - // load input data from user input - LoadInputData(kernel_graph, inputs); - // convert inputs to model - predictmodel::StepConvertWeight(inputs); -#ifdef ENABLE_DEBUGGER - // debugger pre-execution processing - if (debugger_) { - debugger_->PreExecute(kernel_graph); - } -#endif - { - py::gil_scoped_release release; - // run task on device - ExecTask(kernel_graph); - } - // get result from device - UpdateOutputs(kernel_graph, outputs, inputs); - // summary - Summary(kernel_graph.get()); -#ifdef ENABLE_DEBUGGER - // load tensor from device for debugger - if (debugger_ && debugger_->debugger_enabled()) { - LoadTensor(kernel_graph); - } -#endif - // dump used for debug - Dump(kernel_graph); -#ifdef ENABLE_DEBUGGER - // debugger post-execution processing - if (debugger_) { - debugger_->PostExecute(); - } -#endif - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start"; - // data layout optimization - opt::RunOpAscendDataLayout(kernel_graph); - // mixed precision optimization - opt::AscendMixPrecision(kernel_graph); - MS_LOG(INFO) << "Finish"; -} - -void AscendSession::RunOpExecTask(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get()); - if (!ret_ok) { - MS_LOG(EXCEPTION) << "Run task error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { - if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { - return true; - } - - return false; -} - -void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) { - MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; - if (GraphCacheExist(graph_info)) { - MS_LOG(INFO) << "Build op " << op_run_info.op_name << " graph cache has existed !"; - return; - } - - // construct graph include one op - auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); - MS_EXCEPTION_IF_NULL(graph); - opt::RunOpAscendBackendIRFusionOptimization(graph); - // kernel select - SelectKernel(*graph); - // optimize - RunOpHardwareOptimize(graph); - // init runtime resource - InitRuntimeResource(); - // build kernel - RunOpAdjustKernel(graph); - BuildKernel(graph); - run_op_graphs_[graph_info] = graph; - MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; -} - -py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) { - auto graph = run_op_graphs_[graph_info]; - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; - // malloc mem - RunOpMemoryAlloc(input_tensors, graph.get()); - // load input data to device - LoadInputData(graph, input_tensors); - // run op - RunOpExecTask(graph); - // get output - VectorRef outputs; - UpdateOutputs(graph, &outputs, input_tensors); - // trans output to tuple - auto output_tensors = TransformBaseRefListToTuple(outputs); - if (!utils::isa(output_tensors) || - !py::isinstance(utils::cast(output_tensors).object_)) { - MS_LOG(EXCEPTION) << "The output tensors should be a tuple !"; - } - py::object tuple_obj = utils::cast(output_tensors).object_; - py::tuple tuple_tensors = py::cast(tuple_obj); - RunOpMemoryClear(graph.get()); - MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; - return tuple_tensors; -} - -// compile graph steps -void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - size_t raise_precision_count = 0; - size_t reduce_precision_count = 0; - for (const auto &cnode : kernel_graph.execution_order()) { - auto status = device::ascend::SelectKernelInfo(cnode); - if (status == device::ascend::kStatusRaisePrecision) { - raise_precision_count++; - } else if (status == device::ascend::kStatusReducePrecision) { - reduce_precision_count++; - } - MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); - } - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kGraphMode) { - if (raise_precision_count > 0) { - MS_LOG(WARNING) << "There has " << raise_precision_count - << " node/nodes used raise precision to selected the kernel!"; - } - if (reduce_precision_count > 0) { - MS_LOG(WARNING) << "There has " << reduce_precision_count - << " node/nodes used reduce precision to selected the kernel!"; - } - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::InitRuntimeResource() { - MS_LOG(INFO) << "Start!"; - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - if (!runtime_instance->Init()) { - MS_LOG(EXCEPTION) << "Kernel runtime init error."; - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::HardwareOptimize(const std::shared_ptr &kernel_graph) const { - device::ascend::KernelPreBuild(kernel_graph.get()); - MS_LOG(INFO) << "HardwareOptimize start!"; - opt::AscendBackendOptimization(kernel_graph); - opt::AscendGraphKernelCommonProcess(kernel_graph); - opt::AscendBackendFuseBasicOpt(kernel_graph, false); - opt::AscendBackendAddAtomicClean(kernel_graph); - MS_EXCEPTION_IF_NULL(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - MS_LOG(INFO) << "HardwareOptimize Finish!"; -} - -void AscendSession::AdjustKernel(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - opt::HideNopNode(kernel_graph.get()); - // Insert CLearZero op - // prepare for next step from json get atomic info - BuildKernel(kernel_graph); - device::ascend::KernelBuildPreprocess(kernel_graph.get()); - device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "after_adjust_kernel.ir"; - DumpIR(file_path, kernel_graph); - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - opt::HideNopNode(kernel_graph.get()); - // Insert CLearZero op - // prepare for next step from json get atomic info - BuildKernel(kernel_graph); - device::ascend::KernelBuildPreprocess(kernel_graph.get()); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::AssignStream(NotNull kernel_graph) const { - MS_LOG(INFO) << "Start!"; - device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::BuildKernel(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); - auto ret = device::ascend::KernelBuild(kernel_graph.get()); - if (!ret) { - MS_LOG(EXCEPTION) << "Kernel build error."; - } - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "KernelBuild run in " << PRIu64 << " us " << cost; - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { - MS_LOG(INFO) << "Start!"; - MS_EXCEPTION_IF_NULL(kernel_graph); - opt::RemoveNopNode(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->AssignMemory(kernel_graph); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::RunOpMemoryAlloc(const std::vector &input_tensors, - KernelGraph *kernel_graph) const { - MS_LOG(INFO) << "Start memory alloc!"; - MS_EXCEPTION_IF_NULL(kernel_graph); - opt::RemoveNopNode(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpClearMemory(kernel_graph); -} - -void AscendSession::GenerateTaskInfo(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - bool ret_ok = runtime_instance->GenTask(kernel_graph.get()); - if (!ret_ok) { - MS_LOG(EXCEPTION) << "Generate task error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::LoadTask(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - bool ret_ok = runtime_instance->LoadTask(kernel_graph.get()); - if (!ret_ok) { - MS_LOG(EXCEPTION) << "Load task error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::ExecTask(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - bool ret_ok = runtime_instance->Run(kernel_graph.get()); - if (!ret_ok) { - MS_LOG(EXCEPTION) << "run task error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::Dump(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - MS_EXCEPTION_IF_NULL(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - (void)runtime_instance->DumpData(kernel_graph.get()); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::ExportChildGraphs(const GraphId graph_id) { -#ifdef ENABLE_DUMP_IR - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - if (!save_graphs) { - return; - } - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (graph_id == final_graph_id_) { - const auto &graph_order = GetGraphOrder(final_graph_id_); - const auto &graph_type = GetGraphOrderType(final_graph_id_); - for (size_t i = 0; i < graph_order.size(); i++) { - if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { - continue; - } - const auto child_graph = GetGraph(graph_order[i]); - MS_LOG(DEBUG) << "Start export child graph " << graph_order[i]; - MS_EXCEPTION_IF_NULL(child_graph); - std::string file_path = save_graphs_path + "/graph_build_" + std::to_string(child_graph->graph_id()) + ".ir"; - DumpIR(file_path, child_graph, true); - DumpIRProto(child_graph, "vm_build_" + std::to_string(child_graph->graph_id())); - MS_LOG(DEBUG) << "End export child graph " << graph_order[i]; - } - } -#endif -} - -void AscendSession::LoadTensor(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - MS_EXCEPTION_IF_NULL(kernel_graph); -#ifdef ENABLE_DEBUGGER - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - DebugServices *debug_services = debugger_->get_debug_services(); - TensorLoader *tensor_loader = debug_services->get_tensor_loader(); - tensor_loader->EmptyTensor(); - uint32_t iter_num = tensor_loader->GetIterNum(); - tensor_loader->set_iter_num(++iter_num); - (void)runtime_instance->LoadData(kernel_graph.get(), debugger_.get()); -#endif - MS_LOG(INFO) << "Finish!"; -} - -GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { - MS_LOG(INFO) << "Start! Args size " << args.size(); - auto final_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(final_graph); - final_graph_id_ = final_graph->graph_id(); - MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success"; - // init private variables and bind them with final_graph_id - graph_execute_orders_[final_graph_id_] = std::vector(); - graph_order_types_[final_graph_id_] = std::vector(); - for (const auto ¶meter : args) { - MS_EXCEPTION_IF_NULL(parameter); - if (!parameter->isa()) { - MS_LOG(EXCEPTION) << parameter->DebugString() << " is not a parameter type!"; - } - AnfNodePtr parameter_backend = nullptr; - // if function return UINT_MAX,the parameter is not exist in child graph - auto parameter_belong_graph_id = GetGraphIdByNode(parameter); - if (parameter_belong_graph_id == kInvalidGraphId) { - parameter_backend = CreateNewParameterFromParameter(parameter, true, final_graph.get()); - final_graph->FrontBackendlMapAdd(parameter, parameter_backend); - MS_LOG(INFO) << "New parameter" << parameter->DebugString() << "in final_graph"; - } else { - // parametr is a parameter of child graph - auto graph = GetGraph(parameter_belong_graph_id); - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Reuse parameter [" << parameter->DebugString() << "] of child graph [" - << parameter_belong_graph_id << "]"; - parameter_backend = graph->GetBackendAnfByFrontAnf(parameter); - // add parameter in backend to final graph inputs - auto final_graph_inputs = final_graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(final_graph_inputs); - final_graph_inputs->push_back(parameter_backend); - } - MS_EXCEPTION_IF_NULL(parameter_backend); - MS_LOG(INFO) << "Parameter backend " << parameter_backend->DebugString() << " belong_graph_id " - << AnfAlgo::GetGraphId(parameter_backend.get()); - } - MS_LOG(INFO) << "End final_graph_id " << final_graph_id_; - return final_graph_id_; -} - -void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph, - std::map> *summary) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(summary); - // if final graph have no child graph - auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); - if (graph_order_iter == graph_execute_orders_.end()) { - SessionBasic::GetSummaryNodes(graph); - auto summary_nodes = graph->summary_nodes(); - summary->insert(summary_nodes.begin(), summary_nodes.end()); - return; - } - // for every child graph, find summary nodes - auto graph_order = GetGraphOrder(graph->graph_id()); - for (size_t i = 0; i < graph_order.size(); i++) { - auto child_graph = GetGraph(graph_order[i]); - if (child_graph == nullptr) { - continue; - } - SessionBasic::GetSummaryNodes(child_graph.get()); - auto child_graph_summary = child_graph->summary_nodes(); - summary->insert(child_graph_summary.begin(), child_graph_summary.end()); - RecurseGetSummaryNodes(child_graph.get(), summary); - } - graph->set_summary_nodes(*summary); -} - -void AscendSession::GetSummaryNodes(KernelGraph *graph) { - MS_LOG(DEBUG) << "Update summary Start"; - MS_EXCEPTION_IF_NULL(graph); - auto summary_nodes = graph->summary_nodes(); - std::map> summary; - summary.insert(summary_nodes.begin(), summary_nodes.end()); - RecurseGetSummaryNodes(graph, &summary); - graph->set_summary_nodes(summary); - MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); -} - -AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { - auto fake_graph = GetGraph(fake_graph_id); - MS_EXCEPTION_IF_NULL(fake_graph); - auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0); - auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr { - auto parameter = fake_graph->NewParameter(); - MS_EXCEPTION_IF_NULL(parameter); - parameter->set_abstract(abstract); - auto new_parameter = fake_graph->NewParameter(parameter); - // Add new parameter to the graph input of fake_graph to sure that all parameters will be allocated memory. - auto graph_inputs = fake_graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - graph_inputs->push_back(new_parameter); - return new_parameter; - }; - auto create_parameter_from_cnode = [&](const AnfNodePtr &cnode, size_t output_idx) -> AnfNodePtr { - MS_EXCEPTION_IF_NULL(cnode); - auto abstract = cnode->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // create multiple parameters if is a tuple output real kernel - if (abstract->isa()) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - MS_LOG(INFO) << "Tuple size [" << tuple_abstract->size() << "]"; - return create_parameter((*tuple_abstract)[output_idx]); - } - return create_parameter(cnode->abstract()); - }; - if (AnfAlgo::CheckPrimitiveType(output_item_with_index.first, prim::kPrimMakeTuple)) { - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; - auto make_tuple = output_item_with_index.first->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - for (size_t i = 1; i < make_tuple->inputs().size(); i++) { - auto input = make_tuple->inputs()[i]; - make_tuple_inputs.push_back(CreateFakeOutput(fake_graph_id, input)); - } - return fake_graph->NewCNode(make_tuple_inputs); - } - return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second); -} - -void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) { - // get the backend anf node related to the output node of front - auto output_from_graph_id = GetGraphIdByNode(node); - auto output_from_graph = GetGraph(output_from_graph_id); - MS_EXCEPTION_IF_NULL(node); - MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id - << "] to final graph"; - MS_EXCEPTION_IF_NULL(output_from_graph); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - // if output is from final graph,it remarks no child graph exist - if (final_graph_id_ == output_from_graph_id) { - MS_LOG(INFO) << "No child graph,output is " << node->DebugString(); - final_graph->set_output(ConstructOutput({node}, final_graph)); - final_graph->set_executable(false); - return; - } - final_graph->set_output(output_from_graph->output()); -} - -void AscendSession::SetFinalGraphOutput(const ValuePtr &value) { - auto value_node = NewValueNode(value); - auto kernel_info = std::make_shared(); - value_node->set_kernel_info(kernel_info); - value_node->set_abstract(abstract::FromValue(value)); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node})); - final_graph->set_executable(false); - MS_EXCEPTION_IF_NULL(value); - MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]"; -} - -void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) { - for (auto &output : vec_output) { - if (utils::isa(output)) { - auto output_anf_node = utils::cast(output); - SetFinalGraphOutput(output_anf_node); - } else if (utils::isa(output)) { - auto value = utils::cast(output); - SetFinalGraphOutput(value); - } else { - MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); - } - } -} - -void AscendSession::SetFinalGraphOutput(const BaseRef &output) { - if (utils::isa(output)) { - auto output_anf_node = utils::cast(output); - SetFinalGraphOutput(output_anf_node); - } else if (utils::isa(output)) { - auto value = utils::cast(output); - SetFinalGraphOutput(value); - } else if (utils::isa(output)) { - auto vec_output = utils::cast(output); - SetFinalGraphOutput(vec_output); - } else { - MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); - } -} - -void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) { - MS_LOG(INFO) << "Start!"; - MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]"; - auto condition_graph = GetGraph(condition_graph_id); - MS_EXCEPTION_IF_NULL(condition_graph); - tensor::TensorPtr tensor = std::make_shared(kNumberTypeInt32, std::vector{1}); - int32_t *val = nullptr; - val = static_cast(tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 0; - auto value_node = std::make_shared(tensor); - value_node->set_abstract(abstract::FromValue(tensor, false)); - auto counter_const = condition_graph->NewValueNode(value_node); - condition_graph->AddValueNodeToGraph(counter_const); - // create a new switch op - auto switch_primitive = std::make_shared("StreamSwitch"); - auto cond_output_it = condition_output_.find(condition_graph_id); - if (cond_output_it == condition_output_.end()) { - MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; - } - auto cond_output_kernel = - AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first; - MS_EXCEPTION_IF_NULL(cond_output_kernel); - std::vector inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; - CNodePtr switch_node = condition_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(switch_node); - switch_node->set_abstract(std::make_shared()); - AnfAlgo::SetGraphId(condition_graph_id, switch_node.get()); - // set attr: cond_ RT_GREATER - AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue(static_cast(RT_GREATER)), switch_node); - // set attr:data_type - AnfAlgo::SetNodeAttr(kAttrDataType, MakeValue(static_cast(RT_SWITCH_INT64)), switch_node); - // set attr:true branch graph id ,which is same to stream distinction label - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(true_graph_id), switch_node); - // append switch at the end of condition graph - auto return_node = condition_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - InsertControlDependToGraph(condition_graph_id, return_node->input(1), switch_node); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { - auto &graph_execute_order = GetGraphOrder(final_graph_id_); - auto &graph_order_type = GetGraphOrderType(final_graph_id_); - auto false_index = ExecOrderOfChildGraph(final_graph_id_, false_graph_id); - if (false_index == kInvalidIndex || false_index == 0) { - return; - } - for (int i = SizeToInt(false_index) - 1; i >= 0; i--) { - size_t graph_index = IntToSize(i); - if (graph_index >= graph_execute_order.size()) { - MS_LOG(EXCEPTION) << "Graph index[" << graph_index << "] out of range[" << graph_execute_order.size() << "]"; - } - if (graph_order_type[graph_index] == COMMON_GRAPH) { - auto true_last_id = graph_execute_order[graph_index]; - MS_LOG(INFO) << "The last graph of if true branch is " << true_last_id; - auto true_last = GetGraph(true_last_id); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - auto false_last = GetGraph(false_graph_id); - MS_EXCEPTION_IF_NULL(true_last); - MS_EXCEPTION_IF_NULL(false_last); - MS_LOG(INFO) << "The last graph of false branch is " << false_graph_id; - // create fake output - auto fake_output_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(fake_output_graph); - graph_execute_order.push_back(fake_output_graph->graph_id()); - graph_order_type.push_back(COMMON_GRAPH); - fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output())); - final_graph->set_output(fake_output_graph->output()); - InsertMultipleAssignToGraph(true_last_id, true_last->output(), final_graph->output()); - InsertMultipleAssignToGraph(false_graph_id, false_last->output(), final_graph->output()); - // insert stream active for loop sink - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && - ConfigManager::GetInstance().iter_num() > 1) { - // insert active in true graph, another active will be inserted in kernel adjust - InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel); - } - break; - } - } -} - -void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id, - const AnfNodePtr &output) { - if (switches_.find(cond_graph_id) != switches_.end()) { - MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before "; - return; - } - switches_[cond_graph_id] = std::pair(true_graph_id, false_graph_id); - condition_output_[cond_graph_id] = output; - MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; - // set the type of condition graph - auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); - auto &graph_order_type = GetGraphOrderType(final_graph_id_); - if (cond_graph_index >= graph_order_type.size()) { - MS_LOG(EXCEPTION) << "Cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size(); - } - graph_order_type[cond_graph_index] = CONDITION_GRAPH; - // update distinction label of false graph,update before merge to sure the distinction - if (false_graph_id != kInvalidGraphId) { - // false graph and condition in graph same stream - auto condition_graph = GetGraph(cond_graph_id); - MS_EXCEPTION_IF_NULL(condition_graph); - SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); - // if false graph is a condition graph and has been switch compiled before,it's false should be updated again - auto cond_it = switches_.find(false_graph_id); - while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) { - cond_graph_id = cond_it->first; - false_graph_id = cond_it->second.second; - condition_graph = GetGraph(cond_graph_id); - if (condition_graph == nullptr) { - continue; - } - SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); - cond_it = switches_.find(false_graph_id); - } - } -} // namespace session - -void AscendSession::MergeSwitchCompile() { - auto graph_execute_order = GetGraphOrder(final_graph_id_); - auto &graph_order_type = GetGraphOrderType(final_graph_id_); - for (auto switch_compile : switches_) { - auto cond_graph_id = switch_compile.first; - auto true_graph_id = switch_compile.second.first; - auto false_graph_id = switch_compile.second.second; - MS_LOG(INFO) << "Switch compile: " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; - auto condition_graph = GetGraph(cond_graph_id); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(condition_graph); - MS_EXCEPTION_IF_NULL(final_graph); - // insert switch to condition graph - InsertSwitchToGraph(cond_graph_id, true_graph_id); - auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); - auto prev_graph_id = kInvalidGraphId; - // if condition graph is the first graph and final graph has assign op,then the final graph is the common graph - if (cond_graph_index == 0 && !final_graph->execution_order().empty()) { - prev_graph_id = final_graph_id_; - // set the distinction label of final graph - SetStreamDistinctionLabel(final_graph, final_graph_id_, true); - // if condition graph is not the first graph - } else if ((cond_graph_index - 1 < graph_execute_order.size()) && - (graph_order_type[cond_graph_index - 1] == COMMON_GRAPH)) { - prev_graph_id = graph_execute_order[cond_graph_index - 1]; - } - // insert stream active to common graph - if (prev_graph_id != kInvalidGraphId) { - InsertStreamActiveToGraph(prev_graph_id, condition_graph->stream_distinction_label()); - } - // if this is a 'if' condition - auto it = while_condition_graphs_.find(cond_graph_id); - if (it == while_condition_graphs_.end()) { - CopyOutputOfIf(false_graph_id); - } else { - // if it is a while,insert a stream active to true graph - GraphId from_graph = it->second; - InsertStreamActiveToGraph(from_graph, condition_graph->stream_distinction_label()); - } - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::InsertAllAssigns() { - std::vector> assigns; - for (auto assign : assigns_) { - auto front_anf = std::get<0>(assign); - auto to_graph_id = std::get<1>(assign); - auto input_idx = std::get<2>(assign); - auto to_graph = GetGraph(to_graph_id); - MS_EXCEPTION_IF_NULL(to_graph); - std::vector graph_inputs = to_graph->inputs(); - if (input_idx >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); - } - auto backend_parameter = graph_inputs[input_idx]; - assigns.emplace_back(std::pair(front_anf, backend_parameter)); - } - // erase the repeat assign - std::set> inserted_nodes; - for (auto &assign : assigns) { - auto front_anf = assign.first; - auto backend_parameter = assign.second; - auto from_graph_id = GetGraphIdByNode(front_anf); - auto from_graph = GetGraph(from_graph_id); - MS_EXCEPTION_IF_NULL(from_graph); - auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); - if (inserted_nodes.find(assign) == inserted_nodes.end()) { - InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); - (void)inserted_nodes.insert(assign); - } - } -} - -// insert active to graph -void AscendSession::SetActive(GraphId from, GraphId to) { - if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) { - MS_LOG(WARNING) << "To " << to << " has been exits in map,from " << from << ",exist from " - << while_condition_graphs_[to]; - return; - } - MS_LOG(INFO) << "From " << from << " to " << to; - auto &graph_order = GetGraphOrder(final_graph_id_); - auto &graph_type = GetGraphOrderType(final_graph_id_); - std::vector graph_order_new; - std::vector graph_type_new; - for (size_t i = 0; i < graph_order.size(); i++) { - auto graph_id = graph_order[i]; - graph_order_new.push_back(graph_id); - graph_type_new.push_back(graph_type[i]); - if (from == graph_id) { - graph_order_new.push_back(kInvalidGraphId); - graph_type_new.push_back(BRANCH_END); - } - } - graph_order = graph_order_new; - graph_type = graph_type_new; - // set the graph type of condition graph - graph_type[ExecOrderOfChildGraph(final_graph_id_, to)] = CONDITION_GRAPH; - // record the condition graph into while condition set - while_condition_graphs_[to] = from; -} - -void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx) { - MS_LOG(INFO) << "Start!"; - MS_EXCEPTION_IF_NULL(front_anf); - auto from_graph_id = GetGraphIdByNode(front_anf); - auto from_graph = GetGraph(from_graph_id); - MS_EXCEPTION_IF_NULL(from_graph); - auto to_graph = GetGraph(to_graph_id); - MS_EXCEPTION_IF_NULL(to_graph); - std::vector graph_inputs = to_graph->inputs(); - if (input_idx >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); - } - auto backend_parameter = graph_inputs[input_idx]; - MS_EXCEPTION_IF_NULL(backend_parameter); - auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); - MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" - << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) - << "]"; - // a node should not assign to itself - if (backend_arg.get() == backend_parameter.get()) { - return; - } - // if arg is the the parameter of child graph,it is parameter of final graph too - if (front_anf->isa()) { - MS_EXCEPTION_IF_NULL(backend_arg); - MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString() - << "] will be replaced."; - to_graph->ReplaceNode(NOT_NULL(backend_parameter), NOT_NULL(backend_arg)); - return; - } - MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node" - << backend_parameter->DebugString() << "of graph " << to_graph_id; - assigns_.emplace_back(std::tuple(front_anf, to_graph_id, input_idx)); -} - -void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, - size_t input_idx) { - MS_LOG(INFO) << "Start!"; - std::pair graph_input_pair(to_graph_id, input_idx); - initial_tenosrs_[graph_input_pair] = front_tensor; - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::UpdateGraphOrder(GraphId to_graph_id) { - MS_LOG(INFO) << "To_graph_id " << to_graph_id; - auto &graph_order = GetGraphOrder(final_graph_id_); - auto &graph_type = GetGraphOrderType(final_graph_id_); - for (size_t i = 0; i < graph_order.size(); i++) { - if (graph_order[i] == to_graph_id) { - return; - } - } - // if graph is not in graph order,add it to graph order - SetStreamDistinctionLabel(GetGraph(to_graph_id), to_graph_id, false); - graph_order.push_back(to_graph_id); - graph_type.push_back(COMMON_GRAPH); - for (size_t i = 0; i < graph_order.size(); i++) { - MS_LOG(INFO) << "Index " << i << ",graph_id " << graph_order[i] << ",graph_type" << graph_type[i]; - } -} - -size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto output_num = AnfAlgo::GetOutputTensorNum(node); - if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { - return input_index + output_num; - } - auto valid_inputs = graph->valid_inputs(); - if (valid_inputs[input_index]) { - SetChildGraphParameter(node, graph->graph_id(), input_index); - } else { - MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString(); - } - return ++input_index; -} - -size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(value); - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); - } - SetChildGraphParameter(value->cast(), graph->graph_id(), input_index); - return ++input_index; -} - -size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index) { - auto index = input_index; - for (auto &arg : vec_args) { - if (utils::isa(arg)) { - // arg is a anf node - auto node = utils::cast(arg); - index = SetChildGraphInput(graph, node, input_index); - } else if (utils::isa(arg)) { - // arg is a tensor - auto value = utils::cast(arg); - index = SetChildGraphInput(graph, value, input_index); - } else { - MS_LOG(EXCEPTION) << "Unexpected arg type " << arg.ToString(); - } - } - return index; -} - -void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { - MS_LOG(INFO) << "Set input of graph " << g; - auto to_graph = GetGraph(g); - MS_EXCEPTION_IF_NULL(to_graph); - DumpGraphInputArgs(args); - UpdateGraphOrder(g); - auto &graph_inputs = to_graph->inputs(); - auto real_args = GetRealArgs(to_graph, args); - size_t input_index = 0; - for (size_t i = 0; i < real_args.size(); i++) { - if (input_index >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Input_index " << input_index << " out of range size " << graph_inputs.size(); - } - auto &real_arg = real_args[i]; - if (utils::isa(real_arg)) { - // arg is a anf node - auto node = utils::cast(real_arg); - input_index = SetChildGraphInput(to_graph, node, input_index); - } else if (utils::isa(real_arg)) { - // arg is a tensor - auto value = utils::cast(real_arg); - input_index = SetChildGraphInput(to_graph, value, input_index); - } else if (utils::isa(real_arg)) { - // arg is a VectorRef - auto vec_args = utils::cast(real_arg); - input_index = SetChildGraphInput(to_graph, vec_args, input_index); - } else { - MS_LOG(EXCEPTION) << "Unexpected arg type " << real_arg.ToString(); - } - } - MS_LOG(INFO) << "Finish!"; -} - -GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const { - for (const auto &graph_item : graphs_) { - auto graph = graph_item.second; - MS_EXCEPTION_IF_NULL(graph); - // if front_anf is a parameter,the backend parameter may have two - if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) { - return graph_item.first; - } - } - MS_EXCEPTION_IF_NULL(front_anf); - MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph"; - return kInvalidGraphId; -} - -void AscendSession::MergeGraphExecOrder() { - MS_LOG(INFO) << "Start!"; - // merge graph order - auto &graph_order = GetGraphOrder(final_graph_id_); - auto &graph_type = GetGraphOrderType(final_graph_id_); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - if (graph_order.empty()) { - MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!"; - return; - } - if (graph_order.size() > 1) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->enable_task_sink()) { - MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!"; - } - } - // if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph - SetStreamDistinctionLabel(final_graph, graph_order[0], false); - std::vector final_exec_order = final_graph->execution_order(); - KernelGraphPtr last_graph = nullptr; - for (size_t i = 0; i < graph_order.size(); i++) { - auto graph_id = graph_order[i]; - if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { - continue; - } - auto child_graph = GetGraph(graph_id); - last_graph = child_graph; - MS_EXCEPTION_IF_NULL(child_graph); - auto exec_order = child_graph->execution_order(); - MS_LOG(INFO) << "Merge graph,graph_id " << graph_id; - (void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order), - [&](CNodePtr node) -> CNodePtr { - AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get()); - return node; - }); - // add all value nodes of child graphs to final graph - for (auto &value_node : child_graph->graph_value_nodes()) { - final_graph->AddValueNodeToGraph(value_node); - } - // copy ref map to final graph - auto child_ref_map = child_graph->GetRefMap(); - for (auto &item : child_ref_map) { - if (final_graph->IsInRefOutputMap(item.first)) { - MS_LOG(EXCEPTION) << "The ref pair is already in final graph!"; - } - final_graph->AddRefCorrespondPairs(item.first, item.second); - } - } - // set final_exec_order into final graph - MS_EXCEPTION_IF_NULL(final_graph); - DumpGraphExeOrder(final_exec_order); - final_graph->set_execution_order(final_exec_order); -} - -void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { - MS_EXCEPTION_IF_NULL(from); - MS_EXCEPTION_IF_NULL(to); - if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && - AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { - return; - } - if (from.get() == to.get()) { - return; - } - MS_LOG(INFO) << "Insert assign to graph " << graph_id << " from " << from->DebugString() << " to " - << to->DebugString(); - auto graph = graphs_[graph_id]; - MS_EXCEPTION_IF_NULL(graph); - // config inputs of assign node - std::vector inputs = {NewValueNode(std::make_shared("Assign")), to, from}; - // generate a new cnode - auto assign_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(assign_node); - assign_node->set_abstract(to->abstract()); - // append the assign at the end of from graph - InsertDependToGraph(graph_id, assign_node); -} - -void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { - std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); - std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); - MS_LOG(INFO) << "Insert assigns from [" << AnfAlgo::GetGraphId(from.get()) << "] to [" - << AnfAlgo::GetGraphId(to.get()) << "]"; - if (from_outputs.size() != to_outputs.size()) { - MS_LOG(INFO) << "From[" << from->DebugString(5) << "] to[" << to->DebugString(5) << "]"; - MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" - << to_outputs.size() << "]"; - } - for (size_t i = 0; i < from_outputs.size(); i++) { - InsertAssignToGraph(graph_id, from_outputs[i], to_outputs[i]); - } -} - -void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) { - MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream; - auto from_graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(from_graph); - std::vector inputs = {NewValueNode(std::make_shared("StreamActive"))}; - auto active_node = from_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(active_node); - active_node->set_abstract(std::make_shared()); - // set the active stream id into the attr of active node - std::vector active_index_value = {}; - active_index_value.push_back(actived_stream); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_value), active_node); - // append the active node at the end of from graph - auto return_node = from_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - InsertControlDependToGraph(graph_id, return_node->input(1), active_node); -} - -void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { - AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node)); -} - -void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, - const AnfNodePtr &second_node) { - AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node), - NOT_NULL(second_node)); -} - -size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { - auto &graph_order = GetGraphOrder(final_graph); - for (size_t i = 0; i < graph_order.size(); i++) { - if (child_graph == graph_order[i]) { - return i; - } - } - return kInvalidIndex; -} - -std::vector &AscendSession::GetGraphOrder(GraphId final_graph_id) { - auto graph_order_iter = graph_execute_orders_.find(final_graph_id); - if (graph_order_iter == graph_execute_orders_.end()) { - MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no child graph"; - } - return graph_order_iter->second; -} - -// get graph order type vector by graph id -std::vector &AscendSession::GetGraphOrderType(GraphId final_graph_id) { - auto graph_type_iter = graph_order_types_.find(final_graph_id); - if (graph_type_iter == graph_order_types_.end()) { - MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no graph_order_types_"; - } - return graph_type_iter->second; -} - -void AscendSession::SyncInitialTenosrToDevice() { - for (auto &item : initial_tenosrs_) { - auto to_graph_id = item.first.first; - auto input_idx = item.first.second; - auto front_tensor = item.second; - auto to_graph = GetGraph(to_graph_id); - MS_EXCEPTION_IF_NULL(to_graph); - std::vector graph_inputs = to_graph->inputs(); - if (input_idx >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); - } - auto backend_parameter = graph_inputs[input_idx]; - // sync data from host to device - MS_EXCEPTION_IF_NULL(front_tensor); - size_t tensor_size = front_tensor->data().nbytes(); - auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); - MS_EXCEPTION_IF_NULL(addr); - if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, - front_tensor->data_type(), front_tensor->data_c())) { - MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; - } - } -} - -static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph, const std::vector &list) { - // count the output of every anf node - std::set has_output_nodes; - for (auto &anf_node : list) { - MS_EXCEPTION_IF_NULL(anf_node); - for (auto &input : anf_node->inputs()) { - (void)has_output_nodes.insert(input); - } - } - - auto make_tuple_primitve = NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())); - std::vector make_tuple_inputs = {make_tuple_primitve}; - int output_idx = 0; - MS_EXCEPTION_IF_NULL(new_kernel_graph); - for (auto &anf_node : list) { - if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { - new_kernel_graph->set_return(anf_node); - } - if (has_output_nodes.find(anf_node) == has_output_nodes.end()) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "Output[" << output_idx++ << "]:" << anf_node->DebugString(); - make_tuple_inputs.push_back(anf_node); - } - } - if (new_kernel_graph->get_return() == nullptr) { - new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); - } -} - -std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, - const std::vector &list) { - MS_EXCEPTION_IF_NULL(new_kernel_graph); - MS_LOG(INFO) << "Start contruct splited kernel graph:" << new_kernel_graph->graph_id(); - MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); - std::vector call_node_inputs; - std::vector new_graph_inputs; - // create new parameter from cnode - for (auto &anf_node : list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { - auto input = cnode->inputs()[input_idx]; - MS_EXCEPTION_IF_NULL(input); - AnfNodePtr new_parameter = nullptr; - // check whether input has been put into args of call, if mulptiple use of one parameter or cnode, only set one - // parameter in graph inputs and one arg in call node - auto call_input_it = std::find(call_node_inputs.begin(), call_node_inputs.end(), input); - if (call_input_it != call_node_inputs.end()) { - cnode->set_input(input_idx, new_graph_inputs[std::distance(call_node_inputs.begin(), call_input_it)]); - continue; - } - // value node consider move to new graph - if (input->isa()) { - cnode->set_input(input_idx, input); - continue; - } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { - // if is cnode and not in current child graph - new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); - cnode->set_input(input_idx, new_parameter); - } else { - // if is a cnode and in current graph - continue; - } - new_graph_inputs.push_back(new_parameter); - call_node_inputs.push_back(input); - } - } - // set graph inputs of new graph - auto graph_inputs = new_kernel_graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - graph_inputs->clear(); - std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs)); - - MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); - ConstructSplitedGraphOutput(new_kernel_graph, list); - MS_LOG(INFO) << "End"; - return call_node_inputs; -} - -void AscendSession::BackendOptimization(const std::vector &all_graphs) { - MS_LOG(INFO) << "Start BackendCommonOptimization"; - for (auto &graph : all_graphs) { - opt::BackendCommonOptimization(graph); - } - MS_LOG(INFO) << "End."; -} - -void AscendSession::SplitGraphs(NotNull root_graph) { - std::set memo; - // if root graph output is a call node ,the root graph is condition graph of 'if' sentence - auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; - if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { - SplitGraph(root_graph, {prim::kPrimReturn}, NOT_NULL(&memo)); - for (auto &child_graph : root_graph->child_graph_order()) { - RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); - } - } else { - RecurseSplitGraph(root_graph, NOT_NULL(&memo)); - } - memo.clear(); - // add maketuple to the end of the last child graph to suit old process - auto output_graph = root_graph->child_graph_order().empty() ? root_graph : root_graph->child_graph_order().back(); - auto make_tuple = output_graph->NewCNode( - {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), output_graph->output()}); - output_graph->set_output(make_tuple); - // replace the real input if the real input is a call - RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo)); -} - -AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull graph, - const std::vector &child_graph_list) { - // if child graph list only has a call ,then return the exist call - if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { - return child_graph_list[0]; - } - // create new child graph - auto child_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(child_graph); - // create new value node to bind child graph - auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph)); - std::vector new_call_input = {NewValueNode(std::make_shared(prim::kPrimCall->name())), - graph_value_node}; - // set the graph id of all node of child graph - for (auto &child_graph_node : child_graph_list) { - AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); - } - auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list); - std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input)); - auto new_call = graph->NewCNode(new_call_input); - AnfAlgo::SetNodeAttr("graph_id", MakeValue(graph->graph_id()), new_call); - return new_call; -} - -void AscendSession::SplitGraph(NotNull graph, const std::set &cut_prims, - const NotNull *> memo) { - MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); - bool split_flag = false; - auto apply_list = GetCNodes(TopoSort(graph->get_return())); - // update the root graph child graph order - AscendControlParser::UpdateChildGraphOrder(graph); - // get child list from current graph - std::vector> child_graph_lists = GetChildList(apply_list, cut_prims); - if (child_graph_lists.size() > 1) { - std::list depend_input = {}; - for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { - auto call_node = BindNewCallToNewGraph(graph, child_graph_lists[call_index]); - MS_EXCEPTION_IF_NULL(call_node); - // if call node is the last call of true graph,no need create child graph after that - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); - depend_input.push_front(call_node); - if (child_graphs.size() == 1 && child_graphs[0] == graph->parent_graph()) { - break; - } - } - depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimDepend->name())))); - auto depend = graph->NewCNode(std::vector(depend_input.begin(), depend_input.end())); - auto new_return_primitive = - graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimReturn->name()))); - graph->set_return(graph->NewCNode({new_return_primitive, depend})); - AnfNodePtr pre_call_node = nullptr; - AnfNodePtr cur_call_node = nullptr; - auto iter = depend_input.begin(); - for (++iter; iter != depend_input.end(); ++iter) { - pre_call_node = cur_call_node; - cur_call_node = *iter; - if (pre_call_node != nullptr && cur_call_node != nullptr) { - AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); - } - } - split_flag = true; - } - AscendControlParser::UpdateChildGraphOrder(graph); - UpdateRealInput(graph, split_flag, memo); - MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; -} - -void AscendSession::RecurseSplitGraph(NotNull graph, const NotNull *> memo) { - memo->insert(graph.get()); - SplitGraph(graph, {prim::kPrimCall}, memo); - for (auto &child_graph : graph->child_graph_order()) { - if (memo->find(child_graph) == memo->end()) { - RecurseSplitGraph(NOT_NULL(child_graph), memo); - } - } -} - -void AscendSession::LinkChildGraphs(NotNull graph) { AscendControlParser::LinkGraph(graph); } - -void AscendSession::RootGraphExecutorValidate(NotNull graph) { - AscendControlParser::ExecutorValidate(graph); -} - -void AscendSession::RecurseCompileGraph(NotNull graph, const NotNull *> memo) { - memo->insert(graph.get()); - CompileChildGraph(graph); - for (auto child_graph : graph->child_graph_order()) { - if (memo->find(child_graph) != memo->end()) { - continue; - } - RecurseCompileGraph(NOT_NULL(child_graph), memo); - // copy ref map to final graph - auto child_ref_map = child_graph->GetRefMap(); - for (auto &item : child_ref_map) { - if (graph->IsInRefOutputMap(item.first)) { - MS_LOG(EXCEPTION) << "The ref pair is already in final graph!"; - } - graph->AddRefCorrespondPairs(item.first, item.second); - } - } -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h deleted file mode 100755 index 531860c379..0000000000 --- a/mindspore/ccsrc/session/ascend_session.h +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H -#define MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "session/session_basic.h" -#include "session/kernel_graph.h" -#include "kernel/kernel.h" -#include "session/session_factory.h" -#include "session/ascend_control_parser.h" - -namespace mindspore { -namespace session { -enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 }; - -class AscendSession : public SessionBasic { - public: - AscendSession() { final_graph_id_ = kInvalidGraphId; } - ~AscendSession() override = default; - void Init(uint32_t device_id) override { - SessionBasic::Init(device_id); - context_ = std::make_shared(kAscendDevice, device_id); - } - GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - GraphId CompileGraph(NotNull func_graph) override; - void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; - void BuildGraph(GraphId) override; - void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) override; - py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) override; - - // set parameters of final graph - GraphId SetFinalGraphInput(const std::vector &args) override; - // set output of final graph - void SetFinalGraphOutput(const BaseRef &output) override; - // insert switch and set the relative active ops - void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override; - // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter - void SetChildGraphInput(GraphId g, const VectorRef &args) override; - // get graph id in child graphs by ME front anf node pointer - GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; - // get graph id of final graph - GraphId GetFinalRunGraph() const override { return final_graph_id_; } - // insert active to graph - void SetActive(GraphId, GraphId) override; - // compile child graph when session have multiple child graphs - void CompileChildGraph(const KernelGraphPtr &child_graph); - void RecurseGetSummaryNodes(KernelGraph *graph, std::map> *summary); - void GetSummaryNodes(KernelGraph *graph); - - private: - void InitRuntimeResource(); - void SelectKernel(const KernelGraph &kernel_graph) const; - void HardwareOptimize(const std::shared_ptr &kernel_graph) const; - void AdjustKernel(const std::shared_ptr &kernel_graph) const; - void RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const; - void AssignStream(NotNull kernel_graph) const; - void BuildKernel(const std::shared_ptr &kernel_graph) const; - void MemoryAlloc(KernelGraph *kernel_graph) const; - void RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const; - void RunOpMemoryClear(const KernelGraph *kernel_graph) const; - void GenerateTaskInfo(const std::shared_ptr &kernel_graph) const; - void LoadTask(const std::shared_ptr &kernel_graph) const; - void ExecTask(const std::shared_ptr &kernel_graph) const; - void Dump(const std::shared_ptr &kernel_graph) const; - void ExportChildGraphs(const GraphId graph_id); - void LoadTensor(const std::shared_ptr &kernel_graph) const; - // below functions are used for run op - void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const; - void RunOpExecTask(const std::shared_ptr &kernel_graph) const; - - size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index); - size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); - size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); - - void SetFinalGraphOutput(const AnfNodePtr &node); - void SetFinalGraphOutput(const ValuePtr &value); - void SetFinalGraphOutput(const VectorRef &vec_output); - - void SplitGraph(NotNull graph, const std::set &cut_prims, - const NotNull *> memo); - // split graphs with recurse from root graph - void SplitGraphs(NotNull root_graph); - void BackendOptimization(const std::vector &all_graphs); - void LinkChildGraphs(NotNull graph); - void RootGraphExecutorValidate(NotNull graph); - std::vector ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, - const std::vector &list); - void RecurseCompileGraph(NotNull graph, const NotNull *> memo); - void RecurseSplitGraph(NotNull graph, const NotNull *> memo); - AnfNodePtr BindNewCallToNewGraph(NotNull graph, const std::vector &child_graph_list); - - // merge execution order list of child graphs - void MergeGraphExecOrder(); - // insert assion op to sync data bettween different graphs - void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); - // insert mutiple assigns to graph - void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); - // insert active op to graph - void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream); - // get execute index of graph - size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph); - // handle condition graph from vm - void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id); - // insert depend to graph, used to attch control nodes to graph - void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); - // insert depend to graph, used to attch control nodes to graph - void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); - // set child graph parameter if front arg is a anf - void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); - // set child graph parameter if front arg is a tensor - void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx); - // update the execution order of all child graphs - void UpdateGraphOrder(GraphId to_graph); - // handle switch when merge - void MergeSwitchCompile(); - // get graph order vector by graph id - std::vector &GetGraphOrder(GraphId final_graph_id); - // get graph order type vector by graph id - std::vector &GetGraphOrderType(GraphId final_graph_id); - // copy output of if and else - void CopyOutputOfIf(GraphId false_graph_id); - // check if graph cache exist - bool GraphCacheExist(const GraphInfo &graph_info) const; - // insert all assign to child graph - void InsertAllAssigns(); - // create fake output of final graph - AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); - // sync intial tensors' data to device - void SyncInitialTenosrToDevice(); - void SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph); - - // member variables - // key is final_graph_id,value is child graph execute order of final graph - std::unordered_map> graph_execute_orders_; - // key is final_graph_id,value is the graph types of child graphs - std::unordered_map> graph_order_types_; - // record condition graph of while - std::unordered_map while_condition_graphs_; - // record all conditions - std::unordered_map> switches_; - std::unordered_map condition_output_; - // share parameters - std::vector> assigns_; - // initial tensors, these tensor will sync data to device before run graph - std::map, tensor::TensorPtr> initial_tenosrs_; - // final_graph_id is used in every root graph has it's own session situation - GraphId final_graph_id_; -}; -MS_REG_SESSION(kAscendDevice, AscendSession); -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H diff --git a/mindspore/ccsrc/session/cpu_session.cc b/mindspore/ccsrc/session/cpu_session.cc deleted file mode 100644 index 1927df2f49..0000000000 --- a/mindspore/ccsrc/session/cpu_session.cc +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "session/cpu_session.h" -#include -#include "ir/tensor.h" -#include "ir/anf.h" -#include "kernel/kernel.h" -#include "common/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_runtime.h" -#include "predict/predict.h" -#include "kernel/cpu/cpu_kernel_factory.h" -#include "device/cpu/kernel_select_cpu.h" -#ifdef ENABLE_DEBUGGER -#include "debug/debugger/debugger.h" -#endif - -namespace mindspore { -namespace session { -ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - MS_EXCEPTION_IF_NULL(graph); - if (!anf->isa()) { - MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; - } - auto valid_inputs = graph->MutableValidInputs(); - MS_EXCEPTION_IF_NULL(valid_inputs); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - TraceManager::DebugTrace(std::make_shared(anf->debug_info())); - ParameterPtr new_parameter = graph->NewParameter(anf->cast()); - TraceManager::EndTrace(); - graph_inputs->push_back(new_parameter); - valid_inputs->push_back(valid_input); - return new_parameter; -} - -GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - auto graph_id = graph_sum_; - auto graph = ConstructKernelGraph(lst, outputs); - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Set kernel info"; - SetKernelInfo(graph.get()); - predictmodel::StepConvertGraph(graph); - MS_LOG(INFO) << "Build kernel"; - BuildKernel(graph.get()); - MS_LOG(INFO) << "Assign kernel address"; - runtime_.AssignKernelAddress(graph.get()); - return graph_id; -} - -void CPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { - auto &kernel_graph = graphs_[graph_id]; - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_LOG(INFO) << "Bind input output address"; - std::vector need_sync_outputs; - runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs, &need_sync_outputs); - MS_LOG(INFO) << "Run graph start"; - predictmodel::StepConvertWeight(inputs); - auto execution_order = kernel_graph->execution_order(); - Reorder(&execution_order); - - bool enable_summary = summary_callback_ != nullptr; - kernel_graph->set_execution_order(execution_order); - NamedSummaryOutputs summary_outputs; - if (enable_summary) { - GetSummaryNodes(kernel_graph.get()); - summary_outputs = kernel_graph->summary_nodes(); - runtime_.IncreaseSummaryRefCount(summary_outputs); - } -#ifdef ENABLE_DEBUGGER - // debugger pre-execution processing - if (debugger_) { - debugger_->PreExecute(kernel_graph); - } -#endif - bool ret = runtime_.Run(kernel_graph.get()); - if (!ret) { - MS_LOG(EXCEPTION) << "Run graph failed"; - } - for (auto output : need_sync_outputs) { - (void)output->data_sync(); - } - - if (enable_summary) { - Summary(kernel_graph.get()); - runtime_.DecreaseSummaryRefCount(summary_outputs); - } - -#ifdef ENABLE_DEBUGGER - // debugger post-execution processing - if (debugger_) { - debugger_->PostExecute(); - } -#endif - MS_LOG(INFO) << "Run graph end"; -} - -void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto &kernel_nodes = kernel_graph->execution_order(); - for (const auto &kernel_node : kernel_nodes) { - MS_EXCEPTION_IF_NULL(kernel_node); - device::cpu::SetKernelInfo(kernel_node); - } -} - -void CPUSession::BuildKernel(const KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto &kernel_nodes = kernel_graph->execution_order(); - for (const auto &kernel_node : kernel_nodes) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - MS_LOG(INFO) << "Cpu building operator[" << kernel_name << "]."; - std::shared_ptr cpu_kernel = - kernel::CPUKernelFactory::GetInstance().Create(kernel_name, kernel_node); - if (cpu_kernel == nullptr) { - MS_LOG(EXCEPTION) << "Operator[" << kernel_name << "] is not support."; - } - cpu_kernel->Init(kernel_node); - AnfAlgo::SetKernelMod(cpu_kernel, kernel_node.get()); - MS_LOG(INFO) << "Cpu build success operator[" << kernel_name << "]."; - } -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/cpu_session.h b/mindspore/ccsrc/session/cpu_session.h deleted file mode 100644 index 36b987e840..0000000000 --- a/mindspore/ccsrc/session/cpu_session.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_CPU_SESSION_H -#define MINDSPORE_CCSRC_SESSION_CPU_SESSION_H -#include -#include -#include -#include "session/session_basic.h" -#include "session/kernel_graph.h" -#include "device/cpu/cpu_kernel_runtime.h" -#include "session/session_factory.h" -namespace mindspore { -namespace session { -class CPUSession : public SessionBasic { - public: - CPUSession() = default; - ~CPUSession() override = default; - void Init(uint32_t device_id) override { - SessionBasic::Init(device_id); - context_ = std::make_shared(kCPUDevice, device_id); - } - GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; - - protected: - ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; - - private: - void SetKernelInfo(const KernelGraph *kernel_graph); - void BuildKernel(const KernelGraph *kernel_graph); - device::cpu::CPUKernelRuntime runtime_; -}; -MS_REG_SESSION(kCPUDevice, CPUSession); -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_CPU_SESSION_H diff --git a/mindspore/ccsrc/session/gpu_session.cc b/mindspore/ccsrc/session/gpu_session.cc deleted file mode 100644 index 7765e93758..0000000000 --- a/mindspore/ccsrc/session/gpu_session.cc +++ /dev/null @@ -1,267 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/gpu_session.h" -#include "device/gpu/kernel_info_setter.h" -#include "device/gpu/gpu_kernel_build.h" -#include "device/gpu/gpu_kernel_runtime.h" -#include "device/gpu/gpu_stream_assign.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/pass/communication_op_fusion.h" -#include "pre_activate/pass/getitem_tuple.h" -#include "pre_activate/gpu/adam_weight_decay_fusion.h" -#include "pre_activate/gpu/adam_fusion.h" -#include "device/kernel_runtime_manager.h" -#include "predict/predict.h" -#include "common/utils.h" -#include "common/trans.h" -#include "utils/context/ms_context.h" -#include "utils/base_ref_extends.h" - -namespace mindspore { -namespace session { -namespace gpu { -using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; - -void GPUSession::SelectKernel(const std::shared_ptr &kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - for (const auto &kernel_node : kernel_graph->execution_order()) { - MS_EXCEPTION_IF_NULL(kernel_node); - device::gpu::SetKernelInfo(kernel_node); - } -} - -void GPUSession::StartKernelRT() const { - auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - if (!runtime_instance->Init()) { - MS_LOG(EXCEPTION) << "GPU start kernel runtime failed"; - } -} - -void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - optimizer->AddPassManager(pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void GPUSession::HardwareOptimize(const std::shared_ptr &kernel_graph) { - auto optimizer = std::make_shared(); - auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - optimizer->AddPassManager(pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void GPUSession::AssignStream(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - device::gpu::AssignGpuStream(kernel_graph); -} - -void GPUSession::BuildKernel(const std::shared_ptr &kernel_graph) const { - device::gpu::GpuBuild(kernel_graph); -} - -void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->AssignMemory(kernel_graph); -} - -void GPUSession::RunOpAllocateMemory(const std::vector &input_tensors, - KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); -} - -void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpClearMemory(kernel_graph); -} - -void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const { - std::vector inputs(inputs_const); - MS_EXCEPTION_IF_NULL(kernel_graph); - auto input_nodes = kernel_graph->inputs(); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - - for (size_t i = 0; i < inputs.size(); ++i) { - auto tensor = inputs[i]; - MS_EXCEPTION_IF_NULL(tensor); - auto input_node = input_nodes[i]; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { - auto pk_node = input_node->cast(); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - auto tensor_address = tensor->device_address(); - bool need_sync = false; - if (ms_context->enable_pynative_infer()) { - if (tensor_address == nullptr || tensor_address != device_address) { - need_sync = true; - } - } else if (tensor->is_dirty() || tensor_address == nullptr) { - need_sync = true; - } else if (tensor_address != device_address) { - if (tensor_address->DeviceType() == device_address->DeviceType()) { - AnfAlgo::SetOutputAddr(tensor_address, 0, pk_node.get()); - } else { - need_sync = true; - } - } - if (need_sync) { - tensor->set_device_address(device_address); - MS_EXCEPTION_IF_NULL(device_address); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } - } - } - tensor->set_dirty(false); - } -} - -void GPUSession::Execute(const std::shared_ptr &kernel_graph) const { - auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - if (!runtime_instance->Run(kernel_graph.get())) { - MS_LOG(EXCEPTION) << "GPU execute graph failed!"; - } -} - -GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - // Construct graph, if successfully, graph_sum_ + 1 - auto graph_id = graph_sum_; - auto graph = ConstructKernelGraph(lst, outputs); - MS_EXCEPTION_IF_NULL(graph); - // Optimize - Optimize(graph); - // Select kernel build info - SelectKernel(graph); - // Convert kernel Graph to model - predictmodel::StepConvertGraph(graph); - // Start gpu kernel runtime - StartKernelRT(); - // HardwareOptimize - HardwareOptimize(graph); - // Assign CUDA streams - AssignStream(graph); - // Hide NoOp from execution graph - opt::HideNopNode(graph.get()); - // Build kernel if node is cnode - BuildKernel(graph); - // Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph - auto execution_order = graph->execution_order(); - Reorder(&execution_order); - graph->set_execution_order(execution_order); - // Get summary nodes. - GetSummaryNodes(graph.get()); - // Remove NoOp from execution graph - opt::RemoveNopNode(graph.get()); - // Alloc memory, including static memory and dynamic memory - AllocateMemory(graph.get()); - MS_EXCEPTION_IF_NULL(context_); - FuncGraphManagerPtr manager = MakeManager({graph}); - context_->AddManager(manager); - if (manager) { - manager->AddFuncGraph(graph); - graph->set_manager(manager); - } - return graph_id; -} - -void GPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { - auto &kernel_graph = graphs_[graph_id]; - // Load input data from user input - LoadInputData(kernel_graph, inputs); - MS_EXCEPTION_IF_NULL(kernel_graph); - // Convert inputs to model - predictmodel::StepConvertWeight(inputs); - { - py::gil_scoped_release gil_release; - // Run graph on GPU - Execute(kernel_graph); - } - // Get result from GPU - UpdateOutputs(kernel_graph, outputs, inputs); - // Summary - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_gpu_summary()) { - Summary(kernel_graph.get()); - } -} - -void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) { - // Check if the graph cache exists. - if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { - return; - } - // Prepare the graph - auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); - MS_EXCEPTION_IF_NULL(kernel_graph); - SelectKernel(kernel_graph); - StartKernelRT(); - // Hide NoOp from execution graph - opt::HideNopNode(kernel_graph.get()); - BuildKernel(kernel_graph); - run_op_graphs_[graph_info] = kernel_graph; -} - -py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) { - auto kernel_graph = run_op_graphs_[graph_info]; - MS_EXCEPTION_IF_NULL(kernel_graph); - // Remove NoOp from execution graph - opt::RemoveNopNode(kernel_graph.get()); - RunOpAllocateMemory(input_tensors, kernel_graph.get()); - // Execute the computation - LoadInputData(kernel_graph, input_tensors); - Execute(kernel_graph); - // Fetch outputs - VectorRef outputs; - UpdateOutputs(kernel_graph, &outputs, input_tensors); - // Trans output to tuple - auto output_tensors = TransformBaseRefListToTuple(outputs); - if (!utils::isa(output_tensors) || - !py::isinstance(utils::cast(output_tensors).object_)) { - MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; - } - py::object tuple_obj = utils::cast(output_tensors).object_; - py::tuple tuple_tensors = py::cast(tuple_obj); - RunOpClearMemory(kernel_graph.get()); - return tuple_tensors; -} -} // namespace gpu -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/gpu_session.h b/mindspore/ccsrc/session/gpu_session.h deleted file mode 100644 index 4e46c2138d..0000000000 --- a/mindspore/ccsrc/session/gpu_session.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_GPU_SESSION_H -#define MINDSPORE_CCSRC_SESSION_GPU_SESSION_H - -#include -#include -#include "session/session_basic.h" -#include "session/kernel_graph.h" -#include "session/session_factory.h" -using KernelGraph = mindspore::session::KernelGraph; - -namespace mindspore { -namespace session { -namespace gpu { -class GPUSession : public SessionBasic { - public: - GPUSession() = default; - ~GPUSession() override = default; - - void Init(uint32_t device_id) override { - SessionBasic::Init(device_id); - context_ = std::make_shared(kGPUDevice, device_id); - } - - GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - - void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; - void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) override; - py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) override; - - private: - void SelectKernel(const std::shared_ptr &kernel_graph) const; - - void StartKernelRT() const; - - void Optimize(const std::shared_ptr &kernel_graph); - - void HardwareOptimize(const std::shared_ptr &kernel_graph); - - void AssignStream(const std::shared_ptr &kernel_graph); - - void BuildKernel(const std::shared_ptr &kernel_graph) const; - - void AllocateMemory(KernelGraph *kernel_graph) const; - - void RunOpAllocateMemory(const std::vector &input_tensors, KernelGraph *kernel_graph) const; - - void RunOpClearMemory(KernelGraph *kernel_graph) const; - - void LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const override; - - void Execute(const std::shared_ptr &kernel_graph) const; -}; -using GPUSessionPtr = std::shared_ptr; -MS_REG_SESSION(kGPUDevice, GPUSession); -} // namespace gpu -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_GPU_SESSION_H diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc deleted file mode 100644 index 264e2c661b..0000000000 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ /dev/null @@ -1,994 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/kernel_graph.h" -#include -#include -#include -#include -#include "operator/ops.h" -#include "ir/param_value_py.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "kernel/kernel_build_info.h" -#include "device/kernel_runtime_manager.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace session { -namespace { -constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; -constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; -void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, - std::unordered_set *visited_nodes) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(que); - MS_EXCEPTION_IF_NULL(visited_nodes); - if (visited_nodes->find(node) == visited_nodes->end()) { - que->push(node); - (void)visited_nodes->insert(node); - MS_LOG(DEBUG) << "Push que:" << node->DebugString(); - } -} - -std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { - auto item_with_index = - AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple}); - AnfNodePtr node = item_with_index.first; - MS_EXCEPTION_IF_NULL(node); - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { - auto outputs = AnfAlgo::GetAllOutput(node); - std::set memo; - std::vector new_output; - for (auto &output : outputs) { - if (memo.find(output) != memo.end()) { - continue; - } - memo.insert(output); - new_output.push_back(output); - } - if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) { - node = new_output[0]; - } - } - if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { - return {node}; - } - std::vector real_inputs; - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast()); - for (const auto &child_graph : child_graphs) { - if (child_graph->get_output_null()) { - continue; - } - auto real_input = child_graph->output(); - auto child_real_inputs = GetCallRealOutputs(real_input); - std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs)); - } - return real_inputs; -} - -AnfNodePtr MakeValueNode(const AnfNodePtr &node) { - auto value_node = node->cast(); - if (value_node == nullptr) { - return nullptr; - } - - ValueNodePtr new_value_node = std::make_shared(value_node->value()); - new_value_node->set_abstract(value_node->abstract()); - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { - types.push_back(kTypeUnknown); - } - kernel_build_info_builder->SetOutputsDeviceType(types); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - return new_value_node; -} - -bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { - if (left == right) { - return true; - } - if (left == nullptr || right == nullptr) { - return false; - } - if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) { - return false; - } - if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) { - return AnfAlgo::GetNodeAttr(left, kAttrLabelIndex) == - AnfAlgo::GetNodeAttr(right, kAttrLabelIndex); - } - return false; -} -} // namespace -std::vector KernelGraph::outputs() const { - auto graph_output = output(); - if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { - auto make_tuple = output()->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - auto &inputs = make_tuple->inputs(); - return std::vector(inputs.begin() + 1, inputs.end()); - } - return std::vector(1, graph_output); -} - -void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, - std::unordered_set *visited_nodes) { - MS_EXCEPTION_IF_NULL(visit_queue); - MS_EXCEPTION_IF_NULL(visited_nodes); - auto it = node_output_edges_.find(node); - if (it == node_output_edges_.end()) { - // value node and parameter has no input,no need to print log - if (node->isa()) { - MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; - } - return; - } - - // visit all reduce node first, then other nodes - std::vector active_nodes; - for (const auto &output_edge : it->second) { - auto next_node = output_edge.first; - MS_EXCEPTION_IF_NULL(next_node); - if (node_input_num_.find(next_node) == node_input_num_.end()) { - MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; - } - MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() - << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; - if (node_input_num_[next_node] < output_edge.second) { - MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node] - << ",depend edge:" << output_edge.second; - } - node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; - // allreduce first - if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) { - (void)visited_nodes->insert(next_node); - if (AnfAlgo::IsCommunicationOp(next_node)) { - MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString(); - visit_queue->push(next_node); - } else { - active_nodes.emplace_back(next_node); - } - } - } - - for (auto &node : active_nodes) { - MS_EXCEPTION_IF_NULL(node); - MS_LOG(DEBUG) << "Visit node:" << node->DebugString(); - visit_queue->push(node); - } -} - -void KernelGraph::SetExecOrderByDefault() { - std::queue seed_nodes; - UpdateNodeEdgeList(&seed_nodes); - execution_order_.clear(); - std::unordered_set visited_nodes; - std::queue zero_input_nodes; - AnfNodePtr last_communication_node = nullptr; - std::queue communication_descendants; - while (!seed_nodes.empty() || last_communication_node != nullptr) { - // seed nodes first, then visit last all reduce node descendant - if (seed_nodes.empty()) { - VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); - last_communication_node = nullptr; - } else { - zero_input_nodes.push(seed_nodes.front()); - seed_nodes.pop(); - } - // all reduce node descendant first, then common queue - while (!zero_input_nodes.empty() || !communication_descendants.empty()) { - AnfNodePtr node = nullptr; - bool is_communication_descendant = false; - if (communication_descendants.empty()) { - node = zero_input_nodes.front(); - zero_input_nodes.pop(); - } else { - node = communication_descendants.front(); - communication_descendants.pop(); - is_communication_descendant = true; - } - // add execute node - MS_EXCEPTION_IF_NULL(node); - if (node->isa() && AnfAlgo::IsRealKernel(node)) { - execution_order_.push_back(node->cast()); - } - // for all reduce node, visit last all reduce node descendant - if (AnfAlgo::IsCommunicationOp(node)) { - if (last_communication_node != nullptr) { - VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); - } - last_communication_node = node; - } else if (is_communication_descendant) { - VisitNodeDescendants(node, &communication_descendants, &visited_nodes); - } else { - VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes); - } - } - } - CheckLoop(); - // resort start label / end goto - std::vector re_order; - if (start_label_ != nullptr) { - re_order.push_back(start_label_); - } - for (auto &node : execution_order_) { - if (node == start_label_ || node == end_goto_) { - continue; - } - - if (IsSameLabel(node, end_goto_)) { - end_goto_ = node; - MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id(); - continue; - } - - if (IsSameLabel(node, start_label_)) { - start_label_ = node; - MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id(); - continue; - } - - re_order.push_back(node); - } - if (end_goto_ != nullptr) { - re_order.push_back(end_goto_); - } - execution_order_ = re_order; -} - -void KernelGraph::CheckLoop() { - std::map none_zero_nodes; - if (node_input_edges_.size() != node_input_num_.size()) { - MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size() - << "not equal to node_input_num_ size:" << node_input_num_.size(); - } - for (auto &it : node_input_num_) { - MS_EXCEPTION_IF_NULL(it.first); - string str; - auto node_input_it = node_input_edges_.find(it.first); - if (node_input_it == node_input_edges_.end()) { - MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]"; - } - for (const auto &input_edge : node_input_edges_[it.first]) { - MS_EXCEPTION_IF_NULL(input_edge.first); - str = str.append(input_edge.first->DebugString()).append("|"); - } - if (it.second != 0) { - MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second; - none_zero_nodes[it.first] = it.second; - } - } - // if don't consider control depend and loop exit,a exception will be throw - if (!none_zero_nodes.empty()) { - MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); - } -} - -CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { - auto cnode = FuncGraph::NewCNode(inputs); - MS_EXCEPTION_IF_NULL(cnode); - cnode->set_abstract(std::make_shared()); - CreateKernelInfoFromNewParameter(cnode); - - auto kernel_info = std::make_shared(); - std::vector feature_map_input_indexs; - // if the node only has the primitive(such as getNext) or the node's input has a feature map input - // then the node's output is a feature map output - for (size_t index = 1; index < inputs.size(); ++index) { - auto node = inputs[index]; - if (AnfAlgo::IsFeatureMapOutput(node)) { - feature_map_input_indexs.push_back(index); - } - } - if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { - AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); - } - if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { - kernel_info->SetFeatureMapFlag(true); - } - if (AnfAlgo::IsRealCNodeKernel(cnode)) { - AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); - AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); - } - cnode->set_kernel_info(kernel_info); - AnfAlgo::SetGraphId(graph_id_, cnode.get()); - return cnode; -} - -void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { - if (!AnfAlgo::IsGraphKernel(cnode)) { - return; - } - auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); - MS_EXCEPTION_IF_NULL(func_graph); - - std::vector node_list; - std::vector input_list; - std::vector output_list; - kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); - for (auto &anf_node : node_list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto kernel_info = std::make_shared(); - anf_node->set_kernel_info(kernel_info); - auto anf_cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(anf_cnode); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_cnode); ++i) { - auto input_node = anf_cnode->input(i + 1); - MS_EXCEPTION_IF_NULL(input_node); - if (IsValueNode(input_node)) { - auto new_input_node = MakeValueNode(input_node); - if (new_input_node != nullptr) { - anf_cnode->set_input(i + 1, new_input_node); - } - } - } - } - for (auto &anf_node : input_list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto kernel_info = std::make_shared(); - anf_node->set_kernel_info(kernel_info); - } -} - -CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto new_cnode = std::make_shared(*cnode); - // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map - if (BackendNodeExistInFrontBackendMap(cnode)) { - FrontBackendlMapUpdate(cnode, new_cnode); - } - AnfAlgo::SetGraphId(graph_id_, cnode.get()); - if (IsInternalOutput(cnode)) { - ReplaceInternalOutput(cnode, new_cnode); - } - return new_cnode; -} - -ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { - ParameterPtr new_parameter = add_parameter(); - MS_EXCEPTION_IF_NULL(new_parameter); - // create kernel_info form new parameter - auto kernel_info = std::make_shared(); - size_t output_tensor_num = 1; - // if use default parameter = nullptr,it remarks create a new parameter from no parameter - if (parameter == nullptr) { - new_parameter->set_abstract(std::make_shared()); - kernel_info->SetFeatureMapFlag(true); - } else { - // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter - new_parameter->set_abstract(parameter->abstract()); - new_parameter->set_name(parameter->name()); - if (AnfAlgo::IsParameterWeight(parameter)) { - auto param_value = std::dynamic_pointer_cast(parameter->default_param()); - auto param_value_new = std::make_shared(param_value->value()); - new_parameter->set_default_param(param_value_new); - kernel_info->SetFeatureMapFlag(false); - } else { - kernel_info->SetFeatureMapFlag(true); - } - } - new_parameter->set_kernel_info(kernel_info); - // create kernel_build_info for new parameter - auto kernel_build_info_builder = std::make_shared(); - // create init data type, - std::vector init_data_type = {}; - - TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0); - init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); - - // set the format of parameter to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector(output_tensor_num, kOpFormat_DEFAULT)); - // set parameter initaial device data type - kernel_build_info_builder->SetOutputsDeviceType(init_data_type); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get()); - AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); - return new_parameter; -} - -std::vector KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - auto node_value = value_node->value(); - auto output_size = AnfAlgo::GetOutputTensorNum(value_node); - std::vector convert_inputs; - if (!node_value->isa()) { - MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString(); - } - auto value_tuple = node_value->cast(); - MS_EXCEPTION_IF_NULL(value_tuple); - if (value_tuple->size() != output_size) { - MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size() - << " is not mathced with the value node's output size" << output_size; - } - for (size_t index = 0; index < value_tuple->value().size(); ++index) { - auto new_value_node = std::make_shared(value_tuple->value()[index]); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)}, - {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get()); - AddValueNodeToGraph(new_value_node); - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - kernel_info->SetFeatureMapFlag(false); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown}); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); - AddValueNodeToGraph(new_value_node); - convert_inputs.emplace_back(new_value_node); - } - if (!RemoveValueNodeFromGraph(value_node)) { - MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); - } - return convert_inputs; -} - -ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - auto new_value_node = MakeValueNode(value_node)->cast(); - AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); - return new_value_node; -} - -const std::vector &KernelGraph::inputs() const { - MS_EXCEPTION_IF_NULL(inputs_); - return *inputs_; -} - -void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) { - MS_EXCEPTION_IF_NULL(front_anf); - MS_EXCEPTION_IF_NULL(backend_anf); - if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) { - MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_"; - } - if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) { - MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_"; - } - front_backend_anf_map_[front_anf] = backend_anf; - backend_front_anf_map_[backend_anf] = front_anf; -} - -void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) { - MS_EXCEPTION_IF_NULL(old_backend_anf); - MS_EXCEPTION_IF_NULL(new_backend_anf); - if (old_backend_anf == new_backend_anf) { - MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString(); - return; - } - if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) { - MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map"; - return; - } - if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) { - MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString(); - } - front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf; - backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf]; - // delete old kernel - (void)backend_front_anf_map_.erase(old_backend_anf); -} -// get kernel by anf -AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) { - if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) { - return nullptr; - } - return front_backend_anf_map_[front_anf]; -} - -bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) { - return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end(); -} - -ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) { - if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) { - return nullptr; - } - return tensor_to_value_node_map_[tensor]; -} - -void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(tensor); - MS_EXCEPTION_IF_NULL(value_node); - tensor_to_value_node_map_[tensor] = value_node; -} - -void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(input); - MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num; - auto output_depend_edge = std::pair(node, depend_edge_num); - // add output depend edge of input - auto output_it = node_output_edges_.find(input); - if (output_it == node_output_edges_.end()) { - node_output_edges_[input] = std::vector>{output_depend_edge}; - } else { - output_it->second.push_back(output_depend_edge); - } - // add input depend edge of output - auto input_depend_edge = std::pair(input, depend_edge_num); - auto input_it = node_input_edges_.find(node); - if (input_it == node_input_edges_.end()) { - node_input_edges_[node] = std::vector>{input_depend_edge}; - } else { - input_it->second.push_back(input_depend_edge); - } - // add node input depend num - auto depend_it = node_input_num_.find(node); - if (depend_it == node_input_num_.end()) { - node_input_num_[node] = depend_edge_num; - } else { - depend_it->second += depend_edge_num; - } -} - -std::vector KernelGraph::GetOutputNodes(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto it = node_output_edges_.find(node); - if (it == node_output_edges_.end()) { - MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]"; - } - std::vector output_nodes; - auto trans = [](const std::pair &pair) -> AnfNodePtr { return pair.first; }; - (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans); - return output_nodes; -} - -// Find control_depend real input nodes. -void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(result); - MS_EXCEPTION_IF_NULL(visited); - if (visited->find(anf_node) != visited->end()) { - MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; - return; - } - visited->insert(anf_node); - if (AnfAlgo::IsRealKernel(anf_node)) { - result->emplace_back(anf_node); - return; - } - if (!anf_node->isa()) { - return; - } - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().empty()) { - MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString(); - } - auto input0 = cnode->input(0); - if (IsPrimitive(input0, prim::kPrimMakeTuple)) { - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - GetAllFatherRealNode(cnode->input(i), result, visited); - } - } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { - if (cnode->inputs().size() != kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; - } - GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited); - } else if (IsPrimitive(input0, prim::kPrimDepend)) { - if (cnode->inputs().size() != kDependInputSize) { - MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!"; - } - GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited); - GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited); - } -} - -// update the depend relations of control depend -void KernelGraph::UpdateControlDependRelations(const std::vector &depends) { - for (const auto &node : depends) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { - MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend"; - } - auto prior_node = cnode->input(kControlDependPriorIndex); - auto depend_node = cnode->input(kControlDependBehindIndex); - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(depend_node); - std::vector prior_nodes = {prior_node}; - std::vector depend_nodes = {depend_node}; - int depend_mode = 0; - if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { - depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); - } - MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() - << "], depend_mode :" << depend_mode << "."; - if (prior_node->isa() && depend_mode == 1) { - prior_nodes = GetOutputNodes(prior_node); - } - if (depend_node->isa()) { - depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector{}; - } - - std::vector real_prior_nodes; - std::set prior_visited; - for (const auto &tmp : prior_nodes) { - GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); - } - - std::vector real_depend_nodes; - std::set depend_visited; - for (const auto &tmp : depend_nodes) { - GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); - } - - for (auto &first_node : real_prior_nodes) { - if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { - continue; - } - for (auto &second_node : real_depend_nodes) { - if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { - continue; - } - MS_EXCEPTION_IF_NULL(first_node); - MS_EXCEPTION_IF_NULL(second_node); - MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); - AddDependEdge(second_node, first_node, 1); - } - } - } -} - -bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue *que, - std::unordered_set *visited_nodes) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(que); - MS_EXCEPTION_IF_NULL(visited_nodes); - if (!node->isa()) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { - return false; - } - // set the control depend visited but don't push it into the que - if (visited_nodes->find(node) != visited_nodes->end()) { - return true; - } - (void)visited_nodes->insert(cnode); - // add a 0 depend num to keep the link relations to prepare for finding zero output nodes - auto prior_node = cnode->input(kControlDependPriorIndex); - auto depend_node = cnode->input(kControlDependBehindIndex); - for (const auto &input : cnode->inputs()) { - AddDependEdge(node, input, 0); - } - PushNoVisitedNode(depend_node, que, visited_nodes); - PushNoVisitedNode(prior_node, que, visited_nodes); - return true; -} - -void KernelGraph::UpdateNodeEdgeList(std::queue *seed_nodes) { - MS_EXCEPTION_IF_NULL(seed_nodes); - node_output_edges_.clear(); - node_input_num_.clear(); - node_input_edges_.clear(); - std::vector control_depends; - std::unordered_set visited_nodes; - std::queue que; - que.push(get_return()); - while (!que.empty()) { - auto node = que.front(); - que.pop(); - MS_EXCEPTION_IF_NULL(node); - if (node->isa() || node->isa()) { - seed_nodes->push(node); - continue; - } - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // handle data links - for (const auto &input : cnode->inputs()) { - size_t depend_edge_num = 1; - // handle control depend,all inputs of control depend has no depend edge - if (HandleControlDependNode(input, &que, &visited_nodes)) { - control_depends.push_back(input); - depend_edge_num = 0; - } - PushNoVisitedNode(input, &que, &visited_nodes); - AddDependEdge(node, input, depend_edge_num); - } - } - UpdateControlDependRelations(control_depends); -} - -void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); } - -bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; } - -AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const { - if (!IsInRefOutputMap(out_pair)) { - MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap"; - } - return ref_out_in_map_.at(out_pair); -} - -void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) { - if (IsInRefOutputMap(final_pair)) { - MS_LOG(EXCEPTION) << "Out_pair is already in RefOutputMap"; - } - (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair)); -} - -bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { - if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) { - (void)graph_value_nodes_.erase(value_node); - return true; - } - return false; -} - -void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull new_anf_node) { - MS_EXCEPTION_IF_NULL(inputs_); - auto it = node_output_edges_.find(old_anf_node); - if (it != node_output_edges_.end()) { - const auto &outputs = it->second; - for (auto &output_node : outputs) { - MS_EXCEPTION_IF_NULL(output_node.first); - auto output_cnode = output_node.first->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - auto &output_node_inputs = output_cnode->inputs(); - // don't replace node if it is a control edge => output_node.second == 0 - if (output_node.second == 0) { - continue; - } - for (size_t i = 1; i < output_node_inputs.size(); i++) { - if (output_node_inputs[i] == old_anf_node.get()) { - output_cnode->set_input(i, new_anf_node); - } - } - // update graph inputs - for (size_t i = 0; i < inputs_->size(); i++) { - if ((*inputs_)[i] == old_anf_node.get()) { - MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() - << ",new graph input:" << new_anf_node->DebugString(); - (*inputs_)[i] = new_anf_node.get(); - break; - } - } - } - // update front to backend map - FrontBackendlMapUpdate(old_anf_node, new_anf_node); - } - // if change the ir of graph, regenerate execution order of graph - SetExecOrderByDefault(); - // update graph inputs in child graph - auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(), - [&old_anf_node](const std::pair> &n) -> bool { - return n.first == old_anf_node.get(); - }); - if (it_real_inputs != real_inputs_.end()) { - // erase old parameter in map - auto old_args = it_real_inputs->second; - real_inputs_.erase(it_real_inputs); - // insert new parameter to map - auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(), - [&new_anf_node](const std::pair> &n) -> bool { - return n.first == new_anf_node.get(); - }); - if (iter != real_inputs_.end()) { - MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited."; - iter->second = old_args; - } else { - real_inputs_.emplace_back(new_anf_node, old_args); - } - } -} - -void KernelGraph::UpdateExecuteKernelStreamLabel() { - for (auto &kernel : execution_order_) { - AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get()); - } -} - -std::vector> KernelGraph::GetLeafGraphOrder() { - std::vector> leaf_graph_order; - if (IsLeafGraph()) { - leaf_graph_order.push_back(shared_from_this()->cast()); - } else { - for (const auto &child_graph : child_graph_order_) { - MS_EXCEPTION_IF_NULL(child_graph); - auto child_leaf_graph_order = child_graph->GetLeafGraphOrder(); - std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order)); - } - } - return leaf_graph_order; -} - -bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } - -std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const { - std::vector result; - for (const auto &anf : execution_order_) { - if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { - result.push_back(anf->cast()); - } - } - return result; -} - -void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { - MS_EXCEPTION_IF_NULL(parameter); - MS_EXCEPTION_IF_NULL(arg); - MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString(); - MS_EXCEPTION_IF_NULL(parameter); - MS_EXCEPTION_IF_NULL(arg); - auto iter = std::find_if( - real_inputs_.begin(), real_inputs_.end(), - [¶meter](const std::pair> &n) -> bool { return n.first == parameter; }); - if (iter != real_inputs_.end()) { - auto &args = iter->second; - args.push_back(arg); - } else { - real_inputs_.emplace_back(parameter, std::vector(1, arg)); - } -} - -void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph) { - unreuse_args_[arg] = from_graph; -} - -void KernelGraph::UpdateCallRealInput() { - MS_LOG(INFO) << "Update graph id: " << graph_id_; - std::vector>> real_inputs_map; - for (auto &it : real_inputs_) { - auto parameter = it.first; - MS_EXCEPTION_IF_NULL(parameter); - auto real_inputs = it.second; - std::vector new_real_inputs; - for (auto &real_input : real_inputs) { - // if real input is a call node ,find the child graph output act as the new real input - auto tmp_real_input = GetCallRealOutputs(real_input); - std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); - // replace the call in unreuse_args_ - auto unreuse_arg_it = unreuse_args_.find(real_input); - if (unreuse_arg_it != unreuse_args_.end()) { - auto old_graph = unreuse_arg_it->second; - for (auto new_real_input : new_real_inputs) { - // if call reference graph output is parameter, it will be allowed to reuse - if (!new_real_input->isa()) { - unreuse_args_[new_real_input] = old_graph; - } - } - } - } - real_inputs_map.emplace_back(parameter, new_real_inputs); - } - real_inputs_ = real_inputs_map; -} - -void KernelGraph::PrintGraphExecuteOrder() const { - MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; - for (size_t i = 0; i < execution_order_.size(); i++) { - CNodePtr cur_cnode_ptr = execution_order_[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - std::string event_str; - std::string label_str; - if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) { - event_str = ", event_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId)) + "]"; - } - - if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) { - label_str = ", label_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrLabelIndex)) + "]"; - } - - if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) { - auto label_list = AnfAlgo::GetNodeAttr>(cur_cnode_ptr, kAttrLabelSwitchList); - label_str = ", label_id["; - for (size_t j = 0; j < label_list.size(); ++j) { - label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]"); - } - } - - MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id[" - << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" - << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]" - << event_str << label_str; - } -} - -void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { - if (front_node == nullptr || node == nullptr) { - MS_LOG(INFO) << "Front node or node is nullptr"; - return; - } - MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); - front_to_internal_outputs_map_[front_node] = node; - internal_outputs_to_front_map_[node] = front_node; -} - -void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) { - if (new_node == nullptr || node == nullptr) { - MS_LOG(INFO) << "New node or node is nullptr"; - return; - } - if (node == new_node) { - MS_LOG(INFO) << "New node and node is the same"; - return; - } - auto iter = internal_outputs_to_front_map_.find(node); - if (iter == internal_outputs_to_front_map_.end()) { - MS_LOG(INFO) << "Node is not internal output"; - return; - } - MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString(); - internal_outputs_to_front_map_[new_node] = iter->second; - front_to_internal_outputs_map_[iter->second] = new_node; - internal_outputs_to_front_map_.erase(iter); -} - -AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { - auto iter = front_to_internal_outputs_map_.find(front_node); - if (iter != front_to_internal_outputs_map_.end()) { - return iter->second; - } - return nullptr; -} - -bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const { - if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) { - return true; - } - return false; -} - -AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const { - auto iter = internal_outputs_to_front_map_.find(node); - if (iter != internal_outputs_to_front_map_.end()) { - return iter->second; - } - return nullptr; -} - -void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) { - if (node == nullptr) { - return; - } - (void)final_output_kernels_.insert(node); -} - -bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { - if (node == nullptr) { - return false; - } - if (final_output_kernels_.find(node) != final_output_kernels_.end()) { - return true; - } - return false; -} - -std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } - -KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h deleted file mode 100644 index 6861d43de0..0000000000 --- a/mindspore/ccsrc/session/kernel_graph.h +++ /dev/null @@ -1,223 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H -#define MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "utils/graph_utils.h" -#include "utils/contract.h" -#include "device/kernel_info.h" - -namespace mindspore { -namespace session { -using AnfWithOutIndex = std::pair; -class KernelGraph : public FuncGraph { - public: - KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false) { - inputs_ = std::make_shared>(); - execution_order_ = {}; - executable_ = true; - summary_node_exist_ = false; - stream_distinction_label_ = kInvalidDistincLabel; - } - ~KernelGraph() override; - - MS_DECLARE_PARENT(KernelGraph, FuncGraph); - - const std::vector &inputs() const; - std::vector *MutableInputs() const { return inputs_.get(); } - std::vector outputs() const; - CNodePtr NewCNode(const std::vector &inputs) override; - void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); - CNodePtr NewCNode(const CNodePtr &cnode); - ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); - ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr); - std::vector SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node); - void set_execution_order(const std::vector &order) { execution_order_ = order; } - const std::vector &execution_order() const { return execution_order_; } - void SetExecOrderByDefault(); - uint32_t graph_id() const { return graph_id_; } - void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } - - // and a new front to backend anf relation to maop - void FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf); - // replace old backend anf with new backend anf - void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf); - // get backend anf by front anf - AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf); - // check backend node whether exist in map - bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf); - // get value node by tensor - ValueNodePtr GetValueNodeByTensor(const tensor::TensorPtr &tensor); - // add value node tensor relation map - void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node); - // get all value nodes of graph - const std::unordered_set graph_value_nodes() const { return graph_value_nodes_; } - // add value node to graph - void AddValueNodeToGraph(const ValueNodePtr &value_node); - // ref output is in map - bool IsInRefOutputMap(const AnfWithOutIndex &pair) const; - // get ref correspond pairs - AnfWithOutIndex GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const; - // add ref correspond pairs - void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair); - // get map - std::map GetRefMap() const { return ref_out_in_map_; } - // checkout whether loop exist in graph - void CheckLoop(); - // check whether graph is executable - bool executable() const { return executable_; } - // set executable of graph - void set_executable(bool executable) { executable_ = executable; } - // set summary_node of graph - void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; } - // check whether exist summary node in graph - bool summary_node_exist() const { return summary_node_exist_; } - // set invalid inputs for control sink - std::vector *MutableValidInputs() { return &valid_inputs_; } - std::vector valid_inputs() const { return valid_inputs_; } - // replace node in graph - void ReplaceNode(NotNull old_anf_node, NotNull new_anf_node); - // set stream label of graph - void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; } - // get stream label of graph - uint32_t stream_distinction_label() { return stream_distinction_label_; } - // refresh execute kernel stream label - void UpdateExecuteKernelStreamLabel(); - // calculate the leaf graph order of root graph - std::vector> GetLeafGraphOrder(); - // the child graph of current graph - const std::vector> &child_graph_order() const { return child_graph_order_; } - void set_child_graph_order(const std::vector> &order) { child_graph_order_ = order; } - // checkout whether current graph is leaf graph - bool IsLeafGraph() const; - - // set input_tensors pointer of control parameter - void set_input_ctrl_tensors(const std::shared_ptr> &input_tensors_ptr) { - input_ctrl_tensors_ = input_tensors_ptr; - } - // get input_tensors pointer of control parameter - std::shared_ptr> input_ctrl_tensors() const { return input_ctrl_tensors_; } - // get parent kernel graph - std::shared_ptr parent_graph() const { return parent_graph_; } - // set parent kernel graph - void set_parent_graph(const std::shared_ptr &parent_graph) { parent_graph_ = parent_graph; } - // find anf node in graph - std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; - // get real inputs - const std::vector>> &real_inputs() const { return real_inputs_; } - void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); - // mark unreused args - void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph); - const std::map> &unreuse_args() const { return unreuse_args_; } - // used to dump ir - std::string ToString() const override; - // update the real input if the node is a call - void UpdateCallRealInput(); - - void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } - CNodePtr get_start_label() { return start_label_; } - void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } - CNodePtr get_end_goto() { return end_goto_; } - bool get_output_null() { return null_output_; } - void set_output_null(bool is_output_null) { null_output_ = is_output_null; } - void PrintGraphExecuteOrder() const; - const std::map> &summary_nodes() const { return summary_nodes_; } - void set_summary_nodes(const std::map> &nodes) { summary_nodes_ = nodes; } - void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); - void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node); - AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; - bool IsInternalOutput(const AnfNodePtr &node) const; - AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const; - void AddFinalOutputKernel(const AnfNodePtr &node); - bool IsFinalOutputKernel(const AnfNodePtr &node) const; - - private: - // remove value node form graph - bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); - void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, - std::unordered_set *visited_nodes); - // update node edge list - void UpdateNodeEdgeList(std::queue *seed_nodes); - // add node depend edge by data edge or control depend - void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); - // handle control depend - std::vector GetOutputNodes(const AnfNodePtr &node); - bool HandleControlDependNode(const AnfNodePtr &node, std::queue *que, - std::unordered_set *visited_nodes); - void UpdateControlDependRelations(const std::vector &depends); - - std::shared_ptr> inputs_; - std::vector execution_order_; - uint32_t graph_id_; - uint32_t stream_distinction_label_; - - // record map bettween front anf and backend anf,use two map implement bidirectional map - std::unordered_map front_backend_anf_map_; - std::unordered_map backend_front_anf_map_; - // there may be a tensor from ME backend ,a value ndoe will be create according the tensor,map record - std::unordered_map tensor_to_value_node_map_; - // include all value nodes - std::unordered_set graph_value_nodes_; - std::unordered_map node_input_num_; - std::unordered_map>> node_input_edges_; - // record map between ref final output anf with index and ref origin input with index - std::map ref_out_in_map_; - std::unordered_map>> node_output_edges_; - std::map> summary_nodes_; - // graph needn't execute - bool executable_; - // exist summary node in graph - bool summary_node_exist_; - // valid inputs - std::vector valid_inputs_; - - // new members for control sink process - // all child grahs refers to partial node - std::map> node_to_child_graphs_; - // child graph execute order in root graph - std::vector> child_graph_order_; - - // input_tensors of control parameter - std::shared_ptr> input_ctrl_tensors_; - - // parameter graph - std::shared_ptr parent_graph_; - // record real parameters,inputs_ is the formal parameters - std::vector>> real_inputs_; - std::map> unreuse_args_; - - CNodePtr start_label_; - CNodePtr end_goto_; - bool null_output_; - std::unordered_map front_to_internal_outputs_map_; - std::unordered_map internal_outputs_to_front_map_; - std::set final_output_kernels_; -}; -} // namespace session -using KernelGraphPtr = std::shared_ptr; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H diff --git a/mindspore/ccsrc/session/session.cc b/mindspore/ccsrc/session/session.cc deleted file mode 100644 index ae70fc77aa..0000000000 --- a/mindspore/ccsrc/session/session.cc +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "include/inference.h" -#include "session/session.h" -#include "utils/load_onnx/anf_converter.h" -#include "session/session_basic.h" -#include "session/session_factory.h" -#include "utils/base_ref_utils.h" -#include "kernel/oplib/oplib.h" -#ifdef ENABLE_D -#include "utils/context/ms_context.h" -#include "session/ascend_session.h" -#else -#include "session/cpu_session.h" -#endif - -namespace py = pybind11; -namespace mindspore::inference { -std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device) { - try { - inference::Session::RegAllOp(); - auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); - return anf_graph; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference LoadModel failed"; - return nullptr; - } -} - -void ExitInference() { - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return; - } - if (!ms_context->CloseTsd()) { - MS_LOG(ERROR) << "Inference CloseTsd failed!"; - return; - } -} - -std::shared_ptr MSSession::CreateSession(const std::string &device, uint32_t device_id) { - try { - auto session = std::make_shared(); - auto ret = session->Init(device, device_id); - if (ret != 0) { - return nullptr; - } - return session; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference CreatSession failed"; - return nullptr; - } -} - -void Session::RegAllOp() { - static std::mutex init_mutex; - static bool Initialized = false; - - std::lock_guard lock(init_mutex); - if (Initialized) { - return; - } - Initialized = true; - MsContext::GetInstance()->set_execution_mode(kGraphMode); - Py_Initialize(); - auto c_expression = PyImport_ImportModule("mindspore._c_expression"); - if (c_expression == nullptr) { - MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module."; - return; - } - PyObject *c_expression_dict = PyModule_GetDict(c_expression); - - PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); - if (op_info_loader_class == nullptr) { - MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression."; - return; - } - PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); - if (op_info_loader == nullptr) { - MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance."; - return; - } - PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); - if (op_info_loader_ins == nullptr) { - MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance."; - return; - } - auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); - if (all_ops_info_vector_addr_ul == nullptr) { - MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr."; - return; - } - auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); - auto all_ops_info = static_cast *>(all_ops_info_vector_addr); - for (auto op_info : *all_ops_info) { - kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); - } - all_ops_info->clear(); - delete all_ops_info; - Py_DECREF(op_info_loader); - Py_DECREF(op_info_loader_class); - Py_DECREF(c_expression_dict); - Py_DECREF(c_expression); - return; -} - -uint32_t Session::CompileGraph(std::shared_ptr funcGraphPtr) { - MS_ASSERT(session_impl_ != nullptr); - try { - auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); - py::gil_scoped_release gil_release; - return graph_id; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference CompileGraph failed"; - return static_cast(-1); - } -} - -MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector> &inputs) { - try { - std::vector inTensors; - inTensors.resize(inputs.size()); - bool has_error = false; - std::transform(inputs.begin(), inputs.end(), inTensors.begin(), - [&has_error](const std::shared_ptr &tensor_ptr) -> tensor::TensorPtr { - if (tensor_ptr == nullptr) { - MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; - has_error = true; - return nullptr; - } - auto tensor = static_cast(tensor_ptr.get()); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; - has_error = true; - return nullptr; - } - return tensor->tensor(); - }); - if (has_error) { - MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; - std::vector> multiTensor; - return multiTensor; - } - VectorRef outputs; - session_impl_->RunGraph(graph_id, inTensors, &outputs); - - return TransformVectorRefToMultiTensor(outputs); - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference Rungraph failed"; - return MultiTensor(); - } -} -namespace { -string AjustTargetName(const std::string &device) { - if (device == kAscendDevice) { - return std::string(kAscendDevice) + "Inference"; - } else { - MS_LOG(ERROR) << "Only support device Ascend right now"; - return ""; - } -} -} // namespace -int Session::Init(const std::string &device, uint32_t device_id) { - RegAllOp(); - auto ms_context = MsContext::GetInstance(); - ms_context->set_execution_mode(kGraphMode); - ms_context->set_device_id(device_id); - auto ajust_device = AjustTargetName(device); - if (ajust_device == "") { - return -1; - } - ms_context->set_device_target(device); - session_impl_ = session::SessionFactory::Get().Create(ajust_device); - if (session_impl_ == nullptr) { - MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; - return -1; - } - session_impl_->Init(device_id); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return -1; - } - if (!ms_context->OpenTsd()) { - MS_LOG(ERROR) << "Session init OpenTsd failed!"; - return -1; - } - return 0; -} - -Session::Session() = default; -} // namespace mindspore::inference diff --git a/mindspore/ccsrc/session/session.h b/mindspore/ccsrc/session/session.h deleted file mode 100644 index b608163067..0000000000 --- a/mindspore/ccsrc/session/session.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H -#define MINDSPORE_CCSRC_SESSION_SESSION_H - -#include -#include -#include -#include -#include -#include - -#include "session/session_basic.h" -#include "ir/anf.h" -#include "include/inference.h" - -namespace mindspore { -namespace inference { -class Session : public MSSession { - public: - Session(); - - uint32_t CompileGraph(std::shared_ptr funcGraphPtr) override; - - MultiTensor RunGraph(uint32_t graph_id, const std::vector> &inputs) override; - - int Init(const std::string &device, uint32_t device_id); - - static void RegAllOp(); - - private: - std::shared_ptr session_impl_ = nullptr; - std::vector graph_id_; -}; -} // namespace inference -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc deleted file mode 100644 index 91e430182c..0000000000 --- a/mindspore/ccsrc/session/session_basic.cc +++ /dev/null @@ -1,1073 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/session_basic.h" -#include -#include -#include -#include -#include "pipeline/parse/data_converter.h" -#include "ir/manager.h" -#include "ir/param_value_py.h" -#include "kernel/common_utils.h" -#include "operator/ops.h" -#include "common/trans.h" -#include "utils/context/ms_context.h" -#include "utils/config_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" -#include "pre_activate/common/common_backend_optimization.h" -#include "pre_activate/pass/const_input_to_attr_registry.h" -#include "pre_activate/common/helper.h" -#include "common/utils.h" -#include "ir/dtype.h" -#include "ir/anf.h" -#include "ir/func_graph_cloner.h" - -namespace mindspore { -namespace session { -static std::shared_ptr> python_paras_; -void ClearPythonParasMap() { python_paras_ = nullptr; } -namespace { -const int kSummaryGetItem = 2; - -PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) { - if (node == nullptr) { - return nullptr; - } - auto parameter = node->cast(); - if (parameter == nullptr || !parameter->has_default()) { - return nullptr; - } - auto param_value = std::dynamic_pointer_cast(parameter->default_param()); - MS_EXCEPTION_IF_NULL(param_value); - auto py_param = param_value->value(); - return py_param.ptr(); -} - -BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph, - const std::vector &input_tensors) { - MS_EXCEPTION_IF_NULL(node); - MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; - // if node is a value node, no need sync addr from device to host - if (!AnfAlgo::OutputAddrExist(node, output_index)) { - if (node->isa()) { - auto value_node = node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - return value_node->value(); - } - if (node->isa()) { - for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) { - if (input_idx >= input_tensors.size()) { - MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); - } - if (graph.inputs()[input_idx] == node) { - return input_tensors[input_idx]; - } - } - MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr"; - } - } - // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) - auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); - MS_EXCEPTION_IF_NULL(address); - auto shape = AnfAlgo::GetOutputInferShape(node, output_index); - TypeId type_id = kNumberTypeFloat32; - type_id = AnfAlgo::GetOutputInferDataType(node, output_index); - std::vector temp_shape; - if (graph.IsInternalOutput(node)) { - temp_shape.emplace_back(1); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - tensor->set_device_address(address); - tensor->set_dirty(false); - return tensor; - } - (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - // if in paynative mode,data only copyed to host when user want to print data - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { - tensor->set_device_address(address); - tensor->set_dirty(false); - } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), - LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { - MS_LOG(INFO) << "Output sync device to host error!!!"; - tensor->set_dirty(false); - } - return tensor; -} - -BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, - const std::vector &input_tensors) { - MS_EXCEPTION_IF_NULL(anf); - MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]"; - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); - MS_EXCEPTION_IF_NULL(item_with_index.first); - MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString(); - // special handle for maketuple - if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { - auto cnode = item_with_index.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - VectorRef ret; - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors); - ret.push_back(out); - } - return ret; - } - // if is graph return nothing ,the function should return a null anylist - size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first); - if (size == 0) { - return VectorRef(); - } - return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors); -} - -BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph, - const std::vector &input_tensors) { - MS_EXCEPTION_IF_NULL(anf); - if (!AnfAlgo::IsRealKernel(anf)) { - MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel"; - } - if (anf->isa()) { - return CreateOneTensor(anf, 0, graph, input_tensors); - } - VectorRef ret; - if (anf->isa() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) { - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) { - auto out = CreateOneTensor(anf, i, graph, input_tensors); - ret.emplace_back(out); - } - } - return ret; -} - -ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - MS_EXCEPTION_IF_NULL(graph); - auto value_node = anf->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - return nullptr; - } - auto new_value_node = graph->NewValueNode(value_node); - graph->FrontBackendlMapAdd(anf, new_value_node); - graph->AddValueNodeToGraph(new_value_node); - return new_value_node; -} - -size_t LoadCtrlInputTensor(const std::shared_ptr &graph, std::vector *inputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Load kInputCtrlTensors"; - auto inputs_params = graph->input_ctrl_tensors(); - if (inputs_params == nullptr) { - return 0; - } - if (inputs_params->empty()) { - MS_LOG(EXCEPTION) << "Illegal empty inputs_params"; - } - auto tensor = (*inputs_params)[0]; - MS_EXCEPTION_IF_NULL(tensor); - auto *val = static_cast(tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 0; - tensor->set_dirty(true); - // set loop_count to zero - MS_EXCEPTION_IF_NULL(inputs); - inputs->push_back(tensor); - return inputs_params->size(); -} - -ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr &graph, const tensor::TensorPtr &input_tensor) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(input_tensor); - auto value_node = std::make_shared(input_tensor); - MS_EXCEPTION_IF_NULL(value_node); - // construct abstract of value node - auto type_of_tensor = input_tensor->Dtype(); - auto shape_of_tensor = input_tensor->shape(); - auto abstract = std::make_shared(type_of_tensor, shape_of_tensor); - value_node->set_abstract(abstract); - // add value node to graph - auto input_value_node = graph->NewValueNode(value_node); - graph->AddValueNodeToGraph(input_value_node); - return input_value_node; -} - -ParameterPtr ConstructRunOpParameter(const std::shared_ptr &graph, const tensor::TensorPtr &input_tensor, - int tensor_mask) { - MS_EXCEPTION_IF_NULL(graph); - auto param = graph->NewParameter(); - MS_EXCEPTION_IF_NULL(param); - if (tensor_mask == kParameterWeightTensorMask) { - py::object obj; - auto param_value_new = std::make_shared(obj); - param->set_default_param(param_value_new); - } - // set the kernel info of parameter - auto kernel_build_info_builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(input_tensor); - if (input_tensor->device_address().get() == nullptr) { - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type(); - kernel_build_info_builder->SetOutputsDeviceType(std::vector{param_init_data_type}); - } else { - kernel_build_info_builder->SetOutputsFormat(std::vector{input_tensor->device_address()->format()}); - kernel_build_info_builder->SetOutputsDeviceType(std::vector{input_tensor->device_address()->type_id()}); - } - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); - // construct abstract of parameter - auto type_of_tensor = input_tensor->Dtype(); - auto shape_of_tensor = input_tensor->shape(); - auto abstract = std::make_shared(type_of_tensor, shape_of_tensor); - param->set_abstract(abstract); - return param; -} - -void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { - MS_LOG(INFO) << "Graph outputs:"; - const size_t max_deep = 10; - if (recurse_level > max_deep) { - MS_LOG(INFO) << "Recurse too deep"; - return; - } - std::string tab_str; - for (size_t i = 0; i < recurse_level; i++) { - tab_str = tab_str.append(" "); - } - if (any.is()) { - (void)tab_str.append("{"); - MS_LOG(INFO) << tab_str; - auto any_list = any.cast(); - for (auto &it : any_list) { - DumpGraphOutput(it, recurse_level + 1); - } - (void)tab_str.append("}"); - MS_LOG(INFO) << tab_str; - } - (void)tab_str.append(any.ToString()); - MS_LOG(INFO) << tab_str; -} - -bool ExistSummaryNode(const KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto ret = graph->get_return(); - MS_EXCEPTION_IF_NULL(ret); - auto all_nodes = DeepLinkedGraphSearch(ret); - for (auto &n : all_nodes) { - if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || - IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { - return true; - } - } - return false; -} -} // namespace - -GraphId SessionBasic::graph_sum_ = 0; - -KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) { - auto it = graphs_.find(graph_id); - if (it == graphs_.end()) { - MS_LOG(WARNING) << "Can't find graph " << graph_id; - return nullptr; - } - return it->second; -} - -void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) { - auto graph_id = GetGraphIdByNode(out_node); - if (graph_id == kInvalidGraphId) { - return; - } - auto node_graph = GetGraph(graph_id); - if (node_graph == nullptr) { - return; - } - MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString(); - auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node); - if (ref_node == nullptr) { - MS_LOG(INFO) << "No corresponding internal output for output node"; - return; - } - auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0); - auto ref_real_node = real_kernel.first; - auto ref_real_node_index = real_kernel.second; - if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node) && - node_graph->IsFinalOutputKernel(ref_real_node)) { - auto kernel_info = ref_real_node->kernel_info(); - if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { - MS_LOG(INFO) << "No kernel info"; - return; - } - auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index); - if (address == nullptr) { - MS_LOG(INFO) << "No kernel address"; - return; - } - auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); - auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); - parameter->set_kernel_info(std::make_shared()); - auto d_kernel_info = parameter->kernel_info(); - MS_EXCEPTION_IF_NULL(d_kernel_info); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetOutputsDeviceType({type}); - builder.SetOutputsFormat({format}); - d_kernel_info->set_select_kernel_build_info(builder.Build()); - AnfAlgo::SetOutputAddr(address, 0, parameter.get()); - } -} - -std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, - KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(graph); - std::vector parameters; - std::vector pre_graph_out = {node}; - // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive - if (!AnfAlgo::IsRealKernel(node)) { - pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); - } - auto valid_inputs = graph->MutableValidInputs(); - MS_EXCEPTION_IF_NULL(valid_inputs); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { - auto parameter = graph->NewParameter(); - MS_EXCEPTION_IF_NULL(parameter); - parameter->set_abstract(abstract); - auto new_parameter = graph->NewParameter(parameter); - parameters.push_back(new_parameter); - valid_inputs->push_back(valid_input); - graph_inputs->push_back(new_parameter); - }; - for (const auto &out_node : pre_graph_out) { - MS_EXCEPTION_IF_NULL(out_node); - auto abstract = out_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // create multiple parameters if is a tuple output real kernel - if (abstract->isa() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; - for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { - create_parameter((*tuple_abstract)[output_idx]); - } - continue; - } - // create single parameter if is a abstract real kernel - create_parameter(out_node->abstract()); - InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]); - } - return parameters; -} - -ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, - KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - if (!anf->isa()) { - MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; - } - MS_EXCEPTION_IF_NULL(graph); - auto m_tensor = GetParamDefaultInputTensor(anf); - auto valid_inputs = graph->MutableValidInputs(); - MS_EXCEPTION_IF_NULL(valid_inputs); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - ParameterPtr new_parameter = nullptr; - // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter - if (python_paras_ == nullptr) { - python_paras_ = std::make_shared>(); - } - auto iter = python_paras_->find(m_tensor); - if (iter != python_paras_->end()) { - new_parameter = iter->second; - } else { - TraceManager::DebugTrace(std::make_shared(anf->debug_info())); - new_parameter = graph->NewParameter(anf->cast()); - if (m_tensor != nullptr) { - (*python_paras_)[m_tensor] = new_parameter; - } - TraceManager::EndTrace(); - } - graph_inputs->push_back(new_parameter); - valid_inputs->push_back(valid_input); - return new_parameter; -} - -AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; - auto parameters = CreateParameterFromTuple(anf, valid_input, graph); - if (parameters.empty()) { - MS_LOG(EXCEPTION) << "No parameter exist!!"; - } - if (parameters.size() == 1) { - return parameters[0]; - } - std::vector make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; - (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); - auto make_tuple = graph->NewCNode(make_tuple_input); - MS_EXCEPTION_IF_NULL(make_tuple); - MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; - return make_tuple; -} - -CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, - bool *from_other_graph, - std::unordered_map *other_graph_cnode) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(from_other_graph); - MS_EXCEPTION_IF_NULL(other_graph_cnode); - *from_other_graph = false; - // get primitive of old node - std::vector cnode_inputs; - auto prim = AnfAlgo::GetCNodePrimitive(cnode); - if (prim != nullptr) { - // push attr to inputs[0] of new cnode - cnode_inputs.push_back(std::make_shared(std::make_shared(*prim))); - } else { - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); - MS_EXCEPTION_IF_NULL(fg); - auto new_fg = BasicClone(fg); - cnode_inputs.push_back(std::make_shared(new_fg)); - } - auto origin_inputs = cnode->inputs(); - bool optimize_depend = false; - if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && - origin_inputs[kRealInputIndexInDepend]->isa()) { - optimize_depend = true; - } - // if has multiple depends,only select first depend as parameter - for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { - auto anf = origin_inputs[input_idx]; - MS_EXCEPTION_IF_NULL(anf); - // anf has been created before - if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { - cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); - continue; - } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { - cnode_inputs.push_back((*other_graph_cnode)[anf]); - continue; - } else if (anf->isa() && !IsValueNode(anf)) { - // if input is a value node, - auto new_value_node = CreateNewValueNode(anf, graph); - if (new_value_node != nullptr) { - cnode_inputs.emplace_back(new_value_node); - } - continue; - } else if (anf->isa() && AnfAlgo::GetOutputTensorNum(anf) == 1) { - auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); - cnode_inputs.push_back(new_parameter); - if (GetGraphIdByNode(anf) == kInvalidGraphId) { - graph->FrontBackendlMapAdd(anf, new_parameter); - } else { - (*other_graph_cnode)[anf] = new_parameter; - } - continue; - } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { - cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); - continue; - } else if (anf->isa()) { - *from_other_graph = true; - // the input node is a cnode from other graph - auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); - cnode_inputs.push_back(parameter_from_cnode); - (*other_graph_cnode)[anf] = parameter_from_cnode; - continue; - } - MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; - } - TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); - auto new_cnode = graph->NewCNode(cnode_inputs); - TraceManager::EndTrace(); - return new_cnode; -} - -static std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(graph); - // create primitive of cnode:call(partial or switch) - std::vector cnode_inputs = { - graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; - auto attr_input = cnode->input(kAnfPrimitiveIndex); - MS_EXCEPTION_IF_NULL(attr_input); - auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); - if (cnode_input == nullptr) { - MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString() - << ", but input[0] has not been created."; - } - // if the node is partial, insert the inputs of partial to the call - if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) { - auto partial_node = attr_input->cast(); - MS_EXCEPTION_IF_NULL(partial_node); - auto partial_inputs = partial_node->inputs(); - std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(), - std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node)); - return graph->GetBackendAnfByFrontAnf(node); - }); - return cnode_inputs; - } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { - cnode_inputs.emplace_back(cnode_input); - return cnode_inputs; - } - MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; -} - -CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(graph); - std::vector cnode_inputs; - auto attr_input = cnode->input(kAnfPrimitiveIndex); - MS_EXCEPTION_IF_NULL(attr_input); - if (AnfAlgo::IsGraphKernel(cnode)) { - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); - MS_EXCEPTION_IF_NULL(fg); - auto new_fg = BasicClone(fg); - cnode_inputs.push_back(std::make_shared(new_fg)); - } else if (IsValueNode(attr_input)) { - // create primitive of cnode:call - cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; - // create a ValueNode as input of cnode:call - if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { - cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input)); - } else { - auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph); - if (new_value_node != nullptr) { - cnode_inputs.emplace_back(new_value_node); - } - } - } else if (attr_input->isa()) { - cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); - } else { - // get primitive of old node - auto prim = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(prim); - // push attr to inputs[0] of new cnode - cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(*prim)))}; - } - - for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { - auto anf = cnode->input(input_idx); - MS_EXCEPTION_IF_NULL(anf); - // anf has been created before - if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { - cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); - continue; - } else if (IsValueNode(anf)) { - continue; - } - MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; - } - TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); - auto new_cnode = graph->NewCNode(cnode_inputs); - TraceManager::EndTrace(); - return new_cnode; -} - -ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - MS_EXCEPTION_IF_NULL(graph); - auto value_node = anf->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf); - MS_EXCEPTION_IF_NULL(sub_func_graph); - if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) { - MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph."; - } - auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph]; - - ValueNodePtr new_value_node = std::make_shared(sub_kernel_graph); - new_value_node->set_abstract(value_node->abstract()); - // create new kernel_info of new value_node - auto kernel_info = std::make_shared(); - kernel_info->SetFeatureMapFlag(false); - new_value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get()); - - graph->FrontBackendlMapAdd(anf, new_value_node); - - return new_value_node; -} - -ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - MS_EXCEPTION_IF_NULL(graph); - if (!anf->isa()) { - MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; - } - - auto m_tensor = GetParamDefaultInputTensor(anf); - ParameterPtr new_parameter = nullptr; - if (python_paras_ == nullptr) { - python_paras_ = std::make_shared>(); - } - auto iter = python_paras_->find(m_tensor); - if (iter != python_paras_->end()) { - new_parameter = iter->second; - } else { - TraceManager::DebugTrace(std::make_shared(anf->debug_info())); - new_parameter = graph->NewParameter(anf->cast()); - if (m_tensor != nullptr) { - (*python_paras_)[m_tensor] = new_parameter; - } - TraceManager::EndTrace(); - } - - return new_parameter; -} - -KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - std::unordered_map other_graph_cnode; - auto graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Create graph: " << graph->graph_id(); - size_t from_other_graph_depend_num = 0; - for (const auto &node : lst) { - MS_EXCEPTION_IF_NULL(node); - MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); - if (!node->isa()) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode"; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // create a new cnode object - bool from_other_graph = false; - // only first depend from other graph can create - bool valid_input = true; - if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { - valid_input = false; - } - auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode); - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) { - from_other_graph_depend_num++; - } - MS_EXCEPTION_IF_NULL(new_cnode); - new_cnode->set_abstract(cnode->abstract()); - new_cnode->set_scope(cnode->scope()); - // record map relations between anf from ME and new anf node used in backend - graph->FrontBackendlMapAdd(node, new_cnode); - } - // add a make_tuple at the end of graph as output - graph->set_output(ConstructOutput(outputs, graph)); - MS_EXCEPTION_IF_NULL(context_); - FuncGraphManagerPtr manager = MakeManager({graph}); - if (manager) { - manager->AddFuncGraph(graph); - graph->set_manager(manager); - } - graph->SetExecOrderByDefault(); - if (ExistSummaryNode(graph.get())) { - graph->set_summary_node_exist(true); - } - opt::BackendCommonOptimization(graph); - return graph; -} - -void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(graph); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // create a new cnode object - auto new_cnode = CreateNewCNode(cnode, graph.get()); - MS_EXCEPTION_IF_NULL(new_cnode); - new_cnode->set_abstract(cnode->abstract()); - new_cnode->set_fullname_with_scope(cnode->fullname_with_scope()); - new_cnode->set_scope(cnode->scope()); - graph->FrontBackendlMapAdd(node, new_cnode); - if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) { - graph->set_return(new_cnode); - } -} -std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph, - std::vector *all_out_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(all_out_graph); - auto node_list = TopoSort(func_graph->get_return()); - auto graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(graph); - front_backend_graph_map_[func_graph] = graph; - MS_LOG(INFO) << "Create graph: " << graph->graph_id(); - - bool is_trace_back = false; - for (const auto &node : node_list) { - MS_EXCEPTION_IF_NULL(node); - MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); - if (node->isa()) { - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - auto new_parameter = CreateNewParameter(node, graph.get()); - graph_inputs->push_back(new_parameter); - graph->FrontBackendlMapAdd(node, new_parameter); - continue; - } else if (node->isa()) { - if (!IsValueNode(node)) { - // if input is a common value node, - (void)CreateNewValueNode(node, graph.get()); - } else { - // if input is a ValueNode - FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); - if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) { - is_trace_back = true; - } else { - (void)ConstructKernelGraph(child_graph, all_out_graph); - } - (void)CreateValueNodeKernelGraph(node, graph.get()); - } - continue; - } else { - CreateCNodeKernelGraph(node, graph); - } - } - // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null. - graph->set_output_null(is_trace_back); - AddParameterToGraphInputs(func_graph->parameters(), graph.get()); - graph->SetExecOrderByDefault(); - if (ExistSummaryNode(graph.get())) { - graph->set_summary_node_exist(true); - } - all_out_graph->push_back(graph); - return graph; -} - -void SessionBasic::AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - graph_inputs->clear(); - for (auto ¶meter : parameters) { - MS_EXCEPTION_IF_NULL(parameter); - auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); - if (backend_parameter == nullptr) { - // for example "def f(x,y,z) {return x + y}", parameter z in unused - auto new_parameter = CreateNewParameter(parameter, graph); - graph_inputs->push_back(new_parameter); - MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); - continue; - } - MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString(); - graph_inputs->push_back(backend_parameter); - } -} - -// run graph steps -void SessionBasic::LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const { - std::vector inputs(inputs_const); - size_t input_ctrl_size = 1; - MS_EXCEPTION_IF_NULL(kernel_graph); - if (kernel_graph->input_ctrl_tensors()) { - input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); - } - auto input_nodes = kernel_graph->inputs(); - if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) { - MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() - << ", input_ctrl_size:" << input_ctrl_size; - } - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - for (size_t i = 0; i < inputs.size(); ++i) { - auto tensor = inputs[i]; - MS_EXCEPTION_IF_NULL(tensor); - auto input_node = input_nodes[i]; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { - auto pk_node = input_node->cast(); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - bool need_sync = false; - if (ms_context->enable_pynative_infer()) { - if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { - need_sync = true; - } - } else { - if (tensor->is_dirty()) { - need_sync = true; - } else if (tensor->device_address() != device_address) { - (void)tensor->data_sync(); - need_sync = true; - } - } - if (need_sync) { - if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) { - tensor->set_device_address(device_address); - } - MS_EXCEPTION_IF_NULL(device_address); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } - } - } - tensor->set_dirty(false); - } -} - -void SessionBasic::UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, - const std::vector &input_tensors) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(outputs); - if (!kernel_graph->child_graph_order().empty()) { - // use the last child graph output as the root graph output - UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors); - return; - } - auto anf_outputs = kernel_graph->outputs(); - for (auto &item : anf_outputs) { - MS_EXCEPTION_IF_NULL(item); - MS_LOG(INFO) << "Update output[" << item->DebugString() << "]"; - if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) { - outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors)); - continue; - } - outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors)); - } -} - -void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { - MS_EXCEPTION_IF_NULL(callback); - summary_callback_ = callback; -} - -void SessionBasic::Reorder(std::vector *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } - -void SessionBasic::GetSummaryNodes(KernelGraph *graph) { - MS_LOG(DEBUG) << "Update summary Start"; - MS_EXCEPTION_IF_NULL(graph); - if (!graph->summary_node_exist()) { - return; - } - auto summary = graph->summary_nodes(); - auto apply_list = TopoSort(graph->get_return()); - for (auto &n : apply_list) { - MS_EXCEPTION_IF_NULL(n); - if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || - IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { - auto cnode = n->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() <= kSummaryGetItem) { - MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!"; - } - auto node = cnode->input(kSummaryGetItem); - MS_EXCEPTION_IF_NULL(node); - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - MS_EXCEPTION_IF_NULL(item_with_index.first); - if (!AnfAlgo::IsRealKernel(item_with_index.first)) { - MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); - } - summary[n->fullname_with_scope()] = item_with_index; - } - } - graph->set_summary_nodes(summary); - MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); -} - -void SessionBasic::Summary(KernelGraph *graph) { - if (summary_callback_ == nullptr) { - return; - } - MS_EXCEPTION_IF_NULL(graph); - bool exist_summary = graph->summary_node_exist(); - if (!exist_summary) { - return; - } - GetSummaryNodes(graph); - auto summary_outputs = graph->summary_nodes(); - std::map params_list; - // fetch outputs apply kernel in session & run callback functions - for (auto &output_item : summary_outputs) { - auto node = output_item.second.first; - size_t index = IntToSize(output_item.second.second); - auto address = AnfAlgo::GetOutputAddr(node, index); - auto shape = AnfAlgo::GetOutputInferShape(node, index); - TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index); - std::vector temp_shape; - (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - MS_EXCEPTION_IF_NULL(address); - if (!address->GetPtr()) { - continue; - } - if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()), - tensor->data_type(), tensor->data_c())) { - MS_LOG(ERROR) << "Failed to sync output from device to host."; - } - tensor->set_dirty(false); - params_list[output_item.first] = tensor; - } - // call callback function here - summary_callback_(0, params_list); -} - -CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph) { - MS_EXCEPTION_IF_NULL(graph); - std::vector output_args; - for (const auto &output : outputs) { - MS_EXCEPTION_IF_NULL(output); - MS_LOG(INFO) << "Output:" << output->DebugString(); - } - auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { - auto backend_anf = graph->GetBackendAnfByFrontAnf(out); - if (backend_anf != nullptr) { - auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); - auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); - MS_EXCEPTION_IF_NULL(out); - auto out_func_graph = out->func_graph(); - MS_EXCEPTION_IF_NULL(out_func_graph); - auto out_func_graph_manager = out_func_graph->manager(); - if (out_func_graph_manager == nullptr) { - return backend_anf; - } - auto node_users = out_func_graph_manager->node_users(); - auto users = node_users[out]; - bool internal_output = true; - std::string kernel_target = GetCNodeTarget(front_real_kernel.first); - for (auto user : users) { - if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { - internal_output = false; - break; - } - } - if (internal_output) { - MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); - graph->AddInternalOutput(out, backend_real_kernel.first); - } - return backend_anf; - } - MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; - }; - output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); - (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args), - [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); }); - return graph->NewCNode(output_args); -} - -void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr &graph) { - MS_LOG(INFO) << "Start!"; - std::vector make_tuple_inputs; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - MS_EXCEPTION_IF_NULL(graph); - if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) { - for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) { - auto idx = NewValueNode(SizeToInt(output_index)); - MS_EXCEPTION_IF_NULL(idx); - auto imm = std::make_shared(output_index); - idx->set_abstract(std::make_shared(imm)); - auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); - std::vector types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)}; - std::vector> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get()); - make_tuple_inputs.push_back(getitem); - } - } else { - make_tuple_inputs.push_back(cnode); - } - // create output - auto g_output = graph->NewCNode(make_tuple_inputs); - graph->set_output(g_output); - // set graph manager,which now is only used to get valuenodes and hardware optimizing - MS_EXCEPTION_IF_NULL(context_); - FuncGraphManagerPtr manager = context_->manager(); - if (manager != nullptr) { - manager->AddFuncGraph(graph); - graph->set_manager(manager); - } - MS_LOG(INFO) << "Finish!"; -} - -std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, - const std::vector &input_tensors, - const std::vector &tensors_mask) { - auto graph = std::make_shared(); - std::vector inputs; - // set input[0] - PrimitivePtr op_prim = op_run_info.py_primitive; - MS_EXCEPTION_IF_NULL(op_prim); - inputs.push_back(std::make_shared(op_prim)); - // set input parameter - MS_LOG(INFO) << "Input tensor size: " << input_tensors.size(); - if (input_tensors.size() != tensors_mask.size()) { - MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size " - << tensors_mask.size(); - } - for (size_t i = 0; i < input_tensors.size(); ++i) { - if (tensors_mask[i] == kValueNodeTensorMask) { - auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]); - inputs.push_back(value_node); - continue; - } - auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]); - inputs.push_back(parameter); - auto mutable_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(mutable_inputs); - mutable_inputs->push_back(parameter); - } - // set execution order - auto cnode = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(cnode); - // set abstract,which include inferred shapes and types - cnode->set_abstract(op_run_info.abstract); - // set execution order - std::vector exe_order = {cnode}; - graph->set_execution_order(exe_order); - // set output - CreateOutputNode(cnode, graph); - return graph; -} - -BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) { - if (utils::isa(base_ref)) { - auto ref_list = utils::cast(base_ref); - py::tuple output_tensors(ref_list.size()); - for (size_t i = 0; i < ref_list.size(); ++i) { - auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef - if (utils::isa(output)) { - auto tensor_ptr = utils::cast(output); - MS_EXCEPTION_IF_NULL(tensor_ptr); - output_tensors[i] = tensor_ptr; - } else if (utils::isa(output)) { - py::object obj = utils::cast(output).object_; - py::tuple tensor_tuple = py::cast(obj); - output_tensors[i] = tensor_tuple; - } else { - MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; - } - } - return output_tensors; // turn tuple to py::object and store in PyObjectRef - } else if (utils::isa(base_ref)) { - return base_ref; - } else { - MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; - } -} - -KernelGraphPtr SessionBasic::NewKernelGraph() { - auto graph = std::make_shared(); - graph->set_graph_id(graph_sum_); - graphs_[graph_sum_++] = graph; - return graph; -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h deleted file mode 100755 index cf85dd0225..0000000000 --- a/mindspore/ccsrc/session/session_basic.h +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H -#define MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H - -#include -#include -#include -#include -#include -#include - -#include "utils/base_ref_extends.h" -#include "session/session_context.h" -#include "session/kernel_graph.h" -#include "ir/anf.h" -#include "ir/tensor.h" -#include "utils/any.h" -#include "utils/contract.h" -#include "pynative/pynative_execute.h" -#include "device/kernel_info.h" -#ifdef ENABLE_DEBUGGER -#include "debug/debugger/debugger.h" -#endif - -namespace mindspore { -using GraphId = uint32_t; -using GraphInfo = std::string; -namespace session { -void ClearPythonParasMap(); -using CallBackFunc = uint32_t (*)(uint32_t graph_id, - const std::map ¶ms_list); -using AnyList = std::vector; -using AnyListPtr = std::shared_ptr; - -using OpRunInfo = pynative::OpExecInfo; -using OpRunInfoPtr = std::shared_ptr; - -class SessionBasic { - public: - SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) { -#ifdef ENABLE_DEBUGGER - debugger_ = nullptr; -#endif - } - - virtual void Init(uint32_t device_id) { device_id_ = device_id; } - - virtual ~SessionBasic() { summary_callback_ = nullptr; } - - virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; - virtual GraphId CompileGraph(NotNull func_graph) { return kInvalidGraphId; } - // build graph, used to handle multiple child graphs - virtual void BuildGraph(GraphId) {} - - virtual void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) = 0; - - virtual void BuildOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors, - const std::vector &tensors_mask) {} - - virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors) { - return py::tuple(); - } - - virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); - - void CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph); - - std::shared_ptr ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); - std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph, - std::vector *all_out_graph); - - CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, - std::unordered_map *other_graph_cnode); - CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); - - // set parameters of final graph - virtual GraphId SetFinalGraphInput(const std::vector &) { return kInvalidGraphId; } - // set output of final graph - virtual void SetFinalGraphOutput(const BaseRef &) {} - // insert switch and set the relative active ops - virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {} - // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter - virtual void SetChildGraphInput(GraphId, const VectorRef &) {} - // get graph id in child graphs by ME front anf node pointer - virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } - virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } - virtual void SetActive(GraphId, GraphId) {} - virtual void GetSummaryNodes(KernelGraph *graph); - -#ifdef ENABLE_DEBUGGER - // set debugger - void SetDebugger() { - debugger_ = Debugger::GetInstance(); - debugger_->Init(device_id_); - } -#endif - - protected: - // Get graph by graph id ,if not exist return null ptr - KernelGraphPtr GetGraph(GraphId graph_id); - virtual void LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const; - void UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, - const std::vector &input_tensors) const; - void Reorder(std::vector *node_list); - void Summary(KernelGraph *graph); - // create graph output for RunOp - void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr &graph); - CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph); - // create a single run op graph - std::shared_ptr ConstructSingleOpGraph(const OpRunInfo &op_run_info, - const std::vector &input_tensors, - const std::vector &tensors_mask); - // trans BaseRef list to py::tuple - BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); - // create a new kernel graph and update the graph sum - KernelGraphPtr NewKernelGraph(); - std::vector CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph); - virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); - ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); - ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); - AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); - void AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph); - void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); - - std::unordered_map> graphs_; - std::unordered_map> run_op_graphs_; - std::unordered_map front_backend_graph_map_; - std::shared_ptr context_; - CallBackFunc summary_callback_; - static GraphId graph_sum_; - uint32_t device_id_; -#ifdef ENABLE_DEBUGGER - std::shared_ptr debugger_; -#endif -}; - -using SessionPtr = std::shared_ptr; -using NamedSummaryOutputs = std::map>; -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/session/session_context.cc b/mindspore/ccsrc/session/session_context.cc deleted file mode 100644 index 2b6ebf6b84..0000000000 --- a/mindspore/ccsrc/session/session_context.cc +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/session_context.h" -namespace mindspore { -namespace session { -std::shared_ptr Context::GetInstance() { - static std::shared_ptr context_singleton = std::make_shared(); - return context_singleton; -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/session_context.h b/mindspore/ccsrc/session/session_context.h deleted file mode 100644 index 78794c348e..0000000000 --- a/mindspore/ccsrc/session/session_context.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H -#define MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H -#include -#include -#include -#include -#include -#include - -#include "ir/tensor.h" -#include "pipeline/resource.h" -#include "utils/context/ms_context.h" -namespace mindspore { -namespace session { -const char kInputCtrlTensors[] = "input_ctrl_tensors"; - -class Context : public pipeline::ResourceBase { - public: - explicit Context(std::string target = kAscendDevice, uint32_t device_id = 0) - : target_(std::move(target)), device_id_(device_id) {} - ~Context() override = default; - - uint32_t device_id() const { return device_id_; } - static std::shared_ptr GetInstance(); - void AddManager(const FuncGraphManagerPtr &m) { manager_list_.push_back(m); } - - private: - std::vector manager_list_; - std::string target_; - uint32_t device_id_; -}; -} // namespace session -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H diff --git a/mindspore/ccsrc/session/session_factory.cc b/mindspore/ccsrc/session/session_factory.cc deleted file mode 100644 index 4cd0481f8c..0000000000 --- a/mindspore/ccsrc/session/session_factory.cc +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/session_factory.h" -#include -#include -#include -namespace mindspore { -namespace session { -SessionFactory &SessionFactory::Get() { - static SessionFactory instance; - return instance; -} - -void SessionFactory::Register(const std::string &device_name, SessionCreator &&session_creator) { - if (session_creators_.end() == session_creators_.find(device_name)) { - (void)session_creators_.emplace(device_name, session_creator); - } -} - -std::shared_ptr SessionFactory::Create(const std::string &device_name) { - auto iter = session_creators_.find(device_name); - if (session_creators_.end() != iter) { - MS_EXCEPTION_IF_NULL(iter->second); - return (iter->second)(); - } - return nullptr; -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/session_factory.h b/mindspore/ccsrc/session/session_factory.h deleted file mode 100644 index 99db0afeb7..0000000000 --- a/mindspore/ccsrc/session/session_factory.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ -#define MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ - -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "session/session_basic.h" -namespace mindspore { -namespace session { -using SessionCreator = std::function()>; -class SessionFactory { - public: - static SessionFactory &Get(); - void Register(const std::string &device_name, SessionCreator &&session_creator); - std::shared_ptr Create(const std::string &device_name); - - private: - SessionFactory() = default; - ~SessionFactory() = default; - DISABLE_COPY_AND_ASSIGN(SessionFactory) - std::map session_creators_; -}; - -class SessionRegistrar { - public: - SessionRegistrar(const std::string &device_name, SessionCreator &&session_creator) { - SessionFactory::Get().Register(device_name, std::move(session_creator)); - } - ~SessionRegistrar() = default; -}; - -#define MS_REG_SESSION(DEVICE_NAME, SESSION_CLASS) \ - static const SessionRegistrar g_session_registrar__##DEVICE_NAME##_##_reg( \ - DEVICE_NAME, []() { return std::make_shared(); }); -} // namespace session -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ diff --git a/mindspore/ccsrc/transform/CMakeLists.txt b/mindspore/ccsrc/transform/CMakeLists.txt deleted file mode 100644 index c783cc0060..0000000000 --- a/mindspore/ccsrc/transform/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -if (ENABLE_GE OR ENABLE_D) - file(GLOB_RECURSE _TRANSFORM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") - set_property(SOURCE ${_TRANSFORM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE_ADPT) - add_library(_mindspore_transform_obj OBJECT ${_TRANSFORM_SRC_LIST}) - - if (NOT ENABLE_GE) - target_compile_definitions(_mindspore_transform_obj PRIVATE NO_GE_CLIENT) - endif() -endif () diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc deleted file mode 100644 index f88e31fcd2..0000000000 --- a/mindspore/ccsrc/transform/convert.cc +++ /dev/null @@ -1,1898 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/convert.h" - -#include -#include -#include -#include "utils/utils.h" - -#include "operator/ops.h" -#include "utils/log_adapter.h" -#include "utils/graph_utils.h" -#include "utils/symbolic.h" -#include "utils/config_manager.h" -#include "utils/convert_utils.h" -#include "./common.h" - -namespace mindspore { -namespace transform { -using std::endl; - -#define ADPT_DESC_ONE(T) std::make_shared(std::make_shared>()) -#define ADPT_DESC_TWO(T, I) \ - std::make_shared(std::make_shared>(), std::make_shared>()) -#define GET_MACRO(_1, _2, DESC, ...) DESC -#define ADPT_DESC(...) GET_MACRO(__VA_ARGS__, ADPT_DESC_TWO, ADPT_DESC_ONE, ...)(__VA_ARGS__) - -using ge::Operator; -using mindspore::kAnyValue; -using std::make_shared; -using std::shared_ptr; -using std::string; -using std::vector; - -const char kNameCustomOp[] = "CustomOp"; -const char kNameConst[] = "Const"; -const char kNameParam[] = "parameter"; -const char kNameRandomUniform[] = "RandomUniform"; -const char kNameSimpleMean[] = "SimpleMean"; -const char kNameSimpleMeanGrad[] = "SimpleMeanGrad"; -const char kNameAllReduce[] = "AllReduce"; -const char kNameBroadcast[] = "Broadcast"; -const char kNameAllgather[] = "AllGather"; -const char kNameReduceScatter[] = "ReduceScatter"; -const char kNameReduceSum[] = "ReduceSum"; -const char kNameIsFinite[] = "isFinite"; -const char kNameReciprocal[] = "Reciprocal"; -const char kNameRsqrt[] = "Rsqrt"; -const char kNameRsqrtGrad[] = "RsqrtGrad"; -const char kNameSqrt[] = "Sqrt"; -const char kNameSquare[] = "Square"; -const char kNameSquaredDifference[] = "SquaredDifference"; -const char kNamePow[] = "Pow"; -const char kNameBatchMatMul[] = "BatchMatMul"; -const char kNameStridedSlice[] = "StridedSlice"; -const char kNameStridedSliceGrad[] = "StridedSliceGrad"; -const char kNameExpandDims[] = "ExpandDims"; -const char kNameLog[] = "Log"; -const char kNameLogicalAnd[] = "LogicalAnd"; -const char kNameLogicalNot[] = "LogicalNot"; -const char kNameLogicalOr[] = "LogicalOr"; -const char kNameExp[] = "Exp"; -const char kNameLessEqual[] = "LessEqual"; -const char kNameGreaterEqual[] = "GreaterEqual"; -const char kNameEqual[] = "Equal"; -const char kNameNotEqual[] = "NotEqual"; -const char kNameFlattenGrad[] = "FlattenGrad"; -const char kNameConvolution[] = "Convolution"; -const char kNameBiasAdd[] = "BiasAdd"; -const char kNameMaxPoolGrad[] = "MaxPoolGrad"; -const char kNameAvgPoolGrad[] = "AvgPoolGrad"; -const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax"; -const char kNameApplyMomentum[] = "ApplyMomentum"; -const char kNameDropoutDoMask[] = "DropoutDoMask"; -const char kNameResizeBilinear[] = "ResizeBilinear"; -const char kNameResizeBilinearGrad[] = "ResizeBilinearGrad"; -const char kNameZerosLike[] = "ZerosLike"; -const char kNameOnesLike[] = "OnesLike"; -const char kNameTruncatedNormal[] = "TruncatedNormal"; -const char kNameSpaceToBatchNd[] = "SpaceToBatchNd"; -const char kNameConfusionMatrix[] = "ConfusionMatrix"; -const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor"; -const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad"; -const char kNameApplyAdam[] = "Adam"; -const char kNameExtractImagePatches[] = "ExtractImagePatches"; -const char kNameReLU6[] = "ReLU6"; -const char kNameReLU6Grad[] = "ReLU6Grad"; -const char kNameElu[] = "Elu"; -const char kNameEluGrad[] = "EluGrad"; -const char kNameTensorScatterUpdate[] = "TensorScatterUpdate"; -const char kNameScatterUpdate[] = "ScatterUpdate"; -const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; -const char kNameScatterMax[] = "ScatterMax"; -const char kNameNMSWithMask[] = "NMSWithMask"; -const char kNameCheckValid[] = "CheckValid"; -const char kNameSmoothL1Loss[] = "SmoothL1Loss"; -const char kNameSmoothL1LossGrad[] = "SmoothL1LossGrad"; -const char kNameSGD[] = "SGD"; -const char kNameSigmoidCrossEntropyWithLogits[] = "SigmoidCrossEntropyWithLogits"; -const char kNameSigmoidCrossEntropyWithLogitsGrad[] = "SigmoidCrossEntropyWithLogitsGrad"; -const char kNameScatterNdD[] = "ScatterNd"; -const char kNamePadD[] = "Pad"; -const char kNameMirrorPad[] = "MirrorPad"; -const char kNameMirrorPadGrad[] = "MirrorPadGrad"; -const char kNameGatherNd[] = "GatherNd"; -const char kNameArgmax[] = "Argmax"; -const char kNameArgmin[] = "Argmin"; -const char kNameArgMaxWithValue[] = "ArgMaxWithValue"; -const char kNameArgMinWithValue[] = "ArgMinWithValue"; -const char kNameReduceProd[] = "ReduceProd"; -const char kNameCumProd[] = "CumProd"; -const char kNameDiagpart[] = "Diagpart"; -const char kNameSplitD[] = "Split"; -const char kNameBatchToSpaceNd[] = "BatchToSpaceNd"; -const char kNameFloor[] = "Floor"; -const char kNameNPUGetFloatStatus[] = "NPUGetFloatStatus"; -const char kNameAssign[] = "Assign"; -const char kNameAssignAdd[] = "AssignAdd"; -const char kNameAssignSub[] = "AssignSub"; -const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus"; -const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus"; -const char kNameReshape[] = "Reshape"; -const char kNameTransShape[] = "TransShape"; -const char kNameRealDiv[] = "RealDiv"; -const char kNameTile[] = "Tile"; -const char kNameCos[] = "Cos"; -const char kNameACos[] = "ACos"; -const char kNameACosGrad[] = "ACosGrad"; -const char kNameFloorDiv[] = "FloorDiv"; -const char kNameSin[] = "Sin"; -const char kNamePrelu[] = "PReLU"; -const char kNamePreluGrad[] = "PReLUGrad"; -const char kNameSigmoid[] = "Sigmoid"; -const char kNameSigmoidGrad[] = "SigmoidGrad"; -const char kNameL2Normalize[] = "L2Normalize"; -const char kNameL2NormalizeGrad[] = "L2NormalizeGrad"; -const char kNameSoftmax[] = "Softmax"; -const char kNameIOU[] = "IOU"; -const char kNameBoundingBoxDecode[] = "BoundingBoxDecode"; -const char kNameBoundingBoxEncode[] = "BoundingBoxEncode"; -const char kNameSlice[] = "Slice"; -const char kNameAddN[] = "AddN"; -const char kNameLess[] = "Less"; -const char kNameGreater[] = "Greater"; -const char kNamePack[] = "Pack"; -const char kNameUnpack[] = "Unpack"; -const char kNameMerge[] = "Merge"; -const char kNameGeSwitch[] = "GeSwitch"; - -const char kNameHuberLoss[] = "HuberLoss"; -const char kNameCumSum[] = "CumSum"; -const char kNameHuberLossGrad[] = "HuberLossGrad"; -const char kNameSparseSoftmaxCrossEntropy[] = "SparseSoftmaxCrossEntropy"; -const char kNameSparseSoftmaxCrossEntropyGrad[] = "SparseSoftmaxCrossEntropyGrad"; -const char kNameTopK[] = "TopK"; -const char kNameSoftmaxGrad[] = "SoftmaxGrad"; -const char kNameMaxPool[] = "MaxPool"; -const char kNameAvgPool[] = "AvgPool"; -const char kNameMaxPoolWithArgmax[] = "MaxPoolWithArgmax"; -const char kNameBatchNorm[] = "BatchNorm"; -const char kNameBatchNormGrad[] = "BatchNormGrad"; -const char kNameROIAlign[] = "ROIAlign"; -const char kNameROIAlignGrad[] = "ROIAlignGrad"; -const char kNameRandomChoiceWithMask[] = "RandomChoiceWithMask"; -const char kNameAbs[] = "Abs"; -const char kNameAbsGrad[] = "AbsGrad"; -const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy"; -const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad"; -const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad"; -const char kNameSparseApplyFtrlD[] = "SparseApplyFtrlD"; -const char kNameApplyProximalAdagrad[] = "ApplyProximalAdagrad"; -const char kNameAcosh[] = "Acosh"; -const char kNameAcoshGrad[] = "AcoshGrad"; -const char kNameFloorMod[] = "FloorMod"; -const char kNameSpaceToDepth[] = "SpaceToDepth"; -const char kNameDepthToSpace[] = "DepthToSpace"; -const char kNameSign[] = "Sign"; -const char kNameLARSUpdate[] = "LARSUpdate"; -const char kNameRound[] = "Round"; -const char kNamePrint[] = "Print"; -const char kNameApplyFtrl[] = "ApplyFtrl"; -const char kNameDiag[] = "Diag"; -const char kNameDiagPart[] = "DiagPart"; -const char kNameSpaceToBatch[] = "SpaceToBatch"; -const char kNameBatchToSpace[] = "BatchToSpace"; -const char kNameAtan2[] = "Atan2"; -const char kNameApplyRMSProp[] = "ApplyRMSProp"; -const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; -const char kNameL2Loss[] = "L2Loss"; -const char kNameCTCLoss[] = "CTCLoss"; -const char kNameRange[] = "Range"; -const char kNameSquareSumAll[] = "SquareSumAll"; -const char kNameAscendQuant[] = "AscendQuant"; -const char kNameAscendDequant[] = "AscendDequant"; - -// -----------------OpAdapter initialization-------------- -std::unordered_map &DfGraphConvertor::get_adpt_map() { - static std::unordered_map adpt_map = { - {string(kNameCustomOp), ADPT_DESC(Operator)}, - {string(kNameIOU), ADPT_DESC(Iou)}, - {string(kNameGreaterEqual), ADPT_DESC(GreaterEqual)}, - {string(kNameSlice), ADPT_DESC(SliceD)}, - {string(kNameApplyMomentum), ADPT_DESC(ApplyMomentumD)}, - {string(kNameMaxPool), ADPT_DESC(MaxPool)}, - {string(kNameAvgPool), ADPT_DESC(AvgPool)}, - {string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)}, - {string(kNameTopK), ADPT_DESC(TopK)}, - {string(kNamePack), ADPT_DESC(Pack)}, - {string(kNameUnpack), ADPT_DESC(Unpack)}, - {string(kNameSplitD), ADPT_DESC(SplitD)}, - {string(kNameAllReduce), ADPT_DESC(HcomAllReduce)}, - {string(kNameBroadcast), ADPT_DESC(HcomBroadcast)}, - {string(kNameAllgather), ADPT_DESC(HcomAllGather)}, - {string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)}, - {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, - {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, - {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, - {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, - {prim::kPrimAssign->name(), ADPT_DESC(Assign)}, - {prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)}, - {prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)}, - {prim::kPrimBiasAddGrad->name(), ADPT_DESC(BiasAddGrad)}, - {prim::kPrimConv2D->name(), ADPT_DESC(Conv2D)}, - {prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD)}, - {prim::kPrimConv2DBackpropFilter->name(), ADPT_DESC(Conv2DBackpropFilterD)}, - {prim::kPrimDepthwiseConv2dNative->name(), ADPT_DESC(DepthwiseConv2D)}, - {prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), ADPT_DESC(DepthwiseConv2DBackpropFilterD)}, - {prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), ADPT_DESC(DepthwiseConv2DBackpropInputD)}, - {string(kNameBatchNorm), ADPT_DESC(BatchNorm)}, - {string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)}, - {string(kNameReshape), ADPT_DESC(Reshape)}, - {string(kNameTransShape), ADPT_DESC(TransShape)}, - {string(kNameFlattenGrad), ADPT_DESC(Reshape)}, - {prim::kPrimFlatten->name(), ADPT_DESC(Flatten)}, - {string(kNameAddN), ADPT_DESC(AddN)}, - {string(kNameLess), ADPT_DESC(Less)}, - {string(kNameSqrt), ADPT_DESC(Sqrt)}, - {string(kNameRsqrt), ADPT_DESC(Rsqrt)}, - {string(kNameSquare), ADPT_DESC(Square)}, - {prim::kPrimTanh->name(), ADPT_DESC(Tanh)}, - {prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)}, - {string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborV2D)}, - {string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborV2Grad)}, - {string(kNameApplyAdam), ADPT_DESC(ApplyAdam)}, - {string(kNameReLU6), ADPT_DESC(Relu6)}, - {string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)}, - {string(kNameElu), ADPT_DESC(Elu)}, - {string(kNameEluGrad), ADPT_DESC(EluGrad)}, - {string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearV2Grad)}, - {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, - {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, - {string(kNameOnesLike), ADPT_DESC(OnesLike)}, - {string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)}, - {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, - {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, - {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, - {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, - {string(kNameCheckValid), ADPT_DESC(CheckValid)}, - {string(kNameSmoothL1Loss), ADPT_DESC(SmoothL1Loss)}, - {string(kNameSmoothL1LossGrad), ADPT_DESC(SmoothL1LossGrad)}, - {string(kNameSigmoidCrossEntropyWithLogits), ADPT_DESC(SigmoidCrossEntropyWithLogits)}, - {string(kNameSigmoidCrossEntropyWithLogitsGrad), ADPT_DESC(SigmoidCrossEntropyWithLogitsGrad)}, - {string(kNameScatterNdD), ADPT_DESC(ScatterNdD)}, - {string(kNamePadD), ADPT_DESC(PadD)}, - {string(kNameMirrorPad), ADPT_DESC(MirrorPad)}, - {string(kNameMirrorPadGrad), ADPT_DESC(MirrorPadGrad)}, - {string(kNameGatherNd), ADPT_DESC(GatherNd)}, - {string(kNameArgmax), ADPT_DESC(ArgMaxD)}, - {string(kNameArgmin), ADPT_DESC(ArgMinD)}, - {string(kNameArgMaxWithValue), ADPT_DESC(ArgMaxWithValue)}, - {string(kNameArgMinWithValue), ADPT_DESC(ArgMinWithValue)}, - {prim::kPrimReduceSum->name(), ADPT_DESC(ReduceSumD)}, - {prim::kPrimReduceMean->name(), ADPT_DESC(ReduceMeanD)}, - {prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAllD)}, - {prim::kPrimReduceMin->name(), ADPT_DESC(ReduceMinD)}, - {prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMaxD)}, - {string(kNameLARSUpdate), ADPT_DESC(LarsV2Update)}, - {string(kNameReduceProd), ADPT_DESC(ReduceProdD)}, - {string(kNameCumProd), ADPT_DESC(CumprodD)}, - {string(kNameMerge), ADPT_DESC(Merge)}, - {string(kNameGeSwitch), ADPT_DESC(Switch)}, - {string(kNameCumSum), ADPT_DESC(CumsumD)}, - - {prim::kPrimMul->name(), ADPT_DESC(Mul)}, - {string(kNameTile), ADPT_DESC(TileD)}, - {prim::kPrimOneHot->name(), ADPT_DESC(OneHot)}, - - {prim::kPrimGatherV2->name(), ADPT_DESC(GatherV2D)}, - {string(kNameCos), ADPT_DESC(Cos)}, - {string(kNameACos), ADPT_DESC(Acos)}, - {string(kNameACosGrad), ADPT_DESC(AcosGrad)}, - {string(kNameFloor), ADPT_DESC(Floor)}, - {string(kNameFloorDiv), ADPT_DESC(FloorDiv)}, - {string(kNameSin), ADPT_DESC(Sin)}, - {string(kNameExp), ADPT_DESC(Exp)}, - {string(kNameBoundingBoxEncode), ADPT_DESC(BoundingBoxEncode)}, - {string(kNameBoundingBoxDecode), ADPT_DESC(BoundingBoxDecode)}, - - {prim::kPrimCast->name(), ADPT_DESC(Cast)}, - {string(kNameRealDiv), ADPT_DESC(RealDiv)}, - {prim::kPrimNeg->name(), ADPT_DESC(Neg)}, - {prim::kPrimTranspose->name(), ADPT_DESC(TransposeD)}, - {prim::kPrimSub->name(), ADPT_DESC(Sub)}, - {string(kNameReciprocal), ADPT_DESC(Reciprocal)}, - {prim::kPrimDropoutGenMask->name(), ADPT_DESC(DropOutGenMask)}, - {string(kNameAssignAdd), ADPT_DESC(AssignAdd)}, - {string(kNameAssignSub), ADPT_DESC(AssignSub)}, - {prim::kPrimConcat->name(), ADPT_DESC(ConcatD)}, - {string(kNamePow), ADPT_DESC(Pow)}, - {string(kNameExp), ADPT_DESC(Exp)}, - {string(kNameEqual), ADPT_DESC(Equal)}, - {string(kNameNotEqual), ADPT_DESC(NotEqual)}, - {string(kNameLog), ADPT_DESC(Log)}, - {string(kNameLogicalAnd), ADPT_DESC(LogicalAnd)}, - {string(kNameLogicalNot), ADPT_DESC(LogicalNot)}, - {string(kNameLogicalOr), ADPT_DESC(LogicalOr)}, - {string(kNameGreater), ADPT_DESC(Greater)}, - {prim::kPrimMaximum->name(), ADPT_DESC(Maximum)}, - {prim::kPrimRelu->name(), ADPT_DESC(Relu)}, - {string(kNamePrelu), ADPT_DESC(PRelu)}, - {string(kNamePreluGrad), ADPT_DESC(PReluGrad)}, - {string(kNameSigmoid), ADPT_DESC(Sigmoid)}, - {string(kNameSigmoidGrad), ADPT_DESC(SigmoidGrad)}, - {string(kNameSGD), ADPT_DESC(SGD)}, - {prim::kPrimLogSoftmaxGrad->name(), ADPT_DESC(LogSoftmaxGrad)}, - {prim::kPrimMaximumGrad->name(), ADPT_DESC(MaximumGrad)}, - {prim::kPrimMinimumGrad->name(), ADPT_DESC(MinimumGrad)}, - {string(kNameL2Normalize), ADPT_DESC(L2Normalize)}, - {string(kNameL2NormalizeGrad), ADPT_DESC(L2NormalizeGrad)}, - - {prim::kPrimMinimum->name(), ADPT_DESC(Minimum)}, - {prim::kPrimSelect->name(), ADPT_DESC(Select)}, - {string(kNameLessEqual), ADPT_DESC(LessEqual)}, - {prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmaxV2)}, - {string(kNameTruncatedNormal), ADPT_DESC(TruncatedNormal)}, - {string(kNameStridedSliceGrad), ADPT_DESC(StridedSliceGrad)}, - {prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, - {prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)}, - {string(kNameStridedSlice), ADPT_DESC(StridedSlice)}, - {prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMin)}, - {prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)}, - {string(kNameExpandDims), ADPT_DESC(ExpandDims)}, - {prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)}, - {prim::kPrimLayerNorm->name(), ADPT_DESC(LayerNorm)}, - {prim::kPrimLayerNormGrad->name(), ADPT_DESC(LayerNormGrad)}, - {string(kNameBatchMatMul), ADPT_DESC(BatchMatMul)}, - {string(kNameDropoutDoMask), ADPT_DESC(DropOutDoMask)}, - - {string(kNameNPUGetFloatStatus), ADPT_DESC(NPUGetFloatStatus)}, - {string(kNameNPUAllocFloatStatus), ADPT_DESC(NPUAllocFloatStatus)}, - {string(kNameNPUClearFloatStatus), ADPT_DESC(NPUClearFloatStatus)}, - - {string(kNameRandomChoiceWithMask), ADPT_DESC(RandomChoiceWithMask)}, - {prim::kPrimSoftmaxCrossEntropyWithLogits->name(), ADPT_DESC(SoftmaxCrossEntropyWithLogits)}, - - {prim::kPrimScalarSummary->name(), ADPT_DESC(Summary)}, - {prim::kPrimImageSummary->name(), ADPT_DESC(Summary)}, - {prim::kPrimTensorSummary->name(), ADPT_DESC(Summary)}, - {prim::kPrimHistogramSummary->name(), ADPT_DESC(Summary)}, - {prim::kPrimDebug->name(), ADPT_DESC(Summary)}, - {prim::kPrimTensorAdd->name(), - std::make_shared(std::make_shared>(ExtraAttr({{"mode", MakeValue(1)}})), - std::make_shared>(ExtraAttr({{"mode", MakeValue(1)}})))}, - {string(kNameBiasAdd), ADPT_DESC(BiasAdd)}, - {prim::kPrimRelu->name(), ADPT_DESC(Relu)}, - - {prim::kPrimMatMul->name(), ADPT_DESC(MatMul)}, - - {string(kNameConst), ADPT_DESC(Constant, Const)}, - {string(kNameSoftmax), ADPT_DESC(SoftmaxV2)}, - {string(kNameSoftmaxGrad), ADPT_DESC(SoftmaxGrad)}, - {string(kNameParam), ADPT_DESC(Data)}, - {string(kNameROIAlign), ADPT_DESC(ROIAlign)}, - {string(kNameROIAlignGrad), ADPT_DESC(ROIAlignGrad)}, - {string(kNameAbs), ADPT_DESC(Abs)}, - {string(kNameAbsGrad), ADPT_DESC(AbsGrad)}, - {string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)}, - {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, - {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, - {string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)}, - {string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagradD)}, - {string(kNameAcosh), ADPT_DESC(Acosh)}, - {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, - {string(kNameFloorMod), ADPT_DESC(FloorMod)}, - {string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)}, - {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, - {string(kNameSign), ADPT_DESC(Sign)}, - {string(kNameRound), ADPT_DESC(Round)}, - {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrlD)}, - {string(kNameDiag), ADPT_DESC(Diag)}, - {string(kNameDiagPart), ADPT_DESC(DiagPart)}, - {string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)}, - {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, - {string(kNameAtan2), ADPT_DESC(Atan2)}, - {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, - {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}, - {string(kNameL2Loss), ADPT_DESC(L2Loss)}, - {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, - {string(kNameRange), ADPT_DESC(RangeD)}, - {string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}, - {string(kNameAscendQuant), ADPT_DESC(AscendQuant)}, - {string(kNameAscendDequant), ADPT_DESC(AscendDequant)}}; -#ifdef ENABLE_GE - adpt_map[string(kNamePrint)] = ADPT_DESC(Print); - adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); -#endif - return adpt_map; -} - -// ---------------implement of DfGraphConvertor------------- -PrimType GetCNodeFuncType(const CNodePtr cnode) { - if (cnode->inputs().empty()) { - return kPrimTypeUnknown; - } - - AnfNodePtr valuenode = cnode->input(0); - if (IsValueNode(valuenode)) { - // check whether the valuenode is primitive - return GetValueNode(valuenode)->prim_type(); - } - return kPrimTypeUnknown; -} - -OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) { - if (node->isa()) { - auto cnode = node->cast(); - - std::string name = kNameCustomOp; - if (!IsCustomCNode(cnode)) { - name = GetCNodeFuncName(cnode); - } - - auto it_adpt = get_adpt_map().find(name); - if (it_adpt != get_adpt_map().end()) { - return it_adpt->second->Get(train); - } - MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name; - } - - if (node->isa()) { - return get_adpt_map()[kNameConst]->Get(train); - } - if (node->isa()) { - return get_adpt_map()[kNameParam]->Get(train); - } - return OpAdapterPtr(nullptr); -} - -void DfGraphConvertor::InitLoopVar(std::vector *init_input) { - if (this->training_) { - GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); - auto var_iter_num = std::make_shared("npu_runconfig/iterations_per_loop"); - auto var_loop_cond = std::make_shared("npu_runconfig/loop_cond"); - auto var_one = std::make_shared("npu_runconfig/one"); - auto var_zero = std::make_shared("npu_runconfig/zero"); - (void)var_iter_num->update_output_desc_y(desc); - (void)var_loop_cond->update_output_desc_y(desc); - (void)var_one->update_output_desc_y(desc); - (void)var_zero->update_output_desc_y(desc); - vars_["npu_runconfig/iterations_per_loop"] = var_iter_num; - vars_["npu_runconfig/loop_cond"] = var_loop_cond; - vars_["npu_runconfig/one"] = var_one; - vars_["npu_runconfig/zero"] = var_zero; - - int64_t value = 0; - auto const_iter_num = std::make_shared("const/npu_runconfig/iterations_per_loop"); - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { - value = ConfigManager::GetInstance().iter_num(); - } else { - MS_LOG(INFO) << "Run with normal(non-sink) mode, the iterator number will always be 1"; - value = 1; - ConfigManager::GetInstance().set_iter_num(value); - } - value -= 1; // iteration start from 0, the max iteration number for n loop should be n-1 - (void)const_iter_num->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); - - auto const_loop_cond = std::make_shared("const/npu_runconfig/loop_cond"); - value = 0; - (void)const_loop_cond->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); - - auto const_one = std::make_shared("const/npu_runconfig/one"); - value = 1; - (void)const_one->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); - - auto const_zero = std::make_shared("const/npu_runconfig/zero"); - value = 0; - (void)const_zero->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); - - (void)const_iter_num->update_output_desc_y(desc); - (void)const_loop_cond->update_output_desc_y(desc); - (void)const_one->update_output_desc_y(desc); - (void)const_zero->update_output_desc_y(desc); - - auto assign_iter_num = std::make_shared("assign/npu_runconfig/iterations_per_loop"); - (void)assign_iter_num->set_input_ref(*var_iter_num).set_input_value(*const_iter_num); - auto assign_loop_cond = std::make_shared("assign/npu_runconfig/loop_cond"); - (void)assign_loop_cond->set_input_ref(*var_loop_cond).set_input_value(*const_loop_cond); - auto assign_one = std::make_shared("assign/npu_runconfig/one"); - (void)assign_one->set_input_ref(*var_one).set_input_value(*const_one); - auto assign_zero = std::make_shared("assign/npu_runconfig/zero"); - (void)assign_zero->set_input_ref(*var_zero).set_input_value(*const_zero); - - init_input->push_back(*var_iter_num); - init_input->push_back(*var_loop_cond); - init_input->push_back(*var_one); - init_input->push_back(*var_zero); - init_ops_.push_back(var_iter_num); - init_ops_.push_back(var_loop_cond); - init_ops_.push_back(var_one); - init_ops_.push_back(var_zero); - init_ops_.push_back(const_iter_num); - init_ops_.push_back(const_loop_cond); - init_ops_.push_back(const_one); - init_ops_.push_back(const_zero); - init_ops_.push_back(assign_iter_num); - init_ops_.push_back(assign_loop_cond); - init_ops_.push_back(assign_one); - init_ops_.push_back(assign_zero); - } -} - -OpAdapterPtr DfGraphConvertor::FindAdapter(const std::string &name, bool train) { - auto it = get_adpt_map().find(name); - if (it != get_adpt_map().end()) { - return it->second->Get(train); - } - MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name; -} - -void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it) { - // draw init subgraph - init_sout_ << "op_assign" << it.get() << "[label=<"; - init_sout_ << "" << endl; - init_sout_ << ""; - init_sout_ << ""; - init_sout_ << ""; - init_sout_ << "" << endl; - init_sout_ << "" << endl; - init_sout_ << "
resourcevalue
" - << "\"assign_" << name << "\"
> shape=plaintext]" << endl; - init_sout_ << "param" << it.get() << "[shape=octagon, label=\"" << name << "\"]" << endl; - init_sout_ << "const" << it.get() << "[label= \"" << name << "_const" - << "\" shape=ellipse]" << endl; - init_sout_ << "param" << it.get() << "->" - << "op_assign" << it.get() << ":1" << endl; - init_sout_ << "const" << it.get() << "->" - << "op_assign" << it.get() << ":2" << endl; -} - -void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input) { - DfGraphPtr init_graph = std::make_shared("init"); - std::vector nodes = TopoSort(anf_graph_->get_return()); - - for (auto &it : nodes) { - if (it->isa()) { - if (IsValueNode(it)) { - auto symbolic = GetValueNode(it); - auto name = std::static_pointer_cast(symbolic->node())->name(); - auto iter = vars_.find(name); // get correspoding varaible op - if (iter != vars_.end()) { - op_cache_[it.get()] = iter->second; - // #ifdef DRAW_GE_GRAPH - compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()] - << "[style=\"dotted\"]" << endl; - // #endif - } - } else if (IsValueNode(it)) { - auto refkey = GetValueNode(it); - auto name = refkey->tag(); - auto iter = vars_.find(name); // get correspoding varaible op - if (iter != vars_.end()) { - op_cache_[it.get()] = iter->second; - compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()] - << "[style=\"dotted\"]" << endl; - } - } - } - } - - for (auto &it : tensors) { - if (vars_.find(it.first) == vars_.end()) { - MS_LOG(WARNING) << "Init parameter " << it.first << " didn't appear in graph."; - vars_[it.first] = nullptr; - } - } - - // set up init sub graph - if (init_input->size()) { - // init sub graph needs no input - MS_LOG(INFO) << "Build data init subgraph."; - (void)init_graph->SetInputs(*init_input); - this->init_graph_ = init_graph; - } else { - this->init_graph_ = nullptr; - } -} - -void DfGraphConvertor::MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it) { - MS_LOG(INFO) << "The " << name << " is the " << input_idx << "(st/nd/th) input"; - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { - auto getnext_idx = static_cast(input_idx); - DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); - if (!param.input_indexes().empty() && input_idx <= param.input_indexes().size()) { - getnext_idx = param.input_indexes()[input_idx] - 1; // input_idx start from 0. - MS_LOG(INFO) << "remap input_index:" << input_idx << " to getnext_index:" << getnext_idx << "."; - } - // use iterator_getnext op with output_name instead of data op in BuildGraph. - out_handle_cache_[it.get()] = OutHandler(dataset_iter_getnext_, "y" + std::to_string(getnext_idx)); - } -} - -void DfGraphConvertor::SetupBroadcast(const std::shared_ptr &broadcast, - const std::vector &broadcast_desc, - const DfGraphPtr &broadcast_graph, std::vector broadcast_input) { - MS_LOG(INFO) << "build broadcast subgraph"; - if (broadcast_desc.size() != broadcast_input.size()) { - MS_LOG(EXCEPTION) << "Desc number of BroadCast is not equal to number of Input"; - } - (void)broadcast->create_dynamic_input_x(static_cast(broadcast_input.size())); - (void)broadcast->create_dynamic_output_y(static_cast(broadcast_desc.size())); - for (unsigned int i = 0; i < broadcast_input.size(); i++) { - (void)broadcast->set_dynamic_input_x(i, broadcast_input[i]); - (void)broadcast->update_dynamic_output_desc_y(i, broadcast_desc[i]); - } - (void)broadcast_graph->SetInputs(broadcast_input); - this->broadcast_graph_ = broadcast_graph; -} - -void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) { - int index = 0; - std::vector init_input; - for (auto it : tensors) { - std::string name = it.first; - auto node_itor = params_.find(name); - // if name not in params_, create a node in graph - if (node_itor == params_.end()) { - MS_LOG(WARNING) << name << " is not in params, and create a new node."; - ParameterPtr param = std::make_shared(nullptr); - name = name + "_temp"; - param->set_name(name); - (void)ConvertParameter(param); - node_itor = params_.find(name); - } - auto node = node_itor->second; - auto op_itor = op_cache_.find(node.get()); - if (op_itor == op_cache_.end()) { - MS_LOG(EXCEPTION) << "Can not find op for node " << node->ToString() << "."; - } - auto adpt = FindAdapter(kNameParam, training_); - if (adpt == nullptr) continue; - auto param_op = adpt->generate(name + "_data"); - MS_LOG(INFO) << "Add parameter " << name << " as input, index " << index << "."; - - if (!training_) { - auto adpt_const = FindAdapter(kNameConst, training_); - if (adpt_const == nullptr) continue; - auto const_op = adpt_const->generate(name + "_const"); - (void)adpt_const->setAttr(const_op, "value", it.second); - - auto const_op_desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW); - if (const_op_desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; - continue; - } - (void)std::static_pointer_cast(const_op)->update_output_desc_y(*const_op_desc); - - vars_[name] = const_op; - op_itor->second = const_op; - continue; - } - - // create tensor descriptor for output descriptor - auto desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; - continue; - } - - // we need three variable ops for each graph with same name - // build init subgraph - if (it.second->is_init() == 0) { - (void)std::static_pointer_cast(param_op)->set_attr_index(index++); - auto init_var = std::make_shared(name); - auto assign_op = std::make_shared("assign_" + name); - (void)init_var->update_output_desc_y(*desc); - (void)assign_op->set_input_ref(*init_var).set_input_value(*param_op); - init_input.push_back(*init_var); - init_ops_.push_back(param_op); - init_ops_.push_back(assign_op); - init_ops_.push_back(init_var); - } - - auto variable = std::make_shared(name); - (void)variable->update_output_desc_y(*desc); - // do not use read variable while variable sink - MS_LOG(DEBUG) << "InitParam, op_name = " << name << ", var = " << variable->GetName() << "."; - op_itor->second = variable; // replace parameter with variable - vars_[name] = variable; // prevent the variable operator from being freed - DrawParamInitSubGraph(name, node); - } - InitLoopVar(&init_input); - SetupParamInitSubGraph(tensors, &init_input); -} - -// convert all parameter need initialize to variable -DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) { - size_t input_idx = 0; - if (error_ != 0) { - return *this; - } - if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { - error_ = INVALID_ARGUMENT; - MS_LOG(ERROR) << "Invalid AnfGraph in InitParam."; - return *this; - } - - // Processing input with MakeDatasetHandler - for (auto &it : anf_graph_->parameters()) { - auto op_itor = op_cache_.find(it.get()); // converted node - if (it->isa() && op_itor != op_cache_.end()) { - string name = std::static_pointer_cast(it)->name(); - auto tensor_itor = tensors.find(name); // in init value map - if (tensor_itor == tensors.end()) { - DfGraphConvertor::MakeDatasetHandler(name, input_idx, it); - input_idx++; - } - } - } - InitParamWithData(tensors); - init_sout_ << "}" << endl; - return *this; -} - -#if (defined ENABLE_GE) -void DfGraphConvertor::BuildSaveCheckpointGraph() { - std::vector graph_inputs; - ge::op::Save save_op("save_parms"); - int save_op_is_active = 0; - size_t index = 0; - string name; - - int32_t count_size = std::count_if(vars_.begin(), vars_.end(), [](const std::pair &it) { - return (it.second == nullptr || it.first.find("/") != std::string::npos); - }); - - (void)save_op.create_dynamic_input_tensors(vars_.size() - static_cast(count_size)); - - // for each "parameter" in anf graph excluding "input" - for (const auto &it : vars_) { - name = it.first; - if (it.second == nullptr || name.find("/") != std::string::npos) continue; - Variable variable(name); - (void)variable.update_output_desc_y(it.second->GetOutputDesc(0)); - (void)save_op.set_dynamic_input_tensors(index++, variable); - - graph_inputs.push_back(variable); - - if (save_op_is_active == 0) { - checkpoint_sout_ << "op_save" << &save_op << "[label=<"; - checkpoint_sout_ << "" << endl; - checkpoint_sout_ << "" << endl; - checkpoint_sout_ << "" << endl; - checkpoint_sout_ << "
tensor
" - << "\"saveop" - << "\"
> shape=plaintext]" << endl; - } - - checkpoint_sout_ << "param" << it.second << "[shape=octagon, label=\"" << name << "\"]" << endl; - - checkpoint_sout_ << "param" << it.second << "->" - << "op_save" << &save_op << ":1" << endl; - save_op_is_active = 1; - } - if (save_op_is_active) { - std::vector graph_output; - graph_output.emplace_back(save_op); - DfGraphPtr checkpoint_graph = std::make_shared("checkpoint"); - (void)checkpoint_graph->SetInputs(graph_inputs); - (void)checkpoint_graph->SetOutputs(graph_output); - this->save_ckp_graph_ = checkpoint_graph; - } else { - this->save_ckp_graph_ = nullptr; - } - - checkpoint_sout_ << "}" << endl; - return; -} -#endif - -DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap &tensors) { - if (error_ != 0) { - return *this; - } - if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { - error_ = INVALID_ARGUMENT; - MS_LOG(ERROR) << "Invalid AnfGraph in generate broadcast graph"; - return *this; - } - - DfGraphPtr broadcast_graph = std::make_shared("broadcast"); - // collect the operators create for broadcast sub graph, in order to avoid auto release - std::vector broadcast_input; - std::vector broadcast_desc; - auto broadcast = std::make_shared("broadcast_parameter"); - (void)broadcast->set_attr_root_rank(0); - (void)broadcast->set_attr_group("hccl_world_group"); - broadcast_ops_.push_back(broadcast); - - // find every parameter, build broadcast subgraph (or initialize the parameter with constant) - for (auto &it : anf_graph_->parameters()) { - auto op_itor = op_cache_.find(it.get()); // converted node - if (it->isa() && op_itor != op_cache_.end()) { - string name = std::static_pointer_cast(it)->name(); - auto tensor_itor = tensors.find(name); // in init tensor map - if (tensor_itor != tensors.end()) { - auto tensor = tensor_itor->second; - auto shape_ge = tensor->shape_c(); - - // create tensor descriptor for output descriptor - auto desc = TransformUtil::GetGeTensorDesc(shape_ge, tensor->data_type(), kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; - continue; - } - - // build broadcast subgraph - if (distribute_) { - auto broadcast_var = std::make_shared(name); - (void)broadcast_var->update_output_desc_y(*desc); - broadcast_input.push_back(*broadcast_var); - broadcast_desc.push_back(*desc); - broadcast_ops_.push_back(broadcast_var); - } - } - } - } - - // set up broadcast sub graph - if (!broadcast_input.empty()) { - DfGraphConvertor::SetupBroadcast(broadcast, broadcast_desc, broadcast_graph, broadcast_input); - } else { - this->broadcast_graph_ = nullptr; - } - return *this; -} - -DfGraphConvertor &DfGraphConvertor::GenerateCheckpointGraph() { - if (error_ != 0) { - MS_LOG(ERROR) << "Generate checkpoint graph failed, found error code " << error_ << "."; - return *this; - } - if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { - error_ = INVALID_ARGUMENT; - MS_LOG(ERROR) << "Invalid AnfGraph in GenerateCheckpointGraph"; - return *this; - } -#if (defined ENABLE_GE) - BuildSaveCheckpointGraph(); - // Restoring from checkpoint file is done by pyfront, not in graph now. -#endif - return *this; -} - -DfGraphConvertor &DfGraphConvertor::ConvertAllNode() { - if (error_ != 0) { - return *this; - } - if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { - MS_LOG(ERROR) << "Invalid AnfGraph"; - error_ = FAILED; - return *this; - } - - compute_sout_.clear(); - compute_sout_ << "digraph {" << endl; - init_sout_.clear(); - init_sout_ << "digraph {" << endl; - checkpoint_sout_.clear(); - checkpoint_sout_ << "digraph {" << endl; - restore_checkpoint_sout_.clear(); - restore_checkpoint_sout_ << "digraph {" << endl; - - // Convert all anf node to Operator - MS_LOG(DEBUG) << "convert all node"; - std::vector nodes = TopoSort(anf_graph_->get_return()); - for (auto &it : nodes) { - (void)Convert(it); - if (this->error_ != 0) { - MS_LOG(ERROR) << "failed to convert node: " << it->DebugString() << "."; - } - } - - // Create dataset iterator and iterator_getnext node - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { - DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); - MS_LOG(INFO) << "Dataset param is " << param.ToString() << "."; - // GetNext - auto iter_getnext_op = make_shared("get_next_tmp"); - (void)iter_getnext_op->set_attr_output_types(param.ge_types()); - (void)iter_getnext_op->set_attr_output_shapes(param.shapes()); - (void)iter_getnext_op->set_attr_channel_name(param.queue_name()); - - // save iter_getnext_op for later use - dataset_iter_getnext_ = iter_getnext_op; - } - - // return the data flow graph - return *this; -} - -void DfGraphConvertor::TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out) { - auto it = out_handle_cache_.find(anf_out.get()); - if (it != out_handle_cache_.end()) { - OutHandler handle = it->second; - auto op = handle.op; - if (op != nullptr) { - MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out; - graph_outputs_.emplace_back(std::make_pair(*op, handle.out)); - } else { - MS_LOG(EXCEPTION) << "tuple_getitem: " << anf_out->fullname_with_scope() << " is not converted"; - } - } else { - // invalid tuple_getitem e.g. tuple_getitem(tuple_getitem())/tuple_getitem(depend())/tuple_getitem(make_tuple()) - MS_LOG(WARNING) << "Invalid tuple_getitem: " << anf_out->fullname_with_scope(); - } -} - -void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { - AnfNodePtr anf_out = node; - AnfNodePtr pre_node = nullptr; - - // trace Parameter node - TraceOutputFromParameter(anf_out); - // then trace cnode - if (!node->isa()) { - return; - } - - // trace tuple_getitem - while (anf_out->isa() && IsPrimitiveCNode(anf_out, prim::kPrimTupleGetItem)) { - pre_node = anf_out; - anf_out = anf_out->cast()->input(1); - } - // trace every element of make_tuple - auto c = anf_out->cast(); - std::string name = ""; - if (anf_out->isa()) { - name = GetCNodeFuncName(c); - } - - if (name == "make_tuple") { - for (unsigned int i = 1; i < c->inputs().size(); i++) { - TraceOutput(c->input(i)); - } - } else if (name == "Depend") { - if (c->inputs().size() < 3) { // "Depend" primitive have 3 inputs - MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3"; - } - TraceOutput(c->input(1)); - } else if (name == "tuple_getitem") { - TraceOutputFromTupleGetItem(anf_out); - } else { - // add outputs; - auto op = Convert(anf_out); - std::string index; - if (op != nullptr) { - if ((pre_node != nullptr) && IsPrimitiveCNode(pre_node, prim::kPrimTupleGetItem)) { - auto item = out_handle_cache_.find(pre_node.get()); - if (item != out_handle_cache_.end()) { - index = item->second.out; - } else { - MS_LOG(WARNING) << "Can't get operater: " << anf_out->fullname_with_scope() << " 's output item"; - } - } - MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope() << ":" << index; - graph_outputs_.emplace_back(make_pair(*op, index)); - } - } -} - -void DfGraphConvertor::TraceOutputFromParameter(const AnfNodePtr &anf_out) { - if (anf_out->isa()) { - MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope(); - auto it = out_handle_cache_.find(anf_out.get()); - if (it != out_handle_cache_.end()) { - // For dataset graph mode, input parameter is converted to a "iterator_get_next:yn" OutHandler. - OutHandler handle = it->second; - auto op = handle.op; - MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out; - graph_outputs_.emplace_back(make_pair(*op, handle.out)); - } else { - // common parameter case - auto op = Convert(anf_out); - if (op != nullptr) { - MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType(); - graph_outputs_.emplace_back(std::make_pair(*op, "")); - } - } - } -} - -void SetupDatasetIterGetNextNode(const OperatorPtr &op) { - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { - DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); - size_t output_num = param.ge_types().size(); - MS_LOG(INFO) << "Set iterator_getnext op's output num = " << output_num << "."; - // set iterator_getnext op's output num - shared_ptr iter_getnext = std::static_pointer_cast(op); - (void)iter_getnext->create_dynamic_output_y(static_cast(output_num)); - - for (uint32_t i = 0; i < output_num; i++) { - ge::TensorDesc desc(GeShape(param.shapes()[i]), ge::FORMAT_NCHW, (ge::DataType)param.ge_types()[i]); - // we don't SetRealDimCnt here since GE do not use this output's real-dim - (void)iter_getnext->update_dynamic_output_desc_y((i), desc); - } - } - return; -} - -DfGraphConvertor &DfGraphConvertor::BuildGraph() { - SetupDatasetIterGetNextNode(dataset_iter_getnext_); - - if (error_ != 0) { - return *this; - } - - // update tuple_out_handle_cache_ - for (auto it : tuple_out_handle_cache_) { - std::size_t len = it.second->size(); - for (std::size_t i = 0; i < len; i++) { - OutHandler handle = (*it.second)[i]; - if (handle.op) { - string name = handle.op->GetName(); - if (vars_.count(name)) { - OperatorPtr new_op = vars_[name]; - if (new_op != nullptr) { - MS_LOG(INFO) << "update tuple_out_handle_cache_ " << name; - (*it.second)[i] = OutHandler(new_op, handle.out); - } - } - } - } - } - - // set up dependices - MS_LOG(DEBUG) << "set up dependices"; - std::vector nodes = ::mindspore::TopoSort(anf_graph_->get_return()); - for (auto &it : nodes) { - SetNodeInput(it); - SetOpControlInput(it); - UpdateOpDesc(it); - } - - if (error_ == 0) { - df_graph_ = make_shared(anf_graph_->ToString()); - } else { - return *this; - } - - // set graph input according to the order from anf graph - std::vector inputs; - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { - inputs.push_back(*dataset_iter_getnext_); - } else { - auto params = anf_graph_->parameters(); - int index = 0; - for (auto &it : params) { - auto name = std::static_pointer_cast(it)->name(); - // the parameters which has not been converted to var - if (vars_.find(name) == vars_.end()) { - auto op = Convert(it); - MS_EXCEPTION_IF_NULL(op); - MS_LOG(INFO) << "add not var input " << it->ToString() << ", index " << index; - if (op == nullptr) { - MS_LOG(ERROR) << "Convert graph failed!"; - return *this; - } - UpdateDataOpDesc(it, op); - - MS_LOG(INFO) << "add input " << it->ToString() << ", index " << index; - (void)std::static_pointer_cast(op)->set_attr_index(index++); - inputs.push_back(*op); - } else if (vars_[name] != nullptr) { - MS_LOG(INFO) << "add var input " << it->ToString(); - auto op = Convert(it); - MS_EXCEPTION_IF_NULL(op); - inputs.push_back(*op); - } - } - } - - // Add const nodes as graph input for some operator work with constant - std::transform(graph_const_inputs_.begin(), graph_const_inputs_.end(), std::back_inserter(inputs), - [](OperatorPtr x) { return *x; }); - - MS_LOG(INFO) << "set graph input num: " << inputs.size(); - (void)df_graph_->SetInputs(inputs); - - // set graph output - // set the value of finale return apply node as the output of dataflow graph - MS_LOG(DEBUG) << "set output"; - graph_outputs_.clear(); - TraceOutput(anf_graph_->get_return()->input(1)); - MS_LOG(INFO) << "set graph output num: " << graph_outputs_.size(); - (void)df_graph_->SetOutputs(graph_outputs_); - - compute_sout_ << "}" << endl; - // For the graph(e.g. eval_subgraph) whose IterNum is 1, donot set NeedIteration flag. - if (ConfigManager::GetInstance().iter_num() > 1) { - df_graph_->SetNeedIteration(true); - } - return *this; -} - -void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const { - auto node = std::static_pointer_cast(it); - if (node == nullptr) { - MS_LOG(ERROR) << "Update data op descriptor failed! Invalid node."; - return; - } - auto normal_shape_ptr = dyn_cast(node->Shape()); - vector shape; - if (normal_shape_ptr == nullptr) { - MS_LOG(INFO) << "Invalid shape to update data op descriptor."; - return; - } - shape = normal_shape_ptr->shape(); - if (node->Type() == nullptr) { - MS_LOG(INFO) << "Invalid type to update data op descriptor."; - return; - } - TypeId me_type = node->Type()->type_id(); - if (kObjectTypeTensorType == me_type) { - me_type = dyn_cast(node->Type())->element()->type_id(); - } - std::ostringstream buf; - buf << "[" << shape << "]"; - MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type; - auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); - if (desc == nullptr) { - MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null."; - } else { - (void)std::static_pointer_cast(op)->update_input_desc_x(*desc); - (void)std::static_pointer_cast(op)->update_output_desc_y(*desc); - } -} - -DfGraphPtr DfGraphConvertor::GetComputeGraph() { return df_graph_; } - -DfGraphPtr DfGraphConvertor::GetInitGraph() { return init_graph_; } - -DfGraphPtr DfGraphConvertor::GetSaveCheckpointGraph() { return save_ckp_graph_; } - -DfGraphPtr DfGraphConvertor::GetBroadcastGraph() { return broadcast_graph_; } - -void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) { - if (control_depend_cache_.find(node.get()) == control_depend_cache_.end()) { - return; - } - - std::vector control_edges = control_depend_cache_[node.get()]; - if ((control_edges.empty())) { - MS_LOG(ERROR) << "Get control depend node's src or dest operator failed"; - return; - } - - for (auto &item : control_edges) { - (void)item.dest_op->AddControlInput(*item.src_op); - } -} - -const std::vector trans_var_list = {string(kNameAssign), string(kNameAssignAdd), string(kNameAssignSub)}; - -void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { - OperatorPtr src = Convert(node); - auto &inputs = node->inputs(); - for (size_t i = 1; i < inputs.size(); i++) { - auto pred = inputs[i]; - while (pred->isa() && GetCNodeFuncName(pred->cast()) == "Depend") { - pred = pred->cast()->input(1); - } - // skip the None input - if (IsValueNode(pred)) { - continue; - } - // transform "Const" op to "Variable" op when the next node is "Assign" op. - std::string c_name = GetCNodeFuncName(node); - auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); - if (!training_ && pos != trans_var_list.end() && pred->isa()) { - std::string name = std::static_pointer_cast(pred)->name(); - auto op_itor = op_cache_.find(pred.get()); - if (op_itor == op_cache_.end()) { - MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; - } - if (op_itor->second != nullptr && - (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && - vars_.find(name) != vars_.end()) { - auto variable = std::make_shared(name); - auto desc = vars_[name]->GetOutputDesc("y"); - (void)variable->update_output_desc_y(desc); - MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; - op_itor->second = variable; // replace parameter with variable - vars_[name] = variable; - } - } - // find in out_hadnle_cache_ first - auto it = out_handle_cache_.find(pred.get()); - if (it != out_handle_cache_.end()) { - int ret = adpt->setInput(src, SizeToInt(i), it->second); - if (ret == 0) { - if (pred->isa() && GetCNodeFuncName(pred->cast()) == "tuple_getitem") { - compute_sout_ << op_draw_name_[pred->cast()->input(1).get()] << " -> " << op_draw_name_[node.get()] - << ":" << i << endl; - } else if (pred->isa()) { - compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl; - } else { - // don't draw anything. - MS_LOG(INFO) << "DRAW_GE_GRAPH: Shouldn't have this case."; - } - AddGraphConstInput(it->second.op); - } - } else if (tuple_out_handle_cache_.find(pred.get()) != tuple_out_handle_cache_.end()) { - std::shared_ptr> handler_vec = tuple_out_handle_cache_[pred.get()]; - int ret = adpt->setInput(src, SizeToInt(i), handler_vec); - if ((ret == 0) && pred->isa() && (pred->cast()->inputs().size() == handler_vec->size() + 1)) { - for (unsigned int j = 0; j < handler_vec->size(); j++) { - compute_sout_ << op_draw_name_[pred->cast()->input(j + 1).get()] << " -> " - << op_draw_name_[node.get()] << ":" << i << endl; - AddGraphConstInput(handler_vec->at(j).op); - } - } else { - MS_LOG(WARNING) << "Convert tuple node setInput failed : " << node->ToString(); - } - } else { - auto op = Convert(pred); - int ret = adpt->setInput(src, SizeToInt(i), op); - if (ret == 0) { - compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl; - AddGraphConstInput(op); - } - } - } -} - -void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) { - if (op->GetOpType() == "Constant") { - graph_const_inputs_.push_back(op); - } -} - -void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) { - if (!node->isa()) { - return; - } - if (op_cache_.find(node.get()) == op_cache_.end()) { - return; - } - auto cnode = node->cast(); - OpAdapterPtr adpt = FindAdapter(cnode, training_); - if (adpt == nullptr) { - error_ = NOT_FOUND; - return; - } - - // get Operator from op_cache_, use adapter to set Inputs - DfGraphConvertor::SetOpInput(adpt, cnode); -} - -// Update GE op's shape and type info -void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) { - if (nullptr == node || !node->isa()) { - return; - } - - if (op_cache_.find(node.get()) == op_cache_.end()) { - return; - } - - OpAdapterPtr adpt = FindAdapter(node, training_); - if (adpt == nullptr) { - error_ = NOT_FOUND; - return; - } - - // get Operator from op_cache_ - OperatorPtr op = Convert(node); - - adpt->updateOutputDesc(op, node->Shape(), node->Type(), node); -} - -OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) { - if (node == nullptr) { - MS_LOG(ERROR) << "node is nullptr"; - error_ = NOT_FOUND; - return nullptr; - } - // find in cache - if (op_cache_.count(node.get())) { - return op_cache_[node.get()]; - } - - // do not convert primitive node - if (IsValueNode(node)) { - return nullptr; - } - - // convert a new one - if (node->isa()) { - return ConvertCNode(node->cast()); - } - if (node->isa()) { - return ConvertParameter(node); - } - if (node->isa()) { - return ConvertValueNode(node->cast()); - } - - MS_LOG(ERROR) << "Invalide AnfNode"; - error_ = INVALID_ARGUMENT; - return nullptr; -} - -void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) { - std::shared_ptr> tuple_items = std::make_shared>(); - // convert each tuple item to a OutHandler - for (size_t i = 1; i < node->inputs().size(); i++) { - AnfNodePtr item = node->input(i); - OperatorPtr op = Convert(item); - if (op != nullptr) { - tuple_items->emplace_back(OutHandler(op, "")); - } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { - tuple_items->push_back(out_handle_cache_[item.get()]); - } else { - MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << item->ToString(); - return; - } - } - - tuple_out_handle_cache_[node.get()] = tuple_items; -} - -AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, unsigned int *index) { - const int TUPLE_GET_ITEM_INDEX = 2; - if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs - MS_LOG(EXCEPTION) << "length of inputs of TupleGetItem is less than 3"; - } - auto index_node = node->inputs()[TUPLE_GET_ITEM_INDEX]; - if (!index_node->isa()) { - error_ = INVALID_ARGUMENT; - MS_LOG(EXCEPTION) << "can't convert get item with non-constant index"; - } - *index = IntToUint(GetValue(GetValueNode(index_node))); - return node->inputs()[1]; -} - -AnfNodePtr DfGraphConvertor::TraceDepend(const CNodePtr &node) { - auto cnode = node->cast(); - if (cnode->inputs().size() < 3) { // "Depend" primitive have 3 inputs - MS_LOG(EXCEPTION) << "length of inputs of depend is less than 3"; - } - return cnode->inputs()[1]; -} - -AnfNodePtr DfGraphConvertor::TraceMakeTuple(const CNodePtr &node, unsigned int index) { - if (index + 1 >= node->inputs().size()) { - MS_LOG(EXCEPTION) << "length of make_tuple is less than index: " << index; - } - return node->inputs()[index + 1]; -} - -OutHandler DfGraphConvertor::GetHandler(const AnfNodePtr &node, const std::stack &index_stack, - AnfNode *const draw_index) { - if (node == nullptr) { - MS_LOG(ERROR) << "Get nullptr while trace real op"; - return OutHandler(nullptr, ""); - } - std::ostringstream ss; - ss << "op" << node.get(); - if (index_stack.empty()) { - op_draw_name_[draw_index] = ss.str(); - return OutHandler(Convert(node), ""); - } else { - OpAdapterPtr adpt = FindAdapter(node, training_); - if (nullptr == adpt) { - MS_LOG(ERROR) << "Can not get node output as adpt is nullptr!"; - error_ = NOT_FOUND; - return OutHandler(nullptr, ""); - } - OperatorPtr op = Convert(node); - if (op == nullptr) { - error_ = NOT_FOUND; - MS_LOG(ERROR) << "Can not convert node for trace real op"; - return OutHandler(nullptr, ""); - } - op_draw_name_[draw_index] = ss.str(); - return adpt->getOutput(Convert(node), UintToInt(index_stack.top())); - } -} - -// get the real operator through maketuple tuple_getitem depend -OutHandler DfGraphConvertor::TraceRealOp(AnfNodePtr node) { - bool flag = IsPrimitiveCNode(node, prim::kPrimTupleGetItem) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) || - IsPrimitiveCNode(node, prim::kPrimDepend); - std::stack index_stack; - auto draw_index = node.get(); - while (flag) { - flag = false; - if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { - unsigned int index; - node = TraceTupleGetItem(node->cast(), &index); - index_stack.push(index); - flag = true; - } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - if (index_stack.empty()) { - MS_LOG(ERROR) << "TraceRealOp find a make_tuple node"; - return OutHandler(nullptr, ""); - } else { - node = TraceMakeTuple(node->cast(), index_stack.top()); - index_stack.pop(); - flag = true; - } - } else if (IsPrimitiveCNode(node, prim::kPrimDepend)) { - node = TraceDepend(node->cast()); - flag = true; - } - } - return GetHandler(node, index_stack, draw_index); -} - -void DfGraphConvertor::ConvertTupleGetItem(const CNodePtr node) { - auto handle = TraceRealOp(node); - if (handle.op == nullptr) { - MS_LOG(ERROR) << "Failed to trace tuple get item"; - return; - } - out_handle_cache_[node.get()] = handle; -} - -// Get the real op for tuple_getitem through make tuple, or depend -AnfNodePtr DfGraphConvertor::GetRealOpNode(AnfNodePtr node) { - const int TUPLE_GET_ITEM_INDEX = 2; - if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { - auto node_inputs = node->cast()->inputs(); - if (node_inputs.size() != 3) { // "tuple_getitem" primitive must have 3 inputs - MS_LOG(ERROR) << "tuple get item node not correct!"; - error_ = FAILED; - return node; - } - MS_EXCEPTION_IF_NULL(node_inputs[TUPLE_GET_ITEM_INDEX]); - if (!node_inputs[TUPLE_GET_ITEM_INDEX]->isa()) { - error_ = INVALID_ARGUMENT; - MS_LOG(EXCEPTION) << "can't convert get item with non-constant index"; - } - auto value_ptr = GetValueNode(node_inputs[TUPLE_GET_ITEM_INDEX])->cast(); - if (value_ptr == nullptr) { - MS_LOG(ERROR) << "Can not convert get item as value is nullptr!"; - error_ = FAILED; - return node; - } - int index = value_ptr->value(); - - // make_tuple apply inputs:make_tuple, [tuple_items,] - if (IsPrimitiveCNode(node_inputs[1], prim::kPrimMakeTuple)) { - auto tuple_inputs = node->cast()->inputs(); - if (tuple_inputs.size() < IntToSize(index + 1)) { - MS_LOG(ERROR) << "make tuple input items node not correct! size:" << tuple_inputs.size() - << ", item index:" << index; - error_ = FAILED; - return node; - } - return GetRealOpNode(tuple_inputs[IntToSize(index + 1)]); - } - return GetRealOpNode(node_inputs[1]); - } - - // depend apply inputs: depend,output,depended_node - if (IsPrimitiveCNode(node, prim::kPrimDepend)) { - auto depend_inputs = node->cast()->inputs(); - if (depend_inputs.size() != 3) { // "Depend" primitive have 3 inputs - MS_LOG(ERROR) << "depend input items not correct"; - error_ = FAILED; - return node; - } - return GetRealOpNode(depend_inputs[1]); - } - return node; -} - -// convert the anf node to corresponding operator list -std::vector DfGraphConvertor::ConvertDependNode(const AnfNodePtr node) { - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - std::vector op_lists; - auto node_inputs = node->cast()->inputs(); - for (size_t index = 1; index < node_inputs.size(); index++) { - auto op = Convert(GetRealOpNode(node_inputs[index])); - if (op == nullptr) { - MS_LOG(ERROR) << "Convert control depend node to operator failed"; - error_ = FAILED; - return std::vector({}); - } - op_lists.push_back(op); - } - return op_lists; - } - - auto op = Convert(GetRealOpNode(node)); - if (op == nullptr) { - MS_LOG(ERROR) << "Convert control depend node to operator failed"; - error_ = FAILED; - return std::vector({}); - } - return std::vector({op}); -} - -// get the anf node list for depend -std::vector DfGraphConvertor::GetDependNodes(const AnfNodePtr &node) { - std::vector nodes; - // for make tuple, should control depend on the tuple items - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - auto node_inputs = node->cast()->inputs(); - for (size_t index = 1; index < node_inputs.size(); index++) { - nodes.push_back(GetRealOpNode(node_inputs[index])); - } - return nodes; - } - - // for parameter ,find the apply that used the parameter as the control depended node - if (node->isa()) { - auto uses = node->func_graph()->manager()->node_users()[node]; - for (auto &use : uses) { - auto use_node = use.first; - if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { - nodes.push_back(GetRealOpNode(use_node)); - } - } - return nodes; - } - nodes.push_back(GetRealOpNode(node)); - return nodes; -} - -void DfGraphConvertor::DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node) { -#ifdef DRAW_GE_GRAPH - auto src_depend_nodes = GetDependNodes(src_node); - auto dst_depend_nodes = GetDependNodes(dest_node); - if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() > 1) { - for (auto &item : dst_depend_nodes) { - compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[item.get()] - << "[style=\"dotted\"]" << endl; - } - } else if (src_depend_nodes.size() > 1 && dst_depend_nodes.size() == 1) { - for (auto &item : src_depend_nodes) { - compute_sout_ << op_draw_name_[item.get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] - << "[style=\"dotted\"]" << endl; - } - } else if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() == 1) { - compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] - << "[style=\"dotted\"]" << endl; - } -#endif -} - -void DfGraphConvertor::GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, - const AnfNodePtr &dest_node, - const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list) { - if (src_node->isa()) { - auto uses = node->func_graph()->manager()->node_users()[src_node]; - for (auto &use : uses) { - auto use_node = use.first; - if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && - (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { - auto converted_list = ConvertDependNode(use_node); - src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); - } - } - } - - if (dest_node->isa()) { - auto uses = node->func_graph()->manager()->node_users()[dest_node]; - for (auto &use : uses) { - auto use_node = use.first; - if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && - (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { - auto converted_list = ConvertDependNode(use_node); - dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); - } - } - } -} - -bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, - const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list) { - const int CONTROL_DEPEND_INDEX = 0; - const int SRC_NODE_INDEX = 1; - const int DEST_NODE_INDEX = 2; - const int DEPEND_MODE_NORMAL_USE = 0; - const int DEPEND_MODE_ON_PARAMETER_USE = 1; - - auto node_inputs = node->inputs(); - if (node_inputs.size() <= DEST_NODE_INDEX) { - MS_LOG(WARNING) << "Control depend node input size error"; - return false; - } - auto src_node = node_inputs[SRC_NODE_INDEX]; - auto dest_node = node_inputs[DEST_NODE_INDEX]; - if ((src_node == nullptr) || (dest_node == nullptr)) { - MS_LOG(ERROR) << "Control depend node miss src or dest node"; - error_ = FAILED; - return false; - } - AnfNodePtr fn = node_inputs[CONTROL_DEPEND_INDEX]; - PrimitivePtr prim_ptr = GetValueNode(fn); - ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); - int depend_mode = DEPEND_MODE_NORMAL_USE; - if (mode_ptr != nullptr) { - auto mode_int = mode_ptr->cast(); - MS_EXCEPTION_IF_NULL(mode_int); - depend_mode = mode_int->value(); - MS_LOG(DEBUG) << "depend_mode = " << depend_mode; - } - if (depend_mode == DEPEND_MODE_ON_PARAMETER_USE) { - GetDependOnParameterUse(node, src_node, dest_node, src_ops_list, dst_ops_list); - } - - if (src_node->isa()) { - auto converted_list = ConvertDependNode(src_node); - src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); - } - - if (dest_node->isa()) { - auto converted_list = ConvertDependNode(dest_node); - dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); - } - if (src_ops_list->empty() || dst_ops_list->empty()) { - MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it"; - error_ = SUCCESS; - } - return true; -} - -void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { - const int SRC_NODE_INDEX = 1; - const int DEST_NODE_INDEX = 2; - if (control_depend_cache_.find(node.get()) != control_depend_cache_.end()) { - return; - } - auto node_inputs = node->inputs(); - if (node_inputs.size() <= DEST_NODE_INDEX) { - MS_LOG(WARNING) << "Control depend node input size error"; - return; - } - auto src_node = node_inputs[SRC_NODE_INDEX]; - auto dest_node = node_inputs[DEST_NODE_INDEX]; - if ((src_node == nullptr) || (dest_node == nullptr)) { - MS_LOG(ERROR) << "Control depend node miss src or dest node"; - error_ = FAILED; - return; - } - std::shared_ptr> src_ops_list = std::make_shared>(); - std::shared_ptr> dst_ops_list = std::make_shared>(); - if (!GetControlDependList(node, src_ops_list, dst_ops_list)) { - MS_LOG(ERROR) << "Get depend list failed"; - error_ = FAILED; - return; - } - std::vector control_edges; - if (src_ops_list->size() == 1 && dst_ops_list->size() > 1) { - (void)std::transform(dst_ops_list->begin(), dst_ops_list->end(), std::back_inserter(control_edges), - [src_ops_list](const OperatorPtr &op) -> ControlEdge { - return {(*src_ops_list)[0], op}; - }); - } else if (src_ops_list->size() > 1 && dst_ops_list->size() == 1) { - (void)std::transform(src_ops_list->begin(), src_ops_list->end(), std::back_inserter(control_edges), - [dst_ops_list](const OperatorPtr &op) -> ControlEdge { - return {op, (*dst_ops_list)[0]}; - }); - } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { - control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); - } else if (src_ops_list->empty() || dst_ops_list->empty()) { - MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it"; - } else { - MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() - << " -> dst:" << dst_ops_list->size(); - error_ = FAILED; - return; - } - control_depend_cache_[node.get()] = control_edges; - -#ifdef DRAW_GE_GRAPH - DrawControlDepend(src_node, dest_node); -#endif -} - -bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) { - // ignore apply node of return - if (name == "return" || name == "Depend") { - return false; - } - - // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers - if (name == "make_tuple") { - ConvertMakeTuple(node); - return false; - } - - // As for nodes with multi outputs, convert tuple_getitem to OutHandle - if (name == "tuple_getitem") { - ConvertTupleGetItem(node); - return false; - } - - if (name == "ControlDepend") { - ConvertControlDependNode(node); - return false; - } - - return true; -} - -OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) { - std::string name = GetCNodeFuncName(node); - if (!CheckCNode(name, node)) { - return nullptr; - } - - // get corresponding OpAdapter - OpAdapterPtr adpt = FindAdapter(node, training_); - if (adpt == nullptr) { - error_ = NOT_FOUND; - return nullptr; - } - - // get operator - OperatorPtr op = nullptr; - auto it_op = op_cache_.find(node.get()); - if (it_op != op_cache_.end()) { - op = it_op->second; - } else { - op = adpt->generate(node); - } - - // set attribute for primitive - (void)adpt->setAttr(op, node); - - // add into cache - (void)op_cache_.insert(std::make_pair(node.get(), op)); - - DrawCNode(node, adpt); - - return op_cache_[node.get()]; -} - -OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) { - // convert Parameter in ANF to variable in DataFlow - auto op = FindAdapter(node, training_)->generate(node); - op_cache_[node.get()] = op; - - // build index for parameter using name - std::string name = std::static_pointer_cast(node)->name(); - params_[name] = node; - - std::ostringstream ss; - ss << "op" << node.get(); - op_draw_name_[node.get()] = ss.str(); - compute_sout_ << ss.str() << "[shape=octagon, label=\"" << name << "\"]" << endl; - return op_cache_[node.get()]; -} - -Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) { - MS_EXCEPTION_IF_NULL(node); - ValuePtr value = node->value(); - MS_EXCEPTION_IF_NULL(value); - if (!value->isa() && !value->isa()) { - return FAILED; - } - - auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); - if (vec.empty()) { - return FAILED; - } - - std::shared_ptr> tuple_items = std::make_shared>(); - for (size_t i = 0; i < vec.size(); i++) { - MS_EXCEPTION_IF_NULL(vec[i]); - if (vec[i]->isa()) { - GeTensorPtr ge_tensor = transform::TransformUtil::ConvertTensor(vec[i]->cast(), kOpFormat_NCHW); - auto const_op = std::make_shared(node->fullname_with_scope() + "/const/inputs/" + std::to_string(i)); - (void)const_op->set_attr_value(*ge_tensor); - (void)const_op->update_output_desc_y(ge_tensor->GetTensorDesc()); - tuple_items->emplace_back(OutHandler(const_op, "")); - } else { - return FAILED; - } - } - if (tuple_items->empty()) { - return FAILED; - } - - tuple_out_handle_cache_[node.get()] = tuple_items; - return SUCCESS; -} - -OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) { - // convert valuenode in ANF to Const in DataFlow - // find paramerte referenced by SymbolicKeyInstance of valuenode - std::ostringstream ss; - ss << "op" << node.get(); - op_draw_name_[node.get()] = ss.str(); - compute_sout_ << ss.str() << "[label= \"" << node->value()->ToString() << "\" shape=ellipse]" << endl; - - if (TryConvertValueNodeToMultiConst(node) == SUCCESS) { - MS_LOG(INFO) << "Convert value node to multi Constant OP success"; - return nullptr; - } - - OpAdapterPtr adpt = FindAdapter(node, training_); - if (adpt == nullptr) { - error_ = NOT_FOUND; - return nullptr; - } - auto op = adpt->generate(node); - // set const's attrs - if (adpt->setAttr(op, "value", node->value()) != 0) { - MS_LOG(WARNING) << "set attr value for const failed"; - } - -#if (defined ENABLE_GE) - auto const_op = std::static_pointer_cast(op); - if (const_op == nullptr) { - MS_LOG(ERROR) << "Get Constant operator failed"; - return nullptr; - } - auto ge_tensor = const_op->get_attr_value(); - auto ge_desc = ge_tensor.GetTensorDesc(); - (void)const_op->update_output_desc_y(ge_desc); -#endif - - op_cache_[node.get()] = op; - return op_cache_[node.get()]; -} - -void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) { - if (nullptr == adpt || nullptr == node) { - MS_LOG(ERROR) << "Failed to draw apply node as adpt or node is nullptr!"; - return; - } - std::ostringstream ss; - ss << "op" << node.get(); - op_draw_name_[node.get()] = ss.str(); - - compute_sout_ << ss.str() << "[label=<"; - compute_sout_ << "" << endl; - - auto input_map = adpt->getInputMap(); - auto dyn_input_map = adpt->getDynInputMap(); - if (input_map.size() + dyn_input_map.size() > 0) { - compute_sout_ << ""; - for (auto &it : input_map) { - compute_sout_ << ""; - } - for (auto &it : dyn_input_map) { - compute_sout_ << ""; - } - compute_sout_ << "" << endl; - } - - compute_sout_ << "" << endl; - - // print attrs' values - auto atts = adpt->GetAttrsFromDrawGraph(); - for (auto &it : atts) { - compute_sout_ << ""; - } - - adpt->clearAttrVect(); - - compute_sout_ << "
" << it.second.name << "" << it.second.name << "
\"" << node->ToString() - << ":" << GetCNodeFuncName(node) << "\"
\"" << it - << "\"
> shape=plaintext]" << endl; -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/convert.h b/mindspore/ccsrc/transform/convert.h deleted file mode 100644 index 2f6c9bb0ad..0000000000 --- a/mindspore/ccsrc/transform/convert.h +++ /dev/null @@ -1,251 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ -#define MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ - -#define DRAW_GE_GRAPH - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "transform/util.h" -#include "ir/tensor.h" -#include "transform/df_graph_manager.h" -#include "utils/config_manager.h" -#include "transform/op_declare.h" -#include "graph/operator_reg.h" -#ifdef OPEN_SOURCE -#include "ge/client/ge_api.h" -#else -#include "external/ge/ge_api.h" -#endif -#include "graph/tensor.h" -#include "ops/all_ops.h" - -namespace mindspore { -namespace transform { -class OpAdapterDesc { - public: - OpAdapterDesc() : train_(nullptr), infer_(nullptr) {} - - OpAdapterDesc(const OpAdapterPtr &train, const OpAdapterPtr &infer) : train_(train), infer_(infer) {} - - explicit OpAdapterDesc(const OpAdapterPtr &common) : train_(common), infer_(common) {} - - OpAdapterDesc(const OpAdapterDesc &desc) { - this->train_ = desc.train_; - this->infer_ = desc.infer_; - } - - OpAdapterDesc(OpAdapterDesc &&desc) { - this->train_ = desc.train_; - this->infer_ = desc.infer_; - desc.train_ = nullptr; - desc.infer_ = nullptr; - } - - ~OpAdapterDesc() = default; - - OpAdapterPtr Get(bool train) const { return train ? train_ : infer_; } - - OpAdapterDesc &operator=(const OpAdapterDesc &desc) { - if (this != &desc) { - this->train_ = desc.train_; - this->infer_ = desc.infer_; - } - return *this; - } - - OpAdapterDesc &operator=(OpAdapterDesc &&desc) { - if (this != &desc) { - this->train_ = desc.train_; - this->infer_ = desc.infer_; - desc.train_ = nullptr; - desc.infer_ = nullptr; - } - return *this; - } - - private: - OpAdapterPtr train_; - OpAdapterPtr infer_; -}; - -using OpAdapterDescPtr = std::shared_ptr; -using TensorOrderMap = std::map>; - -class DfGraphConvertor { - public: - explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) - : anf_graph_(anf_graph), df_graph_(std::make_shared(anf_graph_->ToString())) { -#if (!defined ENABLE_GE) || (defined ENABLE_INFER) - training_ = anf_graph->has_flag("training"); -#else - training_ = ENABLE_TRAIN; -#endif - distribute_ = anf_graph->has_flag("broadcast_flag"); - if (anf_graph->has_flag("broadcast_flag")) { - ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION); - } else { - ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE); - } - - MS_LOG(INFO) << "Create DfGraphConvertor with training: " << training_ << ", distribute: " << distribute_; - } - - ~DfGraphConvertor() {} - - static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt) { - get_adpt_map()[name] = std::make_shared(adpt); - } - static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { - get_adpt_map()[name] = std::make_shared(train_adpt, infer_adpt); - } - - void DrawComputeGraph(const std::string &name) { - std::ofstream fout(name); - if (!fout.is_open()) { - MS_LOG(ERROR) << "Open file '" << name << "' failed!"; - return; - } - fout << compute_sout_.str(); - fout.close(); - } - void DrawInitGraph(const std::string &name) { - std::ofstream fout(name); - if (!fout.is_open()) { - MS_LOG(ERROR) << "Open file '" << name << "' failed!"; - return; - } - fout << init_sout_.str(); - fout.close(); - } - void DrawSaveCheckpointGraph(const std::string &name) { - std::ofstream fout(name); - if (!fout.is_open()) { - MS_LOG(ERROR) << "Open file '" << name << "' failed!"; - return; - } - fout << checkpoint_sout_.str(); - fout.close(); - } - - DfGraphConvertor &ConvertAllNode(); - DfGraphConvertor &BuildGraph(); - DfGraphConvertor &InitParam(const TensorOrderMap &tensors); - DfGraphConvertor &GenerateCheckpointGraph(); - DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors); - void InitParamWithData(const TensorOrderMap &tensors); - void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node); - void SetupBroadcast(const std::shared_ptr &broadcast, const std::vector &broadcast_desc, - const DfGraphPtr &broadcast_graph, std::vector broadcast_input); - void MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it); - void SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input); - void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it); - - DfGraphPtr GetComputeGraph(); - DfGraphPtr GetInitGraph(); - DfGraphPtr GetSaveCheckpointGraph(); - DfGraphPtr GetBroadcastGraph(); - static OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false); - static OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false); - int ErrCode() const { return static_cast(error_); } - - static std::unordered_map &get_adpt_map(); - bool is_training() const { return training_; } - void set_training(bool is_training) { training_ = is_training; } - - protected: - void InitLoopVar(std::vector *init_input); - - private: - std::ostringstream compute_sout_; - std::ostringstream init_sout_; - std::ostringstream checkpoint_sout_; - std::ostringstream restore_checkpoint_sout_; - std::unordered_map op_draw_name_; - - AnfNodePtr TraceTupleGetItem(const CNodePtr &node, unsigned int *index); - AnfNodePtr TraceMakeTuple(const CNodePtr &node, unsigned int index); - AnfNodePtr TraceDepend(const CNodePtr &node); - OutHandler TraceRealOp(AnfNodePtr node); - OutHandler GetHandler(const AnfNodePtr &node, const std::stack &index_stack, AnfNode *const draw_index); - OperatorPtr Convert(AnfNodePtr node); - OperatorPtr ConvertCNode(CNodePtr node); - std::vector ConvertDependNode(AnfNodePtr node); - AnfNodePtr GetRealOpNode(AnfNodePtr node); - std::vector GetDependNodes(const AnfNodePtr &node); - OperatorPtr ConvertParameter(AnfNodePtr node); - Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); - OperatorPtr ConvertValueNode(ValueNodePtr node); - void ConvertTupleGetItem(const CNodePtr node); - void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, - const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list); - bool GetControlDependList(const CNodePtr &node, const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list); - void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); - void ConvertControlDependNode(const CNodePtr node); - void ConvertMakeTuple(const CNodePtr node); - bool CheckCNode(const std::string &name, const CNodePtr node); - void TraceOutput(AnfNodePtr node); - void TraceOutputFromParameter(const AnfNodePtr &anf_out); - void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out); - void SetNodeInput(AnfNodePtr node); - void SetOpControlInput(const AnfNodePtr node); - void UpdateOpDesc(AnfNodePtr node); - void BuildSaveCheckpointGraph(); - void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); - void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; - void AddGraphConstInput(const OperatorPtr &op); - - std::shared_ptr anf_graph_{nullptr}; - std::shared_ptr df_graph_{nullptr}; - std::shared_ptr init_graph_{nullptr}; - std::shared_ptr save_ckp_graph_{nullptr}; - std::shared_ptr restore_ckp_graph_{nullptr}; - std::shared_ptr broadcast_graph_{nullptr}; - std::unordered_map op_cache_; - std::unordered_map> control_depend_cache_; - /* record "tuple_getitem"<->"out_handler" mapping */ - std::unordered_map out_handle_cache_; - /* record "make_tuple"<->"out_handler vector" mapping */ - std::unordered_map>> tuple_out_handle_cache_; - std::unordered_map params_; - std::unordered_map vars_; - std::vector> graph_outputs_; - std::vector graph_const_inputs_; - std::vector init_ops_; - std::vector broadcast_ops_; - OperatorPtr dataset_iter_getnext_; - Status error_ = SUCCESS; - bool training_ = false; - bool distribute_ = false; -}; -} // namespace transform -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ diff --git a/mindspore/ccsrc/transform/df_graph_manager.cc b/mindspore/ccsrc/transform/df_graph_manager.cc deleted file mode 100644 index f62c386587..0000000000 --- a/mindspore/ccsrc/transform/df_graph_manager.cc +++ /dev/null @@ -1,214 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/df_graph_manager.h" - -#include -#include -#include -#include - -#include "securec/include/securec.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/pipeline.h" -#include "utils/config_manager.h" -#ifndef NO_DLIB -#include "tdt/tsd_client.h" -#endif - -namespace mindspore { -namespace transform { -DfGraphWrapper::DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, - const OptionMap &options) - : name_(name), id_(id), graph_ptr_(graph_ptr), options_(options) {} - -DfGraphManager::DfGraphManager() { - graph_id_ = 0; - graph_runner_ptr_ = nullptr; - sess_ptr_ = nullptr; -} - -DfGraphManager::~DfGraphManager() { - // in python fisrt destroy after atexit but in c++ destoy before atexit - DeleteGraphRunner(); - DeleteGeSession(); - ClearGraph(); - parse::python_adapter::set_python_env_flag(false); -} - -DfGraphManager &DfGraphManager::GetInstance() { - static DfGraphManager instance; - return instance; -} - -int DfGraphManager::GenerateId() { - graph_id_++; - if (graph_id_ <= 0) { - graph_id_ = 1; - } - MS_LOG(INFO) << "Generate graph Id : " << graph_id_; - return graph_id_; -} - -Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph_ptr, const OptionMap &options) { - std::lock_guard lg(lock_); - if (name.empty()) { - MS_LOG(ERROR) << "The graph name is null, add graph failed"; - return Status::INVALID_ARGUMENT; - } - - if (graph_ptr == nullptr) { - MS_LOG(WARNING) << "The new graph {" << name << "}'s pointer is null, add graph failed"; - return Status::INVALID_ARGUMENT; - } - - int id = GenerateId(); - DfGraphWrapperPtr wrap_ptr = std::make_shared(name, id, graph_ptr, options); - auto ret = graphs_.emplace(name, wrap_ptr); - if (ret.second == false) { - MS_LOG(WARNING) << "The graph name:{ " << name << " }is already exists! The old graph will be overwritten!!"; - ret.first->second = wrap_ptr; - } - MS_LOG(INFO) << "Add graph " << name << " to GraphManager success!"; - return Status::SUCCESS; -} - -std::vector DfGraphManager::GetAllGraphs() { - std::lock_guard lg(lock_); - std::vector ret; - std::stringstream ss; - ss << "{ "; - for (auto it = graphs_.begin(); it != graphs_.end(); ++it) { - ss << it->first << ", "; - ret.emplace_back(it->second); - } - ss << "}"; - MS_LOG(INFO) << "Return graphs: " << ss.str(); - return ret; -} -std::set DfGraphManager::GetSavedGraphs() { return saved_graphs_; } - -void DfGraphManager::AddSavedGraphs(const std::string &id) { saved_graphs_.insert(id); } - -DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string &name) { - std::lock_guard lg(lock_); - if (name.empty()) { - MS_LOG(ERROR) << "The graph name is null"; - return nullptr; - } - - auto it = graphs_.find(name); - if (it == graphs_.end()) { - MS_LOG(INFO) << "Can't found graph name: " << name; - return nullptr; - } - MS_LOG(INFO) << "Return graph: " << name; - return it->second; -} - -void DfGraphManager::ClearGraph() noexcept { - std::lock_guard lg(lock_); - graphs_.clear(); - anf_graphs_.clear(); - MS_LOG(INFO) << "Remove all graphs in GraphManager"; -} - -void DfGraphManager::SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) { - DfGraphWrapperPtr df_graph = GetGraphByName(name); - if (df_graph == nullptr) { - MS_LOG(ERROR) << "Can't found graph name: " << name; - return; - } - std::lock_guard lg(lock_); - anf_graphs_[df_graph->id_] = anf_graph_ptr; -} - -AnfGraphPtr DfGraphManager::GetAnfGraph(uint32_t graph_id) { - std::lock_guard lg(lock_); - auto iter = anf_graphs_.find(graph_id); - if (iter == anf_graphs_.end()) { - MS_LOG(ERROR) << "Can't found anf graph, graph_id = " << graph_id; - return nullptr; - } - - return iter->second; -} - -void DfGraphManager::EraseAnfGraph() { - std::lock_guard lg(lock_); - anf_graphs_.clear(); -} - -void DfGraphManager::SetGeSession(const std::shared_ptr &sess_ptr) { - std::lock_guard lg(lock_); - if (sess_ptr == nullptr) { - MS_LOG(WARNING) << "You are adding a empty Ge Session"; - } - - if (sess_ptr_ == nullptr) { - MS_LOG(INFO) << "Add a new Ge Session success"; - } else { - MS_LOG(INFO) << "Add a new Ge Session success, the old Ge Session will be overwritten!!"; - } - sess_ptr_ = sess_ptr; -} - -std::shared_ptr DfGraphManager::GetGeSession() { - std::lock_guard lg(lock_); - return sess_ptr_; -} - -void DfGraphManager::DeleteGeSession() noexcept { - std::lock_guard lg(lock_); - if (sess_ptr_ == nullptr) { - MS_LOG(INFO) << "Ge Session is not exist"; - } else { - sess_ptr_ = nullptr; - saved_graphs_.clear(); - MS_LOG(INFO) << "Delete Ge Session success"; - } -} - -void DfGraphManager::SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept { - std::lock_guard lg(lock_); - if (graph_runner_ptr == nullptr) { - MS_LOG(WARNING) << "You are adding a empty GraphRunner"; - } - - if (graph_runner_ptr_ == nullptr) { - MS_LOG(INFO) << "Add a new GraphRunner success"; - } else { - MS_LOG(INFO) << "Add a new GraphRunner success, the old GraphRunner will be overwritten!!"; - } - graph_runner_ptr_ = graph_runner_ptr; -} - -std::shared_ptr DfGraphManager::GetGraphRunner() { - std::lock_guard lg(lock_); - return graph_runner_ptr_; -} - -void DfGraphManager::DeleteGraphRunner() noexcept { - std::lock_guard lg(lock_); - if (graph_runner_ptr_ == nullptr) { - MS_LOG(INFO) << "GraphRunner is not exist"; - } else { - graph_runner_ptr_ = nullptr; - MS_LOG(INFO) << "Delete GraphRunner success"; - } -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/df_graph_manager.h b/mindspore/ccsrc/transform/df_graph_manager.h deleted file mode 100644 index 2ca43d1f07..0000000000 --- a/mindspore/ccsrc/transform/df_graph_manager.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_DF_GRAPH_MANAGER_H_ -#define TRANSFORM_DF_GRAPH_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "transform/types.h" -#include "ir/anf.h" - -namespace mindspore { -const char BROADCAST_GRAPH_NAME[] = "broadcast_subgraph"; - -namespace transform { -class GraphRunner; -using OptionMap = std::map; - -struct DfGraphWrapper { - public: - DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, const OptionMap &options); - ~DfGraphWrapper() {} - - std::string name_; - int id_; - DfGraphPtr graph_ptr_; - OptionMap options_ = {}; -}; - -using DfGraphWrapperPtr = std::shared_ptr; - -class DfGraphManager { - public: - ~DfGraphManager(); - void ClearGraph() noexcept; - - static DfGraphManager &GetInstance(); - Status AddGraph(const std::string &name, const DfGraphPtr &graph, const OptionMap &options = {}); - std::vector GetAllGraphs(); - std::set GetSavedGraphs(); - void AddSavedGraphs(const std::string &id); - DfGraphWrapperPtr GetGraphByName(const std::string &name); - DfGraphManager(const DfGraphManager &) = delete; - void SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr); - AnfGraphPtr GetAnfGraph(uint32_t graph_id); - std::shared_ptr GetGraphRunner(); - void SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept; - void DeleteGraphRunner() noexcept; - void SetGeSession(const std::shared_ptr &sess_ptr); - std::shared_ptr GetGeSession(); - void DeleteGeSession() noexcept; - void EraseAnfGraph(); - - private: - DfGraphManager(); - int GenerateId(); - - std::mutex lock_; - std::map graphs_; - std::set saved_graphs_; - int graph_id_; - std::map anf_graphs_; - std::shared_ptr graph_runner_ptr_; - std::shared_ptr sess_ptr_; -}; -} // namespace transform -} // namespace mindspore - -#endif // TRANSFORM_DF_GRAPH_MANAGER_H_ diff --git a/mindspore/ccsrc/transform/graph_builder.cc b/mindspore/ccsrc/transform/graph_builder.cc deleted file mode 100644 index 785c5c7f3a..0000000000 --- a/mindspore/ccsrc/transform/graph_builder.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/graph_builder.h" - -#include -#include - -namespace mindspore { -namespace transform { -DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam ¶m) { - MS_LOG(INFO) << "BuildMDDatasetGraph."; - - // InitData - auto d = ge::op::InitData("init_data_tmp").set_attr_channel_name(param.queue_name()); - - // set graph inputs & outputs - std::vector inputs{d}; - std::vector outputs{d}; - DfGraphPtr dataset_graph = std::make_shared("dataset"); - (void)dataset_graph->SetInputs(inputs); - (void)dataset_graph->SetOutputs(outputs); - - return dataset_graph; -} - -Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase) { - Status ret; - std::string graph_name = phase; - - MS_LOG(INFO) << "BuildDatasetGraph begin. phase is " << phase; - MS_LOG(INFO) << "param is " << param.ToString() << "."; - - DfGraphPtr dataset_graph = BuildMDDatasetGraph(param); - ret = DfGraphManager::GetInstance().AddGraph(graph_name, dataset_graph); - if (ret != Status::SUCCESS) { - MS_LOG(ERROR) << "BuildDatasetGraph failed."; - } else { - MS_LOG(INFO) << "BuildDatasetGraph end."; - } - return ret; -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_builder.h b/mindspore/ccsrc/transform/graph_builder.h deleted file mode 100644 index 3d959f5a85..0000000000 --- a/mindspore/ccsrc/transform/graph_builder.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_GRAPH_BUILDER_H_ -#define TRANSFORM_GRAPH_BUILDER_H_ - -#include -#include -#include -#include -#include -#include "transform/types.h" -#include "transform/convert.h" - -namespace mindspore { -namespace transform { -Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase = "dataset"); -} // namespace transform -} // namespace mindspore - -#endif // TRANSFORM_GRAPH_BUILDER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt b/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt new file mode 100644 index 0000000000..3f062609d5 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt @@ -0,0 +1,9 @@ +if (ENABLE_GE OR ENABLE_D) + file(GLOB_RECURSE _TRANSFORM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") + set_property(SOURCE ${_TRANSFORM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE_ADPT) + add_library(_mindspore_transform_graph_ir_obj OBJECT ${_TRANSFORM_SRC_LIST}) + + if (NOT ENABLE_GE) + target_compile_definitions(_mindspore_transform_graph_ir_obj PRIVATE NO_GE_CLIENT) + endif() +endif () diff --git a/mindspore/ccsrc/transform/all_ops.h b/mindspore/ccsrc/transform/graph_ir/all_ops.h similarity index 100% rename from mindspore/ccsrc/transform/all_ops.h rename to mindspore/ccsrc/transform/graph_ir/all_ops.h diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc new file mode 100644 index 0000000000..7419dd2cc9 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -0,0 +1,2073 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/convert.h" + +#include +#include +#include +#include "utils/utils.h" + +#include "frontend/operator/ops.h" +#include "utils/log_adapter.h" +#include "utils/graph_utils.h" +#include "utils/symbolic.h" +#include "utils/config_manager.h" +#include "utils/convert_utils.h" +#include "./common.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace transform { +using std::endl; + +#define ADPT_DESC_ONE(T) std::make_shared(std::make_shared>()) +#define ADPT_DESC_TWO(T, I) \ + std::make_shared(std::make_shared>(), std::make_shared>()) +#define GET_MACRO(_1, _2, DESC, ...) DESC +#define ADPT_DESC(...) GET_MACRO(__VA_ARGS__, ADPT_DESC_TWO, ADPT_DESC_ONE, ...)(__VA_ARGS__) + +using ge::Operator; +using mindspore::kAnyValue; +using std::make_shared; +using std::shared_ptr; +using std::string; +using std::vector; + +const char kNameCustomOp[] = "CustomOp"; +const char kNameConst[] = "Const"; +const char kNameParam[] = "parameter"; +const char kNameRandomUniform[] = "RandomUniform"; +const char kNameSimpleMean[] = "SimpleMean"; +const char kNameSimpleMeanGrad[] = "SimpleMeanGrad"; +const char kNameAllReduce[] = "AllReduce"; +const char kNameBroadcast[] = "Broadcast"; +const char kNameAllgather[] = "AllGather"; +const char kNameReduceScatter[] = "ReduceScatter"; +const char kNameReduceSum[] = "ReduceSum"; +const char kNameIsFinite[] = "isFinite"; +const char kNameReciprocal[] = "Reciprocal"; +const char kNameRsqrt[] = "Rsqrt"; +const char kNameRsqrtGrad[] = "RsqrtGrad"; +const char kNameSqrt[] = "Sqrt"; +const char kNameSquare[] = "Square"; +const char kNameSquaredDifference[] = "SquaredDifference"; +const char kNamePow[] = "Pow"; +const char kNameBatchMatMul[] = "BatchMatMul"; +const char kNameStridedSlice[] = "StridedSlice"; +const char kNameStridedSliceGrad[] = "StridedSliceGrad"; +const char kNameExpandDims[] = "ExpandDims"; +const char kNameLog[] = "Log"; +const char kNameLogicalAnd[] = "LogicalAnd"; +const char kNameLogicalNot[] = "LogicalNot"; +const char kNameLogicalOr[] = "LogicalOr"; +const char kNameExp[] = "Exp"; +const char kNameLessEqual[] = "LessEqual"; +const char kNameGreaterEqual[] = "GreaterEqual"; +const char kNameEqual[] = "Equal"; +const char kNameNotEqual[] = "NotEqual"; +const char kNameFlattenGrad[] = "FlattenGrad"; +const char kNameConvolution[] = "Convolution"; +const char kNameBiasAdd[] = "BiasAdd"; +const char kNameMaxPoolGrad[] = "MaxPoolGrad"; +const char kNameAvgPoolGrad[] = "AvgPoolGrad"; +const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax"; +const char kNameApplyMomentum[] = "ApplyMomentum"; +const char kNameDropoutDoMask[] = "DropoutDoMask"; +const char kNameResizeBilinear[] = "ResizeBilinear"; +const char kNameResizeBilinearGrad[] = "ResizeBilinearGrad"; +const char kNameZerosLike[] = "ZerosLike"; +const char kNameOnesLike[] = "OnesLike"; +const char kNameTruncatedNormal[] = "TruncatedNormal"; +const char kNameSpaceToBatchNd[] = "SpaceToBatchNd"; +const char kNameConfusionMatrix[] = "ConfusionMatrix"; +const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor"; +const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad"; +const char kNameApplyAdam[] = "Adam"; +const char kNameExtractImagePatches[] = "ExtractImagePatches"; +const char kNameReLU6[] = "ReLU6"; +const char kNameReLU6Grad[] = "ReLU6Grad"; +const char kNameElu[] = "Elu"; +const char kNameEluGrad[] = "EluGrad"; +const char kNameTensorScatterUpdate[] = "TensorScatterUpdate"; +const char kNameScatterUpdate[] = "ScatterUpdate"; +const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; +const char kNameScatterMax[] = "ScatterMax"; +const char kNameNMSWithMask[] = "NMSWithMask"; +const char kNameCheckValid[] = "CheckValid"; +const char kNameSmoothL1Loss[] = "SmoothL1Loss"; +const char kNameSmoothL1LossGrad[] = "SmoothL1LossGrad"; +const char kNameSGD[] = "SGD"; +const char kNameSigmoidCrossEntropyWithLogits[] = "SigmoidCrossEntropyWithLogits"; +const char kNameSigmoidCrossEntropyWithLogitsGrad[] = "SigmoidCrossEntropyWithLogitsGrad"; +const char kNameScatterNdD[] = "ScatterNd"; +const char kNamePadD[] = "Pad"; +const char kNameMirrorPad[] = "MirrorPad"; +const char kNameMirrorPadGrad[] = "MirrorPadGrad"; +const char kNameGatherNd[] = "GatherNd"; +const char kNameArgmax[] = "Argmax"; +const char kNameArgmin[] = "Argmin"; +const char kNameArgMaxWithValue[] = "ArgMaxWithValue"; +const char kNameArgMinWithValue[] = "ArgMinWithValue"; +const char kNameReduceProd[] = "ReduceProd"; +const char kNameCumProd[] = "CumProd"; +const char kNameDiagpart[] = "Diagpart"; +const char kNameSplitD[] = "Split"; +const char kNameBatchToSpaceNd[] = "BatchToSpaceNd"; +const char kNameFloor[] = "Floor"; +const char kNameNPUGetFloatStatus[] = "NPUGetFloatStatus"; +const char kNameAssign[] = "Assign"; +const char kNameAssignAdd[] = "AssignAdd"; +const char kNameAssignSub[] = "AssignSub"; +const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus"; +const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus"; +const char kNameReshape[] = "Reshape"; +const char kNameTransShape[] = "TransShape"; +const char kNameRealDiv[] = "RealDiv"; +const char kNameTile[] = "Tile"; +const char kNameCos[] = "Cos"; +const char kNameACos[] = "ACos"; +const char kNameACosGrad[] = "ACosGrad"; +const char kNameFloorDiv[] = "FloorDiv"; +const char kNameSin[] = "Sin"; +const char kNamePrelu[] = "PReLU"; +const char kNamePreluGrad[] = "PReLUGrad"; +const char kNameSigmoid[] = "Sigmoid"; +const char kNameSigmoidGrad[] = "SigmoidGrad"; +const char kNameL2Normalize[] = "L2Normalize"; +const char kNameL2NormalizeGrad[] = "L2NormalizeGrad"; +const char kNameSoftmax[] = "Softmax"; +const char kNameIOU[] = "IOU"; +const char kNameBoundingBoxDecode[] = "BoundingBoxDecode"; +const char kNameBoundingBoxEncode[] = "BoundingBoxEncode"; +const char kNameSlice[] = "Slice"; +const char kNameAddN[] = "AddN"; +const char kNameLess[] = "Less"; +const char kNameGreater[] = "Greater"; +const char kNamePack[] = "Pack"; +const char kNameUnpack[] = "Unpack"; +const char kNameMerge[] = "Merge"; +const char kNameGeSwitch[] = "GeSwitch"; + +const char kNameHuberLoss[] = "HuberLoss"; +const char kNameCumSum[] = "CumSum"; +const char kNameHuberLossGrad[] = "HuberLossGrad"; +const char kNameSparseSoftmaxCrossEntropy[] = "SparseSoftmaxCrossEntropy"; +const char kNameSparseSoftmaxCrossEntropyGrad[] = "SparseSoftmaxCrossEntropyGrad"; +const char kNameTopK[] = "TopK"; +const char kNameSoftmaxGrad[] = "SoftmaxGrad"; +const char kNameMaxPool[] = "MaxPool"; +const char kNameAvgPool[] = "AvgPool"; +const char kNameMaxPoolWithArgmax[] = "MaxPoolWithArgmax"; +const char kNameBatchNorm[] = "BatchNorm"; +const char kNameBatchNormGrad[] = "BatchNormGrad"; +const char kNameROIAlign[] = "ROIAlign"; +const char kNameROIAlignGrad[] = "ROIAlignGrad"; +const char kNameRandomChoiceWithMask[] = "RandomChoiceWithMask"; +const char kNameAbs[] = "Abs"; +const char kNameAbsGrad[] = "AbsGrad"; +const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy"; +const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad"; +const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad"; +const char kNameSparseApplyFtrlD[] = "SparseApplyFtrlD"; +const char kNameApplyProximalAdagrad[] = "ApplyProximalAdagrad"; +const char kNameAcosh[] = "Acosh"; +const char kNameAcoshGrad[] = "AcoshGrad"; +const char kNameFloorMod[] = "FloorMod"; +const char kNameSpaceToDepth[] = "SpaceToDepth"; +const char kNameDepthToSpace[] = "DepthToSpace"; +const char kNameSign[] = "Sign"; +const char kNameLARSUpdate[] = "LARSUpdate"; +const char kNameRound[] = "Round"; +const char kNamePrint[] = "Print"; +const char kNameApplyFtrl[] = "ApplyFtrl"; +const char kNameDiag[] = "Diag"; +const char kNameDiagPart[] = "DiagPart"; +const char kNameSpaceToBatch[] = "SpaceToBatch"; +const char kNameBatchToSpace[] = "BatchToSpace"; +const char kNameAtan2[] = "Atan2"; +const char kNameApplyRMSProp[] = "ApplyRMSProp"; +const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; +const char kNameL2Loss[] = "L2Loss"; +const char kNameCTCLoss[] = "CTCLoss"; +const char kNameRange[] = "Range"; +const char kNameSquareSumAll[] = "SquareSumAll"; +const char kNameAscendQuant[] = "AscendQuant"; +const char kNameAscendDequant[] = "AscendDequant"; +const char kNameCase[] = "Case"; + +// -----------------OpAdapter initialization-------------- +std::unordered_map &DfGraphConvertor::get_adpt_map() { + static std::unordered_map adpt_map = { + {string(kNameCustomOp), ADPT_DESC(Operator)}, + {string(kNameIOU), ADPT_DESC(Iou)}, + {string(kNameGreaterEqual), ADPT_DESC(GreaterEqual)}, + {string(kNameSlice), ADPT_DESC(SliceD)}, + {string(kNameApplyMomentum), ADPT_DESC(ApplyMomentumD)}, + {string(kNameMaxPool), ADPT_DESC(MaxPool)}, + {string(kNameAvgPool), ADPT_DESC(AvgPool)}, + {string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)}, + {string(kNameTopK), ADPT_DESC(TopK)}, + {string(kNamePack), ADPT_DESC(Pack)}, + {string(kNameUnpack), ADPT_DESC(Unpack)}, + {string(kNameSplitD), ADPT_DESC(SplitD)}, + {string(kNameAllReduce), ADPT_DESC(HcomAllReduce)}, + {string(kNameBroadcast), ADPT_DESC(HcomBroadcast)}, + {string(kNameAllgather), ADPT_DESC(HcomAllGather)}, + {string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)}, + {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, + {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, + {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, + {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, + {prim::kPrimAssign->name(), ADPT_DESC(Assign)}, + {prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)}, + {prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)}, + {prim::kPrimBiasAddGrad->name(), ADPT_DESC(BiasAddGrad)}, + {prim::kPrimConv2D->name(), ADPT_DESC(Conv2D)}, + {prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD)}, + {prim::kPrimConv2DBackpropFilter->name(), ADPT_DESC(Conv2DBackpropFilterD)}, + {prim::kPrimDepthwiseConv2dNative->name(), ADPT_DESC(DepthwiseConv2D)}, + {prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), ADPT_DESC(DepthwiseConv2DBackpropFilterD)}, + {prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), ADPT_DESC(DepthwiseConv2DBackpropInputD)}, + {string(kNameBatchNorm), ADPT_DESC(BatchNorm)}, + {string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)}, + {string(kNameReshape), ADPT_DESC(Reshape)}, + {string(kNameTransShape), ADPT_DESC(TransShape)}, + {string(kNameFlattenGrad), ADPT_DESC(Reshape)}, + {prim::kPrimFlatten->name(), ADPT_DESC(Flatten)}, + {string(kNameAddN), ADPT_DESC(AddN)}, + {string(kNameLess), ADPT_DESC(Less)}, + {string(kNameSqrt), ADPT_DESC(Sqrt)}, + {string(kNameRsqrt), ADPT_DESC(Rsqrt)}, + {string(kNameSquare), ADPT_DESC(Square)}, + {prim::kPrimTanh->name(), ADPT_DESC(Tanh)}, + {prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)}, + {string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborV2D)}, + {string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborV2Grad)}, + {string(kNameApplyAdam), ADPT_DESC(ApplyAdam)}, + {string(kNameReLU6), ADPT_DESC(Relu6)}, + {string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)}, + {string(kNameElu), ADPT_DESC(Elu)}, + {string(kNameEluGrad), ADPT_DESC(EluGrad)}, + {string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearV2Grad)}, + {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, + {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, + {string(kNameOnesLike), ADPT_DESC(OnesLike)}, + {string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)}, + {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, + {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, + {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, + {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, + {string(kNameCheckValid), ADPT_DESC(CheckValid)}, + {string(kNameSmoothL1Loss), ADPT_DESC(SmoothL1Loss)}, + {string(kNameSmoothL1LossGrad), ADPT_DESC(SmoothL1LossGrad)}, + {string(kNameSigmoidCrossEntropyWithLogits), ADPT_DESC(SigmoidCrossEntropyWithLogits)}, + {string(kNameSigmoidCrossEntropyWithLogitsGrad), ADPT_DESC(SigmoidCrossEntropyWithLogitsGrad)}, + {string(kNameScatterNdD), ADPT_DESC(ScatterNdD)}, + {string(kNamePadD), ADPT_DESC(PadD)}, + {string(kNameMirrorPad), ADPT_DESC(MirrorPad)}, + {string(kNameMirrorPadGrad), ADPT_DESC(MirrorPadGrad)}, + {string(kNameGatherNd), ADPT_DESC(GatherNd)}, + {string(kNameArgmax), ADPT_DESC(ArgMaxD)}, + {string(kNameArgmin), ADPT_DESC(ArgMinD)}, + {string(kNameArgMaxWithValue), ADPT_DESC(ArgMaxWithValue)}, + {string(kNameArgMinWithValue), ADPT_DESC(ArgMinWithValue)}, + {prim::kPrimReduceSum->name(), ADPT_DESC(ReduceSumD)}, + {prim::kPrimReduceMean->name(), ADPT_DESC(ReduceMeanD)}, + {prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAllD)}, + {prim::kPrimReduceMin->name(), ADPT_DESC(ReduceMinD)}, + {prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMaxD)}, + {string(kNameLARSUpdate), ADPT_DESC(LarsV2Update)}, + {string(kNameReduceProd), ADPT_DESC(ReduceProdD)}, + {string(kNameCumProd), ADPT_DESC(CumprodD)}, + {string(kNameMerge), ADPT_DESC(Merge)}, + {string(kNameGeSwitch), ADPT_DESC(Switch)}, + {string(kNameCumSum), ADPT_DESC(CumsumD)}, + + {prim::kPrimMul->name(), ADPT_DESC(Mul)}, + {string(kNameTile), ADPT_DESC(TileD)}, + {prim::kPrimOneHot->name(), ADPT_DESC(OneHot)}, + + {prim::kPrimGatherV2->name(), ADPT_DESC(GatherV2D)}, + {string(kNameCos), ADPT_DESC(Cos)}, + {string(kNameACos), ADPT_DESC(Acos)}, + {string(kNameACosGrad), ADPT_DESC(AcosGrad)}, + {string(kNameFloor), ADPT_DESC(Floor)}, + {string(kNameFloorDiv), ADPT_DESC(FloorDiv)}, + {string(kNameSin), ADPT_DESC(Sin)}, + {string(kNameExp), ADPT_DESC(Exp)}, + {string(kNameBoundingBoxEncode), ADPT_DESC(BoundingBoxEncode)}, + {string(kNameBoundingBoxDecode), ADPT_DESC(BoundingBoxDecode)}, + + {prim::kPrimCast->name(), ADPT_DESC(Cast)}, + {string(kNameRealDiv), ADPT_DESC(RealDiv)}, + {prim::kPrimNeg->name(), ADPT_DESC(Neg)}, + {prim::kPrimTranspose->name(), ADPT_DESC(TransposeD)}, + {prim::kPrimSub->name(), ADPT_DESC(Sub)}, + {string(kNameReciprocal), ADPT_DESC(Reciprocal)}, + {prim::kPrimDropoutGenMask->name(), ADPT_DESC(DropOutGenMask)}, + {string(kNameAssignAdd), ADPT_DESC(AssignAdd)}, + {string(kNameAssignSub), ADPT_DESC(AssignSub)}, + {prim::kPrimConcat->name(), ADPT_DESC(ConcatD)}, + {string(kNamePow), ADPT_DESC(Pow)}, + {string(kNameExp), ADPT_DESC(Exp)}, + {string(kNameEqual), ADPT_DESC(Equal)}, + {string(kNameNotEqual), ADPT_DESC(NotEqual)}, + {string(kNameLog), ADPT_DESC(Log)}, + {string(kNameLogicalAnd), ADPT_DESC(LogicalAnd)}, + {string(kNameLogicalNot), ADPT_DESC(LogicalNot)}, + {string(kNameLogicalOr), ADPT_DESC(LogicalOr)}, + {string(kNameGreater), ADPT_DESC(Greater)}, + {prim::kPrimMaximum->name(), ADPT_DESC(Maximum)}, + {prim::kPrimRelu->name(), ADPT_DESC(Relu)}, + {string(kNamePrelu), ADPT_DESC(PRelu)}, + {string(kNamePreluGrad), ADPT_DESC(PReluGrad)}, + {string(kNameSigmoid), ADPT_DESC(Sigmoid)}, + {string(kNameSigmoidGrad), ADPT_DESC(SigmoidGrad)}, + {string(kNameSGD), ADPT_DESC(SGD)}, + {prim::kPrimLogSoftmaxGrad->name(), ADPT_DESC(LogSoftmaxGrad)}, + {prim::kPrimMaximumGrad->name(), ADPT_DESC(MaximumGrad)}, + {prim::kPrimMinimumGrad->name(), ADPT_DESC(MinimumGrad)}, + {string(kNameL2Normalize), ADPT_DESC(L2Normalize)}, + {string(kNameL2NormalizeGrad), ADPT_DESC(L2NormalizeGrad)}, + + {prim::kPrimMinimum->name(), ADPT_DESC(Minimum)}, + {prim::kPrimSelect->name(), ADPT_DESC(Select)}, + {string(kNameLessEqual), ADPT_DESC(LessEqual)}, + {prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmaxV2)}, + {string(kNameTruncatedNormal), ADPT_DESC(TruncatedNormal)}, + {string(kNameStridedSliceGrad), ADPT_DESC(StridedSliceGrad)}, + {prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, + {prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)}, + {string(kNameStridedSlice), ADPT_DESC(StridedSlice)}, + {prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMin)}, + {prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)}, + {string(kNameExpandDims), ADPT_DESC(ExpandDims)}, + {prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)}, + {prim::kPrimLayerNorm->name(), ADPT_DESC(LayerNorm)}, + {prim::kPrimLayerNormGrad->name(), ADPT_DESC(LayerNormGrad)}, + {string(kNameBatchMatMul), ADPT_DESC(BatchMatMul)}, + {string(kNameDropoutDoMask), ADPT_DESC(DropOutDoMask)}, + + {string(kNameNPUGetFloatStatus), ADPT_DESC(NPUGetFloatStatus)}, + {string(kNameNPUAllocFloatStatus), ADPT_DESC(NPUAllocFloatStatus)}, + {string(kNameNPUClearFloatStatus), ADPT_DESC(NPUClearFloatStatus)}, + + {string(kNameRandomChoiceWithMask), ADPT_DESC(RandomChoiceWithMask)}, + {prim::kPrimSoftmaxCrossEntropyWithLogits->name(), ADPT_DESC(SoftmaxCrossEntropyWithLogits)}, + + {prim::kPrimScalarSummary->name(), ADPT_DESC(Summary)}, + {prim::kPrimImageSummary->name(), ADPT_DESC(Summary)}, + {prim::kPrimTensorSummary->name(), ADPT_DESC(Summary)}, + {prim::kPrimHistogramSummary->name(), ADPT_DESC(Summary)}, + {prim::kPrimDebug->name(), ADPT_DESC(Summary)}, + {prim::kPrimTensorAdd->name(), + std::make_shared(std::make_shared>(ExtraAttr({{"mode", MakeValue(1)}})), + std::make_shared>(ExtraAttr({{"mode", MakeValue(1)}})))}, + {string(kNameBiasAdd), ADPT_DESC(BiasAdd)}, + {prim::kPrimRelu->name(), ADPT_DESC(Relu)}, + + {prim::kPrimMatMul->name(), ADPT_DESC(MatMulV2)}, + + {string(kNameConst), ADPT_DESC(Constant, Const)}, + {string(kNameSoftmax), ADPT_DESC(SoftmaxV2)}, + {string(kNameSoftmaxGrad), ADPT_DESC(SoftmaxGrad)}, + {string(kNameParam), ADPT_DESC(Data)}, + {string(kNameROIAlign), ADPT_DESC(ROIAlign)}, + {string(kNameROIAlignGrad), ADPT_DESC(ROIAlignGrad)}, + {string(kNameAbs), ADPT_DESC(Abs)}, + {string(kNameAbsGrad), ADPT_DESC(AbsGrad)}, + {string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)}, + {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, + {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, + {string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)}, + {string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagradD)}, + {string(kNameAcosh), ADPT_DESC(Acosh)}, + {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, + {string(kNameFloorMod), ADPT_DESC(FloorMod)}, + {string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)}, + {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, + {string(kNameSign), ADPT_DESC(Sign)}, + {string(kNameRound), ADPT_DESC(Round)}, + {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrlD)}, + {string(kNameDiag), ADPT_DESC(Diag)}, + {string(kNameDiagPart), ADPT_DESC(DiagPart)}, + {string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)}, + {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, + {string(kNameAtan2), ADPT_DESC(Atan2)}, + {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, + {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}, + {string(kNameL2Loss), ADPT_DESC(L2Loss)}, + {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, + {string(kNameRange), ADPT_DESC(RangeD)}, + {string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}, + {string(kNameAscendQuant), ADPT_DESC(AscendQuant)}, + {string(kNameAscendDequant), ADPT_DESC(AscendDequant)}, + {string(kNameCase), ADPT_DESC(Case)}}; +#ifdef ENABLE_GE + adpt_map[string(kNamePrint)] = ADPT_DESC(Print); + adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); +#endif + return adpt_map; +} + +// ---------------implement of DfGraphConvertor------------- +PrimType GetCNodeFuncType(const CNodePtr cnode) { + if (cnode->inputs().empty()) { + return kPrimTypeUnknown; + } + + AnfNodePtr valuenode = cnode->input(0); + if (IsValueNode(valuenode)) { + // check whether the valuenode is primitive + return GetValueNode(valuenode)->prim_type(); + } + return kPrimTypeUnknown; +} + +bool IsCaseNode(const CNodePtr node) { + if (!node->inputs().empty() && node->input(0)->isa() && + GetCNodeFuncName(node->input(0)->cast()) == "switch_layer") { + return true; + } + return false; +} + +std::string GetCNodeTargetFuncName(const CNodePtr cnode) { + if (IsCaseNode(cnode)) { + return string(kNameCase); + } + auto name = GetCNodeFuncName(cnode); + if (name == "switch_layer") { + name = ""; + } + return name; +} + +OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) { + if (node->isa()) { + auto cnode = node->cast(); + + std::string name = kNameCustomOp; + if (!IsCustomCNode(cnode)) { + name = GetCNodeTargetFuncName(cnode); + } + + auto it_adpt = get_adpt_map().find(name); + if (it_adpt != get_adpt_map().end()) { + return it_adpt->second->Get(train); + } + MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name; + } + + if (node->isa()) { + return get_adpt_map()[kNameConst]->Get(train); + } + if (node->isa()) { + return get_adpt_map()[kNameParam]->Get(train); + } + return OpAdapterPtr(nullptr); +} + +void DfGraphConvertor::InitLoopVar(std::vector *init_input) { + if (this->training_) { + GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); + auto var_iter_num = std::make_shared("npu_runconfig/iterations_per_loop"); + auto var_loop_cond = std::make_shared("npu_runconfig/loop_cond"); + auto var_one = std::make_shared("npu_runconfig/one"); + auto var_zero = std::make_shared("npu_runconfig/zero"); + (void)var_iter_num->update_output_desc_y(desc); + (void)var_loop_cond->update_output_desc_y(desc); + (void)var_one->update_output_desc_y(desc); + (void)var_zero->update_output_desc_y(desc); + vars_["npu_runconfig/iterations_per_loop"] = var_iter_num; + vars_["npu_runconfig/loop_cond"] = var_loop_cond; + vars_["npu_runconfig/one"] = var_one; + vars_["npu_runconfig/zero"] = var_zero; + + int64_t value = 0; + auto const_iter_num = std::make_shared("const/npu_runconfig/iterations_per_loop"); + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + value = ConfigManager::GetInstance().iter_num(); + } else { + MS_LOG(INFO) << "Run with normal(non-sink) mode, the iterator number will always be 1"; + value = 1; + ConfigManager::GetInstance().set_iter_num(value); + } + value -= 1; // iteration start from 0, the max iteration number for n loop should be n-1 + (void)const_iter_num->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); + + auto const_loop_cond = std::make_shared("const/npu_runconfig/loop_cond"); + value = 0; + (void)const_loop_cond->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); + + auto const_one = std::make_shared("const/npu_runconfig/one"); + value = 1; + (void)const_one->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); + + auto const_zero = std::make_shared("const/npu_runconfig/zero"); + value = 0; + (void)const_zero->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); + + (void)const_iter_num->update_output_desc_y(desc); + (void)const_loop_cond->update_output_desc_y(desc); + (void)const_one->update_output_desc_y(desc); + (void)const_zero->update_output_desc_y(desc); + + auto assign_iter_num = std::make_shared("assign/npu_runconfig/iterations_per_loop"); + (void)assign_iter_num->set_input_ref(*var_iter_num).set_input_value(*const_iter_num); + auto assign_loop_cond = std::make_shared("assign/npu_runconfig/loop_cond"); + (void)assign_loop_cond->set_input_ref(*var_loop_cond).set_input_value(*const_loop_cond); + auto assign_one = std::make_shared("assign/npu_runconfig/one"); + (void)assign_one->set_input_ref(*var_one).set_input_value(*const_one); + auto assign_zero = std::make_shared("assign/npu_runconfig/zero"); + (void)assign_zero->set_input_ref(*var_zero).set_input_value(*const_zero); + + init_input->push_back(*var_iter_num); + init_input->push_back(*var_loop_cond); + init_input->push_back(*var_one); + init_input->push_back(*var_zero); + init_ops_.push_back(var_iter_num); + init_ops_.push_back(var_loop_cond); + init_ops_.push_back(var_one); + init_ops_.push_back(var_zero); + init_ops_.push_back(const_iter_num); + init_ops_.push_back(const_loop_cond); + init_ops_.push_back(const_one); + init_ops_.push_back(const_zero); + init_ops_.push_back(assign_iter_num); + init_ops_.push_back(assign_loop_cond); + init_ops_.push_back(assign_one); + init_ops_.push_back(assign_zero); + } +} + +OpAdapterPtr DfGraphConvertor::FindAdapter(const std::string &name, bool train) { + auto it = get_adpt_map().find(name); + if (it != get_adpt_map().end()) { + return it->second->Get(train); + } + MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name; +} + +void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it) { + // draw init subgraph + init_sout_ << "op_assign" << it.get() << "[label=<"; + init_sout_ << "" << endl; + init_sout_ << ""; + init_sout_ << ""; + init_sout_ << ""; + init_sout_ << "" << endl; + init_sout_ << "" << endl; + init_sout_ << "
resourcevalue
" + << "\"assign_" << name << "\"
> shape=plaintext]" << endl; + init_sout_ << "param" << it.get() << "[shape=octagon, label=\"" << name << "\"]" << endl; + init_sout_ << "const" << it.get() << "[label= \"" << name << "_const" + << "\" shape=ellipse]" << endl; + init_sout_ << "param" << it.get() << "->" + << "op_assign" << it.get() << ":1" << endl; + init_sout_ << "const" << it.get() << "->" + << "op_assign" << it.get() << ":2" << endl; +} + +void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input) { + DfGraphPtr init_graph = std::make_shared("init"); + std::vector nodes = TopoSort(anf_graph_->get_return()); + + for (auto &it : nodes) { + if (it->isa()) { + if (IsValueNode(it)) { + auto symbolic = GetValueNode(it); + auto name = std::static_pointer_cast(symbolic->node())->name(); + auto iter = vars_.find(name); // get correspoding varaible op + if (iter != vars_.end()) { + op_cache_[it.get()] = iter->second; + // #ifdef DRAW_GE_GRAPH + compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()] + << "[style=\"dotted\"]" << endl; + // #endif + } + } else if (IsValueNode(it)) { + auto refkey = GetValueNode(it); + auto name = refkey->tag(); + auto iter = vars_.find(name); // get correspoding varaible op + if (iter != vars_.end()) { + op_cache_[it.get()] = iter->second; + compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()] + << "[style=\"dotted\"]" << endl; + } + } + } + } + + for (auto &it : tensors) { + if (vars_.find(it.first) == vars_.end()) { + MS_LOG(WARNING) << "Init parameter " << it.first << " didn't appear in graph."; + vars_[it.first] = nullptr; + } + } + + // set up init sub graph + if (init_input->size()) { + // init sub graph needs no input + MS_LOG(INFO) << "Build data init subgraph."; + (void)init_graph->SetInputs(*init_input); + this->init_graph_ = init_graph; + } else { + this->init_graph_ = nullptr; + } +} + +void DfGraphConvertor::MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it) { + MS_LOG(INFO) << "The " << name << " is the " << input_idx << "(st/nd/th) input"; + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + auto getnext_idx = static_cast(input_idx); + DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); + if (!param.input_indexes().empty() && input_idx <= param.input_indexes().size()) { + getnext_idx = param.input_indexes()[input_idx] - 1; // input_idx start from 0. + MS_LOG(INFO) << "remap input_index:" << input_idx << " to getnext_index:" << getnext_idx << "."; + } + // use iterator_getnext op with output_name instead of data op in BuildGraph. + out_handle_cache_[it.get()] = OutHandler(dataset_iter_getnext_, "y" + std::to_string(getnext_idx)); + } +} + +void DfGraphConvertor::SetupBroadcast(const std::shared_ptr &broadcast, + const std::vector &broadcast_desc, + const DfGraphPtr &broadcast_graph, std::vector broadcast_input) { + MS_LOG(INFO) << "build broadcast subgraph"; + if (broadcast_desc.size() != broadcast_input.size()) { + MS_LOG(EXCEPTION) << "Desc number of BroadCast is not equal to number of Input"; + } + (void)broadcast->create_dynamic_input_x(static_cast(broadcast_input.size())); + (void)broadcast->create_dynamic_output_y(static_cast(broadcast_desc.size())); + for (unsigned int i = 0; i < broadcast_input.size(); i++) { + (void)broadcast->set_dynamic_input_x(i, broadcast_input[i]); + (void)broadcast->update_dynamic_output_desc_y(i, broadcast_desc[i]); + } + (void)broadcast_graph->SetInputs(broadcast_input); + this->broadcast_graph_ = broadcast_graph; +} + +void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) { + int index = 0; + std::vector init_input; + for (auto it : tensors) { + std::string name = it.first; + auto node_itor = params_.find(name); + // if name not in params_, create a node in graph + if (node_itor == params_.end()) { + MS_LOG(WARNING) << name << " is not in params, and create a new node."; + ParameterPtr param = std::make_shared(nullptr); + name = name + "_temp"; + param->set_name(name); + (void)ConvertParameter(param); + node_itor = params_.find(name); + } + auto node = node_itor->second; + auto op_itor = op_cache_.find(node.get()); + if (op_itor == op_cache_.end()) { + MS_LOG(EXCEPTION) << "Can not find op for node " << node->ToString() << "."; + } + auto adpt = FindAdapter(kNameParam, training_); + if (adpt == nullptr) continue; + auto param_op = adpt->generate(name + "_data"); + MS_LOG(INFO) << "Add parameter " << name << " as input, index " << index << "."; + + if (!training_) { + auto adpt_const = FindAdapter(kNameConst, training_); + if (adpt_const == nullptr) continue; + auto const_op = adpt_const->generate(name + "_const"); + (void)adpt_const->setAttr(const_op, "value", it.second); + + auto const_op_desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW); + if (const_op_desc == nullptr) { + MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; + continue; + } + (void)std::static_pointer_cast(const_op)->update_output_desc_y(*const_op_desc); + + vars_[name] = const_op; + op_itor->second = const_op; + continue; + } + + // create tensor descriptor for output descriptor + auto desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW); + if (desc == nullptr) { + MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; + continue; + } + + // we need three variable ops for each graph with same name + // build init subgraph + if (it.second->is_init() == 0) { + (void)std::static_pointer_cast(param_op)->set_attr_index(index++); + auto init_var = std::make_shared(name); + auto assign_op = std::make_shared("assign_" + name); + (void)init_var->update_output_desc_y(*desc); + (void)assign_op->set_input_ref(*init_var).set_input_value(*param_op); + init_input.push_back(*init_var); + init_ops_.push_back(param_op); + init_ops_.push_back(assign_op); + init_ops_.push_back(init_var); + } + + auto variable = std::make_shared(name); + (void)variable->update_output_desc_y(*desc); + // do not use read variable while variable sink + MS_LOG(DEBUG) << "InitParam, op_name = " << name << ", var = " << variable->GetName() << "."; + op_itor->second = variable; // replace parameter with variable + vars_[name] = variable; // prevent the variable operator from being freed + DrawParamInitSubGraph(name, node); + } + InitLoopVar(&init_input); + SetupParamInitSubGraph(tensors, &init_input); +} + +// convert all parameter need initialize to variable +DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) { + size_t input_idx = 0; + if (error_ != 0) { + return *this; + } + if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { + error_ = INVALID_ARGUMENT; + MS_LOG(ERROR) << "Invalid AnfGraph in InitParam."; + return *this; + } + + // Processing input with MakeDatasetHandler + for (auto &it : anf_graph_->parameters()) { + auto op_itor = op_cache_.find(it.get()); // converted node + if (it->isa() && op_itor != op_cache_.end()) { + string name = std::static_pointer_cast(it)->name(); + auto tensor_itor = tensors.find(name); // in init value map + if (tensor_itor == tensors.end()) { + DfGraphConvertor::MakeDatasetHandler(name, input_idx, it); + input_idx++; + } + } + } + InitParamWithData(tensors); + init_sout_ << "}" << endl; + return *this; +} + +#if (defined ENABLE_GE) +void DfGraphConvertor::BuildSaveCheckpointGraph() { + std::vector graph_inputs; + ge::op::Save save_op("save_parms"); + int save_op_is_active = 0; + size_t index = 0; + string name; + + int32_t count_size = std::count_if(vars_.begin(), vars_.end(), [](const std::pair &it) { + return (it.second == nullptr || it.first.find("/") != std::string::npos); + }); + + (void)save_op.create_dynamic_input_tensors(vars_.size() - static_cast(count_size)); + + // for each "parameter" in anf graph excluding "input" + for (const auto &it : vars_) { + name = it.first; + if (it.second == nullptr || name.find("/") != std::string::npos) continue; + Variable variable(name); + (void)variable.update_output_desc_y(it.second->GetOutputDesc(0)); + (void)save_op.set_dynamic_input_tensors(index++, variable); + + graph_inputs.push_back(variable); + + if (save_op_is_active == 0) { + checkpoint_sout_ << "op_save" << &save_op << "[label=<"; + checkpoint_sout_ << "" << endl; + checkpoint_sout_ << "" << endl; + checkpoint_sout_ << "" << endl; + checkpoint_sout_ << "
tensor
" + << "\"saveop" + << "\"
> shape=plaintext]" << endl; + } + + checkpoint_sout_ << "param" << it.second << "[shape=octagon, label=\"" << name << "\"]" << endl; + + checkpoint_sout_ << "param" << it.second << "->" + << "op_save" << &save_op << ":1" << endl; + save_op_is_active = 1; + } + if (save_op_is_active) { + std::vector graph_output; + graph_output.emplace_back(save_op); + DfGraphPtr checkpoint_graph = std::make_shared("checkpoint"); + (void)checkpoint_graph->SetInputs(graph_inputs); + (void)checkpoint_graph->SetOutputs(graph_output); + this->save_ckp_graph_ = checkpoint_graph; + } else { + this->save_ckp_graph_ = nullptr; + } + + checkpoint_sout_ << "}" << endl; + return; +} +#endif + +DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap &tensors) { + if (error_ != 0) { + return *this; + } + if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { + error_ = INVALID_ARGUMENT; + MS_LOG(ERROR) << "Invalid AnfGraph in generate broadcast graph"; + return *this; + } + + DfGraphPtr broadcast_graph = std::make_shared("broadcast"); + // collect the operators create for broadcast sub graph, in order to avoid auto release + std::vector broadcast_input; + std::vector broadcast_desc; + auto broadcast = std::make_shared("broadcast_parameter"); + (void)broadcast->set_attr_root_rank(0); + (void)broadcast->set_attr_group("hccl_world_group"); + broadcast_ops_.push_back(broadcast); + + // find every parameter, build broadcast subgraph (or initialize the parameter with constant) + for (auto &it : anf_graph_->parameters()) { + auto op_itor = op_cache_.find(it.get()); // converted node + if (it->isa() && op_itor != op_cache_.end()) { + string name = std::static_pointer_cast(it)->name(); + auto tensor_itor = tensors.find(name); // in init tensor map + if (tensor_itor != tensors.end()) { + auto tensor = tensor_itor->second; + auto shape_ge = tensor->shape_c(); + + // create tensor descriptor for output descriptor + auto desc = TransformUtil::GetGeTensorDesc(shape_ge, tensor->data_type(), kOpFormat_NCHW); + if (desc == nullptr) { + MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; + continue; + } + + // build broadcast subgraph + if (distribute_) { + auto broadcast_var = std::make_shared(name); + (void)broadcast_var->update_output_desc_y(*desc); + broadcast_input.push_back(*broadcast_var); + broadcast_desc.push_back(*desc); + broadcast_ops_.push_back(broadcast_var); + } + } + } + } + + // set up broadcast sub graph + if (!broadcast_input.empty()) { + DfGraphConvertor::SetupBroadcast(broadcast, broadcast_desc, broadcast_graph, broadcast_input); + } else { + this->broadcast_graph_ = nullptr; + } + return *this; +} + +DfGraphConvertor &DfGraphConvertor::GenerateCheckpointGraph() { + if (error_ != 0) { + MS_LOG(ERROR) << "Generate checkpoint graph failed, found error code " << error_ << "."; + return *this; + } + if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { + error_ = INVALID_ARGUMENT; + MS_LOG(ERROR) << "Invalid AnfGraph in GenerateCheckpointGraph"; + return *this; + } +#if (defined ENABLE_GE) + BuildSaveCheckpointGraph(); + // Restoring from checkpoint file is done by pyfront, not in graph now. +#endif + return *this; +} + +DfGraphConvertor &DfGraphConvertor::ConvertAllNode() { + if (error_ != 0) { + return *this; + } + if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { + MS_LOG(ERROR) << "Invalid AnfGraph"; + error_ = FAILED; + return *this; + } + + compute_sout_.clear(); + compute_sout_ << "digraph {" << endl; + init_sout_.clear(); + init_sout_ << "digraph {" << endl; + checkpoint_sout_.clear(); + checkpoint_sout_ << "digraph {" << endl; + restore_checkpoint_sout_.clear(); + restore_checkpoint_sout_ << "digraph {" << endl; + + // Convert all anf node to Operator + MS_LOG(DEBUG) << "convert all node"; + std::vector nodes = TopoSort(anf_graph_->get_return()); + for (auto &it : nodes) { + (void)Convert(it); + if (this->error_ != 0) { + MS_LOG(ERROR) << "failed to convert node: " << it->DebugString() << "."; + } + } + + // Create dataset iterator and iterator_getnext node + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); + MS_LOG(INFO) << "Dataset param is " << param.ToString() << "."; + // GetNext + auto iter_getnext_op = make_shared("get_next_tmp"); + (void)iter_getnext_op->set_attr_output_types(param.ge_types()); + (void)iter_getnext_op->set_attr_output_shapes(param.shapes()); + (void)iter_getnext_op->set_attr_channel_name(param.queue_name()); + + // save iter_getnext_op for later use + dataset_iter_getnext_ = iter_getnext_op; + } + + // return the data flow graph + return *this; +} + +void DfGraphConvertor::TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out) { + auto it = out_handle_cache_.find(anf_out.get()); + if (it != out_handle_cache_.end()) { + OutHandler handle = it->second; + auto op = handle.op; + if (op != nullptr) { + MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out; + graph_outputs_.emplace_back(std::make_pair(*op, handle.out)); + } else { + MS_LOG(EXCEPTION) << "tuple_getitem: " << anf_out->fullname_with_scope() << " is not converted"; + } + } else { + // invalid tuple_getitem e.g. tuple_getitem(tuple_getitem())/tuple_getitem(depend())/tuple_getitem(make_tuple()) + MS_LOG(WARNING) << "Invalid tuple_getitem: " << anf_out->fullname_with_scope(); + } +} + +void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { + AnfNodePtr anf_out = node; + AnfNodePtr pre_node = nullptr; + + // trace Parameter node + TraceOutputFromParameter(anf_out); + // then trace cnode + if (!node->isa()) { + return; + } + + // trace tuple_getitem + while (anf_out->isa() && IsPrimitiveCNode(anf_out, prim::kPrimTupleGetItem)) { + pre_node = anf_out; + anf_out = anf_out->cast()->input(1); + } + // trace every element of make_tuple + auto c = anf_out->cast(); + std::string name = ""; + if (anf_out->isa()) { + name = GetCNodeTargetFuncName(c); + } + + if (name == "make_tuple") { + for (unsigned int i = 1; i < c->inputs().size(); i++) { + TraceOutput(c->input(i)); + } + } else if (name == "Depend") { + if (c->inputs().size() < 3) { // "Depend" primitive have 3 inputs + MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3"; + } + TraceOutput(c->input(1)); + } else if (name == "tuple_getitem") { + TraceOutputFromTupleGetItem(anf_out); + } else { + // add outputs; + auto op = Convert(anf_out); + std::string index; + if (op != nullptr) { + if ((pre_node != nullptr) && IsPrimitiveCNode(pre_node, prim::kPrimTupleGetItem)) { + auto item = out_handle_cache_.find(pre_node.get()); + if (item != out_handle_cache_.end()) { + index = item->second.out; + } else { + MS_LOG(WARNING) << "Can't get operater: " << anf_out->fullname_with_scope() << " 's output item"; + } + } + MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope() << ":" << index; + graph_outputs_.emplace_back(make_pair(*op, index)); + } + } +} + +void DfGraphConvertor::TraceOutputFromParameter(const AnfNodePtr &anf_out) { + if (anf_out->isa()) { + MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope(); + auto it = out_handle_cache_.find(anf_out.get()); + if (it != out_handle_cache_.end()) { + // For dataset graph mode, input parameter is converted to a "iterator_get_next:yn" OutHandler. + OutHandler handle = it->second; + auto op = handle.op; + MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out; + graph_outputs_.emplace_back(make_pair(*op, handle.out)); + } else { + // common parameter case + auto op = Convert(anf_out); + if (op != nullptr) { + MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType(); + graph_outputs_.emplace_back(std::make_pair(*op, "")); + } + } + } +} + +void SetupDatasetIterGetNextNode(const OperatorPtr &op) { + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); + size_t output_num = param.ge_types().size(); + MS_LOG(INFO) << "Set iterator_getnext op's output num = " << output_num << "."; + // set iterator_getnext op's output num + shared_ptr iter_getnext = std::static_pointer_cast(op); + (void)iter_getnext->create_dynamic_output_y(static_cast(output_num)); + + for (uint32_t i = 0; i < output_num; i++) { + ge::TensorDesc desc(GeShape(param.shapes()[i]), ge::FORMAT_NCHW, (ge::DataType)param.ge_types()[i]); + // we don't SetRealDimCnt here since GE do not use this output's real-dim + (void)iter_getnext->update_dynamic_output_desc_y((i), desc); + } + } + return; +} + +void DfGraphConvertor::SetSubgraph(AnfNodePtr node) { + if (!node->isa()) { + return; + } + auto cnode = node->cast(); + if (!IsCaseNode(cnode)) { + return; + } + std::vector case_inputs; + for (size_t i = 1; i < cnode->inputs().size(); i++) { + case_inputs.emplace_back(cnode->input(i)); + } + std::shared_ptr> branches = std::make_shared>(); + auto bnode = cnode->input(0)->cast()->input(2)->cast(); + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + auto branch_node = bnode->input(i)->cast(); + for (size_t j = 2; j < branch_node->inputs().size(); j++) { + if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { + case_inputs.emplace_back(branch_node->input(j)); + } + } + } + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + ProcessSubgraph(bnode->input(i), case_inputs); + } + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + branches->emplace_back(branches_map_[bnode->input(i).get()]); + } + + if (op_cache_.find(node.get()) == op_cache_.end()) { + return; + } + + OpAdapterPtr adpt = FindAdapter(node, training_); + if (nullptr == adpt) { + MS_LOG(DEBUG) << "Not found adapter"; + return; + } + + OperatorPtr op = Convert(node); + adpt->setSubgraph(op, 0, branches); + return; +} + +void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node) { + std::vector case_inputs; + for (size_t i = 1; i < node->inputs().size(); i++) { + case_inputs.emplace_back(node->input(i)); + } + std::shared_ptr> branches = std::make_shared>(); + auto bnode = input_node->input(2)->cast(); + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + auto branch_node = bnode->input(i)->cast(); + for (size_t j = 2; j < branch_node->inputs().size(); j++) { + if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { + case_inputs.emplace_back(branch_node->input(j)); + } + } + } + + const size_t case_index = 1; + const size_t make_tuple_index = 2; + + AnfNodePtr case_index_iter = input_node->input(case_index); + AnfNodePtr make_tuple_iter = input_node->input(make_tuple_index); + auto make_tuple_node = make_tuple_iter->cast(); + std::shared_ptr> tuple_items = std::make_shared>(); + + for (size_t i = 0; i < case_inputs.size(); i++) { + auto item = case_inputs[i]; + auto op = Convert(item); + if (op != nullptr) { + tuple_items->emplace_back(OutHandler(op, "")); + } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { + tuple_items->push_back(out_handle_cache_[item.get()]); + } else { + MS_LOG(WARNING) << "This anf node is not supported as a case input: " << item->ToString(); + continue; + } + } + + tuple_out_handle_cache_[make_tuple_node.get()] = tuple_items; + + std::shared_ptr> case_input_items = std::make_shared>(); + case_input_items->emplace_back(case_index_iter); + case_input_items->emplace_back(make_tuple_iter); + case_input_handle_cache_[node.get()] = case_input_items; +} + +DfGraphConvertor &DfGraphConvertor::BuildGraph() { + SetupDatasetIterGetNextNode(dataset_iter_getnext_); + + if (error_ != 0) { + return *this; + } + + // Case node set input. + std::vector nodes = ::mindspore::TopoSort(anf_graph_->get_return()); + for (auto &it : nodes) { + if (it->isa() && IsCaseNode(it->cast())) { + auto node = it->cast(); + auto input_node = node->input(0)->cast(); + GetCaseNodeInput(node, input_node); + } + } + + // update tuple_out_handle_cache_ + for (auto it : tuple_out_handle_cache_) { + std::size_t len = it.second->size(); + for (std::size_t i = 0; i < len; i++) { + OutHandler handle = (*it.second)[i]; + if (handle.op) { + string name = handle.op->GetName(); + if (vars_.count(name)) { + OperatorPtr new_op = vars_[name]; + if (new_op != nullptr) { + MS_LOG(INFO) << "update tuple_out_handle_cache_ " << name; + (*it.second)[i] = OutHandler(new_op, handle.out); + } + } + } + } + } + + // set up dependices + MS_LOG(DEBUG) << "set up dependices"; + nodes = ::mindspore::TopoSort(anf_graph_->get_return()); + for (auto &it : nodes) { + SetNodeInput(it); + SetOpControlInput(it); + SetSubgraph(it); + UpdateOpDesc(it); + } + + if (error_ == 0) { + df_graph_ = make_shared(anf_graph_->ToString()); + } else { + return *this; + } + + // set graph input according to the order from anf graph + std::vector inputs; + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + inputs.push_back(*dataset_iter_getnext_); + } else { + auto params = anf_graph_->parameters(); + if (use_inputs_) { + params = inputs_; + auto anf_params = anf_graph_->parameters(); + for (size_t i = 0; i < params.size(); i++) { + for (size_t j = 0; j < anf_params.size(); j++) { + if (params[i]->ToString() == anf_params[j]->ToString()) { + params[i] = anf_params[j]; + } + } + } + } + + int index = 0; + for (auto &it : params) { + auto name = std::static_pointer_cast(it)->name(); + // the parameters which has not been converted to var + if (vars_.find(name) == vars_.end()) { + auto op = Convert(it); + MS_EXCEPTION_IF_NULL(op); + MS_LOG(INFO) << "add not var input " << it->ToString() << ", index " << index; + if (op == nullptr) { + MS_LOG(ERROR) << "Convert graph failed!"; + return *this; + } + UpdateDataOpDesc(it, op); + + MS_LOG(INFO) << "add input " << it->ToString() << ", index " << index; + (void)std::static_pointer_cast(op)->set_attr_index(index++); + inputs.push_back(*op); + } else if (vars_[name] != nullptr) { + MS_LOG(INFO) << "add var input " << it->ToString(); + auto op = Convert(it); + MS_EXCEPTION_IF_NULL(op); + inputs.push_back(*op); + } + } + } + + // Add const nodes as graph input for some operator work with constant + std::transform(graph_const_inputs_.begin(), graph_const_inputs_.end(), std::back_inserter(inputs), + [](OperatorPtr x) { return *x; }); + + MS_LOG(INFO) << "set graph input num: " << inputs.size(); + (void)df_graph_->SetInputs(inputs); + + // set graph output + // set the value of finale return apply node as the output of dataflow graph + MS_LOG(DEBUG) << "set output"; + graph_outputs_.clear(); + TraceOutput(anf_graph_->get_return()->input(1)); + MS_LOG(INFO) << "set graph output num: " << graph_outputs_.size(); + (void)df_graph_->SetOutputs(graph_outputs_); + + compute_sout_ << "}" << endl; + // For the graph(e.g. eval_subgraph) whose IterNum is 1, donot set NeedIteration flag. + if (ConfigManager::GetInstance().iter_num() > 1) { + df_graph_->SetNeedIteration(true); + } + return *this; +} + +void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const { + auto node = std::static_pointer_cast(it); + if (node == nullptr) { + MS_LOG(ERROR) << "Update data op descriptor failed! Invalid node."; + return; + } + auto normal_shape_ptr = dyn_cast(node->Shape()); + vector shape; + if (normal_shape_ptr == nullptr) { + MS_LOG(INFO) << "Invalid shape to update data op descriptor."; + return; + } + shape = normal_shape_ptr->shape(); + if (node->Type() == nullptr) { + MS_LOG(INFO) << "Invalid type to update data op descriptor."; + return; + } + TypeId me_type = node->Type()->type_id(); + if (kObjectTypeTensorType == me_type) { + me_type = dyn_cast(node->Type())->element()->type_id(); + } + std::ostringstream buf; + buf << "[" << shape << "]"; + MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type; + auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); + if (desc == nullptr) { + MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null."; + } else { + (void)std::static_pointer_cast(op)->update_input_desc_x(*desc); + (void)std::static_pointer_cast(op)->update_output_desc_y(*desc); + } +} + +DfGraphPtr DfGraphConvertor::GetComputeGraph() { return df_graph_; } + +DfGraphPtr DfGraphConvertor::GetInitGraph() { return init_graph_; } + +DfGraphPtr DfGraphConvertor::GetSaveCheckpointGraph() { return save_ckp_graph_; } + +DfGraphPtr DfGraphConvertor::GetBroadcastGraph() { return broadcast_graph_; } + +void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) { + if (control_depend_cache_.find(node.get()) == control_depend_cache_.end()) { + return; + } + + std::vector control_edges = control_depend_cache_[node.get()]; + if ((control_edges.empty())) { + MS_LOG(ERROR) << "Get control depend node's src or dest operator failed"; + return; + } + + for (auto &item : control_edges) { + (void)item.dest_op->AddControlInput(*item.src_op); + } +} + +const std::vector trans_var_list = {string(kNameAssign), string(kNameAssignAdd), string(kNameAssignSub)}; + +void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { + OperatorPtr src = Convert(node); + int case_flag = 0; + auto &inputs = node->inputs(); + size_t input_size = inputs.size(); + if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) { + case_flag = 1; + input_size = case_input_handle_cache_[node.get()]->size() + 1; + } + + for (size_t i = 1; i < input_size; i++) { + auto pred = inputs[i]; + if (case_flag != 0) { + pred = case_input_handle_cache_[node.get()]->at(i - 1); + } + + while (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == "Depend") { + pred = pred->cast()->input(1); + } + // skip the None input + if (IsValueNode(pred)) { + continue; + } + // transform "Const" op to "Variable" op when the next node is "Assign" op. + std::string c_name = GetCNodeTargetFuncName(node); + auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); + if (!training_ && pos != trans_var_list.end() && pred->isa()) { + std::string name = std::static_pointer_cast(pred)->name(); + auto op_itor = op_cache_.find(pred.get()); + if (op_itor == op_cache_.end()) { + MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; + } + if (op_itor->second != nullptr && + (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && + vars_.find(name) != vars_.end()) { + auto variable = std::make_shared(name); + auto desc = vars_[name]->GetOutputDesc("y"); + (void)variable->update_output_desc_y(desc); + MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; + op_itor->second = variable; // replace parameter with variable + vars_[name] = variable; + } + } + // find in out_hadnle_cache_ first + auto it = out_handle_cache_.find(pred.get()); + if (it != out_handle_cache_.end()) { + int ret = adpt->setInput(src, SizeToInt(i), it->second); + if (ret == 0) { + if (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == "tuple_getitem") { + compute_sout_ << op_draw_name_[pred->cast()->input(1).get()] << " -> " << op_draw_name_[node.get()] + << ":" << i << endl; + } else if (pred->isa()) { + compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl; + } else { + // don't draw anything. + MS_LOG(INFO) << "DRAW_GE_GRAPH: Shouldn't have this case."; + } + AddGraphConstInput(it->second.op); + } + } else if (tuple_out_handle_cache_.find(pred.get()) != tuple_out_handle_cache_.end()) { + std::shared_ptr> handler_vec = tuple_out_handle_cache_[pred.get()]; + int ret = adpt->setInput(src, SizeToInt(i), handler_vec); + if ((ret == 0) && pred->isa() && (pred->cast()->inputs().size() == handler_vec->size() + 1)) { + for (unsigned int j = 0; j < handler_vec->size(); j++) { + compute_sout_ << op_draw_name_[pred->cast()->input(j + 1).get()] << " -> " + << op_draw_name_[node.get()] << ":" << i << endl; + AddGraphConstInput(handler_vec->at(j).op); + } + } else { + MS_LOG(WARNING) << "Convert tuple node setInput failed : " << node->ToString(); + } + } else { + auto op = Convert(pred); + int ret = adpt->setInput(src, SizeToInt(i), op); + if (ret == 0) { + compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl; + AddGraphConstInput(op); + } + } + } +} + +void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) { + if (op->GetOpType() == "Constant") { + graph_const_inputs_.push_back(op); + } +} + +void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) { + if (!node->isa()) { + return; + } + if (op_cache_.find(node.get()) == op_cache_.end()) { + return; + } + auto cnode = node->cast(); + OpAdapterPtr adpt = FindAdapter(cnode, training_); + if (adpt == nullptr) { + error_ = NOT_FOUND; + return; + } + + // get Operator from op_cache_, use adapter to set Inputs + DfGraphConvertor::SetOpInput(adpt, cnode); +} + +void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector &inputs) { + if (!node->isa() || GetCNodeFuncName(node->cast()) != "Partial") { + return; + } + auto graph_node = node->cast()->input(1)->cast(); + FuncGraphPtr anf_graph = graph_node->value()->cast(); + DfGraphConvertor convertor(anf_graph); + convertor.use_inputs_ = true; + convertor.inputs_ = inputs; + (void)convertor.ConvertAllNode().BuildGraph(); + std::string name = graph_node->ToString() + "_ge_graph.dot"; + if (MsContext::GetInstance()->save_graphs_flag()) { + convertor.DrawComputeGraph(name); + } + branches_map_[node.get()] = *(convertor.df_graph_); +} + +// Update GE op's shape and type info +void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) { + if (nullptr == node || !node->isa()) { + return; + } + + if (op_cache_.find(node.get()) == op_cache_.end()) { + return; + } + + OpAdapterPtr adpt = FindAdapter(node, training_); + if (adpt == nullptr) { + error_ = NOT_FOUND; + return; + } + + // get Operator from op_cache_ + OperatorPtr op = Convert(node); + + adpt->updateOutputDesc(op, node->Shape(), node->Type(), node); +} + +OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) { + if (node == nullptr) { + MS_LOG(ERROR) << "node is nullptr"; + error_ = NOT_FOUND; + return nullptr; + } + // find in cache + if (op_cache_.count(node.get())) { + return op_cache_[node.get()]; + } + + // do not convert primitive node + if (IsValueNode(node)) { + return nullptr; + } + + // convert a new one + if (node->isa()) { + return ConvertCNode(node->cast()); + } + if (node->isa()) { + return ConvertParameter(node); + } + if (node->isa()) { + return ConvertValueNode(node->cast()); + } + + MS_LOG(ERROR) << "Invalide AnfNode"; + error_ = INVALID_ARGUMENT; + return nullptr; +} + +void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) { + std::shared_ptr> tuple_items = std::make_shared>(); + // convert each tuple item to a OutHandler + for (size_t i = 1; i < node->inputs().size(); i++) { + AnfNodePtr item = node->input(i); + OperatorPtr op = Convert(item); + if (op != nullptr) { + tuple_items->emplace_back(OutHandler(op, "")); + } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { + tuple_items->push_back(out_handle_cache_[item.get()]); + } else { + MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << item->ToString(); + return; + } + } + + MS_LOG(WARNING) << "ConvertMakeTuple: " << node.get() << " " << tuple_items->size(); + tuple_out_handle_cache_[node.get()] = tuple_items; +} + +AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, unsigned int *index) { + const int TUPLE_GET_ITEM_INDEX = 2; + if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs + MS_LOG(EXCEPTION) << "length of inputs of TupleGetItem is less than 3"; + } + auto index_node = node->inputs()[TUPLE_GET_ITEM_INDEX]; + if (!index_node->isa()) { + error_ = INVALID_ARGUMENT; + MS_LOG(EXCEPTION) << "can't convert get item with non-constant index"; + } + *index = IntToUint(GetValue(GetValueNode(index_node))); + return node->inputs()[1]; +} + +AnfNodePtr DfGraphConvertor::TraceDepend(const CNodePtr &node) { + auto cnode = node->cast(); + if (cnode->inputs().size() < 3) { // "Depend" primitive have 3 inputs + MS_LOG(EXCEPTION) << "length of inputs of depend is less than 3"; + } + return cnode->inputs()[1]; +} + +AnfNodePtr DfGraphConvertor::TraceMakeTuple(const CNodePtr &node, unsigned int index) { + if (index + 1 >= node->inputs().size()) { + MS_LOG(EXCEPTION) << "length of make_tuple is less than index: " << index; + } + return node->inputs()[index + 1]; +} + +OutHandler DfGraphConvertor::GetHandler(const AnfNodePtr &node, const std::stack &index_stack, + AnfNode *const draw_index) { + if (node == nullptr) { + MS_LOG(ERROR) << "Get nullptr while trace real op"; + return OutHandler(nullptr, ""); + } + std::ostringstream ss; + ss << "op" << node.get(); + if (index_stack.empty()) { + op_draw_name_[draw_index] = ss.str(); + return OutHandler(Convert(node), ""); + } else { + OpAdapterPtr adpt = FindAdapter(node, training_); + if (nullptr == adpt) { + MS_LOG(ERROR) << "Can not get node output as adpt is nullptr!"; + error_ = NOT_FOUND; + return OutHandler(nullptr, ""); + } + OperatorPtr op = Convert(node); + if (op == nullptr) { + error_ = NOT_FOUND; + MS_LOG(ERROR) << "Can not convert node for trace real op"; + return OutHandler(nullptr, ""); + } + op_draw_name_[draw_index] = ss.str(); + return adpt->getOutput(Convert(node), UintToInt(index_stack.top())); + } +} + +// get the real operator through maketuple tuple_getitem depend +OutHandler DfGraphConvertor::TraceRealOp(AnfNodePtr node) { + bool flag = IsPrimitiveCNode(node, prim::kPrimTupleGetItem) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) || + IsPrimitiveCNode(node, prim::kPrimDepend); + std::stack index_stack; + auto draw_index = node.get(); + while (flag) { + flag = false; + if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + unsigned int index; + node = TraceTupleGetItem(node->cast(), &index); + index_stack.push(index); + flag = true; + } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + if (index_stack.empty()) { + MS_LOG(ERROR) << "TraceRealOp find a make_tuple node"; + return OutHandler(nullptr, ""); + } else { + node = TraceMakeTuple(node->cast(), index_stack.top()); + index_stack.pop(); + flag = true; + } + } else if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + node = TraceDepend(node->cast()); + flag = true; + } + } + return GetHandler(node, index_stack, draw_index); +} + +void DfGraphConvertor::ConvertTupleGetItem(const CNodePtr node) { + auto handle = TraceRealOp(node); + if (handle.op == nullptr) { + MS_LOG(ERROR) << "Failed to trace tuple get item"; + return; + } + out_handle_cache_[node.get()] = handle; +} + +// Get the real op for tuple_getitem through make tuple, or depend +AnfNodePtr DfGraphConvertor::GetRealOpNode(AnfNodePtr node) { + const int TUPLE_GET_ITEM_INDEX = 2; + if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + auto node_inputs = node->cast()->inputs(); + if (node_inputs.size() != 3) { // "tuple_getitem" primitive must have 3 inputs + MS_LOG(ERROR) << "tuple get item node not correct!"; + error_ = FAILED; + return node; + } + MS_EXCEPTION_IF_NULL(node_inputs[TUPLE_GET_ITEM_INDEX]); + if (!node_inputs[TUPLE_GET_ITEM_INDEX]->isa()) { + error_ = INVALID_ARGUMENT; + MS_LOG(EXCEPTION) << "can't convert get item with non-constant index"; + } + auto value_ptr = GetValueNode(node_inputs[TUPLE_GET_ITEM_INDEX])->cast(); + if (value_ptr == nullptr) { + MS_LOG(ERROR) << "Can not convert get item as value is nullptr!"; + error_ = FAILED; + return node; + } + int index = value_ptr->value(); + + // make_tuple apply inputs:make_tuple, [tuple_items,] + if (IsPrimitiveCNode(node_inputs[1], prim::kPrimMakeTuple)) { + auto tuple_inputs = node->cast()->inputs(); + if (tuple_inputs.size() < IntToSize(index + 1)) { + MS_LOG(ERROR) << "make tuple input items node not correct! size:" << tuple_inputs.size() + << ", item index:" << index; + error_ = FAILED; + return node; + } + return GetRealOpNode(tuple_inputs[IntToSize(index + 1)]); + } + return GetRealOpNode(node_inputs[1]); + } + + // depend apply inputs: depend,output,depended_node + if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + auto depend_inputs = node->cast()->inputs(); + if (depend_inputs.size() != 3) { // "Depend" primitive have 3 inputs + MS_LOG(ERROR) << "depend input items not correct"; + error_ = FAILED; + return node; + } + return GetRealOpNode(depend_inputs[1]); + } + return node; +} + +// convert the anf node to corresponding operator list +std::vector DfGraphConvertor::ConvertDependNode(const AnfNodePtr node) { + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + std::vector op_lists; + auto node_inputs = node->cast()->inputs(); + for (size_t index = 1; index < node_inputs.size(); index++) { + auto op = Convert(GetRealOpNode(node_inputs[index])); + if (op == nullptr) { + MS_LOG(ERROR) << "Convert control depend node to operator failed"; + error_ = FAILED; + return std::vector({}); + } + op_lists.push_back(op); + } + return op_lists; + } + + auto op = Convert(GetRealOpNode(node)); + if (op == nullptr) { + MS_LOG(ERROR) << "Convert control depend node to operator failed"; + error_ = FAILED; + return std::vector({}); + } + return std::vector({op}); +} + +// get the anf node list for depend +std::vector DfGraphConvertor::GetDependNodes(const AnfNodePtr &node) { + std::vector nodes; + // for make tuple, should control depend on the tuple items + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + auto node_inputs = node->cast()->inputs(); + for (size_t index = 1; index < node_inputs.size(); index++) { + nodes.push_back(GetRealOpNode(node_inputs[index])); + } + return nodes; + } + + // for parameter ,find the apply that used the parameter as the control depended node + if (node->isa()) { + auto uses = node->func_graph()->manager()->node_users()[node]; + for (auto &use : uses) { + auto use_node = use.first; + if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { + nodes.push_back(GetRealOpNode(use_node)); + } + } + return nodes; + } + nodes.push_back(GetRealOpNode(node)); + return nodes; +} + +void DfGraphConvertor::DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node) { +#ifdef DRAW_GE_GRAPH + auto src_depend_nodes = GetDependNodes(src_node); + auto dst_depend_nodes = GetDependNodes(dest_node); + if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() > 1) { + for (auto &item : dst_depend_nodes) { + compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[item.get()] + << "[style=\"dotted\"]" << endl; + } + } else if (src_depend_nodes.size() > 1 && dst_depend_nodes.size() == 1) { + for (auto &item : src_depend_nodes) { + compute_sout_ << op_draw_name_[item.get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] + << "[style=\"dotted\"]" << endl; + } + } else if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() == 1) { + compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] + << "[style=\"dotted\"]" << endl; + } +#endif +} + +void DfGraphConvertor::GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, + const AnfNodePtr &dest_node, + const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list) { + if (src_node->isa()) { + auto uses = node->func_graph()->manager()->node_users()[src_node]; + for (auto &use : uses) { + auto use_node = use.first; + if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && + (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { + auto converted_list = ConvertDependNode(use_node); + src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); + } + } + } + + if (dest_node->isa()) { + auto uses = node->func_graph()->manager()->node_users()[dest_node]; + for (auto &use : uses) { + auto use_node = use.first; + if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && + (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { + auto converted_list = ConvertDependNode(use_node); + dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); + } + } + } +} + +bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, + const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list) { + const int CONTROL_DEPEND_INDEX = 0; + const int SRC_NODE_INDEX = 1; + const int DEST_NODE_INDEX = 2; + const int DEPEND_MODE_NORMAL_USE = 0; + const int DEPEND_MODE_ON_PARAMETER_USE = 1; + + auto node_inputs = node->inputs(); + if (node_inputs.size() <= DEST_NODE_INDEX) { + MS_LOG(WARNING) << "Control depend node input size error"; + return false; + } + auto src_node = node_inputs[SRC_NODE_INDEX]; + auto dest_node = node_inputs[DEST_NODE_INDEX]; + if ((src_node == nullptr) || (dest_node == nullptr)) { + MS_LOG(ERROR) << "Control depend node miss src or dest node"; + error_ = FAILED; + return false; + } + AnfNodePtr fn = node_inputs[CONTROL_DEPEND_INDEX]; + PrimitivePtr prim_ptr = GetValueNode(fn); + ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); + int depend_mode = DEPEND_MODE_NORMAL_USE; + if (mode_ptr != nullptr) { + auto mode_int = mode_ptr->cast(); + MS_EXCEPTION_IF_NULL(mode_int); + depend_mode = mode_int->value(); + MS_LOG(DEBUG) << "depend_mode = " << depend_mode; + } + if (depend_mode == DEPEND_MODE_ON_PARAMETER_USE) { + GetDependOnParameterUse(node, src_node, dest_node, src_ops_list, dst_ops_list); + } + + if (src_node->isa()) { + auto converted_list = ConvertDependNode(src_node); + src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); + } + + if (dest_node->isa()) { + auto converted_list = ConvertDependNode(dest_node); + dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); + } + if (src_ops_list->empty() || dst_ops_list->empty()) { + MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it"; + error_ = SUCCESS; + } + return true; +} + +void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { + const int SRC_NODE_INDEX = 1; + const int DEST_NODE_INDEX = 2; + if (control_depend_cache_.find(node.get()) != control_depend_cache_.end()) { + return; + } + auto node_inputs = node->inputs(); + if (node_inputs.size() <= DEST_NODE_INDEX) { + MS_LOG(WARNING) << "Control depend node input size error"; + return; + } + auto src_node = node_inputs[SRC_NODE_INDEX]; + auto dest_node = node_inputs[DEST_NODE_INDEX]; + if ((src_node == nullptr) || (dest_node == nullptr)) { + MS_LOG(ERROR) << "Control depend node miss src or dest node"; + error_ = FAILED; + return; + } + std::shared_ptr> src_ops_list = std::make_shared>(); + std::shared_ptr> dst_ops_list = std::make_shared>(); + if (!GetControlDependList(node, src_ops_list, dst_ops_list)) { + MS_LOG(ERROR) << "Get depend list failed"; + error_ = FAILED; + return; + } + std::vector control_edges; + if (src_ops_list->size() == 1 && dst_ops_list->size() > 1) { + (void)std::transform(dst_ops_list->begin(), dst_ops_list->end(), std::back_inserter(control_edges), + [src_ops_list](const OperatorPtr &op) -> ControlEdge { + return {(*src_ops_list)[0], op}; + }); + } else if (src_ops_list->size() > 1 && dst_ops_list->size() == 1) { + (void)std::transform(src_ops_list->begin(), src_ops_list->end(), std::back_inserter(control_edges), + [dst_ops_list](const OperatorPtr &op) -> ControlEdge { + return {op, (*dst_ops_list)[0]}; + }); + } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { + control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); + } else if (src_ops_list->empty() || dst_ops_list->empty()) { + MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it"; + } else { + MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() + << " -> dst:" << dst_ops_list->size(); + error_ = FAILED; + return; + } + control_depend_cache_[node.get()] = control_edges; + +#ifdef DRAW_GE_GRAPH + DrawControlDepend(src_node, dest_node); +#endif +} + +bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) { + // ignore apply node of return + if (name == "return" || name == "Depend") { + return false; + } + + if (name == "" && GetCNodeFuncName(node) == "switch_layer") { + return false; + } + + if (name == "Partial") { + return false; + } + + // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers + if (name == "make_tuple") { + ConvertMakeTuple(node); + return false; + } + + // As for nodes with multi outputs, convert tuple_getitem to OutHandle + if (name == "tuple_getitem") { + ConvertTupleGetItem(node); + return false; + } + + if (name == "ControlDepend") { + ConvertControlDependNode(node); + return false; + } + + return true; +} + +OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) { + std::string name = GetCNodeTargetFuncName(node); + if (!CheckCNode(name, node)) { + return nullptr; + } + + // get corresponding OpAdapter + OpAdapterPtr adpt = FindAdapter(node, training_); + if (adpt == nullptr) { + error_ = NOT_FOUND; + return nullptr; + } + + // get operator + OperatorPtr op = nullptr; + auto it_op = op_cache_.find(node.get()); + if (it_op != op_cache_.end()) { + op = it_op->second; + } else { + op = adpt->generate(node); + } + + // set attribute for primitive + (void)adpt->setAttr(op, node); + + // add into cache + (void)op_cache_.insert(std::make_pair(node.get(), op)); + + DrawCNode(node, adpt); + + return op_cache_[node.get()]; +} + +OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) { + // convert Parameter in ANF to variable in DataFlow + auto op = FindAdapter(node, training_)->generate(node); + op_cache_[node.get()] = op; + + // build index for parameter using name + std::string name = std::static_pointer_cast(node)->name(); + params_[name] = node; + + std::ostringstream ss; + ss << "op" << node.get(); + op_draw_name_[node.get()] = ss.str(); + compute_sout_ << ss.str() << "[shape=octagon, label=\"" << name << "\"]" << endl; + return op_cache_[node.get()]; +} + +Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) { + MS_EXCEPTION_IF_NULL(node); + ValuePtr value = node->value(); + MS_EXCEPTION_IF_NULL(value); + if (!value->isa() && !value->isa()) { + return FAILED; + } + + auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); + if (vec.empty()) { + return FAILED; + } + + std::shared_ptr> tuple_items = std::make_shared>(); + for (size_t i = 0; i < vec.size(); i++) { + MS_EXCEPTION_IF_NULL(vec[i]); + if (vec[i]->isa()) { + GeTensorPtr ge_tensor = transform::TransformUtil::ConvertTensor(vec[i]->cast(), kOpFormat_NCHW); + auto const_op = std::make_shared(node->fullname_with_scope() + "/const/inputs/" + std::to_string(i)); + (void)const_op->set_attr_value(*ge_tensor); + (void)const_op->update_output_desc_y(ge_tensor->GetTensorDesc()); + tuple_items->emplace_back(OutHandler(const_op, "")); + } else { + return FAILED; + } + } + if (tuple_items->empty()) { + return FAILED; + } + + tuple_out_handle_cache_[node.get()] = tuple_items; + return SUCCESS; +} + +OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) { + // convert valuenode in ANF to Const in DataFlow + // find paramerte referenced by SymbolicKeyInstance of valuenode + std::ostringstream ss; + ss << "op" << node.get(); + op_draw_name_[node.get()] = ss.str(); + compute_sout_ << ss.str() << "[label= \"" << node->value()->ToString() << "\" shape=ellipse]" << endl; + + if (TryConvertValueNodeToMultiConst(node) == SUCCESS) { + MS_LOG(INFO) << "Convert value node to multi Constant OP success"; + return nullptr; + } + + OpAdapterPtr adpt = FindAdapter(node, training_); + if (adpt == nullptr) { + error_ = NOT_FOUND; + return nullptr; + } + auto op = adpt->generate(node); + // set const's attrs + if (adpt->setAttr(op, "value", node->value()) != 0) { + MS_LOG(WARNING) << "set attr value for const failed"; + } + +#if (defined ENABLE_GE) + auto const_op = std::static_pointer_cast(op); + if (const_op == nullptr) { + MS_LOG(ERROR) << "Get Constant operator failed"; + return nullptr; + } + auto ge_tensor = const_op->get_attr_value(); + auto ge_desc = ge_tensor.GetTensorDesc(); + (void)const_op->update_output_desc_y(ge_desc); +#endif + + op_cache_[node.get()] = op; + return op_cache_[node.get()]; +} + +void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) { + if (nullptr == adpt || nullptr == node) { + MS_LOG(ERROR) << "Failed to draw apply node as adpt or node is nullptr!"; + return; + } + std::ostringstream ss; + ss << "op" << node.get(); + op_draw_name_[node.get()] = ss.str(); + + compute_sout_ << ss.str() << "[label=<"; + compute_sout_ << "" << endl; + + auto input_map = adpt->getInputMap(); + auto dyn_input_map = adpt->getDynInputMap(); + if (input_map.size() + dyn_input_map.size() > 0) { + compute_sout_ << ""; + for (auto &it : input_map) { + compute_sout_ << ""; + } + for (auto &it : dyn_input_map) { + compute_sout_ << ""; + } + compute_sout_ << "" << endl; + } + + compute_sout_ << "" << endl; + + // print attrs' values + auto atts = adpt->GetAttrsFromDrawGraph(); + for (auto &it : atts) { + compute_sout_ << ""; + } + + adpt->clearAttrVect(); + + compute_sout_ << "
" << it.second.name << "" << it.second.name << "
\"" << node->ToString() + << ":" << GetCNodeTargetFuncName(node) << "\"
\"" << it + << "\"
> shape=plaintext]" << endl; +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/convert.h b/mindspore/ccsrc/transform/graph_ir/convert.h new file mode 100644 index 0000000000..6fa27831bf --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/convert.h @@ -0,0 +1,258 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ +#define MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ + +#define DRAW_GE_GRAPH + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "transform/graph_ir/util.h" +#include "ir/tensor.h" +#include "transform/graph_ir/df_graph_manager.h" +#include "utils/config_manager.h" +#include "transform/graph_ir/op_declare.h" +#include "graph/operator_reg.h" +#ifdef OPEN_SOURCE +#include "ge/client/ge_api.h" +#else +#include "external/ge/ge_api.h" +#endif +#include "graph/tensor.h" +#include "ops/all_ops.h" + +namespace mindspore { +namespace transform { +class OpAdapterDesc { + public: + OpAdapterDesc() : train_(nullptr), infer_(nullptr) {} + + OpAdapterDesc(const OpAdapterPtr &train, const OpAdapterPtr &infer) : train_(train), infer_(infer) {} + + explicit OpAdapterDesc(const OpAdapterPtr &common) : train_(common), infer_(common) {} + + OpAdapterDesc(const OpAdapterDesc &desc) { + this->train_ = desc.train_; + this->infer_ = desc.infer_; + } + + OpAdapterDesc(OpAdapterDesc &&desc) { + this->train_ = desc.train_; + this->infer_ = desc.infer_; + desc.train_ = nullptr; + desc.infer_ = nullptr; + } + + ~OpAdapterDesc() = default; + + OpAdapterPtr Get(bool train) const { return train ? train_ : infer_; } + + OpAdapterDesc &operator=(const OpAdapterDesc &desc) { + if (this != &desc) { + this->train_ = desc.train_; + this->infer_ = desc.infer_; + } + return *this; + } + + OpAdapterDesc &operator=(OpAdapterDesc &&desc) { + if (this != &desc) { + this->train_ = desc.train_; + this->infer_ = desc.infer_; + desc.train_ = nullptr; + desc.infer_ = nullptr; + } + return *this; + } + + private: + OpAdapterPtr train_; + OpAdapterPtr infer_; +}; + +using OpAdapterDescPtr = std::shared_ptr; +using TensorOrderMap = std::map>; + +class DfGraphConvertor { + public: + explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) + : anf_graph_(anf_graph), df_graph_(std::make_shared(anf_graph_->ToString())) { +#if (!defined ENABLE_GE) || (defined ENABLE_INFER) + training_ = anf_graph->has_flag("training"); +#else + training_ = ENABLE_TRAIN; +#endif + distribute_ = anf_graph->has_flag("broadcast_flag"); + if (anf_graph->has_flag("broadcast_flag")) { + ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION); + } else { + ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE); + } + + MS_LOG(INFO) << "Create DfGraphConvertor with training: " << training_ << ", distribute: " << distribute_; + } + + ~DfGraphConvertor() {} + + static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt) { + get_adpt_map()[name] = std::make_shared(adpt); + } + static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { + get_adpt_map()[name] = std::make_shared(train_adpt, infer_adpt); + } + + void DrawComputeGraph(const std::string &name) { + std::ofstream fout(name); + if (!fout.is_open()) { + MS_LOG(ERROR) << "Open file '" << name << "' failed!"; + return; + } + fout << compute_sout_.str(); + fout.close(); + } + void DrawInitGraph(const std::string &name) { + std::ofstream fout(name); + if (!fout.is_open()) { + MS_LOG(ERROR) << "Open file '" << name << "' failed!"; + return; + } + fout << init_sout_.str(); + fout.close(); + } + void DrawSaveCheckpointGraph(const std::string &name) { + std::ofstream fout(name); + if (!fout.is_open()) { + MS_LOG(ERROR) << "Open file '" << name << "' failed!"; + return; + } + fout << checkpoint_sout_.str(); + fout.close(); + } + + DfGraphConvertor &ConvertAllNode(); + DfGraphConvertor &BuildGraph(); + DfGraphConvertor &InitParam(const TensorOrderMap &tensors); + DfGraphConvertor &GenerateCheckpointGraph(); + DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors); + void InitParamWithData(const TensorOrderMap &tensors); + void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node); + void SetupBroadcast(const std::shared_ptr &broadcast, const std::vector &broadcast_desc, + const DfGraphPtr &broadcast_graph, std::vector broadcast_input); + void MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it); + void SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input); + void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it); + + DfGraphPtr GetComputeGraph(); + DfGraphPtr GetInitGraph(); + DfGraphPtr GetSaveCheckpointGraph(); + DfGraphPtr GetBroadcastGraph(); + static OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false); + static OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false); + int ErrCode() const { return static_cast(error_); } + + static std::unordered_map &get_adpt_map(); + bool is_training() const { return training_; } + void set_training(bool is_training) { training_ = is_training; } + + protected: + void InitLoopVar(std::vector *init_input); + + private: + std::ostringstream compute_sout_; + std::ostringstream init_sout_; + std::ostringstream checkpoint_sout_; + std::ostringstream restore_checkpoint_sout_; + std::unordered_map op_draw_name_; + + AnfNodePtr TraceTupleGetItem(const CNodePtr &node, unsigned int *index); + AnfNodePtr TraceMakeTuple(const CNodePtr &node, unsigned int index); + AnfNodePtr TraceDepend(const CNodePtr &node); + OutHandler TraceRealOp(AnfNodePtr node); + OutHandler GetHandler(const AnfNodePtr &node, const std::stack &index_stack, AnfNode *const draw_index); + OperatorPtr Convert(AnfNodePtr node); + OperatorPtr ConvertCNode(CNodePtr node); + std::vector ConvertDependNode(AnfNodePtr node); + AnfNodePtr GetRealOpNode(AnfNodePtr node); + std::vector GetDependNodes(const AnfNodePtr &node); + OperatorPtr ConvertParameter(AnfNodePtr node); + Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); + OperatorPtr ConvertValueNode(ValueNodePtr node); + void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); + void ConvertTupleGetItem(const CNodePtr node); + void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, + const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list); + bool GetControlDependList(const CNodePtr &node, const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list); + void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); + void ConvertControlDependNode(const CNodePtr node); + void ConvertMakeTuple(const CNodePtr node); + bool CheckCNode(const std::string &name, const CNodePtr node); + void TraceOutput(AnfNodePtr node); + void TraceOutputFromParameter(const AnfNodePtr &anf_out); + void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out); + void SetNodeInput(AnfNodePtr node); + void SetOpControlInput(const AnfNodePtr node); + void UpdateOpDesc(AnfNodePtr node); + void SetSubgraph(AnfNodePtr node); + void ProcessSubgraph(AnfNodePtr node, const std::vector &inputs); + void BuildSaveCheckpointGraph(); + void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); + void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; + void AddGraphConstInput(const OperatorPtr &op); + + std::shared_ptr anf_graph_{nullptr}; + std::shared_ptr df_graph_{nullptr}; + std::shared_ptr init_graph_{nullptr}; + std::shared_ptr save_ckp_graph_{nullptr}; + std::shared_ptr restore_ckp_graph_{nullptr}; + std::shared_ptr broadcast_graph_{nullptr}; + std::unordered_map branches_map_; + std::unordered_map op_cache_; + std::unordered_map> control_depend_cache_; + /* record "tuple_getitem"<->"out_handler" mapping */ + std::unordered_map out_handle_cache_; + /* record "make_tuple"<->"out_handler vector" mapping */ + std::unordered_map>> tuple_out_handle_cache_; + std::unordered_map>> case_input_handle_cache_; + std::unordered_map params_; + std::unordered_map vars_; + std::vector> graph_outputs_; + std::vector graph_const_inputs_; + std::vector init_ops_; + std::vector broadcast_ops_; + std::vector inputs_; + OperatorPtr dataset_iter_getnext_; + Status error_ = SUCCESS; + bool training_ = false; + bool distribute_ = false; + bool use_inputs_ = false; +}; +} // namespace transform +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/df_graph_manager.cc b/mindspore/ccsrc/transform/graph_ir/df_graph_manager.cc new file mode 100644 index 0000000000..29985d6784 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/df_graph_manager.cc @@ -0,0 +1,214 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/df_graph_manager.h" + +#include +#include +#include +#include + +#include "securec/include/securec.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/pipeline.h" +#include "utils/config_manager.h" +#ifndef NO_DLIB +#include "tdt/tsd_client.h" +#endif + +namespace mindspore { +namespace transform { +DfGraphWrapper::DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, + const OptionMap &options) + : name_(name), id_(id), graph_ptr_(graph_ptr), options_(options) {} + +DfGraphManager::DfGraphManager() { + graph_id_ = 0; + graph_runner_ptr_ = nullptr; + sess_ptr_ = nullptr; +} + +DfGraphManager::~DfGraphManager() { + // in python fisrt destroy after atexit but in c++ destoy before atexit + DeleteGraphRunner(); + DeleteGeSession(); + ClearGraph(); + parse::python_adapter::set_python_env_flag(false); +} + +DfGraphManager &DfGraphManager::GetInstance() { + static DfGraphManager instance; + return instance; +} + +int DfGraphManager::GenerateId() { + graph_id_++; + if (graph_id_ <= 0) { + graph_id_ = 1; + } + MS_LOG(INFO) << "Generate graph Id : " << graph_id_; + return graph_id_; +} + +Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph_ptr, const OptionMap &options) { + std::lock_guard lg(lock_); + if (name.empty()) { + MS_LOG(ERROR) << "The graph name is null, add graph failed"; + return Status::INVALID_ARGUMENT; + } + + if (graph_ptr == nullptr) { + MS_LOG(WARNING) << "The new graph {" << name << "}'s pointer is null, add graph failed"; + return Status::INVALID_ARGUMENT; + } + + int id = GenerateId(); + DfGraphWrapperPtr wrap_ptr = std::make_shared(name, id, graph_ptr, options); + auto ret = graphs_.emplace(name, wrap_ptr); + if (ret.second == false) { + MS_LOG(WARNING) << "The graph name:{ " << name << " }is already exists! The old graph will be overwritten!!"; + ret.first->second = wrap_ptr; + } + MS_LOG(INFO) << "Add graph " << name << " to GraphManager success!"; + return Status::SUCCESS; +} + +std::vector DfGraphManager::GetAllGraphs() { + std::lock_guard lg(lock_); + std::vector ret; + std::stringstream ss; + ss << "{ "; + for (auto it = graphs_.begin(); it != graphs_.end(); ++it) { + ss << it->first << ", "; + ret.emplace_back(it->second); + } + ss << "}"; + MS_LOG(INFO) << "Return graphs: " << ss.str(); + return ret; +} +std::set DfGraphManager::GetSavedGraphs() { return saved_graphs_; } + +void DfGraphManager::AddSavedGraphs(const std::string &id) { saved_graphs_.insert(id); } + +DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string &name) { + std::lock_guard lg(lock_); + if (name.empty()) { + MS_LOG(ERROR) << "The graph name is null"; + return nullptr; + } + + auto it = graphs_.find(name); + if (it == graphs_.end()) { + MS_LOG(INFO) << "Can't found graph name: " << name; + return nullptr; + } + MS_LOG(INFO) << "Return graph: " << name; + return it->second; +} + +void DfGraphManager::ClearGraph() noexcept { + std::lock_guard lg(lock_); + graphs_.clear(); + anf_graphs_.clear(); + MS_LOG(INFO) << "Remove all graphs in GraphManager"; +} + +void DfGraphManager::SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) { + DfGraphWrapperPtr df_graph = GetGraphByName(name); + if (df_graph == nullptr) { + MS_LOG(ERROR) << "Can't found graph name: " << name; + return; + } + std::lock_guard lg(lock_); + anf_graphs_[df_graph->id_] = anf_graph_ptr; +} + +AnfGraphPtr DfGraphManager::GetAnfGraph(uint32_t graph_id) { + std::lock_guard lg(lock_); + auto iter = anf_graphs_.find(graph_id); + if (iter == anf_graphs_.end()) { + MS_LOG(ERROR) << "Can't found anf graph, graph_id = " << graph_id; + return nullptr; + } + + return iter->second; +} + +void DfGraphManager::EraseAnfGraph() { + std::lock_guard lg(lock_); + anf_graphs_.clear(); +} + +void DfGraphManager::SetGeSession(const std::shared_ptr &sess_ptr) { + std::lock_guard lg(lock_); + if (sess_ptr == nullptr) { + MS_LOG(WARNING) << "You are adding a empty Ge Session"; + } + + if (sess_ptr_ == nullptr) { + MS_LOG(INFO) << "Add a new Ge Session success"; + } else { + MS_LOG(INFO) << "Add a new Ge Session success, the old Ge Session will be overwritten!!"; + } + sess_ptr_ = sess_ptr; +} + +std::shared_ptr DfGraphManager::GetGeSession() { + std::lock_guard lg(lock_); + return sess_ptr_; +} + +void DfGraphManager::DeleteGeSession() noexcept { + std::lock_guard lg(lock_); + if (sess_ptr_ == nullptr) { + MS_LOG(INFO) << "Ge Session is not exist"; + } else { + sess_ptr_ = nullptr; + saved_graphs_.clear(); + MS_LOG(INFO) << "Delete Ge Session success"; + } +} + +void DfGraphManager::SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept { + std::lock_guard lg(lock_); + if (graph_runner_ptr == nullptr) { + MS_LOG(WARNING) << "You are adding a empty GraphRunner"; + } + + if (graph_runner_ptr_ == nullptr) { + MS_LOG(INFO) << "Add a new GraphRunner success"; + } else { + MS_LOG(INFO) << "Add a new GraphRunner success, the old GraphRunner will be overwritten!!"; + } + graph_runner_ptr_ = graph_runner_ptr; +} + +std::shared_ptr DfGraphManager::GetGraphRunner() { + std::lock_guard lg(lock_); + return graph_runner_ptr_; +} + +void DfGraphManager::DeleteGraphRunner() noexcept { + std::lock_guard lg(lock_); + if (graph_runner_ptr_ == nullptr) { + MS_LOG(INFO) << "GraphRunner is not exist"; + } else { + graph_runner_ptr_ = nullptr; + MS_LOG(INFO) << "Delete GraphRunner success"; + } +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/df_graph_manager.h b/mindspore/ccsrc/transform/graph_ir/df_graph_manager.h new file mode 100644 index 0000000000..8a574b7a04 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/df_graph_manager.h @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_DF_GRAPH_MANAGER_H_ +#define TRANSFORM_DF_GRAPH_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "transform/graph_ir/types.h" +#include "ir/anf.h" + +namespace mindspore { +const char BROADCAST_GRAPH_NAME[] = "broadcast_subgraph"; + +namespace transform { +class GraphRunner; +using OptionMap = std::map; + +struct DfGraphWrapper { + public: + DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, const OptionMap &options); + ~DfGraphWrapper() {} + + std::string name_; + int id_; + DfGraphPtr graph_ptr_; + OptionMap options_ = {}; +}; + +using DfGraphWrapperPtr = std::shared_ptr; + +class DfGraphManager { + public: + ~DfGraphManager(); + void ClearGraph() noexcept; + + static DfGraphManager &GetInstance(); + Status AddGraph(const std::string &name, const DfGraphPtr &graph, const OptionMap &options = {}); + std::vector GetAllGraphs(); + std::set GetSavedGraphs(); + void AddSavedGraphs(const std::string &id); + DfGraphWrapperPtr GetGraphByName(const std::string &name); + DfGraphManager(const DfGraphManager &) = delete; + void SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr); + AnfGraphPtr GetAnfGraph(uint32_t graph_id); + std::shared_ptr GetGraphRunner(); + void SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept; + void DeleteGraphRunner() noexcept; + void SetGeSession(const std::shared_ptr &sess_ptr); + std::shared_ptr GetGeSession(); + void DeleteGeSession() noexcept; + void EraseAnfGraph(); + + private: + DfGraphManager(); + int GenerateId(); + + std::mutex lock_; + std::map graphs_; + std::set saved_graphs_; + int graph_id_; + std::map anf_graphs_; + std::shared_ptr graph_runner_ptr_; + std::shared_ptr sess_ptr_; +}; +} // namespace transform +} // namespace mindspore + +#endif // TRANSFORM_DF_GRAPH_MANAGER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/graph_builder.cc b/mindspore/ccsrc/transform/graph_ir/graph_builder.cc new file mode 100644 index 0000000000..6ee45feef8 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/graph_builder.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/graph_builder.h" + +#include +#include + +namespace mindspore { +namespace transform { +DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam ¶m) { + MS_LOG(INFO) << "BuildMDDatasetGraph."; + + // InitData + auto d = ge::op::InitData("init_data_tmp").set_attr_channel_name(param.queue_name()); + + // set graph inputs & outputs + std::vector inputs{d}; + std::vector outputs{d}; + DfGraphPtr dataset_graph = std::make_shared("dataset"); + (void)dataset_graph->SetInputs(inputs); + (void)dataset_graph->SetOutputs(outputs); + + return dataset_graph; +} + +Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase) { + Status ret; + std::string graph_name = phase; + + MS_LOG(INFO) << "BuildDatasetGraph begin. phase is " << phase; + MS_LOG(INFO) << "param is " << param.ToString() << "."; + + DfGraphPtr dataset_graph = BuildMDDatasetGraph(param); + ret = DfGraphManager::GetInstance().AddGraph(graph_name, dataset_graph); + if (ret != Status::SUCCESS) { + MS_LOG(ERROR) << "BuildDatasetGraph failed."; + } else { + MS_LOG(INFO) << "BuildDatasetGraph end."; + } + return ret; +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/graph_builder.h b/mindspore/ccsrc/transform/graph_ir/graph_builder.h new file mode 100644 index 0000000000..5162674242 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/graph_builder.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_GRAPH_BUILDER_H_ +#define TRANSFORM_GRAPH_BUILDER_H_ + +#include +#include +#include +#include +#include +#include "transform/graph_ir/types.h" +#include "transform/graph_ir/convert.h" + +namespace mindspore { +namespace transform { +Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase = "dataset"); +} // namespace transform +} // namespace mindspore + +#endif // TRANSFORM_GRAPH_BUILDER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/graph_runner.cc b/mindspore/ccsrc/transform/graph_ir/graph_runner.cc new file mode 100644 index 0000000000..d20c49a381 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/graph_runner.cc @@ -0,0 +1,213 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * Limitations under the License. + */ + +#include "transform/graph_ir/graph_runner.h" +#include +#include +#include +#include "utils/log_adapter.h" +#include "utils/config_manager.h" +#include "sys/time.h" +#include "utils/callbacks.h" +#include "utils/utils.h" +#include "./common.h" +#ifdef ENABLE_GE +#include "utils/callbacks_ge.h" +#endif + +#ifdef NO_GE_CLIENT +namespace ge { +Session::Session(const std::map &options) { + if (options.empty()) { + MS_LOG(ERROR) << "session input options is empty"; + } + sessionId_ = 0; +} +Session::~Session() {} +} // namespace ge +#endif + +namespace mindspore { +namespace transform { +std::shared_ptr GraphRunner::NewSession(const SessionOptions &sess_options) { + std::shared_ptr ret = std::make_shared(sess_options); + if (ret == nullptr) { + MS_LOG(ERROR) << "Create GE session failed"; + return nullptr; + } + MS_LOG(INFO) << "Create new GE session success"; + return ret; +} + +GraphRunner::GraphRunner(const GraphRunnerOptions &options) + : options_(options), graph_manager_(DfGraphManager::GetInstance()) { + if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) { + MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode"; + } + + if (options.sess_ptr != nullptr) { + sess_ = options.sess_ptr; + } else { + sess_ = NewSession(options.options); + if (sess_ == nullptr) { + MS_LOG(EXCEPTION) << "GraphRunner initialize failed!!"; + return; + } + } + +#if (defined ENABLE_GE) + // register the callback function + if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ge::GRAPH_SUCCESS) { + MS_LOG(EXCEPTION) << "register callback failed!"; + return; + } + + if (sess_->RegisterCallBackFunc(callbacks::kSummary, callbacks::SummarySaveCallback) != ge::GRAPH_SUCCESS) { + MS_LOG(EXCEPTION) << "register summary callback failed!"; + return; + } +#endif + + std::vector wrappers = graph_manager_.GetAllGraphs(); + if (wrappers.empty()) { + MS_LOG(INFO) << "The GraphManager is empty!!"; + return; + } + +#ifdef ENABLE_GE + for (auto &it : wrappers) { + std::set saved_graph = graph_manager_.GetSavedGraphs(); + auto iter_find = saved_graph.find(std::to_string(it->id_)); + if (iter_find != saved_graph.end()) { + continue; + } + MS_LOG(INFO) << "Add the graph " << (*it).name_ << " to GE, it's id is: " << (*it).id_; + graph_manager_.AddSavedGraphs(std::to_string(it->id_)); + (void)sess_->AddGraph(it->id_, *(it->graph_ptr_), it->options_); + } +#endif +} + +Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, + std::vector *outputs) { + std::string name = options.name; + if (name.empty()) { + MS_LOG(ERROR) << "The graph name is null"; + return Status::INVALID_ARGUMENT; + } + + DfGraphWrapperPtr wrap_ptr = graph_manager_.GetGraphByName(name); + if (wrap_ptr == nullptr) { + MS_LOG(ERROR) << "Get graph form DfGraphManager failed!"; + return Status::NOT_FOUND; + } + + if (wrap_ptr->graph_ptr_ == nullptr) { + MS_LOG(WARNING) << "The graph is null"; + return Status::NOT_FOUND; + } + + // call ge::RunGraph() to exec a graph; + std::vector ge_inputs; + std::vector ge_outputs; + + (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(ge_inputs), + [](const GeTensorPtr &i) { return *i; }); + + MS_LOG(INFO) << "Run the graph in GE with " << ge_inputs.size() << " inputs"; + + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); + +#ifdef ENABLE_GE + if (sess_ == nullptr) { + MS_LOG(ERROR) << "The GE session is null, can't run the graph!"; + return Status::FAILED; + } + + // The information of some nodes could be changed after fusion in some cases + // Therefore a graph needs to be rebuilt in above situation + if (sess_->IsGraphNeedRebuild(wrap_ptr->id_)) { + sess_->RemoveGraph(wrap_ptr->id_); + sess_->AddGraph(wrap_ptr->id_, *(wrap_ptr->graph_ptr_), wrap_ptr->options_); + } + + ge::Status ret = sess_->RunGraph(wrap_ptr->id_, ge_inputs, ge_outputs); + if (ret != ge::GRAPH_SUCCESS) { + MS_LOG(ERROR) << "Call GE RunGraph Failed, ret is: " << ret; + return Status::FAILED; + } +#else + ge_outputs.swap(ge_inputs); +#endif + + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Call GE RunGraph Success in " << cost << " us, the GE outputs num is: " << ge_outputs.size(); + + (void)std::transform(ge_outputs.begin(), ge_outputs.end(), std::back_inserter(*outputs), + [](const GeTensor &ge_tensor) { return std::make_shared(ge_tensor); }); + + return Status::SUCCESS; +} + +Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, + std::vector *const outputs) { + std::vector ge_inputs; + for (auto it : inputs) { + MS_LOG(INFO) << "inputs tensor's data size is: " << (*it).DataSize(); + auto shape = (*it).shape(); + std::string shape_str; + for (const auto &elem : shape) { + shape_str += std::to_string(elem); + shape_str += " "; + } + MS_LOG(INFO) << "inputs tensor's shape is: { " << shape_str << "}"; + + auto ge_tensor_ptr = TransformUtil::ConvertTensor(it, kOpFormat_NCHW); + if (ge_tensor_ptr != nullptr) { + ge_inputs.emplace_back(ge_tensor_ptr); + } else { + MS_LOG(INFO) << "Convert input Me tensor to Ge tensor failed. Abort this graph"; + return Status::FAILED; + } + } + + std::vector ge_outputs; + Status ret; + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + ret = RunGraph(options, ge_inputs, &ge_outputs); + } + if (ret != Status::SUCCESS) { + return ret; + } else { + // conver GeTensor to MeTensor + for (auto &it : ge_outputs) { + auto tensor = TransformUtil::ConvertGeTensor(it); + if (tensor != nullptr) { + outputs->emplace_back(tensor); + } + } + MS_LOG(INFO) << "Return Me tensor outputs num is: " << outputs->size(); + return Status::SUCCESS; + } +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/graph_runner.h b/mindspore/ccsrc/transform/graph_ir/graph_runner.h new file mode 100644 index 0000000000..92db9e1413 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/graph_runner.h @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_GRAPH_RUNNER_H_ +#define TRANSFORM_GRAPH_RUNNER_H_ + +#include +#include +#include +#include +#include + +#include "transform/graph_ir/types.h" +#include "transform/graph_ir/util.h" +#include "ir/tensor.h" +#include "transform/graph_ir/df_graph_manager.h" + +namespace mindspore { +namespace transform { +using SessionOptions = std::map; + +struct GraphRunnerOptions { + std::string target{"default_graph_runner"}; + SessionOptions options; + // if sess_ptr is nullptr, GraphRunner will create a new ge session + std::shared_ptr sess_ptr{nullptr}; +}; + +struct RunOptions { + // graph's name + std::string name; +}; + +class GraphRunner { + public: + explicit GraphRunner(const GraphRunnerOptions &options); + ~GraphRunner() { sess_ = nullptr; } + Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); + Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); + static std::shared_ptr NewSession(const SessionOptions &sess_options); + + private: + std::shared_ptr sess_; + transform::GraphRunnerOptions options_; + DfGraphManager &graph_manager_; +}; +} // namespace transform +} // namespace mindspore + +#endif // TRANSFORM_GRAPH_RUNNER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter.h b/mindspore/ccsrc/transform/graph_ir/op_adapter.h new file mode 100644 index 0000000000..358cbd20a1 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter.h @@ -0,0 +1,913 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_OP_ADAPTER_H_ +#define TRANSFORM_OP_ADAPTER_H_ + +#include +#include +#include +#include + +#include "transform/graph_ir/op_adapter_util.h" +#include "utils/utils.h" +namespace mindspore { +namespace transform { +static uint32_t CustomInferFunc(const Operator &) { return 0; } + +template +class OpAdapter : public BaseOpAdapter { + public: + using OpType = T; + OpAdapter() {} + explicit OpAdapter(const ExtraAttr &extra_attr) : extra_attr_(extra_attr) {} + ~OpAdapter() override {} + + bool IsCustomOp(const OperatorPtr &op) { + MS_EXCEPTION_IF_NULL(op); + auto it = cus_input_map_.find(op->GetOpType()); + if (it == cus_input_map_.end()) { + return false; + } + return true; + } + + Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(op); + MS_EXCEPTION_IF_NULL(prim); + // Create the map of custom op from input index to input name. + std::unordered_map input_map; + auto value = prim->GetAttr("input_names"); + if (value == nullptr) { + cus_output_map_[prim->name()] = input_map; + return NOT_FOUND; + } + + auto input_names = GetValue>(value); + for (size_t i = 0; i < input_names.size(); ++i) { + // input_map begin form 1 + input_map[i + 1] = input_names[i]; + op->CustomInputRegister(input_names[i]); + } + + if (cus_input_map_.find(prim->name()) == cus_input_map_.end()) { + cus_input_map_[prim->name()] = input_map; + } + return SUCCESS; + } + + Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(op); + MS_EXCEPTION_IF_NULL(prim); + // Create the map of custom op from output index to output name. + std::unordered_map output_map; + auto value = prim->GetAttr("output_names"); + if (value == nullptr) { + // generate a empty output_map for it + cus_output_map_[prim->name()] = output_map; + return NOT_FOUND; + } + + auto output_names = GetValue>(value); + for (size_t i = 0; i < output_names.size(); ++i) { + // output_map begin form 0 + output_map[i] = output_names[i]; + op->CustomOutputRegister(output_names[i]); + } + + if (cus_output_map_.find(prim->name()) == cus_output_map_.end()) { + cus_output_map_[prim->name()] = output_map; + } + return SUCCESS; + } + + // Convert ME UserCustom AnfNode to GE CustomOp. And set it's attrs. + OperatorPtr GenerateCustomOp(const AnfNodePtr anf) { + MS_EXCEPTION_IF_NULL(anf); + auto node = anf->cast(); + if (node == nullptr) { + return nullptr; + } + + if (node->inputs().empty()) { + MS_LOG(EXCEPTION) << "length of node inputs is empty"; + } + + auto prim = GetValueNode(node->inputs()[0]); + MS_EXCEPTION_IF_NULL(prim); + auto op = std::make_shared(node->fullname_with_scope(), prim->name()); + if (GenerateCustomOpInputMap(op, prim) != SUCCESS) { + MS_LOG(WARNING) << "Custom op node has no input_names, op[" << prim->name() << "]."; + } + + if (GenerateCustomOpOutputMap(op, prim) != SUCCESS) { + MS_LOG(WARNING) << "Custom op node has no output_names, op[" << prim->name() << "]."; + } + + op->CustomInferFuncRegister(CustomInferFunc); + + return op; + } + + OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) { + OperatorPtr op = nullptr; + // There are duplicate names in ANF graph, do not assign ANF node name to GE + // GE will generate unique name automatically + if (anf != nullptr && anf->fullname_with_scope() != "") { + MS_LOG(DEBUG) << anf->fullname_with_scope(); + op = std::make_shared(anf->fullname_with_scope()); + } else { + MS_LOG(DEBUG) << "no fullname_with_scope"; + op = std::make_shared(); + } + + // set dynamic output num if op use DYNAMIC_OUTPUT + if ((op != nullptr) && (!dyn_output_map_.empty()) && (anf != nullptr)) { + TypePtr type = anf->Type(); + if (type == nullptr) { + MS_LOG(EXCEPTION) << "Dynamic output node:" << op->GetName() << "'s Type is a nullptr!"; + } + size_t num = type->isa() ? (type->cast>()->size()) : 1; + MS_LOG(INFO) << "create_dyn_output for node:" << anf->ToString() << ", type:" << type->ToString() + << ", num:" << num; + dyn_output_map_.begin()->second.create_dyn_output(op, static_cast(num)); + } + return op; + } + + OperatorPtr generate(const AnfNodePtr &anf) override { + OperatorPtr op = nullptr; + if (IsCustomCNode(anf)) { + op = GenerateCustomOp(anf); + } else { + op = GenerateNormalOp(anf); + } + return op; + } + + OperatorPtr generate(const std::string &op_name) override { return std::make_shared(op_name); } + + const std::unordered_map &getInputMap() override { return input_map_; } + const std::unordered_map &getInputAttrMap() override { return input_attr_map_; } + const std::unordered_map &getDynInputMap() override { return dyn_input_map_; } + const std::unordered_map &getOutputMap() override { return output_map_; } + const std::unordered_map &getDynSubgraphMap() override { return dyn_subgraph_map_; } + + Status SetOpSubgraphFunc(const OperatorPtr &op, int index, std::shared_ptr> branches) { + MS_EXCEPTION_IF_NULL(op); + auto it = dyn_subgraph_map_.find(index); + if (it != dyn_subgraph_map_.end()) { + auto size = branches->size(); + it->second.create_dyn_subgraph(op, static_cast(size)); + for (size_t i = 0; i < size; i++) { + it->second.set_subgraph(op, static_cast(i), std::make_shared((*branches)[i])); + } + return SUCCESS; + } + return NOT_FOUND; + } + + int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr> branches) override { + return static_cast(SetOpSubgraphFunc(op, index, branches)); + } + + Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { + MS_EXCEPTION_IF_NULL(op); + MS_EXCEPTION_IF_NULL(input); + auto it = cus_input_map_.find(op->GetOpType()); + if (it == cus_input_map_.end()) { + return NOT_FOUND; + } + std::unordered_map &input_map = it->second; + + if ((input_map.find(index) != input_map.end())) { + MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << input_map[index]; + (void)op->SetInput(input_map[index], *input); + return SUCCESS; + } + return NOT_FOUND; + } + + Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) { + MS_EXCEPTION_IF_NULL(op); + auto it = input_map_.find(index); + if (it != input_map_.end()) { + MS_EXCEPTION_IF_NULL(input); + MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << it->second.name; + it->second.set_op(op, input); + return SUCCESS; + } + return NOT_FOUND; + } + + int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override { + if (IsCustomOp(op)) { + auto cus_op = std::dynamic_pointer_cast(op); + return static_cast(SetCustomOpInput(cus_op, index, input)); + } else { + return static_cast(SetNormalOpInput(op, index, input)); + } + } + + Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) { + MS_EXCEPTION_IF_NULL(op); + auto it = cus_input_map_.find(op->GetOpType()); + if (it == cus_input_map_.end()) { + return NOT_FOUND; + } + + std::unordered_map &input_map = it->second; + if ((handle.op != nullptr) && (input_map.find(index) != input_map.end())) { + if (handle.out.empty()) { + MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << input_map[index]; + (void)op->SetInput(input_map[index], *(handle.op)); + } else { + MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":" + << input_map[index]; + (void)op->SetInput(input_map[index], *(handle.op), handle.out); + } + return SUCCESS; + } + return NOT_FOUND; + } + + Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) { + MS_EXCEPTION_IF_NULL(op); + auto it = input_map_.find(index); + if ((handle.op != nullptr) && (it != input_map_.end())) { + if (handle.out.empty()) { + MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << it->second.name; + it->second.set_op(op, handle.op); + } else { + MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":" + << it->second.name; + it->second.set_handle(op, handle); + } + return SUCCESS; + } + return NOT_FOUND; + } + + int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override { + if (IsCustomOp(op)) { + auto cus_op = std::dynamic_pointer_cast(op); + return static_cast(SetCustomOpInput(cus_op, index, handle)); + } else { + return static_cast(SetNormalOpInput(op, index, handle)); + } + } + + int setInput(const OperatorPtr &op, int index, const std::shared_ptr> &handler_vec) override { + MS_EXCEPTION_IF_NULL(handler_vec); + if (IsCustomOp(op)) { + MS_LOG(ERROR) << "Custom Op do not support dynamic input"; + return static_cast(FAILED); + } + MS_EXCEPTION_IF_NULL(op); + auto it = dyn_input_map_.find(index); + if (it != dyn_input_map_.end()) { + it->second.create_dyn_input(op, static_cast(handler_vec->size())); + for (unsigned int i = 0; i < handler_vec->size(); ++i) { + OutHandler h = (*handler_vec)[i]; + MS_EXCEPTION_IF_NULL(h.op); + if (h.out.empty()) { + MS_LOG(DEBUG) << "Link op " << h.op->GetName() << " to " << op->GetName() << ":" << it->second.name; + it->second.set_op(op, (i) /* index start from 0 */, h.op); + } else { + MS_LOG(DEBUG) << "Link op " << h.op->GetName() << ":" << h.out << " to " << op->GetName() << ":" + << it->second.name; + it->second.set_handle(op, i, h); + } + } + return 0; + } + return static_cast(NOT_FOUND); + } + + OutHandler getOutput(const OperatorPtr &op, int index) override { + MS_EXCEPTION_IF_NULL(op); + if (IsCustomOp(op)) { + return getCustomOutput(op, index); + } + return getNormalOutput(op, index); + } + + OutHandler getCustomOutput(const OperatorPtr &op, int index) { + MS_EXCEPTION_IF_NULL(op); + auto it = cus_output_map_.find(op->GetOpType()); + if (it == cus_output_map_.end()) { + MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT is not supported!"; + return OutHandler(); + } + + std::unordered_map &output_map = it->second; + + if ((output_map.find(index) != output_map.end())) { + return OutHandler(op, output_map[index]); + } + MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT index(" << index << ")!"; + return OutHandler(); + } + + OutHandler getNormalOutput(const OperatorPtr &op, int index) { + MS_EXCEPTION_IF_NULL(op); + if (!dyn_output_map_.empty() && !output_map_.empty()) { + MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!"; + return OutHandler(); + } + auto it = output_map_.find(index); + if (it != output_map_.end()) { + return OutHandler(op, it->second.name); + } else if (!dyn_output_map_.empty()) { + return OutHandler(op, dyn_output_map_.begin()->second.name + std::to_string(index)); + } else { + MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT and DYN_OUTPUT index(" << index << ")!"; + return OutHandler(); + } + } + + Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { + MS_EXCEPTION_IF_NULL(type); + std::string format = "NCHW"; + if (op->GetOpType() == kExtractImagePatchesOpName) { + format = "NHWC"; + } + + auto desc = CreateOutputDesc(dyn_cast(shp), type, format); + if (desc == nullptr) { + MS_LOG(ERROR) << "Update output descriptor failed!"; + return FAILED; + } + + if (IsCustomOp(op)) { + if (cus_output_map_.find(op->GetOpType()) == cus_output_map_.end() || + (cus_output_map_[op->GetOpType()].empty())) { + MS_LOG(ERROR) << "This op does not create custom output map"; + return FAILED; + } + auto cus_op = std::dynamic_pointer_cast(op); + MS_EXCEPTION_IF_NULL(cus_op); + std::unordered_map output_map = cus_output_map_[op->GetOpType()]; + (void)cus_op->UpdateOutputDesc(output_map[0], *desc); + } else { + if (output_map_.empty()) { + MS_LOG(INFO) << "This op does not have output map"; + return FAILED; + } + output_map_.begin()->second.update_out_desc(op, *desc); + } + return SUCCESS; + } + + size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { + MS_EXCEPTION_IF_NULL(cus_op); + if (cus_output_map_.find(cus_op->GetOpType()) == cus_output_map_.end()) { + MS_LOG(ERROR) << "This op does not create custom output map"; + return 0; + } + size_t output_size = cus_output_map_[cus_op->GetOpType()].size(); + return output_size; + } + + std::shared_ptr CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, + const std::string &format) { + if (shape_ptr == nullptr) { + MS_LOG(ERROR) << "Shape ptr is nullptr"; + return nullptr; + } + + if (type == nullptr) { + MS_LOG(ERROR) << "Type ptr is nullptr"; + return nullptr; + } + + TypeId me_type = type->type_id(); + if (kObjectTypeTensorType == me_type) { + me_type = dyn_cast(type)->element()->type_id(); + } + auto desc = TransformUtil::GetGeTensorDesc(shape_ptr->shape(), me_type, format); + return desc; + } + + Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { + auto tuple_shp = dyn_cast(shp); + MS_EXCEPTION_IF_NULL(tuple_shp); + + size_t output_size = 0; + bool is_custom_op = IsCustomOp(op); + if (is_custom_op) { + output_size = GetCustomOpOutputSize(std::dynamic_pointer_cast(op)); + } else { + output_size = output_map_.size(); + } + + if (output_size == 0) { + MS_LOG(INFO) << "This op does not have output map"; + return FAILED; + } + + if (output_size != tuple_shp->shape().size()) { + MS_LOG(ERROR) << "output_map is not equal tuple_shape size"; + return FAILED; + } + std::string format = "NCHW"; + if (op->GetOpType() == kTopKOpName) { + format = "NHWC"; + } + for (size_t i = 0; i < tuple_shp->shape().size(); ++i) { + auto tuple_type = dyn_cast(type); + MS_EXCEPTION_IF_NULL(tuple_type); + TypePtr type_elem = tuple_type->elements()[i]; + + auto desc = CreateOutputDesc(dyn_cast(tuple_shp->shape()[i]), type_elem, format); + if (desc == nullptr) { + MS_LOG(ERROR) << "Create output descriptor failed!"; + return FAILED; + } + + if (is_custom_op) { + (void)std::dynamic_pointer_cast(op)->UpdateOutputDesc(cus_output_map_[op->GetOpType()][i], + *desc); + } else { + auto it = output_map_.find(i); + if (it != output_map_.end()) { + it->second.update_out_desc(op, *desc); + } + } + } + return SUCCESS; + } + + std::shared_ptr CreateNodeDesc(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + TypeId me_type = node->Type()->type_id(); + if (kObjectTypeTensorType == me_type) { + me_type = dyn_cast(node->Type())->element()->type_id(); + } + if (me_type <= kNumberTypeBegin || me_type >= kNumberTypeEnd) { + return nullptr; + } + + std::vector shape; + auto shape_ptr = dyn_cast(node->Shape()); + if (nullptr != shape_ptr) { + shape = shape_ptr->shape(); + } + + auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); + if (desc == nullptr) { + MS_LOG(ERROR) << "Update output descriptor failed!"; + return nullptr; + } + return desc; + } + + void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) { + if (op == nullptr) { + MS_LOG(ERROR) << "op is nullptr"; + return; + } + MS_EXCEPTION_IF_NULL(node); + + auto inputs = node->cast()->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + auto it = input_map_.find(i); + if (it != input_map_.end()) { + auto desc = CreateNodeDesc(inputs[i]); + if (desc == nullptr) { + continue; + } + if (op->GetOpType() == kExtractImagePatchesOpName) { + desc->SetFormat(ge::Format::FORMAT_NHWC); + } + it->second.update_input_desc(op, *desc); + } + } + } + + void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { + if (op == nullptr) { + MS_LOG(ERROR) << "op is nullptr"; + return; + } + MS_EXCEPTION_IF_NULL(node); + + if (cus_input_map_.find(op->GetOpType()) == cus_input_map_.end() || (cus_input_map_[op->GetOpType()].empty())) { + MS_LOG(ERROR) << "This op does not create custom input map"; + return; + } + + std::unordered_map &input_map = cus_input_map_[op->GetOpType()]; + auto inputs = node->cast()->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + if (input_map.find(i) != input_map.end()) { + auto desc = CreateNodeDesc(inputs[i]); + if (desc == nullptr) { + continue; + } + (void)op->UpdateInputDesc(input_map[i], *desc); + } + } + } + + void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(op); + MS_EXCEPTION_IF_NULL(node); + if (IsCustomOp(op)) { + auto cus_op = std::dynamic_pointer_cast(op); + UpdateCustomOpInputDesc(cus_op, node); + } else { + UpdateNormalOpInputDesc(op, node); + } + } + + void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const AnfNodePtr &node) override { + if (op == nullptr) { + MS_LOG(ERROR) << "op is nullptr"; + return; + } + MS_EXCEPTION_IF_NULL(node); + MS_LOG(INFO) << "Op name is " << op->GetName(); + + auto normal_shape_ptr = dyn_cast(shp); + auto no_shape_ptr = dyn_cast(shp); + + if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) { + if (UpdateSingleOutputDesc(op, shp, type) != SUCCESS) { + return; + } + } else if (nullptr != dyn_cast(shp)) { + if (UpdateMultiOutputDesc(op, shp, type) != SUCCESS) { + return; + } + } else { + MS_LOG(WARNING) << "Update output desc failed, unknow output shape type"; + return; + } + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return; + } + + // Need to update input_desc while the output_desc is updated + updateInputDesc(op, node); + } + + int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override { + auto it = attr_map_.find(attrKey); + if (it != attr_map_.end()) { + // switch case for each avalilable attribute type + MS_LOG(INFO) << "Set attr: " << attrKey << "(" << it->second.name << "), value: " << attrValue->ToString(); + AddAttrToDrawGraph(attrKey + std::string("=") + attrValue->ToString()); + it->second.set_attr(op, attrValue); + return 0; + } + return static_cast(NOT_FOUND); + } + + int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { + enum ValueType { + SINGLE_VALUE = 0, + SEQUEUE_VALUE, + UNKNOWN_VALUE, + }; + + MS_EXCEPTION_IF_NULL(prim); + MS_EXCEPTION_IF_NULL(op); + + ValueType value_type = SINGLE_VALUE; + for (auto item : prim->attrs()) { + if (item.second->isa()) { + (void)op->SetAttr(item.first, GetValue(item.second)); + } else if (item.second->isa()) { + (void)op->SetAttr(item.first, GetValue(item.second)); + } else if (item.second->isa()) { + (void)op->SetAttr(item.first, GetValue(item.second)); + } else if (item.second->isa()) { + (void)op->SetAttr(item.first, GetValue(item.second)); + } else if (item.second->isa()) { + value_type = SEQUEUE_VALUE; + auto val_seq = item.second->cast(); + if ((*val_seq)[0]->isa()) { + (void)op->SetAttr(item.first, GetValue>(item.second)); + } else if ((*val_seq)[0]->isa()) { + (void)op->SetAttr(item.first, GetValue>(item.second)); + } else if ((*val_seq)[0]->isa()) { + (void)op->SetAttr(item.first, GetValue>(item.second)); + } else if ((*val_seq)[0]->isa()) { + (void)op->SetAttr(item.first, GetValue>(item.second)); + } else { + MS_LOG(EXCEPTION) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name() + << ", attr name: " << item.first << ", value: " << item.second->ToString(); + } + } else { + value_type = UNKNOWN_VALUE; + MS_LOG(WARNING) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name() + << ", attr name: " << item.first << ", value: " << item.second->ToString(); + return static_cast(NOT_FOUND); + } + + if (value_type == SINGLE_VALUE) { + AddAttrToDrawGraph(item.first + std::string("=") + item.second->ToString()); + } else if (value_type == SEQUEUE_VALUE) { + AddAttrToDrawGraph(item.first + std::string("=") + "[...]"); + } + } + return 0; + } + + int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { + int ret = 0; + MS_EXCEPTION_IF_NULL(prim); + MS_EXCEPTION_IF_NULL(op); + for (auto &it : attr_map_) { + auto value = prim->GetAttr(it.first); + if (value != nullptr) { + // set attr from primitive + ret = setAttr(op, it.first, value); + if (ret) { + return ret; + } + } else { + // set attr from extra_attr + auto it_extra = extra_attr_.find(it.first); + if (it_extra != extra_attr_.end()) { + ret = setAttr(op, it.first, it_extra->second); + if (ret) { + return ret; + } + } + } + } + return 0; + } + + int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { + int ret = 0; + if (IsCustomPrim(prim)) { + auto cus_op = std::dynamic_pointer_cast(op); + ret = SetCustomOpAttr(cus_op, prim); + } else { + ret = SetNormalOpAttr(op, prim); + } + return ret; + } + + int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { + // no attribute for lonely node + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return 0; + } + + auto cnode = node->cast(); + if (cnode == nullptr) { + return 0; + } + + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + return 0; + } + + // get Attr T from abstract of anfnode first, + // if attr "T" appears in primitive, the primitive T will cover this one + if (attr_map_.find("T") != attr_map_.end()) { + // get dtype from inputs[1], if the node has no inputs, set the attr T with output dtype + TypePtr type; + if (inputs.size() > 1) { + type = inputs[1]->Type(); + } else { + type = node->Type(); + } + if (type != nullptr) { + (void)setAttr(op, "T", MakeValue(type)); + } + } + + // set attr from primitive and ExtraAttr + if (IsValueNode(inputs[0])) { + // set attr from primitive + PrimitivePtr prim = GetValueNode(inputs[0]); + int ret = setAttr(op, prim); + if (ret != 0) { + return ret; + } + } + + // set attr from const input + for (auto &it : input_attr_map_) { + if (inputs.size() <= it.first || !inputs[it.first]->isa()) { + continue; + } + auto const_value = GetValueNode(inputs[it.first]); + MS_LOG(INFO) << "Set attr: input_" << it.first << "(" << it.second.name + << "), value: " << const_value->ToString(); + if (const_value->isa()) { + continue; + } + AddAttrToDrawGraph(it.second.name + std::string("=") + const_value->ToString()); + it.second.set_attr(op, const_value); + } + return 0; + } + + std::unordered_map GetExtraAttr() override { return extra_attr_; } + + private: + template + static S ConvertAny(const ValuePtr &value, const AnyTraits &) { + return GetValue(value); + } + + // specialization for reverse bool + static bool ConvertAny(const ValuePtr &value, const AnyTraits &, bool reverse) { + return reverse != GetValue(value); + } + + template + static Q ConvertAny(const ValuePtr &value, const AnyTraits

&traits_from, const AnyTraits &traits_to) { + return ConvertAnyUtil(value, traits_from, traits_to); + } + + // specialization for tensor + static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits &traits) { + // To-DO the format may read from ME tensor + return ConvertAnyUtil(value, traits); + } + + // specialization for int + static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { + return static_cast(GetValue(value)); + } + + // specialization for int or tuple broadcast to Vector + static std::vector ConvertAny(const ValuePtr &value, const std::string &name, + const AnyTraits> anyTraitsInt) { + return ConvertAnyUtil(value, name, anyTraitsInt); + } + + static std::vector> ConvertAny(const ValuePtr &value, + const AnyTraits>>) { + MS_EXCEPTION_IF_NULL(value); + MS_LOG(INFO) << "Value: " << value->type_name(); + std::vector> list; + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got " << value->type_name(); + } + auto vec = value->cast(); + MS_EXCEPTION_IF_NULL(vec); + for (auto &it : vec->value()) { + MS_EXCEPTION_IF_NULL(it); + if (!it->isa()) { + MS_LOG(EXCEPTION) << "It should be ValueTuple, but got " << it->type_name(); + } + auto sub_vector = it->cast(); + std::vector sublist; + for (auto &item : sub_vector->value()) { + sublist.push_back(static_cast(GetValue(item))); + } + list.push_back(sublist); + } + return list; + } + + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>>, + const AnyTraits>) { + MS_EXCEPTION_IF_NULL(value); + MS_LOG(DEBUG) << "Value: " << value->type_name(); + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value should be ValueList, but got " << value->type_name(); + } + auto vec = value->cast(); + std::vector list; + for (auto &it : vec->value()) { + MS_EXCEPTION_IF_NULL(it); + if (!it->isa()) { + MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name(); + } + auto sub_vector = it->cast(); + for (auto &item : sub_vector->value()) { + list.push_back(static_cast(GetValue(item))); + } + } + return list; + } + + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>, + const AnyTraits>) { + MS_EXCEPTION_IF_NULL(value); + MS_LOG(INFO) << "Value: " << value->type_name(); + std::vector list; + if (value->isa()) { + auto vec = value->cast(); + MS_EXCEPTION_IF_NULL(vec); + for (auto &it : vec->value()) { + list.push_back(static_cast(GetValue(it))); + } + return list; + } + if (value->isa()) { + list.push_back(static_cast(GetValue(value))); + return list; + } + MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name(); + } + + static std::string ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, + const AnyTraits anyTraitsStr) { + return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr); + } + + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, + const AnyTraits anyTraitsFlo) { + return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo); + } + + static std::vector ConvertAny(const ValuePtr &value, const std::string &format, + const AnyTraits> anyTraitsVec, + const AnyTraits anyTraitsInt) { + return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt); + } + + // convert value list for value tuple to vector + template + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits

&anyTraitsP, + const AnyTraits> anyTraitsQ) { + return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ); + } + + static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { + auto name = GetValue(value); + auto it = enum_map_.find(name); + int v = 0; + if (it != enum_map_.end()) { + v = it->second; + } + return v; + } + + static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsGE) { + return ConvertAnyUtil(value, anyTraitsGE); + } + + // convert any value to tensor + static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsValue) { + return ConvertAnyUtil(value, anyTraitsValue); + } + + static const std::unordered_map input_map_; + static const std::unordered_map dyn_input_map_; + static const std::unordered_map output_map_; + static const std::unordered_map dyn_output_map_; + static const std::unordered_map dyn_subgraph_map_; + static const std::unordered_map attr_map_; + static const std::unordered_map enum_map_; + // convert input from anf graph to Attr in Operators + static const std::unordered_map input_attr_map_; + static std::unordered_map> cus_input_map_; + static std::unordered_map> cus_output_map_; + std::unordered_map extra_attr_; + std::unordered_map name_counts_; +}; + +template +const std::unordered_map OpAdapter::input_map_; +template +const std::unordered_map OpAdapter::dyn_input_map_; +template +const std::unordered_map OpAdapter::output_map_; +template +const std::unordered_map OpAdapter::dyn_output_map_; +template +const std::unordered_map OpAdapter::dyn_subgraph_map_; +template +const std::unordered_map OpAdapter::attr_map_; +template +const std::unordered_map OpAdapter::enum_map_; +template +const std::unordered_map OpAdapter::input_attr_map_; +template +std::unordered_map> OpAdapter::cus_input_map_; +template +std::unordered_map> OpAdapter::cus_output_map_; + +// specialization for method +} // namespace transform +} // namespace mindspore + +#endif // TRANSFORM_OP_ADAPTER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h new file mode 100644 index 0000000000..77e28dda94 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h @@ -0,0 +1,198 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_OP_ADAPTER_BASE_H_ +#define TRANSFORM_OP_ADAPTER_BASE_H_ + +#include +#include +#include +#include +#include +#include + +#include "transform/graph_ir/util.h" +#include "ir/anf.h" +#include "ir/primitive.h" +#include "ir/value.h" +#include "transform/graph_ir/types.h" +#ifdef ENABLE_GE +#ifdef OPEN_SOURCE +#include "graph/types.h" +#endif +#endif + +#include "graph/operator_reg.h" +#ifdef OPEN_SOURCE +#include "ge/client/ge_api.h" +#else +#include "external/ge/ge_api.h" +#endif +#include "graph/tensor.h" +#include "transform/graph_ir/all_ops.h" + +namespace ge { +class CustomOperator : public Operator { + public: + CustomOperator(const string &name, const string &type) : Operator(name, type) {} + + ~CustomOperator() override{}; + + void CustomInputRegister(const string &name) { Operator::InputRegister(name); } + + void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); } + + void CustomInferFuncRegister(const std::function &func) { + Operator::InferFuncRegister(func); + } +}; +} // namespace ge + +namespace mindspore { +namespace transform { +using CusOperatorPtr = std::shared_ptr; +using CustomOperator = ge::CustomOperator; + +struct OutHandler { + OperatorPtr op; + std::string out; + OutHandler() : op(nullptr), out("") {} + OutHandler(const OperatorPtr &op, const std::string out) : op(op), out(out) {} +}; + +struct ControlEdge { + OperatorPtr src_op; + OperatorPtr dest_op; +}; + +using AttrFunc = std::function; +using OutputFunc = std::function; +using InputOpFunc = std::function; +using InputHandleFunc = std::function; +using CreateDynInputOpFunc = std::function; +using DynInputOpFunc = std::function; +using DynInputHandleFunc = std::function; +using UpdateOutputDescFunc = std::function; +using CreateDynOutputOpFunc = std::function; +using CreateDynSubGraphFunc = std::function; +using DynSubGraphFunc = std::function; + +struct AttrDesc { + std::string name; + AttrFunc set_attr; +}; + +struct InputDesc { + std::string name; + InputOpFunc set_op; + InputHandleFunc set_handle; + UpdateOutputDescFunc update_input_desc; +}; + +struct DynInputDesc { + std::string name; + CreateDynInputOpFunc create_dyn_input; + DynInputOpFunc set_op; + DynInputHandleFunc set_handle; +}; + +struct DynSubGraphDesc { + std::string name; + CreateDynSubGraphFunc create_dyn_subgraph; + DynSubGraphFunc set_subgraph; +}; + +struct OutputDesc { + std::string name; + UpdateOutputDescFunc update_out_desc; +}; + +struct DynOutputDesc { + std::string name; + CreateDynOutputOpFunc create_dyn_output; +}; + +class BaseOpAdapter { + public: + virtual ~BaseOpAdapter() {} + virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; + virtual OperatorPtr generate(const std::string &type) { return std::make_shared(type); } + virtual int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr> branches) = 0; + virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; + virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; + virtual int setInput(const OperatorPtr &op, int index, + const std::shared_ptr> &handler_vec) = 0; + virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0; + virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0; + virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0; + virtual std::unordered_map GetExtraAttr() = 0; + template ::value>::type> + int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr &attrValue) { + return setAttr(op, attrKey, MakeValue(attrValue)); + } + template ::value>::type> + int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) { + return setAttr(op, attrKey, MakeValue(attrValue)); + } + virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0; + virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const AnfNodePtr &node) = 0; + virtual const std::unordered_map &getInputMap() = 0; + virtual const std::unordered_map &getInputAttrMap() = 0; + virtual const std::unordered_map &getDynInputMap() = 0; + virtual const std::unordered_map &getOutputMap() = 0; + virtual const std::unordered_map &getDynSubgraphMap() = 0; + void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } + const std::vector &GetAttrsFromDrawGraph() const { return attrs_vec_; } + void clearAttrVect() { attrs_vec_.clear(); } + + private: + std::vector attrs_vec_; +}; + +using OpAdapterPtr = std::shared_ptr; + +enum AttrType { + ATTR_INT = 0, + ATTR_FLOAT, + ATTR_DOUBLE, + ATTR_STRING, + ATTR_TENSOR, + ATTR_BOOL, + ATTR_LIST_INT, + ATTR_LIST_ANY_INT, + ATTR_ENUM +}; + +struct GeEnum {}; +struct TFType {}; +struct GEType {}; + +// declare Any type +template +struct AnyTraits { + using type = T; +}; + +template <> +struct AnyTraits { + using type = int64_t; +}; + +using ExtraAttr = std::unordered_map; +} // namespace transform +} // namespace mindspore +#endif // TRANSFORM_OP_ADAPTER_BASE_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc new file mode 100644 index 0000000000..78f1f263de --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc @@ -0,0 +1,264 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/op_adapter_util.h" + +#include +#include +#include + +#include "utils/utils.h" +#include "transform/graph_ir/op_adapter_base.h" + +namespace mindspore { +namespace transform { +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &) { + // To-DO the format may read from ME tensor + MS_EXCEPTION_IF_NULL(value); + auto me_tensor = value->cast(); + auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_NCHW); + return ge_tensor == nullptr ? GeTensor() : *ge_tensor; +} + +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, + const AnyTraits>) { + MS_EXCEPTION_IF_NULL(value); + std::vector list; + if (name == "pad") { + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name(); + } + auto vec = value->cast(); + list.resize(vec->value().size() + 2); + list[0] = 1; + list[1] = 1; + (void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2, + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); + } else { + int64_t data = GetValue(value); + int size = 2; // 2 int in list + list = TransformUtil::ConvertIntToList(data, size); + } + + return list; +} + +std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { + MS_EXCEPTION_IF_NULL(value); + auto vec = value->cast(); + if (nullptr == vec) { + MS_LOG(EXCEPTION) << "not ValueTuplePtr"; + } + std::ostringstream buffer; + int i = 0; + for (auto &it : vec->value()) { + if (i != 0) { + buffer << ","; + } + buffer << GetValue(it); + i++; + } + return buffer.str(); +} + +std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { + MS_EXCEPTION_IF_NULL(value); + auto vec = value->cast(); + if (nullptr == vec) { + MS_LOG(EXCEPTION) << "not ValueTuplePtr"; + } + std::vector list; + list.resize(vec->value().size()); + (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); + return list; +} + +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, + const AnyTraits>, const AnyTraits) { + MS_EXCEPTION_IF_NULL(value); + auto vec = value->cast(); + if (nullptr == vec) { + MS_LOG(EXCEPTION) << "not ValueTuplePtr"; + } + std::vector list; + list.resize(vec->value().size()); + (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); + if (format == kOpFormat_NHWC) { + if (list.size() < 4) { + MS_LOG(EXCEPTION) << "The size of list is less than 4"; + } else { + int64_t temp = list[1]; + list[1] = list[2]; + list[2] = list[3]; + list[3] = temp; + } + } + return list; +} + +GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { + MS_EXCEPTION_IF_NULL(value); + if (!value->isa()) { + MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString() + << ", type: " << value->type_name() << ", value should be a Typeptr"; + } + auto type = value->cast(); + MS_EXCEPTION_IF_NULL(type); + TypeId me_type = type->type_id(); + if (kObjectTypeTensorType == me_type) { + me_type = dyn_cast(type)->element()->type_id(); + } + return TransformUtil::ConvertDataType(me_type); +} + +GeTensor VectorToTensorUtil(const ValuePtr &value) { + // convert tuple or list to ge tensor, only supported one dim for now + MS_EXCEPTION_IF_NULL(value); + auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); + if (vec.empty()) { + MS_LOG(WARNING) << "Convert a none tuple to an empty ge tensor"; + return GeTensor(); + } + MS_EXCEPTION_IF_NULL(vec[0]); + if (vec[0]->isa()) { + MS_LOG(INFO) << "convert value to tensor with data type = Int32"; + auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); + auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeInt32, kOpFormat_NCHW); + if (desc == nullptr) { + MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; + } + return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(int32_t)); + } else if (vec[0]->isa()) { + MS_LOG(INFO) << "convert value to tensor with data type = Float32"; + auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); + auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeFloat32, kOpFormat_NCHW); + if (desc == nullptr) { + MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; + } + return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(float)); + } else if (vec[0]->isa()) { + MS_LOG(INFO) << "convert value to tensor with data type = Bool"; + // We use uint8_t to save bool type data + auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); + auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeBool, kOpFormat_NCHW); + if (desc == nullptr) { + MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; + } + return GeTensor(*desc, static_cast(data.data()), data.size() * sizeof(uint8_t)); + } else { + MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name(); + } + + return GeTensor(); +} + +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + // convert me tensor to ge tensor + return ConvertAnyUtil(value, AnyTraits()); + } else if (value->isa() || value->isa()) { + return VectorToTensorUtil(value); + } else if (value->isa()) { + // convert scalar Int to GeTensor + MS_LOG(INFO) << "convert scalar to tensor with data type = Int32"; + GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + auto v = GetValue(value); + desc.SetRealDimCnt(0); + return GeTensor(desc, reinterpret_cast(&v), sizeof(int32_t)); + } else if (value->isa()) { + // convert scalar Int64 to GeTensor + MS_LOG(INFO) << "convert scalar to tensor with data type = Int64"; + GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); + auto v = GetValue(value); + desc.SetRealDimCnt(0); + return GeTensor(desc, reinterpret_cast(&v), sizeof(int64_t)); + } else if (value->isa()) { + // convert scalar FP32 to GeTensor + MS_LOG(INFO) << "convert scalar to tensor with data type = FP32"; + GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); + auto v = GetValue(value); + desc.SetRealDimCnt(0); + return GeTensor(desc, reinterpret_cast(&v), sizeof(float)); + } else if (value->isa()) { + // convert scalar FP32 to GeTensor + MS_LOG(INFO) << "convert scalar to tensor with data type = Bool"; + GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL); + auto v = GetValue(value); + desc.SetRealDimCnt(0); + return GeTensor(desc, reinterpret_cast(&v), sizeof(bool)); + } else if (value->isa()) { + // convert String to GeTensor + MS_LOG(INFO) << "convert string to tensor with data type = String"; + std::string v = GetValue(value); + std::vector ge_shape; + GeShape shape(ge_shape); + GeTensorDesc desc(shape, ge::FORMAT_NCHW, ge::DT_STRING); + GeTensor str_tensor(desc); + str_tensor.SetData(v); + return str_tensor; + } else { + MS_LOG(WARNING) << "Unsupported value type: " << value->type_name() + << " to convert to tensor. Value: " << value->ToString(); + } + return GeTensor(); +} + +bool IsCustomPrim(const PrimitivePtr &prim) { + if (prim == nullptr) { + return false; + } + + ValuePtr flag = prim->GetAttr("_custom_op_flag"); + if (flag == nullptr) { + return false; + } + + bool is_custom_op = GetValue(flag); + if (!is_custom_op && prim->GetAttr("_custom_op_impl_config_path") != nullptr) { + MS_LOG(EXCEPTION) << "The custom op flag is false, but the op information config path is not null, non-custom op " + "can not assign the op information config path."; + } + + return is_custom_op; +} + +bool IsCustomCNode(const AnfNodePtr &anf) { + if (anf == nullptr) { + return false; + } + auto node = anf->cast(); + if (node == nullptr) { + return false; + } + if (node->inputs().empty()) { + MS_LOG(EXCEPTION) << "length of node inputs is empty"; + } + MS_EXCEPTION_IF_NULL(node->inputs()[0]); + if (!node->inputs()[0]->isa()) { + return false; + } + auto cus_prim = GetValueNode(node->inputs()[0]); + if (cus_prim == nullptr) { + return false; + } + + return IsCustomPrim(cus_prim); +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h new file mode 100644 index 0000000000..0a0d745ba2 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h @@ -0,0 +1,66 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_OP_ADAPTER_UTIL_H_ +#define TRANSFORM_OP_ADAPTER_UTIL_H_ + +#include +#include + +#include "transform/graph_ir/op_adapter_base.h" + +namespace mindspore { +namespace transform { +template +static Q ConvertAnyUtil(const ValuePtr &value, const AnyTraits

&, const AnyTraits &) { + return static_cast(GetValue

(value)); +} + +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &traits); + +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, + const AnyTraits>); + +std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); + +std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); + +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, + const AnyTraits>, const AnyTraits); + +GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits); + +template +std::vector ConvertAnyUtil(const ValuePtr &value, AnyTraits

, const AnyTraits>) { + if (!value->isa() && !value->isa()) { + MS_LOG(EXCEPTION) << "error convert Value to vector for value: " << value->ToString() + << ", type: " << value->type_name() << ", value should be a tuple or list"; + } + auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); + std::vector data; + for (auto &it : vec) { + data.push_back(ConvertAnyUtil(it, AnyTraits

(), AnyTraits())); + } + return data; +} + +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits); + +bool IsCustomPrim(const PrimitivePtr &prim); +bool IsCustomCNode(const AnfNodePtr &node); +} // namespace transform +} // namespace mindspore +#endif // TRANSFORM_OP_ADAPTER_UTIL_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare.cc new file mode 100644 index 0000000000..e3751e0c92 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.cc @@ -0,0 +1,1330 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/op_declare.h" + +#include + +#include "transform/graph_ir/all_ops.h" +#include "utils/utils.h" + +namespace mindspore { +namespace transform { +#define INPUT_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::input_map_ +#define EMPTY_INPUT_MAP std::unordered_map() +#define INPUT_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, const OperatorPtr input) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_input_##name(*input); \ + }, \ + [](const OperatorPtr op, const OutHandler& handle) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_input_##name(*(handle.op), handle.out); \ + }, \ + [](const OperatorPtr op, const GeTensorDesc desc) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->update_input_desc_##name(desc); \ + } \ + } + +#define DYN_INPUT_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_input_map_ +#define DYN_INPUT_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, unsigned int num) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->create_dynamic_input_##name(num); \ + }, \ + [](const OperatorPtr op, unsigned int index, const OperatorPtr input) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_dynamic_input_##name(index, *input); \ + }, \ + [](const OperatorPtr op, unsigned int index, const OutHandler& handle) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_dynamic_input_##name(index, *(handle.op), handle.out); \ + } \ + } + +#define DYN_SUBGRAPH_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_subgraph_map_ +#define DYN_SUBGRAPH_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, unsigned int num) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->create_dynamic_subgraph_##name(num); \ + }, \ + [](const OperatorPtr op, unsigned int index, const DfGraphPtr graph) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_dynamic_subgraph_builder_##name(index, [graph](){return *graph;}); \ + } \ + } + +#define ATTR_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::attr_map_ +#define EMPTY_ATTR_MAP std::unordered_map() +#define ATTR_DESC(name, ...) \ + { \ +#name, \ + [](const OperatorPtr op, const ValuePtr& value) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_attr_##name(ConvertAny(value, __VA_ARGS__)); \ + } \ + } + +#define INPUT_ATTR_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::input_attr_map_ + +#define OUTPUT_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::output_map_ +#define OUTPUT_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, const GeTensorDesc desc) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->update_output_desc_##name(desc); \ + } \ + } + +#define DYN_OUTPUT_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_output_map_ + +#define DYN_OUTPUT_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, unsigned int num) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->create_dynamic_output_##name(num); \ + } \ + } + +template <> +std::unordered_map> OpAdapter::cus_input_map_{}; +template <> +std::unordered_map> OpAdapter::cus_output_map_{}; + +// --------------specialization for each operator---------- +// const +INPUT_MAP(Const) = EMPTY_INPUT_MAP; +ATTR_MAP(Const) = {{"value", ATTR_DESC(value, AnyTraits())}}; +OUTPUT_MAP(Const) = {{0, OUTPUT_DESC(y)}}; + +// Assign +INPUT_MAP(Assign) = {{1, INPUT_DESC(ref)}, {2, INPUT_DESC(value)}}; +ATTR_MAP(Assign) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Assign) = {{0, OUTPUT_DESC(ref)}}; + +// Constant +INPUT_MAP(Constant) = EMPTY_INPUT_MAP; +ATTR_MAP(Constant) = {{"value", ATTR_DESC(value, AnyTraits())}}; +OUTPUT_MAP(Constant) = {{0, OUTPUT_DESC(y)}}; + +// ApplyMomentumD +INPUT_MAP(ApplyMomentumD) = { + {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(momentum)}}; +ATTR_MAP(ApplyMomentumD) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}, + {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyMomentumD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; + +// ScalarSummary +INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}}; +ATTR_MAP(Summary) = EMPTY_ATTR_MAP; + +// Data +INPUT_MAP(Data) = EMPTY_INPUT_MAP; +ATTR_MAP(Data) = EMPTY_ATTR_MAP; + +// BatchNorm +INPUT_MAP(BatchNorm) = {{1, INPUT_DESC(x)}, + {2, INPUT_DESC(scale)}, + {3, INPUT_DESC(offset)}, + {4, INPUT_DESC(mean)}, + {5, INPUT_DESC(variance)}}; +ATTR_MAP(BatchNorm) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"epsilon", ATTR_DESC(epsilon, AnyTraits())}, + {"is_training", ATTR_DESC(is_training, AnyTraits())}}; +OUTPUT_MAP(BatchNorm) = {{0, OUTPUT_DESC(y)}, + {1, OUTPUT_DESC(batch_mean)}, + {2, OUTPUT_DESC(batch_variance)}, + {3, OUTPUT_DESC(reserve_space_1)}, + {4, OUTPUT_DESC(reserve_space_2)}}; + +// BatchNormGrad +INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)}, + {2, INPUT_DESC(x)}, + {3, INPUT_DESC(scale)}, + {4, INPUT_DESC(reserve_space_1)}, + {5, INPUT_DESC(reserve_space_2)}}; +ATTR_MAP(BatchNormGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"epsilon", ATTR_DESC(epsilon, AnyTraits())}, + {"is_training", ATTR_DESC(is_training, AnyTraits())}}; +OUTPUT_MAP(BatchNormGrad) = {{0, OUTPUT_DESC(x_backprop)}, + {1, OUTPUT_DESC(scale_backprop)}, + {2, OUTPUT_DESC(offset_backprop)}, + {3, OUTPUT_DESC(reserve_space_4)}, + {4, OUTPUT_DESC(reserve_space_5)}}; + +// Relu +INPUT_MAP(Relu) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Relu) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Relu) = {{0, OUTPUT_DESC(y)}}; + +// Elu +INPUT_MAP(Elu) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Elu) = {{"alpha", ATTR_DESC(alpha, AnyTraits())}}; +OUTPUT_MAP(Elu) = {{0, OUTPUT_DESC(y)}}; + +// EluGrad +INPUT_MAP(EluGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(activations)}}; +ATTR_MAP(EluGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(EluGrad) = {{0, OUTPUT_DESC(y)}}; + +// PRelu +INPUT_MAP(PRelu) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(weight)}}; +ATTR_MAP(PRelu) = EMPTY_ATTR_MAP; +OUTPUT_MAP(PRelu) = {{0, OUTPUT_DESC(y)}}; + +// PReluGrad +INPUT_MAP(PReluGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(features)}, {3, INPUT_DESC(weights)}}; +ATTR_MAP(PReluGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(PReluGrad) = {{0, OUTPUT_DESC(dx)}, {1, OUTPUT_DESC(da)}}; + +// Sigmoid +INPUT_MAP(Sigmoid) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Sigmoid) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Sigmoid) = {{0, OUTPUT_DESC(y)}}; + +// SigmoidGrad +INPUT_MAP(SigmoidGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(SigmoidGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SigmoidGrad) = {{0, OUTPUT_DESC(z)}}; + +// L2NormalizeGrad +INPUT_MAP(L2NormalizeGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}, {3, INPUT_DESC(dy)}}; +ATTR_MAP(L2NormalizeGrad) = { + {"axis", ATTR_DESC(dim, AnyTraits>(), AnyTraits>())}, + {"epsilon", ATTR_DESC(eps, AnyTraits())}}; +OUTPUT_MAP(L2NormalizeGrad) = {{0, OUTPUT_DESC(dx)}}; + +// LarsV2Update +INPUT_MAP(LarsV2Update) = {{1, INPUT_DESC(w)}, + {2, INPUT_DESC(g)}, + {3, INPUT_DESC(w_square_sum)}, + {4, INPUT_DESC(g_square_sum)}, + {5, INPUT_DESC(weight_decay)}, + {6, INPUT_DESC(learning_rate)}}; +ATTR_MAP(LarsV2Update) = {{"epsilon", ATTR_DESC(epsilon, AnyTraits())}, + {"hyperpara", ATTR_DESC(hyperpara, AnyTraits())}, + {"use_clip", ATTR_DESC(use_clip, AnyTraits())}}; +OUTPUT_MAP(LarsV2Update) = {{0, OUTPUT_DESC(g_new)}}; + +// L2Normalize +INPUT_MAP(L2Normalize) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(L2Normalize) = { + {"axis", ATTR_DESC(axis, AnyTraits>(), AnyTraits>())}, + {"epsilon", ATTR_DESC(eps, AnyTraits())}}; +OUTPUT_MAP(L2Normalize) = {{0, OUTPUT_DESC(y)}}; + +// CumsumD +INPUT_MAP(CumsumD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(CumsumD) = {{2, ATTR_DESC(axis, AnyTraits())}}; +ATTR_MAP(CumsumD) = {{"exclusive", ATTR_DESC(exclusive, AnyTraits())}, + {"reverse", ATTR_DESC(reverse, AnyTraits())}}; +OUTPUT_MAP(CumsumD) = {{0, OUTPUT_DESC(y)}}; + +// SoftmaxV2 +INPUT_MAP(SoftmaxV2) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(SoftmaxV2) = { + {"axis", ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}, +}; +OUTPUT_MAP(SoftmaxV2) = {{0, OUTPUT_DESC(y)}}; + +// SoftmaxGrad +INPUT_MAP(SoftmaxGrad) = {{1, INPUT_DESC(softmax)}, {2, INPUT_DESC(grad_softmax)}}; +OUTPUT_MAP(SoftmaxGrad) = {{0, OUTPUT_DESC(grad_x)}}; +ATTR_MAP(SoftmaxGrad) = EMPTY_ATTR_MAP; + +// Flatten +INPUT_MAP(Flatten) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Flatten) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Flatten) = {{0, OUTPUT_DESC(y)}}; + +// add +INPUT_MAP(Add) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Add) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Add) = {{0, OUTPUT_DESC(y)}}; + +// GatherV2 +INPUT_MAP(GatherV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(axis)}}; +ATTR_MAP(GatherV2) = EMPTY_ATTR_MAP; +OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}}; + +// ReduceSumD +INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceSumD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceSumD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceSumD) = {{0, OUTPUT_DESC(y)}}; + +// ReduceProdD +INPUT_MAP(ReduceProdD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceProdD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceProdD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceProdD) = {{0, OUTPUT_DESC(y)}}; + +// CumprodD +INPUT_MAP(CumprodD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(CumprodD) = {{2, ATTR_DESC(axis, AnyTraits())}}; +ATTR_MAP(CumprodD) = {{"exclusive", ATTR_DESC(exclusive, AnyTraits())}, + {"reverse", ATTR_DESC(reverse, AnyTraits())}}; +OUTPUT_MAP(CumprodD) = {{0, OUTPUT_DESC(y)}}; + +// SoftmaxCrossEntropyWithLogits +INPUT_MAP(SoftmaxCrossEntropyWithLogits) = {{1, INPUT_DESC(features)}, {2, INPUT_DESC(labels)}}; +ATTR_MAP(SoftmaxCrossEntropyWithLogits) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SoftmaxCrossEntropyWithLogits) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(backprop)}}; + +// MeanGrad +INPUT_MAP(MeanGrad) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(MeanGrad) = {{2, ATTR_DESC(mean_grad_output_shape_value, kOpFormat_NHWC, + AnyTraits>(), AnyTraits())}}; +ATTR_MAP(MeanGrad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; + +INPUT_MAP(SliceD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(SliceD) = {{2, ATTR_DESC(offsets, AnyTraits(), AnyTraits>())}, + {3, ATTR_DESC(size, AnyTraits(), AnyTraits>())}}; +ATTR_MAP(SliceD) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SliceD) = {{0, OUTPUT_DESC(y)}}; + +// MaxPool +INPUT_MAP(MaxPool) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(MaxPool) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(MaxPool) = {{0, OUTPUT_DESC(y)}}; + +// AvgPool +INPUT_MAP(AvgPool) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(AvgPool) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(AvgPool) = {{0, OUTPUT_DESC(y)}}; + +// GreaterEqual +INPUT_MAP(GreaterEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(GreaterEqual) = EMPTY_ATTR_MAP; +OUTPUT_MAP(GreaterEqual) = {{0, OUTPUT_DESC(y)}}; + +// AssignAdd +INPUT_MAP(AssignAdd) = {{1, INPUT_DESC(ref)}, {2, INPUT_DESC(value)}}; +ATTR_MAP(AssignAdd) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AssignAdd) = {{0, OUTPUT_DESC(ref)}}; + +// AssignSub +INPUT_MAP(AssignSub) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(value)}}; +ATTR_MAP(AssignSub) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AssignSub) = {{0, OUTPUT_DESC(var)}}; + +// Cos +INPUT_MAP(Cos) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Cos) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Cos) = {{0, OUTPUT_DESC(y)}}; + +// Acos +INPUT_MAP(Acos) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Acos) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Acos) = {{0, OUTPUT_DESC(y)}}; + +// AcosGrad +INPUT_MAP(AcosGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(AcosGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AcosGrad) = {{0, OUTPUT_DESC(z)}}; + +// Acosh +INPUT_MAP(Acosh) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Acosh) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Acosh) = {{0, OUTPUT_DESC(y)}}; + +// AcoshGrad +INPUT_MAP(AcoshGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(AcoshGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AcoshGrad) = {{0, OUTPUT_DESC(z)}}; + +// Floor +INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Floor) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Floor) = {{0, OUTPUT_DESC(y)}}; + +// FloorDiv +INPUT_MAP(FloorDiv) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(FloorDiv) = EMPTY_ATTR_MAP; +OUTPUT_MAP(FloorDiv) = {{0, OUTPUT_DESC(y)}}; + +// FloorMod +INPUT_MAP(FloorMod) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(FloorMod) = EMPTY_ATTR_MAP; +OUTPUT_MAP(FloorMod) = {{0, OUTPUT_DESC(y)}}; + +// Sin +INPUT_MAP(Sin) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Sin) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Sin) = {{0, OUTPUT_DESC(y)}}; + +// Exp +INPUT_MAP(Exp) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Exp) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Exp) = {{0, OUTPUT_DESC(y)}}; + +// BoundingBoxEncode +INPUT_MAP(BoundingBoxEncode) = { + {1, INPUT_DESC(anchor_box)}, + {2, INPUT_DESC(ground_truth_box)}, +}; +ATTR_MAP(BoundingBoxEncode) = { + {"means", ATTR_DESC(means, AnyTraits>(), AnyTraits())}, + {"stds", ATTR_DESC(stds, AnyTraits>(), AnyTraits())}, +}; +OUTPUT_MAP(BoundingBoxEncode) = {{0, OUTPUT_DESC(delats)}}; + +// BoundingBoxDecode +INPUT_MAP(BoundingBoxDecode) = { + {1, INPUT_DESC(rois)}, + {2, INPUT_DESC(deltas)}, +}; +ATTR_MAP(BoundingBoxDecode) = { + {"means", ATTR_DESC(means, AnyTraits>(), AnyTraits())}, + {"stds", ATTR_DESC(stds, AnyTraits>(), AnyTraits())}, + {"max_shape", ATTR_DESC(max_shape, AnyTraits>(), AnyTraits>())}, + {"wh_ratio_clip", ATTR_DESC(wh_ratio_clip, AnyTraits())}, +}; +OUTPUT_MAP(BoundingBoxDecode) = {{0, OUTPUT_DESC(bboxes)}}; + +// TopK +INPUT_MAP(TopK) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(k)}}; +ATTR_MAP(TopK) = {{"sorted", ATTR_DESC(sorted, AnyTraits())}}; +OUTPUT_MAP(TopK) = {{0, OUTPUT_DESC(values)}, {1, OUTPUT_DESC(indices)}}; + +// Multiply +INPUT_MAP(Multiply) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}}; +ATTR_MAP(Multiply) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Multiply) = {{0, OUTPUT_DESC(z)}}; + +// TileD +INPUT_MAP(TileD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(TileD) = {{2, ATTR_DESC(multiples, AnyTraits(), AnyTraits>())}}; +ATTR_MAP(TileD) = EMPTY_ATTR_MAP; +OUTPUT_MAP(TileD) = {{0, OUTPUT_DESC(y)}}; + +// OneHot +INPUT_MAP(OneHot) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(depth)}, {3, INPUT_DESC(on_value)}, {4, INPUT_DESC(off_value)}}; +ATTR_MAP(OneHot) = {{"axis", ATTR_DESC(axis, AnyTraits())}}; +OUTPUT_MAP(OneHot) = {{0, OUTPUT_DESC(y)}}; + +// GatherV2D +INPUT_MAP(GatherV2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}}; +INPUT_ATTR_MAP(GatherV2D) = {{3, ATTR_DESC(axis, AnyTraits())}}; +ATTR_MAP(GatherV2D) = EMPTY_ATTR_MAP; +OUTPUT_MAP(GatherV2D) = {{0, OUTPUT_DESC(y)}}; + +// Reshape +INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}}; +ATTR_MAP(Reshape) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}}; + +// TransShape +INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits(), AnyTraits>())}}; +ATTR_MAP(TransShape) = EMPTY_ATTR_MAP; +OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}}; + +// BiasAdd +INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}}; +ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(BiasAdd) = {{0, OUTPUT_DESC(y)}}; + +// Iou +INPUT_MAP(Iou) = {{1, INPUT_DESC(bboxes)}, {2, INPUT_DESC(gtboxes)}}; +ATTR_MAP(Iou) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; +OUTPUT_MAP(Iou) = {{0, OUTPUT_DESC(overlap)}}; + +// ResizeNearestNeighborV2D +INPUT_MAP(ResizeNearestNeighborV2D) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ResizeNearestNeighborV2D) = { + {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, + {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeNearestNeighborV2D) = {{0, OUTPUT_DESC(y)}}; + +// ResizeNearestNeighborV2Grad +INPUT_MAP(ResizeNearestNeighborV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; +ATTR_MAP(ResizeNearestNeighborV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeNearestNeighborV2Grad) = {{0, OUTPUT_DESC(y)}}; + +// ApplyAdam +INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, + {4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)}, + {10, INPUT_DESC(grad)}}; +ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, + {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}}; +OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; + +// ApplyAdamD +INPUT_MAP(ApplyAdamD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, + {4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)}, + {10, INPUT_DESC(grad)}}; +ATTR_MAP(ApplyAdamD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, + {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}}; +OUTPUT_MAP(ApplyAdamD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}}; + +// Relu6 +INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Relu6) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Relu6) = {{0, OUTPUT_DESC(y)}}; + +// Relu6Grad +INPUT_MAP(Relu6Grad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; +ATTR_MAP(Relu6Grad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Relu6Grad) = {{0, OUTPUT_DESC(backprops)}}; + +// ResizeBilinearV2Grad +INPUT_MAP(ResizeBilinearV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; +ATTR_MAP(ResizeBilinearV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeBilinearV2Grad) = {{0, OUTPUT_DESC(y)}}; + +// ResizeBilinearV2D +INPUT_MAP(ResizeBilinearV2D) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ResizeBilinearV2D) = { + {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, + {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeBilinearV2D) = {{0, OUTPUT_DESC(y)}}; + +// ZerosLike +INPUT_MAP(ZerosLike) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ZerosLike) = EMPTY_ATTR_MAP; +OUTPUT_MAP(ZerosLike) = {{0, OUTPUT_DESC(y)}}; + +// OnesLike +INPUT_MAP(OnesLike) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(OnesLike) = EMPTY_ATTR_MAP; +OUTPUT_MAP(OnesLike) = {{0, OUTPUT_DESC(y)}}; + +// NMSWithMask +INPUT_MAP(NMSWithMask) = {{1, INPUT_DESC(box_scores)}}; +ATTR_MAP(NMSWithMask) = {{"iou_threshold", ATTR_DESC(iou_threshold, AnyTraits())}}; +OUTPUT_MAP(NMSWithMask) = { + {0, OUTPUT_DESC(selected_boxes)}, {1, OUTPUT_DESC(selected_idx)}, {2, OUTPUT_DESC(selected_mask)}}; + +// Unpack +INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits())}, {"num", ATTR_DESC(num, AnyTraits())}}; +DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; + +// TensorScatterUpdate +INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP; +OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}}; + +// ScatterUpdate +INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ScatterUpdate) = {{0, OUTPUT_DESC(var)}}; + +// ScatterNdUpdate +INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ScatterNdUpdate) = {{0, OUTPUT_DESC(var)}}; + +// ScatterMax +INPUT_MAP(ScatterMax) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(ScatterMax) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ScatterMax) = {{0, OUTPUT_DESC(var)}}; + +// CheckValid +INPUT_MAP(CheckValid) = {{1, INPUT_DESC(bbox_tensor)}, {2, INPUT_DESC(img_metas)}}; +ATTR_MAP(CheckValid) = EMPTY_ATTR_MAP; +OUTPUT_MAP(CheckValid) = {{0, OUTPUT_DESC(valid_tensor)}}; + +// SmoothL1Loss +INPUT_MAP(SmoothL1Loss) = {{1, INPUT_DESC(predict)}, {2, INPUT_DESC(label)}}; +ATTR_MAP(SmoothL1Loss) = {{"sigma", ATTR_DESC(sigma, AnyTraits())}}; +OUTPUT_MAP(SmoothL1Loss) = {{0, OUTPUT_DESC(loss)}}; + +// SmoothL1LossGrad +INPUT_MAP(SmoothL1LossGrad) = {{1, INPUT_DESC(predict)}, {2, INPUT_DESC(label)}, {3, INPUT_DESC(dout)}}; +ATTR_MAP(SmoothL1LossGrad) = {{"sigma", ATTR_DESC(sigma, AnyTraits())}}; +OUTPUT_MAP(SmoothL1LossGrad) = {{0, OUTPUT_DESC(gradient)}}; + +// SigmoidCrossEntropyWithLogits +INPUT_MAP(SigmoidCrossEntropyWithLogits) = {{1, INPUT_DESC(predict)}, {2, INPUT_DESC(target)}}; +ATTR_MAP(SigmoidCrossEntropyWithLogits) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SigmoidCrossEntropyWithLogits) = {{0, OUTPUT_DESC(loss)}}; + +// SigmoidCrossEntropyWithLogitsGrad +INPUT_MAP(SigmoidCrossEntropyWithLogitsGrad) = { + {1, INPUT_DESC(predict)}, {2, INPUT_DESC(target)}, {3, INPUT_DESC(dout)}}; +ATTR_MAP(SigmoidCrossEntropyWithLogitsGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SigmoidCrossEntropyWithLogitsGrad) = {{0, OUTPUT_DESC(gradient)}}; + +// ScatterNdD +INPUT_MAP(ScatterNdD) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ScatterNdD) = { + {3, ATTR_DESC(shape, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ScatterNdD) = EMPTY_ATTR_MAP; +OUTPUT_MAP(ScatterNdD) = {{0, OUTPUT_DESC(y)}}; + +// PadD +INPUT_MAP(PadD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(PadD) = {{"paddings", ATTR_DESC(paddings, AnyTraits>>())}}; +OUTPUT_MAP(PadD) = {{0, OUTPUT_DESC(y)}}; + +// MirrorPad +INPUT_MAP(MirrorPad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}}; +ATTR_MAP(MirrorPad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; +OUTPUT_MAP(MirrorPad) = {{0, OUTPUT_DESC(y)}}; + +// MirrorPadGrad +INPUT_MAP(MirrorPadGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}}; +ATTR_MAP(MirrorPadGrad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; +OUTPUT_MAP(MirrorPadGrad) = {{0, OUTPUT_DESC(y)}}; + +// GatherNd +INPUT_MAP(GatherNd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}}; +ATTR_MAP(GatherNd) = EMPTY_ATTR_MAP; +OUTPUT_MAP(GatherNd) = {{0, OUTPUT_DESC(y)}}; + +// ROIAlign +INPUT_MAP(ROIAlign) = {{1, INPUT_DESC(features)}, {2, INPUT_DESC(rois)}}; +OUTPUT_MAP(ROIAlign) = {{0, OUTPUT_DESC(y)}}; +ATTR_MAP(ROIAlign) = {{"pooled_height", ATTR_DESC(pooled_height, AnyTraits())}, + {"pooled_width", ATTR_DESC(pooled_width, AnyTraits())}, + {"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits())}, + {"sample_num", ATTR_DESC(sample_num, AnyTraits())}, + {"roi_end_mode", ATTR_DESC(roi_end_mode, AnyTraits())}}; + +// ROIAlignGrad +INPUT_MAP(ROIAlignGrad) = {{1, INPUT_DESC(ydiff)}, {2, INPUT_DESC(rois)}}; +OUTPUT_MAP(ROIAlignGrad) = {{0, OUTPUT_DESC(xdiff)}}; +ATTR_MAP(ROIAlignGrad) = { + {"xdiff_shape", ATTR_DESC(xdiff_shape, AnyTraits>(), AnyTraits>())}, + {"pooled_height", ATTR_DESC(pooled_height, AnyTraits())}, + {"pooled_width", ATTR_DESC(pooled_width, AnyTraits())}, + {"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits())}, + {"sample_num", ATTR_DESC(sample_num, AnyTraits())}}; + +// ArgMaxD +INPUT_MAP(ArgMaxD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ArgMaxD) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, + {"output_type", ATTR_DESC(dtype, AnyTraits())}}; +OUTPUT_MAP(ArgMaxD) = {{0, OUTPUT_DESC(y)}}; + +// ArgMinD +INPUT_MAP(ArgMinD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ArgMinD) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, + {"output_type", ATTR_DESC(dtype, AnyTraits())}}; +OUTPUT_MAP(ArgMinD) = {{0, OUTPUT_DESC(y)}}; + +// ArgMaxWithValue +INPUT_MAP(ArgMaxWithValue) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ArgMaxWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, + {"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ArgMaxWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; + +// ArgMinWithValue +INPUT_MAP(ArgMinWithValue) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ArgMinWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, + {"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ArgMinWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; + +// ReduceAllD +INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceAllD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}}; + +// ReduceMeanD +INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceMeanD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceMeanD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceMeanD) = {{0, OUTPUT_DESC(y)}}; + +// HCOMAllreduce +INPUT_MAP(HcomAllReduce) = {{1, INPUT_DESC(x)}}; +OUTPUT_MAP(HcomAllReduce) = {{0, OUTPUT_DESC(y)}}; +ATTR_MAP(HcomAllReduce) = {{"op", ATTR_DESC(reduction, AnyTraits())}, + {"group", ATTR_DESC(group, AnyTraits())}, + {"fusion", ATTR_DESC(fusion, AnyTraits())}}; + +// HCOMBraodcast +INPUT_MAP(HcomBroadcast) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(HcomBroadcast) = {{1, DYN_INPUT_DESC(x)}}; +DYN_OUTPUT_MAP(HcomBroadcast) = {{0, DYN_OUTPUT_DESC(y)}}; +ATTR_MAP(HcomBroadcast) = {{"root_rank", ATTR_DESC(root_rank, AnyTraits())}, + {"group", ATTR_DESC(group, AnyTraits())}}; + +// HCOMAllreduce +INPUT_MAP(HcomAllGather) = {{1, INPUT_DESC(x)}}; +OUTPUT_MAP(HcomAllGather) = {{0, OUTPUT_DESC(y)}}; +ATTR_MAP(HcomAllGather) = {{"group", ATTR_DESC(group, AnyTraits())}, + {"rank_size", ATTR_DESC(rank_size, AnyTraits())}}; + +// HCOMReduceScatter +INPUT_MAP(HcomReduceScatter) = {{1, INPUT_DESC(x)}}; +OUTPUT_MAP(HcomReduceScatter) = {{0, OUTPUT_DESC(y)}}; +ATTR_MAP(HcomReduceScatter) = {{"group", ATTR_DESC(group, AnyTraits())}, + {"op", ATTR_DESC(reduction, AnyTraits())}, + {"rank_size", ATTR_DESC(rank_size, AnyTraits())}}; + +// Variable +INPUT_MAP(Variable) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Variable) = EMPTY_ATTR_MAP; + +// ReluGrad +INPUT_MAP(ReluGrad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; +ATTR_MAP(ReluGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(ReluGrad) = {{0, OUTPUT_DESC(backprops)}}; + +// BiasAddGrad +INPUT_MAP(BiasAddGrad) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(BiasAddGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(BiasAddGrad) = {{0, OUTPUT_DESC(y)}}; + +// MaxPoolGrad +INPUT_MAP(MaxPoolGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grad)}}; +ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; + +// avgpoolgrad +INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}}; +ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(AvgPoolGrad) = {{0, OUTPUT_DESC(out_grad)}}; + +// MaxPoolWithArgmax +INPUT_MAP(MaxPoolWithArgmax) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(MaxPoolWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}}; +OUTPUT_MAP(MaxPoolWithArgmax) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(argmax)}}; + +// MaxPoolGradWithArgmax +INPUT_MAP(MaxPoolGradWithArgmax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(grad)}, {3, INPUT_DESC(argmax)}}; +ATTR_MAP(MaxPoolGradWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}}; +OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}}; + +// ExtractImagePatches +INPUT_MAP(ExtractImagePatches) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"rates", ATTR_DESC(rates, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}}; +OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}}; + +// Conv2D +INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; +ATTR_MAP(Conv2D) = { + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"group", ATTR_DESC(groups, AnyTraits())}, +}; +OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}}; + +// Conv2DBackpropInputD +INPUT_MAP(Conv2DBackpropInputD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(filter)}}; +INPUT_ATTR_MAP(Conv2DBackpropInputD) = { + {3, ATTR_DESC(input_size, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(Conv2DBackpropInputD) = { + {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"group", ATTR_DESC(groups, AnyTraits())}, +}; +OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}}; + +// Conv2DBackpropFilterD +INPUT_MAP(Conv2DBackpropFilterD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { + {3, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(Conv2DBackpropFilterD) = { + {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"group", ATTR_DESC(groups, AnyTraits())}, +}; +OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; + +// DepthwiseConv2D +INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; +ATTR_MAP(DepthwiseConv2D) = { + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, +}; +OUTPUT_MAP(DepthwiseConv2D) = {{0, OUTPUT_DESC(y)}}; + +// DepthwiseConv2DBackpropInputD +INPUT_MAP(DepthwiseConv2DBackpropInputD) = {{2, INPUT_DESC(filter)}, {3, INPUT_DESC(out_backprop)}}; +INPUT_ATTR_MAP(DepthwiseConv2DBackpropInputD) = { + {1, ATTR_DESC(input_size, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(DepthwiseConv2DBackpropInputD) = { + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, +}; +OUTPUT_MAP(DepthwiseConv2DBackpropInputD) = {{0, OUTPUT_DESC(input_grad)}}; + +// DepthwiseConv2DBackpropFilterD +INPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{1, INPUT_DESC(input)}, {3, INPUT_DESC(out_backprop)}}; +INPUT_ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { + {2, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, +}; +OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; + +// MatMulV2 +INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits())}, + {"transpose_b", ATTR_DESC(transpose_x2, AnyTraits())}}; +OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}}; + +// Merge +INPUT_MAP(Merge) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(Merge) = {{1, DYN_INPUT_DESC(x)}}; +ATTR_MAP(Merge) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Merge) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(value_index)}}; + +// Switch +INPUT_MAP(Switch) = {{1, INPUT_DESC(data)}, {2, INPUT_DESC(pred)}}; +OUTPUT_MAP(Switch) = {{0, OUTPUT_DESC(output_false)}, {1, OUTPUT_DESC(output_true)}}; +ATTR_MAP(Switch) = EMPTY_ATTR_MAP; + +// AddN +INPUT_MAP(AddN) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(AddN) = {{1, DYN_INPUT_DESC(x)}}; +ATTR_MAP(AddN) = {{"n", ATTR_DESC(N, AnyTraits())}}; +OUTPUT_MAP(AddN) = {{0, OUTPUT_DESC(y)}}; + +// Mul +INPUT_MAP(Mul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Mul) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Mul) = {{0, OUTPUT_DESC(y)}}; + +// RealDiv +INPUT_MAP(RealDiv) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(RealDiv) = EMPTY_ATTR_MAP; +OUTPUT_MAP(RealDiv) = {{0, OUTPUT_DESC(y)}}; + +// Cast +INPUT_MAP(Cast) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits())}}; +ATTR_MAP(Cast) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}}; + +// Case +INPUT_MAP(Case) = {{1, INPUT_DESC(branch_index)}}; +DYN_INPUT_MAP(Case) = {{2, DYN_INPUT_DESC(input)}}; +ATTR_MAP(Case) = EMPTY_ATTR_MAP; +DYN_OUTPUT_MAP(Case) = {{0, DYN_OUTPUT_DESC(output)}}; +DYN_SUBGRAPH_MAP(Case) = {{0, DYN_SUBGRAPH_DESC(branches)}}; + +// Reciprocal +INPUT_MAP(Reciprocal) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Reciprocal) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Reciprocal) = {{0, OUTPUT_DESC(y)}}; + +// Sub +INPUT_MAP(Sub) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Sub) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Sub) = {{0, OUTPUT_DESC(y)}}; + +// SplitD +INPUT_MAP(SplitD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits())}, + {"output_num", ATTR_DESC(num_split, AnyTraits())}}; +DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(y)}}; + +// Range +INPUT_MAP(RangeD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(RangeD) = {{"start", ATTR_DESC(start, AnyTraits())}, + {"limit", ATTR_DESC(limit, AnyTraits())}, + {"delta", ATTR_DESC(delta, AnyTraits())}}; +OUTPUT_MAP(RangeD) = {{0, OUTPUT_DESC(y)}}; + +// Neg +INPUT_MAP(Neg) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Neg) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Neg) = {{0, OUTPUT_DESC(y)}}; + +// Transpose +INPUT_MAP(TransposeD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(TransposeD) = {{2, ATTR_DESC(perm, AnyTraits(), AnyTraits>())}}; +ATTR_MAP(TransposeD) = EMPTY_ATTR_MAP; +// Do not set Transpose operator output descriptor + +// DropOutGenMask +INPUT_MAP(DropOutGenMask) = {{1, INPUT_DESC(shape)}, {2, INPUT_DESC(prob)}}; +ATTR_MAP(DropOutGenMask) = {{"Seed0", ATTR_DESC(seed, AnyTraits())}, + {"Seed1", ATTR_DESC(seed2, AnyTraits())}}; +OUTPUT_MAP(DropOutGenMask) = {{0, OUTPUT_DESC(y)}}; + +// Pack +INPUT_MAP(Pack) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(Pack) = {{1, DYN_INPUT_DESC(x)}}; +ATTR_MAP(Pack) = {{"num", ATTR_DESC(N, AnyTraits())}, {"axis", ATTR_DESC(axis, AnyTraits())}}; +OUTPUT_MAP(Pack) = {{0, OUTPUT_DESC(y)}}; + +// ConcatD +INPUT_MAP(ConcatD) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(ConcatD) = {{1, DYN_INPUT_DESC(x)}}; +ATTR_MAP(ConcatD) = { + {"axis", ATTR_DESC(concat_dim, AnyTraits())}, + {"inputNums", ATTR_DESC(N, AnyTraits())}, +}; +OUTPUT_MAP(ConcatD) = {{0, OUTPUT_DESC(y)}}; + +// Less +INPUT_MAP(Less) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Less) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Less) = {{0, OUTPUT_DESC(y)}}; + +// Rsqrt +INPUT_MAP(Rsqrt) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Rsqrt) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Rsqrt) = {{0, OUTPUT_DESC(y)}}; + +// Sqrt +INPUT_MAP(Sqrt) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Sqrt) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Sqrt) = {{0, OUTPUT_DESC(y)}}; + +// Square +INPUT_MAP(Square) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Square) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Square) = {{0, OUTPUT_DESC(y)}}; + +// SquareSumAll +INPUT_MAP(SquareSumAll) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(SquareSumAll) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SquareSumAll) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}}; + +// Tanh +INPUT_MAP(Tanh) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Tanh) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Tanh) = {{0, OUTPUT_DESC(y)}}; + +// TanhGrad +INPUT_MAP(TanhGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(TanhGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(TanhGrad) = {{0, OUTPUT_DESC(z)}}; + +// ReduceMinD +INPUT_MAP(ReduceMinD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceMinD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceMinD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceMinD) = {{0, OUTPUT_DESC(y)}}; + +// ReduceMaxD +INPUT_MAP(ReduceMaxD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceMaxD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceMaxD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceMaxD) = {{0, OUTPUT_DESC(y)}}; + +// Maximum +INPUT_MAP(Maximum) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Maximum) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Maximum) = {{0, OUTPUT_DESC(y)}}; + +// Minimum +INPUT_MAP(Minimum) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Minimum) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Minimum) = {{0, OUTPUT_DESC(y)}}; + +// MaximumGrad +INPUT_MAP(MaximumGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grads)}}; +ATTR_MAP(MaximumGrad) = {{"grad_x", ATTR_DESC(grad_x, AnyTraits())}, + {"grad_y", ATTR_DESC(grad_y, AnyTraits())}}; +OUTPUT_MAP(MaximumGrad) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}}; + +// MinimumGrad +INPUT_MAP(MinimumGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grads)}}; +ATTR_MAP(MinimumGrad) = {{"grad_x", ATTR_DESC(grad_x, AnyTraits())}, + {"grad_y", ATTR_DESC(grad_y, AnyTraits())}}; +OUTPUT_MAP(MinimumGrad) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}}; + +// Pow +INPUT_MAP(Pow) = { + {1, INPUT_DESC(x1)}, + {2, INPUT_DESC(x2)}, +}; +ATTR_MAP(Pow) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Pow) = {{0, OUTPUT_DESC(y)}}; + +// Equal +INPUT_MAP(Equal) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Equal) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Equal) = {{0, OUTPUT_DESC(y)}}; + +// NotEqual +INPUT_MAP(NotEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(NotEqual) = EMPTY_ATTR_MAP; +OUTPUT_MAP(NotEqual) = {{0, OUTPUT_DESC(y)}}; + +// Log +INPUT_MAP(Log) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Log) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Log) = {{0, OUTPUT_DESC(y)}}; + +// LogicalAnd +INPUT_MAP(LogicalAnd) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(LogicalAnd) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LogicalAnd) = {{0, OUTPUT_DESC(y)}}; + +// LogicalOr +INPUT_MAP(LogicalOr) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(LogicalOr) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LogicalOr) = {{0, OUTPUT_DESC(y)}}; + +// LogicalNot +INPUT_MAP(LogicalNot) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(LogicalNot) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LogicalNot) = {{0, OUTPUT_DESC(y)}}; + +// Greater +INPUT_MAP(Greater) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Greater) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Greater) = {{0, OUTPUT_DESC(y)}}; + +// LogSoftmaxGrad +INPUT_MAP(LogSoftmaxGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(grad)}}; +ATTR_MAP(LogSoftmaxGrad) = { + {"axis", ATTR_DESC(axis, AnyTraits>(), AnyTraits>())}}; +OUTPUT_MAP(LogSoftmaxGrad) = {{0, OUTPUT_DESC(y)}}; + +// Select +INPUT_MAP(Select) = {{1, INPUT_DESC(condition)}, {2, INPUT_DESC(x1)}, {3, INPUT_DESC(x2)}}; +ATTR_MAP(Select) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Select) = {{0, OUTPUT_DESC(y)}}; + +// LessEqual +INPUT_MAP(LessEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(LessEqual) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LessEqual) = {{0, OUTPUT_DESC(y)}}; + +// LogSoftmaxV2 +INPUT_MAP(LogSoftmaxV2) = {{1, INPUT_DESC(logits)}}; +ATTR_MAP(LogSoftmaxV2) = { + {"axis", ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +OUTPUT_MAP(LogSoftmaxV2) = {{0, OUTPUT_DESC(logsoftmax)}}; + +// RandomChoiceWithMask +INPUT_MAP(RandomChoiceWithMask) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(RandomChoiceWithMask) = {{"count", ATTR_DESC(count, AnyTraits())}, + {"seed", ATTR_DESC(seed, AnyTraits())}, + {"seed2", ATTR_DESC(seed2, AnyTraits())}}; +OUTPUT_MAP(RandomChoiceWithMask) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(mask)}}; + +// TruncatedNormal +INPUT_MAP(TruncatedNormal) = {{1, INPUT_DESC(shape)}}; +ATTR_MAP(TruncatedNormal) = {{"seed", ATTR_DESC(seed, AnyTraits())}, + {"seed2", ATTR_DESC(seed2, AnyTraits())}}; +OUTPUT_MAP(TruncatedNormal) = {{0, OUTPUT_DESC(y)}}; + +// StridedSliceGrad +INPUT_MAP(StridedSliceGrad) = { + {1, INPUT_DESC(dy)}, {2, INPUT_DESC(shape)}, {3, INPUT_DESC(begin)}, {4, INPUT_DESC(end)}, {5, INPUT_DESC(strides)}}; +ATTR_MAP(StridedSliceGrad) = {{"begin_mask", ATTR_DESC(begin_mask, AnyTraits())}, + {"end_mask", ATTR_DESC(end_mask, AnyTraits())}, + {"ellipsis_mask", ATTR_DESC(ellipsis_mask, AnyTraits())}, + {"new_axis_mask", ATTR_DESC(new_axis_mask, AnyTraits())}, + {"shrink_axis_mask", ATTR_DESC(shrink_axis_mask, AnyTraits())}}; +OUTPUT_MAP(StridedSliceGrad) = {{0, OUTPUT_DESC(output)}}; + +// Gelu +INPUT_MAP(Gelu) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Gelu) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Gelu) = {{0, OUTPUT_DESC(y)}}; + +// GeluGrad +INPUT_MAP(GeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(y)}}; +ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}}; + +// StridedSlice +INPUT_MAP(StridedSlice) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(begin)}, {3, INPUT_DESC(end)}, {4, INPUT_DESC(strides)}}; +ATTR_MAP(StridedSlice) = {{"begin_mask", ATTR_DESC(begin_mask, AnyTraits())}, + {"end_mask", ATTR_DESC(end_mask, AnyTraits())}, + {"ellipsis_mask", ATTR_DESC(ellipsis_mask, AnyTraits())}, + {"new_axis_mask", ATTR_DESC(new_axis_mask, AnyTraits())}, + {"shrink_axis_mask", ATTR_DESC(shrink_axis_mask, AnyTraits())}}; +OUTPUT_MAP(StridedSlice) = {{0, OUTPUT_DESC(y)}}; + +// UnsortedSegmentSum +INPUT_MAP(UnsortedSegmentSumD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}}; +INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits())}}; +ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP; +OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}}; + +// UnsortedSegmentMin +INPUT_MAP(UnsortedSegmentMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}, {3, INPUT_DESC(num_segments)}}; +ATTR_MAP(UnsortedSegmentMin) = EMPTY_ATTR_MAP; +OUTPUT_MAP(UnsortedSegmentMin) = {{0, OUTPUT_DESC(y)}}; + +// ExpandDims +INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}}; +ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP; +OUTPUT_MAP(ExpandDims) = {{0, OUTPUT_DESC(y)}}; + +// Squeeze +INPUT_MAP(Squeeze) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Squeeze) = {{"axis", ATTR_DESC(axis, AnyTraits(), AnyTraits>())}}; +OUTPUT_MAP(Squeeze) = {{0, OUTPUT_DESC(y)}}; + +// SGD +INPUT_MAP(SGD) = {{1, INPUT_DESC(parameters)}, {2, INPUT_DESC(gradient)}, {3, INPUT_DESC(learning_rate)}, + {4, INPUT_DESC(accum)}, {5, INPUT_DESC(momentum)}, {6, INPUT_DESC(stat)}}; +ATTR_MAP(SGD) = {{"dampening", ATTR_DESC(dampening, AnyTraits())}, + {"weight_decay", ATTR_DESC(weight_decay, AnyTraits())}, + {"nesterov", ATTR_DESC(nesterov, AnyTraits())}}; +OUTPUT_MAP(SGD) = {{0, OUTPUT_DESC(parameters)}}; + +// LayerNorm +INPUT_MAP(LayerNorm) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(gamma)}, {3, INPUT_DESC(beta)}}; +ATTR_MAP(LayerNorm) = {{"begin_norm_axis", ATTR_DESC(begin_norm_axis, AnyTraits())}, + {"begin_params_axis", ATTR_DESC(begin_params_axis, AnyTraits())}, + {"epsilon", ATTR_DESC(epsilon, AnyTraits())}}; +OUTPUT_MAP(LayerNorm) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(mean)}, {2, OUTPUT_DESC(variance)}}; + +// LayerNormGrad +INPUT_MAP(LayerNormGrad) = { + {1, INPUT_DESC(x)}, {2, INPUT_DESC(dy)}, {3, INPUT_DESC(variance)}, {4, INPUT_DESC(mean)}, {5, INPUT_DESC(gamma)}}; +ATTR_MAP(LayerNormGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LayerNormGrad) = {{0, OUTPUT_DESC(pd_x)}, {1, OUTPUT_DESC(pd_gamma)}, {2, OUTPUT_DESC(pd_beta)}}; + +// BatchMatMul +INPUT_MAP(BatchMatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(BatchMatMul) = {{"transpose_x1", ATTR_DESC(adj_x1, AnyTraits())}, + {"transpose_x2", ATTR_DESC(adj_x2, AnyTraits())}}; +OUTPUT_MAP(BatchMatMul) = {{0, OUTPUT_DESC(y)}}; + +// DropoutDoMask +INPUT_MAP(DropOutDoMask) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(mask)}, {3, INPUT_DESC(keep_prob)}}; +ATTR_MAP(DropOutDoMask) = EMPTY_ATTR_MAP; +OUTPUT_MAP(DropOutDoMask) = {{0, OUTPUT_DESC(y)}}; + +// NPUGetFloatStatus +INPUT_MAP(NPUGetFloatStatus) = {{1, INPUT_DESC(addr)}}; +OUTPUT_MAP(NPUGetFloatStatus) = {{0, OUTPUT_DESC(data)}}; +ATTR_MAP(NPUGetFloatStatus) = EMPTY_ATTR_MAP; + +// NPUAllocFloatStatus +INPUT_MAP(NPUAllocFloatStatus) = EMPTY_INPUT_MAP; +ATTR_MAP(NPUAllocFloatStatus) = EMPTY_ATTR_MAP; +OUTPUT_MAP(NPUAllocFloatStatus) = {{0, OUTPUT_DESC(data)}}; + +// NPUClearFloatStatus +INPUT_MAP(NPUClearFloatStatus) = {{1, INPUT_DESC(addr)}}; +OUTPUT_MAP(NPUClearFloatStatus) = {{0, OUTPUT_DESC(data)}}; +ATTR_MAP(NPUClearFloatStatus) = EMPTY_ATTR_MAP; + +// Abs +INPUT_MAP(Abs) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Abs) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Abs) = {{0, OUTPUT_DESC(y)}}; + +// AbsGrad +INPUT_MAP(AbsGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(AbsGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AbsGrad) = {{0, OUTPUT_DESC(z)}}; + +// BinaryCrossEntropy +INPUT_MAP(BinaryCrossEntropy) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}, {3, INPUT_DESC(weight)}}; +ATTR_MAP(BinaryCrossEntropy) = {{"reduction", ATTR_DESC(reduction, AnyTraits())}}; +OUTPUT_MAP(BinaryCrossEntropy) = {{0, OUTPUT_DESC(output)}}; + +// BinaryCrossEntropyGrad +INPUT_MAP(BinaryCrossEntropyGrad) = { + {1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}, {3, INPUT_DESC(grad_output)}, {4, INPUT_DESC(weight)}}; +ATTR_MAP(BinaryCrossEntropyGrad) = {{"reduction", ATTR_DESC(reduction, AnyTraits())}}; +OUTPUT_MAP(BinaryCrossEntropyGrad) = {{0, OUTPUT_DESC(output)}}; + +// SparseApplyAdagradD +INPUT_MAP(SparseApplyAdagradD) = { + {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}}; +ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits())}, + {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; + +// ApplyProximalAdagradD +INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, + {4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}}; +ATTR_MAP(ApplyProximalAdagradD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyProximalAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; + +// SparseApplyFtrlD +INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)}, + {2, INPUT_DESC(accum)}, + {3, INPUT_DESC(linear)}, + {4, INPUT_DESC(grad)}, + {5, INPUT_DESC(indices)}}; +ATTR_MAP(SparseApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, + {"lr", ATTR_DESC(lr, AnyTraits())}, + {"l1", ATTR_DESC(l1, AnyTraits())}, + {"l2", ATTR_DESC(l2, AnyTraits())}, + {"lr_power", ATTR_DESC(lr_power, AnyTraits())}}; +OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}}; + +// SpaceToDepth +INPUT_MAP(SpaceToDepth) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(SpaceToDepth) = {{"block_size", ATTR_DESC(block_size, AnyTraits())}}; +OUTPUT_MAP(SpaceToDepth) = {{0, OUTPUT_DESC(y)}}; + +// DepthToSpace +INPUT_MAP(DepthToSpace) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(DepthToSpace) = {{"block_size", ATTR_DESC(block_size, AnyTraits())}}; +OUTPUT_MAP(DepthToSpace) = {{0, OUTPUT_DESC(y)}}; + +// Sign +INPUT_MAP(Sign) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Sign) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Sign) = {{0, OUTPUT_DESC(y)}}; + +// Round +INPUT_MAP(Round) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Round) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}}; + +// ApplyFtrlD +INPUT_MAP(ApplyFtrlD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)}, + {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)}, + {7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}}; +ATTR_MAP(ApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}}; + +// Diag +INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Diag) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Diag) = {{0, OUTPUT_DESC(y)}}; + +// DiagPart +INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP; +OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}}; + +// SpaceToBatchD +INPUT_MAP(SpaceToBatchD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(SpaceToBatchD) = { + {"block_size", ATTR_DESC(block_size, AnyTraits())}, + {"paddings", ATTR_DESC(paddings, AnyTraits>>(), AnyTraits>())}}; +OUTPUT_MAP(SpaceToBatchD) = {{0, OUTPUT_DESC(y)}}; + +// BatchToSpaceD +INPUT_MAP(BatchToSpaceD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(BatchToSpaceD) = { + {"block_size", ATTR_DESC(block_size, AnyTraits())}, + {"crops", ATTR_DESC(crops, AnyTraits>>(), AnyTraits>())}}; +OUTPUT_MAP(BatchToSpaceD) = {{0, OUTPUT_DESC(y)}}; + +// Atan2 +INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Atan2) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}}; + +// ApplyRMSPropD +INPUT_MAP(ApplyRMSPropD) = { + {1, INPUT_DESC(var)}, {2, INPUT_DESC(ms)}, {3, INPUT_DESC(mom)}, {4, INPUT_DESC(lr)}, {5, INPUT_DESC(grad)}}; +INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits())}, + {7, ATTR_DESC(momentum, AnyTraits())}, + {8, ATTR_DESC(epsilon, AnyTraits())}}; +ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}}; + +// ApplyCenteredRMSProp +INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, + {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; +ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; + +// L2Loss +INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP; +OUTPUT_MAP(L2Loss) = {{0, OUTPUT_DESC(y)}}; + +// CTCLoss +INPUT_MAP(CTCLoss) = {{1, INPUT_DESC(inputs)}, + {2, INPUT_DESC(labels_indices)}, + {3, INPUT_DESC(labels_values)}, + {4, INPUT_DESC(sequence_length)}}; +ATTR_MAP(CTCLoss) = { + {"preprocess_collapse_repeated", ATTR_DESC(preprocess_collapse_repeated, AnyTraits())}, + {"ctc_merge_repeated", ATTR_DESC(ctc_merge_repeated, AnyTraits())}, + {"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits())}}; +OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}}; + +// AscendQuant +INPUT_MAP(AscendQuant) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(AscendQuant) = {{"scale", ATTR_DESC(scale, AnyTraits())}, + {"offset", ATTR_DESC(offset, AnyTraits())}, + {"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits())}, + {"round_mode", ATTR_DESC(round_mode, AnyTraits())}}; +OUTPUT_MAP(AscendQuant) = {{0, OUTPUT_DESC(y)}}; + +// AscendDequant +INPUT_MAP(AscendDequant) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(deq_scale)}}; +ATTR_MAP(AscendDequant) = {{"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits())}, + {"relu_flag", ATTR_DESC(relu_flag, AnyTraits())}}; +OUTPUT_MAP(AscendDequant) = {{0, OUTPUT_DESC(y)}}; +#ifdef ENABLE_GE +// Print +INPUT_MAP(Print) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(Print) = {{1, DYN_INPUT_DESC(x)}}; +ATTR_MAP(Print) = EMPTY_ATTR_MAP; +#endif +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare.h new file mode 100755 index 0000000000..e493ea0e52 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.h @@ -0,0 +1,505 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_OP_DECLARE_H_ +#define TRANSFORM_OP_DECLARE_H_ + +#include +#include +#include "transform/graph_ir/op_adapter.h" + +namespace mindspore { +namespace transform { +#define DECLARE_OP_ADAPTER(T) \ + using T = ge::op::T; \ + template <> \ + const std::unordered_map OpAdapter::input_map_; \ + template <> \ + const std::unordered_map OpAdapter::attr_map_; + +#define DECLARE_OP_USE_OUTPUT(T) \ + template <> \ + const std::unordered_map OpAdapter::output_map_; + +#define DECLARE_OP_USE_ENUM(T) \ + template <> \ + const std::unordered_map OpAdapter::enum_map_; + +#define DECLARE_OP_USE_INPUT_ATTR(T) \ + template <> \ + const std::unordered_map OpAdapter::input_attr_map_; + +#define DECLARE_OP_USE_DYN_INPUT(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_input_map_; + +#define DECLARE_OP_USE_DYN_SUBGRAPH(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_subgraph_map_; + +#define DECLARE_OP_USE_DYN_OUTPUT(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_output_map_; + +template <> +std::unordered_map> OpAdapter::cus_input_map_; +template <> +std::unordered_map> OpAdapter::cus_output_map_; + +DECLARE_OP_ADAPTER(GreaterEqual) +DECLARE_OP_USE_OUTPUT(GreaterEqual) +DECLARE_OP_ADAPTER(SliceD) +DECLARE_OP_USE_INPUT_ATTR(SliceD) +DECLARE_OP_USE_OUTPUT(SliceD) +DECLARE_OP_ADAPTER(AssignAdd) +DECLARE_OP_USE_OUTPUT(AssignAdd) +DECLARE_OP_ADAPTER(AssignSub) +DECLARE_OP_USE_OUTPUT(AssignSub) + +DECLARE_OP_ADAPTER(ReduceMean) +DECLARE_OP_ADAPTER(Multiply) +DECLARE_OP_USE_OUTPUT(Multiply) + +// ** Distributed Operations ** +DECLARE_OP_ADAPTER(HcomReduceScatter) +DECLARE_OP_USE_OUTPUT(HcomReduceScatter) +DECLARE_OP_ADAPTER(HcomBroadcast) +DECLARE_OP_USE_DYN_INPUT(HcomBroadcast) +DECLARE_OP_USE_DYN_OUTPUT(HcomBroadcast) +DECLARE_OP_ADAPTER(HcomAllReduce) +DECLARE_OP_USE_OUTPUT(HcomAllReduce) +DECLARE_OP_ADAPTER(HcomAllGather) +DECLARE_OP_USE_OUTPUT(HcomAllGather) +DECLARE_OP_ADAPTER(Variable) +DECLARE_OP_ADAPTER(ReluGrad) +DECLARE_OP_USE_OUTPUT(ReluGrad) +DECLARE_OP_ADAPTER(BiasAddGrad) +DECLARE_OP_USE_OUTPUT(BiasAddGrad) +DECLARE_OP_ADAPTER(MaxPoolWithArgmax) +DECLARE_OP_USE_OUTPUT(MaxPoolWithArgmax) +DECLARE_OP_ADAPTER(MaxPoolGradWithArgmax) +DECLARE_OP_USE_OUTPUT(MaxPoolGradWithArgmax) +DECLARE_OP_ADAPTER(Conv2D) +DECLARE_OP_USE_ENUM(Conv2D) +DECLARE_OP_USE_OUTPUT(Conv2D) +DECLARE_OP_ADAPTER(ExtractImagePatches) +DECLARE_OP_USE_OUTPUT(ExtractImagePatches) +DECLARE_OP_ADAPTER(Conv2DBackpropInputD) +DECLARE_OP_USE_ENUM(Conv2DBackpropInputD) +DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropInputD) +DECLARE_OP_USE_OUTPUT(Conv2DBackpropInputD) +DECLARE_OP_ADAPTER(Conv2DBackpropFilterD) +DECLARE_OP_USE_ENUM(Conv2DBackpropFilterD) +DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropFilterD) +DECLARE_OP_USE_OUTPUT(Conv2DBackpropFilterD) +DECLARE_OP_ADAPTER(DepthwiseConv2D) +DECLARE_OP_USE_ENUM(DepthwiseConv2D) +DECLARE_OP_USE_OUTPUT(DepthwiseConv2D) +DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropFilterD) +DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropFilterD) +DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropFilterD) +DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropInputD) +DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD) +DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD) +DECLARE_OP_ADAPTER(Reshape) +DECLARE_OP_USE_OUTPUT(Reshape) +DECLARE_OP_ADAPTER(TransShape) +DECLARE_OP_USE_INPUT_ATTR(TransShape) +DECLARE_OP_USE_OUTPUT(TransShape) +DECLARE_OP_ADAPTER(Iou) +DECLARE_OP_USE_OUTPUT(Iou) +DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D) +DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2D) +DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad) +DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad) +DECLARE_OP_ADAPTER(ApplyAdam) +DECLARE_OP_USE_OUTPUT(ApplyAdam) +DECLARE_OP_ADAPTER(ApplyAdamD) +DECLARE_OP_USE_OUTPUT(ApplyAdamD) +DECLARE_OP_ADAPTER(Relu6) +DECLARE_OP_USE_OUTPUT(Relu6) +DECLARE_OP_ADAPTER(Relu6Grad) +DECLARE_OP_USE_OUTPUT(Relu6Grad) +DECLARE_OP_ADAPTER(ResizeBilinearV2D) +DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D) +DECLARE_OP_ADAPTER(ResizeBilinearV2Grad) +DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad) +DECLARE_OP_ADAPTER(ZerosLike) +DECLARE_OP_USE_OUTPUT(ZerosLike) +DECLARE_OP_ADAPTER(OnesLike) +DECLARE_OP_USE_OUTPUT(OnesLike) +DECLARE_OP_ADAPTER(TensorScatterUpdate) +DECLARE_OP_USE_OUTPUT(TensorScatterUpdate) +DECLARE_OP_ADAPTER(ScatterUpdate) +DECLARE_OP_USE_OUTPUT(ScatterUpdate) +DECLARE_OP_ADAPTER(ScatterNdUpdate) +DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) +DECLARE_OP_ADAPTER(ScatterMax) +DECLARE_OP_USE_OUTPUT(ScatterMax) +DECLARE_OP_ADAPTER(NMSWithMask) +DECLARE_OP_USE_OUTPUT(NMSWithMask) +DECLARE_OP_ADAPTER(Unpack) +DECLARE_OP_USE_DYN_OUTPUT(Unpack) +DECLARE_OP_ADAPTER(CheckValid) +DECLARE_OP_USE_OUTPUT(CheckValid) +DECLARE_OP_ADAPTER(SmoothL1Loss) +DECLARE_OP_USE_OUTPUT(SmoothL1Loss) +DECLARE_OP_ADAPTER(SmoothL1LossGrad) +DECLARE_OP_USE_OUTPUT(SmoothL1LossGrad) +DECLARE_OP_ADAPTER(SigmoidCrossEntropyWithLogits) +DECLARE_OP_USE_OUTPUT(SigmoidCrossEntropyWithLogits) +DECLARE_OP_ADAPTER(SigmoidCrossEntropyWithLogitsGrad) +DECLARE_OP_USE_OUTPUT(SigmoidCrossEntropyWithLogitsGrad) +DECLARE_OP_ADAPTER(ScatterNdD) +DECLARE_OP_USE_INPUT_ATTR(ScatterNdD) +DECLARE_OP_USE_OUTPUT(ScatterNdD) +DECLARE_OP_ADAPTER(PadD) +DECLARE_OP_USE_OUTPUT(PadD) +DECLARE_OP_ADAPTER(MirrorPad) +DECLARE_OP_USE_OUTPUT(MirrorPad) +DECLARE_OP_ADAPTER(MirrorPadGrad) +DECLARE_OP_USE_OUTPUT(MirrorPadGrad) +DECLARE_OP_ADAPTER(BoundingBoxEncode) +DECLARE_OP_USE_OUTPUT(BoundingBoxEncode) +DECLARE_OP_ADAPTER(BoundingBoxDecode) +DECLARE_OP_USE_OUTPUT(BoundingBoxDecode) +DECLARE_OP_ADAPTER(GatherNd) +DECLARE_OP_USE_OUTPUT(GatherNd) +DECLARE_OP_ADAPTER(ArgMaxD) +DECLARE_OP_USE_OUTPUT(ArgMaxD) +DECLARE_OP_ADAPTER(ArgMinD) +DECLARE_OP_USE_OUTPUT(ArgMinD) +DECLARE_OP_ADAPTER(ArgMaxWithValue) +DECLARE_OP_USE_OUTPUT(ArgMaxWithValue) +DECLARE_OP_ADAPTER(ArgMinWithValue) +DECLARE_OP_USE_OUTPUT(ArgMinWithValue) +DECLARE_OP_ADAPTER(Mul) +DECLARE_OP_USE_OUTPUT(Mul) +DECLARE_OP_ADAPTER(AddN) +DECLARE_OP_USE_DYN_INPUT(AddN) +DECLARE_OP_USE_OUTPUT(AddN) +DECLARE_OP_ADAPTER(Less) +DECLARE_OP_USE_OUTPUT(Less) +DECLARE_OP_ADAPTER(Rsqrt) +DECLARE_OP_USE_OUTPUT(Rsqrt) +DECLARE_OP_ADAPTER(Sqrt) +DECLARE_OP_USE_OUTPUT(Sqrt) +DECLARE_OP_ADAPTER(Square) +DECLARE_OP_USE_OUTPUT(Square) +DECLARE_OP_ADAPTER(SplitD) +DECLARE_OP_USE_DYN_OUTPUT(SplitD) +DECLARE_OP_ADAPTER(SGD) +DECLARE_OP_USE_OUTPUT(SGD) +DECLARE_OP_ADAPTER(SquareSumAll) +DECLARE_OP_USE_OUTPUT(SquareSumAll) + +DECLARE_OP_ADAPTER(Tanh) +DECLARE_OP_USE_OUTPUT(Tanh) +DECLARE_OP_ADAPTER(TanhGrad) +DECLARE_OP_USE_OUTPUT(TanhGrad) +DECLARE_OP_ADAPTER(Maximum) +DECLARE_OP_USE_OUTPUT(Maximum) +DECLARE_OP_ADAPTER(Minimum) +DECLARE_OP_USE_OUTPUT(Minimum) +DECLARE_OP_ADAPTER(MaximumGrad) +DECLARE_OP_USE_OUTPUT(MaximumGrad) +DECLARE_OP_ADAPTER(MinimumGrad) +DECLARE_OP_USE_OUTPUT(MinimumGrad) +DECLARE_OP_ADAPTER(ReduceMinD) +DECLARE_OP_USE_INPUT_ATTR(ReduceMinD) +DECLARE_OP_USE_OUTPUT(ReduceMinD) +DECLARE_OP_ADAPTER(ReduceMaxD) +DECLARE_OP_USE_INPUT_ATTR(ReduceMaxD) +DECLARE_OP_USE_OUTPUT(ReduceMaxD) +DECLARE_OP_ADAPTER(Merge) +DECLARE_OP_USE_DYN_INPUT(Merge) +DECLARE_OP_USE_OUTPUT(Merge) +DECLARE_OP_ADAPTER(Switch) +DECLARE_OP_USE_OUTPUT(Switch) + +DECLARE_OP_ADAPTER(TopK) +DECLARE_OP_USE_OUTPUT(TopK) + +DECLARE_OP_ADAPTER(RealDiv) +DECLARE_OP_USE_OUTPUT(RealDiv) + +DECLARE_OP_ADAPTER(Cast) +DECLARE_OP_USE_INPUT_ATTR(Cast) +DECLARE_OP_USE_OUTPUT(Cast) +DECLARE_OP_ADAPTER(Case) +DECLARE_OP_USE_DYN_INPUT(Case) +DECLARE_OP_USE_DYN_SUBGRAPH(Case) +DECLARE_OP_USE_DYN_OUTPUT(Case) +DECLARE_OP_ADAPTER(Reciprocal) +DECLARE_OP_USE_OUTPUT(Reciprocal) +DECLARE_OP_ADAPTER(Neg) +DECLARE_OP_USE_OUTPUT(Neg) +DECLARE_OP_ADAPTER(TransposeD) +DECLARE_OP_USE_INPUT_ATTR(TransposeD) +// Do not set Transpose operator output descriptor +DECLARE_OP_ADAPTER(Sub) +DECLARE_OP_USE_OUTPUT(Sub) +DECLARE_OP_ADAPTER(DropOutGenMask) +DECLARE_OP_USE_OUTPUT(DropOutGenMask) +DECLARE_OP_ADAPTER(ConcatD) +DECLARE_OP_USE_DYN_INPUT(ConcatD) +DECLARE_OP_USE_OUTPUT(ConcatD) +DECLARE_OP_ADAPTER(Pack) +DECLARE_OP_USE_DYN_INPUT(Pack) +DECLARE_OP_USE_OUTPUT(Pack) + +DECLARE_OP_ADAPTER(Pow) +DECLARE_OP_USE_OUTPUT(Pow) +DECLARE_OP_ADAPTER(Equal) +DECLARE_OP_USE_OUTPUT(Equal) +DECLARE_OP_ADAPTER(NotEqual) +DECLARE_OP_USE_OUTPUT(NotEqual) +DECLARE_OP_ADAPTER(Log) +DECLARE_OP_USE_OUTPUT(Log) +DECLARE_OP_ADAPTER(LogicalAnd) +DECLARE_OP_USE_OUTPUT(LogicalAnd) +DECLARE_OP_ADAPTER(LogicalOr) +DECLARE_OP_USE_OUTPUT(LogicalOr) +DECLARE_OP_ADAPTER(LogicalNot) +DECLARE_OP_USE_OUTPUT(LogicalNot) +DECLARE_OP_ADAPTER(LogSoftmaxGrad) +DECLARE_OP_USE_OUTPUT(LogSoftmaxGrad) + +DECLARE_OP_ADAPTER(RandomChoiceWithMask) +DECLARE_OP_USE_OUTPUT(RandomChoiceWithMask) + +DECLARE_OP_ADAPTER(Select) +DECLARE_OP_USE_OUTPUT(Select) +DECLARE_OP_ADAPTER(LessEqual) +DECLARE_OP_USE_OUTPUT(LessEqual) +DECLARE_OP_ADAPTER(LogSoftmaxV2) +DECLARE_OP_USE_OUTPUT(LogSoftmaxV2) +DECLARE_OP_ADAPTER(TruncatedNormal) +DECLARE_OP_USE_OUTPUT(TruncatedNormal) +DECLARE_OP_ADAPTER(StridedSliceGrad) +DECLARE_OP_USE_OUTPUT(StridedSliceGrad) +DECLARE_OP_ADAPTER(Gelu) +DECLARE_OP_USE_OUTPUT(Gelu) +DECLARE_OP_ADAPTER(GeluGrad) +DECLARE_OP_USE_OUTPUT(GeluGrad) +DECLARE_OP_ADAPTER(StridedSlice) +DECLARE_OP_USE_OUTPUT(StridedSlice) +DECLARE_OP_ADAPTER(UnsortedSegmentSumD) +DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD) +DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD) +DECLARE_OP_ADAPTER(UnsortedSegmentMin) +DECLARE_OP_USE_OUTPUT(UnsortedSegmentMin) +DECLARE_OP_ADAPTER(ExpandDims) +DECLARE_OP_USE_OUTPUT(ExpandDims) +DECLARE_OP_ADAPTER(Squeeze) +DECLARE_OP_USE_OUTPUT(Squeeze) +DECLARE_OP_ADAPTER(LayerNorm) +DECLARE_OP_USE_OUTPUT(LayerNorm) +DECLARE_OP_ADAPTER(LayerNormGrad) +DECLARE_OP_USE_OUTPUT(LayerNormGrad) +DECLARE_OP_ADAPTER(BatchMatMul) +DECLARE_OP_USE_OUTPUT(BatchMatMul) +DECLARE_OP_ADAPTER(DropOutDoMask) +DECLARE_OP_USE_OUTPUT(DropOutDoMask) +// ** Mix-precision Operations ** +DECLARE_OP_ADAPTER(NPUGetFloatStatus) +DECLARE_OP_USE_OUTPUT(NPUGetFloatStatus) +DECLARE_OP_ADAPTER(NPUAllocFloatStatus) +DECLARE_OP_USE_OUTPUT(NPUAllocFloatStatus) +DECLARE_OP_ADAPTER(NPUClearFloatStatus) +DECLARE_OP_USE_OUTPUT(NPUClearFloatStatus) +DECLARE_OP_ADAPTER(MatMulV2) +DECLARE_OP_USE_OUTPUT(MatMulV2) + +DECLARE_OP_ADAPTER(SoftmaxCrossEntropyWithLogits) +DECLARE_OP_USE_OUTPUT(SoftmaxCrossEntropyWithLogits) + +DECLARE_OP_ADAPTER(MeanGrad) +DECLARE_OP_USE_INPUT_ATTR(MeanGrad) + +DECLARE_OP_ADAPTER(Assign) +DECLARE_OP_USE_OUTPUT(Assign) +DECLARE_OP_ADAPTER(Constant) +DECLARE_OP_USE_OUTPUT(Constant) +DECLARE_OP_ADAPTER(ApplyMomentumD) +DECLARE_OP_USE_OUTPUT(ApplyMomentumD) +// ** Summary Operations ** +DECLARE_OP_ADAPTER(Summary) + +// fully supported +DECLARE_OP_ADAPTER(Add) +DECLARE_OP_USE_OUTPUT(Add) +DECLARE_OP_ADAPTER(Const) +DECLARE_OP_USE_OUTPUT(Const) +DECLARE_OP_ADAPTER(Cos) +DECLARE_OP_USE_OUTPUT(Cos) + +DECLARE_OP_ADAPTER(Acos) +DECLARE_OP_USE_OUTPUT(Acos) +DECLARE_OP_ADAPTER(AcosGrad) +DECLARE_OP_USE_OUTPUT(AcosGrad) +DECLARE_OP_ADAPTER(Acosh) +DECLARE_OP_USE_OUTPUT(Acosh) +DECLARE_OP_ADAPTER(AcoshGrad) +DECLARE_OP_USE_OUTPUT(AcoshGrad) + +DECLARE_OP_ADAPTER(Floor) +DECLARE_OP_USE_OUTPUT(Floor) +DECLARE_OP_ADAPTER(FloorDiv) +DECLARE_OP_USE_OUTPUT(FloorDiv) +DECLARE_OP_ADAPTER(FloorMod) +DECLARE_OP_USE_OUTPUT(FloorMod) +DECLARE_OP_ADAPTER(Sin) +DECLARE_OP_USE_OUTPUT(Sin) +DECLARE_OP_ADAPTER(Exp) +DECLARE_OP_USE_OUTPUT(Exp) + +DECLARE_OP_ADAPTER(ReduceAllD) +DECLARE_OP_USE_INPUT_ATTR(ReduceAllD) +DECLARE_OP_USE_OUTPUT(ReduceAllD) +DECLARE_OP_ADAPTER(ReduceSumD) +DECLARE_OP_USE_INPUT_ATTR(ReduceSumD) +DECLARE_OP_USE_OUTPUT(ReduceSumD) +DECLARE_OP_ADAPTER(ReduceMeanD) +DECLARE_OP_USE_INPUT_ATTR(ReduceMeanD) +DECLARE_OP_USE_OUTPUT(ReduceMeanD) +DECLARE_OP_ADAPTER(ReduceProdD) +DECLARE_OP_USE_INPUT_ATTR(ReduceProdD) +DECLARE_OP_USE_OUTPUT(ReduceProdD) +DECLARE_OP_ADAPTER(CumprodD) +DECLARE_OP_USE_INPUT_ATTR(CumprodD) +DECLARE_OP_USE_OUTPUT(CumprodD) + +DECLARE_OP_ADAPTER(TileD) +DECLARE_OP_USE_INPUT_ATTR(TileD) +DECLARE_OP_USE_OUTPUT(TileD) +DECLARE_OP_ADAPTER(OneHot) +DECLARE_OP_USE_OUTPUT(OneHot) +DECLARE_OP_ADAPTER(GatherV2D) +DECLARE_OP_USE_INPUT_ATTR(GatherV2D) +DECLARE_OP_USE_OUTPUT(GatherV2D) +DECLARE_OP_ADAPTER(RangeD) +DECLARE_OP_USE_OUTPUT(RangeD) + +DECLARE_OP_ADAPTER(Data) +DECLARE_OP_ADAPTER(BiasAdd) +DECLARE_OP_USE_OUTPUT(BiasAdd) +DECLARE_OP_ADAPTER(BatchNorm) +DECLARE_OP_USE_OUTPUT(BatchNorm) +DECLARE_OP_ADAPTER(BatchNormGrad) +DECLARE_OP_USE_OUTPUT(BatchNormGrad) +DECLARE_OP_ADAPTER(Relu) +DECLARE_OP_USE_OUTPUT(Relu) +DECLARE_OP_ADAPTER(PRelu) +DECLARE_OP_USE_OUTPUT(PRelu) +DECLARE_OP_ADAPTER(Elu) +DECLARE_OP_USE_OUTPUT(Elu) + +DECLARE_OP_ADAPTER(EluGrad) +DECLARE_OP_USE_OUTPUT(EluGrad) +DECLARE_OP_ADAPTER(PReluGrad) +DECLARE_OP_USE_OUTPUT(PReluGrad) + +DECLARE_OP_ADAPTER(L2Normalize) +DECLARE_OP_USE_OUTPUT(L2Normalize) + +DECLARE_OP_ADAPTER(CumsumD) +DECLARE_OP_USE_INPUT_ATTR(CumsumD) +DECLARE_OP_USE_OUTPUT(CumsumD) +DECLARE_OP_ADAPTER(L2NormalizeGrad) +DECLARE_OP_USE_OUTPUT(L2NormalizeGrad) +DECLARE_OP_ADAPTER(Sigmoid) +DECLARE_OP_USE_OUTPUT(Sigmoid) +DECLARE_OP_ADAPTER(SigmoidGrad) +DECLARE_OP_USE_OUTPUT(SigmoidGrad) +DECLARE_OP_ADAPTER(SoftmaxV2) +DECLARE_OP_USE_OUTPUT(SoftmaxV2) +DECLARE_OP_ADAPTER(SoftmaxGrad) +DECLARE_OP_USE_OUTPUT(SoftmaxGrad) +DECLARE_OP_ADAPTER(Greater) +DECLARE_OP_USE_OUTPUT(Greater) +DECLARE_OP_ADAPTER(Flatten) +DECLARE_OP_USE_OUTPUT(Flatten) +DECLARE_OP_ADAPTER(GatherV2) +DECLARE_OP_USE_OUTPUT(GatherV2) +DECLARE_OP_ADAPTER(MaxPool) +DECLARE_OP_USE_OUTPUT(MaxPool) +DECLARE_OP_ADAPTER(MaxPoolGrad) +DECLARE_OP_USE_OUTPUT(MaxPoolGrad) +DECLARE_OP_ADAPTER(AvgPool) +DECLARE_OP_USE_OUTPUT(AvgPool) +DECLARE_OP_ADAPTER(AvgPoolGrad) +DECLARE_OP_USE_OUTPUT(AvgPoolGrad) +DECLARE_OP_ADAPTER(ROIAlign) +DECLARE_OP_USE_OUTPUT(ROIAlign) +DECLARE_OP_ADAPTER(ROIAlignGrad) +DECLARE_OP_USE_OUTPUT(ROIAlignGrad) +DECLARE_OP_ADAPTER(Abs) +DECLARE_OP_USE_OUTPUT(Abs) +DECLARE_OP_ADAPTER(AbsGrad) +DECLARE_OP_USE_OUTPUT(AbsGrad) +DECLARE_OP_ADAPTER(BinaryCrossEntropy) +DECLARE_OP_USE_OUTPUT(BinaryCrossEntropy) +DECLARE_OP_ADAPTER(BinaryCrossEntropyGrad) +DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad) +DECLARE_OP_ADAPTER(SparseApplyAdagradD) +DECLARE_OP_USE_OUTPUT(SparseApplyAdagradD) +DECLARE_OP_ADAPTER(ApplyProximalAdagradD) +DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD) +DECLARE_OP_ADAPTER(SpaceToDepth) +DECLARE_OP_USE_OUTPUT(SpaceToDepth) +DECLARE_OP_ADAPTER(DepthToSpace) +DECLARE_OP_USE_OUTPUT(DepthToSpace) +DECLARE_OP_ADAPTER(Sign) +DECLARE_OP_USE_OUTPUT(Sign) +DECLARE_OP_ADAPTER(LarsV2Update) +DECLARE_OP_USE_OUTPUT(LarsV2Update) +DECLARE_OP_ADAPTER(Round) +DECLARE_OP_USE_OUTPUT(Round) +DECLARE_OP_ADAPTER(ApplyFtrlD) +DECLARE_OP_USE_OUTPUT(ApplyFtrlD) +DECLARE_OP_ADAPTER(SparseApplyFtrlD) +DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) +DECLARE_OP_ADAPTER(Diag) +DECLARE_OP_USE_OUTPUT(Diag) +DECLARE_OP_ADAPTER(DiagPart) +DECLARE_OP_USE_OUTPUT(DiagPart) +DECLARE_OP_ADAPTER(SpaceToBatchD) +DECLARE_OP_USE_OUTPUT(SpaceToBatchD) +DECLARE_OP_ADAPTER(BatchToSpaceD) +DECLARE_OP_USE_OUTPUT(BatchToSpaceD) +DECLARE_OP_ADAPTER(Atan2) +DECLARE_OP_USE_OUTPUT(Atan2) +DECLARE_OP_ADAPTER(ApplyRMSPropD) +DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) +DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) +DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) +DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) +DECLARE_OP_ADAPTER(L2Loss) +DECLARE_OP_USE_OUTPUT(L2Loss) +DECLARE_OP_ADAPTER(CTCLoss) +DECLARE_OP_USE_OUTPUT(CTCLoss) +DECLARE_OP_ADAPTER(AscendQuant) +DECLARE_OP_USE_OUTPUT(AscendQuant) +DECLARE_OP_ADAPTER(AscendDequant) +DECLARE_OP_USE_OUTPUT(AscendDequant) +#ifdef ENABLE_GE +DECLARE_OP_ADAPTER(Print) +DECLARE_OP_USE_DYN_INPUT(Print) +#endif +} // namespace transform +} // namespace mindspore +#endif // TRANSFORM_OP_DECLARE_H_ diff --git a/mindspore/ccsrc/transform/types.h b/mindspore/ccsrc/transform/graph_ir/types.h similarity index 100% rename from mindspore/ccsrc/transform/types.h rename to mindspore/ccsrc/transform/graph_ir/types.h diff --git a/mindspore/ccsrc/transform/graph_ir/util.cc b/mindspore/ccsrc/transform/graph_ir/util.cc new file mode 100644 index 0000000000..6ae665d69f --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/util.cc @@ -0,0 +1,452 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/util.h" + +#include +#include +#include + +#include "securec/include/securec.h" +#include "utils/convert_utils.h" +#include "utils/utils.h" + +namespace mindspore { +namespace transform { +using std::make_shared; +using std::shared_ptr; +using std::string; +using std::vector; + +const size_t kErrorSize = 0; + +vector TransformUtil::ConvertIntToList(int64_t data, int size) { + vector list{}; + if (size <= 0) { + MS_LOG(WARNING) << "size <= 0"; + return list; + } + for (int i = 0; i < size; ++i) { + list.push_back(data); + } + return list; +} + +static std::map datatype_trans_map = { + {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16}, {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT}, + {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE}, {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8}, + {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16}, {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32}, + {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64}, {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8}, + {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32}, + {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}}; + +GeDataType TransformUtil::ConvertDataType(const MeDataType &type) { + MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type"; + if (datatype_trans_map.find(type) != datatype_trans_map.end()) { + return datatype_trans_map[type]; + } else { + return GeDataType::DT_UNDEFINED; + } +} + +static std::map datatype_size_map = { + {MeDataType::kNumberTypeFloat16, sizeof(float) / 2}, {MeDataType::kNumberTypeFloat32, sizeof(float)}, // 1/2 of float + {MeDataType::kNumberTypeFloat64, sizeof(double)}, {MeDataType::kNumberTypeInt8, sizeof(int8_t)}, + {MeDataType::kNumberTypeInt16, sizeof(int16_t)}, {MeDataType::kNumberTypeInt32, sizeof(int32_t)}, + {MeDataType::kNumberTypeInt64, sizeof(int64_t)}, {MeDataType::kNumberTypeUInt8, sizeof(uint8_t)}, + {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)}, + {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}}; + +size_t TransformUtil::GetDataTypeSize(const MeDataType &type) { + if (datatype_size_map.find(type) != datatype_size_map.end()) { + return datatype_size_map[type]; + } else { + MS_LOG(ERROR) << "Illegal tensor data type!"; + return kErrorSize; + } +} + +GeFormat TransformUtil::ConvertFormat(const string &format) { + if (format == kOpFormat_NCHW) { + return GeFormat::FORMAT_NCHW; + } else if (format == kOpFormat_NC1HWC0) { + return GeFormat::FORMAT_NC1HWC0; + } else if (format == kOpFormat_NHWC) { + return GeFormat::FORMAT_NHWC; + } else if (format == kOpFormat_HWCN) { + return GeFormat::FORMAT_HWCN; + } else { + return GeFormat::FORMAT_ND; + } +} + +static int64_t IntegerCastFunc(size_t temp) { return static_cast(temp); } + +std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector &me_shape, + const MeDataType &me_type, const std::string &format) { + // convert me shape to ge shape + std::vector ge_shape; + + if (me_shape.size() == 1) { + ge_shape.push_back(static_cast(me_shape[0])); + } else { + ge_shape.resize(me_shape.size()); + (void)std::transform(me_shape.begin(), me_shape.end(), ge_shape.begin(), IntegerCastFunc); + } + + GeShape shape(ge_shape); + if (shape.GetDimNum() == 0) { + MS_LOG(INFO) << "The dims size of Ge tensor is zero"; + } + // convert me format to ge format + GeFormat ge_format = ConvertFormat(format); + if (ge_format == GeFormat::FORMAT_ND) { + MS_LOG(ERROR) << "undefined data format : " << static_cast(ge_format); + return nullptr; + } + // convert me datatype to ge datatype + GeDataType data_type = ConvertDataType(me_type); + if (data_type == GeDataType::DT_UNDEFINED) { + MS_LOG(ERROR) << "undefined data type :" << me_type; + return nullptr; + } + + auto desc = std::make_shared(shape, ge_format, data_type); + if (desc == nullptr) { + MS_LOG(ERROR) << "Create GeTensorDesc failed!"; + return nullptr; + } + MS_LOG(INFO) << "SetRealDimCnt is :" << me_shape.size(); + desc->SetRealDimCnt(SizeToInt(me_shape.size())); + return desc; +} + +// if failed, return empty vector. +std::vector TransformUtil::ConvertInputTensors(const std::vector &me_tensors, + const std::string &format) { + std::vector ge_tensors; + + for (size_t index = 0; index < me_tensors.size(); index++) { + MS_EXCEPTION_IF_NULL(me_tensors[index]); + MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize(); + auto shape = me_tensors[index]->shape(); + std::string shape_str; + for (size_t i = 0; i < shape.size(); i++) { + shape_str += std::to_string(shape[i]); + shape_str += " "; + } + MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}"; + MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type(); + + auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format); + if (ge_tensor_ptr != nullptr) { + ge_tensors.emplace_back(ge_tensor_ptr); + } else { + MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!"; + ge_tensors.clear(); + return ge_tensors; + } + } + return ge_tensors; +} + +GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format) { + // get tensor data type size + MS_EXCEPTION_IF_NULL(tensor); + size_t type_size = GetDataTypeSize(tensor->data_type()); + if (type_size == kErrorSize) { + MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size; + return nullptr; + } + size_t elements_num = IntToSize(tensor->ElementsNum()); + if (UINT_MAX / type_size < elements_num) { + MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size + << " overflowed UINT_MAX: " << UINT_MAX << "."; + return nullptr; + } + + // get tensor buff size + size_t data_buff_size = elements_num * type_size; + if (data_buff_size == 0) { + MS_LOG(INFO) << "The Me Tensor data buff size is 0."; + } + // create ge tensor + auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format); + if (desc == nullptr) { + MS_LOG(ERROR) << "Failed to get Tensor Desc"; + return nullptr; + } + GeTensorPtr tensor_ptr = make_shared(*desc, static_cast(tensor->data_c()), data_buff_size); + if (tensor_ptr != nullptr) { + MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!"; + } + return tensor_ptr; +} + +std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors, + const std::vector> &request_dims) { + std::vector outputs; + + for (size_t index = 0; index < ge_tensors.size(); index++) { + MeTensorPtr me_tensor_ptr = nullptr; + if (index < request_dims.size()) { + me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]); + } else { + std::vector empty_shape; + me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape); + } + + if (me_tensor_ptr != nullptr) { + outputs.emplace_back(me_tensor_ptr); + } else { + MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!"; + return outputs; + } + } + return outputs; +} + +std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors) { + std::vector outputs; + + for (size_t index = 0; index < ge_tensors.size(); index++) { + MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]); + if (me_tensor_ptr != nullptr) { + outputs.emplace_back(me_tensor_ptr); + } else { + MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!"; + return outputs; + } + } + return outputs; +} + +MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) { + switch (type) { + case GeDataType::DT_FLOAT16: + return MeDataType::kNumberTypeFloat16; + case GeDataType::DT_FLOAT: + return MeDataType::kNumberTypeFloat32; + case GeDataType::DT_DOUBLE: + return MeDataType::kNumberTypeFloat64; + case GeDataType::DT_INT64: + return MeDataType::kNumberTypeInt64; + case GeDataType::DT_INT32: + return MeDataType::kNumberTypeInt32; + case GeDataType::DT_INT16: + return MeDataType::kNumberTypeInt16; + case GeDataType::DT_INT8: + return MeDataType::kNumberTypeInt8; + case GeDataType::DT_BOOL: + return MeDataType::kNumberTypeBool; + case GeDataType::DT_UINT8: + return MeDataType::kNumberTypeUInt8; + case GeDataType::DT_UINT16: + return MeDataType::kNumberTypeUInt16; + case GeDataType::DT_UINT32: + return MeDataType::kNumberTypeUInt32; + case GeDataType::DT_UINT64: + return MeDataType::kNumberTypeUInt64; + case GeDataType::DT_UNDEFINED: + case GeDataType::DT_DUAL_SUB_UINT8: + case GeDataType::DT_DUAL_SUB_INT8: + case GeDataType::DT_DUAL: + return MeDataType::kTypeUnknown; + default: + return MeDataType::kTypeUnknown; + } +} + +namespace { +bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector &request_dims) { + MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims()); + MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims); + + const int GE_DIMS = 4; + std::vector ge_dims = ge_shape.GetDims(); + if (request_dims.size() > ge_dims.size()) { + MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's"; + return false; + } + + // convert NHWC to NCHW + if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[0] == ge_dims[1]) && + (ge_dims[0] == 1) && (ge_dims[2] == 1) && (ge_dims[3] == 1)) { + MS_LOG(INFO) << "Ge tensor shape and request shape is compatible"; + return true; + } + + std::string::size_type i = 0; + for (; i < request_dims.size(); i++) { + if (ge_dims[i] != request_dims[i]) { + MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's"; + return false; + } + } + + for (; i < ge_dims.size(); i++) { + if (ge_dims[i] != 1) { + MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1"; + return false; + } + } + MS_LOG(INFO) << "Ge tensor shape and request shape is compatible"; + return true; +} +} // namespace + +GeShape TransformUtil::ConvertMeShape(const std::vector &me_dims) { + std::vector ge_dims; + (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims)); + return GeShape(ge_dims); +} + +std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape) { + std::vector me_dims; + std::vector ge_dims = ge_shape.GetDims(); + (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims)); + return me_dims; +} + +std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims) { + vector ret; + if (ge_shape.GetDimNum() == 0) { + MS_LOG(DEBUG) << "GeTensor's shape is scalar"; + return ret; + } + + if (IsGeShapeCompatible(ge_shape, request_dims) == true) { + ret = request_dims; + } else { + MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape"; + ret = ConvertGeShape(ge_shape); + } + return ret; +} + +MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, + const TypeId &me_type) { + MeTensor me_tensor(me_type, me_dims); + + // Get the writable data pointer of the tensor and cast it to its data type + auto me_data_ptr = reinterpret_cast(me_tensor.data_c()); + size_t me_data_size = static_cast(me_tensor.data().nbytes()); + MS_EXCEPTION_IF_NULL(me_data_ptr); + MS_EXCEPTION_IF_NULL(ge_tensor); + if (me_data_size < ge_tensor->GetSize()) { + MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor [" + << ge_tensor->GetSize() << " bytes]"; + return nullptr; + } + + // Copy or use the writable data pointer of the ME tensor + MS_EXCEPTION_IF_NULL(ge_tensor->GetData()); + if (ge_tensor->GetSize() == 0) { + MS_LOG(ERROR) << "GE tensor data size is zero!"; + return nullptr; + } + + // Use memcpy here, not memcpy_s, just because the size of ge_tensor may be bigger than 2GB + // which is the size limit of memcpy_s + memcpy(me_data_ptr, ge_tensor->GetData(), ge_tensor->GetSize()); + + return make_shared(me_tensor); +} + +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) { + MS_EXCEPTION_IF_NULL(ge_tensor); + GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); + vector me_dims = ConvertGeShape(ge_shape); + + TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType()); + if (type_id == MeDataType::kTypeUnknown) { + MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: " + << static_cast(ge_tensor->GetTensorDesc().GetDataType()); + return nullptr; + } + return GenerateMeTensor(ge_tensor, me_dims, type_id); +} + +// if request_dims is empty, use ge tensor's shape,otherwise convert to request shape +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector &request_dims) { + MS_EXCEPTION_IF_NULL(ge_tensor); + GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); + vector me_dims = ConvertGeShape(ge_shape, request_dims); + MS_LOG(INFO) << "GE tensor type is " << static_cast(ge_tensor->GetTensorDesc().GetDataType()); + // Create a tensor with wanted data type and shape + TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType()); + if (type_id == MeDataType::kTypeUnknown) { + MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: " + << static_cast(ge_tensor->GetTensorDesc().GetDataType()); + return nullptr; + } + return GenerateMeTensor(ge_tensor, me_dims, type_id); +} + +std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) { + std::string ret; + if (ge_tensor == nullptr) { + MS_LOG(ERROR) << "Input ge tensor is nullptr"; + return ret; + } + + MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast(ge_tensor->GetTensorDesc().GetDataType()); + switch (ge_tensor->GetTensorDesc().GetDataType()) { + case GeDataType::DT_UINT32: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_FLOAT: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_INT32: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_DOUBLE: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_INT64: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_UINT64: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_INT16: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_UINT16: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_DUAL_SUB_INT8: + case GeDataType::DT_INT8: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_UINT8: + case GeDataType::DT_DUAL_SUB_UINT8: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_FLOAT16: + case GeDataType::DT_BOOL: + case GeDataType::DT_UNDEFINED: + case GeDataType::DT_DUAL: + default: + MS_LOG(ERROR) << "Unsupported to print type:" << static_cast(ge_tensor->GetTensorDesc().GetDataType()) + << " ge tensor"; + break; + } + return ret; +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/util.h b/mindspore/ccsrc/transform/graph_ir/util.h new file mode 100644 index 0000000000..32d4242c4f --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/util.h @@ -0,0 +1,241 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_UTIL_H_ +#define TRANSFORM_UTIL_H_ + +#include +#include +#include +#include +#include "securec/include/securec.h" +#include "ir/anf.h" +#include "ir/dtype.h" +#include "ir/tensor.h" +#include "transform/graph_ir/types.h" + +#include "graph/tensor.h" + +namespace mindspore { +namespace transform { +class TransformUtil { + public: + /* + * Parameters: + * type: [MeDataType] the data type for ME tensor + * Return: + * [GeDataType] the data type for ge tensor + * */ + static std::vector ConvertIntToList(int64_t data, int size); + + /* + * Parameters: + * type: [MeDataType] the data type for ME tensor + * Return: + * [GeDataType] the data type for ge tensor + * */ + static GeDataType ConvertDataType(const MeDataType &type); + + /* + * Parameters: + * type: [string] the data format in ME op + * Return: + * [GeFormat] the data format for ge tensor + * */ + static GeFormat ConvertFormat(const std::string &format); + + /* + * Parameters: + * type: [MeDataType] the data type for ME tensor + * Return: + * [size_t] the buff size for the type in ME + * */ + static size_t GetDataTypeSize(const MeDataType &type); + + /* + * Parameters: + * tensor: [MeTensorPtr] the me tensor to get description from + * format: [string] the data format in ME + * is_input: [bool] whether the tensor is used as input, default:false + * Return: + * [shared_ptr] the shared pointer of ge tensor description + * */ + static std::shared_ptr GetGeTensorDesc(const std::vector &shape, const MeDataType &me_type, + const std::string &format); + + /* + * Parameters: + * tensor: [MeTensor] the data tensor in ME + * format: [string] the data format in ME op + * is_input: [bool] whether the tensor is used as input, default:false + * Return: + * [GeTensor] the data tensor in GE + * */ + static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format); + + /* + * Parameters: + * me_tensors: [vector] the data tensors in ME + * format: [string] the data format in ME op + * Return: + * [std::vector] the data tensors in GE + * */ + static std::vector ConvertInputTensors(const std::vector &me_tensors, + const std::string &format); + + /* + * Parameters: + * tensor: [GeTensor] the data tensor in GE + * Return: + * [MeTensor] the data tensor in ME + * */ + static MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor); + + /* + * Parameters: + * tensor: [GeTensor] the data tensor in GE + * request_dims [std::vector] the output Me tensors must adjust to this shapes + * Return: + * [MeTensor] the data tensor in ME + * */ + static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector &request_dims); + /* + * Parameters: + * ge_tensors: [std::vector] the data tensor in GE + * request_dims [std::vector>] the output Me tensors must adjust to this shapes + * Return: + * [std::vector] the data tensor in ME + * */ + static std::vector ConvertGeTensors(const std::vector &ge_tensors, + const std::vector> &request_dims); + /* + * Parameters: + * ge_tensors: [std::vector] the data tensor in GE + * Return: + * [std::vector] the data tensor in ME + * */ + static std::vector ConvertGeTensors(const std::vector &ge_tensors); + /* + * Parameters: + * ge_tensor: [GeTensor] the data tensor in GE + * me_dims: [std::vector] the shape of created Me tensor + * me_type: [TypeId] the type of created Me tensor + * Return: + * [MeTensor] the data tensor in ME + * */ + static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, + const TypeId &me_type); + /* + * Parameters: + * type: [GeDataType] the ge tensor data type + * Return: + * [MeDataType] the me tensor data type + * */ + static MeDataType ConvertGeDataType(const GeDataType &type); + + /* + * Parameters: + * me_dims: [std::vector] the me shape + * Return: + * [GeShape] the ge shape + * */ + static GeShape ConvertMeShape(const std::vector &me_dims); + + /* + * Parameters: + * ge_shape: [GeShape] the ge shape + * Return: + * [vector] the me shape + * */ + static std::vector ConvertGeShape(const GeShape &ge_shape); + + /* Function: + * Convert GeShape to Me request shape, Support pattern: + * {1, x, 1, 1} --> {x} + * {x, 1, 1, 1} --> {x} + * {x, x, 1, 1} --> {x, x} + * {x, x, x, 1} --> {x, x, x} + * {x, x, x, x} --> {x, x, x, x} + * If unmatch upon patterns, return original ge dims + * Parameters: + * ge_shape: [GeShape] the ge shape + * request_dims: [vector] request dims + * Return: + * [vector] the me shape + * */ + static std::vector ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims); + + /* + * Parameters: + * vec: [std::vector] the vector to print + * Return: + * [string] value string + * */ + template ::value>::type> + static std::string PrintVector(const std::vector &vec) { + const int MAX_PRINT_NUM = 100; + std::stringstream ss; + ss << "{ "; + int i = 0; + for (auto it = vec.begin(); it != vec.end(); ++it) { + ss << std::to_string(*it) << ", "; + i++; + if (i >= MAX_PRINT_NUM) { + break; + } + } + + if (i >= MAX_PRINT_NUM) { + ss << "... to be continue}"; + } else { + ss << "}"; + } + return ss.str(); + } + + /* + * Parameters: + * ge_tensor: [GeTensorPtr] the ge tensor + * Return: + * [stringstream] value string + * */ + static std::string PrintGeTensor(const GeTensorPtr ge_tensor); + + /* + * Parameters: + * data: [uint8_t *] the ge tensor data pointer + * size: [size_t] the ge tensor data bytes + * Return: + * [shared_ptr] vector pointer + * */ + template ::value>::type> + static std::vector MakeVector(const uint8_t *const data, size_t size) { + auto dest = std::vector(size / sizeof(T)); + if (data == nullptr) { + return dest; + } + + errno_t ret = memcpy_s(dest.data(), dest.size() * sizeof(T), data, size); + if (EOK != ret) { + return std::vector(); + } + return dest; + } +}; +} // namespace transform +} // namespace mindspore + +#endif // TRANSFORM_UTIL_H_ diff --git a/mindspore/ccsrc/transform/graph_runner.cc b/mindspore/ccsrc/transform/graph_runner.cc deleted file mode 100644 index 52d0d8e17f..0000000000 --- a/mindspore/ccsrc/transform/graph_runner.cc +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * Limitations under the License. - */ - -#include "transform/graph_runner.h" -#include -#include -#include -#include "utils/log_adapter.h" -#include "utils/config_manager.h" -#include "sys/time.h" -#include "utils/callbacks.h" -#include "utils/utils.h" -#include "./common.h" -#ifdef ENABLE_GE -#include "utils/callbacks_ge.h" -#endif - -#ifdef NO_GE_CLIENT -namespace ge { -Session::Session(const std::map &options) { - if (options.empty()) { - MS_LOG(ERROR) << "session input options is empty"; - } - sessionId_ = 0; -} -Session::~Session() {} -} // namespace ge -#endif - -namespace mindspore { -namespace transform { -std::shared_ptr GraphRunner::NewSession(const SessionOptions &sess_options) { - std::shared_ptr ret = std::make_shared(sess_options); - if (ret == nullptr) { - MS_LOG(ERROR) << "Create GE session failed"; - return nullptr; - } - MS_LOG(INFO) << "Create new GE session success"; - return ret; -} - -GraphRunner::GraphRunner(const GraphRunnerOptions &options) - : options_(options), graph_manager_(DfGraphManager::GetInstance()) { - if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) { - MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode"; - } - - if (options.sess_ptr != nullptr) { - sess_ = options.sess_ptr; - } else { - sess_ = NewSession(options.options); - if (sess_ == nullptr) { - MS_LOG(EXCEPTION) << "GraphRunner initialize failed!!"; - return; - } - } - -#if (defined ENABLE_GE) - // register the callback function - if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ge::GRAPH_SUCCESS) { - MS_LOG(EXCEPTION) << "register callback failed!"; - return; - } - - if (sess_->RegisterCallBackFunc(callbacks::kSummary, callbacks::SummarySaveCallback) != ge::GRAPH_SUCCESS) { - MS_LOG(EXCEPTION) << "register summary callback failed!"; - return; - } -#endif - - std::vector wrappers = graph_manager_.GetAllGraphs(); - if (wrappers.empty()) { - MS_LOG(INFO) << "The GraphManager is empty!!"; - return; - } - -#ifdef ENABLE_GE - for (auto &it : wrappers) { - std::set saved_graph = graph_manager_.GetSavedGraphs(); - auto iter_find = saved_graph.find(std::to_string(it->id_)); - if (iter_find != saved_graph.end()) { - continue; - } - MS_LOG(INFO) << "Add the graph " << (*it).name_ << " to GE, it's id is: " << (*it).id_; - graph_manager_.AddSavedGraphs(std::to_string(it->id_)); - (void)sess_->AddGraph(it->id_, *(it->graph_ptr_), it->options_); - } -#endif -} - -Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, - std::vector *outputs) { - std::string name = options.name; - if (name.empty()) { - MS_LOG(ERROR) << "The graph name is null"; - return Status::INVALID_ARGUMENT; - } - - DfGraphWrapperPtr wrap_ptr = graph_manager_.GetGraphByName(name); - if (wrap_ptr == nullptr) { - MS_LOG(ERROR) << "Get graph form DfGraphManager failed!"; - return Status::NOT_FOUND; - } - - if (wrap_ptr->graph_ptr_ == nullptr) { - MS_LOG(WARNING) << "The graph is null"; - return Status::NOT_FOUND; - } - - // call ge::RunGraph() to exec a graph; - std::vector ge_inputs; - std::vector ge_outputs; - - (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(ge_inputs), - [](const GeTensorPtr &i) { return *i; }); - - MS_LOG(INFO) << "Run the graph in GE with " << ge_inputs.size() << " inputs"; - - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); - -#ifdef ENABLE_GE - if (sess_ == nullptr) { - MS_LOG(ERROR) << "The GE session is null, can't run the graph!"; - return Status::FAILED; - } - - // The information of some nodes could be changed after fusion in some cases - // Therefore a graph needs to be rebuilt in above situation - if (sess_->IsGraphNeedRebuild(wrap_ptr->id_)) { - sess_->RemoveGraph(wrap_ptr->id_); - sess_->AddGraph(wrap_ptr->id_, *(wrap_ptr->graph_ptr_), wrap_ptr->options_); - } - - ge::Status ret = sess_->RunGraph(wrap_ptr->id_, ge_inputs, ge_outputs); - if (ret != ge::GRAPH_SUCCESS) { - MS_LOG(ERROR) << "Call GE RunGraph Failed, ret is: " << ret; - return Status::FAILED; - } -#else - ge_outputs.swap(ge_inputs); -#endif - - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Call GE RunGraph Success in " << cost << " us, the GE outputs num is: " << ge_outputs.size(); - - (void)std::transform(ge_outputs.begin(), ge_outputs.end(), std::back_inserter(*outputs), - [](const GeTensor &ge_tensor) { return std::make_shared(ge_tensor); }); - - return Status::SUCCESS; -} - -Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, - std::vector *const outputs) { - std::vector ge_inputs; - for (auto it : inputs) { - MS_LOG(INFO) << "inputs tensor's data size is: " << (*it).DataSize(); - auto shape = (*it).shape(); - std::string shape_str; - for (const auto &elem : shape) { - shape_str += std::to_string(elem); - shape_str += " "; - } - MS_LOG(INFO) << "inputs tensor's shape is: { " << shape_str << "}"; - - auto ge_tensor_ptr = TransformUtil::ConvertTensor(it, kOpFormat_NCHW); - if (ge_tensor_ptr != nullptr) { - ge_inputs.emplace_back(ge_tensor_ptr); - } else { - MS_LOG(INFO) << "Convert input Me tensor to Ge tensor failed. Abort this graph"; - return Status::FAILED; - } - } - - std::vector ge_outputs; - Status ret; - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - ret = RunGraph(options, ge_inputs, &ge_outputs); - } - if (ret != Status::SUCCESS) { - return ret; - } else { - // conver GeTensor to MeTensor - for (auto &it : ge_outputs) { - auto tensor = TransformUtil::ConvertGeTensor(it); - if (tensor != nullptr) { - outputs->emplace_back(tensor); - } - } - MS_LOG(INFO) << "Return Me tensor outputs num is: " << outputs->size(); - return Status::SUCCESS; - } -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_runner.h b/mindspore/ccsrc/transform/graph_runner.h deleted file mode 100644 index 30769c8310..0000000000 --- a/mindspore/ccsrc/transform/graph_runner.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_GRAPH_RUNNER_H_ -#define TRANSFORM_GRAPH_RUNNER_H_ - -#include -#include -#include -#include -#include - -#include "transform/types.h" -#include "transform/util.h" -#include "ir/tensor.h" -#include "transform/df_graph_manager.h" - -namespace mindspore { -namespace transform { -using SessionOptions = std::map; - -struct GraphRunnerOptions { - std::string target{"default_graph_runner"}; - SessionOptions options; - // if sess_ptr is nullptr, GraphRunner will create a new ge session - std::shared_ptr sess_ptr{nullptr}; -}; - -struct RunOptions { - // graph's name - std::string name; -}; - -class GraphRunner { - public: - explicit GraphRunner(const GraphRunnerOptions &options); - ~GraphRunner() { sess_ = nullptr; } - Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); - Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); - static std::shared_ptr NewSession(const SessionOptions &sess_options); - - private: - std::shared_ptr sess_; - transform::GraphRunnerOptions options_; - DfGraphManager &graph_manager_; -}; -} // namespace transform -} // namespace mindspore - -#endif // TRANSFORM_GRAPH_RUNNER_H_ diff --git a/mindspore/ccsrc/transform/onnx/CMakeLists.txt b/mindspore/ccsrc/transform/onnx/CMakeLists.txt new file mode 100644 index 0000000000..0d2f6c947b --- /dev/null +++ b/mindspore/ccsrc/transform/onnx/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE _ONNX_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX) +add_library(_mindspore_transform_onnx_obj OBJECT ${_ONNX_SRC_FILES}) diff --git a/mindspore/ccsrc/transform/onnx/ir_exporter.cc b/mindspore/ccsrc/transform/onnx/ir_exporter.cc new file mode 100644 index 0000000000..78858eea8a --- /dev/null +++ b/mindspore/ccsrc/transform/onnx/ir_exporter.cc @@ -0,0 +1,618 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/tensor.h" +#include "ir/param_value.h" +#include "debug/anf_ir_utils.h" +#include "frontend/operator/ops.h" +#include "proto/onnx.pb.h" + +namespace mindspore { +using FloatPtr = std::shared_ptr; +using IntPtr = std::shared_ptr; + +// anf type to onnx type map +static std::unordered_map g_data_type_map = { + {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, + {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, + {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, + {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, + {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, + {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, + {kObjectTypeString, onnx::TensorProto_DataType_STRING}, +}; + +static std::unordered_map g_data_bits_int_map = { + {8, onnx::TensorProto_DataType_INT8}, + {16, onnx::TensorProto_DataType_INT16}, + {32, onnx::TensorProto_DataType_INT32}, + {64, onnx::TensorProto_DataType_INT64}, +}; + +static std::unordered_map g_data_bits_float_map = { + {16, onnx::TensorProto_DataType_FLOAT16}, + {32, onnx::TensorProto_DataType_FLOAT}, +}; + +// Can build different builder according to format +class IrExportBuilder; +using IrExportBuilderPtr = std::shared_ptr; + +class IrExporter { + public: + explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {} + virtual ~IrExporter() = default; + std::string GetDumpString(const FuncGraphPtr &func_graph); + + private: + IrExportBuilderPtr builder_; +}; + +class IrExportBuilder { + public: + IrExportBuilder() = default; + ~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); } + std::string GetProtoString(const FuncGraphPtr &func_graph); + void BuildModelInfo(); + void BuildModel(const FuncGraphPtr &func_graph); + + private: + void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); + void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); + void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); + void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto); + void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto); + std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto); + + void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto); + void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto); + void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto); + void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); + void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); + void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); + void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, + std::string suffix = "0"); + void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); + void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); + + onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); + onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); + onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits); + std::string GetNodeName(const AnfNodePtr &node); + std::string GetUniqueNodeName(const AnfNodePtr &node); + std::string GetOpTypeName(const AnfNodePtr &node); + size_t AllocateIndex() { return ++node_index_; } + void ResetIndex() { node_index_ = 0; } + + private: + onnx::ModelProto model_; + onnx::NodeProto *last_node_{nullptr}; + std::list todo_; + std::map node_index_map_; + size_t node_index_{0}; +}; + +using IrExporterPtr = std::shared_ptr; + +std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) { + if ((builder_ == nullptr) || (func_graph == nullptr)) { + MS_LOG(EXCEPTION) << "Input params is null."; + } + + // Export model info + builder_->BuildModelInfo(); + + // Export model and return string + builder_->BuildModel(func_graph); + + return builder_->GetProtoString(func_graph); +} + +std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) { + MS_LOG(DEBUG) << "BuildModel complete!"; + return model_.SerializeAsString(); +} + +void IrExportBuilder::BuildModelInfo() { + model_.set_ir_version(onnx::IR_VERSION_2019_1_22); + model_.set_producer_name("MindSpore"); + model_.set_model_version(1); +} + +void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { + onnx::GraphProto *graph_proto = model_.mutable_graph(); + graph_proto->set_name(func_graph->ToString()); + ResetIndex(); + todo_.clear(); + todo_.push_back(func_graph); + while (!todo_.empty()) { + FuncGraphPtr fg = todo_.back(); + todo_.pop_back(); + BuildFuncGraph(fg, graph_proto); + } +} + +void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + // Export parameters + // 1. parameters should be mapped to ValueInfoProto + // 2. parameters with default value should be mapped to Initializer + BuildParameters(func_graph, graph_proto); + + // Export operator nodes(include output) + BuildNodes(func_graph, graph_proto); +} + +void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + for (auto &item : func_graph->parameters()) { + auto param = item->cast(); + if (param == nullptr) { + MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; + } + onnx::ValueInfoProto *input_proto = graph_proto->add_input(); + std::string param_name = GetUniqueNodeName(param); + input_proto->set_name(param_name); + SetValueInfoProto(param, input_proto); + if (!param->has_default()) { + MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; + continue; + } + + // Using ONNX initializer to set parameter's default value + onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); + initializer_proto->set_name(param_name); + SetParamToTensorProto(param, initializer_proto); + auto tensor = std::dynamic_pointer_cast(param->default_param()->value()); + if (tensor) { + initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); + } + } +} + +onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) { + auto iter = g_data_type_map.find(type_id); + if (iter == g_data_type_map.end()) { + MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id; + } + return iter->second; +} + +onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) { + auto iter = g_data_bits_int_map.find(bits); + if (iter == g_data_bits_int_map.end()) { + MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits; + } + return iter->second; +} + +onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) { + auto iter = g_data_bits_float_map.find(bits); + if (iter == g_data_bits_float_map.end()) { + MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits; + } + return iter->second; +} + +void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) { + if (node == nullptr || value_proto == nullptr) { + MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!"; + } + MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); + SetValueInfoProto(node->Type(), node->Shape(), value_proto); +} + +void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, + onnx::ValueInfoProto *const value_proto) { + onnx::TypeProto *type_proto = value_proto->mutable_type(); + if (type->isa() && shape->isa()) { + auto tensor = type->cast(); + auto elem_type = tensor->element(); + const auto &dims = shape->cast()->shape(); + type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); + for (const auto &dim : dims) { + MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } + } else if (type->isa()) { + auto tup_shape = shape->cast(); + type_proto->set_denotation(std::to_string(tup_shape->shape().size())); + } else { + MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; + } +} + +void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("tensor"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + auto data = value->cast(); + tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); + auto dtype = data->data_type(); + auto shape = data->shape_c(); + tensor_proto->set_data_type(GetOnnxDataType(dtype)); + for (const auto &dim : shape) { + tensor_proto->add_dims(dim); + } +} + +void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, + onnx::TensorProto *const tensor_proto) { + if (!type->isa() || !shape->isa()) { + MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString(); + } + auto tensor = type->cast(); + const auto &dims = shape->cast()->shape(); + tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id())); + for (const auto &dim : dims) { + tensor_proto->add_dims(dim); + } +} + +void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { + if (param == nullptr || tensor_proto == nullptr) { + MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; + } + MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString(); + SetTensorProto(param->Type(), param->Shape(), tensor_proto); +} + +void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); + for (const AnfNodePtr &node : nodes) { + if (!node->isa()) { + MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; + continue; + } + auto cnode = node->cast(); + if (cnode == func_graph->get_return()) { + BuildOutput(cnode, graph_proto); + } else { + BuildCNode(cnode, graph_proto); + } + } +} + +void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) { + if (node->size() != 2) { + MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; + } + AnfNodePtr arg = node->input(1); + // Using make_tuple to set multi-output + if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { + auto tuple_node = arg->cast(); + for (size_t i = 1; i < tuple_node->size(); i++) { + auto input_node = arg->cast()->input(i); + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + auto output_name = GetUniqueNodeName(tuple_node->input(i)); + output_proto->set_name(output_name); + last_node_->add_output(output_name); + SetValueInfoProto(tuple_node->input(i), output_proto); + } + } else { + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + std::string output_name = GetUniqueNodeName(node); + output_proto->set_name(output_name); + last_node_->add_output(output_name); + SetValueInfoProto(arg, output_proto); + } +} + +std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { + // May be ValueNode/CNode/Parameter + std::string type_name = ""; + if (IsValueNode(node)) { + PrimitivePtr prim = GetValueNode(node); + type_name = prim->ToString(); + } else if (IsValueNode(node)) { + FuncGraphPtr fg = GetValueNode(node); + todo_.push_back(fg); + type_name = fg->ToString(); + } else if (node->isa() || node->isa()) { + type_name = node->ToString(); + } else { + MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name(); + } + MS_LOG(DEBUG) << "ExportType: " << type_name; + return type_name; +} + +void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, + onnx::NodeProto *const node_proto, std::string suffix) { + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_ref_attr_name("shape"); + if (suffix.compare("0") != 0) { + attr_proto->set_name("shape" + suffix); + } else { + attr_proto->set_name("shape"); + } + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + SetTensorProto(type, shape, tensor_proto); +} + +void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { + // Get shape of cnode + // 1. prim ArgMaxWithValue need to get shape from tuple element + // 2. some cnode doesn't has shape, such as LayerNorm + // 3. other cnodes have shape + if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { + auto type = node->Type(); + auto shape = node->Shape(); + if (!type->isa()) { + MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); + } + auto elements = type->cast()->elements(); + auto tuple_shape = shape->cast()->shape(); + for (size_t i = 0; i < elements.size(); i++) { + SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); + } + } else { + auto type = node->Type(); + auto shape = node->Shape(); + if (!type->isa() || !shape->isa()) { + MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); + return; + } + SetShapeToNodeProto(type, shape, node_proto); + } +} + +void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { + auto inputs_size = node->size(); + if (inputs_size < 1) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + + // Need to build input node before dealing with cnode + std::vector op_inputs; + std::vector input_names; + for (size_t i = 1; i < inputs_size; i++) { + auto input = node->input(i); + op_inputs.push_back(input); + input_names.push_back(BuildInputNode(input, graph_proto)); + } + + // Build cnode + onnx::NodeProto *node_proto = graph_proto->add_node(); + std::string output_name = GetUniqueNodeName(node); + node_proto->add_output(output_name); + node_proto->set_name(output_name); + node_proto->set_domain(node->fullname_with_scope()); + AnfNodePtr op = node->input(0); + std::string type_name = GetOpTypeName(op); + node_proto->set_op_type(type_name); + last_node_ = node_proto; + SetShapeToNodeProto(node, node_proto); + (void)std::for_each(input_names.begin(), input_names.end(), + [&node_proto](const string &name) { node_proto->add_input(name); }); + + // Add primitive attrs + if (IsValueNode(op)) { + auto prim = GetValueNode(op); + for (auto attr : prim->attrs()) { + MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name(attr.first); + SetValueToAttributeProto(attr.second, attr_proto); + } + } else { + MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name(); + } +} + +std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) { + std::string node_name = GetUniqueNodeName(node); + if (node->isa()) { + // When node input is a ValueNode, need to create a Constant Node + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->add_output(node_name); + SetAttributeProto(node, node_proto); + } + return node_name; +} + +std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { + // Naming anfnode + // 1. parameter is unique in one func_graph + // 2. cnode and valuenode may be reduplicative, so add index to identify. + std::string node_name = ""; + if (node->isa()) { + node_name = GetNodeName(node); + } else if (node->isa() || node->isa()) { + auto iter = node_index_map_.find(node); + if (iter != node_index_map_.end()) { + node_name = GetNodeName(node) + ":" + std::to_string(iter->second); + } else { + auto node_idx = AllocateIndex(); + node_index_map_[node] = node_idx; + node_name = GetNodeName(node) + ":" + std::to_string(node_idx); + } + } else { + MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); + } + MS_LOG(DEBUG) << "Node name: " << node_name; + return node_name; +} + +std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { + std::string node_name = ""; + if ((node != nullptr) && (node->func_graph() != nullptr)) { + node_name = node->func_graph()->ToString() + ":"; + } + node_name += node->ToString(); + MS_LOG(DEBUG) << "GetNodeName: " << node_name; + return node_name; +} + +void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) { + if (node == nullptr || node_proto == nullptr) { + MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; + } + auto value = node->cast()->value(); + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("value"); + MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString(); + SetValueToAttributeProto(value, attr_proto); +} + +void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("type"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + if (value->isa()) { + auto int_value = value->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); + } else if (value->isa()) { + auto float_value = value->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); + } else if (value->isa()) { + tensor_proto->set_name("tensor"); + auto elem_type = value->cast()->element(); + if (elem_type->isa()) { + auto int_value = elem_type->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); + } else if (elem_type->isa()) { + auto float_value = elem_type->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); + } else { + MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name(); + } + } else { + MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); + } +} + +void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + if (value->isa() || value->isa()) { + SetScalarToAttributeProto(value, attr_proto); + } else if (value->isa() || value->isa()) { + SetTypeToAttributeProto(value, attr_proto); + } else if (value->isa()) { + SetSequenceToAttributeProto(value->cast(), attr_proto); + } else if (value->isa()) { + SetTensorToAttributeProto(value, attr_proto); + } else { + MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); + } +} + +void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("scalar"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + SetScalarToProto(value, tensor_proto); +} + +void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { + if (value == nullptr || tensor_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; + } + if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); + tensor_proto->add_string_data(GetValue(value)); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL); + tensor_proto->add_int32_data(GetValue(value)); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + tensor_proto->add_int64_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); + tensor_proto->add_float_data(GetValue(value)); + } else { + MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); + } +} + +void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, + onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("scalar"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + if (value->isa()) { + const ValueTuplePtr &tuple_value = value->cast(); + if (tuple_value->value().size() == 0) { + MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; + return; + } + auto type_id = tuple_value->value()[0]->type()->type_id(); + tensor_proto->set_data_type(GetOnnxDataType(type_id)); + for (const auto &item : tuple_value->value()) { + SetScalarToProto(item, tensor_proto); + } + } else if (value->isa()) { + const ValueListPtr &list_value = value->cast(); + if (list_value->value().size() == 0) { + MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; + return; + } + auto type_id = list_value->value()[0]->type()->type_id(); + tensor_proto->set_data_type(GetOnnxDataType(type_id)); + for (const auto &item : list_value->value()) { + SetScalarToProto(item, tensor_proto); + } + } +} + +std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { + auto builder = std::make_shared(); + if (builder == nullptr) { + MS_LOG(ERROR) << "Create ir exporter failed!"; + return ""; + } + auto exporter = std::make_shared(builder); + if (exporter == nullptr) { + return ""; + } + return exporter->GetDumpString(func_graph); +} +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/onnx/onnx_exporter.cc b/mindspore/ccsrc/transform/onnx/onnx_exporter.cc new file mode 100644 index 0000000000..f69fb81a7e --- /dev/null +++ b/mindspore/ccsrc/transform/onnx/onnx_exporter.cc @@ -0,0 +1,1207 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "debug/anf_ir_utils.h" +#include "proto/onnx.pb.h" +#include "frontend/operator/ops.h" +#include "ir/tensor.h" +#include "ir/param_value.h" + +namespace mindspore { +enum OpMergeMode { + OP_MERGE_UNDEFINED = 0, // undefined behavior + OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list + OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` + OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` + OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` + OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool` +}; + +struct OpMergedInfo { + OpMergeMode mode = OP_MERGE_UNDEFINED; + int referred_count = 0; +}; + +using GenAttrFuncType = + std::function; + +template +void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { + auto casted_value = dyn_cast(value); + if (casted_value == nullptr) { + MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; + } + auto attr_value = casted_value->value(); + switch (attr_type) { + case onnx::AttributeProto_AttributeType_INT: + attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value)); + break; + case onnx::AttributeProto_AttributeType_FLOAT: + attr_proto->set_f(static_cast(attr_value)); + break; + case onnx::AttributeProto_AttributeType_INTS: + for (size_t i = 0; i < rep_cnt; ++i) { + attr_proto->add_ints(static_cast<::google::protobuf::int64>(attr_value)); + } + break; + case onnx::AttributeProto_AttributeType_FLOATS: + for (size_t i = 0; i < rep_cnt; ++i) { + attr_proto->add_floats(static_cast(attr_value)); + } + break; + default: + MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type; + } + attr_proto->set_type(attr_type); +} + +template +void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { + auto tuple_ptr = dyn_cast(value); + if (tuple_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; + } + switch (attr_type) { + case onnx::AttributeProto_AttributeType_INTS: + for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) { + attr_proto->add_ints(GetValue((*tuple_ptr)[i])); + } + break; + case onnx::AttributeProto_AttributeType_FLOATS: + for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) { + attr_proto->add_floats(GetValue((*tuple_ptr)[i])); + } + break; + default: + MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type; + } + attr_proto->set_type(attr_type); +} + +void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); + auto attr_value = GetValue(value); + if (attr_value == "VALID") { + attr_proto->set_s("VALID"); + } else { + attr_proto->set_s("SAME_UPPER"); + } +} + +class OpAttrInfo { + public: + OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name, + onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) + : attr_name_(attr_name), + onnx_attr_name_(onnx_attr_name), + onnx_attr_type_(onnx_attr_type), + fn_gen_attr_(fn_gen_attr) {} + ~OpAttrInfo() {} + + const std::string &attr_name() const { return attr_name_; } + const std::string &onnx_attr_name() const { return onnx_attr_name_; } + onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } + GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } + + private: + std::string attr_name_; // attribute name of MindSpore + std::string onnx_attr_name_; // corresponding attribute name of ONNX + onnx::AttributeProto_AttributeType onnx_attr_type_; // corresponding attribute type of ONNX + GenAttrFuncType fn_gen_attr_; // function used convert +}; + +class OpNameInfo { + public: + OpNameInfo &set_op_type(const std::string &op_type) { + op_type_ = op_type; + return *this; + } + + const std::string &op_type() const { return op_type_; } + + OpNameInfo &set_onnx_type(const std::string &onnx_type) { + onnx_type_ = onnx_type; + return *this; + } + + const std::string &onnx_type() const { return onnx_type_; } + + OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name, + onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) { + op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); + return *this; + } + + const std::vector &op_attrs() const { return op_attrs_; } + + private: + std::string op_type_; // operator type of MindSpore + std::string onnx_type_; // corresponding ONNX operator type + std::vector op_attrs_; // operator attributes map info +}; + +#define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \ + OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); } + +OPERATOR_ONNX_CONVERT_DEFINE(TensorAdd, Add, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo()) + +OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Sigmoid, Sigmoid, OpNameInfo()) + +OPERATOR_ONNX_CONVERT_DEFINE(Flatten, Flatten, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Squeeze, Squeeze, + OpNameInfo().Attr("axis", "axes", onnx::AttributeProto_AttributeType_INTS, + SetAttrTupleValueToProto<0>)) + +OPERATOR_ONNX_CONVERT_DEFINE( + Conv2D, Conv, + OpNameInfo() + .Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) + .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto) + .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) + .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, + [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, + const PrimitivePtr &prim) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); + auto attr_value = GetValue(value); + if (attr_value == "valid") { + attr_proto->set_s("VALID"); + } else if (attr_value == "same") { + attr_proto->set_s("SAME_UPPER"); + } else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads' + attr_proto->set_name("pads"); + SetAttrTupleValueToProto(prim->GetAttr("pad_list"), onnx::AttributeProto_AttributeType_INTS, attr_proto, + prim); + } + }) + .Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) +OPERATOR_ONNX_CONVERT_DEFINE(BiasAdd, Add, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(MatMul, Gemm, + OpNameInfo() + .Attr("transpose_a", "transA", onnx::AttributeProto_AttributeType_INT, + SetAttrValueToProto) + .Attr("transpose_b", "transB", onnx::AttributeProto_AttributeType_INT, + SetAttrValueToProto)) + +OPERATOR_ONNX_CONVERT_DEFINE(BatchNorm, BatchNormalization, + OpNameInfo().Attr("epsilon", "epsilon", onnx::AttributeProto_AttributeType_FLOAT, + SetAttrValueToProto)) + +OPERATOR_ONNX_CONVERT_DEFINE(Reshape, Reshape, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(ReduceMean, ReduceMean, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Cast, Cast, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(PReLU, PRelu, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax, + OpNameInfo() + .Attr("axis", "axis", onnx::AttributeProto_AttributeType_INT, + SetAttrValueToProto) + .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, + [](ValuePtr, onnx::AttributeProto_AttributeType, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto->set_i(0); + })) + +OPERATOR_ONNX_CONVERT_DEFINE(SimpleMean, AveragePool, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE( + MaxPool, MaxPool, + OpNameInfo() + .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) + .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) + .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) + +OPERATOR_ONNX_CONVERT_DEFINE( + MaxPoolWithArgmax, MaxPool, + OpNameInfo() + .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) + .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) + .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) + +OPERATOR_ONNX_CONVERT_DEFINE( + AvgPool, AveragePool, + OpNameInfo() + .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) + .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) + .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) + +OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo()) + +#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name + +void RegisterOpConverters(const std::function &fn) { + fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); + fn(OP_CONVERT_FUNCTION_NAME(Mul)()); + + fn(OP_CONVERT_FUNCTION_NAME(ReLU)()); + fn(OP_CONVERT_FUNCTION_NAME(Sigmoid)()); + + fn(OP_CONVERT_FUNCTION_NAME(Conv2D)()); + fn(OP_CONVERT_FUNCTION_NAME(Argmax)()); + + fn(OP_CONVERT_FUNCTION_NAME(Flatten)()); + fn(OP_CONVERT_FUNCTION_NAME(MaxPool)()); + fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)()); + fn(OP_CONVERT_FUNCTION_NAME(AvgPool)()); + + fn(OP_CONVERT_FUNCTION_NAME(Squeeze)()); + fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)()); + fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); + + fn(OP_CONVERT_FUNCTION_NAME(make_tuple)()); + fn(OP_CONVERT_FUNCTION_NAME(Concat)()); + fn(OP_CONVERT_FUNCTION_NAME(RealDiv)()); + fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)()); + fn(OP_CONVERT_FUNCTION_NAME(Sub)()); +} + +class OpConvertRegistry { + public: + ~OpConvertRegistry() { Clear(); } + + static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } + + static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } + + static OpConvertRegistry &GetSingleton() { + static OpConvertRegistry registry = OpConvertRegistry(); + return registry; + } + + static const std::unordered_map &GetOpConvertMap() { return GetSingleton().op_map_; } + + void Clear() noexcept { op_map_.clear(); } + + private: + OpConvertRegistry() {} + + std::unordered_map op_map_; +}; + +class OnnxExporter { + public: + OnnxExporter() {} + ~OnnxExporter() {} + + std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); + + private: + void InitModelInfo(); + + void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); + void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); + + size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + const PrimitivePtr &prim, const std::vector &inputs, + onnx::GraphProto *graph_proto); + + static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); + void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false); + void SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *tensor_proto); + + void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, + std::unordered_map *op_merged_infos_ptr); + void ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + + void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + + void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + + void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + + void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + std::string GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto); + + void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); + void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); + + size_t AllocateNodeIndex() { return ++onnx_node_index_; } + + void ResetNodeIndex() { onnx_node_index_ = 0; } + + static int GetInt32Value(const AnfNodePtr &node) { + auto value_node_ptr = dyn_cast(node); + MS_EXCEPTION_IF_NULL(value_node_ptr); + return GetValue(value_node_ptr->value()); + } + + onnx::ModelProto model_; + + size_t onnx_node_index_ = 0; +}; + +std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + return ""; + } + ResetNodeIndex(); + OpConvertRegistry::GetSingleton().Clear(); + OpConvertRegistry::RegisterAllOpConverters(); + InitModelInfo(); + onnx::GraphProto *graph_proto = model_.mutable_graph(); + ExportFuncGraph(func_graph, graph_proto); + return model_.SerializeAsString(); +} + +void OnnxExporter::InitModelInfo() { + model_.set_ir_version(onnx::IR_VERSION_2019_1_22); + model_.set_producer_name("MindSpore"); + model_.set_producer_version("1.0"); + onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import(); + opset_proto->set_version(9); +} + +void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + std::map node_map; + + MS_LOG(INFO) << "Begin exporting onnx model for graph " << func_graph->ToString(); + + onnx_node_index_ = func_graph->parameters().size(); + + // set graph name + graph_proto->set_name(func_graph->ToString()); + + // export parameters + // 1. all parameters (with or without default value) will be mapped to ONNX parameters + // 2. parameters with default value will mapped to ONNX initializers + ExportParameters(func_graph, graph_proto); + + // export computational nodes and output nodes + ExportNodes(func_graph, &node_map, graph_proto); + + MS_LOG(INFO) << "End exporting onnx model for graph " << func_graph->ToString(); +} + +void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + for (auto ¶m : func_graph->parameters()) { + const ParameterPtr param_ptr = dyn_cast(param); + if (param_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; + } + + onnx::ValueInfoProto *input_proto = graph_proto->add_input(); + input_proto->set_name(param_ptr->ToString()); + SetValueInfoType(param_ptr, input_proto); + + if (!param_ptr->has_default()) { + continue; + } + // parameter with default value is an ONNX initializer + onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); + initializer_proto->set_name(param_ptr->ToString()); + SetTensorProtoInfo(param_ptr, initializer_proto); + // set value for initializer + auto tensor = std::dynamic_pointer_cast(param_ptr->default_param()->value()); + if (tensor) { + initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); + } + } +} + +onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) { + // clang-format off + static std::unordered_map type_map = { + {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, + {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, + {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, + {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, + {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, + {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, + {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, + {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, + {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, + {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, + {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, + {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, + }; + // clang-format on + + auto iter = type_map.find(type_id); + if (iter == type_map.end()) { + MS_LOG(EXCEPTION) << "Convert type error, unsupported type " << type_id; + } + + return iter->second; +} + +void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) { + auto dtype = node->Type(); + auto shape = node->Shape(); + onnx::TypeProto *type_proto = value_proto->mutable_type(); + if (dtype->isa() && shape->isa()) { + auto tensor = dyn_cast(dtype); + auto elem_type = tensor->element(); + const auto &dims = dyn_cast(shape)->shape(); + // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 + auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); + type_proto->mutable_tensor_type()->set_elem_type(type); + + for (const auto &dim : dims) { + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } + } +} + +void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { + auto dtype = param->Type(); + auto shape = param->Shape(); + if (!dtype->isa() || !shape->isa()) { + MS_LOG(EXCEPTION) << "Parameter " << param->name() << " is not a regular tensor, with value " << param->ToString(); + } + + auto tensor = dyn_cast(dtype); + auto elem_type = tensor->element(); + const auto &dims = dyn_cast(shape)->shape(); + tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); + for (const auto &dim : dims) { + tensor_proto->add_dims(dim); + } +} + +void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, + std::unordered_map *op_merged_infos_ptr) { + std::unordered_map &op_merged_infos = *op_merged_infos_ptr; + + for (auto &node : nodes) { + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (cnode == func_graph->get_return()) { + // if the key `input` does not exist, just create a new one + op_merged_infos[cnode].referred_count += 1; + } + for (auto &input : cnode->inputs()) { + if (!input->isa()) { + continue; + } + // if the key `input` does not exist, just create a new one + op_merged_infos[input].referred_count += 1; + } + // MindSpore Conv + BiasAdd --> ONNX Conv + if (cnode->IsApply(std::make_shared("BiasAdd")) && + IsPrimitiveCNode(cnode->input(1), prim::kPrimConv2D)) { + op_merged_infos[cnode].mode = OP_MERGE_CONV; + op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; + op_merged_infos[cnode->input(1)].referred_count -= 1; + } else if (cnode->IsApply(std::make_shared("BiasAdd")) && + IsPrimitiveCNode(cnode->input(1), prim::kPrimMatMul)) { + op_merged_infos[cnode].mode = OP_MERGE_GEMM; + op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; + op_merged_infos[cnode->input(1)].referred_count -= 1; + } else if (cnode->IsApply(prim::kPrimTupleGetItem) && + IsPrimitiveCNode(cnode->input(1), std::make_shared("BatchNorm")) && + GetInt32Value(cnode->input(2)) == 0) { + op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM; + op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; + op_merged_infos[cnode->input(1)].referred_count -= 1; + } else if (cnode->IsApply(prim::kPrimTupleGetItem) && + IsPrimitiveCNode(cnode->input(1), std::make_shared("MaxPoolWithArgmax")) && + GetInt32Value(cnode->input(2)) == 0) { + op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX; + op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; + op_merged_infos[cnode->input(1)].referred_count -= 1; + } + } +} + +/** + * AnfNode + * +-- CNode + * +-- ANode + * | +-- Parameter + * | `-- ValueNode + */ +void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); + + std::unordered_map op_merged_infos; + MatchAndMark(func_graph, nodes, &op_merged_infos); + + for (const AnfNodePtr &node : nodes) { + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + auto iter = op_merged_infos.find(cnode); + // the node is not referenced by any other nodes, skip it + if (iter == op_merged_infos.end()) { + continue; + } + auto merged_info = iter->second; + // the op node is merged with other node and not used any more, skip it + if (merged_info.mode == OP_MERGE_IGNORE && merged_info.referred_count == 0) { + continue; + } + if (cnode == func_graph->get_return()) { + ExportOutput(func_graph, cnode, node_map_ptr, graph_proto); + continue; + } + switch (merged_info.mode) { + case OP_MERGE_CONV: + ExportMergeConv(func_graph, cnode, node_map_ptr, graph_proto); + break; + case OP_MERGE_GEMM: + ExportMergeGemm(func_graph, cnode, node_map_ptr, graph_proto); + break; + case OP_MERGE_BATCH_NORM: + ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto); + break; + case OP_MERGE_MAXPOOL_WITH_ARGMAX: + ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto); + break; + default: + ExportCNode(func_graph, cnode, node_map_ptr, graph_proto); + break; + } + } +} + +void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_shape = node->input(2); + std::string name_shape; + if (input_shape->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[input_shape] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_shape = std::to_string(const_node_idx); + node_proto->add_output(name_shape); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("value"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(input_shape)->value(), attr_proto->mutable_t()); + } else { + name_shape = GetNodeInputName(input_shape, node_map_ptr, graph_proto); + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Reshape."; + } + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type(prim::kPrimReshape->name()); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_shape); +} + +void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_axis = node->input(2); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + auto name = prim::kPrimReduceMean->name(); + if (node->IsApply(prim::kPrimReduceSum)) { + name = prim::kPrimReduceSum->name(); + } + node_proto->set_op_type(name); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); + + if (input_axis->isa()) { + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("axes"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); + auto axis_value = dyn_cast(input_axis)->value(); + auto int_ptr = dyn_cast(axis_value); + if (int_ptr == nullptr) { + auto tuple_ptr = dyn_cast(axis_value); + MS_EXCEPTION_IF_NULL(tuple_ptr); + for (size_t i = 0; i < tuple_ptr->size(); ++i) { + attr_proto->add_ints(GetValue((*tuple_ptr)[i])); + } + } else { + attr_proto->add_ints(int_ptr->value()); + } + } else { + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name; + } +} + +void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_type = node->input(2); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type(prim::kPrimCast->name()); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); + + if (input_type->isa()) { + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("to"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + auto type_value = dyn_cast(input_type)->value(); + auto type_ptr = dyn_cast(type_value); + MS_EXCEPTION_IF_NULL(type_ptr); + attr_proto->set_i(GetOnnxDataType(type_ptr->type_id())); + } else { + MS_LOG(EXCEPTION) << "Need to convert MindSpore Cast input(1) to ONNX Cast to attribute."; + } +} + +void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); + + auto x_shape = dyn_cast(node->input(1)->Shape()); + auto slope_shape = dyn_cast(node->input(2)->Shape()); + MS_EXCEPTION_IF_NULL(x_shape); + MS_EXCEPTION_IF_NULL(slope_shape); + + // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] + if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { + auto node_idx = AllocateNodeIndex(); + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Unsqueeze"); + node_proto->add_output(std::to_string(node_idx)); + + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); + attr_proto->set_name("axes"); + attr_proto->add_ints(1); + attr_proto->add_ints(2); + + node_proto->add_input(input_slope); + input_slope = std::to_string(node_idx); + } + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("PRelu"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_x); + node_proto->add_input(input_slope); +} + +void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Clip"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_x); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto->set_name("min"); + attr_proto->set_f(0.f); + attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto->set_name("max"); + attr_proto->set_f(6.f); +} + +void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_w = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); + auto x_shape = dyn_cast(node->input(1)->Shape()); + auto w_shape = dyn_cast(node->input(2)->Shape()); + MS_EXCEPTION_IF_NULL(x_shape); + MS_EXCEPTION_IF_NULL(w_shape); + if (x_shape->shape().size() != 4 || w_shape->shape().size() != 4) { + MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d."; + } + if (w_shape->shape()[0] != 1 && w_shape->shape()[1] != 1) { + MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape"; + } + // create w_shape constant node + auto node_idx = AllocateNodeIndex(); + onnx::NodeProto *node_proto = graph_proto->add_node(); + std::string name_w_shape = std::to_string(node_idx); + node_proto->add_output(name_w_shape); + node_proto->set_op_type("Constant"); + // create Value Tensor + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("value"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size())); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + // reshape + tensor_proto->add_int64_data(w_shape->shape()[1]); + tensor_proto->add_int64_data(w_shape->shape()[0]); + tensor_proto->add_int64_data(w_shape->shape()[2]); + tensor_proto->add_int64_data(w_shape->shape()[3]); + + // add reshape node + node_idx = AllocateNodeIndex(); + node_proto = graph_proto->add_node(); + node_proto->set_op_type(prim::kPrimReshape->name()); + node_proto->add_input(input_w); + node_proto->add_input(name_w_shape); + input_w = std::to_string(node_idx); + node_proto->add_output(input_w); + + // add conv node + node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + node_proto = graph_proto->add_node(); + node_proto->set_op_type("Conv"); + node_proto->add_input(input_x); + node_proto->add_input(input_w); + node_proto->add_output(std::to_string(node_idx)); + // set attributes + AnfNodePtr op = node->input(0); + auto op_value = dyn_cast(op); + auto prim = dyn_cast(op_value->value()); + // set dilations + onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("dilations"); + SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, + prim); + // set group + onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("group"); + onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + onnx_attr_proto->set_i(x_shape->shape()[1]); + // set kernel_shape + onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("kernel_shape"); + SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, + prim); + + // set pad + onnx_attr_proto = node_proto->add_attribute(); + auto attr_value = GetValue(prim->GetAttr("pad_mode")); + onnx_attr_proto->set_name("auto_pad"); + onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); + if (attr_value == "valid") { + onnx_attr_proto->set_s("VALID"); + } else if (attr_value == "same") { + onnx_attr_proto->set_s("SAME_UPPER"); + } else { + onnx_attr_proto->set_name("pads"); + SetAttrTupleValueToProto(prim->GetAttr("pads"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); + } + // set strides + onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("strides"); + SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); +} + +void OnnxExporter::ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto multiples = node->input(2); + std::string name_multiples; + if (multiples->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[multiples] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_multiples = std::to_string(const_node_idx); + node_proto->add_output(name_multiples); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("repeat"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(multiples)->value(), attr_proto->mutable_t()); + } else { + name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto); + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile."; + } + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Tile"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_multiples); +} + +void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + std::string name_exponent; + auto const_node_idx = AllocateNodeIndex(); + onnx::NodeProto *node_proto_exp = graph_proto->add_node(); + name_exponent = std::to_string(const_node_idx); + node_proto_exp->add_output(name_exponent); + + node_proto_exp->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + tensor_proto->set_name("exponent"); + tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1)); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + tensor_proto->add_int64_data(2); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Pow"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_exponent); +} + +void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto name_indices = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); + auto axis = node->input(3)->cast()->value(); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Gather"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_indices); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast(axis)->value())); +} + +void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert + if (node->IsApply(prim::kPrimReshape)) { + return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); + } + + if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) { + return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore Cast(x, T) --> ONNX Cast[to=T](x) + if (node->IsApply(prim::kPrimCast)) { + return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto); + } + + // ONNX PRelu requires unidirectional broadcasting, here need some process + if (node->IsApply(std::make_shared("PReLU"))) { + return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x) + if (node->IsApply(std::make_shared("ReLU6"))) { + return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w)) + if (node->IsApply(std::make_shared("DepthwiseConv2dNative"))) { + return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore Tile(x) --> ONNX Tile(x, repeat) + if (node->IsApply(prim::kPrimTile)) { + return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore Square(x) --> ONNX Pow(x, 2) + if (node->IsApply(prim::kPrimSquare)) { + return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices) + if (node->IsApply(prim::kPrimGatherV2)) { + return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto); + } + + auto inputs = node->inputs(); + if (inputs.size() < 1) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + + AnfNodePtr op = inputs[0]; + std::vector op_inputs; + // first process node input 1,2,..., since when node input is a ValueNode, here need to create a Constant Operator + for (size_t i = 1; i < inputs.size(); i++) { + op_inputs.push_back(inputs[i]); + } + auto op_value = dyn_cast(op); + if (op_value == nullptr) { + MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name(); + } + auto prim = dyn_cast(op_value->value()); + if (prim == nullptr) { + MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name(); + } + + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); +} + +size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map *node_map_ptr, + const PrimitivePtr &prim, const std::vector &inputs, + onnx::GraphProto *const graph_proto) { + auto op_map = OpConvertRegistry::GetOpConvertMap(); + auto op_iter = op_map.find(prim->name()); + if (op_iter == op_map.end()) { + MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; + } + const OpNameInfo &op_convert_info = op_iter->second; + + auto node_idx = AllocateNodeIndex(); + + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->add_output(std::to_string(node_idx)); + node_proto->set_op_type(op_convert_info.onnx_type()); + + // Set inputs + for (const auto &input : inputs) { + auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); + node_proto->add_input(input_name); + } + + // Set node attribute + for (const OpAttrInfo &attr : op_convert_info.op_attrs()) { + const std::string &attr_name = attr.attr_name(); + ValuePtr attr_value = nullptr; + if (!attr_name.empty()) { + attr_value = prim->GetAttr(attr_name); + if (attr_value == nullptr) { + MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; + } + } + onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name(attr.onnx_attr_name()); + attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); + } + return node_idx; +} + +void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto conv_node = dyn_cast(node->input(1)); + auto input_x = conv_node->input(1); // conv input x + auto input_w = conv_node->input(2); // conv weight(filter) + auto input_b = node->input(2); // conv bias + + PrimitivePtr prim_conv = dyn_cast((dyn_cast(conv_node->input(0)))->value()); + std::vector inputs{input_x, input_w, input_b}; + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); +} + +void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto matmul_node = dyn_cast(node->input(1)); + auto input_x = matmul_node->input(1); // matmul input x + auto input_y = matmul_node->input(2); // matmul input y + auto input_b = node->input(2); // matmul bias + + PrimitivePtr prim_matmul = dyn_cast((dyn_cast(matmul_node->input(0)))->value()); + std::vector inputs{input_x, input_y, input_b}; + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); +} + +void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto batch_norm_node = dyn_cast(node->input(1)); + + PrimitivePtr prim_batch_norm = dyn_cast((dyn_cast(batch_norm_node->input(0)))->value()); + std::vector inputs; + for (size_t i = 1; i < batch_norm_node->inputs().size(); i++) { + inputs.push_back(batch_norm_node->input(i)); + } + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); +} + +void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto maxpool_with_argmax_node = dyn_cast(node->input(1)); + + PrimitivePtr prim_maxpool_with_argmax = + dyn_cast((dyn_cast(maxpool_with_argmax_node->input(0)))->value()); + std::vector inputs; + for (size_t i = 1; i < maxpool_with_argmax_node->inputs().size(); i++) { + inputs.push_back(maxpool_with_argmax_node->input(i)); + } + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto); +} + +void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + if (node->inputs().size() != 2) { + MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; + } + AnfNodePtr arg = node->input(1); + std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + output_proto->set_name(name); + SetValueInfoType(arg, output_proto, false); +} + +std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + if (node->isa()) { + auto iter = node_map_ptr->find(node); + if (iter == node_map_ptr->end()) { + MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in node_map"; + } + return std::to_string(iter->second); + } + + if (node->isa()) { + return node->ToString(); + } + + // for ValueNode input, create a Constant Operator + if (node->isa()) { + auto iter = node_map_ptr->find(node); + if (iter != node_map_ptr->end()) { + return std::to_string(iter->second); + } + // the id number starts at 1, so the id of created node should be size of map plus one + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + std::string node_name = std::to_string(node_idx); + + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->add_output(node_name); + + SetNodeAttribute(node->cast()->value(), node_proto); + + return node_name; + } + + MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name(); +} + +void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { + auto tuple_ptr = dyn_cast(value); + MS_EXCEPTION_IF_NULL(tuple_ptr); + if (tuple_ptr->size() == 0) { + MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, the size of converted tuple is 0."; + } + auto type_id = (*tuple_ptr)[0]->type()->type_id(); + for (size_t i = 1; i < tuple_ptr->size(); ++i) { + if ((*tuple_ptr)[i]->type()->type_id() != type_id) { + MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, type of tuple elements is not same."; + } + } + + tensor_proto->add_dims(static_cast<::google::protobuf::int64>(tuple_ptr->size())); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + for (size_t i = 0; i < tuple_ptr->size(); ++i) { + ValuePtr elem = (*tuple_ptr)[i]; + if (elem->isa()) { + tensor_proto->add_int64_data(dyn_cast(elem)->value()); + } else if (elem->isa()) { + tensor_proto->add_int64_data(dyn_cast(elem)->value()); + } else if (elem->isa()) { + tensor_proto->add_int64_data(dyn_cast(elem)->value()); + } else if (elem->isa()) { + tensor_proto->add_int64_data(dyn_cast(elem)->value()); + } else { + MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, unexpected tuple element type " << elem->type()->type_name() + << "."; + } + } +} + +void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) { + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("value"); + if (value->isa()) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + auto casted_value = dyn_cast(value); + if (casted_value == nullptr) { + MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; + } + auto attr_value = casted_value->value(); + attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value)); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + } else if (value->isa()) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + auto data = dyn_cast(value); + tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); + auto dtype = data->data_type(); + auto shape = data->shape_c(); + + tensor_proto->set_data_type(GetOnnxDataType(dtype)); + for (const auto &dim : shape) { + tensor_proto->add_dims(dim); + } + } else { + MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; + } +} + +std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { + OnnxExporter exporter; + return exporter.GetOnnxProtoString(func_graph); +} +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h deleted file mode 100644 index ae678606a4..0000000000 --- a/mindspore/ccsrc/transform/op_adapter.h +++ /dev/null @@ -1,891 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_OP_ADAPTER_H_ -#define TRANSFORM_OP_ADAPTER_H_ - -#include -#include -#include -#include - -#include "transform/op_adapter_util.h" -#include "utils/utils.h" -namespace mindspore { -namespace transform { -static uint32_t CustomInferFunc(const Operator &) { return 0; } - -template -class OpAdapter : public BaseOpAdapter { - public: - using OpType = T; - OpAdapter() {} - explicit OpAdapter(const ExtraAttr &extra_attr) : extra_attr_(extra_attr) {} - ~OpAdapter() override {} - - bool IsCustomOp(const OperatorPtr &op) { - MS_EXCEPTION_IF_NULL(op); - auto it = cus_input_map_.find(op->GetOpType()); - if (it == cus_input_map_.end()) { - return false; - } - return true; - } - - Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { - MS_EXCEPTION_IF_NULL(op); - MS_EXCEPTION_IF_NULL(prim); - // Create the map of custom op from input index to input name. - std::unordered_map input_map; - auto value = prim->GetAttr("input_names"); - if (value == nullptr) { - cus_output_map_[prim->name()] = input_map; - return NOT_FOUND; - } - - auto input_names = GetValue>(value); - for (size_t i = 0; i < input_names.size(); ++i) { - // input_map begin form 1 - input_map[i + 1] = input_names[i]; - op->CustomInputRegister(input_names[i]); - } - - if (cus_input_map_.find(prim->name()) == cus_input_map_.end()) { - cus_input_map_[prim->name()] = input_map; - } - return SUCCESS; - } - - Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { - MS_EXCEPTION_IF_NULL(op); - MS_EXCEPTION_IF_NULL(prim); - // Create the map of custom op from output index to output name. - std::unordered_map output_map; - auto value = prim->GetAttr("output_names"); - if (value == nullptr) { - // generate a empty output_map for it - cus_output_map_[prim->name()] = output_map; - return NOT_FOUND; - } - - auto output_names = GetValue>(value); - for (size_t i = 0; i < output_names.size(); ++i) { - // output_map begin form 0 - output_map[i] = output_names[i]; - op->CustomOutputRegister(output_names[i]); - } - - if (cus_output_map_.find(prim->name()) == cus_output_map_.end()) { - cus_output_map_[prim->name()] = output_map; - } - return SUCCESS; - } - - // Convert ME UserCustom AnfNode to GE CustomOp. And set it's attrs. - OperatorPtr GenerateCustomOp(const AnfNodePtr anf) { - MS_EXCEPTION_IF_NULL(anf); - auto node = anf->cast(); - if (node == nullptr) { - return nullptr; - } - - if (node->inputs().empty()) { - MS_LOG(EXCEPTION) << "length of node inputs is empty"; - } - - auto prim = GetValueNode(node->inputs()[0]); - MS_EXCEPTION_IF_NULL(prim); - auto op = std::make_shared(node->fullname_with_scope(), prim->name()); - if (GenerateCustomOpInputMap(op, prim) != SUCCESS) { - MS_LOG(WARNING) << "Custom op node has no input_names, op[" << prim->name() << "]."; - } - - if (GenerateCustomOpOutputMap(op, prim) != SUCCESS) { - MS_LOG(WARNING) << "Custom op node has no output_names, op[" << prim->name() << "]."; - } - - op->CustomInferFuncRegister(CustomInferFunc); - - return op; - } - - OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) { - OperatorPtr op = nullptr; - // There are duplicate names in ANF graph, do not assign ANF node name to GE - // GE will generate unique name automatically - if (anf != nullptr && anf->fullname_with_scope() != "") { - MS_LOG(DEBUG) << anf->fullname_with_scope(); - op = std::make_shared(anf->fullname_with_scope()); - } else { - MS_LOG(DEBUG) << "no fullname_with_scope"; - op = std::make_shared(); - } - - // set dynamic output num if op use DYNAMIC_OUTPUT - if ((op != nullptr) && (!dyn_output_map_.empty()) && (anf != nullptr)) { - TypePtr type = anf->Type(); - if (type == nullptr) { - MS_LOG(EXCEPTION) << "Dynamic output node:" << op->GetName() << "'s Type is a nullptr!"; - } - size_t num = type->isa() ? (type->cast>()->size()) : 1; - MS_LOG(INFO) << "create_dyn_output for node:" << anf->ToString() << ", type:" << type->ToString() - << ", num:" << num; - dyn_output_map_.begin()->second.create_dyn_output(op, static_cast(num)); - } - return op; - } - - OperatorPtr generate(const AnfNodePtr &anf) override { - OperatorPtr op = nullptr; - if (IsCustomCNode(anf)) { - op = GenerateCustomOp(anf); - } else { - op = GenerateNormalOp(anf); - } - return op; - } - - OperatorPtr generate(const std::string &op_name) override { return std::make_shared(op_name); } - - const std::unordered_map &getInputMap() override { return input_map_; } - const std::unordered_map &getInputAttrMap() override { return input_attr_map_; } - const std::unordered_map &getDynInputMap() override { return dyn_input_map_; } - const std::unordered_map &getOutputMap() override { return output_map_; } - - Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { - MS_EXCEPTION_IF_NULL(op); - MS_EXCEPTION_IF_NULL(input); - auto it = cus_input_map_.find(op->GetOpType()); - if (it == cus_input_map_.end()) { - return NOT_FOUND; - } - std::unordered_map &input_map = it->second; - - if ((input_map.find(index) != input_map.end())) { - MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << input_map[index]; - (void)op->SetInput(input_map[index], *input); - return SUCCESS; - } - return NOT_FOUND; - } - - Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) { - MS_EXCEPTION_IF_NULL(op); - auto it = input_map_.find(index); - if (it != input_map_.end()) { - MS_EXCEPTION_IF_NULL(input); - MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << it->second.name; - it->second.set_op(op, input); - return SUCCESS; - } - return NOT_FOUND; - } - - int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override { - if (IsCustomOp(op)) { - auto cus_op = std::dynamic_pointer_cast(op); - return static_cast(SetCustomOpInput(cus_op, index, input)); - } else { - return static_cast(SetNormalOpInput(op, index, input)); - } - } - - Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) { - MS_EXCEPTION_IF_NULL(op); - auto it = cus_input_map_.find(op->GetOpType()); - if (it == cus_input_map_.end()) { - return NOT_FOUND; - } - - std::unordered_map &input_map = it->second; - if ((handle.op != nullptr) && (input_map.find(index) != input_map.end())) { - if (handle.out.empty()) { - MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << input_map[index]; - (void)op->SetInput(input_map[index], *(handle.op)); - } else { - MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":" - << input_map[index]; - (void)op->SetInput(input_map[index], *(handle.op), handle.out); - } - return SUCCESS; - } - return NOT_FOUND; - } - - Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) { - MS_EXCEPTION_IF_NULL(op); - auto it = input_map_.find(index); - if ((handle.op != nullptr) && (it != input_map_.end())) { - if (handle.out.empty()) { - MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << it->second.name; - it->second.set_op(op, handle.op); - } else { - MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":" - << it->second.name; - it->second.set_handle(op, handle); - } - return SUCCESS; - } - return NOT_FOUND; - } - - int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override { - if (IsCustomOp(op)) { - auto cus_op = std::dynamic_pointer_cast(op); - return static_cast(SetCustomOpInput(cus_op, index, handle)); - } else { - return static_cast(SetNormalOpInput(op, index, handle)); - } - } - - int setInput(const OperatorPtr &op, int index, const std::shared_ptr> &handler_vec) override { - MS_EXCEPTION_IF_NULL(handler_vec); - if (IsCustomOp(op)) { - MS_LOG(ERROR) << "Custom Op do not support dynamic input"; - return static_cast(FAILED); - } - MS_EXCEPTION_IF_NULL(op); - auto it = dyn_input_map_.find(index); - if (it != dyn_input_map_.end()) { - it->second.create_dyn_input(op, static_cast(handler_vec->size())); - for (unsigned int i = 0; i < handler_vec->size(); ++i) { - OutHandler h = (*handler_vec)[i]; - MS_EXCEPTION_IF_NULL(h.op); - if (h.out.empty()) { - MS_LOG(DEBUG) << "Link op " << h.op->GetName() << " to " << op->GetName() << ":" << it->second.name; - it->second.set_op(op, (i) /* index start from 0 */, h.op); - } else { - MS_LOG(DEBUG) << "Link op " << h.op->GetName() << ":" << h.out << " to " << op->GetName() << ":" - << it->second.name; - it->second.set_handle(op, i, h); - } - } - return 0; - } - return static_cast(NOT_FOUND); - } - - OutHandler getOutput(const OperatorPtr &op, int index) override { - MS_EXCEPTION_IF_NULL(op); - if (IsCustomOp(op)) { - return getCustomOutput(op, index); - } - return getNormalOutput(op, index); - } - - OutHandler getCustomOutput(const OperatorPtr &op, int index) { - MS_EXCEPTION_IF_NULL(op); - auto it = cus_output_map_.find(op->GetOpType()); - if (it == cus_output_map_.end()) { - MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT is not supported!"; - return OutHandler(); - } - - std::unordered_map &output_map = it->second; - - if ((output_map.find(index) != output_map.end())) { - return OutHandler(op, output_map[index]); - } - MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT index(" << index << ")!"; - return OutHandler(); - } - - OutHandler getNormalOutput(const OperatorPtr &op, int index) { - MS_EXCEPTION_IF_NULL(op); - if (!dyn_output_map_.empty() && !output_map_.empty()) { - MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!"; - return OutHandler(); - } - auto it = output_map_.find(index); - if (it != output_map_.end()) { - return OutHandler(op, it->second.name); - } else if (!dyn_output_map_.empty()) { - return OutHandler(op, dyn_output_map_.begin()->second.name + std::to_string(index)); - } else { - MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT and DYN_OUTPUT index(" << index << ")!"; - return OutHandler(); - } - } - - Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { - MS_EXCEPTION_IF_NULL(type); - std::string format = "NCHW"; - if (op->GetOpType() == kExtractImagePatchesOpName) { - format = "NHWC"; - } - - auto desc = CreateOutputDesc(dyn_cast(shp), type, format); - if (desc == nullptr) { - MS_LOG(ERROR) << "Update output descriptor failed!"; - return FAILED; - } - - if (IsCustomOp(op)) { - if (cus_output_map_.find(op->GetOpType()) == cus_output_map_.end() || - (cus_output_map_[op->GetOpType()].empty())) { - MS_LOG(ERROR) << "This op does not create custom output map"; - return FAILED; - } - auto cus_op = std::dynamic_pointer_cast(op); - MS_EXCEPTION_IF_NULL(cus_op); - std::unordered_map output_map = cus_output_map_[op->GetOpType()]; - (void)cus_op->UpdateOutputDesc(output_map[0], *desc); - } else { - if (output_map_.empty()) { - MS_LOG(INFO) << "This op does not have output map"; - return FAILED; - } - output_map_.begin()->second.update_out_desc(op, *desc); - } - return SUCCESS; - } - - size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { - MS_EXCEPTION_IF_NULL(cus_op); - if (cus_output_map_.find(cus_op->GetOpType()) == cus_output_map_.end()) { - MS_LOG(ERROR) << "This op does not create custom output map"; - return 0; - } - size_t output_size = cus_output_map_[cus_op->GetOpType()].size(); - return output_size; - } - - std::shared_ptr CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, - const std::string &format) { - if (shape_ptr == nullptr) { - MS_LOG(ERROR) << "Shape ptr is nullptr"; - return nullptr; - } - - if (type == nullptr) { - MS_LOG(ERROR) << "Type ptr is nullptr"; - return nullptr; - } - - TypeId me_type = type->type_id(); - if (kObjectTypeTensorType == me_type) { - me_type = dyn_cast(type)->element()->type_id(); - } - auto desc = TransformUtil::GetGeTensorDesc(shape_ptr->shape(), me_type, format); - return desc; - } - - Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { - auto tuple_shp = dyn_cast(shp); - MS_EXCEPTION_IF_NULL(tuple_shp); - - size_t output_size = 0; - bool is_custom_op = IsCustomOp(op); - if (is_custom_op) { - output_size = GetCustomOpOutputSize(std::dynamic_pointer_cast(op)); - } else { - output_size = output_map_.size(); - } - - if (output_size == 0) { - MS_LOG(INFO) << "This op does not have output map"; - return FAILED; - } - - if (output_size != tuple_shp->shape().size()) { - MS_LOG(ERROR) << "output_map is not equal tuple_shape size"; - return FAILED; - } - std::string format = "NCHW"; - if (op->GetOpType() == kTopKOpName) { - format = "NHWC"; - } - for (size_t i = 0; i < tuple_shp->shape().size(); ++i) { - auto tuple_type = dyn_cast(type); - MS_EXCEPTION_IF_NULL(tuple_type); - TypePtr type_elem = tuple_type->elements()[i]; - - auto desc = CreateOutputDesc(dyn_cast(tuple_shp->shape()[i]), type_elem, format); - if (desc == nullptr) { - MS_LOG(ERROR) << "Create output descriptor failed!"; - return FAILED; - } - - if (is_custom_op) { - (void)std::dynamic_pointer_cast(op)->UpdateOutputDesc(cus_output_map_[op->GetOpType()][i], - *desc); - } else { - auto it = output_map_.find(i); - if (it != output_map_.end()) { - it->second.update_out_desc(op, *desc); - } - } - } - return SUCCESS; - } - - std::shared_ptr CreateNodeDesc(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - TypeId me_type = node->Type()->type_id(); - if (kObjectTypeTensorType == me_type) { - me_type = dyn_cast(node->Type())->element()->type_id(); - } - if (me_type <= kNumberTypeBegin || me_type >= kNumberTypeEnd) { - return nullptr; - } - - std::vector shape; - auto shape_ptr = dyn_cast(node->Shape()); - if (nullptr != shape_ptr) { - shape = shape_ptr->shape(); - } - - auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); - if (desc == nullptr) { - MS_LOG(ERROR) << "Update output descriptor failed!"; - return nullptr; - } - return desc; - } - - void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) { - if (op == nullptr) { - MS_LOG(ERROR) << "op is nullptr"; - return; - } - MS_EXCEPTION_IF_NULL(node); - - auto inputs = node->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - auto it = input_map_.find(i); - if (it != input_map_.end()) { - auto desc = CreateNodeDesc(inputs[i]); - if (desc == nullptr) { - continue; - } - if (op->GetOpType() == kExtractImagePatchesOpName) { - desc->SetFormat(ge::Format::FORMAT_NHWC); - } - it->second.update_input_desc(op, *desc); - } - } - } - - void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { - if (op == nullptr) { - MS_LOG(ERROR) << "op is nullptr"; - return; - } - MS_EXCEPTION_IF_NULL(node); - - if (cus_input_map_.find(op->GetOpType()) == cus_input_map_.end() || (cus_input_map_[op->GetOpType()].empty())) { - MS_LOG(ERROR) << "This op does not create custom input map"; - return; - } - - std::unordered_map &input_map = cus_input_map_[op->GetOpType()]; - auto inputs = node->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - if (input_map.find(i) != input_map.end()) { - auto desc = CreateNodeDesc(inputs[i]); - if (desc == nullptr) { - continue; - } - (void)op->UpdateInputDesc(input_map[i], *desc); - } - } - } - - void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(op); - MS_EXCEPTION_IF_NULL(node); - if (IsCustomOp(op)) { - auto cus_op = std::dynamic_pointer_cast(op); - UpdateCustomOpInputDesc(cus_op, node); - } else { - UpdateNormalOpInputDesc(op, node); - } - } - - void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, - const AnfNodePtr &node) override { - if (op == nullptr) { - MS_LOG(ERROR) << "op is nullptr"; - return; - } - MS_EXCEPTION_IF_NULL(node); - MS_LOG(INFO) << "Op name is " << op->GetName(); - - auto normal_shape_ptr = dyn_cast(shp); - auto no_shape_ptr = dyn_cast(shp); - - if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) { - if (UpdateSingleOutputDesc(op, shp, type) != SUCCESS) { - return; - } - } else if (nullptr != dyn_cast(shp)) { - if (UpdateMultiOutputDesc(op, shp, type) != SUCCESS) { - return; - } - } else { - MS_LOG(WARNING) << "Update output desc failed, unknow output shape type"; - return; - } - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return; - } - - // Need to update input_desc while the output_desc is updated - updateInputDesc(op, node); - } - - int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override { - auto it = attr_map_.find(attrKey); - if (it != attr_map_.end()) { - // switch case for each avalilable attribute type - MS_LOG(INFO) << "Set attr: " << attrKey << "(" << it->second.name << "), value: " << attrValue->ToString(); - AddAttrToDrawGraph(attrKey + std::string("=") + attrValue->ToString()); - it->second.set_attr(op, attrValue); - return 0; - } - return static_cast(NOT_FOUND); - } - - int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { - enum ValueType { - SINGLE_VALUE = 0, - SEQUEUE_VALUE, - UNKNOWN_VALUE, - }; - - MS_EXCEPTION_IF_NULL(prim); - MS_EXCEPTION_IF_NULL(op); - - ValueType value_type = SINGLE_VALUE; - for (auto item : prim->attrs()) { - if (item.second->isa()) { - (void)op->SetAttr(item.first, GetValue(item.second)); - } else if (item.second->isa()) { - (void)op->SetAttr(item.first, GetValue(item.second)); - } else if (item.second->isa()) { - (void)op->SetAttr(item.first, GetValue(item.second)); - } else if (item.second->isa()) { - (void)op->SetAttr(item.first, GetValue(item.second)); - } else if (item.second->isa()) { - value_type = SEQUEUE_VALUE; - auto val_seq = item.second->cast(); - if ((*val_seq)[0]->isa()) { - (void)op->SetAttr(item.first, GetValue>(item.second)); - } else if ((*val_seq)[0]->isa()) { - (void)op->SetAttr(item.first, GetValue>(item.second)); - } else if ((*val_seq)[0]->isa()) { - (void)op->SetAttr(item.first, GetValue>(item.second)); - } else if ((*val_seq)[0]->isa()) { - (void)op->SetAttr(item.first, GetValue>(item.second)); - } else { - MS_LOG(EXCEPTION) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name() - << ", attr name: " << item.first << ", value: " << item.second->ToString(); - } - } else { - value_type = UNKNOWN_VALUE; - MS_LOG(WARNING) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name() - << ", attr name: " << item.first << ", value: " << item.second->ToString(); - return static_cast(NOT_FOUND); - } - - if (value_type == SINGLE_VALUE) { - AddAttrToDrawGraph(item.first + std::string("=") + item.second->ToString()); - } else if (value_type == SEQUEUE_VALUE) { - AddAttrToDrawGraph(item.first + std::string("=") + "[...]"); - } - } - return 0; - } - - int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { - int ret = 0; - MS_EXCEPTION_IF_NULL(prim); - MS_EXCEPTION_IF_NULL(op); - for (auto &it : attr_map_) { - auto value = prim->GetAttr(it.first); - if (value != nullptr) { - // set attr from primitive - ret = setAttr(op, it.first, value); - if (ret) { - return ret; - } - } else { - // set attr from extra_attr - auto it_extra = extra_attr_.find(it.first); - if (it_extra != extra_attr_.end()) { - ret = setAttr(op, it.first, it_extra->second); - if (ret) { - return ret; - } - } - } - } - return 0; - } - - int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { - int ret = 0; - if (IsCustomPrim(prim)) { - auto cus_op = std::dynamic_pointer_cast(op); - ret = SetCustomOpAttr(cus_op, prim); - } else { - ret = SetNormalOpAttr(op, prim); - } - return ret; - } - - int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { - // no attribute for lonely node - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return 0; - } - - auto cnode = node->cast(); - if (cnode == nullptr) { - return 0; - } - - auto &inputs = cnode->inputs(); - if (inputs.empty()) { - return 0; - } - - // get Attr T from abstract of anfnode first, - // if attr "T" appears in primitive, the primitive T will cover this one - if (attr_map_.find("T") != attr_map_.end()) { - // get dtype from inputs[1], if the node has no inputs, set the attr T with output dtype - TypePtr type; - if (inputs.size() > 1) { - type = inputs[1]->Type(); - } else { - type = node->Type(); - } - if (type != nullptr) { - (void)setAttr(op, "T", MakeValue(type)); - } - } - - // set attr from primitive and ExtraAttr - if (IsValueNode(inputs[0])) { - // set attr from primitive - PrimitivePtr prim = GetValueNode(inputs[0]); - int ret = setAttr(op, prim); - if (ret != 0) { - return ret; - } - } - - // set attr from const input - for (auto &it : input_attr_map_) { - if (inputs.size() <= it.first || !inputs[it.first]->isa()) { - continue; - } - auto const_value = GetValueNode(inputs[it.first]); - MS_LOG(INFO) << "Set attr: input_" << it.first << "(" << it.second.name - << "), value: " << const_value->ToString(); - if (const_value->isa()) { - continue; - } - AddAttrToDrawGraph(it.second.name + std::string("=") + const_value->ToString()); - it.second.set_attr(op, const_value); - } - return 0; - } - - std::unordered_map GetExtraAttr() override { return extra_attr_; } - - private: - template - static S ConvertAny(const ValuePtr &value, const AnyTraits &) { - return GetValue(value); - } - - // specialization for reverse bool - static bool ConvertAny(const ValuePtr &value, const AnyTraits &, bool reverse) { - return reverse != GetValue(value); - } - - template - static Q ConvertAny(const ValuePtr &value, const AnyTraits

&traits_from, const AnyTraits &traits_to) { - return ConvertAnyUtil(value, traits_from, traits_to); - } - - // specialization for tensor - static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits &traits) { - // To-DO the format may read from ME tensor - return ConvertAnyUtil(value, traits); - } - - // specialization for int - static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { - return static_cast(GetValue(value)); - } - - // specialization for int or tuple broadcast to Vector - static std::vector ConvertAny(const ValuePtr &value, const std::string &name, - const AnyTraits> anyTraitsInt) { - return ConvertAnyUtil(value, name, anyTraitsInt); - } - - static std::vector> ConvertAny(const ValuePtr &value, - const AnyTraits>>) { - MS_EXCEPTION_IF_NULL(value); - MS_LOG(INFO) << "Value: " << value->type_name(); - std::vector> list; - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got " << value->type_name(); - } - auto vec = value->cast(); - MS_EXCEPTION_IF_NULL(vec); - for (auto &it : vec->value()) { - MS_EXCEPTION_IF_NULL(it); - if (!it->isa()) { - MS_LOG(EXCEPTION) << "It should be ValueTuple, but got " << it->type_name(); - } - auto sub_vector = it->cast(); - std::vector sublist; - for (auto &item : sub_vector->value()) { - sublist.push_back(static_cast(GetValue(item))); - } - list.push_back(sublist); - } - return list; - } - - static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>>, - const AnyTraits>) { - MS_EXCEPTION_IF_NULL(value); - MS_LOG(DEBUG) << "Value: " << value->type_name(); - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Value should be ValueList, but got " << value->type_name(); - } - auto vec = value->cast(); - std::vector list; - for (auto &it : vec->value()) { - MS_EXCEPTION_IF_NULL(it); - if (!it->isa()) { - MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name(); - } - auto sub_vector = it->cast(); - for (auto &item : sub_vector->value()) { - list.push_back(static_cast(GetValue(item))); - } - } - return list; - } - - static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>, - const AnyTraits>) { - MS_EXCEPTION_IF_NULL(value); - MS_LOG(INFO) << "Value: " << value->type_name(); - std::vector list; - if (value->isa()) { - auto vec = value->cast(); - MS_EXCEPTION_IF_NULL(vec); - for (auto &it : vec->value()) { - list.push_back(static_cast(GetValue(it))); - } - return list; - } - if (value->isa()) { - list.push_back(static_cast(GetValue(value))); - return list; - } - MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name(); - } - - static std::string ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, - const AnyTraits anyTraitsStr) { - return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr); - } - - static std::vector ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, - const AnyTraits anyTraitsFlo) { - return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo); - } - - static std::vector ConvertAny(const ValuePtr &value, const std::string &format, - const AnyTraits> anyTraitsVec, - const AnyTraits anyTraitsInt) { - return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt); - } - - // convert value list for value tuple to vector - template - static std::vector ConvertAny(const ValuePtr &value, const AnyTraits

&anyTraitsP, - const AnyTraits> anyTraitsQ) { - return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ); - } - - static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { - auto name = GetValue(value); - auto it = enum_map_.find(name); - int v = 0; - if (it != enum_map_.end()) { - v = it->second; - } - return v; - } - - static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsGE) { - return ConvertAnyUtil(value, anyTraitsGE); - } - - // convert any value to tensor - static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsValue) { - return ConvertAnyUtil(value, anyTraitsValue); - } - - static const std::unordered_map input_map_; - static const std::unordered_map dyn_input_map_; - static const std::unordered_map output_map_; - static const std::unordered_map dyn_output_map_; - static const std::unordered_map attr_map_; - static const std::unordered_map enum_map_; - // convert input from anf graph to Attr in Operators - static const std::unordered_map input_attr_map_; - static std::unordered_map> cus_input_map_; - static std::unordered_map> cus_output_map_; - std::unordered_map extra_attr_; - std::unordered_map name_counts_; -}; - -template -const std::unordered_map OpAdapter::input_map_; -template -const std::unordered_map OpAdapter::dyn_input_map_; -template -const std::unordered_map OpAdapter::output_map_; -template -const std::unordered_map OpAdapter::dyn_output_map_; -template -const std::unordered_map OpAdapter::attr_map_; -template -const std::unordered_map OpAdapter::enum_map_; -template -const std::unordered_map OpAdapter::input_attr_map_; -template -std::unordered_map> OpAdapter::cus_input_map_; -template -std::unordered_map> OpAdapter::cus_output_map_; - -// specialization for method -} // namespace transform -} // namespace mindspore - -#endif // TRANSFORM_OP_ADAPTER_H_ diff --git a/mindspore/ccsrc/transform/op_adapter_base.h b/mindspore/ccsrc/transform/op_adapter_base.h deleted file mode 100644 index 01f96e251d..0000000000 --- a/mindspore/ccsrc/transform/op_adapter_base.h +++ /dev/null @@ -1,189 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_OP_ADAPTER_BASE_H_ -#define TRANSFORM_OP_ADAPTER_BASE_H_ - -#include -#include -#include -#include -#include -#include - -#include "transform/util.h" -#include "ir/anf.h" -#include "ir/primitive.h" -#include "ir/value.h" -#include "transform/types.h" - -#ifdef ENABLE_GE -#ifdef OPEN_SOURCE -#include "graph/types.h" -#endif -#endif - -#include "graph/operator_reg.h" -#ifdef OPEN_SOURCE -#include "ge/client/ge_api.h" -#else -#include "external/ge/ge_api.h" -#endif -#include "graph/tensor.h" -#include "transform/all_ops.h" - -namespace ge { -class CustomOperator : public Operator { - public: - CustomOperator(const string &name, const string &type) : Operator(name, type) {} - - ~CustomOperator() override{}; - - void CustomInputRegister(const string &name) { Operator::InputRegister(name); } - - void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); } - - void CustomInferFuncRegister(const std::function &func) { - Operator::InferFuncRegister(func); - } -}; -} // namespace ge - -namespace mindspore { -namespace transform { -using CusOperatorPtr = std::shared_ptr; -using CustomOperator = ge::CustomOperator; - -struct OutHandler { - OperatorPtr op; - std::string out; - OutHandler() : op(nullptr), out("") {} - OutHandler(const OperatorPtr &op, const std::string out) : op(op), out(out) {} -}; - -struct ControlEdge { - OperatorPtr src_op; - OperatorPtr dest_op; -}; - -using AttrFunc = std::function; -using OutputFunc = std::function; -using InputOpFunc = std::function; -using InputHandleFunc = std::function; -using CreateDynInputOpFunc = std::function; -using DynInputOpFunc = std::function; -using DynInputHandleFunc = std::function; -using UpdateOutputDescFunc = std::function; -using CreateDynOutputOpFunc = std::function; - -struct AttrDesc { - std::string name; - AttrFunc set_attr; -}; - -struct InputDesc { - std::string name; - InputOpFunc set_op; - InputHandleFunc set_handle; - UpdateOutputDescFunc update_input_desc; -}; - -struct DynInputDesc { - std::string name; - CreateDynInputOpFunc create_dyn_input; - DynInputOpFunc set_op; - DynInputHandleFunc set_handle; -}; - -struct OutputDesc { - std::string name; - UpdateOutputDescFunc update_out_desc; -}; - -struct DynOutputDesc { - std::string name; - CreateDynOutputOpFunc create_dyn_output; -}; - -class BaseOpAdapter { - public: - virtual ~BaseOpAdapter() {} - virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; - virtual OperatorPtr generate(const std::string &type) { return std::make_shared(type); } - virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; - virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; - virtual int setInput(const OperatorPtr &op, int index, - const std::shared_ptr> &handler_vec) = 0; - virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0; - virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0; - virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0; - virtual std::unordered_map GetExtraAttr() = 0; - template ::value>::type> - int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr &attrValue) { - return setAttr(op, attrKey, MakeValue(attrValue)); - } - template ::value>::type> - int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) { - return setAttr(op, attrKey, MakeValue(attrValue)); - } - virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0; - virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, - const AnfNodePtr &node) = 0; - virtual const std::unordered_map &getInputMap() = 0; - virtual const std::unordered_map &getInputAttrMap() = 0; - virtual const std::unordered_map &getDynInputMap() = 0; - virtual const std::unordered_map &getOutputMap() = 0; - void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } - const std::vector &GetAttrsFromDrawGraph() const { return attrs_vec_; } - void clearAttrVect() { attrs_vec_.clear(); } - - private: - std::vector attrs_vec_; -}; - -using OpAdapterPtr = std::shared_ptr; - -enum AttrType { - ATTR_INT = 0, - ATTR_FLOAT, - ATTR_DOUBLE, - ATTR_STRING, - ATTR_TENSOR, - ATTR_BOOL, - ATTR_LIST_INT, - ATTR_LIST_ANY_INT, - ATTR_ENUM -}; - -struct GeEnum {}; -struct TFType {}; -struct GEType {}; - -// declare Any type -template -struct AnyTraits { - using type = T; -}; - -template <> -struct AnyTraits { - using type = int64_t; -}; - -using ExtraAttr = std::unordered_map; -} // namespace transform -} // namespace mindspore -#endif // TRANSFORM_OP_ADAPTER_BASE_H_ diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc deleted file mode 100644 index cae43c13dc..0000000000 --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ /dev/null @@ -1,264 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/op_adapter_util.h" - -#include -#include -#include - -#include "utils/utils.h" -#include "transform/op_adapter_base.h" - -namespace mindspore { -namespace transform { -GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &) { - // To-DO the format may read from ME tensor - MS_EXCEPTION_IF_NULL(value); - auto me_tensor = value->cast(); - auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_NCHW); - return ge_tensor == nullptr ? GeTensor() : *ge_tensor; -} - -std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, - const AnyTraits>) { - MS_EXCEPTION_IF_NULL(value); - std::vector list; - if (name == "pad") { - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name(); - } - auto vec = value->cast(); - list.resize(vec->value().size() + 2); - list[0] = 1; - list[1] = 1; - (void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2, - [](const ValuePtr &val) { return static_cast(GetValue(val)); }); - } else { - int64_t data = GetValue(value); - int size = 2; // 2 int in list - list = TransformUtil::ConvertIntToList(data, size); - } - - return list; -} - -std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { - MS_EXCEPTION_IF_NULL(value); - auto vec = value->cast(); - if (nullptr == vec) { - MS_LOG(EXCEPTION) << "not ValueTuplePtr"; - } - std::ostringstream buffer; - int i = 0; - for (auto &it : vec->value()) { - if (i != 0) { - buffer << ","; - } - buffer << GetValue(it); - i++; - } - return buffer.str(); -} - -std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { - MS_EXCEPTION_IF_NULL(value); - auto vec = value->cast(); - if (nullptr == vec) { - MS_LOG(EXCEPTION) << "not ValueTuplePtr"; - } - std::vector list; - list.resize(vec->value().size()); - (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), - [](const ValuePtr &val) { return static_cast(GetValue(val)); }); - return list; -} - -std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, - const AnyTraits>, const AnyTraits) { - MS_EXCEPTION_IF_NULL(value); - auto vec = value->cast(); - if (nullptr == vec) { - MS_LOG(EXCEPTION) << "not ValueTuplePtr"; - } - std::vector list; - list.resize(vec->value().size()); - (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), - [](const ValuePtr &val) { return static_cast(GetValue(val)); }); - if (format == kOpFormat_NHWC) { - if (list.size() < 4) { - MS_LOG(EXCEPTION) << "The size of list is less than 4"; - } else { - int64_t temp = list[1]; - list[1] = list[2]; - list[2] = list[3]; - list[3] = temp; - } - } - return list; -} - -GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { - MS_EXCEPTION_IF_NULL(value); - if (!value->isa()) { - MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString() - << ", type: " << value->type_name() << ", value should be a Typeptr"; - } - auto type = value->cast(); - MS_EXCEPTION_IF_NULL(type); - TypeId me_type = type->type_id(); - if (kObjectTypeTensorType == me_type) { - me_type = dyn_cast(type)->element()->type_id(); - } - return TransformUtil::ConvertDataType(me_type); -} - -GeTensor VectorToTensorUtil(const ValuePtr &value) { - // convert tuple or list to ge tensor, only supported one dim for now - MS_EXCEPTION_IF_NULL(value); - auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); - if (vec.empty()) { - MS_LOG(WARNING) << "Convert a none tuple to an empty ge tensor"; - return GeTensor(); - } - MS_EXCEPTION_IF_NULL(vec[0]); - if (vec[0]->isa()) { - MS_LOG(INFO) << "convert value to tensor with data type = Int32"; - auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); - auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeInt32, kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; - } - return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(int32_t)); - } else if (vec[0]->isa()) { - MS_LOG(INFO) << "convert value to tensor with data type = Float32"; - auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); - auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeFloat32, kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; - } - return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(float)); - } else if (vec[0]->isa()) { - MS_LOG(INFO) << "convert value to tensor with data type = Bool"; - // We use uint8_t to save bool type data - auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); - auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeBool, kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; - } - return GeTensor(*desc, static_cast(data.data()), data.size() * sizeof(uint8_t)); - } else { - MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name(); - } - - return GeTensor(); -} - -GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - // convert me tensor to ge tensor - return ConvertAnyUtil(value, AnyTraits()); - } else if (value->isa() || value->isa()) { - return VectorToTensorUtil(value); - } else if (value->isa()) { - // convert scalar Int to GeTensor - MS_LOG(INFO) << "convert scalar to tensor with data type = Int32"; - GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); - auto v = GetValue(value); - desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(int32_t)); - } else if (value->isa()) { - // convert scalar Int64 to GeTensor - MS_LOG(INFO) << "convert scalar to tensor with data type = Int64"; - GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); - auto v = GetValue(value); - desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(int64_t)); - } else if (value->isa()) { - // convert scalar FP32 to GeTensor - MS_LOG(INFO) << "convert scalar to tensor with data type = FP32"; - GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); - auto v = GetValue(value); - desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(float)); - } else if (value->isa()) { - // convert scalar FP32 to GeTensor - MS_LOG(INFO) << "convert scalar to tensor with data type = Bool"; - GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL); - auto v = GetValue(value); - desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(bool)); - } else if (value->isa()) { - // convert String to GeTensor - MS_LOG(INFO) << "convert string to tensor with data type = String"; - std::string v = GetValue(value); - std::vector ge_shape; - GeShape shape(ge_shape); - GeTensorDesc desc(shape, ge::FORMAT_NCHW, ge::DT_STRING); - GeTensor str_tensor(desc); - str_tensor.SetData(v); - return str_tensor; - } else { - MS_LOG(WARNING) << "Unsupported value type: " << value->type_name() - << " to convert to tensor. Value: " << value->ToString(); - } - return GeTensor(); -} - -bool IsCustomPrim(const PrimitivePtr &prim) { - if (prim == nullptr) { - return false; - } - - ValuePtr flag = prim->GetAttr("_custom_op_flag"); - if (flag == nullptr) { - return false; - } - - bool is_custom_op = GetValue(flag); - if (!is_custom_op && prim->GetAttr("_custom_op_impl_config_path") != nullptr) { - MS_LOG(EXCEPTION) << "The custom op flag is false, but the op information config path is not null, non-custom op " - "can not assign the op information config path."; - } - - return is_custom_op; -} - -bool IsCustomCNode(const AnfNodePtr &anf) { - if (anf == nullptr) { - return false; - } - auto node = anf->cast(); - if (node == nullptr) { - return false; - } - if (node->inputs().empty()) { - MS_LOG(EXCEPTION) << "length of node inputs is empty"; - } - MS_EXCEPTION_IF_NULL(node->inputs()[0]); - if (!node->inputs()[0]->isa()) { - return false; - } - auto cus_prim = GetValueNode(node->inputs()[0]); - if (cus_prim == nullptr) { - return false; - } - - return IsCustomPrim(cus_prim); -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/op_adapter_util.h b/mindspore/ccsrc/transform/op_adapter_util.h deleted file mode 100644 index fcabc732d5..0000000000 --- a/mindspore/ccsrc/transform/op_adapter_util.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_OP_ADAPTER_UTIL_H_ -#define TRANSFORM_OP_ADAPTER_UTIL_H_ - -#include -#include - -#include "transform/op_adapter_base.h" - -namespace mindspore { -namespace transform { -template -static Q ConvertAnyUtil(const ValuePtr &value, const AnyTraits

&, const AnyTraits &) { - return static_cast(GetValue

(value)); -} - -GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &traits); - -std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, - const AnyTraits>); - -std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); - -std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); - -std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, - const AnyTraits>, const AnyTraits); - -GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits); - -template -std::vector ConvertAnyUtil(const ValuePtr &value, AnyTraits

, const AnyTraits>) { - if (!value->isa() && !value->isa()) { - MS_LOG(EXCEPTION) << "error convert Value to vector for value: " << value->ToString() - << ", type: " << value->type_name() << ", value should be a tuple or list"; - } - auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); - std::vector data; - for (auto &it : vec) { - data.push_back(ConvertAnyUtil(it, AnyTraits

(), AnyTraits())); - } - return data; -} - -GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits); - -bool IsCustomPrim(const PrimitivePtr &prim); -bool IsCustomCNode(const AnfNodePtr &node); -} // namespace transform -} // namespace mindspore -#endif // TRANSFORM_OP_ADAPTER_UTIL_H_ diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h deleted file mode 100755 index baa819f71f..0000000000 --- a/mindspore/ccsrc/transform/op_declare.h +++ /dev/null @@ -1,497 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_OP_DECLARE_H_ -#define TRANSFORM_OP_DECLARE_H_ - -#include -#include -#include "transform/op_adapter.h" - -namespace mindspore { -namespace transform { -#define DECLARE_OP_ADAPTER(T) \ - using T = ge::op::T; \ - template <> \ - const std::unordered_map OpAdapter::input_map_; \ - template <> \ - const std::unordered_map OpAdapter::attr_map_; - -#define DECLARE_OP_USE_OUTPUT(T) \ - template <> \ - const std::unordered_map OpAdapter::output_map_; - -#define DECLARE_OP_USE_ENUM(T) \ - template <> \ - const std::unordered_map OpAdapter::enum_map_; - -#define DECLARE_OP_USE_INPUT_ATTR(T) \ - template <> \ - const std::unordered_map OpAdapter::input_attr_map_; - -#define DECLARE_OP_USE_DYN_INPUT(T) \ - template <> \ - const std::unordered_map OpAdapter::dyn_input_map_; - -#define DECLARE_OP_USE_DYN_OUTPUT(T) \ - template <> \ - const std::unordered_map OpAdapter::dyn_output_map_; - -template <> -std::unordered_map> OpAdapter::cus_input_map_; -template <> -std::unordered_map> OpAdapter::cus_output_map_; - -DECLARE_OP_ADAPTER(GreaterEqual) -DECLARE_OP_USE_OUTPUT(GreaterEqual) -DECLARE_OP_ADAPTER(SliceD) -DECLARE_OP_USE_INPUT_ATTR(SliceD) -DECLARE_OP_USE_OUTPUT(SliceD) -DECLARE_OP_ADAPTER(AssignAdd) -DECLARE_OP_USE_OUTPUT(AssignAdd) -DECLARE_OP_ADAPTER(AssignSub) -DECLARE_OP_USE_OUTPUT(AssignSub) - -DECLARE_OP_ADAPTER(ReduceMean) -DECLARE_OP_ADAPTER(Multiply) -DECLARE_OP_USE_OUTPUT(Multiply) - -// ** Distributed Operations ** -DECLARE_OP_ADAPTER(HcomReduceScatter) -DECLARE_OP_USE_OUTPUT(HcomReduceScatter) -DECLARE_OP_ADAPTER(HcomBroadcast) -DECLARE_OP_USE_DYN_INPUT(HcomBroadcast) -DECLARE_OP_USE_DYN_OUTPUT(HcomBroadcast) -DECLARE_OP_ADAPTER(HcomAllReduce) -DECLARE_OP_USE_OUTPUT(HcomAllReduce) -DECLARE_OP_ADAPTER(HcomAllGather) -DECLARE_OP_USE_OUTPUT(HcomAllGather) -DECLARE_OP_ADAPTER(Variable) -DECLARE_OP_ADAPTER(ReluGrad) -DECLARE_OP_USE_OUTPUT(ReluGrad) -DECLARE_OP_ADAPTER(BiasAddGrad) -DECLARE_OP_USE_OUTPUT(BiasAddGrad) -DECLARE_OP_ADAPTER(MaxPoolWithArgmax) -DECLARE_OP_USE_OUTPUT(MaxPoolWithArgmax) -DECLARE_OP_ADAPTER(MaxPoolGradWithArgmax) -DECLARE_OP_USE_OUTPUT(MaxPoolGradWithArgmax) -DECLARE_OP_ADAPTER(Conv2D) -DECLARE_OP_USE_ENUM(Conv2D) -DECLARE_OP_USE_OUTPUT(Conv2D) -DECLARE_OP_ADAPTER(ExtractImagePatches) -DECLARE_OP_USE_OUTPUT(ExtractImagePatches) -DECLARE_OP_ADAPTER(Conv2DBackpropInputD) -DECLARE_OP_USE_ENUM(Conv2DBackpropInputD) -DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropInputD) -DECLARE_OP_USE_OUTPUT(Conv2DBackpropInputD) -DECLARE_OP_ADAPTER(Conv2DBackpropFilterD) -DECLARE_OP_USE_ENUM(Conv2DBackpropFilterD) -DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropFilterD) -DECLARE_OP_USE_OUTPUT(Conv2DBackpropFilterD) -DECLARE_OP_ADAPTER(DepthwiseConv2D) -DECLARE_OP_USE_ENUM(DepthwiseConv2D) -DECLARE_OP_USE_OUTPUT(DepthwiseConv2D) -DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropFilterD) -DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropFilterD) -DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropFilterD) -DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropInputD) -DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD) -DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD) -DECLARE_OP_ADAPTER(Reshape) -DECLARE_OP_USE_OUTPUT(Reshape) -DECLARE_OP_ADAPTER(TransShape) -DECLARE_OP_USE_INPUT_ATTR(TransShape) -DECLARE_OP_USE_OUTPUT(TransShape) -DECLARE_OP_ADAPTER(Iou) -DECLARE_OP_USE_OUTPUT(Iou) -DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D) -DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2D) -DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad) -DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad) -DECLARE_OP_ADAPTER(ApplyAdam) -DECLARE_OP_USE_OUTPUT(ApplyAdam) -DECLARE_OP_ADAPTER(ApplyAdamD) -DECLARE_OP_USE_OUTPUT(ApplyAdamD) -DECLARE_OP_ADAPTER(Relu6) -DECLARE_OP_USE_OUTPUT(Relu6) -DECLARE_OP_ADAPTER(Relu6Grad) -DECLARE_OP_USE_OUTPUT(Relu6Grad) -DECLARE_OP_ADAPTER(ResizeBilinearV2D) -DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D) -DECLARE_OP_ADAPTER(ResizeBilinearV2Grad) -DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad) -DECLARE_OP_ADAPTER(ZerosLike) -DECLARE_OP_USE_OUTPUT(ZerosLike) -DECLARE_OP_ADAPTER(OnesLike) -DECLARE_OP_USE_OUTPUT(OnesLike) -DECLARE_OP_ADAPTER(TensorScatterUpdate) -DECLARE_OP_USE_OUTPUT(TensorScatterUpdate) -DECLARE_OP_ADAPTER(ScatterUpdate) -DECLARE_OP_USE_OUTPUT(ScatterUpdate) -DECLARE_OP_ADAPTER(ScatterNdUpdate) -DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) -DECLARE_OP_ADAPTER(ScatterMax) -DECLARE_OP_USE_OUTPUT(ScatterMax) -DECLARE_OP_ADAPTER(NMSWithMask) -DECLARE_OP_USE_OUTPUT(NMSWithMask) -DECLARE_OP_ADAPTER(Unpack) -DECLARE_OP_USE_DYN_OUTPUT(Unpack) -DECLARE_OP_ADAPTER(CheckValid) -DECLARE_OP_USE_OUTPUT(CheckValid) -DECLARE_OP_ADAPTER(SmoothL1Loss) -DECLARE_OP_USE_OUTPUT(SmoothL1Loss) -DECLARE_OP_ADAPTER(SmoothL1LossGrad) -DECLARE_OP_USE_OUTPUT(SmoothL1LossGrad) -DECLARE_OP_ADAPTER(SigmoidCrossEntropyWithLogits) -DECLARE_OP_USE_OUTPUT(SigmoidCrossEntropyWithLogits) -DECLARE_OP_ADAPTER(SigmoidCrossEntropyWithLogitsGrad) -DECLARE_OP_USE_OUTPUT(SigmoidCrossEntropyWithLogitsGrad) -DECLARE_OP_ADAPTER(ScatterNdD) -DECLARE_OP_USE_INPUT_ATTR(ScatterNdD) -DECLARE_OP_USE_OUTPUT(ScatterNdD) -DECLARE_OP_ADAPTER(PadD) -DECLARE_OP_USE_OUTPUT(PadD) -DECLARE_OP_ADAPTER(MirrorPad) -DECLARE_OP_USE_OUTPUT(MirrorPad) -DECLARE_OP_ADAPTER(MirrorPadGrad) -DECLARE_OP_USE_OUTPUT(MirrorPadGrad) -DECLARE_OP_ADAPTER(BoundingBoxEncode) -DECLARE_OP_USE_OUTPUT(BoundingBoxEncode) -DECLARE_OP_ADAPTER(BoundingBoxDecode) -DECLARE_OP_USE_OUTPUT(BoundingBoxDecode) -DECLARE_OP_ADAPTER(GatherNd) -DECLARE_OP_USE_OUTPUT(GatherNd) -DECLARE_OP_ADAPTER(ArgMaxD) -DECLARE_OP_USE_OUTPUT(ArgMaxD) -DECLARE_OP_ADAPTER(ArgMinD) -DECLARE_OP_USE_OUTPUT(ArgMinD) -DECLARE_OP_ADAPTER(ArgMaxWithValue) -DECLARE_OP_USE_OUTPUT(ArgMaxWithValue) -DECLARE_OP_ADAPTER(ArgMinWithValue) -DECLARE_OP_USE_OUTPUT(ArgMinWithValue) -DECLARE_OP_ADAPTER(Mul) -DECLARE_OP_USE_OUTPUT(Mul) -DECLARE_OP_ADAPTER(AddN) -DECLARE_OP_USE_DYN_INPUT(AddN) -DECLARE_OP_USE_OUTPUT(AddN) -DECLARE_OP_ADAPTER(Less) -DECLARE_OP_USE_OUTPUT(Less) -DECLARE_OP_ADAPTER(Rsqrt) -DECLARE_OP_USE_OUTPUT(Rsqrt) -DECLARE_OP_ADAPTER(Sqrt) -DECLARE_OP_USE_OUTPUT(Sqrt) -DECLARE_OP_ADAPTER(Square) -DECLARE_OP_USE_OUTPUT(Square) -DECLARE_OP_ADAPTER(SplitD) -DECLARE_OP_USE_DYN_OUTPUT(SplitD) -DECLARE_OP_ADAPTER(SGD) -DECLARE_OP_USE_OUTPUT(SGD) -DECLARE_OP_ADAPTER(SquareSumAll) -DECLARE_OP_USE_OUTPUT(SquareSumAll) - -DECLARE_OP_ADAPTER(Tanh) -DECLARE_OP_USE_OUTPUT(Tanh) -DECLARE_OP_ADAPTER(TanhGrad) -DECLARE_OP_USE_OUTPUT(TanhGrad) -DECLARE_OP_ADAPTER(Maximum) -DECLARE_OP_USE_OUTPUT(Maximum) -DECLARE_OP_ADAPTER(Minimum) -DECLARE_OP_USE_OUTPUT(Minimum) -DECLARE_OP_ADAPTER(MaximumGrad) -DECLARE_OP_USE_OUTPUT(MaximumGrad) -DECLARE_OP_ADAPTER(MinimumGrad) -DECLARE_OP_USE_OUTPUT(MinimumGrad) -DECLARE_OP_ADAPTER(ReduceMinD) -DECLARE_OP_USE_INPUT_ATTR(ReduceMinD) -DECLARE_OP_USE_OUTPUT(ReduceMinD) -DECLARE_OP_ADAPTER(ReduceMaxD) -DECLARE_OP_USE_INPUT_ATTR(ReduceMaxD) -DECLARE_OP_USE_OUTPUT(ReduceMaxD) -DECLARE_OP_ADAPTER(Merge) -DECLARE_OP_USE_DYN_INPUT(Merge) -DECLARE_OP_USE_OUTPUT(Merge) -DECLARE_OP_ADAPTER(Switch) -DECLARE_OP_USE_OUTPUT(Switch) - -DECLARE_OP_ADAPTER(TopK) -DECLARE_OP_USE_OUTPUT(TopK) - -DECLARE_OP_ADAPTER(RealDiv) -DECLARE_OP_USE_OUTPUT(RealDiv) - -DECLARE_OP_ADAPTER(Cast) -DECLARE_OP_USE_INPUT_ATTR(Cast) -DECLARE_OP_USE_OUTPUT(Cast) -DECLARE_OP_ADAPTER(Reciprocal) -DECLARE_OP_USE_OUTPUT(Reciprocal) -DECLARE_OP_ADAPTER(Neg) -DECLARE_OP_USE_OUTPUT(Neg) -DECLARE_OP_ADAPTER(TransposeD) -DECLARE_OP_USE_INPUT_ATTR(TransposeD) -// Do not set Transpose operator output descriptor -DECLARE_OP_ADAPTER(Sub) -DECLARE_OP_USE_OUTPUT(Sub) -DECLARE_OP_ADAPTER(DropOutGenMask) -DECLARE_OP_USE_OUTPUT(DropOutGenMask) -DECLARE_OP_ADAPTER(ConcatD) -DECLARE_OP_USE_DYN_INPUT(ConcatD) -DECLARE_OP_USE_OUTPUT(ConcatD) -DECLARE_OP_ADAPTER(Pack) -DECLARE_OP_USE_DYN_INPUT(Pack) -DECLARE_OP_USE_OUTPUT(Pack) - -DECLARE_OP_ADAPTER(Pow) -DECLARE_OP_USE_OUTPUT(Pow) -DECLARE_OP_ADAPTER(Equal) -DECLARE_OP_USE_OUTPUT(Equal) -DECLARE_OP_ADAPTER(NotEqual) -DECLARE_OP_USE_OUTPUT(NotEqual) -DECLARE_OP_ADAPTER(Log) -DECLARE_OP_USE_OUTPUT(Log) -DECLARE_OP_ADAPTER(LogicalAnd) -DECLARE_OP_USE_OUTPUT(LogicalAnd) -DECLARE_OP_ADAPTER(LogicalOr) -DECLARE_OP_USE_OUTPUT(LogicalOr) -DECLARE_OP_ADAPTER(LogicalNot) -DECLARE_OP_USE_OUTPUT(LogicalNot) -DECLARE_OP_ADAPTER(LogSoftmaxGrad) -DECLARE_OP_USE_OUTPUT(LogSoftmaxGrad) - -DECLARE_OP_ADAPTER(RandomChoiceWithMask) -DECLARE_OP_USE_OUTPUT(RandomChoiceWithMask) - -DECLARE_OP_ADAPTER(Select) -DECLARE_OP_USE_OUTPUT(Select) -DECLARE_OP_ADAPTER(LessEqual) -DECLARE_OP_USE_OUTPUT(LessEqual) -DECLARE_OP_ADAPTER(LogSoftmaxV2) -DECLARE_OP_USE_OUTPUT(LogSoftmaxV2) -DECLARE_OP_ADAPTER(TruncatedNormal) -DECLARE_OP_USE_OUTPUT(TruncatedNormal) -DECLARE_OP_ADAPTER(StridedSliceGrad) -DECLARE_OP_USE_OUTPUT(StridedSliceGrad) -DECLARE_OP_ADAPTER(Gelu) -DECLARE_OP_USE_OUTPUT(Gelu) -DECLARE_OP_ADAPTER(GeluGrad) -DECLARE_OP_USE_OUTPUT(GeluGrad) -DECLARE_OP_ADAPTER(StridedSlice) -DECLARE_OP_USE_OUTPUT(StridedSlice) -DECLARE_OP_ADAPTER(UnsortedSegmentSumD) -DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD) -DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD) -DECLARE_OP_ADAPTER(UnsortedSegmentMin) -DECLARE_OP_USE_OUTPUT(UnsortedSegmentMin) -DECLARE_OP_ADAPTER(ExpandDims) -DECLARE_OP_USE_OUTPUT(ExpandDims) -DECLARE_OP_ADAPTER(Squeeze) -DECLARE_OP_USE_OUTPUT(Squeeze) -DECLARE_OP_ADAPTER(LayerNorm) -DECLARE_OP_USE_OUTPUT(LayerNorm) -DECLARE_OP_ADAPTER(LayerNormGrad) -DECLARE_OP_USE_OUTPUT(LayerNormGrad) -DECLARE_OP_ADAPTER(BatchMatMul) -DECLARE_OP_USE_OUTPUT(BatchMatMul) -DECLARE_OP_ADAPTER(DropOutDoMask) -DECLARE_OP_USE_OUTPUT(DropOutDoMask) -// ** Mix-precision Operations ** -DECLARE_OP_ADAPTER(NPUGetFloatStatus) -DECLARE_OP_USE_OUTPUT(NPUGetFloatStatus) -DECLARE_OP_ADAPTER(NPUAllocFloatStatus) -DECLARE_OP_USE_OUTPUT(NPUAllocFloatStatus) -DECLARE_OP_ADAPTER(NPUClearFloatStatus) -DECLARE_OP_USE_OUTPUT(NPUClearFloatStatus) -DECLARE_OP_ADAPTER(MatMul) -DECLARE_OP_USE_OUTPUT(MatMul) - -DECLARE_OP_ADAPTER(SoftmaxCrossEntropyWithLogits) -DECLARE_OP_USE_OUTPUT(SoftmaxCrossEntropyWithLogits) - -DECLARE_OP_ADAPTER(MeanGrad) -DECLARE_OP_USE_INPUT_ATTR(MeanGrad) - -DECLARE_OP_ADAPTER(Assign) -DECLARE_OP_USE_OUTPUT(Assign) -DECLARE_OP_ADAPTER(Constant) -DECLARE_OP_USE_OUTPUT(Constant) -DECLARE_OP_ADAPTER(ApplyMomentumD) -DECLARE_OP_USE_OUTPUT(ApplyMomentumD) -// ** Summary Operations ** -DECLARE_OP_ADAPTER(Summary) - -// fully supported -DECLARE_OP_ADAPTER(Add) -DECLARE_OP_USE_OUTPUT(Add) -DECLARE_OP_ADAPTER(Const) -DECLARE_OP_USE_OUTPUT(Const) -DECLARE_OP_ADAPTER(Cos) -DECLARE_OP_USE_OUTPUT(Cos) - -DECLARE_OP_ADAPTER(Acos) -DECLARE_OP_USE_OUTPUT(Acos) -DECLARE_OP_ADAPTER(AcosGrad) -DECLARE_OP_USE_OUTPUT(AcosGrad) -DECLARE_OP_ADAPTER(Acosh) -DECLARE_OP_USE_OUTPUT(Acosh) -DECLARE_OP_ADAPTER(AcoshGrad) -DECLARE_OP_USE_OUTPUT(AcoshGrad) - -DECLARE_OP_ADAPTER(Floor) -DECLARE_OP_USE_OUTPUT(Floor) -DECLARE_OP_ADAPTER(FloorDiv) -DECLARE_OP_USE_OUTPUT(FloorDiv) -DECLARE_OP_ADAPTER(FloorMod) -DECLARE_OP_USE_OUTPUT(FloorMod) -DECLARE_OP_ADAPTER(Sin) -DECLARE_OP_USE_OUTPUT(Sin) -DECLARE_OP_ADAPTER(Exp) -DECLARE_OP_USE_OUTPUT(Exp) - -DECLARE_OP_ADAPTER(ReduceAllD) -DECLARE_OP_USE_INPUT_ATTR(ReduceAllD) -DECLARE_OP_USE_OUTPUT(ReduceAllD) -DECLARE_OP_ADAPTER(ReduceSumD) -DECLARE_OP_USE_INPUT_ATTR(ReduceSumD) -DECLARE_OP_USE_OUTPUT(ReduceSumD) -DECLARE_OP_ADAPTER(ReduceMeanD) -DECLARE_OP_USE_INPUT_ATTR(ReduceMeanD) -DECLARE_OP_USE_OUTPUT(ReduceMeanD) -DECLARE_OP_ADAPTER(ReduceProdD) -DECLARE_OP_USE_INPUT_ATTR(ReduceProdD) -DECLARE_OP_USE_OUTPUT(ReduceProdD) -DECLARE_OP_ADAPTER(CumprodD) -DECLARE_OP_USE_INPUT_ATTR(CumprodD) -DECLARE_OP_USE_OUTPUT(CumprodD) - -DECLARE_OP_ADAPTER(TileD) -DECLARE_OP_USE_INPUT_ATTR(TileD) -DECLARE_OP_USE_OUTPUT(TileD) -DECLARE_OP_ADAPTER(OneHot) -DECLARE_OP_USE_OUTPUT(OneHot) -DECLARE_OP_ADAPTER(GatherV2D) -DECLARE_OP_USE_INPUT_ATTR(GatherV2D) -DECLARE_OP_USE_OUTPUT(GatherV2D) -DECLARE_OP_ADAPTER(RangeD) -DECLARE_OP_USE_OUTPUT(RangeD) - -DECLARE_OP_ADAPTER(Data) -DECLARE_OP_ADAPTER(BiasAdd) -DECLARE_OP_USE_OUTPUT(BiasAdd) -DECLARE_OP_ADAPTER(BatchNorm) -DECLARE_OP_USE_OUTPUT(BatchNorm) -DECLARE_OP_ADAPTER(BatchNormGrad) -DECLARE_OP_USE_OUTPUT(BatchNormGrad) -DECLARE_OP_ADAPTER(Relu) -DECLARE_OP_USE_OUTPUT(Relu) -DECLARE_OP_ADAPTER(PRelu) -DECLARE_OP_USE_OUTPUT(PRelu) -DECLARE_OP_ADAPTER(Elu) -DECLARE_OP_USE_OUTPUT(Elu) - -DECLARE_OP_ADAPTER(EluGrad) -DECLARE_OP_USE_OUTPUT(EluGrad) -DECLARE_OP_ADAPTER(PReluGrad) -DECLARE_OP_USE_OUTPUT(PReluGrad) - -DECLARE_OP_ADAPTER(L2Normalize) -DECLARE_OP_USE_OUTPUT(L2Normalize) - -DECLARE_OP_ADAPTER(CumsumD) -DECLARE_OP_USE_INPUT_ATTR(CumsumD) -DECLARE_OP_USE_OUTPUT(CumsumD) -DECLARE_OP_ADAPTER(L2NormalizeGrad) -DECLARE_OP_USE_OUTPUT(L2NormalizeGrad) -DECLARE_OP_ADAPTER(Sigmoid) -DECLARE_OP_USE_OUTPUT(Sigmoid) -DECLARE_OP_ADAPTER(SigmoidGrad) -DECLARE_OP_USE_OUTPUT(SigmoidGrad) -DECLARE_OP_ADAPTER(SoftmaxV2) -DECLARE_OP_USE_OUTPUT(SoftmaxV2) -DECLARE_OP_ADAPTER(SoftmaxGrad) -DECLARE_OP_USE_OUTPUT(SoftmaxGrad) -DECLARE_OP_ADAPTER(Greater) -DECLARE_OP_USE_OUTPUT(Greater) -DECLARE_OP_ADAPTER(Flatten) -DECLARE_OP_USE_OUTPUT(Flatten) -DECLARE_OP_ADAPTER(GatherV2) -DECLARE_OP_USE_OUTPUT(GatherV2) -DECLARE_OP_ADAPTER(MaxPool) -DECLARE_OP_USE_OUTPUT(MaxPool) -DECLARE_OP_ADAPTER(MaxPoolGrad) -DECLARE_OP_USE_OUTPUT(MaxPoolGrad) -DECLARE_OP_ADAPTER(AvgPool) -DECLARE_OP_USE_OUTPUT(AvgPool) -DECLARE_OP_ADAPTER(AvgPoolGrad) -DECLARE_OP_USE_OUTPUT(AvgPoolGrad) -DECLARE_OP_ADAPTER(ROIAlign) -DECLARE_OP_USE_OUTPUT(ROIAlign) -DECLARE_OP_ADAPTER(ROIAlignGrad) -DECLARE_OP_USE_OUTPUT(ROIAlignGrad) -DECLARE_OP_ADAPTER(Abs) -DECLARE_OP_USE_OUTPUT(Abs) -DECLARE_OP_ADAPTER(AbsGrad) -DECLARE_OP_USE_OUTPUT(AbsGrad) -DECLARE_OP_ADAPTER(BinaryCrossEntropy) -DECLARE_OP_USE_OUTPUT(BinaryCrossEntropy) -DECLARE_OP_ADAPTER(BinaryCrossEntropyGrad) -DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad) -DECLARE_OP_ADAPTER(SparseApplyAdagradD) -DECLARE_OP_USE_OUTPUT(SparseApplyAdagradD) -DECLARE_OP_ADAPTER(ApplyProximalAdagradD) -DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD) -DECLARE_OP_ADAPTER(SpaceToDepth) -DECLARE_OP_USE_OUTPUT(SpaceToDepth) -DECLARE_OP_ADAPTER(DepthToSpace) -DECLARE_OP_USE_OUTPUT(DepthToSpace) -DECLARE_OP_ADAPTER(Sign) -DECLARE_OP_USE_OUTPUT(Sign) -DECLARE_OP_ADAPTER(LarsV2Update) -DECLARE_OP_USE_OUTPUT(LarsV2Update) -DECLARE_OP_ADAPTER(Round) -DECLARE_OP_USE_OUTPUT(Round) -DECLARE_OP_ADAPTER(ApplyFtrlD) -DECLARE_OP_USE_OUTPUT(ApplyFtrlD) -DECLARE_OP_ADAPTER(SparseApplyFtrlD) -DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) -DECLARE_OP_ADAPTER(Diag) -DECLARE_OP_USE_OUTPUT(Diag) -DECLARE_OP_ADAPTER(DiagPart) -DECLARE_OP_USE_OUTPUT(DiagPart) -DECLARE_OP_ADAPTER(SpaceToBatchD) -DECLARE_OP_USE_OUTPUT(SpaceToBatchD) -DECLARE_OP_ADAPTER(BatchToSpaceD) -DECLARE_OP_USE_OUTPUT(BatchToSpaceD) -DECLARE_OP_ADAPTER(Atan2) -DECLARE_OP_USE_OUTPUT(Atan2) -DECLARE_OP_ADAPTER(ApplyRMSPropD) -DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) -DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) -DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) -DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) -DECLARE_OP_ADAPTER(L2Loss) -DECLARE_OP_USE_OUTPUT(L2Loss) -DECLARE_OP_ADAPTER(CTCLoss) -DECLARE_OP_USE_OUTPUT(CTCLoss) -DECLARE_OP_ADAPTER(AscendQuant) -DECLARE_OP_USE_OUTPUT(AscendQuant) -DECLARE_OP_ADAPTER(AscendDequant) -DECLARE_OP_USE_OUTPUT(AscendDequant) -#ifdef ENABLE_GE -DECLARE_OP_ADAPTER(Print) -DECLARE_OP_USE_DYN_INPUT(Print) -#endif -} // namespace transform -} // namespace mindspore -#endif // TRANSFORM_OP_DECLARE_H_ diff --git a/mindspore/ccsrc/transform/util.cc b/mindspore/ccsrc/transform/util.cc deleted file mode 100644 index b848ec117b..0000000000 --- a/mindspore/ccsrc/transform/util.cc +++ /dev/null @@ -1,452 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/util.h" - -#include -#include -#include - -#include "securec/include/securec.h" -#include "utils/convert_utils.h" -#include "utils/utils.h" - -namespace mindspore { -namespace transform { -using std::make_shared; -using std::shared_ptr; -using std::string; -using std::vector; - -const size_t kErrorSize = 0; - -vector TransformUtil::ConvertIntToList(int64_t data, int size) { - vector list{}; - if (size <= 0) { - MS_LOG(WARNING) << "size <= 0"; - return list; - } - for (int i = 0; i < size; ++i) { - list.push_back(data); - } - return list; -} - -static std::map datatype_trans_map = { - {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16}, {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT}, - {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE}, {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8}, - {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16}, {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32}, - {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64}, {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8}, - {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32}, - {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}}; - -GeDataType TransformUtil::ConvertDataType(const MeDataType &type) { - MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type"; - if (datatype_trans_map.find(type) != datatype_trans_map.end()) { - return datatype_trans_map[type]; - } else { - return GeDataType::DT_UNDEFINED; - } -} - -static std::map datatype_size_map = { - {MeDataType::kNumberTypeFloat16, sizeof(float) / 2}, {MeDataType::kNumberTypeFloat32, sizeof(float)}, // 1/2 of float - {MeDataType::kNumberTypeFloat64, sizeof(double)}, {MeDataType::kNumberTypeInt8, sizeof(int8_t)}, - {MeDataType::kNumberTypeInt16, sizeof(int16_t)}, {MeDataType::kNumberTypeInt32, sizeof(int32_t)}, - {MeDataType::kNumberTypeInt64, sizeof(int64_t)}, {MeDataType::kNumberTypeUInt8, sizeof(uint8_t)}, - {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)}, - {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}}; - -size_t TransformUtil::GetDataTypeSize(const MeDataType &type) { - if (datatype_size_map.find(type) != datatype_size_map.end()) { - return datatype_size_map[type]; - } else { - MS_LOG(ERROR) << "Illegal tensor data type!"; - return kErrorSize; - } -} - -GeFormat TransformUtil::ConvertFormat(const string &format) { - if (format == kOpFormat_NCHW) { - return GeFormat::FORMAT_NCHW; - } else if (format == kOpFormat_NC1HWC0) { - return GeFormat::FORMAT_NC1HWC0; - } else if (format == kOpFormat_NHWC) { - return GeFormat::FORMAT_NHWC; - } else if (format == kOpFormat_HWCN) { - return GeFormat::FORMAT_HWCN; - } else { - return GeFormat::FORMAT_ND; - } -} - -static int64_t IntegerCastFunc(size_t temp) { return static_cast(temp); } - -std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector &me_shape, - const MeDataType &me_type, const std::string &format) { - // convert me shape to ge shape - std::vector ge_shape; - - if (me_shape.size() == 1) { - ge_shape.push_back(static_cast(me_shape[0])); - } else { - ge_shape.resize(me_shape.size()); - (void)std::transform(me_shape.begin(), me_shape.end(), ge_shape.begin(), IntegerCastFunc); - } - - GeShape shape(ge_shape); - if (shape.GetDimNum() == 0) { - MS_LOG(INFO) << "The dims size of Ge tensor is zero"; - } - // convert me format to ge format - GeFormat ge_format = ConvertFormat(format); - if (ge_format == GeFormat::FORMAT_ND) { - MS_LOG(ERROR) << "undefined data format : " << static_cast(ge_format); - return nullptr; - } - // convert me datatype to ge datatype - GeDataType data_type = ConvertDataType(me_type); - if (data_type == GeDataType::DT_UNDEFINED) { - MS_LOG(ERROR) << "undefined data type :" << me_type; - return nullptr; - } - - auto desc = std::make_shared(shape, ge_format, data_type); - if (desc == nullptr) { - MS_LOG(ERROR) << "Create GeTensorDesc failed!"; - return nullptr; - } - MS_LOG(INFO) << "SetRealDimCnt is :" << me_shape.size(); - desc->SetRealDimCnt(SizeToInt(me_shape.size())); - return desc; -} - -// if failed, return empty vector. -std::vector TransformUtil::ConvertInputTensors(const std::vector &me_tensors, - const std::string &format) { - std::vector ge_tensors; - - for (size_t index = 0; index < me_tensors.size(); index++) { - MS_EXCEPTION_IF_NULL(me_tensors[index]); - MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize(); - auto shape = me_tensors[index]->shape(); - std::string shape_str; - for (size_t i = 0; i < shape.size(); i++) { - shape_str += std::to_string(shape[i]); - shape_str += " "; - } - MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}"; - MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type(); - - auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format); - if (ge_tensor_ptr != nullptr) { - ge_tensors.emplace_back(ge_tensor_ptr); - } else { - MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!"; - ge_tensors.clear(); - return ge_tensors; - } - } - return ge_tensors; -} - -GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format) { - // get tensor data type size - MS_EXCEPTION_IF_NULL(tensor); - size_t type_size = GetDataTypeSize(tensor->data_type()); - if (type_size == kErrorSize) { - MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size; - return nullptr; - } - size_t elements_num = IntToSize(tensor->ElementsNum()); - if (UINT_MAX / type_size < elements_num) { - MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size - << " overflowed UINT_MAX: " << UINT_MAX << "."; - return nullptr; - } - - // get tensor buff size - size_t data_buff_size = elements_num * type_size; - if (data_buff_size == 0) { - MS_LOG(INFO) << "The Me Tensor data buff size is 0."; - } - // create ge tensor - auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format); - if (desc == nullptr) { - MS_LOG(ERROR) << "Failed to get Tensor Desc"; - return nullptr; - } - GeTensorPtr tensor_ptr = make_shared(*desc, static_cast(tensor->data_c()), data_buff_size); - if (tensor_ptr != nullptr) { - MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!"; - } - return tensor_ptr; -} - -std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors, - const std::vector> &request_dims) { - std::vector outputs; - - for (size_t index = 0; index < ge_tensors.size(); index++) { - MeTensorPtr me_tensor_ptr = nullptr; - if (index < request_dims.size()) { - me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]); - } else { - std::vector empty_shape; - me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape); - } - - if (me_tensor_ptr != nullptr) { - outputs.emplace_back(me_tensor_ptr); - } else { - MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!"; - return outputs; - } - } - return outputs; -} - -std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors) { - std::vector outputs; - - for (size_t index = 0; index < ge_tensors.size(); index++) { - MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]); - if (me_tensor_ptr != nullptr) { - outputs.emplace_back(me_tensor_ptr); - } else { - MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!"; - return outputs; - } - } - return outputs; -} - -MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) { - switch (type) { - case GeDataType::DT_FLOAT16: - return MeDataType::kNumberTypeFloat16; - case GeDataType::DT_FLOAT: - return MeDataType::kNumberTypeFloat32; - case GeDataType::DT_DOUBLE: - return MeDataType::kNumberTypeFloat64; - case GeDataType::DT_INT64: - return MeDataType::kNumberTypeInt64; - case GeDataType::DT_INT32: - return MeDataType::kNumberTypeInt32; - case GeDataType::DT_INT16: - return MeDataType::kNumberTypeInt16; - case GeDataType::DT_INT8: - return MeDataType::kNumberTypeInt8; - case GeDataType::DT_BOOL: - return MeDataType::kNumberTypeBool; - case GeDataType::DT_UINT8: - return MeDataType::kNumberTypeUInt8; - case GeDataType::DT_UINT16: - return MeDataType::kNumberTypeUInt16; - case GeDataType::DT_UINT32: - return MeDataType::kNumberTypeUInt32; - case GeDataType::DT_UINT64: - return MeDataType::kNumberTypeUInt64; - case GeDataType::DT_UNDEFINED: - case GeDataType::DT_DUAL_SUB_UINT8: - case GeDataType::DT_DUAL_SUB_INT8: - case GeDataType::DT_DUAL: - return MeDataType::kTypeUnknown; - default: - return MeDataType::kTypeUnknown; - } -} - -namespace { -bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector &request_dims) { - MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims()); - MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims); - - const int GE_DIMS = 4; - std::vector ge_dims = ge_shape.GetDims(); - if (request_dims.size() > ge_dims.size()) { - MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's"; - return false; - } - - // convert NHWC to NCHW - if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[0] == ge_dims[1]) && - (ge_dims[0] == 1) && (ge_dims[2] == 1) && (ge_dims[3] == 1)) { - MS_LOG(INFO) << "Ge tensor shape and request shape is compatible"; - return true; - } - - std::string::size_type i = 0; - for (; i < request_dims.size(); i++) { - if (ge_dims[i] != request_dims[i]) { - MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's"; - return false; - } - } - - for (; i < ge_dims.size(); i++) { - if (ge_dims[i] != 1) { - MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1"; - return false; - } - } - MS_LOG(INFO) << "Ge tensor shape and request shape is compatible"; - return true; -} -} // namespace - -GeShape TransformUtil::ConvertMeShape(const std::vector &me_dims) { - std::vector ge_dims; - (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims)); - return GeShape(ge_dims); -} - -std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape) { - std::vector me_dims; - std::vector ge_dims = ge_shape.GetDims(); - (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims)); - return me_dims; -} - -std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims) { - vector ret; - if (ge_shape.GetDimNum() == 0) { - MS_LOG(DEBUG) << "GeTensor's shape is scalar"; - return ret; - } - - if (IsGeShapeCompatible(ge_shape, request_dims) == true) { - ret = request_dims; - } else { - MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape"; - ret = ConvertGeShape(ge_shape); - } - return ret; -} - -MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, - const TypeId &me_type) { - MeTensor me_tensor(me_type, me_dims); - - // Get the writable data pointer of the tensor and cast it to its data type - auto me_data_ptr = reinterpret_cast(me_tensor.data_c()); - size_t me_data_size = static_cast(me_tensor.data().nbytes()); - MS_EXCEPTION_IF_NULL(me_data_ptr); - MS_EXCEPTION_IF_NULL(ge_tensor); - if (me_data_size < ge_tensor->GetSize()) { - MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor [" - << ge_tensor->GetSize() << " bytes]"; - return nullptr; - } - - // Copy or use the writable data pointer of the ME tensor - MS_EXCEPTION_IF_NULL(ge_tensor->GetData()); - if (ge_tensor->GetSize() == 0) { - MS_LOG(ERROR) << "GE tensor data size is zero!"; - return nullptr; - } - - // Use memcpy here, not memcpy_s, just because the size of ge_tensor may be bigger than 2GB - // which is the size limit of memcpy_s - memcpy(me_data_ptr, ge_tensor->GetData(), ge_tensor->GetSize()); - - return make_shared(me_tensor); -} - -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) { - MS_EXCEPTION_IF_NULL(ge_tensor); - GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); - vector me_dims = ConvertGeShape(ge_shape); - - TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType()); - if (type_id == MeDataType::kTypeUnknown) { - MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: " - << static_cast(ge_tensor->GetTensorDesc().GetDataType()); - return nullptr; - } - return GenerateMeTensor(ge_tensor, me_dims, type_id); -} - -// if request_dims is empty, use ge tensor's shape,otherwise convert to request shape -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector &request_dims) { - MS_EXCEPTION_IF_NULL(ge_tensor); - GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); - vector me_dims = ConvertGeShape(ge_shape, request_dims); - MS_LOG(INFO) << "GE tensor type is " << static_cast(ge_tensor->GetTensorDesc().GetDataType()); - // Create a tensor with wanted data type and shape - TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType()); - if (type_id == MeDataType::kTypeUnknown) { - MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: " - << static_cast(ge_tensor->GetTensorDesc().GetDataType()); - return nullptr; - } - return GenerateMeTensor(ge_tensor, me_dims, type_id); -} - -std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) { - std::string ret; - if (ge_tensor == nullptr) { - MS_LOG(ERROR) << "Input ge tensor is nullptr"; - return ret; - } - - MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast(ge_tensor->GetTensorDesc().GetDataType()); - switch (ge_tensor->GetTensorDesc().GetDataType()) { - case GeDataType::DT_UINT32: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_FLOAT: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_INT32: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_DOUBLE: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_INT64: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_UINT64: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_INT16: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_UINT16: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_DUAL_SUB_INT8: - case GeDataType::DT_INT8: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_UINT8: - case GeDataType::DT_DUAL_SUB_UINT8: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_FLOAT16: - case GeDataType::DT_BOOL: - case GeDataType::DT_UNDEFINED: - case GeDataType::DT_DUAL: - default: - MS_LOG(ERROR) << "Unsupported to print type:" << static_cast(ge_tensor->GetTensorDesc().GetDataType()) - << " ge tensor"; - break; - } - return ret; -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/util.h b/mindspore/ccsrc/transform/util.h deleted file mode 100644 index 5d8db26ad1..0000000000 --- a/mindspore/ccsrc/transform/util.h +++ /dev/null @@ -1,241 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_UTIL_H_ -#define TRANSFORM_UTIL_H_ - -#include -#include -#include -#include -#include "securec/include/securec.h" -#include "ir/anf.h" -#include "ir/dtype.h" -#include "ir/tensor.h" -#include "transform/types.h" - -#include "graph/tensor.h" - -namespace mindspore { -namespace transform { -class TransformUtil { - public: - /* - * Parameters: - * type: [MeDataType] the data type for ME tensor - * Return: - * [GeDataType] the data type for ge tensor - * */ - static std::vector ConvertIntToList(int64_t data, int size); - - /* - * Parameters: - * type: [MeDataType] the data type for ME tensor - * Return: - * [GeDataType] the data type for ge tensor - * */ - static GeDataType ConvertDataType(const MeDataType &type); - - /* - * Parameters: - * type: [string] the data format in ME op - * Return: - * [GeFormat] the data format for ge tensor - * */ - static GeFormat ConvertFormat(const std::string &format); - - /* - * Parameters: - * type: [MeDataType] the data type for ME tensor - * Return: - * [size_t] the buff size for the type in ME - * */ - static size_t GetDataTypeSize(const MeDataType &type); - - /* - * Parameters: - * tensor: [MeTensorPtr] the me tensor to get description from - * format: [string] the data format in ME - * is_input: [bool] whether the tensor is used as input, default:false - * Return: - * [shared_ptr] the shared pointer of ge tensor description - * */ - static std::shared_ptr GetGeTensorDesc(const std::vector &shape, const MeDataType &me_type, - const std::string &format); - - /* - * Parameters: - * tensor: [MeTensor] the data tensor in ME - * format: [string] the data format in ME op - * is_input: [bool] whether the tensor is used as input, default:false - * Return: - * [GeTensor] the data tensor in GE - * */ - static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format); - - /* - * Parameters: - * me_tensors: [vector] the data tensors in ME - * format: [string] the data format in ME op - * Return: - * [std::vector] the data tensors in GE - * */ - static std::vector ConvertInputTensors(const std::vector &me_tensors, - const std::string &format); - - /* - * Parameters: - * tensor: [GeTensor] the data tensor in GE - * Return: - * [MeTensor] the data tensor in ME - * */ - static MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor); - - /* - * Parameters: - * tensor: [GeTensor] the data tensor in GE - * request_dims [std::vector] the output Me tensors must adjust to this shapes - * Return: - * [MeTensor] the data tensor in ME - * */ - static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector &request_dims); - /* - * Parameters: - * ge_tensors: [std::vector] the data tensor in GE - * request_dims [std::vector>] the output Me tensors must adjust to this shapes - * Return: - * [std::vector] the data tensor in ME - * */ - static std::vector ConvertGeTensors(const std::vector &ge_tensors, - const std::vector> &request_dims); - /* - * Parameters: - * ge_tensors: [std::vector] the data tensor in GE - * Return: - * [std::vector] the data tensor in ME - * */ - static std::vector ConvertGeTensors(const std::vector &ge_tensors); - /* - * Parameters: - * ge_tensor: [GeTensor] the data tensor in GE - * me_dims: [std::vector] the shape of created Me tensor - * me_type: [TypeId] the type of created Me tensor - * Return: - * [MeTensor] the data tensor in ME - * */ - static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, - const TypeId &me_type); - /* - * Parameters: - * type: [GeDataType] the ge tensor data type - * Return: - * [MeDataType] the me tensor data type - * */ - static MeDataType ConvertGeDataType(const GeDataType &type); - - /* - * Parameters: - * me_dims: [std::vector] the me shape - * Return: - * [GeShape] the ge shape - * */ - static GeShape ConvertMeShape(const std::vector &me_dims); - - /* - * Parameters: - * ge_shape: [GeShape] the ge shape - * Return: - * [vector] the me shape - * */ - static std::vector ConvertGeShape(const GeShape &ge_shape); - - /* Function: - * Convert GeShape to Me request shape, Support pattern: - * {1, x, 1, 1} --> {x} - * {x, 1, 1, 1} --> {x} - * {x, x, 1, 1} --> {x, x} - * {x, x, x, 1} --> {x, x, x} - * {x, x, x, x} --> {x, x, x, x} - * If unmatch upon patterns, return original ge dims - * Parameters: - * ge_shape: [GeShape] the ge shape - * request_dims: [vector] request dims - * Return: - * [vector] the me shape - * */ - static std::vector ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims); - - /* - * Parameters: - * vec: [std::vector] the vector to print - * Return: - * [string] value string - * */ - template ::value>::type> - static std::string PrintVector(const std::vector &vec) { - const int MAX_PRINT_NUM = 100; - std::stringstream ss; - ss << "{ "; - int i = 0; - for (auto it = vec.begin(); it != vec.end(); ++it) { - ss << std::to_string(*it) << ", "; - i++; - if (i >= MAX_PRINT_NUM) { - break; - } - } - - if (i >= MAX_PRINT_NUM) { - ss << "... to be continue}"; - } else { - ss << "}"; - } - return ss.str(); - } - - /* - * Parameters: - * ge_tensor: [GeTensorPtr] the ge tensor - * Return: - * [stringstream] value string - * */ - static std::string PrintGeTensor(const GeTensorPtr ge_tensor); - - /* - * Parameters: - * data: [uint8_t *] the ge tensor data pointer - * size: [size_t] the ge tensor data bytes - * Return: - * [shared_ptr] vector pointer - * */ - template ::value>::type> - static std::vector MakeVector(const uint8_t *const data, size_t size) { - auto dest = std::vector(size / sizeof(T)); - if (data == nullptr) { - return dest; - } - - errno_t ret = memcpy_s(dest.data(), dest.size() * sizeof(T), data, size); - if (EOK != ret) { - return std::vector(); - } - return dest; - } -}; -} // namespace transform -} // namespace mindspore - -#endif // TRANSFORM_UTIL_H_ diff --git a/mindspore/ccsrc/utils/anf_ir.proto b/mindspore/ccsrc/utils/anf_ir.proto index 145751e7f0..2ea0511fa8 100644 --- a/mindspore/ccsrc/utils/anf_ir.proto +++ b/mindspore/ccsrc/utils/anf_ir.proto @@ -227,6 +227,9 @@ message NodeProto { // other fields for debug optional uint64 output_i = 7; + + // The full_name_with_scope of CNode + optional string full_name = 8; } // Models diff --git a/mindspore/ccsrc/utils/callbacks.cc b/mindspore/ccsrc/utils/callbacks.cc index 427cc5e568..ceb95d5c8c 100644 --- a/mindspore/ccsrc/utils/callbacks.cc +++ b/mindspore/ccsrc/utils/callbacks.cc @@ -20,8 +20,8 @@ #include #include #include "pybind11/pybind11.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/python_adapter.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/python_adapter.h" #include "utils/visible.h" namespace mindspore { diff --git a/mindspore/ccsrc/utils/callbacks_ge.cc b/mindspore/ccsrc/utils/callbacks_ge.cc index 3174ec4b15..6001b295ad 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.cc +++ b/mindspore/ccsrc/utils/callbacks_ge.cc @@ -16,11 +16,11 @@ #include "utils/callbacks_ge.h" #include "pybind11/pybind11.h" -#include "ir/param_value_py.h" -#include "transform/df_graph_manager.h" -#include "transform/util.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/python_adapter.h" +#include "ir/param_value.h" +#include "transform/graph_ir/df_graph_manager.h" +#include "transform/graph_ir/util.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/python_adapter.h" #include "utils/visible.h" namespace mindspore { @@ -50,13 +50,10 @@ bool GetParameterShape(const FuncGraphPtr &graph, const std::string ¶m_name, return false; } if (param_node->name() == param_name) { - py::object parameter; + TensorPtr tensor; if (param_node->has_default()) { - auto param_value = std::dynamic_pointer_cast(param_node->default_param()); - parameter = param_value->value(); + tensor = std::dynamic_pointer_cast(param_node->default_param()->value()); } - ValuePtr value = parse::data_converter::PyDataToValue(parameter); - TensorPtr tensor = std::dynamic_pointer_cast(value); if (tensor == nullptr) { shape->push_back(ONE_SHAPE); } else { diff --git a/mindspore/ccsrc/utils/callbacks_ge.h b/mindspore/ccsrc/utils/callbacks_ge.h index 9735c3000a..f0ef583aaa 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.h +++ b/mindspore/ccsrc/utils/callbacks_ge.h @@ -20,8 +20,8 @@ #include #include #include -#include "transform/types.h" -#include "transform/util.h" +#include "transform/graph_ir/types.h" +#include "transform/graph_ir/util.h" #include "ir/tensor.h" namespace mindspore { diff --git a/mindspore/ccsrc/utils/comm_manager.cc b/mindspore/ccsrc/utils/comm_manager.cc index 70adfb7467..de165c4aac 100644 --- a/mindspore/ccsrc/utils/comm_manager.cc +++ b/mindspore/ccsrc/utils/comm_manager.cc @@ -16,17 +16,27 @@ #include "utils/comm_manager.h" #include "utils/convert_utils.h" + #ifndef NO_DLIB #include "hccl/hcom.h" #endif +#if defined(ENABLE_GPU) +#include "runtime/device/gpu/distribution/collective_init.h" +using CollectiveInitializer = mindspore::device::gpu::CollectiveInitializer; +using CreateCommGroupFunc = mindspore::device::gpu::CreateCommGroupFunc; +using GetRankIDByGroupFunc = mindspore::device::gpu::GetRankIDByGroupFunc; +using GetGroupSizeFunc = mindspore::device::gpu::GetGroupSizeFunc; +using DestroyGroupFunc = mindspore::device::gpu::DestroyGroupFunc; +#endif + namespace mindspore { +#ifndef NO_DLIB CommManager &CommManager::GetInstance() noexcept { static CommManager instance("hccl"); return instance; } -#ifndef NO_DLIB #define HCCL_RUN_CHECK(op_name, group, op) \ do { \ auto hccl_result = (op); \ @@ -79,7 +89,79 @@ bool CommManager::DestroyGroup(const string &group) const { HCCL_RUN_CHECK(string("destroy communicate group"), group, hcom_destroy_group(group.c_str())); return true; } +#elif defined(ENABLE_GPU) +CommManager &CommManager::GetInstance() noexcept { + static CommManager instance("nccl"); + return instance; +} + +bool CommManager::CreateGroupSync(const string &group, const vector &rank_id_list) const { + const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); + if (!collective_handle_) { + MS_LOG(EXCEPTION) << "GPU collective handle is not initialized."; + } + MS_LOG(INFO) << "Create communication group " << group << " by rank id list " << rank_id_list; + auto create_comm_group_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "CreateCommGroup")); + MS_EXCEPTION_IF_NULL(create_comm_group_funcptr); + bool ret = (*create_comm_group_funcptr)(group, rank_id_list); + if (!ret) { + MS_LOG(ERROR) << "Creating group " << group << "for rank id list" << rank_id_list << "failed."; + return ret; + } + return ret; +} + +bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { + const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); + if (!collective_handle_) { + MS_LOG(EXCEPTION) << "GPU collective handle is not initialized."; + } + auto get_rank_id_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "GetRankIDByGroup")); + MS_EXCEPTION_IF_NULL(get_rank_id_funcptr); + int rank = (*get_rank_id_funcptr)(group); + *rank_id = static_cast(rank); + MS_LOG(INFO) << "This process rank id is " << *rank_id << " in group " << group; + return true; +} + +bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const { + const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); + if (!collective_handle_) { + MS_LOG(EXCEPTION) << "GPU collective handle is not initialized."; + } + auto get_group_size_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "GetGroupSize")); + MS_EXCEPTION_IF_NULL(get_group_size_funcptr); + int size = (*get_group_size_funcptr)(group); + *rank_size = static_cast(size); + MS_LOG(INFO) << "Group " << group << " size is " << *rank_size; + return true; +} + +bool CommManager::DestroyGroup(const string &group) const { + const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); + if (!collective_handle_) { + MS_LOG(EXCEPTION) << "GPU collective handle is not initialized."; + } + auto destroy_group_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "DestroyGroup")); + MS_EXCEPTION_IF_NULL(destroy_group_funcptr); + + bool ret = (*destroy_group_funcptr)(group); + if (!ret) { + MS_LOG(ERROR) << "Destroying group " << group << " failed."; + return ret; + } + return ret; +} #else +CommManager &CommManager::GetInstance() noexcept { + static CommManager instance("hccl"); + return instance; +} + bool CommManager::CreateGroupSync(const string &, const vector &) const { return true; } bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { return true; } diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 2f2471f460..d6381ec7e8 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -27,9 +27,10 @@ #include "tdt/data_common.h" #endif #ifdef ENABLE_GE -#include "transform/df_graph_manager.h" +#include "transform/graph_ir/df_graph_manager.h" #endif #include "ir/tensor.h" +#include "common/utils.h" namespace mindspore { #ifdef ENABLE_GE @@ -89,7 +90,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { max_device_memory_ = kDefaultMaxDeviceMemory; print_file_path_ = ""; enable_graph_kernel_ = false; - enable_sparse_flag_ = false; + enable_sparse_ = false; } std::shared_ptr MsContext::GetInstance() { @@ -168,6 +169,11 @@ bool MsContext::OpenTsd() { return true; } + auto role = common::GetEnv("MS_ROLE"); + if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) { + return true; + } + unsigned int device_id; unsigned int rank_size = 1; diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index 3bca16f8ee..19205cccb8 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -161,8 +161,8 @@ class MsContext { void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } bool enable_graph_kernel() const { return enable_graph_kernel_; } - bool enable_sparse_flag() const { return enable_sparse_flag_; } - void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; } + bool enable_sparse() const { return enable_sparse_; } + void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; } private: MsContext(const std::string &backend_policy, const std::string &target); @@ -207,7 +207,7 @@ class MsContext { float max_device_memory_; std::string print_file_path_; bool enable_graph_kernel_; - bool enable_sparse_flag_; + bool enable_sparse_; }; } // namespace mindspore diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 8cb071b769..b1847d1df5 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -25,12 +25,12 @@ #include #include "pybind11/pybind11.h" -#include "pipeline/static_analysis/abstract_value.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/parse_base.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/parse_base.h" #include "ir/value.h" #include "ir/tensor.h" -#include "ir/param_value_py.h" +#include "ir/param_value.h" #include "utils/base_ref_extends.h" namespace mindspore { @@ -230,6 +230,20 @@ bool ValueToBool(const ValuePtr &v, bool *value) { return true; } +bool BaseRefToInt(const ValuePtr &v, int *value) { + MS_EXCEPTION_IF_NULL(v); + if (v->isa()) { + auto tensor = v->cast(); + (void)tensor->data_sync(); + int *tensor_data = static_cast(tensor->data_c()); + auto vb = tensor_data[0]; + *value = vb; + return true; + } + MS_LOG(ERROR) << "Index must be tensor type."; + return false; +} + bool BaseRefToBool(const BaseRef &v, bool *value) { if (utils::isa(v)) { return ValueToBool(utils::cast(v), value); @@ -435,8 +449,8 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple if (!param->has_default()) { MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")"; } - auto param_value = std::dynamic_pointer_cast(param->default_param()); - *ret_val = param_value->value().attr("data"); + auto tensor = param->default_param()->value(); + *ret_val = py::cast(tensor); } return true; } diff --git a/mindspore/ccsrc/utils/convert_utils.h b/mindspore/ccsrc/utils/convert_utils.h index 40c3e88c5c..d4ecbf4408 100644 --- a/mindspore/ccsrc/utils/convert_utils.h +++ b/mindspore/ccsrc/utils/convert_utils.h @@ -28,7 +28,7 @@ #include "utils/convert_utils_base.h" #include "utils/any.h" #include "utils/base_ref.h" -#include "ir/base.h" +#include "base/base.h" #include "ir/anf.h" namespace py = pybind11; @@ -42,6 +42,7 @@ using TensorPtr = std::shared_ptr; py::object AnyToPyData(const Any &value); py::object BaseRefToPyData(const BaseRef &value); bool BaseRefToBool(const BaseRef &in, bool *out); +bool BaseRefToInt(const ValuePtr &v, int *value); bool ValueToBool(const ValuePtr &in, bool *out); py::object ValuePtrToPyData(const ValuePtr &value); diff --git a/mindspore/ccsrc/utils/graph_utils.h b/mindspore/ccsrc/utils/graph_utils.h index 93edda3e34..2a9240ac84 100644 --- a/mindspore/ccsrc/utils/graph_utils.h +++ b/mindspore/ccsrc/utils/graph_utils.h @@ -29,7 +29,7 @@ #include #include "ir/anf.h" -#include "ir/primitive_base.h" +#include "ir/primitive.h" #include "ir/scalar.h" #include "ir/tensor.h" #include "debug/label.h" diff --git a/mindspore/ccsrc/utils/graph_utils_extends.cc b/mindspore/ccsrc/utils/graph_utils_extends.cc index 0740c24236..852dd0e3f2 100644 --- a/mindspore/ccsrc/utils/graph_utils_extends.cc +++ b/mindspore/ccsrc/utils/graph_utils_extends.cc @@ -31,8 +31,8 @@ #include "debug/label.h" #include "utils/log_adapter.h" #include "common/utils.h" -#include "pipeline/parse/function_block.h" -#include "pipeline/parse/python_adapter.h" +#include "pipeline/jit/parse/function_block.h" +#include "pipeline/jit/parse/python_adapter.h" namespace mindspore { namespace { diff --git a/mindspore/ccsrc/utils/load_onnx/anf_converter.cc b/mindspore/ccsrc/utils/load_onnx/anf_converter.cc index ad87d6ae8f..9e8e51a46b 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_converter.cc +++ b/mindspore/ccsrc/utils/load_onnx/anf_converter.cc @@ -60,6 +60,9 @@ int AnfConverter::ValidateFileStr(const std::string &modelFile, std::string file bool AnfConverter::ReadOnnxFromBinary(const std::string &modelFile, google::protobuf::Message *onnx_model) { std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); int fd = open(onnx_file.get(), O_RDONLY); + if (fd < 0) { + MS_LOG(EXCEPTION) << "failed to open file"; + } google::protobuf::io::FileInputStream input(fd); google::protobuf::io::CodedInputStream code_input(&input); code_input.SetTotalBytesLimit(INT_MAX, 536870912); @@ -85,7 +88,7 @@ std::shared_ptr AnfConverter::RunAnfConverter(const std::string &file MS_LOG(ERROR) << "Trans data not support input format!"; } else { modelFile = flagItem.substr(pos + 1); - std::cout << "input protobuf file path is: " << flagItem.substr(pos + 1) << std::endl; + std::cout << "input protobuf file path is: " << modelFile << std::endl; } if (ValidateFileStr(modelFile, ".pb") != 0) { diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc index c3dfa5194f..fa1137e3f6 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc +++ b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc @@ -22,14 +22,12 @@ #include #include "google/protobuf/io/zero_copy_stream_impl.h" #include "ir/tensor.h" -#include "ir/tensor_py.h" -#include "ir/param_value_py.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/abstract_value.h" +#include "ir/param_value.h" +#include "frontend/operator/ops.h" +#include "abstract/abstract_value.h" #include "proto/onnx.pb.h" #include "utils/log_adapter.h" -using mindspore::tensor::TensorPy; using std::string; namespace mindspore { @@ -121,13 +119,15 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons std::string initial_data = initialize_proto.raw_data(); auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); MS_EXCEPTION_IF_NULL(tensor_data_buf); - memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size()); + auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; + } - py::array array_data = TensorPy::AsNumpy(*tensor_info); - ParamValuePyPtr para_value_ptr = std::make_shared(); - MS_EXCEPTION_IF_NULL(para_value_ptr); - para_value_ptr->set_value(array_data); - node->set_default_param(para_value_ptr); + auto param_value = std::make_shared(); + MS_EXCEPTION_IF_NULL(param_value); + param_value->set_value(tensor_info); + node->set_default_param(param_value); } anfnode_build_map_[value_proto.name()] = node; return true; @@ -252,7 +252,11 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); const std::string &tensor_buf = attr_tensor.raw_data(); auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c()); - memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); + auto ret = memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; + } + auto new_value_node = NewValueNode(MakeValue(tensor_info)); MS_EXCEPTION_IF_NULL(new_value_node); auto tensor_abstract = tensor_info->ToAbstract(); @@ -339,7 +343,6 @@ bool MSANFModelParser::GetAttrValueForValueNode(const std::string &ref_attr_name MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; return false; } - return true; } bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h index 11b9cd101f..58fbd1bc70 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h +++ b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h @@ -32,7 +32,7 @@ using uint64 = uint64_t; using float16 = Eigen::half; class MSANFModelParser { public: - MSANFModelParser() = default; + MSANFModelParser() : producer_name_(""), model_version_(0), ir_version_(0) {} ~MSANFModelParser() = default; FuncGraphPtr Parse(const onnx::ModelProto &model_proto); diff --git a/mindspore/ccsrc/utils/log_adapter.cc b/mindspore/ccsrc/utils/log_adapter.cc index 3588754dae..702deefcb4 100644 --- a/mindspore/ccsrc/utils/log_adapter.cc +++ b/mindspore/ccsrc/utils/log_adapter.cc @@ -18,7 +18,6 @@ #include #include -#include "pybind11/pybind11.h" #include "debug/trace.h" // namespace to support utils module definition @@ -158,6 +157,7 @@ static std::string ExceptionTypeToString(ExceptionType type) { static const char *GetSubModuleName(SubModuleId module_id) { static const char *sub_module_names[NUM_SUBMODUES] = { "UNKNOWN", // SM_UNKNOWN + "BASE", // SM_BASE "ANALYZER", // SM_ANALYZER "COMMON", // SM_COMMON "DEBUG", // SM_DEBUG @@ -176,7 +176,8 @@ static const char *GetSubModuleName(SubModuleId module_id) { "PYNATIVE", // SM_PYNATIVE "SESSION", // SM_SESSION "UTILS", // SM_UTILS - "VM" // SM_VM + "VM", // SM_VM + "ABSTRACT" // SM_ABSTRACT }; return sub_module_names[module_id % NUM_SUBMODUES]; @@ -219,16 +220,10 @@ void LogWriter::operator^(const LogStream &stream) const { trace::TraceGraphEval(); trace::GetEvalStackInfo(oss); - if (exception_type_ == IndexError) { - throw pybind11::index_error(oss.str()); + if (exception_handler_ != nullptr) { + exception_handler_(exception_type_, oss.str()); } - if (exception_type_ == ValueError) { - throw pybind11::value_error(oss.str()); - } - if (exception_type_ == TypeError) { - throw pybind11::type_error(oss.str()); - } - pybind11::pybind11_fail(oss.str()); + throw std::runtime_error(oss.str()); } static std::string GetEnv(const std::string &envvar) { diff --git a/mindspore/ccsrc/utils/log_adapter.h b/mindspore/ccsrc/utils/log_adapter.h index dfd463ee1d..a0e9bfc6d6 100644 --- a/mindspore/ccsrc/utils/log_adapter.h +++ b/mindspore/ccsrc/utils/log_adapter.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "./overload.h" #include "./securec.h" #ifdef USE_GLOG @@ -99,6 +100,7 @@ enum MsLogLevel : int { DEBUG = 0, INFO, WARNING, ERROR, EXCEPTION }; enum SubModuleId : int { SM_UNKNOWN = 0, // unknown submodule + SM_BASE, // base SM_ANALYZER, // static analyzer SM_COMMON, // common SM_DEBUG, // debug @@ -118,6 +120,7 @@ enum SubModuleId : int { SM_SESSION, // session SM_UTILS, // utils SM_VM, // VM + SM_ABSTRACT, // abstract NUM_SUBMODUES // number of submodules }; @@ -133,6 +136,8 @@ extern int g_ms_submodule_log_levels[] __attribute__((visibility("default"))); class LogWriter { public: + using ExceptionHandler = std::function; + LogWriter(const LocationInfo &location, MsLogLevel log_level, SubModuleId submodule, ExceptionType excp_type = NoExceptionType) : location_(location), log_level_(log_level), submodule_(submodule), exception_type_(excp_type) {} @@ -141,6 +146,8 @@ class LogWriter { void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))); void operator^(const LogStream &stream) const __attribute__((noreturn, visibility("default"))); + static void set_exception_handler(ExceptionHandler exception_handler) { exception_handler_ = exception_handler; } + private: void OutputLog(const std::ostringstream &msg) const; @@ -148,6 +155,8 @@ class LogWriter { MsLogLevel log_level_; SubModuleId submodule_; ExceptionType exception_type_; + + inline static ExceptionHandler exception_handler_ = nullptr; }; #define MSLOG_IF(level, condition, excp_type) \ diff --git a/mindspore/ccsrc/utils/log_adapter_py.cc b/mindspore/ccsrc/utils/log_adapter_py.cc new file mode 100644 index 0000000000..c4793b960b --- /dev/null +++ b/mindspore/ccsrc/utils/log_adapter_py.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/log_adapter.h" + +#include +#include "pybind11/pybind11.h" + +namespace py = pybind11; +namespace mindspore { +class PyExceptionInitializer { + public: + PyExceptionInitializer() { mindspore::LogWriter::set_exception_handler(HandleExceptionPy); } + + ~PyExceptionInitializer() = default; + + private: + static void HandleExceptionPy(ExceptionType exception_type, const std::string &str) { + if (exception_type == IndexError) { + throw py::index_error(str); + } + if (exception_type == ValueError) { + throw py::value_error(str); + } + if (exception_type == TypeError) { + throw py::type_error(str); + } + py::pybind11_fail(str); + } +}; + +static PyExceptionInitializer py_exception_initializer; +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/primitive_utils.cc b/mindspore/ccsrc/utils/primitive_utils.cc index 97fa954e12..490e2517a9 100644 --- a/mindspore/ccsrc/utils/primitive_utils.cc +++ b/mindspore/ccsrc/utils/primitive_utils.cc @@ -15,7 +15,7 @@ */ #include "utils/primitive_utils.h" -#include "pipeline/parse/python_adapter.h" +#include "pipeline/jit/parse/python_adapter.h" #include "utils/log_adapter.h" #include "common/utils.h" diff --git a/mindspore/ccsrc/utils/symbolic.h b/mindspore/ccsrc/utils/symbolic.h index 1b7a212610..ca68b2c877 100644 --- a/mindspore/ccsrc/utils/symbolic.h +++ b/mindspore/ccsrc/utils/symbolic.h @@ -26,7 +26,7 @@ #include #include "ir/anf.h" -#include "pipeline/static_analysis/abstract_value.h" +#include "abstract/abstract_value.h" #include "utils/any.h" namespace mindspore { diff --git a/mindspore/ccsrc/utils/tensorprint_utils.cc b/mindspore/ccsrc/utils/tensorprint_utils.cc index ee53345f31..08cd4e4291 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.cc +++ b/mindspore/ccsrc/utils/tensorprint_utils.cc @@ -21,7 +21,7 @@ #include #include #include "ir/tensor.h" -#include "device/convert_tensor_utils.h" +#include "runtime/device/convert_tensor_utils.h" #include "./securec.h" #ifndef NO_DLIB #include "tdt/tsd_client.h" @@ -256,6 +256,7 @@ bool SaveDataItem2File(const std::vector &items, const std::strin if (!print.SerializeToOstream(output)) { MS_LOG(ERROR) << "Save print file:" << print_file_path << " fail."; ret_end_thread = true; + break; } print.Clear(); } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index e28adb6e21..3e82aaff2d 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -176,6 +176,10 @@ constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad"; constexpr auto kTensorMoveOpName = "TensorMove"; constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate"; constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate"; +constexpr auto kPushOpName = "Push"; +constexpr auto kPullOpName = "Pull"; +constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup"; +constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy"; // attr key name constexpr auto kAttrInputNames = "input_names"; @@ -236,9 +240,12 @@ constexpr auto kAttrOutputNum = "output_num"; constexpr auto kAttrSizeSplits = "size_splits"; constexpr auto kAttrOutputDefault = "output_default"; constexpr auto kAttrPrimitiveTarget = "primitive_target"; +constexpr auto kAttrUseLocking = "use_locking"; constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag"; constexpr auto kAttrOffset = "offset"; -constexpr auto kAttrUseLocking = "use_locking"; +constexpr auto kAttrPsKey = "ps_key"; +constexpr auto kAttrOptimizerType = "optim_type"; +constexpr auto kAttrChildGraph = "child_graph"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; @@ -262,6 +269,7 @@ constexpr auto kAnfPartialFuncGraphIndex = 1; constexpr auto kRealInputNodeIndexInTupleGetItem = 1; constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; constexpr auto kTupleGetItemInputSize = 3; +constexpr auto kSwitchInputSize = 4; // index define of control depend constexpr auto kControlDependPriorIndex = 1; constexpr auto kControlDependBehindIndex = 2; @@ -290,12 +298,24 @@ const std::set kOpFormatList = { kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC}; const std::set kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; const std::set kOptOperatorSet = { - kMomentumOpName, kApplyMomentumOpName, kApplyAdadeltaOpName, - kApplyAdagradOpName, kApplyAdagradDAName, kApplyAdamOpName, - kApplyAdaMaxOpName, kApplyAddSignOpName, kApplyCenteredRMSPOpName, - kApplyFtrlOpName, kApplyFtrlV2OpName, kApplyGradientDescentOpName, - kApplyPowerSignOpName, kApplyProximalAdagradOpName, kApplyProximalGradientDescentOpName, + kMomentumOpName, + kApplyMomentumOpName, + kApplyAdadeltaOpName, + kApplyAdagradOpName, + kApplyAdagradDAName, + kApplyAdamOpName, + kApplyAdaMaxOpName, + kApplyAddSignOpName, + kApplyCenteredRMSPOpName, + kApplyFtrlOpName, + kApplyFtrlV2OpName, + kApplyGradientDescentOpName, + kApplyPowerSignOpName, + kApplyProximalAdagradOpName, + kApplyProximalGradientDescentOpName, kApplyRMSPropOpName, + kPushOpName, + kPullOpName, }; const std::set kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 47bc69bbbb..0290ee57fc 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -23,7 +23,7 @@ #include "utils/callbacks.h" #include "utils/graph_utils.h" #include "utils/base_ref_extends.h" -#include "session/session_factory.h" +#include "backend/session/session_factory.h" #include "common/utils.h" #ifdef ENABLE_GE #include "utils/callbacks_ge.h" @@ -32,6 +32,7 @@ namespace mindspore { namespace compile { bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } +bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast(c), value); } LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { // multi_graph merge to one, big graph have paramters in begin and only have one output diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index 3a93cf930f..208c4010fb 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -26,7 +26,7 @@ #include "ir/anf.h" #include "vm/segment_runner.h" #include "vm/vm.h" -#include "session/session_basic.h" +#include "backend/session/session_basic.h" namespace mindspore { namespace compile { @@ -46,6 +46,7 @@ class Backend { virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {} virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; } virtual bool GetCond(const BaseRef &c, bool *value); + virtual bool GetIndex(const BaseRef &c, int *value); virtual void SetSwitchGraph() {} virtual void SetSwitchActive(const BaseRef &, bool) {} virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index db27506134..540b77bcaf 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -31,7 +31,7 @@ #include "utils/utils.h" #include "ir/manager.h" #include "ir/func_graph_cloner.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" namespace mindspore { const char kMsConvert[] = "ms"; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 80d2fc9df9..2cf6ead813 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -26,9 +26,9 @@ #include #include -#include "pipeline/static_analysis/abstract_value.h" +#include "abstract/abstract_value.h" #ifdef ENABLE_GE -#include "transform/convert.h" +#include "transform/graph_ir/convert.h" #endif #include "utils/graph_utils.h" #include "utils/context/ms_context.h" @@ -46,8 +46,9 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, prim::kPrimMakeTuple, prim::kPrimBpropCut}; const std::vector &GetMsNonlinearOps() { - static const std::vector ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, - prim::kPrimBpropCut}; + static const std::vector ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, + prim::kPrimSwitch, prim::kPrimMakeTuple, + prim::kPrimBpropCut, prim::kPrimSwitchLayer}; return ms_nonlinear_ops; } @@ -187,6 +188,29 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & std::reverse(result.begin(), result.end()); return result; } + +bool IsSubGraph(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + + AnfNodePtr fn = inputs[0]; + if (!IsValueNode(fn)) { + return false; + } + auto node_prim = GetValueNode(fn); + if (node_prim->name() == prim::kPrimPartial->name()) { + return true; + } + } else if (IsValueNode(node)) { + return true; + } + return false; +} } // namespace CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector &cut_list) @@ -214,7 +238,6 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { } AnfNodePtr fn = inputs[0]; - MS_EXCEPTION_IF_NULL(fn); if (IsValueNode(fn)) { auto fg = GetValueNode(fn); if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { @@ -235,6 +258,15 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(ms_context); ms_context->set_enable_pynative_hook(true); } + + if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { + if (inputs.size() < 2) { + return false; + } + auto ret = IsSubGraph(inputs[1]); + return ret; + } + return true; } } @@ -466,6 +498,8 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) } else if (IsPrimitive(fn, prim::kPrimSwitch)) { AddSwitch(node); AddSinkSwitch(node); + } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) { + AddSwitchLayer(node); } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) { AddMakeTuple(node); } else { @@ -622,6 +656,17 @@ void CompileGraph::AddSwitch(const CNodePtr &node) { AddInst(Instruction::kSwitch, args); } +void CompileGraph::AddSwitchLayer(const CNodePtr &node) { + auto inputs = node->inputs(); + if (inputs.size() != 3) { + MS_LOG(EXCEPTION) << "Switch layer must have index and branches."; + } + VectorRef args; + args.emplace_back(Ref(inputs[1])); + args.emplace_back(Ref(inputs[2])); + AddInst(Instruction::kSwitchLayer, args); +} + void CompileGraph::AddReturn(const CNodePtr &node) { VectorRef args; if (backend_->simu_flag()) { diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index a02478fc1b..d08a24d188 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -28,7 +28,7 @@ #include "vm/vm.h" #include "ir/anf.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "vm/segment_runner.h" #include "vm/backend.h" @@ -90,6 +90,7 @@ class CompileGraph { void AddPartial(const CNodePtr &node); void AddMakeTuple(const CNodePtr &node); void AddSwitch(const CNodePtr &node); + void AddSwitchLayer(const CNodePtr &node); void AddReturn(const CNodePtr &node); void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim); void AddInput(const AnfNodePtr &node); diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index c73d41df6c..baa5b0ea11 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -23,7 +23,7 @@ #include "vm/vmimpl.h" #include "vm/backend.h" #include "vm/transform.h" -#include "pipeline/parse/data_converter.h" +#include "pipeline/jit/parse/data_converter.h" #include "utils/base_ref_extends.h" namespace mindspore { @@ -480,6 +480,36 @@ void FinalVM::InstSwitch(const VectorRef &args) { MS_LOG(DEBUG) << "End"; } +void FinalVM::InstSwitchLayer(const VectorRef &args) { + MS_LOG(DEBUG) << "Start"; + const size_t args_size = 2; + if (args.size() != args_size) { + MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size() + << "."; + return; + } + + int idx = utils::cast(args[0]); + VectorRef branches = utils::cast(Ref(utils::cast(args[1]))); + int size = static_cast(branches.size()); + + BaseRef index = Ref(idx); + int idx_value = 0; + if (!backend_->GetIndex(index, &idx_value)) { + MS_LOG(EXCEPTION) << "Not supported type to be casted to int."; + } + if (idx_value < 0) { + // Add support negative index range [-size, -1]. + idx_value += size; + } + if (idx_value < 0 || idx_value >= size) { + MS_LOG(EXCEPTION) << __FUNCTION__ << " given index " << idx_value << " out of range. Please make sure the value " + << "of index in [" << -size << ", " << size << "), and the type is int32."; + } + Push(branches[idx_value]); + MS_LOG(DEBUG) << "End"; +} + void FinalVM::InstTuple(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; VectorRef tuple; @@ -618,57 +648,8 @@ void FinalVM::SyncData(const py::object &arg) { BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { MS_LOG(DEBUG) << "input for operation:"; - auto prim_py = dyn_cast(prim); - std::size_t args_size = args.size(); - auto py_args = py::tuple(args_size); - size_t i = 0; - for (auto &arg : args) { - py_args[i] = BaseRefToPyData(arg); - MS_LOG(DEBUG) << "arg: " << i << ":"; - i++; - } - // Hook operator for execute cell custom bprop function - py::object obj; - bool is_bprop = prim->HasAttr("bprop"); - if (is_bprop) { - SyncData(py_args); - py::function fn_bprop = prim_py->hook(); - obj = fn_bprop(*py_args); - return obj; - } - // Sync gradient data from device to host - SyncData(py_args[2]); - bool is_cell = prim->HasAttr("cell_hook"); - if (is_cell) { - // Hook operator for execute cell hook function - std::string cell_id = GetValue(prim->GetAttr("cell_id")); - if (_hook_grad.find(cell_id) != _hook_grad.end()) { - std::size_t hook_args_size = 3; - auto hook_args = py::tuple(hook_args_size); - hook_args[0] = cell_id; - hook_args[1] = py::make_tuple(_hook_grad[cell_id]); - hook_args[2] = py::make_tuple(py_args[2]); - py::function fn_hook = prim_py->hook(); - obj = fn_hook(*hook_args); - if (py::isinstance(obj)) { - obj = py_args[2]; - } - _hook_grad.erase(cell_id); - } else { - _hook_grad[cell_id] = py_args[2]; - obj = py_args[2]; - } - } else { - // Hook operator for execute variable hook function - py::function fn_hook = prim_py->hook(); - obj = fn_hook(py::make_tuple(py_args[2])); - if (py::isinstance(obj)) { - obj = py_args[2]; - } - } - obj = py::make_tuple(obj); - return obj; + MS_EXCEPTION_IF_NULL(prim); + return prim->RunHookFunction(args); } - } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index 6a078c9baf..02a1ad4ddb 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -51,15 +51,17 @@ enum Instruction { kPush, kPrim, kGraph, - kPadStack + kPadStack, + kSwitchLayer }; using InstType = std::pair; using InstSet = std::vector; using InstFunctionMap = std::map>; -const std::vector inst_str{"call", "tail_call", "return", "partial", "switch", "switch_return", "tuple", - "input", "external", "push", "primitive", "graph", "pad_stack"}; +const std::vector inst_str{"call", "tail_call", "return", "partial", "switch", + "switch_return", "tuple", "input", "external", "push", + "primitive", "graph", "pad_stack", "switch_layer"}; class StructPartial : public Base { public: // Initialize StructPartial. @@ -114,6 +116,7 @@ class FinalVM { void InstExternal(const VectorRef &args); void InstPushPrim(const VectorRef &args); void InstSwitchReturn(const VectorRef &args); + void InstSwitchLayer(const VectorRef &args); void set_insts(const InstSet &value) { insts_ = value; } BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); @@ -157,8 +160,7 @@ class FinalVM { {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, - }; - std::map _hook_grad; + {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}}; }; using FinalVMPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/vm/vmimpl.cc b/mindspore/ccsrc/vm/vmimpl.cc index 51b2c9b3d5..2aebf8ad0d 100644 --- a/mindspore/ccsrc/vm/vmimpl.cc +++ b/mindspore/ccsrc/vm/vmimpl.cc @@ -27,10 +27,10 @@ #include #include "ir/tensor.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "ir/manager.h" #include "ir/func_graph_cloner.h" -#include "ir/primitive.h" +#include "ir/primitive_py.h" #include "utils/convert_utils.h" #include "utils/primitive_utils.h" #include "debug/draw.h" diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 571cc9cb40..1605ee4bc5 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -17,11 +17,11 @@ import numbers from copy import copy from mindspore import context +from .._c_expression import ParamValue from . import dtype as mstype from .initializer import initializer, Initializer from .tensor import Tensor, MetaTensor from .._checkparam import _check_str_by_regular -from ..parallel._utils import _set_clone_info, _CloneInfo from ..parallel._tensor import _get_slice_index __all__ = ['Parameter', 'ParameterTuple'] @@ -51,34 +51,33 @@ class Parameter: requires_grad (bool): True if the parameter requires gradient. Default: True. layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, broadcast and gradients communication would not be applied on parameters. Default: False. - sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty. - has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false. """ - def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, - sparse_grad="", has_indexed_slices_grad=False): + def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): + self._value = ParamValue() self.set_parameter_data(default_input) self.name = name self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel - self.sparse_grad = sparse_grad - self.has_indexed_slices_grad = has_indexed_slices_grad self._is_init = False self._sliced = False - self.clone_info = _CloneInfo() + self.is_param_ps = False if context.get_context("mode") == context.PYNATIVE_MODE: self.init_data() def __repr__(self): format_str = 'Parameter (name={name})' - return format_str.format(name=self._name) + return format_str.format(name=self._value.name) def __parameter__(self): """For parse check.""" + def set_param_ps(self): + self.is_param_ps = True + @property def name(self): """Get the name of the parameter.""" - return self._name + return self._value.name @name.setter def name(self, name_): @@ -100,7 +99,7 @@ class Parameter: format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) else: raise ValueError("The type of the name should be `str` or `None`.") - self._name = name_ + self._value.name = name_ @property def sliced(self): @@ -140,7 +139,9 @@ class Parameter: """ _check_str_by_regular(prefix) x = copy(self) - x.name = prefix + '.' + x.name + # pylint: disable=protected-access + x._value = self._value.clone() + x._value.name = prefix + '.' + self._value.name x.is_init = False if init != 'same': shape = self.default_input.shape @@ -152,57 +153,41 @@ class Parameter: x.init_data() else: x.default_input = initializer(init, shape=shape, dtype=dtype) - - x.clone_info = copy(self.clone_info) - _set_clone_info(self.clone_info, x.clone_info) return x @property def layerwise_parallel(self): - return self._layerwise_parallel + return self._value.layerwise_parallel @layerwise_parallel.setter def layerwise_parallel(self, value=True): if not isinstance(value, bool): raise TypeError("`layerwise_parallel` parameter must be bool type") - self._layerwise_parallel = value + self._value.layerwise_parallel = value @property def requires_grad(self): """Return whether the parameter requires gradient.""" - return self._requires_grad + return self._value.requires_grad @requires_grad.setter def requires_grad(self, value=True): if not isinstance(value, bool): raise TypeError("`requires_grad` parameter must be bool type") - self._requires_grad = value + self._value.requires_grad = value @property - def sparse_grad(self): - """Return whether the parameter's gradient is sparse.""" - return self._sparse_grad - - @sparse_grad.setter - def sparse_grad(self, value=""): - if not isinstance(value, str): - raise TypeError("`sparse_grad` parameter must be str type") - self._sparse_grad = value + def data(self): + return self.default_input @property - def has_indexed_slices_grad(self): - """Return whether the parameter's gradient is indexed_slices.""" - return self._has_indexed_slices_grad - - @has_indexed_slices_grad.setter - def has_indexed_slices_grad(self, value=False): - if not isinstance(value, bool): - raise TypeError("`has_indexed_slices_grad` parameter must be bool type") - self._has_indexed_slices_grad = value + def default_input(self): + return self._data - @property - def data(self): - return self.default_input + @default_input.setter + def default_input(self, data): + self._data = data + self._value.data = data def __add__(self, other): return self.default_input + other @@ -223,11 +208,12 @@ class Parameter: def set_parameter_data(self, data): """Set `default_input` of current `Parameter`.""" + self.init_mode = None if isinstance(data, bool): raise ValueError('Parameter data can not be `bool`') if isinstance(data, Tensor): # make a copy of Tensor to init the parameter - data = Tensor(data.asnumpy().copy()) + data = Tensor(data.asnumpy()) data.init_flag = False elif isinstance(data, Initializer): self.init_mode = data @@ -242,7 +228,6 @@ class Parameter: self.default_input = data - def init_data(self, layout=None, set_sliced=False): """ Init data of the parameter. @@ -256,7 +241,7 @@ class Parameter: set_sliced (bool): True if should set parameter sliced after init the data of initializer. Default: False. """ - if not isinstance(self.default_input, MetaTensor): + if self.init_mode is None: return if layout is not None: if not isinstance(layout, list): diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 043ab4f6cf..64a8eb4637 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -73,7 +73,6 @@ class Tensor(Tensor_): else: Tensor_.__init__(self, input_data, dtype) self._virtual_flag = False - self._init_flag = False def __repr__(self): return str(self.__str__()) @@ -182,6 +181,9 @@ class Tensor(Tensor_): def __imod__(self, other): return self.__mod__(other) + def __pow__(self, other): + return tensor_operator_registry.get('__pow__')(self, other) + def __floordiv__(self, other): return tensor_operator_registry.get('__floordiv__')(self, other) @@ -205,19 +207,6 @@ class Tensor(Tensor_): raise TypeError("virtual_flag must be bool.") self._virtual_flag = value - @property - def init_flag(self): - """whether the tensor is init.""" - return self._init_flag - - @init_flag.setter - def init_flag(self, value): - """Set the tensor is init_flag.""" - if not isinstance(value, bool): - raise TypeError("init_flag must be bool.") - self.set_init_flag(value) - self._init_flag = value - class IndexedSlices: def __init__(self, indices, values, dense_shape): diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index 508aa2e7a9..5e1f7d06e7 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -14,7 +14,7 @@ # ============================================================================ """comm_helper""" - +import os from ._hccl_management import load_lib as hccl_load_lib _HCCL_AVAILABLE = False @@ -44,7 +44,7 @@ else: HCCL_WORLD_COMM_GROUP = "hccl_world_group" NCCL_WORLD_COMM_GROUP = "nccl_world_group" - +MS_ROLE = os.getenv("MS_ROLE") class Backend: """ @@ -152,6 +152,9 @@ def _get_rank_helper(group, backend): Integer. The local rank id of the calling process. """ rank_id = None + if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + rank_id = 0 + return rank_id if backend == Backend.HCCL: if group == HCCL_WORLD_COMM_GROUP: rank_id = hccl.get_rank_id() @@ -211,6 +214,9 @@ def _get_size_helper(group, backend): Integer. The rank size of specified group. """ size = None + if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + size = 1 + return size if backend == Backend.HCCL: if group == HCCL_WORLD_COMM_GROUP: size = hccl.get_rank_size() diff --git a/mindspore/communication/management.py b/mindspore/communication/management.py index 1cd60fe2e5..3fb4e7b947 100755 --- a/mindspore/communication/management.py +++ b/mindspore/communication/management.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """Communication management API""" +import os from mindspore.parallel._auto_parallel_context import auto_parallel_context from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ @@ -28,6 +29,7 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP DEFAULT_BACKEND = Backend("hccl") +MS_ROLE = os.getenv("MS_ROLE") def _get_group(group): @@ -58,6 +60,8 @@ def init(backend_name="hccl"): TypeError: If backend name is not a string. RuntimeError: If backend is invalid or distributed init fails. """ + if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): + return if not isinstance(backend_name, str): raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) diff --git a/mindspore/context.py b/mindspore/context.py index b5be6c3213..0de6084caf 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -17,6 +17,7 @@ The context of mindspore, used to configure the current execution environment, including execution mode, execution backend and other feature switches. """ import os +import time import threading from collections import namedtuple from types import FunctionType @@ -55,12 +56,20 @@ def _make_directory(path): os.makedirs(path) real_path = path except PermissionError as e: - logger.error( - f"No write permission on the directory `{path}, error = {e}") + logger.error(f"No write permission on the directory `{path}, error = {e}") raise ValueError(f"No write permission on the directory `{path}`.") return real_path +def _get_print_file_name(file_name): + """Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds).""" + time_second = str(int(time.time())) + file_name = file_name + "." + time_second + if os.path.exists(file_name): + ValueError("This file {} already exists.".format(file_name)) + return file_name + + class _ThreadLocalInfo(threading.local): """ Thread local Info used for store thread local attributes. @@ -209,6 +218,8 @@ class _Context: success = self._context_handle.set_device_target(target) if not success: raise ValueError("Target device name is invalid!!!") + if self.enable_debug_runtime and self.device_target == "CPU": + self.set_backend_policy("vm") @property def device_id(self): @@ -355,14 +366,6 @@ class _Context: def check_bprop(self, check_bprop_flag): self._context_handle.set_check_bprop_flag(check_bprop_flag) - @property - def enable_sparse(self): - return self._context_handle.get_enable_sparse_flag() - - @enable_sparse.setter - def enable_sparse(self, enable_sparse_flag): - self._context_handle.set_enable_sparse_flag(enable_sparse_flag) - @property def max_device_memory(self): return self._context_handle.get_max_device_memory() @@ -381,9 +384,28 @@ class _Context: return None @print_file_path.setter - def print_file_path(self, file): - self._context_handle.set_print_file_path(file) + def print_file_path(self, file_path): + """Add timestamp suffix to file name. Sets print file path.""" + print_file_path = os.path.realpath(file_path) + if os.path.isdir(print_file_path): + raise IOError("Print_file_path should be file path, but got {}.".format(file_path)) + + if os.path.exists(print_file_path): + _path, _file_name = os.path.split(print_file_path) + path = _make_directory(_path) + file_name = _get_print_file_name(_file_name) + full_file_name = os.path.join(path, file_name) + else: + full_file_name = print_file_path + self._context_handle.set_print_file_path(full_file_name) + + @property + def enable_sparse(self): + return self._context_handle.get_enable_sparse() + @enable_sparse.setter + def enable_sparse(self, enable_sparse): + self._context_handle.set_enable_sparse(enable_sparse) def check_input_format(x): import re @@ -575,8 +597,9 @@ def set_context(**kwargs): max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU. The format is "xxGB". Default: "1024GB". print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to - a file by default, and turn off printing to the screen. - enable_sparse (bool): Whether to enable sparse feature. Default: False. + a file by default, and turn off printing to the screen. If the file already exists, add a timestamp + suffix to the file. + enable_sparse (bool): Whether to enable sparsity feature. Default: False. Raises: ValueError: If input key is not an attribute in context. diff --git a/mindspore/core/abstract/CMakeLists.txt b/mindspore/core/abstract/CMakeLists.txt new file mode 100644 index 0000000000..fa331776b3 --- /dev/null +++ b/mindspore/core/abstract/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE _ABSTRACT_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_ABSTRACT_ALL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ABSTRACT) +add_library(_mindspore_abstract_obj OBJECT ${_ABSTRACT_ALL_SRC_FILES}) diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc new file mode 100644 index 0000000000..7bef3829a6 --- /dev/null +++ b/mindspore/core/abstract/abstract_value.cc @@ -0,0 +1,1097 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/abstract_value.h" + +#include + +#include "utils/symbolic.h" +#include "abstract/utils.h" + +namespace mindspore { +namespace abstract { +bool AbstractBase::operator==(const AbstractBase &other) const { + if (tid() != other.tid()) { + return false; + } + if (BuildType()->type_id() == kObjectTypeUndeterminedType && + other.BuildType()->type_id() == kObjectTypeUndeterminedType) { + return true; + } + if (value_ == nullptr || other.value_ == nullptr) { + MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: " + << this->ToString() << ", other: " << other.ToString(); + } + + bool value_equal = *value_ == *other.value_; + bool type_equal = *type_ == *other.type_; + bool shape_equal = *shape_ == *other.shape_; + return value_equal && type_equal && shape_equal; +} + +ValuePtr AbstractBase::BuildValue() const { + if (value_ == nullptr) { + return RealBuildValue(); + } + return value_; +} + +AbstractBasePtr AbstractBase::Broaden() const { + AbstractBasePtr clone = Clone(); + clone->set_value(kAnyValue); + return clone; +} + +std::string AbstractBase::ToString() const { + std::ostringstream buffer; + std::string value = std::string("value is null"); + if (value_ != nullptr) { + value = value_->ToString(); + } + MS_EXCEPTION_IF_NULL(type_); + MS_EXCEPTION_IF_NULL(shape_); + buffer << type_name() << "(" + << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")"; + return buffer.str(); +} + +AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden(); } + +AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { + MS_EXCEPTION_IF_NULL(other); + if (*this == *other) { + return shared_from_base(); + } + auto value_self = GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_self); + ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); + TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); + if (res_value == value_self) { + return shared_from_base(); + } + return std::make_shared(res_value, res_type); +} + +AbstractBasePtr AbstractType::Clone() const { + ValuePtr value_self = GetValueTrack(); + if (value_self == nullptr || !value_self->isa()) { + return nullptr; + } + TypePtr type_self = value_self->cast(); + return std::make_shared(type_self->Clone()); +} + +bool AbstractType::operator==(const AbstractBase &other) const { + if (tid() != other.tid()) { + return false; + } + // Have to compare TypePtr with value; + ValuePtr value_self = GetValueTrack(); + ValuePtr value_other = other.GetValueTrack(); + if (value_self == nullptr || value_other == nullptr) { + MS_LOG(EXCEPTION) << "AbstractType value should not be nullptr. this: " << this->ToString() + << ", other: " << other.ToString(); + } + if (!value_self->isa() || !value_other->isa()) { + return false; + } + TypePtr type_self = value_self->cast(); + TypePtr type_other = value_other->cast(); + bool value_equal = *type_self == *type_other; + return value_equal; +} + +std::string AbstractType::ToString() const { + std::ostringstream buffer; + ValuePtr value_self = GetValueTrack(); + if (value_self == nullptr) { + buffer << "AbstractType value: nullptr"; + return buffer.str(); + } + if (!value_self->isa()) { + buffer << type_name() << "(Value: nullptr)"; + return buffer.str(); + } + TypePtr type_self = value_self->cast(); + MS_EXCEPTION_IF_NULL(type_self); + buffer << type_name() << "(" + << "Value: " << type_self->ToString() << ")"; + return buffer.str(); +} + +std::string AbstractError::ToString() const { + std::ostringstream buffer; + auto value_track = GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + buffer << type_name() << "(" + << "Value: " << value_track->ToString() << ", Node: " << node_->DebugString() << ")"; + return buffer.str(); +} + +AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) { + MS_EXCEPTION_IF_NULL(other); + auto other_func = dyn_cast(other); + if (other_func == nullptr) { + MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); + } + return Join(other_func); +} + +bool AbstractFunction::operator==(const AbstractBase &other) const { + if (!other.isa()) { + return false; + } + const auto &other_func = static_cast(other); + bool value_equal = (*this == other_func); + return value_equal; +} + +const AbstractBasePtr AbstractSequeue::operator[](const std::size_t &dim) const { + if (dim >= size()) { + MS_LOG(EXCEPTION) << "Index [" << dim << "] Out of the size [" << size() << "] of the list."; + } + return elements_[dim]; +} + +std::string AbstractSequeue::ToString() const { + std::ostringstream buffer; + int i = 0; + for (const auto &ele : elements_) { + MS_EXCEPTION_IF_NULL(ele); + buffer << "element[" << i << "]: " << ele->ToString() << ","; + i++; + } + return buffer.str(); +} + +TypePtrList AbstractSequeue::ElementsType() const { + TypePtrList element_type_list; + for (const auto &ele : elements_) { + MS_EXCEPTION_IF_NULL(ele); + TypePtr element_type = ele->BuildType(); + element_type_list.push_back(element_type); + } + return element_type_list; +} + +BaseShapePtrList AbstractSequeue::ElementsShape() const { + BaseShapePtrList element_shape_list; + for (const auto &ele : elements_) { + MS_EXCEPTION_IF_NULL(ele); + BaseShapePtr element_shape = ele->BuildShape(); + element_shape_list.push_back(element_shape); + } + return element_shape_list; +} + +AbstractBasePtrList AbstractSequeue::ElementsClone() const { + AbstractBasePtrList ele_list; + for (const auto &ele : elements_) { + MS_EXCEPTION_IF_NULL(ele); + AbstractBasePtr clone = ele->Clone(); + ele_list.push_back(clone); + } + return ele_list; +} + +AbstractBasePtrList AbstractSequeue::ElementsBroaden() const { + AbstractBasePtrList ele_list; + for (const auto &ele : elements_) { + MS_EXCEPTION_IF_NULL(ele); + AbstractBasePtr broadend = ele->Broaden(); + ele_list.push_back(broadend); + } + return ele_list; +} + +template +ValuePtr AbstractSequeue::ElementsBuildValue() const { + std::vector element_value_list; + for (const auto &ele : elements_) { + ValuePtr element_value = ele->BuildValue(); + if (element_value->isa()) { + return kAnyValue; + } + element_value_list.push_back(element_value); + } + return std::make_shared(element_value_list); +} +template ValuePtr AbstractSequeue::ElementsBuildValue() const; +template ValuePtr AbstractSequeue::ElementsBuildValue() const; + +template +AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &other) { + auto other_sequeue = dyn_cast(other); + if (other_sequeue == nullptr) { + MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); + } + auto joined_list = AbstractJoin(elements_, other_sequeue->elements_); + bool changes = false; + for (std::size_t i = 0; i < elements_.size(); i++) { + if (elements_[i] != joined_list[i]) { + changes = true; + break; + } + } + if (!changes) { + return shared_from_base(); + } + return std::make_shared(joined_list); +} +template AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &); +template AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &); + +std::size_t AbstractSequeue::hash() const { + std::size_t hash_sum = hash_combine(tid(), std::hash{}(elements_.size())); + // Hashing all elements is costly, so only take at most 4 elements into account based on + // some experiments. + for (size_t i = 0; (i < elements_.size()) && (i < 4); i++) { + hash_sum = hash_combine(hash_sum, elements_[i]->hash()); + } + return hash_sum; +} + +bool AbstractTuple::operator==(const AbstractTuple &other) const { + if (&other == this) { + return true; + } + + if (elements_.size() != other.elements_.size()) { + return false; + } + for (size_t i = 0; i < elements_.size(); i++) { + if (!(*(elements_[i]) == *(other.elements_[i]))) { + return false; + } + } + return true; +} + +bool AbstractTuple::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + + if (other.isa()) { + auto other_tuple = static_cast(&other); + return *this == *other_tuple; + } + + return false; +} + +bool AbstractList::operator==(const AbstractList &other) const { + if (&other == this) { + return true; + } + + if (elements_.size() != other.elements_.size()) { + return false; + } + for (size_t i = 0; i < elements_.size(); i++) { + if (!(*(elements_[i]) == *(other.elements_[i]))) { + return false; + } + } + return true; +} + +bool AbstractList::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + + if (other.isa()) { + auto other_list = static_cast(&other); + return *this == *other_list; + } + return false; +} + +TypePtr AbstractSlice::BuildType() const { + MS_EXCEPTION_IF_NULL(start_); + MS_EXCEPTION_IF_NULL(stop_); + MS_EXCEPTION_IF_NULL(step_); + TypePtr start = start_->BuildType(); + TypePtr stop = stop_->BuildType(); + TypePtr step = step_->BuildType(); + return std::make_shared(start, stop, step); +} + +bool AbstractSlice::operator==(const AbstractSlice &other) const { + if (&other == this) { + return true; + } + return (*start_ == *other.start_ && *stop_ == *other.stop_ && *step_ == *other.step_); +} + +bool AbstractSlice::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + if (!other.isa()) { + return false; + } + auto other_slice = static_cast(&other); + return *this == *other_slice; +} + +AbstractBasePtr AbstractSlice::Clone() const { + MS_EXCEPTION_IF_NULL(start_); + MS_EXCEPTION_IF_NULL(stop_); + MS_EXCEPTION_IF_NULL(step_); + AbstractBasePtr start = start_->Clone(); + AbstractBasePtr stop = stop_->Clone(); + AbstractBasePtr step = step_->Clone(); + return std::make_shared(start, stop, step); +} + +AbstractBasePtr AbstractSlice::Broaden() const { + MS_EXCEPTION_IF_NULL(start_); + MS_EXCEPTION_IF_NULL(stop_); + MS_EXCEPTION_IF_NULL(step_); + AbstractBasePtr start = start_->Broaden(); + AbstractBasePtr stop = stop_->Broaden(); + AbstractBasePtr step = step_->Broaden(); + return std::make_shared(start, stop, step); +} + +std::string AbstractSlice::ToString() const { + std::ostringstream buffer; + buffer << type_name() << "["; + MS_EXCEPTION_IF_NULL(start_); + buffer << start_->ToString() << " : "; + MS_EXCEPTION_IF_NULL(stop_); + buffer << stop_->ToString() << " : "; + MS_EXCEPTION_IF_NULL(step_); + buffer << step_->ToString(); + buffer << "]"; + return buffer.str(); +} + +ValuePtr AbstractSlice::RealBuildValue() const { + MS_EXCEPTION_IF_NULL(start_); + MS_EXCEPTION_IF_NULL(stop_); + MS_EXCEPTION_IF_NULL(step_); + ValuePtr start = start_->BuildValue(); + ValuePtr stop = stop_->BuildValue(); + ValuePtr step = step_->BuildValue(); + if (start->isa() || stop->isa() || step->isa()) { + return kAnyValue; + } + return std::make_shared(start, stop, step); +} + +std::size_t AbstractSlice::hash() const { + MS_EXCEPTION_IF_NULL(start_); + MS_EXCEPTION_IF_NULL(stop_); + MS_EXCEPTION_IF_NULL(step_); + return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); +} + +ShapePtr AbstractUndetermined::shape() const { + auto shp = dyn_cast(GetShapeTrack()); + if (shp == nullptr) { + MS_LOG(EXCEPTION) << "Tensor should have a shape."; + } + return shp; +} + +TypePtr AbstractTensor::BuildType() const { + MS_EXCEPTION_IF_NULL(element_); + TypePtr element_type = element_->BuildType(); + return std::make_shared(element_type); +} + +BaseShapePtr AbstractTensor::BuildShape() const { + auto shape = GetShapeTrack(); + // Guard from using set_shape(nullptr) + if (shape == nullptr) { + return kNoShape; + } + return shape; +} + +AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { + if (other->BuildType()->type_id() == kObjectTypeUndeterminedType) { + auto other_tensor = dyn_cast(other); + auto element = element_->Join(other_tensor->element()); + auto shape = ShapeJoin(this->shape(), other_tensor->shape()); + auto ret = std::make_shared(element, shape); + return ret; + } + auto other_tensor = dyn_cast(other); + if (other_tensor == nullptr) { + MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); + } + if (*this == *other) { + return shared_from_base(); + } + auto element = element_->Join(other_tensor->element_); + auto shape = ShapeJoin(this->shape(), other_tensor->shape()); + return std::make_shared(element, shape); +} + +bool AbstractTensor::operator==(const AbstractTensor &other) const { + if (&other == this) { + return true; + } + + auto v1 = GetValueTrack(); + auto v2 = other.GetValueTrack(); + if (v1 == nullptr || v2 == nullptr) { + MS_LOG(EXCEPTION) << "The value of AbstractTensor is nullptr"; + } + + bool is_value_equal = (v1 == v2); + if (v1->isa() && v2->isa()) { + is_value_equal = true; + } + return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal; +} + +bool AbstractTensor::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + + if (other.isa()) { + auto other_tensor = static_cast(&other); + return *this == *other_tensor; + } else { + return false; + } +} + +AbstractBasePtr AbstractTensor::Clone() const { + MS_EXCEPTION_IF_NULL(element_); + auto clone = std::make_shared(element_->Clone()); + ShapePtr shp = shape(); + clone->set_shape(shp->Clone()); + clone->set_value(GetValueTrack()); + return clone; +} + +AbstractBasePtr AbstractTensor::Broaden() const { + MS_EXCEPTION_IF_NULL(element_); + auto broaden = std::make_shared(element_->Broaden()); + auto shp = shape(); + broaden->set_shape(shp->Clone()); + broaden->set_value(kAnyValue); + return broaden; +} + +AbstractBasePtr AbstractTensor::BroadenWithShape() const { + MS_EXCEPTION_IF_NULL(element_); + auto broaden = std::make_shared(element_->Broaden()); + auto shp = shape()->Clone(); + shp->Broaden(); + broaden->set_shape(shp); + broaden->set_value(kAnyValue); + return broaden; +} + +std::string AbstractTensor::ToString() const { + std::ostringstream buffer; + BaseShapePtr shape_track = GetShapeTrack(); + MS_EXCEPTION_IF_NULL(shape_track); + MS_EXCEPTION_IF_NULL(element_); + auto value_track = GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + buffer << type_name() << "(" + << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() + << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"; + return buffer.str(); +} + +TypePtr AbstractDictionary::BuildType() const { + std::vector> key_values; + for (const auto &item : key_values_) { + MS_EXCEPTION_IF_NULL(item.second); + TypePtr type = item.second->BuildType(); + key_values.emplace_back(item.first, type); + } + return std::make_shared(key_values); +} + +bool AbstractDictionary::operator==(const AbstractDictionary &other) const { + if (key_values_.size() != other.key_values_.size()) { + return false; + } + + for (size_t index = 0; index < key_values_.size(); index++) { + if (key_values_[index].first != other.key_values_[index].first) { + return false; + } + if (!(*key_values_[index].second == *other.key_values_[index].second)) { + return false; + } + } + return true; +} + +bool AbstractDictionary::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + if (other.isa()) { + auto other_class = static_cast(&other); + return *this == *other_class; + } + return false; +} + +AbstractBasePtr AbstractDictionary::Clone() const { + std::vector kv; + (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv), + [](const AbstractAttribute &item) { + MS_EXCEPTION_IF_NULL(item.second); + return std::make_pair(item.first, item.second->Clone()); + }); + return std::make_shared(kv); +} + +AbstractBasePtr AbstractDictionary::Broaden() const { + std::vector kv; + (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv), + [](const AbstractAttribute &item) { + MS_EXCEPTION_IF_NULL(item.second); + return std::make_pair(item.first, item.second->Broaden()); + }); + return std::make_shared(kv); +} + +std::string AbstractDictionary::ToString() const { + std::ostringstream buffer; + buffer << type_name() << "{ "; + for (const auto &kv : key_values_) { + MS_EXCEPTION_IF_NULL(kv.second); + buffer << "(" << kv.first << ": " << kv.second->ToString() << ") "; + } + buffer << "}"; + return buffer.str(); +} + +std::size_t AbstractDictionary::hash() const { + std::size_t hash_sum = std::accumulate(key_values_.begin(), key_values_.end(), tid(), + [](std::size_t hash_sum, const AbstractAttribute &item) { + hash_sum = hash_combine(hash_sum, std::hash()(item.first)); + MS_EXCEPTION_IF_NULL(item.second); + hash_sum = hash_combine(hash_sum, item.second->hash()); + return hash_sum; + }); + return hash_sum; +} + +ValuePtr AbstractDictionary::RealBuildValue() const { + std::vector> key_values; + for (const auto &item : key_values_) { + MS_EXCEPTION_IF_NULL(item.second); + auto element_value = item.second->BuildValue(); + MS_EXCEPTION_IF_NULL(element_value); + if (element_value->isa()) { + return kAnyValue; + } + key_values.emplace_back(item.first, element_value); + } + return std::make_shared(key_values); +} + +TypePtr AbstractClass::BuildType() const { + ClassAttrVector attributes_type; + for (auto attr : attributes_) { + MS_EXCEPTION_IF_NULL(attr.second); + TypePtr type = attr.second->BuildType(); + std::pair elem(attr.first, type); + attributes_type.push_back(elem); + } + + return std::make_shared(tag_, attributes_type, methods_); +} + +bool AbstractClass::operator==(const AbstractClass &other) const { + if (!(tag_ == other.tag_)) { + return false; + } + if (attributes_.size() != other.attributes_.size()) { + return false; + } + for (size_t i = 0; i < attributes_.size(); i++) { + MS_EXCEPTION_IF_NULL(attributes_[i].second); + MS_EXCEPTION_IF_NULL(other.attributes_[i].second); + if (!(*attributes_[i].second == *other.attributes_[i].second)) { + MS_LOG(DEBUG) << "attr " << attributes_[i].first << " not equal, arg1:" << attributes_[i].second->ToString() + << " arg2:" << other.attributes_[i].second->ToString(); + return false; + } + } + // method compare; + if (methods_.size() != other.methods_.size()) { + return false; + } + for (const auto &iter : methods_) { + auto iter_other = other.methods_.find(iter.first); + if (iter_other == other.methods_.end()) { + return false; + } + if (!(*iter.second == *iter_other->second)) { + return false; + } + } + return true; +} + +bool AbstractClass::operator==(const AbstractBase &other) const { + if (other.isa()) { + auto other_class = static_cast(&other); + return *this == *other_class; + } + return false; +} + +AbstractBasePtr AbstractClass::GetAttribute(const std::string &name) { + auto it = std::find_if(attributes_.begin(), attributes_.end(), + [name](const AbstractAttribute &pair) -> bool { return pair.first == name; }); + if (it != attributes_.end()) { + return it->second; + } + return nullptr; +} + +ValuePtr AbstractClass::GetMethod(const std::string &name) { + auto method_pair = methods_.find(name); + if (method_pair != methods_.end()) { + return method_pair->second; + } + return kAnyValue; +} + +AbstractBasePtr AbstractClass::Clone() const { + std::vector attributes_clone; + for (auto attr : attributes_) { + MS_EXCEPTION_IF_NULL(attr.second); + AbstractBasePtr clone = attr.second->Clone(); + AbstractAttribute elem(attr.first, clone); + attributes_clone.push_back(elem); + } + return std::make_shared(tag_, attributes_clone, methods_); +} + +AbstractBasePtr AbstractClass::Broaden() const { + std::vector attributes_clone; + for (auto attr : attributes_) { + MS_EXCEPTION_IF_NULL(attr.second); + AbstractBasePtr clone = attr.second->Broaden(); + AbstractAttribute elem(attr.first, clone); + attributes_clone.push_back(elem); + } + return std::make_shared(tag_, attributes_clone, methods_); +} + +std::string AbstractClass::ToString() const { + std::ostringstream buffer; + buffer << type_name() << "(tag: " << tag_ << ") attrs:("; + bool append_comma = false; + for (const auto &attr : attributes_) { + if (append_comma) { + buffer << ", "; + } else { + append_comma = true; + } + MS_EXCEPTION_IF_NULL(attr.second); + buffer << attr.first << ":" << attr.second->ToString(); + } + buffer << ") method:("; + append_comma = false; + for (const auto &iter : methods_) { + if (append_comma) { + buffer << ", "; + } else { + append_comma = true; + } + MS_EXCEPTION_IF_NULL(iter.second); + buffer << iter.first << ":" << iter.second->ToString(); + } + buffer << ")"; + return buffer.str(); +} + +std::size_t AbstractClass::hash() const { + std::size_t hash_sum = std::accumulate(attributes_.begin(), attributes_.end(), hash_combine(tid(), tag_.hash()), + [](std::size_t hash_sum, const AbstractAttribute &item) { + MS_EXCEPTION_IF_NULL(item.second); + return hash_combine(hash_sum, item.second->hash()); + }); + + return hash_sum; +} + +ValuePtr AbstractClass::RealBuildValue() const { + auto cls = BuildType()->cast(); + std::unordered_map attributes_value_map; + for (const auto &attr : attributes_) { + MS_EXCEPTION_IF_NULL(attr.second); + ValuePtr _value = attr.second->BuildValue(); + if (_value->isa()) { + return kAnyValue; + } + attributes_value_map[attr.first] = _value; + } + cls->set_value(attributes_value_map); + return cls; +} + +TypePtr AbstractJTagged::BuildType() const { + MS_EXCEPTION_IF_NULL(element_); + TypePtr subtype = element_->BuildType(); + return std::make_shared(subtype); +} + +AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) { + auto other_jtagged = dyn_cast(other); + if (other_jtagged == nullptr) { + MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); + } + auto joined_elem = element_->Join(other_jtagged->element_); + return std::make_shared(joined_elem); +} + +bool AbstractJTagged::operator==(const AbstractJTagged &other) const { + MS_EXCEPTION_IF_NULL(element_); + MS_EXCEPTION_IF_NULL(other.element_); + return (*element_ == *other.element_); +} + +bool AbstractJTagged::operator==(const AbstractBase &other) const { + if (other.isa()) { + auto other_jtagged = static_cast(&other); + return *this == *other_jtagged; + } + return false; +} + +std::string AbstractJTagged::ToString() const { + std::ostringstream buffer; + MS_EXCEPTION_IF_NULL(element_); + buffer << type_name() << "(" + << "element: " << element_->ToString() << ")"; + return buffer.str(); +} + +TypePtr AbstractRef::BuildType() const { + TypePtr subtype = ref_->BuildType(); + TypePtr subtype_origin = ref_origin_->BuildType(); + return std::make_shared(subtype, subtype_origin); +} + +bool AbstractRef::operator==(const AbstractRef &other) const { + return (*ref_ == *other.ref_) && (*ref_key_ == *other.ref_key_) && (*ref_origin_ == *other.ref_origin_); +} + +bool AbstractRef::operator==(const AbstractBase &other) const { + if (other.isa()) { + auto other_conf = static_cast(&other); + return *this == *other_conf; + } + return false; +} + +AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { + auto other_ref = other->cast(); + if (other_ref == nullptr) { + auto new_ref = ref_->Join(other); + return std::make_shared(ref_key_, new_ref, ref_origin_); + } + if (*this == *other) { + return shared_from_base(); + } + auto ref_key = ref_key_->Join(other_ref->ref_key_); + auto ref = ref_->Join(other_ref->ref()); + auto ref_origin = ref_origin_->Join(other_ref->ref_origin_); + + return std::make_shared(ref_key, ref, ref_origin); +} + +std::string AbstractRef::ToString() const { + std::ostringstream buffer; + buffer << type_name() << "(" + << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString() + << " origin_value: " << ref_origin_->ToString(); + auto value = GetValueTrack(); + if (value) { + buffer << ", value: " << value->ToString(); + } + buffer << ")"; + return buffer.str(); +} + +bool AbstractNone::operator==(const AbstractNone &) const { return true; } + +bool AbstractNone::operator==(const AbstractBase &other) const { + if (other.isa()) { + auto other_none = static_cast(&other); + return *this == *other_none; + } + return false; +} + +std::string AbstractNone::ToString() const { + std::ostringstream buffer; + buffer << type_name() << "(Value: None)"; + return buffer.str(); +} + +ValuePtr AbstractNone::RealBuildValue() const { return kNone; } + +bool AbstractRefKey::operator==(const AbstractRefKey &other) const { + ValuePtr value_self = GetValueTrack(); + ValuePtr value_other = other.GetValueTrack(); + if (value_self != nullptr && value_other != nullptr) { + if (value_self->isa() && value_other->isa()) { + return true; + } + if (!value_self->isa() || !value_other->isa()) { + return false; + } + RefKeyPtr type_self = value_self->cast(); + RefKeyPtr type_other = value_other->cast(); + return *type_self == *type_other; + } else if (value_self != nullptr || value_other != nullptr) { + return false; + } + return true; +} + +bool AbstractRefKey::operator==(const AbstractBase &other) const { + if (other.isa()) { + auto other_confkey = static_cast(&other); + return *this == *other_confkey; + } else { + return false; + } +} + +std::string AbstractRefKey::ToString() const { + std::ostringstream buffer; + buffer << type_name(); + auto value = GetValueTrack(); + if (value) { + buffer << "(value: " << value->ToString() << ")"; + } + return buffer.str(); +} + +bool AbstractNull::operator==(const AbstractNull &) const { return true; } + +bool AbstractNull::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + if (other.isa()) { + auto other_none = static_cast(&other); + return *this == *other_none; + } else { + return false; + } +} + +std::string AbstractNull::ToString() const { + std::ostringstream buffer; + buffer << type_name() << "(Value: Null)"; + return buffer.str(); +} + +bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; } + +bool AbstractEllipsis::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + if (other.isa()) { + auto other_none = static_cast(&other); + return *this == *other_none; + } else { + return false; + } +} + +std::string AbstractEllipsis::ToString() const { + std::ostringstream buffer; + buffer << type_name() << "(Value: Ellipsis)"; + return buffer.str(); +} + +TypePtr AbstractKeywordArg::BuildType() const { + MS_EXCEPTION_IF_NULL(arg_value_); + TypePtr type = arg_value_->BuildType(); + return std::make_shared(arg_name_, type); +} + +AbstractBasePtr AbstractKeywordArg::Clone() const { + MS_EXCEPTION_IF_NULL(arg_value_); + return std::make_shared(arg_name_, arg_value_->Clone()); +} + +AbstractBasePtr AbstractKeywordArg::Broaden() const { + MS_EXCEPTION_IF_NULL(arg_value_); + return std::make_shared(arg_name_, arg_value_->Broaden()); +} + +std::size_t AbstractKeywordArg::hash() const { + MS_EXCEPTION_IF_NULL(arg_value_); + return hash_combine({tid(), std::hash{}(arg_name_), arg_value_->hash()}); +} + +std::string AbstractKeywordArg::ToString() const { + std::ostringstream buffer; + MS_EXCEPTION_IF_NULL(arg_value_); + buffer << type_name() << "("; + buffer << "key : " << arg_name_; + buffer << "value : " << arg_value_->ToString(); + buffer << ")"; + return buffer.str(); +} + +bool AbstractKeywordArg::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + + if (other.isa()) { + auto other_tuple = static_cast(&other); + return *this == *other_tuple; + } + return false; +} + +bool AbstractKeywordArg::operator==(const AbstractKeywordArg &other) const { + if (&other == this) { + return true; + } + MS_EXCEPTION_IF_NULL(arg_value_); + MS_EXCEPTION_IF_NULL(other.arg_value_); + return other.arg_name_ == arg_name_ && *other.arg_value_ == *arg_value_; +} + +ValuePtr AbstractKeywordArg::RealBuildValue() const { + MS_EXCEPTION_IF_NULL(arg_value_); + ValuePtr value = arg_value_->BuildValue(); + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + return kAnyValue; + } + return std::make_shared(arg_name_, value); +} + +std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list) { + std::size_t hash_value = 0; + // Hashing all elements is costly, so only take at most 4 elements into account based on + // some experiments. + for (size_t i = 0; (i < args_spec_list.size()) && (i < 4); i++) { + MS_EXCEPTION_IF_NULL(args_spec_list[i]); + hash_value = hash_combine(hash_value, args_spec_list[i]->hash()); + } + return hash_value; +} + +bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + std::size_t size = lhs.size(); + for (std::size_t i = 0; i < size; i++) { + MS_EXCEPTION_IF_NULL(lhs[i]); + MS_EXCEPTION_IF_NULL(rhs[i]); + if (lhs[i] == rhs[i]) { + continue; + } + if (!(*lhs[i] == *rhs[i])) { + return false; + } + } + return true; +} + +std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &args_spec_list) const { + return AbstractBasePtrListHash(args_spec_list); +} + +bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const { + return AbstractBasePtrListDeepEqual(lhs, rhs); +} + +// IndexedSlices +TypePtr AbstractIndexedSlices::BuildType() const { + MS_EXCEPTION_IF_NULL(element()); + TypePtr element_type = element()->BuildType(); + return std::make_shared(element_type); +} + +AbstractBasePtr AbstractIndexedSlices::Clone() const { + MS_EXCEPTION_IF_NULL(element()); + auto clone = std::make_shared(element()->Clone()); + ShapePtr shp = shape(); + clone->set_shape(shp->Clone()); + clone->set_value(GetValueTrack()); + clone->set_indices(indices_->Clone()->cast()); + clone->set_values(values_->Clone()->cast()); + clone->set_dense_shape(dense_shape_->Clone()->cast()); + return clone; +} + +AbstractBasePtr AbstractIndexedSlices::Broaden() const { + MS_EXCEPTION_IF_NULL(element()); + auto broaden = std::make_shared(element()->Broaden()); + auto shp = shape(); + broaden->set_shape(shp->Clone()); + broaden->set_value(kAnyValue); + broaden->set_indices(indices_->Clone()->cast()); + broaden->set_values(values_->Clone()->cast()); + broaden->set_dense_shape(dense_shape_->Clone()->cast()); + return broaden; +} + +AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const { + MS_EXCEPTION_IF_NULL(element()); + auto broaden = std::make_shared(element()->Broaden()); + auto shp = shape()->Clone(); + shp->Broaden(); + broaden->set_shape(shp); + broaden->set_value(kAnyValue); + broaden->set_indices(indices_->Clone()->cast()); + broaden->set_values(values_->Clone()->cast()); + broaden->set_dense_shape(dense_shape_->Clone()->cast()); + return broaden; +} + +std::string AbstractIndexedSlices::ToString() const { + std::ostringstream buffer; + BaseShapePtr shape_track = GetShapeTrack(); + MS_EXCEPTION_IF_NULL(shape_track); + MS_EXCEPTION_IF_NULL(element()); + auto value_track = GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + buffer << type_name() << "(" + << "shape: " << shape_track->ToString() << ", element: " << element()->ToString() + << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")" + << ", indices: " << indices_->ToString() << ", values" << values_->ToString() + << ", dense_shape: " << dense_shape_->ToString(); + return buffer.str(); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h new file mode 100644 index 0000000000..d922f93e70 --- /dev/null +++ b/mindspore/core/abstract/abstract_value.h @@ -0,0 +1,626 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_ABSTRACT_ABSTRACT_VALUE_H_ +#define MINDSPORE_CCSRC_ABSTRACT_ABSTRACT_VALUE_H_ + +#include +#include +#include +#include +#include + +#include "utils/log_adapter.h" +#include "utils/hashing.h" +#include "base/base.h" +#include "ir/dtype.h" +#include "ir/value.h" +#include "ir/tensor.h" +#include "abstract/dshape.h" + +namespace mindspore { +namespace abstract { +class AbstractBase; +using AbstractBasePtrList = std::vector; + +// The base class for abstract value. The abstract value is used in evaluating +// to express the type, shape, and value of the real value. +class AbstractBase : public Base { + public: + explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, + const BaseShapePtr &shape = kNoShape) + : value_(value), type_(type), shape_(shape) {} + ~AbstractBase() override = default; + MS_DECLARE_PARENT(AbstractBase, Base) + + std::size_t hash() const override { return tid(); } + std::string ToString() const override; + + virtual bool operator==(const AbstractBase &other) const; + void set_value(const ValuePtr &value) { value_ = value; } + void set_type(const TypePtr &type) { type_ = type; } + void set_shape(const BaseShapePtr &shape) { shape_ = shape; } + void set_value_desc(const std::string &desc) { value_desc_ = desc; } + const std::string &value_desc() const { return value_desc_; } + ValuePtr GetValueTrack() const { return value_; } + TypePtr GetTypeTrack() const { return type_; } + BaseShapePtr GetShapeTrack() const { return shape_; } + + // Try build a real value from an abstract value. If the value cannot be built, + // a default value (AnyValue) is returned. + ValuePtr BuildValue() const; + + virtual TypePtr BuildType() const = 0; + virtual BaseShapePtr BuildShape() const { return kNoShape; } + virtual AbstractBasePtr Clone() const = 0; + virtual AbstractBasePtr Broaden() const; + virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base(); } + + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &a) { + os << a->ToString(); + return os; + } + + protected: + // default implementation, it can be overwritten by subclass; + virtual ValuePtr RealBuildValue() const { return kAnyValue; } + + private: + ValuePtr value_; + TypePtr type_; + BaseShapePtr shape_; + std::string value_desc_; // store initial value description for error report +}; + +class AbstractScalar : public AbstractBase { + public: + AbstractScalar() : AbstractBase(kAnyValue, kAnyType) {} + explicit AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {} + explicit AbstractScalar(const ValuePtr &value) : AbstractBase(value, value->type()) {} + explicit AbstractScalar(int value) : AbstractBase(MakeValue(value), kInt32) {} + explicit AbstractScalar(float value) : AbstractBase(MakeValue(value), kFloat32) {} + explicit AbstractScalar(double value) : AbstractBase(MakeValue(value), kFloat64) {} + explicit AbstractScalar(bool value) : AbstractBase(MakeValue(value), kBool) {} + explicit AbstractScalar(const std::string &value) : AbstractBase(MakeValue(value), kString) {} + explicit AbstractScalar(const TypePtr &type) : AbstractBase(kAnyValue, type) {} + ~AbstractScalar() override = default; + MS_DECLARE_PARENT(AbstractScalar, AbstractBase) + + std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); } + + TypePtr BuildType() const override { return GetTypeTrack(); } + AbstractBasePtr Clone() const override { + return std::make_shared(GetValueTrack(), GetTypeTrack()->Clone()); + } + AbstractBasePtr Broaden() const override; + AbstractBasePtr Join(const AbstractBasePtr &other) override; +}; +using AbstractScalarPtr = std::shared_ptr; + +class AbstractType : public AbstractBase { + public: + explicit AbstractType(const TypePtr &type) : AbstractBase(type, kTypeType) { + if (type == nullptr) { + MS_LOG(EXCEPTION) << "type is nullptr"; + } + } + ~AbstractType() override = default; + MS_DECLARE_PARENT(AbstractType, AbstractBase) + + std::string ToString() const override; + bool operator==(const AbstractBase &other) const override; + + TypePtr BuildType() const override { return std::make_shared(); } + AbstractBasePtr Clone() const override; + AbstractBasePtr Broaden() const override { return Clone(); } +}; +using AbstractTypePtr = std::shared_ptr; + +class AbstractError : public AbstractBase { + public: + explicit AbstractError(const StringImmPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) { + if (err == nullptr || node == nullptr) { + MS_LOG(EXCEPTION) << "err or node is nullptr"; + } + } + ~AbstractError() override = default; + MS_DECLARE_PARENT(AbstractError, AbstractBase) + + TypePtr BuildType() const override { return std::make_shared(); } + AbstractBasePtr Broaden() const override { return Clone(); } + + AbstractBasePtr Clone() const override { + return std::make_shared(GetValueTrack()->cast(), node_); + } + + std::string ToString() const override; + + private: + // Origin node been specialized to AbstractError, for debug purpose only. + const AnfNodePtr node_; +}; + +class Evaluator; +using EvaluatorPtr = std::shared_ptr; +class AnalysisEngine; +using AnalysisEnginePtr = std::shared_ptr; + +class AbstractFunction; +using AbstractFunctionPtr = std::shared_ptr; +class AbstractFuncAtom; +using AbstractFuncAtomPtr = std::shared_ptr; +using AbstractFuncAtomPtrList = std::vector; + +class AbstractFunction : public AbstractBase { + public: + AbstractFunction() = default; + ~AbstractFunction() override = default; + MS_DECLARE_PARENT(AbstractFunction, AbstractBase) + + // If there is exactly one possible function, return it. Otherwise, raise an Exception. + // Caller should ensure the uniqueness. + virtual AbstractFunctionPtr GetUnique() = 0; + + TypePtr BuildType() const override { return std::make_shared(); } + AbstractBasePtr Clone() const override { return Copy(); } + // For Function, no need to broaden. + AbstractBasePtr Broaden() const override { + return const_cast(this)->shared_from_base(); + } + virtual AbstractFunctionPtr Copy() const = 0; + + AbstractBasePtr Join(const AbstractBasePtr &other) final; + virtual AbstractFunctionPtr Join(const AbstractFunctionPtr &other) = 0; + + virtual void Visit(std::function) const = 0; + bool operator==(const AbstractBase &other) const final; + virtual bool operator==(const AbstractFunction &other) const = 0; + + static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list); + + virtual EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) = 0; + virtual AnfNodePtr tracking_id() const { return nullptr; } + virtual void set_tracking_id(AnfNodePtr) {} + virtual AnalysisContextPtr context() const { return nullptr; } +}; +using AbstractFunctionPtrList = std::vector; + +// Represents a key-value pair used in function's parameters. +class AbstractKeywordArg : public AbstractBase { + public: + AbstractKeywordArg(const std::string &key, const AbstractBasePtr &argument) : arg_name_(key), arg_value_(argument) {} + ~AbstractKeywordArg() override = default; + MS_DECLARE_PARENT(AbstractKeywordArg, AbstractBase) + + TypePtr BuildType() const override; + AbstractBasePtr Clone() const override; + AbstractBasePtr Broaden() const override; + std::size_t hash() const override; + + bool operator==(const AbstractKeywordArg &other) const; + bool operator==(const AbstractBase &other) const override; + std::string get_key() const { return arg_name_; } + AbstractBasePtr get_arg() const { return arg_value_; } + + std::string ToString() const override; + + protected: + ValuePtr RealBuildValue() const override; + + private: + std::string arg_name_; + AbstractBasePtr arg_value_; +}; +using AbstractKeywordArgPtr = std::shared_ptr; + +class AbstractUndetermined : public AbstractBase { + public: + // shape and type are all unknown + AbstractUndetermined() : AbstractBase(kAnyValue) {} + // only element_ and value, shape track are valid member, type track are unknown. + explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) + : AbstractBase(kAnyValue), element_(element) { + if (element == nullptr) { + MS_LOG(EXCEPTION) << "element is nullptr"; + } + if (element->isa()) { + MS_LOG(EXCEPTION) << "element type error"; + } + set_shape(shape); + } + AbstractUndetermined(const TypePtr &element_type, const std::vector &shape) + : AbstractBase(kAnyValue), element_(std::make_shared(kAnyValue, element_type)) { + if (element_type == nullptr) { + MS_LOG(EXCEPTION) << "element_type is nullptr"; + } + set_shape(std::make_shared(shape)); + } + ~AbstractUndetermined() override = default; + MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase) + TypePtr BuildType() const override { return std::make_shared(); } + AbstractBasePtr Clone() const override { return std::make_shared(); } + const AbstractBasePtr element() const { return element_; } + ShapePtr shape() const; + + protected: + AbstractBasePtr element_; +}; + +class AbstractTensor : public AbstractUndetermined { + public: + // only element_ and value, shape track are valid member, type track are unknown. + explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) + : AbstractUndetermined(element, shape) {} + AbstractTensor(const TypePtr &element_type, const std::vector &shape) + : AbstractUndetermined(element_type, shape) {} + explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {} + ~AbstractTensor() override = default; + MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined) + + TypePtr BuildType() const override; + BaseShapePtr BuildShape() const override; + AbstractBasePtr Clone() const override; + AbstractBasePtr Broaden() const override; + AbstractBasePtr BroadenWithShape() const; + AbstractBasePtr Join(const AbstractBasePtr &other) final; + + bool operator==(const AbstractTensor &other) const; + bool operator==(const AbstractBase &other) const override; + + std::string ToString() const override; + std::size_t hash() const override { + auto value = GetValueTrack(); + auto hash_sum = hash_combine(tid(), element_->hash()); + if (value != nullptr) { + auto tensor = value->cast(); + if (tensor != nullptr) { + hash_sum = hash_combine(hash_sum, IntToSize(tensor->DataSize())); + } + } + return hash_sum; + } +}; +using AbstractTensorPtr = std::shared_ptr; +using AbstractTensorPtrList = std::vector; + +class AbstractSequeue : public AbstractBase { + public: + explicit AbstractSequeue(const AbstractBasePtrList &elements) : elements_(elements) {} + ~AbstractSequeue() override = default; + MS_DECLARE_PARENT(AbstractSequeue, AbstractBase) + + TypePtrList ElementsType() const; + BaseShapePtrList ElementsShape() const; + AbstractBasePtrList ElementsClone() const; + AbstractBasePtrList ElementsBroaden() const; + + template + ValuePtr ElementsBuildValue() const; + + template + AbstractBasePtr ElementsJoin(const AbstractBasePtr &other); + + std::size_t size() const { return elements_.size(); } + const AbstractBasePtrList &elements() const { return elements_; } + + std::size_t hash() const override; + std::string ToString() const override; + const AbstractBasePtr operator[](const std::size_t &dim) const; + + protected: + AbstractBasePtrList elements_; +}; +using AbstractSequeuePtr = std::shared_ptr; + +class AbstractTuple : public AbstractSequeue { + public: + explicit AbstractTuple(const AbstractBasePtrList &elements) : AbstractSequeue(elements) {} + + ~AbstractTuple() override = default; + MS_DECLARE_PARENT(AbstractTuple, AbstractSequeue) + + TypePtr BuildType() const override { return std::make_shared(ElementsType()); } + + BaseShapePtr BuildShape() const override { return std::make_shared(ElementsShape()); } + + AbstractBasePtr Clone() const override { return std::make_shared(ElementsClone()); } + + AbstractBasePtr Broaden() const override { return std::make_shared(ElementsBroaden()); } + + AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin(other); } + + std::string ToString() const override { return type_name() + "(" + AbstractSequeue::ToString() + ")"; } + + bool operator==(const AbstractTuple &other) const; + bool operator==(const AbstractBase &other) const override; + + protected: + ValuePtr RealBuildValue() const override { return ElementsBuildValue(); } +}; +using AbstractTuplePtr = std::shared_ptr; + +class AbstractList : public AbstractSequeue { + public: + explicit AbstractList(const AbstractBasePtrList &elements) : AbstractSequeue(elements) {} + + ~AbstractList() override = default; + MS_DECLARE_PARENT(AbstractList, AbstractSequeue) + + TypePtr BuildType() const override { return std::make_shared(ElementsType()); } + + BaseShapePtr BuildShape() const override { return std::make_shared(ElementsShape()); } + + AbstractBasePtr Clone() const override { return std::make_shared(ElementsClone()); } + + AbstractBasePtr Broaden() const override { return std::make_shared(ElementsBroaden()); } + + AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin(other); } + + std::string ToString() const override { return type_name() + "[" + AbstractSequeue::ToString() + "]"; } + + bool operator==(const AbstractList &other) const; + bool operator==(const AbstractBase &other) const override; + + protected: + ValuePtr RealBuildValue() const override { return ElementsBuildValue(); } +}; +using AbstractListPtr = std::shared_ptr; + +class AbstractClass : public AbstractBase { + public: + AbstractClass(const Named &tag, const std::vector &attributes, + const std::unordered_map &methods) + : attributes_(attributes), tag_(tag), methods_(methods) {} + + ~AbstractClass() override = default; + MS_DECLARE_PARENT(AbstractClass, AbstractBase) + + TypePtr BuildType() const override; + bool operator==(const AbstractClass &other) const; + bool operator==(const AbstractBase &other) const override; + const std::vector &attributes() const { return attributes_; } + std::unordered_map methods() { return methods_; } + AbstractBasePtr GetAttribute(const std::string &name); + ValuePtr GetMethod(const std::string &name); + AbstractBasePtr Clone() const override; + AbstractBasePtr Broaden() const override; + std::string ToString() const override; + Named tag() const { return tag_; } + std::size_t hash() const override; + + protected: + ValuePtr RealBuildValue() const override; + + private: + std::vector attributes_; + Named tag_; + std::unordered_map methods_; +}; +using AbstractClassPtr = std::shared_ptr; + +class AbstractDictionary : public AbstractBase { + public: + explicit AbstractDictionary(const std::vector &key_values) : key_values_(key_values) {} + ~AbstractDictionary() override = default; + MS_DECLARE_PARENT(AbstractDictionary, AbstractBase) + + TypePtr BuildType() const override; + bool operator==(const AbstractDictionary &other) const; + bool operator==(const AbstractBase &other) const override; + AbstractBasePtr Clone() const override; + AbstractBasePtr Broaden() const override; + std::string ToString() const override; + std::size_t hash() const override; + std::size_t size() const { return key_values_.size(); } + const std::vector &elements() const { return key_values_; } + + std::vector key_values_; + + protected: + ValuePtr RealBuildValue() const override; +}; +using AbstractDictionaryPtr = std::shared_ptr; + +class AbstractSlice : public AbstractBase { + public: + AbstractSlice(const AbstractBasePtr &start, const AbstractBasePtr &stop, const AbstractBasePtr &step) + : start_(start), stop_(stop), step_(step) {} + ~AbstractSlice() override = default; + MS_DECLARE_PARENT(AbstractSlice, AbstractBase) + + TypePtr BuildType() const override; + bool operator==(const AbstractSlice &other) const; + bool operator==(const AbstractBase &other) const override; + AbstractBasePtr Clone() const override; + AbstractBasePtr Broaden() const override; + std::string ToString() const override; + std::size_t hash() const override; + AbstractBasePtr start() const { return start_; } + AbstractBasePtr stop() const { return stop_; } + AbstractBasePtr step() const { return step_; } + + protected: + ValuePtr RealBuildValue() const override; + + private: + AbstractBasePtr start_; + AbstractBasePtr stop_; + AbstractBasePtr step_; +}; +using AbstractSlicePtr = std::shared_ptr; + +class AbstractJTagged : public AbstractBase { + public: + explicit AbstractJTagged(const AbstractBasePtr &element) : element_(element) {} + + ~AbstractJTagged() override = default; + MS_DECLARE_PARENT(AbstractJTagged, AbstractBase) + + TypePtr BuildType() const override; + AbstractBasePtr Clone() const override { return std::make_shared(element_->Clone()); } + AbstractBasePtr Broaden() const override { return std::make_shared(element_->Broaden()); } + AbstractBasePtr Join(const AbstractBasePtr &other) override; + + bool operator==(const AbstractJTagged &other) const; + bool operator==(const AbstractBase &other) const override; + std::string ToString() const override; + AbstractBasePtr element() { return element_; } + std::size_t hash() const override { return hash_combine(tid(), element_->hash()); } + + private: + AbstractBasePtr element_; +}; +using AbstractJTaggedPtr = std::shared_ptr; + +class AbstractNone : public AbstractBase { + public: + AbstractNone() : AbstractBase() { set_type(std::make_shared()); } + ~AbstractNone() override = default; + MS_DECLARE_PARENT(AbstractNone, AbstractBase) + + TypePtr BuildType() const override { return std::make_shared(); } + bool operator==(const AbstractNone &other) const; + bool operator==(const AbstractBase &other) const override; + AbstractBasePtr Clone() const override { return std::make_shared(); } + std::string ToString() const override; + + protected: + ValuePtr RealBuildValue() const override; +}; +using AbstractNonePtr = std::shared_ptr; + +// the un assigned state value for variable, which means the variable is not assigned +class AbstractNull : public AbstractBase { + public: + AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared()); } + ~AbstractNull() override = default; + MS_DECLARE_PARENT(AbstractNull, AbstractBase) + + TypePtr BuildType() const override { return std::make_shared(); } + bool operator==(const AbstractNull &other) const; + bool operator==(const AbstractBase &other) const override; + AbstractBasePtr Clone() const override { return std::make_shared(); } + std::string ToString() const override; +}; +using AbstractNullPtr = std::shared_ptr; + +class AbstractEllipsis : public AbstractBase { + public: + AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared()); } + ~AbstractEllipsis() override = default; + MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase) + + TypePtr BuildType() const override { return std::make_shared(); } + bool operator==(const AbstractEllipsis &other) const; + bool operator==(const AbstractBase &other) const override; + AbstractBasePtr Clone() const override { return std::make_shared(); } + std::string ToString() const override; +}; +using AbstractEllipsisPtr = std::shared_ptr; + +class AbstractRefKey : public AbstractBase { + public: + AbstractRefKey() : AbstractBase() { set_type(std::make_shared()); } + ~AbstractRefKey() override = default; + MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) + + TypePtr BuildType() const override { return std::make_shared(); } + bool operator==(const AbstractRefKey &other) const; + bool operator==(const AbstractBase &other) const override; + AbstractBasePtr Clone() const override { return std::make_shared(); } + std::string ToString() const override; +}; +using AbstractRefKeyPtr = std::shared_ptr; + +class AbstractRef : public AbstractBase { + public: + AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, const AbstractBasePtr &ref_origin) + : ref_key_(ref_key), ref_(ref_value), ref_origin_(ref_origin) { + set_type(std::make_shared()); + } + + ~AbstractRef() override = default; + MS_DECLARE_PARENT(AbstractRef, AbstractBase) + + TypePtr BuildType() const override; + bool operator==(const AbstractRef &other) const; + bool operator==(const AbstractBase &other) const override; + AbstractBasePtr Clone() const override { + return std::make_shared(ref_key_->Clone(), ref_->Clone(), ref_origin_->Clone()); + } + std::string ToString() const override; + AbstractBasePtr ref() { return ref_; } + AbstractBasePtr ref_origin() { return ref_origin_; } + AbstractBasePtr ref_key() { return ref_key_; } + AbstractBasePtr Broaden() const override { + return std::make_shared(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden()); + } + AbstractBasePtr Join(const AbstractBasePtr &other) override; + std::size_t hash() const override { + return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash{}(this->tid()) << 1); + } + + private: + AbstractBasePtr ref_key_; + AbstractBasePtr ref_; + AbstractBasePtr ref_origin_; +}; +using AbstractRefPtr = std::shared_ptr; + +struct AbstractBasePtrListHasher { + std::size_t operator()(const AbstractBasePtrList &args_spec_list) const; +}; + +struct AbstractBasePtrListEqual { + bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const; +}; + +std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); +bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); + +// IndexedSlices +class AbstractIndexedSlices : public AbstractUndetermined { + public: + explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) + : AbstractUndetermined(element, shape) {} + AbstractIndexedSlices(const TypePtr &element_type, const std::vector &shape) + : AbstractUndetermined(element_type, shape) {} + ~AbstractIndexedSlices() override = default; + MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined) + + const AbstractTensorPtr indices() const { return indices_; } + const AbstractTensorPtr values() const { return values_; } + const AbstractTuplePtr dense_shape() const { return dense_shape_; } + void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } + void set_values(const AbstractTensorPtr &values) { values_ = values; } + void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } + TypePtr BuildType() const override; + AbstractBasePtr Clone() const override; + AbstractBasePtr Broaden() const override; + AbstractBasePtr BroadenWithShape() const; + + std::string ToString() const override; + + private: + AbstractTensorPtr indices_; + AbstractTensorPtr values_; + AbstractTuplePtr dense_shape_; +}; +} // namespace abstract +} // namespace mindspore +#endif // MINDSPORE_CCSRC_ABSTRACT_ABSTRACT_VALUE_H_ diff --git a/mindspore/core/abstract/analysis_context.cc b/mindspore/core/abstract/analysis_context.cc new file mode 100644 index 0000000000..1ae6125838 --- /dev/null +++ b/mindspore/core/abstract/analysis_context.cc @@ -0,0 +1,216 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/analysis_context.h" + +#include + +#include "utils/symbolic.h" +#include "debug/trace.h" + +namespace mindspore { +namespace abstract { +AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, + const AbstractBasePtrList &args_spec_list) { + auto children_context_map_iter = parent->children_cache_.find(fg); + if (children_context_map_iter != parent->children_cache_.end()) { + auto children_context_map = children_context_map_iter->second; + auto children_context_iter = children_context_map.find(args_spec_list); + if (children_context_iter != children_context_map.end()) { + return children_context_iter->second.lock(); + } + } + AnalysisContextPtr context_new = std::make_shared(parent, fg, args_spec_list); + // Reference to myself, so use weak_ptr to break reference cycle. + auto weak_context = std::weak_ptr(context_new); + context_new->parent_cache_[fg] = weak_context; + parent->children_cache_[fg][args_spec_list] = weak_context; + return context_new; +} + +AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph, + const AbstractBasePtrList &args_spec_list) { + FuncGraphPtr graph_parent = func_graph->parent(); + auto iter = parent_cache_.find(graph_parent); + AnalysisContextPtr parent_context = nullptr; + if (iter != parent_cache_.end()) { + parent_context = iter->second.lock(); + } + // if this happen, it will be bug in code. but we raise exception to keep the scene. + if (parent_context == nullptr) { + std::ostringstream oss; + oss << "BUG: cannot found parent_context in current context: " << this->ToString() + << ", func_graph: " << func_graph->ToString() << ", graph_parent: "; + if (graph_parent != nullptr) { + oss << graph_parent->ToString(); + } else { + oss << "nullptr"; + } + MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); + } + return NewContext(parent_context, func_graph, args_spec_list); +} + +AnalysisContextPtr AnalysisContext::Filter(const FuncGraphPtr &func_graph) { + auto p_iter = parent_cache_.find(func_graph); + AnalysisContextPtr parent_context = nullptr; + if (p_iter != parent_cache_.end()) { + parent_context = p_iter->second.lock(); + } else { + auto iter_parent = parent_cache_.find(func_graph->parent()); + if (iter_parent != parent_cache_.end()) { + parent_context = iter_parent->second.lock(); + } + } + // if this happen, it will be bug in code. but we raise exception to keep the scene. + if (parent_context == nullptr) { + std::ostringstream oss; + oss << "BUG: Filter graph failed: " << func_graph->ToString() << ", graph_parent: "; + if (func_graph->parent() != nullptr) { + oss << func_graph->parent()->ToString(); + } else { + oss << "nullptr"; + } + oss << " parent_cache_: {"; + for (auto iter : parent_cache_) { + if (iter.first == nullptr) { + oss << " [graph: nullptr"; + } else { + oss << " [graph: " << iter.first->ToString(); + } + // iter.second cannot be nullptr even iter.first is nullptr as it will + // always be a Context() object. + oss << ", context: " << iter.second.lock()->ToString() << "]"; + } + oss << "}"; + MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); + } + return parent_context; +} + +AnalysisContextPtr AnalysisContext::DummyContext() { + AnalysisContextPtr dummy_context = std::make_shared(nullptr, nullptr, AbstractBasePtrList()); + dummy_context->parent_cache_[nullptr] = std::weak_ptr(dummy_context); + return dummy_context; +} + +bool AnalysisContext::IsDummyContext() { + if (parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty()) { + return true; + } + return false; +} + +const AnalysisContextPtr kDummyAnalysisContext = + std::make_shared(nullptr, nullptr, AbstractBasePtrList()); + +bool AnalysisContext::operator==(const AnalysisContext &other) const { + if (func_graph_ != other.func_graph_) { + return false; + } + + if (args_spec_list_.size() != other.args_spec_list_.size()) { + return false; + } + + if (((parent_ == nullptr) && (other.parent_ != nullptr)) || ((parent_ != nullptr) && (other.parent_ == nullptr))) { + return false; + } + // Compare parent with content. + bool is_parent_equal = false; + if (parent_ == other.parent_) { + is_parent_equal = true; + } else if (*parent_ == *other.parent_) { + is_parent_equal = true; + } else { + return false; + } + for (std::size_t i = 0; i < args_spec_list_.size(); i++) { + if (!(*args_spec_list_[i] == *other.args_spec_list_[i])) { + return false; + } + } + return is_parent_equal; +} + +// brief The key which controls the graph cloning in Specialize. +// +// Originally, specialize use context directly as the key for cloning graph. The graph will be cloned multiple times +// for different context, which means the graph is called from different node with different arguments and different +// free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what +// graph can be reused. +// The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined +// and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in eval, thus the reused +// graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize. +// The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies +// on correct shape to specialize a tensor constant. +AnalysisContextPtr AnalysisContext::SpecializeKey() const { + AbstractBasePtrList args_broad_shp; + (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(args_broad_shp), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { + if (arg->isa()) { + auto val = arg->GetValueTrack(); + if (val->isa()) { + auto scalar_spec = dyn_cast(arg); + auto ret_spec = scalar_spec->Broaden(); + return ret_spec; + } + } + if (arg->isa()) { + MS_LOG(DEBUG) << "refkey broaden"; + auto arg_spec = dyn_cast(arg); + auto ret_spec = arg_spec->Broaden(); + return ret_spec; + } + return arg; + }); + AnalysisContextPtr context_new = std::make_shared(nullptr, func_graph_, args_broad_shp); + context_new->parent_ = parent_; + return context_new; +} + +std::size_t AnalysisContext::hash() { + std::size_t hash_value = 0; + // hash() recursion exit condition. + if (parent_ != nullptr) { + hash_value = hash_combine(hash_value, parent_->hash()); + } + if (func_graph_ != nullptr) { + hash_value = hash_combine(hash_value, func_graph_->hash()); + } + return hash_value; +} + +std::string AnalysisContext::ToString() const { + std::ostringstream buffer; + buffer << "{"; + if (func_graph_ != nullptr) { + buffer << "Func Graph: " << func_graph_->ToString(); + } + buffer << " Args: "; + int i = 0; + for (const auto &arg : args_spec_list_) { + buffer << "[" << i << "]: " << arg->ToString() << ", "; + i++; + } + if (parent_ != nullptr) { + buffer << "Parent: " << parent_->ToString(); + } + buffer << "}"; + return buffer.str(); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/core/abstract/analysis_context.h b/mindspore/core/abstract/analysis_context.h new file mode 100644 index 0000000000..c0293d7e91 --- /dev/null +++ b/mindspore/core/abstract/analysis_context.h @@ -0,0 +1,88 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_ABSTRACT_ANALYSIS_CONTEXT_H_ +#define MINDSPORE_CCSRC_ABSTRACT_ANALYSIS_CONTEXT_H_ + +#include +#include +#include + +#include "abstract/abstract_value.h" +#include "ir/meta_func_graph.h" + +namespace mindspore { +namespace abstract { +class AnalysisContext; +using AnalysisContextWeakPtr = std::weak_ptr; +using ArgsSpecToAnalysisContextMap = + std::unordered_map; + +// AnalysisContext will be stored in Config in AnalysisCache. +class AnalysisContext { + public: + AnalysisContext(const AnalysisContextPtr &parent, const FuncGraphPtr &fg, const AbstractBasePtrList &args_spec_list) + : parent_(parent), func_graph_(fg), args_spec_list_(args_spec_list) { + if (parent_ != nullptr) { + parent_cache_ = parent_->parent_cache_; + } + } + + ~AnalysisContext() = default; + + // Helper function to wrapper constructor to save shared_ptr in parent_cache. + AnalysisContextPtr NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, const AbstractBasePtrList &args_spec_list); + + // Extend this context with values for another graph. + AnalysisContextPtr NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); + + // Return a context restricted to a graph's dependencies. + AnalysisContextPtr Filter(const FuncGraphPtr &graph); + bool operator==(const AnalysisContext &other) const; + std::size_t hash(); + static AnalysisContextPtr DummyContext(); + bool IsDummyContext(); + FuncGraphPtr func_graph() const { return func_graph_; } + AnalysisContextPtr parent() const { return parent_; } + std::string ToString() const; + AnalysisContextPtr SpecializeKey() const; + AbstractBasePtrList args_spec_list() { return args_spec_list_; } + + private: + AnalysisContextPtr parent_; + FuncGraphPtr func_graph_; + AbstractBasePtrList args_spec_list_; + std::unordered_map parent_cache_; + std::unordered_map children_cache_; +}; + +struct ContextHasher { + std::size_t operator()(const AnalysisContextPtr &t) const { + std::size_t hash = t->hash(); + return hash; + } +}; + +struct ContextEqual { + bool operator()(const AnalysisContextPtr &lhs, const AnalysisContextPtr &rhs) const { return *lhs == *rhs; } +}; + +extern const AnalysisContextPtr kDummyAnalysisContext; +} // namespace abstract +} // namespace mindspore +#endif // MINDSPORE_CCSRC_ABSTRACT_ANALYSIS_CONTEXT_H_ diff --git a/mindspore/core/abstract/dshape.cc b/mindspore/core/abstract/dshape.cc new file mode 100644 index 0000000000..74ea1ff7bf --- /dev/null +++ b/mindspore/core/abstract/dshape.cc @@ -0,0 +1,134 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/dshape.h" + +#include +#include + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace abstract { +// used for print BaseShape content +std::ostream &operator<<(std::ostream &os, const BaseShape &bs) { + os << bs.ToString(); + return os; +} + +std::ostream &operator<<(std::ostream &os, const std::shared_ptr bs) { + MS_EXCEPTION_IF_NULL(bs); + os << bs->ToString(); + return os; +} + +bool BaseShape::operator==(const BaseShape &other) const { + if (tid() != other.tid()) { + return false; + } + return true; +} + +bool BaseShape::operator!=(const BaseShape &other) const { return !(*this == other); } + +std::string Shape::ToString() const { + std::ostringstream buffer; + bool f_begin = true; + buffer << "("; + for (auto &x : shape_) { + if (!f_begin) { + buffer << ", "; + } else { + f_begin = false; + } + buffer << x; + } + buffer << ")"; + return buffer.str(); +} + +std::string Shape::DumpText() const { + std::ostringstream buffer; + buffer << "["; + for (size_t i = 0; i < shape_.size(); i++) { + buffer << (i > 0 ? ", " : "") << shape_[i]; + } + buffer << "]"; + return buffer.str(); +} + +bool Shape::operator==(const BaseShape &other) const { + if (tid() != other.tid()) { + return false; + } + return shape_ == static_cast(other).shape_; +} + +const int Shape::SHP_ANY; +void Shape::Broaden() { + for (size_t i = 0; i < shape_.size(); i++) { + shape_[i] = SHP_ANY; + } +} + +std::string SequeueShape::ToString() const { + std::ostringstream buffer; + bool f_begin = true; + for (auto p_shp : p_shapes_) { + if (!f_begin) { + buffer << ", "; + } else { + f_begin = false; + } + MS_EXCEPTION_IF_NULL(p_shp); + buffer << p_shp->ToString(); + } + return buffer.str(); +} + +BaseShapePtrList SequeueShape::ElementsClone() const { + BaseShapePtrList ele_list; + for (auto p_shp : p_shapes_) { + MS_EXCEPTION_IF_NULL(p_shp); + ele_list.push_back(p_shp->Clone()); + } + return ele_list; +} + +template +bool SequeueShape::SequeueEqual(const BaseShape &other) const { + if (tid() != other.tid()) { + return false; + } + auto other_shapes = static_cast(other).p_shapes_; + if (other_shapes.size() != p_shapes_.size()) { + return false; + } + for (unsigned int i = 0; i < p_shapes_.size(); ++i) { + if (!(*p_shapes_[i] == *other_shapes[i])) { + return false; + } + } + return true; +} +template bool SequeueShape::SequeueEqual(const BaseShape &) const; +template bool SequeueShape::SequeueEqual(const BaseShape &) const; + +const std::shared_ptr kNoShape = std::make_shared(); +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/core/abstract/dshape.h b/mindspore/core/abstract/dshape.h new file mode 100644 index 0000000000..b9b8e93292 --- /dev/null +++ b/mindspore/core/abstract/dshape.h @@ -0,0 +1,135 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_ABSTRACT_DSHAPE_H_ +#define MINDSPORE_CCSRC_ABSTRACT_DSHAPE_H_ + +#include +#include +#include +#include +#include +#include + +#include "utils/log_adapter.h" +#include "base/base.h" + +namespace mindspore { +namespace abstract { +class BaseShape; +using BaseShapePtr = std::shared_ptr; +using BaseShapePtrList = std::vector; + +class BaseShape : public Base { + public: + BaseShape() = default; + ~BaseShape() override = default; + + MS_DECLARE_PARENT(BaseShape, Base) + virtual bool operator==(const BaseShape &other) const; + bool operator!=(const BaseShape &other) const; + std::size_t hash() const override { return tid(); } + + // return a deep copy + virtual BaseShapePtr Clone() const = 0; + virtual void Broaden() {} +}; + +class NoShape : public BaseShape { + public: + MS_DECLARE_PARENT(NoShape, BaseShape) + BaseShapePtr Clone() const override { return std::make_shared(); } + std::string ToString() const override { return type_name(); } +}; +extern const std::shared_ptr kNoShape; + +class Shape : public BaseShape { + public: + static const int SHP_ANY = -1; + Shape() : shape_() {} + Shape(const std::initializer_list &list) : shape_(list) {} + explicit Shape(const std::vector &list) : shape_(list) {} + ~Shape() override = default; + MS_DECLARE_PARENT(Shape, BaseShape) + std::string ToString() const override; + std::string DumpText() const override; + bool operator==(const BaseShape &other) const override; + BaseShapePtr Clone() const override { return std::make_shared(shape_); } + void Broaden() override; + std::vector &shape() { return shape_; } + + std::vector shape_; // use SHP_ANY to implement the any shape in python +}; +using ShapePtr = std::shared_ptr; +using ShapePtrList = std::vector; + +class SequeueShape : public BaseShape { + public: + SequeueShape() : p_shapes_() {} + explicit SequeueShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {} + ~SequeueShape() override = default; + MS_DECLARE_PARENT(SequeueShape, BaseShape) + + std::string ToString() const override; + BaseShapePtrList ElementsClone() const; + + template + bool SequeueEqual(const BaseShape &other) const; + + const BaseShapePtrList &shape() const { return p_shapes_; } + size_t size() const { return p_shapes_.size(); } + const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; } + + protected: + BaseShapePtrList p_shapes_; // shape list of each elements +}; +using SequeueShapePtr = std::shared_ptr; + +class TupleShape : public SequeueShape { + public: + TupleShape() : SequeueShape() {} + explicit TupleShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} + ~TupleShape() override = default; + MS_DECLARE_PARENT(TupleShape, SequeueShape) + + std::string ToString() const override { return type_name() + "(" + SequeueShape::ToString() + ")"; } + + BaseShapePtr Clone() const override { return std::make_shared(ElementsClone()); } + + bool operator==(const BaseShape &other) const override { return SequeueEqual(other); } +}; +using TupleShapePtr = std::shared_ptr; + +class ListShape : public SequeueShape { + public: + ListShape() : SequeueShape() {} + explicit ListShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} + ~ListShape() override = default; + MS_DECLARE_PARENT(ListShape, SequeueShape) + + std::string ToString() const override { return type_name() + "[" + SequeueShape::ToString() + "]"; } + + BaseShapePtr Clone() const override { return std::make_shared(SequeueShape::ElementsClone()); } + + bool operator==(const BaseShape &other) const override { return SequeueEqual(other); } +}; +using ListShapePtr = std::shared_ptr; +} // namespace abstract +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_ABSTRACT_DSHAPE_H_ diff --git a/mindspore/core/abstract/param_validator.cc b/mindspore/core/abstract/param_validator.cc new file mode 100644 index 0000000000..69fe88b4a3 --- /dev/null +++ b/mindspore/core/abstract/param_validator.cc @@ -0,0 +1,147 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/param_validator.h" + +#include +#include +#include +#include "utils/symbolic.h" +#include "abstract/utils.h" + +namespace mindspore { +namespace abstract { +#define ABSTRACT_REPORT_NAME_DEC(abstract) constexpr char ReportNameTraits::name[]; + +ABSTRACT_REPORT_NAME_DEC(Tensor) +ABSTRACT_REPORT_NAME_DEC(Tuple) +ABSTRACT_REPORT_NAME_DEC(Scalar) +ABSTRACT_REPORT_NAME_DEC(List) +ABSTRACT_REPORT_NAME_DEC(Dictionary) +ABSTRACT_REPORT_NAME_DEC(Slice) +ABSTRACT_REPORT_NAME_DEC(Function) +ABSTRACT_REPORT_NAME_DEC(Type) +ABSTRACT_REPORT_NAME_DEC(KeywordArg) +ABSTRACT_REPORT_NAME_DEC(Class) + +TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix) { + bool ok = std::any_of(accepts.begin(), accepts.end(), + [type](const TypePtr &accept) -> bool { return IsIdentidityOrSubclass(type, accept); }); + if (ok) { + return type; + } else { + MS_LOG(EXCEPTION) << error_message_prefix << accepts << " but is " << type->ToString(); + } +} + +TypePtr CheckTensorDType(const AbstractTensorPtr &tensor, const TypePtrList &accepts, + const std::string &error_message_prefix) { + MS_EXCEPTION_IF_NULL(tensor); + TypePtr type = tensor->BuildType(); + if (!type->isa()) { + MS_LOG(EXCEPTION) << error_message_prefix << "requires Tensor but got " << type->ToString(); + } + TypePtr ele_type = tensor->element()->BuildType(); + if (ele_type == nullptr) { + MS_LOG(EXCEPTION) << "Abstract tensor element type nullptr"; + } + return CheckType(ele_type, accepts, error_message_prefix); +} + +TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const TypePtrList &accepts, + const std::string &error_message_prefix) { + if (tensor_list.empty()) { + MS_LOG(EXCEPTION) << "Array list is empty"; + } + + auto sample_tensor = tensor_list[0]; + MS_EXCEPTION_IF_NULL(sample_tensor); + TypePtr sample_type = sample_tensor->element()->BuildType(); + std::ostringstream loginfoBuffer; + loginfoBuffer << "same type, got"; + // Check if other elements have the same type with the first element. + for (size_t index = 1; index < tensor_list.size(); ++index) { + MS_EXCEPTION_IF_NULL(tensor_list[index]); + auto aType = tensor_list[index]->element()->BuildType(); + loginfoBuffer << " " << aType->ToString(); + if (sample_type->type_id() != aType->type_id()) { + MS_LOG(EXCEPTION) << "Expected type " << sample_type->ToString() << ", but got " << aType->ToString() + << ", index " << index; + } + } + MS_LOG(DEBUG) << error_message_prefix << loginfoBuffer.str(); + return CheckTensorDType(sample_tensor, accepts, error_message_prefix); +} + +TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts, + const std::string &error_message_prefix) { + if (scalar == nullptr) { + MS_LOG(EXCEPTION) << "Scalar nullptr"; + } + auto type = scalar->BuildType(); + if (type == nullptr) { + MS_LOG(EXCEPTION) << "Scalar value nullptr"; + } + + return CheckType(type, accepts, error_message_prefix); +} + +ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) { + ShapePtr shape_base = tensor_base->shape(); + ShapePtr shape = tensor->shape(); + if (*shape != *shape_base) { + MS_LOG(EXCEPTION) << op << " evaluator first arg shape " << tensor->shape()->ToString() + << " are not consistent with second arg shape " << tensor_base->shape()->ToString(); + } + return shape_base; +} + +TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) { + TypePtr type_base = tensor_base->element()->BuildType(); + TypePtr type = tensor->element()->BuildType(); + if (*type != *type_base) { + MS_LOG(EXCEPTION) << op << " evaluator first arg dtype " << type_base->ToString() + << " are not consistent with second arg dtype " << type->ToString(); + } + return type_base; +} + +int CheckAxis(const std::string &op, const ValuePtr &axis, int minimum, int max) { + if (axis == nullptr) { + MS_LOG(EXCEPTION) << op << " evaluator axis is null"; + } + if (!axis->isa()) { + MS_LOG(EXCEPTION) << op << " evaluator axis should be int, but got " << axis->type_name(); + } + int axis_value = GetValue(axis); + if (axis_value > max || axis_value < minimum) { + MS_LOG(EXCEPTION) << op << " evaluator axis value should be in the range [" << minimum << ", " << max + << "], but get " << axis_value; + } + return axis_value; +} +void CheckArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list, + size_t size_expect) { + if (args_spec_list.size() != size_expect) { + MS_LOG(EXCEPTION) << op << " input args size should be " << size_expect << ", but got " << args_spec_list.size(); + } + + for (size_t i = 0; i < size_expect; i++) { + MS_EXCEPTION_IF_NULL(args_spec_list[i]); + } +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/core/abstract/param_validator.h b/mindspore/core/abstract/param_validator.h new file mode 100644 index 0000000000..434235abda --- /dev/null +++ b/mindspore/core/abstract/param_validator.h @@ -0,0 +1,100 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_ABSTRACT_PARAM_VALIDATOR_H_ +#define MINDSPORE_CCSRC_ABSTRACT_PARAM_VALIDATOR_H_ + +#include +#include +#include +#include +#include "abstract/abstract_value.h" +#include "abstract/utils.h" +#include "utils/any.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace abstract { +// check if variable's type is an instance of any of accepts or of a subclass of it. +TypePtr CheckType(TypePtr type, const TypePtrList &accepts, const std::string &error_message_prefix); + +TypePtr CheckTensorDType(const AbstractTensorPtr &tensor, const TypePtrList &accepts, + const std::string &error_message_prefix); + +TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const TypePtrList &accepts, + const std::string &error_message_prefix); + +TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts, + const std::string &error_message_prefix); + +ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor); + +TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor); + +int CheckAxis(const std::string &op, const ValuePtr &axis, int min, int max); + +void CheckArgsSize(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t size_expect); + +template +struct ReportNameTraits {}; + +#define ABSTRACT_REPORT_NAME_TRAITS(abstract) \ + template <> \ + struct ReportNameTraits { \ + static constexpr char name[] = #abstract; \ + }; +ABSTRACT_REPORT_NAME_TRAITS(Tensor) +ABSTRACT_REPORT_NAME_TRAITS(Tuple) +ABSTRACT_REPORT_NAME_TRAITS(Scalar) +ABSTRACT_REPORT_NAME_TRAITS(List) +ABSTRACT_REPORT_NAME_TRAITS(Dictionary) +ABSTRACT_REPORT_NAME_TRAITS(Slice) +ABSTRACT_REPORT_NAME_TRAITS(Function) +ABSTRACT_REPORT_NAME_TRAITS(Type) +ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) +ABSTRACT_REPORT_NAME_TRAITS(Class) +ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices) +ABSTRACT_REPORT_NAME_TRAITS(Sequeue) + +template +std::shared_ptr CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) { + if (index >= args_spec_list.size()) { + MS_EXCEPTION(ValueError) << op << " evaluator args list index out of bound, size " << args_spec_list.size() + << ", index " << index; + } + auto arg = dyn_cast(args_spec_list[index]); + if (arg == nullptr) { + MS_EXCEPTION(TypeError) << "Operator " << op << " input[" << index << "] should be " << ReportNameTraits::name + << ", but got " << args_spec_list[index]->BuildType()->ToString() << "."; + } + return arg; +} + +// check if each element in args_spec is type T, and can be joined. +template +void CheckArgsSpec(const AbstractBasePtrList &args_list) { + for (const auto &arg : args_list) { + if (!arg->isa()) { + MS_EXCEPTION(TypeError) << "Expected type " << ReportNameTraits::name << ", but got " + << arg->BuildType()->ToString() << "."; + } + } + (void)AbstractJoin(args_list); +} +} // namespace abstract +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_ABSTRACT_PARAM_VALIDATOR_H_ diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc new file mode 100644 index 0000000000..16497c74a9 --- /dev/null +++ b/mindspore/core/abstract/utils.cc @@ -0,0 +1,201 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/utils.h" + +#include +#include +#include +#include "utils/symbolic.h" +#include "abstract/param_validator.h" + +namespace mindspore { +namespace abstract { +ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) { + MS_EXCEPTION_IF_NULL(value1); + MS_EXCEPTION_IF_NULL(value2); + if (*value1 == *value2) { + return value1; + } + return kAnyValue; +} + +TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2) { + MS_EXCEPTION_IF_NULL(type1); + MS_EXCEPTION_IF_NULL(type2); + if (*type1 == *type2) { + return type1; + } + return kAnyType; +} + +ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { + MS_EXCEPTION_IF_NULL(shape1); + MS_EXCEPTION_IF_NULL(shape2); + if (*shape1 == *shape2) { + return shape1; + } + if (shape1->shape().size() != shape2->shape().size()) { + MS_LOG(WARNING) << "Unsupported shape join. shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString(); + return shape1; + } + std::vector dims; + dims.resize(shape1->shape().size()); + for (std::size_t i = 0; i < shape1->shape().size(); i++) { + if (shape1->shape()[i] == shape2->shape()[i]) { + dims[i] = shape1->shape()[i]; + } else { + dims[i] = Shape::SHP_ANY; + } + } + return std::make_shared(dims); +} + +AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.size() < 1) { + MS_LOG(EXCEPTION) << "AbstractJoin requires at least 1 params, while the input size is " << args_spec_list.size() + << "."; + } + AbstractBasePtr arg_spec_tmp = args_spec_list[0]; + MS_EXCEPTION_IF_NULL(arg_spec_tmp); + for (auto arg_spec : args_spec_list) { + arg_spec_tmp = arg_spec_tmp->Join(arg_spec); + MS_EXCEPTION_IF_NULL(arg_spec_tmp); + } + return arg_spec_tmp; +} + +AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2) { + if (spec1.size() != spec2.size()) { + MS_LOG(EXCEPTION) << "Join failed as list don't have the same size. spec1: " << ::mindspore::ToString(spec1) + << ", spec2: " << ::mindspore::ToString(spec2); + } + AbstractBasePtrList joined_list; + bool changes = false; + for (std::size_t i = 0; i < spec1.size(); i++) { + auto joined_elem = spec1[i]->Join(spec2[i]); + if (joined_elem != spec1[i]) { + changes = true; + } + joined_list.push_back(joined_elem); + } + if (!changes) { + return spec1; + } + return joined_list; +} + +AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec) { + AbstractFunctionPtr f_spec = dyn_cast(spec); + if (f_spec != nullptr) { + return std::make_shared(kAnyValue, std::make_shared()); + } + return spec->Clone(); +} + +namespace { +// Join all types in args_type_list; +TypePtr TypeJoin(const TypePtrList &args_type_list) { + if (args_type_list.empty()) { + MS_LOG(EXCEPTION) << "args_type_list is empty"; + } + + TypePtr type_tmp = args_type_list[0]; + for (std::size_t i = 1; i < args_type_list.size(); i++) { + type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]); + } + return type_tmp; +} +} // namespace + +bool CheckType(const TypePtr &expected_type, const TypePtr &x) { + // As x and predicate both are mindspore type staticly, here we only to judge whether + // x is predicate or is a subclass of predicate. + return IsIdentidityOrSubclass(x, expected_type); +} + +TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) { + MS_EXCEPTION_IF_NULL(predicate); + for (auto arg_type : args_type_list) { + MS_EXCEPTION_IF_NULL(arg_type); + if (!CheckType(predicate, arg_type)) { + MS_LOG(EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString(); + } + } + return TypeJoin(args_type_list); +} + +int GetPositiveAxis(int axis_value, size_t increment) { + if (axis_value < 0) { + axis_value = axis_value + SizeToInt(increment); + } + + if (axis_value < 0) { + MS_LOG(EXCEPTION) << "axis_value should not still <0"; + } + + return axis_value; +} + +// Return if two shapes can be broadcast. +// Broadcast shape is placed in broadcast_output_shape. +std::vector RealBroadcast(const std::string &op, std::vector x_shape, std::vector y_shape) { + std::reverse(x_shape.begin(), x_shape.end()); + std::reverse(y_shape.begin(), y_shape.end()); + // Fill a placeholder value 1 which will be replaced later. + size_t std_len = x_shape.size() > y_shape.size() ? x_shape.size() : y_shape.size(); + y_shape.resize(std_len, 1); + x_shape.resize(std_len, 1); + + std::vector broadcast_shape; + for (size_t i = 0; i < std_len; i++) { + int x_i = x_shape[i]; // i-th dimension of x + int y_i = y_shape[i]; // i-th dimension of y + int output_i = 0; // i-th dimension of the output + if (x_i == y_i) { + output_i = x_i; + } else if (x_i == 1) { + output_i = y_i; + } else if (y_i == 1) { + output_i = x_i; + } else { + MS_LOG(EXCEPTION) + << op + << " evaluator the shape of first tensor and the shape of second tensor do not meet the broadcasting " + "requirements"; + } + broadcast_shape.push_back(output_i); + } + std::reverse(broadcast_shape.begin(), broadcast_shape.end()); + return broadcast_shape; +} + +ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, + const AbstractTensorPtr &tensor_y) { + mindspore::abstract::ShapePtr tensor_x_shape = tensor_x->shape(); + mindspore::abstract::ShapePtr tensor_y_shape = tensor_y->shape(); + // if is the same shape ,just return the x_shape + if (*tensor_x_shape == *tensor_y_shape) { + return tensor_x_shape; + } + auto x_shape = tensor_x_shape->shape(); + auto y_shape = tensor_y_shape->shape(); + return std::make_shared(RealBroadcast(op, x_shape, y_shape)); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/core/abstract/utils.h b/mindspore/core/abstract/utils.h new file mode 100644 index 0000000000..be38ae860d --- /dev/null +++ b/mindspore/core/abstract/utils.h @@ -0,0 +1,56 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_ABSTRACT_UTILS_H_ +#define MINDSPORE_CCSRC_ABSTRACT_UTILS_H_ + +#include +#include +#include +#include +#include "abstract/abstract_value.h" +#include "utils/any.h" +#include "utils/misc.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace abstract { +ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2); +TypePtr TypeJoin(const TypePtr &type1, const TypePtr &type2); +ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2); + +AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list); +AbstractBasePtrList AbstractJoin(const AbstractBasePtrList &spec1, const AbstractBasePtrList &spec2); + +// Return an abstract value for the sensitivity of x. +// The sensitivity of a function is an Env +// The sensitivity of J(x) is x +// else self.Clone; +AbstractBasePtr SensitivityTransform(const AbstractBasePtr &spec); + +TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list); + +bool CheckType(const TypePtr &expected_type, const TypePtr &x); + +int GetPositiveAxis(int axis_value, size_t increment); + +// Get broadcasted shape for binary element-wise operation +ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y); +} // namespace abstract +} // namespace mindspore +#endif // MINDSPORE_CCSRC_ABSTRACT_UTILS_H_ diff --git a/mindspore/core/base/CMakeLists.txt b/mindspore/core/base/CMakeLists.txt new file mode 100644 index 0000000000..d65b91a824 --- /dev/null +++ b/mindspore/core/base/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE _BASE_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_BASE_ALL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_BASE) +add_library(_mindspore_base_obj OBJECT ${_BASE_ALL_SRC_FILES}) diff --git a/mindspore/core/base/base.cc b/mindspore/core/base/base.cc new file mode 100644 index 0000000000..07ed252e96 --- /dev/null +++ b/mindspore/core/base/base.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "base/base.h" +#include +#include +#include + +namespace mindspore { +const bool Base::IsFromTypeId(uint32_t tid) const { + static const uint32_t node_id = GetTypeId(typeid(Base).name()); + return tid == node_id; +} + +uint32_t Base::GetTypeId(const char *const type_name) { + TypeIdManager *t = TypeIdManager::Get(); + std::lock_guard(t->mutex); + auto it = t->map.find(type_name); + if (it != t->map.end()) { + return it->second; + } + uint32_t tid = ++(t->type_counter); + t->map[type_name] = tid; + return tid; +} +} // namespace mindspore diff --git a/mindspore/core/base/base.h b/mindspore/core/base/base.h new file mode 100644 index 0000000000..8e1a447c0d --- /dev/null +++ b/mindspore/core/base/base.h @@ -0,0 +1,152 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BASE_BASE_H_ +#define MINDSPORE_CCSRC_BASE_BASE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/visible.h" +#include "utils/log_adapter.h" +#include "utils/ordered_set.h" +#include "utils/ordered_map.h" + +namespace mindspore { +template +struct is_shared_ptr : public std::false_type {}; +template +struct is_shared_ptr> : public std::true_type {}; + +class Base : public std::enable_shared_from_this { + public: + constexpr Base() = default; + Base(const Base &other) : std::enable_shared_from_this(other) {} + virtual bool operator==(const Base &rhs) { + if (this == &rhs) { + return true; + } + return false; + } + + virtual Base &operator=(const Base &) { return *this; } + virtual ~Base() = default; + virtual std::size_t hash() const { return tid(); } + virtual std::string ToString() const { return type_name(); } + virtual void dump() const { std::cout << ToString() << std::endl; } + + virtual std::string DumpText() const { return ToString(); } + + virtual const bool IsFromTypeId(uint32_t tid) const; + virtual std::string type_name() const { return "Base"; } + static uint32_t GetTypeId(const char *const type_key); + virtual uint32_t tid() const { + static const uint32_t tid = GetTypeId(typeid(Base).name()); + return tid; + } + + template ::value && std::is_base_of::value, T>::type * = nullptr> + inline bool isa() const { + static const uint32_t tid = GetTypeId(typeid(T).name()); + return this->IsFromTypeId(tid); + } + + template ::value, typename T::element_type>::type> + inline T cast() { + if (isa()) { + return std::static_pointer_cast(shared_from_this()); + } else { + return nullptr; + } + } + + protected: + template + std::shared_ptr shared_from_base() { + return std::static_pointer_cast(shared_from_this()); + } +}; + +using BasePtr = std::shared_ptr; +using BaseWeakPtr = std::weak_ptr; + +template +inline T *cast(U *source) { + if (source != nullptr && source->template isa()) { + return static_cast(source); + } else { + return nullptr; + } +} + +template < + typename T, typename U, + typename std::enable_if::value && std::is_base_of::value, T>::type * = nullptr> +inline std::shared_ptr dyn_cast(const std::shared_ptr r) { + if (r != nullptr && r->template isa()) { + return std::static_pointer_cast(r); + } else { + return std::shared_ptr(); + } +} + +#define MS_DECLARE_PARENT(current_t, parent_t) \ + uint32_t tid() const override { \ + static const uint32_t tid = GetTypeId(typeid(current_t).name()); \ + return tid; \ + } \ + const bool IsFromTypeId(uint32_t from_tid) const override { \ + static const uint32_t tid = Base::GetTypeId(typeid(current_t).name()); \ + if (tid == from_tid) { \ + return true; \ + } \ + return parent_t::IsFromTypeId(from_tid); \ + } \ + std::string type_name() const override { return #current_t; } + +class Type; +using TypePtr = std::shared_ptr; + +class AnfNode; +using AnfNodePtr = std::shared_ptr; +using AnfNodePtrList = std::vector; +using AnfNodeSet = OrderedSet; + +namespace abstract { +class AbstractBase; +using AbstractBasePtr = std::shared_ptr; +using AbstractAttribute = std::pair; +class AnalysisContext; +using AnalysisContextPtr = std::shared_ptr; +} // namespace abstract + +struct MS_EXPORT TypeIdManager { + std::mutex mutex; + std::atomic type_counter{0}; + std::unordered_map map; + static TypeIdManager *Get(); + TypeIdManager() : mutex(), type_counter(0), map() {} +}; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BASE_BASE_H_ diff --git a/mindspore/ccsrc/ir/CMakeLists.txt b/mindspore/core/ir/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/ir/CMakeLists.txt rename to mindspore/core/ir/CMakeLists.txt diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc new file mode 100644 index 0000000000..0d96ddf263 --- /dev/null +++ b/mindspore/core/ir/anf.cc @@ -0,0 +1,221 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/anf.h" + +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/primitive.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +// namespace to support intermediate representation definition +CNode::CNode(const std::vector &inputs, const FuncGraphPtr &func_graph) + : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} + +// Check if CNode is an apply with the specific Primitive. +bool CNode::IsApply(const PrimitivePtr &value) const { + if (value == nullptr) { + return false; + } + + if (inputs_.size() != 0 && IsValueNode(inputs_[0])) { + PrimitivePtr fn_value = GetValueNode(inputs_[0]); + if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) { + return true; + } + } + + return false; +} + +void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; } + +std::string CNode::DebugString(int recursive_level) const { + std::ostringstream buffer; + if (recursive_level > 0) { + if (func_graph() != nullptr) { + buffer << func_graph()->ToString() << ":"; + } + buffer << ToString() << "{"; + bool is_first_node = true; + int idx = 0; + for (auto &node : inputs_) { + MS_EXCEPTION_IF_NULL(node); + if (is_first_node) { + is_first_node = false; + } else { + buffer << ", "; + } + buffer << "[" << idx << "]: " << node->DebugString(recursive_level - 1); + idx++; + } + buffer << "}"; + } else { + buffer << ToString(); + } + return buffer.str(); +} + +std::string ValueNode::ToString() const { + MS_EXCEPTION_IF_NULL(value_); + if (value_->isa()) { + return value_->cast()->ToString(); + } + std::ostringstream buffer; + buffer << AnfNode::ToString(); + buffer << "(" << value_->ToString() << ")"; + return buffer.str(); +} + +std::string ValueNode::DebugString(int) const { + MS_EXCEPTION_IF_NULL(value_); + std::ostringstream buffer; + buffer << "ValueNode<" << value_->type_name() << "> " << value_->ToString(); + return buffer.str(); +} + +std::string ValueNode::fullname_with_scope() { + if (!fullname_with_scope_.empty()) { + return fullname_with_scope_; + } + + MS_EXCEPTION_IF_NULL(scope()); + fullname_with_scope_ = scope()->name() + "/" + "data-" + id_generator::get_id(shared_from_base()); + return fullname_with_scope_; +} + +bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode == nullptr) { + return false; + } + if (value != nullptr) { + return cnode->IsApply(value); + } + const auto &prim = GetValueNode(cnode->input(0)); + return prim != nullptr; +} + +PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { + if (node == nullptr) { + return nullptr; + } + auto cnode = node->cast(); + if (cnode != nullptr) { + if (cnode->size() > 0) { + auto prim = GetValueNode(cnode->input(0)); + return prim; + } + } + return nullptr; +} + +std::string GetCNodeFuncName(const CNodePtr cnode) { + if (cnode->inputs().empty()) { + return ""; + } + + AnfNodePtr valuenode = cnode->input(0); + if (valuenode->isa()) { + auto value = GetValueNode(valuenode); + // check whether the valuenode is primitive + if (value->isa()) { + return value->cast()->name(); + } + return value->ToString(); + } + return ""; +} + +bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { + if (IsValueNode(node)) { + PrimitivePtr fn_value = GetValueNode(node); + MS_EXCEPTION_IF_NULL(value); + if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) { + return true; + } + } + return false; +} + +size_t NewSeenGeneration() { + static size_t seen_generation = 0; + return ++seen_generation; +} + +namespace id_generator { +static std::unordered_map node_ids; +std::string get_id(const AnfNodePtr &node) { + auto type_name = node->type_name(); + if (node_ids.find(type_name) == node_ids.end()) { + node_ids[type_name] = 0; + } else { + node_ids[type_name]++; + } + return std::to_string(node_ids[type_name]); +} + +void reset_id() { node_ids.clear(); } +} // namespace id_generator + +std::string GetCNodeTarget(const AnfNodePtr &node) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string default_target = context_ptr->device_target(); + if (!node->isa()) { + return default_target; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto attr_input = cnode->input(0); + if (attr_input == nullptr) { + return default_target; + } + auto value_node = attr_input->cast(); + if (value_node == nullptr) { + return default_target; + } + auto value = value_node->value(); + if (value == nullptr) { + return default_target; + } + if (!value->isa()) { + return default_target; + } + auto primitive = value->cast(); + auto att_target = primitive->GetAttr("primitive_target"); + if (att_target != nullptr) { + if (!att_target->isa()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + } + auto target = GetValue(att_target); + if (kTargetSet.find(target) == kTargetSet.end()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + } + return target; + } + return default_target; +} +} // namespace mindspore diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h new file mode 100644 index 0000000000..c1a28d57f1 --- /dev/null +++ b/mindspore/core/ir/anf.h @@ -0,0 +1,445 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_ANF_H_ +#define MINDSPORE_CCSRC_IR_ANF_H_ + +#include +#include +#include +#include +#include +#include + +#include "base/base.h" +#include "ir/kernel_info_dev.h" +#include "ir/scope.h" +#include "debug/info.h" + +// A MindSpore ANF IR defined here. +// with BNF followed: +// ::= Scalar | Named | Tensor | Var | +// Prim | MetaFuncGraph | FuncGraph | Type| +// Shape | Param +// ::= ( ...) +// ::= | +// ANode: Atomic Node +// CNode: Complex Node +namespace mindspore { +namespace parallel { +class TensorLayout; +class OperatorInfo; +} // namespace parallel +using OperatorInfoPtr = std::shared_ptr; + +namespace abstract { +class BaseShape; +class AbstractBase; +} // namespace abstract +using BaseShapePtr = std::shared_ptr; +using AbstractBasePtr = std::shared_ptr; +using AbstractBasePtrList = std::vector; + +class ValueNode; +using ValueNodePtr = std::shared_ptr; +class CNode; +using CNodePtr = std::shared_ptr; + +class FuncGraph; +using FuncGraphSet = OrderedSet; +using FuncGraphPtrList = std::vector; + +class Primitive; +using PrimitivePtr = std::shared_ptr; + +class BaseRef; + +class Var; +using VarPtr = std::shared_ptr; + +class AnfVisitor; + +class ParamValue; +using ParamValuePtr = std::shared_ptr; + +// AnfNode is the basic class of the IR definition derived from Base. +// Only two types of nodes are derived: CNode and ANode. +// Methods: +// func_graph: return FuncGraph that this AnfNode belongs to. +// scope: return the scope namespace of this AnfNode. Set it using set_scope. +// abstract: return the cached inferred abstract value. It contains type, shape +// value. Set New cache using set_abstract. +// intermediate_abstract: return the cached inferring abstract value. +// Type/Shape: return the related info of this AnfNode. When this AnfNode is an +// input of other CNodes, you can get the related info by this method. +// debug_info: return the information retrived from parser. Set it using set_debug_info. +// fullname_with_scope: return the detailed debug info. +class AnfNode : public Base { + public: + explicit AnfNode(const FuncGraphPtr &func_graph) + : func_graph_(FuncGraphWeakPtr(func_graph)), + abstract_(nullptr), + intermediate_abstract_(nullptr), + debug_info_(std::make_shared()), + fullname_with_scope_(""), + hash_(std::hash()), + kernel_info_(nullptr) { + scope_ = ScopeManager::GetInstance().GetCurrentScope(); + } + + ~AnfNode() override = default; + MS_DECLARE_PARENT(AnfNode, Base); + + virtual void accept(AnfVisitor *) {} + FuncGraphPtr func_graph() const { return func_graph_.lock(); } + + void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } + + ScopePtr scope() { return scope_; } + void set_scope(const ScopePtr &scope) { scope_ = scope; } + + const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); } + KernelInfoDevice *kernel_info() { return kernel_info_.get(); } + const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; } + void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; } + + AbstractBasePtr abstract() const { return abstract_; } + void set_abstract(const AbstractBasePtr &abs) { abstract_ = abs; } + + AbstractBasePtr intermediate_abstract() { return intermediate_abstract_; } + void set_intermediate_abstract(const AbstractBasePtr &abs) { intermediate_abstract_ = abs; } + + NodeDebugInfoPtr debug_info() { + MS_EXCEPTION_IF_NULL(debug_info_); + if (debug_info_->get_node() == nullptr) { + debug_info_->set_node(shared_from_base()); + } + return debug_info_; + } + void set_debug_info(const NodeDebugInfoPtr &debug_info) { + debug_info_ = debug_info; + if (debug_info_->get_node() == nullptr) { + debug_info_->set_node(shared_from_base()); + } + } + + TypePtr Type() const; + BaseShapePtr Shape() const; + + std::size_t hash() const override { return this->hash_(this); } + virtual std::string fullname_with_scope() { return ""; } + + virtual std::string DebugString(int recursive_level = 1) const { return ToString(); } + virtual std::string DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); } + std::string ToString() const override; + void dump() const override { std::cout << DebugString() << std::endl; } + std::string UniqueId() { return std::to_string(debug_info()->unique_id()); } + std::string UniqueIdThroughCopy() { return std::to_string(debug_info()->unique_id_through_copy()); } + virtual bool operator==(const AnfNode &other) const { return &other == this; } + friend std::ostream &operator<<(std::ostream &os, const AnfNode &node) { + os << node.ToString(); + return os; + } + size_t seen_{0}; + + protected: + // Hold a weak ref to Graph as Graph also hold ref to AnfNode. + // Otherwise, func_graph_ and AnfNode will make a reference cycle. + FuncGraphWeakPtr func_graph_; + AbstractBasePtr abstract_; + AbstractBasePtr intermediate_abstract_; + NodeDebugInfoPtr debug_info_; + std::string fullname_with_scope_; + + private: + std::hash hash_; + ScopePtr scope_; + KernelInfoDevicePtr kernel_info_; +}; + +// CNode represents the complex node with a set of arguments. +// Fields: +// inputs_: represents all of the inputs for this CNode. +// Using input(i) to get the index i input. +// Using inputs() to get all the inputs as a vector. +// Using add_input(input) to append a new input for a CNode. +// Using set_input(i, input) to change some input of these inputs. +// Using set_inputs(inputs) to refresh all of the inputs of a CNode. +// func_graph_as_var_: used in opt pattern matching to match a real FuncGraph. +// stop_gradient_: a flag used to stop gradient. +// Using stop_gradient() to get this flag, mainly used in ad. +// Using set_stop_gradient() to set this flag. +class CNode : public AnfNode { + public: + CNode(const std::vector &inputs, const FuncGraphPtr &func_graph); + CNode(const std::vector &inputs, const VarPtr &func_graph_as_var) + : AnfNode(nullptr), inputs_(inputs), func_graph_as_var_(func_graph_as_var), stop_gradient_(false) {} + + ~CNode() override = default; + MS_DECLARE_PARENT(CNode, AnfNode); + + void accept(AnfVisitor *v) override; + // check whether this cnode has some primitive value as the first input. + bool IsApply(const PrimitivePtr &) const; + + const size_t size() const { return inputs_.size(); } + const AnfNodePtr input(size_t i) const { return inputs_[i]; } + const std::vector &inputs() const { return inputs_; } + void add_input(const AnfNodePtr &input) { inputs_.push_back(input); } + void set_input(size_t i, const AnfNodePtr &input); + void set_inputs(const std::vector &inputs) { inputs_ = inputs; } + + bool stop_gradient() const { return stop_gradient_; } + void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } + + std::string fullname_with_scope() override; + void set_fullname_with_scope(const std::string full_name) { fullname_with_scope_ = full_name; } + std::string DebugString(int recursive_level = 1) const override; + std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } + + OperatorInfoPtr set_operator_info(const OperatorInfoPtr &operator_info); + OperatorInfoPtr operator_info() { return operator_info_; } + + void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; } + bool in_forward_flag() const { return in_forward_flag_; } + + VarPtr func_graph_as_var() const { return func_graph_as_var_; } + + private: + std::vector inputs_; + VarPtr func_graph_as_var_; + bool stop_gradient_; + OperatorInfoPtr operator_info_ = nullptr; + bool in_forward_flag_ = false; +}; + +// ANode represents the atomic node. It's derived Parameter and ValueNode. +class ANode : public AnfNode { + public: + ANode() : AnfNode(nullptr) {} + explicit ANode(const FuncGraphPtr &func_graph) : AnfNode(func_graph) {} + virtual ~ANode() = default; + + MS_DECLARE_PARENT(ANode, AnfNode); +}; + +// Parameter represents the parameter inputs of a function. They have no value. +// Attributes: +// default_param_value_: used to hold the inputting tensor of the model. +class Parameter : public ANode { + public: + explicit Parameter(const FuncGraphPtr &func_graph) + : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {} + ~Parameter() override = default; + MS_DECLARE_PARENT(Parameter, ANode); + + void accept(AnfVisitor *v) override; + + std::string name() const { return name_; } + void set_name(const std::string &name) { name_ = name; } + std::string fullname_with_scope() override { return name(); }; + + bool has_default() const { return has_default_; } + void set_default_param(ParamValuePtr param) { + default_param_ = param; + has_default_ = true; + } + ParamValuePtr default_param() const { return default_param_; } + + std::shared_ptr tensor_layout() const { return tensor_layout_; } + void set_tensor_layout(const std::shared_ptr &tensor_layout) { + tensor_layout_ = tensor_layout; + } + + bool operator==(const AnfNode &other) const override { + if (!other.isa()) { + return false; + } + auto p = static_cast(other); + if (name_.length() > 0 && p.name_.length() > 0) { + return p.name_ == name_; + } + return shared_from_this() == other.shared_from_this(); + } + + private: + std::string name_; + bool has_default_; + ParamValuePtr default_param_; + std::shared_ptr tensor_layout_; +}; +using ParameterPtr = std::shared_ptr; + +// Value is used to represent the atomic expression mentioned in BNF. +// It mainly be stored in ValueNode. Value and ValueNode is related definition. +class Value : public Base { + public: + Value() = default; + explicit Value(const TypePtr t) : type_(t) {} + Value(const Value &other) : Base(other) { this->type_ = other.type_; } + ~Value() override = default; + MS_DECLARE_PARENT(Value, Base) + + TypePtr type() const { return type_; } + virtual abstract::AbstractBasePtr ToAbstract() { MS_LOG(EXCEPTION) << "ToAbstract error"; } + + virtual bool operator==(const Value &rhs) const = 0; + virtual Value &operator=(const Value &other) { + if (&other == this) { + return *this; + } + this->type_ = other.type_; + return *this; + } + + protected: + TypePtr type_{nullptr}; +}; +using ValuePtr = std::shared_ptr; +using ValuePtrList = std::vector; + +// ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode +// does not belong to any particular function graph. +class ValueNode : public ANode { + public: + explicit ValueNode(const ValuePtr &value) : value_(value) {} + ~ValueNode() override = default; + MS_DECLARE_PARENT(ValueNode, ANode); + + void accept(AnfVisitor *v) override; + const ValuePtr &value() const { return value_; } + std::string fullname_with_scope() override; + + std::string ToString() const override; + std::string DebugString(int recursive_level = 1) const override; + std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } + + bool operator==(const AnfNode &other) const override { + if (!other.isa()) { + return false; + } + auto v = static_cast(other); + return *v.value() == *value(); + } + friend std::ostream &operator<<(std::ostream &os, const ValueNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + os << node->ToString(); + return os; + } + + private: + ValuePtr value_; +}; + +template +struct ImmTraits {}; + +#define IMM_TRAITS(typeimm, prototype) \ + template <> \ + struct ImmTraits { \ + using type = typeimm; \ + }; + +inline ValuePtr MakeValue(const ValuePtr &value) { return value; } + +template ::type::element_type> +inline ValuePtr MakeValue(S v) { + return std::make_shared(v); +} + +template ::type> +static S GetValue(const ValuePtr &value) { + MS_EXCEPTION_IF_NULL(value); + + U imm = value->cast(); + if (imm == nullptr) { + MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name(); + } + return imm->value(); +} + +template ::value && std::is_base_of::value, + S>::type * = nullptr> +static S GetValue(const ValuePtr &value) { + MS_EXCEPTION_IF_NULL(value); + S v = value->cast(); + if (v == nullptr) { + MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name(); + } + return v; +} + +std::string GetCNodeFuncName(CNodePtr cnode); + +// used to check whether an AnfNode is a cnode with a kind of Primitive as first input +bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr); + +// used to get PrimitivePtr from a cnode first input +PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); + +// used to check whether an AnfNode is a valuenode having some Primitive value +bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value); + +// used to check whether a ValueNode has some kind of value +template +static bool IsValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto anode = node->cast(); + if (anode != nullptr) { + auto value = anode->value(); + if (value == nullptr) { + MS_LOG(EXCEPTION) << "Const value is nullptr."; + } + return value->isa(); + } + return false; +} + +inline ValuePtr GetValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return nullptr; + } + return node->cast()->value(); +} + +template ::value && std::is_base_of::value, + S>::type * = nullptr> +inline S GetValueNode(const AnfNodePtr &node) { + auto value = GetValueNode(node); + if (value == nullptr) { + return nullptr; + } + auto s = value->cast(); + return s; +} + +size_t NewSeenGeneration(); + +namespace id_generator { +std::string get_id(const AnfNodePtr &node); +void reset_id(); +} // namespace id_generator +using TaggedNodeMap = std::unordered_map; +using TaggedGraph = std::pair; +std::string GetCNodeTarget(const AnfNodePtr &node); +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_ANF_H_ diff --git a/mindspore/core/ir/anf_extends.cc b/mindspore/core/ir/anf_extends.cc new file mode 100644 index 0000000000..b70a660aae --- /dev/null +++ b/mindspore/core/ir/anf_extends.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/anf.h" + +#include +#include +#include +#include + +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "frontend/operator/ops.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "debug/label.h" + +namespace mindspore { +// namespace to support intermediate representation definition +// Methods of AnfNode +TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); } +BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } + +std::string AnfNode::ToString() const { + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); +} + +OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { + if (operator_info_ != nullptr) { + MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() + << ", using the new one: " << operator_info->name(); + auto old_ptr = operator_info_; + operator_info_ = operator_info; + return old_ptr; + } + operator_info_ = operator_info; + return nullptr; +} + +std::string CNode::fullname_with_scope() { + // if full name is set, return its name immediately + if (!fullname_with_scope_.empty()) { + return fullname_with_scope_; + } + + if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || + IsApply(prim::kPrimHistogramSummary)) { + std::string tag = GetValue(GetValueNode(input(1))); + std::string name; + if (IsApply(prim::kPrimScalarSummary)) { + name = tag + "[:Scalar]"; + } else if (IsApply(prim::kPrimImageSummary)) { + name = tag + "[:Image]"; + } else if (IsApply(prim::kPrimHistogramSummary)) { + name = tag + "[:Histogram]"; + } else { + name = tag + "[:Tensor]"; + } + fullname_with_scope_ = name; + } else { + // cnode input 0 should be primitive ptr or funcgraph ptr + auto value_ptr = input(0)->cast(); + if (value_ptr == nullptr) { + MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; + fullname_with_scope_ = id_generator::get_id(shared_from_base()); + return fullname_with_scope_; + } + auto input_value = value_ptr->value(); + if (input_value == nullptr) { + MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr."; + fullname_with_scope_ = id_generator::get_id(shared_from_base()); + return fullname_with_scope_; + } + + auto prim = input_value->cast(); + MS_EXCEPTION_IF_NULL(scope()); + fullname_with_scope_ = scope()->name() + "/"; + if (prim != nullptr) { + fullname_with_scope_ += prim->name(); + } else { + auto func_graph = input_value->cast(); + MS_EXCEPTION_IF_NULL(func_graph); + auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + if (fg_flag != nullptr) { + auto fg_name = GetValue(fg_flag); + fullname_with_scope_ += "GraphKernel_" + fg_name; + } else { + fullname_with_scope_ += func_graph->ToString(); + } + } + fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base()); + } + + return fullname_with_scope_; +} + +void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +} // namespace mindspore diff --git a/mindspore/core/ir/anf_py.cc b/mindspore/core/ir/anf_py.cc new file mode 100644 index 0000000000..d033dfff5a --- /dev/null +++ b/mindspore/core/ir/anf_py.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "ir/anf.h" + +#include "pybind_api/api_register.h" + +namespace mindspore { +// Define python 'RefKey' class. +REGISTER_PYBIND_DEFINE(CNode, ([](const pybind11::module *m) { + (void)py::class_(*m, "CNode") + .def("expanded_str", (std::string(CNode::*)(int) const) & CNode::DebugString, + "Get CNode string representation with specified expansion level."); + })); +} // namespace mindspore diff --git a/mindspore/core/ir/device_sync.h b/mindspore/core/ir/device_sync.h new file mode 100644 index 0000000000..a6bbe92233 --- /dev/null +++ b/mindspore/core/ir/device_sync.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ +#define MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ + +#include +#include +#include + +#include "ir/dtype/type.h" + +using std::string; + +namespace mindspore { +// Interface for data synchornize between device and host. +class DeviceSync { + public: + virtual bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const = 0; + virtual bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, + const void *host_ptr) const = 0; +}; +using DeviceSyncPtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ diff --git a/mindspore/ccsrc/ir/dtype.cc b/mindspore/core/ir/dtype.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype.cc rename to mindspore/core/ir/dtype.cc diff --git a/mindspore/core/ir/dtype.h b/mindspore/core/ir/dtype.h new file mode 100644 index 0000000000..dc277c031c --- /dev/null +++ b/mindspore/core/ir/dtype.h @@ -0,0 +1,335 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_DTYPE_H_ +#define MINDSPORE_CCSRC_IR_DTYPE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "base/base.h" +#include "ir/named.h" + +#include "ir/dtype/type.h" +#include "ir/dtype/ref.h" +#include "ir/dtype/number.h" +#include "ir/dtype/container.h" +#include "ir/dtype/empty.h" + +/* namespace to support intermediate representation definition */ +namespace mindspore { +// Only few type supported now. +TypePtr TypeIdToType(TypeId id); + +class String : public Object { + public: + String() : Object(kObjectTypeString, false) {} + ~String() override = default; + MS_DECLARE_PARENT(String, Object) + + TypeId generic_type_id() const override { return kObjectTypeString; } + + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToString() const override { return std::string("String"); } + std::string ToReprString() const override { return "string"; } + std::string DumpText() const override { return "String"; } +}; +using StringPtr = std::shared_ptr; + +class Keyword : public Object { + public: + Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {} + Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} + + ~Keyword() override = default; + MS_DECLARE_PARENT(Keyword, Object) + + TypeId generic_type_id() const override { return kObjectTypeKeyword; } + TypePtr DeepCopy() const override; + + std::string ToString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + std::string GetKey() const { return key_; } + TypePtr GetValue() const { return value_; } + + private: + std::string key_; + TypePtr value_; +}; +using KeywordPtr = std::shared_ptr; + +class Slice : public Object { + public: + Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {} + Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step) + : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {} + + ~Slice() override = default; + MS_DECLARE_PARENT(Slice, Object) + + TypeId generic_type_id() const override { return kObjectTypeSlice; } + TypePtr DeepCopy() const override; + + std::string ToString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + TypePtr get_start() const { return start_; } + TypePtr get_stop() const { return stop_; } + TypePtr get_step() const { return step_; } + + private: + TypePtr start_; + TypePtr stop_; + TypePtr step_; +}; +using SlicePtr = std::shared_ptr; + +class UndeterminedType : public Object { + public: + UndeterminedType() : Object(kObjectTypeUndeterminedType) {} + explicit UndeterminedType(const TypePtr &ele) + : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} + ~UndeterminedType() override = default; + MS_DECLARE_PARENT(UndeterminedType, Object) + + TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + protected: + TypePtr element_type_; +}; +using MetaTensorTypePtr = std::shared_ptr; + +class TensorType : public Object { + public: + TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} + explicit TensorType(const TypePtr &ele) + : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} + ~TensorType() override = default; + MS_DECLARE_PARENT(TensorType, Object) + + TypeId generic_type_id() const override { return kObjectTypeTensorType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + private: + TypePtr element_type_; +}; +using TensorTypePtr = std::shared_ptr; + +class IndexedSlicesType : public Object { + public: + IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {} + explicit IndexedSlicesType(const TypePtr &ele) + : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {} + ~IndexedSlicesType() override = default; + MS_DECLARE_PARENT(IndexedSlicesType, Object) + + TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + private: + TypePtr element_type_; +}; +using IndexedSlicesTypePtr = std::shared_ptr; + +class Function : public Object { + public: + Function(); + Function(const std::vector &args, const TypePtr retval); + ~Function() override = default; + MS_DECLARE_PARENT(Function, Object) + + TypeId generic_type_id() const override { return kObjectTypeFunction; } + + // Add temporarily for return abstraction to avoid type checking. + bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); } + const std::vector &args() const { return args_; } + const TypePtr &retval() const { return retval_; } + + TypePtr DeepCopy() const override; + bool operator==(const Type &other) const override; + std::string ToString() const override; + std::string ToReprString() const override { return "function"; } + + private: + std::vector args_; + TypePtr retval_; +}; +using FunctionPtr = std::shared_ptr; + +class JTagged : public Object { + public: + JTagged() : Object(kObjectTypeJTagged) {} + explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} + ~JTagged() override = default; + MS_DECLARE_PARENT(JTagged, Object) + + TypeId generic_type_id() const override { return kObjectTypeJTagged; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string DumpText() const override; + + private: + TypePtr subtype_; +}; +using JTaggedPtr = std::shared_ptr; + +class SymbolicKeyType : public Object { + public: + SymbolicKeyType() : Object(kObjectTypeSymbolicKeyType) {} + ~SymbolicKeyType() override = default; + MS_DECLARE_PARENT(SymbolicKeyType, Object) + + TypeId generic_type_id() const override { return kObjectTypeSymbolicKeyType; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToReprString() const override { return "symbolic_key"; } + std::string DumpText() const override { return "SymType"; } +}; + +class EnvType : public Object { + public: + EnvType() : Object(kObjectTypeEnvType) {} + ~EnvType() override = default; + MS_DECLARE_PARENT(EnvType, Object) + + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToReprString() const override { return "env_type"; } + std::string DumpText() const override { return "EnvType"; } +}; +using EnvTypePtr = std::shared_ptr; + +class TypeType : public Type { + public: + TypeType() : Type(kMetaTypeTypeType) {} + ~TypeType() override = default; + MS_DECLARE_PARENT(TypeType, Type) + + TypeId generic_type_id() const override { return kMetaTypeTypeType; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToReprString() const override { return "type_type"; } + std::string DumpText() const override { return "TypeType"; } +}; +using TypeTypePtr = std::shared_ptr; + +class Problem : public Type { + public: + Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {} + explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {} + ~Problem() override = default; + MS_DECLARE_PARENT(Problem, Type) + + TypeId generic_type_id() const override { return kMetaTypeProblem; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToString() const override { return kind_.name(); } + std::string DumpText() const override { return "ProblemType"; } + + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr problem); + + private: + Named kind_; +}; +using ProblemPtr = std::shared_ptr; + +class External : public Type { + public: + External() : Type(kMetaTypeExternal) {} + ~External() override = default; + MS_DECLARE_PARENT(External, Type) + + TypeId generic_type_id() const override { return kMetaTypeExternal; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string DumpText() const override { return "ExternalType"; } + + private: + TypePtr kind; +}; +using ExternalPtr = std::shared_ptr; + +// helper template +template +TypePtr Clone(const T &t) { + return t.Clone(); +} + +TypePtr StringToType(const std::string &type_name); + +// Judge whether x is predicate or is a subclass of predicate. +bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type); + +bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type); + +// Whether t1 is identity or a subclass of t2. +bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr); + +struct TypeHasher { + std::size_t operator()(TypePtr const &type) const; +}; +struct TypeListHasher { + std::size_t operator()(const TypePtrList &type_list) const; +}; +struct TypeEqual { + bool operator()(TypePtr const &t1, TypePtr const &t2) const; +}; +struct TypeListEqual { + bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const; +}; + +extern const TypePtr kTypeExternal; +extern const TypePtr kTypeEnv; +extern const TypePtr kTypeType; +extern const TypePtr kString; +extern const TypePtr kList; +extern const TypePtr kTuple; +extern const TypePtr kDict; +extern const TypePtr kSlice; +extern const TypePtr kKeyword; +extern const TypePtr kTensorType; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_DTYPE_H_ diff --git a/mindspore/ccsrc/ir/dtype/container.cc b/mindspore/core/ir/dtype/container.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/container.cc rename to mindspore/core/ir/dtype/container.cc diff --git a/mindspore/core/ir/dtype/container.h b/mindspore/core/ir/dtype/container.h new file mode 100644 index 0000000000..29579fe73c --- /dev/null +++ b/mindspore/core/ir/dtype/container.h @@ -0,0 +1,150 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_DTYPE_CONTAINER_H_ +#define MINDSPORE_CCSRC_IR_DTYPE_CONTAINER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "base/base.h" +#include "ir/named.h" +#include "ir/dtype/type.h" + +namespace mindspore { +// TypeRefKey type + +// List +class List : public Object { + public: + List() : Object(kObjectTypeList) {} + List(const std::initializer_list &objs) + : Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {} + // Shadow copy; + explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {} + ~List() override {} + MS_DECLARE_PARENT(List, Object) + + const TypePtr operator[](size_t dim) const; + TypeId generic_type_id() const override { return kObjectTypeList; } + TypePtr DeepCopy() const override; + + bool operator==(const Type &other) const override; + std::size_t size() const { return elements_.size(); } + TypePtrList elements() const { return elements_; } + std::string ToString() const override; + std::string ToReprString() const override { return "list_"; } + std::string DumpText() const override; + + private: + TypePtrList elements_; +}; +using ListPtr = std::shared_ptr; + +using ClassAttrVector = std::vector>; + +class Class : public Object { + public: + Class() : Object(kObjectTypeClass), tag_(Named("Class")) {} + Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map &methods); + ~Class() override {} + MS_DECLARE_PARENT(Class, Object) + + TypeId generic_type_id() const override { return kObjectTypeClass; } + + bool operator==(const Type &other) const override; + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string DumpText() const override; + void set_value(const std::unordered_map &v) { attributes_value_ = v; } + + Named tag() { return tag_; } + std::unordered_map GetValue() { return attributes_value_; } + std::unordered_map methods() { return methods_; } + ClassAttrVector &GetAttributes() { return attributes_; } + + ClassAttrVector attributes_; + + private: + Named tag_; + std::unordered_map methods_; + // For AbstractClass build value + std::unordered_map attributes_value_; +}; +using ClassPtr = std::shared_ptr; + +class Tuple : public Object { + public: + Tuple() : Object(kObjectTypeTuple) {} + // usage : Tuple t = {std::make_shared(), std::make_shared(32)}; + Tuple(const std::initializer_list &objs) + : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} + + // Shadow copy + explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} + + ~Tuple() override {} + MS_DECLARE_PARENT(Tuple, Object) + + TypeId generic_type_id() const override { return kObjectTypeTuple; } + TypePtr DeepCopy() const override; + + std::string ToString() const override; + std::string ToReprString() const override { return "tuple_"; } + std::string DumpText() const override; + const TypePtr operator[](size_t dim) const; + bool operator==(const Type &other) const override; + + TypePtrList elements() const { return elements_; } + std::size_t size() const { return elements_.size(); } + + private: + TypePtrList elements_; +}; +using TuplePtr = std::shared_ptr; + +class Dictionary : public Object { + public: + Dictionary() : Object(kObjectTypeDictionary) {} + explicit Dictionary(const std::vector> &key_values) + : Object(kObjectTypeDictionary, false), key_values_(key_values) {} + + ~Dictionary() override {} + MS_DECLARE_PARENT(Dictionary, Object) + + TypeId generic_type_id() const override { return kObjectTypeDictionary; } + + bool operator==(const Type &other) const override; + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string DumpText() const override; + + private: + std::vector> key_values_; +}; +using DictionaryPtr = std::shared_ptr; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_DTYPE_CONTAINER_H_ diff --git a/mindspore/ccsrc/ir/dtype/empty.cc b/mindspore/core/ir/dtype/empty.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/empty.cc rename to mindspore/core/ir/dtype/empty.cc diff --git a/mindspore/core/ir/dtype/empty.h b/mindspore/core/ir/dtype/empty.h new file mode 100644 index 0000000000..e6149a1fce --- /dev/null +++ b/mindspore/core/ir/dtype/empty.h @@ -0,0 +1,93 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_DTYPE_EMPTY_H_ +#define MINDSPORE_CCSRC_IR_DTYPE_EMPTY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "base/base.h" +#include "ir/named.h" +#include "ir/dtype/type.h" + +namespace mindspore { +class TypeAnything : public Type { + public: + TypeAnything() : Type(kMetaTypeAnything) {} + ~TypeAnything() override {} + MS_DECLARE_PARENT(TypeAnything, Type) + + TypeId generic_type_id() const override { return kMetaTypeAnything; } + TypePtr DeepCopy() const override; + std::string DumpText() const override { return "AnythingType"; } +}; +using TypeAnythingPtr = std::shared_ptr; + +class TypeNone : public Type { + public: + TypeNone() : Type(kMetaTypeNone) {} + ~TypeNone() override {} + MS_DECLARE_PARENT(TypeNone, Type) + + TypeId generic_type_id() const override { return kMetaTypeNone; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToReprString() const override { return "type_none"; } + std::string DumpText() const override { return "NoneType"; } +}; +using TypeNonePtr = std::shared_ptr; + +class TypeNull : public Type { + public: + TypeNull() : Type(kMetaTypeNull) {} + ~TypeNull() override {} + MS_DECLARE_PARENT(TypeNull, Type) + + TypeId generic_type_id() const override { return kMetaTypeNull; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string DumpText() const override { return "NullType"; } +}; +using TypeNullPtr = std::shared_ptr; + +class TypeEllipsis : public Type { + public: + TypeEllipsis() : Type(kMetaTypeEllipsis) {} + ~TypeEllipsis() override {} + MS_DECLARE_PARENT(TypeEllipsis, Type) + + TypeId generic_type_id() const override { return kMetaTypeEllipsis; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToReprString() const override { return "Ellipsis"; } + std::string DumpText() const override { return "Ellipsis"; } +}; +using TypeEllipsisPtr = std::shared_ptr; + +extern const TypePtr kTypeNone; +extern const TypePtr kTypeNull; +extern const TypePtr kTypeEllipsis; +extern const TypePtr kAnyType; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_DTYPE_EMPTY_H_ diff --git a/mindspore/ccsrc/ir/dtype/number.cc b/mindspore/core/ir/dtype/number.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/number.cc rename to mindspore/core/ir/dtype/number.cc diff --git a/mindspore/core/ir/dtype/number.h b/mindspore/core/ir/dtype/number.h new file mode 100644 index 0000000000..8997ddc4df --- /dev/null +++ b/mindspore/core/ir/dtype/number.h @@ -0,0 +1,154 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_DTYPE_NUMBER_H_ +#define MINDSPORE_CCSRC_IR_DTYPE_NUMBER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "base/base.h" +#include "ir/named.h" +#include "ir/dtype/type.h" + +namespace mindspore { +// Number, abstract class. +class Number : public Object { + public: + Number() : Object(kObjectTypeNumber), number_type_(kObjectTypeNumber), nbits_(0) {} + Number(const TypeId number_type, const int nbits, bool is_generic = true) + : Object(kObjectTypeNumber, is_generic), number_type_(number_type), nbits_(nbits) {} + ~Number() override = default; + MS_DECLARE_PARENT(Number, Object) + + int nbits() const { return nbits_; } + + TypeId number_type() const override { return number_type_; } + TypeId type_id() const override { return number_type_; } + TypeId generic_type_id() const override { return kObjectTypeNumber; } + + bool operator==(const Type &other) const override; + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToString() const override { return "Number"; } + std::string ToReprString() const override { return "number"; } + std::string DumpText() const override { return "Number"; } + std::string GetTypeName(const std::string &type_name) const { + std::ostringstream oss; + oss << type_name; + if (nbits() != 0) { + oss << nbits(); + } + return oss.str(); + } + + private: + const TypeId number_type_; + const int nbits_; +}; + +// Bool +class Bool : public Number { + public: + Bool() : Number(kNumberTypeBool, 8) {} + ~Bool() override = default; + MS_DECLARE_PARENT(Bool, Number) + + TypeId generic_type_id() const override { return kNumberTypeBool; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToString() const override { return "Bool"; } + std::string ToReprString() const override { return "bool"; } + std::string DumpText() const override { return "Bool"; } +}; + +// Int +class Int : public Number { + public: + Int() : Number(kNumberTypeInt, 0) {} + explicit Int(const int nbits); + ~Int() override = default; + MS_DECLARE_PARENT(Int, Number) + TypeId generic_type_id() const override { return kNumberTypeInt; } + TypePtr DeepCopy() const override { return std::make_shared(nbits()); } + std::string ToString() const override { return GetTypeName("Int"); } + std::string ToReprString() const override { return nbits() == 0 ? "int_" : GetTypeName("int"); } + std::string DumpText() const override { + return nbits() == 0 ? std::string("Int") : std::string("I") + std::to_string(nbits()); + } +}; + +// UInt +class UInt : public Number { + public: + UInt() : Number(kNumberTypeUInt, 0) {} + explicit UInt(const int nbits); + TypeId generic_type_id() const override { return kNumberTypeUInt; } + + ~UInt() override {} + MS_DECLARE_PARENT(UInt, Number) + + TypePtr DeepCopy() const override { return std::make_shared(nbits()); } + std::string ToString() const override { return GetTypeName("UInt"); } + std::string ToReprString() const override { return GetTypeName("uint"); } + std::string DumpText() const override { + return nbits() == 0 ? std::string("UInt") : std::string("U") + std::to_string(nbits()); + } +}; + +// Float +class Float : public Number { + public: + Float() : Number(kNumberTypeFloat, 0) {} + explicit Float(const int nbits); + ~Float() override {} + MS_DECLARE_PARENT(Float, Number) + + TypeId generic_type_id() const override { return kNumberTypeFloat; } + TypePtr DeepCopy() const override { return std::make_shared(nbits()); } + std::string ToString() const override { return GetTypeName("Float"); } + std::string ToReprString() const override { return nbits() == 0 ? "float_" : GetTypeName("float"); } + std::string DumpText() const override { + return nbits() == 0 ? std::string("Float") : std::string("F") + std::to_string(nbits()); + } +}; + +extern const TypePtr kBool; +extern const TypePtr kInt8; +extern const TypePtr kInt16; +extern const TypePtr kInt32; +extern const TypePtr kInt64; +extern const TypePtr kUInt8; +extern const TypePtr kUInt16; +extern const TypePtr kUInt32; +extern const TypePtr kUInt64; +extern const TypePtr kFloat16; +extern const TypePtr kFloat32; +extern const TypePtr kFloat64; +extern const TypePtr kInt; +extern const TypePtr kUInt; +extern const TypePtr kFloat; +extern const TypePtr kNumber; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_DTYPE_NUMBER_H_ diff --git a/mindspore/ccsrc/ir/dtype/ref.cc b/mindspore/core/ir/dtype/ref.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/ref.cc rename to mindspore/core/ir/dtype/ref.cc diff --git a/mindspore/core/ir/dtype/ref.h b/mindspore/core/ir/dtype/ref.h new file mode 100644 index 0000000000..e798d72af5 --- /dev/null +++ b/mindspore/core/ir/dtype/ref.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_DTYPE_REF_H_ +#define MINDSPORE_CCSRC_IR_DTYPE_REF_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "base/base.h" +#include "ir/named.h" +#include "ir/dtype/type.h" + +namespace mindspore { +// TypeRefKey type +class RefKeyType : public Object { + public: + RefKeyType() : Object(kObjectTypeRefKey) {} + ~RefKeyType() override {} + MS_DECLARE_PARENT(RefKeyType, Object) + + TypeId generic_type_id() const override { return kObjectTypeRefKey; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToReprString() const override { return "type_refkey"; } + std::string DumpText() const override { return "RefKeyType"; } +}; + +// TypeRef type +class RefType : public Object { + public: + RefType() : Object(kObjectTypeRef) {} + RefType(const TypePtr &subtype, const TypePtr &subtype_origin) + : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} + ~RefType() override {} + MS_DECLARE_PARENT(RefType, Object) + + TypePtr subtype() const { return subtype_; } + TypeId generic_type_id() const override { return kObjectTypeRef; } + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string DumpText() const override; + + private: + TypePtr subtype_; + TypePtr subtype_origin_; +}; +using RefTypePtr = std::shared_ptr; + +extern const TypePtr kRefKeyType; +extern const TypePtr kRefType; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_DTYPE_REF_H_ diff --git a/mindspore/ccsrc/ir/dtype/type.cc b/mindspore/core/ir/dtype/type.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/type.cc rename to mindspore/core/ir/dtype/type.cc diff --git a/mindspore/core/ir/dtype/type.h b/mindspore/core/ir/dtype/type.h new file mode 100644 index 0000000000..2e38e8ffb6 --- /dev/null +++ b/mindspore/core/ir/dtype/type.h @@ -0,0 +1,127 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ +#define MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "base/base.h" +#include "ir/named.h" +#include "ir/dtype/type_id.h" + +namespace mindspore { + +TypeId IntBitsToTypeId(const int nbits); +TypeId UIntBitsToTypeId(const int nbits); +TypeId FloatBitsToTypeId(const int nbits); +const char *TypeIdLabel(const TypeId &v); +TypeId NormalizeTypeId(const TypeId type_id); +bool IsSameObjectType(const Type &lhs, const Type &rhs); +size_t GetTypeByte(const TypePtr &type_ptr); + +// Base class for all types +// forward declaration. + +class Type : public Value { + public: + Type() : meta_type_(kMetaTypeType), is_generic_(true) {} + explicit Type(TypeId t, bool is_generic = true) : meta_type_(t), is_generic_(is_generic) {} + ~Type() override = default; + MS_DECLARE_PARENT(Type, Value) + + bool operator==(const Value &other) const override; + TypeId meta_type() const { return meta_type_; } + + virtual TypeId type_id() const { return meta_type_; } + virtual TypeId generic_type_id() const { return kMetaTypeType; } + + virtual bool operator!=(const Type &other) const { return !(*this == other); } + virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); } + virtual bool equal(const TypePtr other) const { return *this == *other; } + + virtual TypeId object_type() const { return kTypeUnknown; } + virtual TypeId parent_type() const { return kTypeUnknown; } + virtual TypeId number_type() const { return kTypeUnknown; } + virtual TypePtr DeepCopy() const = 0; + virtual TypePtr Clone() const { return DeepCopy(); } + + std::size_t hash() const override { return std::hash{}(static_cast(type_id())); } + + std::string ToString() const override { return TypeIdLabel(meta_type_); } + virtual std::string ToReprString() const { return ToString(); } + std::string ReprString() const { return "mindspore." + ToReprString(); } + void dump() const override { std::cout << ToString() << std::endl; } + bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } + bool IsGeneric() const { return is_generic_; } + abstract::AbstractBasePtr ToAbstract() override; + friend std::ostream &operator<<(std::ostream &os, const Type &type); + friend std::ostream &operator<<(std::ostream &os, const TypePtr type); + + const bool parse_info_ = true; + + private: + TypeId meta_type_; + bool is_generic_; +}; + +using TypePtrList = std::vector; + +// +// Base class for normal objects +// +class Object : public Type { + public: + Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject), parent_type_(kMetaTypeObject) {} + explicit Object(const TypeId object_type, bool is_generic = true) + : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(kMetaTypeObject) {} + explicit Object(const TypeId object_type, const TypeId parent_type, bool is_generic = true) + : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(parent_type) {} + ~Object() override = default; + MS_DECLARE_PARENT(Object, Type) + + TypeId object_type() const override { return object_type_; } + TypeId parent_type() const override { return parent_type_; } + TypeId type_id() const override { return object_type_; } + TypeId generic_type_id() const override { return kMetaTypeObject; } + bool equal(const TypePtr other) const override; + std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } + + friend std::ostream &operator<<(std::ostream &os, const Object &obj); + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr obj); + + private: + const TypeId object_type_; + const TypeId parent_type_; +}; + +std::ostream &operator<<(std::ostream &os, const TypePtrList &types); +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ diff --git a/mindspore/core/ir/dtype/type_extends.cc b/mindspore/core/ir/dtype/type_extends.cc new file mode 100644 index 0000000000..771a460c17 --- /dev/null +++ b/mindspore/core/ir/dtype/type_extends.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/dtype/type.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +abstract::AbstractBasePtr Type::ToAbstract() { + auto ptr = std::make_shared(shared_from_base()); + return ptr; +} +} // namespace mindspore diff --git a/mindspore/core/ir/dtype/type_id.h b/mindspore/core/ir/dtype/type_id.h new file mode 100644 index 0000000000..6fb2a354c1 --- /dev/null +++ b/mindspore/core/ir/dtype/type_id.h @@ -0,0 +1,93 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_ +#define MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_ + +#include +#include + +namespace mindspore { +// +// Supported meta type +// +enum TypeId : int { + kTypeUnknown = 0, + kMetaTypeBegin = kTypeUnknown, + kMetaTypeType, // Type + kMetaTypeAnything, + kMetaTypeObject, + kMetaTypeTypeType, // TypeType + kMetaTypeProblem, + kMetaTypeExternal, + kMetaTypeNone, + kMetaTypeNull, + kMetaTypeEllipsis, + kMetaTypeEnd, + // + // Object types + // + kObjectTypeBegin = kMetaTypeEnd, + kObjectTypeNumber, + kObjectTypeString, + kObjectTypeList, + kObjectTypeTuple, + kObjectTypeSlice, + kObjectTypeKeyword, + kObjectTypeTensorType, + kObjectTypeIndexedSlicesType, + kObjectTypeUndeterminedType, + kObjectTypeClass, + kObjectTypeDictionary, + kObjectTypeFunction, + kObjectTypeJTagged, + kObjectTypeSymbolicKeyType, + kObjectTypeEnvType, + kObjectTypeRefKey, + kObjectTypeRef, + kObjectTypeEnd, + // + // Number Types + // + kNumberTypeBegin = kObjectTypeEnd, + kNumberTypeBool, + kNumberTypeInt, + kNumberTypeInt8, + kNumberTypeInt16, + kNumberTypeInt32, + kNumberTypeInt64, + kNumberTypeUInt, + kNumberTypeUInt8, + kNumberTypeUInt16, + kNumberTypeUInt32, + kNumberTypeUInt64, + kNumberTypeFloat, + kNumberTypeFloat16, + kNumberTypeFloat32, + kNumberTypeFloat64, + kNumberTypeEnd +}; +// +// TypeId name map +// +const std::unordered_map type_name_map = { + {kNumberTypeBool, "bool_"}, {kNumberTypeInt8, "int8"}, {kNumberTypeUInt8, "uint8"}, + {kNumberTypeInt16, "int16"}, {kNumberTypeInt32, "int32"}, {kNumberTypeInt64, "int64"}, + {kNumberTypeFloat16, "float16"}, {kNumberTypeFloat32, "float32"}, {kNumberTypeFloat64, "float64"}}; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_ID_H_ diff --git a/mindspore/core/ir/dtype_extends.cc b/mindspore/core/ir/dtype_extends.cc new file mode 100644 index 0000000000..099748217e --- /dev/null +++ b/mindspore/core/ir/dtype_extends.cc @@ -0,0 +1,438 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/dtype.h" +#include +#include +#include +#include "utils/log_adapter.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +TypePtr TypeAnything::DeepCopy() const { return kAnyType; } + +std::size_t TypeHasher::operator()(TypePtr const &type) const { + MS_EXCEPTION_IF_NULL(type); + std::size_t hash = std::hash()(type->type_id()); + return hash; +} + +std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { + std::size_t hash_sum = 0; + for (auto &type : type_list) { + auto type_id = static_cast(type->type_id()); + hash_sum = hash_combine(hash_sum, type_id); + } + return hash_sum; +} + +bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { + MS_EXCEPTION_IF_NULL(t1); + MS_EXCEPTION_IF_NULL(t2); + return t1->type_id() == t2->type_id(); +} + +bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { + if (lhs.size() != rhs.size()) { + return false; + } + std::size_t size = lhs.size(); + for (std::size_t i = 0; i < size; ++i) { + MS_EXCEPTION_IF_NULL(lhs[i]); + MS_EXCEPTION_IF_NULL(rhs[i]); + if (*lhs[i] != *rhs[i]) { + return false; + } + } + return true; +} + +TypePtr TypeIdToType(TypeId id) { + switch (id) { + case kNumberTypeFloat16: + return kFloat16; + case kNumberTypeFloat: + case kNumberTypeFloat32: + return kFloat32; + case kNumberTypeFloat64: + return kFloat64; + case kNumberTypeInt8: + return kInt8; + case kNumberTypeInt16: + return kInt16; + case kNumberTypeInt32: + return kInt32; + case kNumberTypeInt64: + return kInt64; + case kNumberTypeUInt8: + return kUInt8; + case kNumberTypeUInt16: + return kUInt16; + case kNumberTypeUInt32: + return kUInt32; + case kNumberTypeUInt64: + return kUInt64; + case kNumberTypeBool: + return kBool; + case kMetaTypeExternal: + return kTypeExternal; + case kMetaTypeAnything: + return kAnyType; + case kMetaTypeNone: + return kTypeNone; + case kMetaTypeNull: + return kTypeNull; + case kMetaTypeEllipsis: + return kTypeEllipsis; + case kObjectTypeEnvType: + return kTypeEnv; + case kObjectTypeRefKey: + return kRefKeyType; + case kObjectTypeRef: + return kRefType; + case kMetaTypeTypeType: + return kTypeType; + case kObjectTypeString: + return kString; + case kObjectTypeList: + return kList; + case kObjectTypeTuple: + return kTuple; + case kObjectTypeDictionary: + return kDict; + case kObjectTypeSlice: + return kSlice; + case kObjectTypeKeyword: + return kKeyword; + case kTypeUnknown: + return kTypeNone; + default: + MS_LOG(EXCEPTION) << "Not support the type: " << id; + } +} + +namespace { +template +TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { + TypePtr type = nullptr; + if (type_name == num_type_name) { + type = std::make_shared(); + } else { + try { + if (num_type_name.size() >= type_name.size()) { + MS_LOG(EXCEPTION) << "Convert type is error, type_name(" << type_name << "), num_type_name(" << num_type_name + << ")"; + } + auto bits = std::stoi(type_name.substr(num_type_name.size())); + type = std::make_shared(bits); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << num_type_name << " convert from string error " << e.what(); + } + } + return type; +} + +std::vector StringToVectorOfType(const std::string &type_names) { + std::vector types; + if (type_names.length() == 0) { + return types; + } + std::string::size_type start = 0; + std::string::size_type end = type_names.find_first_of(','); + while (end != std::string::npos) { + types.push_back(StringToType(type_names.substr(start, end))); + // Skip ',' to find the next element. + start = end + 1; + end = type_names.find_first_of(',', start); + } + if (start >= type_names.size()) { + MS_LOG(EXCEPTION) << "Type name is empty string."; + } + types.push_back(StringToType(type_names.substr(start))); + return types; +} + +TypePtr TensorStrToType(const std::string &type_name) { + TypePtr type = nullptr; + if (type_name == "Tensor") { + type = std::make_shared(); + } else { + try { + auto start = type_name.find_first_of('[') + 1; + auto end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + auto element_str = type_name.substr(start, end - start); + auto element_type = StringToType(element_str); + if (element_type == nullptr) { + return nullptr; + } + type = std::make_shared(element_type); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); + } + } + + return type; +} + +TypePtr IndexedSlicesStrToType(const std::string &type_name) { + if (type_name == "IndexedSlices") { + return std::make_shared(); + } + auto start = type_name.find_first_of('[') + 1; + auto end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + auto element_str = type_name.substr(start, end - start); + auto element_type = StringToType(element_str); + if (element_type == nullptr) { + return nullptr; + } + return std::make_shared(element_type); +} + +TypePtr UndeterminedStrToType(const std::string &type_name) { + if (type_name == "Undetermined") { + return std::make_shared(); + } + auto start = type_name.find_first_of('[') + 1; + auto end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + auto element_str = type_name.substr(start, end - start); + auto element_type = StringToType(element_str); + if (element_type == nullptr) { + return nullptr; + } + return std::make_shared(element_type); +} + +TypePtr ListStrToType(const std::string &type_name) { + TypePtr type = nullptr; + if (type_name == "List") { + type = std::make_shared(); + } else { + try { + auto start = type_name.find_first_of('[') + 1; + auto end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + std::string element_strs = type_name.substr(start, end - start); + std::vector element_types = StringToVectorOfType(element_strs); + bool wrong = + std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); + if (wrong) { + return nullptr; + } + type = std::make_shared(element_types); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); + } + } + + return type; +} + +TypePtr TupleStrToType(const std::string &type_name) { + TypePtr type = nullptr; + if (type_name == "Tuple") { + type = std::make_shared(); + } else { + try { + size_t start = type_name.find_first_of('[') + 1; + size_t end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + std::string element_strs = type_name.substr(start, end - start); + std::vector element_types = StringToVectorOfType(element_strs); + bool wrong = + std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); + if (wrong) { + return nullptr; + } + type = std::make_shared(element_types); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); + } + } + return type; +} + +TypePtr FunctionStrToType(const std::string &type_name) { + TypePtr type = nullptr; + + if (type_name == "Function") { + type = std::make_shared(); + } else { + try { + // format: [(para1, para2, para3, ...) retval] + size_t start = type_name.find_first_of('[') + 1; + size_t end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + std::string str_all = type_name.substr(start, end - start); + size_t start_a = str_all.find_first_of('(') + 1; + size_t end_a = str_all.find_last_of(')'); + if (start_a >= str_all.size()) { + return nullptr; + } + std::string str_args = str_all.substr(start_a, end_a - start_a); + // bypass " " between ")" and retval + start = end_a + 2; + if (start >= str_all.size()) { + return nullptr; + } + std::string str_retval = str_all.substr(start); + + std::vector args_type = StringToVectorOfType(str_args); + TypePtr retval = StringToType(str_retval); + bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); + if (retval == nullptr || wrong) { + return nullptr; + } + type = std::make_shared(args_type, retval); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << type_name << " convert from string error " << e.what(); + } + } + return type; +} +} // namespace + +TypePtr StringToType(const std::string &type_name) { + TypePtr type = nullptr; + if (type_name.compare("None") == 0) { + type = std::make_shared(); + } else if (type_name.compare("Ellipsis") == 0) { + type = std::make_shared(); + } else if (type_name.compare("TypeType") == 0) { + type = std::make_shared(); + } else if (type_name.compare("SymbolicKeyType") == 0) { + type = std::make_shared(); + } else if (type_name.compare("RefKeyType") == 0) { + type = std::make_shared(); + } else if (type_name.compare("EnvType") == 0) { + type = std::make_shared(); + } else if (type_name.compare("Number") == 0) { + type = std::make_shared(); + } else if (type_name.compare("Bool") == 0) { + type = std::make_shared(); + } else if (type_name.compare(0, strlen("Int"), "Int") == 0) { + type = StringToNumberType(type_name, "Int"); + } else if (type_name.compare(0, strlen("UInt"), "UInt") == 0) { + type = StringToNumberType(type_name, "UInt"); + } else if (type_name.compare(0, strlen("Float"), "Float") == 0) { + type = StringToNumberType(type_name, "Float"); + } else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) { + type = TensorStrToType(type_name); + } else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) { + type = UndeterminedStrToType(type_name); + } else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) { + type = IndexedSlicesStrToType(type_name); + } else if (type_name.compare(0, strlen("List"), "List") == 0) { + type = ListStrToType(type_name); + } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { + type = TupleStrToType(type_name); + } else if (type_name.compare("Slice") == 0) { + type = std::make_shared(); + } else if (type_name.compare("Dictionary") == 0) { + type = std::make_shared(); + } else if (type_name.compare("String") == 0) { + type = std::make_shared(); + } else if (type_name.compare("Problem") == 0) { + type = std::make_shared(); + } else if (type_name.compare(0, strlen("Function"), "Function") == 0) { + type = FunctionStrToType(type_name); + } else { + // - unsupported to convert + // Class + // SymbolicType + // JTagged + // Anything + // External + // Problem + MS_LOG(EXCEPTION) << "Unsupported type name: " << type_name << "!"; + } + return type; +} + +bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type) { + if (x == nullptr || base_type == nullptr) { + MS_LOG(ERROR) << "Type is nullptr."; + return false; + } + if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { + return false; + } + if (base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type()) { + return true; + } + return false; +} + +bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { + if (x == nullptr || base_type == nullptr) { + MS_LOG(ERROR) << "Type is nullptr."; + return false; + } + if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { + return false; + } else if (!(base_type->IsGeneric())) { + return *(base_type) == *(x); + } else if (base_type->type_id() == x->type_id()) { + return true; + } else if (base_type->type_id() == x->generic_type_id()) { + return true; + } else if (base_type->type_id() == x->object_type()) { + return true; + } else if (base_type->type_id() == x->meta_type()) { + return true; + } else { + return false; + } +} + +bool IsSubType(TypePtr const &t1, TypePtr const &t2) { + MS_EXCEPTION_IF_NULL(t1); + if (t1->type_id() == kTypeUnknown) { + return false; + } else if (t2 != nullptr) { + return IsIdentidityOrSubclass(t1, t2); + } else { + return true; + } +} + +const TypePtr kTypeExternal = std::make_shared(); +const TypePtr kTypeEnv = std::make_shared(); +const TypePtr kTypeType = std::make_shared(); +const TypePtr kTensorType = std::make_shared(); +const TypePtr kIndexedSlicesType = std::make_shared(); +const TypePtr kUndeterminedType = std::make_shared(); +const TypePtr kString = std::make_shared(); +const TypePtr kList = std::make_shared(); +const TypePtr kTuple = std::make_shared(); +const TypePtr kDict = std::make_shared(); +const TypePtr kSlice = std::make_shared(); +const TypePtr kKeyword = std::make_shared(); +} // namespace mindspore diff --git a/mindspore/core/ir/dtype_py.cc b/mindspore/core/ir/dtype_py.cc new file mode 100644 index 0000000000..66bd8ba5f6 --- /dev/null +++ b/mindspore/core/ir/dtype_py.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/dtype.h" +#include +#include +#include +#include "utils/log_adapter.h" +#include "abstract/abstract_value.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +// Define python wrapper to handle data types. +REGISTER_PYBIND_DEFINE( + typing, ([](py::module *const m) { + auto m_sub = m->def_submodule("typing", "submodule for dtype"); + py::enum_(m_sub, "TypeId"); + (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); + (void)m_sub.def("load_type", &TypeIdToType, "load type"); + (void)m_sub.def( + "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); + (void)m_sub.def("str_to_type", &StringToType, "string to typeptr"); + (void)py::class_>(m_sub, "Type") + .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) + .def("__eq__", + [](const TypePtr &t1, const TypePtr &t2) { + if (t1 != nullptr && t2 != nullptr) { + return *t1 == *t2; + } + return false; + }) + .def("__hash__", &Type::hash) + .def("__str__", &Type::ToString) + .def("__repr__", &Type::ReprString) + .def("__deepcopy__", [](const TypePtr &t, py::dict) { + if (t == nullptr) { + return static_cast(nullptr); + } + return t->DeepCopy(); + }); + (void)py::class_>(m_sub, "Number").def(py::init()); + (void)py::class_>(m_sub, "Bool") + .def(py::init()) + .def(py::pickle( + [](const Bool &) { // __getstate__ + return py::make_tuple(); + }, + [](const py::tuple &) { // __setstate__ + return std::make_shared(); + })); + (void)py::class_>(m_sub, "Int") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const Int &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + Int data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "UInt") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const UInt &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + UInt data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "Float") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const Float &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + Float data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "List") + .def(py::init()) + .def(py::init>(), py::arg("elements")); + (void)py::class_>(m_sub, "Tuple") + .def(py::init()) + .def(py::init>(), py::arg("elements")); + (void)py::class_>(m_sub, "TensorType") + .def(py::init()) + .def(py::init(), py::arg("element")) + .def("element_type", &TensorType::element) + .def(py::pickle( + [](const TensorType &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + TensorType data(TypeIdToType(TypeId(static_cast(t[0].cast())))); + return data; + })); + (void)py::class_>(m_sub, "IndexedSlicesType") + .def(py::init()); + (void)py::class_>(m_sub, "UndeterminedType") + .def(py::init()); + (void)py::class_>(m_sub, "Function") + .def(py::init()) + .def(py::init, TypePtr>(), py::arg("args"), py::arg("retval")); + (void)py::class_>(m_sub, "Class").def(py::init()); + (void)py::class_>(m_sub, "SymbolicKeyType").def(py::init()); + (void)py::class_>(m_sub, "EnvType").def(py::init()); + (void)py::class_>(m_sub, "TypeNone").def(py::init()); + (void)py::class_>(m_sub, "TypeType").def(py::init()); + (void)py::class_>(m_sub, "String").def(py::init()); + (void)py::class_>(m_sub, "RefKeyType").def(py::init()); + (void)py::class_>(m_sub, "RefType").def(py::init()); + (void)py::class_>(m_sub, "TypeAnything").def(py::init()); + (void)py::class_>(m_sub, "Slice").def(py::init()); + (void)py::class_>(m_sub, "TypeEllipsis").def(py::init()); + })); +} // namespace mindspore diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc new file mode 100644 index 0000000000..fabdd3e7d3 --- /dev/null +++ b/mindspore/core/ir/func_graph.cc @@ -0,0 +1,628 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/func_graph.h" + +#include +#include +#include + +#include "debug/trace.h" +#include "ir/manager.h" +#include "frontend/operator/ops.h" +#include "utils/ordered_set.h" +#include "utils/convert_utils_base.h" + +namespace mindspore { +/* + * Methods of Graph + */ +FuncGraph::FuncGraph() + : attrs_(), + transforms_(), + parameter_default_value_(), + seen_(0), + parameters_(), + has_vararg_(false), + has_kwarg_(false), + kwonlyargs_count_(0), + hyper_param_count_(0), + is_generated_(false), + return_(nullptr), + manager_(std::weak_ptr()), + stub_(false) { + debug_info_ = std::make_shared(); +} + +AnfNodePtr FuncGraph::output() const { + // If return value is set, return should have two inputs. + if (return_ != nullptr && return_->inputs().size() == 2) { + return return_->input(1); + } else { + // If not set yet, return nullptr. + return nullptr; + } +} + +ParameterPtr FuncGraph::add_parameter() { + FuncGraphPtr this_func_graph = shared_from_base(); + ParameterPtr p = std::make_shared(this_func_graph); + add_parameter(p); + return p; +} + +void FuncGraph::add_parameter(const ParameterPtr &p) { + if (manager_.lock()) { + manager_.lock()->AddParameter(shared_from_base(), p); + } else { + parameters_.push_back(p); + } +} + +ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { + FuncGraphPtr this_graph = shared_from_base(); + ParameterPtr p = std::make_shared(this_graph); + p->set_name(name); + p->debug_info()->set_name(name); + + if (manager_.lock()) { + manager_.lock()->AddParameter(shared_from_base(), p); + } else { + parameters_.push_back(p); + } + hyper_param_count_++; + return p; +} + +bool FuncGraph::has_flag(const std::string &key) { + auto iter = attrs_.find(key); + if (iter != attrs_.cend()) { + if (iter->second->isa()) { + return GetValue(iter->second); + } + MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function."; + } + return false; +} + +bool FuncGraph::has_attr(const std::string &key) { + auto iter = attrs_.find(key); + return !(iter == attrs_.cend()); +} + +ValuePtr FuncGraph::get_attr(const std::string &key) { + auto iter = attrs_.find(key); + return iter == attrs_.cend() ? nullptr : iter->second; +} + +CNodePtr FuncGraph::NewCNode(const std::vector &inputs) { + CNodePtr cnode = std::make_shared(inputs, shared_from_base()); + if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { + order_.push_back(cnode); + MS_LOG(INFO) << "Graph: " << ToString() << ", push back " << cnode->DebugString() << " in order."; + } + return cnode; +} + +CNodePtr FuncGraph::NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope) { + CNodePtr app = NewCNode(inputs); + app->set_scope(scope); + return app; +} + +void FuncGraph::DumpCNodeList() { + MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; + for (const auto &cnode : order_) { + MS_LOG(INFO) << cnode->DebugString(); + } +} + +std::string FuncGraph::ToString() const { + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); +} + +GraphDebugInfoPtr FuncGraph::debug_info() { + MS_EXCEPTION_IF_NULL(this->debug_info_); + if (this->debug_info_->get_graph() == nullptr) { + this->debug_info_->set_graph(shared_from_base()); + } + return this->debug_info_; +} + +const AnfNodeSet &FuncGraph::nodes() { return nodes_; } + +void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); } + +void FuncGraph::ClearNodes() { nodes_.clear(); } + +void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); } + +void FuncGraph::DropNode(AnfNodePtr node) { + nodes_.erase(node); + auto graph = node->func_graph(); + // Remove the node from order list. + if (graph) { + graph->EraseUnusedNodeInOrder(node); + } +} + +const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } + +void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) { + auto &others = source->value_nodes(); + for (auto it = others.begin(); it != others.end(); it++) { + AddValueNode(it->first, it->second); + } +} + +void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } + +void FuncGraph::AddValueNode(AnfNodePtr node, int count) { + if (value_nodes_.count(node) == 0) { + value_nodes_[node] = count; + } else { + value_nodes_[node] += count; + } +} + +void FuncGraph::DropValueNode(AnfNodePtr node) { + if (value_nodes_.count(node) != 0) { + if (value_nodes_[node] == 1) { + (void)value_nodes_.erase(node); + } else { + value_nodes_[node]--; + if (value_nodes_[node] < 0) { + MS_LOG(EXCEPTION) << "Count of ValueNode '" << node + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } +} + +const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } + +void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) { + auto &others = source->free_variables(); + for (auto it = others.begin(); it != others.end(); it++) { + if (it->first->func_graph().get() != this) { + (void)AddFreeVariable(it->first, it->second); + } + } +} + +void FuncGraph::ClearFreeVariables() { free_variables_.clear(); } + +bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) { + if (free_variables_.count(node) == 0) { + free_variables_[node] = count; + return true; + } else { + free_variables_[node] += count; + return false; + } +} + +bool FuncGraph::DropFreeVariable(AnfNodePtr node) { + if (free_variables_.count(node) != 0) { + if (free_variables_[node] == 1) { + (void)free_variables_.erase(node); + return true; + } else { + free_variables_[node]--; + if (free_variables_[node] < 0) { + MS_LOG(EXCEPTION) << "Count of free variable '" << node + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } + return false; +} + +const BaseRefCounterMap &FuncGraph::free_variables_total() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + auto &fv_total = mng->free_variables_total(); + return fv_total[shared_from_base()]; +} + +std::vector FuncGraph::free_variables_nodes() { + std::vector nodes; + const auto &fv_total = this->free_variables_total(); + for (auto &p : fv_total) { + auto key = p.first; + if (utils::isa(key)) { + nodes.push_back(utils::cast(key)); + } + } + + return nodes; +} + +std::vector FuncGraph::free_variables_func_graphs() { + std::vector func_graphs; + const auto &fv_total = this->free_variables_total(); + for (auto &p : fv_total) { + auto key = p.first; + if (utils::isa(key)) { + func_graphs.push_back(utils::cast(key)); + } + } + + return func_graphs; +} + +const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; } + +void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) { + auto &others = source->func_graphs_used(); + for (auto it = others.begin(); it != others.end(); it++) { + (void)AddFuncGraphUsed(it->first, it->second); + } + func_graphs_used_.erase(source); +} + +void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); } + +bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) { + if (func_graphs_used_.count(fg) == 0) { + func_graphs_used_[fg] = count; + return true; + } else { + func_graphs_used_[fg] += count; + return false; + } +} + +bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) { + if (func_graphs_used_.count(fg) != 0) { + if (func_graphs_used_[fg] == 1) { + (void)func_graphs_used_.erase(fg); + return true; + } else { + func_graphs_used_[fg]--; + if (func_graphs_used_[fg] < 0) { + MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } + return false; +} + +const FuncGraphSet &FuncGraph::func_graphs_used_total() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + auto &used = mng->func_graphs_used_total(shared_from_base()); + return used; +} + +const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } + +void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) { + auto &others = source->func_graph_cnodes_index(); + for (auto it = others.begin(); it != others.end(); it++) { + // Ignore the user graph who may own itself. + auto fg = it->first->first->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + if (fg.get() != this) { + AddFuncGraphCNodeIndex(it->first, it->second); + } + } +} + +void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); } + +void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) { + if (func_graph_cnodes_index_.count(pair) == 0) { + func_graph_cnodes_index_[pair] = count; + } else { + func_graph_cnodes_index_[pair] += count; + } +} + +void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { + if (func_graph_cnodes_index_.count(pair) != 0) { + if (func_graph_cnodes_index_[pair] == 1) { + (void)func_graph_cnodes_index_.erase(pair); + } else { + func_graph_cnodes_index_[pair]--; + if (func_graph_cnodes_index_[pair] < 0) { + MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } +} + +const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; } + +void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) { + auto &others = source->j_func_graphs(); + for (auto it = others.begin(); it != others.end(); it++) { + AddJFuncGraph(it->first, it->second); + } +} + +void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); } + +void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) { + if (j_func_graphs_.count(fg) == 0) { + j_func_graphs_[fg] = count; + } else { + j_func_graphs_[fg] += count; + } +} + +void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) { + if (j_func_graphs_.count(fg) != 0) { + if (j_func_graphs_[fg] == 1) { + (void)j_func_graphs_.erase(fg); + } else { + j_func_graphs_[fg]--; + if (j_func_graphs_[fg] < 0) { + MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } +} + +FuncGraphPtr FuncGraph::parent() { + // report the bug early. + if (manager_.lock() == nullptr) { + MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() + << " NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + return mng->parent(shared_from_base()); +} + +const FuncGraphSet &FuncGraph::children() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + return mng->children(shared_from_base()); +} + +const FuncGraphSet &FuncGraph::scope() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + return mng->scopes(shared_from_base()); +} + +bool FuncGraph::recursive() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + return mng->recursive(shared_from_base()); +} + +std::shared_ptr> FuncGraph::recursive_graphs() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + return mng->recursive_graphs(shared_from_base()); +} + +AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { + auto itr = this->parameter_default_value_.find(name); + if (itr == parameter_default_value_.end()) { + return nullptr; + } + auto default_value = itr->second; + if (default_value == nullptr) { + MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist"; + } + if (IsValueNode(default_value)) { + return nullptr; + } + return default_value; +} + +// set the default values +void FuncGraph::SetDefaultValues(const std::vector &name_list, const std::vector &value_list) { + auto all_is_null = + std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode(node); }); + if (value_list.empty()) { + all_is_null = true; + } + for (size_t i = 0; i < name_list.size(); ++i) { + if (!all_is_null) { + this->parameter_default_value_[name_list[i]] = value_list[i]; + } + } +} + +void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } + +size_t FuncGraph::GetDefaultValueCount() { + int null_count = + std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), + [](const std::pair &pair) { return IsValueNode(pair.second); }); + return parameter_default_value_.size() - IntToSize(null_count); +} + +AnfNodePtr FuncGraph::GetVariableArgParameter() { + if (!has_vararg_) { + return nullptr; + } + + if (has_kwarg_) { + if (parameters_.size() < hyper_param_count_ + 2) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 2]; + } + + if (parameters_.size() < hyper_param_count_ + 1) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 1]; +} + +std::string FuncGraph::GetVariableArgName() { + if (!has_vararg_) { + return ""; + } + + if (has_kwarg_) { + if (parameters_.size() < hyper_param_count_ + 2) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 2]->cast()->name(); + } + + if (parameters_.size() < hyper_param_count_ + 1) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast()->name(); +} + +AnfNodePtr FuncGraph::GetVariableKwargParameter() { + if (has_kwarg_) { + if (parameters_.size() < hyper_param_count_ + 1) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 1]; + } + return nullptr; +} + +std::string FuncGraph::GetVariableKwargName() { + if (has_kwarg_) { + if (parameters_.size() < hyper_param_count_ + 1) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast()->name(); + } + return ""; +} + +int FuncGraph::GetPositionalArgsCount() const { + int count = SizeToInt(parameters_.size()); + if (has_kwarg_) { + count--; + } + if (has_vararg_) { + count--; + } + return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); +} + +AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { + for (size_t i = 0; i < parameters_.size(); ++i) { + MS_EXCEPTION_IF_NULL(parameters_[i]); + auto param_cast = parameters_[i]->cast(); + MS_EXCEPTION_IF_NULL(param_cast); + if (param_cast->name() == name) { + return parameters_[i]; + } + } + return nullptr; +} + +void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } + +std::list FuncGraph::GetOrderedCnodes() { + if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { + MS_LOG(DEBUG) << "Return ordered cnodes."; + return order_; + } else { + auto this_ptr = shared_from_base(); + auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1); + auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1); + + std::list cnodes; + auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); + for (const auto &node : nodes) { + auto cnode = dyn_cast(node); + if (cnode) { + cnodes.push_back(cnode); + } + } + return cnodes; + } +} + +void FuncGraph::EraseUnusedNodeInOrder() { + if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { + auto mng = manager_.lock(); + if (mng) { + auto &all_nodes = nodes(); + // Erase unused cnode. + for (auto it = order_.begin(); it != order_.end();) { + if (all_nodes.count(*it)) { + (void)it++; + } else { + MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; + it = order_.erase(it); + } + } + } + } +} + +void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { + if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa()) { + order_.remove(n->cast()); + MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; + } +} + +void FuncGraph::CheckOrder() { + if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { + MS_LOG(DEBUG) << "Check graph " << ToString(); + for (auto it = order_.begin(); it != order_.end(); (void)it++) { + for (const auto &input_node : (*it)->inputs()) { + if (input_node && input_node->isa() && input_node->func_graph() == shared_from_base()) { + // Need to reorder the wrong order node. + auto found = std::find(order_.begin(), it, input_node); + if (found == it) { + DumpCNodeList(); + MS_LOG(EXCEPTION) << "The cnode " << (*it)->DebugString() << " order in " << ToString() + << " doesn't obey the input dependency, " + << "as input " << input_node->DebugString() << " is not ahead of itself."; + } + } + } + } + auto mng = manager_.lock(); + if (mng != nullptr) { + const auto &all_nodes = nodes(); + if (all_nodes.size() != (order_.size() + parameters_.size())) { + DumpCNodeList(); + MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " + << all_nodes.size() - parameters_.size() << "."; + } + } + MS_LOG(DEBUG) << "Check order okay."; + } +} + +size_t NewFgSeenGeneration() { + static size_t fg_seen_generation = 0; + return ++fg_seen_generation; +} + +const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared("FuncGraph"); +const char kFuncGraphFlagUndetermined[] = "Undeterminate"; +} // namespace mindspore diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h new file mode 100644 index 0000000000..712c75b431 --- /dev/null +++ b/mindspore/core/ir/func_graph.h @@ -0,0 +1,423 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_ +#define MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/manager.h" +#include "utils/ordered_set.h" +#include "utils/ordered_map.h" +#include "utils/base_ref.h" + +namespace mindspore { +using BaseRefCounterMap = OrderedMap; +using FuncGraphCounterMap = OrderedMap; + +struct CNodeIndexHasher { + std::size_t operator()(const CNodeIndexPairPtr pair) const { + MS_EXCEPTION_IF_NULL(pair); + MS_EXCEPTION_IF_NULL(pair->first); + return hash_combine(pair->first->hash(), std::hash()(pair->second)); + } +}; + +struct CNodeIndexEqual { + bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const { + if (lhs == nullptr || rhs == nullptr) { + return false; + } + if (lhs == rhs) { + return true; + } + if (lhs->first != rhs->first) { + return false; + } + if (lhs->second != rhs->second) { + return false; + } + return true; + } +}; + +template , class CounterEqual = std::equal_to> +using CounterOrderedMap = OrderedMap; +using AnfNodeCounterMap = CounterOrderedMap; +using CNodeIndexCounterMap = CounterOrderedMap; + +using FuncGraphMap = OrderedMap; + +const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; +const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; +const char FUNC_GRAPH_FLAG_CORE[] = "core"; +const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel"; +const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; + +namespace abstract { +class AbstractKeywordArg; +using AbstractKeywordArgPtr = std::shared_ptr; +class AbstractFunction; +using AbstractFunctionPtr = std::shared_ptr; +} // namespace abstract + +// ANF transform class +// either a primitive or a func_graph +class FuncGraphTransform { + public: + enum Type { kGtPrimitive, kGtFuncGraph }; + + explicit FuncGraphTransform(const PrimitivePtr prim, const FuncGraphPtr func_graph = nullptr) + : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {} + + explicit FuncGraphTransform(const FuncGraphPtr &func_graph, const PrimitivePtr &prim = func_graph_prim_) + : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {} + + FuncGraphTransform(const FuncGraphTransform &t) : prim_(t.prim_), func_graph_(t.func_graph_) {} + + ~FuncGraphTransform() = default; + + Type type() const { + if (IsFuncGraph()) { + return kGtFuncGraph; + } else { + return kGtPrimitive; + } + } + + bool IsPrimitive() const { return (func_graph_.lock() == nullptr); } + bool IsFuncGraph() const { return (func_graph_.lock() != nullptr); } + FuncGraphPtr func_graph() const { return func_graph_.lock(); } + PrimitivePtr primitive() const { return prim_; } + + FuncGraphTransform &operator=(const FuncGraphTransform &t) { + if (this != &t) { + prim_ = t.prim_; + func_graph_ = t.func_graph_; + } + return *this; + } + + private: + PrimitivePtr prim_; + // FuncGraph will be hold by FuncGraphManager, so weak_ptr is enough here. + // And use weak_ptr can break the reference cycle between "primal" and "grad" graph in + // FPropRemapper::FinalizeGraph(). + FuncGraphWeakPtr func_graph_; + static const PrimitivePtr func_graph_prim_; +}; + +class FuncGraphBase : public Value { + public: + FuncGraphBase() = default; + + ~FuncGraphBase() override = default; + MS_DECLARE_PARENT(FuncGraphBase, Value); +}; + +extern const char kFuncGraphFlagUndetermined[]; + +class FuncGraph : public FuncGraphBase { + public: + FuncGraph(); + + ~FuncGraph() override = default; + MS_DECLARE_PARENT(FuncGraph, FuncGraphBase); + + // get the graph's abstract + abstract::AbstractFunctionPtr abstract(); + + // return the graph's output, or nullptr if not yet deduced + AnfNodePtr output() const; + void set_output(const AnfNodePtr &value, bool force_new_ret = false); + + const std::vector ¶meters() const { return parameters_; } + virtual ParameterPtr add_parameter(); + void add_parameter(const ParameterPtr &p); + void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); } + void set_parameters(const std::vector ¶ms) { parameters_ = params; } + // add a weight parameter with specific name + ParameterPtr AddWeightParameter(const std::string &name); + + // create a cnode with given inputs, bound to this graph + virtual CNodePtr NewCNode(const std::vector &inputs = std::vector()); + + // create a cnode with given inputs, bound to this graph, and set to specific scope + CNodePtr NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope); + + // Functions for handling variable argument, keyword-only arguments and variable keyword argument + AnfNodePtr GetDefaultValueByName(const std::string &name); + void set_param_default_value(const std::string &name, const AnfNodePtr &node) { + parameter_default_value_[name] = node; + } + void SetDefaultValues(const std::vector &name_list, const std::vector &value_list); + void ClearDefaultValues(); + size_t GetDefaultValueCount(); + std::map ¶meter_default_value() { return parameter_default_value_; } + void set_has_vararg(bool has_) { has_vararg_ = has_; } + bool has_vararg() const { return has_vararg_; } + AnfNodePtr GetVariableArgParameter(); + std::string GetVariableArgName(); + void set_has_kwarg(bool has_) { has_kwarg_ = has_; } + bool has_kwarg() const { return has_kwarg_; } + void set_kwonlyargs_count(int count) { kwonlyargs_count_ = count; } + int kwonlyargs_count() const { return kwonlyargs_count_; } + AnfNodePtr GetVariableKwargParameter(); + std::string GetVariableKwargName(); + void set_hyper_param_count(size_t count) { hyper_param_count_ = count; } + size_t hyper_param_count() const { return hyper_param_count_; } + int GetPositionalArgsCount() const; + AnfNodePtr GetParameterByName(const std::string &name); + bool NeedGenerate(const std::vector &kwarg_list); + FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list); + void set_is_generate(bool generated) { is_generated_ = generated; } + bool is_generated() const { return is_generated_; } + + std::unordered_map &attrs() { return attrs_; } + void set_attrs(const std::unordered_map &attrs) { + for (auto &attr : attrs) { + attrs_[attr.first] = attr.second; + } + } + bool has_flag(const std::string &key); + void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); } + void erase_flag(const std::string &key) { (void)attrs_.erase(key); } + + bool has_attr(const std::string &key); + ValuePtr get_attr(const std::string &key); + void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } + + std::unordered_map &transforms() { return transforms_; } + void set_transforms(const std::unordered_map &transforms) { + transforms_ = transforms; + } + + CNodePtr get_return() const { return return_; } + void set_return(const CNodePtr &cnode) { return_ = cnode; } + + FuncGraphManagerPtr manager() const { return manager_.lock(); } + void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr(m); } + + std::string ToString() const override; + GraphDebugInfoPtr debug_info(); + void set_debug_info(const GraphDebugInfoPtr &info) { + if (info == nullptr) { + MS_LOG(EXCEPTION) << "Graph set null debug info"; + } + this->debug_info_ = info; + } + + // get all nodes belonging to this func graph + const AnfNodeSet &nodes(); + void CopyNodes(const FuncGraphPtr &source); + void ClearNodes(); + void AddNode(AnfNodePtr node); + void DropNode(AnfNodePtr node); + + // get all value_nodes belonging to this func graph + const AnfNodeCounterMap &value_nodes(); + void CopyValueNodes(const FuncGraphPtr &source); + void ClearValueNodes(); + void AddValueNode(AnfNodePtr node, int count = 1); + void DropValueNode(AnfNodePtr node); + + // get all free vars directly used in this func graph + const AnfNodeCounterMap &free_variables(); + void CopyFreeVariables(const FuncGraphPtr &source); + void ClearFreeVariables(); + bool AddFreeVariable(AnfNodePtr node, int count = 1); + bool DropFreeVariable(AnfNodePtr node); + + // get all vars required by this func graph + const BaseRefCounterMap &free_variables_total(); + + // Return the set of graphs free_variables_total belong to. + std::vector free_variables_nodes(); + + // get all vars that are func graphs + std::vector free_variables_func_graphs(); + + // get all value nodes of func graph directly used by this func graph + const FuncGraphCounterMap &func_graphs_used(); + void CopyFuncGraphsUsed(const FuncGraphPtr &source); + void ClearFuncGraphsUsed(); + bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1); + bool DropFuncGraphUsed(FuncGraphPtr fg); + + // get all value nodes of J func graph directly used by this func graph + const FuncGraphCounterMap &j_func_graphs(); + void CopyJFuncGraphs(const FuncGraphPtr &source); + void ClearJFuncGraphs(); + void AddJFuncGraph(FuncGraphPtr fg, int count = 1); + void DropJFuncGraph(FuncGraphPtr fg); + + // get all func graphs nested used by this func graph + const FuncGraphSet &func_graphs_used_total(); + + // get all user value nodes of this func graph, by CNode and its input's index + const CNodeIndexCounterMap &func_graph_cnodes_index(); + void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source); + void ClearFuncGraphCNodesIndex(); + void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1); + void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node); + + // Return the parent of this graph. + FuncGraphPtr parent(); + + // Return the children of this graph. + const FuncGraphSet &children(); + + // Return the scope of this graph, scope have graph self but children not have. + const FuncGraphSet &scope(); + + // Return whether this graph is recursive + bool recursive(); + + // Return graphs which forms a recursive loop + std::shared_ptr> recursive_graphs(); + + std::size_t hash() const override { return std::hash{}(this); } + + void DumpFuncGraph(const std::string &path = "./func_graph.dot"); + + bool operator==(const Value &other) const override { + if (other.isa()) { + return &other == this; + } else { + return false; + } + } + void GenerateVarParams(const FuncGraphPtr &specialized_graph, std::vector *specialized_parameter_list, + std::unordered_map *repl_nodes, int variable_args_count, + int pos_args_input_count); + + void GenerateKwParams(const FuncGraphPtr &specialized_graph, std::vector *specialized_parameter_list, + const std::vector &kwarg_list, + std::unordered_map *repl_nodes); + + void GenerateDefaultValue(const FuncGraphPtr &specialized_graph, + const std::vector &specialized_parameter_list, + std::unordered_map *repl_nodes); + + const std::vector ¶mter_obj_nodes() const { return paramter_obj_nodes_; } + void add_parameter_obj_node(const AnfNodePtr &p); + + std::unordered_map &make_ref_params() { return make_ref_params_; } + + std::unordered_map attrs_; + std::unordered_map transforms_; + // parameter default value + std::map parameter_default_value_; + std::unordered_map make_ref_params_; + size_t seen_; + + std::list GetOrderedCnodes(); + void EraseUnusedNodeInOrder(const AnfNodePtr &n); + void EraseUnusedNodeInOrder(); + void CheckOrder(); + void DumpCNodeList(); + void ReleaseFullOrderToEffectOrder(); + void SetEffectDepends(const std::vector &depend_inputs); + bool HasEffect(const CNodePtr &cnode); + + bool stub() const { return stub_; } + void set_stub(bool stub) { stub_ = stub; } + + private: + // graph is manipulated by manager and others + friend FuncGraphManager; + + // all nodes of the function + AnfNodeSet nodes_; + + // all value nodes of the function + AnfNodeCounterMap value_nodes_; + + // all func graph value nodes of the function + FuncGraphCounterMap func_graphs_used_; + + // all free variables of the function + AnfNodeCounterMap free_variables_; + + // all value nodes calling J in the function + FuncGraphCounterMap j_func_graphs_; + + // all user value nodes of this func graph, recording by CNode and its input's index + CNodeIndexCounterMap func_graph_cnodes_index_; + + // parameters of this function + std::vector parameters_; + std::vector paramter_obj_nodes_; + + // whether there is a *args and **kwargs, and count kwonlyargs'number + bool has_vararg_; + bool has_kwarg_; + int kwonlyargs_count_; + // the hyper param is placed on the top graph, + // and positioned in the end of the param list, so we record the number to trace the position + size_t hyper_param_count_; + // the argument input list for the graph used to generate this graph + bool is_generated_; + + // the cnode that calls 'return' primitive + // we use shared pointer to manage it. + CNodePtr return_; + + // back-ref to its manager + // hold a weak ref to FuncGraphManager as FuncGraphManager also hold many ref to FuncGraph. + // Otherwise, FuncGraph and FuncGraphManager will make a reference cycles. + // Notes: Normally, there will be a global FuncGraphManager, it will hold all FuncGraphs. + // In some ut test cases, they may use local FuncGraphManager in function which + // generating the func graph, when go outside of that function, func graph will have no + // FuncGraphManager. In that special case, Manage() should be called to make the func graph + // managed. + std::weak_ptr manager_; + + GraphDebugInfoPtr debug_info_; + void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, + std::unordered_map *repl_nodes, + const std::vector &kwarg_keys_tuple_nodes, + const std::vector &kwarg_values_tuple_nodes); + + // CNode order which relates to origin code order + std::list order_; + bool stub_; +}; + +inline CNodePtr NewCNode(const std::vector &inputs, const FuncGraphPtr &fg) { + MS_EXCEPTION_IF_NULL(fg); + return fg->NewCNode(inputs); +} + +size_t NewFgSeenGeneration(); + +// Find the root cnodes of a segment of cnodes. +std::shared_ptr> FindRoots(const std::vector &segment); +// Find the leaf cnodes of a segment of cnodes. +std::shared_ptr> FindLeaves(const std::vector &segment); +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_H_ diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc new file mode 100644 index 0000000000..0857770cad --- /dev/null +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -0,0 +1,650 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/func_graph_cloner.h" + +#include + +#include "ir/manager.h" +#include "ir/param_value.h" +#include "frontend/operator/ops.h" +#include "utils/convert_utils_base.h" +#include "utils/log_adapter.h" +#include "utils/profile.h" +#include "utils/context/ms_context.h" + +// namespace to support intermediate representation definition +namespace mindspore { +Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, + bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation) + : clone_all_valuenodes_(clone_all_valuenodes), + clone_all_child_graphs_(clone_all_child_graphs), + clone_all_used_graphs_(clone_all_used_graphs), + relation_(relation), + target_relation_(target_relation == nullptr ? relation : target_relation) { + for (auto &func_graph : func_graphs) { + AddClone(func_graph); + } + scope_ = kDefaultScope; + type_ = kBasic; +} + +void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList ¶ms, CloneType type) { + if (func_graph != nullptr) { + todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); + type_ = type; + } +} + +void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) { + MS_EXCEPTION_IF_NULL(node); + if (repl_node_.find(node) != repl_node_.end() || node->isa()) { + return; + } + if (node->isa()) { + CloneParameter(node, target); + } else if (node->isa()) { + CloneCNode(node, target); + } +} + +void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(target); + TraceManager::DebugTrace(node->debug_info(), relation_); + auto new_param = (is_add) ? target->add_parameter() : std::make_shared(target); + auto old_param = node->cast(); + new_param->set_abstract(old_param->abstract()); + new_param->set_name(old_param->name()); + if (old_param->has_default()) { + // Default parameter can be shared since it is readonly. + new_param->set_default_param(old_param->default_param()); + } + ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); + new_param->set_scope(scope); + repl_node_[node] = new_param; + TraceManager::EndTrace(); +} + +void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(target); + TraceManager::DebugTrace(node->debug_info(), relation_); + CNodePtr new_node = std::make_shared(AnfNodePtrList{}, target); + auto old_node = node->cast(); + new_node->set_abstract(old_node->abstract()); + ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); + new_node->set_scope(scope); + new_node->set_kernel_info(old_node->kernel_info_ptr()); + repl_node_[old_node] = new_node; + nodes_.emplace_back(old_node, new_node); + TraceManager::EndTrace(); +} + +void Cloner::CloneValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + TraceManager::DebugTrace(node->debug_info(), relation_); + ValueNodePtr new_const = NewValueNode(GetValueNode(node)); + ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); + new_const->set_scope(scope); + new_const->set_abstract(node->abstract()); + repl_node_[node] = new_const; + TraceManager::EndTrace(); +} + +void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(target); + TraceManager::DebugTrace(node->debug_info(), relation_); + ValueNodePtr new_const = NewValueNode(target); + ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); + new_const->set_scope(scope); + new_const->set_abstract(node->abstract()); + repl_node_[node] = new_const; + TraceManager::EndTrace(); +} + +void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(manager_); + if (!clone_all_valuenodes_) { + return; + } + auto &value_nodes = func_graph->value_nodes(); + for (auto &value_node : value_nodes) { + auto old_node = value_node.first; + MS_EXCEPTION_IF_NULL(old_node); + if (repl_node_.count(old_node) == 0) { + CloneValueNode(old_node); + } + } +} + +void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(manager_); + if (!clone_all_child_graphs_) { + return; + } + auto &scopes = manager_->scopes(func_graph); + for (auto &graph : scopes) { + if (graph != func_graph) { + todo_.push_back({graph, nullptr, {}}); + } + } +} + +void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(manager_); + if (!clone_all_used_graphs_) { + return; + } + auto &used = func_graph->func_graphs_used(); + for (auto &fg : used) { + todo_.push_back({fg.first, nullptr, {}}); + } +} + +void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + for (auto &item : func_graph->parameter_default_value()) { + auto nodes = DeepLinkedGraphSearch(item.second); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + CloneNode(node, target_func_graph); + } else if (node->isa()) { + CloneValueNode(node); + } + } + } +} + +void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + MS_EXCEPTION_IF_NULL(manager_); + auto return_node = repl_node_[func_graph->get_return()]->cast(); + if (return_node == nullptr) { + MS_LOG(EXCEPTION) << "Can't find replicate node for return."; + } + target_func_graph->set_return(return_node); + + auto &cnodes = func_graph->func_graph_cnodes_index(); + for (auto &cnode : cnodes) { + auto parent = cnode.first->first->cast(); + auto valuenode = parent->input(cnode.first->second); + CloneValueNode(valuenode, target_func_graph); + } +} + +void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) { + MS_EXCEPTION_IF_NULL(func_graph); + auto &old_params = func_graph->parameters(); + if (old_params.size() != params.size()) { + MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; + return; + } + for (size_t i = 0; i < old_params.size(); ++i) { + repl_node_[old_params[i]] = params[i]; + } +} + +void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); + *target_func_graph = std::make_shared(); + (*target_func_graph)->set_attrs(func_graph->attrs()); + (*target_func_graph)->set_transforms(func_graph->transforms()); + (*target_func_graph)->set_has_vararg(func_graph->has_vararg()); + (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); + (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count()); + (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); + (*target_func_graph)->set_is_generate(func_graph->is_generated()); + (*target_func_graph)->set_stub(func_graph->stub()); + TraceManager::EndTrace(); +} + +void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + auto ¶ms = func_graph->parameters(); + for (auto ¶m : params) { + CloneParameter(param, target_func_graph, true); + } + repl_func_graph_[func_graph] = target_func_graph; +} + +void Cloner::GenParameters(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto &free_vars = manager_->free_variables_total(); + auto iter = free_vars.find(func_graph); + if (iter == free_vars.end()) { + return; + } + + for (auto &fv_map : iter->second) { + auto &free_var = fv_map.first; + if (utils::isa(free_var)) { + repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast(free_var))); + } + } +} + +void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { + param->set_abstract(node->abstract()); + if (node->isa()) { + ParameterPtr old_param = dyn_cast(node); + if (old_param->has_default()) { + // Default parameter can be shared since it is readonly. + param->set_default_param(old_param->default_param()); + } + param->set_name(old_param->name()); + } +} + +ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) { + TraceManager::DebugTrace(std::make_shared(node->debug_info())); + ParameterPtr param = std::make_shared(func_graph); + TraceManager::EndTrace(); + CloneParameter(param, node); + if (is_add) { + func_graph->add_parameter(param); + } + repl_node_[param] = node; + repl_map_node_[func_graph][node] = param; + return param; +} + +void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, + AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) { + AnfNodePtrList parameters; + std::unordered_set old_params; + for (auto ¶m : func_graph->parameters()) { + auto iter = repl_node_.find(param); + if (iter != repl_node_.end()) { + (void)old_params.insert(iter->second); + parameters.push_back(param); + } else { + parameters.push_back(AddParameter(func_graph, param, false)); + (void)old_params.insert(param); + } + } + AnfNodePtr new_param = nullptr; + for (auto ¶m : params) { + auto old_param = repl_node_[param]; + if (old_param->isa() && old_param->func_graph() == func_graph) { + repl_node_[old_param] = old_param; + repl_map_node_[func_graph][old_param] = old_param; + input_params->push_back(old_param); + continue; + } + if (old_params.find(old_param) != old_params.end()) { + new_param = repl_map_node_[func_graph][old_param]; + input_params->push_back(new_param); + continue; + } + new_param = AddParameter(func_graph, old_param, false); + parameters.push_back(new_param); + lift_params->push_back(new_param); + input_params->push_back(new_param); + } + func_graph->set_parameters(parameters); +} + +void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms) { + AnfNodePtr node = nullptr; + auto &repl_func_graph = repl_map_func_graph_[func_graph_user]; + auto iter = repl_func_graph.find(func_graph); + if (iter == repl_func_graph.end()) { + node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); + repl_func_graph[func_graph] = node; + } else { + node = iter->second; + } + if (node == nullptr || !node->isa()) { + return; + } + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs)); + cnode->set_inputs(inputs); + OrderParameters(func_graph, inputs); +} + +void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) { + std::unordered_set old_params; + for (auto ¶m : func_graph->parameters()) { + (void)old_params.insert(repl_node_[param]); + } + std::unordered_set new_params; + AnfNodePtrList parameters; + // Ignore the 1st and 2nd param of inputs(such as. partial graph) + for (size_t i = 2; i < inputs.size(); ++i) { + auto input = inputs[i]; + auto param = repl_node_[input]; + if (old_params.find(param) != old_params.end()) { + auto new_param = repl_map_node_[func_graph][param]; + parameters.push_back(new_param); + (void)new_params.insert(new_param); + } + } + for (auto ¶m : func_graph->parameters()) { + if (new_params.find(param) == new_params.end()) { + parameters.push_back(param); + } + } + func_graph->set_parameters(parameters); +} + +void Cloner::SetEdges(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + for (auto &node : func_graph->nodes()) { + if (node == nullptr) { + continue; + } + // Only cnode needed to be handled + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + for (size_t i = 0; i < inputs.size(); i++) { + auto &input = inputs[i]; + if (IsValueNode(input)) { + auto graph = GetValueNode(input); + auto &repl_func_graph = repl_map_func_graph_[func_graph]; + if (repl_func_graph.find(graph) != repl_func_graph.end()) { + transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); + } + } else { + auto &repl_node = repl_map_node_[func_graph]; + if (repl_node.find(input) != repl_node.end()) { + transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); + } + } + } + } +} + +void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms) { + AnfNodePtrList lift_params; + AnfNodePtrList input_params; + AddParameters(func_graph_user, params, &lift_params, &input_params); + AddInputs(func_graph_user, func_graph, input_params); + if (lift_params.empty()) { + return; + } + for (auto &cnode : func_graph_user->func_graph_cnodes_index()) { + LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params); + } +} + +void Cloner::Lift() { + for (auto &func_graph_params : repl_func_graph_params_) { + auto &func_graph = func_graph_params.first; + auto ¶ms = func_graph_params.second; + for (auto &cnode : func_graph->func_graph_cnodes_index()) { + LiftParameters(cnode.first->first->func_graph(), func_graph, params); + } + } +} + +void Cloner::LiftParameters() { + MS_EXCEPTION_IF_NULL(manager_); + transaction_ = manager_->Transact(); + const FuncGraphSet &func_graphs = manager_->func_graphs(); + for (auto &func_graph : func_graphs) { + GenParameters(func_graph); + } + Lift(); + for (auto &func_graph : func_graphs) { + SetEdges(func_graph); + } + transaction_.Commit(); +} + +bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) { + MS_EXCEPTION_IF_NULL(func_graph); + // Make sure only inline once + if (status_.count(func_graph) != 0) { + if (is_inline == status_[func_graph]) { + return false; + } + if (clone_all_used_graphs_) { + MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False."; + return false; + } + } + return true; +} + +void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + MS_EXCEPTION_IF_NULL(manager_); + const AnfNodeSet &nodes = func_graph->nodes(); + for (auto &node : nodes) { + CloneNode(node, target_func_graph); + } +} + +void Cloner::Run() { + if (todo_.empty()) { + return; + } + + if (type_ < kLifting) { + // Basic and Inline Clone + FuncGraphPtrList func_graphs; + (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), + [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; }); + manager_ = Manage(func_graphs, false); + CloneNodes(); + LinkEdges(); + SetDefaults(); + } else { + // Lifting Clone + CloneInfo item = todo_.back(); + manager_ = Manage(item.origin); + LiftParameters(); + } +} + +void Cloner::CloneNodes() { + while (!todo_.empty()) { + CloneInfo item = todo_.back(); + todo_.pop_back(); + + bool is_inline = (item.target != nullptr); + FuncGraphPtr func_graph = item.origin; + FuncGraphPtr target_func_graph = item.target; + (void)graph_set_.insert(func_graph); + + if (!CheckStatus(func_graph, is_inline)) { + continue; + } + + if (is_inline) { + InlineCloneParameters(func_graph, item.params); + CloneAllNodes(func_graph, target_func_graph); + } else { + SetFuncGraphInfo(func_graph, &target_func_graph); + CloneParameters(func_graph, target_func_graph); + CloneAllNodes(func_graph, target_func_graph); + CloneFuncGraphValueNodes(func_graph, target_func_graph); + CloneFuncGraphDefaultValues(func_graph, target_func_graph); + } + + CloneValueNodes(func_graph); + AddChildGraphs(func_graph); + AddTotalGraphs(func_graph); + status_[func_graph] = is_inline; + } +} + +void Cloner::LinkEdges() { + for (auto &node_pair : nodes_) { + CNodePtr old_node = node_pair.first; + CNodePtr new_node = node_pair.second; + MS_EXCEPTION_IF_NULL(old_node); + MS_EXCEPTION_IF_NULL(new_node); + for (auto &input : old_node->inputs()) { + auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; + new_node->add_input(new_input); + } + } +} + +// For the graphs cloned, update its default value map to the cloned nodes +void Cloner::SetDefaults() { + for (auto &item : graph_set_) { + MS_EXCEPTION_IF_NULL(item); + if (repl_func_graph_.count(item) != 0) { + for (auto ¶m_def : item->parameter_default_value()) { + MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); + if (repl_node_.count(param_def.second) != 0) { + repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); + } else { + repl_func_graph_[item]->set_param_default_value(param_def.first, param_def.second); + } + } + } + } +} + +AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) { + MS_EXCEPTION_IF_NULL(root); + if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { + MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; + } + CloneNode(root, repl_func_graph_[root->func_graph()]); + auto iter = repl_node_.find(root); + if (iter != repl_node_.end()) { + return iter->second; + } + MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; +} + +AnfNodePtr Cloner::operator[](const AnfNodePtr &node) { +#ifdef ENABLE_PROFILE + double time = GetTime(); +#endif + Run(); +#ifdef ENABLE_PROFILE + MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerNode", GetTime() - time); +#endif + return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); +} + +FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) { +#ifdef ENABLE_PROFILE + double time = GetTime(); +#endif + Run(); +#ifdef ENABLE_PROFILE + MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerGraph", GetTime() - time); +#endif + return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); +} + +FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + Cloner cloner({func_graph}, false, true, true, std::make_shared(), nullptr); + return cloner[func_graph]; +} + +AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList &func_graph_args, const ScopePtr &scope) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + Cloner cloner({}, false); + if (scope != nullptr) { + cloner.set_scope(scope); + } + cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline); + return cloner[func_graph->output()]; +} + +FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + Cloner cloner({}, false); + cloner.AddClone(func_graph, nullptr, {}, kLifting); + return cloner[func_graph]; +} + +ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphPtrList func_graphs = {func_graph}; + ClonerPtr cloner = + std::make_shared(func_graphs, false, false, false, std::make_shared(), relation); +#ifdef ENABLE_PROFILE + double time = GetTime(); +#endif + cloner->Run(); +#ifdef ENABLE_PROFILE + MsProfile::StatTime("func_graph_cloner_run.FuncGraphSpecializer", GetTime() - time); +#endif + return cloner; +} + +FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { + MS_EXCEPTION_IF_NULL(func_graph); + TraceManager::DebugTrace(func_graph->debug_info(), relation); + auto new_func_graph = std::make_shared(); + TraceManager::EndTrace(); + + auto ¶meters = func_graph->parameters(); + (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void { + MS_EXCEPTION_IF_NULL(param); + TraceManager::DebugTrace(std::make_shared(param->debug_info())); + (void)new_func_graph->add_parameter(); + TraceManager::EndTrace(); + }); + + Cloner cloner = Cloner(); + cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters()); + AnfNodePtr output = cloner[func_graph->output()]; + new_func_graph->set_output(output); + new_func_graph->set_has_vararg(func_graph->has_vararg()); + new_func_graph->set_has_kwarg(func_graph->has_kwarg()); + new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); + new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); + new_func_graph->set_is_generate(func_graph->is_generated()); + new_func_graph->set_stub(func_graph->stub()); + for (auto &item : func_graph->parameter_default_value()) { + new_func_graph->set_param_default_value(item.first, cloner[item.second]); + } + + if (MsContext::GetInstance()->is_multi_graph_sink()) { + if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { + new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + } + } + + if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + } + + return new_func_graph; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph_cloner.h b/mindspore/core/ir/func_graph_cloner.h similarity index 100% rename from mindspore/ccsrc/ir/func_graph_cloner.h rename to mindspore/core/ir/func_graph_cloner.h diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc new file mode 100644 index 0000000000..579409b05e --- /dev/null +++ b/mindspore/core/ir/func_graph_extends.cc @@ -0,0 +1,411 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/func_graph.h" + +#include +#include +#include + +#include "ir/manager.h" +#include "ir/func_graph_cloner.h" +#include "frontend/operator/ops.h" +#include "utils/ordered_set.h" +#include "abstract/abstract_value.h" +#include "debug/anf_ir_dump.h" +#include "debug/trace.h" +#include "debug/draw.h" +#include "debug/label.h" + +namespace mindspore { +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractFunctionPtr; +using mindspore::abstract::AnalysisContextPtr; +using mindspore::abstract::PrimitiveAbstractClosure; +using mindspore::abstract::VirtualAbstractClosure; + +AbstractFunctionPtr FuncGraph::abstract() { + AbstractBasePtrList args_spec_list; + + for (auto &p : parameters_) { + MS_EXCEPTION_IF_NULL(p); + if (p->abstract() == nullptr) { + MS_LOG(ERROR) << "Error!!"; + return nullptr; + } + args_spec_list.push_back(p->abstract()); + } + + if (nullptr == output()) { + MS_LOG(ERROR) << "Error func graph no output"; + return nullptr; + } + + return std::make_shared(args_spec_list, output()->abstract()); +} + +void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { + if (force_new_ret || return_ == nullptr) { + std::vector params({NewValueNode(prim::kPrimReturn), value}); + FuncGraphPtr this_graph = shared_from_base(); + return_ = this_graph->NewCNode(params); + } else { + if (manager_.lock()) { + manager_.lock()->SetEdge(return_, 1, value); + } else { + return_->set_input(1, value); + } + } + + return_->set_abstract(value->abstract()); + + AnfNodePtr input0 = return_->input(0); + + PrimitivePtr return_prim = prim::kPrimReturn; + auto f = std::make_shared(return_prim, input0); + input0->set_abstract(f); +} + +void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base()); } + +void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, + std::vector *specialized_parameter_list, + std::unordered_map *repl_nodes, int variable_args_count, + int pos_args_input_count) { + // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple + if (specialized_graph->has_vararg()) { + TraceManager::DebugTrace( + std::make_shared(specialized_graph->GetVariableArgParameter()->debug_info())); + std::vector var_param_tuple_nodes; + var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); + + if (variable_args_count < 0) { + MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count + << " were given."; + } + // for python variable argument input , there is no upper limit + for (int i = 0; i < variable_args_count; ++i) { + ParameterPtr p = std::make_shared(specialized_graph); + std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i); + p->set_name(param_name); + MS_EXCEPTION_IF_NULL(p->debug_info()); + p->debug_info()->set_name(param_name); + var_param_tuple_nodes.push_back(p); + MS_EXCEPTION_IF_NULL(specialized_parameter_list); + specialized_parameter_list->push_back(p); + } + auto var_tuple_param = specialized_graph->NewCNode(var_param_tuple_nodes); + (void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param); + TraceManager::EndTrace(); + } else if (variable_args_count > 0) { + MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << this->GetPositionalArgsCount() + << " positional arguments, but " << pos_args_input_count << " were given."; + } +} + +void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, + std::vector *specialized_parameter_list, + const std::vector &kwarg_list, + std::unordered_map *repl_nodes) { + std::vector kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; + std::vector kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; + + for (const auto &kwarg : kwarg_list) { + MS_EXCEPTION_IF_NULL(kwarg); + std::string kw_param_name = kwarg->get_key(); + MS_EXCEPTION_IF_NULL(specialized_graph); + AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name); + // if not find correspoding parameter node + if (param_node == nullptr) { + if (!has_kwarg()) { + MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name; + } else { + ParameterPtr p = std::make_shared(specialized_graph); + std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; + MS_EXCEPTION_IF_NULL(specialized_parameter_list); + auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), + [param_name](const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto param = node->cast(); + return param != nullptr && param->name() == param_name; + }); + if (find_kw_arg_in_list) { + MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name; + } + p->set_name(param_name); + p->debug_info()->set_name(param_name); + kwarg_keys_tuple_nodes.push_back(NewValueNode(kw_param_name)); + auto extract_node = + specialized_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), p}); + kwarg_values_tuple_nodes.push_back(extract_node); + specialized_parameter_list->push_back(p); + } + } else { + auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); + // multiply values found given for parameter + if (node_itr != specialized_parameter_list->end()) { + MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name; + } else { + specialized_parameter_list->push_back(param_node); + auto extract_node = specialized_graph->NewCNode( + {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); + (void)repl_nodes->emplace(param_node, extract_node); + } + } + } + + GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); +} + +void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, + std::unordered_map *repl_nodes, + const std::vector &kwarg_keys_tuple_nodes, + const std::vector &kwarg_values_tuple_nodes) { + if (has_kwarg()) { + MS_EXCEPTION_IF_NULL(specialized_graph); + TraceManager::DebugTrace( + std::make_shared(specialized_graph->GetVariableKwargParameter()->debug_info())); + auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes); + auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes); + auto make_dict_node = + specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values}); + MS_EXCEPTION_IF_NULL(repl_nodes); + (void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node); + TraceManager::EndTrace(); + } +} + +bool FuncGraph::NeedGenerate(const std::vector &kwarg_list) { + // if the function does not have any vararg/kwarg/kwonly/default value/kw args input + // return the original graph + if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { + return false; + } + + // if the graph is generated for specific input, do not need to generate again + if (is_generated()) { + return false; + } + return true; +} + +void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, + const std::vector &specialized_parameter_list, + std::unordered_map *repl_nodes) { + MS_EXCEPTION_IF_NULL(specialized_graph); + for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { + auto param_node = specialized_graph->parameters()[i]; + MS_EXCEPTION_IF_NULL(param_node); + auto param_name = param_node->cast()->name(); + auto node_itr = std::find(specialized_parameter_list.begin(), specialized_parameter_list.end(), param_node); + if (node_itr != specialized_parameter_list.end()) { + continue; + } + if (param_name == specialized_graph->GetVariableArgName() || + param_name == specialized_graph->GetVariableKwargName()) { + continue; + } + auto default_value = specialized_graph->GetDefaultValueByName(param_name); + if (default_value == nullptr) { + MS_LOG(EXCEPTION) << "Miss argument input for parameter:" << param_name; + } + MS_EXCEPTION_IF_NULL(repl_nodes); + (void)repl_nodes->emplace(param_node, default_value); + } +} + +FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { + std::vector kwarg_list; + size_t arguments_count = args_spec_list.size(); + for (const auto &arg : args_spec_list) { + // if it is a keyword argument + MS_EXCEPTION_IF_NULL(arg); + if (arg->isa()) { + kwarg_list.push_back(dyn_cast(arg)); + } + } + if (!NeedGenerate(kwarg_list)) { + return shared_from_base(); + } + FuncGraphPtr specialized_graph = BasicClone(shared_from_base()); + size_t kwarg_count = kwarg_list.size(); + int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count()); + int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); + int variable_args_count = pos_args_input_count - pos_args_count; + std::vector specialized_parameter_list; + std::unordered_map repl_nodes; + // the parameters that has arg input, copy from original parameters + for (size_t i = 0; i < IntToSize(pos_args_count); ++i) { + specialized_parameter_list.push_back(specialized_graph->parameters()[i]); + } + + GenerateVarParams(specialized_graph, &specialized_parameter_list, &repl_nodes, variable_args_count, + pos_args_input_count); + + GenerateKwParams(specialized_graph, &specialized_parameter_list, kwarg_list, &repl_nodes); + + GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes); + + // append hyper parameter to specialized_parameter_list + MS_EXCEPTION_IF_NULL(specialized_graph); + auto params = specialized_graph->parameters(); + (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), + std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); + + std::shared_ptr manager = mindspore::Manage(specialized_graph, false); + auto tr = manager->Transact(); + for (auto &node_pair : repl_nodes) { + MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" + << node_pair.second->DebugString(); + (void)tr.Replace(node_pair.first, node_pair.second); + } + tr.SetParameters(specialized_graph, specialized_parameter_list); + tr.Commit(); + specialized_graph->set_has_kwarg(false); + specialized_graph->set_has_vararg(false); + specialized_graph->set_kwonlyargs_count(0); + specialized_graph->ClearDefaultValues(); + specialized_graph->set_is_generate(true); + return specialized_graph; +} + +const char kPrimHasEffect[] = "_side_effect_flag"; + +bool FuncGraph::HasEffect(const CNodePtr &cnode) { + auto prim = GetCNodePrimitive(cnode); + if (prim != nullptr && prim->isa()) { + auto do_sig = prim->cast(); + auto prim_val = do_sig->function(); + if (prim_val != nullptr && prim_val->isa()) { + prim = prim_val->cast(); + } else { + prim = nullptr; + } + } + if (prim != nullptr) { + auto effect_val = prim->GetAttr(kPrimHasEffect); + if (effect_val && effect_val->isa()) { + auto effect_bool = GetValue(effect_val); + return effect_bool; + } + } + return false; +} + +std::shared_ptr> FindRoots(const std::vector &segment) { + std::shared_ptr> roots = std::make_shared>(segment); + for (const auto &node : segment) { + if (roots->size() == 1) { + return roots; + } + auto input_size = node->size(); + for (size_t i = 0; i < input_size; i++) { + auto in_node = node->input(i); + auto in_cnode = in_node->cast(); + if (in_cnode != nullptr) { + (void)roots->erase(in_cnode); + } + } + } + return roots; +} + +std::shared_ptr> FindLeaves(const std::vector &segment) { + std::shared_ptr> nodes = std::make_shared>(segment); + for (const auto &node : segment) { + if (nodes->size() == 1) { + return nodes; + } + if (IsPrimitiveCNode(node, prim::kPrimSwitch)) { + (void)nodes->erase(node); + continue; + } + auto input_size = node->size(); + for (size_t i = 0; i < input_size; i++) { + auto in_node = node->input(i); + if (!in_node->isa()) { + continue; + } + auto in_cnode = in_node->cast(); + if (in_cnode != nullptr) { + if (std::find(segment.begin(), segment.end(), in_cnode) != segment.end()) { + (void)nodes->erase(node); + break; + } + } + } + } + return nodes; +} + +void FuncGraph::ReleaseFullOrderToEffectOrder() { + MS_LOG(DEBUG) << "Flag has_effect " << has_flag(GRAPH_FLAG_HAS_EFFECT) << "."; + if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { + std::list depends_order; + std::vector segment; + for (const auto &cnode : order_) { + if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { + continue; + } + if (HasEffect(cnode)) { + MS_LOG(DEBUG) << "Meet a effect node " << cnode->DebugString() << "."; + if (segment.size() > 0) { + auto roots = FindRoots(segment); + for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { + depends_order.push_back(*iter); + } + } + segment.clear(); + depends_order.push_back(cnode); + } else { + MS_LOG(DEBUG) << "Meet a general node " << cnode->DebugString() << "."; + segment.push_back(cnode); + } + } + if (segment.size() > 1) { + auto roots = FindRoots(segment); + for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { + depends_order.push_back(*iter); + } + } + std::vector depend_inputs; + auto old_ret = output(); + for (auto iter = depends_order.rbegin(); iter != depends_order.rend(); (void)iter++) { + if (*iter != old_ret) { + depend_inputs.push_back(*iter); + } + } + set_flag(GRAPH_FLAG_HAS_EFFECT, false); + set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); + if (!depend_inputs.empty()) { + SetEffectDepends(depend_inputs); + } + } +} + +void FuncGraph::SetEffectDepends(const std::vector &depend_inputs) { + auto old_ret = output(); + std::vector inputs{NewValueNode(prim::kPrimDepend), old_ret}; + (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); + auto new_ret = NewCNode(inputs); + auto mng = manager(); + if (mng) { + (void)mng->Replace(old_ret, new_ret); + } else { + return_->set_input(1, new_ret); + } +} +} // namespace mindspore diff --git a/mindspore/core/ir/func_graph_py.cc b/mindspore/core/ir/func_graph_py.cc new file mode 100644 index 0000000000..cff25b5aa1 --- /dev/null +++ b/mindspore/core/ir/func_graph_py.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "ir/meta_func_graph.h" +#include "ir/func_graph.h" + +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { + // Define python "MetaFuncGraph_" class + (void)py::class_>(*m, "MetaFuncGraph_") + .def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_) + .def(py::init()); + // Define python "FuncGraph" class + (void)py::class_(*m, "FuncGraph") + .def(py::init()) + .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.") + .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph"); + })); +} // namespace mindspore diff --git a/mindspore/core/ir/kernel_info_dev.h b/mindspore/core/ir/kernel_info_dev.h new file mode 100644 index 0000000000..87c717bdcb --- /dev/null +++ b/mindspore/core/ir/kernel_info_dev.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ +#define MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ + +#include + +namespace mindspore { +// Interface for device kernel program information. +class KernelInfoDevice { + public: + // If kernel program was built and build info is set. + virtual bool has_build_info() const = 0; +}; +using KernelInfoDevicePtr = std::shared_ptr; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ diff --git a/mindspore/core/ir/lite/param_value_lite.h b/mindspore/core/ir/lite/param_value_lite.h new file mode 100644 index 0000000000..1da9b915c2 --- /dev/null +++ b/mindspore/core/ir/lite/param_value_lite.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_ +#define MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_ + +#include + +#include "ir/param_value.h" + +namespace mindspore { +class ParamValueLite : public ParamValue { + public: + ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} + virtual ~ParamValueLite() = default; + + size_t tensor_size() const { return tensor_size_; } + void set_tensor_size(size_t size) { tensor_size_ = size; } + + void *tensor_addr() const { return tensor_addr_; } + void set_tensor_addr(void *addr) { tensor_addr_ = addr; } + + private: + void *tensor_addr_; + size_t tensor_size_; +}; + +using ParamValueLitePtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINNIE_PARAM_VALUE_LITE_H_ diff --git a/mindspore/ccsrc/ir/lite/tensor.cc b/mindspore/core/ir/lite/tensor.cc similarity index 100% rename from mindspore/ccsrc/ir/lite/tensor.cc rename to mindspore/core/ir/lite/tensor.cc diff --git a/mindspore/ccsrc/ir/lite/tensor.h b/mindspore/core/ir/lite/tensor.h similarity index 100% rename from mindspore/ccsrc/ir/lite/tensor.h rename to mindspore/core/ir/lite/tensor.h diff --git a/mindspore/core/ir/manager.cc b/mindspore/core/ir/manager.cc new file mode 100644 index 0000000000..00c39679cd --- /dev/null +++ b/mindspore/core/ir/manager.cc @@ -0,0 +1,914 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/manager.h" + +#include +#include +#include + +#include "debug/trace_base.h" +#include "ir/func_graph.h" +#include "utils/profile.h" +#include "utils/convert_utils_base.h" +#include "frontend/operator/ops.h" + +namespace mindspore { + +FuncGraphManagerPtr MakeManager(const std::vector &func_graphs, bool manage) { + auto m = std::make_shared(func_graphs, manage); + m->Init(); + return m; +} + +FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool manage) { + FuncGraphManagerPtr m = nullptr; + bool root = false; + + for (auto &fg : func_graphs) { + if (fg == nullptr) { + continue; + } + if (fg->manager() != nullptr) { + m = fg->manager(); + break; + } + } + + if (m == nullptr) { + std::vector tmp; + m = MakeManager(tmp, manage); + root = true; + } + + for (auto &fg : func_graphs) { + if (fg == nullptr) { + continue; + } + m->AddFuncGraph(fg, root); + } + return m; +} + +FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) { + std::vector func_graphs = {func_graph}; + return Manage(func_graphs, manage); +} + +FuncGraphManager::FuncGraphManager(const std::vector &roots, bool manage) + : roots_(roots), is_manage_(manage) { + Reset(); +} + +void FuncGraphManager::Reset() { + func_graphs_ = FuncGraphSet(); + all_nodes_ = AnfNodeSet(); + node_users_ = NodeUsersMap(); + + signals_ = std::make_shared(); + + func_graph_parents_total_ = std::make_shared(this); + func_graph_parent_ = std::make_shared(this); + children_ = std::make_shared(this); + scopes_ = std::make_shared(this); + free_variables_total_ = std::make_shared(this); + func_graphs_used_total_ = std::make_shared(this); + recursive_ = std::make_shared(this); + j_total_ = std::make_shared(this); + + limit_ = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); +} + +void FuncGraphManager::Init() { + auto roots = roots_; + roots_ = FuncGraphSet(); + + for (auto &fg : roots) { + AddFuncGraph(fg, true); + } +} + +FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); + func_graph_parents_total_->Recompute(fg); + MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString(); + return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; +} + +FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(func_graph_parent_); + MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); + func_graph_parent_->Recompute(fg); + if (func_graph_parent_->parent_analysis().count(fg) == 0) { + MS_LOG(WARNING) << "This func graph is not in manager:" << fg->ToString(); + return nullptr; + } + MS_LOG(DEBUG) << "End parents func graph " << fg->ToString(); + return func_graph_parent_->parent_analysis()[fg]; +} + +FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(children_); + MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); + children_->Recompute(fg); + return children_->children_analysis()[fg]; +} + +FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(scopes_); + MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); + scopes_->Recompute(fg); + MS_LOG(DEBUG) << "End scopes func graph:" << fg->ToString(); + return scopes_->scope_analysis()[fg]; +} + +FVTotalMap &FuncGraphManager::free_variables_total() const { + MS_EXCEPTION_IF_NULL(free_variables_total_); + free_variables_total_->Recompute(); + return free_variables_total_->fv_total_analysis(); +} + +FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(func_graphs_used_total_); + func_graphs_used_total_->Recompute(fg); + return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; +} + +bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + recursive_->Recompute(fg); + if (recursive_->recursive_analysis().count(fg) == 0) { + MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); + return false; + } + return recursive_->recursive_analysis()[fg]; +} + +std::shared_ptr> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + if (recursive(fg)) { + if (!recursive_->recursive_map().count(fg)) { + auto trace = std::list(); + recursive_->CheckRecursiveGraphs(fg, &trace); + } + if (recursive_->recursive_map().count(fg) == 0) { + MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); + return nullptr; + } + return recursive_->recursive_map()[fg]; + } else { + return nullptr; + } +} + +bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(j_total_); + MS_EXCEPTION_IF_NULL(fg); + j_total_->Recompute(fg); + if (j_total_->j_total_analysis().count(fg) == 0) { + MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); + return false; + } + return j_total_->j_total_analysis()[fg]; +} + +// add a func graph to this manager, optionally as a root func graph. +void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { + MS_EXCEPTION_IF_NULL(func_graph); + if (is_root) { + roots_.add(func_graph); + } + if (func_graphs_.contains(func_graph)) { + return; + } + AddIntoManaged(func_graph); + std::vector para = func_graph->parameters(); + AcquireNodes(para); + std::vector return_vec({func_graph->get_return()}); + AcquireNodes(return_vec); +} + +// clear the all information in manager +void FuncGraphManager::Clear() { + func_graphs_.clear(); + all_nodes_.clear(); + node_users_.clear(); + roots_.clear(); + + signals_->InvalidateComputer(); +} + +void FuncGraphManager::KeepRoots(const std::vector &func_graphs) { + MS_LOG(DEBUG) << "Start keep roots"; + bool root_exist = false; + for (auto &item : func_graphs) { + if (roots_.contains(item)) { + root_exist = true; + break; + } + } + + // if the new_root in roots_, we add new_root first, then calculate the func_graphs + // relation to new_root, remove the func_graphs not relation to new_root + // if the new_root not in roots_, we clear the all func_graphs in manager + // then add the new_root + if (root_exist || func_graphs.empty()) { + FuncGraphSet roots(func_graphs); + if (roots.empty()) { + roots = roots_; + } else { + roots_.clear(); + for (auto &item : roots) { + AddFuncGraph(item, true); + } + } + + FuncGraphSet keep; + for (auto &item : roots) { + MS_LOG(DEBUG) << "roots: " << item->ToString(); + keep.update(func_graphs_used_total(item)); +#ifdef DEBUG + for (auto &k : keep) { + MS_LOG(DEBUG) << "keep: " << k->ToString(); + } +#endif + } + MaybeDropFuncGraphs(func_graphs_ - keep, true); + } else { + Clear(); + FuncGraphSet roots(func_graphs); + for (auto &item : roots) { + AddFuncGraph(item, true); + } + } +} + +void FuncGraphManager::RemoveRoots() { + MS_LOG(DEBUG) << "Start remove roots"; + roots_.clear(); + MaybeDropFuncGraphs(func_graphs_, true); +} + +void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { + MS_EXCEPTION_IF_NULL(fg); + if (is_manage_) { + if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { + MS_LOG(WARNING) << "A func graph can only have one manager."; + } + FuncGraphManagerPtr this_manager = shared_from_this(); + fg->set_manager(this_manager); + } + func_graphs_.add(fg); +} + +void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) { + FuncGraphSet todo(func_graphs); + std::set dropped; + // int count = 0; + while (!todo.empty()) { + FuncGraphPtr func_graph = todo.pop(); + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(DEBUG) << "Maybe drop func graph " << func_graph->ToString(); + if (roots_.contains(func_graph)) { + MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); + continue; + } + auto &users_cnode_index = func_graph->func_graph_cnodes_index(); + if (!users_cnode_index.empty() && !ignore_users) { + MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); + continue; + } + if (dropped.find(func_graph) != dropped.end()) { + MS_LOG(DEBUG) << "Func graph had been dropped " << func_graph->ToString(); + continue; + } + (void)dropped.insert(func_graph); + std::vector return_vec = {func_graph->get_return()}; + todo.update(MaybeDropNodes(return_vec)); + } + for (auto &fg : dropped) { + MS_EXCEPTION_IF_NULL(fg); + all_nodes_.difference_update(fg->parameters()); + (void)func_graphs_.erase(fg); + if (fg->manager().get() == this) { + fg->set_manager(nullptr); + } + MS_LOG(DEBUG) << "Func graph dropped " << fg->ToString(); + } +} + +void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) { + MS_EXCEPTION_IF_NULL(inp); + if (direction == kDecEdge) { + MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); + auto &users_node = node_users_[inp]; + if (!users_node.contains(make_pair(node, index))) { + return; + } + (void)users_node.erase(make_pair(node, index)); + DropEdge(node, index, inp); + } else { + MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); + if (IsValueNode(inp)) { + MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); + AddFuncGraph(GetValueNode(inp)); + } + auto &users_node = node_users_[inp]; + users_node.add(make_pair(node, index)); + AddEdge(node, index, inp); + } +} + +void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + int index = 0; + for (auto &inp : cnode->inputs()) { + ProcessEdge(cnode, index, inp, direction); + ++index; + } + } +} + +IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) { + if (all_nodes_.contains(node)) { + return EXCLUDE; + } else { + return FOLLOW; + } +} + +void FuncGraphManager::AcquireNodes(const std::vector &nodes) { + AnfNodeSet acq; + for (auto &node : nodes) { + AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit_)); + + all_nodes_.update(new_nodes); + acq.update(new_nodes); + } + + for (auto &node : acq) { + MS_EXCEPTION_IF_NULL(node); + auto fg = node->func_graph(); + if (fg != nullptr) { + fg->AddNode(node); + } + ProcessInputs(node, kIncEdge); + } +} + +FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector &nodes) { + AnfNodeSet nodes_ordered(nodes); + FuncGraphSetPtr func_graphs_to_check = std::make_shared(); + while (!nodes_ordered.empty()) { + AnfNodePtr node = nodes_ordered.pop(); + MS_EXCEPTION_IF_NULL(node); + if (!all_nodes_.contains(node)) { + continue; + } + AnfNodeIndexSet &users = node_users_[node]; + + std::vector parameters; + if (!users.empty() || + (node->isa() && parameters.end() != std::find(parameters.begin(), parameters.end(), node))) { + continue; + } + if (IsValueNode(node)) { + auto fg = GetValueNode(node); + func_graphs_to_check->add(fg); + MS_LOG(DEBUG) << "Set value of node " << node->DebugString() << " from func graph " << fg->ToString() + << " to null"; + } + ProcessInputs(node, kDecEdge); + (void)all_nodes_.erase(node); + if (node->func_graph() != nullptr) { + node->func_graph()->DropNode(node); + } + + if (node->isa()) { + auto cnode = node->cast(); + nodes_ordered.update(cnode->inputs()); + } + (void)node_users_.erase(node); + } + return func_graphs_to_check; +} + +void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters) { + auto tr = Transact(); + tr.SetParameters(fg, parameters); + tr.Commit(); +} + +void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) { + auto tr = Transact(); + tr.AddParameter(fg, parameter); + tr.Commit(); +} + +bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { + auto tr = Transact(); + bool success = tr.Replace(old_node, new_node); + if (success) { + tr.Commit(); + } + return success; +} + +void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) { + auto tr = Transact(); + tr.SetEdge(node, index, value); + tr.Commit(); +} + +void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) { + AnfNodePtr source_return = source->get_return(); + AnfNodePtr source_output = source->output(); + AnfNodePtr source_prim = source_return->cast()->input(0); + + int index = 0; + (void)node_users_[source_prim].erase(make_pair(source_return, index)); + DropEdge(source_return, index, source_prim); + index = 1; + (void)node_users_[source_output].erase(make_pair(source_return, index)); + DropEdge(source_return, index, source_output); + (void)all_nodes_.erase(source_return); + (void)node_users_.erase(source_return); + source->DropNode(source_return); + for (auto &node : source->nodes()) { + node->set_func_graph(target); + if (node->scope() == kDefaultScope) { + node->set_scope(scope); + } + } + + MoveAllNodes(source, target); + all_nodes_.difference_update(source->parameters()); + (void)func_graphs_.erase(source); + if (source->manager().get() == this) { + source->set_manager(nullptr); + } +} + +void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { + auto fg = node->func_graph(); + if (input->isa()) { + fg->AddValueNode(input); + if (IsValueNode(input)) { + auto used = GetValueNode(input); + used->AddFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); + if (fg->AddFuncGraphUsed(used)) { + signals_->InvalidateComputer(); + } + if (IsPrimitiveCNode(node, prim::kPrimJ)) { + fg->AddJFuncGraph(used); + } + } + } else if (fg != nullptr && fg != input->func_graph()) { + if (fg->AddFreeVariable(input)) { + signals_->InvalidateComputer(); + } + } +} + +void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { + auto fg = node->func_graph(); + if (input->isa()) { + fg->DropValueNode(input); + if (IsValueNode(input)) { + auto used = GetValueNode(input); + used->DropFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); + if (fg->DropFuncGraphUsed(used)) { + signals_->InvalidateComputer(); + } + if (IsPrimitiveCNode(node, prim::kPrimJ)) { + fg->DropJFuncGraph(used); + } + } + } else if (fg != nullptr && fg != input->func_graph()) { + if (fg->DropFreeVariable(input)) { + signals_->InvalidateComputer(); + } + } +} + +void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { + target->CopyNodes(source); + target->CopyValueNodes(source); + target->CopyFuncGraphCNodesIndex(source); + target->CopyFreeVariables(source); + target->CopyFuncGraphsUsed(source); + target->CopyJFuncGraphs(source); + signals_->InvalidateComputer(); + source->ClearNodes(); + source->ClearValueNodes(); + source->ClearFuncGraphCNodesIndex(); + source->ClearFreeVariables(); + source->ClearFuncGraphsUsed(); + source->ClearJFuncGraphs(); +} + +FuncGraphTransaction FuncGraphManager::Transact() { + auto tr = FuncGraphTransaction(this); + return tr; +} + +void FuncGraphManager::ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, + EdgeTupleCounter *rm_edges, Counter *adds, Counter *rms) { + for (auto &iter : changes) { + auto operation = iter.op; + auto args = iter.args; + switch (operation) { + case Change::kTxSetEdge: { + auto edge = args.cast(); + auto old_node = edge.root_node->input(edge.index); + (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; + (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; + (*rms)[old_node] += 1; + (*adds)[edge.new_node] += 1; + edge.root_node->set_input(edge.index, edge.new_node); + } break; + case Change::kTxSetParams: { + auto param = args.cast(); + MS_EXCEPTION_IF_NULL(param.func_graph); + auto old_parameters = param.func_graph->parameters(); + for (auto &p : param.params) { + (*adds)[p] += 1; + } + for (auto &p : old_parameters) { + (*rms)[p] += 1; + } + param.func_graph->set_parameters(param.params); + } break; + case Change::kTxAddParam: { + auto param = args.cast(); + MS_EXCEPTION_IF_NULL(param.func_graph); + (*adds)[param.param] += 1; + auto param_node = param.param->cast(); + param.func_graph->append_parameter(param_node); + } break; + default: + break; + } + } +} + +void FuncGraphManager::CommitChanges(const std::vector &changes) { + EdgeTupleCounter add_edges; + EdgeTupleCounter rm_edges; + Counter adds; + Counter rms; + ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); + + auto sub_edges = add_edges - rm_edges; + for (auto &iter : sub_edges) { + auto root_node = iter.first.first; + int index = iter.first.second.first; + auto new_node = iter.first.second.second; + ProcessEdge(root_node, index, new_node, kIncEdge); + } + + auto sub_nodes = adds - rms; + std::vector nodes; + (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), + [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); + + AcquireNodes(nodes); + + auto sub_edges_reverse = rm_edges - add_edges; + for (auto &iter : sub_edges_reverse) { + auto root_node = iter.first.first; + int index = iter.first.second.first; + auto old_node = iter.first.second.second; + ProcessEdge(root_node, index, old_node, kDecEdge); + } + + auto sub_nodes_reverse = rms - adds; + std::vector nodes_reverse; + + (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), + [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); + + auto drop_func_graphs = MaybeDropNodes(nodes_reverse); + MaybeDropFuncGraphs(*drop_func_graphs); +} + +void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector ¶ms) { + changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); +} + +void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) { + changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param}); +} + +bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { + MS_EXCEPTION_IF_NULL(old_node); + MS_EXCEPTION_IF_NULL(new_node); + FuncGraphPtr old_func_graph = old_node->func_graph(); + if (old_func_graph != nullptr && old_func_graph->get_return() == old_node) { + MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString(); + return false; + } + auto users = manager_->node_users()[old_node]; + for (auto &node : users) { + SetEdge(node.first, node.second, new_node); + } + + return true; +} + +void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) { + if (k < 0) { + MS_LOG(EXCEPTION) << "Invalid value k = " << k; + } + MS_EXCEPTION_IF_NULL(src_node); + auto cnode = src_node->cast(); + if (cnode == nullptr) { + MS_LOG(EXCEPTION) << "src_node should be a cnode, but cast failed."; + } + changes_.emplace_back(Change::kTxSetEdge, ArgsOfSetEdge{cnode, v, IntToSize(k)}); +} + +void FuncGraphTransaction::Commit() { + std::vector changes; + changes_.swap(changes); + manager_->CommitChanges(changes); +} + +DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) { + MS_EXCEPTION_IF_NULL(manager_); + manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); + validate_ = false; +} + +void DepComputer::Recompute() { + if (!validate_) { + RealRecompute(); + validate_ = true; + } +} + +void DepComputer::Recompute(const FuncGraphPtr &fg) { + if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { + RealRecompute(fg); + func_graphs_validate_[fg] = true; + } +} + +FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) { + if (fg->seen_ == seen_num) { + return std::make_shared(); + } + FuncGraphSetPtr parents = std::make_shared(); + + // Append all the fvs in fg. + auto &fvs = fg->free_variables(); + for (auto fv : fvs) { + parents->add(fv.first->func_graph()); + } + + // Search the fv in fg's child func graph. + auto &fgs = fg->func_graphs_used(); + for (auto &item : fgs) { + fg->seen_ = seen_num; + auto gt = item.first; + parents->update(SeekParents(gt, seen_num)); + } + (void)parents->erase(fg); + return parents; +} + +void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { + MS_EXCEPTION_IF_NULL(fg); + func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration())); +} + +bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { + auto l1 = lhs.second.size(); + auto l2 = rhs.second.size(); + return l1 < l2; +} + +void ParentComputer::RealRecompute(FuncGraphPtr fg) { + this->parent_analysis_[fg] = nullptr; + // Note: must be a copy other than reference as it is modified thereafter. + auto deps = this->manager_->func_graph_parents_total(fg); + + if (deps.empty()) { + this->parent_analysis_[fg] = nullptr; + return; + } else if (deps.size() == 1) { + this->parent_analysis_[fg] = deps.pop(); + return; + } else { + // return nearest parent as parent + FuncGraphSet deps_copy(deps); + for (auto &dep : deps) { + auto parent_deps = this->manager_->func_graph_parents_total(dep); + for (auto &p_d : parent_deps) { + if (deps_copy.count(p_d)) { + (void)deps_copy.erase(p_d); + } + } + if (deps_copy.size() == 1) { + this->parent_analysis_[fg] = deps_copy.pop(); + return; + } + } + } +} + +void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { + MS_EXCEPTION_IF_NULL(manager_); + auto used_fg_total = manager_->func_graphs_used_total(fg); + for (auto &used_fg : used_fg_total) { + if (manager_->parent(used_fg) == fg) { + children_analysis_[fg].add(used_fg); + } + } +} + +void ScopeComputer::RealRecompute(FuncGraphPtr fg) { + MS_EXCEPTION_IF_NULL(manager_); + auto &children = manager_->children(fg); + + scope_analysis_[fg] = FuncGraphSet(); + scope_analysis_[fg].add(fg); + for (auto &child : children) { + scope_analysis_[fg].add(child); + } +} + +void FVTotalComputer::RealRecompute() { + auto manager = DepComputer::manager_; + MS_EXCEPTION_IF_NULL(manager); + + for (auto &fg : manager->func_graphs()) { + fv_total_analysis_[fg] = OrderedMap(); + } + + for (auto &fg : manager->func_graphs()) { + // add all free variable nodes + AnfNodeCounterMap items = fg->free_variables(); + for (auto &iter : items) { + auto curr = fg; + while (curr != nullptr) { + fv_total_analysis_[curr][iter.first] = iter.second; + curr = manager->parent(curr); + if (curr != nullptr) { + const AnfNodeSet &all_nodes = curr->nodes(); + if (all_nodes.contains(iter.first)) { + break; + } + } + } + } + + // add all FGs of free variables + auto &used = fg->func_graphs_used(); + for (auto &iter : used) { + auto p = manager->parent(iter.first); + if (p == nullptr) { + continue; + } + auto curr = fg; + while (curr != p) { + fv_total_analysis_[curr][iter.first] = iter.second; + curr = manager->parent(curr); + } + } + } +} + +void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { + MS_EXCEPTION_IF_NULL(manager_); + std::vector todo; + std::vector todo_new; + + todo.push_back(fg); + while (!todo.empty()) { + todo_new.clear(); + for (auto > : todo) { + for (auto &item : gt->func_graphs_used()) { + auto used_fg = item.first; + if (used_fg == fg) { + func_graph_used_total_analysis_[fg].add(used_fg); + continue; + } + if (func_graph_used_total_analysis_[fg].count(used_fg) == 0) { + todo_new.push_back(used_fg); + } + MS_LOG(DEBUG) << fg->ToString() << " add func graph " << used_fg->ToString(); + func_graph_used_total_analysis_[fg].add(used_fg); + } + } + todo = todo_new; + } +} + +bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { + MS_EXCEPTION_IF_NULL(manager); + std::vector todo; + std::vector todo_new; + todo.push_back(fg); + FuncGraphSet used_total; + while (!todo.empty()) { + todo_new.clear(); + for (auto > : todo) { + for (auto &item : gt->func_graphs_used()) { + auto used_g = item.first; + if (used_g == fg) { + return true; + } + if (used_total.count(used_g) == 0) { + todo_new.push_back(used_g); + } + used_total.add(used_g); + } + } + todo = todo_new; + } + return false; +} + +void RecursiveComputer::RealRecompute(FuncGraphPtr fg) { + this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); +} + +void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list *trace) { + MS_EXCEPTION_IF_NULL(trace); + auto res = std::find(trace->begin(), trace->end(), fg); + // find recursive + if (res != trace->end()) { + auto recur_ptr = std::make_shared>(res, trace->end()); + for (auto iter = res; iter != trace->end(); (void)iter++) { + MS_LOG(DEBUG) << "Recursive graph " << (*iter)->ToString(); + recursive_map_[*iter] = recur_ptr; + } + } else { + trace->push_back(fg); + auto &items = fg->func_graphs_used(); + for (auto iter = items.begin(); iter != items.end(); (void)iter++) { + CheckRecursiveGraphs(iter->first, trace); + } + trace->pop_back(); + if (!recursive_map_.count(fg)) { + recursive_map_[fg] = nullptr; + } + } +} + +bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { + if (fg->seen_ == seen_num) { + MS_LOG(DEBUG) << fg->ToString() << " had been checked"; + return false; + } + auto &j_fgs = fg->j_func_graphs(); + if (!j_fgs.empty()) { + // check g1->J(fg)->g2->g cycle; + auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair iter) { + return iter.first->seen_ != seen_num; + }); + if (contains_j != j_fgs.end()) { + MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; + return true; + } + } + fg->seen_ = seen_num; + + // check if func graphs used contains J(func_graph); + for (auto &item : fg->func_graphs_used()) { + auto used_g = item.first; + if (SeekJ(used_g, seen_num)) { + MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; + return true; + } + } + MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph)"; + return false; +} + +void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) { + this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration()); +} +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/core/ir/manager.h similarity index 100% rename from mindspore/ccsrc/ir/manager.h rename to mindspore/core/ir/manager.h diff --git a/mindspore/core/ir/meta_func_graph.cc b/mindspore/core/ir/meta_func_graph.cc new file mode 100644 index 0000000000..c0cf9d4d2f --- /dev/null +++ b/mindspore/core/ir/meta_func_graph.cc @@ -0,0 +1,45 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/meta_func_graph.h" + +// namespace to support intermediate representation definition +namespace mindspore { +FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { + TypePtrList types; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), + [](const AbstractBasePtr &arg) -> TypePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->BuildType(); + }); + // filter unsafe characters in log print since name_ is from outside + auto iter = cache_.find(types); + if (iter == cache_.end()) { + FuncGraphPtr fg = GenerateFromTypes(types); + MS_EXCEPTION_IF_NULL(fg); + MS_LOG(INFO) << "MetaFuncgraph: cache miss for types: " << mindspore::ToString(args_spec_list) + << ", g: " << fg->ToString(); + cache_[types] = fg; + return fg; + } else { + MS_LOG(DEBUG) << "MetaFuncgraph: cache hit for types: " << mindspore::ToString(args_spec_list) + << ", g: " << iter->second->ToString(); + return iter->second; + } +} +} // namespace mindspore diff --git a/mindspore/core/ir/meta_func_graph.h b/mindspore/core/ir/meta_func_graph.h new file mode 100644 index 0000000000..933c3f700d --- /dev/null +++ b/mindspore/core/ir/meta_func_graph.h @@ -0,0 +1,90 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_ +#define MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_ + +#include +#include +#include +#include +#include +#include + +#include "ir/dtype.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/signature.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +// namespace to support intermediate representation definition +// Graph generator. +// Can be called with a pipeline's resources and a list of argument types to +// generate a graph corresponding to these types. +class MetaFuncGraph : public FuncGraphBase { + public: + explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); } + + ~MetaFuncGraph() override = default; + + MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); + // Return normalized versions of the arguments. + // By default, this returns args unchanged. + virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const { + return args_spec_list; + } + + const std::vector &signatures() const { return signatures_; } + void set_signatures(const std::vector &signatures) { signatures_ = signatures; } + // Generate a Graph for the given abstract arguments. + virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list); + + // Generate a Graph for this type signature. + virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) { + MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; + } + + std::string name() { return name_; } + std::string ToString() const override { return name_; } + std::size_t hash() const override { return tid(); } + + virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; } + bool operator==(const Value &other) const override { + if (other.isa()) { + return &other == this; + } else { + return false; + } + } + const bool parse_info_ = true; + + protected: + template + std::shared_ptr shared_from_base() { + return std::static_pointer_cast(shared_from_this()); + } + std::string name_; + std::vector signatures_; + std::unordered_map cache_; +}; + +using MetaFuncGraphPtr = std::shared_ptr; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_META_FUNC_GRAPH_H_ diff --git a/mindspore/ccsrc/ir/meta_tensor.cc b/mindspore/core/ir/meta_tensor.cc similarity index 100% rename from mindspore/ccsrc/ir/meta_tensor.cc rename to mindspore/core/ir/meta_tensor.cc diff --git a/mindspore/core/ir/meta_tensor.h b/mindspore/core/ir/meta_tensor.h new file mode 100644 index 0000000000..00106215e8 --- /dev/null +++ b/mindspore/core/ir/meta_tensor.h @@ -0,0 +1,195 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_META_TENSOR_H_ +#define MINDSPORE_CCSRC_IR_META_TENSOR_H_ + +#include +#include +#include +#include + +#include "base/base.h" +#include "ir/dtype.h" +#include "utils/convert_utils.h" +#include "utils/hashing.h" + +// brief mindspore namespace. +// +// mindspore namespace is the top level namespace of MindSpore project. +// Other namespace should be a sub namespace of mindspore namespace in the ME project. +namespace mindspore { + +// brief mindspore::tensor namespace +// +// A sub namespace in ME to support tensor related definition. +namespace tensor { + +// brief Device info of Tensor +// +// Includes the format and data type of a tensor. +struct DeviceInfo { + explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr) + : format_(std::move(format)), data_type_(std::move(data_type)) {} + std::string format_ = "DefaultFormat"; + TypePtr data_type_ = nullptr; +}; + +// brief Metadata of Tensor +// +// Includes the metadata information of a tensor, such as data type, shape +// and so on. But it does not contain values of a tensor. +class MetaTensor : public Value { + public: + // Construction + MetaTensor(); + + // brief Constructs a meta tensor of a tensor having data_type data and shape. + // + // The constructed MetaTensor is not a Tensor, but it has the data type and shape + // information of a Tensor. The following codes will create a 2x3 float + // param data_type The data type of the tensor. + // param shape The shape of the tensor. + MetaTensor(const TypeId data_type, const std::vector &shape); + + MetaTensor(const TypePtr &type_ptr, const std::vector &shape); + // brief Constructs a MetaTensor object from an existing MetaTensor instance. + // + // The constructed MetaTensor object will have the same data type and shape as the + // meta_tensor. + // + // param meta_tensor An existing MetaTensor object. + MetaTensor(const MetaTensor &meta_tensor); + ~MetaTensor() override = default; + MS_DECLARE_PARENT(MetaTensor, Value) + + // brief Overloads operator = for MetaTensor. + // + // The constructed MetaTensor object has the same type and shape with meta_tensor. + // + // param meta_tensor An existing MetaTensor object. + virtual MetaTensor &operator=(const MetaTensor &meta_tensor); + + // brief Compares two MetaTensor objects. + // + // The constructed MetaTensor object has the same type and shape with meta_tensor. + // + // param meta_tensor The MetaTensor object to be compared. + // return true: If having same type and shape, return true, or return false. + virtual bool operator==(const MetaTensor &meta_tensor) const; + + // brief Returns the data type of the tensor in its MetaTensor. + // + // All the types are defined in "ir/dtype.h". + TypePtr Dtype() const; + abstract::AbstractBasePtr ToAbstract() override; + TypeId data_type() const { return data_type_; } + std::string ToString() const override; + std::string DumpText() const override; + // brief Sets the data type of a tensor in its MetaTensor. + // + // param data_type The data type of the tensor to be set. + virtual TypeId set_data_type(const TypeId data_type) { + data_type_ = data_type; + return data_type_; + } + virtual TypePtr SetDtype(const TypePtr type_ptr); + // brief Get tensor's shape. + // + // The shape of a tensor is stored in a vector. Each + // element of the vector represents the size of a dimension of the tensor. + // The order of each element in the vector is as same as the the dimension's + // order it represents. + // + // return A const vector which represents the shape of the tensor. + const std::vector &shape() const { return shape_; } + + // brief Sets the shape of a tensor. + // + // The shape of a tensor is stored in a vector. Each + // element of the vector represents the size of a dimension of the tensor. + // The order of each element in the vector is as same as the the dimension's + // order it represents. + // + // param shape The shape of the tensor. + // return The shape's size. + size_t set_shape(const std::vector &shape) { + this->shape_ = shape; + return shape_.size(); + } + + // Get tensor's device info. + DeviceInfo device_info() const { return device_info_; } + + // Set tensor's device info. + void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; } + + void SetDeviceInfo(const std::string &format, const TypePtr &data_type); + + // Get the size of a given dimension by its index number. + int DimensionSize(size_t index) const; + + // Get total number of elements in a tensor. + int ElementsNum() const; + + std::size_t hash() const override { + std::size_t hash_value = std::hash{}(SizeToInt(data_type_)); + hash_value = hash_combine(hash_value, std::hash{}(shape_.size())); + // hash all elements may costly, so only take at most 4 elements into account based on + // some experiments. + for (size_t i = 0; (i < shape_.size()) && (i < 4); ++i) { + hash_value = hash_combine(hash_value, (std::hash{}(shape_[i]))); + } + return hash_value; + } + bool operator==(const Value &other) const override { + if (other.isa()) { + auto other_ = static_cast(other); + return *this == other_; + } else { + return false; + } + } + const bool parse_info_ = true; + + protected: + // brief Data type of the tensor. + // + // All support data type is in Number Types of [TypeId], + // including [kNumberTypeBool], [kNumberTypeInt], + // [kNumberTypeUInt32], [kNumberTypeFloat32] and [kNumberTypeFloat64]. + TypeId data_type_; + + // brief Shape of the tensor. + // + // A std::vector container is used to store the shape of a tensor. + // Each element of the vector represents the size of a dimension of the tensor. + // The order of each element in the vector is as same as the the dimension's + // order it represents. If the dimension size is not set, its value will be -1. + std::vector shape_; + + // brief Device info of Tensor + // + // Includes the format and data type of a tensor on device. + DeviceInfo device_info_; +}; + +using MetaTensorPtr = std::shared_ptr; + +} // namespace tensor +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_META_TENSOR_H_ diff --git a/mindspore/core/ir/meta_tensor_extends.cc b/mindspore/core/ir/meta_tensor_extends.cc new file mode 100644 index 0000000000..d73aa19374 --- /dev/null +++ b/mindspore/core/ir/meta_tensor_extends.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/meta_tensor.h" + +#include +#include +#include +#include +#include + +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace tensor { +abstract::AbstractBasePtr MetaTensor::ToAbstract() { + auto tens = shared_from_base(); + auto dtype = tens->Dtype(); + if (!IsSubType(dtype, kNumber)) { + MS_LOG(EXCEPTION) << "Expect MetaTensor type kNumber but got: " << dtype->ToString() << "."; + } + auto tensor_shape = tens->shape(); + auto abs_tensor = std::make_shared(dtype, tensor_shape); + abs_tensor->set_value(shared_from_base()); + return abs_tensor; +} +} // namespace tensor +} // namespace mindspore diff --git a/mindspore/core/ir/named.cc b/mindspore/core/ir/named.cc new file mode 100644 index 0000000000..802f0c8693 --- /dev/null +++ b/mindspore/core/ir/named.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/named.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +bool Named::operator==(const Value &other) const { + if (other.isa()) { + auto other_named = static_cast(other); + return *this == other_named; + } else { + return false; + } +} + +abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared(); } +const NamedPtr kNone = std::make_shared(); + +abstract::AbstractBasePtr Null::ToAbstract() { return std::make_shared(); } +const NamedPtr kNull = std::make_shared(); + +abstract::AbstractBasePtr Ellipsis::ToAbstract() { return std::make_shared(); } +const NamedPtr kEllipsis = std::make_shared(); +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/named.h b/mindspore/core/ir/named.h similarity index 100% rename from mindspore/ccsrc/ir/named.h rename to mindspore/core/ir/named.h diff --git a/mindspore/ccsrc/ir/optimizer_caller.h b/mindspore/core/ir/optimizer_caller.h similarity index 100% rename from mindspore/ccsrc/ir/optimizer_caller.h rename to mindspore/core/ir/optimizer_caller.h diff --git a/mindspore/core/ir/param_value.h b/mindspore/core/ir/param_value.h new file mode 100644 index 0000000000..00b79ae91c --- /dev/null +++ b/mindspore/core/ir/param_value.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_H_ +#define MINDSPORE_CCSRC_IR_PARAM_VALUE_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "ir/tensor.h" + +namespace mindspore { + +class ParamValue { + public: + ParamValue() {} + + ParamValue(const ParamValue &other) = default; + + ~ParamValue() = default; + + tensor::MetaTensorPtr value() const { return value_; } + void set_value(const tensor::MetaTensorPtr &value) { value_ = value; } + + const std::string &name() const { return name_; } + void set_name(const std::string &name) { name_ = name; } + + const std::string &sparse_grad() const { return sparse_grad_; } + void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; } + + bool requires_grad() const { return requires_grad_; } + void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } + + bool layerwise_parallel() const { return layerwise_parallel_; } + void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; } + + bool has_indexed_slices_grad() const { return has_indexed_slices_grad_; } + void set_has_indexed_slices_grad(bool b) { has_indexed_slices_grad_ = b; } + + // Whether the parameter clone from other parameter. + bool cloned() const { return cloned_; } + + // Whether the parameter is cloned. + bool be_cloned() const { return be_cloned_; } + + // If the parameter is cloned, generate one index per clone. + const std::vector &be_cloned_index() const { return be_cloned_index_; } + + // If the parameter clone from other parameter, it has a unique index. + int32_t cloned_index() const { return cloned_index_; } + + // Make a cloned parameter and update clone info. + ParamValuePtr Clone() { + static std::atomic parameter_cloned_index{1}; + int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed); + auto clone = std::make_shared(*this); + clone->be_cloned_ = false; + clone->cloned_ = true; + clone->be_cloned_index_ = {}; + clone->cloned_index_ = index; + this->be_cloned_ = true; + this->be_cloned_index_.push_back(index); + return clone; + } + + private: + tensor::MetaTensorPtr value_; + std::string name_{"Parameter"}; + std::string sparse_grad_; + bool requires_grad_{true}; + bool layerwise_parallel_{false}; + bool has_indexed_slices_grad_{false}; + bool be_cloned_{false}; + bool cloned_{false}; + std::vector be_cloned_index_; + int32_t cloned_index_{0}; +}; + +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_H_ diff --git a/mindspore/core/ir/param_value_py.cc b/mindspore/core/ir/param_value_py.cc new file mode 100644 index 0000000000..fb4b313c22 --- /dev/null +++ b/mindspore/core/ir/param_value_py.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ir/param_value.h" +#include "pybind11/pybind11.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +namespace py = pybind11; + +REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) { + (void)py::class_(*m, "ParamValue") + .def(py::init()) + .def("clone", &ParamValue::Clone) + .def_property("data", &ParamValue::value, &ParamValue::set_value) + .def_property("name", &ParamValue::name, &ParamValue::set_name) + .def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad) + .def_property("layerwise_parallel", &ParamValue::layerwise_parallel, + &ParamValue::set_layerwise_parallel) + .def_property("has_indexed_slices_grad", &ParamValue::has_indexed_slices_grad, + &ParamValue::set_has_indexed_slices_grad) + .def_property("sparse_grad", &ParamValue::sparse_grad, &ParamValue::set_sparse_grad) + .def(py::pickle( + [](const ParamValue &p) { // __getstate__ + return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(), + p.layerwise_parallel(), p.has_indexed_slices_grad(), + p.sparse_grad()); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 6) { + std::runtime_error("Invalid state for ParamValue!"); + } + ParamValuePtr p = std::make_shared(); + p->set_value(t[0].cast()); + p->set_name(t[1].cast()); + p->set_requires_grad(t[2].cast()); + p->set_layerwise_parallel(t[3].cast()); + p->set_has_indexed_slices_grad(t[4].cast()); + p->set_sparse_grad(t[5].cast()); + return p; + })); + })); +} // namespace mindspore diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h new file mode 100644 index 0000000000..94ba4a381a --- /dev/null +++ b/mindspore/core/ir/pattern_matcher.h @@ -0,0 +1,310 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ +#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ + +#include +#include + +#include "ir/anf.h" +#include "frontend/operator/ops.h" + +namespace mindspore { + +/// +/// Base class for all recognizable patterns. +/// We implement an Expression Template approach using static polymorphism based on +/// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect +/// to the use of virtual functions without the costs..." as described in: +/// https://en.wikipedia.org/wiki/Expression_templates and +/// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern +/// The TryCapture function tries to capture the pattern with the given node. +/// The GetNode function builds a new node using the captured values. +/// + +template +class PBase { + public: + bool CheckFunc(const opt::PredicateFuncType &func, const AnfNodePtr &node) { + return func(get_object().GetNode(node)); + } + + const T &get_object() const { return *static_cast(this); } + + template + bool TryCapture(const TN &value) const { + get_object().Reset(); + return get_object().TryCapture_(value); + } + + using Internal = T; +}; + +template +class PIsEqual { + public: + bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } +}; + +template +class PatternNode : public PBase > { + public: + T GetNode(const AnfNodePtr &node) const { + if (!captured_) { + MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode."; + } + return captured_node_; + } + + bool TryCapture_(const T &node) const { + if (!captured_) { + captured_node_ = node; + captured_ = true; + return true; + } + return PIsEqual()(captured_node_, node); + } + + void Reset() const { captured_ = false; } + using Internal = const PatternNode &; + + protected: + mutable T captured_node_; + mutable bool captured_{false}; +}; + +template +class PBinOperation : public PBase > { + public: + PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {} + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + AnfNodePtr lhs = x_.GetNode(node->func_graph()); + AnfNodePtr rhs = y_.GetNode(node->func_graph()); + AnfNodePtrList list = {prim_->cast(), lhs, rhs}; + return NewCNode(list, node->func_graph()); + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (IsPrimitiveCNode(node, prim_)) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if (inputs.size() == 3) { + // Binary Prim assumes only two inputs + if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { + return false; + } + return true; + } + } + return false; + } + + void Reset() const { + x_.Reset(); + y_.Reset(); + } + + private: + const PrimitivePtr prim_; + typename T::Internal x_; + typename T2::Internal y_; +}; + +/// +/// Helper functions to apply a pattern function on all elements of a tuple +/// +namespace tuple_utils { +template +struct apply_func_tuple_item { + template + static void apply(Func *func, const TTuple &tuple) { + (*func)(Index, std::get(tuple)); + apply_func_tuple_item<(Index + 1) == std::tuple_size::value, (Index + 1), Func>::apply(func, tuple); + } +}; + +template +struct apply_func_tuple_item { + template + static void apply(Func *func, const TTuple &tuple) {} +}; + +template +inline void apply_func_tuple(Func *func, const TTuple &tuple) { + apply_func_tuple_item::value == 0, 0, Func>::apply(func, tuple); +} + +struct PTupleResetCapture { + template + void operator()(size_t i, const T &pattern) const { + pattern.Reset(); + } +}; + +struct PTupleCapture { + explicit PTupleCapture(const AnfNodePtrList tuple) : tuple_(tuple) {} + + template + void operator()(size_t i, const TPattern &pattern) { + // Check if the first node is a Primitive + if (i == 0 && tuple_[i]->isa()) { + auto prim = tuple_[i]->cast(); + if (tuple_[i] != pattern.GetNode(tuple_[i])) { + captured_ = false; + } + } else { + captured_ = captured_ && pattern.TryCapture_(tuple_[i]); + } + } + + const AnfNodePtrList tuple_; + bool captured_{true}; +}; + +struct PTupleGetNode { + explicit PTupleGetNode(const AnfNodePtr &node) : node_(node) {} + + template + void operator()(size_t, const TPattern &pattern) { + args_.push_back(pattern.GetNode(node_)); + } + + const AnfNodePtr &node_; + std::vector args_; +}; +} // namespace tuple_utils + +template +class PCNode : public PBase > { + public: + explicit PCNode(const TArgs &... args) : args_(args...) {} + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + tuple_utils::PTupleGetNode get_node(node); + tuple_utils::apply_func_tuple(&get_node, args_); + return NewCNode(get_node.args_, node->func_graph()); + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (node->isa()) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if (inputs.size() != sizeof...(TArgs)) { + return false; + } + tuple_utils::PTupleCapture capture_func(inputs); + tuple_utils::apply_func_tuple(&capture_func, args_); + return capture_func.captured_; + } + + return false; + } + + void Reset() const { + tuple_utils::PTupleResetCapture reset; + tuple_utils::apply_func_tuple(&reset, args_); + } + + private: + std::tuple args_; +}; + +template +class PPrimitive : public PBase > { + public: + explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {} + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + tuple_utils::PTupleGetNode get_node(node); + tuple_utils::apply_func_tuple(&get_node, args_); + auto prim_cnode = get_node.args_; + prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_)); + return NewCNode(prim_cnode, node->func_graph()); + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (IsPrimitiveCNode(node, prim_)) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if ((inputs.size() - 1) != sizeof...(TArgs)) { + return false; + } + + AnfNodePtrList rest(inputs.begin() + 1, inputs.end()); + tuple_utils::PTupleCapture capture_func(rest); + tuple_utils::apply_func_tuple(&capture_func, args_); + + return capture_func.captured_; + } + + return false; + } + + void Reset() const { + tuple_utils::PTupleResetCapture reset; + tuple_utils::apply_func_tuple(&reset, args_); + } + + private: + const PrimitivePtr prim_; + std::tuple args_; +}; + +// Macro for binary operation functions +#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \ + template \ + inline PBinOperation Operator(const PBase &x, const PBase &y) { \ + return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \ + } + +// Arithmetic operations +BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd); +BIN_OPERATION_PATTERN(operator*, prim::kPrimMul); + +// Macros for match and replace +#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ + if ((CaptureNode).TryCapture(OrigNode)) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } + +#define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ + if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } + +#define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ + if ((CaptureNode).TryCapture(OrigNode)) { \ + if ((Condition)) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } \ + return (ElseNode).GetNode(OrigNode); \ + } + +#define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ + if ((CaptureNode).TryCapture(OrigNode)) { \ + return (Lambda)(); \ + } + +#define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ + if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ + return (Lambda)(); \ + } + +} // namespace mindspore + +#endif // #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ diff --git a/mindspore/core/ir/primitive.cc b/mindspore/core/ir/primitive.cc new file mode 100644 index 0000000000..352c0f31ae --- /dev/null +++ b/mindspore/core/ir/primitive.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/primitive.h" + +#include + +namespace mindspore { +bool Primitive::operator==(const Value &other) const { + if (other.isa()) { + auto other_prim = static_cast(other); + return *this == other_prim; + } else { + return false; + } +} + +bool Primitive::operator==(const Primitive &other) const { + if (name() != other.name()) { + return false; + } + if (attrs_.size() != other.attrs_.size()) { + return false; + } + auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair &item) -> bool { + if (item.second == nullptr) { + return false; + } + auto iter = other.attrs_.find(item.first); + if (iter == other.attrs_.end()) { + return false; + } + return *item.second == *iter->second; + }); + return all; +} + +std::string Primitive::GetAttrsText() const { + if (attrs_.empty()) { + return ""; + } + + std::ostringstream oss; + oss << "["; + bool is_first = true; + for (auto &attr : attrs_) { + if (is_first) { + is_first = false; + } else { + oss << ", "; + } + oss << attr.first << "=" << attr.second->DumpText(); + } + oss << "]"; + + return oss.str(); +} +} // namespace mindspore diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h new file mode 100644 index 0000000000..5471b58063 --- /dev/null +++ b/mindspore/core/ir/primitive.h @@ -0,0 +1,152 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_H_ +#define MINDSPORE_CCSRC_IR_PRIMITIVE_H_ + +#include +#include +#include +#include +#include + +#include "ir/dtype/type.h" +#include "abstract/abstract_value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "utils/base_ref_extends.h" + +namespace mindspore { +// Supported meta type +enum PrimType { + kPrimTypeUnknown = 0, + kPrimTypeBegin = kTypeUnknown, + kPrimTypeBuiltIn, // Built-in primitive operator + kPrimTypePyInferShape, // Primitive operator defined by custom + kPrimTypePyInferTensor, // Primitive operator defined by custom + kPrimTypeUserCustom +}; + +class Primitive : public Named { + public: + explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) + : Named(name), + is_base_(is_base), + has_signature_(false), + prim_type_(prim_type), + record_evaluate_add_attr_(false) {} + + Primitive(const Primitive &prim) + : Named(prim), + attrs_(prim.attrs_), + instance_name_(prim.instance_name_), + is_base_(prim.is_base_), + has_signature_(prim.has_signature_), + prim_type_(prim.prim_type_), + record_evaluate_add_attr_(false) {} + + MS_DECLARE_PARENT(Primitive, Named); + + abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); + std::string ToString() const override { return name(); } + void BeginRecordAddAttr() { + evaluate_added_attrs_.clear(); + record_evaluate_add_attr_ = true; + } + void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } + Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { + attrs_[name] = attr; + if (record_evaluate_add_attr_) { + evaluate_added_attrs_[name] = attr; + } + return *this; + } + + Primitive &SetAttrs(const std::unordered_map &attrs) { + for (auto &attr : attrs) { + attrs_[attr.first] = attr.second; + } + return *this; + } + + void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } + void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } + + ValuePtr GetAttr(const std::string &attrName) const { + auto iter = attrs_.find(attrName); + return iter == attrs_.cend() ? nullptr : iter->second; + } + + const std::unordered_map &attrs() const { return attrs_; } + const std::unordered_map &evaluate_added_attrs() const { return evaluate_added_attrs_; } + + // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. + bool HasAttr() const { return !attrs_.empty(); } + bool HasAttr(const std::string &attrName) const { + auto iter = attrs_.find(attrName); + return !(iter == attrs_.cend()); + } + void set_prim_type(const PrimType t) { prim_type_ = t; } + void set_instance_name(const std::string s) { instance_name_ = s; } + bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } + bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } + bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } + + PrimType prim_type() const { return prim_type_; } + std::string instance_name() const { return instance_name_; } + std::string GetAttrsText() const; + bool operator==(const Value &other) const override; + bool operator==(const Primitive &other) const; + ~Primitive() override = default; + + void set_has_signature(bool has_signature) { has_signature_ = has_signature; } + bool has_signature() const { return has_signature_; } + bool is_base() const { return is_base_; } + virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; } + virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } + + protected: + std::unordered_map attrs_; + std::unordered_map evaluate_added_attrs_; + + private: + std::string instance_name_; + bool is_base_; + bool has_signature_; + PrimType prim_type_; + bool record_evaluate_add_attr_; +}; + +inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { + os << *p; + return os; +} + +struct PrimitiveEqual { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { + MS_EXCEPTION_IF_NULL(t1); + MS_EXCEPTION_IF_NULL(t2); + return t1->name() == t2->name(); + } +}; + +struct PrimitiveHasher { + std::size_t operator()(PrimitivePtr const &prim) const { + MS_EXCEPTION_IF_NULL(prim); + return prim->Hash(); + } +}; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ diff --git a/mindspore/core/ir/primitive_py.cc b/mindspore/core/ir/primitive_py.cc new file mode 100644 index 0000000000..1a97487ddc --- /dev/null +++ b/mindspore/core/ir/primitive_py.cc @@ -0,0 +1,195 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/primitive_py.h" +#include +#include +#include "ir/signature.h" +#include "frontend/operator/ops.h" +#include "./common.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pybind11/pytypes.h" +#include "utils/convert_utils_base.h" +#include "utils/primitive_utils.h" +#include "utils/base_ref_py.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +namespace { +constexpr auto kBpropAttrName = "bprop"; +constexpr auto kCellHookAttrName = "cell_hook"; +constexpr auto kCellIDAttrName = "cell_id"; +void SyncData(const py::object &arg) { + if (py::isinstance(arg)) { + py::tuple arg_list = py::cast(arg); + for (size_t i = 0; i < arg_list.size(); i++) { + SyncData(arg_list[i]); + } + } + if (py::isinstance(arg)) { + auto tensor = py::cast(arg); + (void)tensor->data_sync(); + } +} +} // namespace +std::map PrimitivePy::hook_grad_; +static ValuePtr PyArgToValue(const py::object &arg) { + if (py::isinstance(arg) && + py::cast(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { + return nullptr; + } + return parse::data_converter::PyDataToValue(arg); +} + +void PrimitivePy::set_signatures( + std::vector> signatures) { + signatures_.clear(); + for (auto &signature : signatures) { + auto [name, rw, kind, arg_default, dtype] = signature; + auto default_value = PyArgToValue(arg_default); + signatures_.emplace_back(name, rw, kind, default_value, dtype); + } + set_has_signature(true); +} + +py::function PrimitivePy::GetBpropFunction() { + static const char *const get_bprop_func_name = "get_bprop"; + if (py::hasattr(python_obj_, get_bprop_func_name)) { + py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); + return fn; + } else { + auto fn = GetBpropFunctionByObj(python_obj_); + return fn; + } +} + +BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { + auto py_args = py::tuple(args.size()); + size_t i = 0; + for (auto &arg : args) { + py_args[i] = BaseRefToPyData(arg); + MS_LOG(DEBUG) << "arg:" << i << ":"; + i++; + } + py::object obj; + bool is_bprop = this->HasAttr(kBpropAttrName); + if (is_bprop) { + SyncData(py_args); + obj = hook_(*py_args); + return std::make_shared(obj); + } + SyncData(py_args[2]); + bool is_cell = this->HasAttr(kCellHookAttrName); + if (is_cell) { + auto cell_id = GetValue(this->GetAttr(kCellIDAttrName)); + auto iter = hook_grad_.find(cell_id); + if (iter != hook_grad_.end()) { + auto hook_args = py::tuple(3); + hook_args[0] = cell_id; + hook_args[1] = py::make_tuple(iter->second); + hook_args[2] = py::make_tuple(py_args[2]); + obj = hook_(*hook_args); + if (py::isinstance(obj)) { + obj = py_args[2]; + } + hook_grad_.erase(cell_id); + } else { + hook_grad_[cell_id] = py_args[2]; + obj = py_args[2]; + } + } else { + // Hook operator for execute variable hook function + obj = hook_(py::make_tuple(py_args[2])); + if (py::isinstance(obj)) { + obj = py_args[2]; + } + } + obj = py::make_tuple(obj); + return std::make_shared(obj); +} + +py::function PrimitivePy::GetComputeFunction() { + static const char *const compute_func_name = "vm_impl"; + + if (py::hasattr(python_obj_, compute_func_name)) { + MS_LOG(INFO) << name() << " compute_func_name"; + py::function fn = python_obj_.attr(compute_func_name).cast(); + return fn; + } + + static const std::string vm_module = "mindspore.ops.vm_impl_registry"; + static const std::string get_vm_impl_fn = "get_vm_impl_fn"; + MS_LOG(INFO) << name() << ": get_vm_impl_fn"; + py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn); + py::function vm_fn = get_fn(python_obj_); + + if (py::isinstance(vm_fn)) { + MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast(); + vm_fn = mindspore::GetComputeFunction(Primitive::name()); + } + return vm_fn; +} + +void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { + std::string attr_name = name; + ValuePtr converted_ret = nullptr; + if (py::isinstance(obj)) { + MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; + } + bool converted = parse::ConvertData(obj, &converted_ret); + if (!converted) { + MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); + } + (void)this->AddAttr(attr_name, converted_ret); +} + +py::dict PrimitivePy::GetAttrDict() { + py::dict attr_dict; + for (auto &attr : attrs_) { + attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); + } + return attr_dict; +} + +void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { + MS_EXCEPTION_IF_NULL(primitive); + if (!primitive->isa()) { + MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!"; + } + auto primitive_py = primitive->cast(); + MS_EXCEPTION_IF_NULL(primitive_py); + this->set_hook(primitive_py->hook()); +} + +REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { + (void)py::enum_(*m, "prim_type", py::arithmetic()) + .value("unknown", PrimType::kPrimTypeUnknown) + .value("builtin", PrimType::kPrimTypeBuiltIn) + .value("py_infer_shape", PrimType::kPrimTypePyInferShape) + .value("user_custom", PrimType::kPrimTypeUserCustom); + (void)py::class_>(*m, "Primitive_") + .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) + .def(py::init()) + .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") + .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") + .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") + .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") + .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") + .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); + })); +} // namespace mindspore diff --git a/mindspore/core/ir/primitive_py.h b/mindspore/core/ir/primitive_py.h new file mode 100644 index 0000000000..2dc45ac341 --- /dev/null +++ b/mindspore/core/ir/primitive_py.h @@ -0,0 +1,73 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ +#define MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ + +#include +#include +#include +#include +#include +#include + +#include "abstract/abstract_value.h" +#include "utils/misc.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" +#include "ir/primitive.h" +#include "ir/signature.h" +#include "frontend/parallel/ops_info/operator_info.h" + +namespace py = pybind11; +namespace mindspore { +class PrimitivePy : public Primitive { + public: + PrimitivePy(const py::str &name, const py::object &python_obj) + : Primitive(name, false), python_obj_(python_obj), signatures_() {} + ~PrimitivePy() override = default; + MS_DECLARE_PARENT(PrimitivePy, Primitive); + py::function GetBpropFunction(); + py::function GetComputeFunction(); + + void set_signatures( + std::vector> + signatures); + + const std::vector &signatures() const { return signatures_; } + + void CopyHookFunction(const PrimitivePtr &primitive) override; + + void AddPyAttr(const py::str &name, const py::object &obj); + + py::dict GetAttrDict(); + void set_hook(const py::function &hook) { hook_ = hook; } + py::function hook() const { return hook_; } + BaseRef RunHookFunction(const VectorRef &args) const override; + const bool parse_info_ = true; + const py::object &GetPyObj() const { return python_obj_; } + bool is_tuple_input_ = false; + + private: + py::object python_obj_; + py::function hook_; + std::vector signatures_; + static std::map hook_grad_; +}; + +using PrimitivePyPtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ diff --git a/mindspore/core/ir/scalar.h b/mindspore/core/ir/scalar.h new file mode 100644 index 0000000000..adae8c65f9 --- /dev/null +++ b/mindspore/core/ir/scalar.h @@ -0,0 +1,362 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_SCALAR_H_ +#define MINDSPORE_CCSRC_IR_SCALAR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "base/base.h" +#include "ir/dtype.h" +#include "ir/dtype/number.h" + +using std::fabs; + +namespace mindspore { +class Scalar : public Value { + public: + Scalar() = default; + explicit Scalar(const TypePtr t) : Value(t) {} + ~Scalar() override = default; + MS_DECLARE_PARENT(Scalar, Value) + virtual bool IsZero() = 0; + virtual bool IsOne() = 0; + abstract::AbstractBasePtr ToAbstract() override; + + protected: + std::size_t hash_ = 0; +}; +using ScalarPtr = std::shared_ptr; + +class BoolImm : public Scalar { + public: + explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = std::hash{}(v_); } + ~BoolImm() override = default; + MS_DECLARE_PARENT(BoolImm, Scalar) + std::size_t hash() const override { return hash_; } + bool value() const { return v_; } + bool IsZero() override { return v_ == false; } + bool IsOne() override { return v_ == true; } + bool operator==(const Value &other) const override; + bool operator==(const BoolImm &other) const; + std::string ToString() const override { + if (v_) { + return "true"; + } else { + return "false"; + } + } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "Bool(" << v_ << ")"; + return oss.str(); + } + + private: + bool v_; +}; +using BoolImmPtr = std::shared_ptr; +IMM_TRAITS(BoolImmPtr, bool) + +class IntergerImm : public Scalar { + public: + IntergerImm() = default; + explicit IntergerImm(const TypePtr &t) : Scalar(t) {} + ~IntergerImm() override = default; + MS_DECLARE_PARENT(IntergerImm, Scalar) +}; + +class Int8Imm : public IntergerImm { + public: + Int8Imm() : IntergerImm(kInt8), v_(0) {} + explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = std::hash{}(v_); } + ~Int8Imm() override = default; + MS_DECLARE_PARENT(Int8Imm, IntergerImm) + std::size_t hash() const override { return hash_; } + bool IsZero() override { return v_ == 0; } + bool IsOne() override { return v_ == 1; } + int8_t value() const { return v_; } + bool operator==(const Value &other) const override; + bool operator==(const Int8Imm &other) const; + std::string ToString() const override { return std::to_string(v_); } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "I8(" << v_ << ")"; + return oss.str(); + } + + private: + int8_t v_; +}; +using Int8ImmPtr = std::shared_ptr; +IMM_TRAITS(Int8ImmPtr, int8_t) + +class Int16Imm : public IntergerImm { + public: + Int16Imm() : IntergerImm(kInt16), v_(0) {} + explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = std::hash{}(v_); } + ~Int16Imm() override = default; + MS_DECLARE_PARENT(Int16Imm, IntergerImm) + std::size_t hash() const override { return hash_; } + bool IsZero() override { return v_ == 0; } + bool IsOne() override { return v_ == 1; } + int16_t value() const { return v_; } + bool operator==(const Value &other) const override; + bool operator==(const Int16Imm &other) const; + std::string ToString() const override { return std::to_string(v_); } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "I16(" << v_ << ")"; + return oss.str(); + } + + private: + int16_t v_; +}; +using Int16ImmPtr = std::shared_ptr; +IMM_TRAITS(Int16ImmPtr, int16_t) + +class Int32Imm : public IntergerImm { + public: + Int32Imm() : IntergerImm(kInt32), v_(0) {} + explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = std::hash{}(v_); } + ~Int32Imm() override = default; + MS_DECLARE_PARENT(Int32Imm, IntergerImm) + std::size_t hash() const override { return hash_; } + bool IsZero() override { return v_ == 0; } + bool IsOne() override { return v_ == 1; } + int32_t value() const { return v_; } + bool operator==(const Value &other) const override; + bool operator==(const Int32Imm &other) const; + std::string ToString() const override { return std::to_string(v_); } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "I32(" << v_ << ")"; + return oss.str(); + } + + private: + int32_t v_; +}; +using Int32ImmPtr = std::shared_ptr; +IMM_TRAITS(Int32ImmPtr, int32_t) + +class Int64Imm : public IntergerImm { + public: + Int64Imm() : IntergerImm(kInt64), v_(0) {} + explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = std::hash{}(v_); } + ~Int64Imm() override = default; + MS_DECLARE_PARENT(Int64Imm, IntergerImm) + std::size_t hash() const override { return hash_; } + bool IsZero() override { return v_ == 0; } + bool IsOne() override { return v_ == 1; } + int64_t value() const { return v_; } + bool operator==(const Value &other) const override; + bool operator==(const Int64Imm &other) const; + std::string ToString() const override { return std::to_string(v_); } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "I64(" << v_ << ")"; + return oss.str(); + } + + private: + int64_t v_; +}; +using Int64ImmPtr = std::shared_ptr; +IMM_TRAITS(Int64ImmPtr, int64_t) + +class UInt8Imm : public IntergerImm { + public: + UInt8Imm() : IntergerImm(kUInt8), v_(0) {} + explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { hash_ = std::hash{}(v_); } + ~UInt8Imm() override = default; + MS_DECLARE_PARENT(UInt8Imm, IntergerImm) + std::size_t hash() const override { return hash_; } + bool IsZero() override { return v_ == 0; } + bool IsOne() override { return v_ == 1; } + uint8_t value() const { return v_; } + bool operator==(const Value &other) const override; + bool operator==(const UInt8Imm &other) const; + std::string ToString() const override { return std::to_string(v_); } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "U8(" << v_ << ")"; + return oss.str(); + } + + private: + uint8_t v_; +}; +using UInt8ImmPtr = std::shared_ptr; +IMM_TRAITS(UInt8ImmPtr, uint8_t); + +class UInt16Imm : public IntergerImm { + public: + UInt16Imm() : IntergerImm(kUInt16), v_(0) {} + explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { hash_ = std::hash{}(v_); } + ~UInt16Imm() override = default; + MS_DECLARE_PARENT(UInt16Imm, IntergerImm) + std::size_t hash() const override { return hash_; } + bool IsZero() override { return v_ == 0; } + bool IsOne() override { return v_ == 1; } + uint16_t value() const { return v_; } + bool operator==(const Value &other) const override; + bool operator==(const UInt16Imm &other) const; + std::string ToString() const override { return std::to_string(v_); } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "U16(" << v_ << ")"; + return oss.str(); + } + + private: + uint16_t v_; +}; +using UInt16ImmPtr = std::shared_ptr; +IMM_TRAITS(UInt16ImmPtr, uint16_t); + +class UInt32Imm : public IntergerImm { + public: + UInt32Imm() : IntergerImm(kUInt32), v_(0) {} + explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { hash_ = std::hash{}(v_); } + ~UInt32Imm() override = default; + MS_DECLARE_PARENT(UInt32Imm, IntergerImm) + std::size_t hash() const override { return hash_; } + bool IsZero() override { return v_ == 0; } + bool IsOne() override { return v_ == 1; } + uint32_t value() const { return v_; } + bool operator==(const Value &other) const override; + bool operator==(const UInt32Imm &other) const; + std::string ToString() const override { return std::to_string(v_); } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "U32(" << v_ << ")"; + return oss.str(); + } + + private: + uint32_t v_; +}; +using UInt32ImmPtr = std::shared_ptr; +IMM_TRAITS(UInt32ImmPtr, uint32_t); + +class UInt64Imm : public IntergerImm { + public: + UInt64Imm() : IntergerImm(kUInt64), v_(0) {} + explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { hash_ = std::hash{}(v); } + ~UInt64Imm() override = default; + MS_DECLARE_PARENT(UInt64Imm, IntergerImm) + std::size_t hash() const override { return hash_; } + bool IsZero() override { return v_ == 0; } + bool IsOne() override { return v_ == 1; } + uint64_t value() const { return v_; } + bool operator==(const Value &other) const override; + bool operator==(const UInt64Imm &other) const; + std::string ToString() const override { return std::to_string(v_); } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "U64(" << v_ << ")"; + return oss.str(); + } + + private: + uint64_t v_; +}; +using UInt64ImmPtr = std::shared_ptr; +IMM_TRAITS(UInt64ImmPtr, uint64_t); + +class FloatImm : public Scalar { + public: + FloatImm() = default; + explicit FloatImm(const TypePtr &t) : Scalar(t) {} + ~FloatImm() override = default; + MS_DECLARE_PARENT(FloatImm, Scalar) +}; +using FloatImmPtr = std::shared_ptr; + +class FP32Imm : public FloatImm { + public: + FP32Imm() : FloatImm(kFloat32), v_(0.0) {} + explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = std::hash{}(v_); } + ~FP32Imm() override = default; + MS_DECLARE_PARENT(FP32Imm, FloatImm) + std::size_t hash() const override { return hash_; } + bool IsZero() override { return fabs(v_) <= FLT_EPSILON; } + bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; } + float value() const { return v_; } + bool operator==(const Value &other) const override; + bool operator==(const FP32Imm &other) const; + std::string ToString() const override { return std::to_string(v_); } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "F32(" << v_ << ")"; + return oss.str(); + } + + private: + float v_; +}; +using FP32ImmPtr = std::shared_ptr; +IMM_TRAITS(FP32ImmPtr, float) + +class FP64Imm : public FloatImm { + public: + FP64Imm() : FloatImm(kFloat64), v_(0.0) {} + explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = std::hash{}(v_); } + ~FP64Imm() override = default; + MS_DECLARE_PARENT(FP64Imm, FloatImm) + std::size_t hash() const override { return hash_; } + bool IsZero() override { return fabs(v_) <= DBL_EPSILON; } + bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; } + double value() const { return v_; } + bool operator==(const Value &other) const override; + bool operator==(const FP64Imm &other) const; + std::string ToString() const override { return std::to_string(v_); } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "F64(" << v_ << ")"; + return oss.str(); + } + + private: + double v_; +}; +using FP64ImmPtr = std::shared_ptr; +IMM_TRAITS(FP64ImmPtr, double) + +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_SCALAR_H_ diff --git a/mindspore/ccsrc/ir/scope.cc b/mindspore/core/ir/scope.cc similarity index 100% rename from mindspore/ccsrc/ir/scope.cc rename to mindspore/core/ir/scope.cc diff --git a/mindspore/ccsrc/ir/scope.h b/mindspore/core/ir/scope.h similarity index 100% rename from mindspore/ccsrc/ir/scope.h rename to mindspore/core/ir/scope.h diff --git a/mindspore/core/ir/signature.h b/mindspore/core/ir/signature.h new file mode 100644 index 0000000000..e9a5a2e1ca --- /dev/null +++ b/mindspore/core/ir/signature.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_SIGNATURE_H_ +#define MINDSPORE_CCSRC_IR_SIGNATURE_H_ + +#include +#include +#include "ir/value.h" + +namespace mindspore { +// Input signature, support type +enum SignatureEnumRW { + // describe the arguments action on read and write + kRWRead = 0, // use the value of the input + kRWWrite, // use the key of the input + kRWRef, // use the ref of the input + kRWEmptyDefaultValue, + kRWDefault = kRWRead +}; +enum SignatureEnumKind { + kKindPositionalKeyword = 0, // use value of the input start from this arg + kKindVarPositional, // use key of the input start from this arg + kKindKeywordOnly, + kKindVarKeyword, // use ref of the input start from this arg + kKindEmptyDefaultValue, + kKindDefault = kKindPositionalKeyword +}; +enum SignatureEnumDType { + kDType = 0, + kDType1, + kDType2, + kDType3, + kDType4, + kDType5, + kDType6, + kDType7, + kDType8, + kDType9, + kDTypeEmptyDefaultValue +}; +struct Signature { + std::string name; + SignatureEnumRW rw; + SignatureEnumKind kind; + ValuePtr default_value; // nullptr for no default value + SignatureEnumDType dtype; + Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, + const ValuePtr &arg_default, const SignatureEnumDType &arg_dtype) + : name(arg_name), rw(rw_tag), kind(arg_kind), default_value(arg_default), dtype(arg_dtype) {} + Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind) + : Signature(arg_name, rw_tag, arg_kind, nullptr, SignatureEnumDType::kDTypeEmptyDefaultValue) {} +}; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_SIGNATURE_H_ diff --git a/mindspore/core/ir/signature_py.cc b/mindspore/core/ir/signature_py.cc new file mode 100644 index 0000000000..f513df8533 --- /dev/null +++ b/mindspore/core/ir/signature_py.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/signature.h" +#include "pybind11/operators.h" +#include "pybind_api/api_register.h" +#include "pipeline/jit/parse/data_converter.h" + +namespace py = pybind11; + +namespace mindspore { +// Bind SignatureEnumRW as a python class. +REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { + (void)py::enum_(*m, "signature_rw", py::arithmetic()) + .value("RW_READ", SignatureEnumRW::kRWRead) + .value("RW_WRITE", SignatureEnumRW::kRWWrite) + .value("RW_REF", SignatureEnumRW::kRWRef) + .value("RW_EMPTY_DEFAULT_VALUE", SignatureEnumRW::kRWEmptyDefaultValue); + (void)py::enum_(*m, "signature_kind", py::arithmetic()) + .value("KIND_POSITIONAL_KEYWORD", SignatureEnumKind::kKindPositionalKeyword) + .value("KIND_VAR_POSITIONAL", SignatureEnumKind::kKindVarPositional) + .value("KIND_KEYWORD_ONLY", SignatureEnumKind::kKindKeywordOnly) + .value("KIND_VAR_KEYWARD", SignatureEnumKind::kKindVarKeyword) + .value("KIND_EMPTY_DEFAULT_VALUE", SignatureEnumKind::kKindEmptyDefaultValue); + (void)py::enum_(*m, "signature_dtype", py::arithmetic()) + .value("T", SignatureEnumDType::kDType) + .value("T1", SignatureEnumDType::kDType1) + .value("T2", SignatureEnumDType::kDType2) + .value("T3", SignatureEnumDType::kDType3) + .value("T4", SignatureEnumDType::kDType4) + .value("T5", SignatureEnumDType::kDType5) + .value("T6", SignatureEnumDType::kDType6) + .value("T7", SignatureEnumDType::kDType7) + .value("T8", SignatureEnumDType::kDType8) + .value("T9", SignatureEnumDType::kDType9) + .value("T_EMPTY_DEFAULT_VALUE", SignatureEnumDType::kDTypeEmptyDefaultValue); + })); +} // namespace mindspore diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc new file mode 100644 index 0000000000..c04c2cca96 --- /dev/null +++ b/mindspore/core/ir/tensor.cc @@ -0,0 +1,520 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/tensor.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "runtime/device/device_address.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace tensor { +constexpr auto kEllipsis = "..."; +constexpr auto kThreshold = 6; + +constexpr auto kThreshold1DFloat = kThreshold * 2; +constexpr auto kThreshold1DInt = kThreshold * 4; +constexpr auto kThreshold1DBool = kThreshold * 2; + +static std::string MakeId() { + // Use atomic to make id generator thread safe. + static std::atomic last_id{1}; + return "T" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed)); +} + +static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) { + return data_type ? data_type->type_id() : defaultTypeId; +} + +static size_t SizeOf(const std::vector &shape) { + return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); +} + +template +std::vector CopyData(const std::vector &shape, void *data, TypeId data_type) { + const size_t count = SizeOf(shape); + switch (data_type) { + case kNumberTypeBool: + case kNumberTypeUInt8: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeInt8: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeInt16: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeInt32: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeInt64: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeUInt16: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeUInt32: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeUInt64: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeFloat16: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeFloat32: { + const float *buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeFloat64: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + default: + break; + } + MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; +} + +template +std::vector CopyData(const std::vector &shape, void *data, size_t data_len) { + size_t size = SizeOf(shape); + if (size * sizeof(T) != data_len) { + MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T) + << " item size " << sizeof(T); + } + auto buf = static_cast(data); + return {buf, buf + size}; +} + +// Tensor data implementation. +template +class TensorDataImpl : public TensorData { + public: + explicit TensorDataImpl(const std::vector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {} + ~TensorDataImpl() = default; + + TensorDataImpl(const std::vector &shape, void *data, size_t data_len) + : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_len)) {} + + TensorDataImpl(const std::vector &shape, void *data, TypeId data_type) + : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_type)) {} + + template + TensorDataImpl(const std::vector &shape, InputIt first, InputIt last) + : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(first, last) {} + + template + TensorDataImpl(const std::vector &shape, Scalar scalar) + : ndim_(shape.size()), data_size_(SizeOf(shape)), data_({static_cast(scalar)}) {} + + ssize_t size() const override { return static_cast(data_size_); } + + ssize_t itemsize() const override { return static_cast(sizeof(T)); } + + ssize_t nbytes() const override { return size() * itemsize(); } + + ssize_t ndim() const override { return static_cast(ndim_); } + + void *data() override { + static std::vector empty_data(1); + if (data_size_ == 0) { + // Prevent null pointer for empty shape. + return empty_data.data(); + } + // Lazy allocation. + if (data_.empty()) { + data_.resize(data_size_); + } + return data_.data(); + } + + bool equals(const TensorData &other) const override { + auto ptr = dynamic_cast *>(&other); + if (ptr) { + return (ptr == this) || ((ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) && (data_ == ptr->data_)); + } + return false; + } + + std::string ToString(const TypeId type, const std::vector &shape) const override { + constexpr auto valid = + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value; + static_assert(valid, "Type is invalid"); + if (data_size_ == 0) { + return ""; + } + if (data_.empty()) { + return ""; + } + + std::ostringstream ss; + if (data_size_ == 1 && ndim_ == 0) { // Scalar + OutputDataString(ss, type, 0, 0, 1); + return ss.str(); + } + ssize_t cursor = 0; + SummaryStringRecursive(ss, type, shape, &cursor, 0); + return ss.str(); + } + + private: + void OutputDataString(std::ostringstream &ss, const TypeId type, ssize_t cursor, ssize_t start, ssize_t end) const { + bool isScalar = ndim_ == 0 && end - start == 1; + int linefeedThreshold; + constexpr auto isFloat = + std::is_same::value || std::is_same::value || std::is_same::value; + for (ssize_t i = start; i < end && (cursor + i) < static_cast(data_size_); i++) { + const auto value = data_[cursor + i]; + if constexpr (isFloat) { + if (isScalar) { + ss << value; + } else { + ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right) + << value; + } + linefeedThreshold = kThreshold1DFloat; + } else if (type == kNumberTypeBool) { + if (isScalar) { + ss << (value == 0 ? "False" : "True"); + } else { + ss << std::setw(5) << std::setiosflags(std::ios::right) << (value == 0 ? "False" : "True"); + } + linefeedThreshold = kThreshold1DBool; + } else { + constexpr auto isSigned = std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; + if constexpr (isSigned) { + if (!isScalar && static_cast(value) >= 0) { + ss << ' '; + } + } + if constexpr (std::is_same::value) { + ss << static_cast(value); + } else if constexpr (std::is_same::value) { + ss << static_cast(value); + } else { + ss << value; + } + linefeedThreshold = kThreshold1DInt; + } + if (!isScalar && i != end - 1) { + ss << ' '; + } + if (!isScalar && ndim_ == 1 && (i + 1) % linefeedThreshold == 0) { + // Add a line feed every {threshold of type} for 1D tensor. + ss << '\n' << ' '; + } + } + } + + void SummaryStringRecursive(std::ostringstream &ss, const TypeId type, const std::vector &shape, ssize_t *cursor, + ssize_t depth) const { + if (depth >= static_cast(ndim_)) { + return; + } + ss << '['; + if (depth == static_cast(ndim_) - 1) { // Bottom dimension + ssize_t num = shape[depth]; + if (num > kThreshold && ndim_ > 1) { + OutputDataString(ss, type, *cursor, 0, kThreshold / 2); + ss << ' ' << kEllipsis << ' '; + OutputDataString(ss, type, *cursor, num - kThreshold / 2, num); + } else { + OutputDataString(ss, type, *cursor, 0, num); + } + *cursor += num; + } else { // Middle dimension + ssize_t num = shape[depth]; + // Handle the first half. + for (ssize_t i = 0; i < std::min(static_cast(kThreshold / 2), num); i++) { + if (i > 0) { + ss << '\n'; + ss << std::setw(depth + 1) << ' '; // Add the indent. + } + SummaryStringRecursive(ss, type, shape, cursor, depth + 1); + } + // Handle the ignored part. + if (num > kThreshold) { + ss << '\n'; + ss << std::setw(depth + 1) << ' '; // Add the indent. + ss << kEllipsis; + // Ignored at this layer. + ssize_t ignored = shape[depth + 1]; + for (ssize_t i = depth + 2; i < static_cast(ndim_); i++) { + ignored *= shape[i]; + } + // Multiple with ignored layers number. + ignored *= num - kThreshold; + + *cursor += ignored; + } + // Handle the second half. + if (num > kThreshold / 2) { + for (ssize_t i = num - kThreshold / 2; i < num; i++) { + ss << '\n'; + ss << std::setw(depth + 1) << ' '; // Add the indent. + SummaryStringRecursive(ss, type, shape, cursor, depth + 1); + } + } + } + ss << ']'; + } + + size_t ndim_{0}; + size_t data_size_{0}; + std::vector data_; +}; + +template +TensorDataPtr MakeTensorData(TypeId data_type, const std::vector &shape, const Args... args) { + switch (data_type) { + case kNumberTypeBool: + case kNumberTypeUInt8: + return std::make_shared>(shape, args...); + case kNumberTypeInt8: + return std::make_shared>(shape, args...); + case kNumberTypeInt16: + return std::make_shared>(shape, args...); + case kNumberTypeInt32: + return std::make_shared>(shape, args...); + case kNumberTypeInt64: + return std::make_shared>(shape, args...); + case kNumberTypeUInt16: + return std::make_shared>(shape, args...); + case kNumberTypeUInt32: + return std::make_shared>(shape, args...); + case kNumberTypeUInt64: + return std::make_shared>(shape, args...); + case kNumberTypeFloat16: + return std::make_shared>(shape, args...); + case kNumberTypeFloat32: + return std::make_shared>(shape, args...); + case kNumberTypeFloat64: + return std::make_shared>(shape, args...); + default: + break; + } + MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; +} + +Tensor::Tensor(const Tensor &tensor) + : MetaTensor(tensor), + init_flag_(tensor.init_flag_), + data_(tensor.data_), + dirty_(tensor.dirty_), + id_(tensor.id_), + device_sync_(tensor.device_sync_) {} + +Tensor::Tensor(const Tensor &tensor, TypeId data_type) + : MetaTensor(data_type, tensor.shape_), + init_flag_(tensor.init_flag_), + data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)), + dirty_(tensor.dirty_), + id_(tensor.id_), + device_sync_(tensor.device_sync_) {} + +Tensor::Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data) + : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} + +Tensor::Tensor(TypeId data_type, const std::vector &shape) + : Tensor(data_type, shape, MakeTensorData(data_type, shape)) {} + +Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, size_t data_len) + : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {} + +Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, TypeId src_data_type) + : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {} + +Tensor::Tensor(const std::vector &input, const TypePtr &data_type) + : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast(input.size())}), + data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), + id_(MakeId()) {} + +Tensor::Tensor(const std::vector &input, const TypePtr &data_type) + : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast(input.size())}), + data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), + id_(MakeId()) {} + +Tensor::Tensor(int64_t input, const TypePtr &data_type) + : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}), + data_(MakeTensorData(data_type_, {}, input)), + id_(MakeId()) {} + +Tensor::Tensor(double input, const TypePtr &data_type) + : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {}), + data_(MakeTensorData(data_type_, {}, input)), + id_(MakeId()) {} + +bool Tensor::operator==(const Tensor &tensor) const { + return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_)); +} + +bool Tensor::ValueEqual(const Tensor &tensor) const { + return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); +} +// assgin value to this tensor +Tensor &Tensor::AssignValue(const Tensor &tensor) { + if (this != &tensor) { + MetaTensor::operator=(tensor); + dirty_ = tensor.dirty_; + device_sync_ = tensor.device_sync_; + data_ = tensor.data_; + id_ = tensor.id_; + } + return *this; +} +abstract::AbstractBasePtr Tensor::ToAbstract() { + auto tens = shared_from_base(); + auto dtype = tens->Dtype(); + if (!IsSubType(dtype, kNumber)) { + MS_LOG(EXCEPTION) << "Expect tensor type kNumber but got: " << dtype->ToString() << "."; + } + auto tensor_shape = tens->shape(); + auto abs_tensor = std::make_shared(dtype, tensor_shape); + abs_tensor->set_value(shared_from_base()); + return abs_tensor; +} + +std::string Tensor::GetShapeAndDataTypeInfo() const { + std::ostringstream buf; + buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); + return buf.str(); +} + +std::string Tensor::ToString() const { + const int small_tensor_size = 30; + std::ostringstream buf; + buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); + // only print small tensor + if (DataSize() < small_tensor_size) { + buf << ", value:" << data().ToString(data_type_, shape()); + } + return buf.str(); +} + +std::string Tensor::ToStringRepr() const { + std::ostringstream buf; + auto type_ptr = this->Dtype(); + MS_EXCEPTION_IF_NULL(type_ptr); + buf << "Tensor shape:[" << shape() << "]" << type_ptr->ToString(); + buf << "\nvalue:" << data().ToString(data_type_, shape()); + return buf.str(); +} + +void Tensor::data_sync() const { + if (device_sync_ != nullptr) { + if (!device_sync_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { + MS_LOG(EXCEPTION) << "SyncDeviceToHost when asnumpy."; + } + } +} + +TypeId Tensor::set_data_type(const TypeId data_type) { + if (data_type != data_type_) { + data_ = MakeTensorData(data_type, shape_, data_->data(), data_type_); + return MetaTensor::set_data_type(data_type); + } + return data_type; +} +} // namespace tensor + +namespace inference { +MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector &shape) { + return new Tensor(data_type, shape); +} + +Tensor::Tensor(TypeId data_type, const std::vector &shape) { + this->tensor_impl_ = std::make_shared(data_type, shape); +} + +Tensor::Tensor(std::shared_ptr tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); } + +TypeId Tensor::data_type() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->data_type(); +} + +TypeId Tensor::set_data_type(TypeId data_type) { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->set_data_type(data_type); +} + +std::vector Tensor::shape() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->shape(); +} + +size_t Tensor::set_shape(const std::vector &shape) { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->set_shape(shape); +} + +int Tensor::DimensionSize(size_t index) const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->DimensionSize(index); +} + +int Tensor::ElementsNum() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->ElementsNum(); +} + +std::size_t Tensor::hash() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->hash(); +} + +std::shared_ptr Tensor::tensor() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_; +} + +size_t Tensor::Size() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->data().nbytes(); +} + +void *Tensor::MutableData() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->data_c(); +} + +} // namespace inference +} // namespace mindspore diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h new file mode 100644 index 0000000000..727fb0fdd8 --- /dev/null +++ b/mindspore/core/ir/tensor.h @@ -0,0 +1,276 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_TENSOR_H_ +#define MINDSPORE_CCSRC_IR_TENSOR_H_ + +#include +#include +#include +#include + +#include "Eigen/Core" +#include "ir/device_sync.h" +#include "ir/meta_tensor.h" +#include "include/ms_tensor.h" +#include "utils/log_adapter.h" + +using float16 = Eigen::half; + +// brief mindspore namespace. +// +// mindspore namespace is the top level namespace of MindSpore project. +// Other namespace should be a sub namespace of mindspore namespace in the ME project. +namespace mindspore { +// brief mindspore::tensor namespace +// +// A sub namespace in ME to support tensor related definition. +namespace tensor { +// Tensor data interface. +class TensorData { + public: + /// Total number of elements. + virtual ssize_t size() const = 0; + /// Byte size of a single element. + virtual ssize_t itemsize() const = 0; + /// Total number of bytes. + virtual ssize_t nbytes() const = 0; + /// Number of dimensions. + virtual ssize_t ndim() const = 0; + /// Data pointer. + virtual void *data() = 0; + /// Is data equals. + virtual bool equals(const TensorData &other) const = 0; + /// To string. + virtual std::string ToString(const TypeId type, const std::vector &shape) const = 0; +}; + +using TensorDataPtr = std::shared_ptr; + +// Tensor entity class +class Tensor : public MetaTensor { + public: + abstract::AbstractBasePtr ToAbstract() override; + + // brief Create tensor from another tensor, data is shared. + // + // param tensor [Tensor] The input tensor. + explicit Tensor(const Tensor &tensor); + + // brief Create tensor with given data type from another tensor. + // + // param tensor [Tensor] The input tensor. + // param data_type [TypeId] The new tensor data type. + Tensor(const Tensor &tensor, TypeId data_type); + + // brief Create tensor with the given shared tensor data. + // + // param data_type [TypeId] Data type of the tensor. + // param shape The shape represented by std::vector of the tensor. + // param data The shared tensor data. + Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data); + + // brief Create an all zero tensor. + // + // param data_type [TypeId] Data type of the tensor. + // param shape The shape represented by std::vector of the tensor. + Tensor(TypeId data_type, const std::vector &shape); + + // brief Create a tensor with input data buffer. + // + // param data_type [TypeId] Data type of the tensor. + // param shape The shape represented by std::vector of the tensor. + // param data The input data to be copied into tensor. + // param data_len The length of data in bytes. + Tensor(TypeId data_type, const std::vector &shape, void *data, size_t data_len); + + // brief Create a tensor with input data buffer and given source data type. + // + // param data_type [TypeId] Data type of the tensor. + // param shape The shape represented by std::vector of the tensor. + // param data The input data to be copied into tensor. + // param src_data_type The source data type. + Tensor(TypeId data_type, const std::vector &shape, void *data, TypeId src_data_type); + + // brief Create 1 dimension tensor from an int vector. + // + // param input [std::vector] the data for tensor + // param data_type [TypeId] data type + explicit Tensor(const std::vector &input, const TypePtr &data_type = nullptr); + + // brief Create 1 dimension tensor from a float vector. + // + // param input [std::vector] the data for tensor + // param data_type [TypeId] data type + explicit Tensor(const std::vector &input, const TypePtr &data_type = nullptr); + + // brief Create 0 dimension tensor from an int scalar. + // + // param input [int64] the data for tensor + // param data_type [TypeId] data type + explicit Tensor(int64_t input, const TypePtr &data_type = nullptr); + + // brief Create 0 dimension tensor from a float scalar. + // + // param input [double] the data for tensor + // param data_type [TypeId] data type + explicit Tensor(double input, const TypePtr &data_type = nullptr); + + ~Tensor() override = default; + + MS_DECLARE_PARENT(Tensor, MetaTensor); + + // brief Compares two Tensor objects. + // + // Compare two tensor objects to see if they have same data type, shape and data address. + // + // param tensor The Tensor object to be compared. + // return true: If having same type, shape and data address, return true, or return false. + bool operator==(const Tensor &tensor) const; + + // It is different from 'operator==' which just compare shape/type/address, + // it do real value comparison. + bool ValueEqual(const Tensor &tensor) const; + + // assgin value to this tensor + Tensor &AssignValue(const Tensor &tensor); + + bool operator==(const Value &other) const override { + if (other.isa()) { + auto &other_ = static_cast(other); + return *this == other_; + } + return false; + } + + // brief Gets tensor's dimension + // + // return The number of dimensions of the tensor data. + int DataDim() const { return static_cast(data().ndim()); } + + // brief Getting tensor data size + // + // return The total number of elements of the tensor data. + int DataSize() const { return static_cast(data().size()); } + + // brief Get the data type fo the tensor for C++ + // + // return [int] The tensor's data type will be cast to int to return. + int data_type_c() const { return static_cast(data_type_); } + + // brief Get the tensor's shape for C++ + // + // return [std::vector] + std::vector shape_c(void) const { return shape(); } + + // brief Get Tensor data pointer for c++ type + // + // return The pointer to the object + void *data_c() { return data().data(); } + + // brief Get Tensor data byte-size for c++ type + // + // return byte size of Tensor data + size_t Size() const { return data().nbytes(); } + + void *data_c() const { return data_->data(); } + + // brief Sync data with device. + void data_sync() const; + + // brief Get the internal data object. + // + // return The reference to internal data object. + TensorData &data() { return *data_; } + + // brief Get the internal data shared pointer. + // + // return The reference to internal data object. + const TensorDataPtr &data_ptr() const { return data_; } + + // brief Get the internal data object. + // + // return The reference to internal data object. + const TensorData &data() const { return *data_; } + + TypeId set_data_type(const TypeId data_type) override; + + std::string GetShapeAndDataTypeInfo() const; + + std::string ToString() const override; + + std::string ToStringRepr() const; + + bool is_init() const { return init_flag_; } + void set_init_flag(bool flag) { init_flag_ = flag; } + + bool is_dirty() const { return dirty_; } + void set_dirty(const bool dirty) { dirty_ = dirty; } + + DeviceSyncPtr device_address() const { return device_sync_; } + void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; } + + std::string id() const { return id_; } + + const bool parse_info_ = true; + + private: + bool init_flag_{false}; + TensorDataPtr data_{nullptr}; + bool dirty_{true}; + std::string id_{""}; + DeviceSyncPtr device_sync_{nullptr}; +}; +using TensorPtr = std::shared_ptr; +using TensorPtrList = std::vector>; +} // namespace tensor + +namespace inference { +class Tensor : public MSTensor { + public: + Tensor(TypeId data_type, const std::vector &shape); + + explicit Tensor(std::shared_ptr tensor_ptr); + + ~Tensor() = default; + + TypeId data_type() const override; + + TypeId set_data_type(const TypeId data_type) override; + + std::vector shape() const override; + + size_t set_shape(const std::vector &shape) override; + + int DimensionSize(size_t index) const override; + + int ElementsNum() const override; + + std::size_t hash() const override; + + std::shared_ptr tensor() const; + + size_t Size() const override; + + void *MutableData() const override; + + protected: + std::shared_ptr tensor_impl_; +}; +} // namespace inference +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_TENSOR_H_ diff --git a/mindspore/core/ir/tensor_py.cc b/mindspore/core/ir/tensor_py.cc new file mode 100644 index 0000000000..ef78d2720e --- /dev/null +++ b/mindspore/core/ir/tensor_py.cc @@ -0,0 +1,389 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/tensor_py.h" + +#include +#include +#include +#include +#include + +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace tensor { + +static TypeId GetDataType(const py::buffer_info &buf) { + if (buf.format.size() == 1) { + switch (buf.format.front()) { + case 'e': + case 'f': + case 'd': + switch (buf.itemsize) { + case 2: + return TypeId::kNumberTypeFloat16; + case 4: + return TypeId::kNumberTypeFloat32; + case 8: + return TypeId::kNumberTypeFloat64; + } + break; + case 'b': + case 'h': + case 'i': + case 'l': + case 'q': + switch (buf.itemsize) { + case 1: + return TypeId::kNumberTypeInt8; + case 2: + return TypeId::kNumberTypeInt16; + case 4: + return TypeId::kNumberTypeInt32; + case 8: + return TypeId::kNumberTypeInt64; + } + break; + case 'B': + case 'H': + case 'I': + case 'L': + case 'Q': + switch (buf.itemsize) { + case 1: + return TypeId::kNumberTypeUInt8; + case 2: + return TypeId::kNumberTypeUInt16; + case 4: + return TypeId::kNumberTypeUInt32; + case 8: + return TypeId::kNumberTypeUInt64; + } + break; + case '?': + return TypeId::kNumberTypeBool; + } + } + MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize; + return TypeId::kTypeUnknown; +} + +static std::string GetPyTypeFormat(TypeId data_type) { + switch (data_type) { + case TypeId::kNumberTypeFloat16: + return "e"; + case TypeId::kNumberTypeFloat32: + return py::format_descriptor::format(); + case TypeId::kNumberTypeFloat64: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt8: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt16: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt32: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt64: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt8: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt16: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt32: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt64: + return py::format_descriptor::format(); + case TypeId::kNumberTypeBool: + return py::format_descriptor::format(); + default: + MS_LOG(WARNING) << "Unsupported DataType " << data_type << "."; + return ""; + } +} + +static bool IsCContiguous(const py::array &input) { + auto flags = static_cast(input.flags()); + return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0; +} + +TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) { + // Get input buffer info. + py::buffer_info buf = input.request(); + // Check data types. + auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown; + auto buf_type = GetDataType(buf); + if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) { + MS_LOG(EXCEPTION) << "Unsupported tensor type!"; + } + // Use buf type as data type if type_ptr not set. + if (data_type == TypeId::kTypeUnknown) { + data_type = buf_type; + } + // Convert input array to C contiguous if need. + std::unique_ptr tmp_buf; + if (!IsCContiguous(input)) { + Py_buffer pybuf; + if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS)) { + MS_LOG(EXCEPTION) << "Failed to get buffer from the input!"; + } + tmp_buf = std::make_unique(pybuf.len); + if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C')) { + MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer."; + } + PyBuffer_Release(&pybuf); + buf.ptr = tmp_buf.get(); + } + // Get tensor shape. + std::vector shape(buf.shape.begin(), buf.shape.end()); + if (data_type == buf_type) { + // Use memory copy if input data type is same as the required type. + return std::make_shared(data_type, shape, buf.ptr, buf.size * buf.itemsize); + } + // Create tensor with data type converted. + return std::make_shared(data_type, shape, buf.ptr, buf_type); +} + +static std::vector GetStrides(const std::vector &shape, ssize_t item_size) { + std::vector strides; + strides.reserve(shape.size()); + const auto ndim = shape.size(); + for (size_t i = 0; i < ndim; ++i) { + auto stride = item_size; + for (size_t j = i + 1; j < ndim; ++j) { + stride *= shape[j]; + } + strides.push_back(stride); + } + return strides; +} + +static py::buffer_info GetPyBufferInfo(const Tensor &tensor) { + std::vector shape(tensor.shape().begin(), tensor.shape().end()); + std::vector strides = GetStrides(shape, tensor.data().itemsize()); + return py::buffer_info{ + tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides}; +} + +py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) { + auto &shape = tensor.shape(); + py::tuple dims(shape.size()); + for (size_t i = 0; i < dims.size(); ++i) { + dims[i] = py::int_(shape[i]); + } + return dims; +} + +py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { + tensor.data_sync(); + auto info = GetPyBufferInfo(tensor); + py::object self = py::cast(&tensor); + return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); +} + +py::array TensorPy::AsNumpy(const Tensor &tensor) { + auto info = GetPyBufferInfo(tensor); + py::object self = py::cast(&tensor); + return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); +} + +static std::vector GetShapeFromTuple(const py::tuple &tuple) { + std::vector shape; + const size_t size = tuple.size(); + shape.reserve(tuple.size()); + for (size_t i = 0; i < size; ++i) { + shape.push_back(py::int_(tuple[i])); + } + return shape; +} + +REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { + // Define python MetaTensor class. + (void)py::class_>(*m, "MetaTensor") + .def(py::init>(), py::arg("dtype"), py::arg("shape")) + .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) + .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") + .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") + .def(py::pickle( + [](const MetaTensor &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(static_cast(t.data_type()), t.shape()); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 2) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + MetaTensor tensor(TypeId(t[0].cast()), t[1].cast>()); + return tensor; + })); + // Define python Tensor class. + // dtype should define before Tensor, because Tensor init depend dtype + (void)py::class_>(*m, "Tensor") + .def(py::init([](const Tensor &tensor) { return std::make_shared(tensor); }), + py::arg("input")) + .def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) { + TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown; + if (data_type == kTypeUnknown || tensor.data_type() == data_type) { + return std::make_shared(tensor); + } + return std::make_shared(tensor, data_type); + }), + py::arg("input"), py::arg("dtype")) + .def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) { + auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64; + return std::make_shared(data_type, GetShapeFromTuple(shape)); + }), + py::arg("dtype"), py::arg("shape")) + .def(py::init([](const py::array &input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(input, type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::float_ input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::int_ input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::list input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::tuple input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_) + .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag) + .def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter( + Get the tensor's data type. + + Returns: + type, the data type of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) + >>> data.dtype + Int32 + )mydelimiter") + .def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter( + Get the tensor's shape. + + Returns: + tuple[int], the shape of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((3, 3))) + >>> data.shape() + (3, 3) + )mydelimiter") + .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter( + Convert tensor to numpy.ndarray. + + Returns: + numpy.ndarray. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> array = data.asnumpy() + >>> array + array([[1., 1., 1.], + [1., 1., 1.]]) + )mydelimiter") + .def("size", &Tensor::DataSize, R"mydelimiter( + Get tensor's data size. + + Returns: + int, the size of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.size() + 6 + )mydelimiter") + .def("is_init", &Tensor::is_init, R"mydelimiter( + Get tensor init_flag. + + Returns: + bool, whether the tensor init. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.is_init() + False + )mydelimiter") + .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter( + Set tensor init_flag. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.set_init_flag(True) + )mydelimiter") + .def("dim", &Tensor::DataDim, R"mydelimiter( + Get tensor's data dimension. + + Returns: + int, the dimension of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.dim() + 2 + )mydelimiter") + .def("assign_value", &Tensor::AssignValue, R"mydelimiter( + Assign another tensor value to this. + + Arg: + value (:class:`mindspore.tensor`): The value tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) + >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32)) + >>> data.assign_value(data2) + >>> data.shape + (2, 2) + )mydelimiter") + .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( + Set the tensor's data type. + + Arg: + dtype (:class:`mindspore.dtype`): The type of output tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) + >>> data.set_dtype(mindspore.int32) + mindspore.int32 + )mydelimiter") + .def("__str__", &Tensor::ToString) + .def("__repr__", &Tensor::ToStringRepr) + .def(py::pickle( + [](const Tensor &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(TensorPy::AsNumpy(t)); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + return TensorPy::MakeTensor(t[0].cast()); + })); + })); +} // namespace tensor +} // namespace mindspore diff --git a/mindspore/core/ir/tensor_py.h b/mindspore/core/ir/tensor_py.h new file mode 100644 index 0000000000..f917584977 --- /dev/null +++ b/mindspore/core/ir/tensor_py.h @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_TENSOR_PY_H_ +#define MINDSPORE_CCSRC_IR_TENSOR_PY_H_ + +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/numpy.h" + +#include "ir/tensor.h" + +namespace py = pybind11; + +namespace pybind11 { +namespace detail { +// Similar to enums in `pybind11/numpy.h`. Determined by doing: +// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)' +constexpr int NPY_FLOAT16 = 23; + +template +struct npy_scalar_caster { + PYBIND11_TYPE_CASTER(T, _("PleaseOverride")); + using Array = array_t; + + bool load(handle src, bool convert) { + // Taken from Eigen casters. Permits either scalar dtype or scalar array. + handle type = dtype::of().attr("type"); + if (!convert && !isinstance(src) && !isinstance(src, type)) return false; + + Array tmp = Array::ensure(src); + if (tmp && tmp.size() == 1 && tmp.ndim() == 0) { + this->value = *tmp.data(); + return true; + } + + return false; + } + + static handle cast(T src, return_value_policy, handle) { + Array tmp({1}); + tmp.mutable_at(0) = src; + tmp.resize({}); + + // You could also just return the array if you want a scalar array. + object scalar = tmp[tuple()]; + return scalar.release(); + } +}; + +template <> +struct npy_format_descriptor { + static constexpr auto name = "float16"; + static pybind11::dtype dtype() { + handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16); + return reinterpret_borrow(ptr); + } + virtual ~npy_format_descriptor() {} +}; + +template <> +struct type_caster : public npy_scalar_caster { + static constexpr auto name = "float16"; +}; +} // namespace detail +} // namespace pybind11 + +// brief mindspore namespace. +// +// mindspore namespace is the top level namespace of Mindsporeession project. +// Other namespace should be a sub namespace of mindspore namespace in the ME project. +namespace mindspore { +// brief mindspore::tensor namespace +// +// A sub namespace in ME to support tensor related definition. +namespace tensor { +// Tensor python wrapper and adapter class. +class TensorPy { + public: + // brief Create Tensor from a numpy array object. + // + // param input [py::array] Data value of the tensor. + // param data_type [TypeId] Data type of the tensor. + static TensorPtr MakeTensor(const py::array &input, const TypePtr &data_type = nullptr); + + static py::array SyncAsNumpy(const Tensor &tensor); + + static py::array AsNumpy(const Tensor &tensor); + + static py::tuple GetPyTupleShape(const Tensor &tensor); +}; + +} // namespace tensor +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_TENSOR_PY_H_ diff --git a/mindspore/ccsrc/ir/value.cc b/mindspore/core/ir/value.cc similarity index 100% rename from mindspore/ccsrc/ir/value.cc rename to mindspore/core/ir/value.cc diff --git a/mindspore/core/ir/value.h b/mindspore/core/ir/value.h new file mode 100644 index 0000000000..535de81adf --- /dev/null +++ b/mindspore/core/ir/value.h @@ -0,0 +1,306 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_VALUE_H_ +#define MINDSPORE_CCSRC_IR_VALUE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "base/base.h" +#include "ir/anf.h" +#include "ir/dtype.h" +#include "ir/scalar.h" +#include "ir/dtype/ref.h" +#include "utils/hashing.h" +#include "common/utils.h" + +namespace mindspore { +class ValueSequeue : public Value { + public: + explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) { + TypePtrList t_list; + (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) { + MS_EXCEPTION_IF_NULL(ele); + return ele->type(); + }); + TypePtr t = std::make_shared(t_list); + type_ = t; + } + ValueSequeue(const std::initializer_list &elements) : elements_(elements.begin(), elements.end()) { + TypePtrList t_list; + (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list), + [](const ValuePtr &ele) { return ele->type(); }); + TypePtr t = std::make_shared(t_list); + type_ = t; + } + ~ValueSequeue() override = default; + MS_DECLARE_PARENT(ValueSequeue, Value) + std::size_t hash() const override { return hash_combine(tid(), std::hash{}(elements_.size())); } + std::size_t size() const { return elements_.size(); } + bool erase(size_t idx); + const ValuePtr operator[](const std::size_t &dim) const; + const ValuePtrList &value() const { return elements_; } + bool operator==(const Value &other) const override; + bool operator==(const ValueSequeue &other) const; + std::string ToString() const override; + std::string DumpText() const override; + + protected: + ValuePtrList elements_; +}; +using ValueSequeuePtr = std::shared_ptr; + +class ValueTuple : public ValueSequeue { + public: + explicit ValueTuple(const std::vector &elements) : ValueSequeue(elements) {} + ValueTuple(const std::initializer_list &elements) : ValueSequeue(elements) {} + ~ValueTuple() override = default; + MS_DECLARE_PARENT(ValueTuple, ValueSequeue) + abstract::AbstractBasePtr ToAbstract() override; + + std::string DumpText() const override { return "(" + ValueSequeue::DumpText() + ")"; } + std::string ToString() const override { return "(" + ValueSequeue::ToString() + ")"; } +}; +using ValueTuplePtr = std::shared_ptr; + +class ValueList : public ValueSequeue { + public: + explicit ValueList(const std::vector &elements) : ValueSequeue(elements) {} + ValueList(const std::initializer_list &elements) : ValueSequeue(elements) {} + ~ValueList() override = default; + MS_DECLARE_PARENT(ValueList, ValueSequeue) + abstract::AbstractBasePtr ToAbstract() override; + + std::string DumpText() const override { return "[" + ValueSequeue::DumpText() + "]"; } + std::string ToString() const override { return "[" + ValueSequeue::ToString() + "]"; } +}; +using ValueListPtr = std::shared_ptr; + +inline ValuePtr MakeValue(const std::vector &v) { return std::make_shared(v); } +inline ValuePtr MakeValue(std::initializer_list v) { return std::make_shared(v); } + +template +struct is_vector : public std::false_type {}; +template +struct is_vector> : public std::true_type {}; + +template ::value, typename T::value_type>::type> +ValuePtr MakeValue(const T &vec) { + std::vector list; + (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); }); + return std::make_shared(list); +} + +class ValueSlice : public Value { + public: + ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step) + : start_(start), stop_(stop), step_(step) {} + ~ValueSlice() override = default; + MS_DECLARE_PARENT(ValueSlice, Value) + std::size_t hash() const override; + bool operator==(const Value &other) const override; + bool operator==(const ValueSlice &other) const; + + std::string ToString() const override; + + abstract::AbstractBasePtr ToAbstract() override; + std::string DumpText() const override { return ToString(); } + ValuePtr start() const { return start_; } + ValuePtr stop() const { return stop_; } + ValuePtr step() const { return step_; } + + private: + ValuePtr start_; + ValuePtr stop_; + ValuePtr step_; +}; +using ValueSlicePtr = std::shared_ptr; + +class KeywordArg : public Value { + public: + KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {} + ~KeywordArg() override = default; + MS_DECLARE_PARENT(KeywordArg, Value) + std::size_t hash() const override; + ValuePtr get_value() const { return value_; } + bool operator==(const Value &other) const override; + bool operator==(const KeywordArg &other) const; + + std::string ToString() const override; + + abstract::AbstractBasePtr ToAbstract() override; + std::string DumpText() const override { return ToString(); } + + private: + std::string key_; + ValuePtr value_; +}; +using KeywordArgPtr = std::shared_ptr; + +class ValueDictionary : public Value { + public: + explicit ValueDictionary(const std::vector> &key_values) : key_values_(key_values) {} + ~ValueDictionary() override = default; + MS_DECLARE_PARENT(ValueDictionary, Value) + std::size_t hash() const override { return hash_combine(tid(), std::hash{}(key_values_.size())); } + std::size_t size() const { return key_values_.size(); } + const ValuePtr operator[](const std::string &key) const; + const std::vector> &value() const { return key_values_; } + bool operator==(const Value &other) const override; + bool operator==(const ValueDictionary &other) const; + + std::string ToString() const override { + std::ostringstream buffer; + std::vector keys; + std::vector values; + for (const auto &kv : key_values_) { + keys.push_back(kv.first); + values.push_back(kv.second); + } + buffer << "(Dict: " + << " keys:("; + for (const auto &key : keys) { + buffer << key << ", "; + } + buffer << ") values:("; + for (const auto &value : values) { + MS_EXCEPTION_IF_NULL(value); + buffer << value->DumpText() << ", "; + } + buffer << ")"; + return buffer.str(); + } + abstract::AbstractBasePtr ToAbstract() override; + std::string DumpText() const override { return ToString(); } + + private: + std::vector> key_values_; +}; +using ValueDictionaryPtr = std::shared_ptr; + +class StringImm : public Value { + public: + explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash{}(str_)) {} + + ~StringImm() override = default; + MS_DECLARE_PARENT(StringImm, Value) + std::size_t hash() const override { return hash_; } + const std::string &value() const { return str_; } + bool operator==(const Value &other) const override; + bool operator==(const StringImm &other) const; + abstract::AbstractBasePtr ToAbstract() override; + std::string ToString() const override { return str_; } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "\"" << str_ << "\""; + return oss.str(); + } + + private: + std::string str_; + std::size_t hash_ = 0; +}; +using StringImmPtr = std::shared_ptr; +IMM_TRAITS(StringImmPtr, std::string) +IMM_TRAITS(StringImmPtr, const char *) + +class RefKey : public Value { + public: + explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash{}(tag)) {} + + ~RefKey() override = default; + MS_DECLARE_PARENT(RefKey, Value) + std::size_t hash() const override { return hash_; } + const std::string &tag() const { return tag_; } + bool operator==(const Value &other) const override; + bool operator==(const RefKey &other) const; + abstract::AbstractBasePtr ToAbstract() override; + std::string ToString() const override { return "RefKey[" + tag_ + "]"; } + + std::string DumpText() const override { + std::ostringstream oss; + oss << "RefKey[\"" << tag_ << "\"]"; + return oss.str(); + } + + private: + std::string tag_; + std::size_t hash_ = 0; +}; +using RefKeyPtr = std::shared_ptr; + +class AnyValue : public Value { + public: + AnyValue() = default; + ~AnyValue() override = default; + MS_DECLARE_PARENT(AnyValue, Value) + std::size_t hash() const override { return tid(); } + bool operator==(const Value &other) const override; + abstract::AbstractBasePtr ToAbstract() override; +}; +extern const ValuePtr kAnyValue; + +template <> +inline const char *GetValue(const ValuePtr &value) { + if (value == nullptr) { + MS_LOG(EXCEPTION) << "Value is nullptr"; + } + auto imm = value->cast(); + if (imm == nullptr) { + MS_LOG(EXCEPTION) << "GetValue:" << value->ToString() << ", Type:" << value->type_name(); + } + return common::SafeCStr(imm->value()); +} + +template ::type, + typename U = typename std::enable_if::value, typename S::value_type>::type> +std::vector GetValue(const ValuePtr &value) { + if (value == nullptr) { + MS_LOG(EXCEPTION) << "Value is nullptr"; + } + + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Error GetValue for value: " << value->ToString() << ", type: vector<" << typeid(U).name() + << ">"; + } + std::vector rets; + const std::vector &vals = value->cast()->value(); + (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets), + [](const ValuePtr &v) { return GetValue(v); }); + return rets; +} + +inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared(t); } + +template ::value>::type> +inline ValueNodePtr NewValueNode(const std::shared_ptr &x) { + return NewValueNode(MakeValue(x)); +} + +template ::value>::type> +inline ValueNodePtr NewValueNode(const T &x) { + return NewValueNode(MakeValue(x)); +} +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_VALUE_H_ diff --git a/mindspore/core/ir/value_extends.cc b/mindspore/core/ir/value_extends.cc new file mode 100644 index 0000000000..c75da80665 --- /dev/null +++ b/mindspore/core/ir/value_extends.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/value.h" +#include +#include +#include +#include + +#include "abstract/abstract_value.h" + +namespace mindspore { +using ContextPtr = abstract::AnalysisContextPtr; + +abstract::AbstractBasePtr Scalar::ToAbstract() { + return std::make_shared(shared_from_base()); +} + +abstract::AbstractBasePtr StringImm::ToAbstract() { + return std::make_shared(shared_from_base(), std::make_shared()); +} + +abstract::AbstractBasePtr RefKey::ToAbstract() { + auto refkey = std::make_shared(); + refkey->set_value(shared_from_base()); + return refkey; +} + +abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared(); } + +abstract::AbstractBasePtr ValueTuple::ToAbstract() { + abstract::AbstractBasePtrList a_list; + (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { + MS_EXCEPTION_IF_NULL(ele); + return ele->ToAbstract(); + }); + return std::make_shared(a_list); +} + +abstract::AbstractBasePtr ValueList::ToAbstract() { + abstract::AbstractBasePtrList a_list; + (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { + MS_EXCEPTION_IF_NULL(ele); + return ele->ToAbstract(); + }); + return std::make_shared(a_list); +} + +abstract::AbstractBasePtr ValueSlice::ToAbstract() { + MS_EXCEPTION_IF_NULL(start_); + MS_EXCEPTION_IF_NULL(stop_); + MS_EXCEPTION_IF_NULL(step_); + abstract::AbstractBasePtr start = start_->ToAbstract(); + abstract::AbstractBasePtr end = stop_->ToAbstract(); + abstract::AbstractBasePtr step = step_->ToAbstract(); + return std::make_shared(start, end, step); +} + +abstract::AbstractBasePtr KeywordArg::ToAbstract() { + MS_EXCEPTION_IF_NULL(value_); + abstract::AbstractBasePtr argument = value_->ToAbstract(); + return std::make_shared(key_, argument); +} + +abstract::AbstractBasePtr ValueDictionary::ToAbstract() { + std::vector> kv; + (void)std::transform( + key_values_.begin(), key_values_.end(), std::back_inserter(kv), + [](const std::pair &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); + return std::make_shared(kv); +} +} // namespace mindspore diff --git a/mindspore/core/ir/value_py.cc b/mindspore/core/ir/value_py.cc new file mode 100644 index 0000000000..1d80c74c4d --- /dev/null +++ b/mindspore/core/ir/value_py.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/value.h" +#include + +#include "pybind_api/api_register.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +// Define python 'RefKey' class. +REGISTER_PYBIND_DEFINE( + RefKey, ([](const py::module *m) { + (void)py::class_>(*m, "RefKey").def(py::init(), py::arg("tag")); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/visitor.cc b/mindspore/core/ir/visitor.cc similarity index 100% rename from mindspore/ccsrc/ir/visitor.cc rename to mindspore/core/ir/visitor.cc diff --git a/mindspore/ccsrc/ir/visitor.h b/mindspore/core/ir/visitor.h similarity index 100% rename from mindspore/ccsrc/ir/visitor.h rename to mindspore/core/ir/visitor.h diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index f0070b428d..b2d26b41ee 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -18,12 +18,13 @@ datasets in special format, including mindrecord, tfrecord, manifest. Users can also create samplers with this module to sample data. """ -from .core.configuration import config +from .core import config from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \ GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\ TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ WeightedRandomSampler, Sampler +from .engine.cache_client import DatasetCache from .engine.serializer_deserializer import serialize, deserialize, show from .engine.graphdata import GraphData diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py new file mode 100644 index 0000000000..c863186d97 --- /dev/null +++ b/mindspore/dataset/core/config.py @@ -0,0 +1,195 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +The configuration manager. +""" +import random +import numpy +import mindspore._c_dataengine as cde + +__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers', + 'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load'] + +INT32_MAX = 2147483647 +UINT32_MAX = 4294967295 + +_config = cde.GlobalContext.config_manager() + + +def set_seed(seed): + """ + Set the seed to be used in any random generator. This is used to produce deterministic results. + + Note: + This set_seed function sets the seed in the python random library and numpy.random library + for deterministic python augmentations using randomness. This set_seed function should + be called with every iterator created to reset the random seed. In our pipeline this + does not guarantee deterministic results with num_parallel_workers > 1. + + Args: + seed(int): seed to be set. + + Raises: + ValueError: If seed is invalid (< 0 or > MAX_UINT_32). + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the new seed value, now operators with a random seed will use new seed value. + >>> ds.config.set_seed(1000) + """ + if seed < 0 or seed > UINT32_MAX: + raise ValueError("Seed given is not within the required range.") + _config.set_seed(seed) + random.seed(seed) + # numpy.random isn't thread safe + numpy.random.seed(seed) + + +def get_seed(): + """ + Get the seed. + + Returns: + Int, seed. + """ + return _config.get_seed() + + +def set_prefetch_size(size): + """ + Set the number of rows to be prefetched. + + Args: + size (int): total number of rows to be prefetched. + + Raises: + ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32). + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the new prefetch value. + >>> ds.config.set_prefetch_size(1000) + """ + if size <= 0 or size > INT32_MAX: + raise ValueError("Prefetch size given is not within the required range.") + _config.set_op_connector_size(size) + + +def get_prefetch_size(): + """ + Get the prefetch size in number of rows. + + Returns: + Size, total number of rows to be prefetched. + """ + return _config.get_op_connector_size() + + +def set_num_parallel_workers(num): + """ + Set the default number of parallel workers. + + Args: + num (int): number of parallel workers to be used as a default for each operation. + + Raises: + ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32). + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers. + >>> ds.config.set_num_parallel_workers(8) + """ + if num <= 0 or num > INT32_MAX: + raise ValueError("Num workers given is not within the required range.") + _config.set_num_parallel_workers(num) + + +def get_num_parallel_workers(): + """ + Get the default number of parallel workers. + + Returns: + Int, number of parallel workers to be used as a default for each operation + """ + return _config.get_num_parallel_workers() + + +def set_monitor_sampling_interval(interval): + """ + Set the default interval(ms) of monitor sampling. + + Args: + interval (int): interval(ms) to be used to performance monitor sampling. + + Raises: + ValueError: If interval is invalid (<= 0 or > MAX_INT_32). + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the new interval value. + >>> ds.config.set_monitor_sampling_interval(100) + """ + if interval <= 0 or interval > INT32_MAX: + raise ValueError("Interval given is not within the required range.") + _config.set_monitor_sampling_interval(interval) + + +def get_monitor_sampling_interval(): + """ + Get the default interval of performance monitor sampling. + + Returns: + Interval: interval(ms) of performance monitor sampling. + """ + return _config.get_monitor_sampling_interval() + + +def __str__(): + """ + String representation of the configurations. + + Returns: + Str, configurations. + """ + return str(_config) + + +def load(file): + """ + Load configuration from a file. + + Args: + file (str): path the config file to be loaded. + + Raises: + RuntimeError: If file is invalid and parsing fails. + + Examples: + >>> import mindspore.dataset as ds + >>> # sets the default value according to values in configuration file. + >>> ds.config.load("path/to/config/file") + >>> # example config file: + >>> # { + >>> # "logFilePath": "/tmp", + >>> # "rowsPerBuffer": 32, + >>> # "numParallelWorkers": 4, + >>> # "workerConnectorSize": 16, + >>> # "opConnectorSize": 16, + >>> # "seed": 5489, + >>> # "monitorSamplingInterval": 30 + >>> # } + """ + _config.load(file) diff --git a/mindspore/dataset/core/configuration.py b/mindspore/dataset/core/configuration.py deleted file mode 100644 index 5376c668c4..0000000000 --- a/mindspore/dataset/core/configuration.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -The configuration manager. -""" -import random -import numpy -import mindspore._c_dataengine as cde - -INT32_MAX = 2147483647 -UINT32_MAX = 4294967295 - - -class ConfigurationManager: - """The configuration manager""" - - def __init__(self): - self.config = cde.GlobalContext.config_manager() - - def set_seed(self, seed): - """ - Set the seed to be used in any random generator. This is used to produce deterministic results. - - Note: - This set_seed function sets the seed in the python random library and numpy.random library - for deterministic python augmentations using randomness. This set_seed function should - be called with every iterator created to reset the random seed. In our pipeline this - does not guarantee deterministic results with num_parallel_workers > 1. - - Args: - seed(int): seed to be set - - Raises: - ValueError: If seed is invalid (< 0 or > MAX_UINT_32). - - Examples: - >>> import mindspore.dataset as ds - >>> con = ds.engine.ConfigurationManager() - >>> # sets the new seed value, now operators with a random seed will use new seed value. - >>> con.set_seed(1000) - """ - if seed < 0 or seed > UINT32_MAX: - raise ValueError("Seed given is not within the required range") - self.config.set_seed(seed) - random.seed(seed) - # numpy.random isn't thread safe - numpy.random.seed(seed) - - def get_seed(self): - """ - Get the seed - - Returns: - Int, seed. - """ - return self.config.get_seed() - - def set_prefetch_size(self, size): - """ - Set the number of rows to be prefetched. - - Args: - size: total number of rows to be prefetched. - - Raises: - ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32). - - Examples: - >>> import mindspore.dataset as ds - >>> con = ds.engine.ConfigurationManager() - >>> # sets the new prefetch value. - >>> con.set_prefetch_size(1000) - """ - if size <= 0 or size > INT32_MAX: - raise ValueError("Prefetch size given is not within the required range") - self.config.set_op_connector_size(size) - - def get_prefetch_size(self): - """ - Get the prefetch size in number of rows. - - Returns: - Size, total number of rows to be prefetched. - """ - return self.config.get_op_connector_size() - - def set_num_parallel_workers(self, num): - """ - Set the default number of parallel workers - - Args: - num: number of parallel workers to be used as a default for each operation - - Raises: - ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32). - - Examples: - >>> import mindspore.dataset as ds - >>> con = ds.engine.ConfigurationManager() - >>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers. - >>> con.set_num_parallel_workers(8) - """ - if num <= 0 or num > INT32_MAX: - raise ValueError("Num workers given is not within the required range") - self.config.set_num_parallel_workers(num) - - def get_num_parallel_workers(self): - """ - Get the default number of parallel workers. - - Returns: - Int, number of parallel workers to be used as a default for each operation - """ - return self.config.get_num_parallel_workers() - - def set_monitor_sampling_interval(self, interval): - """ - Set the default interval(ms) of monitor sampling. - - Args: - interval: interval(ms) to be used to performance monitor sampling. - - Raises: - ValueError: If interval is invalid (<= 0 or > MAX_INT_32). - - Examples: - >>> import mindspore.dataset as ds - >>> con = ds.engine.ConfigurationManager() - >>> # sets the new interval value. - >>> con.set_monitor_sampling_interval(100) - """ - if interval <= 0 or interval > INT32_MAX: - raise ValueError("Interval given is not within the required range") - self.config.set_monitor_sampling_interval(interval) - - def get_monitor_sampling_interval(self): - """ - Get the default interval of performance monitor sampling. - - Returns: - Interval: interval(ms) of performance monitor sampling. - """ - return self.config.get_monitor_sampling_interval() - - def __str__(self): - """ - String representation of the configurations. - - Returns: - Str, configurations. - """ - return str(self.config) - - def load(self, file): - """ - Load configuration from a file. - - Args: - file: path the config file to be loaded - - Raises: - RuntimeError: If file is invalid and parsing fails. - - Examples: - >>> import mindspore.dataset as ds - >>> con = ds.engine.ConfigurationManager() - >>> # sets the default value according to values in configuration file. - >>> con.load("path/to/config/file") - >>> # example config file: - >>> # { - >>> # "logFilePath": "/tmp", - >>> # "rowsPerBuffer": 32, - >>> # "numParallelWorkers": 4, - >>> # "workerConnectorSize": 16, - >>> # "opConnectorSize": 16, - >>> # "seed": 5489, - >>> # "monitorSamplingInterval": 30 - >>> # } - """ - self.config.load(file) - - -config = ConfigurationManager() diff --git a/mindspore/dataset/core/validator_helpers.py b/mindspore/dataset/core/validator_helpers.py new file mode 100644 index 0000000000..8806babd63 --- /dev/null +++ b/mindspore/dataset/core/validator_helpers.py @@ -0,0 +1,360 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +General Validators. +""" +import inspect +from multiprocessing import cpu_count +import os +import numpy as np +from ..engine import samplers + +# POS_INT_MIN is used to limit values from starting from 0 +POS_INT_MIN = 1 +UINT8_MAX = 255 +UINT8_MIN = 0 +UINT32_MAX = 4294967295 +UINT32_MIN = 0 +UINT64_MAX = 18446744073709551615 +UINT64_MIN = 0 +INT32_MAX = 2147483647 +INT32_MIN = -2147483648 +INT64_MAX = 9223372036854775807 +INT64_MIN = -9223372036854775808 +FLOAT_MAX_INTEGER = 16777216 +FLOAT_MIN_INTEGER = -16777216 +DOUBLE_MAX_INTEGER = 9007199254740992 +DOUBLE_MIN_INTEGER = -9007199254740992 + +valid_detype = [ + "bool", "int8", "int16", "int32", "int64", "uint8", "uint16", + "uint32", "uint64", "float16", "float32", "float64", "string" +] + + +def pad_arg_name(arg_name): + if arg_name != "": + arg_name = arg_name + " " + return arg_name + + +def check_value(value, valid_range, arg_name=""): + arg_name = pad_arg_name(arg_name) + if value < valid_range[0] or value > valid_range[1]: + raise ValueError( + "Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0], + valid_range[1])) + + +def check_range(values, valid_range, arg_name=""): + arg_name = pad_arg_name(arg_name) + if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]: + raise ValueError( + "Input {0}is not within the required interval of ({1} to {2}).".format(arg_name, valid_range[0], + valid_range[1])) + + +def check_positive(value, arg_name=""): + arg_name = pad_arg_name(arg_name) + if value <= 0: + raise ValueError("Input {0}must be greater than 0.".format(arg_name)) + + +def check_positive_float(value, arg_name=""): + arg_name = pad_arg_name(arg_name) + type_check(value, (float,), arg_name) + check_positive(value, arg_name) + + +def check_2tuple(value, arg_name=""): + if not (isinstance(value, tuple) and len(value) == 2): + raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name)) + + +def check_uint8(value, arg_name=""): + type_check(value, (int,), arg_name) + check_value(value, [UINT8_MIN, UINT8_MAX]) + + +def check_uint32(value, arg_name=""): + type_check(value, (int,), arg_name) + check_value(value, [UINT32_MIN, UINT32_MAX]) + + +def check_pos_int32(value, arg_name=""): + type_check(value, (int,), arg_name) + check_value(value, [POS_INT_MIN, INT32_MAX], arg_name) + + +def check_uint64(value, arg_name=""): + type_check(value, (int,), arg_name) + check_value(value, [UINT64_MIN, UINT64_MAX]) + + +def check_pos_int64(value, arg_name=""): + type_check(value, (int,), arg_name) + check_value(value, [UINT64_MIN, INT64_MAX]) + + +def check_pos_float32(value, arg_name=""): + check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER], arg_name) + + +def check_pos_float64(value, arg_name=""): + check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER], arg_name) + + +def check_valid_detype(type_): + if type_ not in valid_detype: + raise ValueError("Unknown column type") + return True + + +def check_columns(columns, name): + """ + Validate strings in column_names. + + Args: + columns (list): list of column_names. + name (str): name of columns. + + Returns: + Exception: when the value is not correct, otherwise nothing. + """ + type_check(columns, (list, str), name) + if isinstance(columns, list): + if not columns: + raise ValueError("{0} should not be empty".format(name)) + for i, column_name in enumerate(columns): + if not column_name: + raise ValueError("{0}[{1}] should not be empty".format(name, i)) + + col_names = ["{0}[{1}]".format(name, i) for i in range(len(columns))] + type_check_list(columns, (str,), col_names) + if len(set(columns)) != len(columns): + raise ValueError("Every column name should not be same with others in column_names.") + + +def parse_user_args(method, *args, **kwargs): + """ + Parse user arguments in a function. + + Args: + method (method): a callable function. + *args: user passed args. + **kwargs: user passed kwargs. + + Returns: + user_filled_args (list): values of what the user passed in for the arguments. + ba.arguments (Ordered Dict): ordered dict of parameter and argument for what the user has passed. + """ + sig = inspect.signature(method) + if 'self' in sig.parameters or 'cls' in sig.parameters: + ba = sig.bind(method, *args, **kwargs) + ba.apply_defaults() + params = list(sig.parameters.keys())[1:] + else: + ba = sig.bind(*args, **kwargs) + ba.apply_defaults() + params = list(sig.parameters.keys()) + + user_filled_args = [ba.arguments.get(arg_value) for arg_value in params] + return user_filled_args, ba.arguments + + +def type_check_list(args, types, arg_names): + """ + Check the type of each parameter in the list. + + Args: + args (list, tuple): a list or tuple of any variable. + types (tuple): tuple of all valid types for arg. + arg_names (list, tuple of str): the names of args. + + Returns: + Exception: when the type is not correct, otherwise nothing. + """ + type_check(args, (list, tuple,), arg_names) + if len(args) != len(arg_names): + raise ValueError("List of arguments is not the same length as argument_names.") + for arg, arg_name in zip(args, arg_names): + type_check(arg, types, arg_name) + + +def type_check(arg, types, arg_name): + """ + Check the type of the parameter. + + Args: + arg : any variable. + types (tuple): tuple of all valid types for arg. + arg_name (str): the name of arg. + + Returns: + Exception: when the type is not correct, otherwise nothing. + """ + # handle special case of booleans being a subclass of ints + print_value = '\"\"' if repr(arg) == repr('') else arg + + if int in types and bool not in types: + if isinstance(arg, bool): + raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types)) + if not isinstance(arg, types): + raise TypeError("Argument {0} with value {1} is not of type {2}.".format(arg_name, print_value, types)) + + +def check_filename(path): + """ + check the filename in the path. + + Args: + path (str): the path. + + Returns: + Exception: when error. + """ + if not isinstance(path, str): + raise TypeError("path: {} is not string".format(path)) + filename = os.path.basename(path) + + # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', + # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>', + # '*', '(', '%', ')', '-', '=', '{', '?', '$' + forbidden_symbols = set(r'\/:*?"<>|`&\';') + + if set(filename) & forbidden_symbols: + raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'") + + if filename.startswith(' ') or filename.endswith(' '): + raise ValueError("filename should not start/end with space") + + return True + + +def check_dir(dataset_dir): + if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): + raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir)) + + +def check_file(dataset_file): + check_filename(dataset_file) + if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK): + raise ValueError("The file {} does not exist or permission denied!".format(dataset_file)) + + +def check_sampler_shuffle_shard_options(param_dict): + """ + Check for valid shuffle, sampler, num_shards, and shard_id inputs. + Args: + param_dict (dict): param_dict. + + Returns: + Exception: ValueError or RuntimeError if error. + """ + shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') + num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') + + type_check(sampler, (type(None), samplers.BuiltinSampler, samplers.Sampler), "sampler") + + if sampler is not None: + if shuffle is not None: + raise RuntimeError("sampler and shuffle cannot be specified at the same time.") + if num_shards is not None: + raise RuntimeError("sampler and sharding cannot be specified at the same time.") + + if num_shards is not None: + check_pos_int32(num_shards) + if shard_id is None: + raise RuntimeError("num_shards is specified and currently requires shard_id as well.") + check_value(shard_id, [0, num_shards - 1], "shard_id") + + if num_shards is None and shard_id is not None: + raise RuntimeError("shard_id is specified but num_shards is not.") + + +def check_padding_options(param_dict): + """ + Check for valid padded_sample and num_padded of padded samples. + + Args: + param_dict (dict): param_dict. + + Returns: + Exception: ValueError or RuntimeError if error. + """ + + columns_list = param_dict.get('columns_list') + block_reader = param_dict.get('block_reader') + padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded') + if padded_sample is not None: + if num_padded is None: + raise RuntimeError("padded_sample is specified and requires num_padded as well.") + if num_padded < 0: + raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded)) + if columns_list is None: + raise RuntimeError("padded_sample is specified and requires columns_list as well.") + for column in columns_list: + if column not in padded_sample: + raise ValueError("padded_sample cannot match columns_list.") + if block_reader: + raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.") + + if padded_sample is None and num_padded is not None: + raise RuntimeError("num_padded is specified but padded_sample is not.") + + +def check_num_parallel_workers(value): + type_check(value, (int,), "num_parallel_workers") + if value < 1 or value > cpu_count(): + raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count())) + + +def check_num_samples(value): + type_check(value, (int,), "num_samples") + check_value(value, [0, INT32_MAX], "num_samples") + + +def validate_dataset_param_value(param_list, param_dict, param_type): + for param_name in param_list: + if param_dict.get(param_name) is not None: + if param_name == 'num_parallel_workers': + check_num_parallel_workers(param_dict.get(param_name)) + if param_name == 'num_samples': + check_num_samples(param_dict.get(param_name)) + else: + type_check(param_dict.get(param_name), (param_type,), param_name) + + +def check_gnn_list_or_ndarray(param, param_name): + """ + Check if the input parameter is list or numpy.ndarray. + + Args: + param (list, nd.ndarray): param. + param_name (str): param_name. + + Returns: + Exception: TypeError if error. + """ + + type_check(param, (list, np.ndarray), param_name) + if isinstance(param, list): + param_names = ["param_{0}".format(i) for i in range(len(param))] + type_check_list(param, (int,), param_names) + + elif isinstance(param, np.ndarray): + if not param.dtype == np.int32: + raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( + param_name, param.dtype)) diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index 674848f156..b3624e1ca3 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -26,10 +26,9 @@ from .datasets import * from .iterators import * from .serializer_deserializer import serialize, deserialize, show, compare from .samplers import * -from ..core.configuration import config, ConfigurationManager +from ..core import config -__all__ = ["config", "ConfigurationManager", "zip", - "ImageFolderDatasetV2", "MnistDataset", +__all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler", diff --git a/mindspore/dataset/engine/cache_client.py b/mindspore/dataset/engine/cache_client.py new file mode 100644 index 0000000000..800c0dab1d --- /dev/null +++ b/mindspore/dataset/engine/cache_client.py @@ -0,0 +1,49 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cache client +""" + +import copy +from mindspore._c_dataengine import CacheClient + +class DatasetCache: + """ + A client to interface with tensor caching service + """ + + def __init__(self, session_id=None, size=None, spilling=False): + if session_id is None: + raise RuntimeError("Session generation is not implemented yet. session id required") + self.size = size if size is not None else 0 + if size < 0: + raise ValueError("cache size should be 0 or positive integer value but got: size={}".format(size)) + if not isinstance(spilling, bool): + raise ValueError( + "spilling argument for cache should be a boolean value but got: spilling={}".format(spilling)) + self.session_id = session_id + self.spilling = spilling + self.cache_client = CacheClient(session_id, size, spilling) + + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + cls = self.__class__ + new_cache = cls.__new__(cls) + memodict[id(self)] = new_cache + new_cache.session_id = copy.deepcopy(self.session_id, memodict) + new_cache.spilling = copy.deepcopy(self.spilling, memodict) + new_cache.size = copy.deepcopy(self.size, memodict) + new_cache.cache_client = self.cache_client + return new_cache diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index ae0dc6789e..846e7e0a56 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -38,13 +38,13 @@ from mindspore._c_expression import typing from mindspore import log as logger from . import samplers -from .iterators import DictIterator, TupleIterator +from .iterators import DictIterator, TupleIterator, DummyIterator from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ check_rename, check_numpyslicesdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \ check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \ - check_split, check_bucket_batch_by_length, check_cluedataset + check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32 from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -146,6 +146,12 @@ class Dataset: self._num_classes = None self._repeat_count = None self._sync = False + self.ms_role = os.getenv("MS_ROLE") + + def _noop_mode(self): + if self.ms_role in ("MS_PSERVER", "MS_SCHED"): + return True + return False def __add__(self, datasets): return self.concat(datasets) @@ -386,7 +392,7 @@ class Dataset: @check_map def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None, python_multiprocessing=False): + num_parallel_workers=None, python_multiprocessing=False, cache=None): """ Apply each operation in operations to this dataset. @@ -427,6 +433,7 @@ class Dataset: parallel (default=None, the value from the config will be used). python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This option could be beneficial if the python operation is computational heavy (default=False). + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) Returns: MapDataset, dataset after mapping operation. @@ -541,7 +548,7 @@ class Dataset: >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order) """ return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers, - python_multiprocessing) + python_multiprocessing, cache) @check_filter def filter(self, predicate, input_columns=None, num_parallel_workers=1): @@ -939,6 +946,7 @@ class Dataset: raise TypeError("apply_func must return a dataset.") return dataset + @check_positive_int32 def device_que(self, prefetch_size=None): """ Return a transferredDataset that transfer data through device. @@ -956,6 +964,7 @@ class Dataset: """ return self.to_device() + @check_positive_int32 def to_device(self, num_batch=None): """ Transfer data through CPU, GPU or Ascend devices. @@ -973,10 +982,14 @@ class Dataset: Raises: TypeError: If device_type is empty. ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'. - ValueError: If num_batch is None or 0 or larger than int_max. + ValueError: If num_batch is not positive or larger than int_max. + ValueError: If dataset size is None or 0. RuntimeError: If dataset is unknown. RuntimeError: If distribution file path is given but failed to read. """ + if self.get_dataset_size() is None or 0: + raise ValueError("dataset size is None or 0.") + if num_batch is None: num_batch = self.get_dataset_size() repeat_count = self.get_repeat_count() @@ -995,8 +1008,8 @@ class Dataset: if device_type not in ('Ascend', 'GPU', 'CPU'): raise ValueError("Only support CPU, Ascend, GPU") - if num_batch is None or num_batch == 0: - raise ValueError("num_batch is None or 0.") + if num_batch == 0: + raise ValueError("num_batch is 0.") def get_distribution(output_dataset): dev_id = 0 @@ -1055,6 +1068,8 @@ class Dataset: >>> # convert the returned tuple to a list and print >>> print(list(item)) """ + if self._noop_mode(): + return DummyIterator(self, 'tuple') return TupleIterator(self, columns) def create_dict_iterator(self): @@ -1078,6 +1093,8 @@ class Dataset: >>> print(item["column1"]) """ + if self._noop_mode(): + return DummyIterator(self, 'dict') return DictIterator(self) def __iter__(self): @@ -1556,7 +1573,7 @@ class BatchDataset(DatasetOp): Number, number of batches. """ child_size = self.children[0].get_dataset_size() - if child_size is not None: + if child_size is not None and isinstance(self.batch_size, int): if self.drop_remainder: return math.floor(child_size / self.batch_size) return math.ceil(child_size / self.batch_size) @@ -1862,13 +1879,14 @@ class MapDataset(DatasetOp): in parallel (default=None). python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This option could be beneficial if the python operation is computational heavy (default=False). + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) Raises: ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified. """ def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None, python_multiprocessing=False): + num_parallel_workers=None, python_multiprocessing=False, cache=None): super().__init__(num_parallel_workers) self.children.append(input_dataset) if input_columns is not None and not isinstance(input_columns, list): @@ -1880,6 +1898,7 @@ class MapDataset(DatasetOp): if output_columns is not None and not isinstance(output_columns, list): output_columns = [output_columns] self.output_columns = output_columns + self.cache = cache self.columns_order = columns_order if self.input_columns and self.output_columns \ @@ -1898,6 +1917,7 @@ class MapDataset(DatasetOp): args["operations"] = self.operations args["output_columns"] = self.output_columns args["columns_order"] = self.columns_order + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -1923,6 +1943,7 @@ class MapDataset(DatasetOp): new_op.parent = copy.deepcopy(self.parent, memodict) new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) + new_op.cache = copy.deepcopy(self.cache, memodict) new_op.operations = self.operations return new_op @@ -2307,6 +2328,8 @@ class TransferDataset(DatasetOp): def send(self): # need to keep iterator alive so the executionTree is not destroyed + if self._noop_mode(): + return self.iterator = TupleIterator(self) @@ -2340,7 +2363,7 @@ class RangeDataset(MappableDataset): return False -def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): +def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False): """ Create sampler based on user input. @@ -2350,7 +2373,11 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): shuffle (bool): Shuffle. num_shards (int): Number of shard for sharding. shard_id (int): Shard ID. + non_mappable (bool, optional): Indicate if caller is non-mappable dataset for special handling (default=False). """ + if non_mappable is True and all(arg is None for arg in [num_samples, shuffle, num_shards, shard_id, input_sampler]): + return None + if input_sampler is not None: # If the user provided a sampler, then it doesn't matter what the other args are because # we are being asked specifically to use the given sampler. @@ -2363,7 +2390,7 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SubsetRandomSampler, samplers.WeightedRandomSampler, samplers.Sampler)) and - (num_shards is not None or shard_id is not None or shuffle is not None or num_samples is not None)): + (any(arg is not None for arg in [num_shards, shard_id, shuffle, num_samples]))): raise ValueError( 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},' ' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle)) @@ -2452,6 +2479,7 @@ class ImageFolderDatasetV2(MappableDataset): into (default=None). shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only when num_shards is also specified. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) Raises: RuntimeError: If sampler and shuffle are specified at the same time. @@ -2476,7 +2504,7 @@ class ImageFolderDatasetV2(MappableDataset): @check_imagefolderdatasetv2 def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, extensions=None, class_indexing=None, - decode=False, num_shards=None, shard_id=None): + decode=False, num_shards=None, shard_id=None, cache=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir @@ -2488,6 +2516,7 @@ class ImageFolderDatasetV2(MappableDataset): self.decode = decode self.num_shards = num_shards self.shard_id = shard_id + self.cache = cache def get_args(self): args = super().get_args() @@ -2500,6 +2529,7 @@ class ImageFolderDatasetV2(MappableDataset): args["decode"] = self.decode args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id + args["cache"] = self.cache.cache_client if self.cache is not None else None return args def get_dataset_size(self): @@ -3245,6 +3275,7 @@ class TFRecordDataset(SourceDataset): argument should be specified only when num_shards is also specified. shard_equal_rows (bool): Get equal rows for all shards(default=False). If shard_equal_rows is false, number of rows of each shard may be not equal. + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) Examples: >>> import mindspore.dataset as ds >>> import mindspore.common.dtype as mstype @@ -3262,7 +3293,7 @@ class TFRecordDataset(SourceDataset): @check_tfrecorddataset def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, - shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False): + shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None): super().__init__(num_parallel_workers) self.dataset_files = self._find_files(dataset_files) self.dataset_files.sort() @@ -3274,6 +3305,7 @@ class TFRecordDataset(SourceDataset): self.schema = schema self.columns_list = columns_list self.num_samples = num_samples + self.cache = cache if schema_obj is not None and num_samples is None: self.num_samples = schema_obj.num_rows @@ -3289,6 +3321,14 @@ class TFRecordDataset(SourceDataset): else: self.shuffle_level = shuffle self.shuffle_files = True + + # The TF record dataset does not directly support a sampler. It has provided sampling arguments + # (shuffle, num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in + # the pipeline contains a cache. If there is no cache above it, then this sampler is not used. + sampler_shuffle = self.shuffle_files + sampler = None + self.sampler = _select_sampler(self.num_samples, sampler, sampler_shuffle, num_shards, shard_id, + non_mappable=True) self.shard_equal_rows = shard_equal_rows def get_args(self): @@ -3312,6 +3352,8 @@ class TFRecordDataset(SourceDataset): args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id args["shard_equal_rows"] = self.shard_equal_rows + args["cache"] = self.cache.cache_client if self.cache is not None else None + args["sampler"] = self.sampler return args def get_dataset_size(self, estimate=False): @@ -3797,43 +3839,61 @@ class RandomDataset(SourceDataset): A source dataset that generates random data. Args: - num_samples (int): number of samples to generate. + total_rows (int): number of rows for the dataset to generate (default=None, number of rows is random) schema (str or Schema, optional): Path to the json schema file or schema object (default=None). If the schema is not provided, the random dataset generates a random schema. columns_list (list[str], optional): List of columns to be read (default=None, read all columns) + num_samples (int): number of samples to draw from the total. (default=None, which means all rows) num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). + cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used) + shuffle (bool, optional): Whether or not to perform shuffle on the dataset + (default=None, expected order behavior shown in the table). + num_shards (int, optional): Number of shards that the dataset should be divided + into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. """ - def __init__(self, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None): + @check_random_dataset + def __init__(self, total_rows=None, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, + cache=None, shuffle=None, num_shards=None, shard_id=None): super().__init__(num_parallel_workers) schema_obj = None if (schema is not None) and (not isinstance(schema, Schema)): schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it self.schema = schema self.columns_list = columns_list - if schema_obj is not None and num_samples is None: - self.num_samples = schema_obj.num_rows - elif num_samples is None: - self.num_samples = 0 + sampler = None + self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id, non_mappable=True) + self.num_samples = num_samples + self.cache = cache + if schema_obj is not None and total_rows is None: + self.total_rows = schema_obj.num_rows + elif total_rows is None: + self.total_rows = 0 else: - self.num_samples = num_samples + self.total_rows = total_rows + self.num_shards = num_shards + self.shard_id = shard_id + self.shuffle_level = shuffle def get_args(self): args = super().get_args() if self.schema is not None: if isinstance(self.schema, Schema): self.schema.datasetType = 'Random' - if self.num_samples is not None: - self.schema.num_rows = self.num_samples + if self.total_rows is not None: + self.schema.num_rows = self.total_rows args["schema_json_string"] = self.schema.to_json() else: args["schema_file_path"] = self.schema args["schema"] = self.schema - if self.columns_list is not None: - args["columns_list"] = self.columns_list - if self.num_samples is not None: - args["num_samples"] = self.num_samples + args["columns_list"] = self.columns_list + args["num_samples"] = self.num_samples + args["total_rows"] = self.total_rows + args["cache"] = self.cache.cache_client if self.cache is not None else None + args["sampler"] = self.sampler return args def get_dataset_size(self): @@ -3843,18 +3903,28 @@ class RandomDataset(SourceDataset): Return: Number, number of batches. """ + + num_rows = CifarOp.get_num_rows(self.dataset_dir, True) + + rows_per_shard = get_num_rows(num_rows, self.num_shards) rows_from_sampler = self._get_sampler_dataset_size() if rows_from_sampler is None: - return self.num_samples + return rows_per_shard - return min(rows_from_sampler, self.num_samples) + return min(rows_from_sampler, rows_per_shard) def is_shuffled(self): - return True + if self.shuffle_level is None: + return True + + return self.shuffle_level or self.sampler.is_shuffled() def is_sharded(self): - return False + if self.num_shards is not None: + return self.num_shards > 1 + + return self.sampler.is_sharded() class Schema: diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index 472819784e..81314b4373 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -22,7 +22,8 @@ from mindspore._c_dataengine import Tensor from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ - check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk + check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_get_edge_feature, \ + check_gnn_random_walk class GraphData: @@ -127,7 +128,13 @@ class GraphData: @check_gnn_get_sampled_neighbors def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): """ - Get sampled neighbor information, maximum support 6-hop sampling. + Get sampled neighbor information. + + The api supports multi-hop neighbor sampling. That is, the previous sampling result is used as the input of + next-hop sampling. A maximum of 6-hop are allowed. + + The sampling result is tiled into a list in the format of [input node, 1-hop sampling result, + 2-hop samling result ...] Args: node_list (list or numpy.ndarray): The given list of nodes. @@ -207,6 +214,35 @@ class GraphData: Tensor(node_list), feature_types)] + @check_gnn_get_edge_feature + def get_edge_feature(self, edge_list, feature_types): + """ + Get `feature_types` feature of the edges in `edge_list`. + + Args: + edge_list (list or numpy.ndarray): The given list of edges. + feature_types (list or ndarray): The given list of feature types. + + Returns: + numpy.ndarray: array of features. + + Examples: + >>> import mindspore.dataset as ds + >>> data_graph = ds.GraphData('dataset_file', 2) + >>> edges = data_graph.get_all_edges(0) + >>> features = data_graph.get_edge_feature(edges, [1]) + + Raises: + TypeError: If `edge_list` is not list or ndarray. + TypeError: If `feature_types` is not list or ndarray. + """ + if isinstance(edge_list, list): + edge_list = np.array(edge_list, dtype=np.int32) + return [ + t.as_array() for t in self._graph.get_edge_feature( + Tensor(edge_list), + feature_types)] + def graph_info(self): """ Get the meta information of the graph, including the number of nodes, the type of nodes, @@ -232,9 +268,10 @@ class GraphData: Args: target_nodes (list[int]): Start node list in random walk meta_path (list[int]): node type for each walk step - step_home_param (float): return hyper parameter in node2vec algorithm - step_away_param (float): inout hyper parameter in node2vec algorithm - default_node (int): default node if no more neighbors found + step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0). + step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0). + default_node (int, optional): default node if no more neighbors found (Default = -1). + A default value of -1 indicates that no node is given. Returns: numpy.ndarray: array of nodes. diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 1d2d28c1c0..a2a23cbb44 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -17,7 +17,9 @@ from abc import abstractmethod import copy import weakref +import numpy as np +from mindspore.common.tensor import Tensor from mindspore._c_dataengine import DEPipeline from mindspore._c_dataengine import OpName @@ -287,3 +289,32 @@ class TupleIterator(Iterator): """ return [t.as_array() for t in self.depipeline.GetNextAsList()] + + +class DummyIterator(): + """ + A DummyIterator only work when env MS_ROLE="MS_PSERVER" or MS_ROLE="MS_SCHED" + """ + def __init__(self, dataset, mode): + self.mode = mode + self.shapes = dataset.output_shapes() + self.types = dataset.output_types() + self.fetched_first = False + + def __get_tensor(self): + tensor_row = [] + for np_shape, np_type in zip(self.shapes, self.types): + input_np = np.zeros(np_shape, np_type) + tensor = Tensor(input_np) + tensor_row.append(tensor) + return tensor_row + + def __iter__(self): + return self + + def __next__(self): + if self.mode == "tuple": + if not self.fetched_first: + self.fetched_first = True + return self.__get_tensor() + raise StopIteration() diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index 9d3339e26d..8fd3a2bb9b 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -22,7 +22,7 @@ import sys from mindspore import log as logger from . import datasets as de from ..transforms.vision.utils import Inter, Border -from ..core.configuration import config +from ..core import config def serialize(dataset, json_filepath=None): """ @@ -173,7 +173,9 @@ def traverse(node): # num_samples, shard_id, num_shards, shuffle # These arguments get moved into the sampler itself, so they are no longer needed to # be set at the dataset level. - if 'sampler' in node_args.keys(): + # TF Record is a special case because it uses both the dataset and sampler arguments + # which is not decided until later during tree preparation phase. + if node_repr['op_type'] != 'TFRecordDataset' and 'sampler' in node_args.keys(): if 'num_samples' in node_repr.keys(): node_repr['num_samples'] = None if 'shuffle' in node_repr.keys(): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 744a9b94be..29904f1a9e 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -9,335 +9,151 @@ # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and +# See the License foNtest_resr the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Built-in validators. +""" +Built-in validators. """ import inspect as ins import os from functools import wraps -from multiprocessing import cpu_count import numpy as np from mindspore._c_expression import typing +from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ + INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ + validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ + check_columns, check_pos_int32 from . import datasets from . import samplers +from . import cache_client -INT32_MAX = 2147483647 -valid_detype = [ - "bool", "int8", "int16", "int32", "int64", "uint8", "uint16", - "uint32", "uint64", "float16", "float32", "float64", "string" -] - - -def check_valid_detype(type_): - if type_ not in valid_detype: - raise ValueError("Unknown column type") - return True - - -def check_filename(path): - """ - check the filename in the path - - Args: - path (str): the path - - Returns: - Exception: when error - """ - if not isinstance(path, str): - raise TypeError("path: {} is not string".format(path)) - filename = os.path.basename(path) - - # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', - # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>', - # '*', '(', '%', ')', '-', '=', '{', '?', '$' - forbidden_symbols = set(r'\/:*?"<>|`&\';') - - if set(filename) & forbidden_symbols: - raise ValueError(r"filename should not contains \/:*?\"<>|`&;\'") - - if filename.startswith(' ') or filename.endswith(' '): - raise ValueError("filename should not start/end with space") - - return True - - -def make_param_dict(method, args, kwargs): - """Return a dictionary of the method's args and kwargs.""" - sig = ins.signature(method) - params = sig.parameters - keys = list(params.keys()) - param_dict = dict() - try: - for name, value in enumerate(args): - param_dict[keys[name]] = value - except IndexError: - raise TypeError("{0}() expected {1} arguments, but {2} were given".format( - method.__name__, len(keys) - 1, len(args) - 1)) - - param_dict.update(zip(params.keys(), args)) - param_dict.update(kwargs) - - for name, value in params.items(): - if name not in param_dict: - param_dict[name] = value.default - return param_dict - - -def check_type(param, param_name, valid_type): - if (not isinstance(param, valid_type)) or (valid_type == int and isinstance(param, bool)): - raise TypeError("Wrong input type for {0}, should be {1}, got {2}".format(param_name, valid_type, type(param))) - - -def check_param_type(param_list, param_dict, param_type): - for param_name in param_list: - if param_dict.get(param_name) is not None: - if param_name == 'num_parallel_workers': - check_num_parallel_workers(param_dict.get(param_name)) - if param_name == 'num_samples': - check_num_samples(param_dict.get(param_name)) - else: - check_type(param_dict.get(param_name), param_name, param_type) - - -def check_positive_int32(param, param_name): - check_interval_closed(param, param_name, [1, INT32_MAX]) - - -def check_interval_closed(param, param_name, valid_range): - if param < valid_range[0] or param > valid_range[1]: - raise ValueError("The value of {0} exceeds the closed interval range {1}.".format(param_name, valid_range)) - - -def check_num_parallel_workers(value): - check_type(value, 'num_parallel_workers', int) - if value < 1 or value > cpu_count(): - raise ValueError("num_parallel_workers exceeds the boundary between 1 and {}!".format(cpu_count())) - - -def check_num_samples(value): - check_type(value, 'num_samples', int) - if value < 0: - raise ValueError("num_samples cannot be less than 0!") - - -def check_dataset_dir(dataset_dir): - if not os.path.isdir(dataset_dir) or not os.access(dataset_dir, os.R_OK): - raise ValueError("The folder {} does not exist or permission denied!".format(dataset_dir)) - - -def check_dataset_file(dataset_file): - check_filename(dataset_file) - if not os.path.isfile(dataset_file) or not os.access(dataset_file, os.R_OK): - raise ValueError("The file {} does not exist or permission denied!".format(dataset_file)) - - -def check_sampler_shuffle_shard_options(param_dict): - """check for valid shuffle, sampler, num_shards, and shard_id inputs.""" - shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') - num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') - - if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): - raise TypeError("sampler is not a valid Sampler type.") - - if sampler is not None: - if shuffle is not None: - raise RuntimeError("sampler and shuffle cannot be specified at the same time.") - - if num_shards is not None: - raise RuntimeError("sampler and sharding cannot be specified at the same time.") - - if num_shards is not None: - check_positive_int32(num_shards, "num_shards") - if shard_id is None: - raise RuntimeError("num_shards is specified and currently requires shard_id as well.") - if shard_id < 0 or shard_id >= num_shards: - raise ValueError("shard_id is invalid, shard_id={}".format(shard_id)) - - if num_shards is None and shard_id is not None: - raise RuntimeError("shard_id is specified but num_shards is not.") - - -def check_padding_options(param_dict): - """ check for valid padded_sample and num_padded of padded samples""" - columns_list = param_dict.get('columns_list') - block_reader = param_dict.get('block_reader') - padded_sample, num_padded = param_dict.get('padded_sample'), param_dict.get('num_padded') - if padded_sample is not None: - if num_padded is None: - raise RuntimeError("padded_sample is specified and requires num_padded as well.") - if num_padded < 0: - raise ValueError("num_padded is invalid, num_padded={}.".format(num_padded)) - if columns_list is None: - raise RuntimeError("padded_sample is specified and requires columns_list as well.") - for column in columns_list: - if column not in padded_sample: - raise ValueError("padded_sample cannot match columns_list.") - if block_reader: - raise RuntimeError("block_reader and padded_sample cannot be specified at the same time.") - - if padded_sample is None and num_padded is not None: - raise RuntimeError("num_padded is specified but padded_sample is not.") def check_imagefolderdatasetv2(method): - """A wrapper that wrap a parameter checker to the original Dataset(ImageFolderDatasetV2).""" + """A wrapper that wraps a parameter checker to the original Dataset(ImageFolderDatasetV2).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] nreq_param_list = ['extensions'] nreq_param_dict = ['class_indexing'] - # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') - if dataset_dir is None: - raise ValueError("dataset_dir is not provided.") - check_dataset_dir(dataset_dir) - - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_bool, param_dict, bool) - - check_param_type(nreq_param_list, param_dict, list) - - check_param_type(nreq_param_dict, param_dict, dict) + check_dir(dataset_dir) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_list, param_dict, list) + validate_dataset_param_value(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_mnist_cifar_dataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset, Cifar10/100Dataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle'] - # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') - if dataset_dir is None: - raise ValueError("dataset_dir is not provided.") - check_dataset_dir(dataset_dir) - - check_param_type(nreq_param_int, param_dict, int) + check_dir(dataset_dir) - check_param_type(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_manifestdataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(ManifestDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(ManifestDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] nreq_param_str = ['usage'] nreq_param_dict = ['class_indexing'] - # check dataset_file; required argument dataset_file = param_dict.get('dataset_file') - if dataset_file is None: - raise ValueError("dataset_file is not provided.") - check_dataset_file(dataset_file) - - check_param_type(nreq_param_int, param_dict, int) + check_file(dataset_file) - check_param_type(nreq_param_bool, param_dict, bool) - - check_param_type(nreq_param_str, param_dict, str) - - check_param_type(nreq_param_dict, param_dict, dict) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_str, param_dict, str) + validate_dataset_param_value(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_tfrecorddataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(TFRecordDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(TFRecordDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_list = ['columns_list'] nreq_param_bool = ['shard_equal_rows'] - # check dataset_files; required argument dataset_files = param_dict.get('dataset_files') - if dataset_files is None: - raise ValueError("dataset_files is not provided.") if not isinstance(dataset_files, (str, list)): raise TypeError("dataset_files should be of type str or a list of strings.") - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_list, param_dict, list) - - check_param_type(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_list, param_dict, list) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_vocdataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(VOCDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(VOCDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] nreq_param_dict = ['class_indexing'] - # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') - if dataset_dir is None: - raise ValueError("dataset_dir is not provided.") - check_dataset_dir(dataset_dir) - # check task; required argument + check_dir(dataset_dir) + task = param_dict.get('task') - if task is None: - raise ValueError("task is not provided.") - if not isinstance(task, str): - raise TypeError("task is not str type.") - # check mode; required argument + type_check(task, (str,), "task") + mode = param_dict.get('mode') - if mode is None: - raise ValueError("mode is not provided.") - if not isinstance(mode, str): - raise TypeError("mode is not str type.") + type_check(mode, (str,), "mode") - imagesets_file = "" if task == "Segmentation": imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", mode + ".txt") if param_dict.get('class_indexing') is not None: @@ -347,92 +163,74 @@ def check_vocdataset(method): else: raise ValueError("Invalid task : " + task) - check_dataset_file(imagesets_file) - - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_bool, param_dict, bool) - - check_param_type(nreq_param_dict, param_dict, dict) + check_file(imagesets_file) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_cocodataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(CocoDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(CocoDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] - # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') - if dataset_dir is None: - raise ValueError("dataset_dir is not provided.") - check_dataset_dir(dataset_dir) + check_dir(dataset_dir) - # check annotation_file; required argument annotation_file = param_dict.get('annotation_file') - if annotation_file is None: - raise ValueError("annotation_file is not provided.") - check_dataset_file(annotation_file) + check_file(annotation_file) - # check task; required argument task = param_dict.get('task') - if task is None: - raise ValueError("task is not provided.") - if not isinstance(task, str): - raise TypeError("task is not str type.") + type_check(task, (str,), "task") if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}: raise ValueError("Invalid task type") - check_param_type(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_int, param_dict, int) - check_param_type(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) sampler = param_dict.get('sampler') if sampler is not None and isinstance(sampler, samplers.PKSampler): raise ValueError("CocoDataset doesn't support PKSampler") check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_celebadataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(CelebADataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(CelebADataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] nreq_param_list = ['extensions'] nreq_param_str = ['dataset_type'] - # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') - if dataset_dir is None: - raise ValueError("dataset_dir is not provided.") - check_dataset_dir(dataset_dir) - check_param_type(nreq_param_int, param_dict, int) + check_dir(dataset_dir) - check_param_type(nreq_param_bool, param_dict, bool) - - check_param_type(nreq_param_list, param_dict, list) - - check_param_type(nreq_param_str, param_dict, str) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_list, param_dict, list) + validate_dataset_param_value(nreq_param_str, param_dict, str) dataset_type = param_dict.get('dataset_type') if dataset_type is not None and dataset_type not in ('all', 'train', 'valid', 'test'): @@ -444,67 +242,58 @@ def check_celebadataset(method): if sampler is not None and isinstance(sampler, samplers.PKSampler): raise ValueError("CelebADataset does not support PKSampler.") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_minddataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(MindDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(MindDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded'] nreq_param_list = ['columns_list'] nreq_param_bool = ['block_reader'] nreq_param_dict = ['padded_sample'] - # check dataset_file; required argument dataset_file = param_dict.get('dataset_file') - if dataset_file is None: - raise ValueError("dataset_file is not provided.") if isinstance(dataset_file, list): for f in dataset_file: - check_dataset_file(f) + check_file(f) else: - check_dataset_file(dataset_file) - - check_param_type(nreq_param_int, param_dict, int) - - check_param_type(nreq_param_list, param_dict, list) - - check_param_type(nreq_param_bool, param_dict, bool) + check_file(dataset_file) - check_param_type(nreq_param_dict, param_dict, dict) + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_list, param_dict, list) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_dict, param_dict, dict) check_sampler_shuffle_shard_options(param_dict) check_padding_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_generatordataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(GeneratorDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(GeneratorDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) - # check generator_function; required argument source = param_dict.get('source') - if source is None: - raise ValueError("source is not provided.") + if not callable(source): try: iter(source) except TypeError: raise TypeError("source should be callable, iterable or random accessible") - # check column_names or schema; required argument column_names = param_dict.get('column_names') if column_names is not None: check_columns(column_names, "column_names") @@ -518,11 +307,11 @@ def check_generatordataset(method): # check optional argument nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"] - check_param_type(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_int, param_dict, int) nreq_param_list = ["column_types"] - check_param_type(nreq_param_list, param_dict, list) + validate_dataset_param_value(nreq_param_list, param_dict, list) nreq_param_bool = ["shuffle"] - check_param_type(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) num_shards = param_dict.get("num_shards") shard_id = param_dict.get("shard_id") @@ -530,9 +319,9 @@ def check_generatordataset(method): # These two parameters appear together. raise ValueError("num_shards and shard_id need to be passed in together") if num_shards is not None: - check_positive_int32(num_shards, "num_shards") + check_pos_int32(num_shards, "num_shards") if shard_id >= num_shards: - raise ValueError("shard_id should be less than num_shards") + raise ValueError("shard_id should be less than num_shards.") sampler = param_dict.get("sampler") if sampler is not None: @@ -551,81 +340,73 @@ def check_generatordataset(method): if num_shards is not None and not hasattr(source, "__getitem__"): raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method +def check_random_dataset(method): + """A wrapper that wraps a parameter checker to the original Dataset(RandomDataset).""" -def check_batch_size(batch_size): - if not (isinstance(batch_size, int) or (callable(batch_size))): - raise TypeError("batch_size should either be an int or a callable.") - if callable(batch_size): - sig = ins.signature(batch_size) - if len(sig.parameters) != 1: - raise ValueError("batch_size callable should take one parameter (BatchInfo).") + @wraps(method) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows'] + nreq_param_bool = ['shuffle'] + nreq_param_list = ['columns_list'] -def check_count(count): - check_type(count, 'count', int) - if (count <= 0 and count != -1) or count > INT32_MAX: - raise ValueError("count should be either -1 or positive integer.") + validate_dataset_param_value(nreq_param_int, param_dict, int) + validate_dataset_param_value(nreq_param_bool, param_dict, bool) + validate_dataset_param_value(nreq_param_list, param_dict, list) + check_sampler_shuffle_shard_options(param_dict) + + return method(self, *args, **kwargs) -def check_columns(columns, name): - if isinstance(columns, list): - for column in columns: - if not isinstance(column, str): - raise TypeError("Each column in {0} should be of type str. Got {1}.".format(name, type(column))) - elif not isinstance(columns, str): - raise TypeError("{} should be either a list of strings or a single string.".format(name)) + return new_method def check_pad_info(key, val): """check the key and value pair of pad_info in batch""" - check_type(key, "key in pad_info", str) + type_check(key, (str,), "key in pad_info") + if val is not None: assert len(val) == 2, "value of pad_info should be a tuple of size 2" - check_type(val, "value in pad_info", tuple) + type_check(val, (tuple,), "value in pad_info") + if val[0] is not None: - check_type(val[0], "pad_shape", list) + type_check(val[0], (list,), "pad_shape") + for dim in val[0]: if dim is not None: - check_type(dim, "dim in pad_shape", int) + type_check(dim, (int,), "dim in pad_shape") assert dim > 0, "pad shape should be positive integers" if val[1] is not None: - check_type(val[1], "pad_value", (int, float, str, bytes)) + type_check(val[1], (int, float, str, bytes), "pad_value") def check_bucket_batch_by_length(method): """check the input arguments of bucket_batch_by_length.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info, + pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs) nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes'] - check_param_type(nreq_param_list, param_dict, list) + + type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list) nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder'] - check_param_type(nbool_param_list, param_dict, bool) + type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list) # check column_names: must be list of string. - column_names = param_dict.get("column_names") - - if not column_names: - raise ValueError("column_names cannot be empty") + check_columns(column_names, "column_names") - all_string = all(isinstance(item, str) for item in column_names) - if not all_string: - raise TypeError("column_names should be a list of str.") - - element_length_function = param_dict.get("element_length_function") if element_length_function is None and len(column_names) != 1: raise ValueError("If element_length_function is not specified, exactly one column name should be passed.") # check bucket_boundaries: must be list of int, positive and strictly increasing - bucket_boundaries = param_dict.get('bucket_boundaries') - if not bucket_boundaries: raise ValueError("bucket_boundaries cannot be empty.") @@ -633,16 +414,15 @@ def check_bucket_batch_by_length(method): if not all_int: raise TypeError("bucket_boundaries should be a list of int.") - all_non_negative = all(item >= 0 for item in bucket_boundaries) + all_non_negative = all(item > 0 for item in bucket_boundaries) if not all_non_negative: - raise ValueError("bucket_boundaries cannot contain any negative numbers.") + raise ValueError("bucket_boundaries must only contain positive numbers.") for i in range(len(bucket_boundaries) - 1): if not bucket_boundaries[i + 1] > bucket_boundaries[i]: raise ValueError("bucket_boundaries should be strictly increasing.") # check bucket_batch_sizes: must be list of int and positive - bucket_batch_sizes = param_dict.get('bucket_batch_sizes') if len(bucket_batch_sizes) != len(bucket_boundaries) + 1: raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.") @@ -654,12 +434,13 @@ def check_bucket_batch_by_length(method): if not all_non_negative: raise ValueError("bucket_batch_sizes should be a list of positive numbers.") - if param_dict.get('pad_info') is not None: - check_type(param_dict["pad_info"], "pad_info", dict) - for k, v in param_dict.get('pad_info').items(): + if pad_info is not None: + type_check(pad_info, (dict,), "pad_info") + + for k, v in pad_info.items(): check_pad_info(k, v) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -668,37 +449,33 @@ def check_batch(method): """check the input arguments of batch.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - nreq_param_int = ['num_parallel_workers'] - nreq_param_bool = ['drop_remainder'] - nreq_param_columns = ['input_columns'] + def new_method(self, *args, **kwargs): + [batch_size, drop_remainder, num_parallel_workers, per_batch_map, + input_columns, pad_info], param_dict = parse_user_args(method, *args, **kwargs) - # check batch_size; required argument - batch_size = param_dict.get("batch_size") - if batch_size is None: - raise ValueError("batch_size is not provided.") - check_batch_size(batch_size) + if not (isinstance(batch_size, int) or (callable(batch_size))): + raise TypeError("batch_size should either be an int or a callable.") - check_param_type(nreq_param_int, param_dict, int) + if callable(batch_size): + sig = ins.signature(batch_size) + if len(sig.parameters) != 1: + raise ValueError("batch_size callable should take one parameter (BatchInfo).") - check_param_type(nreq_param_bool, param_dict, bool) + if num_parallel_workers is not None: + check_num_parallel_workers(num_parallel_workers) + type_check(drop_remainder, (bool,), "drop_remainder") - if (param_dict.get('pad_info') is not None) and (param_dict.get('per_batch_map') is not None): + if (pad_info is not None) and (per_batch_map is not None): raise ValueError("pad_info and per_batch_map can't both be set") - if param_dict.get('pad_info') is not None: - check_type(param_dict["pad_info"], "pad_info", dict) + if pad_info is not None: + type_check(param_dict["pad_info"], (dict,), "pad_info") for k, v in param_dict.get('pad_info').items(): check_pad_info(k, v) - for param_name in nreq_param_columns: - param = param_dict.get(param_name) - if param is not None: - check_columns(param, param_name) + if input_columns is not None: + check_columns(input_columns, "input_columns") - per_batch_map, input_columns = param_dict.get('per_batch_map'), param_dict.get('input_columns') if (per_batch_map is None) != (input_columns is None): # These two parameters appear together. raise ValueError("per_batch_map and input_columns need to be passed in together.") @@ -709,43 +486,38 @@ def check_batch(method): if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1): raise ValueError("the signature of per_batch_map should match with input columns") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method + def check_sync_wait(method): """check the input arguments of sync_wait.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - nreq_param_str = ['condition_name'] - nreq_param_int = ['step_size'] + def new_method(self, *args, **kwargs): + [condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs) - check_param_type(nreq_param_int, param_dict, int) + type_check(condition_name, (str,), "condition_name") + type_check(num_batch, (int,), "num_batch") - check_param_type(nreq_param_str, param_dict, str) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method + def check_shuffle(method): """check the input arguments of shuffle.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [buffer_size], _ = parse_user_args(method, *args, **kwargs) - # check buffer_size; required argument - buffer_size = param_dict.get("buffer_size") - if buffer_size is None: - raise ValueError("buffer_size is not provided.") - check_type(buffer_size, 'buffer_size', int) - check_interval_closed(buffer_size, 'buffer_size', [2, INT32_MAX]) + type_check(buffer_size, (int,), "buffer_size") - return method(*args, **kwargs) + check_value(buffer_size, [2, INT32_MAX], "buffer_size") + + return method(self, *args, **kwargs) return new_method @@ -754,23 +526,25 @@ def check_map(method): """check the input arguments of map.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \ + parse_user_args(method, *args, **kwargs) - nreq_param_list = ['columns_order'] - nreq_param_int = ['num_parallel_workers'] nreq_param_columns = ['input_columns', 'output_columns'] - nreq_param_bool = ['python_multiprocessing'] - check_param_type(nreq_param_list, param_dict, list) - check_param_type(nreq_param_int, param_dict, int) - check_param_type(nreq_param_bool, param_dict, bool) - for param_name in nreq_param_columns: - param = param_dict.get(param_name) + if columns_order is not None: + type_check(columns_order, (list,), "columns_order") + if num_parallel_workers is not None: + check_num_parallel_workers(num_parallel_workers) + type_check(python_multiprocessing, (bool,), "python_multiprocessing") + if cache is not None: + type_check(cache, (cache_client.DatasetCache,), "cache") + + for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]): if param is not None: check_columns(param, param_name) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -779,19 +553,20 @@ def check_filter(method): """"check the input arguments of filter.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - predicate = param_dict.get("predicate") + def new_method(self, *args, **kwargs): + [predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs) if not callable(predicate): raise TypeError("Predicate should be a python function or a callable python object.") - nreq_param_int = ['num_parallel_workers'] - check_param_type(nreq_param_int, param_dict, int) - param_name = "input_columns" - param = param_dict.get(param_name) - if param is not None: - check_columns(param, param_name) - return method(*args, **kwargs) + check_num_parallel_workers(num_parallel_workers) + + if num_parallel_workers is not None: + check_num_parallel_workers(num_parallel_workers) + + if input_columns is not None: + check_columns(input_columns, "input_columns") + + return method(self, *args, **kwargs) return new_method @@ -800,14 +575,13 @@ def check_repeat(method): """check the input arguments of repeat.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [count], _ = parse_user_args(method, *args, **kwargs) - count = param_dict.get('count') - if count is not None: - check_count(count) - - return method(*args, **kwargs) + type_check(count, (int, type(None)), "repeat") + if isinstance(count, int): + check_value(count, (-1, INT32_MAX), "count") + return method(self, *args, **kwargs) return new_method @@ -816,15 +590,13 @@ def check_skip(method): """check the input arguments of skip.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [count], _ = parse_user_args(method, *args, **kwargs) - count = param_dict.get('count') - check_type(count, 'count', int) - if count < 0: - raise ValueError("Skip count must be positive integer or 0.") + type_check(count, (int,), "count") + check_value(count, (-1, INT32_MAX), "count") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -833,13 +605,32 @@ def check_take(method): """check the input arguments of take.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [count], _ = parse_user_args(method, *args, **kwargs) + type_check(count, (int,), "count") + if (count <= 0 and count != -1) or count > INT32_MAX: + raise ValueError("count should be either -1 or positive integer.") - count = param_dict.get('count') - check_count(count) + return method(self, *args, **kwargs) - return method(*args, **kwargs) + return new_method + + +def check_positive_int32(method): + """check whether the input argument is positive and int, only works for functions with one input.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [count], param_dict = parse_user_args(method, *args, **kwargs) + para_name = None + for key in list(param_dict.keys()): + if key not in ['self', 'cls']: + para_name = key + # Need to get default value of param + if count is not None: + check_pos_int32(count, para_name) + + return method(self, *args, **kwargs) return new_method @@ -849,13 +640,8 @@ def check_zip(method): @wraps(method) def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check datasets; required argument - ds = param_dict.get("datasets") - if ds is None: - raise ValueError("datasets is not provided.") - check_type(ds, 'datasets', tuple) + [ds], _ = parse_user_args(method, *args, **kwargs) + type_check(ds, (tuple,), "datasets") return method(*args, **kwargs) @@ -866,18 +652,11 @@ def check_zip_dataset(method): """check the input arguments of zip method in `Dataset`.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check datasets; required argument - ds = param_dict.get("datasets") - if ds is None: - raise ValueError("datasets is not provided.") + def new_method(self, *args, **kwargs): + [ds], _ = parse_user_args(method, *args, **kwargs) + type_check(ds, (tuple, datasets.Dataset), "datasets") - if not isinstance(ds, (tuple, datasets.Dataset)): - raise TypeError("datasets is not tuple or of type Dataset.") - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -886,18 +665,13 @@ def check_concat(method): """check the input arguments of concat method in `Dataset`.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check datasets; required argument - ds = param_dict.get("datasets") - if ds is None: - raise ValueError("datasets is not provided.") - - if not isinstance(ds, (list, datasets.Dataset)): - raise TypeError("datasets is not list or of type Dataset.") - - return method(*args, **kwargs) + def new_method(self, *args, **kwargs): + [ds], _ = parse_user_args(method, *args, **kwargs) + type_check(ds, (list, datasets.Dataset), "datasets") + if isinstance(ds, list): + dataset_names = ["dataset[{0}]".format(i) for i in range(len(ds)) if isinstance(ds, list)] + type_check_list(ds, (datasets.Dataset,), dataset_names) + return method(self, *args, **kwargs) return new_method @@ -906,26 +680,23 @@ def check_rename(method): """check the input arguments of rename.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + values, _ = parse_user_args(method, *args, **kwargs) req_param_columns = ['input_columns', 'output_columns'] - # check req_param_list; required arguments - for param_name in req_param_columns: - param = param_dict.get(param_name) - if param is None: - raise ValueError("{} is not provided.".format(param_name)) + for param_name, param in zip(req_param_columns, values): check_columns(param, param_name) input_size, output_size = 1, 1 - if isinstance(param_dict.get(req_param_columns[0]), list): - input_size = len(param_dict.get(req_param_columns[0])) - if isinstance(param_dict.get(req_param_columns[1]), list): - output_size = len(param_dict.get(req_param_columns[1])) + input_columns, output_columns = values + if isinstance(input_columns, list): + input_size = len(input_columns) + if isinstance(output_columns, list): + output_size = len(output_columns) if input_size != output_size: raise ValueError("Number of column in input_columns and output_columns is not equal.") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -934,75 +705,54 @@ def check_project(method): """check the input arguments of project.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check columns; required argument - columns = param_dict.get("columns") - if columns is None: - raise ValueError("columns is not provided.") + def new_method(self, *args, **kwargs): + [columns], _ = parse_user_args(method, *args, **kwargs) check_columns(columns, 'columns') - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method -def check_shape(shape, name): - if isinstance(shape, list): - for element in shape: - if not isinstance(element, int): - raise TypeError( - "Each element in {0} should be of type int. Got {1}.".format(name, type(element))) - else: - raise TypeError("Expected int list.") - - def check_add_column(method): """check the input arguments of add_column.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [name, de_type, shape], _ = parse_user_args(method, *args, **kwargs) + + type_check(name, (str,), "name") - # check name; required argument - name = param_dict.get("name") - if not isinstance(name, str) or not name: + if not name: raise TypeError("Expected non-empty string.") - # check type; required argument - de_type = param_dict.get("de_type") if de_type is not None: if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type): raise TypeError("Unknown column type.") else: raise TypeError("Expected non-empty string.") - # check shape - shape = param_dict.get("shape") if shape is not None: - check_shape(shape, "shape") + type_check(shape, (list,), "shape") + shape_names = ["shape[{0}]".format(i) for i in range(len(shape))] + type_check_list(shape, (int,), shape_names) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_cluedataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(CLUEDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] - # check dataset_files; required argument dataset_files = param_dict.get('dataset_files') - if dataset_files is None: - raise ValueError("dataset_files is not provided.") - if not isinstance(dataset_files, (str, list)): - raise TypeError("dataset_files should be of type str or a list of strings.") + type_check(dataset_files, (str, list), "dataset files") # check task task_param = param_dict.get('task') @@ -1014,36 +764,29 @@ def check_cluedataset(method): if usage_param not in ['train', 'test', 'eval']: raise ValueError("usage should be train, test or eval") - check_param_type(nreq_param_int, param_dict, int) - + validate_dataset_param_value(nreq_param_int, param_dict, int) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_textfiledataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] - # check dataset_files; required argument dataset_files = param_dict.get('dataset_files') - if dataset_files is None: - raise ValueError("dataset_files is not provided.") - if not isinstance(dataset_files, (str, list)): - raise TypeError("dataset_files should be of type str or a list of strings.") - - check_param_type(nreq_param_int, param_dict, int) - + type_check(dataset_files, (str, list), "dataset files") + validate_dataset_param_value(nreq_param_int, param_dict, int) check_sampler_shuffle_shard_options(param_dict) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1052,19 +795,16 @@ def check_split(method): """check the input arguments of split.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [sizes, randomize], _ = parse_user_args(method, *args, **kwargs) - nreq_param_list = ['sizes'] - nreq_param_bool = ['randomize'] - check_param_type(nreq_param_list, param_dict, list) - check_param_type(nreq_param_bool, param_dict, bool) + type_check(sizes, (list,), "sizes") + type_check(randomize, (bool,), "randomize") # check sizes: must be list of float or list of int - sizes = param_dict.get('sizes') - if not sizes: raise ValueError("sizes cannot be empty.") + all_int = all(isinstance(item, int) for item in sizes) all_float = all(isinstance(item, float) for item in sizes) @@ -1085,7 +825,7 @@ def check_split(method): if not abs(sum(sizes) - 1) < epsilon: raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1094,123 +834,85 @@ def check_gnn_graphdata(method): """check the input arguments of graphdata.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check dataset_file; required argument - dataset_file = param_dict.get('dataset_file') - if dataset_file is None: - raise ValueError("dataset_file is not provided.") - check_dataset_file(dataset_file) - - nreq_param_int = ['num_parallel_workers'] + def new_method(self, *args, **kwargs): + [dataset_file, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs) + check_file(dataset_file) - check_param_type(nreq_param_int, param_dict, int) - - return method(*args, **kwargs) + if num_parallel_workers is not None: + check_num_parallel_workers(num_parallel_workers) + return method(self, *args, **kwargs) return new_method -def check_gnn_list_or_ndarray(param, param_name): - """Check if the input parameter is list or numpy.ndarray.""" - - if isinstance(param, list): - for m in param: - if not isinstance(m, int): - raise TypeError( - "Each member in {0} should be of type int. Got {1}.".format(param_name, type(m))) - elif isinstance(param, np.ndarray): - if not param.dtype == np.int32: - raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( - param_name, param.dtype)) - else: - raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( - param_name, type(param))) - - def check_gnn_get_all_nodes(method): - """A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_all_nodes` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check node_type; required argument - check_type(param_dict.get("node_type"), 'node_type', int) + def new_method(self, *args, **kwargs): + [node_type], _ = parse_user_args(method, *args, **kwargs) + type_check(node_type, (int,), "node_type") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_gnn_get_all_edges(method): - """A wrapper that wrap a parameter checker to the GNN `get_all_edges` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_all_edges` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [edge_type], _ = parse_user_args(method, *args, **kwargs) + type_check(edge_type, (int,), "edge_type") - # check node_type; required argument - check_type(param_dict.get("edge_type"), 'edge_type', int) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_gnn_get_nodes_from_edges(method): - """A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_nodes_from_edges` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [edge_list], _ = parse_user_args(method, *args, **kwargs) + check_gnn_list_or_ndarray(edge_list, "edge_list") - # check edge_list; required argument - check_gnn_list_or_ndarray(param_dict.get("edge_list"), 'edge_list') - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_gnn_get_all_neighbors(method): - """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_all_neighbors` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [node_list, neighbour_type], _ = parse_user_args(method, *args, **kwargs) - # check node_list; required argument - check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') + check_gnn_list_or_ndarray(node_list, 'node_list') + type_check(neighbour_type, (int,), "neighbour_type") - # check neighbor_type; required argument - check_type(param_dict.get("neighbor_type"), 'neighbor_type', int) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_gnn_get_sampled_neighbors(method): - """A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_sampled_neighbors` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs) - # check node_list; required argument - check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') + check_gnn_list_or_ndarray(node_list, 'node_list') - # check neighbor_nums; required argument - neighbor_nums = param_dict.get("neighbor_nums") check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums') if not neighbor_nums or len(neighbor_nums) > 6: raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format( 'neighbor_nums', len(neighbor_nums))) - # check neighbor_types; required argument - neighbor_types = param_dict.get("neighbor_types") check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types') if not neighbor_types or len(neighbor_types) > 6: raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format( @@ -1220,47 +922,41 @@ def check_gnn_get_sampled_neighbors(method): raise ValueError( "The number of members of neighbor_nums and neighbor_types is inconsistent") - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_gnn_get_neg_sampled_neighbors(method): - """A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_neg_sampled_neighbors` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [node_list, neg_neighbor_num, neg_neighbor_type], _ = parse_user_args(method, *args, **kwargs) - # check node_list; required argument - check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') + check_gnn_list_or_ndarray(node_list, 'node_list') + type_check(neg_neighbor_num, (int,), "neg_neighbor_num") + type_check(neg_neighbor_type, (int,), "neg_neighbor_type") - # check neg_neighbor_num; required argument - check_type(param_dict.get("neg_neighbor_num"), 'neg_neighbor_num', int) - - # check neg_neighbor_type; required argument - check_type(param_dict.get("neg_neighbor_type"), - 'neg_neighbor_type', int) - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method def check_gnn_random_walk(method): - """A wrapper that wrap a parameter checker to the GNN `random_walk` function.""" + """A wrapper that wraps a parameter checker to the GNN `random_walk` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check node_list; required argument - check_gnn_list_or_ndarray(param_dict.get("target_nodes"), 'target_nodes') + def new_method(self, *args, **kwargs): + [target_nodes, meta_path, step_home_param, step_away_param, default_node], _ = parse_user_args(method, *args, + **kwargs) + check_gnn_list_or_ndarray(target_nodes, 'target_nodes') + check_gnn_list_or_ndarray(meta_path, 'meta_path') + type_check(step_home_param, (float,), "step_home_param") + type_check(step_away_param, (float,), "step_away_param") + type_check(default_node, (int,), "default_node") - # check meta_path; required argument - check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path') - - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -1268,8 +964,7 @@ def check_gnn_random_walk(method): def check_aligned_list(param, param_name, member_type): """Check whether the structure of each member of the list is the same.""" - if not isinstance(param, list): - raise TypeError("Parameter {0} is not a list".format(param_name)) + type_check(param, (list,), "param") if not param: raise TypeError( "Parameter {0} or its members are empty".format(param_name)) @@ -1278,6 +973,7 @@ def check_aligned_list(param, param_name, member_type): for member in param: if isinstance(member, list): check_aligned_list(member, param_name, member_type) + if member_have_list not in (None, True): raise TypeError("The type of each member of the parameter {0} is inconsistent".format( param_name)) @@ -1287,9 +983,7 @@ def check_aligned_list(param, param_name, member_type): member_have_list = True list_len = len(member) else: - if not isinstance(member, member_type): - raise TypeError("Each member in {0} should be of type int. Got {1}.".format( - param_name, type(member))) + type_check(member, (member_type,), param_name) if member_have_list not in (None, False): raise TypeError("The type of each member of the parameter {0} is inconsistent".format( param_name)) @@ -1297,53 +991,65 @@ def check_aligned_list(param, param_name, member_type): def check_gnn_get_node_feature(method): - """A wrapper that wrap a parameter checker to the GNN `get_node_feature` function.""" + """A wrapper that wraps a parameter checker to the GNN `get_node_feature` function.""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) + def new_method(self, *args, **kwargs): + [node_list, feature_types], _ = parse_user_args(method, *args, **kwargs) - # check node_list; required argument - node_list = param_dict.get("node_list") + type_check(node_list, (list, np.ndarray), "node_list") if isinstance(node_list, list): check_aligned_list(node_list, 'node_list', int) elif isinstance(node_list, np.ndarray): if not node_list.dtype == np.int32: raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( node_list, node_list.dtype)) - else: - raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( - 'node_list', type(node_list))) - # check feature_types; required argument - check_gnn_list_or_ndarray(param_dict.get( - "feature_types"), 'feature_types') + check_gnn_list_or_ndarray(feature_types, 'feature_types') - return method(*args, **kwargs) + return method(self, *args, **kwargs) + + return new_method + + +def check_gnn_get_edge_feature(method): + """A wrapper that wrap a parameter checker to the GNN `get_edge_feature` function.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs) + + type_check(edge_list, (list, np.ndarray), "edge_list") + if isinstance(edge_list, list): + check_aligned_list(edge_list, 'edge_list', int) + elif isinstance(edge_list, np.ndarray): + if not edge_list.dtype == np.int32: + raise TypeError("Each member in {0} should be of type int32. Got {1}.".format( + edge_list, edge_list.dtype)) + + check_gnn_list_or_ndarray(feature_types, 'feature_types') + + return method(self, *args, **kwargs) return new_method def check_numpyslicesdataset(method): - """A wrapper that wrap a parameter checker to the original Dataset(NumpySlicesDataset).""" + """A wrapper that wraps a parameter checker to the original Dataset(NumpySlicesDataset).""" @wraps(method) - def new_method(*args, **kwargs): - param_dict = make_param_dict(method, args, kwargs) - - # check data; required argument - data = param_dict.get('data') - if not isinstance(data, (list, tuple, dict, np.ndarray)): - raise TypeError("Unsupported data type: {}, only support some common python data type, " - "like list, tuple, dict, and numpy array.".format(type(data))) - if isinstance(data, tuple) and not isinstance(data[0], (list, np.ndarray)): - raise TypeError("Unsupported data type: when input is tuple, only support some common python " - "data type, like tuple of lists and tuple of numpy arrays.") + def new_method(self, *args, **kwargs): + _, param_dict = parse_user_args(method, *args, **kwargs) + + data = param_dict.get("data") + column_names = param_dict.get("column_names") if not data: - raise ValueError("Input data is empty.") + raise ValueError("Argument data cannot be empty") + type_check(data, (list, tuple, dict, np.ndarray), "data") + if isinstance(data, tuple): + type_check(data[0], (list, np.ndarray), "data[0]") # check column_names - column_names = param_dict.get('column_names') if column_names is not None: check_columns(column_names, "column_names") @@ -1364,6 +1070,6 @@ def check_numpyslicesdataset(method): raise ValueError("Num of input column names is {0}, but required is {1} as data is list." .format(column_num, 1)) - return method(*args, **kwargs) + return method(self, *args, **kwargs) return new_method diff --git a/mindspore/dataset/text/transforms.py b/mindspore/dataset/text/transforms.py index 8b0d47df25..30fa2b8f42 100644 --- a/mindspore/dataset/text/transforms.py +++ b/mindspore/dataset/text/transforms.py @@ -52,8 +52,9 @@ import mindspore._c_dataengine as cde from .utils import JiebaMode, NormalizeForm, to_str from .validators import check_lookup, check_jieba_add_dict, \ - check_jieba_add_word, check_jieba_init, check_ngram, check_pair_truncate, \ - check_to_number, check_python_tokenizer + check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer,\ + check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate,\ + check_to_number, check_bert_tokenizer, check_python_tokenizer from ..core.datatypes import mstype_to_detype @@ -63,17 +64,13 @@ class Lookup(cde.LookupOp): Args: vocab(Vocab): a Vocab object. - unknown(int, optional): default id to lookup a word that is out of vocab. If no argument is passed, 1 will be - used to be the default id which is the convention for unknown_token . Otherwise, user is strongly - encouraged to pass in the id for (default=None). + unknown_token(str, optional): word to use for lookup if the word being looked up is out of Vocabulary (oov). + If unknown_token is oov, runtime error will be thrown (default=None). """ @check_lookup - def __init__(self, vocab, unknown=None): - if unknown is None: - super().__init__(vocab) - else: - super().__init__(vocab, unknown) + def __init__(self, vocab, unknown_token=None): + super().__init__(vocab, unknown_token) class Ngram(cde.NgramOp): @@ -98,7 +95,7 @@ class Ngram(cde.NgramOp): """ @check_ngram - def __init__(self, n, left_pad=None, right_pad=None, separator=None): + def __init__(self, n, left_pad=("", 0), right_pad=("", 0), separator=" "): super().__init__(ngrams=n, l_pad_len=left_pad[1], r_pad_len=right_pad[1], l_pad_token=left_pad[0], r_pad_token=right_pad[0], separator=separator) @@ -125,15 +122,31 @@ class JiebaTokenizer(cde.JiebaTokenizerOp): - JiebaMode.MP, tokenize with MPSegment algorithm. - JiebaMode.HMM, tokenize with Hiddel Markov Model Segment algorithm. - JiebaMode.MIX, tokenize with a mix of MPSegment and HMMSegment algorithm. + with_offsets (bool, optional): If or not output offsets of tokens (default=False). + + Examples: + >>> # If with_offsets=False, default output one column {["text", dtype=str]} + >>> tokenizer_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP, with_offsets=False) + >>> data = data.map(operations=tokenizer_op) + >>> # If with_offsets=False, then output three columns {["token", dtype=str], ["offsets_start", dtype=uint32], + >>> # ["offsets_limit", dtype=uint32]} + >>> tokenizer_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP, with_offsets=True) + >>> data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + >>> columns_order=["token", "offsets_start", "offsets_limit"], operations=tokenizer_op) """ @check_jieba_init - def __init__(self, hmm_path, mp_path, mode=JiebaMode.MIX): + def __init__(self, hmm_path, mp_path, mode=JiebaMode.MIX, with_offsets=False): + if not isinstance(mode, JiebaMode): + raise TypeError("Wrong input type for mode, should be JiebaMode.") + self.mode = mode self.__check_path__(hmm_path) self.__check_path__(mp_path) + self.with_offsets = with_offsets super().__init__(hmm_path, mp_path, - DE_C_INTER_JIEBA_MODE[mode]) + DE_C_INTER_JIEBA_MODE[mode], + self.with_offsets) @check_jieba_add_word def add_word(self, word, freq=None): @@ -226,8 +239,26 @@ class JiebaTokenizer(cde.JiebaTokenizerOp): class UnicodeCharTokenizer(cde.UnicodeCharTokenizerOp): """ Tokenize a scalar tensor of UTF-8 string to Unicode characters. + + Args: + with_offsets (bool, optional): If or not output offsets of tokens (default=False). + + Examples: + >>> # If with_offsets=False, default output one column {["text", dtype=str]} + >>> tokenizer_op = text.UnicodeCharTokenizer() + >>> dataset = dataset.map(operations=tokenizer_op) + >>> # If with_offsets=False, then output three columns {["token", dtype=str], ["offsets_start", dtype=uint32], + >>> # ["offsets_limit", dtype=uint32]} + >>> tokenizer_op = text.UnicodeCharTokenizer(True) + >>> data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + >>> columns_order=["token", "offsets_start", "offsets_limit"], operations=tokenizer_op) """ + @check_with_offsets + def __init__(self, with_offsets=False): + self.with_offsets = with_offsets + super().__init__(self.with_offsets) + class WordpieceTokenizer(cde.WordpieceTokenizerOp): """ @@ -239,22 +270,58 @@ class WordpieceTokenizer(cde.WordpieceTokenizerOp): max_bytes_per_token (int, optional): Tokens exceeding this length will not be further split(default=100). unknown_token (str, optional): When we can not found the token: if 'unknown_token' is empty string, return the token directly, else return 'unknown_token'(default='[UNK]'). + with_offsets (bool, optional): If or not output offsets of tokens (default=False). + + Examples: + >>> # If with_offsets=False, default output one column {["text", dtype=str]} + >>> tokenizer_op = text.WordpieceTokenizer(vocab=vocab, unknown_token=['UNK'], + >>> max_bytes_per_token=100, with_offsets=False) + >>> dataset = dataset.map(operations=tokenizer_op) + >>> # If with_offsets=False, then output three columns {["token", dtype=str], ["offsets_start", dtype=uint32], + >>> # ["offsets_limit", dtype=uint32]} + >>> tokenizer_op = text.WordpieceTokenizer(vocab=vocab, unknown_token=['UNK'], + >>> max_bytes_per_token=100, with_offsets=True) + >>> data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + >>> columns_order=["token", "offsets_start", "offsets_limit"], operations=tokenizer_op) """ - def __init__(self, vocab, suffix_indicator='##', max_bytes_per_token=100, unknown_token='[UNK]'): + @check_wordpiece_tokenizer + def __init__(self, vocab, suffix_indicator='##', max_bytes_per_token=100, + unknown_token='[UNK]', with_offsets=False): self.vocab = vocab self.suffix_indicator = suffix_indicator self.max_bytes_per_token = max_bytes_per_token self.unknown_token = unknown_token - super().__init__(self.vocab, self.suffix_indicator, self.max_bytes_per_token, self.unknown_token) + self.with_offsets = with_offsets + super().__init__(self.vocab, self.suffix_indicator, self.max_bytes_per_token, + self.unknown_token, self.with_offsets) if platform.system().lower() != 'windows': class WhitespaceTokenizer(cde.WhitespaceTokenizerOp): """ Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces(such as: ' ', '\\\\t', '\\\\r', '\\\\n'). + + Args: + with_offsets (bool, optional): If or not output offsets of tokens (default=False). + + Examples: + >>> # If with_offsets=False, default output one column {["text", dtype=str]} + >>> tokenizer_op = text.WhitespaceTokenizer() + >>> dataset = dataset.map(operations=tokenizer_op) + >>> # If with_offsets=False, then output three columns {["token", dtype=str], + >>> # ["offsets_start", dtype=uint32], + >>> # ["offsets_limit", dtype=uint32]} + >>> tokenizer_op = text.WhitespaceTokenizer(True) + >>> data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + >>> columns_order=["token", "offsets_start", "offsets_limit"], operations=tokenizer_op) """ + @check_with_offsets + def __init__(self, with_offsets=False): + self.with_offsets = with_offsets + super().__init__(self.with_offsets) + class UnicodeScriptTokenizer(cde.UnicodeScriptTokenizerOp): """ @@ -262,11 +329,25 @@ if platform.system().lower() != 'windows': Args: keep_whitespace (bool, optional): If or not emit whitespace tokens (default=False). + with_offsets (bool, optional): If or not output offsets of tokens (default=False). + + Examples: + >>> # If with_offsets=False, default output one column {["text", dtype=str]} + >>> tokenizer_op = text.UnicodeScriptTokenizerOp(keep_whitespace=True, with_offsets=False) + >>> dataset = dataset.map(operations=tokenizer_op) + >>> # If with_offsets=False, then output three columns {["token", dtype=str], + >>> # ["offsets_start", dtype=uint32], + >>> # ["offsets_limit", dtype=uint32]} + >>> tokenizer_op = text.UnicodeScriptTokenizerOp(keep_whitespace=True, with_offsets=True) + >>> data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + >>> columns_order=["token", "offsets_start", "offsets_limit"], operations=tokenizer_op) """ - def __init__(self, keep_whitespace=False): + @check_unicode_script_tokenizer + def __init__(self, keep_whitespace=False, with_offsets=False): self.keep_whitespace = keep_whitespace - super().__init__(self.keep_whitespace) + self.with_offsets = with_offsets + super().__init__(self.keep_whitespace, self.with_offsets) class CaseFold(cde.CaseFoldOp): @@ -302,6 +383,9 @@ if platform.system().lower() != 'windows': """ def __init__(self, normalize_form=NormalizeForm.NFKC): + if not isinstance(normalize_form, NormalizeForm): + raise TypeError("Wrong input type for normalization_form, should be NormalizeForm.") + self.normalize_form = DE_C_INTER_NORMALIZE_FORM[normalize_form] super().__init__(self.normalize_form) @@ -338,12 +422,26 @@ if platform.system().lower() != 'windows': keep_delim_pattern(str, optional): The string matched by 'delim_pattern' can be kept as a token if it can be matched by 'keep_delim_pattern'. And the default value is empty str(''), in this situation, delimiters will not kept as a output token(default=''). + with_offsets (bool, optional): If or not output offsets of tokens (default=False). + + Examples: + >>> # If with_offsets=False, default output one column {["text", dtype=str]} + >>> tokenizer_op = text.RegexTokenizer(delim_pattern, keep_delim_pattern, with_offsets=False) + >>> dataset = dataset.map(operations=tokenizer_op) + >>> # If with_offsets=False, then output three columns {["token", dtype=str], + >>> # ["offsets_start", dtype=uint32], + >>> # ["offsets_limit", dtype=uint32]} + >>> tokenizer_op = text.RegexTokenizer(delim_pattern, keep_delim_pattern, with_offsets=True) + >>> data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + >>> columns_order=["token", "offsets_start", "offsets_limit"], operations=tokenizer_op) """ - def __init__(self, delim_pattern, keep_delim_pattern=''): + @check_regex_tokenizer + def __init__(self, delim_pattern, keep_delim_pattern='', with_offsets=False): self.delim_pattern = delim_pattern self.keep_delim_pattern = keep_delim_pattern - super().__init__(self.delim_pattern, self.keep_delim_pattern) + self.with_offsets = with_offsets + super().__init__(self.delim_pattern, self.keep_delim_pattern, self.with_offsets) class BasicTokenizer(cde.BasicTokenizerOp): @@ -359,16 +457,41 @@ if platform.system().lower() != 'windows': only effective when 'lower_case' is False. See NormalizeUTF8 for details(default='NONE'). preserve_unused_token(bool, optional): If True, do not split special tokens like '[CLS]', '[SEP]', '[UNK]', '[PAD]', '[MASK]'(default=True). + with_offsets (bool, optional): If or not output offsets of tokens (default=False). + + Examples: + >>> # If with_offsets=False, default output one column {["text", dtype=str]} + >>> tokenizer_op = text.BasicTokenizer(lower_case=False, + >>> keep_whitespace=False, + >>> normalization_form=NormalizeForm.NONE, + >>> preserve_unused_token=True, + >>> with_offsets=False) + >>> dataset = dataset.map(operations=tokenizer_op) + >>> # If with_offsets=False, then output three columns {["token", dtype=str], + >>> # ["offsets_start", dtype=uint32], + >>> # ["offsets_limit", dtype=uint32]} + >>> tokenizer_op = text.BasicTokenizer(lower_case=False, + >>> keep_whitespace=False, + >>> normalization_form=NormalizeForm.NONE, + >>> preserve_unused_token=True, + >>> with_offsets=True) + >>> data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + >>> columns_order=["token", "offsets_start", "offsets_limit"], operations=tokenizer_op) """ - def __init__(self, lower_case=False, keep_whitespace=False, - normalization_form=NormalizeForm.NONE, preserve_unused_token=True): + @check_basic_tokenizer + def __init__(self, lower_case=False, keep_whitespace=False, normalization_form=NormalizeForm.NONE, + preserve_unused_token=True, with_offsets=False): + if not isinstance(normalization_form, NormalizeForm): + raise TypeError("Wrong input type for normalization_form, should be NormalizeForm.") + self.lower_case = lower_case self.keep_whitespace = keep_whitespace self.normalization_form = DE_C_INTER_NORMALIZE_FORM[normalization_form] self.preserve_unused_token = preserve_unused_token - super().__init__(self.lower_case, self.keep_whitespace, - self.normalization_form, self.preserve_unused_token) + self.with_offsets = with_offsets + super().__init__(self.lower_case, self.keep_whitespace, self.normalization_form, + self.preserve_unused_token, self.with_offsets) class BertTokenizer(cde.BertTokenizerOp): @@ -389,11 +512,33 @@ if platform.system().lower() != 'windows': only effective when 'lower_case' is False. See NormalizeUTF8 for details(default='NONE'). preserve_unused_token(bool, optional): If True, do not split special tokens like '[CLS]', '[SEP]', '[UNK]', '[PAD]', '[MASK]'(default=True). + with_offsets (bool, optional): If or not output offsets of tokens (default=False). + + Examples: + >>> # If with_offsets=False, default output one column {["text", dtype=str]} + >>> tokenizer_op = text.BertTokenizer(vocab=vocab, suffix_indicator='##', max_bytes_per_token=100, + >>> unknown_token=100, lower_case=False, keep_whitespace=False, + >>> normalization_form=NormalizeForm.NONE, preserve_unused_token=True, + >>> with_offsets=False) + >>> dataset = dataset.map(operations=tokenizer_op) + >>> # If with_offsets=False, then output three columns {["token", dtype=str], + >>> # ["offsets_start", dtype=uint32], + >>> # ["offsets_limit", dtype=uint32]} + >>> tokenizer_op = text.BertTokenizer(vocab=vocab, suffix_indicator='##', max_bytes_per_token=100, + >>> unknown_token=100, lower_case=False, keep_whitespace=False, + >>> normalization_form=NormalizeForm.NONE, preserve_unused_token=True, + >>> with_offsets=True) + >>> data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + >>> columns_order=["token", "offsets_start", "offsets_limit"], operations=tokenizer_op) """ - def __init__(self, vocab, suffix_indicator='##', max_bytes_per_token=100, - unknown_token='[UNK]', lower_case=False, keep_whitespace=False, - normalization_form=NormalizeForm.NONE, preserve_unused_token=True): + @check_bert_tokenizer + def __init__(self, vocab, suffix_indicator='##', max_bytes_per_token=100, unknown_token='[UNK]', + lower_case=False, keep_whitespace=False, normalization_form=NormalizeForm.NONE, + preserve_unused_token=True, with_offsets=False): + if not isinstance(normalization_form, NormalizeForm): + raise TypeError("Wrong input type for normalization_form, should be NormalizeForm.") + self.vocab = vocab self.suffix_indicator = suffix_indicator self.max_bytes_per_token = max_bytes_per_token @@ -402,8 +547,10 @@ if platform.system().lower() != 'windows': self.keep_whitespace = keep_whitespace self.normalization_form = DE_C_INTER_NORMALIZE_FORM[normalization_form] self.preserve_unused_token = preserve_unused_token + self.with_offsets = with_offsets super().__init__(self.vocab, self.suffix_indicator, self.max_bytes_per_token, self.unknown_token, - self.lower_case, self.keep_whitespace, self.normalization_form, self.preserve_unused_token) + self.lower_case, self.keep_whitespace, self.normalization_form, + self.preserve_unused_token, self.with_offsets) class TruncateSequencePair(cde.TruncateSequencePairOp): diff --git a/mindspore/dataset/text/utils.py b/mindspore/dataset/text/utils.py index 7347a4b854..ef1d0e6fc5 100644 --- a/mindspore/dataset/text/utils.py +++ b/mindspore/dataset/text/utils.py @@ -28,6 +28,7 @@ __all__ = [ "Vocab", "to_str", "to_bytes" ] + class Vocab(cde.Vocab): """ Vocab object that is used to lookup a word. @@ -38,7 +39,7 @@ class Vocab(cde.Vocab): @classmethod @check_from_dataset def from_dataset(cls, dataset, columns=None, freq_range=None, top_k=None, special_tokens=None, - special_first=None): + special_first=True): """ Build a vocab from a dataset. @@ -62,13 +63,21 @@ class Vocab(cde.Vocab): special_tokens(list, optional): a list of strings, each one is a special token. for example special_tokens=["",""] (default=None, no special tokens will be added). special_first(bool, optional): whether special_tokens will be prepended/appended to vocab. If special_tokens - is specified and special_first is set to None, special_tokens will be prepended (default=None). + is specified and special_first is set to True, special_tokens will be prepended (default=True). Returns: Vocab, Vocab object built from dataset. """ vocab = Vocab() + if columns is None: + columns = [] + if not isinstance(columns, list): + columns = [columns] + if freq_range is None: + freq_range = (None, None) + if special_tokens is None: + special_tokens = [] root = copy.deepcopy(dataset).build_vocab(vocab, columns, freq_range, top_k, special_tokens, special_first) for d in root.create_dict_iterator(): if d is not None: @@ -77,7 +86,7 @@ class Vocab(cde.Vocab): @classmethod @check_from_list - def from_list(cls, word_list, special_tokens=None, special_first=None): + def from_list(cls, word_list, special_tokens=None, special_first=True): """ Build a vocab object from a list of word. @@ -86,29 +95,33 @@ class Vocab(cde.Vocab): special_tokens(list, optional): a list of strings, each one is a special token. for example special_tokens=["",""] (default=None, no special tokens will be added). special_first(bool, optional): whether special_tokens will be prepended/appended to vocab, If special_tokens - is specified and special_first is set to None, special_tokens will be prepended (default=None). + is specified and special_first is set to True, special_tokens will be prepended (default=True). """ - + if special_tokens is None: + special_tokens = [] return super().from_list(word_list, special_tokens, special_first) @classmethod @check_from_file - def from_file(cls, file_path, delimiter=None, vocab_size=None, special_tokens=None, special_first=None): + def from_file(cls, file_path, delimiter="", vocab_size=None, special_tokens=None, special_first=True): """ Build a vocab object from a list of word. Args: file_path (str): path to the file which contains the vocab list. delimiter (str, optional): a delimiter to break up each line in file, the first element is taken to be - the word (default=None). + the word (default=""). vocab_size (int, optional): number of words to read from file_path (default=None, all words are taken). special_tokens (list, optional): a list of strings, each one is a special token. for example special_tokens=["",""] (default=None, no special tokens will be added). special_first (bool, optional): whether special_tokens will be prepended/appended to vocab, - If special_tokens is specified and special_first is set to None, - special_tokens will be prepended (default=None). + If special_tokens is specified and special_first is set to True, + special_tokens will be prepended (default=True). """ - + if vocab_size is None: + vocab_size = -1 + if special_tokens is None: + special_tokens = [] return super().from_file(file_path, delimiter, vocab_size, special_tokens, special_first) @classmethod diff --git a/mindspore/dataset/text/validators.py b/mindspore/dataset/text/validators.py index afab8665cd..b0327f5609 100644 --- a/mindspore/dataset/text/validators.py +++ b/mindspore/dataset/text/validators.py @@ -17,23 +17,22 @@ validators for text ops """ from functools import wraps - -import mindspore._c_dataengine as cde import mindspore.common.dtype as mstype +import mindspore._c_dataengine as cde from mindspore._c_expression import typing -from ..transforms.validators import check_uint32, check_pos_int64 + +from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_uint32, \ + INT32_MAX, check_value, check_positive def check_unique_list_of_words(words, arg_name): """Check that words is a list and each element is a str without any duplication""" - if not isinstance(words, list): - raise ValueError(arg_name + " needs to be a list of words of type string.") + type_check(words, (list,), arg_name) words_set = set() for word in words: - if not isinstance(word, str): - raise ValueError("each word in " + arg_name + " needs to be type str.") + type_check(word, (str,), arg_name) if word in words_set: raise ValueError(arg_name + " contains duplicate word: " + word + ".") words_set.add(word) @@ -41,161 +40,100 @@ def check_unique_list_of_words(words, arg_name): def check_lookup(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): - vocab, unknown = (list(args) + 2 * [None])[:2] - if "vocab" in kwargs: - vocab = kwargs.get("vocab") - if "unknown" in kwargs: - unknown = kwargs.get("unknown") - if unknown is not None: - if not (isinstance(unknown, int) and unknown >= 0): - raise ValueError("unknown needs to be a non-negative integer.") + [vocab, unknown_token], _ = parse_user_args(method, *args, **kwargs) - if not isinstance(vocab, cde.Vocab): - raise ValueError("vocab is not an instance of cde.Vocab.") + if unknown_token is not None: + type_check(unknown_token, (str,), "unknown_token") - kwargs["vocab"] = vocab - kwargs["unknown"] = unknown - return method(self, **kwargs) + type_check(vocab, (cde.Vocab,), "vocab is not an instance of cde.Vocab.") + + return method(self, *args, **kwargs) return new_method def check_from_file(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): - file_path, delimiter, vocab_size, special_tokens, special_first = (list(args) + 5 * [None])[:5] - if "file_path" in kwargs: - file_path = kwargs.get("file_path") - if "delimiter" in kwargs: - delimiter = kwargs.get("delimiter") - if "vocab_size" in kwargs: - vocab_size = kwargs.get("vocab_size") - if "special_tokens" in kwargs: - special_tokens = kwargs.get("special_tokens") - if "special_first" in kwargs: - special_first = kwargs.get("special_first") - - if not isinstance(file_path, str): - raise ValueError("file_path needs to be str.") - - if delimiter is not None: - if not isinstance(delimiter, str): - raise ValueError("delimiter needs to be str.") - else: - delimiter = "" + [file_path, delimiter, vocab_size, special_tokens, special_first], _ = parse_user_args(method, *args, + **kwargs) + if special_tokens is not None: + check_unique_list_of_words(special_tokens, "special_tokens") + type_check_list([file_path, delimiter], (str,), ["file_path", "delimiter"]) if vocab_size is not None: - if not (isinstance(vocab_size, int) and vocab_size > 0): - raise ValueError("vocab size needs to be a positive integer.") - else: - vocab_size = -1 - - if special_first is None: - special_first = True - - if not isinstance(special_first, bool): - raise ValueError("special_first needs to be a boolean value") - - if special_tokens is None: - special_tokens = [] + check_value(vocab_size, (-1, INT32_MAX), "vocab_size") + type_check(special_first, (bool,), special_first) - check_unique_list_of_words(special_tokens, "special_tokens") - - kwargs["file_path"] = file_path - kwargs["delimiter"] = delimiter - kwargs["vocab_size"] = vocab_size - kwargs["special_tokens"] = special_tokens - kwargs["special_first"] = special_first - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method def check_from_list(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): - word_list, special_tokens, special_first = (list(args) + 3 * [None])[:3] - if "word_list" in kwargs: - word_list = kwargs.get("word_list") - if "special_tokens" in kwargs: - special_tokens = kwargs.get("special_tokens") - if "special_first" in kwargs: - special_first = kwargs.get("special_first") - if special_tokens is None: - special_tokens = [] - word_set = check_unique_list_of_words(word_list, "word_list") - token_set = check_unique_list_of_words(special_tokens, "special_tokens") + [word_list, special_tokens, special_first], _ = parse_user_args(method, *args, **kwargs) - intersect = word_set.intersection(token_set) + word_set = check_unique_list_of_words(word_list, "word_list") + if special_tokens is not None: + token_set = check_unique_list_of_words(special_tokens, "special_tokens") - if intersect != set(): - raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") + intersect = word_set.intersection(token_set) - if special_first is None: - special_first = True + if intersect != set(): + raise ValueError("special_tokens and word_list contain duplicate word :" + str(intersect) + ".") - if not isinstance(special_first, bool): - raise ValueError("special_first needs to be a boolean value.") + type_check(special_first, (bool,), "special_first") - kwargs["word_list"] = word_list - kwargs["special_tokens"] = special_tokens - kwargs["special_first"] = special_first - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method def check_from_dict(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): - word_dict, = (list(args) + [None])[:1] - if "word_dict" in kwargs: - word_dict = kwargs.get("word_dict") - if not isinstance(word_dict, dict): - raise ValueError("word_dict needs to be a list of word,id pairs.") + [word_dict], _ = parse_user_args(method, *args, **kwargs) + + type_check(word_dict, (dict,), "word_dict") + for word, word_id in word_dict.items(): - if not isinstance(word, str): - raise ValueError("Each word in word_dict needs to be type string.") - if not (isinstance(word_id, int) and word_id >= 0): - raise ValueError("Each word id needs to be positive integer.") - kwargs["word_dict"] = word_dict - return method(self, **kwargs) + type_check(word, (str,), "word") + type_check(word_id, (int,), "word_id") + check_value(word_id, (0, INT32_MAX), "word_id") + return method(self, *args, **kwargs) return new_method def check_jieba_init(method): - """Wrapper method to check the parameters of jieba add word.""" + """Wrapper method to check the parameters of jieba init.""" @wraps(method) def new_method(self, *args, **kwargs): - hmm_path, mp_path, model = (list(args) + 3 * [None])[:3] + [hmm_path, mp_path, _, with_offsets], _ = parse_user_args(method, *args, **kwargs) - if "hmm_path" in kwargs: - hmm_path = kwargs.get("hmm_path") - if "mp_path" in kwargs: - mp_path = kwargs.get("mp_path") if hmm_path is None: - raise ValueError( - "The dict of HMMSegment in cppjieba is not provided.") - kwargs["hmm_path"] = hmm_path + raise ValueError("The dict of HMMSegment in cppjieba is not provided.") + if not isinstance(hmm_path, str): + raise TypeError("Wrong input type for hmm_path, should be string.") if mp_path is None: - raise ValueError( - "The dict of MPSegment in cppjieba is not provided.") - kwargs["mp_path"] = mp_path - if model is not None: - kwargs["model"] = model - return method(self, **kwargs) + raise ValueError("The dict of MPSegment in cppjieba is not provided.") + if not isinstance(mp_path, str): + raise TypeError("Wrong input type for mp_path, should be string.") + if not isinstance(with_offsets, bool): + raise TypeError("Wrong input type for with_offsets, should be boolean.") + return method(self, *args, **kwargs) return new_method @@ -205,19 +143,12 @@ def check_jieba_add_word(method): @wraps(method) def new_method(self, *args, **kwargs): - word, freq = (list(args) + 2 * [None])[:2] - - if "word" in kwargs: - word = kwargs.get("word") - if "freq" in kwargs: - freq = kwargs.get("freq") + [word, freq], _ = parse_user_args(method, *args, **kwargs) if word is None: raise ValueError("word is not provided.") - kwargs["word"] = word if freq is not None: check_uint32(freq) - kwargs["freq"] = freq - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -227,104 +158,183 @@ def check_jieba_add_dict(method): @wraps(method) def new_method(self, *args, **kwargs): - user_dict = (list(args) + [None])[0] - if "user_dict" in kwargs: - user_dict = kwargs.get("user_dict") - if user_dict is None: - raise ValueError("user_dict is not provided.") - kwargs["user_dict"] = user_dict - return method(self, **kwargs) + parse_user_args(method, *args, **kwargs) + return method(self, *args, **kwargs) return new_method -def check_from_dataset(method): - """A wrapper that wrap a parameter checker to the original function.""" +def check_with_offsets(method): + """Wrapper method to check if with_offsets is the only one parameter.""" @wraps(method) def new_method(self, *args, **kwargs): + [with_offsets], _ = parse_user_args(method, *args, **kwargs) + if not isinstance(with_offsets, bool): + raise TypeError("Wrong input type for with_offsets, should be boolean.") + return method(self, *args, **kwargs) - dataset, columns, freq_range, top_k, special_tokens, special_first = (list(args) + 6 * [None])[:6] - if "dataset" in kwargs: - dataset = kwargs.get("dataset") - if "columns" in kwargs: - columns = kwargs.get("columns") - if "freq_range" in kwargs: - freq_range = kwargs.get("freq_range") - if "top_k" in kwargs: - top_k = kwargs.get("top_k") - if "special_tokens" in kwargs: - special_tokens = kwargs.get("special_tokens") - if "special_first" in kwargs: - special_first = kwargs.get("special_first") + return new_method - if columns is None: - columns = [] - if not isinstance(columns, list): - columns = [columns] +def check_unicode_script_tokenizer(method): + """Wrapper method to check the parameter of UnicodeScriptTokenizer.""" - for column in columns: - if not isinstance(column, str): - raise ValueError("columns need to be a list of strings.") + @wraps(method) + def new_method(self, *args, **kwargs): + [keep_whitespace, with_offsets], _ = parse_user_args(method, *args, **kwargs) + if not isinstance(keep_whitespace, bool): + raise TypeError("Wrong input type for keep_whitespace, should be boolean.") + if not isinstance(with_offsets, bool): + raise TypeError("Wrong input type for with_offsets, should be boolean.") + return method(self, *args, **kwargs) - if freq_range is None: - freq_range = (None, None) + return new_method - if not isinstance(freq_range, tuple) or len(freq_range) != 2: - raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") - for num in freq_range: - if num is not None and (not isinstance(num, int)): - raise ValueError("freq_range needs to be either None or a tuple of 2 integers or an int and a None.") +def check_wordpiece_tokenizer(method): + """Wrapper method to check the parameter of WordpieceTokenizer.""" - if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): - if freq_range[0] > freq_range[1] or freq_range[0] < 0: - raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") + @wraps(method) + def new_method(self, *args, **kwargs): + [vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets], _ = \ + parse_user_args(method, *args, **kwargs) + if vocab is None: + raise ValueError("vocab is not provided.") + if not isinstance(vocab, cde.Vocab): + raise TypeError("Wrong input type for vocab, should be Vocab object.") + if not isinstance(suffix_indicator, str): + raise TypeError("Wrong input type for suffix_indicator, should be string.") + if not isinstance(unknown_token, str): + raise TypeError("Wrong input type for unknown_token, should be string.") + if not isinstance(with_offsets, bool): + raise TypeError("Wrong input type for with_offsets, should be boolean.") + check_uint32(max_bytes_per_token) + return method(self, *args, **kwargs) - if top_k is not None and (not isinstance(top_k, int)): - raise ValueError("top_k needs to be a positive integer.") + return new_method - if isinstance(top_k, int) and top_k <= 0: - raise ValueError("top_k needs to be a positive integer.") - if special_first is None: - special_first = True +def check_regex_tokenizer(method): + """Wrapper method to check the parameter of RegexTokenizer.""" - if special_tokens is None: - special_tokens = [] + @wraps(method) + def new_method(self, *args, **kwargs): + [delim_pattern, keep_delim_pattern, with_offsets], _ = parse_user_args(method, *args, **kwargs) + if delim_pattern is None: + raise ValueError("delim_pattern is not provided.") + if not isinstance(delim_pattern, str): + raise TypeError("Wrong input type for delim_pattern, should be string.") + if not isinstance(keep_delim_pattern, str): + raise TypeError("Wrong input type for keep_delim_pattern, should be string.") + if not isinstance(with_offsets, bool): + raise TypeError("Wrong input type for with_offsets, should be boolean.") + return method(self, *args, **kwargs) - if not isinstance(special_first, bool): - raise ValueError("special_first needs to be a boolean value.") + return new_method - check_unique_list_of_words(special_tokens, "special_tokens") - kwargs["dataset"] = dataset - kwargs["columns"] = columns - kwargs["freq_range"] = freq_range - kwargs["top_k"] = top_k - kwargs["special_tokens"] = special_tokens - kwargs["special_first"] = special_first +def check_basic_tokenizer(method): + """Wrapper method to check the parameter of RegexTokenizer.""" - return method(self, **kwargs) + @wraps(method) + def new_method(self, *args, **kwargs): + [lower_case, keep_whitespace, _, preserve_unused, with_offsets], _ = \ + parse_user_args(method, *args, **kwargs) + if not isinstance(lower_case, bool): + raise TypeError("Wrong input type for lower_case, should be boolean.") + if not isinstance(keep_whitespace, bool): + raise TypeError("Wrong input type for keep_whitespace, should be boolean.") + if not isinstance(preserve_unused, bool): + raise TypeError("Wrong input type for preserve_unused_token, should be boolean.") + if not isinstance(with_offsets, bool): + raise TypeError("Wrong input type for with_offsets, should be boolean.") + return method(self, *args, **kwargs) + + return new_method + + +def check_bert_tokenizer(method): + """Wrapper method to check the parameter of BertTokenizer.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [vocab, suffix_indicator, max_bytes_per_token, unknown_token, lower_case, keep_whitespace, _, + preserve_unused_token, with_offsets], _ = parse_user_args(method, *args, **kwargs) + if vocab is None: + raise ValueError("vacab is not provided.") + if not isinstance(vocab, cde.Vocab): + raise TypeError("Wrong input type for vocab, should be Vocab object.") + if not isinstance(suffix_indicator, str): + raise TypeError("Wrong input type for suffix_indicator, should be string.") + if not isinstance(max_bytes_per_token, int): + raise TypeError("Wrong input type for max_bytes_per_token, should be int.") + check_uint32(max_bytes_per_token) + + if not isinstance(unknown_token, str): + raise TypeError("Wrong input type for unknown_token, should be string.") + if not isinstance(lower_case, bool): + raise TypeError("Wrong input type for lower_case, should be boolean.") + if not isinstance(keep_whitespace, bool): + raise TypeError("Wrong input type for keep_whitespace, should be boolean.") + if not isinstance(preserve_unused_token, bool): + raise TypeError("Wrong input type for preserve_unused_token, should be boolean.") + if not isinstance(with_offsets, bool): + raise TypeError("Wrong input type for with_offsets, should be boolean.") + return method(self, *args, **kwargs) + + return new_method + + +def check_from_dataset(method): + """A wrapper that wraps a parameter checker to the original function.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + + [_, columns, freq_range, top_k, special_tokens, special_first], _ = parse_user_args(method, *args, + **kwargs) + if columns is not None: + if not isinstance(columns, list): + columns = [columns] + col_names = ["col_{0}".format(i) for i in range(len(columns))] + type_check_list(columns, (str,), col_names) + + if freq_range is not None: + type_check(freq_range, (tuple,), "freq_range") + + if len(freq_range) != 2: + raise ValueError("freq_range needs to be a tuple of 2 integers or an int and a None.") + + for num in freq_range: + if num is not None and (not isinstance(num, int)): + raise ValueError( + "freq_range needs to be either None or a tuple of 2 integers or an int and a None.") + + if isinstance(freq_range[0], int) and isinstance(freq_range[1], int): + if freq_range[0] > freq_range[1] or freq_range[0] < 0: + raise ValueError("frequency range [a,b] should be 0 <= a <= b (a,b are inclusive).") + + type_check(top_k, (int, type(None)), "top_k") + + if isinstance(top_k, int): + check_positive(top_k, "top_k") + type_check(special_first, (bool,), "special_first") + + if special_tokens is not None: + check_unique_list_of_words(special_tokens, "special_tokens") + + return method(self, *args, **kwargs) return new_method def check_ngram(method): - """A wrapper that wrap a parameter checker to the original function.""" + """A wrapper that wraps a parameter checker to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): - n, left_pad, right_pad, separator = (list(args) + 4 * [None])[:4] - if "n" in kwargs: - n = kwargs.get("n") - if "left_pad" in kwargs: - left_pad = kwargs.get("left_pad") - if "right_pad" in kwargs: - right_pad = kwargs.get("right_pad") - if "separator" in kwargs: - separator = kwargs.get("separator") + [n, left_pad, right_pad, separator], _ = parse_user_args(method, *args, **kwargs) if isinstance(n, int): n = [n] @@ -332,15 +342,9 @@ def check_ngram(method): if not (isinstance(n, list) and n != []): raise ValueError("n needs to be a non-empty list of positive integers.") - for gram in n: - if not (isinstance(gram, int) and gram > 0): - raise ValueError("n in ngram needs to be a positive number.") - - if left_pad is None: - left_pad = ("", 0) - - if right_pad is None: - right_pad = ("", 0) + for i, gram in enumerate(n): + type_check(gram, (int,), "gram[{0}]".format(i)) + check_positive(gram, "gram_{}".format(i)) if not (isinstance(left_pad, tuple) and len(left_pad) == 2 and isinstance(left_pad[0], str) and isinstance( left_pad[1], int)): @@ -353,11 +357,7 @@ def check_ngram(method): if not (left_pad[1] >= 0 and right_pad[1] >= 0): raise ValueError("padding width need to be positive numbers.") - if separator is None: - separator = " " - - if not isinstance(separator, str): - raise ValueError("separator needs to be a string.") + type_check(separator, (str,), "separator") kwargs["n"] = n kwargs["left_pad"] = left_pad @@ -374,16 +374,8 @@ def check_pair_truncate(method): @wraps(method) def new_method(self, *args, **kwargs): - max_length = (list(args) + [None])[0] - if "max_length" in kwargs: - max_length = kwargs.get("max_length") - if max_length is None: - raise ValueError("max_length is not provided.") - - check_pos_int64(max_length) - kwargs["max_length"] = max_length - - return method(self, **kwargs) + parse_user_args(method, *args, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -393,22 +385,13 @@ def check_to_number(method): @wraps(method) def new_method(self, *args, **kwargs): - data_type = (list(args) + [None])[0] - if "data_type" in kwargs: - data_type = kwargs.get("data_type") - - if data_type is None: - raise ValueError("data_type is a mandatory parameter but was not provided.") - - if not isinstance(data_type, typing.Type): - raise TypeError("data_type is not a MindSpore data type.") + [data_type], _ = parse_user_args(method, *args, **kwargs) + type_check(data_type, (typing.Type,), "data_type") if data_type not in mstype.number_type: raise TypeError("data_type is not numeric data type.") - kwargs["data_type"] = data_type - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -418,18 +401,11 @@ def check_python_tokenizer(method): @wraps(method) def new_method(self, *args, **kwargs): - tokenizer = (list(args) + [None])[0] - if "tokenizer" in kwargs: - tokenizer = kwargs.get("tokenizer") - - if tokenizer is None: - raise ValueError("tokenizer is a mandatory parameter.") + [tokenizer], _ = parse_user_args(method, *args, **kwargs) if not callable(tokenizer): raise TypeError("tokenizer is not a callable python function") - kwargs["tokenizer"] = tokenizer - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index 48e986202c..62496822e5 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -197,7 +197,7 @@ class PadEnd(cde.PadEndOp): class Concatenate(cde.ConcatenateOp): """ - Tensor operation to prepend and append to a tensor. + Tensor operation that concatenates all columns into a single tensor. Args: axis (int, optional): axis to concatenate the tensors along (Default=0). diff --git a/mindspore/dataset/transforms/validators.py b/mindspore/dataset/transforms/validators.py index 6b5760e0c5..9fe0fa5f10 100644 --- a/mindspore/dataset/transforms/validators.py +++ b/mindspore/dataset/transforms/validators.py @@ -18,6 +18,7 @@ from functools import wraps import numpy as np from mindspore._c_expression import typing +from ..core.validator_helpers import parse_user_args, type_check, check_pos_int64, check_value, check_positive # POS_INT_MIN is used to limit values from starting from 0 POS_INT_MIN = 1 @@ -37,106 +38,33 @@ DOUBLE_MAX_INTEGER = 9007199254740992 DOUBLE_MIN_INTEGER = -9007199254740992 -def check_type(value, valid_type): - if not isinstance(value, valid_type): - raise ValueError("Wrong input type") - - -def check_value(value, valid_range): - if value < valid_range[0] or value > valid_range[1]: - raise ValueError("Input is not within the required range") - - -def check_range(values, valid_range): - if not valid_range[0] <= values[0] <= values[1] <= valid_range[1]: - raise ValueError("Input range is not valid") - - -def check_positive(value): - if value <= 0: - raise ValueError("Input must greater than 0") - - -def check_positive_float(value, valid_max=None): - if value <= 0 or not isinstance(value, float) or (valid_max is not None and value > valid_max): - raise ValueError("Input need to be a valid positive float.") - - -def check_bool(value): - if not isinstance(value, bool): - raise ValueError("Value needs to be a boolean.") - - -def check_2tuple(value): - if not (isinstance(value, tuple) and len(value) == 2): - raise ValueError("Value needs to be a 2-tuple.") - - -def check_list(value): - if not isinstance(value, list): - raise ValueError("The input needs to be a list.") - - -def check_uint8(value): - if not isinstance(value, int): - raise ValueError("The input needs to be a integer") - check_value(value, [UINT8_MIN, UINT8_MAX]) - - -def check_uint32(value): - if not isinstance(value, int): - raise ValueError("The input needs to be a integer") - check_value(value, [UINT32_MIN, UINT32_MAX]) - - -def check_pos_int32(value): - """Checks for int values starting from 1""" - if not isinstance(value, int): - raise ValueError("The input needs to be a integer") - check_value(value, [POS_INT_MIN, INT32_MAX]) - - -def check_uint64(value): - if not isinstance(value, int): - raise ValueError("The input needs to be a integer") - check_value(value, [UINT64_MIN, UINT64_MAX]) - - -def check_pos_int64(value): - if not isinstance(value, int): - raise ValueError("The input needs to be a integer") - check_value(value, [UINT64_MIN, INT64_MAX]) - +def check_fill_value(method): + """Wrapper method to check the parameters of fill_value.""" -def check_pos_float32(value): - check_value(value, [UINT32_MIN, FLOAT_MAX_INTEGER]) + @wraps(method) + def new_method(self, *args, **kwargs): + [fill_value], _ = parse_user_args(method, *args, **kwargs) + type_check(fill_value, (str, float, bool, int, bytes), "fill_value") + return method(self, *args, **kwargs) -def check_pos_float64(value): - check_value(value, [UINT64_MIN, DOUBLE_MAX_INTEGER]) + return new_method def check_one_hot_op(method): - """Wrapper method to check the parameters of one hot op.""" + """Wrapper method to check the parameters of one_hot_op.""" @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - num_classes, smoothing_rate = args - if "num_classes" in kwargs: - num_classes = kwargs.get("num_classes") - if "smoothing_rate" in kwargs: - smoothing_rate = kwargs.get("smoothing_rate") - - if num_classes is None: - raise ValueError("num_classes") - check_pos_int32(num_classes) - kwargs["num_classes"] = num_classes + [num_classes, smoothing_rate], _ = parse_user_args(method, *args, **kwargs) + + type_check(num_classes, (int,), "num_classes") + check_positive(num_classes) + if smoothing_rate is not None: - check_value(smoothing_rate, [0., 1.]) - kwargs["smoothing_rate"] = smoothing_rate + check_value(smoothing_rate, [0., 1.], "smoothing_rate") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -146,35 +74,12 @@ def check_num_classes(method): @wraps(method) def new_method(self, *args, **kwargs): - num_classes = (list(args) + [None])[0] - if "num_classes" in kwargs: - num_classes = kwargs.get("num_classes") - if num_classes is None: - raise ValueError("num_classes is not provided.") - - check_pos_int32(num_classes) - kwargs["num_classes"] = num_classes - - return method(self, **kwargs) - - return new_method - + [num_classes], _ = parse_user_args(method, *args, **kwargs) -def check_fill_value(method): - """Wrapper method to check the parameters of fill value.""" - - @wraps(method) - def new_method(self, *args, **kwargs): - fill_value = (list(args) + [None])[0] - if "fill_value" in kwargs: - fill_value = kwargs.get("fill_value") - if fill_value is None: - raise ValueError("fill_value is not provided.") - if not isinstance(fill_value, (str, float, bool, int, bytes)): - raise TypeError("fill_value must be either a primitive python str, float, bool, bytes or int") - kwargs["fill_value"] = fill_value + type_check(num_classes, (int,), "num_classes") + check_positive(num_classes) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -184,17 +89,11 @@ def check_de_type(method): @wraps(method) def new_method(self, *args, **kwargs): - data_type = (list(args) + [None])[0] - if "data_type" in kwargs: - data_type = kwargs.get("data_type") + [data_type], _ = parse_user_args(method, *args, **kwargs) - if data_type is None: - raise ValueError("data_type is not provided.") - if not isinstance(data_type, typing.Type): - raise TypeError("data_type is not a MindSpore data type.") - kwargs["data_type"] = data_type + type_check(data_type, (typing.Type,), "data_type") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -204,13 +103,11 @@ def check_slice_op(method): @wraps(method) def new_method(self, *args): - for i, arg in enumerate(args): - if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)): - raise TypeError("Indexing of dim " + str(i) + "is not of valid type") + for _, arg in enumerate(args): + type_check(arg, (int, slice, list, type(None), type(Ellipsis)), "arg") if isinstance(arg, list): for a in arg: - if not isinstance(a, int): - raise TypeError("Index " + a + " is not an int") + type_check(a, (int,), "a") return method(self, *args) return new_method @@ -221,36 +118,14 @@ def check_mask_op(method): @wraps(method) def new_method(self, *args, **kwargs): - operator, constant, dtype = (list(args) + 3 * [None])[:3] - if "operator" in kwargs: - operator = kwargs.get("operator") - if "constant" in kwargs: - constant = kwargs.get("constant") - if "dtype" in kwargs: - dtype = kwargs.get("dtype") - - if operator is None: - raise ValueError("operator is not provided.") - - if constant is None: - raise ValueError("constant is not provided.") + [operator, constant, dtype], _ = parse_user_args(method, *args, **kwargs) from .c_transforms import Relational - if not isinstance(operator, Relational): - raise TypeError("operator is not a Relational operator enum.") + type_check(operator, (Relational,), "operator") + type_check(constant, (str, float, bool, int, bytes), "constant") + type_check(dtype, (typing.Type,), "dtype") - if not isinstance(constant, (str, float, bool, int, bytes)): - raise TypeError("constant must be either a primitive python str, float, bool, bytes or int") - - if dtype is not None: - if not isinstance(dtype, typing.Type): - raise TypeError("dtype is not a MindSpore data type.") - kwargs["dtype"] = dtype - - kwargs["operator"] = operator - kwargs["constant"] = constant - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -260,22 +135,12 @@ def check_pad_end(method): @wraps(method) def new_method(self, *args, **kwargs): - pad_shape, pad_value = (list(args) + 2 * [None])[:2] - if "pad_shape" in kwargs: - pad_shape = kwargs.get("pad_shape") - if "pad_value" in kwargs: - pad_value = kwargs.get("pad_value") - if pad_shape is None: - raise ValueError("pad_shape is not provided.") + [pad_shape, pad_value], _ = parse_user_args(method, *args, **kwargs) if pad_value is not None: - if not isinstance(pad_value, (str, float, bool, int, bytes)): - raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes") - kwargs["pad_value"] = pad_value - - if not isinstance(pad_shape, list): - raise TypeError("pad_shape must be a list") + type_check(pad_value, (str, float, bool, int, bytes), "pad_value") + type_check(pad_shape, (list,), "pad_end") for dim in pad_shape: if dim is not None: @@ -284,9 +149,7 @@ def check_pad_end(method): else: raise TypeError("a value in the list is not an integer.") - kwargs["pad_shape"] = pad_shape - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -296,31 +159,24 @@ def check_concat_type(method): @wraps(method) def new_method(self, *args, **kwargs): - axis, prepend, append = (list(args) + 3 * [None])[:3] - if "prepend" in kwargs: - prepend = kwargs.get("prepend") - if "append" in kwargs: - append = kwargs.get("append") - if "axis" in kwargs: - axis = kwargs.get("axis") + + [axis, prepend, append], _ = parse_user_args(method, *args, **kwargs) if axis is not None: - if not isinstance(axis, int): - raise TypeError("axis type is not valid, must be an integer.") + type_check(axis, (int,), "axis") if axis not in (0, -1): raise ValueError("only 1D concatenation supported.") - kwargs["axis"] = axis if prepend is not None: - if not isinstance(prepend, (type(None), np.ndarray)): - raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.") - kwargs["prepend"] = prepend + type_check(prepend, (np.ndarray,), "prepend") + if len(prepend.shape) != 1: + raise ValueError("can only prepend 1D arrays.") if append is not None: - if not isinstance(append, (type(None), np.ndarray)): - raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.") - kwargs["append"] = append + type_check(append, (np.ndarray,), "append") + if len(append.shape) != 1: + raise ValueError("can only append 1D arrays.") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 43ac037541..8e3b7c7214 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -40,12 +40,14 @@ Examples: >>> dataset = dataset.map(input_columns="image", operations=transforms_list) >>> dataset = dataset.map(input_columns="label", operations=onehot_op) """ +import numbers import mindspore._c_dataengine as cde from .utils import Inter, Border from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ - check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ - check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp + check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \ + check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \ + FLOAT_MAX_INTEGER DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, @@ -57,6 +59,18 @@ DE_C_BORDER_TYPE = {Border.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT, Border.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC} +def parse_padding(padding): + if isinstance(padding, numbers.Number): + padding = [padding] * 4 + if len(padding) == 2: + left = right = padding[0] + top = bottom = padding[1] + padding = (left, top, right, bottom,) + if isinstance(padding, list): + padding = tuple(padding) + return padding + + class Decode(cde.DecodeOp): """ Decode the input image in RGB mode. @@ -136,16 +150,22 @@ class RandomCrop(cde.RandomCropOp): @check_random_crop def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): - self.size = size - self.padding = padding - self.pad_if_needed = pad_if_needed - self.fill_value = fill_value - self.padding_mode = padding_mode.value + if isinstance(size, int): + size = (size, size) if padding is None: padding = (0, 0, 0, 0) + else: + padding = parse_padding(padding) if isinstance(fill_value, int): # temporary fix fill_value = tuple([fill_value] * 3) border_type = DE_C_BORDER_TYPE[padding_mode] + + self.size = size + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill_value = fill_value + self.padding_mode = padding_mode.value + super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) @@ -184,16 +204,23 @@ class RandomCropWithBBox(cde.RandomCropWithBBoxOp): @check_random_crop def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): - self.size = size - self.padding = padding - self.pad_if_needed = pad_if_needed - self.fill_value = fill_value - self.padding_mode = padding_mode.value + if isinstance(size, int): + size = (size, size) if padding is None: padding = (0, 0, 0, 0) + else: + padding = parse_padding(padding) + if isinstance(fill_value, int): # temporary fix fill_value = tuple([fill_value] * 3) border_type = DE_C_BORDER_TYPE[padding_mode] + + self.size = size + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill_value = fill_value + self.padding_mode = padding_mode.value + super().__init__(*size, *padding, border_type, pad_if_needed, *fill_value) @@ -292,6 +319,8 @@ class Resize(cde.ResizeOp): @check_resize_interpolation def __init__(self, size, interpolation=Inter.LINEAR): + if isinstance(size, int): + size = (size, size) self.size = size self.interpolation = interpolation interpoltn = DE_C_INTER_MODE[interpolation] @@ -359,6 +388,8 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp): @check_random_resize_crop def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Inter.BILINEAR, max_attempts=10): + if isinstance(size, int): + size = (size, size) self.size = size self.scale = scale self.ratio = ratio @@ -396,6 +427,8 @@ class RandomResizedCrop(cde.RandomCropAndResizeOp): @check_random_resize_crop def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Inter.BILINEAR, max_attempts=10): + if isinstance(size, int): + size = (size, size) self.size = size self.scale = scale self.ratio = ratio @@ -417,6 +450,8 @@ class CenterCrop(cde.CenterCropOp): @check_crop def __init__(self, size): + if isinstance(size, int): + size = (size, size) self.size = size super().__init__(*size) @@ -442,12 +477,26 @@ class RandomColorAdjust(cde.RandomColorAdjustOp): @check_random_color_adjust def __init__(self, brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)): + brightness = self.expand_values(brightness) + contrast = self.expand_values(contrast) + saturation = self.expand_values(saturation) + hue = self.expand_values(hue, center=0, bound=(-0.5, 0.5), non_negative=False) + self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue + super().__init__(*brightness, *contrast, *saturation, *hue) + def expand_values(self, value, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True): + if isinstance(value, numbers.Number): + value = [center - value, center + value] + if non_negative: + value[0] = max(0, value[0]) + check_range(value, bound) + return (value[0], value[1]) + class RandomRotation(cde.RandomRotationOp): """ @@ -485,6 +534,8 @@ class RandomRotation(cde.RandomRotationOp): self.expand = expand self.center = center self.fill_value = fill_value + if isinstance(degrees, numbers.Number): + degrees = (-degrees, degrees) if center is None: center = (-1, -1) if isinstance(fill_value, int): # temporary fix @@ -584,6 +635,8 @@ class RandomCropDecodeResize(cde.RandomCropDecodeResizeOp): @check_random_resize_crop def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Inter.BILINEAR, max_attempts=10): + if isinstance(size, int): + size = (size, size) self.size = size self.scale = scale self.ratio = ratio @@ -623,12 +676,14 @@ class Pad(cde.PadOp): @check_pad def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): - self.padding = padding - self.fill_value = fill_value - self.padding_mode = padding_mode + padding = parse_padding(padding) if isinstance(fill_value, int): # temporary fix fill_value = tuple([fill_value] * 3) padding_mode = DE_C_BORDER_TYPE[padding_mode] + + self.padding = padding + self.fill_value = fill_value + self.padding_mode = padding_mode super().__init__(*padding, padding_mode, *fill_value) diff --git a/mindspore/dataset/transforms/vision/py_transforms.py b/mindspore/dataset/transforms/vision/py_transforms.py index b252c3434b..3bfd6b0644 100644 --- a/mindspore/dataset/transforms/vision/py_transforms.py +++ b/mindspore/dataset/transforms/vision/py_transforms.py @@ -28,6 +28,7 @@ import numpy as np from PIL import Image from . import py_transforms_util as util +from .c_transforms import parse_padding from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \ check_transforms_list, check_random_apply, check_ten_crop, check_num_channels, check_pad, \ @@ -295,6 +296,10 @@ class RandomCrop: @check_random_crop def __init__(self, size, padding=None, pad_if_needed=False, fill_value=0, padding_mode=Border.CONSTANT): + if padding is None: + padding = (0, 0, 0, 0) + else: + padding = parse_padding(padding) self.size = size self.padding = padding self.pad_if_needed = pad_if_needed @@ -753,6 +758,8 @@ class TenCrop: @check_ten_crop def __init__(self, size, use_vertical_flip=False): + if isinstance(size, int): + size = (size, size) self.size = size self.use_vertical_flip = use_vertical_flip @@ -877,6 +884,8 @@ class Pad: @check_pad def __init__(self, padding, fill_value=0, padding_mode=Border.CONSTANT): + parse_padding(padding) + self.padding = padding self.fill_value = fill_value self.padding_mode = DE_PY_BORDER_TYPE[padding_mode] @@ -1129,56 +1138,23 @@ class RandomAffine: def __init__(self, degrees, translate=None, scale=None, shear=None, resample=Inter.NEAREST, fill_value=0): # Parameter checking # rotation - if isinstance(degrees, numbers.Number): - if degrees < 0: - raise ValueError("If degrees is a single number, it must be positive.") - self.degrees = (-degrees, degrees) - elif isinstance(degrees, (tuple, list)) and len(degrees) == 2: - self.degrees = degrees - else: - raise TypeError("If degrees is a list or tuple, it must be of length 2.") - - # translation - if translate is not None: - if isinstance(translate, (tuple, list)) and len(translate) == 2: - for t in translate: - if t < 0.0 or t > 1.0: - raise ValueError("translation values should be between 0 and 1") - else: - raise TypeError("translate should be a list or tuple of length 2.") - self.translate = translate - - # scale - if scale is not None: - if isinstance(scale, (tuple, list)) and len(scale) == 2: - for s in scale: - if s <= 0: - raise ValueError("scale values should be positive") - else: - raise TypeError("scale should be a list or tuple of length 2.") - self.scale_ranges = scale - - # shear if shear is not None: if isinstance(shear, numbers.Number): - if shear < 0: - raise ValueError("If shear is a single number, it must be positive.") - self.shear = (-1 * shear, shear) - elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4): - # X-Axis shear with [min, max] + shear = (-1 * shear, shear) + else: if len(shear) == 2: - self.shear = [shear[0], shear[1], 0., 0.] + shear = [shear[0], shear[1], 0., 0.] elif len(shear) == 4: - self.shear = [s for s in shear] - else: - raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.") - else: - self.shear = shear + shear = [s for s in shear] - # resample - self.resample = DE_PY_INTER_MODE[resample] + if isinstance(degrees, numbers.Number): + degrees = (-degrees, degrees) - # fill_value + self.degrees = degrees + self.translate = translate + self.scale_ranges = scale + self.shear = shear + self.resample = DE_PY_INTER_MODE[resample] self.fill_value = fill_value def __call__(self, img): diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index b49116349b..4cb6613359 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -16,47 +16,35 @@ """ import numbers from functools import wraps - +import numpy as np from mindspore._c_dataengine import TensorOp from .utils import Inter, Border -from ...transforms.validators import check_pos_int32, check_pos_float32, check_value, check_uint8, FLOAT_MAX_INTEGER, \ - check_bool, check_2tuple, check_range, check_list, check_type, check_positive, INT32_MAX - - -def check_inter_mode(mode): - if not isinstance(mode, Inter): - raise ValueError("Invalid interpolation mode.") - - -def check_border_type(mode): - if not isinstance(mode, Border): - raise ValueError("Invalid padding mode.") +from ...core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \ + check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list def check_crop_size(size): """Wrapper method to check the parameters of crop size.""" + type_check(size, (int, list, tuple), "size") if isinstance(size, int): - size = (size, size) + check_value(size, (1, FLOAT_MAX_INTEGER)) elif isinstance(size, (tuple, list)) and len(size) == 2: - size = size + for value in size: + check_value(value, (1, FLOAT_MAX_INTEGER)) else: raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") - for value in size: - check_pos_int32(value) - return size def check_resize_size(size): """Wrapper method to check the parameters of resize.""" if isinstance(size, int): - check_pos_int32(size) + check_value(size, (1, FLOAT_MAX_INTEGER)) elif isinstance(size, (tuple, list)) and len(size) == 2: - for value in size: - check_value(value, (1, INT32_MAX)) + for i, value in enumerate(size): + check_value(value, (1, INT32_MAX), "size at dim {0}".format(i)) else: raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") - return size def check_normalize_c_param(mean, std): @@ -72,9 +60,9 @@ def check_normalize_py_param(mean, std): if len(mean) != len(std): raise ValueError("Length of mean and std must be equal") for mean_value in mean: - check_value(mean_value, [0., 1.]) + check_value(mean_value, [0., 1.], "mean_value") for std_value in std: - check_value(std_value, [0., 1.]) + check_value(std_value, [0., 1.], "std_value") def check_fill_value(fill_value): @@ -85,66 +73,37 @@ def check_fill_value(fill_value): check_uint8(value) else: raise TypeError("fill_value should be a single integer or a 3-tuple.") - return fill_value def check_padding(padding): """Parsing the padding arguments and check if it is legal.""" - if isinstance(padding, numbers.Number): - top = bottom = left = right = padding - - elif isinstance(padding, (tuple, list)): - if len(padding) == 2: - left = right = padding[0] - top = bottom = padding[1] - elif len(padding) == 4: - left = padding[0] - top = padding[1] - right = padding[2] - bottom = padding[3] - else: + type_check(padding, (tuple, list, numbers.Number), "padding") + if isinstance(padding, (tuple, list)): + if len(padding) not in (2, 4): raise ValueError("The size of the padding list or tuple should be 2 or 4.") - else: - raise TypeError("Padding can be any of: a number, a tuple or list of size 2 or 4.") - if not (isinstance(left, int) and isinstance(top, int) and isinstance(right, int) and isinstance(bottom, int)): - raise TypeError("Padding value should be integer.") - if left < 0 or top < 0 or right < 0 or bottom < 0: - raise ValueError("Padding value could not be negative.") - return left, top, right, bottom + for i, pad_value in enumerate(padding): + type_check(pad_value, (int,), "padding[{}]".format(i)) + check_value(pad_value, (0, INT32_MAX), "pad_value") def check_degrees(degrees): """Check if the degrees is legal.""" + type_check(degrees, (numbers.Number, list, tuple), "degrees") if isinstance(degrees, numbers.Number): - if degrees < 0: - raise ValueError("If degrees is a single number, it cannot be negative.") - degrees = (-degrees, degrees) + check_value(degrees, (0, float("inf")), "degrees") elif isinstance(degrees, (list, tuple)): if len(degrees) != 2: raise TypeError("If degrees is a sequence, the length must be 2.") - else: - raise TypeError("Degrees must be a single non-negative number or a sequence") - return degrees def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True): """Check the parameters in random color adjust operation.""" + type_check(value, (numbers.Number, list, tuple), input_name) if isinstance(value, numbers.Number): if value < 0: raise ValueError("The input value of {} cannot be negative.".format(input_name)) - # convert value into a range - value = [center - value, center + value] - if non_negative: - value[0] = max(0, value[0]) elif isinstance(value, (list, tuple)) and len(value) == 2: - if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError("Please check your value range of {} is valid and " - "within the bound {}".format(input_name, bound)) - else: - raise TypeError("Input of {} should be either a single value, or a list/tuple of " - "length 2.".format(input_name)) - factor = (value[0], value[1]) - return factor + check_range(value, bound) def check_erasing_value(value): @@ -155,173 +114,105 @@ def check_erasing_value(value): def check_crop(method): - """A wrapper that wrap a parameter checker to the original function(crop operation).""" + """A wrapper that wraps a parameter checker to the original function(crop operation).""" @wraps(method) def new_method(self, *args, **kwargs): - size = (list(args) + [None])[0] - if "size" in kwargs: - size = kwargs.get("size") - if size is None: - raise ValueError("size is not provided.") - size = check_crop_size(size) - kwargs["size"] = size + [size], _ = parse_user_args(method, *args, **kwargs) + check_crop_size(size) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method def check_resize_interpolation(method): - """A wrapper that wrap a parameter checker to the original function(resize interpolation operation).""" + """A wrapper that wraps a parameter checker to the original function(resize interpolation operation).""" @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - size, interpolation = args - if "size" in kwargs: - size = kwargs.get("size") - if "interpolation" in kwargs: - interpolation = kwargs.get("interpolation") - - if size is None: - raise ValueError("size is not provided.") - size = check_resize_size(size) - kwargs["size"] = size - + [size, interpolation], _ = parse_user_args(method, *args, **kwargs) + check_resize_size(size) if interpolation is not None: - check_inter_mode(interpolation) - kwargs["interpolation"] = interpolation + type_check(interpolation, (Inter,), "interpolation") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method def check_resize(method): - """A wrapper that wrap a parameter checker to the original function(resize operation).""" + """A wrapper that wraps a parameter checker to the original function(resize operation).""" @wraps(method) def new_method(self, *args, **kwargs): - size = (list(args) + [None])[0] - if "size" in kwargs: - size = kwargs.get("size") - - if size is None: - raise ValueError("size is not provided.") - size = check_resize_size(size) - kwargs["size"] = size + [size], _ = parse_user_args(method, *args, **kwargs) + check_resize_size(size) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method def check_random_resize_crop(method): - """A wrapper that wrap a parameter checker to the original function(random resize crop operation).""" + """A wrapper that wraps a parameter checker to the original function(random resize crop operation).""" @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 5 * [None])[:5] - size, scale, ratio, interpolation, max_attempts = args - if "size" in kwargs: - size = kwargs.get("size") - if "scale" in kwargs: - scale = kwargs.get("scale") - if "ratio" in kwargs: - ratio = kwargs.get("ratio") - if "interpolation" in kwargs: - interpolation = kwargs.get("interpolation") - if "max_attempts" in kwargs: - max_attempts = kwargs.get("max_attempts") - - if size is None: - raise ValueError("size is not provided.") - size = check_crop_size(size) - kwargs["size"] = size + [size, scale, ratio, interpolation, max_attempts], _ = parse_user_args(method, *args, **kwargs) + check_crop_size(size) if scale is not None: check_range(scale, [0, FLOAT_MAX_INTEGER]) - kwargs["scale"] = scale if ratio is not None: check_range(ratio, [0, FLOAT_MAX_INTEGER]) - check_positive(ratio[0]) - kwargs["ratio"] = ratio + check_positive(ratio[0], "ratio[0]") if interpolation is not None: - check_inter_mode(interpolation) - kwargs["interpolation"] = interpolation + type_check(interpolation, (Inter,), "interpolation") if max_attempts is not None: - check_pos_int32(max_attempts) - kwargs["max_attempts"] = max_attempts + check_value(max_attempts, (1, FLOAT_MAX_INTEGER)) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method def check_prob(method): - """A wrapper that wrap a parameter checker(check the probability) to the original function.""" + """A wrapper that wraps a parameter checker(check the probability) to the original function.""" @wraps(method) def new_method(self, *args, **kwargs): - prob = (list(args) + [None])[0] - if "prob" in kwargs: - prob = kwargs.get("prob") - if prob is not None: - check_value(prob, [0., 1.]) - kwargs["prob"] = prob + [prob], _ = parse_user_args(method, *args, **kwargs) + type_check(prob, (float, int,), "prob") + check_value(prob, [0., 1.], "prob") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method def check_normalize_c(method): - """A wrapper that wrap a parameter checker to the original function(normalize operation written in C++).""" + """A wrapper that wraps a parameter checker to the original function(normalize operation written in C++).""" @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - mean, std = args - if "mean" in kwargs: - mean = kwargs.get("mean") - if "std" in kwargs: - std = kwargs.get("std") - - if mean is None: - raise ValueError("mean is not provided.") - if std is None: - raise ValueError("std is not provided.") + [mean, std], _ = parse_user_args(method, *args, **kwargs) check_normalize_c_param(mean, std) - kwargs["mean"] = mean - kwargs["std"] = std - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method def check_normalize_py(method): - """A wrapper that wrap a parameter checker to the original function(normalize operation written in Python).""" + """A wrapper that wraps a parameter checker to the original function(normalize operation written in Python).""" @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - mean, std = args - if "mean" in kwargs: - mean = kwargs.get("mean") - if "std" in kwargs: - std = kwargs.get("std") - - if mean is None: - raise ValueError("mean is not provided.") - if std is None: - raise ValueError("std is not provided.") + [mean, std], _ = parse_user_args(method, *args, **kwargs) check_normalize_py_param(mean, std) - kwargs["mean"] = mean - kwargs["std"] = std - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -331,38 +222,17 @@ def check_random_crop(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 5 * [None])[:5] - size, padding, pad_if_needed, fill_value, padding_mode = args - - if "size" in kwargs: - size = kwargs.get("size") - if "padding" in kwargs: - padding = kwargs.get("padding") - if "fill_value" in kwargs: - fill_value = kwargs.get("fill_value") - if "padding_mode" in kwargs: - padding_mode = kwargs.get("padding_mode") - if "pad_if_needed" in kwargs: - pad_if_needed = kwargs.get("pad_if_needed") - - if size is None: - raise ValueError("size is not provided.") - size = check_crop_size(size) - kwargs["size"] = size - + [size, padding, pad_if_needed, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs) + check_crop_size(size) + type_check(pad_if_needed, (bool,), "pad_if_needed") if padding is not None: - padding = check_padding(padding) - kwargs["padding"] = padding + check_padding(padding) if fill_value is not None: - fill_value = check_fill_value(fill_value) - kwargs["fill_value"] = fill_value + check_fill_value(fill_value) if padding_mode is not None: - check_border_type(padding_mode) - kwargs["padding_mode"] = padding_mode - if pad_if_needed is not None: - kwargs["pad_if_needed"] = pad_if_needed + type_check(padding_mode, (Border,), "padding_mode") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -372,27 +242,13 @@ def check_random_color_adjust(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 4 * [None])[:4] - brightness, contrast, saturation, hue = args - if "brightness" in kwargs: - brightness = kwargs.get("brightness") - if "contrast" in kwargs: - contrast = kwargs.get("contrast") - if "saturation" in kwargs: - saturation = kwargs.get("saturation") - if "hue" in kwargs: - hue = kwargs.get("hue") - - if brightness is not None: - kwargs["brightness"] = check_random_color_adjust_param(brightness, "brightness") - if contrast is not None: - kwargs["contrast"] = check_random_color_adjust_param(contrast, "contrast") - if saturation is not None: - kwargs["saturation"] = check_random_color_adjust_param(saturation, "saturation") - if hue is not None: - kwargs["hue"] = check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False) - - return method(self, **kwargs) + [brightness, contrast, saturation, hue], _ = parse_user_args(method, *args, **kwargs) + check_random_color_adjust_param(brightness, "brightness") + check_random_color_adjust_param(contrast, "contrast") + check_random_color_adjust_param(saturation, "saturation") + check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False) + + return method(self, *args, **kwargs) return new_method @@ -402,38 +258,19 @@ def check_random_rotation(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 5 * [None])[:5] - degrees, resample, expand, center, fill_value = args - if "degrees" in kwargs: - degrees = kwargs.get("degrees") - if "resample" in kwargs: - resample = kwargs.get("resample") - if "expand" in kwargs: - expand = kwargs.get("expand") - if "center" in kwargs: - center = kwargs.get("center") - if "fill_value" in kwargs: - fill_value = kwargs.get("fill_value") - - if degrees is None: - raise ValueError("degrees is not provided.") - degrees = check_degrees(degrees) - kwargs["degrees"] = degrees + [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs) + check_degrees(degrees) if resample is not None: - check_inter_mode(resample) - kwargs["resample"] = resample + type_check(resample, (Inter,), "resample") if expand is not None: - check_bool(expand) - kwargs["expand"] = expand + type_check(expand, (bool,), "expand") if center is not None: - check_2tuple(center) - kwargs["center"] = center + check_2tuple(center, "center") if fill_value is not None: - fill_value = check_fill_value(fill_value) - kwargs["fill_value"] = fill_value + check_fill_value(fill_value) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -443,16 +280,11 @@ def check_transforms_list(method): @wraps(method) def new_method(self, *args, **kwargs): - transforms = (list(args) + [None])[0] - if "transforms" in kwargs: - transforms = kwargs.get("transforms") - if transforms is None: - raise ValueError("transforms is not provided.") + [transforms], _ = parse_user_args(method, *args, **kwargs) - check_list(transforms) - kwargs["transforms"] = transforms + type_check(transforms, (list,), "transforms") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -462,21 +294,14 @@ def check_random_apply(method): @wraps(method) def new_method(self, *args, **kwargs): - transforms, prob = (list(args) + 2 * [None])[:2] - if "transforms" in kwargs: - transforms = kwargs.get("transforms") - if transforms is None: - raise ValueError("transforms is not provided.") - check_list(transforms) - kwargs["transforms"] = transforms - - if "prob" in kwargs: - prob = kwargs.get("prob") + [transforms, prob], _ = parse_user_args(method, *args, **kwargs) + type_check(transforms, (list,), "transforms") + if prob is not None: - check_value(prob, [0., 1.]) - kwargs["prob"] = prob + type_check(prob, (float, int,), "prob") + check_value(prob, [0., 1.], "prob") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -486,23 +311,13 @@ def check_ten_crop(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - size, use_vertical_flip = args - if "size" in kwargs: - size = kwargs.get("size") - if "use_vertical_flip" in kwargs: - use_vertical_flip = kwargs.get("use_vertical_flip") - - if size is None: - raise ValueError("size is not provided.") - size = check_crop_size(size) - kwargs["size"] = size + [size, use_vertical_flip], _ = parse_user_args(method, *args, **kwargs) + check_crop_size(size) if use_vertical_flip is not None: - check_bool(use_vertical_flip) - kwargs["use_vertical_flip"] = use_vertical_flip + type_check(use_vertical_flip, (bool,), "use_vertical_flip") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -512,16 +327,13 @@ def check_num_channels(method): @wraps(method) def new_method(self, *args, **kwargs): - num_output_channels = (list(args) + [None])[0] - if "num_output_channels" in kwargs: - num_output_channels = kwargs.get("num_output_channels") + [num_output_channels], _ = parse_user_args(method, *args, **kwargs) if num_output_channels is not None: if num_output_channels not in (1, 3): raise ValueError("Number of channels of the output grayscale image" "should be either 1 or 3. Got {0}".format(num_output_channels)) - kwargs["num_output_channels"] = num_output_channels - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -531,28 +343,12 @@ def check_pad(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 3 * [None])[:3] - padding, fill_value, padding_mode = args - if "padding" in kwargs: - padding = kwargs.get("padding") - if "fill_value" in kwargs: - fill_value = kwargs.get("fill_value") - if "padding_mode" in kwargs: - padding_mode = kwargs.get("padding_mode") - - if padding is None: - raise ValueError("padding is not provided.") - padding = check_padding(padding) - kwargs["padding"] = padding + [padding, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs) + check_padding(padding) + check_fill_value(fill_value) + type_check(padding_mode, (Border,), "padding_mode") - if fill_value is not None: - fill_value = check_fill_value(fill_value) - kwargs["fill_value"] = fill_value - if padding_mode is not None: - check_border_type(padding_mode) - kwargs["padding_mode"] = padding_mode - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -562,26 +358,13 @@ def check_random_perspective(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 3 * [None])[:3] - distortion_scale, prob, interpolation = args - if "distortion_scale" in kwargs: - distortion_scale = kwargs.get("distortion_scale") - if "prob" in kwargs: - prob = kwargs.get("prob") - if "interpolation" in kwargs: - interpolation = kwargs.get("interpolation") - - if distortion_scale is not None: - check_value(distortion_scale, [0., 1.]) - kwargs["distortion_scale"] = distortion_scale - if prob is not None: - check_value(prob, [0., 1.]) - kwargs["prob"] = prob - if interpolation is not None: - check_inter_mode(interpolation) - kwargs["interpolation"] = interpolation + [distortion_scale, prob, interpolation], _ = parse_user_args(method, *args, **kwargs) - return method(self, **kwargs) + check_value(distortion_scale, [0., 1.], "distortion_scale") + check_value(prob, [0., 1.], "prob") + type_check(interpolation, (Inter,), "interpolation") + + return method(self, *args, **kwargs) return new_method @@ -591,28 +374,13 @@ def check_mix_up(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 3 * [None])[:3] - batch_size, alpha, is_single = args - if "batch_size" in kwargs: - batch_size = kwargs.get("batch_size") - if "alpha" in kwargs: - alpha = kwargs.get("alpha") - if "is_single" in kwargs: - is_single = kwargs.get("is_single") - - if batch_size is None: - raise ValueError("batch_size") - check_pos_int32(batch_size) - kwargs["batch_size"] = batch_size - if alpha is None: - raise ValueError("alpha") - check_positive(alpha) - kwargs["alpha"] = alpha - if is_single is not None: - check_type(is_single, bool) - kwargs["is_single"] = is_single - - return method(self, **kwargs) + [batch_size, alpha, is_single], _ = parse_user_args(method, *args, **kwargs) + + check_value(batch_size, (1, FLOAT_MAX_INTEGER)) + check_positive(alpha, "alpha") + type_check(is_single, (bool,), "is_single") + + return method(self, *args, **kwargs) return new_method @@ -622,41 +390,16 @@ def check_random_erasing(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 6 * [None])[:6] - prob, scale, ratio, value, inplace, max_attempts = args - if "prob" in kwargs: - prob = kwargs.get("prob") - if "scale" in kwargs: - scale = kwargs.get("scale") - if "ratio" in kwargs: - ratio = kwargs.get("ratio") - if "value" in kwargs: - value = kwargs.get("value") - if "inplace" in kwargs: - inplace = kwargs.get("inplace") - if "max_attempts" in kwargs: - max_attempts = kwargs.get("max_attempts") + [prob, scale, ratio, value, inplace, max_attempts], _ = parse_user_args(method, *args, **kwargs) - if prob is not None: - check_value(prob, [0., 1.]) - kwargs["prob"] = prob - if scale is not None: - check_range(scale, [0, FLOAT_MAX_INTEGER]) - kwargs["scale"] = scale - if ratio is not None: - check_range(ratio, [0, FLOAT_MAX_INTEGER]) - kwargs["ratio"] = ratio - if value is not None: - check_erasing_value(value) - kwargs["value"] = value - if inplace is not None: - check_bool(inplace) - kwargs["inplace"] = inplace - if max_attempts is not None: - check_pos_int32(max_attempts) - kwargs["max_attempts"] = max_attempts + check_value(prob, [0., 1.], "prob") + check_range(scale, [0, FLOAT_MAX_INTEGER]) + check_range(ratio, [0, FLOAT_MAX_INTEGER]) + check_erasing_value(value) + type_check(inplace, (bool,), "inplace") + check_value(max_attempts, (1, FLOAT_MAX_INTEGER)) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -666,23 +409,12 @@ def check_cutout(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - length, num_patches = args - if "length" in kwargs: - length = kwargs.get("length") - if "num_patches" in kwargs: - num_patches = kwargs.get("num_patches") - - if length is None: - raise ValueError("length") - check_pos_int32(length) - kwargs["length"] = length + [length, num_patches], _ = parse_user_args(method, *args, **kwargs) - if num_patches is not None: - check_pos_int32(num_patches) - kwargs["num_patches"] = num_patches + check_value(length, (1, FLOAT_MAX_INTEGER)) + check_value(num_patches, (1, FLOAT_MAX_INTEGER)) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -692,17 +424,9 @@ def check_linear_transform(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 2 * [None])[:2] - transformation_matrix, mean_vector = args - if "transformation_matrix" in kwargs: - transformation_matrix = kwargs.get("transformation_matrix") - if "mean_vector" in kwargs: - mean_vector = kwargs.get("mean_vector") - - if transformation_matrix is None: - raise ValueError("transformation_matrix is not provided.") - if mean_vector is None: - raise ValueError("mean_vector is not provided.") + [transformation_matrix, mean_vector], _ = parse_user_args(method, *args, **kwargs) + type_check(transformation_matrix, (np.ndarray,), "transformation_matrix") + type_check(mean_vector, (np.ndarray,), "mean_vector") if transformation_matrix.shape[0] != transformation_matrix.shape[1]: raise ValueError("transformation_matrix should be a square matrix. " @@ -711,10 +435,7 @@ def check_linear_transform(method): raise ValueError("mean_vector length {0} should match either one dimension of the square" "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape)) - kwargs["transformation_matrix"] = transformation_matrix - kwargs["mean_vector"] = mean_vector - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -724,67 +445,40 @@ def check_random_affine(method): @wraps(method) def new_method(self, *args, **kwargs): - args = (list(args) + 6 * [None])[:6] - degrees, translate, scale, shear, resample, fill_value = args - if "degrees" in kwargs: - degrees = kwargs.get("degrees") - if "translate" in kwargs: - translate = kwargs.get("translate") - if "scale" in kwargs: - scale = kwargs.get("scale") - if "shear" in kwargs: - shear = kwargs.get("shear") - if "resample" in kwargs: - resample = kwargs.get("resample") - if "fill_value" in kwargs: - fill_value = kwargs.get("fill_value") - - if degrees is None: - raise ValueError("degrees is not provided.") - degrees = check_degrees(degrees) - kwargs["degrees"] = degrees + [degrees, translate, scale, shear, resample, fill_value], _ = parse_user_args(method, *args, **kwargs) + check_degrees(degrees) if translate is not None: - if isinstance(translate, (tuple, list)) and len(translate) == 2: - for t in translate: - if t < 0.0 or t > 1.0: - raise ValueError("translation values should be between 0 and 1") - else: + if type_check(translate, (list, tuple), "translate"): + translate_names = ["translate_{0}".format(i) for i in range(len(translate))] + type_check_list(translate, (int, float), translate_names) + if len(translate) != 2: raise TypeError("translate should be a list or tuple of length 2.") - kwargs["translate"] = translate + for i, t in enumerate(translate): + check_value(t, [0.0, 1.0], "translate at {0}".format(i)) if scale is not None: - if isinstance(scale, (tuple, list)) and len(scale) == 2: - for s in scale: - if s <= 0: - raise ValueError("scale values should be positive") + type_check(scale, (tuple, list), "scale") + if len(scale) == 2: + for i, s in enumerate(scale): + check_positive(s, "scale[{}]".format(i)) else: raise TypeError("scale should be a list or tuple of length 2.") - kwargs["scale"] = scale if shear is not None: + type_check(shear, (numbers.Number, tuple, list), "shear") if isinstance(shear, numbers.Number): - if shear < 0: - raise ValueError("If shear is a single number, it must be positive.") - shear = (-1 * shear, shear) - elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4): - # X-Axis shear with [min, max] - if len(shear) == 2: - shear = [shear[0], shear[1], 0., 0.] - elif len(shear) == 4: - shear = [s for s in shear] + check_positive(shear, "shear") else: - raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.") - kwargs["shear"] = shear + if len(shear) not in (2, 4): + raise TypeError("shear must be of length 2 or 4.") + + type_check(resample, (Inter,), "resample") - if resample is not None: - check_inter_mode(resample) - kwargs["resample"] = resample if fill_value is not None: - fill_value = check_fill_value(fill_value) - kwargs["fill_value"] = fill_value + check_fill_value(fill_value) - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -794,24 +488,11 @@ def check_rescale(method): @wraps(method) def new_method(self, *args, **kwargs): - rescale, shift = (list(args) + 2 * [None])[:2] - if "rescale" in kwargs: - rescale = kwargs.get("rescale") - if "shift" in kwargs: - shift = kwargs.get("shift") - - if rescale is None: - raise ValueError("rescale is not provided.") + [rescale, shift], _ = parse_user_args(method, *args, **kwargs) check_pos_float32(rescale) - kwargs["rescale"] = rescale - - if shift is None: - raise ValueError("shift is not provided.") - if not isinstance(shift, numbers.Number): - raise TypeError("shift is not a number.") - kwargs["shift"] = shift + type_check(shift, (numbers.Number,), "shift") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -821,33 +502,16 @@ def check_uniform_augment_cpp(method): @wraps(method) def new_method(self, *args, **kwargs): - operations, num_ops = (list(args) + 2 * [None])[:2] - if "operations" in kwargs: - operations = kwargs.get("operations") - else: - raise ValueError("operations list required") - if "num_ops" in kwargs: - num_ops = kwargs.get("num_ops") - else: - num_ops = 2 - - if not isinstance(num_ops, int): - raise ValueError("Number of operations should be an integer.") - - if num_ops <= 0: - raise ValueError("num_ops should be greater than zero") + [operations, num_ops], _ = parse_user_args(method, *args, **kwargs) + type_check(num_ops, (int,), "num_ops") + check_positive(num_ops, "num_ops") + if num_ops > len(operations): raise ValueError("num_ops is greater than operations list size") - if not isinstance(operations, list): - raise TypeError("operations is not a python list") - for op in operations: - if not isinstance(op, TensorOp): - raise ValueError("operations list only accepts C++ operations.") + tensor_ops = ["tensor_op_{0}".format(i) for i in range(len(operations))] + type_check_list(operations, (TensorOp,), tensor_ops) - kwargs["num_ops"] = num_ops - kwargs["operations"] = operations - - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -857,23 +521,11 @@ def check_bounding_box_augment_cpp(method): @wraps(method) def new_method(self, *args, **kwargs): - transform, ratio = (list(args) + 2 * [None])[:2] - if "transform" in kwargs: - transform = kwargs.get("transform") - if "ratio" in kwargs: - ratio = kwargs.get("ratio") - if not isinstance(ratio, float) and not isinstance(ratio, int): - raise ValueError("Ratio should be an int or float.") - if ratio is not None: - check_value(ratio, [0., 1.]) - kwargs["ratio"] = ratio - else: - ratio = 0.3 - if not isinstance(transform, TensorOp): - raise ValueError("Transform can only be a C++ operation.") - kwargs["transform"] = transform - kwargs["ratio"] = ratio - return method(self, **kwargs) + [transform, ratio], _ = parse_user_args(method, *args, **kwargs) + type_check(ratio, (float, int), "ratio") + check_value(ratio, [0., 1.], "ratio") + type_check(transform, (TensorOp,), "transform") + return method(self, *args, **kwargs) return new_method @@ -883,29 +535,22 @@ def check_uniform_augment_py(method): @wraps(method) def new_method(self, *args, **kwargs): - transforms, num_ops = (list(args) + 2 * [None])[:2] - if "transforms" in kwargs: - transforms = kwargs.get("transforms") - if transforms is None: - raise ValueError("transforms is not provided.") + [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs) + type_check(transforms, (list,), "transforms") + if not transforms: raise ValueError("transforms list is empty.") - check_list(transforms) + for transform in transforms: if isinstance(transform, TensorOp): raise ValueError("transform list only accepts Python operations.") - kwargs["transforms"] = transforms - if "num_ops" in kwargs: - num_ops = kwargs.get("num_ops") - if num_ops is not None: - check_type(num_ops, int) - check_positive(num_ops) - if num_ops > len(transforms): - raise ValueError("num_ops cannot be greater than the length of transforms list.") - kwargs["num_ops"] = num_ops + type_check(num_ops, (int,), "num_ops") + check_positive(num_ops, "num_ops") + if num_ops > len(transforms): + raise ValueError("num_ops cannot be greater than the length of transforms list.") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -915,22 +560,16 @@ def check_positive_degrees(method): @wraps(method) def new_method(self, *args, **kwargs): - degrees = (list(args) + [None])[0] - if "degrees" in kwargs: - degrees = kwargs.get("degrees") - - if degrees is not None: - if isinstance(degrees, (list, tuple)): - if len(degrees) != 2: - raise ValueError("Degrees must be a sequence with length 2.") - if degrees[0] < 0: - raise ValueError("Degrees range must be non-negative.") - if degrees[0] > degrees[1]: - raise ValueError("Degrees should be in (min,max) format. Got (max,min).") - else: - raise TypeError("Degrees must be a sequence in (min,max) format.") + [degrees], _ = parse_user_args(method, *args, **kwargs) + + if isinstance(degrees, (list, tuple)): + if len(degrees) != 2: + raise ValueError("Degrees must be a sequence with length 2.") + check_positive(degrees[0], "degrees[0]") + if degrees[0] > degrees[1]: + raise ValueError("Degrees should be in (min,max) format. Got (max,min).") - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method @@ -940,18 +579,12 @@ def check_compose_list(method): @wraps(method) def new_method(self, *args, **kwargs): - transforms = (list(args) + [None])[0] - if "transforms" in kwargs: - transforms = kwargs.get("transforms") - if transforms is None: - raise ValueError("transforms is not provided.") + [transforms], _ = parse_user_args(method, *args, **kwargs) + + type_check(transforms, (list,), transforms) if not transforms: raise ValueError("transforms list is empty.") - if not isinstance(transforms, list): - raise TypeError("transforms is not a python list") - - kwargs["transforms"] = transforms - return method(self, **kwargs) + return method(self, *args, **kwargs) return new_method diff --git a/mindspore/model_zoo/__init__.py b/mindspore/model_zoo/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/mindspore/nn/__init__.py b/mindspore/nn/__init__.py index 8d5e7d3b0a..e5c133a9a6 100644 --- a/mindspore/nn/__init__.py +++ b/mindspore/nn/__init__.py @@ -17,13 +17,15 @@ Neural Networks Cells. Pre-defined building blocks or computing units to construct Neural Networks. """ -from . import layer, loss, optim, metrics, wrap +from . import layer, loss, optim, metrics, wrap, distribution from .cell import Cell, GraphKernel from .layer import * from .loss import * from .optim import * from .metrics import * from .wrap import * +from .distribution import * + __all__ = ["Cell", "GraphKernel"] __all__.extend(layer.__all__) @@ -31,5 +33,7 @@ __all__.extend(loss.__all__) __all__.extend(optim.__all__) __all__.extend(metrics.__all__) __all__.extend(wrap.__all__) +__all__.extend(distribution.__all__) + __all__.sort() diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index cffe00a920..3eec96f0b5 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -16,6 +16,7 @@ import time import gc from collections import OrderedDict +import numpy from mindspore import log as logger from .. import context from ..common import dtype as mstype @@ -211,6 +212,9 @@ class Cell: if context.get_context("mode") == context.GRAPH_MODE: out = self.compile_and_run(*inputs) return out + for item in inputs: + if isinstance(item, numpy.ndarray): + raise TypeError("cell inputs should not be numpy array.") self.init_parameters_data() orign_grad = [] if self.requires_grad is True: @@ -827,6 +831,20 @@ class Cell: self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") self.enable_hook = True + def set_param_ps(self, recurse=True): + """ + Set whether the trainable parameter is updated by parameter server. + + Note: + This only works when running task in parameter server mode. + + Args: + recurse (bool): Whether sets the trainable parameters of subcells. Default: True. + """ + params = self.trainable_params(recurse) + for param in params: + param.set_param_ps() + class GraphKernel(Cell): """ Base class for GraphKernel. diff --git a/mindspore/nn/distribution/__init__.py b/mindspore/nn/distribution/__init__.py new file mode 100644 index 0000000000..55b4b03ef7 --- /dev/null +++ b/mindspore/nn/distribution/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Distribution. + +The high-level components(Distributions) used to construct the probabilistic network. +""" + +from .distribution import Distribution +from .normal import Normal +from .bernoulli import Bernoulli + +__all__ = ['Distribution', + 'Normal', + 'Bernoulli',] diff --git a/mindspore/nn/distribution/_utils/__init__.py b/mindspore/nn/distribution/_utils/__init__.py new file mode 100644 index 0000000000..816485643a --- /dev/null +++ b/mindspore/nn/distribution/_utils/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Distribution operation utility functions. +""" +from .utils import * + +__all__ = ['check_scalar', 'convert_to_batch', 'cast_to_tensor', + 'calc_batch_size', 'check_greater', + 'check_greater_equal_zero', + 'calc_broadcast_shape_from_param', + 'check_scalar_from_param', 'check_prob'] diff --git a/mindspore/nn/distribution/_utils/utils.py b/mindspore/nn/distribution/_utils/utils.py new file mode 100644 index 0000000000..c790a66f25 --- /dev/null +++ b/mindspore/nn/distribution/_utils/utils.py @@ -0,0 +1,199 @@ + +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utitly functions to help distribution class.""" +import numpy as np +from mindspore.ops import _utils as utils +from ....common.tensor import Tensor +from ....common.parameter import Parameter +from ....common import dtype as mstype + + +def check_scalar(value): + """ + Check if input value is a scalar. + """ + return np.isscalar(value) + + +def cast_to_tensor(t, dtype=mstype.float32): + """ + Cast an user input value into a Tensor of dtype. + + Args: + t (int, float, list, numpy.ndarray, Tensor, Parameter): object to be cast to Tensor. + dtype (mindspore.dtype): dtype of the Tensor. Default: mstype.float32. + + Raises: + RuntimeError: if t cannot be cast to Tensor. + + Returns: + Tensor. + """ + if isinstance(t, Parameter): + return t + if isinstance(t, Tensor): + #check if the Tensor in shape of Tensor(4) + if t.dim() == 0: + value = t.asnumpy() + return Tensor([t], dtype=dtype) + #convert the type of tensor to dtype + t.set_dtype(dtype) + return t + if isinstance(t, (list, np.ndarray)): + return Tensor(t, dtype=dtype) + if check_scalar(t): + return Tensor([t], dtype=dtype) + raise RuntimeError("Input type is not supported.") + +def calc_batch_size(batch_shape): + """ + Calculate the size of a given batch_shape. + + Args: + batch_shape (tuple): batch shape to be calculated. + + Returns: + int. + """ + return int(np.prod(batch_shape)) + +def convert_to_batch(t, batch_shape, dtype): + """ + Convert a Tensor to a given batch shape. + + Args: + t (Tensor, Parameter): Tensor to be converted. + batch_shape (tuple): desired batch shape. + dtype (mindspore.dtype): desired dtype. + + Raises: + RuntimeError: if the converison cannot be done. + + Returns: + Tensor, with shape of batch_shape. + """ + if isinstance(t, Parameter): + return t + t = cast_to_tensor(t, dtype) + if t.shape != batch_shape: + mul = calc_batch_size(batch_shape) // t.size() + if (calc_batch_size(batch_shape) % t.size()) != 0: + raise RuntimeError("Cannot cast the tensor to the given batch shape.") + temp = list(t.asnumpy()) * mul + temp = np.reshape(temp, batch_shape) + return Tensor(temp, dtype) + return t + +def check_scalar_from_param(params): + """ + Check if params are all scalars. + + Args: + params (dict): parameters used to initialize distribution. + + Notes: String parameters are excluded. + """ + for value in params.values(): + if isinstance(value, (str, type(params['dtype']))): + continue + elif check_scalar(value): + continue + else: + return False + return True + + +def calc_broadcast_shape_from_param(params): + """ + Calculate the broadcast shape from params. + + Args: + params (dict): parameters used to initialize distribution. + + Returns: + tuple. + """ + broadcast_shape = [] + for value in params.values(): + if isinstance(value, (str, type(params['dtype']))): + continue + if value is None: + return None + if isinstance(value, Parameter): + value_t = value.default_input + else: + value_t = cast_to_tensor(value, params['dtype']) + broadcast_shape = utils.get_broadcast_shape(broadcast_shape, list(value_t.shape), params['name']) + return tuple(broadcast_shape) + +def check_greater_equal_zero(value, name): + """ + Check if the given Tensor is greater zero. + + Args: + value (Tensor, Parameter): value to be checked. + name (str) : name of the value. + + Raises: + ValueError: if the input value is less than zero. + + """ + if isinstance(value, Parameter): + if not isinstance(value.default_input, Tensor): + return + value = value.default_input + comp = np.less(value.asnumpy(), np.zeros(value.shape)) + if comp.any(): + raise ValueError(f'{name} should be greater than zero.') + +def check_greater(a, b, name_a, name_b): + """ + Check if Tensor b is strictly greater than Tensor a. + + Args: + a (Tensor): input tensor a. + b (Tensor): input tensor b. + name_a (str): name of Tensor_a. + name_b (str): name of Tensor_b. + + Raises: + ValueError: if b is less than or equal to a + """ + comp = np.less(a.asnumpy(), b.asnumpy()) + if not comp.all(): + raise ValueError(f'{name_a} should be less than {name_b}') + + +def check_prob(p): + """ + Check if p is a proper probability, i.e. 0 <= p <=1. + + Args: + p (Tensor, Parameter): value to be checked. + + Raises: + ValueError: if p is not a proper probability. + """ + if isinstance(p, Parameter): + if not isinstance(p.default_input, Tensor): + return + p = p.default_input + comp = np.less(p.asnumpy(), np.zeros(p.shape)) + if comp.any(): + raise ValueError('Probabilities should be greater than or equal to zero') + comp = np.greater(p.asnumpy(), np.ones(p.shape)) + if comp.any(): + raise ValueError('Probabilities should be less than or equal to one') diff --git a/mindspore/nn/distribution/bernoulli.py b/mindspore/nn/distribution/bernoulli.py new file mode 100644 index 0000000000..9aa20d668f --- /dev/null +++ b/mindspore/nn/distribution/bernoulli.py @@ -0,0 +1,168 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bernoulli Distribution""" +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from .distribution import Distribution +from ._utils.utils import cast_to_tensor, check_prob +from ...common import dtype as mstype + +class Bernoulli(Distribution): + """ + Example class: Bernoulli Distribution. + + Args: + probs (int, float, list, numpy.ndarray, Tensor, Parameter): probability of 1 as outcome. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. + name (str): name of the distribution. Default: Bernoulli. + + Note: + probs should be proper probabilities (0 <= p <= 1). + + Examples: + >>> # To initialize a Bernoulli distribution which has equal probability of getting 1 and 0 + >>> b = nn.Bernoulli(0.5, dtype = mstype.int32) + >>> # The following create two independent Bernoulli distributions + >>> b = nn.Bernoulli([0.7, 0.2], dtype = mstype.int32) + """ + + def __init__(self, + probs=None, + seed=0, + dtype=mstype.int32, + name="Bernoulli"): + """ + Constructor of Bernoulli distribution. + """ + param = dict(locals()) + super(Bernoulli, self).__init__(dtype, name, param) + if probs is not None: + self._probs = cast_to_tensor(probs) + check_prob(self._probs) + else: + self._probs = probs + self.seed = seed + + # ops needed for the class + self.log = P.Log() + self.add = P.TensorAdd() + self.mul = P.Mul() + self.sqrt = P.Sqrt() + self.realdiv = P.RealDiv() + self.shape = P.Shape() + self.const = P.ScalarToArray() + self.less = P.Less() + self.cast = P.Cast() + self.erf = P.Erf() + self.sqrt = P.Sqrt() + + def extend_repr(self): + str_info = f'probs = {self._probs}' + return str_info + + def probs(self): + """ + Returns the probability for the outcome is 1. + """ + return self._probs + + def _mean(self, name='mean', probs1=None): + r""" + .. math:: + MEAN(B) = probs1 + """ + if name == 'mean': + return self._probs if probs1 is None else probs1 + return None + + def _var(self, name='var', probs1=None): + r""" + .. math:: + VAR(B) = probs1 * probs0 + """ + if name in ('sd', 'var'): + probs1 = self._probs if probs1 is None else probs1 + probs0 = self.add(1, -1 * probs1) + return self.mul(probs0, probs1) + return None + + def _prob(self, name, value, probs=None): + r""" + pmf of Bernoulli distribution. + + Args: + name (str): name of the function. Should be "prob" when passed in from construct. + value (Tensor): a Tensor composed of only zeros and ones. + probs (Tensor): probability of outcome is 1. Default: self._probs. + + .. math:: + pmf(k) = probs1 if k = 1; + pmf(k) = probs0 if k = 0; + """ + if name in ('prob', 'log_prob'): + probs1 = self._probs if probs is None else probs + probs0 = self.add(1, -1 * probs1) + return self.add(self.mul(probs1, value), + self.mul(probs0, self.add(1, -1 * value))) + return None + + def _kl_loss(self, name, dist, probs1_b, probs1_a=None): + r""" + Evaluate bernoulli-bernoulli kl divergence, i.e. KL(a||b). + + Args: + name (str): name of the funtion. Should always be "kl_loss" when passed in from construct. + dist (str): type of the distributions. Should be "Bernoulli" in this case. + probs1_b (Tensor): probs1 of distribution b. + probs1_a (Tensor): probs1 of distribution a. Default: self._probs. + + .. math:: + KL(a||b) = probs1_a * \log(\fract{probs1_a}{probs1_b}) + + probs0_a * \log(\fract{probs0_a}{probs0_b}) + """ + if name == 'kl_loss' and dist == 'Bernoulli': + probs1_a = self._probs if probs1_a is None else probs1_a + probs0_a = self.add(1, -1 * probs1_a) + probs0_b = self.add(1, -1 * probs1_b) + return self.add(probs1_a * self.log(self.realdiv(probs1_a, probs1_b)), + probs0_a * self.log(self.realdiv(probs0_a, probs0_b))) + return None + + def _sample(self, name, shape=(), probs=None): + """ + Sampling. + + Args: + name (str): name of the function. Should always be 'sample' when passed in from construct. + shape (tuple): shape of the sample. Default: (). + probs (Tensor): probs1 of the samples. Default: self._probs. + + Returns: + Tensor, shape is shape + batch_shape. + """ + if name == 'sample': + probs1 = self._probs if probs is None else probs + batch_shape = self.shape(probs1) + sample_shape = shape + batch_shape + mean_zero = self.const(0.0) + sd_one = self.const(1.0) + sqrt_two = self.sqrt(self.const(2.0)) + sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) + sample_uniform = 0.5 * (1 + self.erf(self.realdiv(sample_norm, sqrt_two))) + sample = self.less(sample_uniform, probs1) + sample = self.cast(sample, self._dtype) + return sample + return None diff --git a/mindspore/nn/distribution/distribution.py b/mindspore/nn/distribution/distribution.py new file mode 100644 index 0000000000..1ed7906a9e --- /dev/null +++ b/mindspore/nn/distribution/distribution.py @@ -0,0 +1,200 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""basic""" +from ..cell import Cell +from ._utils.utils import calc_broadcast_shape_from_param + + +class Distribution(Cell): + """ + Base class for all mathematical distributions. + + Args: + dtype (mindspore.dtype): type of the distribution. + name (str): name of the distribution. + param (dict): parameters used to initialize the distribution. + + Note: + Derived class should override operations such as ,_mean, _prob, + and _log_prob. Functions should be called through construct when + used inside a network in the form of function name followed by + arguments. + + Examples: + >>> class MyNormalDistribution(Distribution): + >>> def __init__(self): + >>> super(MyDistribution, self).__init__() + >>> self._mean_value = Tensor([2.0,3.0]) + >>> self._sd_value = Tensor([2.0,3.0]) + >>> + >>> def _mean(self): + >>> return self._mean_value + + """ + def __init__(self, + dtype, + name, + param): + + """ + Constructor of distribution class. + """ + super(Distribution, self).__init__() + self._name = name + self._dtype = dtype + self._parameters = {} + # parsing parameters + for k in param.keys(): + if not(k == 'self' or k.startswith('_')): + self._parameters[k] = param[k] + # some attributes + self._broadcast_shape = calc_broadcast_shape_from_param( + self._parameters) + + # set the function to call according to the derived class's attributes + self._set_prob() + self._set_log_prob() + self._set_sd() + + def _set_prob(self): + """ + Set probability funtion based on the availability of _prob and _log_likehood. + """ + if hasattr(self, '_prob'): + self._call_prob = self._prob + elif hasattr(self, '_log_likelihood'): + self._call_prob = self._calc_prob_from_log_likelihood + + def _set_sd(self): + """ + Set standard deviation based on the availability of _sd and _var. + """ + if hasattr(self, '_sd'): + self._call_sd = self._sd + elif hasattr(self, '_var'): + self._call_sd = self._calc_sd_from_var + + def _set_log_prob(self): + """ + Set log probability based on the availability of _prob and _log_likelihood. + """ + if hasattr(self, '_log_likelihood'): + self._call_log_prob = self._log_likelihood + if hasattr(self, '_prob'): + self._call_log_prob = self._calc_log_prob_from_prob + + def log_likelihood(self, *args): + """ + Evaluate the log probability at the given value. + + Note: + value is casted to Tensor for further calculation. + + Returns: + Tensor, shape is the broadcast_shape of the distribution. + """ + return self._call_log_prob(*args) + + def _calc_prob_from_log_likelihood(self, *args): + r""" + Evaluate prob from log probability. + + .. math:: + probability(x) = \exp(log_likehood(x)) + """ + return self.exp(self._log_likelihood(*args)) + + def prob(self, *args): + """ + Evaluate the prob (pdf or pmf) at given value. + + Note: + value is casted to Tensor for further calculation. + + Returns: + Tensor, shape is the broadcast_shape of the distribution. + """ + return self._call_prob(*args) + + def _calc_log_prob_from_prob(self, *args): + r""" + Evaluate log probability from probability. + + .. math:: + log_prob(x) = \log(prob(x)) + """ + return self.log(self._prob(*args)) + + def kl_loss(self, **kwargs): + """ + Evaluate the KL divergence. Parameters of the second distribution should be + passed in through **kwargs. + + Returns: + Tensor, shape is the broadcast_shape of the distribution and input distribution. + """ + return self._kl_loss(**kwargs) + + def mean(self, **kwargs): + """ + Evaluate the mean. + + Returns: + Tensor, shape is the broadcast_shape of the distribution. + """ + return self._mean(**kwargs) + + def sd(self, **kwargs): + """ + Evaluate the standard deviation. + + Returns: + Tensor, shape is the broadcast_shape of the distribution. + """ + return self._call_sd(**kwargs) + + def _calc_sd_from_var(self, *args): + r""" + Evaluate log probability from probability. + + .. math:: + STD(x) = \sqrt(VAR(x)) + """ + return self.sqrt(self._var(*args)) + + def construct(self, *inputs): + """ + Override construct in Cell. + + Args: + *inputs: inputs[0] is always the name of the function. + + Notes: + Always raise RuntimeError as Distribution should not be called directly. + """ + + if inputs[0] == 'log_prob': + return self._call_log_prob(*inputs) + if inputs[0] == 'prob': + return self._call_prob(*inputs) + if inputs[0] == 'kl_loss': + return self._kl_loss(*inputs) + if inputs[0] == 'mean': + return self._mean(*inputs) + if inputs[0] == 'sd': + return self._call_sd(*inputs) + if inputs[0] == 'sample': + return self._sample(*inputs) + return None diff --git a/mindspore/nn/distribution/normal.py b/mindspore/nn/distribution/normal.py new file mode 100644 index 0000000000..61cec6d810 --- /dev/null +++ b/mindspore/nn/distribution/normal.py @@ -0,0 +1,170 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Normal Distribution""" +import numpy as np +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from .distribution import Distribution +from ._utils.utils import convert_to_batch, check_greater_equal_zero +from ...common import dtype as mstype +from ...context import get_context + +class Normal(Distribution): + """ + Example class: Normal distribution. + + Args: + mean (int, float, list, numpy.ndarray, Tensor, Parameter): mean of the Gaussian distribution. + sd (int, float, list, numpy.ndarray, Tensor, Parameter): stddev of the Gaussian distribution. + seed (int): seed to use in sampling. Default: 0. + dtype (mindspore.dtype): type of the distribution. Default: mstype.float32. + name (str): name of the distribution. Default: Normal. + + + Note: + Standard deviation should be greater than zero. + + Examples: + >>> # To initialize a normal distribution of mean 3.0 and standard deviation 4.0 + >>> n = nn.Normal(3.0, 4.0, dtype=mstype.float32) + >>> # The following create two independent normal distributions + >>> n = nn.Normal([3.0, 3.0], [4.0, 4.0], dtype=mstype.float32) + """ + + def __init__(self, + mean=None, + sd=None, + seed=0, + dtype=mstype.float32, + name="Normal"): + """ + Constructor of normal distribution. + """ + param = dict(locals()) + super(Normal, self).__init__(dtype, name, param) + if mean is not None and sd is not None: + self._mean_value = convert_to_batch(mean, self._broadcast_shape, dtype) + self._sd_value = convert_to_batch(sd, self._broadcast_shape, dtype) + check_greater_equal_zero(self._sd_value, "Standard deviation") + else: + self._mean_value = mean + self._sd_value = sd + self.seed = seed + + #ops needed for the class + self.exp = P.Exp() + self.add = P.TensorAdd() + self.mul = P.Mul() + self.sq = P.Square() + self.log = P.Log() + self.sqrt = P.Sqrt() + self.realdiv = P.RealDiv() + self.expm1 = P.Expm1() if get_context('device_target') == 'Ascend' else self._expm1_by_step + self.shape = P.Shape() + self.zeroslike = P.ZerosLike() + self.const = P.ScalarToArray() + + def extend_repr(self): + str_info = f'mean = {self._mean_value}, standard deviation = {self._sd_value}' + return str_info + + def _expm1_by_step(self, x): + """ + Expm1 ops under GPU context. + """ + return self.add(self.exp(x), -1) + + def _mean(self, name='mean', mean=None, sd=None): + """ + Mean of the distribution. + """ + if name == 'mean': + mean = self._mean_value if mean is None or sd is None else mean + return mean + return None + + def _sd(self, name='sd', mean=None, sd=None): + """ + Standard deviation of the distribution. + """ + if name in ('sd', 'var'): + sd = self._sd_value if mean is None or sd is None else sd + return sd + return None + + def _log_likelihood(self, name, value, mean=None, sd=None): + r""" + Evaluate log probability. + + .. math:: + L(x) = -1* \fract{(x - \mu)^2}{2. * \sigma^2} - \log(\sqrt(2* \pi * \sigma^2)) + """ + if name in ('prob', 'log_prob'): + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + unnormalized_log_prob = -1. * self.realdiv(self.sq(self.add(value, -1. * mean)), + 2. * self.sq(sd)) + neg_normalization = -1. * self.log(self.sqrt(2. * np.pi * self.sq(sd))) + return self.add(unnormalized_log_prob, neg_normalization) + return None + + def _kl_loss(self, name, dist, mean_b, sd_b, mean_a=None, sd_a=None): + r""" + Evaluate Normal-Normal kl divergence, i.e. KL(a||b). + + Args: + name (str): name of the funtion passed in from construct. Should always be "kl_loss". + dist (str): type of the distributions. Should be "Normal" in this case. + mean_b (Tensor): mean of distribution b. + sd_b (Tensor): standard deviation distribution b. + mean_a (Tensor): mean of distribution a. Default: self._mean_value. + sd_a (Tensor): standard deviation distribution a. Default: self._sd_value. + + .. math:: + KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 + + 0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b))) + """ + if name == 'kl_loss' and dist == 'Normal': + mean_a = self._mean_value if mean_a is None else mean_a + sd_a = self._sd_value if sd_a is None else sd_a + diff_log_scale = self.add(self.log(sd_a), - self.log(sd_b)) + squared_diff = self.sq(self.add(self.realdiv(mean_a, sd_b), - self.realdiv(mean_b, sd_b))) + return self.add(self.add(0.5 * squared_diff, 0.5 * self.expm1(2 * diff_log_scale)), - diff_log_scale) + return None + + def _sample(self, name, shape=(), mean=None, sd=None): + """ + Sampling. + + Args: + name (str): name of the function. Should always be 'sample' when passed in from construct. + shape (tuple): shape of the sample. Default: (). + mean (Tensor): mean of the samples. Default: self._mean_value. + sd (Tensor): standard deviation of the samples. Default: self._sd_value. + + Returns: + Tensor, shape is shape + batch_shape. + """ + if name == 'sample': + mean = self._mean_value if mean is None else mean + sd = self._sd_value if sd is None else sd + batch_shape = self.shape(self.add(self.zeroslike(mean), self.zeroslike(sd))) + sample_shape = shape + batch_shape + mean_zero = self.const(0.0) + sd_one = self.const(1.0) + sample_norm = C.normal(sample_shape, mean_zero, sd_one, self.seed) + sample = self.add(mean, self.mul(sample_norm, sd)) + return sample + return None diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 14a1aa8554..384f625133 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -530,6 +530,7 @@ _activation = { 'relu6': ReLU6, 'tanh': Tanh, 'gelu': GELU, + 'elu': ELU, 'sigmoid': Sigmoid, 'prelu': PReLU, 'leakyrelu': LeakyReLU, diff --git a/mindspore/nn/layer/container.py b/mindspore/nn/layer/container.py index 48871401bf..ed36a1dd5f 100644 --- a/mindspore/nn/layer/container.py +++ b/mindspore/nn/layer/container.py @@ -69,7 +69,7 @@ class SequentialCell(Cell): Alternatively, an ordered dict of cells can also be passed in. Args: - args (list, optional): List of subclass of Cell. + args (list, OrderedDict): List of subclass of Cell. Raises: TypeError: If arg is not of type list or OrderedDict. diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index c8873039ab..3c4245d702 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer from ..cell import Cell from ..._checkparam import Validator as validator -__all__ = ['Embedding'] +__all__ = ['Embedding', 'EmbeddingLookup'] class Embedding(Cell): r""" @@ -105,3 +105,49 @@ class Embedding(Cell): self.embedding_table, self.dtype) return s + +class EmbeddingLookup(Cell): + r""" + Returns a slice of input tensor based on the specified indices. + + Note: + When 'target' is set to 'CPU', this module will use + P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') which + specified 'offset = 0' to lookup table. + when 'target' is set to 'DEVICE', this module will use P.GatherV2() which + specified 'axis = 0' to lookup table. + + Args: + target (str): Specify the target where the op is executed. Default: 'CPU'. + + Inputs: + - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + The Tensor slice, instead of the entire Tensor. + - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. + Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, + and the exceeding part will be filled with 0 in the output. + + Outputs: + Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. + + Examples: + >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) + >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) + >>> out = nn.EmbeddingLookup()(input_params, input_indices) + [[[10, 11], [8 ,9]], [[14, 15], [12, 13]]] + """ + def __init__(self, target='CPU'): + super(EmbeddingLookup, self).__init__() + self.target = target + if target not in ('CPU', 'DEVICE'): + raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed ' + + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') + self.gatherv2 = P.GatherV2() + self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') + + def construct(self, params, indices): + if self.target == "CPU": + out = self.embeddinglookup(params, indices, 0) + else: + out = self.gatherv2(params, indices, 0) + return out diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 3721bc3c44..63ae7a94ac 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -21,9 +21,13 @@ from mindspore.ops import functional as F from mindspore.ops.primitive import constexpr from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel +from .conv import Conv2d +from .container import CellList +from .pooling import AvgPool2d +from .activation import ReLU from ..cell import Cell -__all__ = ['ImageGradients', 'SSIM', 'PSNR', 'CentralCrop'] +__all__ = ['ImageGradients', 'SSIM', 'MSSSIM', 'PSNR', 'CentralCrop'] class ImageGradients(Cell): r""" @@ -83,21 +87,6 @@ def _convert_img_dtype_to_float32(img, max_val): ret = ret * scale return ret - -@constexpr -def _gauss_kernel_helper(filter_size): - """gauss kernel helper""" - filter_size = F.scalar_cast(filter_size, mstype.int32) - coords = () - for i in range(filter_size): - i_cast = F.scalar_cast(i, mstype.float32) - offset = F.scalar_cast(filter_size-1, mstype.float32)/2.0 - element = i_cast-offset - coords = coords+(element,) - g = np.square(coords).astype(np.float32) - g = Tensor(g) - return filter_size, g - @constexpr def _check_input_4d(input_shape, param_name, func_name): if len(input_shape) != 4: @@ -110,9 +99,65 @@ def _check_input_filter_size(input_shape, param_name, filter_size, func_name): validator.check(param_name + " shape[2]", input_shape[2], "filter_size", filter_size, Rel.GE, func_name) validator.check(param_name + " shape[3]", input_shape[3], "filter_size", filter_size, Rel.GE, func_name) -@constexpr -def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): - validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) +def _conv2d(in_channels, out_channels, kernel_size, weight, stride=1, padding=0): + return Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + weight_init=weight, padding=padding, pad_mode="valid") + +def _create_window(size, sigma): + x_data, y_data = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] + x_data = np.expand_dims(x_data, axis=-1).astype(np.float32) + x_data = np.expand_dims(x_data, axis=-1) ** 2 + y_data = np.expand_dims(y_data, axis=-1).astype(np.float32) + y_data = np.expand_dims(y_data, axis=-1) ** 2 + sigma = 2 * sigma ** 2 + g = np.exp(-(x_data + y_data) / sigma) + return np.transpose(g / np.sum(g), (2, 3, 0, 1)) + +def _split_img(x): + _, c, _, _ = F.shape(x) + img_split = P.Split(1, c) + output = img_split(x) + return output, c + +def _compute_per_channel_loss(c1, c2, img1, img2, conv): + """computes ssim index between img1 and img2 per single channel""" + dot_img = img1 * img2 + mu1 = conv(img1) + mu2 = conv(img2) + mu1_sq = mu1 * mu1 + mu2_sq = mu2 * mu2 + mu1_mu2 = mu1 * mu2 + sigma1_tmp = conv(img1 * img1) + sigma1_sq = sigma1_tmp - mu1_sq + sigma2_tmp = conv(img2 * img2) + sigma2_sq = sigma2_tmp - mu2_sq + sigma12_tmp = conv(dot_img) + sigma12 = sigma12_tmp - mu1_mu2 + a = (2 * mu1_mu2 + c1) + b = (mu1_sq + mu2_sq + c1) + v1 = 2 * sigma12 + c2 + v2 = sigma1_sq + sigma2_sq + c2 + ssim = (a * v1) / (b * v2) + cs = v1 / v2 + return ssim, cs + +def _compute_multi_channel_loss(c1, c2, img1, img2, conv, concat, mean): + """computes ssim index between img1 and img2 per color channel""" + split_img1, c = _split_img(img1) + split_img2, _ = _split_img(img2) + multi_ssim = () + multi_cs = () + for i in range(c): + ssim_per_channel, cs_per_channel = _compute_per_channel_loss(c1, c2, split_img1[i], split_img2[i], conv) + multi_ssim += (ssim_per_channel,) + multi_cs += (cs_per_channel,) + + multi_ssim = concat(multi_ssim) + multi_cs = concat(multi_cs) + + ssim = mean(multi_ssim, (2, 3)) + cs = mean(multi_cs, (2, 3)) + return ssim, cs class SSIM(Cell): r""" @@ -157,67 +202,126 @@ class SSIM(Cell): self.max_val = max_val self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) - validator.check_value_type('k1', k1, [float], self.cls_name) - self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) - validator.check_value_type('k2', k2, [float], self.cls_name) - self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) - self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) + self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) + self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) + window = _create_window(filter_size, filter_sigma) + self.conv = _conv2d(1, 1, filter_size, Tensor(window)) + self.conv.weight.requires_grad = False + self.reduce_mean = P.ReduceMean() + self.concat = P.Concat(axis=1) def construct(self, img1, img2): - _check_input_dtype(F.dtype(img1), "img1", [mstype.float32, mstype.float16], self.cls_name) _check_input_filter_size(F.shape(img1), "img1", self.filter_size, self.cls_name) P.SameTypeShape()(img1, img2) max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) img1 = _convert_img_dtype_to_float32(img1, self.max_val) img2 = _convert_img_dtype_to_float32(img2, self.max_val) - kernel = self._fspecial_gauss(self.filter_size, self.filter_sigma) - kernel = P.Tile()(kernel, (1, P.Shape()(img1)[1], 1, 1)) + c1 = (self.k1 * max_val) ** 2 + c2 = (self.k2 * max_val) ** 2 + + ssim_ave_channel, _ = _compute_multi_channel_loss(c1, c2, img1, img2, self.conv, self.concat, self.reduce_mean) + loss = self.reduce_mean(ssim_ave_channel, -1) + + return loss + +def _downsample(img1, img2, op): + a = op(img1) + b = op(img2) + return a, b + +class MSSSIM(Cell): + r""" + Returns MS-SSIM index between img1 and img2. + + Its implementation is based on Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. `Multiscale structural similarity + for image quality assessment `_. + Signals, Systems and Computers, 2004. - mean_ssim = self._calculate_mean_ssim(img1, img2, kernel, max_val, self.k1, self.k2) + .. math:: - return mean_ssim + l(x,y)&=\frac{2\mu_x\mu_y+C_1}{\mu_x^2+\mu_y^2+C_1}, C_1=(K_1L)^2.\\ + c(x,y)&=\frac{2\sigma_x\sigma_y+C_2}{\sigma_x^2+\sigma_y^2+C_2}, C_2=(K_2L)^2.\\ + s(x,y)&=\frac{\sigma_{xy}+C_3}{\sigma_x\sigma_y+C_3}, C_3=C_2/2.\\ + MSSSIM(x,y)&=l^alpha_M*{\prod_{1\leq j\leq M} (c^beta_j*s^gamma_j)}. - def _calculate_mean_ssim(self, x, y, kernel, max_val, k1, k2): - """calculate mean ssim""" - c1 = (k1 * max_val) * (k1 * max_val) - c2 = (k2 * max_val) * (k2 * max_val) + Args: + max_val (Union[int, float]): The dynamic range of the pixel values (255 for 8-bit grayscale images). + Default: 1.0. + power_factors (Union[tuple, list]): Iterable of weights for each of the scales. + Default: (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). Default values obtained by Wang et al. + filter_size (int): The size of the Gaussian filter. Default: 11. + filter_sigma (float): The standard deviation of Gaussian kernel. Default: 1.5. + k1 (float): The constant used to generate c1 in the luminance comparison function. Default: 0.01. + k2 (float): The constant used to generate c2 in the contrast comparison function. Default: 0.03. - # SSIM luminance formula - # (2 * mean_{x} * mean_{y} + c1) / (mean_{x}**2 + mean_{y}**2 + c1) - mean_x = self.mean(x, kernel) - mean_y = self.mean(y, kernel) - square_sum = F.square(mean_x)+F.square(mean_y) - luminance = (2*mean_x*mean_y+c1)/(square_sum+c1) + Inputs: + - **img1** (Tensor) - The first image batch with format 'NCHW'. It should be the same shape and dtype as img2. + - **img2** (Tensor) - The second image batch with format 'NCHW'. It should be the same shape and dtype as img1. - # SSIM contrast*structure formula (when c3 = c2/2) - # (2 * conv_{xy} + c2) / (conv_{xx} + conv_{yy} + c2), equals to - # (2 * (mean_{xy} - mean_{x}*mean_{y}) + c2) / (mean_{xx}-mean_{x}**2 + mean_{yy}-mean_{y}**2 + c2) - mean_xy = self.mean(x*y, kernel) - mean_square_add = self.mean(F.square(x)+F.square(y), kernel) + Outputs: + Tensor, has the same dtype as img1. It is a 1-D tensor with shape N, where N is the batch num of img1. - cs = (2*(mean_xy-mean_x*mean_y)+c2)/(mean_square_add-square_sum+c2) + Examples: + >>> net = nn.MSSSIM(power_factors=(0.033, 0.033, 0.033)) + >>> img1 = Tensor(np.random.random((1,3,128,128))) + >>> img2 = Tensor(np.random.random((1,3,128,128))) + >>> msssim = net(img1, img2) + """ + def __init__(self, max_val=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11, + filter_sigma=1.5, k1=0.01, k2=0.03): + super(MSSSIM, self).__init__() + validator.check_value_type('max_val', max_val, [int, float], self.cls_name) + validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) + self.max_val = max_val + validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) + self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) + self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) + self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) + self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) + window = _create_window(filter_size, filter_sigma) + self.level = len(power_factors) + self.conv = [] + for i in range(self.level): + self.conv.append(_conv2d(1, 1, filter_size, Tensor(window))) + self.conv[i].weight.requires_grad = False + self.multi_convs_list = CellList(self.conv) + self.weight_tensor = Tensor(power_factors, mstype.float32) + self.avg_pool = AvgPool2d(kernel_size=2, stride=2, pad_mode='valid') + self.relu = ReLU() + self.reduce_mean = P.ReduceMean() + self.prod = P.ReduceProd() + self.pow = P.Pow() + self.pack = P.Pack(axis=-1) + self.concat = P.Concat(axis=1) - # SSIM formula - # luminance * cs - ssim = luminance*cs + def construct(self, img1, img2): + _check_input_4d(F.shape(img1), "img1", self.cls_name) + _check_input_4d(F.shape(img2), "img2", self.cls_name) + P.SameTypeShape()(img1, img2) + max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) + img1 = _convert_img_dtype_to_float32(img1, self.max_val) + img2 = _convert_img_dtype_to_float32(img2, self.max_val) - mean_ssim = P.ReduceMean()(ssim, (-3, -2, -1)) + c1 = (self.k1 * max_val) ** 2 + c2 = (self.k2 * max_val) ** 2 - return mean_ssim + sim = () + mcs = () - def _fspecial_gauss(self, filter_size, filter_sigma): - """get gauss kernel""" - filter_size, g = _gauss_kernel_helper(filter_size) + for i in range(self.level): + sim, cs = _compute_multi_channel_loss(c1, c2, img1, img2, + self.multi_convs_list[i], self.concat, self.reduce_mean) + mcs += (self.relu(cs),) + img1, img2 = _downsample(img1, img2, self.avg_pool) - square_sigma_scale = -0.5/(filter_sigma * filter_sigma) - g = g*square_sigma_scale - g = F.reshape(g, (1, -1))+F.reshape(g, (-1, 1)) - g = F.reshape(g, (1, -1)) - g = P.Softmax()(g) - ret = F.reshape(g, (1, 1, filter_size, filter_size)) - return ret + mcs = mcs[0:-1:1] + mcs_and_ssim = self.pack(mcs + (self.relu(sim),)) + mcs_and_ssim = self.pow(mcs_and_ssim, self.weight_tensor) + ms_ssim = self.prod(mcs_and_ssim, -1) + loss = self.reduce_mean(ms_ssim, -1) + return loss class PSNR(Cell): r""" diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 1ecb20056e..ddcaf2da6b 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -55,7 +55,7 @@ class ReduceLogSumExp(Cell): Examples: >>> input_x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32)) - >>> op = P.ReduceLogSumExp(keep_dims=True) + >>> op = nn.ReduceLogSumExp(keep_dims=True) >>> output = op(input_x, 1) """ @@ -132,23 +132,19 @@ class Range(Cell): class LinSpace(Cell): r""" - Generates values in an interval. And return the corresponding interpolation accroding to assist. + Generates values in an interval. Args: - - **start** (Union[int, float]) - The start of interval, With shape of 0-D. - - **stop** (Union[int, float]) - The end of interval, With shape of 0-D. - - **num** (int) - ticks number in the interval, the ticks include start and stop value. - With shape of 0-D. + start (Union[int, float]): The start of interval. With shape of 0-D. + stop (Union[int, float]): The end of interval. With shape of 0-D. + num (int): ticks number in the interval, the ticks include start and stop value. With shape of 0-D. Outputs: Tensor, With type same as `start`. The shape is 1-D with length of `num`. Examples: - >>> linspace = nn.LinSpace() - >>> start = Tensor(1, mindspore.float32) - >>> stop = Tensor(10, mindspore.float32) - >>> num = Tensor(5, mindspore.int32) - >>> output = linspace(start, stop, num) + >>> linspace = nn.LinSpace(1, 10, 5) + >>> output = linspace() [1, 3.25, 5.5, 7.75, 10] """ diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index d6c920b620..05e5e54b96 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -84,13 +84,14 @@ class _BatchNorm(Cell): self.dtype = P.DType() self.reshape = P.Reshape() self.is_ascend = context.get_context("device_target") == "Ascend" + self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE self.momentum = 1.0 - momentum if context.get_context("enable_ge"): self.is_ge_backend = True else: self.is_ge_backend = False - if self.is_ge_backend or self.is_ascend: + if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps) else: @@ -152,7 +153,7 @@ class _BatchNorm(Cell): if self.is_ge_backend and self.is_global: axes, re_shape = _shape_infer(F.shape(x), self.num_features) y = self._global_sync(x, axes, re_shape) - elif self.is_ge_backend or self.is_ascend: + elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend): if self.is_global: axes, re_shape = _shape_infer(F.shape(x), self.num_features) y = self._global_sync(x, axes, re_shape) @@ -587,7 +588,7 @@ class GroupNorm(Cell): """calculate groupnorm output""" batch, channel, height, width = self.shape(x) _channel_check(channel, self.num_channels) - x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups)) + x = self.reshape(x, (batch, self.num_groups, -1)) mean = self.reduce_mean(x, 2) var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1) std = self.sqrt(var + self.eps) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index f0c82937c5..63cdedbfe9 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -17,6 +17,7 @@ from functools import partial import numpy as np +from mindspore import nn import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore.ops import functional as F @@ -41,8 +42,7 @@ __all__ = [ 'Conv2dBatchNormQuant', 'Conv2dQuant', 'DenseQuant', - 'ReLUQuant', - 'ReLU6Quant', + 'ActQuant', 'HSwishQuant', 'HSigmoidQuant', 'TensorAddQuant', @@ -375,9 +375,10 @@ class FakeQuantWithMinMax(Cell): def extend_repr(self): s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ - 'quant_delay={}, min_init={}, max_init={}'.format( - self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel, - self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init) + 'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range, + self.ema, self.ema_decay, self.per_channel, + self.channel_axis, self.num_channels, self.quant_delay, + self.min_init, self.max_init) return s def construct(self, x): @@ -540,10 +541,12 @@ class Conv2dBatchNormQuant(Cell): def extend_repr(self): s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \ - 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( - self.in_channels, self.out_channels, self.kernel_size, self.stride, - self.pad_mode, self.padding, self.dilation, self.group, - self.fake, self.freeze_bn, self.momentum, self.quant_delay) + 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels, + self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, + self.group, + self.fake, self.freeze_bn, self.momentum, + self.quant_delay) return s def construct(self, x): @@ -685,10 +688,9 @@ class Conv2dQuant(Cell): def extend_repr(self): s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ 'pad_mode={}, padding={}, dilation={}, group={}, ' \ - 'has_bias={}, quant_delay={}'.format( - self.in_channels, self.out_channels, self.kernel_size, self.stride, - self.pad_mode, self.padding, self.dilation, self.group, - self.has_bias, self.quant_delay) + 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.has_bias, self.quant_delay) return s @@ -799,76 +801,23 @@ class DenseQuant(Cell): class _QuantActivation(Cell): r""" - Base class for Quant activation function. Add Fake Quant OP after activation OP. + Base class for quantization aware training activation function. Add Fake Quant OP after activation OP. """ def get_origin(self): raise NotImplementedError -class ReLUQuant(_QuantActivation): +class ActQuant(_QuantActivation): r""" - ReLUQuant activation function. Add Fake Quant OP after Relu OP. + Quantization aware training activation function. - For a more Detailed overview of ReLU op. - - Args: - ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. - - Inputs: - - **x** (Tensor) - The input of ReLUQuant. - - Outputs: - Tensor, with the same type and shape as the `x`. - - Examples: - >>> relu_quant = nn.ReLUQuant() - >>> input_x = Tensor(np.array([[1, 2, 0], [-1, -2, 1]]), mindspore.float32) - >>> result = relu_quant(input_x) - """ - - def __init__(self, - ema_decay=0.999, - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): - super(ReLUQuant, self).__init__() - self.fake_quant_act = FakeQuantWithMinMax(min_init=0, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) - self.relu = P.ReLU() - - def construct(self, x): - x = self.relu(x) - x = self.fake_quant_act(x) - return x - - def get_origin(self): - return self.relu - - -class ReLU6Quant(_QuantActivation): - r""" - ReLU6Quant activation function. - - Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op + Add Fake Quant OP after activation. Not Recommand to used these cell for Fake Quant Op Will climp the max range of the activation and the relu6 do the same operation. For a more Detailed overview of ReLU6 op. Args: + activation (Cell): Activation cell class. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. @@ -883,19 +832,20 @@ class ReLU6Quant(_QuantActivation): Tensor, with the same type and shape as the `x`. Examples: - >>> relu6_quant = nn.ReLU6Quant(4, 1) + >>> act_quant = nn.ActQuant(4, 1) >>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32) - >>> result = relu6_quant(input_x) + >>> result = act_quant(input_x) """ def __init__(self, + activation, ema_decay=0.999, per_channel=False, num_bits=8, symmetric=False, narrow_range=False, quant_delay=0): - super(ReLU6Quant, self).__init__() + super(ActQuant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=0, max_init=6, ema=True, @@ -905,15 +855,15 @@ class ReLU6Quant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - self.relu6 = P.ReLU6() + self.act = activation() def construct(self, x): - x = self.relu6(x) + x = self.act(x) x = self.fake_quant_act(x) return x def get_origin(self): - return self.relu6 + return self.act class HSwishQuant(_QuantActivation): @@ -923,6 +873,7 @@ class HSwishQuant(_QuantActivation): For a more Detailed overview of HSwish op. Args: + activation (Cell): Activation cell class. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. @@ -943,6 +894,7 @@ class HSwishQuant(_QuantActivation): """ def __init__(self, + activation, ema_decay=0.999, per_channel=False, num_bits=8, @@ -968,7 +920,10 @@ class HSwishQuant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - self.act = P.HSwish() + if issubclass(activation, nn.HSwish): + self.act = activation() + else: + raise ValueError("Activation should be `nn.HSwish`") def construct(self, x): x = self.fake_quant_act_before(x) @@ -987,6 +942,7 @@ class HSigmoidQuant(_QuantActivation): For a more Detailed overview of HSigmoid op. Args: + activation (Cell): Activation cell class. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization granularity based on layer or on channel. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. @@ -1007,6 +963,7 @@ class HSigmoidQuant(_QuantActivation): """ def __init__(self, + activation, ema_decay=0.999, per_channel=False, num_bits=8, @@ -1032,7 +989,10 @@ class HSigmoidQuant(_QuantActivation): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - self.act = P.HSigmoid() + if issubclass(activation, nn.HSwish): + self.act = activation() + else: + raise ValueError("Activation should be `nn.HSigmoid`") def construct(self, x): x = self.fake_quant_act_before(x) @@ -1209,9 +1169,9 @@ class QuantBlock(Cell): return x def extend_repr(self): - str_info = f'quant={self.quant}, core_op={type(self.core_op)}' + str_info = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]' if self.has_bias: - str_info = str_info + f', bias={self.bias}' + str_info = str_info + f', bias=shape[{self.bias.shape}]' if self.has_act: str_info = str_info + f', activation={self.activation}' str_info = str_info + f', dequant={self.dequant}' diff --git a/mindspore/nn/optim/__init__.py b/mindspore/nn/optim/__init__.py index f1dac586bc..538c400067 100644 --- a/mindspore/nn/optim/__init__.py +++ b/mindspore/nn/optim/__init__.py @@ -20,14 +20,14 @@ The optimizer is used to calculate and update the gradients. """ from .optimizer import Optimizer from .momentum import Momentum -from .adam import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR +from .adam import Adam, PSAdam, AdamWeightDecay, AdamWeightDecayDynamicLR from .lamb import Lamb from .sgd import SGD from .lars import LARS -from .ftrl import FTRL +from .ftrl import FTRL, PSFTRL from .rmsprop import RMSProp from .proximal_ada_grad import ProximalAdagrad from .lazyadam import LazyAdam -__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', - 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad'] +__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'PSAdam', 'AdamWeightDecay', 'LazyAdam', + 'AdamWeightDecayDynamicLR', 'Lamb', 'SGD', 'FTRL', 'PSFTRL', 'RMSProp', 'ProximalAdagrad'] diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index b73c284aab..eb6e64074f 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -27,6 +27,7 @@ from mindspore._checkparam import Rel from .optimizer import Optimizer _adam_opt = C.MultitypeFuncGraph("adam_opt") +_adam_push_pull_opt = C.MultitypeFuncGraph("_adam_push_pull_opt") @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @@ -129,6 +130,31 @@ def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, b eps, gradient)) return success +@_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tuple", "Tensor", "Tensor", "Tensor") +def _run_push_pull_opt_with_sparse(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, + moment1, moment2): + """Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse.""" + success = True + op_shape = P.Shape() + shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), + op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), + op_shape(beta2), op_shape(eps), op_shape(gradient[1]), op_shape(gradient[0])) + success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, + eps, gradient[1], gradient[0]), shapes), params)) + return success + + +@_adam_push_pull_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") +def _run_push_pull_opt_with_one_number(push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, + moment1, moment2): + """Apply adam optimizer by push and pull to the weight parameter using Tensor.""" + success = True + op_shape = P.Shape() + success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), + (op_shape(params), op_shape(moment1), op_shape(moment2))), params)) + return success class Adam(Optimizer): r""" @@ -162,8 +188,8 @@ class Adam(Optimizer): To improve parameter groups performance, the customized order of parameters can be supported. - The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the - `sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. + The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. Args: @@ -274,6 +300,51 @@ class Adam(Optimizer): gradients, params, moment1, moment2) return success +class PSAdam(Optimizer): + '''The same usage as Adam optimizer except the parameters are set PS mode.''' + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, + use_nesterov=False, weight_decay=0.0, loss_scale=1.0): + super(PSAdam, self).__init__(learning_rate, params, weight_decay, loss_scale) + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) + validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) + + self.beta1 = Tensor(beta1, mstype.float32) + self.beta2 = Tensor(beta2, mstype.float32) + self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power") + self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power") + self.eps = Tensor(eps, mstype.float32) + + self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') + self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') + + self.hyper_map = C.HyperMap() + self.push = P.Push("Adam", [0, 1, 2]) + self.push.add_prim_attr("primitive_target", "CPU") + self.pull = P.Pull() + self.pull.add_prim_attr("primitive_target", "CPU") + + def construct(self, gradients): + params = self.parameters + moment1 = self.moment1 + moment2 = self.moment2 + gradients = self.decay_weight(gradients) + gradients = self.scale_grad(gradients) + lr = self.get_lr() + + beta1_power = self.beta1_power * self.beta1 + self.beta1_power = beta1_power + beta2_power = self.beta2_power * self.beta2 + self.beta2_power = beta2_power + if self.is_group_lr: + success = self.map_(F.partial(_adam_push_pull_opt, self.push, self.pull, beta1_power, beta2_power, + self.beta1, self.beta2, self.eps), + lr, gradients, params, moment1, moment2) + else: + success = self.map_(F.partial(_adam_push_pull_opt, self.push, self.pull, beta1_power, beta2_power, + self.beta1, self.beta2, self.eps, lr), + gradients, params, moment1, moment2) + return success class AdamWeightDecay(Optimizer): """ @@ -388,7 +459,7 @@ class AdamWeightDecayDynamicLR(Optimizer): beta2=0.999, eps=1e-6, weight_decay=0.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): + decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): super(AdamWeightDecayDynamicLR, self).__init__(0.0, params) if self.is_group: raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index b2954430b4..dd2ebddfa7 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -22,6 +22,7 @@ from mindspore._checkparam import Rel from .optimizer import Optimizer, _apply_decay, _grad_scale _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") +_ftrl_push_pull_opt = C.MultitypeFuncGraph("ftrl_opt") @_ftrl_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", "Tensor", @@ -41,6 +42,26 @@ def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gra success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) return success +@_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tuple", + "Tensor", "Tensor") +def _tensor_run_push_pull_opt_with_sparse(push, pull, learning_rate, l1, l2, lr_power, linear, gradient, + weight, moment): + success = True + op_shape = P.Shape() + shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(gradient[1]), op_shape(gradient[0])) + success = F.depend(success, pull(push((gradient[1], gradient[0]), shapes), weight)) + return success + + +@_ftrl_push_pull_opt.register("Function", "Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", + "Tensor", "Tensor") +def _tensor_run_push_pull_opt_with_one_number(push, pull, learning_rate, l1, l2, lr_power, linear, gradient, + weight, moment): + success = True + op_shape = P.Shape() + success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power), + (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) + return success def _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay=0.0, prim_name=None): """Check param.""" @@ -72,8 +93,8 @@ class FTRL(Optimizer): `_ for engineering document. Note: - The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the - `sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. + The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. Args: @@ -131,3 +152,37 @@ class FTRL(Optimizer): success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, lr, self.l1, self.l2, self.lr_power), linear, grads, params, moments) return success + +class PSFTRL(Optimizer): + def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, + use_locking=False, loss_scale=1.0, weight_decay=0.0): + super(PSFTRL, self).__init__(learning_rate, params, loss_scale=loss_scale) + if self.is_group: + raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") + _check_param(initial_accum, lr_power, l1, l2, use_locking, weight_decay, self.cls_name) + self.moments = self.parameters.clone(prefix="moments", init=initial_accum) + self.linear = self.parameters.clone(prefix="linear", init='zeros') + self.l1 = l1 + self.l2 = l2 + self.lr_power = lr_power + self.weight_decay = weight_decay + self.decay_tf = tuple((lambda: True)() for x in self.parameters) + + self.hyper_map = C.HyperMap() + self.push = P.Push("Ftrl", [0, 1, 2]) + self.push.add_prim_attr("primitive_target", "CPU") + self.pull = P.Pull() + self.pull.add_prim_attr("primitive_target", "CPU") + + def construct(self, grads): + params = self.parameters + moments = self.moments + linear = self.linear + lr = self.learning_rate + if self.weight_decay > 0.0: + grads = self.hyper_map(F.partial(_apply_decay, self.weight_decay), self.decay_tf, params, grads) + + grads = self.scale_grad(grads) + success = self.map_(F.partial(_ftrl_push_pull_opt, self.push, self.pull, lr, self.l1, self.l2, self.lr_power), + linear, grads, params, moments) + return success diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index 4b97d2eb20..7905398437 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -91,8 +91,8 @@ class LazyAdam(Optimizer): value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. - The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the - `sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. + The sparse behavior, to be notice, is not equivalent to the original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index 75f3994e2a..25cf438034 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -59,8 +59,8 @@ class ProximalAdagrad(Optimizer): `_. Note: - The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the - `sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse + The sparse strategy is applied while the SparseGatherV2 operator being used for forward network. + The sparse feature is under continuous development. The sparse behavior is currently performed on the CPU. Args: @@ -71,7 +71,7 @@ class ProximalAdagrad(Optimizer): l1 (float): l1 regularization strength, must be greater than or equal to zero. Default: 0.0. l2 (float): l2 regularization strength, must be greater than or equal to zero. Default: 0.0. use_locking (bool): If True use locks for update operation. Default: False. - loss_scale (float): Value for the loss scale. It should be equal to or greater than 1.0. Default: 1.0. + loss_scale (float): Value for the loss scale. It should be greater than 0.0. Default: 1.0. wegith_decay (float): Weight decay value to multiply weight, must be zero or positive value. Default: 0.0. Inputs: diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index 8e8885aff7..c4d3347038 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -171,7 +171,7 @@ class RMSProp(Optimizer): self.opt = P.ApplyRMSProp(use_locking) self.momentum = momentum - self.ms = self.parameters.clone(prefix="mean_square", init='zeros') + self.ms = self.parameters.clone(prefix="mean_square", init='ones') self.moment = self.parameters.clone(prefix="moment", init='zeros') self.hyper_map = C.HyperMap() self.epsilon = epsilon diff --git a/mindspore/ops/__init__.py b/mindspore/ops/__init__.py index b73d683284..7265b3c98b 100644 --- a/mindspore/ops/__init__.py +++ b/mindspore/ops/__init__.py @@ -32,7 +32,7 @@ Note: from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry -from .op_info_register import op_info_register, AkgRegOp, AiCPURegOp, TBERegOp, DataType +from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType from .primitive import constexpr from .._c_expression import signature_rw, signature_kind @@ -42,6 +42,6 @@ __primitive__ = [ ] __all__ = ["get_vm_impl_fn", "vm_impl_registry", - "op_info_register", "AkgRegOp", "AiCPURegOp", "TBERegOp", "DataType", + "op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType", "constexpr"] __all__.extend(__primitive__) diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index d1494bc051..b1a3e1d98b 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -191,13 +191,12 @@ def get_bprop_tile(self): return bprop -@bprop_getters.register(inner.EmbeddingLookup) +@bprop_getters.register(P.EmbeddingLookup) def get_bprop_embedding_lookup(self): """Generate bprop for EmbeddingLookup""" sub_op = P.Sub() reshape_op = P.Reshape() - host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU') - def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout): + def bprop_sparse(x, indices, offset, out, dout): x_shp = shape_op(x) new_indices = sub_op(indices, offset) # Reshape the 'new_indices' @@ -205,17 +204,9 @@ def get_bprop_embedding_lookup(self): new_indices = reshape_op(new_indices, new_indices_shape_changed) x_shp_tail = x_shp[1:] actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail - if reduce_scatter_flag is True: - # On host - elu_grad = G.EmbeddingLookupCommGrad() - actual_dout = elu_grad(dout, split_num) - # Reshape the 'actual_dout' on host - actual_dout = host_reshape(actual_dout, actual_dout_shape_changed) - else: - # Reshape the 'actual_dout' on device - actual_dout = reshape_op(dout, actual_dout_shape_changed) - return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset), \ - zeros_like(reduce_scatter_flag), zeros_like(split_num) + # Reshape the 'actual_dout' on device + actual_dout = reshape_op(dout, actual_dout_shape_changed) + return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) return bprop_sparse @@ -248,19 +239,37 @@ def get_bprop_transpose(self): return bprop +@constexpr +def _concat_grad_uniform(input_shapes, input_nums): + """Helper function for bprop of Concat""" + is_uniform = True + for i in range(1, input_nums): + if input_shapes[i-1] != input_shapes[i]: + is_uniform = False + break + return is_uniform + @bprop_getters.register(P.Concat) def get_bprop_concat(self): """Generate bprop for Concat""" axis = self.axis + is_ascend = context.get_context('device_target') == "Ascend" def bprop(x, out, dout): dx = () out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x) - for i in range(F.tuple_len(x)): - slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i])) - dx = dx + (slice_out,) + input_nums = F.tuple_len(x) + input_shapes = () + for i in range(input_nums): + input_shapes = input_shapes + (shape_op(x[i]),) + is_uniform = _concat_grad_uniform(input_shapes, input_nums) + if is_uniform and is_ascend: + dx = P.Split(axis, input_nums)(dout) + else: + for i in range(input_nums): + slice_out = P.Slice()(dout, out_offset[i], input_shapes[i]) + dx = dx + (slice_out,) return (dx,) - return bprop @@ -644,6 +653,36 @@ def get_bprop_unsorted_segment_min(self): return bprop +@bprop_getters.register(P.UnsortedSegmentProd) +def get_bprop_unsorted_segment_prod(self): + """Generate bprop for UnsortedSegmentProd""" + equal = P.Equal() + cast = P.Cast() + select = P.Select() + gather = P.GatherV2() + greater = P.Greater() + ones_like = P.OnesLike() + maximum = P.Maximum() + unsorted_segment_prod = P.UnsortedSegmentProd() + + def bprop(x, segment_ids, num_segments, out, dout): + is_zero = equal(x, 0) + num_zero = unsorted_segment_sum(cast(is_zero, mstype.int32), segment_ids, num_segments) + grad = select(greater(num_zero, 1), zeros_like(dout), dout) + non_zero_data = select(is_zero, ones_like(x), x) + non_zero_prod = unsorted_segment_prod(non_zero_data, segment_ids, num_segments) + zero_clipped_indices = maximum(segment_ids, zeros_like(segment_ids)) + gathered_prod = gather(out, zero_clipped_indices, 0) + gathered_non_zero_prod = gather(non_zero_prod, zero_clipped_indices, 0) + prod_divided_by_x = gathered_prod / x + partial_derivative = select(is_zero, gathered_non_zero_prod, prod_divided_by_x) + gathered_grad, _, _ = _GatherDropNegatives(grad, segment_ids, zero_clipped_indices) + dx = gathered_grad * partial_derivative + return dx, zeros_like(segment_ids), zeros_like(num_segments) + + return bprop + + @bprop_getters.register(P.SpaceToBatch) def get_bprop_space_to_batch(self): """Generate bprop for SpaceToBatch""" diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 3a86a05943..61c7e40960 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -760,6 +760,19 @@ def get_bprop_ctc_loss(self): return bprop +@bprop_getters.register(P.CTCLossV2) +def get_bprop_ctc_loss_v2(self): + """Grad definition for `CTCLossV2` operation""" + expand = P.ExpandDims() + + def bprop(inputs, labels, input_lengths, labels_lengths, out, dout): + grad_loss = out[1] + grad = grad_loss * expand(dout[0], -1) + return grad, zeros_like(labels), zeros_like(input_lengths), zeros_like(labels_lengths) + + return bprop + + @bprop_getters.register(P.BasicLSTMCell) def get_bprop_basic_lstm_cell(self): """Grad definition for `BasicLSTMCell` operation.""" diff --git a/mindspore/ops/_op_impl/__init__.py b/mindspore/ops/_op_impl/__init__.py index 65a12cd73c..59729f833f 100644 --- a/mindspore/ops/_op_impl/__init__.py +++ b/mindspore/ops/_op_impl/__init__.py @@ -17,7 +17,7 @@ import platform from .aicpu import * if "Windows" not in platform.system(): - from .akg.gpu import * + from .akg import * from .tbe import * __all__ = [] diff --git a/mindspore/ops/_op_impl/akg/__init__.py b/mindspore/ops/_op_impl/akg/__init__.py index fd86dbf999..c4c70b7aa1 100644 --- a/mindspore/ops/_op_impl/akg/__init__.py +++ b/mindspore/ops/_op_impl/akg/__init__.py @@ -13,77 +13,6 @@ # limitations under the License. # ============================================================================ -"""autodiff ops""" -from .abs import _abs_akg -from .add_n import _add_n_akg -from .add import _add_akg -from .apply_momentum import _apply_momentum_akg -from .assign import _assign_akg -from .inplace_assign import _inplace_assign_akg -from .assign_add import _assign_add_akg -from .bias_add_grad import _bias_add_grad_akg -from .bias_add import _bias_add_akg -from .cast import _cast_akg -from .clear_zero import _clear_zero_akg -from .conv_bn1 import _conv_bn1_akg -from .conv2d_backprop_filter import _conv2d_backprop_filter_akg -from .conv2d_backprop_input import _conv2d_backprop_input_akg -from .conv2d import _conv2d_akg -from .div import _div_akg -from .equal_count import _equal_count_akg -from .exp import _exp_akg -from .five2four import _five2four_akg -from .four2five import _four2five_akg -from .fused_batch_norm_grad import _fused_batch_norm_grad_akg -from .fused_batch_norm_infer import _fused_batch_norm_infer_akg -from .fused_batch_norm import _fused_batch_norm_akg -from .fused_bn1_grad import _bn1_grad_akg -from .fused_bn1 import _fused_bn1_akg -from .fused_bn2_grad import _bn2_grad_akg -from .fused_bn2 import _fused_bn2_akg -from .fused_bn3_grad import _bn3_grad_akg -from .fused_bn3 import _fused_bn3_akg -from .gather_v2 import _gather_v2_akg -from .less import _less_akg -from .log import _log_akg -from .matmul import _matmul_akg -from .batchmatmul import _batchmatmul_akg -from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg -from .max_pool_with_argmax import _max_pool_with_argmax_akg -from .max import _max_akg -from .maximum import _maximum_akg -from .mean_grad import _mean_grad_akg -from .mean import _mean_akg -from .minimum import _minimum_akg -from .mul import _mul_akg -from .neg import _neg_akg -from .one_hot import _one_hot_akg -from .pow import _power_akg -from .real_div import _real_div_akg -from .reciprocal import _reciprocal_akg -from .reduce_max import _reduce_max_akg -from .reduce_mean import _reduce_mean_akg -from .reduce_sum import _reduce_sum_akg -from .relu_grad import _relu_grad_akg -from .relu import _relu_akg -from .reshape import _reshape_akg -from .round import _round_akg -from .rsqrt import _rsqrt_akg -from .select import _select_akg -from .softmax import _softmax_akg -from .sparse_softmax_cross_entropy_with_logits import _sparse_softmax_cross_entropy_with_logits_akg -from .sqrt import _sqrt_akg -from .strided_slice import _strided_slice_akg -from .sub import _sub_akg -from .sum import _sum_akg -from .tile import _tile_akg -from .zeros_like import _zeros_like_akg -from .argmax import _argmax_akg -from .floordiv import _floor_div_akg -from .equal import _equal_akg -from .greater_equal import _greater_equal_akg -from .less_equal import _less_equal_akg -from .expand_dims import _expand_dims_akg -from .greater import _greater_akg -from .equiv_format import _equiv_format_akg +"""akg ops""" +from . import ascend from . import gpu diff --git a/mindspore/ops/_op_impl/akg/abs.py b/mindspore/ops/_op_impl/akg/abs.py deleted file mode 100644 index 8c08f405da..0000000000 --- a/mindspore/ops/_op_impl/akg/abs.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Abs op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Abs", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _abs_akg(): - """Abs AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/add.py b/mindspore/ops/_op_impl/akg/add.py deleted file mode 100644 index 60544ea1c7..0000000000 --- a/mindspore/ops/_op_impl/akg/add.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""TensorAdd op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "TensorAdd", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "int32", "float16", "int32", "float32", "float32", - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0", - "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "int32", "float16", "int32", "float32", "float32", - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0", - "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "int32", "float16", "int32", "float32", "float32", - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0", - "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _add_akg(): - """TensorAdd AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/add_n.py b/mindspore/ops/_op_impl/akg/add_n.py deleted file mode 100644 index 53320f752e..0000000000 --- a/mindspore/ops/_op_impl/akg/add_n.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""AddN op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "AddN", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16","float32","float16","float32", "float16", "float32", - "float16","float32" - ], - "format": [ - "DefaultFormat","DefaultFormat","NC1HWC0","NC1HWC0", "FracZ", "FracZ", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "dynamic", - "name": "inputs" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16","float32","float16","float32", "float16", "float32", - "float16","float32" - ], - "format": [ - "DefaultFormat","DefaultFormat","NC1HWC0","NC1HWC0", "FracZ", "FracZ", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _add_n_akg(): - """AddN AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/apply_momentum.py b/mindspore/ops/_op_impl/akg/apply_momentum.py deleted file mode 100644 index 7160571882..0000000000 --- a/mindspore/ops/_op_impl/akg/apply_momentum.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""ApplyMomentum op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "ApplyMomentum", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - { - "name": "use_nesterov", - "param_type": "optional", - "type": "bool" - }, - { - "name": "gradient_scale", - "param_type": "optional", - "type": "float" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32","float32","float32" - ], - "format": [ - "DefaultFormat","NC1HWC0","FracZ" - ], - "name": "variable" - }, - { - "index": 1, - "dtype": [ - "float32","float32","float32" - ], - "format": [ - "DefaultFormat","NC1HWC0","FracZ" - ], - "name": "accumulation" - }, - { - "index": 2, - "dtype": [ - "float32","float32","float32" - ], - "format": [ - "DefaultFormat","DefaultFormat","DefaultFormat" - ], - "name": "learning_rate" - }, - { - "index": 3, - "dtype": [ - "float32","float32","float32" - ], - "format": [ - "DefaultFormat","NC1HWC0","FracZ" - ], - "name": "gradient" - }, - { - "index": 4, - "dtype": [ - "float32","float32","float32" - ], - "format": [ - "DefaultFormat","DefaultFormat","DefaultFormat" - ], - "name": "momentum" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32","float32","float32" - ], - "format": [ - "DefaultFormat","NC1HWC0","FracZ" - ], - "name": "output" - } - ] -}""") -def _apply_momentum_akg(): - """ApplyMomentum AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/argmax.py b/mindspore/ops/_op_impl/akg/argmax.py deleted file mode 100644 index b04862cbeb..0000000000 --- a/mindspore/ops/_op_impl/akg/argmax.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Argmax op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Argmax", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "axis", - "param_type": "optional", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "int32", "int32", "int32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _argmax_akg(): - """Argmax AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/ascend/__init__.py b/mindspore/ops/_op_impl/akg/ascend/__init__.py new file mode 100644 index 0000000000..a4d7aec7d0 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""__init__""" + +from .add import _add_akg +from .batchmatmul import _batchmatmul_akg +from .cast import _cast_akg +from .expand_dims import _expand_dims_akg +from .greater import _greater_akg +from .inplace_assign import _inplace_assign_akg +from .maximum import _maximum_akg +from .minimum import _minimum_akg +from .mul import _mul_akg +from .real_div import _real_div_akg +from .rsqrt import _rsqrt_akg +from .select import _select_akg +from .sqrt import _sqrt_akg +from .sub import _sub_akg diff --git a/mindspore/ops/_op_impl/akg/ascend/add.py b/mindspore/ops/_op_impl/akg/ascend/add.py new file mode 100644 index 0000000000..d8689eed6d --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/add.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""TensorAdd op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("TensorAdd") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ + .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ + .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ + .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \ + .dtype_format(DT.I32_FracNZ, DT.I32_FracNZ, DT.I32_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _add_akg(): + """TensorAdd Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/batchmatmul.py b/mindspore/ops/_op_impl/akg/ascend/batchmatmul.py new file mode 100644 index 0000000000..d7815c15e6 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/batchmatmul.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""BatchMatMul op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("BatchMatMul") \ + .fusion_type("OPAQUE") \ + .input(0, "x1") \ + .input(1, "x2") \ + .output(0, "output") \ + .attr("transpose_a", "optional", "bool") \ + .attr("transpose_b", "optional", "bool") \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _batchmatmul_akg(): + """BatchMatMul AKG register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/cast.py b/mindspore/ops/_op_impl/akg/ascend/cast.py new file mode 100644 index 0000000000..1b874352f8 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/cast.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cast op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Cast") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .output(0, "output") \ + .attr("dst_type", "required", "str") \ + .dtype_format(DT.F16_Default, DT.F32_Default) \ + .dtype_format(DT.F16_Default, DT.I32_Default) \ + .dtype_format(DT.F32_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.I32_Default) \ + .dtype_format(DT.I32_Default, DT.F16_Default) \ + .dtype_format(DT.I32_Default, DT.F32_Default) \ + .dtype_format(DT.BOOL_Default, DT.F16_Default) \ + .dtype_format(DT.BOOL_Default, DT.F32_Default) \ + .dtype_format(DT.BOOL_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F32_5HD) \ + .dtype_format(DT.F32_5HD, DT.F16_5HD) \ + .dtype_format(DT.BOOL_5HD, DT.I32_5HD) \ + .dtype_format(DT.BOOL_5HD, DT.F32_5HD) \ + .dtype_format(DT.F16_FracNZ, DT.F32_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.BOOL_FracNZ, DT.I32_FracNZ) \ + .dtype_format(DT.BOOL_FracNZ, DT.F32_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _cast_akg(): + """Cast Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/expand_dims.py b/mindspore/ops/_op_impl/akg/ascend/expand_dims.py new file mode 100644 index 0000000000..24faf241aa --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/expand_dims.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ExpandDims op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("ExpandDims") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .output(0, "y") \ + .attr("axis", "required", "int") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default) \ + .get_op_info() + + +@op_info_register(op_info) +def _expand_dims_akg(): + """ExpandDims Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/greater.py b/mindspore/ops/_op_impl/akg/ascend/greater.py new file mode 100644 index 0000000000..14164c895b --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/greater.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Greater op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Greater") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.BOOL_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.BOOL_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.BOOL_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.BOOL_5HD) \ + .get_op_info() + + +@op_info_register(op_info) +def _greater_akg(): + """Greater Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/inplace_assign.py b/mindspore/ops/_op_impl/akg/ascend/inplace_assign.py new file mode 100644 index 0000000000..9f76706440 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/inplace_assign.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""InplaceAssign op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("InplaceAssign") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .input(2, "z") \ + .output(0, "output") \ + .attr("fake_output", "optional", "bool") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ + .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ + .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ + .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _inplace_assign_akg(): + """InplaceAssign Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/maximum.py b/mindspore/ops/_op_impl/akg/ascend/maximum.py new file mode 100644 index 0000000000..b57de7d15a --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/maximum.py @@ -0,0 +1,36 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Maximum op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Maximum") \ + .fusion_type("COMMREDUCE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ + .get_op_info() + + +@op_info_register(op_info) +def _maximum_akg(): + """Maximum Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/minimum.py b/mindspore/ops/_op_impl/akg/ascend/minimum.py new file mode 100644 index 0000000000..cdc0abfc6d --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/minimum.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Minimum op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Minimum") \ + .fusion_type("COMMREDUCE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \ + .dtype_format(DT.I32_FracNZ, DT.I32_FracNZ, DT.I32_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _minimum_akg(): + """Minimum Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/mul.py b/mindspore/ops/_op_impl/akg/ascend/mul.py new file mode 100644 index 0000000000..ea21888b84 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/mul.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Mul op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Mul") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .attr("x_shape", "required", "listInt") \ + .attr("y_shape", "required", "listInt") \ + .attr("data_format", "required", "listStr") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ + .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _mul_akg(): + """Mul Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/real_div.py b/mindspore/ops/_op_impl/akg/ascend/real_div.py new file mode 100644 index 0000000000..c7c3ad9eb6 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/real_div.py @@ -0,0 +1,36 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RealDiv op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("RealDiv") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _real_div_akg(): + """RealDiv Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/rsqrt.py b/mindspore/ops/_op_impl/akg/ascend/rsqrt.py new file mode 100644 index 0000000000..55cf876951 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/rsqrt.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Rsqrt op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Rsqrt") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD) \ + .get_op_info() + + +@op_info_register(op_info) +def _rsqrt_akg(): + """Rsqrt Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/select.py b/mindspore/ops/_op_impl/akg/ascend/select.py new file mode 100644 index 0000000000..67fee114ca --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/select.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Select op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Select") \ + .fusion_type("ELEMWISE") \ + .input(0, "condition") \ + .input(1, "x") \ + .input(2, "y") \ + .output(0, "output") \ + .dtype_format(DT.BOOL_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.BOOL_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.BOOL_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.BOOL_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.BOOL_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.BOOL_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ + .get_op_info() + + +@op_info_register(op_info) +def _select_akg(): + """Select Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/sqrt.py b/mindspore/ops/_op_impl/akg/ascend/sqrt.py new file mode 100644 index 0000000000..43f64b8973 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/sqrt.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Sqrt op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Sqrt") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD) \ + .get_op_info() + + +@op_info_register(op_info) +def _sqrt_akg(): + """Sqrt Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/ascend/sub.py b/mindspore/ops/_op_impl/akg/ascend/sub.py new file mode 100644 index 0000000000..62001b3f44 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/ascend/sub.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Sub op""" +from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT + +op_info = AkgAscendRegOp("Sub") \ + .fusion_type("ELEMWISE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \ + .dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \ + .dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \ + .dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \ + .dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \ + .dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \ + .dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \ + .dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \ + .dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \ + .dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \ + .dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \ + .dtype_format(DT.I32_FracNZ, DT.I32_FracNZ, DT.I32_FracNZ) \ + .get_op_info() + + +@op_info_register(op_info) +def _sub_akg(): + """Sub Akg register""" + return diff --git a/mindspore/ops/_op_impl/akg/assign.py b/mindspore/ops/_op_impl/akg/assign.py deleted file mode 100644 index e7c5a082bd..0000000000 --- a/mindspore/ops/_op_impl/akg/assign.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Assign op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Assign", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ" - ], - "name": "ref" - }, - { - "index": 1, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ" - ], - "name": "value" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ" - ], - "name": "output" - } - ] -}""") -def _assign_akg(): - """Assign AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/assign_add.py b/mindspore/ops/_op_impl/akg/assign_add.py deleted file mode 100644 index 7d0d345764..0000000000 --- a/mindspore/ops/_op_impl/akg/assign_add.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""AssignAdd op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "AssignAdd", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "ref" - }, - { - "index": 1, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "value" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _assign_add_akg(): - """AssignAdd AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/batchmatmul.py b/mindspore/ops/_op_impl/akg/batchmatmul.py deleted file mode 100644 index f5da71aa25..0000000000 --- a/mindspore/ops/_op_impl/akg/batchmatmul.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""BatchMatMul op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "BatchMatMul", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "transpose_a", - "param_type": "optional", - "type": "bool" - }, - { - "name": "transpose_b", - "param_type": "optional", - "type": "bool" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "x1" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "x2" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _batchmatmul_akg(): - """BatchMatMul AKG register""" - return diff --git a/mindspore/ops/_op_impl/akg/bias_add.py b/mindspore/ops/_op_impl/akg/bias_add.py deleted file mode 100644 index 74f2bf7bcf..0000000000 --- a/mindspore/ops/_op_impl/akg/bias_add.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""BiasAdd op""" - -from mindspore.ops.op_info_register import op_info_register - -@op_info_register("""{ - "op_name": "BiasAdd", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - { - "name": "data_format", - "param_type": "optional", - "type": "listStr" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16","float32","float16","float32","float16","float32" - ], - "format": [ - "NHWC","NHWC","NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16","float32","float16","float32","float16","float32" - ], - "format": [ - "NHWC","NHWC","NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat" - ], - "name": "b" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16","float32","float16","float32","float16","float32" - ], - "format": [ - "DefaultFormat","DefaultFormat","NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat" - ], - "name": "output" - } - ] -}""") -def _bias_add_akg(): - """BiasAddGrad AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/bias_add_grad.py b/mindspore/ops/_op_impl/akg/bias_add_grad.py deleted file mode 100644 index 7726af6692..0000000000 --- a/mindspore/ops/_op_impl/akg/bias_add_grad.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""BiasAddGrad op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "BiasAddGrad", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - { - "name": "data_format", - "param_type": "optional", - "type": "listStr" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16","float32","float16","float32","float16","float32" - ], - "format": [ - "NHWC","NHWC","NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat" - ], - "name": "dout" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16","float32","float16","float32","float16","float32" - ], - "format": [ - "DefaultFormat","DefaultFormat","NC1HWC0","NC1HWC0","DefaultFormat","DefaultFormat" - ], - "name": "output" - } - ] -}""") -def _bias_add_grad_akg(): - """BiasAddGrad AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/cast.py b/mindspore/ops/_op_impl/akg/cast.py deleted file mode 100644 index a78d4d87e4..0000000000 --- a/mindspore/ops/_op_impl/akg/cast.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Cast op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Cast", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - { - "name": "dst_type", - "param_type": "required", - "type": "str" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "bool", "bool", - "float16", "float32", "int32", "int32", - "bool", - "float16", "float32", "bool", "bool", - "float16", "float32", "bool", "bool" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", - "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", - "DefaultFormat", - "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32", "float16", "int32", "float16", - "int32", "int32", "float16", "float32", - "float32", - "float32", "float16", "int32", "float32", - "float32", "float16", "int32", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", - "DefaultFormat", "DefaultFormat", "DefaultFormat", "DefaultFormat", - "DefaultFormat", - "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _cast_akg(): - """Cast AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/clear_zero.py b/mindspore/ops/_op_impl/akg/clear_zero.py deleted file mode 100644 index 38bf35044f..0000000000 --- a/mindspore/ops/_op_impl/akg/clear_zero.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""ClearZero op""" - -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "ClearZero", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - { - "name": "pad_mod", - "param_type": "optional", - "type": "string" - }, - { - "name": "window", - "param_type": "optional", - "type": "int" - }, - { - "name": "pad", - "param_type": "optional", - "type": "int" - }, - { - "name": "stride", - "param_type": "optional", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - ] -}""") -def _clear_zero_akg(): - """MaxPoolGradWithArgmax AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/conv2d.py b/mindspore/ops/_op_impl/akg/conv2d.py deleted file mode 100644 index 709aca7001..0000000000 --- a/mindspore/ops/_op_impl/akg/conv2d.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Conv2D op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Conv2D", - "imply_type": "AutoDiff", - "fusion_type": "CONVLUTION", - "attr": [ - { - "name": "x_shape", - "param_type": "required", - "type": "listInt" - }, - { - "name": "w_shape", - "param_type": "required", - "type": "listInt" - }, - { - "name": "pad_list", - "param_type": "required", - "type": "listInt" - }, - { - "name": "stride", - "param_type": "optional", - "type": "int" - }, - { - "name": "dilation", - "param_type": "optional", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "FracZ" - ], - "name": "w" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _conv2d_akg(): - """Conv2D AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/conv2d_backprop_filter.py b/mindspore/ops/_op_impl/akg/conv2d_backprop_filter.py deleted file mode 100644 index 1e4e4f1a1e..0000000000 --- a/mindspore/ops/_op_impl/akg/conv2d_backprop_filter.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Conv2DBackpropFilter op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Conv2DBackpropFilter", - "imply_type": "AutoDiff", - "fusion_type": "CONVLUTION", - "attr": [ - { - "name": "input_shape", - "param_type": "required", - "type": "listInt" - }, - { - "name": "filter_sizes", - "param_type": "required", - "type": "listInt" - }, - { - "name": "stride", - "param_type": "optional", - "type": "int" - }, - { - "name": "pad_list", - "param_type": "required", - "type": "listInt" - }, - { - "name": "dilation", - "param_type": "optional", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "out_backprop" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "input" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "FracZ" - ], - "name": "output" - } - ] -}""") -def _conv2d_backprop_filter_akg(): - """Conv2DBackpropFilter AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/conv2d_backprop_input.py b/mindspore/ops/_op_impl/akg/conv2d_backprop_input.py deleted file mode 100644 index 52c7f2e7b3..0000000000 --- a/mindspore/ops/_op_impl/akg/conv2d_backprop_input.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Conv2DBackpropInput op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Conv2DBackpropInput", - "imply_type": "AutoDiff", - "fusion_type": "CONVLUTION", - "attr": [ - { - "name": "input_sizes", - "param_type": "required", - "type": "listInt" - }, - { - "name": "filter_shape", - "param_type": "required", - "type": "listInt" - }, - { - "name": "stride", - "param_type": "optional", - "type": "int" - }, - { - "name": "pad_list", - "param_type": "required", - "type": "listInt" - }, - { - "name": "dilation", - "param_type": "optional", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "out_backprop" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "FracZ" - ], - "name": "filter" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _conv2d_backprop_input_akg(): - """Conv2DBackpropInput AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/conv_bn1.py b/mindspore/ops/_op_impl/akg/conv_bn1.py deleted file mode 100644 index 118c94e6fc..0000000000 --- a/mindspore/ops/_op_impl/akg/conv_bn1.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""ConvBN1 op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "ConvBN1", - "imply_type": "AutoDiff", - "fusion_type": "CONVLUTION", - "attr": [ - { - "name": "x_shape", - "param_type": "required", - "type": "listInt" - }, - { - "name": "w_shape", - "param_type": "required", - "type": "listInt" - }, - { - "name": "pad_list", - "param_type": "required", - "type": "listInt" - }, - { - "name": "stride", - "param_type": "optional", - "type": "int" - }, - { - "name": "dilation", - "param_type": "optional", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "FracZ" - ], - "name": "w" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "conv_res_16" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "var_part" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "mean" - } - ] -}""") -def _conv_bn1_akg(): - """ConvBN1 AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/div.py b/mindspore/ops/_op_impl/akg/div.py deleted file mode 100644 index 56cdcca868..0000000000 --- a/mindspore/ops/_op_impl/akg/div.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Div op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Div", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _div_akg(): - """Div AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/equal.py b/mindspore/ops/_op_impl/akg/equal.py deleted file mode 100644 index 35874c62bb..0000000000 --- a/mindspore/ops/_op_impl/akg/equal.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Equal op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Equal", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "bool", "bool", "bool", "bool", "bool", "bool" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _equal_akg(): - """Equal AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/equal_count.py b/mindspore/ops/_op_impl/akg/equal_count.py deleted file mode 100644 index 9c575db7b3..0000000000 --- a/mindspore/ops/_op_impl/akg/equal_count.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""EqualCount op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "EqualCount", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32" - ], - "format": [ - "DefaultFormat" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "int32" - ], - "format": [ - "DefaultFormat" - ], - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "int32" - ], - "format": [ - "DefaultFormat" - ], - "name": "output" - } - ] -}""") -def _equal_count_akg(): - """EqualCount AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/equiv_format.py b/mindspore/ops/_op_impl/akg/equiv_format.py deleted file mode 100644 index 111451b15c..0000000000 --- a/mindspore/ops/_op_impl/akg/equiv_format.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""EquivFormat op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "EquivFormat", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "FRACTAL_NZ", "FRACTAL_NZ", "DefaultFormat", "DefaultFormat" - ], - "name": "output" - } - ] -}""") -def _equiv_format_akg(): - """EquivFormat AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/exp.py b/mindspore/ops/_op_impl/akg/exp.py deleted file mode 100644 index 273b3348a4..0000000000 --- a/mindspore/ops/_op_impl/akg/exp.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Exp op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Exp", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _exp_akg(): - """Exp AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/expand_dims.py b/mindspore/ops/_op_impl/akg/expand_dims.py deleted file mode 100644 index 9e1b18153a..0000000000 --- a/mindspore/ops/_op_impl/akg/expand_dims.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""ExpandDims op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "ExpandDims", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "axis", - "param_type": "required", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat" - ], - "name": "y" - } - ] -}""") -def _expand_dims_akg(): - """ExpandDims AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/five2four.py b/mindspore/ops/_op_impl/akg/five2four.py deleted file mode 100644 index 1dac2c3628..0000000000 --- a/mindspore/ops/_op_impl/akg/five2four.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Five2Four op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Five2Four", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "shape4d", - "param_type": "required", - "type": "listInt" - }, - { - "name": "dstType", - "param_type": "required", - "type": "str" - }, - { - "name": "output_format", - "param_type": "required", - "type": "str" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16","float16","float16","float32","float16","float32" - ], - "format": [ - "NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0","NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16","float16","float32","float32","float32","float32" - ], - "format": [ - "DefaultFormat","NHWC","DefaultFormat","DefaultFormat","NHWC","NHWC" - ], - "name": "output" - } - ] -}""") -def _five2four_akg(): - """Five2Four AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/floordiv.py b/mindspore/ops/_op_impl/akg/floordiv.py deleted file mode 100644 index 99e577b4be..0000000000 --- a/mindspore/ops/_op_impl/akg/floordiv.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FloorDiv op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "FloorDiv", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "int32", "int32", "int32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _floor_div_akg(): - """FloorDiv AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/four2five.py b/mindspore/ops/_op_impl/akg/four2five.py deleted file mode 100644 index 01b6f85715..0000000000 --- a/mindspore/ops/_op_impl/akg/four2five.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Four2Five op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Four2Five", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "data_format", - "param_type": "optional", - "type": "listStr" - }, - { - "name": "dst_type", - "param_type": "required", - "type": "str" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float32", "float16","float32", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NHWC", "NHWC", "NHWC" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float16", "float32", "float16", "float16", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _four2five_akg(): - """Four2Five AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/fused_batch_norm.py b/mindspore/ops/_op_impl/akg/fused_batch_norm.py deleted file mode 100644 index 5ce9839328..0000000000 --- a/mindspore/ops/_op_impl/akg/fused_batch_norm.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FusedBatchNorm op""" - -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "FusedBatchNorm", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "momentum", - "param_type": "optional", - "type": "float" - }, - { - "name": "epsilon", - "param_type": "optional", - "type": "float" - }, - { - "name": "data_format", - "param_type": "optional", - "type": "listStr" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "scale" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "b" - }, - { - "index": 3, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "mean" - }, - { - "index": 4, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "variance" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "y" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "running_mean" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "running_variance" - }, - { - "index": 3, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "save_mean" - }, - { - "index": 4, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "save_inv_variance" - } - ] -}""") -def _fused_batch_norm_akg(): - """FusedBatchNorm AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/fused_batch_norm_grad.py b/mindspore/ops/_op_impl/akg/fused_batch_norm_grad.py deleted file mode 100644 index 9191548f73..0000000000 --- a/mindspore/ops/_op_impl/akg/fused_batch_norm_grad.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FusedBatchNormGrad op""" - -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "FusedBatchNormGrad", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "data_format", - "param_type": "optional", - "type": "listStr" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "dy" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "x" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "scale" - }, - { - "index": 3, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "save_mean" - }, - { - "index": 4, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "save_inv_variance" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "dx" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "bn_scale" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "bn_bias" - } - ] -}""") -def _fused_batch_norm_grad_akg(): - """BiasAddGrad AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/fused_batch_norm_infer.py b/mindspore/ops/_op_impl/akg/fused_batch_norm_infer.py deleted file mode 100644 index 1e7743fa8f..0000000000 --- a/mindspore/ops/_op_impl/akg/fused_batch_norm_infer.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FusedBatchNormInfer op""" - -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "FusedBatchNormInfer", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "momentum", - "param_type": "optional", - "type": "float" - }, - { - "name": "epsilon", - "param_type": "optional", - "type": "float" - }, - { - "name": "data_format", - "param_type": "optional", - "type": "listStr" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "scale" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "b" - }, - { - "index": 3, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "mean" - }, - { - "index": 4, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "variance" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "y" - } - ] -}""") -def _fused_batch_norm_infer_akg(): - """FusedBatchNormInfer AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/fused_bn1.py b/mindspore/ops/_op_impl/akg/fused_bn1.py deleted file mode 100644 index fdaa673f25..0000000000 --- a/mindspore/ops/_op_impl/akg/fused_bn1.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FusedBN1 op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "FusedBN1", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "data" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "output" - }, - { - "index": 1, - "dtype": [ - "float32", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _fused_bn1_akg(): - """FusedBN1 AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/fused_bn1_grad.py b/mindspore/ops/_op_impl/akg/fused_bn1_grad.py deleted file mode 100644 index 8de6796d6f..0000000000 --- a/mindspore/ops/_op_impl/akg/fused_bn1_grad.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""BNGrad1 op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "BNGrad1", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "dy" - }, - { - "index": 1, - "dtype": [ - "float16", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "data" - },{ - "index": 2, - "dtype": [ - "float32", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "mean" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "output" - }, - { - "index": 1, - "dtype": [ - "float32", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "output" - }, - { - "index": 2, - "dtype": [ - "float32", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _bn1_grad_akg(): - """BNGrad1 AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/fused_bn2.py b/mindspore/ops/_op_impl/akg/fused_bn2.py deleted file mode 100644 index e26a5ad8a0..0000000000 --- a/mindspore/ops/_op_impl/akg/fused_bn2.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FusedBN2 op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "FusedBN2", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - { - "name": "momentum", - "param_type": "optional", - "type": "float" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "mean" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "var_part" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "running_mean" - }, - { - "index": 3, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "running_var" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _fused_bn2_akg(): - """FusedBN2 AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/fused_bn2_grad.py b/mindspore/ops/_op_impl/akg/fused_bn2_grad.py deleted file mode 100644 index e29a9177b6..0000000000 --- a/mindspore/ops/_op_impl/akg/fused_bn2_grad.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""BNGrad1 op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "BNGrad2", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - { - "name": "eps", - "param_type": "optional", - "type": "float" - }, - { - "name": "data_shape", - "param_type": "optional", - "type": "listInt" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "dgamma_red_hw" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "dbeta_red_hw" - },{ - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "variance" - }, - { - "index": 3, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "gamma" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - }, - { - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - }, - { - "index": 3, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - }, - { - "index": 4, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _bn2_grad_akg(): - """BNGrad2 AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/fused_bn3.py b/mindspore/ops/_op_impl/akg/fused_bn3.py deleted file mode 100644 index 74f3f652f3..0000000000 --- a/mindspore/ops/_op_impl/akg/fused_bn3.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""FusedBN3 op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "FusedBN3", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - { - "name": "eps", - "param_type": "optional", - "type": "float" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "data" - }, - { - "index": 1, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "mean" - },{ - "index": 2, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "variance" - },{ - "index": 3, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "gamma" - },{ - "index": 4, - "dtype": [ - "float32" - ], - "format": [ - "NC1HWC0" - ], - "name": "beta" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _fused_bn3_akg(): - """FusedBN3 AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/fused_bn3_grad.py b/mindspore/ops/_op_impl/akg/fused_bn3_grad.py deleted file mode 100644 index 5ffc57a68e..0000000000 --- a/mindspore/ops/_op_impl/akg/fused_bn3_grad.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""BNGrad3 op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "BNGrad3", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "dy" - }, - { - "index": 1, - "dtype": [ - "float32", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "rs" - },{ - "index": 2, - "dtype": [ - "float32", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "dgamma_dx" - }, - { - "index": 3, - "dtype": [ - "float32", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "dbeta_dx" - }, - { - "index": 4, - "dtype": [ - "float32", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "data_minus_mean" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _bn3_grad_akg(): - """BNGrad3 AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/gather_v2.py b/mindspore/ops/_op_impl/akg/gather_v2.py deleted file mode 100644 index 84ab7eb669..0000000000 --- a/mindspore/ops/_op_impl/akg/gather_v2.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""GatherV2 op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "GatherV2", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - { - "name": "axis", - "param_type": "optional", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat" - ], - "name": "params" - }, - { - "index": 1, - "dtype": [ - "int32", "int32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat" - ], - "name": "indices" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat" - ], - "name": "output" - } - ] -}""") -def _gather_v2_akg(): - """GatherV2 AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/gpu/cast.py b/mindspore/ops/_op_impl/akg/gpu/cast.py index 2f31dab1ba..c8aef249cd 100644 --- a/mindspore/ops/_op_impl/akg/gpu/cast.py +++ b/mindspore/ops/_op_impl/akg/gpu/cast.py @@ -13,15 +13,16 @@ # limitations under the License. """Cast op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -cast_op_info = AkgRegOp("Cast") \ +cast_op_info = AkgGpuRegOp("Cast") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .output(0, "output") \ .attr("dst_type", "required", "str") \ .dtype_format(DataType.F16_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_Default, DataType.F32_Default) \ .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/akg/gpu/equal.py b/mindspore/ops/_op_impl/akg/gpu/equal.py index fa20392411..40a3590f61 100644 --- a/mindspore/ops/_op_impl/akg/gpu/equal.py +++ b/mindspore/ops/_op_impl/akg/gpu/equal.py @@ -13,9 +13,9 @@ # limitations under the License. """Equal op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -equal_op_info = AkgRegOp("Equal") \ +equal_op_info = AkgGpuRegOp("Equal") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .input(1, "y") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/greater_equal.py b/mindspore/ops/_op_impl/akg/gpu/greater_equal.py index b000cbd0e3..666c939b4b 100644 --- a/mindspore/ops/_op_impl/akg/gpu/greater_equal.py +++ b/mindspore/ops/_op_impl/akg/gpu/greater_equal.py @@ -13,9 +13,9 @@ # limitations under the License. """GreaterEqual op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -greater_equal_op_info = AkgRegOp("GreaterEqual") \ +greater_equal_op_info = AkgGpuRegOp("GreaterEqual") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .input(1, "y") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/hsigmoid.py b/mindspore/ops/_op_impl/akg/gpu/hsigmoid.py index 4e802c1cad..34e1e7f14a 100644 --- a/mindspore/ops/_op_impl/akg/gpu/hsigmoid.py +++ b/mindspore/ops/_op_impl/akg/gpu/hsigmoid.py @@ -13,9 +13,9 @@ # limitations under the License. """HSigmoid op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -hsigmoid_op_info = AkgRegOp("HSigmoid") \ +hsigmoid_op_info = AkgGpuRegOp("HSigmoid") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .output(0, "output") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/hsigmoid_grad.py b/mindspore/ops/_op_impl/akg/gpu/hsigmoid_grad.py index 39b819138e..5e08ffb41c 100644 --- a/mindspore/ops/_op_impl/akg/gpu/hsigmoid_grad.py +++ b/mindspore/ops/_op_impl/akg/gpu/hsigmoid_grad.py @@ -13,9 +13,9 @@ # limitations under the License. """HSigmoidGrad op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -hsigmoidgrad_op_info = AkgRegOp("HSigmoidGrad") \ +hsigmoidgrad_op_info = AkgGpuRegOp("HSigmoidGrad") \ .fusion_type("OPAQUE") \ .input(0, "y_grad") \ .input(1, "x") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/hswish.py b/mindspore/ops/_op_impl/akg/gpu/hswish.py index 29f20bafae..77d2c3b50c 100644 --- a/mindspore/ops/_op_impl/akg/gpu/hswish.py +++ b/mindspore/ops/_op_impl/akg/gpu/hswish.py @@ -13,9 +13,9 @@ # limitations under the License. """HSwish op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -hswish_op_info = AkgRegOp("HSwish") \ +hswish_op_info = AkgGpuRegOp("HSwish") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .output(0, "output") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/hswish_grad.py b/mindspore/ops/_op_impl/akg/gpu/hswish_grad.py index 38e8c78e28..3857486f0c 100644 --- a/mindspore/ops/_op_impl/akg/gpu/hswish_grad.py +++ b/mindspore/ops/_op_impl/akg/gpu/hswish_grad.py @@ -13,9 +13,9 @@ # limitations under the License. """HSwishGrad op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -hswish_grad_op_info = AkgRegOp("HSwishGrad") \ +hswish_grad_op_info = AkgGpuRegOp("HSwishGrad") \ .fusion_type("OPAQUE") \ .input(0, "y_grad") \ .input(1, "x") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/lessequal.py b/mindspore/ops/_op_impl/akg/gpu/lessequal.py index a8babf7ae4..58c9c7f90a 100644 --- a/mindspore/ops/_op_impl/akg/gpu/lessequal.py +++ b/mindspore/ops/_op_impl/akg/gpu/lessequal.py @@ -13,9 +13,9 @@ # limitations under the License. """LessEqual op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -lessequal_op_info = AkgRegOp("LessEqual") \ +lessequal_op_info = AkgGpuRegOp("LessEqual") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .input(1, "y") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/logical_and.py b/mindspore/ops/_op_impl/akg/gpu/logical_and.py index da5b696512..58abcd8064 100644 --- a/mindspore/ops/_op_impl/akg/gpu/logical_and.py +++ b/mindspore/ops/_op_impl/akg/gpu/logical_and.py @@ -13,9 +13,9 @@ # limitations under the License. """LogicalAnd op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -logicaland_op_info = AkgRegOp("LogicalAnd") \ +logicaland_op_info = AkgGpuRegOp("LogicalAnd") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .input(1, "y") \ @@ -23,6 +23,7 @@ logicaland_op_info = AkgRegOp("LogicalAnd") \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ .get_op_info() + @op_info_register(logicaland_op_info) def _logical_and_akg(): """LogicalAnd register""" diff --git a/mindspore/ops/_op_impl/akg/gpu/logical_not.py b/mindspore/ops/_op_impl/akg/gpu/logical_not.py index 4b3c7bf647..33815f489a 100644 --- a/mindspore/ops/_op_impl/akg/gpu/logical_not.py +++ b/mindspore/ops/_op_impl/akg/gpu/logical_not.py @@ -13,15 +13,16 @@ # limitations under the License. """LogicalNot op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -logical_not_op_info = AkgRegOp("LogicalNot") \ +logical_not_op_info = AkgGpuRegOp("LogicalNot") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .output(0, "output") \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ .get_op_info() + @op_info_register(logical_not_op_info) def _logical_not_akg(): """LogicalNot AutoDiff register""" diff --git a/mindspore/ops/_op_impl/akg/gpu/logical_or.py b/mindspore/ops/_op_impl/akg/gpu/logical_or.py index 3a642511c6..163674ac2a 100644 --- a/mindspore/ops/_op_impl/akg/gpu/logical_or.py +++ b/mindspore/ops/_op_impl/akg/gpu/logical_or.py @@ -13,9 +13,9 @@ # limitations under the License. """LogicalOr op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -logicalor_op_info = AkgRegOp("LogicalOr") \ +logicalor_op_info = AkgGpuRegOp("LogicalOr") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .input(1, "y") \ @@ -23,6 +23,7 @@ logicalor_op_info = AkgRegOp("LogicalOr") \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ .get_op_info() + @op_info_register(logicalor_op_info) def _logical_or_akg(): """LogicalOr register""" diff --git a/mindspore/ops/_op_impl/akg/gpu/mean.py b/mindspore/ops/_op_impl/akg/gpu/mean.py index b46b701b91..dd997ec0f1 100644 --- a/mindspore/ops/_op_impl/akg/gpu/mean.py +++ b/mindspore/ops/_op_impl/akg/gpu/mean.py @@ -13,9 +13,9 @@ # limitations under the License. """SimpleMean op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -mean_op_info = AkgRegOp("SimpleMean") \ +mean_op_info = AkgGpuRegOp("SimpleMean") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .output(0, "output") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/mean_grad.py b/mindspore/ops/_op_impl/akg/gpu/mean_grad.py index e3e0121c20..ae4620305a 100644 --- a/mindspore/ops/_op_impl/akg/gpu/mean_grad.py +++ b/mindspore/ops/_op_impl/akg/gpu/mean_grad.py @@ -13,9 +13,9 @@ # limitations under the License. """SimpleMeanGrad op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -mean_grad_op_info = AkgRegOp("SimpleMeanGrad") \ +mean_grad_op_info = AkgGpuRegOp("SimpleMeanGrad") \ .fusion_type("OPAQUE") \ .input(0, "HEAD") \ .output(0, "output") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/mul.py b/mindspore/ops/_op_impl/akg/gpu/mul.py index db5b1460ed..0da7b3fb6c 100644 --- a/mindspore/ops/_op_impl/akg/gpu/mul.py +++ b/mindspore/ops/_op_impl/akg/gpu/mul.py @@ -13,9 +13,9 @@ # limitations under the License. """Mul op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -mul_op_info = AkgRegOp("Mul") \ +mul_op_info = AkgGpuRegOp("Mul") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .input(1, "y") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/notequal.py b/mindspore/ops/_op_impl/akg/gpu/notequal.py index dc13449fc1..b9c9c55faf 100644 --- a/mindspore/ops/_op_impl/akg/gpu/notequal.py +++ b/mindspore/ops/_op_impl/akg/gpu/notequal.py @@ -13,9 +13,9 @@ # limitations under the License. """NotEqual op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -notequal_op_info = AkgRegOp("NotEqual") \ +notequal_op_info = AkgGpuRegOp("NotEqual") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .input(1, "y") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/relu6.py b/mindspore/ops/_op_impl/akg/gpu/relu6.py index 31bfebcd8d..33ae7f4dad 100644 --- a/mindspore/ops/_op_impl/akg/gpu/relu6.py +++ b/mindspore/ops/_op_impl/akg/gpu/relu6.py @@ -13,9 +13,9 @@ # limitations under the License. """ReLU6 op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -relu_op_info = AkgRegOp("ReLU6") \ +relu_op_info = AkgGpuRegOp("ReLU6") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .output(0, "output") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/relu6_grad.py b/mindspore/ops/_op_impl/akg/gpu/relu6_grad.py index 83d93f3077..c6ed702247 100644 --- a/mindspore/ops/_op_impl/akg/gpu/relu6_grad.py +++ b/mindspore/ops/_op_impl/akg/gpu/relu6_grad.py @@ -13,9 +13,9 @@ # limitations under the License. """ReLU6Grad op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -relu_grad_op_info = AkgRegOp("ReLU6Grad") \ +relu_grad_op_info = AkgGpuRegOp("ReLU6Grad") \ .fusion_type("OPAQUE") \ .input(0, "y_grad") \ .input(1, "x") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/squeeze.py b/mindspore/ops/_op_impl/akg/gpu/squeeze.py index cebf6ff1f3..8761b64890 100644 --- a/mindspore/ops/_op_impl/akg/gpu/squeeze.py +++ b/mindspore/ops/_op_impl/akg/gpu/squeeze.py @@ -13,9 +13,9 @@ # limitations under the License. """Squeeze op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -squeeze_op_info = AkgRegOp("Squeeze") \ +squeeze_op_info = AkgGpuRegOp("Squeeze") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .output(0, "output") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/squeeze_grad.py b/mindspore/ops/_op_impl/akg/gpu/squeeze_grad.py index 17e45a327a..41eacbf18f 100644 --- a/mindspore/ops/_op_impl/akg/gpu/squeeze_grad.py +++ b/mindspore/ops/_op_impl/akg/gpu/squeeze_grad.py @@ -13,9 +13,9 @@ # limitations under the License. """SqueezeGrad op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -squeeze_grad_op_info = AkgRegOp("SqueezeGrad") \ +squeeze_grad_op_info = AkgGpuRegOp("SqueezeGrad") \ .fusion_type("OPAQUE") \ .input(0, "y_grad") \ .output(0, "output") \ diff --git a/mindspore/ops/_op_impl/akg/gpu/sub.py b/mindspore/ops/_op_impl/akg/gpu/sub.py index 06b92fb49e..eaa8124067 100644 --- a/mindspore/ops/_op_impl/akg/gpu/sub.py +++ b/mindspore/ops/_op_impl/akg/gpu/sub.py @@ -13,9 +13,9 @@ # limitations under the License. """Sub op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -sub_op_info = AkgRegOp("Sub") \ +sub_op_info = AkgGpuRegOp("Sub") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .input(1, "y") \ @@ -25,6 +25,7 @@ sub_op_info = AkgRegOp("Sub") \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .get_op_info() + @op_info_register(sub_op_info) def _sub_akg(): """Sub AutoDiff register""" diff --git a/mindspore/ops/_op_impl/akg/gpu/tile.py b/mindspore/ops/_op_impl/akg/gpu/tile.py index 8c9de00979..e8e634d9a1 100644 --- a/mindspore/ops/_op_impl/akg/gpu/tile.py +++ b/mindspore/ops/_op_impl/akg/gpu/tile.py @@ -13,9 +13,9 @@ # limitations under the License. """Tile op""" -from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType -tile_op_info = AkgRegOp("Tile") \ +tile_op_info = AkgGpuRegOp("Tile") \ .fusion_type("OPAQUE") \ .input(0, "x") \ .output(0, "output") \ diff --git a/mindspore/ops/_op_impl/akg/greater.py b/mindspore/ops/_op_impl/akg/greater.py deleted file mode 100644 index 941946163a..0000000000 --- a/mindspore/ops/_op_impl/akg/greater.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Greater op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Greater", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float16", "float32", "float32" - ], - "format": [ - "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "float16", "float32", "float32" - ], - "format": [ - "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" - ], - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "bool", "bool", "bool", "bool" - ], - "format": [ - "DefaultFormat", "NC1HWC0", "DefaultFormat", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _greater_akg(): - """Greater AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/greater_equal.py b/mindspore/ops/_op_impl/akg/greater_equal.py deleted file mode 100644 index 11642baa86..0000000000 --- a/mindspore/ops/_op_impl/akg/greater_equal.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""GreaterEqual op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "GreaterEqual", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "bool", "bool", "bool", "bool", "bool", "bool" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _greater_equal_akg(): - """Equal AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/inplace_assign.py b/mindspore/ops/_op_impl/akg/inplace_assign.py deleted file mode 100644 index 1cc40abe9b..0000000000 --- a/mindspore/ops/_op_impl/akg/inplace_assign.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""InplaceAssign op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "InplaceAssign", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - { - "name": "fake_output", - "param_type": "optional", - "type": "bool" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ" - ], - "name": "y" - }, - { - "index": 2, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ" - ], - "name": "z" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", "FracZ", "FracZ", "FracZ" - ], - "name": "output" - } - ] -}""") -def _inplace_assign_akg(): - """InplaceAssign AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/less.py b/mindspore/ops/_op_impl/akg/less.py deleted file mode 100644 index 499ed2e8fc..0000000000 --- a/mindspore/ops/_op_impl/akg/less.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Less op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Less", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float16" - ], - "format": [ - "DefaultFormat", "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "float16" - ], - "format": [ - "DefaultFormat", "NC1HWC0" - ], - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "bool", "bool" - ], - "format": [ - "DefaultFormat", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _less_akg(): - """Less AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/less_equal.py b/mindspore/ops/_op_impl/akg/less_equal.py deleted file mode 100644 index 97fbdec090..0000000000 --- a/mindspore/ops/_op_impl/akg/less_equal.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""LessEqual op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "LessEqual", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "bool", "bool", "bool", "bool", "bool", "bool" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _less_equal_akg(): - """Equal AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/log.py b/mindspore/ops/_op_impl/akg/log.py deleted file mode 100644 index 526538d17d..0000000000 --- a/mindspore/ops/_op_impl/akg/log.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Log op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Log", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _log_akg(): - """Log AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/matmul.py b/mindspore/ops/_op_impl/akg/matmul.py deleted file mode 100644 index 084ba754fa..0000000000 --- a/mindspore/ops/_op_impl/akg/matmul.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""MatMul op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "MatMul", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "transpose_a", - "param_type": "optional", - "type": "bool" - }, - { - "name": "transpose_b", - "param_type": "optional", - "type": "bool" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "x1" - }, - { - "index": 1, - "dtype": [ - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "x2" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "output" - } - ] -}""") -def _matmul_akg(): - """MatMul AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/max.py b/mindspore/ops/_op_impl/akg/max.py deleted file mode 100644 index 21fd4ef9c4..0000000000 --- a/mindspore/ops/_op_impl/akg/max.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Max op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Max", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - { - "name": "axis", - "param_type": "required", - "type": "listInt" - }, - { - "name": "keep_dims", - "param_type": "required", - "type": "bool" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _max_akg(): - """Max AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/max_pool_grad_with_argmax.py b/mindspore/ops/_op_impl/akg/max_pool_grad_with_argmax.py deleted file mode 100644 index 4adad3eb88..0000000000 --- a/mindspore/ops/_op_impl/akg/max_pool_grad_with_argmax.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""MaxPoolGradWithArgmax op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "MaxPoolGradWithArgmax", - "imply_type": "AutoDiff", - "fusion_type": "CONVLUTION", - "attr": [ - { - "name": "pad_mode", - "param_type": "optional", - "type": "str" - }, - { - "name": "window", - "param_type": "optional", - "type": "int" - }, - { - "name": "pad", - "param_type": "optional", - "type": "int" - }, - { - "name": "stride", - "param_type": "optional", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float16" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat" - ], - "name": "argmax" - }, - { - "index": 2, - "dtype": [ - "float16", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "grad" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32" - ], - "format": [ - "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _max_pool_grad_with_argmax_akg(): - """MaxPoolGradWithArgmax AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/max_pool_with_argmax.py b/mindspore/ops/_op_impl/akg/max_pool_with_argmax.py deleted file mode 100644 index 3ae36d4793..0000000000 --- a/mindspore/ops/_op_impl/akg/max_pool_with_argmax.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""MaxPoolWithArgmax op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "MaxPoolWithArgmax", - "imply_type": "AutoDiff", - "fusion_type": "CONVLUTION", - "attr": [ - { - "name": "pad_mode", - "param_type": "optional", - "type": "str" - }, - { - "name": "window", - "param_type": "optional", - "type": "int" - }, - { - "name": "pad", - "param_type": "optional", - "type": "int" - }, - { - "name": "stride", - "param_type": "optional", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16" - ], - "format": [ - "NC1HWC0" - ], - "name": "output" - }, - { - "index": 1, - "dtype": [ - "float16" - ], - "format": [ - "DefaultFormat" - ], - "name": "argmax" - } - ] -}""") -def _max_pool_with_argmax_akg(): - """MaxPoolWithArgmax AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/maximum.py b/mindspore/ops/_op_impl/akg/maximum.py deleted file mode 100644 index 8d8de5270a..0000000000 --- a/mindspore/ops/_op_impl/akg/maximum.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Maximum op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Maximum", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "param_type": "required", - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "param_type": "required", - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _maximum_akg(): - """Maximum AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/mean.py b/mindspore/ops/_op_impl/akg/mean.py deleted file mode 100644 index 0b49e76865..0000000000 --- a/mindspore/ops/_op_impl/akg/mean.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SimpleMean op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "SimpleMean", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _mean_akg(): - """SimpleMean AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/mean_grad.py b/mindspore/ops/_op_impl/akg/mean_grad.py deleted file mode 100644 index 3b8379d1f0..0000000000 --- a/mindspore/ops/_op_impl/akg/mean_grad.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SimpleMeanGrad op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "SimpleMeanGrad", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - { - "name": "input_shape", - "param_type": "required", - "type": "listInt" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "HEAD" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _mean_grad_akg(): - """SimpleMeanGrad AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/minimum.py b/mindspore/ops/_op_impl/akg/minimum.py deleted file mode 100644 index 759df2085f..0000000000 --- a/mindspore/ops/_op_impl/akg/minimum.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Minimum op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Minimum", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32", - "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32", - "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32", - "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _minimum_akg(): - """Minimum AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/mul.py b/mindspore/ops/_op_impl/akg/mul.py deleted file mode 100644 index ab02c2d89e..0000000000 --- a/mindspore/ops/_op_impl/akg/mul.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Mul op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Mul", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - { - "name": "x_shape", - "param_type": "required", - "type": "listInt" - }, - { - "name": "y_shape", - "param_type": "required", - "type": "listInt" - }, - { - "name": "data_format", - "param_type": "required", - "type": "listStr" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "FracZ", "FracZ", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "float32", "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "FracZ", "FracZ", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "FracZ", "FracZ", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _mul_akg(): - """Mul AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/neg.py b/mindspore/ops/_op_impl/akg/neg.py deleted file mode 100644 index bc00d60271..0000000000 --- a/mindspore/ops/_op_impl/akg/neg.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Neg op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Neg", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32", - "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32", - "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _neg_akg(): - """Neg AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/one_hot.py b/mindspore/ops/_op_impl/akg/one_hot.py deleted file mode 100644 index c5034dbbd4..0000000000 --- a/mindspore/ops/_op_impl/akg/one_hot.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""OneHot op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "OneHot", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "depth", - "param_type": "required", - "type": "int" - }, - { - "name": "axis", - "param_type": "required", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32", "int32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat" - ], - "name": "indices" - }, - { - "index": 1, - "dtype": [ - "int32", "float32", "float16" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat" - ], - "name": "on_value" - }, - { - "index": 2, - "dtype": [ - "int32", "float32", "float16" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat" - ], - "name": "off_value" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "int32", "float32", "float16" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat" - ], - "name": "output" - } - ] -}""") -def _one_hot_akg(): - """OneHot AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/pow.py b/mindspore/ops/_op_impl/akg/pow.py deleted file mode 100644 index d782968c05..0000000000 --- a/mindspore/ops/_op_impl/akg/pow.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Pow op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Pow", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "int32", "float16", "int32", "float32", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0" - ], - "param_type": "required", - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "int32", "float16", "int32", "float32", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0" - ], - "param_type": "required", - "name": "power" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "int32", "float16", "int32", "float32", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _power_akg(): - """Pow AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/real_div.py b/mindspore/ops/_op_impl/akg/real_div.py deleted file mode 100644 index 9fa37a24e3..0000000000 --- a/mindspore/ops/_op_impl/akg/real_div.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""RealDiv op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "RealDiv", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "x" - }, - { - "index": 1, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _real_div_akg(): - """RealDiv AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/reciprocal.py b/mindspore/ops/_op_impl/akg/reciprocal.py deleted file mode 100644 index 9fd7cc40b4..0000000000 --- a/mindspore/ops/_op_impl/akg/reciprocal.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Reciprocal op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Reciprocal", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _reciprocal_akg(): - """Reciprocal AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/reduce_max.py b/mindspore/ops/_op_impl/akg/reduce_max.py deleted file mode 100644 index b9db8ea83a..0000000000 --- a/mindspore/ops/_op_impl/akg/reduce_max.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""ReduceMax op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "ReduceMax", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - { - "name": "axis", - "param_type": "required", - "type": "listInt" - }, - { - "name": "keep_dims", - "param_type": "required", - "type": "bool" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float16" - ], - "format": [ - "DefaultFormat", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float16" - ], - "format": [ - "DefaultFormat", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _reduce_max_akg(): - """ReduceMax AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/reduce_mean.py b/mindspore/ops/_op_impl/akg/reduce_mean.py deleted file mode 100644 index 0a4ffdf221..0000000000 --- a/mindspore/ops/_op_impl/akg/reduce_mean.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""ReduceMean op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "ReduceMean", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - { - "name": "axis", - "param_type": "required", - "type": "listInt" - }, - { - "name": "keep_dims", - "param_type": "required", - "type": "bool" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _reduce_mean_akg(): - """ReduceMean AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/reduce_sum.py b/mindspore/ops/_op_impl/akg/reduce_sum.py deleted file mode 100644 index 20d091ac76..0000000000 --- a/mindspore/ops/_op_impl/akg/reduce_sum.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""ReduceSum op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "ReduceSum", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - { - "name": "axis", - "param_type": "required", - "type": "listInt" - }, - { - "name": "keep_dims", - "param_type": "required", - "type": "bool" - }, - { - "name": "atomic_add", - "param_type": "optional", - "type": "str" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _reduce_sum_akg(): - """ReduceSum AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/relu.py b/mindspore/ops/_op_impl/akg/relu.py deleted file mode 100644 index b32725f885..0000000000 --- a/mindspore/ops/_op_impl/akg/relu.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""ReLU op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "ReLU", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _relu_akg(): - """ReLU AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/relu_grad.py b/mindspore/ops/_op_impl/akg/relu_grad.py deleted file mode 100644 index c785b750fe..0000000000 --- a/mindspore/ops/_op_impl/akg/relu_grad.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""ReluGrad op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "ReluGrad", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0" - ], - "name": "y_backprop" - }, - { - "index": 1, - "dtype": [ - "float16", "float32", "float16" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _relu_grad_akg(): - """ReluGrad AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/reshape.py b/mindspore/ops/_op_impl/akg/reshape.py deleted file mode 100644 index d200b66fa2..0000000000 --- a/mindspore/ops/_op_impl/akg/reshape.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Reshape op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Reshape", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "shape", - "param_type": "required", - "type": "listInt" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "tensor" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _reshape_akg(): - """Reshape AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/round.py b/mindspore/ops/_op_impl/akg/round.py deleted file mode 100644 index 0625c3ceda..0000000000 --- a/mindspore/ops/_op_impl/akg/round.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Round op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Round", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _round_akg(): - """Round AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/rsqrt.py b/mindspore/ops/_op_impl/akg/rsqrt.py deleted file mode 100644 index 9264864f91..0000000000 --- a/mindspore/ops/_op_impl/akg/rsqrt.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Rsqrt op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Rsqrt", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "param_type": "required", - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _rsqrt_akg(): - """Rsqrt AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/select.py b/mindspore/ops/_op_impl/akg/select.py deleted file mode 100644 index 006c6a5444..0000000000 --- a/mindspore/ops/_op_impl/akg/select.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Select op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Select", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "bool", "bool", "bool", "bool", "bool", "bool" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0" - ], - "param_type": "required", - "name": "condition" - }, - { - "index": 1, - "dtype": [ - "float16", "int32", "float16", "int32", "float32", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0" - ], - "param_type": "required", - "name": "x" - }, - { - "index": 2, - "dtype": [ - "float16", "int32", "float16", "int32", "float32", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0" - ], - "param_type": "required", - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "int32", "float16", "int32", "float32", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "DefaultFormat", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _select_akg(): - """Select AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/softmax.py b/mindspore/ops/_op_impl/akg/softmax.py deleted file mode 100644 index a41c2aef36..0000000000 --- a/mindspore/ops/_op_impl/akg/softmax.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Softmax op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Softmax", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - { - "name": "axis", - "param_type": "required", - "type": "listInt" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _softmax_akg(): - """Softmax AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/sparse_softmax_cross_entropy_with_logits.py b/mindspore/ops/_op_impl/akg/sparse_softmax_cross_entropy_with_logits.py deleted file mode 100644 index e9e828f312..0000000000 --- a/mindspore/ops/_op_impl/akg/sparse_softmax_cross_entropy_with_logits.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""SparseSoftmaxCrossEntropyWithLogits op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "SparseSoftmaxCrossEntropyWithLogits", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "is_grad", - "param_type": "optional", - "type": "bool" - }, - { - "name": "sens", - "param_type": "optional", - "type": "float" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "features" - }, - { - "index": 1, - "dtype": [ - "int32" - ], - "format": [ - "DefaultFormat" - ], - "name": "labels" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float32" - ], - "format": [ - "DefaultFormat" - ], - "name": "output" - } - ] -}""") -def _sparse_softmax_cross_entropy_with_logits_akg(): - """SparseSoftmaxCrossEntropyWithLogits AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/sqrt.py b/mindspore/ops/_op_impl/akg/sqrt.py deleted file mode 100644 index fcaa84b3d4..0000000000 --- a/mindspore/ops/_op_impl/akg/sqrt.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Sqrt op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Sqrt", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "param_type": "required", - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _sqrt_akg(): - """Sqrt AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/strided_slice.py b/mindspore/ops/_op_impl/akg/strided_slice.py deleted file mode 100644 index bdbd8dfc2f..0000000000 --- a/mindspore/ops/_op_impl/akg/strided_slice.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""StridedSlice op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "StridedSlice", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "begin", - "param_type": "required", - "type": "listInt" - }, - { - "name": "end", - "param_type": "required", - "type": "listInt" - }, - { - "name": "strides", - "param_type": "required", - "type": "listInt" - }, - { - "name": "begin_mask", - "param_type": "required", - "type": "int" - }, - { - "name": "end_mask", - "param_type": "required", - "type": "int" - }, - { - "name": "ellipsis_mask", - "param_type": "required", - "type": "int" - }, - { - "name": "new_axis_mask", - "param_type": "required", - "type": "int" - }, - { - "name": "shrink_axis_mask", - "param_type": "required", - "type": "int" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _strided_slice_akg(): - """StridedSlice AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/sub.py b/mindspore/ops/_op_impl/akg/sub.py deleted file mode 100644 index 846aa280bb..0000000000 --- a/mindspore/ops/_op_impl/akg/sub.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Sub op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Sub", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32", - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", - "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "x" - }, - { - "index": 1, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32", - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", - "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "y" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "int32", "float16", "float32", "int32", "float16", "float32", - "int32", "float16", "float32", "int32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0", - "FracZ", "FracZ", "FracZ", "FRACTAL_NZ", "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _sub_akg(): - """Sub AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/sum.py b/mindspore/ops/_op_impl/akg/sum.py deleted file mode 100644 index 501b387b25..0000000000 --- a/mindspore/ops/_op_impl/akg/sum.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Sum op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Sum", - "imply_type": "AutoDiff", - "fusion_type": "COMMREDUCE", - "attr": [ - { - "name": "axis", - "param_type": "required", - "type": "listInt" - }, - { - "name": "keepdims", - "param_type": "required", - "type": "bool" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "param_type": "required", - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32", - "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", - "FRACTAL_NZ", "FRACTAL_NZ" - ], - "name": "output" - } - ] -}""") -def _sum_akg(): - """Sum AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/tile.py b/mindspore/ops/_op_impl/akg/tile.py deleted file mode 100644 index bd13978fe7..0000000000 --- a/mindspore/ops/_op_impl/akg/tile.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Tile op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "Tile", - "imply_type": "AutoDiff", - "fusion_type": "OPAQUE", - "attr": [ - { - "name": "multiples", - "param_type": "required", - "type": "listInt" - } - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "int32", "float16", "float32", "int32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _tile_akg(): - """Tile AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/akg/zeros_like.py b/mindspore/ops/_op_impl/akg/zeros_like.py deleted file mode 100644 index a02ece22d7..0000000000 --- a/mindspore/ops/_op_impl/akg/zeros_like.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""ZerosLike op""" -from mindspore.ops.op_info_register import op_info_register - - -@op_info_register("""{ - "op_name": "ZerosLike", - "imply_type": "AutoDiff", - "fusion_type": "ELEMWISE", - "attr": [ - - ], - "inputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "x" - } - ], - "outputs": [ - { - "index": 0, - "dtype": [ - "float16", "float32", "float16", "float32" - ], - "format": [ - "DefaultFormat", "DefaultFormat", "NC1HWC0", "NC1HWC0" - ], - "name": "output" - } - ] -}""") -def _zeros_like_akg(): - """ZerosLike AutoDiff register""" - return diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 8009280ab8..317509b5a9 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -133,6 +133,7 @@ from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad from .apply_proximal_adagrad import _apply_proximal_adagrad from .transpose_d import _transpose_d_tbe from .unsorted_segment_sum import _unsorted_segment_sum_tbe +from .unsorted_segment_prod import _unsorted_segment_prod_tbe from .logsoftmax_grad import _logsoftmax_grad_tbe from .logsoftmax import _logsoftmax_tbe from .select import _select_tbe @@ -285,3 +286,5 @@ from .mod import _mod_tbe from .max_pool_grad_grad import _max_pool_grad_grad_tbe from .max_pool_grad_grad_with_argmax import _max_pool_grad_grad_with_argmax_tbe from .tensor_move import _tensor_move_tbe +from .population_count import _population_count_tbe +from .parallel_concat import _parallel_concat_tbe diff --git a/mindspore/ops/_op_impl/tbe/parallel_concat.py b/mindspore/ops/_op_impl/tbe/parallel_concat.py new file mode 100644 index 0000000000..46d8736fab --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/parallel_concat.py @@ -0,0 +1,80 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ParallelConcat op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +parallel_concat_op_info = TBERegOp("ParallelConcat") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("parallel_concat.so") \ + .compute_cost(10) \ + .kernel_name("parallel_concat") \ + .partial_flag(True) \ + .attr("shape", "required", "listInt", "all") \ + .attr("N", "required", "int", "all") \ + .input(0, "values", False, "dynamic", "all") \ + .output(0, "output_data", False, "required", "all") \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I16_5HD, DataType.I16_5HD) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U16_5HD, DataType.U16_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U32_5HD, DataType.U32_5HD) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.I64_5HD, DataType.I64_5HD) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.U64_5HD, DataType.U64_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ + .dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \ + .dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \ + .dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \ + .dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \ + .dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \ + .dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \ + .dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \ + .dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \ + .dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \ + .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \ + .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \ + .dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \ + .dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \ + .dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \ + .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \ + .dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \ + .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ + .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ + .get_op_info() + + +@op_info_register(parallel_concat_op_info) +def _parallel_concat_tbe(): + """ParallelConcat TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/population_count.py b/mindspore/ops/_op_impl/tbe/population_count.py new file mode 100644 index 0000000000..14feded367 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/population_count.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""PopulationCount op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +population_count_op_info = TBERegOp("PopulationCount") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("population_count.so") \ + .compute_cost(10) \ + .kernel_name("population_count") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I16_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.I16_Default, DataType.U8_Default) \ + .dtype_format(DataType.U16_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U16_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(population_count_op_info) +def _population_count_tbe(): + """PopulationCount TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/roi_align.py b/mindspore/ops/_op_impl/tbe/roi_align.py index bc4eed80ce..d392651217 100644 --- a/mindspore/ops/_op_impl/tbe/roi_align.py +++ b/mindspore/ops/_op_impl/tbe/roi_align.py @@ -27,7 +27,7 @@ roi_align_op_info = TBERegOp("ROIAlign") \ .attr("pooled_height", "required", "int", "all") \ .attr("pooled_width", "required", "int", "all") \ .attr("sample_num", "optional", "int", "all", "2") \ - .attr("roi_end_mode", "optional", "0,1", "1") \ + .attr("roi_end_mode", "optional", "int", "0,1", "1") \ .input(0, "features", False, "required", "all") \ .input(1, "rois", False, "required", "all") \ .input(2, "rois_n", False, "optional", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py b/mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py new file mode 100644 index 0000000000..40b04d17c3 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""UnsortedSegmentProdD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +unsorted_segment_prod_d_op_info = TBERegOp("UnsortedSegmentProd") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("unsorted_segment_prod_d.so") \ + .compute_cost(10) \ + .kernel_name("unsorted_segment_prod_d") \ + .partial_flag(True) \ + .attr("num_segments", "required", "int", "all") \ + .input(0, "data", False, "required", "all") \ + .input(1, "segment_ids", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.I32_Default, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.I32_Default, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.I32_Default, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.I32_Default, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.I32_Default, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.I32_Default, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_Default, DataType.I32_5HD) \ + .dtype_format(DataType.I32_FracZ, DataType.I32_Default, DataType.I32_FracZ) \ + .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_Default, DataType.I32_C1HWNCoC0) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(unsorted_segment_prod_d_op_info) +def _unsorted_segment_prod_tbe(): + """UnsortedSegmentProdD TBE register""" + return diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index b0f16d82bf..0f28d9572f 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -17,6 +17,7 @@ """Basic composite operations.""" from functools import partial +from types import FunctionType from mindspore import context from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \ @@ -25,6 +26,7 @@ from ...common import dtype as mstype from ...common.api import ms_function, _pynative_exec, _wrap_func from .. import functional as F from ...common.parameter import Parameter +from ...common.tensor import Tensor __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] @@ -114,37 +116,48 @@ class GradOperation(GradOperation_): self.fn = None self.need_forward = False + def _pynative_forward_run(self, args, fn): + """ Pynative forward run to build grad graph. """ + if self.sens_param: + args = args[:-1] + for arg in args: + if not isinstance(arg, Tensor): + raise TypeError("grad inputs should be tensor in pynative mode") + if isinstance(fn, FunctionType): + _pynative_exec.set_grad_flag(True) + _pynative_exec.new_graph(fn, *args) + output = fn(*args) + _pynative_exec.end_graph(fn, output, *args) + else: + if fn.is_run and not fn.requires_grad: + raise ValueError("obj must set_grad.") + if not fn.is_run: + self.need_forward = True + print("already has forward run before grad by user") + if self.need_forward: + fn.set_grad() + fn(*args) + def __call__(self, fn, weights=None): grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) if self.grad_fn is None or self.fn != fn: - if self.get_by_list: - if context.get_context("mode") == context.GRAPH_MODE: + if context.get_context("mode") == context.GRAPH_MODE: + if self.get_by_list: @ms_function(obj=fn) def after_grad(*args): return grad_(fn, weights)(*args) else: - @_wrap_func + @ms_function(obj=fn) def after_grad(*args): - if fn.is_run and not fn.requires_grad: - raise ValueError("obj must set_grad.") - if not fn.is_run: - self.need_forward = True - print("already has forward run before grad by user") - if self.need_forward: - fn.set_grad() - if self.sens_param: - f_args = args[:-1] - fn(*f_args) - else: - fn(*args) - _pynative_exec.grad(grad_, fn, weights, *args) - out = _pynative_exec(*args) - _pynative_exec.clear() - return out + return grad_(fn)(*args) else: - @ms_function(obj=fn) + @_wrap_func def after_grad(*args): - return grad_(fn)(*args) + self._pynative_forward_run(args, fn) + _pynative_exec.grad(grad_, fn, weights, *args) + out = _pynative_exec(*args) + _pynative_exec.clear() + return out self.grad_fn = after_grad self.fn = fn return self.grad_fn diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index a5c3165ab1..2be011cb77 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -158,7 +158,6 @@ make_indexed_slices = Primitive('MakeIndexedSlices') indexed_slices_get_values = Primitive('IndexedSlicesGetValues') indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') -is_indexed_slices = Primitive('IsIndexedSlices') tensor_operator_registry.register('__add__', tensor_add) @@ -166,6 +165,7 @@ tensor_operator_registry.register('__sub__', tensor_sub) tensor_operator_registry.register('__mul__', tensor_mul) tensor_operator_registry.register('__truediv__', tensor_div) tensor_operator_registry.register('__mod__', tensor_mod) +tensor_operator_registry.register('__pow__', tensor_pow) tensor_operator_registry.register('__floordiv__', tensor_floordiv) #ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index a7a60b7181..6ab915e369 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -215,10 +215,10 @@ class RegOp: class AkgRegOp(RegOp): """Class for Akg op info register.""" - def __init__(self, op_name): + def __init__(self, op_name, processor): super(AkgRegOp, self).__init__(op_name) - self.imply_type = "AutoDiff" - self.processor = "cuda" + self.imply_type = "AKG" + self.processor = processor def input(self, index=None, name=None, **kwargs): """ @@ -270,6 +270,16 @@ class AkgRegOp(RegOp): return self +class AkgGpuRegOp(AkgRegOp): + def __init__(self, op_name): + super(AkgGpuRegOp, self).__init__(op_name, "CUDA") + + +class AkgAscendRegOp(AkgRegOp): + def __init__(self, op_name): + super(AkgAscendRegOp, self).__init__(op_name, "AiCore") + + class AiCPURegOp(RegOp): """Class for AiCPU op info register""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 8806487579..1602f2594d 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -27,11 +27,11 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, - Shape, Size, Slice, Split, TransShape, EmbeddingLookup, + Shape, Size, Slice, Split, TransShape, ParallelConcat, Squeeze, StridedSlice, Tile, TensorScatterUpdate, - Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, + Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, - SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence) + SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup) from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, _VirtualDiv, _GetTensorSlice, @@ -62,7 +62,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl DropoutDoMask, DropoutGrad, Dropout, DropoutGenMask, Flatten, FusedBatchNorm, BNTrainingReduce, BNTrainingUpdate, Gelu, Elu, - GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, + GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, LogSoftmax, MaxPool, DataFormatDimMap, AvgPool, Conv2DBackpropInput, ConfusionMulGrad, @@ -77,10 +77,10 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) -from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, - CheckValid, MakeRefKey, Partial, Depend, CheckBprop) from . import _quant_ops from ._quant_ops import * +from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount, + CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull) from .thor_ops import * __all__ = [ @@ -260,6 +260,7 @@ __all__ = [ 'DepthwiseConv2dNative', 'UnsortedSegmentSum', 'UnsortedSegmentMin', + 'UnsortedSegmentProd', "AllGather", "AllReduce", "ReduceScatter", @@ -341,7 +342,12 @@ __all__ = [ "InTopK", "CropAndResize", "LRN", - "Mod" + "Mod", + "PopulationCount", + "ParallelConcat", + "EmbeddingLookup", + "Push", + "Pull" ] __all__.sort() diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index be7e901757..2d17da0028 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -394,76 +394,6 @@ class AscendDequant(PrimitiveWithInfer): return mstype.float16 -class EmbeddingLookup(PrimitiveWithInfer): - """ - Returns a slice of input tensor based on the specified indices. - - This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has three more inputs: - `offset`, `reduce_scatter_flag` and `split_num`. This primitive runs on the host instead of devices. - - Inputs: - - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - The Tensor slice, instead of the entire Tensor. - - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. - Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, - and the exceeding part will be filled with 0 in the output. - - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices - are equal to `input_indices` minus `offset`. - - **reduce_scatter_flag** (bool) - Specifies whether perform reduce_scatter on host or not. - Only constant value is allowed. - - **split_num** (int) - Specifies the number of partitions of the reduce_scatter produces. This variable - is used only if `reduce_scatter_flag` is True. Only constant value is allowed. - - - Outputs: - Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. - - Examples: - >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) - >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) - >>> offset = 4 - >>> reduce_scatter_flag = False - >>> split_num = 1 - >>> out = P.EmbeddingLookup()(input_params, input_indices, offset, reduce_scatter_flag, split_num) - [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] - """ - @prim_attr_register - def __init__(self): - """init index_select""" - self.__setattr_flag__ = True - self.init_prim_io_names(inputs=['params', 'indices', 'offset', 'reduce_scatter_flag', 'split_num'], - outputs=['output']) - self.add_prim_attr('primitive_target', 'CPU') - - def __infer__(self, params, indices, offset, reduce_scatter_flag=False, split_num=2): - validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) - validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) - validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) - validator.check_subclass("split_num", split_num['dtype'], mstype.int_, self.name) - if split_num['value'] < 1: - raise ValueError("The parameter 'split_num' must be positive, but got %d." % split_num) - params_shp = params['shape'] - out_shape = indices['shape'] + params_shp[1:] - if reduce_scatter_flag is None: - raise ValueError("The value of 'reduce_scatter_flag' is None.") - reduce_scatter_flag_value = reduce_scatter_flag['value'] - if split_num is None: - raise ValueError("The value of 'split_num_value' is None.") - split_num_value = split_num['value'] - if reduce_scatter_flag_value is True: - # Partition the tensor along the dimension 0. The shape size of dimension 0 should be divisible by - # (split_num * 8) - if out_shape[0] % (split_num_value * 8) != 0: - raise ValueError("The dimension 0 of the shape: %d, is not divisible by: %d." % - (out_shape[0], (split_num_value * 8))) - # After 'Concat' on host, the shape size of dimension 0 is: out_shape[0] // 8 - out_shape[0] = out_shape[0] // 8 - out = {'shape': out_shape, - 'dtype': params['dtype'], - 'value': None} - return out - - class SparseApplyFtrlNoReturn(PrimitiveWithInfer): """ Update relevant entries according to the FTRL-proximal scheme. @@ -747,7 +677,7 @@ class MatrixDiagPart(PrimitiveWithInfer): Tensor, data type same as input `x`. The shape should be x.shape[:-2] + [min(x.shape[-2:])]. Examples: - >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) + >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) >>> assist = Tensor(np.arange(-12, 0).reshape(3, 2, 2), mindspore.float32) >>> matrix_diag_part = P.MatrixDiagPart() >>> result = matrix_diag_part(x, assist) @@ -789,11 +719,11 @@ class MatrixSetDiag(PrimitiveWithInfer): Tensor, data type same as input `x`. The shape same as `x`. Examples: - >>> x = Tensor([[[-1, 0], [0, 1]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) + >>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) >>> matrix_set_diag = P.MatrixSetDiag() >>> result = matrix_set_diag(x, diagonal) - [[[-1, 0], [0, 2]], [-1, 0], [0, 1]], [[-1, 0], [0, 1]]] + [[[-1, 0], [0, 2]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]] """ @@ -812,10 +742,10 @@ class MatrixSetDiag(PrimitiveWithInfer): validator.check("x shape", x_shape, "assist shape", assist_shape, Rel.EQ, self.name) if x_shape[-2] < x_shape[-1]: - validator.check("x shape excluding the last dimension", x_shape[:-1], "diagnoal shape", - diagonal_shape, Rel.EQ, self.name) + validator.check("diagnoal shape", diagonal_shape, "x shape excluding the last dimension", + x_shape[:-1], Rel.EQ, self.name) else: - validator.check("x shape excluding the second to last dimension", x_shape[:-2]+x_shape[-1:], - "diagonal shape", diagonal_shape, Rel.EQ, self.name) + validator.check("diagonal shape", diagonal_shape, "x shape excluding the second last dimension", + x_shape[:-2] + x_shape[-1:], Rel.EQ, self.name) return assist_shape diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 7b7e8b2b64..1e28a56db1 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -601,51 +601,6 @@ class SparseGatherV2(GatherV2): >>> out = P.SparseGatherV2()(input_params, input_indices, axis) """ -class EmbeddingLookup(PrimitiveWithInfer): - """ - Returns a slice of input tensor based on the specified indices and axis. This Primitive has the similar - functionality as GatherV2, but has one more inputs: `offset`. - This primitive runs on the acipu devices. - - Inputs: - - **params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - The Tensor slice, instead of the entire Tensor. - - **indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. - Specifies the indices of elements of the original Tensor. Values can be out of range of `params`, - and the exceeding part will be filled with 0 in the output. - The indices to do lookup operation whose data type should be mindspore.int32 or mindspore.int64. - - **offset** (int) - Specifies the offset value of this `params` slice. Thus the real indices - are equal to `indices` minus `offset`. - - - Outputs: - Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. - - Examples: - >>> params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) - >>> indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) - >>> offset = 4 - >>> out = P.EmbeddingLookup()(params, indices, offset) - [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] - """ - @prim_attr_register - def __init__(self): - """init index_select""" - self.init_prim_io_names(inputs=['params', 'indices', 'offset'], - outputs=['output']) - - def __infer__(self, params, indices, offset): - validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) - valid_types = (mstype.int32, mstype.int64) - validator.check_tensor_type_same({"indices": indices['dtype']}, valid_types, self.name) - validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) - params_shp = params['shape'] - out_shape = indices['shape'] + params_shp[1:] - out = {'shape': out_shape, - 'dtype': params['dtype'], - 'value': None} - return out - class Split(PrimitiveWithInfer): """ @@ -688,8 +643,10 @@ class Split(PrimitiveWithInfer): validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name) output_valid_check = x_shape[self.axis] % self.output_num - validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ, - self.name) + if output_valid_check != 0: + raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" + f" output_num {self.output_num}") + x_shape[self.axis] = int(x_shape[self.axis] / self.output_num) out_shapes = [] out_dtypes = [] @@ -1031,7 +988,7 @@ class InvertPermutation(PrimitiveWithInfer): values can not be negative. Inputs: - - **input_x** (Union(tuple[int]) - The input tuple is constructed by multiple + - **input_x** (Union(tuple[int], list[int]) - The input is constructed by multiple integers, i.e., :math:`(y_1, y_2, ..., y_S)` representing the indices. The values must include 0. There can be no duplicate values or negative values. Only constant value is allowed. @@ -1059,6 +1016,12 @@ class InvertPermutation(PrimitiveWithInfer): validator.check_value_type("shape", x_shp, [tuple, list], self.name) if mstype.issubclass_(x['dtype'], mstype.tensor): raise ValueError(f'For \'{self.name}\' the input value must be non-Tensor.') + for shp in x_shp: + if shp != []: + x_rank = len(np.array(x_value, np.int64).shape) + raise ValueError(f'For \'{self.name}\' the rank of input must be 1, but got {x_rank}.') + for i, value in enumerate(x_value): + validator.check_value_type("input[%d]" % i, value, [int], self.name) z = [x_value[i] for i in range(len(x_value))] z.sort() @@ -1457,6 +1420,58 @@ class UnsortedSegmentMin(PrimitiveWithInfer): return out +class UnsortedSegmentProd(PrimitiveWithInfer): + """ + Computes the product along segments of a tensor. + + Inputs: + - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. + With float16, float32 or int32 data type. + - **segment_ids** (Tensor) - A `1-D` tensor whose shape is :math:`(x_1)`. Data type must be int32. + - **num_segments** (int) - The value spcifies the number of distinct `segment_ids`, + should be greater than 0. + + Outputs: + Tensor, Set the number of `num_segments` as `N`, the shape is :math:`(N, x_2, ..., x_R)`. + + Examples: + >>> input_x = Tensor(np.array([[1, 2, 3], [4, 5, 6], [4, 2, 1]]).astype(np.float32)) + >>> segment_ids = Tensor(np.array([0, 1, 0]).astype(np.int32)) + >>> num_segments = 2 + >>> unsorted_segment_prod = P.UnsortedSegmentProd() + >>> unsorted_segment_prod(input_x, segment_ids, num_segments) + [[4., 4., 3.], [4., 5., 6.]] + """ + + @prim_attr_register + def __init__(self): + """init UnsortedSegmentProd""" + self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) + + def __infer__(self, x, segment_ids, num_segments): + x_type = x['dtype'] + x_shape = x['shape'] + segment_ids_shape = segment_ids['shape'] + validator.check_subclass("input_x", x_type, mstype.tensor, self.name) + validator.check_value_type("x_shape", x_shape, [list], self.name) + valid_type = [mstype.float16, mstype.float32, mstype.int32] + validator.check_tensor_type_same({"x": x['dtype']}, valid_type, self.name) + validator.check_tensor_type_same({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name) + validator.check_integer("rank of segment_ids_shape", len(segment_ids_shape), 1, Rel.EQ, self.name) + validator.check(f'first shape of input_x', x_shape[0], + 'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name) + num_segments_v = num_segments['value'] + validator.check_value_type('num_segments', num_segments_v, [int], self.name) + validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name) + segment_ids_shape_len = len(segment_ids_shape) + out_shape = [num_segments_v] + out_shape += x_shape[segment_ids_shape_len:] + out = {'shape': out_shape, + 'dtype': mstype.tensor_type(x_type.element_type()), + 'value': None} + return out + + class Concat(PrimitiveWithInfer): r""" Concat tensor in specified axis. @@ -1508,6 +1523,60 @@ class Concat(PrimitiveWithInfer): return out +class ParallelConcat(PrimitiveWithInfer): + r""" + Concat tensor in the first dimension. + + Concat input tensors along with the first dimension. + + Note: + The input tensors are all required to have size 1 in the first dimension. + + Inputs: + - **values** (tuple, list) - Tuple or list of input tensors. The data type and shape of these + tensors must be same. + + Outputs: + Tensor, data type same as `values`. + + Examples: + >>> data1 = Tensor(np.array([[0, 1]]).astype(np.int32)) + >>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32)) + >>> op = P.ParallelConcat() + >>> output = op((data1, data2)) + [[0, 1], [2, 1]] + """ + + @prim_attr_register + def __init__(self): + """init ParallelConcat""" + + def __infer__(self, values): + x_shp = values['shape'] + x_type = values['dtype'] + + validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name) + + args = {f"x_type[{i}]": elem for i, elem in enumerate(x_type)} + validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) + + first_elem = x_shp[0] + for i, elem in enumerate(x_shp[1:]): + j = i + 1 + validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name) + validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name) + + ret_shp = x_shp[0].copy() + ret_shp[0] = len(x_shp) + self.add_prim_attr('shape', ret_shp) + self.add_prim_attr('N', len(x_shp)) + + out = {'shape': ret_shp, + 'dtype': x_type[0], + 'value': None} + return out + + def _get_pack_shape(x_shape, x_type, axis, prim_name): """for pack output shape""" validator.check_value_type("shape", x_shape, [tuple, list], prim_name) @@ -3176,3 +3245,50 @@ class TransShape(PrimitiveWithInfer): return {'shape': shp, 'dtype': dtype, 'value': None} + + +class EmbeddingLookup(PrimitiveWithInfer): + """ + Returns a slice of input tensor based on the specified indices. + + This Primitive has the similar functionality as GatherV2 operating on `axis = 0`, but has one more inputs: + `offset`. + + Inputs: + - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + The Tensor slice, instead of the entire Tensor. + - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. + Specifies the indices of elements of the original Tensor. Values can be out of range of `input_params`, + and the exceeding part will be filled with 0 in the output. + - **offset** (int) - Specifies the offset value of this `input_params` slice. Thus the real indices + are equal to `input_indices` minus `offset`. + + Outputs: + Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. + + Examples: + >>> input_params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mindspore.float32) + >>> input_indices = Tensor(np.array([[5, 2], [8, 5]]), mindspore.int32) + >>> offset = 4 + >>> out = P.EmbeddingLookup()(input_params, input_indices, offset) + [[[10, 11], [0 ,0]], [[0, 0], [10, 11]]] + """ + @prim_attr_register + def __init__(self): + """init index_select""" + self.__setattr_flag__ = True + self.init_prim_io_names(inputs=['params', 'indices', 'offset'], + outputs=['output']) + + def __infer__(self, params, indices, offset): + validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) + validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name) + validator.check_subclass("offset", offset['dtype'], mstype.int_, self.name) + params_shp = params['shape'] + if len(params_shp) != 2: + raise ValueError("The dimension of 'params' in EmbeddingLookup must be 2, but got %d." % len(params_shp)) + out_shape = indices['shape'] + params_shp[1:] + out = {'shape': out_shape, + 'dtype': params['dtype'], + 'value': None} + return out diff --git a/mindspore/ops/operations/image_ops.py b/mindspore/ops/operations/image_ops.py index 1e366b5ea6..437cda3301 100644 --- a/mindspore/ops/operations/image_ops.py +++ b/mindspore/ops/operations/image_ops.py @@ -117,8 +117,8 @@ class CropAndResize(PrimitiveWithInfer): validator.check("crop_height", crop_size_value[0], "minimum", 0, Rel.GT, self.name) validator.check("crop_width", crop_size_value[1], "minimum", 0, Rel.GT, self.name) # check crop_size element type - validator.check("crop_height dtype", crop_size_dtype[0], mstype.int32, self.name) - validator.check("crop_width dtype", crop_size_dtype[1], mstype.int32, self.name) + validator.check("crop_height dtype", crop_size_dtype[0], "expected", mstype.int32, Rel.EQ, self.name) + validator.check("crop_width dtype", crop_size_dtype[1], "expected", mstype.int32, Rel.EQ, self.name) num_boxes = boxes_shape[0] crop_height = crop_size_value[0] diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 9acd75d8e4..a9bdf07d28 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -234,7 +234,7 @@ class Softsign(PrimitiveWithInfer): \text{output} = \frac{\text{input_x}}{1 + \abs{\text{input_x}}}, Inputs: - - **input_x** (Tensor) - The input tensor whose data type should be float. + - **input_x** (Tensor) - The input tensor whose data type should be float16 or float32. Outputs: Tensor, with the same type and shape as the `input_x`. @@ -255,7 +255,7 @@ class Softsign(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) + validator.check_tensor_type_same({'input_x': input_x}, [mstype.float16, mstype.float32], self.name) return input_x @@ -1014,6 +1014,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): def infer_dtype(self, x_dtype, w_dtype): args = {'x': x_dtype, 'w': w_dtype} validator.check_tensor_type_same(args, mstype.number_type, self.name) + if x_dtype.element_type() == mstype.int8: + return mstype.tensor_type(mstype.int32) return x_dtype @@ -1930,7 +1932,7 @@ class ApplyRMSProp(PrimitiveWithInfer): >>> decay = 0.0 >>> momentum = 1e-10 >>> epsilon = 0.001 - >>> result = apply_rms(input_x, mean_square, moment, grad, learning_rate, decay, momentum, epsilon) + >>> result = apply_rms(input_x, mean_square, moment, learning_rate, grad, decay, momentum, epsilon) (-2.9977674, 0.80999994, 1.9987665) """ @@ -2772,6 +2774,7 @@ class ROIAlign(PrimitiveWithInfer): feature map coordinates. Suppose the height of a RoI is `ori_h` in the raw image and `fea_h` in the input feature map, the `spatial_scale` should be `fea_h / ori_h`. sample_num (int): Number of sampling points. Default: 2. + roi_end_mode (int): Number must be 0 or 1. Default: 1. Inputs: - **features** (Tensor) - The input features, whose shape should be `(N, C, H, W)`. @@ -2788,22 +2791,25 @@ class ROIAlign(PrimitiveWithInfer): Examples: >>> input_tensor = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32) >>> rois = Tensor(np.array([[0, 0.2, 0.3, 0.2, 0.3]]), mindspore.float32) - >>> roi_align = P.ROIAlign(1, 1, 0.5, 2) + >>> roi_align = P.ROIAlign(2, 2, 0.5, 2) >>> output_tensor = roi_align(input_tensor, rois) >>> assert output_tensor == Tensor(np.array([[[[2.15]]]]), mindspore.float32) """ @prim_attr_register - def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2): + def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2, roi_end_mode=1): """init ROIAlign""" validator.check_value_type("pooled_height", pooled_height, [int], self.name) validator.check_value_type("pooled_width", pooled_width, [int], self.name) validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) validator.check_value_type("sample_num", sample_num, [int], self.name) + validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name) + validator.check_int_range("roi_end_mode", roi_end_mode, 0, 1, Rel.INC_BOTH, self.name) self.pooled_height = pooled_height self.pooled_width = pooled_width self.spatial_scale = spatial_scale self.sample_num = sample_num + self.roi_end_mode = roi_end_mode def infer_shape(self, inputs_shape, rois_shape): return [rois_shape[0], inputs_shape[1], self.pooled_height, self.pooled_width] @@ -4803,19 +4809,19 @@ class CTCLoss(PrimitiveWithInfer): preprocess_collapse_repeated (bool): If True, repeated labels are collapsed prior to the CTC calculation. Default: False. ctc_merge_repeated (bool): If False, during CTC calculation, repeated non-blank labels will not be merged - and are interpreted as individual labels. This is a simplfied version if CTC. + and are interpreted as individual labels. This is a simplfied version of CTC. Default: True. ignore_longer_outputs_than_inputs (bool): If True, sequences with longer outputs than inputs will be ignored. Default: False. Inputs: - **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is - :math:`(max_time, batch_size, num_class)`. `num_class` should be `num_labels + 1` classes, `num_labels` - indicates the number of actual labels. Blank labels are reserved. + :math:`(max_time, batch_size, num_classes)`. `num_classes` should be `num_labels + 1` classes, `num_labels` + indicates the number of actual labels. Blank labels are reserved. Default blank label is `num_classes - 1`. - **labels_indices** (Tensor) - The indices of labels. `labels_indices[i, :] == [b, t]` means `labels_values[i]` stores the id for `(batch b, time t)`. The type must be int64 and rank must be 2. - **labels_values** (Tensor) - A `1-D` input tensor. The values associated with the given batch and time. The - type must be int32. `labels_values[i]` must in the range of `[0, num_class)`. + type must be int32. `labels_values[i]` must in the range of `[0, num_classes)`. - **sequence_length** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`. The type must be int32. Each value in the tensor should not greater than `max_time`. @@ -4849,6 +4855,7 @@ class CTCLoss(PrimitiveWithInfer): def infer_shape(self, inputs, labels_indices, labels_values, sequence_length): validator.check_integer("inputs rank", len(inputs), 3, Rel.EQ, self.name) validator.check_integer("labels_indices rank", len(labels_indices), 2, Rel.EQ, self.name) + validator.check_integer("labels_indices dim one", labels_indices[1], 2, Rel.EQ, self.name) validator.check_integer("labels_values rank", len(labels_values), 1, Rel.EQ, self.name) validator.check_integer("sequence_length rank", len(sequence_length), 1, Rel.EQ, self.name) validator.check('labels_indices size', labels_indices[0], 'labels_values size', @@ -5027,8 +5034,7 @@ class LRN(PrimitiveWithInfer): bias (float): An offset (usually positive to avoid dividing by 0). alpha (float): A scale factor, usually positive. beta (float): An exponent. - norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS", "WITHIN_CHANNEL". - Default: "ACROSS_CHANNELS". + norm_region (str): Specify normalization region. Options: "ACROSS_CHANNELS". Default: "ACROSS_CHANNELS". Inputs: - **x** (Tensor) - A 4D Tensor with float16 or float32 data type. @@ -5050,10 +5056,66 @@ class LRN(PrimitiveWithInfer): validator.check_value_type("alpha", alpha, [float], self.name) validator.check_value_type("beta", beta, [float], self.name) validator.check_value_type("norm_region", norm_region, [str], self.name) + validator.check_string('norm_region', norm_region, ['ACROSS_CHANNELS'], self.name) + validator.check_integer("depth_radius", depth_radius, 0, Rel.GE, self.name) def infer_dtype(self, x_dtype): validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name) return x_dtype def infer_shape(self, x_shape): + validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name) return x_shape + +class CTCLossV2(PrimitiveWithInfer): + r""" + Calculates the CTC(Connectionist Temporal Classification) loss. Also calculates the gradient. + Note: + - Cudnn Uses label value of for the `blank` + + Inputs: + - **inputs** (Tensor) - The input Tensor should be a `3-D` tensor whose shape is + :math:`(max_time, batch_size, num_class)`. `num_class` should be `num_labels + 1` classes, `num_labels` + indicates the number of actual labels. Blank labels are reserved. + - **labels** (Tensor) - The labels Tensor should be a `1-D` tensor whose shape is + :math:`(\sigma{label_lengths})` + or `2-D` tensor whose shape is + :math:`(max_time, max{label_lengths})` + The type must be int32. + - **input_lengths** (Tensor) - A `1-D` input tensor whose shape is + :math:`(batch_size,)`. The values should be batch. The type must be int32. + - **label_lengths** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`. + The type must be int32. Each value in the tensor should not greater than `max_time`. + + Outputs: + - **loss** (Tensor) - A tensor containing log-probabilities, the shape is :math:`(batch_size)`. Has the same + type with `inputs`. + - **gradient** (Tensor) - The gradient of `loss`. Has the same type and shape with `inputs`. + + Examples: + >>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32) + >>> labels = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32) + >>> input_lengths = Tensor(np.array([3, 3, 3]), mindspore.int32) + >>> label_lengths = Tensor(np.array([3, 3, 3]), mindspore.int32) + >>> ctc_loss = P.CTCLossV2() + >>> output = ctc_loss(inputs, labels, input_lengths, label_lengths) + """ + @prim_attr_register + def __init__(self): + pass + + def infer_dtype(self, input_dtype, labels_dtype, input_lengths_dtype, label_lengths_dtype): + validator.check_tensor_type_same({"input": input_dtype}, (mstype.float32,), self.name) + validator.check_tensor_type_same({"labels": labels_dtype}, (mstype.int32,), self.name) + validator.check_tensor_type_same({"input_lengths": input_lengths_dtype}, (mstype.int32,), self.name) + validator.check_tensor_type_same({"target_lengths": label_lengths_dtype}, (mstype.int32,), self.name) + return mstype.float32, mstype.float32 + + def infer_shape(self, input_shape, labels_shape, input_lengths_shape, label_lengths_shape): + validator.check_integer("input shape", len(input_shape), 3, Rel.EQ, self.name) + validator.check_number_range("labels shape", len(labels_shape), 1, 2, Rel.INC_BOTH, self.name) + validator.check_integer("input lengths shape", len(input_lengths_shape), 1, Rel.EQ, self.name) + validator.check_integer("label lengths shape", len(label_lengths_shape), 1, Rel.EQ, self.name) + validator.check_integer("input[1]", input_shape[1], input_lengths_shape[0], Rel.EQ, self.name) + validator.check_integer("input[1]", input_shape[1], label_lengths_shape[0], Rel.EQ, self.name) + return (input_shape[1],), input_shape diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index b6b938d800..a58403f883 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -51,6 +51,7 @@ class Assign(PrimitiveWithInfer): ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) ) + @prim_attr_register def __init__(self): self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) @@ -59,7 +60,9 @@ class Assign(PrimitiveWithInfer): return variable def infer_dtype(self, variable, value): - # Add a type validation later when we don't have to assign a value to RefKey. + if variable != mstype.type_refkey: + validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name) + validator.check_scalar_or_tensor_type_same({"value": value}, mstype.number_type, self.name) return variable @@ -324,6 +327,7 @@ class Partial(Primitive): partial_func = functools.partial(func, *args[1:]) return partial_func + class Depend(Primitive): """ Depend is used for process side-effect operations. @@ -457,3 +461,83 @@ class ConfusionMatrix(PrimitiveWithInfer): args = {"labels": labels, "predictions": predictions} validator.check_tensor_type_same(args, (mstype.number_type), self.name) return labels + + +class PopulationCount(PrimitiveWithInfer): + r""" + Calculate population count. + + Inputs: + - **input** (Tensor) - The data type should be int16 or uint16. + + Outputs: + Tensor, with shape same as the input. + + Examples: + >>> population_count = P.PopulationCount() + >>> x_input = Tensor([0, 1, 3], mindspore.int16) + >>> population_count(x_input) + """ + + @prim_attr_register + def __init__(self): + pass + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_dtype): + args = {"x": x_dtype} + validator.check_tensor_type_same(args, (mstype.int16, mstype.uint16,), self.name) + return mstype.tensor_type(mstype.uint8) + +class Push(PrimitiveWithInfer): + """ + Pushing the inputs of the corresponding optimizer to parameter server. + + Args: + optim_type (string): The optimizer type. Default: 'ApplyMomentum'. + only_shape_indices (list): The indices of input of which only shape + will be pushed to parameter server. Default: None. + + Inputs: + - **optim_inputs** (tuple) - The inputs for this kind of optimizer. + - **optim_input_shapes** (tuple) - The shapes of the inputs. + + Outputs: + Tensor, the key of the weight which needs to be updated. + """ + + @prim_attr_register + def __init__(self, optim_type='ApplyMomentum', only_shape_indices=None): + """init Push""" + self.init_prim_io_names(inputs=['optim_inputs', 'optim_input_shapes'], outputs=['key']) + + def infer_shape(self, inputs, shapes): + return [1] + + def infer_dtype(self, inputs, shapes): + return mstype.uint64 + +class Pull(PrimitiveWithInfer): + """ + Pulling weight from parameter server. + + Inputs: + - **key** (Tensor) - The key of the weight. + - **weight** (Tensor) - The weight to be updated. + + Outputs: + None. + """ + + @prim_attr_register + def __init__(self): + """init Pull""" + self.init_prim_io_names(inputs=['key', 'weight'], outputs=['output']) + + def infer_shape(self, key_shape, weight_shape): + return [1] + + def infer_dtype(self, key_dtype, weight_dtype): + return mstype.float32 diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 7ceb687778..cb34e9ff24 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -146,7 +146,7 @@ class Primitive(Primitive_): Check whether or not certain inputs should go into backend. Subclass in need should override this method. Args: - Same as arguments of current Primitive + *args(Primitive args): Same as arguments of current Primitive. Returns: A tuple of two elements, first element indicates whether or not we should filter out current arguments; @@ -237,12 +237,14 @@ class PrimitiveWithInfer(Primitive): """ Infer output shape based on input shape. - Args: - inputs (tuple(int)): dimensions of input tensors. - outputs (tuple(int)): dimensions of output tensors. - Note: The shape of scalar is an empty tuple. + + Args: + args (tuple(int)): shapes of input tensors. + + Return: + `tuple(int)`, shapes of output tensors. """ return None @@ -251,8 +253,10 @@ class PrimitiveWithInfer(Primitive): Infer output dtype based on input dtype. Args: - inputs (mstype): data type of inputs. - outputs (mstype): data type of outputs. + args (:class:`mindspore.dtype`): data type of inputs. + + Return: + :class:`mindspore.dtype`, data type of outputs. """ return None @@ -261,8 +265,10 @@ class PrimitiveWithInfer(Primitive): Infer output value based on input value at compile time. Args: - inputs (any): value of inputs. - outputs (any): value of outputs. + args (Any): value of inputs. + + Return: + Value of outputs. Return `None` for, cat not infer the value at compile time. """ return None diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index c5b4d57702..68f070d4a5 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -122,47 +122,6 @@ def _parameter_broadcast_check(parallel_mode, parameter_broadcast): "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}" .format(parallel_mode, parameter_broadcast)) - -PARAMETER_CLONED_INDEX = 0 - - -class _CloneInfo(): - """ - The clone info of parameter. - - Attributes: - be_cloned (bool): Whether the parameter is cloned. - cloned (bool): Whether the parameter clone from other parameter. - be_cloned_index (tuple): If the parameter is cloned, generate one index per clone. - cloned_index (int): If the parameter clone from other parameter, it has a unique index. - """ - def __init__(self): - self.be_cloned = False - self.cloned = False - self.be_cloned_index = [] - self.cloned_index = None - - -def _set_clone_info(clone_from, clone_to): - """ - Set the clone info. - - Args: - clone_from (_CloneInfo): The clone info of be_cloned parameter. - clone_to (_CloneInfo): The clone info of cloned parameter. - """ - global PARAMETER_CLONED_INDEX - clone_to.be_cloned = False - clone_to.cloned = True - clone_to.be_cloned_index = [] - clone_to.cloned_index = PARAMETER_CLONED_INDEX - - clone_from.be_cloned = True - clone_from.be_cloned_index.append(PARAMETER_CLONED_INDEX) - - PARAMETER_CLONED_INDEX = PARAMETER_CLONED_INDEX + 1 - - def _get_python_op(op_name, op_path, instance_name, arglist): """Get python operator.""" module = __import__(op_path, fromlist=["None"]) diff --git a/mindspore/train/callback/_loss_monitor.py b/mindspore/train/callback/_loss_monitor.py index 766777e878..15a095c5cb 100644 --- a/mindspore/train/callback/_loss_monitor.py +++ b/mindspore/train/callback/_loss_monitor.py @@ -14,7 +14,6 @@ # ============================================================================ """LossMonitor Callback class.""" -import time import numpy as np from mindspore.common.tensor import Tensor @@ -32,62 +31,32 @@ class LossMonitor(Callback): Args: per_print_times (int): Print loss every times. Default: 1. - lr_init (numpy array): train learning rate. Default: None. Raises: ValueError: If print_step is not int or less than zero. - - Examples: - >>> LossMonitor(100, lr_init=Tensor([0.05]*100).asnumpy()) """ - def __init__(self, per_print_times=1, lr_init=None): + def __init__(self, per_print_times=1): super(LossMonitor, self).__init__() if not isinstance(per_print_times, int) or per_print_times < 0: raise ValueError("print_step must be int and >= 0.") self._per_print_times = per_print_times - self.lr_init = lr_init - - def epoch_begin(self, run_context): - self.losses = [] - self.epoch_time = time.time() - - def epoch_end(self, run_context): - cb_params = run_context.original_args() - epoch_mseconds = (time.time() - self.epoch_time) * 1000 - per_step_mseconds = epoch_mseconds / cb_params.batch_num - print("Epoch time: {:5.3f}, per step time: {:5.3f}, " - "avg loss: {:5.3f}".format(epoch_mseconds, - per_step_mseconds, - np.mean(self.losses))) - print("*" * 60) - - def step_begin(self, run_context): - self.step_time = time.time() def step_end(self, run_context): cb_params = run_context.original_args() - step_mseconds = (time.time() - self.step_time) * 1000 - step_loss = cb_params.net_outputs + loss = cb_params.net_outputs - if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): - step_loss = step_loss[0] - if isinstance(step_loss, Tensor): - step_loss = np.mean(step_loss.asnumpy()) + if isinstance(loss, (tuple, list)): + if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): + loss = loss[0] - self.losses.append(step_loss) - cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1 + if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): + loss = np.mean(loss.asnumpy()) - if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): - raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " - "Invalid loss, terminating training.".format( - cb_params.cur_epoch_num - 1, cb_params.epoch_num, - cur_step_in_epoch, cb_params.batch_num)) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): + raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( + cb_params.cur_epoch_num, cur_step_in_epoch)) if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: - print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " - "loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}ms]".format( - cb_params.cur_epoch_num, cb_params.epoch_num, - cur_step_in_epoch, int(cb_params.batch_num), - step_loss, np.mean(self.losses), - step_mseconds), flush=True) + print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True) diff --git a/mindspore/train/callback/_summary_collector.py b/mindspore/train/callback/_summary_collector.py index 1550c3c55c..ded0e9a650 100644 --- a/mindspore/train/callback/_summary_collector.py +++ b/mindspore/train/callback/_summary_collector.py @@ -126,10 +126,12 @@ class SummaryCollector(Callback): >>> >>> # Only collect metric, custom lineage data and record data that collected by the summary operator, >>> # others are not collected - >>> specified = {'collect_metric':True, 'custom_lineage_data': {'version': 'resnet50_v1'}} + >>> specified = {'collect_metric': True} >>> summary_collector = SummaryCollector('./summary_dir', >>> collect_specified_data=specified, - >>> keep_default_action=False) + >>> keep_default_action=False, + >>> custom_lineage_data={'version': 'resnet50_v1'} + >>> ) >>> model.train(epoch, dataset, callbacks=summary_collector) """ diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 14797e568b..75e1deabc4 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -14,6 +14,7 @@ # ============================================================================ """Dataset help for minddata dataset""" import math +import os from mindspore._checkparam import check_bool from .. import context @@ -60,7 +61,11 @@ class DatasetHelper: if context.get_context("device_target") == "Ascend": iterclass = _DatasetIterMSLoopSink elif context.get_context("device_target") == "GPU": - iterclass = _DatasetIterMS + ms_role = os.getenv("MS_ROLE") + if ms_role in ("MS_PSERVER", "MS_SCHED"): + iterclass = _DatasetIterPSLite + else: + iterclass = _DatasetIterMS elif context.get_context("device_target") == "CPU": raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.") else: @@ -131,6 +136,9 @@ class _DatasetIterMSLoopSink(_DatasetIter): def __init__(self, dataset): super(_DatasetIterMSLoopSink, self).__init__(dataset) self.loop_count = self.get_loop_count(dataset) + ms_role = os.getenv("MS_ROLE") + if ms_role in ("MS_PSERVER", "MS_SCHED"): + self.loop_count = 1 # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. @@ -154,6 +162,18 @@ class _DatasetIterMS(_DatasetIter): self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) +class _DatasetIterPSLite(_DatasetIter): + """Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED""" + def __init__(self, dataset): + super(_DatasetIterPSLite, self).__init__(dataset) + self.loop_count = 1 + self.loop_size = 1 + self.op = None + def op(): + return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1) + self.op = op + + class _DatasetIterGE(_DatasetIter): """Iter for ge""" def __init__(self, dataset): diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 79bd6bc90b..74fd668e82 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -15,6 +15,7 @@ """Model.""" from collections.abc import Iterable +import os import numpy as np from mindspore import log as logger @@ -350,6 +351,9 @@ class Model: cb_params.train_dataset = train_dataset cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.train_dataset_element = None + ms_role = os.getenv("MS_ROLE") + if ms_role in ("MS_PSERVER", "MS_SCHED"): + epoch = 1 # build callback list with _CallbackManager(callbacks) as list_callback: diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index bc44ba22c2..b553373f10 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -33,8 +33,10 @@ from ...ops.operations import _inner_ops as inner from ...train import serialization from . import quant_utils -_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, - nn.ReLU6: quant.ReLU6Quant, +_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, + nn.ReLU6: quant.ActQuant, + nn.LeakyReLU: quant.ActQuant, + nn.Sigmoid: quant.ActQuant, nn.HSigmoid: quant.HSigmoidQuant, nn.HSwish: quant.HSwishQuant} @@ -112,7 +114,6 @@ class ConvertToQuantNetwork: def run(self): self.network.update_cell_prefix() network = self._convert_subcells2quant(self.network) - network = _AddFakeQuantInput(network) self.network.update_cell_type("quant") return network @@ -257,9 +258,9 @@ class ConvertToQuantNetwork: def _convert_activation(self, activation): act_class = activation.__class__ if act_class not in _ACTIVATION_MAP: - raise ValueError( - "Unsupported activation in auto quant: ", act_class) - return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, + raise ValueError("Unsupported activation in auto quant: ", act_class) + return _ACTIVATION_MAP[act_class](activation=act_class, + num_bits=self.act_bits, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, @@ -273,16 +274,20 @@ class ExportToQuantInferNetwork: Args: network (Cell): MindSpore network API `convert_quant_network`. inputs (Tensor): Input tensors of the `quantization aware training network`. + mean (int): Input data mean. Default: 127.5. + std_dev (int, float): Input data variance. Default: 127.5. Returns: Cell, GEIR backend Infer network. """ __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] - def __init__(self, - network, - *inputs): + def __init__(self, network, mean, std_dev, *inputs): network = validator.check_isinstance('network', network, (nn.Cell,)) + # quantize for inputs: q = f / scale + zero_point + # dequantize for outputs: f = (q - zero_point) * scale + self.input_scale = round(mean) + self.input_zero_point = 1 / std_dev self.data_type = mstype.int8 self.network = copy.deepcopy(network) self.all_parameters = {p.name: p for p in self.network.get_parameters()} @@ -313,11 +318,14 @@ class ExportToQuantInferNetwork: info = self.quant_info_table.get(w_minq_name, None) if info: fack_quant_a_in_op, minq_name = info - maxq = self.all_parameters[minq_name[:-4] + "maxq"] - minq = self.all_parameters[minq_name] - scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) + if minq_name == 'input': + scale_a_in, zp_a_in = self.input_scale, self.input_zero_point + else: + maxq = self.all_parameters[minq_name[:-4] + "maxq"] + minq = self.all_parameters[minq_name] + scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) else: - logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}") + logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}") return None # Build the `Quant` `Dequant` op. @@ -325,7 +333,7 @@ class ExportToQuantInferNetwork: quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in)) sqrt_mode = False scale_deq = scale_a_out * scale_w - if scale_deq < 2 ** -14: + if (scale_deq < 2 ** -14).all(): scale_deq = np.sqrt(scale_deq) sqrt_mode = True dequant_op = inner.AscendDequant(sqrt_mode) @@ -393,7 +401,7 @@ class ExportToQuantInferNetwork: return network -def export(network, *inputs, file_name, file_format='GEIR'): +def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='GEIR'): """ Exports MindSpore quantization predict model to deploy with GEIR. @@ -401,16 +409,27 @@ def export(network, *inputs, file_name, file_format='GEIR'): network (Cell): MindSpore network produced by `convert_quant_network`. inputs (Tensor): Inputs of the `quantization aware training network`. file_name (str): File name of model to export. + mean (int): Input data mean. Default: 127.5. + std_dev (int, float): Input data variance. Default: 127.5. file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model. - GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model. """ + supported_device = ["Ascend"] supported_formats = ['GEIR'] + mean = validator.check_type("mean", mean, (int, float)) + std_dev = validator.check_type("std_dev", std_dev, (int, float)) + + if context.get_context('device_target') not in supported_device: + raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) + if file_format not in supported_formats: raise ValueError('Illegal file format {}.'.format(file_format)) + network.set_train(False) + if file_format == 'GEIR': - exporter = ExportToQuantInferNetwork(network, *inputs) + exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs) deploy_net = exporter.run() serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index c4a8004012..69505970fd 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -45,7 +45,7 @@ def cal_quantization_params(input_min, raise ValueError("input min shape should equal to input max.") if len(input_min.shape) > 1: raise ValueError("input min and max shape should be one dim.") - if input_min > input_max: + if (input_min > input_max).all(): raise ValueError("input_min min should less than input max.") if (input_max == input_min).all(): # scale = 1.0, zp = 0.0 @@ -85,9 +85,7 @@ def cal_quantization_params(input_min, return scale, zp -def weight2int(data, - scale, - zero_point): +def weight2int(data, scale, zero_point): r""" Calculate int8/uint8 weight from fp32. the formula is defined as: @@ -103,12 +101,25 @@ def weight2int(data, weight (numpy.ndarray): The dimension of channel or 1. """ if scale.shape != zero_point.shape: - raise ValueError("scale and zero_point should have the same shape.") - if scale.shape[0] > 0: - scale = scale.reshape(1, -1) - zero_point = zero_point.reshape(1, -1) + raise ValueError("`scale` and `zero_point` should have the same shape.") + if scale.shape[0] < 0: + raise ValueError("`scale` and `zero_point` shape should greater than zero.") + if len(scale.shape) > 1: + # for perchannel + if scale.shape[0] == data.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(data.shape[1:]) + scale = scale.reshape(shape_list) + zero_point = zero_point.reshape(shape_list) + elif scale.shape[0] == data.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(data.shape[2:]) + scale = scale.reshape(shape_list) + zero_point = zero_point.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(data.shape)) - return np.round((data/scale) + zero_point) + return np.round((data / scale) + zero_point) def scale_zp_from_fack_quant_cell(cell, data_type): @@ -183,9 +194,20 @@ def fold_batchnorm(weight, cell_quant): beta = cell_quant.beta.data.asnumpy() epsilon = cell_quant.eps sigma = np.sqrt(variance + epsilon) - gamma = gamma.reshape(-1, 1, 1, 1) - sigma = sigma.reshape(-1, 1, 1, 1) - mean = mean.reshape(-1, 1, 1, 1) - weight = weight * gamma / sigma + + if gamma.shape[0] == weight.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(weight.shape[1:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + elif gamma.shape[0] == weight.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(weight.shape[2:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(weight.shape)) + + weight = weight * _gamma / _sigma bias = beta - gamma * mean / sigma return weight, bias diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index d74bee2706..bc74986321 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -302,7 +302,7 @@ def _save_graph(network, file_name): if graph_proto: with open(file_name, "wb") as f: f.write(graph_proto) - os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) + os.chmod(file_name, stat.S_IRUSR) def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): @@ -424,6 +424,7 @@ def export(net, *inputs, file_name, file_format='GEIR'): if is_training: net.set_train(mode=False) # export model + net.init_parameters_data() if file_format == 'GEIR': _executor.compile(net, *inputs, phase='export') _executor.export(net, file_name, file_format) @@ -462,19 +463,18 @@ def parse_print(print_file_name): List, element of list is Tensor. Raises: - ValueError: Print file is incorrect. + ValueError: The print file may be empty, please make sure enter the correct file name. """ - if not os.path.realpath(print_file_name): - raise ValueError("Please input the correct print file name.") + print_file_path = os.path.realpath(print_file_name) - if os.path.getsize(print_file_name) == 0: + if os.path.getsize(print_file_path) == 0: raise ValueError("The print file may be empty, please make sure enter the correct file name.") logger.info("Execute load print process.") print_list = Print() try: - with open(print_file_name, "rb") as f: + with open(print_file_path, "rb") as f: pb_content = f.read() print_list.ParseFromString(pb_content) except BaseException as e: diff --git a/model_zoo/README.md b/model_zoo/README.md index 2dde985679..1e392445af 100644 --- a/model_zoo/README.md +++ b/model_zoo/README.md @@ -134,43 +134,41 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework a | Parameters | AlexNet | | -------------------------- | ------- | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| Accuracy | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | +| Published Year | 2012 | +| Paper | [ImageNet Classification with Deep Convolutional Neural Networks](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-) | +| Resource | Ascend 910 | +| Features | support with Ascend, GPU | +| MindSpore Version | 0.5.0-beta | +| Dataset | CIFAR10 | +| Training Parameters | epoch=30, batch_size=32 | +| Optimizer | Momentum | +| Loss Function | SoftmaxCrossEntropyWithLogits | +| Accuracy | 88.23% | +| Speed | 1481fps | +| Loss | 0.108 | +| Params (M) | 61.10 | +| Checkpoint for Fine tuning | 445MB(.ckpt file) | +| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/alexnet| #### [LeNet](#table-of-contents) | Parameters | LeNet | | -------------------------- | ----- | -| Published Year | | -| Paper | | -| Resource | | -| Features | | -| MindSpore Version | | -| Dataset | | -| Training Parameters | | -| Optimizer | | -| Loss Function | | -| Accuracy | | -| Speed | | -| Loss | | -| Params (M) | | -| Checkpoint for Fine tuning | | -| Model for inference | | -| Scripts | | +| Published Year | 1998 | +| Paper | [Gradient-Based Learning Applied to Document Recognition](https://ieeexplore.ieee.org/abstract/document/726791) | +| Resource | Ascend 910 | +| Features | support with Ascend, GPU, CPU | +| MindSpore Version | 0.5.0-beta | +| Dataset | MNIST | +| Training Parameters | epoch=10, batch_size=32 | +| Optimizer | Momentum | +| Loss Function | SoftmaxCrossEntropyWithLogits | +| Accuracy | 98.52% | +| Speed | 18680fps | +| Loss | 0.004 | +| Params (M) | 0.06 | +| Checkpoint for Fine tuning | 483KB(.ckpt file) | +| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/lenet| ### Object Detection and Segmentation diff --git a/model_zoo/Transformer/train.py b/model_zoo/Transformer/train.py index 23c0eb78fd..ffd6b8c714 100644 --- a/model_zoo/Transformer/train.py +++ b/model_zoo/Transformer/train.py @@ -147,10 +147,11 @@ def run_transformer_train(): callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()] if args.enable_save_ckpt == "true": - ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps, - keep_checkpoint_max=args.save_checkpoint_num) - ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config) - callbacks.append(ckpoint_cb) + if device_num == 1 or (device_num > 1 and rank_id == 0): + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_steps, + keep_checkpoint_max=args.save_checkpoint_num) + ckpoint_cb = ModelCheckpoint(prefix='transformer', directory=args.save_checkpoint_path, config=ckpt_config) + callbacks.append(ckpoint_cb) if args.enable_lossscale == "true": scale_manager = DynamicLossScaleManager(init_loss_scale=cfg.init_loss_scale_value, diff --git a/model_zoo/alexnet/eval.py b/model_zoo/alexnet/eval.py index 4190451632..6a091aedd8 100644 --- a/model_zoo/alexnet/eval.py +++ b/model_zoo/alexnet/eval.py @@ -20,7 +20,7 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt import argparse from src.config import alexnet_cfg as cfg -from src.dataset import create_dataset_mnist +from src.dataset import create_dataset_cifar10 from src.alexnet import AlexNet import mindspore.nn as nn from mindspore import context @@ -50,8 +50,8 @@ if __name__ == "__main__": print("============== Starting Testing ==============") param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) - ds_eval = create_dataset_mnist(args.data_path, - cfg.batch_size, - status="test") + ds_eval = create_dataset_cifar10(args.data_path, + cfg.batch_size, + status="test") acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) print("============== {} ==============".format(acc)) diff --git a/model_zoo/alexnet/src/dataset.py b/model_zoo/alexnet/src/dataset.py index 6e9f310bed..651c76d6e3 100644 --- a/model_zoo/alexnet/src/dataset.py +++ b/model_zoo/alexnet/src/dataset.py @@ -23,7 +23,7 @@ from mindspore.common import dtype as mstype from .config import alexnet_cfg as cfg -def create_dataset_mnist(data_path, batch_size=32, repeat_size=1, status="train"): +def create_dataset_cifar10(data_path, batch_size=32, repeat_size=1, status="train"): """ create dataset for train or test """ diff --git a/model_zoo/alexnet/train.py b/model_zoo/alexnet/train.py index 184290c26c..df038d62a2 100644 --- a/model_zoo/alexnet/train.py +++ b/model_zoo/alexnet/train.py @@ -20,7 +20,7 @@ python train.py --data_path /YourDataPath import argparse from src.config import alexnet_cfg as cfg -from src.dataset import create_dataset_mnist +from src.dataset import create_dataset_cifar10 from src.generator_lr import get_lr from src.alexnet import AlexNet import mindspore.nn as nn @@ -43,7 +43,7 @@ if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) - ds_train = create_dataset_mnist(args.data_path, cfg.batch_size, cfg.epoch_size) + ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, cfg.epoch_size) network = AlexNet(cfg.num_classes) loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size())) diff --git a/model_zoo/bert/README.md b/model_zoo/bert/README.md index 3ed2bf6783..45928da4e3 100644 --- a/model_zoo/bert/README.md +++ b/model_zoo/bert/README.md @@ -5,9 +5,9 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base]( ## Requirements - Install [MindSpore](https://www.mindspore.cn/install/en). - Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. -- Download the CLUE/SQuAD v1.1 dataset for fine-tuning and evaluation. +- Download dataset for fine-tuning and evaluation such as CLUENER, TNEWS, SQuAD v1.1, etc. > Notes: - If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file. + If you are running a fine-tuning or evaluation task, prepare a checkpoint from pre-train. ## Running the Example ### Pre-Training @@ -24,31 +24,15 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base]( sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH ``` -### Fine-Tuning -- Set options in `finetune_config.py`. Make sure the 'data_file', 'schema_file' and 'pre_training_file' are set to your own path. Set the 'pre_training_ckpt' to a saved checkpoint file generated after pre-training. +### Fine-Tuning and Evaluation +- Set bert network config and optimizer hyperparameters in `finetune_eval_config.py`. -- Run `finetune.py` for fine-tuning of BERT-base and BERT-NEZHA model. +- Set task related hyperparameters in scripts/run_XXX.sh. - ```bash - python finetune.py - ``` - -### Evaluation -- Set options in `evaluation_config.py`. Make sure the 'data_file', 'schema_file' and 'finetune_ckpt' are set to your own path. - -- NER: Run `evaluation.py` for evaluation of BERT-base and BERT-NEZHA model. - - ```bash - python evaluation.py - ``` -- SQuAD v1.1: Run `squadeval.py` and `SQuAD_postprocess.py` for evaluation of BERT-base and BERT-NEZHA model. - - ```bash - python squadeval.py - ``` +- Run `bash scripts/run_XXX.py` for fine-tuning of BERT-base and BERT-NEZHA model. ```bash - python SQuAD_postprocess.py + bash scripts/run_XXX.sh ``` ## Usage @@ -88,26 +72,56 @@ config.py: scale_window steps for once updatation of loss scale: N, default is 1000 optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" -finetune_config.py: - task task type: SeqLabeling | Regression | Classification | COLA | SQUAD - num_labels number of labels to do classification - data_file dataset file to load: PATH, default is "/your/path/train.tfrecord" - schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" - epoch_num repeat counts of training: N, default is 5 - ckpt_prefix prefix used to save checkpoint files: PREFIX, default is "bert" - ckpt_dir path to save checkpoint files: PATH, default is None - pre_training_ckpt checkpoint file to load: PATH, default is "/your/path/pre_training.ckpt" - use_crf whether to use crf for evaluation. use_crf takes effect only when task type is NER, default is False - optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" - -evaluation_config.py: - task task type: SeqLabeling | Regression | Classification | COLA - num_labels number of labels to do classsification - data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord" - schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" - finetune_ckpt checkpoint file to load: PATH, default is "/your/path/your.ckpt" - use_crf whether to use crf for evaluation. use_crf takes effect only when task type is NER, default is False - clue_benchmark whether to use clue benchmark. clue_benchmark takes effect only when task type is NER, default is False +scripts/run_ner.sh: + device_target targeted device to run task: Ascend | GPU + do_train whether to run training on training set: true | false + do_eval whether to run eval on dev set: true | false + assessment_method assessment method to do evaluation: f1 | clue_benchmark + use_crf whether to use crf to calculate loss: true | false + device_id device id to run task + epoch_num total number of training epochs to perform + num_class number of classes to do labeling + vocab_file_path the vocabulary file that the BERT model was trained on + label2id_file_path label to id json file + save_finetune_checkpoint_path path to save generated finetuning checkpoint + load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) + load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval + train_data_file_path ner tfrecord for training. E.g., train.tfrecord + eval_data_file_path ner tfrecord for predictions if f1 is used to evaluate result, ner json for predictions if clue_benchmark is used to evaluate result + schema_file_path path to datafile schema file + +scripts/run_squad.sh: + device_target targeted device to run task: Ascend | GPU + do_train whether to run training on training set: true | false + do_eval whether to run eval on dev set: true | false + device_id device id to run task + epoch_num total number of training epochs to perform + num_class number of classes to classify, usually 2 for squad task + vocab_file_path the vocabulary file that the BERT model was trained on + eval_json_path path to squad dev json file + save_finetune_checkpoint_path path to save generated finetuning checkpoint + load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) + load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval + train_data_file_path squad tfrecord for training. E.g., train1.1.tfrecord + eval_data_file_path squad tfrecord for predictions. E.g., dev1.1.tfrecord + schema_file_path path to datafile schema file + +scripts/run_classifier.sh + device_target targeted device to run task: Ascend | GPU + do_train whether to run training on training set: true | false + do_eval whether to run eval on dev set: true | false + assessment_method assessment method to do evaluation: accuracy | f1 | mcc | spearman_correlation + device_id device id to run task + epoch_num total number of training epochs to perform + num_class number of classes to do labeling + save_finetune_checkpoint_path path to save generated finetuning checkpoint + load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) + load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval + train_data_file_path tfrecord for training. E.g., train.tfrecord + eval_data_file_path tfrecord for predictions. E.g., dev.tfrecord + schema_file_path path to datafile schema file + + ``` ### Parameters: @@ -115,7 +129,7 @@ evaluation_config.py: Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation): batch_size batch size of input dataset: N, default is 16 seq_length length of input sequence: N, default is 128 - vocab_size size of each embedding vector: N, default is 21136 + vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136 hidden_size size of bert encoder layers: N, default is 768 num_hidden_layers number of hidden layers: N, default is 12 num_attention_heads number of attention heads: N, default is 12 diff --git a/model_zoo/bert/evaluation.py b/model_zoo/bert/evaluation.py deleted file mode 100644 index 4e8b2a3aea..0000000000 --- a/model_zoo/bert/evaluation.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -Bert evaluation script. -""" - -import os -import argparse -import math -import numpy as np -import mindspore.common.dtype as mstype -from mindspore import context -from mindspore import log as logger -from mindspore.common.tensor import Tensor -import mindspore.dataset as de -import mindspore.dataset.transforms.c_transforms as C -from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.evaluation_config import cfg, bert_net_cfg -from src.utils import BertNER, BertCLS, BertReg -from src.CRF import postprocess -from src.cluener_evaluation import submit -from src.finetune_config import tag_to_index - - -class Accuracy(): - """ - calculate accuracy - """ - def __init__(self): - self.acc_num = 0 - self.total_num = 0 - - def update(self, logits, labels): - """ - Update accuracy - """ - labels = labels.asnumpy() - labels = np.reshape(labels, -1) - logits = logits.asnumpy() - logit_id = np.argmax(logits, axis=-1) - self.acc_num += np.sum(labels == logit_id) - self.total_num += len(labels) - print("=========================accuracy is ", self.acc_num / self.total_num) - - -class F1(): - """ - calculate F1 score - """ - def __init__(self): - self.TP = 0 - self.FP = 0 - self.FN = 0 - - def update(self, logits, labels): - """ - update F1 score - """ - labels = labels.asnumpy() - labels = np.reshape(labels, -1) - if cfg.use_crf: - backpointers, best_tag_id = logits - best_path = postprocess(backpointers, best_tag_id) - logit_id = [] - for ele in best_path: - logit_id.extend(ele) - else: - logits = logits.asnumpy() - logit_id = np.argmax(logits, axis=-1) - logit_id = np.reshape(logit_id, -1) - pos_eva = np.isin(logit_id, [i for i in range(1, cfg.num_labels)]) - pos_label = np.isin(labels, [i for i in range(1, cfg.num_labels)]) - self.TP += np.sum(pos_eva&pos_label) - self.FP += np.sum(pos_eva&(~pos_label)) - self.FN += np.sum((~pos_eva)&pos_label) - - -class MCC(): - """ - Calculate Matthews Correlation Coefficient. - """ - def __init__(self): - self.TP = 0 - self.FP = 0 - self.FN = 0 - self.TN = 0 - - def update(self, logits, labels): - """ - Update MCC score - """ - labels = labels.asnumpy() - labels = np.reshape(labels, -1) - labels = labels.astype(np.bool) - logits = logits.asnumpy() - logit_id = np.argmax(logits, axis=-1) - logit_id = np.reshape(logit_id, -1) - logit_id = logit_id.astype(np.bool) - ornot = logit_id ^ labels - - self.TP += (~ornot & labels).sum() - self.FP += (ornot & ~labels).sum() - self.FN += (ornot & labels).sum() - self.TN += (~ornot & ~labels).sum() - - -class Spearman_Correlation(): - """ - calculate Spearman Correlation coefficient - """ - def __init__(self): - self.label = [] - self.logit = [] - - def update(self, logits, labels): - """ - Update Spearman Correlation - """ - labels = labels.asnumpy() - labels = np.reshape(labels, -1) - logits = logits.asnumpy() - logits = np.reshape(logits, -1) - self.label.append(labels) - self.logit.append(logits) - - def cal(self): - """ - Calculate Spearman Correlation - """ - label = np.concatenate(self.label) - logit = np.concatenate(self.logit) - sort_label = label.argsort()[::-1] - sort_logit = logit.argsort()[::-1] - n = len(label) - d_acc = 0 - for i in range(n): - d = np.where(sort_label == i)[0] - np.where(sort_logit == i)[0] - d_acc += d**2 - ps = 1 - 6*d_acc/n/(n**2-1) - return ps - - -def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): - """ - get dataset - """ - _ = distribute_file - - ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", - "segment_ids", "label_ids"]) - type_cast_op = C.TypeCast(mstype.int32) - ds = ds.map(input_columns="segment_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_mask", operations=type_cast_op) - ds = ds.map(input_columns="input_ids", operations=type_cast_op) - if cfg.task == "Regression": - type_cast_op_float = C.TypeCast(mstype.float32) - ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) - else: - ds = ds.map(input_columns="label_ids", operations=type_cast_op) - ds = ds.repeat(repeat_count) - - # apply shuffle operation - buffer_size = 960 - ds = ds.shuffle(buffer_size=buffer_size) - - # apply batch operations - ds = ds.batch(batch_size, drop_remainder=True) - return ds - - -def bert_predict(Evaluation): - """ - prediction function - """ - target = args_opt.device_target - if target == "Ascend": - devid = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) - elif target == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - if bert_net_cfg.compute_type != mstype.float32: - logger.warning('GPU only support fp32 temporarily, run with fp32.') - bert_net_cfg.compute_type = mstype.float32 - else: - raise Exception("Target error, GPU or Ascend is supported.") - dataset = get_dataset(bert_net_cfg.batch_size, 1) - if cfg.use_crf: - net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True, - tag_to_index=tag_to_index, dropout_prob=0.0) - else: - net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels) - net_for_pretraining.set_train(False) - param_dict = load_checkpoint(cfg.finetune_ckpt) - load_param_into_net(net_for_pretraining, param_dict) - model = Model(net_for_pretraining) - return model, dataset - -def test_eval(): - """ - evaluation function - """ - if cfg.task == "SeqLabeling": - task_type = BertNER - elif cfg.task == "Regression": - task_type = BertReg - elif cfg.task == "Classification": - task_type = BertCLS - elif cfg.task == "COLA": - task_type = BertCLS - else: - raise ValueError("Task not supported.") - model, dataset = bert_predict(task_type) - - if cfg.clue_benchmark: - submit(model, cfg.data_file, bert_net_cfg.seq_length) - else: - if cfg.task == "SeqLabeling": - callback = F1() - elif cfg.task == "COLA": - callback = MCC() - elif cfg.task == "Regression": - callback = Spearman_Correlation() - else: - callback = Accuracy() - - columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] - for data in dataset.create_dict_iterator(): - input_data = [] - for i in columns_list: - input_data.append(Tensor(data[i])) - input_ids, input_mask, token_type_id, label_ids = input_data - logits = model.predict(input_ids, input_mask, token_type_id, label_ids) - callback.update(logits, label_ids) - print("==============================================================") - if cfg.task == "SeqLabeling": - print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) - print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) - print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FN))) - elif cfg.task == "COLA": - TP = callback.TP - TN = callback.TN - FP = callback.FP - FN = callback.FN - mcc = (TP*TN-FP*FN)/math.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)) - print("MCC: {:.6f}".format(mcc)) - elif cfg.task == "Regression": - print("Spearman Correlation is {:.6f}".format(callback.cal()[0])) - else: - print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, - callback.acc_num / callback.total_num)) - print("==============================================================") - -parser = argparse.ArgumentParser(description='Bert eval') -parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') -args_opt = parser.parse_args() -if __name__ == "__main__": - num_labels = cfg.num_labels - test_eval() diff --git a/model_zoo/bert/finetune.py b/model_zoo/bert/finetune.py deleted file mode 100644 index eb1880b9cc..0000000000 --- a/model_zoo/bert/finetune.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -Bert finetune script. -""" - -import os -import argparse -from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell, BertReg -from src.finetune_config import cfg, bert_net_cfg, tag_to_index -import mindspore.common.dtype as mstype -from mindspore import context -from mindspore import log as logger -import mindspore.dataset as de -import mindspore.dataset.transforms.c_transforms as C -from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell -from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum -from mindspore.train.model import Model -from mindspore.train.callback import Callback -from mindspore.train.callback import CheckpointConfig, ModelCheckpoint -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -class LossCallBack(Callback): - """ - Monitor the loss in training. - If the loss is NAN or INF, terminate training. - Note: - If per_print_times is 0, do not print loss. - Args: - per_print_times (int): Print loss every times. Default: 1. - """ - def __init__(self, per_print_times=1): - super(LossCallBack, self).__init__() - if not isinstance(per_print_times, int) or per_print_times < 0: - raise ValueError("print_step must be in and >= 0.") - self._per_print_times = per_print_times - - def step_end(self, run_context): - cb_params = run_context.original_args() - with open("./loss.log", "a+") as f: - f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, - str(cb_params.net_outputs))) - f.write("\n") - -def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): - """ - get dataset - """ - ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", - "segment_ids", "label_ids"]) - type_cast_op = C.TypeCast(mstype.int32) - ds = ds.map(input_columns="segment_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_mask", operations=type_cast_op) - ds = ds.map(input_columns="input_ids", operations=type_cast_op) - if cfg.task == "Regression": - type_cast_op_float = C.TypeCast(mstype.float32) - ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) - else: - ds = ds.map(input_columns="label_ids", operations=type_cast_op) - ds = ds.repeat(repeat_count) - - # apply shuffle operation - buffer_size = 960 - ds = ds.shuffle(buffer_size=buffer_size) - - # apply batch operations - ds = ds.batch(batch_size, drop_remainder=True) - return ds - -def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''): - """ - get SQuAD dataset - """ - ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", - "start_positions", "end_positions", - "unique_ids", "is_impossible"]) - type_cast_op = C.TypeCast(mstype.int32) - ds = ds.map(input_columns="segment_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_mask", operations=type_cast_op) - ds = ds.map(input_columns="start_positions", operations=type_cast_op) - ds = ds.map(input_columns="end_positions", operations=type_cast_op) - ds = ds.repeat(repeat_count) - - buffer_size = 960 - ds = ds.shuffle(buffer_size=buffer_size) - ds = ds.batch(batch_size, drop_remainder=True) - return ds - -def test_train(): - """ - finetune function - """ - target = args_opt.device_target - if target == "Ascend": - devid = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) - elif target == "GPU": - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - if bert_net_cfg.compute_type != mstype.float32: - logger.warning('GPU only support fp32 temporarily, run with fp32.') - bert_net_cfg.compute_type = mstype.float32 - else: - raise Exception("Target error, GPU or Ascend is supported.") - #BertCLSTrain for classification - #BertNERTrain for sequence labeling - if cfg.task == 'SeqLabeling': - if cfg.use_crf: - netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True, - tag_to_index=tag_to_index, dropout_prob=0.1) - else: - netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) - elif cfg.task == 'SQUAD': - netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) - elif cfg.task == 'Regression': - netwithloss = BertReg(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) - elif cfg.task == 'Classification': - netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) - else: - raise Exception("Target error, GPU or Ascend is supported.") - if cfg.task == 'SQUAD': - dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num) - else: - dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num) - # optimizer - steps_per_epoch = dataset.get_dataset_size() - if cfg.optimizer == 'AdamWeightDecayDynamicLR': - optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), - decay_steps=steps_per_epoch * cfg.epoch_num, - learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, - end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, - power=cfg.AdamWeightDecayDynamicLR.power, - warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), - weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, - eps=cfg.AdamWeightDecayDynamicLR.eps) - elif cfg.optimizer == 'Lamb': - optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num, - start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, - power=cfg.Lamb.power, weight_decay=cfg.Lamb.weight_decay, - warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_filter=cfg.Lamb.decay_filter) - elif cfg.optimizer == 'Momentum': - optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, - momentum=cfg.Momentum.momentum) - else: - raise Exception("Optimizer not supported.") - # load checkpoint into network - ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) - ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix, directory=cfg.ckpt_dir, config=ckpt_config) - param_dict = load_checkpoint(cfg.pre_training_ckpt) - load_param_into_net(netwithloss, param_dict) - - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) - if cfg.task == 'SQUAD': - netwithgrads = BertSquadCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) - else: - netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) - model = Model(netwithgrads) - model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb]) - - -parser = argparse.ArgumentParser(description='Bert finetune') -parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') -args_opt = parser.parse_args() -if __name__ == "__main__": - test_train() diff --git a/model_zoo/bert/run_classifier.py b/model_zoo/bert/run_classifier.py new file mode 100644 index 0000000000..4b2801f87c --- /dev/null +++ b/model_zoo/bert/run_classifier.py @@ -0,0 +1,201 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert finetune and evaluation script. +''' + +import os +import argparse +from src.bert_for_finetune import BertFinetuneCell, BertCLS +from src.finetune_eval_config import optimizer_cfg, bert_net_cfg +from src.dataset import create_classification_dataset +from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation +from src.utils import make_directory, LossCallBack, LoadNewestCkpt +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore import log as logger +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +_cur_dir = os.getcwd() + +def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): + """ do train """ + if load_checkpoint_path == "": + raise ValueError("Pretrain model missed, finetune task must load pretrain model!") + steps_per_epoch = dataset.get_dataset_size() + epoch_num = dataset.get_repeat_count() + # optimizer + if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': + optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), + decay_steps=steps_per_epoch * epoch_num, + learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, + power=optimizer_cfg.AdamWeightDecayDynamicLR.power, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, + eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) + elif optimizer_cfg.optimizer == 'Lamb': + optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, + start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, + end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, + power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_filter=optimizer_cfg.Lamb.decay_filter) + elif optimizer_cfg.optimizer == 'Momentum': + optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, + momentum=optimizer_cfg.Momentum.momentum) + else: + raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") + + # load checkpoint into network + ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) + ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(network, param_dict) + + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) + netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] + model.train(epoch_num, dataset, callbacks=callbacks) + +def eval_result_print(assessment_method="accuracy", callback=None): + """ print eval result """ + if assessment_method == "accuracy": + print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, + callback.acc_num / callback.total_num)) + elif assessment_method == "f1": + print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) + print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) + print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) + elif assessment_method == "mcc": + print("MCC {:.6f} ".format(callback.cal())) + elif assessment_method == "spearman_correlation": + print("Spearman Correlation is {:.6f} ".format(callback.cal()[0])) + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") + +def do_eval(dataset=None, network=None, num_class=2, assessment_method="accuracy", load_checkpoint_path=""): + """ do eval """ + if load_checkpoint_path == "": + raise ValueError("Finetune model missed, evaluation task must load finetune model!") + net_for_pretraining = network(bert_net_cfg, False, num_class) + net_for_pretraining.set_train(False) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(net_for_pretraining, param_dict) + model = Model(net_for_pretraining) + + if assessment_method == "accuracy": + callback = Accuracy() + elif assessment_method == "f1": + callback = F1(False, num_class) + elif assessment_method == "mcc": + callback = MCC() + elif assessment_method == "spearman_correlation": + callback = Spearman_Correlation() + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") + + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] + for data in dataset.create_dict_iterator(): + input_data = [] + for i in columns_list: + input_data.append(Tensor(data[i])) + input_ids, input_mask, token_type_id, label_ids = input_data + logits = model.predict(input_ids, input_mask, token_type_id, label_ids) + callback.update(logits, label_ids) + print("==============================================================") + eval_result_print(assessment_method, callback) + print("==============================================================") + +def run_classifier(): + """run classifier task""" + parser = argparse.ArgumentParser(description="run classifier") + parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") + parser.add_argument("--assessment_method", type=str, default="accuracy", help="assessment_method include: " + "[MCC, Spearman_correlation, " + "Accuracy], default is accuracy") + parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") + parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") + parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") + parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") + parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--train_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--eval_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--schema_file_path", type=str, default="", + help="Schema path, it is better to use absolute path") + args_opt = parser.parse_args() + epoch_num = args_opt.epoch_num + assessment_method = args_opt.assessment_method.lower() + load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path + save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path + load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path + + if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": + raise ValueError("At least one of 'do_train' or 'do_eval' must be true") + if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": + raise ValueError("'train_data_file_path' must be set when do finetune task") + if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": + raise ValueError("'eval_data_file_path' must be set when do evaluation task") + + target = args_opt.device_target + if target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + elif target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + if bert_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + bert_net_cfg.compute_type = mstype.float32 + else: + raise Exception("Target error, GPU or Ascend is supported.") + + netwithloss = BertCLS(bert_net_cfg, True, num_labels=args_opt.num_class, dropout_prob=0.1, + assessment_method=assessment_method) + + if args_opt.do_train.lower() == "true": + ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, + assessment_method=assessment_method, + data_file_path=args_opt.train_data_file_path, + schema_file_path=args_opt.schema_file_path) + do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) + + if args_opt.do_eval.lower() == "true": + if save_finetune_checkpoint_path == "": + load_finetune_checkpoint_dir = _cur_dir + else: + load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) + load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, + ds.get_dataset_size(), epoch_num, "classifier") + + if args_opt.do_eval.lower() == "true": + ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, + assessment_method=assessment_method, + data_file_path=args_opt.eval_data_file_path, + schema_file_path=args_opt.schema_file_path) + do_eval(ds, BertCLS, args_opt.num_class, assessment_method, load_finetune_checkpoint_path) + +if __name__ == "__main__": + run_classifier() diff --git a/model_zoo/bert/run_ner.py b/model_zoo/bert/run_ner.py new file mode 100644 index 0000000000..a61c96066e --- /dev/null +++ b/model_zoo/bert/run_ner.py @@ -0,0 +1,228 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert finetune and evaluation script. +''' + +import os +import json +import argparse +from src.bert_for_finetune import BertFinetuneCell, BertNER +from src.finetune_eval_config import optimizer_cfg, bert_net_cfg +from src.dataset import create_ner_dataset +from src.utils import make_directory, LossCallBack, LoadNewestCkpt +from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore import log as logger +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +_cur_dir = os.getcwd() + + +def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): + """ do train """ + if load_checkpoint_path == "": + raise ValueError("Pretrain model missed, finetune task must load pretrain model!") + steps_per_epoch = dataset.get_dataset_size() + epoch_num = dataset.get_repeat_count() + # optimizer + if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': + optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), + decay_steps=steps_per_epoch * epoch_num, + learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, + power=optimizer_cfg.AdamWeightDecayDynamicLR.power, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, + eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) + elif optimizer_cfg.optimizer == 'Lamb': + optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, + start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, + end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, + power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_filter=optimizer_cfg.Lamb.decay_filter) + elif optimizer_cfg.optimizer == 'Momentum': + optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, + momentum=optimizer_cfg.Momentum.momentum) + else: + raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") + + # load checkpoint into network + ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) + ckpoint_cb = ModelCheckpoint(prefix="ner", directory=save_checkpoint_path, config=ckpt_config) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(network, param_dict) + + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) + netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] + model.train(epoch_num, dataset, callbacks=callbacks) + +def eval_result_print(assessment_method="accuracy", callback=None): + """print eval result""" + if assessment_method == "accuracy": + print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, + callback.acc_num / callback.total_num)) + elif assessment_method == "f1": + print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) + print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) + print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) + elif assessment_method == "mcc": + print("MCC {:.6f} ".format(callback.cal())) + elif assessment_method == "spearman_correlation": + print("Spearman Correlation is {:.6f} ".format(callback.cal()[0])) + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") + +def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="", + load_checkpoint_path="", vocab_file="", label2id_file="", tag_to_index=None): + """ do eval """ + if load_checkpoint_path == "": + raise ValueError("Finetune model missed, evaluation task must load finetune model!") + if assessment_method == "clue_benchmark": + bert_net_cfg.batch_size = 1 + net_for_pretraining = network(bert_net_cfg, False, num_class, use_crf=(use_crf.lower() == "true"), + tag_to_index=tag_to_index) + net_for_pretraining.set_train(False) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(net_for_pretraining, param_dict) + model = Model(net_for_pretraining) + + if assessment_method == "clue_benchmark": + from src.cluener_evaluation import submit + submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf, label2id_file=label2id_file) + else: + if assessment_method == "accuracy": + callback = Accuracy() + elif assessment_method == "f1": + callback = F1((use_crf.lower() == "true"), num_class) + elif assessment_method == "mcc": + callback = MCC() + elif assessment_method == "spearman_correlation": + callback = Spearman_Correlation() + else: + raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") + + columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] + for data in dataset.create_dict_iterator(): + input_data = [] + for i in columns_list: + input_data.append(Tensor(data[i])) + input_ids, input_mask, token_type_id, label_ids = input_data + logits = model.predict(input_ids, input_mask, token_type_id, label_ids) + callback.update(logits, label_ids) + print("==============================================================") + eval_result_print(assessment_method, callback) + print("==============================================================") + +def run_ner(): + """run ner task""" + parser = argparse.ArgumentParser(description="run classifier") + parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") + parser.add_argument("--assessment_method", type=str, default="accuracy", help="assessment_method include: " + "[F1, clue_benchmark], default is F1") + parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") + parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") + parser.add_argument("--use_crf", type=str, default="false", help="Use crf, default is false") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") + parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") + parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark") + parser.add_argument("--label2id_file_path", type=str, default="", help="label2id file path, used in clue benchmark") + parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") + parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--train_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--eval_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--schema_file_path", type=str, default="", + help="Schema path, it is better to use absolute path") + args_opt = parser.parse_args() + epoch_num = args_opt.epoch_num + assessment_method = args_opt.assessment_method.lower() + load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path + save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path + load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path + + if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": + raise ValueError("At least one of 'do_train' or 'do_eval' must be true") + if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": + raise ValueError("'train_data_file_path' must be set when do finetune task") + if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": + raise ValueError("'eval_data_file_path' must be set when do evaluation task") + if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.vocab_file_path == "": + raise ValueError("'vocab_file_path' must be set to do clue benchmark") + if args_opt.use_crf.lower() == "true" and args_opt.label2id_file_path == "": + raise ValueError("'label2id_file_path' must be set to use crf") + if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label2id_file_path == "": + raise ValueError("'label2id_file_path' must be set to do clue benchmark") + + target = args_opt.device_target + if target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + elif target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + if bert_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + bert_net_cfg.compute_type = mstype.float32 + else: + raise Exception("Target error, GPU or Ascend is supported.") + + tag_to_index = None + if args_opt.use_crf.lower() == "true": + with open(args_opt.label2id_file_path) as json_file: + tag_to_index = json.load(json_file) + max_val = max(tag_to_index.values()) + tag_to_index[""] = max_val + 1 + tag_to_index[""] = max_val + 2 + number_labels = len(tag_to_index) + else: + number_labels = args_opt.num_class + netwithloss = BertNER(bert_net_cfg, True, num_labels=number_labels, + use_crf=(args_opt.use_crf.lower() == "true"), + tag_to_index=tag_to_index, dropout_prob=0.1) + if args_opt.do_train.lower() == "true": + ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, + assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, + schema_file_path=args_opt.schema_file_path) + do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) + + if args_opt.do_eval.lower() == "true": + if save_finetune_checkpoint_path == "": + load_finetune_checkpoint_dir = _cur_dir + else: + load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) + load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, + ds.get_dataset_size(), epoch_num, "ner") + + if args_opt.do_eval.lower() == "true": + ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, + assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, + schema_file_path=args_opt.schema_file_path) + do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path, + load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label2id_file_path, tag_to_index) + +if __name__ == "__main__": + run_ner() diff --git a/model_zoo/bert/run_pretrain.py b/model_zoo/bert/run_pretrain.py index 65768946c1..7123c942f3 100644 --- a/model_zoo/bert/run_pretrain.py +++ b/model_zoo/bert/run_pretrain.py @@ -26,33 +26,16 @@ from mindspore import context from mindspore.train.model import Model from mindspore.train.parallel_utils import ParallelMode from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell -from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR from mindspore import log as logger from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell from src.dataset import create_bert_dataset from src.config import cfg, bert_net_cfg +from src.utils import LossCallBack _current_dir = os.path.dirname(os.path.realpath(__file__)) -class LossCallBack(Callback): - """ - Monitor the loss in training. - If the loss in NAN or INF terminating training. - Note: - if per_print_times is 0 do not print loss. - Args: - per_print_times (int): Print loss every times. Default: 1. - """ - def __init__(self, per_print_times=1): - super(LossCallBack, self).__init__() - if not isinstance(per_print_times, int) or per_print_times < 0: - raise ValueError("print_step must be int and >= 0") - self._per_print_times = per_print_times - def step_end(self, run_context): - cb_params = run_context.original_args() - print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, - str(cb_params.net_outputs))) def run_pretrain(): """pre-train bert_clue""" diff --git a/model_zoo/bert/run_squad.py b/model_zoo/bert/run_squad.py new file mode 100644 index 0000000000..083cedac1d --- /dev/null +++ b/model_zoo/bert/run_squad.py @@ -0,0 +1,204 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert finetune and evaluation script. +''' +import os +import argparse +import collections +from src.bert_for_finetune import BertSquadCell, BertSquad +from src.finetune_eval_config import optimizer_cfg, bert_net_cfg +from src.dataset import create_squad_dataset +from src import tokenization +from src.create_squad_data import read_squad_examples, convert_examples_to_features +from src.run_squad import write_predictions +from src.utils import make_directory, LossCallBack, LoadNewestCkpt +import mindspore.common.dtype as mstype +from mindspore import context +from mindspore import log as logger +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +_cur_dir = os.getcwd() + +def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): + """ do train """ + if load_checkpoint_path == "": + raise ValueError("Pretrain model missed, finetune task must load pretrain model!") + steps_per_epoch = dataset.get_dataset_size() + epoch_num = dataset.get_repeat_count() + # optimizer + if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': + optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), + decay_steps=steps_per_epoch * epoch_num, + learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, + end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, + power=optimizer_cfg.AdamWeightDecayDynamicLR.power, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, + eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) + elif optimizer_cfg.optimizer == 'Lamb': + optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, + start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, + end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, + power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, + warmup_steps=int(steps_per_epoch * epoch_num * 0.1), + decay_filter=optimizer_cfg.Lamb.decay_filter) + elif optimizer_cfg.optimizer == 'Momentum': + optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, + momentum=optimizer_cfg.Momentum.momentum) + else: + raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") + + # load checkpoint into network + ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) + ckpoint_cb = ModelCheckpoint(prefix="squad", directory=save_checkpoint_path, config=ckpt_config) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(network, param_dict) + + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) + netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell) + model = Model(netwithgrads) + callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] + model.train(epoch_num, dataset, callbacks=callbacks) + + +def do_eval(dataset=None, vocab_file="", eval_json="", load_checkpoint_path="", seq_length=384): + """ do eval """ + if load_checkpoint_path == "": + raise ValueError("Finetune model missed, evaluation task must load finetune model!") + tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) + eval_examples = read_squad_examples(eval_json, False) + eval_features = convert_examples_to_features( + examples=eval_examples, + tokenizer=tokenizer, + max_seq_length=seq_length, + doc_stride=128, + max_query_length=64, + is_training=False, + output_fn=None, + verbose_logging=False) + + net = BertSquad(bert_net_cfg, False, 2) + net.set_train(False) + param_dict = load_checkpoint(load_checkpoint_path) + load_param_into_net(net, param_dict) + model = Model(net) + output = [] + RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) + columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"] + for data in dataset.create_dict_iterator(): + input_data = [] + for i in columns_list: + input_data.append(Tensor(data[i])) + input_ids, input_mask, segment_ids, unique_ids = input_data + start_positions = Tensor([1], mstype.float32) + end_positions = Tensor([1], mstype.float32) + is_impossible = Tensor([1], mstype.float32) + logits = model.predict(input_ids, input_mask, segment_ids, start_positions, + end_positions, unique_ids, is_impossible) + ids = logits[0].asnumpy() + start = logits[1].asnumpy() + end = logits[2].asnumpy() + + for i in range(bert_net_cfg.batch_size): + unique_id = int(ids[i]) + start_logits = [float(x) for x in start[i].flat] + end_logits = [float(x) for x in end[i].flat] + output.append(RawResult( + unique_id=unique_id, + start_logits=start_logits, + end_logits=end_logits)) + write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json", None, None) + +def run_squad(): + """run squad task""" + parser = argparse.ArgumentParser(description="run classifier") + parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") + parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") + parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") + parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") + parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path") + parser.add_argument("--eval_json_path", type=str, default="", help="Evaluation json file path, can be eval.json") + parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") + parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") + parser.add_argument("--train_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--eval_data_file_path", type=str, default="", + help="Data path, it is better to use absolute path") + parser.add_argument("--schema_file_path", type=str, default="", + help="Schema path, it is better to use absolute path") + args_opt = parser.parse_args() + epoch_num = args_opt.epoch_num + load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path + save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path + load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path + + if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": + raise ValueError("At least one of 'do_train' or 'do_eval' must be true") + if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": + raise ValueError("'train_data_file_path' must be set when do finetune task") + if args_opt.do_eval.lower() == "true": + if args_opt.eval_data_file_path == "": + raise ValueError("'eval_data_file_path' must be set when do evaluation task") + if args_opt.vocab_file_path == "": + raise ValueError("'vocab_file_path' must be set when do evaluation task") + if args_opt.eval_json_path == "": + raise ValueError("'tokenization_file_path' must be set when do evaluation task") + + + target = args_opt.device_target + if target == "Ascend": + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + elif target == "GPU": + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + if bert_net_cfg.compute_type != mstype.float32: + logger.warning('GPU only support fp32 temporarily, run with fp32.') + bert_net_cfg.compute_type = mstype.float32 + else: + raise Exception("Target error, GPU or Ascend is supported.") + + netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) + + if args_opt.do_train.lower() == "true": + ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, + data_file_path=args_opt.train_data_file_path, + schema_file_path=args_opt.schema_file_path) + do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) + if args_opt.do_eval.lower() == "true": + if save_finetune_checkpoint_path == "": + load_finetune_checkpoint_dir = _cur_dir + else: + load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) + load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, + ds.get_dataset_size(), epoch_num, "squad") + + if args_opt.do_eval.lower() == "true": + ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, + data_file_path=args_opt.eval_data_file_path, + schema_file_path=args_opt.schema_file_path, is_training=False) + do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path, + load_finetune_checkpoint_path, bert_net_cfg.seq_length) + +if __name__ == "__main__": + run_squad() diff --git a/model_zoo/bert/scripts/run_classifier.sh b/model_zoo/bert/scripts/run_classifier.sh new file mode 100644 index 0000000000..275324b950 --- /dev/null +++ b/model_zoo/bert/scripts/run_classifier.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_classifier.sh" +echo "for example: bash scripts/run_classifier.sh" +echo "assessment_method include: [MCC, Spearman_correlation ,Accuracy]" +echo "==============================================================================================================" + +mkdir -p ms_log +CUR_DIR=`pwd` +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_classifier.py \ + --device_target="Ascend" \ + --do_train="true" \ + --do_eval="false" \ + --assessment_method="Accuracy" \ + --device_id=0 \ + --epoch_num=1 \ + --num_class=2 \ + --save_finetune_checkpoint_path="" \ + --load_pretrain_checkpoint_path="" \ + --load_finetune_checkpoint_path="" \ + --train_data_file_path="" \ + --eval_data_file_path="" \ + --schema_file_path="" > log.txt 2>&1 & diff --git a/model_zoo/bert/scripts/run_distribute_pretrain.sh b/model_zoo/bert/scripts/run_distribute_pretrain.sh index 5a9f8735aa..eb3a0979d1 100644 --- a/model_zoo/bert/scripts/run_distribute_pretrain.sh +++ b/model_zoo/bert/scripts/run_distribute_pretrain.sh @@ -24,8 +24,7 @@ echo "========================================================================== EPOCH_SIZE=$2 DATA_DIR=$3 SCHEMA_DIR=$4 - -export MINDSPORE_HCCL_CONFIG_PATH=$5 +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) export RANK_TABLE_FILE=$5 export RANK_SIZE=$1 cores=`cat /proc/cpuinfo|grep "processor" |wc -l` @@ -54,7 +53,7 @@ do export GLOG_log_dir=${CUR_DIR}/ms_log export GLOG_logtostderr=0 env > env.log - taskset -c $cmdopt python ../run_pretrain.py \ + taskset -c $cmdopt python ${PROJECT_DIR}/../run_pretrain.py \ --distribute="true" \ --epoch_size=$EPOCH_SIZE \ --device_id=$DEVICE_ID \ diff --git a/model_zoo/bert/scripts/run_ner.sh b/model_zoo/bert/scripts/run_ner.sh new file mode 100644 index 0000000000..ae401b2462 --- /dev/null +++ b/model_zoo/bert/scripts/run_ner.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_ner.sh" +echo "for example: bash scripts/run_ner.sh" +echo "assessment_method include: [F1, clue_benchmark]" +echo "==============================================================================================================" + +mkdir -p ms_log +CUR_DIR=`pwd` +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_ner.py \ + --device_target="Ascend" \ + --do_train="true" \ + --do_eval="false" \ + --assessment_method="F1" \ + --use_crf="false" \ + --device_id=0 \ + --epoch_num=1 \ + --num_class=2 \ + --vocab_file_path="" \ + --label2id_file_path="" \ + --save_finetune_checkpoint_path="" \ + --load_pretrain_checkpoint_path="" \ + --load_finetune_checkpoint_path="" \ + --train_data_file_path="" \ + --eval_data_file_path="" \ + --schema_file_path="" > log.txt 2>&1 & diff --git a/model_zoo/bert/scripts/run_squad.sh b/model_zoo/bert/scripts/run_squad.sh new file mode 100644 index 0000000000..a33950cadb --- /dev/null +++ b/model_zoo/bert/scripts/run_squad.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the scipt as: " +echo "bash scripts/run_squad.sh" +echo "for example: bash scripts/run_squad.sh" +echo "assessment_method include: [Accuracy]" +echo "==============================================================================================================" + +mkdir -p ms_log +CUR_DIR=`pwd` +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +python ${PROJECT_DIR}/../run_squad.py \ + --device_target="Ascend" \ + --do_train="true" \ + --do_eval="false" \ + --device_id=0 \ + --epoch_num=1 \ + --num_class=2 \ + --vocab_file_path="" \ + --eval_json_path="" \ + --save_finetune_checkpoint_path="" \ + --load_pretrain_checkpoint_path="" \ + --load_finetune_checkpoint_path="" \ + --train_data_file_path="" \ + --eval_data_file_path="" \ + --schema_file_path="" > log.txt 2>&1 & diff --git a/model_zoo/bert/scripts/run_standalone_pretrain.sh b/model_zoo/bert/scripts/run_standalone_pretrain.sh index 3cd9545f7f..f59eb69601 100644 --- a/model_zoo/bert/scripts/run_standalone_pretrain.sh +++ b/model_zoo/bert/scripts/run_standalone_pretrain.sh @@ -26,10 +26,11 @@ DATA_DIR=$3 SCHEMA_DIR=$4 mkdir -p ms_log +PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) CUR_DIR=`pwd` export GLOG_log_dir=${CUR_DIR}/ms_log export GLOG_logtostderr=0 -python run_pretrain.py \ +python ${PROJECT_DIR}/../run_pretrain.py \ --distribute="false" \ --epoch_size=$EPOCH_SIZE \ --device_id=$DEVICE_ID \ diff --git a/model_zoo/bert/squadeval.py b/model_zoo/bert/squadeval.py deleted file mode 100644 index 49027acd6d..0000000000 --- a/model_zoo/bert/squadeval.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -"""Evaluation script for SQuAD task""" - -import os -import collections -import mindspore.dataset as de -import mindspore.dataset.transforms.c_transforms as C -import mindspore.common.dtype as mstype -from mindspore import context -from mindspore.common.tensor import Tensor -from mindspore.train.model import Model -from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src import tokenization -from src.evaluation_config import cfg, bert_net_cfg -from src.utils import BertSquad -from src.create_squad_data import read_squad_examples, convert_examples_to_features -from src.run_squad import write_predictions - -def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''): - """get SQuAD dataset from tfrecord""" - ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", - "segment_ids", "unique_ids"], - shuffle=False) - type_cast_op = C.TypeCast(mstype.int32) - ds = ds.map(input_columns="segment_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_ids", operations=type_cast_op) - ds = ds.map(input_columns="input_mask", operations=type_cast_op) - ds = ds.repeat(repeat_count) - ds = ds.batch(batch_size, drop_remainder=True) - return ds - -def test_eval(): - """Evaluation function for SQuAD task""" - tokenizer = tokenization.FullTokenizer(vocab_file="./vocab.txt", do_lower_case=True) - input_file = "dataset/v1.1/dev-v1.1.json" - eval_examples = read_squad_examples(input_file, False) - eval_features = convert_examples_to_features( - examples=eval_examples, - tokenizer=tokenizer, - max_seq_length=384, - doc_stride=128, - max_query_length=64, - is_training=False, - output_fn=None, - verbose_logging=False) - - device_id = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=device_id) - dataset = get_squad_dataset(bert_net_cfg.batch_size, 1) - net = BertSquad(bert_net_cfg, False, 2) - net.set_train(False) - param_dict = load_checkpoint(cfg.finetune_ckpt) - load_param_into_net(net, param_dict) - model = Model(net) - output = [] - RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) - columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"] - for data in dataset.create_dict_iterator(): - input_data = [] - for i in columns_list: - input_data.append(Tensor(data[i])) - input_ids, input_mask, segment_ids, unique_ids = input_data - start_positions = Tensor([1], mstype.float32) - end_positions = Tensor([1], mstype.float32) - is_impossible = Tensor([1], mstype.float32) - logits = model.predict(input_ids, input_mask, segment_ids, start_positions, - end_positions, unique_ids, is_impossible) - ids = logits[0].asnumpy() - start = logits[1].asnumpy() - end = logits[2].asnumpy() - - for i in range(bert_net_cfg.batch_size): - unique_id = int(ids[i]) - start_logits = [float(x) for x in start[i].flat] - end_logits = [float(x) for x in end[i].flat] - output.append(RawResult( - unique_id=unique_id, - start_logits=start_logits, - end_logits=end_logits)) - write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json", - None, None, False, False) - - -if __name__ == "__main__": - test_eval() diff --git a/model_zoo/bert/src/assessment_method.py b/model_zoo/bert/src/assessment_method.py new file mode 100644 index 0000000000..ca6579cabf --- /dev/null +++ b/model_zoo/bert/src/assessment_method.py @@ -0,0 +1,134 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert evaluation assessment method script. +''' +import math +import numpy as np +from .CRF import postprocess + +class Accuracy(): + ''' + calculate accuracy + ''' + def __init__(self): + self.acc_num = 0 + self.total_num = 0 + def update(self, logits, labels): + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + self.acc_num += np.sum(labels == logit_id) + self.total_num += len(labels) + print("=========================accuracy is ", self.acc_num / self.total_num) + +class F1(): + ''' + calculate F1 score + ''' + def __init__(self, use_crf=False, num_labels=2): + self.TP = 0 + self.FP = 0 + self.FN = 0 + self.use_crf = use_crf + self.num_labels = num_labels + + def update(self, logits, labels): + ''' + update F1 score + ''' + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + if self.use_crf: + backpointers, best_tag_id = logits + best_path = postprocess(backpointers, best_tag_id) + logit_id = [] + for ele in best_path: + logit_id.extend(ele) + else: + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + logit_id = np.reshape(logit_id, -1) + pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)]) + pos_label = np.isin(labels, [i for i in range(1, self.num_labels)]) + self.TP += np.sum(pos_eva&pos_label) + self.FP += np.sum(pos_eva&(~pos_label)) + self.FN += np.sum((~pos_eva)&pos_label) + +class MCC(): + ''' + Calculate Matthews Correlation Coefficient + ''' + def __init__(self): + self.TP = 0 + self.FP = 0 + self.FN = 0 + self.TN = 0 + def update(self, logits, labels): + ''' + MCC update + ''' + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + labels = labels.astype(np.bool) + logits = logits.asnumpy() + logit_id = np.argmax(logits, axis=-1) + logit_id = np.reshape(logit_id, -1) + logit_id = logit_id.astype(np.bool) + ornot = logit_id ^ labels + + self.TP += (~ornot & labels).sum() + self.FP += (ornot & ~labels).sum() + self.FN += (ornot & labels).sum() + self.TN += (~ornot & ~labels).sum() + + def cal(self): + mcc = (self.TP*self.TN - self.FP*self.FN)/math.sqrt((self.TP+self.FP)*(self.TP+self.FN) * + (self.TN+self.FP)*(self.TN+self.FN)) + return mcc + +class Spearman_Correlation(): + ''' + Calculate Spearman Correlation Coefficient + ''' + def __init__(self): + self.label = [] + self.logit = [] + + def update(self, logits, labels): + labels = labels.asnumpy() + labels = np.reshape(labels, -1) + logits = logits.asnumpy() + logits = np.reshape(logits, -1) + self.label.append(labels) + self.logit.append(logits) + + def cal(self): + ''' + Calculate Spearman Correlation + ''' + label = np.concatenate(self.label) + logit = np.concatenate(self.logit) + sort_label = label.argsort()[::-1] + sort_logit = logit.argsort()[::-1] + n = len(label) + d_acc = 0 + for i in range(n): + d = np.where(sort_label == i)[0] - np.where(sort_logit == i)[0] + d_acc += d**2 + ps = 1 - 6*d_acc/n/(n**2-1) + return ps diff --git a/model_zoo/bert/src/bert_for_finetune.py b/model_zoo/bert/src/bert_for_finetune.py new file mode 100644 index 0000000000..32ac0823b9 --- /dev/null +++ b/model_zoo/bert/src/bert_for_finetune.py @@ -0,0 +1,327 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert for finetune script. +''' + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.train.parallel_utils import ParallelMode +from mindspore.communication.management import get_group_size +from mindspore import context +from .bert_for_pre_training import clip_grad +from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel +from .utils import CrossEntropyCalculation + + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 1.0 +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + +_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") +grad_overflow = P.FloatStatus() +@_grad_overflow.register("Tensor") +def _tensor_grad_overflow(grad): + return grad_overflow(grad) + +class BertFinetuneCell(nn.Cell): + """ + Especifically defined for finetuning where only four inputs tensor are needed. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + + super(BertFinetuneCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', + get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.gpu_target = False + if context.get_context("device_target") == "GPU": + self.gpu_target = True + self.float_status = P.FloatStatus() + self.addn = P.AddN() + self.reshape = P.Reshape() + else: + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + + def construct(self, + input_ids, + input_mask, + token_type_id, + label_ids, + sens=None): + + + weights = self.weights + init = False + loss = self.network(input_ids, + input_mask, + token_type_id, + label_ids) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + + if not self.gpu_target: + init = self.alloc_status() + clear_before_grad = self.clear_before_grad(init) + F.control_depend(loss, init) + self.depend_parameter_use(clear_before_grad, scaling_sens) + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + label_ids, + self.cast(scaling_sens, + mstype.float32)) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.reducer_flag: + grads = self.grad_reducer(grads) + if not self.gpu_target: + flag = self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + F.control_depend(grads, flag) + F.control_depend(flag, flag_sum) + else: + flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) + flag_sum = self.addn(flag_sum) + flag_sum = self.reshape(flag_sum, (())) + if self.is_distributed: + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond) + return F.depend(ret, succ) + +class BertSquadCell(nn.Cell): + """ + specifically defined for finetuning where only four inputs tensor are needed. + """ + def __init__(self, network, optimizer, scale_update_cell=None): + super(BertSquadCell, self).__init__(auto_prefix=False) + self.network = network + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = None + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + def construct(self, + input_ids, + input_mask, + token_type_id, + start_position, + end_position, + unique_id, + is_impossible, + sens=None): + weights = self.weights + init = self.alloc_status() + loss = self.network(input_ids, + input_mask, + token_type_id, + start_position, + end_position, + unique_id, + is_impossible) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + grads = self.grad(self.network, weights)(input_ids, + input_mask, + token_type_id, + start_position, + end_position, + unique_id, + is_impossible, + self.cast(scaling_sens, + mstype.float32)) + clear_before_grad = self.clear_before_grad(init) + F.control_depend(loss, init) + self.depend_parameter_use(clear_before_grad, scaling_sens) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + if self.reducer_flag: + grads = self.grad_reducer(grads) + flag = self.get_status(init) + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + F.control_depend(grads, flag) + F.control_depend(flag, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, cond) + return F.depend(ret, succ) + +class BertCLS(nn.Cell): + """ + Train interface for classification finetuning task. + """ + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False, + assessment_method=""): + super(BertCLS, self).__init__() + self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings, + assessment_method) + self.loss = CrossEntropyCalculation(is_training) + self.num_labels = num_labels + self.assessment_method = assessment_method + self.is_training = is_training + def construct(self, input_ids, input_mask, token_type_id, label_ids): + logits = self.bert(input_ids, input_mask, token_type_id) + if self.assessment_method == "spearman_correlation": + if self.is_training: + loss = self.loss(logits, label_ids) + else: + loss = logits + else: + loss = self.loss(logits, label_ids, self.num_labels) + return loss + + +class BertNER(nn.Cell): + """ + Train interface for sequence labeling finetuning task. + """ + def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0, + use_one_hot_embeddings=False): + super(BertNER, self).__init__() + self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings) + if use_crf: + if not tag_to_index: + raise Exception("The dict for tag-index mapping should be provided for CRF.") + from src.CRF import CRF + self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training) + else: + self.loss = CrossEntropyCalculation(is_training) + self.num_labels = num_labels + self.use_crf = use_crf + def construct(self, input_ids, input_mask, token_type_id, label_ids): + logits = self.bert(input_ids, input_mask, token_type_id) + if self.use_crf: + loss = self.loss(logits, label_ids) + else: + loss = self.loss(logits, label_ids, self.num_labels) + return loss + +class BertSquad(nn.Cell): + ''' + Train interface for SQuAD finetuning task. + ''' + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): + super(BertSquad, self).__init__() + self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) + self.loss = CrossEntropyCalculation(is_training) + self.num_labels = num_labels + self.seq_length = config.seq_length + self.is_training = is_training + self.total_num = Parameter(Tensor([0], mstype.float32), name='total_num') + self.start_num = Parameter(Tensor([0], mstype.float32), name='start_num') + self.end_num = Parameter(Tensor([0], mstype.float32), name='end_num') + self.sum = P.ReduceSum() + self.equal = P.Equal() + self.argmax = P.ArgMaxWithValue(axis=1) + self.squeeze = P.Squeeze(axis=-1) + + def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible): + logits = self.bert(input_ids, input_mask, token_type_id) + if self.is_training: + unstacked_logits_0 = self.squeeze(logits[:, :, 0:1]) + unstacked_logits_1 = self.squeeze(logits[:, :, 1:2]) + start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length) + end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length) + total_loss = (start_loss + end_loss) / 2.0 + else: + start_logits = self.squeeze(logits[:, :, 0:1]) + end_logits = self.squeeze(logits[:, :, 1:2]) + total_loss = (unique_id, start_logits, end_logits) + return total_loss diff --git a/model_zoo/bert/src/clue_classification_dataset_process.py b/model_zoo/bert/src/clue_classification_dataset_process.py new file mode 100755 index 0000000000..1e27fe0352 --- /dev/null +++ b/model_zoo/bert/src/clue_classification_dataset_process.py @@ -0,0 +1,153 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +sample script of processing CLUE classification dataset using mindspore.dataset.text for fine-tuning bert +""" + +import os +import numpy as np + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.text as text +import mindspore.dataset.transforms.c_transforms as ops + + +def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, + data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64): + """Process TNEWS dataset""" + ### Loading TNEWS from CLUEDataset + assert data_usage in ['train', 'eval', 'test'] + if data_usage == 'train': + dataset = ds.CLUEDataset(os.path.join(data_dir, "train.json"), task='TNEWS', + usage=data_usage, shuffle=shuffle_dataset) + elif data_usage == 'eval': + dataset = ds.CLUEDataset(os.path.join(data_dir, "dev.json"), task='TNEWS', + usage=data_usage, shuffle=shuffle_dataset) + else: + dataset = ds.CLUEDataset(os.path.join(data_dir, "test.json"), task='TNEWS', + usage=data_usage, shuffle=shuffle_dataset) + ### Processing label + if data_usage == 'test': + dataset = dataset.map(input_columns=["id"], output_columns=["id", "label_id"], + columns_order=["id", "label_id", "sentence"], operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["label_id"], operations=ops.Fill(0)) + else: + label_vocab = text.Vocab.from_list(label_list) + label_lookup = text.Lookup(label_vocab) + dataset = dataset.map(input_columns="label_desc", output_columns="label_id", operations=label_lookup) + ### Processing sentence + vocab = text.Vocab.from_file(bert_vocab_path) + tokenizer = text.BertTokenizer(vocab, lower_case=True) + lookup = text.Lookup(vocab, unknown_token='[UNK]') + dataset = dataset.map(input_columns=["sentence"], operations=tokenizer) + dataset = dataset.map(input_columns=["sentence"], operations=ops.Slice(slice(0, max_seq_len))) + dataset = dataset.map(input_columns=["sentence"], + operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'), + append=np.array(["[SEP]"], dtype='S'))) + dataset = dataset.map(input_columns=["sentence"], output_columns=["text_ids"], operations=lookup) + dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) + dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], + columns_order=["label_id", "text_ids", "mask_ids"], operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) + dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "segment_ids"], + columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["segment_ids"], operations=ops.Fill(0)) + dataset = dataset.batch(batch_size) + label = [] + text_ids = [] + mask_ids = [] + segment_ids = [] + for data in dataset: + label.append(data[0]) + text_ids.append(data[1]) + mask_ids.append(data[2]) + segment_ids.append(data[3]) + return label, text_ids, mask_ids, segment_ids + + +def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, + data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64): + """Process CMNLI dataset""" + ### Loading CMNLI from CLUEDataset + assert data_usage in ['train', 'eval', 'test'] + if data_usage == 'train': + dataset = ds.CLUEDataset(os.path.join(data_dir, "train.json"), task='CMNLI', + usage=data_usage, shuffle=shuffle_dataset) + elif data_usage == 'eval': + dataset = ds.CLUEDataset(os.path.join(data_dir, "dev.json"), task='CMNLI', + usage=data_usage, shuffle=shuffle_dataset) + else: + dataset = ds.CLUEDataset(os.path.join(data_dir, "test.json"), task='CMNLI', + usage=data_usage, shuffle=shuffle_dataset) + ### Processing label + if data_usage == 'test': + dataset = dataset.map(input_columns=["id"], output_columns=["id", "label_id"], + columns_order=["id", "label_id", "sentence1", "sentence2"], operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["label_id"], operations=ops.Fill(0)) + else: + label_vocab = text.Vocab.from_list(label_list) + label_lookup = text.Lookup(label_vocab) + dataset = dataset.map(input_columns="label", output_columns="label_id", operations=label_lookup) + ### Processing sentence pairs + vocab = text.Vocab.from_file(bert_vocab_path) + tokenizer = text.BertTokenizer(vocab, lower_case=True) + lookup = text.Lookup(vocab, unknown_token='[UNK]') + ### Tokenizing sentences and truncate sequence pair + dataset = dataset.map(input_columns=["sentence1"], operations=tokenizer) + dataset = dataset.map(input_columns=["sentence2"], operations=tokenizer) + dataset = dataset.map(input_columns=["sentence1", "sentence2"], + operations=text.TruncateSequencePair(max_seq_len-3)) + ### Adding special tokens + dataset = dataset.map(input_columns=["sentence1"], + operations=ops.Concatenate(prepend=np.array(["[CLS]"], dtype='S'), + append=np.array(["[SEP]"], dtype='S'))) + dataset = dataset.map(input_columns=["sentence2"], + operations=ops.Concatenate(append=np.array(["[SEP]"], dtype='S'))) + ### Generating segment_ids + dataset = dataset.map(input_columns=["sentence1"], output_columns=["sentence1", "type_sentence1"], + columns_order=["sentence1", "type_sentence1", "sentence2", "label_id"], + operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["sentence2"], output_columns=["sentence2", "type_sentence2"], + columns_order=["sentence1", "type_sentence1", "sentence2", "type_sentence2", "label_id"], + operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["type_sentence1"], operations=[lookup, ops.Fill(0)]) + dataset = dataset.map(input_columns=["type_sentence2"], operations=[lookup, ops.Fill(1)]) + dataset = dataset.map(input_columns=["type_sentence1", "type_sentence2"], output_columns=["segment_ids"], + columns_order=["sentence1", "sentence2", "segment_ids", "label_id"], + operations=ops.Concatenate()) + dataset = dataset.map(input_columns=["segment_ids"], operations=ops.PadEnd([max_seq_len], 0)) + ### Generating text_ids + dataset = dataset.map(input_columns=["sentence1", "sentence2"], output_columns=["text_ids"], + columns_order=["text_ids", "segment_ids", "label_id"], + operations=ops.Concatenate()) + dataset = dataset.map(input_columns=["text_ids"], operations=lookup) + dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) + ### Generating mask_ids + dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], + columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate()) + dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) + dataset = dataset.batch(batch_size) + label = [] + text_ids = [] + mask_ids = [] + segment_ids = [] + for data in dataset: + label.append(data[0]) + text_ids.append(data[1]) + mask_ids.append(data[2]) + segment_ids.append(data[3]) + return label, text_ids, mask_ids, segment_ids diff --git a/model_zoo/bert/src/cluener_evaluation.py b/model_zoo/bert/src/cluener_evaluation.py index 09de6bf0b3..f4c747ac38 100644 --- a/model_zoo/bert/src/cluener_evaluation.py +++ b/model_zoo/bert/src/cluener_evaluation.py @@ -19,15 +19,13 @@ import json import numpy as np import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor -from . import tokenization -from .sample_process import label_generation, process_one_example_p -from .evaluation_config import cfg -from .CRF import postprocess +from src import tokenization +from src.sample_process import label_generation, process_one_example_p +from src.CRF import postprocess +from src.finetune_eval_config import bert_net_cfg -vocab_file = "./vocab.txt" -tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file) -def process(model, text, sequence_length): +def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""): """ process text. """ @@ -36,13 +34,13 @@ def process(model, text, sequence_length): res = [] ids = [] for i in data: - feature = process_one_example_p(tokenizer_, i, max_seq_len=sequence_length) + feature = process_one_example_p(tokenizer_, i, max_seq_len=bert_net_cfg.seq_length) features.append(feature) input_ids, input_mask, token_type_id = feature input_ids = Tensor(np.array(input_ids), mstype.int32) input_mask = Tensor(np.array(input_mask), mstype.int32) token_type_id = Tensor(np.array(token_type_id), mstype.int32) - if cfg.use_crf: + if use_crf.lower() == "true": backpointers, best_tag_id = model.predict(input_ids, input_mask, token_type_id, Tensor(1)) best_path = postprocess(backpointers, best_tag_id) logits = [] @@ -54,19 +52,21 @@ def process(model, text, sequence_length): ids = logits.asnumpy() ids = np.argmax(ids, axis=-1) ids = list(ids) - res = label_generation(text, ids) + res = label_generation(text=text, probs=ids, label2id_file=label2id_file) return res -def submit(model, path, sequence_length): +def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""): """ submit task """ + tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file) data = [] for line in open(path): if not line.strip(): continue oneline = json.loads(line.strip()) - res = process(model, oneline["text"], sequence_length) + res = process(model=model, text=oneline["text"], tokenizer_=tokenizer_, + use_crf=use_crf, label2id_file=label2id_file) print("text", oneline["text"]) print("res:", res) data.append(json.dumps({"label": res}, ensure_ascii=False)) diff --git a/model_zoo/bert/src/dataset.py b/model_zoo/bert/src/dataset.py index 7985ca8559..e530718d4f 100644 --- a/model_zoo/bert/src/dataset.py +++ b/model_zoo/bert/src/dataset.py @@ -36,8 +36,8 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], - shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, - shard_equal_rows=True) + shuffle=de.Shuffle.FILES if do_shuffle == "true" else False, + num_shards=device_num, shard_id=rank, shard_equal_rows=True) ori_dataset_size = ds.get_dataset_size() print('origin dataset size: ', ori_dataset_size) new_size = ori_dataset_size @@ -58,3 +58,77 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e logger.info("data size: {}".format(ds.get_dataset_size())) logger.info("repeatcount: {}".format(ds.get_repeat_count())) return ds, new_repeat_count + + +def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", + data_file_path=None, schema_file_path=None): + """create finetune or evaluation dataset""" + type_cast_op = C.TypeCast(mstype.int32) + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) + if assessment_method == "Spearman_correlation": + type_cast_op_float = C.TypeCast(mstype.float32) + ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) + else: + ds = ds.map(input_columns="label_ids", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.repeat(repeat_count) + # apply shuffle operation + buffer_size = 960 + ds = ds.shuffle(buffer_size=buffer_size) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + return ds + + +def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", + data_file_path=None, schema_file_path=None): + """create finetune or evaluation dataset""" + type_cast_op = C.TypeCast(mstype.int32) + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) + if assessment_method == "Spearman_correlation": + type_cast_op_float = C.TypeCast(mstype.float32) + ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) + else: + ds = ds.map(input_columns="label_ids", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.repeat(repeat_count) + # apply shuffle operation + buffer_size = 960 + ds = ds.shuffle(buffer_size=buffer_size) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + return ds + + +def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None, is_training=True): + """create finetune or evaluation dataset""" + type_cast_op = C.TypeCast(mstype.int32) + if is_training: + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", + "start_positions", "end_positions", + "unique_ids", "is_impossible"]) + ds = ds.map(input_columns="start_positions", operations=type_cast_op) + ds = ds.map(input_columns="end_positions", operations=type_cast_op) + else: + ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, + columns_list=["input_ids", "input_mask", "segment_ids", "unique_ids"]) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + ds = ds.repeat(repeat_count) + # apply shuffle operation + buffer_size = 960 + ds = ds.shuffle(buffer_size=buffer_size) + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + return ds diff --git a/model_zoo/bert/src/evaluation_config.py b/model_zoo/bert/src/evaluation_config.py deleted file mode 100644 index b18c5643b0..0000000000 --- a/model_zoo/bert/src/evaluation_config.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -config settings, will be used in finetune.py -""" - -from easydict import EasyDict as edict -import mindspore.common.dtype as mstype -from .bert_model import BertConfig - -cfg = edict({ - 'task': 'NER', - 'num_labels': 41, - 'data_file': '/your/path/evaluation.tfrecord', - 'schema_file': '/your/path/schema.json', - 'finetune_ckpt': '/your/path/your.ckpt', - 'use_crf': False, - 'clue_benchmark': False, -}) - -bert_net_cfg = BertConfig( - batch_size=16 if not cfg.clue_benchmark else 1, - seq_length=128, - vocab_size=21128, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, - dtype=mstype.float32, - compute_type=mstype.float16, -) diff --git a/model_zoo/bert/src/finetune_config.py b/model_zoo/bert/src/finetune_config.py deleted file mode 100644 index 6241d06994..0000000000 --- a/model_zoo/bert/src/finetune_config.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -config settings, will be used in finetune.py -""" - -from easydict import EasyDict as edict -import mindspore.common.dtype as mstype -from .bert_model import BertConfig - -cfg = edict({ - 'task': 'NER', - 'num_labels': 41, - 'data_file': '/your/path/train.tfrecord', - 'schema_file': '/your/path/schema.json', - 'epoch_num': 5, - 'ckpt_prefix': 'bert', - 'ckpt_dir': None, - 'pre_training_ckpt': '/your/path/pre_training.ckpt', - 'use_crf': False, - 'optimizer': 'Lamb', - 'AdamWeightDecayDynamicLR': edict({ - 'learning_rate': 2e-5, - 'end_learning_rate': 1e-7, - 'power': 1.0, - 'weight_decay': 1e-5, - 'eps': 1e-6, - }), - 'Lamb': edict({ - 'start_learning_rate': 2e-5, - 'end_learning_rate': 1e-7, - 'power': 1.0, - 'weight_decay': 0.01, - 'decay_filter': lambda x: False, - }), - 'Momentum': edict({ - 'learning_rate': 2e-5, - 'momentum': 0.9, - }), -}) - -bert_net_cfg = BertConfig( - batch_size=16, - seq_length=128, - vocab_size=21128, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - use_relative_positions=False, - input_mask_from_dataset=True, - token_type_ids_from_dataset=True, - dtype=mstype.float32, - compute_type=mstype.float16, -) - -tag_to_index = { - "O": 0, - "S_address": 1, - "B_address": 2, - "M_address": 3, - "E_address": 4, - "S_book": 5, - "B_book": 6, - "M_book": 7, - "E_book": 8, - "S_company": 9, - "B_company": 10, - "M_company": 11, - "E_company": 12, - "S_game": 13, - "B_game": 14, - "M_game": 15, - "E_game": 16, - "S_government": 17, - "B_government": 18, - "M_government": 19, - "E_government": 20, - "S_movie": 21, - "B_movie": 22, - "M_movie": 23, - "E_movie": 24, - "S_name": 25, - "B_name": 26, - "M_name": 27, - "E_name": 28, - "S_organization": 29, - "B_organization": 30, - "M_organization": 31, - "E_organization": 32, - "S_position": 33, - "B_position": 34, - "M_position": 35, - "E_position": 36, - "S_scene": 37, - "B_scene": 38, - "M_scene": 39, - "E_scene": 40, - "": 41, - "": 42 -} diff --git a/model_zoo/bert/src/finetune_eval_config.py b/model_zoo/bert/src/finetune_eval_config.py new file mode 100644 index 0000000000..4b8e121e09 --- /dev/null +++ b/model_zoo/bert/src/finetune_eval_config.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +config settings, will be used in finetune.py +""" + +from easydict import EasyDict as edict +import mindspore.common.dtype as mstype +from .bert_model import BertConfig + +optimizer_cfg = edict({ + 'optimizer': 'Lamb', + 'AdamWeightDecayDynamicLR': edict({ + 'learning_rate': 2e-5, + 'end_learning_rate': 1e-7, + 'power': 1.0, + 'weight_decay': 1e-5, + 'eps': 1e-6, + }), + 'Lamb': edict({ + 'start_learning_rate': 2e-5, + 'end_learning_rate': 1e-7, + 'power': 1.0, + 'weight_decay': 0.01, + 'decay_filter': lambda x: False, + }), + 'Momentum': edict({ + 'learning_rate': 2e-5, + 'momentum': 0.9, + }), +}) + +bert_net_cfg = BertConfig( + batch_size=16, + seq_length=128, + vocab_size=21128, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=False, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, +) diff --git a/model_zoo/bert/src/finetune_eval_model.py b/model_zoo/bert/src/finetune_eval_model.py new file mode 100644 index 0000000000..047decc377 --- /dev/null +++ b/model_zoo/bert/src/finetune_eval_model.py @@ -0,0 +1,123 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +''' +Bert finetune and evaluation model script. +''' + +import mindspore.nn as nn +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import operations as P +from .bert_model import BertModel + +class BertCLSModel(nn.Cell): + """ + This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), + LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final + logits as the results of log_softmax is propotional to that of softmax. + """ + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False, + assessment_method=""): + super(BertCLSModel, self).__init__() + if not is_training: + config.hidden_dropout_prob = 0.0 + config.hidden_probs_dropout_prob = 0.0 + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cast = P.Cast() + self.weight_init = TruncatedNormal(config.initializer_range) + self.log_softmax = P.LogSoftmax(axis=-1) + self.dtype = config.dtype + self.num_labels = num_labels + self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) + self.dropout = nn.Dropout(1 - dropout_prob) + self.assessment_method = assessment_method + + def construct(self, input_ids, input_mask, token_type_id): + _, pooled_output, _ = \ + self.bert(input_ids, token_type_id, input_mask) + cls = self.cast(pooled_output, self.dtype) + cls = self.dropout(cls) + logits = self.dense_1(cls) + logits = self.cast(logits, self.dtype) + if self.assessment_method != "spearman_correlation": + logits = self.log_softmax(logits) + return logits + +class BertSquadModel(nn.Cell): + ''' + This class is responsible for SQuAD + ''' + def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): + super(BertSquadModel, self).__init__() + if not is_training: + config.hidden_dropout_prob = 0.0 + config.hidden_probs_dropout_prob = 0.0 + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.weight_init = TruncatedNormal(config.initializer_range) + self.dense1 = nn.Dense(config.hidden_size, num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) + self.num_labels = num_labels + self.dtype = config.dtype + self.log_softmax = P.LogSoftmax(axis=1) + self.is_training = is_training + + def construct(self, input_ids, input_mask, token_type_id): + sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask) + batch_size, seq_length, hidden_size = P.Shape()(sequence_output) + sequence = P.Reshape()(sequence_output, (-1, hidden_size)) + logits = self.dense1(sequence) + logits = P.Cast()(logits, self.dtype) + logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels)) + logits = self.log_softmax(logits) + return logits + +class BertNERModel(nn.Cell): + """ + This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11). + The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. + """ + def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0, + use_one_hot_embeddings=False): + super(BertNERModel, self).__init__() + if not is_training: + config.hidden_dropout_prob = 0.0 + config.hidden_probs_dropout_prob = 0.0 + self.bert = BertModel(config, is_training, use_one_hot_embeddings) + self.cast = P.Cast() + self.weight_init = TruncatedNormal(config.initializer_range) + self.log_softmax = P.LogSoftmax(axis=-1) + self.dtype = config.dtype + self.num_labels = num_labels + self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, + has_bias=True).to_float(config.compute_type) + self.dropout = nn.Dropout(1 - dropout_prob) + self.reshape = P.Reshape() + self.shape = (-1, config.hidden_size) + self.use_crf = use_crf + self.origin_shape = (config.batch_size, config.seq_length, self.num_labels) + + def construct(self, input_ids, input_mask, token_type_id): + sequence_output, _, _ = \ + self.bert(input_ids, token_type_id, input_mask) + seq = self.dropout(sequence_output) + seq = self.reshape(seq, self.shape) + logits = self.dense_1(seq) + logits = self.cast(logits, self.dtype) + if self.use_crf: + return_value = self.reshape(logits, self.origin_shape) + else: + return_value = self.log_softmax(logits) + return return_value diff --git a/model_zoo/bert/src/sample_process.py b/model_zoo/bert/src/sample_process.py index 59f3e76a31..c7cf29c510 100644 --- a/model_zoo/bert/src/sample_process.py +++ b/model_zoo/bert/src/sample_process.py @@ -52,12 +52,12 @@ def process_one_example_p(tokenizer, text, max_seq_len=128): feature = (input_ids, input_mask, segment_ids) return feature -def label_generation(text, probs): +def label_generation(text="", probs=None, label2id_file=""): """generate label""" data = [text] probs = [probs] result = [] - label2id = json.loads(open("./label2id.json").read()) + label2id = json.loads(open(label2id_file).read()) id2label = [k for k, v in label2id.items()] for index, prob in enumerate(probs): diff --git a/model_zoo/bert/src/utils.py b/model_zoo/bert/src/utils.py index ec5651b205..dfb6ffa5fe 100644 --- a/model_zoo/bert/src/utils.py +++ b/model_zoo/bert/src/utils.py @@ -17,347 +17,13 @@ Functional Cells used in Bert finetune and evaluation. """ +import os import mindspore.nn as nn -from mindspore.common.initializer import TruncatedNormal from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore.ops import composite as C from mindspore.common.tensor import Tensor -from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common import dtype as mstype -from mindspore.nn.wrap.grad_reducer import DistributedGradReducer -from mindspore.train.parallel_utils import ParallelMode -from mindspore.communication.management import get_group_size -from mindspore import context -from .bert_model import BertModel -from .bert_for_pre_training import clip_grad -from .CRF import CRF +from mindspore.train.callback import Callback -GRADIENT_CLIP_TYPE = 1 -GRADIENT_CLIP_VALUE = 1.0 -grad_scale = C.MultitypeFuncGraph("grad_scale") -reciprocal = P.Reciprocal() - -@grad_scale.register("Tensor", "Tensor") -def tensor_grad_scale(scale, grad): - return grad * reciprocal(scale) - -_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") -grad_overflow = P.FloatStatus() - -@_grad_overflow.register("Tensor") -def _tensor_grad_overflow(grad): - return grad_overflow(grad) - -class BertFinetuneCell(nn.Cell): - """ - Especifically defined for finetuning where only four inputs tensor are needed. - """ - def __init__(self, network, optimizer, scale_update_cell=None): - - super(BertFinetuneCell, self).__init__(auto_prefix=False) - self.network = network - self.weights = ParameterTuple(network.trainable_params()) - self.optimizer = optimizer - self.grad = C.GradOperation('grad', - get_by_list=True, - sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("mirror_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.cast = P.Cast() - self.gpu_target = False - if context.get_context("device_target") == "GPU": - self.gpu_target = True - self.float_status = P.FloatStatus() - self.addn = P.AddN() - self.reshape = P.Reshape() - else: - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), - name="loss_scale") - - def construct(self, - input_ids, - input_mask, - token_type_id, - label_ids, - sens=None): - - - weights = self.weights - init = False - loss = self.network(input_ids, - input_mask, - token_type_id, - label_ids) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - - if not self.gpu_target: - init = self.alloc_status() - clear_before_grad = self.clear_before_grad(init) - F.control_depend(loss, init) - self.depend_parameter_use(clear_before_grad, scaling_sens) - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - label_ids, - self.cast(scaling_sens, - mstype.float32)) - grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - if self.reducer_flag: - grads = self.grad_reducer(grads) - if not self.gpu_target: - flag = self.get_status(init) - flag_sum = self.reduce_sum(init, (0,)) - F.control_depend(grads, flag) - F.control_depend(flag, flag_sum) - else: - flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) - flag_sum = self.addn(flag_sum) - flag_sum = self.reshape(flag_sum, (())) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) - if overflow: - succ = False - else: - succ = self.optimizer(grads) - ret = (loss, cond) - return F.depend(ret, succ) - -class BertSquadCell(nn.Cell): - """ - specifically defined for finetuning where only four inputs tensor are needed. - """ - def __init__(self, network, optimizer, scale_update_cell=None): - super(BertSquadCell, self).__init__(auto_prefix=False) - self.network = network - self.weights = ParameterTuple(network.trainable_params()) - self.optimizer = optimizer - self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) - self.reducer_flag = False - self.allreduce = P.AllReduce() - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.grad_reducer = None - if self.reducer_flag: - mean = context.get_auto_parallel_context("mirror_mean") - degree = get_group_size() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), - name="loss_scale") - def construct(self, - input_ids, - input_mask, - token_type_id, - start_position, - end_position, - unique_id, - is_impossible, - sens=None): - weights = self.weights - init = self.alloc_status() - loss = self.network(input_ids, - input_mask, - token_type_id, - start_position, - end_position, - unique_id, - is_impossible) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - grads = self.grad(self.network, weights)(input_ids, - input_mask, - token_type_id, - start_position, - end_position, - unique_id, - is_impossible, - self.cast(scaling_sens, - mstype.float32)) - clear_before_grad = self.clear_before_grad(init) - F.control_depend(loss, init) - self.depend_parameter_use(clear_before_grad, scaling_sens) - grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) - grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - if self.reducer_flag: - grads = self.grad_reducer(grads) - flag = self.get_status(init) - flag_sum = self.reduce_sum(init, (0,)) - if self.is_distributed: - flag_reduce = self.allreduce(flag_sum) - cond = self.less_equal(self.base, flag_reduce) - else: - cond = self.less_equal(self.base, flag_sum) - F.control_depend(grads, flag) - F.control_depend(flag, flag_sum) - overflow = cond - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, cond) - if overflow: - succ = False - else: - succ = self.optimizer(grads) - ret = (loss, cond) - return F.depend(ret, succ) - - -class BertRegressionModel(nn.Cell): - """ - Bert finetune model for regression task - """ - def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): - super(BertRegressionModel, self).__init__() - self.bert = BertModel(config, is_training, use_one_hot_embeddings) - self.cast = P.Cast() - self.weight_init = TruncatedNormal(config.initializer_range) - self.log_softmax = P.LogSoftmax(axis=-1) - self.dtype = config.dtype - self.num_labels = num_labels - self.dropout = nn.Dropout(1 - dropout_prob) - self.dense_1 = nn.Dense(config.hidden_size, 1, weight_init=self.weight_init, - has_bias=True).to_float(mstype.float16) - - def construct(self, input_ids, input_mask, token_type_id): - _, pooled_output, _ = self.bert(input_ids, token_type_id, input_mask) - cls = self.cast(pooled_output, self.dtype) - cls = self.dropout(cls) - logits = self.dense_1(cls) - logits = self.cast(logits, self.dtype) - return logits - - -class BertCLSModel(nn.Cell): - """ - This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), - LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final - logits as the results of log_softmax is propotional to that of softmax. - """ - def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): - super(BertCLSModel, self).__init__() - self.bert = BertModel(config, is_training, use_one_hot_embeddings) - self.cast = P.Cast() - self.weight_init = TruncatedNormal(config.initializer_range) - self.log_softmax = P.LogSoftmax(axis=-1) - self.dtype = config.dtype - self.num_labels = num_labels - self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, - has_bias=True).to_float(config.compute_type) - self.dropout = nn.Dropout(1 - dropout_prob) - - def construct(self, input_ids, input_mask, token_type_id): - _, pooled_output, _ = \ - self.bert(input_ids, token_type_id, input_mask) - cls = self.cast(pooled_output, self.dtype) - cls = self.dropout(cls) - logits = self.dense_1(cls) - logits = self.cast(logits, self.dtype) - log_probs = self.log_softmax(logits) - return log_probs - -class BertSquadModel(nn.Cell): - """ - Bert finetune model for SQuAD v1.1 task - """ - def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): - super(BertSquadModel, self).__init__() - self.bert = BertModel(config, is_training, use_one_hot_embeddings) - self.weight_init = TruncatedNormal(config.initializer_range) - self.dense1 = nn.Dense(config.hidden_size, num_labels, weight_init=self.weight_init, - has_bias=True).to_float(config.compute_type) - self.num_labels = num_labels - self.dtype = config.dtype - self.log_softmax = P.LogSoftmax(axis=1) - self.is_training = is_training - - def construct(self, input_ids, input_mask, token_type_id): - sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask) - batch_size, seq_length, hidden_size = P.Shape()(sequence_output) - sequence = P.Reshape()(sequence_output, (-1, hidden_size)) - logits = self.dense1(sequence) - logits = P.Cast()(logits, self.dtype) - logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels)) - logits = self.log_softmax(logits) - return logits - -class BertNERModel(nn.Cell): - """ - This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11). - The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. - """ - def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0, - use_one_hot_embeddings=False): - super(BertNERModel, self).__init__() - self.bert = BertModel(config, is_training, use_one_hot_embeddings) - self.cast = P.Cast() - self.weight_init = TruncatedNormal(config.initializer_range) - self.log_softmax = P.LogSoftmax(axis=-1) - self.dtype = config.dtype - self.num_labels = num_labels - self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, - has_bias=True).to_float(config.compute_type) - self.dropout = nn.Dropout(1 - dropout_prob) - self.reshape = P.Reshape() - self.shape = (-1, config.hidden_size) - self.use_crf = use_crf - self.origin_shape = (config.batch_size, config.seq_length, self.num_labels) - - def construct(self, input_ids, input_mask, token_type_id): - sequence_output, _, _ = \ - self.bert(input_ids, token_type_id, input_mask) - seq = self.dropout(sequence_output) - seq = self.reshape(seq, self.shape) - logits = self.dense_1(seq) - logits = self.cast(logits, self.dtype) - if self.use_crf: - return_value = self.reshape(logits, self.origin_shape) - else: - return_value = self.log_softmax(logits) - return return_value class CrossEntropyCalculation(nn.Cell): """ @@ -387,95 +53,73 @@ class CrossEntropyCalculation(nn.Cell): return_value = logits * 1.0 return return_value -class BertCLS(nn.Cell): - """ - Train interface for classification finetuning task. - """ - def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): - super(BertCLS, self).__init__() - self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) - self.loss = CrossEntropyCalculation(is_training) - self.num_labels = num_labels - def construct(self, input_ids, input_mask, token_type_id, label_ids): - log_probs = self.bert(input_ids, input_mask, token_type_id) - loss = self.loss(log_probs, label_ids, self.num_labels) - return loss - -class BertNER(nn.Cell): - """ - Train interface for sequence labeling finetuning task. - """ - def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0, - use_one_hot_embeddings=False): - super(BertNER, self).__init__() - self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings) - if use_crf: - if not tag_to_index: - raise Exception("The dict for tag-index mapping should be provided for CRF.") - self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training) - else: - self.loss = CrossEntropyCalculation(is_training) - self.num_labels = num_labels - self.use_crf = use_crf - def construct(self, input_ids, input_mask, token_type_id, label_ids): - logits = self.bert(input_ids, input_mask, token_type_id) - if self.use_crf: - loss = self.loss(logits, label_ids) - else: - loss = self.loss(logits, label_ids, self.num_labels) - return loss - -class BertSquad(nn.Cell): - """ - Train interface for SQuAD finetuning task. - """ - def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): - super(BertSquad, self).__init__() - self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) - self.loss = CrossEntropyCalculation(is_training) - self.num_labels = num_labels - self.seq_length = config.seq_length - self.is_training = is_training - self.total_num = Parameter(Tensor([0], mstype.float32), name='total_num') - self.start_num = Parameter(Tensor([0], mstype.float32), name='start_num') - self.end_num = Parameter(Tensor([0], mstype.float32), name='end_num') - self.sum = P.ReduceSum() - self.equal = P.Equal() - self.argmax = P.ArgMaxWithValue(axis=1) - self.squeeze = P.Squeeze(axis=-1) - - def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible): - logits = self.bert(input_ids, input_mask, token_type_id) - if self.is_training: - unstacked_logits_0 = self.squeeze(logits[:, :, 0:1]) - unstacked_logits_1 = self.squeeze(logits[:, :, 1:2]) - start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length) - end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length) - total_loss = (start_loss + end_loss) / 2.0 - else: - start_logits = self.squeeze(logits[:, :, 0:1]) - end_logits = self.squeeze(logits[:, :, 1:2]) - total_loss = (unique_id, start_logits, end_logits) - return total_loss - - -class BertReg(nn.Cell): - """ - Bert finetune model with loss for regression task - """ - def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): - super(BertReg, self).__init__() - self.bert = BertRegressionModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) - self.loss = nn.MSELoss() - self.is_training = is_training - self.sigmoid = P.Sigmoid() - self.cast = P.Cast() - self.mul = P.Mul() - def construct(self, input_ids, input_mask, token_type_id, labels): - logits = self.bert(input_ids, input_mask, token_type_id) - if self.is_training: - loss = self.loss(logits, labels) - else: - loss = logits - return loss +def make_directory(path: str): + """Make directory.""" + if path is None or not isinstance(path, str) or path.strip() == "": + logger.error("The path(%r) is invalid type.", path) + raise TypeError("Input path is invaild type") + + # convert the relative paths + path = os.path.realpath(path) + logger.debug("The abs path is %r", path) + + # check the path is exist and write permissions? + if os.path.exists(path): + real_path = path + else: + # All exceptions need to be caught because create directory maybe have some limit(permissions) + logger.debug("The directory(%s) doesn't exist, will create it", path) + try: + os.makedirs(path, exist_ok=True) + real_path = path + except PermissionError as e: + logger.error("No write permission on the directory(%r), error = %r", path, e) + raise TypeError("No write permission on the directory.") + return real_path + +class LossCallBack(Callback): + """ + Monitor the loss in training. + If the loss in NAN or INF terminating training. + Note: + if per_print_times is 0 do not print loss. + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + def __init__(self, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0") + self._per_print_times = per_print_times + def step_end(self, run_context): + cb_params = run_context.original_args() + print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, + str(cb_params.net_outputs))) + +def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): + """ + Find the ckpt finetune generated and load it into eval network. + """ + files = os.listdir(load_finetune_checkpoint_dir) + pre_len = len(prefix) + max_num = 0 + for filename in files: + name_ext = os.path.splitext(filename) + if name_ext[-1] != ".ckpt": + continue + #steps_per_epoch = ds.get_dataset_size() + if filename.find(prefix) == 0 and not filename[pre_len].isalpha(): + index = filename[pre_len:].find("-") + if index == 0 and max_num == 0: + load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) + elif index not in (0, -1): + name_split = name_ext[-2].split('_') + if (steps_per_epoch != int(name_split[len(name_split)-1])) \ + or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])): + continue + num = filename[pre_len + 1:pre_len + index] + if int(num) > max_num: + max_num = int(num) + load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) + return load_finetune_checkpoint_path diff --git a/model_zoo/faster_rcnn/eval.py b/model_zoo/faster_rcnn/eval.py index e0b4e2d0ea..d8dd2ed79a 100644 --- a/model_zoo/faster_rcnn/eval.py +++ b/model_zoo/faster_rcnn/eval.py @@ -40,7 +40,7 @@ parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoi parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") args_opt = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) def FasterRcnn_eval(dataset_path, ckpt_path, ann_file): """FasterRcnn evaluation.""" diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py b/model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py index 05d6d1c9d1..bcf0536f5b 100644 --- a/model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py +++ b/model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py @@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor from mindspore.common import dtype as mstype from mindspore.common.initializer import initializer -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") def bias_init_zeros(shape): """Bias init method.""" diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py b/model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py index 9428b20914..f9bcc47df4 100644 --- a/model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py +++ b/model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py @@ -22,7 +22,7 @@ from mindspore import Tensor from mindspore import context -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") class Proposal(nn.Cell): diff --git a/model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py b/model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py index 20d9ee1f34..002ea08d0c 100644 --- a/model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py +++ b/model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py @@ -22,7 +22,7 @@ from mindspore.ops import functional as F from mindspore import context -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") def weight_init_ones(shape): diff --git a/model_zoo/faster_rcnn/train.py b/model_zoo/faster_rcnn/train.py index 3cc86c7cc1..7d5f190bab 100644 --- a/model_zoo/faster_rcnn/train.py +++ b/model_zoo/faster_rcnn/train.py @@ -52,7 +52,7 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums, parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default is 0.") args_opt = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=args_opt.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) if __name__ == '__main__': if not args_opt.do_eval and args_opt.run_distribute: diff --git a/model_zoo/gat/README.md b/model_zoo/gat/README.md index 7c30e08851..0c46aebbaf 100644 --- a/model_zoo/gat/README.md +++ b/model_zoo/gat/README.md @@ -72,9 +72,9 @@ sh run_process_data.sh [SRC_PATH] [DATASET_NAME] >> Launch ``` #Generate dataset in mindrecord format for cora -sh run_process_data.sh cora +./run_process_data.sh ./data cora #Generate dataset in mindrecord format for citeseer -sh run_process_data.sh citeseer +./run_process_data.sh ./data citeseer ``` # Features diff --git a/model_zoo/gat/train.py b/model_zoo/gat/train.py index af1808b995..acfbb05b78 100644 --- a/model_zoo/gat/train.py +++ b/model_zoo/gat/train.py @@ -96,6 +96,8 @@ def train(): if eval_acc >= val_acc_max and eval_loss < val_loss_min: val_acc_model = eval_acc val_loss_model = eval_loss + if os.path.exists("ckpts/gat.ckpt"): + os.remove("ckpts/gat.ckpt") _exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt") val_acc_max = np.max((val_acc_max, eval_acc)) val_loss_min = np.min((val_loss_min, eval_loss)) diff --git a/model_zoo/googlenet/scripts/run_train.sh b/model_zoo/googlenet/scripts/run_train.sh index c21c2f04b6..e8c045c8b1 100644 --- a/model_zoo/googlenet/scripts/run_train.sh +++ b/model_zoo/googlenet/scripts/run_train.sh @@ -33,10 +33,12 @@ MINDSPORE_HCCL_CONFIG_PATH=$(realpath $1) export MINDSPORE_HCCL_CONFIG_PATH echo "MINDSPORE_HCCL_CONFIG_PATH=${MINDSPORE_HCCL_CONFIG_PATH}" +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) for((i=0; i<${DEVICE_NUM}; i++)) do export DEVICE_ID=$i - export RANK_ID=$i + export RANK_ID=$((rank_start + i)) rm -rf ./train_parallel$i mkdir ./train_parallel$i cp -r ./src ./train_parallel$i diff --git a/model_zoo/googlenet/src/dataset.py b/model_zoo/googlenet/src/dataset.py index a1cbc2cdab..a3f74a0617 100644 --- a/model_zoo/googlenet/src/dataset.py +++ b/model_zoo/googlenet/src/dataset.py @@ -31,8 +31,7 @@ def create_dataset(data_home, repeat_num=1, training=True): if not training: data_dir = os.path.join(data_home, "cifar-10-verify-bin") - rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else None - rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else None + rank_size, rank_id = _get_rank_info() data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) resize_height = cfg.image_height @@ -65,3 +64,19 @@ def create_dataset(data_home, repeat_num=1, training=True): data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) return data_set + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + from mindspore.communication.management import get_rank, get_group_size + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = rank_id = None + + return rank_size, rank_id diff --git a/model_zoo/lenet_quant/src/loss_monitor.py b/model_zoo/lenet_quant/src/loss_monitor.py new file mode 100644 index 0000000000..59c222d23d --- /dev/null +++ b/model_zoo/lenet_quant/src/loss_monitor.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""LossMonitor Callback class.""" + +import time +import numpy as np +from mindspore.common.tensor import Tensor +from mindspore.train.callback import Callback + + +class LossMonitor(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF, it will terminate training. + + Note: + If per_print_times is 0 do not print loss. + + Args: + per_print_times (int): Print loss every times. Default: 1. + lr_init (numpy array): train learning rate. Default: None. + + Raises: + ValueError: If print_step is not int or less than zero. + + Examples: + >>> LossMonitor(100, lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, per_print_times=1, lr_init=None): + super(LossMonitor, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self.lr_init = lr_init + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("Epoch time: {:5.3f}, per step time: {:5.3f}, " + "avg loss: {:5.3f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + print("*" * 60) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) * 1000 + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1 + + if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): + raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " + "Invalid loss, terminating training.".format( + cb_params.cur_epoch_num - 1, cb_params.epoch_num, + cur_step_in_epoch, cb_params.batch_num)) + + if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: + print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " + "loss: [{:5.4f}], avg loss: [{:5.4f}], time: [{:5.4f}ms]".format( + cb_params.cur_epoch_num, cb_params.epoch_num, + cur_step_in_epoch, int(cb_params.batch_num), + step_loss, np.mean(self.losses), + step_mseconds), flush=True) diff --git a/model_zoo/lenet_quant/train.py b/model_zoo/lenet_quant/train.py index 2cff465832..03e9ff62bd 100644 --- a/model_zoo/lenet_quant/train.py +++ b/model_zoo/lenet_quant/train.py @@ -22,12 +22,13 @@ import os import argparse import mindspore.nn as nn from mindspore import context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train import Model from mindspore.nn.metrics import Accuracy from src.dataset import create_dataset from src.config import mnist_cfg as cfg from src.lenet_fusion import LeNet5 as LeNet5Fusion +from src.loss_monitor import LossMonitor parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", diff --git a/model_zoo/lenet_quant/train_quant.py b/model_zoo/lenet_quant/train_quant.py index 6f27cec1e3..3a87ccc70d 100644 --- a/model_zoo/lenet_quant/train_quant.py +++ b/model_zoo/lenet_quant/train_quant.py @@ -23,13 +23,14 @@ import argparse import mindspore.nn as nn from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.train.quant import quant from src.dataset import create_dataset from src.config import mnist_cfg as cfg from src.lenet_fusion import LeNet5 as LeNet5Fusion +from src.loss_monitor import LossMonitor parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser.add_argument('--device_target', type=str, default="Ascend", diff --git a/model_zoo/mass/eval.py b/model_zoo/mass/eval.py index 4da63a7333..bb844e9102 100644 --- a/model_zoo/mass/eval.py +++ b/model_zoo/mass/eval.py @@ -15,15 +15,13 @@ """Evaluation api.""" import argparse import pickle -import numpy as np from mindspore.common import dtype as mstype from config import TransformerConfig -from src.transformer import infer -from src.utils import ngram_ppl +from src.transformer import infer, infer_ppl from src.utils import Dictionary -from src.utils import rouge +from src.utils import get_score parser = argparse.ArgumentParser(description='Evaluation MASS.') parser.add_argument("--config", type=str, required=True, @@ -32,6 +30,8 @@ parser.add_argument("--vocab", type=str, required=True, help="Vocabulary to use.") parser.add_argument("--output", type=str, required=True, help="Result file path.") +parser.add_argument("--metric", type=str, default='rouge', + help='Set eval method.') def get_config(config): @@ -45,31 +45,15 @@ if __name__ == '__main__': args, _ = parser.parse_known_args() vocab = Dictionary.load_from_persisted_dict(args.vocab) _config = get_config(args.config) - result = infer(_config) + + if args.metric == 'rouge': + result = infer(_config) + else: + result = infer_ppl(_config) + with open(args.output, "wb") as f: pickle.dump(result, f, 1) - ppl_score = 0. - preds = [] - tgts = [] - _count = 0 - for sample in result: - sentence_prob = np.array(sample['prediction_prob'], dtype=np.float32) - sentence_prob = sentence_prob[:, 1:] - _ppl = [] - for path in sentence_prob: - _ppl.append(ngram_ppl(path, log_softmax=True)) - ppl = np.min(_ppl) - preds.append(' '.join([vocab[t] for t in sample['prediction']])) - tgts.append(' '.join([vocab[t] for t in sample['target']])) - print(f" | source: {' '.join([vocab[t] for t in sample['source']])}") - print(f" | target: {tgts[-1]}") - print(f" | prediction: {preds[-1]}") - print(f" | ppl: {ppl}.") - if np.isinf(ppl): - continue - ppl_score += ppl - _count += 1 - - print(f" | PPL={ppl_score / _count}.") - rouge(preds, tgts) + # get score by given metric + score = get_score(result, vocab, metric=args.metric) + print(score) diff --git a/model_zoo/mass/scripts/run.sh b/model_zoo/mass/scripts/run.sh index 91bed510ea..132e38dae2 100644 --- a/model_zoo/mass/scripts/run.sh +++ b/model_zoo/mass/scripts/run.sh @@ -18,7 +18,7 @@ export DEVICE_ID=0 export RANK_ID=0 export RANK_SIZE=1 -options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab: -- "$@"` +options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric: -- "$@"` eval set -- "$options" echo $options @@ -35,6 +35,7 @@ echo_help() echo " -c --config set the configuration file" echo " -o --output set the output file of inference" echo " -v --vocab set the vocabulary" + echo " -m --metric set the metric" } set_hccl_json() @@ -43,8 +44,8 @@ set_hccl_json() do if [[ "$1" == "-j" || "$1" == "--hccl_json" ]] then - export MINDSPORE_HCCL_CONFIG_PATH=$2 #/data/wsc/hccl_2p_01.json - export RANK_TABLE_FILE=$2 #/data/wsc/hccl_2p_01.json + export MINDSPORE_HCCL_CONFIG_PATH=$2 + export RANK_TABLE_FILE=$2 break fi shift @@ -119,6 +120,11 @@ do vocab=$2 shift 2 ;; + -m|--metric) + echo "metric"; + metric=$2 + shift 2 + ;; --) shift break @@ -163,7 +169,7 @@ do python train.py --config ${configurations##*/} >>log.log 2>&1 & elif [ "$task" == "infer" ] then - python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} >>log_infer.log 2>&1 & + python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} >>log_infer.log 2>&1 & fi cd ../ done diff --git a/model_zoo/mass/src/transformer/__init__.py b/model_zoo/mass/src/transformer/__init__.py index 7912e7f0dd..36db26d360 100644 --- a/model_zoo/mass/src/transformer/__init__.py +++ b/model_zoo/mass/src/transformer/__init__.py @@ -19,10 +19,11 @@ from .decoder import TransformerDecoder from .beam_search import BeamSearchDecoder from .transformer_for_train import TransformerTraining, LabelSmoothedCrossEntropyCriterion, \ TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell -from .infer_mass import infer +from .infer_mass import infer, infer_ppl __all__ = [ "infer", + "infer_ppl", "TransformerTraining", "LabelSmoothedCrossEntropyCriterion", "TransformerTrainOneStepWithLossScaleCell", diff --git a/model_zoo/mass/src/transformer/embedding.py b/model_zoo/mass/src/transformer/embedding.py index bdce540416..22887b0a3e 100644 --- a/model_zoo/mass/src/transformer/embedding.py +++ b/model_zoo/mass/src/transformer/embedding.py @@ -41,7 +41,7 @@ class EmbeddingLookup(nn.Cell): self.vocab_size = vocab_size self.use_one_hot_embeddings = use_one_hot_embeddings - init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim]) + init_weight = np.random.normal(0, embed_dim ** -0.5, size=[vocab_size, embed_dim]).astype(np.float32) # 0 is Padding index, thus init it as 0. init_weight[0, :] = 0 self.embedding_table = Parameter(Tensor(init_weight), diff --git a/model_zoo/mass/src/transformer/infer_mass.py b/model_zoo/mass/src/transformer/infer_mass.py index 54a0b4e54f..b887e3a7b5 100644 --- a/model_zoo/mass/src/transformer/infer_mass.py +++ b/model_zoo/mass/src/transformer/infer_mass.py @@ -17,13 +17,16 @@ import time import mindspore.nn as nn import mindspore.common.dtype as mstype +from mindspore.ops import operations as P from mindspore.common.tensor import Tensor from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore import context from src.dataset import load_dataset from .transformer_for_infer import TransformerInferModel +from .transformer_for_train import TransformerTraining from ..utils.load_weights import load_infer_weights context.set_context( @@ -156,3 +159,129 @@ def infer(config): shuffle=False) if config.test_dataset else None prediction = transformer_infer(config, eval_dataset) return prediction + + +class TransformerInferPPLCell(nn.Cell): + """ + Encapsulation class of transformer network infer for PPL. + + Args: + config(TransformerConfig): Config. + + Returns: + Tuple[Tensor, Tensor], predicted log prob and label lengths. + """ + def __init__(self, config): + super(TransformerInferPPLCell, self).__init__() + self.transformer = TransformerTraining(config, is_training=False, use_one_hot_embeddings=False) + self.batch_size = config.batch_size + self.vocab_size = config.vocab_size + self.one_hot = P.OneHot() + self.on_value = Tensor(float(1), mstype.float32) + self.off_value = Tensor(float(0), mstype.float32) + self.reduce_sum = P.ReduceSum() + self.reshape = P.Reshape() + self.cast = P.Cast() + self.flat_shape = (config.batch_size * config.seq_length,) + self.batch_shape = (config.batch_size, config.seq_length) + self.last_idx = (-1,) + + def construct(self, + source_ids, + source_mask, + target_ids, + target_mask, + label_ids, + label_mask): + """Defines the computation performed.""" + + predicted_log_probs = self.transformer(source_ids, source_mask, target_ids, target_mask) + label_ids = self.reshape(label_ids, self.flat_shape) + label_mask = self.cast(label_mask, mstype.float32) + one_hot_labels = self.one_hot(label_ids, self.vocab_size, self.on_value, self.off_value) + + label_log_probs = self.reduce_sum(predicted_log_probs * one_hot_labels, self.last_idx) + label_log_probs = self.reshape(label_log_probs, self.batch_shape) + log_probs = label_log_probs * label_mask + lengths = self.reduce_sum(label_mask, self.last_idx) + + return log_probs, lengths + + +def transformer_infer_ppl(config, dataset): + """ + Run infer with Transformer for PPL. + + Args: + config (TransformerConfig): Config. + dataset (Dataset): Dataset. + + Returns: + List[Dict], prediction, each example has 4 keys, "source", + "target", "log_prob" and "length". + """ + tfm_infer = TransformerInferPPLCell(config=config) + tfm_infer.init_parameters_data() + + parameter_dict = load_checkpoint(config.existed_ckpt) + load_param_into_net(tfm_infer, parameter_dict) + + model = Model(tfm_infer) + + log_probs = [] + lengths = [] + source_sentences = [] + target_sentences = [] + for batch in dataset.create_dict_iterator(): + source_sentences.append(batch["source_eos_ids"]) + target_sentences.append(batch["target_eos_ids"]) + + source_ids = Tensor(batch["source_eos_ids"], mstype.int32) + source_mask = Tensor(batch["source_eos_mask"], mstype.int32) + target_ids = Tensor(batch["target_sos_ids"], mstype.int32) + target_mask = Tensor(batch["target_sos_mask"], mstype.int32) + label_ids = Tensor(batch["target_eos_ids"], mstype.int32) + label_mask = Tensor(batch["target_eos_mask"], mstype.int32) + + start_time = time.time() + log_prob, length = model.predict(source_ids, source_mask, target_ids, target_mask, label_ids, label_mask) + print(f" | Batch size: {config.batch_size}, " + f"Time cost: {time.time() - start_time}.") + + log_probs.append(log_prob.asnumpy()) + lengths.append(length.asnumpy()) + + output = [] + for inputs, ref, log_prob, length in zip(source_sentences, + target_sentences, + log_probs, + lengths): + for i in range(config.batch_size): + example = { + "source": inputs[i].tolist(), + "target": ref[i].tolist(), + "log_prob": log_prob[i].tolist(), + "length": length[i] + } + output.append(example) + + return output + + +def infer_ppl(config): + """ + Transformer infer PPL api. + + Args: + config (TransformerConfig): Config. + + Returns: + list, result with + """ + eval_dataset = load_dataset(data_files=config.test_dataset, + batch_size=config.batch_size, + epoch_count=1, + sink_mode=config.dataset_sink_mode, + shuffle=False) if config.test_dataset else None + prediction = transformer_infer_ppl(config, eval_dataset) + return prediction diff --git a/model_zoo/mass/src/utils/__init__.py b/model_zoo/mass/src/utils/__init__.py index f78be57b22..efb9f6f4b6 100644 --- a/model_zoo/mass/src/utils/__init__.py +++ b/model_zoo/mass/src/utils/__init__.py @@ -20,6 +20,7 @@ from .loss_monitor import LossCallBack from .byte_pair_encoding import bpe_encode from .initializer import zero_weight, one_weight, normal_weight, weight_variable from .rouge_score import rouge +from .eval_score import get_score __all__ = [ "Dictionary", @@ -31,5 +32,6 @@ __all__ = [ "one_weight", "zero_weight", "normal_weight", - "weight_variable" + "weight_variable", + "get_score" ] diff --git a/model_zoo/mass/src/utils/eval_score.py b/model_zoo/mass/src/utils/eval_score.py new file mode 100644 index 0000000000..30ff0b2208 --- /dev/null +++ b/model_zoo/mass/src/utils/eval_score.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Get score by given metric.""" +from .ppl_score import ngram_ppl +from .rouge_score import rouge + + +def get_ppl_score(result): + """ + Calculate Perplexity(PPL) score. + + Args: + List[Dict], prediction, each example has 4 keys, "source", + "target", "log_prob" and "length". + + Returns: + Float, ppl score. + """ + log_probs = [] + total_length = 0 + + for sample in result: + log_prob = sample['log_prob'] + length = sample['length'] + log_probs.extend(log_prob) + total_length += length + + print(f" | log_prob:{log_prob}") + print(f" | length:{length}") + + ppl = ngram_ppl(log_probs, total_length, log_softmax=True) + print(f" | final PPL={ppl}.") + return ppl + + +def get_rouge_score(result, vocab): + """ + Calculate ROUGE score. + + Args: + List[Dict], prediction, each example has 4 keys, "source", + "target", "prediction" and "prediction_prob". + Dictionary, dict instance. + + retur: + Str, rouge score. + """ + + predictions = [] + targets = [] + for sample in result: + predictions.append(' '.join([vocab[t] for t in sample['prediction']])) + targets.append(' '.join([vocab[t] for t in sample['target']])) + print(f" | source: {' '.join([vocab[t] for t in sample['source']])}") + print(f" | target: {targets[-1]}") + + return rouge(predictions, targets) + + +def get_score(result, vocab=None, metric='rouge'): + """ + Get eval score. + + Args: + List[Dict], prediction. + Dictionary, dict instance. + Str, metric function, default is rouge. + + Return: + Str, Score. + """ + score = None + if metric == 'rouge': + score = get_rouge_score(result, vocab) + elif metric == 'ppl': + score = get_ppl_score(result) + else: + print(f" |metric not in (rouge, ppl)") + + return score diff --git a/model_zoo/mass/src/utils/ppl_score.py b/model_zoo/mass/src/utils/ppl_score.py index 2e5d6e6642..4a9139ced0 100644 --- a/model_zoo/mass/src/utils/ppl_score.py +++ b/model_zoo/mass/src/utils/ppl_score.py @@ -17,10 +17,7 @@ from typing import Union import numpy as np -NINF = -1.0 * 1e9 - - -def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = np.e): +def ngram_ppl(prob: Union[np.ndarray, list], length: int, log_softmax=False, index: float = np.e): """ Calculate Perplexity(PPL) score under N-gram language model. @@ -39,7 +36,8 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n Returns: float, ppl score. """ - eps = 1e-8 + if not length: + return np.inf if not isinstance(prob, (np.ndarray, list)): raise TypeError("`prob` must be type of list or np.ndarray.") if not isinstance(prob, np.ndarray): @@ -47,18 +45,17 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n if prob.shape[0] == 0: raise ValueError("`prob` length must greater than 0.") - p = 1.0 - sen_len = 0 - for t in range(prob.shape[0]): - s = prob[t] - if s <= NINF: - break - if log_softmax: - s = np.power(index, s) - p *= (1 / (s + eps)) - sen_len += 1 + print(f'length:{length}, log_prob:{prob}') - if sen_len == 0: - return np.inf + if log_softmax: + prob = np.sum(prob) / length + ppl = 1. / np.power(index, prob) + print(f'avg log prob:{prob}') + else: + p = 1. + for i in range(prob.shape[0]): + p *= (1. / prob[i]) + ppl = pow(p, 1 / length) - return pow(p, 1 / sen_len) + print(f'ppl val:{ppl}') + return ppl diff --git a/model_zoo/mobilenetv2/Readme.md b/model_zoo/mobilenetv2/Readme.md index 5b36a63fe4..1687d2cbdc 100644 --- a/model_zoo/mobilenetv2/Readme.md +++ b/model_zoo/mobilenetv2/Readme.md @@ -60,14 +60,14 @@ Dataset used: [imagenet](http://www.image-net.org/) ### Usage -- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] +- Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [CKPT_PATH] - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] ### Launch ``` # training example - Ascend: sh run_train.sh Ascend 8 192.168.0.1 0,1,2,3,4,5,6,7 ~/imagenet/train/ mobilenet_199.ckpt + Ascend: sh run_train.sh Ascend 8 0,1,2,3,4,5,6,7 hccl_config.json ~/imagenet/train/ mobilenet_199.ckpt GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ ``` diff --git a/model_zoo/mobilenetv2/scripts/run_train.sh b/model_zoo/mobilenetv2/scripts/run_train.sh index f1d80aeac6..a6e2a79477 100644 --- a/model_zoo/mobilenetv2/scripts/run_train.sh +++ b/model_zoo/mobilenetv2/scripts/run_train.sh @@ -22,14 +22,16 @@ run_ascend() exit 1 fi - if [ ! -d $5 ] + if [ ! -d $5 ] && [ ! -f $5 ] then - echo "error: DATASET_PATH=$5 is not a directory" + echo "error: DATASET_PATH=$5 is not a directory or file" exit 1 fi BASEPATH=$(cd "`dirname $0`" || exit; pwd) export PYTHONPATH=${BASEPATH}:$PYTHONPATH + export MINDSPORE_HCCL_CONFIG_PATH=$4 + export RANK_TABLE_FILE=$4 if [ -d "../train" ]; then rm -rf ../train @@ -38,8 +40,7 @@ run_ascend() cd ../train || exit python ${BASEPATH}/../src/launch.py \ --nproc_per_node=$2 \ - --visible_devices=$4 \ - --server_id=$3 \ + --visible_devices=$3 \ --training_script=${BASEPATH}/../train.py \ --dataset_path=$5 \ --pre_trained=$6 \ @@ -80,7 +81,7 @@ run_gpu() if [ $# -gt 6 ] || [ $# -lt 4 ] then echo "Usage:\n \ - Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \ + Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [CKPT_PATH]\n \ GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ " exit 1 diff --git a/model_zoo/mobilenetv2/src/launch.py b/model_zoo/mobilenetv2/src/launch.py index 48c8159664..f5c97b0bd7 100644 --- a/model_zoo/mobilenetv2/src/launch.py +++ b/model_zoo/mobilenetv2/src/launch.py @@ -15,7 +15,6 @@ """launch train script""" import os import sys -import json import subprocess import shutil from argparse import ArgumentParser @@ -42,8 +41,6 @@ def parse_args(): "each process can be bound to a single D.") parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the visible devices sequentially") - parser.add_argument("--server_id", type=str, default="", - help="server ip") parser.add_argument("--training_script", type=str, help="The full path to the single D training " "program/script to be launched in parallel, " @@ -63,66 +60,6 @@ def main(): assert os.path.isfile(args.training_script) assert len(visible_devices) >= args.nproc_per_node print('visible_devices:{}'.format(visible_devices)) - if not args.server_id: - print('pleaser input server ip!!!') - exit(0) - print('server_id:{}'.format(args.server_id)) - - # construct hccn_table - hccn_configs = open('/etc/hccn.conf', 'r').readlines() - device_ips = {} - for hccn_item in hccn_configs: - hccn_item = hccn_item.strip() - if hccn_item.startswith('address_'): - device_id, device_ip = hccn_item.split('=') - device_id = device_id.split('_')[1] - device_ips[device_id] = device_ip - print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) - hccn_table = {} - hccn_table['board_id'] = '0x0000' - hccn_table['chip_info'] = '910' - hccn_table['deploy_mode'] = 'lab' - hccn_table['group_count'] = '1' - hccn_table['group_list'] = [] - instance_list = [] - usable_dev = '' - for instance_id in range(args.nproc_per_node): - instance = {} - instance['devices'] = [] - device_id = visible_devices[instance_id] - device_ip = device_ips[device_id] - usable_dev += str(device_id) - instance['devices'].append({ - 'device_id': device_id, - 'device_ip': device_ip, - }) - instance['rank_id'] = str(instance_id) - instance['server_id'] = args.server_id - instance_list.append(instance) - hccn_table['group_list'].append({ - 'device_num': str(args.nproc_per_node), - 'server_num': '1', - 'group_name': '', - 'instance_count': str(args.nproc_per_node), - 'instance_list': instance_list, - }) - hccn_table['para_plane_nic_location'] = 'device' - hccn_table['para_plane_nic_name'] = [] - for instance_id in range(args.nproc_per_node): - eth_id = visible_devices[instance_id] - hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) - hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) - hccn_table['status'] = 'completed' - - # save hccn_table to file - table_path = os.getcwd() - if not os.path.exists(table_path): - os.mkdir(table_path) - table_fn = os.path.join(table_path, - 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) - with open(table_fn, 'w') as table_fp: - json.dump(hccn_table, table_fp, indent=4) - sys.stdout.flush() # spawn the processes processes = [] @@ -137,9 +74,6 @@ def main(): device_dir = os.path.join(cur_path, 'device{}'.format(rank_id)) env['RANK_ID'] = str(rank_id) env['DEVICE_ID'] = str(device_id) - if args.nproc_per_node > 1: - env['MINDSPORE_HCCL_CONFIG_PATH'] = table_fn - env['RANK_TABLE_FILE'] = table_fn if os.path.exists(device_dir): shutil.rmtree(device_dir) os.mkdir(device_dir) diff --git a/model_zoo/mobilenetv2/train.py b/model_zoo/mobilenetv2/train.py index 2c211b375a..4ae743f540 100644 --- a/model_zoo/mobilenetv2/train.py +++ b/model_zoo/mobilenetv2/train.py @@ -18,6 +18,7 @@ import time import argparse import random import numpy as np + from mindspore import context from mindspore import Tensor from mindspore import nn @@ -32,8 +33,9 @@ from mindspore.train.model import Model, ParallelMode from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.communication.management import init, get_group_size +from mindspore.communication.management import init, get_group_size, get_rank import mindspore.dataset.engine as de + from src.dataset import create_dataset from src.lr_generator import get_lr from src.config import config_gpu, config_ascend @@ -60,9 +62,14 @@ if args_opt.platform == "Ascend": device_id=device_id, save_graphs=False) elif args_opt.platform == "GPU": context.set_context(mode=context.GRAPH_MODE, - device_target="GPU", save_graphs=False) + device_target="GPU", + save_graphs=False) + init("nccl") + context.set_auto_parallel_context(device_num=get_group_size(), + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) else: - raise ValueError("Unsupport platform.") + raise ValueError("Unsupported device target.") class CrossEntropyWithLabelSmooth(_Loss): @@ -155,12 +162,8 @@ class Monitor(Callback): if __name__ == '__main__': if args_opt.platform == "GPU": # train on gpu - print("train args: ", args_opt, "\ncfg: ", config_gpu) - - init('nccl') - context.set_auto_parallel_context(parallel_mode="data_parallel", - mirror_mean=True, - device_num=get_group_size()) + print("train args: ", args_opt) + print("cfg: ", config_gpu) # define net net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") @@ -201,13 +204,13 @@ if __name__ == '__main__': loss_scale_manager=loss_scale) cb = [Monitor(lr_init=lr.asnumpy())] + ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" if config_gpu.save_checkpoint: config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, keep_checkpoint_max=config_gpu.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint( - prefix="mobilenetV2", directory=config_gpu.save_checkpoint_path, config=config_ck) + ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) cb += [ckpt_cb] - # begine train + # begin train model.train(epoch_size, dataset, callbacks=cb) elif args_opt.platform == "Ascend": # train on ascend diff --git a/model_zoo/mobilenetv2_quant/export.py b/model_zoo/mobilenetv2_quant/export.py new file mode 100644 index 0000000000..00e377cece --- /dev/null +++ b/model_zoo/mobilenetv2_quant/export.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Export MobilenetV2 on ImageNet""" + +import argparse +import numpy as np + +import mindspore +from mindspore import Tensor +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.quant import quant + +from src.mobilenetV2 import mobilenetV2 +from src.config import config_ascend + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--device_target', type=str, default=None, help='Run device target') +args_opt = parser.parse_args() + +if __name__ == '__main__': + cfg = None + if args_opt.device_target == "Ascend": + cfg = config_ascend + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) + else: + raise ValueError("Unsupported device target: {}.".format(args_opt.device_target)) + + # define fusion network + network = mobilenetV2(num_classes=cfg.num_classes) + # convert fusion network to quantization aware network + network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + # load checkpoint + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(network, param_dict) + + # export network + print("============== Starting export ==============") + inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32) + quant.export(network, inputs, file_name="mobilenet_quant", file_format='GEIR') + print("============== End export ==============") diff --git a/model_zoo/mobilenetv3/train.py b/model_zoo/mobilenetv3/train.py index 578893ab75..57199ec1a7 100644 --- a/model_zoo/mobilenetv3/train.py +++ b/model_zoo/mobilenetv3/train.py @@ -18,6 +18,7 @@ import time import argparse import random import numpy as np + from mindspore import context from mindspore import Tensor from mindspore import nn @@ -33,7 +34,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore.dataset.engine as de -from mindspore.communication.management import init, get_group_size +from mindspore.communication.management import init, get_group_size, get_rank + from src.dataset import create_dataset from src.lr_generator import get_lr from src.config import config_gpu, config_ascend @@ -57,10 +59,16 @@ if args_opt.platform == "Ascend": device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", - device_id=device_id, save_graphs=False) + device_id=device_id, + save_graphs=False) elif args_opt.platform == "GPU": context.set_context(mode=context.GRAPH_MODE, - device_target="GPU", save_graphs=False) + device_target="GPU", + save_graphs=False) + init("nccl") + context.set_auto_parallel_context(device_num=get_group_size(), + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) else: raise ValueError("Unsupport platform.") @@ -155,12 +163,8 @@ class Monitor(Callback): if __name__ == '__main__': if args_opt.platform == "GPU": # train on gpu - print("train args: ", args_opt, "\ncfg: ", config_gpu) - - init('nccl') - context.set_auto_parallel_context(parallel_mode="data_parallel", - mirror_mean=True, - device_num=get_group_size()) + print("train args: ", args_opt) + print("cfg: ", config_gpu) # define net net = mobilenet_v3_large(num_classes=config_gpu.num_classes) @@ -201,11 +205,11 @@ if __name__ == '__main__': loss_scale_manager=loss_scale) cb = [Monitor(lr_init=lr.asnumpy())] + ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" if config_gpu.save_checkpoint: config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, keep_checkpoint_max=config_gpu.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint( - prefix="mobilenetV3", directory=config_gpu.save_checkpoint_path, config=config_ck) + ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck) cb += [ckpt_cb] # begine train model.train(epoch_size, dataset, callbacks=cb) diff --git a/model_zoo/utils/hccl_tools/README.md b/model_zoo/utils/hccl_tools/README.md new file mode 100644 index 0000000000..b73a99e592 --- /dev/null +++ b/model_zoo/utils/hccl_tools/README.md @@ -0,0 +1,14 @@ +# description + +mindspore distributed training launch helper utilty that will generate hccl config file. + +# use + +``` +python hccl_tools.py --device_num [1,8] +``` + +output: +``` +hccl_[device_num]p_[which device]_[server_ip].json +``` \ No newline at end of file diff --git a/model_zoo/utils/hccl_tools/hccl_tools.py b/model_zoo/utils/hccl_tools/hccl_tools.py new file mode 100644 index 0000000000..ac4114c0a8 --- /dev/null +++ b/model_zoo/utils/hccl_tools/hccl_tools.py @@ -0,0 +1,165 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""generate hccl config file script""" +import os +import sys +import json +import socket +import platform +from argparse import ArgumentParser +from typing import Dict, Any + + +def parse_args(): + """ + parse args . + + Args: + + Returns: + args. + + Examples: + >>> parse_args() + """ + parser = ArgumentParser(description="mindspore distributed training launch " + "helper utilty that will generate hccl" + " config file") + parser.add_argument("--device_num", type=str, default="[0,8]", + help="The number of the D chip used. please note that the D chips" + "used must be continuous, such [0,4] means to use four chips " + "0,1,2,3; [0,1] means to use chip 0; The first four chips are" + "a group, and the last four chips are a group. In addition to" + "the [0,8] chips are allowed, other cross-group such as [3,6]" + "are prohibited.") + parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", + help="will use the visible devices sequentially") + parser.add_argument("--server_ip", type=str, default="", + help="server ip") + args = parser.parse_args() + return args + + +def get_host_ip(): + """ + get host ip + """ + ip = None + + try: + hostname = socket.gethostname() + ip = socket.gethostbyname(hostname) + except EOFError: + pass + + return ip + + +def main(): + print("start", __file__) + args = parse_args() + + # visible_devices + visible_devices = args.visible_devices.split(',') + print('visible_devices:{}'.format(visible_devices)) + + # server_id + ip = get_host_ip() + if args.server_ip: + server_id = args.server_ip + elif ip: + server_id = ip + else: + raise ValueError("please input server ip!") + print('server_id:{}'.format(server_id)) + + # device_num + first_num = int(args.device_num[1]) + last_num = int(args.device_num[3]) + if first_num < 0 or last_num > 8: + raise ValueError("device num {} must be in range [0,8] !".format(args.device_num)) + if first_num > last_num: + raise ValueError("First num {} of device num {} must less than last num {} !".format(first_num, args.device_num, + last_num)) + if first_num < 4: + if last_num > 4: + if first_num == 0 and last_num == 8: + pass + else: + raise ValueError("device num {} must be in the same group of [0,4] or [4,8] !".format(args.device_num)) + + device_num_list = list(range(first_num, last_num)) + print("device_num_list:", device_num_list) + + assert len(visible_devices) >= len(device_num_list) + + # construct hccn_table + device_ips: Dict[Any, Any] = {} + with open('/etc/hccn.conf', 'r') as fin: + for hccn_item in fin.readlines(): + if hccn_item.strip().startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip.strip() + + arch = platform.processor() + hccn_table = {'board_id': {'aarch64': '0x002f', 'x86_64': '0x0000'}[arch], + 'chip_info': '910', + 'deploy_mode': 'lab', + 'group_count': '1', + 'group_list': []} + instance_list = [] + rank_id = 0 + for instance_id in device_num_list: + instance = {'devices': []} + device_id = visible_devices[instance_id] + device_ip = device_ips[device_id] + instance['devices'].append({ + 'device_id': device_id, + 'device_ip': device_ip, + }) + print('rank_id:{}, device_id:{}, device_ip:{}'.format(rank_id, device_id, device_ip)) + instance['rank_id'] = str(rank_id) + rank_id += 1 + instance['server_id'] = server_id + instance_list.append(instance) + hccn_table['group_list'].append({ + 'device_num': str(len(device_num_list)), + 'server_num': '1', + 'group_name': '', + 'instance_count': str(len(device_num_list)), + 'instance_list': instance_list, + }) + hccn_table['para_plane_nic_location'] = 'device' + hccn_table['para_plane_nic_name'] = [] + for instance_id in device_num_list: + eth_id = visible_devices[instance_id] + hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) + hccn_table['para_plane_nic_num'] = str(len(device_num_list)) + hccn_table['status'] = 'completed' + + # save hccn_table to file + table_path = os.getcwd() + table_fn = os.path.join(table_path, + 'hccl_{}p_{}_{}.json'.format(len(device_num_list), "".join(map(str, device_num_list)), + server_id)) + with open(table_fn, 'w') as table_fp: + json.dump(hccn_table, table_fp, indent=4) + sys.stdout.flush() + print("Completed: hccl file was save in :", table_fn) + + +if __name__ == "__main__": + main() diff --git a/model_zoo/wide_and_deep/src/wide_and_deep.py b/model_zoo/wide_and_deep/src/wide_and_deep.py index 16102039a8..048bf3c66d 100644 --- a/model_zoo/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/wide_and_deep/src/wide_and_deep.py @@ -188,7 +188,7 @@ class WideDeepModel(nn.Cell): self.deep_layer_act, use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) - self.gather_v2 = P.GatherV2() + self.embeddinglookup = nn.EmbeddingLookup(target='DEVICE') self.mul = P.Mul() self.reduce_sum = P.ReduceSum(keep_dims=False) self.reshape = P.Reshape() @@ -206,11 +206,11 @@ class WideDeepModel(nn.Cell): """ mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) # Wide layer - wide_id_weight = self.gather_v2(self.wide_w, id_hldr, 0) + wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr) wx = self.mul(wide_id_weight, mask) wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) # Deep layer - deep_id_embs = self.gather_v2(self.embedding_table, id_hldr, 0) + deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr) vx = self.mul(deep_id_embs, mask) deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) deep_in = self.dense_layer_1(deep_in) diff --git a/scripts/build_icu4c.sh b/scripts/build_icu4c.sh new file mode 100755 index 0000000000..c7f21b756f --- /dev/null +++ b/scripts/build_icu4c.sh @@ -0,0 +1,8 @@ +#!/bin/bash +echo '{ + "strategy": "additive", + "featureFilters": { + "normalization": "include" + } +}' > filter.json +./icu4c/source/runConfigureICU Linux --enable-rpath --disable-tests --disable-samples --disable-icuio --disable-extras ICU_DATA_FILTER_FILE=filter.json "$@" diff --git a/serving/CMakeLists.txt b/serving/CMakeLists.txt index 3c1c08ece0..4529323fe1 100644 --- a/serving/CMakeLists.txt +++ b/serving/CMakeLists.txt @@ -13,7 +13,6 @@ add_library(protobuf::libprotobuf ALIAS protobuf::protobuf) add_executable(protobuf::libprotoc ALIAS protobuf::protoc) set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf) -set(_REFLECTION gRPC::grpc++_reflection) if(CMAKE_CROSSCOMPILING) find_program(_PROTOBUF_PROTOC protoc) else() @@ -22,10 +21,19 @@ endif() # Find gRPC installation # Looks for gRPCConfig.cmake file installed by gRPC's cmake installation. +if (EXISTS ${grpc_ROOT}/lib64) + set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc") +else() + set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc") +endif() +message("serving using grpc_DIR : " ${gPRC_DIR}) + find_package(gRPC CONFIG REQUIRED) message(STATUS "Using gRPC ${gRPC_VERSION}") set(_GRPC_GRPCPP gRPC::grpc++) +set(_REFLECTION gRPC::grpc++_reflection) + if(CMAKE_CROSSCOMPILING) find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) else() diff --git a/setup.py b/setup.py index 2840eb3b14..bf16c9106b 100644 --- a/setup.py +++ b/setup.py @@ -103,6 +103,7 @@ package_data = { 'lib/*.so*', 'lib/*.a', '.commit_id', + 'ms_serving' ] } @@ -125,6 +126,8 @@ def update_permissions(path): for filename in filenames: file_fullpath = os.path.join(dirpath, filename) os.chmod(file_fullpath, stat.S_IREAD) + if filename == "ms_serving": + os.chmod(file_fullpath, stat.S_IREAD | stat.S_IEXEC) class EggInfo(egg_info): diff --git a/tests/st/control/test_switch_layer.py b/tests/st/control/test_switch_layer.py new file mode 100644 index 0000000000..4accb44f1a --- /dev/null +++ b/tests/st/control/test_switch_layer.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor, nn +from mindspore.common import dtype as mstype + + +class CaseNet(nn.Cell): + def __init__(self): + super(CaseNet, self).__init__() + self.conv = nn.Conv2d(1, 3, 3) + self.relu = nn.ReLU() + self.softmax = nn.Softmax() + self.layers1 = (self.relu, self.softmax) + self.layers2 = (self.conv, self.relu) + + def construct(self, x, index1, index2): + x = self.layers1[index1](x) + x = self.layers2[index2](x) + return x + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_switch_layer(): + context.set_context(mode=context.GRAPH_MODE) + net = CaseNet() + data = Tensor(np.ones((1, 1, 224, 224)), mstype.float32) + idx = Tensor(0, mstype.int32) + idx2 = Tensor(-1, mstype.int32) + value = net(data, idx, idx2) + relu = nn.ReLU() + true_value = relu(data) + ret = np.allclose(value.asnumpy(), true_value.asnumpy()) + assert ret + + idx3 = Tensor(3, mstype.int32) + with pytest.raises(RuntimeError): + value = net(data, idx3, idx2) diff --git a/mindspore/model_zoo/resnet.py b/tests/st/networks/models/resnet50/src/resnet.py similarity index 100% rename from mindspore/model_zoo/resnet.py rename to tests/st/networks/models/resnet50/src/resnet.py diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index c88af6bcf7..e721b62c58 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -27,10 +27,10 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train.model import Model, ParallelMode from mindspore.train.callback import Callback from mindspore.train.loss_scale_manager import FixedLossScaleManager -from mindspore.model_zoo.resnet import resnet50 import mindspore.nn as nn import mindspore.dataset as ds +from tests.st.networks.models.resnet50.src.resnet import resnet50 from tests.st.networks.models.resnet50.src.dataset import create_dataset from tests.st.networks.models.resnet50.src.lr_generator import get_learning_rate from tests.st.networks.models.resnet50.src.config import config diff --git a/tests/st/ops/ascend/test_autocast.py b/tests/st/ops/ascend/test_autocast.py index 448dc9b4d6..35690ce2c4 100644 --- a/tests/st/ops/ascend/test_autocast.py +++ b/tests/st/ops/ascend/test_autocast.py @@ -246,3 +246,21 @@ def test_tensor_auto_cast(): bnet(t_fp32) with pytest.raises(TypeError): bnet(t_fp64) +def test_bool_tensor_and_float(): + context.set_context(mode=context.GRAPH_MODE) + t_bool = Tensor(np.ones([2, 1, 2, 2]).astype(np.bool), mstype.bool_) + t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32) + t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16) + t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32) + net = TensorFPAutoCast() + out = net(t_bool) + assert out.dtype == mstype.float32 + net = TensorIntAutoCast() + out = net(t_bool) + assert out.dtype == mstype.int32 + out = net(t_fp16) + assert out.dtype == mstype.float16 + out = net(t_fp32) + assert out.dtype == mstype.float32 + out = net(t_int32) + assert out.dtype == mstype.int32 diff --git a/tests/st/ops/ascend/test_distribution/test_bernoulli.py b/tests/st/ops/ascend/test_distribution/test_bernoulli.py new file mode 100644 index 0000000000..5652d536c7 --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_bernoulli.py @@ -0,0 +1,147 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for bernoulli distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(nn.Cell): + """ + Test class: probability of bernoulli distribution. + """ + def __init__(self): + super(Net, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('prob', x_) + +class Net1(nn.Cell): + """ + Test class: log probability of bernoulli distribution. + """ + def __init__(self): + super(Net1, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('log_prob', x_) + +class Net2(nn.Cell): + """ + Test class: kl_loss between bernoulli distributions. + """ + def __init__(self): + super(Net2, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + @ms_function + def construct(self, x_): + return self.b('kl_loss', 'Bernoulli', x_) + +class Net3(nn.Cell): + """ + Test class: mean/sd of bernoulli distribution. + """ + def __init__(self): + super(Net3, self).__init__() + self.b = nn.Bernoulli([0.5, 0.5], dtype=dtype.int32) + + @ms_function + def construct(self): + return self.b('mean'), self.b('sd') + +class Net4(nn.Cell): + """ + Test class: log probability of bernoulli distribution. + """ + def __init__(self, shape, seed=0): + super(Net4, self).__init__() + self.b = nn.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) + self.shape = shape + + @ms_function + def construct(self, probs=None): + return self.b('sample', self.shape, probs) + +def test_pmf(): + """ + Test pmf. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32) + pdf = Net() + x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) + output = pdf(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() + +def test_log_likelihood(): + """ + Test log_pmf. + """ + bernoulli_benchmark = stats.bernoulli(0.7) + expect_logpmf = bernoulli_benchmark.logpmf([0, 1, 0, 1, 1]).astype(np.float32) + logprob = Net1() + x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(np.int32), dtype=dtype.float32) + output = logprob(x_) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() + +def test_kl_loss(): + """ + Test kl_loss. + """ + probs1_a = 0.7 + probs1_b = 0.5 + probs0_a = 1 - probs1_a + probs0_b = 1 - probs1_b + expect_kl_loss = probs1_a * np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) + kl_loss = Net2() + output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() + +def test_basics(): + """ + Test mean/standard deviation and probs. + """ + basics = Net3() + mean, sd = basics() + expect_mean = [0.5, 0.5] + assert (mean.asnumpy() == expect_mean).all() + assert (sd.asnumpy() == expect_mean).all() + b = nn.Bernoulli([0.7, 0.5], dtype=dtype.int32) + probs = b.probs() + expect_probs = [0.7, 0.5] + tol = 1e-6 + assert (np.abs(probs.asnumpy() - expect_probs) < tol).all() + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + sample = Net4(shape) + output = sample() + assert output.shape == (2, 3, 2) diff --git a/tests/st/ops/ascend/test_distribution/test_normal.py b/tests/st/ops/ascend/test_distribution/test_normal.py new file mode 100644 index 0000000000..52bb1173ee --- /dev/null +++ b/tests/st/ops/ascend/test_distribution/test_normal.py @@ -0,0 +1,152 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cases for normal distribution""" +import numpy as np +from scipy import stats +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(nn.Cell): + """ + Test class: probability of normal distribution. + """ + def __init__(self): + super(Net, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.n('prob', x_) + +class Net1(nn.Cell): + """ + Test class: log probability of normal distribution. + """ + def __init__(self): + super(Net1, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) + + @ms_function + def construct(self, x_): + return self.n('log_prob', x_) + +class Net2(nn.Cell): + """ + Test class: kl_loss of normal distribution. + """ + def __init__(self): + super(Net2, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + + @ms_function + def construct(self, x_, y_): + return self.n('kl_loss', 'Normal', x_, y_) + +class Net3(nn.Cell): + """ + Test class: mean/sd of normal distribution. + """ + def __init__(self): + super(Net3, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) + + @ms_function + def construct(self): + return self.n('mean'), self.n('sd') + +class Net4(nn.Cell): + """ + Test class: mean/sd of normal distribution. + """ + def __init__(self, shape, seed=0): + super(Net4, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) + self.shape = shape + + @ms_function + def construct(self, mean=None, sd=None): + return self.n('sample', self.shape, mean, sd) + +def test_pdf(): + """ + Test pdf. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) + pdf = Net() + output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() + +def test_log_likelihood(): + """ + Test log_pdf. + """ + norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) + expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) + logprob = Net1() + output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() + +def test_kl_loss(): + """ + Test kl_loss. + """ + mean_a = np.array([3.0]).astype(np.float32) + sd_a = np.array([4.0]).astype(np.float32) + + mean_b = np.array([1.0]).astype(np.float32) + sd_b = np.array([1.0]).astype(np.float32) + + diff_log_scale = np.log(sd_a) - np.log(sd_b) + squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) + expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale + + kl_loss = Net2() + mean = Tensor(mean_b, dtype=dtype.float32) + sd = Tensor(sd_b, dtype=dtype.float32) + output = kl_loss(mean, sd) + tol = 1e-6 + assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() + +def test_basics(): + """ + Test mean/standard deviation. + """ + basics = Net3() + mean, sd = basics() + expect_mean = [3.0, 3.0] + expect_sd = [2.0, 4.0] + tol = 1e-6 + assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() + assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() + +def test_sample(): + """ + Test sample. + """ + shape = (2, 3) + seed = 10 + mean = Tensor([2.0], dtype=dtype.float32) + sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) + sample = Net4(shape, seed=seed) + output = sample(mean, sd) + assert output.shape == (2, 3, 3) diff --git a/tests/st/ops/gpu/test_ctcloss_op.py b/tests/st/ops/gpu/test_ctcloss_op.py new file mode 100644 index 0000000000..b9a88e7e70 --- /dev/null +++ b/tests/st/ops/gpu/test_ctcloss_op.py @@ -0,0 +1,119 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from mindspore.ops.composite import GradOperation + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.loss = P.CTCLossV2() + self.div = P.RealDiv() + self.cast = P.Cast() + self.mean = P.ReduceMean() + + def construct(self, probs, label, input_length, label_length): + x, _ = self.loss(probs, label, input_length, label_length) + x = self.div(x, self.cast(label_length, mstype.float32)) + x = self.mean(x) + return x + +class GradData(nn.Cell): + def __init__(self, network): + super(GradData, self).__init__() + self.grad = GradOperation(name="get_all", get_all=True, sens_param=False) + self.network = network + + def construct(self, probs, labels, input_lengths, label_lengths): + return self.grad(self.network)(probs, labels, input_lengths, label_lengths) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ctcloss(): + probs = Tensor([[[-4.4131, -4.6093, -3.4333, -3.9268, -2.8917, -3.4093, -4.2243, -1.1379, -7.1046, -0.6902], + [-2.5109, -3.3397, -4.9384, -1.2723, -1.1443, -2.4683, -2.6768, -4.1282, -2.7062, -3.1906], + [-2.5092, -1.6392, -2.0864, -4.0059, -1.5610, -2.3223, -2.4816, -2.9922, -3.1412, -2.3311]], + + [[-2.1243, -3.5773, -3.1108, -4.4253, -2.7080, -1.9653, -2.0499, -2.4418, -1.8620, -1.5229], + [-2.2479, -3.5128, -1.4189, -2.8701, -1.8562, -2.2752, -2.7019, -2.1865, -2.5634, -2.9869], + [-3.2144, -1.3986, -3.1083, -3.9634, -3.5131, -3.2317, -2.6200, -1.7938, -1.8159, -1.7255]], + + [[-3.1301, -2.1649, -0.9286, -2.9452, -2.5992, -2.0263, -2.9201, -3.2155, -2.8302, -3.3636], + [-1.4661, -3.6311, -2.4781, -4.6180, -2.7308, -1.7019, -1.5570, -2.6012, -4.0788, -2.3073], + [-2.6833, -1.5033, -3.6922, -2.6360, -2.6974, -2.6847, -2.7579, -2.1396, -1.4093, -2.9630]], + + [[-2.0094, -2.3024, -3.3673, -1.0220, -2.8326, -2.2613, -3.0535, -2.9879, -3.7015, -2.4510], + [-1.9071, -3.2603, -2.3229, -2.0572, -4.3450, -2.1284, -2.6306, -1.3824, -2.9815, -2.5061], + [-2.7931, -3.7631, -3.2440, -4.3887, -1.0271, -3.8851, -1.2418, -4.5123, -2.2993, -2.4607]], + + [[-1.5763, -2.7539, -3.6941, -3.8166, -1.2599, -2.6903, -2.5826, -4.8208, -2.9562, -1.6321], + [-3.3031, -3.0087, -1.9982, -1.9081, -3.8731, -2.8764, -2.2485, -2.3808, -1.4283, -2.1625], + [-2.4516, -3.2394, -4.2053, -4.3541, -2.5229, -4.0717, -1.4894, -2.3151, -1.1098, -2.3465]]], + dtype=mstype.float32) + labels = Tensor([9, 4, 6, 4, 7, 1, 4, 6, 6, 8], dtype=mstype.int32) + input_lengths = Tensor([5, 5, 5], dtype=mstype.int32) + label_lengths = Tensor([3, 3, 4], dtype=mstype.int32) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = Net() + ctc_loss = net(probs, labels, input_lengths, label_lengths) + expect_loss = [2.4099] + assert np.allclose(ctc_loss.asnumpy(), expect_loss) + + grad = GradData(net)(probs, labels, input_lengths, label_lengths) + expect_grad = [[[8.8442e-05, 1.1065e-03, 3.5867e-03, 2.1896e-03, 6.1646e-03, + 3.6738e-03, 1.6262e-03, 3.5610e-02, 9.1258e-05, -5.4134e-02], + [-3.7523e-03, 3.9386e-03, 7.9623e-04, 3.1132e-02, -6.2954e-02, + 9.4143e-03, 7.6425e-03, 1.7902e-03, 7.4211e-03, 4.5719e-03], + [6.7778e-03, 1.6178e-02, 1.0344e-02, 1.5173e-03, -6.5840e-02, + 8.1707e-03, 6.9674e-03, 4.1814e-03, 3.6026e-03, 8.0991e-03]], + + [[-1.2581e-02, 3.1057e-03, 4.9517e-03, 1.3301e-03, -2.6320e-02, + 1.5568e-02, 1.4305e-02, 9.6671e-03, 1.7262e-02, -2.7292e-02], + [-1.5566e-02, 3.3126e-03, 2.6887e-02, 6.2993e-03, -3.9716e-02, + 1.1420e-02, 7.4531e-03, -1.4252e-02, 8.5603e-03, 5.6048e-03], + [3.3483e-03, 2.0579e-02, 3.7231e-03, 1.5832e-03, 2.4837e-03, + 3.2909e-03, -7.7267e-02, 1.3861e-02, 1.3558e-02, 1.4840e-02]], + + [[-8.0007e-03, 1.2751e-02, 4.3901e-02, 5.8435e-03, -7.2627e-02, + 1.4647e-02, -8.0584e-03, 4.4595e-03, 6.5557e-03, 5.2891e-04], + [-3.6006e-02, 1.5308e-03, 9.3225e-03, 1.0969e-03, -2.5098e-03, + 2.0260e-02, 2.3419e-02, -3.0053e-02, 1.8809e-03, 1.1059e-02], + [-7.7639e-02, 1.8533e-02, 2.0764e-03, 5.9706e-03, 5.6150e-03, + 5.6868e-03, 5.2854e-03, 9.8085e-03, 2.0360e-02, 4.3053e-03]], + + [[-2.6776e-02, 1.1113e-02, 3.8314e-03, 3.9986e-02, -1.6020e-02, + 1.1579e-02, -4.1635e-02, 5.5992e-03, 2.7429e-03, 9.5786e-03], + [-6.8619e-03, -6.4066e-03, 1.0888e-02, 1.4201e-02, 1.4413e-03, + 1.3225e-02, 8.0039e-03, -4.9191e-02, 5.6352e-03, 9.0651e-03], + [5.1026e-03, 1.9343e-03, 3.2506e-03, 1.0347e-03, 2.9837e-02, + 1.7121e-03, -5.9261e-02, 9.1443e-04, 8.3608e-03, 7.1146e-03]], + + [[-2.0848e-02, 7.0754e-03, 2.7633e-03, 2.4447e-03, 3.1520e-02, + 7.5401e-03, -5.8895e-02, 8.9559e-04, 5.7796e-03, 2.1724e-02], + [-1.3499e-03, -1.0019e-01, 1.5064e-02, 1.6485e-02, 2.3104e-03, + 6.2597e-03, 1.1729e-02, 1.0275e-02, 2.6635e-02, 1.2782e-02], + [7.1796e-03, 3.2656e-03, 1.2430e-03, 1.0712e-03, 6.6856e-03, + 1.4207e-03, 1.8792e-02, 8.2297e-03, -5.5865e-02, 7.9753e-03]]] + assert np.allclose(grad[0].asnumpy(), expect_grad, atol=1e-5) diff --git a/tests/st/ops/gpu/test_dense_op.py b/tests/st/ops/gpu/test_dense_op.py index 220f7ae051..e9c010ea77 100644 --- a/tests/st/ops/gpu/test_dense_op.py +++ b/tests/st/ops/gpu/test_dense_op.py @@ -228,6 +228,7 @@ def test_biasadd_3d(): error = np.ones(shape=[3, 4, 8]) * 1.0e-6 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") net = BiasAdd() + net.set_grad() result = net(x, b) diff = result.asnumpy() - expect assert np.all(diff < error) diff --git a/tests/st/ops/gpu/test_normal.py b/tests/st/ops/gpu/test_normal.py new file mode 100644 index 0000000000..0c4866f6f0 --- /dev/null +++ b/tests/st/ops/gpu/test_normal.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0): + super(Net, self).__init__() + self.shape = shape + self.seed = seed + + def construct(self, mean, stddev): + return C.normal(self.shape, mean, stddev, self.seed) + + +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + mean = 1.0 + stddev = 1.0 + net = Net(shape, seed) + tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) + output = net(tmean, tstddev) + assert output.shape == (3, 2, 4) + + +def test_net_ND(): + seed = 10 + shape = (3, 1, 2) + mean = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32) + stddev = np.array([1.0]).astype(np.float32) + net = Net(shape, seed) + tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32) + output = net(tmean, tstddev) + assert output.shape == (3, 2, 2) diff --git a/tests/st/ops/gpu/test_smoothl1loss_op.py b/tests/st/ops/gpu/test_smoothl1loss_op.py new file mode 100644 index 0000000000..040f404eb0 --- /dev/null +++ b/tests/st/ops/gpu/test_smoothl1loss_op.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_smoothl1loss(): + np.random.seed(42) + prediction = np.random.randn(20).astype(np.float32) + target = np.random.randn(20).astype(np.float32) + sigma = 1.0 + + net = nn.SmoothL1Loss(sigma) + loss = net(Tensor(prediction), Tensor(target)) + expect = [0.46941718, 0.00382918, 0.16829303, 2.447778, 0.04812113, 0.05953304, + 2.2302065, 0.07672881, 0.00860204, 0.34798968, 0.00956192, 1.818008, + 0.03262977, 0.36599946, 2.047463, 0.2168481, 0.7216947, 1.7739174, + 0.08826803, 1.109165] + assert np.allclose(loss.asnumpy(), expect) + + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + + def construct(self, x1, x2, sens): + gout = self.grad(self.network)(x1, x2, sens) + return gout + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_smoothl1loss_grad(): + np.random.seed(42) + prediction = np.random.randn(20).astype(np.float32) + target = np.random.randn(20).astype(np.float32) + sens = np.random.randn(20).astype(np.float32) + sigma = 1.0 + + net = nn.SmoothL1Loss(sigma) + grad = Grad(net) + dx = grad(Tensor(prediction), Tensor(target), Tensor(sens)) + + dx1_expect = [-0.71552587, 0.01499678, -0.06709455, -0.30110368, -0.45868093, + 0.24838912, -0.46063876, 0.41411355, 0.04507046, -1.4708229, + 0.04481723, 0.38508227, -0.17292616, -0.52333146, -1.0309995, + 0.61330026, 0.83921754, -0.3092124, 0.1391843, -0.9755451] + + dx2_expect = [0.71552587, -0.01499678, 0.06709455, 0.30110368, 0.45868093, + -0.24838912, 0.46063876, -0.41411355, -0.04507046, 1.4708229, + -0.04481723, -0.38508227, 0.17292616, 0.52333146, 1.0309995, + -0.61330026, -0.83921754, 0.3092124, -0.1391843, 0.9755451] + + assert np.allclose(dx[0].asnumpy(), dx1_expect) + assert np.allclose(dx[1].asnumpy(), dx2_expect) diff --git a/tests/st/pynative/test_ascend_lenet.py b/tests/st/pynative/test_ascend_lenet.py deleted file mode 100644 index 021c71d9cd..0000000000 --- a/tests/st/pynative/test_ascend_lenet.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import time -import numpy as np -import pytest - -import mindspore.nn as nn -from mindspore import context, Tensor, ParameterTuple -from mindspore.common import dtype as mstype -from mindspore.common.initializer import TruncatedNormal -from mindspore.nn.optim import Momentum -from mindspore.nn.wrap.cell_wrapper import WithLossCell -from mindspore.ops import composite as C -from mindspore.ops import functional as F -from mindspore.ops import operations as P - -np.random.seed(1) - - -def weight_variable(): - """weight initial""" - return TruncatedNormal(0.02) - - -def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): - """weight initial for conv layer""" - weight = weight_variable() - return nn.Conv2d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, padding=padding, - weight_init=weight, has_bias=False, pad_mode="valid") - - -def fc_with_initialize(input_channels, out_channels): - """weight initial for fc layer""" - weight = weight_variable() - bias = weight_variable() - return nn.Dense(input_channels, out_channels, weight, bias) - - -class LeNet(nn.Cell): - """ - Lenet network - Args: - num_class (int): Num classes, Default: 10. - Returns: - Tensor, output tensor - Examples: - >>> LeNet(num_class=10) - """ - - def __init__(self, num_class=10): - super(LeNet, self).__init__() - self.num_class = num_class - self.batch_size = 32 - self.conv1 = conv(1, 6, 5) - self.conv2 = conv(6, 16, 5) - self.fc1 = fc_with_initialize(16 * 5 * 5, 120) - self.fc2 = fc_with_initialize(120, 84) - self.fc3 = fc_with_initialize(84, self.num_class) - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() - - def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.reshape(x, (self.batch_size, -1)) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) - x = self.fc3(x) - return x - - -class CrossEntropyLoss(nn.Cell): - """ - Define loss for network - """ - - def __init__(self): - super(CrossEntropyLoss, self).__init__() - self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() - self.mean = P.ReduceMean() - self.one_hot = P.OneHot() - self.on_value = Tensor(1.0, mstype.float32) - self.off_value = Tensor(0.0, mstype.float32) - self.num = Tensor(32.0, mstype.float32) - - def construct(self, logits, label): - label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value) - loss = self.cross_entropy(logits, label)[0] - loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num) - return loss - - -class GradWrap(nn.Cell): - """ - GradWrap definition - """ - - def __init__(self, network): - super(GradWrap, self).__init__() - self.network = network - self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) - - def construct(self, x, label): - weights = self.weights - return C.grad_by_list(self.network, weights)(x, label) - - -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard -def test_ascend_pynative_lenet(): - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") - - epoch_size = 20 - batch_size = 32 - inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32)) - labels = Tensor(np.ones([batch_size]).astype(np.int32)) - - net = LeNet() - criterion = CrossEntropyLoss() - optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) - - net_with_criterion = WithLossCell(net, criterion) - train_network = GradWrap(net_with_criterion) - train_network.set_train() - total_time = 0 - - for epoch in range(0, epoch_size): - start_time = time.time() - fw_output = net(inputs) - loss_output = criterion(fw_output, labels) - grads = train_network(inputs, labels) - optimizer(grads) - end_time = time.time() - cost_time = end_time - start_time - total_time = total_time + cost_time - - print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) - assert loss_output.asnumpy() < 0.1 diff --git a/tests/st/pynative/test_implicit_conversion.py b/tests/st/pynative/test_implicit_conversion.py new file mode 100644 index 0000000000..fce6c24cbb --- /dev/null +++ b/tests/st/pynative/test_implicit_conversion.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test implicit conversion """ +import numpy as np + +from mindspore import Tensor + + +def test_float_tensor_and_int_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = 2 + ret_actual = x + y + ret_expect = Tensor(np.array([[2.1, 2.2, 2.3], [2.4, 2.5, 2.6]], dtype=np.float32)) + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_bool_tensor_and_float_add(): + x = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) + y = 3.3 + ret_actual = x + y + ret_expect = Tensor(np.array([[4.3, 3.3], [3.3, 4.3]], dtype=np.float32)) + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_bool_tensor_and_int_add(): + x = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) + y = 3 + ret_actual = x + y + ret_expect = Tensor(np.array([[4, 3], [3, 4]], dtype=np.int32)) + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_bool_and_int_tensor_add(): + x = True + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) + ret_actual = x + y + ret_expect = Tensor(np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)) + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + +def test_float_tensor_and_int_tensor_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) + ret_actual = x + y + ret_expect = Tensor(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)) + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_float_tensor_and_float_tensor_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float64)) + y = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)) + ret_actual = x + y + ret_expect = Tensor(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)) + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_int_tensor_and_int_tensor_add(): + x = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int16)) + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) + ret_actual = x + y + ret_expect = Tensor(np.array([[2, 4, 6], [8, 10, 12]], dtype=np.int32)) + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_float_tensor_and_bool_tensors_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = Tensor(np.array([[True, True, True], [False, False, False]], dtype=np.bool_)) + ret_actual = x + y + ret_expect = Tensor(np.array([[1.1, 1.2, 1.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() diff --git a/tests/st/pynative/test_pynative_hook.py b/tests/st/pynative/test_pynative_hook.py new file mode 100644 index 0000000000..0ce4ba4f69 --- /dev/null +++ b/tests/st/pynative/test_pynative_hook.py @@ -0,0 +1,198 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype + +from mindspore import Tensor +from mindspore import context +from mindspore import ParameterTuple +from mindspore.nn import Momentum +from mindspore.nn import WithLossCell +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.common.initializer import TruncatedNormal + +context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +class test_custom_hook_function_base(): + def __init__(self): + pass + + def test_custom_hook_function(self, hook_function, cell_hook_function): + return hook_function, cell_hook_function + + +def cell_hook_function_print_grad(cell_id, grad_input, grad_output): + assert grad_output[0].asnumpy().shape == (32, 6, 14, 14) + assert grad_input[0].asnumpy().shape == (32, 16, 10, 10) + + +def custom_hook_function_print_and_save_grad(grad_out): + assert grad_out[0].asnumpy().shape == (32, 6, 28, 28) + + +class LeNet5(nn.Cell): + def __init__(self, hook_function, cell_hook_function, num_class=10): + super(LeNet5, self).__init__() + self.num_class = num_class + self.batch_size = 32 + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.conv1.register_backward_hook(cell_hook_function) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + self.hook = P.HookBackward(hook_function) + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.hook(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.reshape(x, (self.batch_size, -1)) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +class GradWrap(nn.Cell): + """ GradWrap definition """ + def __init__(self, network): + super(GradWrap, self).__init__(auto_prefix=False) + self.network = network + self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) + + def construct(self, x, label): + weights = self.weights + return C.GradOperation('get_by_list', get_by_list=True)(self.network, weights)(x, label) + + +class test_custom_cell_base(): + def __init__(self): + pass + + def test_custom_cell_function(self, cell): + return cell + + +class MulAdd(nn.Cell): + def __init__(self): + super(MulAdd, self).__init__() + + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + assert x.asnumpy() == 1.0 + assert y.asnumpy() == 2.0 + assert out.asnumpy() == 4.0 + assert dout.asnumpy() == 1.0 + return dout, y + + +class Ms_Cell(nn.Cell): + def __init__(self): + super(Ms_Cell, self).__init__() + self.relu = P.ReLU() + + def construct(self, x): + return self.relu(x) + + def bprop(self, x, out, dout): + dout = Tensor(np.ones([5, 5]).astype(np.float32)) + assert dout.shape == (5, 5) + return dout + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_lenet_train_hook_function_print_and_save_grad(): + hook = test_custom_hook_function_base() + function = hook.test_custom_hook_function(custom_hook_function_print_and_save_grad, + cell_hook_function_print_grad) + net = LeNet5(hook_function=function[0], cell_hook_function=function[1]) + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) + criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=False) + net_with_criterion = WithLossCell(net, criterion) + train_network = GradWrap(net_with_criterion) + train_network.set_train() + + input_data = Tensor(np.ones([net.batch_size, 1, 32, 32]).astype(np.float32) * 0.01) + label = Tensor(np.ones([net.batch_size, net.num_class]).astype(np.float32)) + output = net(Tensor(input_data)) + criterion(output, label) + grads = train_network(input_data, label) + success = optimizer(grads) + assert success + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_custom_bprop_and_Cell_MulAdd(): + custom_cell = test_custom_cell_base() + mul_add = custom_cell.test_custom_cell_function(MulAdd()) + mul_add.bprop_debug = True + C.grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) + assert C.grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) == \ + (Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32)) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_custom_bprop_and_Cell_Ms_Cell(): + custom_cell = test_custom_cell_base() + ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell()) + ms_Cell.bprop_debug = True + assert C.grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(1.0, mstype.float32),) + \ No newline at end of file diff --git a/tests/st/pynative/test_pynative_lenet.py b/tests/st/pynative/test_pynative_lenet.py new file mode 100644 index 0000000000..c6166d0517 --- /dev/null +++ b/tests/st/pynative/test_pynative_lenet.py @@ -0,0 +1,161 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import time +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import context, Tensor, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore.common.initializer import TruncatedNormal +from mindspore.nn.optim import Momentum +from mindspore.nn.wrap.cell_wrapper import WithLossCell +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P + +np.random.seed(1) + + +def weight_variable(): + """weight initial""" + return TruncatedNormal(0.02) + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """weight initial for conv layer""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + """weight initial for fc layer""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +class LeNet(nn.Cell): + """ + Lenet network + Args: + num_class (int): Num classes, Default: 10. + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + """ + + def __init__(self, num_class=10): + super(LeNet, self).__init__() + self.num_class = num_class + self.batch_size = 32 + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, self.num_class) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.reshape(x, (self.batch_size, -1)) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +class CrossEntropyLoss(nn.Cell): + """ + Define loss for network + """ + + def __init__(self): + super(CrossEntropyLoss, self).__init__() + self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean() + self.one_hot = P.OneHot() + self.on_value = Tensor(1.0, mstype.float32) + self.off_value = Tensor(0.0, mstype.float32) + self.num = Tensor(32.0, mstype.float32) + + def construct(self, logits, label): + label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value) + loss = self.cross_entropy(logits, label)[0] + loss = P.RealDiv()(P.ReduceSum()(loss, -1), self.num) + return loss + + +class GradWrap(nn.Cell): + """ + GradWrap definition + """ + + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) + + def construct(self, x, label): + weights = self.weights + return C.grad_by_list(self.network, weights)(x, label) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ascend_pynative_lenet(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + + epoch_size = 20 + batch_size = 32 + inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32)) + labels = Tensor(np.ones([batch_size]).astype(np.int32)) + + net = LeNet() + criterion = CrossEntropyLoss() + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9) + + net_with_criterion = WithLossCell(net, criterion) + train_network = GradWrap(net_with_criterion) + train_network.set_train() + total_time = 0 + + for epoch in range(0, epoch_size): + start_time = time.time() + fw_output = net(inputs) + loss_output = criterion(fw_output, labels) + grads = train_network(inputs, labels) + optimizer(grads) + end_time = time.time() + cost_time = end_time - start_time + total_time = total_time + cost_time + + print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) + assert loss_output.asnumpy() < 0.004 + assert loss_output.asnumpy() > 0.003 diff --git a/tests/st/pynative/test_pynative_resnet50.py b/tests/st/pynative/test_pynative_resnet50.py new file mode 100644 index 0000000000..de9ecebb9c --- /dev/null +++ b/tests/st/pynative/test_pynative_resnet50.py @@ -0,0 +1,432 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import time +import random +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.transforms.vision.c_transforms as vision +import mindspore.nn as nn +import mindspore.ops.functional as F + +from mindspore import Tensor +from mindspore import context +from mindspore import ParameterTuple +from mindspore.nn import Cell +from mindspore.ops import operations as P +from mindspore.ops import composite as CP +from mindspore.nn.optim.momentum import Momentum +from mindspore.common.initializer import initializer +from mindspore.nn.wrap.cell_wrapper import WithLossCell + +random.seed(1) +np.random.seed(1) +ds.config.set_seed(1) + + +def weight_variable(shape): + return initializer('XavierUniform', shape=shape, dtype=mstype.float32) + + +def weight_variable_uniform(shape): + return initializer('Uniform', shape=shape, dtype=mstype.float32) + + +def weight_variable_0(shape): + zeros = np.zeros(shape).astype(np.float32) + return Tensor(zeros) + + +def weight_variable_1(shape): + ones = np.ones(shape).astype(np.float32) + return Tensor(ones) + + +def conv3x3(in_channels, out_channels, stride=1, padding=0): + """3x3 convolution """ + weight_shape = (out_channels, in_channels, 3, 3) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def conv1x1(in_channels, out_channels, stride=1, padding=0): + """1x1 convolution""" + weight_shape = (out_channels, in_channels, 1, 1) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def conv7x7(in_channels, out_channels, stride=1, padding=0): + """1x1 convolution""" + weight_shape = (out_channels, in_channels, 7, 7) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def bn_with_initialize(out_channels): + shape = (out_channels) + mean = weight_variable_0(shape) + var = weight_variable_1(shape) + beta = weight_variable_0(shape) + gamma = weight_variable_uniform(shape) + bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma, + beta_init=beta, moving_mean_init=mean, moving_var_init=var) + return bn + + +def bn_with_initialize_last(out_channels): + shape = (out_channels) + mean = weight_variable_0(shape) + var = weight_variable_1(shape) + beta = weight_variable_0(shape) + gamma = weight_variable_uniform(shape) + bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma, + beta_init=beta, moving_mean_init=mean, moving_var_init=var) + return bn + + +def fc_with_initialize(input_channels, out_channels): + weight_shape = (out_channels, input_channels) + weight = weight_variable(weight_shape) + bias_shape = (out_channels) + bias = weight_variable_uniform(bias_shape) + return nn.Dense(input_channels, out_channels, weight, bias) + + +class ResidualBlock(nn.Cell): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + stride=1): + super(ResidualBlock, self).__init__() + + out_chls = out_channels // self.expansion + self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0) + self.bn1 = bn_with_initialize(out_chls) + + self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0) + self.bn2 = bn_with_initialize(out_chls) + + self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) + self.bn3 = bn_with_initialize_last(out_channels) + + self.relu = P.ReLU() + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResidualBlockWithDown(nn.Cell): + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + stride=1, + down_sample=False): + super(ResidualBlockWithDown, self).__init__() + + out_chls = out_channels // self.expansion + self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0) + self.bn1 = bn_with_initialize(out_chls) + + self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0) + self.bn2 = bn_with_initialize(out_chls) + + self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) + self.bn3 = bn_with_initialize_last(out_channels) + + self.relu = P.ReLU() + self.downSample = down_sample + + self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0) + self.bn_down_sample = bn_with_initialize(out_channels) + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + identity = self.conv_down_sample(identity) + identity = self.bn_down_sample(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class MakeLayer0(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer0, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True) + self.b = block(out_channels, out_channels, stride=stride) + self.c = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + + return x + + +class MakeLayer1(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer1, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + self.d = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + x = self.d(x) + + return x + + +class MakeLayer2(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer2, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + self.d = block(out_channels, out_channels, stride=1) + self.e = block(out_channels, out_channels, stride=1) + self.f = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + x = self.d(x) + x = self.e(x) + x = self.f(x) + + return x + + +class MakeLayer3(nn.Cell): + + def __init__(self, block, in_channels, out_channels, stride): + super(MakeLayer3, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + + def construct(self, x): + x = self.a(x) + x = self.b(x) + x = self.c(x) + + return x + + +class ResNet(nn.Cell): + + def __init__(self, block, num_classes=100, batch_size=32): + super(ResNet, self).__init__() + self.batch_size = batch_size + self.num_classes = num_classes + + self.conv1 = conv7x7(3, 64, stride=2, padding=0) + + self.bn1 = bn_with_initialize(64) + self.relu = P.ReLU() + self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME") + + self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1) + self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2) + self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2) + self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2) + + self.pool = P.ReduceMean(keep_dims=True) + self.squeeze = P.Squeeze(axis=(2, 3)) + self.fc = fc_with_initialize(512 * block.expansion, num_classes) + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x)[0] + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.pool(x, (2, 3)) + x = self.squeeze(x) + x = self.fc(x) + return x + + +def resnet50(batch_size, num_classes): + return ResNet(ResidualBlock, num_classes, batch_size) + + +def create_dataset(repeat_num=1, training=True, batch_size=32): + data_home = "/home/workspace/mindspore_dataset" + data_dir = data_home + "/cifar-10-batches-bin" + if not training: + data_dir = data_home + "/cifar-10-verify-bin" + data_set = ds.Cifar10Dataset(data_dir) + + resize_height = 224 + resize_width = 224 + rescale = 1.0 / 255.0 + shift = 0.0 + + # define map operations + random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_horizontal_op = vision.RandomHorizontalFlip() + # interpolation default BILINEAR + resize_op = vision.Resize((resize_height, resize_width)) + rescale_op = vision.Rescale(rescale, shift) + normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) + changeswap_op = vision.HWC2CHW() + type_cast_op = C.TypeCast(mstype.int32) + + c_trans = [] + if training: + c_trans = [random_crop_op, random_horizontal_op] + c_trans += [resize_op, rescale_op, normalize_op, + changeswap_op] + + # apply map operations on images + data_set = data_set.map(input_columns="label", operations=type_cast_op) + data_set = data_set.map(input_columns="image", operations=c_trans) + + # apply shuffle operations + data_set = data_set.shuffle(buffer_size=1000) + + # apply batch operations + data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) + + # apply repeat operations + data_set = data_set.repeat(repeat_num) + + return data_set + + +class CrossEntropyLoss(nn.Cell): + def __init__(self): + super(CrossEntropyLoss, self).__init__() + self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean() + self.one_hot = P.OneHot() + self.one = Tensor(1.0, mstype.float32) + self.zero = Tensor(0.0, mstype.float32) + + def construct(self, logits, label): + label = self.one_hot(label, F.shape(logits)[1], self.one, self.zero) + loss = self.cross_entropy(logits, label)[0] + loss = self.mean(loss, (-1,)) + return loss + + +class GradWrap(Cell): + """ GradWrap definition """ + + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + self.weights = ParameterTuple(network.trainable_params()) + + def construct(self, x, label): + weights = self.weights + return CP.grad_by_list(self.network, weights)(x, label) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_resnet50(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + + batch_size = 32 + num_classes = 10 + net = resnet50(batch_size, num_classes) + criterion = CrossEntropyLoss() + optimizer = Momentum(learning_rate=0.01, momentum=0.9, + params=filter(lambda x: x.requires_grad, net.get_parameters())) + + net_with_criterion = WithLossCell(net, criterion) + net_with_criterion.set_grad() + train_network = GradWrap(net_with_criterion) + train_network.set_train() + + step = 0 + max_step = 20 + data_set = create_dataset(repeat_num=1, training=True, batch_size=batch_size) + for element in data_set.create_dict_iterator(): + step = step + 1 + if step > max_step: + break + start_time = time.time() + input_data = Tensor(element["image"]) + input_label = Tensor(element["label"]) + loss_output = net_with_criterion(input_data, input_label) + grads = train_network(input_data, input_label) + optimizer(grads) + end_time = time.time() + cost_time = end_time - start_time + print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) + if step > 1: + assert cost_time < 0.3 + \ No newline at end of file diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index dcc798165b..880a281037 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -17,6 +17,7 @@ message("PYTHON_INCLUDE_DIRS = ${PYTHON_INCLUDE_DIRS}") message("PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}") include_directories(${PYTHON_INCLUDE_DIRS}) include_directories(${MS_CCSRC_PATH}) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/) include_directories(${CMAKE_BINARY_DIR}) @@ -27,12 +28,20 @@ link_directories(${MS_CCSRC_BUILD_PATH}) if(ENABLE_MINDDATA) add_definitions(-D ENABLE_MINDDATA) - link_directories(${MS_CCSRC_BUILD_PATH}/dataset) - link_directories(${MS_CCSRC_BUILD_PATH}/mindrecord) + link_directories(${MS_CCSRC_BUILD_PATH}/minddata/dataset) + link_directories(${MS_CCSRC_BUILD_PATH}/minddata/mindrecord) endif() # fetch ut test files if(ENABLE_MINDDATA) - file(GLOB_RECURSE UT_SRCS ./*.cc) + file(GLOB_RECURSE UT_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ./*.cc) + if(NOT ENABLE_PYTHON) + set(PYTHON_RELATED_SRCS + dataset/filter_op_test.cc + dataset/voc_op_test.cc + dataset/manifest_op_test.cc + ) + list(REMOVE_ITEM UT_SRCS ${PYTHON_RELATED_SRCS}) + endif() else() file(GLOB_RECURSE TEMP_UT_SRCS ./*.cc) foreach(OBJ ${TEMP_UT_SRCS}) @@ -43,78 +52,83 @@ else() endif() file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "../../../mindspore/ccsrc/ir/*.cc" + "../../../mindspore/core/base/*.cc" + "../../../mindspore/core/abstract/*.cc" + "../../../mindspore/core/ir/*.cc" "../../../mindspore/ccsrc/common/*.cc" "../../../mindspore/ccsrc/utils/*.cc" - "../../../mindspore/ccsrc/parallel/*.cc" - "../../../mindspore/ccsrc/pipeline/parse/*.cc" - "../../../mindspore/ccsrc/pipeline/static_analysis/*.cc" - "../../../mindspore/ccsrc/pipeline/pipeline.cc" - "../../../mindspore/ccsrc/pipeline/resource.cc" - "../../../mindspore/ccsrc/pipeline/pass.cc" - "../../../mindspore/ccsrc/pipeline/action.cc" - "../../../mindspore/ccsrc/pipeline/validator.cc" - "../../../mindspore/ccsrc/pipeline/remove_value_node_dup.cc" - "../../../mindspore/ccsrc/optimizer/*.cc" + "../../../mindspore/ccsrc/pipeline/jit/parse/*.cc" + "../../../mindspore/ccsrc/pipeline/jit/static_analysis/*.cc" + "../../../mindspore/ccsrc/pipeline/jit/pipeline.cc" + "../../../mindspore/ccsrc/pipeline/jit/resource.cc" + "../../../mindspore/ccsrc/pipeline/jit/pass.cc" + "../../../mindspore/ccsrc/pipeline/jit/action.cc" + "../../../mindspore/ccsrc/pipeline/jit/validator.cc" + "../../../mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc" + "../../../mindspore/ccsrc/frontend/optimizer/*.cc" + "../../../mindspore/ccsrc/frontend/parallel/*.cc" "../../../mindspore/ccsrc/debug/*.cc" - "../../../mindspore/ccsrc/operator/*.cc" - "../../../mindspore/ccsrc/transform/*.cc" - "../../../mindspore/ccsrc/session/anf_runtime_algorithm.cc" - "../../../mindspore/ccsrc/session/ascend_session.cc" - "../../../mindspore/ccsrc/session/ascend_control_parser.cc" - "../../../mindspore/ccsrc/session/kernel_graph.cc" - "../../../mindspore/ccsrc/session/session_basic.cc" - "../../../mindspore/ccsrc/session/session_factory.cc" + "../../../mindspore/ccsrc/frontend/operator/*.cc" + "../../../mindspore/ccsrc/transform/graph_ir/*.cc" + "../../../mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc" + "../../../mindspore/ccsrc/backend/session/ascend_session.cc" + "../../../mindspore/ccsrc/backend/session/ascend_control_parser.cc" + "../../../mindspore/ccsrc/backend/session/kernel_graph.cc" + "../../../mindspore/ccsrc/backend/session/session_basic.cc" + "../../../mindspore/ccsrc/backend/session/session_factory.cc" "../../../mindspore/ccsrc/vm/*.cc" - "../../../mindspore/ccsrc/pynative/*.cc" + "../../../mindspore/ccsrc/pipeline/pynative/*.cc" "../../../mindspore/ccsrc/pybind_api/*.cc" - "../../../mindspore/ccsrc/kernel/akg/*.cc" - "../../../mindspore/ccsrc/kernel/kash/*.cc" - "../../../mindspore/ccsrc/kernel/cce/*.cc" - "../../../mindspore/ccsrc/kernel/rts/*.cc" - "../../../mindspore/ccsrc/kernel/hccl/*.cc" - "../../../mindspore/ccsrc/kernel/kernel_query.cc" - "../../../mindspore/ccsrc/kernel/kernel_build_info.cc" - "../../../mindspore/ccsrc/pre_activate/ascend/*.cc" - "../../../mindspore/ccsrc/pre_activate/common/*.cc" - "../../../mindspore/ccsrc/pre_activate/gpu/*.cc" - "../../../mindspore/ccsrc/pre_activate/mem_reuse/*.cc" - "../../../mindspore/ccsrc/pre_activate/pass/*.cc" - "../../../mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.cc" - "../../../mindspore/ccsrc/kernel/rts/rt_kernel_info.cc" - "../../../mindspore/ccsrc/kernel/common_utils.cc" - "../../../mindspore/ccsrc/kernel/oplib/*.cc" - "../../../mindspore/ccsrc/kernel/tbe/*.cc" - "../../../mindspore/ccsrc/device/kernel_runtime.cc" - "../../../mindspore/ccsrc/device/memory_manager.cc" - "../../../mindspore/ccsrc/device/kernel_runtime_manager.cc" - "../../../mindspore/ccsrc/device/kernel_info.cc" - "../../../mindspore/ccsrc/device/ascend/profiling/*.cc" - "../../../mindspore/ccsrc/device/ascend/kernel_select_ascend.cc" - "../../../mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc" - "../../../mindspore/ccsrc/device/convert_tensor_utils.cc" - "../../../mindspore/ccsrc/device/ascend/kernel_build_ascend.cc" - "../../../mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc" - "../../../mindspore/ccsrc/device/ascend/ascend_memory_manager.cc" - "../../../mindspore/ccsrc/device/ascend/ascend_device_address.cc" - "../../../mindspore/ccsrc/device/ascend/ascend_memory_pool.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/akg/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/kash/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/rts/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/hccl/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc" + "../../../mindspore/ccsrc/backend/optimizer/ascend/*.cc" + "../../../mindspore/ccsrc/backend/optimizer/common/*.cc" + "../../../mindspore/ccsrc/backend/optimizer/gpu/*.cc" + "../../../mindspore/ccsrc/backend/optimizer/mem_reuse/*.cc" + "../../../mindspore/ccsrc/backend/optimizer/pass/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/common_utils.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/oplib/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/tbe/*.cc" + "../../../mindspore/ccsrc/runtime/device/kernel_runtime.cc" + "../../../mindspore/ccsrc/runtime/device/memory_manager.cc" + "../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc" + "../../../mindspore/ccsrc/runtime/device/kernel_info.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/profiling/*.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc" + "../../../mindspore/ccsrc/runtime/device/convert_tensor_utils.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc" "../../../mindspore/ccsrc/predict/generator/utils/ir_model_util.cc" "../../../mindspore/ccsrc/predict/predict.cc" "../../../mindspore/ccsrc/predict/converter/*.cc" "../../../mindspore/ccsrc/predict/converter/attr_utils/*.cc" "../../../mindspore/ccsrc/predict/converter/lite_model/*.cc" "../../../mindspore/ccsrc/predict/converter/lite_model/operations/*.cc" - "../../../mindspore/ccsrc/kernel/cpu/cpu_kernel.cc" - "../../../mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc" - "../../../mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc" - "../../../mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc" - "../../../mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc" - "../../../mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc" ) list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/debug/dump_proto.cc") -list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ir/lite/tensor.cc") -list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/core/ir/lite/tensor.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/util.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/scheduler.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/anf_ir.pb.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/node_strategy.pb.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc") diff --git a/tests/ut/cpp/abstract/abstract_test.cc b/tests/ut/cpp/abstract/abstract_test.cc new file mode 100644 index 0000000000..2e3a2a8d1a --- /dev/null +++ b/tests/ut/cpp/abstract/abstract_test.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "common/common_test.h" + +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "abstract/utils.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/resolve.h" +#include "pipeline/jit/parse/data_converter.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace abstract { + +class TestAbstract : public UT::Common { + public: + TestAbstract() {} + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(TestAbstract, TestParseDataClass) { + py::object fn = parse::python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "TestFoo"); + + ClassPtr cls_ptr = parse::ParseDataClass(fn); + ASSERT_TRUE(nullptr != cls_ptr); + std::shared_ptr cls = dyn_cast(cls_ptr); + ASSERT_TRUE(nullptr != cls); + + MS_LOG(INFO) << "" << cls->ToString(); + ASSERT_EQ(cls->tag(), Named(std::string("TestFoo"))); + + ClassAttrVector attributes = cls->GetAttributes(); + ASSERT_EQ(attributes.size(), 2); + for (auto &v : attributes) { + if (v.first == std::string("x")) { + ASSERT_TRUE(nullptr != dyn_cast(v.second)); + } + if (v.first == std::string("y")) { + ASSERT_TRUE(nullptr != dyn_cast(v.second)); + } + } + + std::unordered_map methods = cls->methods(); + ASSERT_EQ(methods.size(), 4); + int counts = 0; + for (auto &v : methods) { + if (v.first == std::string("inf")) { + counts++; + } + MS_LOG(INFO) << "" << v.first; + } + ASSERT_EQ(counts, 1); + + ValuePtr obj = std::make_shared(fn, "TestFoo"); + + ValueNodePtr fn_node = NewValueNode(obj); + AnfNodeConfigPtr fn_conf = std::make_shared(nullptr, fn_node, nullptr); + AbstractBasePtr foo = ToAbstract(obj, nullptr, fn_conf); + ASSERT_TRUE(foo != nullptr); + + AbstractBasePtr abstract_x = FromValue(1.1, true); + AbstractBasePtr abstract_y = FromValue(5, true); + + auto partical_func = dyn_cast(foo); + AbstractBasePtrList args_spec_list = partical_func->args(); + ASSERT_GT(args_spec_list.size(), 0); + AbstractScalarPtr abs_scalar = dyn_cast(args_spec_list[0]); + + AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y}; + + StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); + ASSERT_TRUE(nullptr != eval_impl); + + AbstractBasePtr new_cls = eval_impl(nullptr, prim::kPrimMakeRecord, args_list); + ASSERT_TRUE(nullptr != new_cls); +} + +} // namespace abstract +} // namespace mindspore diff --git a/tests/ut/cpp/abstract/dshape_test.cc b/tests/ut/cpp/abstract/dshape_test.cc new file mode 100644 index 0000000000..da0e9ed3ee --- /dev/null +++ b/tests/ut/cpp/abstract/dshape_test.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "common/common_test.h" + +#include "abstract/dshape.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace abstract { +class TestDShape : public UT::Common { + public: + Shape shp_1; + Shape shp_2; + Shape shp_3; + Shape shp_4; + + NoShape shp_noshp_1; + NoShape shp_noshp_2; + + TupleShape shp_tuple_1; + TupleShape shp_tuple_2; + TupleShape shp_tuple_3; + TupleShape shp_tuple_4; + TestDShape() + : shp_1({1, 1}), + shp_2({1, 1}), + shp_3({1, 2}), + shp_4({1}), + + shp_noshp_1(), + shp_noshp_2(), + + shp_tuple_1({NoShape().Clone(), Shape({1, 1}).Clone()}), + shp_tuple_2({NoShape().Clone(), Shape({1, 1, 1}).Clone()}), + shp_tuple_3({NoShape().Clone(), Shape({1, 2, 1}).Clone()}), + shp_tuple_4({NoShape().Clone()}) {} +}; + +TEST_F(TestDShape, EqualTest) { + ASSERT_TRUE(shp_1 == shp_2); + ASSERT_FALSE(shp_1 == shp_3); + ASSERT_FALSE(shp_1 == shp_noshp_1); + + ASSERT_TRUE(shp_noshp_1 == shp_noshp_2); + + ASSERT_FALSE(shp_tuple_1 == shp_1); + ASSERT_FALSE(shp_tuple_1 == shp_tuple_2); + ASSERT_FALSE(shp_tuple_1 == shp_tuple_4); +} +TEST_F(TestDShape, ToString) { + ASSERT_EQ(shp_3.ToString(), "(1, 2)"); + ASSERT_EQ(shp_noshp_1.ToString(), "NoShape"); + ASSERT_EQ(shp_tuple_2.ToString(), "TupleShape(NoShape, (1, 1, 1))"); +} + +TEST_F(TestDShape, Clone) { + ASSERT_EQ(*shp_3.Clone(), shp_3); + ASSERT_EQ(*shp_noshp_1.Clone(), shp_noshp_1); + ASSERT_EQ(*shp_tuple_2.Clone(), shp_tuple_2); +} + +} // namespace abstract +} // namespace mindspore diff --git a/tests/ut/cpp/abstract/utils_test.cc b/tests/ut/cpp/abstract/utils_test.cc new file mode 100644 index 0000000000..33cada28d7 --- /dev/null +++ b/tests/ut/cpp/abstract/utils_test.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "abstract/utils.h" + +#include "common/common_test.h" +#include "pipeline/jit/static_analysis/static_analysis.h" + +namespace mindspore { +namespace abstract { +class TestUtils : public UT::Common { + public: + TestUtils() {} + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(TestUtils, test_join) { + // AbstractScalar + AbstractBasePtr abs_s1 = FromValue(1, false); + AbstractBasePtr abs_s2 = FromValue(2, false); + AbstractBasePtr abs_s_anything = FromValue(2, true); + + AbstractBasePtr res_s1 = abs_s1->Join(abs_s2); + ASSERT_EQ(*res_s1, *abs_s_anything); + + // AbstractTuple join; + std::vector list1 = {1, 2, 3, 4, 5}; + std::vector list2 = {5, 4, 3, 2, 1}; + AbstractBasePtr abs_t1 = FromValue(list1, true); + AbstractBasePtr abs_t2 = FromValue(list2, true); + + AbstractBasePtr res_t1 = abs_t1->Join(abs_t2); + ASSERT_EQ(res_t1, abs_t1); + + abs_s1 = FromValue(1, false); + + AbstractBasePtr t1 = std::make_shared(AbstractBasePtrList({abs_s1, abs_s_anything})); + AbstractBasePtr t2 = std::make_shared(AbstractBasePtrList({abs_s1, abs_s_anything})); + AbstractBasePtr t3 = std::make_shared(AbstractBasePtrList({abs_s_anything, abs_s_anything})); + + res_t1 = t1->Join(t2); + ASSERT_EQ(res_t1, t1); + + res_t1 = t1->Join(t3); + ASSERT_EQ(*res_t1, *t3); + + res_t1 = t3->Join(t1); + ASSERT_EQ(res_t1, t3); +} + +} // namespace abstract +} // namespace mindspore diff --git a/tests/ut/cpp/base/base_test.cc b/tests/ut/cpp/base/base_test.cc new file mode 100644 index 0000000000..71a7999e0f --- /dev/null +++ b/tests/ut/cpp/base/base_test.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "common/common_test.h" +#include "utils/any.h" +#include "base/base.h" +#include "ir/anf.h" +#include "utils/log_adapter.h" + +namespace mindspore { + +class TestNode : public UT::Common { + public: + TestNode() {} +}; + +class ChildA : public Base { + public: + ChildA() {} + ~ChildA() {} + MS_DECLARE_PARENT(ChildA, Base); + std::string name() { return "ChildA"; } + std::size_t hash() const override { return 1; } +}; +class ChildAA : public ChildA { + public: + ChildAA() {} + ~ChildAA() {} + MS_DECLARE_PARENT(ChildAA, ChildA); + std::size_t hash() const override { return 1; } + std::string name() { return "ChildAA"; } +}; + +class ChildB : public Base { + public: + ChildB() {} + ~ChildB() {} + MS_DECLARE_PARENT(ChildB, Base); + std::size_t hash() const override { return 1; } + std::string name() { return "ChildB"; } +}; + +TEST_F(TestNode, test_dyn_cast) { + auto aa = std::make_shared(); + std::shared_ptr n = aa; + MS_LOG(INFO) << "aa ptr_name: " << aa->name(); + MS_LOG(INFO) << "aa type_name: " << aa->type_name(); + MS_LOG(INFO) << "n ptr_name: " << demangle(typeid(n).name()); + MS_LOG(INFO) << "n type_name: " << n->type_name(); + ASSERT_TRUE(n != nullptr); + ASSERT_EQ(std::string(n->type_name().c_str()), "ChildAA"); + auto a = dyn_cast(n); + MS_LOG(INFO) << "a ptr_name: " << a->name(); + MS_LOG(INFO) << "a type_name: " << a->type_name(); + ASSERT_TRUE(a != nullptr); + ASSERT_EQ(std::string(a->name()), "ChildA"); + ASSERT_EQ(std::string(a->type_name().c_str()), "ChildAA"); + auto b_null = dyn_cast(n); + ASSERT_TRUE(b_null == nullptr); + + ChildA* pa = cast(n.get()); + ASSERT_TRUE(pa != nullptr); + MS_LOG(INFO) << "a ptr_name: " << pa->name(); + MS_LOG(INFO) << "a type_name: " << pa->type_name(); +} + +TEST_F(TestNode, test_isa) { + auto a = std::make_shared(); + BasePtr n = a; + ASSERT_TRUE(n->isa() == true); + ASSERT_TRUE(n->isa() == false); + + auto aa = std::make_shared(); + n = aa; + ASSERT_TRUE(n->isa() == true); + ASSERT_TRUE(n->isa() == true); + + auto b = std::make_shared(); + n = b; + ASSERT_TRUE(n->isa() == true); + ASSERT_TRUE(n->isa() == false); + ASSERT_TRUE(n->isa() == false); +} + +} // namespace mindspore diff --git a/tests/ut/cpp/common/backend_common_test.cc b/tests/ut/cpp/common/backend_common_test.cc index 060b170a8c..3710349298 100644 --- a/tests/ut/cpp/common/backend_common_test.cc +++ b/tests/ut/cpp/common/backend_common_test.cc @@ -20,11 +20,11 @@ #include #include "utils/log_adapter.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "debug/anf_ir_dump.h" -#include "session/ascend_session.h" -#include "pipeline/resource.h" -#include "pipeline/action.h" +#include "backend/session/ascend_session.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/action.h" #include "ir/anf.h" #include "ir/manager.h" diff --git a/tests/ut/cpp/common/backend_common_test.h b/tests/ut/cpp/common/backend_common_test.h index fb3334182a..f5bfc9d6dd 100644 --- a/tests/ut/cpp/common/backend_common_test.h +++ b/tests/ut/cpp/common/backend_common_test.h @@ -17,7 +17,7 @@ #define TESTS_UT_CPP_COMMON_UT_BACKEND_COMMON_H_ #include "common/common_test.h" #include "utils/context/ms_context.h" -#include "session/kernel_graph.h" +#include "backend/session/kernel_graph.h" namespace mindspore { class BackendCommon : public UT::Common { diff --git a/tests/ut/cpp/common/py_func_graph_fetcher.h b/tests/ut/cpp/common/py_func_graph_fetcher.h index 98552a96b5..d864842760 100644 --- a/tests/ut/cpp/common/py_func_graph_fetcher.h +++ b/tests/ut/cpp/common/py_func_graph_fetcher.h @@ -22,8 +22,8 @@ #include "ir/primitive.h" #include "ir/manager.h" #include "ir/func_graph.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/parse.h" #include "./common.h" namespace UT { diff --git a/tests/ut/cpp/common/test_main.cc b/tests/ut/cpp/common/test_main.cc index f0cfc1778c..fa456ed260 100644 --- a/tests/ut/cpp/common/test_main.cc +++ b/tests/ut/cpp/common/test_main.cc @@ -16,8 +16,8 @@ #include #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "pipeline/pipeline.h" -#include "pipeline/resource.h" +#include "pipeline/jit/pipeline.h" +#include "pipeline/jit/resource.h" namespace mindspore { extern void InitSubModulesLogLevel(); diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 129864ca0f..8bbf42a640 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -11,6 +11,7 @@ SET(DE_UT_SRCS interrupt_test.cc image_folder_op_test.cc buddy_test.cc + bounding_box_augment_op_test.cc arena_test.cc btree_test.cc center_crop_op_test.cc @@ -35,20 +36,26 @@ SET(DE_UT_SRCS project_op_test.cc queue_test.cc random_crop_op_test.cc + random_crop_with_bbox_op_test.cc random_crop_decode_resize_op_test.cc random_crop_and_resize_op_test.cc + random_crop_and_resize_with_bbox_op_test.cc random_color_adjust_op_test.cc random_horizontal_flip_op_test.cc + random_horizontal_flip_with_bbox_test.cc random_resize_op_test.cc + random_resize_with_bbox_op_test.cc random_rotation_op_test.cc random_vertical_flip_op_test.cc + random_vertical_flip_with_bbox_op_test.cc rename_op_test.cc repeat_op_test.cc skip_op_test.cc rescale_op_test.cc resize_bilinear_op_test.cc resize_op_test.cc - schema_test.cc + resize_with_bbox_op_test.cc + schema_test.cc shuffle_op_test.cc stand_alone_samplers_test.cc status_test.cc @@ -83,6 +90,8 @@ SET(DE_UT_SRCS concatenate_op_test.cc cyclic_array_test.cc perf_data_test.cc + c_api_test.cc + tensor_op_fusion_pass_test.cc ) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/arena_test.cc b/tests/ut/cpp/dataset/arena_test.cc index e8698ad979..10d27b51c6 100644 --- a/tests/ut/cpp/dataset/arena_test.cc +++ b/tests/ut/cpp/dataset/arena_test.cc @@ -15,7 +15,7 @@ */ #include -#include "dataset/util/arena.h" +#include "minddata/dataset/util/arena.h" #include "common/common.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/batch_op_test.cc b/tests/ut/cpp/dataset/batch_op_test.cc index a04da06e4e..3e1f3c0b32 100644 --- a/tests/ut/cpp/dataset/batch_op_test.cc +++ b/tests/ut/cpp/dataset/batch_op_test.cc @@ -16,14 +16,14 @@ #include #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/global_context.h" #include "utils/log_adapter.h" #include "securec.h" -#include "dataset/util/status.h" +#include "minddata/dataset/util/status.h" namespace common = mindspore::common; namespace de = mindspore::dataset; diff --git a/tests/ut/cpp/dataset/bit_functions_test.cc b/tests/ut/cpp/dataset/bit_functions_test.cc index 02b6a25f76..cf1c1562db 100644 --- a/tests/ut/cpp/dataset/bit_functions_test.cc +++ b/tests/ut/cpp/dataset/bit_functions_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/core/constants.h" +#include "minddata/dataset/core/constants.h" #include "common/common.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/bounding_box_augment_op_test.cc b/tests/ut/cpp/dataset/bounding_box_augment_op_test.cc new file mode 100644 index 0000000000..dc59d39fac --- /dev/null +++ b/tests/ut/cpp/dataset/bounding_box_augment_op_test.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/bboxop_common.h" +#include "minddata/dataset/kernels/image/bounding_box_augment_op.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +const bool kSaveExpected = false; +const char kOpName[] = "BoundingBoxAugmentOp"; + +class MindDataTestBoundingBoxAugmentOp : public UT::CVOP::BBOXOP::BBoxOpCommon { + protected: + MindDataTestBoundingBoxAugmentOp() : UT::CVOP::BBOXOP::BBoxOpCommon() {} +}; + +TEST_F(MindDataTestBoundingBoxAugmentOp, TestOp) { + MS_LOG(INFO) << "Doing testBoundingBoxAugment."; + TensorTable results; + std::unique_ptr op = + std::make_unique(std::make_shared(90, 90), 1); + for (const auto &row : images_and_annotations_) { + TensorRow output_row; + Status s = op->Compute(row, &output_row); + EXPECT_TRUE(s.IsOk()); + results.push_back(output_row); + } + if (kSaveExpected) { + SaveImagesWithAnnotations(FileType::kExpected, std::string(kOpName), results); + } + SaveImagesWithAnnotations(FileType::kActual, std::string(kOpName), results); + if (!kSaveExpected) { + CompareActualAndExpected(std::string(kOpName)); + } +} diff --git a/tests/ut/cpp/dataset/btree_test.cc b/tests/ut/cpp/dataset/btree_test.cc index 67b6c4e6c7..9fa4fce812 100644 --- a/tests/ut/cpp/dataset/btree_test.cc +++ b/tests/ut/cpp/dataset/btree_test.cc @@ -15,10 +15,10 @@ */ #include -#include "dataset/util/btree.h" -#include "dataset/util/auto_index.h" -#include "dataset/util/system_pool.h" -#include "dataset/util/task_manager.h" +#include "minddata/dataset/util/btree.h" +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/util/system_pool.h" +#include "minddata/dataset/util/task_manager.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/c_api_test.cc b/tests/ut/cpp/dataset/c_api_test.cc new file mode 100644 index 0000000000..902bc9a43b --- /dev/null +++ b/tests/ut/cpp/dataset/c_api_test.cc @@ -0,0 +1,771 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include "utils/log_adapter.h" +#include "common/utils.h" +#include "common/common.h" +#include "gtest/gtest.h" +#include "securec.h" +#include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/include/status.h" +#include "minddata/dataset/include/transforms.h" +#include "minddata/dataset/include/iterator.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/include/samplers.h" + +using namespace mindspore::dataset::api; +using mindspore::MsLogLevel::ERROR; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; +using mindspore::dataset::Tensor; +using mindspore::dataset::Status; +using mindspore::dataset::BorderType; + + +class MindDataTestPipeline : public UT::DatasetOpTesting { + protected: +}; + + +TEST_F(MindDataTestPipeline, TestBatchAndRepeat) { + // Create a Mnist Dataset + std::string folder_path = datasets_root_path_ + "/testMnistData/"; + std::shared_ptr ds = Mnist(folder_path, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 2; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) { + // Create a Mnist Dataset + std::string folder_path = datasets_root_path_ + "/testMnistData/"; + std::shared_ptr ds = Mnist(folder_path, RandomSampler(false, 20)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create objects for the tensor ops + std::shared_ptr resize_op = vision::Resize({30, 30}); + EXPECT_TRUE(resize_op != nullptr); + + std::shared_ptr center_crop_op = vision::CenterCrop({16, 16}); + EXPECT_TRUE(center_crop_op != nullptr); + + // Create a Map operation on ds + ds = ds->Map({resize_op, center_crop_op}); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 40); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestUniformAugWithOps) { + // Create a Mnist Dataset + std::string folder_path = datasets_root_path_ + "/testMnistData/"; + std::shared_ptr ds = Mnist(folder_path, RandomSampler(false, 20)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 1; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create objects for the tensor ops + std::shared_ptr resize_op = vision::Resize({30, 30}); + EXPECT_TRUE(resize_op != nullptr); + + std::shared_ptr random_crop_op = vision::RandomCrop({28, 28}); + EXPECT_TRUE(random_crop_op != nullptr); + + std::shared_ptr center_crop_op = vision::CenterCrop({16, 16}); + EXPECT_TRUE(center_crop_op != nullptr); + + std::shared_ptr uniform_aug_op = vision::UniformAugment({random_crop_op, center_crop_op}, 2); + EXPECT_TRUE(uniform_aug_op != nullptr); + + // Create a Map operation on ds + ds = ds->Map({resize_op, uniform_aug_op}); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestRandomFlip) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create objects for the tensor ops + std::shared_ptr random_vertical_flip_op = vision::RandomVerticalFlip(0.5); + EXPECT_TRUE(random_vertical_flip_op != nullptr); + + std::shared_ptr random_horizontal_flip_op = vision::RandomHorizontalFlip(0.5); + EXPECT_TRUE(random_horizontal_flip_op != nullptr); + + // Create a Map operation on ds + ds = ds->Map({random_vertical_flip_op, random_horizontal_flip_op}); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestImageFolderBatchAndRepeat) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 2; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) { + std::shared_ptr sampl = DistributedSampler(2, 1); + EXPECT_NE(sampl, nullptr); + + sampl = PKSampler(3); + EXPECT_NE(sampl, nullptr); + + sampl = RandomSampler(false, 12); + EXPECT_NE(sampl, nullptr); + + sampl = SequentialSampler(0, 12); + EXPECT_NE(sampl, nullptr); + + std::vector weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1}; + sampl = WeightedRandomSampler(weights, 12); + EXPECT_NE(sampl, nullptr); + + std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; + sampl = SubsetRandomSampler(indices); + EXPECT_NE(sampl, nullptr); + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, false, sampl); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 2; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 12); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestPad) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create objects for the tensor ops + std::shared_ptr pad_op1 = vision::Pad({1, 2, 3, 4}, {0}, BorderType::kSymmetric); + EXPECT_TRUE(pad_op1 != nullptr); + + std::shared_ptr pad_op2 = vision::Pad({1}, {1, 1, 1}, BorderType::kEdge); + EXPECT_TRUE(pad_op2 != nullptr); + + std::shared_ptr pad_op3 = vision::Pad({1, 4}); + EXPECT_TRUE(pad_op3 != nullptr); + + // Create a Map operation on ds + ds = ds->Map({pad_op1, pad_op2, pad_op3}); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCutOut) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create objects for the tensor ops + std::shared_ptr cut_out1 = vision::CutOut(30, 5); + EXPECT_TRUE(cut_out1!= nullptr); + + std::shared_ptr cut_out2 = vision::CutOut(30); + EXPECT_TRUE(cut_out2 != nullptr); + + // Create a Map operation on ds + ds = ds->Map({cut_out1, cut_out2}); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestNormalize) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create objects for the tensor ops + std::shared_ptr normalize = vision::Normalize({121.0, 115.0, 100.0}, {70.0, 68.0, 71.0}); + EXPECT_TRUE(normalize != nullptr); + + // Create a Map operation on ds + ds = ds->Map({normalize}); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestDecode) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, false, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create objects for the tensor ops + std::shared_ptr decode = vision::Decode(true); + EXPECT_TRUE(decode != nullptr); + + // Create a Map operation on ds + ds = ds->Map({decode}); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + EXPECT_EQ(i, 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestShuffleDataset) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Shuffle operation on ds + int32_t shuffle_size = 10; + ds = ds->Shuffle(shuffle_size); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 2; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestCifar10Dataset) { + + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, 0, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 2; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 10); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestRandomColorAdjust) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create objects for the tensor ops + std::shared_ptr random_color_adjust1 = vision::RandomColorAdjust({1.0}, {0.0}, {0.5}, {0.5}); + EXPECT_TRUE(random_color_adjust1 != nullptr); + + std::shared_ptr random_color_adjust2 = vision::RandomColorAdjust({1.0, 1.0}, {0.0, 0.0}, {0.5, 0.5}, + {0.5, 0.5}); + EXPECT_TRUE(random_color_adjust2 != nullptr); + + std::shared_ptr random_color_adjust3 = vision::RandomColorAdjust({0.5, 1.0}, {0.0, 0.5}, {0.25, 0.5}, + {0.25, 0.5}); + EXPECT_TRUE(random_color_adjust3 != nullptr); + + std::shared_ptr random_color_adjust4 = vision::RandomColorAdjust(); + EXPECT_TRUE(random_color_adjust4 != nullptr); + + // Create a Map operation on ds + ds = ds->Map({random_color_adjust1, random_color_adjust2, random_color_adjust3, random_color_adjust4}); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestRandomRotation) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create objects for the tensor ops + std::shared_ptr random_rotation_op = vision::RandomRotation({-180, 180}); + EXPECT_TRUE(random_rotation_op != nullptr); + + // Create a Map operation on ds + ds = ds->Map({random_rotation_op}); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 20); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestProjectMap) { + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_TRUE(ds != nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 2; + ds = ds->Repeat(repeat_num); + EXPECT_TRUE(ds != nullptr); + + // Create objects for the tensor ops + std::shared_ptr random_vertical_flip_op = vision::RandomVerticalFlip(0.5); + EXPECT_TRUE(random_vertical_flip_op != nullptr); + + // Create a Map operation on ds + ds = ds->Map({random_vertical_flip_op}, {}, {}, {"image", "label"}); + EXPECT_TRUE(ds != nullptr); + + // Create a Project operation on ds + std::vector column_project = {"image"}; + ds = ds->Project(column_project); + EXPECT_TRUE(ds != nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 1; + ds = ds->Batch(batch_size); + EXPECT_TRUE(ds != nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_TRUE(iter != nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_TRUE(i == 20); + + // Manually terminate the pipeline + iter->Stop(); +} \ No newline at end of file diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc new file mode 100644 index 0000000000..bdb7c861b2 --- /dev/null +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -0,0 +1,579 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "common/common.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/util/storage_container.h" // lint !e322 +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/data_schema.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::dataset::CacheClient; +using mindspore::dataset::TaskGroup; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +class MindDataTestCacheOp : public UT::DatasetOpTesting { + public: + void SetUp() override { + DatasetOpTesting::SetUp(); + GlobalInit(); + } +}; + +TEST_F(MindDataTestCacheOp, TestCacheServer) { + Status rc; + CacheClient myClient(1, 0, true); // use arbitrary session of 1, size of 0, spilling is true + // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. + rc = myClient.CreateCache(1, true); + EXPECT_TRUE(rc.IsOk()); + std::cout << myClient << std::endl; + + // Create a schema using the C api's + int32_t rank = 0; // not used + std::unique_ptr testSchema = std::make_unique(); + // 2 columns. First column is an "image" 640,480,3 + TensorShape c1Shape({640, 480, 3}); + ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, + rank, // not used + &c1Shape); + // Column 2 will just be a scalar label number + TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); + + testSchema->AddColumn(c1); + testSchema->AddColumn(c2); + + std::unordered_map map; + rc = testSchema->GetColumnNameMap(&map); + EXPECT_TRUE(rc.IsOk()); + + // Test the CacheSchema api + rc = myClient.CacheSchema(map); + EXPECT_TRUE(rc.IsOk()); + + // Create a tensor, take a snapshot and restore it back, and compare. + std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); + t->SetItemAt({0, 0}, 1); + t->SetItemAt({0, 1}, 2); + t->SetItemAt({0, 2}, 3); + t->SetItemAt({1, 0}, 4); + t->SetItemAt({1, 1}, 5); + t->SetItemAt({1, 2}, 6); + std::cout << *t << std::endl; + TensorTable tbl; + TensorRow row; + row.push_back(t); + int64_t row_id; + rc = myClient.WriteRow(row, &row_id); + EXPECT_TRUE(rc.IsOk()); + + // Switch off build phase. + rc = myClient.BuildPhaseDone(); + EXPECT_TRUE(rc.IsOk()); + + // Now restore from cache. + row.clear(); + rc = myClient.GetRows({row_id}, &tbl); + row = tbl.front(); + EXPECT_TRUE(rc.IsOk()); + auto r = row.front(); + std::cout << *r << std::endl; + // Compare + bool cmp = (*t == *r); + EXPECT_TRUE(cmp); + + // Get back the schema and verify + std::unordered_map map_out; + rc = myClient.FetchSchema(&map_out); + EXPECT_TRUE(rc.IsOk()); + cmp = (map_out == map); + EXPECT_TRUE(cmp); + + // Test Purge and Destroy + rc = myClient.PurgeCache(); + EXPECT_TRUE(rc.IsOk()); + rc = myClient.DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} + +TEST_F(MindDataTestCacheOp, TestConcurrencyRequest) { + // Clear the rc of the master thread if any + (void)TaskManager::GetMasterThreadRc(); + TaskGroup vg; + Status rc; + CacheClient myClient(1, 1, true); // use arbitrary session of 1, size 1, spilling is true + // cksum value of 1 for CreateCache here...normally you do not directly create a cache and the cksum arg is generated. + rc = myClient.CreateCache(1, true); + EXPECT_TRUE(rc.IsOk()); + std::cout << myClient << std::endl; + std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); + t->SetItemAt({0, 0}, 1); + t->SetItemAt({0, 1}, 2); + t->SetItemAt({0, 2}, 3); + t->SetItemAt({1, 0}, 4); + t->SetItemAt({1, 1}, 5); + t->SetItemAt({1, 2}, 6); + TensorTable tbl; + TensorRow row; + row.push_back(t); + // Cache tensor row t 5000 times using 10 threads. + for (auto k = 0; k < 10; ++k) { + Status vg_rc = vg.CreateAsyncTask("Test agent", [&myClient, &row]() -> Status { + TaskManager::FindMe()->Post(); + for (auto i = 0; i < 500; i++) { + RETURN_IF_NOT_OK(myClient.WriteRow(row)); + } + return Status::OK(); + }); + EXPECT_TRUE(vg_rc.IsOk()); + } + ASSERT_TRUE(vg.join_all().IsOk()); + ASSERT_TRUE(vg.GetTaskErrorIfAny().IsOk()); + rc = myClient.BuildPhaseDone(); + ASSERT_TRUE(rc.IsOk()); + // Get statistics from the server. + CacheClient::ServiceStat stat{}; + rc = myClient.GetStat(&stat); + ASSERT_TRUE(rc.IsOk()); + std::cout << stat.min_row_id << ":" << stat.max_row_id << ":" << stat.num_mem_cached << ":" << stat.num_disk_cached + << "\n"; + // Expect there are 5000 rows there. + EXPECT_EQ(5000, stat.max_row_id - stat.min_row_id + 1); + // Get them all back using row id and compare with tensor t. + for (auto i = stat.min_row_id; i <= stat.max_row_id; ++i) { + tbl.clear(); + row.clear(); + rc = myClient.GetRows({i}, &tbl); + EXPECT_TRUE(rc.IsOk()); + row = tbl.front(); + auto r = row.front(); + bool cmp = (*t == *r); + EXPECT_TRUE(cmp); + } + rc = myClient.DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} + +// Simple test with a repeated cache op over random data producer +// +// RepeatOp +// | +// CacheOp +// | +// RandomDataOp +// +TEST_F(MindDataTestCacheOp, TestRandomDataCache1) { + Status rc; + int32_t rank = 0; // not used + MS_LOG(INFO) << "UT test TestRandomDataCache1"; + // Start with an empty execution tree + auto myTree = std::make_shared(); + + // Create a schema using the C api's + std::unique_ptr testSchema = std::make_unique(); + + // 2 columns. First column is an "image" 640,480,3 + TensorShape c1Shape({640, 480, 3}); + ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, + rank, // not used + &c1Shape); + + // Column 2 will just be a scalar label number + TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); + + testSchema->AddColumn(c1); + testSchema->AddColumn(c2); + + // RandomDataOp + std::shared_ptr myRandomDataOp; + rc = RandomDataOp::Builder() + .SetRowsPerBuffer(4) + .SetNumWorkers(4) + .SetDataSchema(std::move(testSchema)) + .SetTotalRows(50) // 50 samples for now + .Build(&myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + + // CacheOp + // size of 0, spilling is true + std::shared_ptr myClient = std::make_shared(1, 0, true); + std::shared_ptr myCacheOp; + + int64_t num_samples = 0; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + rc = CacheOp::Builder() + .SetNumWorkers(5) + .SetClient(myClient) + .SetRowsPerBuffer(4) + .SetSampler(std::move(seq_sampler)) + .Build(&myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + + // RepeatOp + uint32_t numRepeats = 4; + std::shared_ptr myRepeatOp; + rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + // Assign tree relations and root + rc = myRepeatOp->AddChild(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myCacheOp->AddChild(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssignRoot(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration"; + rc = myTree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + + // quick check to see what tree looks like + std::ostringstream ss; + ss << *myTree; // some funny const error if I try to write directly to ms log stream + MS_LOG(INFO) << "Here's the tree:\n" << ss.str(); + + std::cout << *myClient << std::endl; + + rc = myTree->Launch(); + EXPECT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator dI(myTree); + TensorRow tensorList; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + int rowCount = 0; + while (!tensorList.empty()) { + // Don't display these rows, just count them + MS_LOG(INFO) << "Row fetched #: " << rowCount; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + rowCount++; + } + ASSERT_EQ(rowCount, 200); + rc = myClient->DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} + +//// Simple test with a repeated cache op over random data producer. +//// This one will exceed memory and require a spill. +//// +//// RepeatOp +//// | +//// CacheOp +//// | +//// RandomDataOp +//// +TEST_F(MindDataTestCacheOp, TestRandomDataCacheSpill) { + Status rc; + int32_t rank = 0; // not used + MS_LOG(INFO) << "UT test TestRandomDataCacheSpill"; + // Start with an empty execution tree + auto myTree = std::make_shared(); + + // Create a schema using the C api's + std::unique_ptr testSchema = std::make_unique(); + + // 2 columns. First column is an "image" 640,480,3 + TensorShape c1Shape({640, 480, 3}); + ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, + rank, // not used + &c1Shape); + + // Column 2 will just be a scalar label number + TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); + + testSchema->AddColumn(c1); + testSchema->AddColumn(c2); + + // RandomDataOp + std::shared_ptr myRandomDataOp; + rc = RandomDataOp::Builder() + .SetRowsPerBuffer(2) + .SetNumWorkers(4) + .SetDataSchema(std::move(testSchema)) + .SetTotalRows(10) + .Build(&myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + + // CacheOp + int64_t num_samples = 0; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + std::shared_ptr myClient = std::make_shared(1, 4, true); + std::shared_ptr myCacheOp; + rc = CacheOp::Builder() + .SetNumWorkers(4) + .SetClient(myClient) + .SetRowsPerBuffer(3) + .SetSampler(std::move(seq_sampler)) + .Build(&myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + + // RepeatOp + uint32_t numRepeats = 4; + std::shared_ptr myRepeatOp; + rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + // Assign tree relations and root + rc = myRepeatOp->AddChild(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myCacheOp->AddChild(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssignRoot(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration"; + rc = myTree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + + std::cout << *myClient << std::endl; + + rc = myTree->Launch(); + EXPECT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator dI(myTree); + TensorRow tensorList; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + int rowCount = 0; + while (!tensorList.empty()) { + // Don't display these rows, just count them + MS_LOG(INFO) << "Row fetched #: " << rowCount; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + rowCount++; + } + ASSERT_EQ(rowCount, 40); + rc = myClient->DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} + +TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) { + Status rc; + int64_t num_samples = 0; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + + std::shared_ptr myClient = std::make_shared(1, 0, true); + + std::shared_ptr myMergeOp; + rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build( + &myMergeOp); + EXPECT_TRUE(rc.IsOk()); + + std::shared_ptr myLookupOp; + rc = CacheLookupOp::Builder() + .SetNumWorkers(3) + .SetOpConnectorSize(3) + .SetClient(myClient) + .SetSampler(seq_sampler) + .Build(&myLookupOp); + EXPECT_TRUE(rc.IsOk()); + + std::shared_ptr so; + ImageFolderOp::Builder builder; + builder.SetSampler(myLookupOp) + .SetOpConnectorSize(3) + .SetNumWorkers(3) + .SetRowsPerBuffer(2) + .SetExtensions({".jpg", ".JPEG"}) + .SetRecursive(true) + .SetImageFolderDir(datasets_root_path_ + "/testPK/data"); + rc = builder.Build(&so); + EXPECT_TRUE(rc.IsOk()); + + // RepeatOp + uint32_t numRepeats = 4; + std::shared_ptr myRepeatOp; + rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + auto myTree = std::make_shared(); + rc = myTree->AssociateNode(so); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myLookupOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myMergeOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssignRoot(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + rc = myRepeatOp->AddChild(myMergeOp); + EXPECT_TRUE(rc.IsOk()); + rc = myMergeOp->AddChild(myLookupOp); + EXPECT_TRUE(rc.IsOk()); + rc = myMergeOp->AddChild(so); + EXPECT_TRUE(rc.IsOk()); + + rc = myTree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->Launch(); + EXPECT_TRUE(rc.IsOk()); + // Start the loop of reading tensors from our pipeline + DatasetIterator dI(myTree); + TensorRow tensorList; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + int rowCount = 0; + while (!tensorList.empty()) { + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + if (rc.IsError()) { + std::cout << rc << std::endl; + break; + } + rowCount++; + } + ASSERT_EQ(rowCount, 176); + std::cout << "Row count : " << rowCount << std::endl; + rc = myClient->DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} + +//// Simple test with a repeated cache op over random data producer. +//// The difference in this one is that you do not add the sampler to the cache op directly. +//// Instead, the sampler is added as part of the leaf op construction. Then, the prepare +//// phase will pull this up from the leaf and into the cache. +//// It removes the sampler from the leaf op, which doesn't make sense there anyway for +//// the RandomDataOp which doesn't support sampling without a cache. +//// +//// RepeatOp +//// | +//// CacheOp +//// | +//// RandomDataOp +//// +TEST_F(MindDataTestCacheOp, TestCacheInheritSampler) { + Status rc; + int32_t rank = 0; // not used + MS_LOG(INFO) << "UT test TestCacheInheritSampler"; + + int64_t num_samples = 0; + int64_t start_index = 0; + auto seq_sampler = std::make_shared(num_samples, start_index); + + // Start with an empty execution tree + auto myTree = std::make_shared(); + + // Create a schema using the C api's + std::unique_ptr testSchema = std::make_unique(); + + // 2 columns. First column is an "image" 640,480,3 + TensorShape c1Shape({640, 480, 3}); + ColDescriptor c1("image", DataType(DataType::DE_INT8), TensorImpl::kFlexible, + rank, // not used + &c1Shape); + + // Column 2 will just be a scalar label number + TensorShape c2Shape({}); // empty shape is a 1-value scalar Tensor + ColDescriptor c2("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, rank, &c2Shape); + + testSchema->AddColumn(c1); + testSchema->AddColumn(c2); + + // RandomDataOp + std::shared_ptr myRandomDataOp; + rc = RandomDataOp::Builder() + .SetRowsPerBuffer(2) + .SetNumWorkers(4) + .SetDataSchema(std::move(testSchema)) + .SetTotalRows(10) + .SetSampler(std::move(seq_sampler)) + .Build(&myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + + // CacheOp + std::shared_ptr myClient = std::make_shared(1, 4, true); + std::shared_ptr myCacheOp; + rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetRowsPerBuffer(3).Build(&myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + + // RepeatOp + uint32_t numRepeats = 4; + std::shared_ptr myRepeatOp; + rc = RepeatOp::Builder(numRepeats).Build(&myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssociateNode(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + // Assign tree relations and root + rc = myRepeatOp->AddChild(myCacheOp); + EXPECT_TRUE(rc.IsOk()); + rc = myCacheOp->AddChild(myRandomDataOp); + EXPECT_TRUE(rc.IsOk()); + rc = myTree->AssignRoot(myRepeatOp); + EXPECT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration"; + rc = myTree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + + std::cout << *myClient << std::endl; + + rc = myTree->Launch(); + EXPECT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator dI(myTree); + TensorRow tensorList; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + int rowCount = 0; + while (!tensorList.empty()) { + // Don't display these rows, just count them + MS_LOG(INFO) << "Row fetched #: " << rowCount; + rc = dI.FetchNextTensorRow(&tensorList); + EXPECT_TRUE(rc.IsOk()); + rowCount++; + } + ASSERT_EQ(rowCount, 40); + rc = myClient->DestroyCache(); + EXPECT_TRUE(rc.IsOk()); +} diff --git a/tests/ut/cpp/dataset/celeba_op_test.cc b/tests/ut/cpp/dataset/celeba_op_test.cc index a109739fda..ccaed122f4 100644 --- a/tests/ut/cpp/dataset/celeba_op_test.cc +++ b/tests/ut/cpp/dataset/celeba_op_test.cc @@ -19,11 +19,11 @@ #include #include "common/common.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/celeba_op.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/center_crop_op_test.cc b/tests/ut/cpp/dataset/center_crop_op_test.cc index 54c45c957e..cd0f362f64 100644 --- a/tests/ut/cpp/dataset/center_crop_op_test.cc +++ b/tests/ut/cpp/dataset/center_crop_op_test.cc @@ -15,8 +15,8 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/center_crop_op.h" -#include "dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/center_crop_op.h" +#include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/channel_swap_test.cc b/tests/ut/cpp/dataset/channel_swap_test.cc index f1dc1396ca..2000de15b2 100644 --- a/tests/ut/cpp/dataset/channel_swap_test.cc +++ b/tests/ut/cpp/dataset/channel_swap_test.cc @@ -15,8 +15,8 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/hwc_to_chw_op.h" -#include "dataset/core/data_type.h" +#include "minddata/dataset/kernels/image/hwc_to_chw_op.h" +#include "minddata/dataset/core/data_type.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/cifar_op_test.cc b/tests/ut/cpp/dataset/cifar_op_test.cc index b37b9acaee..ed22f4f347 100644 --- a/tests/ut/cpp/dataset/cifar_op_test.cc +++ b/tests/ut/cpp/dataset/cifar_op_test.cc @@ -20,14 +20,14 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/cifar_op.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/circular_pool_test.cc b/tests/ut/cpp/dataset/circular_pool_test.cc index c42b08ddcd..d06f846684 100644 --- a/tests/ut/cpp/dataset/circular_pool_test.cc +++ b/tests/ut/cpp/dataset/circular_pool_test.cc @@ -15,9 +15,9 @@ */ #include #include -#include "dataset/util/task_manager.h" -#include "dataset/util/circular_pool.h" -#include "dataset/util/services.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/util/services.h" #include "common/common.h" #include "common/utils.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/client_config_test.cc b/tests/ut/cpp/dataset/client_config_test.cc index a907d50134..5cc9600b4e 100644 --- a/tests/ut/cpp/dataset/client_config_test.cc +++ b/tests/ut/cpp/dataset/client_config_test.cc @@ -20,11 +20,11 @@ #include #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "gtest/gtest.h" -#include "dataset/core/global_context.h" -#include "dataset/util/status.h" -#include "dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/clue_op_test.cc b/tests/ut/cpp/dataset/clue_op_test.cc index ff2f01a9ff..0935434a06 100644 --- a/tests/ut/cpp/dataset/clue_op_test.cc +++ b/tests/ut/cpp/dataset/clue_op_test.cc @@ -17,13 +17,13 @@ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "dataset/engine/datasetops/source/clue_op.h" -#include "dataset/util/status.h" +#include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/util/status.h" namespace common = mindspore::common; diff --git a/tests/ut/cpp/dataset/coco_op_test.cc b/tests/ut/cpp/dataset/coco_op_test.cc index bcb82f8ec1..6e6d3c26e5 100644 --- a/tests/ut/cpp/dataset/coco_op_test.cc +++ b/tests/ut/cpp/dataset/coco_op_test.cc @@ -20,18 +20,18 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/coco_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/common/bboxop_common.cc b/tests/ut/cpp/dataset/common/bboxop_common.cc index 70e6b5a339..62c9f85348 100644 --- a/tests/ut/cpp/dataset/common/bboxop_common.cc +++ b/tests/ut/cpp/dataset/common/bboxop_common.cc @@ -26,9 +26,9 @@ #include "./tinyxml2.h" #include "opencv2/opencv.hpp" #include "common/utils.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/util/path.h" -#include "dataset/core/constants.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/core/constants.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; @@ -66,17 +66,16 @@ void BBoxOpCommon::GetInputImagesAndAnnotations(const std::string &dir, std::siz MS_LOG(ERROR) << "Images folder was not found : " + images_path; EXPECT_TRUE(dir_path.Exists()); } - std::size_t files_fetched = 0; // get image file paths - while (image_dir_itr->hasNext() && files_fetched < num_of_samples) { + while (image_dir_itr->hasNext()) { Path image_path = image_dir_itr->next(); if (image_path.Extension() == std::string(kImageExt)) { paths_to_fetch.push_back(image_path.toString()); - files_fetched++; } } // sort fetched files std::sort(paths_to_fetch.begin(), paths_to_fetch.end()); + std::size_t files_fetched = 0; for (const auto &image_file : paths_to_fetch) { std::string image_ext = std::string(kImageExt); std::string annot_file = image_file; @@ -100,6 +99,10 @@ void BBoxOpCommon::GetInputImagesAndAnnotations(const std::string &dir, std::siz // add image and annotation to the tensor table TensorRow row_data({std::move(input_tensor_), std::move(annotation_tensor)}); images_and_annotations_.push_back(row_data); + files_fetched++; + if (files_fetched == num_of_samples) { + break; + } } } @@ -118,14 +121,11 @@ void BBoxOpCommon::SaveImagesWithAnnotations(BBoxOpCommon::FileType type, const bool passing_data_fetch = true; // For each bounding box draw on the image. for (uint32_t i = 0; i < num_of_boxes; i++) { - uint32_t x = 0; - uint32_t y = 0; - uint32_t w = 0; - uint32_t h = 0; - passing_data_fetch &= row[1]->GetUnsignedIntAt(&x, {i, 0}).IsOk(); - passing_data_fetch &= row[1]->GetUnsignedIntAt(&y, {i, 1}).IsOk(); - passing_data_fetch &= row[1]->GetUnsignedIntAt(&w, {i, 2}).IsOk(); - passing_data_fetch &= row[1]->GetUnsignedIntAt(&h, {i, 3}).IsOk(); + float x = 0.0, y = 0.0, w = 0.0, h = 0.0; + passing_data_fetch &= row[1]->GetItemAt(&x, {i, 0}).IsOk(); + passing_data_fetch &= row[1]->GetItemAt(&y, {i, 1}).IsOk(); + passing_data_fetch &= row[1]->GetItemAt(&w, {i, 2}).IsOk(); + passing_data_fetch &= row[1]->GetItemAt(&h, {i, 3}).IsOk(); if (!passing_data_fetch) { MS_LOG(ERROR) << "Fetching bbox coordinates failed in SaveImagesWithAnnotations."; EXPECT_TRUE(passing_data_fetch); @@ -193,24 +193,24 @@ bool BBoxOpCommon::LoadAnnotationFile(const std::string &path, std::shared_ptr return_value_list; + std::vector return_value_list; dsize_t bbox_count = 0; // keep track of number of bboxes in file dsize_t bbox_val_count = 4; // creating bboxes of size 4 to test function // FILE OK TO READ while (object != nullptr) { bbox_count += 1; std::string label_name; - uint32_t xmin = 0, ymin = 0, xmax = 0, ymax = 0; + float xmin = 0.0, ymin = 0.0, xmax = 0.0, ymax = 0.0; XMLElement *bbox_node = object->FirstChildElement("bndbox"); if (bbox_node != nullptr) { XMLElement *xmin_node = bbox_node->FirstChildElement("xmin"); - if (xmin_node != nullptr) xmin = xmin_node->UnsignedText(); + if (xmin_node != nullptr) xmin = xmin_node->FloatText(); XMLElement *ymin_node = bbox_node->FirstChildElement("ymin"); - if (ymin_node != nullptr) ymin = ymin_node->UnsignedText(); + if (ymin_node != nullptr) ymin = ymin_node->FloatText(); XMLElement *xmax_node = bbox_node->FirstChildElement("xmax"); - if (xmax_node != nullptr) xmax = xmax_node->UnsignedText(); + if (xmax_node != nullptr) xmax = xmax_node->FloatText(); XMLElement *ymax_node = bbox_node->FirstChildElement("ymax"); - if (ymax_node != nullptr) ymax = ymax_node->UnsignedText(); + if (ymax_node != nullptr) ymax = ymax_node->FloatText(); } else { MS_LOG(ERROR) << "bndbox dismatch in " + path; return false; diff --git a/tests/ut/cpp/dataset/common/bboxop_common.h b/tests/ut/cpp/dataset/common/bboxop_common.h index ba3ceb62d9..243908e7a3 100644 --- a/tests/ut/cpp/dataset/common/bboxop_common.h +++ b/tests/ut/cpp/dataset/common/bboxop_common.h @@ -17,7 +17,7 @@ #define TESTS_DATASET_UT_CORE_COMMON_DE_UT_BBOXOP_COMMON_H_ #include "cvop_common.h" -#include "dataset/util/path.h" +#include "minddata/dataset/util/path.h" namespace UT { namespace CVOP { diff --git a/tests/ut/cpp/dataset/common/cvop_common.cc b/tests/ut/cpp/dataset/common/cvop_common.cc index 6f66229e80..48d69564fd 100644 --- a/tests/ut/cpp/dataset/common/cvop_common.cc +++ b/tests/ut/cpp/dataset/common/cvop_common.cc @@ -18,9 +18,9 @@ #include #include #include "cvop_common.h" -#include "dataset/core/constants.h" +#include "minddata/dataset/core/constants.h" #include "common/utils.h" -#include "dataset/core/cv_tensor.h" +#include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" #include #include diff --git a/tests/ut/cpp/dataset/common/cvop_common.h b/tests/ut/cpp/dataset/common/cvop_common.h index 02c079fd68..59134091fd 100644 --- a/tests/ut/cpp/dataset/common/cvop_common.h +++ b/tests/ut/cpp/dataset/common/cvop_common.h @@ -19,7 +19,7 @@ #include #include #include "common.h" -#include "dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/image_utils.h" namespace UT { namespace CVOP { diff --git a/tests/ut/cpp/dataset/concat_op_test.cc b/tests/ut/cpp/dataset/concat_op_test.cc index 70d0268ec7..9e991ce0d3 100644 --- a/tests/ut/cpp/dataset/concat_op_test.cc +++ b/tests/ut/cpp/dataset/concat_op_test.cc @@ -19,7 +19,7 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/concatenate_op_test.cc b/tests/ut/cpp/dataset/concatenate_op_test.cc index 1ceedbac38..dc2fc69266 100644 --- a/tests/ut/cpp/dataset/concatenate_op_test.cc +++ b/tests/ut/cpp/dataset/concatenate_op_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common.h" -#include "dataset/kernels/data/concatenate_op.h" +#include "minddata/dataset/kernels/data/concatenate_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/connector_test.cc b/tests/ut/cpp/dataset/connector_test.cc index 7ee36cc2c0..0fc5b100d7 100644 --- a/tests/ut/cpp/dataset/connector_test.cc +++ b/tests/ut/cpp/dataset/connector_test.cc @@ -23,8 +23,8 @@ #include "common/common.h" -#include "dataset/engine/connector.h" -#include "dataset/util/task_manager.h" +#include "minddata/dataset/engine/connector.h" +#include "minddata/dataset/util/task_manager.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/cut_out_op_test.cc b/tests/ut/cpp/dataset/cut_out_op_test.cc index 462fb3a875..5d24d9c3f9 100644 --- a/tests/ut/cpp/dataset/cut_out_op_test.cc +++ b/tests/ut/cpp/dataset/cut_out_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/cut_out_op.h" +#include "minddata/dataset/kernels/image/cut_out_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/cyclic_array_test.cc b/tests/ut/cpp/dataset/cyclic_array_test.cc index 55f75c403f..380436de1b 100644 --- a/tests/ut/cpp/dataset/cyclic_array_test.cc +++ b/tests/ut/cpp/dataset/cyclic_array_test.cc @@ -19,7 +19,7 @@ #include "common/cvop_common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/engine/perf/cyclic_array.h" +#include "minddata/dataset/engine/perf/cyclic_array.h" #include using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/datatype_test.cc b/tests/ut/cpp/dataset/datatype_test.cc index a55853c4c5..b81618dc24 100644 --- a/tests/ut/cpp/dataset/datatype_test.cc +++ b/tests/ut/cpp/dataset/datatype_test.cc @@ -15,16 +15,14 @@ */ #include #include "./securec.h" -#include "dataset/core/data_type.h" +#include "minddata/dataset/core/data_type.h" #include "common/common.h" #include "gtest/gtest.h" #include -#include "dataset/core/constants.h" +#include "minddata/dataset/core/constants.h" using namespace mindspore::dataset; -namespace py = pybind11; - class MindDataTestDatatype : public UT::Common { public: MindDataTestDatatype() = default; diff --git a/tests/ut/cpp/dataset/decode_op_test.cc b/tests/ut/cpp/dataset/decode_op_test.cc index 7f3e129ac0..1cd03099ce 100644 --- a/tests/ut/cpp/dataset/decode_op_test.cc +++ b/tests/ut/cpp/dataset/decode_op_test.cc @@ -16,7 +16,7 @@ #include #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/duplicate_op_test.cc b/tests/ut/cpp/dataset/duplicate_op_test.cc index b7ce32f655..93779b084d 100644 --- a/tests/ut/cpp/dataset/duplicate_op_test.cc +++ b/tests/ut/cpp/dataset/duplicate_op_test.cc @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/duplicate_op.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/duplicate_op.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/execution_tree_test.cc b/tests/ut/cpp/dataset/execution_tree_test.cc index 529644331a..b871dd00d8 100644 --- a/tests/ut/cpp/dataset/execution_tree_test.cc +++ b/tests/ut/cpp/dataset/execution_tree_test.cc @@ -14,11 +14,11 @@ * limitations under the License. */ #include -#include "dataset/util/circular_pool.h" -#include "dataset/core/client.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/fill_op_test.cc b/tests/ut/cpp/dataset/fill_op_test.cc index d43b7d7548..20e323cc8d 100644 --- a/tests/ut/cpp/dataset/fill_op_test.cc +++ b/tests/ut/cpp/dataset/fill_op_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common.h" -#include "dataset/kernels/data/fill_op.h" +#include "minddata/dataset/kernels/data/fill_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/filter_op_test.cc b/tests/ut/cpp/dataset/filter_op_test.cc index 45ee714337..3e5be8dc04 100644 --- a/tests/ut/cpp/dataset/filter_op_test.cc +++ b/tests/ut/cpp/dataset/filter_op_test.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/circular_pool.h" -#include "dataset/core/client.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/global_context_test.cc b/tests/ut/cpp/dataset/global_context_test.cc index bb75d941aa..cd4c970ae6 100644 --- a/tests/ut/cpp/dataset/global_context_test.cc +++ b/tests/ut/cpp/dataset/global_context_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/global_context.h" #include "common/common.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index dc74e66b0c..c4dd7b055c 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -20,9 +20,9 @@ #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/util/status.h" -#include "dataset/engine/gnn/node.h" -#include "dataset/engine/gnn/graph_loader.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/node.h" +#include "minddata/dataset/engine/gnn/graph_loader.h" using namespace mindspore::dataset; using namespace mindspore::dataset::gnn; @@ -49,9 +49,10 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) { EdgeTypeMap e_type_map; NodeFeatureMap n_feature_map; EdgeFeatureMap e_feature_map; - DefaultFeatureMap default_feature_map; + DefaultNodeFeatureMap default_node_feature_map; + DefaultEdgeFeatureMap default_edge_feature_map; EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map, - &default_feature_map) + &default_node_feature_map, &default_edge_feature_map) .IsOk()); EXPECT_EQ(n_id_map.size(), 20); EXPECT_EQ(e_id_map.size(), 40); @@ -119,6 +120,17 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { std::transform(edges->begin(), edges->end(), edge_list.begin(), [](const EdgeIdType edge) { return edge; }); + TensorRow edge_features; + s = graph.GetEdgeFeature(edges, meta_info.edge_feature_type, &edge_features); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(edge_features[0]->ToString() == + "Tensor (shape: <40>, Type: int32)\n" + "[0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0]"); + EXPECT_TRUE(edge_features[1]->ToString() == + "Tensor (shape: <40>, Type: float32)\n" + "[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2,2.1,2.2,2.3,2.4,2.5,2.6,2." + "7,2.8,2.9,3,3.1,3.2,3.3,3.4,3.5,3.6,3.7,3.8,3.9,4]"); + std::shared_ptr nodes; s = graph.GetNodesFromEdges(edge_list, &nodes); EXPECT_TRUE(s.IsOk()); @@ -247,4 +259,30 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) { s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path); EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); -} \ No newline at end of file +} + +TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) { + std::string path = "data/mindrecord/testGraphData/sns"; + Graph graph(path, 1); + Status s = graph.Init(); + EXPECT_TRUE(s.IsOk()); + + MetaInfo meta_info; + s = graph.GetMetaInfo(&meta_info); + EXPECT_TRUE(s.IsOk()); + + std::shared_ptr nodes; + s = graph.GetAllNodes(meta_info.node_type[0], &nodes); + EXPECT_TRUE(s.IsOk()); + std::vector node_list; + for (auto itr = nodes->begin(); itr != nodes->end(); ++itr) { + node_list.push_back(*itr); + } + + print_int_vec(node_list, "node list "); + std::vector meta_path(59, 1); + std::shared_ptr walk_path; + s = graph.RandomWalk(node_list, meta_path, 1.0, 1.0, -1, &walk_path); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>"); +} diff --git a/tests/ut/cpp/dataset/image_folder_op_test.cc b/tests/ut/cpp/dataset/image_folder_op_test.cc index 576c5abbfc..3168efa196 100644 --- a/tests/ut/cpp/dataset/image_folder_op_test.cc +++ b/tests/ut/cpp/dataset/image_folder_op_test.cc @@ -19,18 +19,18 @@ #include #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/interrupt_test.cc b/tests/ut/cpp/dataset/interrupt_test.cc index 7ab608b9ae..8a06413175 100644 --- a/tests/ut/cpp/dataset/interrupt_test.cc +++ b/tests/ut/cpp/dataset/interrupt_test.cc @@ -15,10 +15,10 @@ */ #include "common/common.h" #include "utils/log_adapter.h" -#include "dataset/util/services.h" -#include "dataset/util/intrp_service.h" -#include "dataset/util/task_manager.h" -#include "dataset/util/queue.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/intrp_service.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/queue.h" using namespace mindspore::dataset; using mindspore::MsLogLevel::INFO; diff --git a/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc b/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc index c5a733f285..85b3384d36 100644 --- a/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc +++ b/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc @@ -18,7 +18,7 @@ #include #include "common/common.h" -#include "dataset/text/kernels/jieba_tokenizer_op.h" +#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" @@ -39,21 +39,22 @@ TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opFuntions) { std::string dataset_path = datasets_root_path_ + "/jiebadict"; std::string hmm_path = dataset_path + "/hmm_model.utf8"; std::string mp_path = dataset_path + "/jieba.dict.utf8"; - std::shared_ptr output_tensor; + TensorRow input, output; std::unique_ptr op(new JiebaTokenizerOp(hmm_path, mp_path)); std::shared_ptr input_tensor = std::make_shared("今天天气太好了我们一起去外面玩吧"); - Status s = op->Compute(input_tensor, &output_tensor); + input.push_back(input_tensor); + Status s = op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output_tensor->Rank(), 1); - EXPECT_EQ(output_tensor->Size(), 7); - CheckEqual(output_tensor, {0}, "今天天气"); - CheckEqual(output_tensor, {1}, "太好了"); - CheckEqual(output_tensor, {2}, "我们"); - CheckEqual(output_tensor, {3}, "一起"); - CheckEqual(output_tensor, {4}, "去"); - CheckEqual(output_tensor, {5}, "外面"); - CheckEqual(output_tensor, {6}, "玩吧"); + EXPECT_EQ(output[0]->Rank(), 1); + EXPECT_EQ(output[0]->Size(), 7); + CheckEqual(output[0], {0}, "今天天气"); + CheckEqual(output[0], {1}, "太好了"); + CheckEqual(output[0], {2}, "我们"); + CheckEqual(output[0], {3}, "一起"); + CheckEqual(output[0], {4}, "去"); + CheckEqual(output[0], {5}, "外面"); + CheckEqual(output[0], {6}, "玩吧"); } TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opAdd) { @@ -61,16 +62,17 @@ TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opAdd) { std::string dataset_path = datasets_root_path_ + "/jiebadict"; std::string hmm_path = dataset_path + "/hmm_model.utf8"; std::string mp_path = dataset_path + "/jieba.dict.utf8"; - std::shared_ptr output_tensor; + TensorRow input, output; std::unique_ptr op(new JiebaTokenizerOp(hmm_path, mp_path)); op->AddWord("男默女泪"); std::shared_ptr input_tensor = std::make_shared("男默女泪"); - Status s = op->Compute(input_tensor, &output_tensor); + input.push_back(input_tensor); + Status s = op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output_tensor->Rank(), 1); - EXPECT_EQ(output_tensor->Size(), 1); - CheckEqual(output_tensor, {0}, "男默女泪"); + EXPECT_EQ(output[0]->Rank(), 1); + EXPECT_EQ(output[0]->Size(), 1); + CheckEqual(output[0], {0}, "男默女泪"); } TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opEmpty) { @@ -78,14 +80,15 @@ TEST_F(MindDataTestJiebaTokenizerOp, TestJieba_opEmpty) { std::string dataset_path = datasets_root_path_ + "/jiebadict"; std::string hmm_path = dataset_path + "/hmm_model.utf8"; std::string mp_path = dataset_path + "/jieba.dict.utf8"; - std::shared_ptr output_tensor; + TensorRow input, output; std::unique_ptr op(new JiebaTokenizerOp(hmm_path, mp_path)); op->AddWord("男默女泪"); std::shared_ptr input_tensor = std::make_shared(""); - Status s = op->Compute(input_tensor, &output_tensor); + input.push_back(input_tensor); + Status s = op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output_tensor->Rank(), 1); - EXPECT_EQ(output_tensor->Size(), 1); - CheckEqual(output_tensor, {0}, ""); + EXPECT_EQ(output[0]->Rank(), 1); + EXPECT_EQ(output[0]->Size(), 1); + CheckEqual(output[0], {0}, ""); } \ No newline at end of file diff --git a/tests/ut/cpp/dataset/manifest_op_test.cc b/tests/ut/cpp/dataset/manifest_op_test.cc index 6317a6a345..a6eef4aaa2 100644 --- a/tests/ut/cpp/dataset/manifest_op_test.cc +++ b/tests/ut/cpp/dataset/manifest_op_test.cc @@ -20,12 +20,12 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/map_op_test.cc b/tests/ut/cpp/dataset/map_op_test.cc index 8b6a152488..4e9cfe9ec9 100644 --- a/tests/ut/cpp/dataset/map_op_test.cc +++ b/tests/ut/cpp/dataset/map_op_test.cc @@ -17,13 +17,14 @@ #include #include + #include "common/common.h" -#include "dataset/core/client.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/kernels/image/decode_op.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/tensor_op.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/tensor_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; @@ -35,93 +36,99 @@ namespace dataset { namespace test { class NoOp : public TensorOp { public: - NoOp() {}; + NoOp(){}; + + ~NoOp(){}; - ~NoOp() {}; + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override { + *output = std::move(input); + return Status::OK(); + }; - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override { - *output = std::move(input); - return Status::OK(); - }; + void Print(std::ostream &out) const override { out << "NoOp"; }; - void Print(std::ostream &out) const override { out << "NoOp"; }; + std::string Name() const override { return kNoOp; } }; class ThreeToOneOp : public TensorOp { public: - ThreeToOneOp() {}; + ThreeToOneOp(){}; + + ~ThreeToOneOp(){}; - ~ThreeToOneOp() {}; + uint32_t NumInput() override { return 3; } + // Compute function that holds the actual implementation of the operation. + Status Compute(const TensorRow &input, TensorRow *output) override { + output->push_back(input[0]); + return Status::OK(); + }; - uint32_t NumInput() override { return 3; } - // Compute function that holds the actual implementation of the operation. - Status Compute(const TensorRow &input, TensorRow *output) override { - output->push_back(input[0]); - return Status::OK(); - }; + void Print(std::ostream &out) const override { out << "ThreeToOneOp"; }; - void Print(std::ostream &out) const override { out << "ThreeToOneOp"; }; + std::string Name() const override { return "ThreeToOneOp"; } }; class OneToThreeOp : public TensorOp { public: - OneToThreeOp() {}; + OneToThreeOp(){}; - ~OneToThreeOp() {}; + ~OneToThreeOp(){}; uint32_t NumOutput() override { return 3; } - // Compute function that holds the actual implementation of the operation. - // Simply pushing the same shared pointer of the first element of input vector three times. - Status Compute(const TensorRow &input, TensorRow *output) override { - output->push_back(input[0]); - output->push_back(input[0]); - output->push_back(input[0]); - return Status::OK(); - }; + // Compute function that holds the actual implementation of the operation. + // Simply pushing the same shared pointer of the first element of input vector three times. + Status Compute(const TensorRow &input, TensorRow *output) override { + output->push_back(input[0]); + output->push_back(input[0]); + output->push_back(input[0]); + return Status::OK(); + }; - void Print(std::ostream &out) const override { out << "OneToThreeOp"; }; + void Print(std::ostream &out) const override { out << "OneToThreeOp"; }; + + std::string Name() const override { return "OneToThreeOp"; }; }; } // namespace test } // namespace dataset } // namespace mindspore - class MindDataTestMapOp : public UT::DatasetOpTesting { public: - void SetUp() override { - DatasetOpTesting::SetUp(); - dataset_path_ = datasets_root_path_ + "" + "/testDataset2/testDataset2.data"; - schema_path_ = datasets_root_path_ + "" + "/testDataset2/datasetSchema.json"; + void SetUp() override { + DatasetOpTesting::SetUp(); + dataset_path_ = datasets_root_path_ + "" + "/testDataset2/testDataset2.data"; + schema_path_ = datasets_root_path_ + "" + "/testDataset2/datasetSchema.json"; - GlobalInit(); + GlobalInit(); - // Start with an empty execution tree - my_tree_ = std::make_shared(); - } + // Start with an empty execution tree + my_tree_ = std::make_shared(); + } - std::shared_ptr CreateTFReaderOp() { - std::shared_ptr my_tfreader_op; - TFReaderOp::Builder builder; - builder.SetDatasetFilesList({dataset_path_}) - .SetColumnsToLoad({"image", "label", "A", "B"}) - .SetRowsPerBuffer(2) - .SetWorkerConnectorSize(2) - .SetNumWorkers(2); - - std::unique_ptr schema = std::make_unique(); - schema->LoadSchemaFile(schema_path_, {}); - builder.SetDataSchema(std::move(schema)); - - Status rc = builder.Build(&my_tfreader_op); - EXPECT_TRUE(rc.IsOk()); - return my_tfreader_op; - } + std::shared_ptr CreateTFReaderOp() { + std::shared_ptr my_tfreader_op; + TFReaderOp::Builder builder; + builder.SetDatasetFilesList({dataset_path_}) + .SetColumnsToLoad({"image", "label", "A", "B"}) + .SetRowsPerBuffer(2) + .SetWorkerConnectorSize(2) + .SetNumWorkers(2); + + std::unique_ptr schema = std::make_unique(); + schema->LoadSchemaFile(schema_path_, {}); + builder.SetDataSchema(std::move(schema)); + + Status rc = builder.Build(&my_tfreader_op); + EXPECT_TRUE(rc.IsOk()); + return my_tfreader_op; + } + + std::shared_ptr my_tree_; - std::shared_ptr my_tree_; private: - std::string dataset_path_; - std::string schema_path_; + std::string dataset_path_; + std::string schema_path_; }; std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, @@ -148,10 +155,7 @@ TEST_F(MindDataTestMapOp, TestAsMap) { my_func_list.push_back(my_no_op); std::shared_ptr my_map_op; MapOp::Builder builder; - builder.SetInColNames({"image"}) - .SetOutColNames({"X"}) - .SetTensorFuncs(std::move(my_func_list)) - .SetNumWorkers(1); + builder.SetInColNames({"image"}).SetOutColNames({"X"}).SetTensorFuncs(std::move(my_func_list)).SetNumWorkers(1); rc = builder.Build(&my_map_op); rc = my_tree_->AssociateNode(my_map_op); EXPECT_TRUE(rc.IsOk()); @@ -200,9 +204,9 @@ TEST_F(MindDataTestMapOp, Test3to1) { std::shared_ptr my_map_op; MapOp::Builder builder; builder.SetInColNames({"image", "A", "B"}) - .SetOutColNames({"X"}) - .SetTensorFuncs(std::move(my_func_list)) - .SetNumWorkers(1); + .SetOutColNames({"X"}) + .SetTensorFuncs(std::move(my_func_list)) + .SetNumWorkers(1); rc = builder.Build(&my_map_op); EXPECT_TRUE(rc.IsOk()); rc = my_tree_->AssociateNode(my_map_op); @@ -252,10 +256,9 @@ TEST_F(MindDataTestMapOp, Test1to3) { std::shared_ptr my_map_op; MapOp::Builder builder; builder.SetInColNames({"image"}) - .SetOutColNames({"X", "Y", "Z"}) - .SetTensorFuncs(std::move(my_func_list)) - .SetNumWorkers(1); - + .SetOutColNames({"X", "Y", "Z"}) + .SetTensorFuncs(std::move(my_func_list)) + .SetNumWorkers(1); // ProjectOp std::vector columns_to_project = {"X", "Y", "Z", "label", "A", "B"}; @@ -296,19 +299,18 @@ TEST_F(MindDataTestMapOp, Test1to3) { // Getting the next row as vector (by position). TensorRow tensor_list; - rc =di.FetchNextTensorRow(&tensor_list); + rc = di.FetchNextTensorRow(&tensor_list); EXPECT_TRUE(rc.IsOk()); // Based on the schema file, create the golden result to compare with. std::vector golden_types({DataType::Type::DE_UINT8, DataType::Type::DE_UINT8, DataType::Type::DE_UINT8, DataType::Type::DE_INT64, - DataType::Type::DE_FLOAT32, DataType::Type::DE_INT64} - ); + DataType::Type::DE_FLOAT32, DataType::Type::DE_INT64}); std::vector golden_ranks({3, 3, 3, 1, 4, 1}); std::vector golden_shapes({TensorShape({3, 4, 2}), TensorShape({3, 4, 2}), TensorShape({3, 4, 2}), - TensorShape({7}), TensorShape({1, 13, 14, 12}), TensorShape({9})} ); + TensorShape({7}), TensorShape({1, 13, 14, 12}), TensorShape({9})}); while (!tensor_list.empty()) { for (uint32_t i = 0; i < tensor_list.size(); i++) { @@ -343,9 +345,9 @@ TEST_F(MindDataTestMapOp, TestMultiTensorOp) { std::shared_ptr my_map_op; MapOp::Builder builder; builder.SetInColNames({"image", "A", "B"}) - .SetOutColNames({"X", "Y", "Z"}) - .SetTensorFuncs(std::move(my_func_list)) - .SetNumWorkers(1); + .SetOutColNames({"X", "Y", "Z"}) + .SetTensorFuncs(std::move(my_func_list)) + .SetNumWorkers(1); rc = builder.Build(&my_map_op); EXPECT_TRUE(rc.IsOk()); rc = my_tree_->AssociateNode(my_map_op); @@ -405,10 +407,7 @@ TEST_F(MindDataTestMapOp, TestTFReaderRepeatMap) { std::shared_ptr my_map_op; MapOp::Builder builder; - builder.SetInColNames({"label"}) - .SetOutColNames({}) - .SetTensorFuncs(std::move(my_func_list)) - .SetNumWorkers(5); + builder.SetInColNames({"label"}).SetOutColNames({}).SetTensorFuncs(std::move(my_func_list)).SetNumWorkers(5); rc = builder.Build(&my_map_op); EXPECT_TRUE(rc.IsOk()); rc = my_tree_->AssociateNode(my_map_op); @@ -440,7 +439,6 @@ TEST_F(MindDataTestMapOp, TestTFReaderRepeatMap) { MS_LOG(INFO) << "row_count: " << row_count << "."; rc = di.FetchNextTensorRow(&tensor_list); EXPECT_TRUE(rc.IsOk()); - } ASSERT_EQ(row_count, 10 * num_repeats); } @@ -467,10 +465,7 @@ TEST_F(MindDataTestMapOp, TestTFReaderMapRepeat) { std::shared_ptr my_map_op; MapOp::Builder builder; - builder.SetInColNames({"label"}) - .SetOutColNames({}) - .SetTensorFuncs(std::move(my_func_list)) - .SetNumWorkers(50); + builder.SetInColNames({"label"}).SetOutColNames({}).SetTensorFuncs(std::move(my_func_list)).SetNumWorkers(50); rc = builder.Build(&my_map_op); EXPECT_TRUE(rc.IsOk()); rc = my_tree_->AssociateNode(my_map_op); @@ -536,25 +531,18 @@ TEST_F(MindDataTestMapOp, TFReader_Decode_Repeat_Resize) { std::shared_ptr my_map_decode_op; MapOp::Builder builder; - builder.SetInColNames({"image"}) - .SetOutColNames({}) - .SetTensorFuncs(std::move(my_func_list)) - .SetNumWorkers(4); + builder.SetInColNames({"image"}).SetOutColNames({}).SetTensorFuncs(std::move(my_func_list)).SetNumWorkers(4); rc = builder.Build(&my_map_decode_op); EXPECT_TRUE(rc.IsOk()); rc = my_tree_->AssociateNode(my_map_decode_op); EXPECT_TRUE(rc.IsOk()); - auto resize_op = std::make_shared(300, 300); std::vector> my_func_list2; my_func_list2.push_back(resize_op); std::shared_ptr my_map_resize_op; MapOp::Builder builder2; - builder2.SetInColNames({"image"}) - .SetOutColNames({}) - .SetTensorFuncs(std::move(my_func_list2)) - .SetNumWorkers(5); + builder2.SetInColNames({"image"}).SetOutColNames({}).SetTensorFuncs(std::move(my_func_list2)).SetNumWorkers(5); rc = builder2.Build(&my_map_resize_op); EXPECT_TRUE(rc.IsOk()); rc = my_tree_->AssociateNode(my_map_resize_op); @@ -610,10 +598,7 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) { std::shared_ptr map_decode_map; MapOp::Builder map_decode_builder; - map_decode_builder.SetInColNames({"image"}) - .SetOutColNames({}) - .SetTensorFuncs(func_list) - .SetNumWorkers(4); + map_decode_builder.SetInColNames({"image"}).SetOutColNames({}).SetTensorFuncs(func_list).SetNumWorkers(4); rc = map_decode_builder.Build(&map_decode_map); EXPECT_TRUE(rc.IsOk()); @@ -622,10 +607,7 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) { func_list2.push_back(resize_op); std::shared_ptr map_resize_op; MapOp::Builder map_resize_builder; - map_resize_builder.SetInColNames({"image"}) - .SetOutColNames({}) - .SetTensorFuncs(func_list2) - .SetNumWorkers(5); + map_resize_builder.SetInColNames({"image"}).SetOutColNames({}).SetTensorFuncs(func_list2).SetNumWorkers(5); rc = map_resize_builder.Build(&map_resize_op); EXPECT_TRUE(rc.IsOk()); @@ -704,7 +686,6 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize) { EXPECT_EQ(result, result2); } - TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) { Status rc; MS_LOG(INFO) << "Doing ImageFolder_Decode_Repeat_Resize_NoInputColumns."; @@ -722,10 +703,7 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) { std::shared_ptr map_decode_map; MapOp::Builder map_decode_builder; - map_decode_builder.SetInColNames({}) - .SetOutColNames({}) - .SetTensorFuncs(func_list) - .SetNumWorkers(4); + map_decode_builder.SetInColNames({}).SetOutColNames({}).SetTensorFuncs(func_list).SetNumWorkers(4); rc = map_decode_builder.Build(&map_decode_map); EXPECT_TRUE(rc.IsOk()); @@ -761,3 +739,5 @@ TEST_F(MindDataTestMapOp, ImageFolder_Decode_Repeat_Resize_NoInputColumns) { } EXPECT_TRUE(i == 88); } + + diff --git a/tests/ut/cpp/dataset/mask_test.cc b/tests/ut/cpp/dataset/mask_test.cc index 9ff5f51fce..609d5bf447 100644 --- a/tests/ut/cpp/dataset/mask_test.cc +++ b/tests/ut/cpp/dataset/mask_test.cc @@ -15,15 +15,15 @@ */ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/core/tensor.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/data_type.h" -#include "dataset/kernels/data/mask_op.h" -#include "dataset/kernels/data/data_utils.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/kernels/data/mask_op.h" +#include "minddata/dataset/kernels/data/data_utils.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/memory_pool_test.cc b/tests/ut/cpp/dataset/memory_pool_test.cc index 136f3fe1b8..b5907655dc 100644 --- a/tests/ut/cpp/dataset/memory_pool_test.cc +++ b/tests/ut/cpp/dataset/memory_pool_test.cc @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "dataset/util/memory_pool.h" -#include "dataset/util/circular_pool.h" -#include "dataset/util/system_pool.h" -#include "dataset/util/allocator.h" +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/util/system_pool.h" +#include "minddata/dataset/util/allocator.h" #include "common/common.h" #include "gtest/gtest.h" diff --git a/tests/ut/cpp/dataset/mind_record_op_test.cc b/tests/ut/cpp/dataset/mind_record_op_test.cc index b2cbdf027e..c9067535d6 100644 --- a/tests/ut/cpp/dataset/mind_record_op_test.cc +++ b/tests/ut/cpp/dataset/mind_record_op_test.cc @@ -16,14 +16,14 @@ #include #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" -#include "mindrecord/include/shard_category.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_shuffle.h" +#include "minddata/mindrecord/include/shard_category.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" #include "utils/log_adapter.h" namespace common = mindspore::common; diff --git a/tests/ut/cpp/dataset/mnist_op_test.cc b/tests/ut/cpp/dataset/mnist_op_test.cc index da78cb6f7f..dfceeaa06a 100644 --- a/tests/ut/cpp/dataset/mnist_op_test.cc +++ b/tests/ut/cpp/dataset/mnist_op_test.cc @@ -20,18 +20,18 @@ #include "common/utils.h" #include "common/common.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/normalize_op_test.cc b/tests/ut/cpp/dataset/normalize_op_test.cc index 05ac3f6289..31791e0e66 100644 --- a/tests/ut/cpp/dataset/normalize_op_test.cc +++ b/tests/ut/cpp/dataset/normalize_op_test.cc @@ -15,8 +15,8 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/normalize_op.h" -#include "dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/normalize_op.h" +#include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" #include diff --git a/tests/ut/cpp/dataset/one_hot_op_test.cc b/tests/ut/cpp/dataset/one_hot_op_test.cc index c414e371e5..2617ae4536 100644 --- a/tests/ut/cpp/dataset/one_hot_op_test.cc +++ b/tests/ut/cpp/dataset/one_hot_op_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common.h" -#include "dataset/kernels/data/one_hot_op.h" +#include "minddata/dataset/kernels/data/one_hot_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/pad_end_op_test.cc b/tests/ut/cpp/dataset/pad_end_op_test.cc index 2787501aa9..1c838da8e8 100644 --- a/tests/ut/cpp/dataset/pad_end_op_test.cc +++ b/tests/ut/cpp/dataset/pad_end_op_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common.h" -#include "dataset/kernels/data/pad_end_op.h" +#include "minddata/dataset/kernels/data/pad_end_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/pad_op_test.cc b/tests/ut/cpp/dataset/pad_op_test.cc index b659d009f3..e2bd822d02 100644 --- a/tests/ut/cpp/dataset/pad_op_test.cc +++ b/tests/ut/cpp/dataset/pad_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/pad_op.h" +#include "minddata/dataset/kernels/image/pad_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/path_test.cc b/tests/ut/cpp/dataset/path_test.cc index 4cf3b17968..b36b38bbc7 100644 --- a/tests/ut/cpp/dataset/path_test.cc +++ b/tests/ut/cpp/dataset/path_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/path.h" +#include "minddata/dataset/util/path.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/perf_data_test.cc b/tests/ut/cpp/dataset/perf_data_test.cc index 048ee1f21a..486209be21 100644 --- a/tests/ut/cpp/dataset/perf_data_test.cc +++ b/tests/ut/cpp/dataset/perf_data_test.cc @@ -17,8 +17,8 @@ #include "common/cvop_common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/engine/perf/cyclic_array.h" -#include "dataset/engine/perf/perf_data.h" +#include "minddata/dataset/engine/perf/cyclic_array.h" +#include "minddata/dataset/engine/perf/perf_data.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/project_op_test.cc b/tests/ut/cpp/dataset/project_op_test.cc index 484396321c..45ef11b88f 100644 --- a/tests/ut/cpp/dataset/project_op_test.cc +++ b/tests/ut/cpp/dataset/project_op_test.cc @@ -19,7 +19,7 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/queue_test.cc b/tests/ut/cpp/dataset/queue_test.cc index 578405e537..ec40cc2ae4 100644 --- a/tests/ut/cpp/dataset/queue_test.cc +++ b/tests/ut/cpp/dataset/queue_test.cc @@ -16,9 +16,11 @@ #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/util/task_manager.h" -#include "dataset/util/queue.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/queue.h" #include +#include +#include #include "utils/log_adapter.h" using namespace mindspore::dataset; @@ -39,7 +41,7 @@ class RefCount { public: RefCount() : v_(nullptr) {} explicit RefCount(int x) : v_(std::make_shared(x)) {} - explicit RefCount(const RefCount &o) : v_(o.v_) {} + RefCount(const RefCount &o) : v_(o.v_) {} ~RefCount() { MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl; gRefCountDestructorCalled++; @@ -167,3 +169,70 @@ TEST_F(MindDataTestQueue, Test6) { MS_LOG(INFO) << "Popped value " << *pepped_value << " from queue index " << chosen_queue_index; ASSERT_EQ(*pepped_value, 99); } +using namespace std::chrono; +template +void Perf(int n, int p, std::string name) { + auto payload = std::vector(n, PayloadType(p)); + auto queue = QueueType(n); + auto t0 = high_resolution_clock::now(); + auto check = 0; + for (int i = 0; i < queue.capacity(); i++) { + queue.Add(PayloadType(p)); + } + check = queue.size(); + for (int i = 0; i < queue.capacity(); i++) { + queue.PopFront(&payload[i]); + } + auto t1 = high_resolution_clock::now(); + std::cout << name << " queue filled size: " << queue.size() << " " << check << std::endl; + auto t2 = high_resolution_clock::now(); + for (int i = 0; i < queue.capacity(); i++) { + queue.Add(PayloadType(p)); + } + check = queue.size(); + for (int i = 0; i < queue.capacity(); i++) { + queue.PopFront(&payload[i]); + } + auto t3 = high_resolution_clock::now(); + auto d = duration_cast(t3 - t2 + t1 - t0).count(); + std::cout << name << " queue emptied size: " << queue.size() << " " << check << std::endl; + std::cout << name << " " + << " ran in " << d << "ms" << std::endl; +} + +template +void Fuzz(int n, int p, std::string name) { + std::mt19937 gen(1); + auto payload = std::vector(n, PayloadType(p)); + auto queue = QueueType(n); + auto dist = std::uniform_int_distribution(0, 2); + std::cout << "###" << std::endl; + for (auto i = 0; i < n; i++) { + auto v = dist(gen); + if (v == 0 && queue.size() < n - 1) { + queue.Add(std::move(payload[i])); + } + if (v == 1 && queue.size() > 0) { + queue.PopFront(&payload[i]); + } else { + queue.Reset(); + } + } + std::cout << name << " fuzz ran " << queue.size() << std::endl; +} +TEST_F(MindDataTestQueue, TestPerf) { + try { + int kSz = 1000000; + // std::cout << "enter size" << std::endl; + // std::cin >> kSz; + Perf>, std::vector>(kSz, 1, "old queue, vector of size 1"); + } catch (const std::exception &e) { + std::cout << e.what() << std::endl; + } + + std::cout << "Test Reset" << std::endl; + std::cout << "Enter fuzz size" << std::endl; + int fs = 1000; +// std::cin >> fs; + Fuzz>, std::vector>(fs, 1, "New queue"); +} diff --git a/tests/ut/cpp/dataset/random_color_adjust_op_test.cc b/tests/ut/cpp/dataset/random_color_adjust_op_test.cc index 82df108ad1..96f4dd8145 100644 --- a/tests/ut/cpp/dataset/random_color_adjust_op_test.cc +++ b/tests/ut/cpp/dataset/random_color_adjust_op_test.cc @@ -15,8 +15,8 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_color_adjust_op.h" -#include "dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/random_color_adjust_op.h" +#include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_crop_and_resize_op_test.cc b/tests/ut/cpp/dataset/random_crop_and_resize_op_test.cc index 3d5298b071..fd59a90117 100644 --- a/tests/ut/cpp/dataset/random_crop_and_resize_op_test.cc +++ b/tests/ut/cpp/dataset/random_crop_and_resize_op_test.cc @@ -16,7 +16,7 @@ #include "common/common.h" #include "common/cvop_common.h" #include -#include "dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_crop_and_resize_with_bbox_op_test.cc b/tests/ut/cpp/dataset/random_crop_and_resize_with_bbox_op_test.cc new file mode 100644 index 0000000000..4efdcb8b78 --- /dev/null +++ b/tests/ut/cpp/dataset/random_crop_and_resize_with_bbox_op_test.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/bboxop_common.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" +#include "utils/log_adapter.h" + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +const bool kSaveExpected = false; +const char kOpName[] = "RandomResizedCropWithBBox_C"; + +class MindDataTestRandomCropAndResizeWithBBoxOp : public UT::CVOP::BBOXOP::BBoxOpCommon { + protected: + MindDataTestRandomCropAndResizeWithBBoxOp() : BBoxOpCommon() {} +}; + +TEST_F(MindDataTestRandomCropAndResizeWithBBoxOp, TestOp1) { + MS_LOG(INFO) << "Doing testRandomCropAndResizeWithBBoxOp1."; + // setting seed here + uint32_t current_seed = GlobalContext::config_manager()->seed(); + GlobalContext::config_manager()->set_seed(327362); + TensorRow output_tensor_row_; + TensorTable results; + int h_out = 1024; + int w_out = 2048; + float aspect_lb = 2; + float aspect_ub = 2.5; + float scale_lb = 0.2; + float scale_ub = 2.0; + auto op = std::make_unique(h_out, w_out, scale_lb, scale_ub, aspect_lb, aspect_ub); + Status s; + for (auto tensor_row_ : images_and_annotations_) { + s = op->Compute(tensor_row_, &output_tensor_row_); + EXPECT_TRUE(s.IsOk()); + results.push_back(output_tensor_row_); + } + if (kSaveExpected) { + SaveImagesWithAnnotations(FileType::kExpected, std::string(kOpName), results); + } + SaveImagesWithAnnotations(FileType::kActual, std::string(kOpName), results); + if (!kSaveExpected) { + CompareActualAndExpected(std::string(kOpName)); + } + GlobalContext::config_manager()->set_seed(current_seed); +} + +TEST_F(MindDataTestRandomCropAndResizeWithBBoxOp, TestOp2) { + MS_LOG(INFO) << "Doing testRandomCropAndResizeWithBBoxOp2."; + TensorRow output_tensor_row_; + int h_out = 1024; + int w_out = 2048; + float aspect_lb = 1; + float aspect_ub = 1.5; + float scale_lb = 0.2; + float scale_ub = 2.0; + auto op = std::make_unique(h_out, w_out, scale_lb, scale_ub, aspect_lb, aspect_ub); + Status s; + for (auto tensor_row_ : images_and_annotations_) { + s = op->Compute(tensor_row_, &output_tensor_row_); + EXPECT_TRUE(s.IsOk()); + } +} + +TEST_F(MindDataTestRandomCropAndResizeWithBBoxOp, TestOp3) { + MS_LOG(INFO) << "Doing testRandomCropAndResizeWithBBoxOp3."; + TensorRow output_tensor_row_; + int h_out = 1024; + int w_out = 2048; + float aspect_lb = 0.2; + float aspect_ub = 3; + float scale_lb = 0.2; + float scale_ub = 2.0; + auto op = std::make_unique(h_out, w_out, scale_lb, scale_ub, aspect_lb, aspect_ub); + Status s; + for (auto tensor_row_ : images_and_annotations_) { + s = op->Compute(tensor_row_, &output_tensor_row_); + EXPECT_TRUE(s.IsOk()); + } + MS_LOG(INFO) << "testRandomCropAndResizeWithBBoxOp end."; +} \ No newline at end of file diff --git a/tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc b/tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc index 1c9f3a98dc..170525b4e7 100644 --- a/tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc +++ b/tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc @@ -16,10 +16,10 @@ #include #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/decode_op.h" -#include "dataset/kernels/image/random_crop_and_resize_op.h" -#include "dataset/kernels/image/random_crop_decode_resize_op.h" -#include "dataset/core/config_manager.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" +#include "minddata/dataset/core/config_manager.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; @@ -54,7 +54,7 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp2) { auto decode_and_crop = static_cast(crop_and_decode_copy); EXPECT_TRUE(crop_and_decode.OneToOne()); GlobalContext::config_manager()->set_seed(42); - for (int k = 0; k < 100; k++) { + for (int k = 0; k < 10; k++) { (void)crop_and_decode.Compute(raw_input_tensor_, &crop_and_decode_output); (void)decode_and_crop.Compute(input_tensor_, &decode_and_crop_output); cv::Mat output1 = CVTensor::AsCVTensor(crop_and_decode_output)->mat().clone(); @@ -104,10 +104,10 @@ TEST_F(MindDataTestRandomCropDecodeResizeOp, TestOp1) { int mse_sum, m1, m2, count; double mse; - for (int k = 0; k < 100; ++k) { + for (int k = 0; k < 10; ++k) { mse_sum = 0; count = 0; - for (auto i = 0; i < 100; i++) { + for (auto i = 0; i < 10; i++) { scale = rd_scale(rd); aspect = rd_aspect(rd); crop_width = std::round(std::sqrt(h * w * scale / aspect)); diff --git a/tests/ut/cpp/dataset/random_crop_op_test.cc b/tests/ut/cpp/dataset/random_crop_op_test.cc index 2f3b19e2f4..9c8f1f31ed 100644 --- a/tests/ut/cpp/dataset/random_crop_op_test.cc +++ b/tests/ut/cpp/dataset/random_crop_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_crop_op.h" +#include "minddata/dataset/kernels/image/random_crop_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc b/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc new file mode 100644 index 0000000000..fcf8ba2605 --- /dev/null +++ b/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/bboxop_common.h" +#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h" +#include "utils/log_adapter.h" + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +const bool kSaveExpected = false; +const char kOpName[] = "RandomCropWithBBox_C"; + +class MindDataTestRandomCropWithBBoxOp : public UT::CVOP::BBOXOP::BBoxOpCommon { + protected: + MindDataTestRandomCropWithBBoxOp() : BBoxOpCommon() {} + TensorRow output_tensor_row_; +}; + +TEST_F(MindDataTestRandomCropWithBBoxOp, TestOp1) { + MS_LOG(INFO) << "Doing testRandomCropWithBBoxOp1."; + TensorTable results; + unsigned int crop_height = 128; + unsigned int crop_width = 128; + // setting seed here + uint32_t current_seed = GlobalContext::config_manager()->seed(); + GlobalContext::config_manager()->set_seed(327362); + std::unique_ptr op( + new RandomCropWithBBoxOp(crop_height, crop_width, 0, 0, 0, 0, BorderType::kConstant, false)); + for (auto tensor_row_ : images_and_annotations_) { + Status s = op->Compute(tensor_row_, &output_tensor_row_); + size_t actual = 0; + if (s == Status::OK()) { + TensorShape get_shape = output_tensor_row_[0]->shape(); + actual = get_shape[0] * get_shape[1] * get_shape[2]; + results.push_back(output_tensor_row_); + } + EXPECT_EQ(actual, crop_height * crop_width * 3); + EXPECT_EQ(s, Status::OK()); + EXPECT_EQ(4, output_tensor_row_[1]->shape()[1]); // check for existence of 4 columns + // Compare Code + if (kSaveExpected) { + SaveImagesWithAnnotations(FileType::kExpected, std::string(kOpName), results); + } + SaveImagesWithAnnotations(FileType::kActual, std::string(kOpName), results); + if (!kSaveExpected) { + CompareActualAndExpected(std::string(kOpName)); + } + GlobalContext::config_manager()->set_seed(current_seed); + } +} + +TEST_F(MindDataTestRandomCropWithBBoxOp, TestOp2) { + MS_LOG(INFO) << "Doing testRandomCropWithBBoxOp2."; + // Crop params + unsigned int crop_height = 1280; + unsigned int crop_width = 1280; + std::unique_ptr op( + new RandomCropWithBBoxOp(crop_height, crop_width, 513, 513, 513, 513, BorderType::kConstant, false)); + + for (auto tensor_row_ : images_and_annotations_) { + Status s = op->Compute(tensor_row_, &output_tensor_row_); + size_t actual = 0; + if (s == Status::OK()) { + TensorShape get_shape = output_tensor_row_[0]->shape(); + actual = get_shape[0] * get_shape[1] * get_shape[2]; + } + EXPECT_EQ(actual, crop_height * crop_width * 3); + EXPECT_EQ(s, Status::OK()); + EXPECT_EQ(4, output_tensor_row_[1]->shape()[1]); // check for existence of 4 columns + } + MS_LOG(INFO) << "testRandomCropWithBBoxOp end."; +} diff --git a/tests/ut/cpp/dataset/random_data_op_test.cc b/tests/ut/cpp/dataset/random_data_op_test.cc index f8a7440c03..3cb7b57ad6 100644 --- a/tests/ut/cpp/dataset/random_data_op_test.cc +++ b/tests/ut/cpp/dataset/random_data_op_test.cc @@ -14,15 +14,15 @@ * limitations under the License. */ -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include #include #include -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/random_data_op.h" -#include "dataset/engine/data_schema.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/data_schema.h" using namespace mindspore::dataset; using mindspore::MsLogLevel::INFO; diff --git a/tests/ut/cpp/dataset/random_horizontal_flip_op_test.cc b/tests/ut/cpp/dataset/random_horizontal_flip_op_test.cc index eb2f753554..bb4ba7498d 100644 --- a/tests/ut/cpp/dataset/random_horizontal_flip_op_test.cc +++ b/tests/ut/cpp/dataset/random_horizontal_flip_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_horizontal_flip_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_horizontal_flip_with_bbox_test.cc b/tests/ut/cpp/dataset/random_horizontal_flip_with_bbox_test.cc new file mode 100644 index 0000000000..ed4e866478 --- /dev/null +++ b/tests/ut/cpp/dataset/random_horizontal_flip_with_bbox_test.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/bboxop_common.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +const bool kSaveExpected = false; +const char kOpName[] = "RandomHorizontalFlipWithBBox"; + +class MindDataTestRandomHorizontalFlipWithBBoxOp : public UT::CVOP::BBOXOP::BBoxOpCommon { + protected: + MindDataTestRandomHorizontalFlipWithBBoxOp() : UT::CVOP::BBOXOP::BBoxOpCommon() {} +}; + +TEST_F(MindDataTestRandomHorizontalFlipWithBBoxOp, TestOp) { + MS_LOG(INFO) << "Doing testRandomHorizontalFlipWithBBox."; + TensorTable results; + std::unique_ptr op(new RandomHorizontalFlipWithBBoxOp(1)); + for (const auto &row: images_and_annotations_) { + TensorRow output_row; + Status s = op->Compute(row, &output_row); + EXPECT_TRUE(s.IsOk()); + results.push_back(output_row); + } + if (kSaveExpected) { + SaveImagesWithAnnotations(FileType::kExpected, std::string(kOpName), results); + } + SaveImagesWithAnnotations(FileType::kActual , std::string(kOpName), results); + if (!kSaveExpected) { + CompareActualAndExpected(std::string(kOpName)); + } +} diff --git a/tests/ut/cpp/dataset/random_resize_op_test.cc b/tests/ut/cpp/dataset/random_resize_op_test.cc index ee185f2fc6..d9e85de6e5 100644 --- a/tests/ut/cpp/dataset/random_resize_op_test.cc +++ b/tests/ut/cpp/dataset/random_resize_op_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/kernels/image/random_resize_op.h" +#include "minddata/dataset/kernels/image/random_resize_op.h" #include "common/common.h" #include "common/cvop_common.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/random_resize_with_bbox_op_test.cc b/tests/ut/cpp/dataset/random_resize_with_bbox_op_test.cc new file mode 100644 index 0000000000..e106f57375 --- /dev/null +++ b/tests/ut/cpp/dataset/random_resize_with_bbox_op_test.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/bboxop_common.h" +#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h" +#include "utils/log_adapter.h" + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +const bool kSaveExpected = false; +const char kOpName[] = "RandomResizeWithBBox_C"; + +class MindDataTestRandomResizeWithBBoxOp : public UT::CVOP::BBOXOP::BBoxOpCommon { + protected: + MindDataTestRandomResizeWithBBoxOp() : BBoxOpCommon() {} +}; +TEST_F(MindDataTestRandomResizeWithBBoxOp, TestOp) { + MS_LOG(INFO) << "Doing testRandomResizeWithBBox."; + //setting seed here + u_int32_t curr_seed = GlobalContext::config_manager()->seed(); + GlobalContext::config_manager()->set_seed(120); + TensorTable results; + std::unique_ptr op(new RandomResizeWithBBoxOp(500)); + for (const auto &tensor_row_ : images_and_annotations_) { + // selected a tensorRow + TensorRow output_row; + Status s = op->Compute(tensor_row_, &output_row); + EXPECT_TRUE(s.IsOk()); + results.push_back(output_row); + } + if (kSaveExpected) { + SaveImagesWithAnnotations(FileType::kExpected, std::string(kOpName), results); + } + SaveImagesWithAnnotations(FileType::kActual, std::string(kOpName), results); + if (!kSaveExpected) { + CompareActualAndExpected(std::string(kOpName)); + } + GlobalContext::config_manager()->set_seed(curr_seed); + MS_LOG(INFO) << "testRandomResizeWithBBox end."; +} diff --git a/tests/ut/cpp/dataset/random_rotation_op_test.cc b/tests/ut/cpp/dataset/random_rotation_op_test.cc index 8b82ef1dcd..a6eb5a1ff3 100644 --- a/tests/ut/cpp/dataset/random_rotation_op_test.cc +++ b/tests/ut/cpp/dataset/random_rotation_op_test.cc @@ -16,8 +16,8 @@ #include #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_rotation_op.h" -#include "dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" +#include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_vertical_flip_op_test.cc b/tests/ut/cpp/dataset/random_vertical_flip_op_test.cc index a2583cab96..db8cc89893 100644 --- a/tests/ut/cpp/dataset/random_vertical_flip_op_test.cc +++ b/tests/ut/cpp/dataset/random_vertical_flip_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_vertical_flip_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_vertical_flip_with_bbox_op_test.cc b/tests/ut/cpp/dataset/random_vertical_flip_with_bbox_op_test.cc new file mode 100644 index 0000000000..d1946ef700 --- /dev/null +++ b/tests/ut/cpp/dataset/random_vertical_flip_with_bbox_op_test.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/bboxop_common.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +const bool kSaveExpected = false; +const char kOpName[] = "RandomVerticalFlipWithBBox_C"; + +class MindDataTestRandomVerticalFlipWithBBoxOp : public UT::CVOP::BBOXOP::BBoxOpCommon { + protected: + MindDataTestRandomVerticalFlipWithBBoxOp() : BBoxOpCommon() {} +}; +TEST_F(MindDataTestRandomVerticalFlipWithBBoxOp, TestOp) { + MS_LOG(INFO) << "Doing testRandomVerticalFlipWithBBoxOp."; + TensorTable results; + std::unique_ptr op(new RandomVerticalFlipWithBBoxOp(1)); + for (const auto &tensor_row_ : images_and_annotations_) { + TensorRow output_row; + Status s = op->Compute(tensor_row_, &output_row); + EXPECT_TRUE(s.IsOk()); + results.push_back(output_row); + } + if (kSaveExpected) { + SaveImagesWithAnnotations(FileType::kExpected, std::string(kOpName), results); + } + SaveImagesWithAnnotations(FileType::kActual, std::string(kOpName), results); + if (!kSaveExpected) { + CompareActualAndExpected(std::string(kOpName)); + } + MS_LOG(INFO) << "testRandomVerticalFlipWithBBoxOp end."; +} diff --git a/tests/ut/cpp/dataset/rename_op_test.cc b/tests/ut/cpp/dataset/rename_op_test.cc index b6849ec53e..ac64346c26 100644 --- a/tests/ut/cpp/dataset/rename_op_test.cc +++ b/tests/ut/cpp/dataset/rename_op_test.cc @@ -17,15 +17,15 @@ #include #include #include -#include "dataset/core/client.h" -#include "dataset/core/constants.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/engine/datasetops/rename_op.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/rename_op.h" #include "common/common.h" #include "common/utils.h" -#include "dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_buffer.h" #include "gtest/gtest.h" -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/global_context.h" #include "utils/log_adapter.h" namespace common = mindspore::common; @@ -51,7 +51,7 @@ TEST_F(MindDataTestRenameOp, TestRenameOpDefault) { auto my_tree = std::make_shared(); // Creating TFReaderOp - std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data"; + std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data"; std::shared_ptr my_tfreader_op; rc = TFReaderOp::Builder() .SetDatasetFilesList({dataset_path}) diff --git a/tests/ut/cpp/dataset/repeat_op_test.cc b/tests/ut/cpp/dataset/repeat_op_test.cc index 42549546ba..74d494c0dc 100644 --- a/tests/ut/cpp/dataset/repeat_op_test.cc +++ b/tests/ut/cpp/dataset/repeat_op_test.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/circular_pool.h" -#include "dataset/core/client.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/rescale_op_test.cc b/tests/ut/cpp/dataset/rescale_op_test.cc index 86abbe972e..5d9bf32a9f 100644 --- a/tests/ut/cpp/dataset/rescale_op_test.cc +++ b/tests/ut/cpp/dataset/rescale_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/rescale_op.h" +#include "minddata/dataset/kernels/image/rescale_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/resize_bilinear_op_test.cc b/tests/ut/cpp/dataset/resize_bilinear_op_test.cc index 8642484149..910c8af2a2 100644 --- a/tests/ut/cpp/dataset/resize_bilinear_op_test.cc +++ b/tests/ut/cpp/dataset/resize_bilinear_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/resize_bilinear_op.h" +#include "minddata/dataset/kernels/image/resize_bilinear_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/resize_op_test.cc b/tests/ut/cpp/dataset/resize_op_test.cc index e23320a65a..807668dde4 100644 --- a/tests/ut/cpp/dataset/resize_op_test.cc +++ b/tests/ut/cpp/dataset/resize_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/resize_with_bbox_op_test.cc b/tests/ut/cpp/dataset/resize_with_bbox_op_test.cc new file mode 100644 index 0000000000..f9eaf85a55 --- /dev/null +++ b/tests/ut/cpp/dataset/resize_with_bbox_op_test.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/bboxop_common.h" +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; + +const bool kSaveExpected = false; +const char kOpName[] = "ResizeWithBBox_C"; + +class MindDataTestResizeWithBBoxOp : public UT::CVOP::BBOXOP::BBoxOpCommon { + protected: + MindDataTestResizeWithBBoxOp() : BBoxOpCommon() {} +}; +TEST_F(MindDataTestResizeWithBBoxOp, TestOp) { + MS_LOG(INFO) << "Doing testResizeWithBBox."; + // resize + TensorTable results; + std::unique_ptr op(new ResizeWithBBoxOp(500)); + for (const auto &tensor_row_ : images_and_annotations_) { + // selected a tensorRow + TensorRow output_row; + Status s = op->Compute(tensor_row_, &output_row); + EXPECT_TRUE(s.IsOk()); + results.push_back(output_row); + } + if (kSaveExpected) { + SaveImagesWithAnnotations(FileType::kExpected, std::string(kOpName), results); + } + SaveImagesWithAnnotations(FileType::kActual, std::string(kOpName), results); + if (!kSaveExpected) { + CompareActualAndExpected(std::string(kOpName)); + } + + MS_LOG(INFO) << "testResizeWithBBox end."; +} diff --git a/tests/ut/cpp/dataset/schema_test.cc b/tests/ut/cpp/dataset/schema_test.cc index 2da61bc047..95b9c75d9e 100644 --- a/tests/ut/cpp/dataset/schema_test.cc +++ b/tests/ut/cpp/dataset/schema_test.cc @@ -19,11 +19,11 @@ #include #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/data_schema.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/shuffle_op_test.cc b/tests/ut/cpp/dataset/shuffle_op_test.cc index c9bcb24c4e..98b4878efb 100644 --- a/tests/ut/cpp/dataset/shuffle_op_test.cc +++ b/tests/ut/cpp/dataset/shuffle_op_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" diff --git a/tests/ut/cpp/dataset/skip_op_test.cc b/tests/ut/cpp/dataset/skip_op_test.cc index 697745512d..387d2f69ff 100644 --- a/tests/ut/cpp/dataset/skip_op_test.cc +++ b/tests/ut/cpp/dataset/skip_op_test.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/circular_pool.h" -#include "dataset/core/client.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc index dfe15a8f15..96e9652bbc 100644 --- a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc +++ b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc @@ -15,13 +15,13 @@ */ #include "common/common.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/status_test.cc b/tests/ut/cpp/dataset/status_test.cc index c64a86b8ba..195da1c119 100644 --- a/tests/ut/cpp/dataset/status_test.cc +++ b/tests/ut/cpp/dataset/status_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/status.h" +#include "minddata/dataset/util/status.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/subset_random_sampler_test.cc b/tests/ut/cpp/dataset/subset_random_sampler_test.cc index 22200ccbac..c389686014 100644 --- a/tests/ut/cpp/dataset/subset_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/subset_random_sampler_test.cc @@ -16,11 +16,11 @@ #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/core/constants.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include #include diff --git a/tests/ut/cpp/dataset/take_op_test.cc b/tests/ut/cpp/dataset/take_op_test.cc index b7be066d6c..a8bfe40b10 100644 --- a/tests/ut/cpp/dataset/take_op_test.cc +++ b/tests/ut/cpp/dataset/take_op_test.cc @@ -19,7 +19,7 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/task_manager_test.cc b/tests/ut/cpp/dataset/task_manager_test.cc index 3d34ec9ec5..7b8101fa56 100644 --- a/tests/ut/cpp/dataset/task_manager_test.cc +++ b/tests/ut/cpp/dataset/task_manager_test.cc @@ -16,7 +16,7 @@ #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/util/task_manager.h" +#include "minddata/dataset/util/task_manager.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc b/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc new file mode 100644 index 0000000000..70832c04b5 --- /dev/null +++ b/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/core/client.h" +#include "common/common.h" +#include "gtest/gtest.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/execution_tree.h" + + +using namespace mindspore::dataset; +using mindspore::LogStream; +using mindspore::MsLogLevel::INFO; + +class MindDataTestTensorOpFusionPass : public UT::DatasetOpTesting { + public: + MindDataTestTensorOpFusionPass() = default; + void SetUp() override { GlobalInit(); } +}; + +TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResize_fusion_disabled) { + MS_LOG(INFO) << "Doing RandomCropDecodeResize_fusion"; + std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, + bool shuf = false, std::shared_ptr sampler = nullptr, + std::map map = {}, bool decode = false); + std::shared_ptr Build(std::vector> ops); + auto rcar_op = std::make_shared(); + auto decode_op = std::make_shared(); + Status rc; + std::vector> func_list; + func_list.push_back(decode_op); + func_list.push_back(rcar_op); + std::shared_ptr map_op; + MapOp::Builder map_decode_builder; + map_decode_builder.SetInColNames({}).SetOutColNames({}).SetTensorFuncs(func_list).SetNumWorkers(4); + rc = map_decode_builder.Build(&map_op); + EXPECT_TRUE(rc.IsOk()); + auto tree = std::make_shared(); + tree = Build({ImageFolder(16, 2, 32, "./", false), map_op}); + rc = tree->SetOptimize(false); + EXPECT_TRUE(rc); + rc = tree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + rc = tree->SetOptimize(false); + EXPECT_TRUE(rc.IsError()); + auto it = tree->begin(); + ++it; + auto *m_op = &(*it); + auto tfuncs = static_cast(m_op)->TFuncs(); + auto func_it = tfuncs.begin(); + EXPECT_EQ((*func_it)->Name(), kDecodeOp); + ++func_it; + EXPECT_EQ((*func_it)->Name(), kRandomCropAndResizeOp); +} + +TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResize_fusion_enabled) { + MS_LOG(INFO) << "Doing RandomCropDecodeResize_fusion"; + std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, + bool shuf = false, std::shared_ptr sampler = nullptr, + std::map map = {}, bool decode = false); + std::shared_ptr Build(std::vector> ops); + auto rcar_op = std::make_shared(); + auto decode_op = std::make_shared(); + Status rc; + std::vector> func_list; + func_list.push_back(decode_op); + func_list.push_back(rcar_op); + std::shared_ptr map_op; + MapOp::Builder map_decode_builder; + map_decode_builder.SetInColNames({}).SetOutColNames({}).SetTensorFuncs(func_list).SetNumWorkers(4); + rc = map_decode_builder.Build(&map_op); + EXPECT_TRUE(rc.IsOk()); + auto tree = std::make_shared(); + tree = Build({ImageFolder(16, 2, 32, "./", false), map_op}); + rc = tree->SetOptimize(true); + EXPECT_TRUE(rc); + rc = tree->Prepare(); + EXPECT_TRUE(rc.IsOk()); + rc = tree->SetOptimize(false); + EXPECT_TRUE(rc.IsError()); + auto it = tree->begin(); + ++it; + auto *m_op = &(*it); + auto tfuncs = static_cast(m_op)->TFuncs(); + auto func_it = tfuncs.begin(); + EXPECT_EQ((*func_it)->Name(), kRandomCropDecodeResizeOp); + EXPECT_EQ(++func_it, tfuncs.end()); +} \ No newline at end of file diff --git a/tests/ut/cpp/dataset/tensor_string_test.cc b/tests/ut/cpp/dataset/tensor_string_test.cc index 43b235304d..fe336a34c5 100644 --- a/tests/ut/cpp/dataset/tensor_string_test.cc +++ b/tests/ut/cpp/dataset/tensor_string_test.cc @@ -15,13 +15,13 @@ */ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/core/tensor.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/tensor_test.cc b/tests/ut/cpp/dataset/tensor_test.cc index 1aa3cad2fa..fce4652b47 100644 --- a/tests/ut/cpp/dataset/tensor_test.cc +++ b/tests/ut/cpp/dataset/tensor_test.cc @@ -15,13 +15,13 @@ */ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/core/tensor.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" using namespace mindspore::dataset; @@ -432,3 +432,17 @@ TEST_F(MindDataTestTensorDE, TensorConcatenate) { s = t1->Concatenate({5}, t2); EXPECT_FALSE(s.IsOk()); } + +TEST_F(MindDataTestTensorDE, TensorEmpty) { + std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); + ASSERT_TRUE(t->HasData()); +} + +TEST_F(MindDataTestTensorDE, TensorEmptyInvalidate) { + std::vector values1 = {1, 2, 3, 0, 0, 0}; + std::shared_ptr t; + Tensor::CreateTensor(&t, values1); + t->Invalidate(); + ASSERT_TRUE(t->HasData()); +} + diff --git a/tests/ut/cpp/dataset/tensorshape_test.cc b/tests/ut/cpp/dataset/tensorshape_test.cc index 1af0bf9c82..65ab386db0 100644 --- a/tests/ut/cpp/dataset/tensorshape_test.cc +++ b/tests/ut/cpp/dataset/tensorshape_test.cc @@ -15,10 +15,10 @@ */ #include #include "./securec.h" -#include "dataset/core/client.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/data_schema.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/data_schema.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" diff --git a/tests/ut/cpp/dataset/text_file_op_test.cc b/tests/ut/cpp/dataset/text_file_op_test.cc index 7887eda955..bc2674a6a3 100644 --- a/tests/ut/cpp/dataset/text_file_op_test.cc +++ b/tests/ut/cpp/dataset/text_file_op_test.cc @@ -17,13 +17,13 @@ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/util/status.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/util/status.h" namespace common = mindspore::common; diff --git a/tests/ut/cpp/dataset/tfReader_op_test.cc b/tests/ut/cpp/dataset/tfReader_op_test.cc index 9b312296d8..30fde33ff9 100644 --- a/tests/ut/cpp/dataset/tfReader_op_test.cc +++ b/tests/ut/cpp/dataset/tfReader_op_test.cc @@ -17,8 +17,8 @@ #include #include -#include "dataset/core/client.h" -#include "dataset/engine/data_schema.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/data_schema.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" diff --git a/tests/ut/cpp/dataset/to_float16_op_test.cc b/tests/ut/cpp/dataset/to_float16_op_test.cc index 9c49c67b2c..5c886690c9 100644 --- a/tests/ut/cpp/dataset/to_float16_op_test.cc +++ b/tests/ut/cpp/dataset/to_float16_op_test.cc @@ -15,9 +15,9 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_rotation_op.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/data/to_float16_op.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/data/to_float16_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/tokenizer_op_test.cc b/tests/ut/cpp/dataset/tokenizer_op_test.cc index 8a18f0da0c..cc2d7473ff 100644 --- a/tests/ut/cpp/dataset/tokenizer_op_test.cc +++ b/tests/ut/cpp/dataset/tokenizer_op_test.cc @@ -18,14 +18,14 @@ #include #include "common/common.h" -#include "dataset/text/kernels/basic_tokenizer_op.h" -#include "dataset/text/kernels/case_fold_op.h" -#include "dataset/text/kernels/normalize_utf8_op.h" -#include "dataset/text/kernels/regex_replace_op.h" -#include "dataset/text/kernels/regex_tokenizer_op.h" -#include "dataset/text/kernels/unicode_char_tokenizer_op.h" -#include "dataset/text/kernels/unicode_script_tokenizer_op.h" -#include "dataset/text/kernels/whitespace_tokenizer_op.h" +#include "minddata/dataset/text/kernels/basic_tokenizer_op.h" +#include "minddata/dataset/text/kernels/case_fold_op.h" +#include "minddata/dataset/text/kernels/normalize_utf8_op.h" +#include "minddata/dataset/text/kernels/regex_replace_op.h" +#include "minddata/dataset/text/kernels/regex_tokenizer_op.h" +#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" +#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h" +#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" @@ -45,227 +45,245 @@ class MindDataTestTokenizerOp : public UT::Common { TEST_F(MindDataTestTokenizerOp, TestUnicodeCharTokenizerOp) { MS_LOG(INFO) << "Doing TestUnicodeCharTokenizerOp."; - std::unique_ptr op(new UnicodeCharTokenizerOp()); + std::unique_ptr op(new UnicodeCharTokenizerOp(true)); std::shared_ptr input = std::make_shared("Hello World!"); - std::shared_ptr output; - Status s = op->Compute(input, &output); + TensorRow output; + Status s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 12); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor1: " << output->ToString(); - CheckEqual(output, {0}, "H"); - CheckEqual(output, {1}, "e"); - CheckEqual(output, {2}, "l"); - CheckEqual(output, {3}, "l"); - CheckEqual(output, {4}, "o"); - CheckEqual(output, {5}, " "); - CheckEqual(output, {6}, "W"); - CheckEqual(output, {7}, "o"); - CheckEqual(output, {8}, "r"); - CheckEqual(output, {9}, "l"); - CheckEqual(output, {10}, "d"); - CheckEqual(output, {11}, "!"); + EXPECT_EQ(output[0]->Size(), 12); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor1: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "H"); + CheckEqual(output[0], {1}, "e"); + CheckEqual(output[0], {2}, "l"); + CheckEqual(output[0], {3}, "l"); + CheckEqual(output[0], {4}, "o"); + CheckEqual(output[0], {5}, " "); + CheckEqual(output[0], {6}, "W"); + CheckEqual(output[0], {7}, "o"); + CheckEqual(output[0], {8}, "r"); + CheckEqual(output[0], {9}, "l"); + CheckEqual(output[0], {10}, "d"); + CheckEqual(output[0], {11}, "!"); input = std::make_shared("中国 你好!"); - s = op->Compute(input, &output); + output.clear(); + s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 6); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor2: " << output->ToString(); - CheckEqual(output, {0}, "中"); - CheckEqual(output, {1}, "国"); - CheckEqual(output, {2}, " "); - CheckEqual(output, {3}, "你"); - CheckEqual(output, {4}, "好"); - CheckEqual(output, {5}, "!"); + EXPECT_EQ(output[0]->Size(), 6); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor2: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "中"); + CheckEqual(output[0], {1}, "国"); + CheckEqual(output[0], {2}, " "); + CheckEqual(output[0], {3}, "你"); + CheckEqual(output[0], {4}, "好"); + CheckEqual(output[0], {5}, "!"); input = std::make_shared("中"); - s = op->Compute(input, &output); + output.clear(); + s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor3: " << output->ToString(); - CheckEqual(output, {0}, "中"); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor3: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "中"); input = std::make_shared("H"); - s = op->Compute(input, &output); + output.clear(); + s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor4: " << output->ToString(); - CheckEqual(output, {0}, "H"); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor4: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "H"); input = std::make_shared(" "); - s = op->Compute(input, &output); + output.clear(); + s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 2); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor5: " << output->ToString(); - CheckEqual(output, {0}, " "); - CheckEqual(output, {1}, " "); + EXPECT_EQ(output[0]->Size(), 2); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor5: " << output[0]->ToString(); + CheckEqual(output[0], {0}, " "); + CheckEqual(output[0], {1}, " "); input = std::make_shared(""); - s = op->Compute(input, &output); + output.clear(); + s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor6: " << output->ToString(); - CheckEqual(output, {0}, ""); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor6: " << output[0]->ToString(); + CheckEqual(output[0], {0}, ""); } TEST_F(MindDataTestTokenizerOp, TestWhitespaceTokenizerOp) { MS_LOG(INFO) << "Doing TestWhitespaceTokenizerOp."; - std::unique_ptr op(new WhitespaceTokenizerOp()); + std::unique_ptr op(new WhitespaceTokenizerOp(true)); std::shared_ptr input = std::make_shared("Welcome to China."); - std::shared_ptr output; - Status s = op->Compute(input, &output); + TensorRow output; + Status s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 3); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor1: " << output->ToString(); - CheckEqual(output, {0}, "Welcome"); - CheckEqual(output, {1}, "to"); - CheckEqual(output, {2}, "China."); + EXPECT_EQ(output[0]->Size(), 3); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor1: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "Welcome"); + CheckEqual(output[0], {1}, "to"); + CheckEqual(output[0], {2}, "China."); input = std::make_shared(" hello"); - s = op->Compute(input, &output); + output.clear(); + s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor2: " << output->ToString(); - CheckEqual(output, {0}, "hello"); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor2: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "hello"); input = std::make_shared("hello"); - s = op->Compute(input, &output); + output.clear(); + s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor3: " << output->ToString(); - CheckEqual(output, {0}, "hello"); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor3: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "hello"); input = std::make_shared("hello "); - s = op->Compute(input, &output); + output.clear(); + s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor4: " << output->ToString(); - CheckEqual(output, {0}, "hello"); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor4: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "hello"); input = std::make_shared(" "); - s = op->Compute(input, &output); + output.clear(); + s = op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor5: " << output->ToString(); - CheckEqual(output, {0}, ""); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor5: " << output[0]->ToString(); + CheckEqual(output[0], {0}, ""); } TEST_F(MindDataTestTokenizerOp, TestUnicodeScriptTokenizer) { MS_LOG(INFO) << "Doing TestUnicodeScriptTokenizer."; - std::unique_ptr keep_whitespace_op(new UnicodeScriptTokenizerOp(true)); - std::unique_ptr skip_whitespace_op(new UnicodeScriptTokenizerOp(false)); + std::unique_ptr keep_whitespace_op(new UnicodeScriptTokenizerOp(true, true)); + std::unique_ptr skip_whitespace_op(new UnicodeScriptTokenizerOp(false, true)); std::shared_ptr input = std::make_shared("Welcome to China. \n 中国\t北京"); - std::shared_ptr output; - Status s = keep_whitespace_op->Compute(input, &output); + TensorRow output; + Status s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 10); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor1: " << output->ToString(); - CheckEqual(output, {0}, "Welcome"); - CheckEqual(output, {1}, " "); - CheckEqual(output, {2}, "to"); - CheckEqual(output, {3}, " "); - CheckEqual(output, {4}, "China"); - CheckEqual(output, {5}, "."); - CheckEqual(output, {6}, " \n "); - CheckEqual(output, {7}, "中国"); - CheckEqual(output, {8}, "\t"); - CheckEqual(output, {9}, "北京"); - s = skip_whitespace_op->Compute(input, &output); + EXPECT_EQ(output[0]->Size(), 10); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor1: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "Welcome"); + CheckEqual(output[0], {1}, " "); + CheckEqual(output[0], {2}, "to"); + CheckEqual(output[0], {3}, " "); + CheckEqual(output[0], {4}, "China"); + CheckEqual(output[0], {5}, "."); + CheckEqual(output[0], {6}, " \n "); + CheckEqual(output[0], {7}, "中国"); + CheckEqual(output[0], {8}, "\t"); + CheckEqual(output[0], {9}, "北京"); + output.clear(); + s = skip_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 6); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor2: " << output->ToString(); - CheckEqual(output, {0}, "Welcome"); - CheckEqual(output, {1}, "to"); - CheckEqual(output, {2}, "China"); - CheckEqual(output, {3}, "."); - CheckEqual(output, {4}, "中国"); - CheckEqual(output, {5}, "北京"); + EXPECT_EQ(output[0]->Size(), 6); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor2: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "Welcome"); + CheckEqual(output[0], {1}, "to"); + CheckEqual(output[0], {2}, "China"); + CheckEqual(output[0], {3}, "."); + CheckEqual(output[0], {4}, "中国"); + CheckEqual(output[0], {5}, "北京"); input = std::make_shared(" Welcome to 中国. "); - s = skip_whitespace_op->Compute(input, &output); + output.clear(); + s = skip_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 4); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor3: " << output->ToString(); - CheckEqual(output, {0}, "Welcome"); - CheckEqual(output, {1}, "to"); - CheckEqual(output, {2}, "中国"); - CheckEqual(output, {3}, "."); - s = keep_whitespace_op->Compute(input, &output); + EXPECT_EQ(output[0]->Size(), 4); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor3: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "Welcome"); + CheckEqual(output[0], {1}, "to"); + CheckEqual(output[0], {2}, "中国"); + CheckEqual(output[0], {3}, "."); + output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 8); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor4: " << output->ToString(); - CheckEqual(output, {0}, " "); - CheckEqual(output, {1}, "Welcome"); - CheckEqual(output, {2}, " "); - CheckEqual(output, {3}, "to"); - CheckEqual(output, {4}, " "); - CheckEqual(output, {5}, "中国"); - CheckEqual(output, {6}, "."); - CheckEqual(output, {7}, " "); + EXPECT_EQ(output[0]->Size(), 8); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor4: " << output[0]->ToString(); + CheckEqual(output[0], {0}, " "); + CheckEqual(output[0], {1}, "Welcome"); + CheckEqual(output[0], {2}, " "); + CheckEqual(output[0], {3}, "to"); + CheckEqual(output[0], {4}, " "); + CheckEqual(output[0], {5}, "中国"); + CheckEqual(output[0], {6}, "."); + CheckEqual(output[0], {7}, " "); input = std::make_shared("Hello"); - s = keep_whitespace_op->Compute(input, &output); + output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor5: " << output->ToString(); - CheckEqual(output, {0}, "Hello"); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor5: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "Hello"); input = std::make_shared("H"); - s = keep_whitespace_op->Compute(input, &output); + output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor6: " << output->ToString(); - CheckEqual(output, {0}, "H"); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor6: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "H"); input = std::make_shared(""); - s = keep_whitespace_op->Compute(input, &output); + output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor7: " << output->ToString(); - CheckEqual(output, {0}, ""); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor7: " << output[0]->ToString(); + CheckEqual(output[0], {0}, ""); input = std::make_shared("Hello中国Hello世界"); - s = keep_whitespace_op->Compute(input, &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 4); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor8: " << output->ToString(); - CheckEqual(output, {0}, "Hello"); - CheckEqual(output, {1}, "中国"); - CheckEqual(output, {2}, "Hello"); - CheckEqual(output, {3}, "世界"); + output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); + EXPECT_EQ(output[0]->Size(), 4); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor8: " << output[0]->ToString(); + CheckEqual(output[0], {0}, "Hello"); + CheckEqual(output[0], {1}, "中国"); + CheckEqual(output[0], {2}, "Hello"); + CheckEqual(output[0], {3}, "世界"); input = std::make_shared(" "); - s = keep_whitespace_op->Compute(input, &output); + output.clear(); + s = keep_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor10: " << output->ToString(); - CheckEqual(output, {0}, " "); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor10: " << output[0]->ToString(); + CheckEqual(output[0], {0}, " "); input = std::make_shared(" "); - s = skip_whitespace_op->Compute(input, &output); + output.clear(); + s = skip_whitespace_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); - EXPECT_EQ(output->Size(), 1); - EXPECT_EQ(output->Rank(), 1); - MS_LOG(INFO) << "Out tensor11: " << output->ToString(); - CheckEqual(output, {0}, ""); + EXPECT_EQ(output[0]->Size(), 1); + EXPECT_EQ(output[0]->Rank(), 1); + MS_LOG(INFO) << "Out tensor11: " << output[0]->ToString(); + CheckEqual(output[0], {0}, ""); } TEST_F(MindDataTestTokenizerOp, TestCaseFold) { @@ -321,10 +339,10 @@ TEST_F(MindDataTestTokenizerOp, TestRegexReplace) { TEST_F(MindDataTestTokenizerOp, TestRegexTokenizer) { MS_LOG(INFO) << "Doing TestRegexTokenizerOp."; - std::unique_ptr regex_tokenizer_op(new RegexTokenizerOp("\\p{Cc}|\\p{Cf}|\\s+", "")); + std::unique_ptr regex_tokenizer_op(new RegexTokenizerOp("\\p{Cc}|\\p{Cf}|\\s+", "", true)); std::shared_ptr input = std::make_shared("Welcome to China. \n 中国\t北京"); - std::shared_ptr output; - Status s = regex_tokenizer_op->Compute(input, &output); + TensorRow output; + Status s = regex_tokenizer_op->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); } @@ -332,9 +350,10 @@ TEST_F(MindDataTestTokenizerOp, TestBasicTokenizer) { MS_LOG(INFO) << "Doing TestBasicTokenizer."; //bool lower_case, bool keep_whitespace, // NormalizeForm normalization_form, bool preserve_unused_token - std::unique_ptr basic_tokenizer(new BasicTokenizerOp(true, true, NormalizeForm::kNone, false)); + std::unique_ptr basic_tokenizer(new BasicTokenizerOp(true, true, NormalizeForm::kNone, false, + true)); std::shared_ptr input = std::make_shared("Welcome to China. 中国\t北京"); - std::shared_ptr output; - Status s = basic_tokenizer->Compute(input, &output); + TensorRow output; + Status s = basic_tokenizer->Compute(TensorRow(0, {input}), &output); EXPECT_TRUE(s.IsOk()); } \ No newline at end of file diff --git a/tests/ut/cpp/dataset/treap_test.cc b/tests/ut/cpp/dataset/treap_test.cc index b454ab108e..b9c534719c 100644 --- a/tests/ut/cpp/dataset/treap_test.cc +++ b/tests/ut/cpp/dataset/treap_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/treap.h" +#include "minddata/dataset/util/treap.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/trucate_pair_test.cc b/tests/ut/cpp/dataset/trucate_pair_test.cc index 95e2aaa11b..af7e61c16a 100644 --- a/tests/ut/cpp/dataset/trucate_pair_test.cc +++ b/tests/ut/cpp/dataset/trucate_pair_test.cc @@ -15,12 +15,12 @@ */ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/core/tensor.h" -#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h" +#include "minddata/dataset/core/tensor.h" +#include "mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/type_cast_op_test.cc b/tests/ut/cpp/dataset/type_cast_op_test.cc index 543eb71637..a94a7fedba 100644 --- a/tests/ut/cpp/dataset/type_cast_op_test.cc +++ b/tests/ut/cpp/dataset/type_cast_op_test.cc @@ -17,12 +17,12 @@ #include #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/core/client.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/pybind_support.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/pybind_support.h" #include "gtest/gtest.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/voc_op_test.cc b/tests/ut/cpp/dataset/voc_op_test.cc index 05dc28b487..4bb212ffc7 100644 --- a/tests/ut/cpp/dataset/voc_op_test.cc +++ b/tests/ut/cpp/dataset/voc_op_test.cc @@ -20,18 +20,18 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc index d146ed10ac..bb3079aec8 100644 --- a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc @@ -16,11 +16,11 @@ #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/core/constants.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" #include "utils/log_adapter.h" #include diff --git a/tests/ut/cpp/dataset/zip_op_test.cc b/tests/ut/cpp/dataset/zip_op_test.cc index b387341398..8d74cb0969 100644 --- a/tests/ut/cpp/dataset/zip_op_test.cc +++ b/tests/ut/cpp/dataset/zip_op_test.cc @@ -21,17 +21,17 @@ #include #include #include -#include "dataset/core/client.h" -#include "dataset/core/constants.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/engine/datasetops/zip_op.h" -#include "dataset/core/tensor.h" -#include "dataset/core/config_manager.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/zip_op.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/config_manager.h" #include "common/common.h" #include "common/utils.h" -#include "dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_buffer.h" #include "gtest/gtest.h" -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/global_context.h" #include "utils/log_adapter.h" namespace common = mindspore::common; @@ -58,7 +58,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpDefault) { auto my_tree = std::make_shared(); // Creating TFReaderOp - std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data"; + std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data"; std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data"; std::shared_ptr my_tfreader_op; rc = TFReaderOp::Builder() @@ -142,7 +142,7 @@ TEST_F(MindDataTestZipOp, MindDataTestZipOpRepeat) { MS_LOG(INFO) << "UT test TestZipRepeat."; auto my_tree = std::make_shared(); - std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images_1/train-0000-of-0001.data"; + std::string dataset_path = datasets_root_path_ + "/test_tf_file_3_images/train-0000-of-0001.data"; std::string dataset_path2 = datasets_root_path_ + "/testBatchDataset/test.data"; std::shared_ptr my_tfreader_op; rc = TFReaderOp::Builder() diff --git a/tests/ut/cpp/device/ascend_kernel_runtime_test.cc b/tests/ut/cpp/device/ascend_kernel_runtime_test.cc index effa0b212d..2aa9512808 100644 --- a/tests/ut/cpp/device/ascend_kernel_runtime_test.cc +++ b/tests/ut/cpp/device/ascend_kernel_runtime_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" -#include "device/kernel_runtime.h" +#include "runtime/device/kernel_runtime.h" #include "./common.h" namespace mindspore { diff --git a/tests/ut/cpp/device/ascend_profiling_test.cc b/tests/ut/cpp/device/ascend_profiling_test.cc index 2829a5fd4a..f862d84c4a 100644 --- a/tests/ut/cpp/device/ascend_profiling_test.cc +++ b/tests/ut/cpp/device/ascend_profiling_test.cc @@ -18,12 +18,12 @@ #include "./prof_reporter.h" #include "common/common_test.h" -#include "device/ascend/profiling/profiling_manager.h" +#include "runtime/device/ascend/profiling/profiling_manager.h" #include "./common.h" #define private public -#include "device/ascend/profiling/plugin_impl.h" +#include "runtime/device/ascend/profiling/plugin_impl.h" #undef private -#include "device/ascend/profiling/profiling_engine_impl.h" +#include "runtime/device/ascend/profiling/profiling_engine_impl.h" namespace mindspore { namespace device { diff --git a/tests/ut/cpp/ir/anf_test.cc b/tests/ut/cpp/ir/anf_test.cc index c649518e21..9b217a2321 100644 --- a/tests/ut/cpp/ir/anf_test.cc +++ b/tests/ut/cpp/ir/anf_test.cc @@ -19,7 +19,7 @@ #include "common/common_test.h" #include "ir/anf.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "./common.h" namespace mindspore { diff --git a/tests/ut/cpp/ir/base_test.cc b/tests/ut/cpp/ir/base_test.cc deleted file mode 100644 index 0b4e8a637b..0000000000 --- a/tests/ut/cpp/ir/base_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "common/common_test.h" -#include "utils/any.h" -#include "ir/base.h" -#include "ir/anf.h" -#include "utils/log_adapter.h" - -namespace mindspore { - -class TestNode : public UT::Common { - public: - TestNode() {} -}; - -class ChildA : public Base { - public: - ChildA() {} - ~ChildA() {} - MS_DECLARE_PARENT(ChildA, Base); - std::string name() { return "ChildA"; } - std::size_t hash() const override { return 1; } -}; -class ChildAA : public ChildA { - public: - ChildAA() {} - ~ChildAA() {} - MS_DECLARE_PARENT(ChildAA, ChildA); - std::size_t hash() const override { return 1; } - std::string name() { return "ChildAA"; } -}; - -class ChildB : public Base { - public: - ChildB() {} - ~ChildB() {} - MS_DECLARE_PARENT(ChildB, Base); - std::size_t hash() const override { return 1; } - std::string name() { return "ChildB"; } -}; - -TEST_F(TestNode, test_dyn_cast) { - auto aa = std::make_shared(); - std::shared_ptr n = aa; - MS_LOG(INFO) << "aa ptr_name: " << aa->name(); - MS_LOG(INFO) << "aa type_name: " << aa->type_name(); - MS_LOG(INFO) << "n ptr_name: " << demangle(typeid(n).name()); - MS_LOG(INFO) << "n type_name: " << n->type_name(); - ASSERT_TRUE(n != nullptr); - ASSERT_EQ(std::string(n->type_name().c_str()), "ChildAA"); - auto a = dyn_cast(n); - MS_LOG(INFO) << "a ptr_name: " << a->name(); - MS_LOG(INFO) << "a type_name: " << a->type_name(); - ASSERT_TRUE(a != nullptr); - ASSERT_EQ(std::string(a->name()), "ChildA"); - ASSERT_EQ(std::string(a->type_name().c_str()), "ChildAA"); - auto b_null = dyn_cast(n); - ASSERT_TRUE(b_null == nullptr); - - ChildA* pa = cast(n.get()); - ASSERT_TRUE(pa != nullptr); - MS_LOG(INFO) << "a ptr_name: " << pa->name(); - MS_LOG(INFO) << "a type_name: " << pa->type_name(); -} - -TEST_F(TestNode, test_isa) { - auto a = std::make_shared(); - BasePtr n = a; - ASSERT_TRUE(n->isa() == true); - ASSERT_TRUE(n->isa() == false); - - auto aa = std::make_shared(); - n = aa; - ASSERT_TRUE(n->isa() == true); - ASSERT_TRUE(n->isa() == true); - - auto b = std::make_shared(); - n = b; - ASSERT_TRUE(n->isa() == true); - ASSERT_TRUE(n->isa() == false); - ASSERT_TRUE(n->isa() == false); -} - -} // namespace mindspore diff --git a/tests/ut/cpp/ir/clone_test.cc b/tests/ut/cpp/ir/clone_test.cc index bb8cae7fbb..20da3fb8b5 100644 --- a/tests/ut/cpp/ir/clone_test.cc +++ b/tests/ut/cpp/ir/clone_test.cc @@ -21,7 +21,7 @@ #include "ir/manager.h" #include "utils/log_adapter.h" #include "ir/func_graph_cloner.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "utils/graph_utils.h" #include "debug/draw.h" #include "./common.h" diff --git a/tests/ut/cpp/ir/manager_test.cc b/tests/ut/cpp/ir/manager_test.cc index 04b584ec10..3e6d1a312c 100644 --- a/tests/ut/cpp/ir/manager_test.cc +++ b/tests/ut/cpp/ir/manager_test.cc @@ -18,8 +18,8 @@ #include "ir/dtype.h" #include "ir/manager.h" #include "ir/func_graph_cloner.h" -#include "pipeline/parse/parse.h" -#include "operator/ops.h" +#include "pipeline/jit/parse/parse.h" +#include "frontend/operator/ops.h" #include "utils/log_adapter.h" #include "debug/draw.h" #include "debug/label.h" diff --git a/tests/ut/cpp/ir/value_test.cc b/tests/ut/cpp/ir/value_test.cc index a71ef7a57f..b4ed5f438e 100644 --- a/tests/ut/cpp/ir/value_test.cc +++ b/tests/ut/cpp/ir/value_test.cc @@ -21,7 +21,7 @@ #include "common/common_test.h" #include "ir/value.h" -#include "pipeline/static_analysis/abstract_value.h" +#include "abstract/abstract_value.h" #include "utils/log_adapter.h" namespace mindspore { diff --git a/tests/ut/cpp/kernel/common_utils_test.cc b/tests/ut/cpp/kernel/common_utils_test.cc index 4bc05b5c05..83f7c59e52 100644 --- a/tests/ut/cpp/kernel/common_utils_test.cc +++ b/tests/ut/cpp/kernel/common_utils_test.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" -#include "kernel/common_utils.h" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace kernel { diff --git a/tests/ut/cpp/kernel/cpu/sparse_apply_adam_cpu_kernel_test.cc b/tests/ut/cpp/kernel/cpu/sparse_apply_adam_cpu_kernel_test.cc index 2a6b80f9e7..e5cba86230 100644 --- a/tests/ut/cpp/kernel/cpu/sparse_apply_adam_cpu_kernel_test.cc +++ b/tests/ut/cpp/kernel/cpu/sparse_apply_adam_cpu_kernel_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #define private public #define protected public -#include "kernel/cpu/sparse_apply_adam_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h" #undef private #undef protected @@ -58,9 +58,12 @@ class SparseApplyAdamCpuKernelTest : public UT::Common { inputs_.push_back(CreateKernelAddress(indices.data())); } - void CreateWorkspaceAddress(std::vector &new_grad, std::vector &new_indices, std::vector &m_t) { + void CreateWorkspaceAddress(std::vector &new_grad, std::vector &new_indices, std::vector &tmp_grad, + std::vector &tmp_indices, std::vector &m_t) { workspace_.push_back(CreateKernelAddress(new_grad.data())); workspace_.push_back(CreateKernelAddress(new_indices.data())); + workspace_.push_back(CreateKernelAddress(tmp_grad.data())); + workspace_.push_back(CreateKernelAddress(tmp_indices.data())); workspace_.push_back(CreateKernelAddress(m_t.data())); } @@ -95,8 +98,10 @@ TEST_F(SparseApplyAdamCpuKernelTest, dense_test) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); std::vector m_t(3 * 3 * 3); - CreateWorkspaceAddress(new_grad, new_indices, m_t); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices, m_t); sparse_adam_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3 * 3; ++i) { EXPECT_TRUE(std::fabs(var_[i] - 0.999684) < 1e-6); @@ -120,8 +125,10 @@ TEST_F(SparseApplyAdamCpuKernelTest, sparse_test1) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); std::vector m_t(3 * 3 * 3); - CreateWorkspaceAddress(new_grad, new_indices, m_t); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices, m_t); sparse_adam_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3; ++i) { EXPECT_TRUE(std::fabs(var_[i] - 0.999684) < 1e-6); @@ -149,8 +156,10 @@ TEST_F(SparseApplyAdamCpuKernelTest, sparse_test2) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); std::vector m_t(3 * 3 * 3); - CreateWorkspaceAddress(new_grad, new_indices, m_t); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices, m_t); sparse_adam_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3; ++i) { EXPECT_TRUE(std::fabs(var_[i] - 0.999715) < 1e-6); diff --git a/tests/ut/cpp/kernel/cpu/sparse_apply_ftrl_cpu_kernel_test.cc b/tests/ut/cpp/kernel/cpu/sparse_apply_ftrl_cpu_kernel_test.cc index c5c2394538..230c8cbf9e 100644 --- a/tests/ut/cpp/kernel/cpu/sparse_apply_ftrl_cpu_kernel_test.cc +++ b/tests/ut/cpp/kernel/cpu/sparse_apply_ftrl_cpu_kernel_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #define private public #define protected public -#include "kernel/cpu/sparse_apply_ftrl_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h" #undef private #undef protected @@ -56,9 +56,12 @@ class SparseApplyFtrlCpuKernelTest : public UT::Common { inputs_.push_back(CreateKernelAddress(indices.data())); } - void CreateWorkspaceAddress(std::vector &new_grad, std::vector &new_indices) { + void CreateWorkspaceAddress(std::vector &new_grad, std::vector &new_indices, std::vector &tmp_grad, + std::vector &tmp_indices) { workspace_.push_back(CreateKernelAddress(new_grad.data())); workspace_.push_back(CreateKernelAddress(new_indices.data())); + workspace_.push_back(CreateKernelAddress(tmp_grad.data())); + workspace_.push_back(CreateKernelAddress(tmp_indices.data())); } std::vector var_; @@ -86,7 +89,9 @@ TEST_F(SparseApplyFtrlCpuKernelTest, dense_test) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); - CreateWorkspaceAddress(new_grad, new_indices); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices); sparse_ftrl_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3 * 3; ++i) { EXPECT_TRUE(std::fabs(var_[i] - 0.291479) < 1e-6); @@ -110,7 +115,9 @@ TEST_F(SparseApplyFtrlCpuKernelTest, sparse_test1) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); - CreateWorkspaceAddress(new_grad, new_indices); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices); sparse_ftrl_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3; ++i) { EXPECT_TRUE(std::fabs(var_[i] - 0.291479) < 1e-6); @@ -138,7 +145,9 @@ TEST_F(SparseApplyFtrlCpuKernelTest, sparse_test2) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); - CreateWorkspaceAddress(new_grad, new_indices); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices); sparse_ftrl_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3; ++i) { EXPECT_EQ(var_[i], 1.0); diff --git a/tests/ut/cpp/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel_test.cc b/tests/ut/cpp/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel_test.cc index 1765ed896f..a829ead90e 100644 --- a/tests/ut/cpp/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel_test.cc +++ b/tests/ut/cpp/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #define private public #define protected public -#include "kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h" #undef private #undef protected @@ -58,9 +58,12 @@ class SparseApplyLazyAdamCpuKernelTest : public UT::Common { inputs_.push_back(CreateKernelAddress(indices.data())); } - void CreateWorkspaceAddress(std::vector &new_grad, std::vector &new_indices) { + void CreateWorkspaceAddress(std::vector &new_grad, std::vector &new_indices, std::vector &tmp_grad, + std::vector &tmp_indices) { workspace_.push_back(CreateKernelAddress(new_grad.data())); workspace_.push_back(CreateKernelAddress(new_indices.data())); + workspace_.push_back(CreateKernelAddress(tmp_grad.data())); + workspace_.push_back(CreateKernelAddress(tmp_indices.data())); } std::vector var_; @@ -94,7 +97,9 @@ TEST_F(SparseApplyLazyAdamCpuKernelTest, dense_test) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); - CreateWorkspaceAddress(new_grad, new_indices); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices); sparse_lazy_adam_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3 * 3; ++i) { EXPECT_TRUE(std::fabs(var_[i] - 0.999684) < 1e-6); @@ -118,7 +123,9 @@ TEST_F(SparseApplyLazyAdamCpuKernelTest, sparse_test1) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); - CreateWorkspaceAddress(new_grad, new_indices); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices); sparse_lazy_adam_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3; ++i) { EXPECT_TRUE(std::fabs(var_[i] - 0.999684) < 1e-6); @@ -146,7 +153,9 @@ TEST_F(SparseApplyLazyAdamCpuKernelTest, sparse_test2) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); - CreateWorkspaceAddress(new_grad, new_indices); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices); sparse_lazy_adam_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3; ++i) { EXPECT_EQ(var_[i], 1.0); diff --git a/tests/ut/cpp/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel_test.cc b/tests/ut/cpp/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel_test.cc index 23f66db58c..64bd5d3ef3 100644 --- a/tests/ut/cpp/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel_test.cc +++ b/tests/ut/cpp/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #define private public #define protected public -#include "kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" #undef private #undef protected @@ -54,9 +54,12 @@ class SparseApplyProximalAdagradCpuKernelTest : public UT::Common { inputs_.push_back(CreateKernelAddress(indices.data())); } - void CreateWorkspaceAddress(std::vector &new_grad, std::vector &new_indices) { + void CreateWorkspaceAddress(std::vector &new_grad, std::vector &new_indices, std::vector &tmp_grad, + std::vector &tmp_indices) { workspace_.push_back(CreateKernelAddress(new_grad.data())); workspace_.push_back(CreateKernelAddress(new_indices.data())); + workspace_.push_back(CreateKernelAddress(tmp_grad.data())); + workspace_.push_back(CreateKernelAddress(tmp_indices.data())); } std::vector var_; @@ -85,7 +88,9 @@ TEST_F(SparseApplyProximalAdagradCpuKernelTest, dense_test) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); - CreateWorkspaceAddress(new_grad, new_indices); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices); sparse_proximal_adagrad_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3 * 3; ++i) { EXPECT_TRUE(std::fabs(var_[i] - 0.9929289) < 1e-6); @@ -108,7 +113,9 @@ TEST_F(SparseApplyProximalAdagradCpuKernelTest, sparse_test1) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); - CreateWorkspaceAddress(new_grad, new_indices); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices); sparse_proximal_adagrad_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3; ++i) { EXPECT_TRUE(std::fabs(var_[i] - 0.9929289) < 1e-6); @@ -135,7 +142,9 @@ TEST_F(SparseApplyProximalAdagradCpuKernelTest, sparse_test2) { CreateInputAddress(indices); std::vector new_grad(3 * 3 * 3); std::vector new_indices(3); - CreateWorkspaceAddress(new_grad, new_indices); + std::vector tmp_grad(3 * 3 * 3); + std::vector tmp_indices(3); + CreateWorkspaceAddress(new_grad, new_indices, tmp_grad, tmp_indices); sparse_proximal_adagrad_->Launch(inputs_, workspace_, outputs_); for (size_t i = 0; i < 3 * 3; ++i) { EXPECT_EQ(var_[i], 1.0); diff --git a/tests/ut/cpp/mindrecord/ut_common.h b/tests/ut/cpp/mindrecord/ut_common.h index 8b244bf87a..ee943ab88e 100644 --- a/tests/ut/cpp/mindrecord/ut_common.h +++ b/tests/ut/cpp/mindrecord/ut_common.h @@ -25,10 +25,10 @@ #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_index.h" -#include "mindrecord/include/shard_header.h" -#include "mindrecord/include/shard_index_generator.h" -#include "mindrecord/include/shard_writer.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_writer.h" using json = nlohmann::json; using std::ifstream; using std::pair; diff --git a/tests/ut/cpp/mindrecord/ut_shard.cc b/tests/ut/cpp/mindrecord/ut_shard.cc index b8c229e82f..11492e9f28 100644 --- a/tests/ut/cpp/mindrecord/ut_shard.cc +++ b/tests/ut/cpp/mindrecord/ut_shard.cc @@ -23,10 +23,10 @@ #include "configuration.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_index.h" -#include "mindrecord/include/shard_header.h" -#include "mindrecord/include/shard_statistics.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_statistics.h" #include "securec.h" #include "ut_common.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_header_test.cc b/tests/ut/cpp/mindrecord/ut_shard_header_test.cc index cea71c34b7..2ff3d1655d 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_header_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_header_test.cc @@ -29,13 +29,13 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_writer.h" -#include "mindrecord/include/shard_index.h" -#include "mindrecord/include/shard_header.h" -#include "mindrecord/include/shard_schema.h" -#include "mindrecord/include/shard_statistics.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_writer.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_schema.h" +#include "minddata/mindrecord/include/shard_statistics.h" #include "securec.h" #include "ut_common.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc index 140fff4166..8e264aafa0 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc @@ -29,10 +29,10 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_index_generator.h" -#include "mindrecord/include/shard_index.h" -#include "mindrecord/include/shard_statistics.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "minddata/mindrecord/include/shard_statistics.h" #include "securec.h" #include "ut_common.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 7fe60c3bfa..4501ea0800 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -24,11 +24,11 @@ #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_category.h" -#include "mindrecord/include/shard_pk_sample.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_shuffle.h" +#include "minddata/mindrecord/include/shard_category.h" +#include "minddata/mindrecord/include/shard_pk_sample.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" #include "ut_common.h" using mindspore::LogStream; diff --git a/tests/ut/cpp/mindrecord/ut_shard_page_test.cc b/tests/ut/cpp/mindrecord/ut_shard_page_test.cc index dabd3d819f..a7e444c80f 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_page_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_page_test.cc @@ -21,7 +21,7 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_page.h" +#include "minddata/mindrecord/include/shard_page.h" #include "ut_common.h" using json = nlohmann::json; diff --git a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc index c532fe28b8..8b5eb2cf69 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc @@ -24,8 +24,8 @@ #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_sample.h" #include "ut_common.h" using mindspore::LogStream; diff --git a/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc b/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc index 8d9654a5ef..6863a25791 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc @@ -29,9 +29,9 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_page.h" -#include "mindrecord/include/shard_schema.h" -#include "mindrecord/include/shard_statistics.h" +#include "minddata/mindrecord/include/shard_page.h" +#include "minddata/mindrecord/include/shard_schema.h" +#include "minddata/mindrecord/include/shard_statistics.h" #include "securec.h" #include "ut_common.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc index 3fa6812352..6b99e44d89 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc @@ -30,7 +30,7 @@ #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_segment.h" +#include "minddata/mindrecord/include/shard_segment.h" #include "ut_common.h" using mindspore::LogStream; diff --git a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc index 159efbf2f8..046b4f93d5 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc @@ -24,9 +24,9 @@ #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_writer.h" -#include "mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_writer.h" +#include "minddata/mindrecord/include/shard_index_generator.h" #include "securec.h" #include "ut_common.h" diff --git a/tests/ut/cpp/operator/cc_implementations_test.cc b/tests/ut/cpp/operator/cc_implementations_test.cc index bac885db88..4bc5aea964 100644 --- a/tests/ut/cpp/operator/cc_implementations_test.cc +++ b/tests/ut/cpp/operator/cc_implementations_test.cc @@ -18,7 +18,7 @@ #include #include "common/common_test.h" -#include "operator/cc_implementations.h" +#include "frontend/operator/cc_implementations.h" namespace mindspore { namespace prim { diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index 8ca318300a..a2108998bc 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -18,10 +18,10 @@ #include "common/common_test.h" #include "ir/anf.h" #include "ir/value.h" -#include "operator/composite/composite.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/prim.h" -#include "pipeline/static_analysis/abstract_function.h" +#include "frontend/operator/composite/composite.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/abstract_function.h" #include "debug/trace.h" namespace mindspore { @@ -127,11 +127,17 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) { try { trace::ClearTraceStack(); engine_->Run(tupleSliceGraphPtr, args_spec_list); - FAIL() << "Excepted exception :Args type is wrong"; + FAIL() << "Excepted exception: Args type is wrong"; } catch (pybind11::type_error const &err) { ASSERT_TRUE(true); + } catch (std::runtime_error const &err) { + if (std::strstr(err.what(), "TypeError") != nullptr) { + ASSERT_TRUE(true); + } else { + FAIL() << "Excepted exception: Args type is wrong, message: " << err.what(); + } } catch (...) { - FAIL() << "Excepted exception :Args type is wrong"; + FAIL() << "Excepted exception: Args type is wrong"; } } diff --git a/tests/ut/cpp/operator/grad_implementations_test.cc b/tests/ut/cpp/operator/grad_implementations_test.cc index e9035e63b6..f55553ab72 100644 --- a/tests/ut/cpp/operator/grad_implementations_test.cc +++ b/tests/ut/cpp/operator/grad_implementations_test.cc @@ -20,7 +20,7 @@ #include "ir/value.h" #include "ir/manager.h" #include "common/common_test.h" -#include "optimizer/ad/dfunctor.h" +#include "frontend/optimizer/ad/dfunctor.h" #include "debug/draw.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/operator/ops_test.cc b/tests/ut/cpp/operator/ops_test.cc index 1d1389b54a..789b1cab25 100644 --- a/tests/ut/cpp/operator/ops_test.cc +++ b/tests/ut/cpp/operator/ops_test.cc @@ -19,8 +19,8 @@ #include "common/common_test.h" #include "ir/value.h" -#include "ir/primitive.h" -#include "operator/ops.h" +#include "ir/primitive_py.h" +#include "frontend/operator/ops.h" #include "./common.h" namespace mindspore { diff --git a/tests/ut/cpp/operator/prim2func_test.cc b/tests/ut/cpp/operator/prim2func_test.cc index 8f7c73a064..3952128b52 100644 --- a/tests/ut/cpp/operator/prim2func_test.cc +++ b/tests/ut/cpp/operator/prim2func_test.cc @@ -21,7 +21,7 @@ #include "ir/anf.h" #include "ir/dtype.h" -#include "operator/prim_to_function.h" +#include "frontend/operator/prim_to_function.h" namespace mindspore { namespace prim { diff --git a/tests/ut/cpp/optimizer/ad/ad_test.cc b/tests/ut/cpp/optimizer/ad/ad_test.cc index 34612b5474..3f861d3604 100644 --- a/tests/ut/cpp/optimizer/ad/ad_test.cc +++ b/tests/ut/cpp/optimizer/ad/ad_test.cc @@ -16,7 +16,7 @@ #include #include -#include "optimizer/ad/grad.h" +#include "frontend/optimizer/ad/grad.h" #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "ir/manager.h" @@ -24,10 +24,10 @@ #include "ir/func_graph_cloner.h" #include "utils/log_adapter.h" #include "utils/graph_utils.h" -#include "pipeline/resource.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" namespace mindspore { namespace ad { diff --git a/tests/ut/cpp/optimizer/cconv_test.cc b/tests/ut/cpp/optimizer/cconv_test.cc index 8bd6957e85..c004409058 100644 --- a/tests/ut/cpp/optimizer/cconv_test.cc +++ b/tests/ut/cpp/optimizer/cconv_test.cc @@ -20,7 +20,7 @@ #include "ir/func_graph_cloner.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/optimizer/clean_test.cc b/tests/ut/cpp/optimizer/clean_test.cc index c4f393c233..82bec1b5a8 100644 --- a/tests/ut/cpp/optimizer/clean_test.cc +++ b/tests/ut/cpp/optimizer/clean_test.cc @@ -19,9 +19,9 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" -#include "optimizer/clean.h" +#include "frontend/optimizer/clean.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index bc8561f171..751b301283 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -25,11 +25,11 @@ #include "ir/manager.h" #include "ir/value.h" #include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "pipeline/resource.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "pipeline/jit/resource.h" #include "debug/draw.h" -#include "pipeline/parse/data_converter.h" +#include "pipeline/jit/parse/data_converter.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/optimizer/opt_test.cc b/tests/ut/cpp/optimizer/opt_test.cc index 2428d0dddb..c329adc4a5 100644 --- a/tests/ut/cpp/optimizer/opt_test.cc +++ b/tests/ut/cpp/optimizer/opt_test.cc @@ -22,13 +22,13 @@ #include "ir/anf.h" #include "ir/visitor.h" #include "ir/func_graph_cloner.h" -#include "optimizer/opt.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/arithmetic_simplify.h" +#include "frontend/optimizer/opt.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/arithmetic_simplify.h" #include "debug/draw.h" -#include "operator/ops.h" -#include "optimizer/cse.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/cse.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/optimizer/optimizer_test.cc b/tests/ut/cpp/optimizer/optimizer_test.cc index ca7c589d47..c5c99531e4 100644 --- a/tests/ut/cpp/optimizer/optimizer_test.cc +++ b/tests/ut/cpp/optimizer/optimizer_test.cc @@ -20,10 +20,10 @@ #include "common/py_func_graph_fetcher.h" #include "ir/anf.h" -#include "operator/ops.h" -#include "optimizer/cse.h" -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/cse.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc b/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc index 0462993672..a500afc859 100644 --- a/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc @@ -15,12 +15,12 @@ */ #include "common/common_test.h" -#include "parallel/device_manager.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/ops_info/matmul_info.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/ops_info/tmp_identity_info.h" -#include "parallel/auto_parallel/dp_algo_costmodel.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/ops_info/matmul_info.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/ops_info/tmp_identity_info.h" +#include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc index 291539c27d..190a189a2d 100644 --- a/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc @@ -16,9 +16,9 @@ #include "common/common_test.h" #include "ir/dtype/number.h" -#include "parallel/device_manager.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/ops_info/matmul_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/ops_info/matmul_info.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc index 78d05c7235..7d63f03179 100644 --- a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc @@ -15,9 +15,9 @@ */ #include "common/common_test.h" -#include "parallel/device_manager.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/ops_info/matmul_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/ops_info/matmul_info.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc index 919c5b43ec..b9b6bb67d9 100644 --- a/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc @@ -15,10 +15,10 @@ */ #include -#include "parallel/tensor_layout/tensor_layout.h" -#include "parallel/tensor_layout/tensor_info.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/device_manager.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc b/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc index 1eb65b468f..7942fa2a10 100644 --- a/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc @@ -15,9 +15,9 @@ */ #include "common/common_test.h" -#include "parallel/auto_parallel/rec_core/rec_tensor.h" -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/auto_parallel/rec_core/rec_partition.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" #include #include "ir/value.h" diff --git a/tests/ut/cpp/parallel/device_manager_test.cc b/tests/ut/cpp/parallel/device_manager_test.cc index 056896f514..0c048d647b 100644 --- a/tests/ut/cpp/parallel/device_manager_test.cc +++ b/tests/ut/cpp/parallel/device_manager_test.cc @@ -15,9 +15,9 @@ */ #include #include "common/common_test.h" -#include "parallel/device.h" -#include "parallel/device_manager.h" -#include "parallel/group_manager.h" +#include "frontend/parallel/device.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/group_manager.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/device_matrix_test.cc b/tests/ut/cpp/parallel/device_matrix_test.cc index 877a211df8..57a438e76e 100644 --- a/tests/ut/cpp/parallel/device_matrix_test.cc +++ b/tests/ut/cpp/parallel/device_matrix_test.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/device_matrix.h" +#include "frontend/parallel/device_matrix.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/group_manager_test.cc b/tests/ut/cpp/parallel/group_manager_test.cc index e3d2b3a364..fa4abfcb7e 100644 --- a/tests/ut/cpp/parallel/group_manager_test.cc +++ b/tests/ut/cpp/parallel/group_manager_test.cc @@ -14,10 +14,10 @@ * limitations under the License. */ #include -#include "parallel/device_manager.h" +#include "frontend/parallel/device_manager.h" #include "common/common_test.h" -#include "parallel/device.h" -#include "parallel/group_manager.h" +#include "frontend/parallel/device.h" +#include "frontend/parallel/group_manager.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/activation_info_test.cc b/tests/ut/cpp/parallel/ops_info/activation_info_test.cc index a9fe9b4c48..5f09de9e48 100644 --- a/tests/ut/cpp/parallel/ops_info/activation_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/activation_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/activation_test.cc b/tests/ut/cpp/parallel/ops_info/activation_test.cc index 9af7203799..9d129b7a18 100644 --- a/tests/ut/cpp/parallel/ops_info/activation_test.cc +++ b/tests/ut/cpp/parallel/ops_info/activation_test.cc @@ -18,9 +18,9 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc b/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc index e54d1f2423..e49ed4e79d 100644 --- a/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc b/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc index 947ad60cca..125723868a 100644 --- a/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc +++ b/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/arithmetic_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/arithmetic_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc b/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc index 503edf2eda..029e0f2dc6 100644 --- a/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/get_next_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/get_next_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc b/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc index b59481e1f6..7037a85699 100644 --- a/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/l2_normalize_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/l2_normalize_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc b/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc index cf5a4239a2..8de5c07226 100644 --- a/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc index f710f51265..2d5676f211 100644 --- a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc @@ -18,11 +18,11 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/matmul_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" -#include "parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/matmul_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc b/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc index 07d150a294..074e4582f0 100644 --- a/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/onehot_info.h" -#include "parallel/device_manager.h" -#include "parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/onehot_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc b/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc index c89bf97fb3..769d5bec45 100644 --- a/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc +++ b/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/onehot_info.h" -#include "parallel/device_manager.h" -#include "parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/onehot_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/pow_info_test.cc b/tests/ut/cpp/parallel/ops_info/pow_info_test.cc index 7b37a90fd8..f582640db8 100644 --- a/tests/ut/cpp/parallel/ops_info/pow_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/pow_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/arithmetic_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/arithmetic_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/prelu_test.cc b/tests/ut/cpp/parallel/ops_info/prelu_test.cc index d6db1b8460..1d4cf5eff0 100644 --- a/tests/ut/cpp/parallel/ops_info/prelu_test.cc +++ b/tests/ut/cpp/parallel/ops_info/prelu_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/prelu_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/prelu_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc b/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc index a1fe46ca33..64ba6af70b 100644 --- a/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc +++ b/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc @@ -18,11 +18,11 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/reduce_method_info.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/reduce_method_info.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/reshape_test.cc b/tests/ut/cpp/parallel/ops_info/reshape_test.cc index fb60c6d250..8cc8390e9a 100644 --- a/tests/ut/cpp/parallel/ops_info/reshape_test.cc +++ b/tests/ut/cpp/parallel/ops_info/reshape_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/reshape_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/reshape_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc b/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc index 03634b9a6f..d370c168c9 100644 --- a/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/loss_info.h" -#include "parallel/device_manager.h" -#include "parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/loss_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc b/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc index bba6e89626..9c4205672b 100644 --- a/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc b/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc index a892c5c84a..2be6c5bf7f 100644 --- a/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc index 42d292c605..b523652fcb 100644 --- a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/arithmetic_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/arithmetic_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc index eabac51e17..461a27d4ed 100644 --- a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc @@ -15,10 +15,10 @@ */ #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/device_manager.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/ops_info/tmp_identity_info.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/ops_info/tmp_identity_info.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/transpose_test.cc b/tests/ut/cpp/parallel/ops_info/transpose_test.cc index 991ec47820..fe5cbb01b3 100644 --- a/tests/ut/cpp/parallel/ops_info/transpose_test.cc +++ b/tests/ut/cpp/parallel/ops_info/transpose_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/transpose_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/transpose_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/step_auto_parallel_test.cc b/tests/ut/cpp/parallel/step_auto_parallel_test.cc index a1474ca244..6cf7ec66c6 100644 --- a/tests/ut/cpp/parallel/step_auto_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_auto_parallel_test.cc @@ -14,12 +14,12 @@ * limitations under the License. */ #include "common/common_test.h" -#include "parallel/step_parallel.h" -#include "parallel/step_auto_parallel.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/static_analysis.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/step_auto_parallel.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/static_analysis/static_analysis.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/step_parallel_test.cc b/tests/ut/cpp/parallel/step_parallel_test.cc index d8f8681a34..5657db8790 100644 --- a/tests/ut/cpp/parallel/step_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_parallel_test.cc @@ -14,12 +14,12 @@ * limitations under the License. */ #include "common/common_test.h" -#include "parallel/step_parallel.h" -#include "parallel/graph_util/generate_graph.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/graph_util/generate_graph.h" #include "common/py_func_graph_fetcher.h" #include "debug/draw.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/static_analysis.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/static_analysis/static_analysis.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/strategy_test.cc b/tests/ut/cpp/parallel/strategy_test.cc index 9a2f92f018..c13b71944e 100644 --- a/tests/ut/cpp/parallel/strategy_test.cc +++ b/tests/ut/cpp/parallel/strategy_test.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" +#include "frontend/parallel/strategy.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc b/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc index 2ba8cc9dfc..b80f199035 100644 --- a/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc @@ -17,10 +17,10 @@ #include #include "common/common_test.h" #include "ir/value.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/matmul_info.h" -#include "parallel/device_manager.h" -#include "parallel/tensor_layout/construct_operator.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/matmul_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/construct_operator.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc b/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc index 5291e2f48d..4ddc130a45 100644 --- a/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc @@ -17,8 +17,8 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/tensor_layout.h" -#include "parallel/tensor_layout/redistribution_layout_transfer.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h" #include "util_layout_gen_test.h" namespace mindspore { diff --git a/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc b/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc index 1b1dd4af04..f6caad2f9d 100644 --- a/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc @@ -16,8 +16,8 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/redistribution_operator_infer.h" -#include "parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/redistribution_operator_infer.h" +#include "frontend/parallel/device_manager.h" #include "util_layout_gen_test.h" namespace mindspore { diff --git a/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc b/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc index 9d6152721e..11f471ea33 100644 --- a/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc @@ -17,8 +17,8 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/tensor_layout.h" -#include "parallel/tensor_layout/reshape_layout_transfer.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/reshape_layout_transfer.h" #include "util_layout_gen_test.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc b/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc index b5e2ea3e5b..824ab876cd 100644 --- a/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/shape_util.h" +#include "frontend/parallel/tensor_layout/shape_util.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc b/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc index bae05d650a..15fb16f088 100644 --- a/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc @@ -17,7 +17,7 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc b/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc index 572763faa3..40a4017c4b 100644 --- a/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc @@ -17,7 +17,7 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc index 6f5c1e49ed..330b571ae7 100644 --- a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc @@ -21,7 +21,7 @@ #include #include #include -#include "parallel/tensor_layout/shape_util.h" +#include "frontend/parallel/tensor_layout/shape_util.h" #include "common/common_test.h" using std::pow; diff --git a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h index a359cadbea..c16a1fc6d4 100644 --- a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h +++ b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h @@ -20,7 +20,7 @@ #include #include -#include "parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/virtual_dataset_test.cc b/tests/ut/cpp/parallel/virtual_dataset_test.cc index 1d3ff081c7..4cafdebc17 100644 --- a/tests/ut/cpp/parallel/virtual_dataset_test.cc +++ b/tests/ut/cpp/parallel/virtual_dataset_test.cc @@ -17,10 +17,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/virtual_dataset_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/virtual_dataset_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/pipeline/parse/parser_abnormal_test.cc b/tests/ut/cpp/pipeline/parse/parser_abnormal_test.cc index 3c97cfb203..2d21b591ea 100644 --- a/tests/ut/cpp/pipeline/parse/parser_abnormal_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_abnormal_test.cc @@ -19,7 +19,7 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" #include "utils/profile.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/parse/parser_class_test.cc b/tests/ut/cpp/pipeline/parse/parser_class_test.cc index dcedc32b1b..8d9cc8ebc8 100644 --- a/tests/ut/cpp/pipeline/parse/parser_class_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_class_test.cc @@ -19,7 +19,7 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/parse/parser_integrate_test.cc b/tests/ut/cpp/pipeline/parse/parser_integrate_test.cc index fd8438503f..1f54298a81 100644 --- a/tests/ut/cpp/pipeline/parse/parser_integrate_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_integrate_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/parse/parser_primitive_test.cc b/tests/ut/cpp/pipeline/parse/parser_primitive_test.cc index adc09cca32..937ad1fe5e 100644 --- a/tests/ut/cpp/pipeline/parse/parser_primitive_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_primitive_test.cc @@ -19,7 +19,7 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/parse/parser_test.cc b/tests/ut/cpp/pipeline/parse/parser_test.cc index 4d7731dfd1..f1d9087110 100644 --- a/tests/ut/cpp/pipeline/parse/parser_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_test.cc @@ -19,7 +19,7 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/parse/resolve_test.cc b/tests/ut/cpp/pipeline/parse/resolve_test.cc index 8ade92bb34..5a2d0ebd7f 100644 --- a/tests/ut/cpp/pipeline/parse/resolve_test.cc +++ b/tests/ut/cpp/pipeline/parse/resolve_test.cc @@ -19,7 +19,7 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/resource_test.cc b/tests/ut/cpp/pipeline/resource_test.cc index 09bd2060dc..b6be393652 100644 --- a/tests/ut/cpp/pipeline/resource_test.cc +++ b/tests/ut/cpp/pipeline/resource_test.cc @@ -18,9 +18,9 @@ #include "common/common_test.h" #include "utils/log_adapter.h" -#include "pipeline/resource.h" +#include "pipeline/jit/resource.h" #include "ir/primitive.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" namespace mindspore { namespace pipeline { diff --git a/tests/ut/cpp/pipeline/static_analysis/abstract_test.cc b/tests/ut/cpp/pipeline/static_analysis/abstract_test.cc deleted file mode 100644 index 93baf86c3e..0000000000 --- a/tests/ut/cpp/pipeline/static_analysis/abstract_test.cc +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include - -#include "common/common_test.h" - -#include "pipeline/static_analysis/static_analysis.h" -#include "pipeline/static_analysis/utils.h" -#include "pipeline/static_analysis/prim.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/resolve.h" -#include "pipeline/parse/data_converter.h" -#include "operator/ops.h" - -namespace mindspore { -namespace abstract { - -class TestAbstract : public UT::Common { - public: - TestAbstract() {} - virtual void SetUp() {} - virtual void TearDown() {} -}; - -TEST_F(TestAbstract, TestParseDataClass) { - py::object fn = parse::python_adapter::GetPyFn("gtest_input.pipeline.parse.parser_test", "TestFoo"); - - ClassPtr cls_ptr = parse::ParseDataClass(fn); - ASSERT_TRUE(nullptr != cls_ptr); - std::shared_ptr cls = dyn_cast(cls_ptr); - ASSERT_TRUE(nullptr != cls); - - MS_LOG(INFO) << "" << cls->ToString(); - ASSERT_EQ(cls->tag(), Named(std::string("TestFoo"))); - - ClassAttrVector attributes = cls->GetAttributes(); - ASSERT_EQ(attributes.size(), 2); - for (auto &v : attributes) { - if (v.first == std::string("x")) { - ASSERT_TRUE(nullptr != dyn_cast(v.second)); - } - if (v.first == std::string("y")) { - ASSERT_TRUE(nullptr != dyn_cast(v.second)); - } - } - - std::unordered_map methods = cls->methods(); - ASSERT_EQ(methods.size(), 4); - int counts = 0; - for (auto &v : methods) { - if (v.first == std::string("inf")) { - counts++; - } - MS_LOG(INFO) << "" << v.first; - } - ASSERT_EQ(counts, 1); - - ValuePtr obj = std::make_shared(fn, "TestFoo"); - - ValueNodePtr fn_node = NewValueNode(obj); - AnfNodeConfigPtr fn_conf = std::make_shared(nullptr, fn_node, nullptr); - AbstractBasePtr foo = ToAbstract(obj, nullptr, fn_conf); - ASSERT_TRUE(foo != nullptr); - - AbstractBasePtr abstract_x = FromValue(1.1, true); - AbstractBasePtr abstract_y = FromValue(5, true); - - auto partical_func = dyn_cast(foo); - AbstractBasePtrList args_spec_list = partical_func->args(); - ASSERT_GT(args_spec_list.size(), 0); - AbstractScalarPtr abs_scalar = dyn_cast(args_spec_list[0]); - - AbstractBasePtrList args_list = {abs_scalar, abstract_x, abstract_y}; - - StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim::kPrimMakeRecord); - ASSERT_TRUE(nullptr != eval_impl); - - AbstractBasePtr new_cls = eval_impl(nullptr, prim::kPrimMakeRecord, args_list); - ASSERT_TRUE(nullptr != new_cls); -} - -} // namespace abstract -} // namespace mindspore diff --git a/tests/ut/cpp/pipeline/static_analysis/data_test.cc b/tests/ut/cpp/pipeline/static_analysis/data_test.cc index 61a22bbe5f..fb9d8b1f7e 100644 --- a/tests/ut/cpp/pipeline/static_analysis/data_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/data_test.cc @@ -18,9 +18,9 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/utils.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" namespace mindspore { namespace abstract { diff --git a/tests/ut/cpp/pipeline/static_analysis/dshape_test.cc b/tests/ut/cpp/pipeline/static_analysis/dshape_test.cc deleted file mode 100644 index ae18f7730b..0000000000 --- a/tests/ut/cpp/pipeline/static_analysis/dshape_test.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include - -#include "common/common_test.h" - -#include "pipeline/static_analysis/dshape.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace abstract { -class TestDShape : public UT::Common { - public: - Shape shp_1; - Shape shp_2; - Shape shp_3; - Shape shp_4; - - NoShape shp_noshp_1; - NoShape shp_noshp_2; - - TupleShape shp_tuple_1; - TupleShape shp_tuple_2; - TupleShape shp_tuple_3; - TupleShape shp_tuple_4; - TestDShape() - : shp_1({1, 1}), - shp_2({1, 1}), - shp_3({1, 2}), - shp_4({1}), - - shp_noshp_1(), - shp_noshp_2(), - - shp_tuple_1({NoShape().Clone(), Shape({1, 1}).Clone()}), - shp_tuple_2({NoShape().Clone(), Shape({1, 1, 1}).Clone()}), - shp_tuple_3({NoShape().Clone(), Shape({1, 2, 1}).Clone()}), - shp_tuple_4({NoShape().Clone()}) {} -}; - -TEST_F(TestDShape, EqualTest) { - ASSERT_TRUE(shp_1 == shp_2); - ASSERT_FALSE(shp_1 == shp_3); - ASSERT_FALSE(shp_1 == shp_noshp_1); - - ASSERT_TRUE(shp_noshp_1 == shp_noshp_2); - - ASSERT_FALSE(shp_tuple_1 == shp_1); - ASSERT_FALSE(shp_tuple_1 == shp_tuple_2); - ASSERT_FALSE(shp_tuple_1 == shp_tuple_4); -} -TEST_F(TestDShape, ToString) { - ASSERT_EQ(shp_3.ToString(), "(1, 2)"); - ASSERT_EQ(shp_noshp_1.ToString(), "NoShape"); - ASSERT_EQ(shp_tuple_2.ToString(), "TupleShape(NoShape, (1, 1, 1))"); -} - -TEST_F(TestDShape, Clone) { - ASSERT_EQ(*shp_3.Clone(), shp_3); - ASSERT_EQ(*shp_noshp_1.Clone(), shp_noshp_1); - ASSERT_EQ(*shp_tuple_2.Clone(), shp_tuple_2); -} - -} // namespace abstract -} // namespace mindspore diff --git a/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc b/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc index eebe6c252b..664f353faa 100644 --- a/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "pipeline/static_analysis/evaluator.h" -#include "pipeline/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/evaluator.h" +#include "pipeline/jit/static_analysis/prim.h" #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pipeline/static_analysis/helper.cc b/tests/ut/cpp/pipeline/static_analysis/helper.cc index db697e95e0..ebf8c233e2 100644 --- a/tests/ut/cpp/pipeline/static_analysis/helper.cc +++ b/tests/ut/cpp/pipeline/static_analysis/helper.cc @@ -16,7 +16,7 @@ #include "pipeline/static_analysis/helper.h" -#include "pipeline/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/prim.h" namespace mindspore { namespace abstract { diff --git a/tests/ut/cpp/pipeline/static_analysis/helper.h b/tests/ut/cpp/pipeline/static_analysis/helper.h index 7ca902a1e9..44c647779e 100644 --- a/tests/ut/cpp/pipeline/static_analysis/helper.h +++ b/tests/ut/cpp/pipeline/static_analysis/helper.h @@ -17,7 +17,7 @@ #ifndef TESTS_UT_PIPELINE_STATIC_ANALYSIS_HELPER_H_ #define TESTS_UT_PIPELINE_STATIC_ANALYSIS_HELPER_H_ -#include "pipeline/static_analysis/evaluator.h" +#include "pipeline/jit/static_analysis/evaluator.h" namespace mindspore { namespace abstract { diff --git a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc index 04a14a0f29..8ebea4d212 100644 --- a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc @@ -21,9 +21,9 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "ir/manager.h" -#include "pipeline/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/prim.h" #include "pipeline/static_analysis/helper.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "debug/draw.h" #include "ir/tensor.h" #include "utils/symbolic.h" diff --git a/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc b/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc index 23ea55f8f7..e32a86d9be 100644 --- a/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc @@ -20,8 +20,8 @@ #include "common/py_func_graph_fetcher.h" #include "ir/manager.h" -#include "pipeline/static_analysis/prim.h" -#include "pipeline/static_analysis/program_specialize.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/program_specialize.h" #include "pipeline/static_analysis/helper.h" #include "utils/log_adapter.h" #include "utils/graph_utils.h" diff --git a/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc b/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc index 8a58969e12..78d3a7083a 100644 --- a/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc @@ -16,16 +16,16 @@ #include #include -#include "pipeline/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/prim.h" #include "pipeline/static_analysis/helper.h" #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "ir/manager.h" #include "ir/tensor.h" -#include "operator/ops.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/resource.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/resource.h" #include "debug/draw.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/pipeline/static_analysis/utils_test.cc b/tests/ut/cpp/pipeline/static_analysis/utils_test.cc deleted file mode 100644 index dceef71b02..0000000000 --- a/tests/ut/cpp/pipeline/static_analysis/utils_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pipeline/static_analysis/utils.h" - -#include "common/common_test.h" -#include "pipeline/static_analysis/static_analysis.h" - -namespace mindspore { -namespace abstract { -class TestUtils : public UT::Common { - public: - TestUtils() {} - virtual void SetUp() {} - virtual void TearDown() {} -}; - -TEST_F(TestUtils, test_join) { - // AbstractScalar - AbstractBasePtr abs_s1 = FromValue(1, false); - AbstractBasePtr abs_s2 = FromValue(2, false); - AbstractBasePtr abs_s_anything = FromValue(2, true); - - AbstractBasePtr res_s1 = abs_s1->Join(abs_s2); - ASSERT_EQ(*res_s1, *abs_s_anything); - - // AbstractTuple join; - std::vector list1 = {1, 2, 3, 4, 5}; - std::vector list2 = {5, 4, 3, 2, 1}; - AbstractBasePtr abs_t1 = FromValue(list1, true); - AbstractBasePtr abs_t2 = FromValue(list2, true); - - AbstractBasePtr res_t1 = abs_t1->Join(abs_t2); - ASSERT_EQ(res_t1, abs_t1); - - abs_s1 = FromValue(1, false); - - AbstractBasePtr t1 = std::make_shared(AbstractBasePtrList({abs_s1, abs_s_anything})); - AbstractBasePtr t2 = std::make_shared(AbstractBasePtrList({abs_s1, abs_s_anything})); - AbstractBasePtr t3 = std::make_shared(AbstractBasePtrList({abs_s_anything, abs_s_anything})); - - res_t1 = t1->Join(t2); - ASSERT_EQ(res_t1, t1); - - res_t1 = t1->Join(t3); - ASSERT_EQ(*res_t1, *t3); - - res_t1 = t3->Join(t1); - ASSERT_EQ(res_t1, t3); -} - -} // namespace abstract -} // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc index 483c144930..58b810a3e1 100644 --- a/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc @@ -17,23 +17,23 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" -#include "kernel/kernel.h" -#include "device/kernel_info.h" -#include "pre_activate/common/optimizer.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" -#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" +#include "backend/kernel_compiler/kernel.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h" +#include "backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc index e4ab2431b7..ba64c206af 100644 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc @@ -15,14 +15,14 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" -#include "mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc index 56bf0ae4e0..2be25212e8 100644 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc @@ -15,16 +15,16 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/ascend_session.h" -#include "session/anf_runtime_algorithm.h" -#include "pipeline/resource.h" -#include "operator/ops.h" +#include "backend/session/ascend_session.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "pipeline/jit/resource.h" +#include "frontend/operator/ops.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc index 22cf70ded3..103d0f21a4 100644 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc @@ -15,16 +15,16 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" #define private public #define protected public -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" #undef private #undef protected namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/check_consistency_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/check_consistency_test.cc index 72ce73e20f..89d680f442 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/check_consistency_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/check_consistency_test.cc @@ -16,18 +16,18 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "common/backend_common_test.h" -#include "session/ascend_session.h" -#include "session/anf_runtime_algorithm.h" -#include "pipeline/resource.h" -#include "pipeline/action.h" -#include "operator/ops.h" +#include "backend/session/ascend_session.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/action.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/format_type/check_consistency.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/format_type/check_consistency.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc index 317eace6c6..2b61a49048 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc @@ -14,17 +14,17 @@ * limitations under the License. */ #include "common/backend_common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "device/kernel_info.h" -#include "pre_activate/ascend/format_type/insert_cast.h" -#include "kernel/kernel_build_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/ascend/format_type/insert_cast.h" +#include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" #include "utils/context/ms_context.h" diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc index 8c57238e0a..0a5cf3dd9e 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc @@ -14,18 +14,18 @@ * limitations under the License. */ #include "common/backend_common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" #include "utils/context/ms_context.h" #define private public #define protected public -#include "pre_activate/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/merge_cast_to_op_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/merge_cast_to_op_test.cc index c0017c2deb..69e7fa8b27 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/merge_cast_to_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/merge_cast_to_op_test.cc @@ -15,17 +15,17 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" #define private public #define protected public -#include "pre_activate/ascend/format_type/merge_cast_to_op.h" +#include "backend/optimizer/ascend/format_type/merge_cast_to_op.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/addn_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/addn_fission_test.cc index 90174636b1..8ec2b22a79 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/addn_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/addn_fission_test.cc @@ -18,7 +18,7 @@ #include "common/py_func_graph_fetcher.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/addn_fission.h" +#include "backend/optimizer/ascend/ir_fission/addn_fission.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc index 06895cb081..f793e0371b 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" +#include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc index ea4a5c0d5d..80f30c8938 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" +#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc index dc437221f8..f0a5a857b9 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc @@ -15,17 +15,17 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/bn_grad_split.h" +#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc index c5ebc28b48..9f4f31bf82 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc @@ -15,20 +15,20 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/ascend_session.h" -#include "session/anf_runtime_algorithm.h" -#include "pipeline/resource.h" -#include "operator/ops.h" +#include "backend/session/ascend_session.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "pipeline/jit/resource.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/bn_split.h" +#include "backend/optimizer/ascend/ir_fission/bn_split.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/lars_v2_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/lars_v2_fission_test.cc index c0a0cc455e..c726142e99 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/lars_v2_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/lars_v2_fission_test.cc @@ -16,7 +16,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fission/lars_v2_fission.h" +#include "backend/optimizer/ascend/ir_fission/lars_v2_fission.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/layer_norm_grad_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/layer_norm_grad_split_test.cc index 1df87960e3..4303485d85 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/layer_norm_grad_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/layer_norm_grad_split_test.cc @@ -15,17 +15,17 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" +#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/single_batch_norm_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/single_batch_norm_fission_test.cc index b0aa455a0a..9f84f22678 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/single_batch_norm_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/single_batch_norm_fission_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" +#include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc index ab70e83480..30de43be4e 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc @@ -18,7 +18,7 @@ #include "common/py_func_graph_fetcher.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/split_fission.h" +#include "backend/optimizer/ascend/ir_fission/split_fission.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/tensor_scatter_update_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/tensor_scatter_update_fission_test.cc index faebe0e4a0..1c928b581d 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/tensor_scatter_update_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/tensor_scatter_update_fission_test.cc @@ -16,7 +16,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fission/tensor_scatter_update_fission.h" +#include "backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc index b09268aa66..2ab614d4c2 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc @@ -16,13 +16,13 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "device/kernel_info.h" -#include "pre_activate/pass/convert_const_input_to_attr.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/pass/convert_const_input_to_attr.h" #include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/topk_split.h" +#include "backend/optimizer/ascend/ir_fission/topk_split.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc index f2b975a08e..220e45f10a 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc @@ -16,16 +16,16 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" #include "debug/anf_ir_dump.h" #include "utils/context/ms_context.h" #define private public #define protected public -#include "pre_activate/ascend/format_type/insert_trans_op.h" -#include "pre_activate/ascend/ir_fission/transdata_split.h" +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/ir_fission/transdata_split.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc index c2ee7b6519..2759864037 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc @@ -15,7 +15,7 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc index 014e60f579..78c815bf50 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc new file mode 100644 index 0000000000..5d42ff7069 --- /dev/null +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/backend_common_test.h" +#include "common/py_func_graph_fetcher.h" +#include "debug/anf_ir_dump.h" + +#define private public +#define protected public +#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +class TestHWAddInputToOutput : public BackendCommon { + public: + TestHWAddInputToOutput() : getPyFun_("gtest_input.pre_activate.add_input_to_output_test", true) {} + ~TestHWAddInputToOutput() override = default; + + public: + UT::PyFuncGraphFetcher getPyFun_; +}; + +class MockOpFinder : public OpFinder { + public: + MockOpFinder() = default; + ~MockOpFinder() override = default; + int GetOpRegisteredOutputNum(const std::string &op_name) override { return 2; } +}; + +TEST_F(TestHWAddInputToOutput, test_add_input_to_output) { + FuncGraphPtr g = getPyFun_.CallAndParseRet("test_add_input_to_output", "before"); + EXPECT_NE(g, nullptr); + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 5; ++i) { + args_spec_list.push_back(x_abstract); + } + auto kg = GetKernelGraph(g, args_spec_list); + EXPECT_NE(kg, nullptr); + auto ret = kg->get_return(); + EXPECT_NE(ret, nullptr); + auto make_tuple = ret->input(1); + EXPECT_NE(make_tuple, nullptr); + auto momentum = make_tuple->cast()->input(1); + EXPECT_NE(momentum, nullptr); + EXPECT_NE(momentum->abstract(), nullptr); + EXPECT_FALSE(momentum->abstract()->isa()); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pass->op_finder_ = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + (void)optimizer->Optimize(kg); + EXPECT_TRUE(momentum->abstract()->isa()); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer_test.cc index 466cba8e67..d9d0baf7be 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad_test.cc index d1fc2783ac..1b64e5fd00 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion_test.cc index 0c8bf67391..aa56d79239 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_value_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_value_fusion_test.cc index 4160c3a8e4..ac01f9b1dd 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_value_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_value_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/clip_by_value_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc index 2044857841..be6bd95b02 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_softmax_grad_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_softmax_grad_test.cc index 05fa2c65df..068cc0d12e 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_softmax_grad_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_softmax_grad_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc index ffa5a42b4d..663ed309ee 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/derelu_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc index 597b7b18ff..f7cbfdc678 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc index 6ea622d030..64c004ff27 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc @@ -17,7 +17,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc index 36f0321511..776ce625b7 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc @@ -16,7 +16,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule_test.cc index fbb1f5e913..bf21649672 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule_test.cc @@ -16,7 +16,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_right_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_right_rule_test.cc index f1ca92c811..6a7c866ab4 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_right_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_right_rule_test.cc @@ -15,7 +15,7 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion_test.cc index 7a2806162b..4de2de2700 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2_test.cc index 05262e72ab..5be6195da2 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2_test.cc @@ -17,7 +17,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc index 44b9b3df69..7392d05b98 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc @@ -15,13 +15,13 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "device/kernel_info.h" +#include "runtime/device/kernel_info.h" #include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion_test.cc index c8f97be290..f67eda9776 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion_test.cc index 114fcf4233..50dfd66f54 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc index 87bb21f89a..b293cdeecb 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc @@ -15,7 +15,7 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_addn_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_addn_fusion_test.cc index ab9718d80a..8ac106f81c 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_addn_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_addn_fusion_test.cc @@ -15,7 +15,7 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.h" +#include "mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc index 59140e91a1..6792f4720a 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc @@ -17,8 +17,8 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion_test.cc index 5f02f0e9c1..f6e8a1194c 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/square_sum_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/square_sum_fusion_test.cc index 2dd858a0fc..efe5433d75 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/square_sum_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/square_sum_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc index 3290acd42f..6ec407d2ea 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc @@ -17,8 +17,8 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc index 98dc9e9efc..d156959c4c 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc @@ -16,14 +16,14 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" #include "utils/context/ms_context.h" #define private public #define protected public -#include "pre_activate/ascend/format_type/insert_trans_op.h" -#include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc b/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc index 7b0e2cc9db..12030433fc 100644 --- a/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc +++ b/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc @@ -20,8 +20,8 @@ #include #include "common/common_test.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/visit.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/visit.h" #include "utils/base_ref.h" #include "ir/anf.h" diff --git a/tests/ut/cpp/pre_activate/mem_reuse/kernel_ref_test.cc b/tests/ut/cpp/pre_activate/mem_reuse/kernel_ref_test.cc index 5b237fda58..8b6d3e061a 100644 --- a/tests/ut/cpp/pre_activate/mem_reuse/kernel_ref_test.cc +++ b/tests/ut/cpp/pre_activate/mem_reuse/kernel_ref_test.cc @@ -18,7 +18,7 @@ #include #include -#include "pre_activate/mem_reuse/kernel_refcount.h" +#include "backend/optimizer/mem_reuse/kernel_refcount.h" #include "utils/utils.h" #include "common/common_test.h" diff --git a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc index e0966d2d12..2a6904658e 100644 --- a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc +++ b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc @@ -17,9 +17,9 @@ #include #include #include -#include "operator/ops.h" -#include "pre_activate/mem_reuse/mem_reuse.h" -#include "pre_activate/mem_reuse/mem_reuse_allocator.h" +#include "frontend/operator/ops.h" +#include "backend/optimizer/mem_reuse/mem_reuse.h" +#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc index a36463d297..31ae923c0a 100644 --- a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc +++ b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc @@ -16,19 +16,19 @@ #include #include #include -#include "session/kernel_graph.h" -#include "session/session_basic.h" -#include "session/ascend_session.h" -#include "pre_activate/mem_reuse/kernel_refcount.h" -#include "pre_activate/mem_reuse/mem_reuse_allocator.h" -#include "device/kernel_info.h" -#include "kernel/tbe/tbe_kernel_mod.h" -#include "operator/ops.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/session_basic.h" +#include "backend/session/ascend_session.h" +#include "backend/optimizer/mem_reuse/kernel_refcount.h" +#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" +#include "frontend/operator/ops.h" #include "utils/log_adapter.h" -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" #include "common/utils.h" -#include "pipeline/resource.h" -#include "pre_activate/mem_reuse/mem_reuse.h" +#include "pipeline/jit/resource.h" +#include "backend/optimizer/mem_reuse/mem_reuse.h" #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc index 69a330614e..02e1865a82 100644 --- a/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc @@ -15,16 +15,16 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/pass/communication_op_fusion.h" -#include "pre_activate/common/optimizer.h" -#include "device/kernel_info.h" -#include "pre_activate/common/pass_manager.h" -#include "kernel/kernel_build_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/pass/communication_op_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" #include "utils/context/ms_context.h" diff --git a/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc b/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc index 12c4d35db5..cfcc34970b 100644 --- a/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc +++ b/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc @@ -14,17 +14,17 @@ * limitations under the License. */ #include "common/backend_common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "device/kernel_info.h" -#include "pre_activate/pass/common_subexpression_elimination.h" -#include "kernel/kernel_build_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/pass/common_subexpression_elimination.h" +#include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" #include "utils/context/ms_context.h" diff --git a/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc b/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc index 8fc709433e..25e4b3c111 100644 --- a/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc +++ b/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc @@ -14,13 +14,13 @@ * limitations under the License. */ #include "common/backend_common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/const_to_attr_strided_slice_grad.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h" #include "utils/utils.h" #include "common/utils.h" diff --git a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc index fcb3b19a24..ac3272317a 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc @@ -14,13 +14,13 @@ * limitations under the License. */ #include "common/backend_common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/convert_const_input_to_attr.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/convert_const_input_to_attr.h" #include "utils/utils.h" #include "common/utils.h" diff --git a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc index 1749e54d94..5b303d15a5 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc @@ -18,10 +18,10 @@ #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/convert_const_input_to_tensor_input.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/convert_const_input_to_tensor_input.h" #include "utils/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/pass/convert_tuple_input_to_dynamic_input_test.cc b/tests/ut/cpp/pre_activate/pass/convert_tuple_input_to_dynamic_input_test.cc index aded376536..2c1dfc1c6c 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_tuple_input_to_dynamic_input_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_tuple_input_to_dynamic_input_test.cc @@ -18,10 +18,10 @@ #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/convert_tuple_input_to_dynamic_input.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h" #include "utils/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/pass/convert_tuple_output_to_maketuple_test.cc b/tests/ut/cpp/pre_activate/pass/convert_tuple_output_to_maketuple_test.cc index eeb01270e2..458c854218 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_tuple_output_to_maketuple_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_tuple_output_to_maketuple_test.cc @@ -18,10 +18,10 @@ #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/convert_tuple_output_to_maketuple.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/convert_tuple_output_to_maketuple.h" #include "utils/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc index 3e43155011..07bef7a042 100644 --- a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc +++ b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc @@ -15,26 +15,26 @@ */ #include "common/backend_common_test.h" -#include "kernel/kernel.h" -#include "operator/ops.h" +#include "backend/kernel_compiler/kernel.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -// #include "device/optimizer/pass/insert_trans_op.h" -#include "pre_activate/ascend/format_type/insert_cast.h" -#include "pre_activate/pass/eliminate_redundant_op.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" +// #include "runtime/device/optimizer/pass/insert_trans_op.h" +#include "backend/optimizer/ascend/format_type/insert_cast.h" +#include "backend/optimizer/pass/eliminate_redundant_op.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" #include "utils/utils.h" #include "utils/context/ms_context.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" #include "utils/context/ms_context.h" #define private public #define protected public -#include "pre_activate/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/pass/getitem_tuple_test.cc b/tests/ut/cpp/pre_activate/pass/getitem_tuple_test.cc index b172e1b351..555dd95426 100644 --- a/tests/ut/cpp/pre_activate/pass/getitem_tuple_test.cc +++ b/tests/ut/cpp/pre_activate/pass/getitem_tuple_test.cc @@ -15,14 +15,14 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/ascend_session.h" -#include "pipeline/resource.h" -#include "operator/ops.h" +#include "backend/session/ascend_session.h" +#include "pipeline/jit/resource.h" +#include "frontend/operator/ops.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/pass/getitem_tuple.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/pass/getitem_tuple.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc b/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc index 04461e6602..f9cfe273bc 100644 --- a/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc +++ b/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/pass/optimize_dependence.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/pass/optimize_dependence.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pynative/pynative_execute_test.cc b/tests/ut/cpp/pynative/pynative_execute_test.cc index a0d1516b58..c5f25ca484 100644 --- a/tests/ut/cpp/pynative/pynative_execute_test.cc +++ b/tests/ut/cpp/pynative/pynative_execute_test.cc @@ -16,10 +16,10 @@ #include #include #include "common/common_test.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/data_converter.h" -#include "operator/ops.h" -#include "pynative/pynative_execute.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/data_converter.h" +#include "frontend/operator/ops.h" +#include "pipeline/pynative/pynative_execute.h" #include "utils/context/ms_context.h" #include "utils/utils.h" diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py index e38c61f16e..bcfa077ea5 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py @@ -17,8 +17,8 @@ import numpy as np import mindspore as ms from mindspore.common.tensor import Tensor -from mindspore.model_zoo.resnet import resnet50 from mindspore.ops import Primitive +from tests.ut.python.model.resnet import resnet50 scala_add = Primitive('scalar_add') diff --git a/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parser_integrate.py b/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parser_integrate.py index fa5b1b9055..28bded6401 100644 --- a/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parser_integrate.py +++ b/tests/ut/cpp/python_input/gtest_input/pipeline/parse/parser_integrate.py @@ -22,9 +22,9 @@ from mindspore.common import dtype from mindspore.common.api import ms_function, _executor from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore.model_zoo.resnet import resnet50 from mindspore.ops import functional as F from mindspore.train.model import Model +from tests.ut.python.model.resnet import resnet50 def test_high_order_function(a): diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py new file mode 100644 index 0000000000..4d4fa1fe96 --- /dev/null +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/add_input_to_output_test.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.ops import operations as P + +ApplyMomentum = P.ApplyMomentum() + + +class FnDict: + def __init__(self): + self.fnDict = {} + + def __call__(self, fn): + self.fnDict[fn.__name__] = fn + + def __getitem__(self, name): + return self.fnDict[name] + + +def test_add_input_to_output(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4): + return ApplyMomentum(input0, input1, input2, input3, input4) + + return fns[tag] diff --git a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc index 4c94cdde57..ac38e5427e 100644 --- a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc +++ b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc @@ -15,12 +15,12 @@ */ #include "common/common_test.h" -#include "ir/param_value_py.h" -#include "operator/ops.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "mindspore/ccsrc/device/kernel_info.h" -#include "mindspore/ccsrc/device/ascend/ascend_device_address.h" +#include "ir/param_value.h" +#include "frontend/operator/ops.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "mindspore/ccsrc/runtime/device/kernel_info.h" +#include "mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h" #include "utils/utils.h" namespace mindspore { @@ -255,7 +255,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) { AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {shape, shape}, add.get()); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat16->type_id()}); @@ -274,7 +274,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetInputsDeviceType({kFloat32->type_id(), kFloat16->type_id()}); @@ -293,7 +293,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputFormat) { auto pre_add = kernel_graph->NewCNode(pre_node_inputs); MS_EXCEPTION_IF_NULL(pre_add); pre_add->set_kernel_info(std::make_shared()); - auto d_kernel_info = pre_add->kernel_info(); + auto d_kernel_info = dynamic_cast(pre_add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetOutputsDeviceType({kFloat32->type_id()}); @@ -373,7 +373,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceShape) { MS_EXCEPTION_IF_NULL(add); add->set_abstract(tuple_abstract); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetOutputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_NZ}); @@ -404,7 +404,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetInputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC}); @@ -457,7 +457,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetOutputsDeviceType({kFloat32->type_id()}); @@ -474,7 +474,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetInputsDeviceType({kFloat32->type_id(), kFloat16->type_id()}); @@ -492,7 +492,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputDeviceDataType) { auto pre_add = kernel_graph->NewCNode(pre_add_inputs); MS_EXCEPTION_IF_NULL(pre_add); pre_add->set_kernel_info(std::make_shared()); - auto d_kernel_info = pre_add->kernel_info(); + auto d_kernel_info = dynamic_cast(pre_add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetOutputsDeviceType({kFloat32->type_id()}); @@ -513,7 +513,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputAddr) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); int *addr = nullptr; auto device_address = std::make_shared(addr, 1); @@ -528,7 +528,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputAddr) { auto pre_add = kernel_graph->NewCNode(pre_add_inputs); MS_EXCEPTION_IF_NULL(pre_add); pre_add->set_kernel_info(std::make_shared()); - auto d_kernel_info = pre_add->kernel_info(); + auto d_kernel_info = dynamic_cast(pre_add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); int *addr = nullptr; auto device_address = std::make_shared(addr, 1); @@ -561,7 +561,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetWorkspaceAddr) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); int *addr = nullptr; auto device_address = std::make_shared(addr, 1); @@ -643,7 +643,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetKernelType) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetKernelType(AKG_KERNEL); @@ -659,7 +659,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetProcessor) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetProcessor(kernel::AICORE); @@ -675,7 +675,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetFusionType) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetFusionType(kernel::CONVLUTION); @@ -703,7 +703,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetKernelMod) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); d_kernel_info->set_kernel_mod(nullptr); EXPECT_EQ(AnfAlgo::GetKernelMod(add), nullptr); @@ -764,10 +764,9 @@ TEST_F(AnfRuntimeAlgorithmTest, IsRealCNodeKernel) { TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) { auto kernel_graph = std::make_shared(); - py::object obj; auto parameter_node = kernel_graph->add_parameter(); MS_EXCEPTION_IF_NULL(parameter_node); - auto param_value_new = std::make_shared(obj); + auto param_value_new = std::make_shared(); parameter_node->set_default_param(param_value_new); EXPECT_TRUE(AnfAlgo::IsParameterWeight(parameter_node)); EXPECT_THROW(AnfAlgo::IsParameterWeight(nullptr), std::runtime_error); @@ -780,7 +779,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetStreamId) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); d_kernel_info->set_stream_id(0); EXPECT_EQ(AnfAlgo::GetStreamId(add), 0); diff --git a/tests/ut/cpp/session/kernel_graph_test.cc b/tests/ut/cpp/session/kernel_graph_test.cc index 75e653c26c..f24036b4aa 100644 --- a/tests/ut/cpp/session/kernel_graph_test.cc +++ b/tests/ut/cpp/session/kernel_graph_test.cc @@ -15,11 +15,11 @@ */ #include "common/common_test.h" -#include "ir/param_value_py.h" -#include "operator/ops.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "mindspore/ccsrc/device/kernel_info.h" +#include "ir/param_value.h" +#include "frontend/operator/ops.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "mindspore/ccsrc/runtime/device/kernel_info.h" #include "utils/utils.h" namespace mindspore { @@ -42,7 +42,7 @@ TEST_F(KernelGraphTest, NewValueNode) { auto x_abstract = std::make_shared(kFloat32, shape); add_value->set_abstract(x_abstract); add_value->set_kernel_info(std::make_shared()); - auto mutable_kernel_info = add_value->kernel_info(); + auto mutable_kernel_info = dynamic_cast(add_value->kernel_info()); MS_EXCEPTION_IF_NULL(mutable_kernel_info); std::shared_ptr builder = std::make_shared(); builder->SetOutputsFormat({kOpFormat_FRAC_Z}); @@ -82,8 +82,7 @@ TEST_F(KernelGraphTest, NewParameter) { // test weight parameter node as input auto weight_parameter_node = anf_graph->add_parameter(); MS_EXCEPTION_IF_NULL(weight_parameter_node); - py::object obj; - auto param_value_new = std::make_shared(obj); + auto param_value_new = std::make_shared(); weight_parameter_node->set_default_param(param_value_new); weight_parameter_node->set_abstract(x_abstract); auto new_weight_parameter_node = kernel_graph->NewParameter(weight_parameter_node); diff --git a/tests/ut/cpp/session/session_basic_test.cc b/tests/ut/cpp/session/session_basic_test.cc index 1a7ca68065..c438c92b52 100644 --- a/tests/ut/cpp/session/session_basic_test.cc +++ b/tests/ut/cpp/session/session_basic_test.cc @@ -15,10 +15,10 @@ */ #include "common/common_test.h" -#include "operator/ops.h" -#include "session/ascend_session.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "backend/session/ascend_session.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" #include "utils/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/stub/aicpu/aicpu_stub.cc b/tests/ut/cpp/stub/aicpu/aicpu_stub.cc index 78ada6de18..5516d1fdc8 100644 --- a/tests/ut/cpp/stub/aicpu/aicpu_stub.cc +++ b/tests/ut/cpp/stub/aicpu/aicpu_stub.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "kernel/kernel.h" +#include "backend/kernel_compiler/kernel.h" namespace mindspore { namespace kernel { diff --git a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc index a3a991247c..234ffdaf6b 100644 --- a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc +++ b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc @@ -15,7 +15,7 @@ */ #include #include "framework/ge_runtime/model_runner.h" -#include "device/ascend/tasksink/runtime_utils.h" +#include "runtime/device/ascend/tasksink/runtime_utils.h" namespace ge { namespace model_runner { @@ -32,6 +32,8 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint bool ModelRunner::UnloadModel(uint32_t model_id) { return true; } +bool ModelRunner::LoadModelComplete(uint32_t model_id) { return true; } + bool ModelRunner::RunModel(uint32_t model_id, const ge::InputData &input_data, ge::OutputData *output_data) { return true; } @@ -45,6 +47,11 @@ const std::vector &ModelRunner::GetStreamIdList(uint32_t model_id) con static std::vector stream_id_list; return stream_id_list; } + +const std::map> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { + static std::map> runtime_info_map; + return runtime_info_map; +} } // namespace model_runner } // namespace ge diff --git a/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc b/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc index ba642dfe18..87ab543c7c 100755 --- a/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc +++ b/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "kernel/kernel_fusion.h" -#include "kernel/tbe/tbe_kernel_mod.h" +#include "backend/kernel_compiler/kernel_fusion.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" #include "common/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc index 43d0dd4b3f..f6f2f45092 100644 --- a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc +++ b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc @@ -15,7 +15,7 @@ */ #include #include -#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" +#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "utils/log_adapter.h" namespace mindspore { diff --git a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc index a6ec3a50b5..85470e2315 100755 --- a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc +++ b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc @@ -13,10 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "device/ascend/ascend_stream_assign.h" -#include "device/ascend/ascend_label_assign.h" -#include "device/ascend/tasksink/task_generator.h" -#include "device/kernel_adjust.h" +#include "runtime/device/ascend/ascend_stream_assign.h" +#include "runtime/device/ascend/ascend_label_assign.h" +#include "runtime/device/kernel_adjust.h" namespace mindspore { namespace device { @@ -31,13 +30,6 @@ void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { return; } void AscendStreamAssign::GetHcomStreams(std::vector *streams) { return; } - -namespace tasksink { -bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::vector *const task_info_list, - uint32_t graph_id) { - return true; -} -} // namespace tasksink } // namespace ascend void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { return; } bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr) { return true; } diff --git a/tests/ut/cpp/stub/tasksink/task_sink_stub.cc b/tests/ut/cpp/stub/tasksink/task_sink_stub.cc new file mode 100644 index 0000000000..0b12a3862c --- /dev/null +++ b/tests/ut/cpp/stub/tasksink/task_sink_stub.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/tasksink/task_generator.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace tasksink { +bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::vector *const task_info_list, + uint32_t graph_id) { + return true; +} +} // namespace tasksink +} // namespace ascend +} // namespace device +} // namespace mindspore \ No newline at end of file diff --git a/tests/ut/cpp/transform/convert_test.cc b/tests/ut/cpp/transform/convert_test.cc index f8f48920e0..6902f7d90d 100644 --- a/tests/ut/cpp/transform/convert_test.cc +++ b/tests/ut/cpp/transform/convert_test.cc @@ -20,16 +20,16 @@ #include "transform/transform_base_test.h" #include "common/py_func_graph_fetcher.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" #include "debug/anf_ir_dump.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" #include "common/common_test.h" #define private public -#include "transform/types.h" -#include "transform/convert.h" +#include "transform/graph_ir/types.h" +#include "transform/graph_ir/convert.h" #include "securec/include/securec.h" #include "utils/utils.h" using std::cout; diff --git a/tests/ut/cpp/transform/graph_builder_test.cc b/tests/ut/cpp/transform/graph_builder_test.cc index e92463e2dc..e4d72b33cb 100644 --- a/tests/ut/cpp/transform/graph_builder_test.cc +++ b/tests/ut/cpp/transform/graph_builder_test.cc @@ -25,8 +25,8 @@ #endif #define private public -#include "transform/graph_builder.h" -#include "transform/df_graph_manager.h" +#include "transform/graph_ir/graph_builder.h" +#include "transform/graph_ir/df_graph_manager.h" using UT::Common; diff --git a/tests/ut/cpp/transform/graph_manager_test.cc b/tests/ut/cpp/transform/graph_manager_test.cc index 699f81ca4c..9e55e1725b 100644 --- a/tests/ut/cpp/transform/graph_manager_test.cc +++ b/tests/ut/cpp/transform/graph_manager_test.cc @@ -25,7 +25,7 @@ #endif #define private public -#include "transform/df_graph_manager.h" +#include "transform/graph_ir/df_graph_manager.h" using UT::Common; diff --git a/tests/ut/cpp/transform/graph_runner_test.cc b/tests/ut/cpp/transform/graph_runner_test.cc index 1b87cea464..b91ec959d2 100644 --- a/tests/ut/cpp/transform/graph_runner_test.cc +++ b/tests/ut/cpp/transform/graph_runner_test.cc @@ -21,10 +21,10 @@ #include "ir/tensor_py.h" #include "transform/transform_base_test.h" #include "common/py_func_graph_fetcher.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "operator/ops.h" -#include "transform/df_graph_manager.h" -#include "transform/convert.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "frontend/operator/ops.h" +#include "transform/graph_ir/df_graph_manager.h" +#include "transform/graph_ir/convert.h" #include "utils/utils.h" #ifdef OPEN_SOURCE @@ -34,7 +34,7 @@ #endif #define private public -#include "transform/graph_runner.h" +#include "transform/graph_ir/graph_runner.h" using mindspore::tensor::TensorPy; diff --git a/tests/ut/cpp/transform/op_adapter_test.cc b/tests/ut/cpp/transform/op_adapter_test.cc index 254452bb42..2aa6ba37e3 100644 --- a/tests/ut/cpp/transform/op_adapter_test.cc +++ b/tests/ut/cpp/transform/op_adapter_test.cc @@ -19,9 +19,9 @@ #include "common/common_test.h" -#include "transform/op_declare.h" +#include "transform/graph_ir/op_declare.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "./common.h" using std::cout; diff --git a/tests/ut/cpp/transform/transform_base_test.h b/tests/ut/cpp/transform/transform_base_test.h index 92147dfbbf..4886b25748 100644 --- a/tests/ut/cpp/transform/transform_base_test.h +++ b/tests/ut/cpp/transform/transform_base_test.h @@ -20,11 +20,11 @@ #include #include #include -#include "transform/util.h" +#include "transform/graph_ir/util.h" #include "ir/tensor.h" #include "common/common_test.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "./common.h" #include "graph/tensor.h" diff --git a/tests/ut/cpp/utils/any_test.cc b/tests/ut/cpp/utils/any_test.cc index d11831d602..8a49017d95 100644 --- a/tests/ut/cpp/utils/any_test.cc +++ b/tests/ut/cpp/utils/any_test.cc @@ -20,7 +20,7 @@ #include #include "common/common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "utils/any.h" #include "utils/misc.h" diff --git a/tests/ut/cpp/utils/callback_test.cc b/tests/ut/cpp/utils/callback_test.cc index c63f68f000..0a4ffb8190 100644 --- a/tests/ut/cpp/utils/callback_test.cc +++ b/tests/ut/cpp/utils/callback_test.cc @@ -18,9 +18,9 @@ #include "pybind11/pybind11.h" #include "utils/callbacks.h" #include "common/common_test.h" -#include "pipeline/pipeline.h" -#include "pipeline/parse/python_adapter.h" -#include "transform/df_graph_manager.h" +#include "pipeline/jit/pipeline.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "transform/graph_ir/df_graph_manager.h" #include "debug/draw.h" #ifdef ENABLE_GE #include "utils/callbacks_ge.h" diff --git a/tests/ut/cpp/utils/graph_utils_test.cc b/tests/ut/cpp/utils/graph_utils_test.cc index ce5a4318d3..35fa9cdc6a 100644 --- a/tests/ut/cpp/utils/graph_utils_test.cc +++ b/tests/ut/cpp/utils/graph_utils_test.cc @@ -24,8 +24,8 @@ #include "ir/anf.h" #include "utils/graph_utils.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/parse.h" namespace mindspore { diff --git a/tests/ut/cpp/utils/ir_import_test.cc b/tests/ut/cpp/utils/ir_import_test.cc index 5e7db98a38..374c36b4e8 100644 --- a/tests/ut/cpp/utils/ir_import_test.cc +++ b/tests/ut/cpp/utils/ir_import_test.cc @@ -19,10 +19,10 @@ #include "utils/log_adapter.h" #include "debug/anf_ir_utils.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "ir/manager.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" namespace mindspore { class TestIrImporter : public UT::Common { diff --git a/tests/ut/cpp/utils/symbolic_test.cc b/tests/ut/cpp/utils/symbolic_test.cc index f259b62d6b..c0abd388d5 100644 --- a/tests/ut/cpp/utils/symbolic_test.cc +++ b/tests/ut/cpp/utils/symbolic_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "pipeline/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/static_analysis.h" #include "utils/symbolic.h" using std::cout; diff --git a/tests/ut/cpp/utils/validator_test.cc b/tests/ut/cpp/utils/validator_test.cc index 8eef44bde5..93334d7664 100644 --- a/tests/ut/cpp/utils/validator_test.cc +++ b/tests/ut/cpp/utils/validator_test.cc @@ -18,11 +18,11 @@ #include "common/common_test.h" #include "utils/log_adapter.h" -#include "pipeline/validator.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/validator.h" +#include "pipeline/jit/parse/parse.h" #include "ir/manager.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" namespace mindspore { namespace validator { diff --git a/tests/ut/cpp/vm/segment_runner_test.cc b/tests/ut/cpp/vm/segment_runner_test.cc index b9bc552d90..c83b1b3434 100644 --- a/tests/ut/cpp/vm/segment_runner_test.cc +++ b/tests/ut/cpp/vm/segment_runner_test.cc @@ -20,11 +20,11 @@ #include "ir/manager.h" #include "utils/log_adapter.h" #include "ir/func_graph_cloner.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "utils/graph_utils.h" -#include "pipeline/resource.h" +#include "pipeline/jit/resource.h" #include "debug/draw.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "vm/segment_runner.h" #include "vm/transform.h" #include "ir/tensor.h" diff --git a/tests/ut/cpp/vm/vm_test.cc b/tests/ut/cpp/vm/vm_test.cc index 04633043af..9168d408c3 100644 --- a/tests/ut/cpp/vm/vm_test.cc +++ b/tests/ut/cpp/vm/vm_test.cc @@ -15,7 +15,7 @@ */ #include "vm/vm.h" #include "common/common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "vm/backend.h" namespace mindspore { diff --git a/tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz b/tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz index e4e92210d7..14ddc166e2 100644 Binary files a/tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz and b/tests/ut/data/dataset/golden/bounding_box_augment_crop_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz b/tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz index 8cc7e15e31..07ae4e5892 100644 Binary files a/tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz and b/tests/ut/data/dataset/golden/bounding_box_augment_rotation_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz b/tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz index dafea520fe..a72643457b 100644 Binary files a/tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz and b/tests/ut/data/dataset/golden/bounding_box_augment_valid_edge_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz b/tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz index 71e58406ac..9a6ae1cb99 100644 Binary files a/tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz and b/tests/ut/data/dataset/golden/bounding_box_augment_valid_ratio_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/cache_map_01_result.npz b/tests/ut/data/dataset/golden/cache_map_01_result.npz new file mode 100644 index 0000000000..7cff9ded88 Binary files /dev/null and b/tests/ut/data/dataset/golden/cache_map_01_result.npz differ diff --git a/tests/ut/data/dataset/golden/cache_map_02_result.npz b/tests/ut/data/dataset/golden/cache_map_02_result.npz new file mode 100644 index 0000000000..7cff9ded88 Binary files /dev/null and b/tests/ut/data/dataset/golden/cache_map_02_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz index 0c220fd09d..bb33f1bece 100644 Binary files a/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz and b/tests/ut/data/dataset/golden/random_crop_with_bbox_01_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz index d360bb98ec..416223ff4d 100644 Binary files a/tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz and b/tests/ut/data/dataset/golden/random_horizontal_flip_with_bbox_01_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_coco_result.npz b/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_coco_result.npz new file mode 100644 index 0000000000..db62d6509e Binary files /dev/null and b/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_coco_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_voc_result.npz b/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_voc_result.npz new file mode 100644 index 0000000000..75f4447ded Binary files /dev/null and b/tests/ut/data/dataset/golden/random_resize_with_bbox_op_01_c_voc_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz index a909cbe88c..aa9778bd39 100644 Binary files a/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz and b/tests/ut/data/dataset/golden/random_resized_crop_with_bbox_01_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz b/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz index aba6fe97b0..e0e0eb2823 100644 Binary files a/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz and b/tests/ut/data/dataset/golden/random_vertical_flip_with_bbox_01_c_result.npz differ diff --git a/tests/ut/data/dataset/golden/repeat_list_result.npz b/tests/ut/data/dataset/golden/repeat_list_result.npz index c0240c6e21..883ac58be8 100644 Binary files a/tests/ut/data/dataset/golden/repeat_list_result.npz and b/tests/ut/data/dataset/golden/repeat_list_result.npz differ diff --git a/tests/ut/data/dataset/golden/repeat_result.npz b/tests/ut/data/dataset/golden/repeat_result.npz index 73b0a24b20..2df787cef8 100644 Binary files a/tests/ut/data/dataset/golden/repeat_result.npz and b/tests/ut/data/dataset/golden/repeat_result.npz differ diff --git a/tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_coco_result.npz b/tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_coco_result.npz new file mode 100644 index 0000000000..999c15e5f3 Binary files /dev/null and b/tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_coco_result.npz differ diff --git a/tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_voc_result.npz b/tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_voc_result.npz new file mode 100644 index 0000000000..ca64884937 Binary files /dev/null and b/tests/ut/data/dataset/golden/resize_with_bbox_op_01_c_voc_result.npz differ diff --git a/tests/ut/data/dataset/golden/tf_file_no_schema.npz b/tests/ut/data/dataset/golden/tf_file_no_schema.npz deleted file mode 100644 index b823998521..0000000000 Binary files a/tests/ut/data/dataset/golden/tf_file_no_schema.npz and /dev/null differ diff --git a/tests/ut/data/dataset/golden/tf_file_padBytes10.npz b/tests/ut/data/dataset/golden/tf_file_padBytes10.npz deleted file mode 100644 index e3d6d9934b..0000000000 Binary files a/tests/ut/data/dataset/golden/tf_file_padBytes10.npz and /dev/null differ diff --git a/tests/ut/data/dataset/golden/tfreader_result.npz b/tests/ut/data/dataset/golden/tfreader_result.npz deleted file mode 100644 index 10cad9f2b0..0000000000 Binary files a/tests/ut/data/dataset/golden/tfreader_result.npz and /dev/null differ diff --git a/tests/ut/data/dataset/golden/tfrecord_files_basic.npz b/tests/ut/data/dataset/golden/tfrecord_files_basic.npz new file mode 100644 index 0000000000..810182faf9 Binary files /dev/null and b/tests/ut/data/dataset/golden/tfrecord_files_basic.npz differ diff --git a/tests/ut/data/dataset/golden/tfrecord_no_schema.npz b/tests/ut/data/dataset/golden/tfrecord_no_schema.npz new file mode 100644 index 0000000000..bda2807e89 Binary files /dev/null and b/tests/ut/data/dataset/golden/tfrecord_no_schema.npz differ diff --git a/tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz b/tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz new file mode 100644 index 0000000000..580e19de64 Binary files /dev/null and b/tests/ut/data/dataset/golden/tfrecord_pad_bytes10.npz differ diff --git a/tests/ut/data/dataset/imagefolder/ExpectedBoundingBoxAugmentOp0.jpg b/tests/ut/data/dataset/imagefolder/ExpectedBoundingBoxAugmentOp0.jpg new file mode 100644 index 0000000000..242559f276 Binary files /dev/null and b/tests/ut/data/dataset/imagefolder/ExpectedBoundingBoxAugmentOp0.jpg differ diff --git a/tests/ut/data/dataset/imagefolder/ExpectedRandomCropWithBBox_C0.jpg b/tests/ut/data/dataset/imagefolder/ExpectedRandomCropWithBBox_C0.jpg new file mode 100644 index 0000000000..362d841170 Binary files /dev/null and b/tests/ut/data/dataset/imagefolder/ExpectedRandomCropWithBBox_C0.jpg differ diff --git a/tests/ut/data/dataset/imagefolder/ExpectedRandomHorizontalFlipWithBBox0.jpg b/tests/ut/data/dataset/imagefolder/ExpectedRandomHorizontalFlipWithBBox0.jpg new file mode 100644 index 0000000000..3210a7b1fe Binary files /dev/null and b/tests/ut/data/dataset/imagefolder/ExpectedRandomHorizontalFlipWithBBox0.jpg differ diff --git a/tests/ut/data/dataset/imagefolder/ExpectedRandomResizeWithBBox_C0.jpg b/tests/ut/data/dataset/imagefolder/ExpectedRandomResizeWithBBox_C0.jpg new file mode 100644 index 0000000000..235516d75f Binary files /dev/null and b/tests/ut/data/dataset/imagefolder/ExpectedRandomResizeWithBBox_C0.jpg differ diff --git a/tests/ut/data/dataset/imagefolder/ExpectedRandomResizedCropWithBBox_C0.jpg b/tests/ut/data/dataset/imagefolder/ExpectedRandomResizedCropWithBBox_C0.jpg new file mode 100644 index 0000000000..d7666adb9b Binary files /dev/null and b/tests/ut/data/dataset/imagefolder/ExpectedRandomResizedCropWithBBox_C0.jpg differ diff --git a/tests/ut/data/dataset/imagefolder/ExpectedRandomVerticalFlipWithBBox_C0.jpg b/tests/ut/data/dataset/imagefolder/ExpectedRandomVerticalFlipWithBBox_C0.jpg new file mode 100644 index 0000000000..c5fe8ff540 Binary files /dev/null and b/tests/ut/data/dataset/imagefolder/ExpectedRandomVerticalFlipWithBBox_C0.jpg differ diff --git a/tests/ut/data/dataset/imagefolder/ExpectedResizeWithBBox_C0.jpg b/tests/ut/data/dataset/imagefolder/ExpectedResizeWithBBox_C0.jpg new file mode 100644 index 0000000000..f6dfd85547 Binary files /dev/null and b/tests/ut/data/dataset/imagefolder/ExpectedResizeWithBBox_C0.jpg differ diff --git a/tests/ut/data/dataset/testCifar100Data/datasetSchema.json b/tests/ut/data/dataset/testCifar100Data/datasetSchema.json deleted file mode 100644 index 474a806bf2..0000000000 --- a/tests/ut/data/dataset/testCifar100Data/datasetSchema.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "datasetType": "CIFAR100", - "numRows": 100, - "columns": { - "image": { - "type": "uint8", - "rank": 1, - "t_impl": "cvmat" - }, - "coarse_label" : { - "type": "uint32", - "rank": 1, - "t_impl": "flex" - }, - "fine_label" : { - "type": "uint32", - "rank": 1, - "t_impl": "flex" - } - } -} diff --git a/tests/ut/data/dataset/testCifar100Data/datasetSchemaTestRepeat.json b/tests/ut/data/dataset/testCifar100Data/datasetSchemaTestRepeat.json deleted file mode 100644 index a90edb342b..0000000000 --- a/tests/ut/data/dataset/testCifar100Data/datasetSchemaTestRepeat.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "datasetType": "CIFAR100", - "numRows": 33, - "columns": { - "image": { - "type": "uint8", - "rank": 1, - "t_impl": "cvmat" - }, - "coarse_label" : { - "type": "uint32", - "rank": 1, - "t_impl": "flex" - }, - "fine_label" : { - "type": "uint32", - "rank": 1, - "t_impl": "flex" - } - } -} diff --git a/tests/ut/data/dataset/testCifar10Data/data_batch_1.bin b/tests/ut/data/dataset/testCifar10Data/data_batch_1.bin index 7964f0952c..b3ec462f79 100644 Binary files a/tests/ut/data/dataset/testCifar10Data/data_batch_1.bin and b/tests/ut/data/dataset/testCifar10Data/data_batch_1.bin differ diff --git a/tests/ut/data/dataset/testCifar10Data/datasetDistributionAll.json b/tests/ut/data/dataset/testCifar10Data/datasetDistributionAll.json deleted file mode 100644 index 9234a6e033..0000000000 --- a/tests/ut/data/dataset/testCifar10Data/datasetDistributionAll.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "deviceNum" : 3, - "deviceId" : 1, - "shardConfig" : "ALL", - "shuffle" : "ON", - "seed" : 0, - "epoch" : 2 -} - diff --git a/tests/ut/data/dataset/testCifar10Data/datasetDistributionRandom.json b/tests/ut/data/dataset/testCifar10Data/datasetDistributionRandom.json deleted file mode 100644 index 3f61c582a5..0000000000 --- a/tests/ut/data/dataset/testCifar10Data/datasetDistributionRandom.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "deviceNum" : 3, - "deviceId" : 1, - "shardConfig" : "RANDOM", - "shuffle" : "ON", - "seed" : 0, - "epoch" : 1 -} - diff --git a/tests/ut/data/dataset/testCifar10Data/datasetDistributionUnique.json b/tests/ut/data/dataset/testCifar10Data/datasetDistributionUnique.json deleted file mode 100644 index 99e685132b..0000000000 --- a/tests/ut/data/dataset/testCifar10Data/datasetDistributionUnique.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "deviceNum" : 3, - "deviceId" : 1, - "shardConfig" : "UNIQUE", - "shuffle" : "ON", - "seed" : 0, - "epoch" : 3 -} - diff --git a/tests/ut/data/dataset/testCifar10Data/datasetSchema.json b/tests/ut/data/dataset/testCifar10Data/datasetSchema.json deleted file mode 100644 index 1a04b9af59..0000000000 --- a/tests/ut/data/dataset/testCifar10Data/datasetSchema.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "datasetType": "CIFAR10", - "numRows": 60000, - "columns": { - "image": { - "type": "uint8", - "rank": 1, - "t_impl": "cvmat" - }, - "label" : { - "type": "uint32", - "rank": 1, - "t_impl": "flex" - } - } -} diff --git a/tests/ut/data/dataset/testCifar10Data/datasetSchemaTestRepeat.json b/tests/ut/data/dataset/testCifar10Data/datasetSchemaTestRepeat.json deleted file mode 100644 index c25e11c30f..0000000000 --- a/tests/ut/data/dataset/testCifar10Data/datasetSchemaTestRepeat.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "datasetType": "CIFAR10", - "numRows": 33, - "columns": { - "image": { - "type": "uint8", - "rank": 1, - "t_impl": "cvmat" - }, - "label" : { - "type": "uint32", - "rank": 1, - "t_impl": "flex" - } - } -} diff --git a/tests/ut/data/dataset/test_tf_file_3_images_1/datasetSchema.json b/tests/ut/data/dataset/test_tf_file_3_images_1/datasetSchema.json deleted file mode 100644 index 0aa5a4577a..0000000000 --- a/tests/ut/data/dataset/test_tf_file_3_images_1/datasetSchema.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "datasetType": "TF", - "numRows": 3, - "columns": { - "label": { - "type": "int64", - "rank": 1, - "t_impl": "flex" - } - } -} diff --git a/tests/ut/data/dataset/test_tf_file_3_images_1/train-0000-of-0001.data b/tests/ut/data/dataset/test_tf_file_3_images_1/train-0000-of-0001.data deleted file mode 100644 index 829e8d70cb..0000000000 Binary files a/tests/ut/data/dataset/test_tf_file_3_images_1/train-0000-of-0001.data and /dev/null differ diff --git a/tests/ut/data/dataset/test_tf_file_3_images_2/datasetSchema.json b/tests/ut/data/dataset/test_tf_file_3_images_2/datasetSchema.json deleted file mode 100644 index b7b3cb9ea3..0000000000 --- a/tests/ut/data/dataset/test_tf_file_3_images_2/datasetSchema.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "datasetType": "TF", - "numRows": 3, - "columns": { - "image": { - "type": "uint8", - "rank": 1, - "t_impl": "cvmat" - } - } -} diff --git a/tests/ut/data/dataset/test_tf_file_3_images_2/train-0000-of-0001.data b/tests/ut/data/dataset/test_tf_file_3_images_2/train-0000-of-0001.data deleted file mode 100644 index 829e8d70cb..0000000000 Binary files a/tests/ut/data/dataset/test_tf_file_3_images_2/train-0000-of-0001.data and /dev/null differ diff --git a/tests/ut/data/mindrecord/testGraphData/testdata b/tests/ut/data/mindrecord/testGraphData/testdata index e206469ac6..5235973469 100644 Binary files a/tests/ut/data/mindrecord/testGraphData/testdata and b/tests/ut/data/mindrecord/testGraphData/testdata differ diff --git a/tests/ut/data/mindrecord/testGraphData/testdata.db b/tests/ut/data/mindrecord/testGraphData/testdata.db index 541da0e998..0f022589f4 100644 Binary files a/tests/ut/data/mindrecord/testGraphData/testdata.db and b/tests/ut/data/mindrecord/testGraphData/testdata.db differ diff --git a/tests/ut/python/automl/case.py b/tests/ut/python/automl/case.py new file mode 100644 index 0000000000..745376277c --- /dev/null +++ b/tests/ut/python/automl/case.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test case.""" +import numpy as np + +import mindspore +import mindspore.nn as nn +from mindspore import Tensor, context + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 3, 3) + self.conv2 = nn.Conv2d(1, 3, 5, has_bias=True) + self.layers = (self.conv1, self.conv2) + + def construct(self, x, index): + x = self.layers[index](x) + y = self.conv1(x) + return x + y + + +def test_case(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + net = Net() + data = Tensor(np.ones((1, 1, 224, 224)), mindspore.float32) + idx = Tensor(1, mindspore.int32) + net(data, idx) diff --git a/tests/ut/python/dataset/test_basic_tokenizer.py b/tests/ut/python/dataset/test_basic_tokenizer.py deleted file mode 100644 index 45c9f94da4..0000000000 --- a/tests/ut/python/dataset/test_basic_tokenizer.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Testing BasicTokenizer op in DE -""" -import numpy as np -import mindspore.dataset as ds -from mindspore import log as logger -import mindspore.dataset.text as nlp - -BASIC_TOKENIZER_FILE = "../data/dataset/testTokenizerData/basic_tokenizer.txt" - -test_paras = [ - dict( - first=1, - last=6, - expected_tokens= - [['Welcome', 'to', 'Beijing', '北', '京', '欢', '迎', '您'], - ['長', '風', '破', '浪', '會', '有', '時', ',', '直', '掛', '雲', '帆', '濟', '滄', '海'], - ['😀', '嘿', '嘿', '😃', '哈', '哈', '😄', '大', '笑', '😁', '嘻', '嘻'], - ['明', '朝', '(', '1368', '—', '1644', '年', ')', '和', '清', '朝', - '(', '1644', '—', '1911', '年', ')', ',', '是', '中', '国', '封', - '建', '王', '朝', '史', '上', '最', '后', '两', '个', '朝', '代'], - ['明', '代', '(', '1368', '-', '1644', ')', 'と', '清', '代', - '(', '1644', '-', '1911', ')', 'は', '、', '中', '国', 'の', '封', - '建', '王', '朝', 'の', '歴', '史', 'における', '最', '後', 'の2つの', '王', '朝', 'でした'], - ['명나라', '(', '1368', '-', '1644', ')', '와', '청나라', '(', '1644', '-', '1911', ')', '는', - '중국', '봉건', '왕조의', '역사에서', '마지막', '두', '왕조였다']] - ), - dict( - first=7, - last=7, - expected_tokens=[['this', 'is', 'a', 'funky', 'string']], - lower_case=True - ), -] - - -def check_basic_tokenizer(first, last, expected_tokens, lower_case=False, keep_whitespace=False, - normalization_form=nlp.utils.NormalizeForm.NONE, preserve_unused_token=False): - dataset = ds.TextFileDataset(BASIC_TOKENIZER_FILE, shuffle=False) - if first > 1: - dataset = dataset.skip(first - 1) - if last >= first: - dataset = dataset.take(last - first + 1) - - basic_tokenizer = nlp.BasicTokenizer(lower_case=lower_case, - keep_whitespace=keep_whitespace, - normalization_form=normalization_form, - preserve_unused_token=preserve_unused_token) - - dataset = dataset.map(operations=basic_tokenizer) - count = 0 - for i in dataset.create_dict_iterator(): - text = nlp.to_str(i['text']) - logger.info("Out:", text) - logger.info("Exp:", expected_tokens[count]) - np.testing.assert_array_equal(text, expected_tokens[count]) - count = count + 1 - - -def test_basic_tokenizer(): - """ - Test BasicTokenizer - """ - for paras in test_paras: - check_basic_tokenizer(**paras) - - -if __name__ == '__main__': - test_basic_tokenizer() diff --git a/tests/ut/python/dataset/test_bert_tokenizer.py b/tests/ut/python/dataset/test_bert_tokenizer.py deleted file mode 100644 index ba487343a0..0000000000 --- a/tests/ut/python/dataset/test_bert_tokenizer.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Testing BertTokenizer op in DE -""" -import numpy as np -import mindspore.dataset as ds -from mindspore import log as logger -import mindspore.dataset.text as nlp - -BERT_TOKENIZER_FILE = "../data/dataset/testTokenizerData/bert_tokenizer.txt" - -vocab_bert = [ - "床", "前", "明", "月", "光", "疑", "是", "地", "上", "霜", "举", "头", "望", "低", "思", "故", "乡", - "繁", "體", "字", "嘿", "哈", "大", "笑", "嘻", - "i", "am", "mak", "make", "small", "mistake", "##s", "during", "work", "##ing", "hour", - "😀", "😃", "😄", "😁", "+", "/", "-", "=", "12", "28", "40", "16", " ", "I", - "[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]", "[unused1]", "[unused10]" -] -pad = '' -test_paras = [ - # test chinese text - dict( - first=1, - last=4, - expect_str=[['床', '前', '明', '月', '光'], - ['疑', '是', '地', '上', '霜'], - ['举', '头', '望', '明', '月'], - ['低', '头', '思', '故', '乡']], - vocab_list=vocab_bert - ), - # test english text - dict( - first=5, - last=5, - expect_str=[['i', 'am', 'mak', '##ing', 'small', 'mistake', '##s', 'during', 'work', '##ing', 'hour', '##s']], - lower_case=True, - vocab_list=vocab_bert - ), - dict( - first=5, - last=5, - expect_str=[['I', "am", 'mak', '##ing', 'small', 'mistake', '##s', 'during', 'work', '##ing', 'hour', '##s']], - lower_case=False, - vocab_list=vocab_bert - ), - # test emoji tokens - dict( - first=6, - last=7, - expect_str=[ - ['😀', '嘿', '嘿', '😃', '哈', '哈', '😄', '大', '笑', '😁', '嘻', '嘻'], - ['繁', '體', '字']], - normalization_form=nlp.utils.NormalizeForm.NFKC, - vocab_list=vocab_bert - ), - # test preserved tokens - dict( - first=8, - last=14, - expect_str=[ - ['[UNK]', '[CLS]'], - ['[UNK]', '[SEP]'], - ['[UNK]', '[UNK]'], - ['[UNK]', '[PAD]'], - ['[UNK]', '[MASK]'], - ['[unused1]'], - ['[unused10]'] - ], - lower_case=False, - vocab_list=vocab_bert, - preserve_unused_token=True, - ), - dict( - first=8, - last=14, - expect_str=[ - ['[UNK]', '[CLS]'], - ['[UNK]', '[SEP]'], - ['[UNK]', '[UNK]'], - ['[UNK]', '[PAD]'], - ['[UNK]', '[MASK]'], - ['[unused1]'], - ['[unused10]'] - ], - lower_case=True, - vocab_list=vocab_bert, - preserve_unused_token=True, - ), - # test special symbol - dict( - first=15, - last=15, - expect_str=[['12', '+', '/', '-', '28', '=', '40', '/', '-', '16']], - preserve_unused_token=True, - vocab_list=vocab_bert - ), - # test non-default parms - dict( - first=8, - last=8, - expect_str=[['[UNK]', ' ', '[CLS]']], - lower_case=False, - vocab_list=vocab_bert, - preserve_unused_token=True, - keep_whitespace=True - ), - dict( - first=8, - last=8, - expect_str=[['unused', ' ', '[CLS]']], - lower_case=False, - vocab_list=vocab_bert, - preserve_unused_token=True, - keep_whitespace=True, - unknown_token='' - ), - dict( - first=8, - last=8, - expect_str=[['unused', ' ', '[', 'CLS', ']']], - lower_case=False, - vocab_list=vocab_bert, - preserve_unused_token=False, - keep_whitespace=True, - unknown_token='' - ), -] - - -def check_bert_tokenizer(first, last, expect_str, - vocab_list, - suffix_indicator='##', - max_bytes_per_token=100, unknown_token='[UNK]', - lower_case=False, keep_whitespace=False, - normalization_form=nlp.utils.NormalizeForm.NONE, - preserve_unused_token=False): - dataset = ds.TextFileDataset(BERT_TOKENIZER_FILE, shuffle=False) - if first > 1: - dataset = dataset.skip(first - 1) - if last >= first: - dataset = dataset.take(last - first + 1) - vocab = nlp.Vocab.from_list(vocab_list) - tokenizer_op = nlp.BertTokenizer( - vocab=vocab, suffix_indicator=suffix_indicator, - max_bytes_per_token=max_bytes_per_token, unknown_token=unknown_token, - lower_case=lower_case, keep_whitespace=keep_whitespace, - normalization_form=normalization_form, - preserve_unused_token=preserve_unused_token) - dataset = dataset.map(operations=tokenizer_op) - count = 0 - for i in dataset.create_dict_iterator(): - text = nlp.to_str(i['text']) - logger.info("Out:", text) - logger.info("Exp:", expect_str[count]) - np.testing.assert_array_equal(text, expect_str[count]) - count = count + 1 - - -def test_bert_tokenizer(): - """ - Test WordpieceTokenizer - """ - for paras in test_paras: - check_bert_tokenizer(**paras) - - -if __name__ == '__main__': - test_bert_tokenizer() diff --git a/tests/ut/python/dataset/test_bounding_box_augment.py b/tests/ut/python/dataset/test_bounding_box_augment.py index fbcb56514f..8924af968c 100644 --- a/tests/ut/python/dataset/test_bounding_box_augment.py +++ b/tests/ut/python/dataset/test_bounding_box_augment.py @@ -15,36 +15,21 @@ """ Testing the bounding box augment op in DE """ -from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ - config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 + import numpy as np import mindspore.log as logger import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as c_vision +from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \ + config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5 + GENERATE_GOLDEN = False +# updated VOC dataset with correct annotations DATA_DIR = "../data/dataset/testVOC2012_2" - - -def fix_annotate(bboxes): - """ - Fix annotations to format followed by mindspore. - :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format - :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format - """ - for bbox in bboxes: - if bbox.size == 7: - tmp = bbox[0] - bbox[0] = bbox[1] - bbox[1] = bbox[2] - bbox[2] = bbox[3] - bbox[3] = bbox[4] - bbox[4] = tmp - else: - print("ERROR: Invalid Bounding Box size provided") - break - return bboxes +DATA_DIR_2 = ["../data/dataset/testCOCO/train/", + "../data/dataset/testCOCO/annotations/train.json"] # DATA_DIR, ANNOTATION_DIR def test_bounding_box_augment_with_rotation_op(plot_vis=False): @@ -63,13 +48,6 @@ def test_bounding_box_augment_with_rotation_op(plot_vis=False): # Ratio is set to 1 to apply rotation on all bounding boxes. test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1) - # maps to fix annotations to minddata standard - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -100,22 +78,15 @@ def test_bounding_box_augment_with_crop_op(plot_vis=False): """ logger.info("test_bounding_box_augment_with_crop_op") - original_seed = config_get_set_seed(1) + original_seed = config_get_set_seed(0) original_num_parallel_workers = config_get_set_num_parallel_workers(1) dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - # Ratio is set to 1 to apply rotation on all bounding boxes. - test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(90), 1) - - # maps to fix annotations to minddata standard - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) + # Ratio is set to 0.9 to apply RandomCrop of size (50, 50) on 90% of the bounding boxes. + test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(50), 0.9) + # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -154,13 +125,6 @@ def test_bounding_box_augment_valid_ratio_c(plot_vis=False): test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9) - # maps to fix annotations to minddata standard - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -183,6 +147,36 @@ def test_bounding_box_augment_valid_ratio_c(plot_vis=False): ds.config.set_num_parallel_workers(original_num_parallel_workers) +def test_bounding_box_augment_op_coco_c(plot_vis=False): + """ + Prints images and bboxes side by side with and without BoundingBoxAugment Op applied, + Testing with COCO dataset + """ + logger.info("test_bounding_box_augment_op_coco_c") + + dataCoco1 = ds.CocoDataset(DATA_DIR_2[0], annotation_file=DATA_DIR_2[1], task="Detection", + decode=True, shuffle=False) + + dataCoco2 = ds.CocoDataset(DATA_DIR_2[0], annotation_file=DATA_DIR_2[1], task="Detection", + decode=True, shuffle=False) + + test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1) + + dataCoco2 = dataCoco2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) + + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataCoco1.create_dict_iterator(), dataCoco2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp, "bbox") + + def test_bounding_box_augment_valid_edge_c(plot_vis=False): """ Test BoundingBoxAugment op (testing with valid edge case, box covering full image). @@ -198,25 +192,18 @@ def test_bounding_box_augment_valid_edge_c(plot_vis=False): test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1) - # maps to fix annotations to minddata standard - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops # Add column for "annotation" dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], operations=lambda img, bbox: - (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.uint32))) + (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], operations=lambda img, bbox: - (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.uint32))) + (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], @@ -249,10 +236,6 @@ def test_bounding_box_augment_invalid_ratio_c(): try: # ratio range is from 0 - 1 test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1.5) - # maps to fix annotations to minddata standard - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -260,7 +243,7 @@ def test_bounding_box_augment_invalid_ratio_c(): operations=[test_op]) # Add column for "annotation" except ValueError as error: logger.info("Got an exception in DE: {}".format(str(error))) - assert "Input is not" in str(error) + assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error) def test_bounding_box_augment_invalid_bounds_c(): @@ -286,6 +269,7 @@ if __name__ == "__main__": # set to false to not show plots test_bounding_box_augment_with_rotation_op(plot_vis=False) test_bounding_box_augment_with_crop_op(plot_vis=False) + test_bounding_box_augment_op_coco_c(plot_vis=False) test_bounding_box_augment_valid_ratio_c(plot_vis=False) test_bounding_box_augment_valid_edge_c(plot_vis=False) test_bounding_box_augment_invalid_ratio_c() diff --git a/tests/ut/python/dataset/test_bucket_batch_by_length.py b/tests/ut/python/dataset/test_bucket_batch_by_length.py index febcc6483f..405b874110 100644 --- a/tests/ut/python/dataset/test_bucket_batch_by_length.py +++ b/tests/ut/python/dataset/test_bucket_batch_by_length.py @@ -17,6 +17,7 @@ import pytest import numpy as np import mindspore.dataset as ds + # generates 1 column [0], [0, 1], ..., [0, ..., n-1] def generate_sequential(n): for i in range(n): @@ -44,6 +45,7 @@ def test_bucket_batch_invalid_input(): bucket_boundaries = [1, 2, 3] empty_bucket_boundaries = [] invalid_bucket_boundaries = ["1", "2", "3"] + zero_start_bucket_boundaries = [0, 2, 3] negative_bucket_boundaries = [1, 2, -3] decreasing_bucket_boundaries = [3, 2, 1] non_increasing_bucket_boundaries = [1, 2, 2] @@ -58,7 +60,7 @@ def test_bucket_batch_invalid_input(): with pytest.raises(TypeError) as info: _ = dataset.bucket_batch_by_length(invalid_column_names, bucket_boundaries, bucket_batch_sizes) - assert "column_names should be a list of str" in str(info.value) + assert "Argument column_names[0] with value 1 is not of type (,)." in str(info.value) with pytest.raises(ValueError) as info: _ = dataset.bucket_batch_by_length(column_names, empty_bucket_boundaries, bucket_batch_sizes) @@ -68,9 +70,13 @@ def test_bucket_batch_invalid_input(): _ = dataset.bucket_batch_by_length(column_names, invalid_bucket_boundaries, bucket_batch_sizes) assert "bucket_boundaries should be a list of int" in str(info.value) + with pytest.raises(ValueError) as info: + _ = dataset.bucket_batch_by_length(column_names, zero_start_bucket_boundaries, bucket_batch_sizes) + assert "bucket_boundaries must only contain positive numbers." in str(info.value) + with pytest.raises(ValueError) as info: _ = dataset.bucket_batch_by_length(column_names, negative_bucket_boundaries, bucket_batch_sizes) - assert "bucket_boundaries cannot contain any negative numbers" in str(info.value) + assert "bucket_boundaries must only contain positive numbers." in str(info.value) with pytest.raises(ValueError) as info: _ = dataset.bucket_batch_by_length(column_names, decreasing_bucket_boundaries, bucket_batch_sizes) @@ -99,12 +105,12 @@ def test_bucket_batch_invalid_input(): with pytest.raises(TypeError) as info: _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, None, None, invalid_type_pad_to_bucket_boundary) - assert "Wrong input type for pad_to_bucket_boundary, should be " in str(info.value) + assert "Argument pad_to_bucket_boundary with value \"\" is not of type (,)." in str(info.value) with pytest.raises(TypeError) as info: _ = dataset.bucket_batch_by_length(column_names, bucket_boundaries, bucket_batch_sizes, None, None, False, invalid_type_drop_remainder) - assert "Wrong input type for drop_remainder, should be " in str(info.value) + assert "Argument drop_remainder with value \"\" is not of type (,)." in str(info.value) def test_bucket_batch_multi_bucket_no_padding(): @@ -272,7 +278,6 @@ def test_bucket_batch_default_pad(): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]]] - output = [] for data in dataset.create_dict_iterator(): output.append(data["col1"].tolist()) diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py new file mode 100644 index 0000000000..0e42b422aa --- /dev/null +++ b/tests/ut/python/dataset/test_cache_map.py @@ -0,0 +1,157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing cache operator with mappable datasets +""" +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as c_vision +from mindspore import log as logger +from util import save_and_check_md5 + +DATA_DIR = "../data/dataset/testImageNetData/train/" + +GENERATE_GOLDEN = False + +def test_cache_map_basic1(): + """ + Test mappable leaf with cache op right over the leaf + + Repeat + | + Map(decode) + | + Cache + | + ImageFolder + """ + + logger.info("Test cache map basic 1") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + filename = "cache_map_01_result.npz" + save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN) + + logger.info("test_cache_map_basic1 Ended.\n") + + +def test_cache_map_basic2(): + """ + Test mappable leaf with the cache op later in the tree above the map(decode) + + Repeat + | + Cache + | + Map(decode) + | + ImageFolder + """ + + logger.info("Test cache map basic 2") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + filename = "cache_map_02_result.npz" + save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN) + + logger.info("test_cache_map_basic2 Ended.\n") + + +def test_cache_map_basic3(): + """ + Test a repeat under mappable cache + + Cache + | + Map(decode) + | + Repeat + | + ImageFolder + """ + + logger.info("Test cache basic 3") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR) + decode_op = c_vision.Decode() + ds1 = ds1.repeat(4) + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 8 + logger.info('test_cache_basic3 Ended.\n') + + +def test_cache_map_failure1(): + """ + Test nested cache (failure) + + Repeat + | + Cache + | + Map(decode) + | + Cache + | + ImageFolder + + """ + logger.info("Test cache failure 1") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # This DATA_DIR only has 2 images in it + ds1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + try: + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Nested cache operations is not supported!" in str(e) + + assert num_iter == 0 + logger.info('test_cache_failure1 Ended.\n') + +if __name__ == '__main__': + test_cache_map_basic1() + test_cache_map_basic2() + test_cache_map_basic3() + test_cache_map_failure1() diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py new file mode 100644 index 0000000000..39e00c0621 --- /dev/null +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -0,0 +1,429 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing cache operator with non-mappable datasets +""" +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as c_vision +from mindspore import log as logger + +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + +GENERATE_GOLDEN = False + +def test_cache_nomap_basic1(): + """ + A random dataset (a non mappable dataset) with a cache over it just after the leaf + """ + + logger.info("Test cache nomap basic 1") + + schema = ds.Schema() + schema.add_column('image', de_type=mstype.uint8, + shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) + schema.add_column('label', de_type=mstype.uint8, shape=[1]) + + # create a cache. arbitrary session_id for now + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # User-created sampler here + ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for data in ds1.create_dict_iterator(): + logger.info("printing the label: {}".format(data["label"])) + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 40 + logger.info("test_cache_nomap_basic1 Ended.\n") + + +def test_cache_nomap_basic2(): + """ + A random dataset (a non mappable dataset) with a cache over it just after the leaf + """ + + logger.info("Test cache nomap basic 2") + + schema = ds.Schema() + schema.add_column('image', de_type=mstype.uint8, + shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) + schema.add_column('label', de_type=mstype.uint8, shape=[1]) + + # create a cache. arbitrary session_id for now + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler: + # num_samples, shuffle, num_shards, shard_id + # In this case, the presence of num_samples chooses a sampler. + ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache) + ds1 = ds1.repeat(2) + + num_iter = 0 + for data in ds1.create_dict_iterator(): + logger.info("printing the label: {}".format(data["label"])) + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 40 + logger.info("test_cache_nomap_basic2 Ended.\n") + + +def test_cache_nomap_basic3(): + """ + A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf + + Repeat + | + Map(decode) + | + Cache + | + TFReader + """ + + logger.info("Test cache nomap basic 3") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_basic3 Ended.\n") + + +def test_cache_nomap_basic4(): + """ + A TF reader dataset (a non mappable dataset) with a map decode and cache after it + Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf. + But, if there's a cache later, that shuffle becomes invalid and should be removed. + + Repeat + | + Cache + | + Map(decode) + | + TFReader + """ + + logger.info("Test cache nomap basic 4") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache + # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will + # explicitly give the global option, even though it's the default in python. + # But, when caching is added in the ascendent tree above TF, we do global shuffling + # through the sampler over the cache, not by the shuffle op. In that case, tree prepare + # will remove the shuffle op that got injected by the initial tree creation. + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL) + decode_op = c_vision.Decode() + + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_basic4 Ended.\n") + + +def test_cache_nomap_basic5(): + """ + A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf + Same as test 3, but this one does not have shuffle arg, causing tf to default to global + shuffle which attempts to inject a shuffle operator. However, since there is a cache + we do not need global shuffle, so the shuffle will not be built. It ends up being + identical to test basic 3, however we arrive at the same tree in different codepaths + (if there was no cache, then the shuffle IS built) + + Repeat + | + Map(decode) + | + Cache + | + TFReader + """ + + logger.info("Test cache nomap basic 5") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_basic5 Ended.\n") + + +def test_cache_nomap_basic6(): + """ + A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf + In this one, the tf dataset will be given sharding configuration, however since a cache is + used, the tree prepare should undo the sharding configuration and instead, a distributed + sampler will be chosen with the same shard config. + + Repeat + | + Map(decode) + | + Cache + | + TFReader + """ + + logger.info("Test cache nomap basic 6") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + # With only 3 records shard into 3, we expect only 1 record returned for this shard + # However, the sharding will be done by the sampler, not by the tf record leaf node + # In this case, it is a row-based sharding, not the file-based sharding that would happen if + # there was not any cache. + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 4 + logger.info("test_cache_nomap_basic6 Ended.\n") + + +def test_cache_nomap_basic7(): + """ + A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by + map. + In this one, the tf dataset with global shuffle might want to inject a shuffle op over top of the + tf reader, but since a cache is given, it will choose not to. + + Repeat + | + Map(decode) + | + cache + | + TFReader + """ + + logger.info("Test cache nomap basic 7") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache) + decode_op = c_vision.Decode() + ds1 = ds1.map(input_columns=["image"], operations=decode_op) + ds1 = ds1.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + logger.info("test_cache_nomap_basic7 Ended.\n") + + +def test_cache_nomap_allowed_share1(): + """ + It is allowed to share the cache between the following two trees: + + Repeat Shuffle + | | + Cache Cache + | | + TFReader TFReader + """ + + logger.info("Test cache nomap allowed share 1") + + ds.config.set_seed(1) + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) + ds1 = ds1.repeat(4) + + ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache) + ds2 = ds2.shuffle(buffer_size=2) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + assert num_iter == 12 + logger.info("Number of data in ds1: {} ".format(num_iter)) + + num_iter = 0 + for _ in ds2.create_dict_iterator(): + num_iter += 1 + assert num_iter == 3 + logger.info("test_cache_nomap_allowed_share1 Ended.\n") + + +def test_cache_nomap_allowed_share2(): + """ + It is allowed to share the cache between the following two trees (with map decode): + + Repeat Shuffle + | | + Cache Cache + | | + Map(decode) Map(decode) + | | + TFReader TFReader + """ + + logger.info("Test cache nomap allowed share 2") + + ds.config.set_seed(1) + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True) + decode_op = c_vision.Decode() + + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds1 = ds1.repeat(4) + + ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds2 = ds2.map(input_columns=["image"], operations=decode_op, cache=some_cache) + ds2 = ds2.shuffle(buffer_size=2) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + + num_iter = 0 + for _ in ds2.create_dict_iterator(): + num_iter += 1 + assert num_iter == 3 + logger.info("test_cache_nomap_allowed_share2 Ended.\n") + + +def test_cache_nomap_allowed_share3(): + """ + It is allowed to share the cache between the following two trees (different shard ids): + + Repeat Repeat + | | + Cache Cache + | | + TFReader(shard_id = 0) TFReader(shard_id = 1) + """ + + logger.info("Test cache nomap allowed share 3") + + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + + tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"] + ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache) + ds1 = ds1.repeat(4) + + ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache) + ds2 = ds2.repeat(4) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 12 + + num_iter = 0 + for _ in ds2.create_dict_iterator(): + num_iter += 1 + assert num_iter == 12 + logger.info("test_cache_nomap_allowed_share3 Ended.\n") + + +def test_cache_nomap_disallowed_share1(): + """ + It is not allowed to share the cache between the following two trees: + + Cache Cache + | | + Map(decode) Map(rescale) + | | + TFReader TFReader + """ + + logger.info("Test cache nomap disallowed share1") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True) + decode_op = c_vision.Decode() + rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0) + + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache) + + ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds2 = ds2.map(input_columns=["image"], operations=rescale_op, cache=some_cache) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 3 + + try: + sum([1 for _ in ds2]) + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Attempt to re-use a cache for a different tree!" in str(e) + + logger.info("test_cache_nomap_disallowed_share1 Ended.\n") + + +if __name__ == '__main__': + test_cache_nomap_basic1() + test_cache_nomap_basic2() + test_cache_nomap_basic3() + test_cache_nomap_basic4() + test_cache_nomap_basic5() + test_cache_nomap_basic6() + test_cache_nomap_basic7() + test_cache_nomap_allowed_share1() + test_cache_nomap_allowed_share2() + test_cache_nomap_allowed_share3() + test_cache_nomap_disallowed_share1() diff --git a/tests/ut/python/dataset/test_cifarop.py b/tests/ut/python/dataset/test_cifarop.py deleted file mode 100644 index e944f8703d..0000000000 --- a/tests/ut/python/dataset/test_cifarop.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import os - -import numpy as np - -import mindspore.dataset as ds -from mindspore import log as logger - -# Data for CIFAR and MNIST are not part of build tree -# They need to be downloaded directly -# prep_data.py can be executed or code below -# import sys -# sys.path.insert(0,"../../data") -# import prep_data -# prep_data.download_all_for_test("../../data") -DATA_DIR_10 = "../data/dataset/testCifar10Data" -DATA_DIR_100 = "../data/dataset/testCifar100Data" - - -def load_cifar(path): - raw = np.empty(0, dtype=np.uint8) - for file_name in os.listdir(path): - if file_name.endswith(".bin"): - with open(os.path.join(path, file_name), mode='rb') as file: - raw = np.append(raw, np.fromfile(file, dtype=np.uint8), axis=0) - raw = raw.reshape(-1, 3073) - labels = raw[:, 0] - images = raw[:, 1:] - images = images.reshape(-1, 3, 32, 32) - images = images.transpose(0, 2, 3, 1) - return images, labels - - -def test_case_dataset_cifar10(): - """ - dataset parameter - """ - logger.info("Test dataset parameter") - # apply dataset operations - data1 = ds.Cifar10Dataset(DATA_DIR_10, 100) - - num_iter = 0 - for _ in data1.create_dict_iterator(): - # in this example, each dictionary has keys "image" and "label" - num_iter += 1 - assert num_iter == 100 - - -def test_case_dataset_cifar100(): - """ - dataset parameter - """ - logger.info("Test dataset parameter") - # apply dataset operations - data1 = ds.Cifar100Dataset(DATA_DIR_100, 100) - - num_iter = 0 - for _ in data1.create_dict_iterator(): - # in this example, each dictionary has keys "image" and "label" - num_iter += 1 - assert num_iter == 100 - - -def test_reading_cifar10(): - """ - Validate CIFAR10 image readings - """ - data1 = ds.Cifar10Dataset(DATA_DIR_10, 100, shuffle=False) - images, labels = load_cifar(DATA_DIR_10) - for i, d in enumerate(data1.create_dict_iterator()): - np.testing.assert_array_equal(d["image"], images[i]) - np.testing.assert_array_equal(d["label"], labels[i]) - - -if __name__ == '__main__': - test_case_dataset_cifar10() - test_case_dataset_cifar100() - test_reading_cifar10() diff --git a/tests/ut/python/dataset/test_concatenate_op.py b/tests/ut/python/dataset/test_concatenate_op.py index d04ff49724..f7a432e471 100644 --- a/tests/ut/python/dataset/test_concatenate_op.py +++ b/tests/ut/python/dataset/test_concatenate_op.py @@ -108,7 +108,7 @@ def test_concatenate_op_type_mismatch(): with pytest.raises(RuntimeError) as error_info: for _ in data: pass - assert "Tensor types do not match" in repr(error_info.value) + assert "Tensor types do not match" in str(error_info.value) def test_concatenate_op_type_mismatch2(): @@ -123,7 +123,7 @@ def test_concatenate_op_type_mismatch2(): with pytest.raises(RuntimeError) as error_info: for _ in data: pass - assert "Tensor types do not match" in repr(error_info.value) + assert "Tensor types do not match" in str(error_info.value) def test_concatenate_op_incorrect_dim(): @@ -138,13 +138,13 @@ def test_concatenate_op_incorrect_dim(): with pytest.raises(RuntimeError) as error_info: for _ in data: pass - assert "Only 1D tensors supported" in repr(error_info.value) + assert "Only 1D tensors supported" in str(error_info.value) def test_concatenate_op_wrong_axis(): with pytest.raises(ValueError) as error_info: data_trans.Concatenate(2) - assert "only 1D concatenation supported." in repr(error_info.value) + assert "only 1D concatenation supported." in str(error_info.value) def test_concatenate_op_negative_axis(): @@ -163,18 +163,11 @@ def test_concatenate_op_negative_axis(): def test_concatenate_op_incorrect_input_dim(): - def gen(): - yield (np.array(["ss", "ad"], dtype='S'),) - prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') - data = ds.GeneratorDataset(gen, column_names=["col"]) - concatenate_op = data_trans.Concatenate(0, prepend_tensor) - data = data.map(input_columns=["col"], operations=concatenate_op) - with pytest.raises(RuntimeError) as error_info: - for _ in data: - pass - assert "Only 1D tensors supported" in repr(error_info.value) + with pytest.raises(ValueError) as error_info: + data_trans.Concatenate(0, prepend_tensor) + assert "can only prepend 1D arrays." in str(error_info.value) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_config.py b/tests/ut/python/dataset/test_config.py index 259f42d948..6783eea2fd 100644 --- a/tests/ut/python/dataset/test_config.py +++ b/tests/ut/python/dataset/test_config.py @@ -245,17 +245,17 @@ def test_deterministic_run_distribution(): # First dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) - random_crop_op = c_vision.RandomHorizontalFlip(0.1) + random_horizontal_flip_op = c_vision.RandomHorizontalFlip(0.1) decode_op = c_vision.Decode() data1 = data1.map(input_columns=["image"], operations=decode_op) - data1 = data1.map(input_columns=["image"], operations=random_crop_op) + data1 = data1.map(input_columns=["image"], operations=random_horizontal_flip_op) # Second dataset data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data2 = data2.map(input_columns=["image"], operations=decode_op) # If seed is set up on constructor, so the two ops output deterministic sequence - random_crop_op2 = c_vision.RandomHorizontalFlip(0.1) - data2 = data2.map(input_columns=["image"], operations=random_crop_op2) + random_horizontal_flip_op2 = c_vision.RandomHorizontalFlip(0.1) + data2 = data2.map(input_columns=["image"], operations=random_horizontal_flip_op2) for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): np.testing.assert_equal(item1["image"], item2["image"]) diff --git a/tests/ut/python/dataset/test_dataset_numpy_slices.py b/tests/ut/python/dataset/test_dataset_numpy_slices.py index 4cd4e26a33..791a567408 100644 --- a/tests/ut/python/dataset/test_dataset_numpy_slices.py +++ b/tests/ut/python/dataset/test_dataset_numpy_slices.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import sys +import pytest import numpy as np +import pandas as pd import mindspore.dataset as de from mindspore import log as logger import mindspore.dataset.transforms.vision.c_transforms as vision -import pandas as pd def test_numpy_slices_list_1(): @@ -172,8 +174,26 @@ def test_numpy_slices_distributed_sampler(): assert sum([1 for _ in ds]) == 2 -def test_numpy_slices_sequential_sampler(): +def test_numpy_slices_distributed_shard_limit(): + logger.info("Test Slicing a 1D list.") + + np_data = [1, 2, 3] + num = sys.maxsize + with pytest.raises(ValueError) as err: + de.NumpySlicesDataset(np_data, num_shards=num, shard_id=0, shuffle=False) + assert "Input num_shards is not within the required interval of (1 to 2147483647)." in str(err.value) + +def test_numpy_slices_distributed_zero_shard(): + logger.info("Test Slicing a 1D list.") + + np_data = [1, 2, 3] + with pytest.raises(ValueError) as err: + de.NumpySlicesDataset(np_data, num_shards=0, shard_id=0, shuffle=False) + assert "Input num_shards is not within the required interval of (1 to 2147483647)." in str(err.value) + + +def test_numpy_slices_sequential_sampler(): logger.info("Test numpy_slices_dataset with SequentialSampler and repeat.") np_data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]] @@ -183,6 +203,42 @@ def test_numpy_slices_sequential_sampler(): assert np.equal(data[0], np_data[i % 8]).all() +def test_numpy_slices_invalid_column_names_type(): + logger.info("Test incorrect column_names input") + np_data = [1, 2, 3] + + with pytest.raises(TypeError) as err: + de.NumpySlicesDataset(np_data, column_names=[1], shuffle=False) + assert "Argument column_names[0] with value 1 is not of type (,)." in str(err.value) + + +def test_numpy_slices_invalid_column_names_string(): + logger.info("Test incorrect column_names input") + np_data = [1, 2, 3] + + with pytest.raises(ValueError) as err: + de.NumpySlicesDataset(np_data, column_names=[""], shuffle=False) + assert "column_names[0] should not be empty" in str(err.value) + + +def test_numpy_slices_invalid_empty_column_names(): + logger.info("Test incorrect column_names input") + np_data = [1, 2, 3] + + with pytest.raises(ValueError) as err: + de.NumpySlicesDataset(np_data, column_names=[], shuffle=False) + assert "column_names should not be empty" in str(err.value) + + +def test_numpy_slices_invalid_empty_data_column(): + logger.info("Test incorrect column_names input") + np_data = [] + + with pytest.raises(ValueError) as err: + de.NumpySlicesDataset(np_data, shuffle=False) + assert "Argument data cannot be empty" in str(err.value) + + if __name__ == "__main__": test_numpy_slices_list_1() test_numpy_slices_list_2() @@ -196,4 +252,10 @@ if __name__ == "__main__": test_numpy_slices_csv_dict() test_numpy_slices_num_samplers() test_numpy_slices_distributed_sampler() + test_numpy_slices_distributed_shard_limit() + test_numpy_slices_distributed_zero_shard() test_numpy_slices_sequential_sampler() + test_numpy_slices_invalid_column_names_type() + test_numpy_slices_invalid_column_names_string() + test_numpy_slices_invalid_empty_column_names() + test_numpy_slices_invalid_empty_data_column() diff --git a/tests/ut/python/dataset/test_datasets_cifarop.py b/tests/ut/python/dataset/test_datasets_cifarop.py new file mode 100644 index 0000000000..d6d3029b53 --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_cifarop.py @@ -0,0 +1,387 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test Cifar10 and Cifar100 dataset operators +""" +import os +import pytest +import numpy as np +import matplotlib.pyplot as plt +import mindspore.dataset as ds +from mindspore import log as logger + +DATA_DIR_10 = "../data/dataset/testCifar10Data" +DATA_DIR_100 = "../data/dataset/testCifar100Data" + + +def load_cifar(path, kind="cifar10"): + """ + load Cifar10/100 data + """ + raw = np.empty(0, dtype=np.uint8) + for file_name in os.listdir(path): + if file_name.endswith(".bin"): + with open(os.path.join(path, file_name), mode='rb') as file: + raw = np.append(raw, np.fromfile(file, dtype=np.uint8), axis=0) + if kind == "cifar10": + raw = raw.reshape(-1, 3073) + labels = raw[:, 0] + images = raw[:, 1:] + elif kind == "cifar100": + raw = raw.reshape(-1, 3074) + labels = raw[:, :2] + images = raw[:, 2:] + else: + raise ValueError("Invalid parameter value") + images = images.reshape(-1, 3, 32, 32) + images = images.transpose(0, 2, 3, 1) + return images, labels + + +def visualize_dataset(images, labels): + """ + Helper function to visualize the dataset samples + """ + num_samples = len(images) + for i in range(num_samples): + plt.subplot(1, num_samples, i + 1) + plt.imshow(images[i]) + plt.title(labels[i]) + plt.show() + + +### Testcases for Cifar10Dataset Op ### + + +def test_cifar10_content_check(): + """ + Validate Cifar10Dataset image readings + """ + logger.info("Test Cifar10Dataset Op with content check") + data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100, shuffle=False) + images, labels = load_cifar(DATA_DIR_10) + num_iter = 0 + # in this example, each dictionary has keys "image" and "label" + for i, d in enumerate(data1.create_dict_iterator()): + np.testing.assert_array_equal(d["image"], images[i]) + np.testing.assert_array_equal(d["label"], labels[i]) + num_iter += 1 + assert num_iter == 100 + + +def test_cifar10_basic(): + """ + Validate CIFAR10 + """ + logger.info("Test Cifar10Dataset Op") + + # case 1: test num_samples + data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100) + num_iter1 = 0 + for _ in data1.create_dict_iterator(): + num_iter1 += 1 + assert num_iter1 == 100 + + # case 2: test num_parallel_workers + data2 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=50, num_parallel_workers=1) + num_iter2 = 0 + for _ in data2.create_dict_iterator(): + num_iter2 += 1 + assert num_iter2 == 50 + + # case 3: test repeat + data3 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100) + data3 = data3.repeat(3) + num_iter3 = 0 + for _ in data3.create_dict_iterator(): + num_iter3 += 1 + assert num_iter3 == 300 + + # case 4: test batch with drop_remainder=False + data4 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100) + assert data4.get_dataset_size() == 100 + assert data4.get_batch_size() == 1 + data4 = data4.batch(batch_size=7) # drop_remainder is default to be False + assert data4.get_dataset_size() == 15 + assert data4.get_batch_size() == 7 + num_iter4 = 0 + for _ in data4.create_dict_iterator(): + num_iter4 += 1 + assert num_iter4 == 15 + + # case 5: test batch with drop_remainder=True + data5 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=100) + assert data5.get_dataset_size() == 100 + assert data5.get_batch_size() == 1 + data5 = data5.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped + assert data5.get_dataset_size() == 14 + assert data5.get_batch_size() == 7 + num_iter5 = 0 + for _ in data5.create_dict_iterator(): + num_iter5 += 1 + assert num_iter5 == 14 + + +def test_cifar10_pk_sampler(): + """ + Test Cifar10Dataset with PKSampler + """ + logger.info("Test Cifar10Dataset Op with PKSampler") + golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, + 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9] + sampler = ds.PKSampler(3) + data = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) + num_iter = 0 + label_list = [] + for item in data.create_dict_iterator(): + label_list.append(item["label"]) + num_iter += 1 + np.testing.assert_array_equal(golden, label_list) + assert num_iter == 30 + + +def test_cifar10_sequential_sampler(): + """ + Test Cifar10Dataset with SequentialSampler + """ + logger.info("Test Cifar10Dataset Op with SequentialSampler") + num_samples = 30 + sampler = ds.SequentialSampler(num_samples=num_samples) + data1 = ds.Cifar10Dataset(DATA_DIR_10, sampler=sampler) + data2 = ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_samples=num_samples) + num_iter = 0 + for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): + np.testing.assert_equal(item1["label"], item2["label"]) + num_iter += 1 + assert num_iter == num_samples + + +def test_cifar10_exception(): + """ + Test error cases for Cifar10Dataset + """ + logger.info("Test error cases for Cifar10Dataset") + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, sampler=ds.PKSampler(3)) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.Cifar10Dataset(DATA_DIR_10, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.Cifar10Dataset(DATA_DIR_10, num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.Cifar10Dataset(DATA_DIR_10, shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.Cifar10Dataset(DATA_DIR_10, num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.Cifar10Dataset(DATA_DIR_10, shuffle=False, num_parallel_workers=88) + + +def test_cifar10_visualize(plot=False): + """ + Visualize Cifar10Dataset results + """ + logger.info("Test Cifar10Dataset visualization") + + data1 = ds.Cifar10Dataset(DATA_DIR_10, num_samples=10, shuffle=False) + num_iter = 0 + image_list, label_list = [], [] + for item in data1.create_dict_iterator(): + image = item["image"] + label = item["label"] + image_list.append(image) + label_list.append("label {}".format(label)) + assert isinstance(image, np.ndarray) + assert image.shape == (32, 32, 3) + assert image.dtype == np.uint8 + assert label.dtype == np.uint32 + num_iter += 1 + assert num_iter == 10 + if plot: + visualize_dataset(image_list, label_list) + + +### Testcases for Cifar100Dataset Op ### + +def test_cifar100_content_check(): + """ + Validate Cifar100Dataset image readings + """ + logger.info("Test Cifar100Dataset with content check") + data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, shuffle=False) + images, labels = load_cifar(DATA_DIR_100, kind="cifar100") + num_iter = 0 + # in this example, each dictionary has keys "image", "coarse_label" and "fine_image" + for i, d in enumerate(data1.create_dict_iterator()): + np.testing.assert_array_equal(d["image"], images[i]) + np.testing.assert_array_equal(d["coarse_label"], labels[i][0]) + np.testing.assert_array_equal(d["fine_label"], labels[i][1]) + num_iter += 1 + assert num_iter == 100 + + +def test_cifar100_basic(): + """ + Test Cifar100Dataset + """ + logger.info("Test Cifar100Dataset") + + # case 1: test num_samples + data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100) + num_iter1 = 0 + for _ in data1.create_dict_iterator(): + num_iter1 += 1 + assert num_iter1 == 100 + + # case 2: test repeat + data1 = data1.repeat(2) + num_iter2 = 0 + for _ in data1.create_dict_iterator(): + num_iter2 += 1 + assert num_iter2 == 200 + + # case 3: test num_parallel_workers + data2 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100, num_parallel_workers=1) + num_iter3 = 0 + for _ in data2.create_dict_iterator(): + num_iter3 += 1 + assert num_iter3 == 100 + + # case 4: test batch with drop_remainder=False + data3 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100) + assert data3.get_dataset_size() == 100 + assert data3.get_batch_size() == 1 + data3 = data3.batch(batch_size=3) + assert data3.get_dataset_size() == 34 + assert data3.get_batch_size() == 3 + num_iter4 = 0 + for _ in data3.create_dict_iterator(): + num_iter4 += 1 + assert num_iter4 == 34 + + # case 4: test batch with drop_remainder=True + data4 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=100) + data4 = data4.batch(batch_size=3, drop_remainder=True) + assert data4.get_dataset_size() == 33 + assert data4.get_batch_size() == 3 + num_iter5 = 0 + for _ in data4.create_dict_iterator(): + num_iter5 += 1 + assert num_iter5 == 33 + + +def test_cifar100_pk_sampler(): + """ + Test Cifar100Dataset with PKSampler + """ + logger.info("Test Cifar100Dataset with PKSampler") + golden = [i for i in range(20)] + sampler = ds.PKSampler(1) + data = ds.Cifar100Dataset(DATA_DIR_100, sampler=sampler) + num_iter = 0 + label_list = [] + for item in data.create_dict_iterator(): + label_list.append(item["coarse_label"]) + num_iter += 1 + np.testing.assert_array_equal(golden, label_list) + assert num_iter == 20 + + +def test_cifar100_exception(): + """ + Test error cases for Cifar100Dataset + """ + logger.info("Test error cases for Cifar100Dataset") + error_msg_1 = "sampler and shuffle cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_1): + ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, sampler=ds.PKSampler(3)) + + error_msg_2 = "sampler and sharding cannot be specified at the same time" + with pytest.raises(RuntimeError, match=error_msg_2): + ds.Cifar100Dataset(DATA_DIR_100, sampler=ds.PKSampler(3), num_shards=2, shard_id=0) + + error_msg_3 = "num_shards is specified and currently requires shard_id as well" + with pytest.raises(RuntimeError, match=error_msg_3): + ds.Cifar100Dataset(DATA_DIR_100, num_shards=10) + + error_msg_4 = "shard_id is specified but num_shards is not" + with pytest.raises(RuntimeError, match=error_msg_4): + ds.Cifar100Dataset(DATA_DIR_100, shard_id=0) + + error_msg_5 = "Input shard_id is not within the required interval" + with pytest.raises(ValueError, match=error_msg_5): + ds.Cifar100Dataset(DATA_DIR_100, num_shards=2, shard_id=-1) + with pytest.raises(ValueError, match=error_msg_5): + ds.Cifar10Dataset(DATA_DIR_100, num_shards=2, shard_id=5) + + error_msg_6 = "num_parallel_workers exceeds" + with pytest.raises(ValueError, match=error_msg_6): + ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=0) + with pytest.raises(ValueError, match=error_msg_6): + ds.Cifar100Dataset(DATA_DIR_100, shuffle=False, num_parallel_workers=88) + + +def test_cifar100_visualize(plot=False): + """ + Visualize Cifar100Dataset results + """ + logger.info("Test Cifar100Dataset visualization") + + data1 = ds.Cifar100Dataset(DATA_DIR_100, num_samples=10, shuffle=False) + num_iter = 0 + image_list, label_list = [], [] + for item in data1.create_dict_iterator(): + image = item["image"] + coarse_label = item["coarse_label"] + fine_label = item["fine_label"] + image_list.append(image) + label_list.append("coarse_label {}\nfine_label {}".format(coarse_label, fine_label)) + assert isinstance(image, np.ndarray) + assert image.shape == (32, 32, 3) + assert image.dtype == np.uint8 + assert coarse_label.dtype == np.uint32 + assert fine_label.dtype == np.uint32 + num_iter += 1 + assert num_iter == 10 + if plot: + visualize_dataset(image_list, label_list) + + +if __name__ == '__main__': + test_cifar10_content_check() + test_cifar10_basic() + test_cifar10_pk_sampler() + test_cifar10_sequential_sampler() + test_cifar10_exception() + test_cifar10_visualize(plot=False) + + test_cifar100_content_check() + test_cifar100_basic() + test_cifar100_pk_sampler() + test_cifar100_exception() + test_cifar100_visualize(plot=False) diff --git a/tests/ut/python/dataset/test_datasets_imagenet.py b/tests/ut/python/dataset/test_datasets_imagenet.py deleted file mode 100644 index a6e2afa65a..0000000000 --- a/tests/ut/python/dataset/test_datasets_imagenet.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import mindspore.dataset as ds -import mindspore.dataset.transforms.c_transforms as data_trans -import mindspore.dataset.transforms.vision.c_transforms as vision -from mindspore import log as logger - -DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] -SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" - - -def test_case_repeat(): - """ - a simple repeat operation. - """ - logger.info("Test Simple Repeat") - # define parameters - repeat_count = 2 - - # apply dataset operations - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) - data1 = data1.repeat(repeat_count) - - num_iter = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - # in this example, each dictionary has keys "image" and "label" - logger.info("image is: {}".format(item["image"])) - logger.info("label is: {}".format(item["label"])) - num_iter += 1 - - logger.info("Number of data in data1: {}".format(num_iter)) - - -def test_case_shuffle(): - """ - a simple shuffle operation. - """ - logger.info("Test Simple Shuffle") - # define parameters - buffer_size = 8 - seed = 10 - - # apply dataset operations - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) - ds.config.set_seed(seed) - data1 = data1.shuffle(buffer_size=buffer_size) - - for item in data1.create_dict_iterator(): - logger.info("image is: {}".format(item["image"])) - logger.info("label is: {}".format(item["label"])) - - -def test_case_0(): - """ - Test Repeat then Shuffle - """ - logger.info("Test Repeat then Shuffle") - # define parameters - repeat_count = 2 - buffer_size = 7 - seed = 9 - - # apply dataset operations - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) - data1 = data1.repeat(repeat_count) - ds.config.set_seed(seed) - data1 = data1.shuffle(buffer_size=buffer_size) - - num_iter = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - # in this example, each dictionary has keys "image" and "label" - logger.info("image is: {}".format(item["image"])) - logger.info("label is: {}".format(item["label"])) - num_iter += 1 - - logger.info("Number of data in data1: {}".format(num_iter)) - - -def test_case_0_reverse(): - """ - Test Shuffle then Repeat - """ - logger.info("Test Shuffle then Repeat") - # define parameters - repeat_count = 2 - buffer_size = 10 - seed = 9 - - # apply dataset operations - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) - ds.config.set_seed(seed) - data1 = data1.shuffle(buffer_size=buffer_size) - data1 = data1.repeat(repeat_count) - - num_iter = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - # in this example, each dictionary has keys "image" and "label" - logger.info("image is: {}".format(item["image"])) - logger.info("label is: {}".format(item["label"])) - num_iter += 1 - - logger.info("Number of data in data1: {}".format(num_iter)) - - -def test_case_3(): - """ - Test Map - """ - logger.info("Test Map Rescale and Resize, then Shuffle") - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) - # define data augmentation parameters - rescale = 1.0 / 255.0 - shift = 0.0 - resize_height, resize_width = 224, 224 - - # define map operations - decode_op = vision.Decode() - rescale_op = vision.Rescale(rescale, shift) - # resize_op = vision.Resize(resize_height, resize_width, - # InterpolationMode.DE_INTER_LINEAR) # Bilinear mode - resize_op = vision.Resize((resize_height, resize_width)) - - # apply map operations on images - data1 = data1.map(input_columns=["image"], operations=decode_op) - data1 = data1.map(input_columns=["image"], operations=rescale_op) - data1 = data1.map(input_columns=["image"], operations=resize_op) - - # # apply ont-hot encoding on labels - num_classes = 4 - one_hot_encode = data_trans.OneHot(num_classes) # num_classes is input argument - data1 = data1.map(input_columns=["label"], operations=one_hot_encode) - # - # # apply Datasets - buffer_size = 100 - seed = 10 - batch_size = 2 - ds.config.set_seed(seed) - data1 = data1.shuffle(buffer_size=buffer_size) # 10000 as in imageNet train script - data1 = data1.batch(batch_size, drop_remainder=True) - - num_iter = 0 - for item in data1.create_dict_iterator(): # each data is a dictionary - # in this example, each dictionary has keys "image" and "label" - logger.info("image is: {}".format(item["image"])) - logger.info("label is: {}".format(item["label"])) - num_iter += 1 - - logger.info("Number of data in data1: {}".format(num_iter)) - - -if __name__ == '__main__': - logger.info('===========now test Repeat============') - # logger.info('Simple Repeat') - test_case_repeat() - logger.info('\n') - - logger.info('===========now test Shuffle===========') - # logger.info('Simple Shuffle') - test_case_shuffle() - logger.info('\n') - - # Note: cannot work with different shapes, hence not for image - # logger.info('===========now test Batch=============') - # # logger.info('Simple Batch') - # test_case_batch() - # logger.info('\n') - - logger.info('===========now test case 0============') - # logger.info('Repeat then Shuffle') - test_case_0() - logger.info('\n') - - logger.info('===========now test case 0 reverse============') - # # logger.info('Shuffle then Repeat') - test_case_0_reverse() - logger.info('\n') - - # logger.info('===========now test case 1============') - # # logger.info('Repeat with Batch') - # test_case_1() - # logger.info('\n') - - # logger.info('===========now test case 2============') - # # logger.info('Batch with Shuffle') - # test_case_2() - # logger.info('\n') - - # for image augmentation only - logger.info('===========now test case 3============') - logger.info('Map then Shuffle') - test_case_3() - logger.info('\n') diff --git a/tests/ut/python/dataset/test_datasets_imagenet_distribution.py b/tests/ut/python/dataset/test_datasets_imagenet_distribution.py deleted file mode 100644 index 92bdb68dc5..0000000000 --- a/tests/ut/python/dataset/test_datasets_imagenet_distribution.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import mindspore.dataset as ds -from mindspore import log as logger - -DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", - "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", - "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", - "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] - -SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" - - -def test_tf_file_normal(): - # apply dataset operations - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) - data1 = data1.repeat(1) - num_iter = 0 - for _ in data1.create_dict_iterator(): # each data is a dictionary - num_iter += 1 - - logger.info("Number of data in data1: {}".format(num_iter)) - assert num_iter == 12 - - -if __name__ == '__main__': - logger.info('=======test normal=======') - test_tf_file_normal() diff --git a/tests/ut/python/dataset/test_datasets_voc.py b/tests/ut/python/dataset/test_datasets_voc.py index 8db65e9734..37f4a8c123 100644 --- a/tests/ut/python/dataset/test_datasets_voc.py +++ b/tests/ut/python/dataset/test_datasets_voc.py @@ -37,7 +37,7 @@ def test_voc_detection(): for item in data1.create_dict_iterator(): assert item["image"].shape[0] == IMAGE_SHAPE[num] for bbox in item["annotation"]: - count[bbox[0]] += 1 + count[int(bbox[6])] += 1 num += 1 assert num == 9 assert count == [3, 2, 1, 2, 4, 3] @@ -55,8 +55,8 @@ def test_voc_class_index(): count = [0, 0, 0, 0, 0, 0] for item in data1.create_dict_iterator(): for bbox in item["annotation"]: - assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 5) - count[bbox[0]] += 1 + assert (int(bbox[6]) == 0 or int(bbox[6]) == 1 or int(bbox[6]) == 5) + count[int(bbox[6])] += 1 num += 1 assert num == 6 assert count == [3, 2, 0, 0, 0, 3] @@ -73,8 +73,9 @@ def test_voc_get_class_indexing(): count = [0, 0, 0, 0, 0, 0] for item in data1.create_dict_iterator(): for bbox in item["annotation"]: - assert (bbox[0] == 0 or bbox[0] == 1 or bbox[0] == 2 or bbox[0] == 3 or bbox[0] == 4 or bbox[0] == 5) - count[bbox[0]] += 1 + assert (int(bbox[6]) == 0 or int(bbox[6]) == 1 or int(bbox[6]) == 2 or int(bbox[6]) == 3 + or int(bbox[6]) == 4 or int(bbox[6]) == 5) + count[int(bbox[6])] += 1 num += 1 assert num == 9 assert count == [3, 2, 1, 2, 4, 3] diff --git a/tests/ut/python/dataset/test_exceptions.py b/tests/ut/python/dataset/test_exceptions.py index cbfa402bb0..253eb564ae 100644 --- a/tests/ut/python/dataset/test_exceptions.py +++ b/tests/ut/python/dataset/test_exceptions.py @@ -28,9 +28,9 @@ def test_exception_01(): """ logger.info("test_exception_01") data = ds.TFRecordDataset(DATA_DIR, columns_list=["image"]) - with pytest.raises(ValueError) as info: - data = data.map(input_columns=["image"], operations=vision.Resize(100, 100)) - assert "Invalid interpolation mode." in str(info.value) + with pytest.raises(TypeError) as info: + data.map(input_columns=["image"], operations=vision.Resize(100, 100)) + assert "Argument interpolation with value 100 is not of type (,)" in str(info.value) def test_exception_02(): @@ -40,8 +40,8 @@ def test_exception_02(): logger.info("test_exception_02") num_samples = -1 with pytest.raises(ValueError) as info: - data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) - assert "num_samples cannot be less than 0" in str(info.value) + ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) + assert 'Input num_samples is not within the required interval of (0 to 2147483647).' in str(info.value) num_samples = 1 data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) diff --git a/tests/ut/python/dataset/test_fill_op.py b/tests/ut/python/dataset/test_fill_op.py index f138dd15ec..657a529723 100644 --- a/tests/ut/python/dataset/test_fill_op.py +++ b/tests/ut/python/dataset/test_fill_op.py @@ -82,9 +82,9 @@ def test_fillop_error_handling(): data = data.map(input_columns=["col"], operations=fill_op) with pytest.raises(RuntimeError) as error_info: - for data_row in data: - print(data_row) - assert "Types do not match" in repr(error_info.value) + for _ in data: + pass + assert "Types do not match" in str(error_info.value) if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_from_dataset.py b/tests/ut/python/dataset/test_from_dataset.py index 207a6be6a1..983052ea08 100644 --- a/tests/ut/python/dataset/test_from_dataset.py +++ b/tests/ut/python/dataset/test_from_dataset.py @@ -23,9 +23,10 @@ import mindspore.dataset.text as text def test_demo_basic_from_dataset(): """ this is a tutorial on how from_dataset should be used in a normal use case""" data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) - vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, special_tokens=["", ""], + vocab = text.Vocab.from_dataset(data, "text", freq_range=None, top_k=None, + special_tokens=["", ""], special_first=True) - data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) + data = data.map(input_columns=["text"], operations=text.Lookup(vocab, "")) res = [] for d in data.create_dict_iterator(): res.append(d["text"].item()) @@ -38,7 +39,7 @@ def test_demo_basic_from_dataset_with_tokenizer(): data = data.map(input_columns=["text"], operations=text.UnicodeCharTokenizer()) vocab = text.Vocab.from_dataset(data, None, freq_range=None, top_k=None, special_tokens=["", ""], special_first=True) - data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) + data = data.map(input_columns=["text"], operations=text.Lookup(vocab, "")) res = [] for d in data.create_dict_iterator(): res.append(list(d["text"])) @@ -59,7 +60,7 @@ def test_from_dataset(): corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"]) vocab = text.Vocab.from_dataset(corpus_dataset, None, freq_range, top_k, special_tokens=["", ""], special_first=True) - corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab)) + corpus_dataset = corpus_dataset.map(input_columns="text", operations=text.Lookup(vocab, "")) res = [] for d in corpus_dataset.create_dict_iterator(): res.append(list(d["text"])) @@ -107,7 +108,7 @@ def test_from_dataset_special_token(): corpus_dataset = ds.GeneratorDataset(gen_corpus, column_names=["text"]) vocab = text.Vocab.from_dataset(corpus_dataset, None, None, top_k, special_tokens, special_first) data = ds.GeneratorDataset(gen_input(texts), column_names=["text"]) - data = data.map(input_columns="text", operations=text.Lookup(vocab)) + data = data.map(input_columns="text", operations=text.Lookup(vocab, "")) res = [] for d in data.create_dict_iterator(): res.append(d["text"].item()) @@ -127,15 +128,16 @@ def test_from_dataset_exceptions(): data = ds.TextFileDataset("../data/dataset/testVocab/words.txt", shuffle=False) vocab = text.Vocab.from_dataset(data, columns, freq_range, top_k) assert isinstance(vocab.text.Vocab) - except ValueError as e: + except (TypeError, ValueError) as e: assert s in str(e), str(e) - test_config("text", (), 1, "freq_range needs to be either None or a tuple of 2 integers") - test_config("text", (2, 3), 1.2345, "top_k needs to be a positive integer") - test_config(23, (2, 3), 1.2345, "columns need to be a list of strings") - test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b") - test_config("text", (2, 3), 0, "top_k needs to be a positive integer") - test_config([123], (2, 3), 0, "columns need to be a list of strings") + test_config("text", (), 1, "freq_range needs to be a tuple of 2 integers or an int and a None.") + test_config("text", (2, 3), 1.2345, + "Argument top_k with value 1.2345 is not of type (, )") + test_config(23, (2, 3), 1.2345, "Argument col_0 with value 23 is not of type (,)") + test_config("text", (100, 1), 12, "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)") + test_config("text", (2, 3), 0, "top_k must be greater than 0") + test_config([123], (2, 3), -1, "top_k must be greater than 0") if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index 4083336623..0f78cfd03a 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -23,6 +23,10 @@ SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" def test_graphdata_getfullneighbor(): + """ + Test get all neighbors + """ + logger.info('test get all neighbors.\n') g = ds.GraphData(DATASET_FILE, 2) nodes = g.get_all_nodes(1) assert len(nodes) == 10 @@ -33,6 +37,10 @@ def test_graphdata_getfullneighbor(): def test_graphdata_getnodefeature_input_check(): + """ + Test get node feature input check + """ + logger.info('test getnodefeature input check.\n') g = ds.GraphData(DATASET_FILE) with pytest.raises(TypeError): input_list = [1, [1, 1]] @@ -80,6 +88,10 @@ def test_graphdata_getnodefeature_input_check(): def test_graphdata_getsampledneighbors(): + """ + Test sampled neighbors + """ + logger.info('test get sampled neighbors.\n') g = ds.GraphData(DATASET_FILE, 1) edges = g.get_all_edges(0) nodes = g.get_nodes_from_edges(edges) @@ -90,6 +102,10 @@ def test_graphdata_getsampledneighbors(): def test_graphdata_getnegsampledneighbors(): + """ + Test neg sampled neighbors + """ + logger.info('test get negative sampled neighbors.\n') g = ds.GraphData(DATASET_FILE, 2) nodes = g.get_all_nodes(1) assert len(nodes) == 10 @@ -98,6 +114,10 @@ def test_graphdata_getnegsampledneighbors(): def test_graphdata_graphinfo(): + """ + Test graph info + """ + logger.info('test graph info.\n') g = ds.GraphData(DATASET_FILE, 2) graph_info = g.graph_info() assert graph_info['node_type'] == [1, 2] @@ -105,7 +125,7 @@ def test_graphdata_graphinfo(): assert graph_info['node_num'] == {1: 10, 2: 10} assert graph_info['edge_num'] == {0: 40} assert graph_info['node_feature_type'] == [1, 2, 3, 4] - assert graph_info['edge_feature_type'] == [] + assert graph_info['edge_feature_type'] == [1, 2] class RandomBatchedSampler(ds.Sampler): @@ -155,6 +175,10 @@ class GNNGraphDataset(): def test_graphdata_generatordataset(): + """ + Test generator dataset + """ + logger.info('test generator dataset.\n') g = ds.GraphData(DATASET_FILE) batch_num = 2 edge_num = g.graph_info()['edge_num'][0] @@ -173,10 +197,13 @@ def test_graphdata_generatordataset(): assert i == 40 -def test_graphdata_randomwalk(): +def test_graphdata_randomwalkdefault(): + """ + Test random walk defaults + """ + logger.info('test randomwalk with default parameters.\n') g = ds.GraphData(SOCIAL_DATA_FILE, 1) nodes = g.get_all_nodes(1) - print(len(nodes)) assert len(nodes) == 33 meta_path = [1 for _ in range(39)] @@ -184,18 +211,39 @@ def test_graphdata_randomwalk(): assert walks.shape == (33, 40) +def test_graphdata_randomwalk(): + """ + Test random walk + """ + logger.info('test random walk with given parameters.\n') + g = ds.GraphData(SOCIAL_DATA_FILE, 1) + nodes = g.get_all_nodes(1) + assert len(nodes) == 33 + + meta_path = [1 for _ in range(39)] + walks = g.random_walk(nodes, meta_path, 2.0, 0.5, -1) + assert walks.shape == (33, 40) + + +def test_graphdata_getedgefeature(): + """ + Test get edge feature + """ + logger.info('test get_edge_feature.\n') + g = ds.GraphData(DATASET_FILE) + edges = g.get_all_edges(0) + features = g.get_edge_feature(edges, [1, 2]) + assert features[0].shape == (40,) + assert features[1].shape == (40,) + + if __name__ == '__main__': test_graphdata_getfullneighbor() - logger.info('test_graphdata_getfullneighbor Ended.\n') test_graphdata_getnodefeature_input_check() - logger.info('test_graphdata_getnodefeature_input_check Ended.\n') test_graphdata_getsampledneighbors() - logger.info('test_graphdata_getsampledneighbors Ended.\n') test_graphdata_getnegsampledneighbors() - logger.info('test_graphdata_getnegsampledneighbors Ended.\n') test_graphdata_graphinfo() - logger.info('test_graphdata_graphinfo Ended.\n') test_graphdata_generatordataset() - logger.info('test_graphdata_generatordataset Ended.\n') + test_graphdata_randomwalkdefault() test_graphdata_randomwalk() - logger.info('test_graphdata_randomwalk Ended.\n') + test_graphdata_getedgefeature() diff --git a/tests/ut/python/dataset/test_linear_transformation.py b/tests/ut/python/dataset/test_linear_transformation.py index 0dd25a4da1..f932916ed8 100644 --- a/tests/ut/python/dataset/test_linear_transformation.py +++ b/tests/ut/python/dataset/test_linear_transformation.py @@ -73,6 +73,7 @@ def test_linear_transformation_op(plot=False): if plot: visualize_list(image, image_transformed) + def test_linear_transformation_md5(): """ Test LinearTransformation op: valid params (transformation_matrix, mean_vector) @@ -102,6 +103,7 @@ def test_linear_transformation_md5(): filename = "linear_transformation_01_result.npz" save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) + def test_linear_transformation_exception_01(): """ Test LinearTransformation op: transformation_matrix is not provided @@ -126,9 +128,10 @@ def test_linear_transformation_exception_01(): ] transform = py_vision.ComposeOp(transforms) data1 = data1.map(input_columns=["image"], operations=transform()) - except ValueError as e: + except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "not provided" in str(e) + assert "Argument transformation_matrix with value None is not of type (,)" in str(e) + def test_linear_transformation_exception_02(): """ @@ -154,9 +157,10 @@ def test_linear_transformation_exception_02(): ] transform = py_vision.ComposeOp(transforms) data1 = data1.map(input_columns=["image"], operations=transform()) - except ValueError as e: + except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "not provided" in str(e) + assert "Argument mean_vector with value None is not of type (,)" in str(e) + def test_linear_transformation_exception_03(): """ @@ -187,6 +191,7 @@ def test_linear_transformation_exception_03(): logger.info("Got an exception in DE: {}".format(str(e))) assert "square matrix" in str(e) + def test_linear_transformation_exception_04(): """ Test LinearTransformation op: mean_vector does not match dimension of transformation_matrix @@ -199,7 +204,7 @@ def test_linear_transformation_exception_04(): weight = 50 dim = 3 * height * weight transformation_matrix = np.ones([dim, dim]) - mean_vector = np.zeros(dim-1) + mean_vector = np.zeros(dim - 1) # Generate dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) @@ -216,6 +221,7 @@ def test_linear_transformation_exception_04(): logger.info("Got an exception in DE: {}".format(str(e))) assert "should match" in str(e) + if __name__ == '__main__': test_linear_transformation_op(plot=True) test_linear_transformation_md5() diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index b15944d76b..0b4d0dfc8f 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -184,24 +184,26 @@ def test_minddataset_invalidate_num_shards(): create_cv_mindrecord(1) columns_list = ["data", "label"] num_readers = 4 - with pytest.raises(Exception, match="shard_id is invalid, "): + with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 + assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) + os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) - def test_minddataset_invalidate_shard_id(): create_cv_mindrecord(1) columns_list = ["data", "label"] num_readers = 4 - with pytest.raises(Exception, match="shard_id is invalid, "): + with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 + assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) @@ -210,17 +212,19 @@ def test_minddataset_shard_id_bigger_than_num_shard(): create_cv_mindrecord(1) columns_list = ["data", "label"] num_readers = 4 - with pytest.raises(Exception, match="shard_id is invalid, "): + with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 + assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) - with pytest.raises(Exception, match="shard_id is invalid, "): + with pytest.raises(Exception) as error_info: data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5) num_iter = 0 for _ in data_set.create_dict_iterator(): num_iter += 1 + assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info) os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) diff --git a/tests/ut/python/dataset/test_ngram_op.py b/tests/ut/python/dataset/test_ngram_op.py index 73b2702378..777fca8764 100644 --- a/tests/ut/python/dataset/test_ngram_op.py +++ b/tests/ut/python/dataset/test_ngram_op.py @@ -15,9 +15,9 @@ """ Testing Ngram in mindspore.dataset """ +import numpy as np import mindspore.dataset as ds import mindspore.dataset.text as text -import numpy as np def test_multiple_ngrams(): @@ -61,7 +61,7 @@ def test_simple_ngram(): yield (np.array(line.split(" "), dtype='S'),) dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"]) - dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=None)) + dataset = dataset.map(input_columns=["text"], operations=text.Ngram(3, separator=" ")) i = 0 for data in dataset.create_dict_iterator(): @@ -72,43 +72,36 @@ def test_simple_ngram(): def test_corner_cases(): """ testing various corner cases and exceptions""" - def test_config(input_line, output_line, n, l_pad=None, r_pad=None, sep=None): + def test_config(input_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "): def gen(texts): yield (np.array(texts.split(" "), dtype='S'),) - dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"]) - dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep)) - for data in dataset.create_dict_iterator(): - assert [d.decode("utf8") for d in data["text"]] == output_line, output_line + try: + dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"]) + dataset = dataset.map(input_columns=["text"], operations=text.Ngram(n, l_pad, r_pad, separator=sep)) + for data in dataset.create_dict_iterator(): + return [d.decode("utf8") for d in data["text"]] + except (ValueError, TypeError) as e: + return str(e) # test tensor length smaller than n - test_config("Lone Star", ["Lone Star", "", "", ""], [2, 3, 4, 5]) + assert test_config("Lone Star", [2, 3, 4, 5]) == ["Lone Star", "", "", ""] # test empty separator - test_config("Beautiful British Columbia", ['BeautifulBritish', 'BritishColumbia'], 2, sep="") + assert test_config("Beautiful British Columbia", 2, sep="") == ['BeautifulBritish', 'BritishColumbia'] # test separator with longer length - test_config("Beautiful British Columbia", ['Beautiful^-^British^-^Columbia'], 3, sep="^-^") + assert test_config("Beautiful British Columbia", 3, sep="^-^") == ['Beautiful^-^British^-^Columbia'] # test left pad != right pad - test_config("Lone Star", ['The Lone Star State'], 4, ("The", 1), ("State", 1)) + assert test_config("Lone Star", 4, ("The", 1), ("State", 1)) == ['The Lone Star State'] # test invalid n - try: - test_config("Yours to Discover", "", [0, [1]]) - except Exception as e: - assert "ngram needs to be a positive number" in str(e) - # test empty n - try: - test_config("Yours to Discover", "", []) - except Exception as e: - assert "n needs to be a non-empty list" in str(e) - # test invalid pad - try: - test_config("Yours to Discover", "", [1], ("str", -1)) - except Exception as e: - assert "padding width need to be positive numbers" in str(e) + assert "gram[1] with value [1] is not of type (,)" in test_config("Yours to Discover", [1, [1]]) + assert "n needs to be a non-empty list" in test_config("Yours to Discover", []) # test invalid pad - try: - test_config("Yours to Discover", "", [1], ("str", "rts")) - except Exception as e: - assert "pad needs to be a tuple of (str, int)" in str(e) + assert "padding width need to be positive numbers" in test_config("Yours to Discover", [1], ("str", -1)) + assert "pad needs to be a tuple of (str, int)" in test_config("Yours to Discover", [1], ("str", "rts")) + # test 0 as in valid input + assert "gram_0 must be greater than 0" in test_config("Yours to Discover", 0) + assert "gram_0 must be greater than 0" in test_config("Yours to Discover", [0]) + assert "gram_1 must be greater than 0" in test_config("Yours to Discover", [1, 0]) if __name__ == '__main__': diff --git a/tests/ut/python/dataset/test_nlp.py b/tests/ut/python/dataset/test_nlp.py index 6b44cfc80b..cb517160a1 100644 --- a/tests/ut/python/dataset/test_nlp.py +++ b/tests/ut/python/dataset/test_nlp.py @@ -34,13 +34,32 @@ def test_on_tokenized_line(): jieba_op.add_word(word) data = data.map(input_columns=["text"], operations=jieba_op) vocab = text.Vocab.from_file(VOCAB_FILE, ",", special_tokens=["", ""]) - lookup = text.Lookup(vocab) + lookup = text.Lookup(vocab, "") data = data.map(input_columns=["text"], operations=lookup) res = np.array([[10, 1, 11, 1, 12, 1, 15, 1, 13, 1, 14], [11, 1, 12, 1, 10, 1, 14, 1, 13, 1, 15]], dtype=np.int32) for i, d in enumerate(data.create_dict_iterator()): - _ = (np.testing.assert_array_equal(d["text"], res[i]), i) + np.testing.assert_array_equal(d["text"], res[i]) + + +def test_on_tokenized_line_with_no_special_tokens(): + data = ds.TextFileDataset("../data/dataset/testVocab/lines.txt", shuffle=False) + jieba_op = text.JiebaTokenizer(HMM_FILE, MP_FILE, mode=text.JiebaMode.MP) + with open(VOCAB_FILE, 'r') as f: + for line in f: + word = line.split(',')[0] + jieba_op.add_word(word) + + data = data.map(input_columns=["text"], operations=jieba_op) + vocab = text.Vocab.from_file(VOCAB_FILE, ",") + lookup = text.Lookup(vocab, "not") + data = data.map(input_columns=["text"], operations=lookup) + res = np.array([[8, 0, 9, 0, 10, 0, 13, 0, 11, 0, 12], + [9, 0, 10, 0, 8, 0, 12, 0, 11, 0, 13]], dtype=np.int32) + for i, d in enumerate(data.create_dict_iterator()): + np.testing.assert_array_equal(d["text"], res[i]) if __name__ == '__main__': test_on_tokenized_line() + test_on_tokenized_line_with_no_special_tokens() diff --git a/tests/ut/python/dataset/test_nlp_jieop.py b/tests/ut/python/dataset/test_nlp_jieop.py deleted file mode 100644 index 1ab53205d0..0000000000 --- a/tests/ut/python/dataset/test_nlp_jieop.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import numpy as np -import mindspore.dataset as ds -from mindspore.dataset.text import JiebaTokenizer -from mindspore.dataset.text import JiebaMode, to_str - -DATA_FILE = "../data/dataset/testJiebaDataset/3.txt" -DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*" - -HMM_FILE = "../data/dataset/jiebadict/hmm_model.utf8" -MP_FILE = "../data/dataset/jiebadict/jieba.dict.utf8" - - -def test_jieba_1(): - """Test jieba tokenizer with MP mode""" - data = ds.TextFileDataset(DATA_FILE) - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) - data = data.map(input_columns=["text"], - operations=jieba_op, num_parallel_workers=1) - expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] - ret = [] - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -def test_jieba_1_1(): - """Test jieba tokenizer with HMM mode""" - data = ds.TextFileDataset(DATA_FILE) - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.HMM) - data = data.map(input_columns=["text"], - operations=jieba_op, num_parallel_workers=1) - expect = ['今天', '天气', '太', '好', '了', '我们', '一起', '去', '外面', '玩', '吧'] - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -def test_jieba_1_2(): - """Test jieba tokenizer with HMM MIX""" - data = ds.TextFileDataset(DATA_FILE) - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MIX) - data = data.map(input_columns=["text"], - operations=jieba_op, num_parallel_workers=1) - expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -def test_jieba_2(): - """Test add_word""" - DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" - data = ds.TextFileDataset(DATA_FILE4) - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) - jieba_op.add_word("男默女泪") - expect = ['男默女泪', '市', '长江大桥'] - data = data.map(input_columns=["text"], - operations=jieba_op, num_parallel_workers=2) - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -def test_jieba_2_1(): - """Test add_word with freq""" - DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" - data = ds.TextFileDataset(DATA_FILE4) - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) - jieba_op.add_word("男默女泪", 10) - data = data.map(input_columns=["text"], - operations=jieba_op, num_parallel_workers=2) - expect = ['男默女泪', '市', '长江大桥'] - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -def test_jieba_2_2(): - """Test add_word with invalid None Input""" - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) - try: - jieba_op.add_word(None) - except ValueError: - pass - - -def test_jieba_2_3(): - """Test add_word with freq, the value of freq affects the result of segmentation""" - DATA_FILE4 = "../data/dataset/testJiebaDataset/6.txt" - data = ds.TextFileDataset(DATA_FILE4) - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) - jieba_op.add_word("江大桥", 20000) - data = data.map(input_columns=["text"], - operations=jieba_op, num_parallel_workers=2) - expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式'] - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -def test_jieba_3(): - """Test add_dict with dict""" - DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" - user_dict = { - "男默女泪": 10 - } - data = ds.TextFileDataset(DATA_FILE4) - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) - jieba_op.add_dict(user_dict) - data = data.map(input_columns=["text"], - operations=jieba_op, num_parallel_workers=1) - expect = ['男默女泪', '市', '长江大桥'] - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -def test_jieba_3_1(): - """Test add_dict with dict""" - DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" - user_dict = { - "男默女泪": 10, - "江大桥": 20000 - } - data = ds.TextFileDataset(DATA_FILE4) - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) - jieba_op.add_dict(user_dict) - data = data.map(input_columns=["text"], - operations=jieba_op, num_parallel_workers=1) - expect = ['男默女泪', '市长', '江大桥'] - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -def test_jieba_4(): - DATA_FILE4 = "../data/dataset/testJiebaDataset/3.txt" - DICT_FILE = "../data/dataset/testJiebaDataset/user_dict.txt" - - data = ds.TextFileDataset(DATA_FILE4) - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) - jieba_op.add_dict(DICT_FILE) - data = data.map(input_columns=["text"], - operations=jieba_op, num_parallel_workers=1) - expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -def test_jieba_4_1(): - """Test add dict with invalid file path""" - DICT_FILE = "" - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) - try: - jieba_op.add_dict(DICT_FILE) - except ValueError: - pass - - -def test_jieba_5(): - """Test add dict with file path""" - DATA_FILE4 = "../data/dataset/testJiebaDataset/6.txt" - - data = ds.TextFileDataset(DATA_FILE4) - jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) - jieba_op.add_word("江大桥", 20000) - data = data.map(input_columns=["text"], - operations=jieba_op, num_parallel_workers=1) - expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式'] - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -def gen(): - text = np.array("今天天气太好了我们一起去外面玩吧".encode("UTF8"), dtype='S') - yield (text,) - - -def pytoken_op(input_data): - te = str(to_str(input_data)) - tokens = [] - tokens.append(te[:5].encode("UTF8")) - tokens.append(te[5:10].encode("UTF8")) - tokens.append(te[10:].encode("UTF8")) - return np.array(tokens, dtype='S') - - -def test_jieba_6(): - data = ds.GeneratorDataset(gen, column_names=["text"]) - data = data.map(input_columns=["text"], - operations=pytoken_op, num_parallel_workers=1) - expect = ['今天天气太', '好了我们一', '起去外面玩吧'] - for i in data.create_dict_iterator(): - ret = to_str(i["text"]) - for index, item in enumerate(ret): - assert item == expect[index] - - -if __name__ == "__main__": - test_jieba_1() - test_jieba_1_1() - test_jieba_1_2() - test_jieba_2() - test_jieba_2_1() - test_jieba_2_2() - test_jieba_3() - test_jieba_3_1() - test_jieba_4() - test_jieba_4_1() - test_jieba_5() - test_jieba_5() - test_jieba_6() diff --git a/tests/ut/python/dataset/test_noop_mode.py b/tests/ut/python/dataset/test_noop_mode.py new file mode 100644 index 0000000000..0ea9673200 --- /dev/null +++ b/tests/ut/python/dataset/test_noop_mode.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Test No-op mode support with Dummy Iterator +""" +import os +import mindspore.dataset as ds + +DATA_DIR = "../data/dataset/testVOC2012" + +def test_noop_pserver(): + os.environ['MS_ROLE'] = 'MS_PSERVER' + data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) + num = 0 + for _ in data1.create_dict_iterator(): + num += 1 + assert num == 0 + del os.environ['MS_ROLE'] + + +def test_noop_sched(): + os.environ['MS_ROLE'] = 'MS_SCHED' + data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) + num = 0 + for _ in data1.create_dict_iterator(): + num += 1 + assert num == 0 + del os.environ['MS_ROLE'] + + +if __name__ == '__main__': + test_noop_pserver() + test_noop_sched() diff --git a/tests/ut/python/dataset/test_normalizeOp.py b/tests/ut/python/dataset/test_normalizeOp.py index af97ee0c08..d5ebc799f9 100644 --- a/tests/ut/python/dataset/test_normalizeOp.py +++ b/tests/ut/python/dataset/test_normalizeOp.py @@ -279,7 +279,7 @@ def test_normalize_exception_invalid_range_py(): _ = py_vision.Normalize([0.75, 1.25, 0.5], [0.1, 0.18, 1.32]) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not within the required range" in str(e) + assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e) def test_normalize_grayscale_md5_01(): diff --git a/tests/ut/python/dataset/test_onehot_op.py b/tests/ut/python/dataset/test_onehot_op.py index 500f770b9b..44d98b0ae0 100644 --- a/tests/ut/python/dataset/test_onehot_op.py +++ b/tests/ut/python/dataset/test_onehot_op.py @@ -13,12 +13,13 @@ # limitations under the License. # ============================================================================== """ -Testing the one_hot op in DE +Testing the OneHot Op """ import numpy as np import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as data_trans +import mindspore.dataset.transforms.vision.c_transforms as c_vision from mindspore import log as logger from util import diff_mse @@ -37,15 +38,15 @@ def one_hot(index, depth): def test_one_hot(): """ - Test one_hot + Test OneHot Tensor Operator """ - logger.info("Test one_hot") + logger.info("test_one_hot") depth = 10 # First dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) - one_hot_op = data_trans.OneHot(depth) + one_hot_op = data_trans.OneHot(num_classes=depth) data1 = data1.map(input_columns=["label"], operations=one_hot_op, columns_order=["label"]) # Second dataset @@ -58,8 +59,54 @@ def test_one_hot(): label2 = one_hot(item2["label"][0], depth) mse = diff_mse(label1, label2) logger.info("DE one_hot: {}, Numpy one_hot: {}, diff: {}".format(label1, label2, mse)) + assert mse == 0 num_iter += 1 + assert num_iter == 3 + +def test_one_hot_post_aug(): + """ + Test One Hot Encoding after Multiple Data Augmentation Operators + """ + logger.info("test_one_hot_post_aug") + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + # Define data augmentation parameters + rescale = 1.0 / 255.0 + shift = 0.0 + resize_height, resize_width = 224, 224 + + # Define map operations + decode_op = c_vision.Decode() + rescale_op = c_vision.Rescale(rescale, shift) + resize_op = c_vision.Resize((resize_height, resize_width)) + + # Apply map operations on images + data1 = data1.map(input_columns=["image"], operations=decode_op) + data1 = data1.map(input_columns=["image"], operations=rescale_op) + data1 = data1.map(input_columns=["image"], operations=resize_op) + + # Apply one-hot encoding on labels + depth = 4 + one_hot_encode = data_trans.OneHot(depth) + data1 = data1.map(input_columns=["label"], operations=one_hot_encode) + + # Apply datasets ops + buffer_size = 100 + seed = 10 + batch_size = 2 + ds.config.set_seed(seed) + data1 = data1.shuffle(buffer_size=buffer_size) + data1 = data1.batch(batch_size, drop_remainder=True) + + num_iter = 0 + for item in data1.create_dict_iterator(): + logger.info("image is: {}".format(item["image"])) + logger.info("label is: {}".format(item["label"])) + num_iter += 1 + + assert num_iter == 1 if __name__ == "__main__": test_one_hot() + test_one_hot_post_aug() diff --git a/tests/ut/python/dataset/test_pad_end_op.py b/tests/ut/python/dataset/test_pad_end_op.py index 5742d73665..c25d6b9a95 100644 --- a/tests/ut/python/dataset/test_pad_end_op.py +++ b/tests/ut/python/dataset/test_pad_end_op.py @@ -61,6 +61,10 @@ def test_pad_end_exceptions(): pad_compare([3, 4, 5], ["2"], 1, []) assert "a value in the list is not an integer." in str(info.value) + with pytest.raises(TypeError) as info: + pad_compare([1, 2], 3, -1, [1, 2, -1]) + assert "Argument pad_end with value 3 is not of type (,)" in str(info.value) + if __name__ == "__main__": test_pad_end_basics() diff --git a/tests/ut/python/dataset/test_random_affine.py b/tests/ut/python/dataset/test_random_affine.py index b856684ed1..ec829eb53a 100644 --- a/tests/ut/python/dataset/test_random_affine.py +++ b/tests/ut/python/dataset/test_random_affine.py @@ -103,7 +103,7 @@ def test_random_affine_exception_negative_degrees(): _ = py_vision.RandomAffine(degrees=-15) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "If degrees is a single number, it cannot be negative." + assert str(e) == "Input degrees is not within the required interval of (0 to inf)." def test_random_affine_exception_translation_range(): @@ -115,7 +115,7 @@ def test_random_affine_exception_translation_range(): _ = py_vision.RandomAffine(degrees=15, translate=(0.1, 1.5)) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "translation values should be between 0 and 1" + assert str(e) == "Input translate at 1 is not within the required interval of (0.0 to 1.0)." def test_random_affine_exception_scale_value(): @@ -127,7 +127,7 @@ def test_random_affine_exception_scale_value(): _ = py_vision.RandomAffine(degrees=15, scale=(0.0, 1.1)) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "scale values should be positive" + assert str(e) == "Input scale[0] must be greater than 0." def test_random_affine_exception_shear_value(): @@ -139,7 +139,7 @@ def test_random_affine_exception_shear_value(): _ = py_vision.RandomAffine(degrees=15, shear=-5) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "If shear is a single number, it must be positive." + assert str(e) == "Input shear must be greater than 0." def test_random_affine_exception_degrees_size(): @@ -165,7 +165,9 @@ def test_random_affine_exception_translate_size(): _ = py_vision.RandomAffine(degrees=15, translate=(0.1)) except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "translate should be a list or tuple of length 2." + assert str( + e) == "Argument translate with value 0.1 is not of type (," \ + " )." def test_random_affine_exception_scale_size(): @@ -178,7 +180,8 @@ def test_random_affine_exception_scale_size(): _ = py_vision.RandomAffine(degrees=15, scale=(0.5)) except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "scale should be a list or tuple of length 2." + assert str(e) == "Argument scale with value 0.5 is not of type (," \ + " )." def test_random_affine_exception_shear_size(): @@ -191,7 +194,7 @@ def test_random_affine_exception_shear_size(): _ = py_vision.RandomAffine(degrees=15, shear=(-5, 5, 10)) except TypeError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "shear should be a list or tuple and it must be of length 2 or 4." + assert str(e) == "shear must be of length 2 or 4." if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_random_color.py b/tests/ut/python/dataset/test_random_color.py index 45847ba653..0015e8498f 100644 --- a/tests/ut/python/dataset/test_random_color.py +++ b/tests/ut/python/dataset/test_random_color.py @@ -97,7 +97,7 @@ def test_random_color_md5(): data = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) transforms = F.ComposeOp([F.Decode(), - F.RandomColor((0.5, 1.5)), + F.RandomColor((0.1, 1.9)), F.ToTensor()]) data = data.map(input_columns="image", operations=transforms()) diff --git a/tests/ut/python/dataset/test_random_crop_and_resize.py b/tests/ut/python/dataset/test_random_crop_and_resize.py index de039e6d82..486d2cd5ed 100644 --- a/tests/ut/python/dataset/test_random_crop_and_resize.py +++ b/tests/ut/python/dataset/test_random_crop_and_resize.py @@ -232,7 +232,7 @@ def test_random_crop_and_resize_04_c(): data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input range is not valid" in str(e) + assert "Input is not within the required interval of (0 to 16777216)." in str(e) def test_random_crop_and_resize_04_py(): @@ -255,7 +255,7 @@ def test_random_crop_and_resize_04_py(): data = data.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input range is not valid" in str(e) + assert "Input is not within the required interval of (0 to 16777216)." in str(e) def test_random_crop_and_resize_05_c(): @@ -275,7 +275,7 @@ def test_random_crop_and_resize_05_c(): data = data.map(input_columns=["image"], operations=random_crop_and_resize_op) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input range is not valid" in str(e) + assert "Input is not within the required interval of (0 to 16777216)." in str(e) def test_random_crop_and_resize_05_py(): @@ -298,7 +298,7 @@ def test_random_crop_and_resize_05_py(): data = data.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input range is not valid" in str(e) + assert "Input is not within the required interval of (0 to 16777216)." in str(e) def test_random_crop_and_resize_comp(plot=False): diff --git a/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py b/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py index b13dc466f7..599acc9560 100644 --- a/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_random_crop_and_resize_with_bbox.py @@ -25,34 +25,16 @@ from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, GENERATE_GOLDEN = False -# updated VOC dataset with correct annotations -DATA_DIR = "../data/dataset/testVOC2012_2" - - -def fix_annotate(bboxes): - """ - Fix annotations to format followed by mindspore. - :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format - :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format - """ - for bbox in bboxes: - if bbox.size == 7: - tmp = bbox[0] - bbox[0] = bbox[1] - bbox[1] = bbox[2] - bbox[2] = bbox[3] - bbox[3] = bbox[4] - bbox[4] = tmp - else: - print("ERROR: Invalid Bounding Box size provided") - break - return bboxes +# Updated VOC dataset with correct annotations - DATA_DIR +DATA_DIR_VOC = "../data/dataset/testVOC2012_2" +# COCO dataset - DATA_DIR, ANNOTATION_DIR +DATA_DIR_COCO = ["../data/dataset/testCOCO/train/", "../data/dataset/testCOCO/annotations/train.json"] def test_random_resized_crop_with_bbox_op_c(plot_vis=False): """ - Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, - tests with MD5 check, expected to pass + Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, + tests with MD5 check, expected to pass """ logger.info("test_random_resized_crop_with_bbox_op_c") @@ -60,22 +42,16 @@ def test_random_resized_crop_with_bbox_op_c(plot_vis=False): original_num_parallel_workers = config_get_set_num_parallel_workers(1) # Load dataset - dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + operations=[test_op]) filename = "random_resized_crop_with_bbox_01_c_result.npz" save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) @@ -94,26 +70,49 @@ def test_random_resized_crop_with_bbox_op_c(plot_vis=False): ds.config.set_num_parallel_workers(original_num_parallel_workers) +def test_random_resized_crop_with_bbox_op_coco_c(plot_vis=False): + """ + Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, + Testing with Coco dataset + """ + logger.info("test_random_resized_crop_with_bbox_op_coco_c") + # load dataset + dataCoco1 = ds.CocoDataset(DATA_DIR_COCO[0], annotation_file=DATA_DIR_COCO[1], task="Detection", + decode=True, shuffle=False) + + dataCoco2 = ds.CocoDataset(DATA_DIR_COCO[0], annotation_file=DATA_DIR_COCO[1], task="Detection", + decode=True, shuffle=False) + + test_op = c_vision.RandomResizedCropWithBBox((512, 512), (0.5, 1), (0.5, 1)) + + dataCoco2 = dataCoco2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) + + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataCoco1.create_dict_iterator(), dataCoco2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp, "bbox") + + def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False): """ - Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, - tests on dynamically generated edge case, expected to pass + Prints images and bboxes side by side with and without RandomResizedCropWithBBox Op applied, + tests on dynamically generated edge case, expected to pass """ logger.info("test_random_resized_crop_with_bbox_op_edge_c") # Load dataset - dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - # maps to convert data into valid edge case data dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -138,20 +137,17 @@ def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False): def test_random_resized_crop_with_bbox_op_invalid_c(): """ - Tests RandomResizedCropWithBBox on invalid constructor parameters, expected to raise ValueError + Tests RandomResizedCropWithBBox on invalid constructor parameters, expected to raise ValueError """ logger.info("test_random_resized_crop_with_bbox_op_invalid_c") # Load dataset, only Augmented Dataset as test will raise ValueError - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) try: # If input range of scale is not in the order of (min, max), ValueError will be raised. test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 0.5), (0.5, 0.5)) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -163,7 +159,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input range is not valid" in str(err) + assert "Input is not within the required interval of (0 to 16777216)." in str(err) def test_random_resized_crop_with_bbox_op_invalid2_c(): @@ -172,15 +168,12 @@ def test_random_resized_crop_with_bbox_op_invalid2_c(): """ logger.info("test_random_resized_crop_with_bbox_op_invalid2_c") # Load dataset # only loading the to AugDataset as test will fail on this - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) try: # If input range of ratio is not in the order of (min, max), ValueError will be raised. test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 1), (1, 0.5)) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -192,7 +185,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input range is not valid" in str(err) + assert "Input is not within the required interval of (0 to 16777216)." in str(err) def test_random_resized_crop_with_bbox_op_bad_c(): @@ -202,18 +195,19 @@ def test_random_resized_crop_with_bbox_op_bad_c(): logger.info("test_random_resized_crop_with_bbox_op_bad_c") test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") if __name__ == "__main__": test_random_resized_crop_with_bbox_op_c(plot_vis=True) + test_random_resized_crop_with_bbox_op_coco_c(plot_vis=True) test_random_resized_crop_with_bbox_op_edge_c(plot_vis=True) test_random_resized_crop_with_bbox_op_invalid_c() test_random_resized_crop_with_bbox_op_invalid2_c() diff --git a/tests/ut/python/dataset/test_random_crop_with_bbox.py b/tests/ut/python/dataset/test_random_crop_with_bbox.py index 9262dfd65d..b93c638f41 100644 --- a/tests/ut/python/dataset/test_random_crop_with_bbox.py +++ b/tests/ut/python/dataset/test_random_crop_with_bbox.py @@ -26,49 +26,25 @@ from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, GENERATE_GOLDEN = False -# updated VOC dataset with correct annotations -DATA_DIR = "../data/dataset/testVOC2012_2" - - -def fix_annotate(bboxes): - """ - Fix annotations to format followed by mindspore. - :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format - :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format - """ - for bbox in bboxes: - if bbox.size == 7: - tmp = bbox[0] - bbox[0] = bbox[1] - bbox[1] = bbox[2] - bbox[2] = bbox[3] - bbox[3] = bbox[4] - bbox[4] = tmp - else: - print("ERROR: Invalid Bounding Box size provided") - break - return bboxes +# Updated VOC dataset with correct annotations - DATA_DIR +DATA_DIR_VOC = "../data/dataset/testVOC2012_2" +# COCO dataset - DATA_DIR, ANNOTATION_DIR +DATA_DIR_COCO = ["../data/dataset/testCOCO/train/", "../data/dataset/testCOCO/annotations/train.json"] def test_random_crop_with_bbox_op_c(plot_vis=False): """ - Prints images and bboxes side by side with and without RandomCropWithBBox Op applied + Prints images and bboxes side by side with and without RandomCropWithBBox Op applied """ logger.info("test_random_crop_with_bbox_op_c") # Load dataset - dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) # define test OP with values to match existing Op UT test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -85,33 +61,57 @@ def test_random_crop_with_bbox_op_c(plot_vis=False): visualize_with_bounding_boxes(unaugSamp, augSamp) +def test_random_crop_with_bbox_op_coco_c(plot_vis=False): + """ + Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, + Testing with Coco dataset + """ + logger.info("test_random_crop_with_bbox_op_coco_c") + # load dataset + dataCoco1 = ds.CocoDataset(DATA_DIR_COCO[0], annotation_file=DATA_DIR_COCO[1], task="Detection", + decode=True, shuffle=False) + + dataCoco2 = ds.CocoDataset(DATA_DIR_COCO[0], annotation_file=DATA_DIR_COCO[1], task="Detection", + decode=True, shuffle=False) + + test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) + + dataCoco2 = dataCoco2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) + + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataCoco1.create_dict_iterator(), dataCoco2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp, "bbox") + + def test_random_crop_with_bbox_op2_c(plot_vis=False): """ - Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, - with md5 check, expected to pass + Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, + with md5 check, expected to pass """ logger.info("test_random_crop_with_bbox_op2_c") original_seed = config_get_set_seed(593447) original_num_parallel_workers = config_get_set_num_parallel_workers(1) # Load dataset - dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) # define test OP with values to match existing Op unit - test test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255)) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + operations=[test_op]) filename = "random_crop_with_bbox_01_c_result.npz" save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) @@ -132,29 +132,23 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False): def test_random_crop_with_bbox_op3_c(plot_vis=False): """ - Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, - with Padding Mode explicitly passed + Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, + with Padding Mode explicitly passed """ logger.info("test_random_crop_with_bbox_op3_c") # Load dataset - dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) # define test OP with values to match existing Op unit - test test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], - operations=[test_op]) # Add column for "annotation" + operations=[test_op]) unaugSamp, augSamp = [], [] @@ -168,25 +162,18 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False): def test_random_crop_with_bbox_op_edge_c(plot_vis=False): """ - Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, - applied on dynamically generated edge case, expected to pass + Prints images and bboxes side by side with and without RandomCropWithBBox Op applied, + applied on dynamically generated edge case, expected to pass """ logger.info("test_random_crop_with_bbox_op_edge_c") # Load dataset - dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) # define test OP with values to match existing Op unit - test test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - # maps to convert data into valid edge case data dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -216,16 +203,12 @@ def test_random_crop_with_bbox_op_invalid_c(): logger.info("test_random_crop_with_bbox_op_invalid_c") # Load dataset - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) try: # define test OP with values to match existing Op unit - test test_op = c_vision.RandomCropWithBBox([512, 512, 375]) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -246,18 +229,19 @@ def test_random_crop_with_bbox_op_bad_c(): logger.info("test_random_crop_with_bbox_op_bad_c") test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") if __name__ == "__main__": test_random_crop_with_bbox_op_c(plot_vis=True) + test_random_crop_with_bbox_op_coco_c(plot_vis=True) test_random_crop_with_bbox_op2_c(plot_vis=True) test_random_crop_with_bbox_op3_c(plot_vis=True) test_random_crop_with_bbox_op_edge_c(plot_vis=True) diff --git a/tests/ut/python/dataset/test_random_dataset.py b/tests/ut/python/dataset/test_random_dataset.py index 4d50be254c..56a2a93113 100644 --- a/tests/ut/python/dataset/test_random_dataset.py +++ b/tests/ut/python/dataset/test_random_dataset.py @@ -16,17 +16,16 @@ import mindspore.common.dtype as mstype import mindspore.dataset as ds from mindspore import log as logger - # just a basic test with parallel random data op def test_randomdataset_basic1(): - logger.info("Test randomdataset basic") + logger.info("Test randomdataset basic 1") schema = ds.Schema() schema.add_column('image', de_type=mstype.uint8, shape=[2]) schema.add_column('label', de_type=mstype.uint8, shape=[1]) # apply dataset operations - ds1 = ds.RandomDataset(schema=schema, num_samples=50, num_parallel_workers=4) + ds1 = ds.RandomDataset(schema=schema, total_rows=50, num_parallel_workers=4) ds1 = ds1.repeat(4) num_iter = 0 @@ -36,8 +35,9 @@ def test_randomdataset_basic1(): logger.info("{} label: {}".format(num_iter, data["label"])) num_iter += 1 - logger.info("Number of data in ds1: ", num_iter) + logger.info("Number of data in ds1: {}".format(num_iter)) assert num_iter == 200 + logger.info("Test randomdataset basic 1 complete") # Another simple test @@ -49,10 +49,8 @@ def test_randomdataset_basic2(): shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image) schema.add_column('label', de_type=mstype.uint8, shape=[1]) - # Make up about 10 samples - ds1 = ds.RandomDataset(schema=schema, num_samples=10, num_parallel_workers=1) - - # cache size allows for about 4 images since each image just a bit less than 1MB, after that we will have to spill + # Make up 10 rows + ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=1) ds1 = ds1.repeat(4) num_iter = 0 @@ -62,11 +60,31 @@ def test_randomdataset_basic2(): logger.info("printing the label: {}".format(data["label"])) num_iter += 1 - logger.info("Number of data in ds1: ", num_iter) + logger.info("Number of data in ds1: {}".format(num_iter)) assert num_iter == 40 + logger.info("Test randomdataset basic 2 complete") + + +# Another simple test +def test_randomdataset_basic3(): + logger.info("Test randomdataset basic 3") + + # Make up 10 samples, but here even the schema is randomly created + # The columns are named like this "c0", "c1", "c2" etc + # But, we will use a tuple iterator instead of dict iterator so the column names + # are not needed to iterate + ds1 = ds.RandomDataset(total_rows=10, num_parallel_workers=1) + ds1 = ds1.repeat(2) + + num_iter = 0 + for _ in ds1.create_tuple_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {}".format(num_iter)) + assert num_iter == 20 + logger.info("Test randomdataset basic 3 Complete") if __name__ == '__main__': test_randomdataset_basic1() test_randomdataset_basic2() - logger.info('test_randomdataset_basic Ended.\n') + test_randomdataset_basic3() diff --git a/tests/ut/python/dataset/test_random_grayscale.py b/tests/ut/python/dataset/test_random_grayscale.py index 83514a55f6..4cb25c3a3a 100644 --- a/tests/ut/python/dataset/test_random_grayscale.py +++ b/tests/ut/python/dataset/test_random_grayscale.py @@ -179,7 +179,7 @@ def test_random_grayscale_invalid_param(): data = data.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not within the required range" in str(e) + assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) if __name__ == "__main__": test_random_grayscale_valid_prob(True) diff --git a/tests/ut/python/dataset/test_random_horizontal_flip.py b/tests/ut/python/dataset/test_random_horizontal_flip.py index 1272148e4f..ef4f5b8eb6 100644 --- a/tests/ut/python/dataset/test_random_horizontal_flip.py +++ b/tests/ut/python/dataset/test_random_horizontal_flip.py @@ -141,7 +141,7 @@ def test_random_horizontal_invalid_prob_c(): data = data.map(input_columns=["image"], operations=random_horizontal_op) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) + assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) def test_random_horizontal_invalid_prob_py(): @@ -164,7 +164,7 @@ def test_random_horizontal_invalid_prob_py(): data = data.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) + assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(e) def test_random_horizontal_comp(plot=False): diff --git a/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py b/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py index 94ab843ce1..4fd51a7a03 100644 --- a/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py +++ b/tests/ut/python/dataset/test_random_horizontal_flip_with_bbox.py @@ -24,33 +24,15 @@ from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, GENERATE_GOLDEN = False +# updated VOC dataset with correct annotations DATA_DIR = "../data/dataset/testVOC2012_2" - - -def fix_annotate(bboxes): - """ - Fix annotations to format followed by mindspore. - :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format - :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format - """ - for bbox in bboxes: - if bbox.size == 7: - tmp = bbox[0] - bbox[0] = bbox[1] - bbox[1] = bbox[2] - bbox[2] = bbox[3] - bbox[3] = bbox[4] - bbox[4] = tmp - else: - print("ERROR: Invalid Bounding Box size provided") - break - return bboxes +DATA_DIR_2 = ["../data/dataset/testCOCO/train/", + "../data/dataset/testCOCO/annotations/train.json"] # DATA_DIR, ANNOTATION_DIR def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False): """ - Prints images side by side with and without Aug applied + bboxes to - compare and test + Prints images and bboxes side by side with and without RandomHorizontalFlipWithBBox Op applied """ logger.info("test_random_horizontal_flip_with_bbox_op_c") @@ -63,14 +45,6 @@ def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False): test_op = c_vision.RandomHorizontalFlipWithBBox(1) - # maps to fix annotations to minddata standard - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], @@ -86,7 +60,37 @@ def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False): visualize_with_bounding_boxes(unaugSamp, augSamp) -def test_random_horizontal_bbox_with_bbox_valid_rand_c(plot_vis=False): +def test_random_horizontal_flip_with_bbox_op_coco_c(plot_vis=False): + """ + Prints images and bboxes side by side with and without RandomHorizontalFlipWithBBox Op applied, + Testing with COCO dataset + """ + logger.info("test_random_horizontal_flip_with_bbox_op_coco_c") + + dataCoco1 = ds.CocoDataset(DATA_DIR_2[0], annotation_file=DATA_DIR_2[1], task="Detection", + decode=True, shuffle=False) + + dataCoco2 = ds.CocoDataset(DATA_DIR_2[0], annotation_file=DATA_DIR_2[1], task="Detection", + decode=True, shuffle=False) + + test_op = c_vision.RandomHorizontalFlipWithBBox(1) + + dataCoco2 = dataCoco2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) + + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataCoco1.create_dict_iterator(), dataCoco2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp, "bbox") + + +def test_random_horizontal_flip_with_bbox_valid_rand_c(plot_vis=False): """ Uses a valid non-default input, expect to pass Prints images side by side with and without Aug applied + bboxes to @@ -106,13 +110,6 @@ def test_random_horizontal_bbox_with_bbox_valid_rand_c(plot_vis=False): test_op = c_vision.RandomHorizontalFlipWithBBox(0.6) - # maps to fix annotations to minddata standard - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -148,25 +145,18 @@ def test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False): test_op = c_vision.RandomHorizontalFlipWithBBox(1) - # maps to fix annotations to minddata standard - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops # Add column for "annotation" dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], operations=lambda img, bbox: - (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.uint32))) + (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], operations=lambda img, bbox: - (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.uint32))) + (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32))) dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], @@ -193,9 +183,6 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c(): try: # Note: Valid range of prob should be [0.0, 1.0] test_op = c_vision.RandomHorizontalFlipWithBBox(1.5) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -203,7 +190,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c(): operations=[test_op]) # Add column for "annotation" except ValueError as error: logger.info("Got an exception in DE: {}".format(str(error))) - assert "Input is not" in str(error) + assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error) def test_random_horizontal_flip_with_bbox_invalid_bounds_c(): @@ -227,7 +214,8 @@ def test_random_horizontal_flip_with_bbox_invalid_bounds_c(): if __name__ == "__main__": # set to false to not show plots test_random_horizontal_flip_with_bbox_op_c(plot_vis=False) - test_random_horizontal_bbox_with_bbox_valid_rand_c(plot_vis=False) + test_random_horizontal_flip_with_bbox_op_coco_c(plot_vis=False) + test_random_horizontal_flip_with_bbox_valid_rand_c(plot_vis=False) test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False) test_random_horizontal_flip_with_bbox_invalid_prob_c() test_random_horizontal_flip_with_bbox_invalid_bounds_c() diff --git a/tests/ut/python/dataset/test_random_perspective.py b/tests/ut/python/dataset/test_random_perspective.py index 507c9cdb80..992bf2b222 100644 --- a/tests/ut/python/dataset/test_random_perspective.py +++ b/tests/ut/python/dataset/test_random_perspective.py @@ -67,7 +67,7 @@ def test_random_perspective_op(plot=False): visualize_list(image_original, image_perspective) -def test_random_perspective_md5(): +def skip_test_random_perspective_md5(): """ Test RandomPerspective with md5 comparison """ @@ -107,7 +107,7 @@ def test_random_perspective_exception_distortion_scale_range(): _ = py_vision.RandomPerspective(distortion_scale=1.5) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "Input is not within the required range" + assert str(e) == "Input distortion_scale is not within the required interval of (0.0 to 1.0)." def test_random_perspective_exception_prob_range(): @@ -119,11 +119,11 @@ def test_random_perspective_exception_prob_range(): _ = py_vision.RandomPerspective(prob=1.2) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert str(e) == "Input is not within the required range" + assert str(e) == "Input prob is not within the required interval of (0.0 to 1.0)." if __name__ == "__main__": test_random_perspective_op(plot=True) - test_random_perspective_md5() + skip_test_random_perspective_md5() test_random_perspective_exception_distortion_scale_range() test_random_perspective_exception_prob_range() diff --git a/tests/ut/python/dataset/test_random_resize_with_bbox.py b/tests/ut/python/dataset/test_random_resize_with_bbox.py index 4aadf9ef01..94f9d12427 100644 --- a/tests/ut/python/dataset/test_random_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_random_resize_with_bbox.py @@ -26,32 +26,18 @@ from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, GENERATE_GOLDEN = False DATA_DIR = "../data/dataset/testVOC2012_2" +DATA_DIR_2 = ["../data/dataset/testCOCO/train/", + "../data/dataset/testCOCO/annotations/train.json"] # DATA_DIR, ANNOTATION_DIR -def fix_annotate(bboxes): +def test_random_resize_with_bbox_op_voc_c(plot_vis=False): """ - Fix annotations to format followed by mindspore. - :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format - :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format + Prints images and bboxes side by side with and without RandomResizeWithBBox Op applied + testing with VOC dataset """ - for (i, box) in enumerate(bboxes): - if box.size == 7: - bboxes[i] = np.roll(box, -1) - else: - print("ERROR: Invalid Bounding Box size provided") - break - return bboxes - - -def test_random_resize_with_bbox_op_rand_c(plot_vis=False): - """ - Prints images and bboxes side by side with and without RandomResizeWithBBox Op applied, - tests with MD5 check, expected to pass - """ - logger.info("test_random_resize_with_bbox_rand_c") - original_seed = config_get_set_seed(1) + logger.info("test_random_resize_with_bbox_op_voc_c") + original_seed = config_get_set_seed(123) original_num_parallel_workers = config_get_set_num_parallel_workers(1) - # Load dataset dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) @@ -59,21 +45,15 @@ def test_random_resize_with_bbox_op_rand_c(plot_vis=False): dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - test_op = c_vision.RandomResizeWithBBox(200) + test_op = c_vision.RandomResizeWithBBox(100) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], operations=[test_op]) - filename = "random_resize_with_bbox_op_01_c_result.npz" + filename = "random_resize_with_bbox_op_01_c_voc_result.npz" save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) unaugSamp, augSamp = [], [] @@ -90,6 +70,49 @@ def test_random_resize_with_bbox_op_rand_c(plot_vis=False): ds.config.set_num_parallel_workers(original_num_parallel_workers) +def test_random_resize_with_bbox_op_rand_coco_c(plot_vis=False): + """ + Prints images and bboxes side by side with and without RandomResizeWithBBox Op applied, + tests with MD5 check, expected to pass + testing with COCO dataset + """ + logger.info("test_random_resize_with_bbox_op_rand_coco_c") + original_seed = config_get_set_seed(231) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + + # Load dataset + dataCoco1 = ds.CocoDataset(DATA_DIR_2[0], annotation_file=DATA_DIR_2[1], task="Detection", + decode=True, shuffle=False) + + dataCoco2 = ds.CocoDataset(DATA_DIR_2[0], annotation_file=DATA_DIR_2[1], task="Detection", + decode=True, shuffle=False) + + test_op = c_vision.RandomResizeWithBBox(200) + + # map to apply ops + + dataCoco2 = dataCoco2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) + + filename = "random_resize_with_bbox_op_01_c_coco_result.npz" + save_and_check_md5(dataCoco2, filename, generate_golden=GENERATE_GOLDEN) + + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataCoco1.create_dict_iterator(), dataCoco2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp, annot_name="bbox") + + # Restore config setting + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + def test_random_resize_with_bbox_op_edge_c(plot_vis=False): """ Prints images and bboxes side by side with and without RandomresizeWithBBox Op applied, @@ -105,13 +128,6 @@ def test_random_resize_with_bbox_op_edge_c(plot_vis=False): test_op = c_vision.RandomResizeWithBBox(500) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - # maps to convert data into valid edge case data dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -147,7 +163,7 @@ def test_random_resize_with_bbox_op_invalid_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input is not" in str(err) + assert "Input is not within the required interval of (1 to 16777216)." in str(err) try: # one of the size values is zero @@ -155,7 +171,7 @@ def test_random_resize_with_bbox_op_invalid_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input is not" in str(err) + assert "Input size at dim 0 is not within the required interval of (1 to 2147483647)." in str(err) try: # negative value for resize @@ -163,7 +179,7 @@ def test_random_resize_with_bbox_op_invalid_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input is not" in str(err) + assert "Input is not within the required interval of (1 to 16777216)." in str(err) try: # invalid input shape @@ -192,7 +208,8 @@ def test_random_resize_with_bbox_op_bad_c(): if __name__ == "__main__": - test_random_resize_with_bbox_op_rand_c(plot_vis=False) + test_random_resize_with_bbox_op_voc_c(plot_vis=False) + test_random_resize_with_bbox_op_rand_coco_c(plot_vis=False) test_random_resize_with_bbox_op_edge_c(plot_vis=False) test_random_resize_with_bbox_op_invalid_c() test_random_resize_with_bbox_op_bad_c() diff --git a/tests/ut/python/dataset/test_random_sharpness.py b/tests/ut/python/dataset/test_random_sharpness.py index d8207ff099..22e5c66f1a 100644 --- a/tests/ut/python/dataset/test_random_sharpness.py +++ b/tests/ut/python/dataset/test_random_sharpness.py @@ -97,7 +97,7 @@ def test_random_sharpness_md5(): # define map operations transforms = [ F.Decode(), - F.RandomSharpness((0.5, 1.5)), + F.RandomSharpness((0.1, 1.9)), F.ToTensor() ] transform = F.ComposeOp(transforms) diff --git a/tests/ut/python/dataset/test_random_vertical_flip.py b/tests/ut/python/dataset/test_random_vertical_flip.py index 2fc9b12774..a3d02959fd 100644 --- a/tests/ut/python/dataset/test_random_vertical_flip.py +++ b/tests/ut/python/dataset/test_random_vertical_flip.py @@ -141,7 +141,7 @@ def test_random_vertical_invalid_prob_c(): data = data.map(input_columns=["image"], operations=random_horizontal_op) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) + assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e) def test_random_vertical_invalid_prob_py(): @@ -163,7 +163,7 @@ def test_random_vertical_invalid_prob_py(): data = data.map(input_columns=["image"], operations=transform()) except ValueError as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "Input is not" in str(e) + assert 'Input prob is not within the required interval of (0.0 to 1.0).' in str(e) def test_random_vertical_comp(plot=False): diff --git a/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py b/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py index f746bd50b0..490dc3e419 100644 --- a/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py +++ b/tests/ut/python/dataset/test_random_vertical_flip_with_bbox.py @@ -25,50 +25,26 @@ from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, GENERATE_GOLDEN = False -# updated VOC dataset with correct annotations -DATA_DIR = "../data/dataset/testVOC2012_2" - - -def fix_annotate(bboxes): - """ - Fix annotations to format followed by mindspore. - :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format - :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format - """ - for bbox in bboxes: - if bbox.size == 7: - tmp = bbox[0] - bbox[0] = bbox[1] - bbox[1] = bbox[2] - bbox[2] = bbox[3] - bbox[3] = bbox[4] - bbox[4] = tmp - else: - print("ERROR: Invalid Bounding Box size provided") - break - return bboxes +# Updated VOC dataset with correct annotations - DATA_DIR +DATA_DIR_VOC = "../data/dataset/testVOC2012_2" +# COCO dataset - DATA_DIR, ANNOTATION_DIR +DATA_DIR_COCO = ["../data/dataset/testCOCO/train/", "../data/dataset/testCOCO/annotations/train.json"] def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): """ - Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied + Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied """ logger.info("test_random_vertical_flip_with_bbox_op_c") # Load dataset - dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) test_op = c_vision.RandomVerticalFlipWithBBox(1) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -84,31 +60,56 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): if plot_vis: visualize_with_bounding_boxes(unaugSamp, augSamp) +def test_random_vertical_flip_with_bbox_op_coco_c(plot_vis=False): + """ + Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, + Testing with Coco dataset + """ + logger.info("test_random_vertical_flip_with_bbox_op_coco_c") + # load dataset + dataCoco1 = ds.CocoDataset(DATA_DIR_COCO[0], annotation_file=DATA_DIR_COCO[1], task="Detection", + decode=True, shuffle=False) + + dataCoco2 = ds.CocoDataset(DATA_DIR_COCO[0], annotation_file=DATA_DIR_COCO[1], task="Detection", + decode=True, shuffle=False) + + test_op = c_vision.RandomVerticalFlipWithBBox(1) + + dataCoco2 = dataCoco2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) + + test_op = c_vision.RandomVerticalFlipWithBBox(1) + + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataCoco1.create_dict_iterator(), dataCoco2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp, "bbox") + def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): """ - Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, - tests with MD5 check, expected to pass + Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, + tests with MD5 check, expected to pass """ logger.info("test_random_vertical_flip_with_bbox_op_rand_c") original_seed = config_get_set_seed(29847) original_num_parallel_workers = config_get_set_num_parallel_workers(1) # Load dataset - dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) test_op = c_vision.RandomVerticalFlipWithBBox(0.8) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -134,25 +135,18 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False): """ - Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, + Prints images and bboxes side by side with and without RandomVerticalFlipWithBBox Op applied, applied on dynamically generated edge case, expected to pass """ logger.info("test_random_vertical_flip_with_bbox_op_edge_c") - dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) test_op = c_vision.RandomVerticalFlipWithBBox(1) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - # maps to convert data into valid edge case data dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -177,17 +171,15 @@ def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False): def test_random_vertical_flip_with_bbox_op_invalid_c(): """ - Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError + Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError """ logger.info("test_random_vertical_flip_with_bbox_op_invalid_c") - dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", + dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) try: test_op = c_vision.RandomVerticalFlipWithBBox(2) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) + # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -199,7 +191,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c(): except ValueError as err: logger.info("Got an exception in DE: {}".format(str(err))) - assert "Input is not" in str(err) + assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(err) def test_random_vertical_flip_with_bbox_op_bad_c(): @@ -209,18 +201,19 @@ def test_random_vertical_flip_with_bbox_op_bad_c(): logger.info("test_random_vertical_flip_with_bbox_op_bad_c") test_op = c_vision.RandomVerticalFlipWithBBox(1) - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") - data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") if __name__ == "__main__": test_random_vertical_flip_with_bbox_op_c(plot_vis=True) + test_random_vertical_flip_with_bbox_op_coco_c(plot_vis=True) test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=True) test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=True) test_random_vertical_flip_with_bbox_op_invalid_c() diff --git a/tests/ut/python/dataset/test_repeat.py b/tests/ut/python/dataset/test_repeat.py index 4bdde7beeb..ca4702ff8c 100644 --- a/tests/ut/python/dataset/test_repeat.py +++ b/tests/ut/python/dataset/test_repeat.py @@ -12,25 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +""" +Test Repeat Op +""" import numpy as np -from util import save_and_check import mindspore.dataset as ds import mindspore.dataset.transforms.vision.c_transforms as vision from mindspore import log as logger +from util import save_and_check_dict DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"] SCHEMA_DIR_TF = "../data/dataset/testTFTestAllTypes/datasetSchema.json" -COLUMNS_TF = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", - "col_sint16", "col_sint32", "col_sint64"] -GENERATE_GOLDEN = False - -IMG_DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] -IMG_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" +GENERATE_GOLDEN = False + def test_tf_repeat_01(): """ @@ -39,14 +38,13 @@ def test_tf_repeat_01(): logger.info("Test Simple Repeat") # define parameters repeat_count = 2 - parameters = {"params": {'repeat_count': repeat_count}} # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, shuffle=False) data1 = data1.repeat(repeat_count) filename = "repeat_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def test_tf_repeat_02(): @@ -99,14 +97,13 @@ def test_tf_repeat_04(): logger.info("Test Simple Repeat Column List") # define parameters repeat_count = 2 - parameters = {"params": {'repeat_count': repeat_count}} columns_list = ["col_sint64", "col_sint32"] # apply dataset operations data1 = ds.TFRecordDataset(DATA_DIR_TF, SCHEMA_DIR_TF, columns_list=columns_list, shuffle=False) data1 = data1.repeat(repeat_count) filename = "repeat_list_result.npz" - save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) def generator(): @@ -115,6 +112,7 @@ def generator(): def test_nested_repeat1(): + logger.info("test_nested_repeat1") data = ds.GeneratorDataset(generator, ["data"]) data = data.repeat(2) data = data.repeat(3) @@ -126,6 +124,7 @@ def test_nested_repeat1(): def test_nested_repeat2(): + logger.info("test_nested_repeat2") data = ds.GeneratorDataset(generator, ["data"]) data = data.repeat(1) data = data.repeat(1) @@ -137,6 +136,7 @@ def test_nested_repeat2(): def test_nested_repeat3(): + logger.info("test_nested_repeat3") data = ds.GeneratorDataset(generator, ["data"]) data = data.repeat(1) data = data.repeat(2) @@ -148,6 +148,7 @@ def test_nested_repeat3(): def test_nested_repeat4(): + logger.info("test_nested_repeat4") data = ds.GeneratorDataset(generator, ["data"]) data = data.repeat(2) data = data.repeat(1) @@ -159,6 +160,7 @@ def test_nested_repeat4(): def test_nested_repeat5(): + logger.info("test_nested_repeat5") data = ds.GeneratorDataset(generator, ["data"]) data = data.batch(3) data = data.repeat(2) @@ -171,6 +173,7 @@ def test_nested_repeat5(): def test_nested_repeat6(): + logger.info("test_nested_repeat6") data = ds.GeneratorDataset(generator, ["data"]) data = data.repeat(2) data = data.batch(3) @@ -183,6 +186,7 @@ def test_nested_repeat6(): def test_nested_repeat7(): + logger.info("test_nested_repeat7") data = ds.GeneratorDataset(generator, ["data"]) data = data.repeat(2) data = data.repeat(3) @@ -195,6 +199,7 @@ def test_nested_repeat7(): def test_nested_repeat8(): + logger.info("test_nested_repeat8") data = ds.GeneratorDataset(generator, ["data"]) data = data.batch(2, drop_remainder=False) data = data.repeat(2) @@ -210,6 +215,7 @@ def test_nested_repeat8(): def test_nested_repeat9(): + logger.info("test_nested_repeat9") data = ds.GeneratorDataset(generator, ["data"]) data = data.repeat() data = data.repeat(3) @@ -221,6 +227,7 @@ def test_nested_repeat9(): def test_nested_repeat10(): + logger.info("test_nested_repeat10") data = ds.GeneratorDataset(generator, ["data"]) data = data.repeat(3) data = data.repeat() @@ -232,6 +239,7 @@ def test_nested_repeat10(): def test_nested_repeat11(): + logger.info("test_nested_repeat11") data = ds.GeneratorDataset(generator, ["data"]) data = data.repeat(2) data = data.repeat(3) diff --git a/tests/ut/python/dataset/test_resize_with_bbox.py b/tests/ut/python/dataset/test_resize_with_bbox.py index 06f3937958..3bb731ee97 100644 --- a/tests/ut/python/dataset/test_resize_with_bbox.py +++ b/tests/ut/python/dataset/test_resize_with_bbox.py @@ -26,29 +26,16 @@ from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, GENERATE_GOLDEN = False DATA_DIR = "../data/dataset/testVOC2012_2" +DATA_DIR_2 = ["../data/dataset/testCOCO/train/", + "../data/dataset/testCOCO/annotations/train.json"] # DATA_DIR, ANNOTATION_DIR -def fix_annotate(bboxes): +def test_resize_with_bbox_op_voc_c(plot_vis=False): """ - Fix annotations to format followed by mindspore. - :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format - :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format + Prints images and bboxes side by side with and without ResizeWithBBox Op applied + testing with VOC dataset """ - for (i, box) in enumerate(bboxes): - if box.size == 7: - bboxes[i] = np.roll(box, -1) - else: - print("ERROR: Invalid Bounding Box size provided") - break - return bboxes - - -def test_resize_with_bbox_op_c(plot_vis=False): - """ - Prints images and bboxes side by side with and without ResizeWithBBox Op applied, - tests with MD5 check, expected to pass - """ - logger.info("test_resize_with_bbox_op_c") + logger.info("test_resize_with_bbox_op_voc_c") # Load dataset dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", @@ -57,21 +44,15 @@ def test_resize_with_bbox_op_c(plot_vis=False): dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) - test_op = c_vision.ResizeWithBBox(200) + test_op = c_vision.ResizeWithBBox(100) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) # map to apply ops dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], operations=[test_op]) - filename = "resize_with_bbox_op_01_c_result.npz" + filename = "resize_with_bbox_op_01_c_voc_result.npz" save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN) unaugSamp, augSamp = [], [] @@ -84,6 +65,43 @@ def test_resize_with_bbox_op_c(plot_vis=False): visualize_with_bounding_boxes(unaugSamp, augSamp) +def test_resize_with_bbox_op_coco_c(plot_vis=False): + """ + Prints images and bboxes side by side with and without ResizeWithBBox Op applied, + tests with MD5 check, expected to pass + Testing with COCO dataset + """ + logger.info("test_resize_with_bbox_op_coco_c") + + # Load dataset + dataCOCO1 = ds.CocoDataset(DATA_DIR_2[0], annotation_file=DATA_DIR_2[1], task="Detection", + decode=True, shuffle=False) + + dataCOCO2 = ds.CocoDataset(DATA_DIR_2[0], annotation_file=DATA_DIR_2[1], task="Detection", + decode=True, shuffle=False) + + test_op = c_vision.ResizeWithBBox(200) + + # map to apply ops + + dataCOCO2 = dataCOCO2.map(input_columns=["image", "bbox"], + output_columns=["image", "bbox"], + columns_order=["image", "bbox"], + operations=[test_op]) + + filename = "resize_with_bbox_op_01_c_coco_result.npz" + save_and_check_md5(dataCOCO2, filename, generate_golden=GENERATE_GOLDEN) + + unaugSamp, augSamp = [], [] + + for unAug, Aug in zip(dataCOCO1.create_dict_iterator(), dataCOCO2.create_dict_iterator()): + unaugSamp.append(unAug) + augSamp.append(Aug) + + if plot_vis: + visualize_with_bounding_boxes(unaugSamp, augSamp, annot_name="bbox") + + def test_resize_with_bbox_op_edge_c(plot_vis=False): """ Prints images and bboxes side by side with and without ResizeWithBBox Op applied, @@ -99,13 +117,6 @@ def test_resize_with_bbox_op_edge_c(plot_vis=False): test_op = c_vision.ResizeWithBBox(500) - dataVoc1 = dataVoc1.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - dataVoc2 = dataVoc2.map(input_columns=["annotation"], - output_columns=["annotation"], - operations=fix_annotate) - # maps to convert data into valid edge case data dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], @@ -113,7 +124,6 @@ def test_resize_with_bbox_op_edge_c(plot_vis=False): operations=[lambda img, bboxes: ( img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))]) - # Test Op added to list of Operations here dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"], output_columns=["image", "annotation"], columns_order=["image", "annotation"], @@ -140,7 +150,7 @@ def test_resize_with_bbox_op_invalid_c(): # invalid interpolation value c_vision.ResizeWithBBox(400, interpolation="invalid") - except ValueError as err: + except TypeError as err: logger.info("Got an exception in DE: {}".format(str(err))) assert "interpolation" in str(err) @@ -163,7 +173,8 @@ def test_resize_with_bbox_op_bad_c(): if __name__ == "__main__": - test_resize_with_bbox_op_c(plot_vis=False) + test_resize_with_bbox_op_voc_c(plot_vis=False) + test_resize_with_bbox_op_coco_c(plot_vis=False) test_resize_with_bbox_op_edge_c(plot_vis=False) test_resize_with_bbox_op_invalid_c() test_resize_with_bbox_op_bad_c() diff --git a/tests/ut/python/dataset/test_shuffle.py b/tests/ut/python/dataset/test_shuffle.py index 56cc65a23b..460c491ca1 100644 --- a/tests/ut/python/dataset/test_shuffle.py +++ b/tests/ut/python/dataset/test_shuffle.py @@ -154,7 +154,7 @@ def test_shuffle_exception_01(): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "buffer_size" in str(e) + assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) def test_shuffle_exception_02(): @@ -172,7 +172,7 @@ def test_shuffle_exception_02(): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "buffer_size" in str(e) + assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) def test_shuffle_exception_03(): @@ -190,7 +190,7 @@ def test_shuffle_exception_03(): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "buffer_size" in str(e) + assert "Input buffer_size is not within the required interval of (2 to 2147483647)" in str(e) def test_shuffle_exception_05(): diff --git a/tests/ut/python/dataset/test_sync_wait.py b/tests/ut/python/dataset/test_sync_wait.py index a5727a2991..eb2261a5d3 100644 --- a/tests/ut/python/dataset/test_sync_wait.py +++ b/tests/ut/python/dataset/test_sync_wait.py @@ -14,7 +14,7 @@ # ============================================================================== import numpy as np - +import pytest import mindspore.dataset as ds from mindspore import log as logger @@ -163,7 +163,6 @@ def test_sync_exception_01(): """ logger.info("test_sync_exception_01") shuffle_size = 4 - batch_size = 10 dataset = ds.GeneratorDataset(gen, column_names=["input"]) @@ -171,11 +170,9 @@ def test_sync_exception_01(): dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) - try: - dataset = dataset.shuffle(shuffle_size) - except Exception as e: - assert "shuffle" in str(e) - dataset = dataset.batch(batch_size) + with pytest.raises(RuntimeError) as e: + dataset.shuffle(shuffle_size) + assert "No shuffle after sync operators" in str(e.value) def test_sync_exception_02(): @@ -183,7 +180,6 @@ def test_sync_exception_02(): Test sync: with duplicated condition name """ logger.info("test_sync_exception_02") - batch_size = 6 dataset = ds.GeneratorDataset(gen, column_names=["input"]) @@ -192,11 +188,9 @@ def test_sync_exception_02(): dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) - try: - dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") - except Exception as e: - assert "name" in str(e) - dataset = dataset.batch(batch_size) + with pytest.raises(RuntimeError) as e: + dataset.sync_wait(num_batch=2, condition_name="every batch") + assert "Condition name is already in use" in str(e.value) def test_sync_exception_03(): @@ -209,12 +203,9 @@ def test_sync_exception_03(): aug = Augment(0) # try to create dataset with batch_size < 0 - try: - dataset = dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update) - except Exception as e: - assert "num_batch" in str(e) - - dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + with pytest.raises(ValueError) as e: + dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update) + assert "num_batch need to be greater than 0." in str(e.value) def test_sync_exception_04(): @@ -230,14 +221,13 @@ def test_sync_exception_04(): dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) count = 0 - try: + with pytest.raises(RuntimeError) as e: for _ in dataset.create_dict_iterator(): count += 1 data = {"loss": count} - # dataset.disable_sync() dataset.sync_update(condition_name="every batch", num_batch=-1, data=data) - except Exception as e: - assert "batch" in str(e) + assert "Sync_update batch size can only be positive" in str(e.value) + def test_sync_exception_05(): """ @@ -251,15 +241,15 @@ def test_sync_exception_05(): # try to create dataset with batch_size < 0 dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) - try: + with pytest.raises(RuntimeError) as e: for _ in dataset.create_dict_iterator(): dataset.disable_sync() count += 1 data = {"loss": count} dataset.disable_sync() dataset.sync_update(condition_name="every", data=data) - except Exception as e: - assert "name" in str(e) + assert "Condition name not found" in str(e.value) + if __name__ == "__main__": test_simple_sync_wait() diff --git a/tests/ut/python/dataset/test_ten_crop.py b/tests/ut/python/dataset/test_ten_crop.py index 7bffea5cc9..d196bc05cf 100644 --- a/tests/ut/python/dataset/test_ten_crop.py +++ b/tests/ut/python/dataset/test_ten_crop.py @@ -62,7 +62,7 @@ def util_test_ten_crop(crop_size, vertical_flip=False, plot=False): logger.info("dtype of image_2: {}".format(image_2.dtype)) if plot: - visualize_list(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1)) + visualize_list(np.array([image_1] * 10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1)) # The output data should be of a 4D tensor shape, a stack of 10 images. assert len(image_2.shape) == 4 @@ -144,7 +144,7 @@ def test_ten_crop_invalid_size_error_msg(): vision.TenCrop(0), lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images ] - error_msg = "Input is not within the required range" + error_msg = "Input is not within the required interval of (1 to 16777216)." assert error_msg == str(info.value) with pytest.raises(ValueError) as info: diff --git a/tests/ut/python/dataset/test_text_basic_tokenizer.py b/tests/ut/python/dataset/test_text_basic_tokenizer.py new file mode 100644 index 0000000000..822790fd60 --- /dev/null +++ b/tests/ut/python/dataset/test_text_basic_tokenizer.py @@ -0,0 +1,138 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing BasicTokenizer op in DE +""" +import numpy as np +import mindspore.dataset as ds +from mindspore import log as logger +import mindspore.dataset.text as text + +BASIC_TOKENIZER_FILE = "../data/dataset/testTokenizerData/basic_tokenizer.txt" + +test_paras = [ + dict( + first=1, + last=6, + expected_tokens= + [['Welcome', 'to', 'Beijing', '北', '京', '欢', '迎', '您'], + ['長', '風', '破', '浪', '會', '有', '時', ',', '直', '掛', '雲', '帆', '濟', '滄', '海'], + ['😀', '嘿', '嘿', '😃', '哈', '哈', '😄', '大', '笑', '😁', '嘻', '嘻'], + ['明', '朝', '(', '1368', '—', '1644', '年', ')', '和', '清', '朝', + '(', '1644', '—', '1911', '年', ')', ',', '是', '中', '国', '封', + '建', '王', '朝', '史', '上', '最', '后', '两', '个', '朝', '代'], + ['明', '代', '(', '1368', '-', '1644', ')', 'と', '清', '代', + '(', '1644', '-', '1911', ')', 'は', '、', '中', '国', 'の', '封', + '建', '王', '朝', 'の', '歴', '史', 'における', '最', '後', 'の2つの', '王', '朝', 'でした'], + ['명나라', '(', '1368', '-', '1644', ')', '와', '청나라', '(', '1644', '-', '1911', ')', '는', + '중국', '봉건', '왕조의', '역사에서', '마지막', '두', '왕조였다']], + expected_offsets_start=[[0, 8, 11, 18, 21, 24, 27, 30], + [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42], + [0, 4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37], + [0, 3, 6, 9, 13, 16, 20, 23, 26, 29, 32, 35, 38, 42, 45, 49, + 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 85, 88, 91, 94, 97, 100], + [0, 3, 6, 9, 13, 14, 18, 21, 24, 27, 30, 33, 37, 38, 42, 45, 48, 51, + 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 93, 96, 99, 109, 112, 115], + [0, 10, 11, 15, 16, 20, 21, 25, 35, 36, 40, 41, 45, 46, 50, 57, 64, 74, 87, 97, 101]], + expected_offsets_limit=[[7, 10, 18, 21, 24, 27, 30, 33], + [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45], + [4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40], + [3, 6, 9, 13, 16, 20, 23, 26, 29, 32, 35, 38, 42, 45, 49, 52, 55, 58, + 61, 64, 67, 70, 73, 76, 79, 82, 85, 88, 91, 94, 97, 100, 103], + [3, 6, 9, 13, 14, 18, 21, 24, 27, 30, 33, 37, 38, 42, 45, 48, 51, 54, + 57, 60, 63, 66, 69, 72, 75, 78, 81, 93, 96, 99, 109, 112, 115, 124], + [9, 11, 15, 16, 20, 21, 24, 34, 36, 40, 41, 45, 46, 49, 56, 63, 73, 86, 96, 100, 113]] + ), + dict( + first=7, + last=7, + expected_tokens=[['this', 'is', 'a', 'funky', 'string']], + expected_offsets_start=[[0, 5, 8, 10, 16]], + expected_offsets_limit=[[4, 7, 9, 15, 22]], + lower_case=True + ), +] + + +def check_basic_tokenizer_default(first, last, expected_tokens, expected_offsets_start, expected_offsets_limit, + lower_case=False, keep_whitespace=False, + normalization_form=text.utils.NormalizeForm.NONE, preserve_unused_token=False): + dataset = ds.TextFileDataset(BASIC_TOKENIZER_FILE, shuffle=False) + if first > 1: + dataset = dataset.skip(first - 1) + if last >= first: + dataset = dataset.take(last - first + 1) + + basic_tokenizer = text.BasicTokenizer(lower_case=lower_case, + keep_whitespace=keep_whitespace, + normalization_form=normalization_form, + preserve_unused_token=preserve_unused_token) + + dataset = dataset.map(operations=basic_tokenizer) + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['text']) + logger.info("Out:", token) + logger.info("Exp:", expected_tokens[count]) + np.testing.assert_array_equal(token, expected_tokens[count]) + count = count + 1 + + +def check_basic_tokenizer_with_offsets(first, last, expected_tokens, expected_offsets_start, expected_offsets_limit, + lower_case=False, keep_whitespace=False, + normalization_form=text.utils.NormalizeForm.NONE, preserve_unused_token=False): + dataset = ds.TextFileDataset(BASIC_TOKENIZER_FILE, shuffle=False) + if first > 1: + dataset = dataset.skip(first - 1) + if last >= first: + dataset = dataset.take(last - first + 1) + + basic_tokenizer = text.BasicTokenizer(lower_case=lower_case, + keep_whitespace=keep_whitespace, + normalization_form=normalization_form, + preserve_unused_token=preserve_unused_token, + with_offsets=True) + + dataset = dataset.map(input_columns=['text'], output_columns=['token', 'offsets_start', 'offsets_limit'], + columns_order=['token', 'offsets_start', 'offsets_limit'], operations=basic_tokenizer) + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['token']) + logger.info("Out:", token) + logger.info("Exp:", expected_tokens[count]) + np.testing.assert_array_equal(token, expected_tokens[count]) + np.testing.assert_array_equal(i['offsets_start'], expected_offsets_start[count]) + np.testing.assert_array_equal(i['offsets_limit'], expected_offsets_limit[count]) + count = count + 1 + +def test_basic_tokenizer_with_offsets(): + """ + Test BasicTokenizer + """ + for paras in test_paras: + check_basic_tokenizer_with_offsets(**paras) + + +def test_basic_tokenizer_default(): + """ + Test BasicTokenizer + """ + for paras in test_paras: + check_basic_tokenizer_default(**paras) + + +if __name__ == '__main__': + test_basic_tokenizer_default() + test_basic_tokenizer_with_offsets() diff --git a/tests/ut/python/dataset/test_text_bert_tokenizer.py b/tests/ut/python/dataset/test_text_bert_tokenizer.py new file mode 100644 index 0000000000..b29f94eb32 --- /dev/null +++ b/tests/ut/python/dataset/test_text_bert_tokenizer.py @@ -0,0 +1,246 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing BertTokenizer op in DE +""" +import numpy as np +import mindspore.dataset as ds +from mindspore import log as logger +import mindspore.dataset.text as text + +BERT_TOKENIZER_FILE = "../data/dataset/testTokenizerData/bert_tokenizer.txt" + +vocab_bert = [ + "床", "前", "明", "月", "光", "疑", "是", "地", "上", "霜", "举", "头", "望", "低", "思", "故", "乡", + "繁", "體", "字", "嘿", "哈", "大", "笑", "嘻", + "i", "am", "mak", "make", "small", "mistake", "##s", "during", "work", "##ing", "hour", + "😀", "😃", "😄", "😁", "+", "/", "-", "=", "12", "28", "40", "16", " ", "I", + "[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]", "[unused1]", "[unused10]" +] +pad = '' +test_paras = [ + # test chinese text + dict( + first=1, + last=4, + expect_str=[['床', '前', '明', '月', '光'], + ['疑', '是', '地', '上', '霜'], + ['举', '头', '望', '明', '月'], + ['低', '头', '思', '故', '乡']], + expected_offsets_start=[[0, 3, 6, 9, 12], + [0, 3, 6, 9, 12], + [0, 3, 6, 9, 12], + [0, 3, 6, 9, 12]], + expected_offsets_limit=[[3, 6, 9, 12, 15], + [3, 6, 9, 12, 15], + [3, 6, 9, 12, 15], + [3, 6, 9, 12, 15]], + vocab_list=vocab_bert + ), + # test english text + dict( + first=5, + last=5, + expect_str=[['i', 'am', 'mak', '##ing', 'small', 'mistake', '##s', 'during', 'work', '##ing', 'hour', '##s']], + expected_offsets_start=[[0, 2, 5, 8, 12, 18, 25, 27, 34, 38, 42, 46]], + expected_offsets_limit=[[1, 4, 8, 11, 17, 25, 26, 33, 38, 41, 46, 47]], + lower_case=True, + vocab_list=vocab_bert + ), + dict( + first=5, + last=5, + expect_str=[['I', "am", 'mak', '##ing', 'small', 'mistake', '##s', 'during', 'work', '##ing', 'hour', '##s']], + expected_offsets_start=[[0, 2, 5, 8, 12, 18, 25, 27, 34, 38, 42, 46]], + expected_offsets_limit=[[1, 4, 8, 11, 17, 25, 26, 33, 38, 41, 46, 47]], + lower_case=False, + vocab_list=vocab_bert + ), + # test emoji tokens + dict( + first=6, + last=7, + expect_str=[ + ['😀', '嘿', '嘿', '😃', '哈', '哈', '😄', '大', '笑', '😁', '嘻', '嘻'], + ['繁', '體', '字']], + expected_offsets_start=[[0, 4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37], [0, 3, 6]], + expected_offsets_limit=[[4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40], [3, 6, 9]], + normalization_form=text.utils.NormalizeForm.NFKC, + vocab_list=vocab_bert + ), + # test preserved tokens + dict( + first=8, + last=14, + expect_str=[ + ['[UNK]', '[CLS]'], + ['[UNK]', '[SEP]'], + ['[UNK]', '[UNK]'], + ['[UNK]', '[PAD]'], + ['[UNK]', '[MASK]'], + ['[unused1]'], + ['[unused10]'] + ], + expected_offsets_start=[[0, 7], [0, 7], [0, 7], [0, 7], [0, 7], [0], [0]], + expected_offsets_limit=[[6, 12], [6, 12], [6, 12], [6, 12], [6, 13], [9], [10]], + lower_case=False, + vocab_list=vocab_bert, + preserve_unused_token=True, + ), + dict( + first=8, + last=14, + expect_str=[ + ['[UNK]', '[CLS]'], + ['[UNK]', '[SEP]'], + ['[UNK]', '[UNK]'], + ['[UNK]', '[PAD]'], + ['[UNK]', '[MASK]'], + ['[unused1]'], + ['[unused10]'] + ], + expected_offsets_start=[[0, 7], [0, 7], [0, 7], [0, 7], [0, 7], [0], [0]], + expected_offsets_limit=[[6, 12], [6, 12], [6, 12], [6, 12], [6, 13], [9], [10]], + lower_case=True, + vocab_list=vocab_bert, + preserve_unused_token=True, + ), + # test special symbol + dict( + first=15, + last=15, + expect_str=[['12', '+', '/', '-', '28', '=', '40', '/', '-', '16']], + expected_offsets_start=[[0, 2, 3, 4, 5, 7, 8, 10, 11, 12]], + expected_offsets_limit=[[2, 3, 4, 5, 7, 8, 10, 11, 12, 14]], + preserve_unused_token=True, + vocab_list=vocab_bert + ), + # test non-default parms + dict( + first=8, + last=8, + expect_str=[['[UNK]', ' ', '[CLS]']], + expected_offsets_start=[[0, 6, 7]], + expected_offsets_limit=[[6, 7, 12]], + lower_case=False, + vocab_list=vocab_bert, + preserve_unused_token=True, + keep_whitespace=True + ), + dict( + first=8, + last=8, + expect_str=[['unused', ' ', '[CLS]']], + expected_offsets_start=[[0, 6, 7]], + expected_offsets_limit=[[6, 7, 12]], + lower_case=False, + vocab_list=vocab_bert, + preserve_unused_token=True, + keep_whitespace=True, + unknown_token='' + ), + dict( + first=8, + last=8, + expect_str=[['unused', ' ', '[', 'CLS', ']']], + expected_offsets_start=[[0, 6, 7, 8, 11]], + expected_offsets_limit=[[6, 7, 8, 11, 12]], + lower_case=False, + vocab_list=vocab_bert, + preserve_unused_token=False, + keep_whitespace=True, + unknown_token='' + ), +] + + +def check_bert_tokenizer_default(first, last, expect_str, + expected_offsets_start, expected_offsets_limit, + vocab_list, suffix_indicator='##', + max_bytes_per_token=100, unknown_token='[UNK]', + lower_case=False, keep_whitespace=False, + normalization_form=text.utils.NormalizeForm.NONE, + preserve_unused_token=False): + dataset = ds.TextFileDataset(BERT_TOKENIZER_FILE, shuffle=False) + if first > 1: + dataset = dataset.skip(first - 1) + if last >= first: + dataset = dataset.take(last - first + 1) + vocab = text.Vocab.from_list(vocab_list) + tokenizer_op = text.BertTokenizer( + vocab=vocab, suffix_indicator=suffix_indicator, + max_bytes_per_token=max_bytes_per_token, unknown_token=unknown_token, + lower_case=lower_case, keep_whitespace=keep_whitespace, + normalization_form=normalization_form, + preserve_unused_token=preserve_unused_token) + dataset = dataset.map(operations=tokenizer_op) + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['text']) + logger.info("Out:", token) + logger.info("Exp:", expect_str[count]) + np.testing.assert_array_equal(token, expect_str[count]) + count = count + 1 + + +def check_bert_tokenizer_with_offsets(first, last, expect_str, + expected_offsets_start, expected_offsets_limit, + vocab_list, suffix_indicator='##', + max_bytes_per_token=100, unknown_token='[UNK]', + lower_case=False, keep_whitespace=False, + normalization_form=text.utils.NormalizeForm.NONE, + preserve_unused_token=False): + dataset = ds.TextFileDataset(BERT_TOKENIZER_FILE, shuffle=False) + if first > 1: + dataset = dataset.skip(first - 1) + if last >= first: + dataset = dataset.take(last - first + 1) + vocab = text.Vocab.from_list(vocab_list) + tokenizer_op = text.BertTokenizer( + vocab=vocab, suffix_indicator=suffix_indicator, max_bytes_per_token=max_bytes_per_token, + unknown_token=unknown_token, lower_case=lower_case, keep_whitespace=keep_whitespace, + normalization_form=normalization_form, preserve_unused_token=preserve_unused_token, with_offsets=True) + dataset = dataset.map(input_columns=['text'], output_columns=['token', 'offsets_start', 'offsets_limit'], + columns_order=['token', 'offsets_start', 'offsets_limit'], operations=tokenizer_op) + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['token']) + logger.info("Out:", token) + logger.info("Exp:", expect_str[count]) + np.testing.assert_array_equal(token, expect_str[count]) + np.testing.assert_array_equal(i['offsets_start'], expected_offsets_start[count]) + np.testing.assert_array_equal(i['offsets_limit'], expected_offsets_limit[count]) + count = count + 1 + + +def test_bert_tokenizer_default(): + """ + Test WordpieceTokenizer when with_offsets=False + """ + for paras in test_paras: + check_bert_tokenizer_default(**paras) + + +def test_bert_tokenizer_with_offsets(): + """ + Test WordpieceTokenizer when with_offsets=True + """ + for paras in test_paras: + check_bert_tokenizer_with_offsets(**paras) + + +if __name__ == '__main__': + test_bert_tokenizer_default() + test_bert_tokenizer_with_offsets() diff --git a/tests/ut/python/dataset/test_text_jieba_tokenizer.py b/tests/ut/python/dataset/test_text_jieba_tokenizer.py new file mode 100644 index 0000000000..66665b61e6 --- /dev/null +++ b/tests/ut/python/dataset/test_text_jieba_tokenizer.py @@ -0,0 +1,471 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import mindspore.dataset as ds +from mindspore.dataset.text import JiebaTokenizer +from mindspore.dataset.text import JiebaMode, to_str + +DATA_FILE = "../data/dataset/testJiebaDataset/3.txt" +DATA_ALL_FILE = "../data/dataset/testJiebaDataset/*" + +HMM_FILE = "../data/dataset/jiebadict/hmm_model.utf8" +MP_FILE = "../data/dataset/jiebadict/jieba.dict.utf8" + + +def test_jieba_1(): + """Test jieba tokenizer with MP mode""" + data = ds.TextFileDataset(DATA_FILE) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) + data = data.map(input_columns=["text"], + operations=jieba_op, num_parallel_workers=1) + expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] + ret = [] + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +def test_jieba_1_1(): + """Test jieba tokenizer with HMM mode""" + data = ds.TextFileDataset(DATA_FILE) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.HMM) + data = data.map(input_columns=["text"], + operations=jieba_op, num_parallel_workers=1) + expect = ['今天', '天气', '太', '好', '了', '我们', '一起', '去', '外面', '玩', '吧'] + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +def test_jieba_1_2(): + """Test jieba tokenizer with HMM MIX""" + data = ds.TextFileDataset(DATA_FILE) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MIX) + data = data.map(input_columns=["text"], + operations=jieba_op, num_parallel_workers=1) + expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +def test_jieba_2(): + """Test add_word""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) + jieba_op.add_word("男默女泪") + expect = ['男默女泪', '市', '长江大桥'] + data = data.map(input_columns=["text"], + operations=jieba_op, num_parallel_workers=2) + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +def test_jieba_2_1(): + """Test add_word with freq""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) + jieba_op.add_word("男默女泪", 10) + data = data.map(input_columns=["text"], + operations=jieba_op, num_parallel_workers=2) + expect = ['男默女泪', '市', '长江大桥'] + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +def test_jieba_2_2(): + """Test add_word with invalid None Input""" + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) + try: + jieba_op.add_word(None) + except ValueError: + pass + + +def test_jieba_2_3(): + """Test add_word with freq, the value of freq affects the result of segmentation""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/6.txt" + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) + jieba_op.add_word("江大桥", 20000) + data = data.map(input_columns=["text"], + operations=jieba_op, num_parallel_workers=2) + expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式'] + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +def test_jieba_3(): + """Test add_dict with dict""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" + user_dict = { + "男默女泪": 10 + } + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) + jieba_op.add_dict(user_dict) + data = data.map(input_columns=["text"], + operations=jieba_op, num_parallel_workers=1) + expect = ['男默女泪', '市', '长江大桥'] + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +def test_jieba_3_1(): + """Test add_dict with dict""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" + user_dict = { + "男默女泪": 10, + "江大桥": 20000 + } + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) + jieba_op.add_dict(user_dict) + data = data.map(input_columns=["text"], + operations=jieba_op, num_parallel_workers=1) + expect = ['男默女泪', '市长', '江大桥'] + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +def test_jieba_4(): + DATA_FILE4 = "../data/dataset/testJiebaDataset/3.txt" + DICT_FILE = "../data/dataset/testJiebaDataset/user_dict.txt" + + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) + jieba_op.add_dict(DICT_FILE) + data = data.map(input_columns=["text"], + operations=jieba_op, num_parallel_workers=1) + expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +def test_jieba_4_1(): + """Test add dict with invalid file path""" + DICT_FILE = "" + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) + try: + jieba_op.add_dict(DICT_FILE) + except ValueError: + pass + + +def test_jieba_5(): + """Test add dict with file path""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/6.txt" + + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP) + jieba_op.add_word("江大桥", 20000) + data = data.map(input_columns=["text"], + operations=jieba_op, num_parallel_workers=1) + expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式'] + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +def test_jieba_with_offsets_1(): + """Test jieba tokenizer with MP mode""" + data = ds.TextFileDataset(DATA_FILE) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP, with_offsets=True) + data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + columns_order=["token", "offsets_start", "offsets_limit"], + operations=jieba_op, num_parallel_workers=1) + expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] + expected_offsets_start = [0, 12, 21, 27, 33, 36, 42] + expected_offsets_limit = [12, 21, 27, 33, 36, 42, 48] + ret = [] + for i in data.create_dict_iterator(): + ret = to_str(i["token"]) + for index, item in enumerate(ret): + assert item == expect[index] + for index, item in enumerate(i["offsets_start"]): + assert item == expected_offsets_start[index] + for index, item in enumerate(i["offsets_limit"]): + assert item == expected_offsets_limit[index] + + +def test_jieba_with_offsets_1_1(): + """Test jieba tokenizer with HMM mode""" + data = ds.TextFileDataset(DATA_FILE) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.HMM, with_offsets=True) + data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + columns_order=["token", "offsets_start", "offsets_limit"], + operations=jieba_op, num_parallel_workers=1) + expect = ['今天', '天气', '太', '好', '了', '我们', '一起', '去', '外面', '玩', '吧'] + expected_offsets_start = [0, 6, 12, 15, 18, 21, 27, 33, 36, 42, 45] + expected_offsets_limit = [6, 12, 15, 18, 21, 27, 33, 36, 42, 45, 48] + for i in data.create_dict_iterator(): + ret = to_str(i["token"]) + for index, item in enumerate(ret): + assert item == expect[index] + for index, item in enumerate(i["offsets_start"]): + assert item == expected_offsets_start[index] + for index, item in enumerate(i["offsets_limit"]): + assert item == expected_offsets_limit[index] + + +def test_jieba_with_offsets_1_2(): + """Test jieba tokenizer with HMM MIX""" + data = ds.TextFileDataset(DATA_FILE) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MIX, with_offsets=True) + data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + columns_order=["token", "offsets_start", "offsets_limit"], + operations=jieba_op, num_parallel_workers=1) + expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] + expected_offsets_start = [0, 12, 21, 27, 33, 36, 42] + expected_offsets_limit = [12, 21, 27, 33, 36, 42, 48] + for i in data.create_dict_iterator(): + ret = to_str(i["token"]) + for index, item in enumerate(ret): + assert item == expect[index] + for index, item in enumerate(i["offsets_start"]): + assert item == expected_offsets_start[index] + for index, item in enumerate(i["offsets_limit"]): + assert item == expected_offsets_limit[index] + + +def test_jieba_with_offsets_2(): + """Test add_word""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP, with_offsets=True) + jieba_op.add_word("男默女泪") + expect = ['男默女泪', '市', '长江大桥'] + data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + columns_order=["token", "offsets_start", "offsets_limit"], + operations=jieba_op, num_parallel_workers=2) + expected_offsets_start = [0, 12, 15] + expected_offsets_limit = [12, 15, 27] + for i in data.create_dict_iterator(): + ret = to_str(i["token"]) + for index, item in enumerate(ret): + assert item == expect[index] + for index, item in enumerate(i["offsets_start"]): + assert item == expected_offsets_start[index] + for index, item in enumerate(i["offsets_limit"]): + assert item == expected_offsets_limit[index] + + +def test_jieba_with_offsets_2_1(): + """Test add_word with freq""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP, with_offsets=True) + jieba_op.add_word("男默女泪", 10) + data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + columns_order=["token", "offsets_start", "offsets_limit"], + operations=jieba_op, num_parallel_workers=2) + expect = ['男默女泪', '市', '长江大桥'] + expected_offsets_start = [0, 12, 15] + expected_offsets_limit = [12, 15, 27] + for i in data.create_dict_iterator(): + ret = to_str(i["token"]) + for index, item in enumerate(ret): + assert item == expect[index] + for index, item in enumerate(i["offsets_start"]): + assert item == expected_offsets_start[index] + for index, item in enumerate(i["offsets_limit"]): + assert item == expected_offsets_limit[index] + + +def test_jieba_with_offsets_2_2(): + """Test add_word with freq, the value of freq affects the result of segmentation""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/6.txt" + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP, with_offsets=True) + jieba_op.add_word("江大桥", 20000) + data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + columns_order=["token", "offsets_start", "offsets_limit"], + operations=jieba_op, num_parallel_workers=2) + expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式'] + expected_offsets_start = [0, 6, 12, 21, 27, 30, 42, 45, 51] + expected_offsets_limit = [6, 12, 21, 27, 30, 42, 45, 51, 57] + for i in data.create_dict_iterator(): + ret = to_str(i["token"]) + for index, item in enumerate(ret): + assert item == expect[index] + for index, item in enumerate(i["offsets_start"]): + assert item == expected_offsets_start[index] + for index, item in enumerate(i["offsets_limit"]): + assert item == expected_offsets_limit[index] + + +def test_jieba_with_offsets_3(): + """Test add_dict with dict""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" + user_dict = { + "男默女泪": 10 + } + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP, with_offsets=True) + jieba_op.add_dict(user_dict) + data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + columns_order=["token", "offsets_start", "offsets_limit"], + operations=jieba_op, num_parallel_workers=1) + expect = ['男默女泪', '市', '长江大桥'] + expected_offsets_start = [0, 12, 15] + expected_offsets_limit = [12, 15, 27] + for i in data.create_dict_iterator(): + ret = to_str(i["token"]) + for index, item in enumerate(ret): + assert item == expect[index] + for index, item in enumerate(i["offsets_start"]): + assert item == expected_offsets_start[index] + for index, item in enumerate(i["offsets_limit"]): + assert item == expected_offsets_limit[index] + + +def test_jieba_with_offsets_3_1(): + """Test add_dict with dict""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/4.txt" + user_dict = { + "男默女泪": 10, + "江大桥": 20000 + } + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP, with_offsets=True) + jieba_op.add_dict(user_dict) + data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + columns_order=["token", "offsets_start", "offsets_limit"], + operations=jieba_op, num_parallel_workers=1) + expect = ['男默女泪', '市长', '江大桥'] + expected_offsets_start = [0, 12, 18] + expected_offsets_limit = [12, 18, 27] + for i in data.create_dict_iterator(): + ret = to_str(i["token"]) + for index, item in enumerate(ret): + assert item == expect[index] + for index, item in enumerate(i["offsets_start"]): + assert item == expected_offsets_start[index] + for index, item in enumerate(i["offsets_limit"]): + assert item == expected_offsets_limit[index] + + +def test_jieba_with_offsets_4(): + DATA_FILE4 = "../data/dataset/testJiebaDataset/3.txt" + DICT_FILE = "../data/dataset/testJiebaDataset/user_dict.txt" + + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP, with_offsets=True) + jieba_op.add_dict(DICT_FILE) + data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + columns_order=["token", "offsets_start", "offsets_limit"], + operations=jieba_op, num_parallel_workers=1) + expect = ['今天天气', '太好了', '我们', '一起', '去', '外面', '玩吧'] + expected_offsets_start = [0, 12, 21, 27, 33, 36, 42] + expected_offsets_limit = [12, 21, 27, 33, 36, 42, 48] + for i in data.create_dict_iterator(): + ret = to_str(i["token"]) + for index, item in enumerate(ret): + assert item == expect[index] + for index, item in enumerate(i["offsets_start"]): + assert item == expected_offsets_start[index] + for index, item in enumerate(i["offsets_limit"]): + assert item == expected_offsets_limit[index] + + +def test_jieba_with_offsets_5(): + """Test add dict with file path""" + DATA_FILE4 = "../data/dataset/testJiebaDataset/6.txt" + + data = ds.TextFileDataset(DATA_FILE4) + jieba_op = JiebaTokenizer(HMM_FILE, MP_FILE, mode=JiebaMode.MP, with_offsets=True) + jieba_op.add_word("江大桥", 20000) + data = data.map(input_columns=["text"], output_columns=["token", "offsets_start", "offsets_limit"], + columns_order=["token", "offsets_start", "offsets_limit"], + operations=jieba_op, num_parallel_workers=1) + expect = ['江州', '市长', '江大桥', '参加', '了', '长江大桥', '的', '通车', '仪式'] + expected_offsets_start = [0, 6, 12, 21, 27, 30, 42, 45, 51] + expected_offsets_limit = [6, 12, 21, 27, 30, 42, 45, 51, 57] + for i in data.create_dict_iterator(): + ret = to_str(i["token"]) + for index, item in enumerate(ret): + assert item == expect[index] + for index, item in enumerate(i["offsets_start"]): + assert item == expected_offsets_start[index] + for index, item in enumerate(i["offsets_limit"]): + assert item == expected_offsets_limit[index] + +def gen(): + text = np.array("今天天气太好了我们一起去外面玩吧".encode("UTF8"), dtype='S') + yield (text,) + + +def pytoken_op(input_data): + te = str(to_str(input_data)) + tokens = [] + tokens.append(te[:5].encode("UTF8")) + tokens.append(te[5:10].encode("UTF8")) + tokens.append(te[10:].encode("UTF8")) + return np.array(tokens, dtype='S') + + +def test_jieba_6(): + data = ds.GeneratorDataset(gen, column_names=["text"]) + data = data.map(input_columns=["text"], + operations=pytoken_op, num_parallel_workers=1) + expect = ['今天天气太', '好了我们一', '起去外面玩吧'] + for i in data.create_dict_iterator(): + ret = to_str(i["text"]) + for index, item in enumerate(ret): + assert item == expect[index] + + +if __name__ == "__main__": + test_jieba_1() + test_jieba_1_1() + test_jieba_1_2() + test_jieba_2() + test_jieba_2_1() + test_jieba_2_2() + test_jieba_3() + test_jieba_3_1() + test_jieba_4() + test_jieba_4_1() + test_jieba_5() + test_jieba_5() + test_jieba_6() + test_jieba_with_offsets_1() + test_jieba_with_offsets_1_1() + test_jieba_with_offsets_1_2() + test_jieba_with_offsets_2() + test_jieba_with_offsets_2_1() + test_jieba_with_offsets_2_2() + test_jieba_with_offsets_3() + test_jieba_with_offsets_3_1() + test_jieba_with_offsets_4() + test_jieba_with_offsets_5() diff --git a/tests/ut/python/dataset/test_text_tokenizer.py b/tests/ut/python/dataset/test_text_tokenizer.py new file mode 100644 index 0000000000..2e2b7b741d --- /dev/null +++ b/tests/ut/python/dataset/test_text_tokenizer.py @@ -0,0 +1,380 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing UnicodeCharTokenizer op in DE +""" +import numpy as np +import mindspore.dataset as ds +from mindspore import log as logger +import mindspore.dataset.text as text + +DATA_FILE = "../data/dataset/testTokenizerData/1.txt" +NORMALIZE_FILE = "../data/dataset/testTokenizerData/normalize.txt" +REGEX_REPLACE_FILE = "../data/dataset/testTokenizerData/regex_replace.txt" +REGEX_TOKENIZER_FILE = "../data/dataset/testTokenizerData/regex_tokenizer.txt" + + +def split_by_unicode_char(input_strs): + """ + Split utf-8 strings to unicode characters + """ + out = [] + for s in input_strs: + out.append([c for c in s]) + return out + + +def test_unicode_char_tokenizer_default(): + """ + Test UnicodeCharTokenizer + """ + input_strs = ("Welcome to Beijing!", "北京欢迎您!", "我喜欢English!", " ") + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + tokenizer = text.UnicodeCharTokenizer() + dataset = dataset.map(operations=tokenizer) + tokens = [] + for i in dataset.create_dict_iterator(): + token = text.to_str(i['text']).tolist() + tokens.append(token) + logger.info("The out tokens is : {}".format(tokens)) + assert split_by_unicode_char(input_strs) == tokens + + +def test_unicode_char_tokenizer_with_offsets(): + """ + Test UnicodeCharTokenizer + """ + input_strs = ("Welcome to Beijing!", "北京欢迎您!", "我喜欢English!", " ") + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + tokenizer = text.UnicodeCharTokenizer(with_offsets=True) + dataset = dataset.map(input_columns=['text'], output_columns=['token', 'offsets_start', 'offsets_limit'], + columns_order=['token', 'offsets_start', 'offsets_limit'], operations=tokenizer) + tokens = [] + expected_offsets_start = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], + [0, 3, 6, 9, 12, 15], [0, 3, 6, 9, 10, 11, 12, 13, 14, 15, 16], [0, 1]] + expected_offsets_limit = [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], + [3, 6, 9, 12, 15, 18], [3, 6, 9, 10, 11, 12, 13, 14, 15, 16, 17], [1, 2]] + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['token']).tolist() + tokens.append(token) + np.testing.assert_array_equal(i['offsets_start'], expected_offsets_start[count]) + np.testing.assert_array_equal(i['offsets_limit'], expected_offsets_limit[count]) + count += 1 + logger.info("The out tokens is : {}".format(tokens)) + assert split_by_unicode_char(input_strs) == tokens + + +def test_whitespace_tokenizer_default(): + """ + Test WhitespaceTokenizer + """ + whitespace_strs = [["Welcome", "to", "Beijing!"], + ["北京欢迎您!"], + ["我喜欢English!"], + [""]] + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + tokenizer = text.WhitespaceTokenizer() + dataset = dataset.map(operations=tokenizer) + tokens = [] + for i in dataset.create_dict_iterator(): + token = text.to_str(i['text']).tolist() + tokens.append(token) + logger.info("The out tokens is : {}".format(tokens)) + assert whitespace_strs == tokens + + +def test_whitespace_tokenizer_with_offsets(): + """ + Test WhitespaceTokenizer + """ + whitespace_strs = [["Welcome", "to", "Beijing!"], + ["北京欢迎您!"], + ["我喜欢English!"], + [""]] + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + tokenizer = text.WhitespaceTokenizer(with_offsets=True) + dataset = dataset.map(input_columns=['text'], output_columns=['token', 'offsets_start', 'offsets_limit'], + columns_order=['token', 'offsets_start', 'offsets_limit'], operations=tokenizer) + tokens = [] + expected_offsets_start = [[0, 8, 11], [0], [0], [0]] + expected_offsets_limit = [[7, 10, 19], [18], [17], [0]] + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['token']).tolist() + tokens.append(token) + np.testing.assert_array_equal(i['offsets_start'], expected_offsets_start[count]) + np.testing.assert_array_equal(i['offsets_limit'], expected_offsets_limit[count]) + count += 1 + + logger.info("The out tokens is : {}".format(tokens)) + assert whitespace_strs == tokens + + +def test_unicode_script_tokenizer_default(): + """ + Test UnicodeScriptTokenizer when para keep_whitespace=False + """ + unicode_script_strs = [["Welcome", "to", "Beijing", "!"], + ["北京欢迎您", "!"], + ["我喜欢", "English", "!"], + [""]] + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + tokenizer = text.UnicodeScriptTokenizer(keep_whitespace=False) + dataset = dataset.map(operations=tokenizer) + + tokens = [] + for i in dataset.create_dict_iterator(): + token = text.to_str(i['text']).tolist() + tokens.append(token) + logger.info("The out tokens is : {}".format(tokens)) + assert unicode_script_strs == tokens + + +def test_unicode_script_tokenizer_default2(): + """ + Test UnicodeScriptTokenizer when para keep_whitespace=True + """ + unicode_script_strs2 = [["Welcome", " ", "to", " ", "Beijing", "!"], + ["北京欢迎您", "!"], + ["我喜欢", "English", "!"], + [" "]] + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + tokenizer = text.UnicodeScriptTokenizer(keep_whitespace=True) + dataset = dataset.map(operations=tokenizer) + tokens = [] + for i in dataset.create_dict_iterator(): + token = text.to_str(i['text']).tolist() + tokens.append(token) + logger.info("The out tokens is :", tokens) + assert unicode_script_strs2 == tokens + + +def test_unicode_script_tokenizer_with_offsets(): + """ + Test UnicodeScriptTokenizer when para keep_whitespace=False and with_offsets=True + """ + unicode_script_strs = [["Welcome", "to", "Beijing", "!"], + ["北京欢迎您", "!"], + ["我喜欢", "English", "!"], + [""]] + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + tokenizer = text.UnicodeScriptTokenizer(keep_whitespace=False, with_offsets=True) + dataset = dataset.map(input_columns=['text'], output_columns=['token', 'offsets_start', 'offsets_limit'], + columns_order=['token', 'offsets_start', 'offsets_limit'], operations=tokenizer) + tokens = [] + expected_offsets_start = [[0, 8, 11, 18], [0, 15], [0, 9, 16], [0]] + expected_offsets_limit = [[7, 10, 18, 19], [15, 18], [9, 16, 17], [0]] + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['token']).tolist() + tokens.append(token) + np.testing.assert_array_equal(i['offsets_start'], expected_offsets_start[count]) + np.testing.assert_array_equal(i['offsets_limit'], expected_offsets_limit[count]) + count += 1 + logger.info("The out tokens is : {}".format(tokens)) + assert unicode_script_strs == tokens + + +def test_unicode_script_tokenizer_with_offsets2(): + """ + Test UnicodeScriptTokenizer when para keep_whitespace=True and with_offsets=True + """ + unicode_script_strs2 = [["Welcome", " ", "to", " ", "Beijing", "!"], + ["北京欢迎您", "!"], + ["我喜欢", "English", "!"], + [" "]] + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + tokenizer = text.UnicodeScriptTokenizer(keep_whitespace=True, with_offsets=True) + dataset = dataset.map(input_columns=['text'], output_columns=['token', 'offsets_start', 'offsets_limit'], + columns_order=['token', 'offsets_start', 'offsets_limit'], operations=tokenizer) + tokens = [] + expected_offsets_start = [[0, 7, 8, 10, 11, 18], [0, 15], [0, 9, 16], [0]] + expected_offsets_limit = [[7, 8, 10, 11, 18, 19], [15, 18], [9, 16, 17], [2]] + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['token']).tolist() + tokens.append(token) + np.testing.assert_array_equal(i['offsets_start'], expected_offsets_start[count]) + np.testing.assert_array_equal(i['offsets_limit'], expected_offsets_limit[count]) + count += 1 + logger.info("The out tokens is :", tokens) + assert unicode_script_strs2 == tokens + + +def test_case_fold(): + """ + Test CaseFold + """ + expect_strs = ["welcome to beijing!", "北京欢迎您!", "我喜欢english!", " "] + dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) + op = text.CaseFold() + dataset = dataset.map(operations=op) + + lower_strs = [] + for i in dataset.create_dict_iterator(): + token = text.to_str(i['text']).tolist() + lower_strs.append(token) + assert lower_strs == expect_strs + + +def test_normalize_utf8(): + """ + Test NormalizeUTF8 + """ + + def normalize(normalize_form): + dataset = ds.TextFileDataset(NORMALIZE_FILE, shuffle=False) + normalize = text.NormalizeUTF8(normalize_form=normalize_form) + dataset = dataset.map(operations=normalize) + out_bytes = [] + out_texts = [] + for i in dataset.create_dict_iterator(): + out_bytes.append(i['text']) + out_texts.append(text.to_str(i['text']).tolist()) + logger.info("The out bytes is : ", out_bytes) + logger.info("The out texts is: ", out_texts) + return out_bytes + + expect_normlize_data = [ + # NFC + [b'\xe1\xb9\xa9', b'\xe1\xb8\x8d\xcc\x87', b'q\xcc\xa3\xcc\x87', + b'\xef\xac\x81', b'2\xe2\x81\xb5', b'\xe1\xba\x9b\xcc\xa3'], + # NFKC + [b'\xe1\xb9\xa9', b'\xe1\xb8\x8d\xcc\x87', b'q\xcc\xa3\xcc\x87', + b'fi', b'25', b'\xe1\xb9\xa9'], + # NFD + [b's\xcc\xa3\xcc\x87', b'd\xcc\xa3\xcc\x87', b'q\xcc\xa3\xcc\x87', + b'\xef\xac\x81', b'2\xe2\x81\xb5', b'\xc5\xbf\xcc\xa3\xcc\x87'], + # NFKD + [b's\xcc\xa3\xcc\x87', b'd\xcc\xa3\xcc\x87', b'q\xcc\xa3\xcc\x87', + b'fi', b'25', b's\xcc\xa3\xcc\x87'] + ] + assert normalize(text.utils.NormalizeForm.NFC) == expect_normlize_data[0] + assert normalize(text.utils.NormalizeForm.NFKC) == expect_normlize_data[1] + assert normalize(text.utils.NormalizeForm.NFD) == expect_normlize_data[2] + assert normalize(text.utils.NormalizeForm.NFKD) == expect_normlize_data[3] + + +def test_regex_replace(): + """ + Test RegexReplace + """ + + def regex_replace(first, last, expect_str, pattern, replace): + dataset = ds.TextFileDataset(REGEX_REPLACE_FILE, shuffle=False) + if first > 1: + dataset = dataset.skip(first - 1) + if last >= first: + dataset = dataset.take(last - first + 1) + replace_op = text.RegexReplace(pattern, replace) + dataset = dataset.map(operations=replace_op) + out_text = [] + for i in dataset.create_dict_iterator(): + token = text.to_str(i['text']).tolist() + out_text.append(token) + logger.info("Out:", out_text) + logger.info("Exp:", expect_str) + assert expect_str == out_text + + regex_replace(1, 2, ['H____ W____', "L__'_ G_"], "\\p{Ll}", '_') + regex_replace(3, 5, ['hello', 'world', '31:beijing'], "^(\\d:|b:)", "") + regex_replace(6, 6, ["WelcometoChina!"], "\\s+", "") + regex_replace(7, 8, ['我不想长大', 'WelcometoShenzhen!'], "\\p{Cc}|\\p{Cf}|\\s+", "") + + +def test_regex_tokenizer_default(): + """ + Test RegexTokenizer + """ + + def regex_tokenizer(first, last, expect_str, delim_pattern, keep_delim_pattern): + dataset = ds.TextFileDataset(REGEX_TOKENIZER_FILE, shuffle=False) + if first > 1: + dataset = dataset.skip(first - 1) + if last >= first: + dataset = dataset.take(last - first + 1) + tokenizer_op = text.RegexTokenizer(delim_pattern, keep_delim_pattern) + dataset = dataset.map(operations=tokenizer_op) + out_text = [] + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['text']).tolist() + np.testing.assert_array_equal(token, expect_str[count]) + count += 1 + out_text.append(token) + logger.info("Out:", out_text) + logger.info("Exp:", expect_str) + + regex_tokenizer(1, 1, [['Welcome', 'to', 'Shenzhen!']], "\\s+", "") + regex_tokenizer(1, 1, [['Welcome', ' ', 'to', ' ', 'Shenzhen!']], "\\s+", "\\s+") + regex_tokenizer(2, 2, [['北', '京', '欢', '迎', '您', '!Welcome to Beijing!']], r"\p{Han}", r"\p{Han}") + regex_tokenizer(3, 3, [['12', '¥+', '36', '¥=?']], r"[\p{P}|\p{S}]+", r"[\p{P}|\p{S}]+") + regex_tokenizer(3, 3, [['12', '36']], r"[\p{P}|\p{S}]+", "") + regex_tokenizer(3, 3, [['¥+', '¥=?']], r"[\p{N}]+", "") + + +def test_regex_tokenizer_with_offsets(): + """ + Test RegexTokenizer + """ + + def regex_tokenizer(first, last, expect_str, expected_offsets_start, expected_offsets_limit, delim_pattern, + keep_delim_pattern): + dataset = ds.TextFileDataset(REGEX_TOKENIZER_FILE, shuffle=False) + if first > 1: + dataset = dataset.skip(first - 1) + if last >= first: + dataset = dataset.take(last - first + 1) + tokenizer_op = text.RegexTokenizer(delim_pattern, keep_delim_pattern, with_offsets=True) + dataset = dataset.map(input_columns=['text'], output_columns=['token', 'offsets_start', 'offsets_limit'], + columns_order=['token', 'offsets_start', 'offsets_limit'], operations=tokenizer_op) + out_text = [] + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['token']).tolist() + np.testing.assert_array_equal(token, expect_str[count]) + np.testing.assert_array_equal(i['offsets_start'], expected_offsets_start[count]) + np.testing.assert_array_equal(i['offsets_limit'], expected_offsets_limit[count]) + count += 1 + out_text.append(token) + logger.info("Out:", out_text) + logger.info("Exp:", expect_str) + + regex_tokenizer(1, 1, [['Welcome', 'to', 'Shenzhen!']], [[0, 8, 11]], [[7, 10, 20]], "\\s+", "") + regex_tokenizer(1, 1, [['Welcome', ' ', 'to', ' ', 'Shenzhen!']], [[0, 7, 8, 10, 11]], [[7, 8, 10, 11, 20]], + "\\s+", "\\s+") + regex_tokenizer(2, 2, [['北', '京', '欢', '迎', '您', '!Welcome to Beijing!']], [[0, 3, 6, 9, 12, 15]], + [[3, 6, 9, 12, 15, 35]], r"\p{Han}", r"\p{Han}") + regex_tokenizer(3, 3, [['12', '¥+', '36', '¥=?']], [[0, 2, 6, 8]], [[2, 6, 8, 13]], + r"[\p{P}|\p{S}]+", r"[\p{P}|\p{S}]+") + regex_tokenizer(3, 3, [['12', '36']], [[0, 6]], [[2, 8]], r"[\p{P}|\p{S}]+", "") + regex_tokenizer(3, 3, [['¥+', '¥=?']], [[2, 8]], [[6, 13]], r"[\p{N}]+", "") + + +if __name__ == '__main__': + test_unicode_char_tokenizer_default() + test_unicode_char_tokenizer_with_offsets() + test_whitespace_tokenizer_default() + test_whitespace_tokenizer_with_offsets() + test_unicode_script_tokenizer_default() + test_unicode_script_tokenizer_default2() + test_unicode_script_tokenizer_with_offsets() + test_unicode_script_tokenizer_with_offsets2() + test_case_fold() + test_normalize_utf8() + test_regex_replace() + test_regex_tokenizer_default() + test_regex_tokenizer_with_offsets() diff --git a/tests/ut/python/dataset/test_text_wordpiece_tokenizer.py b/tests/ut/python/dataset/test_text_wordpiece_tokenizer.py new file mode 100644 index 0000000000..8b47ec971e --- /dev/null +++ b/tests/ut/python/dataset/test_text_wordpiece_tokenizer.py @@ -0,0 +1,160 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing WordpieceTokenizer op in DE +""" +import numpy as np +import mindspore.dataset as ds +from mindspore import log as logger +import mindspore.dataset.text as text + +WORDPIECE_TOKENIZER_FILE = "../data/dataset/testTokenizerData/wordpiece_tokenizer.txt" + +vocab_english = [ + "book", "cholera", "era", "favor", "##ite", "my", "is", "love", "dur", "##ing", "the" +] + +vocab_chinese = [ + "我", '最', '喜', '欢', '的', '书', '是', '霍', '乱', '时', '期', '爱', '情' +] + +vocab_mix = vocab_chinese + vocab_english + +test_paras = [ + dict( + first=1, + last=10, + expect_str=[['my'], ['favor', '##ite'], ['book'], ['is'], ['love'], ['dur', '##ing'], ['the'], ['cholera'], + ['era'], ['[UNK]']], + expected_offsets_start=[[0], [0, 5], [0], [0], [0], [0, 3], [0], [0], [0], [0]], + expected_offsets_limit=[[2], [5, 8], [4], [2], [4], [3, 6], [3], [7], [3], [4]], + vocab_list=vocab_english + ), + dict( + first=1, + last=10, + expect_str=[['my'], ['favor', '##ite'], ['book'], ['is'], ['love'], ['dur', '##ing'], ['the'], ['cholera'], + ['era'], ['what']], + expected_offsets_start=[[0], [0, 5], [0], [0], [0], [0, 3], [0], [0], [0], [0]], + expected_offsets_limit=[[2], [5, 8], [4], [2], [4], [3, 6], [3], [7], [3], [4]], + vocab_list=vocab_english, + unknown_token="" + ), + dict( + first=1, + last=10, + expect_str=[['my'], ['[UNK]'], ['book'], ['is'], ['love'], ['[UNK]'], ['the'], ['[UNK]'], ['era'], ['[UNK]']], + expected_offsets_start=[[0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], + expected_offsets_limit=[[2], [5], [4], [2], [4], [5], [3], [5], [3], [4]], + vocab_list=vocab_english, + max_bytes_per_token=4 + ), + dict( + first=11, + last=25, + expect_str=[['我'], ['最'], ['喜'], ['欢'], ['的'], ['书'], ['是'], ['霍'], ['乱'], ['时'], ['期'], ['的'], ['爱'], ['情'], + ['[UNK]']], + expected_offsets_start=[[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], + expected_offsets_limit=[[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3]], + vocab_list=vocab_chinese, + ), + dict( + first=25, + last=25, + expect_str=[['您']], + expected_offsets_start=[[0]], + expected_offsets_limit=[[3]], + vocab_list=vocab_chinese, + unknown_token="" + ), + dict( + first=1, + last=25, + expect_str=[ + ['my'], ['favor', '##ite'], ['book'], ['is'], ['love'], ['dur', '##ing'], ['the'], ['cholera'], ['era'], + ['[UNK]'], + ['我'], ['最'], ['喜'], ['欢'], ['的'], ['书'], ['是'], ['霍'], ['乱'], ['时'], ['期'], ['的'], ['爱'], ['情'], + ['[UNK]']], + expected_offsets_start=[[0], [0, 5], [0], [0], [0], [0, 3], [0], [0], [0], [0], + [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], + expected_offsets_limit=[[2], [5, 8], [4], [2], [4], [3, 6], [3], [7], [3], [4], + [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3]], + vocab_list=vocab_mix, + ), +] + + +def check_wordpiece_tokenizer_default(first, last, expect_str, expected_offsets_start, expected_offsets_limit, + vocab_list, unknown_token='[UNK]', max_bytes_per_token=100): + dataset = ds.TextFileDataset(WORDPIECE_TOKENIZER_FILE, shuffle=False) + if first > 1: + dataset = dataset.skip(first - 1) + if last >= first: + dataset = dataset.take(last - first + 1) + vocab = text.Vocab.from_list(vocab_list) + tokenizer_op = text.WordpieceTokenizer(vocab=vocab, unknown_token=unknown_token, + max_bytes_per_token=max_bytes_per_token) + dataset = dataset.map(operations=tokenizer_op) + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['text']) + logger.info("Out:", token) + logger.info("Exp:", expect_str[count]) + np.testing.assert_array_equal(token, expect_str[count]) + count = count + 1 + + +def check_wordpiece_tokenizer_with_offsets(first, last, expect_str, expected_offsets_start, expected_offsets_limit, + vocab_list, unknown_token='[UNK]', max_bytes_per_token=100): + dataset = ds.TextFileDataset(WORDPIECE_TOKENIZER_FILE, shuffle=False) + if first > 1: + dataset = dataset.skip(first - 1) + if last >= first: + dataset = dataset.take(last - first + 1) + vocab = text.Vocab.from_list(vocab_list) + tokenizer_op = text.WordpieceTokenizer(vocab=vocab, with_offsets=True, unknown_token=unknown_token, + max_bytes_per_token=max_bytes_per_token) + dataset = dataset.map(input_columns=['text'], output_columns=['token', 'offsets_start', 'offsets_limit'], + columns_order=['token', 'offsets_start', 'offsets_limit'], operations=tokenizer_op) + count = 0 + for i in dataset.create_dict_iterator(): + token = text.to_str(i['token']) + logger.info("Out:", token) + logger.info("Exp:", expect_str[count]) + np.testing.assert_array_equal(token, expect_str[count]) + np.testing.assert_array_equal(i['offsets_start'], expected_offsets_start[count]) + np.testing.assert_array_equal(i['offsets_limit'], expected_offsets_limit[count]) + count = count + 1 + + +def test_wordpiece_tokenizer_default(): + """ + Test WordpieceTokenizer + """ + for paras in test_paras: + check_wordpiece_tokenizer_default(**paras) + + +def test_wordpiece_tokenizer_with_offsets(): + """ + Test WordpieceTokenizer + """ + for paras in test_paras: + check_wordpiece_tokenizer_with_offsets(**paras) + + +if __name__ == '__main__': + test_wordpiece_tokenizer_default() + test_wordpiece_tokenizer_with_offsets() diff --git a/tests/ut/python/dataset/test_tfreader_op.py b/tests/ut/python/dataset/test_tfreader_op.py index 5948b1e4c1..f57c387b35 100644 --- a/tests/ut/python/dataset/test_tfreader_op.py +++ b/tests/ut/python/dataset/test_tfreader_op.py @@ -12,21 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +""" +Test TFRecordDataset Ops +""" import numpy as np import pytest -from util import save_and_check import mindspore.common.dtype as mstype import mindspore.dataset as ds from mindspore import log as logger +from util import save_and_check_dict FILES = ["../data/dataset/testTFTestAllTypes/test.data"] DATASET_ROOT = "../data/dataset/testTFTestAllTypes/" SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" +DATA_FILES2 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data", + "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"] +SCHEMA_FILE2 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" GENERATE_GOLDEN = False -def test_case_tf_shape(): +def test_tfrecord_shape(): + logger.info("test_tfrecord_shape") schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaRank0.json" ds1 = ds.TFRecordDataset(FILES, schema_file) ds1 = ds1.batch(2) @@ -36,7 +45,8 @@ def test_case_tf_shape(): assert len(output_shape[-1]) == 1 -def test_case_tf_read_all_dataset(): +def test_tfrecord_read_all_dataset(): + logger.info("test_tfrecord_read_all_dataset") schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json" ds1 = ds.TFRecordDataset(FILES, schema_file) assert ds1.get_dataset_size() == 12 @@ -46,7 +56,8 @@ def test_case_tf_read_all_dataset(): assert count == 12 -def test_case_num_samples(): +def test_tfrecord_num_samples(): + logger.info("test_tfrecord_num_samples") schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8) assert ds1.get_dataset_size() == 8 @@ -56,7 +67,8 @@ def test_case_num_samples(): assert count == 8 -def test_case_num_samples2(): +def test_tfrecord_num_samples2(): + logger.info("test_tfrecord_num_samples2") schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" ds1 = ds.TFRecordDataset(FILES, schema_file) assert ds1.get_dataset_size() == 7 @@ -66,42 +78,41 @@ def test_case_num_samples2(): assert count == 7 -def test_case_tf_shape_2(): +def test_tfrecord_shape2(): + logger.info("test_tfrecord_shape2") ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) ds1 = ds1.batch(2) output_shape = ds1.output_shapes() assert len(output_shape[-1]) == 2 -def test_case_tf_file(): - logger.info("reading data from: {}".format(FILES[0])) - parameters = {"params": {}} +def test_tfrecord_files_basic(): + logger.info("test_tfrecord_files_basic") data = ds.TFRecordDataset(FILES, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) - filename = "tfreader_result.npz" - save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN) + filename = "tfrecord_files_basic.npz" + save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) -def test_case_tf_file_no_schema(): - logger.info("reading data from: {}".format(FILES[0])) - parameters = {"params": {}} +def test_tfrecord_no_schema(): + logger.info("test_tfrecord_no_schema") data = ds.TFRecordDataset(FILES, shuffle=ds.Shuffle.FILES) - filename = "tf_file_no_schema.npz" - save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN) + filename = "tfrecord_no_schema.npz" + save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) -def test_case_tf_file_pad(): - logger.info("reading data from: {}".format(FILES[0])) - parameters = {"params": {}} +def test_tfrecord_pad(): + logger.info("test_tfrecord_pad") schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaPadBytes10.json" data = ds.TFRecordDataset(FILES, schema_file, shuffle=ds.Shuffle.FILES) - filename = "tf_file_padBytes10.npz" - save_and_check(data, parameters, filename, generate_golden=GENERATE_GOLDEN) + filename = "tfrecord_pad_bytes10.npz" + save_and_check_dict(data, filename, generate_golden=GENERATE_GOLDEN) -def test_tf_files(): +def test_tfrecord_read_files(): + logger.info("test_tfrecord_read_files") pattern = DATASET_ROOT + "/test.data" data = ds.TFRecordDataset(pattern, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) assert sum([1 for _ in data]) == 12 @@ -123,7 +134,19 @@ def test_tf_files(): assert sum([1 for _ in data]) == 24 -def test_tf_record_schema(): +def test_tfrecord_multi_files(): + logger.info("test_tfrecord_multi_files") + data1 = ds.TFRecordDataset(DATA_FILES2, SCHEMA_FILE2, shuffle=False) + data1 = data1.repeat(1) + num_iter = 0 + for _ in data1.create_dict_iterator(): + num_iter += 1 + + assert num_iter == 12 + + +def test_tfrecord_schema(): + logger.info("test_tfrecord_schema") schema = ds.Schema() schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2]) @@ -142,7 +165,8 @@ def test_tf_record_schema(): assert np.array_equal(t1, t2) -def test_tf_record_shuffle(): +def test_tfrecord_shuffle(): + logger.info("test_tfrecord_shuffle") ds.config.set_seed(1) data1 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.GLOBAL) data2 = ds.TFRecordDataset(FILES, schema=SCHEMA_FILE, shuffle=ds.Shuffle.FILES) @@ -153,7 +177,8 @@ def test_tf_record_shuffle(): assert np.array_equal(t1, t2) -def test_tf_record_shard(): +def test_tfrecord_shard(): + logger.info("test_tfrecord_shard") tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] @@ -181,7 +206,8 @@ def test_tf_record_shard(): assert set(worker2_res) == set(worker1_res) -def test_tf_shard_equal_rows(): +def test_tfrecord_shard_equal_rows(): + logger.info("test_tfrecord_shard_equal_rows") tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data", "../data/dataset/tf_file_dataset/test3.data", "../data/dataset/tf_file_dataset/test4.data"] @@ -209,7 +235,8 @@ def test_tf_shard_equal_rows(): assert len(worker4_res) == 40 -def test_case_tf_file_no_schema_columns_list(): +def test_tfrecord_no_schema_columns_list(): + logger.info("test_tfrecord_no_schema_columns_list") data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"]) row = data.create_dict_iterator().get_next() assert row["col_sint16"] == [-32768] @@ -219,7 +246,8 @@ def test_case_tf_file_no_schema_columns_list(): assert "col_sint32" in str(info.value) -def test_tf_record_schema_columns_list(): +def test_tfrecord_schema_columns_list(): + logger.info("test_tfrecord_schema_columns_list") schema = ds.Schema() schema.add_column('col_1d', de_type=mstype.int64, shape=[2]) schema.add_column('col_2d', de_type=mstype.int64, shape=[2, 2]) @@ -238,7 +266,8 @@ def test_tf_record_schema_columns_list(): assert "col_sint32" in str(info.value) -def test_case_invalid_files(): +def test_tfrecord_invalid_files(): + logger.info("test_tfrecord_invalid_files") valid_file = "../data/dataset/testTFTestAllTypes/test.data" invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt" files = [invalid_file, valid_file, SCHEMA_FILE] @@ -266,19 +295,20 @@ def test_case_invalid_files(): if __name__ == '__main__': - test_case_tf_shape() - test_case_tf_read_all_dataset() - test_case_num_samples() - test_case_num_samples2() - test_case_tf_shape_2() - test_case_tf_file() - test_case_tf_file_no_schema() - test_case_tf_file_pad() - test_tf_files() - test_tf_record_schema() - test_tf_record_shuffle() - test_tf_record_shard() - test_tf_shard_equal_rows() - test_case_tf_file_no_schema_columns_list() - test_tf_record_schema_columns_list() - test_case_invalid_files() + test_tfrecord_shape() + test_tfrecord_read_all_dataset() + test_tfrecord_num_samples() + test_tfrecord_num_samples2() + test_tfrecord_shape2() + test_tfrecord_files_basic() + test_tfrecord_no_schema() + test_tfrecord_pad() + test_tfrecord_read_files() + test_tfrecord_multi_files() + test_tfrecord_schema() + test_tfrecord_shuffle() + test_tfrecord_shard() + test_tfrecord_shard_equal_rows() + test_tfrecord_no_schema_columns_list() + test_tfrecord_schema_columns_list() + test_tfrecord_invalid_files() diff --git a/tests/ut/python/dataset/test_tokenizer.py b/tests/ut/python/dataset/test_tokenizer.py deleted file mode 100644 index 2ec988d8dc..0000000000 --- a/tests/ut/python/dataset/test_tokenizer.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Testing UnicodeCharTokenizer op in DE -""" -import numpy as np -import mindspore.dataset as ds -from mindspore import log as logger -import mindspore.dataset.text as nlp - -DATA_FILE = "../data/dataset/testTokenizerData/1.txt" -NORMALIZE_FILE = "../data/dataset/testTokenizerData/normalize.txt" -REGEX_REPLACE_FILE = "../data/dataset/testTokenizerData/regex_replace.txt" -REGEX_TOKENIZER_FILE = "../data/dataset/testTokenizerData/regex_tokenizer.txt" - - -def split_by_unicode_char(input_strs): - """ - Split utf-8 strings to unicode characters - """ - out = [] - for s in input_strs: - out.append([c for c in s]) - return out - - -def test_unicode_char_tokenizer(): - """ - Test UnicodeCharTokenizer - """ - input_strs = ("Welcome to Beijing!", "北京欢迎您!", "我喜欢English!", " ") - dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) - tokenizer = nlp.UnicodeCharTokenizer() - dataset = dataset.map(operations=tokenizer) - tokens = [] - for i in dataset.create_dict_iterator(): - text = nlp.to_str(i['text']).tolist() - tokens.append(text) - logger.info("The out tokens is : {}".format(tokens)) - assert split_by_unicode_char(input_strs) == tokens - - -def test_whitespace_tokenizer(): - """ - Test WhitespaceTokenizer - """ - whitespace_strs = [["Welcome", "to", "Beijing!"], - ["北京欢迎您!"], - ["我喜欢English!"], - [""]] - dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) - tokenizer = nlp.WhitespaceTokenizer() - dataset = dataset.map(operations=tokenizer) - tokens = [] - for i in dataset.create_dict_iterator(): - text = nlp.to_str(i['text']).tolist() - tokens.append(text) - logger.info("The out tokens is : {}".format(tokens)) - assert whitespace_strs == tokens - - -def test_unicode_script_tokenizer(): - """ - Test UnicodeScriptTokenizer when para keep_whitespace=False - """ - unicode_script_strs = [["Welcome", "to", "Beijing", "!"], - ["北京欢迎您", "!"], - ["我喜欢", "English", "!"], - [""]] - dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) - tokenizer = nlp.UnicodeScriptTokenizer(keep_whitespace=False) - dataset = dataset.map(operations=tokenizer) - - tokens = [] - for i in dataset.create_dict_iterator(): - text = nlp.to_str(i['text']).tolist() - tokens.append(text) - logger.info("The out tokens is : {}".format(tokens)) - assert unicode_script_strs == tokens - - -def test_unicode_script_tokenizer2(): - """ - Test UnicodeScriptTokenizer when para keep_whitespace=True - """ - unicode_script_strs2 = [["Welcome", " ", "to", " ", "Beijing", "!"], - ["北京欢迎您", "!"], - ["我喜欢", "English", "!"], - [" "]] - dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) - tokenizer = nlp.UnicodeScriptTokenizer(keep_whitespace=True) - dataset = dataset.map(operations=tokenizer) - tokens = [] - for i in dataset.create_dict_iterator(): - text = nlp.to_str(i['text']).tolist() - tokens.append(text) - logger.info("The out tokens is :", tokens) - assert unicode_script_strs2 == tokens - - -def test_case_fold(): - """ - Test CaseFold - """ - expect_strs = ["welcome to beijing!", "北京欢迎您!", "我喜欢english!", " "] - dataset = ds.TextFileDataset(DATA_FILE, shuffle=False) - op = nlp.CaseFold() - dataset = dataset.map(operations=op) - - lower_strs = [] - for i in dataset.create_dict_iterator(): - text = nlp.to_str(i['text']).tolist() - lower_strs.append(text) - assert lower_strs == expect_strs - - -def test_normalize_utf8(): - """ - Test NormalizeUTF8 - """ - - def normalize(normalize_form): - dataset = ds.TextFileDataset(NORMALIZE_FILE, shuffle=False) - normalize = nlp.NormalizeUTF8(normalize_form=normalize_form) - dataset = dataset.map(operations=normalize) - out_bytes = [] - out_texts = [] - for i in dataset.create_dict_iterator(): - out_bytes.append(i['text']) - out_texts.append(nlp.to_str(i['text']).tolist()) - logger.info("The out bytes is : ", out_bytes) - logger.info("The out texts is: ", out_texts) - return out_bytes - - expect_normlize_data = [ - # NFC - [b'\xe1\xb9\xa9', b'\xe1\xb8\x8d\xcc\x87', b'q\xcc\xa3\xcc\x87', - b'\xef\xac\x81', b'2\xe2\x81\xb5', b'\xe1\xba\x9b\xcc\xa3'], - # NFKC - [b'\xe1\xb9\xa9', b'\xe1\xb8\x8d\xcc\x87', b'q\xcc\xa3\xcc\x87', - b'fi', b'25', b'\xe1\xb9\xa9'], - # NFD - [b's\xcc\xa3\xcc\x87', b'd\xcc\xa3\xcc\x87', b'q\xcc\xa3\xcc\x87', - b'\xef\xac\x81', b'2\xe2\x81\xb5', b'\xc5\xbf\xcc\xa3\xcc\x87'], - # NFKD - [b's\xcc\xa3\xcc\x87', b'd\xcc\xa3\xcc\x87', b'q\xcc\xa3\xcc\x87', - b'fi', b'25', b's\xcc\xa3\xcc\x87'] - ] - assert normalize(nlp.utils.NormalizeForm.NFC) == expect_normlize_data[0] - assert normalize(nlp.utils.NormalizeForm.NFKC) == expect_normlize_data[1] - assert normalize(nlp.utils.NormalizeForm.NFD) == expect_normlize_data[2] - assert normalize(nlp.utils.NormalizeForm.NFKD) == expect_normlize_data[3] - - -def test_regex_replace(): - """ - Test RegexReplace - """ - - def regex_replace(first, last, expect_str, pattern, replace): - dataset = ds.TextFileDataset(REGEX_REPLACE_FILE, shuffle=False) - if first > 1: - dataset = dataset.skip(first - 1) - if last >= first: - dataset = dataset.take(last - first + 1) - replace_op = nlp.RegexReplace(pattern, replace) - dataset = dataset.map(operations=replace_op) - out_text = [] - for i in dataset.create_dict_iterator(): - text = nlp.to_str(i['text']).tolist() - out_text.append(text) - logger.info("Out:", out_text) - logger.info("Exp:", expect_str) - assert expect_str == out_text - - regex_replace(1, 2, ['H____ W____', "L__'_ G_"], "\\p{Ll}", '_') - regex_replace(3, 5, ['hello', 'world', '31:beijing'], "^(\\d:|b:)", "") - regex_replace(6, 6, ["WelcometoChina!"], "\\s+", "") - regex_replace(7, 8, ['我不想长大', 'WelcometoShenzhen!'], "\\p{Cc}|\\p{Cf}|\\s+", "") - - -def test_regex_tokenizer(): - """ - Test RegexTokenizer - """ - - def regex_tokenizer(first, last, expect_str, delim_pattern, keep_delim_pattern): - dataset = ds.TextFileDataset(REGEX_TOKENIZER_FILE, shuffle=False) - if first > 1: - dataset = dataset.skip(first - 1) - if last >= first: - dataset = dataset.take(last - first + 1) - tokenizer_op = nlp.RegexTokenizer(delim_pattern, keep_delim_pattern) - dataset = dataset.map(operations=tokenizer_op) - out_text = [] - count = 0 - for i in dataset.create_dict_iterator(): - text = nlp.to_str(i['text']).tolist() - np.testing.assert_array_equal(text, expect_str[count]) - count += 1 - out_text.append(text) - logger.info("Out:", out_text) - logger.info("Exp:", expect_str) - - regex_tokenizer(1, 1, [['Welcome', 'to', 'Shenzhen!']], "\\s+", "") - regex_tokenizer(1, 1, [['Welcome', ' ', 'to', ' ', 'Shenzhen!']], "\\s+", "\\s+") - regex_tokenizer(2, 2, [['北', '京', '欢', '迎', '您', '!Welcome to Beijing!']], r"\p{Han}", r"\p{Han}") - regex_tokenizer(3, 3, [['12', '¥+', '36', '¥=?']], r"[\p{P}|\p{S}]+", r"[\p{P}|\p{S}]+") - regex_tokenizer(3, 3, [['12', '36']], r"[\p{P}|\p{S}]+", "") - regex_tokenizer(3, 3, [['¥+', '¥=?']], r"[\p{N}]+", "") - - -if __name__ == '__main__': - test_unicode_char_tokenizer() - test_whitespace_tokenizer() - test_unicode_script_tokenizer() - test_unicode_script_tokenizer2() - test_case_fold() - test_normalize_utf8() - test_regex_replace() - test_regex_tokenizer() diff --git a/tests/ut/python/dataset/test_uniform_augment.py b/tests/ut/python/dataset/test_uniform_augment.py index a26b647265..e5b66696ea 100644 --- a/tests/ut/python/dataset/test_uniform_augment.py +++ b/tests/ut/python/dataset/test_uniform_augment.py @@ -16,6 +16,7 @@ Testing UniformAugment in DE """ import numpy as np +import pytest import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.c_transforms as C @@ -164,12 +165,13 @@ def test_cpp_uniform_augment_exception_pyops(num_ops=2): C.RandomRotation(degrees=45), F.Invert()] - try: + with pytest.raises(TypeError) as e: _ = C.UniformAugment(operations=transforms_ua, num_ops=num_ops) - except Exception as e: - logger.info("Got an exception in DE: {}".format(str(e))) - assert "operations" in str(e) + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Argument tensor_op_5 with value" \ + " ,)" in str(e.value) def test_cpp_uniform_augment_exception_large_numops(num_ops=6): @@ -209,7 +211,7 @@ def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "num_ops" in str(e) + assert "Input num_ops must be greater than 0" in str(e) def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): @@ -229,7 +231,7 @@ def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5): except Exception as e: logger.info("Got an exception in DE: {}".format(str(e))) - assert "integer" in str(e) + assert "Argument num_ops with value 2.5 is not of type (,)" in str(e) def test_cpp_uniform_augment_random_crop_badinput(num_ops=1): diff --git a/tests/ut/python/dataset/test_vocab.py b/tests/ut/python/dataset/test_vocab.py index 35411e5c80..0545181360 100644 --- a/tests/ut/python/dataset/test_vocab.py +++ b/tests/ut/python/dataset/test_vocab.py @@ -26,7 +26,7 @@ SIMPLE_VOCAB_FILE = "../data/dataset/testVocab/simple_vocab_list.txt" def test_from_list_tutorial(): vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "), ["", ""], True) - lookup = text.Lookup(vocab) + lookup = text.Lookup(vocab, "") data = ds.TextFileDataset(DATA_FILE, shuffle=False) data = data.map(input_columns=["text"], operations=lookup) ind = 0 @@ -50,7 +50,7 @@ def test_from_file_tutorial(): def test_from_dict_tutorial(): vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "": 6}) - lookup = text.Lookup(vocab, 6) # default value is -1 + lookup = text.Lookup(vocab, "") # any unknown token will be mapped to the id of data = ds.TextFileDataset(DATA_FILE, shuffle=False) data = data.map(input_columns=["text"], operations=lookup) res = [3, 6, 2, 4, 5, 6] @@ -60,33 +60,51 @@ def test_from_dict_tutorial(): ind += 1 +def test_from_dict_exception(): + try: + vocab = text.Vocab.from_dict({"home": -1, "behind": 0}) + if not vocab: + raise ValueError("Vocab is None") + except ValueError as e: + assert "is not within the required interval" in str(e) + + def test_from_list(): def gen(texts): for word in texts.split(" "): yield (np.array(word, dtype='S'),) - def test_config(lookup_str, vocab_input, special_tokens, special_first): + def test_config(lookup_str, vocab_input, special_tokens, special_first, unknown_token): try: vocab = text.Vocab.from_list(vocab_input, special_tokens, special_first) data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"]) - data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) + data = data.map(input_columns=["text"], operations=text.Lookup(vocab, unknown_token)) res = [] for d in data.create_dict_iterator(): res.append(d["text"].item()) return res - except ValueError as e: + except (ValueError, RuntimeError, TypeError) as e: return str(e) + # test basic default config, special_token=None, unknown_token=None + assert test_config("w1 w2 w3", ["w1", "w2", "w3"], None, True, None) == [0, 1, 2] # test normal operations - assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], True) == [2, 3, 4, 0, 1] - assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False) == [0, 1, 2, 3, 4] - assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, True) == [2, 1, 0] - assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, False) == [2, 1, 0] + assert test_config("w1 w2 w3 s1 s2 ephemeral", ["w1", "w2", "w3"], ["s1", "s2"], True, "s2") == [2, 3, 4, 0, 1, 1] + assert test_config("w1 w2 w3 s1 s2", ["w1", "w2", "w3"], ["s1", "s2"], False, "s2") == [0, 1, 2, 3, 4] + assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, True, "w1") == [2, 1, 0] + assert test_config("w3 w2 w1", ["w1", "w2", "w3"], None, False, "w1") == [2, 1, 0] + # test unknown token lookup + assert test_config("w1 un1 w3 un2", ["w1", "w2", "w3"], ["", ""], True, "") == [2, 1, 4, 1] + assert test_config("w1 un1 w3 un2", ["w1", "w2", "w3"], ["", ""], False, "") == [0, 4, 2, 4] # test exceptions - assert "word_list contains duplicate" in test_config("w1", ["w1", "w1"], [], True) - assert "special_tokens contains duplicate" in test_config("w1", ["w1", "w2"], ["s1", "s1"], True) - assert "special_tokens and word_list contain duplicate" in test_config("w1", ["w1", "w2"], ["s1", "w1"], True) + assert "doesn't exist in vocab." in test_config("un1", ["w1"], [], False, "unk") + assert "doesn't exist in vocab and no unknown token is specified." in test_config("un1", ["w1"], [], False, None) + assert "doesn't exist in vocab" in test_config("un1", ["w1"], [], False, None) + assert "word_list contains duplicate" in test_config("w1", ["w1", "w1"], [], True, "w1") + assert "special_tokens contains duplicate" in test_config("w1", ["w1", "w2"], ["s1", "s1"], True, "w1") + assert "special_tokens and word_list contain duplicate" in test_config("w1", ["w1", "w2"], ["s1", "w1"], True, "w1") + assert "is not of type" in test_config("w1", ["w1", "w2"], ["s1"], True, 123) def test_from_file(): @@ -99,7 +117,7 @@ def test_from_file(): vocab = text.Vocab.from_file(SIMPLE_VOCAB_FILE, vocab_size=vocab_size, special_tokens=special_tokens, special_first=special_first) data = ds.GeneratorDataset(gen(lookup_str), column_names=["text"]) - data = data.map(input_columns=["text"], operations=text.Lookup(vocab)) + data = data.map(input_columns=["text"], operations=text.Lookup(vocab, "s2")) res = [] for d in data.create_dict_iterator(): res.append(d["text"].item()) @@ -118,6 +136,7 @@ def test_from_file(): if __name__ == '__main__': + test_from_dict_exception() test_from_list_tutorial() test_from_file_tutorial() test_from_dict_tutorial() diff --git a/tests/ut/python/dataset/test_wordpiece_tokenizer.py b/tests/ut/python/dataset/test_wordpiece_tokenizer.py deleted file mode 100644 index 7934884740..0000000000 --- a/tests/ut/python/dataset/test_wordpiece_tokenizer.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Testing WordpieceTokenizer op in DE -""" -import numpy as np -import mindspore.dataset as ds -from mindspore import log as logger -import mindspore.dataset.text as nlp - -WORDPIECE_TOKENIZER_FILE = "../data/dataset/testTokenizerData/wordpiece_tokenizer.txt" - -vocab_english = [ - "book", "cholera", "era", "favor", "##ite", "my", "is", "love", "dur", "##ing", "the" -] - -vocab_chinese = [ - "我", '最', '喜', '欢', '的', '书', '是', '霍', '乱', '时', '期', '爱', '情' -] - -vocab_mix = vocab_chinese + vocab_english - -test_paras = [ - dict( - first=1, - last=10, - expect_str=[['my'], ['favor', '##ite'], ['book'], ['is'], ['love'], ['dur', '##ing'], ['the'], ['cholera'], - ['era'], ['[UNK]']], - vocab_list=vocab_english - ), - dict( - first=1, - last=10, - expect_str=[['my'], ['favor', '##ite'], ['book'], ['is'], ['love'], ['dur', '##ing'], ['the'], ['cholera'], - ['era'], ['what']], - vocab_list=vocab_english, - unknown_token="" - ), - dict( - first=1, - last=10, - expect_str=[['my'], ['[UNK]'], ['book'], ['is'], ['love'], ['[UNK]'], ['the'], ['[UNK]'], ['era'], ['[UNK]']], - vocab_list=vocab_english, - max_bytes_per_token=4 - ), - dict( - first=11, - last=25, - expect_str=[['我'], ['最'], ['喜'], ['欢'], ['的'], ['书'], ['是'], ['霍'], ['乱'], ['时'], ['期'], ['的'], ['爱'], ['情'], - ['[UNK]']], - vocab_list=vocab_chinese, - ), - dict( - first=25, - last=25, - expect_str=[['您']], - vocab_list=vocab_chinese, - unknown_token="" - ), - dict( - first=1, - last=25, - expect_str=[ - ['my'], ['favor', '##ite'], ['book'], ['is'], ['love'], ['dur', '##ing'], ['the'], ['cholera'], ['era'], - ['[UNK]'], - ['我'], ['最'], ['喜'], ['欢'], ['的'], ['书'], ['是'], ['霍'], ['乱'], ['时'], ['期'], ['的'], ['爱'], ['情'], - ['[UNK]']], - vocab_list=vocab_mix, - ), -] - - -def check_wordpiece_tokenizer(first, last, expect_str, vocab_list, unknown_token='[UNK]', max_bytes_per_token=100): - dataset = ds.TextFileDataset(WORDPIECE_TOKENIZER_FILE, shuffle=False) - if first > 1: - dataset = dataset.skip(first - 1) - if last >= first: - dataset = dataset.take(last - first + 1) - vocab = nlp.Vocab.from_list(vocab_list) - tokenizer_op = nlp.WordpieceTokenizer(vocab=vocab, unknown_token=unknown_token, - max_bytes_per_token=max_bytes_per_token) - dataset = dataset.map(operations=tokenizer_op) - count = 0 - for i in dataset.create_dict_iterator(): - text = nlp.to_str(i['text']) - logger.info("Out:", text) - logger.info("Exp:", expect_str[count]) - np.testing.assert_array_equal(text, expect_str[count]) - count = count + 1 - - -def test_wordpiece_tokenizer(): - """ - Test WordpieceTokenizer - """ - for paras in test_paras: - check_wordpiece_tokenizer(**paras) - - -if __name__ == '__main__': - test_wordpiece_tokenizer() diff --git a/tests/ut/python/dataset/util.py b/tests/ut/python/dataset/util.py index 2a8e93cd0b..11c5735406 100644 --- a/tests/ut/python/dataset/util.py +++ b/tests/ut/python/dataset/util.py @@ -288,12 +288,13 @@ def config_get_set_num_parallel_workers(num_parallel_workers_new): return num_parallel_workers_original -def visualize_with_bounding_boxes(orig, aug, plot_rows=3): +def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=3): """ Take a list of un-augmented and augmented images with "annotation" bounding boxes Plot images to compare test correct BBox augment functionality :param orig: list of original images and bboxes (without aug) :param aug: list of augmented images and bboxes + :param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "annotation" (VOC) :param plot_rows: number of rows on plot (rows = samples on one plot) :return: None """ @@ -301,9 +302,10 @@ def visualize_with_bounding_boxes(orig, aug, plot_rows=3): def add_bounding_boxes(ax, bboxes): for bbox in bboxes: rect = patches.Rectangle((bbox[0], bbox[1]), - bbox[2], bbox[3], - linewidth=1, edgecolor='r', facecolor='none') + bbox[2]*0.997, bbox[3]*0.997, + linewidth=1.80, edgecolor='r', facecolor='none') # Add the patch to the Axes + # Params to Rectangle slightly modified to prevent drawing overflow ax.add_patch(rect) # Quick check to confirm correct input parameters @@ -312,14 +314,15 @@ def visualize_with_bounding_boxes(orig, aug, plot_rows=3): if len(orig) != len(aug) or not orig: return - batch_size = int(len(orig)/plot_rows) # creates batches of images to plot together + batch_size = int(len(orig) / plot_rows) # creates batches of images to plot together split_point = batch_size * plot_rows orig, aug = np.array(orig), np.array(aug) if len(orig) > plot_rows: # Create batches of required size and add remainder to last batch - orig = np.split(orig[:split_point], batch_size) + ([orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added + orig = np.split(orig[:split_point], batch_size) + ( + [orig[split_point:]] if (split_point < orig.shape[0]) else []) # check to avoid empty arrays being added aug = np.split(aug[:split_point], batch_size) + ([aug[split_point:]] if (split_point < aug.shape[0]) else []) else: orig = [orig] @@ -334,18 +337,19 @@ def visualize_with_bounding_boxes(orig, aug, plot_rows=3): for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])): cur_ix = base_ix + x - (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) # select plotting axes based on number of image rows on plot - else case when 1 row + # select plotting axes based on number of image rows on plot - else case when 1 row + (axA, axB) = (axs[x, 0], axs[x, 1]) if (curPlot > 1) else (axs[0], axs[1]) axA.imshow(dataA["image"]) - add_bounding_boxes(axA, dataA["annotation"]) + add_bounding_boxes(axA, dataA[annot_name]) axA.title.set_text("Original" + str(cur_ix+1)) axB.imshow(dataB["image"]) - add_bounding_boxes(axB, dataB["annotation"]) + add_bounding_boxes(axB, dataB[annot_name]) axB.title.set_text("Augmented" + str(cur_ix+1)) - logger.info("Original **\n{} : {}".format(str(cur_ix+1), dataA["annotation"])) - logger.info("Augmented **\n{} : {}\n".format(str(cur_ix+1), dataB["annotation"])) + logger.info("Original **\n{} : {}".format(str(cur_ix+1), dataA[annot_name])) + logger.info("Augmented **\n{} : {}\n".format(str(cur_ix+1), dataB[annot_name])) plt.show() @@ -381,19 +385,19 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error): width = img.shape[1] if invalid_bbox_type_ == InvalidBBoxType.WidthOverflow: # use box that overflows on width - return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.uint32) + return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.float32) if invalid_bbox_type_ == InvalidBBoxType.HeightOverflow: # use box that overflows on height - return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.uint32) + return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.float32) if invalid_bbox_type_ == InvalidBBoxType.NegativeXY: # use box with negative xy - return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.uint32) + return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.float32) if invalid_bbox_type_ == InvalidBBoxType.WrongShape: # use box that has incorrect shape - return img, np.array([[0, 0, width - 1]]).astype(np.uint32) + return img, np.array([[0, 0, width - 1]]).astype(np.float32) return img, bboxes try: diff --git a/tests/ut/python/ir/test_indexed_slices.py b/tests/ut/python/ir/test_indexed_slices.py index 8690183090..36dfe464cb 100644 --- a/tests/ut/python/ir/test_indexed_slices.py +++ b/tests/ut/python/ir/test_indexed_slices.py @@ -36,6 +36,8 @@ from mindspore._checkparam import Rel from mindspore.nn import Optimizer from mindspore.nn import TrainOneStepCell, WithLossCell +context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) + reduce_sum = P.ReduceSum() unsorted_segment_sum = P.UnsortedSegmentSum() transpose = P.Transpose() @@ -44,7 +46,6 @@ reshape = P.Reshape() size_op = P.Size() invert_permutation = P.InvertPermutation() logical_and = P.LogicalAnd() -context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) @constexpr def _generate_shape_index(out_shape, indices_shape, axis): @@ -103,10 +104,15 @@ def get_bprop_sparse_gather_v2(self): adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Undetermined", "Bool") -def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): - if gradient.is_indexed_slices(): - return gradient.values() + "Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool") +def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param, + m, v, gradient, decay_flag): + return gradient.values() + +@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Tensor", "Bool") +def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param, + m, v, gradient, decay_flag): op_mul = P.Mul() op_square = P.Square() op_sqrt = P.Sqrt() @@ -182,7 +188,7 @@ def test_indexed_slices_make_indexed_slices(): self.dense_shape = (3, 4) def construct(self, indices, values): ret = (IndexedSlices(indices, values, self.dense_shape),) - return ret[0].is_indexed_slices() + return ret[0] indices = Tensor([[0, 0], [1, 2]]) values = Tensor([1, 2], dtype=ms.float32) MakeIndexedSlices()(indices, values) @@ -209,7 +215,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): self.network = network def construct(self, x, y): grad = grad_all(self.network)(x, y) - return grad, grad[0].is_indexed_slices(), grad[1].is_indexed_slices() + return grad, grad[0], grad[1] class SparseGatherV2(nn.Cell): def __init__(self): super(SparseGatherV2, self).__init__() @@ -233,14 +239,13 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): weights = self.weights grad = grad_by_list(self.network, weights)(x) x = grad[0] - return x.is_indexed_slices(), x.values(), x.indices(), x.dense_shape() + return x, x.values(), x.indices(), x.dense_shape() class SparseGatherV2(nn.Cell): def __init__(self): super(SparseGatherV2, self).__init__() self.sparse_gatherv2 = MySparseGatherV2() self.axis = 0 - self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), - name="params", has_indexed_slices_grad=True) + self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params") def construct(self, indices): return self.sparse_gatherv2(self.params, indices, self.axis) indices = Tensor(np.array([0, 1]).astype(np.int32)) @@ -248,20 +253,6 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): network(indices) -def test_indexed_slices_is_indexed_slices(): - class MakeIndexedSlices(nn.Cell): - def __init__(self): - super(MakeIndexedSlices, self).__init__() - self.dense_shape = (3, 4) - def construct(self, indices, values): - indexed_slices = IndexedSlices(indices, values, self.dense_shape) - ret = indexed_slices.is_indexed_slices() - return ret - indices = Tensor([[0, 0], [1, 2]]) - values = Tensor([1, 2], dtype=ms.float32) - MakeIndexedSlices()(indices, values) - - def test_indexed_slices_env_get(): class Loss(nn.Cell): def __init__(self): @@ -271,7 +262,7 @@ def test_indexed_slices_env_get(): class NetWithSparseGatherV2(nn.Cell): def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", has_indexed_slices_grad=True) + self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") self.gatherv2 = MySparseGatherV2() self.axis = 0 diff --git a/tests/ut/python/model/resnet.py b/tests/ut/python/model/resnet.py new file mode 100644 index 0000000000..001e1db0cf --- /dev/null +++ b/tests/ut/python/model/resnet.py @@ -0,0 +1,282 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ResNet.""" +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor + + +def _weight_variable(shape, factor=0.01): + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + + +def _conv3x3(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 3, 3) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv1x1(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 1, 1) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv7x7(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 7, 7) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _fc(in_channel, out_channel): + weight_shape = (out_channel, in_channel) + weight = _weight_variable(weight_shape) + return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1): + super(ResidualBlock, self).__init__() + + channel = out_channel // self.expansion + self.conv1 = _conv1x1(in_channel, channel, stride=1) + self.bn1 = _bn(channel) + + self.conv2 = _conv3x3(channel, channel, stride=stride) + self.bn2 = _bn(channel) + + self.conv3 = _conv1x1(channel, out_channel, stride=1) + self.bn3 = _bn_last(out_channel) + + self.relu = nn.ReLU() + + self.down_sample = False + + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride), + _bn(out_channel)]) + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + + self.conv1 = _conv7x7(3, 64, stride=2) + self.bn1 = _bn(64) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0]) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1]) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2]) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3]) + + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.end_point = _fc(out_channels[3], num_classes) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride) + layers.append(resnet_block) + + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1) + layers.append(resnet_block) + + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.mean(c5, (2, 3)) + out = self.flatten(out) + out = self.end_point(out) + + return out + + +def resnet50(class_num=10): + """ + Get ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet50 neural network. + + Examples: + >>> net = resnet50(10) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) + +def resnet101(class_num=1001): + """ + Get ResNet101 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet101 neural network. + + Examples: + >>> net = resnet101(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 23, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) diff --git a/tests/ut/python/model/test_mix_precision.py b/tests/ut/python/model/test_mix_precision.py index d0e77f901a..f1fc2cc2f7 100644 --- a/tests/ut/python/model/test_mix_precision.py +++ b/tests/ut/python/model/test_mix_precision.py @@ -219,3 +219,31 @@ def test_dict_cast(): y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32) net = FirstNet() net(x, y) + + +def test_kwarg_cast(): + class FirstNet(nn.Cell): + def __init__(self): + super(FirstNet, self).__init__() + self.net = SecondNet().add_flags_recursive(fp16=True) + self.add = P.TensorAdd() + + def construct(self, tensor_a, tensor_b): + tensor_c = self.add(tensor_a, tensor_b) + dictionary = {"key": tensor_a} + result = self.net(key1=tensor_c, key2=dictionary) + return result + + class SecondNet(nn.Cell): + def __init__(self): + super(SecondNet, self).__init__() + self.add = P.TensorAdd() + + def construct(self, key1=1, key2=2): + tensor_d = self.add(key1, key2["key"]) + return tensor_d + + x = Tensor(np.array([1, 2.5, 3.5]), mstype.float32) + y = Tensor(np.array([4, 5.5, 6.5]), mstype.float32) + net = FirstNet() + net(x, y) diff --git a/tests/ut/python/nn/optim/test_adam.py b/tests/ut/python/nn/optim/test_adam.py index b435bf65b9..03a73893c5 100644 --- a/tests/ut/python/nn/optim/test_adam.py +++ b/tests/ut/python/nn/optim/test_adam.py @@ -17,12 +17,13 @@ import numpy as np import pytest import mindspore.nn as nn -from mindspore import Tensor, Parameter +from mindspore import Tensor, Parameter, context from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR from mindspore.ops import operations as P +context.set_context(enable_sparse=True) class Net(nn.Cell): """ Net definition """ @@ -53,8 +54,7 @@ class NetWithSparseGatherV2(nn.Cell): """ NetWithSparseGatherV2 definition """ def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), - name="weight1", sparse_grad="sparse_key_w1") + self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") self.axis = 0 self.gather = P.SparseGatherV2() diff --git a/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py b/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py index 7f9f341a93..23aad24c47 100644 --- a/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py +++ b/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py @@ -27,6 +27,7 @@ from mindspore.ops import functional as F from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel +context.set_context(enable_sparse=True) adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @@ -154,7 +155,7 @@ def test_AdamWeightDecaySparse(): class NetWithSparseGatherV2(nn.Cell): def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1") + self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1") self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") self.gatherv2 = P.SparseGatherV2() self.axis = 0 diff --git a/tests/ut/python/nn/optim/test_ftrl.py b/tests/ut/python/nn/optim/test_ftrl.py index de59dfdbad..670bebc92d 100644 --- a/tests/ut/python/nn/optim/test_ftrl.py +++ b/tests/ut/python/nn/optim/test_ftrl.py @@ -17,12 +17,13 @@ import numpy as np import mindspore.nn as nn -from mindspore import Tensor, Parameter +from mindspore import Tensor, Parameter, context from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import FTRL from mindspore.ops import operations as P +context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self): @@ -41,8 +42,7 @@ class NetWithSparseGatherV2(nn.Cell): """ NetWithSparseGatherV2 definition """ def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), - name="weight1", sparse_grad="sparse_key_w1") + self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") self.axis = 0 self.gather = P.SparseGatherV2() diff --git a/tests/ut/python/nn/optim/test_lazyadam.py b/tests/ut/python/nn/optim/test_lazyadam.py index ce66b404e2..7769597140 100644 --- a/tests/ut/python/nn/optim/test_lazyadam.py +++ b/tests/ut/python/nn/optim/test_lazyadam.py @@ -17,12 +17,13 @@ import numpy as np import pytest import mindspore.nn as nn -from mindspore import Tensor, Parameter +from mindspore import Tensor, Parameter, context from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import LazyAdam from mindspore.ops import operations as P +context.set_context(enable_sparse=True) class Net(nn.Cell): """ Net definition """ @@ -43,8 +44,7 @@ class NetWithSparseGatherV2(nn.Cell): """ NetWithSparseGatherV2 definition """ def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), - name="weight1", sparse_grad="sparse_key_w1") + self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") self.axis = 0 self.gather = P.SparseGatherV2() diff --git a/tests/ut/python/nn/optim/test_proximal_ada_grad.py b/tests/ut/python/nn/optim/test_proximal_ada_grad.py index c7e6d3f88a..3077896fed 100644 --- a/tests/ut/python/nn/optim/test_proximal_ada_grad.py +++ b/tests/ut/python/nn/optim/test_proximal_ada_grad.py @@ -17,12 +17,13 @@ import numpy as np import mindspore.nn as nn -from mindspore import Tensor, Parameter +from mindspore import Tensor, Parameter, context from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import ProximalAdagrad from mindspore.ops import operations as P +context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self): @@ -40,8 +41,7 @@ class NetWithSparseGatherV2(nn.Cell): """ NetWithSparseGatherV2 definition """ def __init__(self): super(NetWithSparseGatherV2, self).__init__() - self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", - sparse_grad="sparse_key_w1") + self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="weight2") self.axis = 0 self.gather = P.SparseGatherV2() diff --git a/tests/ut/python/nn/test_distribution.py b/tests/ut/python/nn/test_distribution.py new file mode 100644 index 0000000000..845c64a110 --- /dev/null +++ b/tests/ut/python/nn/test_distribution.py @@ -0,0 +1,369 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test nn.Distribution. + +Including Normal Distribution and Bernoulli Distribution. +""" +import pytest +import numpy as np + +import mindspore.nn as nn +from mindspore import dtype +from mindspore import Tensor + +def test_normal_shape_errpr(): + """ + Invalid shapes. + """ + with pytest.raises(ValueError): + nn.Normal([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) + +def test_no_arguments(): + """ + No args passed in during initialization. + """ + n = nn.Normal() + assert isinstance(n, nn.Distribution) + b = nn.Bernoulli() + assert isinstance(b, nn.Distribution) + +def test_with_arguments(): + """ + Args passed in during initialization. + """ + n = nn.Normal([3.0], [4.0], dtype=dtype.float32) + assert isinstance(n, nn.Distribution) + b = nn.Bernoulli([0.3, 0.5], dtype=dtype.int32) + assert isinstance(b, nn.Distribution) + +class NormalProb(nn.Cell): + """ + Normal distribution: initialize with mean/sd. + """ + def __init__(self): + super(NormalProb, self).__init__() + self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value): + x = self.normal('prob', value) + y = self.normal('log_prob', value) + return x, y + +def test_normal_prob(): + """ + Test pdf/log_pdf: passing value through construct. + """ + net = NormalProb() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + pdf, log_pdf = net(value) + assert isinstance(pdf, Tensor) + assert isinstance(log_pdf, Tensor) + +class NormalProb1(nn.Cell): + """ + Normal distribution: initialize without mean/sd. + """ + def __init__(self): + super(NormalProb1, self).__init__() + self.normal = nn.Normal() + + def construct(self, value, mean, sd): + x = self.normal('prob', value, mean, sd) + y = self.normal('log_prob', value, mean, sd) + return x, y + +def test_normal_prob1(): + """ + Test pdf/logpdf: passing mean/sd, value through construct. + """ + net = NormalProb1() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + mean = Tensor([0.0], dtype=dtype.float32) + sd = Tensor([1.0], dtype=dtype.float32) + pdf, log_pdf = net(value, mean, sd) + assert isinstance(pdf, Tensor) + assert isinstance(log_pdf, Tensor) + +class NormalProb2(nn.Cell): + """ + Normal distribution: initialize with mean/sd. + """ + def __init__(self): + super(NormalProb2, self).__init__() + self.normal = nn.Normal(3.0, 4.0, dtype=dtype.float32) + + def construct(self, value, mean, sd): + x = self.normal('prob', value, mean, sd) + y = self.normal('log_prob', value, mean, sd) + return x, y + +def test_normal_prob2(): + """ + Test pdf/log_pdf: passing mean/sd through construct. + Overwrite original mean/sd. + """ + net = NormalProb2() + value = Tensor([0.5, 1.0], dtype=dtype.float32) + mean = Tensor([0.0], dtype=dtype.float32) + sd = Tensor([1.0], dtype=dtype.float32) + pdf, log_pdf = net(value, mean, sd) + assert isinstance(pdf, Tensor) + assert isinstance(log_pdf, Tensor) + +class BernoulliProb(nn.Cell): + """ + Bernoulli distribution: initialize with probs. + """ + def __init__(self): + super(BernoulliProb, self).__init__() + self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) + + def construct(self, value): + return self.bernoulli('prob', value) + +class BernoulliLogProb(nn.Cell): + """ + Bernoulli distribution: initialize with probs. + """ + def __init__(self): + super(BernoulliLogProb, self).__init__() + self.bernoulli = nn.Bernoulli(0.5, dtype=dtype.int32) + + def construct(self, value): + return self.bernoulli('log_prob', value) + + +def test_bernoulli_prob(): + """ + Test pmf/log_pmf: passing value through construct. + """ + net = BernoulliProb() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + pmf = net(value) + assert isinstance(pmf, Tensor) + +def test_bernoulli_log_prob(): + """ + Test pmf/log_pmf: passing value through construct. + """ + net = BernoulliLogProb() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + log_pmf = net(value) + assert isinstance(log_pmf, Tensor) + +class BernoulliProb1(nn.Cell): + """ + Bernoulli distribution: initialize without probs. + """ + def __init__(self): + super(BernoulliProb1, self).__init__() + self.bernoulli = nn.Bernoulli() + + def construct(self, value, probs): + return self.bernoulli('prob', value, probs) + +class BernoulliLogProb1(nn.Cell): + """ + Bernoulli distribution: initialize without probs. + """ + def __init__(self): + super(BernoulliLogProb1, self).__init__() + self.bernoulli = nn.Bernoulli() + + def construct(self, value, probs): + return self.bernoulli('log_prob', value, probs) + + +def test_bernoulli_prob1(): + """ + Test pmf/log_pmf: passing probs through construct. + """ + net = BernoulliProb1() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + probs = Tensor([0.3], dtype=dtype.float32) + pmf = net(value, probs) + assert isinstance(pmf, Tensor) + +def test_bernoulli_log_prob1(): + """ + Test pmf/log_pmf: passing probs through construct. + """ + net = BernoulliLogProb1() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + probs = Tensor([0.3], dtype=dtype.float32) + log_pmf = net(value, probs) + assert isinstance(log_pmf, Tensor) + +class BernoulliProb2(nn.Cell): + """ + Bernoulli distribution: initialize with probs. + """ + def __init__(self): + super(BernoulliProb2, self).__init__() + self.bernoulli = nn.Bernoulli(0.5) + + def construct(self, value, probs): + return self.bernoulli('prob', value, probs) + +class BernoulliLogProb2(nn.Cell): + """ + Bernoulli distribution: initialize with probs. + """ + def __init__(self): + super(BernoulliLogProb2, self).__init__() + self.bernoulli = nn.Bernoulli(0.5) + + def construct(self, value, probs): + return self.bernoulli('log_prob', value, probs) + + +def test_bernoulli_prob2(): + """ + Test pmf/log_pmf: passing probs/value through construct. + Overwrite original probs. + """ + net = BernoulliProb2() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + probs = Tensor([0.3], dtype=dtype.float32) + pmf = net(value, probs) + assert isinstance(pmf, Tensor) + +def test_bernoulli_log_prob2(): + """ + Test pmf/log_pmf: passing probs/value through construct. + Overwrite original probs. + """ + net = BernoulliLogProb2() + value = Tensor([1, 0, 1, 0, 1], dtype=dtype.float32) + probs = Tensor([0.3], dtype=dtype.float32) + log_pmf = net(value, probs) + assert isinstance(log_pmf, Tensor) + + +class NormalKl(nn.Cell): + """ + Test class: kl_loss of Normal distribution. + """ + def __init__(self): + super(NormalKl, self).__init__() + self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) + + def construct(self, x_, y_): + return self.n('kl_loss', 'Normal', x_, y_) + +class BernoulliKl(nn.Cell): + """ + Test class: kl_loss between Bernoulli distributions. + """ + def __init__(self): + super(BernoulliKl, self).__init__() + self.b = nn.Bernoulli(0.7, dtype=dtype.int32) + + def construct(self, x_): + return self.b('kl_loss', 'Bernoulli', x_) + +def test_kl(): + """ + Test kl_loss function. + """ + nor_net = NormalKl() + mean_b = np.array([1.0]).astype(np.float32) + sd_b = np.array([1.0]).astype(np.float32) + mean = Tensor(mean_b, dtype=dtype.float32) + sd = Tensor(sd_b, dtype=dtype.float32) + loss = nor_net(mean, sd) + assert isinstance(loss, Tensor) + + ber_net = BernoulliKl() + probs_b = Tensor([0.3], dtype=dtype.float32) + loss = ber_net(probs_b) + assert isinstance(loss, Tensor) + + +class NormalKlNoArgs(nn.Cell): + """ + Test class: kl_loss of Normal distribution. + No args during initialization. + """ + def __init__(self): + super(NormalKlNoArgs, self).__init__() + self.n = nn.Normal(dtype=dtype.float32) + + def construct(self, x_, y_, w_, v_): + return self.n('kl_loss', 'Normal', x_, y_, w_, v_) + +class BernoulliKlNoArgs(nn.Cell): + """ + Test class: kl_loss between Bernoulli distributions. + No args during initialization. + """ + def __init__(self): + super(BernoulliKlNoArgs, self).__init__() + self.b = nn.Bernoulli(dtype=dtype.int32) + + def construct(self, x_, y_): + return self.b('kl_loss', 'Bernoulli', x_, y_) + +def test_kl_no_args(): + """ + Test kl_loss function. + """ + nor_net = NormalKlNoArgs() + mean_b = np.array([1.0]).astype(np.float32) + sd_b = np.array([1.0]).astype(np.float32) + mean_a = np.array([2.0]).astype(np.float32) + sd_a = np.array([3.0]).astype(np.float32) + mean_b = Tensor(mean_b, dtype=dtype.float32) + sd_b = Tensor(sd_b, dtype=dtype.float32) + mean_a = Tensor(mean_a, dtype=dtype.float32) + sd_a = Tensor(sd_a, dtype=dtype.float32) + loss = nor_net(mean_b, sd_b, mean_a, sd_a) + assert isinstance(loss, Tensor) + + ber_net = BernoulliKlNoArgs() + probs_b = Tensor([0.3], dtype=dtype.float32) + probs_a = Tensor([0.7], dtype=dtype.float32) + loss = ber_net(probs_b, probs_a) + assert isinstance(loss, Tensor) + + + +class NormalBernoulli(nn.Cell): + """ + Test class: basic mean/sd function. + """ + def __init__(self): + super(NormalBernoulli, self).__init__() + self.n = nn.Normal(3.0, 4.0, dtype=dtype.float32) + self.b = nn.Bernoulli(0.5, dtype=dtype.int32) + + def construct(self): + normal_mean = self.n('mean') + normal_sd = self.n('sd') + bernoulli_mean = self.b('mean') + bernoulli_sd = self.b('sd') + return normal_mean, normal_sd, bernoulli_mean, bernoulli_sd + +def test_bascis(): + """ + Test mean/sd functionality of Normal and Bernoulli. + """ + net = NormalBernoulli() + normal_mean, normal_sd, bernoulli_mean, bernoulli_sd = net() + assert isinstance(normal_mean, Tensor) + assert isinstance(normal_sd, Tensor) + assert isinstance(bernoulli_mean, Tensor) + assert isinstance(bernoulli_sd, Tensor) diff --git a/tests/ut/python/nn/test_msssim.py b/tests/ut/python/nn/test_msssim.py new file mode 100644 index 0000000000..b85d13c927 --- /dev/null +++ b/tests/ut/python/nn/test_msssim.py @@ -0,0 +1,135 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +test msssim +""" +import numpy as np +import pytest + +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import _executor + +_MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) + +class MSSSIMNet(nn.Cell): + def __init__(self, max_val=1.0, power_factors=_MSSSIM_WEIGHTS, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): + super(MSSSIMNet, self).__init__() + self.net = nn.MSSSIM(max_val, power_factors, filter_size, filter_sigma, k1, k2) + + def construct(self, img1, img2): + return self.net(img1, img2) + + +def test_compile(): + factors = (0.033, 0.033, 0.033) + net = MSSSIMNet(power_factors=factors) + img1 = Tensor(np.random.random((8, 3, 128, 128))) + img2 = Tensor(np.random.random((8, 3, 128, 128))) + _executor.compile(net, img1, img2) + + +def test_compile_grayscale(): + max_val = 255 + factors = (0.033, 0.033, 0.033) + net = MSSSIMNet(max_val=max_val, power_factors=factors) + img1 = Tensor(np.random.randint(0, 256, (8, 3, 128, 128), np.uint8)) + img2 = Tensor(np.random.randint(0, 256, (8, 3, 128, 128), np.uint8)) + _executor.compile(net, img1, img2) + + +def test_msssim_max_val_negative(): + max_val = -1 + with pytest.raises(ValueError): + _ = MSSSIMNet(max_val) + + +def test_msssim_max_val_bool(): + max_val = True + with pytest.raises(TypeError): + _ = MSSSIMNet(max_val) + + +def test_msssim_max_val_zero(): + max_val = 0 + with pytest.raises(ValueError): + _ = MSSSIMNet(max_val) + + +def test_msssim_power_factors_set(): + with pytest.raises(TypeError): + _ = MSSSIMNet(power_factors={0.033, 0.033, 0.033}) + + +def test_msssim_filter_size_float(): + with pytest.raises(TypeError): + _ = MSSSIMNet(filter_size=1.1) + + +def test_msssim_filter_size_zero(): + with pytest.raises(ValueError): + _ = MSSSIMNet(filter_size=0) + + +def test_msssim_filter_sigma_zero(): + with pytest.raises(ValueError): + _ = MSSSIMNet(filter_sigma=0.0) + + +def test_msssim_filter_sigma_negative(): + with pytest.raises(ValueError): + _ = MSSSIMNet(filter_sigma=-0.1) + + +def test_msssim_different_shape(): + shape_1 = (8, 3, 128, 128) + shape_2 = (8, 3, 256, 256) + factors = (0.033, 0.033, 0.033) + img1 = Tensor(np.random.random(shape_1)) + img2 = Tensor(np.random.random(shape_2)) + net = MSSSIMNet(power_factors=factors) + with pytest.raises(ValueError): + _executor.compile(net, img1, img2) + + +def test_msssim_different_dtype(): + dtype_1 = mstype.float32 + dtype_2 = mstype.float16 + factors = (0.033, 0.033, 0.033) + img1 = Tensor(np.random.random((8, 3, 128, 128)), dtype=dtype_1) + img2 = Tensor(np.random.random((8, 3, 128, 128)), dtype=dtype_2) + net = MSSSIMNet(power_factors=factors) + with pytest.raises(TypeError): + _executor.compile(net, img1, img2) + + +def test_msssim_invalid_5d_input(): + shape_1 = (8, 3, 128, 128) + shape_2 = (8, 3, 256, 256) + invalid_shape = (8, 3, 128, 128, 1) + factors = (0.033, 0.033, 0.033) + img1 = Tensor(np.random.random(shape_1)) + invalid_img1 = Tensor(np.random.random(invalid_shape)) + img2 = Tensor(np.random.random(shape_2)) + invalid_img2 = Tensor(np.random.random(invalid_shape)) + + net = MSSSIMNet(power_factors=factors) + with pytest.raises(ValueError): + _executor.compile(net, invalid_img1, img2) + with pytest.raises(ValueError): + _executor.compile(net, img1, invalid_img2) + with pytest.raises(ValueError): + _executor.compile(net, invalid_img1, invalid_img2) diff --git a/tests/ut/python/nn/test_ssim.py b/tests/ut/python/nn/test_ssim.py index 5cf1b0c94c..8b7e441014 100644 --- a/tests/ut/python/nn/test_ssim.py +++ b/tests/ut/python/nn/test_ssim.py @@ -78,26 +78,6 @@ def test_ssim_filter_sigma_negative(): _ = SSIMNet(filter_sigma=-0.1) -def test_ssim_k1_k2_wrong_value(): - with pytest.raises(ValueError): - _ = SSIMNet(k1=1.1) - with pytest.raises(ValueError): - _ = SSIMNet(k1=1.0) - with pytest.raises(ValueError): - _ = SSIMNet(k1=0.0) - with pytest.raises(ValueError): - _ = SSIMNet(k1=-1.0) - - with pytest.raises(ValueError): - _ = SSIMNet(k2=1.1) - with pytest.raises(ValueError): - _ = SSIMNet(k2=1.0) - with pytest.raises(ValueError): - _ = SSIMNet(k2=0.0) - with pytest.raises(ValueError): - _ = SSIMNet(k2=-1.0) - - def test_ssim_different_shape(): shape_1 = (8, 3, 16, 16) shape_2 = (8, 3, 8, 8) diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 064512b19a..53b42b8f66 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -600,3 +600,42 @@ def test_while_tensor(): x = Tensor(np.ones([6, 8, 10], np.int32)) y = Tensor(np.ones([6, 8, 10], np.int32)) out = net(x, y) + + +def test_large_for_loop(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.flatten = P.ReLU() #nn.Flatten() + + def construct(self, x): + for elem in range(1, 19000): + x = self.flatten(x + elem) + return x + + t = Tensor(np.ones([2, 3], dtype=np.float32)) + net = Net() + net(t) + + +def test_large_for_loop_with_continue_break(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.flatten = P.ReLU() #nn.Flatten() + + def construct(self, x): + idx = 0 + for elem1 in range(200): + idx = idx + 1 + if idx < 10: + x = x + 0.5 + continue + if idx > 500: + break + x = self.flatten(x + elem1) + return x + + t = Tensor(np.ones([2, 3], dtype=np.float32)) + net = Net() + net(t) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 029d49fe1c..31ca540f74 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -649,6 +649,15 @@ def test_strided_slice_const(): assert (ret.asnumpy() == np.array([], np.float32).reshape([0, 1, 7, 8, 9, 3, 1])).all() +class ParallelConcatNet(nn.Cell): + def __init__(self): + super(ParallelConcatNet, self).__init__() + self.parallel_concat = P.ParallelConcat() + + def construct(self, x1, x2): + return self.parallel_concat((x1, x2)) + + test_case_math_ops = [ ('BitwiseAnd', { 'block': P.BitwiseAnd(), @@ -1391,6 +1400,11 @@ test_case_nn_ops = [ 'desc_const': [4], 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([1, 2, 3]).astype(np.int32))], 'desc_bprop': [[4, 2, 1, 3]]}), + ('UnsortedSegmentProd', { + 'block': P.UnsortedSegmentProd(), + 'desc_const': [4], + 'desc_inputs': [[3, 2, 1, 3], Tensor(np.array([0, 1, 0]).astype(np.int32))], + 'desc_bprop': [[4, 2, 1, 3]]}), ('DropoutGenMask', { 'block': P.DropoutGenMask(), 'desc_const': [(2, 2), Tensor(0.5, mstype.float32)], @@ -1948,6 +1962,12 @@ test_case_array_ops = [ 'desc_inputs': [[1, 3, 24, 24]], 'desc_bprop': [[1, 12, 24, 24]], }), + ('ParallelConcat', { + 'block': ParallelConcatNet(), + 'desc_inputs': [Tensor([[1, 2]], mstype.float32), + Tensor([[5, 6]], mstype.float32)], + 'skip': ['backward'], + }), ] test_case_other_ops = [ @@ -2216,7 +2236,10 @@ test_case_other_ops = [ 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), Tensor(np.array([1.2]).astype(np.float32))], 'skip': ['backward']}), - + ('PopulationCount', { + 'block': P.PopulationCount(), + 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.int16))], + 'skip': ['backward']}), ] test_case_quant_ops = [ diff --git a/tests/ut/python/optimizer/test_python_pass.py b/tests/ut/python/optimizer/test_python_pass.py new file mode 100644 index 0000000000..c3ce3d6c4e --- /dev/null +++ b/tests/ut/python/optimizer/test_python_pass.py @@ -0,0 +1,64 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore +import mindspore.nn as nn +from mindspore import context +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +from mindspore.common.python_pass_register import registe_pass, PyPassManager +from mindspore.common.api import _generate_pip_args +from mindspore._c_expression import generate_key, Executor_ + +context.set_context(mode=context.GRAPH_MODE) + +def get_func_graph(obj, *args, phase="predict"): + args_names, args_list = _generate_pip_args(obj, *args) + dic = dict(zip(args_names, args_list)) + key = generate_key(phase, dic) + phase_prefix = str(key[1]) + if phase == 'export': + phase = phase + '.' + phase_prefix + '.' + str(obj.create_time) + else: + phase = phase_prefix + phase + '.' + str(obj.create_time) + _executor = Executor_.get_instance() + _executor.compile(obj, args_list, phase, False) + return _executor.get_func_graph(phase) + +def test_softmax_relu(): + """ + Use python pass to transform from Softmax to ReLU. + """ + inputs = Tensor(np.ones([42]), mindspore.float16) + softmax_model = nn.Softmax() + + @registe_pass(run_only_once=True) + def softmax_relu_pass(): + softmax = P.Softmax() + relu = P.ReLU() + def pattern(x): + x = softmax(x) + return x + def target(x): + x = relu(x) + return x + return pattern, target + + transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) + ppm = PyPassManager() + ppm.unregiste(softmax_relu_pass) + assert "ReLU" in transformed_repr + assert "Softmax" not in transformed_repr diff --git a/tests/ut/python/parallel/test_embeddinglookup.py b/tests/ut/python/parallel/test_embeddinglookup.py index 4ab5f5f878..db84ab26eb 100644 --- a/tests/ut/python/parallel/test_embeddinglookup.py +++ b/tests/ut/python/parallel/test_embeddinglookup.py @@ -19,7 +19,6 @@ import mindspore.nn as nn from mindspore.common.api import _executor from mindspore.ops import operations as P from mindspore.ops import composite as C -from mindspore.ops.operations import _inner_ops as inner from mindspore import Tensor, context from tests.ut.python.ops.test_math_ops import VirtualLoss @@ -42,17 +41,15 @@ class NetWithLoss(nn.Cell): return self.loss(predict) class Net(nn.Cell): - def __init__(self, shape, offset, reduce_scatter_flag, split_num): + def __init__(self, shape, offset, strategy1=None, strategy2=None, target="Device"): super().__init__() self.index = Tensor(np.ones(shape), dtype=ms.int32) self.offset = offset - self.reduce_scatter_flag = reduce_scatter_flag - self.split_num = split_num - self.elu = inner.EmbeddingLookup() - self.mm = P.BatchMatMul() + self.elu = P.EmbeddingLookup().set_strategy(strategy1).add_prim_attr("primitive_target", target) + self.mm = P.BatchMatMul().set_strategy(strategy2) def construct(self, x, y): - out = self.elu(x, self.index, self.offset, self.reduce_scatter_flag, self.split_num) + out = self.elu(x, self.index, self.offset) out = self.mm(out, y) return out @@ -60,9 +57,7 @@ class Net(nn.Cell): def test_embeddinglookup_reducescatter_false(): shape = [8, 8] offset = 8 - reduce_scatter_flag = False - split_num = 1 - net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)) + net = NetWithLoss(Net(shape, offset)) net.set_auto_parallel() x = Tensor(np.ones([64, 32]), dtype=ms.float32) @@ -71,11 +66,9 @@ def test_embeddinglookup_reducescatter_false(): def test_embeddinglookup_reducescatter_true(): - shape = [64, 8] + shape = [8, 8] offset = 8 - reduce_scatter_flag = True - split_num = 8 - net = NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num)) + net = NetWithLoss(Net(shape, offset)) net.set_auto_parallel() x = Tensor(np.ones([64, 32]), dtype=ms.float32) @@ -86,9 +79,7 @@ def test_embeddinglookup_reducescatter_true(): def test_embeddinglookup_reducescatter_false_grad(): shape = [8, 8] offset = 8 - reduce_scatter_flag = False - split_num = 1 - net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num))) + net = GradWrap(NetWithLoss(Net(shape, offset))) net.set_auto_parallel() x = Tensor(np.ones([64, 32]), dtype=ms.float32) @@ -98,13 +89,39 @@ def test_embeddinglookup_reducescatter_false_grad(): def test_embeddinglookup_reducescatter_true_grad(): context.set_context(save_graphs=True) - shape = [64, 8] + shape = [8, 8] offset = 8 - reduce_scatter_flag = True - split_num = 8 - net = GradWrap(NetWithLoss(Net(shape, offset, reduce_scatter_flag, split_num))) + net = GradWrap(NetWithLoss(Net(shape, offset))) net.set_auto_parallel() x = Tensor(np.ones([64, 32]), dtype=ms.float32) y = Tensor(np.ones([8, 32, 8]), dtype=ms.float32) _executor.compile(net, x, y) + + +def test_embeddinglookup_semi_auto1(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + shape = [64, 32] + offset = 0 + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((4, 1, 2), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU"))) + + net.set_auto_parallel() + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) + + +def test_embeddinglookup_semi_auto2(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + shape = [64, 32] + offset = 0 + strategy1 = ((1, 8), (1, 1)) + strategy2 = ((4, 1, 2), (4, 2, 1)) + net = GradWrap(NetWithLoss(Net(shape, offset, strategy1, strategy2, "CPU"))) + + net.set_auto_parallel() + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) + _executor.compile(net, x, y) diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py index 5d52089cbe..2e853875bf 100644 --- a/tests/ut/python/parallel/test_gather_v2.py +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ import numpy as np - import mindspore as ms import mindspore.nn as nn from mindspore import Tensor @@ -182,39 +181,3 @@ def test_gatherv2_auto1(): x = Tensor(np.ones([64, 32]), dtype=ms.float32) y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) _executor.compile(net, x, y) - - -def test_gatherv2_cpu0(): - context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((8, 1), (1, 1)) - strategy2 = ((4, 2, 1), (4, 2, 1)) - net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) - net.set_auto_parallel() - - x = Tensor(np.ones([64, 64]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) - _executor.compile(net, x, y) - - -def test_gatherv2_cpu1(): - context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((16, 1), (1, 1)) - strategy2 = ((4, 2, 1), (4, 2, 1)) - net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) - net.set_auto_parallel() - - x = Tensor(np.ones([64, 64]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) - _executor.compile(net, x, y) - - -def test_gatherv2_cpu2(): - context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - strategy1 = ((1, 8), (1, 1)) - strategy2 = ((4, 2, 1), (4, 2, 1)) - net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) - net.set_auto_parallel() - - x = Tensor(np.ones([64, 64]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) - _executor.compile(net, x, y) diff --git a/tests/ut/python/parallel/test_manual_gatherv2.py b/tests/ut/python/parallel/test_manual_gatherv2.py new file mode 100644 index 0000000000..21d25ae720 --- /dev/null +++ b/tests/ut/python/parallel/test_manual_gatherv2.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer + +class Net(Cell): + def __init__(self, strategy1=None, strategy2=None, strategy3=None): + super().__init__() + self.gatherv2 = P.GatherV2().set_strategy(strategy1) + self.gatherv2.add_prim_attr("manual_split", ((1, 0), (7, 1))) + self.mul = P.Mul().set_strategy(strategy2) + self.reshape = P.Reshape() + self.matmul = P.MatMul().set_strategy(strategy3) + self.matmul.add_prim_attr("forward_reduce_scatter", True) + self.param = Parameter(initializer("ones", (8, 64), ms.float32), name="gatherv2_param") + self.mul_weight = Parameter(initializer("ones", (2, 4, 64), ms.float32), name="mul_weight") + self.matmul_weight = Parameter(initializer("ones", (256, 16), ms.float32), name="matmul_weight") + + def construct(self, x, b): + out = self.gatherv2(self.param, x, 0) + out = self.mul(out, self.mul_weight) + out = self.reshape(out, (2, 256)) + out = self.matmul(out, self.matmul_weight) + return out + +_x = Tensor(np.ones([2, 4]), dtype=ms.int32) +_b = Tensor(np.ones([64, 8]), dtype=ms.float32) + +def compile_net(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + +def test_neg_data_parallel(): + context.set_context(save_graphs=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=2, global_rank=0) + strategy1 = ((2, 1), (1, 2)) + strategy2 = ((1, 2, 1), (1, 2, 1)) + strategy3 = ((1, 2), (2, 1)) + net = Net(strategy1, strategy2, strategy3) + compile_net(net) diff --git a/tests/ut/python/parallel/test_sparse_gather_v2.py b/tests/ut/python/parallel/test_sparse_gather_v2.py index dd0517a08e..2d4d0c2bf2 100644 --- a/tests/ut/python/parallel/test_sparse_gather_v2.py +++ b/tests/ut/python/parallel/test_sparse_gather_v2.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ import numpy as np +import pytest import mindspore as ms import mindspore.nn as nn @@ -184,6 +185,7 @@ def test_gatherv2_auto1(): _executor.compile(net, x, y) +@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") def test_gatherv2_cpu0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((8, 1), (1, 1)) @@ -196,6 +198,7 @@ def test_gatherv2_cpu0(): _executor.compile(net, x, y) +@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") def test_gatherv2_cpu1(): context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((16, 1), (1, 1)) @@ -208,6 +211,7 @@ def test_gatherv2_cpu1(): _executor.compile(net, x, y) +@pytest.mark.skip(reason="The transition from GatherV2 to EmbeddingLookup needs adjusting. by lichen") def test_gatherv2_cpu2(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") strategy1 = ((1, 8), (1, 1)) diff --git a/tests/ut/python/parameter_feature/test_var_grad.py b/tests/ut/python/parameter_feature/test_var_grad.py index 7a332b1c3b..f0358394e7 100644 --- a/tests/ut/python/parameter_feature/test_var_grad.py +++ b/tests/ut/python/parameter_feature/test_var_grad.py @@ -22,7 +22,7 @@ from mindspore.common.parameter import ParameterTuple from mindspore.nn import Cell from mindspore.ops import operations as P -context.set_context(mode=context.GRAPH_MODE) +context.set_context(mode=context.GRAPH_MODE, save_graphs=True) def test_net_vargs_expand(): @@ -184,6 +184,27 @@ def test_grad_var_args_with_sens(): _ = grad_net(x, y, sens) +def test_grad_with_param_sens(): + """"test grad_with_sens parameter""" + + class GradNet(Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.weights = ParameterTuple(net.trainable_params()) + self.net = net + self.sens = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), name='sens', requires_grad=False) + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + + def construct(self, x, y): + return self.grad(self.net, self.weights)(x, y, self.sens) + + x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + net = SecondNet() + grad_net = GradNet(net) + _ = grad_net(x, y) + + def test_var_args_grad(): class VarNet(Cell): def __init__(self, net): diff --git a/tests/ut/python/pipeline/infer/test_hypermap_specialize.py b/tests/ut/python/pipeline/infer/test_hypermap_specialize.py index 1f669f7355..c292e3662d 100644 --- a/tests/ut/python/pipeline/infer/test_hypermap_specialize.py +++ b/tests/ut/python/pipeline/infer/test_hypermap_specialize.py @@ -53,4 +53,4 @@ def test_hypermap_specialize_param(): expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32))) ret = hypermap_specialize_param() - assert ret == (expected_ret, expected_ret) + assert ret == (expected_ret, list(expected_ret)) diff --git a/tests/ut/python/pipeline/infer/test_net_infer.py b/tests/ut/python/pipeline/infer/test_net_infer.py index 6b32a7617d..9c19f213f5 100644 --- a/tests/ut/python/pipeline/infer/test_net_infer.py +++ b/tests/ut/python/pipeline/infer/test_net_infer.py @@ -45,6 +45,7 @@ def test_net_infer(): def test_assign_in_while(): + context.set_context(device_target="Ascend") context.set_context(mode=context.GRAPH_MODE) class Net(nn.Cell): def __init__(self, input_shape): diff --git a/tests/ut/python/pipeline/parse/test_cell_bprop.py b/tests/ut/python/pipeline/parse/test_cell_bprop.py new file mode 100644 index 0000000000..e896ddc9ac --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_cell_bprop.py @@ -0,0 +1,405 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_cell_bprop """ +import numpy as np +import pytest + +import mindspore as ms +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore import Parameter +from mindspore import context +from mindspore.common.initializer import initializer +from mindspore.common.tensor import Tensor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from .....mindspore_test_framework.utils.bprop_util import bprop + + +def setup_module(module): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + +def teardown_module(module): + context.set_context(device_target="Ascend") + +class MulAdd(nn.Cell): + def __init__(self): + super(MulAdd, self).__init__() + + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result + return 2 * dout, 2 * y + + +def test_grad_mul_add(): + mul_add = MulAdd() + x = Tensor(1, dtype=ms.int32) + y = Tensor(2, dtype=ms.int32) + assert C.grad_all(mul_add)(x, y) == (2, 4) + + +class InlineMulADD(nn.Cell): + def __init__(self): + super(InlineMulADD, self).__init__() + self.mul_add = MulAdd() + self.param = 2 + + def construct(self, x, y): + return self.mul_add(x, y) + x + self.param * y + + +def test_grad_inline_mul_add(): + inline_mul_add = InlineMulADD() + x = Tensor(1, dtype=ms.int32) + y = Tensor(2, dtype=ms.int32) + assert C.grad_all(inline_mul_add)(x, y) == (3, 6) + + +class WithParameter(nn.Cell): + def __init__(self): + super(WithParameter, self).__init__() + self.param1 = Parameter(1, 'param1') + self.param2 = Parameter(2, 'param2') + + def construct(self, x, y): + return self.param1 * self.param2 * x + y + + def bprop(self, x, y, out, dout): + # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result + return self.param1 * self.param2 * dout, 2 * y + + +def test_with_param(): + with_param = WithParameter() + with pytest.raises(RuntimeError): + C.grad_all(with_param)(1, 2) + + +class WithNoBprop(nn.Cell): + def __init__(self): + super(WithNoBprop, self).__init__() + + def construct(self, x, y): + return 2 * x + y + + +def test_with_no_bprop(): + with_no_bprop = WithNoBprop() + x = Tensor(1, dtype=ms.int32) + y = Tensor(2, dtype=ms.int32) + assert C.grad_all(with_no_bprop)(x, y) == (2, 1) + + +def test_grad_in_bprop_1(): + class GradInBprop_1(nn.Cell): + def __init__(self): + super(GradInBprop_1, self).__init__() + self.relu = P.ReLU() + + def construct(self, x, y): + return self.relu(x) + + class GradInBprop_2(nn.Cell): + def __init__(self): + super(GradInBprop_2, self).__init__() + self.f = GradInBprop_1() + + def construct(self, x, y): + return self.f(x, y), C.grad_all(self.f)(x, y) + + def bprop(self, x, y, out, dout): + grads = C.grad_all(self.f)(x, y) + return out[1][0], grads[1] + + class GradInBprop_3(nn.Cell): + def __init__(self): + super(GradInBprop_3, self).__init__() + self.f = GradInBprop_2() + + def construct(self, x, y): + return self.f(x, y) + + grad_in_bprop = GradInBprop_3() + grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), + Tensor(np.ones([2, 2]).astype(np.float32))) + assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() + assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() + + +def test_grad_in_bprop_2(): + class GradInBprop_1(nn.Cell): + def __init__(self): + super(GradInBprop_1, self).__init__() + self.relu = P.ReLU() + + def construct(self, x, y): + return self.relu(x) + + def bprop(self, x, y, out, dout): + return x * y, y + x + + class GradInBprop_2(nn.Cell): + def __init__(self): + super(GradInBprop_2, self).__init__() + self.f = GradInBprop_1() + + def construct(self, x, y): + return self.f(x, y), C.grad_all(self.f)(x, y) + + def bprop(self, x, y, out, dout): + grads = C.grad_all(self.f)(x, y) + return out[1][0], grads[1] + + class GradInBprop_3(nn.Cell): + def __init__(self): + super(GradInBprop_3, self).__init__() + self.f = GradInBprop_2() + + def construct(self, x, y): + return self.f(x, y) + + grad_in_bprop = GradInBprop_3() + grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), + Tensor(np.ones([2, 2]).astype(np.float32))) + assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() + assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all() + + +def test_grad_in_bprop_3(): + class GradInBprop_1(nn.Cell): + def __init__(self): + super(GradInBprop_1, self).__init__() + self.relu = P.ReLU() + + def construct(self, x, y): + return self.relu(x) + + class GradInBprop_2(nn.Cell): + def __init__(self): + super(GradInBprop_2, self).__init__() + self.f = GradInBprop_1() + + def construct(self, x, y): + return self.f(x, y), C.grad_all(self.f)(x, y) + + def bprop(self, x, y, out, dout): + grads = C.grad_all(self.f)(x, y) + return out[1][0], grads[1] + + class GradInBprop_3(nn.Cell): + def __init__(self): + super(GradInBprop_3, self).__init__() + self.f = GradInBprop_2() + + def construct(self, x, y): + return self.f(x, y) + + def bprop(self, x, y, out, dout): + return x + y + y + out[0], x + x + y + y + dout[0] + + grad_in_bprop = GradInBprop_3() + grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), + Tensor(np.ones([2, 2]).astype(np.float32))) + assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all() + assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all() + + +class OneInputBprop(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.ReLU() + + def construct(self, x): + return self.op(x) + + def bprop(self, x, out, dout): + return (5 * x,) + + +def test_grad_one_input_bprop(): + net = OneInputBprop() + input1 = Tensor(np.ones([2, 2]).astype(np.float32)) + grad = C.grad_all(net)(input1) + assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all() + + +class TwoInput(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, x, y): + return x * y + + +class InlineBpropTwoInput(nn.Cell): + def __init__(self): + super().__init__() + self.f = TwoInput() + + def construct(self, x, y): + return self.f(x, y), C.grad_all(self.f)(x, y) + + def bprop(self, x, y, out, dout): + grads = C.grad_all(self.f)(x, y) + return grads[0] * 2, grads[1] * 2 + + +def test_grad_inline_bprop_two_input(): + net = InlineBpropTwoInput() + input1 = Tensor(np.ones([2, 2]).astype(np.float32)) + input2 = Tensor(np.ones([2, 2]).astype(np.float32)) + grads = C.grad_all(net)(input1, input2) + assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() + assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all() + assert len(grads) == 2 + + +class TwoInputBprop(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + + def construct(self, x, y): + return self.op(x, y) + + def bprop(self, x, y, out, dout): + return 5 * x, 8 * y + + +class TwoInputWithParameter(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") + + def construct(self, x, y): + x = self.inputdata + x + return self.op(x, y) + + +class TwoInputWithOnlyInitParameterBprop(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") + + def construct(self, x, y): + return self.op(x, y) + + def bprop(self, x, y, out, dout): + return 5 * x, 8 * y + + +class InlineMutilTwoInputParameterCell(nn.Cell): + def __init__(self): + super().__init__() + self.f1 = TwoInputBprop() + self.f2 = TwoInput() + self.f3 = TwoInputWithParameter() + self.f4 = TwoInputWithOnlyInitParameterBprop() + + def construct(self, x, y): + output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y) + return output + + +def test_grad_inline_bprop_multi_input(): + net = InlineMutilTwoInputParameterCell() + input1 = Tensor(np.ones([2, 2]).astype(np.float32)) + input2 = Tensor(np.ones([2, 2]).astype(np.float32)) + net.init_parameters_data() + grads = C.grad_all(net)(input1, input2) + assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all() + assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all() + assert len(grads) == 2 + + +class MulAddWithParam(nn.Cell): + def __init__(self): + super(MulAddWithParam, self).__init__() + self.mul_add = MulAdd() + self.param = Parameter(Tensor(np.array([[3, 2]], np.float32)), 'param') + + def construct(self, x): + return self.mul_add(self.param, x) + + +def test_refkey_bprop(): + net = MulAddWithParam() + input_data = Tensor(np.array([2, 2], np.float32)) + grads = bprop(net, input_data, + grads_wrt_outputs=(Tensor(np.ones([1, 2]).astype(np.float32))), + wrt=['params', 'inputs'], + params=net.trainable_params()) + assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all() + assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() + + +class MulAddWithWrongOutputNum(nn.Cell): + def __init__(self): + super(MulAddWithWrongOutputNum, self).__init__() + + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + return (2 * dout,) + + +def test_grad_mul_add_with_wrong_output_num(): + context.set_context(check_bprop=True) + mul_add = MulAddWithWrongOutputNum() + with pytest.raises(TypeError): + C.grad_all(mul_add)(1, 2) + + +class MulAddWithWrongOutputType(nn.Cell): + def __init__(self): + super(MulAddWithWrongOutputType, self).__init__() + + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + return 2 * dout, 2 + + +def test_grad_mul_add_with_wrong_output_type(): + context.set_context(check_bprop=True) + mul_add = MulAddWithWrongOutputType() + with pytest.raises(TypeError): + C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) + + +class MulAddWithWrongOutputShape(nn.Cell): + def __init__(self): + super(MulAddWithWrongOutputShape, self).__init__() + self.ones = Tensor(np.ones([2,])) + + def construct(self, x, y): + return 2 * x + y + + def bprop(self, x, y, out, dout): + return 2, self.ones + + +def test_grad_mul_add_with_wrong_output_shape(): + context.set_context(check_bprop=True) + mul_add = MulAddWithWrongOutputShape() + with pytest.raises(TypeError): + C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) diff --git a/tests/ut/python/pipeline/parse/test_enumerate.py b/tests/ut/python/pipeline/parse/test_enumerate.py index cd808696f1..37f9c603df 100644 --- a/tests/ut/python/pipeline/parse/test_enumerate.py +++ b/tests/ut/python/pipeline/parse/test_enumerate.py @@ -91,6 +91,7 @@ def test_enumerate_tuple_parameter(): index_sum += i ret += (j,) return index_sum, ret + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) net = Net() net(x, x, x) @@ -127,10 +128,12 @@ def test_enumerate_tuple_parameter_1(): index_sum += i[0] ret += (i[1],) return index_sum, ret + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) net = Net() net(x, x, x) + def test_enumerate_tuple_const_2(): class Net(nn.Cell): def __init__(self): @@ -162,20 +165,37 @@ def test_enumerate_tuple_parameter_2(): index_sum += i[0] ret += (i[1],) return index_sum, ret + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) net = Net() net(x, x, x) -def test_enumerate_parameter_type_error(): +def test_enumerate_first_input_type_error(): class Net(nn.Cell): def __init__(self): super(Net, self).__init__() def construct(self, x): return enumerate(x) + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) net = Net() with pytest.raises(TypeError) as ex: net(x) - assert "For 'enumerate', the input parameter should be tuple or list" in str(ex.value) + assert "For 'enumerate', the 'first input'" in str(ex.value) + + +def test_enumerate_start_type_error(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + return enumerate(x, start=1.2) + + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) + net = Net() + with pytest.raises(TypeError) as ex: + net((x, x)) + assert "For 'enumerate', the 'start'" in str(ex.value) diff --git a/tests/ut/python/pipeline/parse/test_for_stmt.py b/tests/ut/python/pipeline/parse/test_for_stmt.py index 4930dae796..748c73e873 100644 --- a/tests/ut/python/pipeline/parse/test_for_stmt.py +++ b/tests/ut/python/pipeline/parse/test_for_stmt.py @@ -17,6 +17,9 @@ from dataclasses import dataclass import numpy as np from mindspore import Tensor, Model, context +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F from mindspore.nn import Cell from mindspore.nn import ReLU from ...ut_filter import non_graph_engine @@ -66,3 +69,58 @@ def function_access_base(number): def test_access_0040(): """ test_access_0040 """ function_access_base(2) + + +class OpSeqNet(Cell): + def __init__(self, loop_count=1): + super().__init__() + self.loop_count = loop_count + self.op_seq = (P.Sqrt(), P.Reciprocal(), P.Square()) + + def construct(self, x): + t = x + for op in self.op_seq: + t = op(t) + return t + + +def test_op_seq_test(): + context.set_context(mode=context.GRAPH_MODE) + net = OpSeqNet() + input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me = Tensor(input_np) + net(input_me) + + +_grad_fusion = C.MultitypeFuncGraph("grad_fushion") + + +@_grad_fusion.register("Tensor", "Function") +def tensor_grad_scale(x, op): + return op(x) + + +class AllReduceTest(Cell): + def __init__(self, loop_count=1): + super().__init__() + self.op_list = () + self.fushion_flag = [0, 1, 1, 0, 1, 0] + for i in self.fushion_flag: + op = P.AllReduce().add_prim_attr('fusion', i) + self.op_list = self.op_list + (op,) + self.hyper_map = C.HyperMap() + + def construct(self, x): + ret = () + for _ in self.fushion_flag: + ret = ret + (x,) + fushion_res = self.hyper_map(F.partial(_grad_fusion), ret, self.op_list) + return fushion_res + + +def test_allreduce_fushio_test(): + context.set_context(mode=context.GRAPH_MODE) + net = AllReduceTest() + input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me = Tensor(input_np) + net(input_me) diff --git a/tests/ut/python/pipeline/parse/test_parse.py b/tests/ut/python/pipeline/parse/test_parse.py index bbc32d0728..b295adcbec 100644 --- a/tests/ut/python/pipeline/parse/test_parse.py +++ b/tests/ut/python/pipeline/parse/test_parse.py @@ -19,21 +19,27 @@ @Desc : """ import logging +import pytest import numpy as np import mindspore as ms import mindspore.nn as nn from mindspore import Tensor +from mindspore import context +from mindspore.ops import composite as C from mindspore.common.api import ms_function, _executor +from mindspore.ops._grad.grad_base import bprop_getters +from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer from mindspore.ops.functional import tensor_add from ...ut_filter import non_graph_engine -# pylint: disable=W0613 +# pylint: disable=W0613,W0612 # W0613: unused-argument log = logging.getLogger("test") log.setLevel(level=logging.ERROR) +context.set_context(mode=context.GRAPH_MODE) # Test case: use the parse obj interface use default parameter @@ -135,3 +141,113 @@ def test_net_with_ndarray(): input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32') net(ms.Tensor(input_data)) + + +def test_bprop_with_wrong_output_num(): + context.set_context(check_bprop=True) + class BpropWithWrongOutputNum(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') + + def __call__(self, x, y): + return x + + def infer_shape(self, x_shape, yshape): + return x_shape + + def infer_dtype(self, x_type, y_type): + return x_type + + @bprop_getters.register(BpropWithWrongOutputNum) + def get_bprop_with_wrong_output_num(self): + """Generate bprop for BpropWithWrongOutputNum""" + + def bprop(x, y, out, dout): + return (dout,) + + return bprop + + class BpropWithWrongOutputNumCell(nn.Cell): + def __init__(self): + super(BpropWithWrongOutputNumCell, self).__init__() + + def construct(self, x, y): + return BpropWithWrongOutputNum()(x, y) + + with pytest.raises(TypeError): + C.grad_all(BpropWithWrongOutputNumCell())(1, 2) + +def test_bprop_with_wrong_output_type(): + context.set_context(check_bprop=True) + class BpropWithWrongOutputType(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') + + def __call__(self, x): + return x + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + return x_type + + @bprop_getters.register(BpropWithWrongOutputType) + def get_bprop_with_wrong_output_type(self): + """Generate bprop for BpropWithWrongOutputType""" + + def bprop(x, out, dout): + return (1,) + + return bprop + + class BpropWithWrongOutputTypeCell(nn.Cell): + def __init__(self): + super(BpropWithWrongOutputTypeCell, self).__init__() + + def construct(self, x): + return BpropWithWrongOutputType()(x) + + with pytest.raises(TypeError): + C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) + + +def test_bprop_with_wrong_output_shape(): + context.set_context(check_bprop=True) + class BpropWithWrongOutputShape(PrimitiveWithInfer): + @prim_attr_register + def __init__(self): + super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') + + def __call__(self, x): + return x + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + return x_type + + @bprop_getters.register(BpropWithWrongOutputShape) + def get_bprop_with_wrong_output_shape(self): + """Generate bprop for BpropWithWrongOutputShape""" + ones = Tensor(np.ones([2,]).astype(np.int32)) + + def bprop(x, out, dout): + return (ones,) + + return bprop + + class BpropWithWrongOutputShapeCell(nn.Cell): + def __init__(self): + super(BpropWithWrongOutputShapeCell, self).__init__() + + def construct(self, x): + return BpropWithWrongOutputShape()(x) + + with pytest.raises(TypeError): + net = BpropWithWrongOutputShapeCell() + net.set_grad() + C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) diff --git a/tests/ut/python/pynative_mode/nn/test_tensor_operation.py b/tests/ut/python/pynative_mode/nn/test_tensor_operation.py index 306ba63c9f..eb8610bdf1 100644 --- a/tests/ut/python/pynative_mode/nn/test_tensor_operation.py +++ b/tests/ut/python/pynative_mode/nn/test_tensor_operation.py @@ -78,3 +78,9 @@ def test_tensor_imul(): y = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32)) x *= y assert x.asnumpy()[0][0][0][0] == 1.0 + + +def test_tensor_pow(): + x = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32) * 2) + y = x ** 3 + assert y.asnumpy()[0][0][0][0] == 8.0 diff --git a/tests/ut/python/pynative_mode/ops/test_grad.py b/tests/ut/python/pynative_mode/ops/test_grad.py index 8d880a86d9..f028e91beb 100644 --- a/tests/ut/python/pynative_mode/ops/test_grad.py +++ b/tests/ut/python/pynative_mode/ops/test_grad.py @@ -89,7 +89,11 @@ def test_scalar_cast_grad(): output = F.scalar_cast(x, input_t) return output - gfn = C.grad(fx_cast)(input_x) + @ms_function + def grad_fx_cast(input_x): + return C.grad(fx_cast)(input_x) + + gfn = grad_fx_cast(input_x) expect_dx = 1 assert gfn == expect_dx @@ -133,25 +137,6 @@ def test_transpose_grad(): assert np.all(gout[0].asnumpy() == expect) -@non_graph_engine -def test_squeeze_grad(): - """ test_squeeze_grad """ - input_tensor = Tensor(np.ones(shape=[3, 2, 1])) - squeeze = P.Squeeze(2) - - def fn(x): - output = squeeze(x) - return output - - out = fn(input_tensor) - gfn = grad_all_with_sens(fn) - sens = Tensor(np.ones_like(out.asnumpy())) - args = [input_tensor, sens] - gout = gfn(*args) - expect = np.ones([3, 2, 1]) - assert np.all(gout[0].asnumpy() == expect) - - def test_select_grad(): """ test_select_grad """ select = P.Select() @@ -176,6 +161,25 @@ def test_select_grad(): assert np.all(gout[2].asnumpy() == expect_y) +@non_graph_engine +def test_squeeze_grad(): + """ test_squeeze_grad """ + input_tensor = Tensor(np.ones(shape=[3, 2, 1])) + squeeze = P.Squeeze(2) + + def fn(x): + output = squeeze(x) + return output + + out = fn(input_tensor) + gfn = grad_all_with_sens(fn) + sens = Tensor(np.ones_like(out.asnumpy())) + args = [input_tensor, sens] + gout = gfn(*args) + expect = np.ones([3, 2, 1]) + assert np.all(gout[0].asnumpy() == expect) + + def test_SubGrad(): """ test_SubGrad """ input_x = Tensor(np.array([[2, 2]])) diff --git a/tests/ut/python/pynative_mode/test_cell_bprop.py b/tests/ut/python/pynative_mode/test_cell_bprop.py deleted file mode 100644 index 09a096a090..0000000000 --- a/tests/ut/python/pynative_mode/test_cell_bprop.py +++ /dev/null @@ -1,396 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" test_cell_bprop """ -import numpy as np -import pytest - -import mindspore.common.dtype as mstype -import mindspore.nn as nn -from mindspore import Parameter -from mindspore import context -from mindspore.common.initializer import initializer -from mindspore.common.tensor import Tensor -from mindspore.ops import composite as C -from mindspore.ops import operations as P -from ....mindspore_test_framework.utils.bprop_util import bprop - - -def setup_module(module): - context.set_context(mode=context.PYNATIVE_MODE) - - -class MulAdd(nn.Cell): - def __init__(self): - super(MulAdd, self).__init__() - - def construct(self, x, y): - return 2 * x + y - - def bprop(self, x, y, out, dout): - # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result - return 2 * dout, 2 * y - - -def test_grad_mul_add(): - mul_add = MulAdd() - assert C.grad_all(mul_add)(1, 2) == (2, 4) - - -class InlineMulADD(nn.Cell): - def __init__(self): - super(InlineMulADD, self).__init__() - self.mul_add = MulAdd() - self.param = 2 - - def construct(self, x, y): - return self.mul_add(x, y) + x + self.param * y - - -def test_grad_inline_mul_add(): - inline_mul_add = InlineMulADD() - assert C.grad_all(inline_mul_add)(1, 2) == (3, 6) - - -class WithParameter(nn.Cell): - def __init__(self): - super(WithParameter, self).__init__() - self.param1 = Parameter(1, 'param1') - self.param2 = Parameter(2, 'param2') - - def construct(self, x, y): - return self.param1 * self.param2 * x + y - - def bprop(self, x, y, out, dout): - # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result - return self.param1 * self.param2 * dout, 2 * y - - -def test_with_param(): - with_param = WithParameter() - with pytest.raises(RuntimeError): - C.grad_all(with_param)(1, 2) - - -class WithNoBprop(nn.Cell): - def __init__(self): - super(WithNoBprop, self).__init__() - - def construct(self, x, y): - return 2 * x + y - - -def test_with_no_bprop(): - with_no_bprop = WithNoBprop() - assert C.grad_all(with_no_bprop)(1, 2) == (2, 1) - - -def test_grad_in_bprop_1(): - class GradInBprop_1(nn.Cell): - def __init__(self): - super(GradInBprop_1, self).__init__() - self.relu = P.ReLU() - - def construct(self, x, y): - return self.relu(x) - - class GradInBprop_2(nn.Cell): - def __init__(self): - super(GradInBprop_2, self).__init__() - self.f = GradInBprop_1() - - def construct(self, x, y): - return self.f(x, y), C.grad_all(self.f)(x, y) - - def bprop(self, x, y, out, dout): - grads = C.grad_all(self.f)(x, y) - return out[1][0], grads[1] - - class GradInBprop_3(nn.Cell): - def __init__(self): - super(GradInBprop_3, self).__init__() - self.f = GradInBprop_2() - - def construct(self, x, y): - return self.f(x, y) - - grad_in_bprop = GradInBprop_3() - grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), - Tensor(np.ones([2, 2]).astype(np.float32))) - assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() - - -def test_grad_in_bprop_2(): - class GradInBprop_1(nn.Cell): - def __init__(self): - super(GradInBprop_1, self).__init__() - self.relu = P.ReLU() - - def construct(self, x, y): - return self.relu(x) - - def bprop(self, x, y, out, dout): - return x * y, y + x - - class GradInBprop_2(nn.Cell): - def __init__(self): - super(GradInBprop_2, self).__init__() - self.f = GradInBprop_1() - - def construct(self, x, y): - return self.f(x, y), C.grad_all(self.f)(x, y) - - def bprop(self, x, y, out, dout): - grads = C.grad_all(self.f)(x, y) - return out[1][0], grads[1] - - class GradInBprop_3(nn.Cell): - def __init__(self): - super(GradInBprop_3, self).__init__() - self.f = GradInBprop_2() - - def construct(self, x, y): - return self.f(x, y) - - grad_in_bprop = GradInBprop_3() - grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), - Tensor(np.ones([2, 2]).astype(np.float32))) - assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all() - - -def test_grad_in_bprop_3(): - class GradInBprop_1(nn.Cell): - def __init__(self): - super(GradInBprop_1, self).__init__() - self.relu = P.ReLU() - - def construct(self, x, y): - return self.relu(x) - - class GradInBprop_2(nn.Cell): - def __init__(self): - super(GradInBprop_2, self).__init__() - self.f = GradInBprop_1() - - def construct(self, x, y): - return self.f(x, y), C.grad_all(self.f)(x, y) - - def bprop(self, x, y, out, dout): - grads = C.grad_all(self.f)(x, y) - return out[1][0], grads[1] - - class GradInBprop_3(nn.Cell): - def __init__(self): - super(GradInBprop_3, self).__init__() - self.f = GradInBprop_2() - - def construct(self, x, y): - return self.f(x, y) - - def bprop(self, x, y, out, dout): - return x + y + y + out[0], x + x + y + y + dout[0] - - grad_in_bprop = GradInBprop_3() - grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), - Tensor(np.ones([2, 2]).astype(np.float32))) - assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all() - - -class OneInputBprop(nn.Cell): - def __init__(self): - super().__init__() - self.op = P.ReLU() - - def construct(self, x): - return self.op(x) - - def bprop(self, x, out, dout): - return (5 * x,) - - -def test_grad_one_input_bprop(): - net = OneInputBprop() - input1 = Tensor(np.ones([2, 2]).astype(np.float32)) - grad = C.grad_all(net)(input1) - assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all() - - -class TwoInput(nn.Cell): - def __init__(self): - super().__init__() - - def construct(self, x, y): - return x * y - - -class InlineBpropTwoInput(nn.Cell): - def __init__(self): - super().__init__() - self.f = TwoInput() - - def construct(self, x, y): - return self.f(x, y), C.grad_all(self.f)(x, y) - - def bprop(self, x, y, out, dout): - grads = C.grad_all(self.f)(x, y) - return grads[0] * 2, grads[1] * 2 - - -def test_grad_inline_bprop_two_input(): - net = InlineBpropTwoInput() - input1 = Tensor(np.ones([2, 2]).astype(np.float32)) - input2 = Tensor(np.ones([2, 2]).astype(np.float32)) - grads = C.grad_all(net)(input1, input2) - assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all() - assert len(grads) == 2 - - -class TwoInputBprop(nn.Cell): - def __init__(self): - super().__init__() - self.op = P.Mul() - - def construct(self, x, y): - return self.op(x, y) - - def bprop(self, x, y, out, dout): - return 5 * x, 8 * y - - -class TwoInputWithParameter(nn.Cell): - def __init__(self): - super().__init__() - self.op = P.Mul() - self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") - - def construct(self, x, y): - x = self.inputdata + x - return self.op(x, y) - - -class TwoInputWithOnlyInitParameterBprop(nn.Cell): - def __init__(self): - super().__init__() - self.op = P.Mul() - self.inputdata = Parameter(initializer(1, (2, 2), mstype.float32), name="global_step") - - def construct(self, x, y): - return self.op(x, y) - - def bprop(self, x, y, out, dout): - return 5 * x, 8 * y - - -class InlineMutilTwoInputParameterCell(nn.Cell): - def __init__(self): - super().__init__() - self.f1 = TwoInputBprop() - self.f2 = TwoInput() - self.f3 = TwoInputWithParameter() - self.f4 = TwoInputWithOnlyInitParameterBprop() - - def construct(self, x, y): - output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y) - return output - - -def test_grad_inline_bprop_multi_input(): - net = InlineMutilTwoInputParameterCell() - input1 = Tensor(np.ones([2, 2]).astype(np.float32)) - input2 = Tensor(np.ones([2, 2]).astype(np.float32)) - net.init_parameters_data() - grads = C.grad_all(net)(input1, input2) - assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all() - assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all() - assert len(grads) == 2 - - -class MulAddWithParam(nn.Cell): - def __init__(self): - super(MulAddWithParam, self).__init__() - self.mul_add = MulAdd() - self.param = Parameter(Tensor(np.array([[3, 2]], np.float32)), 'param') - - def construct(self, x): - return self.mul_add(self.param, x) - - -def test_refkey_bprop(): - net = MulAddWithParam() - input_data = Tensor(np.array([2, 2], np.float32)) - grads = bprop(net, input_data, - grads_wrt_outputs=(Tensor(np.ones([1, 2]).astype(np.float32))), - wrt=['params', 'inputs'], - params=net.trainable_params()) - assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all() - assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() - - -class MulAddWithWrongOutputNum(nn.Cell): - def __init__(self): - super(MulAddWithWrongOutputNum, self).__init__() - - def construct(self, x, y): - return 2 * x + y - - def bprop(self, x, y, out, dout): - return (2 * dout,) - - -def test_grad_mul_add_with_wrong_output_num(): - context.set_context(check_bprop=True) - mul_add = MulAddWithWrongOutputNum() - with pytest.raises(TypeError): - C.grad_all(mul_add)(1, 2) - - -class MulAddWithWrongOutputType(nn.Cell): - def __init__(self): - super(MulAddWithWrongOutputType, self).__init__() - - def construct(self, x, y): - return 2 * x + y - - def bprop(self, x, y, out, dout): - return 2 * dout, 2 - - -def test_grad_mul_add_with_wrong_output_type(): - context.set_context(check_bprop=True) - mul_add = MulAddWithWrongOutputType() - with pytest.raises(TypeError): - C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) - - -class MulAddWithWrongOutputShape(nn.Cell): - def __init__(self): - super(MulAddWithWrongOutputShape, self).__init__() - self.ones = Tensor(np.ones([2,])) - - def construct(self, x, y): - return 2 * x + y - - def bprop(self, x, y, out, dout): - return 2, self.ones - - -def test_grad_mul_add_with_wrong_output_shape(): - context.set_context(check_bprop=True) - mul_add = MulAddWithWrongOutputShape() - with pytest.raises(TypeError): - C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) diff --git a/tests/ut/python/pynative_mode/test_context.py b/tests/ut/python/pynative_mode/test_context.py index 66dc0a4f58..e2d4e31412 100644 --- a/tests/ut/python/pynative_mode/test_context.py +++ b/tests/ut/python/pynative_mode/test_context.py @@ -118,6 +118,12 @@ def test_variable_memory_max_size(): context.set_context(variable_memory_max_size="3GB") +def test_print_file_path(): + """test_print_file_path""" + with pytest.raises(IOError): + context.set_context(print_file_path="./") + + def test_set_context(): """ test_set_context """ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index 39a4c97ab9..3b99d0dc5f 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -16,6 +16,7 @@ import numpy as np import pytest +import mindspore as ms import mindspore.nn as nn from mindspore import context from mindspore.common import dtype as mstype @@ -23,8 +24,6 @@ from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.tensor import Tensor from mindspore.ops import composite as C from mindspore.ops import operations as P -from mindspore.ops._grad.grad_base import bprop_getters -from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer from ..ut_filter import non_graph_engine from ....mindspore_test_framework.utils.check_gradient import ( ms_function, check_jacobian, Tensor, NNGradChecker, @@ -156,14 +155,14 @@ def test_if_always_true(): @non_graph_engine def test_f(): """ test_f """ - res = mainf(3, 2) + res = mainf(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) assert res == (2, 3) @non_graph_engine def test_grad_add_mul(): """ test_grad_add_mul """ - res = grad_add_mul(3, 2) + res = grad_add_mul(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) assert res == (2, 7) @@ -262,17 +261,19 @@ def test_if_tensor(): assert res == Tensor(np.ones([1]).astype(np.int32) * 4) -@ms_function def rec(x): """ rec """ if x > 0: return rec(x - 1) return x +@ms_function +def grad_rec(input_x): + return C.grad(rec)(input_x) def test_grad_rec(): """ test_grad_rec """ - res = C.grad(rec)(10) + res = grad_rec(3) assert res == 1 @@ -282,7 +283,6 @@ def test_me_rec(): assert res == 0 -@ms_function def t2_while(x, y): out = y - x i = 0 @@ -298,8 +298,10 @@ def test_while2(): def test_grad_while2(): - res = C.grad(t2_while)(2, 3) - assert res == 3 + @ms_function + def df_t2_while(input_x, input_y): + return C.grad(t2_while)(input_x, input_y) + assert df_t2_while(2, 3) == 3 def if_test(a, b): @@ -316,7 +318,7 @@ def grad_if(x, y): def test_grad_if(): """ test_grad_if """ - assert grad_if(5, 4) == (3, 0) + assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0) # While loop is not unrolled in forward and backward graphs. @@ -421,7 +423,7 @@ def grad_while(x): def test_grad_while(): """ test_grad_while """ - assert grad_while(5) == (60,) + assert grad_while(Tensor(5, dtype=ms.int32)) == (60,) @ms_function @@ -438,8 +440,10 @@ def test_factorial(): def test_grad_factorial(): - res = C.grad(factorial)(3) - assert res == 11 + @ms_function + def df_factorial(x): + return C.grad(factorial)(x) + assert df_factorial(3) == 11 @ms_function @@ -513,7 +517,7 @@ def _for(x): ret = ret * i return ret - +@ms_function def grad_for(x): """ grad_for """ return C.grad_all(_for)(x) @@ -786,7 +790,10 @@ def multi_outputs(x, y): def test_grad_multi_outputs(): - assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4) + @ms_function + def df_multi_outputs(x, y): + return C.grad_all_with_sens(multi_outputs)(x, y, (1, 1)) + assert df_multi_outputs(2, 3) == (4, 4) @ms_function @@ -813,7 +820,7 @@ def grad_refactor_simple_1(x, y): def test_grad_refactor_simple_1(): - assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2) + assert C.grad_all(grad_refactor_simple_1)(Tensor(2, dtype=ms.int32), Tensor(1, dtype=ms.int32)) == (4, 2) def grad_refactor_simple_2(x, y, z): @@ -822,7 +829,10 @@ def grad_refactor_simple_2(x, y, z): def test_grad_refactor_simple_2(): - assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7) + x = Tensor(2, dtype=ms.int32) + y = Tensor(3, dtype=ms.int32) + z = Tensor(0, dtype=ms.int32) + assert C.grad_all(grad_refactor_simple_2)(x, y, z) == (7, 4, 7) def grad_refactor_1(a, b): @@ -835,7 +845,7 @@ def grad_refactor_1(a, b): def test_grad_refactor_1(): - assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2) + assert C.grad_all(grad_refactor_1)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (3, 2) def grad_refactor_2(a, b): @@ -848,7 +858,7 @@ def grad_refactor_2(a, b): def test_grad_refactor_2(): - assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54) + assert C.grad_all(grad_refactor_2)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (27, 54) def grad_refactor_3(a): @@ -859,7 +869,10 @@ def grad_refactor_3(a): def test_grad_refactor_3(): - assert C.grad_all(grad_refactor_3)(3) == (3,) + @ms_function + def df_refactor_3(x): + return C.grad_all(grad_refactor_3)(x) + assert df_refactor_3(3) == (3,) def grad_refactor_4(a): @@ -870,7 +883,7 @@ def grad_refactor_4(a): def test_grad_refactor_4(): - assert C.grad_all(grad_refactor_4)(4) == (3,) + assert C.grad_all(grad_refactor_4)(Tensor(4, dtype=ms.int32)) == (3,) def grad_refactor_5(a): @@ -881,7 +894,10 @@ def grad_refactor_5(a): def test_grad_refactor_5(): - assert C.grad_all(grad_refactor_5)(1) == (1,) + @ms_function + def df_refactor_5(x): + return C.grad_all(grad_refactor_5)(x) + assert df_refactor_5(1) == (1,) def grad_refactor_6(a, b): @@ -892,7 +908,7 @@ def grad_refactor_6(a, b): def test_grad_refactor_6(): - assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1) + assert C.grad_all(grad_refactor_6)(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) == (3, 1) def grad_refactor_while(x): @@ -904,7 +920,10 @@ def grad_refactor_while(x): def test_grad_refactor_9(): - assert C.grad_all(grad_refactor_while)(3) == (6,) + @ms_function + def df_refactor_while(input_x): + return C.grad_all(grad_refactor_while)(input_x) + assert df_refactor_while(3) == (6,) def grad_refactor__while_1(x): @@ -919,7 +938,7 @@ def grad_refactor__while_1(x): def test_grad_refactor_10(): """ test_grad_while """ - assert C.grad_all(grad_refactor__while_1)(5) == (60,) + assert C.grad_all(grad_refactor__while_1)(Tensor(5, dtype=ms.int32)) == (60,) def test_grad_refactor_11(): @@ -985,7 +1004,10 @@ def grad_refactor_14(a, b): def test_grad_refactor_14(): - assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9) + @ms_function + def df_refactor_14(x, y): + return C.grad_all(grad_refactor_14)(x, y) + assert df_refactor_14(2, 3) == (3, 9) # pylint: disable=using-constant-test @@ -1011,109 +1033,11 @@ def test_grad_if_defer_inline(): assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) -def test_bprop_with_wrong_output_num(): - context.set_context(check_bprop=True) - class BpropWithWrongOutputNum(PrimitiveWithInfer): - @prim_attr_register - def __init__(self): - super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') - - def __call__(self, x, y): - return x - - def infer_shape(self, x_shape, yshape): - return x_shape - - def infer_dtype(self, x_type, y_type): - return x_type - - @bprop_getters.register(BpropWithWrongOutputNum) - def get_bprop_with_wrong_output_num(self): - """Generate bprop for BpropWithWrongOutputNum""" - - def bprop(x, y, out, dout): - return (dout,) - - return bprop - - class BpropWithWrongOutputNumCell(nn.Cell): - def __init__(self): - super(BpropWithWrongOutputNumCell, self).__init__() - - def construct(self, x, y): - return BpropWithWrongOutputNum()(x, y) - - with pytest.raises(TypeError): - C.grad_all(BpropWithWrongOutputNumCell())(1, 2) - -def test_bprop_with_wrong_output_type(): - context.set_context(check_bprop=True) - class BpropWithWrongOutputType(PrimitiveWithInfer): - @prim_attr_register - def __init__(self): - super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') - - def __call__(self, x): - return x - - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_type): - return x_type - - @bprop_getters.register(BpropWithWrongOutputType) - def get_bprop_with_wrong_output_type(self): - """Generate bprop for BpropWithWrongOutputType""" - - def bprop(x, out, dout): - return (1,) - - return bprop - - class BpropWithWrongOutputTypeCell(nn.Cell): - def __init__(self): - super(BpropWithWrongOutputTypeCell, self).__init__() - - def construct(self, x): - return BpropWithWrongOutputType()(x) - - with pytest.raises(TypeError): - C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) - - -def test_bprop_with_wrong_output_shape(): - context.set_context(check_bprop=True) - class BpropWithWrongOutputShape(PrimitiveWithInfer): - @prim_attr_register - def __init__(self): - super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') - - def __call__(self, x): - return x - - def infer_shape(self, x_shape): - return x_shape - - def infer_dtype(self, x_type): - return x_type - - @bprop_getters.register(BpropWithWrongOutputShape) - def get_bprop_with_wrong_output_shape(self): - """Generate bprop for BpropWithWrongOutputShape""" - ones = Tensor(np.ones([2,]).astype(np.int32)) - - def bprop(x, out, dout): - return (ones,) - - return bprop - - class BpropWithWrongOutputShapeCell(nn.Cell): +def test_dict_const(): + class Net(nn.Cell): def __init__(self): - super(BpropWithWrongOutputShapeCell, self).__init__() - - def construct(self, x): - return BpropWithWrongOutputShape()(x) - - with pytest.raises(TypeError): - C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) + super(Net, self).__init__() + self.res = {'1': 10} + def construct(self): + return self.res + Net()() diff --git a/tests/ut/python/pynative_mode/test_hook.py b/tests/ut/python/pynative_mode/test_hook.py index 07a7a7ad8b..f34a81ab5c 100644 --- a/tests/ut/python/pynative_mode/test_hook.py +++ b/tests/ut/python/pynative_mode/test_hook.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ import numpy as np +import pytest import mindspore.nn as nn import mindspore.ops.operations as P @@ -154,22 +155,47 @@ def test_hook(): print(loss_output.asnumpy().shape) +bprop_debug = False + class MulAdd(nn.Cell): def __init__(self): super(MulAdd, self).__init__() def construct(self, x, y): - return 2 * x + y + return 2 * x * x + y * y def bprop(self, x, y, out, dout): - assert (x == 1) - assert (y == 2) - assert (out == 4) - assert (dout == 1) - return 3 * dout, 2 * y + global bprop_debug + bprop_debug = True + return dout, 2 * y def test_custom_bprop(): mul_add = MulAdd() mul_add.bprop_debug = True - assert C.grad_all(mul_add)(1, 2) == (3, 4) + x = Tensor(np.array([1, 2, 3]).astype(np.int32)) + y = Tensor(np.array([2, 3, 4]).astype(np.int32)) + C.grad_all(mul_add)(x, y) + assert bprop_debug + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x, y): + return 2 * x * x + y * y + +def test_grad_all(): + net = Net() + x = Tensor(np.array([1, 2, 3]).astype(np.int32)) + y = Tensor(np.array([2, 3, 4]).astype(np.int32)) + res = C.grad_all(net)(x, y) + print(res) + +def test_check_input(): + net = Net() + x = np.array([1, 2, 3]) + y = np.array([2, 3, 4]) + with pytest.raises(TypeError): + net(x, y) diff --git a/tests/ut/python/pynative_mode/test_implicit_conversion.py b/tests/ut/python/pynative_mode/test_implicit_conversion.py new file mode 100644 index 0000000000..ecaffd87f2 --- /dev/null +++ b/tests/ut/python/pynative_mode/test_implicit_conversion.py @@ -0,0 +1,204 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test implicit conversion """ +import numpy as np + +from mindspore import Tensor, nn +from mindspore.ops import composite as C + + +def test_float_tensor_and_int_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = 2 + ret_actual = x + y + ret_expect = Tensor(np.array([[2.1, 2.2, 2.3], [2.4, 2.5, 2.6]], dtype=np.float32)) + assert ret_actual.dtype == ret_expect.dtype + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_bool_tensor_and_float_add(): + x = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) + y = 3.3 + ret_actual = x + y + ret_expect = Tensor(np.array([[4.3, 3.3], [3.3, 4.3]], dtype=np.float32)) + assert ret_actual.dtype == ret_expect.dtype + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_bool_tensor_and_int_add(): + x = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) + y = 3 + ret_actual = x + y + ret_expect = Tensor(np.array([[4, 3], [3, 4]], dtype=np.int32)) + assert ret_actual.dtype == ret_expect.dtype + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_bool_and_int_tensor_add(): + x = True + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) + ret_actual = x + y + ret_expect = Tensor(np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)) + assert ret_actual.dtype == ret_expect.dtype + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_float_tensor_and_int_tensor_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) + ret_actual = x + y + ret_expect = Tensor(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)) + assert ret_actual.dtype == ret_expect.dtype + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_float_tensor_and_float_tensor_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float16)) + ret_actual = x + y + ret_expect = Tensor(np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)) + assert ret_actual.dtype == ret_expect.dtype + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_int_tensor_and_int_tensor_add(): + x = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int8)) + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) + ret_actual = x + y + ret_expect = Tensor(np.array([[2, 4, 6], [8, 10, 12]], dtype=np.int32)) + assert ret_actual.dtype == ret_expect.dtype + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_float_tensor_and_bool_tensors_add(): + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = Tensor(np.array([[True, True, True], [False, False, False]], dtype=np.bool_)) + ret_actual = x + y + ret_expect = Tensor(np.array([[1.1, 1.2, 1.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + assert (ret_actual.asnumpy() == ret_expect.asnumpy()).all() + + +def test_float_tensor_and_bool_tensors_add_grad(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x, y): + return x + y + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, x, y, sens): + + return C.grad_all_with_sens(self.net)(x, y, sens) + + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = Tensor(np.array([[True, True, True], [False, False, False]], dtype=np.bool_)) + sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32)) + net = Net() + grad_net = GradNet(net) + ret = grad_net(x, y, sens) + assert ret[0].dtype == x.dtype + assert ret[1].dtype == y.dtype + assert (ret[0].asnumpy() == sens.asnumpy()).all() + assert (ret[1].asnumpy() == sens.asnumpy().astype(np.bool_)).all() + + +def test_float_tensor_and_int_tensors_sub_grad(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x, y): + return x - y + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, x, y, sens): + + return C.grad_all_with_sens(self.net)(x, y, sens) + + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + y = Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)) + sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32)) + net = Net() + grad_net = GradNet(net) + ret = grad_net(x, y, sens) + print(ret) + assert ret[0].dtype == x.dtype + assert ret[1].dtype == y.dtype + assert (ret[0].asnumpy() == sens.asnumpy()).all() + assert (ret[1].asnumpy() == sens.asnumpy() * -1).all() + + +def test_float16_tensor_and_float32_tensors_sub_grad(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x, y): + return x - y + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, x, y, sens): + + return C.grad_all_with_sens(self.net)(x, y, sens) + + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.int32)) + y = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)) + sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32)) + net = Net() + grad_net = GradNet(net) + ret = grad_net(x, y, sens) + print(ret) + assert ret[0].dtype == x.dtype + assert ret[1].dtype == y.dtype + assert (ret[0].asnumpy() == sens.asnumpy()).all() + assert (ret[1].asnumpy() == sens.asnumpy() * -1).all() + + +def test_float_tensor_and_int_add_grad(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + return x + 2 + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, x, sens): + return C.grad_all_with_sens(self.net)(x, sens) + + x = Tensor(np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float32)) + sens = Tensor(np.array([[1.0, 2.0, 0.0], [0.0, 3.0, 4.0]], dtype=np.float32)) + net = Net() + grad_net = GradNet(net) + ret = grad_net(x, sens) + assert ret[0].dtype == x.dtype + assert (ret[0].asnumpy() == sens.asnumpy()).all() diff --git a/tests/ut/python/pynative_mode/test_insert_grad_of.py b/tests/ut/python/pynative_mode/test_insert_grad_of.py index 0a28bbbb63..218a4ee253 100644 --- a/tests/ut/python/pynative_mode/test_insert_grad_of.py +++ b/tests/ut/python/pynative_mode/test_insert_grad_of.py @@ -46,6 +46,7 @@ def test_InsertGradientOf_1(): c = x * y return c + @ms_function def f(x, y): return C.grad_all(stop_test)(x, y) @@ -80,6 +81,7 @@ def test_InsertGradientOf_2(): def f(x, y): return clip_test(x, y) + @ms_function def fd(x, y): return C.grad_all(clip_test)(x, y) diff --git a/tests/ut/python/pynative_mode/test_stop_gradient.py b/tests/ut/python/pynative_mode/test_stop_gradient.py index a94f80adf0..09e4f25c54 100644 --- a/tests/ut/python/pynative_mode/test_stop_gradient.py +++ b/tests/ut/python/pynative_mode/test_stop_gradient.py @@ -16,6 +16,7 @@ import numpy as np import pytest +import mindspore as ms import mindspore.common.dtype as mstype import mindspore.nn as nn from mindspore import Parameter, ParameterTuple @@ -81,16 +82,24 @@ def stop_test4(x, y): return e +@ms_function def grad_stop_test(x, y): """ grad_stop_test """ return C.grad_all(stop_test2)(x, y) +@ms_function def grad_stop_test1(x, y): """ grad_stop_test1 """ return C.grad_all(stop_test3)(x, y) +@ms_function +def grad_stop_test5(x, y): + """ grad_stop_test5 """ + return C.grad_all(stop_test5)(x, y) + + def test_stop(): """ test_stop """ print("test_stop:", grad_stop_test(1, 1)) @@ -103,7 +112,7 @@ def test_stop1(): def test_stop5(): """ test_stop1 """ - print("test_stop5:", C.grad_all(stop_test5)(2, 3)) + print("test_stop5:", grad_stop_test5(2, 3)) class GradWrap(nn.Cell): @@ -247,7 +256,7 @@ def test_stop_gradient_4(): def stop_test(x): return stop_gradient(x) - assert C.grad_all(stop_test)(1) == (0,) + assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,) def test_stop_gradient_5(): @@ -257,7 +266,7 @@ def test_stop_gradient_5(): ret = x + y return ret - assert C.grad_all(stop_test)(1) == (1,) + assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,) def test_stop_gradient_6(): @@ -266,7 +275,7 @@ def test_stop_gradient_6(): ret = stop_gradient(ret) return ret - assert C.grad_all(stop_test)(1, 3) == (0, 0) + assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0) class PrimWithMultiOutputs(PrimitiveWithInfer): diff --git a/tests/ut/python/train/quant/mobilenetv2.py b/tests/ut/python/train/quant/mobilenetv2.py deleted file mode 100644 index 163b230e1e..0000000000 --- a/tests/ut/python/train/quant/mobilenetv2.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""MobileNetV2""" -from mindspore import nn -from mindspore.ops import operations as P - - -def make_divisible(input_x, div_by=8): - return int((input_x + div_by) // div_by) - - -def _conv_bn(in_channel, - out_channel, - ksize, - stride=1): - """Get a conv2d batchnorm and relu layer.""" - return nn.SequentialCell( - [nn.Conv2d(in_channel, - out_channel, - kernel_size=ksize, - stride=stride), - nn.BatchNorm2d(out_channel)]) - - -class InvertedResidual(nn.Cell): - def __init__(self, inp, oup, stride, expend_ratio): - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2] - - hidden_dim = int(inp * expend_ratio) - self.use_res_connect = self.stride == 1 and inp == oup - if expend_ratio == 1: - self.conv = nn.SequentialCell([ - nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim), - nn.BatchNorm2d(hidden_dim), - nn.ReLU6(), - nn.Conv2d(hidden_dim, oup, 1, 1), - nn.BatchNorm2d(oup) - ]) - else: - self.conv = nn.SequentialCell([ - nn.Conv2d(inp, hidden_dim, 1, 1), - nn.BatchNorm2d(hidden_dim), - nn.ReLU6(), - - nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim), - nn.BatchNorm2d(hidden_dim), - nn.ReLU6(), - - nn.Conv2d(hidden_dim, oup, 1, 1), - nn.BatchNorm2d(oup) - ]) - - def construct(self, input_x): - out = self.conv(input_x) - if self.use_res_connect: - out = input_x + out - return out - - -class MobileNetV2(nn.Cell): - def __init__(self, num_class=1000, input_size=224, width_mul=1.): - super(MobileNetV2, self).__init__() - _ = input_size - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - inverted_residual_setting = [ - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 230, 1, 1], - ] - if width_mul > 1.0: - last_channel = make_divisible(last_channel * width_mul) - self.last_channel = last_channel - features = [_conv_bn(3, input_channel, 3, 2)] - - for t, c, n, s in inverted_residual_setting: - out_channel = make_divisible(c * width_mul) if t > 1 else c - for i in range(n): - if i == 0: - features.append(block(input_channel, out_channel, s, t)) - else: - features.append(block(input_channel, out_channel, 1, t)) - input_channel = out_channel - - features.append(_conv_bn(input_channel, self.last_channel, 1)) - - self.features = nn.SequentialCell(features) - self.mean = P.ReduceMean(keep_dims=False) - self.classifier = nn.Dense(self.last_channel, num_class) - - def construct(self, input_x): - out = input_x - out = self.features(out) - out = self.mean(out, (2, 3)) - out = self.classifier(out) - return out diff --git a/tests/ut/python/train/quant/mobilenetv2_combined.py b/tests/ut/python/train/quant/mobilenetv2_combined.py deleted file mode 100644 index 51916192d8..0000000000 --- a/tests/ut/python/train/quant/mobilenetv2_combined.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""mobile net v2""" -from mindspore import nn -from mindspore.ops import operations as P - - -def make_divisible(input_x, div_by=8): - return int((input_x + div_by) // div_by) - - -def _conv_bn(in_channel, - out_channel, - ksize, - stride=1): - """Get a conv2d batchnorm and relu layer.""" - return nn.SequentialCell( - [nn.Conv2dBnAct(in_channel, - out_channel, - kernel_size=ksize, - stride=stride, - has_bn=True)]) - - -class InvertedResidual(nn.Cell): - def __init__(self, inp, oup, stride, expend_ratio): - super(InvertedResidual, self).__init__() - self.stride = stride - assert stride in [1, 2] - - hidden_dim = int(inp * expend_ratio) - self.use_res_connect = self.stride == 1 and inp == oup - if expend_ratio == 1: - self.conv = nn.SequentialCell([ - nn.Conv2dBnAct(hidden_dim, - hidden_dim, - 3, - stride, - group=hidden_dim, - has_bn=True, - activation='relu6'), - nn.Conv2dBnAct(hidden_dim, oup, 1, 1, - has_bn=True) - ]) - else: - self.conv = nn.SequentialCell([ - nn.Conv2dBnAct(inp, hidden_dim, 1, 1, - has_bn=True, - activation='relu6'), - nn.Conv2dBnAct(hidden_dim, - hidden_dim, - 3, - stride, - group=hidden_dim, - has_bn=True, - activation='relu6'), - nn.Conv2dBnAct(hidden_dim, oup, 1, 1, - has_bn=True) - ]) - self.add = P.TensorAdd() - - def construct(self, input_x): - out = self.conv(input_x) - if self.use_res_connect: - out = self.add(input_x, out) - return out - - -class MobileNetV2(nn.Cell): - def __init__(self, num_class=1000, input_size=224, width_mul=1.): - super(MobileNetV2, self).__init__() - _ = input_size - block = InvertedResidual - input_channel = 32 - last_channel = 1280 - inverted_residual_setting = [ - [1, 16, 1, 1], - [6, 24, 2, 2], - [6, 32, 3, 2], - [6, 64, 4, 2], - [6, 96, 3, 1], - [6, 160, 3, 2], - [6, 230, 1, 1], - ] - if width_mul > 1.0: - last_channel = make_divisible(last_channel * width_mul) - self.last_channel = last_channel - features = [_conv_bn(3, input_channel, 3, 2)] - - for t, c, n, s in inverted_residual_setting: - out_channel = make_divisible(c * width_mul) if t > 1 else c - for i in range(n): - if i == 0: - features.append(block(input_channel, out_channel, s, t)) - else: - features.append(block(input_channel, out_channel, 1, t)) - input_channel = out_channel - - features.append(_conv_bn(input_channel, self.last_channel, 1)) - - self.features = nn.SequentialCell(features) - self.mean = P.ReduceMean(keep_dims=False) - self.classifier = nn.DenseBnAct(self.last_channel, num_class) - - def construct(self, input_x): - out = input_x - out = self.features(out) - out = self.mean(out, (2, 3)) - out = self.classifier(out) - return out diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index 1a21bc2c02..39e887170c 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -20,7 +20,7 @@ import mindspore.context as context from mindspore import Tensor from mindspore import nn from mindspore.train.quant import quant as qat -from mobilenetv2_combined import MobileNetV2 +from model_zoo.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -42,7 +42,7 @@ class LeNet5(nn.Cell): def __init__(self, num_class=10): super(LeNet5, self).__init__() self.num_class = num_class - self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid") + self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu', pad_mode="valid") self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid") self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu') @@ -67,20 +67,19 @@ def test_qat_lenet(): img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) net = LeNet5() net = qat.convert_quant_network( - net, freeze_bn=10000, num_bits=8) + net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) # should load the checkpoint. mock here for param in net.get_parameters(): param.init_data() - qat.export_geir(net, img, file_name="quant.pb") + qat.export(net, img, file_name="quant.pb") @pytest.mark.skip(reason="no `te.lang.cce` in ut env") def test_qat_mobile(): - net = MobileNetV2() + network = mobilenetV2(num_classes=1000) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) - net = qat.convert_quant_network( - net, quant_delay=0, bn_fold=True, freeze_bn=10000, num_bits=8) + network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) # should load the checkpoint. mock here - for param in net.get_parameters(): + for param in network.get_parameters(): param.init_data() - qat.export_geir(net, img, file_name="quant.pb") + qat.export(network, img, file_name="quant.pb") diff --git a/tests/ut/python/train/test_amp.py b/tests/ut/python/train/test_amp.py index c7befb6c2b..6bb4ec5464 100644 --- a/tests/ut/python/train/test_amp.py +++ b/tests/ut/python/train/test_amp.py @@ -22,10 +22,10 @@ from mindspore import amp from mindspore import nn from mindspore.train import Model, ParallelMode from mindspore.common import dtype as mstype -from mindspore.model_zoo.resnet import resnet50 from ....dataset_mock import MindData from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.communication.management import init +from tests.ut.python.model.resnet import resnet50 def setup_module(module): _ = module diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index 035ea87845..7f85695a19 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -34,7 +34,7 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load _exec_save_checkpoint, export, _save_graph from ..ut_filter import non_graph_engine -context.set_context(mode=context.GRAPH_MODE, print_file_path="print.pb") +context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb") class Net(nn.Cell): @@ -374,10 +374,13 @@ def test_print(): def teardown_module(): - files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt', 'print.pb'] + files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt'] for item in files: file_name = './' + item if not os.path.exists(file_name): continue os.chmod(file_name, stat.S_IWRITE) os.remove(file_name) + import shutil + if os.path.exists('./print'): + shutil.rmtree('./print') diff --git a/tests/vm_impl/vm_me.py b/tests/vm_impl/vm_me.py index 89cc1569a9..7216ec613b 100644 --- a/tests/vm_impl/vm_me.py +++ b/tests/vm_impl/vm_me.py @@ -441,7 +441,7 @@ def max_pool_grad(x, dout, pool_h, pool_w, stride): """Grad of max pooling.""" dout = dout.transpose(0, 2, 3, 1) pool_size = pool_h * pool_w - dmax = np.zeros((dout.size, pool_size)) + dmax = np.zeros((dout.size, pool_size), dout.dtype) col = im2col(x, pool_h, pool_w, stride) col = col.reshape(-1, pool_h * pool_w) arg_max = np.argmax(col, axis=1) @@ -456,7 +456,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride): """Grad of max pooling with argmax.""" dout = dout.transpose(0, 2, 3, 1) pool_size = pool_h * pool_w - dmax = np.zeros((dout.size, pool_size)) + dmax = np.zeros((dout.size, pool_size), dout.dtype) dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten() dmax = dmax.reshape(dout.shape + (pool_size,)) dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1) diff --git a/third_party/icu4c/filter.json b/third_party/icu4c/filter.json deleted file mode 100644 index b3decad8fb..0000000000 --- a/third_party/icu4c/filter.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "strategy": "additive", - "featureFilters": { - "normalization": "include" - } -} \ No newline at end of file diff --git a/third_party/patch/pslite/ps_lite.patch001 b/third_party/patch/pslite/ps_lite.patch001 index bdc7b11a4b..e2e51e93c8 100644 --- a/third_party/patch/pslite/ps_lite.patch001 +++ b/third_party/patch/pslite/ps_lite.patch001 @@ -12,16 +12,7 @@ diff -Npur ps-lite-master/include/dmlc/base.h ps-lite-master-new/include/dmlc/ba /*! diff -Npur ps-lite-master/include/dmlc/logging.h ps-lite-master-new/include/dmlc/logging.h --- ps-lite-master/include/dmlc/logging.h 2020-02-29 13:59:55.000000000 +0800 -+++ ps-lite-master-new/include/dmlc/logging.h 2020-07-01 11:58:00.015919207 +0800 -@@ -13,7 +13,7 @@ - #include - #include - #include --#include "./base.h" -+//#include "./base.h" - - #if DMLC_LOG_STACK_TRACE - #include ++++ ps-lite-master-new/include/dmlc/logging.h 2020-07-08 21:35:33.334584767 +0800 @@ -52,7 +52,7 @@ struct Error : public std::runtime_error namespace dmlc {